From c22e9e4459ff3c371a3c38a98d394edfcd95c369 Mon Sep 17 00:00:00 2001 From: hydai Date: Mon, 16 Mar 2020 17:56:49 +0000 Subject: [PATCH 001/623] [CI] Update Dockerfile and Docker image --- utils/docker/Dockerfile.base | 14 +++++++++++++ utils/docker/Dockerfile.build | 8 ++++++++ utils/docker/build.sh | 38 +++++++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+) create mode 100644 utils/docker/Dockerfile.base create mode 100644 utils/docker/Dockerfile.build create mode 100755 utils/docker/build.sh diff --git a/utils/docker/Dockerfile.base b/utils/docker/Dockerfile.base new file mode 100644 index 00000000..d1abe185 --- /dev/null +++ b/utils/docker/Dockerfile.base @@ -0,0 +1,14 @@ +FROM ubuntu:18.04 + +MAINTAINER hydai hydai@secondstate.io + +RUN apt update && apt install -y \ + cmake \ + curl \ + g++ \ + libboost-all-dev + +RUN curl -sL https://deb.nodesource.com/setup_12.x | bash \ + && apt install -y nodejs + +RUN rm -rf /var/lib/apt/lists/* diff --git a/utils/docker/Dockerfile.build b/utils/docker/Dockerfile.build new file mode 100644 index 00000000..3b5e1d1b --- /dev/null +++ b/utils/docker/Dockerfile.build @@ -0,0 +1,8 @@ +ARG BASE=secondstate/ssvm:ubuntu-base +FROM ${BASE} + +RUN apt update && apt install -y \ + software-properties-common \ + llvm-8-dev + +RUN rm -rf /var/lib/apt/lists/* diff --git a/utils/docker/build.sh b/utils/docker/build.sh new file mode 100755 index 00000000..d5e63bfd --- /dev/null +++ b/utils/docker/build.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +NAME=${1:+$1/}ssvm +INTERMEDIATES=() +IMAGES=() + +set -e + +function docker_build +{ + local FILENAME=$1; shift + local TAG=$1; shift + local NAME_TAG=${NAME}:${TAG} + echo "Building docker image \"${NAME_TAG}\" from file \"${FILENAME}\"." + + ( set -x; docker build "$@" -f "${FILENAME}" -t "${NAME_TAG}" . ) + + if [[ "${TAG}" == im-* ]]; then + INTERMEDIATES+=( "${NAME_TAG}" ) + else + IMAGES+=( "${NAME_TAG}" ) + fi +} + +# Build all images. +docker_build Dockerfile.base ubuntu-base +docker_build Dockerfile.build ubuntu-build \ + --build-arg "BASE=${NAME}:ubuntu-base" + +# Remove intermediate images. +for NAME_TAG in "${INTERMEDIATES[@]}"; do + ( set -x; docker rmi "${NAME_TAG}" ) +done + +# Push all images. +for NAME_TAG in "${IMAGES[@]}"; do + ( set -x; docker push "${NAME_TAG}" ) +done From 2209285c282ff25c6b6d72c1cae413b96ad49535 Mon Sep 17 00:00:00 2001 From: hydai Date: Thu, 19 Mar 2020 10:34:30 +0000 Subject: [PATCH 002/623] [CI] Update llvm dependency from 8 to 9. And enable -fPIC flag --- utils/docker/Dockerfile.base | 1 - utils/docker/Dockerfile.build | 8 -------- utils/docker/Dockerfile.build-clang | 24 ++++++++++++++++++++++++ utils/docker/Dockerfile.build-gcc | 17 +++++++++++++++++ utils/docker/build.sh | 8 ++++++-- 5 files changed, 47 insertions(+), 11 deletions(-) delete mode 100644 utils/docker/Dockerfile.build create mode 100644 utils/docker/Dockerfile.build-clang create mode 100644 utils/docker/Dockerfile.build-gcc diff --git a/utils/docker/Dockerfile.base b/utils/docker/Dockerfile.base index d1abe185..98948809 100644 --- a/utils/docker/Dockerfile.base +++ b/utils/docker/Dockerfile.base @@ -5,7 +5,6 @@ MAINTAINER hydai hydai@secondstate.io RUN apt update && apt install -y \ cmake \ curl \ - g++ \ libboost-all-dev RUN curl -sL https://deb.nodesource.com/setup_12.x | bash \ diff --git a/utils/docker/Dockerfile.build b/utils/docker/Dockerfile.build deleted file mode 100644 index 3b5e1d1b..00000000 --- a/utils/docker/Dockerfile.build +++ /dev/null @@ -1,8 +0,0 @@ -ARG BASE=secondstate/ssvm:ubuntu-base -FROM ${BASE} - -RUN apt update && apt install -y \ - software-properties-common \ - llvm-8-dev - -RUN rm -rf /var/lib/apt/lists/* diff --git a/utils/docker/Dockerfile.build-clang b/utils/docker/Dockerfile.build-clang new file mode 100644 index 00000000..a7b54cb4 --- /dev/null +++ b/utils/docker/Dockerfile.build-clang @@ -0,0 +1,24 @@ +ARG BASE=secondstate/ssvm:ubuntu-base +FROM ${BASE} + +RUN apt update && apt install -y \ + software-properties-common \ + wget + +RUN wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|apt-key add - + +RUN apt update && apt install -y \ + libllvm9 \ + llvm-9 \ + llvm-9-dev \ + llvm-9-runtime \ + libclang-common-9-dev # for yaml-bench + +RUN apt install -y \ + clang-9 \ + clang-tools-9 + +RUN rm -rf /var/lib/apt/lists/* + +ENV CC=/usr/bin/clang-9 +ENV CXX=/usr/bin/clang++-9 diff --git a/utils/docker/Dockerfile.build-gcc b/utils/docker/Dockerfile.build-gcc new file mode 100644 index 00000000..0fca0847 --- /dev/null +++ b/utils/docker/Dockerfile.build-gcc @@ -0,0 +1,17 @@ +ARG BASE=secondstate/ssvm:ubuntu-base +FROM ${BASE} + +RUN apt update && apt install -y \ + software-properties-common \ + wget + +RUN wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|apt-key add - + +RUN apt update && apt install -y \ + libllvm9 \ + llvm-9 \ + llvm-9-dev \ + llvm-9-runtime \ + libclang-common-9-dev # for yaml-bench + +RUN rm -rf /var/lib/apt/lists/* diff --git a/utils/docker/build.sh b/utils/docker/build.sh index d5e63bfd..3102865b 100755 --- a/utils/docker/build.sh +++ b/utils/docker/build.sh @@ -23,8 +23,12 @@ function docker_build } # Build all images. -docker_build Dockerfile.base ubuntu-base -docker_build Dockerfile.build ubuntu-build \ +docker_build Dockerfile.base ubuntu-base +docker_build Dockerfile.build-clang ubuntu-build-clang \ + --build-arg "BASE=${NAME}:ubuntu-base" +docker_build Dockerfile.build-clang latest \ + --build-arg "BASE=${NAME}:ubuntu-base" +docker_build Dockerfile.build-gcc ubuntu-build-gcc \ --build-arg "BASE=${NAME}:ubuntu-base" # Remove intermediate images. From 787f6d43315c7c528af2b6827affc654415b58d3 Mon Sep 17 00:00:00 2001 From: hydai Date: Fri, 20 Mar 2020 09:30:32 +0000 Subject: [PATCH 003/623] [CI] Update gcc from 7 to 8 --- utils/docker/Dockerfile.build-gcc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/utils/docker/Dockerfile.build-gcc b/utils/docker/Dockerfile.build-gcc index 0fca0847..5dc74d1a 100644 --- a/utils/docker/Dockerfile.build-gcc +++ b/utils/docker/Dockerfile.build-gcc @@ -14,4 +14,11 @@ RUN apt update && apt install -y \ llvm-9-runtime \ libclang-common-9-dev # for yaml-bench +RUN apt update && apt install -y \ + gcc-8 \ + g++-8 + RUN rm -rf /var/lib/apt/lists/* + +ENV CC=gcc-8 +ENV CXX=g++-8 From abe30d43e0c1aaa360f250af6ba59f43e356daab Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 13 May 2020 09:53:31 +0000 Subject: [PATCH 004/623] [CI] Add lld-9 dependency --- utils/docker/Dockerfile.build-clang | 1 + utils/docker/Dockerfile.build-gcc | 1 + 2 files changed, 2 insertions(+) diff --git a/utils/docker/Dockerfile.build-clang b/utils/docker/Dockerfile.build-clang index a7b54cb4..f16a04c0 100644 --- a/utils/docker/Dockerfile.build-clang +++ b/utils/docker/Dockerfile.build-clang @@ -11,6 +11,7 @@ RUN apt update && apt install -y \ libllvm9 \ llvm-9 \ llvm-9-dev \ + liblld-9-dev \ llvm-9-runtime \ libclang-common-9-dev # for yaml-bench diff --git a/utils/docker/Dockerfile.build-gcc b/utils/docker/Dockerfile.build-gcc index 5dc74d1a..1a4d960c 100644 --- a/utils/docker/Dockerfile.build-gcc +++ b/utils/docker/Dockerfile.build-gcc @@ -11,6 +11,7 @@ RUN apt update && apt install -y \ libllvm9 \ llvm-9 \ llvm-9-dev \ + liblld-9-dev \ llvm-9-runtime \ libclang-common-9-dev # for yaml-bench From a27b725bd7a380ed47b139b87555f6a625dcb9af Mon Sep 17 00:00:00 2001 From: hydai Date: Fri, 15 May 2020 10:19:55 +0000 Subject: [PATCH 005/623] [CI] Update Cmake from 3.10 to 3.17 --- utils/docker/Dockerfile.base | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/utils/docker/Dockerfile.base b/utils/docker/Dockerfile.base index 98948809..b4179b3f 100644 --- a/utils/docker/Dockerfile.base +++ b/utils/docker/Dockerfile.base @@ -2,9 +2,22 @@ FROM ubuntu:18.04 MAINTAINER hydai hydai@secondstate.io +RUN apt update && apt install -y \ + apt-transport-https \ + ca-certificates \ + gnupg \ + software-properties-common \ + wget + +RUN wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | \ + gpg --dearmor - | \ + tee /etc/apt/trusted.gpg.d/kitware.gpg >/dev/null + +RUN apt-add-repository 'deb https://apt.kitware.com/ubuntu/ bionic main' -y RUN apt update && apt install -y \ cmake \ curl \ + git \ libboost-all-dev RUN curl -sL https://deb.nodesource.com/setup_12.x | bash \ From a9a6e82eb551cf2043bd2cbafea723d95871763c Mon Sep 17 00:00:00 2001 From: hydai Date: Mon, 13 Jul 2020 12:42:49 +0000 Subject: [PATCH 006/623] [CI] Update base image to Ubuntu 20.04 --- utils/docker/Dockerfile.base | 20 ++++++-------------- utils/docker/Dockerfile.build-clang | 21 +++++---------------- utils/docker/Dockerfile.build-gcc | 22 ++++++---------------- 3 files changed, 17 insertions(+), 46 deletions(-) diff --git a/utils/docker/Dockerfile.base b/utils/docker/Dockerfile.base index b4179b3f..ab323566 100644 --- a/utils/docker/Dockerfile.base +++ b/utils/docker/Dockerfile.base @@ -1,26 +1,18 @@ -FROM ubuntu:18.04 +FROM ubuntu:20.04 MAINTAINER hydai hydai@secondstate.io +ENV DEBIAN_FRONTEND=noninteractive -RUN apt update && apt install -y \ - apt-transport-https \ - ca-certificates \ - gnupg \ +RUN apt update && apt upgrade -y \ + && apt install -y \ software-properties-common \ - wget - -RUN wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | \ - gpg --dearmor - | \ - tee /etc/apt/trusted.gpg.d/kitware.gpg >/dev/null - -RUN apt-add-repository 'deb https://apt.kitware.com/ubuntu/ bionic main' -y -RUN apt update && apt install -y \ + wget \ cmake \ curl \ git \ libboost-all-dev -RUN curl -sL https://deb.nodesource.com/setup_12.x | bash \ +RUN curl -sL https://deb.nodesource.com/setup_14.x | bash \ && apt install -y nodejs RUN rm -rf /var/lib/apt/lists/* diff --git a/utils/docker/Dockerfile.build-clang b/utils/docker/Dockerfile.build-clang index f16a04c0..23355d88 100644 --- a/utils/docker/Dockerfile.build-clang +++ b/utils/docker/Dockerfile.build-clang @@ -2,24 +2,13 @@ ARG BASE=secondstate/ssvm:ubuntu-base FROM ${BASE} RUN apt update && apt install -y \ - software-properties-common \ - wget - -RUN wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|apt-key add - + llvm-dev \ + liblld-10-dev RUN apt update && apt install -y \ - libllvm9 \ - llvm-9 \ - llvm-9-dev \ - liblld-9-dev \ - llvm-9-runtime \ - libclang-common-9-dev # for yaml-bench - -RUN apt install -y \ - clang-9 \ - clang-tools-9 + clang RUN rm -rf /var/lib/apt/lists/* -ENV CC=/usr/bin/clang-9 -ENV CXX=/usr/bin/clang++-9 +ENV CC=/usr/bin/clang +ENV CXX=/usr/bin/clang++ diff --git a/utils/docker/Dockerfile.build-gcc b/utils/docker/Dockerfile.build-gcc index 1a4d960c..014594ed 100644 --- a/utils/docker/Dockerfile.build-gcc +++ b/utils/docker/Dockerfile.build-gcc @@ -2,24 +2,14 @@ ARG BASE=secondstate/ssvm:ubuntu-base FROM ${BASE} RUN apt update && apt install -y \ - software-properties-common \ - wget - -RUN wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|apt-key add - - -RUN apt update && apt install -y \ - libllvm9 \ - llvm-9 \ - llvm-9-dev \ - liblld-9-dev \ - llvm-9-runtime \ - libclang-common-9-dev # for yaml-bench + llvm-dev \ + liblld-10-dev RUN apt update && apt install -y \ - gcc-8 \ - g++-8 + gcc \ + g++ RUN rm -rf /var/lib/apt/lists/* -ENV CC=gcc-8 -ENV CXX=g++-8 +ENV CC=gcc +ENV CXX=g++ From 257887107532331818219724ad1e7558cf1e8e94 Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Tue, 15 Dec 2020 03:42:28 +0000 Subject: [PATCH 007/623] [Misc] Add ninja-build on docker image * Change CI workflows to use ninja generator --- utils/docker/Dockerfile.base | 1 + 1 file changed, 1 insertion(+) diff --git a/utils/docker/Dockerfile.base b/utils/docker/Dockerfile.base index ab323566..41b365a3 100644 --- a/utils/docker/Dockerfile.base +++ b/utils/docker/Dockerfile.base @@ -8,6 +8,7 @@ RUN apt update && apt upgrade -y \ software-properties-common \ wget \ cmake \ + ninja-build \ curl \ git \ libboost-all-dev From 323d5a5f91f76eedb4891be7b87c3739219ab0da Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Fri, 18 Dec 2020 08:03:12 +0000 Subject: [PATCH 008/623] [Misc] Build manylinux* package * Add qemu for aarch64 --- utils/docker/Dockerfile.manylinux1_x86_64 | 76 +++++++++++++++++++ utils/docker/Dockerfile.manylinux2010_x86_64 | 75 ++++++++++++++++++ utils/docker/Dockerfile.manylinux2014_aarch64 | 74 ++++++++++++++++++ utils/docker/Dockerfile.manylinux2014_x86_64 | 75 ++++++++++++++++++ utils/docker/SHA256SUM | 10 +++ utils/docker/build-manylinux.sh | 14 ++++ utils/docker/llvm-glibc-2.5.patch | 39 ++++++++++ 7 files changed, 363 insertions(+) create mode 100644 utils/docker/Dockerfile.manylinux1_x86_64 create mode 100644 utils/docker/Dockerfile.manylinux2010_x86_64 create mode 100644 utils/docker/Dockerfile.manylinux2014_aarch64 create mode 100644 utils/docker/Dockerfile.manylinux2014_x86_64 create mode 100644 utils/docker/SHA256SUM create mode 100755 utils/docker/build-manylinux.sh create mode 100644 utils/docker/llvm-glibc-2.5.patch diff --git a/utils/docker/Dockerfile.manylinux1_x86_64 b/utils/docker/Dockerfile.manylinux1_x86_64 new file mode 100644 index 00000000..122b0f1a --- /dev/null +++ b/utils/docker/Dockerfile.manylinux1_x86_64 @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +FROM quay.io/pypa/manylinux1_x86_64 + +MAINTAINER hydai hydai@secondstate.io + +ADD SHA256SUM llvm-glibc-2.5.patch build-manylinux.sh /root + +RUN cd && yum check-update && yum install -y xz && \ + export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ + 'import multiprocessing; print(multiprocessing.cpu_count())') && \ + export CFGFLAGS="--prefix=/toolchain --disable-shared --libdir=/toolchain/lib64" && \ + export PKG_CONFIG_PATH=/toolchain/lib64/pkgconfig && \ + export CC=gcc && \ + export CXX=g++ && \ + export CPPFLAGS=-I/toolchain/include && \ + export LDFLAGS=-L/toolchain/lib64 && \ + curl -s -L -O --remote-name-all \ + https://ftp.gnu.org/gnu/gmp/gmp-6.2.1.tar.xz \ + https://ftp.gnu.org/gnu/mpfr/mpfr-4.1.0.tar.xz \ + https://ftp.gnu.org/gnu/mpc/mpc-1.2.1.tar.gz \ + http://isl.gforge.inria.fr/isl-0.23.tar.xz \ + https://github.com/facebook/zstd/releases/download/v1.4.7/zstd-1.4.7.tar.gz \ + https://ftp.gnu.org/gnu/gcc/gcc-10.2.0/gcc-10.2.0.tar.xz \ + https://github.com/Kitware/CMake/releases/download/v3.13.5/cmake-3.13.5.tar.gz \ + https://github.com/ninja-build/ninja/archive/v1.10.2.tar.gz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/llvm-11.0.0.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/lld-11.0.0.src.tar.xz && \ + sha256sum -c SHA256SUM && \ + xz -dc gmp-6.2.1.tar.xz | tar -xf - && \ + xz -dc mpfr-4.1.0.tar.xz | tar -xf - && \ + gzip -dc mpc-1.2.1.tar.gz | tar -xf - && \ + xz -dc isl-0.23.tar.xz | tar -xf - && \ + gzip -dc zstd-1.4.7.tar.gz | tar -xf - && \ + xz -dc gcc-10.2.0.tar.xz | tar -xf - && \ + gzip -dc cmake-3.13.5.tar.gz | tar -xf - && \ + gzip -dc v1.10.2.tar.gz | tar -xf - && \ + xz -dc llvm-11.0.0.src.tar.xz | tar -xf - && \ + xz -dc lld-11.0.0.src.tar.xz | tar -xf - && \ + mkdir build && cd build && ../gmp-6.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mkdir build && cd build && ../mpfr-4.1.0/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mkdir build && cd build && ../mpc-1.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mkdir build && cd build && ../isl-0.23/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + export ZSTDFLAGS="PREFIX=/toolchain LIBDIR=/toolchain/lib64 SED_ERE_OPT=--regexp-extended" && \ + cd zstd-1.4.7 && make -s $ZSTDFLAGS -j $CPU && make -s $ZSTDFLAGS install && rm -f /toolchain/lib64/libzstd.so* && cd - && \ + mkdir build && cd build && ../gcc-10.2.0/configure --prefix=/toolchain --libdir=/toolchain/lib64 \ + --with-gmp=/toolchain --with-gmp-lib=/toolchain/lib64 \ + --with-zstd=/toolchain --with-zstd-lib=/toolchain/lib64 \ + --disable-libmpx --disable-libsanitizer --disable-libunwind-exceptions \ + --disable-multilib --enable-__cxa_atexit --enable-gnu-indirect-function \ + --disable-gnu-unique-object --enable-initfini-array --enable-languages="c,c++,lto" \ + --enable-linker-build-id --enable-lto --enable-plugin --enable-threads=posix \ + --with-default-libstdcxx-abi="gcc4-compatible" \ + --with-gcc-major-version-only --with-linker-hash-style="gnu" \ + --with-tune="generic" && \ + make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mv -v /toolchain/lib64/libstdc++.so.6.0.28 /toolchain/lib64/libstdc++.so.6.0.28.backup && \ + echo -e "OUTPUT_FORMAT(elf64-x86-64)\nINPUT ( libstdc++.so.6.0.8 libstdc++.a )" \ + > /toolchain/lib64/libstdc++.so.6.0.28 && \ + export PATH="/toolchain/bin:$PATH" && \ + mkdir build && cd build && /opt/python/cp39-cp39/bin/python \ + ../ninja-1.10.2/configure.py --bootstrap \ + --with-python=/opt/python/cp39-cp39/bin/python && \ + cp -v ninja /toolchain/bin/ninja && cd - && rm -rf build && \ + mkdir build && cd build && ../cmake-3.13.5/configure --prefix=/toolchain \ + --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mv -v lld-11.0.0.src lld && \ + cd llvm-11.0.0.src && patch -p1 < ../llvm-glibc-2.5.patch && cd - && \ + cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=/toolchain \ + -DPython3_ROOT_DIR=/opt/python/cp39-cp39 -DLLVM_LIBDIR_SUFFIX=64 \ + -DLLVM_TARGETS_TO_BUILD="X86" -DLLVM_ENABLE_PROJECTS=lld \ + -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" \ + llvm-11.0.0.src && cmake --build build --target install && rm -rf build && \ + rm -rf * + +RUN yum clean all diff --git a/utils/docker/Dockerfile.manylinux2010_x86_64 b/utils/docker/Dockerfile.manylinux2010_x86_64 new file mode 100644 index 00000000..5a03df34 --- /dev/null +++ b/utils/docker/Dockerfile.manylinux2010_x86_64 @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +FROM quay.io/pypa/manylinux2010_x86_64 + +MAINTAINER hydai hydai@secondstate.io + +ADD SHA256SUM build-manylinux.sh /root + +RUN cd && yum check-update && yum install -y xz && \ + export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ + 'import multiprocessing; print(multiprocessing.cpu_count())') && \ + export CFGFLAGS="--prefix=/toolchain --disable-shared --libdir=/toolchain/lib64" && \ + export PKG_CONFIG_PATH=/toolchain/lib64/pkgconfig && \ + export CC=gcc && \ + export CXX=g++ && \ + export CPPFLAGS=-I/toolchain/include && \ + export LDFLAGS=-L/toolchain/lib64 && \ + curl -s -L -O --remote-name-all \ + https://ftp.gnu.org/gnu/gmp/gmp-6.2.1.tar.xz \ + https://ftp.gnu.org/gnu/mpfr/mpfr-4.1.0.tar.xz \ + https://ftp.gnu.org/gnu/mpc/mpc-1.2.1.tar.gz \ + http://isl.gforge.inria.fr/isl-0.23.tar.xz \ + https://github.com/facebook/zstd/releases/download/v1.4.7/zstd-1.4.7.tar.gz \ + https://ftp.gnu.org/gnu/gcc/gcc-10.2.0/gcc-10.2.0.tar.xz \ + https://github.com/Kitware/CMake/releases/download/v3.13.5/cmake-3.13.5.tar.gz \ + https://github.com/ninja-build/ninja/archive/v1.10.2.tar.gz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/llvm-11.0.0.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/lld-11.0.0.src.tar.xz && \ + sha256sum -c SHA256SUM && \ + xz -dc gmp-6.2.1.tar.xz | tar -xf - && \ + xz -dc mpfr-4.1.0.tar.xz | tar -xf - && \ + gzip -dc mpc-1.2.1.tar.gz | tar -xf - && \ + xz -dc isl-0.23.tar.xz | tar -xf - && \ + gzip -dc zstd-1.4.7.tar.gz | tar -xf - && \ + xz -dc gcc-10.2.0.tar.xz | tar -xf - && \ + gzip -dc cmake-3.13.5.tar.gz | tar -xf - && \ + gzip -dc v1.10.2.tar.gz | tar -xf - && \ + xz -dc llvm-11.0.0.src.tar.xz | tar -xf - && \ + xz -dc lld-11.0.0.src.tar.xz | tar -xf - && \ + mkdir build && cd build && ../gmp-6.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mkdir build && cd build && ../mpfr-4.1.0/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mkdir build && cd build && ../mpc-1.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mkdir build && cd build && ../isl-0.23/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + export ZSTDFLAGS="PREFIX=/toolchain LIBDIR=/toolchain/lib64 SED_ERE_OPT=--regexp-extended" && \ + cd zstd-1.4.7 && make -s $ZSTDFLAGS -j $CPU && make -s $ZSTDFLAGS install && rm -f /toolchain/lib64/libzstd.so* && cd - && \ + mkdir build && cd build && ../gcc-10.2.0/configure --prefix=/toolchain --libdir=/toolchain/lib64 \ + --with-gmp=/toolchain --with-gmp-lib=/toolchain/lib64 \ + --with-zstd=/toolchain --with-zstd-lib=/toolchain/lib64 \ + --disable-libmpx --disable-libsanitizer --disable-libunwind-exceptions \ + --disable-multilib --enable-__cxa_atexit --enable-gnu-indirect-function \ + --enable-gnu-unique-object --enable-initfini-array --enable-languages="c,c++,lto" \ + --enable-linker-build-id --enable-lto --enable-plugin --enable-threads=posix \ + --with-default-libstdcxx-abi="gcc4-compatible" \ + --with-gcc-major-version-only --with-linker-hash-style="gnu" \ + --with-tune="generic" && \ + make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mv -v /toolchain/lib64/libstdc++.so.6.0.28 /toolchain/lib64/libstdc++.so.6.0.28.backup && \ + echo -e "OUTPUT_FORMAT(elf64-x86-64)\nINPUT ( libstdc++.so.6.0.13 libstdc++.a )" \ + > /toolchain/lib64/libstdc++.so.6.0.28 && \ + export PATH="/toolchain/bin:$PATH" && \ + mkdir build && cd build && /opt/python/cp39-cp39/bin/python \ + ../ninja-1.10.2/configure.py --bootstrap \ + --with-python=/opt/python/cp39-cp39/bin/python && \ + cp -v ninja /toolchain/bin/ninja && cd - && rm -rf build && \ + mkdir build && cd build && ../cmake-3.13.5/configure --prefix=/toolchain \ + --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mv -v lld-11.0.0.src lld && \ + cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=/toolchain \ + -DPython3_ROOT_DIR=/opt/python/cp39-cp39 -DLLVM_LIBDIR_SUFFIX=64 \ + -DLLVM_TARGETS_TO_BUILD="X86" -DLLVM_ENABLE_PROJECTS=lld \ + -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" \ + llvm-11.0.0.src && cmake --build build --target install && rm -rf build && \ + rm -rf * + +RUN yum clean all diff --git a/utils/docker/Dockerfile.manylinux2014_aarch64 b/utils/docker/Dockerfile.manylinux2014_aarch64 new file mode 100644 index 00000000..cf0008db --- /dev/null +++ b/utils/docker/Dockerfile.manylinux2014_aarch64 @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +FROM quay.io/pypa/manylinux2014_aarch64 + +MAINTAINER hydai hydai@secondstate.io + +ADD SHA256SUM build-manylinux.sh /root + +RUN cd && yum check-update && yum install -y xz && \ + export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ + 'import multiprocessing; print(multiprocessing.cpu_count())') && \ + export CFGFLAGS="--prefix=/toolchain --disable-shared --libdir=/toolchain/lib64" && \ + export PKG_CONFIG_PATH=/toolchain/lib64/pkgconfig && \ + export CC=gcc && \ + export CXX=g++ && \ + export CPPFLAGS=-I/toolchain/include && \ + export LDFLAGS=-L/toolchain/lib64 && \ + curl -s -L -O --remote-name-all \ + https://ftp.gnu.org/gnu/gmp/gmp-6.2.1.tar.xz \ + https://ftp.gnu.org/gnu/mpfr/mpfr-4.1.0.tar.xz \ + https://ftp.gnu.org/gnu/mpc/mpc-1.2.1.tar.gz \ + http://isl.gforge.inria.fr/isl-0.23.tar.xz \ + https://github.com/facebook/zstd/releases/download/v1.4.7/zstd-1.4.7.tar.gz \ + https://ftp.gnu.org/gnu/gcc/gcc-10.2.0/gcc-10.2.0.tar.xz \ + https://github.com/Kitware/CMake/releases/download/v3.13.5/cmake-3.13.5.tar.gz \ + https://github.com/ninja-build/ninja/archive/v1.10.2.tar.gz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/llvm-11.0.0.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/lld-11.0.0.src.tar.xz && \ + sha256sum -c SHA256SUM && \ + xz -dc gmp-6.2.1.tar.xz | tar -xf - && \ + xz -dc mpfr-4.1.0.tar.xz | tar -xf - && \ + gzip -dc mpc-1.2.1.tar.gz | tar -xf - && \ + xz -dc isl-0.23.tar.xz | tar -xf - && \ + gzip -dc zstd-1.4.7.tar.gz | tar -xf - && \ + xz -dc gcc-10.2.0.tar.xz | tar -xf - && \ + gzip -dc cmake-3.13.5.tar.gz | tar -xf - && \ + gzip -dc v1.10.2.tar.gz | tar -xf - && \ + xz -dc llvm-11.0.0.src.tar.xz | tar -xf - && \ + xz -dc lld-11.0.0.src.tar.xz | tar -xf - && \ + mkdir build && cd build && ../gmp-6.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mkdir build && cd build && ../mpfr-4.1.0/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mkdir build && cd build && ../mpc-1.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mkdir build && cd build && ../isl-0.23/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + export ZSTDFLAGS="PREFIX=/toolchain LIBDIR=/toolchain/lib64 SED_ERE_OPT=--regexp-extended" && \ + cd zstd-1.4.7 && make -s $ZSTDFLAGS -j 1 && make -s $ZSTDFLAGS install && rm -f /toolchain/lib64/libzstd.so* && cd - && \ + mkdir build && cd build && ../gcc-10.2.0/configure --prefix=/toolchain --libdir=/toolchain/lib64 \ + --with-gmp=/toolchain --with-gmp-lib=/toolchain/lib64 \ + --with-zstd=/toolchain --with-zstd-lib=/toolchain/lib64 \ + --disable-libmpx --disable-libsanitizer --disable-libunwind-exceptions \ + --disable-multilib --enable-__cxa_atexit --enable-gnu-indirect-function \ + --enable-gnu-unique-object --enable-initfini-array --enable-languages="c,c++,lto" \ + --enable-linker-build-id --enable-lto --enable-plugin --enable-threads=posix \ + --with-default-libstdcxx-abi="gcc4-compatible" \ + --with-gcc-major-version-only --with-linker-hash-style="gnu" && \ + make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mv -v /toolchain/lib64/libstdc++.so.6.0.28 /toolchain/lib64/libstdc++.so.6.0.28.backup && \ + echo -e "OUTPUT_FORMAT(elf64-aarch64)\nINPUT ( libstdc++.so.6.0.19 libstdc++.a )" \ + > /toolchain/lib64/libstdc++.so.6.0.28 && \ + export PATH="/toolchain/bin:$PATH" && \ + mkdir build && cd build && /opt/python/cp39-cp39/bin/python \ + ../ninja-1.10.2/configure.py --bootstrap \ + --with-python=/opt/python/cp39-cp39/bin/python && \ + cp -v ninja /toolchain/bin/ninja && cd - && rm -rf build && \ + mkdir build && cd build && ../cmake-3.13.5/configure --prefix=/toolchain \ + --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mv -v lld-11.0.0.src lld && \ + cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=/toolchain \ + -DPython3_ROOT_DIR=/opt/python/cp39-cp39 -DLLVM_LIBDIR_SUFFIX=64 \ + -DLLVM_TARGETS_TO_BUILD="AArch64" -DLLVM_ENABLE_PROJECTS=lld \ + -DLLVM_DEFAULT_TARGET_TRIPLE="aarch64-redhat-linux-gnu" \ + llvm-11.0.0.src && cmake --build build --target install && rm -rf build && \ + rm -rf * + +RUN yum clean all diff --git a/utils/docker/Dockerfile.manylinux2014_x86_64 b/utils/docker/Dockerfile.manylinux2014_x86_64 new file mode 100644 index 00000000..6f5a3aeb --- /dev/null +++ b/utils/docker/Dockerfile.manylinux2014_x86_64 @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +FROM quay.io/pypa/manylinux2014_x86_64 + +MAINTAINER hydai hydai@secondstate.io + +ADD SHA256SUM build-manylinux.sh /root + +RUN cd && yum check-update && yum install -y xz && \ + export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ + 'import multiprocessing; print(multiprocessing.cpu_count())') && \ + export CFGFLAGS="--prefix=/toolchain --disable-shared --libdir=/toolchain/lib64" && \ + export PKG_CONFIG_PATH=/toolchain/lib64/pkgconfig && \ + export CC=gcc && \ + export CXX=g++ && \ + export CPPFLAGS=-I/toolchain/include && \ + export LDFLAGS=-L/toolchain/lib64 && \ + curl -s -L -O --remote-name-all \ + https://ftp.gnu.org/gnu/gmp/gmp-6.2.1.tar.xz \ + https://ftp.gnu.org/gnu/mpfr/mpfr-4.1.0.tar.xz \ + https://ftp.gnu.org/gnu/mpc/mpc-1.2.1.tar.gz \ + http://isl.gforge.inria.fr/isl-0.23.tar.xz \ + https://github.com/facebook/zstd/releases/download/v1.4.7/zstd-1.4.7.tar.gz \ + https://ftp.gnu.org/gnu/gcc/gcc-10.2.0/gcc-10.2.0.tar.xz \ + https://github.com/Kitware/CMake/releases/download/v3.13.5/cmake-3.13.5.tar.gz \ + https://github.com/ninja-build/ninja/archive/v1.10.2.tar.gz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/llvm-11.0.0.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/lld-11.0.0.src.tar.xz && \ + sha256sum -c SHA256SUM && \ + xz -dc gmp-6.2.1.tar.xz | tar -xf - && \ + xz -dc mpfr-4.1.0.tar.xz | tar -xf - && \ + gzip -dc mpc-1.2.1.tar.gz | tar -xf - && \ + xz -dc isl-0.23.tar.xz | tar -xf - && \ + gzip -dc zstd-1.4.7.tar.gz | tar -xf - && \ + xz -dc gcc-10.2.0.tar.xz | tar -xf - && \ + gzip -dc cmake-3.13.5.tar.gz | tar -xf - && \ + gzip -dc v1.10.2.tar.gz | tar -xf - && \ + xz -dc llvm-11.0.0.src.tar.xz | tar -xf - && \ + xz -dc lld-11.0.0.src.tar.xz | tar -xf - && \ + mkdir build && cd build && ../gmp-6.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mkdir build && cd build && ../mpfr-4.1.0/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mkdir build && cd build && ../mpc-1.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mkdir build && cd build && ../isl-0.23/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + export ZSTDFLAGS="PREFIX=/toolchain LIBDIR=/toolchain/lib64 SED_ERE_OPT=--regexp-extended" && \ + cd zstd-1.4.7 && make -s $ZSTDFLAGS -j 1 && make -s $ZSTDFLAGS install && rm -f /toolchain/lib64/libzstd.so* && cd - && \ + mkdir build && cd build && ../gcc-10.2.0/configure --prefix=/toolchain --libdir=/toolchain/lib64 \ + --with-gmp=/toolchain --with-gmp-lib=/toolchain/lib64 \ + --with-zstd=/toolchain --with-zstd-lib=/toolchain/lib64 \ + --disable-libmpx --disable-libsanitizer --disable-libunwind-exceptions \ + --disable-multilib --enable-__cxa_atexit --enable-gnu-indirect-function \ + --enable-gnu-unique-object --enable-initfini-array --enable-languages="c,c++,lto" \ + --enable-linker-build-id --enable-lto --enable-plugin --enable-threads=posix \ + --with-default-libstdcxx-abi="gcc4-compatible" \ + --with-gcc-major-version-only --with-linker-hash-style="gnu" \ + --with-tune="generic" && \ + make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mv -v /toolchain/lib64/libstdc++.so.6.0.28 /toolchain/lib64/libstdc++.so.6.0.28.backup && \ + echo -e "OUTPUT_FORMAT(elf64-x86-64)\nINPUT ( libstdc++.so.6.0.19 libstdc++.a )" \ + > /toolchain/lib64/libstdc++.so.6.0.28 && \ + export PATH="/toolchain/bin:$PATH" && \ + mkdir build && cd build && /opt/python/cp39-cp39/bin/python \ + ../ninja-1.10.2/configure.py --bootstrap \ + --with-python=/opt/python/cp39-cp39/bin/python && \ + cp -v ninja /toolchain/bin/ninja && cd - && rm -rf build && \ + mkdir build && cd build && ../cmake-3.13.5/configure --prefix=/toolchain \ + --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mv -v lld-11.0.0.src lld && \ + cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=/toolchain \ + -DPython3_ROOT_DIR=/opt/python/cp39-cp39 -DLLVM_LIBDIR_SUFFIX=64 \ + -DLLVM_TARGETS_TO_BUILD="X86" -DLLVM_ENABLE_PROJECTS=lld \ + -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" \ + llvm-11.0.0.src && cmake --build build --target install && rm -rf build && \ + rm -rf * + +RUN yum clean all diff --git a/utils/docker/SHA256SUM b/utils/docker/SHA256SUM new file mode 100644 index 00000000..08fe0f41 --- /dev/null +++ b/utils/docker/SHA256SUM @@ -0,0 +1,10 @@ +526db6a4b47772d1943b2f86de693e712f9dacf3d7c13b19197c9bef133766a5 cmake-3.13.5.tar.gz +b8dd4368bb9c7f0b98188317ee0254dd8cc99d1e3a18d0ff146c855fe16c1d8c gcc-10.2.0.tar.xz +fd4829912cddd12f84181c3451cc752be224643e87fac497b69edddadc49b4f2 gmp-6.2.1.tar.xz +5efc53efaef151301f4e7dde3856b66812d8153dede24fab17673f801c8698f2 isl-0.23.tar.xz +efe7be4a7b7cdc6f3bcf222827c6f837439e6e656d12d6c885d5c8a80ff4fd1c lld-11.0.0.src.tar.xz +913f68c898dfb4a03b397c5e11c6a2f39d0f22ed7665c9cefa87a34423a72469 llvm-11.0.0.src.tar.xz +17503d2c395dfcf106b622dc142683c1199431d095367c6aacba6eec30340459 mpc-1.2.1.tar.gz +0c98a3f1732ff6ca4ea690552079da9c597872d30e96ec28414ee23c95558a7f mpfr-4.1.0.tar.xz +ce35865411f0490368a8fc383f29071de6690cbadc27704734978221f25e2bed v1.10.2.tar.gz +192cbb1274a9672cbcceaf47b5c4e9e59691ca60a357f1d4a8b2dfa2c365d757 zstd-1.4.7.tar.gz diff --git a/utils/docker/build-manylinux.sh b/utils/docker/build-manylinux.sh new file mode 100755 index 00000000..e7db8a4f --- /dev/null +++ b/utils/docker/build-manylinux.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +export PATH="/toolchain/bin:$PATH" +export CC=gcc +export CXX=g++ + +cd +curl -s -L -O --remote-name-all https://dl.bintray.com/boostorg/release/1.75.0/source/boost_1_75_0.tar.bz2 +bzip2 -dc boost_1_75_0.tar.bz2 | tar -xf - +cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release -DBUILD_PACKAGE=RPM -DBOOST_INCLUDEDIR=$(pwd)/boost_1_75_0/ /ssvm +cmake --build build +cmake --build build --target package +cp -v build/SSVM-*.rpm /ssvm/SSVM-manylinux1.rpm diff --git a/utils/docker/llvm-glibc-2.5.patch b/utils/docker/llvm-glibc-2.5.patch new file mode 100644 index 00000000..e2afabb4 --- /dev/null +++ b/utils/docker/llvm-glibc-2.5.patch @@ -0,0 +1,39 @@ +--- a/lib/Support/Host.cpp 2020-12-17 20:09:25.321395012 +0000 ++++ b/lib/Support/Host.cpp 2020-12-17 20:29:40.296551916 +0000 +@@ -1224,6 +1224,15 @@ StringRef sys::getHostCPUName() { return + #endif + + #if defined(__linux__) && (defined(__i386__) || defined(__x86_64__)) ++#if !defined(CPU_COUNT) ++static inline auto CPU_COUNT(const cpu_set_t *Set) noexcept { ++ int Count = 0; ++ for (const auto &Bits : Set->__bits) { ++ Count += __builtin_popcountl(Bits); ++ } ++ return Count; ++} ++#endif + // On Linux, the number of physical cores can be computed from /proc/cpuinfo, + // using the number of unique physical/core id pairs. The following + // implementation reads the /proc/cpuinfo format on an x86_64 system. + +--- a/lib/Support/Unix/Threading.inc 2020-12-17 20:09:25.325395024 +0000 ++++ b/lib/Support/Unix/Threading.inc 2020-12-17 20:24:57.267834738 +0000 +@@ -281,6 +281,16 @@ SetThreadPriorityResult llvm::set_thread + + #include + ++#if !defined(CPU_COUNT) ++static inline auto CPU_COUNT(const cpu_set_t *Set) noexcept { ++ int Count = 0; ++ for (const auto &Bits : Set->__bits) { ++ Count += __builtin_popcountl(Bits); ++ } ++ return Count; ++} ++#endif ++ + int computeHostNumHardwareThreads() { + #ifdef __linux__ + cpu_set_t Set; + From f726d1e68cce50203a7aaf87dafe033a4caf527f Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Tue, 22 Dec 2020 19:05:38 +0000 Subject: [PATCH 009/623] [Misc] Upgrade cmake to 3.19.2 --- utils/docker/Dockerfile.manylinux1_x86_64 | 13 +- utils/docker/Dockerfile.manylinux2010_x86_64 | 8 +- utils/docker/Dockerfile.manylinux2014_aarch64 | 8 +- utils/docker/Dockerfile.manylinux2014_x86_64 | 8 +- utils/docker/SHA256SUM | 2 +- utils/docker/cmake-glibc-2.5.patch | 193 ++++++++++++++++++ 6 files changed, 213 insertions(+), 19 deletions(-) create mode 100644 utils/docker/cmake-glibc-2.5.patch diff --git a/utils/docker/Dockerfile.manylinux1_x86_64 b/utils/docker/Dockerfile.manylinux1_x86_64 index 122b0f1a..08134b1f 100644 --- a/utils/docker/Dockerfile.manylinux1_x86_64 +++ b/utils/docker/Dockerfile.manylinux1_x86_64 @@ -3,9 +3,9 @@ FROM quay.io/pypa/manylinux1_x86_64 MAINTAINER hydai hydai@secondstate.io -ADD SHA256SUM llvm-glibc-2.5.patch build-manylinux.sh /root +ADD SHA256SUM llvm-glibc-2.5.patch cmake-glibc-2.5.patch build-manylinux.sh /root -RUN cd && yum check-update && yum install -y xz && \ +RUN cd && yum check-update && yum install -y xz openssl-devel && \ export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ 'import multiprocessing; print(multiprocessing.cpu_count())') && \ export CFGFLAGS="--prefix=/toolchain --disable-shared --libdir=/toolchain/lib64" && \ @@ -21,7 +21,7 @@ RUN cd && yum check-update && yum install -y xz && \ http://isl.gforge.inria.fr/isl-0.23.tar.xz \ https://github.com/facebook/zstd/releases/download/v1.4.7/zstd-1.4.7.tar.gz \ https://ftp.gnu.org/gnu/gcc/gcc-10.2.0/gcc-10.2.0.tar.xz \ - https://github.com/Kitware/CMake/releases/download/v3.13.5/cmake-3.13.5.tar.gz \ + https://github.com/Kitware/CMake/releases/download/v3.19.2/cmake-3.19.2.tar.gz \ https://github.com/ninja-build/ninja/archive/v1.10.2.tar.gz \ https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/llvm-11.0.0.src.tar.xz \ https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/lld-11.0.0.src.tar.xz && \ @@ -32,10 +32,12 @@ RUN cd && yum check-update && yum install -y xz && \ xz -dc isl-0.23.tar.xz | tar -xf - && \ gzip -dc zstd-1.4.7.tar.gz | tar -xf - && \ xz -dc gcc-10.2.0.tar.xz | tar -xf - && \ - gzip -dc cmake-3.13.5.tar.gz | tar -xf - && \ + gzip -dc cmake-3.19.2.tar.gz | tar -xf - && \ gzip -dc v1.10.2.tar.gz | tar -xf - && \ xz -dc llvm-11.0.0.src.tar.xz | tar -xf - && \ xz -dc lld-11.0.0.src.tar.xz | tar -xf - && \ + cd llvm-11.0.0.src && patch -p1 < ../llvm-glibc-2.5.patch && cd - && \ + cd cmake-3.19.2 && patch -p1 < ../cmake-glibc-2.5.patch && cd - && \ mkdir build && cd build && ../gmp-6.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpfr-4.1.0/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpc-1.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ @@ -61,10 +63,9 @@ RUN cd && yum check-update && yum install -y xz && \ ../ninja-1.10.2/configure.py --bootstrap \ --with-python=/opt/python/cp39-cp39/bin/python && \ cp -v ninja /toolchain/bin/ninja && cd - && rm -rf build && \ - mkdir build && cd build && ../cmake-3.13.5/configure --prefix=/toolchain \ + mkdir build && cd build && ../cmake-3.19.2/configure --prefix=/toolchain \ --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mv -v lld-11.0.0.src lld && \ - cd llvm-11.0.0.src && patch -p1 < ../llvm-glibc-2.5.patch && cd - && \ cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_INSTALL_PREFIX=/toolchain \ -DPython3_ROOT_DIR=/opt/python/cp39-cp39 -DLLVM_LIBDIR_SUFFIX=64 \ diff --git a/utils/docker/Dockerfile.manylinux2010_x86_64 b/utils/docker/Dockerfile.manylinux2010_x86_64 index 5a03df34..3bd7a955 100644 --- a/utils/docker/Dockerfile.manylinux2010_x86_64 +++ b/utils/docker/Dockerfile.manylinux2010_x86_64 @@ -5,7 +5,7 @@ MAINTAINER hydai hydai@secondstate.io ADD SHA256SUM build-manylinux.sh /root -RUN cd && yum check-update && yum install -y xz && \ +RUN cd && yum check-update && yum install -y xz openssl-devel && \ export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ 'import multiprocessing; print(multiprocessing.cpu_count())') && \ export CFGFLAGS="--prefix=/toolchain --disable-shared --libdir=/toolchain/lib64" && \ @@ -21,7 +21,7 @@ RUN cd && yum check-update && yum install -y xz && \ http://isl.gforge.inria.fr/isl-0.23.tar.xz \ https://github.com/facebook/zstd/releases/download/v1.4.7/zstd-1.4.7.tar.gz \ https://ftp.gnu.org/gnu/gcc/gcc-10.2.0/gcc-10.2.0.tar.xz \ - https://github.com/Kitware/CMake/releases/download/v3.13.5/cmake-3.13.5.tar.gz \ + https://github.com/Kitware/CMake/releases/download/v3.19.2/cmake-3.19.2.tar.gz \ https://github.com/ninja-build/ninja/archive/v1.10.2.tar.gz \ https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/llvm-11.0.0.src.tar.xz \ https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/lld-11.0.0.src.tar.xz && \ @@ -32,7 +32,7 @@ RUN cd && yum check-update && yum install -y xz && \ xz -dc isl-0.23.tar.xz | tar -xf - && \ gzip -dc zstd-1.4.7.tar.gz | tar -xf - && \ xz -dc gcc-10.2.0.tar.xz | tar -xf - && \ - gzip -dc cmake-3.13.5.tar.gz | tar -xf - && \ + gzip -dc cmake-3.19.2.tar.gz | tar -xf - && \ gzip -dc v1.10.2.tar.gz | tar -xf - && \ xz -dc llvm-11.0.0.src.tar.xz | tar -xf - && \ xz -dc lld-11.0.0.src.tar.xz | tar -xf - && \ @@ -61,7 +61,7 @@ RUN cd && yum check-update && yum install -y xz && \ ../ninja-1.10.2/configure.py --bootstrap \ --with-python=/opt/python/cp39-cp39/bin/python && \ cp -v ninja /toolchain/bin/ninja && cd - && rm -rf build && \ - mkdir build && cd build && ../cmake-3.13.5/configure --prefix=/toolchain \ + mkdir build && cd build && ../cmake-3.19.2/configure --prefix=/toolchain \ --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mv -v lld-11.0.0.src lld && \ cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ diff --git a/utils/docker/Dockerfile.manylinux2014_aarch64 b/utils/docker/Dockerfile.manylinux2014_aarch64 index cf0008db..f129f132 100644 --- a/utils/docker/Dockerfile.manylinux2014_aarch64 +++ b/utils/docker/Dockerfile.manylinux2014_aarch64 @@ -5,7 +5,7 @@ MAINTAINER hydai hydai@secondstate.io ADD SHA256SUM build-manylinux.sh /root -RUN cd && yum check-update && yum install -y xz && \ +RUN cd && yum check-update && yum install -y xz openssl-devel && \ export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ 'import multiprocessing; print(multiprocessing.cpu_count())') && \ export CFGFLAGS="--prefix=/toolchain --disable-shared --libdir=/toolchain/lib64" && \ @@ -21,7 +21,7 @@ RUN cd && yum check-update && yum install -y xz && \ http://isl.gforge.inria.fr/isl-0.23.tar.xz \ https://github.com/facebook/zstd/releases/download/v1.4.7/zstd-1.4.7.tar.gz \ https://ftp.gnu.org/gnu/gcc/gcc-10.2.0/gcc-10.2.0.tar.xz \ - https://github.com/Kitware/CMake/releases/download/v3.13.5/cmake-3.13.5.tar.gz \ + https://github.com/Kitware/CMake/releases/download/v3.19.2/cmake-3.19.2.tar.gz \ https://github.com/ninja-build/ninja/archive/v1.10.2.tar.gz \ https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/llvm-11.0.0.src.tar.xz \ https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/lld-11.0.0.src.tar.xz && \ @@ -32,7 +32,7 @@ RUN cd && yum check-update && yum install -y xz && \ xz -dc isl-0.23.tar.xz | tar -xf - && \ gzip -dc zstd-1.4.7.tar.gz | tar -xf - && \ xz -dc gcc-10.2.0.tar.xz | tar -xf - && \ - gzip -dc cmake-3.13.5.tar.gz | tar -xf - && \ + gzip -dc cmake-3.19.2.tar.gz | tar -xf - && \ gzip -dc v1.10.2.tar.gz | tar -xf - && \ xz -dc llvm-11.0.0.src.tar.xz | tar -xf - && \ xz -dc lld-11.0.0.src.tar.xz | tar -xf - && \ @@ -60,7 +60,7 @@ RUN cd && yum check-update && yum install -y xz && \ ../ninja-1.10.2/configure.py --bootstrap \ --with-python=/opt/python/cp39-cp39/bin/python && \ cp -v ninja /toolchain/bin/ninja && cd - && rm -rf build && \ - mkdir build && cd build && ../cmake-3.13.5/configure --prefix=/toolchain \ + mkdir build && cd build && ../cmake-3.19.2/configure --prefix=/toolchain \ --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mv -v lld-11.0.0.src lld && \ cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ diff --git a/utils/docker/Dockerfile.manylinux2014_x86_64 b/utils/docker/Dockerfile.manylinux2014_x86_64 index 6f5a3aeb..323ea181 100644 --- a/utils/docker/Dockerfile.manylinux2014_x86_64 +++ b/utils/docker/Dockerfile.manylinux2014_x86_64 @@ -5,7 +5,7 @@ MAINTAINER hydai hydai@secondstate.io ADD SHA256SUM build-manylinux.sh /root -RUN cd && yum check-update && yum install -y xz && \ +RUN cd && yum check-update && yum install -y xz openssl-devel && \ export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ 'import multiprocessing; print(multiprocessing.cpu_count())') && \ export CFGFLAGS="--prefix=/toolchain --disable-shared --libdir=/toolchain/lib64" && \ @@ -21,7 +21,7 @@ RUN cd && yum check-update && yum install -y xz && \ http://isl.gforge.inria.fr/isl-0.23.tar.xz \ https://github.com/facebook/zstd/releases/download/v1.4.7/zstd-1.4.7.tar.gz \ https://ftp.gnu.org/gnu/gcc/gcc-10.2.0/gcc-10.2.0.tar.xz \ - https://github.com/Kitware/CMake/releases/download/v3.13.5/cmake-3.13.5.tar.gz \ + https://github.com/Kitware/CMake/releases/download/v3.19.2/cmake-3.19.2.tar.gz \ https://github.com/ninja-build/ninja/archive/v1.10.2.tar.gz \ https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/llvm-11.0.0.src.tar.xz \ https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/lld-11.0.0.src.tar.xz && \ @@ -32,7 +32,7 @@ RUN cd && yum check-update && yum install -y xz && \ xz -dc isl-0.23.tar.xz | tar -xf - && \ gzip -dc zstd-1.4.7.tar.gz | tar -xf - && \ xz -dc gcc-10.2.0.tar.xz | tar -xf - && \ - gzip -dc cmake-3.13.5.tar.gz | tar -xf - && \ + gzip -dc cmake-3.19.2.tar.gz | tar -xf - && \ gzip -dc v1.10.2.tar.gz | tar -xf - && \ xz -dc llvm-11.0.0.src.tar.xz | tar -xf - && \ xz -dc lld-11.0.0.src.tar.xz | tar -xf - && \ @@ -61,7 +61,7 @@ RUN cd && yum check-update && yum install -y xz && \ ../ninja-1.10.2/configure.py --bootstrap \ --with-python=/opt/python/cp39-cp39/bin/python && \ cp -v ninja /toolchain/bin/ninja && cd - && rm -rf build && \ - mkdir build && cd build && ../cmake-3.13.5/configure --prefix=/toolchain \ + mkdir build && cd build && ../cmake-3.19.2/configure --prefix=/toolchain \ --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mv -v lld-11.0.0.src lld && \ cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ diff --git a/utils/docker/SHA256SUM b/utils/docker/SHA256SUM index 08fe0f41..398f25a4 100644 --- a/utils/docker/SHA256SUM +++ b/utils/docker/SHA256SUM @@ -1,4 +1,4 @@ -526db6a4b47772d1943b2f86de693e712f9dacf3d7c13b19197c9bef133766a5 cmake-3.13.5.tar.gz +e3e0fd3b23b7fb13e1a856581078e0776ffa2df4e9d3164039c36d3315e0c7f0 cmake-3.19.2.tar.gz b8dd4368bb9c7f0b98188317ee0254dd8cc99d1e3a18d0ff146c855fe16c1d8c gcc-10.2.0.tar.xz fd4829912cddd12f84181c3451cc752be224643e87fac497b69edddadc49b4f2 gmp-6.2.1.tar.xz 5efc53efaef151301f4e7dde3856b66812d8153dede24fab17673f801c8698f2 isl-0.23.tar.xz diff --git a/utils/docker/cmake-glibc-2.5.patch b/utils/docker/cmake-glibc-2.5.patch new file mode 100644 index 00000000..bada2453 --- /dev/null +++ b/utils/docker/cmake-glibc-2.5.patch @@ -0,0 +1,193 @@ +diff -rup a/Utilities/cmlibuv/src/unix/async.c b/Utilities/cmlibuv/src/unix/async.c +--- a/Utilities/cmlibuv/src/unix/async.c 2020-12-16 12:35:29.000000000 +0000 ++++ b/Utilities/cmlibuv/src/unix/async.c 2020-12-22 18:36:14.000000000 +0000 +@@ -34,7 +34,7 @@ + #include + #include /* sched_yield() */ + +-#ifdef __linux__ ++#if defined(__linux__) && __GLIBC_PREREQ(2, 8) + #include + #endif + +@@ -175,7 +175,7 @@ static void uv__async_send(uv_loop_t* lo + len = 1; + fd = loop->async_wfd; + +-#if defined(__linux__) ++#if defined(__linux__) && __GLIBC_PREREQ(2, 8) + if (fd == -1) { + static const uint64_t val = 1; + buf = &val; +@@ -206,7 +206,7 @@ static int uv__async_start(uv_loop_t* lo + if (loop->async_io_watcher.fd != -1) + return 0; + +-#ifdef __linux__ ++#if defined(__linux__) && __GLIBC_PREREQ(2, 8) + err = eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK); + if (err < 0) + return UV__ERR(errno); +diff -rup a/Utilities/cmlibuv/src/unix/core.c b/Utilities/cmlibuv/src/unix/core.c +--- a/Utilities/cmlibuv/src/unix/core.c 2020-12-16 12:35:29.000000000 +0000 ++++ b/Utilities/cmlibuv/src/unix/core.c 2020-12-22 18:07:40.000000000 +0000 +@@ -88,7 +88,9 @@ extern char** environ; + + #if defined(__linux__) + # include +-# define uv__accept4 accept4 ++# if __GLIBC_PREREQ(2, 10) ++# define uv__accept4 accept4 ++# endif + #endif + + static int uv__run_pending(uv_loop_t* loop); +@@ -1032,7 +1034,7 @@ int uv__open_cloexec(const char* path, i + + + int uv__dup2_cloexec(int oldfd, int newfd) { +-#if defined(__FreeBSD__) || defined(__NetBSD__) || defined(__linux__) ++#if defined(__FreeBSD__) || defined(__NetBSD__) || (defined(__linux__) && __GLIBC_PREREQ(2, 9)) + int r; + + r = dup3(oldfd, newfd, O_CLOEXEC); +diff -rup a/Utilities/cmlibuv/src/unix/fs.c b/Utilities/cmlibuv/src/unix/fs.c +--- a/Utilities/cmlibuv/src/unix/fs.c 2020-12-16 12:35:29.000000000 +0000 ++++ b/Utilities/cmlibuv/src/unix/fs.c 2020-12-22 18:10:56.000000000 +0000 +@@ -224,7 +224,7 @@ UV_UNUSED(static struct timeval uv__fs_t + } + + static ssize_t uv__fs_futime(uv_fs_t* req) { +-#if defined(__linux__) \ ++#if (defined(__linux__) && __GLIBC_PREREQ(2, 6)) \ + || defined(_AIX71) \ + || defined(__HAIKU__) + /* utimesat() has nanosecond resolution but we stick to microseconds +@@ -234,7 +234,8 @@ static ssize_t uv__fs_futime(uv_fs_t* re + ts[0] = uv__fs_to_timespec(req->atime); + ts[1] = uv__fs_to_timespec(req->mtime); + return futimens(req->file, ts); +-#elif defined(__APPLE__) \ ++#elif (defined(__linux__) && !__GLIBC_PREREQ(2, 6)) \ ++ || defined(__APPLE__) \ + || defined(__DragonFly__) \ + || defined(__FreeBSD__) \ + || defined(__FreeBSD_kernel__) \ +@@ -1016,7 +1017,7 @@ ok: + + + static ssize_t uv__fs_utime(uv_fs_t* req) { +-#if defined(__linux__) \ ++#if (defined(__linux__) && __GLIBC_PREREQ(2, 6)) \ + || defined(_AIX71) \ + || defined(__sun) \ + || defined(__HAIKU__) +@@ -1027,7 +1028,8 @@ static ssize_t uv__fs_utime(uv_fs_t* req + ts[0] = uv__fs_to_timespec(req->atime); + ts[1] = uv__fs_to_timespec(req->mtime); + return utimensat(AT_FDCWD, req->path, ts, 0); +-#elif defined(__APPLE__) \ ++#elif (defined(__linux__) && !__GLIBC_PREREQ(2, 6)) \ ++ || defined(__APPLE__) \ + || defined(__DragonFly__) \ + || defined(__FreeBSD__) \ + || defined(__FreeBSD_kernel__) \ +@@ -1059,7 +1061,7 @@ static ssize_t uv__fs_utime(uv_fs_t* req + + + static ssize_t uv__fs_lutime(uv_fs_t* req) { +-#if defined(__linux__) || \ ++#if (defined(__linux__) && __GLIBC_PREREQ(2, 6)) || \ + defined(_AIX71) || \ + defined(__sun) || \ + defined(__HAIKU__) +@@ -1067,7 +1069,8 @@ static ssize_t uv__fs_lutime(uv_fs_t* re + ts[0] = uv__fs_to_timespec(req->atime); + ts[1] = uv__fs_to_timespec(req->mtime); + return utimensat(AT_FDCWD, req->path, ts, AT_SYMLINK_NOFOLLOW); +-#elif defined(__APPLE__) || \ ++#elif (defined(__linux__) && !__GLIBC_PREREQ(2, 6)) || \ ++ defined(__APPLE__) || \ + defined(__DragonFly__) || \ + defined(__FreeBSD__) || \ + defined(__FreeBSD_kernel__) || \ +diff -rup a/Utilities/cmlibuv/src/unix/linux-core.c b/Utilities/cmlibuv/src/unix/linux-core.c +--- a/Utilities/cmlibuv/src/unix/linux-core.c 2020-12-16 12:35:29.000000000 +0000 ++++ b/Utilities/cmlibuv/src/unix/linux-core.c 2020-12-22 18:13:06.000000000 +0000 +@@ -85,7 +85,12 @@ static uint64_t read_cpufreq(unsigned in + + int uv__platform_loop_init(uv_loop_t* loop) { + int fd; ++#if __GLIBC_PREREQ(2, 9) + fd = epoll_create1(O_CLOEXEC); ++#else ++ fd = -1; ++ errno = ENOSYS; ++#endif + + /* epoll_create1() can fail either because it's not implemented (old kernel) + * or because it doesn't understand the O_CLOEXEC flag. +@@ -311,11 +316,16 @@ void uv__io_poll(uv_loop_t* loop, int ti + abort(); + + if (no_epoll_wait != 0 || (sigmask != 0 && no_epoll_pwait == 0)) { ++#if __GLIBC_PREREQ(2, 6) + nfds = epoll_pwait(loop->backend_fd, + events, + ARRAY_SIZE(events), + timeout, + &sigset); ++#else ++ nfds = -1; ++ errno = ENOSYS; ++#endif + if (nfds == -1 && errno == ENOSYS) { + uv__store_relaxed(&no_epoll_pwait_cached, 1); + no_epoll_pwait = 1; +diff -rup a/Utilities/cmlibuv/src/unix/linux-inotify.c b/Utilities/cmlibuv/src/unix/linux-inotify.c +--- a/Utilities/cmlibuv/src/unix/linux-inotify.c 2020-12-16 12:35:29.000000000 +0000 ++++ b/Utilities/cmlibuv/src/unix/linux-inotify.c 2020-12-22 18:16:16.000000000 +0000 +@@ -71,10 +71,22 @@ static int init_inotify(uv_loop_t* loop) + if (loop->inotify_fd != -1) + return 0; + ++#if __GLIBC_PREREQ(2, 6) + fd = inotify_init1(IN_NONBLOCK | IN_CLOEXEC); ++#else ++ fd = inotify_init(); ++#endif ++ + if (fd < 0) + return UV__ERR(errno); + ++#if !__GLIBC_PREREQ(2, 6) ++ if (uv__nonblock(fd, 1) || uv__cloexec(fd, 1)) { ++ uv__close(fd); ++ return UV__ERR(errno); ++ } ++#endif ++ + loop->inotify_fd = fd; + uv__io_init(&loop->inotify_read_watcher, uv__inotify_read, loop->inotify_fd); + uv__io_start(loop, &loop->inotify_read_watcher, POLLIN); +diff -rup a/Utilities/cmlibuv/src/unix/process.c b/Utilities/cmlibuv/src/unix/process.c +--- a/Utilities/cmlibuv/src/unix/process.c 2020-12-16 12:35:29.000000000 +0000 ++++ b/Utilities/cmlibuv/src/unix/process.c 2020-12-22 18:23:18.000000000 +0000 +@@ -124,7 +124,7 @@ static void uv__chld(uv_signal_t* handle + + + static int uv__make_socketpair(int fds[2]) { +-#if defined(__FreeBSD__) || defined(__linux__) ++#if defined(__FreeBSD__) || (defined(__linux__) && __GLIBC_PREREQ(2, 9)) + if (socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, fds)) + return UV__ERR(errno); + +@@ -151,7 +151,7 @@ static int uv__make_socketpair(int fds[2 + + + int uv__make_pipe(int fds[2], int flags) { +-#if defined(__FreeBSD__) || defined(__linux__) ++#if defined(__FreeBSD__) || (defined(__linux__) && __GLIBC_PREREQ(2, 9)) + if (pipe2(fds, flags | O_CLOEXEC)) + return UV__ERR(errno); + From 85a8673fb28d55f878f2b6a37c084a4a57f48c8f Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Thu, 24 Dec 2020 06:14:35 +0000 Subject: [PATCH 010/623] [Misc] Add manylinux* package on release --- utils/docker/Dockerfile.manylinux1_x86_64 | 1 + utils/docker/Dockerfile.manylinux2010_x86_64 | 1 + utils/docker/Dockerfile.manylinux2014_aarch64 | 1 + utils/docker/Dockerfile.manylinux2014_x86_64 | 1 + utils/docker/build-manylinux.sh | 9 +++++++-- 5 files changed, 11 insertions(+), 2 deletions(-) diff --git a/utils/docker/Dockerfile.manylinux1_x86_64 b/utils/docker/Dockerfile.manylinux1_x86_64 index 08134b1f..064ea304 100644 --- a/utils/docker/Dockerfile.manylinux1_x86_64 +++ b/utils/docker/Dockerfile.manylinux1_x86_64 @@ -72,6 +72,7 @@ RUN cd && yum check-update && yum install -y xz openssl-devel && \ -DLLVM_TARGETS_TO_BUILD="X86" -DLLVM_ENABLE_PROJECTS=lld \ -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" \ llvm-11.0.0.src && cmake --build build --target install && rm -rf build && \ + mv build-manylinux.sh /toolchain/bin/ && \ rm -rf * RUN yum clean all diff --git a/utils/docker/Dockerfile.manylinux2010_x86_64 b/utils/docker/Dockerfile.manylinux2010_x86_64 index 3bd7a955..294e5cdc 100644 --- a/utils/docker/Dockerfile.manylinux2010_x86_64 +++ b/utils/docker/Dockerfile.manylinux2010_x86_64 @@ -70,6 +70,7 @@ RUN cd && yum check-update && yum install -y xz openssl-devel && \ -DLLVM_TARGETS_TO_BUILD="X86" -DLLVM_ENABLE_PROJECTS=lld \ -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" \ llvm-11.0.0.src && cmake --build build --target install && rm -rf build && \ + mv build-manylinux.sh /toolchain/bin/ && \ rm -rf * RUN yum clean all diff --git a/utils/docker/Dockerfile.manylinux2014_aarch64 b/utils/docker/Dockerfile.manylinux2014_aarch64 index f129f132..92524b0a 100644 --- a/utils/docker/Dockerfile.manylinux2014_aarch64 +++ b/utils/docker/Dockerfile.manylinux2014_aarch64 @@ -69,6 +69,7 @@ RUN cd && yum check-update && yum install -y xz openssl-devel && \ -DLLVM_TARGETS_TO_BUILD="AArch64" -DLLVM_ENABLE_PROJECTS=lld \ -DLLVM_DEFAULT_TARGET_TRIPLE="aarch64-redhat-linux-gnu" \ llvm-11.0.0.src && cmake --build build --target install && rm -rf build && \ + mv build-manylinux.sh /toolchain/bin/ && \ rm -rf * RUN yum clean all diff --git a/utils/docker/Dockerfile.manylinux2014_x86_64 b/utils/docker/Dockerfile.manylinux2014_x86_64 index 323ea181..18f4daba 100644 --- a/utils/docker/Dockerfile.manylinux2014_x86_64 +++ b/utils/docker/Dockerfile.manylinux2014_x86_64 @@ -70,6 +70,7 @@ RUN cd && yum check-update && yum install -y xz openssl-devel && \ -DLLVM_TARGETS_TO_BUILD="X86" -DLLVM_ENABLE_PROJECTS=lld \ -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" \ llvm-11.0.0.src && cmake --build build --target install && rm -rf build && \ + mv build-manylinux.sh /toolchain/bin/ && \ rm -rf * RUN yum clean all diff --git a/utils/docker/build-manylinux.sh b/utils/docker/build-manylinux.sh index e7db8a4f..259e0711 100755 --- a/utils/docker/build-manylinux.sh +++ b/utils/docker/build-manylinux.sh @@ -7,8 +7,13 @@ export CXX=g++ cd curl -s -L -O --remote-name-all https://dl.bintray.com/boostorg/release/1.75.0/source/boost_1_75_0.tar.bz2 +echo 953db31e016db7bb207f11432bef7df100516eeb746843fa0486a222e3fd49cb boost_1_75_0.tar.bz2 | sha256sum -c bzip2 -dc boost_1_75_0.tar.bz2 | tar -xf - -cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release -DBUILD_PACKAGE=RPM -DBOOST_INCLUDEDIR=$(pwd)/boost_1_75_0/ /ssvm +cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release -DBUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM" -DBOOST_INCLUDEDIR=$(pwd)/boost_1_75_0/ /ssvm cmake --build build cmake --build build --target package -cp -v build/SSVM-*.rpm /ssvm/SSVM-manylinux1.rpm +cp -v build/SSVM-*.tar.gz /ssvm/SSVM.tar.gz +cp -v build/SSVM-*.tar.bz2 /ssvm/SSVM.tar.bz2 +cp -v build/SSVM-*.tar.xz /ssvm/SSVM.tar.xz +cp -v build/SSVM-*.tar.zst /ssvm/SSVM.tar.zst +cp -v build/SSVM-*.rpm /ssvm/SSVM.rpm From de28c6375861a76127c09f96e1a0553285671eb6 Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 19 Jan 2021 15:30:29 +0000 Subject: [PATCH 011/623] [CI] Create a CI docker env for building the remaining docker images --- utils/docker/Dockerfile.ci-image-base | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 utils/docker/Dockerfile.ci-image-base diff --git a/utils/docker/Dockerfile.ci-image-base b/utils/docker/Dockerfile.ci-image-base new file mode 100644 index 00000000..d381de58 --- /dev/null +++ b/utils/docker/Dockerfile.ci-image-base @@ -0,0 +1,23 @@ +FROM ubuntu:20.04 + +MAINTAINER hydai hydai@secondstate.io +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt update && apt upgrade -y \ + && apt install -y \ + apt-transport-https \ + ca-certificates \ + curl \ + gnupg-agent \ + software-properties-common + +RUN curl -fsSL https://download.docker.com/linux/ubuntu/gpg | apt-key add - +RUN add-apt-repository \ + "deb [arch=amd64] https://download.docker.com/linux/ubuntu \ + $(lsb_release -cs) \ + stable" + +RUN apt update && apt install -y \ + docker-ce \ + docker-ce-cli \ + containerd.io From 8bdf8305ae423656807498c07c9b2fa9eb46a006 Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 3 Feb 2021 09:26:48 +0000 Subject: [PATCH 012/623] [CI] Ignore exit code 100 and move rpm-build dependencies to the dockerfile --- utils/docker/Dockerfile.manylinux1_x86_64 | 2 +- utils/docker/Dockerfile.manylinux2010_x86_64 | 2 +- utils/docker/Dockerfile.manylinux2014_aarch64 | 2 +- utils/docker/Dockerfile.manylinux2014_x86_64 | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/utils/docker/Dockerfile.manylinux1_x86_64 b/utils/docker/Dockerfile.manylinux1_x86_64 index 064ea304..3fd86ec6 100644 --- a/utils/docker/Dockerfile.manylinux1_x86_64 +++ b/utils/docker/Dockerfile.manylinux1_x86_64 @@ -5,7 +5,7 @@ MAINTAINER hydai hydai@secondstate.io ADD SHA256SUM llvm-glibc-2.5.patch cmake-glibc-2.5.patch build-manylinux.sh /root -RUN cd && yum check-update && yum install -y xz openssl-devel && \ +RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build && \ export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ 'import multiprocessing; print(multiprocessing.cpu_count())') && \ export CFGFLAGS="--prefix=/toolchain --disable-shared --libdir=/toolchain/lib64" && \ diff --git a/utils/docker/Dockerfile.manylinux2010_x86_64 b/utils/docker/Dockerfile.manylinux2010_x86_64 index 294e5cdc..f9f134d1 100644 --- a/utils/docker/Dockerfile.manylinux2010_x86_64 +++ b/utils/docker/Dockerfile.manylinux2010_x86_64 @@ -5,7 +5,7 @@ MAINTAINER hydai hydai@secondstate.io ADD SHA256SUM build-manylinux.sh /root -RUN cd && yum check-update && yum install -y xz openssl-devel && \ +RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build && \ export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ 'import multiprocessing; print(multiprocessing.cpu_count())') && \ export CFGFLAGS="--prefix=/toolchain --disable-shared --libdir=/toolchain/lib64" && \ diff --git a/utils/docker/Dockerfile.manylinux2014_aarch64 b/utils/docker/Dockerfile.manylinux2014_aarch64 index 92524b0a..7bbd8d7e 100644 --- a/utils/docker/Dockerfile.manylinux2014_aarch64 +++ b/utils/docker/Dockerfile.manylinux2014_aarch64 @@ -5,7 +5,7 @@ MAINTAINER hydai hydai@secondstate.io ADD SHA256SUM build-manylinux.sh /root -RUN cd && yum check-update && yum install -y xz openssl-devel && \ +RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build && \ export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ 'import multiprocessing; print(multiprocessing.cpu_count())') && \ export CFGFLAGS="--prefix=/toolchain --disable-shared --libdir=/toolchain/lib64" && \ diff --git a/utils/docker/Dockerfile.manylinux2014_x86_64 b/utils/docker/Dockerfile.manylinux2014_x86_64 index 18f4daba..a6f2b411 100644 --- a/utils/docker/Dockerfile.manylinux2014_x86_64 +++ b/utils/docker/Dockerfile.manylinux2014_x86_64 @@ -5,7 +5,7 @@ MAINTAINER hydai hydai@secondstate.io ADD SHA256SUM build-manylinux.sh /root -RUN cd && yum check-update && yum install -y xz openssl-devel && \ +RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build && \ export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ 'import multiprocessing; print(multiprocessing.cpu_count())') && \ export CFGFLAGS="--prefix=/toolchain --disable-shared --libdir=/toolchain/lib64" && \ From 97c183c2d89d0c13f8d68e7f40e17fb82d431930 Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Mon, 3 May 2021 10:06:56 +0000 Subject: [PATCH 013/623] [Misc] Rename SSVM to WasmEdge --- utils/docker/Dockerfile.build-clang | 2 +- utils/docker/Dockerfile.build-gcc | 2 +- utils/docker/build-manylinux.sh | 12 ++++++------ utils/docker/build.sh | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/utils/docker/Dockerfile.build-clang b/utils/docker/Dockerfile.build-clang index 23355d88..06fed0dc 100644 --- a/utils/docker/Dockerfile.build-clang +++ b/utils/docker/Dockerfile.build-clang @@ -1,4 +1,4 @@ -ARG BASE=secondstate/ssvm:ubuntu-base +ARG BASE=secondstate/wasmedge:ubuntu-base FROM ${BASE} RUN apt update && apt install -y \ diff --git a/utils/docker/Dockerfile.build-gcc b/utils/docker/Dockerfile.build-gcc index 014594ed..5ba1028b 100644 --- a/utils/docker/Dockerfile.build-gcc +++ b/utils/docker/Dockerfile.build-gcc @@ -1,4 +1,4 @@ -ARG BASE=secondstate/ssvm:ubuntu-base +ARG BASE=secondstate/wasmedge:ubuntu-base FROM ${BASE} RUN apt update && apt install -y \ diff --git a/utils/docker/build-manylinux.sh b/utils/docker/build-manylinux.sh index 259e0711..eba0bd5d 100755 --- a/utils/docker/build-manylinux.sh +++ b/utils/docker/build-manylinux.sh @@ -9,11 +9,11 @@ cd curl -s -L -O --remote-name-all https://dl.bintray.com/boostorg/release/1.75.0/source/boost_1_75_0.tar.bz2 echo 953db31e016db7bb207f11432bef7df100516eeb746843fa0486a222e3fd49cb boost_1_75_0.tar.bz2 | sha256sum -c bzip2 -dc boost_1_75_0.tar.bz2 | tar -xf - -cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release -DBUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM" -DBOOST_INCLUDEDIR=$(pwd)/boost_1_75_0/ /ssvm +cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release -DBUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM" -DBOOST_INCLUDEDIR=$(pwd)/boost_1_75_0/ /wasmedge cmake --build build cmake --build build --target package -cp -v build/SSVM-*.tar.gz /ssvm/SSVM.tar.gz -cp -v build/SSVM-*.tar.bz2 /ssvm/SSVM.tar.bz2 -cp -v build/SSVM-*.tar.xz /ssvm/SSVM.tar.xz -cp -v build/SSVM-*.tar.zst /ssvm/SSVM.tar.zst -cp -v build/SSVM-*.rpm /ssvm/SSVM.rpm +cp -v build/WasmEdge-*.tar.gz /wasmedge/WasmEdge.tar.gz +cp -v build/WasmEdge-*.tar.bz2 /wasmedge/WasmEdge.tar.bz2 +cp -v build/WasmEdge-*.tar.xz /wasmedge/WasmEdge.tar.xz +cp -v build/WasmEdge-*.tar.zst /wasmedge/WasmEdge.tar.zst +cp -v build/WasmEdge-*.rpm /wasmedge/WasmEdge.rpm diff --git a/utils/docker/build.sh b/utils/docker/build.sh index 3102865b..eb3d58f2 100755 --- a/utils/docker/build.sh +++ b/utils/docker/build.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -NAME=${1:+$1/}ssvm +NAME=${1:+$1/}wasmedge INTERMEDIATES=() IMAGES=() From 0189f1f4acee6f4207fa139c19ac6c1b5b37cfe4 Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 4 May 2021 09:17:10 +0000 Subject: [PATCH 014/623] [CI] Update CI and docker image links for WasmEdge (formerly SSVM) --- utils/docker/Dockerfile.build-clang | 2 +- utils/docker/Dockerfile.build-gcc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/docker/Dockerfile.build-clang b/utils/docker/Dockerfile.build-clang index 06fed0dc..d510c531 100644 --- a/utils/docker/Dockerfile.build-clang +++ b/utils/docker/Dockerfile.build-clang @@ -1,4 +1,4 @@ -ARG BASE=secondstate/wasmedge:ubuntu-base +ARG BASE=wasmedge/wasmedge:ubuntu-base FROM ${BASE} RUN apt update && apt install -y \ diff --git a/utils/docker/Dockerfile.build-gcc b/utils/docker/Dockerfile.build-gcc index 5ba1028b..06c7f8e2 100644 --- a/utils/docker/Dockerfile.build-gcc +++ b/utils/docker/Dockerfile.build-gcc @@ -1,4 +1,4 @@ -ARG BASE=secondstate/wasmedge:ubuntu-base +ARG BASE=wasmedge/wasmedge:ubuntu-base FROM ${BASE} RUN apt update && apt install -y \ From 08026ea04d12082f6c7c5dbe5d4eab01bd8916e5 Mon Sep 17 00:00:00 2001 From: hydai Date: Thu, 13 May 2021 09:20:07 +0000 Subject: [PATCH 015/623] [Misc] Update manylinux* dockerfiles --- utils/docker/Dockerfile.manylinux1_x86_64 | 2 +- utils/docker/Dockerfile.manylinux2010_x86_64 | 2 +- utils/docker/Dockerfile.manylinux2014_aarch64 | 2 +- utils/docker/Dockerfile.manylinux2014_x86_64 | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/utils/docker/Dockerfile.manylinux1_x86_64 b/utils/docker/Dockerfile.manylinux1_x86_64 index 3fd86ec6..ad48f868 100644 --- a/utils/docker/Dockerfile.manylinux1_x86_64 +++ b/utils/docker/Dockerfile.manylinux1_x86_64 @@ -3,7 +3,7 @@ FROM quay.io/pypa/manylinux1_x86_64 MAINTAINER hydai hydai@secondstate.io -ADD SHA256SUM llvm-glibc-2.5.patch cmake-glibc-2.5.patch build-manylinux.sh /root +ADD SHA256SUM llvm-glibc-2.5.patch cmake-glibc-2.5.patch build-manylinux.sh /root/ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build && \ export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ diff --git a/utils/docker/Dockerfile.manylinux2010_x86_64 b/utils/docker/Dockerfile.manylinux2010_x86_64 index f9f134d1..a1ace4c3 100644 --- a/utils/docker/Dockerfile.manylinux2010_x86_64 +++ b/utils/docker/Dockerfile.manylinux2010_x86_64 @@ -3,7 +3,7 @@ FROM quay.io/pypa/manylinux2010_x86_64 MAINTAINER hydai hydai@secondstate.io -ADD SHA256SUM build-manylinux.sh /root +ADD SHA256SUM build-manylinux.sh /root/ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build && \ export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ diff --git a/utils/docker/Dockerfile.manylinux2014_aarch64 b/utils/docker/Dockerfile.manylinux2014_aarch64 index 7bbd8d7e..46b350f2 100644 --- a/utils/docker/Dockerfile.manylinux2014_aarch64 +++ b/utils/docker/Dockerfile.manylinux2014_aarch64 @@ -3,7 +3,7 @@ FROM quay.io/pypa/manylinux2014_aarch64 MAINTAINER hydai hydai@secondstate.io -ADD SHA256SUM build-manylinux.sh /root +ADD SHA256SUM build-manylinux.sh /root/ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build && \ export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ diff --git a/utils/docker/Dockerfile.manylinux2014_x86_64 b/utils/docker/Dockerfile.manylinux2014_x86_64 index a6f2b411..32694e54 100644 --- a/utils/docker/Dockerfile.manylinux2014_x86_64 +++ b/utils/docker/Dockerfile.manylinux2014_x86_64 @@ -3,7 +3,7 @@ FROM quay.io/pypa/manylinux2014_x86_64 MAINTAINER hydai hydai@secondstate.io -ADD SHA256SUM build-manylinux.sh /root +ADD SHA256SUM build-manylinux.sh /root/ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build && \ export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ From eadf87ba25e24095e177afbc1ab58e3fa13d6374 Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Fri, 14 May 2021 04:41:12 +0000 Subject: [PATCH 016/623] [Misc] Update manylinux* dockerfiles for new gcc and llvm * Fix compile error on gcc-4.8.8 for gcc-11.1.0 * Fix compile error on llvm --- utils/docker/Dockerfile.manylinux1_x86_64 | 48 ++++++++++--------- utils/docker/Dockerfile.manylinux2010_x86_64 | 43 +++++++++-------- utils/docker/Dockerfile.manylinux2014_aarch64 | 43 +++++++++-------- utils/docker/Dockerfile.manylinux2014_x86_64 | 43 +++++++++-------- utils/docker/SHA256SUM | 12 ++--- utils/docker/gcc-4.8.2-gcc-11.patch | 38 +++++++++++++++ utils/docker/llvm-fix-missing-include.patch | 10 ++++ utils/docker/llvm-glibc-2.5.patch | 2 +- 8 files changed, 146 insertions(+), 93 deletions(-) create mode 100644 utils/docker/gcc-4.8.2-gcc-11.patch create mode 100644 utils/docker/llvm-fix-missing-include.patch diff --git a/utils/docker/Dockerfile.manylinux1_x86_64 b/utils/docker/Dockerfile.manylinux1_x86_64 index ad48f868..73cd463e 100644 --- a/utils/docker/Dockerfile.manylinux1_x86_64 +++ b/utils/docker/Dockerfile.manylinux1_x86_64 @@ -3,7 +3,7 @@ FROM quay.io/pypa/manylinux1_x86_64 MAINTAINER hydai hydai@secondstate.io -ADD SHA256SUM llvm-glibc-2.5.patch cmake-glibc-2.5.patch build-manylinux.sh /root/ +ADD SHA256SUM gcc-4.8.2-gcc-11.patch llvm-fix-missing-include.patch llvm-glibc-2.5.patch cmake-glibc-2.5.patch build-manylinux.sh /root/ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build && \ export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ @@ -18,33 +18,35 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil https://ftp.gnu.org/gnu/gmp/gmp-6.2.1.tar.xz \ https://ftp.gnu.org/gnu/mpfr/mpfr-4.1.0.tar.xz \ https://ftp.gnu.org/gnu/mpc/mpc-1.2.1.tar.gz \ - http://isl.gforge.inria.fr/isl-0.23.tar.xz \ - https://github.com/facebook/zstd/releases/download/v1.4.7/zstd-1.4.7.tar.gz \ - https://ftp.gnu.org/gnu/gcc/gcc-10.2.0/gcc-10.2.0.tar.xz \ - https://github.com/Kitware/CMake/releases/download/v3.19.2/cmake-3.19.2.tar.gz \ + http://isl.gforge.inria.fr/isl-0.24.tar.xz \ + https://github.com/facebook/zstd/releases/download/v1.4.9/zstd-1.4.9.tar.gz \ + https://ftp.gnu.org/gnu/gcc/gcc-11.1.0/gcc-11.1.0.tar.xz \ + https://github.com/Kitware/CMake/releases/download/v3.20.2/cmake-3.20.2.tar.gz \ https://github.com/ninja-build/ninja/archive/v1.10.2.tar.gz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/llvm-11.0.0.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/lld-11.0.0.src.tar.xz && \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-11.1.0/llvm-11.1.0.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-11.1.0/lld-11.1.0.src.tar.xz && \ sha256sum -c SHA256SUM && \ xz -dc gmp-6.2.1.tar.xz | tar -xf - && \ xz -dc mpfr-4.1.0.tar.xz | tar -xf - && \ gzip -dc mpc-1.2.1.tar.gz | tar -xf - && \ - xz -dc isl-0.23.tar.xz | tar -xf - && \ - gzip -dc zstd-1.4.7.tar.gz | tar -xf - && \ - xz -dc gcc-10.2.0.tar.xz | tar -xf - && \ - gzip -dc cmake-3.19.2.tar.gz | tar -xf - && \ + xz -dc isl-0.24.tar.xz | tar -xf - && \ + gzip -dc zstd-1.4.9.tar.gz | tar -xf - && \ + xz -dc gcc-11.1.0.tar.xz | tar -xf - && \ + gzip -dc cmake-3.20.2.tar.gz | tar -xf - && \ gzip -dc v1.10.2.tar.gz | tar -xf - && \ - xz -dc llvm-11.0.0.src.tar.xz | tar -xf - && \ - xz -dc lld-11.0.0.src.tar.xz | tar -xf - && \ - cd llvm-11.0.0.src && patch -p1 < ../llvm-glibc-2.5.patch && cd - && \ - cd cmake-3.19.2 && patch -p1 < ../cmake-glibc-2.5.patch && cd - && \ + xz -dc llvm-11.1.0.src.tar.xz | tar -xf - && \ + xz -dc lld-11.1.0.src.tar.xz | tar -xf - && \ + cd gcc-11.1.0 && patch -p1 < ../gcc-4.8.2-gcc-11.patch && cd - && \ + cd llvm-11.1.0.src && patch -p1 < ../llvm-fix-missing-include.patch && cd - && \ + cd llvm-11.1.0.src && patch -p1 < ../llvm-glibc-2.5.patch && cd - && \ + cd cmake-3.20.2 && patch -p1 < ../cmake-glibc-2.5.patch && cd - && \ mkdir build && cd build && ../gmp-6.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpfr-4.1.0/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpc-1.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mkdir build && cd build && ../isl-0.23/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mkdir build && cd build && ../isl-0.24/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ export ZSTDFLAGS="PREFIX=/toolchain LIBDIR=/toolchain/lib64 SED_ERE_OPT=--regexp-extended" && \ - cd zstd-1.4.7 && make -s $ZSTDFLAGS -j $CPU && make -s $ZSTDFLAGS install && rm -f /toolchain/lib64/libzstd.so* && cd - && \ - mkdir build && cd build && ../gcc-10.2.0/configure --prefix=/toolchain --libdir=/toolchain/lib64 \ + cd zstd-1.4.9 && make -s $ZSTDFLAGS -j $CPU && make -s $ZSTDFLAGS install && rm -f /toolchain/lib64/libzstd.so* && cd - && \ + mkdir build && cd build && ../gcc-11.1.0/configure --prefix=/toolchain --libdir=/toolchain/lib64 \ --with-gmp=/toolchain --with-gmp-lib=/toolchain/lib64 \ --with-zstd=/toolchain --with-zstd-lib=/toolchain/lib64 \ --disable-libmpx --disable-libsanitizer --disable-libunwind-exceptions \ @@ -55,23 +57,23 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil --with-gcc-major-version-only --with-linker-hash-style="gnu" \ --with-tune="generic" && \ make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mv -v /toolchain/lib64/libstdc++.so.6.0.28 /toolchain/lib64/libstdc++.so.6.0.28.backup && \ + mv -v /toolchain/lib64/libstdc++.so.6.0.29 /toolchain/lib64/libstdc++.so.6.0.29.backup && \ echo -e "OUTPUT_FORMAT(elf64-x86-64)\nINPUT ( libstdc++.so.6.0.8 libstdc++.a )" \ - > /toolchain/lib64/libstdc++.so.6.0.28 && \ + > /toolchain/lib64/libstdc++.so.6.0.29 && \ export PATH="/toolchain/bin:$PATH" && \ mkdir build && cd build && /opt/python/cp39-cp39/bin/python \ ../ninja-1.10.2/configure.py --bootstrap \ --with-python=/opt/python/cp39-cp39/bin/python && \ cp -v ninja /toolchain/bin/ninja && cd - && rm -rf build && \ - mkdir build && cd build && ../cmake-3.19.2/configure --prefix=/toolchain \ + mkdir build && cd build && ../cmake-3.20.2/configure --prefix=/toolchain \ --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mv -v lld-11.0.0.src lld && \ + mv -v lld-11.1.0.src lld && \ cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_INSTALL_PREFIX=/toolchain \ -DPython3_ROOT_DIR=/opt/python/cp39-cp39 -DLLVM_LIBDIR_SUFFIX=64 \ -DLLVM_TARGETS_TO_BUILD="X86" -DLLVM_ENABLE_PROJECTS=lld \ -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" \ - llvm-11.0.0.src && cmake --build build --target install && rm -rf build && \ + llvm-11.1.0.src && cmake --build build --target install && rm -rf build && \ mv build-manylinux.sh /toolchain/bin/ && \ rm -rf * diff --git a/utils/docker/Dockerfile.manylinux2010_x86_64 b/utils/docker/Dockerfile.manylinux2010_x86_64 index a1ace4c3..4c6310b6 100644 --- a/utils/docker/Dockerfile.manylinux2010_x86_64 +++ b/utils/docker/Dockerfile.manylinux2010_x86_64 @@ -3,7 +3,7 @@ FROM quay.io/pypa/manylinux2010_x86_64 MAINTAINER hydai hydai@secondstate.io -ADD SHA256SUM build-manylinux.sh /root/ +ADD SHA256SUM llvm-fix-missing-include.patch build-manylinux.sh /root/ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build && \ export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ @@ -18,31 +18,32 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil https://ftp.gnu.org/gnu/gmp/gmp-6.2.1.tar.xz \ https://ftp.gnu.org/gnu/mpfr/mpfr-4.1.0.tar.xz \ https://ftp.gnu.org/gnu/mpc/mpc-1.2.1.tar.gz \ - http://isl.gforge.inria.fr/isl-0.23.tar.xz \ - https://github.com/facebook/zstd/releases/download/v1.4.7/zstd-1.4.7.tar.gz \ - https://ftp.gnu.org/gnu/gcc/gcc-10.2.0/gcc-10.2.0.tar.xz \ - https://github.com/Kitware/CMake/releases/download/v3.19.2/cmake-3.19.2.tar.gz \ + http://isl.gforge.inria.fr/isl-0.24.tar.xz \ + https://github.com/facebook/zstd/releases/download/v1.4.9/zstd-1.4.9.tar.gz \ + https://ftp.gnu.org/gnu/gcc/gcc-11.1.0/gcc-11.1.0.tar.xz \ + https://github.com/Kitware/CMake/releases/download/v3.20.2/cmake-3.20.2.tar.gz \ https://github.com/ninja-build/ninja/archive/v1.10.2.tar.gz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/llvm-11.0.0.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/lld-11.0.0.src.tar.xz && \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-11.1.0/llvm-11.1.0.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-11.1.0/lld-11.1.0.src.tar.xz && \ sha256sum -c SHA256SUM && \ xz -dc gmp-6.2.1.tar.xz | tar -xf - && \ xz -dc mpfr-4.1.0.tar.xz | tar -xf - && \ gzip -dc mpc-1.2.1.tar.gz | tar -xf - && \ - xz -dc isl-0.23.tar.xz | tar -xf - && \ - gzip -dc zstd-1.4.7.tar.gz | tar -xf - && \ - xz -dc gcc-10.2.0.tar.xz | tar -xf - && \ - gzip -dc cmake-3.19.2.tar.gz | tar -xf - && \ + xz -dc isl-0.24.tar.xz | tar -xf - && \ + gzip -dc zstd-1.4.9.tar.gz | tar -xf - && \ + xz -dc gcc-11.1.0.tar.xz | tar -xf - && \ + gzip -dc cmake-3.20.2.tar.gz | tar -xf - && \ gzip -dc v1.10.2.tar.gz | tar -xf - && \ - xz -dc llvm-11.0.0.src.tar.xz | tar -xf - && \ - xz -dc lld-11.0.0.src.tar.xz | tar -xf - && \ + xz -dc llvm-11.1.0.src.tar.xz | tar -xf - && \ + xz -dc lld-11.1.0.src.tar.xz | tar -xf - && \ + cd llvm-11.1.0.src && patch -p1 < ../llvm-fix-missing-include.patch && cd - && \ mkdir build && cd build && ../gmp-6.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpfr-4.1.0/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpc-1.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mkdir build && cd build && ../isl-0.23/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mkdir build && cd build && ../isl-0.24/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ export ZSTDFLAGS="PREFIX=/toolchain LIBDIR=/toolchain/lib64 SED_ERE_OPT=--regexp-extended" && \ - cd zstd-1.4.7 && make -s $ZSTDFLAGS -j $CPU && make -s $ZSTDFLAGS install && rm -f /toolchain/lib64/libzstd.so* && cd - && \ - mkdir build && cd build && ../gcc-10.2.0/configure --prefix=/toolchain --libdir=/toolchain/lib64 \ + cd zstd-1.4.9 && make -s $ZSTDFLAGS -j $CPU && make -s $ZSTDFLAGS install && rm -f /toolchain/lib64/libzstd.so* && cd - && \ + mkdir build && cd build && ../gcc-11.1.0/configure --prefix=/toolchain --libdir=/toolchain/lib64 \ --with-gmp=/toolchain --with-gmp-lib=/toolchain/lib64 \ --with-zstd=/toolchain --with-zstd-lib=/toolchain/lib64 \ --disable-libmpx --disable-libsanitizer --disable-libunwind-exceptions \ @@ -53,23 +54,23 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil --with-gcc-major-version-only --with-linker-hash-style="gnu" \ --with-tune="generic" && \ make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mv -v /toolchain/lib64/libstdc++.so.6.0.28 /toolchain/lib64/libstdc++.so.6.0.28.backup && \ + mv -v /toolchain/lib64/libstdc++.so.6.0.29 /toolchain/lib64/libstdc++.so.6.0.29.backup && \ echo -e "OUTPUT_FORMAT(elf64-x86-64)\nINPUT ( libstdc++.so.6.0.13 libstdc++.a )" \ - > /toolchain/lib64/libstdc++.so.6.0.28 && \ + > /toolchain/lib64/libstdc++.so.6.0.29 && \ export PATH="/toolchain/bin:$PATH" && \ mkdir build && cd build && /opt/python/cp39-cp39/bin/python \ ../ninja-1.10.2/configure.py --bootstrap \ --with-python=/opt/python/cp39-cp39/bin/python && \ cp -v ninja /toolchain/bin/ninja && cd - && rm -rf build && \ - mkdir build && cd build && ../cmake-3.19.2/configure --prefix=/toolchain \ + mkdir build && cd build && ../cmake-3.20.2/configure --prefix=/toolchain \ --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mv -v lld-11.0.0.src lld && \ + mv -v lld-11.1.0.src lld && \ cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_INSTALL_PREFIX=/toolchain \ -DPython3_ROOT_DIR=/opt/python/cp39-cp39 -DLLVM_LIBDIR_SUFFIX=64 \ -DLLVM_TARGETS_TO_BUILD="X86" -DLLVM_ENABLE_PROJECTS=lld \ -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" \ - llvm-11.0.0.src && cmake --build build --target install && rm -rf build && \ + llvm-11.1.0.src && cmake --build build --target install && rm -rf build && \ mv build-manylinux.sh /toolchain/bin/ && \ rm -rf * diff --git a/utils/docker/Dockerfile.manylinux2014_aarch64 b/utils/docker/Dockerfile.manylinux2014_aarch64 index 46b350f2..e549d64e 100644 --- a/utils/docker/Dockerfile.manylinux2014_aarch64 +++ b/utils/docker/Dockerfile.manylinux2014_aarch64 @@ -3,7 +3,7 @@ FROM quay.io/pypa/manylinux2014_aarch64 MAINTAINER hydai hydai@secondstate.io -ADD SHA256SUM build-manylinux.sh /root/ +ADD SHA256SUM llvm-fix-missing-include.patch build-manylinux.sh /root/ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build && \ export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ @@ -18,31 +18,32 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil https://ftp.gnu.org/gnu/gmp/gmp-6.2.1.tar.xz \ https://ftp.gnu.org/gnu/mpfr/mpfr-4.1.0.tar.xz \ https://ftp.gnu.org/gnu/mpc/mpc-1.2.1.tar.gz \ - http://isl.gforge.inria.fr/isl-0.23.tar.xz \ - https://github.com/facebook/zstd/releases/download/v1.4.7/zstd-1.4.7.tar.gz \ - https://ftp.gnu.org/gnu/gcc/gcc-10.2.0/gcc-10.2.0.tar.xz \ - https://github.com/Kitware/CMake/releases/download/v3.19.2/cmake-3.19.2.tar.gz \ + http://isl.gforge.inria.fr/isl-0.24.tar.xz \ + https://github.com/facebook/zstd/releases/download/v1.4.9/zstd-1.4.9.tar.gz \ + https://ftp.gnu.org/gnu/gcc/gcc-11.1.0/gcc-11.1.0.tar.xz \ + https://github.com/Kitware/CMake/releases/download/v3.20.2/cmake-3.20.2.tar.gz \ https://github.com/ninja-build/ninja/archive/v1.10.2.tar.gz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/llvm-11.0.0.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/lld-11.0.0.src.tar.xz && \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-11.1.0/llvm-11.1.0.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-11.1.0/lld-11.1.0.src.tar.xz && \ sha256sum -c SHA256SUM && \ xz -dc gmp-6.2.1.tar.xz | tar -xf - && \ xz -dc mpfr-4.1.0.tar.xz | tar -xf - && \ gzip -dc mpc-1.2.1.tar.gz | tar -xf - && \ - xz -dc isl-0.23.tar.xz | tar -xf - && \ - gzip -dc zstd-1.4.7.tar.gz | tar -xf - && \ - xz -dc gcc-10.2.0.tar.xz | tar -xf - && \ - gzip -dc cmake-3.19.2.tar.gz | tar -xf - && \ + xz -dc isl-0.24.tar.xz | tar -xf - && \ + gzip -dc zstd-1.4.9.tar.gz | tar -xf - && \ + xz -dc gcc-11.1.0.tar.xz | tar -xf - && \ + gzip -dc cmake-3.20.2.tar.gz | tar -xf - && \ gzip -dc v1.10.2.tar.gz | tar -xf - && \ - xz -dc llvm-11.0.0.src.tar.xz | tar -xf - && \ - xz -dc lld-11.0.0.src.tar.xz | tar -xf - && \ + xz -dc llvm-11.1.0.src.tar.xz | tar -xf - && \ + xz -dc lld-11.1.0.src.tar.xz | tar -xf - && \ + cd llvm-11.1.0.src && patch -p1 < ../llvm-fix-missing-include.patch && cd - && \ mkdir build && cd build && ../gmp-6.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpfr-4.1.0/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpc-1.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mkdir build && cd build && ../isl-0.23/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mkdir build && cd build && ../isl-0.24/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ export ZSTDFLAGS="PREFIX=/toolchain LIBDIR=/toolchain/lib64 SED_ERE_OPT=--regexp-extended" && \ - cd zstd-1.4.7 && make -s $ZSTDFLAGS -j 1 && make -s $ZSTDFLAGS install && rm -f /toolchain/lib64/libzstd.so* && cd - && \ - mkdir build && cd build && ../gcc-10.2.0/configure --prefix=/toolchain --libdir=/toolchain/lib64 \ + cd zstd-1.4.9 && make -s $ZSTDFLAGS -j 1 && make -s $ZSTDFLAGS install && rm -f /toolchain/lib64/libzstd.so* && cd - && \ + mkdir build && cd build && ../gcc-11.1.0/configure --prefix=/toolchain --libdir=/toolchain/lib64 \ --with-gmp=/toolchain --with-gmp-lib=/toolchain/lib64 \ --with-zstd=/toolchain --with-zstd-lib=/toolchain/lib64 \ --disable-libmpx --disable-libsanitizer --disable-libunwind-exceptions \ @@ -52,23 +53,23 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil --with-default-libstdcxx-abi="gcc4-compatible" \ --with-gcc-major-version-only --with-linker-hash-style="gnu" && \ make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mv -v /toolchain/lib64/libstdc++.so.6.0.28 /toolchain/lib64/libstdc++.so.6.0.28.backup && \ + mv -v /toolchain/lib64/libstdc++.so.6.0.29 /toolchain/lib64/libstdc++.so.6.0.29.backup && \ echo -e "OUTPUT_FORMAT(elf64-aarch64)\nINPUT ( libstdc++.so.6.0.19 libstdc++.a )" \ - > /toolchain/lib64/libstdc++.so.6.0.28 && \ + > /toolchain/lib64/libstdc++.so.6.0.29 && \ export PATH="/toolchain/bin:$PATH" && \ mkdir build && cd build && /opt/python/cp39-cp39/bin/python \ ../ninja-1.10.2/configure.py --bootstrap \ --with-python=/opt/python/cp39-cp39/bin/python && \ cp -v ninja /toolchain/bin/ninja && cd - && rm -rf build && \ - mkdir build && cd build && ../cmake-3.19.2/configure --prefix=/toolchain \ + mkdir build && cd build && ../cmake-3.20.2/configure --prefix=/toolchain \ --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mv -v lld-11.0.0.src lld && \ + mv -v lld-11.1.0.src lld && \ cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_INSTALL_PREFIX=/toolchain \ -DPython3_ROOT_DIR=/opt/python/cp39-cp39 -DLLVM_LIBDIR_SUFFIX=64 \ -DLLVM_TARGETS_TO_BUILD="AArch64" -DLLVM_ENABLE_PROJECTS=lld \ -DLLVM_DEFAULT_TARGET_TRIPLE="aarch64-redhat-linux-gnu" \ - llvm-11.0.0.src && cmake --build build --target install && rm -rf build && \ + llvm-11.1.0.src && cmake --build build --target install && rm -rf build && \ mv build-manylinux.sh /toolchain/bin/ && \ rm -rf * diff --git a/utils/docker/Dockerfile.manylinux2014_x86_64 b/utils/docker/Dockerfile.manylinux2014_x86_64 index 32694e54..cfbe5331 100644 --- a/utils/docker/Dockerfile.manylinux2014_x86_64 +++ b/utils/docker/Dockerfile.manylinux2014_x86_64 @@ -3,7 +3,7 @@ FROM quay.io/pypa/manylinux2014_x86_64 MAINTAINER hydai hydai@secondstate.io -ADD SHA256SUM build-manylinux.sh /root/ +ADD SHA256SUM llvm-fix-missing-include.patch build-manylinux.sh /root/ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build && \ export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ @@ -18,31 +18,32 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil https://ftp.gnu.org/gnu/gmp/gmp-6.2.1.tar.xz \ https://ftp.gnu.org/gnu/mpfr/mpfr-4.1.0.tar.xz \ https://ftp.gnu.org/gnu/mpc/mpc-1.2.1.tar.gz \ - http://isl.gforge.inria.fr/isl-0.23.tar.xz \ - https://github.com/facebook/zstd/releases/download/v1.4.7/zstd-1.4.7.tar.gz \ - https://ftp.gnu.org/gnu/gcc/gcc-10.2.0/gcc-10.2.0.tar.xz \ - https://github.com/Kitware/CMake/releases/download/v3.19.2/cmake-3.19.2.tar.gz \ + http://isl.gforge.inria.fr/isl-0.24.tar.xz \ + https://github.com/facebook/zstd/releases/download/v1.4.9/zstd-1.4.9.tar.gz \ + https://ftp.gnu.org/gnu/gcc/gcc-11.1.0/gcc-11.1.0.tar.xz \ + https://github.com/Kitware/CMake/releases/download/v3.20.2/cmake-3.20.2.tar.gz \ https://github.com/ninja-build/ninja/archive/v1.10.2.tar.gz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/llvm-11.0.0.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.0/lld-11.0.0.src.tar.xz && \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-11.1.0/llvm-11.1.0.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-11.1.0/lld-11.1.0.src.tar.xz && \ sha256sum -c SHA256SUM && \ xz -dc gmp-6.2.1.tar.xz | tar -xf - && \ xz -dc mpfr-4.1.0.tar.xz | tar -xf - && \ gzip -dc mpc-1.2.1.tar.gz | tar -xf - && \ - xz -dc isl-0.23.tar.xz | tar -xf - && \ - gzip -dc zstd-1.4.7.tar.gz | tar -xf - && \ - xz -dc gcc-10.2.0.tar.xz | tar -xf - && \ - gzip -dc cmake-3.19.2.tar.gz | tar -xf - && \ + xz -dc isl-0.24.tar.xz | tar -xf - && \ + gzip -dc zstd-1.4.9.tar.gz | tar -xf - && \ + xz -dc gcc-11.1.0.tar.xz | tar -xf - && \ + gzip -dc cmake-3.20.2.tar.gz | tar -xf - && \ gzip -dc v1.10.2.tar.gz | tar -xf - && \ - xz -dc llvm-11.0.0.src.tar.xz | tar -xf - && \ - xz -dc lld-11.0.0.src.tar.xz | tar -xf - && \ + xz -dc llvm-11.1.0.src.tar.xz | tar -xf - && \ + xz -dc lld-11.1.0.src.tar.xz | tar -xf - && \ + cd llvm-11.1.0.src && patch -p1 < ../llvm-fix-missing-include.patch && cd - && \ mkdir build && cd build && ../gmp-6.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpfr-4.1.0/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpc-1.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mkdir build && cd build && ../isl-0.23/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mkdir build && cd build && ../isl-0.24/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ export ZSTDFLAGS="PREFIX=/toolchain LIBDIR=/toolchain/lib64 SED_ERE_OPT=--regexp-extended" && \ - cd zstd-1.4.7 && make -s $ZSTDFLAGS -j 1 && make -s $ZSTDFLAGS install && rm -f /toolchain/lib64/libzstd.so* && cd - && \ - mkdir build && cd build && ../gcc-10.2.0/configure --prefix=/toolchain --libdir=/toolchain/lib64 \ + cd zstd-1.4.9 && make -s $ZSTDFLAGS -j 1 && make -s $ZSTDFLAGS install && rm -f /toolchain/lib64/libzstd.so* && cd - && \ + mkdir build && cd build && ../gcc-11.1.0/configure --prefix=/toolchain --libdir=/toolchain/lib64 \ --with-gmp=/toolchain --with-gmp-lib=/toolchain/lib64 \ --with-zstd=/toolchain --with-zstd-lib=/toolchain/lib64 \ --disable-libmpx --disable-libsanitizer --disable-libunwind-exceptions \ @@ -53,23 +54,23 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil --with-gcc-major-version-only --with-linker-hash-style="gnu" \ --with-tune="generic" && \ make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mv -v /toolchain/lib64/libstdc++.so.6.0.28 /toolchain/lib64/libstdc++.so.6.0.28.backup && \ + mv -v /toolchain/lib64/libstdc++.so.6.0.29 /toolchain/lib64/libstdc++.so.6.0.29.backup && \ echo -e "OUTPUT_FORMAT(elf64-x86-64)\nINPUT ( libstdc++.so.6.0.19 libstdc++.a )" \ - > /toolchain/lib64/libstdc++.so.6.0.28 && \ + > /toolchain/lib64/libstdc++.so.6.0.29 && \ export PATH="/toolchain/bin:$PATH" && \ mkdir build && cd build && /opt/python/cp39-cp39/bin/python \ ../ninja-1.10.2/configure.py --bootstrap \ --with-python=/opt/python/cp39-cp39/bin/python && \ cp -v ninja /toolchain/bin/ninja && cd - && rm -rf build && \ - mkdir build && cd build && ../cmake-3.19.2/configure --prefix=/toolchain \ + mkdir build && cd build && ../cmake-3.20.2/configure --prefix=/toolchain \ --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mv -v lld-11.0.0.src lld && \ + mv -v lld-11.1.0.src lld && \ cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_INSTALL_PREFIX=/toolchain \ -DPython3_ROOT_DIR=/opt/python/cp39-cp39 -DLLVM_LIBDIR_SUFFIX=64 \ -DLLVM_TARGETS_TO_BUILD="X86" -DLLVM_ENABLE_PROJECTS=lld \ -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" \ - llvm-11.0.0.src && cmake --build build --target install && rm -rf build && \ + llvm-11.1.0.src && cmake --build build --target install && rm -rf build && \ mv build-manylinux.sh /toolchain/bin/ && \ rm -rf * diff --git a/utils/docker/SHA256SUM b/utils/docker/SHA256SUM index 398f25a4..5dc84484 100644 --- a/utils/docker/SHA256SUM +++ b/utils/docker/SHA256SUM @@ -1,10 +1,10 @@ -e3e0fd3b23b7fb13e1a856581078e0776ffa2df4e9d3164039c36d3315e0c7f0 cmake-3.19.2.tar.gz -b8dd4368bb9c7f0b98188317ee0254dd8cc99d1e3a18d0ff146c855fe16c1d8c gcc-10.2.0.tar.xz +aecf6ecb975179eb3bb6a4a50cae192d41e92b9372b02300f9e8f1d5f559544e cmake-3.20.2.tar.gz +4c4a6fb8a8396059241c2e674b85b351c26a5d678274007f076957afa1cc9ddf gcc-11.1.0.tar.xz fd4829912cddd12f84181c3451cc752be224643e87fac497b69edddadc49b4f2 gmp-6.2.1.tar.xz -5efc53efaef151301f4e7dde3856b66812d8153dede24fab17673f801c8698f2 isl-0.23.tar.xz -efe7be4a7b7cdc6f3bcf222827c6f837439e6e656d12d6c885d5c8a80ff4fd1c lld-11.0.0.src.tar.xz -913f68c898dfb4a03b397c5e11c6a2f39d0f22ed7665c9cefa87a34423a72469 llvm-11.0.0.src.tar.xz +043105cc544f416b48736fff8caf077fb0663a717d06b1113f16e391ac99ebad isl-0.24.tar.xz +017a788cbe1ecc4a949abf10755870519086d058a2e99f438829aef24f0c66ce lld-11.1.0.src.tar.xz +ce8508e318a01a63d4e8b3090ab2ded3c598a50258cc49e2625b9120d4c03ea5 llvm-11.1.0.src.tar.xz 17503d2c395dfcf106b622dc142683c1199431d095367c6aacba6eec30340459 mpc-1.2.1.tar.gz 0c98a3f1732ff6ca4ea690552079da9c597872d30e96ec28414ee23c95558a7f mpfr-4.1.0.tar.xz ce35865411f0490368a8fc383f29071de6690cbadc27704734978221f25e2bed v1.10.2.tar.gz -192cbb1274a9672cbcceaf47b5c4e9e59691ca60a357f1d4a8b2dfa2c365d757 zstd-1.4.7.tar.gz +29ac74e19ea28659017361976240c4b5c5c24db3b89338731a6feb97c038d293 zstd-1.4.9.tar.gz diff --git a/utils/docker/gcc-4.8.2-gcc-11.patch b/utils/docker/gcc-4.8.2-gcc-11.patch new file mode 100644 index 00000000..00e7cbdb --- /dev/null +++ b/utils/docker/gcc-4.8.2-gcc-11.patch @@ -0,0 +1,38 @@ +--- a/gcc/splay-tree-utils.h 2021-05-14 06:09:43.289274290 +0000 ++++ b/gcc/splay-tree-utils.h 2021-05-14 06:24:13.628159368 +0000 +@@ -105,7 +105,11 @@ template + class base_splay_tree : protected Accessors + { + public: ++#if __GNUC__ > 4 + using typename Accessors::node_type; ++#else ++ using node_type = typename Accessors::node_type; ++#endif + + // INDEX is either 0 or 1. If INDEX is 0, insert CHILD immediately + // before NODE, otherwise insert CHILD immediately after NODE. +@@ -148,7 +152,11 @@ class rooted_splay_tree : public base_sp + using parent = base_splay_tree; + + public: ++#if __GNUC__ > 4 + using typename Accessors::node_type; ++#else ++ using node_type = typename Accessors::node_type; ++#endif + + protected: + // The root of the splay tree, or node_type () if the tree is empty. +@@ -409,7 +417,11 @@ class rootless_splay_tree + public: + using rooted = rooted_splay_tree; + ++#if __GNUC__ > 4 + using typename Accessors::node_type; ++#else ++ using node_type = typename Accessors::node_type; ++#endif + + // Remove NODE from the splay tree. Return the node that replaces it, + // or null if NODE had no children. diff --git a/utils/docker/llvm-fix-missing-include.patch b/utils/docker/llvm-fix-missing-include.patch new file mode 100644 index 00000000..20411a3f --- /dev/null +++ b/utils/docker/llvm-fix-missing-include.patch @@ -0,0 +1,10 @@ +--- a/utils/benchmark/src/benchmark_register.h ++++ b/utils/benchmark/src/benchmark_register.h +@@ -1,6 +1,7 @@ + #ifndef BENCHMARK_REGISTER_H + #define BENCHMARK_REGISTER_H + ++#include + #include + + #include "check.h" diff --git a/utils/docker/llvm-glibc-2.5.patch b/utils/docker/llvm-glibc-2.5.patch index e2afabb4..54cac8ea 100644 --- a/utils/docker/llvm-glibc-2.5.patch +++ b/utils/docker/llvm-glibc-2.5.patch @@ -1,6 +1,6 @@ --- a/lib/Support/Host.cpp 2020-12-17 20:09:25.321395012 +0000 +++ b/lib/Support/Host.cpp 2020-12-17 20:29:40.296551916 +0000 -@@ -1224,6 +1224,15 @@ StringRef sys::getHostCPUName() { return +@@ -1225,6 +1225,15 @@ StringRef sys::getHostCPUName() { return #endif #if defined(__linux__) && (defined(__i386__) || defined(__x86_64__)) From 5267b324c22482ff62b981966dc741308bc86111 Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Mon, 17 May 2021 08:08:33 +0000 Subject: [PATCH 017/623] [Misc] Update boost to 1.76 --- utils/docker/build-manylinux.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/utils/docker/build-manylinux.sh b/utils/docker/build-manylinux.sh index eba0bd5d..046ddbbe 100755 --- a/utils/docker/build-manylinux.sh +++ b/utils/docker/build-manylinux.sh @@ -6,10 +6,10 @@ export CC=gcc export CXX=g++ cd -curl -s -L -O --remote-name-all https://dl.bintray.com/boostorg/release/1.75.0/source/boost_1_75_0.tar.bz2 -echo 953db31e016db7bb207f11432bef7df100516eeb746843fa0486a222e3fd49cb boost_1_75_0.tar.bz2 | sha256sum -c -bzip2 -dc boost_1_75_0.tar.bz2 | tar -xf - -cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release -DBUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM" -DBOOST_INCLUDEDIR=$(pwd)/boost_1_75_0/ /wasmedge +curl -s -L -O --remote-name-all https://boostorg.jfrog.io/artifactory/main/release/1.76.0/source/boost_1_76_0.tar.bz2 +echo f0397ba6e982c4450f27bf32a2a83292aba035b827a5623a14636ea583318c41 boost_1_76_0.tar.bz2 | sha256sum -c +bzip2 -dc boost_1_76_0.tar.bz2 | tar -xf - +cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release -DBUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM" -DBOOST_INCLUDEDIR=$(pwd)/boost_1_76_0/ /wasmedge cmake --build build cmake --build build --target package cp -v build/WasmEdge-*.tar.gz /wasmedge/WasmEdge.tar.gz From 480eafa257bf9ecde076f188d1d678ee64714fec Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Mon, 17 May 2021 08:12:12 +0000 Subject: [PATCH 018/623] [Misc] Move environment variable to Dockerfile --- utils/docker/Dockerfile.manylinux1_x86_64 | 3 +++ utils/docker/Dockerfile.manylinux2010_x86_64 | 3 +++ utils/docker/Dockerfile.manylinux2014_aarch64 | 3 +++ utils/docker/Dockerfile.manylinux2014_x86_64 | 3 +++ utils/docker/build-manylinux.sh | 4 ---- 5 files changed, 12 insertions(+), 4 deletions(-) diff --git a/utils/docker/Dockerfile.manylinux1_x86_64 b/utils/docker/Dockerfile.manylinux1_x86_64 index 73cd463e..99ed5e9b 100644 --- a/utils/docker/Dockerfile.manylinux1_x86_64 +++ b/utils/docker/Dockerfile.manylinux1_x86_64 @@ -78,3 +78,6 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil rm -rf * RUN yum clean all +ENV PATH /toolchain/bin:$PATH +ENV CC gcc +ENV CXX g++ diff --git a/utils/docker/Dockerfile.manylinux2010_x86_64 b/utils/docker/Dockerfile.manylinux2010_x86_64 index 4c6310b6..30361906 100644 --- a/utils/docker/Dockerfile.manylinux2010_x86_64 +++ b/utils/docker/Dockerfile.manylinux2010_x86_64 @@ -75,3 +75,6 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil rm -rf * RUN yum clean all +ENV PATH /toolchain/bin:$PATH +ENV CC gcc +ENV CXX g++ diff --git a/utils/docker/Dockerfile.manylinux2014_aarch64 b/utils/docker/Dockerfile.manylinux2014_aarch64 index e549d64e..fc9e1ce4 100644 --- a/utils/docker/Dockerfile.manylinux2014_aarch64 +++ b/utils/docker/Dockerfile.manylinux2014_aarch64 @@ -74,3 +74,6 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil rm -rf * RUN yum clean all +ENV PATH /toolchain/bin:$PATH +ENV CC gcc +ENV CXX g++ diff --git a/utils/docker/Dockerfile.manylinux2014_x86_64 b/utils/docker/Dockerfile.manylinux2014_x86_64 index cfbe5331..72a4bfee 100644 --- a/utils/docker/Dockerfile.manylinux2014_x86_64 +++ b/utils/docker/Dockerfile.manylinux2014_x86_64 @@ -75,3 +75,6 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil rm -rf * RUN yum clean all +ENV PATH /toolchain/bin:$PATH +ENV CC gcc +ENV CXX g++ diff --git a/utils/docker/build-manylinux.sh b/utils/docker/build-manylinux.sh index 046ddbbe..7eb613cf 100755 --- a/utils/docker/build-manylinux.sh +++ b/utils/docker/build-manylinux.sh @@ -1,10 +1,6 @@ #!/usr/bin/env bash # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -export PATH="/toolchain/bin:$PATH" -export CC=gcc -export CXX=g++ - cd curl -s -L -O --remote-name-all https://boostorg.jfrog.io/artifactory/main/release/1.76.0/source/boost_1_76_0.tar.bz2 echo f0397ba6e982c4450f27bf32a2a83292aba035b827a5623a14636ea583318c41 boost_1_76_0.tar.bz2 | sha256sum -c From eccc0e9f92e8d262dbdf6ccf6cb69023376c61ac Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Mon, 17 May 2021 18:55:43 +0000 Subject: [PATCH 019/623] [Misc] Use helper script to build --- utils/docker/Dockerfile.manylinux1_x86_64 | 3 +-- utils/docker/Dockerfile.manylinux2010_x86_64 | 3 +-- utils/docker/Dockerfile.manylinux2014_aarch64 | 3 +-- utils/docker/Dockerfile.manylinux2014_x86_64 | 3 +-- utils/docker/build-manylinux.sh | 19 ++++++++++++------- 5 files changed, 16 insertions(+), 15 deletions(-) diff --git a/utils/docker/Dockerfile.manylinux1_x86_64 b/utils/docker/Dockerfile.manylinux1_x86_64 index 99ed5e9b..5e236965 100644 --- a/utils/docker/Dockerfile.manylinux1_x86_64 +++ b/utils/docker/Dockerfile.manylinux1_x86_64 @@ -3,7 +3,7 @@ FROM quay.io/pypa/manylinux1_x86_64 MAINTAINER hydai hydai@secondstate.io -ADD SHA256SUM gcc-4.8.2-gcc-11.patch llvm-fix-missing-include.patch llvm-glibc-2.5.patch cmake-glibc-2.5.patch build-manylinux.sh /root/ +ADD SHA256SUM gcc-4.8.2-gcc-11.patch llvm-fix-missing-include.patch llvm-glibc-2.5.patch cmake-glibc-2.5.patch /root/ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build && \ export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ @@ -74,7 +74,6 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil -DLLVM_TARGETS_TO_BUILD="X86" -DLLVM_ENABLE_PROJECTS=lld \ -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" \ llvm-11.1.0.src && cmake --build build --target install && rm -rf build && \ - mv build-manylinux.sh /toolchain/bin/ && \ rm -rf * RUN yum clean all diff --git a/utils/docker/Dockerfile.manylinux2010_x86_64 b/utils/docker/Dockerfile.manylinux2010_x86_64 index 30361906..8100b613 100644 --- a/utils/docker/Dockerfile.manylinux2010_x86_64 +++ b/utils/docker/Dockerfile.manylinux2010_x86_64 @@ -3,7 +3,7 @@ FROM quay.io/pypa/manylinux2010_x86_64 MAINTAINER hydai hydai@secondstate.io -ADD SHA256SUM llvm-fix-missing-include.patch build-manylinux.sh /root/ +ADD SHA256SUM llvm-fix-missing-include.patch /root/ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build && \ export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ @@ -71,7 +71,6 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil -DLLVM_TARGETS_TO_BUILD="X86" -DLLVM_ENABLE_PROJECTS=lld \ -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" \ llvm-11.1.0.src && cmake --build build --target install && rm -rf build && \ - mv build-manylinux.sh /toolchain/bin/ && \ rm -rf * RUN yum clean all diff --git a/utils/docker/Dockerfile.manylinux2014_aarch64 b/utils/docker/Dockerfile.manylinux2014_aarch64 index fc9e1ce4..870a36ee 100644 --- a/utils/docker/Dockerfile.manylinux2014_aarch64 +++ b/utils/docker/Dockerfile.manylinux2014_aarch64 @@ -3,7 +3,7 @@ FROM quay.io/pypa/manylinux2014_aarch64 MAINTAINER hydai hydai@secondstate.io -ADD SHA256SUM llvm-fix-missing-include.patch build-manylinux.sh /root/ +ADD SHA256SUM llvm-fix-missing-include.patch /root/ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build && \ export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ @@ -70,7 +70,6 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil -DLLVM_TARGETS_TO_BUILD="AArch64" -DLLVM_ENABLE_PROJECTS=lld \ -DLLVM_DEFAULT_TARGET_TRIPLE="aarch64-redhat-linux-gnu" \ llvm-11.1.0.src && cmake --build build --target install && rm -rf build && \ - mv build-manylinux.sh /toolchain/bin/ && \ rm -rf * RUN yum clean all diff --git a/utils/docker/Dockerfile.manylinux2014_x86_64 b/utils/docker/Dockerfile.manylinux2014_x86_64 index 72a4bfee..ebb16c78 100644 --- a/utils/docker/Dockerfile.manylinux2014_x86_64 +++ b/utils/docker/Dockerfile.manylinux2014_x86_64 @@ -3,7 +3,7 @@ FROM quay.io/pypa/manylinux2014_x86_64 MAINTAINER hydai hydai@secondstate.io -ADD SHA256SUM llvm-fix-missing-include.patch build-manylinux.sh /root/ +ADD SHA256SUM llvm-fix-missing-include.patch /root/ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build && \ export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ @@ -71,7 +71,6 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil -DLLVM_TARGETS_TO_BUILD="X86" -DLLVM_ENABLE_PROJECTS=lld \ -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" \ llvm-11.1.0.src && cmake --build build --target install && rm -rf build && \ - mv build-manylinux.sh /toolchain/bin/ && \ rm -rf * RUN yum clean all diff --git a/utils/docker/build-manylinux.sh b/utils/docker/build-manylinux.sh index 7eb613cf..449ddea9 100755 --- a/utils/docker/build-manylinux.sh +++ b/utils/docker/build-manylinux.sh @@ -1,15 +1,20 @@ #!/usr/bin/env bash # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -cd +export PATH="/toolchain/bin:$PATH" +export CC=gcc +export CXX=g++ +export CPPFLAGS=-I/toolchain/include +export LDFLAGS=-L/toolchain/lib64 curl -s -L -O --remote-name-all https://boostorg.jfrog.io/artifactory/main/release/1.76.0/source/boost_1_76_0.tar.bz2 echo f0397ba6e982c4450f27bf32a2a83292aba035b827a5623a14636ea583318c41 boost_1_76_0.tar.bz2 | sha256sum -c bzip2 -dc boost_1_76_0.tar.bz2 | tar -xf - -cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release -DBUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM" -DBOOST_INCLUDEDIR=$(pwd)/boost_1_76_0/ /wasmedge +if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release -DBUILD_TOOL_WASMEDGE_STATIC=OFF -DBUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM" -DBoost_NO_SYSTEM_PATHS=TRUE -DBOOST_INCLUDEDIR=$(pwd)/boost_1_76_0/; then + echo === CMakeOutput.log === + cat build/CMakeFiles/CMakeOutput.log + echo === CMakeError.log === + cat build/CMakeFiles/CMakeError.log + exit 1 +fi cmake --build build cmake --build build --target package -cp -v build/WasmEdge-*.tar.gz /wasmedge/WasmEdge.tar.gz -cp -v build/WasmEdge-*.tar.bz2 /wasmedge/WasmEdge.tar.bz2 -cp -v build/WasmEdge-*.tar.xz /wasmedge/WasmEdge.tar.xz -cp -v build/WasmEdge-*.tar.zst /wasmedge/WasmEdge.tar.zst -cp -v build/WasmEdge-*.rpm /wasmedge/WasmEdge.rpm From eeba286fd44e110dbd6fda621a287c6d2ed2f386 Mon Sep 17 00:00:00 2001 From: hydai Date: Mon, 24 May 2021 08:12:10 +0000 Subject: [PATCH 020/623] [Deps] Bump LLVM to 12.0.0 Signed-off-by: hydai --- utils/docker/Dockerfile.base | 2 +- utils/docker/Dockerfile.build-clang | 2 +- utils/docker/Dockerfile.build-gcc | 2 +- utils/docker/Dockerfile.ci-image-base | 2 +- utils/docker/build.sh | 1 + 5 files changed, 5 insertions(+), 4 deletions(-) diff --git a/utils/docker/Dockerfile.base b/utils/docker/Dockerfile.base index 41b365a3..3e28e74c 100644 --- a/utils/docker/Dockerfile.base +++ b/utils/docker/Dockerfile.base @@ -1,4 +1,4 @@ -FROM ubuntu:20.04 +FROM ubuntu:21.04 MAINTAINER hydai hydai@secondstate.io ENV DEBIAN_FRONTEND=noninteractive diff --git a/utils/docker/Dockerfile.build-clang b/utils/docker/Dockerfile.build-clang index d510c531..bfca6ce1 100644 --- a/utils/docker/Dockerfile.build-clang +++ b/utils/docker/Dockerfile.build-clang @@ -3,7 +3,7 @@ FROM ${BASE} RUN apt update && apt install -y \ llvm-dev \ - liblld-10-dev + liblld-12-dev RUN apt update && apt install -y \ clang diff --git a/utils/docker/Dockerfile.build-gcc b/utils/docker/Dockerfile.build-gcc index 06c7f8e2..38c91ff2 100644 --- a/utils/docker/Dockerfile.build-gcc +++ b/utils/docker/Dockerfile.build-gcc @@ -3,7 +3,7 @@ FROM ${BASE} RUN apt update && apt install -y \ llvm-dev \ - liblld-10-dev + liblld-12-dev RUN apt update && apt install -y \ gcc \ diff --git a/utils/docker/Dockerfile.ci-image-base b/utils/docker/Dockerfile.ci-image-base index d381de58..670f08b1 100644 --- a/utils/docker/Dockerfile.ci-image-base +++ b/utils/docker/Dockerfile.ci-image-base @@ -1,4 +1,4 @@ -FROM ubuntu:20.04 +FROM ubuntu:21.04 MAINTAINER hydai hydai@secondstate.io ENV DEBIAN_FRONTEND=noninteractive diff --git a/utils/docker/build.sh b/utils/docker/build.sh index eb3d58f2..cfc88878 100755 --- a/utils/docker/build.sh +++ b/utils/docker/build.sh @@ -24,6 +24,7 @@ function docker_build # Build all images. docker_build Dockerfile.base ubuntu-base +docker_build Dockerfile.ci-image-base ci-image-base docker_build Dockerfile.build-clang ubuntu-build-clang \ --build-arg "BASE=${NAME}:ubuntu-base" docker_build Dockerfile.build-clang latest \ From 1959ef5a44032229b407695204349aa79c96629b Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Tue, 25 May 2021 12:52:45 +0000 Subject: [PATCH 021/623] [Misc] Upgrade llvm to 12.0.0, zstd to 1.5.0 * Add generic flags to gcc * Add dpkg for packaging * Add libunwind for header dependency * Move env flags to header Signed-off-by: Shen-Ta Hsieh --- utils/docker/Dockerfile.manylinux1_x86_64 | 50 ++++++++++--------- utils/docker/Dockerfile.manylinux2010_x86_64 | 48 +++++++++--------- utils/docker/Dockerfile.manylinux2014_aarch64 | 49 +++++++++--------- utils/docker/Dockerfile.manylinux2014_x86_64 | 48 +++++++++--------- utils/docker/SHA256SUM | 7 +-- utils/docker/build-manylinux.sh | 4 +- utils/docker/llvm-fix-missing-include.patch | 10 ---- 7 files changed, 108 insertions(+), 108 deletions(-) delete mode 100644 utils/docker/llvm-fix-missing-include.patch diff --git a/utils/docker/Dockerfile.manylinux1_x86_64 b/utils/docker/Dockerfile.manylinux1_x86_64 index 5e236965..cbaacf2d 100644 --- a/utils/docker/Dockerfile.manylinux1_x86_64 +++ b/utils/docker/Dockerfile.manylinux1_x86_64 @@ -3,49 +3,52 @@ FROM quay.io/pypa/manylinux1_x86_64 MAINTAINER hydai hydai@secondstate.io -ADD SHA256SUM gcc-4.8.2-gcc-11.patch llvm-fix-missing-include.patch llvm-glibc-2.5.patch cmake-glibc-2.5.patch /root/ +ADD SHA256SUM gcc-4.8.2-gcc-11.patch llvm-glibc-2.5.patch cmake-glibc-2.5.patch /root/ -RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build && \ +ENV PATH /toolchain/bin:$PATH +ENV CC gcc +ENV CXX g++ +ENV CPPFLAGS -I/toolchain/include +ENV LDFLAGS -L/toolchain/lib64 +ENV PKG_CONFIG_PATH /toolchain/lib64/pkgconfig + +RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build dpkg && \ export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ 'import multiprocessing; print(multiprocessing.cpu_count())') && \ export CFGFLAGS="--prefix=/toolchain --disable-shared --libdir=/toolchain/lib64" && \ - export PKG_CONFIG_PATH=/toolchain/lib64/pkgconfig && \ - export CC=gcc && \ - export CXX=g++ && \ - export CPPFLAGS=-I/toolchain/include && \ - export LDFLAGS=-L/toolchain/lib64 && \ curl -s -L -O --remote-name-all \ https://ftp.gnu.org/gnu/gmp/gmp-6.2.1.tar.xz \ https://ftp.gnu.org/gnu/mpfr/mpfr-4.1.0.tar.xz \ https://ftp.gnu.org/gnu/mpc/mpc-1.2.1.tar.gz \ http://isl.gforge.inria.fr/isl-0.24.tar.xz \ - https://github.com/facebook/zstd/releases/download/v1.4.9/zstd-1.4.9.tar.gz \ + https://github.com/facebook/zstd/releases/download/v1.5.0/zstd-1.5.0.tar.gz \ https://ftp.gnu.org/gnu/gcc/gcc-11.1.0/gcc-11.1.0.tar.xz \ https://github.com/Kitware/CMake/releases/download/v3.20.2/cmake-3.20.2.tar.gz \ https://github.com/ninja-build/ninja/archive/v1.10.2.tar.gz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-11.1.0/llvm-11.1.0.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-11.1.0/lld-11.1.0.src.tar.xz && \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-12.0.0/llvm-12.0.0.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-12.0.0/lld-12.0.0.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-12.0.0/libunwind-12.0.0.src.tar.xz && \ sha256sum -c SHA256SUM && \ xz -dc gmp-6.2.1.tar.xz | tar -xf - && \ xz -dc mpfr-4.1.0.tar.xz | tar -xf - && \ gzip -dc mpc-1.2.1.tar.gz | tar -xf - && \ xz -dc isl-0.24.tar.xz | tar -xf - && \ - gzip -dc zstd-1.4.9.tar.gz | tar -xf - && \ + gzip -dc zstd-1.5.0.tar.gz | tar -xf - && \ xz -dc gcc-11.1.0.tar.xz | tar -xf - && \ gzip -dc cmake-3.20.2.tar.gz | tar -xf - && \ gzip -dc v1.10.2.tar.gz | tar -xf - && \ - xz -dc llvm-11.1.0.src.tar.xz | tar -xf - && \ - xz -dc lld-11.1.0.src.tar.xz | tar -xf - && \ + xz -dc llvm-12.0.0.src.tar.xz | tar -xf - && \ + xz -dc lld-12.0.0.src.tar.xz | tar -xf - && \ + xz -dc libunwind-12.0.0.src.tar.xz | tar -xf - && \ cd gcc-11.1.0 && patch -p1 < ../gcc-4.8.2-gcc-11.patch && cd - && \ - cd llvm-11.1.0.src && patch -p1 < ../llvm-fix-missing-include.patch && cd - && \ - cd llvm-11.1.0.src && patch -p1 < ../llvm-glibc-2.5.patch && cd - && \ + cd llvm-12.0.0.src && patch -p1 < ../llvm-glibc-2.5.patch && cd - && \ cd cmake-3.20.2 && patch -p1 < ../cmake-glibc-2.5.patch && cd - && \ mkdir build && cd build && ../gmp-6.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpfr-4.1.0/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpc-1.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../isl-0.24/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ export ZSTDFLAGS="PREFIX=/toolchain LIBDIR=/toolchain/lib64 SED_ERE_OPT=--regexp-extended" && \ - cd zstd-1.4.9 && make -s $ZSTDFLAGS -j $CPU && make -s $ZSTDFLAGS install && rm -f /toolchain/lib64/libzstd.so* && cd - && \ + cd zstd-1.5.0 && make -s $ZSTDFLAGS -j $CPU && make -s $ZSTDFLAGS install && rm -f /toolchain/lib64/libzstd.so* && cd - && \ mkdir build && cd build && ../gcc-11.1.0/configure --prefix=/toolchain --libdir=/toolchain/lib64 \ --with-gmp=/toolchain --with-gmp-lib=/toolchain/lib64 \ --with-zstd=/toolchain --with-zstd-lib=/toolchain/lib64 \ @@ -55,7 +58,7 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil --enable-linker-build-id --enable-lto --enable-plugin --enable-threads=posix \ --with-default-libstdcxx-abi="gcc4-compatible" \ --with-gcc-major-version-only --with-linker-hash-style="gnu" \ - --with-tune="generic" && \ + --with-arch="x86-64" --with-tune="generic" && \ make -s -j $CPU && make -s install && cd - && rm -rf build && \ mv -v /toolchain/lib64/libstdc++.so.6.0.29 /toolchain/lib64/libstdc++.so.6.0.29.backup && \ echo -e "OUTPUT_FORMAT(elf64-x86-64)\nINPUT ( libstdc++.so.6.0.8 libstdc++.a )" \ @@ -67,16 +70,15 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil cp -v ninja /toolchain/bin/ninja && cd - && rm -rf build && \ mkdir build && cd build && ../cmake-3.20.2/configure --prefix=/toolchain \ --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mv -v lld-11.1.0.src lld && \ + mv -v llvm-12.0.0.src llvm && \ + mv -v lld-12.0.0.src lld && \ + mv -v libunwind-12.0.0.src libunwind && \ cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_INSTALL_PREFIX=/toolchain \ -DPython3_ROOT_DIR=/opt/python/cp39-cp39 -DLLVM_LIBDIR_SUFFIX=64 \ -DLLVM_TARGETS_TO_BUILD="X86" -DLLVM_ENABLE_PROJECTS=lld \ - -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" \ - llvm-11.1.0.src && cmake --build build --target install && rm -rf build && \ - rm -rf * + -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" llvm && \ + cmake --build build --target install && \ + rm -rf build && rm -rf * RUN yum clean all -ENV PATH /toolchain/bin:$PATH -ENV CC gcc -ENV CXX g++ diff --git a/utils/docker/Dockerfile.manylinux2010_x86_64 b/utils/docker/Dockerfile.manylinux2010_x86_64 index 8100b613..f477dc84 100644 --- a/utils/docker/Dockerfile.manylinux2010_x86_64 +++ b/utils/docker/Dockerfile.manylinux2010_x86_64 @@ -3,46 +3,49 @@ FROM quay.io/pypa/manylinux2010_x86_64 MAINTAINER hydai hydai@secondstate.io -ADD SHA256SUM llvm-fix-missing-include.patch /root/ +ADD SHA256SUM /root/ -RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build && \ +ENV PATH /toolchain/bin:$PATH +ENV CC gcc +ENV CXX g++ +ENV CPPFLAGS -I/toolchain/include +ENV LDFLAGS -L/toolchain/lib64 +ENV PKG_CONFIG_PATH /toolchain/lib64/pkgconfig + +RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build dpkg && \ export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ 'import multiprocessing; print(multiprocessing.cpu_count())') && \ export CFGFLAGS="--prefix=/toolchain --disable-shared --libdir=/toolchain/lib64" && \ - export PKG_CONFIG_PATH=/toolchain/lib64/pkgconfig && \ - export CC=gcc && \ - export CXX=g++ && \ - export CPPFLAGS=-I/toolchain/include && \ - export LDFLAGS=-L/toolchain/lib64 && \ curl -s -L -O --remote-name-all \ https://ftp.gnu.org/gnu/gmp/gmp-6.2.1.tar.xz \ https://ftp.gnu.org/gnu/mpfr/mpfr-4.1.0.tar.xz \ https://ftp.gnu.org/gnu/mpc/mpc-1.2.1.tar.gz \ http://isl.gforge.inria.fr/isl-0.24.tar.xz \ - https://github.com/facebook/zstd/releases/download/v1.4.9/zstd-1.4.9.tar.gz \ + https://github.com/facebook/zstd/releases/download/v1.5.0/zstd-1.5.0.tar.gz \ https://ftp.gnu.org/gnu/gcc/gcc-11.1.0/gcc-11.1.0.tar.xz \ https://github.com/Kitware/CMake/releases/download/v3.20.2/cmake-3.20.2.tar.gz \ https://github.com/ninja-build/ninja/archive/v1.10.2.tar.gz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-11.1.0/llvm-11.1.0.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-11.1.0/lld-11.1.0.src.tar.xz && \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-12.0.0/llvm-12.0.0.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-12.0.0/lld-12.0.0.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-12.0.0/libunwind-12.0.0.src.tar.xz && \ sha256sum -c SHA256SUM && \ xz -dc gmp-6.2.1.tar.xz | tar -xf - && \ xz -dc mpfr-4.1.0.tar.xz | tar -xf - && \ gzip -dc mpc-1.2.1.tar.gz | tar -xf - && \ xz -dc isl-0.24.tar.xz | tar -xf - && \ - gzip -dc zstd-1.4.9.tar.gz | tar -xf - && \ + gzip -dc zstd-1.5.0.tar.gz | tar -xf - && \ xz -dc gcc-11.1.0.tar.xz | tar -xf - && \ gzip -dc cmake-3.20.2.tar.gz | tar -xf - && \ gzip -dc v1.10.2.tar.gz | tar -xf - && \ - xz -dc llvm-11.1.0.src.tar.xz | tar -xf - && \ - xz -dc lld-11.1.0.src.tar.xz | tar -xf - && \ - cd llvm-11.1.0.src && patch -p1 < ../llvm-fix-missing-include.patch && cd - && \ + xz -dc llvm-12.0.0.src.tar.xz | tar -xf - && \ + xz -dc lld-12.0.0.src.tar.xz | tar -xf - && \ + xz -dc libunwind-12.0.0.src.tar.xz | tar -xf - && \ mkdir build && cd build && ../gmp-6.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpfr-4.1.0/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpc-1.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../isl-0.24/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ export ZSTDFLAGS="PREFIX=/toolchain LIBDIR=/toolchain/lib64 SED_ERE_OPT=--regexp-extended" && \ - cd zstd-1.4.9 && make -s $ZSTDFLAGS -j $CPU && make -s $ZSTDFLAGS install && rm -f /toolchain/lib64/libzstd.so* && cd - && \ + cd zstd-1.5.0 && make -s $ZSTDFLAGS -j $CPU && make -s $ZSTDFLAGS install && rm -f /toolchain/lib64/libzstd.so* && cd - && \ mkdir build && cd build && ../gcc-11.1.0/configure --prefix=/toolchain --libdir=/toolchain/lib64 \ --with-gmp=/toolchain --with-gmp-lib=/toolchain/lib64 \ --with-zstd=/toolchain --with-zstd-lib=/toolchain/lib64 \ @@ -52,7 +55,7 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil --enable-linker-build-id --enable-lto --enable-plugin --enable-threads=posix \ --with-default-libstdcxx-abi="gcc4-compatible" \ --with-gcc-major-version-only --with-linker-hash-style="gnu" \ - --with-tune="generic" && \ + --with-arch="x86-64" --with-tune="generic" && \ make -s -j $CPU && make -s install && cd - && rm -rf build && \ mv -v /toolchain/lib64/libstdc++.so.6.0.29 /toolchain/lib64/libstdc++.so.6.0.29.backup && \ echo -e "OUTPUT_FORMAT(elf64-x86-64)\nINPUT ( libstdc++.so.6.0.13 libstdc++.a )" \ @@ -64,16 +67,15 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil cp -v ninja /toolchain/bin/ninja && cd - && rm -rf build && \ mkdir build && cd build && ../cmake-3.20.2/configure --prefix=/toolchain \ --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mv -v lld-11.1.0.src lld && \ + mv -v llvm-12.0.0.src llvm && \ + mv -v lld-12.0.0.src lld && \ + mv -v libunwind-12.0.0.src libunwind && \ cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_INSTALL_PREFIX=/toolchain \ -DPython3_ROOT_DIR=/opt/python/cp39-cp39 -DLLVM_LIBDIR_SUFFIX=64 \ -DLLVM_TARGETS_TO_BUILD="X86" -DLLVM_ENABLE_PROJECTS=lld \ - -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" \ - llvm-11.1.0.src && cmake --build build --target install && rm -rf build && \ - rm -rf * + -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" llvm && \ + cmake --build build --target install && \ + rm -rf build && rm -rf * RUN yum clean all -ENV PATH /toolchain/bin:$PATH -ENV CC gcc -ENV CXX g++ diff --git a/utils/docker/Dockerfile.manylinux2014_aarch64 b/utils/docker/Dockerfile.manylinux2014_aarch64 index 870a36ee..08f13585 100644 --- a/utils/docker/Dockerfile.manylinux2014_aarch64 +++ b/utils/docker/Dockerfile.manylinux2014_aarch64 @@ -3,46 +3,49 @@ FROM quay.io/pypa/manylinux2014_aarch64 MAINTAINER hydai hydai@secondstate.io -ADD SHA256SUM llvm-fix-missing-include.patch /root/ +ADD SHA256SUM /root/ -RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build && \ +ENV PATH /toolchain/bin:$PATH +ENV CC gcc +ENV CXX g++ +ENV CPPFLAGS -I/toolchain/include +ENV LDFLAGS -L/toolchain/lib64 +ENV PKG_CONFIG_PATH /toolchain/lib64/pkgconfig + +RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build dpkg && \ export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ 'import multiprocessing; print(multiprocessing.cpu_count())') && \ export CFGFLAGS="--prefix=/toolchain --disable-shared --libdir=/toolchain/lib64" && \ - export PKG_CONFIG_PATH=/toolchain/lib64/pkgconfig && \ - export CC=gcc && \ - export CXX=g++ && \ - export CPPFLAGS=-I/toolchain/include && \ - export LDFLAGS=-L/toolchain/lib64 && \ curl -s -L -O --remote-name-all \ https://ftp.gnu.org/gnu/gmp/gmp-6.2.1.tar.xz \ https://ftp.gnu.org/gnu/mpfr/mpfr-4.1.0.tar.xz \ https://ftp.gnu.org/gnu/mpc/mpc-1.2.1.tar.gz \ http://isl.gforge.inria.fr/isl-0.24.tar.xz \ - https://github.com/facebook/zstd/releases/download/v1.4.9/zstd-1.4.9.tar.gz \ + https://github.com/facebook/zstd/releases/download/v1.5.0/zstd-1.5.0.tar.gz \ https://ftp.gnu.org/gnu/gcc/gcc-11.1.0/gcc-11.1.0.tar.xz \ https://github.com/Kitware/CMake/releases/download/v3.20.2/cmake-3.20.2.tar.gz \ https://github.com/ninja-build/ninja/archive/v1.10.2.tar.gz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-11.1.0/llvm-11.1.0.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-11.1.0/lld-11.1.0.src.tar.xz && \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-12.0.0/llvm-12.0.0.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-12.0.0/lld-12.0.0.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-12.0.0/libunwind-12.0.0.src.tar.xz && \ sha256sum -c SHA256SUM && \ xz -dc gmp-6.2.1.tar.xz | tar -xf - && \ xz -dc mpfr-4.1.0.tar.xz | tar -xf - && \ gzip -dc mpc-1.2.1.tar.gz | tar -xf - && \ xz -dc isl-0.24.tar.xz | tar -xf - && \ - gzip -dc zstd-1.4.9.tar.gz | tar -xf - && \ + gzip -dc zstd-1.5.0.tar.gz | tar -xf - && \ xz -dc gcc-11.1.0.tar.xz | tar -xf - && \ gzip -dc cmake-3.20.2.tar.gz | tar -xf - && \ gzip -dc v1.10.2.tar.gz | tar -xf - && \ - xz -dc llvm-11.1.0.src.tar.xz | tar -xf - && \ - xz -dc lld-11.1.0.src.tar.xz | tar -xf - && \ - cd llvm-11.1.0.src && patch -p1 < ../llvm-fix-missing-include.patch && cd - && \ + xz -dc llvm-12.0.0.src.tar.xz | tar -xf - && \ + xz -dc lld-12.0.0.src.tar.xz | tar -xf - && \ + xz -dc libunwind-12.0.0.src.tar.xz | tar -xf - && \ mkdir build && cd build && ../gmp-6.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpfr-4.1.0/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpc-1.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../isl-0.24/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ export ZSTDFLAGS="PREFIX=/toolchain LIBDIR=/toolchain/lib64 SED_ERE_OPT=--regexp-extended" && \ - cd zstd-1.4.9 && make -s $ZSTDFLAGS -j 1 && make -s $ZSTDFLAGS install && rm -f /toolchain/lib64/libzstd.so* && cd - && \ + cd zstd-1.5.0 && make -s $ZSTDFLAGS -j $CPU && make -s $ZSTDFLAGS install && rm -f /toolchain/lib64/libzstd.so* && cd - && \ mkdir build && cd build && ../gcc-11.1.0/configure --prefix=/toolchain --libdir=/toolchain/lib64 \ --with-gmp=/toolchain --with-gmp-lib=/toolchain/lib64 \ --with-zstd=/toolchain --with-zstd-lib=/toolchain/lib64 \ @@ -51,7 +54,8 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil --enable-gnu-unique-object --enable-initfini-array --enable-languages="c,c++,lto" \ --enable-linker-build-id --enable-lto --enable-plugin --enable-threads=posix \ --with-default-libstdcxx-abi="gcc4-compatible" \ - --with-gcc-major-version-only --with-linker-hash-style="gnu" && \ + --with-gcc-major-version-only --with-linker-hash-style="gnu" \ + --with-arch="x86-64" --with-tune="generic" && \ make -s -j $CPU && make -s install && cd - && rm -rf build && \ mv -v /toolchain/lib64/libstdc++.so.6.0.29 /toolchain/lib64/libstdc++.so.6.0.29.backup && \ echo -e "OUTPUT_FORMAT(elf64-aarch64)\nINPUT ( libstdc++.so.6.0.19 libstdc++.a )" \ @@ -63,16 +67,15 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil cp -v ninja /toolchain/bin/ninja && cd - && rm -rf build && \ mkdir build && cd build && ../cmake-3.20.2/configure --prefix=/toolchain \ --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mv -v lld-11.1.0.src lld && \ + mv -v llvm-12.0.0.src llvm && \ + mv -v lld-12.0.0.src lld && \ + mv -v libunwind-12.0.0.src libunwind && \ cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_INSTALL_PREFIX=/toolchain \ -DPython3_ROOT_DIR=/opt/python/cp39-cp39 -DLLVM_LIBDIR_SUFFIX=64 \ -DLLVM_TARGETS_TO_BUILD="AArch64" -DLLVM_ENABLE_PROJECTS=lld \ - -DLLVM_DEFAULT_TARGET_TRIPLE="aarch64-redhat-linux-gnu" \ - llvm-11.1.0.src && cmake --build build --target install && rm -rf build && \ - rm -rf * + -DLLVM_DEFAULT_TARGET_TRIPLE="aarch64-redhat-linux-gnu" llvm && \ + cmake --build build --target install && \ + rm -rf build && rm -rf * RUN yum clean all -ENV PATH /toolchain/bin:$PATH -ENV CC gcc -ENV CXX g++ diff --git a/utils/docker/Dockerfile.manylinux2014_x86_64 b/utils/docker/Dockerfile.manylinux2014_x86_64 index ebb16c78..f5872673 100644 --- a/utils/docker/Dockerfile.manylinux2014_x86_64 +++ b/utils/docker/Dockerfile.manylinux2014_x86_64 @@ -3,46 +3,49 @@ FROM quay.io/pypa/manylinux2014_x86_64 MAINTAINER hydai hydai@secondstate.io -ADD SHA256SUM llvm-fix-missing-include.patch /root/ +ADD SHA256SUM /root/ -RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build && \ +ENV PATH /toolchain/bin:$PATH +ENV CC gcc +ENV CXX g++ +ENV CPPFLAGS -I/toolchain/include +ENV LDFLAGS -L/toolchain/lib64 +ENV PKG_CONFIG_PATH /toolchain/lib64/pkgconfig + +RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build dpkg && \ export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ 'import multiprocessing; print(multiprocessing.cpu_count())') && \ export CFGFLAGS="--prefix=/toolchain --disable-shared --libdir=/toolchain/lib64" && \ - export PKG_CONFIG_PATH=/toolchain/lib64/pkgconfig && \ - export CC=gcc && \ - export CXX=g++ && \ - export CPPFLAGS=-I/toolchain/include && \ - export LDFLAGS=-L/toolchain/lib64 && \ curl -s -L -O --remote-name-all \ https://ftp.gnu.org/gnu/gmp/gmp-6.2.1.tar.xz \ https://ftp.gnu.org/gnu/mpfr/mpfr-4.1.0.tar.xz \ https://ftp.gnu.org/gnu/mpc/mpc-1.2.1.tar.gz \ http://isl.gforge.inria.fr/isl-0.24.tar.xz \ - https://github.com/facebook/zstd/releases/download/v1.4.9/zstd-1.4.9.tar.gz \ + https://github.com/facebook/zstd/releases/download/v1.5.0/zstd-1.5.0.tar.gz \ https://ftp.gnu.org/gnu/gcc/gcc-11.1.0/gcc-11.1.0.tar.xz \ https://github.com/Kitware/CMake/releases/download/v3.20.2/cmake-3.20.2.tar.gz \ https://github.com/ninja-build/ninja/archive/v1.10.2.tar.gz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-11.1.0/llvm-11.1.0.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-11.1.0/lld-11.1.0.src.tar.xz && \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-12.0.0/llvm-12.0.0.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-12.0.0/lld-12.0.0.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-12.0.0/libunwind-12.0.0.src.tar.xz && \ sha256sum -c SHA256SUM && \ xz -dc gmp-6.2.1.tar.xz | tar -xf - && \ xz -dc mpfr-4.1.0.tar.xz | tar -xf - && \ gzip -dc mpc-1.2.1.tar.gz | tar -xf - && \ xz -dc isl-0.24.tar.xz | tar -xf - && \ - gzip -dc zstd-1.4.9.tar.gz | tar -xf - && \ + gzip -dc zstd-1.5.0.tar.gz | tar -xf - && \ xz -dc gcc-11.1.0.tar.xz | tar -xf - && \ gzip -dc cmake-3.20.2.tar.gz | tar -xf - && \ gzip -dc v1.10.2.tar.gz | tar -xf - && \ - xz -dc llvm-11.1.0.src.tar.xz | tar -xf - && \ - xz -dc lld-11.1.0.src.tar.xz | tar -xf - && \ - cd llvm-11.1.0.src && patch -p1 < ../llvm-fix-missing-include.patch && cd - && \ + xz -dc llvm-12.0.0.src.tar.xz | tar -xf - && \ + xz -dc lld-12.0.0.src.tar.xz | tar -xf - && \ + xz -dc libunwind-12.0.0.src.tar.xz | tar -xf - && \ mkdir build && cd build && ../gmp-6.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpfr-4.1.0/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpc-1.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../isl-0.24/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ export ZSTDFLAGS="PREFIX=/toolchain LIBDIR=/toolchain/lib64 SED_ERE_OPT=--regexp-extended" && \ - cd zstd-1.4.9 && make -s $ZSTDFLAGS -j 1 && make -s $ZSTDFLAGS install && rm -f /toolchain/lib64/libzstd.so* && cd - && \ + cd zstd-1.5.0 && make -s $ZSTDFLAGS -j $CPU && make -s $ZSTDFLAGS install && rm -f /toolchain/lib64/libzstd.so* && cd - && \ mkdir build && cd build && ../gcc-11.1.0/configure --prefix=/toolchain --libdir=/toolchain/lib64 \ --with-gmp=/toolchain --with-gmp-lib=/toolchain/lib64 \ --with-zstd=/toolchain --with-zstd-lib=/toolchain/lib64 \ @@ -52,7 +55,7 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil --enable-linker-build-id --enable-lto --enable-plugin --enable-threads=posix \ --with-default-libstdcxx-abi="gcc4-compatible" \ --with-gcc-major-version-only --with-linker-hash-style="gnu" \ - --with-tune="generic" && \ + --with-arch="x86-64" --with-tune="generic" && \ make -s -j $CPU && make -s install && cd - && rm -rf build && \ mv -v /toolchain/lib64/libstdc++.so.6.0.29 /toolchain/lib64/libstdc++.so.6.0.29.backup && \ echo -e "OUTPUT_FORMAT(elf64-x86-64)\nINPUT ( libstdc++.so.6.0.19 libstdc++.a )" \ @@ -64,16 +67,15 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil cp -v ninja /toolchain/bin/ninja && cd - && rm -rf build && \ mkdir build && cd build && ../cmake-3.20.2/configure --prefix=/toolchain \ --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mv -v lld-11.1.0.src lld && \ + mv -v llvm-12.0.0.src llvm && \ + mv -v lld-12.0.0.src lld && \ + mv -v libunwind-12.0.0.src libunwind && \ cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_INSTALL_PREFIX=/toolchain \ -DPython3_ROOT_DIR=/opt/python/cp39-cp39 -DLLVM_LIBDIR_SUFFIX=64 \ -DLLVM_TARGETS_TO_BUILD="X86" -DLLVM_ENABLE_PROJECTS=lld \ - -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" \ - llvm-11.1.0.src && cmake --build build --target install && rm -rf build && \ - rm -rf * + -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" llvm && \ + cmake --build build --target install && \ + rm -rf build && rm -rf * RUN yum clean all -ENV PATH /toolchain/bin:$PATH -ENV CC gcc -ENV CXX g++ diff --git a/utils/docker/SHA256SUM b/utils/docker/SHA256SUM index 5dc84484..66d8051f 100644 --- a/utils/docker/SHA256SUM +++ b/utils/docker/SHA256SUM @@ -2,9 +2,10 @@ aecf6ecb975179eb3bb6a4a50cae192d41e92b9372b02300f9e8f1d5f559544e cmake-3.20.2.t 4c4a6fb8a8396059241c2e674b85b351c26a5d678274007f076957afa1cc9ddf gcc-11.1.0.tar.xz fd4829912cddd12f84181c3451cc752be224643e87fac497b69edddadc49b4f2 gmp-6.2.1.tar.xz 043105cc544f416b48736fff8caf077fb0663a717d06b1113f16e391ac99ebad isl-0.24.tar.xz -017a788cbe1ecc4a949abf10755870519086d058a2e99f438829aef24f0c66ce lld-11.1.0.src.tar.xz -ce8508e318a01a63d4e8b3090ab2ded3c598a50258cc49e2625b9120d4c03ea5 llvm-11.1.0.src.tar.xz +9ed2a5b28853f7f58be9d04836ff43d6e4132df5a2c058b690dc3e9d75bd1cf5 libunwind-12.0.0.src.tar.xz +2cb7d497f3ce33ce8a2c50ad26ec93a8c45f57268d4d96953cd0f25566f753fd lld-12.0.0.src.tar.xz +49dc47c8697a1a0abd4ee51629a696d7bfe803662f2a7252a3b16fc75f3a8b50 llvm-12.0.0.src.tar.xz 17503d2c395dfcf106b622dc142683c1199431d095367c6aacba6eec30340459 mpc-1.2.1.tar.gz 0c98a3f1732ff6ca4ea690552079da9c597872d30e96ec28414ee23c95558a7f mpfr-4.1.0.tar.xz ce35865411f0490368a8fc383f29071de6690cbadc27704734978221f25e2bed v1.10.2.tar.gz -29ac74e19ea28659017361976240c4b5c5c24db3b89338731a6feb97c038d293 zstd-1.4.9.tar.gz +5194fbfa781fcf45b98c5e849651aa7b3b0a008c6b72d4a0db760f3002291e94 zstd-1.5.0.tar.gz diff --git a/utils/docker/build-manylinux.sh b/utils/docker/build-manylinux.sh index 449ddea9..d081b353 100755 --- a/utils/docker/build-manylinux.sh +++ b/utils/docker/build-manylinux.sh @@ -7,9 +7,9 @@ export CXX=g++ export CPPFLAGS=-I/toolchain/include export LDFLAGS=-L/toolchain/lib64 curl -s -L -O --remote-name-all https://boostorg.jfrog.io/artifactory/main/release/1.76.0/source/boost_1_76_0.tar.bz2 -echo f0397ba6e982c4450f27bf32a2a83292aba035b827a5623a14636ea583318c41 boost_1_76_0.tar.bz2 | sha256sum -c +echo "f0397ba6e982c4450f27bf32a2a83292aba035b827a5623a14636ea583318c41 boost_1_76_0.tar.bz2" | sha256sum -c bzip2 -dc boost_1_76_0.tar.bz2 | tar -xf - -if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release -DBUILD_TOOL_WASMEDGE_STATIC=OFF -DBUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM" -DBoost_NO_SYSTEM_PATHS=TRUE -DBOOST_INCLUDEDIR=$(pwd)/boost_1_76_0/; then +if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release -DBUILD_TOOL_WASMEDGE_STATIC=OFF -DBUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM;DEB" -DBoost_NO_SYSTEM_PATHS=TRUE -DBOOST_INCLUDEDIR=$(pwd)/boost_1_76_0/; then echo === CMakeOutput.log === cat build/CMakeFiles/CMakeOutput.log echo === CMakeError.log === diff --git a/utils/docker/llvm-fix-missing-include.patch b/utils/docker/llvm-fix-missing-include.patch deleted file mode 100644 index 20411a3f..00000000 --- a/utils/docker/llvm-fix-missing-include.patch +++ /dev/null @@ -1,10 +0,0 @@ ---- a/utils/benchmark/src/benchmark_register.h -+++ b/utils/benchmark/src/benchmark_register.h -@@ -1,6 +1,7 @@ - #ifndef BENCHMARK_REGISTER_H - #define BENCHMARK_REGISTER_H - -+#include - #include - - #include "check.h" From 0309fc0cb27e5d594fde78700868382faff3b195 Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Fri, 28 May 2021 09:57:45 +0000 Subject: [PATCH 022/623] [Docker] Provide cpu infomation for gmp to build more generic code Signed-off-by: Shen-Ta Hsieh --- utils/docker/Dockerfile.manylinux1_x86_64 | 2 +- utils/docker/Dockerfile.manylinux2010_x86_64 | 2 +- utils/docker/Dockerfile.manylinux2014_aarch64 | 2 +- utils/docker/Dockerfile.manylinux2014_x86_64 | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/utils/docker/Dockerfile.manylinux1_x86_64 b/utils/docker/Dockerfile.manylinux1_x86_64 index cbaacf2d..89338988 100644 --- a/utils/docker/Dockerfile.manylinux1_x86_64 +++ b/utils/docker/Dockerfile.manylinux1_x86_64 @@ -43,7 +43,7 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil cd gcc-11.1.0 && patch -p1 < ../gcc-4.8.2-gcc-11.patch && cd - && \ cd llvm-12.0.0.src && patch -p1 < ../llvm-glibc-2.5.patch && cd - && \ cd cmake-3.20.2 && patch -p1 < ../cmake-glibc-2.5.patch && cd - && \ - mkdir build && cd build && ../gmp-6.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mkdir build && cd build && ../gmp-6.2.1/configure --build=x86_64-pc-linux-gnu $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpfr-4.1.0/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpc-1.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../isl-0.24/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ diff --git a/utils/docker/Dockerfile.manylinux2010_x86_64 b/utils/docker/Dockerfile.manylinux2010_x86_64 index f477dc84..226656e1 100644 --- a/utils/docker/Dockerfile.manylinux2010_x86_64 +++ b/utils/docker/Dockerfile.manylinux2010_x86_64 @@ -40,7 +40,7 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil xz -dc llvm-12.0.0.src.tar.xz | tar -xf - && \ xz -dc lld-12.0.0.src.tar.xz | tar -xf - && \ xz -dc libunwind-12.0.0.src.tar.xz | tar -xf - && \ - mkdir build && cd build && ../gmp-6.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mkdir build && cd build && ../gmp-6.2.1/configure --build=x86_64-pc-linux-gnu $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpfr-4.1.0/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpc-1.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../isl-0.24/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ diff --git a/utils/docker/Dockerfile.manylinux2014_aarch64 b/utils/docker/Dockerfile.manylinux2014_aarch64 index 08f13585..23d9b811 100644 --- a/utils/docker/Dockerfile.manylinux2014_aarch64 +++ b/utils/docker/Dockerfile.manylinux2014_aarch64 @@ -40,7 +40,7 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil xz -dc llvm-12.0.0.src.tar.xz | tar -xf - && \ xz -dc lld-12.0.0.src.tar.xz | tar -xf - && \ xz -dc libunwind-12.0.0.src.tar.xz | tar -xf - && \ - mkdir build && cd build && ../gmp-6.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mkdir build && cd build && ../gmp-6.2.1/configure --build=aarch64-redhat-linux-gnu $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpfr-4.1.0/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpc-1.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../isl-0.24/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ diff --git a/utils/docker/Dockerfile.manylinux2014_x86_64 b/utils/docker/Dockerfile.manylinux2014_x86_64 index f5872673..cf009f81 100644 --- a/utils/docker/Dockerfile.manylinux2014_x86_64 +++ b/utils/docker/Dockerfile.manylinux2014_x86_64 @@ -40,7 +40,7 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil xz -dc llvm-12.0.0.src.tar.xz | tar -xf - && \ xz -dc lld-12.0.0.src.tar.xz | tar -xf - && \ xz -dc libunwind-12.0.0.src.tar.xz | tar -xf - && \ - mkdir build && cd build && ../gmp-6.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ + mkdir build && cd build && ../gmp-6.2.1/configure --build=x86_64-pc-linux-gnu $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpfr-4.1.0/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../mpc-1.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mkdir build && cd build && ../isl-0.24/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ From 55319125a7dc9ba28b9d7b10eccd9b438a6c7ba8 Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Tue, 15 Jun 2021 00:30:41 +0000 Subject: [PATCH 023/623] [WASI] Add wasi-test for testing basic WASI interface * Run wasi-test on Github Action Signed-off-by: Shen-Ta Hsieh --- ...-Disable-other-tests-except-wasmedge.patch | 43 ++++++++ utils/wasi-test/run-wasi-test.sh | 100 ++++++++++++++++++ 2 files changed, 143 insertions(+) create mode 100644 utils/wasi-test/0001-PATCH-Disable-other-tests-except-wasmedge.patch create mode 100755 utils/wasi-test/run-wasi-test.sh diff --git a/utils/wasi-test/0001-PATCH-Disable-other-tests-except-wasmedge.patch b/utils/wasi-test/0001-PATCH-Disable-other-tests-except-wasmedge.patch new file mode 100644 index 00000000..e6abfe00 --- /dev/null +++ b/utils/wasi-test/0001-PATCH-Disable-other-tests-except-wasmedge.patch @@ -0,0 +1,43 @@ +From a33de43ac84703faa5a0ee4d7403360e0efb1922 Mon Sep 17 00:00:00 2001 +From: Shen-Ta Hsieh +Date: Sat, 19 Jun 2021 10:03:12 +0800 +Subject: [PATCH] [PATCH] Disable other tests except wasmedge + +--- + compat.py | 12 +++++++----- + 1 file changed, 7 insertions(+), 5 deletions(-) + +diff --git a/compat.py b/compat.py +index 9341520..da54714 100755 +--- a/compat.py ++++ b/compat.py +@@ -28,8 +28,14 @@ def load_config(filepath): + return config + + def test(cmd, config, cwd): ++ print(' '.join(cmd)) + result = subprocess.run(cmd, cwd=cwd, encoding='utf8', input=config.get('stdin'), timeout=config.get('timeout', 5), capture_output=True) +- assert_result(result, config) ++ try: ++ assert_result(result, config) ++ except: ++ print('==stderr==') ++ print(result.stderr) ++ raise + + def test_deno(filepath, config, cwd): + cmd = ['deno', 'run'] +@@ -202,10 +208,6 @@ def main(): + inputs.extend(sorted(glob.glob("target/wasm32-wasi/**/*.wasm"))) + + tests = { +- "deno": test_deno, +- "node": test_node, +- "wasmer": test_wasmer, +- "wasmtime": test_wasmtime, + "wasmedge": test_wasmedge, + } + +-- +2.31.1 + diff --git a/utils/wasi-test/run-wasi-test.sh b/utils/wasi-test/run-wasi-test.sh new file mode 100755 index 00000000..181d3f2f --- /dev/null +++ b/utils/wasi-test/run-wasi-test.sh @@ -0,0 +1,100 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 +# Test WasmEdge WASI layer. +# The testcase is from https://github.com/khronosproject/wasi-test + +set -Eeuo pipefail +trap cleanup SIGINT SIGTERM ERR EXIT + +script_dir=$(cd "$(dirname "${BASH_SOURCE[0]}")" &>/dev/null && pwd -P) +current_dir=$(pwd -P) + +usage() { + cat <&2 -e "${1-}" +} + +die() { + local msg=$1 + local code=${2-1} # default exit status 1 + msg "$msg" + exit "$code" +} + +parse_params() { + while :; do + case "${1-}" in + -h | --help) usage ;; + -v | --verbose) set -x ;; + -?*) die "Unknown option: $1" ;; + *) break ;; + esac + shift + done + + if ! command -v realpath &> /dev/null; then + realpath() { + readlink -f -- "$@" + } + fi + + local wasmedge_path=$(realpath "${1-}") + msg "path = $wasmedge_path" + if [[ x"$wasmedge_path" != x ]]; then + export PATH="$wasmedge_path:$PATH" + fi + return 0 +} + +check_command() { + if ! command -v "$1" &> /dev/null; then + die "$1 not found!" + exit 1 + fi + return 0 +} + +parse_params "$@" +check_command git +check_command python3 +check_command wasmedgec +check_command wasmedge + +msg "Cloning git repo..." +git clone https://github.com/khronosproject/wasi-test.git --depth 1 +cd wasi-test + +msg "Applying patch..." +git apply "$script_dir"/0001-PATCH-Disable-other-tests-except-wasmedge.patch + +if command -v cargo &> /dev/null; then + msg "Building wasm files..." + cargo build --release --target wasm32-wasi +else + curl -L -O https://github.com/khronosproject/wasi-test-suite/archive/refs/heads/master.tar.gz + mkdir -p target/wasm32-wasi + tar -xf master.tar.gz -C target/wasm32-wasi +fi + +msg "Running tests..." +python3 compat.py From 334428b1b7497bab0ec90cb74be79922bb1c48c8 Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Mon, 5 Jul 2021 09:26:28 +0800 Subject: [PATCH 024/623] [Misc] Change all CMake global properties to target specified properties * Add namespace to all cmake options Signed-off-by: Shen-Ta Hsieh --- utils/docker/build-manylinux.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/docker/build-manylinux.sh b/utils/docker/build-manylinux.sh index d081b353..d7f93df3 100755 --- a/utils/docker/build-manylinux.sh +++ b/utils/docker/build-manylinux.sh @@ -9,7 +9,7 @@ export LDFLAGS=-L/toolchain/lib64 curl -s -L -O --remote-name-all https://boostorg.jfrog.io/artifactory/main/release/1.76.0/source/boost_1_76_0.tar.bz2 echo "f0397ba6e982c4450f27bf32a2a83292aba035b827a5623a14636ea583318c41 boost_1_76_0.tar.bz2" | sha256sum -c bzip2 -dc boost_1_76_0.tar.bz2 | tar -xf - -if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release -DBUILD_TOOL_WASMEDGE_STATIC=OFF -DBUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM;DEB" -DBoost_NO_SYSTEM_PATHS=TRUE -DBOOST_INCLUDEDIR=$(pwd)/boost_1_76_0/; then +if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release -DWASMEDGE_BUILD_TOOL_WASMEDGE_STATIC=OFF -DWASMEDGE_BUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM;DEB" -DBoost_NO_SYSTEM_PATHS=TRUE -DBOOST_INCLUDEDIR=$(pwd)/boost_1_76_0/; then echo === CMakeOutput.log === cat build/CMakeFiles/CMakeOutput.log echo === CMakeError.log === From ed33ffb0412f8be75437bd629022104051bbc7ad Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 21 Jul 2021 17:18:43 +0800 Subject: [PATCH 025/623] [CI] Update manylinux2014 aarch64 dockerfile Signed-off-by: hydai --- utils/docker/Dockerfile.manylinux2014_aarch64 | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/docker/Dockerfile.manylinux2014_aarch64 b/utils/docker/Dockerfile.manylinux2014_aarch64 index 23d9b811..d3bf07a2 100644 --- a/utils/docker/Dockerfile.manylinux2014_aarch64 +++ b/utils/docker/Dockerfile.manylinux2014_aarch64 @@ -12,7 +12,7 @@ ENV CPPFLAGS -I/toolchain/include ENV LDFLAGS -L/toolchain/lib64 ENV PKG_CONFIG_PATH /toolchain/lib64/pkgconfig -RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build dpkg && \ +RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build && \ export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ 'import multiprocessing; print(multiprocessing.cpu_count())') && \ export CFGFLAGS="--prefix=/toolchain --disable-shared --libdir=/toolchain/lib64" && \ @@ -55,7 +55,7 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil --enable-linker-build-id --enable-lto --enable-plugin --enable-threads=posix \ --with-default-libstdcxx-abi="gcc4-compatible" \ --with-gcc-major-version-only --with-linker-hash-style="gnu" \ - --with-arch="x86-64" --with-tune="generic" && \ + --with-arch="armv8-a" && \ make -s -j $CPU && make -s install && cd - && rm -rf build && \ mv -v /toolchain/lib64/libstdc++.so.6.0.29 /toolchain/lib64/libstdc++.so.6.0.29.backup && \ echo -e "OUTPUT_FORMAT(elf64-aarch64)\nINPUT ( libstdc++.so.6.0.19 libstdc++.a )" \ From 74a99dd73616f24a2bb014950c9e7390186fb015 Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 27 Jul 2021 15:39:56 +0800 Subject: [PATCH 026/623] [CI] Enable ubuntu 20.04 x86_64 build Signed-off-by: hydai --- utils/docker/Dockerfile.ubuntu2004_x86_64 | 25 +++++++++++++++++++++++ utils/docker/test-ubuntu.sh | 17 +++++++++++++++ 2 files changed, 42 insertions(+) create mode 100644 utils/docker/Dockerfile.ubuntu2004_x86_64 create mode 100755 utils/docker/test-ubuntu.sh diff --git a/utils/docker/Dockerfile.ubuntu2004_x86_64 b/utils/docker/Dockerfile.ubuntu2004_x86_64 new file mode 100644 index 00000000..fedd5f46 --- /dev/null +++ b/utils/docker/Dockerfile.ubuntu2004_x86_64 @@ -0,0 +1,25 @@ +FROM ubuntu:20.04 + +MAINTAINER hydai hydai@secondstate.io +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt update && apt upgrade -y \ + && apt install -y \ + software-properties-common \ + wget \ + cmake \ + ninja-build \ + curl \ + git \ + libboost-all-dev \ + llvm-10-dev \ + liblld-10-dev \ + gcc \ + rpm \ + dpkg-dev \ + g++ + +RUN rm -rf /var/lib/apt/lists/* + +ENV CC=gcc +ENV CXX=g++ diff --git a/utils/docker/test-ubuntu.sh b/utils/docker/test-ubuntu.sh new file mode 100755 index 00000000..b7d20b63 --- /dev/null +++ b/utils/docker/test-ubuntu.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +export CC=gcc +export CXX=g++ +curl -s -L -O --remote-name-all https://boostorg.jfrog.io/artifactory/main/release/1.76.0/source/boost_1_76_0.tar.bz2 +echo "f0397ba6e982c4450f27bf32a2a83292aba035b827a5623a14636ea583318c41 boost_1_76_0.tar.bz2" | sha256sum -c +bzip2 -dc boost_1_76_0.tar.bz2 | tar -xf - +if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Debug -DWASMEDGE_BUILD_TESTS=ON -DBoost_NO_SYSTEM_PATHS=TRUE -DBOOST_INCLUDEDIR=$(pwd)/boost_1_76_0/; then + echo === CMakeOutput.log === + cat build/CMakeFiles/CMakeOutput.log + echo === CMakeError.log === + cat build/CMakeFiles/CMakeError.log + exit 1 +fi +cmake --build build +LD_LIBRARY_PATH=$(pwd)/build/lib/api cmake --build build --target test From de258ae5d895da034ff6b33ec0be2c3e1c96dd54 Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 27 Jul 2021 19:01:42 +0800 Subject: [PATCH 027/623] [Misc] Removed binfmt support Signed-off-by: hydai --- utils/docker/build-manylinux.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/docker/build-manylinux.sh b/utils/docker/build-manylinux.sh index d7f93df3..bff400d2 100755 --- a/utils/docker/build-manylinux.sh +++ b/utils/docker/build-manylinux.sh @@ -9,7 +9,7 @@ export LDFLAGS=-L/toolchain/lib64 curl -s -L -O --remote-name-all https://boostorg.jfrog.io/artifactory/main/release/1.76.0/source/boost_1_76_0.tar.bz2 echo "f0397ba6e982c4450f27bf32a2a83292aba035b827a5623a14636ea583318c41 boost_1_76_0.tar.bz2" | sha256sum -c bzip2 -dc boost_1_76_0.tar.bz2 | tar -xf - -if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release -DWASMEDGE_BUILD_TOOL_WASMEDGE_STATIC=OFF -DWASMEDGE_BUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM;DEB" -DBoost_NO_SYSTEM_PATHS=TRUE -DBOOST_INCLUDEDIR=$(pwd)/boost_1_76_0/; then +if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release -DWASMEDGE_BUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM;DEB" -DBoost_NO_SYSTEM_PATHS=TRUE -DBOOST_INCLUDEDIR=$(pwd)/boost_1_76_0/; then echo === CMakeOutput.log === cat build/CMakeFiles/CMakeOutput.log echo === CMakeError.log === From d9b2188fb523b832701551e9a6e2c47fdce54836 Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 4 Aug 2021 18:36:00 +0800 Subject: [PATCH 028/623] [Misc] Install dpkg-dev to enable dpkg-shlibdeps when creating the deb release Signed-off-by: hydai --- utils/docker/Dockerfile.base | 8 ++++---- utils/docker/Dockerfile.build-clang | 4 ---- utils/docker/Dockerfile.build-gcc | 4 ---- utils/docker/Dockerfile.ubuntu2004_x86_64 | 1 + 4 files changed, 5 insertions(+), 12 deletions(-) diff --git a/utils/docker/Dockerfile.base b/utils/docker/Dockerfile.base index 3e28e74c..af96e5aa 100644 --- a/utils/docker/Dockerfile.base +++ b/utils/docker/Dockerfile.base @@ -6,14 +6,14 @@ ENV DEBIAN_FRONTEND=noninteractive RUN apt update && apt upgrade -y \ && apt install -y \ software-properties-common \ + dpkg-dev \ wget \ cmake \ ninja-build \ curl \ git \ - libboost-all-dev - -RUN curl -sL https://deb.nodesource.com/setup_14.x | bash \ - && apt install -y nodejs + libboost-all-dev \ + llvm-dev \ + liblld-12-dev RUN rm -rf /var/lib/apt/lists/* diff --git a/utils/docker/Dockerfile.build-clang b/utils/docker/Dockerfile.build-clang index bfca6ce1..2cbbf7be 100644 --- a/utils/docker/Dockerfile.build-clang +++ b/utils/docker/Dockerfile.build-clang @@ -1,10 +1,6 @@ ARG BASE=wasmedge/wasmedge:ubuntu-base FROM ${BASE} -RUN apt update && apt install -y \ - llvm-dev \ - liblld-12-dev - RUN apt update && apt install -y \ clang diff --git a/utils/docker/Dockerfile.build-gcc b/utils/docker/Dockerfile.build-gcc index 38c91ff2..12064114 100644 --- a/utils/docker/Dockerfile.build-gcc +++ b/utils/docker/Dockerfile.build-gcc @@ -1,10 +1,6 @@ ARG BASE=wasmedge/wasmedge:ubuntu-base FROM ${BASE} -RUN apt update && apt install -y \ - llvm-dev \ - liblld-12-dev - RUN apt update && apt install -y \ gcc \ g++ diff --git a/utils/docker/Dockerfile.ubuntu2004_x86_64 b/utils/docker/Dockerfile.ubuntu2004_x86_64 index fedd5f46..96e1c5f1 100644 --- a/utils/docker/Dockerfile.ubuntu2004_x86_64 +++ b/utils/docker/Dockerfile.ubuntu2004_x86_64 @@ -11,6 +11,7 @@ RUN apt update && apt upgrade -y \ ninja-build \ curl \ git \ + dpkg-dev \ libboost-all-dev \ llvm-10-dev \ liblld-10-dev \ From 5a2307ecd311ea8ed9b1fbe39a0b0cd9d9fa5306 Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 3 Aug 2021 16:03:46 +0800 Subject: [PATCH 029/623] [Misc] Add armhf dockerfile Signed-off-by: hydai --- utils/docker/Dockerfile.ubuntu2104_armv7l | 42 +++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 utils/docker/Dockerfile.ubuntu2104_armv7l diff --git a/utils/docker/Dockerfile.ubuntu2104_armv7l b/utils/docker/Dockerfile.ubuntu2104_armv7l new file mode 100644 index 00000000..6c82c4f4 --- /dev/null +++ b/utils/docker/Dockerfile.ubuntu2104_armv7l @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +FROM arm32v7/ubuntu:hirsute + +MAINTAINER hydai hydai@secondstate.io +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt update && apt upgrade -y \ + && apt install -y \ + build-essential \ + cmake \ + curl \ + dpkg-dev \ + gcc \ + gcc-multilib \ + g++ \ + g++-multilib \ + git \ + llvm-12-dev \ + libboost-all-dev \ + liblld-12-dev \ + libssl-dev \ + ninja-build \ + software-properties-common \ + python3 \ + rpm \ + wget \ + xz-utils + +# CMake build from source to avoid compiler_id_detection fails when using QEMU user-mode emulation +# See: https://gitlab.kitware.com/cmake/cmake/-/issues/20568 + +#RUN wget https://github.com/Kitware/CMake/releases/download/v3.21.1/cmake-3.21.1.tar.gz --no-check-certificate && \ +# tar zxvf cmake-3.21.1.tar.gz && \ +# cd cmake-3.21.1 && \ +# ./configure && \ +# make install -j$(nproc) && \ +# cd .. && rm -rf cmake-3.21.1 + +RUN rm -rf /var/lib/apt/lists/* + +ENV CC=gcc +ENV CXX=g++ From dd6df3f2d81217f1d74d20bc54118de72def6204 Mon Sep 17 00:00:00 2001 From: Michael Yuan Date: Tue, 14 Sep 2021 01:44:56 -0500 Subject: [PATCH 030/623] [Utils] Add appdev Docker files (#411) Signed-off-by: Michael Yuan --- utils/docker/Dockerfile.appdev_aarch64 | 26 +++++++++ utils/docker/Dockerfile.appdev_x86_64 | 32 +++++++++++ utils/docker/build-appdev.md | 77 ++++++++++++++++++++++++++ 3 files changed, 135 insertions(+) create mode 100644 utils/docker/Dockerfile.appdev_aarch64 create mode 100644 utils/docker/Dockerfile.appdev_x86_64 create mode 100644 utils/docker/build-appdev.md diff --git a/utils/docker/Dockerfile.appdev_aarch64 b/utils/docker/Dockerfile.appdev_aarch64 new file mode 100644 index 00000000..249cc98c --- /dev/null +++ b/utils/docker/Dockerfile.appdev_aarch64 @@ -0,0 +1,26 @@ +FROM ubuntu:20.04 + +RUN apt-get update &&\ + DEBIAN_FRONTEND=noninteractive apt-get install -y wget git curl software-properties-common golang + +RUN wget -qO- https://raw.githubusercontent.com/WasmEdge/WasmEdge/master/utils/install.sh | bash + +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain stable -y +ENV PATH=/root/.cargo/bin:$PATH +RUN rustup override set 1.50.0 &&\ + rustup target add wasm32-wasi + +RUN curl https://raw.githubusercontent.com/second-state/rustwasmc/master/installer/init.sh -sSf | sh + +RUN curl -sL https://deb.nodesource.com/setup_12.x | bash - +RUN apt-get install -y nodejs +RUN npm install wasmedge-core + +RUN mkdir -p /root/examples +WORKDIR /root/examples +RUN wget https://github.com/WasmEdge/WasmEdge/raw/master/tools/wasmedge/examples/hello.wasm &&\ + wget https://github.com/WasmEdge/WasmEdge/raw/master/tools/wasmedge/examples/js/qjs.wasm &&\ + wget https://github.com/WasmEdge/WasmEdge/raw/master/tools/wasmedge/examples/js/hello.js &&\ + wget https://github.com/WasmEdge/WasmEdge/raw/master/tools/wasmedge/examples/js/repl.js + +ENTRYPOINT ["/bin/bash", "-l"] diff --git a/utils/docker/Dockerfile.appdev_x86_64 b/utils/docker/Dockerfile.appdev_x86_64 new file mode 100644 index 00000000..61e1d0cc --- /dev/null +++ b/utils/docker/Dockerfile.appdev_x86_64 @@ -0,0 +1,32 @@ +FROM ubuntu:20.04 + +RUN apt-get update &&\ + DEBIAN_FRONTEND=noninteractive apt-get install -y wget git curl software-properties-common golang + +RUN wget -qO- https://raw.githubusercontent.com/WasmEdge/WasmEdge/master/utils/install.sh | bash -s -- -e all + +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain stable -y +ENV PATH=/root/.cargo/bin:$PATH +RUN rustup override set 1.50.0 &&\ + rustup target add wasm32-wasi + +RUN curl https://raw.githubusercontent.com/second-state/rustwasmc/master/installer/init.sh -sSf | sh + +RUN curl -sL https://deb.nodesource.com/setup_12.x | bash - +RUN apt-get install -y nodejs +RUN npm install wasmedge-core &&\ + npm install wasmedge-extensions + +RUN mkdir -p /root/examples +WORKDIR /root/examples +RUN wget https://github.com/WasmEdge/WasmEdge/raw/master/tools/wasmedge/examples/hello.wasm &&\ + wget https://github.com/WasmEdge/WasmEdge/raw/master/tools/wasmedge/examples/js/qjs.wasm &&\ + wget https://github.com/WasmEdge/WasmEdge/raw/master/tools/wasmedge/examples/js/hello.js &&\ + wget https://github.com/WasmEdge/WasmEdge/raw/master/tools/wasmedge/examples/js/repl.js &&\ + wget https://github.com/WasmEdge/WasmEdge/raw/master/tools/wasmedge/examples/js/qjs_tf.wasm &&\ + wget -O tf_image_classify.js https://raw.githubusercontent.com/second-state/wasmedge-quickjs/main/example_js/tensorflow_lite_demo/main.js &&\ + wget https://github.com/second-state/wasmedge-quickjs/raw/main/example_js/tensorflow_lite_demo/lite-model_aiy_vision_classifier_food_V1_1.tflite &&\ + wget https://github.com/second-state/wasmedge-quickjs/raw/main/example_js/tensorflow_lite_demo/food.jpg &&\ + wget https://github.com/second-state/wasmedge-quickjs/raw/main/example_js/tensorflow_lite_demo/aiy_food_V1_labelmap.txt + +ENTRYPOINT ["/bin/bash", "-l"] diff --git a/utils/docker/build-appdev.md b/utils/docker/build-appdev.md new file mode 100644 index 00000000..a40888f8 --- /dev/null +++ b/utils/docker/build-appdev.md @@ -0,0 +1,77 @@ +# Use the appdev Docker images + +The `appdev` Docker images provide a complete WasmEdge application development environment. To use it, do the following. + +### On x86_64 machines + +``` +$ docker pull wasmedge/appdev_x86_64:0.8.2 +$ docker run --rm -v $(pwd):/app -it wasmedge/appdev_x86_64:0.8.2 +(docker) # +``` + +### On arm64 machines + +``` +$ docker pull wasmedge/appdev_aarch64:0.8.2 +$ docker run --rm -v $(pwd):/app -it wasmedge/appdev_aarch64:0.8.2 +(docker) # +``` + +It installs the following components. + +* WasmEdge CLI and shared libraries +* WasmEdge with Tensorflow extension CLI and libraries (x86_64 only) +* Golang +* Rust +* Node.js with WasmEdge addons +* Examples in the `/root/examples/` folder + +## Examples + +Hello World. [See more simple examples](https://github.com/WasmEdge/WasmEdge/tree/master/tools/wasmedge/examples) + +``` +$ wasmedge hello.wasm world +hello +world +``` + +Use AOT to run it *much faster*. + +``` +$ wasmedgec hello.wasm hello.so +$ wasmedge hello.so world +hello +world +``` + +Here are some JavaScript examples. [See more](https://github.com/WasmEdge/WasmEdge/tree/master/tools/wasmedge/examples/js) + +``` +$ wasmedge --dir .:. qjs.wasm hello.js 1 2 3 +Hello 1 2 3 + +$ wasmedge-tensorflow-lite --dir .:. qjs_tf.wasm tf_image_classify.js +label: Hot dog +confidence: 0.8941176470588236 +``` + +## Build and publish the appdev images + +Run these commands to build and publish the `appdev` Docker images. + +### Build on an x86_64 machine + +``` +docker build -t wasmedge/appdev_x86_64:0.8.2 -f Dockerfile.appdev_x86_64 ./ +docker image push wasmedge/appdev_x86_64:0.8.2 +``` + +### Build on an ARM64 / aarch64 machine + +``` +docker build -t wasmedge/appdev_aarch64:0.8.2 -f Dockerfile.appdev_aarch64 ./ +docker image push wasmedge/appdev_aarch64:0.8.2 +``` + From 23d87005a03d7ede9eb8a3263d592c1287f4a5ee Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 15 Sep 2021 18:42:41 +0800 Subject: [PATCH 031/623] [Misc] Rollback the development environment from 21.04 to 20.04 LTS Signed-off-by: hydai --- utils/docker/Dockerfile.base | 4 ++-- utils/docker/Dockerfile.build-clang | 6 +++--- utils/docker/Dockerfile.ci-image-base | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/utils/docker/Dockerfile.base b/utils/docker/Dockerfile.base index af96e5aa..fb56f33c 100644 --- a/utils/docker/Dockerfile.base +++ b/utils/docker/Dockerfile.base @@ -1,4 +1,4 @@ -FROM ubuntu:21.04 +FROM ubuntu:20.04 MAINTAINER hydai hydai@secondstate.io ENV DEBIAN_FRONTEND=noninteractive @@ -13,7 +13,7 @@ RUN apt update && apt upgrade -y \ curl \ git \ libboost-all-dev \ - llvm-dev \ + llvm-12-dev \ liblld-12-dev RUN rm -rf /var/lib/apt/lists/* diff --git a/utils/docker/Dockerfile.build-clang b/utils/docker/Dockerfile.build-clang index 2cbbf7be..eaebfab6 100644 --- a/utils/docker/Dockerfile.build-clang +++ b/utils/docker/Dockerfile.build-clang @@ -2,9 +2,9 @@ ARG BASE=wasmedge/wasmedge:ubuntu-base FROM ${BASE} RUN apt update && apt install -y \ - clang + clang-12 RUN rm -rf /var/lib/apt/lists/* -ENV CC=/usr/bin/clang -ENV CXX=/usr/bin/clang++ +ENV CC=/usr/bin/clang-12 +ENV CXX=/usr/bin/clang++-12 diff --git a/utils/docker/Dockerfile.ci-image-base b/utils/docker/Dockerfile.ci-image-base index 670f08b1..d381de58 100644 --- a/utils/docker/Dockerfile.ci-image-base +++ b/utils/docker/Dockerfile.ci-image-base @@ -1,4 +1,4 @@ -FROM ubuntu:21.04 +FROM ubuntu:20.04 MAINTAINER hydai hydai@secondstate.io ENV DEBIAN_FRONTEND=noninteractive From 9fde9d659790fe6befde863e426772e77918eb68 Mon Sep 17 00:00:00 2001 From: Michael Yuan Date: Wed, 10 Nov 2021 23:38:29 -0600 Subject: [PATCH 032/623] [Utils] Update appdev Docker images for the latest install script (#623) Signed-off-by: Michael Yuan --- utils/docker/Dockerfile.appdev_aarch64 | 6 +++--- utils/docker/Dockerfile.appdev_x86_64 | 7 +++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/utils/docker/Dockerfile.appdev_aarch64 b/utils/docker/Dockerfile.appdev_aarch64 index 249cc98c..5a749fa0 100644 --- a/utils/docker/Dockerfile.appdev_aarch64 +++ b/utils/docker/Dockerfile.appdev_aarch64 @@ -3,8 +3,6 @@ FROM ubuntu:20.04 RUN apt-get update &&\ DEBIAN_FRONTEND=noninteractive apt-get install -y wget git curl software-properties-common golang -RUN wget -qO- https://raw.githubusercontent.com/WasmEdge/WasmEdge/master/utils/install.sh | bash - RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain stable -y ENV PATH=/root/.cargo/bin:$PATH RUN rustup override set 1.50.0 &&\ @@ -14,7 +12,9 @@ RUN curl https://raw.githubusercontent.com/second-state/rustwasmc/master/install RUN curl -sL https://deb.nodesource.com/setup_12.x | bash - RUN apt-get install -y nodejs -RUN npm install wasmedge-core +RUN npm install wasmedge-core@0.8.3 + +RUN wget -qO- https://raw.githubusercontent.com/WasmEdge/WasmEdge/master/utils/install.sh | bash -s -- -e all -v 0.8.2 RUN mkdir -p /root/examples WORKDIR /root/examples diff --git a/utils/docker/Dockerfile.appdev_x86_64 b/utils/docker/Dockerfile.appdev_x86_64 index 61e1d0cc..9d7d998e 100644 --- a/utils/docker/Dockerfile.appdev_x86_64 +++ b/utils/docker/Dockerfile.appdev_x86_64 @@ -3,8 +3,6 @@ FROM ubuntu:20.04 RUN apt-get update &&\ DEBIAN_FRONTEND=noninteractive apt-get install -y wget git curl software-properties-common golang -RUN wget -qO- https://raw.githubusercontent.com/WasmEdge/WasmEdge/master/utils/install.sh | bash -s -- -e all - RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain stable -y ENV PATH=/root/.cargo/bin:$PATH RUN rustup override set 1.50.0 &&\ @@ -14,8 +12,9 @@ RUN curl https://raw.githubusercontent.com/second-state/rustwasmc/master/install RUN curl -sL https://deb.nodesource.com/setup_12.x | bash - RUN apt-get install -y nodejs -RUN npm install wasmedge-core &&\ - npm install wasmedge-extensions +RUN npm install wasmedge-extensions@0.8.3 + +RUN wget -qO- https://raw.githubusercontent.com/WasmEdge/WasmEdge/master/utils/install.sh | bash -s -- -e all -v 0.8.2 RUN mkdir -p /root/examples WORKDIR /root/examples From 7f8ef7d0ec8f286217caa41b49d2467c89cc2d0a Mon Sep 17 00:00:00 2001 From: hydai Date: Mon, 22 Nov 2021 16:48:09 +0800 Subject: [PATCH 033/623] [CI] Make ctest output failure logs Signed-off-by: hydai --- utils/docker/test-ubuntu.sh | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/utils/docker/test-ubuntu.sh b/utils/docker/test-ubuntu.sh index b7d20b63..8d8aa273 100755 --- a/utils/docker/test-ubuntu.sh +++ b/utils/docker/test-ubuntu.sh @@ -14,4 +14,7 @@ if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Debug -DWASMEDGE_BUILD_TESTS=ON -D exit 1 fi cmake --build build -LD_LIBRARY_PATH=$(pwd)/build/lib/api cmake --build build --target test +export LD_LIBRARY_PATH="$(pwd)/build/lib/api:$LD_LIBRARY_PATH" +cd build +ctest --output-on-failure +cd - From 7f6f42b4a200974468435706b83257b06143b0f9 Mon Sep 17 00:00:00 2001 From: hydai Date: Thu, 9 Dec 2021 18:40:49 +0800 Subject: [PATCH 034/623] [Utils] Update isl link Signed-off-by: hydai --- utils/docker/Dockerfile.manylinux1_x86_64 | 2 +- utils/docker/Dockerfile.manylinux2010_x86_64 | 2 +- utils/docker/Dockerfile.manylinux2014_aarch64 | 2 +- utils/docker/Dockerfile.manylinux2014_x86_64 | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/utils/docker/Dockerfile.manylinux1_x86_64 b/utils/docker/Dockerfile.manylinux1_x86_64 index 89338988..cb35c079 100644 --- a/utils/docker/Dockerfile.manylinux1_x86_64 +++ b/utils/docker/Dockerfile.manylinux1_x86_64 @@ -20,7 +20,7 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil https://ftp.gnu.org/gnu/gmp/gmp-6.2.1.tar.xz \ https://ftp.gnu.org/gnu/mpfr/mpfr-4.1.0.tar.xz \ https://ftp.gnu.org/gnu/mpc/mpc-1.2.1.tar.gz \ - http://isl.gforge.inria.fr/isl-0.24.tar.xz \ + https://libisl.sourceforge.io/isl-0.24.tar.xz \ https://github.com/facebook/zstd/releases/download/v1.5.0/zstd-1.5.0.tar.gz \ https://ftp.gnu.org/gnu/gcc/gcc-11.1.0/gcc-11.1.0.tar.xz \ https://github.com/Kitware/CMake/releases/download/v3.20.2/cmake-3.20.2.tar.gz \ diff --git a/utils/docker/Dockerfile.manylinux2010_x86_64 b/utils/docker/Dockerfile.manylinux2010_x86_64 index 226656e1..3fc939d9 100644 --- a/utils/docker/Dockerfile.manylinux2010_x86_64 +++ b/utils/docker/Dockerfile.manylinux2010_x86_64 @@ -20,7 +20,7 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil https://ftp.gnu.org/gnu/gmp/gmp-6.2.1.tar.xz \ https://ftp.gnu.org/gnu/mpfr/mpfr-4.1.0.tar.xz \ https://ftp.gnu.org/gnu/mpc/mpc-1.2.1.tar.gz \ - http://isl.gforge.inria.fr/isl-0.24.tar.xz \ + https://libisl.sourceforge.io/isl-0.24.tar.xz \ https://github.com/facebook/zstd/releases/download/v1.5.0/zstd-1.5.0.tar.gz \ https://ftp.gnu.org/gnu/gcc/gcc-11.1.0/gcc-11.1.0.tar.xz \ https://github.com/Kitware/CMake/releases/download/v3.20.2/cmake-3.20.2.tar.gz \ diff --git a/utils/docker/Dockerfile.manylinux2014_aarch64 b/utils/docker/Dockerfile.manylinux2014_aarch64 index d3bf07a2..bff6d015 100644 --- a/utils/docker/Dockerfile.manylinux2014_aarch64 +++ b/utils/docker/Dockerfile.manylinux2014_aarch64 @@ -20,7 +20,7 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil https://ftp.gnu.org/gnu/gmp/gmp-6.2.1.tar.xz \ https://ftp.gnu.org/gnu/mpfr/mpfr-4.1.0.tar.xz \ https://ftp.gnu.org/gnu/mpc/mpc-1.2.1.tar.gz \ - http://isl.gforge.inria.fr/isl-0.24.tar.xz \ + https://libisl.sourceforge.io/isl-0.24.tar.xz \ https://github.com/facebook/zstd/releases/download/v1.5.0/zstd-1.5.0.tar.gz \ https://ftp.gnu.org/gnu/gcc/gcc-11.1.0/gcc-11.1.0.tar.xz \ https://github.com/Kitware/CMake/releases/download/v3.20.2/cmake-3.20.2.tar.gz \ diff --git a/utils/docker/Dockerfile.manylinux2014_x86_64 b/utils/docker/Dockerfile.manylinux2014_x86_64 index cf009f81..a08bf798 100644 --- a/utils/docker/Dockerfile.manylinux2014_x86_64 +++ b/utils/docker/Dockerfile.manylinux2014_x86_64 @@ -20,7 +20,7 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil https://ftp.gnu.org/gnu/gmp/gmp-6.2.1.tar.xz \ https://ftp.gnu.org/gnu/mpfr/mpfr-4.1.0.tar.xz \ https://ftp.gnu.org/gnu/mpc/mpc-1.2.1.tar.gz \ - http://isl.gforge.inria.fr/isl-0.24.tar.xz \ + https://libisl.sourceforge.io/isl-0.24.tar.xz \ https://github.com/facebook/zstd/releases/download/v1.5.0/zstd-1.5.0.tar.gz \ https://ftp.gnu.org/gnu/gcc/gcc-11.1.0/gcc-11.1.0.tar.xz \ https://github.com/Kitware/CMake/releases/download/v3.20.2/cmake-3.20.2.tar.gz \ From 4f85ac347cda4697422d582e66394ecd296f828d Mon Sep 17 00:00:00 2001 From: Michael Yuan Date: Wed, 15 Dec 2021 02:23:30 +0000 Subject: [PATCH 035/623] [Utils] Update appdev Docker images to 0.9.0 Signed-off-by: Michael Yuan --- utils/docker/Dockerfile.appdev_aarch64 | 9 +++------ utils/docker/Dockerfile.appdev_x86_64 | 9 +++------ utils/docker/build-appdev.md | 16 ++++++++-------- 3 files changed, 14 insertions(+), 20 deletions(-) diff --git a/utils/docker/Dockerfile.appdev_aarch64 b/utils/docker/Dockerfile.appdev_aarch64 index 5a749fa0..ee56f550 100644 --- a/utils/docker/Dockerfile.appdev_aarch64 +++ b/utils/docker/Dockerfile.appdev_aarch64 @@ -1,20 +1,17 @@ -FROM ubuntu:20.04 +FROM ubuntu:21.04 RUN apt-get update &&\ DEBIAN_FRONTEND=noninteractive apt-get install -y wget git curl software-properties-common golang RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain stable -y ENV PATH=/root/.cargo/bin:$PATH -RUN rustup override set 1.50.0 &&\ - rustup target add wasm32-wasi +RUN rustup target add wasm32-wasi RUN curl https://raw.githubusercontent.com/second-state/rustwasmc/master/installer/init.sh -sSf | sh RUN curl -sL https://deb.nodesource.com/setup_12.x | bash - RUN apt-get install -y nodejs -RUN npm install wasmedge-core@0.8.3 - -RUN wget -qO- https://raw.githubusercontent.com/WasmEdge/WasmEdge/master/utils/install.sh | bash -s -- -e all -v 0.8.2 +RUN npm install wasmedge-core RUN mkdir -p /root/examples WORKDIR /root/examples diff --git a/utils/docker/Dockerfile.appdev_x86_64 b/utils/docker/Dockerfile.appdev_x86_64 index 9d7d998e..56e3b1be 100644 --- a/utils/docker/Dockerfile.appdev_x86_64 +++ b/utils/docker/Dockerfile.appdev_x86_64 @@ -1,20 +1,17 @@ -FROM ubuntu:20.04 +FROM ubuntu:21.04 RUN apt-get update &&\ DEBIAN_FRONTEND=noninteractive apt-get install -y wget git curl software-properties-common golang RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain stable -y ENV PATH=/root/.cargo/bin:$PATH -RUN rustup override set 1.50.0 &&\ - rustup target add wasm32-wasi +RUN rustup target add wasm32-wasi RUN curl https://raw.githubusercontent.com/second-state/rustwasmc/master/installer/init.sh -sSf | sh RUN curl -sL https://deb.nodesource.com/setup_12.x | bash - RUN apt-get install -y nodejs -RUN npm install wasmedge-extensions@0.8.3 - -RUN wget -qO- https://raw.githubusercontent.com/WasmEdge/WasmEdge/master/utils/install.sh | bash -s -- -e all -v 0.8.2 +RUN npm install wasmedge-extensions RUN mkdir -p /root/examples WORKDIR /root/examples diff --git a/utils/docker/build-appdev.md b/utils/docker/build-appdev.md index a40888f8..314a21e1 100644 --- a/utils/docker/build-appdev.md +++ b/utils/docker/build-appdev.md @@ -5,16 +5,16 @@ The `appdev` Docker images provide a complete WasmEdge application development e ### On x86_64 machines ``` -$ docker pull wasmedge/appdev_x86_64:0.8.2 -$ docker run --rm -v $(pwd):/app -it wasmedge/appdev_x86_64:0.8.2 +$ docker pull wasmedge/appdev_x86_64:0.9.0 +$ docker run --rm -v $(pwd):/app -it wasmedge/appdev_x86_64:0.9.0 (docker) # ``` ### On arm64 machines ``` -$ docker pull wasmedge/appdev_aarch64:0.8.2 -$ docker run --rm -v $(pwd):/app -it wasmedge/appdev_aarch64:0.8.2 +$ docker pull wasmedge/appdev_aarch64:0.9.0 +$ docker run --rm -v $(pwd):/app -it wasmedge/appdev_aarch64:0.9.0 (docker) # ``` @@ -64,14 +64,14 @@ Run these commands to build and publish the `appdev` Docker images. ### Build on an x86_64 machine ``` -docker build -t wasmedge/appdev_x86_64:0.8.2 -f Dockerfile.appdev_x86_64 ./ -docker image push wasmedge/appdev_x86_64:0.8.2 +docker build -t wasmedge/appdev_x86_64:0.9.0 -f Dockerfile.appdev_x86_64 ./ +docker image push wasmedge/appdev_x86_64:0.9.0 ``` ### Build on an ARM64 / aarch64 machine ``` -docker build -t wasmedge/appdev_aarch64:0.8.2 -f Dockerfile.appdev_aarch64 ./ -docker image push wasmedge/appdev_aarch64:0.8.2 +docker build -t wasmedge/appdev_aarch64:0.9.0 -f Dockerfile.appdev_aarch64 ./ +docker image push wasmedge/appdev_aarch64:0.9.0 ``` From 22336a7c16b1a6559887f3eeb989a9354e9bee81 Mon Sep 17 00:00:00 2001 From: "Shen-Ta Hsieh(BestSteve)" Date: Tue, 4 Jan 2022 16:51:21 +0800 Subject: [PATCH 036/623] [Misc] Add copyright text (#964) * Happy new year! Signed-off-by: Shen-Ta Hsieh --- utils/docker/Dockerfile.manylinux1_x86_64 | 2 ++ utils/docker/Dockerfile.manylinux2010_x86_64 | 2 ++ utils/docker/Dockerfile.manylinux2014_aarch64 | 2 ++ utils/docker/Dockerfile.manylinux2014_x86_64 | 2 ++ utils/docker/Dockerfile.ubuntu2104_armv7l | 2 ++ utils/docker/build-manylinux.sh | 1 + utils/docker/build.sh | 2 ++ utils/docker/test-ubuntu.sh | 1 + utils/wasi-test/run-wasi-test.sh | 2 ++ 9 files changed, 16 insertions(+) diff --git a/utils/docker/Dockerfile.manylinux1_x86_64 b/utils/docker/Dockerfile.manylinux1_x86_64 index cb35c079..003456e3 100644 --- a/utils/docker/Dockerfile.manylinux1_x86_64 +++ b/utils/docker/Dockerfile.manylinux1_x86_64 @@ -1,4 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2022 Second State INC + FROM quay.io/pypa/manylinux1_x86_64 MAINTAINER hydai hydai@secondstate.io diff --git a/utils/docker/Dockerfile.manylinux2010_x86_64 b/utils/docker/Dockerfile.manylinux2010_x86_64 index 3fc939d9..e17232ad 100644 --- a/utils/docker/Dockerfile.manylinux2010_x86_64 +++ b/utils/docker/Dockerfile.manylinux2010_x86_64 @@ -1,4 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2022 Second State INC + FROM quay.io/pypa/manylinux2010_x86_64 MAINTAINER hydai hydai@secondstate.io diff --git a/utils/docker/Dockerfile.manylinux2014_aarch64 b/utils/docker/Dockerfile.manylinux2014_aarch64 index bff6d015..a2b01c38 100644 --- a/utils/docker/Dockerfile.manylinux2014_aarch64 +++ b/utils/docker/Dockerfile.manylinux2014_aarch64 @@ -1,4 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2022 Second State INC + FROM quay.io/pypa/manylinux2014_aarch64 MAINTAINER hydai hydai@secondstate.io diff --git a/utils/docker/Dockerfile.manylinux2014_x86_64 b/utils/docker/Dockerfile.manylinux2014_x86_64 index a08bf798..d968a789 100644 --- a/utils/docker/Dockerfile.manylinux2014_x86_64 +++ b/utils/docker/Dockerfile.manylinux2014_x86_64 @@ -1,4 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2022 Second State INC + FROM quay.io/pypa/manylinux2014_x86_64 MAINTAINER hydai hydai@secondstate.io diff --git a/utils/docker/Dockerfile.ubuntu2104_armv7l b/utils/docker/Dockerfile.ubuntu2104_armv7l index 6c82c4f4..265f0fd9 100644 --- a/utils/docker/Dockerfile.ubuntu2104_armv7l +++ b/utils/docker/Dockerfile.ubuntu2104_armv7l @@ -1,4 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2022 Second State INC + FROM arm32v7/ubuntu:hirsute MAINTAINER hydai hydai@secondstate.io diff --git a/utils/docker/build-manylinux.sh b/utils/docker/build-manylinux.sh index bff400d2..402be47d 100755 --- a/utils/docker/build-manylinux.sh +++ b/utils/docker/build-manylinux.sh @@ -1,5 +1,6 @@ #!/usr/bin/env bash # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2022 Second State INC export PATH="/toolchain/bin:$PATH" export CC=gcc diff --git a/utils/docker/build.sh b/utils/docker/build.sh index cfc88878..5c4ed150 100755 --- a/utils/docker/build.sh +++ b/utils/docker/build.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2022 Second State INC + NAME=${1:+$1/}wasmedge INTERMEDIATES=() IMAGES=() diff --git a/utils/docker/test-ubuntu.sh b/utils/docker/test-ubuntu.sh index 8d8aa273..24d54f5c 100755 --- a/utils/docker/test-ubuntu.sh +++ b/utils/docker/test-ubuntu.sh @@ -1,5 +1,6 @@ #!/usr/bin/env bash # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2022 Second State INC export CC=gcc export CXX=g++ diff --git a/utils/wasi-test/run-wasi-test.sh b/utils/wasi-test/run-wasi-test.sh index 181d3f2f..0acc1c59 100755 --- a/utils/wasi-test/run-wasi-test.sh +++ b/utils/wasi-test/run-wasi-test.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + # Test WasmEdge WASI layer. # The testcase is from https://github.com/khronosproject/wasi-test From f8f1fc0d4a2f422b9ff9ed79cd19bbf754bd5c92 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Wed, 9 Feb 2022 21:45:16 +0800 Subject: [PATCH 037/623] [Docs] Fix markdown errors and update the WasmEdge versions to `0.9.1`. Signed-off-by: YiYing He --- utils/docker/build-appdev.md | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/utils/docker/build-appdev.md b/utils/docker/build-appdev.md index 314a21e1..f9ecbe72 100644 --- a/utils/docker/build-appdev.md +++ b/utils/docker/build-appdev.md @@ -2,17 +2,17 @@ The `appdev` Docker images provide a complete WasmEdge application development environment. To use it, do the following. -### On x86_64 machines +## On x86_64 machines -``` +```bash $ docker pull wasmedge/appdev_x86_64:0.9.0 $ docker run --rm -v $(pwd):/app -it wasmedge/appdev_x86_64:0.9.0 (docker) # ``` -### On arm64 machines +## On arm64 machines -``` +```bash $ docker pull wasmedge/appdev_aarch64:0.9.0 $ docker run --rm -v $(pwd):/app -it wasmedge/appdev_aarch64:0.9.0 (docker) # @@ -31,7 +31,7 @@ It installs the following components. Hello World. [See more simple examples](https://github.com/WasmEdge/WasmEdge/tree/master/tools/wasmedge/examples) -``` +```bash $ wasmedge hello.wasm world hello world @@ -39,7 +39,7 @@ world Use AOT to run it *much faster*. -``` +```bash $ wasmedgec hello.wasm hello.so $ wasmedge hello.so world hello @@ -48,7 +48,7 @@ world Here are some JavaScript examples. [See more](https://github.com/WasmEdge/WasmEdge/tree/master/tools/wasmedge/examples/js) -``` +```bash $ wasmedge --dir .:. qjs.wasm hello.js 1 2 3 Hello 1 2 3 @@ -63,15 +63,14 @@ Run these commands to build and publish the `appdev` Docker images. ### Build on an x86_64 machine -``` +```bash docker build -t wasmedge/appdev_x86_64:0.9.0 -f Dockerfile.appdev_x86_64 ./ docker image push wasmedge/appdev_x86_64:0.9.0 ``` ### Build on an ARM64 / aarch64 machine -``` +```bash docker build -t wasmedge/appdev_aarch64:0.9.0 -f Dockerfile.appdev_aarch64 ./ docker image push wasmedge/appdev_aarch64:0.9.0 ``` - From ebe751d73cb9da646e84c32727009d3a33a5952a Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Fri, 29 Apr 2022 21:08:07 +0800 Subject: [PATCH 038/623] [Misc] Update boost in manylinux build * Set git config `safe.directory` for root directory owned by someone else Signed-off-by: Shen-Ta Hsieh --- utils/docker/build-manylinux.sh | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/utils/docker/build-manylinux.sh b/utils/docker/build-manylinux.sh index 402be47d..225571ce 100755 --- a/utils/docker/build-manylinux.sh +++ b/utils/docker/build-manylinux.sh @@ -7,10 +7,11 @@ export CC=gcc export CXX=g++ export CPPFLAGS=-I/toolchain/include export LDFLAGS=-L/toolchain/lib64 -curl -s -L -O --remote-name-all https://boostorg.jfrog.io/artifactory/main/release/1.76.0/source/boost_1_76_0.tar.bz2 -echo "f0397ba6e982c4450f27bf32a2a83292aba035b827a5623a14636ea583318c41 boost_1_76_0.tar.bz2" | sha256sum -c -bzip2 -dc boost_1_76_0.tar.bz2 | tar -xf - -if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release -DWASMEDGE_BUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM;DEB" -DBoost_NO_SYSTEM_PATHS=TRUE -DBOOST_INCLUDEDIR=$(pwd)/boost_1_76_0/; then +curl -s -L -O --remote-name-all https://boostorg.jfrog.io/artifactory/main/release/1.79.0/source/boost_1_79_0.tar.bz2 +echo "475d589d51a7f8b3ba2ba4eda022b170e562ca3b760ee922c146b6c65856ef39 boost_1_79_0.tar.bz2" | sha256sum -c +git config --global --add safe.directory $(pwd) +bzip2 -dc boost_1_79_0.tar.bz2 | tar -xf - +if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release -DWASMEDGE_BUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM;DEB" -DBoost_NO_SYSTEM_PATHS=TRUE -DBOOST_INCLUDEDIR=$(pwd)/boost_1_79_0/; then echo === CMakeOutput.log === cat build/CMakeFiles/CMakeOutput.log echo === CMakeError.log === From 251f88f16deb2b7dd1ef03dab1dd25e0861e79b6 Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Fri, 22 Apr 2022 04:51:45 +0800 Subject: [PATCH 039/623] [Plugin] Move WasmEdgeProcess to plugins directory * Install to global plugin directory Signed-off-by: Shen-Ta Hsieh --- plugins/CMakeLists.txt | 4 + plugins/wasmedge_process/CMakeLists.txt | 65 ++++ plugins/wasmedge_process/processbase.h | 23 ++ plugins/wasmedge_process/processenv.cpp | 60 ++++ plugins/wasmedge_process/processenv.h | 52 ++++ plugins/wasmedge_process/processfunc.cpp | 343 +++++++++++++++++++++ plugins/wasmedge_process/processfunc.h | 106 +++++++ plugins/wasmedge_process/processmodule.cpp | 39 +++ plugins/wasmedge_process/processmodule.h | 23 ++ 9 files changed, 715 insertions(+) create mode 100644 plugins/CMakeLists.txt create mode 100644 plugins/wasmedge_process/CMakeLists.txt create mode 100644 plugins/wasmedge_process/processbase.h create mode 100644 plugins/wasmedge_process/processenv.cpp create mode 100644 plugins/wasmedge_process/processenv.h create mode 100644 plugins/wasmedge_process/processfunc.cpp create mode 100644 plugins/wasmedge_process/processfunc.h create mode 100644 plugins/wasmedge_process/processmodule.cpp create mode 100644 plugins/wasmedge_process/processmodule.h diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt new file mode 100644 index 00000000..6725682f --- /dev/null +++ b/plugins/CMakeLists.txt @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +add_subdirectory(wasmedge_process) diff --git a/plugins/wasmedge_process/CMakeLists.txt b/plugins/wasmedge_process/CMakeLists.txt new file mode 100644 index 00000000..8a9e4b62 --- /dev/null +++ b/plugins/wasmedge_process/CMakeLists.txt @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +wasmedge_add_library(wasmedgePluginWasmEdgeProcess + SHARED + processenv.cpp + processfunc.cpp + processmodule.cpp +) + +target_compile_options(wasmedgePluginWasmEdgeProcess + PUBLIC + -DWASMEDGE_PLUGIN +) + +if(CMAKE_SYSTEM_NAME MATCHES "Darwin") + target_link_options(wasmedgePluginWasmEdgeProcess + PUBLIC + -Wl,-U,__ZN8WasmEdge6Plugin14PluginRegisterC1EPKNS0_6Plugin16PluginDescriptorE + -Wl,-U,__ZN8WasmEdge6Plugin14PluginRegisterD1Ev + ) +endif() + +target_include_directories(wasmedgePluginWasmEdgeProcess + PUBLIC + $ +) + +target_link_libraries(wasmedgePluginWasmEdgeProcess + PUBLIC + wasmedgeCommon + wasmedgeSystem +) + +install(TARGETS wasmedgePluginWasmEdgeProcess DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) + +wasmedge_add_library(wasmedgeHostModuleWasmEdgeProcess + processenv.cpp + processfunc.cpp + processmodule.cpp +) + +target_include_directories(wasmedgeHostModuleWasmEdgeProcess + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_link_libraries(wasmedgeHostModuleWasmEdgeProcess + PUBLIC + wasmedgeCommon + wasmedgeSystem + wasmedgePlugin +) + +if(CMAKE_SYSTEM_NAME MATCHES "Linux") + target_link_options(wasmedgeHostModuleWasmEdgeProcess + PUBLIC + -u_ZN8WasmEdge4Host26WasmEdgeProcessEnvironment8RegisterE + ) +elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin") + target_link_options(wasmedgeHostModuleWasmEdgeProcess + PUBLIC + -u__ZN8WasmEdge4Host26WasmEdgeProcessEnvironment8RegisterE + ) +endif() diff --git a/plugins/wasmedge_process/processbase.h b/plugins/wasmedge_process/processbase.h new file mode 100644 index 00000000..7db2c860 --- /dev/null +++ b/plugins/wasmedge_process/processbase.h @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "common/errcode.h" +#include "processenv.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { + +template class WasmEdgeProcess : public Runtime::HostFunction { +public: + WasmEdgeProcess(WasmEdgeProcessEnvironment &HostEnv) + : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + WasmEdgeProcessEnvironment &Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_process/processenv.cpp b/plugins/wasmedge_process/processenv.cpp new file mode 100644 index 00000000..b314addc --- /dev/null +++ b/plugins/wasmedge_process/processenv.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "processenv.h" +#include "po/helper.h" +#include "processmodule.h" +#include + +namespace WasmEdge { +namespace Host { + +using namespace std::literals::string_view_literals; + +PO::List WasmEdgeProcessEnvironment::AllowCmd( + PO::Description( + "Allow commands called from wasmedge_process host functions. Each command can be specified as --allow-command `COMMAND`."sv), + PO::MetaVar("COMMANDS"sv)); + +PO::Option WasmEdgeProcessEnvironment::AllowCmdAll(PO::Description( + "Allow all commands called from wasmedge_process host functions."sv)); + +WasmEdgeProcessEnvironment::WasmEdgeProcessEnvironment() noexcept + : AllowedCmd(AllowCmd.value().begin(), AllowCmd.value().end()), + AllowedAll(AllowCmdAll.value()) {} + +namespace { + +void addOptions(PO::ArgumentParser &Parser) noexcept { + Parser.add_option("allow-command"sv, WasmEdgeProcessEnvironment::AllowCmd) + .add_option("allow-command-all"sv, + WasmEdgeProcessEnvironment::AllowCmdAll); +} + +Runtime::Instance::ModuleInstance *create(void) noexcept { + return new WasmEdgeProcessModule; +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasmedge_process", + .Description = "", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 10, 0, 0}, + .ModuleCount = 1, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "wasmedge_process", + .Description = "", + .Create = create, + }, + }, + .AddOptions = addOptions, +}; + +} // namespace + +Plugin::PluginRegister WasmEdgeProcessEnvironment::Register(&Descriptor); + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_process/processenv.h b/plugins/wasmedge_process/processenv.h new file mode 100644 index 00000000..5bbdd7ec --- /dev/null +++ b/plugins/wasmedge_process/processenv.h @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "plugin/plugin.h" +#include "po/argument_parser.h" +#include "po/list.h" +#include "po/option.h" +#include +#include +#include +#include +#include + +namespace WasmEdge { +namespace Host { + +class WasmEdgeProcessEnvironment { +public: + WasmEdgeProcessEnvironment() noexcept; + + /// Default timeout in milliseconds. + static inline const uint32_t DEFAULT_TIMEOUT = 10000; + /// Default polling time in milliseconds. + static inline const uint32_t DEFAULT_POLLTIME = 1; + + /// Commands + std::string Name; + std::vector Args; + std::unordered_map Envs; + + /// IO + std::vector StdIn; + std::vector StdOut; + std::vector StdErr; + + /// Configurations + uint32_t TimeOut = DEFAULT_TIMEOUT; /// Timeout in milliseconds. + std::unordered_set AllowedCmd; /// Programs in white list. + bool AllowedAll; /// Flag to allow all programs. + + /// Results + uint32_t ExitCode = 0; + + static PO::List AllowCmd; + static PO::Option AllowCmdAll; + static Plugin::PluginRegister Register; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_process/processfunc.cpp b/plugins/wasmedge_process/processfunc.cpp new file mode 100644 index 00000000..d10735c8 --- /dev/null +++ b/plugins/wasmedge_process/processfunc.cpp @@ -0,0 +1,343 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "processfunc.h" + +#if WASMEDGE_OS_LINUX || WASMEDGE_OS_MACOS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#elif WASMEDGE_OS_WINDOWS +#endif + +namespace WasmEdge { +namespace Host { + +Expect +WasmEdgeProcessSetProgName::body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t NamePtr, uint32_t NameLen) { + // Check memory instance from module. + if (MemInst == nullptr) { + return Unexpect(ErrCode::ExecutionFailed); + } + + char *Buf = MemInst->getPointer(NamePtr); + std::copy_n(Buf, NameLen, std::back_inserter(Env.Name)); + return {}; +} + +Expect +WasmEdgeProcessAddArg::body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t ArgPtr, uint32_t ArgLen) { + // Check memory instance from module. + if (MemInst == nullptr) { + return Unexpect(ErrCode::ExecutionFailed); + } + + char *Buf = MemInst->getPointer(ArgPtr); + std::string NewArg; + std::copy_n(Buf, ArgLen, std::back_inserter(NewArg)); + Env.Args.push_back(std::move(NewArg)); + return {}; +} + +Expect +WasmEdgeProcessAddEnv::body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t EnvNamePtr, uint32_t EnvNameLen, + uint32_t EnvValPtr, uint32_t EnvValLen) { + // Check memory instance from module. + if (MemInst == nullptr) { + return Unexpect(ErrCode::ExecutionFailed); + } + + char *EnvBuf = MemInst->getPointer(EnvNamePtr); + char *ValBuf = MemInst->getPointer(EnvValPtr); + std::string NewEnv, NewVal; + std::copy_n(EnvBuf, EnvNameLen, std::back_inserter(NewEnv)); + std::copy_n(ValBuf, EnvValLen, std::back_inserter(NewVal)); + Env.Envs.emplace(std::move(NewEnv), std::move(NewVal)); + return {}; +} + +Expect +WasmEdgeProcessAddStdIn::body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t BufPtr, uint32_t BufLen) { + // Check memory instance from module. + if (MemInst == nullptr) { + return Unexpect(ErrCode::ExecutionFailed); + } + + uint8_t *Buf = MemInst->getPointer(BufPtr); + Env.StdIn.reserve(Env.StdIn.size() + BufLen); + std::copy_n(Buf, BufLen, std::back_inserter(Env.StdIn)); + return {}; +} + +Expect +WasmEdgeProcessSetTimeOut::body(Runtime::Instance::MemoryInstance *, + uint32_t Time) { + Env.TimeOut = Time; + return {}; +} + +Expect WasmEdgeProcessRun::body(Runtime::Instance::MemoryInstance *) { +#if WASMEDGE_OS_LINUX || WASMEDGE_OS_MACOS + // Clear outputs. + Env.StdOut.clear(); + Env.StdErr.clear(); + Env.ExitCode = static_cast(-1); + + // Check white list of commands. + if (!Env.AllowedAll && + Env.AllowedCmd.find(Env.Name) == Env.AllowedCmd.end()) { + std::string Msg = "Permission denied: Command \""; + Msg.append(Env.Name); + Msg.append("\" is not in the white list. Please use --allow-command="); + Msg.append(Env.Name); + Msg.append(" or --allow-command-all to add \""); + Msg.append(Env.Name); + Msg.append("\" command into the white list.\n"); + Env.Name.clear(); + Env.Args.clear(); + Env.Envs.clear(); + Env.StdIn.clear(); + Env.StdErr.reserve(Msg.length()); + std::copy_n(Msg.c_str(), Msg.length(), std::back_inserter(Env.StdErr)); + Env.ExitCode = static_cast(INT8_C(-1)); + Env.TimeOut = Env.DEFAULT_TIMEOUT; + return Env.ExitCode; + } + + // Create pipes for stdin, stdout, and stderr. + int FDStdIn[2], FDStdOut[2], FDStdErr[2]; + if (pipe(FDStdIn) == -1) { + // Create stdin pipe failed. + return Env.ExitCode; + } + if (pipe(FDStdOut) == -1) { + // Create stdout pipe failed. + close(FDStdIn[0]); + close(FDStdIn[1]); + return Env.ExitCode; + } + if (pipe(FDStdErr) == -1) { + // Create stderr pipe failed. + close(FDStdIn[0]); + close(FDStdIn[1]); + close(FDStdOut[0]); + close(FDStdOut[1]); + return Env.ExitCode; + } + + // Create a child process for executing command. + pid_t PID = fork(); + if (PID == -1) { + // Create process failed. + close(FDStdIn[0]); + close(FDStdIn[1]); + close(FDStdOut[0]); + close(FDStdOut[1]); + close(FDStdErr[0]); + close(FDStdErr[1]); + return Env.ExitCode; + } else if (PID == 0) { + // Child process. Setup pipes. + dup2(FDStdIn[0], 0); + dup2(FDStdOut[1], 1); + dup2(FDStdErr[1], 2); + close(FDStdIn[0]); + close(FDStdIn[1]); + close(FDStdOut[0]); + close(FDStdOut[1]); + close(FDStdErr[0]); + close(FDStdErr[1]); + + // Prepare arguments and environment variables. + std::vector EnvStr; + for (auto &It : Env.Envs) { + EnvStr.push_back(It.first + "=" + It.second); + } + std::vector Argv, Envp; + Argv.push_back(Env.Name.data()); + std::transform(Env.Args.begin(), Env.Args.end(), std::back_inserter(Argv), + [](std::string &S) { return S.data(); }); + std::transform(EnvStr.begin(), EnvStr.end(), std::back_inserter(Envp), + [](std::string &S) { return S.data(); }); + Argv.push_back(nullptr); + Envp.push_back(nullptr); +#if defined(__GLIBC_PREREQ) +#if __GLIBC_PREREQ(2, 11) + if (execvpe(Env.Name.c_str(), &Argv[0], &Envp[0]) == -1) { +#else + if (execve(Env.Name.c_str(), &Argv[0], &Envp[0]) == -1) { +#endif +#else + if (execve(Env.Name.c_str(), &Argv[0], &Envp[0]) == -1) { +#endif + switch (errno) { + case EACCES: + spdlog::error("Permission denied."); + break; + case ENOENT: + spdlog::error("Command not found."); + break; + default: + spdlog::error("Unknown error."); + break; + } + _exit(-1); + } + } else { + // Parent process. Close unused file descriptors. + close(FDStdIn[0]); + close(FDStdOut[1]); + close(FDStdErr[1]); + + // Send inputs. + uint32_t WBytes = 0; + while (WBytes < Env.StdIn.size()) { + uint32_t WriteNum = + std::min(static_cast(PIPE_BUF), Env.StdIn.size() - WBytes); + if (auto Res = write(FDStdIn[1], &Env.StdIn[WBytes], WriteNum); Res > 0) { + WBytes += Res; + } else { + break; + } + } + close(FDStdIn[1]); + + // Waiting for child process and get outputs. + uint8_t Buf[PIPE_BUF]; + ssize_t RBytes; + int ChildStat; + struct timeval TStart, TCurr; + gettimeofday(&TStart, NULL); + while (true) { + gettimeofday(&TCurr, NULL); + if ((TCurr.tv_sec - TStart.tv_sec) * 1000U + + (TCurr.tv_usec - TStart.tv_usec) / 1000000U > + Env.TimeOut) { + // Over timeout. Interrupt child process. + kill(PID, SIGKILL); + Env.ExitCode = static_cast(ETIMEDOUT); + break; + } + + // Wait for child process. + pid_t WPID = waitpid(PID, &ChildStat, WNOHANG); + if (WPID == -1) { + // waitpid failed. + Env.ExitCode = static_cast(EINVAL); + break; + } else if (WPID > 0) { + // Child process returned. + Env.ExitCode = static_cast(WEXITSTATUS(ChildStat)); + break; + } + + // Read stdout and stderr. + fd_set FDSet; + int NFD = std::max(FDStdOut[0], FDStdErr[0]) + 1; + FD_ZERO(&FDSet); + FD_SET(FDStdOut[0], &FDSet); + FD_SET(FDStdErr[0], &FDSet); + struct timeval TSelect = {.tv_sec = 0, .tv_usec = 0}; + if (select(NFD, &FDSet, NULL, NULL, &TSelect) > 0) { + if (FD_ISSET(FDStdOut[0], &FDSet)) { + if (RBytes = read(FDStdOut[0], Buf, sizeof(Buf)); RBytes > 0) { + Env.StdOut.reserve(Env.StdOut.size() + RBytes); + std::copy_n(Buf, RBytes, std::back_inserter(Env.StdOut)); + } + } + if (FD_ISSET(FDStdErr[0], &FDSet)) { + if (RBytes = read(FDStdErr[0], Buf, sizeof(Buf)); RBytes > 0) { + Env.StdErr.reserve(Env.StdErr.size() + RBytes); + std::copy_n(Buf, RBytes, std::back_inserter(Env.StdErr)); + } + } + } + usleep(Env.DEFAULT_POLLTIME * 1000); + } + + // Read remained stdout and stderr. + do { + RBytes = read(FDStdOut[0], Buf, sizeof(Buf)); + if (RBytes > 0) { + Env.StdOut.reserve(Env.StdOut.size() + RBytes); + std::copy_n(Buf, RBytes, std::back_inserter(Env.StdOut)); + } + } while (RBytes > 0); + do { + RBytes = read(FDStdErr[0], Buf, sizeof(Buf)); + if (RBytes > 0) { + Env.StdErr.reserve(Env.StdErr.size() + RBytes); + std::copy_n(Buf, RBytes, std::back_inserter(Env.StdErr)); + } + } while (RBytes > 0); + close(FDStdOut[0]); + close(FDStdErr[0]); + } + + // Reset inputs. + Env.Name.clear(); + Env.Args.clear(); + Env.Envs.clear(); + Env.StdIn.clear(); + Env.TimeOut = Env.DEFAULT_TIMEOUT; + return Env.ExitCode; +#elif WASMEDGE_OS_WINDOWS + spdlog::error("wasmedge_process doesn't support windows now."); + return Unexpect(ErrCode::ExecutionFailed); +#endif +} + +Expect +WasmEdgeProcessGetExitCode::body(Runtime::Instance::MemoryInstance *) { + return Env.ExitCode; +} + +Expect +WasmEdgeProcessGetStdOutLen::body(Runtime::Instance::MemoryInstance *) { + return static_cast(Env.StdOut.size()); +} + +Expect +WasmEdgeProcessGetStdOut::body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t BufPtr) { + // Check memory instance from module. + if (MemInst == nullptr) { + return Unexpect(ErrCode::ExecutionFailed); + } + + char *Buf = MemInst->getPointer(BufPtr); + std::copy_n(Env.StdOut.begin(), Env.StdOut.size(), Buf); + return {}; +} + +Expect +WasmEdgeProcessGetStdErrLen::body(Runtime::Instance::MemoryInstance *) { + return static_cast(Env.StdErr.size()); +} + +Expect +WasmEdgeProcessGetStdErr::body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t BufPtr) { + // Check memory instance from module. + if (MemInst == nullptr) { + return Unexpect(ErrCode::ExecutionFailed); + } + + char *Buf = MemInst->getPointer(BufPtr); + std::copy_n(Env.StdErr.begin(), Env.StdErr.size(), Buf); + return {}; +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_process/processfunc.h b/plugins/wasmedge_process/processfunc.h new file mode 100644 index 00000000..aef2e4bc --- /dev/null +++ b/plugins/wasmedge_process/processfunc.h @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "common/defines.h" +#include "processbase.h" +#include + +namespace WasmEdge { +namespace Host { + +class WasmEdgeProcessSetProgName + : public WasmEdgeProcess { +public: + WasmEdgeProcessSetProgName(WasmEdgeProcessEnvironment &HostEnv) + : WasmEdgeProcess(HostEnv) {} + Expect body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t NamePtr, uint32_t NameLen); +}; + +class WasmEdgeProcessAddArg : public WasmEdgeProcess { +public: + WasmEdgeProcessAddArg(WasmEdgeProcessEnvironment &HostEnv) + : WasmEdgeProcess(HostEnv) {} + Expect body(Runtime::Instance::MemoryInstance *MemInst, uint32_t ArgPtr, + uint32_t ArgLen); +}; + +class WasmEdgeProcessAddEnv : public WasmEdgeProcess { +public: + WasmEdgeProcessAddEnv(WasmEdgeProcessEnvironment &HostEnv) + : WasmEdgeProcess(HostEnv) {} + Expect body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t EnvNamePtr, uint32_t EnvNameLen, + uint32_t EnvValPtr, uint32_t EnvValLen); +}; + +class WasmEdgeProcessAddStdIn + : public WasmEdgeProcess { +public: + WasmEdgeProcessAddStdIn(WasmEdgeProcessEnvironment &HostEnv) + : WasmEdgeProcess(HostEnv) {} + Expect body(Runtime::Instance::MemoryInstance *MemInst, uint32_t BufPtr, + uint32_t BufLen); +}; + +class WasmEdgeProcessSetTimeOut + : public WasmEdgeProcess { +public: + WasmEdgeProcessSetTimeOut(WasmEdgeProcessEnvironment &HostEnv) + : WasmEdgeProcess(HostEnv) {} + Expect body(Runtime::Instance::MemoryInstance *MemInst, uint32_t Time); +}; + +class WasmEdgeProcessRun : public WasmEdgeProcess { +public: + WasmEdgeProcessRun(WasmEdgeProcessEnvironment &HostEnv) + : WasmEdgeProcess(HostEnv) {} + Expect body(Runtime::Instance::MemoryInstance *MemInst); +}; + +class WasmEdgeProcessGetExitCode + : public WasmEdgeProcess { +public: + WasmEdgeProcessGetExitCode(WasmEdgeProcessEnvironment &HostEnv) + : WasmEdgeProcess(HostEnv) {} + Expect body(Runtime::Instance::MemoryInstance *MemInst); +}; + +class WasmEdgeProcessGetStdOutLen + : public WasmEdgeProcess { +public: + WasmEdgeProcessGetStdOutLen(WasmEdgeProcessEnvironment &HostEnv) + : WasmEdgeProcess(HostEnv) {} + Expect body(Runtime::Instance::MemoryInstance *MemInst); +}; + +class WasmEdgeProcessGetStdOut + : public WasmEdgeProcess { +public: + WasmEdgeProcessGetStdOut(WasmEdgeProcessEnvironment &HostEnv) + : WasmEdgeProcess(HostEnv) {} + Expect body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t BufPtr); +}; + +class WasmEdgeProcessGetStdErrLen + : public WasmEdgeProcess { +public: + WasmEdgeProcessGetStdErrLen(WasmEdgeProcessEnvironment &HostEnv) + : WasmEdgeProcess(HostEnv) {} + Expect body(Runtime::Instance::MemoryInstance *MemInst); +}; + +class WasmEdgeProcessGetStdErr + : public WasmEdgeProcess { +public: + WasmEdgeProcessGetStdErr(WasmEdgeProcessEnvironment &HostEnv) + : WasmEdgeProcess(HostEnv) {} + Expect body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t BufPtr); +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_process/processmodule.cpp b/plugins/wasmedge_process/processmodule.cpp new file mode 100644 index 00000000..163a1cf2 --- /dev/null +++ b/plugins/wasmedge_process/processmodule.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "processmodule.h" +#include "processfunc.h" + +#include + +namespace WasmEdge { +namespace Host { + +WasmEdgeProcessModule::WasmEdgeProcessModule() + : ModuleInstance("wasmedge_process") { + addHostFunc("wasmedge_process_set_prog_name", + std::make_unique(Env)); + addHostFunc("wasmedge_process_add_arg", + std::make_unique(Env)); + addHostFunc("wasmedge_process_add_env", + std::make_unique(Env)); + addHostFunc("wasmedge_process_add_stdin", + std::make_unique(Env)); + addHostFunc("wasmedge_process_set_timeout", + std::make_unique(Env)); + addHostFunc("wasmedge_process_run", + std::make_unique(Env)); + addHostFunc("wasmedge_process_get_exit_code", + std::make_unique(Env)); + addHostFunc("wasmedge_process_get_stdout_len", + std::make_unique(Env)); + addHostFunc("wasmedge_process_get_stdout", + std::make_unique(Env)); + addHostFunc("wasmedge_process_get_stderr_len", + std::make_unique(Env)); + addHostFunc("wasmedge_process_get_stderr", + std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_process/processmodule.h b/plugins/wasmedge_process/processmodule.h new file mode 100644 index 00000000..491ae2ce --- /dev/null +++ b/plugins/wasmedge_process/processmodule.h @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "processenv.h" +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasmEdgeProcessModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeProcessModule(); + + WasmEdgeProcessEnvironment &getEnv() { return Env; } + +private: + WasmEdgeProcessEnvironment Env; +}; + +} // namespace Host +} // namespace WasmEdge From 3b26fc5d935d8419db506f99a53f9cb09619c2fb Mon Sep 17 00:00:00 2001 From: YiYing He Date: Mon, 23 May 2022 23:13:39 +0800 Subject: [PATCH 040/623] [Examples] Move the examples in the `tools` into the `examples`. Signed-off-by: YiYing He --- utils/docker/Dockerfile.appdev_aarch64 | 8 ++++---- utils/docker/Dockerfile.appdev_x86_64 | 10 +++++----- utils/docker/build-appdev.md | 4 ++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/utils/docker/Dockerfile.appdev_aarch64 b/utils/docker/Dockerfile.appdev_aarch64 index ee56f550..0fc33825 100644 --- a/utils/docker/Dockerfile.appdev_aarch64 +++ b/utils/docker/Dockerfile.appdev_aarch64 @@ -15,9 +15,9 @@ RUN npm install wasmedge-core RUN mkdir -p /root/examples WORKDIR /root/examples -RUN wget https://github.com/WasmEdge/WasmEdge/raw/master/tools/wasmedge/examples/hello.wasm &&\ - wget https://github.com/WasmEdge/WasmEdge/raw/master/tools/wasmedge/examples/js/qjs.wasm &&\ - wget https://github.com/WasmEdge/WasmEdge/raw/master/tools/wasmedge/examples/js/hello.js &&\ - wget https://github.com/WasmEdge/WasmEdge/raw/master/tools/wasmedge/examples/js/repl.js +RUN wget https://github.com/WasmEdge/WasmEdge/raw/master/examples/wasm/hello.wasm &&\ + wget https://github.com/WasmEdge/WasmEdge/raw/master/examples/js/qjs.wasm &&\ + wget https://github.com/WasmEdge/WasmEdge/raw/master/examples/js/hello.js &&\ + wget https://github.com/WasmEdge/WasmEdge/raw/master/examples/js/repl.js ENTRYPOINT ["/bin/bash", "-l"] diff --git a/utils/docker/Dockerfile.appdev_x86_64 b/utils/docker/Dockerfile.appdev_x86_64 index 56e3b1be..e97b245a 100644 --- a/utils/docker/Dockerfile.appdev_x86_64 +++ b/utils/docker/Dockerfile.appdev_x86_64 @@ -15,11 +15,11 @@ RUN npm install wasmedge-extensions RUN mkdir -p /root/examples WORKDIR /root/examples -RUN wget https://github.com/WasmEdge/WasmEdge/raw/master/tools/wasmedge/examples/hello.wasm &&\ - wget https://github.com/WasmEdge/WasmEdge/raw/master/tools/wasmedge/examples/js/qjs.wasm &&\ - wget https://github.com/WasmEdge/WasmEdge/raw/master/tools/wasmedge/examples/js/hello.js &&\ - wget https://github.com/WasmEdge/WasmEdge/raw/master/tools/wasmedge/examples/js/repl.js &&\ - wget https://github.com/WasmEdge/WasmEdge/raw/master/tools/wasmedge/examples/js/qjs_tf.wasm &&\ +RUN wget https://github.com/WasmEdge/WasmEdge/raw/master/examples/wasm/hello.wasm &&\ + wget https://github.com/WasmEdge/WasmEdge/raw/master/examples/js/qjs.wasm &&\ + wget https://github.com/WasmEdge/WasmEdge/raw/master/examples/js/hello.js &&\ + wget https://github.com/WasmEdge/WasmEdge/raw/master/examples/js/repl.js &&\ + wget https://github.com/WasmEdge/WasmEdge/raw/master/examples/js/qjs_tf.wasm &&\ wget -O tf_image_classify.js https://raw.githubusercontent.com/second-state/wasmedge-quickjs/main/example_js/tensorflow_lite_demo/main.js &&\ wget https://github.com/second-state/wasmedge-quickjs/raw/main/example_js/tensorflow_lite_demo/lite-model_aiy_vision_classifier_food_V1_1.tflite &&\ wget https://github.com/second-state/wasmedge-quickjs/raw/main/example_js/tensorflow_lite_demo/food.jpg &&\ diff --git a/utils/docker/build-appdev.md b/utils/docker/build-appdev.md index f9ecbe72..74cdf34f 100644 --- a/utils/docker/build-appdev.md +++ b/utils/docker/build-appdev.md @@ -29,7 +29,7 @@ It installs the following components. ## Examples -Hello World. [See more simple examples](https://github.com/WasmEdge/WasmEdge/tree/master/tools/wasmedge/examples) +Hello World. [See more simple examples](https://github.com/WasmEdge/WasmEdge/tree/master/examples/wasm) ```bash $ wasmedge hello.wasm world @@ -46,7 +46,7 @@ hello world ``` -Here are some JavaScript examples. [See more](https://github.com/WasmEdge/WasmEdge/tree/master/tools/wasmedge/examples/js) +Here are some JavaScript examples. [See more](https://github.com/WasmEdge/WasmEdge/tree/master/examples/js) ```bash $ wasmedge --dir .:. qjs.wasm hello.js 1 2 3 From 942865b968d5f61ffbdc7151711352113720eae0 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Thu, 16 Jun 2022 22:10:52 +0800 Subject: [PATCH 041/623] [CI] Only install wasmedge_process on Linux. Signed-off-by: YiYing He --- plugins/wasmedge_process/CMakeLists.txt | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/plugins/wasmedge_process/CMakeLists.txt b/plugins/wasmedge_process/CMakeLists.txt index 8a9e4b62..4f41af5f 100644 --- a/plugins/wasmedge_process/CMakeLists.txt +++ b/plugins/wasmedge_process/CMakeLists.txt @@ -32,7 +32,10 @@ target_link_libraries(wasmedgePluginWasmEdgeProcess wasmedgeSystem ) -install(TARGETS wasmedgePluginWasmEdgeProcess DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) +# Only Linux systems support wasmedge_process now. +if(CMAKE_SYSTEM_NAME MATCHES "Linux") + install(TARGETS wasmedgePluginWasmEdgeProcess DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) +endif() wasmedge_add_library(wasmedgeHostModuleWasmEdgeProcess processenv.cpp From dd901cc754f5445d4b5a1d102108649eade7e4c9 Mon Sep 17 00:00:00 2001 From: Jianbai Ye Date: Thu, 10 Mar 2022 20:52:32 +0800 Subject: [PATCH 042/623] [WASI] support OpenVINO backend for WASI-NN (#1340) Signed-off-by: Jianbai Ye --- utils/docker/build-wasinn-ubuntu-openvino.sh | 16 ++++++++++++++++ utils/docker/test-wasinn-ubuntu-openvino.sh | 7 +++++++ 2 files changed, 23 insertions(+) create mode 100644 utils/docker/build-wasinn-ubuntu-openvino.sh create mode 100644 utils/docker/test-wasinn-ubuntu-openvino.sh diff --git a/utils/docker/build-wasinn-ubuntu-openvino.sh b/utils/docker/build-wasinn-ubuntu-openvino.sh new file mode 100644 index 00000000..544cfc79 --- /dev/null +++ b/utils/docker/build-wasinn-ubuntu-openvino.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +set -e +echo "Installing OpenVINO with version ${OPENVINO_VERSION}" +curl -sSL https://apt.repos.intel.com/openvino/$OPENVINO_YEAR/GPG-PUB-KEY-INTEL-OPENVINO-$OPENVINO_YEAR >./GPG-PUB-KEY-INTEL-OPENVINO-$OPENVINO_YEAR +apt-key add ./GPG-PUB-KEY-INTEL-OPENVINO-$OPENVINO_YEAR +echo "deb https://apt.repos.intel.com/openvino/$OPENVINO_YEAR all main" | tee /etc/apt/sources.list.d/intel-openvino-$OPENVINO_YEAR.list +apt update +apt install -y intel-openvino-runtime-ubuntu20-$OPENVINO_VERSION +source /opt/intel/openvino_2021/bin/setupvars.sh +ldconfig + +cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=$CMAKE_BUILD_TYPE -DWASMEDGE_BUILD_TESTS=ON -DWASMEDGE_WASINN_BUILD_OPENVINO=ON . +cmake --build build diff --git a/utils/docker/test-wasinn-ubuntu-openvino.sh b/utils/docker/test-wasinn-ubuntu-openvino.sh new file mode 100644 index 00000000..6500a296 --- /dev/null +++ b/utils/docker/test-wasinn-ubuntu-openvino.sh @@ -0,0 +1,7 @@ +source /opt/intel/openvino_2021/bin/setupvars.sh +ldconfig +export LD_LIBRARY_PATH="$(pwd)/build/lib/api:$LD_LIBRARY_PATH" + +cd build +ctest +cd - From 607297318f5ccc4cfe325461ca3fe2987c4355d3 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Tue, 7 Jun 2022 11:47:55 +0800 Subject: [PATCH 043/623] [Test] Update the WASI-NN test scripts. Signed-off-by: YiYing He --- utils/docker/build-wasinn-ubuntu-openvino.sh | 15 +++++++++----- utils/docker/test-ubuntu.sh | 21 -------------------- utils/docker/test-wasinn-ubuntu-openvino.sh | 0 3 files changed, 10 insertions(+), 26 deletions(-) mode change 100644 => 100755 utils/docker/build-wasinn-ubuntu-openvino.sh delete mode 100755 utils/docker/test-ubuntu.sh mode change 100644 => 100755 utils/docker/test-wasinn-ubuntu-openvino.sh diff --git a/utils/docker/build-wasinn-ubuntu-openvino.sh b/utils/docker/build-wasinn-ubuntu-openvino.sh old mode 100644 new mode 100755 index 544cfc79..6f3bc7a3 --- a/utils/docker/build-wasinn-ubuntu-openvino.sh +++ b/utils/docker/build-wasinn-ubuntu-openvino.sh @@ -4,13 +4,18 @@ set -e echo "Installing OpenVINO with version ${OPENVINO_VERSION}" -curl -sSL https://apt.repos.intel.com/openvino/$OPENVINO_YEAR/GPG-PUB-KEY-INTEL-OPENVINO-$OPENVINO_YEAR >./GPG-PUB-KEY-INTEL-OPENVINO-$OPENVINO_YEAR -apt-key add ./GPG-PUB-KEY-INTEL-OPENVINO-$OPENVINO_YEAR -echo "deb https://apt.repos.intel.com/openvino/$OPENVINO_YEAR all main" | tee /etc/apt/sources.list.d/intel-openvino-$OPENVINO_YEAR.list +curl -sSL https://apt.repos.intel.com/openvino/$OPENVINO_YEAR/GPG-PUB-KEY-INTEL-OPENVINO-$OPENVINO_YEAR | gpg --dearmor > /usr/share/keyrings/GPG-PUB-KEY-INTEL-OPENVINO-$OPENVINO_YEAR.gpg +echo "deb [signed-by=/usr/share/keyrings/GPG-PUB-KEY-INTEL-OPENVINO-$OPENVINO_YEAR.gpg] https://apt.repos.intel.com/openvino/$OPENVINO_YEAR all main" | tee /etc/apt/sources.list.d/intel-openvino-$OPENVINO_YEAR.list apt update apt install -y intel-openvino-runtime-ubuntu20-$OPENVINO_VERSION source /opt/intel/openvino_2021/bin/setupvars.sh ldconfig - -cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=$CMAKE_BUILD_TYPE -DWASMEDGE_BUILD_TESTS=ON -DWASMEDGE_WASINN_BUILD_OPENVINO=ON . +git config --global --add safe.directory $(pwd) +if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=$CMAKE_BUILD_TYPE -DWASMEDGE_BUILD_TESTS=ON -DWASMEDGE_WASINN_BUILD_OPENVINO=ON .; then + echo === CMakeOutput.log === + cat build/CMakeFiles/CMakeOutput.log + echo === CMakeError.log === + cat build/CMakeFiles/CMakeError.log + exit 1 +fi cmake --build build diff --git a/utils/docker/test-ubuntu.sh b/utils/docker/test-ubuntu.sh deleted file mode 100755 index 24d54f5c..00000000 --- a/utils/docker/test-ubuntu.sh +++ /dev/null @@ -1,21 +0,0 @@ -#!/usr/bin/env bash -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# SPDX-FileCopyrightText: 2019-2022 Second State INC - -export CC=gcc -export CXX=g++ -curl -s -L -O --remote-name-all https://boostorg.jfrog.io/artifactory/main/release/1.76.0/source/boost_1_76_0.tar.bz2 -echo "f0397ba6e982c4450f27bf32a2a83292aba035b827a5623a14636ea583318c41 boost_1_76_0.tar.bz2" | sha256sum -c -bzip2 -dc boost_1_76_0.tar.bz2 | tar -xf - -if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Debug -DWASMEDGE_BUILD_TESTS=ON -DBoost_NO_SYSTEM_PATHS=TRUE -DBOOST_INCLUDEDIR=$(pwd)/boost_1_76_0/; then - echo === CMakeOutput.log === - cat build/CMakeFiles/CMakeOutput.log - echo === CMakeError.log === - cat build/CMakeFiles/CMakeError.log - exit 1 -fi -cmake --build build -export LD_LIBRARY_PATH="$(pwd)/build/lib/api:$LD_LIBRARY_PATH" -cd build -ctest --output-on-failure -cd - diff --git a/utils/docker/test-wasinn-ubuntu-openvino.sh b/utils/docker/test-wasinn-ubuntu-openvino.sh old mode 100644 new mode 100755 From b46996d86670007381de1efb8562afdbecd688d1 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Wed, 15 Jun 2022 16:02:33 +0800 Subject: [PATCH 044/623] [WASI] Implement the plug-in version of wasi-nn. Signed-off-by: YiYing He --- plugins/CMakeLists.txt | 1 + plugins/wasi_nn/CMakeLists.txt | 83 +++++ plugins/wasi_nn/wasinnbase.h | 23 ++ plugins/wasi_nn/wasinnenv.cpp | 40 ++ plugins/wasi_nn/wasinnenv.h | 127 +++++++ plugins/wasi_nn/wasinnfunc.cpp | 601 +++++++++++++++++++++++++++++++ plugins/wasi_nn/wasinnfunc.h | 52 +++ plugins/wasi_nn/wasinnmodule.cpp | 20 + plugins/wasi_nn/wasinnmodule.h | 21 ++ 9 files changed, 968 insertions(+) create mode 100644 plugins/wasi_nn/CMakeLists.txt create mode 100644 plugins/wasi_nn/wasinnbase.h create mode 100644 plugins/wasi_nn/wasinnenv.cpp create mode 100644 plugins/wasi_nn/wasinnenv.h create mode 100644 plugins/wasi_nn/wasinnfunc.cpp create mode 100644 plugins/wasi_nn/wasinnfunc.h create mode 100644 plugins/wasi_nn/wasinnmodule.cpp create mode 100644 plugins/wasi_nn/wasinnmodule.h diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index 6725682f..c935c71b 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -2,3 +2,4 @@ # SPDX-FileCopyrightText: 2019-2022 Second State INC add_subdirectory(wasmedge_process) +add_subdirectory(wasi_nn) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt new file mode 100644 index 00000000..94c7a150 --- /dev/null +++ b/plugins/wasi_nn/CMakeLists.txt @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +wasmedge_add_library(wasmedgePluginWasiNN + SHARED + wasinnenv.cpp + wasinnfunc.cpp + wasinnmodule.cpp +) + +target_compile_options(wasmedgePluginWasiNN + PUBLIC + -DWASMEDGE_PLUGIN +) + +if(CMAKE_SYSTEM_NAME MATCHES "Darwin") + target_link_options(wasmedgePluginWasiNN + PUBLIC + -Wl,-U,__ZN8WasmEdge6Plugin14PluginRegisterC1EPKNS0_6Plugin16PluginDescriptorE + -Wl,-U,__ZN8WasmEdge6Plugin14PluginRegisterD1Ev + ) +endif() + +target_include_directories(wasmedgePluginWasiNN + PUBLIC + $ +) + +target_link_libraries(wasmedgePluginWasiNN + PUBLIC + wasmedgeCommon + wasmedgeSystem +) + +# Only Linux systems support wasi-nn with openVINO now. +if(CMAKE_SYSTEM_NAME MATCHES "Linux" AND WASMEDGE_WASINN_BUILD_OPENVINO) + install(TARGETS wasmedgePluginWasiNN DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) +endif() + +wasmedge_add_library(wasmedgeHostModuleWasiNN + wasinnenv.cpp + wasinnfunc.cpp + wasinnmodule.cpp +) + +target_include_directories(wasmedgeHostModuleWasiNN + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_link_libraries(wasmedgeHostModuleWasiNN + PUBLIC + wasmedgeCommon + wasmedgeSystem + wasmedgePlugin +) + +if(CMAKE_SYSTEM_NAME MATCHES "Linux") + target_link_options(wasmedgeHostModuleWasiNN + PUBLIC + -u_ZN8WasmEdge4Host6WASINN17WasiNNEnvironment8RegisterE + ) +elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin") + target_link_options(wasmedgeHostModuleWasiNN + PUBLIC + -u__ZN8WasmEdge4Host6WASINN17WasiNNEnvironment8RegisterE + ) +endif() + +# add backends building flags +if(WASMEDGE_WASINN_BUILD_OPENVINO) + message(STATUS "Build OpenVINO backend for WASI-NN") + add_definitions(-DWASMEDGE_WASINN_BUILD_OPENVINO) + find_package(InferenceEngine REQUIRED) + target_link_libraries(wasmedgeHostModuleWasiNN + PUBLIC + ${InferenceEngine_LIBRARIES} + ) + target_link_libraries(wasmedgePluginWasiNN + PUBLIC + ${InferenceEngine_LIBRARIES} + ) +endif() diff --git a/plugins/wasi_nn/wasinnbase.h b/plugins/wasi_nn/wasinnbase.h new file mode 100644 index 00000000..b5daaf84 --- /dev/null +++ b/plugins/wasi_nn/wasinnbase.h @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "common/errcode.h" +#include "runtime/hostfunc.h" +#include "wasinnenv.h" + +namespace WasmEdge { +namespace Host { + +template class WasiNN : public Runtime::HostFunction { +public: + WasiNN(WASINN::WasiNNEnvironment &HostEnv) + : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + WASINN::WasiNNEnvironment &Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp new file mode 100644 index 00000000..b5b2cd0c --- /dev/null +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "wasinnenv.h" +#include "wasinnmodule.h" + +namespace WasmEdge { +namespace Host { + +namespace { + +void addOptions(PO::ArgumentParser &) noexcept {} + +Runtime::Instance::ModuleInstance *create(void) noexcept { + return new WasiNNModule; +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasi_nn", + .Description = "", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 10, 1, 0}, + .ModuleCount = 1, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "wasi_nn", + .Description = "", + .Create = create, + }, + }, + .AddOptions = addOptions, +}; + +} // namespace + +Plugin::PluginRegister WASINN::WasiNNEnvironment::Register(&Descriptor); + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h new file mode 100644 index 00000000..c61f9d20 --- /dev/null +++ b/plugins/wasi_nn/wasinnenv.h @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "plugin/plugin.h" + +#include +#include + +#ifdef WASMEDGE_WASINN_BUILD_OPENVINO +#include "common/log.h" +#include +#endif + +namespace WasmEdge { +namespace Host { +namespace WASINN { + +enum class ErrNo : uint32_t { + Success = 0, // No error occurred. + InvalidArgument = 1, // Caller module passed an invalid argument. + MissingMemory = 2, // Caller module is missing a memory export. + Busy = 3 // Device or resource busy. +}; + +enum class Backend : uint8_t { + OpenVINO = 0, +}; + +class Graph { +public: +#ifdef WASMEDGE_WASINN_BUILD_OPENVINO + Graph() = delete; + Graph(Backend BE) noexcept + : GraphBackend(BE), OpenVINONetwork(nullptr), + OpenVINOExecNetwork(nullptr), OpenVINOWeightBlob(nullptr) {} + ~Graph() noexcept { + if (OpenVINONetwork) { + ie_network_free(&OpenVINONetwork); + } + if (OpenVINOExecNetwork) { + ie_exec_network_free(&OpenVINOExecNetwork); + } + if (OpenVINOWeightBlob) { + ie_blob_free(&OpenVINOWeightBlob); + } + for (auto &I : OpenVINOInputNames) { + if (I) { + ie_network_name_free(&I); + } + } + for (auto &I : OpenVINOOutputNames) { + if (I) { + ie_network_name_free(&I); + } + } + } +#else + Graph() noexcept = default; +#endif + + Backend GraphBackend; +#ifdef WASMEDGE_WASINN_BUILD_OPENVINO + ie_network_t *OpenVINONetwork; + ie_executable_network_t *OpenVINOExecNetwork; + ie_blob_t *OpenVINOWeightBlob; + std::vector OpenVINOInputNames; + std::vector OpenVINOOutputNames; +#endif +}; + +class Context { +public: + Context() = delete; +#ifdef WASMEDGE_WASINN_BUILD_OPENVINO + Context(Graph &G, ie_infer_request_t *InferReq) noexcept + : GraphRef(G), OpenVINOInferRequest(InferReq) {} + ~Context() noexcept { + if (OpenVINOInferRequest) { + ie_infer_request_free(&OpenVINOInferRequest); + } + } +#else + Context(Graph &G) noexcept : GraphRef(G) {} +#endif + + Graph &GraphRef; +#ifdef WASMEDGE_WASINN_BUILD_OPENVINO + ie_infer_request_t *OpenVINOInferRequest; +#endif +}; + +class WasiNNEnvironment { +public: + WasiNNEnvironment() noexcept { +#ifdef WASMEDGE_WASINN_BUILD_OPENVINO + if (ie_core_create("", &OpenVINOCore) != IEStatusCode::OK) { + spdlog::error( + "[WASI-NN] Error happened when initializing OpenVINO core."); + } +#endif + NNGraph.reserve(16U); + NNContext.reserve(16U); + } + ~WasiNNEnvironment() noexcept { + NNContext.clear(); + NNGraph.clear(); +#ifdef WASMEDGE_WASINN_BUILD_OPENVINO + if (OpenVINOCore) { + ie_core_free(&OpenVINOCore); + } +#endif + } + + std::vector NNGraph; + std::vector NNContext; +#ifdef WASMEDGE_WASINN_BUILD_OPENVINO + ie_core_t *OpenVINOCore = nullptr; +#endif + + static Plugin::PluginRegister Register; +}; + +} // namespace WASINN +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp new file mode 100644 index 00000000..cf840b4b --- /dev/null +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -0,0 +1,601 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "wasinnfunc.h" +#include "common/log.h" + +#ifdef WASMEDGE_WASINN_BUILD_OPENVINO +#include +#include + +#include +#endif + +namespace WasmEdge { +namespace Host { + +Expect WasiNNLoad::body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t BuilderPtr [[maybe_unused]], + uint32_t BuilderLen [[maybe_unused]], + uint32_t Encoding, + uint32_t Target [[maybe_unused]], + uint32_t GraphIdPtr [[maybe_unused]]) { + // Check memory instance from module. + if (MemInst == nullptr) { + return Unexpect(ErrCode::ExecutionFailed); + } + + if (Encoding == static_cast(WASINN::Backend::OpenVINO)) { +#ifdef WASMEDGE_WASINN_BUILD_OPENVINO + // The OpenVINO core must be initialized in constructor. + if (unlikely(Env.OpenVINOCore == nullptr)) { + spdlog::error("[WASI-NN] OpenVINO core not initialized."); + return static_cast(WASINN::ErrNo::MissingMemory); + } + + // Check the return value: GraphIdPtr should be valid. + uint32_t *GraphId = MemInst->getPointer(GraphIdPtr, 1); + if (unlikely(GraphId == nullptr)) { + spdlog::error( + "[WASI-NN] Failed when accessing the return GraphID memory."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + + // The graph builder length must be 2. + if (BuilderLen != 2) { + spdlog::error("[WASI-NN] Wrong GraphBuilder Length {:d}, expect 2", + BuilderLen); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + + // Get and check the device name string. + std::string DeviceName; + switch (Target) { + case 0: + DeviceName = "CPU"; + break; + case 1: + DeviceName = "GPU"; + break; + case 2: + DeviceName = "TPU"; + break; + default: + DeviceName = ""; + } + if (DeviceName.length() == 0) { + spdlog::error("[WASI-NN] Device target {:d} not support!", Target); + return static_cast(WASINN::ErrNo::InvalidArgument); + } else { + spdlog::debug("[WASI-NN] Using device: {:s}", DeviceName); + } + + // Get the graph builders. + // GraphBuilders' Layout: + // | builder-0 | builder-0 len | builder-1 | builder-1 len | ... + uint32_t *GraphBuilders = + MemInst->getPointer(BuilderPtr, BuilderLen * 2); + if (unlikely(GraphBuilders == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the GraphBuilder memory."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + + // Get the XML and Weight raw buffer from memory instance. + // Builder-0: the XML string + // Builder-1: the Weight binary + uint32_t XMLStringLen = GraphBuilders[1]; + uint32_t WeightsBinLen = GraphBuilders[3]; + uint8_t *XMLPtr = + MemInst->getPointer(GraphBuilders[0], XMLStringLen); + uint8_t *BinPtr = + MemInst->getPointer(GraphBuilders[2], WeightsBinLen); + if (unlikely(XMLPtr == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the XML memory."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + if (unlikely(BinPtr == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the Weignt memory."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + + // Add a new graph. + Env.NNGraph.emplace_back(static_cast(Encoding)); + auto &Graph = Env.NNGraph.back(); + + // Create the weights blob memory. + tensor_desc_t WeightsDesc{ + layout_e::ANY, {1, {WeightsBinLen}}, precision_e::U8}; + IEStatusCode Status = + ie_blob_make_memory(&WeightsDesc, &(Graph.OpenVINOWeightBlob)); + if (Status != IEStatusCode::OK) { + spdlog::error( + "[WASI-NN] Unable to create the model's weight blob, error code: {}", + Status); + Env.NNGraph.pop_back(); + return static_cast(WASINN::ErrNo::Busy); + } + + // Copy the weights buffer to the blob. + ie_blob_buffer_t BlobBuffer; + Status = ie_blob_get_buffer(Graph.OpenVINOWeightBlob, &BlobBuffer); + if (unlikely(Status != IEStatusCode::OK)) { + spdlog::error( + "[WASI-NN] Unable to find the weight blob's buffer, error code: {}", + Status); + Env.NNGraph.pop_back(); + return static_cast(WASINN::ErrNo::MissingMemory); + } + std::copy_n(BinPtr, WeightsBinLen, + static_cast(BlobBuffer.buffer)); + + // Read network from memory. + Status = ie_core_read_network_from_memory( + Env.OpenVINOCore, XMLPtr, XMLStringLen, Graph.OpenVINOWeightBlob, + &(Graph.OpenVINONetwork)); + if (Status != IEStatusCode::OK) { + spdlog::error("[WASI-NN] Unable to read network from the XML and " + "Weights, error code: {}", + Status); + Env.NNGraph.pop_back(); + return static_cast(WASINN::ErrNo::Busy); + } + + // Get the network input and output size. + size_t NetworkInputSize = 0; + Status = + ie_network_get_inputs_number(Graph.OpenVINONetwork, &NetworkInputSize); + if (unlikely(Status != IEStatusCode::OK)) { + spdlog::error("[WASI-NN] Unable to get the inputs number from the " + "network, error code: {}", + Status); + Env.NNGraph.pop_back(); + return static_cast(WASINN::ErrNo::MissingMemory); + } + spdlog::debug("[WASI-NN] Got input size: {}", NetworkInputSize); + size_t NetworkOutputSize = 0; + Status = ie_network_get_outputs_number(Graph.OpenVINONetwork, + &NetworkOutputSize); + if (unlikely(Status != IEStatusCode::OK)) { + spdlog::error("[WASI-NN] Unable to get the outputs number from the " + "network, error code: {}", + Status); + Env.NNGraph.pop_back(); + return static_cast(WASINN::ErrNo::MissingMemory); + } + spdlog::debug("[WASI-NN] Got output size: {}", NetworkOutputSize); + + // Get and store the input and output names. + Graph.OpenVINOInputNames.resize(NetworkInputSize, nullptr); + for (size_t I = 0; I < NetworkInputSize; I++) { + Status = ie_network_get_input_name(Graph.OpenVINONetwork, I, + &(Graph.OpenVINOInputNames[I])); + if (Status != IEStatusCode::OK) { + spdlog::error("[WASI-NN] Unable to find input name correctly with " + "Index {}, error code: {}", + I, Status); + Env.NNGraph.pop_back(); + return static_cast(WASINN::ErrNo::MissingMemory); + } + spdlog::debug("[WASI-NN] Got input name: {}", + Graph.OpenVINOInputNames[I]); + } + Graph.OpenVINOOutputNames.resize(NetworkOutputSize, nullptr); + for (size_t I = 0; I < NetworkOutputSize; I++) { + Status = ie_network_get_output_name(Graph.OpenVINONetwork, I, + &(Graph.OpenVINOOutputNames[I])); + if (Status != IEStatusCode::OK) { + spdlog::error("[WASI-NN] Unable to find output name correctly with " + "Index {}, error code: {}", + I, Status); + Env.NNGraph.pop_back(); + return static_cast(WASINN::ErrNo::MissingMemory); + } + spdlog::debug("[WASI-NN] Got output name: {}", + Graph.OpenVINOOutputNames[I]); + } + + // Set the input layout. + // FIXME: this is a temporary workaround. We need a more eligant way to + // specify the layout in the long run. However, without this newer versions + // of OpenVINO will fail due to parameter mismatch. + for (size_t I = 0; I < NetworkInputSize; I++) { + // More layouts should be supported. + Status = ie_network_set_input_layout( + Graph.OpenVINONetwork, Graph.OpenVINOInputNames[I], layout_e::NHWC); + spdlog::debug("[WASI-NN] Setting [{}] to NHWC", + Graph.OpenVINOInputNames[I]); + if (Status != IEStatusCode::OK) { + spdlog::error("[WASI-NN] Unable to set input layout with the input " + "name {}, error code: {}", + Graph.OpenVINOInputNames[I], Status); + Env.NNGraph.pop_back(); + return static_cast(WASINN::ErrNo::MissingMemory); + } + } + + // Load network. + ie_config_t Config = {nullptr, nullptr, nullptr}; + Status = ie_core_load_network(Env.OpenVINOCore, Graph.OpenVINONetwork, + DeviceName.c_str(), &Config, + &(Graph.OpenVINOExecNetwork)); + if (Status != IEStatusCode::OK) { + spdlog::error( + "[WASI-NN] Unable to create executable Network, error code: {}", + Status); + Env.NNGraph.pop_back(); + return static_cast(WASINN::ErrNo::Busy); + } + + // Store the loaded graph. + *GraphId = Env.NNGraph.size() - 1; + + return static_cast(WASINN::ErrNo::Success); +#else + spdlog::error("[WASI-NN] OpenVINO backend is not built. use " + "-DWASMEDGE_WASINN_BUILD_OPENVINO=ON" + "to build it."); +#endif + } else { + spdlog::error("[WASI-NN] Current backend is not supported."); + } + return static_cast(WASINN::ErrNo::InvalidArgument); +} + +Expect +WasiNNInitExecCtx::body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t GraphId, + uint32_t ContextPtr [[maybe_unused]]) { + if (MemInst == nullptr) { + return Unexpect(ErrCode::ExecutionFailed); + } + + if (Env.NNGraph.size() <= GraphId) { + spdlog::error("[WASI-NN] init_execution_context: Graph Id does not exist."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + + if (Env.NNGraph[GraphId].GraphBackend == WASINN::Backend::OpenVINO) { +#ifdef WASMEDGE_WASINN_BUILD_OPENVINO + // Check the return value: Context should be valid. + uint32_t *Context = MemInst->getPointer(ContextPtr, 1); + if (unlikely(Context == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the Context memory."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + + // Check the network and the execution network with the graph ID. + if (Env.NNGraph[GraphId].OpenVINONetwork == nullptr || + Env.NNGraph[GraphId].OpenVINOExecNetwork == nullptr) { + spdlog::error("[WASI-NN] Model for Graph:{} is empty!", GraphId); + return static_cast(WASINN::ErrNo::MissingMemory); + } + + // Create the infer request. + ie_infer_request_t *InferRequest = nullptr; + IEStatusCode Status = ie_exec_network_create_infer_request( + Env.NNGraph[GraphId].OpenVINOExecNetwork, &InferRequest); + if (Status != IEStatusCode::OK) { + spdlog::error("[WASI-NN] Unable to create openvino session"); + return static_cast(WASINN::ErrNo::Busy); + } + + *Context = Env.NNContext.size(); + Env.NNContext.emplace_back(Env.NNGraph[GraphId], InferRequest); + + return static_cast(WASINN::ErrNo::Success); +#else + spdlog::error("[WASI-NN] OpenVINO backend is not built. define " + "-DWASMEDGE_WASINN_BUILD_OPENVINO " + "to build it."); +#endif + } else { + spdlog::error("[WASI-NN] Current backend is not supported."); + } + return static_cast(WASINN::ErrNo::InvalidArgument); +} + +Expect +WasiNNSetInput::body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t Context, uint32_t Index [[maybe_unused]], + uint32_t TensorPtr [[maybe_unused]]) { + if (MemInst == nullptr) { + return Unexpect(ErrCode::ExecutionFailed); + } + + if (Env.NNContext.size() <= Context) { + spdlog::error("[WASI-NN] set_input: Execution Context does not exist."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + + auto &CxtRef = Env.NNContext[Context]; + if (CxtRef.GraphRef.GraphBackend == WASINN::Backend::OpenVINO) { +#ifdef WASMEDGE_WASINN_BUILD_OPENVINO + // Check the infer request and the network. + auto *Network = CxtRef.GraphRef.OpenVINONetwork; + if (Network == nullptr || CxtRef.OpenVINOInferRequest == nullptr) { + spdlog::error("[WASI-NN] The founded openvino session is empty"); + return static_cast(WASINN::ErrNo::MissingMemory); + } + + // Check the input index. + if (CxtRef.GraphRef.OpenVINOInputNames.size() <= Index) { + spdlog::error( + "[WASI-NN] The input index {} exceeds the inputs number {}.", Index, + CxtRef.GraphRef.OpenVINOInputNames.size()); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + char *InputName = CxtRef.GraphRef.OpenVINOInputNames[Index]; + + // Get the tensor. + // Tensor's Layout: + // | dim buf | dim buf len | rtype | data buf | data buf len | + uint32_t *Tensor = MemInst->getPointer(TensorPtr, 5); + if (unlikely(Tensor == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the Tensor memory."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + uint32_t DimensionLen = Tensor[1]; + if (DimensionLen > 8) { + spdlog::error( + "[WASI-NN] Tensor dimension is out of range, expect it under 8-dim, " + "but got {}-dim.", + DimensionLen); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + uint32_t *DimensionBuf = + MemInst->getPointer(Tensor[0], DimensionLen); + if (unlikely(DimensionBuf == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the Dimension memory."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + uint32_t TensorDataLen = Tensor[4]; + uint8_t *TensorDataBuf = + MemInst->getPointer(Tensor[3], TensorDataLen); + if (unlikely(TensorDataBuf == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the TensorData memory."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + uint32_t RType = Tensor[2]; + if (RType != 1) { + spdlog::error( + "[WASI-NN] Only F32 inputs and outputs are supported for now."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + + // Set the input resize algorithm. + // Mark the input as resizable by setting a resize algorithm. + // In this case we will be able to set an input blob of any shape to an + // infer request. Resizing and layout conversions are executed automatically + // when inferring. + IEStatusCode Status = ie_network_set_input_resize_algorithm( + Network, InputName, RESIZE_BILINEAR); + if (Status != IEStatusCode::OK) { + spdlog::error( + "[WASI-NN] Unable to set input resize correctly, error code: {}", + Status); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + + // Set the input layout. + // More layouts should be supported. + Status = ie_network_set_input_layout(Network, InputName, layout_e::NHWC); + if (Status != IEStatusCode::OK) { + spdlog::error( + "[WASI-NN] Unable to set input layout correctly, error code: {}", + Status); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + + // Set the input precision. + // More types should be supported. + Status = + ie_network_set_input_precision(Network, InputName, precision_e::FP32); + if (Status != IEStatusCode::OK) { + spdlog::error( + "[WASI-NN] Unable to set input precision correctly, error code: {}", + Status); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + + // Set the dimensions and the tensor description. + dimensions_t Dimens; + Dimens.ranks = DimensionLen; + for (size_t I = 0; I < Dimens.ranks; I++) { + Dimens.dims[I] = static_cast(DimensionBuf[I]); + } + tensor_desc_t TensorDesc = {layout_e::NHWC, Dimens, precision_e::FP32}; + + // Create the input blob memory. + ie_blob_t *InputBlob = nullptr; + Status = ie_blob_make_memory(&TensorDesc, &InputBlob); + if (Status != IEStatusCode::OK) { + spdlog::error("[WASI-NN] Unable to allocated input tensor correctly, " + "error code: {}", + Status); + return static_cast(WASINN::ErrNo::Busy); + } + + // Get the blob buffer size and compare with the tensor size. + int BlobSize; + Status = ie_blob_size(InputBlob, &BlobSize); + if (unlikely(Status != IEStatusCode::OK)) { + spdlog::error( + "[WASI-NN] Unable to get the input blob size, error code: {}", + Status); + return static_cast(WASINN::ErrNo::Busy); + } + if (unlikely(static_cast(BlobSize * 4) != TensorDataLen)) { + spdlog::error( + "[WASI-NN] Blob size {} and the Tensor size {} not matched.", + BlobSize * 4, TensorDataLen); + } + + // Copy the data into the input blob buffer. + ie_blob_buffer_t BlobBuffer; + Status = ie_blob_get_buffer(InputBlob, &BlobBuffer); + if (unlikely(Status != IEStatusCode::OK)) { + spdlog::error("[WASI-NN] Unable to find input tensor buffer"); + ie_blob_free(&InputBlob); + return static_cast(WASINN::ErrNo::MissingMemory); + } + std::copy_n(TensorDataBuf, TensorDataLen, + static_cast(BlobBuffer.buffer)); + + // Set input blob. + Status = ie_infer_request_set_blob(CxtRef.OpenVINOInferRequest, InputName, + InputBlob); + if (Status != IEStatusCode::OK) { + spdlog::error("[WASI-NN] Unable to set input tensor to model correctly, " + "error code: {}", + Status); + ie_blob_free(&InputBlob); + return static_cast(WASINN::ErrNo::Busy); + } + + ie_blob_free(&InputBlob); + + return static_cast(WASINN::ErrNo::Success); +#else + spdlog::error("[WASI-NN] OpenVINO backend is not built, use " + "-DWASMEDGE_WASINN_BUILD_OPENVINO=ON" + "to build it."); +#endif + } else { + spdlog::error("[WASI-NN] Current backend is not supported."); + } + return static_cast(WASINN::ErrNo::InvalidArgument); +} + +Expect +WasiNNGetOuput::body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t Context, uint32_t Index [[maybe_unused]], + uint32_t OutBufferPtr [[maybe_unused]], + uint32_t OutBufferMaxSize [[maybe_unused]], + uint32_t BytesWrittenPtr [[maybe_unused]]) { + if (MemInst == nullptr) { + return Unexpect(ErrCode::ExecutionFailed); + } + + if (Env.NNContext.size() <= Context) { + spdlog::error("[WASI-NN] get_output: Execution Context does not exist"); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + + auto &CxtRef = Env.NNContext[Context]; + if (CxtRef.GraphRef.GraphBackend == WASINN::Backend::OpenVINO) { +#ifdef WASMEDGE_WASINN_BUILD_OPENVINO + auto *Network = CxtRef.GraphRef.OpenVINONetwork; + + // Check the output index. + if (CxtRef.GraphRef.OpenVINOOutputNames.size() <= Index) { + spdlog::error( + "[WASI-NN] The output index {} exceeds the outputs number {}.", Index, + CxtRef.GraphRef.OpenVINOOutputNames.size()); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + char *OutputName = CxtRef.GraphRef.OpenVINOOutputNames[Index]; + + // Set output precision. + IEStatusCode Status = + ie_network_set_output_precision(Network, OutputName, precision_e::FP32); + if (Status != IEStatusCode::OK) { + spdlog::error( + "[WASI-NN] Unable to set output precision correctly with Index:{}", + Index); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + + // Get output blob buffer. + ie_blob_t *OutputBlob = nullptr; + Status = ie_infer_request_get_blob(CxtRef.OpenVINOInferRequest, OutputName, + &OutputBlob); + if (Status != IEStatusCode::OK) { + spdlog::error("[WASI-NN] Unable to retrieve output tensor correctly", + Index); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + + // Get the blob size and copy the output buffer. + int BlobSize; + Status = ie_blob_size(OutputBlob, &BlobSize); + ie_blob_buffer_t BlobCBuffer; + Status = ie_blob_get_cbuffer(OutputBlob, &BlobCBuffer); + if (Status != IEStatusCode::OK) { + spdlog::error("[WASI-NN] Unable to retrieve output tensor correctly", + Index); + ie_blob_free(&OutputBlob); + return static_cast(WASINN::ErrNo::MissingMemory); + } + uint32_t BytesToWrite = + std::min(static_cast(BlobSize * 4), OutBufferMaxSize); + uint8_t *OutBuffer = + MemInst->getPointer(OutBufferPtr, BytesToWrite); + if (unlikely(OutBuffer == nullptr)) { + spdlog::error( + "[WASI-NN] Failed when accessing the Output Buffer memory."); + ie_blob_free(&OutputBlob); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + std::copy_n(static_cast(BlobCBuffer.cbuffer), BytesToWrite, + OutBuffer); + + // Write the bytes written result. + uint32_t *BytesWritten = + MemInst->getPointer(BytesWrittenPtr, 1); + if (unlikely(BytesWritten == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the BytesWritten memory."); + ie_blob_free(&OutputBlob); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + *BytesWritten = BytesToWrite; + + ie_blob_free(&OutputBlob); + + return static_cast(WASINN::ErrNo::Success); +#else + spdlog::error("[WASI-NN] OpenVINO backend is not built. use " + "-DWASMEDGE_WASINN_BUILD_OPENVINO=ON" + "to build it."); +#endif + } else { + spdlog::error("[WASI-NN] Current backend is not supported."); + } + return static_cast(WASINN::ErrNo::InvalidArgument); +} + +Expect WasiNNCompute::body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t Context) { + if (MemInst == nullptr) { + return Unexpect(ErrCode::ExecutionFailed); + } + + if (Env.NNContext.size() <= Context) { + spdlog::error("[WASI-NN] compute: Execution Context does not exist."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + + auto &CxtRef = Env.NNContext[Context]; + if (CxtRef.GraphRef.GraphBackend == WASINN::Backend::OpenVINO) { +#ifdef WASMEDGE_WASINN_BUILD_OPENVINO + IEStatusCode Status = ie_infer_request_infer(CxtRef.OpenVINOInferRequest); + if (Status != IEStatusCode::OK) { + spdlog::error( + "[WASI-NN] Unable to perform computation correctly, error code: {}", + Status); + return static_cast(WASINN::ErrNo::Busy); + } + return static_cast(WASINN::ErrNo::Success); +#else + spdlog::error("[WASI-NN] OpenVINO backend is not built. use " + "-DWASMEDGE_WASINN_BUILD_OPENVINO=ON" + "to build it."); +#endif + } else { + spdlog::error("[WASI-NN] Current backend is not supported."); + } + + return static_cast(WASINN::ErrNo::InvalidArgument); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_nn/wasinnfunc.h b/plugins/wasi_nn/wasinnfunc.h new file mode 100644 index 00000000..49925acf --- /dev/null +++ b/plugins/wasi_nn/wasinnfunc.h @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "runtime/instance/memory.h" +#include "wasinnbase.h" + +#include + +namespace WasmEdge { +namespace Host { + +class WasiNNLoad : public WasiNN { +public: + WasiNNLoad(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} + Expect body(Runtime::Instance::MemoryInstance *, + uint32_t BuilderPtr, uint32_t BuilderLen, + uint32_t Encoding, uint32_t Target, + uint32_t GraphIdPtr); +}; + +class WasiNNInitExecCtx : public WasiNN { +public: + WasiNNInitExecCtx(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} + Expect body(Runtime::Instance::MemoryInstance *, uint32_t GraphId, + uint32_t ContextPtr); +}; + +class WasiNNSetInput : public WasiNN { +public: + WasiNNSetInput(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} + Expect body(Runtime::Instance::MemoryInstance *, uint32_t Context, + uint32_t Index, uint32_t TensorPtr); +}; + +class WasiNNGetOuput : public WasiNN { +public: + WasiNNGetOuput(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} + Expect body(Runtime::Instance::MemoryInstance *, uint32_t Context, + uint32_t Index, uint32_t OutBufferPtr, + uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr); +}; + +class WasiNNCompute : public WasiNN { +public: + WasiNNCompute(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} + Expect body(Runtime::Instance::MemoryInstance *, uint32_t Context); +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_nn/wasinnmodule.cpp b/plugins/wasi_nn/wasinnmodule.cpp new file mode 100644 index 00000000..56cf5314 --- /dev/null +++ b/plugins/wasi_nn/wasinnmodule.cpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "wasinnmodule.h" +#include "wasinnfunc.h" + +namespace WasmEdge { +namespace Host { + +WasiNNModule::WasiNNModule() : ModuleInstance("wasi_ephemeral_nn") { + addHostFunc("load", std::make_unique(Env)); + addHostFunc("init_execution_context", + std::make_unique(Env)); + addHostFunc("set_input", std::make_unique(Env)); + addHostFunc("get_output", std::make_unique(Env)); + addHostFunc("compute", std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_nn/wasinnmodule.h b/plugins/wasi_nn/wasinnmodule.h new file mode 100644 index 00000000..486cac12 --- /dev/null +++ b/plugins/wasi_nn/wasinnmodule.h @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "runtime/instance/module.h" +#include "wasinnenv.h" + +namespace WasmEdge { +namespace Host { + +class WasiNNModule : public Runtime::Instance::ModuleInstance { +public: + WasiNNModule(); + +private: + WASINN::WasiNNEnvironment Env; +}; + +} // namespace Host +} // namespace WasmEdge From 085d347efd40c83992b9d84efc70286ee99a7d7c Mon Sep 17 00:00:00 2001 From: YiYing He Date: Thu, 23 Jun 2022 19:35:53 +0800 Subject: [PATCH 045/623] [Plugin] Accept the null option adder. Signed-off-by: YiYing He --- plugins/wasi_nn/wasinnenv.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index b5b2cd0c..148c8e5b 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -9,8 +9,6 @@ namespace Host { namespace { -void addOptions(PO::ArgumentParser &) noexcept {} - Runtime::Instance::ModuleInstance *create(void) noexcept { return new WasiNNModule; } @@ -29,7 +27,7 @@ Plugin::Plugin::PluginDescriptor Descriptor{ .Create = create, }, }, - .AddOptions = addOptions, + .AddOptions = nullptr, }; } // namespace From 16e339931da53d21d834c1a2424a0eefae627da0 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Mon, 27 Jun 2022 16:58:30 +0800 Subject: [PATCH 046/623] [WASI] Refine the cmake options for WASI-NN. Signed-off-by: YiYing He --- plugins/CMakeLists.txt | 4 ++- plugins/wasi_nn/CMakeLists.txt | 38 ++++++++++---------- plugins/wasi_nn/wasinnenv.h | 16 ++++----- plugins/wasi_nn/wasinnfunc.cpp | 22 ++++++------ utils/docker/build-wasinn-ubuntu-openvino.sh | 12 ++++++- 5 files changed, 53 insertions(+), 39 deletions(-) diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index c935c71b..178c0ffc 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -2,4 +2,6 @@ # SPDX-FileCopyrightText: 2019-2022 Second State INC add_subdirectory(wasmedge_process) -add_subdirectory(wasi_nn) +if(WASMEDGE_WASINN_BACKEND) + add_subdirectory(wasi_nn) +endif() diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 94c7a150..584dbb2b 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -32,10 +32,7 @@ target_link_libraries(wasmedgePluginWasiNN wasmedgeSystem ) -# Only Linux systems support wasi-nn with openVINO now. -if(CMAKE_SYSTEM_NAME MATCHES "Linux" AND WASMEDGE_WASINN_BUILD_OPENVINO) - install(TARGETS wasmedgePluginWasiNN DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) -endif() +install(TARGETS wasmedgePluginWasiNN DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) wasmedge_add_library(wasmedgeHostModuleWasiNN wasinnenv.cpp @@ -67,17 +64,22 @@ elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin") ) endif() -# add backends building flags -if(WASMEDGE_WASINN_BUILD_OPENVINO) - message(STATUS "Build OpenVINO backend for WASI-NN") - add_definitions(-DWASMEDGE_WASINN_BUILD_OPENVINO) - find_package(InferenceEngine REQUIRED) - target_link_libraries(wasmedgeHostModuleWasiNN - PUBLIC - ${InferenceEngine_LIBRARIES} - ) - target_link_libraries(wasmedgePluginWasiNN - PUBLIC - ${InferenceEngine_LIBRARIES} - ) -endif() +# Add backends building flags. +foreach(BACKEND ${WASMEDGE_WASINN_BACKEND}) + if(BACKEND MATCHES "OpenVINO") + message(STATUS "Build ${BACKEND} backend for WASI-NN") + find_package(InferenceEngine REQUIRED) + add_definitions(-DWASMEDGE_WASINN_BACKEND_OPENVINO) + target_link_libraries(wasmedgePluginWasiNN + PUBLIC + ${InferenceEngine_LIBRARIES} + ) + target_link_libraries(wasmedgeHostModuleWasiNN + PUBLIC + ${InferenceEngine_LIBRARIES} + ) + else() + # Add the other backends here. + message(FATAL_ERROR "WASI-NN backend ${BACKEND} not found or unimplemented.") + endif() +endforeach() diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index c61f9d20..69255d96 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -8,7 +8,7 @@ #include #include -#ifdef WASMEDGE_WASINN_BUILD_OPENVINO +#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO #include "common/log.h" #include #endif @@ -30,7 +30,7 @@ enum class Backend : uint8_t { class Graph { public: -#ifdef WASMEDGE_WASINN_BUILD_OPENVINO +#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO Graph() = delete; Graph(Backend BE) noexcept : GraphBackend(BE), OpenVINONetwork(nullptr), @@ -61,7 +61,7 @@ class Graph { #endif Backend GraphBackend; -#ifdef WASMEDGE_WASINN_BUILD_OPENVINO +#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO ie_network_t *OpenVINONetwork; ie_executable_network_t *OpenVINOExecNetwork; ie_blob_t *OpenVINOWeightBlob; @@ -73,7 +73,7 @@ class Graph { class Context { public: Context() = delete; -#ifdef WASMEDGE_WASINN_BUILD_OPENVINO +#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO Context(Graph &G, ie_infer_request_t *InferReq) noexcept : GraphRef(G), OpenVINOInferRequest(InferReq) {} ~Context() noexcept { @@ -86,7 +86,7 @@ class Context { #endif Graph &GraphRef; -#ifdef WASMEDGE_WASINN_BUILD_OPENVINO +#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO ie_infer_request_t *OpenVINOInferRequest; #endif }; @@ -94,7 +94,7 @@ class Context { class WasiNNEnvironment { public: WasiNNEnvironment() noexcept { -#ifdef WASMEDGE_WASINN_BUILD_OPENVINO +#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO if (ie_core_create("", &OpenVINOCore) != IEStatusCode::OK) { spdlog::error( "[WASI-NN] Error happened when initializing OpenVINO core."); @@ -106,7 +106,7 @@ class WasiNNEnvironment { ~WasiNNEnvironment() noexcept { NNContext.clear(); NNGraph.clear(); -#ifdef WASMEDGE_WASINN_BUILD_OPENVINO +#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO if (OpenVINOCore) { ie_core_free(&OpenVINOCore); } @@ -115,7 +115,7 @@ class WasiNNEnvironment { std::vector NNGraph; std::vector NNContext; -#ifdef WASMEDGE_WASINN_BUILD_OPENVINO +#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO ie_core_t *OpenVINOCore = nullptr; #endif diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index cf840b4b..1566d3ba 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -4,7 +4,7 @@ #include "wasinnfunc.h" #include "common/log.h" -#ifdef WASMEDGE_WASINN_BUILD_OPENVINO +#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO #include #include @@ -26,7 +26,7 @@ Expect WasiNNLoad::body(Runtime::Instance::MemoryInstance *MemInst, } if (Encoding == static_cast(WASINN::Backend::OpenVINO)) { -#ifdef WASMEDGE_WASINN_BUILD_OPENVINO +#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO // The OpenVINO core must be initialized in constructor. if (unlikely(Env.OpenVINOCore == nullptr)) { spdlog::error("[WASI-NN] OpenVINO core not initialized."); @@ -232,7 +232,7 @@ Expect WasiNNLoad::body(Runtime::Instance::MemoryInstance *MemInst, return static_cast(WASINN::ErrNo::Success); #else spdlog::error("[WASI-NN] OpenVINO backend is not built. use " - "-DWASMEDGE_WASINN_BUILD_OPENVINO=ON" + "-DWASMEDGE_WASINN_BACKEND_OPENVINO=ON" "to build it."); #endif } else { @@ -255,7 +255,7 @@ WasiNNInitExecCtx::body(Runtime::Instance::MemoryInstance *MemInst, } if (Env.NNGraph[GraphId].GraphBackend == WASINN::Backend::OpenVINO) { -#ifdef WASMEDGE_WASINN_BUILD_OPENVINO +#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO // Check the return value: Context should be valid. uint32_t *Context = MemInst->getPointer(ContextPtr, 1); if (unlikely(Context == nullptr)) { @@ -285,7 +285,7 @@ WasiNNInitExecCtx::body(Runtime::Instance::MemoryInstance *MemInst, return static_cast(WASINN::ErrNo::Success); #else spdlog::error("[WASI-NN] OpenVINO backend is not built. define " - "-DWASMEDGE_WASINN_BUILD_OPENVINO " + "-DWASMEDGE_WASINN_BACKEND_OPENVINO " "to build it."); #endif } else { @@ -309,7 +309,7 @@ WasiNNSetInput::body(Runtime::Instance::MemoryInstance *MemInst, auto &CxtRef = Env.NNContext[Context]; if (CxtRef.GraphRef.GraphBackend == WASINN::Backend::OpenVINO) { -#ifdef WASMEDGE_WASINN_BUILD_OPENVINO +#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO // Check the infer request and the network. auto *Network = CxtRef.GraphRef.OpenVINONetwork; if (Network == nullptr || CxtRef.OpenVINOInferRequest == nullptr) { @@ -457,7 +457,7 @@ WasiNNSetInput::body(Runtime::Instance::MemoryInstance *MemInst, return static_cast(WASINN::ErrNo::Success); #else spdlog::error("[WASI-NN] OpenVINO backend is not built, use " - "-DWASMEDGE_WASINN_BUILD_OPENVINO=ON" + "-DWASMEDGE_WASINN_BACKEND_OPENVINO=ON" "to build it."); #endif } else { @@ -483,7 +483,7 @@ WasiNNGetOuput::body(Runtime::Instance::MemoryInstance *MemInst, auto &CxtRef = Env.NNContext[Context]; if (CxtRef.GraphRef.GraphBackend == WASINN::Backend::OpenVINO) { -#ifdef WASMEDGE_WASINN_BUILD_OPENVINO +#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO auto *Network = CxtRef.GraphRef.OpenVINONetwork; // Check the output index. @@ -554,7 +554,7 @@ WasiNNGetOuput::body(Runtime::Instance::MemoryInstance *MemInst, return static_cast(WASINN::ErrNo::Success); #else spdlog::error("[WASI-NN] OpenVINO backend is not built. use " - "-DWASMEDGE_WASINN_BUILD_OPENVINO=ON" + "-DWASMEDGE_WASINN_BACKEND_OPENVINO=ON" "to build it."); #endif } else { @@ -576,7 +576,7 @@ Expect WasiNNCompute::body(Runtime::Instance::MemoryInstance *MemInst, auto &CxtRef = Env.NNContext[Context]; if (CxtRef.GraphRef.GraphBackend == WASINN::Backend::OpenVINO) { -#ifdef WASMEDGE_WASINN_BUILD_OPENVINO +#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO IEStatusCode Status = ie_infer_request_infer(CxtRef.OpenVINOInferRequest); if (Status != IEStatusCode::OK) { spdlog::error( @@ -587,7 +587,7 @@ Expect WasiNNCompute::body(Runtime::Instance::MemoryInstance *MemInst, return static_cast(WASINN::ErrNo::Success); #else spdlog::error("[WASI-NN] OpenVINO backend is not built. use " - "-DWASMEDGE_WASINN_BUILD_OPENVINO=ON" + "-DWASMEDGE_WASINN_BACKEND_OPENVINO=ON" "to build it."); #endif } else { diff --git a/utils/docker/build-wasinn-ubuntu-openvino.sh b/utils/docker/build-wasinn-ubuntu-openvino.sh index 6f3bc7a3..ce6e1621 100755 --- a/utils/docker/build-wasinn-ubuntu-openvino.sh +++ b/utils/docker/build-wasinn-ubuntu-openvino.sh @@ -2,6 +2,16 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-FileCopyrightText: 2019-2022 Second State INC +if [[ ! -v "${OPENVINO_VERSION}" ]]; then + OPENVINO_VERSION="2021.4.582" +fi +if [[ ! -v "${OPENVINO_YEAR}" ]]; then + OPENVINO_YEAR="2021" +fi +if [[ ! -v "${CMAKE_BUILD_TYPE}" ]]; then + CMAKE_BUILD_TYPE=Release +fi + set -e echo "Installing OpenVINO with version ${OPENVINO_VERSION}" curl -sSL https://apt.repos.intel.com/openvino/$OPENVINO_YEAR/GPG-PUB-KEY-INTEL-OPENVINO-$OPENVINO_YEAR | gpg --dearmor > /usr/share/keyrings/GPG-PUB-KEY-INTEL-OPENVINO-$OPENVINO_YEAR.gpg @@ -11,7 +21,7 @@ apt install -y intel-openvino-runtime-ubuntu20-$OPENVINO_VERSION source /opt/intel/openvino_2021/bin/setupvars.sh ldconfig git config --global --add safe.directory $(pwd) -if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=$CMAKE_BUILD_TYPE -DWASMEDGE_BUILD_TESTS=ON -DWASMEDGE_WASINN_BUILD_OPENVINO=ON .; then +if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=$CMAKE_BUILD_TYPE -DWASMEDGE_BUILD_TESTS=ON -DWASMEDGE_WASINN_BACKEND="OpenVINO" .; then echo === CMakeOutput.log === cat build/CMakeFiles/CMakeOutput.log echo === CMakeError.log === From 766f6a24aa89b30e8915159a3ff63571ab2147fb Mon Sep 17 00:00:00 2001 From: YiYing He Date: Wed, 29 Jun 2022 15:57:28 +0800 Subject: [PATCH 047/623] [Utils] Move the WASI-NN build scripts into `utils/wasi-nn`. Signed-off-by: YiYing He --- .../build-wasinn-ubuntu-openvino.sh | 0 utils/wasi-nn/download_openvino_fixtures.sh | 14 ++++++++++++++ .../test-wasinn-ubuntu-openvino.sh | 0 3 files changed, 14 insertions(+) rename utils/{docker => wasi-nn}/build-wasinn-ubuntu-openvino.sh (100%) create mode 100755 utils/wasi-nn/download_openvino_fixtures.sh rename utils/{docker => wasi-nn}/test-wasinn-ubuntu-openvino.sh (100%) diff --git a/utils/docker/build-wasinn-ubuntu-openvino.sh b/utils/wasi-nn/build-wasinn-ubuntu-openvino.sh similarity index 100% rename from utils/docker/build-wasinn-ubuntu-openvino.sh rename to utils/wasi-nn/build-wasinn-ubuntu-openvino.sh diff --git a/utils/wasi-nn/download_openvino_fixtures.sh b/utils/wasi-nn/download_openvino_fixtures.sh new file mode 100755 index 00000000..97e44e4b --- /dev/null +++ b/utils/wasi-nn/download_openvino_fixtures.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash + +DOWNLOAD_TO=$1 +FIXTURE=https://github.com/intel/openvino-rs/raw/v0.3.3/crates/openvino/tests/fixtures/mobilenet/ + +if [ ! -f $DOWNLOAD_TO/mobilenet.bin ]; then + wget -q --no-clobber --directory-prefix=$DOWNLOAD_TO $FIXTURE/mobilenet.bin +fi +if [ ! -f $DOWNLOAD_TO/mobilenet.xml ]; then + wget -q --no-clobber --directory-prefix=$DOWNLOAD_TO $FIXTURE/mobilenet.xml +fi +if [ ! -f $DOWNLOAD_TO/tensor-1x224x224x3-f32.bgr ]; then + wget -q --no-clobber --directory-prefix=$DOWNLOAD_TO $FIXTURE/tensor-1x224x224x3-f32.bgr +fi diff --git a/utils/docker/test-wasinn-ubuntu-openvino.sh b/utils/wasi-nn/test-wasinn-ubuntu-openvino.sh similarity index 100% rename from utils/docker/test-wasinn-ubuntu-openvino.sh rename to utils/wasi-nn/test-wasinn-ubuntu-openvino.sh From 5185a401984245db48ee7c3e4cc93354bac4ea8e Mon Sep 17 00:00:00 2001 From: Yukang Date: Tue, 5 Jul 2022 01:05:18 +0800 Subject: [PATCH 048/623] [Docs] Fix outdated contributing page (#1618) * [Docs] Fix outdated contributing page * [Docs]: fix contributing page * [Misc]: fix docker images Signed-off-by: yukang --- utils/docker/Dockerfile.ubuntu2004_x86_64 | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/docker/Dockerfile.ubuntu2004_x86_64 b/utils/docker/Dockerfile.ubuntu2004_x86_64 index 96e1c5f1..dda1cb9c 100644 --- a/utils/docker/Dockerfile.ubuntu2004_x86_64 +++ b/utils/docker/Dockerfile.ubuntu2004_x86_64 @@ -13,8 +13,8 @@ RUN apt update && apt upgrade -y \ git \ dpkg-dev \ libboost-all-dev \ - llvm-10-dev \ - liblld-10-dev \ + llvm-12-dev \ + liblld-12-dev \ gcc \ rpm \ dpkg-dev \ From 7ed7356711c9d640fc79bb83efc9065d1908769f Mon Sep 17 00:00:00 2001 From: YiYing He Date: Fri, 1 Jul 2022 13:56:49 +0800 Subject: [PATCH 049/623] [Test] Move the wasmedge_process tests to depend on plugin. Signed-off-by: YiYing He --- plugins/wasmedge_process/CMakeLists.txt | 31 +------------------------ 1 file changed, 1 insertion(+), 30 deletions(-) diff --git a/plugins/wasmedge_process/CMakeLists.txt b/plugins/wasmedge_process/CMakeLists.txt index 4f41af5f..69dcab5c 100644 --- a/plugins/wasmedge_process/CMakeLists.txt +++ b/plugins/wasmedge_process/CMakeLists.txt @@ -24,6 +24,7 @@ endif() target_include_directories(wasmedgePluginWasmEdgeProcess PUBLIC $ + ${CMAKE_CURRENT_SOURCE_DIR} ) target_link_libraries(wasmedgePluginWasmEdgeProcess @@ -36,33 +37,3 @@ target_link_libraries(wasmedgePluginWasmEdgeProcess if(CMAKE_SYSTEM_NAME MATCHES "Linux") install(TARGETS wasmedgePluginWasmEdgeProcess DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) endif() - -wasmedge_add_library(wasmedgeHostModuleWasmEdgeProcess - processenv.cpp - processfunc.cpp - processmodule.cpp -) - -target_include_directories(wasmedgeHostModuleWasmEdgeProcess - PUBLIC - ${CMAKE_CURRENT_SOURCE_DIR} -) - -target_link_libraries(wasmedgeHostModuleWasmEdgeProcess - PUBLIC - wasmedgeCommon - wasmedgeSystem - wasmedgePlugin -) - -if(CMAKE_SYSTEM_NAME MATCHES "Linux") - target_link_options(wasmedgeHostModuleWasmEdgeProcess - PUBLIC - -u_ZN8WasmEdge4Host26WasmEdgeProcessEnvironment8RegisterE - ) -elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin") - target_link_options(wasmedgeHostModuleWasmEdgeProcess - PUBLIC - -u__ZN8WasmEdge4Host26WasmEdgeProcessEnvironment8RegisterE - ) -endif() From 067bc70f55e2718d5ebab5f97c99e9ab14afaadf Mon Sep 17 00:00:00 2001 From: YiYing He Date: Fri, 1 Jul 2022 16:30:23 +0800 Subject: [PATCH 050/623] [Test] Move the WASI-NN tests to depend on plugin. Signed-off-by: YiYing He --- plugins/wasi_nn/CMakeLists.txt | 35 +--------------------------------- plugins/wasi_nn/wasinnmodule.h | 2 ++ 2 files changed, 3 insertions(+), 34 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 584dbb2b..5600d8c5 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -24,6 +24,7 @@ endif() target_include_directories(wasmedgePluginWasiNN PUBLIC $ + ${CMAKE_CURRENT_SOURCE_DIR} ) target_link_libraries(wasmedgePluginWasiNN @@ -34,36 +35,6 @@ target_link_libraries(wasmedgePluginWasiNN install(TARGETS wasmedgePluginWasiNN DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) -wasmedge_add_library(wasmedgeHostModuleWasiNN - wasinnenv.cpp - wasinnfunc.cpp - wasinnmodule.cpp -) - -target_include_directories(wasmedgeHostModuleWasiNN - PUBLIC - ${CMAKE_CURRENT_SOURCE_DIR} -) - -target_link_libraries(wasmedgeHostModuleWasiNN - PUBLIC - wasmedgeCommon - wasmedgeSystem - wasmedgePlugin -) - -if(CMAKE_SYSTEM_NAME MATCHES "Linux") - target_link_options(wasmedgeHostModuleWasiNN - PUBLIC - -u_ZN8WasmEdge4Host6WASINN17WasiNNEnvironment8RegisterE - ) -elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin") - target_link_options(wasmedgeHostModuleWasiNN - PUBLIC - -u__ZN8WasmEdge4Host6WASINN17WasiNNEnvironment8RegisterE - ) -endif() - # Add backends building flags. foreach(BACKEND ${WASMEDGE_WASINN_BACKEND}) if(BACKEND MATCHES "OpenVINO") @@ -74,10 +45,6 @@ foreach(BACKEND ${WASMEDGE_WASINN_BACKEND}) PUBLIC ${InferenceEngine_LIBRARIES} ) - target_link_libraries(wasmedgeHostModuleWasiNN - PUBLIC - ${InferenceEngine_LIBRARIES} - ) else() # Add the other backends here. message(FATAL_ERROR "WASI-NN backend ${BACKEND} not found or unimplemented.") diff --git a/plugins/wasi_nn/wasinnmodule.h b/plugins/wasi_nn/wasinnmodule.h index 486cac12..0c18bd16 100644 --- a/plugins/wasi_nn/wasinnmodule.h +++ b/plugins/wasi_nn/wasinnmodule.h @@ -13,6 +13,8 @@ class WasiNNModule : public Runtime::Instance::ModuleInstance { public: WasiNNModule(); + WASINN::WasiNNEnvironment &getEnv() { return Env; } + private: WASINN::WasiNNEnvironment Env; }; From 05ecbc65d4a5aab7ba353f9958c93d57f778bafa Mon Sep 17 00:00:00 2001 From: YiYing He Date: Mon, 4 Jul 2022 09:20:01 +0800 Subject: [PATCH 051/623] [Plugin] Not to build and test wasmedge_process on unsupported platforms. Signed-off-by: YiYing He --- plugins/CMakeLists.txt | 6 +++++- plugins/wasmedge_process/CMakeLists.txt | 5 +---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index 178c0ffc..cfc5b056 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -1,7 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: 2019-2022 Second State INC -add_subdirectory(wasmedge_process) +# Only Linux systems support wasmedge_process now. +if(CMAKE_SYSTEM_NAME MATCHES "Linux") + add_subdirectory(wasmedge_process) +endif() + if(WASMEDGE_WASINN_BACKEND) add_subdirectory(wasi_nn) endif() diff --git a/plugins/wasmedge_process/CMakeLists.txt b/plugins/wasmedge_process/CMakeLists.txt index 69dcab5c..82852d81 100644 --- a/plugins/wasmedge_process/CMakeLists.txt +++ b/plugins/wasmedge_process/CMakeLists.txt @@ -33,7 +33,4 @@ target_link_libraries(wasmedgePluginWasmEdgeProcess wasmedgeSystem ) -# Only Linux systems support wasmedge_process now. -if(CMAKE_SYSTEM_NAME MATCHES "Linux") - install(TARGETS wasmedgePluginWasmEdgeProcess DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) -endif() +install(TARGETS wasmedgePluginWasmEdgeProcess DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) From a455211b43fc5af9e299d992b9c56b08602f18e2 Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 6 Jul 2022 17:10:19 +0800 Subject: [PATCH 052/623] [Docker] Release DockerSlim images Signed-off-by: dm4 --- utils/docker/Dockerfile.release | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 utils/docker/Dockerfile.release diff --git a/utils/docker/Dockerfile.release b/utils/docker/Dockerfile.release new file mode 100644 index 00000000..225b9549 --- /dev/null +++ b/utils/docker/Dockerfile.release @@ -0,0 +1,10 @@ +FROM ubuntu:20.04 +ARG VERSION +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt update && \ + apt install -y curl git && \ + curl -sSf https://raw.githubusercontent.com/WasmEdge/WasmEdge/master/utils/install.sh | bash -s -- -p /usr/local -e all -v $VERSION + +WORKDIR /app +CMD ["/usr/local/bin/wasmedge"] From 473acb94dd686bd2a8f71c9339f789506f49da8d Mon Sep 17 00:00:00 2001 From: YiYing He Date: Tue, 12 Jul 2022 03:52:04 +0800 Subject: [PATCH 053/623] [Misc] Add the `PLUGIN` prefix for wasi-nn options and update the WASI-NN docs. Signed-off-by: YiYing He --- plugins/CMakeLists.txt | 2 +- plugins/wasi_nn/CMakeLists.txt | 4 ++-- plugins/wasi_nn/wasinnenv.h | 16 +++++++------- plugins/wasi_nn/wasinnfunc.cpp | 22 +++++++++---------- utils/wasi-nn/build-wasinn-ubuntu-openvino.sh | 2 +- 5 files changed, 23 insertions(+), 23 deletions(-) diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index cfc5b056..d87e230f 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -6,6 +6,6 @@ if(CMAKE_SYSTEM_NAME MATCHES "Linux") add_subdirectory(wasmedge_process) endif() -if(WASMEDGE_WASINN_BACKEND) +if(WASMEDGE_PLUGIN_WASI_NN_BACKEND) add_subdirectory(wasi_nn) endif() diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 5600d8c5..b13cc3cc 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -36,11 +36,11 @@ target_link_libraries(wasmedgePluginWasiNN install(TARGETS wasmedgePluginWasiNN DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) # Add backends building flags. -foreach(BACKEND ${WASMEDGE_WASINN_BACKEND}) +foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) if(BACKEND MATCHES "OpenVINO") message(STATUS "Build ${BACKEND} backend for WASI-NN") find_package(InferenceEngine REQUIRED) - add_definitions(-DWASMEDGE_WASINN_BACKEND_OPENVINO) + add_definitions(-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO) target_link_libraries(wasmedgePluginWasiNN PUBLIC ${InferenceEngine_LIBRARIES} diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index 69255d96..a163f01e 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -8,7 +8,7 @@ #include #include -#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO #include "common/log.h" #include #endif @@ -30,7 +30,7 @@ enum class Backend : uint8_t { class Graph { public: -#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO Graph() = delete; Graph(Backend BE) noexcept : GraphBackend(BE), OpenVINONetwork(nullptr), @@ -61,7 +61,7 @@ class Graph { #endif Backend GraphBackend; -#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO ie_network_t *OpenVINONetwork; ie_executable_network_t *OpenVINOExecNetwork; ie_blob_t *OpenVINOWeightBlob; @@ -73,7 +73,7 @@ class Graph { class Context { public: Context() = delete; -#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO Context(Graph &G, ie_infer_request_t *InferReq) noexcept : GraphRef(G), OpenVINOInferRequest(InferReq) {} ~Context() noexcept { @@ -86,7 +86,7 @@ class Context { #endif Graph &GraphRef; -#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO ie_infer_request_t *OpenVINOInferRequest; #endif }; @@ -94,7 +94,7 @@ class Context { class WasiNNEnvironment { public: WasiNNEnvironment() noexcept { -#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO if (ie_core_create("", &OpenVINOCore) != IEStatusCode::OK) { spdlog::error( "[WASI-NN] Error happened when initializing OpenVINO core."); @@ -106,7 +106,7 @@ class WasiNNEnvironment { ~WasiNNEnvironment() noexcept { NNContext.clear(); NNGraph.clear(); -#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO if (OpenVINOCore) { ie_core_free(&OpenVINOCore); } @@ -115,7 +115,7 @@ class WasiNNEnvironment { std::vector NNGraph; std::vector NNContext; -#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO ie_core_t *OpenVINOCore = nullptr; #endif diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index 1566d3ba..b402fa08 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -4,7 +4,7 @@ #include "wasinnfunc.h" #include "common/log.h" -#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO #include #include @@ -26,7 +26,7 @@ Expect WasiNNLoad::body(Runtime::Instance::MemoryInstance *MemInst, } if (Encoding == static_cast(WASINN::Backend::OpenVINO)) { -#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO // The OpenVINO core must be initialized in constructor. if (unlikely(Env.OpenVINOCore == nullptr)) { spdlog::error("[WASI-NN] OpenVINO core not initialized."); @@ -232,7 +232,7 @@ Expect WasiNNLoad::body(Runtime::Instance::MemoryInstance *MemInst, return static_cast(WASINN::ErrNo::Success); #else spdlog::error("[WASI-NN] OpenVINO backend is not built. use " - "-DWASMEDGE_WASINN_BACKEND_OPENVINO=ON" + "-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO=ON" "to build it."); #endif } else { @@ -255,7 +255,7 @@ WasiNNInitExecCtx::body(Runtime::Instance::MemoryInstance *MemInst, } if (Env.NNGraph[GraphId].GraphBackend == WASINN::Backend::OpenVINO) { -#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO // Check the return value: Context should be valid. uint32_t *Context = MemInst->getPointer(ContextPtr, 1); if (unlikely(Context == nullptr)) { @@ -285,7 +285,7 @@ WasiNNInitExecCtx::body(Runtime::Instance::MemoryInstance *MemInst, return static_cast(WASINN::ErrNo::Success); #else spdlog::error("[WASI-NN] OpenVINO backend is not built. define " - "-DWASMEDGE_WASINN_BACKEND_OPENVINO " + "-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO " "to build it."); #endif } else { @@ -309,7 +309,7 @@ WasiNNSetInput::body(Runtime::Instance::MemoryInstance *MemInst, auto &CxtRef = Env.NNContext[Context]; if (CxtRef.GraphRef.GraphBackend == WASINN::Backend::OpenVINO) { -#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO // Check the infer request and the network. auto *Network = CxtRef.GraphRef.OpenVINONetwork; if (Network == nullptr || CxtRef.OpenVINOInferRequest == nullptr) { @@ -457,7 +457,7 @@ WasiNNSetInput::body(Runtime::Instance::MemoryInstance *MemInst, return static_cast(WASINN::ErrNo::Success); #else spdlog::error("[WASI-NN] OpenVINO backend is not built, use " - "-DWASMEDGE_WASINN_BACKEND_OPENVINO=ON" + "-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO=ON" "to build it."); #endif } else { @@ -483,7 +483,7 @@ WasiNNGetOuput::body(Runtime::Instance::MemoryInstance *MemInst, auto &CxtRef = Env.NNContext[Context]; if (CxtRef.GraphRef.GraphBackend == WASINN::Backend::OpenVINO) { -#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO auto *Network = CxtRef.GraphRef.OpenVINONetwork; // Check the output index. @@ -554,7 +554,7 @@ WasiNNGetOuput::body(Runtime::Instance::MemoryInstance *MemInst, return static_cast(WASINN::ErrNo::Success); #else spdlog::error("[WASI-NN] OpenVINO backend is not built. use " - "-DWASMEDGE_WASINN_BACKEND_OPENVINO=ON" + "-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO=ON" "to build it."); #endif } else { @@ -576,7 +576,7 @@ Expect WasiNNCompute::body(Runtime::Instance::MemoryInstance *MemInst, auto &CxtRef = Env.NNContext[Context]; if (CxtRef.GraphRef.GraphBackend == WASINN::Backend::OpenVINO) { -#ifdef WASMEDGE_WASINN_BACKEND_OPENVINO +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO IEStatusCode Status = ie_infer_request_infer(CxtRef.OpenVINOInferRequest); if (Status != IEStatusCode::OK) { spdlog::error( @@ -587,7 +587,7 @@ Expect WasiNNCompute::body(Runtime::Instance::MemoryInstance *MemInst, return static_cast(WASINN::ErrNo::Success); #else spdlog::error("[WASI-NN] OpenVINO backend is not built. use " - "-DWASMEDGE_WASINN_BACKEND_OPENVINO=ON" + "-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO=ON" "to build it."); #endif } else { diff --git a/utils/wasi-nn/build-wasinn-ubuntu-openvino.sh b/utils/wasi-nn/build-wasinn-ubuntu-openvino.sh index ce6e1621..8899ec77 100755 --- a/utils/wasi-nn/build-wasinn-ubuntu-openvino.sh +++ b/utils/wasi-nn/build-wasinn-ubuntu-openvino.sh @@ -21,7 +21,7 @@ apt install -y intel-openvino-runtime-ubuntu20-$OPENVINO_VERSION source /opt/intel/openvino_2021/bin/setupvars.sh ldconfig git config --global --add safe.directory $(pwd) -if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=$CMAKE_BUILD_TYPE -DWASMEDGE_BUILD_TESTS=ON -DWASMEDGE_WASINN_BACKEND="OpenVINO" .; then +if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=$CMAKE_BUILD_TYPE -DWASMEDGE_BUILD_TESTS=ON -DWASMEDGE_PLUGIN_WASI_NN_BACKEND="OpenVINO" .; then echo === CMakeOutput.log === cat build/CMakeFiles/CMakeOutput.log echo === CMakeError.log === From 248dc7c1b500f144908abcfeff7ab21ddfbdcae9 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Wed, 13 Jul 2022 14:40:56 +0800 Subject: [PATCH 054/623] [Test] Move the plugin tests into the `plugins` folder. Signed-off-by: YiYing He --- test/plugins/CMakeLists.txt | 9 + test/plugins/wasi_nn/CMakeLists.txt | 41 ++ test/plugins/wasi_nn/wasi_nn.cpp | 475 ++++++++++++++++ test/plugins/wasmedge_process/CMakeLists.txt | 15 + .../wasmedge_process/wasmedge_process.cpp | 523 ++++++++++++++++++ 5 files changed, 1063 insertions(+) create mode 100644 test/plugins/CMakeLists.txt create mode 100644 test/plugins/wasi_nn/CMakeLists.txt create mode 100644 test/plugins/wasi_nn/wasi_nn.cpp create mode 100644 test/plugins/wasmedge_process/CMakeLists.txt create mode 100644 test/plugins/wasmedge_process/wasmedge_process.cpp diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt new file mode 100644 index 00000000..7f920c0e --- /dev/null +++ b/test/plugins/CMakeLists.txt @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +if (CMAKE_SYSTEM_NAME MATCHES "Linux") + add_subdirectory(wasmedge_process) +endif() +if(WASMEDGE_PLUGIN_WASI_NN_BACKEND) + add_subdirectory(wasi_nn) +endif() diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt new file mode 100644 index 00000000..24b6e98e --- /dev/null +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +wasmedge_add_executable(wasiNNTests + wasi_nn.cpp +) + +# Prepare the testing data for each backends. +foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) + if(BACKEND MATCHES "OpenVINO") + message( STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_openvino_fixtures") + execute_process( + COMMAND bash ${CMAKE_SOURCE_DIR}/utils/wasi-nn/download_openvino_fixtures.sh ${CMAKE_CURRENT_BINARY_DIR}/wasinn_openvino_fixtures + RESULT_VARIABLE DOWNLOAD_ERROR + OUTPUT_STRIP_TRAILING_WHITESPACE) + file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_openvino_fixtures/mobilenet.bin CHECKSUM_WEIGHT) + file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_openvino_fixtures/mobilenet.xml CHECKSUM_DESCRIP) + file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_openvino_fixtures/tensor-1x224x224x3-f32.bgr CHECKSUM_TENSOR) + if(NOT CHECKSUM_WEIGHT STREQUAL "ae096b1f735f1e8e54bac8b2a42303bd") + message(FATAL_ERROR "mobilenet.bin downloaded with wrong md5") + endif() + if(NOT CHECKSUM_DESCRIP STREQUAL "4ea3a14273587ce5c1662018878f9f90") + message(FATAL_ERROR "mobilenet.xml downloaded with wrong md5") + endif() + if(NOT CHECKSUM_TENSOR STREQUAL "bfca546f4a3b5e6da49b7bd728e2799a") + message(FATAL_ERROR "tensor-1x224x224x3-f32.bgr downloaded with wrong md5") + endif() + add_definitions(-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO) + else() + # Add the other backend test files fetching here. + endif() +endforeach() + +target_link_libraries(wasiNNTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} + wasmedgePlugin + wasmedgePluginWasiNN +) + +add_test(wasiNNTests wasiNNTests) diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp new file mode 100644 index 00000000..15631ff5 --- /dev/null +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -0,0 +1,475 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "common/types.h" +#include "runtime/instance/module.h" +#include "wasinnfunc.h" +#include "wasinnmodule.h" + +#include +#include +#include +#include +#include +#include + +using WasmEdge::Host::WASINN::ErrNo; +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO + +namespace { +WasmEdge::Runtime::Instance::ModuleInstance *createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasi_nn/" + "libwasmedgePluginWasiNN" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasi_nn"sv)) { + if (const auto *Module = Plugin->findModule("wasi_nn"sv)) { + return Module->create().release(); + } + } + return nullptr; +} + +inline std::vector readEntireFile(const std::string &Path) { + std::ifstream Fin(Path, std::ios::binary | std::ios::ate); + if (!Fin) { + return {}; + } + Fin.seekg(0, std::ios::end); + std::vector Buf(static_cast(Fin.tellg())); + Fin.seekg(0, std::ios::beg); + if (!Fin.read(reinterpret_cast(Buf.data()), + static_cast(Buf.size()))) { + return {}; + } + Fin.close(); + return Buf; +} + +template +void writeBinaries(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + std::vector Binaries, uint32_t Ptr) noexcept { + std::copy(Binaries.begin(), Binaries.end(), MemInst.getPointer(Ptr)); +} + +void writeUInt32(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t Value, uint32_t &Ptr) { + uint32_t *BufPtr = MemInst.getPointer(Ptr); + *BufPtr = Value; + Ptr += 4; +} + +void writeFatPointer(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t PtrVal, uint32_t PtrSize, uint32_t &Ptr) { + writeUInt32(MemInst, PtrVal, Ptr); + writeUInt32(MemInst, PtrSize, Ptr); +} + +template +std::vector classSort(const std::vector &Array) { + std::vector Indices(Array.size()); + std::iota(Indices.begin(), Indices.end(), 0); + std::sort(Indices.begin(), Indices.end(), + [&Array](int Left, int Right) -> bool { + // Sort indices according to corresponding array element. + return Array[Left] >= Array[Right]; + }); + return Indices; +} +} // namespace + +TEST(WasiNNTest, OpenVINOBackend) { + // Create the wasmedge_process module instance. + auto *NNMod = dynamic_cast(createModule()); + EXPECT_FALSE(NNMod == nullptr); + + // Create the memory instance. + WasmEdge::Runtime::Instance::MemoryInstance MemInst( + WasmEdge::AST::MemoryType(400)); + + // Load the files. + std::vector TensorData = + readEntireFile("./wasinn_openvino_fixtures/tensor-1x224x224x3-f32.bgr"); + std::vector XmlRead = + readEntireFile("./wasinn_openvino_fixtures/mobilenet.xml"); + std::vector WeightRead = + readEntireFile("./wasinn_openvino_fixtures/mobilenet.bin"); + + std::vector TensorDim{1, 3, 224, 224}; + uint32_t BuilderPtr = UINT32_C(0); + uint32_t LoadEntryPtr = UINT32_C(0); + uint32_t SetInputEntryPtr = UINT32_C(0); + uint32_t OutBoundPtr = UINT32_C(410 * 65536); + uint32_t StorePtr = UINT32_C(65536); + + // Return value. + std::array Errno = {UINT32_C(0)}; + + // Temp. values. + std::vector NNGraphTmp; + std::vector NNContextTmp; + + // Get the function "load". + auto *FuncInst = NNMod->findFuncExports("load"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncLoad = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "init_execution_context". + FuncInst = NNMod->findFuncExports("init_execution_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInit = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "set_input". + FuncInst = NNMod->findFuncExports("set_input"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSetInput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "get_output". + FuncInst = NNMod->findFuncExports("get_output"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetOutput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "compute". + FuncInst = NNMod->findFuncExports("compute"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncCompute = + dynamic_cast(FuncInst->getHostFunc()); + + // OpenVINO WASI-NN load tests. + // Test: load -- meaningless binaries. + { + EXPECT_TRUE(HostFuncLoad.run( + &MemInst, + std::initializer_list{ + LoadEntryPtr, UINT32_C(2), UINT32_C(0), UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Busy)); + } + + // Test: load -- graph id ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run( + &MemInst, + std::initializer_list{ + LoadEntryPtr, UINT32_C(2), UINT32_C(0), UINT32_C(0), OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- graph builder ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run( + &MemInst, + std::initializer_list{ + OutBoundPtr, UINT32_C(2), UINT32_C(0), UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: laod -- OpenVINO model xml ptr out of bounds. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, OutBoundPtr, XmlRead.size(), BuilderPtr); + writeFatPointer(MemInst, StorePtr + XmlRead.size(), WeightRead.size(), + BuilderPtr); + { + EXPECT_TRUE(HostFuncLoad.run( + &MemInst, + std::initializer_list{ + LoadEntryPtr, UINT32_C(2), UINT32_C(0), UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- OpenVINO model bin ptr out of bounds. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, StorePtr, XmlRead.size(), BuilderPtr); + writeFatPointer(MemInst, OutBoundPtr, WeightRead.size(), BuilderPtr); + { + EXPECT_TRUE(HostFuncLoad.run( + &MemInst, + std::initializer_list{ + LoadEntryPtr, UINT32_C(2), UINT32_C(0), UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- wrong builders' length. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, StorePtr, XmlRead.size(), BuilderPtr); + writeFatPointer(MemInst, StorePtr + XmlRead.size(), WeightRead.size(), + BuilderPtr); + writeBinaries(MemInst, XmlRead, StorePtr); + writeBinaries(MemInst, WeightRead, StorePtr + XmlRead.size()); + StorePtr += (XmlRead.size() + WeightRead.size()); + { + EXPECT_TRUE(HostFuncLoad.run( + &MemInst, + std::initializer_list{ + LoadEntryPtr, UINT32_C(4), UINT32_C(0), UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- unsupported device. CPU 0, GPU 1, TPU 2 + { + EXPECT_TRUE(HostFuncLoad.run( + &MemInst, + std::initializer_list{ + LoadEntryPtr, UINT32_C(2), UINT32_C(0), UINT32_C(3), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- load successfully. + { + EXPECT_TRUE(HostFuncLoad.run( + &MemInst, + std::initializer_list{ + LoadEntryPtr, UINT32_C(2), UINT32_C(0), UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Test: load -- load second graph. + { + EXPECT_TRUE(HostFuncLoad.run( + &MemInst, + std::initializer_list{ + LoadEntryPtr, UINT32_C(2), UINT32_C(0), UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 1); + BuilderPtr += 4; + } + + // OpenVINO WASI-NN init_execution_context tests. + // Test: init_execution_context -- graph id invalid. + { + EXPECT_TRUE(HostFuncInit.run( + &MemInst, + std::initializer_list{UINT32_C(2), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Swap to the tmp. env. + NNGraphTmp.emplace_back(WasmEdge::Host::WASINN::Backend::OpenVINO); + NNGraphTmp.swap(NNMod->getEnv().NNGraph); + NNContextTmp.swap(NNMod->getEnv().NNContext); + // Test: init_execution_context -- graph id exceeds. + { + EXPECT_TRUE(HostFuncInit.run( + &MemInst, + std::initializer_list{UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::MissingMemory)); + } + // Swap back. + NNGraphTmp.swap(NNMod->getEnv().NNGraph); + NNContextTmp.swap(NNMod->getEnv().NNContext); + + // Test: init_execution_context -- init context successfully. + { + EXPECT_TRUE(HostFuncInit.run( + &MemInst, + std::initializer_list{UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Test: init_execution_context -- init second context. + { + EXPECT_TRUE(HostFuncInit.run( + &MemInst, + std::initializer_list{UINT32_C(1), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 1); + BuilderPtr += 4; + } + + // OpenVINO WASI-NN set_input tests. + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); + + // Swap to the tmp. env. + NNContextTmp.emplace_back(NNGraphTmp[0], nullptr); + NNGraphTmp.swap(NNMod->getEnv().NNGraph); + NNContextTmp.swap(NNMod->getEnv().NNContext); + // Test: set_input -- context id exceeds. + { + EXPECT_TRUE( + HostFuncSetInput.run(&MemInst, + std::initializer_list{ + UINT32_C(3), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: set_input -- empty context. + { + EXPECT_TRUE( + HostFuncSetInput.run(&MemInst, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::MissingMemory)); + } + // Swap back. + NNGraphTmp.swap(NNMod->getEnv().NNGraph); + NNContextTmp.swap(NNMod->getEnv().NNContext); + + // Test: set_input -- input index exceeds. + { + EXPECT_TRUE( + HostFuncSetInput.run(&MemInst, + std::initializer_list{ + UINT32_C(0), UINT32_C(10), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: set_input -- tensor type not FP32. + BuilderPtr = SetInputEntryPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, UINT32_C(2), BuilderPtr); + writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + BuilderPtr); + { + EXPECT_TRUE( + HostFuncSetInput.run(&MemInst, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: set_input -- set input successfully. + BuilderPtr = SetInputEntryPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + BuilderPtr); + { + EXPECT_TRUE( + HostFuncSetInput.run(&MemInst, + std::initializer_list{ + UINT32_C(1), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += (TensorDim.size() * 4 + TensorData.size()); + + // OpenVINO WASI-NN compute tests. + // Test: compute -- context id exceeds. + { + EXPECT_TRUE(HostFuncCompute.run( + &MemInst, std::initializer_list{UINT32_C(3)}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Swap to the tmp. env. + NNGraphTmp.swap(NNMod->getEnv().NNGraph); + NNContextTmp.swap(NNMod->getEnv().NNContext); + // Test: compute -- empty context. + { + EXPECT_TRUE(HostFuncCompute.run( + &MemInst, std::initializer_list{UINT32_C(0)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Busy)); + } + // Swap back. + NNGraphTmp.swap(NNMod->getEnv().NNGraph); + NNContextTmp.swap(NNMod->getEnv().NNContext); + + // Test: compute -- compute successfully. + { + EXPECT_TRUE(HostFuncCompute.run( + &MemInst, std::initializer_list{UINT32_C(1)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // OpenVINO WASI-NN get_output tests. + // Test: get_output -- output bytes ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + &MemInst, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- output buffer ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + &MemInst, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), OutBoundPtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- output index exceeds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + &MemInst, + std::initializer_list{ + UINT32_C(0), UINT32_C(10), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- get output successfully. + { + EXPECT_TRUE(HostFuncGetOutput.run( + &MemInst, + std::initializer_list{ + UINT32_C(1), UINT32_C(0), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), UINT32_C(4004)); + std::vector OutputClassification( + MemInst.getPointer(StorePtr, 1001) + 1, + MemInst.getPointer(StorePtr, 1001) + 1001); + std::vector SortedIndex, CorrectClasses{963, 762, 909, 926, 567}; + SortedIndex = classSort(OutputClassification); + // The probability of class i is placed at buffer[i]. + for (size_t I = 0; I < CorrectClasses.size(); I++) { + EXPECT_EQ(SortedIndex[I], CorrectClasses[I]); + } + } +} +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO diff --git a/test/plugins/wasmedge_process/CMakeLists.txt b/test/plugins/wasmedge_process/CMakeLists.txt new file mode 100644 index 00000000..9f7d1309 --- /dev/null +++ b/test/plugins/wasmedge_process/CMakeLists.txt @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +wasmedge_add_executable(wasmedgeProcessTests + wasmedge_process.cpp +) + +target_link_libraries(wasmedgeProcessTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} + wasmedgePlugin + wasmedgePluginWasmEdgeProcess +) + +add_test(wasmedgeProcessTests wasmedgeProcessTests) diff --git a/test/plugins/wasmedge_process/wasmedge_process.cpp b/test/plugins/wasmedge_process/wasmedge_process.cpp new file mode 100644 index 00000000..42bfb7a0 --- /dev/null +++ b/test/plugins/wasmedge_process/wasmedge_process.cpp @@ -0,0 +1,523 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "common/defines.h" +#include "processfunc.h" +#include "processmodule.h" +#include "runtime/instance/module.h" + +#include +#include +#include +#include +#include +#include + +namespace { +WasmEdge::Runtime::Instance::ModuleInstance *createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasmedge_process/" + "libwasmedgePluginWasmEdgeProcess" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = + WasmEdge::Plugin::Plugin::find("wasmedge_process"sv)) { + if (const auto *Module = Plugin->findModule("wasmedge_process"sv)) { + return Module->create().release(); + } + } + return nullptr; +} + +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t Offset, uint32_t Cnt, uint8_t C = 0) noexcept { + std::fill_n(MemInst.getPointer(Offset), Cnt, C); +} + +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t Offset, const std::string &Str) noexcept { + char *Buf = MemInst.getPointer(Offset); + std::copy_n(Str.c_str(), Str.length(), Buf); +} +} // namespace + +TEST(WasmEdgeProcessTest, SetProgName) { + // Create the wasmedge_process module instance. + auto *ProcMod = + dynamic_cast(createModule()); + EXPECT_FALSE(ProcMod == nullptr); + + // Create the memory instance. + WasmEdge::Runtime::Instance::MemoryInstance MemInst( + WasmEdge::AST::MemoryType(1)); + + // Clear the memory[0, 64]. + fillMemContent(MemInst, 0, 64); + // Set the memory[0, 4] as string "echo". + fillMemContent(MemInst, 0, std::string("echo")); + + // Get the function "wasmedge_process_set_prog_name". + auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_set_prog_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInst = + dynamic_cast( + FuncInst->getHostFunc()); + + // Test: Run function successfully. + EXPECT_TRUE(HostFuncInst.run( + &MemInst, + std::initializer_list{UINT32_C(0), UINT32_C(4)}, + {})); + EXPECT_EQ(ProcMod->getEnv().Name, "echo"); + + // Test: Run function with nullptr memory instance -- fail + EXPECT_FALSE(HostFuncInst.run( + nullptr, + std::initializer_list{UINT32_C(0), UINT32_C(4)}, + {})); + + delete ProcMod; +} + +TEST(WasmEdgeProcessTest, AddArg) { + // Create the wasmedge_process module instance. + auto *ProcMod = + dynamic_cast(createModule()); + EXPECT_FALSE(ProcMod == nullptr); + + // Create the memory instance. + WasmEdge::Runtime::Instance::MemoryInstance MemInst( + WasmEdge::AST::MemoryType(1)); + + // Clear the memory[0, 64]. + fillMemContent(MemInst, 0, 64); + // Set the memory[0, 4] as string "echo". + fillMemContent(MemInst, 0, std::string("arg1")); + // Set the memory[4, 8] as string "arg2". + fillMemContent(MemInst, 4, std::string("arg2")); + // Set the memory[30, 41] as string "--final-arg". + fillMemContent(MemInst, 30, std::string("--final-arg")); + + // Get the function "wasmedge_process_add_arg". + auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_add_arg"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInst = dynamic_cast( + FuncInst->getHostFunc()); + + // Test: Run function successfully to add "arg1". + EXPECT_TRUE(HostFuncInst.run( + &MemInst, + std::initializer_list{UINT32_C(0), UINT32_C(4)}, + {})); + EXPECT_EQ(ProcMod->getEnv().Args.size(), 1U); + EXPECT_EQ(ProcMod->getEnv().Args[0], "arg1"); + + // Test: Run function successfully to add "arg2". + EXPECT_TRUE(HostFuncInst.run( + &MemInst, + std::initializer_list{UINT32_C(4), UINT32_C(4)}, + {})); + EXPECT_EQ(ProcMod->getEnv().Args.size(), 2U); + EXPECT_EQ(ProcMod->getEnv().Args[1], "arg2"); + + // Test: Run function successfully to add "--final-arg". + EXPECT_TRUE(HostFuncInst.run( + &MemInst, + std::initializer_list{UINT32_C(30), UINT32_C(11)}, + {})); + EXPECT_EQ(ProcMod->getEnv().Args.size(), 3U); + EXPECT_EQ(ProcMod->getEnv().Args[2], "--final-arg"); + + // Test: Run function with nullptr memory instance -- fail + EXPECT_FALSE(HostFuncInst.run( + nullptr, + std::initializer_list{UINT32_C(0), UINT32_C(4)}, + {})); + + delete ProcMod; +} + +TEST(WasmEdgeProcessTest, AddEnv) { + // Create the wasmedge_process module instance. + auto *ProcMod = + dynamic_cast(createModule()); + EXPECT_FALSE(ProcMod == nullptr); + + // Create the memory instance. + WasmEdge::Runtime::Instance::MemoryInstance MemInst( + WasmEdge::AST::MemoryType(1)); + + // Clear the memory[0, 256]. + fillMemContent(MemInst, 0, 256); + // Set the memory[0, 4] as string "ENV1". + fillMemContent(MemInst, 0, std::string("ENV1")); + // Set the memory[4, 10] as string "VALUE1". + fillMemContent(MemInst, 4, std::string("VALUE1")); + // Set the memory[30, 45] as string "LD_LIBRARY_PATH". + fillMemContent(MemInst, 30, std::string("LD_LIBRARY_PATH")); + // Set the memory[50, 64] as string "/usr/local/lib". + fillMemContent(MemInst, 50, std::string("/usr/local/lib")); + + // Get the function "wasmedge_process_add_env". + auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_add_env"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInst = dynamic_cast( + FuncInst->getHostFunc()); + + // Test: Run function successfully to add "ENV1", "VALUE1". + EXPECT_TRUE( + HostFuncInst.run(&MemInst, + std::initializer_list{ + UINT32_C(0), UINT32_C(4), UINT32_C(4), UINT32_C(6)}, + {})); + EXPECT_EQ(ProcMod->getEnv().Envs.size(), 1U); + EXPECT_EQ(ProcMod->getEnv().Envs["ENV1"], "VALUE1"); + + // Test: Run function successfully to add "LD_LIBRARY_PATH", "/usr/local/lib". + EXPECT_TRUE(HostFuncInst.run( + &MemInst, + std::initializer_list{UINT32_C(30), UINT32_C(15), + UINT32_C(50), UINT32_C(14)}, + {})); + EXPECT_EQ(ProcMod->getEnv().Envs.size(), 2U); + EXPECT_EQ(ProcMod->getEnv().Envs["LD_LIBRARY_PATH"], "/usr/local/lib"); + + // Test: Run function with nullptr memory instance -- fail + EXPECT_FALSE( + HostFuncInst.run(nullptr, + std::initializer_list{ + UINT32_C(0), UINT32_C(4), UINT32_C(4), UINT32_C(6)}, + {})); + + delete ProcMod; +} + +TEST(WasmEdgeProcessTest, AddStdIn) { + // Create the wasmedge_process module instance. + auto *ProcMod = + dynamic_cast(createModule()); + EXPECT_FALSE(ProcMod == nullptr); + + // Create the memory instance. + WasmEdge::Runtime::Instance::MemoryInstance MemInst( + WasmEdge::AST::MemoryType(1)); + + // Clear the memory[0, 64]. + fillMemContent(MemInst, 0, 64); + // Set the memory[0, 4] as string "\01\02\03\04". + fillMemContent(MemInst, 0, std::string("\01\02\03\04")); + // Set the memory[30, 46] as string "hello, wasmedge\n". + fillMemContent(MemInst, 30, std::string("hello, wasmedge\n")); + + // Get the function "wasmedge_process_add_stdin". + auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_add_stdin"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInst = dynamic_cast( + FuncInst->getHostFunc()); + + // Test: Run function successfully to add "\01\02\03\04". + EXPECT_TRUE(HostFuncInst.run( + &MemInst, + std::initializer_list{UINT32_C(0), UINT32_C(4)}, + {})); + EXPECT_EQ(ProcMod->getEnv().StdIn.size(), 4U); + EXPECT_EQ(ProcMod->getEnv().StdIn, + std::vector({0x01, 0x02, 0x03, 0x04})); + + // Test: Run function successfully to add "hello, wasmedge\n". + EXPECT_TRUE(HostFuncInst.run( + &MemInst, + std::initializer_list{UINT32_C(30), UINT32_C(16)}, + {})); + EXPECT_EQ(ProcMod->getEnv().StdIn.size(), 20U); + EXPECT_EQ(ProcMod->getEnv().StdIn, + std::vector({0x01, 0x02, 0x03, 0x04, 'h', 'e', 'l', + 'l', 'o', ',', ' ', 'w', 'a', 's', + 'm', 'e', 'd', 'g', 'e', '\n'})); + + // Test: Run function with nullptr memory instance -- fail + EXPECT_FALSE(HostFuncInst.run( + nullptr, + std::initializer_list{UINT32_C(0), UINT32_C(4)}, + {})); + + delete ProcMod; +} + +TEST(WasmEdgeProcessTest, SetTimeOut) { + // Create the wasmedge_process module instance. + auto *ProcMod = + dynamic_cast(createModule()); + EXPECT_FALSE(ProcMod == nullptr); + + // Get the function "wasmedge_process_set_timeout". + auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_set_timeout"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInst = + dynamic_cast( + FuncInst->getHostFunc()); + + // Test: Run function successfully to set timeout 100. + EXPECT_TRUE(HostFuncInst.run( + nullptr, std::initializer_list{UINT32_C(100)}, {})); + EXPECT_EQ(ProcMod->getEnv().TimeOut, 100U); + + delete ProcMod; +} + +TEST(WasmEdgeProcessTest, Run) { + // Create the wasmedge_process module instance. + auto *ProcMod = + dynamic_cast(createModule()); + EXPECT_FALSE(ProcMod == nullptr); + + // Create the memory instance. + WasmEdge::Runtime::Instance::MemoryInstance MemInst( + WasmEdge::AST::MemoryType(1)); + + // Clear the memory[0, 64]. + fillMemContent(MemInst, 0, 64); + // Set the memory[0, 4] as string "\01\02\03\04". + fillMemContent(MemInst, 0, std::string("\01\02\03\04")); + // Set the memory[30, 46] as string "hello, wasmedge\n". + fillMemContent(MemInst, 30, std::string("hello, wasmedge\n")); + + // Get the function "wasmedge_process_run". + auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_run"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInst = dynamic_cast( + FuncInst->getHostFunc()); + + // Return value. + std::array RetVal; + + // Test: Run function failed to run "c++" without allowing all commands. + ProcMod->getEnv().AllowedAll = false; + ProcMod->getEnv().Name = "c++"; + EXPECT_TRUE(HostFuncInst.run(nullptr, {}, RetVal)); + EXPECT_EQ(RetVal[0].get(), -1); + EXPECT_TRUE(ProcMod->getEnv().StdOut.size() == 0); + EXPECT_TRUE(ProcMod->getEnv().StdErr.size() > 0); + std::string ErrStr = + "Permission denied: Command \"c++\" is not in the white list. Please use " + "--allow-command=c++ or --allow-command-all to add \"c++\" command into " + "the white list.\n"; + EXPECT_TRUE(std::equal(ProcMod->getEnv().StdErr.begin(), + ProcMod->getEnv().StdErr.end(), ErrStr.begin())); + + // Test: Run function successfully to run "c++" with allowing all commands. + ProcMod->getEnv().AllowedAll = true; + ProcMod->getEnv().Name = "c++"; + EXPECT_TRUE(HostFuncInst.run(nullptr, {}, RetVal)); + EXPECT_EQ(RetVal[0].get(), 1); + EXPECT_TRUE(ProcMod->getEnv().StdOut.size() == 0); + EXPECT_TRUE(ProcMod->getEnv().StdErr.size() > 0); + + // Test: Run function successfully to run "c++" with allowing this command. + ProcMod->getEnv().AllowedAll = false; + ProcMod->getEnv().AllowedCmd.insert("c++"); + ProcMod->getEnv().Name = "c++"; + EXPECT_TRUE(HostFuncInst.run(nullptr, {}, RetVal)); + EXPECT_EQ(RetVal[0].get(), 1); + EXPECT_TRUE(ProcMod->getEnv().StdOut.size() == 0); + EXPECT_TRUE(ProcMod->getEnv().StdErr.size() > 0); + + // Test: Run function successfully to run "/bin/echo" with allowing this + // command. + ProcMod->getEnv().AllowedAll = false; + ProcMod->getEnv().AllowedCmd.clear(); + ProcMod->getEnv().AllowedCmd.insert("/bin/echo"); + ProcMod->getEnv().Name = "/bin/echo"; + ProcMod->getEnv().Args.push_back("123456 test"); + EXPECT_TRUE(HostFuncInst.run(nullptr, {}, RetVal)); + EXPECT_EQ(RetVal[0].get(), 0); + EXPECT_TRUE(ProcMod->getEnv().StdOut.size() == 12); + EXPECT_TRUE(ProcMod->getEnv().StdErr.size() == 0); + std::string OutStr = "123456 test\n"; + EXPECT_TRUE(std::equal(ProcMod->getEnv().StdOut.begin(), + ProcMod->getEnv().StdOut.end(), OutStr.begin())); + + delete ProcMod; +} + +TEST(WasmEdgeProcessTest, GetExitCode) { + // Create the wasmedge_process module instance. + auto *ProcMod = + dynamic_cast(createModule()); + EXPECT_FALSE(ProcMod == nullptr); + + // Get the function "wasmedge_process_get_exit_code". + auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_get_exit_code"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInst = + dynamic_cast( + FuncInst->getHostFunc()); + + // Test: Run function successfully to get exit code. + std::array RetVal; + EXPECT_TRUE(HostFuncInst.run(nullptr, {}, RetVal)); + EXPECT_EQ(RetVal[0].get(), 0); + + delete ProcMod; +} + +TEST(WasmEdgeProcessTest, GetStdOut) { + // Create the wasmedge_process module instance. + auto *ProcMod = + dynamic_cast(createModule()); + EXPECT_FALSE(ProcMod == nullptr); + + // Create the memory instance. + WasmEdge::Runtime::Instance::MemoryInstance MemInst( + WasmEdge::AST::MemoryType(1)); + + // Clear the memory[0, 256]. + fillMemContent(MemInst, 0, 256); + + // Get the function "wasmedge_process_run". + auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_run"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncRun = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "wasmedge_process_run". + FuncInst = ProcMod->findFuncExports("wasmedge_process_get_stdout_len"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetStdOutLen = + dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "wasmedge_process_run". + FuncInst = ProcMod->findFuncExports("wasmedge_process_get_stdout"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetStdOut = + dynamic_cast( + FuncInst->getHostFunc()); + + // Return value. + std::array RetVal; + + // Run the command "echo $(pwd)". + ProcMod->getEnv().Name = "echo"; + ProcMod->getEnv().AllowedCmd.insert("echo"); + ProcMod->getEnv().Args.push_back("$(pwd)"); + EXPECT_TRUE(HostFuncRun.run(nullptr, {}, RetVal)); + EXPECT_EQ(RetVal[0].get(), 0U); + + // Test: Run wasmedge_process_get_stdout_len successfully. + EXPECT_TRUE(HostFuncGetStdOutLen.run(nullptr, {}, RetVal)); + uint32_t Len = RetVal[0].get(); + EXPECT_TRUE(Len > 0U); + + // Test: Run function with nullptr memory instance -- fail + EXPECT_FALSE(HostFuncGetStdOut.run( + nullptr, std::initializer_list{UINT32_C(0)}, {})); + + // Test: Run wasmedge_process_get_stdout successfully. + EXPECT_TRUE(HostFuncGetStdOut.run( + &MemInst, std::initializer_list{UINT32_C(0)}, {})); + EXPECT_TRUE(std::equal(ProcMod->getEnv().StdOut.begin(), + ProcMod->getEnv().StdOut.end(), + MemInst.getPointer(0))); + + delete ProcMod; +} + +TEST(WasmEdgeProcessTest, GetStdErr) { + // Create the wasmedge_process module instance. + auto *ProcMod = + dynamic_cast(createModule()); + EXPECT_FALSE(ProcMod == nullptr); + + // Create the memory instance. + WasmEdge::Runtime::Instance::MemoryInstance MemInst( + WasmEdge::AST::MemoryType(1)); + + // Clear the memory[0, 256]. + fillMemContent(MemInst, 0, 256); + + // Get the function "wasmedge_process_run". + auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_run"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncRun = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "wasmedge_process_run". + FuncInst = ProcMod->findFuncExports("wasmedge_process_get_stderr_len"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetStdErrLen = + dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "wasmedge_process_run". + FuncInst = ProcMod->findFuncExports("wasmedge_process_get_stderr"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetStdErr = + dynamic_cast( + FuncInst->getHostFunc()); + + // Return value. + std::array RetVal; + + // Run the command "c++". + ProcMod->getEnv().Name = "c++"; + ProcMod->getEnv().AllowedCmd.insert("c++"); + EXPECT_TRUE(HostFuncRun.run(nullptr, {}, RetVal)); + EXPECT_NE(RetVal[0].get(), 0U); + + // Test: Run wasmedge_process_get_stdout_len successfully. + EXPECT_TRUE(HostFuncGetStdErrLen.run(nullptr, {}, RetVal)); + uint32_t Len = RetVal[0].get(); + EXPECT_TRUE(Len > 0U); + + // Test: Run function with nullptr memory instance -- fail + EXPECT_FALSE(HostFuncGetStdErr.run( + nullptr, std::initializer_list{UINT32_C(0)}, {})); + + // Test: Run wasmedge_process_get_stdout successfully. + EXPECT_TRUE(HostFuncGetStdErr.run( + &MemInst, std::initializer_list{UINT32_C(0)}, {})); + EXPECT_TRUE(std::equal(ProcMod->getEnv().StdOut.begin(), + ProcMod->getEnv().StdOut.end(), + MemInst.getPointer(0))); + + delete ProcMod; +} + +TEST(WasmEdgeProcessTest, Module) { + // Create the wasmedge_process module instance. + auto *ProcMod = + dynamic_cast(createModule()); + EXPECT_FALSE(ProcMod == nullptr); + EXPECT_EQ(ProcMod->getEnv().ExitCode, 0U); + EXPECT_EQ(ProcMod->getFuncExportNum(), 11U); + EXPECT_NE(ProcMod->findFuncExports("wasmedge_process_set_prog_name"), + nullptr); + EXPECT_NE(ProcMod->findFuncExports("wasmedge_process_add_arg"), nullptr); + EXPECT_NE(ProcMod->findFuncExports("wasmedge_process_add_env"), nullptr); + EXPECT_NE(ProcMod->findFuncExports("wasmedge_process_add_stdin"), nullptr); + EXPECT_NE(ProcMod->findFuncExports("wasmedge_process_set_timeout"), nullptr); + EXPECT_NE(ProcMod->findFuncExports("wasmedge_process_run"), nullptr); + EXPECT_NE(ProcMod->findFuncExports("wasmedge_process_get_exit_code"), + nullptr); + EXPECT_NE(ProcMod->findFuncExports("wasmedge_process_get_stdout_len"), + nullptr); + EXPECT_NE(ProcMod->findFuncExports("wasmedge_process_get_stdout"), nullptr); + EXPECT_NE(ProcMod->findFuncExports("wasmedge_process_get_stderr_len"), + nullptr); + EXPECT_NE(ProcMod->findFuncExports("wasmedge_process_get_stderr"), nullptr); + delete ProcMod; +} + +GTEST_API_ int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} From 0e308f844b18fefda67a414558088e53fb75d4a1 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Tue, 12 Jul 2022 04:02:54 +0800 Subject: [PATCH 055/623] [Utils] Add the option to assign the build type when building manylinux. Signed-off-by: YiYing He --- utils/docker/build-manylinux.sh | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/utils/docker/build-manylinux.sh b/utils/docker/build-manylinux.sh index 225571ce..db29bb15 100755 --- a/utils/docker/build-manylinux.sh +++ b/utils/docker/build-manylinux.sh @@ -2,6 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-FileCopyrightText: 2019-2022 Second State INC +if [[ "$#" -lt 1 ]]; then + export CMAKE_BUILD_TYPE=Release +else + export CMAKE_BUILD_TYPE=$1 +fi export PATH="/toolchain/bin:$PATH" export CC=gcc export CXX=g++ @@ -11,7 +16,7 @@ curl -s -L -O --remote-name-all https://boostorg.jfrog.io/artifactory/main/relea echo "475d589d51a7f8b3ba2ba4eda022b170e562ca3b760ee922c146b6c65856ef39 boost_1_79_0.tar.bz2" | sha256sum -c git config --global --add safe.directory $(pwd) bzip2 -dc boost_1_79_0.tar.bz2 | tar -xf - -if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release -DWASMEDGE_BUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM;DEB" -DBoost_NO_SYSTEM_PATHS=TRUE -DBOOST_INCLUDEDIR=$(pwd)/boost_1_79_0/; then +if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DWASMEDGE_BUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM;DEB" -DBoost_NO_SYSTEM_PATHS=TRUE -DBOOST_INCLUDEDIR=$(pwd)/boost_1_79_0/; then echo === CMakeOutput.log === cat build/CMakeFiles/CMakeOutput.log echo === CMakeError.log === From beadc2ee2820b2146ef942f5cb7775123181b6c0 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Tue, 12 Jul 2022 04:03:41 +0800 Subject: [PATCH 056/623] [Utils] Separate the openVINO installation and WasmEdge building. Signed-off-by: YiYing He --- test/plugins/wasi_nn/CMakeLists.txt | 2 +- utils/wasi-nn/build-wasinn-ubuntu-openvino.sh | 16 -------------- ...tures.sh => download-openvino-fixtures.sh} | 0 utils/wasi-nn/install-openvino.sh | 22 +++++++++++++++++++ 4 files changed, 23 insertions(+), 17 deletions(-) rename utils/wasi-nn/{download_openvino_fixtures.sh => download-openvino-fixtures.sh} (100%) create mode 100755 utils/wasi-nn/install-openvino.sh diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index 24b6e98e..bc6193a0 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -10,7 +10,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) if(BACKEND MATCHES "OpenVINO") message( STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_openvino_fixtures") execute_process( - COMMAND bash ${CMAKE_SOURCE_DIR}/utils/wasi-nn/download_openvino_fixtures.sh ${CMAKE_CURRENT_BINARY_DIR}/wasinn_openvino_fixtures + COMMAND bash ${CMAKE_SOURCE_DIR}/utils/wasi-nn/download-openvino-fixtures.sh ${CMAKE_CURRENT_BINARY_DIR}/wasinn_openvino_fixtures RESULT_VARIABLE DOWNLOAD_ERROR OUTPUT_STRIP_TRAILING_WHITESPACE) file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_openvino_fixtures/mobilenet.bin CHECKSUM_WEIGHT) diff --git a/utils/wasi-nn/build-wasinn-ubuntu-openvino.sh b/utils/wasi-nn/build-wasinn-ubuntu-openvino.sh index 8899ec77..719b3599 100755 --- a/utils/wasi-nn/build-wasinn-ubuntu-openvino.sh +++ b/utils/wasi-nn/build-wasinn-ubuntu-openvino.sh @@ -2,22 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-FileCopyrightText: 2019-2022 Second State INC -if [[ ! -v "${OPENVINO_VERSION}" ]]; then - OPENVINO_VERSION="2021.4.582" -fi -if [[ ! -v "${OPENVINO_YEAR}" ]]; then - OPENVINO_YEAR="2021" -fi -if [[ ! -v "${CMAKE_BUILD_TYPE}" ]]; then - CMAKE_BUILD_TYPE=Release -fi - -set -e -echo "Installing OpenVINO with version ${OPENVINO_VERSION}" -curl -sSL https://apt.repos.intel.com/openvino/$OPENVINO_YEAR/GPG-PUB-KEY-INTEL-OPENVINO-$OPENVINO_YEAR | gpg --dearmor > /usr/share/keyrings/GPG-PUB-KEY-INTEL-OPENVINO-$OPENVINO_YEAR.gpg -echo "deb [signed-by=/usr/share/keyrings/GPG-PUB-KEY-INTEL-OPENVINO-$OPENVINO_YEAR.gpg] https://apt.repos.intel.com/openvino/$OPENVINO_YEAR all main" | tee /etc/apt/sources.list.d/intel-openvino-$OPENVINO_YEAR.list -apt update -apt install -y intel-openvino-runtime-ubuntu20-$OPENVINO_VERSION source /opt/intel/openvino_2021/bin/setupvars.sh ldconfig git config --global --add safe.directory $(pwd) diff --git a/utils/wasi-nn/download_openvino_fixtures.sh b/utils/wasi-nn/download-openvino-fixtures.sh similarity index 100% rename from utils/wasi-nn/download_openvino_fixtures.sh rename to utils/wasi-nn/download-openvino-fixtures.sh diff --git a/utils/wasi-nn/install-openvino.sh b/utils/wasi-nn/install-openvino.sh new file mode 100755 index 00000000..7c13f261 --- /dev/null +++ b/utils/wasi-nn/install-openvino.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +if [[ ! -v "${OPENVINO_VERSION}" ]]; then + OPENVINO_VERSION="2021.4.582" +fi +if [[ ! -v "${OPENVINO_YEAR}" ]]; then + OPENVINO_YEAR="2021" +fi +if [[ ! -v "${CMAKE_BUILD_TYPE}" ]]; then + CMAKE_BUILD_TYPE=Release +fi + +set -e +echo "Installing OpenVINO with version ${OPENVINO_VERSION}" +curl -sSL https://apt.repos.intel.com/openvino/$OPENVINO_YEAR/GPG-PUB-KEY-INTEL-OPENVINO-$OPENVINO_YEAR | gpg --dearmor > /usr/share/keyrings/GPG-PUB-KEY-INTEL-OPENVINO-$OPENVINO_YEAR.gpg +echo "deb [signed-by=/usr/share/keyrings/GPG-PUB-KEY-INTEL-OPENVINO-$OPENVINO_YEAR.gpg] https://apt.repos.intel.com/openvino/$OPENVINO_YEAR all main" | tee /etc/apt/sources.list.d/intel-openvino-$OPENVINO_YEAR.list +apt update +apt install -y intel-openvino-runtime-ubuntu20-$OPENVINO_VERSION +source /opt/intel/openvino_2021/bin/setupvars.sh +ldconfig From 7cccd2cfa07563b4209ca2ceeff213ac37dd734c Mon Sep 17 00:00:00 2001 From: sonder-joker Date: Wed, 13 Jul 2022 13:37:03 +0800 Subject: [PATCH 057/623] [WASI] wasi-crypto host functions and definitions. 1. Scripts to fetch witx file and generate header. 2. Host module and host function definitions. 3. Export to CLI tool. Signed-off-by: sonder-joker --- thirdparty/wasi_crypto/api.hpp | 701 ++++++++++++++++++++++++++++++++ utils/docker/build-manylinux.sh | 56 ++- 2 files changed, 751 insertions(+), 6 deletions(-) create mode 100644 thirdparty/wasi_crypto/api.hpp diff --git a/thirdparty/wasi_crypto/api.hpp b/thirdparty/wasi_crypto/api.hpp new file mode 100644 index 00000000..115423af --- /dev/null +++ b/thirdparty/wasi_crypto/api.hpp @@ -0,0 +1,701 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +/** + * THIS FILE IS AUTO-GENERATED from the following files: + * proposal_kx.witx, proposal_asymmetric_common.witx, proposal_common.witx, proposal_signatures.witx, proposal_symmetric.witx, proposal_external_secrets.witx + * + * @file + * This file describes the [WASI] interface, consisting of functions, types, + * and defined values (macros). + * + * The interface described here is greatly inspired by [CloudABI]'s clean, + * thoughtfully-designed, capability-oriented, POSIX-style API. + * + * [CloudABI]: https://github.com/NuxiNL/cloudlibc + * [WASI]: https://github.com/WebAssembly/WASI/ + */ + +#pragma once + +#include +#include +#include + +using const_uint8_t_ptr = uint32_t; +using uint8_t_ptr = uint32_t; + +#define DEFINE_ENUM_OPERATORS(type) \ + inline constexpr type operator~(type a) noexcept { \ + return static_cast(~static_cast>(a)); \ + } \ + inline constexpr type operator|(type a, type b) noexcept { \ + return static_cast(static_cast>(a) | \ + static_cast>(b)); \ + } \ + inline constexpr type &operator|=(type &a, type b) noexcept { \ + a = a | b; \ + return a; \ + } \ + inline constexpr type operator&(type a, type b) noexcept { \ + return static_cast(static_cast>(a) & \ + static_cast>(b)); \ + } \ + inline constexpr type &operator&=(type &a, type b) noexcept { \ + a = a & b; \ + return a; \ + } + +static_assert(alignof(int8_t) == 1, "non-wasi data layout"); +static_assert(alignof(uint8_t) == 1, "non-wasi data layout"); +static_assert(alignof(int16_t) == 2, "non-wasi data layout"); +static_assert(alignof(uint16_t) == 2, "non-wasi data layout"); +static_assert(alignof(int32_t) == 4, "non-wasi data layout"); +static_assert(alignof(uint32_t) == 4, "non-wasi data layout"); +static_assert(alignof(int64_t) == 8, "non-wasi data layout"); +static_assert(alignof(uint64_t) == 8, "non-wasi data layout"); +static_assert(alignof(const_uint8_t_ptr) == 4, "non-wasi data layout"); +static_assert(alignof(uint8_t_ptr) == 4, "non-wasi data layout"); + +/** + * Error codes. + */ +enum __wasi_crypto_errno_e_t : uint16_t { + /** + * Operation succeeded. + */ + __WASI_CRYPTO_ERRNO_SUCCESS = 0, + + /** + * An error occurred when trying to during a conversion from a host type to a guest type. + * + * Only an internal bug can throw this error. + */ + __WASI_CRYPTO_ERRNO_GUEST_ERROR = 1, + + /** + * The requested operation is valid, but not implemented by the host. + */ + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED = 2, + + /** + * The requested feature is not supported by the chosen algorithm. + */ + __WASI_CRYPTO_ERRNO_UNSUPPORTED_FEATURE = 3, + + /** + * The requested operation is valid, but was administratively prohibited. + */ + __WASI_CRYPTO_ERRNO_PROHIBITED_OPERATION = 4, + + /** + * Unsupported encoding for an import or export operation. + */ + __WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING = 5, + + /** + * The requested algorithm is not supported by the host. + */ + __WASI_CRYPTO_ERRNO_UNSUPPORTED_ALGORITHM = 6, + + /** + * The requested option is not supported by the currently selected algorithm. + */ + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION = 7, + + /** + * An invalid or incompatible key was supplied. + * + * The key may not be valid, or was generated for a different algorithm or parameters set. + */ + __WASI_CRYPTO_ERRNO_INVALID_KEY = 8, + + /** + * The currently selected algorithm doesn't support the requested output length. + * + * This error is thrown by non-extensible hash functions, when requesting an output size larger than they produce out of a single block. + */ + __WASI_CRYPTO_ERRNO_INVALID_LENGTH = 9, + + /** + * A signature or authentication tag verification failed. + */ + __WASI_CRYPTO_ERRNO_VERIFICATION_FAILED = 10, + + /** + * A secure random numbers generator is not available. + * + * The requested operation requires random numbers, but the host cannot securely generate them at the moment. + */ + __WASI_CRYPTO_ERRNO_RNG_ERROR = 11, + + /** + * An error was returned by the underlying cryptography library. + * + * The host may be running out of memory, parameters may be incompatible with the chosen implementation of an algorithm or another unexpected error may have happened. + * + * Ideally, the specification should provide enough details and guidance to make this error impossible to ever be thrown. + * + * Realistically, the WASI crypto module cannot possibly cover all possible error types implementations can return, especially since some of these may be language-specific. + * This error can thus be thrown when other error types are not suitable, and when the original error comes from the cryptographic primitives themselves and not from the WASI module. + */ + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE = 12, + + /** + * The supplied signature is invalid, or incompatible with the chosen algorithm. + */ + __WASI_CRYPTO_ERRNO_INVALID_SIGNATURE = 13, + + /** + * An attempt was made to close a handle that was already closed. + */ + __WASI_CRYPTO_ERRNO_CLOSED = 14, + + /** + * A function was called with an unassigned handle, a closed handle, or handle of an unexpected type. + */ + __WASI_CRYPTO_ERRNO_INVALID_HANDLE = 15, + + /** + * The host needs to copy data to a guest-allocated buffer, but that buffer is too small. + */ + __WASI_CRYPTO_ERRNO_OVERFLOW = 16, + + /** + * An internal error occurred. + * + * This error is reserved to internal consistency checks, and must only be sent if the internal state of the host remains safe after an inconsistency was detected. + */ + __WASI_CRYPTO_ERRNO_INTERNAL_ERROR = 17, + + /** + * Too many handles are currently open, and a new one cannot be created. + * + * Implementations are free to represent handles as they want, and to enforce limits to limit resources usage. + */ + __WASI_CRYPTO_ERRNO_TOO_MANY_HANDLES = 18, + + /** + * A key was provided, but the chosen algorithm doesn't support keys. + * + * This is returned by symmetric operations. + * + * Many hash functions, in particular, do not support keys without being used in particular constructions. + * Blindly ignoring a key provided by mistake while trying to open a context for such as function could cause serious security vulnerabilities. + * + * These functions must refuse to create the context and return this error instead. + */ + __WASI_CRYPTO_ERRNO_KEY_NOT_SUPPORTED = 19, + + /** + * A key is required for the chosen algorithm, but none was given. + */ + __WASI_CRYPTO_ERRNO_KEY_REQUIRED = 20, + + /** + * The provided authentication tag is invalid or incompatible with the current algorithm. + * + * This error is returned by decryption functions and tag verification functions. + * + * Unlike `verification_failed`, this error code is returned when the tag cannot possibly verify for any input. + */ + __WASI_CRYPTO_ERRNO_INVALID_TAG = 21, + + /** + * The requested operation is incompatible with the current scheme. + * + * For example, the `symmetric_state_encrypt()` function cannot complete if the selected construction is a key derivation function. + * This error code will be returned instead. + */ + __WASI_CRYPTO_ERRNO_INVALID_OPERATION = 22, + + /** + * A nonce is required. + * + * Most encryption schemes require a nonce. + * + * In the absence of a nonce, the WASI cryptography module can automatically generate one, if that can be done safely. The nonce can be retrieved later with the `symmetric_state_option_get()` function using the `nonce` parameter. + * If automatically generating a nonce cannot be done safely, the module never falls back to an insecure option and requests an explicit nonce by throwing that error. + */ + __WASI_CRYPTO_ERRNO_NONCE_REQUIRED = 23, + + /** + * The provided nonce doesn't have a correct size for the given cipher. + */ + __WASI_CRYPTO_ERRNO_INVALID_NONCE = 24, + + /** + * The named option was not set. + * + * The caller tried to read the value of an option that was not set. + * This error is used to make the distinction between an empty option, and an option that was not set and left to its default value. + */ + __WASI_CRYPTO_ERRNO_OPTION_NOT_SET = 25, + + /** + * A key or key pair matching the requested identifier cannot be found using the supplied information. + * + * This error is returned by a secrets manager via the `keypair_from_id()` function. + */ + __WASI_CRYPTO_ERRNO_NOT_FOUND = 26, + + /** + * The algorithm requires parameters that haven't been set. + * + * Non-generic options are required and must be given by building an `options` set and giving that object to functions instantiating that algorithm. + */ + __WASI_CRYPTO_ERRNO_PARAMETERS_MISSING = 27, + + /** + * A requested computation is not done yet, and additional calls to the function are required. + * + * Some functions, such as functions generating key pairs and password stretching functions, can take a long time to complete. + * + * In order to avoid a host call to be blocked for too long, these functions can return prematurely, requiring additional calls with the same parameters until they complete. + */ + __WASI_CRYPTO_ERRNO_IN_PROGRESS = 28, + + /** + * Multiple keys have been provided, but they do not share the same type. + * + * This error is returned when trying to build a key pair from a public key and a secret key that were created for different and incompatible algorithms. + */ + __WASI_CRYPTO_ERRNO_INCOMPATIBLE_KEYS = 29, + + /** + * A managed key or secret expired and cannot be used any more. + */ + __WASI_CRYPTO_ERRNO_EXPIRED = 30, + +}; +static_assert(sizeof(__wasi_crypto_errno_e_t) == 2, "witx calculated size"); +static_assert(alignof(__wasi_crypto_errno_e_t) == 2, "witx calculated align"); + +/** + * Encoding to use for importing or exporting a key pair. + */ +enum __wasi_keypair_encoding_e_t : uint16_t { + /** + * Raw bytes. + */ + __WASI_KEYPAIR_ENCODING_RAW = 0, + + /** + * PCSK8/DER encoding. + */ + __WASI_KEYPAIR_ENCODING_PKCS8 = 1, + + /** + * PEM encoding. + */ + __WASI_KEYPAIR_ENCODING_PEM = 2, + + /** + * PCSK8/DER encoding with compressed coordinates. + */ + __WASI_KEYPAIR_ENCODING_COMPRESSED_PKCS8 = 3, + + /** + * PEM encoding with compressed coordinates. + */ + __WASI_KEYPAIR_ENCODING_COMPRESSED_PEM = 4, + + /** + * Implementation-defined encoding. + */ + __WASI_KEYPAIR_ENCODING_LOCAL = 5, + +}; +static_assert(sizeof(__wasi_keypair_encoding_e_t) == 2, "witx calculated size"); +static_assert(alignof(__wasi_keypair_encoding_e_t) == 2, "witx calculated align"); + +/** + * Encoding to use for importing or exporting a public key. + */ +enum __wasi_publickey_encoding_e_t : uint16_t { + /** + * Raw bytes. + */ + __WASI_PUBLICKEY_ENCODING_RAW = 0, + + /** + * PKCS8/DER encoding. + */ + __WASI_PUBLICKEY_ENCODING_PKCS8 = 1, + + /** + * PEM encoding. + */ + __WASI_PUBLICKEY_ENCODING_PEM = 2, + + /** + * SEC-1 encoding. + */ + __WASI_PUBLICKEY_ENCODING_SEC = 3, + + /** + * Compressed SEC-1 encoding. + */ + __WASI_PUBLICKEY_ENCODING_COMPRESSED_SEC = 4, + + /** + * PKCS8/DER encoding with compressed coordinates. + */ + __WASI_PUBLICKEY_ENCODING_COMPRESSED_PKCS8 = 5, + + /** + * PEM encoding with compressed coordinates. + */ + __WASI_PUBLICKEY_ENCODING_COMPRESSED_PEM = 6, + + /** + * Implementation-defined encoding. + */ + __WASI_PUBLICKEY_ENCODING_LOCAL = 7, + +}; +static_assert(sizeof(__wasi_publickey_encoding_e_t) == 2, "witx calculated size"); +static_assert(alignof(__wasi_publickey_encoding_e_t) == 2, "witx calculated align"); + +/** + * Encoding to use for importing or exporting a secret key. + */ +enum __wasi_secretkey_encoding_e_t : uint16_t { + /** + * Raw bytes. + */ + __WASI_SECRETKEY_ENCODING_RAW = 0, + + /** + * PKCS8/DER encoding. + */ + __WASI_SECRETKEY_ENCODING_PKCS8 = 1, + + /** + * PEM encoding. + */ + __WASI_SECRETKEY_ENCODING_PEM = 2, + + /** + * SEC-1 encoding. + */ + __WASI_SECRETKEY_ENCODING_SEC = 3, + + /** + * Implementation-defined encoding. + */ + __WASI_SECRETKEY_ENCODING_LOCAL = 4, + +}; +static_assert(sizeof(__wasi_secretkey_encoding_e_t) == 2, "witx calculated size"); +static_assert(alignof(__wasi_secretkey_encoding_e_t) == 2, "witx calculated align"); + +/** + * Encoding to use for importing or exporting a signature. + */ +enum __wasi_signature_encoding_e_t : uint16_t { + /** + * Raw bytes. + */ + __WASI_SIGNATURE_ENCODING_RAW = 0, + + /** + * DER encoding. + */ + __WASI_SIGNATURE_ENCODING_DER = 1, + +}; +static_assert(sizeof(__wasi_signature_encoding_e_t) == 2, "witx calculated size"); +static_assert(alignof(__wasi_signature_encoding_e_t) == 2, "witx calculated align"); + +/** + * An algorithm category. + */ +enum __wasi_algorithm_type_e_t : uint16_t { + __WASI_ALGORITHM_TYPE_SIGNATURES = 0, + + __WASI_ALGORITHM_TYPE_SYMMETRIC = 1, + + __WASI_ALGORITHM_TYPE_KEY_EXCHANGE = 2, + +}; +static_assert(sizeof(__wasi_algorithm_type_e_t) == 2, "witx calculated size"); +static_assert(alignof(__wasi_algorithm_type_e_t) == 2, "witx calculated align"); + +/** + * Version of a managed key. + * + * A version can be an arbitrary `u64` integer, with the expection of some reserved values. + */ +using __wasi_version_t = uint64_t; + +static_assert(sizeof(__wasi_version_t) == 8, "witx calculated size"); +static_assert(alignof(__wasi_version_t) == 8, "witx calculated align"); + +/** + * Size of a value. + */ +using __wasi_size_t = uint32_t; + +static_assert(sizeof(__wasi_size_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_size_t) == 4, "witx calculated align"); + +/** + * A UNIX timestamp, in seconds since 01/01/1970. + */ +using __wasi_timestamp_t = uint64_t; + +static_assert(sizeof(__wasi_timestamp_t) == 8, "witx calculated size"); +static_assert(alignof(__wasi_timestamp_t) == 8, "witx calculated align"); + +/** + * Handle for functions returning output whose size may be large or not known in advance. + * + * An `array_output` object contains a host-allocated byte array. + * + * A guest can get the size of that array after a function returns in order to then allocate a buffer of the correct size. + * In addition, the content of such an object can be consumed by a guest in a streaming fashion. + * + * An `array_output` handle is automatically closed after its full content has been consumed. + */ +using __wasi_array_output_t = int32_t; + +static_assert(sizeof(__wasi_array_output_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_array_output_t) == 4, "witx calculated align"); + +/** + * A set of options. + * + * This type is used to set non-default parameters. + * + * The exact set of allowed options depends on the algorithm being used. + */ +using __wasi_options_t = int32_t; + +static_assert(sizeof(__wasi_options_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_options_t) == 4, "witx calculated align"); + +/** + * A handle to the optional secrets management facilities offered by a host. + * + * This is used to generate, retrieve and invalidate managed keys. + */ +using __wasi_secrets_manager_t = int32_t; + +static_assert(sizeof(__wasi_secrets_manager_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_secrets_manager_t) == 4, "witx calculated align"); + +/** + * A key pair. + */ +using __wasi_keypair_t = int32_t; + +static_assert(sizeof(__wasi_keypair_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_keypair_t) == 4, "witx calculated align"); + +/** + * A state to absorb data to be signed. + * + * After a signature has been computed or verified, the state remains valid for further operations. + * + * A subsequent signature would sign all the data accumulated since the creation of the state object. + */ +using __wasi_signature_state_t = int32_t; + +static_assert(sizeof(__wasi_signature_state_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_signature_state_t) == 4, "witx calculated align"); + +/** + * A signature. + */ +using __wasi_signature_t = int32_t; + +static_assert(sizeof(__wasi_signature_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_signature_t) == 4, "witx calculated align"); + +/** + * A public key, for key exchange and signature verification. + */ +using __wasi_publickey_t = int32_t; + +static_assert(sizeof(__wasi_publickey_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_publickey_t) == 4, "witx calculated align"); + +/** + * A secret key, for key exchange mechanisms. + */ +using __wasi_secretkey_t = int32_t; + +static_assert(sizeof(__wasi_secretkey_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_secretkey_t) == 4, "witx calculated align"); + +/** + * A state to absorb signed data to be verified. + */ +using __wasi_signature_verification_state_t = int32_t; + +static_assert(sizeof(__wasi_signature_verification_state_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_signature_verification_state_t) == 4, "witx calculated align"); + +/** + * A state to perform symmetric operations. + * + * The state is not reset nor invalidated after an option has been performed. + * Incremental updates and sessions are thus supported. + */ +using __wasi_symmetric_state_t = int32_t; + +static_assert(sizeof(__wasi_symmetric_state_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_symmetric_state_t) == 4, "witx calculated align"); + +/** + * A symmetric key. + * + * The key can be imported from raw bytes, or can be a reference to a managed key. + * + * If it was imported, the host will wipe it from memory as soon as the handle is closed. + */ +using __wasi_symmetric_key_t = int32_t; + +static_assert(sizeof(__wasi_symmetric_key_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_symmetric_key_t) == 4, "witx calculated align"); + +/** + * An authentication tag. + * + * This is an object returned by functions computing authentication tags. + * + * A tag can be compared against another tag (directly supplied as raw bytes) in constant time with the `symmetric_tag_verify()` function. + * + * This object type can't be directly created from raw bytes. They are only returned by functions computing MACs. + * + * The host is reponsible for securely wiping them from memory on close. + */ +using __wasi_symmetric_tag_t = int32_t; + +static_assert(sizeof(__wasi_symmetric_tag_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_symmetric_tag_t) == 4, "witx calculated align"); + +/** + * Options index, only required by the Interface Types translation layer. + */ +enum __wasi_opt_options_u_e_t : uint8_t { + __WASI_OPT_OPTIONS_U_SOME = 0, + + __WASI_OPT_OPTIONS_U_NONE = 1, + +}; +static_assert(sizeof(__wasi_opt_options_u_e_t) == 1, "witx calculated size"); +static_assert(alignof(__wasi_opt_options_u_e_t) == 1, "witx calculated align"); + +/** + * An optional options set. + * + * This union simulates an `Option` type to make the `options` parameter of some functions optional. + */ +union __wasi_opt_options_u_t { + __wasi_options_t some; +}; +struct __wasi_opt_options_t { + __wasi_opt_options_u_e_t tag; + __wasi_opt_options_u_t u; +}; + +static_assert(sizeof(__wasi_opt_options_t) == 8, "witx calculated size"); +static_assert(alignof(__wasi_opt_options_t) == 4, "witx calculated align"); +static_assert(offsetof(__wasi_opt_options_t, u) == 4, "witx calculated union offset"); + +/** + * Symmetric key index, only required by the Interface Types translation layer. + */ +enum __wasi_opt_symmetric_key_u_e_t : uint8_t { + __WASI_OPT_SYMMETRIC_KEY_U_SOME = 0, + + __WASI_OPT_SYMMETRIC_KEY_U_NONE = 1, + +}; +static_assert(sizeof(__wasi_opt_symmetric_key_u_e_t) == 1, "witx calculated size"); +static_assert(alignof(__wasi_opt_symmetric_key_u_e_t) == 1, "witx calculated align"); + +/** + * An optional symmetric key. + * + * This union simulates an `Option` type to make the `symmetric_key` parameter of some functions optional. + */ +union __wasi_opt_symmetric_key_u_t { + __wasi_symmetric_key_t some; +}; +struct __wasi_opt_symmetric_key_t { + __wasi_opt_symmetric_key_u_e_t tag; + __wasi_opt_symmetric_key_u_t u; +}; + +static_assert(sizeof(__wasi_opt_symmetric_key_t) == 8, "witx calculated size"); +static_assert(alignof(__wasi_opt_symmetric_key_t) == 4, "witx calculated align"); +static_assert(offsetof(__wasi_opt_symmetric_key_t, u) == 4, "witx calculated union offset"); + +using __wasi_u64_t = uint64_t; + +static_assert(sizeof(__wasi_u64_t) == 8, "witx calculated size"); +static_assert(alignof(__wasi_u64_t) == 8, "witx calculated align"); + +/** + * `$kx_keypair` is just an alias for `$keypair` + * + * However, bindings may want to define a specialized type `kx_keypair` as a super class of `keypair`. + */ +using __wasi_kx_keypair_t = __wasi_keypair_t; + +static_assert(sizeof(__wasi_kx_keypair_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_kx_keypair_t) == 4, "witx calculated align"); + +/** + * `$kx_publickey` is just an alias for `$publickey` + * + * However, bindings may want to define a specialized type `kx_publickey` as a super class of `publickey`, with additional methods such as `dh`. + */ +using __wasi_kx_publickey_t = __wasi_publickey_t; + +static_assert(sizeof(__wasi_kx_publickey_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_kx_publickey_t) == 4, "witx calculated align"); + +/** + * `$kx_secretkey` is just an alias for `$secretkey` + * + * However, bindings may want to define a specialized type `kx_secretkey` as a super class of `secretkeykey`, with additional methods such as `dh`. + */ +using __wasi_kx_secretkey_t = __wasi_secretkey_t; + +static_assert(sizeof(__wasi_kx_secretkey_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_kx_secretkey_t) == 4, "witx calculated align"); + +/** + * `$signature_keypair` is just an alias for `$keypair` + * + * However, bindings may want to define a specialized type `signature_keypair` as a super class of `keypair`, with additional methods such as `sign`. + */ +using __wasi_signature_keypair_t = __wasi_keypair_t; + +static_assert(sizeof(__wasi_signature_keypair_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_signature_keypair_t) == 4, "witx calculated align"); + +/** + * `$signature_publickey` is just an alias for `$publickey` + * + * However, bindings may want to define a specialized type `signature_publickey` as a super class of `publickey`, with additional methods such as `verify`. + */ +using __wasi_signature_publickey_t = __wasi_publickey_t; + +static_assert(sizeof(__wasi_signature_publickey_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_signature_publickey_t) == 4, "witx calculated align"); + +/** + * `$signature_secretkey` is just an alias for `$secretkey` + * + * However, bindings may want to define a specialized type `signature_secretkey` as a super class of `secretkey`. + */ +using __wasi_signature_secretkey_t = __wasi_secretkey_t; + +static_assert(sizeof(__wasi_signature_secretkey_t) == 4, "witx calculated size"); +static_assert(alignof(__wasi_signature_secretkey_t) == 4, "witx calculated align"); + diff --git a/utils/docker/build-manylinux.sh b/utils/docker/build-manylinux.sh index db29bb15..e3fc59bb 100755 --- a/utils/docker/build-manylinux.sh +++ b/utils/docker/build-manylinux.sh @@ -2,11 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-FileCopyrightText: 2019-2022 Second State INC -if [[ "$#" -lt 1 ]]; then - export CMAKE_BUILD_TYPE=Release -else - export CMAKE_BUILD_TYPE=$1 -fi export PATH="/toolchain/bin:$PATH" export CC=gcc export CXX=g++ @@ -16,7 +11,56 @@ curl -s -L -O --remote-name-all https://boostorg.jfrog.io/artifactory/main/relea echo "475d589d51a7f8b3ba2ba4eda022b170e562ca3b760ee922c146b6c65856ef39 boost_1_79_0.tar.bz2" | sha256sum -c git config --global --add safe.directory $(pwd) bzip2 -dc boost_1_79_0.tar.bz2 | tar -xf - -if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DWASMEDGE_BUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM;DEB" -DBoost_NO_SYSTEM_PATHS=TRUE -DBOOST_INCLUDEDIR=$(pwd)/boost_1_79_0/; then + +CMAKE_BUILD_TYPE="Release" +WASMEDGE_BUILD_WASI_CRYPTO="OFF" + +for i in "$@"; do + case $i in + -DCMAKE_BUILD_TYPE=*) + CMAKE_BUILD_TYPE="${i#*=}" + shift + ;; + -DWASMEDGE_BUILD_WASI_CRYPTO=*) + WASMEDGE_BUILD_WASI_CRYPTO=$(echo ${i#*=} | tr '[:lower:]' '[:upper:]') + shift + ;; + *) + ;; + esac +done + +CMAKE_OPTS="-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}" +if [ ${WASMEDGE_BUILD_WASI_CRYPTO} == "ON" ]; then + echo "Building wasi-crypto..." + # install openssl + curl -s -L -O --remote-name-all https://www.openssl.org/source/openssl-1.1.1n.tar.gz + echo "40dceb51a4f6a5275bde0e6bf20ef4b91bfc32ed57c0552e2e8e15463372b17a openssl-1.1.1n.tar.gz" | sha256sum -c + tar -xf openssl-1.1.1n.tar.gz + cd ./openssl-1.1.1n + # openssl configure need newer perl + curl -s -L -O --remote-name-all https://www.cpan.org/src/5.0/perl-5.34.0.tar.gz + tar -xf perl-5.34.0.tar.gz + cd perl-5.34.0 + mkdir localperl + ./Configure -des -Dprefix=$(pwd)/localperl/ + make -j + # too long! + # make test + make install + export PATH="$(pwd)/localperl/bin/:$PATH" + cd .. + # configure by previous perl + mkdir openssl + ./perl-5.34.0/localperl/bin/perl ./config --prefix=$(pwd)/openssl --openssldir=$(pwd)/openssl + make -j + make test + make install + cd .. + CMAKE_OPTS="${CMAKE_OPTS} -DWASMEDGE_BUILD_WASI_CRYPTO=ON -DOPENSSL_ROOT_DIR=$(pwd)/openssl-1.1.1n/openssl" +fi + +if ! cmake -Bbuild -GNinja ${CMAKE_OPTS} -DWASMEDGE_BUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM;DEB" -DBoost_NO_SYSTEM_PATHS=TRUE -DBOOST_INCLUDEDIR=$(pwd)/boost_1_79_0/; then echo === CMakeOutput.log === cat build/CMakeFiles/CMakeOutput.log echo === CMakeError.log === From e49cee83d45623a81f543bc437ff22a9179f3600 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Thu, 14 Jul 2022 13:39:11 +0800 Subject: [PATCH 058/623] [Misc] Refactor the WASI-Crypto to plugin. Signed-off-by: YiYing He --- plugins/CMakeLists.txt | 8 +- plugins/wasi_crypto/CMakeLists.txt | 75 + plugins/wasi_crypto/asymmetric_common/ctx.cpp | 197 +++ plugins/wasi_crypto/asymmetric_common/ecdsa.h | 408 +++++ .../wasi_crypto/asymmetric_common/func.cpp | 517 ++++++ plugins/wasi_crypto/asymmetric_common/func.h | 188 +++ .../wasi_crypto/asymmetric_common/keypair.cpp | 123 ++ .../wasi_crypto/asymmetric_common/keypair.h | 52 + .../asymmetric_common/publickey.cpp | 37 + .../wasi_crypto/asymmetric_common/publickey.h | 41 + .../wasi_crypto/asymmetric_common/registed.h | 48 + .../asymmetric_common/secretkey.cpp | 42 + .../wasi_crypto/asymmetric_common/secretkey.h | 42 + plugins/wasi_crypto/common/array_output.cpp | 32 + plugins/wasi_crypto/common/array_output.h | 69 + plugins/wasi_crypto/common/ctx.cpp | 87 + plugins/wasi_crypto/common/func.cpp | 215 +++ plugins/wasi_crypto/common/func.h | 103 ++ plugins/wasi_crypto/common/options.cpp | 55 + plugins/wasi_crypto/common/options.h | 62 + plugins/wasi_crypto/ctx.cpp | 38 + plugins/wasi_crypto/ctx.h | 344 ++++ plugins/wasi_crypto/kx/ctx.cpp | 69 + plugins/wasi_crypto/kx/dh/ecdsa.cpp | 39 + plugins/wasi_crypto/kx/dh/ecdsa.h | 64 + plugins/wasi_crypto/kx/dh/x25519.cpp | 181 +++ plugins/wasi_crypto/kx/dh/x25519.h | 109 ++ plugins/wasi_crypto/kx/func.cpp | 81 + plugins/wasi_crypto/kx/func.h | 52 + plugins/wasi_crypto/kx/kx.cpp | 55 + plugins/wasi_crypto/kx/kx.h | 57 + plugins/wasi_crypto/kx/options.cpp | 37 + plugins/wasi_crypto/kx/options.h | 47 + plugins/wasi_crypto/kx/registed.h | 42 + plugins/wasi_crypto/module.cpp | 161 ++ plugins/wasi_crypto/module.h | 35 + plugins/wasi_crypto/signatures/ctx.cpp | 128 ++ plugins/wasi_crypto/signatures/ecdsa.cpp | 125 ++ plugins/wasi_crypto/signatures/ecdsa.h | 118 ++ plugins/wasi_crypto/signatures/eddsa.cpp | 250 +++ plugins/wasi_crypto/signatures/eddsa.h | 165 ++ plugins/wasi_crypto/signatures/func.cpp | 225 +++ plugins/wasi_crypto/signatures/func.h | 113 ++ plugins/wasi_crypto/signatures/options.cpp | 37 + plugins/wasi_crypto/signatures/options.h | 46 + plugins/wasi_crypto/signatures/registed.h | 52 + plugins/wasi_crypto/signatures/rsa.cpp | 375 +++++ plugins/wasi_crypto/signatures/rsa.h | 224 +++ plugins/wasi_crypto/signatures/signatures.cpp | 35 + plugins/wasi_crypto/signatures/signatures.h | 38 + plugins/wasi_crypto/signatures/signstate.cpp | 43 + plugins/wasi_crypto/signatures/signstate.h | 47 + .../signatures/verificationstate.cpp | 71 + .../signatures/verificationstate.h | 47 + plugins/wasi_crypto/symmetric/aeads.cpp | 231 +++ plugins/wasi_crypto/symmetric/aeads.h | 184 +++ plugins/wasi_crypto/symmetric/ctx.cpp | 312 ++++ plugins/wasi_crypto/symmetric/func.cpp | 674 ++++++++ plugins/wasi_crypto/symmetric/func.h | 253 +++ plugins/wasi_crypto/symmetric/hash.cpp | 94 ++ plugins/wasi_crypto/symmetric/hash.h | 139 ++ plugins/wasi_crypto/symmetric/kdf.cpp | 156 ++ plugins/wasi_crypto/symmetric/kdf.h | 249 +++ plugins/wasi_crypto/symmetric/key.cpp | 44 + plugins/wasi_crypto/symmetric/key.h | 42 + plugins/wasi_crypto/symmetric/mac.cpp | 98 ++ plugins/wasi_crypto/symmetric/mac.h | 150 ++ plugins/wasi_crypto/symmetric/options.cpp | 115 ++ plugins/wasi_crypto/symmetric/options.h | 68 + plugins/wasi_crypto/symmetric/registed.h | 47 + plugins/wasi_crypto/symmetric/state.cpp | 232 +++ plugins/wasi_crypto/symmetric/state.h | 101 ++ plugins/wasi_crypto/symmetric/tag.cpp | 35 + plugins/wasi_crypto/symmetric/tag.h | 57 + plugins/wasi_crypto/utils/error.h | 51 + plugins/wasi_crypto/utils/evp_wrapper.cpp | 218 +++ plugins/wasi_crypto/utils/evp_wrapper.h | 139 ++ plugins/wasi_crypto/utils/handles_manager.h | 202 +++ plugins/wasi_crypto/utils/hostfunction.cpp | 164 ++ plugins/wasi_crypto/utils/hostfunction.h | 160 ++ plugins/wasi_crypto/utils/optional.h | 132 ++ plugins/wasi_crypto/utils/secret_vec.h | 79 + test/plugins/CMakeLists.txt | 5 +- test/plugins/wasi_crypto/CMakeLists.txt | 24 + test/plugins/wasi_crypto/aeads.cpp | 116 ++ test/plugins/wasi_crypto/asymmetric.cpp | 605 +++++++ test/plugins/wasi_crypto/common.cpp | 129 ++ test/plugins/wasi_crypto/hash.cpp | 120 ++ test/plugins/wasi_crypto/helper.cpp | 1436 +++++++++++++++++ test/plugins/wasi_crypto/helper.h | 361 +++++ test/plugins/wasi_crypto/kdf.cpp | 88 + test/plugins/wasi_crypto/kx.cpp | 128 ++ test/plugins/wasi_crypto/mac.cpp | 136 ++ test/plugins/wasi_crypto/notimplement.cpp | 47 + test/plugins/wasi_crypto/signatures.cpp | 107 ++ utils/docker/build-manylinux.sh | 10 +- 96 files changed, 13681 insertions(+), 8 deletions(-) create mode 100644 plugins/wasi_crypto/CMakeLists.txt create mode 100644 plugins/wasi_crypto/asymmetric_common/ctx.cpp create mode 100644 plugins/wasi_crypto/asymmetric_common/ecdsa.h create mode 100644 plugins/wasi_crypto/asymmetric_common/func.cpp create mode 100644 plugins/wasi_crypto/asymmetric_common/func.h create mode 100644 plugins/wasi_crypto/asymmetric_common/keypair.cpp create mode 100644 plugins/wasi_crypto/asymmetric_common/keypair.h create mode 100644 plugins/wasi_crypto/asymmetric_common/publickey.cpp create mode 100644 plugins/wasi_crypto/asymmetric_common/publickey.h create mode 100644 plugins/wasi_crypto/asymmetric_common/registed.h create mode 100644 plugins/wasi_crypto/asymmetric_common/secretkey.cpp create mode 100644 plugins/wasi_crypto/asymmetric_common/secretkey.h create mode 100644 plugins/wasi_crypto/common/array_output.cpp create mode 100644 plugins/wasi_crypto/common/array_output.h create mode 100644 plugins/wasi_crypto/common/ctx.cpp create mode 100644 plugins/wasi_crypto/common/func.cpp create mode 100644 plugins/wasi_crypto/common/func.h create mode 100644 plugins/wasi_crypto/common/options.cpp create mode 100644 plugins/wasi_crypto/common/options.h create mode 100644 plugins/wasi_crypto/ctx.cpp create mode 100644 plugins/wasi_crypto/ctx.h create mode 100644 plugins/wasi_crypto/kx/ctx.cpp create mode 100644 plugins/wasi_crypto/kx/dh/ecdsa.cpp create mode 100644 plugins/wasi_crypto/kx/dh/ecdsa.h create mode 100644 plugins/wasi_crypto/kx/dh/x25519.cpp create mode 100644 plugins/wasi_crypto/kx/dh/x25519.h create mode 100644 plugins/wasi_crypto/kx/func.cpp create mode 100644 plugins/wasi_crypto/kx/func.h create mode 100644 plugins/wasi_crypto/kx/kx.cpp create mode 100644 plugins/wasi_crypto/kx/kx.h create mode 100644 plugins/wasi_crypto/kx/options.cpp create mode 100644 plugins/wasi_crypto/kx/options.h create mode 100644 plugins/wasi_crypto/kx/registed.h create mode 100644 plugins/wasi_crypto/module.cpp create mode 100644 plugins/wasi_crypto/module.h create mode 100644 plugins/wasi_crypto/signatures/ctx.cpp create mode 100644 plugins/wasi_crypto/signatures/ecdsa.cpp create mode 100644 plugins/wasi_crypto/signatures/ecdsa.h create mode 100644 plugins/wasi_crypto/signatures/eddsa.cpp create mode 100644 plugins/wasi_crypto/signatures/eddsa.h create mode 100644 plugins/wasi_crypto/signatures/func.cpp create mode 100644 plugins/wasi_crypto/signatures/func.h create mode 100644 plugins/wasi_crypto/signatures/options.cpp create mode 100644 plugins/wasi_crypto/signatures/options.h create mode 100644 plugins/wasi_crypto/signatures/registed.h create mode 100644 plugins/wasi_crypto/signatures/rsa.cpp create mode 100644 plugins/wasi_crypto/signatures/rsa.h create mode 100644 plugins/wasi_crypto/signatures/signatures.cpp create mode 100644 plugins/wasi_crypto/signatures/signatures.h create mode 100644 plugins/wasi_crypto/signatures/signstate.cpp create mode 100644 plugins/wasi_crypto/signatures/signstate.h create mode 100644 plugins/wasi_crypto/signatures/verificationstate.cpp create mode 100644 plugins/wasi_crypto/signatures/verificationstate.h create mode 100644 plugins/wasi_crypto/symmetric/aeads.cpp create mode 100644 plugins/wasi_crypto/symmetric/aeads.h create mode 100644 plugins/wasi_crypto/symmetric/ctx.cpp create mode 100644 plugins/wasi_crypto/symmetric/func.cpp create mode 100644 plugins/wasi_crypto/symmetric/func.h create mode 100644 plugins/wasi_crypto/symmetric/hash.cpp create mode 100644 plugins/wasi_crypto/symmetric/hash.h create mode 100644 plugins/wasi_crypto/symmetric/kdf.cpp create mode 100644 plugins/wasi_crypto/symmetric/kdf.h create mode 100644 plugins/wasi_crypto/symmetric/key.cpp create mode 100644 plugins/wasi_crypto/symmetric/key.h create mode 100644 plugins/wasi_crypto/symmetric/mac.cpp create mode 100644 plugins/wasi_crypto/symmetric/mac.h create mode 100644 plugins/wasi_crypto/symmetric/options.cpp create mode 100644 plugins/wasi_crypto/symmetric/options.h create mode 100644 plugins/wasi_crypto/symmetric/registed.h create mode 100644 plugins/wasi_crypto/symmetric/state.cpp create mode 100644 plugins/wasi_crypto/symmetric/state.h create mode 100644 plugins/wasi_crypto/symmetric/tag.cpp create mode 100644 plugins/wasi_crypto/symmetric/tag.h create mode 100644 plugins/wasi_crypto/utils/error.h create mode 100644 plugins/wasi_crypto/utils/evp_wrapper.cpp create mode 100644 plugins/wasi_crypto/utils/evp_wrapper.h create mode 100644 plugins/wasi_crypto/utils/handles_manager.h create mode 100644 plugins/wasi_crypto/utils/hostfunction.cpp create mode 100644 plugins/wasi_crypto/utils/hostfunction.h create mode 100644 plugins/wasi_crypto/utils/optional.h create mode 100644 plugins/wasi_crypto/utils/secret_vec.h create mode 100644 test/plugins/wasi_crypto/CMakeLists.txt create mode 100644 test/plugins/wasi_crypto/aeads.cpp create mode 100644 test/plugins/wasi_crypto/asymmetric.cpp create mode 100644 test/plugins/wasi_crypto/common.cpp create mode 100644 test/plugins/wasi_crypto/hash.cpp create mode 100644 test/plugins/wasi_crypto/helper.cpp create mode 100644 test/plugins/wasi_crypto/helper.h create mode 100644 test/plugins/wasi_crypto/kdf.cpp create mode 100644 test/plugins/wasi_crypto/kx.cpp create mode 100644 test/plugins/wasi_crypto/mac.cpp create mode 100644 test/plugins/wasi_crypto/notimplement.cpp create mode 100644 test/plugins/wasi_crypto/signatures.cpp diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index d87e230f..e9244d3d 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -2,10 +2,14 @@ # SPDX-FileCopyrightText: 2019-2022 Second State INC # Only Linux systems support wasmedge_process now. -if(CMAKE_SYSTEM_NAME MATCHES "Linux") +if (CMAKE_SYSTEM_NAME MATCHES "Linux") add_subdirectory(wasmedge_process) endif() -if(WASMEDGE_PLUGIN_WASI_NN_BACKEND) +if (WASMEDGE_PLUGIN_WASI_NN_BACKEND) add_subdirectory(wasi_nn) endif() + +if (WASMEDGE_PLUGIN_WASI_CRYPTO) + add_subdirectory(wasi_crypto) +endif() diff --git a/plugins/wasi_crypto/CMakeLists.txt b/plugins/wasi_crypto/CMakeLists.txt new file mode 100644 index 00000000..dd9ecf19 --- /dev/null +++ b/plugins/wasi_crypto/CMakeLists.txt @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +find_package(OpenSSL REQUIRED) + +wasmedge_add_library(wasmedgePluginWasiCrypto + SHARED + module.cpp + ctx.cpp + asymmetric_common/ctx.cpp + asymmetric_common/func.cpp + asymmetric_common/keypair.cpp + asymmetric_common/publickey.cpp + asymmetric_common/secretkey.cpp + common/array_output.cpp + common/ctx.cpp + common/func.cpp + common/options.cpp + kx/ctx.cpp + kx/dh/ecdsa.cpp + kx/dh/x25519.cpp + kx/func.cpp + kx/kx.cpp + kx/options.cpp + signatures/ctx.cpp + signatures/ecdsa.cpp + signatures/eddsa.cpp + signatures/func.cpp + signatures/options.cpp + signatures/rsa.cpp + signatures/signatures.cpp + signatures/signstate.cpp + signatures/verificationstate.cpp + symmetric/aeads.cpp + symmetric/ctx.cpp + symmetric/func.cpp + symmetric/hash.cpp + symmetric/kdf.cpp + symmetric/key.cpp + symmetric/mac.cpp + symmetric/options.cpp + symmetric/state.cpp + symmetric/tag.cpp + utils/evp_wrapper.cpp + utils/hostfunction.cpp +) + +target_compile_options(wasmedgePluginWasiCrypto + PUBLIC + -DWASMEDGE_PLUGIN +) + +if(CMAKE_SYSTEM_NAME MATCHES "Darwin") + target_link_options(wasmedgePluginWasiCrypto + PUBLIC + -Wl,-U,__ZN8WasmEdge6Plugin14PluginRegisterC1EPKNS0_6Plugin16PluginDescriptorE + -Wl,-U,__ZN8WasmEdge6Plugin14PluginRegisterD1Ev + ) +endif() + +target_include_directories(wasmedgePluginWasiCrypto + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} + ${PROJECT_SOURCE_DIR}/thirdparty +) + +target_link_libraries(wasmedgePluginWasiCrypto + PUBLIC + OpenSSL::Crypto + wasmedgeCommon + wasmedgeSystem +) + +install(TARGETS wasmedgePluginWasiCrypto DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) diff --git a/plugins/wasi_crypto/asymmetric_common/ctx.cpp b/plugins/wasi_crypto/asymmetric_common/ctx.cpp new file mode 100644 index 00000000..551bd5cc --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/ctx.cpp @@ -0,0 +1,197 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "ctx.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +WasiCryptoExpect<__wasi_array_output_t> +Context::publickeyExport(__wasi_publickey_t PkHandle, + __wasi_publickey_encoding_e_t Encoding) noexcept { + return PublicKeyManager.get(PkHandle) + .and_then([Encoding](auto &&Pk) { + return AsymmetricCommon::pkExportData(std::forward(Pk), + Encoding); + }) + .and_then([this](auto &&Data) { + return ArrayOutputManager.registerManager( + std::forward(Data)); + }); +} + +WasiCryptoExpect +Context::publickeyVerify(__wasi_publickey_t PkHandle) noexcept { + return PublicKeyManager.get(PkHandle).and_then(AsymmetricCommon::pkVerify); +} + +WasiCryptoExpect +Context::publickeyClose(__wasi_publickey_t PkHandle) noexcept { + return PublicKeyManager.close(PkHandle); +} + +WasiCryptoExpect<__wasi_array_output_t> +Context::secretkeyExport(__wasi_secretkey_t SkHandle, + __wasi_secretkey_encoding_e_t Encoding) noexcept { + return SecretKeyManager.get(SkHandle) + .and_then([Encoding](auto &&Sk) { + return AsymmetricCommon::skExportData(std::forward(Sk), + Encoding); + }) + .and_then([this](auto &&Data) noexcept { + return ArrayOutputManager.registerManager( + std::forward(Data)); + }); +} + +WasiCryptoExpect +Context::secretkeyClose(__wasi_secretkey_t SkHandle) noexcept { + return SecretKeyManager.close(SkHandle); +} + +WasiCryptoExpect<__wasi_publickey_t> +Context::publickeyFromSecretkey(__wasi_secretkey_t SkHandle) noexcept { + return SecretKeyManager.get(SkHandle) + .and_then(AsymmetricCommon::skPublicKey) + .and_then([this](auto &&Pk) noexcept { + return PublicKeyManager.registerManager(std::forward(Pk)); + }); +} + +WasiCryptoExpect<__wasi_array_output_t> +Context::keypairExport(__wasi_keypair_t KpHandle, + __wasi_keypair_encoding_e_t Encoding) noexcept { + return KeyPairManager.get(KpHandle) + .and_then([Encoding](auto &&Kp) noexcept { + return AsymmetricCommon::kpExportData(std::forward(Kp), + Encoding); + }) + .and_then([this](auto &&Data) noexcept { + return ArrayOutputManager.registerManager( + std::forward(Data)); + }); +} + +WasiCryptoExpect<__wasi_publickey_t> +Context::keypairPublickey(__wasi_keypair_t KpHandle) noexcept { + return KeyPairManager.get(KpHandle) + .and_then(AsymmetricCommon::kpPublicKey) + .and_then([this](auto &&Pk) noexcept { + return PublicKeyManager.registerManager(std::forward(Pk)); + }); +} + +WasiCryptoExpect<__wasi_secretkey_t> +Context::keypairSecretkey(__wasi_keypair_t KpHandle) noexcept { + return KeyPairManager.get(KpHandle) + .and_then(AsymmetricCommon::kpSecretKey) + .and_then([this](auto &&Sk) noexcept { + return SecretKeyManager.registerManager(std::forward(Sk)); + }); +} + +WasiCryptoExpect +Context::keypairClose(__wasi_keypair_t KpHandle) noexcept { + return KeyPairManager.close(KpHandle); +} + +WasiCryptoExpect<__wasi_keypair_t> +Context::keypairFromPkAndSk(__wasi_publickey_t PkHandle, + __wasi_secretkey_t SkHandle) noexcept { + auto Pk = PublicKeyManager.get(PkHandle); + if (!Pk) { + return WasiCryptoUnexpect(Pk); + } + + auto Sk = SecretKeyManager.get(SkHandle); + if (!Sk) { + return WasiCryptoUnexpect(Sk); + } + + return AsymmetricCommon::kpFromPkAndSk(*Pk, *Sk).and_then( + [this](auto &&Kp) noexcept { + return KeyPairManager.registerManager(std::forward(Kp)); + }); +} + +WasiCryptoExpect<__wasi_keypair_t> +Context::keypairGenerate(AsymmetricCommon::Algorithm Alg, + __wasi_opt_options_t OptOptionsHandle) noexcept { + return mapAndTransposeOptional( + OptOptionsHandle, + [this](__wasi_options_t OptionsHandle) noexcept { + return OptionsManager.get(OptionsHandle); + }) + .and_then([Alg](auto &&OptOptions) noexcept { + return AsymmetricCommon::generateKp( + Alg, asOptionalRef(std::forward(OptOptions))); + }) + .and_then([this](auto &&Kp) noexcept { + return KeyPairManager.registerManager(std::forward(Kp)); + }); +} + +WasiCryptoExpect<__wasi_keypair_t> +Context::keypairImport(AsymmetricCommon::Algorithm Alg, + Span Encoded, + __wasi_keypair_encoding_e_t Encoding) noexcept { + return AsymmetricCommon::importKp(Alg, Encoded, Encoding) + .and_then([this](auto &&Kp) noexcept { + return KeyPairManager.registerManager(std::forward(Kp)); + }); +} + +WasiCryptoExpect<__wasi_publickey_t> +Context::publickeyImport(AsymmetricCommon::Algorithm Alg, + Span Encoded, + __wasi_publickey_encoding_e_t Encoding) noexcept { + return AsymmetricCommon::importPk(Alg, Encoded, Encoding) + .and_then([this](auto &&Pk) noexcept { + return PublicKeyManager.registerManager(std::forward(Pk)); + }); +} + +WasiCryptoExpect<__wasi_secretkey_t> +Context::secretkeyImport(AsymmetricCommon::Algorithm Alg, + Span Encoded, + __wasi_secretkey_encoding_e_t Encoding) noexcept { + return AsymmetricCommon::importSk(Alg, Encoded, Encoding) + .and_then([this](auto &&Sk) noexcept { + return SecretKeyManager.registerManager(std::forward(Sk)); + }); +} + +WasiCryptoExpect<__wasi_keypair_t> +Context::keypairGenerateManaged(__wasi_secrets_manager_t, + AsymmetricCommon::Algorithm, + __wasi_opt_options_t) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect Context::keypairStoreManaged(__wasi_secrets_manager_t, + __wasi_keypair_t, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect<__wasi_version_t> +Context::keypairReplaceManaged(__wasi_secrets_manager_t, __wasi_keypair_t, + __wasi_keypair_t) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect> +Context::keypairId(__wasi_keypair_t, Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect<__wasi_keypair_t> +Context::keypairFromId(__wasi_secrets_manager_t, Span, + __wasi_version_t) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/asymmetric_common/ecdsa.h b/plugins/wasi_crypto/asymmetric_common/ecdsa.h new file mode 100644 index 00000000..2941c12c --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/ecdsa.h @@ -0,0 +1,408 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/asymmetric_common/ecdsa.h - Ecdsa alg-===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the definition of ecdsa algorithm. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "kx/options.h" +#include "utils/error.h" +#include "utils/evp_wrapper.h" +#include "utils/optional.h" +#include "utils/secret_vec.h" + +#include "common/span.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace AsymmetricCommon { + +template +class Ecdsa { + inline static const size_t UnCompressedPkSize = 65; + inline static const size_t CompressedPkSize = 33; + constexpr static size_t getRawPkSize(bool Compressed) { + return Compressed ? CompressedPkSize : UnCompressedPkSize; + } + + inline static const size_t SkSize = 32; + + constexpr static point_conversion_form_t getForm(bool Compressed) noexcept { + return Compressed ? POINT_CONVERSION_COMPRESSED + : POINT_CONVERSION_UNCOMPRESSED; + } + +public: + class PublicKeyBase { + public: + PublicKeyBase(EvpPkeyPtr Ctx) noexcept : Ctx(std::move(Ctx)) {} + + PublicKeyBase(SharedEvpPkey Ctx) noexcept : Ctx(std::move(Ctx)) {} + + static WasiCryptoExpect + import(Span Encoded, + __wasi_publickey_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_PUBLICKEY_ENCODING_PKCS8: + return importPkcs8(Encoded, false); + case __WASI_PUBLICKEY_ENCODING_COMPRESSED_PKCS8: + return importPkcs8(Encoded, true); + case __WASI_PUBLICKEY_ENCODING_PEM: + return importPem(Encoded, false); + case __WASI_PUBLICKEY_ENCODING_COMPRESSED_PEM: + return importPem(Encoded, true); + case __WASI_PUBLICKEY_ENCODING_SEC: + return importSec(Encoded, false); + case __WASI_PUBLICKEY_ENCODING_COMPRESSED_SEC: + return importSec(Encoded, true); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } + } + + WasiCryptoExpect> + exportData(__wasi_publickey_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_PUBLICKEY_ENCODING_SEC: + return exportSec(false); + case __WASI_PUBLICKEY_ENCODING_COMPRESSED_SEC: + return exportSec(true); + case __WASI_PUBLICKEY_ENCODING_PEM: + return exportPem(false); + case __WASI_PUBLICKEY_ENCODING_COMPRESSED_PEM: + return exportPem(true); + case __WASI_PUBLICKEY_ENCODING_PKCS8: + return exportPkcs8(false); + case __WASI_PUBLICKEY_ENCODING_COMPRESSED_PKCS8: + return exportPkcs8(true); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } + } + + WasiCryptoExpect verify() const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + } + + protected: + static WasiCryptoExpect importPkcs8(Span Encoded, + bool Compressed) noexcept { + return checkValid(EvpPkeyPtr{d2iPUBKEY(Encoded)}, Compressed); + } + + static WasiCryptoExpect importPem(Span Encoded, + bool Compressed) noexcept { + return checkValid(EvpPkeyPtr{pemReadPUBKEY(Encoded)}, Compressed); + } + + static WasiCryptoExpect importSec(Span Encoded, + bool Compressed) noexcept { + EcKeyPtr EcCtx{EC_KEY_new_by_curve_name(CurveNid)}; + EcPointPtr Pk{EC_POINT_new(EC_KEY_get0_group(EcCtx.get()))}; + ensureOrReturn(EC_POINT_oct2point(EC_KEY_get0_group(EcCtx.get()), + Pk.get(), Encoded.data(), + Encoded.size(), nullptr), + __WASI_CRYPTO_ERRNO_INVALID_KEY); + opensslCheck(EC_KEY_set_public_key(EcCtx.get(), Pk.get())); + + EvpPkeyPtr Ctx{EVP_PKEY_new()}; + opensslCheck(EVP_PKEY_set1_EC_KEY(Ctx.get(), EcCtx.get())); + + return checkValid(std::move(Ctx), Compressed); + } + + static WasiCryptoExpect checkValid(EvpPkeyPtr Ctx, + bool) noexcept { + ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); + EC_KEY *EcCtx = EVP_PKEY_get0_EC_KEY(Ctx.get()); + ensureOrReturn(EcCtx, __WASI_CRYPTO_ERRNO_INVALID_KEY); + + const EC_GROUP *Group = EC_KEY_get0_group(EcCtx); + ensureOrReturn(Group, __WASI_CRYPTO_ERRNO_INVALID_KEY); + ensureOrReturn(EC_GROUP_get_curve_name(Group) == CurveNid, + __WASI_CRYPTO_ERRNO_INVALID_KEY); + return {std::move(Ctx)}; + } + + WasiCryptoExpect> + exportSec(bool Compressed) const noexcept { + EC_KEY *EcCtx = EVP_PKEY_get0_EC_KEY(Ctx.get()); + std::vector Res(getRawPkSize(Compressed)); + opensslCheck(EC_POINT_point2oct( + EC_KEY_get0_group(EcCtx), EC_KEY_get0_public_key(EcCtx), + getForm(Compressed), Res.data(), Res.size(), nullptr)); + return Res; + } + + WasiCryptoExpect> + exportPem(bool Compressed) const noexcept { + EC_KEY_set_conv_form(EVP_PKEY_get0_EC_KEY(Ctx.get()), + getForm(Compressed)); + + return pemWritePUBKEY(Ctx.get()); + } + + WasiCryptoExpect> + exportPkcs8(bool Compressed) const noexcept { + EC_KEY_set_conv_form(EVP_PKEY_get0_EC_KEY(Ctx.get()), + getForm(Compressed)); + + return i2dPUBKEY(Ctx.get()); + } + + SharedEvpPkey Ctx; + }; + + class SecretKeyBase { + public: + SecretKeyBase(EvpPkeyPtr Ctx) noexcept : Ctx(std::move(Ctx)) {} + + SecretKeyBase(SharedEvpPkey Ctx) noexcept : Ctx(std::move(Ctx)) {} + + static WasiCryptoExpect + import(Span Encoded, + __wasi_secretkey_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_SECRETKEY_ENCODING_RAW: + return importRaw(Encoded); + case __WASI_SECRETKEY_ENCODING_PKCS8: + return importPkcs8(Encoded); + case __WASI_SECRETKEY_ENCODING_PEM: + return importPem(Encoded); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } + } + + WasiCryptoExpect + exportData(__wasi_secretkey_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_SECRETKEY_ENCODING_RAW: + return exportRaw(); + case __WASI_SECRETKEY_ENCODING_PKCS8: + return exportPkcs8(); + case __WASI_SECRETKEY_ENCODING_PEM: + return exportPem(); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } + } + + WasiCryptoExpect publicKey() const noexcept { + // Since the inner is always `const`, we just increase the ref count. + return Ctx; + } + + WasiCryptoExpect toKeyPair(const PublicKey &) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + } + + protected: + static WasiCryptoExpect + importPkcs8(Span Encoded) noexcept { + return checkValid(EvpPkeyPtr{d2iPrivateKey(Encoded)}); + } + + static WasiCryptoExpect + importPem(Span Encoded) noexcept { + return checkValid(EvpPkeyPtr{pemReadPrivateKey(Encoded)}); + } + + static WasiCryptoExpect + importRaw(Span Encoded) noexcept { + EcKeyPtr EcCtx{EC_KEY_new_by_curve_name(CurveNid)}; + ensureOrReturn(Encoded.size() == SkSize, __WASI_CRYPTO_ERRNO_INVALID_KEY); + + BnPtr Sk{ + BN_bin2bn(Encoded.data(), static_cast(Encoded.size()), nullptr)}; + ensureOrReturn(EC_KEY_set_private_key(EcCtx.get(), Sk.get()), + __WASI_CRYPTO_ERRNO_INVALID_KEY); + + EvpPkeyPtr Ctx{EVP_PKEY_new()}; + opensslCheck(EVP_PKEY_set1_EC_KEY(Ctx.get(), EcCtx.get())); + + return Ctx; + } + + static WasiCryptoExpect checkValid(EvpPkeyPtr Ctx) noexcept { + ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); + EC_KEY *EcCtx = EVP_PKEY_get0_EC_KEY(Ctx.get()); + ensureOrReturn(EcCtx, __WASI_CRYPTO_ERRNO_INVALID_KEY); + + const EC_GROUP *Group = EC_KEY_get0_group(EcCtx); + ensureOrReturn(Group, __WASI_CRYPTO_ERRNO_INVALID_KEY); + ensureOrReturn(EC_GROUP_get_curve_name(Group) == CurveNid, + __WASI_CRYPTO_ERRNO_INVALID_KEY); + return {std::move(Ctx)}; + } + + WasiCryptoExpect exportPkcs8() const noexcept { + return i2dPrivateKey(Ctx.get()); + } + + WasiCryptoExpect exportPem() const noexcept { + return pemWritePrivateKey(Ctx.get()); + } + + WasiCryptoExpect exportRaw() const noexcept { + // Must equal to SkSize, not check. + const BIGNUM *Sk = + EC_KEY_get0_private_key(EVP_PKEY_get0_EC_KEY(Ctx.get())); + SecretVec Res(SkSize); + opensslCheck(BN_bn2bin(Sk, Res.data())); + + return Res; + } + + SharedEvpPkey Ctx; + }; + + class KeyPairBase { + public: + KeyPairBase(EvpPkeyPtr Ctx) noexcept : Ctx(std::move(Ctx)) {} + + KeyPairBase(SharedEvpPkey Ctx) noexcept : Ctx(std::move(Ctx)) {} + + static WasiCryptoExpect + generate(OptionalRef) noexcept { + EvpPkeyCtxPtr ParamCtx{EVP_PKEY_CTX_new_id(EVP_PKEY_EC, nullptr)}; + EVP_PKEY_keygen_init(ParamCtx.get()); + EVP_PKEY_CTX_set_ec_paramgen_curve_nid(ParamCtx.get(), CurveNid); + + EVP_PKEY *Key = nullptr; + opensslCheck(EVP_PKEY_keygen(ParamCtx.get(), &Key)); + + return EvpPkeyPtr{Key}; + } + + static WasiCryptoExpect + import(Span Encoded, + __wasi_keypair_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_KEYPAIR_ENCODING_RAW: + return importRaw(Encoded); + case __WASI_KEYPAIR_ENCODING_PKCS8: + return importPkcs8(Encoded, false); + case __WASI_KEYPAIR_ENCODING_PEM: + return importPem(Encoded, false); + case __WASI_KEYPAIR_ENCODING_COMPRESSED_PKCS8: + return importPkcs8(Encoded, true); + case __WASI_KEYPAIR_ENCODING_COMPRESSED_PEM: + return importPem(Encoded, true); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } + } + + WasiCryptoExpect publicKey() const noexcept { + // Since the inner is always `const`, we just increase the ref count. + return Ctx; + } + + WasiCryptoExpect secretKey() const noexcept { + // Since the inner is always `const`, we just increase the ref count. + return Ctx; + } + + WasiCryptoExpect + exportData(__wasi_keypair_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_KEYPAIR_ENCODING_RAW: + return exportRaw(); + case __WASI_KEYPAIR_ENCODING_PKCS8: + return exportPkcs8(); + case __WASI_KEYPAIR_ENCODING_PEM: + return exportPem(); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } + } + + protected: + static WasiCryptoExpect importPkcs8(Span Encoded, + bool Compressed) noexcept { + return checkValid(EvpPkeyPtr{d2iPrivateKey(Encoded)}, Compressed); + } + + static WasiCryptoExpect importPem(Span Encoded, + bool Compressed) noexcept { + return checkValid(EvpPkeyPtr{pemReadPrivateKey(Encoded)}, Compressed); + } + + static WasiCryptoExpect + importRaw(Span Encoded) noexcept { + ensureOrReturn(Encoded.size() == SkSize, __WASI_CRYPTO_ERRNO_INVALID_KEY); + EcKeyPtr EcCtx{EC_KEY_new_by_curve_name(CurveNid)}; + BnPtr Sk{ + BN_bin2bn(Encoded.data(), static_cast(Encoded.size()), nullptr)}; + ensureOrReturn(EC_KEY_set_private_key(EcCtx.get(), Sk.get()), + __WASI_CRYPTO_ERRNO_INVALID_KEY); + + // Calculate and set Pk. + EcPointPtr Pk{EC_POINT_new(EC_KEY_get0_group(EcCtx.get()))}; + opensslCheck(EC_POINT_mul(EC_KEY_get0_group(EcCtx.get()), Pk.get(), + Sk.get(), nullptr, nullptr, nullptr)); + opensslCheck(EC_KEY_set_public_key(EcCtx.get(), Pk.get())); + + EvpPkeyPtr Ctx{EVP_PKEY_new()}; + opensslCheck(EVP_PKEY_set1_EC_KEY(Ctx.get(), EcCtx.get())); + ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); + + return Ctx; + } + + static WasiCryptoExpect checkValid(EvpPkeyPtr Ctx, + bool) noexcept { + ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); + EC_KEY *EcCtx = EVP_PKEY_get0_EC_KEY(Ctx.get()); + ensureOrReturn(EcCtx, __WASI_CRYPTO_ERRNO_INVALID_KEY); + + // Curve id check. + const EC_GROUP *Group = EC_KEY_get0_group(EcCtx); + ensureOrReturn(Group, __WASI_CRYPTO_ERRNO_INVALID_KEY); + ensureOrReturn(EC_GROUP_get_curve_name(Group) == CurveNid, + __WASI_CRYPTO_ERRNO_INVALID_KEY); + // Have public key. + ensureOrReturn(EC_KEY_get0_public_key(EcCtx), + __WASI_CRYPTO_ERRNO_INVALID_KEY); + return {std::move(Ctx)}; + } + + WasiCryptoExpect exportPkcs8() const noexcept { + return i2dPrivateKey(Ctx.get()); + } + + WasiCryptoExpect exportPem() const noexcept { + return pemWritePrivateKey(Ctx.get()); + } + + WasiCryptoExpect exportRaw() const noexcept { + const BIGNUM *Sk = + EC_KEY_get0_private_key(EVP_PKEY_get0_EC_KEY(Ctx.get())); + SecretVec Res(SkSize); + opensslCheck(BN_bn2bin(Sk, Res.data())); + + return Res; + } + + SharedEvpPkey Ctx; + }; +}; + +} // namespace AsymmetricCommon +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/asymmetric_common/func.cpp b/plugins/wasi_crypto/asymmetric_common/func.cpp new file mode 100644 index 00000000..7574cd25 --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/func.cpp @@ -0,0 +1,517 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "asymmetric_common/func.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace AsymmetricCommon { + +Expect +KeypairGenerate::body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t AlgType, uint32_t AlgPtr, uint32_t AlgLen, + uint32_t OptOptionsHandlePtr, + uint32_t /* Out */ KpHandlePtr) { + checkExist(MemInst); + + const __wasi_size_t WasiAlgLen = AlgLen; + auto *const Alg = MemInst->getPointer(AlgPtr, WasiAlgLen); + checkExist(Alg); + + AsymmetricCommon::Algorithm WasiAlg; + if (auto Res = cast<__wasi_algorithm_type_e_t>(AlgType).and_then( + [Alg, WasiAlgLen](auto WasiAlgType) { + return tryFrom(WasiAlgType, {Alg, WasiAlgLen}); + }); + unlikely(!Res)) { + return Res.error(); + } else { + WasiAlg = *Res; + } + + auto *const OptOptionsHandle = + MemInst->getPointer(OptOptionsHandlePtr); + checkExist(OptOptionsHandle); + + auto *const KpHandle = MemInst->getPointer<__wasi_keypair_t *>(KpHandlePtr); + checkExist(KpHandle); + + if (auto Res = Ctx.keypairGenerate(WasiAlg, *OptOptionsHandle); + unlikely(!Res)) { + return Res.error(); + } else { + *KpHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeypairImport::body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t AlgType, uint32_t AlgPtr, + uint32_t AlgLen, uint32_t EncodedPtr, + uint32_t EncodedLen, uint32_t Encoding, + uint32_t /* Out */ KpHandlePtr) { + checkExist(MemInst); + + const __wasi_size_t WasiAlgLen = AlgLen; + auto *const Alg = MemInst->getPointer(AlgPtr, WasiAlgLen); + checkExist(Alg); + + AsymmetricCommon::Algorithm WasiAlg; + if (auto Res = cast<__wasi_algorithm_type_e_t>(AlgType).and_then( + [Alg, WasiAlgLen](auto WasiAlgType) { + return tryFrom(WasiAlgType, {Alg, WasiAlgLen}); + }); + unlikely(!Res)) { + return Res.error(); + } else { + WasiAlg = *Res; + } + + const __wasi_size_t WasiEncodedLen = EncodedLen; + auto *const Encoded = + MemInst->getPointer(EncodedPtr, WasiEncodedLen); + checkExist(Encoded); + + const auto WasiEncoding = cast<__wasi_keypair_encoding_e_t>(Encoding); + checkExist(WasiEncoding); + + auto *const KpHandle = MemInst->getPointer<__wasi_keypair_t *>(KpHandlePtr); + checkExist(KpHandle); + + if (auto Res = + Ctx.keypairImport(WasiAlg, {Encoded, WasiEncodedLen}, *WasiEncoding); + unlikely(!Res)) { + return Res.error(); + } else { + *KpHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeypairGenerateManaged::body( + Runtime::Instance::MemoryInstance *MemInst, int32_t SecretsManagerHandle, + uint32_t AlgType, uint32_t AlgPtr, uint32_t AlgLen, + uint32_t OptOptionsHandlePtr, uint32_t KpHandlePtr) { + checkExist(MemInst); + + const __wasi_size_t WasiAlgLen = AlgLen; + auto *const Alg = MemInst->getPointer(AlgPtr, WasiAlgLen); + checkExist(Alg); + + AsymmetricCommon::Algorithm WasiAlg; + if (auto Res = cast<__wasi_algorithm_type_e_t>(AlgType).and_then( + [Alg, WasiAlgLen](auto WasiAlgType) { + return tryFrom(WasiAlgType, {Alg, WasiAlgLen}); + }); + unlikely(!Res)) { + return Res.error(); + } else { + WasiAlg = *Res; + } + + auto *const OptOptionsHandle = + MemInst->getPointer(OptOptionsHandlePtr); + checkExist(OptOptionsHandle); + + auto *const KpHandle = MemInst->getPointer<__wasi_keypair_t *>(KpHandlePtr); + checkExist(KpHandle); + + if (auto Res = Ctx.keypairGenerateManaged(SecretsManagerHandle, WasiAlg, + *OptOptionsHandle); + unlikely(!Res)) { + return Res.error(); + } else { + *KpHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +KeypairStoreManaged::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SecretsManagerHandle, int32_t KpHandle, + uint32_t KpIdPtr, uint32_t KpIdMaxLen) { + checkExist(MemInst); + + const __wasi_size_t WasiKpIdMaxLen = KpIdMaxLen; + auto *const KpId = MemInst->getPointer(KpIdPtr, WasiKpIdMaxLen); + checkExist(KpId); + + if (auto Res = Ctx.keypairStoreManaged(SecretsManagerHandle, KpHandle, + {KpId, WasiKpIdMaxLen}); + unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeypairReplaceManaged::body( + Runtime::Instance::MemoryInstance *MemInst, int32_t SecretsManagerHandle, + int32_t OldKpHandle, int32_t NewKpHandle, uint32_t /* Out */ KpVersionPtr) { + checkExist(MemInst); + + auto *const KpVersion = MemInst->getPointer<__wasi_version_t *>(KpVersionPtr); + checkExist(KpVersion); + + if (auto Res = Ctx.keypairReplaceManaged(SecretsManagerHandle, OldKpHandle, + NewKpHandle); + unlikely(!Res)) { + return Res.error(); + } else { + *KpVersion = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeypairId::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t KpHandle, uint32_t KpIdPtr, + uint32_t KpIdMaxLen, + uint32_t /* Out */ SizePtr, + uint32_t /* Out */ KpVersionPtr) { + checkExist(MemInst); + + const __wasi_size_t WasiKpIdMaxLen = KpIdMaxLen; + auto *const KpId = MemInst->getPointer(KpIdPtr, WasiKpIdMaxLen); + checkExist(KpId); + + auto *const Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); + checkExist(Size); + + auto *const Version = MemInst->getPointer<__wasi_version_t *>(KpVersionPtr); + checkExist(Version); + + if (auto Res = Ctx.keypairId(KpHandle, {KpId, WasiKpIdMaxLen}); + unlikely(!Res)) { + return Res.error(); + } else { + auto [ResSize, ResVersion] = *Res; + + auto SafeResSize = toWasiSize(ResSize); + if (unlikely(!SafeResSize)) { + return SafeResSize.error(); + } + + *Size = *SafeResSize; + + *Version = ResVersion; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeypairFromId::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SecretsManagerHandle, + uint32_t KpIdPtr, uint32_t KpIdLen, + uint64_t KpVersion, + uint32_t /* Out */ KpHandlePtr) { + + checkExist(MemInst); + + const __wasi_size_t WasiKpIdLen = KpIdLen; + auto *const KpId = MemInst->getPointer(KpIdPtr, WasiKpIdLen); + checkExist(KpId); + + auto *const KpHandle = MemInst->getPointer<__wasi_keypair_t *>(KpHandlePtr); + checkExist(KpHandle); + + if (auto Res = Ctx.keypairFromId(SecretsManagerHandle, {KpId, WasiKpIdLen}, + KpVersion); + unlikely(!Res)) { + return Res.error(); + } else { + *KpHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +KeypairFromPkAndSk::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t PkHandle, int32_t SkHandle, + uint32_t /* Out */ KpHandlePtr) { + checkExist(MemInst); + + auto *const KpHandle = MemInst->getPointer<__wasi_keypair_t *>(KpHandlePtr); + checkExist(KpHandle); + + if (auto Res = Ctx.keypairFromPkAndSk(PkHandle, SkHandle); unlikely(!Res)) { + return Res.error(); + } else { + *KpHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeypairExport::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t KpHandle, uint32_t KpEncoding, + uint32_t /* Out */ ArrayOutputHandlePtr) { + checkExist(MemInst); + + __wasi_keypair_encoding_e_t WasiKpEncoding; + if (auto Res = cast<__wasi_keypair_encoding_e_t>(KpEncoding); + unlikely(!Res)) { + return Res.error(); + } else { + WasiKpEncoding = *Res; + } + + auto *const ArrayOutputHandle = + MemInst->getPointer<__wasi_array_output_t *>(ArrayOutputHandlePtr); + checkExist(ArrayOutputHandle); + + if (auto Res = Ctx.keypairExport(KpHandle, WasiKpEncoding); unlikely(!Res)) { + return Res.error(); + } else { + *ArrayOutputHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +KeypairPublickey::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t KpHandle, uint32_t /* Out */ PkHandlePtr) { + checkExist(MemInst); + + auto *const PkHandle = MemInst->getPointer<__wasi_keypair_t *>(PkHandlePtr); + checkExist(PkHandle); + + if (auto Res = Ctx.keypairPublickey(KpHandle); unlikely(!Res)) { + return Res.error(); + } else { + *PkHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +KeypairSecretkey::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t KpHandle, uint32_t /* Out */ SkHandlePtr) { + checkExist(MemInst); + + auto *const SkHandle = MemInst->getPointer<__wasi_keypair_t *>(SkHandlePtr); + checkExist(SkHandle); + + if (auto Res = Ctx.keypairSecretkey(KpHandle); unlikely(!Res)) { + return Res.error(); + } else { + *SkHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeypairClose::body(Runtime::Instance::MemoryInstance *, + int32_t KpHandle) { + if (auto Res = Ctx.keypairClose(KpHandle); unlikely(!Res)) { + return __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +PublickeyImport::body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t AlgType, uint32_t AlgPtr, uint32_t AlgLen, + uint32_t EncodedPtr, uint32_t EncodedLen, + uint32_t Encoding, uint32_t /* Out */ PkHandlePtr) { + checkExist(MemInst); + + const __wasi_size_t WasiAlgLen = AlgLen; + auto *const Alg = MemInst->getPointer(AlgPtr, WasiAlgLen); + checkExist(Alg); + + AsymmetricCommon::Algorithm WasiAlg; + if (auto Res = cast<__wasi_algorithm_type_e_t>(AlgType).and_then( + [Alg, WasiAlgLen](auto WasiAlgType) { + return tryFrom(WasiAlgType, {Alg, WasiAlgLen}); + }); + unlikely(!Res)) { + return Res.error(); + } else { + WasiAlg = *Res; + } + + const __wasi_size_t WasiEncodedLen = EncodedLen; + auto *const Encoded = + MemInst->getPointer(EncodedPtr, WasiEncodedLen); + checkExist(Encoded); + + __wasi_publickey_encoding_e_t WasiPkEncoding; + if (auto Res = cast<__wasi_publickey_encoding_e_t>(Encoding); !Res) { + return Res.error(); + } else { + WasiPkEncoding = *Res; + } + + auto *const PkHandle = MemInst->getPointer<__wasi_publickey_t *>(PkHandlePtr); + checkExist(PkHandle); + + if (auto Res = Ctx.publickeyImport(WasiAlg, {Encoded, WasiEncodedLen}, + WasiPkEncoding); + unlikely(!Res)) { + return Res.error(); + } else { + *PkHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +PublickeyExport::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t PkHandle, uint32_t PkEncoding, + uint32_t /* Out */ ArrayOutputHandlePtr) { + checkExist(MemInst); + + __wasi_publickey_encoding_e_t WasiPkEncoding; + if (auto Res = cast<__wasi_publickey_encoding_e_t>(PkEncoding); !Res) { + return Res.error(); + } else { + WasiPkEncoding = *Res; + } + + auto *const ArrayOutputHandle = + MemInst->getPointer<__wasi_array_output_t *>(ArrayOutputHandlePtr); + checkExist(ArrayOutputHandle); + + if (auto Res = Ctx.publickeyExport(PkHandle, WasiPkEncoding); + unlikely(!Res)) { + return Res.error(); + } else { + *ArrayOutputHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect PublickeyVerify::body(Runtime::Instance::MemoryInstance *, + int32_t PkHandle) { + if (auto Res = Ctx.publickeyVerify(PkHandle); unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +PublickeyFromSecretkey::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SkHandle, uint32_t /* Out */ PkHandlePtr) { + checkExist(MemInst); + + auto *const PkHandle = MemInst->getPointer<__wasi_publickey_t *>(PkHandlePtr); + checkExist(PkHandle); + + if (auto Res = Ctx.publickeyFromSecretkey(SkHandle); unlikely(!Res)) { + return Res.error(); + } else { + *PkHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect PublickeyClose::body(Runtime::Instance::MemoryInstance *, + int32_t PkHandle) { + if (auto Res = Ctx.publickeyClose(PkHandle); unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +SecretkeyImport::body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t AlgType, uint32_t AlgPtr, uint32_t AlgLen, + uint32_t EncodedPtr, uint32_t EncodedLen, + uint32_t Encoding, uint32_t /* Out */ SkHandlePtr) { + checkExist(MemInst); + + const __wasi_size_t WasiAlgLen = AlgLen; + auto *const Alg = MemInst->getPointer(AlgPtr, WasiAlgLen); + checkExist(Alg); + + AsymmetricCommon::Algorithm WasiAlg; + if (auto Res = cast<__wasi_algorithm_type_e_t>(AlgType).and_then( + [Alg, WasiAlgLen](auto WasiAlgType) { + return tryFrom(WasiAlgType, {Alg, WasiAlgLen}); + }); + unlikely(!Res)) { + return Res.error(); + } else { + WasiAlg = *Res; + } + + const __wasi_size_t WasiEncodedLen = EncodedLen; + auto *const Encoded = + MemInst->getPointer(EncodedPtr, WasiEncodedLen); + checkExist(Encoded); + + auto WasiEncoding = cast<__wasi_secretkey_encoding_e_t>(Encoding); + if (!WasiEncoding) { + return WasiEncoding.error(); + } + + auto *const SkHandle = MemInst->getPointer<__wasi_secretkey_t *>(SkHandlePtr); + checkExist(SkHandle); + + if (auto Res = Ctx.secretkeyImport(WasiAlg, {Encoded, WasiEncodedLen}, + *WasiEncoding); + unlikely(!Res)) { + return Res.error(); + } else { + *SkHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +SecretkeyExport::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SkHandle, uint32_t SkEncoding, + uint32_t /* Out */ ArrayOutputHandlePtr) { + checkExist(MemInst); + + __wasi_secretkey_encoding_e_t WasiSkEncoding; + if (auto Res = cast<__wasi_secretkey_encoding_e_t>(SkEncoding); !Res) { + return Res.error(); + } else { + WasiSkEncoding = *Res; + } + + auto *const ArrayOutputHandle = + MemInst->getPointer<__wasi_array_output_t *>(ArrayOutputHandlePtr); + checkExist(ArrayOutputHandle); + + if (auto Res = Ctx.secretkeyExport(SkHandle, WasiSkEncoding); + unlikely(!Res)) { + return Res.error(); + } else { + *ArrayOutputHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect SecretkeyClose::body(Runtime::Instance::MemoryInstance *, + int32_t Sk) { + if (auto Res = Ctx.secretkeyClose(Sk); unlikely(!Res)) { + return __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +} // namespace AsymmetricCommon +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/asymmetric_common/func.h b/plugins/wasi_crypto/asymmetric_common/func.h new file mode 100644 index 00000000..2a5be27d --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/func.h @@ -0,0 +1,188 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/asymmetric_common/func.h -------------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the asymmetric common host functions of wasi-crypto. +/// +//===----------------------------------------------------------------------===// +#pragma once + +#include "utils/hostfunction.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace AsymmetricCommon { + +class KeypairGenerate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t AlgType, uint32_t AlgPtr, uint32_t AlgLen, + uint32_t OptOptionsHandlePtr, + uint32_t /* Out */ KpHandlePtr); +}; + +class KeypairImport : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t AlgType, uint32_t AlgPtr, uint32_t AlgLen, + uint32_t EncodedPtr, uint32_t EncodedLen, + uint32_t Encoding, uint32_t /* Out */ KpHandlePtr); +}; + +class KeypairGenerateManaged : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SecretsManagerHandle, uint32_t AlgType, + uint32_t AlgPtr, uint32_t AlgLen, + uint32_t OptOptionsHandlePtr, uint32_t KpHandlePtr); +}; + +class KeypairStoreManaged : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SecretsManagerHandle, int32_t KpHandle, + uint32_t KpIdPtr, uint32_t KpIdMaxLen); +}; + +class KeypairReplaceManaged : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SecretsManagerHandle, int32_t OldKpHandle, + int32_t NewKpHandle, uint32_t /* Out */ KpVersionPtr); +}; + +class KeypairId : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t KpHandle, uint32_t KpIdPtr, uint32_t KpIdMaxLen, + uint32_t /* Out */ SizePtr, + uint32_t /* Out */ KpVersionPtr); +}; + +class KeypairFromId : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SecretsManagerHandle, uint32_t KpIdPtr, + uint32_t KpIdLen, uint64_t KpVersion, + uint32_t /* Out */ KpHandlePtr); +}; + +class KeypairFromPkAndSk : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t PkHandle, int32_t SkHandle, + uint32_t /* Out */ KpHandlePtr); +}; + +class KeypairExport : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t KpHandle, uint32_t KpEncoding, + uint32_t /* Out */ ArrayOutputHandlePtr); +}; + +class KeypairPublickey : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t KpHandle, uint32_t /* Out */ PkHandlePtr); +}; + +class KeypairSecretkey : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t KpHandle, uint32_t /* Out */ SkHandlePtr); +}; + +class KeypairClose : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t KpHandle); +}; + +class PublickeyImport : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t AlgType, uint32_t AlgPtr, uint32_t AlgLen, + uint32_t EncodedPtr, uint32_t EncodedLen, + uint32_t Encoding, uint32_t /* Out */ PkHandlePtr); +}; + +class PublickeyExport : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t PkHandle, uint32_t PkEncoding, + uint32_t /* Out */ ArrayOutputHandlePtr); +}; + +class PublickeyVerify : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t PkHandle); +}; + +class PublickeyFromSecretkey : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SkHandle, uint32_t /* Out */ PkHandlePtr); +}; + +class PublickeyClose : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t PkHandle); +}; + +class SecretkeyImport : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t AlgType, uint32_t AlgPtr, uint32_t AlgLen, + uint32_t EncodedPtr, uint32_t EncodedLen, + uint32_t Encoding, uint32_t /* Out */ SkHandlePtr); +}; + +class SecretkeyExport : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SkHandle, uint32_t SkEncoding, + uint32_t /* Out */ ArrayOutputHandlePtr); +}; + +class SecretkeyClose : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SkHandle); +}; + +} // namespace AsymmetricCommon +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/asymmetric_common/keypair.cpp b/plugins/wasi_crypto/asymmetric_common/keypair.cpp new file mode 100644 index 00000000..ec543c10 --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/keypair.cpp @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "asymmetric_common/keypair.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace AsymmetricCommon { + +WasiCryptoExpect +importKp(AsymmetricCommon::Algorithm Alg, Span Encoded, + __wasi_keypair_encoding_e_t Encoding) noexcept { + return std::visit( + [=](auto Factory) noexcept -> WasiCryptoExpect { + return decltype(Factory)::KeyPair::import(Encoded, Encoding); + }, + Alg); +} + +namespace { +/// Correspond signatures: +/// WasiCryptoExpect generate(OptionalRef); +/// is used to get the `OptionsType`. +template struct KpGenerateTrait; +template +struct KpGenerateTrait (*)( + OptionalRef) noexcept> { + using Options = OptionsType; +}; +template +using OptionsType = + typename KpGenerateTrait::Options; +} // namespace + +WasiCryptoExpect +generateKp(AsymmetricCommon::Algorithm Alg, + OptionalRef OptOptions) noexcept { + return std::visit( + [=](auto Factory) noexcept -> WasiCryptoExpect { + using RequiredOptionsType = OptionsType; + return transposeOptionalRef( + OptOptions, + [](auto &&Options) noexcept + -> WasiCryptoExpect> { + using InOptionsType = std::decay_t; + if constexpr (std::is_same_v) { + return Options; + } else { + return WasiCryptoUnexpect( + __WASI_CRYPTO_ERRNO_INVALID_HANDLE); + } + }) + .and_then([](auto OptRequiredOptions) noexcept { + return decltype(Factory)::KeyPair::generate(OptRequiredOptions); + }); + }, + Alg); +} + +namespace { +/// Correspond signatures: +/// WasiCryptoExpect Sk::toKeyPair(const PublicKeyType&); +/// is used to get the `PublicKeyType`. +template struct KpFromPkAndSkTrait; +template +struct KpFromPkAndSkTrait (SecretKeyType::*)( + const PublicKeyType &) const noexcept> { + using PublicKey = PublicKeyType; +}; +template +using PkType = typename KpFromPkAndSkTrait::PublicKey; +} // namespace + +WasiCryptoExpect kpFromPkAndSk(const PkVariant &PkVariant, + const SkVariant &SkVariant) noexcept { + return std::visit( + [](const auto &Pk, + const auto &Sk) noexcept -> WasiCryptoExpect { + using RequiredPkType = PkType>; + using InPkType = std::decay_t; + if constexpr (std::is_same_v) { + return Sk.toKeyPair(Pk); + } else { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_KEY); + } + }, + PkVariant, SkVariant); +} + +WasiCryptoExpect +kpExportData(const KpVariant &KpVariant, + __wasi_keypair_encoding_e_t Encoding) noexcept { + return std::visit( + [Encoding](const auto &Kp) noexcept { return Kp.exportData(Encoding); }, + KpVariant); +} + +WasiCryptoExpect kpPublicKey(const KpVariant &KpVariant) noexcept { + return std::visit( + [](const auto &Kp) noexcept { + return Kp.publicKey().map([](auto &&Pk) noexcept { + return PkVariant{std::forward(Pk)}; + }); + }, + KpVariant); +} + +WasiCryptoExpect kpSecretKey(const KpVariant &KpVariant) noexcept { + return std::visit( + [](const auto &Kp) noexcept { + return Kp.secretKey().map([](auto &&Sk) noexcept { + return SkVariant{std::forward(Sk)}; + }); + }, + KpVariant); +} + +} // namespace AsymmetricCommon +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/asymmetric_common/keypair.h b/plugins/wasi_crypto/asymmetric_common/keypair.h new file mode 100644 index 00000000..1cdd138f --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/keypair.h @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/asymmetric_common/keypair.h ----------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the asymmetric common Keypair of wasi-crypto. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "asymmetric_common/publickey.h" +#include "asymmetric_common/registed.h" +#include "asymmetric_common/secretkey.h" +#include "common/options.h" +#include "utils/error.h" +#include "utils/optional.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace AsymmetricCommon { + +using KpVariant = RegistedAlg::KpVariant; + +WasiCryptoExpect +importKp(AsymmetricCommon::Algorithm Alg, Span Encoded, + __wasi_keypair_encoding_e_t Encoding) noexcept; + +WasiCryptoExpect +generateKp(AsymmetricCommon::Algorithm Alg, + OptionalRef OptOptions) noexcept; + +WasiCryptoExpect +kpExportData(const KpVariant &KpVariant, + __wasi_keypair_encoding_e_t Encoding) noexcept; + +WasiCryptoExpect kpPublicKey(const KpVariant &KpVariant) noexcept; + +WasiCryptoExpect kpSecretKey(const KpVariant &KpVariant) noexcept; + +WasiCryptoExpect kpFromPkAndSk(const PkVariant &Pk, + const SkVariant &Sk) noexcept; +} // namespace AsymmetricCommon +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/asymmetric_common/publickey.cpp b/plugins/wasi_crypto/asymmetric_common/publickey.cpp new file mode 100644 index 00000000..5e9f1bae --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/publickey.cpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "asymmetric_common/publickey.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace AsymmetricCommon { + +WasiCryptoExpect +importPk(AsymmetricCommon::Algorithm Alg, Span Encoded, + __wasi_publickey_encoding_e_t Encoding) noexcept { + return std::visit( + [=](auto Factory) noexcept -> WasiCryptoExpect { + return decltype(Factory)::PublicKey::import(Encoded, Encoding); + }, + Alg); +} + +WasiCryptoExpect> +pkExportData(const PkVariant &PkVariant, + __wasi_publickey_encoding_e_t Encoding) noexcept { + return std::visit( + [Encoding](const auto &Pk) noexcept { return Pk.exportData(Encoding); }, + PkVariant); +} + +WasiCryptoExpect pkVerify(const PkVariant &PkVariant) noexcept { + return std::visit([](const auto &Pk) noexcept { return Pk.verify(); }, + PkVariant); +} + +} // namespace AsymmetricCommon +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/asymmetric_common/publickey.h b/plugins/wasi_crypto/asymmetric_common/publickey.h new file mode 100644 index 00000000..36dd7bf3 --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/publickey.h @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/asymmetric_common/publickey.h --------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the asymmetric common PubicKey of wasi-crypto. +/// +//===----------------------------------------------------------------------===// +#pragma once + +#include "asymmetric_common/registed.h" +#include "utils/error.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace AsymmetricCommon { + +using PkVariant = RegistedAlg::PkVariant; + +WasiCryptoExpect +importPk(AsymmetricCommon::Algorithm Alg, Span Encoded, + __wasi_publickey_encoding_e_t Encoding) noexcept; + +WasiCryptoExpect> +pkExportData(const PkVariant &PkVariant, + __wasi_publickey_encoding_e_t Encoding) noexcept; + +WasiCryptoExpect pkVerify(const PkVariant &PkVariant) noexcept; + +} // namespace AsymmetricCommon +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/asymmetric_common/registed.h b/plugins/wasi_crypto/asymmetric_common/registed.h new file mode 100644 index 00000000..560476de --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/registed.h @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/asymmetric/registed.h - Registed -----===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the register asymmetric common algorithm definitions. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "kx/registed.h" +#include "signatures/registed.h" +#include "utils/error.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace AsymmetricCommon { + +template struct Registed { + using PkVariant = std::variant; + using SkVariant = std::variant; + using KpVariant = std::variant; + using Variant = std::variant; +}; + +template +struct Registed, Kx::Registed> { + using Alg = Registed; +}; + +/// Combine the signatures and kx algoritms. +using RegistedAlg = Registed::Alg; + +using Algorithm = RegistedAlg::Variant; + +} // namespace AsymmetricCommon +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/asymmetric_common/secretkey.cpp b/plugins/wasi_crypto/asymmetric_common/secretkey.cpp new file mode 100644 index 00000000..f76b2180 --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/secretkey.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "asymmetric_common/secretkey.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace AsymmetricCommon { + +WasiCryptoExpect +importSk(AsymmetricCommon::Algorithm Alg, Span Encoded, + __wasi_secretkey_encoding_e_t Encoding) noexcept { + return std::visit( + [=](auto Factory) noexcept -> WasiCryptoExpect { + return decltype(Factory)::SecretKey::import(Encoded, Encoding); + }, + Alg); +} + +WasiCryptoExpect +skExportData(const SkVariant &SkVariant, + __wasi_secretkey_encoding_e_t Encoding) noexcept { + return std::visit( + [Encoding](const auto &Sk) noexcept { return Sk.exportData(Encoding); }, + SkVariant); +} + +WasiCryptoExpect skPublicKey(const SkVariant &SkVariant) noexcept { + return std::visit( + [](const auto &Sk) noexcept { + return Sk.publicKey().map([](auto &&Pk) noexcept { + return PkVariant{std::forward(Pk)}; + }); + }, + SkVariant); +} + +} // namespace AsymmetricCommon +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/asymmetric_common/secretkey.h b/plugins/wasi_crypto/asymmetric_common/secretkey.h new file mode 100644 index 00000000..16214c0b --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/secretkey.h @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/asymmetric_common/secretkey.h --------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the asymmetric common SecretKey of wasi-crypto. +/// +//===----------------------------------------------------------------------===// +#pragma once + +#include "asymmetric_common/publickey.h" +#include "asymmetric_common/registed.h" +#include "wasi_crypto/api.hpp" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace AsymmetricCommon { + +using SkVariant = RegistedAlg::SkVariant; + +WasiCryptoExpect +importSk(AsymmetricCommon::Algorithm Alg, Span Encoded, + __wasi_secretkey_encoding_e_t Encoding) noexcept; + +WasiCryptoExpect +skExportData(const SkVariant &SkVariant, + __wasi_secretkey_encoding_e_t Encoding) noexcept; + +WasiCryptoExpect skPublicKey(const SkVariant &SkVariant) noexcept; + +} // namespace AsymmetricCommon +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/common/array_output.cpp b/plugins/wasi_crypto/common/array_output.cpp new file mode 100644 index 00000000..79ec82e7 --- /dev/null +++ b/plugins/wasi_crypto/common/array_output.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "common/array_output.h" + +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Common { + +std::tuple ArrayOutput::pull(Span Buf) noexcept { + std::scoped_lock Lock{Mutex}; + + using DataPosT = decltype(Data)::difference_type; + + size_t OutputSize = std::min(Buf.size(), Data.size() - Pos); + + std::copy(Data.begin() + static_cast(Pos), + Data.begin() + static_cast(Pos + OutputSize), + Buf.begin()); + Pos += OutputSize; + + return {OutputSize, Pos + OutputSize == Data.size()}; +} + +} // namespace Common +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/common/array_output.h b/plugins/wasi_crypto/common/array_output.h new file mode 100644 index 00000000..51238281 --- /dev/null +++ b/plugins/wasi_crypto/common/array_output.h @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/common/array_output.h - ArrayOutput --===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the ArrayOutput class definition. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "utils/error.h" +#include "utils/secret_vec.h" + +#include "common/span.h" + +#include +#include +#include +#include +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Common { + +/// Functions returning arrays whose size is not constant or too large to be +/// safely allocated on the stack return a handle to an ArrayOutput type. +/// +/// More detail: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#array-outputs +class ArrayOutput { +public: + ArrayOutput(const ArrayOutput &) noexcept = delete; + ArrayOutput &operator=(const ArrayOutput &) noexcept = delete; + ArrayOutput &operator=(ArrayOutput &&) noexcept = delete; + ArrayOutput(ArrayOutput &&) noexcept = delete; + + ArrayOutput(std::vector &&Data) noexcept : Data(std::move(Data)) {} + + ArrayOutput(SecretVec &&Data) noexcept : Data(std::move(Data)) {} + + /// Copy the content to the @param Buf buffer. + /// Multiple calls are possible, the total number of bytes to be read is + /// guaranteed to always match the data size. + /// + /// @returns the number of bytes read. If all pull, return true. + std::tuple pull(Span Buf) noexcept; + + /// Return ArrayOutput data size. + size_t len() const noexcept { return Data.size(); } + +private: + const SecretVec Data; + size_t Pos = 0; + std::mutex Mutex; +}; + +} // namespace Common +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/common/ctx.cpp b/plugins/wasi_crypto/common/ctx.cpp new file mode 100644 index 00000000..03996922 --- /dev/null +++ b/plugins/wasi_crypto/common/ctx.cpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "ctx.h" +#include "common/array_output.h" +#include "common/options.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +WasiCryptoExpect +Context::arrayOutputLen(__wasi_array_output_t ArrayOutputHandle) noexcept { + return ArrayOutputManager.get(ArrayOutputHandle) + .map(&Common::ArrayOutput::len); +} + +WasiCryptoExpect +Context::arrayOutputPull(__wasi_array_output_t ArrayOutputHandle, + Span Buf) noexcept { + return ArrayOutputManager.get(ArrayOutputHandle) + .map([=](Common::ArrayOutput &ArrayOutput) noexcept { + auto [Size, AlreadyConsumed] = ArrayOutput.pull(Buf); + if (AlreadyConsumed) { + ArrayOutputManager.close(ArrayOutputHandle); + } + return Size; + }); +} + +WasiCryptoExpect<__wasi_options_t> +Context::optionsOpen(__wasi_algorithm_type_e_t AlgType) noexcept { + return OptionsManager.registerManager(Common::optionsOpen(AlgType)); +} + +WasiCryptoExpect +Context::optionsClose(__wasi_options_t OptionsHandle) noexcept { + return OptionsManager.close(OptionsHandle); +} + +WasiCryptoExpect Context::optionsSet(__wasi_options_t OptionsHandle, + std::string_view Name, + Span Value) noexcept { + return OptionsManager.get(OptionsHandle) + .and_then([Name, Value](auto &&Options) noexcept { + return Common::optionsSet(Options, Name, Value); + }); +} + +WasiCryptoExpect Context::optionsSetU64(__wasi_options_t OptionsHandle, + std::string_view Name, + uint64_t Value) noexcept { + return OptionsManager.get(OptionsHandle) + .and_then([Name, Value](auto &&Options) noexcept { + return Common::optionsSetU64(Options, Name, Value); + }); +} + +WasiCryptoExpect +Context::optionsSetGuestBuffer(__wasi_options_t OptionsHandle, + std::string_view Name, + Span Buf) noexcept { + return OptionsManager.get(OptionsHandle) + .and_then([Name, Buf](auto &&Options) noexcept { + return Common::optionsSetGuestBuffer(Options, Name, Buf); + }); +} + +WasiCryptoExpect<__wasi_secrets_manager_t> +Context::secretsManagerOpen(__wasi_opt_options_t) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect +Context::secretsManagerClose(__wasi_secrets_manager_t) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect +Context::secretsManagerInvalidate(__wasi_secrets_manager_t, Span, + __wasi_version_t) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/common/func.cpp b/plugins/wasi_crypto/common/func.cpp new file mode 100644 index 00000000..1da1a208 --- /dev/null +++ b/plugins/wasi_crypto/common/func.cpp @@ -0,0 +1,215 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "common/func.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Common { + +Expect +ArrayOutputLen::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t ArrayOutputHandle, uint32_t /* Out */ SizePtr) { + // Check memory instance from module. + checkExist(MemInst); + + auto *const Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); + checkExist(Size); + + if (auto Res = Ctx.arrayOutputLen(ArrayOutputHandle).and_then(toWasiSize); + unlikely(!Res)) { + return Res.error(); + } else { + *Size = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +ArrayOutputPull::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t ArrayOutputHandle, uint32_t BufPtr, + uint32_t BufLen, uint32_t /* Out */ SizePtr) { + // Check memory instance from module. + checkExist(MemInst); + + const __wasi_size_t WasiBufLen = BufLen; + auto *const Buf = MemInst->getPointer(BufPtr, WasiBufLen); + checkExist(Buf); + + auto *const Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); + checkExist(Size); + + if (auto Res = Ctx.arrayOutputPull(ArrayOutputHandle, {Buf, WasiBufLen}) + .and_then(toWasiSize); + unlikely(!Res)) { + return Res.error(); + } else { + *Size = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect OptionsOpen::body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t AlgType, + uint32_t /* Out */ OptionsHandlePtr) { + // Check memory instance from module. + checkExist(MemInst); + + __wasi_algorithm_type_e_t WasiAlgType; + if (auto Res = cast<__wasi_algorithm_type_e_t>(AlgType); unlikely(!Res)) { + return Res.error(); + } else { + WasiAlgType = *Res; + } + + auto *const OptionsHandle = + MemInst->getPointer<__wasi_options_t *>(OptionsHandlePtr); + checkExist(OptionsHandle); + + if (auto Res = Ctx.optionsOpen(WasiAlgType); unlikely(!Res)) { + return Res.error(); + } else { + *OptionsHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect OptionsClose::body(Runtime::Instance::MemoryInstance *, + int32_t OptionsHandle) { + + if (auto Res = Ctx.optionsClose(OptionsHandle); unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect OptionsSet::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t OptionsHandle, uint32_t NamePtr, + uint32_t NameLen, uint32_t ValuePtr, + uint32_t ValueLen) { + // Check memory instance from module. + checkExist(MemInst); + + const __wasi_size_t WasiNameLen = NameLen; + auto *const Name = MemInst->getPointer(NamePtr, WasiNameLen); + checkExist(Name); + + const __wasi_size_t WasiValueLen = ValueLen; + auto *const Value = + MemInst->getPointer(ValuePtr, WasiValueLen); + checkExist(Value); + + if (auto Res = Ctx.optionsSet(OptionsHandle, {Name, WasiNameLen}, + {Value, WasiValueLen}); + unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect OptionsSetU64::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t OptionsHandle, uint32_t NamePtr, + uint32_t NameLen, uint64_t Value) { + // Check memory instance from module. + checkExist(MemInst); + + const __wasi_size_t WasiNameLen = NameLen; + auto *const Name = MemInst->getPointer(NamePtr, WasiNameLen); + checkExist(Name); + + if (auto Res = Ctx.optionsSetU64(OptionsHandle, {Name, WasiNameLen}, Value); + unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect OptionsSetGuestBuffer::body( + Runtime::Instance::MemoryInstance *MemInst, int32_t OptionsHandle, + uint32_t NamePtr, uint32_t NameLen, uint32_t BufPtr, uint32_t BufLen) { + // Check memory instance from module. + checkExist(MemInst); + + const __wasi_size_t WasiNameLen = NameLen; + auto *const Name = MemInst->getPointer(NamePtr, WasiNameLen); + checkExist(Name); + + const __wasi_size_t WasiBufLen = BufLen; + auto *const Buf = MemInst->getPointer(BufPtr, WasiBufLen); + checkExist(Buf); + + if (auto Res = Ctx.optionsSetGuestBuffer(OptionsHandle, {Name, WasiNameLen}, + {Buf, WasiBufLen}); + unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +SecretsManagerOpen::body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t OptOptionsHandlePtr, + uint32_t /* Out */ SecretsManagerHandlePtr) { + // Check memory instance from module. + checkExist(MemInst); + + auto *const OptOptionsHandle = + MemInst->getPointer(OptOptionsHandlePtr); + checkExist(OptOptionsHandle); + + auto *const SecretsManagerHandle = + MemInst->getPointer<__wasi_secrets_manager_t *>(SecretsManagerHandlePtr); + checkExist(SecretsManagerHandle); + + if (auto Res = Ctx.secretsManagerOpen(*OptOptionsHandle); unlikely(!Res)) { + return Res.error(); + } else { + *SecretsManagerHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect SecretsManagerClose::body(Runtime::Instance::MemoryInstance *, + int32_t SecretsManagerHandle) { + if (auto Res = Ctx.secretsManagerClose(SecretsManagerHandle); + unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +SecretsManagerInvalidate::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SecretsManagerHandle, uint32_t KeyIdPtr, + uint32_t KeyIdLen, uint64_t Version) { + // Check memory instance from module. + checkExist(MemInst); + + const __wasi_size_t WasiKeyIdLen = KeyIdLen; + auto *const KeyId = + MemInst->getPointer(KeyIdPtr, WasiKeyIdLen); + checkExist(KeyId); + + if (auto Res = Ctx.secretsManagerInvalidate(SecretsManagerHandle, + {KeyId, WasiKeyIdLen}, Version); + unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +} // namespace Common +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/common/func.h b/plugins/wasi_crypto/common/func.h new file mode 100644 index 00000000..c08bbede --- /dev/null +++ b/plugins/wasi_crypto/common/func.h @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/common/func.h - Common func ----------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the common host functions of wasi-crypto. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "utils/hostfunction.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Common { + +class ArrayOutputLen : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t ArrayOutputHandle, uint32_t /* Out */ SizePtr); +}; + +class ArrayOutputPull : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t ArrayOutputHandle, uint32_t BufPtr, + uint32_t BufLen, uint32_t /* Out */ SizePtr); +}; + +class OptionsOpen : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t AlgType, uint32_t /* Out */ OptionsHandlePtr); +}; + +class OptionsClose : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t OptionsHandle); +}; + +class OptionsSet : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t OptionsHandle, uint32_t NamePtr, + uint32_t NameLen, uint32_t ValuePtr, uint32_t ValueLen); +}; + +class OptionsSetU64 : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t OptionsHandle, uint32_t NamePtr, + uint32_t NameLen, uint64_t Value); +}; + +class OptionsSetGuestBuffer : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t OptionsHandle, uint32_t NamePtr, + uint32_t NameLen, uint32_t BufPtr, uint32_t BufLen); +}; + +class SecretsManagerOpen : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t OptOptionsHandlePtr, + uint32_t /* Out */ SecretsManagerHandlePtr); +}; + +class SecretsManagerClose : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SecretsManagerHandle); +}; + +class SecretsManagerInvalidate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SecretsManagerHandle, uint32_t KeyIdPtr, + uint32_t KeyIdLen, uint64_t Version); +}; + +} // namespace Common +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/common/options.cpp b/plugins/wasi_crypto/common/options.cpp new file mode 100644 index 00000000..f93f2fe0 --- /dev/null +++ b/plugins/wasi_crypto/common/options.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "common/options.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Common { + +Options optionsOpen(__wasi_algorithm_type_e_t Alg) noexcept { + switch (Alg) { + case __WASI_ALGORITHM_TYPE_SIGNATURES: + return Options{std::in_place_type}; + case __WASI_ALGORITHM_TYPE_SYMMETRIC: + return Options{std::in_place_type}; + case __WASI_ALGORITHM_TYPE_KEY_EXCHANGE: + return Options{std::in_place_type}; + default: + assumingUnreachable(); + } +} + +WasiCryptoExpect optionsSet(Options &Options, std::string_view Name, + Span Value) noexcept { + return std::visit( + [Name, Value](auto &Option) noexcept { return Option.set(Name, Value); }, + Options); +} + +WasiCryptoExpect optionsSetU64(Options &Options, std::string_view Name, + uint64_t Value) noexcept { + return std::visit( + [Name, Value](auto &Option) noexcept { + return Option.setU64(Name, Value); + }, + Options); +} + +WasiCryptoExpect optionsSetGuestBuffer(Options &Options, + std::string_view Name, + Span Value) noexcept { + return std::visit( + [Name, Value](auto &Option) noexcept { + return Option.setGuestBuffer(Name, Value); + }, + Options); +} + +} // namespace Common +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/common/options.h b/plugins/wasi_crypto/common/options.h new file mode 100644 index 00000000..de22b64d --- /dev/null +++ b/plugins/wasi_crypto/common/options.h @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/common/options.h - Options definition ===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the Options definition of wasi-crypto. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "kx/options.h" +#include "signatures/options.h" +#include "symmetric/options.h" +#include "wasi_crypto/api.hpp" + +#include "common/span.h" + +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Common { + +/// Some functions support options. For example, options can be used to access +/// the features that are only relevant to specific ciphers and hash functions. +/// +/// Options are represented as a (key, value) map, keys are strings. They are +/// attached to a context, such as a cipher state. Applications can set, and +/// also read the value associated with a key in order to either get the default +/// value or obtain the runtime information. +/// +/// More detail: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#options +using Options = + std::variant; + +Options optionsOpen(__wasi_algorithm_type_e_t Alg) noexcept; + +/// Set byte vectors. +WasiCryptoExpect optionsSet(Options &Options, std::string_view Name, + Span Value) noexcept; + +/// Set unsigned integers. +WasiCryptoExpect optionsSetU64(Options &Options, std::string_view Name, + uint64_t Value) noexcept; + +/// Set memory buffers. +WasiCryptoExpect optionsSetGuestBuffer(Options &Options, + std::string_view Name, + Span Value) noexcept; +} // namespace Common +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/ctx.cpp b/plugins/wasi_crypto/ctx.cpp new file mode 100644 index 00000000..a7457632 --- /dev/null +++ b/plugins/wasi_crypto/ctx.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "ctx.h" +#include "module.h" + +namespace WasmEdge { +namespace Host { + +namespace { + +Runtime::Instance::ModuleInstance *create(void) noexcept { + return new WasiCryptoModule; +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasi_crypto", + .Description = "", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 10, 1, 0}, + .ModuleCount = 1, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "wasi_crypto", + .Description = "", + .Create = create, + }, + }, + .AddOptions = nullptr, +}; + +} // namespace + +Plugin::PluginRegister WasiCrypto::Context::Register(&Descriptor); + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/ctx.h b/plugins/wasi_crypto/ctx.h new file mode 100644 index 00000000..da07acd5 --- /dev/null +++ b/plugins/wasi_crypto/ctx.h @@ -0,0 +1,344 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/ctx.h - Context class definition -----===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the declaration of the wasi-crypto context. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "asymmetric_common/keypair.h" +#include "asymmetric_common/publickey.h" +#include "asymmetric_common/secretkey.h" +#include "common/array_output.h" +#include "common/options.h" +#include "common/span.h" +#include "kx/registed.h" +#include "signatures/registed.h" +#include "signatures/signatures.h" +#include "signatures/signstate.h" +#include "signatures/verificationstate.h" +#include "symmetric/key.h" +#include "symmetric/registed.h" +#include "symmetric/state.h" +#include "symmetric/tag.h" +#include "utils/error.h" +#include "utils/handles_manager.h" + +#include "plugin/plugin.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +class Context { +public: + // Common + + WasiCryptoExpect + arrayOutputLen(__wasi_array_output_t ArrayOutputHandle) noexcept; + + WasiCryptoExpect + arrayOutputPull(__wasi_array_output_t ArrayOutputHandle, + Span Buf) noexcept; + + WasiCryptoExpect<__wasi_options_t> + optionsOpen(__wasi_algorithm_type_e_t AlgType) noexcept; + + WasiCryptoExpect optionsClose(__wasi_options_t OptionsHandle) noexcept; + + WasiCryptoExpect optionsSet(__wasi_options_t OptionsHandle, + std::string_view Name, + Span Value) noexcept; + + WasiCryptoExpect optionsSetU64(__wasi_options_t OptionsHandle, + std::string_view Name, + uint64_t Value) noexcept; + + WasiCryptoExpect optionsSetGuestBuffer(__wasi_options_t OptionsHandle, + std::string_view Name, + Span Buf) noexcept; + + WasiCryptoExpect<__wasi_secrets_manager_t> + secretsManagerOpen(__wasi_opt_options_t OptOptionsHandle) noexcept; + + WasiCryptoExpect + secretsManagerClose(__wasi_secrets_manager_t SecretsManagerHandle) noexcept; + + WasiCryptoExpect + secretsManagerInvalidate(__wasi_secrets_manager_t SecretsManagerHandle, + Span KeyId, + __wasi_version_t Version) noexcept; + + // Symmetric + + WasiCryptoExpect<__wasi_symmetric_key_t> + symmetricKeyGenerate(Symmetric::Algorithm Alg, + __wasi_opt_options_t OptOptionsHandle) noexcept; + + WasiCryptoExpect<__wasi_symmetric_key_t> + symmetricKeyImport(Symmetric::Algorithm Alg, + Span Raw) noexcept; + + WasiCryptoExpect<__wasi_array_output_t> + symmetricKeyExport(__wasi_symmetric_key_t KeyHandle) noexcept; + + WasiCryptoExpect + symmetricKeyClose(__wasi_symmetric_key_t KeyHandle) noexcept; + + WasiCryptoExpect<__wasi_symmetric_key_t> + symmetricKeyGenerateManaged(__wasi_secrets_manager_t SecretsManagerHandle, + Symmetric::Algorithm Alg, + __wasi_opt_options_t OptOptionsHandle) noexcept; + + WasiCryptoExpect + symmetricKeyStoreManaged(__wasi_secrets_manager_t SecretsManagerHandle, + __wasi_symmetric_key_t KeyHandle, + Span KeyId) noexcept; + + WasiCryptoExpect<__wasi_version_t> + symmetricKeyReplaceManaged(__wasi_secrets_manager_t SecretsManagerHandle, + __wasi_symmetric_key_t OldKeyHandle, + __wasi_symmetric_key_t NewKeyHandle) noexcept; + + WasiCryptoExpect> + symmetricKeyId(__wasi_symmetric_key_t KeyHandle, + Span KeyId) noexcept; + + WasiCryptoExpect<__wasi_symmetric_key_t> + symmetricKeyFromId(__wasi_secrets_manager_t SecretsManagerHandle, + Span KeyId, __wasi_version_t KeyVersion) noexcept; + + WasiCryptoExpect<__wasi_symmetric_state_t> + symmetricStateOpen(Symmetric::Algorithm Alg, + __wasi_opt_symmetric_key_t OptKeyHandle, + __wasi_opt_options_t OptOptionsHandle) noexcept; + + WasiCryptoExpect<__wasi_symmetric_state_t> + symmetricStateClone(__wasi_symmetric_state_t StateHandle) noexcept; + + WasiCryptoExpect + symmetricStateOptionsGet(__wasi_symmetric_state_t StateHandle, + std::string_view Name, Span Value) noexcept; + + WasiCryptoExpect + symmetricStateOptionsGetU64(__wasi_symmetric_state_t StateHandle, + std::string_view Name) noexcept; + + WasiCryptoExpect + symmetricStateClose(__wasi_symmetric_state_t StateHandle) noexcept; + + WasiCryptoExpect + symmetricStateAbsorb(__wasi_symmetric_state_t StateHandle, + Span Data) noexcept; + + WasiCryptoExpect + symmetricStateSqueeze(__wasi_symmetric_state_t StateHandle, + Span Out) noexcept; + + WasiCryptoExpect<__wasi_symmetric_tag_t> + symmetricStateSqueezeTag(__wasi_symmetric_state_t StateHandle) noexcept; + + WasiCryptoExpect<__wasi_symmetric_key_t> + symmetricStateSqueezeKey(__wasi_symmetric_state_t StateHandle, + Symmetric::Algorithm Alg) noexcept; + + WasiCryptoExpect + symmetricStateMaxTagLen(__wasi_symmetric_state_t StateHandle) noexcept; + + WasiCryptoExpect + symmetricStateEncrypt(__wasi_symmetric_state_t StateHandle, Span Out, + Span Data) noexcept; + + WasiCryptoExpect<__wasi_symmetric_tag_t> + symmetricStateEncryptDetached(__wasi_symmetric_state_t StateHandle, + Span Out, + Span Data) noexcept; + + WasiCryptoExpect + symmetricStateDecrypt(__wasi_symmetric_state_t StateHandle, Span Out, + Span Data) noexcept; + + WasiCryptoExpect + symmetricStateDecryptDetached(__wasi_symmetric_state_t StateHandle, + Span Out, Span Data, + Span RawTag) noexcept; + + WasiCryptoExpect + symmetricStateRatchet(__wasi_symmetric_state_t StateHandle) noexcept; + + WasiCryptoExpect + symmetricTagLen(__wasi_symmetric_tag_t TagHandle) noexcept; + + WasiCryptoExpect symmetricTagPull(__wasi_symmetric_tag_t TagHandle, + Span Buf) noexcept; + + WasiCryptoExpect + symmetricTagVerify(__wasi_symmetric_tag_t TagHandle, + Span RawTag) noexcept; + + WasiCryptoExpect + symmetricTagClose(__wasi_symmetric_tag_t TagHandle) noexcept; + + // Asymmetric_common + + WasiCryptoExpect<__wasi_keypair_t> + keypairGenerate(AsymmetricCommon::Algorithm Alg, + __wasi_opt_options_t OptOptionsHandle) noexcept; + + WasiCryptoExpect<__wasi_keypair_t> + keypairImport(AsymmetricCommon::Algorithm Alg, Span Encoded, + __wasi_keypair_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect<__wasi_keypair_t> + keypairGenerateManaged(__wasi_secrets_manager_t SecretsManagerHandle, + AsymmetricCommon::Algorithm Alg, + __wasi_opt_options_t OptOptionsHandle) noexcept; + + WasiCryptoExpect + keypairStoreManaged(__wasi_secrets_manager_t SecretsManagerHandle, + __wasi_keypair_t KpHandle, Span KpId) noexcept; + + WasiCryptoExpect<__wasi_version_t> + keypairReplaceManaged(__wasi_secrets_manager_t SecretsManagerHandle, + __wasi_keypair_t OldKpHandle, + __wasi_keypair_t NewKpHandle) noexcept; + + WasiCryptoExpect> + keypairId(__wasi_keypair_t KpHandle, Span KpId) noexcept; + + WasiCryptoExpect<__wasi_keypair_t> + keypairFromId(__wasi_secrets_manager_t SecretsManagerHandle, + Span KpId, + __wasi_version_t KpIdVersion) noexcept; + + WasiCryptoExpect<__wasi_keypair_t> + keypairFromPkAndSk(__wasi_publickey_t PkHandle, + __wasi_secretkey_t SkHandle) noexcept; + + WasiCryptoExpect<__wasi_array_output_t> + keypairExport(__wasi_keypair_t KpHandle, + __wasi_keypair_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect<__wasi_publickey_t> + keypairPublickey(__wasi_keypair_t KpHandle) noexcept; + + WasiCryptoExpect<__wasi_secretkey_t> + keypairSecretkey(__wasi_keypair_t KpHandle) noexcept; + + WasiCryptoExpect keypairClose(__wasi_keypair_t KpHandle) noexcept; + + WasiCryptoExpect<__wasi_publickey_t> + publickeyImport(AsymmetricCommon::Algorithm Alg, Span Encoded, + __wasi_publickey_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect<__wasi_array_output_t> + publickeyExport(__wasi_publickey_t PkHandle, + __wasi_publickey_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect publickeyVerify(__wasi_publickey_t PkHandle) noexcept; + + WasiCryptoExpect<__wasi_publickey_t> + publickeyFromSecretkey(__wasi_secretkey_t SkHandle) noexcept; + + WasiCryptoExpect publickeyClose(__wasi_publickey_t PkHandle) noexcept; + + WasiCryptoExpect<__wasi_secretkey_t> + secretkeyImport(AsymmetricCommon::Algorithm Alg, Span Encoded, + __wasi_secretkey_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect<__wasi_array_output_t> + secretkeyExport(__wasi_secretkey_t SkHandle, + __wasi_secretkey_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect secretkeyClose(__wasi_secretkey_t SkHandle) noexcept; + + // Key_exchange + + WasiCryptoExpect<__wasi_array_output_t> + kxDh(__wasi_kx_publickey_t PkHandle, __wasi_kx_secretkey_t SkHandle) noexcept; + + WasiCryptoExpect> + kxEncapsulate(__wasi_kx_publickey_t PkHandle) noexcept; + + WasiCryptoExpect<__wasi_array_output_t> + kxDecapsulate(__wasi_kx_secretkey_t SkHandle, + Span EncapsulatedSecret) noexcept; + + // Signature + + WasiCryptoExpect<__wasi_array_output_t> + signatureExport(__wasi_signature_t SigHandle, + __wasi_signature_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect<__wasi_signature_t> + signatureImport(Signatures::Algorithm Alg, Span Encoded, + __wasi_signature_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect signatureClose(__wasi_signature_t SigHandle) noexcept; + + WasiCryptoExpect<__wasi_signature_state_t> + signatureStateOpen(__wasi_signature_keypair_t KpHandle) noexcept; + + WasiCryptoExpect + signatureStateUpdate(__wasi_signature_state_t StateHandle, + Span Input) noexcept; + + WasiCryptoExpect<__wasi_signature_t> + signatureStateSign(__wasi_signature_state_t StateHandle) noexcept; + + WasiCryptoExpect + signatureStateClose(__wasi_signature_state_t StateHandle) noexcept; + + WasiCryptoExpect<__wasi_signature_verification_state_t> + signatureVerificationStateOpen( + __wasi_signature_publickey_t PkHandle) noexcept; + + WasiCryptoExpect signatureVerificationStateUpdate( + __wasi_signature_verification_state_t StateHandle, + Span Input) noexcept; + + WasiCryptoExpect signatureVerificationStateVerify( + __wasi_signature_verification_state_t StateHandle, + __wasi_signature_t SigHandle) noexcept; + + WasiCryptoExpect signatureVerificationStateClose( + __wasi_signature_verification_state_t StateHandle) noexcept; + +private: + RefHandlesManager<__wasi_array_output_t, Common::ArrayOutput> + ArrayOutputManager{0x00}; + RcHandlesManager<__wasi_options_t, Common::Options> OptionsManager{0x01}; + RefHandlesManager<__wasi_symmetric_tag_t, Symmetric::Tag> SymmetricTagManager{ + 0xa}; + RcHandlesManager<__wasi_symmetric_key_t, Symmetric::KeyVariant> + SymmetricKeyManager{0x09}; + RcHandlesManager<__wasi_symmetric_state_t, Symmetric::StateVariant> + SymmetricStateManager{0x08}; + RcHandlesManager<__wasi_publickey_t, AsymmetricCommon::PkVariant> + PublicKeyManager{0x03}; + RcHandlesManager<__wasi_secretkey_t, AsymmetricCommon::SkVariant> + SecretKeyManager{0x04}; + RcHandlesManager<__wasi_keypair_t, AsymmetricCommon::KpVariant> + KeyPairManager{0x05}; + RcHandlesManager<__wasi_signature_t, Signatures::SigVariant> SignatureManager{ + 0x06}; + RcHandlesManager<__wasi_signature_state_t, Signatures::SignStateVariant> + SignStateManager{0x07}; + RcHandlesManager<__wasi_signature_state_t, + Signatures::VerificationStateVariant> + VerificationStateManager{0x02}; + + static Plugin::PluginRegister Register; +}; + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/ctx.cpp b/plugins/wasi_crypto/kx/ctx.cpp new file mode 100644 index 00000000..f5bb349b --- /dev/null +++ b/plugins/wasi_crypto/kx/ctx.cpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "ctx.h" +#include "kx/kx.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +WasiCryptoExpect<__wasi_array_output_t> +Context::kxDh(__wasi_kx_publickey_t PkHandle, + __wasi_kx_secretkey_t SkHandle) noexcept { + auto Sk = SecretKeyManager.getAs(SkHandle); + if (!Sk) { + return WasiCryptoUnexpect(Sk); + } + + auto Pk = PublicKeyManager.getAs(PkHandle); + if (!Pk) { + return WasiCryptoUnexpect(Pk); + } + + return Kx::dh(*Pk, *Sk).and_then([this](auto &&Data) { + return ArrayOutputManager.registerManager( + std::forward(Data)); + }); +} + +WasiCryptoExpect> +Context::kxEncapsulate(__wasi_kx_publickey_t PkHandle) noexcept { + auto EncapsulatedSecret = + PublicKeyManager.getAs(PkHandle).and_then( + [](auto &&KxPk) noexcept { return Kx::encapsulate(KxPk); }); + if (!EncapsulatedSecret) { + return WasiCryptoUnexpect(EncapsulatedSecret); + } + + auto SecretHandle = + ArrayOutputManager.registerManager(std::move(EncapsulatedSecret->Secret)); + if (!SecretHandle) { + return WasiCryptoUnexpect(SecretHandle); + } + + auto EncapsulatedSecretHandle = ArrayOutputManager.registerManager( + std::move(EncapsulatedSecret->EncapsulatedSecretData)); + if (!EncapsulatedSecretHandle) { + return WasiCryptoUnexpect(EncapsulatedSecretHandle); + } + + return std::tuple(*SecretHandle, *EncapsulatedSecretHandle); +} + +WasiCryptoExpect<__wasi_array_output_t> +Context::kxDecapsulate(__wasi_kx_secretkey_t SkHandle, + Span EncapsulatedSecret) noexcept { + return SecretKeyManager.getAs(SkHandle) + .and_then([EncapsulatedSecret](auto &&KxSk) noexcept { + return Kx::decapsulate(KxSk, EncapsulatedSecret); + }) + .and_then([this](auto &&Secret) noexcept { + return ArrayOutputManager.registerManager( + std::forward(Secret)); + }); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/dh/ecdsa.cpp b/plugins/wasi_crypto/kx/dh/ecdsa.cpp new file mode 100644 index 00000000..b0fd4f5a --- /dev/null +++ b/plugins/wasi_crypto/kx/dh/ecdsa.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "kx/dh/ecdsa.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Kx { + +namespace { +inline const size_t SharedSecretSize = 32; +} // namespace + +WasiCryptoExpect +Ecdsa::SecretKey::dh(const PublicKey &Pk) const noexcept { + EvpPkeyCtxPtr SkCtx{EVP_PKEY_CTX_new(Ctx.get(), nullptr)}; + opensslCheck(EVP_PKEY_derive_init(SkCtx.get())); + + // Set peer key. + opensslCheck(EVP_PKEY_derive_set_peer(SkCtx.get(), Pk.raw().get())); + + // Generate shared secret. + SecretVec Res(SharedSecretSize); + size_t Size = SharedSecretSize; + ensureOrReturn(EVP_PKEY_derive(SkCtx.get(), Res.data(), &Size), + __WASI_CRYPTO_ERRNO_INVALID_KEY); + ensureOrReturn(Size == SharedSecretSize, + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + return Res; +} + +} // namespace Kx +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/dh/ecdsa.h b/plugins/wasi_crypto/kx/dh/ecdsa.h new file mode 100644 index 00000000..2ffbbd82 --- /dev/null +++ b/plugins/wasi_crypto/kx/dh/ecdsa.h @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/kx/dh/ecdsa.h - Ecdsa alg implement --===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the definition of ecdsa algorithm. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "asymmetric_common/ecdsa.h" +#include "kx/options.h" +#include "utils/error.h" +#include "utils/evp_wrapper.h" +#include "utils/optional.h" +#include "utils/secret_vec.h" + +#include "common/span.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Kx { + +class Ecdsa { + +public: + class PublicKey; + class SecretKey; + class KeyPair; + using Base = AsymmetricCommon::Ecdsa; + class PublicKey : public Base::PublicKeyBase { + public: + using Base::PublicKeyBase::PublicKeyBase; + + const auto &raw() const { return Ctx; } + }; + + class SecretKey : public Base::SecretKeyBase { + public: + using Base::SecretKeyBase::SecretKeyBase; + + WasiCryptoExpect dh(const PublicKey &Pk) const noexcept; + }; + + class KeyPair : public Base::KeyPairBase { + public: + using Base::KeyPairBase::KeyPairBase; + }; +}; + +} // namespace Kx +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/dh/x25519.cpp b/plugins/wasi_crypto/kx/dh/x25519.cpp new file mode 100644 index 00000000..a2460217 --- /dev/null +++ b/plugins/wasi_crypto/kx/dh/x25519.cpp @@ -0,0 +1,181 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "kx/dh/x25519.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Kx { + +namespace { +inline const size_t PkSize = 32; +inline const size_t SkSize = 32; +inline const size_t KpSize = 64; +inline const size_t SharedSecretSize = 32; +} // namespace + +WasiCryptoExpect> X25519::PublicKey::exportData( + __wasi_publickey_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_PUBLICKEY_ENCODING_RAW: { + std::vector Res(PkSize); + + size_t Size = PkSize; + opensslCheck(EVP_PKEY_get_raw_public_key(Ctx.get(), Res.data(), &Size)); + ensureOrReturn(Size == PkSize, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + return Res; + } + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect X25519::PublicKey::verify() const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect X25519::SecretKey::exportData( + __wasi_secretkey_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_SECRETKEY_ENCODING_RAW: { + SecretVec Res(SkSize); + + size_t Size = SkSize; + opensslCheck(EVP_PKEY_get_raw_private_key(Ctx.get(), Res.data(), &Size)); + ensureOrReturn(Size == SkSize, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + return Res; + } + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect +X25519::SecretKey::publicKey() const noexcept { + // Since the inner is always `const`, we just increase the ref count. + return Ctx; +} + +WasiCryptoExpect +X25519::SecretKey::dh(const PublicKey &Pk) const noexcept { + EvpPkeyCtxPtr SkCtx{EVP_PKEY_CTX_new(Ctx.get(), nullptr)}; + opensslCheck(EVP_PKEY_derive_init(SkCtx.get())); + + // Set peer key. + opensslCheck(EVP_PKEY_derive_set_peer(SkCtx.get(), Pk.raw().get())); + + // Generate shared secret. + SecretVec Res(SharedSecretSize); + size_t Size = SharedSecretSize; + ensureOrReturn(EVP_PKEY_derive(SkCtx.get(), Res.data(), &Size), + __WASI_CRYPTO_ERRNO_INVALID_KEY); + ensureOrReturn(Size == SharedSecretSize, + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + return Res; +} + +WasiCryptoExpect +X25519::SecretKey::toKeyPair(const PublicKey &) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect +X25519::KeyPair::publicKey() const noexcept { + // Since the inner is always `const`, we just increase the ref count. + return Ctx; +} + +WasiCryptoExpect +X25519::KeyPair::secretKey() const noexcept { + // Since the inner is always `const`, we just increase the ref count. + return Ctx; +} + +WasiCryptoExpect X25519::KeyPair::exportData( + __wasi_keypair_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_KEYPAIR_ENCODING_RAW: { + SecretVec Res(KpSize); + + size_t Size = PkSize; + opensslCheck(EVP_PKEY_get_raw_public_key(Ctx.get(), Res.data(), &Size)); + ensureOrReturn(Size == PkSize, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + Size = SkSize; + opensslCheck( + EVP_PKEY_get_raw_private_key(Ctx.get(), Res.data() + PkSize, &Size)); + ensureOrReturn(Size == SkSize, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + return Res; + } + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect +X25519::PublicKey::import(Span Encoded, + __wasi_publickey_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_PUBLICKEY_ENCODING_RAW: { + EvpPkeyPtr Pk{EVP_PKEY_new_raw_public_key(EVP_PKEY_X25519, nullptr, + Encoded.data(), Encoded.size())}; + ensureOrReturn(Pk, __WASI_CRYPTO_ERRNO_INVALID_KEY); + return Pk; + } + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect +X25519::SecretKey::import(Span Encoded, + __wasi_secretkey_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_SECRETKEY_ENCODING_RAW: { + EvpPkeyPtr Sk{EVP_PKEY_new_raw_private_key(EVP_PKEY_X25519, nullptr, + Encoded.data(), Encoded.size())}; + ensureOrReturn(Sk, __WASI_CRYPTO_ERRNO_INVALID_KEY); + return Sk; + } + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect +X25519::KeyPair::generate(OptionalRef) noexcept { + EvpPkeyCtxPtr Ctx{EVP_PKEY_CTX_new_id(EVP_PKEY_X25519, nullptr)}; + opensslCheck(EVP_PKEY_keygen_init(Ctx.get())); + + EVP_PKEY *Kp = nullptr; + opensslCheck(EVP_PKEY_keygen(Ctx.get(), &Kp)); + + return EvpPkeyPtr{Kp}; +} + +WasiCryptoExpect +X25519::KeyPair::import(Span Encoded, + __wasi_keypair_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_KEYPAIR_ENCODING_RAW: { + ensureOrReturn(Encoded.size() == KpSize, __WASI_CRYPTO_ERRNO_INVALID_KEY); + // PublicKey can auto generate from SecretKey. + EvpPkeyPtr Sk{EVP_PKEY_new_raw_private_key( + EVP_PKEY_X25519, nullptr, Encoded.data() + PkSize, SkSize)}; + ensureOrReturn(Sk, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + return Sk; + } + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +} // namespace Kx +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/dh/x25519.h b/plugins/wasi_crypto/kx/dh/x25519.h new file mode 100644 index 00000000..f598b026 --- /dev/null +++ b/plugins/wasi_crypto/kx/dh/x25519.h @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/kx/dh/x25519.h - X25519 alg implement ===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the definition of x25519 algorithm. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "kx/options.h" +#include "utils/error.h" +#include "utils/evp_wrapper.h" +#include "utils/optional.h" +#include "utils/secret_vec.h" + +#include "common/span.h" + +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Kx { + +class X25519 { +public: + class PublicKey { + public: + PublicKey(EvpPkeyPtr Ctx) noexcept : Ctx(std::move(Ctx)) {} + + PublicKey(SharedEvpPkey Ctx) noexcept : Ctx(std::move(Ctx)) {} + + static WasiCryptoExpect + import(Span Encoded, + __wasi_publickey_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect> + exportData(__wasi_publickey_encoding_e_t Encoding) const noexcept; + + WasiCryptoExpect verify() const noexcept; + + const auto &raw() const { return Ctx; } + + private: + SharedEvpPkey Ctx; + }; + + class KeyPair; + + class SecretKey { + public: + SecretKey(EvpPkeyPtr Ctx) noexcept : Ctx(std::move(Ctx)) {} + + SecretKey(SharedEvpPkey Ctx) noexcept : Ctx(std::move(Ctx)) {} + + static WasiCryptoExpect + import(Span Encoded, + __wasi_secretkey_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect + exportData(__wasi_secretkey_encoding_e_t Encoding) const noexcept; + + WasiCryptoExpect publicKey() const noexcept; + + WasiCryptoExpect dh(const PublicKey &Pk) const noexcept; + + WasiCryptoExpect toKeyPair(const PublicKey &Pk) const noexcept; + + private: + SharedEvpPkey Ctx; + }; + + class KeyPair { + public: + KeyPair(EvpPkeyPtr Ctx) noexcept : Ctx(std::move(Ctx)) {} + + KeyPair(SharedEvpPkey Ctx) noexcept : Ctx(std::move(Ctx)) {} + + static WasiCryptoExpect + generate(OptionalRef Options) noexcept; + + static WasiCryptoExpect + import(Span Raw, + __wasi_keypair_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect publicKey() const noexcept; + + WasiCryptoExpect secretKey() const noexcept; + + WasiCryptoExpect + exportData(__wasi_keypair_encoding_e_t Encoding) const noexcept; + + private: + SharedEvpPkey Ctx; + }; +}; + +} // namespace Kx +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/func.cpp b/plugins/wasi_crypto/kx/func.cpp new file mode 100644 index 00000000..471d3a97 --- /dev/null +++ b/plugins/wasi_crypto/kx/func.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "kx/func.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Kx { + +Expect Dh::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t PkHandle, int32_t SkHandle, + uint32_t /* Out */ SharedSecretPtr) { + checkExist(MemInst); + + auto *const SharedSecret = + MemInst->getPointer<__wasi_array_output_t *>(SharedSecretPtr); + checkExist(SharedSecret); + + if (auto Res = Ctx.kxDh(PkHandle, SkHandle); unlikely(!Res)) { + return Res.error(); + } else { + *SharedSecret = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect Encapsulate::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t PkHandle, + uint32_t /* Out */ SecretPtr, + uint32_t /* Out */ EncapsulatedSecretPtr) { + checkExist(MemInst); + + auto *const Secret = MemInst->getPointer<__wasi_array_output_t *>(SecretPtr); + checkExist(Secret); + + auto *const EncapsulatedSecret = + MemInst->getPointer<__wasi_array_output_t *>(EncapsulatedSecretPtr); + checkExist(EncapsulatedSecret); + + if (auto Res = Ctx.kxEncapsulate(PkHandle); unlikely(!Res)) { + return Res.error(); + } else { + std::tie(*Secret, *EncapsulatedSecret) = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect Decapsulate::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SkHandle, + uint32_t EncapsulatedSecretPtr, + uint32_t EncapsulatedSecretLen, + uint32_t /* Out */ SecretPtr) { + checkExist(MemInst); + + const __wasi_size_t WasiEncapsulatedSecretLen = EncapsulatedSecretLen; + auto *const EncapsulatedSecret = MemInst->getPointer( + EncapsulatedSecretPtr, WasiEncapsulatedSecretLen); + + checkExist(EncapsulatedSecret); + + auto *const Secret = MemInst->getPointer<__wasi_array_output_t *>(SecretPtr); + checkExist(Secret); + + if (auto Res = Ctx.kxDecapsulate( + SkHandle, {EncapsulatedSecret, WasiEncapsulatedSecretLen}); + unlikely(!Res)) { + return Res.error(); + } else { + *Secret = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +} // namespace Kx +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/func.h b/plugins/wasi_crypto/kx/func.h new file mode 100644 index 00000000..b2d4b945 --- /dev/null +++ b/plugins/wasi_crypto/kx/func.h @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/kx/func.h - Key Exchange funcs -------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the Key Exchange host functions of wasi-crypto. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "utils/hostfunction.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Kx { + +class Dh : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t PkHandle, int32_t SkHandle, + uint32_t /* Out */ ArrayOutputHandlePtr); +}; + +class Encapsulate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t PkHandle, uint32_t /* Out */ SecretPtr, + uint32_t /* Out */ EncapsulatedSecretPtr); +}; + +class Decapsulate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SkHandle, uint32_t EncapsulatedSecretPtr, + uint32_t EncapsulatedSecretLen, + uint32_t /* Out */ SecretPtr); +}; + +} // namespace Kx +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/kx.cpp b/plugins/wasi_crypto/kx/kx.cpp new file mode 100644 index 00000000..cd785fc5 --- /dev/null +++ b/plugins/wasi_crypto/kx/kx.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "kx/kx.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Kx { + +namespace { +template struct DhTrait; + +template +struct DhTrait (SkType::*)(const PkType &) + const noexcept> { + using Pk = PkType; +}; +template using PkType = typename DhTrait::Pk; +} // namespace + +WasiCryptoExpect dh(const PkVariant &PkVariant, + const SkVariant &SkVariant) noexcept { + return std::visit( + [](const auto &Pk, + const auto &Sk) noexcept -> WasiCryptoExpect { + using InPkType = std::decay_t; + using ExpectPkType = PkType>; + if constexpr (std::is_same_v) { + return Sk.dh(Pk); + } else { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_KEY); + } + }, + PkVariant, SkVariant); +} + +WasiCryptoExpect +encapsulate(PkVariant &PkVariant) noexcept { + return std::visit( + [](auto &&) { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + }, + PkVariant); +} + +WasiCryptoExpect> +decapsulate(SkVariant &, Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +} // namespace Kx +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/kx.h b/plugins/wasi_crypto/kx/kx.h new file mode 100644 index 00000000..cf81bf62 --- /dev/null +++ b/plugins/wasi_crypto/kx/kx.h @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/kx/kx.h - Key Exchange related -------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the key exchange related functions. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "kx/registed.h" +#include "utils/error.h" + +#include +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Kx { + +using PkVariant = RegistedAlg::PkVariant; +using SkVariant = RegistedAlg::SkVariant; + +/// Diffie-Hellman based key agreement. +/// +/// More detailed: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#diffie-hellman-based-key-agreement +WasiCryptoExpect dh(const PkVariant &PkVariant, + const SkVariant &SkVariant) noexcept; + +/// Key encapsulation mechanisms. +/// +/// More detailed +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#key-encapsulation-mechanisms +struct EncapsulatedSecret { + std::vector EncapsulatedSecretData; + std::vector Secret; +}; + +WasiCryptoExpect encapsulate(PkVariant &PkVariant) noexcept; + +WasiCryptoExpect> +decapsulate(SkVariant &SkVariant, + Span EncapsulatedSecret) noexcept; + +} // namespace Kx +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/options.cpp b/plugins/wasi_crypto/kx/options.cpp new file mode 100644 index 00000000..da383199 --- /dev/null +++ b/plugins/wasi_crypto/kx/options.cpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "kx/options.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Kx { + +WasiCryptoExpect Options::set(std::string_view, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); +} + +WasiCryptoExpect Options::setU64(std::string_view, uint64_t) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); +} + +WasiCryptoExpect Options::setGuestBuffer(std::string_view, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); +} + +WasiCryptoExpect Options::get(std::string_view, + Span) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); +} + +WasiCryptoExpect Options::getU64(std::string_view) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); +} + +} // namespace Kx +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/options.h b/plugins/wasi_crypto/kx/options.h new file mode 100644 index 00000000..4074579c --- /dev/null +++ b/plugins/wasi_crypto/kx/options.h @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/kx/options.h - Key exchange Options --===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the Key Exchange Options class definition. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "utils/error.h" + +#include "common/span.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Kx { + +class Options { +public: + WasiCryptoExpect set(std::string_view Name, + Span Value) noexcept; + + WasiCryptoExpect setU64(std::string_view Name, uint64_t Value) noexcept; + + WasiCryptoExpect setGuestBuffer(std::string_view Name, + Span Buffer) noexcept; + + WasiCryptoExpect get(std::string_view Name, + Span Value) const noexcept; + + WasiCryptoExpect getU64(std::string_view Name) const noexcept; +}; + +} // namespace Kx +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/kx/registed.h b/plugins/wasi_crypto/kx/registed.h new file mode 100644 index 00000000..74fca12a --- /dev/null +++ b/plugins/wasi_crypto/kx/registed.h @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/kx/registed.h - Registed -------------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the register key exchange algorithm definitions. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "kx/dh/ecdsa.h" +#include "kx/dh/x25519.h" +#include "utils/error.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Kx { + +template struct Registed { + using PkVariant = std::variant; + using SkVariant = std::variant; + using KpVariant = std::variant; + using Variant = std::variant; +}; + +using RegistedAlg = Registed; + +using Algorithm = RegistedAlg::Variant; + +} // namespace Kx +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/module.cpp b/plugins/wasi_crypto/module.cpp new file mode 100644 index 00000000..9999cac4 --- /dev/null +++ b/plugins/wasi_crypto/module.cpp @@ -0,0 +1,161 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "module.h" +#include "asymmetric_common/func.h" +#include "common/func.h" +#include "kx/func.h" +#include "signatures/func.h" +#include "symmetric/func.h" + +#include + +namespace WasmEdge { +namespace Host { + +WasiCryptoModule::WasiCryptoModule() : ModuleInstance("wasi_ephemeral_crypto") { + using namespace WasiCrypto; + + // common + addHostFunc("array_output_len", + std::make_unique(Ctx)); + addHostFunc("array_output_pull", + std::make_unique(Ctx)); + addHostFunc("options_open", std::make_unique(Ctx)); + addHostFunc("options_close", std::make_unique(Ctx)); + addHostFunc("options_set", std::make_unique(Ctx)); + addHostFunc("options_set_u64", std::make_unique(Ctx)); + addHostFunc("options_set_guest_buffer", + std::make_unique(Ctx)); + addHostFunc("secrets_manager_open", + std::make_unique(Ctx)); + addHostFunc("secrets_manager_close", + std::make_unique(Ctx)); + addHostFunc("secrets_manager_invalidate", + std::make_unique(Ctx)); + + // symmetric + addHostFunc("symmetric_key_generate", + std::make_unique(Ctx)); + addHostFunc("symmetric_key_import", + std::make_unique(Ctx)); + addHostFunc("symmetric_key_export", + std::make_unique(Ctx)); + addHostFunc("symmetric_key_close", + std::make_unique(Ctx)); + addHostFunc("symmetric_key_generate_managed", + std::make_unique(Ctx)); + addHostFunc("symmetric_key_store_managed", + std::make_unique(Ctx)); + addHostFunc("symmetric_key_replace_managed", + std::make_unique(Ctx)); + addHostFunc("symmetric_key_id", std::make_unique(Ctx)); + addHostFunc("symmetric_key_from_id", + std::make_unique(Ctx)); + addHostFunc("symmetric_state_open", + std::make_unique(Ctx)); + addHostFunc("symmetric_state_clone", + std::make_unique(Ctx)); + addHostFunc("symmetric_state_options_get", + std::make_unique(Ctx)); + addHostFunc("symmetric_state_options_get_u64", + std::make_unique(Ctx)); + addHostFunc("symmetric_state_close", + std::make_unique(Ctx)); + addHostFunc("symmetric_state_absorb", + std::make_unique(Ctx)); + addHostFunc("symmetric_state_squeeze", + std::make_unique(Ctx)); + addHostFunc("symmetric_state_squeeze_tag", + std::make_unique(Ctx)); + addHostFunc("symmetric_state_squeeze_key", + std::make_unique(Ctx)); + addHostFunc("symmetric_state_max_tag_len", + std::make_unique(Ctx)); + addHostFunc("symmetric_state_encrypt", + std::make_unique(Ctx)); + addHostFunc("symmetric_state_encrypt_detached", + std::make_unique(Ctx)); + addHostFunc("symmetric_state_decrypt", + std::make_unique(Ctx)); + addHostFunc("symmetric_state_decrypt_detached", + std::make_unique(Ctx)); + addHostFunc("symmetric_state_ratchet", + std::make_unique(Ctx)); + addHostFunc("symmetric_tag_len", std::make_unique(Ctx)); + addHostFunc("symmetric_tag_pull", std::make_unique(Ctx)); + addHostFunc("symmetric_tag_verify", + std::make_unique(Ctx)); + addHostFunc("symmetric_tag_close", + std::make_unique(Ctx)); + + // asymmetric + addHostFunc("keypair_generate", + std::make_unique(Ctx)); + addHostFunc("keypair_import", + std::make_unique(Ctx)); + addHostFunc("keypair_generate_managed", + std::make_unique(Ctx)); + addHostFunc("keypair_store_managed", + std::make_unique(Ctx)); + addHostFunc("keypair_replace_managed", + std::make_unique(Ctx)); + addHostFunc("keypair_id", std::make_unique(Ctx)); + addHostFunc("keypair_from_id", + std::make_unique(Ctx)); + addHostFunc("keypair_from_pk_and_sk", + std::make_unique(Ctx)); + addHostFunc("keypair_export", + std::make_unique(Ctx)); + addHostFunc("keypair_publickey", + std::make_unique(Ctx)); + addHostFunc("keypair_secretkey", + std::make_unique(Ctx)); + addHostFunc("keypair_close", + std::make_unique(Ctx)); + addHostFunc("publickey_import", + std::make_unique(Ctx)); + addHostFunc("publickey_export", + std::make_unique(Ctx)); + addHostFunc("publickey_verify", + std::make_unique(Ctx)); + addHostFunc("publickey_from_secretkey", + std::make_unique(Ctx)); + addHostFunc("publickey_close", + std::make_unique(Ctx)); + addHostFunc("secretkey_import", + std::make_unique(Ctx)); + addHostFunc("secretkey_export", + std::make_unique(Ctx)); + addHostFunc("secretkey_close", + std::make_unique(Ctx)); + + // kx + addHostFunc("kx_dh", std::make_unique(Ctx)); + addHostFunc("kx_encapsulate", std::make_unique(Ctx)); + addHostFunc("kx_decapsulate", std::make_unique(Ctx)); + + // signature + addHostFunc("signature_export", std::make_unique(Ctx)); + addHostFunc("signature_import", std::make_unique(Ctx)); + addHostFunc("signature_state_open", + std::make_unique(Ctx)); + addHostFunc("signature_state_update", + std::make_unique(Ctx)); + addHostFunc("signature_state_sign", + std::make_unique(Ctx)); + addHostFunc("signature_state_close", + std::make_unique(Ctx)); + addHostFunc("signature_verification_state_open", + std::make_unique(Ctx)); + addHostFunc("signature_verification_state_update", + std::make_unique(Ctx)); + addHostFunc("signature_verification_state_verify", + std::make_unique(Ctx)); + addHostFunc("signature_verification_state_close", + std::make_unique(Ctx)); + addHostFunc("signature_close", std::make_unique(Ctx)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/module.h b/plugins/wasi_crypto/module.h new file mode 100644 index 00000000..b709325e --- /dev/null +++ b/plugins/wasi_crypto/module.h @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/module.h - Module class definition ---===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the declaration of the wasi-crypto module class. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "ctx.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasiCryptoModule : public Runtime::Instance::ModuleInstance { +public: + WasiCryptoModule(); + + WasiCrypto::Context &getContext() { return Ctx; } + +private: + WasiCrypto::Context Ctx; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/ctx.cpp b/plugins/wasi_crypto/signatures/ctx.cpp new file mode 100644 index 00000000..c5213f9c --- /dev/null +++ b/plugins/wasi_crypto/signatures/ctx.cpp @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "ctx.h" +#include "signatures/signatures.h" +#include "signatures/signstate.h" +#include "signatures/verificationstate.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +WasiCryptoExpect<__wasi_array_output_t> +Context::signatureExport(__wasi_signature_t SigHandle, + __wasi_signature_encoding_e_t Encoding) noexcept { + return SignatureManager.get(SigHandle) + .and_then([Encoding](auto &&SigVariant) noexcept { + return Signatures::sigExportData( + std::forward(SigVariant), Encoding); + }) + .and_then([this](auto &&Data) noexcept { + return ArrayOutputManager.registerManager( + std::forward(Data)); + }); +} + +WasiCryptoExpect +Context::signatureClose(__wasi_signature_t SigHandle) noexcept { + return SignatureManager.close(SigHandle); +} + +WasiCryptoExpect<__wasi_array_output_t> +Context::signatureImport(Signatures::Algorithm Alg, Span Encoded, + __wasi_signature_encoding_e_t Encoding) noexcept { + return Signatures::sigImport(Alg, Encoded, Encoding) + .and_then([this](auto &&Sig) noexcept { + return SignatureManager.registerManager( + std::forward(Sig)); + }); +} + +WasiCryptoExpect<__wasi_signature_state_t> +Context::signatureStateOpen(__wasi_signature_keypair_t KpHandle) noexcept { + return KeyPairManager.getAs(KpHandle) + .and_then([](auto &&KpVariant) noexcept { + return Signatures::sigStateOpen( + std::forward(KpVariant)); + }) + .and_then([this](auto &&SignStateVariant) noexcept { + return SignStateManager.registerManager( + std::forward(SignStateVariant)); + }); +} + +WasiCryptoExpect +Context::signatureStateUpdate(__wasi_signature_state_t StateHandle, + Span Input) noexcept { + return SignStateManager.get(StateHandle) + .and_then([Input](auto &&SignStateVariant) noexcept { + return Signatures::sigStateUpdate(SignStateVariant, Input); + }); +} + +WasiCryptoExpect<__wasi_signature_t> +Context::signatureStateSign(__wasi_signature_state_t StateHandle) noexcept { + return SignStateManager.get(StateHandle) + .and_then([](auto &&SignStateVariant) noexcept { + return Signatures::sigStateSign(SignStateVariant); + }) + .and_then([this](auto &&Signature) noexcept { + return SignatureManager.registerManager( + std::forward(Signature)); + }); +} + +WasiCryptoExpect +Context::signatureStateClose(__wasi_signature_state_t StateHandle) noexcept { + return SignStateManager.close(StateHandle); +} + +WasiCryptoExpect<__wasi_signature_verification_state_t> +Context::signatureVerificationStateOpen( + __wasi_signature_publickey_t PkHandle) noexcept { + return PublicKeyManager.getAs(PkHandle) + .and_then([](auto &&PkVariant) noexcept { + return Signatures::verificationStateOpen( + std::forward(PkVariant)); + }) + .and_then([this](auto &&VerificationStateVariant) noexcept { + return VerificationStateManager.registerManager( + std::forward( + VerificationStateVariant)); + }); +} + +WasiCryptoExpect Context::signatureVerificationStateUpdate( + __wasi_signature_verification_state_t VerificationHandle, + Span Input) noexcept { + return VerificationStateManager.get(VerificationHandle) + .and_then([Input](auto &&VerificationStateVariant) noexcept { + return Signatures::verificationStateUpdate(VerificationStateVariant, + Input); + }); +} + +WasiCryptoExpect Context::signatureVerificationStateVerify( + __wasi_signature_verification_state_t VerificationHandle, + __wasi_signature_t SigHandle) noexcept { + auto Verification = VerificationStateManager.get(VerificationHandle); + if (!Verification) { + return WasiCryptoUnexpect(Verification); + } + + auto Sig = SignatureManager.get(SigHandle); + if (!Sig) { + return WasiCryptoUnexpect(Sig); + } + + return Signatures::verificationStateVerify(*Verification, *Sig); +} + +WasiCryptoExpect Context::signatureVerificationStateClose( + __wasi_signature_verification_state_t VerificationHandle) noexcept { + return VerificationStateManager.close(VerificationHandle); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/ecdsa.cpp b/plugins/wasi_crypto/signatures/ecdsa.cpp new file mode 100644 index 00000000..f03ed1b3 --- /dev/null +++ b/plugins/wasi_crypto/signatures/ecdsa.cpp @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "signatures/ecdsa.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +namespace { +inline const size_t RawSigSize = 64; +} // namespace + +template +WasiCryptoExpect::VerificationState> +Ecdsa::PublicKey::openVerificationState() const noexcept { + EvpMdCtxPtr SignCtx{EVP_MD_CTX_create()}; + opensslCheck(EVP_DigestVerifyInit(SignCtx.get(), nullptr, EVP_sha256(), + nullptr, this->Ctx.get())); + return SignCtx; +} + +template +WasiCryptoExpect::SignState> +Ecdsa::KeyPair::openSignState() const noexcept { + EvpMdCtxPtr SignCtx{EVP_MD_CTX_create()}; + opensslCheck(EVP_DigestSignInit(SignCtx.get(), nullptr, EVP_sha256(), nullptr, + this->Ctx.get())); + + return SignCtx; +} + +template +WasiCryptoExpect::Signature> +Ecdsa::Signature::import( + Span Encoded, + __wasi_signature_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_SIGNATURE_ENCODING_RAW: { + ensureOrReturn(Encoded.size() == RawSigSize, + __WASI_CRYPTO_ERRNO_INVALID_SIGNATURE); + EcdsaSigPtr Sig{o2iEcdsaSig(Encoded)}; + ensureOrReturn(Sig, __WASI_CRYPTO_ERRNO_INVALID_SIGNATURE); + return i2dEcdsaSig(Sig.get()); + } + case __WASI_SIGNATURE_ENCODING_DER: { + return std::vector(Encoded.begin(), Encoded.end()); + } + default: + assumingUnreachable(); + } +} + +template +WasiCryptoExpect> Ecdsa::Signature::exportData( + __wasi_signature_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_SIGNATURE_ENCODING_RAW: { + EcdsaSigPtr Sig{d2iEcdsaSig(Data)}; + ensureOrReturn(Sig, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + return i2oEcdsaSig(Sig.get()); + } + case __WASI_SIGNATURE_ENCODING_DER: { + return Data; + } + default: + assumingUnreachable(); + } +} + +template +WasiCryptoExpect +Ecdsa::SignState::update(Span Data) noexcept { + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck( + EVP_DigestSignUpdate(Ctx->RawCtx.get(), Data.data(), Data.size())); + return {}; +} + +template +WasiCryptoExpect::Signature> +Ecdsa::SignState::sign() noexcept { + size_t Size; + // For ecdsa, OpenSSL produce a der format signatures which means the size is + // not fixed. Here is an answer talk about it: + // https://bitcoin.stackexchange.com/questions/77191/what-is-the-maximum-size-of-a-der-encoded-ecdsa-signature + // So instead of fixing size, just read. + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_DigestSignFinal(Ctx->RawCtx.get(), nullptr, &Size)); + + std::vector Res(Size); + opensslCheck(EVP_DigestSignFinal(Ctx->RawCtx.get(), Res.data(), &Size)); + + return Res; +} + +template +WasiCryptoExpect +Ecdsa::VerificationState::update(Span Data) noexcept { + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck( + EVP_DigestVerifyUpdate(Ctx->RawCtx.get(), Data.data(), Data.size())); + return {}; +} + +template +WasiCryptoExpect +Ecdsa::VerificationState::verify(const Signature &Sig) noexcept { + std::scoped_lock Lock{Ctx->Mutex}; + ensureOrReturn(EVP_DigestVerifyFinal(Ctx->RawCtx.get(), Sig.ref().data(), + Sig.ref().size()), + __WASI_CRYPTO_ERRNO_VERIFICATION_FAILED); + return {}; +} + +template class Ecdsa; +template class Ecdsa; + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/ecdsa.h b/plugins/wasi_crypto/signatures/ecdsa.h new file mode 100644 index 00000000..8e8c57bb --- /dev/null +++ b/plugins/wasi_crypto/signatures/ecdsa.h @@ -0,0 +1,118 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/signatures/ecdsa.h - Ecdsa alg -------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the definition of ecdsa algorithm. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "asymmetric_common/ecdsa.h" +#include "signatures/options.h" +#include "utils/error.h" +#include "utils/evp_wrapper.h" +#include "utils/optional.h" +#include "utils/secret_vec.h" + +#include "common/span.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +template class Ecdsa { +public: + class PublicKey; + class KeyPair; + class SecretKey; + using Base = + AsymmetricCommon::Ecdsa; + class Signature { + public: + Signature(std::vector Data) noexcept : Data(std::move(Data)) {} + + static WasiCryptoExpect + import(Span Encoded, + __wasi_signature_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect> + exportData(__wasi_signature_encoding_e_t Encoding) const noexcept; + + const std::vector &ref() const { return Data; } + + private: + // Inner represent as der because OpenSSL use der format for evp interface. + std::vector Data; + }; + + class SignState { + public: + SignState(EvpMdCtxPtr Ctx) noexcept + : Ctx(std::make_shared(std::move(Ctx))) {} + + WasiCryptoExpect update(Span Input) noexcept; + + WasiCryptoExpect sign() noexcept; + + private: + struct Inner { + Inner(EvpMdCtxPtr RawCtx) noexcept : RawCtx(std::move(RawCtx)) {} + std::mutex Mutex; + EvpMdCtxPtr RawCtx; + }; + std::shared_ptr Ctx; + }; + + class VerificationState { + public: + VerificationState(EvpMdCtxPtr Ctx) noexcept + : Ctx(std::make_shared(std::move(Ctx))) {} + + WasiCryptoExpect update(Span Input) noexcept; + + WasiCryptoExpect verify(const Signature &Sig) noexcept; + + private: + struct Inner { + Inner(EvpMdCtxPtr RawCtx) noexcept : RawCtx(std::move(RawCtx)) {} + std::mutex Mutex; + EvpMdCtxPtr RawCtx; + }; + std::shared_ptr Ctx; + }; + + class PublicKey : public Base::PublicKeyBase { + public: + using Base::PublicKeyBase::PublicKeyBase; + + WasiCryptoExpect openVerificationState() const noexcept; + }; + + class SecretKey : public Base::SecretKeyBase { + public: + using Base::SecretKeyBase::SecretKeyBase; + }; + + class KeyPair : public Base::KeyPairBase { + public: + using Base::KeyPairBase::KeyPairBase; + + WasiCryptoExpect openSignState() const noexcept; + }; +}; + +using EcdsaP256 = Ecdsa; +using EcdsaK256 = Ecdsa; + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/eddsa.cpp b/plugins/wasi_crypto/signatures/eddsa.cpp new file mode 100644 index 00000000..73d3d6db --- /dev/null +++ b/plugins/wasi_crypto/signatures/eddsa.cpp @@ -0,0 +1,250 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "signatures/eddsa.h" + +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +namespace { +inline constexpr size_t PkSize = 32; +inline constexpr size_t SkSize = 32; +inline constexpr size_t KpSize = 64; +inline constexpr size_t SigSize = 64; +} // namespace + +WasiCryptoExpect +Eddsa::PublicKey::import(Span Encoded, + __wasi_publickey_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_PUBLICKEY_ENCODING_RAW: { + EvpPkeyPtr Ctx{EVP_PKEY_new_raw_public_key(EVP_PKEY_ED25519, nullptr, + Encoded.data(), Encoded.size())}; + ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); + + return Ctx; + } + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect Eddsa::PublicKey::verify() const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect> Eddsa::PublicKey::exportData( + __wasi_publickey_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_PUBLICKEY_ENCODING_RAW: { + size_t Size = PkSize; + std::vector Res(PkSize); + opensslCheck(EVP_PKEY_get_raw_public_key(Ctx.get(), Res.data(), &Size)); + ensureOrReturn(Size == PkSize, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + return Res; + } + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect +Eddsa::PublicKey::openVerificationState() const noexcept { + EvpMdCtxPtr SignCtx{EVP_MD_CTX_create()}; + + opensslCheck(EVP_DigestVerifyInit(SignCtx.get(), nullptr, nullptr, nullptr, + Ctx.get())); + return SignCtx; +} + +WasiCryptoExpect +Eddsa::SecretKey::import(Span Encoded, + __wasi_secretkey_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_SECRETKEY_ENCODING_RAW: { + EvpPkeyPtr Ctx{EVP_PKEY_new_raw_private_key( + EVP_PKEY_ED25519, nullptr, Encoded.data(), Encoded.size())}; + ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); + + return Ctx; + } + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect +Eddsa::SecretKey::publicKey() const noexcept { + // Since the inner is always `const`, we just increase the ref count. + return Ctx; +} + +WasiCryptoExpect +Eddsa::SecretKey::toKeyPair(const PublicKey &) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect Eddsa::SecretKey::exportData( + __wasi_secretkey_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_SECRETKEY_ENCODING_RAW: { + size_t Size = SkSize; + SecretVec Res(SkSize); + opensslCheck(EVP_PKEY_get_raw_private_key(Ctx.get(), Res.data(), &Size)); + ensureOrReturn(Size == SkSize, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + return Res; + } + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect +Eddsa::KeyPair::generate(OptionalRef) noexcept { + // Generate Key. + EvpPkeyCtxPtr KCtx{EVP_PKEY_CTX_new_id(EVP_PKEY_ED25519, nullptr)}; + opensslCheck(KCtx); + opensslCheck(EVP_PKEY_keygen_init(KCtx.get())); + + EVP_PKEY *Key = nullptr; + opensslCheck(EVP_PKEY_keygen(KCtx.get(), &Key)); + + return EvpPkeyPtr{Key}; +} + +// Refer to: https://github.com/openssl/openssl/issues/8960 +WasiCryptoExpect +Eddsa::KeyPair::import(Span Encoded, + __wasi_keypair_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_KEYPAIR_ENCODING_RAW: { + ensureOrReturn(Encoded.size() == KpSize, __WASI_CRYPTO_ERRNO_INVALID_KEY); + // PublicKey can auto generate from SecretKey. + EvpPkeyPtr SkCtx{EVP_PKEY_new_raw_private_key(EVP_PKEY_ED25519, nullptr, + Encoded.data(), SkSize)}; + ensureOrReturn(SkCtx, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + return SkCtx; + } + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect Eddsa::KeyPair::exportData( + __wasi_keypair_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_KEYPAIR_ENCODING_RAW: { + SecretVec Res(KpSize); + size_t Size = SkSize; + opensslCheck(EVP_PKEY_get_raw_private_key(Ctx.get(), Res.data(), &Size)); + ensureOrReturn(Size == SkSize, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + Size = PkSize; + opensslCheck( + EVP_PKEY_get_raw_public_key(Ctx.get(), Res.data() + SkSize, &Size)); + ensureOrReturn(Size == PkSize, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + return Res; + } + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect Eddsa::KeyPair::publicKey() const noexcept { + // Since the inner is always `const`, we just increase the ref count. + return Ctx; +} + +WasiCryptoExpect Eddsa::KeyPair::secretKey() const noexcept { + // Since the inner is always `const`, we just increase the ref count. + return Ctx; +} + +WasiCryptoExpect +Eddsa::KeyPair::openSignState() const noexcept { + EvpMdCtxPtr SignCtx{EVP_MD_CTX_create()}; + opensslCheck(SignCtx); + + opensslCheck( + EVP_DigestSignInit(SignCtx.get(), nullptr, nullptr, nullptr, Ctx.get())); + + return SignCtx; +} + +WasiCryptoExpect +Eddsa::Signature::import(Span Encoded, + __wasi_signature_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_SIGNATURE_ENCODING_RAW: + ensureOrReturn(Encoded.size() == SigSize, + __WASI_CRYPTO_ERRNO_INVALID_SIGNATURE); + return std::vector(Encoded.begin(), Encoded.end()); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect> Eddsa::Signature::exportData( + __wasi_signature_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_SIGNATURE_ENCODING_RAW: + return Data; + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +WasiCryptoExpect +Eddsa::SignState::update(Span Input) noexcept { + // Notice: Ecdsa is oneshot in OpenSSL, we need a cache for updating instead + // of calling `EVP_DigestSignUpdate`. + std::scoped_lock Lock{Ctx->Mutex}; + + Ctx->Data.insert(Ctx->Data.end(), Input.begin(), Input.end()); + return {}; +} + +WasiCryptoExpect Eddsa::SignState::sign() noexcept { + size_t Size = SigSize; + std::vector Res(Size); + + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_DigestSign(Ctx->RawCtx.get(), Res.data(), &Size, + Ctx->Data.data(), Ctx->Data.size())); + ensureOrReturn(Size == SigSize, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + return Res; +} + +WasiCryptoExpect +Eddsa::VerificationState::update(Span Input) noexcept { + // Also oneshot. + std::scoped_lock Lock{Ctx->Mutex}; + + Ctx->Data.insert(Ctx->Data.end(), Input.begin(), Input.end()); + return {}; +} + +WasiCryptoExpect +Eddsa::VerificationState::verify(const Signature &Sig) noexcept { + std::scoped_lock Lock{Ctx->Mutex}; + // The invokation to EVP_DigestVerifyFinal() internally finalizes a copy of + // the digest context. This means that EVP_VerifyUpdate() and + // EVP_VerifyFinal() can be called later to digest and verify the additional + // data. + ensureOrReturn(EVP_DigestVerify(Ctx->RawCtx.get(), Sig.ref().data(), + Sig.ref().size(), Ctx->Data.data(), + Ctx->Data.size()), + __WASI_CRYPTO_ERRNO_VERIFICATION_FAILED); + + return {}; +} + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/eddsa.h b/plugins/wasi_crypto/signatures/eddsa.h new file mode 100644 index 00000000..70a17538 --- /dev/null +++ b/plugins/wasi_crypto/signatures/eddsa.h @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/signatures/eddsa.h - Eddsa alg -------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the definition of eddsa algorithm. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "signatures/options.h" +#include "utils/error.h" +#include "utils/evp_wrapper.h" +#include "utils/optional.h" +#include "utils/secret_vec.h" + +#include "common/span.h" + +#include +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +class Eddsa { +public: + class Signature { + public: + Signature(std::vector &&Data) noexcept : Data(std::move(Data)) {} + + static WasiCryptoExpect + import(Span Encoded, + __wasi_signature_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect> + exportData(__wasi_signature_encoding_e_t Encoding) const noexcept; + + const std::vector &ref() const { return Data; } + + private: + std::vector Data; + }; + + class SignState { + public: + SignState(EvpMdCtxPtr Ctx) noexcept + : Ctx(std::make_shared(std::move(Ctx))) {} + + WasiCryptoExpect update(Span Input) noexcept; + + WasiCryptoExpect sign() noexcept; + + private: + struct Inner { + Inner(EvpMdCtxPtr Ctx) noexcept : RawCtx(std::move(Ctx)) {} + std::mutex Mutex; + std::vector Data; + EvpMdCtxPtr RawCtx; + }; + std::shared_ptr Ctx; + }; + + class VerificationState { + public: + VerificationState(EvpMdCtxPtr Ctx) noexcept + : Ctx(std::make_shared(std::move(Ctx))) {} + + WasiCryptoExpect update(Span Input) noexcept; + + WasiCryptoExpect verify(const Signature &Sig) noexcept; + + private: + struct Inner { + Inner(EvpMdCtxPtr Ctx) noexcept : RawCtx(std::move(Ctx)) {} + std::mutex Mutex; + std::vector Data; + EvpMdCtxPtr RawCtx; + }; + std::shared_ptr Ctx; + }; + + class KeyPair; + + class PublicKey { + public: + PublicKey(EvpPkeyPtr Ctx) noexcept : Ctx(std::move(Ctx)) {} + + PublicKey(SharedEvpPkey Ctx) noexcept : Ctx(std::move(Ctx)) {} + + static WasiCryptoExpect + import(Span Encoded, + __wasi_publickey_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect verify() const noexcept; + + WasiCryptoExpect> + exportData(__wasi_publickey_encoding_e_t Encoding) const noexcept; + + WasiCryptoExpect openVerificationState() const noexcept; + + private: + SharedEvpPkey Ctx; + }; + + class SecretKey { + public: + SecretKey(EvpPkeyPtr Ctx) noexcept : Ctx(std::move(Ctx)) {} + + SecretKey(SharedEvpPkey Ctx) noexcept : Ctx(std::move(Ctx)) {} + + static WasiCryptoExpect + import(Span Encoded, + __wasi_secretkey_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect publicKey() const noexcept; + + WasiCryptoExpect toKeyPair(const PublicKey &Pk) const noexcept; + + WasiCryptoExpect + exportData(__wasi_secretkey_encoding_e_t Encoding) const noexcept; + + private: + SharedEvpPkey Ctx; + }; + + class KeyPair { + public: + KeyPair(EvpPkeyPtr Ctx) noexcept : Ctx(std::move(Ctx)) {} + + KeyPair(SharedEvpPkey Ctx) noexcept : Ctx(std::move(Ctx)) {} + + static WasiCryptoExpect + generate(OptionalRef Options) noexcept; + + static WasiCryptoExpect + import(Span Encoded, + __wasi_keypair_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect + exportData(__wasi_keypair_encoding_e_t Encoding) const noexcept; + + WasiCryptoExpect publicKey() const noexcept; + + WasiCryptoExpect secretKey() const noexcept; + + WasiCryptoExpect openSignState() const noexcept; + + private: + SharedEvpPkey Ctx; + }; +}; + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/func.cpp b/plugins/wasi_crypto/signatures/func.cpp new file mode 100644 index 00000000..c78df90e --- /dev/null +++ b/plugins/wasi_crypto/signatures/func.cpp @@ -0,0 +1,225 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "signatures/func.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +Expect Export::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SigHandle, uint32_t Encoding, + uint32_t /* Out */ ArrayOutputHandlePtr) { + checkExist(MemInst); + + __wasi_signature_encoding_e_t WasiEncoding; + if (auto Res = cast<__wasi_signature_encoding_e_t>(Encoding); + unlikely(!Res)) { + return Res.error(); + } else { + WasiEncoding = *Res; + } + + auto *const ArrayOutput = + MemInst->getPointer<__wasi_array_output_t *>(ArrayOutputHandlePtr); + checkExist(ArrayOutput); + + if (auto Res = Ctx.signatureExport(SigHandle, WasiEncoding); unlikely(!Res)) { + return Res.error(); + } else { + *ArrayOutput = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect Import::body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t AlgPtr, uint32_t AlgLen, + uint32_t EncodedPtr, uint32_t EncodedLen, + uint32_t Encoding, + uint32_t /* Out */ SigHandlePtr) { + checkExist(MemInst); + + const __wasi_size_t WasiAlgLen = AlgLen; + auto *const Alg = MemInst->getPointer(AlgPtr, WasiAlgLen); + checkExist(Alg); + + Algorithm WasiAlg; + if (auto Res = tryFrom({Alg, AlgLen}); unlikely(!Res)) { + return Res.error(); + } else { + WasiAlg = *Res; + } + + const __wasi_size_t WasiEncodedLen = EncodedLen; + auto *const Encoded = + MemInst->getPointer(EncodedPtr, WasiEncodedLen); + checkExist(Encoded); + + __wasi_signature_encoding_e_t WasiEncoding; + if (auto Res = cast<__wasi_signature_encoding_e_t>(Encoding); + unlikely(!Res)) { + return Res.error(); + } else { + WasiEncoding = *Res; + } + + auto *const SigHandle = + MemInst->getPointer<__wasi_signature_t *>(SigHandlePtr); + checkExist(SigHandle); + + if (auto Res = + Ctx.signatureImport(WasiAlg, {Encoded, WasiEncodedLen}, WasiEncoding); + unlikely(!Res)) { + return Res.error(); + } else { + *SigHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateOpen::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t KpHandle, + uint32_t /* Out */ SigStatePtr) { + checkExist(MemInst); + + auto *const SigState = + MemInst->getPointer<__wasi_signature_state_t *>(SigStatePtr); + checkExist(SigState); + + if (auto Res = Ctx.signatureStateOpen(KpHandle); unlikely(!Res)) { + return Res.error(); + } else { + *SigState = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateUpdate::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SigStateHandle, uint32_t InputPtr, + uint32_t InputSize) { + checkExist(MemInst); + + const __wasi_size_t WasiInputSize = InputSize; + auto *const Input = + MemInst->getPointer(InputPtr, WasiInputSize); + checkExist(Input); + + if (auto Res = + Ctx.signatureStateUpdate(SigStateHandle, {Input, WasiInputSize}); + unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateSign::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SigStateHandle, + uint32_t /* Out */ ArrayOutputHandlePtr) { + checkExist(MemInst); + + auto *const ArrayOutputHandle = + MemInst->getPointer<__wasi_array_output_t *>(ArrayOutputHandlePtr); + checkExist(ArrayOutputHandle); + + if (auto Res = Ctx.signatureStateSign(SigStateHandle); unlikely(!Res)) { + return Res.error(); + } else { + *ArrayOutputHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateClose::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SigStateHandle) { + checkExist(MemInst); + + if (auto Res = Ctx.signatureStateClose(SigStateHandle); unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +VerificationStateOpen::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SigPkHandle, + uint32_t /* Out */ VerificationStateHandlePtr) { + checkExist(MemInst); + + auto *const VerificationStateHandle = + MemInst->getPointer<__wasi_signature_state_t *>( + VerificationStateHandlePtr); + checkExist(VerificationStateHandle); + + if (auto Res = Ctx.signatureVerificationStateOpen(SigPkHandle); + unlikely(!Res)) { + return Res.error(); + } else { + *VerificationStateHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +VerificationStateUpdate::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SigStateHandle, uint32_t InputPtr, + uint32_t InputSize) { + checkExist(MemInst); + + const __wasi_size_t WasiInputSize = InputSize; + auto *const Input = MemInst->getPointer(InputPtr, InputSize); + checkExist(Input); + + if (auto Res = Ctx.signatureVerificationStateUpdate(SigStateHandle, + {Input, WasiInputSize}); + unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +VerificationStateVerify::body(Runtime::Instance::MemoryInstance *, + int32_t VerificationStateHandle, + int32_t SigHandle) { + if (auto Res = Ctx.signatureVerificationStateVerify(VerificationStateHandle, + SigHandle); + unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +VerificationStateClose::body(Runtime::Instance::MemoryInstance *, + int32_t VerificationStateHandle) { + if (auto Res = Ctx.signatureVerificationStateClose(VerificationStateHandle); + unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect Close::body(Runtime::Instance::MemoryInstance *, + int32_t SigHandle) { + if (auto Res = Ctx.signatureClose(SigHandle); unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/func.h b/plugins/wasi_crypto/signatures/func.h new file mode 100644 index 00000000..6c0599f1 --- /dev/null +++ b/plugins/wasi_crypto/signatures/func.h @@ -0,0 +1,113 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/signatures/func.h - Signatures func --===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the signatures host functions of wasi-crypto. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "utils/hostfunction.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +class Export : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SigHandle, uint32_t Encoding, + uint32_t /* Out */ ArrayOutputHandlePtr); +}; + +class Import : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t AlgPtr, uint32_t AlgLen, uint32_t EncodedPtr, + uint32_t EncodedLen, uint32_t Encoding, + uint32_t /* Out */ SigHandlePtr); +}; + +class StateOpen : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t KpHandle, uint32_t /* Out */ SigStatePtr); +}; + +class StateUpdate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SigStateHandle, uint32_t InputPtr, + uint32_t InputSize); +}; + +class StateSign : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SigStateHandle, + uint32_t /* Out */ ArrayOutputHandlePtr); +}; + +class StateClose : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SigStateHandle); +}; + +class VerificationStateOpen : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SigPkHandle, + uint32_t /* Out */ VerificationStateHandlePtr); +}; + +class VerificationStateUpdate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SigStateHandle, uint32_t InputPtr, + uint32_t InputSize); +}; + +class VerificationStateVerify : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t VerificationStateHandle, int32_t SigHandle); +}; + +class VerificationStateClose : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t VerificationStateHandle); +}; + +class Close : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SigHandle); +}; + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/options.cpp b/plugins/wasi_crypto/signatures/options.cpp new file mode 100644 index 00000000..761ca439 --- /dev/null +++ b/plugins/wasi_crypto/signatures/options.cpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "signatures/options.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +WasiCryptoExpect Options::set(std::string_view, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); +} + +WasiCryptoExpect Options::setU64(std::string_view, uint64_t) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); +} + +WasiCryptoExpect Options::setGuestBuffer(std::string_view, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); +} + +WasiCryptoExpect Options::get(std::string_view, + Span) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); +} + +WasiCryptoExpect Options::getU64(std::string_view) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); +} + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/options.h b/plugins/wasi_crypto/signatures/options.h new file mode 100644 index 00000000..a229dc9b --- /dev/null +++ b/plugins/wasi_crypto/signatures/options.h @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/signatures/options.h - Options -------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the Signatures Options class definition. +/// +//===----------------------------------------------------------------------===// +#pragma once + +#include "utils/error.h" + +#include "common/span.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +class Options { +public: + WasiCryptoExpect set(std::string_view Name, + Span Value) noexcept; + + WasiCryptoExpect setU64(std::string_view Name, uint64_t Value) noexcept; + + WasiCryptoExpect setGuestBuffer(std::string_view Name, + Span Buffer) noexcept; + + WasiCryptoExpect get(std::string_view Name, + Span Value) const noexcept; + + WasiCryptoExpect getU64(std::string_view Name) const noexcept; +}; + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/registed.h b/plugins/wasi_crypto/signatures/registed.h new file mode 100644 index 00000000..c80c66e9 --- /dev/null +++ b/plugins/wasi_crypto/signatures/registed.h @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/signatures/registed.h - Registed -----===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the register signature algorithm definitions. +/// +//===----------------------------------------------------------------------===// +#pragma once + +#include "signatures/ecdsa.h" +#include "signatures/eddsa.h" +#include "signatures/rsa.h" +#include "utils/error.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +template struct Registed { + using PkVariant = std::variant; + using SkVariant = std::variant; + using KpVariant = std::variant; + using Variant = std::variant; + using SigVariant = std::variant; + using SignStateVariant = std::variant; + using VerificationStateVariant = + std::variant; +}; + +using RegistedAlg = + Registed; + +using Algorithm = RegistedAlg::Variant; + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/rsa.cpp b/plugins/wasi_crypto/signatures/rsa.cpp new file mode 100644 index 00000000..6493a16d --- /dev/null +++ b/plugins/wasi_crypto/signatures/rsa.cpp @@ -0,0 +1,375 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "signatures/rsa.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +template +WasiCryptoExpect::PublicKey> +Rsa::PublicKey::import( + Span Encoded, + __wasi_publickey_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_PUBLICKEY_ENCODING_PKCS8: + return importPkcs8(Encoded); + case __WASI_PUBLICKEY_ENCODING_PEM: + return importPem(Encoded); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +template +WasiCryptoExpect::PublicKey> +Rsa::PublicKey::importPkcs8( + Span Encoded) noexcept { + return checkValid(EvpPkeyPtr{d2iPUBKEY(Encoded)}); +} + +template +WasiCryptoExpect::PublicKey> +Rsa::PublicKey::importPem( + Span Encoded) noexcept { + return checkValid(EvpPkeyPtr{pemReadPUBKEY(Encoded)}); +} + +template +WasiCryptoExpect +Rsa::PublicKey::checkValid(EvpPkeyPtr Ctx) noexcept { + ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); + RSA *RsaKey = EVP_PKEY_get0_RSA(Ctx.get()); + ensureOrReturn(RsaKey, __WASI_CRYPTO_ERRNO_INVALID_KEY); + ensureOrReturn(RSA_bits(RsaKey) == KeyBits, __WASI_CRYPTO_ERRNO_INVALID_KEY); + return {std::move(Ctx)}; +} + +template +WasiCryptoExpect +Rsa::PublicKey::verify() const noexcept { + ensureOrReturn(RSA_check_key(EVP_PKEY_get0_RSA(Ctx.get())), + __WASI_CRYPTO_ERRNO_INVALID_KEY); + return {}; +} + +template +WasiCryptoExpect> +Rsa::PublicKey::exportData( + __wasi_publickey_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_PUBLICKEY_ENCODING_PKCS8: + return exportPkcs8(); + case __WASI_PUBLICKEY_ENCODING_PEM: + return exportPem(); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +template +WasiCryptoExpect> +Rsa::PublicKey::exportPkcs8() const noexcept { + return i2dPUBKEY(Ctx.get()); +} + +template +WasiCryptoExpect> +Rsa::PublicKey::exportPem() const noexcept { + return pemWritePUBKEY(Ctx.get()); +} + +template +WasiCryptoExpect::VerificationState> +Rsa::PublicKey::openVerificationState() + const noexcept { + EvpMdCtxPtr SignCtx{EVP_MD_CTX_create()}; + opensslCheck(EVP_DigestVerifyInit( + SignCtx.get(), nullptr, EVP_get_digestbynid(ShaNid), nullptr, Ctx.get())); + opensslCheck(EVP_PKEY_CTX_set_rsa_padding(EVP_MD_CTX_pkey_ctx(SignCtx.get()), + PadMode)); + return SignCtx; +} + +template +WasiCryptoExpect::SecretKey> +Rsa::SecretKey::import( + Span Encoded, + __wasi_secretkey_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_SECRETKEY_ENCODING_PKCS8: + return importPkcs8(Encoded); + case __WASI_SECRETKEY_ENCODING_PEM: + return importPem(Encoded); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +template +WasiCryptoExpect::SecretKey> +Rsa::SecretKey::importPem( + Span Encoded) noexcept { + return checkValid(EvpPkeyPtr{pemReadPrivateKey(Encoded)}); +} + +template +WasiCryptoExpect::SecretKey> +Rsa::SecretKey::importPkcs8( + Span Encoded) noexcept { + return checkValid(EvpPkeyPtr{d2iPrivateKey(Encoded)}); +} + +template +WasiCryptoExpect +Rsa::SecretKey::checkValid(EvpPkeyPtr Ctx) noexcept { + ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); + RSA *RsaKey = EVP_PKEY_get0_RSA(Ctx.get()); + ensureOrReturn(RsaKey, __WASI_CRYPTO_ERRNO_INVALID_KEY); + ensureOrReturn(RSA_bits(RsaKey) == KeyBits, __WASI_CRYPTO_ERRNO_INVALID_KEY); + return {std::move(Ctx)}; +} + +template +WasiCryptoExpect::KeyPair> +Rsa::SecretKey::toKeyPair( + const PublicKey &) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +template +WasiCryptoExpect +Rsa::SecretKey::exportData( + __wasi_secretkey_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_SECRETKEY_ENCODING_PKCS8: + return exportPkcs8(); + case __WASI_SECRETKEY_ENCODING_PEM: + return exportPem(); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +template +WasiCryptoExpect::PublicKey> +Rsa::SecretKey::publicKey() const noexcept { + // Since the inner is always `const`, we just increase the ref count. + return Ctx; +} + +template +WasiCryptoExpect +Rsa::SecretKey::exportPem() const noexcept { + return pemWritePrivateKey(Ctx.get()); +} + +template +WasiCryptoExpect +Rsa::SecretKey::exportPkcs8() const noexcept { + return i2dPrivateKey(Ctx.get()); +} + +template +WasiCryptoExpect::KeyPair> +Rsa::KeyPair::import( + Span Encoded, + __wasi_keypair_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_KEYPAIR_ENCODING_PKCS8: + return importPkcs8(Encoded); + case __WASI_KEYPAIR_ENCODING_PEM: + return importPem(Encoded); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +template +WasiCryptoExpect::KeyPair> +Rsa::KeyPair::importPem( + Span Encoded) noexcept { + return checkValid(EvpPkeyPtr{d2iPrivateKey(Encoded)}); +} + +template +WasiCryptoExpect::KeyPair> +Rsa::KeyPair::importPkcs8( + Span Encoded) noexcept { + return checkValid(EvpPkeyPtr{d2iPrivateKey(Encoded)}); +} + +template +WasiCryptoExpect +Rsa::KeyPair::checkValid(EvpPkeyPtr Ctx) noexcept { + ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); + RSA *RsaKey = EVP_PKEY_get0_RSA(Ctx.get()); + ensureOrReturn(RsaKey, __WASI_CRYPTO_ERRNO_INVALID_KEY); + ensureOrReturn(RSA_bits(RsaKey) == KeyBits, __WASI_CRYPTO_ERRNO_INVALID_KEY); + return {std::move(Ctx)}; +} + +template +WasiCryptoExpect::SignState> +Rsa::KeyPair::openSignState() const noexcept { + EvpMdCtxPtr SignCtx{EVP_MD_CTX_create()}; + opensslCheck(EVP_DigestSignInit( + SignCtx.get(), nullptr, EVP_get_digestbynid(ShaNid), nullptr, Ctx.get())); + opensslCheck(EVP_PKEY_CTX_set_rsa_padding(EVP_MD_CTX_pkey_ctx(SignCtx.get()), + PadMode)); + return SignCtx; +} + +template +WasiCryptoExpect::KeyPair> +Rsa::KeyPair::generate( + OptionalRef) noexcept { + EvpPkeyCtxPtr Ctx{EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, nullptr)}; + EVP_PKEY_keygen_init(Ctx.get()); + EVP_PKEY_CTX_set_rsa_padding(Ctx.get(), PadMode); + EVP_PKEY_CTX_set_rsa_keygen_bits(Ctx.get(), KeyBits); + EVP_PKEY_CTX_set_signature_md(Ctx.get(), getShaCtx()); + + EVP_PKEY *Res = nullptr; + EVP_PKEY_keygen(Ctx.get(), &Res); + return EvpPkeyPtr{Res}; +} + +template +WasiCryptoExpect Rsa::KeyPair::exportData( + __wasi_keypair_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_KEYPAIR_ENCODING_PKCS8: + return exportPkcs8(); + case __WASI_KEYPAIR_ENCODING_PEM: + return exportPem(); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +template +WasiCryptoExpect +Rsa::KeyPair::exportPem() const noexcept { + return pemWritePrivateKey(Ctx.get()); +} + +template +WasiCryptoExpect +Rsa::KeyPair::exportPkcs8() const noexcept { + return i2dPrivateKey(Ctx.get()); +} + +template +WasiCryptoExpect::PublicKey> +Rsa::KeyPair::publicKey() const noexcept { + // Since the inner is always `const`, we just increase the ref count. + return Ctx; +} + +template +WasiCryptoExpect::SecretKey> +Rsa::KeyPair::secretKey() const noexcept { + // Since the inner is always `const`, we just increase the ref count. + return Ctx; +} + +template +WasiCryptoExpect::Signature> +Rsa::Signature::import( + Span Encoded, + __wasi_signature_encoding_e_t Encoding) noexcept { + switch (Encoding) { + case __WASI_SIGNATURE_ENCODING_RAW: + ensureOrReturn(Encoded.size() == getSigSize(), + __WASI_CRYPTO_ERRNO_INVALID_SIGNATURE); + return std::vector(Encoded.begin(), Encoded.end()); + case __WASI_SIGNATURE_ENCODING_DER: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + default: + assumingUnreachable(); + } +} + +template +WasiCryptoExpect> +Rsa::Signature::exportData( + __wasi_signature_encoding_e_t Encoding) const noexcept { + switch (Encoding) { + case __WASI_SIGNATURE_ENCODING_RAW: + return Data; + case __WASI_SIGNATURE_ENCODING_DER: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + default: + assumingUnreachable(); + } +} + +template +WasiCryptoExpect Rsa::SignState::update( + Span Data) noexcept { + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck( + EVP_DigestSignUpdate(Ctx->RawCtx.get(), Data.data(), Data.size())); + return {}; +} + +template +WasiCryptoExpect::Signature> +Rsa::SignState::sign() noexcept { + size_t Size = getSigSize(); + std::vector Res(Size); + + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_DigestSignFinal(Ctx->RawCtx.get(), Res.data(), &Size)); + ensureOrReturn(Size == getSigSize(), __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + return Res; +} + +template +WasiCryptoExpect Rsa::VerificationState::update( + Span Data) noexcept { + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck( + EVP_DigestVerifyUpdate(Ctx->RawCtx.get(), Data.data(), Data.size())); + return {}; +} + +template +WasiCryptoExpect Rsa::VerificationState::verify( + const Signature &Sig) noexcept { + std::scoped_lock Lock{Ctx->Mutex}; + ensureOrReturn(EVP_DigestVerifyFinal(Ctx->RawCtx.get(), Sig.ref().data(), + Sig.ref().size()), + __WASI_CRYPTO_ERRNO_VERIFICATION_FAILED); + + return {}; +} + +template class Rsa; +template class Rsa; +template class Rsa; + +template class Rsa; +template class Rsa; + +template class Rsa; + +template class Rsa; +template class Rsa; +template class Rsa; + +template class Rsa; +template class Rsa; + +template class Rsa; + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/rsa.h b/plugins/wasi_crypto/signatures/rsa.h new file mode 100644 index 00000000..980136e7 --- /dev/null +++ b/plugins/wasi_crypto/signatures/rsa.h @@ -0,0 +1,224 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/signatures/rsa.h - Rsa alg implement -===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the declaration of rsa and related classes. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "signatures/options.h" +#include "utils/error.h" +#include "utils/evp_wrapper.h" +#include "utils/optional.h" +#include "utils/secret_vec.h" + +#include "common/span.h" + +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +template class Rsa { +public: + class Signature { + public: + Signature(std::vector &&Data) noexcept : Data(std::move(Data)) {} + + static WasiCryptoExpect + import(Span Encoded, + __wasi_signature_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect> + exportData(__wasi_signature_encoding_e_t Encoding) const noexcept; + + const std::vector &ref() const { return Data; } + + private: + std::vector Data; + }; + + class SignState { + public: + SignState(EvpMdCtxPtr Ctx) noexcept + : Ctx(std::make_shared(std::move(Ctx))) {} + + WasiCryptoExpect update(Span Data) noexcept; + + WasiCryptoExpect sign() noexcept; + + private: + struct Inner { + Inner(EvpMdCtxPtr RawCtx) noexcept : RawCtx(std::move(RawCtx)) {} + std::mutex Mutex; + EvpMdCtxPtr RawCtx; + }; + std::shared_ptr Ctx; + }; + + class VerificationState { + public: + VerificationState(EvpMdCtxPtr Ctx) noexcept + : Ctx(std::make_shared(std::move(Ctx))) {} + + WasiCryptoExpect update(Span Data) noexcept; + + WasiCryptoExpect verify(const Signature &Sig) noexcept; + + private: + struct Inner { + Inner(EvpMdCtxPtr RawCtx) noexcept : RawCtx(std::move(RawCtx)) {} + std::mutex Mutex; + EvpMdCtxPtr RawCtx; + }; + std::shared_ptr Ctx; + }; + + class PublicKey { + public: + PublicKey(EvpPkeyPtr Ctx) noexcept : Ctx(std::move(Ctx)) {} + + PublicKey(SharedEvpPkey Ctx) noexcept : Ctx(std::move(Ctx)) {} + + static WasiCryptoExpect + import(Span Encoded, + __wasi_publickey_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect verify() const noexcept; + + WasiCryptoExpect> + exportData(__wasi_publickey_encoding_e_t Encoding) const noexcept; + + WasiCryptoExpect openVerificationState() const noexcept; + + private: + static WasiCryptoExpect + importPem(Span Encoded) noexcept; + + static WasiCryptoExpect + importPkcs8(Span Encoded) noexcept; + + static WasiCryptoExpect checkValid(EvpPkeyPtr Ctx) noexcept; + + WasiCryptoExpect> exportPem() const noexcept; + + WasiCryptoExpect> exportPkcs8() const noexcept; + + SharedEvpPkey Ctx; + }; + + class KeyPair; + + class SecretKey { + public: + SecretKey(EvpPkeyPtr Ctx) : Ctx(std::move(Ctx)) {} + + SecretKey(SharedEvpPkey Ctx) noexcept : Ctx(std::move(Ctx)) {} + + static WasiCryptoExpect + import(Span Encoded, + __wasi_secretkey_encoding_e_t Encoding) noexcept; + + WasiCryptoExpect + exportData(__wasi_secretkey_encoding_e_t Encoding) const noexcept; + + WasiCryptoExpect publicKey() const noexcept; + + WasiCryptoExpect toKeyPair(const PublicKey &Pk) const noexcept; + + private: + static WasiCryptoExpect + importPem(Span Encoded) noexcept; + + static WasiCryptoExpect + importPkcs8(Span Encoded) noexcept; + + static WasiCryptoExpect checkValid(EvpPkeyPtr Ctx) noexcept; + + WasiCryptoExpect exportPem() const noexcept; + + WasiCryptoExpect exportPkcs8() const noexcept; + + SharedEvpPkey Ctx; + }; + + class KeyPair { + public: + KeyPair(EvpPkeyPtr Ctx) : Ctx(std::move(Ctx)) {} + + KeyPair(SharedEvpPkey Ctx) noexcept : Ctx(std::move(Ctx)) {} + + static WasiCryptoExpect + import(Span Encoded, + __wasi_keypair_encoding_e_t Encoding) noexcept; + + static WasiCryptoExpect + generate(OptionalRef OptOptions) noexcept; + + WasiCryptoExpect + exportData(__wasi_keypair_encoding_e_t Encoding) const noexcept; + + WasiCryptoExpect publicKey() const noexcept; + + WasiCryptoExpect secretKey() const noexcept; + + WasiCryptoExpect openSignState() const noexcept; + + private: + static WasiCryptoExpect + importPem(Span Encoded) noexcept; + + static WasiCryptoExpect + importPkcs8(Span Encoded) noexcept; + + static WasiCryptoExpect checkValid(EvpPkeyPtr Ctx) noexcept; + + WasiCryptoExpect exportPem() const noexcept; + + WasiCryptoExpect exportPkcs8() const noexcept; + + SharedEvpPkey Ctx; + }; + +private: + static constexpr size_t getSigSize() { return KeyBits / 8; } + + static void *getShaCtx() { + return static_cast( + const_cast(EVP_get_digestbynid(ShaNid))); + } +}; + +using RSA_PKCS1_2048_SHA256 = Rsa; +using RSA_PKCS1_2048_SHA384 = Rsa; +using RSA_PKCS1_2048_SHA512 = Rsa; + +using RSA_PKCS1_3072_SHA384 = Rsa; +using RSA_PKCS1_3072_SHA512 = Rsa; + +using RSA_PKCS1_4096_SHA512 = Rsa; + +using RSA_PSS_2048_SHA256 = Rsa; +using RSA_PSS_2048_SHA384 = Rsa; +using RSA_PSS_2048_SHA512 = Rsa; + +using RSA_PSS_3072_SHA384 = Rsa; +using RSA_PSS_3072_SHA512 = Rsa; + +using RSA_PSS_4096_SHA512 = Rsa; + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/signatures.cpp b/plugins/wasi_crypto/signatures/signatures.cpp new file mode 100644 index 00000000..29fae93a --- /dev/null +++ b/plugins/wasi_crypto/signatures/signatures.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "signatures/signatures.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +WasiCryptoExpect +sigImport(Algorithm Alg, Span Encoded, + __wasi_signature_encoding_e_t Encoding) noexcept { + return std::visit( + [=](auto Factory) noexcept { + return decltype(Factory)::Signature::import(Encoded, Encoding) + .map([](auto &&Sig) noexcept { + return SigVariant{std::forward(Sig)}; + }); + }, + Alg); +} + +WasiCryptoExpect> +sigExportData(const SigVariant &SigVariant, + __wasi_signature_encoding_e_t Encoding) noexcept { + return std::visit( + [Encoding](auto &Sig) noexcept { return Sig.exportData(Encoding); }, + SigVariant); +} + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/signatures.h b/plugins/wasi_crypto/signatures/signatures.h new file mode 100644 index 00000000..44de8564 --- /dev/null +++ b/plugins/wasi_crypto/signatures/signatures.h @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/signatures/signatures.h - Signatures -===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the implementation of signatures. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "signatures/registed.h" +#include "utils/error.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +using SigVariant = RegistedAlg::SigVariant; + +WasiCryptoExpect +sigImport(Algorithm Alg, Span Encoded, + __wasi_signature_encoding_e_t Encoding) noexcept; + +WasiCryptoExpect> +sigExportData(const SigVariant &SigVariant, + __wasi_signature_encoding_e_t Encoding) noexcept; + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/signstate.cpp b/plugins/wasi_crypto/signatures/signstate.cpp new file mode 100644 index 00000000..bf12ddbc --- /dev/null +++ b/plugins/wasi_crypto/signatures/signstate.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "signatures/signstate.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +WasiCryptoExpect +sigStateOpen(const KpVariant &KpVariant) noexcept { + return std::visit( + [](const auto &Kp) noexcept { + return Kp.openSignState().map([](auto &&SignState) noexcept { + return SignStateVariant{std::forward(SignState)}; + }); + }, + KpVariant); +} + +WasiCryptoExpect sigStateUpdate(SignStateVariant &SignStateVariant, + Span Input) noexcept { + return std::visit( + [Input](auto &SignState) noexcept { return SignState.update(Input); }, + SignStateVariant); +} + +WasiCryptoExpect +sigStateSign(SignStateVariant &SignStateVariant) noexcept { + return std::visit( + [](auto &SignState) noexcept { + return SignState.sign().map([](auto &&Sig) noexcept { + return SigVariant{std::forward(Sig)}; + }); + }, + SignStateVariant); +} + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/signstate.h b/plugins/wasi_crypto/signatures/signstate.h new file mode 100644 index 00000000..be92cac0 --- /dev/null +++ b/plugins/wasi_crypto/signatures/signstate.h @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/signatures/signstate.h - SignState ---===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the implementation of sign state. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "signatures/registed.h" +#include "signatures/signatures.h" +#include "utils/error.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +/// Signatures computation. +/// +/// More detailed: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#signature-creation +using SignStateVariant = RegistedAlg::SignStateVariant; +using KpVariant = RegistedAlg::KpVariant; + +WasiCryptoExpect +sigStateOpen(const KpVariant &PkVariant) noexcept; + +WasiCryptoExpect sigStateUpdate(SignStateVariant &SignStateVariant, + Span Input) noexcept; + +WasiCryptoExpect +sigStateSign(SignStateVariant &SignStateVariant) noexcept; + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/verificationstate.cpp b/plugins/wasi_crypto/signatures/verificationstate.cpp new file mode 100644 index 00000000..547bbdb8 --- /dev/null +++ b/plugins/wasi_crypto/signatures/verificationstate.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "signatures/verificationstate.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +WasiCryptoExpect +verificationStateOpen(const PkVariant &PkVariant) noexcept { + return std::visit( + [](const auto &Pk) noexcept { + return Pk.openVerificationState().map( + [](auto &&VerificationState) noexcept { + return VerificationStateVariant{ + std::forward(VerificationState)}; + }); + }, + PkVariant); +} + +WasiCryptoExpect +verificationStateUpdate(VerificationStateVariant &VerificationStateVariant, + Span Input) noexcept { + return std::visit( + [Input](auto &VerificationState) noexcept { + return VerificationState.update(Input); + }, + VerificationStateVariant); +} + +namespace { +/// Correspond signatures: +/// WasiCryptoExpect VerificationStateType::verify(const SigType&); +/// is used to get `SigType`. +template struct VerifyTrait; +template +struct VerifyTrait (VerificationStateType::*)( + const SigType &) noexcept> { + using Sig = SigType; +}; + +template +using SigType = typename VerifyTrait::Sig; +} // namespace + +WasiCryptoExpect +verificationStateVerify(VerificationStateVariant &VerificationStateVariant, + const SigVariant &SigVariant) noexcept { + return std::visit( + [](auto &VerificationState, + const auto &Sig) noexcept -> WasiCryptoExpect { + using RequiredSigType = + SigType>; + using InSigType = std::decay_t; + + if constexpr (std::is_same_v) { + return VerificationState.verify(Sig); + } else { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_SIGNATURE); + } + }, + VerificationStateVariant, SigVariant); +} + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/verificationstate.h b/plugins/wasi_crypto/signatures/verificationstate.h new file mode 100644 index 00000000..565c1ffe --- /dev/null +++ b/plugins/wasi_crypto/signatures/verificationstate.h @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/signatures/verificationstate.h -------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the verification state related functions. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "signatures/registed.h" +#include "signatures/signatures.h" +#include "utils/error.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Signatures { + +/// Signatures verify +/// +/// More detailed: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#signature-verification +using VerificationStateVariant = RegistedAlg::VerificationStateVariant; +using PkVariant = RegistedAlg::PkVariant; + +WasiCryptoExpect +verificationStateOpen(const PkVariant &PkVariant) noexcept; + +WasiCryptoExpect +verificationStateUpdate(VerificationStateVariant &VerificationStateVariant, + Span Input) noexcept; + +WasiCryptoExpect +verificationStateVerify(VerificationStateVariant &VerificationStateVariant, + const SigVariant &SigVariant) noexcept; + +} // namespace Signatures +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/aeads.cpp b/plugins/wasi_crypto/symmetric/aeads.cpp new file mode 100644 index 00000000..99a09d06 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/aeads.cpp @@ -0,0 +1,231 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "symmetric/aeads.h" +#include "utils/error.h" +#include "utils/evp_wrapper.h" + +#include +#include + +#include +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { +using namespace std::literals; + +template +constexpr size_t Cipher::getKeySize() noexcept { + static_assert(CipherNid == NID_aes_128_gcm || CipherNid == NID_aes_256_gcm || + CipherNid == NID_chacha20_poly1305); + if constexpr (CipherNid == NID_aes_128_gcm) { + return 16; + } + if constexpr (CipherNid == NID_aes_256_gcm) { + return 32; + } + if constexpr (CipherNid == NID_chacha20_poly1305) { + return 32; + } +} + +template +WasiCryptoExpect Cipher::State::maxTagLen() const noexcept { + return getTagSize(); +} + +template +constexpr size_t Cipher::getTagSize() noexcept { + return 16; +} + +template +WasiCryptoExpect::Key> +Cipher::Key::generate(OptionalRef) noexcept { + return SecretVec::random(); +} + +template +WasiCryptoExpect::Key> +Cipher::Key::import(Span Raw) noexcept { + return SecretVec{Raw}; +} + +template +WasiCryptoExpect::State> +Cipher::State::open(const Key &Key, + OptionalRef OptOption) noexcept { + ensureOrReturn(OptOption, __WASI_CRYPTO_ERRNO_NONCE_REQUIRED); + + std::array Nonce; + if (auto Res = OptOption->get("nonce"sv, Nonce); !Res) { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_NONCE); + } else { + ensureOrReturn(*Res == NonceSize, __WASI_CRYPTO_ERRNO_INVALID_NONCE); + } + + ensureOrReturn(getKeySize() == Key.ref().size(), + __WASI_CRYPTO_ERRNO_INVALID_HANDLE); + + EvpCipherCtxPtr Ctx{EVP_CIPHER_CTX_new()}; + opensslCheck(EVP_CipherInit_ex(Ctx.get(), EVP_get_cipherbynid(CipherNid), + nullptr, Key.ref().data(), Nonce.data(), + Mode::Unchanged)); + + return State{std::move(Ctx), Nonce}; +} + +template +WasiCryptoExpect +Cipher::State::optionsGet(std::string_view Name, + Span Value) const noexcept { + ensureOrReturn(Name == "nonce"sv, __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + ensureOrReturn(NonceSize <= Value.size(), __WASI_CRYPTO_ERRNO_OVERFLOW); + + std::copy(Ctx->Nonce.begin(), Ctx->Nonce.end(), Value.begin()); + return NonceSize; +} + +// https://wiki.openssl.org/index.php/EVP_Authenticated_Encryption_and_Decryption +template +WasiCryptoExpect +Cipher::State::absorb(Span Data) noexcept { + ensureOrReturn(Data.size() <= + static_cast(std::numeric_limits::max()), + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + int DataSize = static_cast(Data.size()); + + int ActualAbsorbSize; + { + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_CipherUpdate(Ctx->RawCtx.get(), nullptr, &ActualAbsorbSize, + Data.data(), DataSize)); + } + ensureOrReturn(ActualAbsorbSize == DataSize, + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + return {}; +} + +template +WasiCryptoExpect +Cipher::State::encrypt(Span Out, + Span Data) noexcept { + return encryptImpl(Out.first(Data.size()), Out.last(getTagSize()), Data); +} + +template +WasiCryptoExpect +Cipher::State::encryptDetached(Span Out, + Span Data) noexcept { + SecretVec Tag(getTagSize()); + if (auto Res = encryptImpl(Out, Tag, Data); !Res) { + return WasiCryptoUnexpect(Res); + } + return Tag; +} + +template +WasiCryptoExpect +Cipher::State::encryptImpl(Span Out, Span Tag, + Span Data) noexcept { + ensureOrReturn(Data.size() <= + static_cast(std::numeric_limits::max()), + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + int DataSize = static_cast(Data.size()); + + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_CipherInit_ex(Ctx->RawCtx.get(), nullptr, nullptr, nullptr, + nullptr, Mode::Encrypt)); + + int ActualUpdateSize; + opensslCheck(EVP_CipherUpdate(Ctx->RawCtx.get(), Out.data(), + &ActualUpdateSize, Data.data(), DataSize)); + + int ActualFinalSize; + ensureOrReturn( + EVP_CipherFinal_ex(Ctx->RawCtx.get(), nullptr, &ActualFinalSize), + __WASI_CRYPTO_ERRNO_INTERNAL_ERROR); + + ensureOrReturn(static_cast(ActualUpdateSize + ActualFinalSize) == + Out.size(), + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + opensslCheck(EVP_CIPHER_CTX_ctrl(Ctx->RawCtx.get(), EVP_CTRL_AEAD_GET_TAG, + static_cast(getTagSize()), Tag.data())); + + return Out.size() + getTagSize(); +} + +template +WasiCryptoExpect +Cipher::State::decrypt(Span Out, + Span Data) noexcept { + return decryptImpl(Out, Data.first(Out.size()), Data.last(getTagSize())); +} + +template +WasiCryptoExpect +Cipher::State::decryptDetached(Span Out, + Span Data, + Span RawTag) noexcept { + return decryptImpl(Out, Data, RawTag); +} + +template +WasiCryptoExpect +Cipher::State::decryptImpl(Span Out, + Span Data, + Span RawTag) noexcept { + ensureOrReturn(Data.size() <= + static_cast(std::numeric_limits::max()), + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + int DataSize = static_cast(Data.size()); + + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_CipherInit_ex(Ctx->RawCtx.get(), nullptr, nullptr, nullptr, + nullptr, Mode::Decrypt)); + int ActualUpdateSize; + opensslCheck(EVP_CipherUpdate(Ctx->RawCtx.get(), Out.data(), + &ActualUpdateSize, Data.data(), DataSize)); + + opensslCheck(EVP_CIPHER_CTX_ctrl(Ctx->RawCtx.get(), EVP_CTRL_AEAD_SET_TAG, + static_cast(getTagSize()), + const_cast(RawTag.data()))); + + int ActualFinalSize; + if (!EVP_CipherFinal_ex(Ctx->RawCtx.get(), nullptr, &ActualFinalSize)) { + OPENSSL_cleanse(Out.data(), Out.size()); + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_TAG); + } + ensureOrReturn(ActualFinalSize + ActualUpdateSize == DataSize, + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + return Out.size(); +} + +template +WasiCryptoExpect::State> +Cipher::State::clone() const noexcept { + EvpCipherCtxPtr CloneCtx{EVP_CIPHER_CTX_new()}; + + { + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_CIPHER_CTX_copy(CloneCtx.get(), Ctx->RawCtx.get())); + } + + return State{std::move(CloneCtx), Ctx->Nonce}; +} + +template class Cipher; +template class Cipher; +template class Cipher; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/aeads.h b/plugins/wasi_crypto/symmetric/aeads.h new file mode 100644 index 00000000..96d6b245 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/aeads.h @@ -0,0 +1,184 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/symmetric/aeads.h - Aeads related ----===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the declaration of Aeads and related classes. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "symmetric/options.h" +#include "symmetric/tag.h" +#include "utils/error.h" +#include "utils/evp_wrapper.h" +#include "utils/optional.h" +#include "utils/secret_vec.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +/// Aeads invalid operations, every Aeads state should inherent from this class. +/// +/// More detailed: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#aeads +template class AEADsState { +public: + WasiCryptoExpect optionsGetU64(std::string_view) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + } + + WasiCryptoExpect squeeze(Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect ratchet() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect squeezeKey() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect squeezeTag() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } +}; + +template class Cipher { + static inline constexpr size_t NonceSize = 12; + +public: + class Key { + public: + Key(SecretVec Data) noexcept : Data(std::move(Data)) {} + + static WasiCryptoExpect import(Span Data) noexcept; + + static WasiCryptoExpect + generate(OptionalRef Options) noexcept; + + SecretVec exportData() const noexcept { return Data; } + + const SecretVec &ref() const noexcept { return Data; } + + private: + SecretVec Data; + }; + + class State : public AEADsState { + public: + /// There are four inputs for authenticated encryption: + /// @param[in] Key The secret key for encrypt + /// @param[in] OptOption `Must` contain an Nonce (Initialization vector). + static WasiCryptoExpect + open(const Key &Key, OptionalRef OptOption) noexcept; + + State(EvpCipherCtxPtr Ctx, std::array Nonce) noexcept + : Ctx(std::make_shared(std::move(Ctx), Nonce)) {} + + WasiCryptoExpect optionsGet(std::string_view Name, + Span Value) const noexcept; + + /// Absorbs additional data. Multiple calls to absorb() MUST be equivalent + /// to a single call with a concatenation of the inputs. + /// + /// @param Data Additional data + /// @return Nothing or WasiCrypto error + WasiCryptoExpect absorb(Span Data) noexcept; + + /// Return the length required to encode the authentication tag and the + /// optional padding bytes. The returned length MUST be constant for a given + /// algorithm. Guest applications are expected to provide an output buffer + /// whose size is the size of the message, plus the max_tag_len() output + /// value. + /// + /// @return the length required to encode the authentication tag + /// and optional padding bytes. + WasiCryptoExpect maxTagLen() const noexcept; + + /// Check Out.size() == Data.size() + maxTagLen(), then call + /// encryptUnchecked(Out, Data) or return error if not equal. + /// + /// @param Out The encrypted data text + /// @param Data The data to be encrypted + /// @return Tag's size or + /// `__WASI_CRYPTO_ERRNO_OVERFLOW`/`__WASI_CRYPTO_ERRNO_INVALID_LENGTH` + WasiCryptoExpect encrypt(Span Out, + Span Data) noexcept; + + /// Check Out.size() == Data.size(), then call + /// encryptDetachedUnchecked(Out, Data) or error if not equal + /// + /// @param Out The encrypted data text + /// @param Data The data to be encrypted + /// @return Tag + /// or `__WASI_CRYPTO_ERRNO_OVERFLOW`/`__WASI_CRYPTO_ERRNO_INVALID_LENGTH` + WasiCryptoExpect encryptDetached(Span Out, + Span Data) noexcept; + + /// Check Out.size() = Data.size() + maxTagLen(), then call + /// decryptDetachedUnchecked(Out, Data) or error if not equal + /// + /// @param Out The decrypted data text + /// @param Data The data to be decrypted + /// @return Size or + /// `__WASI_CRYPTO_ERRNO_OVERFLOW`/`__WASI_CRYPTO_ERRNO_INVALID_LENGTH` + WasiCryptoExpect decrypt(Span Out, + Span Data) noexcept; + + /// Check Out.size() == Data.size(), then call + /// encryptDetachedUnchecked(Out, Data) or error if not equal + /// + /// @param Out The decrypted data text + /// @param Data The data to be decrypted + /// @return Size or + /// `__WASI_CRYPTO_ERRNO_OVERFLOW`/`__WASI_CRYPTO_ERRNO_INVALID_LENGTH` + WasiCryptoExpect + decryptDetached(Span Out, Span Data, + Span RawTag) noexcept; + + WasiCryptoExpect clone() const noexcept; + + private: + WasiCryptoExpect encryptImpl(Span Out, Span Tag, + Span Data) noexcept; + + WasiCryptoExpect decryptImpl(Span Out, + Span Data, + Span RawTag) noexcept; + struct Inner { + Inner(EvpCipherCtxPtr RawCtx, + std::array Nonce) noexcept + : RawCtx(std::move(RawCtx)), Nonce(Nonce) {} + EvpCipherCtxPtr RawCtx; + const std::array Nonce; + std::mutex Mutex; + }; + std::shared_ptr Ctx; + }; + +private: + enum Mode { Unchanged = -1, Decrypt = 0, Encrypt = 1 }; + + constexpr static size_t getKeySize() noexcept; + + constexpr static size_t getTagSize() noexcept; +}; + +using Aes128Gcm = Cipher; +using Aes256Gcm = Cipher; +using ChaCha20Poly1305 = Cipher; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/ctx.cpp b/plugins/wasi_crypto/symmetric/ctx.cpp new file mode 100644 index 00000000..4afa2351 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/ctx.cpp @@ -0,0 +1,312 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "ctx.h" +#include "symmetric/key.h" +#include "symmetric/state.h" +#include "symmetric/tag.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +WasiCryptoExpect +Context::symmetricTagLen(__wasi_symmetric_tag_t TagHandle) noexcept { + return SymmetricTagManager.get(TagHandle).map(&Symmetric::Tag::len); +} + +WasiCryptoExpect +Context::symmetricTagPull(__wasi_symmetric_tag_t TagHandle, + Span Buf) noexcept { + return SymmetricTagManager.get(TagHandle).and_then( + [Buf](const Symmetric::Tag &Tag) noexcept { return Tag.pull(Buf); }); +} + +WasiCryptoExpect +Context::symmetricTagVerify(__wasi_symmetric_tag_t TagHandle, + Span RawTag) noexcept { + return SymmetricTagManager.get(TagHandle).and_then( + [RawTag](const Symmetric::Tag &Tag) noexcept { + return Tag.verify(RawTag); + }); +} + +WasiCryptoExpect +Context::symmetricTagClose(__wasi_symmetric_tag_t TagHandle) noexcept { + return SymmetricTagManager.close(TagHandle); +} + +WasiCryptoExpect<__wasi_array_output_t> +Context::symmetricKeyExport(__wasi_symmetric_key_t KeyHandle) noexcept { + return SymmetricKeyManager.get(KeyHandle) + .map(Symmetric::keyExportData) + .and_then([this](auto &&Data) noexcept { + return ArrayOutputManager.registerManager( + std::forward(Data)); + }); +} + +WasiCryptoExpect +Context::symmetricKeyClose(__wasi_symmetric_key_t SymmetricKey) noexcept { + return SymmetricKeyManager.close(SymmetricKey); +} + +WasiCryptoExpect +Context::symmetricStateOptionsGet(__wasi_symmetric_state_t StateHandle, + std::string_view Name, + Span Value) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then([=](auto &&State) noexcept { + return Symmetric::stateOptionsGet(std::forward(State), + Name, Value); + }); +} + +WasiCryptoExpect +Context::symmetricStateOptionsGetU64(__wasi_symmetric_state_t StateHandle, + std::string_view Name) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then([Name](auto &&State) noexcept { + return Symmetric::stateOptionsGetU64( + std::forward(State), Name); + }); +} + +WasiCryptoExpect +Context::symmetricStateClose(__wasi_symmetric_state_t StateHandle) noexcept { + return SymmetricStateManager.close(StateHandle); +} + +WasiCryptoExpect +Context::symmetricStateAbsorb(__wasi_symmetric_state_t StateHandle, + Span Data) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then([Data](auto &&State) noexcept { + return Symmetric::stateAbsorb(State, Data); + }); +} + +WasiCryptoExpect +Context::symmetricStateSqueeze(__wasi_symmetric_state_t StateHandle, + Span Out) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then([Out](auto &&State) noexcept { + return Symmetric::stateSqueeze(State, Out); + }); +} + +WasiCryptoExpect<__wasi_symmetric_tag_t> Context::symmetricStateSqueezeTag( + __wasi_symmetric_state_t StateHandle) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then([](auto &&State) noexcept { + return Symmetric::stateSqueezeTag(State); + }) + .and_then([this](auto &&Tag) { + return SymmetricTagManager.registerManager( + std::forward(Tag)); + }); +} + +WasiCryptoExpect<__wasi_symmetric_key_t> +Context::symmetricStateSqueezeKey(__wasi_symmetric_state_t StateHandle, + Symmetric::Algorithm Alg) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then([Alg](auto &&State) noexcept { + return Symmetric::stateSqueezeKey(State, Alg); + }) + .and_then([this](auto &&Key) { + return SymmetricKeyManager.registerManager( + std::forward(Key)); + }); +} + +WasiCryptoExpect Context::symmetricStateMaxTagLen( + __wasi_symmetric_state_t StateHandle) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then(&Symmetric::stateMaxTagLen); +} + +WasiCryptoExpect +Context::symmetricStateEncrypt(__wasi_symmetric_state_t StateHandle, + Span Out, + Span Data) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then([=](auto &&State) noexcept { + return Symmetric::stateEncrypt(State, Out, Data); + }); +} + +WasiCryptoExpect<__wasi_symmetric_tag_t> +Context::symmetricStateEncryptDetached(__wasi_symmetric_state_t StateHandle, + Span Out, + Span Data) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then([=](auto &&State) noexcept { + return Symmetric::stateEncryptDetached(State, Out, Data); + }) + .and_then([this](auto &&Tag) noexcept { + return SymmetricTagManager.registerManager( + std::forward(Tag)); + }); +} + +WasiCryptoExpect +Context::symmetricStateDecrypt(__wasi_symmetric_state_t StateHandle, + Span Out, + Span Data) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then([=](auto &&State) noexcept { + return Symmetric::stateDecrypt(State, Out, Data); + }); +} + +WasiCryptoExpect Context::symmetricStateDecryptDetached( + __wasi_symmetric_state_t StateHandle, Span Out, + Span Data, Span RawTag) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then([=](auto &&State) noexcept { + return Symmetric::stateDecryptDetached(State, Out, Data, RawTag); + }); +} + +WasiCryptoExpect +Context::symmetricStateRatchet(__wasi_symmetric_state_t StateHandle) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then( + [](auto &&State) noexcept { return Symmetric::stateRatchet(State); }); +} + +WasiCryptoExpect<__wasi_symmetric_key_t> +Context::symmetricKeyImport(Symmetric::Algorithm Alg, + Span Raw) noexcept { + return Symmetric::importKey(Alg, Raw).and_then([this](auto &&Key) noexcept { + return SymmetricKeyManager.registerManager( + std::forward(Key)); + }); +} + +WasiCryptoExpect<__wasi_symmetric_key_t> +Context::symmetricKeyGenerate(Symmetric::Algorithm Alg, + __wasi_opt_options_t OptOptionsHandle) noexcept { + auto OptOptionsResult = mapAndTransposeOptional( + OptOptionsHandle, [this](__wasi_options_t OptionsHandle) noexcept { + return OptionsManager.get(OptionsHandle); + }); + if (!OptOptionsResult) { + return WasiCryptoUnexpect(OptOptionsResult); + } + + // Refer to OptOptionsResult if it's a Symmetric::Options. + auto OptSymmetricOptionsResult = transposeOptionalToRef( + *OptOptionsResult, + [](const auto &Options) noexcept + -> WasiCryptoExpect> { + auto *SymmetricOptions = std::get_if(&Options); + if (!SymmetricOptions) { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_HANDLE); + } + return SymmetricOptions; + }); + if (!OptSymmetricOptionsResult) { + return WasiCryptoUnexpect(OptSymmetricOptionsResult); + } + + return Symmetric::generateKey(Alg, *OptSymmetricOptionsResult) + .and_then([this](auto &&Key) noexcept { + return SymmetricKeyManager.registerManager( + std::forward(Key)); + }); +} + +WasiCryptoExpect<__wasi_symmetric_state_t> +Context::symmetricStateOpen(Symmetric::Algorithm Alg, + __wasi_opt_symmetric_key_t OptKeyHandle, + __wasi_opt_options_t OptOptionsHandle) noexcept { + // Copy from KeyManager. + auto OptKeyResult = + mapAndTransposeOptional(OptKeyHandle, + [this](__wasi_symmetric_key_t KeyHandle) noexcept + -> WasiCryptoExpect { + return SymmetricKeyManager.get(KeyHandle); + }); + if (!OptKeyResult) { + return WasiCryptoUnexpect(OptKeyResult); + } + + // Copy from OptionsManager. + auto OptOptionsResult = mapAndTransposeOptional( + OptOptionsHandle, [this](__wasi_options_t OptionsHandle) noexcept { + return OptionsManager.get(OptionsHandle); + }); + if (!OptOptionsResult) { + return WasiCryptoUnexpect(OptOptionsResult); + } + + // Refer to OptOptionsResult if it's a Smmetric::Options. + auto OptSymmetricOptionsResult = transposeOptionalToRef( + *OptOptionsResult, + [](const auto &Options) noexcept + -> WasiCryptoExpect> { + auto *SymmetricOptions = std::get_if(&Options); + if (!SymmetricOptions) { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_HANDLE); + } + return SymmetricOptions; + }); + if (!OptSymmetricOptionsResult) { + return WasiCryptoUnexpect(OptSymmetricOptionsResult); + } + + return Symmetric::openState(Alg, asOptionalRef(*OptKeyResult), + *OptSymmetricOptionsResult) + .and_then([this](auto &&State) noexcept { + return SymmetricStateManager.registerManager( + std::forward(State)); + }); +} + +WasiCryptoExpect<__wasi_symmetric_state_t> +Context::symmetricStateClone(__wasi_symmetric_state_t StateHandle) noexcept { + return SymmetricStateManager.get(StateHandle) + .and_then(&Symmetric::stateClone) + .and_then([this](auto &&State) noexcept { + return SymmetricStateManager.registerManager( + std::forward(State)); + }); +} + +WasiCryptoExpect<__wasi_symmetric_key_t> +Context::symmetricKeyGenerateManaged(__wasi_secrets_manager_t, + Symmetric::Algorithm, + __wasi_opt_options_t) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect Context::symmetricKeyStoreManaged( + __wasi_secrets_manager_t, __wasi_symmetric_key_t, Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect<__wasi_version_t> +Context::symmetricKeyReplaceManaged(__wasi_secrets_manager_t, + __wasi_symmetric_key_t, + __wasi_symmetric_key_t) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect> +Context::symmetricKeyId(__wasi_symmetric_key_t, Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +WasiCryptoExpect<__wasi_symmetric_key_t> +Context::symmetricKeyFromId(__wasi_secrets_manager_t, Span, + __wasi_version_t) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/func.cpp b/plugins/wasi_crypto/symmetric/func.cpp new file mode 100644 index 00000000..2d4864e2 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/func.cpp @@ -0,0 +1,674 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "symmetric/func.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +Expect KeyGenerate::body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t AlgPtr, uint32_t AlgLen, + uint32_t OptOptionsPtr, + uint32_t /* Out */ KeyHandlePtr) { + checkExist(MemInst); + + const __wasi_size_t WasiAlgLen = AlgLen; + auto *const Alg = MemInst->getPointer(AlgPtr, WasiAlgLen); + checkExist(Alg); + Algorithm WasiAlg; + if (auto Res = tryFrom({Alg, WasiAlgLen}); !Res) { + return Res.error(); + } else { + WasiAlg = *Res; + } + + auto *const OptOptions = + MemInst->getPointer(OptOptionsPtr); + checkExist(OptOptions); + + auto *const KeyHandle = + MemInst->getPointer<__wasi_symmetric_key_t *>(KeyHandlePtr); + checkExist(KeyHandle); + + if (auto Res = Ctx.symmetricKeyGenerate(WasiAlg, *OptOptions); + unlikely(!Res)) { + return Res.error(); + } else { + *KeyHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeyImport::body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t AlgPtr, uint32_t AlgLen, + uint32_t RawPtr, uint32_t RawLen, + uint32_t /* Out */ KeyPtr) { + checkExist(MemInst); + + const __wasi_size_t WasiAlgLen = AlgLen; + auto *const Alg = MemInst->getPointer(AlgPtr, WasiAlgLen); + checkExist(Alg); + + Algorithm WasiAlg; + if (auto Res = tryFrom({Alg, WasiAlgLen}); !Res) { + return Res.error(); + } else { + WasiAlg = *Res; + } + + auto *const Key = MemInst->getPointer<__wasi_symmetric_key_t *>(KeyPtr); + checkExist(Key); + + const __wasi_size_t WasiRawLen = RawLen; + auto *Raw = MemInst->getPointer(RawPtr, WasiRawLen); + checkExist(Raw); + + if (auto Res = Ctx.symmetricKeyImport(WasiAlg, {Raw, WasiRawLen}); + unlikely(!Res)) { + return Res.error(); + } else { + *Key = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeyExport::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t KeyHandle, + uint32_t /* Out */ ArrayOutputHandlePtr) { + checkExist(MemInst); + + auto *const ArrayOutputHandle = + MemInst->getPointer<__wasi_array_output_t *>(ArrayOutputHandlePtr); + checkExist(ArrayOutputHandle); + + if (auto Res = Ctx.symmetricKeyExport(KeyHandle); unlikely(!Res)) { + return Res.error(); + } else { + *ArrayOutputHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeyClose::body(Runtime::Instance::MemoryInstance *, + int32_t KeyHandle) { + if (auto Res = Ctx.symmetricKeyClose(KeyHandle); unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +KeyGenerateManaged::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SecretsManagerHandle, uint32_t AlgPtr, + uint32_t AlgLen, uint32_t OptOptionsPtr, + uint32_t /* Out */ KeyHandlePtr) { + checkExist(MemInst); + + const __wasi_size_t WasiAlgLen = AlgLen; + auto *const Alg = MemInst->getPointer(AlgPtr, WasiAlgLen); + checkExist(Alg); + Algorithm WasiAlg; + if (auto Res = tryFrom({Alg, WasiAlgLen}); !Res) { + return Res.error(); + } else { + WasiAlg = *Res; + } + + auto *const OptOptions = + MemInst->getPointer(OptOptionsPtr); + checkExist(OptOptions); + + auto *const KeyHandle = + MemInst->getPointer<__wasi_symmetric_key_t *>(KeyHandlePtr); + checkExist(KeyHandle); + + if (auto Res = Ctx.symmetricKeyGenerateManaged(SecretsManagerHandle, WasiAlg, + *OptOptions); + unlikely(!Res)) { + return Res.error(); + } else { + *KeyHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +KeyStoreManaged::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SecretsManagerHandle, int32_t KeyHandle, + uint32_t KeyIdPtr, uint32_t KeyIdMaxLen) { + checkExist(MemInst); + + const __wasi_size_t WasiKeyIdMaxLen = KeyIdMaxLen; + auto *const KeyId = MemInst->getPointer(KeyIdPtr, WasiKeyIdMaxLen); + checkExist(KeyId); + + if (auto Res = Ctx.symmetricKeyStoreManaged(SecretsManagerHandle, KeyHandle, + {KeyId, WasiKeyIdMaxLen}); + unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +KeyReplaceManaged::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SecretsManagerHandle, int32_t OldKeyHandle, + int32_t NewKeyHandle, + uint32_t /* Out */ KeyVersionPtr) { + checkExist(MemInst); + + auto *const KeyVersion = + MemInst->getPointer<__wasi_version_t *>(KeyVersionPtr); + checkExist(KeyVersion); + + if (auto Res = Ctx.symmetricKeyReplaceManaged(SecretsManagerHandle, + OldKeyHandle, NewKeyHandle); + unlikely(!Res)) { + return Res.error(); + } else { + *KeyVersion = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeyId::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t KeyHandle, uint32_t KeyIdPtr, + uint32_t KeyIdMaxLen, uint32_t /* Out */ SizePtr, + uint32_t /* Out */ KeyVersionPtr) { + checkExist(MemInst); + + const __wasi_size_t WasiKeyIdMaxLen = KeyIdMaxLen; + auto *const KeyId = MemInst->getPointer(KeyIdPtr, WasiKeyIdMaxLen); + checkExist(KeyId); + + auto *const Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); + checkExist(Size); + + auto *const KeyVersion = + MemInst->getPointer<__wasi_version_t *>(KeyVersionPtr); + if (unlikely(KeyVersion == nullptr)) { + return __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE; + } + + if (auto Res = Ctx.symmetricKeyId(KeyHandle, {KeyId, WasiKeyIdMaxLen}); + unlikely(!Res)) { + return Res.error(); + } else { + auto [SizeRes, VersionRes] = *Res; + auto SafeSizeRes = toWasiSize(SizeRes); + if (unlikely(!SafeSizeRes)) { + return SafeSizeRes.error(); + } + + *KeyVersion = VersionRes; + *Size = *SafeSizeRes; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect KeyFromId::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SecretsManagerHandle, + uint32_t KeyIdPtr, uint32_t KeyIdLen, + uint64_t KeyVersion, + uint32_t /* Out */ KeyHandlePtr) { + checkExist(MemInst); + + const __wasi_size_t WasiKeyIdLen = KeyIdLen; + auto *const KeyId = MemInst->getPointer(KeyIdPtr, WasiKeyIdLen); + checkExist(KeyId); + + auto *const KeyHandle = + MemInst->getPointer<__wasi_symmetric_key_t *>(KeyHandlePtr); + checkExist(KeyHandle); + + if (auto Res = Ctx.symmetricKeyFromId(SecretsManagerHandle, + {KeyId, WasiKeyIdLen}, KeyVersion); + unlikely(!Res)) { + return Res.error(); + } else { + *KeyHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateOpen::body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t AlgPtr, uint32_t AlgLen, + uint32_t OptKeyHandlePtr, + uint32_t OptOptionsPtr, + uint32_t /* Out */ StatePtr) { + checkExist(MemInst); + + const __wasi_size_t WasiAlgLen = AlgLen; + auto *const Alg = MemInst->getPointer(AlgPtr, WasiAlgLen); + checkExist(Alg); + Algorithm WasiAlg; + if (auto Res = tryFrom({Alg, WasiAlgLen}); !Res) { + return Res.error(); + } else { + WasiAlg = *Res; + } + + auto *const OptKeyHandle = + MemInst->getPointer(OptKeyHandlePtr); + checkExist(OptKeyHandle); + + auto *const OptOptions = + MemInst->getPointer(OptOptionsPtr); + checkExist(OptOptions); + + auto *const State = MemInst->getPointer<__wasi_symmetric_state_t *>(StatePtr); + if (unlikely(State == nullptr)) { + return __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE; + } + + if (auto Res = Ctx.symmetricStateOpen(WasiAlg, *OptKeyHandle, *OptOptions); + unlikely(!Res)) { + return Res.error(); + } else { + *State = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateClone::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t StateHandle, + uint32_t /* Out */ StatePtr) { + checkExist(MemInst); + + auto *const State = MemInst->getPointer<__wasi_symmetric_state_t *>(StatePtr); + if (unlikely(State == nullptr)) { + return __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE; + } + + if (auto Res = Ctx.symmetricStateClone(StateHandle); unlikely(!Res)) { + return Res.error(); + } else { + *State = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +StateOptionsGet::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t StateHandle, uint32_t NamePtr, uint32_t NameLen, + uint32_t ValuePtr, uint32_t ValueLen, + uint32_t /* Out */ SizePtr) { + checkExist(MemInst); + + const __wasi_size_t WasiNameLen = NameLen; + auto *const Name = MemInst->getPointer(NamePtr, WasiNameLen); + checkExist(Name); + + const __wasi_size_t WasiValueLen = ValueLen; + auto *const Value = MemInst->getPointer(ValuePtr, ValueLen); + checkExist(Value); + + auto *const Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); + checkExist(Size); + + if (auto Res = Ctx.symmetricStateOptionsGet(StateHandle, {Name, WasiNameLen}, + {Value, WasiValueLen}) + .and_then(toWasiSize); + !Res) { + return Res.error(); + } else { + *Size = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +StateOptionsGetU64::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t StateHandle, uint32_t NamePtr, + uint32_t NameLen, uint32_t /* Out */ U64Ptr) { + checkExist(MemInst); + + const __wasi_size_t WasiNameLen = NameLen; + auto *const Name = MemInst->getPointer(NamePtr, WasiNameLen); + checkExist(Name); + + auto *const U64 = MemInst->getPointer(U64Ptr); + checkExist(U64); + + if (auto Res = + Ctx.symmetricStateOptionsGetU64(StateHandle, {Name, WasiNameLen}); + unlikely(!Res)) { + return Res.error(); + } else { + *U64 = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateClose::body(Runtime::Instance::MemoryInstance *, + int32_t StateHandle) { + if (auto Res = Ctx.symmetricStateClose(StateHandle); unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateAbsorb::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t StateHandle, uint32_t DataPtr, + uint32_t DataLen) { + checkExist(MemInst); + + const __wasi_size_t WasiDataLen = DataLen; + auto *const Data = MemInst->getPointer(DataPtr, WasiDataLen); + checkExist(Data); + + if (auto Res = Ctx.symmetricStateAbsorb(StateHandle, {Data, WasiDataLen}); + unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateSqueeze::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t StateHandle, uint32_t OutPtr, + uint32_t OutLen) { + checkExist(MemInst); + + const __wasi_size_t WasiOutLen = OutLen; + auto *const Out = MemInst->getPointer(OutPtr, WasiOutLen); + checkExist(Out); + + if (auto Res = Ctx.symmetricStateSqueeze(StateHandle, {Out, WasiOutLen}); + unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +StateSqueezeTag::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t StateHandle, uint32_t /* Out */ TagHandlePtr) { + checkExist(MemInst); + + auto *const TagHandle = + MemInst->getPointer<__wasi_symmetric_tag_t *>(TagHandlePtr); + checkExist(TagHandle); + + if (auto Res = Ctx.symmetricStateSqueezeTag(StateHandle); unlikely(!Res)) { + return Res.error(); + } else { + *TagHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +StateSqueezeKey::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t StateHandle, uint32_t AlgPtr, uint32_t AlgLen, + uint32_t /* Out */ KeyHandlePtr) { + checkExist(MemInst); + + const __wasi_size_t WasiAlgLen = AlgLen; + auto *const Alg = MemInst->getPointer(AlgPtr, WasiAlgLen); + checkExist(Alg); + Algorithm WasiAlg; + if (auto Res = tryFrom({Alg, WasiAlgLen}); !Res) { + return Res.error(); + } else { + WasiAlg = *Res; + } + + auto *const KeyHandle = + MemInst->getPointer<__wasi_symmetric_key_t *>(KeyHandlePtr); + checkExist(KeyHandle); + + if (auto Res = Ctx.symmetricStateSqueezeKey(StateHandle, WasiAlg); + unlikely(!Res)) { + return Res.error(); + } else { + *KeyHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +StateMaxTagLen::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t StateHandle, uint32_t /* Out */ SizePtr) { + checkExist(MemInst); + + auto *const Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); + checkExist(Size); + + if (auto Res = Ctx.symmetricStateMaxTagLen(StateHandle).and_then(toWasiSize); + unlikely(!Res)) { + return Res.error(); + } else { + *Size = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateEncrypt::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t StateHandle, uint32_t OutPtr, + uint32_t OutLen, uint32_t DataPtr, + uint32_t DataLen, + uint32_t /* Out */ SizePtr) { + checkExist(MemInst); + + const __wasi_size_t WasiOutLen = OutLen; + auto *const Out = MemInst->getPointer(OutPtr, WasiOutLen); + checkExist(Out); + + const __wasi_size_t WasiDataLen = DataLen; + auto *Data = MemInst->getPointer(DataPtr, WasiDataLen); + checkExist(Data); + + auto *const Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); + checkExist(Size); + + if (auto Res = Ctx.symmetricStateEncrypt(StateHandle, {Out, WasiOutLen}, + {Data, WasiDataLen}) + .and_then(toWasiSize); + unlikely(!Res)) { + return Res.error(); + } else { + *Size = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect +StateEncryptDetached::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t StateHandle, uint32_t OutPtr, + uint32_t OutLen, uint32_t DataPtr, uint32_t DataLen, + uint32_t /* Out */ TagHandlePtr) { + checkExist(MemInst); + + const __wasi_size_t WasiOutLen = OutLen; + auto *const Out = MemInst->getPointer(OutPtr, WasiOutLen); + checkExist(Out); + + const __wasi_size_t WasiDataLen = DataLen; + auto *const Data = MemInst->getPointer(DataPtr, WasiDataLen); + checkExist(Data); + + auto *const TagHandle = + MemInst->getPointer<__wasi_symmetric_tag_t *>(TagHandlePtr); + checkExist(TagHandle); + + if (auto Res = Ctx.symmetricStateEncryptDetached( + StateHandle, {Out, WasiOutLen}, {Data, WasiDataLen}); + unlikely(!Res)) { + return Res.error(); + } else { + *TagHandle = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateDecrypt::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t StateHandle, uint32_t OutPtr, + uint32_t OutLen, uint32_t DataPtr, + uint32_t DataLen, + uint32_t /* Out */ SizePtr) { + checkExist(MemInst); + + const __wasi_size_t WasiOutLen = OutLen; + auto *const Out = MemInst->getPointer(OutPtr, WasiOutLen); + checkExist(Out); + + const __wasi_size_t WasiDataLen = DataLen; + auto *const Data = MemInst->getPointer(DataPtr, WasiDataLen); + checkExist(Data); + + auto *const Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); + if (unlikely(Size == nullptr)) { + return __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE; + } + + if (auto Res = Ctx.symmetricStateDecrypt(StateHandle, {Out, WasiOutLen}, + {Data, WasiDataLen}) + .and_then(toWasiSize); + unlikely(!Res)) { + return Res.error(); + } else { + *Size = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateDecryptDetached::body( + Runtime::Instance::MemoryInstance *MemInst, int32_t StateHandle, + uint32_t OutPtr, uint32_t OutLen, uint32_t DataPtr, uint32_t DataLen, + uint32_t RawTagPtr, uint32_t RawTagLen, uint32_t /* Out */ SizePtr) { + checkExist(MemInst); + + const __wasi_size_t WasiOutLen = OutLen; + auto *const Out = MemInst->getPointer(OutPtr, WasiOutLen); + checkExist(Out); + + const __wasi_size_t WasiDataLen = DataLen; + auto *const Data = MemInst->getPointer(DataPtr, WasiDataLen); + checkExist(Data); + + const __wasi_size_t WasiRawTagLen = RawTagLen; + auto *RawTag = MemInst->getPointer(RawTagPtr, WasiRawTagLen); + checkExist(RawTag); + + auto *Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); + checkExist(Size); + + if (auto Res = Ctx.symmetricStateDecryptDetached( + StateHandle, {Out, WasiOutLen}, {Data, WasiDataLen}, + {RawTag, WasiRawTagLen}) + .and_then(toWasiSize); + unlikely(!Res)) { + return Res.error(); + } else { + *Size = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect StateRatchet::body(Runtime::Instance::MemoryInstance *, + int32_t StateHandle) { + + if (auto Res = Ctx.symmetricStateRatchet(StateHandle); unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect TagLen::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t TagHandle, uint32_t /* Out */ SizePtr) { + checkExist(MemInst); + + auto *Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); + if (unlikely(Size == nullptr)) { + return __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE; + } + + if (auto Res = Ctx.symmetricTagLen(TagHandle).and_then(toWasiSize); + unlikely(!Res)) { + return Res.error(); + } else { + *Size = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect TagPull::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t TagHandle, uint32_t BufPtr, + uint32_t BufLen, uint32_t /* Out */ SizePtr) { + checkExist(MemInst); + + const __wasi_size_t WasiBufLen = BufLen; + auto *Buf = MemInst->getPointer(BufPtr, WasiBufLen); + checkExist(Buf); + + auto *Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); + checkExist(Size); + + if (auto Res = Ctx.symmetricTagPull(TagHandle, {Buf, WasiBufLen}) + .and_then(toWasiSize); + unlikely(!Res)) { + return Res.error(); + } else { + *Size = *Res; + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect TagVerify::body(Runtime::Instance::MemoryInstance *MemInst, + int32_t TagHandle, uint32_t RawTagPtr, + uint32_t RawTagLen) { + checkExist(MemInst); + + const __wasi_size_t WasiRawTagLen = RawTagLen; + auto *RawTag = MemInst->getPointer(RawTagPtr, WasiRawTagLen); + checkExist(RawTag); + + if (auto Res = Ctx.symmetricTagVerify(TagHandle, {RawTag, RawTagLen}); + unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +Expect TagClose::body(Runtime::Instance::MemoryInstance *, + int32_t TagHandle) { + if (auto Res = Ctx.symmetricTagClose(TagHandle); unlikely(!Res)) { + return Res.error(); + } + + return __WASI_CRYPTO_ERRNO_SUCCESS; +} + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/func.h b/plugins/wasi_crypto/symmetric/func.h new file mode 100644 index 00000000..03fa19d3 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/func.h @@ -0,0 +1,253 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/symmetric/func.h - Symmetric funcs ---===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the symmetric host functions of wasi-crypto. +/// +//===----------------------------------------------------------------------===// +#pragma once + +#include "utils/hostfunction.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +class KeyGenerate : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t AlgPtr, uint32_t AlgLen, + uint32_t OptOptionsPtr, + uint32_t /* Out */ KeyHandlePtr); +}; + +class KeyImport : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t AlgPtr, uint32_t AlgLen, uint32_t RawPtr, + uint32_t RawLen, uint32_t /* Out */ KeyHandlePtr); +}; + +class KeyExport : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t KeyHandle, + uint32_t /* Out */ ArrayOutputHandlePtr); +}; + +class KeyClose : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t KeyHandle); +}; + +class KeyGenerateManaged : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SecretsManagerHandle, uint32_t AlgPtr, + uint32_t AlgLen, uint32_t OptOptionsPtr, + uint32_t /* Out */ KeyHandlePtr); +}; + +class KeyStoreManaged : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SecretsManagerHandle, int32_t KeyHandle, + uint32_t KeyIdPtr, uint32_t KeyIdMaxLen); +}; + +class KeyReplaceManaged : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SecretsManagerHandle, int32_t OldKeyHandle, + int32_t NewKeyHandle, uint32_t /* Out */ KeyVersionPtr); +}; + +class KeyId : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t KeyHandle, uint32_t KeyIdPtr, + uint32_t KeyIdMaxLen, uint32_t /* Out */ SizePtr, + uint32_t /* Out */ KeyVersionPtr); +}; + +class KeyFromId : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t SecretsManagerHandle, uint32_t KeyIdPtr, + uint32_t KeyIdLen, uint64_t KeyVersion, + uint32_t /* Out */ KeyHandlePtr); +}; + +class StateOpen : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t AlgPtr, uint32_t AlgLen, + uint32_t OptKeyHandlePtr, uint32_t OptOptionsPtr, + uint32_t /* Out */ StatePtr); +}; + +class StateClone : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t StateHandle, uint32_t /* Out */ StatePtr); +}; + +class StateOptionsGet : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t StateHandle, uint32_t NamePtr, uint32_t NameLen, + uint32_t ValuePtr, uint32_t ValueLen, + uint32_t /* Out */ SizePtr); +}; + +class StateOptionsGetU64 : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t StateHandle, uint32_t NamePtr, uint32_t NameLen, + uint32_t /* Out */ U64Ptr); +}; + +class StateClose : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t StateHandle); +}; + +class StateAbsorb : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t StateHandle, uint32_t DataPtr, + uint32_t DataLen); +}; + +class StateSqueeze : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t StateHandle, uint32_t OutPtr, uint32_t OutLen); +}; + +class StateSqueezeTag : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t StateHandle, uint32_t /* Out */ TagHandlePtr); +}; + +class StateSqueezeKey : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t StateHandle, uint32_t AlgPtr, uint32_t AlgLen, + uint32_t /* Out */ KeyHandlePtr); +}; + +class StateMaxTagLen : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t StateHandle, uint32_t /* Out */ SizePtr); +}; + +class StateEncrypt : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t StateHandle, uint32_t OutPtr, uint32_t OutLen, + uint32_t DataPtr, uint32_t DataLen, + uint32_t /* Out */ SizePtr); +}; + +class StateEncryptDetached : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t StateHandle, uint32_t OutPtr, uint32_t OutLen, + uint32_t DataPtr, uint32_t DataLen, + uint32_t /* Out */ TagHandlePtr); +}; + +class StateDecrypt : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t StateHandle, uint32_t OutPtr, uint32_t OutLen, + uint32_t DataPtr, uint32_t DataLen, + uint32_t /* Out */ SizePtr); +}; + +class StateDecryptDetached : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t StateHandle, uint32_t OutPtr, uint32_t OutLen, + uint32_t DataPtr, uint32_t DataLen, uint32_t RawTagPtr, + uint32_t RawTagLen, uint32_t /* Out */ SizePtr); +}; + +class StateRatchet : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t StateHandle); +}; + +class TagLen : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t TagHandle, uint32_t /* Out */ SizePtr); +}; + +class TagPull : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t TagHandle, uint32_t BufPtr, uint32_t BufLen, + uint32_t /* Out */ SizePtr); +}; + +class TagVerify : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t TagHandle, uint32_t RawTagPtr, + uint32_t RawTagLen); +}; + +class TagClose : public HostFunction { +public: + using HostFunction::HostFunction; + Expect body(Runtime::Instance::MemoryInstance *MemInst, + int32_t TagHandle); +}; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/hash.cpp b/plugins/wasi_crypto/symmetric/hash.cpp new file mode 100644 index 00000000..dd283f34 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/hash.cpp @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "symmetric/hash.h" +#include "utils/evp_wrapper.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +template constexpr size_t Sha2::getDigestSize() noexcept { + static_assert(ShaNid == NID_sha256 || ShaNid == NID_sha512 || + ShaNid == NID_sha512_256); + if constexpr (ShaNid == NID_sha256) + return 32; + if constexpr (ShaNid == NID_sha512) + return 64; + if constexpr (ShaNid == NID_sha512_256) + return 32; +} + +template +WasiCryptoExpect::State> +Sha2::State::open(OptionalRef) noexcept { + EvpMdCtxPtr Ctx{EVP_MD_CTX_new()}; + opensslCheck(EVP_DigestInit(Ctx.get(), EVP_get_digestbynid(ShaNid))); + return Ctx; +} + +template +WasiCryptoExpect +Sha2::State::absorb(Span Data) noexcept { + std::unique_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_DigestUpdate(Ctx->RawCtx.get(), Data.data(), Data.size())); + return {}; +} + +template +WasiCryptoExpect +Sha2::State::squeeze(Span Out) noexcept { + ensureOrReturn(getDigestSize() >= Out.size(), + __WASI_CRYPTO_ERRNO_INVALID_LENGTH); + + EvpMdCtxPtr CopyCtx{EVP_MD_CTX_new()}; + + { + std::shared_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_MD_CTX_copy_ex(CopyCtx.get(), Ctx->RawCtx.get())); + } + + if (getDigestSize() == Out.size()) { + unsigned int Size; + opensslCheck(EVP_DigestFinal_ex(CopyCtx.get(), Out.data(), &Size)); + ensureOrReturn(Size == getDigestSize(), + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + } + + if (getDigestSize() > Out.size()) { + unsigned int Size; + std::array Cache; + opensslCheck(EVP_DigestFinal_ex(CopyCtx.get(), Cache.data(), &Size)); + ensureOrReturn(Size == getDigestSize(), + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + std::copy(Cache.begin(), Cache.begin() + static_cast(Out.size()), + Out.data()); + } + + return {}; +} + +template +WasiCryptoExpect::State> +Sha2::State::clone() const noexcept { + EvpMdCtxPtr CloneCtx{EVP_MD_CTX_new()}; + + { + std::shared_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_MD_CTX_copy_ex(CloneCtx.get(), Ctx->RawCtx.get())); + } + + return CloneCtx; +} + +template class Sha2; +template class Sha2; +template class Sha2; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/hash.h b/plugins/wasi_crypto/symmetric/hash.h new file mode 100644 index 00000000..49288f22 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/hash.h @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/symmetric/hash.h - Hash related ------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the declaration of Hash and related classes. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "symmetric/options.h" +#include "symmetric/tag.h" +#include "utils/evp_wrapper.h" +#include "utils/optional.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +/// Hash never have key, just a placement, every hash key should inherent from +/// this class. +/// +/// More detailed: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#hash-functions +template class HashKey { +public: + static WasiCryptoExpect import(Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_KEY_NOT_SUPPORTED); + } + + static WasiCryptoExpect generate(OptionalRef) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_KEY_NOT_SUPPORTED); + } + + SecretVec exportData() const noexcept { assumingUnreachable(); } +}; + +/// Hash invalid operations, every hash state should inherent from this class. +template class HashState { +public: + /// Current hash not support any options. + WasiCryptoExpect optionsGet(std::string_view, + Span) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + } + + /// Current hash not support any options. + WasiCryptoExpect optionsGetU64(std::string_view) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + } + + WasiCryptoExpect ratchet() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect encrypt(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect encryptDetached(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect decrypt(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect decryptDetached(Span, Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect squeezeKey() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect squeezeTag() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect maxTagLen() const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } +}; + +template class Sha2 { +public: + /// In fact, sha2 key will never produce. This design is for removing the + /// forwarding declaration. + class Key : public HashKey {}; + + class State : public HashState { + public: + State(EvpMdCtxPtr Ctx) noexcept + : Ctx(std::make_shared(std::move(Ctx))) {} + + static WasiCryptoExpect + open(OptionalRef OptOption) noexcept; + + WasiCryptoExpect absorb(Span Data) noexcept; + + WasiCryptoExpect squeeze(Span Out) noexcept; + + WasiCryptoExpect clone() const noexcept; + + private: + struct Inner { + Inner(EvpMdCtxPtr Ctx) noexcept : RawCtx(std::move(Ctx)) {} + EvpMdCtxPtr RawCtx; + std::shared_mutex Mutex; + }; + std::shared_ptr Ctx; + }; + +private: + /// Return the sha digest size. + constexpr static size_t getDigestSize() noexcept; +}; + +using Sha256 = Sha2; +using Sha512 = Sha2; +using Sha512_256 = Sha2; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/kdf.cpp b/plugins/wasi_crypto/symmetric/kdf.cpp new file mode 100644 index 00000000..6d23a868 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/kdf.cpp @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "symmetric/kdf.h" +#include "utils/error.h" +#include "utils/evp_wrapper.h" +#include "utils/secret_vec.h" + +#include +#include + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +template constexpr uint32_t Hkdf::getKeySize() noexcept { + static_assert(ShaNid == NID_sha256 || ShaNid == NID_sha512); + + if constexpr (ShaNid == NID_sha256) + return 32; + if constexpr (ShaNid == NID_sha512) + return 64; +} + +template constexpr void *Hkdf::getShaCtx() noexcept { + return static_cast(const_cast(EVP_get_digestbynid(ShaNid))); +} + +template +WasiCryptoExpect::Expand::Key> +Hkdf::Expand::Key::generate(OptionalRef) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_FEATURE); +} + +template +WasiCryptoExpect::Expand::Key> +Hkdf::Expand::Key::import(Span Raw) noexcept { + ensureOrReturn(Raw.size() == getKeySize(), __WASI_CRYPTO_ERRNO_INVALID_KEY); + return SecretVec{Raw}; +} + +template +WasiCryptoExpect::Expand::State> +Hkdf::Expand::State::open(const Key &Key, + OptionalRef) noexcept { + return openStateImpl(Key.ref(), EVP_PKEY_HKDEF_MODE_EXPAND_ONLY); +} + +template +WasiCryptoExpect +Hkdf::Expand::State::absorb(Span Data) noexcept { + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck( + EVP_PKEY_CTX_add1_hkdf_info(Ctx->RawCtx.get(), Data.data(), Data.size())); + return {}; +} + +template +WasiCryptoExpect +Hkdf::Expand::State::squeeze(Span Out) noexcept { + size_t KeyLen = Out.size(); + + { + std::scoped_lock Lock{Ctx->Mutex}; + ensureOrReturn(EVP_PKEY_derive(Ctx->RawCtx.get(), Out.data(), &KeyLen), + __WASI_CRYPTO_ERRNO_INVALID_KEY); + } + + ensureOrReturn(KeyLen == getKeySize(), __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + return {}; +} + +template +WasiCryptoExpect::Expand::State> +Hkdf::Expand::State::clone() const noexcept { + // not supported for a keygen operation. + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +template +WasiCryptoExpect::Extract::Key> +Hkdf::Extract::Key::generate(OptionalRef) noexcept { + return SecretVec::random(); +} + +template +WasiCryptoExpect::Extract::Key> +Hkdf::Extract::Key::import(Span Raw) noexcept { + return SecretVec{Raw}; +} + +template +WasiCryptoExpect::Extract::State> +Hkdf::Extract::State::open(const Key &Key, + OptionalRef) noexcept { + return openStateImpl(Key.ref(), EVP_PKEY_HKDEF_MODE_EXTRACT_ONLY); +} + +template +WasiCryptoExpect +Hkdf::Extract::State::absorb(Span Data) noexcept { + std::scoped_lock Lock{Ctx->Mutex}; + + Ctx->Salt.insert(Ctx->Salt.end(), Data.begin(), Data.end()); + return {}; +} + +template +WasiCryptoExpect::Expand::Key> +Hkdf::Extract::State::squeezeKey() noexcept { + std::scoped_lock Lock{Ctx->Mutex}; + + opensslCheck(EVP_PKEY_CTX_set1_hkdf_salt(Ctx->RawCtx.get(), Ctx->Salt.data(), + Ctx->Salt.size())); + + SecretVec Data(getKeySize()); + + size_t ActualOutSize; + opensslCheck(EVP_PKEY_derive(Ctx->RawCtx.get(), Data.data(), &ActualOutSize)); + ensureOrReturn(ActualOutSize == getKeySize(), + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + return Data; +} + +template +WasiCryptoExpect::Extract::State> +Hkdf::Extract::State::clone() const noexcept { + // not supported for a keygen operation. + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +template +WasiCryptoExpect +Hkdf::openStateImpl(Span Key, int Mode) noexcept { + EvpPkeyCtxPtr Ctx{EVP_PKEY_CTX_new_id(EVP_PKEY_HKDF, nullptr)}; + opensslCheck(EVP_PKEY_derive_init(Ctx.get())); + opensslCheck(EVP_PKEY_CTX_set_hkdf_md(Ctx.get(), getShaCtx())); + opensslCheck(EVP_PKEY_CTX_hkdf_mode(Ctx.get(), Mode)); + ensureOrReturn(EVP_PKEY_CTX_set1_hkdf_key(Ctx.get(), Key.data(), Key.size()), + __WASI_CRYPTO_ERRNO_INVALID_KEY); + + return Ctx; +} + +template class Hkdf::Extract; +template class Hkdf::Extract; +template class Hkdf::Expand; +template class Hkdf::Expand; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/kdf.h b/plugins/wasi_crypto/symmetric/kdf.h new file mode 100644 index 00000000..5ecc528f --- /dev/null +++ b/plugins/wasi_crypto/symmetric/kdf.h @@ -0,0 +1,249 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/symmetric/kdf.h - Kdf related --------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the declaration of Key derivation and related classes. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "symmetric/options.h" +#include "symmetric/tag.h" +#include "utils/error.h" +#include "utils/evp_wrapper.h" +#include "utils/optional.h" +#include "utils/secret_vec.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +/// Expand invalid operations, every expand state should inherent from this +/// class. +/// +/// More detailed: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#key-derivation-using-extract-and-expand +template class ExpandState { +public: + /// Current kdf not support any options. + WasiCryptoExpect optionsGet(std::string_view, + Span) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + } + + /// Current kdf not support any options. + WasiCryptoExpect optionsGetU64(std::string_view) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + } + + WasiCryptoExpect ratchet() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect encrypt(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect encryptDetached(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect decrypt(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect decryptDetached(Span, Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect squeezeKey() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect squeezeTag() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect maxTagLen() const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } +}; + +/// Extract invalid operations, every extract state should inherent from this +/// class. +template class ExtractState { +public: + /// Current kdf not support any options. + WasiCryptoExpect optionsGet(std::string_view, + Span) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + } + + /// Current kdf not support any options. + WasiCryptoExpect optionsGetU64(std::string_view) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + } + + WasiCryptoExpect squeeze(Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect ratchet() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect encrypt(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect encryptDetached(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect decrypt(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect decryptDetached(Span, Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect squeezeTag() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect maxTagLen() const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } +}; + +template class Hkdf { +public: + class Expand { + public: + class Key { + public: + Key(SecretVec Data) noexcept : Data(std::move(Data)) {} + + static WasiCryptoExpect import(Span Data) noexcept; + + static WasiCryptoExpect + generate(OptionalRef Options) noexcept; + + SecretVec exportData() const noexcept { return Data; } + + const SecretVec &ref() const noexcept { return Data; } + + private: + SecretVec Data; + }; + + class State : public ExpandState { + public: + static WasiCryptoExpect + open(const Key &Key, OptionalRef OptOption) noexcept; + + State(EvpPkeyCtxPtr Ctx) noexcept + : Ctx(std::make_shared(std::move(Ctx))) {} + + /// absorb info information. + WasiCryptoExpect absorb(Span Data) noexcept; + + /// derivation + WasiCryptoExpect squeeze(Span Out) noexcept; + + WasiCryptoExpect clone() const noexcept; + + private: + struct Inner { + Inner(EvpPkeyCtxPtr RawCtx) : RawCtx(std::move(RawCtx)) {} + EvpPkeyCtxPtr RawCtx; + std::mutex Mutex; + }; + std::shared_ptr Ctx; + }; + }; + + class Extract { + public: + class Key { + public: + Key(SecretVec Data) noexcept : Data(std::move(Data)) {} + + static WasiCryptoExpect import(Span Data) noexcept; + + static WasiCryptoExpect + generate(OptionalRef Options) noexcept; + + SecretVec exportData() const noexcept { return Data; } + + const SecretVec &ref() const noexcept { return Data; } + + private: + SecretVec Data; + }; + + class State : public ExtractState { + public: + State(EvpPkeyCtxPtr Ctx) noexcept + : Ctx(std::make_shared(std::move(Ctx))) {} + + static WasiCryptoExpect + open(const Key &Key, OptionalRef OptOption) noexcept; + + /// Absorbs the salt information. + WasiCryptoExpect absorb(Span Data) noexcept; + + /// Returns the PRK, whose algorithm type is set to the EXPAND counterpart + /// of the EXTRACT operation. + WasiCryptoExpect squeezeKey() noexcept; + + WasiCryptoExpect clone() const noexcept; + + private: + struct Inner { + Inner(EvpPkeyCtxPtr RawCtx) : RawCtx(std::move(RawCtx)) {} + std::mutex Mutex; + std::vector Salt; + EvpPkeyCtxPtr RawCtx; + }; + std::shared_ptr Ctx; + }; + }; + +private: + constexpr static uint32_t getKeySize() noexcept; + + constexpr static void *getShaCtx() noexcept; + + static WasiCryptoExpect openStateImpl(Span Key, + int Mode) noexcept; +}; + +using HkdfSha256Extract = Hkdf::Extract; +using HkdfSha512Extract = Hkdf::Extract; +using HkdfSha256Expand = Hkdf::Expand; +using HkdfSha512Expand = Hkdf::Expand; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/key.cpp b/plugins/wasi_crypto/symmetric/key.cpp new file mode 100644 index 00000000..33797042 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/key.cpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "symmetric/key.h" +#include "utils/error.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +WasiCryptoExpect importKey(Algorithm Alg, + Span Data) noexcept { + return std::visit( + [Data](auto Factory) noexcept { + return decltype(Factory)::Key::import(Data).map( + [](auto &&Key) noexcept { + return KeyVariant{std::forward(Key)}; + }); + }, + Alg); +} + +WasiCryptoExpect +generateKey(Algorithm Alg, OptionalRef OptOptions) noexcept { + return std::visit( + [OptOptions](auto Factory) noexcept { + return decltype(Factory)::Key::generate(OptOptions) + .map([](auto &&Key) noexcept { + return KeyVariant{std::forward(Key)}; + }); + }, + Alg); +} + +SecretVec keyExportData(const KeyVariant &KeyVariant) noexcept { + return std::visit([](const auto &Key) noexcept { return Key.exportData(); }, + KeyVariant); +} + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/key.h b/plugins/wasi_crypto/symmetric/key.h new file mode 100644 index 00000000..b5f62150 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/key.h @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/symmetric/key.h - Symmetric Key class ===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the Symmetric Key class definition. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "symmetric/registed.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +/// Object represents a key and an algorithm. +/// +/// More detail: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#symmetric-keys-1 +using KeyVariant = RegistedAlg::Key; + +WasiCryptoExpect importKey(Algorithm Alg, + Span Data) noexcept; + +WasiCryptoExpect +generateKey(Algorithm Alg, OptionalRef OptOptions) noexcept; + +/// Get the inner represent. +SecretVec keyExportData(const KeyVariant &Key) noexcept; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/mac.cpp b/plugins/wasi_crypto/symmetric/mac.cpp new file mode 100644 index 00000000..d6349cf2 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/mac.cpp @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "symmetric/mac.h" +#include "utils/secret_vec.h" + +#include + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +template constexpr size_t Hmac::getKeySize() noexcept { + static_assert(ShaNid == NID_sha256 || ShaNid == NID_sha512); + if constexpr (ShaNid == NID_sha256) { + return 32; + } + if constexpr (ShaNid == NID_sha512) { + return 64; + } +} + +template +WasiCryptoExpect::Key> +Hmac::Key::generate(OptionalRef) noexcept { + return SecretVec::random(); +} + +template +WasiCryptoExpect::Key> +Hmac::Key::import(Span Raw) noexcept { + return SecretVec{Raw}; +} + +template +WasiCryptoExpect::State> +Hmac::State::open(const Key &Key, OptionalRef) noexcept { + EvpPkeyPtr HmacKey{EVP_PKEY_new_raw_private_key( + EVP_PKEY_HMAC, nullptr, Key.ref().data(), Key.ref().size())}; + opensslCheck(HmacKey); + + EvpMdCtxPtr Ctx{EVP_MD_CTX_new()}; + + opensslCheck(EVP_DigestSignInit( + Ctx.get(), nullptr, EVP_get_digestbynid(ShaNid), nullptr, HmacKey.get())); + + return Ctx; +} + +template +WasiCryptoExpect +Hmac::State::absorb(Span Data) noexcept { + std::scoped_lock Lock{Ctx->Mutex}; + + opensslCheck( + EVP_DigestSignUpdate(Ctx->RawCtx.get(), Data.data(), Data.size())); + return {}; +} + +template +WasiCryptoExpect Hmac::State::squeezeTag() noexcept { + SecretVec Res(getKeySize()); + + size_t ActualOutSize; + { + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck( + EVP_DigestSignFinal(Ctx->RawCtx.get(), Res.data(), &ActualOutSize)); + } + + ensureOrReturn(ActualOutSize == getKeySize(), + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + return Res; +} + +template +WasiCryptoExpect::State> +Hmac::State::clone() const noexcept { + EvpMdCtxPtr CloneCtx{EVP_MD_CTX_new()}; + + { + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_MD_CTX_copy_ex(CloneCtx.get(), Ctx->RawCtx.get())); + } + + return CloneCtx; +} + +template class Hmac; +template class Hmac; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasi_crypto/symmetric/mac.h b/plugins/wasi_crypto/symmetric/mac.h new file mode 100644 index 00000000..f1247b95 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/mac.h @@ -0,0 +1,150 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/symmetric/mac.h - Mac related --------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the declaration of Mac and related classes. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "symmetric/options.h" +#include "symmetric/tag.h" +#include "utils/evp_wrapper.h" +#include "utils/optional.h" +#include "utils/secret_vec.h" + +#include +#include +#include +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +/// Mac invalid operation, every mac state should inherent from this class +/// +/// More detailed: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#message-authentication-codes +template class MacState { +public: + /// Current mac not support any options. + WasiCryptoExpect optionsGet(std::string_view, + Span) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + } + + /// Current mac not support any options. + WasiCryptoExpect optionsGetU64(std::string_view) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + } + + WasiCryptoExpect squeeze(Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect ratchet() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect encrypt(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect encryptDetached(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect decrypt(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect decryptDetached(Span, Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect squeezeKey() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect maxTagLen() const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } +}; + +template class Hmac { +public: + class Key { + public: + Key(SecretVec Data) noexcept : Data(std::move(Data)) {} + + static WasiCryptoExpect import(Span Data) noexcept; + + static WasiCryptoExpect + generate(OptionalRef Options) noexcept; + + SecretVec exportData() const noexcept { return Data; } + + const SecretVec &ref() const noexcept { return Data; } + + private: + SecretVec Data; + }; + + class State : public MacState { + public: + State(EvpMdCtxPtr Ctx) noexcept + : Ctx(std::make_shared(std::move(Ctx))) {} + + static WasiCryptoExpect + open(const Key &Key, OptionalRef OptOption) noexcept; + + /// Adds input data to the state. + /// + /// @param[in] Data the input data. + /// @return Nothing or WasiCrypto error. + WasiCryptoExpect absorb(Span Data) noexcept; + + /// Authenticates the input received up to the function call. + /// If the finalization is required, the implementation MUST duplicate the + /// internal state and apply the finalization on the copy, leaving the state + /// unchanged from the guest perspective. + /// + /// @return Nothing or WasiCrypto error. + WasiCryptoExpect squeezeTag() noexcept; + + WasiCryptoExpect clone() const noexcept; + + private: + struct Inner { + Inner(EvpMdCtxPtr RawCtx) noexcept : RawCtx(std::move(RawCtx)) {} + EvpMdCtxPtr RawCtx; + std::mutex Mutex; + }; + std::shared_ptr Ctx; + }; + +private: + constexpr static size_t getKeySize() noexcept; +}; + +using HmacSha256 = Hmac; +using HmacSha512 = Hmac; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/options.cpp b/plugins/wasi_crypto/symmetric/options.cpp new file mode 100644 index 00000000..22309e57 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/options.cpp @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "symmetric/options.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { +using namespace std::literals; + +namespace { +constexpr std::array ValidNames{"context"sv, "salt"sv, + "nonce"sv}; + +std::string toLower(std::string_view Name) noexcept { + std::string Ret{Name}; + std::transform(Ret.begin(), Ret.end(), Ret.begin(), + [](char C) { return std::tolower(C); }); + return Ret; +} + +bool isValidName(std::string_view Name) noexcept { + return std::find(ValidNames.begin(), ValidNames.end(), Name) != + ValidNames.end(); +} + +constexpr std::array ValidU64Names{ + "memory_limit"sv, "ops_limit"sv, "parallelism"sv}; + +bool isValidU64Name(std::string_view Name) noexcept { + return std::find(ValidU64Names.begin(), ValidU64Names.end(), Name) != + ValidU64Names.end(); +} + +} // namespace + +WasiCryptoExpect Options::set(std::string_view Name, + Span Value) noexcept { + std::string ActuallyName = toLower(Name); + + ensureOrReturn(isValidName(ActuallyName), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + + { + std::unique_lock Lock{Inner->Mutex}; + Inner->ValueMap.insert_or_assign(ActuallyName, + std::vector(Value.begin(), Value.end())); + } + return {}; +} + +WasiCryptoExpect Options::setU64(std::string_view Name, + uint64_t Value) noexcept { + std::string ActuallyName = toLower(Name); + ensureOrReturn(isValidU64Name(ActuallyName), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + { + std::unique_lock Lock{Inner->Mutex}; + Inner->U64ValueMap.insert_or_assign(ActuallyName, Value); + } + return {}; +} + +WasiCryptoExpect Options::setGuestBuffer(std::string_view Name, + Span Buffer) noexcept { + std::string ActuallyName = toLower(Name); + ensureOrReturn(ActuallyName == "buffer"sv, + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + { + std::unique_lock Lock{Inner->Mutex}; + Inner->GuestBuffer = Buffer; + } + return {}; +} + +WasiCryptoExpect Options::get(std::string_view Name, + Span Value) const noexcept { + std::string ActuallyName = toLower(Name); + ensureOrReturn(isValidName(ActuallyName), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + { + std::shared_lock Lock{Inner->Mutex}; + if (auto It = Inner->ValueMap.find(ActuallyName); + It != Inner->ValueMap.end()) { + ensureOrReturn(It->second.size() <= Value.size(), + __WASI_CRYPTO_ERRNO_OVERFLOW); + std::copy(It->second.begin(), It->second.end(), Value.begin()); + return It->second.size(); + } + } + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_OPTION_NOT_SET); +} + +WasiCryptoExpect +Options::getU64(std::string_view Name) const noexcept { + std::string ActuallyName = toLower(Name); + ensureOrReturn(isValidU64Name(ActuallyName), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + { + std::shared_lock Lock{Inner->Mutex}; + if (auto It = Inner->U64ValueMap.find(ActuallyName); + It != Inner->U64ValueMap.end()) { + return It->second; + } + } + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_OPTION_NOT_SET); +} + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/options.h b/plugins/wasi_crypto/symmetric/options.h new file mode 100644 index 00000000..9d0a6ec3 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/options.h @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/symmetric/options.h - Options --------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the Symmetric Options class definition. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "utils/error.h" + +#include "common/span.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +/// Options for symmetric state and key. +/// +/// More detail: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#options-1 +class Options { +public: + WasiCryptoExpect set(std::string_view Name, + Span Value) noexcept; + + WasiCryptoExpect setU64(std::string_view Name, uint64_t Value) noexcept; + + WasiCryptoExpect setGuestBuffer(std::string_view Name, + Span Buffer) noexcept; + + WasiCryptoExpect get(std::string_view Name, + Span Value) const noexcept; + + WasiCryptoExpect getU64(std::string_view Name) const noexcept; + +private: + struct DataType { + std::map> ValueMap; + std::map U64ValueMap; + std::optional> GuestBuffer; + mutable std::shared_mutex Mutex; + }; + + std::shared_ptr Inner = std::make_shared(); +}; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/registed.h b/plugins/wasi_crypto/symmetric/registed.h new file mode 100644 index 00000000..197478c4 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/registed.h @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/symmetric/registed.h - Registed ------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the register symmetric algorithm definitions. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "symmetric/aeads.h" +#include "symmetric/hash.h" +#include "symmetric/kdf.h" +#include "symmetric/mac.h" +#include "utils/error.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +/// Registed algorithm +template struct Registed { + using Key = std::variant; + using State = std::variant; + using Variant = std::variant; +}; + +using RegistedAlg = + Registed; + +using Algorithm = RegistedAlg::Variant; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/state.cpp b/plugins/wasi_crypto/symmetric/state.cpp new file mode 100644 index 00000000..77852f58 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/state.cpp @@ -0,0 +1,232 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "symmetric/state.h" +#include "utils/error.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { +namespace { +WasiCryptoExpect checkedAdd(size_t A, size_t B) { + size_t Res; + ensureOrReturn(!__builtin_add_overflow(A, B, &Res), + __WASI_CRYPTO_ERRNO_OVERFLOW); + return Res; +} + +/// Correspond signatures: +template struct StateOpenTrait; + +/// WasiCryptoExpect open(const KeyType &, OptionalRef); +template +struct StateOpenTrait (*)( + const KeyType &, OptionalRef) noexcept> { + static inline constexpr bool NeedKey = true; + using Key = KeyType; +}; + +/// WasiCryptoExpect open(OptionalRef); +template +struct StateOpenTrait (*)( + OptionalRef) noexcept> { + static inline constexpr bool NeedKey = false; +}; + +template +using GetStateOpenTrait = StateOpenTrait; +} // namespace + +WasiCryptoExpect +openState(Algorithm Alg, OptionalRef OptKeyVariant, + OptionalRef OptOptions) noexcept { + return std::visit( + [=](auto Factory) noexcept -> WasiCryptoExpect { + using StateOpen = GetStateOpenTrait; + if constexpr (StateOpen::NeedKey) { + using RequiredKeyType = typename StateOpen::Key; + // Need key. Not have key, fail. + if (unlikely(!OptKeyVariant)) { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_KEY_REQUIRED); + } + return std::visit( + [OptOptions](const auto &Key) -> WasiCryptoExpect { + using InKeyType = std::decay_t; + if constexpr (!std::is_same_v) { + // Key type not same. + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_KEY); + } else { + // Key type fitted. + return decltype(Factory)::State::open(Key, OptOptions) + .map([](auto &&State) noexcept { + return StateVariant{ + std::forward(State)}; + }); + } + }, + *OptKeyVariant); + + } else { + // Not need key. Have key, fail. + if (unlikely(!!OptKeyVariant)) { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_KEY_NOT_SUPPORTED); + } + return decltype(Factory)::State::open(OptOptions) + .map([](auto &&State) noexcept { + return StateVariant{std::forward(State)}; + }); + } + }, + Alg); +} + +WasiCryptoExpect stateOptionsGet(const StateVariant &StateVariant, + std::string_view Name, + Span Value) noexcept { + return std::visit( + [=](const auto &State) noexcept { return State.optionsGet(Name, Value); }, + StateVariant); +} + +WasiCryptoExpect stateOptionsGetU64(const StateVariant &StateVariant, + std::string_view Name) noexcept { + return std::visit( + [Name](const auto &State) noexcept { return State.optionsGetU64(Name); }, + StateVariant); +} + +WasiCryptoExpect stateAbsorb(StateVariant &StateVariant, + Span Data) noexcept { + return std::visit([Data](auto &State) noexcept { return State.absorb(Data); }, + StateVariant); +} + +WasiCryptoExpect stateSqueeze(StateVariant &StateVariant, + Span Out) noexcept { + return std::visit([Out](auto &State) noexcept { return State.squeeze(Out); }, + StateVariant); +} + +WasiCryptoExpect stateSqueezeTag(StateVariant &StateVariant) noexcept { + return std::visit([](auto &State) noexcept { return State.squeezeTag(); }, + StateVariant); +} + +namespace { +template struct GetSqueezeKeyTypeTrait; +template +struct GetSqueezeKeyTypeTrait ( + StateType::*)() noexcept> { + using Key = KeyType; +}; +template +using GetSqueezeKeyType = + typename GetSqueezeKeyTypeTrait::Key; + +} // namespace + +WasiCryptoExpect stateSqueezeKey(StateVariant &StateVariant, + Algorithm KeyAlg) noexcept { + return std::visit( + [](auto &State, auto Alg) noexcept -> WasiCryptoExpect { + if constexpr (std::is_same_v< + GetSqueezeKeyType>, + typename decltype(Alg)::Key>) { + return State.squeezeKey().map([](auto &&Key) { + return KeyVariant{std::forward(Key)}; + }); + } else { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ALGORITHM); + } + }, + StateVariant, KeyAlg); +} + +WasiCryptoExpect +stateMaxTagLen(const StateVariant &StateVariant) noexcept { + return std::visit( + [](const auto &State) noexcept { return State.maxTagLen(); }, + StateVariant); +} + +WasiCryptoExpect stateEncrypt(StateVariant &StateVariant, + Span Out, + Span Data) noexcept { + return std::visit( + [=](auto &State) noexcept -> WasiCryptoExpect { + return State.maxTagLen() + .and_then([DataSize = Data.size()](size_t TagLen) noexcept { + return checkedAdd(DataSize, TagLen); + }) + .and_then([Out, Data, &State](size_t ActualDataLen) noexcept + -> WasiCryptoExpect { + ensureOrReturn(Out.size() == ActualDataLen, + __WASI_CRYPTO_ERRNO_INVALID_LENGTH); + return State.encrypt(Out, Data); + }); + }, + StateVariant); +} + +WasiCryptoExpect stateEncryptDetached(StateVariant &StateVariant, + Span Out, + Span Data) noexcept { + ensureOrReturn(Data.size() == Out.size(), __WASI_CRYPTO_ERRNO_INVALID_LENGTH); + return std::visit( + [=](auto &State) noexcept { return State.encryptDetached(Out, Data); }, + StateVariant); +} + +WasiCryptoExpect stateDecrypt(StateVariant &StateVariant, + Span Out, + Span Data) noexcept { + return std::visit( + [=](auto &State) noexcept -> WasiCryptoExpect { + return State.maxTagLen() + .and_then([OutSize = Out.size()](size_t TagLen) noexcept { + return checkedAdd(OutSize, TagLen); + }) + .and_then([Out, Data, &State](size_t ActualOutLen) noexcept + -> WasiCryptoExpect { + ensureOrReturn(Data.size() == ActualOutLen, + __WASI_CRYPTO_ERRNO_INVALID_LENGTH); + return State.decrypt(Out, Data); + }); + }, + StateVariant); +} + +WasiCryptoExpect +stateDecryptDetached(StateVariant &StateVariant, Span Out, + Span Data, + Span RawTag) noexcept { + ensureOrReturn(Data.size() == Out.size(), __WASI_CRYPTO_ERRNO_INVALID_LENGTH); + return std::visit( + [=](auto &State) noexcept { + return State.decryptDetached(Out, Data, RawTag); + }, + StateVariant); +} + +WasiCryptoExpect stateRatchet(StateVariant &StateVariant) noexcept { + return std::visit([](auto &State) noexcept { return State.ratchet(); }, + StateVariant); +} + +WasiCryptoExpect +stateClone(const StateVariant &ClonedStateVariant) noexcept { + return std::visit( + [](const auto &ClonedState) noexcept -> WasiCryptoExpect { + return ClonedState.clone().map([](auto &&NewState) noexcept { + return StateVariant{std::forward(NewState)}; + }); + }, + ClonedStateVariant); +} + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/state.h b/plugins/wasi_crypto/symmetric/state.h new file mode 100644 index 00000000..22c94e5c --- /dev/null +++ b/plugins/wasi_crypto/symmetric/state.h @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/symmetric/state.h - Symmetric State --===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the symmetric state related classes, and provide a +/// unified interface which can be used to implement the algorithm operations. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "symmetric/key.h" +#include "symmetric/registed.h" +#include "symmetric/tag.h" +#include "utils/error.h" + +#include "common/span.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +/// State created from key, and performs symmetric operations with using the +/// underlying algorithms. +/// +/// More detail: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#state +using StateVariant = RegistedAlg::State; + +WasiCryptoExpect +openState(Algorithm Alg, OptionalRef OptKeyVariant, + OptionalRef OptOptions) noexcept; + +WasiCryptoExpect stateOptionsGet(const StateVariant &StateVariant, + std::string_view Name, + Span Value) noexcept; + +WasiCryptoExpect stateOptionsGetU64(const StateVariant &StateVariant, + std::string_view Name) noexcept; + +/// Absorb data into the state. +WasiCryptoExpect stateAbsorb(StateVariant &StateVariant, + Span Data) noexcept; + +/// Squeeze bytes from the state. +WasiCryptoExpect stateSqueeze(StateVariant &StateVariant, + Span Out) noexcept; + +/// Compute and return a tag for all the data injected into the state so far. +WasiCryptoExpect stateSqueezeTag(StateVariant &StateVariant) noexcept; + +/// Use the current state to produce a key for a target algorithm. +WasiCryptoExpect stateSqueezeKey(StateVariant &StateVariant, + Algorithm KeyAlg) noexcept; + +/// Encrypt data with an attached tag. +WasiCryptoExpect stateEncrypt(StateVariant &StateVariant, + Span Out, + Span Data) noexcept; + +/// Encrypt data and return the ciphertext and the authentication tag +/// separately. +WasiCryptoExpect stateEncryptDetached(StateVariant &StateVariant, + Span Out, + Span Data) noexcept; + +/// Decrypt a ciphertext with an attached tag. +WasiCryptoExpect stateDecrypt(StateVariant &StateVariant, + Span Out, + Span Data) noexcept; + +/// Verify an authentication tag and decrypt the corresponding ciphertext if +/// the verification passes. +WasiCryptoExpect +stateDecryptDetached(StateVariant &StateVariant, Span Out, + Span Data, + Span RawTag) noexcept; + +/// Returns the length required to encode the authentication tag and optional +/// padding bytes. +WasiCryptoExpect +stateMaxTagLen(const StateVariant &StateVariant) noexcept; + +/// Make the state impossible to recover the previous state. +WasiCryptoExpect stateRatchet(StateVariant &StateVariant) noexcept; + +/// Clone the state. +WasiCryptoExpect +stateClone(const StateVariant &ClonedStateVariant) noexcept; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/tag.cpp b/plugins/wasi_crypto/symmetric/tag.cpp new file mode 100644 index 00000000..2fee4fb9 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/tag.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "symmetric/tag.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +WasiCryptoExpect Tag::verify(Span RawTag) const noexcept { + ensureOrReturn(!CRYPTO_memcmp(RawTag.data(), Data.data(), RawTag.size()), + __WASI_CRYPTO_ERRNO_INVALID_TAG); + + return {}; +} + +WasiCryptoExpect Tag::pull(Span Raw) const noexcept { + if (Raw.size() > Data.size()) { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_LENGTH); + } + if (Raw.size() < Data.size()) { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_OVERFLOW); + } + + std::copy(Data.begin(), Data.end(), Raw.begin()); + return Data.size(); +} + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/tag.h b/plugins/wasi_crypto/symmetric/tag.h new file mode 100644 index 00000000..5d2f3fad --- /dev/null +++ b/plugins/wasi_crypto/symmetric/tag.h @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/symmetric/tag.h - Symmetric Tag class ===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the Symmetric Tag definition. +/// +//===----------------------------------------------------------------------===// +#pragma once + +#include "utils/error.h" +#include "utils/secret_vec.h" + +#include "common/span.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +/// Authentication tag, that can be verified without channels using the provided +/// APIs. Very small and no streaming. +/// +/// More detail: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#authentication-tags +class Tag { +public: + Tag(Tag &&Data) noexcept = default; + Tag &operator=(Tag &&Data) noexcept = default; + Tag(const Tag &Data) noexcept = delete; + Tag &operator=(const Tag &Data) noexcept = delete; + + Tag(SecretVec &&Data) noexcept : Data(std::move(Data)) {} + + size_t len() const noexcept { return Data.size(); } + + /// The function MUST return `__WASI_CRYPTO_ERRNO_INVALID_TAG` if the + /// tags don't match. + WasiCryptoExpect verify(Span RawTag) const noexcept; + + WasiCryptoExpect pull(Span Raw) const noexcept; + +private: + SecretVec Data; +}; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/utils/error.h b/plugins/wasi_crypto/utils/error.h new file mode 100644 index 00000000..ba1bc1fc --- /dev/null +++ b/plugins/wasi_crypto/utils/error.h @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/utils/error.h - Error definition -----===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the wasi-crypto error handling related functions. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "common/errcode.h" +#include "common/expected.h" +#include "wasi_crypto/api.hpp" + +#include +#include + +/// Ensure the Expr is true or return ErrorCode. +#define ensureOrReturn(Expr, ErrorCode) \ + do { \ + if (!(Expr)) { \ + return WasiCryptoUnexpect((ErrorCode)); \ + } \ + } while (0) + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +/// Type aliasing for Expected. +template +using WasiCryptoExpect = Expected; + +/// Helper function for Unexpected. +constexpr auto WasiCryptoUnexpect(__wasi_crypto_errno_e_t Val) noexcept { + return Unexpected<__wasi_crypto_errno_e_t>(Val); +} +template +constexpr auto WasiCryptoUnexpect(const WasiCryptoExpect &Val) noexcept { + return Unexpected<__wasi_crypto_errno_e_t>(Val.error()); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/utils/evp_wrapper.cpp b/plugins/wasi_crypto/utils/evp_wrapper.cpp new file mode 100644 index 00000000..e5badb14 --- /dev/null +++ b/plugins/wasi_crypto/utils/evp_wrapper.cpp @@ -0,0 +1,218 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "utils/evp_wrapper.h" +#include "utils/error.h" + +#include +#include + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +EVP_PKEY *pemReadPUBKEY(Span Encoded) { + BioPtr Bio{BIO_new(BIO_s_mem())}; + + if (size_t Size; + BIO_write_ex(Bio.get(), Encoded.data(), Encoded.size(), &Size)) { + if (Size != Encoded.size()) { + return nullptr; + } + } else { + return nullptr; + } + + return PEM_read_bio_PUBKEY(Bio.get(), nullptr, nullptr, nullptr); +} + +WasiCryptoExpect> pemWritePUBKEY(EVP_PKEY *Key) { + BioPtr Bio{BIO_new(BIO_s_mem())}; + opensslCheck(PEM_write_bio_PUBKEY(Bio.get(), Key)); + + BUF_MEM *Mem = nullptr; + opensslCheck(BIO_get_mem_ptr(Bio.get(), &Mem)); + std::vector Ret(Mem->length); + + if (size_t Size; BIO_read_ex(Bio.get(), Ret.data(), Ret.size(), &Size)) { + ensureOrReturn(Size == Ret.size(), __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + } else { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + } + + return Ret; +} + +EVP_PKEY *pemReadPrivateKey(Span Encoded) { + BioPtr Bio{BIO_new(BIO_s_mem())}; + + if (size_t Size; + BIO_write_ex(Bio.get(), Encoded.data(), Encoded.size(), &Size)) { + if (Size != Encoded.size()) { + return nullptr; + } + } else { + return nullptr; + } + + return PEM_read_bio_PrivateKey(Bio.get(), nullptr, nullptr, nullptr); +} + +WasiCryptoExpect pemWritePrivateKey(EVP_PKEY *Key) { + BioPtr Bio{BIO_new(BIO_s_mem())}; + opensslCheck(PEM_write_bio_PrivateKey(Bio.get(), Key, nullptr, nullptr, 0, + nullptr, nullptr)); + + BUF_MEM *Mem = nullptr; + opensslCheck(BIO_get_mem_ptr(Bio.get(), &Mem)); + SecretVec Ret(Mem->length); + + if (size_t Size; BIO_read_ex(Bio.get(), Ret.data(), Ret.size(), &Size)) { + ensureOrReturn(Size == Ret.size(), __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + } else { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + } + + return Ret; +} + +EVP_PKEY *d2iPUBKEY(Span Encoded) { + BioPtr Bio{BIO_new(BIO_s_mem())}; + + if (size_t Size; + BIO_write_ex(Bio.get(), Encoded.data(), Encoded.size(), &Size)) { + if (Size != Encoded.size()) { + return nullptr; + } + } else { + return nullptr; + } + + return d2i_PUBKEY_bio(Bio.get(), nullptr); +} + +WasiCryptoExpect> i2dPUBKEY(EVP_PKEY *Key) { + BioPtr Bio{BIO_new(BIO_s_mem())}; + opensslCheck(i2d_PUBKEY_bio(Bio.get(), Key)); + + BUF_MEM *Mem = nullptr; + opensslCheck(BIO_get_mem_ptr(Bio.get(), &Mem)); + std::vector Ret(Mem->length); + + if (size_t Size; BIO_read_ex(Bio.get(), Ret.data(), Ret.size(), &Size)) { + ensureOrReturn(Size == Ret.size(), __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + } else { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + } + + return Ret; +} + +EVP_PKEY *d2iPrivateKey(Span Encoded) { + BioPtr Bio{BIO_new(BIO_s_mem())}; + + if (size_t Size; + BIO_write_ex(Bio.get(), Encoded.data(), Encoded.size(), &Size)) { + if (Size != Encoded.size()) { + return nullptr; + } + } else { + return nullptr; + } + + return d2i_PrivateKey_bio(Bio.get(), nullptr); +} + +WasiCryptoExpect i2dPrivateKey(EVP_PKEY *Key) { + BioPtr Bio{BIO_new(BIO_s_mem())}; + opensslCheck(i2d_PrivateKey_bio(Bio.get(), Key)); + + BUF_MEM *Mem = nullptr; + opensslCheck(BIO_get_mem_ptr(Bio.get(), &Mem)); + SecretVec Ret(Mem->length); + + if (size_t Size; BIO_read_ex(Bio.get(), Ret.data(), Ret.size(), &Size)) { + ensureOrReturn(Size == Ret.size(), __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + } else { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + } + + return Ret; +} + +ECDSA_SIG *d2iEcdsaSig(Span Encoded) { + if (Encoded.size() > static_cast(std::numeric_limits::max())) { + return nullptr; + } + auto *Data = Encoded.data(); + return d2i_ECDSA_SIG(nullptr, &Data, static_cast(Encoded.size())); +} + +WasiCryptoExpect> i2dEcdsaSig(ECDSA_SIG *Sig) { + int SigSize = i2d_ECDSA_SIG(Sig, nullptr); + ensureOrReturn(SigSize >= 0, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + std::vector Res(static_cast(SigSize)); + + auto *Data = Res.data(); + auto NewSize = i2d_ECDSA_SIG(Sig, &Data); + ensureOrReturn(NewSize == SigSize, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + + return Res; +} + +ECDSA_SIG *o2iEcdsaSig(Span Encoded) { + if (Encoded.size() > static_cast(std::numeric_limits::max())) { + return nullptr; + } + int EncodedSize = static_cast(Encoded.size()); + BnPtr R{BN_bin2bn(Encoded.data(), EncodedSize / 2, nullptr)}; + BnPtr S{ + BN_bin2bn(Encoded.data() + EncodedSize / 2, EncodedSize / 2, nullptr)}; + EcdsaSigPtr Sig{ECDSA_SIG_new()}; + if (!ECDSA_SIG_set0(Sig.get(), R.release(), S.release())) { + return nullptr; + } + + return Sig.release(); +} + +WasiCryptoExpect> i2oEcdsaSig(ECDSA_SIG *Sig) { + auto *R = ECDSA_SIG_get0_r(Sig); + auto *S = ECDSA_SIG_get0_s(Sig); + auto RSize = static_cast(BN_num_bytes(R)); + auto SSize = static_cast(BN_num_bytes(S)); + std::vector Res(RSize + SSize); + opensslCheck(BN_bn2bin(R, Res.data())); + opensslCheck(BN_bn2bin(S, Res.data() + RSize)); + + return Res; +} + +SharedEvpPkey::~SharedEvpPkey() noexcept { + if (Pkey != nullptr) { + EVP_PKEY_free(Pkey); + Pkey = nullptr; + } +} + +SharedEvpPkey::SharedEvpPkey(const SharedEvpPkey &Rhs) noexcept + : Pkey(Rhs.Pkey) { + if (Rhs.Pkey != nullptr) { + EVP_PKEY_up_ref(Pkey); + } +} + +SharedEvpPkey::SharedEvpPkey(SharedEvpPkey &&Rhs) noexcept : Pkey(Rhs.Pkey) { + Rhs.Pkey = nullptr; +} + +EVP_PKEY *SharedEvpPkey::get() const noexcept { return Pkey; } + +SharedEvpPkey::operator bool() const noexcept { return Pkey != nullptr; } + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/utils/evp_wrapper.h b/plugins/wasi_crypto/utils/evp_wrapper.h new file mode 100644 index 00000000..601fd7a1 --- /dev/null +++ b/plugins/wasi_crypto/utils/evp_wrapper.h @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/utils/evp_wrapper.h - Evp Wrapper ----===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the definition of OpenSSL evp relative function. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "utils/error.h" +#include "utils/secret_vec.h" + +#include "common/log.h" +#include "common/span.h" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +/// Helper alias for OpenSSL relative struct. +template using Deleter = std::integral_constant; + +template +using OpenSSLUniquePtr = std::unique_ptr>; + +using EvpMdCtxPtr = OpenSSLUniquePtr; +using EvpPkeyCtxPtr = OpenSSLUniquePtr; +using EvpCipherCtxPtr = OpenSSLUniquePtr; +using EvpPkeyPtr = OpenSSLUniquePtr; +using BioPtr = OpenSSLUniquePtr; +using EcKeyPtr = OpenSSLUniquePtr; +using BnPtr = OpenSSLUniquePtr; +using EcPointPtr = OpenSSLUniquePtr; +using EcdsaSigPtr = OpenSSLUniquePtr; +using RsaPtr = OpenSSLUniquePtr; + +/// OpenSSL functions always return 1 for success and 0/NULL for failure. This +/// is used to reduce repeating checking. +#ifdef NDEBUG +#define opensslCheck(Cond) \ + do { \ + if (!(Cond)) { \ + ERR_print_errors_cb( \ + [](const char *_Str, size_t, void *) { \ + spdlog::error(_Str); \ + return 1; \ + }, \ + nullptr); \ + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); \ + } \ + } while (0) +#else +#define opensslCheck(Cond) \ + (static_cast(Cond) \ + ? static_cast(0) \ + : (ERR_print_errors_cb( \ + [](const char *_Str, size_t, void *) { \ + spdlog::error(_Str); \ + return 1; \ + }, \ + nullptr), \ + OPENSSL_die("assertion failed: " #Cond, __FILE__, __LINE__))) +#endif + +/// OpenSSL encoding parse api is too confusing, simplify them. +/// For example, `PEM_read_bio_PUBKEY` is equal to `pemReadPUBKEY`. + +// ------------------------------------------------------------------------- // +EVP_PKEY *pemReadPUBKEY(Span Encoded); + +WasiCryptoExpect> pemWritePUBKEY(EVP_PKEY *Key); + +EVP_PKEY *pemReadPrivateKey(Span Encoded); + +WasiCryptoExpect pemWritePrivateKey(EVP_PKEY *Key); + +EVP_PKEY *d2iPUBKEY(Span Encoded); + +WasiCryptoExpect> i2dPUBKEY(EVP_PKEY *Key); + +EVP_PKEY *d2iPrivateKey(Span Encoded); + +WasiCryptoExpect i2dPrivateKey(EVP_PKEY *Key); + +ECDSA_SIG *d2iEcdsaSig(Span Encoded); + +WasiCryptoExpect> i2dEcdsaSig(ECDSA_SIG *Sig); + +// ------------------------------------------------------------------------- // +// Transform raw represent ecdsa ( r | s) to ECDSA_SIG. Need to check `nullptr`. +ECDSA_SIG *o2iEcdsaSig(Span Encoded); + +// Transform ECDSA_SIG to raw represent ( r | s). +WasiCryptoExpect> i2oEcdsaSig(ECDSA_SIG *Sig); + +// This is a wrapper for EVP_PKEY, since EVP_PKEY inner use lock to guarantee +// thread-safe `EVP_PKEY_up_ref` (you will find them in crypto/evp/p_lib.c in +// OpenSSL v1.1.1), use shared_ptr for `EVP_PKEY` is wasted. +// It only provide limits function to correct use. +class SharedEvpPkey { +public: + SharedEvpPkey(EvpPkeyPtr Pkey) noexcept : Pkey(Pkey.release()) {} + ~SharedEvpPkey() noexcept; + + SharedEvpPkey(const SharedEvpPkey &Rhs) noexcept; + SharedEvpPkey(SharedEvpPkey &&Rhs) noexcept; + // Assigning to existing SharedEvpPkey is not thread-safe, delete them. + SharedEvpPkey &operator=(const SharedEvpPkey &Rhs) noexcept = delete; + SharedEvpPkey &operator=(SharedEvpPkey &&Rhs) noexcept = delete; + + EVP_PKEY *get() const noexcept; + + explicit operator bool() const noexcept; + +private: + EVP_PKEY *Pkey; +}; + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/utils/handles_manager.h b/plugins/wasi_crypto/utils/handles_manager.h new file mode 100644 index 00000000..ba01f157 --- /dev/null +++ b/plugins/wasi_crypto/utils/handles_manager.h @@ -0,0 +1,202 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/utils/handles_manager.h --------------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the class definitions of the WasiCrypto HandlesManager. +/// It controls the handle and the inner states. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "utils/error.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace detail { + +/// The Handles Manager base class. +/// +/// @tparam HandleType This is the type of handle, notice it must be `32-bit +/// long`. +/// @tparam ManagerType The managed content type. +/// +/// HandlesManager uses handle as index to represent the managed contents. +/// +/// Referenced from: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/implementations/hostcalls/rust/src/handles.rs +template class BaseHandlesManager { +public: + BaseHandlesManager(const BaseHandlesManager &) noexcept = delete; + BaseHandlesManager &operator=(const BaseHandlesManager &) noexcept = delete; + BaseHandlesManager(BaseHandlesManager &&) noexcept = delete; + BaseHandlesManager &operator=(BaseHandlesManager &&) noexcept = delete; + + /// @param TypeID A unique number + BaseHandlesManager(uint8_t TypeID) noexcept : LastHandle{TypeID, 0} {} + + WasiCryptoExpect close(HandleType Handle) noexcept { + std::unique_lock Lock{Mutex}; + + if (!Map.erase(HandleWrapper(Handle))) + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_CLOSED); + + return {}; + } + + /// Constructor a new manager. + template + WasiCryptoExpect registerManager(Args &&...Manager) noexcept { + std::unique_lock Lock{Mutex}; + + // Find a handle that can be used and emplace. + // Assume that LastHandle is 0x01000000, NextHandle is 0x01000001. + auto NextHandle = LastHandle.nextHandle(); + while (true) { + // Try to emplace NextHandle. + if (Map.try_emplace(NextHandle, std::forward(Manager)...).second) { + // If success, emplace which indicate the NextHandle not exists in the + // managed content. Update the last handle and return it. + LastHandle = NextHandle; + return LastHandle.Handle; + } + // Otherwise, the NextHandle Map already exists a content, call NextHandle + // and loop. + NextHandle = NextHandle.nextHandle(); + + // If after looping `many times(2^24 - 1)`, we get 0x01000000 again. + if (NextHandle == LastHandle) { + // It indicates the hashmap is full. + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_TOO_MANY_HANDLES); + } + } + } + +protected: + /// The handle internal representation as [-TypeID-|------CurrentNumber------] + union HandleWrapper { + static_assert(sizeof(HandleType) == 4, "HandleType must be 4 byte"); + HandleWrapper(uint8_t TypeID, uint32_t CurrentNumber) noexcept + : TypeID(TypeID), CurrentNumber(CurrentNumber) {} + explicit HandleWrapper(HandleType Handle) : Handle(Handle) {} + + HandleWrapper nextHandle() noexcept { + return {TypeID, static_cast(CurrentNumber + 1)}; + } + + struct Hash { + size_t operator()(const HandleWrapper &Wrapper) const noexcept { + return static_cast(Wrapper.Handle); + } + }; + + bool operator==(const HandleWrapper &Wrapper) const noexcept { + return Wrapper.Handle == this->Handle; + } + + struct { + uint8_t TypeID : 8; + uint32_t CurrentNumber : 24; + }; + HandleType Handle; + }; + + std::shared_mutex Mutex; + HandleWrapper LastHandle; + std::unordered_map + Map; +}; + +template struct IsVariantMember; +template +struct IsVariantMember> + : public std::disjunction...> {}; + +} // namespace detail + +/// ManagerType need reference count. +template , bool> = + false> +class RcHandlesManager + : public detail::BaseHandlesManager { + using HandleWrapper = + typename detail::BaseHandlesManager::HandleWrapper; + +public: + /// Get the return copy. + WasiCryptoExpect get(HandleType Handle) noexcept { + std::shared_lock Lock{this->Mutex}; + + auto HandleValue = this->Map.find(HandleWrapper(Handle)); + if (HandleValue == this->Map.end()) { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_HANDLE); + } + return HandleValue->second; + } + + /// Get as different variant type. + template + WasiCryptoExpect getAs(HandleType Handle) noexcept { + std::shared_lock Lock{this->Mutex}; + + auto HandleValue = this->Map.find(HandleWrapper(Handle)); + if (HandleValue == this->Map.end()) { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_HANDLE); + } + return std::visit( + [](auto &&Value) noexcept -> WasiCryptoExpect { + using T = std::decay_t; + if constexpr (detail::IsVariantMember::value) { + + return Value; + } else { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_HANDLE); + } + }, + HandleValue->second); + } +}; + +/// ManagerType just use reference. +template +class RefHandlesManager + : public detail::BaseHandlesManager { + using HandleWrapper = + typename detail::BaseHandlesManager::HandleWrapper; + +public: + /// Get the return reference. + WasiCryptoExpect> + get(HandleType Handle) noexcept { + std::shared_lock Lock{this->Mutex}; + + auto HandleValue = this->Map.find(HandleWrapper(Handle)); + if (HandleValue == this->Map.end()) { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_HANDLE); + } + return HandleValue->second; + } +}; + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/utils/hostfunction.cpp b/plugins/wasi_crypto/utils/hostfunction.cpp new file mode 100644 index 00000000..be77896d --- /dev/null +++ b/plugins/wasi_crypto/utils/hostfunction.cpp @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "utils/hostfunction.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +using namespace std::literals; + +namespace { +std::string toUpper(std::string_view Name) noexcept { + std::string Ret{Name}; + std::transform(Ret.begin(), Ret.end(), Ret.begin(), + [](char C) { return std::toupper(C); }); + return Ret; +} +} // namespace + +WasiCryptoExpect +tryFrom(__wasi_algorithm_type_e_t AlgType, + std::string_view RawAlgStr) noexcept { + std::string AlgStr = toUpper(RawAlgStr); + // Delegate to sig and kx. + switch (AlgType) { + case __WASI_ALGORITHM_TYPE_SIGNATURES: { + return tryFrom(AlgStr).map([](auto Alg) noexcept { + return std::visit( + [](auto Factory) noexcept -> AsymmetricCommon::Algorithm { + return Factory; + }, + Alg); + }); + } + case __WASI_ALGORITHM_TYPE_KEY_EXCHANGE: { + return tryFrom(AlgStr).map([](auto Alg) noexcept { + return std::visit( + [](auto Factory) noexcept -> AsymmetricCommon::Algorithm { + return Factory; + }, + Alg); + }); + } + case __WASI_ALGORITHM_TYPE_SYMMETRIC: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + default: + assumingUnreachable(); + } +} + +template <> +WasiCryptoExpect tryFrom(std::string_view RawAlgStr) noexcept { + using namespace Kx; + std::string AlgStr = toUpper(RawAlgStr); + if (AlgStr == "X25519"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "P256-SHA256"sv) { + return Algorithm{std::in_place_type}; + } + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ALGORITHM); +} + +template <> +WasiCryptoExpect +tryFrom(std::string_view RawAlgStr) noexcept { + using namespace Symmetric; + std::string AlgStr = toUpper(RawAlgStr); + if (AlgStr == "SHA-256"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "SHA-512"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "SHA-512/256"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "HMAC/SHA-256"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "HMAC/SHA-512"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "HKDF-EXPAND/SHA-256"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "HKDF-EXTRACT/SHA-256"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "HKDF-EXPAND/SHA-512"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "HKDF-EXTRACT/SHA-512"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "AES-128-GCM"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "AES-256-GCM"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "CHACHA20-POLY1305"sv) { + return Algorithm{std::in_place_type}; + } + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ALGORITHM); +} + +template <> +WasiCryptoExpect +tryFrom(std::string_view RawAlgStr) noexcept { + using namespace Signatures; + std::string AlgStr = toUpper(RawAlgStr); + if (AlgStr == "ECDSA_P256_SHA256"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "ECDSA_K256_SHA256"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "ED25519"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "RSA_PKCS1_2048_SHA256"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "RSA_PKCS1_2048_SHA384"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "RSA_PKCS1_2048_SHA512"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "RSA_PKCS1_3072_SHA384"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "RSA_PKCS1_3072_SHA512"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "RSA_PKCS1_4096_SHA512"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "RSA_PSS_2048_SHA256"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "RSA_PSS_2048_SHA384"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "RSA_PSS_2048_SHA512"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "RSA_PSS_3072_SHA384"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "RSA_PSS_3072_SHA512"sv) { + return Algorithm{std::in_place_type}; + } + if (AlgStr == "RSA_PSS_4096_SHA512"sv) { + return Algorithm{std::in_place_type}; + } + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ALGORITHM); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/utils/hostfunction.h b/plugins/wasi_crypto/utils/hostfunction.h new file mode 100644 index 00000000..78c1f375 --- /dev/null +++ b/plugins/wasi_crypto/utils/hostfunction.h @@ -0,0 +1,160 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/hostfunc.h - HostFunction class ------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the HostFunction classes and some helper functions +/// interact with wasi. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "ctx.h" +#include "symmetric/registed.h" +#include "utils/error.h" + +#include "runtime/hostfunc.h" +#include "runtime/instance/memory.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +/// The wasi-crypto's base HostFunction class, every wasi-crypto HostFunction +/// should inherit from this class. +template class HostFunction : public Runtime::HostFunction { +public: + HostFunction(Context &Ctx) : Runtime::HostFunction(0), Ctx(Ctx) {} + +protected: + Context &Ctx; +}; + +/// Cast wasi enum to c++ enum. +template constexpr WasiCryptoExpect cast(uint64_t) noexcept; + +template struct WasiRawType { + using Type = std::underlying_type_t; +}; +template <> struct WasiRawType { using Type = uint8_t; }; +template <> struct WasiRawType { using Type = uint16_t; }; +template <> struct WasiRawType { using Type = uint32_t; }; +template <> struct WasiRawType { using Type = uint64_t; }; + +template using WasiRawTypeT = typename WasiRawType::Type; + +template <> +constexpr WasiCryptoExpect<__wasi_algorithm_type_e_t> +cast(uint64_t AlgType) noexcept { + switch (static_cast>(AlgType)) { + case __WASI_ALGORITHM_TYPE_SIGNATURES: + case __WASI_ALGORITHM_TYPE_SYMMETRIC: + case __WASI_ALGORITHM_TYPE_KEY_EXCHANGE: + return static_cast<__wasi_algorithm_type_e_t>(AlgType); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ALGORITHM); + } +} + +template <> +constexpr WasiCryptoExpect<__wasi_keypair_encoding_e_t> +cast(uint64_t Encoding) noexcept { + switch (static_cast>(Encoding)) { + case __WASI_KEYPAIR_ENCODING_RAW: + case __WASI_KEYPAIR_ENCODING_PKCS8: + case __WASI_KEYPAIR_ENCODING_PEM: + case __WASI_KEYPAIR_ENCODING_COMPRESSED_PKCS8: + case __WASI_KEYPAIR_ENCODING_COMPRESSED_PEM: + case __WASI_KEYPAIR_ENCODING_LOCAL: + return static_cast<__wasi_keypair_encoding_e_t>(Encoding); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +template <> +constexpr WasiCryptoExpect<__wasi_publickey_encoding_e_t> +cast(uint64_t Encoding) noexcept { + switch (static_cast>(Encoding)) { + case __WASI_PUBLICKEY_ENCODING_RAW: + case __WASI_PUBLICKEY_ENCODING_PKCS8: + case __WASI_PUBLICKEY_ENCODING_PEM: + case __WASI_PUBLICKEY_ENCODING_SEC: + case __WASI_PUBLICKEY_ENCODING_COMPRESSED_SEC: + case __WASI_PUBLICKEY_ENCODING_COMPRESSED_PKCS8: + case __WASI_PUBLICKEY_ENCODING_COMPRESSED_PEM: + case __WASI_PUBLICKEY_ENCODING_LOCAL: + return static_cast<__wasi_publickey_encoding_e_t>(Encoding); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +template <> +constexpr WasiCryptoExpect<__wasi_secretkey_encoding_e_t> +cast(uint64_t Encoding) noexcept { + switch (static_cast>(Encoding)) { + case __WASI_SECRETKEY_ENCODING_RAW: + case __WASI_SECRETKEY_ENCODING_PKCS8: + case __WASI_SECRETKEY_ENCODING_PEM: + case __WASI_SECRETKEY_ENCODING_SEC: + case __WASI_SECRETKEY_ENCODING_LOCAL: + return static_cast<__wasi_secretkey_encoding_e_t>(Encoding); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +template <> +constexpr WasiCryptoExpect<__wasi_signature_encoding_e_t> +cast(uint64_t Encoding) noexcept { + switch (static_cast>(Encoding)) { + case __WASI_SIGNATURE_ENCODING_RAW: + case __WASI_SIGNATURE_ENCODING_DER: + return static_cast<__wasi_signature_encoding_e_t>(Encoding); + default: + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); + } +} + +/// Cast c++ size_t to wasi size_t. +constexpr WasiCryptoExpect<__wasi_size_t> toWasiSize(size_t Size) noexcept { + ensureOrReturn(Size <= std::numeric_limits<__wasi_size_t>::max(), + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + return static_cast<__wasi_size_t>(Size); +} + +/// Convert string_view to inner Alg representation. +template WasiCryptoExpect tryFrom(std::string_view) noexcept; + +template <> +WasiCryptoExpect +tryFrom(std::string_view RawAlgStr) noexcept; + +WasiCryptoExpect +tryFrom(__wasi_algorithm_type_e_t AlgType, std::string_view RawAlgStr) noexcept; + +template <> +WasiCryptoExpect tryFrom(std::string_view RawAlgStr) noexcept; + +template <> +WasiCryptoExpect +tryFrom(std::string_view RawAlgStr) noexcept; + +/// Check exist or return `_algorithm_failure`. +#define checkExist(Expr) \ + do { \ + if (unlikely(!(Expr))) { \ + return __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE; \ + } \ + } while (0) + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/utils/optional.h b/plugins/wasi_crypto/utils/optional.h new file mode 100644 index 00000000..ea8e93f6 --- /dev/null +++ b/plugins/wasi_crypto/utils/optional.h @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/utils/handles_manager.h - OptionalRef ===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the definition of the OptionalRef and some helper +/// functions. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "utils/error.h" + +#include +#include +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +template using OptionalRef = T *; + +namespace detail { + +template struct IsOptional : std::false_type {}; +template struct IsOptional> : std::true_type { + using Type = T; +}; + +template struct IsExpected : std::false_type {}; +template struct IsExpected> : std::true_type { + using Type = T; +}; + +template struct IsOptionalRef : std::false_type {}; +template struct IsOptionalRef> : std::true_type { + using Type = T; +}; + +template struct IsExpectedOptionalRef : std::false_type {}; +template +struct IsExpectedOptionalRef>> + : std::true_type { + using Type = T; +}; + +} // namespace detail + +/// std::optional -> (T -> WasiCrypto) -> +/// WasiCryptoExpect> +template +inline auto mapAndTransposeOptional(const __wasi_opt_options_t Optional, + F &&Function) noexcept + -> std::enable_if_t< + detail::IsExpected(Function), Optional.u.some))>>::value, + WasiCryptoExpect(Function), Optional.u.some))>>::Type>>> { + if (Optional.tag == __WASI_OPT_OPTIONS_U_NONE) + return std::nullopt; + + return std::invoke(std::forward(Function), Optional.u.some); +} + +template +inline auto mapAndTransposeOptional(const __wasi_opt_symmetric_key_t Optional, + F &&Function) noexcept + -> std::enable_if_t< + detail::IsExpected(Function), Optional.u.some))>>::value, + WasiCryptoExpect(Function), Optional.u.some))>>::Type>>> { + if (Optional.tag == __WASI_OPT_SYMMETRIC_KEY_U_NONE) + return std::nullopt; + + return std::invoke(std::forward(Function), Optional.u.some); +} + +/// std::optional -> (T -> WasiCryptoExpect>) -> +/// WasiCryptoExpect> +template < + typename O, typename F, + typename = std::enable_if_t>::value>> +inline auto transposeOptionalToRef(O &&Optional, F &&Function) noexcept + -> WasiCryptoExpect(Function), *std::forward(Optional)))>>::Type>> { + if (!Optional) + return nullptr; + + return std::invoke(std::forward(Function), *Optional); +} + +/// OptionalRef -> (T -> WasiCryptoExpect>) -> +/// WasiCryptoExpect> +template < + typename O, typename F, + typename = std::enable_if_t>::value>> +inline auto transposeOptionalRef(O &&Optional, F &&Function) noexcept + -> WasiCryptoExpect(Function), *std::forward(Optional)))>>::Type>> { + if (!Optional) + return nullptr; + + return std::invoke(std::forward(Function), *Optional); +} + +/// std::optional -> OptionalRef +template >::value>> +inline auto asOptionalRef(O &&Optional) noexcept + -> OptionalRef>::Type> { + if (!Optional) + return nullptr; + + return &*Optional; +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/utils/secret_vec.h b/plugins/wasi_crypto/utils/secret_vec.h new file mode 100644 index 00000000..b329046b --- /dev/null +++ b/plugins/wasi_crypto/utils/secret_vec.h @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/utils/secret_vec.h - Secret Vec def --===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the definition of the secret vec. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "utils/error.h" + +#include "common/span.h" + +#include +#include + +#include +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +/// A vector wrapper, but swipe the secret key info on destory. +class SecretVec { +public: + SecretVec(const SecretVec &) = default; + SecretVec &operator=(const SecretVec &) = default; + SecretVec &operator=(SecretVec &&) noexcept = default; + SecretVec(SecretVec &&) noexcept = default; + + SecretVec(Span Data) noexcept + : Data(Data.begin(), Data.end()) {} + + SecretVec(size_t Size) noexcept : Data(Size) {} + + ~SecretVec() noexcept { OPENSSL_cleanse(Data.data(), Data.size()); } + + auto begin() noexcept { return Data.begin(); } + auto begin() const noexcept { return Data.begin(); } + + auto end() noexcept { return Data.end(); } + auto end() const noexcept { return Data.end(); } + + auto size() const noexcept { return Data.size(); } + + auto data() noexcept { return Data.data(); } + auto data() const noexcept { return Data.data(); } + + using difference_type = std::vector::difference_type; + + /// Generate random size vector. Notice that the size shouldn't beyond + /// std::numeric_limits::max() because of the limitations of openssl. + template static WasiCryptoExpect random() noexcept { + static_assert( + Size <= std::numeric_limits::max(), + "Random key size shouldn't beyond std::numeric_limits::max()"); + + SecretVec Res(Size); + ensureOrReturn(RAND_bytes(Res.data(), static_cast(Size)), + __WASI_CRYPTO_ERRNO_RNG_ERROR); + return Res; + } + +private: + std::vector Data; +}; + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt index 7f920c0e..1b8f8f04 100644 --- a/test/plugins/CMakeLists.txt +++ b/test/plugins/CMakeLists.txt @@ -4,6 +4,9 @@ if (CMAKE_SYSTEM_NAME MATCHES "Linux") add_subdirectory(wasmedge_process) endif() -if(WASMEDGE_PLUGIN_WASI_NN_BACKEND) +if (WASMEDGE_PLUGIN_WASI_NN_BACKEND) add_subdirectory(wasi_nn) endif() +if (WASMEDGE_PLUGIN_WASI_CRYPTO) + add_subdirectory(wasi_crypto) +endif() diff --git a/test/plugins/wasi_crypto/CMakeLists.txt b/test/plugins/wasi_crypto/CMakeLists.txt new file mode 100644 index 00000000..acbed2b8 --- /dev/null +++ b/test/plugins/wasi_crypto/CMakeLists.txt @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +wasmedge_add_executable(wasiCryptoTests + aeads.cpp + asymmetric.cpp + common.cpp + hash.cpp + helper.cpp + kdf.cpp + kx.cpp + mac.cpp + notimplement.cpp + signatures.cpp +) + +target_link_libraries(wasiCryptoTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} + wasmedgePlugin + wasmedgePluginWasiCrypto +) + +add_test(wasiCryptoTests wasiCryptoTests) diff --git a/test/plugins/wasi_crypto/aeads.cpp b/test/plugins/wasi_crypto/aeads.cpp new file mode 100644 index 00000000..8d449adc --- /dev/null +++ b/test/plugins/wasi_crypto/aeads.cpp @@ -0,0 +1,116 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "helper.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +using namespace std::literals; + +TEST_F(WasiCryptoTest, Aeads) { + auto AeadsTest = [this](std::string_view Name, + const std::vector &Nonce, size_t MaxTagSize, + const std::vector &Msg) { + SCOPED_TRACE(Name); + + WASI_CRYPTO_EXPECT_SUCCESS(KeyHandle, + symmetricKeyGenerate(Name, std::nullopt)); + WASI_CRYPTO_EXPECT_SUCCESS(OptionsHandle, + optionsOpen(__WASI_ALGORITHM_TYPE_SYMMETRIC)); + // Repeatedly set and overwrite the previous option. + WASI_CRYPTO_EXPECT_TRUE(optionsSet(OptionsHandle, "nonce"sv, "nonce"_u8)); + WASI_CRYPTO_EXPECT_TRUE(optionsSet(OptionsHandle, "nonce"sv, Nonce)); + WASI_CRYPTO_EXPECT_SUCCESS( + State1Handle, symmetricStateOpen(Name, KeyHandle, OptionsHandle)); + + // State nonce equal to the previous set one. + std::vector ObservedNonce(Nonce.size()); + symmetricStateOptionsGet(State1Handle, "nonce"sv, ObservedNonce); + EXPECT_EQ(ObservedNonce, Nonce); + WASI_CRYPTO_EXPECT_SUCCESS(TagSize, symmetricStateMaxTagLen(State1Handle)); + EXPECT_EQ(TagSize, MaxTagSize); + + std::vector CiphertextWithTag(Msg.size() + MaxTagSize); + + WASI_CRYPTO_EXPECT_SUCCESS( + OutTagSize, + symmetricStateEncrypt(State1Handle, CiphertextWithTag, Msg)); + EXPECT_EQ(OutTagSize, CiphertextWithTag.size()); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(State1Handle)); + + { + WASI_CRYPTO_EXPECT_SUCCESS( + State2Handle, symmetricStateOpen(Name, KeyHandle, OptionsHandle)); + std::vector Msg2(Msg.size()); + WASI_CRYPTO_EXPECT_SUCCESS( + OutputMsg2Size, + symmetricStateDecrypt(State2Handle, Msg2, CiphertextWithTag)); + EXPECT_EQ(OutputMsg2Size, Msg2.size()); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(State2Handle)); + EXPECT_EQ(Msg2, Msg); + } + + WASI_CRYPTO_EXPECT_SUCCESS( + State3Handle, symmetricStateOpen(Name, KeyHandle, OptionsHandle)); + std::vector Ciphertext(Msg.size()); + WASI_CRYPTO_EXPECT_SUCCESS(TagHandle, symmetricStateEncryptDetached( + State3Handle, Ciphertext, Msg)); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(State3Handle)); + std::vector Tag(MaxTagSize); + WASI_CRYPTO_EXPECT_SUCCESS(OutputTagSize, symmetricTagPull(TagHandle, Tag)); + EXPECT_EQ(OutputTagSize, MaxTagSize); + EXPECT_EQ(Tag, + std::vector( + CiphertextWithTag.begin() + + static_cast( + Msg.size()), + CiphertextWithTag.end())); + + WASI_CRYPTO_EXPECT_SUCCESS( + State4Handle, symmetricStateOpen(Name, KeyHandle, OptionsHandle)); + std::vector Msg3(Msg.size()); + symmetricStateDecryptDetached(State4Handle, Msg3, Ciphertext, Tag); + EXPECT_EQ("test"_u8, Msg3); + WASI_CRYPTO_EXPECT_TRUE(optionsClose(OptionsHandle)); + + { + // Some error cases checking. + EXPECT_TRUE( + symmetricStateOpen(Name, InvaildHandle, std::nullopt).error() == + __WASI_CRYPTO_ERRNO_INVALID_HANDLE); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateOptionsGet(State4Handle, "foo"sv, {}), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateOptionsGetU64(State4Handle, "foo"sv), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateSqueezeTag(State4Handle), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateSqueezeKey(State4Handle, Name), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateSqueeze(State4Handle, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateRatchet(State4Handle), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + { + // Clone checking. + WASI_CRYPTO_EXPECT_SUCCESS(NewStateHandle, + symmetricStateClone(State4Handle)); + EXPECT_NE(State4Handle, NewStateHandle); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(NewStateHandle)); + } + + WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(State4Handle)); + }; + + AeadsTest("AES-128-GCM"sv, std::vector(12, 42), 16, "test"_u8); + AeadsTest("AES-256-GCM"sv, std::vector(12, 42), 16, "test"_u8); + AeadsTest("CHACHA20-POLY1305"sv, std::vector(12, 42), 16, "test"_u8); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasi_crypto/asymmetric.cpp b/test/plugins/wasi_crypto/asymmetric.cpp new file mode 100644 index 00000000..c34f3a8d --- /dev/null +++ b/test/plugins/wasi_crypto/asymmetric.cpp @@ -0,0 +1,605 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "helper.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +using namespace std::literals; + +TEST_F(WasiCryptoTest, Asymmetric) { + auto EncodingCheck = + [this](std::string_view Alg, __wasi_algorithm_type_e_t AlgType, + std::map<__wasi_publickey_encoding_e_t, std::vector> + SupportPk, + std::map<__wasi_secretkey_encoding_e_t, std::vector> + SupportSk, + std::map<__wasi_keypair_encoding_e_t, std::vector> + SupportKp) { + SCOPED_TRACE(Alg); + + // Function checking. + { + WASI_CRYPTO_EXPECT_SUCCESS( + KpHandle, keypairGenerate(AlgType, Alg, std::nullopt)); + WASI_CRYPTO_EXPECT_SUCCESS(PkHandle, keypairPublickey(KpHandle)); + WASI_CRYPTO_EXPECT_SUCCESS(SkHandle, keypairSecretkey(KpHandle)); + WASI_CRYPTO_EXPECT_TRUE(keypairClose(KpHandle)); + WASI_CRYPTO_EXPECT_TRUE(publickeyClose(PkHandle)); + WASI_CRYPTO_EXPECT_TRUE(secretkeyClose(SkHandle)); + } + + // Encoding checking. + for (auto &&[PkEncoding, Pk] : SupportPk) { + SCOPED_TRACE("Public key encoding"); + SCOPED_TRACE(PkEncoding); + WASI_CRYPTO_EXPECT_SUCCESS( + PkHandle, publickeyImport(AlgType, Alg, Pk, PkEncoding)); + + std::vector ExportPk(Pk.size()); + WASI_CRYPTO_EXPECT_SUCCESS(PkOutputHandle, + publickeyExport(PkHandle, PkEncoding)); + WASI_CRYPTO_EXPECT_TRUE(arrayOutputPull(PkOutputHandle, ExportPk)); + EXPECT_EQ(ExportPk, Pk); + } + for (auto &&[SkEncoding, Sk] : SupportSk) { + SCOPED_TRACE("Secret key encoding"); + SCOPED_TRACE(SkEncoding); + WASI_CRYPTO_EXPECT_SUCCESS( + SkHandle, secretkeyImport(AlgType, Alg, Sk, SkEncoding)); + + std::vector ExportSk(Sk.size()); + WASI_CRYPTO_EXPECT_SUCCESS(SkOutputHandle, + secretkeyExport(SkHandle, SkEncoding)); + WASI_CRYPTO_EXPECT_TRUE(arrayOutputPull(SkOutputHandle, ExportSk)); + EXPECT_EQ(ExportSk, Sk); + } + for (auto &&[KpEncoding, Kp] : SupportKp) { + SCOPED_TRACE("Key Pair encoding"); + SCOPED_TRACE(KpEncoding); + WASI_CRYPTO_EXPECT_SUCCESS( + KpHandle, keypairImport(AlgType, Alg, Kp, KpEncoding)); + + std::vector ExportKp(Kp.size()); + WASI_CRYPTO_EXPECT_SUCCESS(KpOutputHandle, + keypairExport(KpHandle, KpEncoding)); + WASI_CRYPTO_EXPECT_TRUE(arrayOutputPull(KpOutputHandle, ExportKp)); + EXPECT_EQ(ExportKp, Kp); + } + }; + EncodingCheck( + "X25519"sv, __WASI_ALGORITHM_TYPE_KEY_EXCHANGE, + {{__WASI_PUBLICKEY_ENCODING_RAW, + "8520f0098930a754748b7ddcb43ef75a0dbf3a0d26381af4eba4a98eaa9b4e6a"_u8v}}, + {{__WASI_SECRETKEY_ENCODING_RAW, + "77076d0a7318a57d3c16c17251b26645df4c2f87ebc0992ab177fba51db92c2a"_u8v}}, + {{__WASI_KEYPAIR_ENCODING_RAW, + "8520f0098930a754748b7ddcb43ef75a0dbf3a0d26381af4eba4a98eaa9b4e6a77076d0a7318a57d3c16c17251b26645df4c2f87ebc0992ab177fba51db92c2a"_u8v}}); + EncodingCheck( + "ECDSA_P256_SHA256"sv, __WASI_ALGORITHM_TYPE_SIGNATURES, + {{__WASI_PUBLICKEY_ENCODING_SEC, + "0460FED4BA255A9D31C961EB74C6356D68C049B8923B61FA6CE669622E60F29FB67903FE1008B8BC99A41AE9E95628BC64F2F1B20C2D7E9F5177A3C294D4462299"_u8v}, + {__WASI_PUBLICKEY_ENCODING_PKCS8, + "3059301306072a8648ce3d020106082a8648ce3d0301070342000460FED4BA255A9D31C961EB74C6356D68C049B8923B61FA6CE669622E60F29FB67903FE1008B8BC99A41AE9E95628BC64F2F1B20C2D7E9F5177A3C294D4462299"_u8v}, + {__WASI_PUBLICKEY_ENCODING_PEM, + "-----BEGIN PUBLIC KEY-----\n" + "MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEYP7UuiVanTHJYet0xjVtaMBJuJI7\n" + "Yfps5mliLmDyn7Z5A/4QCLi8maQa6elWKLxk8vGyDC1+n1F3o8KU1EYimQ==\n" + "-----END PUBLIC KEY-----\n"_u8}, + {__WASI_PUBLICKEY_ENCODING_COMPRESSED_SEC, + "0360FED4BA255A9D31C961EB74C6356D68C049B8923B61FA6CE669622E60F29FB6"_u8v}, + {__WASI_PUBLICKEY_ENCODING_COMPRESSED_PKCS8, + "3039301306072a8648ce3d020106082a8648ce3d0301070322000360FED4BA255A9D31C961EB74C6356D68C049B8923B61FA6CE669622E60F29FB6"_u8v}, + {__WASI_PUBLICKEY_ENCODING_COMPRESSED_PEM, + "-----BEGIN PUBLIC KEY-----\n" + "MDkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDIgADYP7UuiVanTHJYet0xjVtaMBJuJI7\n" + "Yfps5mliLmDyn7Y=\n" + "-----END PUBLIC KEY-----\n"_u8}}, + {{__WASI_SECRETKEY_ENCODING_RAW, + "C9AFA9D845BA75166B5C215767B1D6934E50C3DB36E89B127B8A622B120F6721"_u8v}, + {__WASI_SECRETKEY_ENCODING_PEM, + "-----BEGIN PRIVATE KEY-----\n" + "MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgya+p2EW6dRZrXCFX\n" + "Z7HWk05Qw9s26JsSe4piKxIPZyGhRANCAARg/tS6JVqdMclh63TGNW1owEm4kjth\n" + "+mzmaWIuYPKftnkD/hAIuLyZpBrp6VYovGTy8bIMLX6fUXejwpTURiKZ\n" + "-----END PRIVATE KEY-----\n"_u8}}, + {}); + EncodingCheck( + "ECDSA_K256_SHA256"sv, __WASI_ALGORITHM_TYPE_SIGNATURES, + {{__WASI_PUBLICKEY_ENCODING_SEC, + "04b838ff44e5bc177bf21189d0766082fc9d843226887fc9760371100b7ee20a6ff0c9d75bfba7b31a6bca1974496eeb56de357071955d83c4b1badaa0b21832e9"_u8v}, + {__WASI_PUBLICKEY_ENCODING_PKCS8, + "3056301006072a8648ce3d020106052b8104000a03420004b838ff44e5bc177bf21189d0766082fc9d843226887fc9760371100b7ee20a6ff0c9d75bfba7b31a6bca1974496eeb56de357071955d83c4b1badaa0b21832e9"_u8v}, + {__WASI_PUBLICKEY_ENCODING_PEM, + "-----BEGIN PUBLIC KEY-----\n" + "MFYwEAYHKoZIzj0CAQYFK4EEAAoDQgAEuDj/ROW8F3vyEYnQdmCC/J2EMiaIf8l2\n" + "A3EQC37iCm/wyddb+6ezGmvKGXRJbutW3jVwcZVdg8Sxutqgshgy6Q==\n" + "-----END PUBLIC KEY-----"_u8}, + {__WASI_PUBLICKEY_ENCODING_COMPRESSED_SEC, + "03b838ff44e5bc177bf21189d0766082fc9d843226887fc9760371100b7ee20a6f"_u8v}, + {__WASI_PUBLICKEY_ENCODING_COMPRESSED_PEM, + "-----BEGIN PUBLIC KEY-----\n" + "MDYwEAYHKoZIzj0CAQYFK4EEAAoDIgADuDj/ROW8F3vyEYnQdmCC/J2EMiaIf8l2\n" + "A3EQC37iCm8=\n" + "-----END PUBLIC KEY-----\n"_u8}, + {__WASI_PUBLICKEY_ENCODING_COMPRESSED_PKCS8, + "3036301006072a8648ce3d020106052b8104000a03220003b838ff44e5bc177bf21189d0766082fc9d843226887fc9760371100b7ee20a6f"_u8v}}, + {{__WASI_SECRETKEY_ENCODING_RAW, + "b9aa5c28ef96d750e47f4ba44d5d6a7ac3ab6988d292e7819e362a4b0ac8e250"_u8v}, + {__WASI_SECRETKEY_ENCODING_PKCS8, + "30740201010420b9aa5c28ef96d750e47f4ba44d5d6a7ac3ab6988d292e7819e362a4b" + "0ac8e250a00706052b8104000aa144034200047fef8e21686370c7d343992f14b2d45a" + "262cd6a5c75032736fcbb02f46a99edf0e1d114cdc93956cc75648bfd38fa832a82135" + "d5c2ba634766a8753f6d88aae5"_u8v}, + {__WASI_SECRETKEY_ENCODING_PEM, + "-----BEGIN PRIVATE KEY-----\n" + "MIGEAgEAMBAGByqGSM49AgEGBSuBBAAKBG0wawIBAQQguapcKO+W11Dkf0ukTV1q\n" + "esOraYjSkueBnjYqSwrI4lChRANCAAR/744haGNwx9NDmS8UstRaJizWpcdQMnNv\n" + "y7AvRqme3w4dEUzck5Vsx1ZIv9OPqDKoITXVwrpjR2aodT9tiKrl\n" + "-----END PRIVATE KEY-----\n"_u8}}, + {}); + EncodingCheck( + "Ed25519"sv, __WASI_ALGORITHM_TYPE_SIGNATURES, + {{__WASI_PUBLICKEY_ENCODING_RAW, + "d75a980182b10ab7d54bfed3c964073a0ee172f3daa62325af021a68f707511a"_u8v}}, + {{__WASI_SECRETKEY_ENCODING_RAW, + "9d61b19deffd5a60ba844af492ec2cc44449c5697b326919703bac031cae7f60"_u8v}}, + {{__WASI_KEYPAIR_ENCODING_RAW, + "9d61b19deffd5a60ba844af492ec2cc44449c5697b326919703bac031cae7f60d75a980182b10ab7d54bfed3c964073a0ee172f3daa62325af021a68f707511a"_u8v}}); + + auto RsaCheck = + [EncodingCheck]( + std::string_view Bit, + std::map<__wasi_publickey_encoding_e_t, std::vector> + SupportPk, + std::map<__wasi_secretkey_encoding_e_t, std::vector> + SupportSk, + std::map<__wasi_keypair_encoding_e_t, std::vector> + SupportKp) { + if (Bit == "2048"sv) { + EncodingCheck("RSA_PSS_2048_SHA256"sv, + __WASI_ALGORITHM_TYPE_SIGNATURES, SupportPk, SupportSk, + SupportKp); + EncodingCheck("RSA_PSS_2048_SHA384"sv, + __WASI_ALGORITHM_TYPE_SIGNATURES, SupportPk, SupportSk, + SupportKp); + EncodingCheck("RSA_PSS_2048_SHA512"sv, + __WASI_ALGORITHM_TYPE_SIGNATURES, SupportPk, SupportSk, + SupportKp); + EncodingCheck("RSA_PKCS1_2048_SHA256"sv, + __WASI_ALGORITHM_TYPE_SIGNATURES, SupportPk, SupportSk, + SupportKp); + EncodingCheck("RSA_PKCS1_2048_SHA384"sv, + __WASI_ALGORITHM_TYPE_SIGNATURES, SupportPk, SupportSk, + SupportKp); + EncodingCheck("RSA_PKCS1_2048_SHA512"sv, + __WASI_ALGORITHM_TYPE_SIGNATURES, SupportPk, SupportSk, + SupportKp); + } + if (Bit == "3072"sv) { + EncodingCheck("RSA_PSS_3072_SHA384"sv, + __WASI_ALGORITHM_TYPE_SIGNATURES, SupportPk, SupportSk, + SupportKp); + EncodingCheck("RSA_PSS_3072_SHA512"sv, + __WASI_ALGORITHM_TYPE_SIGNATURES, SupportPk, SupportSk, + SupportKp); + EncodingCheck("RSA_PKCS1_3072_SHA384"sv, + __WASI_ALGORITHM_TYPE_SIGNATURES, SupportPk, SupportSk, + SupportKp); + EncodingCheck("RSA_PKCS1_3072_SHA512"sv, + __WASI_ALGORITHM_TYPE_SIGNATURES, SupportPk, SupportSk, + SupportKp); + } + if (Bit == "4096"sv) { + EncodingCheck("RSA_PSS_4096_SHA512"sv, + __WASI_ALGORITHM_TYPE_SIGNATURES, SupportPk, SupportSk, + SupportKp); + EncodingCheck("RSA_PKCS1_4096_SHA512"sv, + __WASI_ALGORITHM_TYPE_SIGNATURES, SupportPk, SupportSk, + SupportKp); + } + }; + + RsaCheck( + "2048"sv, + {{__WASI_PUBLICKEY_ENCODING_PEM, + "-----BEGIN PUBLIC KEY-----\n" + "MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAtoGqtGXL6bqbDK0IJ2pJ\n" + "fM7hfMpyYVXWxtEdv5oNErvGWpHKiH1NIeQxn0ks1QFmkHoK8Quvzyn5/anHLxiZ\n" + "ecWbzRA01MTfs1y7HMBlrToZhRTZPFe7VRGJ+Liv1jpIiRHPzqabZrssgs3Kj5fG\n" + "0EYXITaQRfn0kfZcYtJLYi0OvU18DBi64MrLwABr3wqn2UZgMgiw3MhKvyXRybca\n" + "x0ASO1RTxAJIm21XuFWTztHcJBvl66ygDAzzRdOJyPWvG+TuhNXvZ7dtA0N4iU8p\n" + "SwJljzLEzWzwKOgAizx3Q3EdS+9P+pTdKtei9UGWVunoj46kCw+0QasQE958NPa3\n" + "uQIDAQAB\n" + "-----END PUBLIC KEY-----\n"_u8}, + {__WASI_PUBLICKEY_ENCODING_PKCS8, + "30820122300d06092a864886f70d01010105000382010f003082010a0282" + "010100b681aab465cbe9ba9b0cad08276a497ccee17cca726155d6c6d11d" + "bf9a0d12bbc65a91ca887d4d21e4319f492cd50166907a0af10bafcf29f9" + "fda9c72f189979c59bcd1034d4c4dfb35cbb1cc065ad3a198514d93c57bb" + "551189f8b8afd63a488911cfcea69b66bb2c82cdca8f97c6d04617213690" + "45f9f491f65c62d24b622d0ebd4d7c0c18bae0cacbc0006bdf0aa7d94660" + "3208b0dcc84abf25d1c9b71ac740123b5453c402489b6d57b85593ced1dc" + "241be5ebaca00c0cf345d389c8f5af1be4ee84d5ef67b76d034378894f29" + "4b02658f32c4cd6cf028e8008b3c7743711d4bef4ffa94dd2ad7a2f54196" + "56e9e88f8ea40b0fb441ab1013de7c34f6b7b90203010001"_u8v}}, + {{__WASI_SECRETKEY_ENCODING_PEM, + "-----BEGIN PRIVATE KEY-----\n" + "MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC2gaq0ZcvpupsM\n" + "rQgnakl8zuF8ynJhVdbG0R2/mg0Su8ZakcqIfU0h5DGfSSzVAWaQegrxC6/PKfn9\n" + "qccvGJl5xZvNEDTUxN+zXLscwGWtOhmFFNk8V7tVEYn4uK/WOkiJEc/OpptmuyyC\n" + "zcqPl8bQRhchNpBF+fSR9lxi0ktiLQ69TXwMGLrgysvAAGvfCqfZRmAyCLDcyEq/\n" + "JdHJtxrHQBI7VFPEAkibbVe4VZPO0dwkG+XrrKAMDPNF04nI9a8b5O6E1e9nt20D\n" + "Q3iJTylLAmWPMsTNbPAo6ACLPHdDcR1L70/6lN0q16L1QZZW6eiPjqQLD7RBqxAT\n" + "3nw09re5AgMBAAECggEAJds7p3O+GltEshpqKJLZb3QSPapYk2wUwuS5gPbZY1tj\n" + "x4GaOzmSeEc3K80n6X8C4VEPV/SOoTAZ1M4UrOYzX5jnul90NfYoWLIRdeNKs+Xr\n" + "STmL3gJsrzaWIetdPdiVFymEq17PuT11/CPnsmVPLgB7573Dq2AvpN8vRqhMTq6j\n" + "6JoaErGi6nqhfOZ0J7GfwKEJtMGPWbOwyjsMpLxrW3PAy0YH99UZD8Ocxw39hERH\n" + "8i9SGSjk1ty5/NnpC6K8wJmb7BOUXFl1g00FkH9nI/rydjV3Xc6L+vLeLrTvDE59\n" + "uPtRKoFSCbHmxoARkzm8RapX+R/tsJOUxwwPRMjaAQKBgQDsja822AFkjKoDtK4k\n" + "R6R9FKcvW/mUVeKz4ldMQynA+5adU2cyNI7UVRWYcbJI7eSefIaaFe/aHhA2NSny\n" + "83+IhcPfa+z10TBz93CcLs7szYgBddrjYsQ8Q7gc73dGBRqm4oQW+HpuFqPlU5z1\n" + "NZ7eBtlJJBCdxpilR/bC+GyQYQKBgQDFgo2vg1dMKEIm7YfRSXojpRUcAxS4qCJe\n" + "hOcZePkkd1mHmwBkOkX6MMlUozNiCXTmwXjRMWRgZJktbH0I9pwSHCEzJbmfnF5W\n" + "GUwVdfCEvZ4gIcgDZyhsm4A9+L2qC5Nm3dr30NR7c34K1ZpiK3AJTWoI19uEmY46\n" + "1ryOGH5GWQKBgQCxKdAPGCm636q5ScmefFWSJDSuQIkkckpuhNby095inUqJG5zP\n" + "OhO6rNqWqJhpDFpL5GF+520Sg6+KmbiIL5vVaLFxFEiNNhW+1JPvNRNewPPafCTq\n" + "Zd8ob2NlsGc49rumP0HEXmZ7KtOm/j8wWu9Xw/NaVvtm3wUVzFbgYOQWIQKBgQC+\n" + "T2mOcJOxQilbsQxpUM9rgSmx8BYLR5a2VIEJPlNyG74ct/HMoYnD5TZZY1ejY1FM\n" + "96ceiuUZLFWcOyjPdjA0Ev66deNCND2B4KY7F4VFoh+2/lXnUYLWA4+yJvc53iWN\n" + "vL+8gW/78/DDJ8a2SPyPOhStqLBQOFWfxEGy+U7TIQKBgA+cdl5pm1YcVNhcfGbx\n" + "zjN75MLGkq9QfL6zxqWIv4xUsjyYkwGgqwazMeNmi5KvhgpfUM8A8tJQixXmq/oe\n" + "m8MDtOovmQ3I1S6jYK7V8wyxyqgj/2oetPhRIjvht8IaHuFKQkjv6424vdoukuoM\n" + "3o5BHs2kyvh9kuSthBY9XZnN\n" + "-----END PRIVATE KEY-----\n"_u8}, + {__WASI_SECRETKEY_ENCODING_PKCS8, + "308204a40201000282010100b681aab465cbe9ba9b0cad08276a497ccee1" + "7cca726155d6c6d11dbf9a0d12bbc65a91ca887d4d21e4319f492cd50166" + "907a0af10bafcf29f9fda9c72f189979c59bcd1034d4c4dfb35cbb1cc065" + "ad3a198514d93c57bb551189f8b8afd63a488911cfcea69b66bb2c82cdca" + "8f97c6d0461721369045f9f491f65c62d24b622d0ebd4d7c0c18bae0cacb" + "c0006bdf0aa7d946603208b0dcc84abf25d1c9b71ac740123b5453c40248" + "9b6d57b85593ced1dc241be5ebaca00c0cf345d389c8f5af1be4ee84d5ef" + "67b76d034378894f294b02658f32c4cd6cf028e8008b3c7743711d4bef4f" + "fa94dd2ad7a2f5419656e9e88f8ea40b0fb441ab1013de7c34f6b7b90203" + "0100010282010025db3ba773be1a5b44b21a6a2892d96f74123daa58936c" + "14c2e4b980f6d9635b63c7819a3b39927847372bcd27e97f02e1510f57f4" + "8ea13019d4ce14ace6335f98e7ba5f7435f62858b21175e34ab3e5eb4939" + "8bde026caf369621eb5d3dd895172984ab5ecfb93d75fc23e7b2654f2e00" + "7be7bdc3ab602fa4df2f46a84c4eaea3e89a1a12b1a2ea7aa17ce67427b1" + "9fc0a109b4c18f59b3b0ca3b0ca4bc6b5b73c0cb4607f7d5190fc39cc70d" + "fd844447f22f521928e4d6dcb9fcd9e90ba2bcc0999bec13945c5975834d" + "05907f6723faf27635775dce8bfaf2de2eb4ef0c4e7db8fb512a815209b1" + "e6c680119339bc45aa57f91fedb09394c70c0f44c8da0102818100ec8daf" + "36d801648caa03b4ae2447a47d14a72f5bf99455e2b3e2574c4329c0fb96" + "9d536732348ed455159871b248ede49e7c869a15efda1e10363529f2f37f" + "8885c3df6becf5d13073f7709c2eceeccd880175dae362c43c43b81cef77" + "46051aa6e28416f87a6e16a3e5539cf5359ede06d94924109dc698a547f6" + "c2f86c906102818100c5828daf83574c284226ed87d1497a23a5151c0314" + "b8a8225e84e71978f9247759879b00643a45fa30c954a333620974e6c178" + "d131646064992d6c7d08f69c121c213325b99f9c5e56194c1575f084bd9e" + "2021c80367286c9b803df8bdaa0b9366dddaf7d0d47b737e0ad59a622b70" + "094d6a08d7db84998e3ad6bc8e187e465902818100b129d00f1829badfaa" + "b949c99e7c55922434ae408924724a6e84d6f2d3de629d4a891b9ccf3a13" + "baacda96a898690c5a4be4617ee76d1283af8a99b8882f9bd568b1711448" + "8d3615bed493ef35135ec0f3da7c24ea65df286f6365b06738f6bba63f41" + "c45e667b2ad3a6fe3f305aef57c3f35a56fb66df0515cc56e060e4162102" + "818100be4f698e7093b142295bb10c6950cf6b8129b1f0160b4796b65481" + "093e53721bbe1cb7f1cca189c3e536596357a363514cf7a71e8ae5192c55" + "9c3b28cf76303412feba75e342343d81e0a63b178545a21fb6fe55e75182" + "d6038fb226f739de258dbcbfbc816ffbf3f0c327c6b648fc8f3a14ada8b0" + "5038559fc441b2f94ed3210281800f9c765e699b561c54d85c7c66f1ce33" + "7be4c2c692af507cbeb3c6a588bf8c54b23c989301a0ab06b331e3668b92" + "af860a5f50cf00f2d2508b15e6abfa1e9bc303b4ea2f990dc8d52ea360ae" + "d5f30cb1caa823ff6a1eb4f851223be1b7c21a1ee14a4248efeb8db8bdda" + "2e92ea0cde8e411ecda4caf87d92e4ad84163d5d99cd"_u8v}}, + {}); + RsaCheck( + "3072"sv, + {{__WASI_PUBLICKEY_ENCODING_PEM, + "-----BEGIN PUBLIC KEY-----\n" + "MIIBojANBgkqhkiG9w0BAQEFAAOCAY8AMIIBigKCAYEAnvafMCYCGPGOhpvpA3FG\n" + "8EtN6P/wIQv+fC3MuqrOo3JBYyYVGIxMrpxX/iGEDOEDWkq9Lp7ANFKMu2P2cR+O\n" + "udZ+7WvIB8nINcvPZCGdqndOL9PoJV2zJveQ6FVMZpbh7Nc+dj5Mc2WFiE1LwkEe\n" + "aZD4c5r5UsT2CUvfdDTUoRqNnYVAjnQWfurrfo/o9gjietrvvGKT/dOtdDh/WvEl\n" + "3+HRwVrEeffsO/1tEPbAPHHvtnd1gTNJbNtzVAAUsHT7a4OxPK7JS3hmtv6JRp/Z\n" + "j9VUkDie86a1nD7dFW+S7a3N0W/0EMBiWJduiYye1Qitf2Uf0Dpo78J8lnJce7zR\n" + "1UwQc6EIuNlZh4EODIzz+Pm39locuuFgVVq38dcNStCeSX03GuL75SN3sBT6vXms\n" + "PeEhjhQxxlkGICbfuhCk1TPHj2Q8UROUAvZzbgspbhrtYYaVGrOc+eWBsqlgRzzV\n" + "bF//cJ9yRZy0PzsQSVTonAGpNXLHKTpB438hz5auwfNfAgMBAAE=\n" + "-----END PUBLIC KEY-----\n"_u8}, + {__WASI_PUBLICKEY_ENCODING_PKCS8, + "308201a2300d06092a864886f70d01010105000382018f003082018a0282" + "0181009ef69f30260218f18e869be9037146f04b4de8fff0210bfe7c2dcc" + "baaacea37241632615188c4cae9c57fe21840ce1035a4abd2e9ec034528c" + "bb63f6711f8eb9d67eed6bc807c9c835cbcf64219daa774e2fd3e8255db3" + "26f790e8554c6696e1ecd73e763e4c736585884d4bc2411e6990f8739af9" + "52c4f6094bdf7434d4a11a8d9d85408e74167eeaeb7e8fe8f608e27adaef" + "bc6293fdd3ad74387f5af125dfe1d1c15ac479f7ec3bfd6d10f6c03c71ef" + "b677758133496cdb73540014b074fb6b83b13caec94b7866b6fe89469fd9" + "8fd55490389ef3a6b59c3edd156f92edadcdd16ff410c06258976e898c9e" + "d508ad7f651fd03a68efc27c96725c7bbcd1d54c1073a108b8d95987810e" + "0c8cf3f8f9b7f65a1cbae160555ab7f1d70d4ad09e497d371ae2fbe52377" + "b014fabd79ac3de1218e1431c659062026dfba10a4d533c78f643c511394" + "02f6736e0b296e1aed6186951ab39cf9e581b2a960473cd56c5fff709f72" + "459cb43f3b104954e89c01a93572c7293a41e37f21cf96aec1f35f020301" + "0001"_u8v}}, + {{__WASI_SECRETKEY_ENCODING_PEM, + "-----BEGIN PRIVATE KEY-----\n" + "MIIG/AIBADANBgkqhkiG9w0BAQEFAASCBuYwggbiAgEAAoIBgQCe9p8wJgIY8Y6G\n" + "m+kDcUbwS03o//AhC/58Lcy6qs6jckFjJhUYjEyunFf+IYQM4QNaSr0unsA0Uoy7\n" + "Y/ZxH4651n7ta8gHycg1y89kIZ2qd04v0+glXbMm95DoVUxmluHs1z52PkxzZYWI\n" + "TUvCQR5pkPhzmvlSxPYJS990NNShGo2dhUCOdBZ+6ut+j+j2COJ62u+8YpP90610\n" + "OH9a8SXf4dHBWsR59+w7/W0Q9sA8ce+2d3WBM0ls23NUABSwdPtrg7E8rslLeGa2\n" + "/olGn9mP1VSQOJ7zprWcPt0Vb5Ltrc3Rb/QQwGJYl26JjJ7VCK1/ZR/QOmjvwnyW\n" + "clx7vNHVTBBzoQi42VmHgQ4MjPP4+bf2Why64WBVWrfx1w1K0J5JfTca4vvlI3ew\n" + "FPq9eaw94SGOFDHGWQYgJt+6EKTVM8ePZDxRE5QC9nNuCyluGu1hhpUas5z55YGy\n" + "qWBHPNVsX/9wn3JFnLQ/OxBJVOicAak1cscpOkHjfyHPlq7B818CAwEAAQKCAYBK\n" + "w+QLWVUTNkm6tgnaPKUIz+JM/FOMt39yGHh6M2wNI+ftIjQ534MRfSdFt63MAOj6\n" + "xrxD+RadhVX7rQB0JEuUzHXWZSMnxpgL9VgN2GG3k3WKuTgumutwIHBfVf8hIUYR\n" + "hwsxwgtjGxS7Dt/a9ZXAQRcaCIHLlCfEJ5NprI91Vm/U7p92YNNTzloEpNsFHRio\n" + "f+DR0euZLr4eM5RyyYjuy99D+dT/KMRLUt7BY8z2oQAF6hmyMtUOBgkwMPmKJPp9\n" + "xBLHH/25jcY9t3+PM9vT3xkrJ3GrcF9kHHd4R03lLdxuvf0B/AvHlZUphwYcf6ET\n" + "eKS1H9t1CII2Z0ZKHlv2C4Zy3ZH06yIFPNjJI9wCJcq9QnH1SN3Kf1Jak7iXT0xC\n" + "rOrUE4N2mcp8nF3HhgeZZ0vPCW/NBvwHmZgdDIrZI7m5aGq+MZr2JOqrcHgVDMC6\n" + "KCvMpkJRN8Ine6ICpBKd4tl79mrRBbXWK/hCEIwSG9o7C8CPV6NrnZMfgVMUqgEC\n" + "gcEAzWoCtDSsn08CUivWoLdwcvXkCXWihRYM45EwSY3Zkz1YLsjp7NNGoX0mOzWW\n" + "aDa6Ofm2KSQv2pBw3oVcR7ns/IUg/2GSP+v1aZpWGx7wPUu3dbuMfpnt8PKUazbL\n" + "JoBF83bMXZN5LKijEW4HWqO5WsjUwwLUJm0ouvSabnMof6sCeviiDe38Be2lXGfJ\n" + "jCsj9hvmGUqPdFCQJ9c6IjXaP20qvOYNDdVYJ5AJLZywfH7RDGHd4+GH5GToHXTr\n" + "9h7hAoHBAMYcM7jj2NKaNkEYnf7YTnjNymdDYLwm7UCFDKHmluYzleNHuyQu53Y+\n" + "hSyNhKRRVtdxuSKybDPKzTSw5gN4YP1kCGITcDTIQkt9W70Z/U4JL5WSaTyFelZY\n" + "Wg8MACPZhIDexZ+f+bNA82VBinE40kTXAsp+I/dNZHusZJnHR0Q0bnarfCz8UuQe\n" + "qFO4vVH71MpJvthVD/wHuhHZT/NpRXDwgrkKVOC8VwK5U2ZTtVf8uppdsBr0gmyX\n" + "FHI+kz2aPwKBwQCF1E2SrsbQvB8c/ibFav4+R+mcKCIMZ0NaeFtncJ2SimMLiCav\n" + "/y6DRBBGfzFREGbgIssFnuf2lCiVMXnf2UiHdQz8lcs9DjRD6yOyY8PNi6kpcVml\n" + "mhAl7UW5XGea2/O3HW0kglJuQCiN0IvGB+lZNoM30n350yC4PWjoEOsP0pC5IYgj\n" + "XyvViPE1dQEg63Jwg9i0HZm9BEgHTPg5FbDtpeg0TgWvP5JBpFv2daGeWtlEIfb4\n" + "4xUwPnXjyyt4nMECgb9cFr/0MfWX8BdIKylGTUYs4Xw0hB1zWKTwWOiGWanLWC9U\n" + "dwOGzkbJsEY3b5E40JaNj09/0XB6osrAs3o4IrzzDIzZCjAeWPh4Hs2GGY6lt59m\n" + "56gDeghkGq3CUNG/2Fy/is5SZQqtSIPbjZvNBZy4Yzno5rnROyh6VKhu0zNNgRHY\n" + "F96hCql9YMLeKAHZGjbP0XflF6VWgkD8CwgfHdApr6MUYLkTvnizy3H5HvAs9k3H\n" + "c8Vowj/eOlxGvs+y0wKBwHOIGyQ49ethiJMUyt9orVfV04luQgvOOi2VRu8ZQesp\n" + "hA3tuBVilKq7G9EnH2EHy9/xpOql4SzqvboCivh4oKf0W8UyAKQ3Yv7/h4T/obRY\n" + "QzPUB3vc657B+5DzQEhjnP88BF6poznANEmuK9YJ4OXeA3RQaQiIEIFbY+GgaO7u\n" + "znZv+HE2UJcfnbGW3m/naIKfIhcIaq2bfWOG5hYd1L6Dazkfttc2GtRYaBpyo0bz\n" + "gpvbM5ETOLFocm1MM1hcOA==\n" + "-----END PRIVATE KEY-----\n"_u8}, + {__WASI_SECRETKEY_ENCODING_PKCS8, + "308206e202010002820181009ef69f30260218f18e869be9037146f04b4d" + "e8fff0210bfe7c2dccbaaacea37241632615188c4cae9c57fe21840ce103" + "5a4abd2e9ec034528cbb63f6711f8eb9d67eed6bc807c9c835cbcf64219d" + "aa774e2fd3e8255db326f790e8554c6696e1ecd73e763e4c736585884d4b" + "c2411e6990f8739af952c4f6094bdf7434d4a11a8d9d85408e74167eeaeb" + "7e8fe8f608e27adaefbc6293fdd3ad74387f5af125dfe1d1c15ac479f7ec" + "3bfd6d10f6c03c71efb677758133496cdb73540014b074fb6b83b13caec9" + "4b7866b6fe89469fd98fd55490389ef3a6b59c3edd156f92edadcdd16ff4" + "10c06258976e898c9ed508ad7f651fd03a68efc27c96725c7bbcd1d54c10" + "73a108b8d95987810e0c8cf3f8f9b7f65a1cbae160555ab7f1d70d4ad09e" + "497d371ae2fbe52377b014fabd79ac3de1218e1431c659062026dfba10a4" + "d533c78f643c51139402f6736e0b296e1aed6186951ab39cf9e581b2a960" + "473cd56c5fff709f72459cb43f3b104954e89c01a93572c7293a41e37f21" + "cf96aec1f35f0203010001028201804ac3e40b5955133649bab609da3ca5" + "08cfe24cfc538cb77f7218787a336c0d23e7ed223439df83117d2745b7ad" + "cc00e8fac6bc43f9169d8555fbad0074244b94cc75d6652327c6980bf558" + "0dd861b793758ab9382e9aeb7020705f55ff21214611870b31c20b631b14" + "bb0edfdaf595c041171a0881cb9427c4279369ac8f75566fd4ee9f7660d3" + "53ce5a04a4db051d18a87fe0d1d1eb992ebe1e339472c988eecbdf43f9d4" + "ff28c44b52dec163ccf6a10005ea19b232d50e06093030f98a24fa7dc412" + "c71ffdb98dc63db77f8f33dbd3df192b2771ab705f641c7778474de52ddc" + "6ebdfd01fc0bc795952987061c7fa11378a4b51fdb7508823667464a1e5b" + "f60b8672dd91f4eb22053cd8c923dc0225cabd4271f548ddca7f525a93b8" + "974f4c42acead413837699ca7c9c5dc7860799674bcf096fcd06fc079998" + "1d0c8ad923b9b9686abe319af624eaab7078150cc0ba282bcca6425137c2" + "277ba202a4129de2d97bf66ad105b5d62bf842108c121bda3b0bc08f57a3" + "6b9d931f815314aa010281c100cd6a02b434ac9f4f02522bd6a0b77072f5" + "e40975a285160ce39130498dd9933d582ec8e9ecd346a17d263b35966836" + "ba39f9b629242fda9070de855c47b9ecfc8520ff61923febf5699a561b1e" + "f03d4bb775bb8c7e99edf0f2946b36cb268045f376cc5d93792ca8a3116e" + "075aa3b95ac8d4c302d4266d28baf49a6e73287fab027af8a20dedfc05ed" + "a55c67c98c2b23f61be6194a8f74509027d73a2235da3f6d2abce60d0dd5" + "582790092d9cb07c7ed10c61dde3e187e464e81d74ebf61ee10281c100c6" + "1c33b8e3d8d29a3641189dfed84e78cdca674360bc26ed40850ca1e696e6" + "3395e347bb242ee7763e852c8d84a45156d771b922b26c33cacd34b0e603" + "7860fd640862137034c8424b7d5bbd19fd4e092f9592693c857a56585a0f" + "0c0023d98480dec59f9ff9b340f365418a7138d244d702ca7e23f74d647b" + "ac6499c74744346e76ab7c2cfc52e41ea853b8bd51fbd4ca49bed8550ffc" + "07ba11d94ff3694570f082b90a54e0bc5702b9536653b557fcba9a5db01a" + "f4826c9714723e933d9a3f0281c10085d44d92aec6d0bc1f1cfe26c56afe" + "3e47e99c28220c67435a785b67709d928a630b8826afff2e834410467f31" + "511066e022cb059ee7f69428953179dfd94887750cfc95cb3d0e3443eb23" + "b263c3cd8ba9297159a59a1025ed45b95c679adbf3b71d6d2482526e4028" + "8dd08bc607e959368337d27df9d320b83d68e810eb0fd290b92188235f2b" + "d588f135750120eb727083d8b41d99bd0448074cf83915b0eda5e8344e05" + "af3f9241a45bf675a19e5ad94421f6f8e315303e75e3cb2b789cc10281bf" + "5c16bff431f597f017482b29464d462ce17c34841d7358a4f058e88659a9" + "cb582f54770386ce46c9b046376f9138d0968d8f4f7fd1707aa2cac0b37a" + "3822bcf30c8cd90a301e58f8781ecd86198ea5b79f66e7a8037a08641aad" + "c250d1bfd85cbf8ace52650aad4883db8d9bcd059cb86339e8e6b9d13b28" + "7a54a86ed3334d8111d817dea10aa97d60c2de2801d91a36cfd177e517a5" + "568240fc0b081f1dd029afa31460b913be78b3cb71f91ef02cf64dc773c5" + "68c23fde3a5c46becfb2d30281c073881b2438f5eb61889314cadf68ad57" + "d5d3896e420bce3a2d9546ef1941eb29840dedb8156294aabb1bd1271f61" + "07cbdff1a4eaa5e12ceabdba028af878a0a7f45bc53200a43762feff8784" + "ffa1b4584333d4077bdceb9ec1fb90f34048639cff3c045ea9a339c03449" + "ae2bd609e0e5de03745069088810815b63e1a068eeeece766ff871365097" + "1f9db196de6fe768829f2217086aad9b7d6386e6161dd4be836b391fb6d7" + "361ad458681a72a346f3829bdb33911338b168726d4c33585c38"_u8v}}, + {}); + RsaCheck( + "4096"sv, + {{__WASI_PUBLICKEY_ENCODING_PEM, + "-----BEGIN PUBLIC KEY-----\n" + "MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEArcIlHK14T1jiaKxiiS3E\n" + "jOcQvxErqnc/aQLPXNIx/d21D1/Ii1vODoSTt5KAvWE87ah0rE5oAs0wMVRWSAV+\n" + "evUAG+WVjBiN8crGilTVWyylCNVBM2Kwrt95rDYLz68OL7ZOgJ393KfZ2HbMq51Q\n" + "d6AJsvTkEGVdbbO3q9Opx1htoVm4Crac3TgKpNQceQ6i7b3gFBMaIjGNhD8fuWgK\n" + "5FAYQgstzURXB5B/vzgOSgkG70q1Ksr8iq5q5UpQRoVaUwq7TXvFR0hARoE0lS7T\n" + "WeAA5GksGwLYPDfg7WvZn1LmCLU0OAAdPJAegFOZa87rKZTo8PXWMgbQM15/rz1G\n" + "DFVbJRToGr1pgO0lNZ19uBFfmStcQowHNR1is9GCv08yb6ZQBszA3JhGet2Aaqy9\n" + "HgN6nS8bzJnYljz3q5MTwzLvtBYl7Bw4MBDypbPnqHeGyUY1LJ8BkRbX0b39FPbl\n" + "9ywWXDnjJ8sly1D3FQmtOiN54grrYW8Ee9kTPQchdgogOX1pJRTRuEnGJI5aRiDE\n" + "gVkrx12jSymdeUpuy7UdvPkWKQYzwgABgNc7RZs6+qWkQZBsU96jSxMayKTsP5U6\n" + "fT2rkSAsmXH6HPjtD2gvpbvl32MUBGlbNOUYh4gePv01GiztuONCjASfhRvjJagK\n" + "4zn18PfVWWpgqas4rrFDyLECAwEAAQ==\n" + "-----END PUBLIC KEY-----\n"_u8}, + {__WASI_PUBLICKEY_ENCODING_PKCS8, + "30820222300d06092a864886f70d01010105000382020f003082020a0282" + "020100adc2251cad784f58e268ac62892dc48ce710bf112baa773f6902cf" + "5cd231fdddb50f5fc88b5bce0e8493b79280bd613ceda874ac4e6802cd30" + "31545648057e7af5001be5958c188df1cac68a54d55b2ca508d5413362b0" + "aedf79ac360bcfaf0e2fb64e809dfddca7d9d876ccab9d5077a009b2f4e4" + "10655d6db3b7abd3a9c7586da159b80ab69cdd380aa4d41c790ea2edbde0" + "14131a22318d843f1fb9680ae45018420b2dcd445707907fbf380e4a0906" + "ef4ab52acafc8aae6ae54a5046855a530abb4d7bc5474840468134952ed3" + "59e000e4692c1b02d83c37e0ed6bd99f52e608b53438001d3c901e805399" + "6bceeb2994e8f0f5d63206d0335e7faf3d460c555b2514e81abd6980ed25" + "359d7db8115f992b5c428c07351d62b3d182bf4f326fa65006ccc0dc9846" + "7add806aacbd1e037a9d2f1bcc99d8963cf7ab9313c332efb41625ec1c38" + "3010f2a5b3e7a87786c946352c9f019116d7d1bdfd14f6e5f72c165c39e3" + "27cb25cb50f71509ad3a2379e20aeb616f047bd9133d0721760a20397d69" + "2514d1b849c6248e5a4620c481592bc75da34b299d794a6ecbb51dbcf916" + "290633c2000180d73b459b3afaa5a441906c53dea34b131ac8a4ec3f953a" + "7d3dab91202c9971fa1cf8ed0f682fa5bbe5df631404695b34e51887881e" + "3efd351a2cedb8e3428c049f851be325a80ae339f5f0f7d5596a60a9ab38" + "aeb143c8b10203010001"_u8v}}, + {{__WASI_SECRETKEY_ENCODING_PEM, + "-----BEGIN PRIVATE KEY-----\n" + "MIIJQwIBADANBgkqhkiG9w0BAQEFAASCCS0wggkpAgEAAoICAQCtwiUcrXhPWOJo\n" + "rGKJLcSM5xC/ESuqdz9pAs9c0jH93bUPX8iLW84OhJO3koC9YTztqHSsTmgCzTAx\n" + "VFZIBX569QAb5ZWMGI3xysaKVNVbLKUI1UEzYrCu33msNgvPrw4vtk6Anf3cp9nY\n" + "dsyrnVB3oAmy9OQQZV1ts7er06nHWG2hWbgKtpzdOAqk1Bx5DqLtveAUExoiMY2E\n" + "Px+5aArkUBhCCy3NRFcHkH+/OA5KCQbvSrUqyvyKrmrlSlBGhVpTCrtNe8VHSEBG\n" + "gTSVLtNZ4ADkaSwbAtg8N+Dta9mfUuYItTQ4AB08kB6AU5lrzusplOjw9dYyBtAz\n" + "Xn+vPUYMVVslFOgavWmA7SU1nX24EV+ZK1xCjAc1HWKz0YK/TzJvplAGzMDcmEZ6\n" + "3YBqrL0eA3qdLxvMmdiWPPerkxPDMu+0FiXsHDgwEPKls+eod4bJRjUsnwGRFtfR\n" + "vf0U9uX3LBZcOeMnyyXLUPcVCa06I3niCuthbwR72RM9ByF2CiA5fWklFNG4ScYk\n" + "jlpGIMSBWSvHXaNLKZ15Sm7LtR28+RYpBjPCAAGA1ztFmzr6paRBkGxT3qNLExrI\n" + "pOw/lTp9PauRICyZcfoc+O0PaC+lu+XfYxQEaVs05RiHiB4+/TUaLO2440KMBJ+F\n" + "G+MlqArjOfXw99VZamCpqziusUPIsQIDAQABAoICAQCfy0GyA932qrlcpdvgaBSv\n" + "t/fwnuvXUt8fxZPJuwx6eR//yYh2kLEJLOdkFPkMMJaFwTu7EkgY+3ZshzDp/xN4\n" + "JEQ7Y4GKWzJ+wIqhwK6NsJr9apERnpr5107gDrwB/O1A95luMt25xStUJLzIvl24\n" + "BZel2gy6/11Se8pX3MnwJ+R6VDYqtBHCZ71yJBcjRVCU7t9Z1s9bztJkYmDcc1BA\n" + "81+7rOgsM8MNk9fHlNefQnn8Kmo9tntVVl28DAGTOSP95oqmEUM18L4bmMswvuVj\n" + "a9umMwp6tL0DdCgIb/ysxuIB9BLXxVMd1TQXs8oOGTavAODQaGTZkOZ7t1YZZHI7\n" + "c8rJwOycVKDE9Y0prtI5oyyUF4kbOkICO7aHiOaXRX1JVsUVHV2ck5/UQfAyq2XZ\n" + "/Q99j8rHJYHmCOBoy7ECQ87SStGU/ypyl3PC31af8lTEQpN+aj/B2yq2NaaEfouw\n" + "fW9kNNMkfizJQ6aB4uYya63vZBdAYlZ8T3KbgiEFTPO6Am03BipReChDtfQPlNlx\n" + "eKcYxKXwEp+ddDu+fNmXowpJeiuFbc3/tSS8WvfMQDmVu0ikOG5MfDy1tL22F42w\n" + "lhNnVtt1lwN2op3/WfyfyroNqMA+pT2BqnjzAH/cAhKnI6DU+8bTG8jmPpMnftXH\n" + "k91OUUjucDe+pLLJ0gdE+QKCAQEA56Wq/2x8drGlAYhoRmS8lJldtU2zGUCXnBHk\n" + "roZVOkRvG0bDZTokBkn7ZHhH8/uCIjIUaNJRbpFDzlbtShxYphzunIrJq8OhSoAE\n" + "vKOqVII5rDQ9WlZvIIZ46SQcg6uFbA6l95hi2WPwjl5Enzc3sw1i6fEv4eziPWFa\n" + "JEHlZaonuIEjo8uHKZBhgG3nW5w5oYcPzTOM3ogca2DM1xa1hBzJi0P28Zn+NJBt\n" + "JPnfZiyHlzsXdAtlJtvTvpVknIpwnjfL8vM9RCk1+EfWXGqgrn/pI1IlfQn1ZPRd\n" + "/xJ/+qwZN2StReHjU3Z+KjNR+roALpBV/AhkPd36rmhoXqHPTwKCAQEAwAaDfT2U\n" + "6MzV0cudB1Q1UxrdFqsyxycIWWkxenn/LdOULecN0ujoW8TSTz4PWsXRDAhtdRS0\n" + "nsg4XalKb9UYnqzohNQcqv0rJR/M1HaZ4coUACZcICfiAdFWP1SiFZgu9PJW/EUJ\n" + "u2z08t35iOGwGLxfHFPoiZaomAVlyL+kkEXA8AWwms5QCZjf+IZuT/RbUK4qgzcR\n" + "xbYsv0cjmymCObrFNQjjsedQDygOmj0rA8X1Q/tNfKsx86YJV2CsErWJW4RUZP/h\n" + "Ws9kR40l2NvJnWZEHsnIRNPFgJ90NavXvUwY4VEt0D9An9jiWXLXF5NpChFiWB6/\n" + "8EvgwiydQhrn/wKCAQEAyk5xbOm+OZsj1JbhGrlXyR+4K2NUizVSM0edRJ6lSGID\n" + "9vpyI7IHTEbIexJhJL//AwZhtLoZzEqpwUdBrXvcIBccfTLotk4ASyRK/sShOXUS\n" + "EUb+Xismmm1Wo6aaEJR3zcttPzOjAOC7clr562M6DfIe9NljTBip7ZlcNFYolgVo\n" + "80Y1bhOOU8p4nMVfTS6/Vkayki/3U1HkIBNGUoLOvDa3/hy5Sn+G9zk7WROw+3bg\n" + "ZD+DWCGrkahi4Qtv9xchC80HHYM5epHTRKbYm5W0BzJG1kYj33QXELgqb14kzzQG\n" + "Qc53VZTWCEpwHUL80dAn4ILF1XsusKlxCWi93gfLGQKCAQBwhCCNxQS4+DUdjgI/\n" + "5h6syGPdwYiqWvuwcEv2qP9V2dDMqMNX3vMvun9EwWd718drFpEUdoJzO3yTnPup\n" + "1aJsb4J7OlJl+pxKT3zUzX3TaHYZtGBs0xHB4Oh5iVzD7H0vN8SyYr2WHfzVRi3N\n" + "//gQNmhAkAYEgMve7+K5I1oI02Z+/caCnvsU9Iff9t0yaksLVlJAuobmY52KouOB\n" + "KmxM6VxefAv3FUO67czIoajPuDHDmL/JmgJV8ucsVM/e0pJeloZg+/IPJNBsgI85\n" + "p2dWnDK0G6YGdlQWztfoDv4FxE4b0FZY3IdAYnQW14yjGtQEezU1zybGZZ+YB05K\n" + "Crv/AoIBADlvlifpD0dkIQ2udfG7wUi0OXmj9sepQ73k0n8ULzX1PV3oMyknGs54\n" + "0KlSKGgg/47YezvI4BupBZ0zVox0Ztgilg2oHdQlmEAGGiDBYNLO/Nd9FotytcAu\n" + "gjyanngI5IaXzKVEAW6QFZTkeIavl/S44NtB39MjM7tRaaJNu60PaZ/UOlicyyNj\n" + "RZPS2JrYmymSo9La27si9yY9L/gXCw0yL2/3sbR5cPEkoFN0o9XGDs5BWsnvMLzH\n" + "UwtbtmC2Pw4kp8kzRk21cH75Yl/wg9Oir95uL0C8w4B7FacCPijNPCPsvmYJi9Cg\n" + "Kah5KuUlkoLBFrMqKJCKqhvP72HZUp8=\n" + "-----END PRIVATE KEY-----\n"_u8}, + {__WASI_SECRETKEY_ENCODING_PKCS8, + "308209290201000282020100adc2251cad784f58e268ac62892dc48ce710" + "bf112baa773f6902cf5cd231fdddb50f5fc88b5bce0e8493b79280bd613c" + "eda874ac4e6802cd3031545648057e7af5001be5958c188df1cac68a54d5" + "5b2ca508d5413362b0aedf79ac360bcfaf0e2fb64e809dfddca7d9d876cc" + "ab9d5077a009b2f4e410655d6db3b7abd3a9c7586da159b80ab69cdd380a" + "a4d41c790ea2edbde014131a22318d843f1fb9680ae45018420b2dcd4457" + "07907fbf380e4a0906ef4ab52acafc8aae6ae54a5046855a530abb4d7bc5" + "474840468134952ed359e000e4692c1b02d83c37e0ed6bd99f52e608b534" + "38001d3c901e8053996bceeb2994e8f0f5d63206d0335e7faf3d460c555b" + "2514e81abd6980ed25359d7db8115f992b5c428c07351d62b3d182bf4f32" + "6fa65006ccc0dc98467add806aacbd1e037a9d2f1bcc99d8963cf7ab9313" + "c332efb41625ec1c383010f2a5b3e7a87786c946352c9f019116d7d1bdfd" + "14f6e5f72c165c39e327cb25cb50f71509ad3a2379e20aeb616f047bd913" + "3d0721760a20397d692514d1b849c6248e5a4620c481592bc75da34b299d" + "794a6ecbb51dbcf916290633c2000180d73b459b3afaa5a441906c53dea3" + "4b131ac8a4ec3f953a7d3dab91202c9971fa1cf8ed0f682fa5bbe5df6314" + "04695b34e51887881e3efd351a2cedb8e3428c049f851be325a80ae339f5" + "f0f7d5596a60a9ab38aeb143c8b1020301000102820201009fcb41b203dd" + "f6aab95ca5dbe06814afb7f7f09eebd752df1fc593c9bb0c7a791fffc988" + "7690b1092ce76414f90c309685c13bbb124818fb766c8730e9ff13782444" + "3b63818a5b327ec08aa1c0ae8db09afd6a91119e9af9d74ee00ebc01fced" + "40f7996e32ddb9c52b5424bcc8be5db80597a5da0cbaff5d527bca57dcc9" + "f027e47a54362ab411c267bd72241723455094eedf59d6cf5bced2646260" + "dc735040f35fbbace82c33c30d93d7c794d79f4279fc2a6a3db67b55565d" + "bc0c01933923fde68aa6114335f0be1b98cb30bee5636bdba6330a7ab4bd" + "037428086ffcacc6e201f412d7c5531dd53417b3ca0e1936af00e0d06864" + "d990e67bb7561964723b73cac9c0ec9c54a0c4f58d29aed239a32c941789" + "1b3a42023bb68788e697457d4956c5151d5d9c939fd441f032ab65d9fd0f" + "7d8fcac72581e608e068cbb10243ced24ad194ff2a729773c2df569ff254" + "c442937e6a3fc1db2ab635a6847e8bb07d6f6434d3247e2cc943a681e2e6" + "326badef64174062567c4f729b8221054cf3ba026d37062a51782843b5f4" + "0f94d97178a718c4a5f0129f9d743bbe7cd997a30a497a2b856dcdffb524" + "bc5af7cc403995bb48a4386e4c7c3cb5b4bdb6178db096136756db759703" + "76a29dff59fc9fcaba0da8c03ea53d81aa78f3007fdc0212a723a0d4fbc6" + "d31bc8e63e93277ed5c793dd4e5148ee7037bea4b2c9d20744f902820101" + "00e7a5aaff6c7c76b1a50188684664bc94995db54db31940979c11e4ae86" + "553a446f1b46c3653a240649fb647847f3fb8222321468d2516e9143ce56" + "ed4a1c58a61cee9c8ac9abc3a14a8004bca3aa548239ac343d5a566f2086" + "78e9241c83ab856c0ea5f79862d963f08e5e449f3737b30d62e9f12fe1ec" + "e23d615a2441e565aa27b88123a3cb87299061806de75b9c39a1870fcd33" + "8cde881c6b60ccd716b5841cc98b43f6f199fe34906d24f9df662c87973b" + "17740b6526dbd3be95649c8a709e37cbf2f33d442935f847d65c6aa0ae7f" + "e92352257d09f564f45dff127ffaac193764ad45e1e353767e2a3351faba" + "002e9055fc08643dddfaae68685ea1cf4f0282010100c006837d3d94e8cc" + "d5d1cb9d075435531add16ab32c727085969317a79ff2dd3942de70dd2e8" + "e85bc4d24f3e0f5ac5d10c086d7514b49ec8385da94a6fd5189eace884d4" + "1caafd2b251fccd47699e1ca1400265c2027e201d1563f54a215982ef4f2" + "56fc4509bb6cf4f2ddf988e1b018bc5f1c53e88996a8980565c8bfa49045" + "c0f005b09ace500998dff8866e4ff45b50ae2a833711c5b62cbf47239b29" + "8239bac53508e3b1e7500f280e9a3d2b03c5f543fb4d7cab31f3a6095760" + "ac12b5895b845464ffe15acf64478d25d8dbc99d66441ec9c844d3c5809f" + "7435abd7bd4c18e1512dd03f409fd8e25972d71793690a1162581ebff04b" + "e0c22c9d421ae7ff0282010100ca4e716ce9be399b23d496e11ab957c91f" + "b82b63548b355233479d449ea5486203f6fa7223b2074c46c87b126124bf" + "ff030661b4ba19cc4aa9c14741ad7bdc20171c7d32e8b64e004b244afec4" + "a13975121146fe5e2b269a6d56a3a69a109477cdcb6d3f33a300e0bb725a" + "f9eb633a0df21ef4d9634c18a9ed995c345628960568f346356e138e53ca" + "789cc55f4d2ebf5646b2922ff75351e42013465282cebc36b7fe1cb94a7f" + "86f7393b5913b0fb76e0643f835821ab91a862e10b6ff717210bcd071d83" + "397a91d344a6d89b95b4073246d64623df741710b82a6f5e24cf340641ce" + "775594d6084a701d42fcd1d027e082c5d57b2eb0a9710968bdde07cb1902" + "8201007084208dc504b8f8351d8e023fe61eacc863ddc188aa5afbb0704b" + "f6a8ff55d9d0cca8c357def32fba7f44c1677bd7c76b1691147682733b7c" + "939cfba9d5a26c6f827b3a5265fa9c4a4f7cd4cd7dd3687619b4606cd311" + "c1e0e879895cc3ec7d2f37c4b262bd961dfcd5462dcdfff8103668409006" + "0480cbdeefe2b9235a08d3667efdc6829efb14f487dff6dd326a4b0b5652" + "40ba86e6639d8aa2e3812a6c4ce95c5e7c0bf71543baedccc8a1a8cfb831" + "c398bfc99a0255f2e72c54cfded2925e968660fbf20f24d06c808f39a767" + "569c32b41ba606765416ced7e80efe05c44e1bd05658dc8740627416d78c" + "a31ad4047b3535cf26c6659f98074e4a0abbff02820100396f9627e90f47" + "64210dae75f1bbc148b43979a3f6c7a943bde4d27f142f35f53d5de83329" + "271ace78d0a952286820ff8ed87b3bc8e01ba9059d33568c7466d822960d" + "a81dd4259840061a20c160d2cefcd77d168b72b5c02e823c9a9e7808e486" + "97cca544016e901594e47886af97f4b8e0db41dfd32333bb5169a24dbbad" + "0f699fd43a589ccb23634593d2d89ad89b2992a3d2dadbbb22f7263d2ff8" + "170b0d322f6ff7b1b47970f124a05374a3d5c60ece415ac9ef30bcc7530b" + "5bb660b63f0e24a7c933464db5707ef9625ff083d3a2afde6e2f40bcc380" + "7b15a7023e28cd3c23ecbe66098bd0a029a8792ae5259282c116b32a2890" + "8aaa1bcfef61d9529f"_u8v}}, + {}); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasi_crypto/common.cpp b/test/plugins/wasi_crypto/common.cpp new file mode 100644 index 00000000..63156564 --- /dev/null +++ b/test/plugins/wasi_crypto/common.cpp @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "common/func.h" +#include "helper.h" + +namespace { +template +T *getHostFunc(WasmEdge::Host::WasiCryptoModule *Mod, const char *Name) { + if (Mod) { + auto *FuncInst = Mod->findFuncExports(Name); + if (FuncInst && FuncInst->isHostFunction()) { + return dynamic_cast(&FuncInst->getHostFunc()); + } + } + return nullptr; +} +} // namespace + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +using namespace std::literals; + +TEST_F(WasiCryptoTest, Options) { + // Symmetric options. + { + // Open options. + WASI_CRYPTO_EXPECT_SUCCESS(SymmetricOptionsHandle, + optionsOpen(__WASI_ALGORITHM_TYPE_SYMMETRIC)); + + // Set options. + WASI_CRYPTO_EXPECT_TRUE( + optionsSet(SymmetricOptionsHandle, "context"sv, "foo"_u8)); + WASI_CRYPTO_EXPECT_TRUE( + optionsSet(SymmetricOptionsHandle, "salt"sv, "foo"_u8)); + WASI_CRYPTO_EXPECT_TRUE( + optionsSet(SymmetricOptionsHandle, "nonce"sv, "foo"_u8)); + WASI_CRYPTO_EXPECT_TRUE( + optionsSetU64(SymmetricOptionsHandle, "memory_limit"sv, 0)); + WASI_CRYPTO_EXPECT_TRUE( + optionsSetU64(SymmetricOptionsHandle, "ops_limit"sv, 0)); + WASI_CRYPTO_EXPECT_TRUE( + optionsSetU64(SymmetricOptionsHandle, "parallelism"sv, 0)); + + // Unsupport options. + WASI_CRYPTO_EXPECT_FAILURE( + optionsSet(SymmetricOptionsHandle, "foo"sv, "foo"_u8), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + WASI_CRYPTO_EXPECT_FAILURE( + optionsSetU64(SymmetricOptionsHandle, "foo"sv, 0), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + + writeDummyMemoryContent(); + writeString("foo"sv, 0); + uint32_t NameSize = 3; + auto *Func = getHostFunc( + WasiCryptoMod, "options_set_guest_buffer"); + ASSERT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + SymmetricOptionsHandle, 0, NameSize, 0, NameSize}, + Errno)); + EXPECT_EQ(Errno[0].get(), __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + + WASI_CRYPTO_EXPECT_TRUE(optionsClose(SymmetricOptionsHandle)); + } + + // Signature options. + { + // Open options. + WASI_CRYPTO_EXPECT_SUCCESS(SigOptionsHandle, + optionsOpen(__WASI_ALGORITHM_TYPE_SIGNATURES)); + + // Unsupport options. + WASI_CRYPTO_EXPECT_FAILURE(optionsSet(SigOptionsHandle, "foo"sv, "foo"_u8), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + + WASI_CRYPTO_EXPECT_FAILURE(optionsSetU64(SigOptionsHandle, "foo"sv, 0), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + + writeDummyMemoryContent(); + writeString("foo"sv, 0); + uint32_t NameSize = 3; + auto *Func = getHostFunc( + WasiCryptoMod, "options_set_guest_buffer"); + ASSERT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + SigOptionsHandle, 0, NameSize, 0, NameSize}, + Errno)); + EXPECT_EQ(Errno[0].get(), __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + + // Close options. + WASI_CRYPTO_EXPECT_TRUE(optionsClose(SigOptionsHandle)); + } + + // Key exchange options. + { + // Open options. + WASI_CRYPTO_EXPECT_SUCCESS(KxOptionsHandle, + optionsOpen(__WASI_ALGORITHM_TYPE_KEY_EXCHANGE)); + // Unsupport options. + WASI_CRYPTO_EXPECT_FAILURE(optionsSet(KxOptionsHandle, "foo"sv, "foo"_u8), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + + WASI_CRYPTO_EXPECT_FAILURE(optionsSetU64(KxOptionsHandle, "foo"sv, 0), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + + writeDummyMemoryContent(); + writeString("foo"sv, 0); + uint32_t NameSize = 3; + auto *Func = getHostFunc( + WasiCryptoMod, "options_set_guest_buffer"); + ASSERT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + KxOptionsHandle, 0, NameSize, 0, NameSize}, + Errno)); + EXPECT_EQ(Errno[0].get(), __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + + // Close options. + WASI_CRYPTO_EXPECT_TRUE(optionsClose(KxOptionsHandle)); + } +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasi_crypto/hash.cpp b/test/plugins/wasi_crypto/hash.cpp new file mode 100644 index 00000000..825e0432 --- /dev/null +++ b/test/plugins/wasi_crypto/hash.cpp @@ -0,0 +1,120 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "helper.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +using namespace std::literals; + +TEST_F(WasiCryptoTest, Hash) { + auto HashTest = [this](std::string_view Name, + const std::vector &AbsorbData1, + const std::vector &AbsorbData2, + const std::vector &ExpectedSqueezeData1, + const std::vector &ExpectedSqueezeData2, + const std::vector &TruncatedSquueezeData) { + WASI_CRYPTO_EXPECT_SUCCESS( + StateHandle, symmetricStateOpen(Name, std::nullopt, std::nullopt)); + + SCOPED_TRACE(Name); + { + // "data" + std::vector SqueezeContent(ExpectedSqueezeData1.size()); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateAbsorb(StateHandle, AbsorbData1)); + WASI_CRYPTO_EXPECT_TRUE( + symmetricStateSqueeze(StateHandle, SqueezeContent)); + EXPECT_EQ(SqueezeContent, ExpectedSqueezeData1); + } + + { + // "datamore_data" + std::vector SqueezeContent(ExpectedSqueezeData2.size()); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateAbsorb(StateHandle, AbsorbData2)); + WASI_CRYPTO_EXPECT_TRUE( + symmetricStateSqueeze(StateHandle, SqueezeContent)); + EXPECT_EQ(SqueezeContent, ExpectedSqueezeData2); + } + + { + // Smaller than the hash function output size. Truncate the output. + std::vector SqueezeContent(TruncatedSquueezeData.size()); + WASI_CRYPTO_EXPECT_TRUE( + symmetricStateSqueeze(StateHandle, SqueezeContent)); + EXPECT_EQ(SqueezeContent, TruncatedSquueezeData); + } + + { + // Requested size exceeds the returned invalid_length. + std::vector SqueezeContent(ExpectedSqueezeData1.size() + 1); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateSqueeze(StateHandle, SqueezeContent), + __WASI_CRYPTO_ERRNO_INVALID_LENGTH); + } + + { + // Clone checking. + WASI_CRYPTO_EXPECT_SUCCESS(NewStateHandle, + symmetricStateClone(StateHandle)); + EXPECT_NE(StateHandle, NewStateHandle); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(NewStateHandle)); + } + + { + // Some error cases checking. + WASI_CRYPTO_EXPECT_FAILURE(symmetricKeyGenerate(Name, std::nullopt), + __WASI_CRYPTO_ERRNO_KEY_NOT_SUPPORTED); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateOpen(Name, InvaildHandle, std::nullopt), + __WASI_CRYPTO_ERRNO_INVALID_HANDLE); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateOptionsGet(StateHandle, "foo"sv, {}), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateOptionsGetU64(StateHandle, "foo"sv), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateSqueezeTag(StateHandle), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateSqueezeKey(StateHandle, Name), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateMaxTagLen(StateHandle), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateEncrypt(StateHandle, {}, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateEncryptDetached(StateHandle, {}, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateDecrypt(StateHandle, {}, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateDecryptDetached(StateHandle, {}, {}, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateRatchet(StateHandle), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + // Close. + WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(StateHandle)); + }; + + HashTest( + "SHA-256"sv, "data"_u8, "more_data"_u8, + "3a6eb0790f39ac87c94f3856b2dd2c5d110e6811602261a9a923d3bb23adc8b7"_u8v, + "13c40eec22541a155e172010c7fd6ef654e4e138a0c20923f9a91062a27f57b6"_u8v, + "13c40eec22541a155e172010c7fd6ef654e4e138a0c20923f9a91062a27f57"_u8v); + HashTest( + "SHA-512"sv, "data"_u8, "more_data"_u8, + "77c7ce9a5d86bb386d443bb96390faa120633158699c8844c30b13ab0bf92760b7e4416aea397db91b4ac0e5dd56b8ef7e4b066162ab1fdc088319ce6defc876"_u8v, + "78d0b55eeb3a07754f0967a6e960b5b7488b09ec4d2a62d832a45d80f814aef88e5414e2115165012ac592ff050651e956089a5aacd4ea52cf247c3cc2f6add2"_u8v, + "78d0b55eeb3a07754f0967a6e960b5b7488b09ec4d2a62d832a45d80f814aef88e5414e2115165012ac592ff050651e956089a5aacd4ea52cf247c3cc2f6ad"_u8v); + HashTest( + "SHA-512/256"sv, "data"_u8, "more_data"_u8, + "99902eaf90e92264667843cde66675ed94caa361634bad57874642aa364aa968"_u8v, + "d1def71920a44d8b6c83b2eaa99379a16047cc82cec8d80689fbf02fbd062481"_u8v, + "d1def71920a44d8b6c83b2eaa99379a16047cc82cec8d80689fbf02fbd0624"_u8v); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasi_crypto/helper.cpp b/test/plugins/wasi_crypto/helper.cpp new file mode 100644 index 00000000..cc7223f0 --- /dev/null +++ b/test/plugins/wasi_crypto/helper.cpp @@ -0,0 +1,1436 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "helper.h" +#include "asymmetric_common/func.h" +#include "common/func.h" +#include "kx/func.h" +#include "signatures/func.h" +#include "symmetric/func.h" +#include "utils/error.h" + +#include +#include +#include + +#define ensureOrReturnOnTest(Expr) \ + do { \ + if ((static_cast<__wasi_crypto_errno_e_t>(Expr) != \ + __WASI_CRYPTO_ERRNO_SUCCESS)) { \ + return WasiCryptoUnexpect(static_cast<__wasi_crypto_errno_e_t>(Expr)); \ + } \ + } while (0) + +namespace { +template +T *getHostFunc(WasmEdge::Host::WasiCryptoModule *Mod, const char *Name) { + if (Mod) { + auto *FuncInst = Mod->findFuncExports(Name); + if (FuncInst && FuncInst->isHostFunction()) { + return dynamic_cast(&FuncInst->getHostFunc()); + } + } + return nullptr; +} +} // namespace + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +std::vector operator"" _u8(const char *Str, std::size_t Len) { + return std::vector{reinterpret_cast(Str), + reinterpret_cast(Str) + Len}; +} + +std::vector operator"" _u8v(const char *Str, std::size_t Len) { + std::vector Res; + Res.reserve(Len / 2); + for (size_t I = 0; I < Len; I += 2) { + std::string Tran{Str + I, 2}; + uint8_t Byte = static_cast(std::strtol(Tran.c_str(), nullptr, 16)); + Res.push_back(Byte); + } + return Res; +} + +void WasiCryptoTest::writeDummyMemoryContent() { + std::fill_n(MemInst.getPointer(0), 64, UINT8_C(0xa5)); +} + +void WasiCryptoTest::writeString(std::string_view String, uint32_t Ptr) { + std::copy(String.begin(), String.end(), MemInst.getPointer(Ptr)); +} + +void WasiCryptoTest::writeSpan(Span Content, uint32_t Ptr) { + std::copy(Content.begin(), Content.end(), MemInst.getPointer(Ptr)); +} + +void WasiCryptoTest::writeOptKey(std::optional OptKey, uint32_t Ptr) { + __wasi_opt_symmetric_key_t Key; + if (OptKey) { + Key.tag = __WASI_OPT_SYMMETRIC_KEY_U_SOME; + Key.u = {*OptKey}; + } else { + Key.tag = __WASI_OPT_SYMMETRIC_KEY_U_NONE; + } + auto *BeginPlace = MemInst.getPointer<__wasi_opt_symmetric_key_t *>(Ptr); + *BeginPlace = Key; +} + +void WasiCryptoTest::writeOptOptions(std::optional<__wasi_options_t> OptOptions, + uint32_t Ptr) { + __wasi_opt_options_t Options; + if (OptOptions) { + Options.tag = __WASI_OPT_OPTIONS_U_SOME; + Options.u = {*OptOptions}; + } else { + Options.tag = __WASI_OPT_OPTIONS_U_NONE; + } + auto *BeginPlace = MemInst.getPointer<__wasi_opt_options_t *>(Ptr); + *BeginPlace = Options; +} + +WasiCryptoExpect<__wasi_size_t> +WasiCryptoTest::arrayOutputLen(__wasi_array_output_t ArrayOutputHandle) { + writeDummyMemoryContent(); + + auto *Func = + getHostFunc(WasiCryptoMod, "array_output_len"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, + std::initializer_list{ArrayOutputHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_size_t *>(0); +} + +WasiCryptoExpect<__wasi_size_t> +WasiCryptoTest::arrayOutputPull(__wasi_array_output_t ArrayOutputHandle, + Span Buf) { + writeDummyMemoryContent(); + writeSpan(Buf, 0); + uint32_t BufSize = static_cast(Buf.size()); + + auto *Func = + getHostFunc(WasiCryptoMod, "array_output_pull"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + ArrayOutputHandle, 0, BufSize, BufSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + std::copy(MemInst.getPointer(0), + MemInst.getPointer(BufSize), Buf.begin()); + return *MemInst.getPointer<__wasi_size_t *>(BufSize); +} + +WasiCryptoExpect<__wasi_options_t> +WasiCryptoTest::optionsOpen(__wasi_algorithm_type_e_t AlgorithmType) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc(WasiCryptoMod, "options_open"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + static_cast(AlgorithmType), 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_options_t *>(0); +} + +WasiCryptoExpect +WasiCryptoTest::optionsClose(__wasi_options_t OptionsHandle) { + writeDummyMemoryContent(); + + auto *Func = + getHostFunc(WasiCryptoMod, "options_close"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, std::initializer_list{OptionsHandle}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect +WasiCryptoTest::optionsSet(__wasi_options_t OptionsHandle, + std::string_view Name, Span Value) { + writeDummyMemoryContent(); + writeString(Name, 0); + uint32_t NameSize = static_cast(Name.size()); + writeSpan(Value, NameSize); + uint32_t ValueSize = static_cast(Value.size()); + + auto *Func = getHostFunc(WasiCryptoMod, "options_set"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + OptionsHandle, 0, NameSize, NameSize, ValueSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect +WasiCryptoTest::optionsSetU64(__wasi_options_t OptionsHandle, + std::string_view Name, uint64_t Value) { + writeDummyMemoryContent(); + + writeString(Name, 0); + uint32_t NameSize = static_cast(Name.size()); + + auto *Func = + getHostFunc(WasiCryptoMod, "options_set_u64"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + OptionsHandle, 0, NameSize, Value}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_secrets_manager_t> WasiCryptoTest::secretsManagerOpen( + std::optional<__wasi_options_t> OptOptionsHandle) { + writeDummyMemoryContent(); + writeOptOptions(OptOptionsHandle, 0); + + auto *Func = getHostFunc(WasiCryptoMod, + "secrets_manager_open"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, std::initializer_list{0, 8}, Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_secrets_manager_t *>(8); +} + +WasiCryptoExpect WasiCryptoTest::secretsManagerClose( + __wasi_secrets_manager_t SecretsManagerHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoMod, "secrets_manager_close"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, + std::initializer_list{SecretsManagerHandle}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect WasiCryptoTest::secretsManagerInvalidate( + __wasi_secrets_manager_t SecretsManagerHandle, Span KeyId, + __wasi_version_t Version) { + writeDummyMemoryContent(); + writeSpan(KeyId, 0); + uint32_t KeyIdSize = static_cast(KeyId.size()); + + auto *Func = getHostFunc( + WasiCryptoMod, "secrets_manager_invalidate"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + SecretsManagerHandle, 0, KeyIdSize, Version}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_symmetric_key_t> WasiCryptoTest::symmetricKeyGenerate( + std::string_view Alg, std::optional<__wasi_options_t> OptOptionsHandle) { + writeDummyMemoryContent(); + writeString(Alg, 0); + uint32_t AlgSize = static_cast(Alg.size()); + writeOptOptions(OptOptionsHandle, AlgSize); + + auto *Func = getHostFunc(WasiCryptoMod, + "symmetric_key_generate"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + 0, AlgSize, AlgSize, AlgSize + 8}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_symmetric_key_t *>(AlgSize + 8); +} + +WasiCryptoExpect<__wasi_symmetric_key_t> +WasiCryptoTest::symmetricKeyImport(std::string_view Alg, + Span Raw) { + writeDummyMemoryContent(); + writeString(Alg, 0); + uint32_t AlgSize = static_cast(Alg.size()); + writeSpan(Raw, AlgSize); + uint32_t RawSize = static_cast(Raw.size()); + + auto *Func = + getHostFunc(WasiCryptoMod, "symmetric_key_import"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + 0, AlgSize, AlgSize, RawSize, AlgSize + RawSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_symmetric_key_t *>(AlgSize + RawSize); +} + +WasiCryptoExpect<__wasi_array_output_t> +WasiCryptoTest::symmetricKeyExport(__wasi_symmetric_key_t KeyHandle) { + writeDummyMemoryContent(); + + auto *Func = + getHostFunc(WasiCryptoMod, "symmetric_key_export"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, std::initializer_list{KeyHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_array_output_t *>(0); +} + +WasiCryptoExpect +WasiCryptoTest::symmetricKeyClose(__wasi_symmetric_key_t KeyHandle) { + writeDummyMemoryContent(); + + auto *Func = + getHostFunc(WasiCryptoMod, "symmetric_key_close"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, std::initializer_list{KeyHandle}, Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_symmetric_key_t> +WasiCryptoTest::symmetricKeyGenerateManaged( + __wasi_secrets_manager_t SecretsManagerHandle, std::string_view Alg, + std::optional<__wasi_options_t> OptOptionsHandle) { + writeDummyMemoryContent(); + writeString(Alg, 0); + uint32_t AlgSize = static_cast(Alg.size()); + writeOptOptions(OptOptionsHandle, AlgSize); + + auto *Func = getHostFunc( + WasiCryptoMod, "symmetric_key_generate_managed"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE( + Func->run(&MemInst, + std::initializer_list{ + SecretsManagerHandle, 0, AlgSize, AlgSize, AlgSize + 8}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_keypair_t *>(AlgSize + 8); +} + +WasiCryptoExpect WasiCryptoTest::symmetricKeyStoreManaged( + __wasi_secrets_manager_t SecretsManagerHandle, + __wasi_symmetric_key_t KeyHandle, Span KeyId) { + writeDummyMemoryContent(); + writeSpan(KeyId, 0); + uint32_t KpIdSize = static_cast(KeyId.size()); + + auto *Func = getHostFunc( + WasiCryptoMod, "symmetric_key_store_managed"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + SecretsManagerHandle, KeyHandle, 0, KpIdSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + std::copy(MemInst.getPointer(0), + MemInst.getPointer(KpIdSize), KeyId.begin()); + + return {}; +} + +WasiCryptoExpect<__wasi_version_t> WasiCryptoTest::symmetricKeyReplaceManaged( + __wasi_secrets_manager_t SecretsManagerHandle, + __wasi_symmetric_key_t OldKeyHandle, __wasi_symmetric_key_t NewKeyHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoMod, "symmetric_key_replace_managed"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE( + Func->run(&MemInst, + std::initializer_list{ + SecretsManagerHandle, OldKeyHandle, NewKeyHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_version_t *>(0); +} + +WasiCryptoExpect> +WasiCryptoTest::symmetricKeyId(__wasi_symmetric_key_t KeyHandle, + Span KeyId) { + writeDummyMemoryContent(); + writeSpan(KeyId, 0); + uint32_t KeyIdSize = static_cast(KeyId.size()); + + auto *Func = getHostFunc(WasiCryptoMod, "symmetric_key_id"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + KeyHandle, 0, KeyIdSize, KeyIdSize, KeyIdSize + 1}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + std::copy(MemInst.getPointer(0), + MemInst.getPointer(KeyIdSize), KeyId.begin()); + + return std::make_tuple( + *MemInst.getPointer(KeyIdSize), + *MemInst.getPointer<__wasi_version_t *>(KeyIdSize + 1)); +} + +WasiCryptoExpect<__wasi_symmetric_key_t> WasiCryptoTest::symmetricKeyFromId( + __wasi_secrets_manager_t SecretsManagerHandle, Span KeyId, + __wasi_version_t KeyVersion) { + writeDummyMemoryContent(); + writeSpan(KeyId, 0); + uint32_t KeyIdSize = static_cast(KeyId.size()); + + auto *Func = + getHostFunc(WasiCryptoMod, "symmetric_key_from_id"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE( + Func->run(&MemInst, + std::initializer_list{ + SecretsManagerHandle, 0, KeyIdSize, KeyVersion, KeyIdSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_symmetric_key_t *>(KeyIdSize); +} + +WasiCryptoExpect<__wasi_symmetric_state_t> WasiCryptoTest::symmetricStateOpen( + std::string_view Alg, std::optional<__wasi_symmetric_key_t> OptKeyHandle, + std::optional<__wasi_options_t> OptOptionsHandle) { + writeDummyMemoryContent(); + writeString(Alg, 0); + uint32_t AlgSize = static_cast(Alg.size()); + writeOptKey(OptKeyHandle, AlgSize); + writeOptOptions(OptOptionsHandle, AlgSize + 8); + + auto *Func = + getHostFunc(WasiCryptoMod, "symmetric_state_open"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + 0, AlgSize, AlgSize, AlgSize + 8, AlgSize + 16}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_symmetric_state_t *>(AlgSize + 16); +} + +WasiCryptoExpect<__wasi_size_t> +WasiCryptoTest::symmetricStateOptionsGet(__wasi_symmetric_state_t StateHandle, + std::string_view Name, + Span Value) { + writeDummyMemoryContent(); + writeString(Name, 0); + uint32_t NameSize = static_cast(Name.size()); + writeSpan(Value, NameSize); + uint32_t ValueSize = static_cast(Value.size()); + + auto *Func = getHostFunc( + WasiCryptoMod, "symmetric_state_options_get"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, + std::initializer_list{ + StateHandle, 0, NameSize, NameSize, ValueSize, NameSize + ValueSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + std::copy(MemInst.getPointer(NameSize), + MemInst.getPointer(NameSize + ValueSize), Value.begin()); + + return *MemInst.getPointer<__wasi_size_t *>(NameSize + ValueSize); +} + +WasiCryptoExpect WasiCryptoTest::symmetricStateOptionsGetU64( + __wasi_symmetric_state_t StateHandle, std::string_view Name) { + writeDummyMemoryContent(); + writeString(Name, 0); + uint32_t NameSize = static_cast(Name.size()); + + auto *Func = getHostFunc( + WasiCryptoMod, "symmetric_state_options_get_u64"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + StateHandle, 0, NameSize, NameSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer(NameSize); +} + +WasiCryptoExpect +WasiCryptoTest::symmetricStateClose(__wasi_symmetric_state_t StateHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc(WasiCryptoMod, + "symmetric_state_close"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, std::initializer_list{StateHandle}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect +WasiCryptoTest::symmetricStateAbsorb(__wasi_symmetric_state_t StateHandle, + Span Data) { + writeDummyMemoryContent(); + writeSpan(Data, 0); + uint32_t DataSize = static_cast(Data.size()); + + auto *Func = getHostFunc(WasiCryptoMod, + "symmetric_state_absorb"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, + std::initializer_list{StateHandle, 0, DataSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_symmetric_state_t> +WasiCryptoTest::symmetricStateClone(__wasi_symmetric_state_t StateHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc(WasiCryptoMod, + "symmetric_state_clone"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, std::initializer_list{StateHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_symmetric_state_t *>(0); +} + +WasiCryptoExpect +WasiCryptoTest::symmetricStateSqueeze(__wasi_symmetric_state_t StateHandle, + Span Out) { + writeDummyMemoryContent(); + writeSpan(Out, 0); + uint32_t OutSize = static_cast(Out.size()); + + auto *Func = getHostFunc(WasiCryptoMod, + "symmetric_state_squeeze"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, + std::initializer_list{StateHandle, 0, OutSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + std::copy(MemInst.getPointer(0), + MemInst.getPointer(OutSize), Out.begin()); + + return {}; +} + +WasiCryptoExpect<__wasi_symmetric_tag_t> +WasiCryptoTest::symmetricStateSqueezeTag(__wasi_symmetric_state_t StateHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoMod, "symmetric_state_squeeze_tag"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, std::initializer_list{StateHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_symmetric_tag_t *>(0); +} + +WasiCryptoExpect<__wasi_symmetric_key_t> +WasiCryptoTest::symmetricStateSqueezeKey(__wasi_symmetric_state_t StateHandle, + std::string_view Alg) { + writeDummyMemoryContent(); + writeString(Alg, 0); + uint32_t AlgSize = static_cast(Alg.size()); + + auto *Func = getHostFunc( + WasiCryptoMod, "symmetric_state_squeeze_key"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + StateHandle, 0, AlgSize, AlgSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_symmetric_key_t *>(AlgSize); +} + +WasiCryptoExpect<__wasi_size_t> +WasiCryptoTest::symmetricStateMaxTagLen(__wasi_symmetric_state_t StateHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoMod, "symmetric_state_max_tag_len"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, std::initializer_list{StateHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_size_t *>(0); +} + +WasiCryptoExpect<__wasi_size_t> +WasiCryptoTest::symmetricStateEncrypt(__wasi_symmetric_state_t StateHandle, + Span Out, + Span Data) { + writeDummyMemoryContent(); + writeSpan(Out, 0); + uint32_t OutSize = static_cast(Out.size()); + writeSpan(Data, OutSize); + uint32_t DataSize = static_cast(Data.size()); + + auto *Func = getHostFunc(WasiCryptoMod, + "symmetric_state_encrypt"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, + std::initializer_list{ + StateHandle, 0, OutSize, OutSize, DataSize, OutSize + DataSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + std::copy(MemInst.getPointer(0), + MemInst.getPointer(OutSize), Out.begin()); + + return *MemInst.getPointer<__wasi_size_t *>(OutSize + DataSize); +} + +WasiCryptoExpect<__wasi_symmetric_tag_t> +WasiCryptoTest::symmetricStateEncryptDetached( + __wasi_symmetric_state_t StateHandle, Span Out, + Span Data) { + writeDummyMemoryContent(); + writeSpan(Out, 0); + uint32_t OutSize = static_cast(Out.size()); + writeSpan(Data, OutSize); + uint32_t DataSize = static_cast(Data.size()); + + auto *Func = getHostFunc( + WasiCryptoMod, "symmetric_state_encrypt_detached"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, + std::initializer_list{ + StateHandle, 0, OutSize, OutSize, DataSize, OutSize + DataSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + std::copy(MemInst.getPointer(0), + MemInst.getPointer(OutSize), Out.begin()); + + return *MemInst.getPointer<__wasi_symmetric_tag_t *>(OutSize + DataSize); +} + +WasiCryptoExpect<__wasi_size_t> +WasiCryptoTest::symmetricStateDecrypt(__wasi_symmetric_state_t StateHandle, + Span Out, + Span Data) { + writeDummyMemoryContent(); + writeSpan(Out, 0); + uint32_t OutSize = static_cast(Out.size()); + writeSpan(Data, OutSize); + uint32_t DataSize = static_cast(Data.size()); + + auto *Func = getHostFunc(WasiCryptoMod, + "symmetric_state_decrypt"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, + std::initializer_list{ + StateHandle, 0, OutSize, OutSize, DataSize, OutSize + DataSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + std::copy(MemInst.getPointer(0), + MemInst.getPointer(OutSize), Out.begin()); + + return *MemInst.getPointer<__wasi_size_t *>(OutSize + DataSize); +} + +WasiCryptoExpect<__wasi_size_t> WasiCryptoTest::symmetricStateDecryptDetached( + __wasi_symmetric_state_t StateHandle, Span Out, + Span Data, Span RawTag) { + writeDummyMemoryContent(); + writeSpan(Out, 0); + uint32_t OutSize = static_cast(Out.size()); + writeSpan(Data, OutSize); + uint32_t DataSize = static_cast(Data.size()); + writeSpan(RawTag, OutSize + DataSize); + uint32_t RawTagSize = static_cast(RawTag.size()); + + auto *Func = getHostFunc( + WasiCryptoMod, "symmetric_state_decrypt_detached"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + StateHandle, 0, OutSize, OutSize, DataSize, + OutSize + DataSize, RawTagSize, + OutSize + DataSize + RawTagSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + std::copy(MemInst.getPointer(0), + MemInst.getPointer(OutSize), Out.begin()); + std::copy(MemInst.getPointer(OutSize + DataSize), + MemInst.getPointer(OutSize + DataSize + RawTagSize), + RawTag.begin()); + + return *MemInst.getPointer<__wasi_size_t *>(OutSize + DataSize + RawTagSize); +} + +WasiCryptoExpect +WasiCryptoTest::symmetricStateRatchet(__wasi_symmetric_state_t StateHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc(WasiCryptoMod, + "symmetric_state_ratchet"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, std::initializer_list{StateHandle}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_size_t> +WasiCryptoTest::symmetricMaxTagLen(__wasi_symmetric_tag_t TagHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoMod, "symmetric_state_max_tag_len"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, std::initializer_list{TagHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_size_t *>(0); +} + +WasiCryptoExpect<__wasi_size_t> +WasiCryptoTest::symmetricTagPull(__wasi_symmetric_tag_t TagHandle, + Span Buf) { + writeDummyMemoryContent(); + writeSpan(Buf, 0); + uint32_t BufSize = static_cast(Buf.size()); + + auto *Func = + getHostFunc(WasiCryptoMod, "symmetric_tag_pull"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + TagHandle, 0, BufSize, BufSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + std::copy(MemInst.getPointer(0), + MemInst.getPointer(BufSize), Buf.begin()); + + return *MemInst.getPointer<__wasi_size_t *>(BufSize); +} + +WasiCryptoExpect +WasiCryptoTest::symmetricTagVerify(__wasi_symmetric_tag_t TagHandle, + Span RawTag) { + writeDummyMemoryContent(); + writeSpan(RawTag, 0); + uint32_t RawTagSize = static_cast(RawTag.size()); + + auto *Func = + getHostFunc(WasiCryptoMod, "symmetric_tag_verify"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, + std::initializer_list{TagHandle, 0, RawTagSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect +WasiCryptoTest::symmetricTagClose(__wasi_symmetric_tag_t TagHandle) { + writeDummyMemoryContent(); + + auto *Func = + getHostFunc(WasiCryptoMod, "symmetric_tag_close"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, std::initializer_list{TagHandle}, Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_keypair_t> WasiCryptoTest::keypairGenerate( + __wasi_algorithm_type_e_t AlgType, std::string_view Alg, + std::optional<__wasi_options_t> OptOptionsHandle) { + writeDummyMemoryContent(); + writeString(Alg, 0); + uint32_t AlgSize = static_cast(Alg.size()); + writeOptOptions(OptOptionsHandle, AlgSize); + + auto *Func = getHostFunc( + WasiCryptoMod, "keypair_generate"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, + std::initializer_list{ + static_cast(AlgType), 0, AlgSize, AlgSize, AlgSize + 8}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_keypair_t *>(AlgSize + 8); +} + +WasiCryptoExpect<__wasi_keypair_t> +WasiCryptoTest::keypairImport(__wasi_algorithm_type_e_t AlgType, + std::string_view Alg, Span Encoded, + __wasi_keypair_encoding_e_t Encoding) { + writeDummyMemoryContent(); + writeString(Alg, 0); + uint32_t AlgSize = static_cast(Alg.size()); + writeSpan(Encoded, AlgSize); + uint32_t EncodedSize = static_cast(Encoded.size()); + + auto *Func = getHostFunc(WasiCryptoMod, + "keypair_import"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + static_cast(AlgType), 0, AlgSize, AlgSize, + EncodedSize, static_cast(Encoding), + AlgSize + EncodedSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_keypair_t *>(AlgSize + EncodedSize); +} + +WasiCryptoExpect<__wasi_keypair_t> WasiCryptoTest::keypairGenerateManaged( + __wasi_secrets_manager_t SecretsManagerHandle, + __wasi_algorithm_type_e_t AlgType, std::string_view Alg, + std::optional<__wasi_options_t> OptOptionsHandle) { + writeDummyMemoryContent(); + writeString(Alg, 0); + uint32_t AlgSize = static_cast(Alg.size()); + writeOptOptions(OptOptionsHandle, AlgSize); + + auto *Func = getHostFunc( + WasiCryptoMod, "keypair_generate_managed"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE( + Func->run(&MemInst, + std::initializer_list{ + SecretsManagerHandle, static_cast(AlgType), 0, + AlgSize, AlgSize, AlgSize + 8}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_keypair_t *>(AlgSize + 8); +} + +WasiCryptoExpect WasiCryptoTest::keypairStoreManaged( + __wasi_secrets_manager_t SecretsManagerHandle, __wasi_keypair_t KpHandle, + Span KpId) { + writeDummyMemoryContent(); + writeSpan(KpId, 0); + uint32_t KpIdSize = static_cast(KpId.size()); + + auto *Func = getHostFunc( + WasiCryptoMod, "keypair_store_managed"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + SecretsManagerHandle, KpHandle, 0, KpIdSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + std::copy(MemInst.getPointer(0), + MemInst.getPointer(KpIdSize), KpId.begin()); + + return {}; +} + +WasiCryptoExpect<__wasi_version_t> WasiCryptoTest::keypairReplaceManaged( + __wasi_secrets_manager_t SecretsManagerHandle, __wasi_keypair_t OldKpHandle, + __wasi_keypair_t NewKpHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoMod, "keypair_replace_managed"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + SecretsManagerHandle, OldKpHandle, NewKpHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_version_t *>(0); +} + +WasiCryptoExpect> +WasiCryptoTest::keypairId(__wasi_keypair_t KpHandle, Span KpId) { + writeDummyMemoryContent(); + writeSpan(KpId, 0); + uint32_t KpIdSize = static_cast(KpId.size()); + + auto *Func = + getHostFunc(WasiCryptoMod, "keypair_id"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + KpHandle, 0, KpIdSize, KpIdSize, KpIdSize + 1}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + std::copy(MemInst.getPointer(0), + MemInst.getPointer(KpIdSize), KpId.begin()); + + return std::make_tuple(*MemInst.getPointer(KpIdSize), + *MemInst.getPointer<__wasi_version_t *>(KpIdSize + 1)); +} + +WasiCryptoExpect<__wasi_keypair_t> +WasiCryptoTest::keypairFromId(__wasi_secrets_manager_t SecretsManagerHandle, + Span KpId, + __wasi_version_t KpIdVersion) { + writeDummyMemoryContent(); + writeSpan(KpId, 0); + uint32_t KpIdSize = static_cast(KpId.size()); + + auto *Func = getHostFunc(WasiCryptoMod, + "keypair_from_id"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE( + Func->run(&MemInst, + std::initializer_list{ + SecretsManagerHandle, 0, KpIdSize, KpIdVersion, KpIdSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_keypair_t *>(KpIdSize); +} + +WasiCryptoExpect<__wasi_keypair_t> +WasiCryptoTest::keypairFromPkAndSk(__wasi_publickey_t PkHandle, + __wasi_secretkey_t SkHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoMod, "keypair_from_pk_and_sk"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, + std::initializer_list{PkHandle, SkHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_keypair_t *>(0); +} + +WasiCryptoExpect<__wasi_array_output_t> +WasiCryptoTest::keypairExport(__wasi_keypair_t KpHandle, + __wasi_keypair_encoding_e_t Encoding) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc(WasiCryptoMod, + "keypair_export"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + KpHandle, static_cast(Encoding), 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_array_output_t *>(0); +} + +WasiCryptoExpect<__wasi_publickey_t> +WasiCryptoTest::keypairPublickey(__wasi_keypair_t KpHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoMod, "keypair_publickey"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, std::initializer_list{KpHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_publickey_t *>(0); +} + +WasiCryptoExpect<__wasi_secretkey_t> +WasiCryptoTest::keypairSecretkey(__wasi_keypair_t KpHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoMod, "keypair_secretkey"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, std::initializer_list{KpHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_secretkey_t *>(0); +} + +WasiCryptoExpect WasiCryptoTest::keypairClose(__wasi_keypair_t KpHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc(WasiCryptoMod, + "keypair_close"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, std::initializer_list{KpHandle}, Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_publickey_t> WasiCryptoTest::publickeyImport( + __wasi_algorithm_type_e_t AlgType, std::string_view Alg, + Span Encoded, __wasi_publickey_encoding_e_t Encoding) { + writeDummyMemoryContent(); + writeString(Alg, 0); + uint32_t AlgSize = static_cast(Alg.size()); + writeSpan(Encoded, AlgSize); + uint32_t EncodedSize = static_cast(Encoded.size()); + + auto *Func = getHostFunc( + WasiCryptoMod, "publickey_import"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + static_cast(AlgType), 0, AlgSize, AlgSize, + EncodedSize, static_cast(Encoding), + AlgSize + EncodedSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_publickey_t *>(AlgSize + EncodedSize); +} + +WasiCryptoExpect<__wasi_array_output_t> +WasiCryptoTest::publickeyExport(__wasi_publickey_t PkHandle, + __wasi_publickey_encoding_e_t Encoding) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoMod, "publickey_export"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + PkHandle, static_cast(Encoding), 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_array_output_t *>(0); +} + +WasiCryptoExpect +WasiCryptoTest::publickeyVerify(__wasi_publickey_t PkHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoMod, "publickey_verify"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, std::initializer_list{PkHandle}, Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_publickey_t> +WasiCryptoTest::publickeyFromSecretkey(__wasi_secretkey_t SkHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoMod, "publickey_from_secretkey"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, std::initializer_list{SkHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_publickey_t *>(0); +} + +WasiCryptoExpect +WasiCryptoTest::publickeyClose(__wasi_publickey_t PkHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc(WasiCryptoMod, + "publickey_close"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, std::initializer_list{PkHandle}, Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_secretkey_t> WasiCryptoTest::secretkeyImport( + __wasi_algorithm_type_e_t AlgType, std::string_view Alg, + Span Encoded, __wasi_secretkey_encoding_e_t Encoding) { + writeDummyMemoryContent(); + writeString(Alg, 0); + uint32_t AlgSize = static_cast(Alg.size()); + writeSpan(Encoded, AlgSize); + uint32_t EncodedSize = static_cast(Encoded.size()); + + auto *Func = getHostFunc( + WasiCryptoMod, "secretkey_import"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + static_cast(AlgType), 0, AlgSize, AlgSize, + EncodedSize, static_cast(Encoding), + AlgSize + EncodedSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_secretkey_t *>(AlgSize + EncodedSize); +} + +WasiCryptoExpect<__wasi_array_output_t> +WasiCryptoTest::secretkeyExport(__wasi_secretkey_t SkHandle, + __wasi_secretkey_encoding_e_t Encoding) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoMod, "secretkey_export"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + SkHandle, static_cast(Encoding), 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_publickey_t *>(0); +} + +WasiCryptoExpect +WasiCryptoTest::secretkeyClose(__wasi_secretkey_t SkHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc(WasiCryptoMod, + "secretkey_close"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, std::initializer_list{SkHandle}, Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_array_output_t> +WasiCryptoTest::kxDh(__wasi_kx_publickey_t PkHandle, + __wasi_kx_secretkey_t SkHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc(WasiCryptoMod, "kx_dh"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, + std::initializer_list{PkHandle, SkHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_array_output_t *>(0); +} + +WasiCryptoExpect> +WasiCryptoTest::kxEncapsulate(__wasi_kx_publickey_t PkHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc(WasiCryptoMod, "kx_encapsulate"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, std::initializer_list{PkHandle, 0, 1}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return std::make_tuple(*MemInst.getPointer<__wasi_array_output_t *>(0), + *MemInst.getPointer<__wasi_array_output_t *>(1)); +} + +WasiCryptoExpect<__wasi_array_output_t> +WasiCryptoTest::kxDecapsulate(__wasi_kx_secretkey_t SkHandle, + Span EncapsulatedSecret) { + writeDummyMemoryContent(); + writeSpan(EncapsulatedSecret, 0); + uint32_t EncapsulatedSecretSize = + static_cast(EncapsulatedSecret.size()); + + auto *Func = getHostFunc(WasiCryptoMod, "kx_decapsulate"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, + std::initializer_list{ + SkHandle, 0, EncapsulatedSecretSize, EncapsulatedSecretSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_array_output_t *>(EncapsulatedSecretSize); +} + +WasiCryptoExpect<__wasi_array_output_t> +WasiCryptoTest::signatureExport(__wasi_signature_t SigHandle, + __wasi_signature_encoding_e_t Encoding) { + writeDummyMemoryContent(); + + auto *Func = + getHostFunc(WasiCryptoMod, "signature_export"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run(&MemInst, + std::initializer_list{ + SigHandle, static_cast(Encoding), 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_array_output_t *>(0); +} + +WasiCryptoExpect<__wasi_signature_t> +WasiCryptoTest::signatureImport(std::string_view Alg, + Span Encoded, + __wasi_signature_encoding_e_t Encoding) { + writeDummyMemoryContent(); + writeString(Alg, 0); + uint32_t AlgSize = static_cast(Alg.size()); + writeSpan(Encoded, AlgSize); + uint32_t EncodedSize = static_cast(Encoded.size()); + + auto *Func = + getHostFunc(WasiCryptoMod, "signature_import"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE( + Func->run(&MemInst, + std::initializer_list{ + 0, AlgSize, AlgSize, EncodedSize, + static_cast(Encoding), AlgSize + EncodedSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_signature_t *>(AlgSize + EncodedSize); +} + +WasiCryptoExpect +WasiCryptoTest::signatureClose(__wasi_signature_t SigHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc(WasiCryptoMod, "signature_close"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, std::initializer_list{SigHandle}, Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_signature_state_t> +WasiCryptoTest::signatureStateOpen(__wasi_signature_keypair_t KpHandle) { + writeDummyMemoryContent(); + + auto *Func = + getHostFunc(WasiCryptoMod, "signature_state_open"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, std::initializer_list{KpHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_signature_state_t *>(0); +} + +WasiCryptoExpect +WasiCryptoTest::signatureStateUpdate(__wasi_signature_state_t StateHandle, + Span Input) { + writeDummyMemoryContent(); + writeSpan(Input, 0); + uint32_t InputSize = static_cast(Input.size()); + + auto *Func = getHostFunc(WasiCryptoMod, + "signature_state_update"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, + std::initializer_list{StateHandle, 0, InputSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_signature_t> +WasiCryptoTest::signatureStateSign(__wasi_signature_state_t StateHandle) { + writeDummyMemoryContent(); + + auto *Func = + getHostFunc(WasiCryptoMod, "signature_state_sign"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, std::initializer_list{StateHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_signature_t *>(0); +} + +WasiCryptoExpect +WasiCryptoTest::signatureStateClose(__wasi_signature_state_t StateHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc(WasiCryptoMod, + "signature_state_close"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, std::initializer_list{StateHandle}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect<__wasi_signature_verification_state_t> +WasiCryptoTest::signatureVerificationStateOpen( + __wasi_signature_publickey_t PkHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoMod, "signature_verification_state_open"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, std::initializer_list{PkHandle, 0}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return *MemInst.getPointer<__wasi_signature_verification_state_t *>(0); +} + +WasiCryptoExpect WasiCryptoTest::signatureVerificationStateUpdate( + __wasi_signature_verification_state_t StateHandle, + Span Input) { + writeDummyMemoryContent(); + writeSpan(Input, 0); + uint32_t InputSize = static_cast(Input.size()); + + auto *Func = getHostFunc( + WasiCryptoMod, "signature_verification_state_update"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, + std::initializer_list{StateHandle, 0, InputSize}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect WasiCryptoTest::signatureVerificationStateVerify( + __wasi_signature_verification_state_t StateHandle, + __wasi_signature_t SigHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoMod, "signature_verification_state_verify"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, + std::initializer_list{StateHandle, SigHandle}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +WasiCryptoExpect WasiCryptoTest::signatureVerificationStateClose( + __wasi_signature_verification_state_t StateHandle) { + writeDummyMemoryContent(); + + auto *Func = getHostFunc( + WasiCryptoMod, "signature_verification_state_close"); + EXPECT_NE(Func, nullptr); + EXPECT_TRUE(Func->run( + &MemInst, std::initializer_list{StateHandle}, + Errno)); + ensureOrReturnOnTest(Errno[0].get()); + + return {}; +} + +// WasiCryptoExpect<__wasi_secretkey_t> WasiCryptoTest::secretkeyImport( +// __wasi_algorithm_type_e_t AlgType, std::string_view AlgStr, +// Span Encoded, __wasi_secretkey_encoding_e_t Encoding) { +// writeString(AlgStr, 0); +// writeSpan(Encoded, AlgStr.size()); +// auto Res = +// testRun( +// {static_cast(AlgType), 0, AlgStr.size(), AlgStr.size(), +// Encoded.size(), static_cast(Encoding), +// AlgStr.size() + Encoded.size()}) +// .value(); +// if (Res != __WASI_CRYPTO_ERRNO_SUCCESS) { +// return WasiCryptoUnexpect(Res); +// } +// return *MemInst.getPointer<__wasi_signature_keypair_t *>(AlgStr.size() + +// Encoded.size()); +// } + +// WasiCryptoExpect<__wasi_array_output_t> +// WasiCryptoTest::secretkeyExport(__wasi_secretkey_t SkHandle, +// __wasi_secretkey_encoding_e_t SkEncoding) { +// auto Res = testRun( +// {SkHandle, static_cast(SkEncoding), 0}) +// .value(); +// if (Res != __WASI_CRYPTO_ERRNO_SUCCESS) { +// return WasiCryptoUnexpect(Res); +// } +// return *MemInst.getPointer<__wasi_signature_keypair_t *>(0); +// } + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasi_crypto/helper.h b/test/plugins/wasi_crypto/helper.h new file mode 100644 index 00000000..a2f0b336 --- /dev/null +++ b/test/plugins/wasi_crypto/helper.h @@ -0,0 +1,361 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "asymmetric_common/func.h" +#include "common/func.h" +#include "ctx.h" +#include "helper.h" +#include "module.h" +#include "utils/error.h" + +#include "common/span.h" +#include "common/types.h" +#include "runtime/instance/memory.h" +#include "gtest/gtest.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define WASI_CRYPTO_EXPECT_SUCCESS(Expr, Function) \ + auto &&Expr##__Result = Function; \ + EXPECT_TRUE(Expr##__Result) \ + << "Wasi Crypto Error code: " << Errno[0].get(); \ + auto &&Expr = Expr##__Result.value() + +#define WASI_CRYPTO_EXPECT_FAILURE(Function, ErrorCode) \ + do { \ + auto Result = Function; \ + EXPECT_FALSE(Result) << "The function result should be error but success"; \ + EXPECT_EQ(Result.error(), ErrorCode); \ + } while (0) + +#define WASI_CRYPTO_EXPECT_TRUE(Function) \ + EXPECT_TRUE(Function) << "Wasi Crypto Error code: " << Errno[0].get() + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { + +std::vector operator"" _u8(const char *Str, std::size_t Len); + +std::vector operator"" _u8v(const char *Str, std::size_t Len); + +/// Designed for testing. +class WasiCryptoTest : public ::testing::Test { +public: + WasiCryptoTest() { + using namespace std::literals::string_view_literals; + Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasi_crypto/" + "libwasmedgePluginWasiCrypto" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasi_crypto"sv)) { + if (const auto *Module = Plugin->findModule("wasi_crypto"sv)) { + WasiCryptoMod = dynamic_cast( + Module->create().release()); + } + } + } + + ~WasiCryptoTest() override { + if (WasiCryptoMod) { + delete WasiCryptoMod; + } + } + +protected: + void writeDummyMemoryContent(); + + void writeString(std::string_view String, uint32_t Ptr); + + void writeSpan(Span Content, uint32_t Ptr); + + void writeOptKey(std::optional OptKey, uint32_t Ptr); + + void writeOptOptions(std::optional<__wasi_options_t> OptOptions, + uint32_t Ptr); + + // Common + + WasiCryptoExpect<__wasi_size_t> + arrayOutputLen(__wasi_array_output_t ArrayOutputHandle); + + WasiCryptoExpect<__wasi_size_t> + arrayOutputPull(__wasi_array_output_t ArrayOutputHandle, Span Buf); + + WasiCryptoExpect<__wasi_options_t> + optionsOpen(__wasi_algorithm_type_e_t AlgType); + + WasiCryptoExpect optionsClose(__wasi_options_t OptionsHandle); + + WasiCryptoExpect optionsSet(__wasi_options_t OptionsHandle, + std::string_view Name, + Span Value); + + WasiCryptoExpect optionsSetU64(__wasi_options_t OptionsHandle, + std::string_view Name, uint64_t Value); + + // Not supported, Buf placing must be on page. + // WasiCryptoExpect + // optionsSetGuestBuffer(__wasi_options_t OptionsHandle, + // std::string_view Name, Span Buf); + + WasiCryptoExpect<__wasi_secrets_manager_t> + secretsManagerOpen(std::optional<__wasi_options_t> OptOptionsHandle); + + WasiCryptoExpect + secretsManagerClose(__wasi_secrets_manager_t SecretsManagerHandle); + + WasiCryptoExpect + secretsManagerInvalidate(__wasi_secrets_manager_t SecretsManagerHandle, + Span KeyId, __wasi_version_t Version); + + // Symmetric + + WasiCryptoExpect<__wasi_symmetric_key_t> + symmetricKeyGenerate(std::string_view Alg, + std::optional<__wasi_options_t> OptOptionsHandle); + + WasiCryptoExpect<__wasi_symmetric_key_t> + symmetricKeyImport(std::string_view Alg, Span Raw); + + WasiCryptoExpect<__wasi_array_output_t> + symmetricKeyExport(__wasi_symmetric_key_t KeyHandle); + + WasiCryptoExpect symmetricKeyClose(__wasi_symmetric_key_t KeyHandle); + + WasiCryptoExpect<__wasi_symmetric_key_t> + symmetricKeyGenerateManaged(__wasi_secrets_manager_t SecretsManagerHandle, + std::string_view Alg, + std::optional<__wasi_options_t> OptOptionsHandle); + + WasiCryptoExpect + symmetricKeyStoreManaged(__wasi_secrets_manager_t SecretsManagerHandle, + __wasi_symmetric_key_t KeyHandle, + Span KeyId); + + WasiCryptoExpect<__wasi_version_t> + symmetricKeyReplaceManaged(__wasi_secrets_manager_t SecretsManagerHandle, + __wasi_symmetric_key_t OldKeyHandle, + __wasi_symmetric_key_t NewKeyHandle); + + WasiCryptoExpect> + symmetricKeyId(__wasi_symmetric_key_t KeyHandle, Span KeyId); + + WasiCryptoExpect<__wasi_symmetric_key_t> + symmetricKeyFromId(__wasi_secrets_manager_t SecretsManagerHandle, + Span KeyId, __wasi_version_t KeyVersion); + + WasiCryptoExpect<__wasi_symmetric_state_t> + symmetricStateOpen(std::string_view Alg, + std::optional<__wasi_symmetric_key_t> OptKeyHandle, + std::optional<__wasi_options_t> OptOptionsHandle); + + WasiCryptoExpect<__wasi_size_t> + symmetricStateOptionsGet(__wasi_symmetric_state_t StateHandle, + std::string_view Name, Span Value); + + WasiCryptoExpect + symmetricStateOptionsGetU64(__wasi_symmetric_state_t StateHandle, + std::string_view Name); + + WasiCryptoExpect + symmetricStateClose(__wasi_symmetric_state_t StateHandle); + + WasiCryptoExpect + symmetricStateAbsorb(__wasi_symmetric_state_t StateHandle, + Span Data); + + WasiCryptoExpect<__wasi_symmetric_state_t> + symmetricStateClone(__wasi_symmetric_state_t StateHandle); + + WasiCryptoExpect + symmetricStateSqueeze(__wasi_symmetric_state_t StateHandle, + Span Out); + + WasiCryptoExpect<__wasi_symmetric_tag_t> + symmetricStateSqueezeTag(__wasi_symmetric_state_t StateHandle); + + WasiCryptoExpect<__wasi_symmetric_key_t> + symmetricStateSqueezeKey(__wasi_symmetric_state_t StateHandle, + std::string_view Alg); + + WasiCryptoExpect<__wasi_size_t> + symmetricStateMaxTagLen(__wasi_symmetric_state_t StateHandle); + + WasiCryptoExpect<__wasi_size_t> + symmetricStateEncrypt(__wasi_symmetric_state_t StateHandle, Span Out, + Span Data); + + WasiCryptoExpect<__wasi_symmetric_tag_t> + symmetricStateEncryptDetached(__wasi_symmetric_state_t StateHandle, + Span Out, Span Data); + + WasiCryptoExpect<__wasi_size_t> + symmetricStateDecrypt(__wasi_symmetric_state_t StateHandle, Span Out, + Span Data); + + WasiCryptoExpect<__wasi_size_t> + symmetricStateDecryptDetached(__wasi_symmetric_state_t StateHandle, + Span Out, Span Data, + Span RawTag); + + WasiCryptoExpect + symmetricStateRatchet(__wasi_symmetric_state_t StateHandle); + + WasiCryptoExpect<__wasi_size_t> + symmetricMaxTagLen(__wasi_symmetric_tag_t TagHandle); + + WasiCryptoExpect<__wasi_size_t> + symmetricTagPull(__wasi_symmetric_tag_t TagHandle, Span Buf); + + WasiCryptoExpect symmetricTagVerify(__wasi_symmetric_tag_t TagHandle, + Span RawTag); + + WasiCryptoExpect symmetricTagClose(__wasi_symmetric_tag_t TagHandle); + + // Asymmetric_common + + WasiCryptoExpect<__wasi_keypair_t> + keypairGenerate(__wasi_algorithm_type_e_t AlgType, std::string_view Alg, + std::optional<__wasi_options_t> OptOptionsHandle); + + WasiCryptoExpect<__wasi_keypair_t> + keypairImport(__wasi_algorithm_type_e_t AlgType, std::string_view Alg, + Span Encoded, + __wasi_keypair_encoding_e_t Encoding); + + WasiCryptoExpect<__wasi_keypair_t> + keypairGenerateManaged(__wasi_secrets_manager_t SecretsManagerHandle, + __wasi_algorithm_type_e_t AlgType, + std::string_view Alg, + std::optional<__wasi_options_t> OptOptionsHandle); + + WasiCryptoExpect + keypairStoreManaged(__wasi_secrets_manager_t SecretsManagerHandle, + __wasi_keypair_t KpHandle, Span KpId); + + WasiCryptoExpect<__wasi_version_t> + keypairReplaceManaged(__wasi_secrets_manager_t SecretsManagerHandle, + __wasi_keypair_t OldKpHandle, + __wasi_keypair_t NewKpHandle); + + WasiCryptoExpect> + keypairId(__wasi_keypair_t KpHandle, Span KpId); + + WasiCryptoExpect<__wasi_keypair_t> + keypairFromId(__wasi_secrets_manager_t SecretsManagerHandle, + Span KpId, __wasi_version_t KpIdVersion); + + WasiCryptoExpect<__wasi_keypair_t> + keypairFromPkAndSk(__wasi_publickey_t PkHandle, __wasi_secretkey_t SkHandle); + + WasiCryptoExpect<__wasi_array_output_t> + keypairExport(__wasi_keypair_t KpHandle, + __wasi_keypair_encoding_e_t Encoding); + + WasiCryptoExpect<__wasi_publickey_t> + keypairPublickey(__wasi_keypair_t KpHandle); + + WasiCryptoExpect<__wasi_secretkey_t> + keypairSecretkey(__wasi_keypair_t KpHandle); + + WasiCryptoExpect keypairClose(__wasi_keypair_t KpHandle); + + WasiCryptoExpect<__wasi_publickey_t> + publickeyImport(__wasi_algorithm_type_e_t AlgType, std::string_view Alg, + Span Encoded, + __wasi_publickey_encoding_e_t Encoding); + + WasiCryptoExpect<__wasi_array_output_t> + publickeyExport(__wasi_publickey_t PkHandle, + __wasi_publickey_encoding_e_t Encoding); + + WasiCryptoExpect publickeyVerify(__wasi_publickey_t PkHandle); + + WasiCryptoExpect<__wasi_publickey_t> + publickeyFromSecretkey(__wasi_secretkey_t SkHandle); + + WasiCryptoExpect publickeyClose(__wasi_publickey_t PkHandle); + + WasiCryptoExpect<__wasi_secretkey_t> + secretkeyImport(__wasi_algorithm_type_e_t AlgType, std::string_view Alg, + Span Encoded, + __wasi_secretkey_encoding_e_t Encoding); + + WasiCryptoExpect<__wasi_array_output_t> + secretkeyExport(__wasi_secretkey_t SkHandle, + __wasi_secretkey_encoding_e_t Encoding); + + WasiCryptoExpect secretkeyClose(__wasi_secretkey_t SkHandle); + + // Key_exchange + + WasiCryptoExpect<__wasi_array_output_t> kxDh(__wasi_kx_publickey_t PkHandle, + __wasi_kx_secretkey_t SkHandle); + + WasiCryptoExpect> + kxEncapsulate(__wasi_kx_publickey_t PkHandle); + + WasiCryptoExpect<__wasi_array_output_t> + kxDecapsulate(__wasi_kx_secretkey_t SkHandle, + Span EncapsulatedSecret); + + // Signature + + WasiCryptoExpect<__wasi_array_output_t> + signatureExport(__wasi_signature_t SigHandle, + __wasi_signature_encoding_e_t Encoding); + + WasiCryptoExpect<__wasi_signature_t> + signatureImport(std::string_view Alg, Span Encoded, + __wasi_signature_encoding_e_t Encoding); + + WasiCryptoExpect signatureClose(__wasi_signature_t SigHandle); + + WasiCryptoExpect<__wasi_signature_state_t> + signatureStateOpen(__wasi_signature_keypair_t KpHandle); + + WasiCryptoExpect + signatureStateUpdate(__wasi_signature_state_t StateHandle, + Span Input); + + WasiCryptoExpect<__wasi_signature_t> + signatureStateSign(__wasi_signature_state_t StateHandle); + + WasiCryptoExpect + signatureStateClose(__wasi_signature_state_t StateHandle); + + WasiCryptoExpect<__wasi_signature_verification_state_t> + signatureVerificationStateOpen(__wasi_signature_publickey_t PkHandle); + + WasiCryptoExpect signatureVerificationStateUpdate( + __wasi_signature_verification_state_t StateHandle, + Span Input); + + WasiCryptoExpect signatureVerificationStateVerify( + __wasi_signature_verification_state_t StateHandle, + __wasi_signature_t SigHandle); + + WasiCryptoExpect signatureVerificationStateClose( + __wasi_signature_verification_state_t StateHandle); + + int32_t InvaildHandle = 9999; + WasmEdge::Runtime::Instance::MemoryInstance MemInst{ + WasmEdge::AST::MemoryType(1)}; + std::array Errno; + + Host::WasiCryptoModule *WasiCryptoMod = nullptr; +}; + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasi_crypto/kdf.cpp b/test/plugins/wasi_crypto/kdf.cpp new file mode 100644 index 00000000..6e3fa5ed --- /dev/null +++ b/test/plugins/wasi_crypto/kdf.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "helper.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +using namespace std::literals; + +TEST_F(WasiCryptoTest, Kdf) { + auto KdfTest = [this](std::string_view ExtractAlg, std::string_view ExpandAlg, + const std::vector &Key, + const std::vector &Salt, + const std::vector &Info, size_t KeySize) { + WASI_CRYPTO_EXPECT_SUCCESS(KeyHandle, symmetricKeyImport(ExtractAlg, Key)); + WASI_CRYPTO_EXPECT_SUCCESS( + ExtractStateHandle, + symmetricStateOpen(ExtractAlg, KeyHandle, std::nullopt)); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateAbsorb(ExtractStateHandle, Salt)); + WASI_CRYPTO_EXPECT_SUCCESS( + PrkHandle, symmetricStateSqueezeKey(ExtractStateHandle, ExpandAlg)); + WASI_CRYPTO_EXPECT_TRUE(symmetricKeyClose(KeyHandle)); + WASI_CRYPTO_EXPECT_SUCCESS( + ExpandStateHandle, + symmetricStateOpen(ExpandAlg, PrkHandle, std::nullopt)); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateAbsorb(ExpandStateHandle, Info)); + std::vector SqueezeKey(KeySize); + WASI_CRYPTO_EXPECT_TRUE( + symmetricStateSqueeze(ExpandStateHandle, SqueezeKey)); + + auto BothInvalid = [this](std::string_view Name, + __wasi_symmetric_state_t StateHandle) { + EXPECT_TRUE( + symmetricStateOpen(Name, InvaildHandle, std::nullopt).error() == + __WASI_CRYPTO_ERRNO_INVALID_HANDLE); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateOptionsGet(StateHandle, "foo"sv, {}), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateOptionsGetU64(StateHandle, "foo"sv), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateSqueezeTag(StateHandle), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateMaxTagLen(StateHandle), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateEncrypt(StateHandle, {}, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateEncryptDetached(StateHandle, {}, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateDecrypt(StateHandle, {}, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateDecryptDetached(StateHandle, {}, {}, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateRatchet(StateHandle), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + + // Clone checking. + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateClone(StateHandle), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + }; + BothInvalid(ExpandAlg, ExtractStateHandle); + BothInvalid(ExtractAlg, ExpandStateHandle); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateSqueeze(ExtractStateHandle, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateSqueezeKey(ExpandStateHandle, ExpandAlg), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(ExtractStateHandle)); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(ExpandStateHandle)); + + WASI_CRYPTO_EXPECT_SUCCESS(NewKeyHandle, + symmetricKeyGenerate(ExtractAlg, std::nullopt)); + WASI_CRYPTO_EXPECT_TRUE(symmetricKeyClose(NewKeyHandle)); + WASI_CRYPTO_EXPECT_FAILURE(symmetricKeyGenerate(ExpandAlg, std::nullopt), + __WASI_CRYPTO_ERRNO_UNSUPPORTED_FEATURE); + }; + KdfTest("HKDF-EXTRACT/SHA-256"sv, "HKDF-EXPAND/SHA-256"sv, "IKM"_u8, + "salt"_u8, "info"_u8, 32); + KdfTest("HKDF-EXTRACT/SHA-512"sv, "HKDF-EXPAND/SHA-512"sv, "IKM"_u8, + "salt"_u8, "info"_u8, 64); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasi_crypto/kx.cpp b/test/plugins/wasi_crypto/kx.cpp new file mode 100644 index 00000000..d3b0d80c --- /dev/null +++ b/test/plugins/wasi_crypto/kx.cpp @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "helper.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +using namespace std::literals; + +TEST_F(WasiCryptoTest, KxDh) { + + auto KxDhTest = [this](std::string_view Alg, const std::vector &Pk1, + const std::vector &Sk1, + const std::vector &Pk2, + const std::vector &Sk2, + const std::vector &SharedSecret) { + SCOPED_TRACE(Alg); + WASI_CRYPTO_EXPECT_SUCCESS( + Pk1Handle, publickeyImport(__WASI_ALGORITHM_TYPE_KEY_EXCHANGE, Alg, Pk1, + __WASI_PUBLICKEY_ENCODING_RAW)); + WASI_CRYPTO_EXPECT_SUCCESS( + Sk1Handle, secretkeyImport(__WASI_ALGORITHM_TYPE_KEY_EXCHANGE, Alg, Sk1, + __WASI_SECRETKEY_ENCODING_RAW)); + WASI_CRYPTO_EXPECT_SUCCESS( + Pk2Handle, publickeyImport(__WASI_ALGORITHM_TYPE_KEY_EXCHANGE, Alg, Pk2, + __WASI_PUBLICKEY_ENCODING_RAW)); + WASI_CRYPTO_EXPECT_SUCCESS( + Sk2Handle, secretkeyImport(__WASI_ALGORITHM_TYPE_KEY_EXCHANGE, Alg, Sk2, + __WASI_SECRETKEY_ENCODING_RAW)); + + WASI_CRYPTO_EXPECT_SUCCESS(SharedKey1Handle, kxDh(Pk1Handle, Sk2Handle)); + + WASI_CRYPTO_EXPECT_SUCCESS(SharedKey1Size, + arrayOutputLen(SharedKey1Handle)); + EXPECT_EQ(SharedKey1Size, 32); + std::vector SharedKey1(32); + + WASI_CRYPTO_EXPECT_SUCCESS(SharedKey1PullSize, + arrayOutputPull(SharedKey1Handle, SharedKey1)); + EXPECT_EQ(SharedKey1PullSize, 32); + EXPECT_EQ(SharedKey1, SharedSecret); + + WASI_CRYPTO_EXPECT_SUCCESS(SharedKey2Handle, kxDh(Pk2Handle, Sk1Handle)); + + WASI_CRYPTO_EXPECT_SUCCESS(SharedKey2Size, + arrayOutputLen(SharedKey2Handle)); + EXPECT_EQ(SharedKey2Size, 32); + std::vector SharedKey2(32); + WASI_CRYPTO_EXPECT_TRUE(arrayOutputPull(SharedKey2Handle, SharedKey2)); + EXPECT_EQ(SharedKey2, SharedSecret); + + /// It's only supported in OpenSSL 3.0. + /// See: https://github.com/openssl/openssl/issues/7616 + WASI_CRYPTO_EXPECT_FAILURE(kxEncapsulate(Pk1Handle), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_FAILURE(kxDecapsulate(Sk1Handle, {}), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_TRUE(publickeyClose(Pk1Handle)); + WASI_CRYPTO_EXPECT_TRUE(secretkeyClose(Sk2Handle)); + + WASI_CRYPTO_EXPECT_TRUE(publickeyClose(Pk2Handle)); + WASI_CRYPTO_EXPECT_TRUE(secretkeyClose(Sk1Handle)); + }; + + // From: https://datatracker.ietf.org/doc/html/rfc7748#section-6.1 + KxDhTest( + "X25519"sv, + "8520f0098930a754748b7ddcb43ef75a0dbf3a0d26381af4eba4a98eaa9b4e6a"_u8v, + "77076d0a7318a57d3c16c17251b26645df4c2f87ebc0992ab177fba51db92c2a"_u8v, + "de9edb7d7b7dc1b4d35b61c2ece435373f8343c85b78674dadfc7e146f882b4f"_u8v, + "5dab087e624a8a4b79e17f8b83800ee66f3bb1292618b6fd1c2f8b27ff88e0eb"_u8v, + "4a5d9d5ba4ce2de1728e3bf480350f25e07e21c947d19e3376f09b3c1e161742"_u8v); + + auto NewKxDhTest = [this](std::string_view Alg) { + SCOPED_TRACE(Alg); + + WASI_CRYPTO_EXPECT_SUCCESS( + Kp1Handle, + keypairGenerate(__WASI_ALGORITHM_TYPE_KEY_EXCHANGE, Alg, std::nullopt)); + WASI_CRYPTO_EXPECT_SUCCESS( + Kp2Handle, + keypairGenerate(__WASI_ALGORITHM_TYPE_KEY_EXCHANGE, Alg, std::nullopt)); + + WASI_CRYPTO_EXPECT_SUCCESS(Pk1Handle, keypairPublickey(Kp1Handle)); + WASI_CRYPTO_EXPECT_SUCCESS(Sk1Handle, keypairSecretkey(Kp1Handle)); + WASI_CRYPTO_EXPECT_SUCCESS(Pk2Handle, keypairPublickey(Kp2Handle)); + WASI_CRYPTO_EXPECT_SUCCESS(Sk2Handle, keypairSecretkey(Kp2Handle)); + + WASI_CRYPTO_EXPECT_SUCCESS(SharedKey1Handle, kxDh(Pk1Handle, Sk2Handle)); + + WASI_CRYPTO_EXPECT_SUCCESS(SharedKey1Size, + arrayOutputLen(SharedKey1Handle)); + EXPECT_EQ(SharedKey1Size, 32); + std::vector SharedKey1(32); + + WASI_CRYPTO_EXPECT_SUCCESS(SharedKey1PullSize, + arrayOutputPull(SharedKey1Handle, SharedKey1)); + EXPECT_EQ(SharedKey1PullSize, 32); + + WASI_CRYPTO_EXPECT_SUCCESS(SharedKey2Handle, kxDh(Pk2Handle, Sk1Handle)); + + WASI_CRYPTO_EXPECT_SUCCESS(SharedKey2Size, + arrayOutputLen(SharedKey2Handle)); + EXPECT_EQ(SharedKey2Size, 32); + std::vector SharedKey2(32); + WASI_CRYPTO_EXPECT_TRUE(arrayOutputPull(SharedKey2Handle, SharedKey2)); + + EXPECT_EQ(SharedKey1, SharedKey2); + + /// It's only supported in OpenSSL 3.0. + /// See: https://github.com/openssl/openssl/issues/7616 + WASI_CRYPTO_EXPECT_FAILURE(kxEncapsulate(Pk1Handle), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_FAILURE(kxDecapsulate(Sk1Handle, {}), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_TRUE(publickeyClose(Pk1Handle)); + WASI_CRYPTO_EXPECT_TRUE(secretkeyClose(Sk2Handle)); + + WASI_CRYPTO_EXPECT_TRUE(publickeyClose(Pk2Handle)); + WASI_CRYPTO_EXPECT_TRUE(secretkeyClose(Sk1Handle)); + }; + NewKxDhTest("P256-SHA256"sv); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasi_crypto/mac.cpp b/test/plugins/wasi_crypto/mac.cpp new file mode 100644 index 00000000..2e8521fc --- /dev/null +++ b/test/plugins/wasi_crypto/mac.cpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "helper.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +using namespace std::literals; + +TEST_F(WasiCryptoTest, Mac) { + auto MacTest = [this](std::string_view Name, + const std::vector &ImportKey, + const std::vector &AbsorbData1, + const std::vector &AbsorbData2, + const std::vector &ExpectedTag1, + const std::vector &ExpectedTag2) { + SCOPED_TRACE(Name); + // Generate key hmac. + { + WASI_CRYPTO_EXPECT_SUCCESS(KeyHandle, + symmetricKeyGenerate(Name, std::nullopt)); + + // Key size checking. + WASI_CRYPTO_EXPECT_SUCCESS(KeyOutputHandle, + symmetricKeyExport(KeyHandle)); + WASI_CRYPTO_EXPECT_SUCCESS(KeySize, arrayOutputLen(KeyOutputHandle)); + EXPECT_EQ(KeySize, ImportKey.size()); + + WASI_CRYPTO_EXPECT_SUCCESS( + StateHandle, symmetricStateOpen(Name, KeyHandle, std::nullopt)); + WASI_CRYPTO_EXPECT_TRUE(symmetricKeyClose(KeyHandle)); + + // Equivalent to a single call. + WASI_CRYPTO_EXPECT_TRUE(symmetricStateAbsorb(StateHandle, AbsorbData1)); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateAbsorb(StateHandle, AbsorbData2)); + + WASI_CRYPTO_EXPECT_SUCCESS(TagHandle, + symmetricStateSqueezeTag(StateHandle)); + std::vector Tag(ExpectedTag1.size()); + + WASI_CRYPTO_EXPECT_SUCCESS(TagPullSize, symmetricTagPull(TagHandle, Tag)); + EXPECT_EQ(TagPullSize, ExpectedTag1.size()); + WASI_CRYPTO_EXPECT_TRUE(symmetricTagClose(TagHandle)); + + WASI_CRYPTO_EXPECT_SUCCESS(NewTagHandle, + symmetricStateSqueezeTag(StateHandle)); + + WASI_CRYPTO_EXPECT_TRUE(symmetricTagVerify(NewTagHandle, Tag)); + WASI_CRYPTO_EXPECT_TRUE(symmetricTagClose(NewTagHandle)); + + // Error case checking. + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateOpen(Name, std::nullopt, std::nullopt), + __WASI_CRYPTO_ERRNO_KEY_REQUIRED); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateSqueeze(StateHandle, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateSqueezeKey(StateHandle, Name), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateMaxTagLen(StateHandle), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateEncrypt(StateHandle, {}, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateEncryptDetached(StateHandle, {}, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateDecrypt(StateHandle, {}, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE( + symmetricStateDecryptDetached(StateHandle, {}, {}, {}), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateRatchet(StateHandle), + __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + + // Clone checking. + WASI_CRYPTO_EXPECT_SUCCESS(NewStateHandle, + symmetricStateClone(StateHandle)); + EXPECT_NE(StateHandle, NewStateHandle); + WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(NewStateHandle)); + + WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(StateHandle)); + } + + // Import key hmac. + { + WASI_CRYPTO_EXPECT_SUCCESS(KeyHandle, + symmetricKeyImport(Name, ImportKey)); + WASI_CRYPTO_EXPECT_SUCCESS( + StateHandle, symmetricStateOpen(Name, KeyHandle, std::nullopt)); + WASI_CRYPTO_EXPECT_TRUE(symmetricKeyClose(KeyHandle)); + { + // Absorb "data". + WASI_CRYPTO_EXPECT_TRUE(symmetricStateAbsorb(StateHandle, AbsorbData1)); + + // SqueezeTag "data". + WASI_CRYPTO_EXPECT_SUCCESS(TagHandle, + symmetricStateSqueezeTag(StateHandle)); + std::vector Tag(ExpectedTag1.size()); + WASI_CRYPTO_EXPECT_SUCCESS(TagPullSize, + symmetricTagPull(TagHandle, Tag)); + EXPECT_EQ(TagPullSize, Tag.size()); + EXPECT_EQ(Tag, ExpectedTag1); + } + + { + // Abosorb "more_data". + WASI_CRYPTO_EXPECT_TRUE(symmetricStateAbsorb(StateHandle, AbsorbData2)); + + // SqueezeTag "datamore_data". + WASI_CRYPTO_EXPECT_SUCCESS(TagHandle, + symmetricStateSqueezeTag(StateHandle)); + std::vector Tag(ExpectedTag2.size()); + WASI_CRYPTO_EXPECT_SUCCESS(TagPullSize, + symmetricTagPull(TagHandle, Tag)); + EXPECT_EQ(TagPullSize, Tag.size()); + EXPECT_EQ(Tag, ExpectedTag2); + } + WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(StateHandle)); + } + }; + MacTest( + "HMAC/SHA-256"sv, "00000000000000000000000000000000"_u8, "data"_u8, + "more_data"_u8, + "7f12a3d914ec4d1ee67dd35ff04df5a725d11a6bb78a4aafd1093f5bfbd86887"_u8v, + "77af4875ffb3932cba0c8bc5da18410c42c85eeb07072918629675e054fbc42d"_u8v); + MacTest( + "HMAC/SHA-512"sv, + "0000000000000000000000000000000000000000000000000000000000000000"_u8, + "data"_u8, "more_data"_u8, + "52fbafda16189e63730604e49c747c8281d2420e7aae34c927927e7c3cddfcea62fea554d1962a0c0d1c8177884787d8b2a88bd396d5780e3fb82b11ab33c5cc"_u8v, + "36d2dbfb50768b963fe243535bcda302750297b361b7eb079978b27177adc40338dab5c244ae90e2f11a3518ac31126a52eb5ec715c0a9476b98f73e7ff7682e"_u8v); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasi_crypto/notimplement.cpp b/test/plugins/wasi_crypto/notimplement.cpp new file mode 100644 index 00000000..b87c0b06 --- /dev/null +++ b/test/plugins/wasi_crypto/notimplement.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "helper.h" + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +using namespace std::literals; + +TEST_F(WasiCryptoTest, NotImplement) { + WASI_CRYPTO_EXPECT_FAILURE( + symmetricKeyGenerateManaged(1, "SHA-256"sv, std::nullopt), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_FAILURE(symmetricKeyStoreManaged(1, 1, {}), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_FAILURE(symmetricKeyReplaceManaged(1, 1, 1), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_FAILURE(symmetricKeyId(1, {}), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_FAILURE(symmetricKeyFromId(1, {}, 1), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + + EXPECT_EQ(keypairGenerateManaged(1, __WASI_ALGORITHM_TYPE_SIGNATURES, + "Ed25519"sv, std::nullopt) + .error(), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_FAILURE(keypairStoreManaged(1, 1, {}), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_FAILURE(keypairReplaceManaged(1, 1, 1), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_FAILURE(keypairId(1, {}), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_FAILURE(keypairFromId(1, {}, 1), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + + WASI_CRYPTO_EXPECT_FAILURE(secretsManagerOpen(std::nullopt), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_FAILURE(secretsManagerClose(InvaildHandle), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_FAILURE(secretsManagerInvalidate(InvaildHandle, {}, 0), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasi_crypto/signatures.cpp b/test/plugins/wasi_crypto/signatures.cpp new file mode 100644 index 00000000..f59ff6fd --- /dev/null +++ b/test/plugins/wasi_crypto/signatures.cpp @@ -0,0 +1,107 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "helper.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +using namespace std::literals; + +TEST_F(WasiCryptoTest, Signatures) { + // Use the generated data to sign and verfiy. + auto SigTest = [this](__wasi_algorithm_type_e_t AlgType, + std::string_view Alg) { + SCOPED_TRACE(Alg); + WASI_CRYPTO_EXPECT_SUCCESS(KpHandle, + keypairGenerate(AlgType, Alg, std::nullopt)); + WASI_CRYPTO_EXPECT_SUCCESS(StateHandle, signatureStateOpen(KpHandle)); + WASI_CRYPTO_EXPECT_TRUE(signatureStateUpdate(StateHandle, "test"_u8)); + WASI_CRYPTO_EXPECT_TRUE(signatureStateUpdate(StateHandle, "test"_u8)); + WASI_CRYPTO_EXPECT_SUCCESS(SigHandle, signatureStateSign(StateHandle)); + WASI_CRYPTO_EXPECT_TRUE(signatureStateClose(StateHandle)); + + WASI_CRYPTO_EXPECT_SUCCESS(PkHandle, keypairPublickey(KpHandle)); + WASI_CRYPTO_EXPECT_SUCCESS(VerifictionStateHandle, + signatureVerificationStateOpen(PkHandle)); + WASI_CRYPTO_EXPECT_TRUE( + signatureVerificationStateUpdate(VerifictionStateHandle, "test"_u8)); + WASI_CRYPTO_EXPECT_TRUE( + signatureVerificationStateUpdate(VerifictionStateHandle, "test"_u8)); + WASI_CRYPTO_EXPECT_TRUE( + signatureVerificationStateVerify(VerifictionStateHandle, SigHandle)); + WASI_CRYPTO_EXPECT_TRUE( + signatureVerificationStateClose(VerifictionStateHandle)); + }; + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "ECDSA_P256_SHA256"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "ECDSA_K256_SHA256"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "Ed25519"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PKCS1_2048_SHA256"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PKCS1_2048_SHA384"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PKCS1_2048_SHA512"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PKCS1_3072_SHA384"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PKCS1_3072_SHA512"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PKCS1_4096_SHA512"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PSS_2048_SHA256"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PSS_2048_SHA384"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PSS_2048_SHA512"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PSS_3072_SHA384"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PSS_3072_SHA512"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PSS_4096_SHA512"sv); + + auto SigEncodingTest = + [this]( + std::string_view Alg, + std::map<__wasi_signature_encoding_e_t, std::vector> Data) { + SCOPED_TRACE(Alg); + for (auto &[Encoding, Sig] : Data) { + SCOPED_TRACE(Encoding); + WASI_CRYPTO_EXPECT_SUCCESS(SigHandle, + signatureImport(Alg, Sig, Encoding)); + WASI_CRYPTO_EXPECT_SUCCESS(ExportSigHandle, + signatureExport(SigHandle, Encoding)); + WASI_CRYPTO_EXPECT_SUCCESS(ExportSigSize, + arrayOutputLen(ExportSigHandle)); + std::vector ExportSig(ExportSigSize); + WASI_CRYPTO_EXPECT_TRUE(arrayOutputPull(ExportSigHandle, ExportSig)); + EXPECT_EQ(Sig, ExportSig); + } + }; + SigEncodingTest( + "ECDSA_K256_SHA256"sv, + {{__WASI_SIGNATURE_ENCODING_RAW, + "9D92E9FDCA3DDF2E1DDCA1E3B7A79A250B6E4AFFCABF5F9FF4D960B152AB8300E9EB978BD3DA89C42BBFE5A2C2AEB0AF1DD178FB4BCD0833B587D118F59BBB4D"_u8v}, + {__WASI_SIGNATURE_ENCODING_DER, + "30460221009d92e9fdca3ddf2e1ddca1e3b7a79a250b6e4affcabf5f9f" + "f4d960b152ab8300022100e9eb978bd3da89c42bbfe5a2c2aeb0af1dd1" + "78fb4bcd0833b587d118f59bbb4d"_u8v}}); + SigEncodingTest( + "ECDSA_P256_SHA256"sv, + {{__WASI_SIGNATURE_ENCODING_RAW, + "80D5D4769AE4F3998DD6B8B01177DE855204122A361F2189F9567C806DE2673E2FBFD3FF018338875B1D144F583EB6E8DC16CF6EEB2BB5C19A3202464ABB58BD"_u8v}, + {__WASI_SIGNATURE_ENCODING_DER, + "304502210080d5d4769ae4f3998dd6b8b01177de855204122a361f2189" + "f9567c806de2673e02202fbfd3ff018338875b1d144f583eb6e8dc16cf" + "6eeb2bb5c19a3202464abb58bd"_u8v}}); + SigEncodingTest( + "Ed25519"sv, + {{__WASI_SIGNATURE_ENCODING_RAW, + "d4fbdb52bfa726b44d1786a8c0d171c3e62ca83c9e5bbe63de0bb2483f8fd6cc1429ab72cafc41ab56af02ff8fcc43b99bfe4c7ae940f60f38ebaa9d311c4007"_u8v}}); + SigEncodingTest( + "RSA_PSS_2048_SHA256"sv, + {{__WASI_SIGNATURE_ENCODING_RAW, + "4f01e0c12b08625ecac89a69231906edf826380f37c959a96690d046316d68ff" + "ce9d5c471694fcebfc6b45534864689256e4fc81c78e583f675d0c94b4496474" + "51e81beff01a11a516d5e5ce3f1a910437cb8a3a5096b19fb15f4524a35b23d8" + "9cdba12cf5b71aac1047b28c562df7c5542c34ce23a182cf7e0e231934b17294" + "799d44877a1d68ef1b8f073619b7618e6b7c22db20030d98cf591ffc3d4da5f5" + "8613ecd5ecfc3b40a1d02f40891ca43695cd4c088b05a8054c89c595a47e2748" + "16f35384226f74459ee63e25a1bfc03c360490552ec38343f8ace502f065303b" + "00bc0ec320711b211fde92e57feb9013c3609342495ec0d7cabdec21e54acc38"_u8v}}); +} + +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/utils/docker/build-manylinux.sh b/utils/docker/build-manylinux.sh index e3fc59bb..c041c552 100755 --- a/utils/docker/build-manylinux.sh +++ b/utils/docker/build-manylinux.sh @@ -13,7 +13,7 @@ git config --global --add safe.directory $(pwd) bzip2 -dc boost_1_79_0.tar.bz2 | tar -xf - CMAKE_BUILD_TYPE="Release" -WASMEDGE_BUILD_WASI_CRYPTO="OFF" +WASMEDGE_PLUGIN_WASI_CRYPTO="OFF" for i in "$@"; do case $i in @@ -21,8 +21,8 @@ for i in "$@"; do CMAKE_BUILD_TYPE="${i#*=}" shift ;; - -DWASMEDGE_BUILD_WASI_CRYPTO=*) - WASMEDGE_BUILD_WASI_CRYPTO=$(echo ${i#*=} | tr '[:lower:]' '[:upper:]') + -DWASMEDGE_PLUGIN_WASI_CRYPTO=*) + WASMEDGE_PLUGIN_WASI_CRYPTO=$(echo ${i#*=} | tr '[:lower:]' '[:upper:]') shift ;; *) @@ -31,7 +31,7 @@ for i in "$@"; do done CMAKE_OPTS="-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}" -if [ ${WASMEDGE_BUILD_WASI_CRYPTO} == "ON" ]; then +if [ ${WASMEDGE_PLUGIN_WASI_CRYPTO} == "ON" ]; then echo "Building wasi-crypto..." # install openssl curl -s -L -O --remote-name-all https://www.openssl.org/source/openssl-1.1.1n.tar.gz @@ -57,7 +57,7 @@ if [ ${WASMEDGE_BUILD_WASI_CRYPTO} == "ON" ]; then make test make install cd .. - CMAKE_OPTS="${CMAKE_OPTS} -DWASMEDGE_BUILD_WASI_CRYPTO=ON -DOPENSSL_ROOT_DIR=$(pwd)/openssl-1.1.1n/openssl" + CMAKE_OPTS="${CMAKE_OPTS} -DWASMEDGE_PLUGIN_WASI_CRYPTO=ON -DOPENSSL_ROOT_DIR=$(pwd)/openssl-1.1.1n/openssl" fi if ! cmake -Bbuild -GNinja ${CMAKE_OPTS} -DWASMEDGE_BUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM;DEB" -DBoost_NO_SYSTEM_PATHS=TRUE -DBOOST_INCLUDEDIR=$(pwd)/boost_1_79_0/; then From 1e5f88b7f04562a5406686830adae6fac819b26e Mon Sep 17 00:00:00 2001 From: YiYing He Date: Tue, 19 Jul 2022 08:37:09 +0800 Subject: [PATCH 059/623] [CI] Move out the wasi-crypto testing. Signed-off-by: YiYing He --- plugins/wasi_crypto/CMakeLists.txt | 1 + utils/docker/build-manylinux.sh | 76 +++++++++++++++--------------- utils/wasi-crypto/build-openssl.sh | 29 ++++++++++++ 3 files changed, 68 insertions(+), 38 deletions(-) create mode 100755 utils/wasi-crypto/build-openssl.sh diff --git a/plugins/wasi_crypto/CMakeLists.txt b/plugins/wasi_crypto/CMakeLists.txt index dd9ecf19..5f885d5a 100644 --- a/plugins/wasi_crypto/CMakeLists.txt +++ b/plugins/wasi_crypto/CMakeLists.txt @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: 2019-2022 Second State INC +set(OPENSSL_USE_STATIC_LIBS ON) find_package(OpenSSL REQUIRED) wasmedge_add_library(wasmedgePluginWasiCrypto diff --git a/utils/docker/build-manylinux.sh b/utils/docker/build-manylinux.sh index c041c552..673104bf 100755 --- a/utils/docker/build-manylinux.sh +++ b/utils/docker/build-manylinux.sh @@ -13,59 +13,59 @@ git config --global --add safe.directory $(pwd) bzip2 -dc boost_1_79_0.tar.bz2 | tar -xf - CMAKE_BUILD_TYPE="Release" -WASMEDGE_PLUGIN_WASI_CRYPTO="OFF" +IS_BUILD_TARGET=true +IS_NINJA=true +CMAKE_OPTS="" for i in "$@"; do case $i in - -DCMAKE_BUILD_TYPE=*) - CMAKE_BUILD_TYPE="${i#*=}" + --release|--Release) + CMAKE_BUILD_TYPE="Release" shift ;; - -DWASMEDGE_PLUGIN_WASI_CRYPTO=*) - WASMEDGE_PLUGIN_WASI_CRYPTO=$(echo ${i#*=} | tr '[:lower:]' '[:upper:]') + --debug|--Debug) + CMAKE_BUILD_TYPE="Debug" + shift + ;; + --not-build) + IS_BUILD_TARGET=false + shift + ;; + --not-ninja) + IS_NINJA=false shift ;; *) + CMAKE_OPTS="${CMAKE_OPTS} $i" + shift ;; esac done -CMAKE_OPTS="-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}" -if [ ${WASMEDGE_PLUGIN_WASI_CRYPTO} == "ON" ]; then - echo "Building wasi-crypto..." - # install openssl - curl -s -L -O --remote-name-all https://www.openssl.org/source/openssl-1.1.1n.tar.gz - echo "40dceb51a4f6a5275bde0e6bf20ef4b91bfc32ed57c0552e2e8e15463372b17a openssl-1.1.1n.tar.gz" | sha256sum -c - tar -xf openssl-1.1.1n.tar.gz - cd ./openssl-1.1.1n - # openssl configure need newer perl - curl -s -L -O --remote-name-all https://www.cpan.org/src/5.0/perl-5.34.0.tar.gz - tar -xf perl-5.34.0.tar.gz - cd perl-5.34.0 - mkdir localperl - ./Configure -des -Dprefix=$(pwd)/localperl/ - make -j - # too long! - # make test - make install - export PATH="$(pwd)/localperl/bin/:$PATH" - cd .. - # configure by previous perl - mkdir openssl - ./perl-5.34.0/localperl/bin/perl ./config --prefix=$(pwd)/openssl --openssldir=$(pwd)/openssl - make -j - make test - make install - cd .. - CMAKE_OPTS="${CMAKE_OPTS} -DWASMEDGE_PLUGIN_WASI_CRYPTO=ON -DOPENSSL_ROOT_DIR=$(pwd)/openssl-1.1.1n/openssl" -fi - -if ! cmake -Bbuild -GNinja ${CMAKE_OPTS} -DWASMEDGE_BUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM;DEB" -DBoost_NO_SYSTEM_PATHS=TRUE -DBOOST_INCLUDEDIR=$(pwd)/boost_1_79_0/; then +if $IS_NINJA; then + if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DWASMEDGE_BUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM;DEB" -DBoost_NO_SYSTEM_PATHS=TRUE -DBOOST_INCLUDEDIR=$(pwd)/boost_1_79_0/ ${CMAKE_OPTS} .; then echo === CMakeOutput.log === cat build/CMakeFiles/CMakeOutput.log echo === CMakeError.log === cat build/CMakeFiles/CMakeError.log exit 1 + fi +else + rm -rf build + mkdir build + cd build + if ! cmake -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DWASMEDGE_BUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM;DEB" -DBoost_NO_SYSTEM_PATHS=TRUE -DBOOST_INCLUDEDIR=$(pwd)/../boost_1_79_0/ ${CMAKE_OPTS} ..; then + cd .. + echo === CMakeOutput.log === + cat build/CMakeFiles/CMakeOutput.log + echo === CMakeError.log === + cat build/CMakeFiles/CMakeError.log + exit 1 + fi + cd .. +fi + +if ${IS_BUILD_TARGET}; then + cmake --build build + cmake --build build --target package fi -cmake --build build -cmake --build build --target package diff --git a/utils/wasi-crypto/build-openssl.sh b/utils/wasi-crypto/build-openssl.sh new file mode 100755 index 00000000..547170a3 --- /dev/null +++ b/utils/wasi-crypto/build-openssl.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +echo "Building OpenSSL for wasi-crypto..." +# Get OpenSSL source +curl -s -L -O --remote-name-all https://www.openssl.org/source/openssl-1.1.1n.tar.gz +echo "40dceb51a4f6a5275bde0e6bf20ef4b91bfc32ed57c0552e2e8e15463372b17a openssl-1.1.1n.tar.gz" | sha256sum -c +tar -xf openssl-1.1.1n.tar.gz +cd ./openssl-1.1.1n +# OpenSSL configure need newer perl +curl -s -L -O --remote-name-all https://www.cpan.org/src/5.0/perl-5.34.0.tar.gz +tar -xf perl-5.34.0.tar.gz +cd perl-5.34.0 +mkdir localperl +./Configure -des -Dprefix=$(pwd)/localperl/ +make -j +# too long! +# make test +make install +export PATH="$(pwd)/localperl/bin/:$PATH" +cd .. +# Configure by previous perl +mkdir openssl +./perl-5.34.0/localperl/bin/perl ./config --prefix=$(pwd)/openssl --openssldir=$(pwd)/openssl +make -j +make test +make install +cd .. From 50e49eec3699d097ff16a48e91602bc16aec3b39 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Mon, 25 Jul 2022 20:51:55 +0800 Subject: [PATCH 060/623] [WASI] Refactor to correct the wasi-crypto module names. Signed-off-by: YiYing He --- plugins/wasi_crypto/CMakeLists.txt | 6 +- .../wasi_crypto/asymmetric_common/module.cpp | 60 ++++++ .../wasi_crypto/asymmetric_common/module.h | 37 ++++ plugins/wasi_crypto/common/module.cpp | 36 ++++ plugins/wasi_crypto/common/module.h | 35 ++++ plugins/wasi_crypto/ctx.cpp | 51 ++++- plugins/wasi_crypto/ctx.h | 24 ++- plugins/wasi_crypto/kx/module.cpp | 22 ++ plugins/wasi_crypto/{ => kx}/module.h | 12 +- plugins/wasi_crypto/module.cpp | 161 --------------- plugins/wasi_crypto/signatures/ctx.cpp | 1 + plugins/wasi_crypto/signatures/module.cpp | 43 ++++ plugins/wasi_crypto/signatures/module.h | 36 ++++ plugins/wasi_crypto/symmetric/module.cpp | 73 +++++++ plugins/wasi_crypto/symmetric/module.h | 36 ++++ test/plugins/wasi_crypto/common.cpp | 9 +- test/plugins/wasi_crypto/helper.cpp | 191 +++++++++--------- test/plugins/wasi_crypto/helper.h | 55 ++++- 18 files changed, 606 insertions(+), 282 deletions(-) create mode 100644 plugins/wasi_crypto/asymmetric_common/module.cpp create mode 100644 plugins/wasi_crypto/asymmetric_common/module.h create mode 100644 plugins/wasi_crypto/common/module.cpp create mode 100644 plugins/wasi_crypto/common/module.h create mode 100644 plugins/wasi_crypto/kx/module.cpp rename plugins/wasi_crypto/{ => kx}/module.h (55%) delete mode 100644 plugins/wasi_crypto/module.cpp create mode 100644 plugins/wasi_crypto/signatures/module.cpp create mode 100644 plugins/wasi_crypto/signatures/module.h create mode 100644 plugins/wasi_crypto/symmetric/module.cpp create mode 100644 plugins/wasi_crypto/symmetric/module.h diff --git a/plugins/wasi_crypto/CMakeLists.txt b/plugins/wasi_crypto/CMakeLists.txt index 5f885d5a..36b59520 100644 --- a/plugins/wasi_crypto/CMakeLists.txt +++ b/plugins/wasi_crypto/CMakeLists.txt @@ -6,27 +6,30 @@ find_package(OpenSSL REQUIRED) wasmedge_add_library(wasmedgePluginWasiCrypto SHARED - module.cpp ctx.cpp asymmetric_common/ctx.cpp asymmetric_common/func.cpp asymmetric_common/keypair.cpp + asymmetric_common/module.cpp asymmetric_common/publickey.cpp asymmetric_common/secretkey.cpp common/array_output.cpp common/ctx.cpp common/func.cpp + common/module.cpp common/options.cpp kx/ctx.cpp kx/dh/ecdsa.cpp kx/dh/x25519.cpp kx/func.cpp kx/kx.cpp + kx/module.cpp kx/options.cpp signatures/ctx.cpp signatures/ecdsa.cpp signatures/eddsa.cpp signatures/func.cpp + signatures/module.cpp signatures/options.cpp signatures/rsa.cpp signatures/signatures.cpp @@ -39,6 +42,7 @@ wasmedge_add_library(wasmedgePluginWasiCrypto symmetric/kdf.cpp symmetric/key.cpp symmetric/mac.cpp + symmetric/module.cpp symmetric/options.cpp symmetric/state.cpp symmetric/tag.cpp diff --git a/plugins/wasi_crypto/asymmetric_common/module.cpp b/plugins/wasi_crypto/asymmetric_common/module.cpp new file mode 100644 index 00000000..816f2232 --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/module.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "asymmetric_common/module.h" +#include "asymmetric_common/func.h" + +#include + +namespace WasmEdge { +namespace Host { + +WasiCryptoAsymmetricCommonModule::WasiCryptoAsymmetricCommonModule( + std::shared_ptr C) + : ModuleInstance("wasi_ephemeral_crypto_asymmetric_common"), Ctx(C) { + using namespace WasiCrypto; + + addHostFunc("keypair_generate", + std::make_unique(*Ctx)); + addHostFunc("keypair_import", + std::make_unique(*Ctx)); + addHostFunc("keypair_generate_managed", + std::make_unique(*Ctx)); + addHostFunc("keypair_store_managed", + std::make_unique(*Ctx)); + addHostFunc("keypair_replace_managed", + std::make_unique(*Ctx)); + addHostFunc("keypair_id", + std::make_unique(*Ctx)); + addHostFunc("keypair_from_id", + std::make_unique(*Ctx)); + addHostFunc("keypair_from_pk_and_sk", + std::make_unique(*Ctx)); + addHostFunc("keypair_export", + std::make_unique(*Ctx)); + addHostFunc("keypair_publickey", + std::make_unique(*Ctx)); + addHostFunc("keypair_secretkey", + std::make_unique(*Ctx)); + addHostFunc("keypair_close", + std::make_unique(*Ctx)); + addHostFunc("publickey_import", + std::make_unique(*Ctx)); + addHostFunc("publickey_export", + std::make_unique(*Ctx)); + addHostFunc("publickey_verify", + std::make_unique(*Ctx)); + addHostFunc("publickey_from_secretkey", + std::make_unique(*Ctx)); + addHostFunc("publickey_close", + std::make_unique(*Ctx)); + addHostFunc("secretkey_import", + std::make_unique(*Ctx)); + addHostFunc("secretkey_export", + std::make_unique(*Ctx)); + addHostFunc("secretkey_close", + std::make_unique(*Ctx)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/asymmetric_common/module.h b/plugins/wasi_crypto/asymmetric_common/module.h new file mode 100644 index 00000000..4e5281af --- /dev/null +++ b/plugins/wasi_crypto/asymmetric_common/module.h @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/asymmetric_common/module.h - Asym ----===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the declaration of the wasi-crypto asymmetric_common +/// module class. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "ctx.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasiCryptoAsymmetricCommonModule + : public Runtime::Instance::ModuleInstance { +public: + WasiCryptoAsymmetricCommonModule(std::shared_ptr); + + WasiCrypto::Context &getContext() { return *Ctx.get(); } + +private: + std::shared_ptr Ctx; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/common/module.cpp b/plugins/wasi_crypto/common/module.cpp new file mode 100644 index 00000000..bbb0107b --- /dev/null +++ b/plugins/wasi_crypto/common/module.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "common/module.h" +#include "common/func.h" + +#include + +namespace WasmEdge { +namespace Host { + +WasiCryptoCommonModule::WasiCryptoCommonModule( + std::shared_ptr C) + : ModuleInstance("wasi_ephemeral_crypto_common"), Ctx(C) { + using namespace WasiCrypto; + + addHostFunc("array_output_len", + std::make_unique(*Ctx)); + addHostFunc("array_output_pull", + std::make_unique(*Ctx)); + addHostFunc("options_open", std::make_unique(*Ctx)); + addHostFunc("options_close", std::make_unique(*Ctx)); + addHostFunc("options_set", std::make_unique(*Ctx)); + addHostFunc("options_set_u64", std::make_unique(*Ctx)); + addHostFunc("options_set_guest_buffer", + std::make_unique(*Ctx)); + addHostFunc("secrets_manager_open", + std::make_unique(*Ctx)); + addHostFunc("secrets_manager_close", + std::make_unique(*Ctx)); + addHostFunc("secrets_manager_invalidate", + std::make_unique(*Ctx)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/common/module.h b/plugins/wasi_crypto/common/module.h new file mode 100644 index 00000000..cfa05726 --- /dev/null +++ b/plugins/wasi_crypto/common/module.h @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/common/module.h - Common Module ------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the declaration of the wasi-crypto common module class. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "ctx.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasiCryptoCommonModule : public Runtime::Instance::ModuleInstance { +public: + WasiCryptoCommonModule(std::shared_ptr); + + WasiCrypto::Context &getContext() { return *Ctx.get(); } + +private: + std::shared_ptr Ctx; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/ctx.cpp b/plugins/wasi_crypto/ctx.cpp index a7457632..f233a504 100644 --- a/plugins/wasi_crypto/ctx.cpp +++ b/plugins/wasi_crypto/ctx.cpp @@ -2,15 +2,32 @@ // SPDX-FileCopyrightText: 2019-2022 Second State INC #include "ctx.h" -#include "module.h" +#include "asymmetric_common/module.h" +#include "common/module.h" +#include "kx/module.h" +#include "signatures/module.h" +#include "symmetric/module.h" namespace WasmEdge { namespace Host { namespace { -Runtime::Instance::ModuleInstance *create(void) noexcept { - return new WasiCryptoModule; +Runtime::Instance::ModuleInstance *createAsymmetricCommon(void) noexcept { + return new WasiCryptoAsymmetricCommonModule( + WasiCrypto::Context::getInstance()); +} +Runtime::Instance::ModuleInstance *createCommon(void) noexcept { + return new WasiCryptoCommonModule(WasiCrypto::Context::getInstance()); +} +Runtime::Instance::ModuleInstance *createKx(void) noexcept { + return new WasiCryptoKxModule(WasiCrypto::Context::getInstance()); +} +Runtime::Instance::ModuleInstance *createSignatures(void) noexcept { + return new WasiCryptoSignaturesModule(WasiCrypto::Context::getInstance()); +} +Runtime::Instance::ModuleInstance *createSymmetric(void) noexcept { + return new WasiCryptoSymmetricModule(WasiCrypto::Context::getInstance()); } Plugin::Plugin::PluginDescriptor Descriptor{ @@ -18,13 +35,33 @@ Plugin::Plugin::PluginDescriptor Descriptor{ .Description = "", .APIVersion = Plugin::Plugin::CurrentAPIVersion, .Version = {0, 10, 1, 0}, - .ModuleCount = 1, + .ModuleCount = 5, .ModuleDescriptions = (Plugin::PluginModule::ModuleDescriptor[]){ { - .Name = "wasi_crypto", + .Name = "wasi_crypto_asymmetric_common", + .Description = "", + .Create = createAsymmetricCommon, + }, + { + .Name = "wasi_crypto_common", + .Description = "", + .Create = createCommon, + }, + { + .Name = "wasi_crypto_kx", + .Description = "", + .Create = createKx, + }, + { + .Name = "wasi_crypto_signatures", + .Description = "", + .Create = createSignatures, + }, + { + .Name = "wasi_crypto_symmetric", .Description = "", - .Create = create, + .Create = createSymmetric, }, }, .AddOptions = nullptr, @@ -33,6 +70,8 @@ Plugin::Plugin::PluginDescriptor Descriptor{ } // namespace Plugin::PluginRegister WasiCrypto::Context::Register(&Descriptor); +std::shared_mutex WasiCrypto::Context::Mutex; +std::weak_ptr WasiCrypto::Context::Instance; } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasi_crypto/ctx.h b/plugins/wasi_crypto/ctx.h index da07acd5..9a60c21e 100644 --- a/plugins/wasi_crypto/ctx.h +++ b/plugins/wasi_crypto/ctx.h @@ -19,7 +19,6 @@ #include "asymmetric_common/secretkey.h" #include "common/array_output.h" #include "common/options.h" -#include "common/span.h" #include "kx/registed.h" #include "signatures/registed.h" #include "signatures/signatures.h" @@ -32,14 +31,33 @@ #include "utils/error.h" #include "utils/handles_manager.h" +#include "common/span.h" #include "plugin/plugin.h" +#include +#include +#include + namespace WasmEdge { namespace Host { namespace WasiCrypto { class Context { public: + // Singleton + static std::shared_ptr getInstance() noexcept { + std::unique_lock Lock(Mutex); + std::shared_ptr CtxPtr = Instance.lock(); + if (!CtxPtr) { + CtxPtr.reset(new Context()); + Instance = CtxPtr; + } + return CtxPtr; + } + + Context(const Context &) = delete; + void operator=(const Context &) = delete; + // Common WasiCryptoExpect @@ -313,6 +331,8 @@ class Context { __wasi_signature_verification_state_t StateHandle) noexcept; private: + Context() noexcept {} + RefHandlesManager<__wasi_array_output_t, Common::ArrayOutput> ArrayOutputManager{0x00}; RcHandlesManager<__wasi_options_t, Common::Options> OptionsManager{0x01}; @@ -336,6 +356,8 @@ class Context { Signatures::VerificationStateVariant> VerificationStateManager{0x02}; + static std::shared_mutex Mutex; + static std::weak_ptr Instance; static Plugin::PluginRegister Register; }; diff --git a/plugins/wasi_crypto/kx/module.cpp b/plugins/wasi_crypto/kx/module.cpp new file mode 100644 index 00000000..38f35cf6 --- /dev/null +++ b/plugins/wasi_crypto/kx/module.cpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "kx/module.h" +#include "kx/func.h" + +#include + +namespace WasmEdge { +namespace Host { + +WasiCryptoKxModule::WasiCryptoKxModule(std::shared_ptr C) + : ModuleInstance("wasi_ephemeral_crypto_kx"), Ctx(C) { + using namespace WasiCrypto; + + addHostFunc("kx_dh", std::make_unique(*Ctx)); + addHostFunc("kx_encapsulate", std::make_unique(*Ctx)); + addHostFunc("kx_decapsulate", std::make_unique(*Ctx)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/module.h b/plugins/wasi_crypto/kx/module.h similarity index 55% rename from plugins/wasi_crypto/module.h rename to plugins/wasi_crypto/kx/module.h index b709325e..85992f0a 100644 --- a/plugins/wasi_crypto/module.h +++ b/plugins/wasi_crypto/kx/module.h @@ -1,14 +1,14 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2022 Second State INC -//===-- wasmedge/plugins/wasi_crypto/module.h - Module class definition ---===// +//===-- wasmedge/plugins/wasi_crypto/kx/module.h - Kx Module --------------===// // // Part of the WasmEdge Project. // //===----------------------------------------------------------------------===// /// /// \file -/// This file contains the declaration of the wasi-crypto module class. +/// This file contains the declaration of the wasi-crypto Kx module class. /// //===----------------------------------------------------------------------===// @@ -21,14 +21,14 @@ namespace WasmEdge { namespace Host { -class WasiCryptoModule : public Runtime::Instance::ModuleInstance { +class WasiCryptoKxModule : public Runtime::Instance::ModuleInstance { public: - WasiCryptoModule(); + WasiCryptoKxModule(std::shared_ptr); - WasiCrypto::Context &getContext() { return Ctx; } + WasiCrypto::Context &getContext() { return *Ctx.get(); } private: - WasiCrypto::Context Ctx; + std::shared_ptr Ctx; }; } // namespace Host diff --git a/plugins/wasi_crypto/module.cpp b/plugins/wasi_crypto/module.cpp deleted file mode 100644 index 9999cac4..00000000 --- a/plugins/wasi_crypto/module.cpp +++ /dev/null @@ -1,161 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC - -#include "module.h" -#include "asymmetric_common/func.h" -#include "common/func.h" -#include "kx/func.h" -#include "signatures/func.h" -#include "symmetric/func.h" - -#include - -namespace WasmEdge { -namespace Host { - -WasiCryptoModule::WasiCryptoModule() : ModuleInstance("wasi_ephemeral_crypto") { - using namespace WasiCrypto; - - // common - addHostFunc("array_output_len", - std::make_unique(Ctx)); - addHostFunc("array_output_pull", - std::make_unique(Ctx)); - addHostFunc("options_open", std::make_unique(Ctx)); - addHostFunc("options_close", std::make_unique(Ctx)); - addHostFunc("options_set", std::make_unique(Ctx)); - addHostFunc("options_set_u64", std::make_unique(Ctx)); - addHostFunc("options_set_guest_buffer", - std::make_unique(Ctx)); - addHostFunc("secrets_manager_open", - std::make_unique(Ctx)); - addHostFunc("secrets_manager_close", - std::make_unique(Ctx)); - addHostFunc("secrets_manager_invalidate", - std::make_unique(Ctx)); - - // symmetric - addHostFunc("symmetric_key_generate", - std::make_unique(Ctx)); - addHostFunc("symmetric_key_import", - std::make_unique(Ctx)); - addHostFunc("symmetric_key_export", - std::make_unique(Ctx)); - addHostFunc("symmetric_key_close", - std::make_unique(Ctx)); - addHostFunc("symmetric_key_generate_managed", - std::make_unique(Ctx)); - addHostFunc("symmetric_key_store_managed", - std::make_unique(Ctx)); - addHostFunc("symmetric_key_replace_managed", - std::make_unique(Ctx)); - addHostFunc("symmetric_key_id", std::make_unique(Ctx)); - addHostFunc("symmetric_key_from_id", - std::make_unique(Ctx)); - addHostFunc("symmetric_state_open", - std::make_unique(Ctx)); - addHostFunc("symmetric_state_clone", - std::make_unique(Ctx)); - addHostFunc("symmetric_state_options_get", - std::make_unique(Ctx)); - addHostFunc("symmetric_state_options_get_u64", - std::make_unique(Ctx)); - addHostFunc("symmetric_state_close", - std::make_unique(Ctx)); - addHostFunc("symmetric_state_absorb", - std::make_unique(Ctx)); - addHostFunc("symmetric_state_squeeze", - std::make_unique(Ctx)); - addHostFunc("symmetric_state_squeeze_tag", - std::make_unique(Ctx)); - addHostFunc("symmetric_state_squeeze_key", - std::make_unique(Ctx)); - addHostFunc("symmetric_state_max_tag_len", - std::make_unique(Ctx)); - addHostFunc("symmetric_state_encrypt", - std::make_unique(Ctx)); - addHostFunc("symmetric_state_encrypt_detached", - std::make_unique(Ctx)); - addHostFunc("symmetric_state_decrypt", - std::make_unique(Ctx)); - addHostFunc("symmetric_state_decrypt_detached", - std::make_unique(Ctx)); - addHostFunc("symmetric_state_ratchet", - std::make_unique(Ctx)); - addHostFunc("symmetric_tag_len", std::make_unique(Ctx)); - addHostFunc("symmetric_tag_pull", std::make_unique(Ctx)); - addHostFunc("symmetric_tag_verify", - std::make_unique(Ctx)); - addHostFunc("symmetric_tag_close", - std::make_unique(Ctx)); - - // asymmetric - addHostFunc("keypair_generate", - std::make_unique(Ctx)); - addHostFunc("keypair_import", - std::make_unique(Ctx)); - addHostFunc("keypair_generate_managed", - std::make_unique(Ctx)); - addHostFunc("keypair_store_managed", - std::make_unique(Ctx)); - addHostFunc("keypair_replace_managed", - std::make_unique(Ctx)); - addHostFunc("keypair_id", std::make_unique(Ctx)); - addHostFunc("keypair_from_id", - std::make_unique(Ctx)); - addHostFunc("keypair_from_pk_and_sk", - std::make_unique(Ctx)); - addHostFunc("keypair_export", - std::make_unique(Ctx)); - addHostFunc("keypair_publickey", - std::make_unique(Ctx)); - addHostFunc("keypair_secretkey", - std::make_unique(Ctx)); - addHostFunc("keypair_close", - std::make_unique(Ctx)); - addHostFunc("publickey_import", - std::make_unique(Ctx)); - addHostFunc("publickey_export", - std::make_unique(Ctx)); - addHostFunc("publickey_verify", - std::make_unique(Ctx)); - addHostFunc("publickey_from_secretkey", - std::make_unique(Ctx)); - addHostFunc("publickey_close", - std::make_unique(Ctx)); - addHostFunc("secretkey_import", - std::make_unique(Ctx)); - addHostFunc("secretkey_export", - std::make_unique(Ctx)); - addHostFunc("secretkey_close", - std::make_unique(Ctx)); - - // kx - addHostFunc("kx_dh", std::make_unique(Ctx)); - addHostFunc("kx_encapsulate", std::make_unique(Ctx)); - addHostFunc("kx_decapsulate", std::make_unique(Ctx)); - - // signature - addHostFunc("signature_export", std::make_unique(Ctx)); - addHostFunc("signature_import", std::make_unique(Ctx)); - addHostFunc("signature_state_open", - std::make_unique(Ctx)); - addHostFunc("signature_state_update", - std::make_unique(Ctx)); - addHostFunc("signature_state_sign", - std::make_unique(Ctx)); - addHostFunc("signature_state_close", - std::make_unique(Ctx)); - addHostFunc("signature_verification_state_open", - std::make_unique(Ctx)); - addHostFunc("signature_verification_state_update", - std::make_unique(Ctx)); - addHostFunc("signature_verification_state_verify", - std::make_unique(Ctx)); - addHostFunc("signature_verification_state_close", - std::make_unique(Ctx)); - addHostFunc("signature_close", std::make_unique(Ctx)); -} - -} // namespace Host -} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/ctx.cpp b/plugins/wasi_crypto/signatures/ctx.cpp index c5213f9c..17070d44 100644 --- a/plugins/wasi_crypto/signatures/ctx.cpp +++ b/plugins/wasi_crypto/signatures/ctx.cpp @@ -1,4 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC #include "ctx.h" #include "signatures/signatures.h" diff --git a/plugins/wasi_crypto/signatures/module.cpp b/plugins/wasi_crypto/signatures/module.cpp new file mode 100644 index 00000000..44a3fef7 --- /dev/null +++ b/plugins/wasi_crypto/signatures/module.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "module.h" +#include "asymmetric_common/func.h" +#include "common/func.h" +#include "kx/func.h" +#include "signatures/func.h" +#include "symmetric/func.h" + +#include + +namespace WasmEdge { +namespace Host { + +WasiCryptoSignaturesModule::WasiCryptoSignaturesModule( + std::shared_ptr C) + : ModuleInstance("wasi_ephemeral_crypto_signatures"), Ctx(C) { + using namespace WasiCrypto; + + addHostFunc("signature_export", std::make_unique(*Ctx)); + addHostFunc("signature_import", std::make_unique(*Ctx)); + addHostFunc("signature_state_open", + std::make_unique(*Ctx)); + addHostFunc("signature_state_update", + std::make_unique(*Ctx)); + addHostFunc("signature_state_sign", + std::make_unique(*Ctx)); + addHostFunc("signature_state_close", + std::make_unique(*Ctx)); + addHostFunc("signature_verification_state_open", + std::make_unique(*Ctx)); + addHostFunc("signature_verification_state_update", + std::make_unique(*Ctx)); + addHostFunc("signature_verification_state_verify", + std::make_unique(*Ctx)); + addHostFunc("signature_verification_state_close", + std::make_unique(*Ctx)); + addHostFunc("signature_close", std::make_unique(*Ctx)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/signatures/module.h b/plugins/wasi_crypto/signatures/module.h new file mode 100644 index 00000000..71e8540f --- /dev/null +++ b/plugins/wasi_crypto/signatures/module.h @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/signatures/module.h - Module ---------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the declaration of the wasi-crypto signatures module +/// class. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "ctx.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasiCryptoSignaturesModule : public Runtime::Instance::ModuleInstance { +public: + WasiCryptoSignaturesModule(std::shared_ptr); + + WasiCrypto::Context &getContext() { return *Ctx.get(); } + +private: + std::shared_ptr Ctx; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/module.cpp b/plugins/wasi_crypto/symmetric/module.cpp new file mode 100644 index 00000000..402c4b19 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/module.cpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "symmetric/module.h" +#include "symmetric/func.h" + +#include + +namespace WasmEdge { +namespace Host { + +WasiCryptoSymmetricModule::WasiCryptoSymmetricModule( + std::shared_ptr C) + : ModuleInstance("wasi_ephemeral_crypto_symmetric"), Ctx(C) { + using namespace WasiCrypto; + + addHostFunc("symmetric_key_generate", + std::make_unique(*Ctx)); + addHostFunc("symmetric_key_import", + std::make_unique(*Ctx)); + addHostFunc("symmetric_key_export", + std::make_unique(*Ctx)); + addHostFunc("symmetric_key_close", + std::make_unique(*Ctx)); + addHostFunc("symmetric_key_generate_managed", + std::make_unique(*Ctx)); + addHostFunc("symmetric_key_store_managed", + std::make_unique(*Ctx)); + addHostFunc("symmetric_key_replace_managed", + std::make_unique(*Ctx)); + addHostFunc("symmetric_key_id", std::make_unique(*Ctx)); + addHostFunc("symmetric_key_from_id", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_open", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_clone", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_options_get", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_options_get_u64", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_close", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_absorb", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_squeeze", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_squeeze_tag", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_squeeze_key", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_max_tag_len", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_encrypt", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_encrypt_detached", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_decrypt", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_decrypt_detached", + std::make_unique(*Ctx)); + addHostFunc("symmetric_state_ratchet", + std::make_unique(*Ctx)); + addHostFunc("symmetric_tag_len", std::make_unique(*Ctx)); + addHostFunc("symmetric_tag_pull", std::make_unique(*Ctx)); + addHostFunc("symmetric_tag_verify", + std::make_unique(*Ctx)); + addHostFunc("symmetric_tag_close", + std::make_unique(*Ctx)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/module.h b/plugins/wasi_crypto/symmetric/module.h new file mode 100644 index 00000000..af4f23d4 --- /dev/null +++ b/plugins/wasi_crypto/symmetric/module.h @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/symmetric/module.h - Module ----------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the declaration of the wasi-crypto symmetric module +/// class. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "ctx.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasiCryptoSymmetricModule : public Runtime::Instance::ModuleInstance { +public: + WasiCryptoSymmetricModule(std::shared_ptr); + + WasiCrypto::Context &getContext() { return *Ctx.get(); } + +private: + std::shared_ptr Ctx; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasi_crypto/common.cpp b/test/plugins/wasi_crypto/common.cpp index 63156564..791481c6 100644 --- a/test/plugins/wasi_crypto/common.cpp +++ b/test/plugins/wasi_crypto/common.cpp @@ -5,8 +5,7 @@ #include "helper.h" namespace { -template -T *getHostFunc(WasmEdge::Host::WasiCryptoModule *Mod, const char *Name) { +template T *getHostFunc(M *Mod, const char *Name) { if (Mod) { auto *FuncInst = Mod->findFuncExports(Name); if (FuncInst && FuncInst->isHostFunction()) { @@ -55,7 +54,7 @@ TEST_F(WasiCryptoTest, Options) { writeString("foo"sv, 0); uint32_t NameSize = 3; auto *Func = getHostFunc( - WasiCryptoMod, "options_set_guest_buffer"); + WasiCryptoCommonMod, "options_set_guest_buffer"); ASSERT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ @@ -83,7 +82,7 @@ TEST_F(WasiCryptoTest, Options) { writeString("foo"sv, 0); uint32_t NameSize = 3; auto *Func = getHostFunc( - WasiCryptoMod, "options_set_guest_buffer"); + WasiCryptoCommonMod, "options_set_guest_buffer"); ASSERT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ @@ -111,7 +110,7 @@ TEST_F(WasiCryptoTest, Options) { writeString("foo"sv, 0); uint32_t NameSize = 3; auto *Func = getHostFunc( - WasiCryptoMod, "options_set_guest_buffer"); + WasiCryptoCommonMod, "options_set_guest_buffer"); ASSERT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ diff --git a/test/plugins/wasi_crypto/helper.cpp b/test/plugins/wasi_crypto/helper.cpp index cc7223f0..6fe86695 100644 --- a/test/plugins/wasi_crypto/helper.cpp +++ b/test/plugins/wasi_crypto/helper.cpp @@ -22,8 +22,7 @@ } while (0) namespace { -template -T *getHostFunc(WasmEdge::Host::WasiCryptoModule *Mod, const char *Name) { +template T *getHostFunc(M *Mod, const char *Name) { if (Mod) { auto *FuncInst = Mod->findFuncExports(Name); if (FuncInst && FuncInst->isHostFunction()) { @@ -95,8 +94,8 @@ WasiCryptoExpect<__wasi_size_t> WasiCryptoTest::arrayOutputLen(__wasi_array_output_t ArrayOutputHandle) { writeDummyMemoryContent(); - auto *Func = - getHostFunc(WasiCryptoMod, "array_output_len"); + auto *Func = getHostFunc(WasiCryptoCommonMod, + "array_output_len"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, @@ -114,8 +113,8 @@ WasiCryptoTest::arrayOutputPull(__wasi_array_output_t ArrayOutputHandle, writeSpan(Buf, 0); uint32_t BufSize = static_cast(Buf.size()); - auto *Func = - getHostFunc(WasiCryptoMod, "array_output_pull"); + auto *Func = getHostFunc(WasiCryptoCommonMod, + "array_output_pull"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ @@ -132,7 +131,8 @@ WasiCryptoExpect<__wasi_options_t> WasiCryptoTest::optionsOpen(__wasi_algorithm_type_e_t AlgorithmType) { writeDummyMemoryContent(); - auto *Func = getHostFunc(WasiCryptoMod, "options_open"); + auto *Func = + getHostFunc(WasiCryptoCommonMod, "options_open"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ @@ -148,7 +148,7 @@ WasiCryptoTest::optionsClose(__wasi_options_t OptionsHandle) { writeDummyMemoryContent(); auto *Func = - getHostFunc(WasiCryptoMod, "options_close"); + getHostFunc(WasiCryptoCommonMod, "options_close"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, std::initializer_list{OptionsHandle}, @@ -167,7 +167,8 @@ WasiCryptoTest::optionsSet(__wasi_options_t OptionsHandle, writeSpan(Value, NameSize); uint32_t ValueSize = static_cast(Value.size()); - auto *Func = getHostFunc(WasiCryptoMod, "options_set"); + auto *Func = + getHostFunc(WasiCryptoCommonMod, "options_set"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ @@ -186,8 +187,8 @@ WasiCryptoTest::optionsSetU64(__wasi_options_t OptionsHandle, writeString(Name, 0); uint32_t NameSize = static_cast(Name.size()); - auto *Func = - getHostFunc(WasiCryptoMod, "options_set_u64"); + auto *Func = getHostFunc(WasiCryptoCommonMod, + "options_set_u64"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ @@ -203,7 +204,7 @@ WasiCryptoExpect<__wasi_secrets_manager_t> WasiCryptoTest::secretsManagerOpen( writeDummyMemoryContent(); writeOptOptions(OptOptionsHandle, 0); - auto *Func = getHostFunc(WasiCryptoMod, + auto *Func = getHostFunc(WasiCryptoCommonMod, "secrets_manager_open"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( @@ -218,7 +219,7 @@ WasiCryptoExpect WasiCryptoTest::secretsManagerClose( writeDummyMemoryContent(); auto *Func = getHostFunc( - WasiCryptoMod, "secrets_manager_close"); + WasiCryptoCommonMod, "secrets_manager_close"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, @@ -237,7 +238,7 @@ WasiCryptoExpect WasiCryptoTest::secretsManagerInvalidate( uint32_t KeyIdSize = static_cast(KeyId.size()); auto *Func = getHostFunc( - WasiCryptoMod, "secrets_manager_invalidate"); + WasiCryptoCommonMod, "secrets_manager_invalidate"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ @@ -255,7 +256,7 @@ WasiCryptoExpect<__wasi_symmetric_key_t> WasiCryptoTest::symmetricKeyGenerate( uint32_t AlgSize = static_cast(Alg.size()); writeOptOptions(OptOptionsHandle, AlgSize); - auto *Func = getHostFunc(WasiCryptoMod, + auto *Func = getHostFunc(WasiCryptoSymmMod, "symmetric_key_generate"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, @@ -276,8 +277,8 @@ WasiCryptoTest::symmetricKeyImport(std::string_view Alg, writeSpan(Raw, AlgSize); uint32_t RawSize = static_cast(Raw.size()); - auto *Func = - getHostFunc(WasiCryptoMod, "symmetric_key_import"); + auto *Func = getHostFunc(WasiCryptoSymmMod, + "symmetric_key_import"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ @@ -292,8 +293,8 @@ WasiCryptoExpect<__wasi_array_output_t> WasiCryptoTest::symmetricKeyExport(__wasi_symmetric_key_t KeyHandle) { writeDummyMemoryContent(); - auto *Func = - getHostFunc(WasiCryptoMod, "symmetric_key_export"); + auto *Func = getHostFunc(WasiCryptoSymmMod, + "symmetric_key_export"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, std::initializer_list{KeyHandle, 0}, @@ -307,8 +308,8 @@ WasiCryptoExpect WasiCryptoTest::symmetricKeyClose(__wasi_symmetric_key_t KeyHandle) { writeDummyMemoryContent(); - auto *Func = - getHostFunc(WasiCryptoMod, "symmetric_key_close"); + auto *Func = getHostFunc(WasiCryptoSymmMod, + "symmetric_key_close"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, std::initializer_list{KeyHandle}, Errno)); @@ -327,7 +328,7 @@ WasiCryptoTest::symmetricKeyGenerateManaged( writeOptOptions(OptOptionsHandle, AlgSize); auto *Func = getHostFunc( - WasiCryptoMod, "symmetric_key_generate_managed"); + WasiCryptoSymmMod, "symmetric_key_generate_managed"); EXPECT_NE(Func, nullptr); EXPECT_TRUE( Func->run(&MemInst, @@ -347,7 +348,7 @@ WasiCryptoExpect WasiCryptoTest::symmetricKeyStoreManaged( uint32_t KpIdSize = static_cast(KeyId.size()); auto *Func = getHostFunc( - WasiCryptoMod, "symmetric_key_store_managed"); + WasiCryptoSymmMod, "symmetric_key_store_managed"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ @@ -367,7 +368,7 @@ WasiCryptoExpect<__wasi_version_t> WasiCryptoTest::symmetricKeyReplaceManaged( writeDummyMemoryContent(); auto *Func = getHostFunc( - WasiCryptoMod, "symmetric_key_replace_managed"); + WasiCryptoSymmMod, "symmetric_key_replace_managed"); EXPECT_NE(Func, nullptr); EXPECT_TRUE( Func->run(&MemInst, @@ -386,7 +387,8 @@ WasiCryptoTest::symmetricKeyId(__wasi_symmetric_key_t KeyHandle, writeSpan(KeyId, 0); uint32_t KeyIdSize = static_cast(KeyId.size()); - auto *Func = getHostFunc(WasiCryptoMod, "symmetric_key_id"); + auto *Func = + getHostFunc(WasiCryptoSymmMod, "symmetric_key_id"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ @@ -409,8 +411,8 @@ WasiCryptoExpect<__wasi_symmetric_key_t> WasiCryptoTest::symmetricKeyFromId( writeSpan(KeyId, 0); uint32_t KeyIdSize = static_cast(KeyId.size()); - auto *Func = - getHostFunc(WasiCryptoMod, "symmetric_key_from_id"); + auto *Func = getHostFunc(WasiCryptoSymmMod, + "symmetric_key_from_id"); EXPECT_NE(Func, nullptr); EXPECT_TRUE( Func->run(&MemInst, @@ -431,8 +433,8 @@ WasiCryptoExpect<__wasi_symmetric_state_t> WasiCryptoTest::symmetricStateOpen( writeOptKey(OptKeyHandle, AlgSize); writeOptOptions(OptOptionsHandle, AlgSize + 8); - auto *Func = - getHostFunc(WasiCryptoMod, "symmetric_state_open"); + auto *Func = getHostFunc(WasiCryptoSymmMod, + "symmetric_state_open"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ @@ -454,7 +456,7 @@ WasiCryptoTest::symmetricStateOptionsGet(__wasi_symmetric_state_t StateHandle, uint32_t ValueSize = static_cast(Value.size()); auto *Func = getHostFunc( - WasiCryptoMod, "symmetric_state_options_get"); + WasiCryptoSymmMod, "symmetric_state_options_get"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, @@ -476,7 +478,7 @@ WasiCryptoExpect WasiCryptoTest::symmetricStateOptionsGetU64( uint32_t NameSize = static_cast(Name.size()); auto *Func = getHostFunc( - WasiCryptoMod, "symmetric_state_options_get_u64"); + WasiCryptoSymmMod, "symmetric_state_options_get_u64"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ @@ -491,7 +493,7 @@ WasiCryptoExpect WasiCryptoTest::symmetricStateClose(__wasi_symmetric_state_t StateHandle) { writeDummyMemoryContent(); - auto *Func = getHostFunc(WasiCryptoMod, + auto *Func = getHostFunc(WasiCryptoSymmMod, "symmetric_state_close"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( @@ -509,7 +511,7 @@ WasiCryptoTest::symmetricStateAbsorb(__wasi_symmetric_state_t StateHandle, writeSpan(Data, 0); uint32_t DataSize = static_cast(Data.size()); - auto *Func = getHostFunc(WasiCryptoMod, + auto *Func = getHostFunc(WasiCryptoSymmMod, "symmetric_state_absorb"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( @@ -525,7 +527,7 @@ WasiCryptoExpect<__wasi_symmetric_state_t> WasiCryptoTest::symmetricStateClone(__wasi_symmetric_state_t StateHandle) { writeDummyMemoryContent(); - auto *Func = getHostFunc(WasiCryptoMod, + auto *Func = getHostFunc(WasiCryptoSymmMod, "symmetric_state_clone"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( @@ -543,7 +545,7 @@ WasiCryptoTest::symmetricStateSqueeze(__wasi_symmetric_state_t StateHandle, writeSpan(Out, 0); uint32_t OutSize = static_cast(Out.size()); - auto *Func = getHostFunc(WasiCryptoMod, + auto *Func = getHostFunc(WasiCryptoSymmMod, "symmetric_state_squeeze"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( @@ -563,7 +565,7 @@ WasiCryptoTest::symmetricStateSqueezeTag(__wasi_symmetric_state_t StateHandle) { writeDummyMemoryContent(); auto *Func = getHostFunc( - WasiCryptoMod, "symmetric_state_squeeze_tag"); + WasiCryptoSymmMod, "symmetric_state_squeeze_tag"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, std::initializer_list{StateHandle, 0}, @@ -581,7 +583,7 @@ WasiCryptoTest::symmetricStateSqueezeKey(__wasi_symmetric_state_t StateHandle, uint32_t AlgSize = static_cast(Alg.size()); auto *Func = getHostFunc( - WasiCryptoMod, "symmetric_state_squeeze_key"); + WasiCryptoSymmMod, "symmetric_state_squeeze_key"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ @@ -597,7 +599,7 @@ WasiCryptoTest::symmetricStateMaxTagLen(__wasi_symmetric_state_t StateHandle) { writeDummyMemoryContent(); auto *Func = getHostFunc( - WasiCryptoMod, "symmetric_state_max_tag_len"); + WasiCryptoSymmMod, "symmetric_state_max_tag_len"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, std::initializer_list{StateHandle, 0}, @@ -617,7 +619,7 @@ WasiCryptoTest::symmetricStateEncrypt(__wasi_symmetric_state_t StateHandle, writeSpan(Data, OutSize); uint32_t DataSize = static_cast(Data.size()); - auto *Func = getHostFunc(WasiCryptoMod, + auto *Func = getHostFunc(WasiCryptoSymmMod, "symmetric_state_encrypt"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( @@ -644,7 +646,7 @@ WasiCryptoTest::symmetricStateEncryptDetached( uint32_t DataSize = static_cast(Data.size()); auto *Func = getHostFunc( - WasiCryptoMod, "symmetric_state_encrypt_detached"); + WasiCryptoSymmMod, "symmetric_state_encrypt_detached"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, @@ -669,7 +671,7 @@ WasiCryptoTest::symmetricStateDecrypt(__wasi_symmetric_state_t StateHandle, writeSpan(Data, OutSize); uint32_t DataSize = static_cast(Data.size()); - auto *Func = getHostFunc(WasiCryptoMod, + auto *Func = getHostFunc(WasiCryptoSymmMod, "symmetric_state_decrypt"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( @@ -697,7 +699,7 @@ WasiCryptoExpect<__wasi_size_t> WasiCryptoTest::symmetricStateDecryptDetached( uint32_t RawTagSize = static_cast(RawTag.size()); auto *Func = getHostFunc( - WasiCryptoMod, "symmetric_state_decrypt_detached"); + WasiCryptoSymmMod, "symmetric_state_decrypt_detached"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ @@ -720,7 +722,7 @@ WasiCryptoExpect WasiCryptoTest::symmetricStateRatchet(__wasi_symmetric_state_t StateHandle) { writeDummyMemoryContent(); - auto *Func = getHostFunc(WasiCryptoMod, + auto *Func = getHostFunc(WasiCryptoSymmMod, "symmetric_state_ratchet"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( @@ -736,7 +738,7 @@ WasiCryptoTest::symmetricMaxTagLen(__wasi_symmetric_tag_t TagHandle) { writeDummyMemoryContent(); auto *Func = getHostFunc( - WasiCryptoMod, "symmetric_state_max_tag_len"); + WasiCryptoSymmMod, "symmetric_state_max_tag_len"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, std::initializer_list{TagHandle, 0}, @@ -754,7 +756,7 @@ WasiCryptoTest::symmetricTagPull(__wasi_symmetric_tag_t TagHandle, uint32_t BufSize = static_cast(Buf.size()); auto *Func = - getHostFunc(WasiCryptoMod, "symmetric_tag_pull"); + getHostFunc(WasiCryptoSymmMod, "symmetric_tag_pull"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ @@ -775,8 +777,8 @@ WasiCryptoTest::symmetricTagVerify(__wasi_symmetric_tag_t TagHandle, writeSpan(RawTag, 0); uint32_t RawTagSize = static_cast(RawTag.size()); - auto *Func = - getHostFunc(WasiCryptoMod, "symmetric_tag_verify"); + auto *Func = getHostFunc(WasiCryptoSymmMod, + "symmetric_tag_verify"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, @@ -791,8 +793,8 @@ WasiCryptoExpect WasiCryptoTest::symmetricTagClose(__wasi_symmetric_tag_t TagHandle) { writeDummyMemoryContent(); - auto *Func = - getHostFunc(WasiCryptoMod, "symmetric_tag_close"); + auto *Func = getHostFunc(WasiCryptoSymmMod, + "symmetric_tag_close"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, std::initializer_list{TagHandle}, Errno)); @@ -810,7 +812,7 @@ WasiCryptoExpect<__wasi_keypair_t> WasiCryptoTest::keypairGenerate( writeOptOptions(OptOptionsHandle, AlgSize); auto *Func = getHostFunc( - WasiCryptoMod, "keypair_generate"); + WasiCryptoAsymCommonMod, "keypair_generate"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, @@ -832,8 +834,8 @@ WasiCryptoTest::keypairImport(__wasi_algorithm_type_e_t AlgType, writeSpan(Encoded, AlgSize); uint32_t EncodedSize = static_cast(Encoded.size()); - auto *Func = getHostFunc(WasiCryptoMod, - "keypair_import"); + auto *Func = getHostFunc( + WasiCryptoAsymCommonMod, "keypair_import"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ @@ -856,7 +858,7 @@ WasiCryptoExpect<__wasi_keypair_t> WasiCryptoTest::keypairGenerateManaged( writeOptOptions(OptOptionsHandle, AlgSize); auto *Func = getHostFunc( - WasiCryptoMod, "keypair_generate_managed"); + WasiCryptoAsymCommonMod, "keypair_generate_managed"); EXPECT_NE(Func, nullptr); EXPECT_TRUE( Func->run(&MemInst, @@ -877,7 +879,7 @@ WasiCryptoExpect WasiCryptoTest::keypairStoreManaged( uint32_t KpIdSize = static_cast(KpId.size()); auto *Func = getHostFunc( - WasiCryptoMod, "keypair_store_managed"); + WasiCryptoAsymCommonMod, "keypair_store_managed"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ @@ -897,7 +899,7 @@ WasiCryptoExpect<__wasi_version_t> WasiCryptoTest::keypairReplaceManaged( writeDummyMemoryContent(); auto *Func = getHostFunc( - WasiCryptoMod, "keypair_replace_managed"); + WasiCryptoAsymCommonMod, "keypair_replace_managed"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ @@ -914,8 +916,8 @@ WasiCryptoTest::keypairId(__wasi_keypair_t KpHandle, Span KpId) { writeSpan(KpId, 0); uint32_t KpIdSize = static_cast(KpId.size()); - auto *Func = - getHostFunc(WasiCryptoMod, "keypair_id"); + auto *Func = getHostFunc(WasiCryptoAsymCommonMod, + "keypair_id"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ @@ -938,8 +940,8 @@ WasiCryptoTest::keypairFromId(__wasi_secrets_manager_t SecretsManagerHandle, writeSpan(KpId, 0); uint32_t KpIdSize = static_cast(KpId.size()); - auto *Func = getHostFunc(WasiCryptoMod, - "keypair_from_id"); + auto *Func = getHostFunc( + WasiCryptoAsymCommonMod, "keypair_from_id"); EXPECT_NE(Func, nullptr); EXPECT_TRUE( Func->run(&MemInst, @@ -957,7 +959,7 @@ WasiCryptoTest::keypairFromPkAndSk(__wasi_publickey_t PkHandle, writeDummyMemoryContent(); auto *Func = getHostFunc( - WasiCryptoMod, "keypair_from_pk_and_sk"); + WasiCryptoAsymCommonMod, "keypair_from_pk_and_sk"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, @@ -973,8 +975,8 @@ WasiCryptoTest::keypairExport(__wasi_keypair_t KpHandle, __wasi_keypair_encoding_e_t Encoding) { writeDummyMemoryContent(); - auto *Func = getHostFunc(WasiCryptoMod, - "keypair_export"); + auto *Func = getHostFunc( + WasiCryptoAsymCommonMod, "keypair_export"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ @@ -990,7 +992,7 @@ WasiCryptoTest::keypairPublickey(__wasi_keypair_t KpHandle) { writeDummyMemoryContent(); auto *Func = getHostFunc( - WasiCryptoMod, "keypair_publickey"); + WasiCryptoAsymCommonMod, "keypair_publickey"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, std::initializer_list{KpHandle, 0}, @@ -1005,7 +1007,7 @@ WasiCryptoTest::keypairSecretkey(__wasi_keypair_t KpHandle) { writeDummyMemoryContent(); auto *Func = getHostFunc( - WasiCryptoMod, "keypair_secretkey"); + WasiCryptoAsymCommonMod, "keypair_secretkey"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, std::initializer_list{KpHandle, 0}, @@ -1018,8 +1020,8 @@ WasiCryptoTest::keypairSecretkey(__wasi_keypair_t KpHandle) { WasiCryptoExpect WasiCryptoTest::keypairClose(__wasi_keypair_t KpHandle) { writeDummyMemoryContent(); - auto *Func = getHostFunc(WasiCryptoMod, - "keypair_close"); + auto *Func = getHostFunc( + WasiCryptoAsymCommonMod, "keypair_close"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, std::initializer_list{KpHandle}, Errno)); @@ -1038,7 +1040,7 @@ WasiCryptoExpect<__wasi_publickey_t> WasiCryptoTest::publickeyImport( uint32_t EncodedSize = static_cast(Encoded.size()); auto *Func = getHostFunc( - WasiCryptoMod, "publickey_import"); + WasiCryptoAsymCommonMod, "publickey_import"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ @@ -1057,7 +1059,7 @@ WasiCryptoTest::publickeyExport(__wasi_publickey_t PkHandle, writeDummyMemoryContent(); auto *Func = getHostFunc( - WasiCryptoMod, "publickey_export"); + WasiCryptoAsymCommonMod, "publickey_export"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ @@ -1073,7 +1075,7 @@ WasiCryptoTest::publickeyVerify(__wasi_publickey_t PkHandle) { writeDummyMemoryContent(); auto *Func = getHostFunc( - WasiCryptoMod, "publickey_verify"); + WasiCryptoAsymCommonMod, "publickey_verify"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, std::initializer_list{PkHandle}, Errno)); @@ -1087,7 +1089,7 @@ WasiCryptoTest::publickeyFromSecretkey(__wasi_secretkey_t SkHandle) { writeDummyMemoryContent(); auto *Func = getHostFunc( - WasiCryptoMod, "publickey_from_secretkey"); + WasiCryptoAsymCommonMod, "publickey_from_secretkey"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, std::initializer_list{SkHandle, 0}, @@ -1101,8 +1103,8 @@ WasiCryptoExpect WasiCryptoTest::publickeyClose(__wasi_publickey_t PkHandle) { writeDummyMemoryContent(); - auto *Func = getHostFunc(WasiCryptoMod, - "publickey_close"); + auto *Func = getHostFunc( + WasiCryptoAsymCommonMod, "publickey_close"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, std::initializer_list{PkHandle}, Errno)); @@ -1121,7 +1123,7 @@ WasiCryptoExpect<__wasi_secretkey_t> WasiCryptoTest::secretkeyImport( uint32_t EncodedSize = static_cast(Encoded.size()); auto *Func = getHostFunc( - WasiCryptoMod, "secretkey_import"); + WasiCryptoAsymCommonMod, "secretkey_import"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ @@ -1140,7 +1142,7 @@ WasiCryptoTest::secretkeyExport(__wasi_secretkey_t SkHandle, writeDummyMemoryContent(); auto *Func = getHostFunc( - WasiCryptoMod, "secretkey_export"); + WasiCryptoAsymCommonMod, "secretkey_export"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ @@ -1155,8 +1157,8 @@ WasiCryptoExpect WasiCryptoTest::secretkeyClose(__wasi_secretkey_t SkHandle) { writeDummyMemoryContent(); - auto *Func = getHostFunc(WasiCryptoMod, - "secretkey_close"); + auto *Func = getHostFunc( + WasiCryptoAsymCommonMod, "secretkey_close"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, std::initializer_list{SkHandle}, Errno)); @@ -1170,7 +1172,7 @@ WasiCryptoTest::kxDh(__wasi_kx_publickey_t PkHandle, __wasi_kx_secretkey_t SkHandle) { writeDummyMemoryContent(); - auto *Func = getHostFunc(WasiCryptoMod, "kx_dh"); + auto *Func = getHostFunc(WasiCryptoKxMod, "kx_dh"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, @@ -1185,7 +1187,7 @@ WasiCryptoExpect> WasiCryptoTest::kxEncapsulate(__wasi_kx_publickey_t PkHandle) { writeDummyMemoryContent(); - auto *Func = getHostFunc(WasiCryptoMod, "kx_encapsulate"); + auto *Func = getHostFunc(WasiCryptoKxMod, "kx_encapsulate"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, std::initializer_list{PkHandle, 0, 1}, @@ -1204,7 +1206,7 @@ WasiCryptoTest::kxDecapsulate(__wasi_kx_secretkey_t SkHandle, uint32_t EncapsulatedSecretSize = static_cast(EncapsulatedSecret.size()); - auto *Func = getHostFunc(WasiCryptoMod, "kx_decapsulate"); + auto *Func = getHostFunc(WasiCryptoKxMod, "kx_decapsulate"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, @@ -1221,8 +1223,8 @@ WasiCryptoTest::signatureExport(__wasi_signature_t SigHandle, __wasi_signature_encoding_e_t Encoding) { writeDummyMemoryContent(); - auto *Func = - getHostFunc(WasiCryptoMod, "signature_export"); + auto *Func = getHostFunc(WasiCryptoSignMod, + "signature_export"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run(&MemInst, std::initializer_list{ @@ -1243,8 +1245,8 @@ WasiCryptoTest::signatureImport(std::string_view Alg, writeSpan(Encoded, AlgSize); uint32_t EncodedSize = static_cast(Encoded.size()); - auto *Func = - getHostFunc(WasiCryptoMod, "signature_import"); + auto *Func = getHostFunc(WasiCryptoSignMod, + "signature_import"); EXPECT_NE(Func, nullptr); EXPECT_TRUE( Func->run(&MemInst, @@ -1261,7 +1263,8 @@ WasiCryptoExpect WasiCryptoTest::signatureClose(__wasi_signature_t SigHandle) { writeDummyMemoryContent(); - auto *Func = getHostFunc(WasiCryptoMod, "signature_close"); + auto *Func = getHostFunc(WasiCryptoSignMod, + "signature_close"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, std::initializer_list{SigHandle}, Errno)); @@ -1274,8 +1277,8 @@ WasiCryptoExpect<__wasi_signature_state_t> WasiCryptoTest::signatureStateOpen(__wasi_signature_keypair_t KpHandle) { writeDummyMemoryContent(); - auto *Func = - getHostFunc(WasiCryptoMod, "signature_state_open"); + auto *Func = getHostFunc(WasiCryptoSignMod, + "signature_state_open"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, std::initializer_list{KpHandle, 0}, @@ -1292,7 +1295,7 @@ WasiCryptoTest::signatureStateUpdate(__wasi_signature_state_t StateHandle, writeSpan(Input, 0); uint32_t InputSize = static_cast(Input.size()); - auto *Func = getHostFunc(WasiCryptoMod, + auto *Func = getHostFunc(WasiCryptoSignMod, "signature_state_update"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( @@ -1308,8 +1311,8 @@ WasiCryptoExpect<__wasi_signature_t> WasiCryptoTest::signatureStateSign(__wasi_signature_state_t StateHandle) { writeDummyMemoryContent(); - auto *Func = - getHostFunc(WasiCryptoMod, "signature_state_sign"); + auto *Func = getHostFunc(WasiCryptoSignMod, + "signature_state_sign"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, std::initializer_list{StateHandle, 0}, @@ -1323,7 +1326,7 @@ WasiCryptoExpect WasiCryptoTest::signatureStateClose(__wasi_signature_state_t StateHandle) { writeDummyMemoryContent(); - auto *Func = getHostFunc(WasiCryptoMod, + auto *Func = getHostFunc(WasiCryptoSignMod, "signature_state_close"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( @@ -1340,7 +1343,7 @@ WasiCryptoTest::signatureVerificationStateOpen( writeDummyMemoryContent(); auto *Func = getHostFunc( - WasiCryptoMod, "signature_verification_state_open"); + WasiCryptoSignMod, "signature_verification_state_open"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, std::initializer_list{PkHandle, 0}, @@ -1358,7 +1361,7 @@ WasiCryptoExpect WasiCryptoTest::signatureVerificationStateUpdate( uint32_t InputSize = static_cast(Input.size()); auto *Func = getHostFunc( - WasiCryptoMod, "signature_verification_state_update"); + WasiCryptoSignMod, "signature_verification_state_update"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, @@ -1375,7 +1378,7 @@ WasiCryptoExpect WasiCryptoTest::signatureVerificationStateVerify( writeDummyMemoryContent(); auto *Func = getHostFunc( - WasiCryptoMod, "signature_verification_state_verify"); + WasiCryptoSignMod, "signature_verification_state_verify"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, @@ -1391,7 +1394,7 @@ WasiCryptoExpect WasiCryptoTest::signatureVerificationStateClose( writeDummyMemoryContent(); auto *Func = getHostFunc( - WasiCryptoMod, "signature_verification_state_close"); + WasiCryptoSignMod, "signature_verification_state_close"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( &MemInst, std::initializer_list{StateHandle}, diff --git a/test/plugins/wasi_crypto/helper.h b/test/plugins/wasi_crypto/helper.h index a2f0b336..92fd6379 100644 --- a/test/plugins/wasi_crypto/helper.h +++ b/test/plugins/wasi_crypto/helper.h @@ -3,11 +3,13 @@ #pragma once -#include "asymmetric_common/func.h" -#include "common/func.h" +#include "asymmetric_common/module.h" +#include "common/module.h" #include "ctx.h" #include "helper.h" -#include "module.h" +#include "kx/module.h" +#include "signatures/module.h" +#include "symmetric/module.h" #include "utils/error.h" #include "common/span.h" @@ -58,16 +60,49 @@ class WasiCryptoTest : public ::testing::Test { "../../../plugins/wasi_crypto/" "libwasmedgePluginWasiCrypto" WASMEDGE_LIB_EXTENSION)); if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasi_crypto"sv)) { - if (const auto *Module = Plugin->findModule("wasi_crypto"sv)) { - WasiCryptoMod = dynamic_cast( + if (const auto *Module = + Plugin->findModule("wasi_crypto_asymmetric_common"sv)) { + WasiCryptoAsymCommonMod = + dynamic_cast( + Module->create().release()); + } + if (const auto *Module = Plugin->findModule("wasi_crypto_common"sv)) { + WasiCryptoCommonMod = + dynamic_cast( + Module->create().release()); + } + if (const auto *Module = Plugin->findModule("wasi_crypto_kx"sv)) { + WasiCryptoKxMod = dynamic_cast( Module->create().release()); } + if (const auto *Module = Plugin->findModule("wasi_crypto_signatures"sv)) { + WasiCryptoSignMod = + dynamic_cast( + Module->create().release()); + } + if (const auto *Module = Plugin->findModule("wasi_crypto_symmetric"sv)) { + WasiCryptoSymmMod = + dynamic_cast( + Module->create().release()); + } } } ~WasiCryptoTest() override { - if (WasiCryptoMod) { - delete WasiCryptoMod; + if (WasiCryptoAsymCommonMod) { + delete WasiCryptoAsymCommonMod; + } + if (WasiCryptoCommonMod) { + delete WasiCryptoCommonMod; + } + if (WasiCryptoKxMod) { + delete WasiCryptoKxMod; + } + if (WasiCryptoSignMod) { + delete WasiCryptoSignMod; + } + if (WasiCryptoSymmMod) { + delete WasiCryptoSymmMod; } } @@ -353,7 +388,11 @@ class WasiCryptoTest : public ::testing::Test { WasmEdge::AST::MemoryType(1)}; std::array Errno; - Host::WasiCryptoModule *WasiCryptoMod = nullptr; + Host::WasiCryptoAsymmetricCommonModule *WasiCryptoAsymCommonMod = nullptr; + Host::WasiCryptoCommonModule *WasiCryptoCommonMod = nullptr; + Host::WasiCryptoKxModule *WasiCryptoKxMod = nullptr; + Host::WasiCryptoSignaturesModule *WasiCryptoSignMod = nullptr; + Host::WasiCryptoSymmetricModule *WasiCryptoSymmMod = nullptr; }; } // namespace WasiCrypto From e79ba1218141f708e8f4e4dfc8999f1f42d6fc53 Mon Sep 17 00:00:00 2001 From: dm4 Date: Fri, 29 Jul 2022 16:32:33 +0800 Subject: [PATCH 061/623] [CI] Publish docker slim image in release.yml Signed-off-by: dm4 --- utils/docker/Dockerfile.release | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/utils/docker/Dockerfile.release b/utils/docker/Dockerfile.release index 225b9549..e2656cc9 100644 --- a/utils/docker/Dockerfile.release +++ b/utils/docker/Dockerfile.release @@ -1,10 +1,12 @@ FROM ubuntu:20.04 ARG VERSION -ENV DEBIAN_FRONTEND=noninteractive -RUN apt update && \ - apt install -y curl git && \ - curl -sSf https://raw.githubusercontent.com/WasmEdge/WasmEdge/master/utils/install.sh | bash -s -- -p /usr/local -e all -v $VERSION +ADD WasmEdge-$VERSION-Linux.tar.gz /tmp/ +RUN cp -rf /tmp/WasmEdge-$VERSION-Linux/bin/* /usr/local/bin && \ + cp -rf /tmp/WasmEdge-$VERSION-Linux/lib64/* /usr/local/lib && \ + cp -rf /tmp/WasmEdge-$VERSION-Linux/include/* /usr/local/include && \ + ldconfig /usr/local/lib +RUN rm -rf /tmp/WasmEdge-$VERSION-Linux WORKDIR /app CMD ["/usr/local/bin/wasmedge"] From 940563690b56a9bffbbecdaa26bd2e6f75f02f58 Mon Sep 17 00:00:00 2001 From: Zhou Zhou Date: Tue, 2 Aug 2022 14:59:55 +0800 Subject: [PATCH 062/623] [Plugin] implement the host function for https request (#1666) * [Executor] implement the host function for https * [Misc] clang format and tidy * [Docs] add the doc for networking for https * [Test] add the tests for networking for https * [Plugin] update the host function to expose the received content * [Tests] update the tests for get_rcv and get_rcv_len function * [Docs] update the docs for get_rcv and get_rcv_len function * [Misc] add plugin name to Error message line * [Docs] correct the grammar mistake Signed-off-by: zhouzhou --- plugins/CMakeLists.txt | 6 ++ plugins/httpsreq/CMakeLists.txt | 38 +++++++ plugins/httpsreq/httpsreqbase.h | 23 ++++ plugins/httpsreq/httpsreqenv.cpp | 38 +++++++ plugins/httpsreq/httpsreqenv.h | 29 ++++++ plugins/httpsreq/httpsreqfunc.cpp | 142 +++++++++++++++++++++++++ plugins/httpsreq/httpsreqfunc.h | 35 +++++++ plugins/httpsreq/httpsreqmodule.cpp | 18 ++++ plugins/httpsreq/httpsreqmodule.h | 24 +++++ test/plugins/CMakeLists.txt | 5 + test/plugins/httpsreq/CMakeLists.txt | 15 +++ test/plugins/httpsreq/httpsreq.cpp | 150 +++++++++++++++++++++++++++ 12 files changed, 523 insertions(+) create mode 100644 plugins/httpsreq/CMakeLists.txt create mode 100644 plugins/httpsreq/httpsreqbase.h create mode 100644 plugins/httpsreq/httpsreqenv.cpp create mode 100644 plugins/httpsreq/httpsreqenv.h create mode 100644 plugins/httpsreq/httpsreqfunc.cpp create mode 100644 plugins/httpsreq/httpsreqfunc.h create mode 100644 plugins/httpsreq/httpsreqmodule.cpp create mode 100644 plugins/httpsreq/httpsreqmodule.h create mode 100644 test/plugins/httpsreq/CMakeLists.txt create mode 100644 test/plugins/httpsreq/httpsreq.cpp diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index e9244d3d..5f228529 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -10,6 +10,12 @@ if (WASMEDGE_PLUGIN_WASI_NN_BACKEND) add_subdirectory(wasi_nn) endif() +if(WASMEDGE_PLUGIN_HTTPSREQ) + if(CMAKE_SYSTEM_NAME MATCHES "Linux") + add_subdirectory(httpsreq) + endif() +endif() + if (WASMEDGE_PLUGIN_WASI_CRYPTO) add_subdirectory(wasi_crypto) endif() diff --git a/plugins/httpsreq/CMakeLists.txt b/plugins/httpsreq/CMakeLists.txt new file mode 100644 index 00000000..b71cd710 --- /dev/null +++ b/plugins/httpsreq/CMakeLists.txt @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +wasmedge_add_library(wasmedgePluginHttpsReq + SHARED + httpsreqenv.cpp + httpsreqfunc.cpp + httpsreqmodule.cpp +) + +target_compile_options(wasmedgePluginHttpsReq + PUBLIC + -DWASMEDGE_PLUGIN +) + + +if(CMAKE_SYSTEM_NAME MATCHES "Darwin") + target_link_options(wasmedgePluginHttpsReq + PUBLIC + -Wl,-U,__ZN8WasmEdge6Plugin14PluginRegisterC1EPKNS0_6Plugin16PluginDescriptorE + -Wl,-U,__ZN8WasmEdge6Plugin14PluginRegisterD1Ev + ) +endif() + +target_include_directories(wasmedgePluginHttpsReq + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_link_libraries(wasmedgePluginHttpsReq + PUBLIC + wasmedgeCommon + wasmedgeSystem + -L/usr/lib -lssl -lcrypto +) + +install(TARGETS wasmedgePluginHttpsReq DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) diff --git a/plugins/httpsreq/httpsreqbase.h b/plugins/httpsreq/httpsreqbase.h new file mode 100644 index 00000000..6cafe9b7 --- /dev/null +++ b/plugins/httpsreq/httpsreqbase.h @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "common/errcode.h" +#include "httpsreqenv.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { + +template class HttpsReq : public Runtime::HostFunction { +public: + HttpsReq(HttpsReqEnvironment &HostEnv) + : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + HttpsReqEnvironment &Env; +}; + +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/httpsreq/httpsreqenv.cpp b/plugins/httpsreq/httpsreqenv.cpp new file mode 100644 index 00000000..6b2cc318 --- /dev/null +++ b/plugins/httpsreq/httpsreqenv.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "httpsreqenv.h" +#include "httpsreqmodule.h" + +namespace WasmEdge { +namespace Host { + +namespace { + +Runtime::Instance::ModuleInstance *create(void) noexcept { + return new HttpsReqModule; +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "https_req", + .Description = "", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 10, 1, 0}, + .ModuleCount = 1, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "https_req", + .Description = "", + .Create = create, + }, + }, + .AddOptions = nullptr, +}; + +} // namespace + +Plugin::PluginRegister HttpsReqEnvironment::Register(&Descriptor); + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/httpsreq/httpsreqenv.h b/plugins/httpsreq/httpsreqenv.h new file mode 100644 index 00000000..047604cd --- /dev/null +++ b/plugins/httpsreq/httpsreqenv.h @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "plugin/plugin.h" +#include "po/argument_parser.h" +#include "po/list.h" +#include "po/option.h" +#include +#include +#include + +namespace WasmEdge { +namespace Host { + +class HttpsReqEnvironment { +public: + std::string Host; + uint32_t Port; + std::string BodyStr; + std::string Rcv; + + /// Initial Configurations + static Plugin::PluginRegister Register; +}; + +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/httpsreq/httpsreqfunc.cpp b/plugins/httpsreq/httpsreqfunc.cpp new file mode 100644 index 00000000..680818a1 --- /dev/null +++ b/plugins/httpsreq/httpsreqfunc.cpp @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "httpsreqfunc.h" +#include +#include +#include +#include +#include +#include +#include +#include + +// Some of the code was taken from this post: +// https://stackoverflow.com/questions/52727565/client-in-c-use-gethostbyname-or-getaddrinfo + +namespace WasmEdge { +namespace Host { + +Expect SendData::body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t HostPtr, uint32_t HostLen, uint32_t Port, + uint32_t BodyPtr, uint32_t BodyLen) { + if (MemInst == nullptr) { + return Unexpect(ErrCode::ExecutionFailed); + } + + char *Host = MemInst->getPointer(HostPtr); + std::string NewHost; + std::copy_n(Host, HostLen, std::back_inserter(NewHost)); + Env.Host = std::move(NewHost); + + Env.Port = Port; + + char *BodyStr = MemInst->getPointer(BodyPtr); + std::string NewBodyStr; + std::copy_n(BodyStr, BodyLen, std::back_inserter(NewBodyStr)); + Env.BodyStr = std::move(NewBodyStr); + + const SSL_METHOD *Method = TLS_client_method(); + SSL_CTX *Ctx = SSL_CTX_new(Method); + + if (Ctx == nullptr) { + ERR_print_errors_fp(stderr); + exit(EXIT_FAILURE); + } + + SSL *Ssl = SSL_new(Ctx); + if (Ssl == nullptr) { + fprintf(stderr, "[Httpsreq plugin] SSL_new() failed\n"); + exit(EXIT_FAILURE); + } + + // open connection + int Sfd, Err; + struct addrinfo Hints = {}, *Addrs; + char PortStr[16] = {}; + + Hints.ai_family = AF_INET; + Hints.ai_socktype = SOCK_STREAM; + Hints.ai_protocol = IPPROTO_TCP; + + std::sprintf(PortStr, "%d", Port); + + Err = getaddrinfo(Env.Host.c_str(), PortStr, &Hints, &Addrs); + if (Err != 0) { + fprintf(stderr, "[Httpsreq plugin] %s: %s\n", Env.Host.c_str(), + gai_strerror(Err)); + abort(); + } + + for (struct addrinfo *Addr = Addrs; Addr != NULL; Addr = Addr->ai_next) { + Sfd = socket(Addr->ai_family, Addr->ai_socktype, Addr->ai_protocol); + if (Sfd == -1) { + Err = errno; + break; + } + + if (connect(Sfd, Addr->ai_addr, Addr->ai_addrlen) == 0) + break; + Err = errno; + close(Sfd); + Sfd = -1; + } + + freeaddrinfo(Addrs); + + if (Sfd == -1) { + fprintf(stderr, "[Httpsreq plugin] %s: %s\n", Env.Host.c_str(), + strerror(Err)); + abort(); + } + + SSL_set_fd(Ssl, Sfd); + + const int Status = SSL_connect(Ssl); + if (Status != 1) { + SSL_get_error(Ssl, Status); + ERR_print_errors_fp(stderr); + fprintf(stderr, + "[Httpsreq plugin] SSL_connect failed with SSL_get_error code %d\n", + Status); + exit(EXIT_FAILURE); + } + + SSL_write(Ssl, BodyStr, strlen(Env.BodyStr.c_str())); + + // Receive + char Buffer[1024]; + int Nbytes = 0; + Env.Rcv = ""; + while (true) { + Nbytes = SSL_read(Ssl, Buffer, 1024); + if (Nbytes <= 0) { + break; + } + std::string Buf(Buffer, Nbytes); + Env.Rcv = Env.Rcv + Buf; + } + + SSL_free(Ssl); + close(Sfd); + SSL_CTX_free(Ctx); + + return {}; +} + +Expect HttpsReqGetRcv::body(Runtime::Instance::MemoryInstance *MemInst, + uint32_t BufPtr) { + if (MemInst == nullptr) { + return Unexpect(ErrCode::ExecutionFailed); + } + char *Buf = MemInst->getPointer(BufPtr); + std::copy_n(Env.Rcv.begin(), Env.Rcv.size(), Buf); + return {}; +} + +Expect HttpsReqGetRcvLen::body(Runtime::Instance::MemoryInstance *) { + return static_cast(Env.Rcv.size()); +} + +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/httpsreq/httpsreqfunc.h b/plugins/httpsreq/httpsreqfunc.h new file mode 100644 index 00000000..bc391fab --- /dev/null +++ b/plugins/httpsreq/httpsreqfunc.h @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "common/defines.h" +#include "httpsreqbase.h" + +#include + +namespace WasmEdge { +namespace Host { + +class SendData : public HttpsReq { +public: + SendData(HttpsReqEnvironment &HostEnv) : HttpsReq(HostEnv) {} + Expect body(Runtime::Instance::MemoryInstance *, uint32_t HostPtr, + uint32_t HostLen, uint32_t Port, uint32_t BodyPtr, + uint32_t BodyLen); +}; + +class HttpsReqGetRcv : public HttpsReq { +public: + HttpsReqGetRcv(HttpsReqEnvironment &HostEnv) : HttpsReq(HostEnv) {} + Expect body(Runtime::Instance::MemoryInstance *, uint32_t BufPtr); +}; + +class HttpsReqGetRcvLen : public HttpsReq { +public: + HttpsReqGetRcvLen(HttpsReqEnvironment &HostEnv) : HttpsReq(HostEnv) {} + Expect body(Runtime::Instance::MemoryInstance *); +}; + +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/httpsreq/httpsreqmodule.cpp b/plugins/httpsreq/httpsreqmodule.cpp new file mode 100644 index 00000000..868f4458 --- /dev/null +++ b/plugins/httpsreq/httpsreqmodule.cpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "httpsreqmodule.h" +#include "httpsreqfunc.h" + +namespace WasmEdge { +namespace Host { + +/// Register your functions in module. +HttpsReqModule::HttpsReqModule() : ModuleInstance("httpsreq") { + addHostFunc("send_data", std::make_unique(Env)); + addHostFunc("get_rcv", std::make_unique(Env)); + addHostFunc("get_rcv_len", std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/httpsreq/httpsreqmodule.h b/plugins/httpsreq/httpsreqmodule.h new file mode 100644 index 00000000..60f545ef --- /dev/null +++ b/plugins/httpsreq/httpsreqmodule.h @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "httpsreqenv.h" +#include "runtime/instance/module.h" +#include + +namespace WasmEdge { +namespace Host { + +class HttpsReqModule : public Runtime::Instance::ModuleInstance { +public: + HttpsReqModule(); + + HttpsReqEnvironment &getEnv() { return Env; } + +private: + HttpsReqEnvironment Env; +}; + +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt index 1b8f8f04..5a52aab1 100644 --- a/test/plugins/CMakeLists.txt +++ b/test/plugins/CMakeLists.txt @@ -7,6 +7,11 @@ endif() if (WASMEDGE_PLUGIN_WASI_NN_BACKEND) add_subdirectory(wasi_nn) endif() +if(WASMEDGE_PLUGIN_HTTPSREQ) + if(CMAKE_SYSTEM_NAME MATCHES "Linux") + add_subdirectory(httpsreq) + endif() +endif() if (WASMEDGE_PLUGIN_WASI_CRYPTO) add_subdirectory(wasi_crypto) endif() diff --git a/test/plugins/httpsreq/CMakeLists.txt b/test/plugins/httpsreq/CMakeLists.txt new file mode 100644 index 00000000..54f6279b --- /dev/null +++ b/test/plugins/httpsreq/CMakeLists.txt @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +wasmedge_add_executable(wasmedgeHttpsReqTests + httpsreq.cpp +) + +target_link_libraries(wasmedgeHttpsReqTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} + wasmedgePlugin + wasmedgePluginHttpsReq +) + +add_test(wasmedgeHttpsReqTests wasmedgeHttpsReqTests) \ No newline at end of file diff --git a/test/plugins/httpsreq/httpsreq.cpp b/test/plugins/httpsreq/httpsreq.cpp new file mode 100644 index 00000000..35fae957 --- /dev/null +++ b/test/plugins/httpsreq/httpsreq.cpp @@ -0,0 +1,150 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "common/defines.h" +#include "httpsreqfunc.h" +#include "httpsreqmodule.h" +#include "runtime/instance/module.h" + +#include +#include +#include +#include +#include +#include + +namespace { +WasmEdge::Runtime::Instance::ModuleInstance *createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/httpsreq/" + "libwasmedgePluginHttpsReq" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("https_req"sv)) { + if (const auto *Module = Plugin->findModule("https_req"sv)) { + return Module->create().release(); + } + } + return nullptr; +} + +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t Offset, uint32_t Cnt, uint8_t C = 0) noexcept { + std::fill_n(MemInst.getPointer(Offset), Cnt, C); +} + +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t Offset, const std::string &Str) noexcept { + char *Buf = MemInst.getPointer(Offset); + std::copy_n(Str.c_str(), Str.length(), Buf); +} + +} // namespace + +TEST(wasmedgeHttpsReqTests, SendData) { + // Create the httpsreq module instance. + auto *ProcMod = + dynamic_cast(createModule()); + EXPECT_FALSE(ProcMod == nullptr); + + // Create the memory instance. + WasmEdge::Runtime::Instance::MemoryInstance MemInst( + WasmEdge::AST::MemoryType(1)); + + // Clear the memory[0, 256]. + fillMemContent(MemInst, 0, 256); + // Set the memory[0, 11] as string "echo". + fillMemContent(MemInst, 0, std::string("httpbin.org")); + // Set the memory[30, 116] as string "GET / HTTP/1.1\nHost: + // httpbin.org\r\nConnection: Close\r\nReferer: https://httpbin.org/\r\n\r\n". + fillMemContent(MemInst, 30, + std::string("GET / HTTP/1.1\nHost: httpbin.org\r\nConnection: " + "Close\r\nReferer: https://httpbin.org/\r\n\r\n")); + + // Get the function "send_data" + auto *FuncInst = ProcMod->findFuncExports("send_data"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInst = + dynamic_cast(FuncInst->getHostFunc()); + + // Test: Run function successfully for get requests + EXPECT_TRUE(HostFuncInst.run( + &MemInst, + std::initializer_list{ + UINT32_C(0), UINT32_C(11), UINT32_C(443), UINT32_C(30), UINT32_C(86)}, + {})); + EXPECT_TRUE(ProcMod->getEnv().Host == "httpbin.org"); + EXPECT_TRUE(ProcMod->getEnv().Body == + "GET / HTTP/1.1\nHost: httpbin.org\r\nConnection: " + "Close\r\nReferer: https://httpbin.org/\r\n\r\n"); +} + +TEST(wasmedgeHttpsReqTests, GetRcv) { + // Create the httpsreq module instance. + auto *ProcMod = + dynamic_cast(createModule()); + EXPECT_FALSE(ProcMod == nullptr); + + // Create the memory instance. + WasmEdge::Runtime::Instance::MemoryInstance MemInst( + WasmEdge::AST::MemoryType(1)); + + fillMemContent(MemInst, 0, 256); + // Set the memory[0, 11] as string "echo". + fillMemContent(MemInst, 0, std::string("httpbin.org")); + // Set the memory[30, 116] as string "GET / HTTP/1.1\nHost: + // httpbin.org\r\nConnection: Close\r\nReferer: https://httpbin.org/\r\n\r\n". + fillMemContent(MemInst, 30, + std::string("GET / HTTP/1.1\nHost: httpbin.org\r\nConnection: " + "Close\r\nReferer: https://httpbin.org/\r\n\r\n")); + + // Get the function "send_data" + auto *FuncInst = ProcMod->findFuncExports("send_data"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSendData = + dynamic_cast(FuncInst->getHostFunc()); + + // Get the function "get_rcv_len" + FuncInst = ProcMod->findFuncExports("https_req_get_rcv_len"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetRcvLen = dynamic_cast( + FuncInst->getHostFunc()); + + // Get the function "get_rcv" + FuncInst = ProcMod->findFuncExports("https_req_get_rcv"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetRcv = + dynamic_cast(FuncInst->getHostFunc()); + + // Test: Run function successfully for get requests + EXPECT_TRUE(HostFuncSendData.run( + &MemInst, + std::initializer_list{ + UINT32_C(0), UINT32_C(11), UINT32_C(443), UINT32_C(30), UINT32_C(86)}, + {})); + EXPECT_TRUE(ProcMod->getEnv().Host == "httpbin.org"); + EXPECT_TRUE(ProcMod->getEnv().Body == + "GET / HTTP/1.1\nHost: httpbin.org\r\nConnection: " + "Close\r\nReferer: https://httpbin.org/\r\n\r\n"); + + // Test: Run function successfully for getrcvlen + std::array RetVal; + EXPECT_TRUE(HostFuncGetRcvLen.run(nullptr, {}, RetVal)); + uint32_t Len = RetVal[0].get(); + EXPECT_TRUE(Len > 0U); + + // Test Run function successfully for getrcv + EXPECT_TRUE(HostFuncGetRcv.run( + &MemInst, std::initializer_list{UINT32_C(0)}, {})); + EXPECT_TRUE(std::equal(ProcMod->getEnv().Rcv.begin(), + ProcMod->getEnv().Rcv.end(), + MemInst.getPointer(0))); +} + +GTEST_API_ int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} From 9ff18f319c924538031b377990cbe4e3c219bcf5 Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Thu, 18 Aug 2022 14:05:14 +0800 Subject: [PATCH 063/623] [Plugin] Remove weak symobol definition * Make plugin depends on wasmedge library Signed-off-by: Shen-Ta Hsieh --- plugins/httpsreq/CMakeLists.txt | 19 ++++++++++++++++--- plugins/wasi_crypto/CMakeLists.txt | 14 ++++++++++++-- plugins/wasi_nn/CMakeLists.txt | 16 +++++++++++----- plugins/wasmedge_process/CMakeLists.txt | 16 +++++++++++----- 4 files changed, 50 insertions(+), 15 deletions(-) diff --git a/plugins/httpsreq/CMakeLists.txt b/plugins/httpsreq/CMakeLists.txt index b71cd710..e7cfca95 100644 --- a/plugins/httpsreq/CMakeLists.txt +++ b/plugins/httpsreq/CMakeLists.txt @@ -1,6 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: 2019-2022 Second State INC +set(OPENSSL_USE_STATIC_LIBS ON) +find_package(OpenSSL REQUIRED) + wasmedge_add_library(wasmedgePluginHttpsReq SHARED httpsreqenv.cpp @@ -30,9 +33,19 @@ target_include_directories(wasmedgePluginHttpsReq target_link_libraries(wasmedgePluginHttpsReq PUBLIC - wasmedgeCommon - wasmedgeSystem - -L/usr/lib -lssl -lcrypto + OpenSSL::Crypto + OpenSSL::SSL ) +if(WASMEDGE_LINK_PUGLINS_STATIC) + target_link_libraries(wasmedgePluginHttpsReq + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginHttpsReq + PRIVATE + wasmedge_c_shared + ) +endif() install(TARGETS wasmedgePluginHttpsReq DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) diff --git a/plugins/wasi_crypto/CMakeLists.txt b/plugins/wasi_crypto/CMakeLists.txt index 36b59520..3b770f07 100644 --- a/plugins/wasi_crypto/CMakeLists.txt +++ b/plugins/wasi_crypto/CMakeLists.txt @@ -72,9 +72,19 @@ target_include_directories(wasmedgePluginWasiCrypto target_link_libraries(wasmedgePluginWasiCrypto PUBLIC + wasmedge_c_shared OpenSSL::Crypto - wasmedgeCommon - wasmedgeSystem ) +if(WASMEDGE_LINK_PUGLINS_STATIC) + target_link_libraries(wasmedgePluginWasiCrypto + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginWasiCrypto + PRIVATE + wasmedge_c_shared + ) +endif() install(TARGETS wasmedgePluginWasiCrypto DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index b13cc3cc..63807824 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -27,11 +27,17 @@ target_include_directories(wasmedgePluginWasiNN ${CMAKE_CURRENT_SOURCE_DIR} ) -target_link_libraries(wasmedgePluginWasiNN - PUBLIC - wasmedgeCommon - wasmedgeSystem -) +if(WASMEDGE_LINK_PUGLINS_STATIC) + target_link_libraries(wasmedgePluginWasiNN + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginWasiNN + PRIVATE + wasmedge_c_shared + ) +endif() install(TARGETS wasmedgePluginWasiNN DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) diff --git a/plugins/wasmedge_process/CMakeLists.txt b/plugins/wasmedge_process/CMakeLists.txt index 82852d81..50f00713 100644 --- a/plugins/wasmedge_process/CMakeLists.txt +++ b/plugins/wasmedge_process/CMakeLists.txt @@ -27,10 +27,16 @@ target_include_directories(wasmedgePluginWasmEdgeProcess ${CMAKE_CURRENT_SOURCE_DIR} ) -target_link_libraries(wasmedgePluginWasmEdgeProcess - PUBLIC - wasmedgeCommon - wasmedgeSystem -) +if(WASMEDGE_LINK_PUGLINS_STATIC) + target_link_libraries(wasmedgePluginWasmEdgeProcess + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginWasmEdgeProcess + PRIVATE + wasmedge_c_shared + ) +endif() install(TARGETS wasmedgePluginWasmEdgeProcess DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) From 2fddde68de83c46ac18b30bb57e0ba2438c8211c Mon Sep 17 00:00:00 2001 From: YiYing He Date: Thu, 4 Aug 2022 11:26:52 +0800 Subject: [PATCH 064/623] [Common] Extend the error code as a struct to support user-defined errors. Signed-off-by: YiYing He --- plugins/httpsreq/httpsreqfunc.cpp | 4 ++-- plugins/wasi_nn/wasinnfunc.cpp | 10 +++++----- plugins/wasmedge_process/processfunc.cpp | 14 +++++++------- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/plugins/httpsreq/httpsreqfunc.cpp b/plugins/httpsreq/httpsreqfunc.cpp index 680818a1..a2846d8f 100644 --- a/plugins/httpsreq/httpsreqfunc.cpp +++ b/plugins/httpsreq/httpsreqfunc.cpp @@ -21,7 +21,7 @@ Expect SendData::body(Runtime::Instance::MemoryInstance *MemInst, uint32_t HostPtr, uint32_t HostLen, uint32_t Port, uint32_t BodyPtr, uint32_t BodyLen) { if (MemInst == nullptr) { - return Unexpect(ErrCode::ExecutionFailed); + return Unexpect(ErrCode::Value::HostFuncError); } char *Host = MemInst->getPointer(HostPtr); @@ -127,7 +127,7 @@ Expect SendData::body(Runtime::Instance::MemoryInstance *MemInst, Expect HttpsReqGetRcv::body(Runtime::Instance::MemoryInstance *MemInst, uint32_t BufPtr) { if (MemInst == nullptr) { - return Unexpect(ErrCode::ExecutionFailed); + return Unexpect(ErrCode::Value::HostFuncError); } char *Buf = MemInst->getPointer(BufPtr); std::copy_n(Env.Rcv.begin(), Env.Rcv.size(), Buf); diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index b402fa08..20949124 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -22,7 +22,7 @@ Expect WasiNNLoad::body(Runtime::Instance::MemoryInstance *MemInst, uint32_t GraphIdPtr [[maybe_unused]]) { // Check memory instance from module. if (MemInst == nullptr) { - return Unexpect(ErrCode::ExecutionFailed); + return Unexpect(ErrCode::Value::HostFuncError); } if (Encoding == static_cast(WASINN::Backend::OpenVINO)) { @@ -246,7 +246,7 @@ WasiNNInitExecCtx::body(Runtime::Instance::MemoryInstance *MemInst, uint32_t GraphId, uint32_t ContextPtr [[maybe_unused]]) { if (MemInst == nullptr) { - return Unexpect(ErrCode::ExecutionFailed); + return Unexpect(ErrCode::Value::HostFuncError); } if (Env.NNGraph.size() <= GraphId) { @@ -299,7 +299,7 @@ WasiNNSetInput::body(Runtime::Instance::MemoryInstance *MemInst, uint32_t Context, uint32_t Index [[maybe_unused]], uint32_t TensorPtr [[maybe_unused]]) { if (MemInst == nullptr) { - return Unexpect(ErrCode::ExecutionFailed); + return Unexpect(ErrCode::Value::HostFuncError); } if (Env.NNContext.size() <= Context) { @@ -473,7 +473,7 @@ WasiNNGetOuput::body(Runtime::Instance::MemoryInstance *MemInst, uint32_t OutBufferMaxSize [[maybe_unused]], uint32_t BytesWrittenPtr [[maybe_unused]]) { if (MemInst == nullptr) { - return Unexpect(ErrCode::ExecutionFailed); + return Unexpect(ErrCode::Value::HostFuncError); } if (Env.NNContext.size() <= Context) { @@ -566,7 +566,7 @@ WasiNNGetOuput::body(Runtime::Instance::MemoryInstance *MemInst, Expect WasiNNCompute::body(Runtime::Instance::MemoryInstance *MemInst, uint32_t Context) { if (MemInst == nullptr) { - return Unexpect(ErrCode::ExecutionFailed); + return Unexpect(ErrCode::Value::HostFuncError); } if (Env.NNContext.size() <= Context) { diff --git a/plugins/wasmedge_process/processfunc.cpp b/plugins/wasmedge_process/processfunc.cpp index d10735c8..72b69e00 100644 --- a/plugins/wasmedge_process/processfunc.cpp +++ b/plugins/wasmedge_process/processfunc.cpp @@ -25,7 +25,7 @@ WasmEdgeProcessSetProgName::body(Runtime::Instance::MemoryInstance *MemInst, uint32_t NamePtr, uint32_t NameLen) { // Check memory instance from module. if (MemInst == nullptr) { - return Unexpect(ErrCode::ExecutionFailed); + return Unexpect(ErrCode::Value::HostFuncError); } char *Buf = MemInst->getPointer(NamePtr); @@ -38,7 +38,7 @@ WasmEdgeProcessAddArg::body(Runtime::Instance::MemoryInstance *MemInst, uint32_t ArgPtr, uint32_t ArgLen) { // Check memory instance from module. if (MemInst == nullptr) { - return Unexpect(ErrCode::ExecutionFailed); + return Unexpect(ErrCode::Value::HostFuncError); } char *Buf = MemInst->getPointer(ArgPtr); @@ -54,7 +54,7 @@ WasmEdgeProcessAddEnv::body(Runtime::Instance::MemoryInstance *MemInst, uint32_t EnvValPtr, uint32_t EnvValLen) { // Check memory instance from module. if (MemInst == nullptr) { - return Unexpect(ErrCode::ExecutionFailed); + return Unexpect(ErrCode::Value::HostFuncError); } char *EnvBuf = MemInst->getPointer(EnvNamePtr); @@ -71,7 +71,7 @@ WasmEdgeProcessAddStdIn::body(Runtime::Instance::MemoryInstance *MemInst, uint32_t BufPtr, uint32_t BufLen) { // Check memory instance from module. if (MemInst == nullptr) { - return Unexpect(ErrCode::ExecutionFailed); + return Unexpect(ErrCode::Value::HostFuncError); } uint8_t *Buf = MemInst->getPointer(BufPtr); @@ -294,7 +294,7 @@ Expect WasmEdgeProcessRun::body(Runtime::Instance::MemoryInstance *) { return Env.ExitCode; #elif WASMEDGE_OS_WINDOWS spdlog::error("wasmedge_process doesn't support windows now."); - return Unexpect(ErrCode::ExecutionFailed); + return Unexpect(ErrCode::Value::HostFuncError); #endif } @@ -313,7 +313,7 @@ WasmEdgeProcessGetStdOut::body(Runtime::Instance::MemoryInstance *MemInst, uint32_t BufPtr) { // Check memory instance from module. if (MemInst == nullptr) { - return Unexpect(ErrCode::ExecutionFailed); + return Unexpect(ErrCode::Value::HostFuncError); } char *Buf = MemInst->getPointer(BufPtr); @@ -331,7 +331,7 @@ WasmEdgeProcessGetStdErr::body(Runtime::Instance::MemoryInstance *MemInst, uint32_t BufPtr) { // Check memory instance from module. if (MemInst == nullptr) { - return Unexpect(ErrCode::ExecutionFailed); + return Unexpect(ErrCode::Value::HostFuncError); } char *Buf = MemInst->getPointer(BufPtr); From 390c49301b321491fc16e5b6cb2377daeeefdde4 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Fri, 19 Aug 2022 04:58:36 +0800 Subject: [PATCH 065/623] [Plugin] Update the host functions for the API change. Signed-off-by: YiYing He --- plugins/httpsreq/httpsreqbase.h | 1 + plugins/httpsreq/httpsreqfunc.cpp | 8 +- plugins/httpsreq/httpsreqfunc.h | 6 +- .../wasi_crypto/asymmetric_common/func.cpp | 110 ++++++++------ plugins/wasi_crypto/asymmetric_common/func.h | 81 +++++----- plugins/wasi_crypto/common/func.cpp | 44 +++--- plugins/wasi_crypto/common/func.h | 22 +-- plugins/wasi_crypto/kx/func.cpp | 11 +- plugins/wasi_crypto/kx/func.h | 12 +- plugins/wasi_crypto/signatures/func.cpp | 39 ++--- plugins/wasi_crypto/signatures/func.h | 30 ++-- plugins/wasi_crypto/symmetric/func.cpp | 143 ++++++++++-------- plugins/wasi_crypto/symmetric/func.h | 121 +++++++-------- plugins/wasi_crypto/utils/hostfunction.h | 2 +- plugins/wasi_nn/wasinnbase.h | 1 + plugins/wasi_nn/wasinnfunc.cpp | 28 ++-- plugins/wasi_nn/wasinnfunc.h | 13 +- plugins/wasmedge_process/processbase.h | 1 + plugins/wasmedge_process/processfunc.cpp | 50 +++--- plugins/wasmedge_process/processfunc.h | 30 ++-- 20 files changed, 399 insertions(+), 354 deletions(-) diff --git a/plugins/httpsreq/httpsreqbase.h b/plugins/httpsreq/httpsreqbase.h index 6cafe9b7..e7fcd4bd 100644 --- a/plugins/httpsreq/httpsreqbase.h +++ b/plugins/httpsreq/httpsreqbase.h @@ -5,6 +5,7 @@ #include "common/errcode.h" #include "httpsreqenv.h" +#include "runtime/callingframe.h" #include "runtime/hostfunc.h" namespace WasmEdge { diff --git a/plugins/httpsreq/httpsreqfunc.cpp b/plugins/httpsreq/httpsreqfunc.cpp index a2846d8f..ff989357 100644 --- a/plugins/httpsreq/httpsreqfunc.cpp +++ b/plugins/httpsreq/httpsreqfunc.cpp @@ -17,9 +17,10 @@ namespace WasmEdge { namespace Host { -Expect SendData::body(Runtime::Instance::MemoryInstance *MemInst, +Expect SendData::body(const Runtime::CallingFrame &Frame, uint32_t HostPtr, uint32_t HostLen, uint32_t Port, uint32_t BodyPtr, uint32_t BodyLen) { + auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); } @@ -124,8 +125,9 @@ Expect SendData::body(Runtime::Instance::MemoryInstance *MemInst, return {}; } -Expect HttpsReqGetRcv::body(Runtime::Instance::MemoryInstance *MemInst, +Expect HttpsReqGetRcv::body(const Runtime::CallingFrame &Frame, uint32_t BufPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); } @@ -134,7 +136,7 @@ Expect HttpsReqGetRcv::body(Runtime::Instance::MemoryInstance *MemInst, return {}; } -Expect HttpsReqGetRcvLen::body(Runtime::Instance::MemoryInstance *) { +Expect HttpsReqGetRcvLen::body(const Runtime::CallingFrame &) { return static_cast(Env.Rcv.size()); } diff --git a/plugins/httpsreq/httpsreqfunc.h b/plugins/httpsreq/httpsreqfunc.h index bc391fab..8f085ba5 100644 --- a/plugins/httpsreq/httpsreqfunc.h +++ b/plugins/httpsreq/httpsreqfunc.h @@ -14,7 +14,7 @@ namespace Host { class SendData : public HttpsReq { public: SendData(HttpsReqEnvironment &HostEnv) : HttpsReq(HostEnv) {} - Expect body(Runtime::Instance::MemoryInstance *, uint32_t HostPtr, + Expect body(const Runtime::CallingFrame &, uint32_t HostPtr, uint32_t HostLen, uint32_t Port, uint32_t BodyPtr, uint32_t BodyLen); }; @@ -22,13 +22,13 @@ class SendData : public HttpsReq { class HttpsReqGetRcv : public HttpsReq { public: HttpsReqGetRcv(HttpsReqEnvironment &HostEnv) : HttpsReq(HostEnv) {} - Expect body(Runtime::Instance::MemoryInstance *, uint32_t BufPtr); + Expect body(const Runtime::CallingFrame &, uint32_t BufPtr); }; class HttpsReqGetRcvLen : public HttpsReq { public: HttpsReqGetRcvLen(HttpsReqEnvironment &HostEnv) : HttpsReq(HostEnv) {} - Expect body(Runtime::Instance::MemoryInstance *); + Expect body(const Runtime::CallingFrame &); }; } // namespace Host diff --git a/plugins/wasi_crypto/asymmetric_common/func.cpp b/plugins/wasi_crypto/asymmetric_common/func.cpp index 7574cd25..a45443d0 100644 --- a/plugins/wasi_crypto/asymmetric_common/func.cpp +++ b/plugins/wasi_crypto/asymmetric_common/func.cpp @@ -10,11 +10,12 @@ namespace Host { namespace WasiCrypto { namespace AsymmetricCommon { -Expect -KeypairGenerate::body(Runtime::Instance::MemoryInstance *MemInst, - uint32_t AlgType, uint32_t AlgPtr, uint32_t AlgLen, - uint32_t OptOptionsHandlePtr, - uint32_t /* Out */ KpHandlePtr) { +Expect KeypairGenerate::body(const Runtime::CallingFrame &Frame, + uint32_t AlgType, uint32_t AlgPtr, + uint32_t AlgLen, + uint32_t OptOptionsHandlePtr, + uint32_t /* Out */ KpHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiAlgLen = AlgLen; @@ -49,11 +50,12 @@ KeypairGenerate::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect KeypairImport::body(Runtime::Instance::MemoryInstance *MemInst, +Expect KeypairImport::body(const Runtime::CallingFrame &Frame, uint32_t AlgType, uint32_t AlgPtr, uint32_t AlgLen, uint32_t EncodedPtr, uint32_t EncodedLen, uint32_t Encoding, uint32_t /* Out */ KpHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiAlgLen = AlgLen; @@ -94,9 +96,10 @@ Expect KeypairImport::body(Runtime::Instance::MemoryInstance *MemInst, } Expect KeypairGenerateManaged::body( - Runtime::Instance::MemoryInstance *MemInst, int32_t SecretsManagerHandle, + const Runtime::CallingFrame &Frame, int32_t SecretsManagerHandle, uint32_t AlgType, uint32_t AlgPtr, uint32_t AlgLen, uint32_t OptOptionsHandlePtr, uint32_t KpHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiAlgLen = AlgLen; @@ -132,10 +135,11 @@ Expect KeypairGenerateManaged::body( return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect -KeypairStoreManaged::body(Runtime::Instance::MemoryInstance *MemInst, - int32_t SecretsManagerHandle, int32_t KpHandle, - uint32_t KpIdPtr, uint32_t KpIdMaxLen) { +Expect KeypairStoreManaged::body(const Runtime::CallingFrame &Frame, + int32_t SecretsManagerHandle, + int32_t KpHandle, uint32_t KpIdPtr, + uint32_t KpIdMaxLen) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiKpIdMaxLen = KpIdMaxLen; @@ -151,9 +155,12 @@ KeypairStoreManaged::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect KeypairReplaceManaged::body( - Runtime::Instance::MemoryInstance *MemInst, int32_t SecretsManagerHandle, - int32_t OldKpHandle, int32_t NewKpHandle, uint32_t /* Out */ KpVersionPtr) { +Expect KeypairReplaceManaged::body(const Runtime::CallingFrame &Frame, + int32_t SecretsManagerHandle, + int32_t OldKpHandle, + int32_t NewKpHandle, + uint32_t /* Out */ KpVersionPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); auto *const KpVersion = MemInst->getPointer<__wasi_version_t *>(KpVersionPtr); @@ -170,11 +177,12 @@ Expect KeypairReplaceManaged::body( return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect KeypairId::body(Runtime::Instance::MemoryInstance *MemInst, +Expect KeypairId::body(const Runtime::CallingFrame &Frame, int32_t KpHandle, uint32_t KpIdPtr, uint32_t KpIdMaxLen, uint32_t /* Out */ SizePtr, uint32_t /* Out */ KpVersionPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiKpIdMaxLen = KpIdMaxLen; @@ -206,12 +214,12 @@ Expect KeypairId::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect KeypairFromId::body(Runtime::Instance::MemoryInstance *MemInst, +Expect KeypairFromId::body(const Runtime::CallingFrame &Frame, int32_t SecretsManagerHandle, uint32_t KpIdPtr, uint32_t KpIdLen, uint64_t KpVersion, uint32_t /* Out */ KpHandlePtr) { - + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiKpIdLen = KpIdLen; @@ -232,10 +240,10 @@ Expect KeypairFromId::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect -KeypairFromPkAndSk::body(Runtime::Instance::MemoryInstance *MemInst, - int32_t PkHandle, int32_t SkHandle, - uint32_t /* Out */ KpHandlePtr) { +Expect KeypairFromPkAndSk::body(const Runtime::CallingFrame &Frame, + int32_t PkHandle, int32_t SkHandle, + uint32_t /* Out */ KpHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); auto *const KpHandle = MemInst->getPointer<__wasi_keypair_t *>(KpHandlePtr); @@ -250,9 +258,10 @@ KeypairFromPkAndSk::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect KeypairExport::body(Runtime::Instance::MemoryInstance *MemInst, +Expect KeypairExport::body(const Runtime::CallingFrame &Frame, int32_t KpHandle, uint32_t KpEncoding, uint32_t /* Out */ ArrayOutputHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); __wasi_keypair_encoding_e_t WasiKpEncoding; @@ -276,9 +285,10 @@ Expect KeypairExport::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect -KeypairPublickey::body(Runtime::Instance::MemoryInstance *MemInst, - int32_t KpHandle, uint32_t /* Out */ PkHandlePtr) { +Expect KeypairPublickey::body(const Runtime::CallingFrame &Frame, + int32_t KpHandle, + uint32_t /* Out */ PkHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); auto *const PkHandle = MemInst->getPointer<__wasi_keypair_t *>(PkHandlePtr); @@ -293,9 +303,10 @@ KeypairPublickey::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect -KeypairSecretkey::body(Runtime::Instance::MemoryInstance *MemInst, - int32_t KpHandle, uint32_t /* Out */ SkHandlePtr) { +Expect KeypairSecretkey::body(const Runtime::CallingFrame &Frame, + int32_t KpHandle, + uint32_t /* Out */ SkHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); auto *const SkHandle = MemInst->getPointer<__wasi_keypair_t *>(SkHandlePtr); @@ -310,7 +321,7 @@ KeypairSecretkey::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect KeypairClose::body(Runtime::Instance::MemoryInstance *, +Expect KeypairClose::body(const Runtime::CallingFrame &, int32_t KpHandle) { if (auto Res = Ctx.keypairClose(KpHandle); unlikely(!Res)) { return __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE; @@ -319,11 +330,12 @@ Expect KeypairClose::body(Runtime::Instance::MemoryInstance *, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect -PublickeyImport::body(Runtime::Instance::MemoryInstance *MemInst, - uint32_t AlgType, uint32_t AlgPtr, uint32_t AlgLen, - uint32_t EncodedPtr, uint32_t EncodedLen, - uint32_t Encoding, uint32_t /* Out */ PkHandlePtr) { +Expect PublickeyImport::body(const Runtime::CallingFrame &Frame, + uint32_t AlgType, uint32_t AlgPtr, + uint32_t AlgLen, uint32_t EncodedPtr, + uint32_t EncodedLen, uint32_t Encoding, + uint32_t /* Out */ PkHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiAlgLen = AlgLen; @@ -368,9 +380,10 @@ PublickeyImport::body(Runtime::Instance::MemoryInstance *MemInst, } Expect -PublickeyExport::body(Runtime::Instance::MemoryInstance *MemInst, - int32_t PkHandle, uint32_t PkEncoding, +PublickeyExport::body(const Runtime::CallingFrame &Frame, int32_t PkHandle, + uint32_t PkEncoding, uint32_t /* Out */ ArrayOutputHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); __wasi_publickey_encoding_e_t WasiPkEncoding; @@ -394,7 +407,7 @@ PublickeyExport::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect PublickeyVerify::body(Runtime::Instance::MemoryInstance *, +Expect PublickeyVerify::body(const Runtime::CallingFrame &, int32_t PkHandle) { if (auto Res = Ctx.publickeyVerify(PkHandle); unlikely(!Res)) { return Res.error(); @@ -404,8 +417,9 @@ Expect PublickeyVerify::body(Runtime::Instance::MemoryInstance *, } Expect -PublickeyFromSecretkey::body(Runtime::Instance::MemoryInstance *MemInst, +PublickeyFromSecretkey::body(const Runtime::CallingFrame &Frame, int32_t SkHandle, uint32_t /* Out */ PkHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); auto *const PkHandle = MemInst->getPointer<__wasi_publickey_t *>(PkHandlePtr); @@ -420,7 +434,7 @@ PublickeyFromSecretkey::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect PublickeyClose::body(Runtime::Instance::MemoryInstance *, +Expect PublickeyClose::body(const Runtime::CallingFrame &, int32_t PkHandle) { if (auto Res = Ctx.publickeyClose(PkHandle); unlikely(!Res)) { return Res.error(); @@ -429,11 +443,12 @@ Expect PublickeyClose::body(Runtime::Instance::MemoryInstance *, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect -SecretkeyImport::body(Runtime::Instance::MemoryInstance *MemInst, - uint32_t AlgType, uint32_t AlgPtr, uint32_t AlgLen, - uint32_t EncodedPtr, uint32_t EncodedLen, - uint32_t Encoding, uint32_t /* Out */ SkHandlePtr) { +Expect SecretkeyImport::body(const Runtime::CallingFrame &Frame, + uint32_t AlgType, uint32_t AlgPtr, + uint32_t AlgLen, uint32_t EncodedPtr, + uint32_t EncodedLen, uint32_t Encoding, + uint32_t /* Out */ SkHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiAlgLen = AlgLen; @@ -476,9 +491,10 @@ SecretkeyImport::body(Runtime::Instance::MemoryInstance *MemInst, } Expect -SecretkeyExport::body(Runtime::Instance::MemoryInstance *MemInst, - int32_t SkHandle, uint32_t SkEncoding, +SecretkeyExport::body(const Runtime::CallingFrame &Frame, int32_t SkHandle, + uint32_t SkEncoding, uint32_t /* Out */ ArrayOutputHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); __wasi_secretkey_encoding_e_t WasiSkEncoding; @@ -502,7 +518,7 @@ SecretkeyExport::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect SecretkeyClose::body(Runtime::Instance::MemoryInstance *, +Expect SecretkeyClose::body(const Runtime::CallingFrame &, int32_t Sk) { if (auto Res = Ctx.secretkeyClose(Sk); unlikely(!Res)) { return __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE; diff --git a/plugins/wasi_crypto/asymmetric_common/func.h b/plugins/wasi_crypto/asymmetric_common/func.h index 2a5be27d..1157751b 100644 --- a/plugins/wasi_crypto/asymmetric_common/func.h +++ b/plugins/wasi_crypto/asymmetric_common/func.h @@ -25,8 +25,8 @@ namespace AsymmetricCommon { class KeypairGenerate : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - uint32_t AlgType, uint32_t AlgPtr, uint32_t AlgLen, + Expect body(const Runtime::CallingFrame &Frame, uint32_t AlgType, + uint32_t AlgPtr, uint32_t AlgLen, uint32_t OptOptionsHandlePtr, uint32_t /* Out */ KpHandlePtr); }; @@ -34,16 +34,16 @@ class KeypairGenerate : public HostFunction { class KeypairImport : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - uint32_t AlgType, uint32_t AlgPtr, uint32_t AlgLen, - uint32_t EncodedPtr, uint32_t EncodedLen, - uint32_t Encoding, uint32_t /* Out */ KpHandlePtr); + Expect body(const Runtime::CallingFrame &Frame, uint32_t AlgType, + uint32_t AlgPtr, uint32_t AlgLen, uint32_t EncodedPtr, + uint32_t EncodedLen, uint32_t Encoding, + uint32_t /* Out */ KpHandlePtr); }; class KeypairGenerateManaged : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, + Expect body(const Runtime::CallingFrame &Frame, int32_t SecretsManagerHandle, uint32_t AlgType, uint32_t AlgPtr, uint32_t AlgLen, uint32_t OptOptionsHandlePtr, uint32_t KpHandlePtr); @@ -52,7 +52,7 @@ class KeypairGenerateManaged : public HostFunction { class KeypairStoreManaged : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, + Expect body(const Runtime::CallingFrame &Frame, int32_t SecretsManagerHandle, int32_t KpHandle, uint32_t KpIdPtr, uint32_t KpIdMaxLen); }; @@ -60,7 +60,7 @@ class KeypairStoreManaged : public HostFunction { class KeypairReplaceManaged : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, + Expect body(const Runtime::CallingFrame &Frame, int32_t SecretsManagerHandle, int32_t OldKpHandle, int32_t NewKpHandle, uint32_t /* Out */ KpVersionPtr); }; @@ -68,8 +68,8 @@ class KeypairReplaceManaged : public HostFunction { class KeypairId : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t KpHandle, uint32_t KpIdPtr, uint32_t KpIdMaxLen, + Expect body(const Runtime::CallingFrame &Frame, int32_t KpHandle, + uint32_t KpIdPtr, uint32_t KpIdMaxLen, uint32_t /* Out */ SizePtr, uint32_t /* Out */ KpVersionPtr); }; @@ -77,7 +77,7 @@ class KeypairId : public HostFunction { class KeypairFromId : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, + Expect body(const Runtime::CallingFrame &Frame, int32_t SecretsManagerHandle, uint32_t KpIdPtr, uint32_t KpIdLen, uint64_t KpVersion, uint32_t /* Out */ KpHandlePtr); @@ -86,100 +86,95 @@ class KeypairFromId : public HostFunction { class KeypairFromPkAndSk : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t PkHandle, int32_t SkHandle, - uint32_t /* Out */ KpHandlePtr); + Expect body(const Runtime::CallingFrame &Frame, int32_t PkHandle, + int32_t SkHandle, uint32_t /* Out */ KpHandlePtr); }; class KeypairExport : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t KpHandle, uint32_t KpEncoding, + Expect body(const Runtime::CallingFrame &Frame, int32_t KpHandle, + uint32_t KpEncoding, uint32_t /* Out */ ArrayOutputHandlePtr); }; class KeypairPublickey : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t KpHandle, uint32_t /* Out */ PkHandlePtr); + Expect body(const Runtime::CallingFrame &Frame, int32_t KpHandle, + uint32_t /* Out */ PkHandlePtr); }; class KeypairSecretkey : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t KpHandle, uint32_t /* Out */ SkHandlePtr); + Expect body(const Runtime::CallingFrame &Frame, int32_t KpHandle, + uint32_t /* Out */ SkHandlePtr); }; class KeypairClose : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t KpHandle); + Expect body(const Runtime::CallingFrame &Frame, int32_t KpHandle); }; class PublickeyImport : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - uint32_t AlgType, uint32_t AlgPtr, uint32_t AlgLen, - uint32_t EncodedPtr, uint32_t EncodedLen, - uint32_t Encoding, uint32_t /* Out */ PkHandlePtr); + Expect body(const Runtime::CallingFrame &Frame, uint32_t AlgType, + uint32_t AlgPtr, uint32_t AlgLen, uint32_t EncodedPtr, + uint32_t EncodedLen, uint32_t Encoding, + uint32_t /* Out */ PkHandlePtr); }; class PublickeyExport : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t PkHandle, uint32_t PkEncoding, + Expect body(const Runtime::CallingFrame &Frame, int32_t PkHandle, + uint32_t PkEncoding, uint32_t /* Out */ ArrayOutputHandlePtr); }; class PublickeyVerify : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t PkHandle); + Expect body(const Runtime::CallingFrame &Frame, int32_t PkHandle); }; class PublickeyFromSecretkey : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t SkHandle, uint32_t /* Out */ PkHandlePtr); + Expect body(const Runtime::CallingFrame &Frame, int32_t SkHandle, + uint32_t /* Out */ PkHandlePtr); }; class PublickeyClose : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t PkHandle); + Expect body(const Runtime::CallingFrame &Frame, int32_t PkHandle); }; class SecretkeyImport : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - uint32_t AlgType, uint32_t AlgPtr, uint32_t AlgLen, - uint32_t EncodedPtr, uint32_t EncodedLen, - uint32_t Encoding, uint32_t /* Out */ SkHandlePtr); + Expect body(const Runtime::CallingFrame &Frame, uint32_t AlgType, + uint32_t AlgPtr, uint32_t AlgLen, uint32_t EncodedPtr, + uint32_t EncodedLen, uint32_t Encoding, + uint32_t /* Out */ SkHandlePtr); }; class SecretkeyExport : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t SkHandle, uint32_t SkEncoding, + Expect body(const Runtime::CallingFrame &Frame, int32_t SkHandle, + uint32_t SkEncoding, uint32_t /* Out */ ArrayOutputHandlePtr); }; class SecretkeyClose : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t SkHandle); + Expect body(const Runtime::CallingFrame &Frame, int32_t SkHandle); }; } // namespace AsymmetricCommon diff --git a/plugins/wasi_crypto/common/func.cpp b/plugins/wasi_crypto/common/func.cpp index 1da1a208..73676491 100644 --- a/plugins/wasi_crypto/common/func.cpp +++ b/plugins/wasi_crypto/common/func.cpp @@ -8,10 +8,11 @@ namespace Host { namespace WasiCrypto { namespace Common { -Expect -ArrayOutputLen::body(Runtime::Instance::MemoryInstance *MemInst, - int32_t ArrayOutputHandle, uint32_t /* Out */ SizePtr) { +Expect ArrayOutputLen::body(const Runtime::CallingFrame &Frame, + int32_t ArrayOutputHandle, + uint32_t /* Out */ SizePtr) { // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); auto *const Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); @@ -27,11 +28,12 @@ ArrayOutputLen::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect -ArrayOutputPull::body(Runtime::Instance::MemoryInstance *MemInst, - int32_t ArrayOutputHandle, uint32_t BufPtr, - uint32_t BufLen, uint32_t /* Out */ SizePtr) { +Expect ArrayOutputPull::body(const Runtime::CallingFrame &Frame, + int32_t ArrayOutputHandle, + uint32_t BufPtr, uint32_t BufLen, + uint32_t /* Out */ SizePtr) { // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiBufLen = BufLen; @@ -52,10 +54,11 @@ ArrayOutputPull::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect OptionsOpen::body(Runtime::Instance::MemoryInstance *MemInst, +Expect OptionsOpen::body(const Runtime::CallingFrame &Frame, uint32_t AlgType, uint32_t /* Out */ OptionsHandlePtr) { // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); __wasi_algorithm_type_e_t WasiAlgType; @@ -78,9 +81,8 @@ Expect OptionsOpen::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect OptionsClose::body(Runtime::Instance::MemoryInstance *, +Expect OptionsClose::body(const Runtime::CallingFrame &, int32_t OptionsHandle) { - if (auto Res = Ctx.optionsClose(OptionsHandle); unlikely(!Res)) { return Res.error(); } @@ -88,11 +90,12 @@ Expect OptionsClose::body(Runtime::Instance::MemoryInstance *, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect OptionsSet::body(Runtime::Instance::MemoryInstance *MemInst, +Expect OptionsSet::body(const Runtime::CallingFrame &Frame, int32_t OptionsHandle, uint32_t NamePtr, uint32_t NameLen, uint32_t ValuePtr, uint32_t ValueLen) { // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiNameLen = NameLen; @@ -113,10 +116,11 @@ Expect OptionsSet::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect OptionsSetU64::body(Runtime::Instance::MemoryInstance *MemInst, +Expect OptionsSetU64::body(const Runtime::CallingFrame &Frame, int32_t OptionsHandle, uint32_t NamePtr, uint32_t NameLen, uint64_t Value) { // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiNameLen = NameLen; @@ -131,10 +135,12 @@ Expect OptionsSetU64::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect OptionsSetGuestBuffer::body( - Runtime::Instance::MemoryInstance *MemInst, int32_t OptionsHandle, - uint32_t NamePtr, uint32_t NameLen, uint32_t BufPtr, uint32_t BufLen) { +Expect OptionsSetGuestBuffer::body(const Runtime::CallingFrame &Frame, + int32_t OptionsHandle, + uint32_t NamePtr, uint32_t NameLen, + uint32_t BufPtr, uint32_t BufLen) { // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiNameLen = NameLen; @@ -155,10 +161,11 @@ Expect OptionsSetGuestBuffer::body( } Expect -SecretsManagerOpen::body(Runtime::Instance::MemoryInstance *MemInst, +SecretsManagerOpen::body(const Runtime::CallingFrame &Frame, uint32_t OptOptionsHandlePtr, uint32_t /* Out */ SecretsManagerHandlePtr) { // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); auto *const OptOptionsHandle = @@ -178,7 +185,7 @@ SecretsManagerOpen::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect SecretsManagerClose::body(Runtime::Instance::MemoryInstance *, +Expect SecretsManagerClose::body(const Runtime::CallingFrame &, int32_t SecretsManagerHandle) { if (auto Res = Ctx.secretsManagerClose(SecretsManagerHandle); unlikely(!Res)) { @@ -189,10 +196,11 @@ Expect SecretsManagerClose::body(Runtime::Instance::MemoryInstance *, } Expect -SecretsManagerInvalidate::body(Runtime::Instance::MemoryInstance *MemInst, +SecretsManagerInvalidate::body(const Runtime::CallingFrame &Frame, int32_t SecretsManagerHandle, uint32_t KeyIdPtr, uint32_t KeyIdLen, uint64_t Version) { // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiKeyIdLen = KeyIdLen; diff --git a/plugins/wasi_crypto/common/func.h b/plugins/wasi_crypto/common/func.h index c08bbede..d29a9f3e 100644 --- a/plugins/wasi_crypto/common/func.h +++ b/plugins/wasi_crypto/common/func.h @@ -24,14 +24,14 @@ namespace Common { class ArrayOutputLen : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, + Expect body(const Runtime::CallingFrame &Frame, int32_t ArrayOutputHandle, uint32_t /* Out */ SizePtr); }; class ArrayOutputPull : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, + Expect body(const Runtime::CallingFrame &Frame, int32_t ArrayOutputHandle, uint32_t BufPtr, uint32_t BufLen, uint32_t /* Out */ SizePtr); }; @@ -39,21 +39,21 @@ class ArrayOutputPull : public HostFunction { class OptionsOpen : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - uint32_t AlgType, uint32_t /* Out */ OptionsHandlePtr); + Expect body(const Runtime::CallingFrame &Frame, uint32_t AlgType, + uint32_t /* Out */ OptionsHandlePtr); }; class OptionsClose : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, + Expect body(const Runtime::CallingFrame &Frame, int32_t OptionsHandle); }; class OptionsSet : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, + Expect body(const Runtime::CallingFrame &Frame, int32_t OptionsHandle, uint32_t NamePtr, uint32_t NameLen, uint32_t ValuePtr, uint32_t ValueLen); }; @@ -61,7 +61,7 @@ class OptionsSet : public HostFunction { class OptionsSetU64 : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, + Expect body(const Runtime::CallingFrame &Frame, int32_t OptionsHandle, uint32_t NamePtr, uint32_t NameLen, uint64_t Value); }; @@ -69,7 +69,7 @@ class OptionsSetU64 : public HostFunction { class OptionsSetGuestBuffer : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, + Expect body(const Runtime::CallingFrame &Frame, int32_t OptionsHandle, uint32_t NamePtr, uint32_t NameLen, uint32_t BufPtr, uint32_t BufLen); }; @@ -77,7 +77,7 @@ class OptionsSetGuestBuffer : public HostFunction { class SecretsManagerOpen : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, + Expect body(const Runtime::CallingFrame &Frame, uint32_t OptOptionsHandlePtr, uint32_t /* Out */ SecretsManagerHandlePtr); }; @@ -85,14 +85,14 @@ class SecretsManagerOpen : public HostFunction { class SecretsManagerClose : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, + Expect body(const Runtime::CallingFrame &Frame, int32_t SecretsManagerHandle); }; class SecretsManagerInvalidate : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, + Expect body(const Runtime::CallingFrame &Frame, int32_t SecretsManagerHandle, uint32_t KeyIdPtr, uint32_t KeyIdLen, uint64_t Version); }; diff --git a/plugins/wasi_crypto/kx/func.cpp b/plugins/wasi_crypto/kx/func.cpp index 471d3a97..d0533546 100644 --- a/plugins/wasi_crypto/kx/func.cpp +++ b/plugins/wasi_crypto/kx/func.cpp @@ -8,9 +8,10 @@ namespace Host { namespace WasiCrypto { namespace Kx { -Expect Dh::body(Runtime::Instance::MemoryInstance *MemInst, - int32_t PkHandle, int32_t SkHandle, +Expect Dh::body(const Runtime::CallingFrame &Frame, int32_t PkHandle, + int32_t SkHandle, uint32_t /* Out */ SharedSecretPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); auto *const SharedSecret = @@ -26,10 +27,11 @@ Expect Dh::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect Encapsulate::body(Runtime::Instance::MemoryInstance *MemInst, +Expect Encapsulate::body(const Runtime::CallingFrame &Frame, int32_t PkHandle, uint32_t /* Out */ SecretPtr, uint32_t /* Out */ EncapsulatedSecretPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); auto *const Secret = MemInst->getPointer<__wasi_array_output_t *>(SecretPtr); @@ -48,11 +50,12 @@ Expect Encapsulate::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect Decapsulate::body(Runtime::Instance::MemoryInstance *MemInst, +Expect Decapsulate::body(const Runtime::CallingFrame &Frame, int32_t SkHandle, uint32_t EncapsulatedSecretPtr, uint32_t EncapsulatedSecretLen, uint32_t /* Out */ SecretPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiEncapsulatedSecretLen = EncapsulatedSecretLen; diff --git a/plugins/wasi_crypto/kx/func.h b/plugins/wasi_crypto/kx/func.h index b2d4b945..c0f83789 100644 --- a/plugins/wasi_crypto/kx/func.h +++ b/plugins/wasi_crypto/kx/func.h @@ -24,24 +24,24 @@ namespace Kx { class Dh : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t PkHandle, int32_t SkHandle, + Expect body(const Runtime::CallingFrame &Frame, int32_t PkHandle, + int32_t SkHandle, uint32_t /* Out */ ArrayOutputHandlePtr); }; class Encapsulate : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t PkHandle, uint32_t /* Out */ SecretPtr, + Expect body(const Runtime::CallingFrame &Frame, int32_t PkHandle, + uint32_t /* Out */ SecretPtr, uint32_t /* Out */ EncapsulatedSecretPtr); }; class Decapsulate : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t SkHandle, uint32_t EncapsulatedSecretPtr, + Expect body(const Runtime::CallingFrame &Frame, int32_t SkHandle, + uint32_t EncapsulatedSecretPtr, uint32_t EncapsulatedSecretLen, uint32_t /* Out */ SecretPtr); }; diff --git a/plugins/wasi_crypto/signatures/func.cpp b/plugins/wasi_crypto/signatures/func.cpp index c78df90e..f5e3069a 100644 --- a/plugins/wasi_crypto/signatures/func.cpp +++ b/plugins/wasi_crypto/signatures/func.cpp @@ -8,9 +8,10 @@ namespace Host { namespace WasiCrypto { namespace Signatures { -Expect Export::body(Runtime::Instance::MemoryInstance *MemInst, +Expect Export::body(const Runtime::CallingFrame &Frame, int32_t SigHandle, uint32_t Encoding, uint32_t /* Out */ ArrayOutputHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); __wasi_signature_encoding_e_t WasiEncoding; @@ -34,11 +35,12 @@ Expect Export::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect Import::body(Runtime::Instance::MemoryInstance *MemInst, +Expect Import::body(const Runtime::CallingFrame &Frame, uint32_t AlgPtr, uint32_t AlgLen, uint32_t EncodedPtr, uint32_t EncodedLen, uint32_t Encoding, uint32_t /* Out */ SigHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiAlgLen = AlgLen; @@ -80,9 +82,10 @@ Expect Import::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect StateOpen::body(Runtime::Instance::MemoryInstance *MemInst, +Expect StateOpen::body(const Runtime::CallingFrame &Frame, int32_t KpHandle, uint32_t /* Out */ SigStatePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); auto *const SigState = @@ -98,9 +101,10 @@ Expect StateOpen::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect StateUpdate::body(Runtime::Instance::MemoryInstance *MemInst, +Expect StateUpdate::body(const Runtime::CallingFrame &Frame, int32_t SigStateHandle, uint32_t InputPtr, uint32_t InputSize) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiInputSize = InputSize; @@ -117,9 +121,10 @@ Expect StateUpdate::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect StateSign::body(Runtime::Instance::MemoryInstance *MemInst, +Expect StateSign::body(const Runtime::CallingFrame &Frame, int32_t SigStateHandle, uint32_t /* Out */ ArrayOutputHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); auto *const ArrayOutputHandle = @@ -135,8 +140,9 @@ Expect StateSign::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect StateClose::body(Runtime::Instance::MemoryInstance *MemInst, +Expect StateClose::body(const Runtime::CallingFrame &Frame, int32_t SigStateHandle) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); if (auto Res = Ctx.signatureStateClose(SigStateHandle); unlikely(!Res)) { @@ -147,9 +153,10 @@ Expect StateClose::body(Runtime::Instance::MemoryInstance *MemInst, } Expect -VerificationStateOpen::body(Runtime::Instance::MemoryInstance *MemInst, +VerificationStateOpen::body(const Runtime::CallingFrame &Frame, int32_t SigPkHandle, uint32_t /* Out */ VerificationStateHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); auto *const VerificationStateHandle = @@ -168,9 +175,10 @@ VerificationStateOpen::body(Runtime::Instance::MemoryInstance *MemInst, } Expect -VerificationStateUpdate::body(Runtime::Instance::MemoryInstance *MemInst, +VerificationStateUpdate::body(const Runtime::CallingFrame &Frame, int32_t SigStateHandle, uint32_t InputPtr, uint32_t InputSize) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiInputSize = InputSize; @@ -186,10 +194,9 @@ VerificationStateUpdate::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect -VerificationStateVerify::body(Runtime::Instance::MemoryInstance *, - int32_t VerificationStateHandle, - int32_t SigHandle) { +Expect VerificationStateVerify::body(const Runtime::CallingFrame &, + int32_t VerificationStateHandle, + int32_t SigHandle) { if (auto Res = Ctx.signatureVerificationStateVerify(VerificationStateHandle, SigHandle); unlikely(!Res)) { @@ -199,9 +206,8 @@ VerificationStateVerify::body(Runtime::Instance::MemoryInstance *, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect -VerificationStateClose::body(Runtime::Instance::MemoryInstance *, - int32_t VerificationStateHandle) { +Expect VerificationStateClose::body(const Runtime::CallingFrame &, + int32_t VerificationStateHandle) { if (auto Res = Ctx.signatureVerificationStateClose(VerificationStateHandle); unlikely(!Res)) { return Res.error(); @@ -210,8 +216,7 @@ VerificationStateClose::body(Runtime::Instance::MemoryInstance *, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect Close::body(Runtime::Instance::MemoryInstance *, - int32_t SigHandle) { +Expect Close::body(const Runtime::CallingFrame &, int32_t SigHandle) { if (auto Res = Ctx.signatureClose(SigHandle); unlikely(!Res)) { return Res.error(); } diff --git a/plugins/wasi_crypto/signatures/func.h b/plugins/wasi_crypto/signatures/func.h index 6c0599f1..56683080 100644 --- a/plugins/wasi_crypto/signatures/func.h +++ b/plugins/wasi_crypto/signatures/func.h @@ -26,16 +26,16 @@ namespace Signatures { class Export : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t SigHandle, uint32_t Encoding, + Expect body(const Runtime::CallingFrame &Frame, int32_t SigHandle, + uint32_t Encoding, uint32_t /* Out */ ArrayOutputHandlePtr); }; class Import : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - uint32_t AlgPtr, uint32_t AlgLen, uint32_t EncodedPtr, + Expect body(const Runtime::CallingFrame &Frame, uint32_t AlgPtr, + uint32_t AlgLen, uint32_t EncodedPtr, uint32_t EncodedLen, uint32_t Encoding, uint32_t /* Out */ SigHandlePtr); }; @@ -43,14 +43,14 @@ class Import : public HostFunction { class StateOpen : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t KpHandle, uint32_t /* Out */ SigStatePtr); + Expect body(const Runtime::CallingFrame &Frame, int32_t KpHandle, + uint32_t /* Out */ SigStatePtr); }; class StateUpdate : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, + Expect body(const Runtime::CallingFrame &Frame, int32_t SigStateHandle, uint32_t InputPtr, uint32_t InputSize); }; @@ -58,7 +58,7 @@ class StateUpdate : public HostFunction { class StateSign : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, + Expect body(const Runtime::CallingFrame &Frame, int32_t SigStateHandle, uint32_t /* Out */ ArrayOutputHandlePtr); }; @@ -66,22 +66,21 @@ class StateSign : public HostFunction { class StateClose : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, + Expect body(const Runtime::CallingFrame &Frame, int32_t SigStateHandle); }; class VerificationStateOpen : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t SigPkHandle, + Expect body(const Runtime::CallingFrame &Frame, int32_t SigPkHandle, uint32_t /* Out */ VerificationStateHandlePtr); }; class VerificationStateUpdate : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, + Expect body(const Runtime::CallingFrame &Frame, int32_t SigStateHandle, uint32_t InputPtr, uint32_t InputSize); }; @@ -89,22 +88,21 @@ class VerificationStateUpdate : public HostFunction { class VerificationStateVerify : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, + Expect body(const Runtime::CallingFrame &Frame, int32_t VerificationStateHandle, int32_t SigHandle); }; class VerificationStateClose : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, + Expect body(const Runtime::CallingFrame &Frame, int32_t VerificationStateHandle); }; class Close : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t SigHandle); + Expect body(const Runtime::CallingFrame &Frame, int32_t SigHandle); }; } // namespace Signatures diff --git a/plugins/wasi_crypto/symmetric/func.cpp b/plugins/wasi_crypto/symmetric/func.cpp index 2d4864e2..315b7c4a 100644 --- a/plugins/wasi_crypto/symmetric/func.cpp +++ b/plugins/wasi_crypto/symmetric/func.cpp @@ -8,10 +8,11 @@ namespace Host { namespace WasiCrypto { namespace Symmetric { -Expect KeyGenerate::body(Runtime::Instance::MemoryInstance *MemInst, +Expect KeyGenerate::body(const Runtime::CallingFrame &Frame, uint32_t AlgPtr, uint32_t AlgLen, uint32_t OptOptionsPtr, uint32_t /* Out */ KeyHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiAlgLen = AlgLen; @@ -42,10 +43,11 @@ Expect KeyGenerate::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect KeyImport::body(Runtime::Instance::MemoryInstance *MemInst, +Expect KeyImport::body(const Runtime::CallingFrame &Frame, uint32_t AlgPtr, uint32_t AlgLen, uint32_t RawPtr, uint32_t RawLen, uint32_t /* Out */ KeyPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiAlgLen = AlgLen; @@ -76,9 +78,10 @@ Expect KeyImport::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect KeyExport::body(Runtime::Instance::MemoryInstance *MemInst, +Expect KeyExport::body(const Runtime::CallingFrame &Frame, int32_t KeyHandle, uint32_t /* Out */ ArrayOutputHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); auto *const ArrayOutputHandle = @@ -94,7 +97,7 @@ Expect KeyExport::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect KeyClose::body(Runtime::Instance::MemoryInstance *, +Expect KeyClose::body(const Runtime::CallingFrame &, int32_t KeyHandle) { if (auto Res = Ctx.symmetricKeyClose(KeyHandle); unlikely(!Res)) { return Res.error(); @@ -103,11 +106,12 @@ Expect KeyClose::body(Runtime::Instance::MemoryInstance *, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect -KeyGenerateManaged::body(Runtime::Instance::MemoryInstance *MemInst, - int32_t SecretsManagerHandle, uint32_t AlgPtr, - uint32_t AlgLen, uint32_t OptOptionsPtr, - uint32_t /* Out */ KeyHandlePtr) { +Expect KeyGenerateManaged::body(const Runtime::CallingFrame &Frame, + int32_t SecretsManagerHandle, + uint32_t AlgPtr, uint32_t AlgLen, + uint32_t OptOptionsPtr, + uint32_t /* Out */ KeyHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiAlgLen = AlgLen; @@ -139,10 +143,11 @@ KeyGenerateManaged::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect -KeyStoreManaged::body(Runtime::Instance::MemoryInstance *MemInst, - int32_t SecretsManagerHandle, int32_t KeyHandle, - uint32_t KeyIdPtr, uint32_t KeyIdMaxLen) { +Expect KeyStoreManaged::body(const Runtime::CallingFrame &Frame, + int32_t SecretsManagerHandle, + int32_t KeyHandle, uint32_t KeyIdPtr, + uint32_t KeyIdMaxLen) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiKeyIdMaxLen = KeyIdMaxLen; @@ -158,11 +163,12 @@ KeyStoreManaged::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect -KeyReplaceManaged::body(Runtime::Instance::MemoryInstance *MemInst, - int32_t SecretsManagerHandle, int32_t OldKeyHandle, - int32_t NewKeyHandle, - uint32_t /* Out */ KeyVersionPtr) { +Expect KeyReplaceManaged::body(const Runtime::CallingFrame &Frame, + int32_t SecretsManagerHandle, + int32_t OldKeyHandle, + int32_t NewKeyHandle, + uint32_t /* Out */ KeyVersionPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); auto *const KeyVersion = @@ -180,10 +186,11 @@ KeyReplaceManaged::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect KeyId::body(Runtime::Instance::MemoryInstance *MemInst, +Expect KeyId::body(const Runtime::CallingFrame &Frame, int32_t KeyHandle, uint32_t KeyIdPtr, uint32_t KeyIdMaxLen, uint32_t /* Out */ SizePtr, uint32_t /* Out */ KeyVersionPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiKeyIdMaxLen = KeyIdMaxLen; @@ -216,11 +223,12 @@ Expect KeyId::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect KeyFromId::body(Runtime::Instance::MemoryInstance *MemInst, +Expect KeyFromId::body(const Runtime::CallingFrame &Frame, int32_t SecretsManagerHandle, uint32_t KeyIdPtr, uint32_t KeyIdLen, uint64_t KeyVersion, uint32_t /* Out */ KeyHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiKeyIdLen = KeyIdLen; @@ -242,11 +250,12 @@ Expect KeyFromId::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect StateOpen::body(Runtime::Instance::MemoryInstance *MemInst, +Expect StateOpen::body(const Runtime::CallingFrame &Frame, uint32_t AlgPtr, uint32_t AlgLen, uint32_t OptKeyHandlePtr, uint32_t OptOptionsPtr, uint32_t /* Out */ StatePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiAlgLen = AlgLen; @@ -282,9 +291,10 @@ Expect StateOpen::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect StateClone::body(Runtime::Instance::MemoryInstance *MemInst, +Expect StateClone::body(const Runtime::CallingFrame &Frame, int32_t StateHandle, uint32_t /* Out */ StatePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); auto *const State = MemInst->getPointer<__wasi_symmetric_state_t *>(StatePtr); @@ -301,11 +311,12 @@ Expect StateClone::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect -StateOptionsGet::body(Runtime::Instance::MemoryInstance *MemInst, - int32_t StateHandle, uint32_t NamePtr, uint32_t NameLen, - uint32_t ValuePtr, uint32_t ValueLen, - uint32_t /* Out */ SizePtr) { +Expect StateOptionsGet::body(const Runtime::CallingFrame &Frame, + int32_t StateHandle, uint32_t NamePtr, + uint32_t NameLen, uint32_t ValuePtr, + uint32_t ValueLen, + uint32_t /* Out */ SizePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiNameLen = NameLen; @@ -331,10 +342,11 @@ StateOptionsGet::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect -StateOptionsGetU64::body(Runtime::Instance::MemoryInstance *MemInst, - int32_t StateHandle, uint32_t NamePtr, - uint32_t NameLen, uint32_t /* Out */ U64Ptr) { +Expect StateOptionsGetU64::body(const Runtime::CallingFrame &Frame, + int32_t StateHandle, uint32_t NamePtr, + uint32_t NameLen, + uint32_t /* Out */ U64Ptr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiNameLen = NameLen; @@ -355,7 +367,7 @@ StateOptionsGetU64::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect StateClose::body(Runtime::Instance::MemoryInstance *, +Expect StateClose::body(const Runtime::CallingFrame &, int32_t StateHandle) { if (auto Res = Ctx.symmetricStateClose(StateHandle); unlikely(!Res)) { return Res.error(); @@ -364,9 +376,10 @@ Expect StateClose::body(Runtime::Instance::MemoryInstance *, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect StateAbsorb::body(Runtime::Instance::MemoryInstance *MemInst, +Expect StateAbsorb::body(const Runtime::CallingFrame &Frame, int32_t StateHandle, uint32_t DataPtr, uint32_t DataLen) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiDataLen = DataLen; @@ -381,9 +394,10 @@ Expect StateAbsorb::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect StateSqueeze::body(Runtime::Instance::MemoryInstance *MemInst, +Expect StateSqueeze::body(const Runtime::CallingFrame &Frame, int32_t StateHandle, uint32_t OutPtr, uint32_t OutLen) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiOutLen = OutLen; @@ -398,9 +412,10 @@ Expect StateSqueeze::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect -StateSqueezeTag::body(Runtime::Instance::MemoryInstance *MemInst, - int32_t StateHandle, uint32_t /* Out */ TagHandlePtr) { +Expect StateSqueezeTag::body(const Runtime::CallingFrame &Frame, + int32_t StateHandle, + uint32_t /* Out */ TagHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); auto *const TagHandle = @@ -416,10 +431,11 @@ StateSqueezeTag::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect -StateSqueezeKey::body(Runtime::Instance::MemoryInstance *MemInst, - int32_t StateHandle, uint32_t AlgPtr, uint32_t AlgLen, - uint32_t /* Out */ KeyHandlePtr) { +Expect StateSqueezeKey::body(const Runtime::CallingFrame &Frame, + int32_t StateHandle, uint32_t AlgPtr, + uint32_t AlgLen, + uint32_t /* Out */ KeyHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiAlgLen = AlgLen; @@ -446,9 +462,10 @@ StateSqueezeKey::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect -StateMaxTagLen::body(Runtime::Instance::MemoryInstance *MemInst, - int32_t StateHandle, uint32_t /* Out */ SizePtr) { +Expect StateMaxTagLen::body(const Runtime::CallingFrame &Frame, + int32_t StateHandle, + uint32_t /* Out */ SizePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); auto *const Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); @@ -464,11 +481,12 @@ StateMaxTagLen::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect StateEncrypt::body(Runtime::Instance::MemoryInstance *MemInst, +Expect StateEncrypt::body(const Runtime::CallingFrame &Frame, int32_t StateHandle, uint32_t OutPtr, uint32_t OutLen, uint32_t DataPtr, uint32_t DataLen, uint32_t /* Out */ SizePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiOutLen = OutLen; @@ -494,11 +512,12 @@ Expect StateEncrypt::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect -StateEncryptDetached::body(Runtime::Instance::MemoryInstance *MemInst, - int32_t StateHandle, uint32_t OutPtr, - uint32_t OutLen, uint32_t DataPtr, uint32_t DataLen, - uint32_t /* Out */ TagHandlePtr) { +Expect StateEncryptDetached::body(const Runtime::CallingFrame &Frame, + int32_t StateHandle, + uint32_t OutPtr, uint32_t OutLen, + uint32_t DataPtr, uint32_t DataLen, + uint32_t /* Out */ TagHandlePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiOutLen = OutLen; @@ -524,11 +543,12 @@ StateEncryptDetached::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect StateDecrypt::body(Runtime::Instance::MemoryInstance *MemInst, +Expect StateDecrypt::body(const Runtime::CallingFrame &Frame, int32_t StateHandle, uint32_t OutPtr, uint32_t OutLen, uint32_t DataPtr, uint32_t DataLen, uint32_t /* Out */ SizePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiOutLen = OutLen; @@ -557,9 +577,10 @@ Expect StateDecrypt::body(Runtime::Instance::MemoryInstance *MemInst, } Expect StateDecryptDetached::body( - Runtime::Instance::MemoryInstance *MemInst, int32_t StateHandle, - uint32_t OutPtr, uint32_t OutLen, uint32_t DataPtr, uint32_t DataLen, - uint32_t RawTagPtr, uint32_t RawTagLen, uint32_t /* Out */ SizePtr) { + const Runtime::CallingFrame &Frame, int32_t StateHandle, uint32_t OutPtr, + uint32_t OutLen, uint32_t DataPtr, uint32_t DataLen, uint32_t RawTagPtr, + uint32_t RawTagLen, uint32_t /* Out */ SizePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiOutLen = OutLen; @@ -590,9 +611,8 @@ Expect StateDecryptDetached::body( return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect StateRatchet::body(Runtime::Instance::MemoryInstance *, +Expect StateRatchet::body(const Runtime::CallingFrame &, int32_t StateHandle) { - if (auto Res = Ctx.symmetricStateRatchet(StateHandle); unlikely(!Res)) { return Res.error(); } @@ -600,8 +620,9 @@ Expect StateRatchet::body(Runtime::Instance::MemoryInstance *, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect TagLen::body(Runtime::Instance::MemoryInstance *MemInst, +Expect TagLen::body(const Runtime::CallingFrame &Frame, int32_t TagHandle, uint32_t /* Out */ SizePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); auto *Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); @@ -619,9 +640,10 @@ Expect TagLen::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect TagPull::body(Runtime::Instance::MemoryInstance *MemInst, +Expect TagPull::body(const Runtime::CallingFrame &Frame, int32_t TagHandle, uint32_t BufPtr, uint32_t BufLen, uint32_t /* Out */ SizePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiBufLen = BufLen; @@ -642,9 +664,10 @@ Expect TagPull::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect TagVerify::body(Runtime::Instance::MemoryInstance *MemInst, +Expect TagVerify::body(const Runtime::CallingFrame &Frame, int32_t TagHandle, uint32_t RawTagPtr, uint32_t RawTagLen) { + auto *MemInst = Frame.getMemoryByIndex(0); checkExist(MemInst); const __wasi_size_t WasiRawTagLen = RawTagLen; @@ -659,7 +682,7 @@ Expect TagVerify::body(Runtime::Instance::MemoryInstance *MemInst, return __WASI_CRYPTO_ERRNO_SUCCESS; } -Expect TagClose::body(Runtime::Instance::MemoryInstance *, +Expect TagClose::body(const Runtime::CallingFrame &, int32_t TagHandle) { if (auto Res = Ctx.symmetricTagClose(TagHandle); unlikely(!Res)) { return Res.error(); diff --git a/plugins/wasi_crypto/symmetric/func.h b/plugins/wasi_crypto/symmetric/func.h index 03fa19d3..a122784f 100644 --- a/plugins/wasi_crypto/symmetric/func.h +++ b/plugins/wasi_crypto/symmetric/func.h @@ -25,39 +25,36 @@ namespace Symmetric { class KeyGenerate : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - uint32_t AlgPtr, uint32_t AlgLen, - uint32_t OptOptionsPtr, + Expect body(const Runtime::CallingFrame &Frame, uint32_t AlgPtr, + uint32_t AlgLen, uint32_t OptOptionsPtr, uint32_t /* Out */ KeyHandlePtr); }; class KeyImport : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - uint32_t AlgPtr, uint32_t AlgLen, uint32_t RawPtr, - uint32_t RawLen, uint32_t /* Out */ KeyHandlePtr); + Expect body(const Runtime::CallingFrame &Frame, uint32_t AlgPtr, + uint32_t AlgLen, uint32_t RawPtr, uint32_t RawLen, + uint32_t /* Out */ KeyHandlePtr); }; class KeyExport : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t KeyHandle, + Expect body(const Runtime::CallingFrame &Frame, int32_t KeyHandle, uint32_t /* Out */ ArrayOutputHandlePtr); }; class KeyClose : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t KeyHandle); + Expect body(const Runtime::CallingFrame &Frame, int32_t KeyHandle); }; class KeyGenerateManaged : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, + Expect body(const Runtime::CallingFrame &Frame, int32_t SecretsManagerHandle, uint32_t AlgPtr, uint32_t AlgLen, uint32_t OptOptionsPtr, uint32_t /* Out */ KeyHandlePtr); @@ -66,7 +63,7 @@ class KeyGenerateManaged : public HostFunction { class KeyStoreManaged : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, + Expect body(const Runtime::CallingFrame &Frame, int32_t SecretsManagerHandle, int32_t KeyHandle, uint32_t KeyIdPtr, uint32_t KeyIdMaxLen); }; @@ -74,7 +71,7 @@ class KeyStoreManaged : public HostFunction { class KeyReplaceManaged : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, + Expect body(const Runtime::CallingFrame &Frame, int32_t SecretsManagerHandle, int32_t OldKeyHandle, int32_t NewKeyHandle, uint32_t /* Out */ KeyVersionPtr); }; @@ -82,16 +79,16 @@ class KeyReplaceManaged : public HostFunction { class KeyId : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t KeyHandle, uint32_t KeyIdPtr, - uint32_t KeyIdMaxLen, uint32_t /* Out */ SizePtr, + Expect body(const Runtime::CallingFrame &Frame, int32_t KeyHandle, + uint32_t KeyIdPtr, uint32_t KeyIdMaxLen, + uint32_t /* Out */ SizePtr, uint32_t /* Out */ KeyVersionPtr); }; class KeyFromId : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, + Expect body(const Runtime::CallingFrame &Frame, int32_t SecretsManagerHandle, uint32_t KeyIdPtr, uint32_t KeyIdLen, uint64_t KeyVersion, uint32_t /* Out */ KeyHandlePtr); @@ -100,151 +97,143 @@ class KeyFromId : public HostFunction { class StateOpen : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - uint32_t AlgPtr, uint32_t AlgLen, - uint32_t OptKeyHandlePtr, uint32_t OptOptionsPtr, - uint32_t /* Out */ StatePtr); + Expect body(const Runtime::CallingFrame &Frame, uint32_t AlgPtr, + uint32_t AlgLen, uint32_t OptKeyHandlePtr, + uint32_t OptOptionsPtr, uint32_t /* Out */ StatePtr); }; class StateClone : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t StateHandle, uint32_t /* Out */ StatePtr); + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle, + uint32_t /* Out */ StatePtr); }; class StateOptionsGet : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t StateHandle, uint32_t NamePtr, uint32_t NameLen, - uint32_t ValuePtr, uint32_t ValueLen, - uint32_t /* Out */ SizePtr); + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle, + uint32_t NamePtr, uint32_t NameLen, uint32_t ValuePtr, + uint32_t ValueLen, uint32_t /* Out */ SizePtr); }; class StateOptionsGetU64 : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t StateHandle, uint32_t NamePtr, uint32_t NameLen, + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle, + uint32_t NamePtr, uint32_t NameLen, uint32_t /* Out */ U64Ptr); }; class StateClose : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle); }; class StateAbsorb : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t StateHandle, uint32_t DataPtr, - uint32_t DataLen); + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle, + uint32_t DataPtr, uint32_t DataLen); }; class StateSqueeze : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t StateHandle, uint32_t OutPtr, uint32_t OutLen); + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle, + uint32_t OutPtr, uint32_t OutLen); }; class StateSqueezeTag : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t StateHandle, uint32_t /* Out */ TagHandlePtr); + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle, + uint32_t /* Out */ TagHandlePtr); }; class StateSqueezeKey : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t StateHandle, uint32_t AlgPtr, uint32_t AlgLen, + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle, + uint32_t AlgPtr, uint32_t AlgLen, uint32_t /* Out */ KeyHandlePtr); }; class StateMaxTagLen : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t StateHandle, uint32_t /* Out */ SizePtr); + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle, + uint32_t /* Out */ SizePtr); }; class StateEncrypt : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t StateHandle, uint32_t OutPtr, uint32_t OutLen, - uint32_t DataPtr, uint32_t DataLen, - uint32_t /* Out */ SizePtr); + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle, + uint32_t OutPtr, uint32_t OutLen, uint32_t DataPtr, + uint32_t DataLen, uint32_t /* Out */ SizePtr); }; class StateEncryptDetached : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t StateHandle, uint32_t OutPtr, uint32_t OutLen, - uint32_t DataPtr, uint32_t DataLen, - uint32_t /* Out */ TagHandlePtr); + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle, + uint32_t OutPtr, uint32_t OutLen, uint32_t DataPtr, + uint32_t DataLen, uint32_t /* Out */ TagHandlePtr); }; class StateDecrypt : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t StateHandle, uint32_t OutPtr, uint32_t OutLen, - uint32_t DataPtr, uint32_t DataLen, - uint32_t /* Out */ SizePtr); + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle, + uint32_t OutPtr, uint32_t OutLen, uint32_t DataPtr, + uint32_t DataLen, uint32_t /* Out */ SizePtr); }; class StateDecryptDetached : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t StateHandle, uint32_t OutPtr, uint32_t OutLen, - uint32_t DataPtr, uint32_t DataLen, uint32_t RawTagPtr, + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle, + uint32_t OutPtr, uint32_t OutLen, uint32_t DataPtr, + uint32_t DataLen, uint32_t RawTagPtr, uint32_t RawTagLen, uint32_t /* Out */ SizePtr); }; class StateRatchet : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, + Expect body(const Runtime::CallingFrame &Frame, int32_t StateHandle); }; class TagLen : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t TagHandle, uint32_t /* Out */ SizePtr); + Expect body(const Runtime::CallingFrame &Frame, int32_t TagHandle, + uint32_t /* Out */ SizePtr); }; class TagPull : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t TagHandle, uint32_t BufPtr, uint32_t BufLen, + Expect body(const Runtime::CallingFrame &Frame, int32_t TagHandle, + uint32_t BufPtr, uint32_t BufLen, uint32_t /* Out */ SizePtr); }; class TagVerify : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t TagHandle, uint32_t RawTagPtr, - uint32_t RawTagLen); + Expect body(const Runtime::CallingFrame &Frame, int32_t TagHandle, + uint32_t RawTagPtr, uint32_t RawTagLen); }; class TagClose : public HostFunction { public: using HostFunction::HostFunction; - Expect body(Runtime::Instance::MemoryInstance *MemInst, - int32_t TagHandle); + Expect body(const Runtime::CallingFrame &Frame, int32_t TagHandle); }; } // namespace Symmetric diff --git a/plugins/wasi_crypto/utils/hostfunction.h b/plugins/wasi_crypto/utils/hostfunction.h index 78c1f375..65a701af 100644 --- a/plugins/wasi_crypto/utils/hostfunction.h +++ b/plugins/wasi_crypto/utils/hostfunction.h @@ -19,8 +19,8 @@ #include "symmetric/registed.h" #include "utils/error.h" +#include "runtime/callingframe.h" #include "runtime/hostfunc.h" -#include "runtime/instance/memory.h" namespace WasmEdge { namespace Host { diff --git a/plugins/wasi_nn/wasinnbase.h b/plugins/wasi_nn/wasinnbase.h index b5daaf84..698b6049 100644 --- a/plugins/wasi_nn/wasinnbase.h +++ b/plugins/wasi_nn/wasinnbase.h @@ -4,6 +4,7 @@ #pragma once #include "common/errcode.h" +#include "runtime/callingframe.h" #include "runtime/hostfunc.h" #include "wasinnenv.h" diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index 20949124..f5807d16 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -14,13 +14,14 @@ namespace WasmEdge { namespace Host { -Expect WasiNNLoad::body(Runtime::Instance::MemoryInstance *MemInst, +Expect WasiNNLoad::body(const Runtime::CallingFrame &Frame, uint32_t BuilderPtr [[maybe_unused]], uint32_t BuilderLen [[maybe_unused]], uint32_t Encoding, uint32_t Target [[maybe_unused]], uint32_t GraphIdPtr [[maybe_unused]]) { // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); } @@ -241,10 +242,10 @@ Expect WasiNNLoad::body(Runtime::Instance::MemoryInstance *MemInst, return static_cast(WASINN::ErrNo::InvalidArgument); } -Expect -WasiNNInitExecCtx::body(Runtime::Instance::MemoryInstance *MemInst, - uint32_t GraphId, - uint32_t ContextPtr [[maybe_unused]]) { +Expect WasiNNInitExecCtx::body(const Runtime::CallingFrame &Frame, + uint32_t GraphId, + uint32_t ContextPtr [[maybe_unused]]) { + auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); } @@ -294,10 +295,11 @@ WasiNNInitExecCtx::body(Runtime::Instance::MemoryInstance *MemInst, return static_cast(WASINN::ErrNo::InvalidArgument); } -Expect -WasiNNSetInput::body(Runtime::Instance::MemoryInstance *MemInst, - uint32_t Context, uint32_t Index [[maybe_unused]], - uint32_t TensorPtr [[maybe_unused]]) { +Expect WasiNNSetInput::body(const Runtime::CallingFrame &Frame, + uint32_t Context, + uint32_t Index [[maybe_unused]], + uint32_t TensorPtr [[maybe_unused]]) { + auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); } @@ -467,11 +469,12 @@ WasiNNSetInput::body(Runtime::Instance::MemoryInstance *MemInst, } Expect -WasiNNGetOuput::body(Runtime::Instance::MemoryInstance *MemInst, - uint32_t Context, uint32_t Index [[maybe_unused]], +WasiNNGetOuput::body(const Runtime::CallingFrame &Frame, uint32_t Context, + uint32_t Index [[maybe_unused]], uint32_t OutBufferPtr [[maybe_unused]], uint32_t OutBufferMaxSize [[maybe_unused]], uint32_t BytesWrittenPtr [[maybe_unused]]) { + auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); } @@ -563,8 +566,9 @@ WasiNNGetOuput::body(Runtime::Instance::MemoryInstance *MemInst, return static_cast(WASINN::ErrNo::InvalidArgument); } -Expect WasiNNCompute::body(Runtime::Instance::MemoryInstance *MemInst, +Expect WasiNNCompute::body(const Runtime::CallingFrame &Frame, uint32_t Context) { + auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); } diff --git a/plugins/wasi_nn/wasinnfunc.h b/plugins/wasi_nn/wasinnfunc.h index 49925acf..d8d9ffd1 100644 --- a/plugins/wasi_nn/wasinnfunc.h +++ b/plugins/wasi_nn/wasinnfunc.h @@ -14,30 +14,29 @@ namespace Host { class WasiNNLoad : public WasiNN { public: WasiNNLoad(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} - Expect body(Runtime::Instance::MemoryInstance *, - uint32_t BuilderPtr, uint32_t BuilderLen, - uint32_t Encoding, uint32_t Target, + Expect body(const Runtime::CallingFrame &, uint32_t BuilderPtr, + uint32_t BuilderLen, uint32_t Encoding, uint32_t Target, uint32_t GraphIdPtr); }; class WasiNNInitExecCtx : public WasiNN { public: WasiNNInitExecCtx(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} - Expect body(Runtime::Instance::MemoryInstance *, uint32_t GraphId, + Expect body(const Runtime::CallingFrame &, uint32_t GraphId, uint32_t ContextPtr); }; class WasiNNSetInput : public WasiNN { public: WasiNNSetInput(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} - Expect body(Runtime::Instance::MemoryInstance *, uint32_t Context, + Expect body(const Runtime::CallingFrame &, uint32_t Context, uint32_t Index, uint32_t TensorPtr); }; class WasiNNGetOuput : public WasiNN { public: WasiNNGetOuput(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} - Expect body(Runtime::Instance::MemoryInstance *, uint32_t Context, + Expect body(const Runtime::CallingFrame &, uint32_t Context, uint32_t Index, uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr); }; @@ -45,7 +44,7 @@ class WasiNNGetOuput : public WasiNN { class WasiNNCompute : public WasiNN { public: WasiNNCompute(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} - Expect body(Runtime::Instance::MemoryInstance *, uint32_t Context); + Expect body(const Runtime::CallingFrame &, uint32_t Context); }; } // namespace Host diff --git a/plugins/wasmedge_process/processbase.h b/plugins/wasmedge_process/processbase.h index 7db2c860..ba98ae74 100644 --- a/plugins/wasmedge_process/processbase.h +++ b/plugins/wasmedge_process/processbase.h @@ -5,6 +5,7 @@ #include "common/errcode.h" #include "processenv.h" +#include "runtime/callingframe.h" #include "runtime/hostfunc.h" namespace WasmEdge { diff --git a/plugins/wasmedge_process/processfunc.cpp b/plugins/wasmedge_process/processfunc.cpp index 72b69e00..3852c596 100644 --- a/plugins/wasmedge_process/processfunc.cpp +++ b/plugins/wasmedge_process/processfunc.cpp @@ -21,9 +21,10 @@ namespace WasmEdge { namespace Host { Expect -WasmEdgeProcessSetProgName::body(Runtime::Instance::MemoryInstance *MemInst, +WasmEdgeProcessSetProgName::body(const Runtime::CallingFrame &Frame, uint32_t NamePtr, uint32_t NameLen) { // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); } @@ -33,10 +34,10 @@ WasmEdgeProcessSetProgName::body(Runtime::Instance::MemoryInstance *MemInst, return {}; } -Expect -WasmEdgeProcessAddArg::body(Runtime::Instance::MemoryInstance *MemInst, - uint32_t ArgPtr, uint32_t ArgLen) { +Expect WasmEdgeProcessAddArg::body(const Runtime::CallingFrame &Frame, + uint32_t ArgPtr, uint32_t ArgLen) { // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); } @@ -48,11 +49,13 @@ WasmEdgeProcessAddArg::body(Runtime::Instance::MemoryInstance *MemInst, return {}; } -Expect -WasmEdgeProcessAddEnv::body(Runtime::Instance::MemoryInstance *MemInst, - uint32_t EnvNamePtr, uint32_t EnvNameLen, - uint32_t EnvValPtr, uint32_t EnvValLen) { +Expect WasmEdgeProcessAddEnv::body(const Runtime::CallingFrame &Frame, + uint32_t EnvNamePtr, + uint32_t EnvNameLen, + uint32_t EnvValPtr, + uint32_t EnvValLen) { // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); } @@ -66,10 +69,10 @@ WasmEdgeProcessAddEnv::body(Runtime::Instance::MemoryInstance *MemInst, return {}; } -Expect -WasmEdgeProcessAddStdIn::body(Runtime::Instance::MemoryInstance *MemInst, - uint32_t BufPtr, uint32_t BufLen) { +Expect WasmEdgeProcessAddStdIn::body(const Runtime::CallingFrame &Frame, + uint32_t BufPtr, uint32_t BufLen) { // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); } @@ -80,14 +83,13 @@ WasmEdgeProcessAddStdIn::body(Runtime::Instance::MemoryInstance *MemInst, return {}; } -Expect -WasmEdgeProcessSetTimeOut::body(Runtime::Instance::MemoryInstance *, - uint32_t Time) { +Expect WasmEdgeProcessSetTimeOut::body(const Runtime::CallingFrame &, + uint32_t Time) { Env.TimeOut = Time; return {}; } -Expect WasmEdgeProcessRun::body(Runtime::Instance::MemoryInstance *) { +Expect WasmEdgeProcessRun::body(const Runtime::CallingFrame &) { #if WASMEDGE_OS_LINUX || WASMEDGE_OS_MACOS // Clear outputs. Env.StdOut.clear(); @@ -299,19 +301,19 @@ Expect WasmEdgeProcessRun::body(Runtime::Instance::MemoryInstance *) { } Expect -WasmEdgeProcessGetExitCode::body(Runtime::Instance::MemoryInstance *) { +WasmEdgeProcessGetExitCode::body(const Runtime::CallingFrame &) { return Env.ExitCode; } Expect -WasmEdgeProcessGetStdOutLen::body(Runtime::Instance::MemoryInstance *) { +WasmEdgeProcessGetStdOutLen::body(const Runtime::CallingFrame &) { return static_cast(Env.StdOut.size()); } -Expect -WasmEdgeProcessGetStdOut::body(Runtime::Instance::MemoryInstance *MemInst, - uint32_t BufPtr) { +Expect WasmEdgeProcessGetStdOut::body(const Runtime::CallingFrame &Frame, + uint32_t BufPtr) { // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); } @@ -322,14 +324,14 @@ WasmEdgeProcessGetStdOut::body(Runtime::Instance::MemoryInstance *MemInst, } Expect -WasmEdgeProcessGetStdErrLen::body(Runtime::Instance::MemoryInstance *) { +WasmEdgeProcessGetStdErrLen::body(const Runtime::CallingFrame &) { return static_cast(Env.StdErr.size()); } -Expect -WasmEdgeProcessGetStdErr::body(Runtime::Instance::MemoryInstance *MemInst, - uint32_t BufPtr) { +Expect WasmEdgeProcessGetStdErr::body(const Runtime::CallingFrame &Frame, + uint32_t BufPtr) { // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); } diff --git a/plugins/wasmedge_process/processfunc.h b/plugins/wasmedge_process/processfunc.h index aef2e4bc..072d0c54 100644 --- a/plugins/wasmedge_process/processfunc.h +++ b/plugins/wasmedge_process/processfunc.h @@ -15,15 +15,15 @@ class WasmEdgeProcessSetProgName public: WasmEdgeProcessSetProgName(WasmEdgeProcessEnvironment &HostEnv) : WasmEdgeProcess(HostEnv) {} - Expect body(Runtime::Instance::MemoryInstance *MemInst, - uint32_t NamePtr, uint32_t NameLen); + Expect body(const Runtime::CallingFrame &Frame, uint32_t NamePtr, + uint32_t NameLen); }; class WasmEdgeProcessAddArg : public WasmEdgeProcess { public: WasmEdgeProcessAddArg(WasmEdgeProcessEnvironment &HostEnv) : WasmEdgeProcess(HostEnv) {} - Expect body(Runtime::Instance::MemoryInstance *MemInst, uint32_t ArgPtr, + Expect body(const Runtime::CallingFrame &Frame, uint32_t ArgPtr, uint32_t ArgLen); }; @@ -31,9 +31,9 @@ class WasmEdgeProcessAddEnv : public WasmEdgeProcess { public: WasmEdgeProcessAddEnv(WasmEdgeProcessEnvironment &HostEnv) : WasmEdgeProcess(HostEnv) {} - Expect body(Runtime::Instance::MemoryInstance *MemInst, - uint32_t EnvNamePtr, uint32_t EnvNameLen, - uint32_t EnvValPtr, uint32_t EnvValLen); + Expect body(const Runtime::CallingFrame &Frame, uint32_t EnvNamePtr, + uint32_t EnvNameLen, uint32_t EnvValPtr, + uint32_t EnvValLen); }; class WasmEdgeProcessAddStdIn @@ -41,7 +41,7 @@ class WasmEdgeProcessAddStdIn public: WasmEdgeProcessAddStdIn(WasmEdgeProcessEnvironment &HostEnv) : WasmEdgeProcess(HostEnv) {} - Expect body(Runtime::Instance::MemoryInstance *MemInst, uint32_t BufPtr, + Expect body(const Runtime::CallingFrame &Frame, uint32_t BufPtr, uint32_t BufLen); }; @@ -50,14 +50,14 @@ class WasmEdgeProcessSetTimeOut public: WasmEdgeProcessSetTimeOut(WasmEdgeProcessEnvironment &HostEnv) : WasmEdgeProcess(HostEnv) {} - Expect body(Runtime::Instance::MemoryInstance *MemInst, uint32_t Time); + Expect body(const Runtime::CallingFrame &Frame, uint32_t Time); }; class WasmEdgeProcessRun : public WasmEdgeProcess { public: WasmEdgeProcessRun(WasmEdgeProcessEnvironment &HostEnv) : WasmEdgeProcess(HostEnv) {} - Expect body(Runtime::Instance::MemoryInstance *MemInst); + Expect body(const Runtime::CallingFrame &Frame); }; class WasmEdgeProcessGetExitCode @@ -65,7 +65,7 @@ class WasmEdgeProcessGetExitCode public: WasmEdgeProcessGetExitCode(WasmEdgeProcessEnvironment &HostEnv) : WasmEdgeProcess(HostEnv) {} - Expect body(Runtime::Instance::MemoryInstance *MemInst); + Expect body(const Runtime::CallingFrame &Frame); }; class WasmEdgeProcessGetStdOutLen @@ -73,7 +73,7 @@ class WasmEdgeProcessGetStdOutLen public: WasmEdgeProcessGetStdOutLen(WasmEdgeProcessEnvironment &HostEnv) : WasmEdgeProcess(HostEnv) {} - Expect body(Runtime::Instance::MemoryInstance *MemInst); + Expect body(const Runtime::CallingFrame &Frame); }; class WasmEdgeProcessGetStdOut @@ -81,8 +81,7 @@ class WasmEdgeProcessGetStdOut public: WasmEdgeProcessGetStdOut(WasmEdgeProcessEnvironment &HostEnv) : WasmEdgeProcess(HostEnv) {} - Expect body(Runtime::Instance::MemoryInstance *MemInst, - uint32_t BufPtr); + Expect body(const Runtime::CallingFrame &Frame, uint32_t BufPtr); }; class WasmEdgeProcessGetStdErrLen @@ -90,7 +89,7 @@ class WasmEdgeProcessGetStdErrLen public: WasmEdgeProcessGetStdErrLen(WasmEdgeProcessEnvironment &HostEnv) : WasmEdgeProcess(HostEnv) {} - Expect body(Runtime::Instance::MemoryInstance *MemInst); + Expect body(const Runtime::CallingFrame &Frame); }; class WasmEdgeProcessGetStdErr @@ -98,8 +97,7 @@ class WasmEdgeProcessGetStdErr public: WasmEdgeProcessGetStdErr(WasmEdgeProcessEnvironment &HostEnv) : WasmEdgeProcess(HostEnv) {} - Expect body(Runtime::Instance::MemoryInstance *MemInst, - uint32_t BufPtr); + Expect body(const Runtime::CallingFrame &Frame, uint32_t BufPtr); }; } // namespace Host From 3d901a7bd8e7ffbc6ff013704d67a213e2042d42 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Fri, 19 Aug 2022 04:26:32 +0800 Subject: [PATCH 066/623] [Test] Update tests for the host function API change. Signed-off-by: YiYing He --- test/plugins/httpsreq/httpsreq.cpp | 69 ++-- test/plugins/wasi_crypto/common.cpp | 6 +- test/plugins/wasi_crypto/helper.cpp | 332 +++++++++--------- test/plugins/wasi_crypto/helper.h | 18 +- test/plugins/wasi_nn/wasi_nn.cpp | 66 ++-- .../wasmedge_process/wasmedge_process.cpp | 173 +++++---- 6 files changed, 373 insertions(+), 291 deletions(-) diff --git a/test/plugins/httpsreq/httpsreq.cpp b/test/plugins/httpsreq/httpsreq.cpp index 35fae957..5f02c394 100644 --- a/test/plugins/httpsreq/httpsreq.cpp +++ b/test/plugins/httpsreq/httpsreq.cpp @@ -42,17 +42,23 @@ void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, TEST(wasmedgeHttpsReqTests, SendData) { // Create the httpsreq module instance. - auto *ProcMod = + auto *HttpMod = dynamic_cast(createModule()); - EXPECT_FALSE(ProcMod == nullptr); - - // Create the memory instance. - WasmEdge::Runtime::Instance::MemoryInstance MemInst( - WasmEdge::AST::MemoryType(1)); + EXPECT_FALSE(HttpMod == nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); // Clear the memory[0, 256]. fillMemContent(MemInst, 0, 256); - // Set the memory[0, 11] as string "echo". + // Set the memory[0, 11] as string "httpbin.org". fillMemContent(MemInst, 0, std::string("httpbin.org")); // Set the memory[30, 116] as string "GET / HTTP/1.1\nHost: // httpbin.org\r\nConnection: Close\r\nReferer: https://httpbin.org/\r\n\r\n". @@ -61,7 +67,7 @@ TEST(wasmedgeHttpsReqTests, SendData) { "Close\r\nReferer: https://httpbin.org/\r\n\r\n")); // Get the function "send_data" - auto *FuncInst = ProcMod->findFuncExports("send_data"); + auto *FuncInst = HttpMod->findFuncExports("send_data"); EXPECT_NE(FuncInst, nullptr); EXPECT_TRUE(FuncInst->isHostFunction()); auto &HostFuncInst = @@ -69,28 +75,34 @@ TEST(wasmedgeHttpsReqTests, SendData) { // Test: Run function successfully for get requests EXPECT_TRUE(HostFuncInst.run( - &MemInst, + CallFrame, std::initializer_list{ UINT32_C(0), UINT32_C(11), UINT32_C(443), UINT32_C(30), UINT32_C(86)}, {})); - EXPECT_TRUE(ProcMod->getEnv().Host == "httpbin.org"); - EXPECT_TRUE(ProcMod->getEnv().Body == + EXPECT_TRUE(HttpMod->getEnv().Host == "httpbin.org"); + EXPECT_TRUE(HttpMod->getEnv().BodyStr == "GET / HTTP/1.1\nHost: httpbin.org\r\nConnection: " "Close\r\nReferer: https://httpbin.org/\r\n\r\n"); } TEST(wasmedgeHttpsReqTests, GetRcv) { // Create the httpsreq module instance. - auto *ProcMod = + auto *HttpMod = dynamic_cast(createModule()); - EXPECT_FALSE(ProcMod == nullptr); - - // Create the memory instance. - WasmEdge::Runtime::Instance::MemoryInstance MemInst( - WasmEdge::AST::MemoryType(1)); + EXPECT_FALSE(HttpMod == nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); fillMemContent(MemInst, 0, 256); - // Set the memory[0, 11] as string "echo". + // Set the memory[0, 11] as string "httpbin.org". fillMemContent(MemInst, 0, std::string("httpbin.org")); // Set the memory[30, 116] as string "GET / HTTP/1.1\nHost: // httpbin.org\r\nConnection: Close\r\nReferer: https://httpbin.org/\r\n\r\n". @@ -99,21 +111,21 @@ TEST(wasmedgeHttpsReqTests, GetRcv) { "Close\r\nReferer: https://httpbin.org/\r\n\r\n")); // Get the function "send_data" - auto *FuncInst = ProcMod->findFuncExports("send_data"); + auto *FuncInst = HttpMod->findFuncExports("send_data"); EXPECT_NE(FuncInst, nullptr); EXPECT_TRUE(FuncInst->isHostFunction()); auto &HostFuncSendData = dynamic_cast(FuncInst->getHostFunc()); // Get the function "get_rcv_len" - FuncInst = ProcMod->findFuncExports("https_req_get_rcv_len"); + FuncInst = HttpMod->findFuncExports("https_req_get_rcv_len"); EXPECT_NE(FuncInst, nullptr); EXPECT_TRUE(FuncInst->isHostFunction()); auto &HostFuncGetRcvLen = dynamic_cast( FuncInst->getHostFunc()); // Get the function "get_rcv" - FuncInst = ProcMod->findFuncExports("https_req_get_rcv"); + FuncInst = HttpMod->findFuncExports("https_req_get_rcv"); EXPECT_NE(FuncInst, nullptr); EXPECT_TRUE(FuncInst->isHostFunction()); auto &HostFuncGetRcv = @@ -121,26 +133,27 @@ TEST(wasmedgeHttpsReqTests, GetRcv) { // Test: Run function successfully for get requests EXPECT_TRUE(HostFuncSendData.run( - &MemInst, + CallFrame, std::initializer_list{ UINT32_C(0), UINT32_C(11), UINT32_C(443), UINT32_C(30), UINT32_C(86)}, {})); - EXPECT_TRUE(ProcMod->getEnv().Host == "httpbin.org"); - EXPECT_TRUE(ProcMod->getEnv().Body == + EXPECT_TRUE(HttpMod->getEnv().Host == "httpbin.org"); + EXPECT_TRUE(HttpMod->getEnv().BodyStr == "GET / HTTP/1.1\nHost: httpbin.org\r\nConnection: " "Close\r\nReferer: https://httpbin.org/\r\n\r\n"); // Test: Run function successfully for getrcvlen std::array RetVal; - EXPECT_TRUE(HostFuncGetRcvLen.run(nullptr, {}, RetVal)); + EXPECT_TRUE(HostFuncGetRcvLen.run( + WasmEdge::Runtime::CallingFrame(nullptr, nullptr), {}, RetVal)); uint32_t Len = RetVal[0].get(); EXPECT_TRUE(Len > 0U); // Test Run function successfully for getrcv EXPECT_TRUE(HostFuncGetRcv.run( - &MemInst, std::initializer_list{UINT32_C(0)}, {})); - EXPECT_TRUE(std::equal(ProcMod->getEnv().Rcv.begin(), - ProcMod->getEnv().Rcv.end(), + CallFrame, std::initializer_list{UINT32_C(0)}, {})); + EXPECT_TRUE(std::equal(HttpMod->getEnv().Rcv.begin(), + HttpMod->getEnv().Rcv.end(), MemInst.getPointer(0))); } diff --git a/test/plugins/wasi_crypto/common.cpp b/test/plugins/wasi_crypto/common.cpp index 791481c6..11e57d36 100644 --- a/test/plugins/wasi_crypto/common.cpp +++ b/test/plugins/wasi_crypto/common.cpp @@ -56,7 +56,7 @@ TEST_F(WasiCryptoTest, Options) { auto *Func = getHostFunc( WasiCryptoCommonMod, "options_set_guest_buffer"); ASSERT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ SymmetricOptionsHandle, 0, NameSize, 0, NameSize}, Errno)); @@ -84,7 +84,7 @@ TEST_F(WasiCryptoTest, Options) { auto *Func = getHostFunc( WasiCryptoCommonMod, "options_set_guest_buffer"); ASSERT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ SigOptionsHandle, 0, NameSize, 0, NameSize}, Errno)); @@ -112,7 +112,7 @@ TEST_F(WasiCryptoTest, Options) { auto *Func = getHostFunc( WasiCryptoCommonMod, "options_set_guest_buffer"); ASSERT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ KxOptionsHandle, 0, NameSize, 0, NameSize}, Errno)); diff --git a/test/plugins/wasi_crypto/helper.cpp b/test/plugins/wasi_crypto/helper.cpp index 6fe86695..51b84955 100644 --- a/test/plugins/wasi_crypto/helper.cpp +++ b/test/plugins/wasi_crypto/helper.cpp @@ -54,15 +54,16 @@ std::vector operator"" _u8v(const char *Str, std::size_t Len) { } void WasiCryptoTest::writeDummyMemoryContent() { - std::fill_n(MemInst.getPointer(0), 64, UINT8_C(0xa5)); + std::fill_n(MemInst->getPointer(0), 64, UINT8_C(0xa5)); } void WasiCryptoTest::writeString(std::string_view String, uint32_t Ptr) { - std::copy(String.begin(), String.end(), MemInst.getPointer(Ptr)); + std::copy(String.begin(), String.end(), MemInst->getPointer(Ptr)); } void WasiCryptoTest::writeSpan(Span Content, uint32_t Ptr) { - std::copy(Content.begin(), Content.end(), MemInst.getPointer(Ptr)); + std::copy(Content.begin(), Content.end(), + MemInst->getPointer(Ptr)); } void WasiCryptoTest::writeOptKey(std::optional OptKey, uint32_t Ptr) { @@ -73,7 +74,7 @@ void WasiCryptoTest::writeOptKey(std::optional OptKey, uint32_t Ptr) { } else { Key.tag = __WASI_OPT_SYMMETRIC_KEY_U_NONE; } - auto *BeginPlace = MemInst.getPointer<__wasi_opt_symmetric_key_t *>(Ptr); + auto *BeginPlace = MemInst->getPointer<__wasi_opt_symmetric_key_t *>(Ptr); *BeginPlace = Key; } @@ -86,7 +87,7 @@ void WasiCryptoTest::writeOptOptions(std::optional<__wasi_options_t> OptOptions, } else { Options.tag = __WASI_OPT_OPTIONS_U_NONE; } - auto *BeginPlace = MemInst.getPointer<__wasi_opt_options_t *>(Ptr); + auto *BeginPlace = MemInst->getPointer<__wasi_opt_options_t *>(Ptr); *BeginPlace = Options; } @@ -98,12 +99,12 @@ WasiCryptoTest::arrayOutputLen(__wasi_array_output_t ArrayOutputHandle) { "array_output_len"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, + CallFrame, std::initializer_list{ArrayOutputHandle, 0}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_size_t *>(0); + return *MemInst->getPointer<__wasi_size_t *>(0); } WasiCryptoExpect<__wasi_size_t> @@ -116,15 +117,15 @@ WasiCryptoTest::arrayOutputPull(__wasi_array_output_t ArrayOutputHandle, auto *Func = getHostFunc(WasiCryptoCommonMod, "array_output_pull"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ ArrayOutputHandle, 0, BufSize, BufSize}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - std::copy(MemInst.getPointer(0), - MemInst.getPointer(BufSize), Buf.begin()); - return *MemInst.getPointer<__wasi_size_t *>(BufSize); + std::copy(MemInst->getPointer(0), + MemInst->getPointer(BufSize), Buf.begin()); + return *MemInst->getPointer<__wasi_size_t *>(BufSize); } WasiCryptoExpect<__wasi_options_t> @@ -134,13 +135,13 @@ WasiCryptoTest::optionsOpen(__wasi_algorithm_type_e_t AlgorithmType) { auto *Func = getHostFunc(WasiCryptoCommonMod, "options_open"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ static_cast(AlgorithmType), 0}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_options_t *>(0); + return *MemInst->getPointer<__wasi_options_t *>(0); } WasiCryptoExpect @@ -151,7 +152,7 @@ WasiCryptoTest::optionsClose(__wasi_options_t OptionsHandle) { getHostFunc(WasiCryptoCommonMod, "options_close"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, std::initializer_list{OptionsHandle}, + CallFrame, std::initializer_list{OptionsHandle}, Errno)); ensureOrReturnOnTest(Errno[0].get()); @@ -170,7 +171,7 @@ WasiCryptoTest::optionsSet(__wasi_options_t OptionsHandle, auto *Func = getHostFunc(WasiCryptoCommonMod, "options_set"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ OptionsHandle, 0, NameSize, NameSize, ValueSize}, Errno)); @@ -190,7 +191,7 @@ WasiCryptoTest::optionsSetU64(__wasi_options_t OptionsHandle, auto *Func = getHostFunc(WasiCryptoCommonMod, "options_set_u64"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ OptionsHandle, 0, NameSize, Value}, Errno)); @@ -208,10 +209,10 @@ WasiCryptoExpect<__wasi_secrets_manager_t> WasiCryptoTest::secretsManagerOpen( "secrets_manager_open"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, std::initializer_list{0, 8}, Errno)); + CallFrame, std::initializer_list{0, 8}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_secrets_manager_t *>(8); + return *MemInst->getPointer<__wasi_secrets_manager_t *>(8); } WasiCryptoExpect WasiCryptoTest::secretsManagerClose( @@ -222,7 +223,7 @@ WasiCryptoExpect WasiCryptoTest::secretsManagerClose( WasiCryptoCommonMod, "secrets_manager_close"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, + CallFrame, std::initializer_list{SecretsManagerHandle}, Errno)); ensureOrReturnOnTest(Errno[0].get()); @@ -240,7 +241,7 @@ WasiCryptoExpect WasiCryptoTest::secretsManagerInvalidate( auto *Func = getHostFunc( WasiCryptoCommonMod, "secrets_manager_invalidate"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ SecretsManagerHandle, 0, KeyIdSize, Version}, Errno)); @@ -259,13 +260,13 @@ WasiCryptoExpect<__wasi_symmetric_key_t> WasiCryptoTest::symmetricKeyGenerate( auto *Func = getHostFunc(WasiCryptoSymmMod, "symmetric_key_generate"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ 0, AlgSize, AlgSize, AlgSize + 8}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_symmetric_key_t *>(AlgSize + 8); + return *MemInst->getPointer<__wasi_symmetric_key_t *>(AlgSize + 8); } WasiCryptoExpect<__wasi_symmetric_key_t> @@ -280,13 +281,13 @@ WasiCryptoTest::symmetricKeyImport(std::string_view Alg, auto *Func = getHostFunc(WasiCryptoSymmMod, "symmetric_key_import"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ 0, AlgSize, AlgSize, RawSize, AlgSize + RawSize}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_symmetric_key_t *>(AlgSize + RawSize); + return *MemInst->getPointer<__wasi_symmetric_key_t *>(AlgSize + RawSize); } WasiCryptoExpect<__wasi_array_output_t> @@ -297,11 +298,11 @@ WasiCryptoTest::symmetricKeyExport(__wasi_symmetric_key_t KeyHandle) { "symmetric_key_export"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, std::initializer_list{KeyHandle, 0}, + CallFrame, std::initializer_list{KeyHandle, 0}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_array_output_t *>(0); + return *MemInst->getPointer<__wasi_array_output_t *>(0); } WasiCryptoExpect @@ -311,8 +312,9 @@ WasiCryptoTest::symmetricKeyClose(__wasi_symmetric_key_t KeyHandle) { auto *Func = getHostFunc(WasiCryptoSymmMod, "symmetric_key_close"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run( - &MemInst, std::initializer_list{KeyHandle}, Errno)); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{KeyHandle}, + Errno)); ensureOrReturnOnTest(Errno[0].get()); return {}; @@ -331,13 +333,13 @@ WasiCryptoTest::symmetricKeyGenerateManaged( WasiCryptoSymmMod, "symmetric_key_generate_managed"); EXPECT_NE(Func, nullptr); EXPECT_TRUE( - Func->run(&MemInst, + Func->run(CallFrame, std::initializer_list{ SecretsManagerHandle, 0, AlgSize, AlgSize, AlgSize + 8}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_keypair_t *>(AlgSize + 8); + return *MemInst->getPointer<__wasi_keypair_t *>(AlgSize + 8); } WasiCryptoExpect WasiCryptoTest::symmetricKeyStoreManaged( @@ -350,14 +352,14 @@ WasiCryptoExpect WasiCryptoTest::symmetricKeyStoreManaged( auto *Func = getHostFunc( WasiCryptoSymmMod, "symmetric_key_store_managed"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ SecretsManagerHandle, KeyHandle, 0, KpIdSize}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - std::copy(MemInst.getPointer(0), - MemInst.getPointer(KpIdSize), KeyId.begin()); + std::copy(MemInst->getPointer(0), + MemInst->getPointer(KpIdSize), KeyId.begin()); return {}; } @@ -371,13 +373,13 @@ WasiCryptoExpect<__wasi_version_t> WasiCryptoTest::symmetricKeyReplaceManaged( WasiCryptoSymmMod, "symmetric_key_replace_managed"); EXPECT_NE(Func, nullptr); EXPECT_TRUE( - Func->run(&MemInst, + Func->run(CallFrame, std::initializer_list{ SecretsManagerHandle, OldKeyHandle, NewKeyHandle, 0}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_version_t *>(0); + return *MemInst->getPointer<__wasi_version_t *>(0); } WasiCryptoExpect> @@ -390,18 +392,18 @@ WasiCryptoTest::symmetricKeyId(__wasi_symmetric_key_t KeyHandle, auto *Func = getHostFunc(WasiCryptoSymmMod, "symmetric_key_id"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ KeyHandle, 0, KeyIdSize, KeyIdSize, KeyIdSize + 1}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - std::copy(MemInst.getPointer(0), - MemInst.getPointer(KeyIdSize), KeyId.begin()); + std::copy(MemInst->getPointer(0), + MemInst->getPointer(KeyIdSize), KeyId.begin()); return std::make_tuple( - *MemInst.getPointer(KeyIdSize), - *MemInst.getPointer<__wasi_version_t *>(KeyIdSize + 1)); + *MemInst->getPointer(KeyIdSize), + *MemInst->getPointer<__wasi_version_t *>(KeyIdSize + 1)); } WasiCryptoExpect<__wasi_symmetric_key_t> WasiCryptoTest::symmetricKeyFromId( @@ -415,13 +417,13 @@ WasiCryptoExpect<__wasi_symmetric_key_t> WasiCryptoTest::symmetricKeyFromId( "symmetric_key_from_id"); EXPECT_NE(Func, nullptr); EXPECT_TRUE( - Func->run(&MemInst, + Func->run(CallFrame, std::initializer_list{ SecretsManagerHandle, 0, KeyIdSize, KeyVersion, KeyIdSize}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_symmetric_key_t *>(KeyIdSize); + return *MemInst->getPointer<__wasi_symmetric_key_t *>(KeyIdSize); } WasiCryptoExpect<__wasi_symmetric_state_t> WasiCryptoTest::symmetricStateOpen( @@ -436,13 +438,13 @@ WasiCryptoExpect<__wasi_symmetric_state_t> WasiCryptoTest::symmetricStateOpen( auto *Func = getHostFunc(WasiCryptoSymmMod, "symmetric_state_open"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ 0, AlgSize, AlgSize, AlgSize + 8, AlgSize + 16}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_symmetric_state_t *>(AlgSize + 16); + return *MemInst->getPointer<__wasi_symmetric_state_t *>(AlgSize + 16); } WasiCryptoExpect<__wasi_size_t> @@ -459,16 +461,17 @@ WasiCryptoTest::symmetricStateOptionsGet(__wasi_symmetric_state_t StateHandle, WasiCryptoSymmMod, "symmetric_state_options_get"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, + CallFrame, std::initializer_list{ StateHandle, 0, NameSize, NameSize, ValueSize, NameSize + ValueSize}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - std::copy(MemInst.getPointer(NameSize), - MemInst.getPointer(NameSize + ValueSize), Value.begin()); + std::copy(MemInst->getPointer(NameSize), + MemInst->getPointer(NameSize + ValueSize), + Value.begin()); - return *MemInst.getPointer<__wasi_size_t *>(NameSize + ValueSize); + return *MemInst->getPointer<__wasi_size_t *>(NameSize + ValueSize); } WasiCryptoExpect WasiCryptoTest::symmetricStateOptionsGetU64( @@ -480,13 +483,13 @@ WasiCryptoExpect WasiCryptoTest::symmetricStateOptionsGetU64( auto *Func = getHostFunc( WasiCryptoSymmMod, "symmetric_state_options_get_u64"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ StateHandle, 0, NameSize, NameSize}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer(NameSize); + return *MemInst->getPointer(NameSize); } WasiCryptoExpect @@ -497,7 +500,7 @@ WasiCryptoTest::symmetricStateClose(__wasi_symmetric_state_t StateHandle) { "symmetric_state_close"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, std::initializer_list{StateHandle}, + CallFrame, std::initializer_list{StateHandle}, Errno)); ensureOrReturnOnTest(Errno[0].get()); @@ -515,7 +518,7 @@ WasiCryptoTest::symmetricStateAbsorb(__wasi_symmetric_state_t StateHandle, "symmetric_state_absorb"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, + CallFrame, std::initializer_list{StateHandle, 0, DataSize}, Errno)); ensureOrReturnOnTest(Errno[0].get()); @@ -531,11 +534,11 @@ WasiCryptoTest::symmetricStateClone(__wasi_symmetric_state_t StateHandle) { "symmetric_state_clone"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, std::initializer_list{StateHandle, 0}, + CallFrame, std::initializer_list{StateHandle, 0}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_symmetric_state_t *>(0); + return *MemInst->getPointer<__wasi_symmetric_state_t *>(0); } WasiCryptoExpect @@ -549,13 +552,13 @@ WasiCryptoTest::symmetricStateSqueeze(__wasi_symmetric_state_t StateHandle, "symmetric_state_squeeze"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, + CallFrame, std::initializer_list{StateHandle, 0, OutSize}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - std::copy(MemInst.getPointer(0), - MemInst.getPointer(OutSize), Out.begin()); + std::copy(MemInst->getPointer(0), + MemInst->getPointer(OutSize), Out.begin()); return {}; } @@ -568,11 +571,11 @@ WasiCryptoTest::symmetricStateSqueezeTag(__wasi_symmetric_state_t StateHandle) { WasiCryptoSymmMod, "symmetric_state_squeeze_tag"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, std::initializer_list{StateHandle, 0}, + CallFrame, std::initializer_list{StateHandle, 0}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_symmetric_tag_t *>(0); + return *MemInst->getPointer<__wasi_symmetric_tag_t *>(0); } WasiCryptoExpect<__wasi_symmetric_key_t> @@ -585,13 +588,13 @@ WasiCryptoTest::symmetricStateSqueezeKey(__wasi_symmetric_state_t StateHandle, auto *Func = getHostFunc( WasiCryptoSymmMod, "symmetric_state_squeeze_key"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ StateHandle, 0, AlgSize, AlgSize}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_symmetric_key_t *>(AlgSize); + return *MemInst->getPointer<__wasi_symmetric_key_t *>(AlgSize); } WasiCryptoExpect<__wasi_size_t> @@ -602,11 +605,11 @@ WasiCryptoTest::symmetricStateMaxTagLen(__wasi_symmetric_state_t StateHandle) { WasiCryptoSymmMod, "symmetric_state_max_tag_len"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, std::initializer_list{StateHandle, 0}, + CallFrame, std::initializer_list{StateHandle, 0}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_size_t *>(0); + return *MemInst->getPointer<__wasi_size_t *>(0); } WasiCryptoExpect<__wasi_size_t> @@ -623,16 +626,16 @@ WasiCryptoTest::symmetricStateEncrypt(__wasi_symmetric_state_t StateHandle, "symmetric_state_encrypt"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, + CallFrame, std::initializer_list{ StateHandle, 0, OutSize, OutSize, DataSize, OutSize + DataSize}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - std::copy(MemInst.getPointer(0), - MemInst.getPointer(OutSize), Out.begin()); + std::copy(MemInst->getPointer(0), + MemInst->getPointer(OutSize), Out.begin()); - return *MemInst.getPointer<__wasi_size_t *>(OutSize + DataSize); + return *MemInst->getPointer<__wasi_size_t *>(OutSize + DataSize); } WasiCryptoExpect<__wasi_symmetric_tag_t> @@ -649,16 +652,16 @@ WasiCryptoTest::symmetricStateEncryptDetached( WasiCryptoSymmMod, "symmetric_state_encrypt_detached"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, + CallFrame, std::initializer_list{ StateHandle, 0, OutSize, OutSize, DataSize, OutSize + DataSize}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - std::copy(MemInst.getPointer(0), - MemInst.getPointer(OutSize), Out.begin()); + std::copy(MemInst->getPointer(0), + MemInst->getPointer(OutSize), Out.begin()); - return *MemInst.getPointer<__wasi_symmetric_tag_t *>(OutSize + DataSize); + return *MemInst->getPointer<__wasi_symmetric_tag_t *>(OutSize + DataSize); } WasiCryptoExpect<__wasi_size_t> @@ -675,16 +678,16 @@ WasiCryptoTest::symmetricStateDecrypt(__wasi_symmetric_state_t StateHandle, "symmetric_state_decrypt"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, + CallFrame, std::initializer_list{ StateHandle, 0, OutSize, OutSize, DataSize, OutSize + DataSize}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - std::copy(MemInst.getPointer(0), - MemInst.getPointer(OutSize), Out.begin()); + std::copy(MemInst->getPointer(0), + MemInst->getPointer(OutSize), Out.begin()); - return *MemInst.getPointer<__wasi_size_t *>(OutSize + DataSize); + return *MemInst->getPointer<__wasi_size_t *>(OutSize + DataSize); } WasiCryptoExpect<__wasi_size_t> WasiCryptoTest::symmetricStateDecryptDetached( @@ -701,7 +704,7 @@ WasiCryptoExpect<__wasi_size_t> WasiCryptoTest::symmetricStateDecryptDetached( auto *Func = getHostFunc( WasiCryptoSymmMod, "symmetric_state_decrypt_detached"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ StateHandle, 0, OutSize, OutSize, DataSize, OutSize + DataSize, RawTagSize, @@ -709,13 +712,13 @@ WasiCryptoExpect<__wasi_size_t> WasiCryptoTest::symmetricStateDecryptDetached( Errno)); ensureOrReturnOnTest(Errno[0].get()); - std::copy(MemInst.getPointer(0), - MemInst.getPointer(OutSize), Out.begin()); - std::copy(MemInst.getPointer(OutSize + DataSize), - MemInst.getPointer(OutSize + DataSize + RawTagSize), + std::copy(MemInst->getPointer(0), + MemInst->getPointer(OutSize), Out.begin()); + std::copy(MemInst->getPointer(OutSize + DataSize), + MemInst->getPointer(OutSize + DataSize + RawTagSize), RawTag.begin()); - return *MemInst.getPointer<__wasi_size_t *>(OutSize + DataSize + RawTagSize); + return *MemInst->getPointer<__wasi_size_t *>(OutSize + DataSize + RawTagSize); } WasiCryptoExpect @@ -726,7 +729,7 @@ WasiCryptoTest::symmetricStateRatchet(__wasi_symmetric_state_t StateHandle) { "symmetric_state_ratchet"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, std::initializer_list{StateHandle}, + CallFrame, std::initializer_list{StateHandle}, Errno)); ensureOrReturnOnTest(Errno[0].get()); @@ -741,11 +744,11 @@ WasiCryptoTest::symmetricMaxTagLen(__wasi_symmetric_tag_t TagHandle) { WasiCryptoSymmMod, "symmetric_state_max_tag_len"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, std::initializer_list{TagHandle, 0}, + CallFrame, std::initializer_list{TagHandle, 0}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_size_t *>(0); + return *MemInst->getPointer<__wasi_size_t *>(0); } WasiCryptoExpect<__wasi_size_t> @@ -758,16 +761,16 @@ WasiCryptoTest::symmetricTagPull(__wasi_symmetric_tag_t TagHandle, auto *Func = getHostFunc(WasiCryptoSymmMod, "symmetric_tag_pull"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ TagHandle, 0, BufSize, BufSize}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - std::copy(MemInst.getPointer(0), - MemInst.getPointer(BufSize), Buf.begin()); + std::copy(MemInst->getPointer(0), + MemInst->getPointer(BufSize), Buf.begin()); - return *MemInst.getPointer<__wasi_size_t *>(BufSize); + return *MemInst->getPointer<__wasi_size_t *>(BufSize); } WasiCryptoExpect @@ -781,7 +784,7 @@ WasiCryptoTest::symmetricTagVerify(__wasi_symmetric_tag_t TagHandle, "symmetric_tag_verify"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, + CallFrame, std::initializer_list{TagHandle, 0, RawTagSize}, Errno)); ensureOrReturnOnTest(Errno[0].get()); @@ -796,8 +799,9 @@ WasiCryptoTest::symmetricTagClose(__wasi_symmetric_tag_t TagHandle) { auto *Func = getHostFunc(WasiCryptoSymmMod, "symmetric_tag_close"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run( - &MemInst, std::initializer_list{TagHandle}, Errno)); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{TagHandle}, + Errno)); ensureOrReturnOnTest(Errno[0].get()); return {}; @@ -815,13 +819,13 @@ WasiCryptoExpect<__wasi_keypair_t> WasiCryptoTest::keypairGenerate( WasiCryptoAsymCommonMod, "keypair_generate"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, + CallFrame, std::initializer_list{ static_cast(AlgType), 0, AlgSize, AlgSize, AlgSize + 8}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_keypair_t *>(AlgSize + 8); + return *MemInst->getPointer<__wasi_keypair_t *>(AlgSize + 8); } WasiCryptoExpect<__wasi_keypair_t> @@ -837,7 +841,7 @@ WasiCryptoTest::keypairImport(__wasi_algorithm_type_e_t AlgType, auto *Func = getHostFunc( WasiCryptoAsymCommonMod, "keypair_import"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ static_cast(AlgType), 0, AlgSize, AlgSize, EncodedSize, static_cast(Encoding), @@ -845,7 +849,7 @@ WasiCryptoTest::keypairImport(__wasi_algorithm_type_e_t AlgType, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_keypair_t *>(AlgSize + EncodedSize); + return *MemInst->getPointer<__wasi_keypair_t *>(AlgSize + EncodedSize); } WasiCryptoExpect<__wasi_keypair_t> WasiCryptoTest::keypairGenerateManaged( @@ -861,14 +865,14 @@ WasiCryptoExpect<__wasi_keypair_t> WasiCryptoTest::keypairGenerateManaged( WasiCryptoAsymCommonMod, "keypair_generate_managed"); EXPECT_NE(Func, nullptr); EXPECT_TRUE( - Func->run(&MemInst, + Func->run(CallFrame, std::initializer_list{ SecretsManagerHandle, static_cast(AlgType), 0, AlgSize, AlgSize, AlgSize + 8}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_keypair_t *>(AlgSize + 8); + return *MemInst->getPointer<__wasi_keypair_t *>(AlgSize + 8); } WasiCryptoExpect WasiCryptoTest::keypairStoreManaged( @@ -881,14 +885,14 @@ WasiCryptoExpect WasiCryptoTest::keypairStoreManaged( auto *Func = getHostFunc( WasiCryptoAsymCommonMod, "keypair_store_managed"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ SecretsManagerHandle, KpHandle, 0, KpIdSize}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - std::copy(MemInst.getPointer(0), - MemInst.getPointer(KpIdSize), KpId.begin()); + std::copy(MemInst->getPointer(0), + MemInst->getPointer(KpIdSize), KpId.begin()); return {}; } @@ -901,13 +905,13 @@ WasiCryptoExpect<__wasi_version_t> WasiCryptoTest::keypairReplaceManaged( auto *Func = getHostFunc( WasiCryptoAsymCommonMod, "keypair_replace_managed"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ SecretsManagerHandle, OldKpHandle, NewKpHandle, 0}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_version_t *>(0); + return *MemInst->getPointer<__wasi_version_t *>(0); } WasiCryptoExpect> @@ -919,17 +923,18 @@ WasiCryptoTest::keypairId(__wasi_keypair_t KpHandle, Span KpId) { auto *Func = getHostFunc(WasiCryptoAsymCommonMod, "keypair_id"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ KpHandle, 0, KpIdSize, KpIdSize, KpIdSize + 1}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - std::copy(MemInst.getPointer(0), - MemInst.getPointer(KpIdSize), KpId.begin()); + std::copy(MemInst->getPointer(0), + MemInst->getPointer(KpIdSize), KpId.begin()); - return std::make_tuple(*MemInst.getPointer(KpIdSize), - *MemInst.getPointer<__wasi_version_t *>(KpIdSize + 1)); + return std::make_tuple( + *MemInst->getPointer(KpIdSize), + *MemInst->getPointer<__wasi_version_t *>(KpIdSize + 1)); } WasiCryptoExpect<__wasi_keypair_t> @@ -944,13 +949,13 @@ WasiCryptoTest::keypairFromId(__wasi_secrets_manager_t SecretsManagerHandle, WasiCryptoAsymCommonMod, "keypair_from_id"); EXPECT_NE(Func, nullptr); EXPECT_TRUE( - Func->run(&MemInst, + Func->run(CallFrame, std::initializer_list{ SecretsManagerHandle, 0, KpIdSize, KpIdVersion, KpIdSize}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_keypair_t *>(KpIdSize); + return *MemInst->getPointer<__wasi_keypair_t *>(KpIdSize); } WasiCryptoExpect<__wasi_keypair_t> @@ -962,12 +967,12 @@ WasiCryptoTest::keypairFromPkAndSk(__wasi_publickey_t PkHandle, WasiCryptoAsymCommonMod, "keypair_from_pk_and_sk"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, + CallFrame, std::initializer_list{PkHandle, SkHandle, 0}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_keypair_t *>(0); + return *MemInst->getPointer<__wasi_keypair_t *>(0); } WasiCryptoExpect<__wasi_array_output_t> @@ -978,13 +983,13 @@ WasiCryptoTest::keypairExport(__wasi_keypair_t KpHandle, auto *Func = getHostFunc( WasiCryptoAsymCommonMod, "keypair_export"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ KpHandle, static_cast(Encoding), 0}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_array_output_t *>(0); + return *MemInst->getPointer<__wasi_array_output_t *>(0); } WasiCryptoExpect<__wasi_publickey_t> @@ -995,11 +1000,11 @@ WasiCryptoTest::keypairPublickey(__wasi_keypair_t KpHandle) { WasiCryptoAsymCommonMod, "keypair_publickey"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, std::initializer_list{KpHandle, 0}, + CallFrame, std::initializer_list{KpHandle, 0}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_publickey_t *>(0); + return *MemInst->getPointer<__wasi_publickey_t *>(0); } WasiCryptoExpect<__wasi_secretkey_t> @@ -1010,11 +1015,11 @@ WasiCryptoTest::keypairSecretkey(__wasi_keypair_t KpHandle) { WasiCryptoAsymCommonMod, "keypair_secretkey"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, std::initializer_list{KpHandle, 0}, + CallFrame, std::initializer_list{KpHandle, 0}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_secretkey_t *>(0); + return *MemInst->getPointer<__wasi_secretkey_t *>(0); } WasiCryptoExpect WasiCryptoTest::keypairClose(__wasi_keypair_t KpHandle) { @@ -1024,7 +1029,7 @@ WasiCryptoExpect WasiCryptoTest::keypairClose(__wasi_keypair_t KpHandle) { WasiCryptoAsymCommonMod, "keypair_close"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, std::initializer_list{KpHandle}, Errno)); + CallFrame, std::initializer_list{KpHandle}, Errno)); ensureOrReturnOnTest(Errno[0].get()); return {}; @@ -1042,7 +1047,7 @@ WasiCryptoExpect<__wasi_publickey_t> WasiCryptoTest::publickeyImport( auto *Func = getHostFunc( WasiCryptoAsymCommonMod, "publickey_import"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ static_cast(AlgType), 0, AlgSize, AlgSize, EncodedSize, static_cast(Encoding), @@ -1050,7 +1055,7 @@ WasiCryptoExpect<__wasi_publickey_t> WasiCryptoTest::publickeyImport( Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_publickey_t *>(AlgSize + EncodedSize); + return *MemInst->getPointer<__wasi_publickey_t *>(AlgSize + EncodedSize); } WasiCryptoExpect<__wasi_array_output_t> @@ -1061,13 +1066,13 @@ WasiCryptoTest::publickeyExport(__wasi_publickey_t PkHandle, auto *Func = getHostFunc( WasiCryptoAsymCommonMod, "publickey_export"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ PkHandle, static_cast(Encoding), 0}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_array_output_t *>(0); + return *MemInst->getPointer<__wasi_array_output_t *>(0); } WasiCryptoExpect @@ -1078,7 +1083,7 @@ WasiCryptoTest::publickeyVerify(__wasi_publickey_t PkHandle) { WasiCryptoAsymCommonMod, "publickey_verify"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, std::initializer_list{PkHandle}, Errno)); + CallFrame, std::initializer_list{PkHandle}, Errno)); ensureOrReturnOnTest(Errno[0].get()); return {}; @@ -1092,11 +1097,11 @@ WasiCryptoTest::publickeyFromSecretkey(__wasi_secretkey_t SkHandle) { WasiCryptoAsymCommonMod, "publickey_from_secretkey"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, std::initializer_list{SkHandle, 0}, + CallFrame, std::initializer_list{SkHandle, 0}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_publickey_t *>(0); + return *MemInst->getPointer<__wasi_publickey_t *>(0); } WasiCryptoExpect @@ -1107,7 +1112,7 @@ WasiCryptoTest::publickeyClose(__wasi_publickey_t PkHandle) { WasiCryptoAsymCommonMod, "publickey_close"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, std::initializer_list{PkHandle}, Errno)); + CallFrame, std::initializer_list{PkHandle}, Errno)); ensureOrReturnOnTest(Errno[0].get()); return {}; @@ -1125,7 +1130,7 @@ WasiCryptoExpect<__wasi_secretkey_t> WasiCryptoTest::secretkeyImport( auto *Func = getHostFunc( WasiCryptoAsymCommonMod, "secretkey_import"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ static_cast(AlgType), 0, AlgSize, AlgSize, EncodedSize, static_cast(Encoding), @@ -1133,7 +1138,7 @@ WasiCryptoExpect<__wasi_secretkey_t> WasiCryptoTest::secretkeyImport( Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_secretkey_t *>(AlgSize + EncodedSize); + return *MemInst->getPointer<__wasi_secretkey_t *>(AlgSize + EncodedSize); } WasiCryptoExpect<__wasi_array_output_t> @@ -1144,13 +1149,13 @@ WasiCryptoTest::secretkeyExport(__wasi_secretkey_t SkHandle, auto *Func = getHostFunc( WasiCryptoAsymCommonMod, "secretkey_export"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ SkHandle, static_cast(Encoding), 0}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_publickey_t *>(0); + return *MemInst->getPointer<__wasi_publickey_t *>(0); } WasiCryptoExpect @@ -1161,7 +1166,7 @@ WasiCryptoTest::secretkeyClose(__wasi_secretkey_t SkHandle) { WasiCryptoAsymCommonMod, "secretkey_close"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, std::initializer_list{SkHandle}, Errno)); + CallFrame, std::initializer_list{SkHandle}, Errno)); ensureOrReturnOnTest(Errno[0].get()); return {}; @@ -1175,12 +1180,12 @@ WasiCryptoTest::kxDh(__wasi_kx_publickey_t PkHandle, auto *Func = getHostFunc(WasiCryptoKxMod, "kx_dh"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, + CallFrame, std::initializer_list{PkHandle, SkHandle, 0}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_array_output_t *>(0); + return *MemInst->getPointer<__wasi_array_output_t *>(0); } WasiCryptoExpect> @@ -1190,12 +1195,12 @@ WasiCryptoTest::kxEncapsulate(__wasi_kx_publickey_t PkHandle) { auto *Func = getHostFunc(WasiCryptoKxMod, "kx_encapsulate"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, std::initializer_list{PkHandle, 0, 1}, + CallFrame, std::initializer_list{PkHandle, 0, 1}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return std::make_tuple(*MemInst.getPointer<__wasi_array_output_t *>(0), - *MemInst.getPointer<__wasi_array_output_t *>(1)); + return std::make_tuple(*MemInst->getPointer<__wasi_array_output_t *>(0), + *MemInst->getPointer<__wasi_array_output_t *>(1)); } WasiCryptoExpect<__wasi_array_output_t> @@ -1209,13 +1214,13 @@ WasiCryptoTest::kxDecapsulate(__wasi_kx_secretkey_t SkHandle, auto *Func = getHostFunc(WasiCryptoKxMod, "kx_decapsulate"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, + CallFrame, std::initializer_list{ SkHandle, 0, EncapsulatedSecretSize, EncapsulatedSecretSize}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_array_output_t *>(EncapsulatedSecretSize); + return *MemInst->getPointer<__wasi_array_output_t *>(EncapsulatedSecretSize); } WasiCryptoExpect<__wasi_array_output_t> @@ -1223,16 +1228,16 @@ WasiCryptoTest::signatureExport(__wasi_signature_t SigHandle, __wasi_signature_encoding_e_t Encoding) { writeDummyMemoryContent(); - auto *Func = getHostFunc(WasiCryptoSignMod, - "signature_export"); + auto *Func = + getHostFunc(WasiCryptoSignMod, "signature_export"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run(&MemInst, + EXPECT_TRUE(Func->run(CallFrame, std::initializer_list{ SigHandle, static_cast(Encoding), 0}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_array_output_t *>(0); + return *MemInst->getPointer<__wasi_array_output_t *>(0); } WasiCryptoExpect<__wasi_signature_t> @@ -1245,29 +1250,30 @@ WasiCryptoTest::signatureImport(std::string_view Alg, writeSpan(Encoded, AlgSize); uint32_t EncodedSize = static_cast(Encoded.size()); - auto *Func = getHostFunc(WasiCryptoSignMod, - "signature_import"); + auto *Func = + getHostFunc(WasiCryptoSignMod, "signature_import"); EXPECT_NE(Func, nullptr); EXPECT_TRUE( - Func->run(&MemInst, + Func->run(CallFrame, std::initializer_list{ 0, AlgSize, AlgSize, EncodedSize, static_cast(Encoding), AlgSize + EncodedSize}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_signature_t *>(AlgSize + EncodedSize); + return *MemInst->getPointer<__wasi_signature_t *>(AlgSize + EncodedSize); } WasiCryptoExpect WasiCryptoTest::signatureClose(__wasi_signature_t SigHandle) { writeDummyMemoryContent(); - auto *Func = getHostFunc(WasiCryptoSignMod, - "signature_close"); + auto *Func = + getHostFunc(WasiCryptoSignMod, "signature_close"); EXPECT_NE(Func, nullptr); - EXPECT_TRUE(Func->run( - &MemInst, std::initializer_list{SigHandle}, Errno)); + EXPECT_TRUE(Func->run(CallFrame, + std::initializer_list{SigHandle}, + Errno)); ensureOrReturnOnTest(Errno[0].get()); return {}; @@ -1281,11 +1287,11 @@ WasiCryptoTest::signatureStateOpen(__wasi_signature_keypair_t KpHandle) { "signature_state_open"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, std::initializer_list{KpHandle, 0}, + CallFrame, std::initializer_list{KpHandle, 0}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_signature_state_t *>(0); + return *MemInst->getPointer<__wasi_signature_state_t *>(0); } WasiCryptoExpect @@ -1299,7 +1305,7 @@ WasiCryptoTest::signatureStateUpdate(__wasi_signature_state_t StateHandle, "signature_state_update"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, + CallFrame, std::initializer_list{StateHandle, 0, InputSize}, Errno)); ensureOrReturnOnTest(Errno[0].get()); @@ -1315,11 +1321,11 @@ WasiCryptoTest::signatureStateSign(__wasi_signature_state_t StateHandle) { "signature_state_sign"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, std::initializer_list{StateHandle, 0}, + CallFrame, std::initializer_list{StateHandle, 0}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_signature_t *>(0); + return *MemInst->getPointer<__wasi_signature_t *>(0); } WasiCryptoExpect @@ -1330,7 +1336,7 @@ WasiCryptoTest::signatureStateClose(__wasi_signature_state_t StateHandle) { "signature_state_close"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, std::initializer_list{StateHandle}, + CallFrame, std::initializer_list{StateHandle}, Errno)); ensureOrReturnOnTest(Errno[0].get()); @@ -1346,11 +1352,11 @@ WasiCryptoTest::signatureVerificationStateOpen( WasiCryptoSignMod, "signature_verification_state_open"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, std::initializer_list{PkHandle, 0}, + CallFrame, std::initializer_list{PkHandle, 0}, Errno)); ensureOrReturnOnTest(Errno[0].get()); - return *MemInst.getPointer<__wasi_signature_verification_state_t *>(0); + return *MemInst->getPointer<__wasi_signature_verification_state_t *>(0); } WasiCryptoExpect WasiCryptoTest::signatureVerificationStateUpdate( @@ -1364,7 +1370,7 @@ WasiCryptoExpect WasiCryptoTest::signatureVerificationStateUpdate( WasiCryptoSignMod, "signature_verification_state_update"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, + CallFrame, std::initializer_list{StateHandle, 0, InputSize}, Errno)); ensureOrReturnOnTest(Errno[0].get()); @@ -1381,7 +1387,7 @@ WasiCryptoExpect WasiCryptoTest::signatureVerificationStateVerify( WasiCryptoSignMod, "signature_verification_state_verify"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, + CallFrame, std::initializer_list{StateHandle, SigHandle}, Errno)); ensureOrReturnOnTest(Errno[0].get()); @@ -1397,7 +1403,7 @@ WasiCryptoExpect WasiCryptoTest::signatureVerificationStateClose( WasiCryptoSignMod, "signature_verification_state_close"); EXPECT_NE(Func, nullptr); EXPECT_TRUE(Func->run( - &MemInst, std::initializer_list{StateHandle}, + CallFrame, std::initializer_list{StateHandle}, Errno)); ensureOrReturnOnTest(Errno[0].get()); @@ -1418,7 +1424,7 @@ WasiCryptoExpect WasiCryptoTest::signatureVerificationStateClose( // if (Res != __WASI_CRYPTO_ERRNO_SUCCESS) { // return WasiCryptoUnexpect(Res); // } -// return *MemInst.getPointer<__wasi_signature_keypair_t *>(AlgStr.size() + +// return *MemInst->getPointer<__wasi_signature_keypair_t *>(AlgStr.size() + // Encoded.size()); // } @@ -1431,7 +1437,7 @@ WasiCryptoExpect WasiCryptoTest::signatureVerificationStateClose( // if (Res != __WASI_CRYPTO_ERRNO_SUCCESS) { // return WasiCryptoUnexpect(Res); // } -// return *MemInst.getPointer<__wasi_signature_keypair_t *>(0); +// return *MemInst->getPointer<__wasi_signature_keypair_t *>(0); // } } // namespace WasiCrypto diff --git a/test/plugins/wasi_crypto/helper.h b/test/plugins/wasi_crypto/helper.h index 92fd6379..beb1a339 100644 --- a/test/plugins/wasi_crypto/helper.h +++ b/test/plugins/wasi_crypto/helper.h @@ -14,7 +14,8 @@ #include "common/span.h" #include "common/types.h" -#include "runtime/instance/memory.h" +#include "runtime/callingframe.h" +#include "runtime/instance/module.h" #include "gtest/gtest.h" #include @@ -54,7 +55,12 @@ std::vector operator"" _u8v(const char *Str, std::size_t Len); /// Designed for testing. class WasiCryptoTest : public ::testing::Test { public: - WasiCryptoTest() { + WasiCryptoTest() : Mod(""), CallFrame(nullptr, &Mod) { + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + MemInst = Mod.findMemoryExports("memory"); + using namespace std::literals::string_view_literals; Plugin::Plugin::load(std::filesystem::u8path( "../../../plugins/wasi_crypto/" @@ -384,8 +390,12 @@ class WasiCryptoTest : public ::testing::Test { __wasi_signature_verification_state_t StateHandle); int32_t InvaildHandle = 9999; - WasmEdge::Runtime::Instance::MemoryInstance MemInst{ - WasmEdge::AST::MemoryType(1)}; + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod; + WasmEdge::Runtime::Instance::MemoryInstance *MemInst; + WasmEdge::Runtime::CallingFrame CallFrame; + std::array Errno; Host::WasiCryptoAsymmetricCommonModule *WasiCryptoAsymCommonMod = nullptr; diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 15631ff5..0edf04c5 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -81,11 +81,17 @@ std::vector classSort(const std::vector &Array) { TEST(WasiNNTest, OpenVINOBackend) { // Create the wasmedge_process module instance. auto *NNMod = dynamic_cast(createModule()); - EXPECT_FALSE(NNMod == nullptr); - - // Create the memory instance. - WasmEdge::Runtime::Instance::MemoryInstance MemInst( - WasmEdge::AST::MemoryType(400)); + ASSERT_TRUE(NNMod != nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(400))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); // Load the files. std::vector TensorData = @@ -144,7 +150,7 @@ TEST(WasiNNTest, OpenVINOBackend) { // Test: load -- meaningless binaries. { EXPECT_TRUE(HostFuncLoad.run( - &MemInst, + CallFrame, std::initializer_list{ LoadEntryPtr, UINT32_C(2), UINT32_C(0), UINT32_C(0), BuilderPtr}, Errno)); @@ -154,7 +160,7 @@ TEST(WasiNNTest, OpenVINOBackend) { // Test: load -- graph id ptr out of bounds. { EXPECT_TRUE(HostFuncLoad.run( - &MemInst, + CallFrame, std::initializer_list{ LoadEntryPtr, UINT32_C(2), UINT32_C(0), UINT32_C(0), OutBoundPtr}, Errno)); @@ -165,7 +171,7 @@ TEST(WasiNNTest, OpenVINOBackend) { // Test: load -- graph builder ptr out of bounds. { EXPECT_TRUE(HostFuncLoad.run( - &MemInst, + CallFrame, std::initializer_list{ OutBoundPtr, UINT32_C(2), UINT32_C(0), UINT32_C(0), BuilderPtr}, Errno)); @@ -180,7 +186,7 @@ TEST(WasiNNTest, OpenVINOBackend) { BuilderPtr); { EXPECT_TRUE(HostFuncLoad.run( - &MemInst, + CallFrame, std::initializer_list{ LoadEntryPtr, UINT32_C(2), UINT32_C(0), UINT32_C(0), BuilderPtr}, Errno)); @@ -194,7 +200,7 @@ TEST(WasiNNTest, OpenVINOBackend) { writeFatPointer(MemInst, OutBoundPtr, WeightRead.size(), BuilderPtr); { EXPECT_TRUE(HostFuncLoad.run( - &MemInst, + CallFrame, std::initializer_list{ LoadEntryPtr, UINT32_C(2), UINT32_C(0), UINT32_C(0), BuilderPtr}, Errno)); @@ -212,7 +218,7 @@ TEST(WasiNNTest, OpenVINOBackend) { StorePtr += (XmlRead.size() + WeightRead.size()); { EXPECT_TRUE(HostFuncLoad.run( - &MemInst, + CallFrame, std::initializer_list{ LoadEntryPtr, UINT32_C(4), UINT32_C(0), UINT32_C(0), BuilderPtr}, Errno)); @@ -223,7 +229,7 @@ TEST(WasiNNTest, OpenVINOBackend) { // Test: load -- unsupported device. CPU 0, GPU 1, TPU 2 { EXPECT_TRUE(HostFuncLoad.run( - &MemInst, + CallFrame, std::initializer_list{ LoadEntryPtr, UINT32_C(2), UINT32_C(0), UINT32_C(3), BuilderPtr}, Errno)); @@ -234,7 +240,7 @@ TEST(WasiNNTest, OpenVINOBackend) { // Test: load -- load successfully. { EXPECT_TRUE(HostFuncLoad.run( - &MemInst, + CallFrame, std::initializer_list{ LoadEntryPtr, UINT32_C(2), UINT32_C(0), UINT32_C(0), BuilderPtr}, Errno)); @@ -246,7 +252,7 @@ TEST(WasiNNTest, OpenVINOBackend) { // Test: load -- load second graph. { EXPECT_TRUE(HostFuncLoad.run( - &MemInst, + CallFrame, std::initializer_list{ LoadEntryPtr, UINT32_C(2), UINT32_C(0), UINT32_C(0), BuilderPtr}, Errno)); @@ -259,7 +265,7 @@ TEST(WasiNNTest, OpenVINOBackend) { // Test: init_execution_context -- graph id invalid. { EXPECT_TRUE(HostFuncInit.run( - &MemInst, + CallFrame, std::initializer_list{UINT32_C(2), BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), @@ -273,7 +279,7 @@ TEST(WasiNNTest, OpenVINOBackend) { // Test: init_execution_context -- graph id exceeds. { EXPECT_TRUE(HostFuncInit.run( - &MemInst, + CallFrame, std::initializer_list{UINT32_C(0), BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), @@ -286,7 +292,7 @@ TEST(WasiNNTest, OpenVINOBackend) { // Test: init_execution_context -- init context successfully. { EXPECT_TRUE(HostFuncInit.run( - &MemInst, + CallFrame, std::initializer_list{UINT32_C(0), BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); @@ -297,7 +303,7 @@ TEST(WasiNNTest, OpenVINOBackend) { // Test: init_execution_context -- init second context. { EXPECT_TRUE(HostFuncInit.run( - &MemInst, + CallFrame, std::initializer_list{UINT32_C(1), BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); @@ -321,7 +327,7 @@ TEST(WasiNNTest, OpenVINOBackend) { // Test: set_input -- context id exceeds. { EXPECT_TRUE( - HostFuncSetInput.run(&MemInst, + HostFuncSetInput.run(CallFrame, std::initializer_list{ UINT32_C(3), UINT32_C(0), SetInputEntryPtr}, Errno)); @@ -332,7 +338,7 @@ TEST(WasiNNTest, OpenVINOBackend) { // Test: set_input -- empty context. { EXPECT_TRUE( - HostFuncSetInput.run(&MemInst, + HostFuncSetInput.run(CallFrame, std::initializer_list{ UINT32_C(0), UINT32_C(0), SetInputEntryPtr}, Errno)); @@ -346,7 +352,7 @@ TEST(WasiNNTest, OpenVINOBackend) { // Test: set_input -- input index exceeds. { EXPECT_TRUE( - HostFuncSetInput.run(&MemInst, + HostFuncSetInput.run(CallFrame, std::initializer_list{ UINT32_C(0), UINT32_C(10), SetInputEntryPtr}, Errno)); @@ -362,7 +368,7 @@ TEST(WasiNNTest, OpenVINOBackend) { BuilderPtr); { EXPECT_TRUE( - HostFuncSetInput.run(&MemInst, + HostFuncSetInput.run(CallFrame, std::initializer_list{ UINT32_C(0), UINT32_C(0), SetInputEntryPtr}, Errno)); @@ -378,7 +384,7 @@ TEST(WasiNNTest, OpenVINOBackend) { BuilderPtr); { EXPECT_TRUE( - HostFuncSetInput.run(&MemInst, + HostFuncSetInput.run(CallFrame, std::initializer_list{ UINT32_C(1), UINT32_C(0), SetInputEntryPtr}, Errno)); @@ -390,7 +396,7 @@ TEST(WasiNNTest, OpenVINOBackend) { // Test: compute -- context id exceeds. { EXPECT_TRUE(HostFuncCompute.run( - &MemInst, std::initializer_list{UINT32_C(3)}, + CallFrame, std::initializer_list{UINT32_C(3)}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); @@ -402,7 +408,7 @@ TEST(WasiNNTest, OpenVINOBackend) { // Test: compute -- empty context. { EXPECT_TRUE(HostFuncCompute.run( - &MemInst, std::initializer_list{UINT32_C(0)}, + CallFrame, std::initializer_list{UINT32_C(0)}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Busy)); } @@ -413,7 +419,7 @@ TEST(WasiNNTest, OpenVINOBackend) { // Test: compute -- compute successfully. { EXPECT_TRUE(HostFuncCompute.run( - &MemInst, std::initializer_list{UINT32_C(1)}, + CallFrame, std::initializer_list{UINT32_C(1)}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); } @@ -422,7 +428,7 @@ TEST(WasiNNTest, OpenVINOBackend) { // Test: get_output -- output bytes ptr out of bounds. { EXPECT_TRUE(HostFuncGetOutput.run( - &MemInst, + CallFrame, std::initializer_list{ UINT32_C(0), UINT32_C(0), StorePtr, 65532, OutBoundPtr}, Errno)); @@ -433,7 +439,7 @@ TEST(WasiNNTest, OpenVINOBackend) { // Test: get_output -- output buffer ptr out of bounds. { EXPECT_TRUE(HostFuncGetOutput.run( - &MemInst, + CallFrame, std::initializer_list{ UINT32_C(0), UINT32_C(0), OutBoundPtr, 65532, BuilderPtr}, Errno)); @@ -444,7 +450,7 @@ TEST(WasiNNTest, OpenVINOBackend) { // Test: get_output -- output index exceeds. { EXPECT_TRUE(HostFuncGetOutput.run( - &MemInst, + CallFrame, std::initializer_list{ UINT32_C(0), UINT32_C(10), StorePtr, 65532, BuilderPtr}, Errno)); @@ -455,7 +461,7 @@ TEST(WasiNNTest, OpenVINOBackend) { // Test: get_output -- get output successfully. { EXPECT_TRUE(HostFuncGetOutput.run( - &MemInst, + CallFrame, std::initializer_list{ UINT32_C(1), UINT32_C(0), StorePtr, 65532, BuilderPtr}, Errno)); diff --git a/test/plugins/wasmedge_process/wasmedge_process.cpp b/test/plugins/wasmedge_process/wasmedge_process.cpp index 42bfb7a0..9c25f412 100644 --- a/test/plugins/wasmedge_process/wasmedge_process.cpp +++ b/test/plugins/wasmedge_process/wasmedge_process.cpp @@ -14,6 +14,8 @@ #include namespace { +WasmEdge::Runtime::CallingFrame DummyCallFrame(nullptr, nullptr); + WasmEdge::Runtime::Instance::ModuleInstance *createModule() { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( @@ -44,11 +46,17 @@ TEST(WasmEdgeProcessTest, SetProgName) { // Create the wasmedge_process module instance. auto *ProcMod = dynamic_cast(createModule()); - EXPECT_FALSE(ProcMod == nullptr); - - // Create the memory instance. - WasmEdge::Runtime::Instance::MemoryInstance MemInst( - WasmEdge::AST::MemoryType(1)); + ASSERT_TRUE(ProcMod != nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); // Clear the memory[0, 64]. fillMemContent(MemInst, 0, 64); @@ -65,14 +73,14 @@ TEST(WasmEdgeProcessTest, SetProgName) { // Test: Run function successfully. EXPECT_TRUE(HostFuncInst.run( - &MemInst, + CallFrame, std::initializer_list{UINT32_C(0), UINT32_C(4)}, {})); EXPECT_EQ(ProcMod->getEnv().Name, "echo"); // Test: Run function with nullptr memory instance -- fail EXPECT_FALSE(HostFuncInst.run( - nullptr, + DummyCallFrame, std::initializer_list{UINT32_C(0), UINT32_C(4)}, {})); @@ -83,11 +91,17 @@ TEST(WasmEdgeProcessTest, AddArg) { // Create the wasmedge_process module instance. auto *ProcMod = dynamic_cast(createModule()); - EXPECT_FALSE(ProcMod == nullptr); - - // Create the memory instance. - WasmEdge::Runtime::Instance::MemoryInstance MemInst( - WasmEdge::AST::MemoryType(1)); + ASSERT_TRUE(ProcMod != nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); // Clear the memory[0, 64]. fillMemContent(MemInst, 0, 64); @@ -107,7 +121,7 @@ TEST(WasmEdgeProcessTest, AddArg) { // Test: Run function successfully to add "arg1". EXPECT_TRUE(HostFuncInst.run( - &MemInst, + CallFrame, std::initializer_list{UINT32_C(0), UINT32_C(4)}, {})); EXPECT_EQ(ProcMod->getEnv().Args.size(), 1U); @@ -115,7 +129,7 @@ TEST(WasmEdgeProcessTest, AddArg) { // Test: Run function successfully to add "arg2". EXPECT_TRUE(HostFuncInst.run( - &MemInst, + CallFrame, std::initializer_list{UINT32_C(4), UINT32_C(4)}, {})); EXPECT_EQ(ProcMod->getEnv().Args.size(), 2U); @@ -123,7 +137,7 @@ TEST(WasmEdgeProcessTest, AddArg) { // Test: Run function successfully to add "--final-arg". EXPECT_TRUE(HostFuncInst.run( - &MemInst, + CallFrame, std::initializer_list{UINT32_C(30), UINT32_C(11)}, {})); EXPECT_EQ(ProcMod->getEnv().Args.size(), 3U); @@ -131,7 +145,7 @@ TEST(WasmEdgeProcessTest, AddArg) { // Test: Run function with nullptr memory instance -- fail EXPECT_FALSE(HostFuncInst.run( - nullptr, + DummyCallFrame, std::initializer_list{UINT32_C(0), UINT32_C(4)}, {})); @@ -142,11 +156,17 @@ TEST(WasmEdgeProcessTest, AddEnv) { // Create the wasmedge_process module instance. auto *ProcMod = dynamic_cast(createModule()); - EXPECT_FALSE(ProcMod == nullptr); - - // Create the memory instance. - WasmEdge::Runtime::Instance::MemoryInstance MemInst( - WasmEdge::AST::MemoryType(1)); + ASSERT_TRUE(ProcMod != nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); // Clear the memory[0, 256]. fillMemContent(MemInst, 0, 256); @@ -168,7 +188,7 @@ TEST(WasmEdgeProcessTest, AddEnv) { // Test: Run function successfully to add "ENV1", "VALUE1". EXPECT_TRUE( - HostFuncInst.run(&MemInst, + HostFuncInst.run(CallFrame, std::initializer_list{ UINT32_C(0), UINT32_C(4), UINT32_C(4), UINT32_C(6)}, {})); @@ -177,7 +197,7 @@ TEST(WasmEdgeProcessTest, AddEnv) { // Test: Run function successfully to add "LD_LIBRARY_PATH", "/usr/local/lib". EXPECT_TRUE(HostFuncInst.run( - &MemInst, + CallFrame, std::initializer_list{UINT32_C(30), UINT32_C(15), UINT32_C(50), UINT32_C(14)}, {})); @@ -186,7 +206,7 @@ TEST(WasmEdgeProcessTest, AddEnv) { // Test: Run function with nullptr memory instance -- fail EXPECT_FALSE( - HostFuncInst.run(nullptr, + HostFuncInst.run(DummyCallFrame, std::initializer_list{ UINT32_C(0), UINT32_C(4), UINT32_C(4), UINT32_C(6)}, {})); @@ -198,11 +218,17 @@ TEST(WasmEdgeProcessTest, AddStdIn) { // Create the wasmedge_process module instance. auto *ProcMod = dynamic_cast(createModule()); - EXPECT_FALSE(ProcMod == nullptr); - - // Create the memory instance. - WasmEdge::Runtime::Instance::MemoryInstance MemInst( - WasmEdge::AST::MemoryType(1)); + ASSERT_TRUE(ProcMod != nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); // Clear the memory[0, 64]. fillMemContent(MemInst, 0, 64); @@ -220,7 +246,7 @@ TEST(WasmEdgeProcessTest, AddStdIn) { // Test: Run function successfully to add "\01\02\03\04". EXPECT_TRUE(HostFuncInst.run( - &MemInst, + CallFrame, std::initializer_list{UINT32_C(0), UINT32_C(4)}, {})); EXPECT_EQ(ProcMod->getEnv().StdIn.size(), 4U); @@ -229,7 +255,7 @@ TEST(WasmEdgeProcessTest, AddStdIn) { // Test: Run function successfully to add "hello, wasmedge\n". EXPECT_TRUE(HostFuncInst.run( - &MemInst, + CallFrame, std::initializer_list{UINT32_C(30), UINT32_C(16)}, {})); EXPECT_EQ(ProcMod->getEnv().StdIn.size(), 20U); @@ -240,7 +266,7 @@ TEST(WasmEdgeProcessTest, AddStdIn) { // Test: Run function with nullptr memory instance -- fail EXPECT_FALSE(HostFuncInst.run( - nullptr, + DummyCallFrame, std::initializer_list{UINT32_C(0), UINT32_C(4)}, {})); @@ -251,7 +277,7 @@ TEST(WasmEdgeProcessTest, SetTimeOut) { // Create the wasmedge_process module instance. auto *ProcMod = dynamic_cast(createModule()); - EXPECT_FALSE(ProcMod == nullptr); + ASSERT_TRUE(ProcMod != nullptr); // Get the function "wasmedge_process_set_timeout". auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_set_timeout"); @@ -263,7 +289,8 @@ TEST(WasmEdgeProcessTest, SetTimeOut) { // Test: Run function successfully to set timeout 100. EXPECT_TRUE(HostFuncInst.run( - nullptr, std::initializer_list{UINT32_C(100)}, {})); + DummyCallFrame, + std::initializer_list{UINT32_C(100)}, {})); EXPECT_EQ(ProcMod->getEnv().TimeOut, 100U); delete ProcMod; @@ -273,11 +300,17 @@ TEST(WasmEdgeProcessTest, Run) { // Create the wasmedge_process module instance. auto *ProcMod = dynamic_cast(createModule()); - EXPECT_FALSE(ProcMod == nullptr); - - // Create the memory instance. - WasmEdge::Runtime::Instance::MemoryInstance MemInst( - WasmEdge::AST::MemoryType(1)); + ASSERT_TRUE(ProcMod != nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); // Clear the memory[0, 64]. fillMemContent(MemInst, 0, 64); @@ -299,7 +332,7 @@ TEST(WasmEdgeProcessTest, Run) { // Test: Run function failed to run "c++" without allowing all commands. ProcMod->getEnv().AllowedAll = false; ProcMod->getEnv().Name = "c++"; - EXPECT_TRUE(HostFuncInst.run(nullptr, {}, RetVal)); + EXPECT_TRUE(HostFuncInst.run(DummyCallFrame, {}, RetVal)); EXPECT_EQ(RetVal[0].get(), -1); EXPECT_TRUE(ProcMod->getEnv().StdOut.size() == 0); EXPECT_TRUE(ProcMod->getEnv().StdErr.size() > 0); @@ -313,7 +346,7 @@ TEST(WasmEdgeProcessTest, Run) { // Test: Run function successfully to run "c++" with allowing all commands. ProcMod->getEnv().AllowedAll = true; ProcMod->getEnv().Name = "c++"; - EXPECT_TRUE(HostFuncInst.run(nullptr, {}, RetVal)); + EXPECT_TRUE(HostFuncInst.run(DummyCallFrame, {}, RetVal)); EXPECT_EQ(RetVal[0].get(), 1); EXPECT_TRUE(ProcMod->getEnv().StdOut.size() == 0); EXPECT_TRUE(ProcMod->getEnv().StdErr.size() > 0); @@ -322,7 +355,7 @@ TEST(WasmEdgeProcessTest, Run) { ProcMod->getEnv().AllowedAll = false; ProcMod->getEnv().AllowedCmd.insert("c++"); ProcMod->getEnv().Name = "c++"; - EXPECT_TRUE(HostFuncInst.run(nullptr, {}, RetVal)); + EXPECT_TRUE(HostFuncInst.run(DummyCallFrame, {}, RetVal)); EXPECT_EQ(RetVal[0].get(), 1); EXPECT_TRUE(ProcMod->getEnv().StdOut.size() == 0); EXPECT_TRUE(ProcMod->getEnv().StdErr.size() > 0); @@ -334,7 +367,7 @@ TEST(WasmEdgeProcessTest, Run) { ProcMod->getEnv().AllowedCmd.insert("/bin/echo"); ProcMod->getEnv().Name = "/bin/echo"; ProcMod->getEnv().Args.push_back("123456 test"); - EXPECT_TRUE(HostFuncInst.run(nullptr, {}, RetVal)); + EXPECT_TRUE(HostFuncInst.run(DummyCallFrame, {}, RetVal)); EXPECT_EQ(RetVal[0].get(), 0); EXPECT_TRUE(ProcMod->getEnv().StdOut.size() == 12); EXPECT_TRUE(ProcMod->getEnv().StdErr.size() == 0); @@ -349,7 +382,7 @@ TEST(WasmEdgeProcessTest, GetExitCode) { // Create the wasmedge_process module instance. auto *ProcMod = dynamic_cast(createModule()); - EXPECT_FALSE(ProcMod == nullptr); + ASSERT_TRUE(ProcMod != nullptr); // Get the function "wasmedge_process_get_exit_code". auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_get_exit_code"); @@ -361,7 +394,7 @@ TEST(WasmEdgeProcessTest, GetExitCode) { // Test: Run function successfully to get exit code. std::array RetVal; - EXPECT_TRUE(HostFuncInst.run(nullptr, {}, RetVal)); + EXPECT_TRUE(HostFuncInst.run(DummyCallFrame, {}, RetVal)); EXPECT_EQ(RetVal[0].get(), 0); delete ProcMod; @@ -371,11 +404,17 @@ TEST(WasmEdgeProcessTest, GetStdOut) { // Create the wasmedge_process module instance. auto *ProcMod = dynamic_cast(createModule()); - EXPECT_FALSE(ProcMod == nullptr); - - // Create the memory instance. - WasmEdge::Runtime::Instance::MemoryInstance MemInst( - WasmEdge::AST::MemoryType(1)); + ASSERT_TRUE(ProcMod != nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); // Clear the memory[0, 256]. fillMemContent(MemInst, 0, 256); @@ -408,21 +447,22 @@ TEST(WasmEdgeProcessTest, GetStdOut) { ProcMod->getEnv().Name = "echo"; ProcMod->getEnv().AllowedCmd.insert("echo"); ProcMod->getEnv().Args.push_back("$(pwd)"); - EXPECT_TRUE(HostFuncRun.run(nullptr, {}, RetVal)); + EXPECT_TRUE(HostFuncRun.run(DummyCallFrame, {}, RetVal)); EXPECT_EQ(RetVal[0].get(), 0U); // Test: Run wasmedge_process_get_stdout_len successfully. - EXPECT_TRUE(HostFuncGetStdOutLen.run(nullptr, {}, RetVal)); + EXPECT_TRUE(HostFuncGetStdOutLen.run(DummyCallFrame, {}, RetVal)); uint32_t Len = RetVal[0].get(); EXPECT_TRUE(Len > 0U); // Test: Run function with nullptr memory instance -- fail EXPECT_FALSE(HostFuncGetStdOut.run( - nullptr, std::initializer_list{UINT32_C(0)}, {})); + DummyCallFrame, std::initializer_list{UINT32_C(0)}, + {})); // Test: Run wasmedge_process_get_stdout successfully. EXPECT_TRUE(HostFuncGetStdOut.run( - &MemInst, std::initializer_list{UINT32_C(0)}, {})); + CallFrame, std::initializer_list{UINT32_C(0)}, {})); EXPECT_TRUE(std::equal(ProcMod->getEnv().StdOut.begin(), ProcMod->getEnv().StdOut.end(), MemInst.getPointer(0))); @@ -434,11 +474,17 @@ TEST(WasmEdgeProcessTest, GetStdErr) { // Create the wasmedge_process module instance. auto *ProcMod = dynamic_cast(createModule()); - EXPECT_FALSE(ProcMod == nullptr); - - // Create the memory instance. - WasmEdge::Runtime::Instance::MemoryInstance MemInst( - WasmEdge::AST::MemoryType(1)); + ASSERT_TRUE(ProcMod != nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); // Clear the memory[0, 256]. fillMemContent(MemInst, 0, 256); @@ -470,21 +516,22 @@ TEST(WasmEdgeProcessTest, GetStdErr) { // Run the command "c++". ProcMod->getEnv().Name = "c++"; ProcMod->getEnv().AllowedCmd.insert("c++"); - EXPECT_TRUE(HostFuncRun.run(nullptr, {}, RetVal)); + EXPECT_TRUE(HostFuncRun.run(DummyCallFrame, {}, RetVal)); EXPECT_NE(RetVal[0].get(), 0U); // Test: Run wasmedge_process_get_stdout_len successfully. - EXPECT_TRUE(HostFuncGetStdErrLen.run(nullptr, {}, RetVal)); + EXPECT_TRUE(HostFuncGetStdErrLen.run(DummyCallFrame, {}, RetVal)); uint32_t Len = RetVal[0].get(); EXPECT_TRUE(Len > 0U); // Test: Run function with nullptr memory instance -- fail EXPECT_FALSE(HostFuncGetStdErr.run( - nullptr, std::initializer_list{UINT32_C(0)}, {})); + DummyCallFrame, std::initializer_list{UINT32_C(0)}, + {})); // Test: Run wasmedge_process_get_stdout successfully. EXPECT_TRUE(HostFuncGetStdErr.run( - &MemInst, std::initializer_list{UINT32_C(0)}, {})); + CallFrame, std::initializer_list{UINT32_C(0)}, {})); EXPECT_TRUE(std::equal(ProcMod->getEnv().StdOut.begin(), ProcMod->getEnv().StdOut.end(), MemInst.getPointer(0))); From 2bdaa9895e8a390feef673fd94040e05b9eba11e Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 24 Aug 2022 03:45:12 +0800 Subject: [PATCH 067/623] [Utils] Remove manylinux1 dockerfile and related patches about glibc and gcc Signed-off-by: hydai --- utils/docker/Dockerfile.manylinux1_x86_64 | 86 ---------- utils/docker/cmake-glibc-2.5.patch | 193 ---------------------- utils/docker/gcc-4.8.2-gcc-11.patch | 38 ----- utils/docker/llvm-glibc-2.5.patch | 39 ----- 4 files changed, 356 deletions(-) delete mode 100644 utils/docker/Dockerfile.manylinux1_x86_64 delete mode 100644 utils/docker/cmake-glibc-2.5.patch delete mode 100644 utils/docker/gcc-4.8.2-gcc-11.patch delete mode 100644 utils/docker/llvm-glibc-2.5.patch diff --git a/utils/docker/Dockerfile.manylinux1_x86_64 b/utils/docker/Dockerfile.manylinux1_x86_64 deleted file mode 100644 index 003456e3..00000000 --- a/utils/docker/Dockerfile.manylinux1_x86_64 +++ /dev/null @@ -1,86 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# SPDX-FileCopyrightText: 2019-2022 Second State INC - -FROM quay.io/pypa/manylinux1_x86_64 - -MAINTAINER hydai hydai@secondstate.io - -ADD SHA256SUM gcc-4.8.2-gcc-11.patch llvm-glibc-2.5.patch cmake-glibc-2.5.patch /root/ - -ENV PATH /toolchain/bin:$PATH -ENV CC gcc -ENV CXX g++ -ENV CPPFLAGS -I/toolchain/include -ENV LDFLAGS -L/toolchain/lib64 -ENV PKG_CONFIG_PATH /toolchain/lib64/pkgconfig - -RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build dpkg && \ - export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ - 'import multiprocessing; print(multiprocessing.cpu_count())') && \ - export CFGFLAGS="--prefix=/toolchain --disable-shared --libdir=/toolchain/lib64" && \ - curl -s -L -O --remote-name-all \ - https://ftp.gnu.org/gnu/gmp/gmp-6.2.1.tar.xz \ - https://ftp.gnu.org/gnu/mpfr/mpfr-4.1.0.tar.xz \ - https://ftp.gnu.org/gnu/mpc/mpc-1.2.1.tar.gz \ - https://libisl.sourceforge.io/isl-0.24.tar.xz \ - https://github.com/facebook/zstd/releases/download/v1.5.0/zstd-1.5.0.tar.gz \ - https://ftp.gnu.org/gnu/gcc/gcc-11.1.0/gcc-11.1.0.tar.xz \ - https://github.com/Kitware/CMake/releases/download/v3.20.2/cmake-3.20.2.tar.gz \ - https://github.com/ninja-build/ninja/archive/v1.10.2.tar.gz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-12.0.0/llvm-12.0.0.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-12.0.0/lld-12.0.0.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-12.0.0/libunwind-12.0.0.src.tar.xz && \ - sha256sum -c SHA256SUM && \ - xz -dc gmp-6.2.1.tar.xz | tar -xf - && \ - xz -dc mpfr-4.1.0.tar.xz | tar -xf - && \ - gzip -dc mpc-1.2.1.tar.gz | tar -xf - && \ - xz -dc isl-0.24.tar.xz | tar -xf - && \ - gzip -dc zstd-1.5.0.tar.gz | tar -xf - && \ - xz -dc gcc-11.1.0.tar.xz | tar -xf - && \ - gzip -dc cmake-3.20.2.tar.gz | tar -xf - && \ - gzip -dc v1.10.2.tar.gz | tar -xf - && \ - xz -dc llvm-12.0.0.src.tar.xz | tar -xf - && \ - xz -dc lld-12.0.0.src.tar.xz | tar -xf - && \ - xz -dc libunwind-12.0.0.src.tar.xz | tar -xf - && \ - cd gcc-11.1.0 && patch -p1 < ../gcc-4.8.2-gcc-11.patch && cd - && \ - cd llvm-12.0.0.src && patch -p1 < ../llvm-glibc-2.5.patch && cd - && \ - cd cmake-3.20.2 && patch -p1 < ../cmake-glibc-2.5.patch && cd - && \ - mkdir build && cd build && ../gmp-6.2.1/configure --build=x86_64-pc-linux-gnu $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mkdir build && cd build && ../mpfr-4.1.0/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mkdir build && cd build && ../mpc-1.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mkdir build && cd build && ../isl-0.24/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - export ZSTDFLAGS="PREFIX=/toolchain LIBDIR=/toolchain/lib64 SED_ERE_OPT=--regexp-extended" && \ - cd zstd-1.5.0 && make -s $ZSTDFLAGS -j $CPU && make -s $ZSTDFLAGS install && rm -f /toolchain/lib64/libzstd.so* && cd - && \ - mkdir build && cd build && ../gcc-11.1.0/configure --prefix=/toolchain --libdir=/toolchain/lib64 \ - --with-gmp=/toolchain --with-gmp-lib=/toolchain/lib64 \ - --with-zstd=/toolchain --with-zstd-lib=/toolchain/lib64 \ - --disable-libmpx --disable-libsanitizer --disable-libunwind-exceptions \ - --disable-multilib --enable-__cxa_atexit --enable-gnu-indirect-function \ - --disable-gnu-unique-object --enable-initfini-array --enable-languages="c,c++,lto" \ - --enable-linker-build-id --enable-lto --enable-plugin --enable-threads=posix \ - --with-default-libstdcxx-abi="gcc4-compatible" \ - --with-gcc-major-version-only --with-linker-hash-style="gnu" \ - --with-arch="x86-64" --with-tune="generic" && \ - make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mv -v /toolchain/lib64/libstdc++.so.6.0.29 /toolchain/lib64/libstdc++.so.6.0.29.backup && \ - echo -e "OUTPUT_FORMAT(elf64-x86-64)\nINPUT ( libstdc++.so.6.0.8 libstdc++.a )" \ - > /toolchain/lib64/libstdc++.so.6.0.29 && \ - export PATH="/toolchain/bin:$PATH" && \ - mkdir build && cd build && /opt/python/cp39-cp39/bin/python \ - ../ninja-1.10.2/configure.py --bootstrap \ - --with-python=/opt/python/cp39-cp39/bin/python && \ - cp -v ninja /toolchain/bin/ninja && cd - && rm -rf build && \ - mkdir build && cd build && ../cmake-3.20.2/configure --prefix=/toolchain \ - --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mv -v llvm-12.0.0.src llvm && \ - mv -v lld-12.0.0.src lld && \ - mv -v libunwind-12.0.0.src libunwind && \ - cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_INSTALL_PREFIX=/toolchain \ - -DPython3_ROOT_DIR=/opt/python/cp39-cp39 -DLLVM_LIBDIR_SUFFIX=64 \ - -DLLVM_TARGETS_TO_BUILD="X86" -DLLVM_ENABLE_PROJECTS=lld \ - -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" llvm && \ - cmake --build build --target install && \ - rm -rf build && rm -rf * - -RUN yum clean all diff --git a/utils/docker/cmake-glibc-2.5.patch b/utils/docker/cmake-glibc-2.5.patch deleted file mode 100644 index bada2453..00000000 --- a/utils/docker/cmake-glibc-2.5.patch +++ /dev/null @@ -1,193 +0,0 @@ -diff -rup a/Utilities/cmlibuv/src/unix/async.c b/Utilities/cmlibuv/src/unix/async.c ---- a/Utilities/cmlibuv/src/unix/async.c 2020-12-16 12:35:29.000000000 +0000 -+++ b/Utilities/cmlibuv/src/unix/async.c 2020-12-22 18:36:14.000000000 +0000 -@@ -34,7 +34,7 @@ - #include - #include /* sched_yield() */ - --#ifdef __linux__ -+#if defined(__linux__) && __GLIBC_PREREQ(2, 8) - #include - #endif - -@@ -175,7 +175,7 @@ static void uv__async_send(uv_loop_t* lo - len = 1; - fd = loop->async_wfd; - --#if defined(__linux__) -+#if defined(__linux__) && __GLIBC_PREREQ(2, 8) - if (fd == -1) { - static const uint64_t val = 1; - buf = &val; -@@ -206,7 +206,7 @@ static int uv__async_start(uv_loop_t* lo - if (loop->async_io_watcher.fd != -1) - return 0; - --#ifdef __linux__ -+#if defined(__linux__) && __GLIBC_PREREQ(2, 8) - err = eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK); - if (err < 0) - return UV__ERR(errno); -diff -rup a/Utilities/cmlibuv/src/unix/core.c b/Utilities/cmlibuv/src/unix/core.c ---- a/Utilities/cmlibuv/src/unix/core.c 2020-12-16 12:35:29.000000000 +0000 -+++ b/Utilities/cmlibuv/src/unix/core.c 2020-12-22 18:07:40.000000000 +0000 -@@ -88,7 +88,9 @@ extern char** environ; - - #if defined(__linux__) - # include --# define uv__accept4 accept4 -+# if __GLIBC_PREREQ(2, 10) -+# define uv__accept4 accept4 -+# endif - #endif - - static int uv__run_pending(uv_loop_t* loop); -@@ -1032,7 +1034,7 @@ int uv__open_cloexec(const char* path, i - - - int uv__dup2_cloexec(int oldfd, int newfd) { --#if defined(__FreeBSD__) || defined(__NetBSD__) || defined(__linux__) -+#if defined(__FreeBSD__) || defined(__NetBSD__) || (defined(__linux__) && __GLIBC_PREREQ(2, 9)) - int r; - - r = dup3(oldfd, newfd, O_CLOEXEC); -diff -rup a/Utilities/cmlibuv/src/unix/fs.c b/Utilities/cmlibuv/src/unix/fs.c ---- a/Utilities/cmlibuv/src/unix/fs.c 2020-12-16 12:35:29.000000000 +0000 -+++ b/Utilities/cmlibuv/src/unix/fs.c 2020-12-22 18:10:56.000000000 +0000 -@@ -224,7 +224,7 @@ UV_UNUSED(static struct timeval uv__fs_t - } - - static ssize_t uv__fs_futime(uv_fs_t* req) { --#if defined(__linux__) \ -+#if (defined(__linux__) && __GLIBC_PREREQ(2, 6)) \ - || defined(_AIX71) \ - || defined(__HAIKU__) - /* utimesat() has nanosecond resolution but we stick to microseconds -@@ -234,7 +234,8 @@ static ssize_t uv__fs_futime(uv_fs_t* re - ts[0] = uv__fs_to_timespec(req->atime); - ts[1] = uv__fs_to_timespec(req->mtime); - return futimens(req->file, ts); --#elif defined(__APPLE__) \ -+#elif (defined(__linux__) && !__GLIBC_PREREQ(2, 6)) \ -+ || defined(__APPLE__) \ - || defined(__DragonFly__) \ - || defined(__FreeBSD__) \ - || defined(__FreeBSD_kernel__) \ -@@ -1016,7 +1017,7 @@ ok: - - - static ssize_t uv__fs_utime(uv_fs_t* req) { --#if defined(__linux__) \ -+#if (defined(__linux__) && __GLIBC_PREREQ(2, 6)) \ - || defined(_AIX71) \ - || defined(__sun) \ - || defined(__HAIKU__) -@@ -1027,7 +1028,8 @@ static ssize_t uv__fs_utime(uv_fs_t* req - ts[0] = uv__fs_to_timespec(req->atime); - ts[1] = uv__fs_to_timespec(req->mtime); - return utimensat(AT_FDCWD, req->path, ts, 0); --#elif defined(__APPLE__) \ -+#elif (defined(__linux__) && !__GLIBC_PREREQ(2, 6)) \ -+ || defined(__APPLE__) \ - || defined(__DragonFly__) \ - || defined(__FreeBSD__) \ - || defined(__FreeBSD_kernel__) \ -@@ -1059,7 +1061,7 @@ static ssize_t uv__fs_utime(uv_fs_t* req - - - static ssize_t uv__fs_lutime(uv_fs_t* req) { --#if defined(__linux__) || \ -+#if (defined(__linux__) && __GLIBC_PREREQ(2, 6)) || \ - defined(_AIX71) || \ - defined(__sun) || \ - defined(__HAIKU__) -@@ -1067,7 +1069,8 @@ static ssize_t uv__fs_lutime(uv_fs_t* re - ts[0] = uv__fs_to_timespec(req->atime); - ts[1] = uv__fs_to_timespec(req->mtime); - return utimensat(AT_FDCWD, req->path, ts, AT_SYMLINK_NOFOLLOW); --#elif defined(__APPLE__) || \ -+#elif (defined(__linux__) && !__GLIBC_PREREQ(2, 6)) || \ -+ defined(__APPLE__) || \ - defined(__DragonFly__) || \ - defined(__FreeBSD__) || \ - defined(__FreeBSD_kernel__) || \ -diff -rup a/Utilities/cmlibuv/src/unix/linux-core.c b/Utilities/cmlibuv/src/unix/linux-core.c ---- a/Utilities/cmlibuv/src/unix/linux-core.c 2020-12-16 12:35:29.000000000 +0000 -+++ b/Utilities/cmlibuv/src/unix/linux-core.c 2020-12-22 18:13:06.000000000 +0000 -@@ -85,7 +85,12 @@ static uint64_t read_cpufreq(unsigned in - - int uv__platform_loop_init(uv_loop_t* loop) { - int fd; -+#if __GLIBC_PREREQ(2, 9) - fd = epoll_create1(O_CLOEXEC); -+#else -+ fd = -1; -+ errno = ENOSYS; -+#endif - - /* epoll_create1() can fail either because it's not implemented (old kernel) - * or because it doesn't understand the O_CLOEXEC flag. -@@ -311,11 +316,16 @@ void uv__io_poll(uv_loop_t* loop, int ti - abort(); - - if (no_epoll_wait != 0 || (sigmask != 0 && no_epoll_pwait == 0)) { -+#if __GLIBC_PREREQ(2, 6) - nfds = epoll_pwait(loop->backend_fd, - events, - ARRAY_SIZE(events), - timeout, - &sigset); -+#else -+ nfds = -1; -+ errno = ENOSYS; -+#endif - if (nfds == -1 && errno == ENOSYS) { - uv__store_relaxed(&no_epoll_pwait_cached, 1); - no_epoll_pwait = 1; -diff -rup a/Utilities/cmlibuv/src/unix/linux-inotify.c b/Utilities/cmlibuv/src/unix/linux-inotify.c ---- a/Utilities/cmlibuv/src/unix/linux-inotify.c 2020-12-16 12:35:29.000000000 +0000 -+++ b/Utilities/cmlibuv/src/unix/linux-inotify.c 2020-12-22 18:16:16.000000000 +0000 -@@ -71,10 +71,22 @@ static int init_inotify(uv_loop_t* loop) - if (loop->inotify_fd != -1) - return 0; - -+#if __GLIBC_PREREQ(2, 6) - fd = inotify_init1(IN_NONBLOCK | IN_CLOEXEC); -+#else -+ fd = inotify_init(); -+#endif -+ - if (fd < 0) - return UV__ERR(errno); - -+#if !__GLIBC_PREREQ(2, 6) -+ if (uv__nonblock(fd, 1) || uv__cloexec(fd, 1)) { -+ uv__close(fd); -+ return UV__ERR(errno); -+ } -+#endif -+ - loop->inotify_fd = fd; - uv__io_init(&loop->inotify_read_watcher, uv__inotify_read, loop->inotify_fd); - uv__io_start(loop, &loop->inotify_read_watcher, POLLIN); -diff -rup a/Utilities/cmlibuv/src/unix/process.c b/Utilities/cmlibuv/src/unix/process.c ---- a/Utilities/cmlibuv/src/unix/process.c 2020-12-16 12:35:29.000000000 +0000 -+++ b/Utilities/cmlibuv/src/unix/process.c 2020-12-22 18:23:18.000000000 +0000 -@@ -124,7 +124,7 @@ static void uv__chld(uv_signal_t* handle - - - static int uv__make_socketpair(int fds[2]) { --#if defined(__FreeBSD__) || defined(__linux__) -+#if defined(__FreeBSD__) || (defined(__linux__) && __GLIBC_PREREQ(2, 9)) - if (socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, fds)) - return UV__ERR(errno); - -@@ -151,7 +151,7 @@ static int uv__make_socketpair(int fds[2 - - - int uv__make_pipe(int fds[2], int flags) { --#if defined(__FreeBSD__) || defined(__linux__) -+#if defined(__FreeBSD__) || (defined(__linux__) && __GLIBC_PREREQ(2, 9)) - if (pipe2(fds, flags | O_CLOEXEC)) - return UV__ERR(errno); - diff --git a/utils/docker/gcc-4.8.2-gcc-11.patch b/utils/docker/gcc-4.8.2-gcc-11.patch deleted file mode 100644 index 00e7cbdb..00000000 --- a/utils/docker/gcc-4.8.2-gcc-11.patch +++ /dev/null @@ -1,38 +0,0 @@ ---- a/gcc/splay-tree-utils.h 2021-05-14 06:09:43.289274290 +0000 -+++ b/gcc/splay-tree-utils.h 2021-05-14 06:24:13.628159368 +0000 -@@ -105,7 +105,11 @@ template - class base_splay_tree : protected Accessors - { - public: -+#if __GNUC__ > 4 - using typename Accessors::node_type; -+#else -+ using node_type = typename Accessors::node_type; -+#endif - - // INDEX is either 0 or 1. If INDEX is 0, insert CHILD immediately - // before NODE, otherwise insert CHILD immediately after NODE. -@@ -148,7 +152,11 @@ class rooted_splay_tree : public base_sp - using parent = base_splay_tree; - - public: -+#if __GNUC__ > 4 - using typename Accessors::node_type; -+#else -+ using node_type = typename Accessors::node_type; -+#endif - - protected: - // The root of the splay tree, or node_type () if the tree is empty. -@@ -409,7 +417,11 @@ class rootless_splay_tree - public: - using rooted = rooted_splay_tree; - -+#if __GNUC__ > 4 - using typename Accessors::node_type; -+#else -+ using node_type = typename Accessors::node_type; -+#endif - - // Remove NODE from the splay tree. Return the node that replaces it, - // or null if NODE had no children. diff --git a/utils/docker/llvm-glibc-2.5.patch b/utils/docker/llvm-glibc-2.5.patch deleted file mode 100644 index 54cac8ea..00000000 --- a/utils/docker/llvm-glibc-2.5.patch +++ /dev/null @@ -1,39 +0,0 @@ ---- a/lib/Support/Host.cpp 2020-12-17 20:09:25.321395012 +0000 -+++ b/lib/Support/Host.cpp 2020-12-17 20:29:40.296551916 +0000 -@@ -1225,6 +1225,15 @@ StringRef sys::getHostCPUName() { return - #endif - - #if defined(__linux__) && (defined(__i386__) || defined(__x86_64__)) -+#if !defined(CPU_COUNT) -+static inline auto CPU_COUNT(const cpu_set_t *Set) noexcept { -+ int Count = 0; -+ for (const auto &Bits : Set->__bits) { -+ Count += __builtin_popcountl(Bits); -+ } -+ return Count; -+} -+#endif - // On Linux, the number of physical cores can be computed from /proc/cpuinfo, - // using the number of unique physical/core id pairs. The following - // implementation reads the /proc/cpuinfo format on an x86_64 system. - ---- a/lib/Support/Unix/Threading.inc 2020-12-17 20:09:25.325395024 +0000 -+++ b/lib/Support/Unix/Threading.inc 2020-12-17 20:24:57.267834738 +0000 -@@ -281,6 +281,16 @@ SetThreadPriorityResult llvm::set_thread - - #include - -+#if !defined(CPU_COUNT) -+static inline auto CPU_COUNT(const cpu_set_t *Set) noexcept { -+ int Count = 0; -+ for (const auto &Bits : Set->__bits) { -+ Count += __builtin_popcountl(Bits); -+ } -+ return Count; -+} -+#endif -+ - int computeHostNumHardwareThreads() { - #ifdef __linux__ - cpu_set_t Set; - From 40bdcaa07c7458ae7662ede2802670271f1dea81 Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 24 Aug 2022 03:45:33 +0800 Subject: [PATCH 068/623] [Utils] Remove manylinux2010 dockerfile Signed-off-by: hydai --- utils/docker/Dockerfile.manylinux2010_x86_64 | 83 -------------------- 1 file changed, 83 deletions(-) delete mode 100644 utils/docker/Dockerfile.manylinux2010_x86_64 diff --git a/utils/docker/Dockerfile.manylinux2010_x86_64 b/utils/docker/Dockerfile.manylinux2010_x86_64 deleted file mode 100644 index e17232ad..00000000 --- a/utils/docker/Dockerfile.manylinux2010_x86_64 +++ /dev/null @@ -1,83 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# SPDX-FileCopyrightText: 2019-2022 Second State INC - -FROM quay.io/pypa/manylinux2010_x86_64 - -MAINTAINER hydai hydai@secondstate.io - -ADD SHA256SUM /root/ - -ENV PATH /toolchain/bin:$PATH -ENV CC gcc -ENV CXX g++ -ENV CPPFLAGS -I/toolchain/include -ENV LDFLAGS -L/toolchain/lib64 -ENV PKG_CONFIG_PATH /toolchain/lib64/pkgconfig - -RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build dpkg && \ - export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ - 'import multiprocessing; print(multiprocessing.cpu_count())') && \ - export CFGFLAGS="--prefix=/toolchain --disable-shared --libdir=/toolchain/lib64" && \ - curl -s -L -O --remote-name-all \ - https://ftp.gnu.org/gnu/gmp/gmp-6.2.1.tar.xz \ - https://ftp.gnu.org/gnu/mpfr/mpfr-4.1.0.tar.xz \ - https://ftp.gnu.org/gnu/mpc/mpc-1.2.1.tar.gz \ - https://libisl.sourceforge.io/isl-0.24.tar.xz \ - https://github.com/facebook/zstd/releases/download/v1.5.0/zstd-1.5.0.tar.gz \ - https://ftp.gnu.org/gnu/gcc/gcc-11.1.0/gcc-11.1.0.tar.xz \ - https://github.com/Kitware/CMake/releases/download/v3.20.2/cmake-3.20.2.tar.gz \ - https://github.com/ninja-build/ninja/archive/v1.10.2.tar.gz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-12.0.0/llvm-12.0.0.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-12.0.0/lld-12.0.0.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-12.0.0/libunwind-12.0.0.src.tar.xz && \ - sha256sum -c SHA256SUM && \ - xz -dc gmp-6.2.1.tar.xz | tar -xf - && \ - xz -dc mpfr-4.1.0.tar.xz | tar -xf - && \ - gzip -dc mpc-1.2.1.tar.gz | tar -xf - && \ - xz -dc isl-0.24.tar.xz | tar -xf - && \ - gzip -dc zstd-1.5.0.tar.gz | tar -xf - && \ - xz -dc gcc-11.1.0.tar.xz | tar -xf - && \ - gzip -dc cmake-3.20.2.tar.gz | tar -xf - && \ - gzip -dc v1.10.2.tar.gz | tar -xf - && \ - xz -dc llvm-12.0.0.src.tar.xz | tar -xf - && \ - xz -dc lld-12.0.0.src.tar.xz | tar -xf - && \ - xz -dc libunwind-12.0.0.src.tar.xz | tar -xf - && \ - mkdir build && cd build && ../gmp-6.2.1/configure --build=x86_64-pc-linux-gnu $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mkdir build && cd build && ../mpfr-4.1.0/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mkdir build && cd build && ../mpc-1.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mkdir build && cd build && ../isl-0.24/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - export ZSTDFLAGS="PREFIX=/toolchain LIBDIR=/toolchain/lib64 SED_ERE_OPT=--regexp-extended" && \ - cd zstd-1.5.0 && make -s $ZSTDFLAGS -j $CPU && make -s $ZSTDFLAGS install && rm -f /toolchain/lib64/libzstd.so* && cd - && \ - mkdir build && cd build && ../gcc-11.1.0/configure --prefix=/toolchain --libdir=/toolchain/lib64 \ - --with-gmp=/toolchain --with-gmp-lib=/toolchain/lib64 \ - --with-zstd=/toolchain --with-zstd-lib=/toolchain/lib64 \ - --disable-libmpx --disable-libsanitizer --disable-libunwind-exceptions \ - --disable-multilib --enable-__cxa_atexit --enable-gnu-indirect-function \ - --enable-gnu-unique-object --enable-initfini-array --enable-languages="c,c++,lto" \ - --enable-linker-build-id --enable-lto --enable-plugin --enable-threads=posix \ - --with-default-libstdcxx-abi="gcc4-compatible" \ - --with-gcc-major-version-only --with-linker-hash-style="gnu" \ - --with-arch="x86-64" --with-tune="generic" && \ - make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mv -v /toolchain/lib64/libstdc++.so.6.0.29 /toolchain/lib64/libstdc++.so.6.0.29.backup && \ - echo -e "OUTPUT_FORMAT(elf64-x86-64)\nINPUT ( libstdc++.so.6.0.13 libstdc++.a )" \ - > /toolchain/lib64/libstdc++.so.6.0.29 && \ - export PATH="/toolchain/bin:$PATH" && \ - mkdir build && cd build && /opt/python/cp39-cp39/bin/python \ - ../ninja-1.10.2/configure.py --bootstrap \ - --with-python=/opt/python/cp39-cp39/bin/python && \ - cp -v ninja /toolchain/bin/ninja && cd - && rm -rf build && \ - mkdir build && cd build && ../cmake-3.20.2/configure --prefix=/toolchain \ - --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mv -v llvm-12.0.0.src llvm && \ - mv -v lld-12.0.0.src lld && \ - mv -v libunwind-12.0.0.src libunwind && \ - cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_INSTALL_PREFIX=/toolchain \ - -DPython3_ROOT_DIR=/opt/python/cp39-cp39 -DLLVM_LIBDIR_SUFFIX=64 \ - -DLLVM_TARGETS_TO_BUILD="X86" -DLLVM_ENABLE_PROJECTS=lld \ - -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" llvm && \ - cmake --build build --target install && \ - rm -rf build && rm -rf * - -RUN yum clean all From 9f9448cf701a17ba9e8e5ee3fe3faaa7997aaa38 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Wed, 24 Aug 2022 03:09:11 +0800 Subject: [PATCH 069/623] [Plugins] Refine the host function headers. Signed-off-by: YiYing He --- plugins/httpsreq/httpsreqbase.h | 5 +++-- plugins/httpsreq/httpsreqenv.h | 7 ++----- plugins/httpsreq/httpsreqfunc.cpp | 12 ++++++++---- plugins/httpsreq/httpsreqfunc.h | 5 +---- plugins/httpsreq/httpsreqmodule.cpp | 4 +++- plugins/httpsreq/httpsreqmodule.h | 4 +++- plugins/wasi_nn/wasinnbase.h | 1 - plugins/wasi_nn/wasinnfunc.h | 2 +- plugins/wasmedge_process/processbase.h | 4 ++-- plugins/wasmedge_process/processenv.cpp | 4 +++- plugins/wasmedge_process/processenv.h | 1 + plugins/wasmedge_process/processfunc.cpp | 2 ++ plugins/wasmedge_process/processfunc.h | 4 +++- plugins/wasmedge_process/processmodule.h | 1 + 14 files changed, 33 insertions(+), 23 deletions(-) diff --git a/plugins/httpsreq/httpsreqbase.h b/plugins/httpsreq/httpsreqbase.h index e7fcd4bd..c8b020f7 100644 --- a/plugins/httpsreq/httpsreqbase.h +++ b/plugins/httpsreq/httpsreqbase.h @@ -3,8 +3,9 @@ #pragma once -#include "common/errcode.h" #include "httpsreqenv.h" + +#include "common/errcode.h" #include "runtime/callingframe.h" #include "runtime/hostfunc.h" @@ -21,4 +22,4 @@ template class HttpsReq : public Runtime::HostFunction { }; } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/plugins/httpsreq/httpsreqenv.h b/plugins/httpsreq/httpsreqenv.h index 047604cd..17495598 100644 --- a/plugins/httpsreq/httpsreqenv.h +++ b/plugins/httpsreq/httpsreqenv.h @@ -4,12 +4,9 @@ #pragma once #include "plugin/plugin.h" -#include "po/argument_parser.h" -#include "po/list.h" -#include "po/option.h" + #include #include -#include namespace WasmEdge { namespace Host { @@ -26,4 +23,4 @@ class HttpsReqEnvironment { }; } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/plugins/httpsreq/httpsreqfunc.cpp b/plugins/httpsreq/httpsreqfunc.cpp index ff989357..e3cd55bf 100644 --- a/plugins/httpsreq/httpsreqfunc.cpp +++ b/plugins/httpsreq/httpsreqfunc.cpp @@ -2,13 +2,17 @@ // SPDX-FileCopyrightText: 2019-2022 Second State INC #include "httpsreqfunc.h" -#include -#include -#include + #include #include + +#include +#include +#include +#include #include #include +#include #include // Some of the code was taken from this post: @@ -141,4 +145,4 @@ Expect HttpsReqGetRcvLen::body(const Runtime::CallingFrame &) { } } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/plugins/httpsreq/httpsreqfunc.h b/plugins/httpsreq/httpsreqfunc.h index 8f085ba5..300458af 100644 --- a/plugins/httpsreq/httpsreqfunc.h +++ b/plugins/httpsreq/httpsreqfunc.h @@ -3,11 +3,8 @@ #pragma once -#include "common/defines.h" #include "httpsreqbase.h" -#include - namespace WasmEdge { namespace Host { @@ -32,4 +29,4 @@ class HttpsReqGetRcvLen : public HttpsReq { }; } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/plugins/httpsreq/httpsreqmodule.cpp b/plugins/httpsreq/httpsreqmodule.cpp index 868f4458..c95111fb 100644 --- a/plugins/httpsreq/httpsreqmodule.cpp +++ b/plugins/httpsreq/httpsreqmodule.cpp @@ -4,6 +4,8 @@ #include "httpsreqmodule.h" #include "httpsreqfunc.h" +#include + namespace WasmEdge { namespace Host { @@ -15,4 +17,4 @@ HttpsReqModule::HttpsReqModule() : ModuleInstance("httpsreq") { } } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/plugins/httpsreq/httpsreqmodule.h b/plugins/httpsreq/httpsreqmodule.h index 60f545ef..89ac524c 100644 --- a/plugins/httpsreq/httpsreqmodule.h +++ b/plugins/httpsreq/httpsreqmodule.h @@ -4,7 +4,9 @@ #pragma once #include "httpsreqenv.h" + #include "runtime/instance/module.h" + #include namespace WasmEdge { @@ -21,4 +23,4 @@ class HttpsReqModule : public Runtime::Instance::ModuleInstance { }; } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/plugins/wasi_nn/wasinnbase.h b/plugins/wasi_nn/wasinnbase.h index 698b6049..b5daaf84 100644 --- a/plugins/wasi_nn/wasinnbase.h +++ b/plugins/wasi_nn/wasinnbase.h @@ -4,7 +4,6 @@ #pragma once #include "common/errcode.h" -#include "runtime/callingframe.h" #include "runtime/hostfunc.h" #include "wasinnenv.h" diff --git a/plugins/wasi_nn/wasinnfunc.h b/plugins/wasi_nn/wasinnfunc.h index d8d9ffd1..2a1e7462 100644 --- a/plugins/wasi_nn/wasinnfunc.h +++ b/plugins/wasi_nn/wasinnfunc.h @@ -3,7 +3,7 @@ #pragma once -#include "runtime/instance/memory.h" +#include "runtime/callingframe.h" #include "wasinnbase.h" #include diff --git a/plugins/wasmedge_process/processbase.h b/plugins/wasmedge_process/processbase.h index ba98ae74..12c39357 100644 --- a/plugins/wasmedge_process/processbase.h +++ b/plugins/wasmedge_process/processbase.h @@ -3,9 +3,9 @@ #pragma once -#include "common/errcode.h" #include "processenv.h" -#include "runtime/callingframe.h" + +#include "common/errcode.h" #include "runtime/hostfunc.h" namespace WasmEdge { diff --git a/plugins/wasmedge_process/processenv.cpp b/plugins/wasmedge_process/processenv.cpp index b314addc..483c3fdd 100644 --- a/plugins/wasmedge_process/processenv.cpp +++ b/plugins/wasmedge_process/processenv.cpp @@ -2,8 +2,10 @@ // SPDX-FileCopyrightText: 2019-2022 Second State INC #include "processenv.h" -#include "po/helper.h" #include "processmodule.h" + +#include "po/helper.h" + #include namespace WasmEdge { diff --git a/plugins/wasmedge_process/processenv.h b/plugins/wasmedge_process/processenv.h index 5bbdd7ec..f1b4ddbc 100644 --- a/plugins/wasmedge_process/processenv.h +++ b/plugins/wasmedge_process/processenv.h @@ -7,6 +7,7 @@ #include "po/argument_parser.h" #include "po/list.h" #include "po/option.h" + #include #include #include diff --git a/plugins/wasmedge_process/processfunc.cpp b/plugins/wasmedge_process/processfunc.cpp index 3852c596..ed69a03d 100644 --- a/plugins/wasmedge_process/processfunc.cpp +++ b/plugins/wasmedge_process/processfunc.cpp @@ -3,6 +3,8 @@ #include "processfunc.h" +#include "common/defines.h" + #if WASMEDGE_OS_LINUX || WASMEDGE_OS_MACOS #include #include diff --git a/plugins/wasmedge_process/processfunc.h b/plugins/wasmedge_process/processfunc.h index 072d0c54..f23a4c41 100644 --- a/plugins/wasmedge_process/processfunc.h +++ b/plugins/wasmedge_process/processfunc.h @@ -3,8 +3,10 @@ #pragma once -#include "common/defines.h" #include "processbase.h" + +#include "runtime/callingframe.h" + #include namespace WasmEdge { diff --git a/plugins/wasmedge_process/processmodule.h b/plugins/wasmedge_process/processmodule.h index 491ae2ce..0a8e5bac 100644 --- a/plugins/wasmedge_process/processmodule.h +++ b/plugins/wasmedge_process/processmodule.h @@ -4,6 +4,7 @@ #pragma once #include "processenv.h" + #include "runtime/instance/module.h" namespace WasmEdge { From 21b11caf94231c6ecae2b5fc2378ffe92437973b Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 24 Aug 2022 05:15:02 +0800 Subject: [PATCH 070/623] [Misc] Use wasmedge_shared instead of wasmedge_c_shared Signed-off-by: hydai --- plugins/httpsreq/CMakeLists.txt | 2 +- plugins/wasi_crypto/CMakeLists.txt | 4 ++-- plugins/wasi_nn/CMakeLists.txt | 2 +- plugins/wasmedge_process/CMakeLists.txt | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/plugins/httpsreq/CMakeLists.txt b/plugins/httpsreq/CMakeLists.txt index e7cfca95..d32cf3e6 100644 --- a/plugins/httpsreq/CMakeLists.txt +++ b/plugins/httpsreq/CMakeLists.txt @@ -44,7 +44,7 @@ if(WASMEDGE_LINK_PUGLINS_STATIC) else() target_link_libraries(wasmedgePluginHttpsReq PRIVATE - wasmedge_c_shared + wasmedge_shared ) endif() diff --git a/plugins/wasi_crypto/CMakeLists.txt b/plugins/wasi_crypto/CMakeLists.txt index 3b770f07..4eacb87e 100644 --- a/plugins/wasi_crypto/CMakeLists.txt +++ b/plugins/wasi_crypto/CMakeLists.txt @@ -72,7 +72,7 @@ target_include_directories(wasmedgePluginWasiCrypto target_link_libraries(wasmedgePluginWasiCrypto PUBLIC - wasmedge_c_shared + wasmedge_shared OpenSSL::Crypto ) if(WASMEDGE_LINK_PUGLINS_STATIC) @@ -83,7 +83,7 @@ if(WASMEDGE_LINK_PUGLINS_STATIC) else() target_link_libraries(wasmedgePluginWasiCrypto PRIVATE - wasmedge_c_shared + wasmedge_shared ) endif() diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 63807824..efce12db 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -35,7 +35,7 @@ if(WASMEDGE_LINK_PUGLINS_STATIC) else() target_link_libraries(wasmedgePluginWasiNN PRIVATE - wasmedge_c_shared + wasmedge_shared ) endif() diff --git a/plugins/wasmedge_process/CMakeLists.txt b/plugins/wasmedge_process/CMakeLists.txt index 50f00713..8c3db48b 100644 --- a/plugins/wasmedge_process/CMakeLists.txt +++ b/plugins/wasmedge_process/CMakeLists.txt @@ -35,7 +35,7 @@ if(WASMEDGE_LINK_PUGLINS_STATIC) else() target_link_libraries(wasmedgePluginWasmEdgeProcess PRIVATE - wasmedge_c_shared + wasmedge_shared ) endif() From bcb4afb6e3cbcca428811054010d8438d7de980e Mon Sep 17 00:00:00 2001 From: dm4 Date: Fri, 26 Aug 2022 16:03:28 +0800 Subject: [PATCH 071/623] [Docker] Remove appdev images Signed-off-by: dm4 --- utils/docker/Dockerfile.appdev_aarch64 | 23 -------- utils/docker/Dockerfile.appdev_x86_64 | 28 ---------- utils/docker/build-appdev.md | 76 -------------------------- 3 files changed, 127 deletions(-) delete mode 100644 utils/docker/Dockerfile.appdev_aarch64 delete mode 100644 utils/docker/Dockerfile.appdev_x86_64 delete mode 100644 utils/docker/build-appdev.md diff --git a/utils/docker/Dockerfile.appdev_aarch64 b/utils/docker/Dockerfile.appdev_aarch64 deleted file mode 100644 index 0fc33825..00000000 --- a/utils/docker/Dockerfile.appdev_aarch64 +++ /dev/null @@ -1,23 +0,0 @@ -FROM ubuntu:21.04 - -RUN apt-get update &&\ - DEBIAN_FRONTEND=noninteractive apt-get install -y wget git curl software-properties-common golang - -RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain stable -y -ENV PATH=/root/.cargo/bin:$PATH -RUN rustup target add wasm32-wasi - -RUN curl https://raw.githubusercontent.com/second-state/rustwasmc/master/installer/init.sh -sSf | sh - -RUN curl -sL https://deb.nodesource.com/setup_12.x | bash - -RUN apt-get install -y nodejs -RUN npm install wasmedge-core - -RUN mkdir -p /root/examples -WORKDIR /root/examples -RUN wget https://github.com/WasmEdge/WasmEdge/raw/master/examples/wasm/hello.wasm &&\ - wget https://github.com/WasmEdge/WasmEdge/raw/master/examples/js/qjs.wasm &&\ - wget https://github.com/WasmEdge/WasmEdge/raw/master/examples/js/hello.js &&\ - wget https://github.com/WasmEdge/WasmEdge/raw/master/examples/js/repl.js - -ENTRYPOINT ["/bin/bash", "-l"] diff --git a/utils/docker/Dockerfile.appdev_x86_64 b/utils/docker/Dockerfile.appdev_x86_64 deleted file mode 100644 index e97b245a..00000000 --- a/utils/docker/Dockerfile.appdev_x86_64 +++ /dev/null @@ -1,28 +0,0 @@ -FROM ubuntu:21.04 - -RUN apt-get update &&\ - DEBIAN_FRONTEND=noninteractive apt-get install -y wget git curl software-properties-common golang - -RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain stable -y -ENV PATH=/root/.cargo/bin:$PATH -RUN rustup target add wasm32-wasi - -RUN curl https://raw.githubusercontent.com/second-state/rustwasmc/master/installer/init.sh -sSf | sh - -RUN curl -sL https://deb.nodesource.com/setup_12.x | bash - -RUN apt-get install -y nodejs -RUN npm install wasmedge-extensions - -RUN mkdir -p /root/examples -WORKDIR /root/examples -RUN wget https://github.com/WasmEdge/WasmEdge/raw/master/examples/wasm/hello.wasm &&\ - wget https://github.com/WasmEdge/WasmEdge/raw/master/examples/js/qjs.wasm &&\ - wget https://github.com/WasmEdge/WasmEdge/raw/master/examples/js/hello.js &&\ - wget https://github.com/WasmEdge/WasmEdge/raw/master/examples/js/repl.js &&\ - wget https://github.com/WasmEdge/WasmEdge/raw/master/examples/js/qjs_tf.wasm &&\ - wget -O tf_image_classify.js https://raw.githubusercontent.com/second-state/wasmedge-quickjs/main/example_js/tensorflow_lite_demo/main.js &&\ - wget https://github.com/second-state/wasmedge-quickjs/raw/main/example_js/tensorflow_lite_demo/lite-model_aiy_vision_classifier_food_V1_1.tflite &&\ - wget https://github.com/second-state/wasmedge-quickjs/raw/main/example_js/tensorflow_lite_demo/food.jpg &&\ - wget https://github.com/second-state/wasmedge-quickjs/raw/main/example_js/tensorflow_lite_demo/aiy_food_V1_labelmap.txt - -ENTRYPOINT ["/bin/bash", "-l"] diff --git a/utils/docker/build-appdev.md b/utils/docker/build-appdev.md deleted file mode 100644 index 74cdf34f..00000000 --- a/utils/docker/build-appdev.md +++ /dev/null @@ -1,76 +0,0 @@ -# Use the appdev Docker images - -The `appdev` Docker images provide a complete WasmEdge application development environment. To use it, do the following. - -## On x86_64 machines - -```bash -$ docker pull wasmedge/appdev_x86_64:0.9.0 -$ docker run --rm -v $(pwd):/app -it wasmedge/appdev_x86_64:0.9.0 -(docker) # -``` - -## On arm64 machines - -```bash -$ docker pull wasmedge/appdev_aarch64:0.9.0 -$ docker run --rm -v $(pwd):/app -it wasmedge/appdev_aarch64:0.9.0 -(docker) # -``` - -It installs the following components. - -* WasmEdge CLI and shared libraries -* WasmEdge with Tensorflow extension CLI and libraries (x86_64 only) -* Golang -* Rust -* Node.js with WasmEdge addons -* Examples in the `/root/examples/` folder - -## Examples - -Hello World. [See more simple examples](https://github.com/WasmEdge/WasmEdge/tree/master/examples/wasm) - -```bash -$ wasmedge hello.wasm world -hello -world -``` - -Use AOT to run it *much faster*. - -```bash -$ wasmedgec hello.wasm hello.so -$ wasmedge hello.so world -hello -world -``` - -Here are some JavaScript examples. [See more](https://github.com/WasmEdge/WasmEdge/tree/master/examples/js) - -```bash -$ wasmedge --dir .:. qjs.wasm hello.js 1 2 3 -Hello 1 2 3 - -$ wasmedge-tensorflow-lite --dir .:. qjs_tf.wasm tf_image_classify.js -label: Hot dog -confidence: 0.8941176470588236 -``` - -## Build and publish the appdev images - -Run these commands to build and publish the `appdev` Docker images. - -### Build on an x86_64 machine - -```bash -docker build -t wasmedge/appdev_x86_64:0.9.0 -f Dockerfile.appdev_x86_64 ./ -docker image push wasmedge/appdev_x86_64:0.9.0 -``` - -### Build on an ARM64 / aarch64 machine - -```bash -docker build -t wasmedge/appdev_aarch64:0.9.0 -f Dockerfile.appdev_aarch64 ./ -docker image push wasmedge/appdev_aarch64:0.9.0 -``` From d018077ed7f40a28b67cc2a4a55d2bd25d7bf904 Mon Sep 17 00:00:00 2001 From: Zhou Zhou Date: Thu, 1 Sep 2022 15:39:11 +0800 Subject: [PATCH 072/623] [Plugin] refine the httpsreq plugin implementation (#1802) * [Rename] refine the httpsreq (#1768) (#1788) * [Misc] refine the httpsreq (#1768) (#1788) * [Misc] modify the head File * [Test] rename the httpsreq module name in tests * [Misc] refine the implementations * [Misc] refine the details * [Misc] rename Wasmedge to WasmEdge * [Docs] update the docs for httpsreq plugin Signed-off-by: zhouzhou --- plugins/CMakeLists.txt | 2 +- plugins/httpsreq/httpsreqfunc.h | 32 -------- plugins/httpsreq/httpsreqmodule.cpp | 20 ----- .../CMakeLists.txt | 0 .../httpsreqbase.h | 6 +- .../httpsreqenv.cpp | 8 +- .../httpsreqenv.h | 5 +- .../httpsreqfunc.cpp | 74 +++++++++---------- plugins/wasmedge_httpsreq/httpsreqfunc.h | 37 ++++++++++ plugins/wasmedge_httpsreq/httpsreqmodule.cpp | 22 ++++++ .../httpsreqmodule.h | 10 +-- test/plugins/CMakeLists.txt | 2 +- .../CMakeLists.txt | 0 .../httpsreq.cpp | 62 ++++++++-------- 14 files changed, 135 insertions(+), 145 deletions(-) delete mode 100644 plugins/httpsreq/httpsreqfunc.h delete mode 100644 plugins/httpsreq/httpsreqmodule.cpp rename plugins/{httpsreq => wasmedge_httpsreq}/CMakeLists.txt (100%) rename plugins/{httpsreq => wasmedge_httpsreq}/httpsreqbase.h (68%) rename plugins/{httpsreq => wasmedge_httpsreq}/httpsreqenv.cpp (78%) rename plugins/{httpsreq => wasmedge_httpsreq}/httpsreqenv.h (80%) rename plugins/{httpsreq => wasmedge_httpsreq}/httpsreqfunc.cpp (53%) create mode 100644 plugins/wasmedge_httpsreq/httpsreqfunc.h create mode 100644 plugins/wasmedge_httpsreq/httpsreqmodule.cpp rename plugins/{httpsreq => wasmedge_httpsreq}/httpsreqmodule.h (60%) rename test/plugins/{httpsreq => wasmedge_httpsreq}/CMakeLists.txt (100%) rename test/plugins/{httpsreq => wasmedge_httpsreq}/httpsreq.cpp (70%) diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index 5f228529..5bc0f845 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -12,7 +12,7 @@ endif() if(WASMEDGE_PLUGIN_HTTPSREQ) if(CMAKE_SYSTEM_NAME MATCHES "Linux") - add_subdirectory(httpsreq) + add_subdirectory(wasmedge_httpsreq) endif() endif() diff --git a/plugins/httpsreq/httpsreqfunc.h b/plugins/httpsreq/httpsreqfunc.h deleted file mode 100644 index 300458af..00000000 --- a/plugins/httpsreq/httpsreqfunc.h +++ /dev/null @@ -1,32 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC - -#pragma once - -#include "httpsreqbase.h" - -namespace WasmEdge { -namespace Host { - -class SendData : public HttpsReq { -public: - SendData(HttpsReqEnvironment &HostEnv) : HttpsReq(HostEnv) {} - Expect body(const Runtime::CallingFrame &, uint32_t HostPtr, - uint32_t HostLen, uint32_t Port, uint32_t BodyPtr, - uint32_t BodyLen); -}; - -class HttpsReqGetRcv : public HttpsReq { -public: - HttpsReqGetRcv(HttpsReqEnvironment &HostEnv) : HttpsReq(HostEnv) {} - Expect body(const Runtime::CallingFrame &, uint32_t BufPtr); -}; - -class HttpsReqGetRcvLen : public HttpsReq { -public: - HttpsReqGetRcvLen(HttpsReqEnvironment &HostEnv) : HttpsReq(HostEnv) {} - Expect body(const Runtime::CallingFrame &); -}; - -} // namespace Host -} // namespace WasmEdge diff --git a/plugins/httpsreq/httpsreqmodule.cpp b/plugins/httpsreq/httpsreqmodule.cpp deleted file mode 100644 index c95111fb..00000000 --- a/plugins/httpsreq/httpsreqmodule.cpp +++ /dev/null @@ -1,20 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC - -#include "httpsreqmodule.h" -#include "httpsreqfunc.h" - -#include - -namespace WasmEdge { -namespace Host { - -/// Register your functions in module. -HttpsReqModule::HttpsReqModule() : ModuleInstance("httpsreq") { - addHostFunc("send_data", std::make_unique(Env)); - addHostFunc("get_rcv", std::make_unique(Env)); - addHostFunc("get_rcv_len", std::make_unique(Env)); -} - -} // namespace Host -} // namespace WasmEdge diff --git a/plugins/httpsreq/CMakeLists.txt b/plugins/wasmedge_httpsreq/CMakeLists.txt similarity index 100% rename from plugins/httpsreq/CMakeLists.txt rename to plugins/wasmedge_httpsreq/CMakeLists.txt diff --git a/plugins/httpsreq/httpsreqbase.h b/plugins/wasmedge_httpsreq/httpsreqbase.h similarity index 68% rename from plugins/httpsreq/httpsreqbase.h rename to plugins/wasmedge_httpsreq/httpsreqbase.h index c8b020f7..d11b96c2 100644 --- a/plugins/httpsreq/httpsreqbase.h +++ b/plugins/wasmedge_httpsreq/httpsreqbase.h @@ -12,13 +12,13 @@ namespace WasmEdge { namespace Host { -template class HttpsReq : public Runtime::HostFunction { +template class WasmEdgeHttpsReq : public Runtime::HostFunction { public: - HttpsReq(HttpsReqEnvironment &HostEnv) + WasmEdgeHttpsReq(WasmEdgeHttpsReqEnvironment &HostEnv) : Runtime::HostFunction(0), Env(HostEnv) {} protected: - HttpsReqEnvironment &Env; + WasmEdgeHttpsReqEnvironment &Env; }; } // namespace Host diff --git a/plugins/httpsreq/httpsreqenv.cpp b/plugins/wasmedge_httpsreq/httpsreqenv.cpp similarity index 78% rename from plugins/httpsreq/httpsreqenv.cpp rename to plugins/wasmedge_httpsreq/httpsreqenv.cpp index 6b2cc318..3587ed27 100644 --- a/plugins/httpsreq/httpsreqenv.cpp +++ b/plugins/wasmedge_httpsreq/httpsreqenv.cpp @@ -10,11 +10,11 @@ namespace Host { namespace { Runtime::Instance::ModuleInstance *create(void) noexcept { - return new HttpsReqModule; + return new WasmEdgeHttpsReqModule; } Plugin::Plugin::PluginDescriptor Descriptor{ - .Name = "https_req", + .Name = "wasmedge_httpsreq", .Description = "", .APIVersion = Plugin::Plugin::CurrentAPIVersion, .Version = {0, 10, 1, 0}, @@ -22,7 +22,7 @@ Plugin::Plugin::PluginDescriptor Descriptor{ .ModuleDescriptions = (Plugin::PluginModule::ModuleDescriptor[]){ { - .Name = "https_req", + .Name = "wasmedge_httpsreq", .Description = "", .Create = create, }, @@ -32,7 +32,7 @@ Plugin::Plugin::PluginDescriptor Descriptor{ } // namespace -Plugin::PluginRegister HttpsReqEnvironment::Register(&Descriptor); +Plugin::PluginRegister WasmEdgeHttpsReqEnvironment::Register(&Descriptor); } // namespace Host } // namespace WasmEdge diff --git a/plugins/httpsreq/httpsreqenv.h b/plugins/wasmedge_httpsreq/httpsreqenv.h similarity index 80% rename from plugins/httpsreq/httpsreqenv.h rename to plugins/wasmedge_httpsreq/httpsreqenv.h index 17495598..7b1781b1 100644 --- a/plugins/httpsreq/httpsreqenv.h +++ b/plugins/wasmedge_httpsreq/httpsreqenv.h @@ -11,11 +11,8 @@ namespace WasmEdge { namespace Host { -class HttpsReqEnvironment { +class WasmEdgeHttpsReqEnvironment { public: - std::string Host; - uint32_t Port; - std::string BodyStr; std::string Rcv; /// Initial Configurations diff --git a/plugins/httpsreq/httpsreqfunc.cpp b/plugins/wasmedge_httpsreq/httpsreqfunc.cpp similarity index 53% rename from plugins/httpsreq/httpsreqfunc.cpp rename to plugins/wasmedge_httpsreq/httpsreqfunc.cpp index e3cd55bf..aa6d32f2 100644 --- a/plugins/httpsreq/httpsreqfunc.cpp +++ b/plugins/wasmedge_httpsreq/httpsreqfunc.cpp @@ -15,62 +15,57 @@ #include #include -// Some of the code was taken from this post: -// https://stackoverflow.com/questions/52727565/client-in-c-use-gethostbyname-or-getaddrinfo - namespace WasmEdge { namespace Host { -Expect SendData::body(const Runtime::CallingFrame &Frame, - uint32_t HostPtr, uint32_t HostLen, uint32_t Port, - uint32_t BodyPtr, uint32_t BodyLen) { +Expect WasmEdgeHttpsReqSendData::body(const Runtime::CallingFrame &Frame, + uint32_t HostPtr, uint32_t HostLen, + uint32_t Port, uint32_t BodyPtr, + uint32_t BodyLen) { auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); } - char *Host = MemInst->getPointer(HostPtr); - std::string NewHost; - std::copy_n(Host, HostLen, std::back_inserter(NewHost)); - Env.Host = std::move(NewHost); - - Env.Port = Port; - - char *BodyStr = MemInst->getPointer(BodyPtr); - std::string NewBodyStr; - std::copy_n(BodyStr, BodyLen, std::back_inserter(NewBodyStr)); - Env.BodyStr = std::move(NewBodyStr); + const char *Host = MemInst->getPointer(HostPtr); + const char *Body = MemInst->getPointer(BodyPtr); + if (Host == nullptr) { + spdlog::error("[WasmEdge Httpsreq] Fail to get Host"); + return Unexpect(ErrCode::Value::HostFuncError); + } + if (Body == nullptr) { + spdlog::error("[WasmEdge Httpsreq] Fail to get Body"); + return Unexpect(ErrCode::Value::HostFuncError); + } + std::string HostStr, BodyStr, PortStr = std::to_string(Port); + std::copy_n(Host, HostLen, std::back_inserter(HostStr)); + std::copy_n(Body, BodyLen, std::back_inserter(BodyStr)); const SSL_METHOD *Method = TLS_client_method(); SSL_CTX *Ctx = SSL_CTX_new(Method); - if (Ctx == nullptr) { ERR_print_errors_fp(stderr); - exit(EXIT_FAILURE); + spdlog::error("[WasmEdge Httpsreq] SSL_CTX_new() failed"); + return Unexpect(ErrCode::Value::HostFuncError); } - SSL *Ssl = SSL_new(Ctx); if (Ssl == nullptr) { - fprintf(stderr, "[Httpsreq plugin] SSL_new() failed\n"); - exit(EXIT_FAILURE); + spdlog::error("[WasmEdge Httpsreq] SSL_new() failed"); + return Unexpect(ErrCode::Value::HostFuncError); } // open connection int Sfd, Err; struct addrinfo Hints = {}, *Addrs; - char PortStr[16] = {}; Hints.ai_family = AF_INET; Hints.ai_socktype = SOCK_STREAM; Hints.ai_protocol = IPPROTO_TCP; - std::sprintf(PortStr, "%d", Port); - - Err = getaddrinfo(Env.Host.c_str(), PortStr, &Hints, &Addrs); + Err = getaddrinfo(HostStr.c_str(), PortStr.c_str(), &Hints, &Addrs); if (Err != 0) { - fprintf(stderr, "[Httpsreq plugin] %s: %s\n", Env.Host.c_str(), - gai_strerror(Err)); - abort(); + spdlog::error("[WasmEdge Httpsreq] {}", gai_strerror(Err)); + return Unexpect(ErrCode::Value::HostFuncError); } for (struct addrinfo *Addr = Addrs; Addr != NULL; Addr = Addr->ai_next) { @@ -79,7 +74,6 @@ Expect SendData::body(const Runtime::CallingFrame &Frame, Err = errno; break; } - if (connect(Sfd, Addr->ai_addr, Addr->ai_addrlen) == 0) break; Err = errno; @@ -90,9 +84,8 @@ Expect SendData::body(const Runtime::CallingFrame &Frame, freeaddrinfo(Addrs); if (Sfd == -1) { - fprintf(stderr, "[Httpsreq plugin] %s: %s\n", Env.Host.c_str(), - strerror(Err)); - abort(); + spdlog::error("[WasmEdge Httpsreq] {}", strerror(Err)); + return Unexpect(ErrCode::Value::HostFuncError); } SSL_set_fd(Ssl, Sfd); @@ -101,13 +94,11 @@ Expect SendData::body(const Runtime::CallingFrame &Frame, if (Status != 1) { SSL_get_error(Ssl, Status); ERR_print_errors_fp(stderr); - fprintf(stderr, - "[Httpsreq plugin] SSL_connect failed with SSL_get_error code %d\n", - Status); - exit(EXIT_FAILURE); + spdlog::error("[WasmEdge Httpsreq] SSL_get_error code {}", Status); + return Unexpect(ErrCode::Value::HostFuncError); } - SSL_write(Ssl, BodyStr, strlen(Env.BodyStr.c_str())); + SSL_write(Ssl, BodyStr.c_str(), BodyLen); // Receive char Buffer[1024]; @@ -129,8 +120,8 @@ Expect SendData::body(const Runtime::CallingFrame &Frame, return {}; } -Expect HttpsReqGetRcv::body(const Runtime::CallingFrame &Frame, - uint32_t BufPtr) { +Expect WasmEdgeHttpsReqGetRcv::body(const Runtime::CallingFrame &Frame, + uint32_t BufPtr) { auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); @@ -140,7 +131,8 @@ Expect HttpsReqGetRcv::body(const Runtime::CallingFrame &Frame, return {}; } -Expect HttpsReqGetRcvLen::body(const Runtime::CallingFrame &) { +Expect +WasmEdgeHttpsReqGetRcvLen::body(const Runtime::CallingFrame &) { return static_cast(Env.Rcv.size()); } diff --git a/plugins/wasmedge_httpsreq/httpsreqfunc.h b/plugins/wasmedge_httpsreq/httpsreqfunc.h new file mode 100644 index 00000000..566a1a7a --- /dev/null +++ b/plugins/wasmedge_httpsreq/httpsreqfunc.h @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "httpsreqbase.h" + +namespace WasmEdge { +namespace Host { + +class WasmEdgeHttpsReqSendData + : public WasmEdgeHttpsReq { +public: + WasmEdgeHttpsReqSendData(WasmEdgeHttpsReqEnvironment &HostEnv) + : WasmEdgeHttpsReq(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t HostPtr, + uint32_t HostLen, uint32_t Port, uint32_t BodyPtr, + uint32_t BodyLen); +}; + +class WasmEdgeHttpsReqGetRcv : public WasmEdgeHttpsReq { +public: + WasmEdgeHttpsReqGetRcv(WasmEdgeHttpsReqEnvironment &HostEnv) + : WasmEdgeHttpsReq(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t BufPtr); +}; + +class WasmEdgeHttpsReqGetRcvLen + : public WasmEdgeHttpsReq { +public: + WasmEdgeHttpsReqGetRcvLen(WasmEdgeHttpsReqEnvironment &HostEnv) + : WasmEdgeHttpsReq(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_httpsreq/httpsreqmodule.cpp b/plugins/wasmedge_httpsreq/httpsreqmodule.cpp new file mode 100644 index 00000000..d8e43208 --- /dev/null +++ b/plugins/wasmedge_httpsreq/httpsreqmodule.cpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "httpsreqmodule.h" +#include "httpsreqfunc.h" + +namespace WasmEdge { +namespace Host { + +/// Register your functions in module. +WasmEdgeHttpsReqModule::WasmEdgeHttpsReqModule() + : ModuleInstance("wasmedge_httpsreq") { + addHostFunc("wasmedge_httpsreq_send_data", + std::make_unique(Env)); + addHostFunc("wasmedge_httpsreq_get_rcv", + std::make_unique(Env)); + addHostFunc("wasmedge_httpsreq_get_rcv_len", + std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/httpsreq/httpsreqmodule.h b/plugins/wasmedge_httpsreq/httpsreqmodule.h similarity index 60% rename from plugins/httpsreq/httpsreqmodule.h rename to plugins/wasmedge_httpsreq/httpsreqmodule.h index 89ac524c..4931fc54 100644 --- a/plugins/httpsreq/httpsreqmodule.h +++ b/plugins/wasmedge_httpsreq/httpsreqmodule.h @@ -4,22 +4,20 @@ #pragma once #include "httpsreqenv.h" - #include "runtime/instance/module.h" - #include namespace WasmEdge { namespace Host { -class HttpsReqModule : public Runtime::Instance::ModuleInstance { +class WasmEdgeHttpsReqModule : public Runtime::Instance::ModuleInstance { public: - HttpsReqModule(); + WasmEdgeHttpsReqModule(); - HttpsReqEnvironment &getEnv() { return Env; } + WasmEdgeHttpsReqEnvironment &getEnv() { return Env; } private: - HttpsReqEnvironment Env; + WasmEdgeHttpsReqEnvironment Env; }; } // namespace Host diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt index 5a52aab1..8f5cebe0 100644 --- a/test/plugins/CMakeLists.txt +++ b/test/plugins/CMakeLists.txt @@ -9,7 +9,7 @@ if (WASMEDGE_PLUGIN_WASI_NN_BACKEND) endif() if(WASMEDGE_PLUGIN_HTTPSREQ) if(CMAKE_SYSTEM_NAME MATCHES "Linux") - add_subdirectory(httpsreq) + add_subdirectory(wasmedge_httpsreq) endif() endif() if (WASMEDGE_PLUGIN_WASI_CRYPTO) diff --git a/test/plugins/httpsreq/CMakeLists.txt b/test/plugins/wasmedge_httpsreq/CMakeLists.txt similarity index 100% rename from test/plugins/httpsreq/CMakeLists.txt rename to test/plugins/wasmedge_httpsreq/CMakeLists.txt diff --git a/test/plugins/httpsreq/httpsreq.cpp b/test/plugins/wasmedge_httpsreq/httpsreq.cpp similarity index 70% rename from test/plugins/httpsreq/httpsreq.cpp rename to test/plugins/wasmedge_httpsreq/httpsreq.cpp index 5f02c394..6306211a 100644 --- a/test/plugins/httpsreq/httpsreq.cpp +++ b/test/plugins/wasmedge_httpsreq/httpsreq.cpp @@ -17,10 +17,11 @@ namespace { WasmEdge::Runtime::Instance::ModuleInstance *createModule() { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( - "../../../plugins/httpsreq/" + "../../../plugins/wasmedge_httpsreq/" "libwasmedgePluginHttpsReq" WASMEDGE_LIB_EXTENSION)); - if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("https_req"sv)) { - if (const auto *Module = Plugin->findModule("https_req"sv)) { + if (const auto *Plugin = + WasmEdge::Plugin::Plugin::find("wasmedge_httpsreq"sv)) { + if (const auto *Module = Plugin->findModule("wasmedge_httpsreq"sv)) { return Module->create().release(); } } @@ -41,9 +42,9 @@ void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, } // namespace TEST(wasmedgeHttpsReqTests, SendData) { - // Create the httpsreq module instance. + // Create the wasmedge httpsreq module instance. auto *HttpMod = - dynamic_cast(createModule()); + dynamic_cast(createModule()); EXPECT_FALSE(HttpMod == nullptr); // Create the calling frame with memory instance. @@ -67,11 +68,11 @@ TEST(wasmedgeHttpsReqTests, SendData) { "Close\r\nReferer: https://httpbin.org/\r\n\r\n")); // Get the function "send_data" - auto *FuncInst = HttpMod->findFuncExports("send_data"); + auto *FuncInst = HttpMod->findFuncExports("wasmedge_httpsreq_send_data"); EXPECT_NE(FuncInst, nullptr); EXPECT_TRUE(FuncInst->isHostFunction()); - auto &HostFuncInst = - dynamic_cast(FuncInst->getHostFunc()); + auto &HostFuncInst = dynamic_cast( + FuncInst->getHostFunc()); // Test: Run function successfully for get requests EXPECT_TRUE(HostFuncInst.run( @@ -79,16 +80,13 @@ TEST(wasmedgeHttpsReqTests, SendData) { std::initializer_list{ UINT32_C(0), UINT32_C(11), UINT32_C(443), UINT32_C(30), UINT32_C(86)}, {})); - EXPECT_TRUE(HttpMod->getEnv().Host == "httpbin.org"); - EXPECT_TRUE(HttpMod->getEnv().BodyStr == - "GET / HTTP/1.1\nHost: httpbin.org\r\nConnection: " - "Close\r\nReferer: https://httpbin.org/\r\n\r\n"); + delete HttpMod; } TEST(wasmedgeHttpsReqTests, GetRcv) { // Create the httpsreq module instance. auto *HttpMod = - dynamic_cast(createModule()); + dynamic_cast(createModule()); EXPECT_FALSE(HttpMod == nullptr); // Create the calling frame with memory instance. @@ -102,6 +100,7 @@ TEST(wasmedgeHttpsReqTests, GetRcv) { WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); fillMemContent(MemInst, 0, 256); + // Set the memory[0, 11] as string "httpbin.org". fillMemContent(MemInst, 0, std::string("httpbin.org")); // Set the memory[30, 116] as string "GET / HTTP/1.1\nHost: @@ -111,25 +110,27 @@ TEST(wasmedgeHttpsReqTests, GetRcv) { "Close\r\nReferer: https://httpbin.org/\r\n\r\n")); // Get the function "send_data" - auto *FuncInst = HttpMod->findFuncExports("send_data"); + auto *FuncInst = HttpMod->findFuncExports("wasmedge_httpsreq_send_data"); EXPECT_NE(FuncInst, nullptr); EXPECT_TRUE(FuncInst->isHostFunction()); auto &HostFuncSendData = - dynamic_cast(FuncInst->getHostFunc()); + dynamic_cast( + FuncInst->getHostFunc()); // Get the function "get_rcv_len" - FuncInst = HttpMod->findFuncExports("https_req_get_rcv_len"); + FuncInst = HttpMod->findFuncExports("wasmedge_httpsreq_get_rcv_len"); EXPECT_NE(FuncInst, nullptr); EXPECT_TRUE(FuncInst->isHostFunction()); - auto &HostFuncGetRcvLen = dynamic_cast( - FuncInst->getHostFunc()); + auto &HostFuncGetRcvLen = + dynamic_cast( + FuncInst->getHostFunc()); // Get the function "get_rcv" - FuncInst = HttpMod->findFuncExports("https_req_get_rcv"); + FuncInst = HttpMod->findFuncExports("wasmedge_httpsreq_get_rcv"); EXPECT_NE(FuncInst, nullptr); EXPECT_TRUE(FuncInst->isHostFunction()); - auto &HostFuncGetRcv = - dynamic_cast(FuncInst->getHostFunc()); + auto &HostFuncGetRcv = dynamic_cast( + FuncInst->getHostFunc()); // Test: Run function successfully for get requests EXPECT_TRUE(HostFuncSendData.run( @@ -137,24 +138,19 @@ TEST(wasmedgeHttpsReqTests, GetRcv) { std::initializer_list{ UINT32_C(0), UINT32_C(11), UINT32_C(443), UINT32_C(30), UINT32_C(86)}, {})); - EXPECT_TRUE(HttpMod->getEnv().Host == "httpbin.org"); - EXPECT_TRUE(HttpMod->getEnv().BodyStr == - "GET / HTTP/1.1\nHost: httpbin.org\r\nConnection: " - "Close\r\nReferer: https://httpbin.org/\r\n\r\n"); // Test: Run function successfully for getrcvlen std::array RetVal; - EXPECT_TRUE(HostFuncGetRcvLen.run( - WasmEdge::Runtime::CallingFrame(nullptr, nullptr), {}, RetVal)); + EXPECT_TRUE(HostFuncGetRcvLen.run(CallFrame, {}, RetVal)); uint32_t Len = RetVal[0].get(); EXPECT_TRUE(Len > 0U); - // Test Run function successfully for getrcv - EXPECT_TRUE(HostFuncGetRcv.run( - CallFrame, std::initializer_list{UINT32_C(0)}, {})); - EXPECT_TRUE(std::equal(HttpMod->getEnv().Rcv.begin(), - HttpMod->getEnv().Rcv.end(), - MemInst.getPointer(0))); + // Test: Run function with nullptr memory instance -- fail + EXPECT_FALSE(HostFuncGetRcv.run( + WasmEdge::Runtime::CallingFrame(nullptr, nullptr), + std::initializer_list{UINT32_C(0)}, {})); + + delete HttpMod; } GTEST_API_ int main(int argc, char **argv) { From 2779107174104f13b7f731e0594ef569f1bd1ba1 Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Wed, 7 Sep 2022 00:30:54 +0800 Subject: [PATCH 073/623] [Plugin] Support openssl3 on wasi-crypto plugin * Use `OPENSSL_API_COMPAT` macro for old api interface Signed-off-by: Shen-Ta Hsieh --- plugins/wasi_crypto/CMakeLists.txt | 1 + plugins/wasi_crypto/asymmetric_common/ecdsa.h | 18 ++++++++++-------- plugins/wasi_crypto/signatures/rsa.cpp | 6 +++--- plugins/wasi_crypto/signatures/rsa.h | 5 +---- plugins/wasi_crypto/symmetric/kdf.cpp | 5 +++-- plugins/wasi_crypto/symmetric/kdf.h | 2 +- plugins/wasi_crypto/utils/evp_wrapper.h | 2 ++ 7 files changed, 21 insertions(+), 18 deletions(-) diff --git a/plugins/wasi_crypto/CMakeLists.txt b/plugins/wasi_crypto/CMakeLists.txt index 4eacb87e..2452d74e 100644 --- a/plugins/wasi_crypto/CMakeLists.txt +++ b/plugins/wasi_crypto/CMakeLists.txt @@ -53,6 +53,7 @@ wasmedge_add_library(wasmedgePluginWasiCrypto target_compile_options(wasmedgePluginWasiCrypto PUBLIC -DWASMEDGE_PLUGIN + -DOPENSSL_API_COMPAT=0x10100000L ) if(CMAKE_SYSTEM_NAME MATCHES "Darwin") diff --git a/plugins/wasi_crypto/asymmetric_common/ecdsa.h b/plugins/wasi_crypto/asymmetric_common/ecdsa.h index 2941c12c..d79274cd 100644 --- a/plugins/wasi_crypto/asymmetric_common/ecdsa.h +++ b/plugins/wasi_crypto/asymmetric_common/ecdsa.h @@ -127,7 +127,7 @@ class Ecdsa { static WasiCryptoExpect checkValid(EvpPkeyPtr Ctx, bool) noexcept { ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); - EC_KEY *EcCtx = EVP_PKEY_get0_EC_KEY(Ctx.get()); + const EC_KEY *EcCtx = EVP_PKEY_get0_EC_KEY(Ctx.get()); ensureOrReturn(EcCtx, __WASI_CRYPTO_ERRNO_INVALID_KEY); const EC_GROUP *Group = EC_KEY_get0_group(EcCtx); @@ -139,7 +139,7 @@ class Ecdsa { WasiCryptoExpect> exportSec(bool Compressed) const noexcept { - EC_KEY *EcCtx = EVP_PKEY_get0_EC_KEY(Ctx.get()); + const EC_KEY *EcCtx = EVP_PKEY_get0_EC_KEY(Ctx.get()); std::vector Res(getRawPkSize(Compressed)); opensslCheck(EC_POINT_point2oct( EC_KEY_get0_group(EcCtx), EC_KEY_get0_public_key(EcCtx), @@ -149,16 +149,18 @@ class Ecdsa { WasiCryptoExpect> exportPem(bool Compressed) const noexcept { - EC_KEY_set_conv_form(EVP_PKEY_get0_EC_KEY(Ctx.get()), - getForm(Compressed)); + EC_KEY_set_conv_form( + const_cast(EVP_PKEY_get0_EC_KEY(Ctx.get())), + getForm(Compressed)); return pemWritePUBKEY(Ctx.get()); } WasiCryptoExpect> exportPkcs8(bool Compressed) const noexcept { - EC_KEY_set_conv_form(EVP_PKEY_get0_EC_KEY(Ctx.get()), - getForm(Compressed)); + EC_KEY_set_conv_form( + const_cast(EVP_PKEY_get0_EC_KEY(Ctx.get())), + getForm(Compressed)); return i2dPUBKEY(Ctx.get()); } @@ -239,7 +241,7 @@ class Ecdsa { static WasiCryptoExpect checkValid(EvpPkeyPtr Ctx) noexcept { ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); - EC_KEY *EcCtx = EVP_PKEY_get0_EC_KEY(Ctx.get()); + const EC_KEY *EcCtx = EVP_PKEY_get0_EC_KEY(Ctx.get()); ensureOrReturn(EcCtx, __WASI_CRYPTO_ERRNO_INVALID_KEY); const EC_GROUP *Group = EC_KEY_get0_group(EcCtx); @@ -367,7 +369,7 @@ class Ecdsa { static WasiCryptoExpect checkValid(EvpPkeyPtr Ctx, bool) noexcept { ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); - EC_KEY *EcCtx = EVP_PKEY_get0_EC_KEY(Ctx.get()); + const EC_KEY *EcCtx = EVP_PKEY_get0_EC_KEY(Ctx.get()); ensureOrReturn(EcCtx, __WASI_CRYPTO_ERRNO_INVALID_KEY); // Curve id check. diff --git a/plugins/wasi_crypto/signatures/rsa.cpp b/plugins/wasi_crypto/signatures/rsa.cpp index 6493a16d..7657f392 100644 --- a/plugins/wasi_crypto/signatures/rsa.cpp +++ b/plugins/wasi_crypto/signatures/rsa.cpp @@ -43,7 +43,7 @@ template WasiCryptoExpect Rsa::PublicKey::checkValid(EvpPkeyPtr Ctx) noexcept { ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); - RSA *RsaKey = EVP_PKEY_get0_RSA(Ctx.get()); + const RSA *RsaKey = EVP_PKEY_get0_RSA(Ctx.get()); ensureOrReturn(RsaKey, __WASI_CRYPTO_ERRNO_INVALID_KEY); ensureOrReturn(RSA_bits(RsaKey) == KeyBits, __WASI_CRYPTO_ERRNO_INVALID_KEY); return {std::move(Ctx)}; @@ -128,7 +128,7 @@ template WasiCryptoExpect Rsa::SecretKey::checkValid(EvpPkeyPtr Ctx) noexcept { ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); - RSA *RsaKey = EVP_PKEY_get0_RSA(Ctx.get()); + const RSA *RsaKey = EVP_PKEY_get0_RSA(Ctx.get()); ensureOrReturn(RsaKey, __WASI_CRYPTO_ERRNO_INVALID_KEY); ensureOrReturn(RSA_bits(RsaKey) == KeyBits, __WASI_CRYPTO_ERRNO_INVALID_KEY); return {std::move(Ctx)}; @@ -207,7 +207,7 @@ template WasiCryptoExpect Rsa::KeyPair::checkValid(EvpPkeyPtr Ctx) noexcept { ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); - RSA *RsaKey = EVP_PKEY_get0_RSA(Ctx.get()); + const RSA *RsaKey = EVP_PKEY_get0_RSA(Ctx.get()); ensureOrReturn(RsaKey, __WASI_CRYPTO_ERRNO_INVALID_KEY); ensureOrReturn(RSA_bits(RsaKey) == KeyBits, __WASI_CRYPTO_ERRNO_INVALID_KEY); return {std::move(Ctx)}; diff --git a/plugins/wasi_crypto/signatures/rsa.h b/plugins/wasi_crypto/signatures/rsa.h index 980136e7..85c9b312 100644 --- a/plugins/wasi_crypto/signatures/rsa.h +++ b/plugins/wasi_crypto/signatures/rsa.h @@ -194,10 +194,7 @@ template class Rsa { private: static constexpr size_t getSigSize() { return KeyBits / 8; } - static void *getShaCtx() { - return static_cast( - const_cast(EVP_get_digestbynid(ShaNid))); - } + static const EVP_MD *getShaCtx() { return EVP_get_digestbynid(ShaNid); } }; using RSA_PKCS1_2048_SHA256 = Rsa; diff --git a/plugins/wasi_crypto/symmetric/kdf.cpp b/plugins/wasi_crypto/symmetric/kdf.cpp index 6d23a868..5b00c430 100644 --- a/plugins/wasi_crypto/symmetric/kdf.cpp +++ b/plugins/wasi_crypto/symmetric/kdf.cpp @@ -24,8 +24,9 @@ template constexpr uint32_t Hkdf::getKeySize() noexcept { return 64; } -template constexpr void *Hkdf::getShaCtx() noexcept { - return static_cast(const_cast(EVP_get_digestbynid(ShaNid))); +template +constexpr const EVP_MD *Hkdf::getShaCtx() noexcept { + return EVP_get_digestbynid(ShaNid); } template diff --git a/plugins/wasi_crypto/symmetric/kdf.h b/plugins/wasi_crypto/symmetric/kdf.h index 5ecc528f..73c8897c 100644 --- a/plugins/wasi_crypto/symmetric/kdf.h +++ b/plugins/wasi_crypto/symmetric/kdf.h @@ -232,7 +232,7 @@ template class Hkdf { private: constexpr static uint32_t getKeySize() noexcept; - constexpr static void *getShaCtx() noexcept; + constexpr static const EVP_MD *getShaCtx() noexcept; static WasiCryptoExpect openStateImpl(Span Key, int Mode) noexcept; diff --git a/plugins/wasi_crypto/utils/evp_wrapper.h b/plugins/wasi_crypto/utils/evp_wrapper.h index 601fd7a1..af4d0994 100644 --- a/plugins/wasi_crypto/utils/evp_wrapper.h +++ b/plugins/wasi_crypto/utils/evp_wrapper.h @@ -21,11 +21,13 @@ #include "common/span.h" #include +#include #include #include #include #include #include +#include #include #include From 474a0f10222ac25da7c286432a2cb279dc4b0519 Mon Sep 17 00:00:00 2001 From: sonder-joker Date: Wed, 7 Sep 2022 21:56:01 +0800 Subject: [PATCH 074/623] [WASI] Sync wasi-crypto. Signed-off-by: sonder-joker --- plugins/wasi_crypto/asymmetric_common/ecdsa.h | 62 +++++++------------ plugins/wasi_crypto/utils/hostfunction.h | 5 -- test/plugins/wasi_crypto/asymmetric.cpp | 20 +----- thirdparty/wasi_crypto/api.hpp | 29 +-------- 4 files changed, 25 insertions(+), 91 deletions(-) diff --git a/plugins/wasi_crypto/asymmetric_common/ecdsa.h b/plugins/wasi_crypto/asymmetric_common/ecdsa.h index d79274cd..9f4fc960 100644 --- a/plugins/wasi_crypto/asymmetric_common/ecdsa.h +++ b/plugins/wasi_crypto/asymmetric_common/ecdsa.h @@ -57,17 +57,11 @@ class Ecdsa { __wasi_publickey_encoding_e_t Encoding) noexcept { switch (Encoding) { case __WASI_PUBLICKEY_ENCODING_PKCS8: - return importPkcs8(Encoded, false); - case __WASI_PUBLICKEY_ENCODING_COMPRESSED_PKCS8: - return importPkcs8(Encoded, true); + return importPkcs8(Encoded); case __WASI_PUBLICKEY_ENCODING_PEM: - return importPem(Encoded, false); - case __WASI_PUBLICKEY_ENCODING_COMPRESSED_PEM: - return importPem(Encoded, true); + return importPem(Encoded); case __WASI_PUBLICKEY_ENCODING_SEC: - return importSec(Encoded, false); - case __WASI_PUBLICKEY_ENCODING_COMPRESSED_SEC: - return importSec(Encoded, true); + return importSec(Encoded); default: return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); } @@ -78,16 +72,10 @@ class Ecdsa { switch (Encoding) { case __WASI_PUBLICKEY_ENCODING_SEC: return exportSec(false); - case __WASI_PUBLICKEY_ENCODING_COMPRESSED_SEC: - return exportSec(true); case __WASI_PUBLICKEY_ENCODING_PEM: return exportPem(false); - case __WASI_PUBLICKEY_ENCODING_COMPRESSED_PEM: - return exportPem(true); case __WASI_PUBLICKEY_ENCODING_PKCS8: return exportPkcs8(false); - case __WASI_PUBLICKEY_ENCODING_COMPRESSED_PKCS8: - return exportPkcs8(true); default: return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); } @@ -98,18 +86,18 @@ class Ecdsa { } protected: - static WasiCryptoExpect importPkcs8(Span Encoded, - bool Compressed) noexcept { - return checkValid(EvpPkeyPtr{d2iPUBKEY(Encoded)}, Compressed); + static WasiCryptoExpect + importPkcs8(Span Encoded) noexcept { + return checkValid(EvpPkeyPtr{d2iPUBKEY(Encoded)}); } - static WasiCryptoExpect importPem(Span Encoded, - bool Compressed) noexcept { - return checkValid(EvpPkeyPtr{pemReadPUBKEY(Encoded)}, Compressed); + static WasiCryptoExpect + importPem(Span Encoded) noexcept { + return checkValid(EvpPkeyPtr{pemReadPUBKEY(Encoded)}); } - static WasiCryptoExpect importSec(Span Encoded, - bool Compressed) noexcept { + static WasiCryptoExpect + importSec(Span Encoded) noexcept { EcKeyPtr EcCtx{EC_KEY_new_by_curve_name(CurveNid)}; EcPointPtr Pk{EC_POINT_new(EC_KEY_get0_group(EcCtx.get()))}; ensureOrReturn(EC_POINT_oct2point(EC_KEY_get0_group(EcCtx.get()), @@ -121,11 +109,10 @@ class Ecdsa { EvpPkeyPtr Ctx{EVP_PKEY_new()}; opensslCheck(EVP_PKEY_set1_EC_KEY(Ctx.get(), EcCtx.get())); - return checkValid(std::move(Ctx), Compressed); + return checkValid(std::move(Ctx)); } - static WasiCryptoExpect checkValid(EvpPkeyPtr Ctx, - bool) noexcept { + static WasiCryptoExpect checkValid(EvpPkeyPtr Ctx) noexcept { ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); const EC_KEY *EcCtx = EVP_PKEY_get0_EC_KEY(Ctx.get()); ensureOrReturn(EcCtx, __WASI_CRYPTO_ERRNO_INVALID_KEY); @@ -297,13 +284,9 @@ class Ecdsa { case __WASI_KEYPAIR_ENCODING_RAW: return importRaw(Encoded); case __WASI_KEYPAIR_ENCODING_PKCS8: - return importPkcs8(Encoded, false); + return importPkcs8(Encoded); case __WASI_KEYPAIR_ENCODING_PEM: - return importPem(Encoded, false); - case __WASI_KEYPAIR_ENCODING_COMPRESSED_PKCS8: - return importPkcs8(Encoded, true); - case __WASI_KEYPAIR_ENCODING_COMPRESSED_PEM: - return importPem(Encoded, true); + return importPem(Encoded); default: return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ENCODING); } @@ -334,14 +317,14 @@ class Ecdsa { } protected: - static WasiCryptoExpect importPkcs8(Span Encoded, - bool Compressed) noexcept { - return checkValid(EvpPkeyPtr{d2iPrivateKey(Encoded)}, Compressed); + static WasiCryptoExpect + importPkcs8(Span Encoded) noexcept { + return checkValid(EvpPkeyPtr{d2iPrivateKey(Encoded)}); } - static WasiCryptoExpect importPem(Span Encoded, - bool Compressed) noexcept { - return checkValid(EvpPkeyPtr{pemReadPrivateKey(Encoded)}, Compressed); + static WasiCryptoExpect + importPem(Span Encoded) noexcept { + return checkValid(EvpPkeyPtr{pemReadPrivateKey(Encoded)}); } static WasiCryptoExpect @@ -366,8 +349,7 @@ class Ecdsa { return Ctx; } - static WasiCryptoExpect checkValid(EvpPkeyPtr Ctx, - bool) noexcept { + static WasiCryptoExpect checkValid(EvpPkeyPtr Ctx) noexcept { ensureOrReturn(Ctx, __WASI_CRYPTO_ERRNO_INVALID_KEY); const EC_KEY *EcCtx = EVP_PKEY_get0_EC_KEY(Ctx.get()); ensureOrReturn(EcCtx, __WASI_CRYPTO_ERRNO_INVALID_KEY); diff --git a/plugins/wasi_crypto/utils/hostfunction.h b/plugins/wasi_crypto/utils/hostfunction.h index 65a701af..dcbd0a5a 100644 --- a/plugins/wasi_crypto/utils/hostfunction.h +++ b/plugins/wasi_crypto/utils/hostfunction.h @@ -69,8 +69,6 @@ cast(uint64_t Encoding) noexcept { case __WASI_KEYPAIR_ENCODING_RAW: case __WASI_KEYPAIR_ENCODING_PKCS8: case __WASI_KEYPAIR_ENCODING_PEM: - case __WASI_KEYPAIR_ENCODING_COMPRESSED_PKCS8: - case __WASI_KEYPAIR_ENCODING_COMPRESSED_PEM: case __WASI_KEYPAIR_ENCODING_LOCAL: return static_cast<__wasi_keypair_encoding_e_t>(Encoding); default: @@ -86,9 +84,6 @@ cast(uint64_t Encoding) noexcept { case __WASI_PUBLICKEY_ENCODING_PKCS8: case __WASI_PUBLICKEY_ENCODING_PEM: case __WASI_PUBLICKEY_ENCODING_SEC: - case __WASI_PUBLICKEY_ENCODING_COMPRESSED_SEC: - case __WASI_PUBLICKEY_ENCODING_COMPRESSED_PKCS8: - case __WASI_PUBLICKEY_ENCODING_COMPRESSED_PEM: case __WASI_PUBLICKEY_ENCODING_LOCAL: return static_cast<__wasi_publickey_encoding_e_t>(Encoding); default: diff --git a/test/plugins/wasi_crypto/asymmetric.cpp b/test/plugins/wasi_crypto/asymmetric.cpp index c34f3a8d..79820665 100644 --- a/test/plugins/wasi_crypto/asymmetric.cpp +++ b/test/plugins/wasi_crypto/asymmetric.cpp @@ -86,15 +86,6 @@ TEST_F(WasiCryptoTest, Asymmetric) { "-----BEGIN PUBLIC KEY-----\n" "MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEYP7UuiVanTHJYet0xjVtaMBJuJI7\n" "Yfps5mliLmDyn7Z5A/4QCLi8maQa6elWKLxk8vGyDC1+n1F3o8KU1EYimQ==\n" - "-----END PUBLIC KEY-----\n"_u8}, - {__WASI_PUBLICKEY_ENCODING_COMPRESSED_SEC, - "0360FED4BA255A9D31C961EB74C6356D68C049B8923B61FA6CE669622E60F29FB6"_u8v}, - {__WASI_PUBLICKEY_ENCODING_COMPRESSED_PKCS8, - "3039301306072a8648ce3d020106082a8648ce3d0301070322000360FED4BA255A9D31C961EB74C6356D68C049B8923B61FA6CE669622E60F29FB6"_u8v}, - {__WASI_PUBLICKEY_ENCODING_COMPRESSED_PEM, - "-----BEGIN PUBLIC KEY-----\n" - "MDkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDIgADYP7UuiVanTHJYet0xjVtaMBJuJI7\n" - "Yfps5mliLmDyn7Y=\n" "-----END PUBLIC KEY-----\n"_u8}}, {{__WASI_SECRETKEY_ENCODING_RAW, "C9AFA9D845BA75166B5C215767B1D6934E50C3DB36E89B127B8A622B120F6721"_u8v}, @@ -115,16 +106,7 @@ TEST_F(WasiCryptoTest, Asymmetric) { "-----BEGIN PUBLIC KEY-----\n" "MFYwEAYHKoZIzj0CAQYFK4EEAAoDQgAEuDj/ROW8F3vyEYnQdmCC/J2EMiaIf8l2\n" "A3EQC37iCm/wyddb+6ezGmvKGXRJbutW3jVwcZVdg8Sxutqgshgy6Q==\n" - "-----END PUBLIC KEY-----"_u8}, - {__WASI_PUBLICKEY_ENCODING_COMPRESSED_SEC, - "03b838ff44e5bc177bf21189d0766082fc9d843226887fc9760371100b7ee20a6f"_u8v}, - {__WASI_PUBLICKEY_ENCODING_COMPRESSED_PEM, - "-----BEGIN PUBLIC KEY-----\n" - "MDYwEAYHKoZIzj0CAQYFK4EEAAoDIgADuDj/ROW8F3vyEYnQdmCC/J2EMiaIf8l2\n" - "A3EQC37iCm8=\n" - "-----END PUBLIC KEY-----\n"_u8}, - {__WASI_PUBLICKEY_ENCODING_COMPRESSED_PKCS8, - "3036301006072a8648ce3d020106052b8104000a03220003b838ff44e5bc177bf21189d0766082fc9d843226887fc9760371100b7ee20a6f"_u8v}}, + "-----END PUBLIC KEY-----"_u8}}, {{__WASI_SECRETKEY_ENCODING_RAW, "b9aa5c28ef96d750e47f4ba44d5d6a7ac3ab6988d292e7819e362a4b0ac8e250"_u8v}, {__WASI_SECRETKEY_ENCODING_PKCS8, diff --git a/thirdparty/wasi_crypto/api.hpp b/thirdparty/wasi_crypto/api.hpp index 115423af..8d6f6081 100644 --- a/thirdparty/wasi_crypto/api.hpp +++ b/thirdparty/wasi_crypto/api.hpp @@ -290,20 +290,10 @@ enum __wasi_keypair_encoding_e_t : uint16_t { */ __WASI_KEYPAIR_ENCODING_PEM = 2, - /** - * PCSK8/DER encoding with compressed coordinates. - */ - __WASI_KEYPAIR_ENCODING_COMPRESSED_PKCS8 = 3, - - /** - * PEM encoding with compressed coordinates. - */ - __WASI_KEYPAIR_ENCODING_COMPRESSED_PEM = 4, - /** * Implementation-defined encoding. */ - __WASI_KEYPAIR_ENCODING_LOCAL = 5, + __WASI_KEYPAIR_ENCODING_LOCAL = 3, }; static_assert(sizeof(__wasi_keypair_encoding_e_t) == 2, "witx calculated size"); @@ -333,25 +323,10 @@ enum __wasi_publickey_encoding_e_t : uint16_t { */ __WASI_PUBLICKEY_ENCODING_SEC = 3, - /** - * Compressed SEC-1 encoding. - */ - __WASI_PUBLICKEY_ENCODING_COMPRESSED_SEC = 4, - - /** - * PKCS8/DER encoding with compressed coordinates. - */ - __WASI_PUBLICKEY_ENCODING_COMPRESSED_PKCS8 = 5, - - /** - * PEM encoding with compressed coordinates. - */ - __WASI_PUBLICKEY_ENCODING_COMPRESSED_PEM = 6, - /** * Implementation-defined encoding. */ - __WASI_PUBLICKEY_ENCODING_LOCAL = 7, + __WASI_PUBLICKEY_ENCODING_LOCAL = 4, }; static_assert(sizeof(__wasi_publickey_encoding_e_t) == 2, "witx calculated size"); From ec7fc0b9a8acb29d3bcc2de1a46776018ce63b0a Mon Sep 17 00:00:00 2001 From: sonder-joker Date: Sun, 11 Sep 2022 15:47:30 +0800 Subject: [PATCH 075/623] [WASI] Add p384 https://github.com/WebAssembly/wasi-crypto/pull/62 Signed-off-by: sonder-joker --- plugins/wasi_crypto/kx/dh/ecdsa.cpp | 8 ++++++-- plugins/wasi_crypto/kx/dh/ecdsa.h | 11 +++++++---- plugins/wasi_crypto/kx/registed.h | 2 +- plugins/wasi_crypto/signatures/ecdsa.cpp | 1 + plugins/wasi_crypto/signatures/ecdsa.h | 1 + plugins/wasi_crypto/signatures/registed.h | 2 +- plugins/wasi_crypto/utils/hostfunction.cpp | 8 +++++++- test/plugins/wasi_crypto/asymmetric.cpp | 16 ++++++++++++++++ test/plugins/wasi_crypto/kx.cpp | 1 + test/plugins/wasi_crypto/signatures.cpp | 1 + 10 files changed, 42 insertions(+), 9 deletions(-) diff --git a/plugins/wasi_crypto/kx/dh/ecdsa.cpp b/plugins/wasi_crypto/kx/dh/ecdsa.cpp index b0fd4f5a..ef4536a6 100644 --- a/plugins/wasi_crypto/kx/dh/ecdsa.cpp +++ b/plugins/wasi_crypto/kx/dh/ecdsa.cpp @@ -14,9 +14,10 @@ namespace { inline const size_t SharedSecretSize = 32; } // namespace +template WasiCryptoExpect -Ecdsa::SecretKey::dh(const PublicKey &Pk) const noexcept { - EvpPkeyCtxPtr SkCtx{EVP_PKEY_CTX_new(Ctx.get(), nullptr)}; +Ecdsa::SecretKey::dh(const PublicKey &Pk) const noexcept { + EvpPkeyCtxPtr SkCtx{EVP_PKEY_CTX_new(this->Ctx.get(), nullptr)}; opensslCheck(EVP_PKEY_derive_init(SkCtx.get())); // Set peer key. @@ -33,6 +34,9 @@ Ecdsa::SecretKey::dh(const PublicKey &Pk) const noexcept { return Res; } +template class Ecdsa; +template class Ecdsa; + } // namespace Kx } // namespace WasiCrypto } // namespace Host diff --git a/plugins/wasi_crypto/kx/dh/ecdsa.h b/plugins/wasi_crypto/kx/dh/ecdsa.h index 2ffbbd82..d15cbb3c 100644 --- a/plugins/wasi_crypto/kx/dh/ecdsa.h +++ b/plugins/wasi_crypto/kx/dh/ecdsa.h @@ -30,19 +30,19 @@ namespace Host { namespace WasiCrypto { namespace Kx { -class Ecdsa { +template class Ecdsa { public: class PublicKey; class SecretKey; class KeyPair; - using Base = AsymmetricCommon::Ecdsa; + using Base = + AsymmetricCommon::Ecdsa; class PublicKey : public Base::PublicKeyBase { public: using Base::PublicKeyBase::PublicKeyBase; - const auto &raw() const { return Ctx; } + const auto &raw() const { return this->Ctx; } }; class SecretKey : public Base::SecretKeyBase { @@ -58,6 +58,9 @@ class Ecdsa { }; }; +using EcdsaP256 = Ecdsa; +using EcdsaP384 = Ecdsa; + } // namespace Kx } // namespace WasiCrypto } // namespace Host diff --git a/plugins/wasi_crypto/kx/registed.h b/plugins/wasi_crypto/kx/registed.h index 74fca12a..ac9e1713 100644 --- a/plugins/wasi_crypto/kx/registed.h +++ b/plugins/wasi_crypto/kx/registed.h @@ -32,7 +32,7 @@ template struct Registed { using Variant = std::variant; }; -using RegistedAlg = Registed; +using RegistedAlg = Registed; using Algorithm = RegistedAlg::Variant; diff --git a/plugins/wasi_crypto/signatures/ecdsa.cpp b/plugins/wasi_crypto/signatures/ecdsa.cpp index f03ed1b3..f295910c 100644 --- a/plugins/wasi_crypto/signatures/ecdsa.cpp +++ b/plugins/wasi_crypto/signatures/ecdsa.cpp @@ -118,6 +118,7 @@ Ecdsa::VerificationState::verify(const Signature &Sig) noexcept { template class Ecdsa; template class Ecdsa; +template class Ecdsa; } // namespace Signatures } // namespace WasiCrypto diff --git a/plugins/wasi_crypto/signatures/ecdsa.h b/plugins/wasi_crypto/signatures/ecdsa.h index 8e8c57bb..4c8e373e 100644 --- a/plugins/wasi_crypto/signatures/ecdsa.h +++ b/plugins/wasi_crypto/signatures/ecdsa.h @@ -111,6 +111,7 @@ template class Ecdsa { using EcdsaP256 = Ecdsa; using EcdsaK256 = Ecdsa; +using EcdsaP384 = Ecdsa; } // namespace Signatures } // namespace WasiCrypto diff --git a/plugins/wasi_crypto/signatures/registed.h b/plugins/wasi_crypto/signatures/registed.h index c80c66e9..9b736ed0 100644 --- a/plugins/wasi_crypto/signatures/registed.h +++ b/plugins/wasi_crypto/signatures/registed.h @@ -37,7 +37,7 @@ template struct Registed { }; using RegistedAlg = - Registed tryFrom(std::string_view RawAlgStr) noexcept { return Algorithm{std::in_place_type}; } if (AlgStr == "P256-SHA256"sv) { - return Algorithm{std::in_place_type}; + return Algorithm{std::in_place_type}; + } + if (AlgStr == "P384-SHA384"sv) { + return Algorithm{std::in_place_type}; } return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_ALGORITHM); } @@ -117,6 +120,9 @@ tryFrom(std::string_view RawAlgStr) noexcept { if (AlgStr == "ECDSA_K256_SHA256"sv) { return Algorithm{std::in_place_type}; } + if (AlgStr == "ECDSA_P384_SHA384"sv) { + return Algorithm{std::in_place_type}; + } if (AlgStr == "ED25519"sv) { return Algorithm{std::in_place_type}; } diff --git a/test/plugins/wasi_crypto/asymmetric.cpp b/test/plugins/wasi_crypto/asymmetric.cpp index 79820665..9a3389d9 100644 --- a/test/plugins/wasi_crypto/asymmetric.cpp +++ b/test/plugins/wasi_crypto/asymmetric.cpp @@ -121,6 +121,22 @@ TEST_F(WasiCryptoTest, Asymmetric) { "y7AvRqme3w4dEUzck5Vsx1ZIv9OPqDKoITXVwrpjR2aodT9tiKrl\n" "-----END PRIVATE KEY-----\n"_u8}}, {}); + EncodingCheck( + "ECDSA_P384_SHA384"sv, __WASI_ALGORITHM_TYPE_SIGNATURES, + {{__WASI_PUBLICKEY_ENCODING_PEM, + "-----BEGIN PUBLIC KEY-----\n" + "MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAEUIwwEUeBvHWigFKAYgPs93Pb21mUr2Oa\n" + "CTSiuK0Hmjc25OO5m4Gk92s4HG22R5596FB7nXlZljqEnETE4xDc4Dugv5ZzlTCf\n" + "HLeAvmfv+hFMRzroZ+GS1Xwgl434yH9r\n" + "-----END PUBLIC KEY-----\n"_u8}}, + {{__WASI_SECRETKEY_ENCODING_PEM, + "-----BEGIN PRIVATE KEY-----\n" + "MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDAHfLyuXwp7DoNPIvxg\n" + "B5k8zOAyXHFpJ4FF7CIg4zH/UBFb5m8AyT+c9rvvyVcHlEKhZANiAARQjDARR4G8\n" + "daKAUoBiA+z3c9vbWZSvY5oJNKK4rQeaNzbk47mbgaT3azgcbbZHnn3oUHudeVmW\n" + "OoScRMTjENzgO6C/lnOVMJ8ct4C+Z+/6EUxHOuhn4ZLVfCCXjfjIf2s=\n" + "-----END PRIVATE KEY-----\n"_u8}}, + {}); EncodingCheck( "Ed25519"sv, __WASI_ALGORITHM_TYPE_SIGNATURES, {{__WASI_PUBLICKEY_ENCODING_RAW, diff --git a/test/plugins/wasi_crypto/kx.cpp b/test/plugins/wasi_crypto/kx.cpp index d3b0d80c..de107fcb 100644 --- a/test/plugins/wasi_crypto/kx.cpp +++ b/test/plugins/wasi_crypto/kx.cpp @@ -121,6 +121,7 @@ TEST_F(WasiCryptoTest, KxDh) { WASI_CRYPTO_EXPECT_TRUE(secretkeyClose(Sk1Handle)); }; NewKxDhTest("P256-SHA256"sv); + NewKxDhTest("P384-SHA384"sv); } } // namespace WasiCrypto diff --git a/test/plugins/wasi_crypto/signatures.cpp b/test/plugins/wasi_crypto/signatures.cpp index f59ff6fd..705038bc 100644 --- a/test/plugins/wasi_crypto/signatures.cpp +++ b/test/plugins/wasi_crypto/signatures.cpp @@ -37,6 +37,7 @@ TEST_F(WasiCryptoTest, Signatures) { }; SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "ECDSA_P256_SHA256"sv); SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "ECDSA_K256_SHA256"sv); + SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "ECDSA_P384_SHA384"sv); SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "Ed25519"sv); SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PKCS1_2048_SHA256"sv); SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PKCS1_2048_SHA384"sv); From 3766960b23b5cd8ed268f2072a0b1bdd13aa902a Mon Sep 17 00:00:00 2001 From: Gustavo Ye Date: Mon, 22 Aug 2022 17:04:22 +0800 Subject: [PATCH 076/623] [WASI] Support PyTorch Backend of WASI-NN proposal (#1654) Signed-off-by: Gustavo Ye --- plugins/wasi_nn/CMakeLists.txt | 11 +- plugins/wasi_nn/wasinnenv.h | 52 ++- plugins/wasi_nn/wasinnfunc.cpp | 291 +++++++++++++--- test/plugins/wasi_nn/CMakeLists.txt | 17 + test/plugins/wasi_nn/wasi_nn.cpp | 375 ++++++++++++++++++++- utils/wasi-nn/download-pytorch-fixtures.sh | 12 + utils/wasi-nn/install-pytorch.sh | 5 + 7 files changed, 689 insertions(+), 74 deletions(-) create mode 100644 utils/wasi-nn/download-pytorch-fixtures.sh create mode 100644 utils/wasi-nn/install-pytorch.sh diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index efce12db..af002928 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -44,13 +44,22 @@ install(TARGETS wasmedgePluginWasiNN DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedg # Add backends building flags. foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) if(BACKEND MATCHES "OpenVINO") - message(STATUS "Build ${BACKEND} backend for WASI-NN") + message(STATUS "Build OpenVINO backend for WASI-NN") find_package(InferenceEngine REQUIRED) add_definitions(-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO) target_link_libraries(wasmedgePluginWasiNN PUBLIC ${InferenceEngine_LIBRARIES} ) + elseif(BACKEND MATCHES "PyTorch") + message(STATUS "Build PyTorch backend for WASI-NN") + find_package(Torch REQUIRED) + add_definitions(-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") + target_link_libraries(wasmedgePluginWasiNN + PUBLIC + ${TORCH_LIBRARIES} + ) else() # Add the other backends here. message(FATAL_ERROR "WASI-NN backend ${BACKEND} not found or unimplemented.") diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index a163f01e..eb525836 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -3,15 +3,17 @@ #pragma once +#include "common/log.h" #include "plugin/plugin.h" - #include #include #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO -#include "common/log.h" #include #endif +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH +#include +#endif namespace WasmEdge { namespace Host { @@ -26,16 +28,21 @@ enum class ErrNo : uint32_t { enum class Backend : uint8_t { OpenVINO = 0, + PyTorch = 1, }; class Graph { public: -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO Graph() = delete; - Graph(Backend BE) noexcept - : GraphBackend(BE), OpenVINONetwork(nullptr), - OpenVINOExecNetwork(nullptr), OpenVINOWeightBlob(nullptr) {} + Graph(Backend BE) noexcept : GraphBackend(BE) { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO + OpenVINONetwork = nullptr; + OpenVINOExecNetwork = nullptr; + OpenVINOWeightBlob = nullptr; +#endif + } ~Graph() noexcept { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO if (OpenVINONetwork) { ie_network_free(&OpenVINONetwork); } @@ -55,10 +62,8 @@ class Graph { ie_network_name_free(&I); } } - } -#else - Graph() noexcept = default; #endif + } Backend GraphBackend; #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO @@ -68,26 +73,43 @@ class Graph { std::vector OpenVINOInputNames; std::vector OpenVINOOutputNames; #endif +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH + torch::jit::Module TorchModel; +#endif }; class Context { public: Context() = delete; + + Context(Graph &G) noexcept : GraphRef(G) { + if (G.GraphBackend == Backend::OpenVINO) { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO - Context(Graph &G, ie_infer_request_t *InferReq) noexcept - : GraphRef(G), OpenVINOInferRequest(InferReq) {} + IEStatusCode Status = ie_exec_network_create_infer_request( + G.OpenVINOExecNetwork, &OpenVINOInferRequest); + if (Status != IEStatusCode::OK) { + OpenVINOInferRequest = nullptr; + spdlog::error("[WASI-NN] Unable to create infer request for OpenVINO"); + } +#endif + } + } + ~Context() noexcept { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO if (OpenVINOInferRequest) { ie_infer_request_free(&OpenVINOInferRequest); } - } -#else - Context(Graph &G) noexcept : GraphRef(G) {} #endif + } Graph &GraphRef; #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO - ie_infer_request_t *OpenVINOInferRequest; + ie_infer_request_t *OpenVINOInferRequest = nullptr; +#endif +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH + std::vector TorchInputs; + std::vector TorchOutputs; #endif }; diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index f5807d16..b55ba289 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -4,16 +4,43 @@ #include "wasinnfunc.h" #include "common/log.h" +#include + #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO #include -#include #include #endif +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH +#include + +#include +#endif + namespace WasmEdge { namespace Host { +namespace { +std::string FindDevice(const uint32_t Target) { + std::string DeviceName; + switch (Target) { + case 0: + DeviceName = "CPU"; + break; + // case 1: + // DeviceName = "GPU"; + // break; + // case 2: + // DeviceName = "TPU"; + // break; + default: + DeviceName = ""; + } + return DeviceName; +} +} // namespace + Expect WasiNNLoad::body(const Runtime::CallingFrame &Frame, uint32_t BuilderPtr [[maybe_unused]], uint32_t BuilderLen [[maybe_unused]], @@ -25,7 +52,12 @@ Expect WasiNNLoad::body(const Runtime::CallingFrame &Frame, if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); } - + // Check the return value: GraphIdPtr should be valid. + uint32_t *GraphId = MemInst->getPointer(GraphIdPtr, 1); + if (unlikely(GraphId == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the return GraphID memory."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } if (Encoding == static_cast(WASINN::Backend::OpenVINO)) { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO // The OpenVINO core must be initialized in constructor. @@ -34,14 +66,6 @@ Expect WasiNNLoad::body(const Runtime::CallingFrame &Frame, return static_cast(WASINN::ErrNo::MissingMemory); } - // Check the return value: GraphIdPtr should be valid. - uint32_t *GraphId = MemInst->getPointer(GraphIdPtr, 1); - if (unlikely(GraphId == nullptr)) { - spdlog::error( - "[WASI-NN] Failed when accessing the return GraphID memory."); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - // The graph builder length must be 2. if (BuilderLen != 2) { spdlog::error("[WASI-NN] Wrong GraphBuilder Length {:d}, expect 2", @@ -51,25 +75,12 @@ Expect WasiNNLoad::body(const Runtime::CallingFrame &Frame, // Get and check the device name string. std::string DeviceName; - switch (Target) { - case 0: - DeviceName = "CPU"; - break; - case 1: - DeviceName = "GPU"; - break; - case 2: - DeviceName = "TPU"; - break; - default: - DeviceName = ""; - } + DeviceName = FindDevice(Target); if (DeviceName.length() == 0) { - spdlog::error("[WASI-NN] Device target {:d} not support!", Target); + spdlog::error("[WASI-NN] PyTorch backend only support CPU target"); return static_cast(WASINN::ErrNo::InvalidArgument); - } else { - spdlog::debug("[WASI-NN] Using device: {:s}", DeviceName); } + spdlog::debug("[WASI-NN] Using device: {:s}", DeviceName); // Get the graph builders. // GraphBuilders' Layout: @@ -95,7 +106,7 @@ Expect WasiNNLoad::body(const Runtime::CallingFrame &Frame, return static_cast(WASINN::ErrNo::InvalidArgument); } if (unlikely(BinPtr == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the Weignt memory."); + spdlog::error("[WASI-NN] Failed when accessing the Weight memory."); return static_cast(WASINN::ErrNo::InvalidArgument); } @@ -233,9 +244,58 @@ Expect WasiNNLoad::body(const Runtime::CallingFrame &Frame, return static_cast(WASINN::ErrNo::Success); #else spdlog::error("[WASI-NN] OpenVINO backend is not built. use " - "-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO=ON" - "to build it."); + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"OpenVINO\" to build it."); #endif + } else if (Encoding == static_cast(WASINN::Backend::PyTorch)) { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH + std::string DeviceName; + DeviceName = FindDevice(Target); + if (DeviceName.length() == 0) { + spdlog::error("[WASI-NN] PyTorch backend only support CPU target"); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + + // The graph builder length must be 2. + if (BuilderLen != 1) { + spdlog::error("[WASI-NN] Wrong GraphBuilder Length {:d}, expect 1", + BuilderLen); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + uint32_t *GraphBuilders = + MemInst->getPointer(BuilderPtr, BuilderLen * 2); + if (unlikely(GraphBuilders == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the GraphBuilder memory."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + + uint32_t BinLen = GraphBuilders[1]; + uint8_t *BinPtr = MemInst->getPointer(GraphBuilders[0], BinLen); + if (unlikely(BinPtr == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the Weight memory."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + // Add a new graph. + Env.NNGraph.emplace_back(static_cast(Encoding)); + auto &Graph = Env.NNGraph.back(); + std::string BinString((char *)BinPtr, BinLen); + std::stringstream BinRead; + BinRead.str(BinString); + + try { + Graph.TorchModel = torch::jit::load(BinRead); + } catch (const c10::Error &e) { + spdlog::error("[WASI-NN] Failed when load the TorchScript model."); + Env.NNGraph.pop_back(); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + // Store the loaded graph. + *GraphId = Env.NNGraph.size() - 1; + return static_cast(WASINN::ErrNo::Success); + +#else + spdlog::error("[WASI-NN] PyTorch backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"PyTorch\" to build it."); +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH } else { spdlog::error("[WASI-NN] Current backend is not supported."); } @@ -254,16 +314,14 @@ Expect WasiNNInitExecCtx::body(const Runtime::CallingFrame &Frame, spdlog::error("[WASI-NN] init_execution_context: Graph Id does not exist."); return static_cast(WASINN::ErrNo::InvalidArgument); } - + // Check the return value: Context should be valid. + uint32_t *Context = MemInst->getPointer(ContextPtr, 1); + if (unlikely(Context == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the Context memory."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } if (Env.NNGraph[GraphId].GraphBackend == WASINN::Backend::OpenVINO) { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO - // Check the return value: Context should be valid. - uint32_t *Context = MemInst->getPointer(ContextPtr, 1); - if (unlikely(Context == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the Context memory."); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - // Check the network and the execution network with the graph ID. if (Env.NNGraph[GraphId].OpenVINONetwork == nullptr || Env.NNGraph[GraphId].OpenVINOExecNetwork == nullptr) { @@ -271,23 +329,31 @@ Expect WasiNNInitExecCtx::body(const Runtime::CallingFrame &Frame, return static_cast(WASINN::ErrNo::MissingMemory); } - // Create the infer request. - ie_infer_request_t *InferRequest = nullptr; - IEStatusCode Status = ie_exec_network_create_infer_request( - Env.NNGraph[GraphId].OpenVINOExecNetwork, &InferRequest); - if (Status != IEStatusCode::OK) { - spdlog::error("[WASI-NN] Unable to create openvino session"); + // Create context. + Env.NNContext.emplace_back(Env.NNGraph[GraphId]); + auto &NewContext = Env.NNContext.back(); + if (NewContext.OpenVINOInferRequest == nullptr) { + spdlog::error("[WASI-NN] Unable to create openvino context"); + Env.NNContext.pop_back(); return static_cast(WASINN::ErrNo::Busy); } - *Context = Env.NNContext.size(); - Env.NNContext.emplace_back(Env.NNGraph[GraphId], InferRequest); + *Context = Env.NNContext.size() - 1; + return static_cast(WASINN::ErrNo::Success); +#else + spdlog::error("[WASI-NN] OpenVINO backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"OpenVINO\" to build it."); +#endif + } else if (Env.NNGraph[GraphId].GraphBackend == WASINN::Backend::PyTorch) { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH + Env.NNContext.emplace_back(Env.NNGraph[GraphId]); + *Context = Env.NNContext.size() - 1; return static_cast(WASINN::ErrNo::Success); + #else - spdlog::error("[WASI-NN] OpenVINO backend is not built. define " - "-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO " - "to build it."); + spdlog::error("[WASI-NN] PyTorch backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"PyTorch\" to build it."); #endif } else { spdlog::error("[WASI-NN] Current backend is not supported."); @@ -458,9 +524,53 @@ Expect WasiNNSetInput::body(const Runtime::CallingFrame &Frame, return static_cast(WASINN::ErrNo::Success); #else - spdlog::error("[WASI-NN] OpenVINO backend is not built, use " - "-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO=ON" - "to build it."); + spdlog::error("[WASI-NN] OpenVINO backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"OpenVINO\" to build it."); +#endif + } else if (CxtRef.GraphRef.GraphBackend == WASINN::Backend::PyTorch) { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH + if (Index >= CxtRef.TorchInputs.size()) { + CxtRef.TorchInputs.resize(Index + 1); + } + uint32_t *Tensor = MemInst->getPointer(TensorPtr, 5); + if (unlikely(Tensor == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the Tensor memory."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + uint32_t DimensionLen = Tensor[1]; + uint32_t *DimensionBuf = + MemInst->getPointer(Tensor[0], DimensionLen); + if (unlikely(DimensionBuf == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the Dimension memory."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + uint32_t TensorDataLen = Tensor[4]; + uint8_t *TensorDataBuf = + MemInst->getPointer(Tensor[3], TensorDataLen); + if (unlikely(TensorDataBuf == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the TensorData memory."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + uint32_t RType = Tensor[2]; + if (RType != 1) { + spdlog::error( + "[WASI-NN] Only F32 inputs and outputs are supported for now."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + auto Options = + torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false); + std::vector Dims; + for (size_t I = 0; I < DimensionLen; I++) { + Dims.push_back(static_cast(DimensionBuf[I])); + } + torch::Tensor InTensor = torch::from_blob( + reinterpret_cast(TensorDataBuf), Dims, Options); + + CxtRef.TorchInputs[Index] = InTensor.clone(); + return static_cast(WASINN::ErrNo::Success); +#else + spdlog::error("[WASI-NN] PyTorch backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"PyTorch\" to build it."); #endif } else { spdlog::error("[WASI-NN] Current backend is not supported."); @@ -557,8 +667,46 @@ WasiNNGetOuput::body(const Runtime::CallingFrame &Frame, uint32_t Context, return static_cast(WASINN::ErrNo::Success); #else spdlog::error("[WASI-NN] OpenVINO backend is not built. use " - "-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO=ON" - "to build it."); + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"OpenVINO\" to build it."); +#endif + } else if (CxtRef.GraphRef.GraphBackend == WASINN::Backend::PyTorch) { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH + if (CxtRef.TorchOutputs.size() <= Index) { + spdlog::error( + "[WASI-NN] The output index {} exceeds the outputs number {}.", Index, + CxtRef.TorchOutputs.size()); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + torch::Tensor OutTensor = + CxtRef.TorchOutputs[Index].toType(torch::kFloat32); + float *TensorBuffer = OutTensor.data_ptr(); + + size_t BlobSize = 1; + for (auto I : OutTensor.sizes()) { + BlobSize *= I; + } + uint32_t BytesToWrite = + std::min(static_cast(BlobSize * 4), OutBufferMaxSize); + uint8_t *OutBuffer = + MemInst->getPointer(OutBufferPtr, BytesToWrite); + if (unlikely(OutBuffer == nullptr)) { + spdlog::error( + "[WASI-NN] Failed when accessing the Output Buffer memory."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + std::copy_n(reinterpret_cast(TensorBuffer), BytesToWrite, + OutBuffer); + uint32_t *BytesWritten = + MemInst->getPointer(BytesWrittenPtr, 1); + if (unlikely(BytesWritten == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the BytesWritten memory."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + *BytesWritten = BytesToWrite; + return static_cast(WASINN::ErrNo::Success); +#else + spdlog::error("[WASI-NN] PyTorch backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"PyTorch\" to build it."); #endif } else { spdlog::error("[WASI-NN] Current backend is not supported."); @@ -591,8 +739,41 @@ Expect WasiNNCompute::body(const Runtime::CallingFrame &Frame, return static_cast(WASINN::ErrNo::Success); #else spdlog::error("[WASI-NN] OpenVINO backend is not built. use " - "-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO=ON" - "to build it."); + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"OpenVINO\" to build it."); +#endif + } else if (CxtRef.GraphRef.GraphBackend == WASINN::Backend::PyTorch) { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH + if (CxtRef.TorchInputs.size() == 0) { + spdlog::error("[WASI-NN] Input is not set!"); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + for (size_t I = 0; I < CxtRef.TorchInputs.size(); I++) { + torch::jit::IValue InTensor = CxtRef.TorchInputs[I]; + if (InTensor.isNone()) { + spdlog::error("[WASI-NN] Input [{}] is not set!", I); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + } + torch::jit::IValue RawOutput = + CxtRef.GraphRef.TorchModel.forward(CxtRef.TorchInputs); + // TODO: more output type should be supported here + if (RawOutput.isTensorList()) { + auto OutTensors = RawOutput.toTensorVector(); + for (auto &OneOf : OutTensors) { + CxtRef.TorchOutputs.push_back(OneOf.clone()); + } + } else if (RawOutput.isTensor()) { + auto OutTensor = RawOutput.toTensor(); + CxtRef.TorchOutputs.push_back(OutTensor.clone()); + } else { + spdlog::error("[WASI-NN] PyTorch backend only supports output a tensor " + "or a list of tensor"); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + return static_cast(WASINN::ErrNo::Success); +#else + spdlog::error("[WASI-NN] PyTorch backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"PyTorch\" to build it."); #endif } else { spdlog::error("[WASI-NN] Current backend is not supported."); diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index bc6193a0..cba10683 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -26,6 +26,23 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) message(FATAL_ERROR "tensor-1x224x224x3-f32.bgr downloaded with wrong md5") endif() add_definitions(-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO) + elseif(BACKEND MATCHES "PyTorch") + message( STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_pytorch_fixtures") + execute_process( + COMMAND bash ${CMAKE_SOURCE_DIR}/utils/wasi-nn/download-pytorch-fixtures.sh ${CMAKE_CURRENT_BINARY_DIR}/wasinn_pytorch_fixtures + RESULT_VARIABLE DOWNLOAD_ERROR + OUTPUT_STRIP_TRAILING_WHITESPACE) + file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_pytorch_fixtures/mobilenet.pt CHECKSUM_WEIGHT) + file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_pytorch_fixtures/image-1-3-244-244.rgb CHECKSUM_IMAGE) + if(NOT CHECKSUM_WEIGHT STREQUAL "4823fc625c9ab452c64ad614eb52f2f0") + message(FATAL_ERROR "mobilenet.pt downloaded with wrong md5") + endif() + if(NOT CHECKSUM_IMAGE STREQUAL "551caa6f3b66c1d953655228462570a1") + message(FATAL_ERROR "image-1-3-244-244.rgb downloaded with wrong md5") + endif() + add_definitions(-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH) + find_package(Torch REQUIRED) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") else() # Add the other backend test files fetching here. endif() diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 0edf04c5..1c573e9c 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -13,9 +13,11 @@ #include #include +using WasmEdge::Host::WASINN::Backend; using WasmEdge::Host::WASINN::ErrNo; -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO +#if defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO) || \ + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH) namespace { WasmEdge::Runtime::Instance::ModuleInstance *createModule() { using namespace std::literals::string_view_literals; @@ -77,7 +79,9 @@ std::vector classSort(const std::vector &Array) { return Indices; } } // namespace +#endif +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO TEST(WasiNNTest, OpenVINOBackend) { // Create the wasmedge_process module instance. auto *NNMod = dynamic_cast(createModule()); @@ -273,7 +277,7 @@ TEST(WasiNNTest, OpenVINOBackend) { } // Swap to the tmp. env. - NNGraphTmp.emplace_back(WasmEdge::Host::WASINN::Backend::OpenVINO); + NNGraphTmp.emplace_back(Backend::OpenVINO); NNGraphTmp.swap(NNMod->getEnv().NNGraph); NNContextTmp.swap(NNMod->getEnv().NNContext); // Test: init_execution_context -- graph id exceeds. @@ -321,7 +325,7 @@ TEST(WasiNNTest, OpenVINOBackend) { writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); // Swap to the tmp. env. - NNContextTmp.emplace_back(NNGraphTmp[0], nullptr); + NNContextTmp.emplace_back(NNGraphTmp[0]); NNGraphTmp.swap(NNMod->getEnv().NNGraph); NNContextTmp.swap(NNMod->getEnv().NNContext); // Test: set_input -- context id exceeds. @@ -479,3 +483,368 @@ TEST(WasiNNTest, OpenVINOBackend) { } } #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH +TEST(WasiNNTest, PyTorchBackend) { + // Create the wasmedge_process module instance. + auto *NNMod = dynamic_cast(createModule()); + EXPECT_FALSE(NNMod == nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(400))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // Load the files. + std::vector TensorData = + readEntireFile("./wasinn_pytorch_fixtures/image-1-3-244-244.rgb"); + std::vector WeightRead = + readEntireFile("./wasinn_pytorch_fixtures/mobilenet.pt"); + + std::vector TensorDim{1, 3, 224, 224}; + uint32_t BuilderPtr = UINT32_C(0); + uint32_t LoadEntryPtr = UINT32_C(0); + uint32_t SetInputEntryPtr = UINT32_C(0); + uint32_t OutBoundPtr = UINT32_C(410 * 65536); + uint32_t StorePtr = UINT32_C(65536); + + // Return value. + std::array Errno = {UINT32_C(0)}; + + // Temp. values. + std::vector NNGraphTmp; + std::vector NNContextTmp; + + // Get the function "load". + auto *FuncInst = NNMod->findFuncExports("load"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncLoad = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "init_execution_context". + FuncInst = NNMod->findFuncExports("init_execution_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInit = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "set_input". + FuncInst = NNMod->findFuncExports("set_input"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSetInput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "get_output". + FuncInst = NNMod->findFuncExports("get_output"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetOutput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "compute". + FuncInst = NNMod->findFuncExports("compute"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncCompute = + dynamic_cast(FuncInst->getHostFunc()); + + // Torch WASI-NN load tests. + // Test: load -- meaningless binaries. + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::PyTorch), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- graph id ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::PyTorch), + UINT32_C(0), OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- graph builder ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + OutBoundPtr, UINT32_C(1), + static_cast(Backend::PyTorch), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: load -- Torch model bin ptr out of bounds. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, OutBoundPtr, WeightRead.size(), BuilderPtr); + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::PyTorch), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- wrong builders' length. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, StorePtr, WeightRead.size(), BuilderPtr); + writeBinaries(MemInst, WeightRead, StorePtr); + StorePtr += WeightRead.size(); + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(2), + static_cast(Backend::PyTorch), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- unsupported device. CPU 0, GPU 1, TPU 2 + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::PyTorch), + UINT32_C(3), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- load successfully. + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::PyTorch), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Test: load -- load second graph. + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::PyTorch), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 1); + BuilderPtr += 4; + } + + // Torch WASI-NN init_execution_context tests. + // Test: init_execution_context -- graph id invalid. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(2), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Swap to the tmp. env. + NNGraphTmp.emplace_back(Backend::PyTorch); + // Test: init_execution_context -- graph id exceeds. + // TODO: not null test for pytorch now + // NNGraphTmp.swap(NNMod->getEnv().NNGraph); + // NNContextTmp.swap(NNMod->getEnv().NNContext); + // { + // EXPECT_TRUE(HostFuncInit.run( + // CallFrame, + // std::initializer_list{UINT32_C(0), + // BuilderPtr}, Errno)); + // EXPECT_EQ(Errno[0].get(), + // static_cast(ErrNo::MissingMemory)); + // } + // Swap back. + // NNGraphTmp.swap(NNMod->getEnv().NNGraph); + // NNContextTmp.swap(NNMod->getEnv().NNContext); + + // Test: init_execution_context -- init context successfully. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Test: init_execution_context -- init second context. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(1), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 1); + BuilderPtr += 4; + } + + // Torch WASI-NN set_input tests. + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); + + // Test: set_input -- context id exceeds. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(3), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + NNContextTmp.emplace_back(NNGraphTmp[0]); + + // Test: set_input -- tensor type not FP32. + BuilderPtr = SetInputEntryPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, UINT32_C(2), BuilderPtr); + writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + BuilderPtr); + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: set_input -- set input successfully. + BuilderPtr = SetInputEntryPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + BuilderPtr); + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(1), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += (TensorDim.size() * 4 + TensorData.size()); + + // Torch WASI-NN compute tests. + // Test: compute -- context id exceeds. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(3)}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Swap to the tmp. env. + NNGraphTmp.swap(NNMod->getEnv().NNGraph); + NNContextTmp.swap(NNMod->getEnv().NNContext); + // Test: compute -- empty context. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(0)}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Swap back. + NNGraphTmp.swap(NNMod->getEnv().NNGraph); + NNContextTmp.swap(NNMod->getEnv().NNContext); + + // Test: compute -- compute successfully. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(1)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Torch WASI-NN get_output tests. + // Test: get_output -- output bytes ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- output buffer ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), OutBoundPtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- output index exceeds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(10), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- get output successfully. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(1), UINT32_C(0), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), UINT32_C(4000)); + std::vector OutputClassification( + MemInst.getPointer(StorePtr, 1000), + MemInst.getPointer(StorePtr, 1000) + 1000); + std::vector SortedIndex, CorrectClasses{954, 940, 951, 950, 953}; + SortedIndex = classSort(OutputClassification); + // The probability of class i is placed at buffer[i]. + for (size_t I = 0; I < CorrectClasses.size(); I++) { + EXPECT_EQ(SortedIndex[I], CorrectClasses[I]); + } + } +} +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH \ No newline at end of file diff --git a/utils/wasi-nn/download-pytorch-fixtures.sh b/utils/wasi-nn/download-pytorch-fixtures.sh new file mode 100644 index 00000000..bb20c471 --- /dev/null +++ b/utils/wasi-nn/download-pytorch-fixtures.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash + +DOWNLOAD_TO=$1 +FIXTURE=https://github.com/gusye1234/torchscript_fixtures/raw/main/mobilenet + +if [ ! -f $DOWNLOAD_TO/mobilenet.pt ]; then + wget -q --no-clobber --directory-prefix=$DOWNLOAD_TO $FIXTURE/mobilenet.pt +fi + +if [ ! -f $DOWNLOAD_TO/image-1-3-244-244.rgb ]; then + wget -q --no-clobber --directory-prefix=$DOWNLOAD_TO $FIXTURE/image-1-3-244-244.rgb +fi diff --git a/utils/wasi-nn/install-pytorch.sh b/utils/wasi-nn/install-pytorch.sh new file mode 100644 index 00000000..36d75519 --- /dev/null +++ b/utils/wasi-nn/install-pytorch.sh @@ -0,0 +1,5 @@ +if [ ! -d ./libtorch ]; then + curl -s -L -O --remote-name-all https://download.pytorch.org/libtorch/lts/1.8/cpu/libtorch-cxx11-abi-shared-with-deps-1.8.2%2Bcpu.zip + echo "b76d6dd4380e2233ce6f7654e672e13aae7c871231d223a4267ef018dcbfb616 libtorch-cxx11-abi-shared-with-deps-1.8.2%2Bcpu.zip" | sha256sum -c + unzip -q "libtorch-cxx11-abi-shared-with-deps-1.8.2%2Bcpu.zip" +fi From 45e9cb7d88eb11868e92112ca4f6daddd9ccb6cc Mon Sep 17 00:00:00 2001 From: YiYing He Date: Wed, 14 Sep 2022 21:49:59 +0800 Subject: [PATCH 077/623] [WASI] Refine the WASI-NN pytorch scripts Signed-off-by: YiYing He --- plugins/wasi_nn/CMakeLists.txt | 1 - test/plugins/wasi_nn/CMakeLists.txt | 7 ++-- test/plugins/wasi_nn/wasi_nn.cpp | 2 +- utils/wasi-nn/build-wasinn-ubuntu-openvino.sh | 4 +++ utils/wasi-nn/download-openvino-fixtures.sh | 2 ++ utils/wasi-nn/download-pytorch-fixtures.sh | 9 ++--- utils/wasi-nn/install-openvino.sh | 3 -- utils/wasi-nn/install-pytorch.sh | 34 ++++++++++++++++--- 8 files changed, 45 insertions(+), 17 deletions(-) mode change 100644 => 100755 utils/wasi-nn/download-pytorch-fixtures.sh mode change 100644 => 100755 utils/wasi-nn/install-pytorch.sh diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index af002928..ad84db4f 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -55,7 +55,6 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) message(STATUS "Build PyTorch backend for WASI-NN") find_package(Torch REQUIRED) add_definitions(-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") target_link_libraries(wasmedgePluginWasiNN PUBLIC ${TORCH_LIBRARIES} diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index cba10683..6895f53a 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -33,16 +33,15 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) RESULT_VARIABLE DOWNLOAD_ERROR OUTPUT_STRIP_TRAILING_WHITESPACE) file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_pytorch_fixtures/mobilenet.pt CHECKSUM_WEIGHT) - file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_pytorch_fixtures/image-1-3-244-244.rgb CHECKSUM_IMAGE) - if(NOT CHECKSUM_WEIGHT STREQUAL "4823fc625c9ab452c64ad614eb52f2f0") + file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_pytorch_fixtures/image-1x3x224x224.rgb CHECKSUM_IMAGE) + if(NOT CHECKSUM_WEIGHT STREQUAL "234f446d2446e0f6fd8ed700c0b4b63b") message(FATAL_ERROR "mobilenet.pt downloaded with wrong md5") endif() if(NOT CHECKSUM_IMAGE STREQUAL "551caa6f3b66c1d953655228462570a1") - message(FATAL_ERROR "image-1-3-244-244.rgb downloaded with wrong md5") + message(FATAL_ERROR "image-1x3x224x224.rgb downloaded with wrong md5") endif() add_definitions(-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH) find_package(Torch REQUIRED) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") else() # Add the other backend test files fetching here. endif() diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 1c573e9c..19ba3670 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -502,7 +502,7 @@ TEST(WasiNNTest, PyTorchBackend) { // Load the files. std::vector TensorData = - readEntireFile("./wasinn_pytorch_fixtures/image-1-3-244-244.rgb"); + readEntireFile("./wasinn_pytorch_fixtures/image-1x3x224x224.rgb"); std::vector WeightRead = readEntireFile("./wasinn_pytorch_fixtures/mobilenet.pt"); diff --git a/utils/wasi-nn/build-wasinn-ubuntu-openvino.sh b/utils/wasi-nn/build-wasinn-ubuntu-openvino.sh index 719b3599..eac9e848 100755 --- a/utils/wasi-nn/build-wasinn-ubuntu-openvino.sh +++ b/utils/wasi-nn/build-wasinn-ubuntu-openvino.sh @@ -2,6 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-FileCopyrightText: 2019-2022 Second State INC +if [[ ! -v "${CMAKE_BUILD_TYPE}" ]]; then + CMAKE_BUILD_TYPE=Release +fi + source /opt/intel/openvino_2021/bin/setupvars.sh ldconfig git config --global --add safe.directory $(pwd) diff --git a/utils/wasi-nn/download-openvino-fixtures.sh b/utils/wasi-nn/download-openvino-fixtures.sh index 97e44e4b..2f57f688 100755 --- a/utils/wasi-nn/download-openvino-fixtures.sh +++ b/utils/wasi-nn/download-openvino-fixtures.sh @@ -1,4 +1,6 @@ #!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2022 Second State INC DOWNLOAD_TO=$1 FIXTURE=https://github.com/intel/openvino-rs/raw/v0.3.3/crates/openvino/tests/fixtures/mobilenet/ diff --git a/utils/wasi-nn/download-pytorch-fixtures.sh b/utils/wasi-nn/download-pytorch-fixtures.sh old mode 100644 new mode 100755 index bb20c471..809cd22d --- a/utils/wasi-nn/download-pytorch-fixtures.sh +++ b/utils/wasi-nn/download-pytorch-fixtures.sh @@ -1,12 +1,13 @@ #!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2022 Second State INC DOWNLOAD_TO=$1 -FIXTURE=https://github.com/gusye1234/torchscript_fixtures/raw/main/mobilenet - +FIXTURE=https://github.com/second-state/WasmEdge-WASINN-examples/raw/master/pytorch-mobilenet-image/ if [ ! -f $DOWNLOAD_TO/mobilenet.pt ]; then wget -q --no-clobber --directory-prefix=$DOWNLOAD_TO $FIXTURE/mobilenet.pt fi -if [ ! -f $DOWNLOAD_TO/image-1-3-244-244.rgb ]; then - wget -q --no-clobber --directory-prefix=$DOWNLOAD_TO $FIXTURE/image-1-3-244-244.rgb +if [ ! -f $DOWNLOAD_TO/image-1x3x224x224.rgb ]; then + wget -q --no-clobber --directory-prefix=$DOWNLOAD_TO $FIXTURE/image-1x3x224x224.rgb fi diff --git a/utils/wasi-nn/install-openvino.sh b/utils/wasi-nn/install-openvino.sh index 7c13f261..00d3c13b 100755 --- a/utils/wasi-nn/install-openvino.sh +++ b/utils/wasi-nn/install-openvino.sh @@ -8,9 +8,6 @@ fi if [[ ! -v "${OPENVINO_YEAR}" ]]; then OPENVINO_YEAR="2021" fi -if [[ ! -v "${CMAKE_BUILD_TYPE}" ]]; then - CMAKE_BUILD_TYPE=Release -fi set -e echo "Installing OpenVINO with version ${OPENVINO_VERSION}" diff --git a/utils/wasi-nn/install-pytorch.sh b/utils/wasi-nn/install-pytorch.sh old mode 100644 new mode 100755 index 36d75519..2c8a3a7f --- a/utils/wasi-nn/install-pytorch.sh +++ b/utils/wasi-nn/install-pytorch.sh @@ -1,5 +1,31 @@ -if [ ! -d ./libtorch ]; then - curl -s -L -O --remote-name-all https://download.pytorch.org/libtorch/lts/1.8/cpu/libtorch-cxx11-abi-shared-with-deps-1.8.2%2Bcpu.zip - echo "b76d6dd4380e2233ce6f7654e672e13aae7c871231d223a4267ef018dcbfb616 libtorch-cxx11-abi-shared-with-deps-1.8.2%2Bcpu.zip" | sha256sum -c - unzip -q "libtorch-cxx11-abi-shared-with-deps-1.8.2%2Bcpu.zip" +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +if [[ ! -n ${PYTORCH_VERSION} ]]; then + PYTORCH_VERSION="1.8.2" +fi + +if [[ ! -n ${PYTORCH_INSTALL_TO} ]]; then + PYTORCH_INSTALL_TO=. +fi + +PYTORCH_LINK="libtorch-cxx11-abi" +PYTORCH_SHA="b76d6dd4380e2233ce6f7654e672e13aae7c871231d223a4267ef018dcbfb616" + +for i in "$@"; do + case $i in + --disable-cxx11-abi) + PYTORCH_LINK="libtorch" + PYTORCH_SHA="b5ddadc9addc054d8503f4086546f0cbcfdc3fc70087863bbd7b0e3300e3247f" + shift + ;; + esac +done + +if [ ! -d ${PYTORCH_INSTALL_TO}/libtorch ]; then + curl -s -L -O --remote-name-all https://download.pytorch.org/libtorch/lts/1.8/cpu/${PYTORCH_LINK}-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip + echo "${PYTORCH_SHA} ${PYTORCH_LINK}-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip" | sha256sum -c + unzip -q "${PYTORCH_LINK}-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip" -d ${PYTORCH_INSTALL_TO} + rm -f "${PYTORCH_LINK}-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip" fi From fe597d1e5e55030215109f0c9bee1ad3acc002c3 Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Mon, 6 Jun 2022 18:29:27 +0800 Subject: [PATCH 078/623] [Plugin] Add plugin in c interface Signed-off-by: Shen-Ta Hsieh --- plugins/CMakeLists.txt | 2 + plugins/test/CMakeLists.txt | 28 +++++++++ plugins/test/test.c | 75 +++++++++++++++++++++++ plugins/wasi_crypto/ctx.cpp | 15 +++-- plugins/wasi_nn/wasinnenv.cpp | 3 +- plugins/wasmedge_httpsreq/httpsreqenv.cpp | 3 +- plugins/wasmedge_process/processenv.cpp | 6 +- 7 files changed, 123 insertions(+), 9 deletions(-) create mode 100644 plugins/test/CMakeLists.txt create mode 100644 plugins/test/test.c diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index 5bc0f845..2e878ecc 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -19,3 +19,5 @@ endif() if (WASMEDGE_PLUGIN_WASI_CRYPTO) add_subdirectory(wasi_crypto) endif() + +add_subdirectory(test) diff --git a/plugins/test/CMakeLists.txt b/plugins/test/CMakeLists.txt new file mode 100644 index 00000000..abc1ba64 --- /dev/null +++ b/plugins/test/CMakeLists.txt @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +wasmedge_add_library(wasmedgePluginTest + SHARED + test.c +) + +set_target_properties(wasmedgePluginTest PROPERTIES + C_STANDARD 11 +) + +target_compile_options(wasmedgePluginTest + PUBLIC + -DWASMEDGE_PLUGIN +) + +if(WASMEDGE_LINK_PUGLINS_STATIC) + target_link_libraries(wasmedgePluginTest + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginTest + PRIVATE + wasmedge_shared + ) +endif() diff --git a/plugins/test/test.c b/plugins/test/test.c new file mode 100644 index 00000000..bf8e95d6 --- /dev/null +++ b/plugins/test/test.c @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "wasmedge/wasmedge.h" +#include + +static int32_t TestingOption; +static const int32_t TestingOptionDefaultValue = 42; + +static WasmEdge_Result Test(void *Data __attribute__((unused)), + const WasmEdge_CallingFrameContext *CallFrameCxt + __attribute__((unused)), + const WasmEdge_Value *In __attribute__((unused)), + WasmEdge_Value *Out __attribute__((unused))) { + return WasmEdge_Result_Success; +} + +static WasmEdge_ModuleInstanceContext * +CreateTestModule(const struct WasmEdge_ModuleDescriptor *Desc + __attribute__((unused))) { + WasmEdge_ModuleInstanceContext *Mod; + + { + WasmEdge_String ModuleName = WasmEdge_StringCreateByCString("test"); + Mod = WasmEdge_ModuleInstanceCreate(ModuleName); + WasmEdge_StringDelete(ModuleName); + } + + { + WasmEdge_FunctionTypeContext *FType = + WasmEdge_FunctionTypeCreate(NULL, 0, NULL, 0); + WasmEdge_FunctionInstanceContext *Func = + WasmEdge_FunctionInstanceCreate(FType, Test, NULL, 0); + WasmEdge_FunctionTypeDelete(FType); + WasmEdge_String FName = WasmEdge_StringCreateByCString("test"); + WasmEdge_ModuleInstanceAddFunction(Mod, FName, Func); + WasmEdge_StringDelete(FName); + } + + return Mod; +} + +static WasmEdge_ProgramOption PODesc[] = {{ + .Name = "test", + .Description = "testing option", + .Type = WasmEdge_ProgramOptionType_Int32, + .Storage = &TestingOption, + .DefaultValue = &TestingOptionDefaultValue, +}}; +static WasmEdge_ModuleDescriptor ModuleDesc[] = {{ + .Name = "test", + .Description = "testing module", + .Create = CreateTestModule, +}}; +static WasmEdge_PluginDescriptor Desc[] = {{ + .Name = "test", + .Description = "testing plugin", + .APIVersion = WasmEdge_Plugin_CurrentAPIVersion, + .Version = + { + .Major = 0, + .Minor = 0, + .Patch = 0, + .Build = 0, + }, + .ModuleCount = 1, + .ProgramOptionCount = 1, + .ModuleDescriptions = ModuleDesc, + .ProgramOptions = PODesc, +}}; + +WASMEDGE_CAPI_PLUGIN_EXPORT const WasmEdge_PluginDescriptor * +WasmEdge_Plugin_GetDescriptor(void) { + return Desc; +} diff --git a/plugins/wasi_crypto/ctx.cpp b/plugins/wasi_crypto/ctx.cpp index f233a504..7ff586f6 100644 --- a/plugins/wasi_crypto/ctx.cpp +++ b/plugins/wasi_crypto/ctx.cpp @@ -13,20 +13,25 @@ namespace Host { namespace { -Runtime::Instance::ModuleInstance *createAsymmetricCommon(void) noexcept { +Runtime::Instance::ModuleInstance *createAsymmetricCommon( + const Plugin::PluginModule::ModuleDescriptor *) noexcept { return new WasiCryptoAsymmetricCommonModule( WasiCrypto::Context::getInstance()); } -Runtime::Instance::ModuleInstance *createCommon(void) noexcept { +Runtime::Instance::ModuleInstance * +createCommon(const Plugin::PluginModule::ModuleDescriptor *) noexcept { return new WasiCryptoCommonModule(WasiCrypto::Context::getInstance()); } -Runtime::Instance::ModuleInstance *createKx(void) noexcept { +Runtime::Instance::ModuleInstance * +createKx(const Plugin::PluginModule::ModuleDescriptor *) noexcept { return new WasiCryptoKxModule(WasiCrypto::Context::getInstance()); } -Runtime::Instance::ModuleInstance *createSignatures(void) noexcept { +Runtime::Instance::ModuleInstance * +createSignatures(const Plugin::PluginModule::ModuleDescriptor *) noexcept { return new WasiCryptoSignaturesModule(WasiCrypto::Context::getInstance()); } -Runtime::Instance::ModuleInstance *createSymmetric(void) noexcept { +Runtime::Instance::ModuleInstance * +createSymmetric(const Plugin::PluginModule::ModuleDescriptor *) noexcept { return new WasiCryptoSymmetricModule(WasiCrypto::Context::getInstance()); } diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index 148c8e5b..2ad93957 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -9,7 +9,8 @@ namespace Host { namespace { -Runtime::Instance::ModuleInstance *create(void) noexcept { +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { return new WasiNNModule; } diff --git a/plugins/wasmedge_httpsreq/httpsreqenv.cpp b/plugins/wasmedge_httpsreq/httpsreqenv.cpp index 3587ed27..857d93a8 100644 --- a/plugins/wasmedge_httpsreq/httpsreqenv.cpp +++ b/plugins/wasmedge_httpsreq/httpsreqenv.cpp @@ -9,7 +9,8 @@ namespace Host { namespace { -Runtime::Instance::ModuleInstance *create(void) noexcept { +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { return new WasmEdgeHttpsReqModule; } diff --git a/plugins/wasmedge_process/processenv.cpp b/plugins/wasmedge_process/processenv.cpp index 483c3fdd..8c769a33 100644 --- a/plugins/wasmedge_process/processenv.cpp +++ b/plugins/wasmedge_process/processenv.cpp @@ -27,13 +27,15 @@ WasmEdgeProcessEnvironment::WasmEdgeProcessEnvironment() noexcept namespace { -void addOptions(PO::ArgumentParser &Parser) noexcept { +void addOptions(const Plugin::Plugin::PluginDescriptor *, + PO::ArgumentParser &Parser) noexcept { Parser.add_option("allow-command"sv, WasmEdgeProcessEnvironment::AllowCmd) .add_option("allow-command-all"sv, WasmEdgeProcessEnvironment::AllowCmdAll); } -Runtime::Instance::ModuleInstance *create(void) noexcept { +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { return new WasmEdgeProcessModule; } From acaeff72e686bf09789e85c09c25c0060666bacd Mon Sep 17 00:00:00 2001 From: Gustavo Ye Date: Thu, 20 Oct 2022 13:51:01 +0800 Subject: [PATCH 079/623] [WASI] Support tensorflow lite backend for WASI-NN (#1846) Signed-off-by: Gustavo Ye --- plugins/wasi_nn/CMakeLists.txt | 112 ++++++- plugins/wasi_nn/wasinnenv.h | 28 ++ plugins/wasi_nn/wasinnfunc.cpp | 244 ++++++++++++-- test/plugins/wasi_nn/CMakeLists.txt | 19 ++ test/plugins/wasi_nn/wasi_nn.cpp | 369 +++++++++++++++++++++- utils/wasi-nn/download-tflite-fixtures.sh | 13 + utils/wasi-nn/install-pytorch.sh | 18 +- 7 files changed, 762 insertions(+), 41 deletions(-) create mode 100644 utils/wasi-nn/download-tflite-fixtures.sh diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index ad84db4f..ff83409a 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -43,24 +43,126 @@ install(TARGETS wasmedgePluginWasiNN DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedg # Add backends building flags. foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) - if(BACKEND MATCHES "OpenVINO") - message(STATUS "Build OpenVINO backend for WASI-NN") + if(BACKEND STREQUAL "OpenVINO") + message(STATUS "WASI-NN: Build OpenVINO backend for WASI-NN") find_package(InferenceEngine REQUIRED) add_definitions(-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO) target_link_libraries(wasmedgePluginWasiNN PUBLIC ${InferenceEngine_LIBRARIES} ) - elseif(BACKEND MATCHES "PyTorch") - message(STATUS "Build PyTorch backend for WASI-NN") + elseif(BACKEND STREQUAL "PyTorch") + message(STATUS "WASI-NN: Build PyTorch backend for WASI-NN") find_package(Torch REQUIRED) add_definitions(-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH) target_link_libraries(wasmedgePluginWasiNN PUBLIC ${TORCH_LIBRARIES} ) + elseif(BACKEND STREQUAL "Tensorflowlite") + message(STATUS "WASI-NN: Build Tensorflow lite backend for WASI-NN") + add_definitions(-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE) + + if(NOT WASMEDGE_DEPS_VERSION) + set(WASMEDGE_DEPS_VERSION "0.11.0-rc.1") + endif() + + # Clone required shared libraries + if(ANDROID) + if(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") + set(WASMEDGE_TENSORFLOW_SYSTEM_NAME "android_aarch64") + set(WASMEDGE_TENSORFLOW_DEPS_TFLITE_HASH "a25dafad049cbc998c1f9682c57aec22b2fe5799eeffdd4ed19793a734cde8a4") + elseif() + message(FATAL_ERROR "Unsupported architecture: ${CMAKE_SYSTEM_PROCESSOR}") + endif() + elseif(APPLE) + if(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64") + set(WASMEDGE_TENSORFLOW_SYSTEM_NAME "darwin_x86_64") + set(WASMEDGE_TENSORFLOW_DEPS_TFLITE_HASH "2593772df440a768e79d87e74a860378f46fb0b7d1e7805879ab2ec26a093b57") + else() + message(FATAL_ERROR "Unsupported architecture: ${CMAKE_SYSTEM_PROCESSOR}") + endif() + elseif(UNIX) + if(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64") + set(WASMEDGE_TENSORFLOW_SYSTEM_NAME "manylinux2014_x86_64") + set(WASMEDGE_TENSORFLOW_DEPS_TFLITE_HASH "43b2a782efb58b047c6d33f64d7ac711b24426959f91287d910edb8937c11dea") + elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") + set(WASMEDGE_TENSORFLOW_SYSTEM_NAME "manylinux2014_aarch64") + set(WASMEDGE_TENSORFLOW_DEPS_TFLITE_HASH "1f47dcd05f32907848253e0f4b0eb3a6276802dae41d2b7de61214b75ba02395") + else() + message(FATAL_ERROR "Unsupported architecture: ${CMAKE_SYSTEM_PROCESSOR}") + endif() + else() + message(FATAL_ERROR "Unsupported system: ${CMAKE_SYSTEM_NAME}") + endif() + + include(FetchContent) + + # Fetch Tensorflow-lite library. + FetchContent_Declare( + wasmedgetensorflowdepslite + URL "https://github.com/second-state/WasmEdge-tensorflow-deps/releases/download/${WASMEDGE_DEPS_VERSION}/WasmEdge-tensorflow-deps-TFLite-${WASMEDGE_DEPS_VERSION}-${WASMEDGE_TENSORFLOW_SYSTEM_NAME}.tar.gz" + URL_HASH "SHA256=${WASMEDGE_TENSORFLOW_DEPS_TFLITE_HASH}" + ) + FetchContent_GetProperties(wasmedgetensorflowdepslite) + + if(NOT wasmedgetensorflowdepslite_POPULATED) + message(STATUS "Downloading dependency: libtensorflowlite") + FetchContent_Populate(wasmedgetensorflowdepslite) + message(STATUS "Downloading dependency: libtensorflowlite - done") + endif() + + # Setup Tensorflow-lite library. + if(APPLE) + set(WASMEDGE_TENSORFLOW_DEPS_TFLITE_LIB + "${wasmedgetensorflowdepslite_SOURCE_DIR}/libtensorflowlite_c.dylib" + ) + elseif(UNIX) + set(WASMEDGE_TENSORFLOW_DEPS_TFLITE_LIB + "${wasmedgetensorflowdepslite_SOURCE_DIR}/libtensorflowlite_c.so" + ) + endif() + + include(FetchContent) + FetchContent_Declare( + wasmedge_tensorflow_deps + GIT_REPOSITORY https://github.com/second-state/WasmEdge-tensorflow-deps.git + GIT_TAG ${WASMEDGE_DEPS_VERSION} + ) + FetchContent_GetProperties(wasmedge_tensorflow_deps) + + if(NOT wasmedge_tensorflow_deps_POPULATED) + message(STATUS "Fetching WasmEdge-tensorflow-dep repository") + FetchContent_Populate(wasmedge_tensorflow_deps) + message(STATUS "Fetching WasmEdge-tensorflow-dep repository - done") + endif() + + # -hardcode remove in future------------------------------- + # set(WASMEDGE_TENSORFLOW_DEPS_TFLITE_HASH "43b2a782efb58b047c6d33f64d7ac711b24426959f91287d910edb8937c11dea") + # set(wasmedge_tensorflow_deps_SOURCE_DIR "${CMAKE_SOURCE_DIR}/utils/WasmEdge-tensorflow-deps") + + # set(WASMEDGE_TENSORFLOW_DEPS_TFLITE_LIB + # "${wasmedge_tensorflow_deps_SOURCE_DIR}/shares/libtensorflowlite_c.so" + # ) + + # -------------------------------- + set(WASMEDGE_TENSORFLOW_DEPS_PATH ${wasmedge_tensorflow_deps_SOURCE_DIR}) + set(WASMEDGE_TENSORFLOW_DEPS_BIN_PATH ${CMAKE_CURRENT_BINARY_DIR}/utils/WasmEdge-tensorflow-deps) + + message(STATUS "WASI-NN: Set WasmEdge-tensorflow deps source path: ${WASMEDGE_TENSORFLOW_DEPS_PATH}") + message(STATUS "WASI-NN: Set WasmEdge-tensorflow deps binary path: ${WASMEDGE_TENSORFLOW_DEPS_BIN_PATH}") + message(STATUS "WASI-NN: Set WasmEdge-tensorflowlite share path: ${WASMEDGE_TENSORFLOW_DEPS_TFLITE_LIB}") + add_subdirectory(${WASMEDGE_TENSORFLOW_DEPS_PATH} ${WASMEDGE_TENSORFLOW_DEPS_BIN_PATH}) + target_include_directories(wasmedgePluginWasiNN + PUBLIC + ${TENSORFLOW_INCLUDE} + ) + target_link_libraries(wasmedgePluginWasiNN + PUBLIC + ${WASMEDGE_TENSORFLOW_DEPS_TFLITE_LIB} + ) else() # Add the other backends here. - message(FATAL_ERROR "WASI-NN backend ${BACKEND} not found or unimplemented.") + message(FATAL_ERROR "WASI-NN: backend ${BACKEND} not found or unimplemented.") endif() endforeach() diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index eb525836..8352470f 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -15,6 +15,10 @@ #include #endif +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE +#include "tensorflow/lite/c/c_api.h" +#endif + namespace WasmEdge { namespace Host { namespace WASINN { @@ -26,9 +30,13 @@ enum class ErrNo : uint32_t { Busy = 3 // Device or resource busy. }; +enum class TensorType : uint32_t { F16 = 0, F32 = 1, U8 = 2, I32 = 3 }; + enum class Backend : uint8_t { OpenVINO = 0, PyTorch = 1, + Tensorflow = 2, + TensorflowLite = 3 }; class Graph { @@ -62,6 +70,14 @@ class Graph { ie_network_name_free(&I); } } +#endif +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE + if (TFLiteMod) { + TfLiteModelDelete(TFLiteMod); + } + if (TFLiteOps) { + TfLiteInterpreterOptionsDelete(TFLiteOps); + } #endif } @@ -76,6 +92,10 @@ class Graph { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH torch::jit::Module TorchModel; #endif +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE + TfLiteModel *TFLiteMod = nullptr; + TfLiteInterpreterOptions *TFLiteOps = nullptr; +#endif }; class Context { @@ -100,6 +120,11 @@ class Context { if (OpenVINOInferRequest) { ie_infer_request_free(&OpenVINOInferRequest); } +#endif +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE + if (TFLiteInterp) { + TfLiteInterpreterDelete(TFLiteInterp); + } #endif } @@ -111,6 +136,9 @@ class Context { std::vector TorchInputs; std::vector TorchOutputs; #endif +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE + TfLiteInterpreter *TFLiteInterp = nullptr; +#endif }; class WasiNNEnvironment { diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index b55ba289..96dc418d 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -18,11 +18,15 @@ #include #endif +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE +#include "tensorflow/lite/c/c_api.h" +#endif + namespace WasmEdge { namespace Host { namespace { -std::string FindDevice(const uint32_t Target) { +[[maybe_unused]] std::string FindDevice(const uint32_t Target) { std::string DeviceName; switch (Target) { case 0: @@ -39,13 +43,13 @@ std::string FindDevice(const uint32_t Target) { } return DeviceName; } + } // namespace Expect WasiNNLoad::body(const Runtime::CallingFrame &Frame, uint32_t BuilderPtr [[maybe_unused]], uint32_t BuilderLen [[maybe_unused]], - uint32_t Encoding, - uint32_t Target [[maybe_unused]], + uint32_t Encoding, uint32_t Target, uint32_t GraphIdPtr [[maybe_unused]]) { // Check memory instance from module. auto *MemInst = Frame.getMemoryByIndex(0); @@ -58,6 +62,15 @@ Expect WasiNNLoad::body(const Runtime::CallingFrame &Frame, spdlog::error("[WASI-NN] Failed when accessing the return GraphID memory."); return static_cast(WASINN::ErrNo::InvalidArgument); } + // Get and check the device name string. + std::string DeviceName; + DeviceName = FindDevice(Target); + if (unlikely(DeviceName.length() == 0)) { + spdlog::error("[WASI-NN] Only support CPU target"); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + spdlog::debug("[WASI-NN] Using device: {:s}", DeviceName); + if (Encoding == static_cast(WASINN::Backend::OpenVINO)) { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO // The OpenVINO core must be initialized in constructor. @@ -73,15 +86,6 @@ Expect WasiNNLoad::body(const Runtime::CallingFrame &Frame, return static_cast(WASINN::ErrNo::InvalidArgument); } - // Get and check the device name string. - std::string DeviceName; - DeviceName = FindDevice(Target); - if (DeviceName.length() == 0) { - spdlog::error("[WASI-NN] PyTorch backend only support CPU target"); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - spdlog::debug("[WASI-NN] Using device: {:s}", DeviceName); - // Get the graph builders. // GraphBuilders' Layout: // | builder-0 | builder-0 len | builder-1 | builder-1 len | ... @@ -248,13 +252,6 @@ Expect WasiNNLoad::body(const Runtime::CallingFrame &Frame, #endif } else if (Encoding == static_cast(WASINN::Backend::PyTorch)) { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH - std::string DeviceName; - DeviceName = FindDevice(Target); - if (DeviceName.length() == 0) { - spdlog::error("[WASI-NN] PyTorch backend only support CPU target"); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - // The graph builder length must be 2. if (BuilderLen != 1) { spdlog::error("[WASI-NN] Wrong GraphBuilder Length {:d}, expect 1", @@ -296,6 +293,48 @@ Expect WasiNNLoad::body(const Runtime::CallingFrame &Frame, spdlog::error("[WASI-NN] PyTorch backend is not built. use " "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"PyTorch\" to build it."); #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH + } else if (Encoding == + static_cast(WASINN::Backend::TensorflowLite)) { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE + // The graph builder length must be 1. + if (BuilderLen != 1) { + spdlog::error("[WASI-NN] Wrong GraphBuilder Length {:d}, expect 1", + BuilderLen); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + uint32_t *GraphBuilders = + MemInst->getPointer(BuilderPtr, BuilderLen * 2); + if (unlikely(GraphBuilders == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the GraphBuilder memory."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + uint32_t BinLen = GraphBuilders[1]; + char *BinPtr = MemInst->getPointer(GraphBuilders[0], BinLen); + if (unlikely(BinPtr == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the Weight memory."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + // Add a new graph. + Env.NNGraph.emplace_back(static_cast(Encoding)); + auto &Graph = Env.NNGraph.back(); + + Graph.TFLiteMod = TfLiteModelCreate(BinPtr, BinLen); + if (unlikely(Graph.TFLiteMod == nullptr)) { + spdlog::error("[WASI-NN] Cannot import TFLite model"); + Env.NNGraph.pop_back(); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + Graph.TFLiteOps = TfLiteInterpreterOptionsCreate(); + TfLiteInterpreterOptionsSetNumThreads(Graph.TFLiteOps, 2); + + // Store the loaded graph. + *GraphId = Env.NNGraph.size() - 1; + return static_cast(WASINN::ErrNo::Success); +#else + spdlog::error( + "[WASI-NN] TensorflowLite backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"Tensorflowlite\" to build it."); +#endif } else { spdlog::error("[WASI-NN] Current backend is not supported."); } @@ -354,6 +393,35 @@ Expect WasiNNInitExecCtx::body(const Runtime::CallingFrame &Frame, #else spdlog::error("[WASI-NN] PyTorch backend is not built. use " "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"PyTorch\" to build it."); +#endif + } else if (Env.NNGraph[GraphId].GraphBackend == + WASINN::Backend::TensorflowLite) { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE + // Check the network and the execution network with the graph ID. + if (Env.NNGraph[GraphId].TFLiteMod == nullptr || + Env.NNGraph[GraphId].TFLiteOps == nullptr) { + spdlog::error("[WASI-NN] Model for Graph:{} is missing!", GraphId); + return static_cast(WASINN::ErrNo::MissingMemory); + } + + Env.NNContext.emplace_back(Env.NNGraph[GraphId]); + const auto Graph = Env.NNGraph[GraphId]; + auto &NewContext = Env.NNContext.back(); + NewContext.TFLiteInterp = + TfLiteInterpreterCreate(Graph.TFLiteMod, Graph.TFLiteOps); + if (unlikely(NewContext.TFLiteInterp == nullptr)) { + spdlog::error("[WASI-NN] Cannot create TFLite interpreter."); + Env.NNContext.pop_back(); + return static_cast(WASINN::ErrNo::Busy); + } + TfLiteInterpreterAllocateTensors(NewContext.TFLiteInterp); + + *Context = Env.NNContext.size() - 1; + return static_cast(WASINN::ErrNo::Success); +#else + spdlog::error( + "[WASI-NN] TensorflowLite backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"Tensorflowlite\" to build it."); #endif } else { spdlog::error("[WASI-NN] Current backend is not supported."); @@ -423,8 +491,8 @@ Expect WasiNNSetInput::body(const Runtime::CallingFrame &Frame, spdlog::error("[WASI-NN] Failed when accessing the TensorData memory."); return static_cast(WASINN::ErrNo::InvalidArgument); } - uint32_t RType = Tensor[2]; - if (RType != 1) { + WASINN::TensorType RType = static_cast(Tensor[2]); + if (RType != WASINN::TensorType::F32) { spdlog::error( "[WASI-NN] Only F32 inputs and outputs are supported for now."); return static_cast(WASINN::ErrNo::InvalidArgument); @@ -551,8 +619,8 @@ Expect WasiNNSetInput::body(const Runtime::CallingFrame &Frame, spdlog::error("[WASI-NN] Failed when accessing the TensorData memory."); return static_cast(WASINN::ErrNo::InvalidArgument); } - uint32_t RType = Tensor[2]; - if (RType != 1) { + WASINN::TensorType RType = static_cast(Tensor[2]); + if (RType != WASINN::TensorType::F32) { spdlog::error( "[WASI-NN] Only F32 inputs and outputs are supported for now."); return static_cast(WASINN::ErrNo::InvalidArgument); @@ -571,6 +639,82 @@ Expect WasiNNSetInput::body(const Runtime::CallingFrame &Frame, #else spdlog::error("[WASI-NN] PyTorch backend is not built. use " "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"PyTorch\" to build it."); +#endif + } else if (CxtRef.GraphRef.GraphBackend == WASINN::Backend::TensorflowLite) { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE + uint32_t InCnt = TfLiteInterpreterGetInputTensorCount(CxtRef.TFLiteInterp); + if (Index >= InCnt) { + spdlog::error("[WASI-NN] Invalid index id {} for the input, only {} " + "inputs are allowed", + Index, InCnt); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + uint32_t *Tensor = MemInst->getPointer(TensorPtr, 5); + if (unlikely(Tensor == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the Tensor memory."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + uint32_t DimensionLen = Tensor[1]; + std::vector TFDimension(DimensionLen); + uint32_t *DimensionBuf = + MemInst->getPointer(Tensor[0], DimensionLen); + for (uint32_t I = 0; I < DimensionLen; I++) { + TFDimension.push_back(static_cast(DimensionBuf[I])); + } + if (unlikely(DimensionBuf == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the Dimension memory."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + uint32_t TensorDataLen = Tensor[4]; + uint32_t TensorType = Tensor[2]; + uint8_t *TensorDataBuf = + MemInst->getPointer(Tensor[3], TensorDataLen); + if (unlikely(TensorDataBuf == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the TensorData memory."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + + auto *HoldTensor = + TfLiteInterpreterGetInputTensor(CxtRef.TFLiteInterp, Index); + TfLiteType LiteType = TfLiteTensorType(HoldTensor); + WASINN::TensorType NNType; + + switch (LiteType) { + case TfLiteType::kTfLiteUInt8: + NNType = WASINN::TensorType::U8; + break; + case TfLiteType::kTfLiteFloat16: + NNType = WASINN::TensorType::F16; + break; + case TfLiteType::kTfLiteFloat32: + NNType = WASINN::TensorType::F32; + break; + case TfLiteType::kTfLiteInt32: + NNType = WASINN::TensorType::I32; + break; + default: + spdlog::error("[WASI-NN] Unsupported TFLite type: {}", LiteType); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + + if (unlikely(TensorType != static_cast(NNType))) { + spdlog::error("[WASI-NN] Expect tensor type {}, but got {}", + static_cast(NNType), TensorType); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + TfLiteStatus Stat = + TfLiteTensorCopyFromBuffer(HoldTensor, TensorDataBuf, TensorDataLen); + if (unlikely(Stat != TfLiteStatus::kTfLiteOk)) { + spdlog::error("[WASI-NN] Copy tensor memory failed"); + return static_cast(WASINN::ErrNo::Busy); + } + + return static_cast(WASINN::ErrNo::Success); + +#else + spdlog::error( + "[WASI-NN] TensorflowLite backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"Tensorflowlite\" to build it."); #endif } else { spdlog::error("[WASI-NN] Current backend is not supported."); @@ -707,6 +851,42 @@ WasiNNGetOuput::body(const Runtime::CallingFrame &Frame, uint32_t Context, #else spdlog::error("[WASI-NN] PyTorch backend is not built. use " "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"PyTorch\" to build it."); +#endif + } else if (CxtRef.GraphRef.GraphBackend == WASINN::Backend::TensorflowLite) { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE + uint32_t OutCnt = + TfLiteInterpreterGetOutputTensorCount(CxtRef.TFLiteInterp); + if (Index >= OutCnt) { + spdlog::error("[WASI-NN] Invalid index id {} for the input, only {} " + "outputs are allowed", + Index, OutCnt); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + const TfLiteTensor *HoldTensor = + TfLiteInterpreterGetOutputTensor(CxtRef.TFLiteInterp, Index); + const uint32_t BlobSize = TfLiteTensorByteSize(HoldTensor); + uint32_t BytesToWrite = std::min(BlobSize, OutBufferMaxSize); + uint8_t *OutBuffer = + MemInst->getPointer(OutBufferPtr, BytesToWrite); + if (unlikely(OutBuffer == nullptr)) { + spdlog::error( + "[WASI-NN] Failed when accessing the Output Buffer memory."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + TfLiteTensorCopyToBuffer(HoldTensor, OutBuffer, BytesToWrite); + uint32_t *BytesWritten = + MemInst->getPointer(BytesWrittenPtr, 1); + if (unlikely(BytesWritten == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the BytesWritten memory."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + *BytesWritten = BytesToWrite; + return static_cast(WASINN::ErrNo::Success); + +#else + spdlog::error( + "[WASI-NN] Tensorflowlite backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"Tensorflowlite\" to build it."); #endif } else { spdlog::error("[WASI-NN] Current backend is not supported."); @@ -774,6 +954,24 @@ Expect WasiNNCompute::body(const Runtime::CallingFrame &Frame, #else spdlog::error("[WASI-NN] PyTorch backend is not built. use " "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"PyTorch\" to build it."); +#endif + } else if (CxtRef.GraphRef.GraphBackend == WASINN::Backend::TensorflowLite) { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE + // Run session + if (unlikely(CxtRef.TFLiteInterp == nullptr)) { + spdlog::error("[WASI-NN] Tensorflow Lite context empty"); + return static_cast(WASINN::ErrNo::MissingMemory); + } + TfLiteStatus Stat = TfLiteInterpreterInvoke(CxtRef.TFLiteInterp); + if (unlikely(Stat != TfLiteStatus::kTfLiteOk)) { + spdlog::error("[WASI-NN] Invokation failed."); + return static_cast(WASINN::ErrNo::Busy); + } + return static_cast(WASINN::ErrNo::Success); +#else + spdlog::error( + "[WASI-NN] Tensorflowlite backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"Tensorflowlite\" to build it."); #endif } else { spdlog::error("[WASI-NN] Current backend is not supported."); diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index 6895f53a..e3d55a9a 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -42,6 +42,25 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) endif() add_definitions(-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH) find_package(Torch REQUIRED) + elseif(BACKEND STREQUAL "Tensorflowlite") + message( STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_tflite_fixtures") + execute_process( + COMMAND bash ${CMAKE_SOURCE_DIR}/utils/wasi-nn/download-tflite-fixtures.sh ${CMAKE_CURRENT_BINARY_DIR}/wasinn_tflite_fixtures + RESULT_VARIABLE DOWNLOAD_ERROR + OUTPUT_STRIP_TRAILING_WHITESPACE) + file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_tflite_fixtures/lite-model_aiy_vision_classifier_birds_V1_3.tflite CHECKSUM_WEIGHT) + file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_tflite_fixtures/birdx224x224x3.rgb CHECKSUM_IMAGE) + if(NOT CHECKSUM_WEIGHT STREQUAL "3e59cc3a99afeeb819c2c38b319a7938") + message(FATAL_ERROR "downloaded tflite model with wrong md5") + endif() + if(NOT CHECKSUM_IMAGE STREQUAL "ad51c39cfe35d2ef35c4052b78cb3c55") + message(FATAL_ERROR "downloaded bird.jpg fixture with wrong md5") + endif() + add_definitions(-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE) + target_include_directories(wasiNNTests + PUBLIC + ${TENSORFLOW_INCLUDE} + ) else() # Add the other backend test files fetching here. endif() diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 19ba3670..d6f74924 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -17,7 +17,8 @@ using WasmEdge::Host::WASINN::Backend; using WasmEdge::Host::WASINN::ErrNo; #if defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO) || \ - defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH) + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH) || \ + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE) namespace { WasmEdge::Runtime::Instance::ModuleInstance *createModule() { using namespace std::literals::string_view_literals; @@ -72,9 +73,9 @@ std::vector classSort(const std::vector &Array) { std::vector Indices(Array.size()); std::iota(Indices.begin(), Indices.end(), 0); std::sort(Indices.begin(), Indices.end(), - [&Array](int Left, int Right) -> bool { + [&Array](size_t Left, size_t Right) -> bool { // Sort indices according to corresponding array element. - return Array[Left] >= Array[Right]; + return Array[Left] > Array[Right]; }); return Indices; } @@ -847,4 +848,364 @@ TEST(WasiNNTest, PyTorchBackend) { } } } -#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH \ No newline at end of file +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE +TEST(WasiNNTest, TFLiteBackend) { + // Create the wasmedge_process module instance. + auto *NNMod = dynamic_cast(createModule()); + EXPECT_FALSE(NNMod == nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(400))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // Load the files. + std::vector TensorData = + readEntireFile("./wasinn_tflite_fixtures/birdx224x224x3.rgb"); + std::vector WeightRead = + readEntireFile("./wasinn_tflite_fixtures/" + "lite-model_aiy_vision_classifier_birds_V1_3.tflite"); + spdlog::info("Read {}", TensorData.size()); + std::vector TensorDim{1, 3, 224, 224}; + uint32_t BuilderPtr = UINT32_C(0); + uint32_t LoadEntryPtr = UINT32_C(0); + uint32_t SetInputEntryPtr = UINT32_C(0); + uint32_t OutBoundPtr = UINT32_C(410 * 65536); + uint32_t StorePtr = UINT32_C(65536); + + // Return value. + std::array Errno = {UINT32_C(0)}; + + // Temp. values. + std::vector NNGraphTmp; + std::vector NNContextTmp; + + // Get the function "load". + auto *FuncInst = NNMod->findFuncExports("load"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncLoad = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "init_execution_context". + FuncInst = NNMod->findFuncExports("init_execution_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInit = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "set_input". + FuncInst = NNMod->findFuncExports("set_input"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSetInput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "get_output". + FuncInst = NNMod->findFuncExports("get_output"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetOutput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "compute". + FuncInst = NNMod->findFuncExports("compute"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncCompute = + dynamic_cast(FuncInst->getHostFunc()); + + // Torch WASI-NN load tests. + // Test: load -- meaningless binaries. + { + EXPECT_TRUE( + HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::TensorflowLite), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- graph id ptr out of bounds. + { + EXPECT_TRUE( + HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::TensorflowLite), + UINT32_C(0), OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- graph builder ptr out of bounds. + { + EXPECT_TRUE( + HostFuncLoad.run(CallFrame, + std::initializer_list{ + OutBoundPtr, UINT32_C(1), + static_cast(Backend::TensorflowLite), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: load -- model bin ptr out of bounds. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, OutBoundPtr, WeightRead.size(), BuilderPtr); + { + EXPECT_TRUE( + HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::TensorflowLite), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- wrong builders' length. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, StorePtr, WeightRead.size(), BuilderPtr); + writeBinaries(MemInst, WeightRead, StorePtr); + StorePtr += WeightRead.size(); + { + EXPECT_TRUE( + HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(2), + static_cast(Backend::TensorflowLite), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- unsupported device. CPU 0, GPU 1, TPU 2 + { + EXPECT_TRUE( + HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::TensorflowLite), + UINT32_C(3), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- load successfully. + { + EXPECT_TRUE( + HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::TensorflowLite), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Test: load -- load second graph. + { + EXPECT_TRUE( + HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::TensorflowLite), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 1); + BuilderPtr += 4; + } + + // WASI-NN init_execution_context tests. + // Test: init_execution_context -- graph id invalid. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(2), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Swap to the tmp. env. + // Test: init_execution_context -- graph id exceeds. + NNGraphTmp.emplace_back(Backend::TensorflowLite); + NNGraphTmp.swap(NNMod->getEnv().NNGraph); + NNContextTmp.swap(NNMod->getEnv().NNContext); + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::MissingMemory)); + } + // Swap back. + NNGraphTmp.swap(NNMod->getEnv().NNGraph); + NNContextTmp.swap(NNMod->getEnv().NNContext); + + // Test: init_execution_context -- init context successfully. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Test: init_execution_context -- init second context. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(1), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 1); + BuilderPtr += 4; + } + + // Torch WASI-NN set_input tests. + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); + + // Test: set_input -- context id exceeds. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(3), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + NNContextTmp.emplace_back(NNGraphTmp[0]); + + // Test: set_input -- set input successfully. + BuilderPtr = SetInputEntryPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + // Tensor type U8 + writeUInt32(MemInst, UINT32_C(2), BuilderPtr); + writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + BuilderPtr); + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(1), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += (TensorDim.size() * 4 + TensorData.size()); + + // Torch WASI-NN compute tests. + // Test: compute -- context id exceeds. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(3)}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Swap to the tmp. env. + NNGraphTmp.swap(NNMod->getEnv().NNGraph); + NNContextTmp.swap(NNMod->getEnv().NNContext); + // Test: compute -- empty context. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(0)}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::MissingMemory)); + } + // Swap back. + NNGraphTmp.swap(NNMod->getEnv().NNGraph); + NNContextTmp.swap(NNMod->getEnv().NNContext); + + // Test: compute -- compute successfully. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(1)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + // WASI-NN get_output tests. + // Test: get_output -- output bytes ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- output buffer ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), OutBoundPtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- output index exceeds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(10), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- get output successfully. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(1), UINT32_C(0), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), UINT32_C(965)); + std::vector OutputClassification( + MemInst.getPointer(StorePtr, 965), + MemInst.getPointer(StorePtr, 965) + 965); + std::vector SortedIndex, CorrectClasses{166, 158, 34, 778, 819}; + // FIXME: classSort causing segmentation fault + SortedIndex = classSort(OutputClassification); + + // The probability of class i is placed at buffer[i]. + for (size_t I = 0; I < CorrectClasses.size(); I++) { + EXPECT_EQ(OutputClassification[SortedIndex[I]], + OutputClassification[CorrectClasses[I]]); + } + } +} +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE \ No newline at end of file diff --git a/utils/wasi-nn/download-tflite-fixtures.sh b/utils/wasi-nn/download-tflite-fixtures.sh new file mode 100644 index 00000000..e74e6511 --- /dev/null +++ b/utils/wasi-nn/download-tflite-fixtures.sh @@ -0,0 +1,13 @@ +TODIR=$1 + +FIXTURE=https://raw.githubusercontent.com/gusye1234/WasmEdge-WASINN-examples/demo-tflite-image/tflite-birds_v1-image +if [ ! -d $TODIR ]; then + mkdir $TODIR +fi + +if [ ! -f $TODIR/lite-model_aiy_vision_classifier_birds_V1_3.tflite ]; then + curl -o $TODIR/lite-model_aiy_vision_classifier_birds_V1_3.tflite $FIXTURE/lite-model_aiy_vision_classifier_birds_V1_3.tflite +fi +if [ ! -f $TODIR/birdx224x224x3.rgb ]; then + curl -o $TODIR/birdx224x224x3.rgb $FIXTURE/birdx224x224x3.rgb +fi diff --git a/utils/wasi-nn/install-pytorch.sh b/utils/wasi-nn/install-pytorch.sh index 2c8a3a7f..faec86f2 100755 --- a/utils/wasi-nn/install-pytorch.sh +++ b/utils/wasi-nn/install-pytorch.sh @@ -15,17 +15,17 @@ PYTORCH_SHA="b76d6dd4380e2233ce6f7654e672e13aae7c871231d223a4267ef018dcbfb616" for i in "$@"; do case $i in - --disable-cxx11-abi) - PYTORCH_LINK="libtorch" - PYTORCH_SHA="b5ddadc9addc054d8503f4086546f0cbcfdc3fc70087863bbd7b0e3300e3247f" - shift - ;; + --disable-cxx11-abi) + PYTORCH_LINK="libtorch" + PYTORCH_SHA="b5ddadc9addc054d8503f4086546f0cbcfdc3fc70087863bbd7b0e3300e3247f" + shift + ;; esac done if [ ! -d ${PYTORCH_INSTALL_TO}/libtorch ]; then - curl -s -L -O --remote-name-all https://download.pytorch.org/libtorch/lts/1.8/cpu/${PYTORCH_LINK}-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip - echo "${PYTORCH_SHA} ${PYTORCH_LINK}-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip" | sha256sum -c - unzip -q "${PYTORCH_LINK}-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip" -d ${PYTORCH_INSTALL_TO} - rm -f "${PYTORCH_LINK}-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip" + curl -s -L -O --remote-name-all https://download.pytorch.org/libtorch/lts/1.8/cpu/${PYTORCH_LINK}-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip + echo "${PYTORCH_SHA} ${PYTORCH_LINK}-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip" | sha256sum -c + unzip -q "${PYTORCH_LINK}-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip" -d ${PYTORCH_INSTALL_TO} + rm -f "${PYTORCH_LINK}-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip" fi From e63ced7d3d4362c45ab3c2fa52212679c220ae9c Mon Sep 17 00:00:00 2001 From: YiYing He Date: Thu, 20 Oct 2022 14:37:38 +0800 Subject: [PATCH 080/623] [CI] Refine the plugin build and release CI for WASI-NN. Signed-off-by: YiYing He --- plugins/wasi_nn/CMakeLists.txt | 19 ++++++----------- utils/wasi-nn/download-openvino-fixtures.sh | 23 ++++++++++++++------- utils/wasi-nn/download-pytorch-fixtures.sh | 16 +++++++++----- utils/wasi-nn/download-tflite-fixtures.sh | 12 ++++++++--- 4 files changed, 42 insertions(+), 28 deletions(-) mode change 100644 => 100755 utils/wasi-nn/download-tflite-fixtures.sh diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index ff83409a..dae0cb94 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -43,7 +43,8 @@ install(TARGETS wasmedgePluginWasiNN DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedg # Add backends building flags. foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) - if(BACKEND STREQUAL "OpenVINO") + string(TOLOWER ${BACKEND} BACKEND) + if(BACKEND STREQUAL "openvino") message(STATUS "WASI-NN: Build OpenVINO backend for WASI-NN") find_package(InferenceEngine REQUIRED) add_definitions(-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO) @@ -51,7 +52,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) PUBLIC ${InferenceEngine_LIBRARIES} ) - elseif(BACKEND STREQUAL "PyTorch") + elseif(BACKEND STREQUAL "pytorch") message(STATUS "WASI-NN: Build PyTorch backend for WASI-NN") find_package(Torch REQUIRED) add_definitions(-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH) @@ -59,12 +60,13 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) PUBLIC ${TORCH_LIBRARIES} ) - elseif(BACKEND STREQUAL "Tensorflowlite") + elseif(BACKEND STREQUAL "tensorflowlite") message(STATUS "WASI-NN: Build Tensorflow lite backend for WASI-NN") + # TODO: Move these complicated steps into a helper cmake. add_definitions(-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE) if(NOT WASMEDGE_DEPS_VERSION) - set(WASMEDGE_DEPS_VERSION "0.11.0-rc.1") + set(WASMEDGE_DEPS_VERSION "0.11.1") endif() # Clone required shared libraries @@ -137,15 +139,6 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) message(STATUS "Fetching WasmEdge-tensorflow-dep repository - done") endif() - # -hardcode remove in future------------------------------- - # set(WASMEDGE_TENSORFLOW_DEPS_TFLITE_HASH "43b2a782efb58b047c6d33f64d7ac711b24426959f91287d910edb8937c11dea") - # set(wasmedge_tensorflow_deps_SOURCE_DIR "${CMAKE_SOURCE_DIR}/utils/WasmEdge-tensorflow-deps") - - # set(WASMEDGE_TENSORFLOW_DEPS_TFLITE_LIB - # "${wasmedge_tensorflow_deps_SOURCE_DIR}/shares/libtensorflowlite_c.so" - # ) - - # -------------------------------- set(WASMEDGE_TENSORFLOW_DEPS_PATH ${wasmedge_tensorflow_deps_SOURCE_DIR}) set(WASMEDGE_TENSORFLOW_DEPS_BIN_PATH ${CMAKE_CURRENT_BINARY_DIR}/utils/WasmEdge-tensorflow-deps) diff --git a/utils/wasi-nn/download-openvino-fixtures.sh b/utils/wasi-nn/download-openvino-fixtures.sh index 2f57f688..02a243da 100755 --- a/utils/wasi-nn/download-openvino-fixtures.sh +++ b/utils/wasi-nn/download-openvino-fixtures.sh @@ -2,15 +2,24 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-FileCopyrightText: 2019-2022 Second State INC -DOWNLOAD_TO=$1 +TODIR=$1 +if [[ $# -eq 0 ]]; then + TODIR=. +fi FIXTURE=https://github.com/intel/openvino-rs/raw/v0.3.3/crates/openvino/tests/fixtures/mobilenet/ +if [ ! -d $TODIR ]; then + mkdir $TODIR +fi +if [ ! -d $TODIR ]; then + mkdir $TODIR +fi -if [ ! -f $DOWNLOAD_TO/mobilenet.bin ]; then - wget -q --no-clobber --directory-prefix=$DOWNLOAD_TO $FIXTURE/mobilenet.bin +if [ ! -f $TODIR/mobilenet.bin ]; then + curl -sL $FIXTURE/mobilenet.bin -o $TODIR/mobilenet.bin fi -if [ ! -f $DOWNLOAD_TO/mobilenet.xml ]; then - wget -q --no-clobber --directory-prefix=$DOWNLOAD_TO $FIXTURE/mobilenet.xml +if [ ! -f $TODIR/mobilenet.xml ]; then + curl -sL $FIXTURE/mobilenet.xml -o $TODIR/mobilenet.xml fi -if [ ! -f $DOWNLOAD_TO/tensor-1x224x224x3-f32.bgr ]; then - wget -q --no-clobber --directory-prefix=$DOWNLOAD_TO $FIXTURE/tensor-1x224x224x3-f32.bgr +if [ ! -f $TODIR/tensor-1x224x224x3-f32.bgr ]; then + curl -sL $FIXTURE/tensor-1x224x224x3-f32.bgr -o $TODIR/tensor-1x224x224x3-f32.bgr fi diff --git a/utils/wasi-nn/download-pytorch-fixtures.sh b/utils/wasi-nn/download-pytorch-fixtures.sh index 809cd22d..6a6aab91 100755 --- a/utils/wasi-nn/download-pytorch-fixtures.sh +++ b/utils/wasi-nn/download-pytorch-fixtures.sh @@ -2,12 +2,18 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-FileCopyrightText: 2019-2022 Second State INC -DOWNLOAD_TO=$1 +TODIR=$1 +if [[ $# -eq 0 ]]; then + TODIR=. +fi FIXTURE=https://github.com/second-state/WasmEdge-WASINN-examples/raw/master/pytorch-mobilenet-image/ -if [ ! -f $DOWNLOAD_TO/mobilenet.pt ]; then - wget -q --no-clobber --directory-prefix=$DOWNLOAD_TO $FIXTURE/mobilenet.pt +if [ ! -d $TODIR ]; then + mkdir $TODIR fi -if [ ! -f $DOWNLOAD_TO/image-1x3x224x224.rgb ]; then - wget -q --no-clobber --directory-prefix=$DOWNLOAD_TO $FIXTURE/image-1x3x224x224.rgb +if [ ! -f $TODIR/mobilenet.pt ]; then + curl -sL $FIXTURE/mobilenet.pt -o $TODIR/mobilenet.pt +fi +if [ ! -f $TODIR/image-1x3x224x224.rgb ]; then + curl -sL $FIXTURE/image-1x3x224x224.rgb -o $TODIR/image-1x3x224x224.rgb fi diff --git a/utils/wasi-nn/download-tflite-fixtures.sh b/utils/wasi-nn/download-tflite-fixtures.sh old mode 100644 new mode 100755 index e74e6511..959d7fee --- a/utils/wasi-nn/download-tflite-fixtures.sh +++ b/utils/wasi-nn/download-tflite-fixtures.sh @@ -1,13 +1,19 @@ -TODIR=$1 +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2022 Second State INC +TODIR=$1 +if [[ $# -eq 0 ]]; then + TODIR=. +fi FIXTURE=https://raw.githubusercontent.com/gusye1234/WasmEdge-WASINN-examples/demo-tflite-image/tflite-birds_v1-image if [ ! -d $TODIR ]; then mkdir $TODIR fi if [ ! -f $TODIR/lite-model_aiy_vision_classifier_birds_V1_3.tflite ]; then - curl -o $TODIR/lite-model_aiy_vision_classifier_birds_V1_3.tflite $FIXTURE/lite-model_aiy_vision_classifier_birds_V1_3.tflite + curl -sL $FIXTURE/lite-model_aiy_vision_classifier_birds_V1_3.tflite -o $TODIR/lite-model_aiy_vision_classifier_birds_V1_3.tflite fi if [ ! -f $TODIR/birdx224x224x3.rgb ]; then - curl -o $TODIR/birdx224x224x3.rgb $FIXTURE/birdx224x224x3.rgb + curl -sL $FIXTURE/birdx224x224x3.rgb -o $TODIR/birdx224x224x3.rgb fi From e0667c0b526d7defa2f2a0bdae1afb31a6d86fdf Mon Sep 17 00:00:00 2001 From: YiYing He Date: Thu, 20 Oct 2022 17:40:52 +0800 Subject: [PATCH 081/623] [WASI] Fix the order of graph-encoding in WASI-NN. Signed-off-by: YiYing He --- plugins/wasi_nn/wasinnenv.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index 8352470f..4ec819ad 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -34,9 +34,10 @@ enum class TensorType : uint32_t { F16 = 0, F32 = 1, U8 = 2, I32 = 3 }; enum class Backend : uint8_t { OpenVINO = 0, - PyTorch = 1, + ONNX = 1, Tensorflow = 2, - TensorflowLite = 3 + PyTorch = 3, + TensorflowLite = 4 }; class Graph { From 9f0dcb1765f0001cb4809065386fab7d3308d9fa Mon Sep 17 00:00:00 2001 From: YiYing He Date: Thu, 20 Oct 2022 21:52:59 +0800 Subject: [PATCH 082/623] [WASI] Fix the WASI-NN tensorflow-lite backend segmentation fault. Signed-off-by: YiYing He --- plugins/wasi_nn/wasinnenv.h | 4 ---- plugins/wasi_nn/wasinnfunc.cpp | 10 +++++----- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index 4ec819ad..6c50467f 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -76,9 +76,6 @@ class Graph { if (TFLiteMod) { TfLiteModelDelete(TFLiteMod); } - if (TFLiteOps) { - TfLiteInterpreterOptionsDelete(TFLiteOps); - } #endif } @@ -95,7 +92,6 @@ class Graph { #endif #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE TfLiteModel *TFLiteMod = nullptr; - TfLiteInterpreterOptions *TFLiteOps = nullptr; #endif }; diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index 96dc418d..f092f968 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -324,8 +324,6 @@ Expect WasiNNLoad::body(const Runtime::CallingFrame &Frame, Env.NNGraph.pop_back(); return static_cast(WASINN::ErrNo::InvalidArgument); } - Graph.TFLiteOps = TfLiteInterpreterOptionsCreate(); - TfLiteInterpreterOptionsSetNumThreads(Graph.TFLiteOps, 2); // Store the loaded graph. *GraphId = Env.NNGraph.size() - 1; @@ -398,8 +396,7 @@ Expect WasiNNInitExecCtx::body(const Runtime::CallingFrame &Frame, WASINN::Backend::TensorflowLite) { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE // Check the network and the execution network with the graph ID. - if (Env.NNGraph[GraphId].TFLiteMod == nullptr || - Env.NNGraph[GraphId].TFLiteOps == nullptr) { + if (Env.NNGraph[GraphId].TFLiteMod == nullptr) { spdlog::error("[WASI-NN] Model for Graph:{} is missing!", GraphId); return static_cast(WASINN::ErrNo::MissingMemory); } @@ -407,8 +404,11 @@ Expect WasiNNInitExecCtx::body(const Runtime::CallingFrame &Frame, Env.NNContext.emplace_back(Env.NNGraph[GraphId]); const auto Graph = Env.NNGraph[GraphId]; auto &NewContext = Env.NNContext.back(); + auto *TFLiteOps = TfLiteInterpreterOptionsCreate(); + TfLiteInterpreterOptionsSetNumThreads(TFLiteOps, 2); NewContext.TFLiteInterp = - TfLiteInterpreterCreate(Graph.TFLiteMod, Graph.TFLiteOps); + TfLiteInterpreterCreate(Graph.TFLiteMod, TFLiteOps); + TfLiteInterpreterOptionsDelete(TFLiteOps); if (unlikely(NewContext.TFLiteInterp == nullptr)) { spdlog::error("[WASI-NN] Cannot create TFLite interpreter."); Env.NNContext.pop_back(); From 1f3fe9338bcb77a9d7aef76cc17a703e7f8fee0a Mon Sep 17 00:00:00 2001 From: dm4 Date: Fri, 28 Oct 2022 10:08:52 +0800 Subject: [PATCH 083/623] [CI] Add network libs into slim images Signed-off-by: dm4 --- utils/docker/Dockerfile.release | 1 + 1 file changed, 1 insertion(+) diff --git a/utils/docker/Dockerfile.release b/utils/docker/Dockerfile.release index e2656cc9..4b4f51e2 100644 --- a/utils/docker/Dockerfile.release +++ b/utils/docker/Dockerfile.release @@ -1,6 +1,7 @@ FROM ubuntu:20.04 ARG VERSION +RUN apt-get update && apt-get install -y netbase ADD WasmEdge-$VERSION-Linux.tar.gz /tmp/ RUN cp -rf /tmp/WasmEdge-$VERSION-Linux/bin/* /usr/local/bin && \ cp -rf /tmp/WasmEdge-$VERSION-Linux/lib64/* /usr/local/lib && \ From 1413fdf760bb5ca7c723947090847ca3ab11f9df Mon Sep 17 00:00:00 2001 From: Harry Chiang Date: Tue, 29 Nov 2022 22:22:37 +0800 Subject: [PATCH 084/623] [MISC] Fix typo in cmake (#2129) Signed-off-by: Harry Chiang --- plugins/test/CMakeLists.txt | 2 +- plugins/wasi_crypto/CMakeLists.txt | 2 +- plugins/wasi_nn/CMakeLists.txt | 2 +- plugins/wasmedge_httpsreq/CMakeLists.txt | 2 +- plugins/wasmedge_process/CMakeLists.txt | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/plugins/test/CMakeLists.txt b/plugins/test/CMakeLists.txt index abc1ba64..6cf8cf5b 100644 --- a/plugins/test/CMakeLists.txt +++ b/plugins/test/CMakeLists.txt @@ -15,7 +15,7 @@ target_compile_options(wasmedgePluginTest -DWASMEDGE_PLUGIN ) -if(WASMEDGE_LINK_PUGLINS_STATIC) +if(WASMEDGE_LINK_PLUGINS_STATIC) target_link_libraries(wasmedgePluginTest PRIVATE wasmedgeCAPI diff --git a/plugins/wasi_crypto/CMakeLists.txt b/plugins/wasi_crypto/CMakeLists.txt index 2452d74e..ebec934b 100644 --- a/plugins/wasi_crypto/CMakeLists.txt +++ b/plugins/wasi_crypto/CMakeLists.txt @@ -76,7 +76,7 @@ target_link_libraries(wasmedgePluginWasiCrypto wasmedge_shared OpenSSL::Crypto ) -if(WASMEDGE_LINK_PUGLINS_STATIC) +if(WASMEDGE_LINK_PLUGINS_STATIC) target_link_libraries(wasmedgePluginWasiCrypto PRIVATE wasmedgeCAPI diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index dae0cb94..74517d89 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -27,7 +27,7 @@ target_include_directories(wasmedgePluginWasiNN ${CMAKE_CURRENT_SOURCE_DIR} ) -if(WASMEDGE_LINK_PUGLINS_STATIC) +if(WASMEDGE_LINK_PLUGINS_STATIC) target_link_libraries(wasmedgePluginWasiNN PRIVATE wasmedgeCAPI diff --git a/plugins/wasmedge_httpsreq/CMakeLists.txt b/plugins/wasmedge_httpsreq/CMakeLists.txt index d32cf3e6..755769df 100644 --- a/plugins/wasmedge_httpsreq/CMakeLists.txt +++ b/plugins/wasmedge_httpsreq/CMakeLists.txt @@ -36,7 +36,7 @@ target_link_libraries(wasmedgePluginHttpsReq OpenSSL::Crypto OpenSSL::SSL ) -if(WASMEDGE_LINK_PUGLINS_STATIC) +if(WASMEDGE_LINK_PLUGINS_STATIC) target_link_libraries(wasmedgePluginHttpsReq PRIVATE wasmedgeCAPI diff --git a/plugins/wasmedge_process/CMakeLists.txt b/plugins/wasmedge_process/CMakeLists.txt index 8c3db48b..98693eba 100644 --- a/plugins/wasmedge_process/CMakeLists.txt +++ b/plugins/wasmedge_process/CMakeLists.txt @@ -27,7 +27,7 @@ target_include_directories(wasmedgePluginWasmEdgeProcess ${CMAKE_CURRENT_SOURCE_DIR} ) -if(WASMEDGE_LINK_PUGLINS_STATIC) +if(WASMEDGE_LINK_PLUGINS_STATIC) target_link_libraries(wasmedgePluginWasmEdgeProcess PRIVATE wasmedgeCAPI From b8be9bdb8efa613090b899641228d6f0ba96557f Mon Sep 17 00:00:00 2001 From: YiYing He Date: Wed, 30 Nov 2022 14:07:43 +0800 Subject: [PATCH 085/623] [Plugin] Fix the wasi_nn::TensorType from u32 to u8. Signed-off-by: YiYing He --- plugins/wasi_nn/wasinnenv.h | 2 +- plugins/wasi_nn/wasinnfunc.cpp | 21 ++++++++++----------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index 6c50467f..241f0d5f 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -30,7 +30,7 @@ enum class ErrNo : uint32_t { Busy = 3 // Device or resource busy. }; -enum class TensorType : uint32_t { F16 = 0, F32 = 1, U8 = 2, I32 = 3 }; +enum class TensorType : uint8_t { F16 = 0, F32 = 1, U8 = 2, I32 = 3 }; enum class Backend : uint8_t { OpenVINO = 0, diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index f092f968..ac65ac0c 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -666,7 +666,6 @@ Expect WasiNNSetInput::body(const Runtime::CallingFrame &Frame, return static_cast(WASINN::ErrNo::InvalidArgument); } uint32_t TensorDataLen = Tensor[4]; - uint32_t TensorType = Tensor[2]; uint8_t *TensorDataBuf = MemInst->getPointer(Tensor[3], TensorDataLen); if (unlikely(TensorDataBuf == nullptr)) { @@ -674,32 +673,32 @@ Expect WasiNNSetInput::body(const Runtime::CallingFrame &Frame, return static_cast(WASINN::ErrNo::InvalidArgument); } + WASINN::TensorType RType = static_cast(Tensor[2]); auto *HoldTensor = TfLiteInterpreterGetInputTensor(CxtRef.TFLiteInterp, Index); - TfLiteType LiteType = TfLiteTensorType(HoldTensor); - WASINN::TensorType NNType; - - switch (LiteType) { + WASINN::TensorType LiteType; + switch (TfLiteTensorType(HoldTensor)) { case TfLiteType::kTfLiteUInt8: - NNType = WASINN::TensorType::U8; + LiteType = WASINN::TensorType::U8; break; case TfLiteType::kTfLiteFloat16: - NNType = WASINN::TensorType::F16; + LiteType = WASINN::TensorType::F16; break; case TfLiteType::kTfLiteFloat32: - NNType = WASINN::TensorType::F32; + LiteType = WASINN::TensorType::F32; break; case TfLiteType::kTfLiteInt32: - NNType = WASINN::TensorType::I32; + LiteType = WASINN::TensorType::I32; break; default: spdlog::error("[WASI-NN] Unsupported TFLite type: {}", LiteType); return static_cast(WASINN::ErrNo::InvalidArgument); } - if (unlikely(TensorType != static_cast(NNType))) { + if (unlikely(LiteType != RType)) { spdlog::error("[WASI-NN] Expect tensor type {}, but got {}", - static_cast(NNType), TensorType); + static_cast(LiteType), + static_cast(RType)); return static_cast(WASINN::ErrNo::InvalidArgument); } TfLiteStatus Stat = From ee0d1d08bc41efe2bfd061de0c6987d3d8f912b2 Mon Sep 17 00:00:00 2001 From: Puelloc Date: Thu, 22 Dec 2022 17:25:30 +0800 Subject: [PATCH 086/623] [WASI-crypto] fix: `symmetric_state_squeeze` for arbitrary-long output (#2177) Fixes #2176 Signed-off-by: Puelloc --- plugins/wasi_crypto/symmetric/kdf.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/plugins/wasi_crypto/symmetric/kdf.cpp b/plugins/wasi_crypto/symmetric/kdf.cpp index 5b00c430..f9c09d6e 100644 --- a/plugins/wasi_crypto/symmetric/kdf.cpp +++ b/plugins/wasi_crypto/symmetric/kdf.cpp @@ -69,7 +69,6 @@ Hkdf::Expand::State::squeeze(Span Out) noexcept { __WASI_CRYPTO_ERRNO_INVALID_KEY); } - ensureOrReturn(KeyLen == getKeySize(), __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); return {}; } From c3b34de24be1f15cfaf6a81ee2182054e2b31d92 Mon Sep 17 00:00:00 2001 From: Xiongsheng Wang <12643705+RRobot-lm@users.noreply.github.com> Date: Tue, 27 Dec 2022 23:46:13 +0800 Subject: [PATCH 087/623] [Misc] clean up typos and enable general linter (#2165) * [Chore] clean up typos and enable general linter Signed-off-by: RRobot-lm --- plugins/wasi_crypto/asymmetric_common/keypair.h | 2 +- .../wasi_crypto/asymmetric_common/publickey.h | 2 +- .../{registed.h => registered.h} | 16 ++++++++-------- .../wasi_crypto/asymmetric_common/secretkey.h | 2 +- plugins/wasi_crypto/ctx.h | 6 +++--- plugins/wasi_crypto/kx/kx.h | 2 +- .../wasi_crypto/kx/{registed.h => registered.h} | 6 +++--- plugins/wasi_crypto/signatures/eddsa.cpp | 2 +- .../signatures/{registed.h => registered.h} | 17 +++++++++-------- plugins/wasi_crypto/signatures/signatures.h | 2 +- plugins/wasi_crypto/signatures/signstate.h | 2 +- .../wasi_crypto/signatures/verificationstate.h | 2 +- plugins/wasi_crypto/symmetric/key.h | 2 +- .../symmetric/{registed.h => registered.h} | 12 ++++++------ plugins/wasi_crypto/symmetric/state.h | 2 +- plugins/wasi_crypto/utils/hostfunction.h | 2 +- plugins/wasi_crypto/utils/secret_vec.h | 2 +- plugins/wasi_nn/wasinnfunc.cpp | 2 +- test/plugins/wasi_crypto/common.cpp | 6 +++--- test/plugins/wasi_crypto/signatures.cpp | 2 +- test/plugins/wasi_nn/wasi_nn.cpp | 2 +- 21 files changed, 47 insertions(+), 46 deletions(-) rename plugins/wasi_crypto/asymmetric_common/{registed.h => registered.h} (68%) rename plugins/wasi_crypto/kx/{registed.h => registered.h} (83%) rename plugins/wasi_crypto/signatures/{registed.h => registered.h} (69%) rename plugins/wasi_crypto/symmetric/{registed.h => registered.h} (72%) diff --git a/plugins/wasi_crypto/asymmetric_common/keypair.h b/plugins/wasi_crypto/asymmetric_common/keypair.h index 1cdd138f..65a32172 100644 --- a/plugins/wasi_crypto/asymmetric_common/keypair.h +++ b/plugins/wasi_crypto/asymmetric_common/keypair.h @@ -15,7 +15,7 @@ #pragma once #include "asymmetric_common/publickey.h" -#include "asymmetric_common/registed.h" +#include "asymmetric_common/registered.h" #include "asymmetric_common/secretkey.h" #include "common/options.h" #include "utils/error.h" diff --git a/plugins/wasi_crypto/asymmetric_common/publickey.h b/plugins/wasi_crypto/asymmetric_common/publickey.h index 36dd7bf3..1e4915cb 100644 --- a/plugins/wasi_crypto/asymmetric_common/publickey.h +++ b/plugins/wasi_crypto/asymmetric_common/publickey.h @@ -13,7 +13,7 @@ //===----------------------------------------------------------------------===// #pragma once -#include "asymmetric_common/registed.h" +#include "asymmetric_common/registered.h" #include "utils/error.h" #include diff --git a/plugins/wasi_crypto/asymmetric_common/registed.h b/plugins/wasi_crypto/asymmetric_common/registered.h similarity index 68% rename from plugins/wasi_crypto/asymmetric_common/registed.h rename to plugins/wasi_crypto/asymmetric_common/registered.h index 560476de..989130ac 100644 --- a/plugins/wasi_crypto/asymmetric_common/registed.h +++ b/plugins/wasi_crypto/asymmetric_common/registered.h @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2022 Second State INC -//===-- wasmedge/plugins/wasi_crypto/asymmetric/registed.h - Registed -----===// +//===-- wasmedge/plugins/wasi_crypto/asymmetric/registered.h - Registered -===// // // Part of the WasmEdge Project. // @@ -14,8 +14,8 @@ #pragma once -#include "kx/registed.h" -#include "signatures/registed.h" +#include "kx/registered.h" +#include "signatures/registered.h" #include "utils/error.h" #include @@ -25,7 +25,7 @@ namespace Host { namespace WasiCrypto { namespace AsymmetricCommon { -template struct Registed { +template struct Registered { using PkVariant = std::variant; using SkVariant = std::variant; using KpVariant = std::variant; @@ -33,12 +33,12 @@ template struct Registed { }; template -struct Registed, Kx::Registed> { - using Alg = Registed; +struct Registered, Kx::Registered> { + using Alg = Registered; }; -/// Combine the signatures and kx algoritms. -using RegistedAlg = Registed::Alg; +/// Combine the signatures and kx algorithms. +using RegistedAlg = Registered::Alg; using Algorithm = RegistedAlg::Variant; diff --git a/plugins/wasi_crypto/asymmetric_common/secretkey.h b/plugins/wasi_crypto/asymmetric_common/secretkey.h index 16214c0b..25115a8b 100644 --- a/plugins/wasi_crypto/asymmetric_common/secretkey.h +++ b/plugins/wasi_crypto/asymmetric_common/secretkey.h @@ -14,7 +14,7 @@ #pragma once #include "asymmetric_common/publickey.h" -#include "asymmetric_common/registed.h" +#include "asymmetric_common/registered.h" #include "wasi_crypto/api.hpp" #include diff --git a/plugins/wasi_crypto/ctx.h b/plugins/wasi_crypto/ctx.h index 9a60c21e..f5eda8a7 100644 --- a/plugins/wasi_crypto/ctx.h +++ b/plugins/wasi_crypto/ctx.h @@ -19,13 +19,13 @@ #include "asymmetric_common/secretkey.h" #include "common/array_output.h" #include "common/options.h" -#include "kx/registed.h" -#include "signatures/registed.h" +#include "kx/registered.h" +#include "signatures/registered.h" #include "signatures/signatures.h" #include "signatures/signstate.h" #include "signatures/verificationstate.h" #include "symmetric/key.h" -#include "symmetric/registed.h" +#include "symmetric/registered.h" #include "symmetric/state.h" #include "symmetric/tag.h" #include "utils/error.h" diff --git a/plugins/wasi_crypto/kx/kx.h b/plugins/wasi_crypto/kx/kx.h index cf81bf62..eba5a4c3 100644 --- a/plugins/wasi_crypto/kx/kx.h +++ b/plugins/wasi_crypto/kx/kx.h @@ -14,7 +14,7 @@ #pragma once -#include "kx/registed.h" +#include "kx/registered.h" #include "utils/error.h" #include diff --git a/plugins/wasi_crypto/kx/registed.h b/plugins/wasi_crypto/kx/registered.h similarity index 83% rename from plugins/wasi_crypto/kx/registed.h rename to plugins/wasi_crypto/kx/registered.h index ac9e1713..655fc392 100644 --- a/plugins/wasi_crypto/kx/registed.h +++ b/plugins/wasi_crypto/kx/registered.h @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2022 Second State INC -//===-- wasmedge/plugins/wasi_crypto/kx/registed.h - Registed -------------===// +//===-- wasmedge/plugins/wasi_crypto/kx/registered.h - Registered ---------===// // // Part of the WasmEdge Project. // @@ -25,14 +25,14 @@ namespace Host { namespace WasiCrypto { namespace Kx { -template struct Registed { +template struct Registered { using PkVariant = std::variant; using SkVariant = std::variant; using KpVariant = std::variant; using Variant = std::variant; }; -using RegistedAlg = Registed; +using RegistedAlg = Registered; using Algorithm = RegistedAlg::Variant; diff --git a/plugins/wasi_crypto/signatures/eddsa.cpp b/plugins/wasi_crypto/signatures/eddsa.cpp index 73d3d6db..820204b8 100644 --- a/plugins/wasi_crypto/signatures/eddsa.cpp +++ b/plugins/wasi_crypto/signatures/eddsa.cpp @@ -232,7 +232,7 @@ Eddsa::VerificationState::update(Span Input) noexcept { WasiCryptoExpect Eddsa::VerificationState::verify(const Signature &Sig) noexcept { std::scoped_lock Lock{Ctx->Mutex}; - // The invokation to EVP_DigestVerifyFinal() internally finalizes a copy of + // The invocation to EVP_DigestVerifyFinal() internally finalizes a copy of // the digest context. This means that EVP_VerifyUpdate() and // EVP_VerifyFinal() can be called later to digest and verify the additional // data. diff --git a/plugins/wasi_crypto/signatures/registed.h b/plugins/wasi_crypto/signatures/registered.h similarity index 69% rename from plugins/wasi_crypto/signatures/registed.h rename to plugins/wasi_crypto/signatures/registered.h index 9b736ed0..5d436ea3 100644 --- a/plugins/wasi_crypto/signatures/registed.h +++ b/plugins/wasi_crypto/signatures/registered.h @@ -1,7 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2022 Second State INC -//===-- wasmedge/plugins/wasi_crypto/signatures/registed.h - Registed -----===// +//===-- wasmedge/plugins/wasi_crypto/signatures/registered.h - Registered +//-----===// // // Part of the WasmEdge Project. // @@ -25,7 +26,7 @@ namespace Host { namespace WasiCrypto { namespace Signatures { -template struct Registed { +template struct Registered { using PkVariant = std::variant; using SkVariant = std::variant; using KpVariant = std::variant; @@ -37,12 +38,12 @@ template struct Registed { }; using RegistedAlg = - Registed; + Registered; using Algorithm = RegistedAlg::Variant; diff --git a/plugins/wasi_crypto/signatures/signatures.h b/plugins/wasi_crypto/signatures/signatures.h index 44de8564..6e548c0f 100644 --- a/plugins/wasi_crypto/signatures/signatures.h +++ b/plugins/wasi_crypto/signatures/signatures.h @@ -14,7 +14,7 @@ #pragma once -#include "signatures/registed.h" +#include "signatures/registered.h" #include "utils/error.h" namespace WasmEdge { diff --git a/plugins/wasi_crypto/signatures/signstate.h b/plugins/wasi_crypto/signatures/signstate.h index be92cac0..02fcd7a5 100644 --- a/plugins/wasi_crypto/signatures/signstate.h +++ b/plugins/wasi_crypto/signatures/signstate.h @@ -14,7 +14,7 @@ #pragma once -#include "signatures/registed.h" +#include "signatures/registered.h" #include "signatures/signatures.h" #include "utils/error.h" diff --git a/plugins/wasi_crypto/signatures/verificationstate.h b/plugins/wasi_crypto/signatures/verificationstate.h index 565c1ffe..3e326daa 100644 --- a/plugins/wasi_crypto/signatures/verificationstate.h +++ b/plugins/wasi_crypto/signatures/verificationstate.h @@ -14,7 +14,7 @@ #pragma once -#include "signatures/registed.h" +#include "signatures/registered.h" #include "signatures/signatures.h" #include "utils/error.h" diff --git a/plugins/wasi_crypto/symmetric/key.h b/plugins/wasi_crypto/symmetric/key.h index b5f62150..bf2d8745 100644 --- a/plugins/wasi_crypto/symmetric/key.h +++ b/plugins/wasi_crypto/symmetric/key.h @@ -14,7 +14,7 @@ #pragma once -#include "symmetric/registed.h" +#include "symmetric/registered.h" namespace WasmEdge { namespace Host { diff --git a/plugins/wasi_crypto/symmetric/registed.h b/plugins/wasi_crypto/symmetric/registered.h similarity index 72% rename from plugins/wasi_crypto/symmetric/registed.h rename to plugins/wasi_crypto/symmetric/registered.h index 197478c4..0785fc6c 100644 --- a/plugins/wasi_crypto/symmetric/registed.h +++ b/plugins/wasi_crypto/symmetric/registered.h @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2022 Second State INC -//===-- wasmedge/plugins/wasi_crypto/symmetric/registed.h - Registed ------===// +//===-- wasmedge/plugins/wasi_crypto/symmetric/registered.h - Registered --===// // // Part of the WasmEdge Project. // @@ -27,17 +27,17 @@ namespace Host { namespace WasiCrypto { namespace Symmetric { -/// Registed algorithm -template struct Registed { +/// Registered algorithm +template struct Registered { using Key = std::variant; using State = std::variant; using Variant = std::variant; }; using RegistedAlg = - Registed; + Registered; using Algorithm = RegistedAlg::Variant; diff --git a/plugins/wasi_crypto/symmetric/state.h b/plugins/wasi_crypto/symmetric/state.h index 22c94e5c..d567cf2e 100644 --- a/plugins/wasi_crypto/symmetric/state.h +++ b/plugins/wasi_crypto/symmetric/state.h @@ -16,7 +16,7 @@ #pragma once #include "symmetric/key.h" -#include "symmetric/registed.h" +#include "symmetric/registered.h" #include "symmetric/tag.h" #include "utils/error.h" diff --git a/plugins/wasi_crypto/utils/hostfunction.h b/plugins/wasi_crypto/utils/hostfunction.h index dcbd0a5a..668e5c61 100644 --- a/plugins/wasi_crypto/utils/hostfunction.h +++ b/plugins/wasi_crypto/utils/hostfunction.h @@ -16,7 +16,7 @@ #pragma once #include "ctx.h" -#include "symmetric/registed.h" +#include "symmetric/registered.h" #include "utils/error.h" #include "runtime/callingframe.h" diff --git a/plugins/wasi_crypto/utils/secret_vec.h b/plugins/wasi_crypto/utils/secret_vec.h index b329046b..a66bed9f 100644 --- a/plugins/wasi_crypto/utils/secret_vec.h +++ b/plugins/wasi_crypto/utils/secret_vec.h @@ -29,7 +29,7 @@ namespace WasmEdge { namespace Host { namespace WasiCrypto { -/// A vector wrapper, but swipe the secret key info on destory. +/// A vector wrapper, but swipe the secret key info on destroy. class SecretVec { public: SecretVec(const SecretVec &) = default; diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index ac65ac0c..9087cfeb 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -963,7 +963,7 @@ Expect WasiNNCompute::body(const Runtime::CallingFrame &Frame, } TfLiteStatus Stat = TfLiteInterpreterInvoke(CxtRef.TFLiteInterp); if (unlikely(Stat != TfLiteStatus::kTfLiteOk)) { - spdlog::error("[WASI-NN] Invokation failed."); + spdlog::error("[WASI-NN] Invocation failed."); return static_cast(WASINN::ErrNo::Busy); } return static_cast(WASINN::ErrNo::Success); diff --git a/test/plugins/wasi_crypto/common.cpp b/test/plugins/wasi_crypto/common.cpp index 11e57d36..d50a527c 100644 --- a/test/plugins/wasi_crypto/common.cpp +++ b/test/plugins/wasi_crypto/common.cpp @@ -42,7 +42,7 @@ TEST_F(WasiCryptoTest, Options) { WASI_CRYPTO_EXPECT_TRUE( optionsSetU64(SymmetricOptionsHandle, "parallelism"sv, 0)); - // Unsupport options. + // Unsupported options. WASI_CRYPTO_EXPECT_FAILURE( optionsSet(SymmetricOptionsHandle, "foo"sv, "foo"_u8), __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); @@ -71,7 +71,7 @@ TEST_F(WasiCryptoTest, Options) { WASI_CRYPTO_EXPECT_SUCCESS(SigOptionsHandle, optionsOpen(__WASI_ALGORITHM_TYPE_SIGNATURES)); - // Unsupport options. + // Unsupported options. WASI_CRYPTO_EXPECT_FAILURE(optionsSet(SigOptionsHandle, "foo"sv, "foo"_u8), __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); @@ -99,7 +99,7 @@ TEST_F(WasiCryptoTest, Options) { // Open options. WASI_CRYPTO_EXPECT_SUCCESS(KxOptionsHandle, optionsOpen(__WASI_ALGORITHM_TYPE_KEY_EXCHANGE)); - // Unsupport options. + // Unsupported options. WASI_CRYPTO_EXPECT_FAILURE(optionsSet(KxOptionsHandle, "foo"sv, "foo"_u8), __WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); diff --git a/test/plugins/wasi_crypto/signatures.cpp b/test/plugins/wasi_crypto/signatures.cpp index 705038bc..8cd2f083 100644 --- a/test/plugins/wasi_crypto/signatures.cpp +++ b/test/plugins/wasi_crypto/signatures.cpp @@ -11,7 +11,7 @@ namespace WasiCrypto { using namespace std::literals; TEST_F(WasiCryptoTest, Signatures) { - // Use the generated data to sign and verfiy. + // Use the generated data to sign and verify. auto SigTest = [this](__wasi_algorithm_type_e_t AlgType, std::string_view Alg) { SCOPED_TRACE(Alg); diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index d6f74924..f5a481ca 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -184,7 +184,7 @@ TEST(WasiNNTest, OpenVINOBackend) { static_cast(ErrNo::InvalidArgument)); } - // Test: laod -- OpenVINO model xml ptr out of bounds. + // Test: load -- OpenVINO model xml ptr out of bounds. BuilderPtr = LoadEntryPtr; writeFatPointer(MemInst, OutBoundPtr, XmlRead.size(), BuilderPtr); writeFatPointer(MemInst, StorePtr + XmlRead.size(), WeightRead.size(), From a0c734032f4c32bdb5e9e329d7b05ab7faebf265 Mon Sep 17 00:00:00 2001 From: Puelloc Date: Tue, 10 Jan 2023 21:49:50 +0800 Subject: [PATCH 088/623] [WASI-crypto] fix: wasi-crypto keypair import read pem as pkcs8 (#2211) Fixes #2210 Signed-off-by: Puelloc --- plugins/wasi_crypto/signatures/rsa.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_crypto/signatures/rsa.cpp b/plugins/wasi_crypto/signatures/rsa.cpp index 7657f392..2ecd585c 100644 --- a/plugins/wasi_crypto/signatures/rsa.cpp +++ b/plugins/wasi_crypto/signatures/rsa.cpp @@ -193,7 +193,7 @@ template WasiCryptoExpect::KeyPair> Rsa::KeyPair::importPem( Span Encoded) noexcept { - return checkValid(EvpPkeyPtr{d2iPrivateKey(Encoded)}); + return checkValid(EvpPkeyPtr{pemReadPrivateKey(Encoded)}); } template From df62acdf7ca7aec02f04f9c546a8faa76a6d2419 Mon Sep 17 00:00:00 2001 From: Puelloc Date: Tue, 10 Jan 2023 23:01:56 +0800 Subject: [PATCH 089/623] [WASI-crypto] fix: wasi-crypto keypair_generate for rsa-pss (#2213) * fix: wasi-crypto keypair_generate for rsa-pss Fixes #2212 Signed-off-by: Puelloc --- plugins/wasi_crypto/signatures/rsa.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_crypto/signatures/rsa.cpp b/plugins/wasi_crypto/signatures/rsa.cpp index 2ecd585c..6e7fc8f0 100644 --- a/plugins/wasi_crypto/signatures/rsa.cpp +++ b/plugins/wasi_crypto/signatures/rsa.cpp @@ -228,9 +228,10 @@ template WasiCryptoExpect::KeyPair> Rsa::KeyPair::generate( OptionalRef) noexcept { - EvpPkeyCtxPtr Ctx{EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, nullptr)}; + const auto Id = + PadMode == RSA_PKCS1_PADDING ? EVP_PKEY_RSA : EVP_PKEY_RSA_PSS; + EvpPkeyCtxPtr Ctx{EVP_PKEY_CTX_new_id(Id, nullptr)}; EVP_PKEY_keygen_init(Ctx.get()); - EVP_PKEY_CTX_set_rsa_padding(Ctx.get(), PadMode); EVP_PKEY_CTX_set_rsa_keygen_bits(Ctx.get(), KeyBits); EVP_PKEY_CTX_set_signature_md(Ctx.get(), getShaCtx()); From 3ed8632a8e8fd3f36b94198fd23f44f92db78b7c Mon Sep 17 00:00:00 2001 From: YiYing He Date: Mon, 13 Feb 2023 17:05:51 +0800 Subject: [PATCH 090/623] [Test] Add the plug-in unit tests for C++ API. Signed-off-by: YiYing He --- test/plugins/CMakeLists.txt | 4 ++ test/plugins/unittest/CMakeLists.txt | 57 +++++++++++++++++ test/plugins/unittest/testplugin.cpp | 57 +++++++++++++++++ test/plugins/unittest/testplugin.h | 94 ++++++++++++++++++++++++++++ test/plugins/unittest/unittest.cpp | 89 ++++++++++++++++++++++++++ 5 files changed, 301 insertions(+) create mode 100644 test/plugins/unittest/CMakeLists.txt create mode 100644 test/plugins/unittest/testplugin.cpp create mode 100644 test/plugins/unittest/testplugin.h create mode 100644 test/plugins/unittest/unittest.cpp diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt index 8f5cebe0..420cb655 100644 --- a/test/plugins/CMakeLists.txt +++ b/test/plugins/CMakeLists.txt @@ -15,3 +15,7 @@ endif() if (WASMEDGE_PLUGIN_WASI_CRYPTO) add_subdirectory(wasi_crypto) endif() + +if(CMAKE_SYSTEM_NAME MATCHES "Linux") + add_subdirectory(unittest) +endif() diff --git a/test/plugins/unittest/CMakeLists.txt b/test/plugins/unittest/CMakeLists.txt new file mode 100644 index 00000000..35d0e5d0 --- /dev/null +++ b/test/plugins/unittest/CMakeLists.txt @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +# The test plugin module +wasmedge_add_library(wasmedgePluginTestModule + SHARED + testplugin.cpp +) + +target_compile_options(wasmedgePluginTestModule + PUBLIC + -DWASMEDGE_PLUGIN +) + +if(CMAKE_SYSTEM_NAME MATCHES "Darwin") + target_link_options(wasmedgePluginTestModule + PUBLIC + -Wl,-U,__ZN8WasmEdge6Plugin14PluginRegisterC1EPKNS0_6Plugin16PluginDescriptorE + -Wl,-U,__ZN8WasmEdge6Plugin14PluginRegisterD1Ev + ) +endif() + +target_include_directories(wasmedgePluginTestModule + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} +) + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginTestModule + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginTestModule + PRIVATE + wasmedge_shared + ) +endif() + +# The test executable +wasmedge_add_executable(wasmedgePluginUnittests + unittest.cpp +) + +target_include_directories(wasmedgePluginUnittests + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_link_libraries(wasmedgePluginUnittests + PRIVATE + ${GTEST_BOTH_LIBRARIES} + wasmedgePlugin +) + +add_test(wasmedgePluginUnittests wasmedgePluginUnittests) diff --git a/test/plugins/unittest/testplugin.cpp b/test/plugins/unittest/testplugin.cpp new file mode 100644 index 00000000..1ed3b670 --- /dev/null +++ b/test/plugins/unittest/testplugin.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "testplugin.h" +#include "po/helper.h" + +#include + +namespace WasmEdge { +namespace Host { + +using namespace std::literals::string_view_literals; + +PO::List + WasmEdgePluginTestEnv::CmdArgs(PO::Description("Test for args."sv), + PO::MetaVar("ARG"sv)); + +PO::Option + WasmEdgePluginTestEnv::CmdName(PO::Description("Test for input name."sv), + PO::DefaultValue(std::string(""))); + +namespace { + +void addOptions(const Plugin::Plugin::PluginDescriptor *, + PO::ArgumentParser &Parser) noexcept { + Parser.add_option("arg"sv, WasmEdgePluginTestEnv::CmdArgs) + .add_option("name"sv, WasmEdgePluginTestEnv::CmdName); +} + +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgePluginTestModule; +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasmedge_plugintest", + .Description = "", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 10, 0, 0}, + .ModuleCount = 1, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "wasmedge_plugintest", + .Description = "This is for the plugin tests in WasmEdge.", + .Create = create, + }, + }, + .AddOptions = addOptions, +}; + +} // namespace + +Plugin::PluginRegister WasmEdgePluginTestEnv::Register(&Descriptor); + +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/unittest/testplugin.h b/test/plugins/unittest/testplugin.h new file mode 100644 index 00000000..d01b9de8 --- /dev/null +++ b/test/plugins/unittest/testplugin.h @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "plugin/plugin.h" +#include "po/argument_parser.h" +#include "po/list.h" +#include "po/option.h" + +#include +#include + +namespace WasmEdge { +namespace Host { + +class WasmEdgePluginTestEnv { +public: + WasmEdgePluginTestEnv() noexcept = default; + + static PO::List CmdArgs; + static PO::Option CmdName; + static Plugin::PluginRegister Register; +}; + +template +class WasmEdgePluginTestFunc : public Runtime::HostFunction { +public: + WasmEdgePluginTestFunc(WasmEdgePluginTestEnv &HostEnv) + : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + WasmEdgePluginTestEnv &Env; +}; + +class WasmEdgePluginTestFuncAdd + : public WasmEdgePluginTestFunc { +public: + WasmEdgePluginTestFuncAdd(WasmEdgePluginTestEnv &HostEnv) + : WasmEdgePluginTestFunc(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t A, uint32_t B) { + return A + B; + } +}; + +class WasmEdgePluginTestFuncSub + : public WasmEdgePluginTestFunc { +public: + WasmEdgePluginTestFuncSub(WasmEdgePluginTestEnv &HostEnv) + : WasmEdgePluginTestFunc(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t A, uint32_t B) { + return A - B; + } +}; + +class WasmEdgePluginTestFuncArgLen + : public WasmEdgePluginTestFunc { +public: + WasmEdgePluginTestFuncArgLen(WasmEdgePluginTestEnv &HostEnv) + : WasmEdgePluginTestFunc(HostEnv) {} + Expect body(const Runtime::CallingFrame &) { + return static_cast(Env.CmdArgs.value().size()); + } +}; + +class WasmEdgePluginTestFuncNameSize + : public WasmEdgePluginTestFunc { +public: + WasmEdgePluginTestFuncNameSize(WasmEdgePluginTestEnv &HostEnv) + : WasmEdgePluginTestFunc(HostEnv) {} + Expect body(const Runtime::CallingFrame &) { + return static_cast(Env.CmdName.value().size()); + } +}; + +class WasmEdgePluginTestModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgePluginTestModule() + : Runtime::Instance::ModuleInstance("wasmedge_plugintest") { + addHostFunc("add", std::make_unique(Env)); + addHostFunc("sub", std::make_unique(Env)); + addHostFunc("arg_len", std::make_unique(Env)); + addHostFunc("name_size", + std::make_unique(Env)); + } + + WasmEdgePluginTestEnv &getEnv() { return Env; } + +private: + WasmEdgePluginTestEnv Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/unittest/unittest.cpp b/test/plugins/unittest/unittest.cpp new file mode 100644 index 00000000..1294ec00 --- /dev/null +++ b/test/plugins/unittest/unittest.cpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "common/defines.h" +#include "runtime/callingframe.h" +#include "runtime/instance/module.h" +#include "testplugin.h" + +#include +#include +#include +#include +#include +#include + +namespace { +WasmEdge::Runtime::Instance::ModuleInstance *createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "./libwasmedgePluginTestModule" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = + WasmEdge::Plugin::Plugin::find("wasmedge_plugintest"sv)) { + WasmEdge::PO::ArgumentParser Parser; + Plugin->registerOptions(Parser); + Parser.set_raw_value("name"sv, std::string("test_name")); + Parser.set_raw_value>( + "arg"sv, std::vector({"arg0", "arg1", "arg2", "arg3"})); + if (const auto *Module = Plugin->findModule("wasmedge_plugintest"sv)) { + return Module->create().release(); + } + } + return nullptr; +} +} // namespace + +TEST(wasmedgePluginTests, CPP_Run) { + // Create the wasmedge_plugintest module instance. + auto *TestMod = + dynamic_cast(createModule()); + EXPECT_FALSE(TestMod == nullptr); + + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + std::array RetVal; + + // Get the function "arg_len". + auto *FuncInst1 = TestMod->findFuncExports("arg_len"); + EXPECT_NE(FuncInst1, nullptr); + EXPECT_TRUE(FuncInst1->isHostFunction()); + auto &HostFuncInst1 = + dynamic_cast( + FuncInst1->getHostFunc()); + + // Test: Run function successfully. + EXPECT_TRUE(HostFuncInst1.run(CallFrame, {}, RetVal)); + EXPECT_EQ(RetVal[0].get(), 4); + + // Get the function "name_size". + auto *FuncInst2 = TestMod->findFuncExports("name_size"); + EXPECT_NE(FuncInst2, nullptr); + EXPECT_TRUE(FuncInst2->isHostFunction()); + auto &HostFuncInst2 = + dynamic_cast( + FuncInst2->getHostFunc()); + + // Test: Run function successfully. + EXPECT_TRUE(HostFuncInst2.run(CallFrame, {}, RetVal)); + EXPECT_EQ(RetVal[0].get(), 9); + + delete TestMod; +} + +TEST(wasmedgePluginTests, CPP_Module) { + // Create the wasmedge_plugintest module instance. + auto *TestMod = + dynamic_cast(createModule()); + EXPECT_FALSE(TestMod == nullptr); + EXPECT_EQ(TestMod->getFuncExportNum(), 4U); + EXPECT_NE(TestMod->findFuncExports("add"), nullptr); + EXPECT_NE(TestMod->findFuncExports("sub"), nullptr); + EXPECT_NE(TestMod->findFuncExports("arg_len"), nullptr); + EXPECT_NE(TestMod->findFuncExports("name_size"), nullptr); + delete TestMod; +} + +GTEST_API_ int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} From b97b51153722a6ab97935e3d0d143857a7a57acd Mon Sep 17 00:00:00 2001 From: YiYing He Date: Thu, 2 Feb 2023 13:22:23 +0800 Subject: [PATCH 091/623] [CI] Separate and stand alone the wasmedge process plugin. Signed-off-by: YiYing He --- plugins/CMakeLists.txt | 27 ++++++++++++++++----------- test/plugins/CMakeLists.txt | 15 ++++++++++----- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index 2e878ecc..c33c4128 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -1,23 +1,28 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: 2019-2022 Second State INC -# Only Linux systems support wasmedge_process now. -if (CMAKE_SYSTEM_NAME MATCHES "Linux") - add_subdirectory(wasmedge_process) +if(WASMEDGE_PLUGIN_WASI_NN_BACKEND) + add_subdirectory(wasi_nn) endif() -if (WASMEDGE_PLUGIN_WASI_NN_BACKEND) - add_subdirectory(wasi_nn) +if(WASMEDGE_PLUGIN_WASI_CRYPTO) + add_subdirectory(wasi_crypto) endif() -if(WASMEDGE_PLUGIN_HTTPSREQ) +if(WASMEDGE_PLUGIN_PROCESS) + # Only Linux systems support wasmedge_process now. if(CMAKE_SYSTEM_NAME MATCHES "Linux") - add_subdirectory(wasmedge_httpsreq) + add_subdirectory(wasmedge_process) + else() + message(WARNING "Only Linux platforms support WasmEdge_Process plug-in now.") endif() endif() -if (WASMEDGE_PLUGIN_WASI_CRYPTO) - add_subdirectory(wasi_crypto) +if(WASMEDGE_PLUGIN_HTTPSREQ) + # Only Linux systems support wasmedge_httpsreq now. + if(CMAKE_SYSTEM_NAME MATCHES "Linux") + add_subdirectory(wasmedge_httpsreq) + else() + message(WARNING "Only Linux platforms support WasmEdge_HttpsReq plug-in now.") + endif() endif() - -add_subdirectory(test) diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt index 420cb655..2ab4a70d 100644 --- a/test/plugins/CMakeLists.txt +++ b/test/plugins/CMakeLists.txt @@ -1,18 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: 2019-2022 Second State INC -if (CMAKE_SYSTEM_NAME MATCHES "Linux") - add_subdirectory(wasmedge_process) +if(WASMEDGE_PLUGIN_PROCESS) + if (CMAKE_SYSTEM_NAME MATCHES "Linux") + add_subdirectory(wasmedge_process) + endif() endif() -if (WASMEDGE_PLUGIN_WASI_NN_BACKEND) + +if(WASMEDGE_PLUGIN_WASI_NN_BACKEND) add_subdirectory(wasi_nn) endif() -if(WASMEDGE_PLUGIN_HTTPSREQ) + +if(WASMEDGE_PLUGIN_HTTPSREQ) if(CMAKE_SYSTEM_NAME MATCHES "Linux") add_subdirectory(wasmedge_httpsreq) endif() endif() -if (WASMEDGE_PLUGIN_WASI_CRYPTO) + +if(WASMEDGE_PLUGIN_WASI_CRYPTO) add_subdirectory(wasi_crypto) endif() From 366489ac989252ba177e03757830e43831c67a99 Mon Sep 17 00:00:00 2001 From: Puelloc Date: Fri, 17 Feb 2023 02:26:00 +0800 Subject: [PATCH 092/623] [WASI-Crypto] use i2d_PKCS8PrivateKey_bio and fix test (#2283) Signed-off-by: Puellaquae Co-authored-by: hydai --- plugins/wasi_crypto/asymmetric_common/ecdsa.h | 18 +++++++++++++++++- test/plugins/wasi_crypto/asymmetric.cpp | 9 +++++---- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/plugins/wasi_crypto/asymmetric_common/ecdsa.h b/plugins/wasi_crypto/asymmetric_common/ecdsa.h index 9f4fc960..5e170cf0 100644 --- a/plugins/wasi_crypto/asymmetric_common/ecdsa.h +++ b/plugins/wasi_crypto/asymmetric_common/ecdsa.h @@ -239,7 +239,23 @@ class Ecdsa { } WasiCryptoExpect exportPkcs8() const noexcept { - return i2dPrivateKey(Ctx.get()); + EVP_PKEY *Key = Ctx.get(); + BioPtr Bio{BIO_new(BIO_s_mem())}; + opensslCheck(i2d_PKCS8PrivateKey_bio(Bio.get(), Key, nullptr, nullptr, 0, + nullptr, nullptr)); + + BUF_MEM *Mem = nullptr; + opensslCheck(BIO_get_mem_ptr(Bio.get(), &Mem)); + SecretVec Ret(Mem->length); + + if (size_t Size; BIO_read_ex(Bio.get(), Ret.data(), Ret.size(), &Size)) { + ensureOrReturn(Size == Ret.size(), + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + } else { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + } + + return Ret; } WasiCryptoExpect exportPem() const noexcept { diff --git a/test/plugins/wasi_crypto/asymmetric.cpp b/test/plugins/wasi_crypto/asymmetric.cpp index 9a3389d9..c218216a 100644 --- a/test/plugins/wasi_crypto/asymmetric.cpp +++ b/test/plugins/wasi_crypto/asymmetric.cpp @@ -110,10 +110,11 @@ TEST_F(WasiCryptoTest, Asymmetric) { {{__WASI_SECRETKEY_ENCODING_RAW, "b9aa5c28ef96d750e47f4ba44d5d6a7ac3ab6988d292e7819e362a4b0ac8e250"_u8v}, {__WASI_SECRETKEY_ENCODING_PKCS8, - "30740201010420b9aa5c28ef96d750e47f4ba44d5d6a7ac3ab6988d292e7819e362a4b" - "0ac8e250a00706052b8104000aa144034200047fef8e21686370c7d343992f14b2d45a" - "262cd6a5c75032736fcbb02f46a99edf0e1d114cdc93956cc75648bfd38fa832a82135" - "d5c2ba634766a8753f6d88aae5"_u8v}, + "308184020100301006072a8648ce3d020106052b8104000a046d306b02010104" + "207778b8225c02cc7f2ebcd0a47e2c4fcebd6716a329bdf2e4f961fa35041cba" + "97a1440342000434e2dea3923666bc28779bcd84fba5b4ee97bb8f6ec3cdc0d8" + "6609f6c8b8b9ca81592cdf4d3aeccdacb092e94e8f814265f46e3eefb49ad43c" + "3968e69d4faef4"_u8v}, {__WASI_SECRETKEY_ENCODING_PEM, "-----BEGIN PRIVATE KEY-----\n" "MIGEAgEAMBAGByqGSM49AgEGBSuBBAAKBG0wawIBAQQguapcKO+W11Dkf0ukTV1q\n" From c27a38877608cfa86e23073a696ea2d666887998 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Fri, 17 Mar 2023 16:31:53 +0800 Subject: [PATCH 093/623] [Test] Add the C API plugin tests. Signed-off-by: YiYing He --- plugins/test/CMakeLists.txt | 28 ----- plugins/test/test.c | 75 ----------- test/plugins/unittest/CMakeLists.txt | 93 ++++++++++---- test/plugins/unittest/testplugin.c | 96 ++++++++++++++ test/plugins/unittest/testplugin.cpp | 4 +- test/plugins/unittest/testplugin.h | 2 +- test/plugins/unittest/unittest.cpp | 89 ------------- test/plugins/unittest/unittest_c.cpp | 167 +++++++++++++++++++++++++ test/plugins/unittest/unittest_cpp.cpp | 120 ++++++++++++++++++ 9 files changed, 457 insertions(+), 217 deletions(-) delete mode 100644 plugins/test/CMakeLists.txt delete mode 100644 plugins/test/test.c create mode 100644 test/plugins/unittest/testplugin.c delete mode 100644 test/plugins/unittest/unittest.cpp create mode 100644 test/plugins/unittest/unittest_c.cpp create mode 100644 test/plugins/unittest/unittest_cpp.cpp diff --git a/plugins/test/CMakeLists.txt b/plugins/test/CMakeLists.txt deleted file mode 100644 index 6cf8cf5b..00000000 --- a/plugins/test/CMakeLists.txt +++ /dev/null @@ -1,28 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC - -wasmedge_add_library(wasmedgePluginTest - SHARED - test.c -) - -set_target_properties(wasmedgePluginTest PROPERTIES - C_STANDARD 11 -) - -target_compile_options(wasmedgePluginTest - PUBLIC - -DWASMEDGE_PLUGIN -) - -if(WASMEDGE_LINK_PLUGINS_STATIC) - target_link_libraries(wasmedgePluginTest - PRIVATE - wasmedgeCAPI - ) -else() - target_link_libraries(wasmedgePluginTest - PRIVATE - wasmedge_shared - ) -endif() diff --git a/plugins/test/test.c b/plugins/test/test.c deleted file mode 100644 index bf8e95d6..00000000 --- a/plugins/test/test.c +++ /dev/null @@ -1,75 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC - -#include "wasmedge/wasmedge.h" -#include - -static int32_t TestingOption; -static const int32_t TestingOptionDefaultValue = 42; - -static WasmEdge_Result Test(void *Data __attribute__((unused)), - const WasmEdge_CallingFrameContext *CallFrameCxt - __attribute__((unused)), - const WasmEdge_Value *In __attribute__((unused)), - WasmEdge_Value *Out __attribute__((unused))) { - return WasmEdge_Result_Success; -} - -static WasmEdge_ModuleInstanceContext * -CreateTestModule(const struct WasmEdge_ModuleDescriptor *Desc - __attribute__((unused))) { - WasmEdge_ModuleInstanceContext *Mod; - - { - WasmEdge_String ModuleName = WasmEdge_StringCreateByCString("test"); - Mod = WasmEdge_ModuleInstanceCreate(ModuleName); - WasmEdge_StringDelete(ModuleName); - } - - { - WasmEdge_FunctionTypeContext *FType = - WasmEdge_FunctionTypeCreate(NULL, 0, NULL, 0); - WasmEdge_FunctionInstanceContext *Func = - WasmEdge_FunctionInstanceCreate(FType, Test, NULL, 0); - WasmEdge_FunctionTypeDelete(FType); - WasmEdge_String FName = WasmEdge_StringCreateByCString("test"); - WasmEdge_ModuleInstanceAddFunction(Mod, FName, Func); - WasmEdge_StringDelete(FName); - } - - return Mod; -} - -static WasmEdge_ProgramOption PODesc[] = {{ - .Name = "test", - .Description = "testing option", - .Type = WasmEdge_ProgramOptionType_Int32, - .Storage = &TestingOption, - .DefaultValue = &TestingOptionDefaultValue, -}}; -static WasmEdge_ModuleDescriptor ModuleDesc[] = {{ - .Name = "test", - .Description = "testing module", - .Create = CreateTestModule, -}}; -static WasmEdge_PluginDescriptor Desc[] = {{ - .Name = "test", - .Description = "testing plugin", - .APIVersion = WasmEdge_Plugin_CurrentAPIVersion, - .Version = - { - .Major = 0, - .Minor = 0, - .Patch = 0, - .Build = 0, - }, - .ModuleCount = 1, - .ProgramOptionCount = 1, - .ModuleDescriptions = ModuleDesc, - .ProgramOptions = PODesc, -}}; - -WASMEDGE_CAPI_PLUGIN_EXPORT const WasmEdge_PluginDescriptor * -WasmEdge_Plugin_GetDescriptor(void) { - return Desc; -} diff --git a/test/plugins/unittest/CMakeLists.txt b/test/plugins/unittest/CMakeLists.txt index 35d0e5d0..2de4833c 100644 --- a/test/plugins/unittest/CMakeLists.txt +++ b/test/plugins/unittest/CMakeLists.txt @@ -1,57 +1,106 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: 2019-2022 Second State INC -# The test plugin module -wasmedge_add_library(wasmedgePluginTestModule +# The test plugin module in C API +wasmedge_add_library(wasmedgePluginTestModuleC + SHARED + testplugin.c +) + +set_target_properties(wasmedgePluginTestModuleC PROPERTIES + C_STANDARD 11 +) + +target_compile_options(wasmedgePluginTestModuleC + PUBLIC + -DWASMEDGE_PLUGIN +) + +# The test plugin module in C++ API +wasmedge_add_library(wasmedgePluginTestModuleCPP SHARED testplugin.cpp ) -target_compile_options(wasmedgePluginTestModule +target_compile_options(wasmedgePluginTestModuleCPP PUBLIC -DWASMEDGE_PLUGIN ) if(CMAKE_SYSTEM_NAME MATCHES "Darwin") - target_link_options(wasmedgePluginTestModule + target_link_options(wasmedgePluginTestModuleCPP PUBLIC -Wl,-U,__ZN8WasmEdge6Plugin14PluginRegisterC1EPKNS0_6Plugin16PluginDescriptorE -Wl,-U,__ZN8WasmEdge6Plugin14PluginRegisterD1Ev ) endif() -target_include_directories(wasmedgePluginTestModule +target_include_directories(wasmedgePluginTestModuleCPP PUBLIC $ ${CMAKE_CURRENT_SOURCE_DIR} ) -if(WASMEDGE_LINK_PLUGINS_STATIC) - target_link_libraries(wasmedgePluginTestModule - PRIVATE - wasmedgeCAPI - ) -else() - target_link_libraries(wasmedgePluginTestModule - PRIVATE - wasmedge_shared - ) -endif() +# The test executable for C API +wasmedge_add_executable(wasmedgePluginUnittestsC + unittest_c.cpp +) + +target_include_directories(wasmedgePluginUnittestsC + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_link_libraries(wasmedgePluginUnittestsC + PRIVATE + ${GTEST_BOTH_LIBRARIES} + wasmedgePlugin +) -# The test executable -wasmedge_add_executable(wasmedgePluginUnittests - unittest.cpp +# The test executable for C++ API +wasmedge_add_executable(wasmedgePluginUnittestsCPP + unittest_cpp.cpp ) -target_include_directories(wasmedgePluginUnittests +target_include_directories(wasmedgePluginUnittestsCPP PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} ) -target_link_libraries(wasmedgePluginUnittests +target_link_libraries(wasmedgePluginUnittestsCPP PRIVATE ${GTEST_BOTH_LIBRARIES} wasmedgePlugin ) -add_test(wasmedgePluginUnittests wasmedgePluginUnittests) +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginTestModuleC + PRIVATE + wasmedgeCAPI + ) + target_link_libraries(wasmedgePluginTestModuleCPP + PRIVATE + wasmedgeCAPI + ) + target_link_libraries(wasmedgePluginUnittestsC + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginTestModuleC + PRIVATE + wasmedge_shared + ) + target_link_libraries(wasmedgePluginTestModuleCPP + PRIVATE + wasmedge_shared + ) + target_link_libraries(wasmedgePluginUnittestsC + PRIVATE + wasmedge_shared + ) +endif() + +add_test(wasmedgePluginUnittestsC wasmedgePluginUnittestsC) +add_test(wasmedgePluginUnittestsCPP wasmedgePluginUnittestsCPP) diff --git a/test/plugins/unittest/testplugin.c b/test/plugins/unittest/testplugin.c new file mode 100644 index 00000000..7b58cfaf --- /dev/null +++ b/test/plugins/unittest/testplugin.c @@ -0,0 +1,96 @@ + +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "wasmedge/wasmedge.h" + +#include + +static WasmEdge_String NameString; +static const char NameCString[] = "name"; +static const WasmEdge_String NameStringDefaultValue = {.Buf = NameCString, + .Length = 4}; + +WasmEdge_Result HostFuncAdd(void *Data __attribute__((unused)), + const WasmEdge_CallingFrameContext *CallFrameCxt + __attribute__((unused)), + const WasmEdge_Value *In, WasmEdge_Value *Out) { + int32_t Val1 = WasmEdge_ValueGetI32(In[0]); + int32_t Val2 = WasmEdge_ValueGetI32(In[1]); + Out[0] = WasmEdge_ValueGenI32(Val1 + Val2); + return WasmEdge_Result_Success; +} + +WasmEdge_Result HostFuncSub(void *Data __attribute__((unused)), + const WasmEdge_CallingFrameContext *CallFrameCxt + __attribute__((unused)), + const WasmEdge_Value *In, WasmEdge_Value *Out) { + int32_t Val1 = WasmEdge_ValueGetI32(In[0]); + int32_t Val2 = WasmEdge_ValueGetI32(In[1]); + Out[0] = WasmEdge_ValueGenI32(Val1 - Val2); + return WasmEdge_Result_Success; +} + +WasmEdge_ModuleInstanceContext * +CreateTestModule(const struct WasmEdge_ModuleDescriptor *Desc) { + WasmEdge_String ModuleName = + WasmEdge_StringCreateByCString(Desc->Name); + WasmEdge_ModuleInstanceContext *Mod = + WasmEdge_ModuleInstanceCreate(ModuleName); + WasmEdge_StringDelete(ModuleName); + + WasmEdge_String FuncName; + WasmEdge_FunctionTypeContext *FType; + WasmEdge_FunctionInstanceContext *FuncCxt; + enum WasmEdge_ValType ParamTypes[2], ReturnTypes[1]; + ParamTypes[0] = WasmEdge_ValType_I32; + ParamTypes[1] = WasmEdge_ValType_I32; + ReturnTypes[0] = WasmEdge_ValType_I32; + + FType = WasmEdge_FunctionTypeCreate(ParamTypes, 2, ReturnTypes, 1); + FuncName = WasmEdge_StringCreateByCString("add"); + FuncCxt = WasmEdge_FunctionInstanceCreate(FType, HostFuncAdd, NULL, 0); + WasmEdge_ModuleInstanceAddFunction(Mod, FuncName, FuncCxt); + WasmEdge_StringDelete(FuncName); + FuncName = WasmEdge_StringCreateByCString("sub"); + FuncCxt = WasmEdge_FunctionInstanceCreate(FType, HostFuncSub, NULL, 0); + WasmEdge_ModuleInstanceAddFunction(Mod, FuncName, FuncCxt); + WasmEdge_StringDelete(FuncName); + WasmEdge_FunctionTypeDelete(FType); + + return Mod; +} + +static WasmEdge_ProgramOption PODesc[] = {{ + .Name = "name", + .Description = "test name string", + .Type = WasmEdge_ProgramOptionType_String, + .Storage = &NameString, + .DefaultValue = &NameStringDefaultValue, +}}; +static WasmEdge_ModuleDescriptor ModuleDesc[] = {{ + .Name = "wasmedge_plugintest_c_module", + .Description = "This is for the plugin tests in WasmEdge C API.", + .Create = CreateTestModule, +}}; +static WasmEdge_PluginDescriptor Desc = { + .Name = "wasmedge_plugintest_c", + .Description = "", + .APIVersion = WasmEdge_Plugin_CurrentAPIVersion, + .Version = + { + .Major = 0, + .Minor = 10, + .Patch = 0, + .Build = 0, + }, + .ModuleCount = 1, + .ProgramOptionCount = 1, + .ModuleDescriptions = ModuleDesc, + .ProgramOptions = PODesc, +}; + +WASMEDGE_CAPI_PLUGIN_EXPORT const WasmEdge_PluginDescriptor * +WasmEdge_Plugin_GetDescriptor(void) { + return &Desc; +} diff --git a/test/plugins/unittest/testplugin.cpp b/test/plugins/unittest/testplugin.cpp index 1ed3b670..21431b14 100644 --- a/test/plugins/unittest/testplugin.cpp +++ b/test/plugins/unittest/testplugin.cpp @@ -33,7 +33,7 @@ create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { } Plugin::Plugin::PluginDescriptor Descriptor{ - .Name = "wasmedge_plugintest", + .Name = "wasmedge_plugintest_cpp", .Description = "", .APIVersion = Plugin::Plugin::CurrentAPIVersion, .Version = {0, 10, 0, 0}, @@ -41,7 +41,7 @@ Plugin::Plugin::PluginDescriptor Descriptor{ .ModuleDescriptions = (Plugin::PluginModule::ModuleDescriptor[]){ { - .Name = "wasmedge_plugintest", + .Name = "wasmedge_plugintest_cpp_module", .Description = "This is for the plugin tests in WasmEdge.", .Create = create, }, diff --git a/test/plugins/unittest/testplugin.h b/test/plugins/unittest/testplugin.h index d01b9de8..214ab049 100644 --- a/test/plugins/unittest/testplugin.h +++ b/test/plugins/unittest/testplugin.h @@ -76,7 +76,7 @@ class WasmEdgePluginTestFuncNameSize class WasmEdgePluginTestModule : public Runtime::Instance::ModuleInstance { public: WasmEdgePluginTestModule() - : Runtime::Instance::ModuleInstance("wasmedge_plugintest") { + : Runtime::Instance::ModuleInstance("wasmedge_plugintest_cpp_module") { addHostFunc("add", std::make_unique(Env)); addHostFunc("sub", std::make_unique(Env)); addHostFunc("arg_len", std::make_unique(Env)); diff --git a/test/plugins/unittest/unittest.cpp b/test/plugins/unittest/unittest.cpp deleted file mode 100644 index 1294ec00..00000000 --- a/test/plugins/unittest/unittest.cpp +++ /dev/null @@ -1,89 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC - -#include "common/defines.h" -#include "runtime/callingframe.h" -#include "runtime/instance/module.h" -#include "testplugin.h" - -#include -#include -#include -#include -#include -#include - -namespace { -WasmEdge::Runtime::Instance::ModuleInstance *createModule() { - using namespace std::literals::string_view_literals; - WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( - "./libwasmedgePluginTestModule" WASMEDGE_LIB_EXTENSION)); - if (const auto *Plugin = - WasmEdge::Plugin::Plugin::find("wasmedge_plugintest"sv)) { - WasmEdge::PO::ArgumentParser Parser; - Plugin->registerOptions(Parser); - Parser.set_raw_value("name"sv, std::string("test_name")); - Parser.set_raw_value>( - "arg"sv, std::vector({"arg0", "arg1", "arg2", "arg3"})); - if (const auto *Module = Plugin->findModule("wasmedge_plugintest"sv)) { - return Module->create().release(); - } - } - return nullptr; -} -} // namespace - -TEST(wasmedgePluginTests, CPP_Run) { - // Create the wasmedge_plugintest module instance. - auto *TestMod = - dynamic_cast(createModule()); - EXPECT_FALSE(TestMod == nullptr); - - WasmEdge::Runtime::Instance::ModuleInstance Mod(""); - WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); - std::array RetVal; - - // Get the function "arg_len". - auto *FuncInst1 = TestMod->findFuncExports("arg_len"); - EXPECT_NE(FuncInst1, nullptr); - EXPECT_TRUE(FuncInst1->isHostFunction()); - auto &HostFuncInst1 = - dynamic_cast( - FuncInst1->getHostFunc()); - - // Test: Run function successfully. - EXPECT_TRUE(HostFuncInst1.run(CallFrame, {}, RetVal)); - EXPECT_EQ(RetVal[0].get(), 4); - - // Get the function "name_size". - auto *FuncInst2 = TestMod->findFuncExports("name_size"); - EXPECT_NE(FuncInst2, nullptr); - EXPECT_TRUE(FuncInst2->isHostFunction()); - auto &HostFuncInst2 = - dynamic_cast( - FuncInst2->getHostFunc()); - - // Test: Run function successfully. - EXPECT_TRUE(HostFuncInst2.run(CallFrame, {}, RetVal)); - EXPECT_EQ(RetVal[0].get(), 9); - - delete TestMod; -} - -TEST(wasmedgePluginTests, CPP_Module) { - // Create the wasmedge_plugintest module instance. - auto *TestMod = - dynamic_cast(createModule()); - EXPECT_FALSE(TestMod == nullptr); - EXPECT_EQ(TestMod->getFuncExportNum(), 4U); - EXPECT_NE(TestMod->findFuncExports("add"), nullptr); - EXPECT_NE(TestMod->findFuncExports("sub"), nullptr); - EXPECT_NE(TestMod->findFuncExports("arg_len"), nullptr); - EXPECT_NE(TestMod->findFuncExports("name_size"), nullptr); - delete TestMod; -} - -GTEST_API_ int main(int argc, char **argv) { - testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/test/plugins/unittest/unittest_c.cpp b/test/plugins/unittest/unittest_c.cpp new file mode 100644 index 00000000..cc36965a --- /dev/null +++ b/test/plugins/unittest/unittest_c.cpp @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "common/defines.h" +#include "wasmedge/wasmedge.h" + +#include + +#include + +namespace { +WasmEdge_ModuleInstanceContext *createModuleC() { + WasmEdge_PluginLoadFromPath( + "./libwasmedgePluginTestModuleC" WASMEDGE_LIB_EXTENSION); + WasmEdge_String Str = WasmEdge_StringCreateByCString("wasmedge_plugintest_c"); + const WasmEdge_PluginContext *PluginCxt = WasmEdge_PluginFind(Str); + WasmEdge_StringDelete(Str); + if (!PluginCxt) { + return nullptr; + } + + Str = WasmEdge_StringCreateByCString("wasmedge_plugintest_c_module"); + WasmEdge_ModuleInstanceContext *ModCxt = + WasmEdge_PluginCreateModule(PluginCxt, Str); + WasmEdge_StringDelete(Str); + return ModCxt; +} + +WasmEdge_ModuleInstanceContext *createModuleCPP() { + WasmEdge_PluginLoadFromPath( + "./libwasmedgePluginTestModuleCPP" WASMEDGE_LIB_EXTENSION); + WasmEdge_String Str = + WasmEdge_StringCreateByCString("wasmedge_plugintest_cpp"); + const WasmEdge_PluginContext *PluginCxt = WasmEdge_PluginFind(Str); + WasmEdge_StringDelete(Str); + if (!PluginCxt) { + return nullptr; + } + + Str = WasmEdge_StringCreateByCString("wasmedge_plugintest_cpp_module"); + WasmEdge_ModuleInstanceContext *ModCxt = + WasmEdge_PluginCreateModule(PluginCxt, Str); + WasmEdge_StringDelete(Str); + return ModCxt; +} +} // namespace + +TEST(wasmedgePluginTests, C_Run) { + auto *ExecCxt = WasmEdge_ExecutorCreate(nullptr, nullptr); + auto *StoreCxt = WasmEdge_StoreCreate(); + WasmEdge_Result Res; + WasmEdge_FunctionInstanceContext *FuncCxt; + WasmEdge_String FuncName; + WasmEdge_Value Params[2], Returns[1]; + + // Create the wasmedge_plugintest_c_module module instance. + auto *ModInstC = createModuleC(); + EXPECT_FALSE(ModInstC == nullptr); + + // Create the wasmedge_plugintest_cpp_module module instance. + auto *ModInstCPP = createModuleCPP(); + EXPECT_FALSE(ModInstCPP == nullptr); + + // Test: Run the function "wasmedge_plugintest_c.add". + FuncName = WasmEdge_StringCreateByCString("add"); + FuncCxt = WasmEdge_ModuleInstanceFindFunction(ModInstC, FuncName); + WasmEdge_StringDelete(FuncName); + EXPECT_NE(FuncCxt, nullptr); + Params[0] = WasmEdge_ValueGenI32(111); + Params[1] = WasmEdge_ValueGenI32(333); + Res = WasmEdge_ExecutorInvoke(ExecCxt, FuncCxt, Params, 2, Returns, 1); + EXPECT_TRUE(WasmEdge_ResultOK(Res)); + EXPECT_EQ(WasmEdge_ValueGetI32(Returns[0]), 444); + + // Test: Run the function "wasmedge_plugintest_c.sub". + FuncName = WasmEdge_StringCreateByCString("sub"); + FuncCxt = WasmEdge_ModuleInstanceFindFunction(ModInstC, FuncName); + WasmEdge_StringDelete(FuncName); + EXPECT_NE(FuncCxt, nullptr); + Params[0] = WasmEdge_ValueGenI32(666); + Params[1] = WasmEdge_ValueGenI32(555); + Res = WasmEdge_ExecutorInvoke(ExecCxt, FuncCxt, Params, 2, Returns, 1); + EXPECT_TRUE(WasmEdge_ResultOK(Res)); + EXPECT_EQ(WasmEdge_ValueGetI32(Returns[0]), 111); + + // Test: Run the function "wasmedge_plugintest_cpp.add". + FuncName = WasmEdge_StringCreateByCString("add"); + FuncCxt = WasmEdge_ModuleInstanceFindFunction(ModInstCPP, FuncName); + WasmEdge_StringDelete(FuncName); + EXPECT_NE(FuncCxt, nullptr); + Params[0] = WasmEdge_ValueGenI32(111); + Params[1] = WasmEdge_ValueGenI32(333); + Res = WasmEdge_ExecutorInvoke(ExecCxt, FuncCxt, Params, 2, Returns, 1); + EXPECT_TRUE(WasmEdge_ResultOK(Res)); + EXPECT_EQ(WasmEdge_ValueGetI32(Returns[0]), 444); + + // Test: Run the function "wasmedge_plugintest_cpp.sub". + FuncName = WasmEdge_StringCreateByCString("sub"); + FuncCxt = WasmEdge_ModuleInstanceFindFunction(ModInstCPP, FuncName); + WasmEdge_StringDelete(FuncName); + EXPECT_NE(FuncCxt, nullptr); + Params[0] = WasmEdge_ValueGenI32(666); + Params[1] = WasmEdge_ValueGenI32(555); + Res = WasmEdge_ExecutorInvoke(ExecCxt, FuncCxt, Params, 2, Returns, 1); + EXPECT_TRUE(WasmEdge_ResultOK(Res)); + EXPECT_EQ(WasmEdge_ValueGetI32(Returns[0]), 111); + + // Test: Run the function "wasmedge_plugintest_cpp.arg_len". + FuncName = WasmEdge_StringCreateByCString("arg_len"); + FuncCxt = WasmEdge_ModuleInstanceFindFunction(ModInstCPP, FuncName); + WasmEdge_StringDelete(FuncName); + EXPECT_NE(FuncCxt, nullptr); + Res = WasmEdge_ExecutorInvoke(ExecCxt, FuncCxt, nullptr, 0, Returns, 1); + EXPECT_TRUE(WasmEdge_ResultOK(Res)); + EXPECT_EQ(WasmEdge_ValueGetI32(Returns[0]), 0); + + // Test: Run the function "wasmedge_plugintest_cpp.name_size". + FuncName = WasmEdge_StringCreateByCString("name_size"); + FuncCxt = WasmEdge_ModuleInstanceFindFunction(ModInstCPP, FuncName); + WasmEdge_StringDelete(FuncName); + EXPECT_NE(FuncCxt, nullptr); + Res = WasmEdge_ExecutorInvoke(ExecCxt, FuncCxt, nullptr, 0, Returns, 1); + EXPECT_TRUE(WasmEdge_ResultOK(Res)); + EXPECT_EQ(WasmEdge_ValueGetI32(Returns[0]), 0); + + WasmEdge_ExecutorDelete(ExecCxt); + WasmEdge_StoreDelete(StoreCxt); + WasmEdge_ModuleInstanceDelete(ModInstC); + WasmEdge_ModuleInstanceDelete(ModInstCPP); +} + +TEST(wasmedgePluginTests, C_Module) { + WasmEdge_String NameBuf[16]; + + // Create the wasmedge_plugintest_c_module module instance. + auto *ModInstC = createModuleC(); + ASSERT_FALSE(ModInstC == nullptr); + EXPECT_EQ(WasmEdge_ModuleInstanceListFunctionLength(ModInstC), 2U); + std::memset(NameBuf, 0, sizeof(WasmEdge_String) * 16); + EXPECT_EQ(WasmEdge_ModuleInstanceListFunction(ModInstC, NameBuf, 16), 2U); + EXPECT_TRUE( + WasmEdge_StringIsEqual(NameBuf[0], WasmEdge_StringWrap("add", 3U))); + EXPECT_TRUE( + WasmEdge_StringIsEqual(NameBuf[1], WasmEdge_StringWrap("sub", 3U))); + WasmEdge_ModuleInstanceDelete(ModInstC); + + // Create the wasmedge_plugintest_cpp_module module instance. + auto *ModInstCPP = createModuleCPP(); + ASSERT_FALSE(ModInstCPP == nullptr); + EXPECT_EQ(WasmEdge_ModuleInstanceListFunctionLength(ModInstCPP), 4U); + std::memset(NameBuf, 0, sizeof(WasmEdge_String) * 16); + EXPECT_EQ(WasmEdge_ModuleInstanceListFunction(ModInstCPP, NameBuf, 16), 4U); + EXPECT_TRUE( + WasmEdge_StringIsEqual(NameBuf[0], WasmEdge_StringWrap("add", 3U))); + EXPECT_TRUE( + WasmEdge_StringIsEqual(NameBuf[1], WasmEdge_StringWrap("arg_len", 7U))); + EXPECT_TRUE( + WasmEdge_StringIsEqual(NameBuf[2], WasmEdge_StringWrap("name_size", 9U))); + EXPECT_TRUE( + WasmEdge_StringIsEqual(NameBuf[3], WasmEdge_StringWrap("sub", 3U))); + WasmEdge_ModuleInstanceDelete(ModInstCPP); +} + +GTEST_API_ int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/plugins/unittest/unittest_cpp.cpp b/test/plugins/unittest/unittest_cpp.cpp new file mode 100644 index 00000000..00efc45e --- /dev/null +++ b/test/plugins/unittest/unittest_cpp.cpp @@ -0,0 +1,120 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "common/defines.h" +#include "runtime/callingframe.h" +#include "runtime/instance/module.h" + +#include "testplugin.h" + +#include +#include +#include +#include +#include +#include + +namespace { +WasmEdge::Runtime::Instance::ModuleInstance *createModuleC() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "./libwasmedgePluginTestModuleC" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = + WasmEdge::Plugin::Plugin::find("wasmedge_plugintest_c"sv)) { + if (const auto *Module = + Plugin->findModule("wasmedge_plugintest_c_module"sv)) { + return Module->create().release(); + } + } + return nullptr; +} + +WasmEdge::Runtime::Instance::ModuleInstance *createModuleCPP() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "./libwasmedgePluginTestModuleCPP" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = + WasmEdge::Plugin::Plugin::find("wasmedge_plugintest_cpp"sv)) { + WasmEdge::PO::ArgumentParser Parser; + Plugin->registerOptions(Parser); + Parser.set_raw_value("name"sv, std::string("test_name")); + Parser.set_raw_value>( + "arg"sv, std::vector({"arg0", "arg1", "arg2", "arg3"})); + if (const auto *Module = + Plugin->findModule("wasmedge_plugintest_cpp_module"sv)) { + return Module->create().release(); + } + } + return nullptr; +} +} // namespace + +TEST(wasmedgePluginTests, CPP_Run) { + // Create the wasmedge_plugintest_cpp_module module instance. + auto *TestModCPP = dynamic_cast( + createModuleCPP()); + ASSERT_FALSE(TestModCPP == nullptr); + + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + std::array RetVal; + + // Get the function "arg_len". + auto *FuncInst1 = TestModCPP->findFuncExports("arg_len"); + EXPECT_NE(FuncInst1, nullptr); + EXPECT_TRUE(FuncInst1->isHostFunction()); + auto &HostFuncInst1 = + dynamic_cast( + FuncInst1->getHostFunc()); + + // Test: Run function successfully. + EXPECT_TRUE(HostFuncInst1.run(CallFrame, {}, RetVal)); + EXPECT_EQ(RetVal[0].get(), 4); + + // Get the function "name_size". + auto *FuncInst2 = TestModCPP->findFuncExports("name_size"); + EXPECT_NE(FuncInst2, nullptr); + EXPECT_TRUE(FuncInst2->isHostFunction()); + auto &HostFuncInst2 = + dynamic_cast( + FuncInst2->getHostFunc()); + + // Test: Run function successfully. + EXPECT_TRUE(HostFuncInst2.run(CallFrame, {}, RetVal)); + EXPECT_EQ(RetVal[0].get(), 9); + + delete TestModCPP; + + // Create the wasmedge_plugintest_c_module module instance. + auto *TestModC = createModuleC(); + ASSERT_FALSE(TestModC == nullptr); + // The host functions are implemented in the C API. + // Therefore not test to invoke them here. + delete TestModC; +} + +TEST(wasmedgePluginTests, CPP_Module) { + // Create the wasmedge_plugintest_cpp_module module instance. + auto *TestModCPP = dynamic_cast( + createModuleCPP()); + ASSERT_FALSE(TestModCPP == nullptr); + EXPECT_EQ(TestModCPP->getFuncExportNum(), 4U); + EXPECT_NE(TestModCPP->findFuncExports("add"), nullptr); + EXPECT_NE(TestModCPP->findFuncExports("sub"), nullptr); + EXPECT_NE(TestModCPP->findFuncExports("arg_len"), nullptr); + EXPECT_NE(TestModCPP->findFuncExports("name_size"), nullptr); + delete TestModCPP; + + // Create the wasmedge_plugintest_c_module module instance. + auto *TestModC = createModuleC(); + ASSERT_FALSE(TestModC == nullptr); + EXPECT_EQ(TestModC->getFuncExportNum(), 2U); + EXPECT_NE(TestModC->findFuncExports("add"), nullptr); + EXPECT_NE(TestModC->findFuncExports("sub"), nullptr); + delete TestModC; +} + +GTEST_API_ int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} From 6bc14a78ec10065366846472eed28af03667b31b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=91=E5=BE=AE?= <1067852565@qq.com> Date: Wed, 5 Apr 2023 08:52:57 +0800 Subject: [PATCH 094/623] [Plugin] implement the host function for eBPF programs (#2314) * update: wasm-bpf plugin Signed-off-by: officeyutong * tidy: remove useless comments Signed-off-by: officeyutong * update: cache support for bpf_map* at map operation Signed-off-by: officeyutong * update: don't bundle libbpf anymore Signed-off-by: officeyutong * update: formatted files with clang-format Signed-off-by: officeyutong * update: write more comments Signed-off-by: officeyutong * fix: rename wasm-bpf to wasm_bpf Signed-off-by: yunwei37 <1067852565@qq.com> Signed-off-by: officeyutong * fix: use ExternalProject_Add for building libbpf Signed-off-by: yunwei37 <1067852565@qq.com> Signed-off-by: officeyutong * fix: remove unnecessary includes and fix LICENSE Signed-off-by: yunwei37 <1067852565@qq.com> Signed-off-by: officeyutong * fix: refactor wasm-bpf functions Signed-off-by: yunwei37 <1067852565@qq.com> Signed-off-by: officeyutong * fix: fix compile error for bpf functions Signed-off-by: yunwei37 <1067852565@qq.com> Signed-off-by: officeyutong * fix: use pragma once instead of include Signed-off-by: yunwei37 <1067852565@qq.com> Signed-off-by: officeyutong * fix: use namespave host and WasmEdge Signed-off-by: yunwei37 <1067852565@qq.com> Signed-off-by: officeyutong * fix: a series of small issues Signed-off-by: yunwei37 <1067852565@qq.com> Signed-off-by: officeyutong * fix: convert int to int32_t Signed-off-by: yunwei37 <1067852565@qq.com> Signed-off-by: officeyutong * fix: move getMemoryByIndex to the front of func Signed-off-by: yunwei37 <1067852565@qq.com> Signed-off-by: officeyutong * use `reinterpret_cast` when casting pointers to integers Signed-off-by: officeyutong * Split `wasm-bpf-module.cpp` into header and source Signed-off-by: officeyutong * Link shared object of `libbpf` Signed-off-by: officeyutong * Add module function test for `wasm_bpf` Signed-off-by: officeyutong * Format files just written Signed-off-by: officeyutong * get the missing `wasm_bpf_program::map_ptr_by_fd` back Signed-off-by: officeyutong * fix `wasm_bpf_program::attach_bpf_program` Signed-off-by: officeyutong * add an check to poll Signed-off-by: officeyutong * fix wasm-bpf.cpp bug Signed-off-by: officeyutong * fix another wasm-bpf.cpp bug Signed-off-by: officeyutong * [test] add a test that run polling ebpf program Signed-off-by: officeyutong * [test] add `runqlat` test Signed-off-by: officeyutong * fix typo Signed-off-by: officeyutong * reduce size of WasmEdge_String names[] to 1 Signed-off-by: officeyutong * use std::size instead of sizeof(a)/sizeof(a[0]) Signed-off-by: officeyutong * format the code Signed-off-by: officeyutong * use static_cast instead of (xxx) Signed-off-by: officeyutong * use static_cast Signed-off-by: officeyutong * Update ways to get libbpf Signed-off-by: officeyutong * Use pkg-conf to get libelf and libz Signed-off-by: officeyutong * remove blank lines between comment and code Signed-off-by: officeyutong * use `\brief` instead of `@brief` Signed-off-by: officeyutong * expand a macro that is used only once Signed-off-by: officeyutong * remove some checks to the front of the function Signed-off-by: officeyutong * fix an invalid macro.. Signed-off-by: officeyutong * add a line of debug log when invalid fd is provided Signed-off-by: officeyutong * fix a format error Signed-off-by: officeyutong * Add a simple test to test if ringbuf works Signed-off-by: officeyutong * add a test about ebpf maps Signed-off-by: officeyutong * fix spell Signed-off-by: officeyutong * plugin: fix related comments and docs Signed-off-by: yunwei37 <1067852565@qq.com> Signed-off-by: officeyutong * add prefix for logging Signed-off-by: officeyutong * remove `struct` in variable decl Signed-off-by: officeyutong * use static_cast Signed-off-by: officeyutong * formatted the code Signed-off-by: officeyutong * add logs when provided invalid map operation Signed-off-by: officeyutong * fix typo Signed-off-by: officeyutong * remove binaries from commit tree Signed-off-by: officeyutong * fix license issue in `runqlat.bpf.c` Signed-off-by: officeyutong * Download & build bpf objs Signed-off-by: officeyutong --------- Signed-off-by: officeyutong Signed-off-by: yunwei37 <1067852565@qq.com> Co-authored-by: officeyutong --- plugins/CMakeLists.txt | 9 + plugins/wasm_bpf/CMakeLists.txt | 152 +++++ plugins/wasm_bpf/README.md | 109 ++++ plugins/wasm_bpf/bpf-api.h | 100 +++ plugins/wasm_bpf/func-attach-bpf-program.cpp | 31 + plugins/wasm_bpf/func-attach-bpf-program.h | 30 + plugins/wasm_bpf/func-bpf-buffer-poll.cpp | 49 ++ plugins/wasm_bpf/func-bpf-buffer-poll.h | 41 ++ plugins/wasm_bpf/func-bpf-map-fd-by-name.cpp | 28 + plugins/wasm_bpf/func-bpf-map-fd-by-name.h | 30 + plugins/wasm_bpf/func-bpf-map-operate.cpp | 68 ++ plugins/wasm_bpf/func-bpf-map-operate.h | 30 + plugins/wasm_bpf/func-close-bpf-object.cpp | 21 + plugins/wasm_bpf/func-close-bpf-object.h | 29 + plugins/wasm_bpf/func-load-bpf-object.cpp | 32 + plugins/wasm_bpf/func-load-bpf-object.h | 34 + plugins/wasm_bpf/state.h | 25 + plugins/wasm_bpf/util.cpp | 25 + plugins/wasm_bpf/util.h | 29 + plugins/wasm_bpf/wasm-bpf-module.cpp | 59 ++ plugins/wasm_bpf/wasm-bpf-module.h | 16 + plugins/wasm_bpf/wasm-bpf.cpp | 258 ++++++++ test/plugins/CMakeLists.txt | 6 + test/plugins/wasm_bpf/CMakeLists.txt | 24 + test/plugins/wasm_bpf/assets/.gitignore | 4 + test/plugins/wasm_bpf/assets/CMakeLists.txt | 41 ++ test/plugins/wasm_bpf/assets/README.md | 11 + .../wasm_bpf/assets/bpf-sources/.gitignore | 4 + .../wasm_bpf/assets/bpf-sources/Makefile | 20 + .../assets/bpf-sources/simple_map.bpf.c | 48 ++ .../assets/bpf-sources/simple_ringbuf.bpf.c | 38 ++ test/plugins/wasm_bpf/simple_map_test.cpp | 291 +++++++++ test/plugins/wasm_bpf/simple_ringbuf_test.cpp | 245 ++++++++ test/plugins/wasm_bpf/wasm_bpf.cpp | 580 ++++++++++++++++++ 34 files changed, 2517 insertions(+) create mode 100644 plugins/wasm_bpf/CMakeLists.txt create mode 100644 plugins/wasm_bpf/README.md create mode 100644 plugins/wasm_bpf/bpf-api.h create mode 100644 plugins/wasm_bpf/func-attach-bpf-program.cpp create mode 100644 plugins/wasm_bpf/func-attach-bpf-program.h create mode 100644 plugins/wasm_bpf/func-bpf-buffer-poll.cpp create mode 100644 plugins/wasm_bpf/func-bpf-buffer-poll.h create mode 100644 plugins/wasm_bpf/func-bpf-map-fd-by-name.cpp create mode 100644 plugins/wasm_bpf/func-bpf-map-fd-by-name.h create mode 100644 plugins/wasm_bpf/func-bpf-map-operate.cpp create mode 100644 plugins/wasm_bpf/func-bpf-map-operate.h create mode 100644 plugins/wasm_bpf/func-close-bpf-object.cpp create mode 100644 plugins/wasm_bpf/func-close-bpf-object.h create mode 100644 plugins/wasm_bpf/func-load-bpf-object.cpp create mode 100644 plugins/wasm_bpf/func-load-bpf-object.h create mode 100644 plugins/wasm_bpf/state.h create mode 100644 plugins/wasm_bpf/util.cpp create mode 100644 plugins/wasm_bpf/util.h create mode 100644 plugins/wasm_bpf/wasm-bpf-module.cpp create mode 100644 plugins/wasm_bpf/wasm-bpf-module.h create mode 100644 plugins/wasm_bpf/wasm-bpf.cpp create mode 100644 test/plugins/wasm_bpf/CMakeLists.txt create mode 100644 test/plugins/wasm_bpf/assets/.gitignore create mode 100644 test/plugins/wasm_bpf/assets/CMakeLists.txt create mode 100644 test/plugins/wasm_bpf/assets/README.md create mode 100644 test/plugins/wasm_bpf/assets/bpf-sources/.gitignore create mode 100644 test/plugins/wasm_bpf/assets/bpf-sources/Makefile create mode 100644 test/plugins/wasm_bpf/assets/bpf-sources/simple_map.bpf.c create mode 100644 test/plugins/wasm_bpf/assets/bpf-sources/simple_ringbuf.bpf.c create mode 100644 test/plugins/wasm_bpf/simple_map_test.cpp create mode 100644 test/plugins/wasm_bpf/simple_ringbuf_test.cpp create mode 100644 test/plugins/wasm_bpf/wasm_bpf.cpp diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index c33c4128..83c6d793 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -26,3 +26,12 @@ if(WASMEDGE_PLUGIN_HTTPSREQ) message(WARNING "Only Linux platforms support WasmEdge_HttpsReq plug-in now.") endif() endif() + +if(WASMEDGE_PLUGIN_WASM_BPF) + # Only Linux systems support wasm_bpf now. + if(CMAKE_SYSTEM_NAME MATCHES "Linux") + add_subdirectory(wasm_bpf) + else() + message(WARNING "Only Linux platforms support wasm_bpf plug-in now.") + endif() +endif() diff --git a/plugins/wasm_bpf/CMakeLists.txt b/plugins/wasm_bpf/CMakeLists.txt new file mode 100644 index 00000000..e64aa599 --- /dev/null +++ b/plugins/wasm_bpf/CMakeLists.txt @@ -0,0 +1,152 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +# Try to get libbpf use the following order +# - PkgConfig +# - ${LIBBPF_SOURCE_DIR} +# - FetchContent + +message(STATUS "Trying to get libbpf..") +set(LIBBPF_FOUND FALSE) + +# A wrapper function to add libbpf located at a local path as a dependency +function(AddLibbpfAsExternal SOURCE_ROOT) + include(ExternalProject) + ExternalProject_Add(libbpf + PREFIX libbpf + SOURCE_DIR ${SOURCE_ROOT} + CONFIGURE_COMMAND "" + BUILD_COMMAND "make" "-C" "${SOURCE_ROOT}/src" + INSTALL_COMMAND "" + BUILD_IN_SOURCE TRUE + ) + set(LIBBPF_SO_PATH ${SOURCE_ROOT}/src/libbpf.so) + set(LIBBPF_INCLUDE_DIRS ${SOURCE_ROOT}/src PARENT_SCOPE) + set(LIBBPF_LIBRARIES ${LIBBPF_SO_PATH} PARENT_SCOPE) + set(LIBBPF_TARGET_NAME libbpf PARENT_SCOPE) + file(COPY_FILE ${LIBBPF_SO_PATH} ${CMAKE_CURRENT_BINARY_DIR}/libbpf.so) + + # Copy libbpf.so to the place where libwasmedgePluginWasmBpf.so exists + message(STATUS "Copied libbpf.so from ${LIBBPF_SO_PATH} to ${CMAKE_CURRENT_BINARY_DIR}/libbpf.so") +endfunction() + +# Try PkgConfig +if(NOT ${LIBBPF_FOUND}) + find_package(PkgConfig) + + if(PkgConfig_FOUND) + message(STATUS "Try to get libbpf through PkgConfig") + + # It will set LIBBPF_FOUND for us + pkg_check_modules(LIBBPF libbpf>=1.2 IMPORTED_TARGET) + set(LIBBPF_TARGET_NAME "PkgConfig::LIBBPF") + + if(${LIBBPF_FOUND}) + message(STATUS "libbpf found using PkgConfig") + else() + message(STATUS "libbpf not found using pkgconfig") + endif() + else() + message(STATUS "PkgConfig not found") + endif() +endif() + +# Try LIBBPF_SOURCE_DIR +if(NOT ${LIBBPF_FOUND}) + message(STATUS "Try to get libbpf through the pre-defined LIBBPF_SOURCE_DIR") + + if(DEFINED LIBBPF_SOURCE_DIR) + AddLibbpfAsExternal(${LIBBPF_SOURCE_DIR}) + set(LIBBPF_FOUND TRUE) + message(STATUS "libbpf found using LIBBPF_SOURCE_DIR") + else() + message(STATUS "LIBBPF_SOURCE_DIR not defined") + endif() +endif() + +# Try FetchContent +if(NOT ${LIBBPF_FOUND}) + message(STATUS "Try to get libbpf through FetchContent") + include(FetchContent) + FetchContent_Declare( + libbpf + GIT_REPOSITORY https://github.com/libbpf/libbpf + GIT_TAG cf46d44f0a06aa8b9400691ea3eb86ca4f066d5c + ) + FetchContent_GetProperties(libbpf) + + if(NOT libbpf_POPULATED) + message(STATUS "Fetching libbpf..") + FetchContent_Populate(libbpf) + message(STATUS "Fetched libbpf") + endif() + + set(LIBBPF_DOWNLOAD_SOURCE_DIR "${libbpf_SOURCE_DIR}") + message(DEBUG "libbpf saved at: ${LIBBPF_DOWNLOAD_SOURCE_DIR}") + AddLibbpfAsExternal(${LIBBPF_DOWNLOAD_SOURCE_DIR}) + set(LIBBPF_FOUND TRUE) +endif() + +# If we cannot find libbpf.. +if(NOT ${LIBBPF_FOUND}) + message(FATAL_ERROR "Could not find libbpf") +endif() + +message(DEBUG "LIBBPF_INCLUDE_DIRS=${LIBBPF_INCLUDE_DIRS}") +message(DEBUG "LIBBPF_LIBRARIES=${LIBBPF_LIBRARIES}") +message(DEBUG "LIBBPF_TARGET_NAME=${LIBBPF_TARGET_NAME}") + +# Find the dependencies `libelf` and `libz` of libbpf +find_package(PkgConfig) + +pkg_check_modules(LIBBPF_DEP REQUIRED libelf zlib) + +message(DEBUG "LIBBPF_DEP_LIBRARIES=${LIBBPF_DEP_LIBRARIES}") + +wasmedge_add_library(wasmedgePluginWasmBpf + SHARED + wasm-bpf-module.cpp + func-load-bpf-object.cpp + func-close-bpf-object.cpp + func-attach-bpf-program.cpp + func-bpf-buffer-poll.cpp + func-bpf-map-fd-by-name.cpp + func-bpf-map-operate.cpp + wasm-bpf.cpp + util.cpp +) + +add_dependencies(wasmedgePluginWasmBpf ${LIBBPF_TARGET_NAME}) +target_link_libraries(wasmedgePluginWasmBpf PUBLIC ${LIBBPF_LIBRARIES} ${LIBBPF_DEP_LIBRARIES}) +target_include_directories(wasmedgePluginWasmBpf PUBLIC ${LIBBPF_INCLUDE_DIRS}) + +set_target_properties(wasmedgePluginWasmBpf PROPERTIES + CXX_STANDARD 17 +) + +target_compile_options(wasmedgePluginWasmBpf + PUBLIC + -DWASMEDGE_PLUGIN + -fPIC +) + +target_include_directories(wasmedgePluginWasmBpf + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} + ${LIBBPF_INCLUDE_DIRS} +) + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasmBpf + PRIVATE + wasmedgeCAPI + ${LIBBPF_LIBRARIES} + ) +else() + target_link_libraries(wasmedgePluginWasmBpf + PRIVATE + wasmedge_shared + ${LIBBPF_LIBRARIES} + ) +endif() diff --git a/plugins/wasm_bpf/README.md b/plugins/wasm_bpf/README.md new file mode 100644 index 00000000..05a2755f --- /dev/null +++ b/plugins/wasm_bpf/README.md @@ -0,0 +1,109 @@ +# wasm_bpf Plugin + +This plugin added six host functions that give you Wasm application access to eBPF. + +Six functions are listed here. And all of them are in the module `wasm_bpf`, if you loaded this plugin. +```c +/// lookup a bpf map fd by name. +i32 wasm_bpf_map_fd_by_name(u64 obj, u32 name); +/// detach and close a bpf program. +i32 wasm_close_bpf_object(u64 obj); +/// CO-RE load a bpf object into the kernel. +u64 wasm_load_bpf_object(u32 obj_buf, u32 obj_buf_sz); +/// attach a bpf program to a kernel hook. +i32 wasm_attach_bpf_program(u64 obj, u32 name, + u32 attach_target); +/// poll a bpf buffer, and call a wasm callback indicated by sample_func. +/// the first time to call this function will open and create a bpf buffer. +i32 wasm_bpf_buffer_poll(u64 program, i32 fd, u32 sample_func, + u32 ctx, u32 data, i32 max_size, + i32 timeout_ms); +/// lookup, update, delete, and get_next_key operations on a bpf map. +i32 wasm_bpf_map_operate(u64 fd, i32 cmd, u32 key, u32 value, + u32 next_key, u64 flags); +``` + +- `iXX` denotes signed integer with `XX` bits +- `uXX` denotes unsigned integer with `XX` bits + +## How to compile this plugin + +### Install dependencies + +See the https://wasmedge.org/book/en/contribute/build_from_src/linux.html for how to build `WasmEdge` from source. + +#### libbpf + +This plugin requires `libbpf >= 1.2` + +Follow [https://github.com/libbpf/libbpf#building-libbpf](https://github.com/libbpf/libbpf#building-libbpf) to build and install `libbpf`. + +### Build `wasm_bpf` plug-in + +Run the following commands at the root of the `WasmEdge` project: + +- Note: It's important to set `WASMEDGE_PLUGIN_WASM_BPF` to `TRUE` in the command line. This toggle controls the build of `wasm_bpf` plugin. + +``` +cmake -DWASMEDGE_PLUGIN_WASM_BPF:BOOL=TRUE -B ./build -G "Unix Makefiles" +cmake --build ./build +``` + +## How to use this plugin + +You can either download the examples or build them by yourself. + +### Download the examples + +```sh +wget https://eunomia-bpf.github.io/wasm-bpf/examples/runqlat/runqlat.wasm +``` + +### build the examples + +Examples of wasm-bpf programs can be found in [wasm-bpf](https://github.com/eunomia-bpf/wasm-bpf/tree/main/examples) repo. You can build them by running the following commands: + +```sh +# install the wasi-sdk if you don't have it +wget https://github.com/WebAssembly/wasi-sdk/releases/download/wasi-sdk-17/wasi-sdk-17.0-linux.tar.gz +tar -zxf wasi-sdk-17.0-linux.tar.gz +sudo mkdir -p /opt/wasi-sdk/ && sudo mv wasi-sdk-17.0/* /opt/wasi-sdk/ + +# build the examples +git clone https://github.com/eunomia-bpf/wasm-bpf +cd wasm-bpf/examples +git submodule update --init --recursive + +# for example, build the execve example +cd execve && make +``` + +All examples are: + +```console +$ ls +bootstrap execve go-execve go-lsm lsm opensnoop runqlat rust-bootstrap sockfilter sockops +``` + +### run the examples + +After building, you can find the plug-in `./build/plugins/wasm_bpf/libwasmedgePluginWasmBpf.so` and the WasmEdge CLI tool at `./build/tools/wasmedge/wasmedge`. + +Set `WASMEDGE_PLUGIN_PATH=./build/plugins/wasm_bpf/` and run wasmedge: + +```console +# WASMEDGE_PLUGIN_PATH=./build/plugins/wasm_bpf/ ./build/tools/wasmedge/wasmedge execve.wasm + +[289150] node -> /bin/sh -c which ps +[289151] sh -> which ps +[289152] node -> /bin/sh -c /usr/bin/ps -ax -o pid=,ppid=,pcpu=,pmem=,c +[289153] sh -> /usr/bin/ps -ax -o pid=,ppid=,pcpu=,pmem=,command= +[289154] node -> /bin/sh -c "/root/.vscode-server-insiders/bin/96a795cc +[289155] sh -> /root/.vscode-server-insiders/bin/96a795cc0 245632 245678 289148 +[289156] cpuUsage.sh -> sed -n s/^cpu\s//p /proc/stat +[289157] cpuUsage.sh -> cat /proc/245632/stat +[289158] cpuUsage.sh -> cat /proc/245678/stat +[289159] cpuUsage.sh -> cat /proc/289148/stat +[289160] cpuUsage.sh -> sleep 1 +^C +``` \ No newline at end of file diff --git a/plugins/wasm_bpf/bpf-api.h b/plugins/wasm_bpf/bpf-api.h new file mode 100644 index 00000000..9039cf69 --- /dev/null +++ b/plugins/wasm_bpf/bpf-api.h @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "executor/executor.h" +#include "runtime/instance/module.h" +#include "wasmedge/wasmedge.h" + +extern "C" { +#include +#include +} + +#define POLL_TIMEOUT_MS 100 +#define PERF_BUFFER_PAGES 64 +#define DEBUG_LIBBPF_RUNTIME 0 +#define DEBUG_PRINT_BUFFER_SIZE 1024 + +namespace WasmEdge { +namespace Host { + +/// \brief init libbpf callbacks +void init_libbpf(void); + +typedef int32_t (*bpf_buffer_sample_fn)(void *ctx, void *data, size_t size); + +/// An absraction of a bpf ring buffer or perf buffer +/// see https://github.com/iovisor/bcc/blob/master/libbpf-tools/compat.c +class bpf_buffer { +protected: + bpf_buffer_sample_fn fn; + WasmEdge_ExecutorContext *wasm_executor; + const WasmEdge_ModuleInstanceContext *wasm_module_instance; + uint32_t wasm_ctx; + uint32_t wasm_sample_function; + void *poll_data; + size_t max_poll_size; + uint32_t wasm_buf_ptr; + +public: + /// sample callback which calls the wasm handler indirectly + int32_t bpf_buffer_sample(void *data, size_t size); + /// Check if the bpf buffer is valid + /// + /// a valid module instance should have only one table and a sample function + bool is_valid() const; + /// set the wasm callback parameters + void + set_callback_params(WasmEdge_ExecutorContext *executor, + const WasmEdge_ModuleInstanceContext *module_instance, + uint32_t sample_func, void *data, size_t max_size, + uint32_t ctx, uint32_t buf_ptr); + /// polling the bpf buffer + virtual int32_t bpf_buffer__poll(int32_t timeout_ms) = 0; + /// open the bpf buffer map + virtual int32_t bpf_buffer__open(int32_t fd, bpf_buffer_sample_fn sample_cb, + void *ctx) = 0; + virtual ~bpf_buffer() noexcept = default; +}; + +/// bpf program instance +class wasm_bpf_program { + std::unique_ptr obj{nullptr, + bpf_object__close}; + std::unique_ptr buffer; + std::unordered_set> + links; + +public: + /// Find a bpf map fd by name + int32_t bpf_map_fd_by_name(const char *name); + /// Load a bpf object from a buffer into the kernel + int32_t load_bpf_object(const void *obj_buf, size_t obj_buf_sz); + /// Attach a bpf program to a target (e.g. a kernel function on a kprobe) + int32_t attach_bpf_program(const char *name, const char *attach_target); + /// Poll the bpf buffer to get data from the kernel + int32_t bpf_buffer_poll(WasmEdge_ExecutorContext *executor, + const WasmEdge_ModuleInstanceContext *module_instance, + int32_t fd, int32_t sample_func, uint32_t ctx, + void *buffer_data, size_t max_size, + int32_t timeout_ms, uint32_t wasm_buf_ptr); + /// Get the bpf map pointer by fd + bpf_map *map_ptr_by_fd(int32_t fd); +}; + +enum bpf_map_cmd { + _BPF_MAP_LOOKUP_ELEM = 1, + _BPF_MAP_UPDATE_ELEM, + _BPF_MAP_DELETE_ELEM, + _BPF_MAP_GET_NEXT_KEY, +}; + +/// Operate on a bpf map. +int32_t bpf_map_operate(int32_t fd, int32_t cmd, void *key, void *value, + void *next_key, uint64_t flags); +using handle_t = int64_t; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/func-attach-bpf-program.cpp b/plugins/wasm_bpf/func-attach-bpf-program.cpp new file mode 100644 index 00000000..be4d6ade --- /dev/null +++ b/plugins/wasm_bpf/func-attach-bpf-program.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "func-attach-bpf-program.h" +#include "util.h" + +namespace WasmEdge { +namespace Host { + +Expect AttachBpfProgram::body(const Runtime::CallingFrame &Frame, + handle_t program, uint32_t name, + uint32_t attach_target) { + auto *memory = Frame.getMemoryByIndex(0); + if (unlikely(!memory)) { + return Unexpect(ErrCode::Value::HostFuncError); + } + std::shared_lock lock(state->lock); + auto program_ptr = state->handles.find(program); + if (program_ptr == state->handles.end()) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + const char *name_str = nullptr; + const char *attach_target_str = nullptr; + checkAndSetCstr(memory, name, name_str); + checkAndSetCstr(memory, attach_target, attach_target_str); + return program_ptr->second->attach_bpf_program(name_str, attach_target_str); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/func-attach-bpf-program.h b/plugins/wasm_bpf/func-attach-bpf-program.h new file mode 100644 index 00000000..df5abedd --- /dev/null +++ b/plugins/wasm_bpf/func-attach-bpf-program.h @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "bpf-api.h" +#include "plugin/plugin.h" +#include "po/helper.h" +#include "runtime/callingframe.h" +#include "runtime/instance/module.h" +#include "state.h" + +namespace WasmEdge { +namespace Host { + +/// \brief Attach a bpf program to the specified target +class AttachBpfProgram + : public WasmEdge::Runtime::HostFunction { +public: + AttachBpfProgram(state_t state) : state(state) {} + WasmEdge::Expect body(const WasmEdge::Runtime::CallingFrame &Frame, + handle_t program, uint32_t name, + uint32_t attach_target); + +private: + state_t state; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/func-bpf-buffer-poll.cpp b/plugins/wasm_bpf/func-bpf-buffer-poll.cpp new file mode 100644 index 00000000..7d627ea7 --- /dev/null +++ b/plugins/wasm_bpf/func-bpf-buffer-poll.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "func-bpf-buffer-poll.h" +#include "wasmedge/wasmedge.h" +#include + +namespace WasmEdge { +namespace Host { + +// Helper functions of context conversions. +inline const auto *toCallFrameCxt(const Runtime::CallingFrame *Cxt) noexcept { + return reinterpret_cast(Cxt); +} +Expect BpfBufferPoll::body(const Runtime::CallingFrame &Frame, + handle_t program, int32_t fd, + int32_t sample_func, uint32_t ctx, + uint32_t data, int32_t max_size, + int32_t timeout_ms) { + auto c_ctx = toCallFrameCxt(&Frame); + auto c_module = WasmEdge_CallingFrameGetModuleInstance(c_ctx); + auto c_executor = WasmEdge_CallingFrameGetExecutor(c_ctx); + if (unlikely(!c_ctx || !c_module || !c_executor)) { + return Unexpect(ErrCode::Value::HostFuncError); + } + auto *memory = Frame.getMemoryByIndex(0); + if (unlikely(!memory)) { + return Unexpect(ErrCode::Value::HostFuncError); + } + auto module_instance = Frame.getModule(); + if (unlikely(!module_instance)) { + return Unexpect(ErrCode::Value::HostFuncError); + } + std::shared_lock lock(state->lock); + auto program_ptr = state->handles.find(program); + if (program_ptr == state->handles.end()) { + return Unexpect(ErrCode::Value::HostFuncError); + } + auto data_buf = memory->getPointer(data, max_size); + if (!data_buf) { + return Unexpect(ErrCode::Value::HostFuncError); + } + return program_ptr->second->bpf_buffer_poll(c_executor, c_module, fd, + sample_func, ctx, data_buf, + max_size, timeout_ms, data); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/func-bpf-buffer-poll.h b/plugins/wasm_bpf/func-bpf-buffer-poll.h new file mode 100644 index 00000000..6a5366de --- /dev/null +++ b/plugins/wasm_bpf/func-bpf-buffer-poll.h @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "bpf-api.h" +#include "plugin/plugin.h" +#include "po/helper.h" +#include "runtime/callingframe.h" +#include "runtime/instance/module.h" +#include "state.h" + +namespace WasmEdge { +namespace Host { + +/// Perform a bpf buffer poll. If the map is not opened, it will be opened. +/// +/// \param fd the map fd for bpf buffer. +/// \param sample_func callback function. When things are polled, it will be +/// invoked. +/// \param ctx user customized variable. +/// \param data data buffer that will be used to store the polled data. +/// \param max_size How many bytes can be put at data. +/// \param timeout_ms how many milliseconds can be waited. +/// +/// \return On success, return 0. On error, return error code. +class BpfBufferPoll : public WasmEdge::Runtime::HostFunction { +public: + BpfBufferPoll(state_t state) : state(state) {} + WasmEdge::Expect body(const WasmEdge::Runtime::CallingFrame &Frame, + handle_t program, int32_t fd, + int32_t sample_func, uint32_t ctx, + uint32_t data, int32_t max_size, + int32_t timeout_ms); + +private: + state_t state; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/func-bpf-map-fd-by-name.cpp b/plugins/wasm_bpf/func-bpf-map-fd-by-name.cpp new file mode 100644 index 00000000..c7a98eef --- /dev/null +++ b/plugins/wasm_bpf/func-bpf-map-fd-by-name.cpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "func-bpf-map-fd-by-name.h" +#include "util.h" +#include + +namespace WasmEdge { +namespace Host { + +Expect BpfMapFdByName::body(const Runtime::CallingFrame &Frame, + handle_t program, uint32_t name) { + const char *name_str = nullptr; + auto *memory = Frame.getMemoryByIndex(0); + if (unlikely(!memory)) { + return Unexpect(ErrCode::Value::HostFuncError); + } + checkAndSetCstr(memory, name, name_str); + std::shared_lock guard(this->state->lock); + auto program_ptr = state->handles.find(program); + if (program_ptr == state->handles.end()) { + return Unexpect(ErrCode::Value::HostFuncError); + } + return program_ptr->second->bpf_map_fd_by_name(name_str); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/func-bpf-map-fd-by-name.h b/plugins/wasm_bpf/func-bpf-map-fd-by-name.h new file mode 100644 index 00000000..d15f67db --- /dev/null +++ b/plugins/wasm_bpf/func-bpf-map-fd-by-name.h @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "bpf-api.h" +#include "plugin/plugin.h" +#include "po/helper.h" +#include "runtime/callingframe.h" +#include "runtime/instance/module.h" +#include "state.h" + +namespace WasmEdge { +namespace Host { + +/// \brief Lookup a map fd by its name. +/// +/// Map fd is returned if succeed, others if failed. +class BpfMapFdByName : public WasmEdge::Runtime::HostFunction { +public: + BpfMapFdByName(state_t state) : state(state) {} + WasmEdge::Expect body(const WasmEdge::Runtime::CallingFrame &Frame, + handle_t program, uint32_t name); + +private: + state_t state; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/func-bpf-map-operate.cpp b/plugins/wasm_bpf/func-bpf-map-operate.cpp new file mode 100644 index 00000000..d14f73b9 --- /dev/null +++ b/plugins/wasm_bpf/func-bpf-map-operate.cpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "func-bpf-map-operate.h" +#include "bpf-api.h" + +extern "C" { +#include +} + +namespace WasmEdge { +namespace Host { + +#define ensure_memory_size(var, offset, size) \ + void *var = memory->getPointer(offset, size); \ + if (!var) \ + return Unexpect(ErrCode::Value::HostFuncError); +Expect +BpfMapOperate::body(const WasmEdge::Runtime::CallingFrame &Frame, int32_t fd, + int32_t cmd, uint32_t key, uint32_t value, + uint32_t next_key, uint64_t flags) { + + auto *memory = Frame.getMemoryByIndex(0); + if (unlikely(!memory)) { + return Unexpect(ErrCode::Value::HostFuncError); + } + std::shared_lock guard(this->state->lock); + bpf_map_info map_info; + memset(&map_info, 0, sizeof(map_info)); + uint32_t info_len = sizeof(map_info); + int32_t err; + if ((err = bpf_map_get_info_by_fd(fd, &map_info, &info_len)) != 0) { + spdlog::debug("[WasmEdge Wasm_bpf] Invalid map fd found: fd={},err={}", fd, + err); + // Invalid map fd + return err; + } + auto key_size = map_info.key_size; + auto value_size = map_info.value_size; + + switch ((bpf_map_cmd)cmd) { + case BPF_MAP_GET_NEXT_KEY: { + ensure_memory_size(key_ptr, key, key_size); + ensure_memory_size(next_key_ptr, next_key, key_size); + return bpf_map_get_next_key(fd, key_ptr, next_key_ptr); + } + case BPF_MAP_LOOKUP_ELEM: { + ensure_memory_size(key_ptr, key, key_size); + ensure_memory_size(value_ptr, value, value_size); + return bpf_map_lookup_elem_flags(fd, key_ptr, value_ptr, flags); + } + case BPF_MAP_UPDATE_ELEM: { + ensure_memory_size(key_ptr, key, key_size); + ensure_memory_size(value_ptr, value, value_size); + return bpf_map_update_elem(fd, key_ptr, value_ptr, flags); + } + case BPF_MAP_DELETE_ELEM: { + ensure_memory_size(key_ptr, key, key_size); + return bpf_map_delete_elem_flags(fd, key_ptr, flags); + } + default: // More syscall commands can be allowed here + spdlog::debug("[WasmEdge Wasm_bpf] Invalid map operation", cmd); + return -EINVAL; + } +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/func-bpf-map-operate.h b/plugins/wasm_bpf/func-bpf-map-operate.h new file mode 100644 index 00000000..259f28de --- /dev/null +++ b/plugins/wasm_bpf/func-bpf-map-operate.h @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "bpf-api.h" +#include "plugin/plugin.h" +#include "po/helper.h" +#include "runtime/callingframe.h" +#include "runtime/instance/module.h" +#include "state.h" + +namespace WasmEdge { +namespace Host { + +/// Perform bpf map operations on a specified bpf map through map fd. +/// +/// Return zero if succeed, others if error +class BpfMapOperate : public WasmEdge::Runtime::HostFunction { +public: + BpfMapOperate(state_t state) : state(state) {} + WasmEdge::Expect body(const WasmEdge::Runtime::CallingFrame &Frame, + int32_t fd, int32_t cmd, uint32_t key, + uint32_t value, uint32_t next_key, + uint64_t flags); + +private: + state_t state; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/func-close-bpf-object.cpp b/plugins/wasm_bpf/func-close-bpf-object.cpp new file mode 100644 index 00000000..909b4f0f --- /dev/null +++ b/plugins/wasm_bpf/func-close-bpf-object.cpp @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "func-close-bpf-object.h" +#include + +namespace WasmEdge { +namespace Host { + +Expect CloseBpfObject::body(const WasmEdge::Runtime::CallingFrame &, + handle_t program) { + std::shared_lock guard(this->state->lock); + auto &handles = this->state->handles; + if (!handles.count(program)) { + return Unexpect(ErrCode::Value::HostFuncError); + } + return handles.erase(program) > 0 ? 0 : -1; +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/func-close-bpf-object.h b/plugins/wasm_bpf/func-close-bpf-object.h new file mode 100644 index 00000000..3c441238 --- /dev/null +++ b/plugins/wasm_bpf/func-close-bpf-object.h @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "bpf-api.h" +#include "plugin/plugin.h" +#include "po/helper.h" +#include "runtime/callingframe.h" +#include "runtime/instance/module.h" +#include "state.h" + +namespace WasmEdge { +namespace Host { + +/// \brief Close an opened bpf object. Will remove mapfds from the cache. +/// Return 0 if success. Others represent error codes. +class CloseBpfObject : public WasmEdge::Runtime::HostFunction { +public: + CloseBpfObject(state_t state) : state(state) {} + WasmEdge::Expect body(const WasmEdge::Runtime::CallingFrame &Frame, + handle_t program); + +private: + state_t state; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/func-load-bpf-object.cpp b/plugins/wasm_bpf/func-load-bpf-object.cpp new file mode 100644 index 00000000..fe576fd3 --- /dev/null +++ b/plugins/wasm_bpf/func-load-bpf-object.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "func-load-bpf-object.h" + +namespace WasmEdge { +namespace Host { + +Expect LoadBpfObject::body(const Runtime::CallingFrame &Frame, + uint32_t obj_buf, uint32_t obj_buf_sz) { + auto *memory = Frame.getMemoryByIndex(0); + if (unlikely(!memory)) { + return Unexpect(ErrCode::Value::HostFuncError); + } + char *const object_buffer = memory->getPointer(obj_buf, obj_buf_sz); + if (!object_buffer) { + return Unexpect(ErrCode::Value::HostFuncError); + } + auto program = std::make_unique(); + int32_t res = + program->load_bpf_object(object_buffer, static_cast(obj_buf_sz)); + if (res < 0) + return 0; + auto key = reinterpret_cast(program.get()); + + std::shared_lock guard(state->lock); + state->handles.emplace(key, std::move(program)); + return key; +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/func-load-bpf-object.h b/plugins/wasm_bpf/func-load-bpf-object.h new file mode 100644 index 00000000..5a6226ad --- /dev/null +++ b/plugins/wasm_bpf/func-load-bpf-object.h @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "bpf-api.h" +#include "plugin/plugin.h" +#include "po/helper.h" +#include "runtime/callingframe.h" +#include "runtime/instance/module.h" +#include "state.h" + +namespace WasmEdge { +namespace Host { + +/// \brief Load a bpf ELF file. +/// +/// Binary file should be provided through a Wasm Buffer. wasm_bpf will handle +/// the remaining process Call to this function will also cache bpf map fds. +/// +/// \return a handle to a bpf program, which is stored in a map in the global +/// state. Return 0 if failed. +class LoadBpfObject : public WasmEdge::Runtime::HostFunction { +public: + LoadBpfObject(state_t state) : state(state) {} + WasmEdge::Expect body(const WasmEdge::Runtime::CallingFrame &Frame, + uint32_t obj_buf, uint32_t obj_buf_sz); + +private: + state_t state; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/state.h b/plugins/wasm_bpf/state.h new file mode 100644 index 00000000..7ea995b8 --- /dev/null +++ b/plugins/wasm_bpf/state.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "bpf-api.h" +#include +#include +#include +#include + +namespace WasmEdge { +namespace Host { + +struct WasmBpfState { + /// manage bpf programs + std::unordered_map> handles; + std::shared_mutex lock; + ~WasmBpfState() noexcept = default; +}; + +using state_t = std::shared_ptr; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/util.cpp b/plugins/wasm_bpf/util.cpp new file mode 100644 index 00000000..28d383d9 --- /dev/null +++ b/plugins/wasm_bpf/util.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "util.h" + +namespace WasmEdge { +namespace Host { + +Expect read_c_str(Runtime::Instance::MemoryInstance *memory, + uint32_t ptr) { + uint32_t tail = ptr; + while (true) { + auto ch = memory->getBytes(tail, 1); + if (!ch.has_value()) + return Unexpect(ch.error()); + if (ch.value()[0] == '\0') + break; + tail++; + } + uint32_t len = tail - ptr + 1; + return memory->getPointer(ptr, len); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/util.h b/plugins/wasm_bpf/util.h new file mode 100644 index 00000000..958290ff --- /dev/null +++ b/plugins/wasm_bpf/util.h @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "runtime/instance/memory.h" + +namespace WasmEdge { +namespace Host { + +/// \brief read a c string from memory and check if it is null terminated +/// \param memory memory instance from wasm runtime +/// \param ptr the wasm32 buffer pointer +/// \return +WasmEdge::Expect +read_c_str(WasmEdge::Runtime::Instance::MemoryInstance *memory, uint32_t ptr); + +/// Check exist and set cstr, or return `Unexpect`. +#define checkAndSetCstr(memory, name, name_str) \ + do { \ + if (auto res = read_c_str(memory, name); unlikely(!res)) { \ + return Unexpect(res.error()); \ + } else { \ + name_str = *res; \ + } \ + } while (0) + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/wasm-bpf-module.cpp b/plugins/wasm_bpf/wasm-bpf-module.cpp new file mode 100644 index 00000000..9ae739d4 --- /dev/null +++ b/plugins/wasm_bpf/wasm-bpf-module.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "wasm-bpf-module.h" +#include "func-attach-bpf-program.h" +#include "func-bpf-buffer-poll.h" +#include "func-bpf-map-fd-by-name.h" +#include "func-bpf-map-operate.h" +#include "func-close-bpf-object.h" +#include "func-load-bpf-object.h" +#include "plugin/plugin.h" +#include "po/helper.h" +#include "runtime/callingframe.h" +#include "state.h" +#include +#include + +namespace WasmEdge { +namespace Host { + +using namespace std::literals::string_view_literals; + +WasmBpfModule::WasmBpfModule() : ModuleInstance("wasm_bpf") { + state_t state = std::make_shared(); + addHostFunc("wasm_load_bpf_object", std::make_unique(state)); + addHostFunc("wasm_close_bpf_object", std::make_unique(state)); + addHostFunc("wasm_attach_bpf_program", + std::make_unique(state)); + addHostFunc("wasm_bpf_buffer_poll", std::make_unique(state)); + addHostFunc("wasm_bpf_map_fd_by_name", + std::make_unique(state)); + addHostFunc("wasm_bpf_map_operate", std::make_unique(state)); +} + +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmBpfModule; +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasm_bpf", + .Description = "A plugin provides API for eBPF", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 1, 0, 0}, + .ModuleCount = 1, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "wasm_bpf", + .Description = "Provide functions for eBPF", + .Create = create, + }, + }, + .AddOptions = nullptr}; + +Plugin::PluginRegister Register(&Descriptor); + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/wasm-bpf-module.h b/plugins/wasm_bpf/wasm-bpf-module.h new file mode 100644 index 00000000..8cbdd9dd --- /dev/null +++ b/plugins/wasm_bpf/wasm-bpf-module.h @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasmBpfModule : public Runtime::Instance::ModuleInstance { +public: + WasmBpfModule(); +}; +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasm_bpf/wasm-bpf.cpp b/plugins/wasm_bpf/wasm-bpf.cpp new file mode 100644 index 00000000..9adb55eb --- /dev/null +++ b/plugins/wasm_bpf/wasm-bpf.cpp @@ -0,0 +1,258 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include +#include +#include +#include +#include + +#include "bpf-api.h" +#include "common/types.h" +#include "wasmedge/wasmedge.h" + +extern "C" { +#include +#include +} + +static int32_t bpf_buffer_sample(void *ctx, void *data, size_t size); +static int32_t libbpf_print_fn(enum libbpf_print_level level, + const char *format, va_list args) { + if (level == LIBBPF_DEBUG && DEBUG_LIBBPF_RUNTIME) + return 0; + char buf[DEBUG_PRINT_BUFFER_SIZE]; + int32_t len = vsnprintf(buf, sizeof(buf), format, args); + spdlog::debug("[WasmEdge Wasm_bpf] {}", buf); + return len; +} + +/// \brief perf buffer sample callback +static void perfbuf_sample_fn(void *ctx, int32_t cpu, void *data, __u32 size) { + static_cast(cpu); + bpf_buffer_sample(ctx, data, size); +} + +/// \brief sample the perf buffer and ring buffer +static int32_t bpf_buffer_sample(void *ctx, void *data, size_t size) { + WasmEdge::Host::bpf_buffer *buffer = + static_cast(ctx); + return buffer->bpf_buffer_sample(data, size); +} + +namespace WasmEdge { +namespace Host { + +/// \brief initialize libbpf library +void init_libbpf(void) { + libbpf_set_strict_mode(LIBBPF_STRICT_ALL); + libbpf_set_print(libbpf_print_fn); +} + +class perf_buffer_wrapper : public bpf_buffer { + std::unique_ptr inner{ + nullptr, perf_buffer__free}; + +public: + perf_buffer_wrapper(bpf_map *events) { + bpf_map__set_type(events, BPF_MAP_TYPE_PERF_EVENT_ARRAY); + bpf_map__set_key_size(events, sizeof(int)); + bpf_map__set_value_size(events, sizeof(int)); + } + int32_t bpf_buffer__poll(int32_t timeout_ms) override { + return perf_buffer__poll(inner.get(), timeout_ms); + } + int32_t bpf_buffer__open(int32_t fd, bpf_buffer_sample_fn sample_cb, + void *ctx) override { + fn = sample_cb; + inner.reset(perf_buffer__new(fd, PERF_BUFFER_PAGES, perfbuf_sample_fn, + nullptr, ctx, nullptr)); + return inner ? 0 : -EINVAL; + } +}; + +struct ring_buffer_wrapper : public bpf_buffer { +public: + std::unique_ptr inner{ + nullptr, ring_buffer__free}; + ring_buffer_wrapper(bpf_map *events) { + bpf_map__set_autocreate(events, false); + } + int32_t bpf_buffer__poll(int32_t timeout_ms) override { + return ring_buffer__poll(inner.get(), timeout_ms); + } + int32_t bpf_buffer__open(int32_t fd, bpf_buffer_sample_fn sample_cb, + void *ctx) override { + inner.reset(ring_buffer__new(fd, sample_cb, ctx, nullptr)); + return inner ? 0 : -1; + } +}; + +void bpf_buffer::set_callback_params( + WasmEdge_ExecutorContext *executor, + const WasmEdge_ModuleInstanceContext *module_instance, uint32_t sample_func, + void *data, size_t max_size, uint32_t ctx, uint32_t buf_ptr) { + wasm_executor = executor; + wasm_module_instance = module_instance; + wasm_sample_function = sample_func; + poll_data = data; + max_poll_size = max_size; + wasm_ctx = ctx; + wasm_buf_ptr = buf_ptr; +} + +bool bpf_buffer::is_valid() const { + auto module_inst = wasm_module_instance; + WasmEdge_String names; + uint32_t exported_table_len = + WasmEdge_ModuleInstanceListTable(module_inst, &names, 1); + if (exported_table_len != 1) { + return false; + } + auto table_inst = WasmEdge_ModuleInstanceFindTable(module_inst, names); + if (!table_inst) { + return false; + } + WasmEdge_Value value; + auto get_data_result = + WasmEdge_TableInstanceGetData(table_inst, &value, wasm_sample_function); + return WasmEdge_ResultOK(get_data_result); +} + +int32_t bpf_buffer::bpf_buffer_sample(void *data, size_t size) { + size_t sample_size = size; + if (max_poll_size < size) { + sample_size = max_poll_size; + } + memcpy(poll_data, data, sample_size); + auto module_inst = wasm_module_instance; + WasmEdge_String names[1]; + /// a valid module instance should have only one table + uint32_t exported_table_len = + WasmEdge_ModuleInstanceListTable(module_inst, names, std::size(names)); + assuming(exported_table_len == 1); + auto table_inst = WasmEdge_ModuleInstanceFindTable(module_inst, names[0]); + assuming(table_inst); + WasmEdge_Value value; + auto get_data_result = + WasmEdge_TableInstanceGetData(table_inst, &value, wasm_sample_function); + assuming(WasmEdge_ResultOK(get_data_result)); + assert(value.Type == WasmEdge_ValType::WasmEdge_ValType_FuncRef); + auto func_ref = WasmEdge_ValueGetFuncRef(value); + + WasmEdge_Value invoke_func_params[3] = { + WasmEdge_ValueGenI32(wasm_ctx), + WasmEdge_ValueGenI32(wasm_buf_ptr), + WasmEdge_ValueGenI32(size), + }; + WasmEdge_Value invoke_func_result; + auto call_result = WasmEdge_ExecutorInvoke( + wasm_executor, func_ref, invoke_func_params, 3, &invoke_func_result, 1); + if (!WasmEdge_ResultOK(call_result)) { + return -EINVAL; + } + return WasmEdge_ValueGetI32(invoke_func_result); +} + +/// \brief create a bpf buffer based on the object map type +std::unique_ptr bpf_buffer__new(bpf_map *events) { + bpf_map_type map_type = bpf_map__type(events); + switch (map_type) { + case BPF_MAP_TYPE_PERF_EVENT_ARRAY: + return std::make_unique(events); + case BPF_MAP_TYPE_RINGBUF: + return std::make_unique(events); + default: + return nullptr; + } +} + +/// Get the file descriptor of a map by name. +int32_t wasm_bpf_program::bpf_map_fd_by_name(const char *name) { + return bpf_object__find_map_fd_by_name(obj.get(), name); +} +/// \brief load all bpf programs and maps in a object file. +int32_t wasm_bpf_program::load_bpf_object(const void *obj_buf, + size_t obj_buf_sz) { + auto object = bpf_object__open_mem(obj_buf, obj_buf_sz, nullptr); + if (!object) { + return static_cast(libbpf_get_error(object)); + } + obj.reset(object); + return bpf_object__load(object); +} + +/// \brief attach a specific bpf program by name and target. +int32_t wasm_bpf_program::attach_bpf_program(const char *name, + const char *attach_target) { + bpf_link *link; + if (!attach_target) { + // auto attach base on bpf_program__section_name. The works well for most + // bpf types, include kprobe, uprobe, fentry, lsm, etc. + link = + bpf_program__attach(bpf_object__find_program_by_name(obj.get(), name)); + } else { + bpf_object *o = obj.get(); + bpf_program *prog = bpf_object__find_program_by_name(o, name); + if (!prog) { + spdlog::error("[WasmEdge Wasm_bpf] get prog {} fail", name); + return -1; + } + // TODO: attach dynamically base on bpf_program__section_name(prog) and + // attach_target to support more attach type libbpf cannot auto attach. For + // example, if bpf_program__section_name(prog) is "xdp" and attach_target is + // "eth0", or attach sockops to a socket fd. For now, we will try auto + // attach as well. + link = + bpf_program__attach(bpf_object__find_program_by_name(obj.get(), name)); + } + if (!link) { + return static_cast(libbpf_get_error(link)); + } + links.emplace(std::unique_ptr{ + link, bpf_link__destroy}); + return 0; +} + +/// \brief get map pointer by fd through iterating over all maps +bpf_map *wasm_bpf_program::map_ptr_by_fd(int fd) { + bpf_map *curr = nullptr; + bpf_map__for_each(curr, obj.get()) { + if (bpf_map__fd(curr) == fd) { + return curr; + } + } + return nullptr; +} + +/// polling the buffer, if the buffer is not created, create it. +int32_t wasm_bpf_program::bpf_buffer_poll( + WasmEdge_ExecutorContext *executor, + const WasmEdge_ModuleInstanceContext *module_instance, int32_t fd, + int32_t sample_func, uint32_t ctx, void *data, size_t max_size, + int32_t timeout_ms, uint32_t wasm_buf_ptr) { + int32_t res; + if (!buffer.get()) { + // create buffer + auto map = map_ptr_by_fd(fd); + buffer = bpf_buffer__new(map); + if (!buffer) { + return -1; + } + res = buffer->bpf_buffer__open(fd, bpf_buffer_sample, buffer.get()); + if (res < 0) { + return res; + } + } + buffer->set_callback_params(executor, module_instance, + static_cast(sample_func), data, + max_size, ctx, wasm_buf_ptr); + if (!buffer->is_valid()) { + return -EINVAL; + } + // poll the buffer + return buffer->bpf_buffer__poll(timeout_ms); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt index 2ab4a70d..2a46e044 100644 --- a/test/plugins/CMakeLists.txt +++ b/test/plugins/CMakeLists.txt @@ -21,6 +21,12 @@ if(WASMEDGE_PLUGIN_WASI_CRYPTO) add_subdirectory(wasi_crypto) endif() +if(WASMEDGE_PLUGIN_WASM_BPF) + if(CMAKE_SYSTEM_NAME MATCHES "Linux") + add_subdirectory(wasm_bpf) + endif() +endif() + if(CMAKE_SYSTEM_NAME MATCHES "Linux") add_subdirectory(unittest) endif() diff --git a/test/plugins/wasm_bpf/CMakeLists.txt b/test/plugins/wasm_bpf/CMakeLists.txt new file mode 100644 index 00000000..26e9c3fe --- /dev/null +++ b/test/plugins/wasm_bpf/CMakeLists.txt @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +wasmedge_add_executable(wasmBpfTests + simple_map_test.cpp + simple_ringbuf_test.cpp + wasm_bpf.cpp +) + +target_link_libraries(wasmBpfTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} + wasmedgePlugin + wasmedgePluginWasmBpf + wasmedgeExecutor +) +add_subdirectory(assets) + +add_dependencies( + wasmBpfTests + wasmBpfTestsAssets +) + +add_test(wasmBpfTests wasmBpfTests) diff --git a/test/plugins/wasm_bpf/assets/.gitignore b/test/plugins/wasm_bpf/assets/.gitignore new file mode 100644 index 00000000..2247a974 --- /dev/null +++ b/test/plugins/wasm_bpf/assets/.gitignore @@ -0,0 +1,4 @@ +bootstrap.bpf.o +runqlat.bpf.o +simple_ringbuf.bpf.o +simple_map.bpf.o diff --git a/test/plugins/wasm_bpf/assets/CMakeLists.txt b/test/plugins/wasm_bpf/assets/CMakeLists.txt new file mode 100644 index 00000000..80749a7e --- /dev/null +++ b/test/plugins/wasm_bpf/assets/CMakeLists.txt @@ -0,0 +1,41 @@ +include(FetchContent) + +# Download wasm-bpf, copy & compile two bpf programs from that +FetchContent_Declare( + wasm_bpf + GIT_REPOSITORY https://github.com/eunomia-bpf/wasm-bpf + GIT_TAG b76be32d44c2ec1933ca28eab875b50e713855b8 +) + +message("Downloading wasm-bpf") +FetchContent_MakeAvailable(wasm_bpf) +message("Downloaded wasm-bpf") + +add_custom_command( + OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/bootstrap.bpf.o + COMMAND make -C ${wasm_bpf_SOURCE_DIR}/examples/bootstrap/ bootstrap.bpf.o + COMMAND cp ${wasm_bpf_SOURCE_DIR}/examples/bootstrap/bootstrap.bpf.o ${CMAKE_CURRENT_SOURCE_DIR} + WORKING_DIRECTORY ${wasm_bpf_SOURCE_DIR}/examples/bootstrap/ +) + +add_custom_command( + OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/runqlat.bpf.o + COMMAND make -C ${wasm_bpf_SOURCE_DIR}/examples/runqlat/ runqlat.bpf.o + COMMAND cp ${wasm_bpf_SOURCE_DIR}/examples/runqlat/runqlat.bpf.o ${CMAKE_CURRENT_SOURCE_DIR} + WORKING_DIRECTORY ${wasm_bpf_SOURCE_DIR}/examples/runqlat/ +) + +add_custom_command( + OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/simple_map.bpf.o ${CMAKE_CURRENT_SOURCE_DIR}/simple_ringbuf.bpf.o + COMMAND make -C ${CMAKE_CURRENT_SOURCE_DIR}/bpf-sources + COMMAND cp ${CMAKE_CURRENT_SOURCE_DIR}/bpf-sources/*.bpf.o ${CMAKE_CURRENT_SOURCE_DIR} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + +add_custom_target( + wasmBpfTestsAssets + DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/bootstrap.bpf.o + ${CMAKE_CURRENT_SOURCE_DIR}/runqlat.bpf.o + ${CMAKE_CURRENT_SOURCE_DIR}/simple_ringbuf.bpf.o + ${CMAKE_CURRENT_SOURCE_DIR}/simple_map.bpf.o +) diff --git a/test/plugins/wasm_bpf/assets/README.md b/test/plugins/wasm_bpf/assets/README.md new file mode 100644 index 00000000..33fa78dd --- /dev/null +++ b/test/plugins/wasm_bpf/assets/README.md @@ -0,0 +1,11 @@ +This file contains bpf programs that will be used during testing. + + +- `bootstrap` and `runqlat`: examples copied from `wasm-bpf`. See [here](https://github.com/eunomia-bpf/wasm-bpf/tree/main/examples) for build instructions. + +- `simple_ringbuf`: A simple ebpf program which writes fixed data to a ring buffer +- `simple_map`: A simple ebpf program which stores fixed data to a bpf map + +The source of `simple_ringbuf` and `simple_map` are listed under `bpf-sources`. Run `make` under that directory to build them. + +`libbpf` and `clang` are required to build them. diff --git a/test/plugins/wasm_bpf/assets/bpf-sources/.gitignore b/test/plugins/wasm_bpf/assets/bpf-sources/.gitignore new file mode 100644 index 00000000..2cf24234 --- /dev/null +++ b/test/plugins/wasm_bpf/assets/bpf-sources/.gitignore @@ -0,0 +1,4 @@ +*.bpf.o +bootstrap.bpf.c +runqlat.bpf.c +bootstrap.h diff --git a/test/plugins/wasm_bpf/assets/bpf-sources/Makefile b/test/plugins/wasm_bpf/assets/bpf-sources/Makefile new file mode 100644 index 00000000..b65b866c --- /dev/null +++ b/test/plugins/wasm_bpf/assets/bpf-sources/Makefile @@ -0,0 +1,20 @@ + +DEL = rm -rf +FILES = $(shell ls *.bpf.c | awk '{split($$0,a,".");print a[1]}') + +ARCH ?= $(shell uname -m | sed 's/x86_64/x86/' | sed 's/aarch64/arm64/' | sed 's/ppc64le/powerpc/' | sed 's/mips.*/mips/') +CLANG_BPF_SYS_INCLUDES = $(shell $(CLANG) -v -E - &1 \ + | sed -n '/<...> search starts here:/,/End of search list./{ s| \(/.*\)|-idirafter \1|p }') + +all: $(FILES) + +$(FILES) : % : %.bpf.c + $(DEL) $@.bpf.o + clang -g -O2 -target bpf -D__TARGET_ARCH_$(ARCH) -c $(filter %.c,$^) -o $@.bpf.o + llvm-strip -g $@.bpf.o + cp $@.bpf.o ../$@.bpf.o + +%.bpf.o : % + +clean: + $(DEL) *.bpf.o diff --git a/test/plugins/wasm_bpf/assets/bpf-sources/simple_map.bpf.c b/test/plugins/wasm_bpf/assets/bpf-sources/simple_map.bpf.c new file mode 100644 index 00000000..a787ceaa --- /dev/null +++ b/test/plugins/wasm_bpf/assets/bpf-sources/simple_map.bpf.c @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#define SEC(name) __attribute__((section(name), used)) +#define __uint(name, val) int(*name)[val] +#define __type(name, val) typeof(val) *name + +#define __u64 unsigned long long +#define __u32 unsigned int + +#define u32 __u32 +#define u64 __u64 + +#define BPF_MAP_TYPE_HASH ((u32)1) +struct { + __uint(type, BPF_MAP_TYPE_HASH); + __uint(max_entries, 16); + __type(key, u32); + __type(value, u64); +} test_map SEC(".maps"); +static void *(*bpf_map_lookup_elem)(void *map, const void *key) = (void *)1; +static long (*bpf_map_update_elem)(void *map, const void *key, + const void *value, __u64 flags) = (void *)2; + +static const u32 INDICATING_KEY = 0xABCD; +static const u32 ADD_VALUE_1_KEY = 0xCDEF; +static const u32 ADD_VALUE_2_KEY = 0x1234; +static const u32 RESULT_VALUE_KEY = 0x7890; +SEC("tp_btf/sched_wakeup") +int sched_wakeup(void *ctx) { + // Use an element with key `0xABCD` to indicate that the userspace program + // already set values of the add values. + if (!bpf_map_lookup_elem(&test_map, &INDICATING_KEY)) { + return 0; + } + // Read the two add values from the map + u64 *val1, *val2; + val1 = bpf_map_lookup_elem(&test_map, &ADD_VALUE_1_KEY); + val2 = bpf_map_lookup_elem(&test_map, &ADD_VALUE_2_KEY); + if (!val1 || !val2) + return 0; + u64 result = (u64)(*val1) + (u64)(*val2); + // Store the result + bpf_map_update_elem(&test_map, &RESULT_VALUE_KEY, &result, 0); + return 0; +} + +char LICENSE[] SEC("license") = "Dual BSD/GPL"; diff --git a/test/plugins/wasm_bpf/assets/bpf-sources/simple_ringbuf.bpf.c b/test/plugins/wasm_bpf/assets/bpf-sources/simple_ringbuf.bpf.c new file mode 100644 index 00000000..888c6e1a --- /dev/null +++ b/test/plugins/wasm_bpf/assets/bpf-sources/simple_ringbuf.bpf.c @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#define SEC(name) __attribute__((section(name), used)) +#define __uint(name, val) int(*name)[val] + +#define __u64 unsigned long long +#define __u32 unsigned int + +#define u32 __u32 + +#define BPF_MAP_TYPE_RINGBUF ((u32)27) + +char LICENSE[] SEC("license") = "Dual BSD/GPL"; +static void *(*bpf_ringbuf_reserve)(void *ringbuf, __u64 size, + __u64 flags) = (void *)131; + +static void (*bpf_ringbuf_submit)(void *data, __u64 flags) = (void *)132; + +struct { + __uint(type, BPF_MAP_TYPE_RINGBUF); + __uint(max_entries, 256 * 1024); +} rb SEC(".maps"); + +SEC("tp/sched/sched_process_exec") +int handle_exec(void *ctx) { + u32 send_data; + send_data = 0xABCD1234; + + /* reserve sample from BPF ringbuf */ + u32 *e = bpf_ringbuf_reserve(&rb, sizeof(send_data), 0); + if (!e) + return 0; + *e = send_data; + /* successfully submit it to user-space for post-processing */ + bpf_ringbuf_submit(e, 0); + return 0; +} diff --git a/test/plugins/wasm_bpf/simple_map_test.cpp b/test/plugins/wasm_bpf/simple_map_test.cpp new file mode 100644 index 00000000..e902bbc2 --- /dev/null +++ b/test/plugins/wasm_bpf/simple_map_test.cpp @@ -0,0 +1,291 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include +#include +#include +#include +#include +#include +#include +#include "common/defines.h" +#include "executor/executor.h" +#include "func-attach-bpf-program.h" +#include "func-bpf-map-fd-by-name.h" +#include "func-bpf-map-operate.h" +#include "func-close-bpf-object.h" +#include "func-load-bpf-object.h" +#include "plugin/plugin.h" +#include "runtime/instance/module.h" +#include "wasm-bpf-module.h" +namespace { +WasmEdge::Runtime::Instance::ModuleInstance* createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasm_bpf/" + "libwasmedgePluginWasmBpf" WASMEDGE_LIB_EXTENSION)); + if (const auto* Plugin = WasmEdge::Plugin::Plugin::find("wasm_bpf"sv)) { + if (const auto* Module = Plugin->findModule("wasm_bpf"sv)) { + return Module->create().release(); + } + } + return nullptr; +} + +std::filesystem::path getAssertsPath() { + std::filesystem::path thisFile(__FILE__); + return thisFile.parent_path() / "assets"; +} + +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance& memInst, + uint32_t offset, + const std::vector& data) noexcept { + char* buf = memInst.getPointer(offset); + std::copy(data.begin(), data.end(), buf); +} +} // namespace + +static const uint32_t INDICATING_KEY = 0xABCD; +static const uint32_t ADD_VALUE_1_KEY = 0xCDEF; +static const uint32_t ADD_VALUE_2_KEY = 0x1234; +static const uint32_t RESULT_VALUE_KEY = 0x7890; + +TEST(WasmBpfTest, SimpleMapTest) { + using namespace std::string_view_literals; + // Test loading and attaching a bpf program, and some operations of maps + auto module = dynamic_cast(createModule()); + ASSERT_NE(module, nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance moduleInst(""); + // moduleInst.addHostFunc() + moduleInst.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto* memoryInst = moduleInst.findMemoryExports("memory"); + ASSERT_NE(memoryInst, nullptr); + auto& memoryInstRef = *memoryInst; + WasmEdge::Executor::Executor executor((WasmEdge::Configure())); + WasmEdge::Runtime::CallingFrame CallFrame(&executor, &moduleInst); + + namespace fs = std::filesystem; + auto bpfObject = getAssertsPath() / "simple_map.bpf.o"; + + // Ensure the bpf object we need exists + ASSERT_TRUE(fs::exists(bpfObject)); + + // Read the bpf object into wasm memory + std::ifstream bpfObjStream(bpfObject); + ASSERT_TRUE(bpfObjStream.is_open()); + ASSERT_TRUE(bpfObjStream.good()); + std::vector bpfObjectBytes( + (std::istreambuf_iterator(bpfObjStream)), + std::istreambuf_iterator()); + ASSERT_FALSE(bpfObjectBytes.empty()); + // Offset to put things into memory + uint32_t nextOffset = 1; + + // Put the bpf object into memory + const uint32_t bpfObjectMemoryOffset = nextOffset; + fillMemContent(memoryInstRef, bpfObjectMemoryOffset, bpfObjectBytes); + nextOffset += static_cast(bpfObjectBytes.size()); + + // Fill strings that will be used into memory + std::array strings = { + "test_map", // Map name + "sched_wakeup", // Program names + "" // An empty string + }; + std::array stringOffsets; + + for (size_t i = 0; i < strings.size(); i++) { + std::string currString(strings[i]); + std::vector bytes(currString.begin(), currString.end()); + // Ensure that strings are zero-terminated + bytes.push_back('\0'); + fillMemContent(memoryInstRef, nextOffset, bytes); + stringOffsets[i] = nextOffset; + nextOffset += static_cast(bytes.size()); + } + + // Get function "wasm_load_bpf_object" + auto* loadFunc = module->findFuncExports("wasm_load_bpf_object"); + ASSERT_NE(loadFunc, nullptr); + ASSERT_TRUE(loadFunc->isHostFunction()); + auto& loadFuncHost = + dynamic_cast(loadFunc->getHostFunc()); + + // call "wasm_load_bpf_object" to Load `bootstrap.bpf.o`, and check the + // result + std::array loadResult; + ASSERT_TRUE(loadFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(bpfObjectMemoryOffset), + WasmEdge::ValVariant(static_cast(bpfObjectBytes.size()))}, + loadResult)); + auto handle = loadResult[0].get(); + ASSERT_NE(handle, 0); + + // Get function `wasm_attach_bpf_program` + auto* attachFunc = module->findFuncExports("wasm_attach_bpf_program"); + ASSERT_NE(attachFunc, nullptr); + ASSERT_TRUE(attachFunc->isHostFunction()); + auto& attachFuncHost = dynamic_cast( + attachFunc->getHostFunc()); + + // Call "wasm_attach_bpf_program" to attach, and check the result + std::array attachResult; + ASSERT_TRUE(attachFuncHost.run(CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + WasmEdge::ValVariant(stringOffsets[1]), + // There should be '\0' + WasmEdge::ValVariant(stringOffsets[2]), + }, + attachResult)); + ASSERT_GE(attachResult[0].get(), 0); + + // Get function `wasm_bpf_map_fd_by_name` + auto* mapFdFunc = module->findFuncExports("wasm_bpf_map_fd_by_name"); + ASSERT_NE(mapFdFunc, nullptr); + ASSERT_TRUE(mapFdFunc->isHostFunction()); + auto& mapFdFuncHost = + dynamic_cast(mapFdFunc->getHostFunc()); + + // Call "wasm_bpf_map_fd_by_name" to get the map fd, and check the result + std::array mapFdResult; + ASSERT_TRUE(mapFdFuncHost.run(CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + WasmEdge::ValVariant(stringOffsets[0])}, + mapFdResult)); + auto mapFd = mapFdResult[0].get(); + ASSERT_GE(mapFd, 0); + + // Get function `wasm_bpf_map_fd_by_name` + auto* mapOptFunc = module->findFuncExports("wasm_bpf_map_operate"); + EXPECT_NE(mapOptFunc, nullptr); + EXPECT_TRUE(mapOptFunc->isHostFunction()); + auto& mapOptFuncHost = + dynamic_cast(mapOptFunc->getHostFunc()); + + // A wrapper to call wasm_bpf_map_operate + auto callMapOperate = [&](int32_t fd, int32_t cmd, uint32_t key, + uint32_t value, uint32_t nextKey, + uint64_t flags) -> int32_t { + std::array callResult; + EXPECT_TRUE(mapOptFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(fd), WasmEdge::ValVariant(cmd), + WasmEdge::ValVariant(key), WasmEdge::ValVariant(value), + WasmEdge::ValVariant(nextKey), WasmEdge::ValVariant(flags)}, + callResult)); + return callResult[0].get(); + }; + + auto mapLookupElem = [&](int32_t fd, uint32_t key, + uint32_t valueOut) -> int32_t { + // key found -> returns 0 + // key not found -> returns -1 + return callMapOperate(fd, + 1, // BPF_MAP_LOOKUP_ELEM + key, valueOut, 0, 0); + }; + + auto mapUpdateElem = [&](int32_t fd, uint32_t key, + uint32_t value) -> int32_t { + return callMapOperate(fd, + 2, // BPF_MAP_UPDATE_ELEM + key, value, 0, 0); + }; + + // Helper functions to make read & write more convenient + auto readU64 = [&](uint32_t offset) -> uint64_t { + const auto* ptr = memoryInstRef.getPointer(offset); + EXPECT_NE(ptr, nullptr); + return *ptr; + }; + auto writeU64 = [&](uint32_t offset, uint64_t val) { + auto* ptr = memoryInstRef.getPointer(offset); + EXPECT_NE(ptr, nullptr); + *ptr = val; + }; + + auto writeU32 = [&](uint32_t offset, uint32_t val) { + auto* ptr = memoryInstRef.getPointer(offset); + EXPECT_NE(ptr, nullptr); + *ptr = val; + }; + + // Generate two numbers, which will be stored in the map and calculated the + // summation by the ebpf program + std::mt19937 randGen; + randGen.seed(std::random_device()()); + std::uniform_int_distribution intDist( + 0, static_cast(1e16)); + uint64_t num1 = intDist(randGen); + uint64_t num2 = intDist(randGen); + + // Prepare for wasm memory which is used to store numbers + const uint32_t numOffset1 = nextOffset; + nextOffset += 8; + const uint32_t numOffset2 = nextOffset; + nextOffset += 8; + const uint32_t resultOffset = nextOffset; + nextOffset += 8; + const uint32_t indicatingKeyOffset = nextOffset; + nextOffset += 4; + const uint32_t num1KeyOffset = nextOffset; + nextOffset += 4; + const uint32_t num2KeyOffset = nextOffset; + nextOffset += 4; + const uint32_t resultKeyOffset = nextOffset; + nextOffset += 4; + + writeU32(num1KeyOffset, ADD_VALUE_1_KEY); + writeU32(num2KeyOffset, ADD_VALUE_2_KEY); + writeU32(resultKeyOffset, RESULT_VALUE_KEY); + writeU32(indicatingKeyOffset, INDICATING_KEY); + + writeU32(INDICATING_KEY, indicatingKeyOffset); + + writeU64(numOffset1, num1); + writeU64(numOffset2, num2); + + writeU64(resultOffset, 0); + + // Write the add values into the map + ASSERT_EQ(mapUpdateElem(mapFd, num1KeyOffset, numOffset1), 0); + ASSERT_EQ(mapUpdateElem(mapFd, num2KeyOffset, numOffset2), 0); + + // Write the indicating key + // Arbitrary values are correct. We only care the existence of the + // indicating key + ASSERT_EQ(mapUpdateElem(mapFd, indicatingKeyOffset, numOffset1), 0); + + // Sleep for 1s and wait for the ebpf program to process.. + std::this_thread::sleep_for(std::chrono::seconds(2)); + + // Read the result and check it + ASSERT_EQ(mapLookupElem(mapFd, resultKeyOffset, resultOffset), 0); + uint64_t addResult = readU64(resultOffset); + ASSERT_EQ(addResult, num1 + num2); + + // Get function `wasm_close_bpf_object` + auto* closeFunc = module->findFuncExports("wasm_close_bpf_object"); + ASSERT_NE(closeFunc, nullptr); + ASSERT_TRUE(closeFunc->isHostFunction()); + auto& closeFuncHost = + dynamic_cast(closeFunc->getHostFunc()); + + // Call "wasm_close_bpf_object" to attach, and check the result + std::array closeResult; + ASSERT_TRUE(closeFuncHost.run(CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + }, + closeResult)); + ASSERT_EQ(closeResult[0].get(), 0); +} diff --git a/test/plugins/wasm_bpf/simple_ringbuf_test.cpp b/test/plugins/wasm_bpf/simple_ringbuf_test.cpp new file mode 100644 index 00000000..cda1865b --- /dev/null +++ b/test/plugins/wasm_bpf/simple_ringbuf_test.cpp @@ -0,0 +1,245 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include +#include +#include +#include "common/defines.h" +#include "executor/executor.h" +#include "func-attach-bpf-program.h" +#include "func-bpf-buffer-poll.h" +#include "func-bpf-map-fd-by-name.h" +#include "func-close-bpf-object.h" +#include "func-load-bpf-object.h" +#include "plugin/plugin.h" +#include "runtime/instance/module.h" +#include "wasm-bpf-module.h" +namespace { +WasmEdge::Runtime::Instance::ModuleInstance* createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasm_bpf/" + "libwasmedgePluginWasmBpf" WASMEDGE_LIB_EXTENSION)); + if (const auto* Plugin = WasmEdge::Plugin::Plugin::find("wasm_bpf"sv)) { + if (const auto* Module = Plugin->findModule("wasm_bpf"sv)) { + return Module->create().release(); + } + } + return nullptr; +} + +std::filesystem::path getAssertsPath() { + std::filesystem::path thisFile(__FILE__); + return thisFile.parent_path() / "assets"; +} + +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance& memInst, + uint32_t offset, + const std::vector& data) noexcept { + char* buf = memInst.getPointer(offset); + std::copy(data.begin(), data.end(), buf); +} +class PollCallbackFunction + : public WasmEdge::Runtime::HostFunction { + public: + PollCallbackFunction() {} + WasmEdge::Expect body(const WasmEdge::Runtime::CallingFrame& Frame, + uint32_t __attribute__((unused)) ctx, + uint32_t data, + uint32_t data_sz) { + using namespace std; + using WasmEdge::unlikely; + auto* memory = Frame.getMemoryByIndex(0); + if (unlikely(!memory)) { + return WasmEdge::Unexpect(WasmEdge::ErrCode::Value::HostFuncError); + } + if (data_sz < static_cast(sizeof(uint32_t))) { + return WasmEdge::Unexpect(WasmEdge::ErrCode::Value::HostFuncError); + } + const uint32_t* dataPtr = memory->getPointer(data, 1); + if (unlikely(!dataPtr)) { + return WasmEdge::Unexpect(WasmEdge::ErrCode::Value::HostFuncError); + } + EXPECT_EQ(*dataPtr, UINT32_C(0xABCD1234)); + return 0; + } +}; + +} // namespace + +TEST(WasmBpfTest, SimpleRingbuf) { + using namespace std::string_view_literals; + // Test loading and attaching a bpf program, and polling buffer + auto module = dynamic_cast(createModule()); + ASSERT_NE(module, nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance moduleInst(""); + // moduleInst.addHostFunc() + moduleInst.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto* memoryInst = moduleInst.findMemoryExports("memory"); + ASSERT_NE(memoryInst, nullptr); + auto& memoryInstRef = *memoryInst; + WasmEdge::Executor::Executor executor((WasmEdge::Configure())); + WasmEdge::Runtime::CallingFrame CallFrame(&executor, &moduleInst); + + namespace fs = std::filesystem; + auto bpfObject = getAssertsPath() / "simple_ringbuf.bpf.o"; + + // Ensure the bpf object we need exists + ASSERT_TRUE(fs::exists(bpfObject)); + + // Read the bpf object into wasm memory + std::ifstream bpfObjStream(bpfObject); + ASSERT_TRUE(bpfObjStream.is_open()); + ASSERT_TRUE(bpfObjStream.good()); + std::vector bpfObjectBytes( + (std::istreambuf_iterator(bpfObjStream)), + std::istreambuf_iterator()); + ASSERT_FALSE(bpfObjectBytes.empty()); + // Offset to put things into memory + uint32_t nextOffset = 1; + + // Put the bpf object into memory + const uint32_t bpfObjectMemoryOffset = nextOffset; + fillMemContent(memoryInstRef, bpfObjectMemoryOffset, bpfObjectBytes); + nextOffset += static_cast(bpfObjectBytes.size()); + + // Fill strings that will be used into memory + std::array strings = { + "rb", // Map name + "handle_exec", // Program names + "" // An empty string + }; + std::array stringOffsets; + + for (size_t i = 0; i < strings.size(); i++) { + std::string currString(strings[i]); + std::vector bytes(currString.begin(), currString.end()); + // Ensure that strings are zero-terminated + bytes.push_back('\0'); + fillMemContent(memoryInstRef, nextOffset, bytes); + stringOffsets[i] = nextOffset; + nextOffset += static_cast(bytes.size()); + } + + const uint32_t bufferPollMemoryOffset = nextOffset; + const uint32_t bufferPollSize = 256; + nextOffset += bufferPollSize; + + // Get function "wasm_load_bpf_object" + auto* loadFunc = module->findFuncExports("wasm_load_bpf_object"); + ASSERT_NE(loadFunc, nullptr); + ASSERT_TRUE(loadFunc->isHostFunction()); + auto& loadFuncHost = + dynamic_cast(loadFunc->getHostFunc()); + + // call "wasm_load_bpf_object" to Load `bootstrap.bpf.o`, and check the + // result + std::array loadResult; + ASSERT_TRUE(loadFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(bpfObjectMemoryOffset), + WasmEdge::ValVariant(static_cast(bpfObjectBytes.size()))}, + loadResult)); + auto handle = loadResult[0].get(); + ASSERT_NE(handle, 0); + + // Get function `wasm_attach_bpf_program` + auto* attachFunc = module->findFuncExports("wasm_attach_bpf_program"); + ASSERT_NE(attachFunc, nullptr); + ASSERT_TRUE(attachFunc->isHostFunction()); + auto& attachFuncHost = dynamic_cast( + attachFunc->getHostFunc()); + + // Call "wasm_attach_bpf_program" to attach, and check the result + std::array attachResult; + ASSERT_TRUE(attachFuncHost.run(CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + WasmEdge::ValVariant(stringOffsets[1]), + // There should be '\0' + WasmEdge::ValVariant(stringOffsets[2]), + }, + attachResult)); + ASSERT_GE(attachResult[0].get(), 0); + + // Get function `wasm_bpf_map_fd_by_name` + auto* mapFdFunc = module->findFuncExports("wasm_bpf_map_fd_by_name"); + ASSERT_NE(mapFdFunc, nullptr); + ASSERT_TRUE(mapFdFunc->isHostFunction()); + auto& mapFdFuncHost = + dynamic_cast(mapFdFunc->getHostFunc()); + + // Call "wasm_bpf_map_fd_by_name" to get the map fd, and check the result + std::array mapFdResult; + ASSERT_TRUE(mapFdFuncHost.run(CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + WasmEdge::ValVariant(stringOffsets[0])}, + mapFdResult)); + auto mapFd = mapFdResult[0].get(); + ASSERT_GE(mapFd, 0); + + // In the following several steps we will prepare for polling + // Create an instance of the polling callback function + auto callbackFuncInst = + std::make_unique( + &moduleInst, std::make_unique()); + // Create a function table, and fill the callback function into it + auto funcTableInst = + std::make_unique( + WasmEdge::AST::TableType(WasmEdge::RefType::FuncRef, 1)); + ASSERT_TRUE(funcTableInst->setRefs( + std::initializer_list{ + WasmEdge::FuncRef(callbackFuncInst.get())}, + 0, 0, 1)); + // Add the table to the main module + moduleInst.addHostTable("__indirect_function_table"sv, + std::move(funcTableInst)); + + // Get the "wasm_bpf_buffer_poll" function + auto* bufferPollFunc = module->findFuncExports("wasm_bpf_buffer_poll"); + ASSERT_NE(bufferPollFunc, nullptr); + ASSERT_TRUE(bufferPollFunc->isHostFunction()); + auto& bufferPollFuncHost = dynamic_cast( + bufferPollFunc->getHostFunc()); + + // Call the polling function + std::array pollResult; + for (size_t i = 1; i <= 50; i++) { + using namespace std; + ASSERT_TRUE(bufferPollFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), // object handle + WasmEdge::ValVariant(mapFd), // map fd + UINT32_C(0), // callback function index + UINT32_C(0), // Custom context pointer + WasmEdge::ValVariant(bufferPollMemoryOffset), // buffer offset + WasmEdge::ValVariant(bufferPollSize), // buffer size + UINT32_C(100) // timeout (ms) + }, + pollResult)); + ASSERT_GE(pollResult[0].get(), 0); + } + + // Get function `wasm_close_bpf_object` + auto* closeFunc = module->findFuncExports("wasm_close_bpf_object"); + ASSERT_NE(closeFunc, nullptr); + ASSERT_TRUE(closeFunc->isHostFunction()); + auto& closeFuncHost = + dynamic_cast(closeFunc->getHostFunc()); + + // Call "wasm_close_bpf_object" to attach, and check the result + std::array closeResult; + ASSERT_TRUE(closeFuncHost.run(CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + }, + closeResult)); + ASSERT_EQ(closeResult[0].get(), 0); +} diff --git a/test/plugins/wasm_bpf/wasm_bpf.cpp b/test/plugins/wasm_bpf/wasm_bpf.cpp new file mode 100644 index 00000000..2ac2078b --- /dev/null +++ b/test/plugins/wasm_bpf/wasm_bpf.cpp @@ -0,0 +1,580 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "ast/type.h" +#include "common/defines.h" +#include "executor/executor.h" +#include "func-attach-bpf-program.h" +#include "func-bpf-buffer-poll.h" +#include "func-bpf-map-fd-by-name.h" +#include "func-bpf-map-operate.h" +#include "func-close-bpf-object.h" +#include "func-load-bpf-object.h" +#include "plugin/plugin.h" +#include "runtime/instance/module.h" +#include "wasm-bpf-module.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace { +WasmEdge::Runtime::Instance::ModuleInstance *createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasm_bpf/" + "libwasmedgePluginWasmBpf" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasm_bpf"sv)) { + if (const auto *Module = Plugin->findModule("wasm_bpf"sv)) { + return Module->create().release(); + } + } + return nullptr; +} + +std::filesystem::path getAssertsPath() { + std::filesystem::path thisFile(__FILE__); + return thisFile.parent_path() / "assets"; +} +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &memInst, + uint32_t offset, uint32_t count, char chr = 0) noexcept { + std::fill_n(memInst.getPointer(offset), count, chr); +} + +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &memInst, + uint32_t offset, const std::vector &data) noexcept { + char *buf = memInst.getPointer(offset); + std::copy(data.begin(), data.end(), buf); +} + +} // namespace + +TEST(WasmBpfTest, Module) { + auto module = dynamic_cast(createModule()); + EXPECT_NE(module, nullptr); + // Test whether functions are exported + EXPECT_EQ(module->getFuncExportNum(), 6U); + EXPECT_NE(module->findFuncExports("wasm_load_bpf_object"), nullptr); + EXPECT_NE(module->findFuncExports("wasm_close_bpf_object"), nullptr); + EXPECT_NE(module->findFuncExports("wasm_attach_bpf_program"), nullptr); + EXPECT_NE(module->findFuncExports("wasm_bpf_buffer_poll"), nullptr); + EXPECT_NE(module->findFuncExports("wasm_bpf_map_fd_by_name"), nullptr); + EXPECT_NE(module->findFuncExports("wasm_bpf_map_operate"), nullptr); + + delete module; +} + +static const size_t TASK_COMM_LEN = 16; +static const size_t MAX_FILENAME_LEN = 127; +struct event { + int pid; + int ppid; + unsigned exit_code; + unsigned long long duration_ns; + char comm[TASK_COMM_LEN]; + char filename[MAX_FILENAME_LEN]; + char exit_event; +}; + +class PollCallbackFunction + : public WasmEdge::Runtime::HostFunction { +public: + PollCallbackFunction() {} + WasmEdge::Expect body(const WasmEdge::Runtime::CallingFrame &Frame, + uint32_t __attribute__((unused)) ctx, + uint32_t data, uint32_t data_sz) { + using namespace std; + using WasmEdge::unlikely; + auto *memory = Frame.getMemoryByIndex(0); + if (unlikely(!memory)) { + return WasmEdge::Unexpect(WasmEdge::ErrCode::Value::HostFuncError); + } + if (data_sz < static_cast(sizeof(event))) { + return WasmEdge::Unexpect(WasmEdge::ErrCode::Value::HostFuncError); + } + const event *dataPtr = memory->getPointer(data, 1); + if (unlikely(!dataPtr)) { + return WasmEdge::Unexpect(WasmEdge::ErrCode::Value::HostFuncError); + } + auto nowTime = chrono::system_clock::to_time_t(chrono::system_clock::now()); + tm nowTimeRepr; + localtime_r(&nowTime, &nowTimeRepr); + if (dataPtr->exit_event == 1) { + cout.setf(ios::left); + cout << std::put_time(&nowTimeRepr, "%H:%M:%S") << " EXIT " << setw(16) + << setfill(' ') << dataPtr->comm << " " << setw(7) << setfill(' ') + << dataPtr->pid << " " << setw(7) << setfill(' ') << dataPtr->ppid + << " [" << dataPtr->exit_code << "]"; + if (dataPtr->duration_ns != 0) { + cout << " (" << dataPtr->duration_ns / 1000000 << ")" << endl; + } + } else { + cout.setf(ios::left); + cout << std::put_time(&nowTimeRepr, "%H:%M:%S") << " EXEC " << setw(16) + << setfill(' ') << dataPtr->comm << " " << setw(7) << setfill(' ') + << dataPtr->pid << " " << setw(7) << setfill(' ') << dataPtr->ppid + << " " << dataPtr->filename << endl; + } + return 0; + } +}; + +TEST(WasmBpfTest, RunBpfProgramWithPolling) { + using namespace std::literals::string_view_literals; + // Test loading and attaching a bpf program, and polling buffer + auto module = dynamic_cast(createModule()); + EXPECT_NE(module, nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance moduleInst(""); + // moduleInst.addHostFunc() + moduleInst.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *memoryInst = moduleInst.findMemoryExports("memory"); + EXPECT_NE(memoryInst, nullptr); + auto &memoryInstRef = *memoryInst; + WasmEdge::Executor::Executor executor((WasmEdge::Configure())); + WasmEdge::Runtime::CallingFrame CallFrame(&executor, &moduleInst); + + namespace fs = std::filesystem; + auto bpfObject = getAssertsPath() / "bootstrap.bpf.o"; + + // Ensure the bpf object we need exists + EXPECT_TRUE(fs::exists(bpfObject)); + + // Read the bpf object into wasm memory + std::ifstream bpfObjStream(bpfObject); + EXPECT_TRUE(bpfObjStream.is_open()); + EXPECT_TRUE(bpfObjStream.good()); + std::vector bpfObjectBytes( + (std::istreambuf_iterator(bpfObjStream)), + std::istreambuf_iterator()); + EXPECT_FALSE(bpfObjectBytes.empty()); + + // Fill bpf object into memory + const uint32_t bpfObjectMemoryOffset = 1; + fillMemContent(memoryInstRef, bpfObjectMemoryOffset, bpfObjectBytes); + + // Fill `handle_exec`, the bpf function name, into memory + const uint32_t targetHandleExecNameMemoryOffset = + bpfObjectMemoryOffset + static_cast(bpfObjectBytes.size()); + const std::string targetHandleExecName("handle_exec"); + // Zero terminated.. + std::vector targetHandleExecNameBytes(targetHandleExecName.size() + 1, + 0); + std::copy(targetHandleExecName.begin(), targetHandleExecName.end(), + targetHandleExecNameBytes.begin()); + fillMemContent(memoryInstRef, targetHandleExecNameMemoryOffset, + targetHandleExecNameBytes); + + // Fill `handle_exit`, the bpf function name, into memory + const uint32_t targetHandleExitNameMemoryOffset = + targetHandleExecNameMemoryOffset + + static_cast(targetHandleExecNameBytes.size()); + const std::string targetHandleExitName("handle_exit"); + // Zero terminated.. + std::vector targetHandleExitNameBytes(targetHandleExitName.size() + 1, + 0); + std::copy(targetHandleExitName.begin(), targetHandleExitName.end(), + targetHandleExitNameBytes.begin()); + fillMemContent(memoryInstRef, targetHandleExitNameMemoryOffset, + targetHandleExitNameBytes); + + // Fill the map name `rb` + const uint32_t mapNameMemoryOffset = + targetHandleExitNameMemoryOffset + + static_cast(targetHandleExitNameBytes.size()); + const std::string mapName("rb"); + // Zero terminated.. + std::vector mapNameBytes(mapName.size() + 1, 0); + std::copy(mapName.begin(), mapName.end(), mapNameBytes.begin()); + fillMemContent(memoryInstRef, mapNameMemoryOffset, mapNameBytes); + + // Prepare a memory area for storing polled things + const uint32_t bufferPollMemoryOffset = + mapNameMemoryOffset + static_cast(mapNameBytes.size()); + const uint32_t bufferPollSize = 1024; + fillMemContent(memoryInstRef, bufferPollMemoryOffset, bufferPollSize, 0); + + // Get function "wasm_load_bpf_object" + auto *loadFunc = module->findFuncExports("wasm_load_bpf_object"); + EXPECT_NE(loadFunc, nullptr); + EXPECT_TRUE(loadFunc->isHostFunction()); + auto &loadFuncHost = + dynamic_cast(loadFunc->getHostFunc()); + + // call "wasm_load_bpf_object" to Load `bootstrap.bpf.o`, and check the + // result + std::array loadResult; + EXPECT_TRUE(loadFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(bpfObjectMemoryOffset), + WasmEdge::ValVariant(static_cast(bpfObjectBytes.size()))}, + loadResult)); + auto handle = loadResult[0].get(); + EXPECT_NE(handle, 0); + + // Get function `wasm_attach_bpf_program` + auto *attachFunc = module->findFuncExports("wasm_attach_bpf_program"); + EXPECT_NE(attachFunc, nullptr); + EXPECT_TRUE(attachFunc->isHostFunction()); + auto &attachFuncHost = dynamic_cast( + attachFunc->getHostFunc()); + + // Call "wasm_attach_bpf_program" to attach, and check the result + std::array attachResult; + EXPECT_TRUE(attachFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + WasmEdge::ValVariant(targetHandleExecNameMemoryOffset), + // There should be '\0' + WasmEdge::ValVariant( + targetHandleExecNameMemoryOffset + + static_cast(targetHandleExecName.size())), + }, + attachResult)); + EXPECT_GE(attachResult[0].get(), 0); + EXPECT_TRUE(attachFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + WasmEdge::ValVariant(targetHandleExitNameMemoryOffset), + // There should be '\0' + WasmEdge::ValVariant( + targetHandleExitNameMemoryOffset + + static_cast(targetHandleExitName.size())), + }, + attachResult)); + EXPECT_GE(attachResult[0].get(), 0); + + // Get function `wasm_bpf_map_fd_by_name` + auto *mapFdFunc = module->findFuncExports("wasm_bpf_map_fd_by_name"); + EXPECT_NE(mapFdFunc, nullptr); + EXPECT_TRUE(mapFdFunc->isHostFunction()); + auto &mapFdFuncHost = + dynamic_cast(mapFdFunc->getHostFunc()); + + // Call "wasm_bpf_map_fd_by_name" to get the map fd, and check the result + std::array mapFdResult; + EXPECT_TRUE(mapFdFuncHost.run(CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + WasmEdge::ValVariant(mapNameMemoryOffset)}, + mapFdResult)); + auto mapFd = mapFdResult[0].get(); + EXPECT_GE(mapFd, 0); + + // In the following several steps we will prepare for polling + // Create an instance of the polling callback function + auto callbackFuncInst = + std::make_unique( + &moduleInst, std::make_unique()); + // Create a function table, and fill the callback function into it + auto funcTableInst = + std::make_unique( + WasmEdge::AST::TableType(WasmEdge::RefType::FuncRef, 1)); + EXPECT_TRUE(funcTableInst->setRefs( + std::initializer_list{ + WasmEdge::FuncRef(callbackFuncInst.get())}, + 0, 0, 1)); + // Add the table to the main module + moduleInst.addHostTable("__indirect_function_table"sv, + std::move(funcTableInst)); + + // Get the "wasm_bpf_buffer_poll" function + auto *bufferPollFunc = module->findFuncExports("wasm_bpf_buffer_poll"); + EXPECT_NE(bufferPollFunc, nullptr); + EXPECT_TRUE(bufferPollFunc->isHostFunction()); + auto &bufferPollFuncHost = dynamic_cast( + bufferPollFunc->getHostFunc()); + + // Call the polling function + std::array pollResult; + for (size_t i = 1; i <= 50; i++) { + using namespace std; + EXPECT_TRUE(bufferPollFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), // object handle + WasmEdge::ValVariant(mapFd), // map fd + UINT32_C(0), // callback function index + UINT32_C(0), // Custom context pointer + WasmEdge::ValVariant(bufferPollMemoryOffset), // buffer offset + WasmEdge::ValVariant(bufferPollSize), // buffer size + UINT32_C(100) // timeout (ms) + }, + pollResult)); + EXPECT_GE(pollResult[0].get(), 0); + } + + // Get function `wasm_close_bpf_object` + auto *closeFunc = module->findFuncExports("wasm_close_bpf_object"); + EXPECT_NE(closeFunc, nullptr); + EXPECT_TRUE(closeFunc->isHostFunction()); + auto &closeFuncHost = + dynamic_cast(closeFunc->getHostFunc()); + + // Call "wasm_close_bpf_object" to attach, and check the result + std::array closeResult; + EXPECT_TRUE(closeFuncHost.run(CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + }, + closeResult)); + EXPECT_EQ(closeResult[0].get(), 0); +} + +static const size_t MAX_SLOTS = 26; + +struct hist { + unsigned int slots[MAX_SLOTS]; + char comm[TASK_COMM_LEN]; +} __attribute__((packed)); + +TEST(WasmBpfTest, RunBpfProgramWithMapOperation) { + // Test loading and attaching a bpf program, and polling buffer + auto module = dynamic_cast(createModule()); + EXPECT_NE(module, nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance moduleInst(""); + // moduleInst.addHostFunc() + moduleInst.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *memoryInst = moduleInst.findMemoryExports("memory"); + EXPECT_NE(memoryInst, nullptr); + auto &memoryInstRef = *memoryInst; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &moduleInst); + + namespace fs = std::filesystem; + auto bpfObject = getAssertsPath() / "runqlat.bpf.o"; + + // Ensure the bpf object we need exists + EXPECT_TRUE(fs::exists(bpfObject)); + + // Read the bpf object into wasm memory + std::ifstream bpfObjStream(bpfObject); + EXPECT_TRUE(bpfObjStream.is_open()); + EXPECT_TRUE(bpfObjStream.good()); + std::vector bpfObjectBytes( + (std::istreambuf_iterator(bpfObjStream)), + std::istreambuf_iterator()); + EXPECT_FALSE(bpfObjectBytes.empty()); + // Offset to put things into memory + uint32_t nextOffset = 1; + + // Put the bpf object into memory + const uint32_t bpfObjectMemoryOffset = nextOffset; + fillMemContent(memoryInstRef, bpfObjectMemoryOffset, bpfObjectBytes); + nextOffset += static_cast(bpfObjectBytes.size()); + + // Fill strings that will be used into memory + std::array strings = { + "hists", // Map name + "sched_wakeup", "sched_wakeup_new", "sched_switch", // Program names + "" // An empty string + }; + std::array stringOffsets; + + for (size_t i = 0; i < strings.size(); i++) { + std::string currString(strings[i]); + std::vector bytes(currString.begin(), currString.end()); + // Ensure that strings are zero-terminated + bytes.push_back('\0'); + fillMemContent(memoryInstRef, nextOffset, bytes); + stringOffsets[i] = nextOffset; + nextOffset += static_cast(bytes.size()); + } + + // Get function "wasm_load_bpf_object" + auto *loadFunc = module->findFuncExports("wasm_load_bpf_object"); + EXPECT_NE(loadFunc, nullptr); + EXPECT_TRUE(loadFunc->isHostFunction()); + auto &loadFuncHost = + dynamic_cast(loadFunc->getHostFunc()); + + // call "wasm_load_bpf_object" to Load `bootstrap.bpf.o`, and check the + // result + std::array loadResult; + EXPECT_TRUE(loadFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(bpfObjectMemoryOffset), + WasmEdge::ValVariant(static_cast(bpfObjectBytes.size()))}, + loadResult)); + auto handle = loadResult[0].get(); + EXPECT_NE(handle, 0); + + // Get function `wasm_attach_bpf_program` + auto *attachFunc = module->findFuncExports("wasm_attach_bpf_program"); + EXPECT_NE(attachFunc, nullptr); + EXPECT_TRUE(attachFunc->isHostFunction()); + auto &attachFuncHost = dynamic_cast( + attachFunc->getHostFunc()); + std::array programNameIndexes = {1, 2, 3}; + + // Attach the programs + for (size_t index : programNameIndexes) { + std::array attachResult; + EXPECT_TRUE( + attachFuncHost.run(CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + WasmEdge::ValVariant(stringOffsets[index]), + // There should be '\0' + WasmEdge::ValVariant(stringOffsets[4]), + }, + attachResult)); + EXPECT_GE(attachResult[0].get(), 0); + } + + // Get function `wasm_bpf_map_fd_by_name` + auto *mapFdFunc = module->findFuncExports("wasm_bpf_map_fd_by_name"); + EXPECT_NE(mapFdFunc, nullptr); + EXPECT_TRUE(mapFdFunc->isHostFunction()); + auto &mapFdFuncHost = + dynamic_cast(mapFdFunc->getHostFunc()); + + // Call "wasm_bpf_map_fd_by_name" to get the map fd, and check the result + std::array mapFdResult; + EXPECT_TRUE(mapFdFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), WasmEdge::ValVariant(stringOffsets[0])}, + mapFdResult)); + auto histsFd = mapFdResult[0].get(); + EXPECT_GE(histsFd, 0); + + // Get function `wasm_bpf_map_fd_by_name` + auto *mapOptFunc = module->findFuncExports("wasm_bpf_map_operate"); + EXPECT_NE(mapOptFunc, nullptr); + EXPECT_TRUE(mapOptFunc->isHostFunction()); + auto &mapOptFuncHost = + dynamic_cast(mapOptFunc->getHostFunc()); + // A wrapper to call wasm_bpf_map_operate + auto callMapOperate = [&](int32_t fd, int32_t cmd, uint32_t key, + uint32_t value, uint32_t nextKey, + uint64_t flags) -> int32_t { + std::array callResult; + EXPECT_TRUE(mapOptFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(fd), WasmEdge::ValVariant(cmd), + WasmEdge::ValVariant(key), WasmEdge::ValVariant(value), + WasmEdge::ValVariant(nextKey), WasmEdge::ValVariant(flags)}, + callResult)); + return callResult[0].get(); + }; + // Three helper functions that will be used + auto mapGetNextKey = [&](int32_t fd, uint32_t lookupKey, + uint32_t nextKey) -> int32_t { + // lookupKey is the last element -> returns -1 + // lookupKey found -> returns 0, set nextKey + // lookupKey not found -> returns 0, set nextKey to the first key + + return callMapOperate(fd, + 4, // BPF_MAP_GET_NEXT_KEY + lookupKey, 0, nextKey, 0); + }; + auto mapLookupElem = [&](int32_t fd, uint32_t key, + uint32_t valueOut) -> int32_t { + // key found -> returns 0 + // key not found -> returns -1 + return callMapOperate(fd, + 1, // BPF_MAP_LOOKUP_ELEM + key, valueOut, 0, 0); + }; + auto mapDeleteElem = [&](int32_t fd, uint32_t key) -> int32_t { + // key found -> return 0 + // key not found -> returns -1 + return callMapOperate(fd, + 3, // BPF_MAP_DELETE_ELEM + key, 0, 0, 0); + }; + // Three helper functions to make read & write more convenient + auto readU32 = [&](uint32_t offset) -> uint32_t { + const auto *ptr = memoryInstRef.getPointer(offset); + EXPECT_NE(ptr, nullptr); + return *ptr; + }; + auto writeU32 = [&](uint32_t offset, uint32_t val) { + auto *ptr = memoryInstRef.getPointer(offset); + EXPECT_NE(ptr, nullptr); + *ptr = val; + }; + auto readHistRef = [&](uint32_t offset) -> const hist & { + const auto *ptr = memoryInstRef.getPointer(offset); + EXPECT_NE(ptr, nullptr); + return *ptr; + }; + const uint32_t lookUpKeyOffset = nextOffset; + nextOffset += sizeof(uint32_t); + const uint32_t nextKeyOffset = nextOffset; + nextOffset += sizeof(uint32_t); + const uint32_t histOffset = nextOffset; + nextOffset += sizeof(hist); + + // Poll 10 times, with interval 1s + for (size_t i = 1; i <= 10; i++) { + using namespace std; + std::this_thread::sleep_for(std::chrono::seconds(1)); + writeU32(lookUpKeyOffset, static_cast(-2)); + while (mapGetNextKey(histsFd, lookUpKeyOffset, nextKeyOffset) == 0) { + EXPECT_GE(mapLookupElem(histsFd, nextKeyOffset, histOffset), 0); + const auto &histRef = readHistRef(histOffset); + size_t maxIdx = 0; + for (size_t i = 0; i < std::size(histRef.slots); i++) + if (histRef.slots[i] > 0) + maxIdx = i; + for (size_t i = 0; i < maxIdx; i++) { + auto low = UINT64_C(1) << (i); + auto high = (UINT64_C(1) << (i + 1)) - 1; + cout.setf(ios::left); + cout << setw(6) << low << "..." << setw(6) << high << " " << setw(6) + << histRef.slots[i] << endl; + } + writeU32(lookUpKeyOffset, readU32(nextKeyOffset)); + } + writeU32(lookUpKeyOffset, static_cast(-2)); + while (mapGetNextKey(histsFd, lookUpKeyOffset, nextKeyOffset) == 0) { + EXPECT_GE(mapDeleteElem(histsFd, nextKeyOffset), 0); + writeU32(lookUpKeyOffset, readU32(nextKeyOffset)); + } + cout << endl; + } + + // Get function `wasm_close_bpf_object` + auto *closeFunc = module->findFuncExports("wasm_close_bpf_object"); + EXPECT_NE(closeFunc, nullptr); + EXPECT_TRUE(closeFunc->isHostFunction()); + auto &closeFuncHost = + dynamic_cast(closeFunc->getHostFunc()); + + // Call "wasm_close_bpf_object" to attach, and check the result + std::array closeResult; + EXPECT_TRUE(closeFuncHost.run(CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + }, + closeResult)); + EXPECT_EQ(closeResult[0].get(), 0); +} + +GTEST_API_ int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} From 23b72baa3ef1b1e05c2a6f206c6df538d9caaa86 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Mon, 17 Apr 2023 15:59:45 +0800 Subject: [PATCH 095/623] [Misc] Fix for clang-format-15. Signed-off-by: YiYing He --- plugins/wasi_crypto/utils/hostfunction.h | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/plugins/wasi_crypto/utils/hostfunction.h b/plugins/wasi_crypto/utils/hostfunction.h index 668e5c61..6d3fe175 100644 --- a/plugins/wasi_crypto/utils/hostfunction.h +++ b/plugins/wasi_crypto/utils/hostfunction.h @@ -42,10 +42,18 @@ template constexpr WasiCryptoExpect cast(uint64_t) noexcept; template struct WasiRawType { using Type = std::underlying_type_t; }; -template <> struct WasiRawType { using Type = uint8_t; }; -template <> struct WasiRawType { using Type = uint16_t; }; -template <> struct WasiRawType { using Type = uint32_t; }; -template <> struct WasiRawType { using Type = uint64_t; }; +template <> struct WasiRawType { + using Type = uint8_t; +}; +template <> struct WasiRawType { + using Type = uint16_t; +}; +template <> struct WasiRawType { + using Type = uint32_t; +}; +template <> struct WasiRawType { + using Type = uint64_t; +}; template using WasiRawTypeT = typename WasiRawType::Type; From 3b80c7431d32d5d64be4ff35a71aa9acf12d53db Mon Sep 17 00:00:00 2001 From: jinser <46820840+jetjinser@users.noreply.github.com> Date: Mon, 24 Apr 2023 20:33:55 +0800 Subject: [PATCH 096/623] [Plugin] Add hostname verification in HTTPS request plugin to fix SSL connection errors (#2425) * [Plugin] Fix wrong error code in log Signed-off-by: Jinser * [Plugin] Add TLS extension for remote hostname verification Signed-off-by: Jinser --------- Signed-off-by: Jinser --- plugins/wasmedge_httpsreq/httpsreqfunc.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/plugins/wasmedge_httpsreq/httpsreqfunc.cpp b/plugins/wasmedge_httpsreq/httpsreqfunc.cpp index aa6d32f2..9fc74643 100644 --- a/plugins/wasmedge_httpsreq/httpsreqfunc.cpp +++ b/plugins/wasmedge_httpsreq/httpsreqfunc.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -90,11 +91,13 @@ Expect WasmEdgeHttpsReqSendData::body(const Runtime::CallingFrame &Frame, SSL_set_fd(Ssl, Sfd); + SSL_set_tlsext_host_name(Ssl, Host); + const int Status = SSL_connect(Ssl); if (Status != 1) { - SSL_get_error(Ssl, Status); + const int Code = SSL_get_error(Ssl, Status); ERR_print_errors_fp(stderr); - spdlog::error("[WasmEdge Httpsreq] SSL_get_error code {}", Status); + spdlog::error("[WasmEdge Httpsreq] SSL_get_error code {}", Code); return Unexpect(ErrCode::Value::HostFuncError); } From 7708cda2a2a9ef65e19f77592643f0a819f5e670 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Tue, 9 May 2023 06:50:28 +0800 Subject: [PATCH 097/623] [CMake] Update the CMake of plug-ins and tests. Signed-off-by: YiYing He --- plugins/wasi_crypto/CMakeLists.txt | 9 -- plugins/wasi_nn/CMakeLists.txt | 130 +----------------- plugins/wasmedge_httpsreq/CMakeLists.txt | 9 -- plugins/wasmedge_process/CMakeLists.txt | 8 -- test/plugins/CMakeLists.txt | 2 +- test/plugins/unittest/CMakeLists.txt | 20 +-- test/plugins/wasi_crypto/CMakeLists.txt | 29 +++- test/plugins/wasi_nn/CMakeLists.txt | 42 ++++-- test/plugins/wasm_bpf/CMakeLists.txt | 34 +++-- test/plugins/wasmedge_httpsreq/CMakeLists.txt | 26 +++- test/plugins/wasmedge_process/CMakeLists.txt | 24 +++- 11 files changed, 140 insertions(+), 193 deletions(-) diff --git a/plugins/wasi_crypto/CMakeLists.txt b/plugins/wasi_crypto/CMakeLists.txt index ebec934b..70ae211e 100644 --- a/plugins/wasi_crypto/CMakeLists.txt +++ b/plugins/wasi_crypto/CMakeLists.txt @@ -56,14 +56,6 @@ target_compile_options(wasmedgePluginWasiCrypto -DOPENSSL_API_COMPAT=0x10100000L ) -if(CMAKE_SYSTEM_NAME MATCHES "Darwin") - target_link_options(wasmedgePluginWasiCrypto - PUBLIC - -Wl,-U,__ZN8WasmEdge6Plugin14PluginRegisterC1EPKNS0_6Plugin16PluginDescriptorE - -Wl,-U,__ZN8WasmEdge6Plugin14PluginRegisterD1Ev - ) -endif() - target_include_directories(wasmedgePluginWasiCrypto PUBLIC $ @@ -73,7 +65,6 @@ target_include_directories(wasmedgePluginWasiCrypto target_link_libraries(wasmedgePluginWasiCrypto PUBLIC - wasmedge_shared OpenSSL::Crypto ) if(WASMEDGE_LINK_PLUGINS_STATIC) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 74517d89..dd5a4a13 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -13,14 +13,6 @@ target_compile_options(wasmedgePluginWasiNN -DWASMEDGE_PLUGIN ) -if(CMAKE_SYSTEM_NAME MATCHES "Darwin") - target_link_options(wasmedgePluginWasiNN - PUBLIC - -Wl,-U,__ZN8WasmEdge6Plugin14PluginRegisterC1EPKNS0_6Plugin16PluginDescriptorE - -Wl,-U,__ZN8WasmEdge6Plugin14PluginRegisterD1Ev - ) -endif() - target_include_directories(wasmedgePluginWasiNN PUBLIC $ @@ -39,123 +31,7 @@ else() ) endif() -install(TARGETS wasmedgePluginWasiNN DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) - -# Add backends building flags. -foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) - string(TOLOWER ${BACKEND} BACKEND) - if(BACKEND STREQUAL "openvino") - message(STATUS "WASI-NN: Build OpenVINO backend for WASI-NN") - find_package(InferenceEngine REQUIRED) - add_definitions(-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO) - target_link_libraries(wasmedgePluginWasiNN - PUBLIC - ${InferenceEngine_LIBRARIES} - ) - elseif(BACKEND STREQUAL "pytorch") - message(STATUS "WASI-NN: Build PyTorch backend for WASI-NN") - find_package(Torch REQUIRED) - add_definitions(-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH) - target_link_libraries(wasmedgePluginWasiNN - PUBLIC - ${TORCH_LIBRARIES} - ) - elseif(BACKEND STREQUAL "tensorflowlite") - message(STATUS "WASI-NN: Build Tensorflow lite backend for WASI-NN") - # TODO: Move these complicated steps into a helper cmake. - add_definitions(-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE) - - if(NOT WASMEDGE_DEPS_VERSION) - set(WASMEDGE_DEPS_VERSION "0.11.1") - endif() - - # Clone required shared libraries - if(ANDROID) - if(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") - set(WASMEDGE_TENSORFLOW_SYSTEM_NAME "android_aarch64") - set(WASMEDGE_TENSORFLOW_DEPS_TFLITE_HASH "a25dafad049cbc998c1f9682c57aec22b2fe5799eeffdd4ed19793a734cde8a4") - elseif() - message(FATAL_ERROR "Unsupported architecture: ${CMAKE_SYSTEM_PROCESSOR}") - endif() - elseif(APPLE) - if(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64") - set(WASMEDGE_TENSORFLOW_SYSTEM_NAME "darwin_x86_64") - set(WASMEDGE_TENSORFLOW_DEPS_TFLITE_HASH "2593772df440a768e79d87e74a860378f46fb0b7d1e7805879ab2ec26a093b57") - else() - message(FATAL_ERROR "Unsupported architecture: ${CMAKE_SYSTEM_PROCESSOR}") - endif() - elseif(UNIX) - if(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64") - set(WASMEDGE_TENSORFLOW_SYSTEM_NAME "manylinux2014_x86_64") - set(WASMEDGE_TENSORFLOW_DEPS_TFLITE_HASH "43b2a782efb58b047c6d33f64d7ac711b24426959f91287d910edb8937c11dea") - elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") - set(WASMEDGE_TENSORFLOW_SYSTEM_NAME "manylinux2014_aarch64") - set(WASMEDGE_TENSORFLOW_DEPS_TFLITE_HASH "1f47dcd05f32907848253e0f4b0eb3a6276802dae41d2b7de61214b75ba02395") - else() - message(FATAL_ERROR "Unsupported architecture: ${CMAKE_SYSTEM_PROCESSOR}") - endif() - else() - message(FATAL_ERROR "Unsupported system: ${CMAKE_SYSTEM_NAME}") - endif() - - include(FetchContent) +include(WASINNDeps) +wasmedge_setup_wasinn_target(wasmedgePluginWasiNN) - # Fetch Tensorflow-lite library. - FetchContent_Declare( - wasmedgetensorflowdepslite - URL "https://github.com/second-state/WasmEdge-tensorflow-deps/releases/download/${WASMEDGE_DEPS_VERSION}/WasmEdge-tensorflow-deps-TFLite-${WASMEDGE_DEPS_VERSION}-${WASMEDGE_TENSORFLOW_SYSTEM_NAME}.tar.gz" - URL_HASH "SHA256=${WASMEDGE_TENSORFLOW_DEPS_TFLITE_HASH}" - ) - FetchContent_GetProperties(wasmedgetensorflowdepslite) - - if(NOT wasmedgetensorflowdepslite_POPULATED) - message(STATUS "Downloading dependency: libtensorflowlite") - FetchContent_Populate(wasmedgetensorflowdepslite) - message(STATUS "Downloading dependency: libtensorflowlite - done") - endif() - - # Setup Tensorflow-lite library. - if(APPLE) - set(WASMEDGE_TENSORFLOW_DEPS_TFLITE_LIB - "${wasmedgetensorflowdepslite_SOURCE_DIR}/libtensorflowlite_c.dylib" - ) - elseif(UNIX) - set(WASMEDGE_TENSORFLOW_DEPS_TFLITE_LIB - "${wasmedgetensorflowdepslite_SOURCE_DIR}/libtensorflowlite_c.so" - ) - endif() - - include(FetchContent) - FetchContent_Declare( - wasmedge_tensorflow_deps - GIT_REPOSITORY https://github.com/second-state/WasmEdge-tensorflow-deps.git - GIT_TAG ${WASMEDGE_DEPS_VERSION} - ) - FetchContent_GetProperties(wasmedge_tensorflow_deps) - - if(NOT wasmedge_tensorflow_deps_POPULATED) - message(STATUS "Fetching WasmEdge-tensorflow-dep repository") - FetchContent_Populate(wasmedge_tensorflow_deps) - message(STATUS "Fetching WasmEdge-tensorflow-dep repository - done") - endif() - - set(WASMEDGE_TENSORFLOW_DEPS_PATH ${wasmedge_tensorflow_deps_SOURCE_DIR}) - set(WASMEDGE_TENSORFLOW_DEPS_BIN_PATH ${CMAKE_CURRENT_BINARY_DIR}/utils/WasmEdge-tensorflow-deps) - - message(STATUS "WASI-NN: Set WasmEdge-tensorflow deps source path: ${WASMEDGE_TENSORFLOW_DEPS_PATH}") - message(STATUS "WASI-NN: Set WasmEdge-tensorflow deps binary path: ${WASMEDGE_TENSORFLOW_DEPS_BIN_PATH}") - message(STATUS "WASI-NN: Set WasmEdge-tensorflowlite share path: ${WASMEDGE_TENSORFLOW_DEPS_TFLITE_LIB}") - add_subdirectory(${WASMEDGE_TENSORFLOW_DEPS_PATH} ${WASMEDGE_TENSORFLOW_DEPS_BIN_PATH}) - target_include_directories(wasmedgePluginWasiNN - PUBLIC - ${TENSORFLOW_INCLUDE} - ) - target_link_libraries(wasmedgePluginWasiNN - PUBLIC - ${WASMEDGE_TENSORFLOW_DEPS_TFLITE_LIB} - ) - else() - # Add the other backends here. - message(FATAL_ERROR "WASI-NN: backend ${BACKEND} not found or unimplemented.") - endif() -endforeach() +install(TARGETS wasmedgePluginWasiNN DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) diff --git a/plugins/wasmedge_httpsreq/CMakeLists.txt b/plugins/wasmedge_httpsreq/CMakeLists.txt index 755769df..64affc60 100644 --- a/plugins/wasmedge_httpsreq/CMakeLists.txt +++ b/plugins/wasmedge_httpsreq/CMakeLists.txt @@ -16,15 +16,6 @@ target_compile_options(wasmedgePluginHttpsReq -DWASMEDGE_PLUGIN ) - -if(CMAKE_SYSTEM_NAME MATCHES "Darwin") - target_link_options(wasmedgePluginHttpsReq - PUBLIC - -Wl,-U,__ZN8WasmEdge6Plugin14PluginRegisterC1EPKNS0_6Plugin16PluginDescriptorE - -Wl,-U,__ZN8WasmEdge6Plugin14PluginRegisterD1Ev - ) -endif() - target_include_directories(wasmedgePluginHttpsReq PUBLIC $ diff --git a/plugins/wasmedge_process/CMakeLists.txt b/plugins/wasmedge_process/CMakeLists.txt index 98693eba..0355b905 100644 --- a/plugins/wasmedge_process/CMakeLists.txt +++ b/plugins/wasmedge_process/CMakeLists.txt @@ -13,14 +13,6 @@ target_compile_options(wasmedgePluginWasmEdgeProcess -DWASMEDGE_PLUGIN ) -if(CMAKE_SYSTEM_NAME MATCHES "Darwin") - target_link_options(wasmedgePluginWasmEdgeProcess - PUBLIC - -Wl,-U,__ZN8WasmEdge6Plugin14PluginRegisterC1EPKNS0_6Plugin16PluginDescriptorE - -Wl,-U,__ZN8WasmEdge6Plugin14PluginRegisterD1Ev - ) -endif() - target_include_directories(wasmedgePluginWasmEdgeProcess PUBLIC $ diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt index 2a46e044..c7bbae53 100644 --- a/test/plugins/CMakeLists.txt +++ b/test/plugins/CMakeLists.txt @@ -27,6 +27,6 @@ if(WASMEDGE_PLUGIN_WASM_BPF) endif() endif() -if(CMAKE_SYSTEM_NAME MATCHES "Linux") +if(CMAKE_SYSTEM_NAME MATCHES "Linux" OR CMAKE_SYSTEM_NAME MATCHES "Darwin") add_subdirectory(unittest) endif() diff --git a/test/plugins/unittest/CMakeLists.txt b/test/plugins/unittest/CMakeLists.txt index 2de4833c..4631c0b4 100644 --- a/test/plugins/unittest/CMakeLists.txt +++ b/test/plugins/unittest/CMakeLists.txt @@ -27,14 +27,6 @@ target_compile_options(wasmedgePluginTestModuleCPP -DWASMEDGE_PLUGIN ) -if(CMAKE_SYSTEM_NAME MATCHES "Darwin") - target_link_options(wasmedgePluginTestModuleCPP - PUBLIC - -Wl,-U,__ZN8WasmEdge6Plugin14PluginRegisterC1EPKNS0_6Plugin16PluginDescriptorE - -Wl,-U,__ZN8WasmEdge6Plugin14PluginRegisterD1Ev - ) -endif() - target_include_directories(wasmedgePluginTestModuleCPP PUBLIC $ @@ -49,12 +41,12 @@ wasmedge_add_executable(wasmedgePluginUnittestsC target_include_directories(wasmedgePluginUnittestsC PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} + $ ) target_link_libraries(wasmedgePluginUnittestsC PRIVATE ${GTEST_BOTH_LIBRARIES} - wasmedgePlugin ) # The test executable for C++ API @@ -65,12 +57,12 @@ wasmedge_add_executable(wasmedgePluginUnittestsCPP target_include_directories(wasmedgePluginUnittestsCPP PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} + $ ) target_link_libraries(wasmedgePluginUnittestsCPP PRIVATE ${GTEST_BOTH_LIBRARIES} - wasmedgePlugin ) # Link to the WasmEdge library @@ -87,6 +79,10 @@ if(WASMEDGE_LINK_PLUGINS_STATIC) PRIVATE wasmedgeCAPI ) + target_link_libraries(wasmedgePluginUnittestsCPP + PRIVATE + wasmedgeCAPI + ) else() target_link_libraries(wasmedgePluginTestModuleC PRIVATE @@ -100,6 +96,10 @@ else() PRIVATE wasmedge_shared ) + target_link_libraries(wasmedgePluginUnittestsCPP + PRIVATE + wasmedge_shared + ) endif() add_test(wasmedgePluginUnittestsC wasmedgePluginUnittestsC) diff --git a/test/plugins/wasi_crypto/CMakeLists.txt b/test/plugins/wasi_crypto/CMakeLists.txt index acbed2b8..5e874abd 100644 --- a/test/plugins/wasi_crypto/CMakeLists.txt +++ b/test/plugins/wasi_crypto/CMakeLists.txt @@ -14,11 +14,36 @@ wasmedge_add_executable(wasiCryptoTests signatures.cpp ) +add_dependencies(wasiCryptoTests + wasmedgePluginWasiCrypto +) + +target_compile_options(wasiCryptoTests + PUBLIC + -DOPENSSL_API_COMPAT=0x10100000L +) + +target_include_directories(wasiCryptoTests + PUBLIC + $ + $ +) + target_link_libraries(wasiCryptoTests PRIVATE ${GTEST_BOTH_LIBRARIES} - wasmedgePlugin - wasmedgePluginWasiCrypto ) +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasiCryptoTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasiCryptoTests + PRIVATE + wasmedge_shared + ) +endif() add_test(wasiCryptoTests wasiCryptoTests) diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index e3d55a9a..5f050951 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -7,7 +7,8 @@ wasmedge_add_executable(wasiNNTests # Prepare the testing data for each backends. foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) - if(BACKEND MATCHES "OpenVINO") + string(TOLOWER ${BACKEND} BACKEND) + if(BACKEND MATCHES "openvino") message( STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_openvino_fixtures") execute_process( COMMAND bash ${CMAKE_SOURCE_DIR}/utils/wasi-nn/download-openvino-fixtures.sh ${CMAKE_CURRENT_BINARY_DIR}/wasinn_openvino_fixtures @@ -25,8 +26,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) if(NOT CHECKSUM_TENSOR STREQUAL "bfca546f4a3b5e6da49b7bd728e2799a") message(FATAL_ERROR "tensor-1x224x224x3-f32.bgr downloaded with wrong md5") endif() - add_definitions(-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO) - elseif(BACKEND MATCHES "PyTorch") + elseif(BACKEND MATCHES "pytorch") message( STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_pytorch_fixtures") execute_process( COMMAND bash ${CMAKE_SOURCE_DIR}/utils/wasi-nn/download-pytorch-fixtures.sh ${CMAKE_CURRENT_BINARY_DIR}/wasinn_pytorch_fixtures @@ -40,9 +40,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) if(NOT CHECKSUM_IMAGE STREQUAL "551caa6f3b66c1d953655228462570a1") message(FATAL_ERROR "image-1x3x224x224.rgb downloaded with wrong md5") endif() - add_definitions(-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH) - find_package(Torch REQUIRED) - elseif(BACKEND STREQUAL "Tensorflowlite") + elseif(BACKEND STREQUAL "tensorflowlite") message( STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_tflite_fixtures") execute_process( COMMAND bash ${CMAKE_SOURCE_DIR}/utils/wasi-nn/download-tflite-fixtures.sh ${CMAKE_CURRENT_BINARY_DIR}/wasinn_tflite_fixtures @@ -56,21 +54,39 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) if(NOT CHECKSUM_IMAGE STREQUAL "ad51c39cfe35d2ef35c4052b78cb3c55") message(FATAL_ERROR "downloaded bird.jpg fixture with wrong md5") endif() - add_definitions(-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE) - target_include_directories(wasiNNTests - PUBLIC - ${TENSORFLOW_INCLUDE} - ) else() # Add the other backend test files fetching here. endif() endforeach() +add_dependencies(wasiNNTests + wasmedgePluginWasiNN +) + +include(WASINNDeps) +wasmedge_setup_wasinn_target(wasiNNTests) + +target_include_directories(wasiNNTests + PUBLIC + $ + $ +) + target_link_libraries(wasiNNTests PRIVATE ${GTEST_BOTH_LIBRARIES} - wasmedgePlugin - wasmedgePluginWasiNN ) +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasiNNTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasiNNTests + PRIVATE + wasmedge_shared + ) +endif() add_test(wasiNNTests wasiNNTests) diff --git a/test/plugins/wasm_bpf/CMakeLists.txt b/test/plugins/wasm_bpf/CMakeLists.txt index 26e9c3fe..735764be 100644 --- a/test/plugins/wasm_bpf/CMakeLists.txt +++ b/test/plugins/wasm_bpf/CMakeLists.txt @@ -7,18 +7,34 @@ wasmedge_add_executable(wasmBpfTests wasm_bpf.cpp ) -target_link_libraries(wasmBpfTests - PRIVATE - ${GTEST_BOTH_LIBRARIES} - wasmedgePlugin - wasmedgePluginWasmBpf - wasmedgeExecutor -) add_subdirectory(assets) -add_dependencies( - wasmBpfTests +add_dependencies(wasmBpfTests + wasmedgePluginWasmBpf wasmBpfTestsAssets ) +target_include_directories(wasmBpfTests + PUBLIC + $ + $ +) + +target_link_libraries(wasmBpfTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} +) +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmBpfTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmBpfTests + PRIVATE + wasmedge_shared + ) +endif() + add_test(wasmBpfTests wasmBpfTests) diff --git a/test/plugins/wasmedge_httpsreq/CMakeLists.txt b/test/plugins/wasmedge_httpsreq/CMakeLists.txt index 54f6279b..921ee619 100644 --- a/test/plugins/wasmedge_httpsreq/CMakeLists.txt +++ b/test/plugins/wasmedge_httpsreq/CMakeLists.txt @@ -5,11 +5,31 @@ wasmedge_add_executable(wasmedgeHttpsReqTests httpsreq.cpp ) +add_dependencies(wasmedgeHttpsReqTests + wasmedgePluginHttpsReq +) + +target_include_directories(wasmedgeHttpsReqTests + PUBLIC + $ + $ +) + target_link_libraries(wasmedgeHttpsReqTests PRIVATE ${GTEST_BOTH_LIBRARIES} - wasmedgePlugin - wasmedgePluginHttpsReq ) +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgeHttpsReqTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgeHttpsReqTests + PRIVATE + wasmedge_shared + ) +endif() -add_test(wasmedgeHttpsReqTests wasmedgeHttpsReqTests) \ No newline at end of file +add_test(wasmedgeHttpsReqTests wasmedgeHttpsReqTests) diff --git a/test/plugins/wasmedge_process/CMakeLists.txt b/test/plugins/wasmedge_process/CMakeLists.txt index 9f7d1309..ee34f5c1 100644 --- a/test/plugins/wasmedge_process/CMakeLists.txt +++ b/test/plugins/wasmedge_process/CMakeLists.txt @@ -5,11 +5,31 @@ wasmedge_add_executable(wasmedgeProcessTests wasmedge_process.cpp ) +add_dependencies(wasmedgeProcessTests + wasmedgePluginWasmEdgeProcess +) + +target_include_directories(wasmedgeProcessTests + PUBLIC + $ + $ +) + target_link_libraries(wasmedgeProcessTests PRIVATE ${GTEST_BOTH_LIBRARIES} - wasmedgePlugin - wasmedgePluginWasmEdgeProcess ) +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgeProcessTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgeProcessTests + PRIVATE + wasmedge_shared + ) +endif() add_test(wasmedgeProcessTests wasmedgeProcessTests) From a1b9547ac70c7cd361f6da54122f4cb8096aa2ec Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 17 May 2023 16:25:35 +0800 Subject: [PATCH 098/623] [Misc] Bump wasmedge docker image from ubuntu 20.04 to ubuntu 22.04 Signed-off-by: hydai --- utils/docker/Dockerfile.base | 6 +++--- utils/docker/Dockerfile.build-clang | 6 +++--- utils/docker/Dockerfile.ci-image-base | 2 +- utils/docker/Dockerfile.release | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/utils/docker/Dockerfile.base b/utils/docker/Dockerfile.base index fb56f33c..08d1e58a 100644 --- a/utils/docker/Dockerfile.base +++ b/utils/docker/Dockerfile.base @@ -1,4 +1,4 @@ -FROM ubuntu:20.04 +FROM ubuntu:22.04 MAINTAINER hydai hydai@secondstate.io ENV DEBIAN_FRONTEND=noninteractive @@ -13,7 +13,7 @@ RUN apt update && apt upgrade -y \ curl \ git \ libboost-all-dev \ - llvm-12-dev \ - liblld-12-dev + llvm-15-dev \ + liblld-15-dev RUN rm -rf /var/lib/apt/lists/* diff --git a/utils/docker/Dockerfile.build-clang b/utils/docker/Dockerfile.build-clang index eaebfab6..bed2409b 100644 --- a/utils/docker/Dockerfile.build-clang +++ b/utils/docker/Dockerfile.build-clang @@ -2,9 +2,9 @@ ARG BASE=wasmedge/wasmedge:ubuntu-base FROM ${BASE} RUN apt update && apt install -y \ - clang-12 + clang-15 RUN rm -rf /var/lib/apt/lists/* -ENV CC=/usr/bin/clang-12 -ENV CXX=/usr/bin/clang++-12 +ENV CC=/usr/bin/clang-15 +ENV CXX=/usr/bin/clang++-15 diff --git a/utils/docker/Dockerfile.ci-image-base b/utils/docker/Dockerfile.ci-image-base index d381de58..4c7fa368 100644 --- a/utils/docker/Dockerfile.ci-image-base +++ b/utils/docker/Dockerfile.ci-image-base @@ -1,4 +1,4 @@ -FROM ubuntu:20.04 +FROM ubuntu:22.04 MAINTAINER hydai hydai@secondstate.io ENV DEBIAN_FRONTEND=noninteractive diff --git a/utils/docker/Dockerfile.release b/utils/docker/Dockerfile.release index 4b4f51e2..ab05caf5 100644 --- a/utils/docker/Dockerfile.release +++ b/utils/docker/Dockerfile.release @@ -1,4 +1,4 @@ -FROM ubuntu:20.04 +FROM ubuntu:22.04 ARG VERSION RUN apt-get update && apt-get install -y netbase From e711dc2f135133e223f65fd17ce248ff5ad97f66 Mon Sep 17 00:00:00 2001 From: yanghaku <36074633+yanghaku@users.noreply.github.com> Date: Wed, 17 May 2023 18:11:21 +0800 Subject: [PATCH 099/623] [WASI-NN] Add GPU target support for pytorch (#2457) Signed-off-by: yanghaku <1961882079@qq.com> --- plugins/wasi_nn/wasinnenv.h | 3 +++ plugins/wasi_nn/wasinnfunc.cpp | 45 ++++++++++++++++++++++------------ 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index 241f0d5f..fb8509d8 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -32,6 +32,8 @@ enum class ErrNo : uint32_t { enum class TensorType : uint8_t { F16 = 0, F32 = 1, U8 = 2, I32 = 3 }; +enum class Device : uint32_t { CPU = 0, GPU = 1, TPU = 2 }; + enum class Backend : uint8_t { OpenVINO = 0, ONNX = 1, @@ -89,6 +91,7 @@ class Graph { #endif #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH torch::jit::Module TorchModel; + torch::DeviceType TorchDevice = at::kCPU; #endif #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE TfLiteModel *TFLiteMod = nullptr; diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index 9087cfeb..a3673496 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -26,18 +26,18 @@ namespace WasmEdge { namespace Host { namespace { -[[maybe_unused]] std::string FindDevice(const uint32_t Target) { +[[maybe_unused]] std::string findDevice(const WASINN::Device Target) { std::string DeviceName; switch (Target) { - case 0: + case WASINN::Device::CPU: DeviceName = "CPU"; break; - // case 1: - // DeviceName = "GPU"; - // break; - // case 2: - // DeviceName = "TPU"; - // break; + case WASINN::Device::GPU: + DeviceName = "GPU"; + break; + case WASINN::Device::TPU: + DeviceName = "TPU"; + break; default: DeviceName = ""; } @@ -63,10 +63,12 @@ Expect WasiNNLoad::body(const Runtime::CallingFrame &Frame, return static_cast(WASINN::ErrNo::InvalidArgument); } // Get and check the device name string. - std::string DeviceName; - DeviceName = FindDevice(Target); - if (unlikely(DeviceName.length() == 0)) { - spdlog::error("[WASI-NN] Only support CPU target"); + const auto Device = static_cast(Target); + const std::string DeviceName = findDevice(Device); + if (unlikely(DeviceName.length() == 0 && + (Encoding != static_cast(WASINN::Backend::PyTorch) || + Device != WASINN::Device::GPU))) { + spdlog::error("[WASI-NN] Only support CPU target and Pytorch GPU target."); return static_cast(WASINN::ErrNo::InvalidArgument); } spdlog::debug("[WASI-NN] Using device: {:s}", DeviceName); @@ -274,12 +276,23 @@ Expect WasiNNLoad::body(const Runtime::CallingFrame &Frame, // Add a new graph. Env.NNGraph.emplace_back(static_cast(Encoding)); auto &Graph = Env.NNGraph.back(); + // Setup Graph Device + if (Device == WASINN::Device::GPU) { + if (torch::cuda::is_available()) { + Graph.TorchDevice = at::kCUDA; + } else { + spdlog::error("[WASI-NN] Platform Cannot support GPU target."); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + } + std::string BinString((char *)BinPtr, BinLen); std::stringstream BinRead; BinRead.str(BinString); try { Graph.TorchModel = torch::jit::load(BinRead); + Graph.TorchModel.to(Graph.TorchDevice); } catch (const c10::Error &e) { spdlog::error("[WASI-NN] Failed when load the TorchScript model."); Env.NNGraph.pop_back(); @@ -631,8 +644,10 @@ Expect WasiNNSetInput::body(const Runtime::CallingFrame &Frame, for (size_t I = 0; I < DimensionLen; I++) { Dims.push_back(static_cast(DimensionBuf[I])); } - torch::Tensor InTensor = torch::from_blob( - reinterpret_cast(TensorDataBuf), Dims, Options); + torch::Tensor InTensor = + torch::from_blob(reinterpret_cast(TensorDataBuf), Dims, + Options) + .to(CxtRef.GraphRef.TorchDevice); CxtRef.TorchInputs[Index] = InTensor.clone(); return static_cast(WASINN::ErrNo::Success); @@ -821,7 +836,7 @@ WasiNNGetOuput::body(const Runtime::CallingFrame &Frame, uint32_t Context, return static_cast(WASINN::ErrNo::InvalidArgument); } torch::Tensor OutTensor = - CxtRef.TorchOutputs[Index].toType(torch::kFloat32); + CxtRef.TorchOutputs[Index].to(at::kCPU).toType(torch::kFloat32); float *TensorBuffer = OutTensor.data_ptr(); size_t BlobSize = 1; From f222e6237ce37b443f78983708135cce69a2a863 Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Thu, 18 May 2023 01:50:38 +0800 Subject: [PATCH 100/623] [WASI-NN] Use right TensorType in error message Signed-off-by: Shen-Ta Hsieh --- plugins/wasi_nn/wasinnfunc.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index a3673496..eef488fc 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -13,13 +13,14 @@ #endif #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH -#include +#include #include #endif #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE #include "tensorflow/lite/c/c_api.h" +#include "tensorflow/lite/c/common.h" #endif namespace WasmEdge { @@ -692,7 +693,7 @@ Expect WasiNNSetInput::body(const Runtime::CallingFrame &Frame, auto *HoldTensor = TfLiteInterpreterGetInputTensor(CxtRef.TFLiteInterp, Index); WASINN::TensorType LiteType; - switch (TfLiteTensorType(HoldTensor)) { + switch (const auto Type = TfLiteTensorType(HoldTensor)) { case TfLiteType::kTfLiteUInt8: LiteType = WASINN::TensorType::U8; break; @@ -706,7 +707,8 @@ Expect WasiNNSetInput::body(const Runtime::CallingFrame &Frame, LiteType = WASINN::TensorType::I32; break; default: - spdlog::error("[WASI-NN] Unsupported TFLite type: {}", LiteType); + spdlog::error("[WASI-NN] Unsupported TFLite type: {}", + TfLiteTypeGetName(Type)); return static_cast(WASINN::ErrNo::InvalidArgument); } From e466c0dc0cdd1e5b3d7690b2fb006943eb353ffe Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Sat, 20 May 2023 04:58:50 +0800 Subject: [PATCH 101/623] [WASI-CRYPTO] Fix CI Fail * OpenSSL 3.0 didn't implement context duplication for aes-gcm and chacha20. https://github.com/openssl/openssl/issues/20978 * Change CRLF line ending to unix LF * Set buffer size for correct API usaging Signed-off-by: Shen-Ta Hsieh --- plugins/wasi_crypto/symmetric/aeads.cpp | 9 +- plugins/wasi_crypto/symmetric/hash.cpp | 188 ++++++++-------- plugins/wasi_crypto/symmetric/hash.h | 278 ++++++++++++------------ plugins/wasi_crypto/symmetric/kdf.cpp | 21 +- plugins/wasi_crypto/symmetric/mac.cpp | 7 +- plugins/wasi_crypto/utils/evp_wrapper.h | 28 ++- test/plugins/wasi_crypto/aeads.cpp | 14 +- 7 files changed, 285 insertions(+), 260 deletions(-) diff --git a/plugins/wasi_crypto/symmetric/aeads.cpp b/plugins/wasi_crypto/symmetric/aeads.cpp index 99a09d06..08ba7658 100644 --- a/plugins/wasi_crypto/symmetric/aeads.cpp +++ b/plugins/wasi_crypto/symmetric/aeads.cpp @@ -211,8 +211,15 @@ Cipher::State::decryptImpl(Span Out, template WasiCryptoExpect::State> Cipher::State::clone() const noexcept { - EvpCipherCtxPtr CloneCtx{EVP_CIPHER_CTX_new()}; + // XXX: These cipher didn't implement context duplication from OpenSSL 3.0.0 + // https://github.com/openssl/openssl/issues/20978 + if (0x30000000 <= OPENSSL_VERSION_NUMBER && + (CipherNid == NID_aes_128_gcm || CipherNid == NID_aes_256_gcm || + CipherNid == NID_chacha20_poly1305)) { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + } + EvpCipherCtxPtr CloneCtx{EVP_CIPHER_CTX_new()}; { std::scoped_lock Lock{Ctx->Mutex}; opensslCheck(EVP_CIPHER_CTX_copy(CloneCtx.get(), Ctx->RawCtx.get())); diff --git a/plugins/wasi_crypto/symmetric/hash.cpp b/plugins/wasi_crypto/symmetric/hash.cpp index dd283f34..ea2129b5 100644 --- a/plugins/wasi_crypto/symmetric/hash.cpp +++ b/plugins/wasi_crypto/symmetric/hash.cpp @@ -1,94 +1,94 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC - -#include "symmetric/hash.h" -#include "utils/evp_wrapper.h" - -#include - -namespace WasmEdge { -namespace Host { -namespace WasiCrypto { -namespace Symmetric { - -template constexpr size_t Sha2::getDigestSize() noexcept { - static_assert(ShaNid == NID_sha256 || ShaNid == NID_sha512 || - ShaNid == NID_sha512_256); - if constexpr (ShaNid == NID_sha256) - return 32; - if constexpr (ShaNid == NID_sha512) - return 64; - if constexpr (ShaNid == NID_sha512_256) - return 32; -} - -template -WasiCryptoExpect::State> -Sha2::State::open(OptionalRef) noexcept { - EvpMdCtxPtr Ctx{EVP_MD_CTX_new()}; - opensslCheck(EVP_DigestInit(Ctx.get(), EVP_get_digestbynid(ShaNid))); - return Ctx; -} - -template -WasiCryptoExpect -Sha2::State::absorb(Span Data) noexcept { - std::unique_lock Lock{Ctx->Mutex}; - opensslCheck(EVP_DigestUpdate(Ctx->RawCtx.get(), Data.data(), Data.size())); - return {}; -} - -template -WasiCryptoExpect -Sha2::State::squeeze(Span Out) noexcept { - ensureOrReturn(getDigestSize() >= Out.size(), - __WASI_CRYPTO_ERRNO_INVALID_LENGTH); - - EvpMdCtxPtr CopyCtx{EVP_MD_CTX_new()}; - - { - std::shared_lock Lock{Ctx->Mutex}; - opensslCheck(EVP_MD_CTX_copy_ex(CopyCtx.get(), Ctx->RawCtx.get())); - } - - if (getDigestSize() == Out.size()) { - unsigned int Size; - opensslCheck(EVP_DigestFinal_ex(CopyCtx.get(), Out.data(), &Size)); - ensureOrReturn(Size == getDigestSize(), - __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); - } - - if (getDigestSize() > Out.size()) { - unsigned int Size; - std::array Cache; - opensslCheck(EVP_DigestFinal_ex(CopyCtx.get(), Cache.data(), &Size)); - ensureOrReturn(Size == getDigestSize(), - __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); - std::copy(Cache.begin(), Cache.begin() + static_cast(Out.size()), - Out.data()); - } - - return {}; -} - -template -WasiCryptoExpect::State> -Sha2::State::clone() const noexcept { - EvpMdCtxPtr CloneCtx{EVP_MD_CTX_new()}; - - { - std::shared_lock Lock{Ctx->Mutex}; - opensslCheck(EVP_MD_CTX_copy_ex(CloneCtx.get(), Ctx->RawCtx.get())); - } - - return CloneCtx; -} - -template class Sha2; -template class Sha2; -template class Sha2; - -} // namespace Symmetric -} // namespace WasiCrypto -} // namespace Host -} // namespace WasmEdge +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "symmetric/hash.h" +#include "utils/evp_wrapper.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +template constexpr size_t Sha2::getDigestSize() noexcept { + static_assert(ShaNid == NID_sha256 || ShaNid == NID_sha512 || + ShaNid == NID_sha512_256); + if constexpr (ShaNid == NID_sha256) + return 32; + if constexpr (ShaNid == NID_sha512) + return 64; + if constexpr (ShaNid == NID_sha512_256) + return 32; +} + +template +WasiCryptoExpect::State> +Sha2::State::open(OptionalRef) noexcept { + EvpMdCtxPtr Ctx{EVP_MD_CTX_new()}; + opensslCheck(EVP_DigestInit(Ctx.get(), EVP_get_digestbynid(ShaNid))); + return Ctx; +} + +template +WasiCryptoExpect +Sha2::State::absorb(Span Data) noexcept { + std::unique_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_DigestUpdate(Ctx->RawCtx.get(), Data.data(), Data.size())); + return {}; +} + +template +WasiCryptoExpect +Sha2::State::squeeze(Span Out) noexcept { + ensureOrReturn(getDigestSize() >= Out.size(), + __WASI_CRYPTO_ERRNO_INVALID_LENGTH); + + EvpMdCtxPtr CopyCtx{EVP_MD_CTX_new()}; + + { + std::shared_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_MD_CTX_copy_ex(CopyCtx.get(), Ctx->RawCtx.get())); + } + + if (getDigestSize() == Out.size()) { + unsigned int Size; + opensslCheck(EVP_DigestFinal_ex(CopyCtx.get(), Out.data(), &Size)); + ensureOrReturn(Size == getDigestSize(), + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + } + + if (getDigestSize() > Out.size()) { + unsigned int Size; + std::array Cache; + opensslCheck(EVP_DigestFinal_ex(CopyCtx.get(), Cache.data(), &Size)); + ensureOrReturn(Size == getDigestSize(), + __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + std::copy(Cache.begin(), Cache.begin() + static_cast(Out.size()), + Out.data()); + } + + return {}; +} + +template +WasiCryptoExpect::State> +Sha2::State::clone() const noexcept { + EvpMdCtxPtr CloneCtx{EVP_MD_CTX_new()}; + + { + std::shared_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_MD_CTX_copy_ex(CloneCtx.get(), Ctx->RawCtx.get())); + } + + return CloneCtx; +} + +template class Sha2; +template class Sha2; +template class Sha2; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/hash.h b/plugins/wasi_crypto/symmetric/hash.h index 49288f22..66934da5 100644 --- a/plugins/wasi_crypto/symmetric/hash.h +++ b/plugins/wasi_crypto/symmetric/hash.h @@ -1,139 +1,139 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC - -//===-- wasmedge/plugins/wasi_crypto/symmetric/hash.h - Hash related ------===// -// -// Part of the WasmEdge Project. -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file contains the declaration of Hash and related classes. -/// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "symmetric/options.h" -#include "symmetric/tag.h" -#include "utils/evp_wrapper.h" -#include "utils/optional.h" - -#include - -namespace WasmEdge { -namespace Host { -namespace WasiCrypto { -namespace Symmetric { - -/// Hash never have key, just a placement, every hash key should inherent from -/// this class. -/// -/// More detailed: -/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#hash-functions -template class HashKey { -public: - static WasiCryptoExpect import(Span) noexcept { - return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_KEY_NOT_SUPPORTED); - } - - static WasiCryptoExpect generate(OptionalRef) noexcept { - return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_KEY_NOT_SUPPORTED); - } - - SecretVec exportData() const noexcept { assumingUnreachable(); } -}; - -/// Hash invalid operations, every hash state should inherent from this class. -template class HashState { -public: - /// Current hash not support any options. - WasiCryptoExpect optionsGet(std::string_view, - Span) const noexcept { - return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); - } - - /// Current hash not support any options. - WasiCryptoExpect optionsGetU64(std::string_view) const noexcept { - return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); - } - - WasiCryptoExpect ratchet() noexcept { - return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); - } - - WasiCryptoExpect encrypt(Span, - Span) noexcept { - return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); - } - - WasiCryptoExpect encryptDetached(Span, - Span) noexcept { - return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); - } - - WasiCryptoExpect decrypt(Span, - Span) noexcept { - return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); - } - - WasiCryptoExpect decryptDetached(Span, Span, - Span) noexcept { - return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); - } - - WasiCryptoExpect squeezeKey() noexcept { - return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); - } - - WasiCryptoExpect squeezeTag() noexcept { - return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); - } - - WasiCryptoExpect maxTagLen() const noexcept { - return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); - } -}; - -template class Sha2 { -public: - /// In fact, sha2 key will never produce. This design is for removing the - /// forwarding declaration. - class Key : public HashKey {}; - - class State : public HashState { - public: - State(EvpMdCtxPtr Ctx) noexcept - : Ctx(std::make_shared(std::move(Ctx))) {} - - static WasiCryptoExpect - open(OptionalRef OptOption) noexcept; - - WasiCryptoExpect absorb(Span Data) noexcept; - - WasiCryptoExpect squeeze(Span Out) noexcept; - - WasiCryptoExpect clone() const noexcept; - - private: - struct Inner { - Inner(EvpMdCtxPtr Ctx) noexcept : RawCtx(std::move(Ctx)) {} - EvpMdCtxPtr RawCtx; - std::shared_mutex Mutex; - }; - std::shared_ptr Ctx; - }; - -private: - /// Return the sha digest size. - constexpr static size_t getDigestSize() noexcept; -}; - -using Sha256 = Sha2; -using Sha512 = Sha2; -using Sha512_256 = Sha2; - -} // namespace Symmetric -} // namespace WasiCrypto -} // namespace Host -} // namespace WasmEdge +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +//===-- wasmedge/plugins/wasi_crypto/symmetric/hash.h - Hash related ------===// +// +// Part of the WasmEdge Project. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the declaration of Hash and related classes. +/// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "symmetric/options.h" +#include "symmetric/tag.h" +#include "utils/evp_wrapper.h" +#include "utils/optional.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasiCrypto { +namespace Symmetric { + +/// Hash never have key, just a placement, every hash key should inherent from +/// this class. +/// +/// More detailed: +/// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#hash-functions +template class HashKey { +public: + static WasiCryptoExpect import(Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_KEY_NOT_SUPPORTED); + } + + static WasiCryptoExpect generate(OptionalRef) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_KEY_NOT_SUPPORTED); + } + + SecretVec exportData() const noexcept { assumingUnreachable(); } +}; + +/// Hash invalid operations, every hash state should inherent from this class. +template class HashState { +public: + /// Current hash not support any options. + WasiCryptoExpect optionsGet(std::string_view, + Span) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + } + + /// Current hash not support any options. + WasiCryptoExpect optionsGetU64(std::string_view) const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); + } + + WasiCryptoExpect ratchet() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect encrypt(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect encryptDetached(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect decrypt(Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect decryptDetached(Span, Span, + Span) noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect squeezeKey() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect squeezeTag() noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } + + WasiCryptoExpect maxTagLen() const noexcept { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_OPERATION); + } +}; + +template class Sha2 { +public: + /// In fact, sha2 key will never produce. This design is for removing the + /// forwarding declaration. + class Key : public HashKey {}; + + class State : public HashState { + public: + State(EvpMdCtxPtr Ctx) noexcept + : Ctx(std::make_shared(std::move(Ctx))) {} + + static WasiCryptoExpect + open(OptionalRef OptOption) noexcept; + + WasiCryptoExpect absorb(Span Data) noexcept; + + WasiCryptoExpect squeeze(Span Out) noexcept; + + WasiCryptoExpect clone() const noexcept; + + private: + struct Inner { + Inner(EvpMdCtxPtr Ctx) noexcept : RawCtx(std::move(Ctx)) {} + EvpMdCtxPtr RawCtx; + std::shared_mutex Mutex; + }; + std::shared_ptr Ctx; + }; + +private: + /// Return the sha digest size. + constexpr static size_t getDigestSize() noexcept; +}; + +using Sha256 = Sha2; +using Sha512 = Sha2; +using Sha512_256 = Sha2; + +} // namespace Symmetric +} // namespace WasiCrypto +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/symmetric/kdf.cpp b/plugins/wasi_crypto/symmetric/kdf.cpp index f9c09d6e..62f43c2f 100644 --- a/plugins/wasi_crypto/symmetric/kdf.cpp +++ b/plugins/wasi_crypto/symmetric/kdf.cpp @@ -110,15 +110,18 @@ Hkdf::Extract::State::absorb(Span Data) noexcept { template WasiCryptoExpect::Expand::Key> Hkdf::Extract::State::squeezeKey() noexcept { - std::scoped_lock Lock{Ctx->Mutex}; - - opensslCheck(EVP_PKEY_CTX_set1_hkdf_salt(Ctx->RawCtx.get(), Ctx->Salt.data(), - Ctx->Salt.size())); - - SecretVec Data(getKeySize()); - - size_t ActualOutSize; - opensslCheck(EVP_PKEY_derive(Ctx->RawCtx.get(), Data.data(), &ActualOutSize)); + { + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck(EVP_PKEY_CTX_set1_hkdf_salt( + Ctx->RawCtx.get(), Ctx->Salt.data(), Ctx->Salt.size())); + } + size_t ActualOutSize = getKeySize(); + SecretVec Data(ActualOutSize); + { + std::scoped_lock Lock{Ctx->Mutex}; + opensslCheck( + EVP_PKEY_derive(Ctx->RawCtx.get(), Data.data(), &ActualOutSize)); + } ensureOrReturn(ActualOutSize == getKeySize(), __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); diff --git a/plugins/wasi_crypto/symmetric/mac.cpp b/plugins/wasi_crypto/symmetric/mac.cpp index d6349cf2..53c7e2ee 100644 --- a/plugins/wasi_crypto/symmetric/mac.cpp +++ b/plugins/wasi_crypto/symmetric/mac.cpp @@ -62,9 +62,8 @@ Hmac::State::absorb(Span Data) noexcept { template WasiCryptoExpect Hmac::State::squeezeTag() noexcept { - SecretVec Res(getKeySize()); - - size_t ActualOutSize; + size_t ActualOutSize = getKeySize(); + SecretVec Res(ActualOutSize); { std::scoped_lock Lock{Ctx->Mutex}; opensslCheck( @@ -95,4 +94,4 @@ template class Hmac; } // namespace Symmetric } // namespace WasiCrypto } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/plugins/wasi_crypto/utils/evp_wrapper.h b/plugins/wasi_crypto/utils/evp_wrapper.h index af4d0994..67b4ecfb 100644 --- a/plugins/wasi_crypto/utils/evp_wrapper.h +++ b/plugins/wasi_crypto/utils/evp_wrapper.h @@ -60,26 +60,30 @@ using RsaPtr = OpenSSLUniquePtr; #define opensslCheck(Cond) \ do { \ if (!(Cond)) { \ + using namespace std::literals; \ ERR_print_errors_cb( \ - [](const char *_Str, size_t, void *) { \ - spdlog::error(_Str); \ + [](const char *ErrStr, size_t ErrLen, void *) { \ + spdlog::error("{}"sv, std::string_view(ErrStr, ErrLen)); \ return 1; \ }, \ nullptr); \ return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); \ } \ - } while (0) + } while (false) #else #define opensslCheck(Cond) \ - (static_cast(Cond) \ - ? static_cast(0) \ - : (ERR_print_errors_cb( \ - [](const char *_Str, size_t, void *) { \ - spdlog::error(_Str); \ - return 1; \ - }, \ - nullptr), \ - OPENSSL_die("assertion failed: " #Cond, __FILE__, __LINE__))) + do { \ + if (!(Cond)) { \ + using namespace std::literals; \ + ERR_print_errors_cb( \ + [](const char *ErrStr, size_t ErrLen, void *) { \ + spdlog::error("{}"sv, std::string_view(ErrStr, ErrLen)); \ + return 1; \ + }, \ + nullptr); \ + OPENSSL_die("assertion failed: " #Cond, __FILE__, __LINE__); \ + } \ + } while (false) #endif /// OpenSSL encoding parse api is too confusing, simplify them. diff --git a/test/plugins/wasi_crypto/aeads.cpp b/test/plugins/wasi_crypto/aeads.cpp index 8d449adc..eefc21c7 100644 --- a/test/plugins/wasi_crypto/aeads.cpp +++ b/test/plugins/wasi_crypto/aeads.cpp @@ -95,12 +95,24 @@ TEST_F(WasiCryptoTest, Aeads) { __WASI_CRYPTO_ERRNO_INVALID_OPERATION); } - { + bool IsSymmetricStateCloneImplemented = true; + // XXX: These cipher didn't implement context duplication from OpenSSL 3.0.0 + // https://github.com/openssl/openssl/issues/20978 + if (0x30000000 <= OPENSSL_VERSION_NUMBER && + (Name == "AES-128-GCM"sv || Name == "AES-256-GCM"sv || + Name == "CHACHA20-POLY1305"sv)) { + IsSymmetricStateCloneImplemented = false; + } + + if (IsSymmetricStateCloneImplemented) { // Clone checking. WASI_CRYPTO_EXPECT_SUCCESS(NewStateHandle, symmetricStateClone(State4Handle)); EXPECT_NE(State4Handle, NewStateHandle); WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(NewStateHandle)); + } else { + WASI_CRYPTO_EXPECT_FAILURE(symmetricStateClone(State4Handle), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); } WASI_CRYPTO_EXPECT_TRUE(symmetricStateClose(State4Handle)); From 6828d5c30ecc5b626af4a26b0b1dba57a4869d15 Mon Sep 17 00:00:00 2001 From: yanghaku <36074633+yanghaku@users.noreply.github.com> Date: Tue, 23 May 2023 20:53:24 +0800 Subject: [PATCH 102/623] [WASI-NN] tensorflow lite backend - fix set_input and get_output bugs when checking (#2360) Signed-off-by: yanghaku <1961882079@qq.com> --- plugins/wasi_nn/wasinnenv.h | 6 +++-- plugins/wasi_nn/wasinnfunc.cpp | 44 +++++++++++++++++++++++++------- test/plugins/wasi_nn/wasi_nn.cpp | 2 +- 3 files changed, 40 insertions(+), 12 deletions(-) diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index fb8509d8..a03b591c 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -26,8 +26,10 @@ namespace WASINN { enum class ErrNo : uint32_t { Success = 0, // No error occurred. InvalidArgument = 1, // Caller module passed an invalid argument. - MissingMemory = 2, // Caller module is missing a memory export. - Busy = 3 // Device or resource busy. + InvalidEncoding = 2, // Invalid encoding. + MissingMemory = 3, // Caller module is missing a memory export. + Busy = 4, // Device or resource busy. + RuntimeError = 5, // Runtime Error. }; enum class TensorType : uint8_t { F16 = 0, F32 = 1, U8 = 2, I32 = 3 }; diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index eef488fc..b744087f 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -349,6 +349,7 @@ Expect WasiNNLoad::body(const Runtime::CallingFrame &Frame, #endif } else { spdlog::error("[WASI-NN] Current backend is not supported."); + return static_cast(WASINN::ErrNo::InvalidEncoding); } return static_cast(WASINN::ErrNo::InvalidArgument); } @@ -416,7 +417,7 @@ Expect WasiNNInitExecCtx::body(const Runtime::CallingFrame &Frame, } Env.NNContext.emplace_back(Env.NNGraph[GraphId]); - const auto Graph = Env.NNGraph[GraphId]; + const auto &Graph = Env.NNGraph[GraphId]; auto &NewContext = Env.NNContext.back(); auto *TFLiteOps = TfLiteInterpreterOptionsCreate(); TfLiteInterpreterOptionsSetNumThreads(TFLiteOps, 2); @@ -671,12 +672,8 @@ Expect WasiNNSetInput::body(const Runtime::CallingFrame &Frame, return static_cast(WASINN::ErrNo::InvalidArgument); } uint32_t DimensionLen = Tensor[1]; - std::vector TFDimension(DimensionLen); uint32_t *DimensionBuf = MemInst->getPointer(Tensor[0], DimensionLen); - for (uint32_t I = 0; I < DimensionLen; I++) { - TFDimension.push_back(static_cast(DimensionBuf[I])); - } if (unlikely(DimensionBuf == nullptr)) { spdlog::error("[WASI-NN] Failed when accessing the Dimension memory."); return static_cast(WASINN::ErrNo::InvalidArgument); @@ -689,9 +686,33 @@ Expect WasiNNSetInput::body(const Runtime::CallingFrame &Frame, return static_cast(WASINN::ErrNo::InvalidArgument); } - WASINN::TensorType RType = static_cast(Tensor[2]); auto *HoldTensor = TfLiteInterpreterGetInputTensor(CxtRef.TFLiteInterp, Index); + // Check the input data size. + const auto HoldTensorByteSize = TfLiteTensorByteSize(HoldTensor); + if (HoldTensorByteSize != static_cast(TensorDataLen)) { + spdlog::error("[WASI-NN] Expect tensor byte size {}, but got {}", + HoldTensorByteSize, TensorDataLen); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + // Check the input tensor dimensions. + const auto HoldTensorNumDims = TfLiteTensorNumDims(HoldTensor); + if (static_cast(HoldTensorNumDims) != DimensionLen) { + spdlog::error( + "[WASI-NN] Expect tensor number of dimensions {}, but got {}", + HoldTensorNumDims, DimensionLen); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + for (uint32_t I = 0; I < DimensionLen; I++) { + const auto HoldTensorDim = TfLiteTensorDim(HoldTensor, I); + if (static_cast(HoldTensorDim) != DimensionBuf[I]) { + spdlog::error("[WASI-NN] Expect tensor dimension[{}] = {}, but got {}", + I, HoldTensorDim, DimensionBuf[I]); + return static_cast(WASINN::ErrNo::InvalidArgument); + } + } + // Check the input tensor type. + WASINN::TensorType RType = static_cast(Tensor[2]); WASINN::TensorType LiteType; switch (const auto Type = TfLiteTensorType(HoldTensor)) { case TfLiteType::kTfLiteUInt8: @@ -709,7 +730,7 @@ Expect WasiNNSetInput::body(const Runtime::CallingFrame &Frame, default: spdlog::error("[WASI-NN] Unsupported TFLite type: {}", TfLiteTypeGetName(Type)); - return static_cast(WASINN::ErrNo::InvalidArgument); + return static_cast(WASINN::ErrNo::RuntimeError); } if (unlikely(LiteType != RType)) { @@ -880,8 +901,13 @@ WasiNNGetOuput::body(const Runtime::CallingFrame &Frame, uint32_t Context, } const TfLiteTensor *HoldTensor = TfLiteInterpreterGetOutputTensor(CxtRef.TFLiteInterp, Index); - const uint32_t BlobSize = TfLiteTensorByteSize(HoldTensor); - uint32_t BytesToWrite = std::min(BlobSize, OutBufferMaxSize); + const uint32_t BytesToWrite = TfLiteTensorByteSize(HoldTensor); + // Check out buffer max size. + if (OutBufferMaxSize < BytesToWrite) { + spdlog::error("[WASI-NN] Expect out buffer max size {}, but got {}", + BytesToWrite, OutBufferMaxSize); + return static_cast(WASINN::ErrNo::InvalidArgument); + } uint8_t *OutBuffer = MemInst->getPointer(OutBufferPtr, BytesToWrite); if (unlikely(OutBuffer == nullptr)) { diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index f5a481ca..39d2484a 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -873,7 +873,7 @@ TEST(WasiNNTest, TFLiteBackend) { readEntireFile("./wasinn_tflite_fixtures/" "lite-model_aiy_vision_classifier_birds_V1_3.tflite"); spdlog::info("Read {}", TensorData.size()); - std::vector TensorDim{1, 3, 224, 224}; + std::vector TensorDim{1, 224, 224, 3}; uint32_t BuilderPtr = UINT32_C(0); uint32_t LoadEntryPtr = UINT32_C(0); uint32_t SetInputEntryPtr = UINT32_C(0); From 3776e237a92a5314c8681d155d7668188bfc32fd Mon Sep 17 00:00:00 2001 From: hydai Date: Fri, 19 May 2023 05:49:38 +0000 Subject: [PATCH 103/623] [Plugin] Deprecate wasmedge-httpsreq Signed-off-by: hydai --- plugins/CMakeLists.txt | 11 +- plugins/wasmedge_httpsreq/CMakeLists.txt | 42 ----- plugins/wasmedge_httpsreq/httpsreqbase.h | 25 --- plugins/wasmedge_httpsreq/httpsreqenv.cpp | 39 ----- plugins/wasmedge_httpsreq/httpsreqenv.h | 23 --- plugins/wasmedge_httpsreq/httpsreqfunc.cpp | 143 ---------------- plugins/wasmedge_httpsreq/httpsreqfunc.h | 37 ---- plugins/wasmedge_httpsreq/httpsreqmodule.cpp | 22 --- plugins/wasmedge_httpsreq/httpsreqmodule.h | 24 --- test/plugins/CMakeLists.txt | 6 - test/plugins/wasmedge_httpsreq/CMakeLists.txt | 35 ---- test/plugins/wasmedge_httpsreq/httpsreq.cpp | 159 ------------------ 12 files changed, 1 insertion(+), 565 deletions(-) delete mode 100644 plugins/wasmedge_httpsreq/CMakeLists.txt delete mode 100644 plugins/wasmedge_httpsreq/httpsreqbase.h delete mode 100644 plugins/wasmedge_httpsreq/httpsreqenv.cpp delete mode 100644 plugins/wasmedge_httpsreq/httpsreqenv.h delete mode 100644 plugins/wasmedge_httpsreq/httpsreqfunc.cpp delete mode 100644 plugins/wasmedge_httpsreq/httpsreqfunc.h delete mode 100644 plugins/wasmedge_httpsreq/httpsreqmodule.cpp delete mode 100644 plugins/wasmedge_httpsreq/httpsreqmodule.h delete mode 100644 test/plugins/wasmedge_httpsreq/CMakeLists.txt delete mode 100644 test/plugins/wasmedge_httpsreq/httpsreq.cpp diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index 83c6d793..015ea804 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -9,7 +9,7 @@ if(WASMEDGE_PLUGIN_WASI_CRYPTO) add_subdirectory(wasi_crypto) endif() -if(WASMEDGE_PLUGIN_PROCESS) +if(WASMEDGE_PLUGIN_PROCESS) # Only Linux systems support wasmedge_process now. if(CMAKE_SYSTEM_NAME MATCHES "Linux") add_subdirectory(wasmedge_process) @@ -18,15 +18,6 @@ if(WASMEDGE_PLUGIN_PROCESS) endif() endif() -if(WASMEDGE_PLUGIN_HTTPSREQ) - # Only Linux systems support wasmedge_httpsreq now. - if(CMAKE_SYSTEM_NAME MATCHES "Linux") - add_subdirectory(wasmedge_httpsreq) - else() - message(WARNING "Only Linux platforms support WasmEdge_HttpsReq plug-in now.") - endif() -endif() - if(WASMEDGE_PLUGIN_WASM_BPF) # Only Linux systems support wasm_bpf now. if(CMAKE_SYSTEM_NAME MATCHES "Linux") diff --git a/plugins/wasmedge_httpsreq/CMakeLists.txt b/plugins/wasmedge_httpsreq/CMakeLists.txt deleted file mode 100644 index 64affc60..00000000 --- a/plugins/wasmedge_httpsreq/CMakeLists.txt +++ /dev/null @@ -1,42 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC - -set(OPENSSL_USE_STATIC_LIBS ON) -find_package(OpenSSL REQUIRED) - -wasmedge_add_library(wasmedgePluginHttpsReq - SHARED - httpsreqenv.cpp - httpsreqfunc.cpp - httpsreqmodule.cpp -) - -target_compile_options(wasmedgePluginHttpsReq - PUBLIC - -DWASMEDGE_PLUGIN -) - -target_include_directories(wasmedgePluginHttpsReq - PUBLIC - $ - ${CMAKE_CURRENT_SOURCE_DIR} -) - -target_link_libraries(wasmedgePluginHttpsReq - PUBLIC - OpenSSL::Crypto - OpenSSL::SSL -) -if(WASMEDGE_LINK_PLUGINS_STATIC) - target_link_libraries(wasmedgePluginHttpsReq - PRIVATE - wasmedgeCAPI - ) -else() - target_link_libraries(wasmedgePluginHttpsReq - PRIVATE - wasmedge_shared - ) -endif() - -install(TARGETS wasmedgePluginHttpsReq DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) diff --git a/plugins/wasmedge_httpsreq/httpsreqbase.h b/plugins/wasmedge_httpsreq/httpsreqbase.h deleted file mode 100644 index d11b96c2..00000000 --- a/plugins/wasmedge_httpsreq/httpsreqbase.h +++ /dev/null @@ -1,25 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC - -#pragma once - -#include "httpsreqenv.h" - -#include "common/errcode.h" -#include "runtime/callingframe.h" -#include "runtime/hostfunc.h" - -namespace WasmEdge { -namespace Host { - -template class WasmEdgeHttpsReq : public Runtime::HostFunction { -public: - WasmEdgeHttpsReq(WasmEdgeHttpsReqEnvironment &HostEnv) - : Runtime::HostFunction(0), Env(HostEnv) {} - -protected: - WasmEdgeHttpsReqEnvironment &Env; -}; - -} // namespace Host -} // namespace WasmEdge diff --git a/plugins/wasmedge_httpsreq/httpsreqenv.cpp b/plugins/wasmedge_httpsreq/httpsreqenv.cpp deleted file mode 100644 index 857d93a8..00000000 --- a/plugins/wasmedge_httpsreq/httpsreqenv.cpp +++ /dev/null @@ -1,39 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC - -#include "httpsreqenv.h" -#include "httpsreqmodule.h" - -namespace WasmEdge { -namespace Host { - -namespace { - -Runtime::Instance::ModuleInstance * -create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { - return new WasmEdgeHttpsReqModule; -} - -Plugin::Plugin::PluginDescriptor Descriptor{ - .Name = "wasmedge_httpsreq", - .Description = "", - .APIVersion = Plugin::Plugin::CurrentAPIVersion, - .Version = {0, 10, 1, 0}, - .ModuleCount = 1, - .ModuleDescriptions = - (Plugin::PluginModule::ModuleDescriptor[]){ - { - .Name = "wasmedge_httpsreq", - .Description = "", - .Create = create, - }, - }, - .AddOptions = nullptr, -}; - -} // namespace - -Plugin::PluginRegister WasmEdgeHttpsReqEnvironment::Register(&Descriptor); - -} // namespace Host -} // namespace WasmEdge diff --git a/plugins/wasmedge_httpsreq/httpsreqenv.h b/plugins/wasmedge_httpsreq/httpsreqenv.h deleted file mode 100644 index 7b1781b1..00000000 --- a/plugins/wasmedge_httpsreq/httpsreqenv.h +++ /dev/null @@ -1,23 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC - -#pragma once - -#include "plugin/plugin.h" - -#include -#include - -namespace WasmEdge { -namespace Host { - -class WasmEdgeHttpsReqEnvironment { -public: - std::string Rcv; - - /// Initial Configurations - static Plugin::PluginRegister Register; -}; - -} // namespace Host -} // namespace WasmEdge diff --git a/plugins/wasmedge_httpsreq/httpsreqfunc.cpp b/plugins/wasmedge_httpsreq/httpsreqfunc.cpp deleted file mode 100644 index 9fc74643..00000000 --- a/plugins/wasmedge_httpsreq/httpsreqfunc.cpp +++ /dev/null @@ -1,143 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC - -#include "httpsreqfunc.h" - -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace WasmEdge { -namespace Host { - -Expect WasmEdgeHttpsReqSendData::body(const Runtime::CallingFrame &Frame, - uint32_t HostPtr, uint32_t HostLen, - uint32_t Port, uint32_t BodyPtr, - uint32_t BodyLen) { - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - return Unexpect(ErrCode::Value::HostFuncError); - } - - const char *Host = MemInst->getPointer(HostPtr); - const char *Body = MemInst->getPointer(BodyPtr); - if (Host == nullptr) { - spdlog::error("[WasmEdge Httpsreq] Fail to get Host"); - return Unexpect(ErrCode::Value::HostFuncError); - } - if (Body == nullptr) { - spdlog::error("[WasmEdge Httpsreq] Fail to get Body"); - return Unexpect(ErrCode::Value::HostFuncError); - } - std::string HostStr, BodyStr, PortStr = std::to_string(Port); - std::copy_n(Host, HostLen, std::back_inserter(HostStr)); - std::copy_n(Body, BodyLen, std::back_inserter(BodyStr)); - - const SSL_METHOD *Method = TLS_client_method(); - SSL_CTX *Ctx = SSL_CTX_new(Method); - if (Ctx == nullptr) { - ERR_print_errors_fp(stderr); - spdlog::error("[WasmEdge Httpsreq] SSL_CTX_new() failed"); - return Unexpect(ErrCode::Value::HostFuncError); - } - SSL *Ssl = SSL_new(Ctx); - if (Ssl == nullptr) { - spdlog::error("[WasmEdge Httpsreq] SSL_new() failed"); - return Unexpect(ErrCode::Value::HostFuncError); - } - - // open connection - int Sfd, Err; - struct addrinfo Hints = {}, *Addrs; - - Hints.ai_family = AF_INET; - Hints.ai_socktype = SOCK_STREAM; - Hints.ai_protocol = IPPROTO_TCP; - - Err = getaddrinfo(HostStr.c_str(), PortStr.c_str(), &Hints, &Addrs); - if (Err != 0) { - spdlog::error("[WasmEdge Httpsreq] {}", gai_strerror(Err)); - return Unexpect(ErrCode::Value::HostFuncError); - } - - for (struct addrinfo *Addr = Addrs; Addr != NULL; Addr = Addr->ai_next) { - Sfd = socket(Addr->ai_family, Addr->ai_socktype, Addr->ai_protocol); - if (Sfd == -1) { - Err = errno; - break; - } - if (connect(Sfd, Addr->ai_addr, Addr->ai_addrlen) == 0) - break; - Err = errno; - close(Sfd); - Sfd = -1; - } - - freeaddrinfo(Addrs); - - if (Sfd == -1) { - spdlog::error("[WasmEdge Httpsreq] {}", strerror(Err)); - return Unexpect(ErrCode::Value::HostFuncError); - } - - SSL_set_fd(Ssl, Sfd); - - SSL_set_tlsext_host_name(Ssl, Host); - - const int Status = SSL_connect(Ssl); - if (Status != 1) { - const int Code = SSL_get_error(Ssl, Status); - ERR_print_errors_fp(stderr); - spdlog::error("[WasmEdge Httpsreq] SSL_get_error code {}", Code); - return Unexpect(ErrCode::Value::HostFuncError); - } - - SSL_write(Ssl, BodyStr.c_str(), BodyLen); - - // Receive - char Buffer[1024]; - int Nbytes = 0; - Env.Rcv = ""; - while (true) { - Nbytes = SSL_read(Ssl, Buffer, 1024); - if (Nbytes <= 0) { - break; - } - std::string Buf(Buffer, Nbytes); - Env.Rcv = Env.Rcv + Buf; - } - - SSL_free(Ssl); - close(Sfd); - SSL_CTX_free(Ctx); - - return {}; -} - -Expect WasmEdgeHttpsReqGetRcv::body(const Runtime::CallingFrame &Frame, - uint32_t BufPtr) { - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - return Unexpect(ErrCode::Value::HostFuncError); - } - char *Buf = MemInst->getPointer(BufPtr); - std::copy_n(Env.Rcv.begin(), Env.Rcv.size(), Buf); - return {}; -} - -Expect -WasmEdgeHttpsReqGetRcvLen::body(const Runtime::CallingFrame &) { - return static_cast(Env.Rcv.size()); -} - -} // namespace Host -} // namespace WasmEdge diff --git a/plugins/wasmedge_httpsreq/httpsreqfunc.h b/plugins/wasmedge_httpsreq/httpsreqfunc.h deleted file mode 100644 index 566a1a7a..00000000 --- a/plugins/wasmedge_httpsreq/httpsreqfunc.h +++ /dev/null @@ -1,37 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC - -#pragma once - -#include "httpsreqbase.h" - -namespace WasmEdge { -namespace Host { - -class WasmEdgeHttpsReqSendData - : public WasmEdgeHttpsReq { -public: - WasmEdgeHttpsReqSendData(WasmEdgeHttpsReqEnvironment &HostEnv) - : WasmEdgeHttpsReq(HostEnv) {} - Expect body(const Runtime::CallingFrame &Frame, uint32_t HostPtr, - uint32_t HostLen, uint32_t Port, uint32_t BodyPtr, - uint32_t BodyLen); -}; - -class WasmEdgeHttpsReqGetRcv : public WasmEdgeHttpsReq { -public: - WasmEdgeHttpsReqGetRcv(WasmEdgeHttpsReqEnvironment &HostEnv) - : WasmEdgeHttpsReq(HostEnv) {} - Expect body(const Runtime::CallingFrame &Frame, uint32_t BufPtr); -}; - -class WasmEdgeHttpsReqGetRcvLen - : public WasmEdgeHttpsReq { -public: - WasmEdgeHttpsReqGetRcvLen(WasmEdgeHttpsReqEnvironment &HostEnv) - : WasmEdgeHttpsReq(HostEnv) {} - Expect body(const Runtime::CallingFrame &Frame); -}; - -} // namespace Host -} // namespace WasmEdge diff --git a/plugins/wasmedge_httpsreq/httpsreqmodule.cpp b/plugins/wasmedge_httpsreq/httpsreqmodule.cpp deleted file mode 100644 index d8e43208..00000000 --- a/plugins/wasmedge_httpsreq/httpsreqmodule.cpp +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC - -#include "httpsreqmodule.h" -#include "httpsreqfunc.h" - -namespace WasmEdge { -namespace Host { - -/// Register your functions in module. -WasmEdgeHttpsReqModule::WasmEdgeHttpsReqModule() - : ModuleInstance("wasmedge_httpsreq") { - addHostFunc("wasmedge_httpsreq_send_data", - std::make_unique(Env)); - addHostFunc("wasmedge_httpsreq_get_rcv", - std::make_unique(Env)); - addHostFunc("wasmedge_httpsreq_get_rcv_len", - std::make_unique(Env)); -} - -} // namespace Host -} // namespace WasmEdge diff --git a/plugins/wasmedge_httpsreq/httpsreqmodule.h b/plugins/wasmedge_httpsreq/httpsreqmodule.h deleted file mode 100644 index 4931fc54..00000000 --- a/plugins/wasmedge_httpsreq/httpsreqmodule.h +++ /dev/null @@ -1,24 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC - -#pragma once - -#include "httpsreqenv.h" -#include "runtime/instance/module.h" -#include - -namespace WasmEdge { -namespace Host { - -class WasmEdgeHttpsReqModule : public Runtime::Instance::ModuleInstance { -public: - WasmEdgeHttpsReqModule(); - - WasmEdgeHttpsReqEnvironment &getEnv() { return Env; } - -private: - WasmEdgeHttpsReqEnvironment Env; -}; - -} // namespace Host -} // namespace WasmEdge diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt index c7bbae53..5a5f7c35 100644 --- a/test/plugins/CMakeLists.txt +++ b/test/plugins/CMakeLists.txt @@ -11,12 +11,6 @@ if(WASMEDGE_PLUGIN_WASI_NN_BACKEND) add_subdirectory(wasi_nn) endif() -if(WASMEDGE_PLUGIN_HTTPSREQ) - if(CMAKE_SYSTEM_NAME MATCHES "Linux") - add_subdirectory(wasmedge_httpsreq) - endif() -endif() - if(WASMEDGE_PLUGIN_WASI_CRYPTO) add_subdirectory(wasi_crypto) endif() diff --git a/test/plugins/wasmedge_httpsreq/CMakeLists.txt b/test/plugins/wasmedge_httpsreq/CMakeLists.txt deleted file mode 100644 index 921ee619..00000000 --- a/test/plugins/wasmedge_httpsreq/CMakeLists.txt +++ /dev/null @@ -1,35 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC - -wasmedge_add_executable(wasmedgeHttpsReqTests - httpsreq.cpp -) - -add_dependencies(wasmedgeHttpsReqTests - wasmedgePluginHttpsReq -) - -target_include_directories(wasmedgeHttpsReqTests - PUBLIC - $ - $ -) - -target_link_libraries(wasmedgeHttpsReqTests - PRIVATE - ${GTEST_BOTH_LIBRARIES} -) -# Link to the WasmEdge library -if(WASMEDGE_LINK_PLUGINS_STATIC) - target_link_libraries(wasmedgeHttpsReqTests - PRIVATE - wasmedgeCAPI - ) -else() - target_link_libraries(wasmedgeHttpsReqTests - PRIVATE - wasmedge_shared - ) -endif() - -add_test(wasmedgeHttpsReqTests wasmedgeHttpsReqTests) diff --git a/test/plugins/wasmedge_httpsreq/httpsreq.cpp b/test/plugins/wasmedge_httpsreq/httpsreq.cpp deleted file mode 100644 index 6306211a..00000000 --- a/test/plugins/wasmedge_httpsreq/httpsreq.cpp +++ /dev/null @@ -1,159 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC - -#include "common/defines.h" -#include "httpsreqfunc.h" -#include "httpsreqmodule.h" -#include "runtime/instance/module.h" - -#include -#include -#include -#include -#include -#include - -namespace { -WasmEdge::Runtime::Instance::ModuleInstance *createModule() { - using namespace std::literals::string_view_literals; - WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( - "../../../plugins/wasmedge_httpsreq/" - "libwasmedgePluginHttpsReq" WASMEDGE_LIB_EXTENSION)); - if (const auto *Plugin = - WasmEdge::Plugin::Plugin::find("wasmedge_httpsreq"sv)) { - if (const auto *Module = Plugin->findModule("wasmedge_httpsreq"sv)) { - return Module->create().release(); - } - } - return nullptr; -} - -void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, - uint32_t Offset, uint32_t Cnt, uint8_t C = 0) noexcept { - std::fill_n(MemInst.getPointer(Offset), Cnt, C); -} - -void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, - uint32_t Offset, const std::string &Str) noexcept { - char *Buf = MemInst.getPointer(Offset); - std::copy_n(Str.c_str(), Str.length(), Buf); -} - -} // namespace - -TEST(wasmedgeHttpsReqTests, SendData) { - // Create the wasmedge httpsreq module instance. - auto *HttpMod = - dynamic_cast(createModule()); - EXPECT_FALSE(HttpMod == nullptr); - - // Create the calling frame with memory instance. - WasmEdge::Runtime::Instance::ModuleInstance Mod(""); - Mod.addHostMemory( - "memory", std::make_unique( - WasmEdge::AST::MemoryType(1))); - auto *MemInstPtr = Mod.findMemoryExports("memory"); - ASSERT_TRUE(MemInstPtr != nullptr); - auto &MemInst = *MemInstPtr; - WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); - - // Clear the memory[0, 256]. - fillMemContent(MemInst, 0, 256); - // Set the memory[0, 11] as string "httpbin.org". - fillMemContent(MemInst, 0, std::string("httpbin.org")); - // Set the memory[30, 116] as string "GET / HTTP/1.1\nHost: - // httpbin.org\r\nConnection: Close\r\nReferer: https://httpbin.org/\r\n\r\n". - fillMemContent(MemInst, 30, - std::string("GET / HTTP/1.1\nHost: httpbin.org\r\nConnection: " - "Close\r\nReferer: https://httpbin.org/\r\n\r\n")); - - // Get the function "send_data" - auto *FuncInst = HttpMod->findFuncExports("wasmedge_httpsreq_send_data"); - EXPECT_NE(FuncInst, nullptr); - EXPECT_TRUE(FuncInst->isHostFunction()); - auto &HostFuncInst = dynamic_cast( - FuncInst->getHostFunc()); - - // Test: Run function successfully for get requests - EXPECT_TRUE(HostFuncInst.run( - CallFrame, - std::initializer_list{ - UINT32_C(0), UINT32_C(11), UINT32_C(443), UINT32_C(30), UINT32_C(86)}, - {})); - delete HttpMod; -} - -TEST(wasmedgeHttpsReqTests, GetRcv) { - // Create the httpsreq module instance. - auto *HttpMod = - dynamic_cast(createModule()); - EXPECT_FALSE(HttpMod == nullptr); - - // Create the calling frame with memory instance. - WasmEdge::Runtime::Instance::ModuleInstance Mod(""); - Mod.addHostMemory( - "memory", std::make_unique( - WasmEdge::AST::MemoryType(1))); - auto *MemInstPtr = Mod.findMemoryExports("memory"); - ASSERT_TRUE(MemInstPtr != nullptr); - auto &MemInst = *MemInstPtr; - WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); - - fillMemContent(MemInst, 0, 256); - - // Set the memory[0, 11] as string "httpbin.org". - fillMemContent(MemInst, 0, std::string("httpbin.org")); - // Set the memory[30, 116] as string "GET / HTTP/1.1\nHost: - // httpbin.org\r\nConnection: Close\r\nReferer: https://httpbin.org/\r\n\r\n". - fillMemContent(MemInst, 30, - std::string("GET / HTTP/1.1\nHost: httpbin.org\r\nConnection: " - "Close\r\nReferer: https://httpbin.org/\r\n\r\n")); - - // Get the function "send_data" - auto *FuncInst = HttpMod->findFuncExports("wasmedge_httpsreq_send_data"); - EXPECT_NE(FuncInst, nullptr); - EXPECT_TRUE(FuncInst->isHostFunction()); - auto &HostFuncSendData = - dynamic_cast( - FuncInst->getHostFunc()); - - // Get the function "get_rcv_len" - FuncInst = HttpMod->findFuncExports("wasmedge_httpsreq_get_rcv_len"); - EXPECT_NE(FuncInst, nullptr); - EXPECT_TRUE(FuncInst->isHostFunction()); - auto &HostFuncGetRcvLen = - dynamic_cast( - FuncInst->getHostFunc()); - - // Get the function "get_rcv" - FuncInst = HttpMod->findFuncExports("wasmedge_httpsreq_get_rcv"); - EXPECT_NE(FuncInst, nullptr); - EXPECT_TRUE(FuncInst->isHostFunction()); - auto &HostFuncGetRcv = dynamic_cast( - FuncInst->getHostFunc()); - - // Test: Run function successfully for get requests - EXPECT_TRUE(HostFuncSendData.run( - CallFrame, - std::initializer_list{ - UINT32_C(0), UINT32_C(11), UINT32_C(443), UINT32_C(30), UINT32_C(86)}, - {})); - - // Test: Run function successfully for getrcvlen - std::array RetVal; - EXPECT_TRUE(HostFuncGetRcvLen.run(CallFrame, {}, RetVal)); - uint32_t Len = RetVal[0].get(); - EXPECT_TRUE(Len > 0U); - - // Test: Run function with nullptr memory instance -- fail - EXPECT_FALSE(HostFuncGetRcv.run( - WasmEdge::Runtime::CallingFrame(nullptr, nullptr), - std::initializer_list{UINT32_C(0)}, {})); - - delete HttpMod; -} - -GTEST_API_ int main(int argc, char **argv) { - testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} From fc9f5a2d8fdb540d7c7819925a923f957729a804 Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Thu, 18 May 2023 19:44:20 +0800 Subject: [PATCH 104/623] [WASI-NN] Refactor and reorganize codes * Separate different backand codes to different file * Check all parameters before process * Fix typo, `Ouput` -> `Output` Signed-off-by: Shen-Ta Hsieh --- plugins/wasi_nn/CMakeLists.txt | 5 + plugins/wasi_nn/onnx.cpp | 35 + plugins/wasi_nn/onnx.h | 36 + plugins/wasi_nn/openvino.cpp | 418 +++++++++++ plugins/wasi_nn/openvino.h | 161 +++++ plugins/wasi_nn/tf.cpp | 35 + plugins/wasi_nn/tf.h | 36 + plugins/wasi_nn/tfl.cpp | 216 ++++++ plugins/wasi_nn/tfl.h | 65 ++ plugins/wasi_nn/torch.cpp | 177 +++++ plugins/wasi_nn/torch.h | 57 ++ plugins/wasi_nn/types.h | 92 +++ plugins/wasi_nn/wasinnbase.h | 4 + plugins/wasi_nn/wasinnenv.h | 235 +++---- plugins/wasi_nn/wasinnfunc.cpp | 1104 +++++------------------------- plugins/wasi_nn/wasinnfunc.h | 61 +- plugins/wasi_nn/wasinnmodule.cpp | 2 +- test/plugins/wasi_nn/wasi_nn.cpp | 12 +- 18 files changed, 1669 insertions(+), 1082 deletions(-) create mode 100644 plugins/wasi_nn/onnx.cpp create mode 100644 plugins/wasi_nn/onnx.h create mode 100644 plugins/wasi_nn/openvino.cpp create mode 100644 plugins/wasi_nn/openvino.h create mode 100644 plugins/wasi_nn/tf.cpp create mode 100644 plugins/wasi_nn/tf.h create mode 100644 plugins/wasi_nn/tfl.cpp create mode 100644 plugins/wasi_nn/tfl.h create mode 100644 plugins/wasi_nn/torch.cpp create mode 100644 plugins/wasi_nn/torch.h create mode 100644 plugins/wasi_nn/types.h diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index dd5a4a13..7c9e9412 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -6,6 +6,11 @@ wasmedge_add_library(wasmedgePluginWasiNN wasinnenv.cpp wasinnfunc.cpp wasinnmodule.cpp + openvino.cpp + onnx.cpp + tf.cpp + torch.cpp + tfl.cpp ) target_compile_options(wasmedgePluginWasiNN diff --git a/plugins/wasi_nn/onnx.cpp b/plugins/wasi_nn/onnx.cpp new file mode 100644 index 00000000..3fa4b326 --- /dev/null +++ b/plugins/wasi_nn/onnx.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "onnx.h" +#include "wasinnenv.h" + +namespace WasmEdge::Host::WASINN::ONNX { +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error("[WASI-NN] ONNX backend is not supported."); + return WASINN::ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WASINN::WasiNNEnvironment &, + Span>, WASINN::Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WASINN::WasiNNEnvironment &, uint32_t, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + Span, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WASINN::WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +} // namespace WasmEdge::Host::WASINN::ONNX diff --git a/plugins/wasi_nn/onnx.h b/plugins/wasi_nn/onnx.h new file mode 100644 index 00000000..f27c7157 --- /dev/null +++ b/plugins/wasi_nn/onnx.h @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "plugin/plugin.h" +#include "types.h" + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} + +namespace WasmEdge::Host::WASINN::ONNX { +struct Graph {}; +struct Context { + Context(Graph &) noexcept {} +}; + +struct Environ {}; + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +} // namespace WasmEdge::Host::WASINN::ONNX diff --git a/plugins/wasi_nn/openvino.cpp b/plugins/wasi_nn/openvino.cpp new file mode 100644 index 00000000..afb29632 --- /dev/null +++ b/plugins/wasi_nn/openvino.cpp @@ -0,0 +1,418 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "openvino.h" +#include "wasinnenv.h" +#include + +namespace WasmEdge::Host::WASINN::OpenVINO { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept { + // The OpenVINO core must be initialized in constructor. + if (unlikely(Env.OpenVINOCore == nullptr)) { + spdlog::error("[WASI-NN] OpenVINO core not initialized."); + return WASINN::ErrNo::MissingMemory; + } + + // The graph builder length must be 2. + if (Builders.size() != 2) { + spdlog::error("[WASI-NN] Wrong GraphBuilder Length {:d}, expect 2", + Builders.size()); + return WASINN::ErrNo::InvalidArgument; + } + + // Get the XML and Weight raw buffer. + // Builder-0: the XML string + // Builder-1: the Weight binary + auto XML = Builders[0]; + auto Weight = Builders[1]; + + // Add a new graph. + Env.NNGraph.emplace_back(Backend::OpenVINO); + auto &GraphRef = Env.NNGraph.back().get(); + + // Create the weights blob memory. + tensor_desc_t WeightsDesc{ + layout_e::ANY, {1, {Weight.size()}}, precision_e::U8}; + IEStatusCode Status = + ie_blob_make_memory(&WeightsDesc, &(GraphRef.OpenVINOWeightBlob)); + if (Status != IEStatusCode::OK) { + spdlog::error( + "[WASI-NN] Unable to create the model's weight blob, error code: {}", + Status); + Env.NNGraph.pop_back(); + return WASINN::ErrNo::Busy; + } + + // Copy the weights buffer to the blob. + ie_blob_buffer_t BlobBuffer; + Status = ie_blob_get_buffer(GraphRef.OpenVINOWeightBlob, &BlobBuffer); + if (unlikely(Status != IEStatusCode::OK)) { + spdlog::error( + "[WASI-NN] Unable to find the weight blob's buffer, error code: {}", + Status); + Env.NNGraph.pop_back(); + return WASINN::ErrNo::MissingMemory; + } + std::copy_n(Weight.data(), Weight.size(), + static_cast(BlobBuffer.buffer)); + + // Read network from memory. + Status = ie_core_read_network_from_memory( + Env.OpenVINOCore, XML.data(), XML.size(), GraphRef.OpenVINOWeightBlob, + &(GraphRef.OpenVINONetwork)); + if (Status != IEStatusCode::OK) { + spdlog::error("[WASI-NN] Unable to read network from the XML and " + "Weights, error code: {}", + Status); + Env.NNGraph.pop_back(); + return WASINN::ErrNo::Busy; + } + + // Get the network input and output size. + size_t NetworkInputSize = 0; + Status = + ie_network_get_inputs_number(GraphRef.OpenVINONetwork, &NetworkInputSize); + if (unlikely(Status != IEStatusCode::OK)) { + spdlog::error("[WASI-NN] Unable to get the inputs number from the " + "network, error code: {}", + Status); + Env.NNGraph.pop_back(); + return WASINN::ErrNo::MissingMemory; + } + spdlog::debug("[WASI-NN] Got input size: {}", NetworkInputSize); + size_t NetworkOutputSize = 0; + Status = ie_network_get_outputs_number(GraphRef.OpenVINONetwork, + &NetworkOutputSize); + if (unlikely(Status != IEStatusCode::OK)) { + spdlog::error("[WASI-NN] Unable to get the outputs number from the " + "network, error code: {}", + Status); + Env.NNGraph.pop_back(); + return WASINN::ErrNo::MissingMemory; + } + spdlog::debug("[WASI-NN] Got output size: {}", NetworkOutputSize); + + // Get and store the input and output names. + GraphRef.OpenVINOInputNames.resize(NetworkInputSize, nullptr); + for (size_t I = 0; I < NetworkInputSize; I++) { + Status = ie_network_get_input_name(GraphRef.OpenVINONetwork, I, + &(GraphRef.OpenVINOInputNames[I])); + if (Status != IEStatusCode::OK) { + spdlog::error("[WASI-NN] Unable to find input name correctly with " + "Index {}, error code: {}", + I, Status); + Env.NNGraph.pop_back(); + return WASINN::ErrNo::MissingMemory; + } + spdlog::debug("[WASI-NN] Got input name: {}", + GraphRef.OpenVINOInputNames[I]); + } + GraphRef.OpenVINOOutputNames.resize(NetworkOutputSize, nullptr); + for (size_t I = 0; I < NetworkOutputSize; I++) { + Status = ie_network_get_output_name(GraphRef.OpenVINONetwork, I, + &(GraphRef.OpenVINOOutputNames[I])); + if (Status != IEStatusCode::OK) { + spdlog::error("[WASI-NN] Unable to find output name correctly with " + "Index {}, error code: {}", + I, Status); + Env.NNGraph.pop_back(); + return WASINN::ErrNo::MissingMemory; + } + spdlog::debug("[WASI-NN] Got output name: {}", + GraphRef.OpenVINOOutputNames[I]); + } + + // Set the input layout. + // FIXME: this is a temporary workaround. We need a more eligant way to + // specify the layout in the long run. However, without this newer versions + // of OpenVINO will fail due to parameter mismatch. + for (size_t I = 0; I < NetworkInputSize; I++) { + // More layouts should be supported. + Status = ie_network_set_input_layout(GraphRef.OpenVINONetwork, + GraphRef.OpenVINOInputNames[I], + layout_e::NHWC); + spdlog::debug("[WASI-NN] Setting [{}] to NHWC", + GraphRef.OpenVINOInputNames[I]); + if (Status != IEStatusCode::OK) { + spdlog::error("[WASI-NN] Unable to set input layout with the input " + "name {}, error code: {}", + GraphRef.OpenVINOInputNames[I], Status); + Env.NNGraph.pop_back(); + return WASINN::ErrNo::MissingMemory; + } + } + + // Load network. + ie_config_t Config = {nullptr, nullptr, nullptr}; + Status = ie_core_load_network(Env.OpenVINOCore, GraphRef.OpenVINONetwork, + fmt::format("{}"sv, Device).c_str(), &Config, + &GraphRef.OpenVINOExecNetwork); + if (Status != IEStatusCode::OK) { + spdlog::error( + "[WASI-NN] Unable to create executable Network, error code: {}", + Status); + Env.NNGraph.pop_back(); + return WASINN::ErrNo::Busy; + } + + // Store the loaded graph. + GraphId = Env.NNGraph.size() - 1; + + return WASINN::ErrNo::Success; +} + +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept { + // Check the network and the execution network with the graph ID. + auto &GraphRef = Env.NNGraph[GraphId].get(); + if (GraphRef.OpenVINONetwork == nullptr || + GraphRef.OpenVINOExecNetwork == nullptr) { + spdlog::error("[WASI-NN] Model for Graph:{} is empty!", GraphId); + return WASINN::ErrNo::MissingMemory; + } + + // Create context. + Env.NNContext.emplace_back(Env.NNGraph[GraphId]); + auto &CtxRef = Env.NNContext.back().get(); + if (CtxRef.OpenVINOInferRequest == nullptr) { + spdlog::error("[WASI-NN] Unable to create openvino context"); + Env.NNContext.pop_back(); + return WASINN::ErrNo::Busy; + } + + ContextId = Env.NNContext.size() - 1; + return WASINN::ErrNo::Success; +} + +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + // Check the infer request and the network. + auto *Network = CxtRef.GraphRef.OpenVINONetwork; + if (Network == nullptr || CxtRef.OpenVINOInferRequest == nullptr) { + spdlog::error("[WASI-NN] The founded openvino session is empty"); + return WASINN::ErrNo::MissingMemory; + } + + // Check the input index. + if (CxtRef.GraphRef.OpenVINOInputNames.size() <= Index) { + spdlog::error("[WASI-NN] The input index {} exceeds the inputs number {}.", + Index, CxtRef.GraphRef.OpenVINOInputNames.size()); + return WASINN::ErrNo::InvalidArgument; + } + char *InputName = CxtRef.GraphRef.OpenVINOInputNames[Index]; + + if (Tensor.Dimension.size() > 8) { + spdlog::error( + "[WASI-NN] Tensor dimension is out of range, expect it under 8-dim, " + "but got {}-dim.", + Tensor.Dimension.size()); + return WASINN::ErrNo::InvalidArgument; + } + if (Tensor.RType != WASINN::TensorType::F32) { + spdlog::error( + "[WASI-NN] Only F32 inputs and outputs are supported for now."); + return WASINN::ErrNo::InvalidArgument; + } + + // Set the input resize algorithm. + // Mark the input as resizable by setting a resize algorithm. + // In this case we will be able to set an input blob of any shape to an + // infer request. Resizing and layout conversions are executed automatically + // when inferring. + IEStatusCode Status = ie_network_set_input_resize_algorithm( + Network, InputName, RESIZE_BILINEAR); + if (Status != IEStatusCode::OK) { + spdlog::error( + "[WASI-NN] Unable to set input resize correctly, error code: {}", + Status); + return WASINN::ErrNo::InvalidArgument; + } + + // Set the input layout. + // More layouts should be supported. + Status = ie_network_set_input_layout(Network, InputName, layout_e::NHWC); + if (Status != IEStatusCode::OK) { + spdlog::error( + "[WASI-NN] Unable to set input layout correctly, error code: {}", + Status); + return WASINN::ErrNo::InvalidArgument; + } + + // Set the input precision. + // More types should be supported. + Status = + ie_network_set_input_precision(Network, InputName, precision_e::FP32); + if (Status != IEStatusCode::OK) { + spdlog::error( + "[WASI-NN] Unable to set input precision correctly, error code: {}", + Status); + return WASINN::ErrNo::InvalidArgument; + } + + // Set the dimensions and the tensor description. + dimensions_t Dimens; + Dimens.ranks = Tensor.Dimension.size(); + for (size_t I = 0; I < Dimens.ranks; I++) { + Dimens.dims[I] = static_cast(Tensor.Dimension[I]); + } + tensor_desc_t TensorDesc = {layout_e::NHWC, Dimens, precision_e::FP32}; + + // Create the input blob memory. + ie_blob_t *InputBlob = nullptr; + Status = ie_blob_make_memory(&TensorDesc, &InputBlob); + if (Status != IEStatusCode::OK) { + spdlog::error("[WASI-NN] Unable to allocated input tensor correctly, " + "error code: {}", + Status); + return WASINN::ErrNo::Busy; + } + + // Get the blob buffer size and compare with the tensor size. + int BlobSize; + Status = ie_blob_size(InputBlob, &BlobSize); + if (unlikely(Status != IEStatusCode::OK)) { + spdlog::error("[WASI-NN] Unable to get the input blob size, error code: {}", + Status); + return WASINN::ErrNo::Busy; + } + if (unlikely(static_cast(BlobSize * 4) != Tensor.Tensor.size())) { + spdlog::error("[WASI-NN] Blob size {} and the Tensor size {} not matched.", + BlobSize * 4, Tensor.Tensor.size()); + } + + // Copy the data into the input blob buffer. + ie_blob_buffer_t BlobBuffer; + Status = ie_blob_get_buffer(InputBlob, &BlobBuffer); + if (unlikely(Status != IEStatusCode::OK)) { + spdlog::error("[WASI-NN] Unable to find input tensor buffer"); + ie_blob_free(&InputBlob); + return WASINN::ErrNo::MissingMemory; + } + std::copy_n(Tensor.Tensor.data(), Tensor.Tensor.size(), + static_cast(BlobBuffer.buffer)); + + // Set input blob. + Status = ie_infer_request_set_blob(CxtRef.OpenVINOInferRequest, InputName, + InputBlob); + if (Status != IEStatusCode::OK) { + spdlog::error("[WASI-NN] Unable to set input tensor to model correctly, " + "error code: {}", + Status); + ie_blob_free(&InputBlob); + return WASINN::ErrNo::Busy; + } + + ie_blob_free(&InputBlob); + + return WASINN::ErrNo::Success; +} + +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto *Network = CxtRef.GraphRef.OpenVINONetwork; + + // Check the output index. + if (CxtRef.GraphRef.OpenVINOOutputNames.size() <= Index) { + spdlog::error( + "[WASI-NN] The output index {} exceeds the outputs number {}.", Index, + CxtRef.GraphRef.OpenVINOOutputNames.size()); + return WASINN::ErrNo::InvalidArgument; + } + char *OutputName = CxtRef.GraphRef.OpenVINOOutputNames[Index]; + + // Set output precision. + IEStatusCode Status = + ie_network_set_output_precision(Network, OutputName, precision_e::FP32); + if (Status != IEStatusCode::OK) { + spdlog::error( + "[WASI-NN] Unable to set output precision correctly with Index:{}", + Index); + return WASINN::ErrNo::InvalidArgument; + } + + // Get output blob buffer. + ie_blob_t *OutputBlob = nullptr; + Status = ie_infer_request_get_blob(CxtRef.OpenVINOInferRequest, OutputName, + &OutputBlob); + if (Status != IEStatusCode::OK) { + spdlog::error("[WASI-NN] Unable to retrieve output tensor correctly", + Index); + return WASINN::ErrNo::InvalidArgument; + } + + // Get the blob size and copy the output buffer. + int BlobSize; + Status = ie_blob_size(OutputBlob, &BlobSize); + ie_blob_buffer_t BlobCBuffer; + Status = ie_blob_get_cbuffer(OutputBlob, &BlobCBuffer); + if (Status != IEStatusCode::OK) { + spdlog::error("[WASI-NN] Unable to retrieve output tensor correctly", + Index); + ie_blob_free(&OutputBlob); + return WASINN::ErrNo::MissingMemory; + } + uint32_t BytesToWrite = + std::min(static_cast(BlobSize * 4), OutBuffer.size()); + std::copy_n(static_cast(BlobCBuffer.cbuffer), BytesToWrite, + OutBuffer.data()); + + // Write the bytes written result. + BytesWritten = BytesToWrite; + + ie_blob_free(&OutputBlob); + + return WASINN::ErrNo::Success; +} + +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + IEStatusCode Status = ie_infer_request_infer(CxtRef.OpenVINOInferRequest); + if (Status != IEStatusCode::OK) { + spdlog::error( + "[WASI-NN] Unable to perform computation correctly, error code: {}", + Status); + return WASINN::ErrNo::Busy; + } + return WASINN::ErrNo::Success; +} +#else +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error("[WASI-NN] OpenVINO backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"OpenVINO\" to build it."); + return WASINN::ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WASINN::WasiNNEnvironment &, + Span>, WASINN::Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WASINN::WasiNNEnvironment &, uint32_t, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + Span, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WASINN::WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +#endif +} // namespace WasmEdge::Host::WASINN::OpenVINO diff --git a/plugins/wasi_nn/openvino.h b/plugins/wasi_nn/openvino.h new file mode 100644 index 00000000..83821a7e --- /dev/null +++ b/plugins/wasi_nn/openvino.h @@ -0,0 +1,161 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "plugin/plugin.h" +#include "types.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO +#include "common/log.h" +#include +#include + +template <> +struct fmt::formatter : fmt::formatter { + fmt::format_context::iterator format(IEStatusCode Code, + fmt::format_context &Ctx) const { + using namespace std::literals; + std::string_view Name; + switch (Code) { + case OK: + Name = "OK"sv; + break; + case GENERAL_ERROR: + Name = "GENERAL_ERROR"sv; + break; + case NOT_IMPLEMENTED: + Name = "NOT_IMPLEMENTED"sv; + break; + case NETWORK_NOT_LOADED: + Name = "NETWORK_NOT_LOADED"sv; + break; + case PARAMETER_MISMATCH: + Name = "PARAMETER_MISMATCH"sv; + break; + case NOT_FOUND: + Name = "NOT_FOUND"sv; + break; + case OUT_OF_BOUNDS: + Name = "OUT_OF_BOUNDS"sv; + break; + case UNEXPECTED: + Name = "UNEXPECTED"sv; + break; + case REQUEST_BUSY: + Name = "REQUEST_BUSY"sv; + break; + case RESULT_NOT_READY: + Name = "RESULT_NOT_READY"sv; + break; + case NOT_ALLOCATED: + Name = "NOT_ALLOCATED"sv; + break; + case INFER_NOT_STARTED: + Name = "INFER_NOT_STARTED"sv; + break; + case NETWORK_NOT_READ: + Name = "NETWORK_NOT_READ"sv; + break; + case INFER_CANCELLED: + Name = "INFER_CANCELLED"sv; + break; + default: + Name = "Unknown"sv; + } + return fmt::formatter::format(Name, Ctx); + } +}; +#endif + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} + +namespace WasmEdge::Host::WASINN::OpenVINO { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO +struct Graph { + ~Graph() noexcept { + if (OpenVINONetwork) { + ie_network_free(&OpenVINONetwork); + } + if (OpenVINOExecNetwork) { + ie_exec_network_free(&OpenVINOExecNetwork); + } + if (OpenVINOWeightBlob) { + ie_blob_free(&OpenVINOWeightBlob); + } + for (auto &I : OpenVINOInputNames) { + if (I) { + ie_network_name_free(&I); + } + } + for (auto &I : OpenVINOOutputNames) { + if (I) { + ie_network_name_free(&I); + } + } + } + ie_network_t *OpenVINONetwork = nullptr; + ie_executable_network_t *OpenVINOExecNetwork = nullptr; + ie_blob_t *OpenVINOWeightBlob = nullptr; + std::vector OpenVINOInputNames; + std::vector OpenVINOOutputNames; +}; + +struct Context { + Context(Graph &G) noexcept : GraphRef(G) { + IEStatusCode Status = ie_exec_network_create_infer_request( + G.OpenVINOExecNetwork, &OpenVINOInferRequest); + if (Status != IEStatusCode::OK) { + OpenVINOInferRequest = nullptr; + spdlog::error("[WASI-NN] Unable to create infer request for OpenVINO"); + } + } + ~Context() noexcept { + if (OpenVINOInferRequest) { + ie_infer_request_free(&OpenVINOInferRequest); + } + } + Graph &GraphRef; + ie_infer_request_t *OpenVINOInferRequest = nullptr; +}; + +struct Environ { + Environ() noexcept { + if (ie_core_create("", &OpenVINOCore) != IEStatusCode::OK) { + spdlog::error( + "[WASI-NN] Error happened when initializing OpenVINO core."); + } + } + ~Environ() noexcept { + if (OpenVINOCore) { + ie_core_free(&OpenVINOCore); + } + } + ie_core_t *OpenVINOCore = nullptr; +}; +#else +struct Graph {}; +struct Context { + Context(Graph &) noexcept {} +}; +struct Environ {}; +#endif + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +} // namespace WasmEdge::Host::WASINN::OpenVINO diff --git a/plugins/wasi_nn/tf.cpp b/plugins/wasi_nn/tf.cpp new file mode 100644 index 00000000..353d1cf7 --- /dev/null +++ b/plugins/wasi_nn/tf.cpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "tf.h" +#include "wasinnenv.h" + +namespace WasmEdge::Host::WASINN::Tensorflow { +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error("[WASI-NN] Tensorflow backend is not supported."); + return WASINN::ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WASINN::WasiNNEnvironment &, + Span>, WASINN::Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WASINN::WasiNNEnvironment &, uint32_t, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + Span, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WASINN::WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +} // namespace WasmEdge::Host::WASINN::Tensorflow diff --git a/plugins/wasi_nn/tf.h b/plugins/wasi_nn/tf.h new file mode 100644 index 00000000..b35c904b --- /dev/null +++ b/plugins/wasi_nn/tf.h @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "plugin/plugin.h" +#include "types.h" + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} + +namespace WasmEdge::Host::WASINN::Tensorflow { +struct Graph {}; +struct Context { + Context(Graph &) noexcept {} +}; + +struct Environ {}; + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +} // namespace WasmEdge::Host::WASINN::Tensorflow diff --git a/plugins/wasi_nn/tfl.cpp b/plugins/wasi_nn/tfl.cpp new file mode 100644 index 00000000..73450c57 --- /dev/null +++ b/plugins/wasi_nn/tfl.cpp @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "tfl.h" +#include "wasinnenv.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE +#include "tensorflow/lite/c/common.h" +#endif + +namespace WasmEdge::Host::WASINN::TensorflowLite { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept { + if ((Device != WASINN::Device::CPU)) { + spdlog::error("[WASI-NN] TensorflowLite Only support CPU target."); + return WASINN::ErrNo::InvalidArgument; + } + // The graph builder length must be 1. + if (Builders.size() != 1) { + spdlog::error("[WASI-NN] Wrong GraphBuilder Length {:d}, expect 1", + Builders.size()); + return WASINN::ErrNo::InvalidArgument; + } + auto Weight = Builders[0]; + // Add a new graph. + Env.NNGraph.emplace_back(WASINN::Backend::TensorflowLite); + auto &GraphRef = Env.NNGraph.back().get(); + + GraphRef.TFLiteMod = TfLiteModelCreate(Weight.data(), Weight.size()); + if (unlikely(GraphRef.TFLiteMod == nullptr)) { + spdlog::error("[WASI-NN] Cannot import TFLite model"); + Env.NNGraph.pop_back(); + return WASINN::ErrNo::InvalidArgument; + } + + // Store the loaded graph. + GraphId = Env.NNGraph.size() - 1; + return WASINN::ErrNo::Success; +} + +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept { + // Check the network and the execution network with the graph ID. + if (Env.NNGraph[GraphId].get().TFLiteMod == nullptr) { + spdlog::error("[WASI-NN] Model for Graph:{} is missing!", GraphId); + return WASINN::ErrNo::MissingMemory; + } + + // Create context. + Env.NNContext.emplace_back(Env.NNGraph[GraphId]); + auto &CxtRef = Env.NNContext.back().get(); + auto *TFLiteOps = TfLiteInterpreterOptionsCreate(); + TfLiteInterpreterOptionsSetNumThreads(TFLiteOps, 2); + CxtRef.TFLiteInterp = + TfLiteInterpreterCreate(CxtRef.GraphRef.TFLiteMod, TFLiteOps); + TfLiteInterpreterOptionsDelete(TFLiteOps); + if (unlikely(CxtRef.TFLiteInterp == nullptr)) { + spdlog::error("[WASI-NN] Cannot create TFLite interpreter."); + Env.NNContext.pop_back(); + return WASINN::ErrNo::Busy; + } + TfLiteInterpreterAllocateTensors(CxtRef.TFLiteInterp); + + ContextId = Env.NNContext.size() - 1; + return WASINN::ErrNo::Success; +} + +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const WASINN::TensorData &Tensor) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + uint32_t InCnt = TfLiteInterpreterGetInputTensorCount(CxtRef.TFLiteInterp); + if (Index >= InCnt) { + spdlog::error("[WASI-NN] Invalid index id {} for the input, only {} " + "inputs are allowed", + Index, InCnt); + return WASINN::ErrNo::InvalidArgument; + } + + auto *HoldTensor = + TfLiteInterpreterGetInputTensor(CxtRef.TFLiteInterp, Index); + // Check the input data size. + const auto HoldTensorByteSize = TfLiteTensorByteSize(HoldTensor); + if (HoldTensorByteSize != Tensor.Tensor.size()) { + spdlog::error("[WASI-NN] Expect tensor byte size {}, but got {}", + HoldTensorByteSize, Tensor.Tensor.size()); + return WASINN::ErrNo::InvalidArgument; + } + // Check the input tensor dimensions. + const auto HoldTensorNumDims = TfLiteTensorNumDims(HoldTensor); + if (static_cast(HoldTensorNumDims) != Tensor.Dimension.size()) { + spdlog::error("[WASI-NN] Expect tensor number of dimensions {}, but got {}", + HoldTensorNumDims, Tensor.Dimension.size()); + return WASINN::ErrNo::InvalidArgument; + } + for (uint32_t I = 0; I < Tensor.Dimension.size(); I++) { + const auto HoldTensorDim = TfLiteTensorDim(HoldTensor, I); + if (static_cast(HoldTensorDim) != Tensor.Dimension[I]) { + spdlog::error("[WASI-NN] Expect tensor dimension[{}] = {}, but got {}", I, + HoldTensorDim, Tensor.Dimension[I]); + return WASINN::ErrNo::InvalidArgument; + } + } + // Check the input tensor type. + WASINN::TensorType LiteType; + switch (const auto Type = TfLiteTensorType(HoldTensor)) { + case TfLiteType::kTfLiteUInt8: + LiteType = WASINN::TensorType::U8; + break; + case TfLiteType::kTfLiteFloat16: + LiteType = WASINN::TensorType::F16; + break; + case TfLiteType::kTfLiteFloat32: + LiteType = WASINN::TensorType::F32; + break; + case TfLiteType::kTfLiteInt32: + LiteType = WASINN::TensorType::I32; + break; + default: + spdlog::error("[WASI-NN] Unsupported TFLite type: {}", + TfLiteTypeGetName(Type)); + return WASINN::ErrNo::InvalidArgument; + } + + if (unlikely(LiteType != Tensor.RType)) { + spdlog::error("[WASI-NN] Expect tensor type {}, but got {}", LiteType, + Tensor.RType); + return WASINN::ErrNo::InvalidArgument; + } + TfLiteStatus Stat = TfLiteTensorCopyFromBuffer( + HoldTensor, Tensor.Tensor.data(), Tensor.Tensor.size()); + if (unlikely(Stat != TfLiteStatus::kTfLiteOk)) { + spdlog::error("[WASI-NN] Copy tensor memory failed"); + return WASINN::ErrNo::Busy; + } + + return WASINN::ErrNo::Success; +} + +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + uint32_t OutCnt = TfLiteInterpreterGetOutputTensorCount(CxtRef.TFLiteInterp); + if (Index >= OutCnt) { + spdlog::error("[WASI-NN] Invalid index id {} for the input, only {} " + "outputs are allowed", + Index, OutCnt); + return WASINN::ErrNo::InvalidArgument; + } + const TfLiteTensor *HoldTensor = + TfLiteInterpreterGetOutputTensor(CxtRef.TFLiteInterp, Index); + const uint32_t BytesToWrite = TfLiteTensorByteSize(HoldTensor); + // Check out buffer max size. + if (OutBuffer.size() < BytesToWrite) { + spdlog::error("[WASI-NN] Expect out buffer max size {}, but got {}", + BytesToWrite, OutBuffer.size()); + return WASINN::ErrNo::InvalidArgument; + } + TfLiteTensorCopyToBuffer(HoldTensor, OutBuffer.data(), BytesToWrite); + BytesWritten = BytesToWrite; + return WASINN::ErrNo::Success; +} + +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + // Run session + if (unlikely(CxtRef.TFLiteInterp == nullptr)) { + spdlog::error("[WASI-NN] Tensorflow Lite context empty"); + return WASINN::ErrNo::MissingMemory; + } + TfLiteStatus Stat = TfLiteInterpreterInvoke(CxtRef.TFLiteInterp); + if (unlikely(Stat != TfLiteStatus::kTfLiteOk)) { + spdlog::error("[WASI-NN] Invocation failed."); + return WASINN::ErrNo::Busy; + } + return WASINN::ErrNo::Success; +} +#else +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error( + "[WASI-NN] TensorflowLite backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"Tensorflowlite\" to build it."); + return WASINN::ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WASINN::WasiNNEnvironment &, + Span>, WASINN::Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WASINN::WasiNNEnvironment &, uint32_t, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + Span, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WASINN::WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} + +#endif +} // namespace WasmEdge::Host::WASINN::TensorflowLite diff --git a/plugins/wasi_nn/tfl.h b/plugins/wasi_nn/tfl.h new file mode 100644 index 00000000..21d637f0 --- /dev/null +++ b/plugins/wasi_nn/tfl.h @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "plugin/plugin.h" +#include "types.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE +#include "tensorflow/lite/c/c_api.h" +#include +#endif + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} + +namespace WasmEdge::Host::WASINN::TensorflowLite { + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE +struct Graph { + ~Graph() noexcept { + if (TFLiteMod) { + TfLiteModelDelete(TFLiteMod); + } + } + TfLiteModel *TFLiteMod = nullptr; +}; + +struct Context { +public: + Context(Graph &G) noexcept : GraphRef(G) {} + ~Context() noexcept { + if (TFLiteInterp) { + TfLiteInterpreterDelete(TFLiteInterp); + } + } + Graph &GraphRef; + TfLiteInterpreter *TFLiteInterp = nullptr; +}; +#else +struct Graph {}; +struct Context { + Context(Graph &) noexcept {} +}; +#endif + +struct Environ {}; + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +} // namespace WasmEdge::Host::WASINN::TensorflowLite diff --git a/plugins/wasi_nn/torch.cpp b/plugins/wasi_nn/torch.cpp new file mode 100644 index 00000000..9ada3547 --- /dev/null +++ b/plugins/wasi_nn/torch.cpp @@ -0,0 +1,177 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "torch.h" +#include "wasinnenv.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH +#include +#endif + +namespace WasmEdge::Host::WASINN::PyTorch { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH +Expect load(WasiNNEnvironment &Env, Span> Builders, + Device Device, uint32_t &GraphId) noexcept { + // The graph builder length must be 1. + if (Builders.size() != 1) { + spdlog::error("[WASI-NN] Wrong GraphBuilder Length {:d}, expect 1", + Builders.size()); + return ErrNo::InvalidArgument; + } + + auto Weight = Builders[0]; + // Add a new graph. + Env.NNGraph.emplace_back(Backend::PyTorch); + auto &GraphRef = Env.NNGraph.back().get(); + // Setup Graph Device + if (Device == Device::CPU) { + GraphRef.TorchDevice = at::kCPU; + } else if (Device == Device::GPU) { + if (!torch::cuda::is_available()) { + spdlog::error( + "[WASI-NN] CUDA Unavailable, platform Cannot support GPU target."); + return ErrNo::InvalidArgument; + } + GraphRef.TorchDevice = at::kCUDA; + } else { + spdlog::error("[WASI-NN] PyTorch Only support CPU and GPU target."); + return ErrNo::InvalidArgument; + } + + std::istringstream BinRead( + std::string(reinterpret_cast(Weight.data()), Weight.size())); + + try { + GraphRef.TorchModel = torch::jit::load(BinRead); + GraphRef.TorchModel.to(GraphRef.TorchDevice); + } catch (const c10::Error &e) { + spdlog::error("[WASI-NN] Failed when load the TorchScript model."); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } + // Store the loaded graph. + GraphId = Env.NNGraph.size() - 1; + return ErrNo::Success; +} + +Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, + uint32_t &ContextId) noexcept { + Env.NNContext.emplace_back(Env.NNGraph[GraphId]); + + ContextId = Env.NNContext.size() - 1; + return ErrNo::Success; +} + +Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index, const TensorData &Tensor) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + if (Index >= CxtRef.TorchInputs.size()) { + CxtRef.TorchInputs.resize(Index + 1); + } + if (Tensor.RType != TensorType::F32) { + spdlog::error( + "[WASI-NN] Only F32 inputs and outputs are supported for now."); + return ErrNo::InvalidArgument; + } + auto Options = + torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false); + std::vector Dims; + for (size_t I = 0; I < Tensor.Dimension.size(); I++) { + Dims.push_back(static_cast(Tensor.Dimension[I])); + } + torch::Tensor InTensor = + torch::from_blob(reinterpret_cast(Tensor.Tensor.data()), Dims, + Options) + .to(CxtRef.GraphRef.TorchDevice); + + CxtRef.TorchInputs[Index] = InTensor.clone(); + return ErrNo::Success; +} + +Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index, Span OutBuffer, + uint32_t &BytesWritten) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + if (CxtRef.TorchOutputs.size() <= Index) { + spdlog::error( + "[WASI-NN] The output index {} exceeds the outputs number {}.", Index, + CxtRef.TorchOutputs.size()); + return ErrNo::InvalidArgument; + } + torch::Tensor OutTensor = + CxtRef.TorchOutputs[Index].to(at::kCPU).toType(torch::kFloat32); + float *TensorBuffer = OutTensor.data_ptr(); + + size_t BlobSize = 1; + for (auto I : OutTensor.sizes()) { + BlobSize *= I; + } + uint32_t BytesToWrite = + std::min(static_cast(BlobSize * 4), OutBuffer.size()); + std::copy_n(reinterpret_cast(TensorBuffer), BytesToWrite, + OutBuffer.data()); + BytesWritten = BytesToWrite; + return ErrNo::Success; +} + +Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + if (CxtRef.TorchInputs.size() == 0) { + spdlog::error("[WASI-NN] Input is not set!"); + return ErrNo::InvalidArgument; + } + for (size_t I = 0; I < CxtRef.TorchInputs.size(); I++) { + torch::jit::IValue InTensor = CxtRef.TorchInputs[I]; + if (InTensor.isNone()) { + spdlog::error("[WASI-NN] Input [{}] is not set!", I); + return ErrNo::InvalidArgument; + } + } + torch::jit::IValue RawOutput = + CxtRef.GraphRef.TorchModel.forward(CxtRef.TorchInputs); + // TODO: more output type should be supported here + if (RawOutput.isTensorList()) { + auto OutTensors = RawOutput.toTensorVector(); + for (auto &OneOf : OutTensors) { + CxtRef.TorchOutputs.push_back(OneOf.clone()); + } + } else if (RawOutput.isTensor()) { + auto OutTensor = RawOutput.toTensor(); + CxtRef.TorchOutputs.push_back(OutTensor.clone()); + } else { + spdlog::error("[WASI-NN] PyTorch backend only supports output a tensor " + "or a list of tensor"); + return ErrNo::InvalidArgument; + } + return ErrNo::Success; +} +#else +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error("[WASI-NN] PyTorch backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"PyTorch\" to build it."); + return ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WasiNNEnvironment &, Span>, Device, + uint32_t *) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WasiNNEnvironment &, uint32_t, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WasiNNEnvironment &, uint32_t, uint32_t, Span, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} + +#endif +} // namespace WasmEdge::Host::WASINN::PyTorch diff --git a/plugins/wasi_nn/torch.h b/plugins/wasi_nn/torch.h new file mode 100644 index 00000000..ec3ae71e --- /dev/null +++ b/plugins/wasi_nn/torch.h @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "plugin/plugin.h" +#include "types.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH +#include +#include +#endif + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} + +namespace WasmEdge::Host::WASINN::PyTorch { + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH +struct Graph { + torch::jit::Module TorchModel; + torch::DeviceType TorchDevice = at::kCPU; +}; + +struct Context { +public: + Context(Graph &G) noexcept : GraphRef(G) {} + Graph &GraphRef; + std::vector TorchInputs; + std::vector TorchOutputs; +}; +#else +struct Graph {}; +struct Context { + Context(Graph &) noexcept {} +}; +#endif + +struct Environ {}; + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +} // namespace WasmEdge::Host::WASINN::PyTorch diff --git a/plugins/wasi_nn/types.h b/plugins/wasi_nn/types.h new file mode 100644 index 00000000..54073a3f --- /dev/null +++ b/plugins/wasi_nn/types.h @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once +#include "common/log.h" +#include "common/span.h" +#include + +namespace WasmEdge::Host::WASINN { + +enum class ErrNo : uint32_t { + Success = 0, // No error occurred. + InvalidArgument = 1, // Caller module passed an invalid argument. + InvalidEncoding = 2, // Invalid encoding. + MissingMemory = 3, // Caller module is missing a memory export. + Busy = 4, // Device or resource busy. + RuntimeError = 5, // Runtime Error. +}; + +enum class TensorType : uint8_t { F16 = 0, F32 = 1, U8 = 2, I32 = 3 }; + +enum class Device : uint32_t { CPU = 0, GPU = 1, TPU = 2 }; + +enum class Backend : uint8_t { + OpenVINO = 0, + ONNX = 1, + Tensorflow = 2, + PyTorch = 3, + TensorflowLite = 4 +}; + +#define FOR_EACH_BACKEND(F) \ + F(OpenVINO) F(ONNX) F(Tensorflow) F(PyTorch) F(TensorflowLite) + +struct TensorData { + Span Dimension; + WASINN::TensorType RType; + Span Tensor; +}; + +} // namespace WasmEdge::Host::WASINN + +template <> +struct fmt::formatter + : fmt::formatter { + fmt::format_context::iterator format(WasmEdge::Host::WASINN::TensorType RType, + fmt::format_context &Ctx) const { + using namespace std::literals; + std::string_view Name; + switch (RType) { + case WasmEdge::Host::WASINN::TensorType::F16: + Name = "F16"sv; + break; + case WasmEdge::Host::WASINN::TensorType::F32: + Name = "F32"sv; + break; + case WasmEdge::Host::WASINN::TensorType::U8: + Name = "U8"sv; + break; + case WasmEdge::Host::WASINN::TensorType::I32: + Name = "I32"sv; + break; + default: + Name = "Unknown"sv; + } + return fmt::formatter::format(Name, Ctx); + } +}; + +template <> +struct fmt::formatter + : fmt::formatter { + fmt::format_context::iterator format(WasmEdge::Host::WASINN::Device Target, + fmt::format_context &Ctx) const { + using namespace std::literals; + std::string_view Name; + switch (Target) { + case WasmEdge::Host::WASINN::Device::CPU: + Name = "CPU"sv; + break; + case WasmEdge::Host::WASINN::Device::GPU: + Name = "GPU"sv; + break; + case WasmEdge::Host::WASINN::Device::TPU: + Name = "TPU"sv; + break; + default: + Name = "Unknown"sv; + } + return fmt::formatter::format(Name, Ctx); + } +}; diff --git a/plugins/wasi_nn/wasinnbase.h b/plugins/wasi_nn/wasinnbase.h index b5daaf84..4568dc76 100644 --- a/plugins/wasi_nn/wasinnbase.h +++ b/plugins/wasi_nn/wasinnbase.h @@ -16,6 +16,10 @@ template class WasiNN : public Runtime::HostFunction { : Runtime::HostFunction(0), Env(HostEnv) {} protected: + static constexpr uint32_t castErrNo(WASINN::ErrNo E) noexcept { + return static_cast(E); + } + WASINN::WasiNNEnvironment &Env; }; diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index a03b591c..d89b8dda 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -8,168 +8,149 @@ #include #include -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO -#include -#endif -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH -#include -#endif - -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE -#include "tensorflow/lite/c/c_api.h" -#endif +#include "onnx.h" +#include "openvino.h" +#include "tf.h" +#include "tfl.h" +#include "torch.h" +#include "types.h" namespace WasmEdge { namespace Host { namespace WASINN { -enum class ErrNo : uint32_t { - Success = 0, // No error occurred. - InvalidArgument = 1, // Caller module passed an invalid argument. - InvalidEncoding = 2, // Invalid encoding. - MissingMemory = 3, // Caller module is missing a memory export. - Busy = 4, // Device or resource busy. - RuntimeError = 5, // Runtime Error. -}; +namespace detail { +template struct VariantIndex; -enum class TensorType : uint8_t { F16 = 0, F32 = 1, U8 = 2, I32 = 3 }; +template +struct VariantIndex> + : std::integral_constant {}; -enum class Device : uint32_t { CPU = 0, GPU = 1, TPU = 2 }; +template +struct VariantIndex> + : std::integral_constant< + std::size_t, VariantIndex>::value + 1> {}; -enum class Backend : uint8_t { - OpenVINO = 0, - ONNX = 1, - Tensorflow = 2, - PyTorch = 3, - TensorflowLite = 4 -}; +template +inline constexpr std::size_t VariantIndexV = VariantIndex::value; + +template struct BackendTrait; +#define EACH(B) \ + template <> struct BackendTrait { \ + using Graph = B::Graph; \ + using Context = B::Context; \ + }; +FOR_EACH_BACKEND(EACH) +#undef EACH + +template using BackendGraphT = typename BackendTrait::Graph; +template using BackendContextT = typename BackendTrait::Context; +} // namespace detail class Graph { public: Graph() = delete; - Graph(Backend BE) noexcept : GraphBackend(BE) { -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO - OpenVINONetwork = nullptr; - OpenVINOExecNetwork = nullptr; - OpenVINOWeightBlob = nullptr; -#endif - } - ~Graph() noexcept { -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO - if (OpenVINONetwork) { - ie_network_free(&OpenVINONetwork); + Graph(Backend BE) noexcept : Impl(std::in_place_type_t()) { + switch (BE) { +#define EACH(B) \ + case Backend::B: \ + Impl.emplace(); \ + break; + FOR_EACH_BACKEND(EACH) +#undef EACH + default: + __builtin_unreachable(); } - if (OpenVINOExecNetwork) { - ie_exec_network_free(&OpenVINOExecNetwork); - } - if (OpenVINOWeightBlob) { - ie_blob_free(&OpenVINOWeightBlob); - } - for (auto &I : OpenVINOInputNames) { - if (I) { - ie_network_name_free(&I); - } - } - for (auto &I : OpenVINOOutputNames) { - if (I) { - ie_network_name_free(&I); - } - } -#endif -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE - if (TFLiteMod) { - TfLiteModelDelete(TFLiteMod); + } + Backend getBackend() const noexcept { + using V = std::decay_t; + switch (Impl.index()) { +#define EACH(B) \ + case detail::VariantIndexV: \ + return Backend::B; + FOR_EACH_BACKEND(EACH) +#undef EACH + default: + __builtin_unreachable(); } -#endif } - - Backend GraphBackend; -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO - ie_network_t *OpenVINONetwork; - ie_executable_network_t *OpenVINOExecNetwork; - ie_blob_t *OpenVINOWeightBlob; - std::vector OpenVINOInputNames; - std::vector OpenVINOOutputNames; -#endif -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH - torch::jit::Module TorchModel; - torch::DeviceType TorchDevice = at::kCPU; -#endif -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE - TfLiteModel *TFLiteMod = nullptr; -#endif + template auto &get() noexcept { + return *std::get_if>(&Impl); + } + template const auto &get() const noexcept { + return *std::get_if>(&Impl); + } + template auto &get() noexcept { return *std::get_if(&Impl); } + template const auto &get() const noexcept { + return *std::get_if(&Impl); + } + std::variant< +#define EACH(B) B::Graph, + FOR_EACH_BACKEND(EACH) +#undef EACH + std::monostate> + Impl; }; class Context { public: Context() = delete; - - Context(Graph &G) noexcept : GraphRef(G) { - if (G.GraphBackend == Backend::OpenVINO) { -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO - IEStatusCode Status = ie_exec_network_create_infer_request( - G.OpenVINOExecNetwork, &OpenVINOInferRequest); - if (Status != IEStatusCode::OK) { - OpenVINOInferRequest = nullptr; - spdlog::error("[WASI-NN] Unable to create infer request for OpenVINO"); - } -#endif + Context(Graph &G) noexcept : Impl(std::in_place_type_t()) { + switch (G.getBackend()) { +#define EACH(B) \ + case Backend::B: \ + Impl.emplace(G.get()); \ + break; + FOR_EACH_BACKEND(EACH) +#undef EACH + default: + __builtin_unreachable(); } } - ~Context() noexcept { -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO - if (OpenVINOInferRequest) { - ie_infer_request_free(&OpenVINOInferRequest); - } -#endif -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE - if (TFLiteInterp) { - TfLiteInterpreterDelete(TFLiteInterp); + Backend getBackend() const noexcept { + using V = std::decay_t; + switch (Impl.index()) { +#define EACH(B) \ + case detail::VariantIndexV: \ + return Backend::B; + FOR_EACH_BACKEND(EACH) +#undef EACH + default: + __builtin_unreachable(); } -#endif } - Graph &GraphRef; -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO - ie_infer_request_t *OpenVINOInferRequest = nullptr; -#endif -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH - std::vector TorchInputs; - std::vector TorchOutputs; -#endif -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE - TfLiteInterpreter *TFLiteInterp = nullptr; -#endif + template auto &get() noexcept { + return *std::get_if>(&Impl); + } + template const auto &get() const noexcept { + return *std::get_if>(&Impl); + } + template auto &get() noexcept { return *std::get_if(&Impl); } + template const auto &get() const noexcept { + return *std::get_if(&Impl); + } + std::variant< +#define EACH(B) B::Context, + FOR_EACH_BACKEND(EACH) +#undef EACH + std::monostate> + Impl; }; -class WasiNNEnvironment { -public: +struct WasiNNEnvironment : +#define EACH(B) B::Environ, + FOR_EACH_BACKEND(EACH) +#undef EACH + std::monostate { WasiNNEnvironment() noexcept { -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO - if (ie_core_create("", &OpenVINOCore) != IEStatusCode::OK) { - spdlog::error( - "[WASI-NN] Error happened when initializing OpenVINO core."); - } -#endif NNGraph.reserve(16U); NNContext.reserve(16U); } - ~WasiNNEnvironment() noexcept { - NNContext.clear(); - NNGraph.clear(); -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO - if (OpenVINOCore) { - ie_core_free(&OpenVINOCore); - } -#endif - } std::vector NNGraph; std::vector NNContext; -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO - ie_core_t *OpenVINOCore = nullptr; -#endif static Plugin::PluginRegister Register; }; diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index b744087f..c7f59968 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -3,55 +3,24 @@ #include "wasinnfunc.h" #include "common/log.h" +#include "wasinnenv.h" #include - -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO -#include - -#include -#endif - -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH -#include - -#include -#endif - -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE -#include "tensorflow/lite/c/c_api.h" -#include "tensorflow/lite/c/common.h" -#endif +#include namespace WasmEdge { namespace Host { namespace { -[[maybe_unused]] std::string findDevice(const WASINN::Device Target) { - std::string DeviceName; - switch (Target) { - case WASINN::Device::CPU: - DeviceName = "CPU"; - break; - case WASINN::Device::GPU: - DeviceName = "GPU"; - break; - case WASINN::Device::TPU: - DeviceName = "TPU"; - break; - default: - DeviceName = ""; - } - return DeviceName; +inline void reportUnknownBackend(WASINN::Backend B) noexcept { + spdlog::error("[WASI-NN] Unknown backend {}.", static_cast(B)); } - } // namespace -Expect WasiNNLoad::body(const Runtime::CallingFrame &Frame, - uint32_t BuilderPtr [[maybe_unused]], - uint32_t BuilderLen [[maybe_unused]], - uint32_t Encoding, uint32_t Target, - uint32_t GraphIdPtr [[maybe_unused]]) { +Expect +WasiNNLoad::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t BuilderPtr, + uint32_t BuilderLen, uint32_t RawEncoding, uint32_t Target, + uint32_t GraphIdPtr) { // Check memory instance from module. auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { @@ -61,302 +30,65 @@ Expect WasiNNLoad::body(const Runtime::CallingFrame &Frame, uint32_t *GraphId = MemInst->getPointer(GraphIdPtr, 1); if (unlikely(GraphId == nullptr)) { spdlog::error("[WASI-NN] Failed when accessing the return GraphID memory."); - return static_cast(WASINN::ErrNo::InvalidArgument); + return WASINN::ErrNo::InvalidArgument; } - // Get and check the device name string. + // Get and check the device. const auto Device = static_cast(Target); - const std::string DeviceName = findDevice(Device); - if (unlikely(DeviceName.length() == 0 && - (Encoding != static_cast(WASINN::Backend::PyTorch) || - Device != WASINN::Device::GPU))) { - spdlog::error("[WASI-NN] Only support CPU target and Pytorch GPU target."); - return static_cast(WASINN::ErrNo::InvalidArgument); + switch (Device) { + case WASINN::Device::CPU: + case WASINN::Device::GPU: + case WASINN::Device::TPU: + break; + default: + spdlog::error("[WASI-NN] Unknown device {};", Target); + return WASINN::ErrNo::InvalidArgument; + } + spdlog::debug("[WASI-NN] Using device: {}", Device); + + // Builders' Layout: + // | builder-0 | builder-0 len | builder-1 | builder-1 len | ... + struct WasiBuilderPair { + uint32_t Ptr; + uint32_t Len; + }; + + auto WasiBuilders = Span( + MemInst->getPointer(BuilderPtr, BuilderLen), + BuilderLen); + if (unlikely(WasiBuilders.data() == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the GraphBuilder memory."); + return WASINN::ErrNo::InvalidArgument; } - spdlog::debug("[WASI-NN] Using device: {:s}", DeviceName); - - if (Encoding == static_cast(WASINN::Backend::OpenVINO)) { -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO - // The OpenVINO core must be initialized in constructor. - if (unlikely(Env.OpenVINOCore == nullptr)) { - spdlog::error("[WASI-NN] OpenVINO core not initialized."); - return static_cast(WASINN::ErrNo::MissingMemory); - } - - // The graph builder length must be 2. - if (BuilderLen != 2) { - spdlog::error("[WASI-NN] Wrong GraphBuilder Length {:d}, expect 2", - BuilderLen); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - - // Get the graph builders. - // GraphBuilders' Layout: - // | builder-0 | builder-0 len | builder-1 | builder-1 len | ... - uint32_t *GraphBuilders = - MemInst->getPointer(BuilderPtr, BuilderLen * 2); - if (unlikely(GraphBuilders == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the GraphBuilder memory."); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - - // Get the XML and Weight raw buffer from memory instance. - // Builder-0: the XML string - // Builder-1: the Weight binary - uint32_t XMLStringLen = GraphBuilders[1]; - uint32_t WeightsBinLen = GraphBuilders[3]; - uint8_t *XMLPtr = - MemInst->getPointer(GraphBuilders[0], XMLStringLen); - uint8_t *BinPtr = - MemInst->getPointer(GraphBuilders[2], WeightsBinLen); - if (unlikely(XMLPtr == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the XML memory."); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - if (unlikely(BinPtr == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the Weight memory."); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - - // Add a new graph. - Env.NNGraph.emplace_back(static_cast(Encoding)); - auto &Graph = Env.NNGraph.back(); - - // Create the weights blob memory. - tensor_desc_t WeightsDesc{ - layout_e::ANY, {1, {WeightsBinLen}}, precision_e::U8}; - IEStatusCode Status = - ie_blob_make_memory(&WeightsDesc, &(Graph.OpenVINOWeightBlob)); - if (Status != IEStatusCode::OK) { - spdlog::error( - "[WASI-NN] Unable to create the model's weight blob, error code: {}", - Status); - Env.NNGraph.pop_back(); - return static_cast(WASINN::ErrNo::Busy); - } - - // Copy the weights buffer to the blob. - ie_blob_buffer_t BlobBuffer; - Status = ie_blob_get_buffer(Graph.OpenVINOWeightBlob, &BlobBuffer); - if (unlikely(Status != IEStatusCode::OK)) { - spdlog::error( - "[WASI-NN] Unable to find the weight blob's buffer, error code: {}", - Status); - Env.NNGraph.pop_back(); - return static_cast(WASINN::ErrNo::MissingMemory); - } - std::copy_n(BinPtr, WeightsBinLen, - static_cast(BlobBuffer.buffer)); - - // Read network from memory. - Status = ie_core_read_network_from_memory( - Env.OpenVINOCore, XMLPtr, XMLStringLen, Graph.OpenVINOWeightBlob, - &(Graph.OpenVINONetwork)); - if (Status != IEStatusCode::OK) { - spdlog::error("[WASI-NN] Unable to read network from the XML and " - "Weights, error code: {}", - Status); - Env.NNGraph.pop_back(); - return static_cast(WASINN::ErrNo::Busy); - } - - // Get the network input and output size. - size_t NetworkInputSize = 0; - Status = - ie_network_get_inputs_number(Graph.OpenVINONetwork, &NetworkInputSize); - if (unlikely(Status != IEStatusCode::OK)) { - spdlog::error("[WASI-NN] Unable to get the inputs number from the " - "network, error code: {}", - Status); - Env.NNGraph.pop_back(); - return static_cast(WASINN::ErrNo::MissingMemory); - } - spdlog::debug("[WASI-NN] Got input size: {}", NetworkInputSize); - size_t NetworkOutputSize = 0; - Status = ie_network_get_outputs_number(Graph.OpenVINONetwork, - &NetworkOutputSize); - if (unlikely(Status != IEStatusCode::OK)) { - spdlog::error("[WASI-NN] Unable to get the outputs number from the " - "network, error code: {}", - Status); - Env.NNGraph.pop_back(); - return static_cast(WASINN::ErrNo::MissingMemory); - } - spdlog::debug("[WASI-NN] Got output size: {}", NetworkOutputSize); - - // Get and store the input and output names. - Graph.OpenVINOInputNames.resize(NetworkInputSize, nullptr); - for (size_t I = 0; I < NetworkInputSize; I++) { - Status = ie_network_get_input_name(Graph.OpenVINONetwork, I, - &(Graph.OpenVINOInputNames[I])); - if (Status != IEStatusCode::OK) { - spdlog::error("[WASI-NN] Unable to find input name correctly with " - "Index {}, error code: {}", - I, Status); - Env.NNGraph.pop_back(); - return static_cast(WASINN::ErrNo::MissingMemory); - } - spdlog::debug("[WASI-NN] Got input name: {}", - Graph.OpenVINOInputNames[I]); - } - Graph.OpenVINOOutputNames.resize(NetworkOutputSize, nullptr); - for (size_t I = 0; I < NetworkOutputSize; I++) { - Status = ie_network_get_output_name(Graph.OpenVINONetwork, I, - &(Graph.OpenVINOOutputNames[I])); - if (Status != IEStatusCode::OK) { - spdlog::error("[WASI-NN] Unable to find output name correctly with " - "Index {}, error code: {}", - I, Status); - Env.NNGraph.pop_back(); - return static_cast(WASINN::ErrNo::MissingMemory); - } - spdlog::debug("[WASI-NN] Got output name: {}", - Graph.OpenVINOOutputNames[I]); - } - - // Set the input layout. - // FIXME: this is a temporary workaround. We need a more eligant way to - // specify the layout in the long run. However, without this newer versions - // of OpenVINO will fail due to parameter mismatch. - for (size_t I = 0; I < NetworkInputSize; I++) { - // More layouts should be supported. - Status = ie_network_set_input_layout( - Graph.OpenVINONetwork, Graph.OpenVINOInputNames[I], layout_e::NHWC); - spdlog::debug("[WASI-NN] Setting [{}] to NHWC", - Graph.OpenVINOInputNames[I]); - if (Status != IEStatusCode::OK) { - spdlog::error("[WASI-NN] Unable to set input layout with the input " - "name {}, error code: {}", - Graph.OpenVINOInputNames[I], Status); - Env.NNGraph.pop_back(); - return static_cast(WASINN::ErrNo::MissingMemory); - } - } - - // Load network. - ie_config_t Config = {nullptr, nullptr, nullptr}; - Status = ie_core_load_network(Env.OpenVINOCore, Graph.OpenVINONetwork, - DeviceName.c_str(), &Config, - &(Graph.OpenVINOExecNetwork)); - if (Status != IEStatusCode::OK) { - spdlog::error( - "[WASI-NN] Unable to create executable Network, error code: {}", - Status); - Env.NNGraph.pop_back(); - return static_cast(WASINN::ErrNo::Busy); - } - - // Store the loaded graph. - *GraphId = Env.NNGraph.size() - 1; - - return static_cast(WASINN::ErrNo::Success); -#else - spdlog::error("[WASI-NN] OpenVINO backend is not built. use " - "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"OpenVINO\" to build it."); -#endif - } else if (Encoding == static_cast(WASINN::Backend::PyTorch)) { -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH - // The graph builder length must be 2. - if (BuilderLen != 1) { - spdlog::error("[WASI-NN] Wrong GraphBuilder Length {:d}, expect 1", - BuilderLen); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - uint32_t *GraphBuilders = - MemInst->getPointer(BuilderPtr, BuilderLen * 2); - if (unlikely(GraphBuilders == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the GraphBuilder memory."); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - - uint32_t BinLen = GraphBuilders[1]; - uint8_t *BinPtr = MemInst->getPointer(GraphBuilders[0], BinLen); - if (unlikely(BinPtr == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the Weight memory."); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - // Add a new graph. - Env.NNGraph.emplace_back(static_cast(Encoding)); - auto &Graph = Env.NNGraph.back(); - // Setup Graph Device - if (Device == WASINN::Device::GPU) { - if (torch::cuda::is_available()) { - Graph.TorchDevice = at::kCUDA; - } else { - spdlog::error("[WASI-NN] Platform Cannot support GPU target."); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - } - - std::string BinString((char *)BinPtr, BinLen); - std::stringstream BinRead; - BinRead.str(BinString); - - try { - Graph.TorchModel = torch::jit::load(BinRead); - Graph.TorchModel.to(Graph.TorchDevice); - } catch (const c10::Error &e) { - spdlog::error("[WASI-NN] Failed when load the TorchScript model."); - Env.NNGraph.pop_back(); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - // Store the loaded graph. - *GraphId = Env.NNGraph.size() - 1; - return static_cast(WASINN::ErrNo::Success); - -#else - spdlog::error("[WASI-NN] PyTorch backend is not built. use " - "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"PyTorch\" to build it."); -#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH - } else if (Encoding == - static_cast(WASINN::Backend::TensorflowLite)) { -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE - // The graph builder length must be 1. - if (BuilderLen != 1) { - spdlog::error("[WASI-NN] Wrong GraphBuilder Length {:d}, expect 1", - BuilderLen); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - uint32_t *GraphBuilders = - MemInst->getPointer(BuilderPtr, BuilderLen * 2); - if (unlikely(GraphBuilders == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the GraphBuilder memory."); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - uint32_t BinLen = GraphBuilders[1]; - char *BinPtr = MemInst->getPointer(GraphBuilders[0], BinLen); - if (unlikely(BinPtr == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the Weight memory."); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - // Add a new graph. - Env.NNGraph.emplace_back(static_cast(Encoding)); - auto &Graph = Env.NNGraph.back(); - Graph.TFLiteMod = TfLiteModelCreate(BinPtr, BinLen); - if (unlikely(Graph.TFLiteMod == nullptr)) { - spdlog::error("[WASI-NN] Cannot import TFLite model"); - Env.NNGraph.pop_back(); - return static_cast(WASINN::ErrNo::InvalidArgument); - } + std::vector> Builders; + Builders.reserve(BuilderLen); + for (size_t I = 0; I < WasiBuilders.size(); ++I) { + const auto &WasiBuilder = WasiBuilders[I]; + auto Builder = + MemInst->getPointer(WasiBuilder.Ptr, WasiBuilder.Len); + if (unlikely(Builder == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the Builder[{}] memory.", + I); + return WASINN::ErrNo::InvalidArgument; + } + Builders.emplace_back(Builder, WasiBuilder.Len); + } - // Store the loaded graph. - *GraphId = Env.NNGraph.size() - 1; - return static_cast(WASINN::ErrNo::Success); -#else - spdlog::error( - "[WASI-NN] TensorflowLite backend is not built. use " - "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"Tensorflowlite\" to build it."); -#endif - } else { - spdlog::error("[WASI-NN] Current backend is not supported."); - return static_cast(WASINN::ErrNo::InvalidEncoding); + switch (const auto Backend = static_cast(RawEncoding)) { +#define EACH(B) \ + case WASINN::Backend::B: \ + return WASINN::B::load(Env, Builders, Device, *GraphId); + FOR_EACH_BACKEND(EACH) +#undef EACH + default: + reportUnknownBackend(Backend); + return WASINN::ErrNo::InvalidEncoding; } - return static_cast(WASINN::ErrNo::InvalidArgument); } -Expect WasiNNInitExecCtx::body(const Runtime::CallingFrame &Frame, - uint32_t GraphId, - uint32_t ContextPtr [[maybe_unused]]) { +Expect +WasiNNInitExecCtx::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t GraphId, uint32_t ContextPtr) { auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); @@ -364,90 +96,31 @@ Expect WasiNNInitExecCtx::body(const Runtime::CallingFrame &Frame, if (Env.NNGraph.size() <= GraphId) { spdlog::error("[WASI-NN] init_execution_context: Graph Id does not exist."); - return static_cast(WASINN::ErrNo::InvalidArgument); + return WASINN::ErrNo::InvalidArgument; } + // Check the return value: Context should be valid. uint32_t *Context = MemInst->getPointer(ContextPtr, 1); if (unlikely(Context == nullptr)) { spdlog::error("[WASI-NN] Failed when accessing the Context memory."); - return static_cast(WASINN::ErrNo::InvalidArgument); + return WASINN::ErrNo::InvalidArgument; } - if (Env.NNGraph[GraphId].GraphBackend == WASINN::Backend::OpenVINO) { -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO - // Check the network and the execution network with the graph ID. - if (Env.NNGraph[GraphId].OpenVINONetwork == nullptr || - Env.NNGraph[GraphId].OpenVINOExecNetwork == nullptr) { - spdlog::error("[WASI-NN] Model for Graph:{} is empty!", GraphId); - return static_cast(WASINN::ErrNo::MissingMemory); - } - - // Create context. - Env.NNContext.emplace_back(Env.NNGraph[GraphId]); - auto &NewContext = Env.NNContext.back(); - if (NewContext.OpenVINOInferRequest == nullptr) { - spdlog::error("[WASI-NN] Unable to create openvino context"); - Env.NNContext.pop_back(); - return static_cast(WASINN::ErrNo::Busy); - } - - *Context = Env.NNContext.size() - 1; - return static_cast(WASINN::ErrNo::Success); -#else - spdlog::error("[WASI-NN] OpenVINO backend is not built. use " - "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"OpenVINO\" to build it."); -#endif - } else if (Env.NNGraph[GraphId].GraphBackend == WASINN::Backend::PyTorch) { -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH - Env.NNContext.emplace_back(Env.NNGraph[GraphId]); - - *Context = Env.NNContext.size() - 1; - return static_cast(WASINN::ErrNo::Success); - -#else - spdlog::error("[WASI-NN] PyTorch backend is not built. use " - "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"PyTorch\" to build it."); -#endif - } else if (Env.NNGraph[GraphId].GraphBackend == - WASINN::Backend::TensorflowLite) { -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE - // Check the network and the execution network with the graph ID. - if (Env.NNGraph[GraphId].TFLiteMod == nullptr) { - spdlog::error("[WASI-NN] Model for Graph:{} is missing!", GraphId); - return static_cast(WASINN::ErrNo::MissingMemory); - } - Env.NNContext.emplace_back(Env.NNGraph[GraphId]); - const auto &Graph = Env.NNGraph[GraphId]; - auto &NewContext = Env.NNContext.back(); - auto *TFLiteOps = TfLiteInterpreterOptionsCreate(); - TfLiteInterpreterOptionsSetNumThreads(TFLiteOps, 2); - NewContext.TFLiteInterp = - TfLiteInterpreterCreate(Graph.TFLiteMod, TFLiteOps); - TfLiteInterpreterOptionsDelete(TFLiteOps); - if (unlikely(NewContext.TFLiteInterp == nullptr)) { - spdlog::error("[WASI-NN] Cannot create TFLite interpreter."); - Env.NNContext.pop_back(); - return static_cast(WASINN::ErrNo::Busy); - } - TfLiteInterpreterAllocateTensors(NewContext.TFLiteInterp); - - *Context = Env.NNContext.size() - 1; - return static_cast(WASINN::ErrNo::Success); -#else - spdlog::error( - "[WASI-NN] TensorflowLite backend is not built. use " - "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"Tensorflowlite\" to build it."); -#endif - } else { - spdlog::error("[WASI-NN] Current backend is not supported."); + switch (const auto Backend = Env.NNGraph[GraphId].getBackend()) { +#define EACH(B) \ + case WASINN::Backend::B: \ + return WASINN::B::initExecCtx(Env, GraphId, *Context); + FOR_EACH_BACKEND(EACH) +#undef EACH + default: + reportUnknownBackend(Backend); + return WASINN::ErrNo::InvalidEncoding; } - return static_cast(WASINN::ErrNo::InvalidArgument); } -Expect WasiNNSetInput::body(const Runtime::CallingFrame &Frame, - uint32_t Context, - uint32_t Index [[maybe_unused]], - uint32_t TensorPtr [[maybe_unused]]) { +Expect +WasiNNSetInput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context, + uint32_t Index, uint32_t TensorPtr) { auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); @@ -455,316 +128,71 @@ Expect WasiNNSetInput::body(const Runtime::CallingFrame &Frame, if (Env.NNContext.size() <= Context) { spdlog::error("[WASI-NN] set_input: Execution Context does not exist."); - return static_cast(WASINN::ErrNo::InvalidArgument); + return WASINN::ErrNo::InvalidArgument; } - auto &CxtRef = Env.NNContext[Context]; - if (CxtRef.GraphRef.GraphBackend == WASINN::Backend::OpenVINO) { -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO - // Check the infer request and the network. - auto *Network = CxtRef.GraphRef.OpenVINONetwork; - if (Network == nullptr || CxtRef.OpenVINOInferRequest == nullptr) { - spdlog::error("[WASI-NN] The founded openvino session is empty"); - return static_cast(WASINN::ErrNo::MissingMemory); - } - - // Check the input index. - if (CxtRef.GraphRef.OpenVINOInputNames.size() <= Index) { - spdlog::error( - "[WASI-NN] The input index {} exceeds the inputs number {}.", Index, - CxtRef.GraphRef.OpenVINOInputNames.size()); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - char *InputName = CxtRef.GraphRef.OpenVINOInputNames[Index]; - - // Get the tensor. - // Tensor's Layout: - // | dim buf | dim buf len | rtype | data buf | data buf len | - uint32_t *Tensor = MemInst->getPointer(TensorPtr, 5); - if (unlikely(Tensor == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the Tensor memory."); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - uint32_t DimensionLen = Tensor[1]; - if (DimensionLen > 8) { - spdlog::error( - "[WASI-NN] Tensor dimension is out of range, expect it under 8-dim, " - "but got {}-dim.", - DimensionLen); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - uint32_t *DimensionBuf = - MemInst->getPointer(Tensor[0], DimensionLen); - if (unlikely(DimensionBuf == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the Dimension memory."); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - uint32_t TensorDataLen = Tensor[4]; - uint8_t *TensorDataBuf = - MemInst->getPointer(Tensor[3], TensorDataLen); - if (unlikely(TensorDataBuf == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the TensorData memory."); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - WASINN::TensorType RType = static_cast(Tensor[2]); - if (RType != WASINN::TensorType::F32) { - spdlog::error( - "[WASI-NN] Only F32 inputs and outputs are supported for now."); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - - // Set the input resize algorithm. - // Mark the input as resizable by setting a resize algorithm. - // In this case we will be able to set an input blob of any shape to an - // infer request. Resizing and layout conversions are executed automatically - // when inferring. - IEStatusCode Status = ie_network_set_input_resize_algorithm( - Network, InputName, RESIZE_BILINEAR); - if (Status != IEStatusCode::OK) { - spdlog::error( - "[WASI-NN] Unable to set input resize correctly, error code: {}", - Status); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - - // Set the input layout. - // More layouts should be supported. - Status = ie_network_set_input_layout(Network, InputName, layout_e::NHWC); - if (Status != IEStatusCode::OK) { - spdlog::error( - "[WASI-NN] Unable to set input layout correctly, error code: {}", - Status); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - - // Set the input precision. - // More types should be supported. - Status = - ie_network_set_input_precision(Network, InputName, precision_e::FP32); - if (Status != IEStatusCode::OK) { - spdlog::error( - "[WASI-NN] Unable to set input precision correctly, error code: {}", - Status); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - - // Set the dimensions and the tensor description. - dimensions_t Dimens; - Dimens.ranks = DimensionLen; - for (size_t I = 0; I < Dimens.ranks; I++) { - Dimens.dims[I] = static_cast(DimensionBuf[I]); - } - tensor_desc_t TensorDesc = {layout_e::NHWC, Dimens, precision_e::FP32}; - - // Create the input blob memory. - ie_blob_t *InputBlob = nullptr; - Status = ie_blob_make_memory(&TensorDesc, &InputBlob); - if (Status != IEStatusCode::OK) { - spdlog::error("[WASI-NN] Unable to allocated input tensor correctly, " - "error code: {}", - Status); - return static_cast(WASINN::ErrNo::Busy); - } - - // Get the blob buffer size and compare with the tensor size. - int BlobSize; - Status = ie_blob_size(InputBlob, &BlobSize); - if (unlikely(Status != IEStatusCode::OK)) { - spdlog::error( - "[WASI-NN] Unable to get the input blob size, error code: {}", - Status); - return static_cast(WASINN::ErrNo::Busy); - } - if (unlikely(static_cast(BlobSize * 4) != TensorDataLen)) { - spdlog::error( - "[WASI-NN] Blob size {} and the Tensor size {} not matched.", - BlobSize * 4, TensorDataLen); - } - - // Copy the data into the input blob buffer. - ie_blob_buffer_t BlobBuffer; - Status = ie_blob_get_buffer(InputBlob, &BlobBuffer); - if (unlikely(Status != IEStatusCode::OK)) { - spdlog::error("[WASI-NN] Unable to find input tensor buffer"); - ie_blob_free(&InputBlob); - return static_cast(WASINN::ErrNo::MissingMemory); - } - std::copy_n(TensorDataBuf, TensorDataLen, - static_cast(BlobBuffer.buffer)); - - // Set input blob. - Status = ie_infer_request_set_blob(CxtRef.OpenVINOInferRequest, InputName, - InputBlob); - if (Status != IEStatusCode::OK) { - spdlog::error("[WASI-NN] Unable to set input tensor to model correctly, " - "error code: {}", - Status); - ie_blob_free(&InputBlob); - return static_cast(WASINN::ErrNo::Busy); - } - - ie_blob_free(&InputBlob); - - return static_cast(WASINN::ErrNo::Success); -#else - spdlog::error("[WASI-NN] OpenVINO backend is not built. use " - "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"OpenVINO\" to build it."); -#endif - } else if (CxtRef.GraphRef.GraphBackend == WASINN::Backend::PyTorch) { -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH - if (Index >= CxtRef.TorchInputs.size()) { - CxtRef.TorchInputs.resize(Index + 1); - } - uint32_t *Tensor = MemInst->getPointer(TensorPtr, 5); - if (unlikely(Tensor == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the Tensor memory."); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - uint32_t DimensionLen = Tensor[1]; - uint32_t *DimensionBuf = - MemInst->getPointer(Tensor[0], DimensionLen); - if (unlikely(DimensionBuf == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the Dimension memory."); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - uint32_t TensorDataLen = Tensor[4]; - uint8_t *TensorDataBuf = - MemInst->getPointer(Tensor[3], TensorDataLen); - if (unlikely(TensorDataBuf == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the TensorData memory."); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - WASINN::TensorType RType = static_cast(Tensor[2]); - if (RType != WASINN::TensorType::F32) { - spdlog::error( - "[WASI-NN] Only F32 inputs and outputs are supported for now."); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - auto Options = - torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false); - std::vector Dims; - for (size_t I = 0; I < DimensionLen; I++) { - Dims.push_back(static_cast(DimensionBuf[I])); - } - torch::Tensor InTensor = - torch::from_blob(reinterpret_cast(TensorDataBuf), Dims, - Options) - .to(CxtRef.GraphRef.TorchDevice); - - CxtRef.TorchInputs[Index] = InTensor.clone(); - return static_cast(WASINN::ErrNo::Success); -#else - spdlog::error("[WASI-NN] PyTorch backend is not built. use " - "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"PyTorch\" to build it."); -#endif - } else if (CxtRef.GraphRef.GraphBackend == WASINN::Backend::TensorflowLite) { -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE - uint32_t InCnt = TfLiteInterpreterGetInputTensorCount(CxtRef.TFLiteInterp); - if (Index >= InCnt) { - spdlog::error("[WASI-NN] Invalid index id {} for the input, only {} " - "inputs are allowed", - Index, InCnt); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - uint32_t *Tensor = MemInst->getPointer(TensorPtr, 5); - if (unlikely(Tensor == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the Tensor memory."); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - uint32_t DimensionLen = Tensor[1]; - uint32_t *DimensionBuf = - MemInst->getPointer(Tensor[0], DimensionLen); - if (unlikely(DimensionBuf == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the Dimension memory."); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - uint32_t TensorDataLen = Tensor[4]; - uint8_t *TensorDataBuf = - MemInst->getPointer(Tensor[3], TensorDataLen); - if (unlikely(TensorDataBuf == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the TensorData memory."); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - - auto *HoldTensor = - TfLiteInterpreterGetInputTensor(CxtRef.TFLiteInterp, Index); - // Check the input data size. - const auto HoldTensorByteSize = TfLiteTensorByteSize(HoldTensor); - if (HoldTensorByteSize != static_cast(TensorDataLen)) { - spdlog::error("[WASI-NN] Expect tensor byte size {}, but got {}", - HoldTensorByteSize, TensorDataLen); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - // Check the input tensor dimensions. - const auto HoldTensorNumDims = TfLiteTensorNumDims(HoldTensor); - if (static_cast(HoldTensorNumDims) != DimensionLen) { - spdlog::error( - "[WASI-NN] Expect tensor number of dimensions {}, but got {}", - HoldTensorNumDims, DimensionLen); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - for (uint32_t I = 0; I < DimensionLen; I++) { - const auto HoldTensorDim = TfLiteTensorDim(HoldTensor, I); - if (static_cast(HoldTensorDim) != DimensionBuf[I]) { - spdlog::error("[WASI-NN] Expect tensor dimension[{}] = {}, but got {}", - I, HoldTensorDim, DimensionBuf[I]); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - } - // Check the input tensor type. - WASINN::TensorType RType = static_cast(Tensor[2]); - WASINN::TensorType LiteType; - switch (const auto Type = TfLiteTensorType(HoldTensor)) { - case TfLiteType::kTfLiteUInt8: - LiteType = WASINN::TensorType::U8; - break; - case TfLiteType::kTfLiteFloat16: - LiteType = WASINN::TensorType::F16; - break; - case TfLiteType::kTfLiteFloat32: - LiteType = WASINN::TensorType::F32; - break; - case TfLiteType::kTfLiteInt32: - LiteType = WASINN::TensorType::I32; - break; - default: - spdlog::error("[WASI-NN] Unsupported TFLite type: {}", - TfLiteTypeGetName(Type)); - return static_cast(WASINN::ErrNo::RuntimeError); - } - - if (unlikely(LiteType != RType)) { - spdlog::error("[WASI-NN] Expect tensor type {}, but got {}", - static_cast(LiteType), - static_cast(RType)); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - TfLiteStatus Stat = - TfLiteTensorCopyFromBuffer(HoldTensor, TensorDataBuf, TensorDataLen); - if (unlikely(Stat != TfLiteStatus::kTfLiteOk)) { - spdlog::error("[WASI-NN] Copy tensor memory failed"); - return static_cast(WASINN::ErrNo::Busy); - } - - return static_cast(WASINN::ErrNo::Success); + // Tensor's Layout: + // | dim buf | dim buf len | rtype | data buf | data buf len | + struct WasiTensorData { + uint32_t DimensionPtr; + uint32_t DimensionLen; + uint32_t RType; + uint32_t TensorPtr; + uint32_t TensorLen; + }; + // Get the tensor. + auto *WasiTensor = MemInst->getPointer(TensorPtr); + if (unlikely(WasiTensor == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the Tensor memory."); + return WASINN::ErrNo::InvalidArgument; + } + WASINN::TensorData Tensor; + Tensor.Dimension = + Span(MemInst->getPointer(WasiTensor->DimensionPtr, + WasiTensor->DimensionLen), + WasiTensor->DimensionLen); + if (unlikely(Tensor.Dimension.data() == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the Dimension memory."); + return WASINN::ErrNo::InvalidArgument; + } + Tensor.Tensor = + Span(MemInst->getPointer(WasiTensor->TensorPtr, + WasiTensor->TensorLen), + WasiTensor->TensorLen); + if (unlikely(Tensor.Tensor.data() == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the TensorData memory."); + return WASINN::ErrNo::InvalidArgument; + } + switch (const auto RType = + static_cast(WasiTensor->RType)) { + case WASINN::TensorType::F16: + case WASINN::TensorType::F32: + case WASINN::TensorType::U8: + case WASINN::TensorType::I32: + Tensor.RType = RType; + break; + default: + spdlog::error("[WASI-NN] Unknown tensor type {}.", + static_cast(RType)); + return WASINN::ErrNo::InvalidArgument; + } -#else - spdlog::error( - "[WASI-NN] TensorflowLite backend is not built. use " - "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"Tensorflowlite\" to build it."); -#endif - } else { - spdlog::error("[WASI-NN] Current backend is not supported."); + switch (const auto Backend = Env.NNContext[Context].getBackend()) { +#define EACH(B) \ + case WASINN::Backend::B: \ + return WASINN::B::setInput(Env, Context, Index, Tensor); + FOR_EACH_BACKEND(EACH) +#undef EACH + default: + reportUnknownBackend(Backend); + return WASINN::ErrNo::InvalidEncoding; } - return static_cast(WASINN::ErrNo::InvalidArgument); } -Expect -WasiNNGetOuput::body(const Runtime::CallingFrame &Frame, uint32_t Context, - uint32_t Index [[maybe_unused]], - uint32_t OutBufferPtr [[maybe_unused]], - uint32_t OutBufferMaxSize [[maybe_unused]], - uint32_t BytesWrittenPtr [[maybe_unused]]) { +Expect +WasiNNGetOutput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context, + uint32_t Index, uint32_t OutBufferPtr, + uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr) { auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); @@ -772,172 +200,36 @@ WasiNNGetOuput::body(const Runtime::CallingFrame &Frame, uint32_t Context, if (Env.NNContext.size() <= Context) { spdlog::error("[WASI-NN] get_output: Execution Context does not exist"); - return static_cast(WASINN::ErrNo::InvalidArgument); + return WASINN::ErrNo::InvalidArgument; } - auto &CxtRef = Env.NNContext[Context]; - if (CxtRef.GraphRef.GraphBackend == WASINN::Backend::OpenVINO) { -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO - auto *Network = CxtRef.GraphRef.OpenVINONetwork; - - // Check the output index. - if (CxtRef.GraphRef.OpenVINOOutputNames.size() <= Index) { - spdlog::error( - "[WASI-NN] The output index {} exceeds the outputs number {}.", Index, - CxtRef.GraphRef.OpenVINOOutputNames.size()); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - char *OutputName = CxtRef.GraphRef.OpenVINOOutputNames[Index]; - - // Set output precision. - IEStatusCode Status = - ie_network_set_output_precision(Network, OutputName, precision_e::FP32); - if (Status != IEStatusCode::OK) { - spdlog::error( - "[WASI-NN] Unable to set output precision correctly with Index:{}", - Index); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - - // Get output blob buffer. - ie_blob_t *OutputBlob = nullptr; - Status = ie_infer_request_get_blob(CxtRef.OpenVINOInferRequest, OutputName, - &OutputBlob); - if (Status != IEStatusCode::OK) { - spdlog::error("[WASI-NN] Unable to retrieve output tensor correctly", - Index); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - - // Get the blob size and copy the output buffer. - int BlobSize; - Status = ie_blob_size(OutputBlob, &BlobSize); - ie_blob_buffer_t BlobCBuffer; - Status = ie_blob_get_cbuffer(OutputBlob, &BlobCBuffer); - if (Status != IEStatusCode::OK) { - spdlog::error("[WASI-NN] Unable to retrieve output tensor correctly", - Index); - ie_blob_free(&OutputBlob); - return static_cast(WASINN::ErrNo::MissingMemory); - } - uint32_t BytesToWrite = - std::min(static_cast(BlobSize * 4), OutBufferMaxSize); - uint8_t *OutBuffer = - MemInst->getPointer(OutBufferPtr, BytesToWrite); - if (unlikely(OutBuffer == nullptr)) { - spdlog::error( - "[WASI-NN] Failed when accessing the Output Buffer memory."); - ie_blob_free(&OutputBlob); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - std::copy_n(static_cast(BlobCBuffer.cbuffer), BytesToWrite, - OutBuffer); - - // Write the bytes written result. - uint32_t *BytesWritten = - MemInst->getPointer(BytesWrittenPtr, 1); - if (unlikely(BytesWritten == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the BytesWritten memory."); - ie_blob_free(&OutputBlob); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - *BytesWritten = BytesToWrite; - - ie_blob_free(&OutputBlob); - - return static_cast(WASINN::ErrNo::Success); -#else - spdlog::error("[WASI-NN] OpenVINO backend is not built. use " - "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"OpenVINO\" to build it."); -#endif - } else if (CxtRef.GraphRef.GraphBackend == WASINN::Backend::PyTorch) { -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH - if (CxtRef.TorchOutputs.size() <= Index) { - spdlog::error( - "[WASI-NN] The output index {} exceeds the outputs number {}.", Index, - CxtRef.TorchOutputs.size()); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - torch::Tensor OutTensor = - CxtRef.TorchOutputs[Index].to(at::kCPU).toType(torch::kFloat32); - float *TensorBuffer = OutTensor.data_ptr(); - - size_t BlobSize = 1; - for (auto I : OutTensor.sizes()) { - BlobSize *= I; - } - uint32_t BytesToWrite = - std::min(static_cast(BlobSize * 4), OutBufferMaxSize); - uint8_t *OutBuffer = - MemInst->getPointer(OutBufferPtr, BytesToWrite); - if (unlikely(OutBuffer == nullptr)) { - spdlog::error( - "[WASI-NN] Failed when accessing the Output Buffer memory."); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - std::copy_n(reinterpret_cast(TensorBuffer), BytesToWrite, - OutBuffer); - uint32_t *BytesWritten = - MemInst->getPointer(BytesWrittenPtr, 1); - if (unlikely(BytesWritten == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the BytesWritten memory."); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - *BytesWritten = BytesToWrite; - return static_cast(WASINN::ErrNo::Success); -#else - spdlog::error("[WASI-NN] PyTorch backend is not built. use " - "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"PyTorch\" to build it."); -#endif - } else if (CxtRef.GraphRef.GraphBackend == WASINN::Backend::TensorflowLite) { -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE - uint32_t OutCnt = - TfLiteInterpreterGetOutputTensorCount(CxtRef.TFLiteInterp); - if (Index >= OutCnt) { - spdlog::error("[WASI-NN] Invalid index id {} for the input, only {} " - "outputs are allowed", - Index, OutCnt); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - const TfLiteTensor *HoldTensor = - TfLiteInterpreterGetOutputTensor(CxtRef.TFLiteInterp, Index); - const uint32_t BytesToWrite = TfLiteTensorByteSize(HoldTensor); - // Check out buffer max size. - if (OutBufferMaxSize < BytesToWrite) { - spdlog::error("[WASI-NN] Expect out buffer max size {}, but got {}", - BytesToWrite, OutBufferMaxSize); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - uint8_t *OutBuffer = - MemInst->getPointer(OutBufferPtr, BytesToWrite); - if (unlikely(OutBuffer == nullptr)) { - spdlog::error( - "[WASI-NN] Failed when accessing the Output Buffer memory."); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - TfLiteTensorCopyToBuffer(HoldTensor, OutBuffer, BytesToWrite); - uint32_t *BytesWritten = - MemInst->getPointer(BytesWrittenPtr, 1); - if (unlikely(BytesWritten == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the BytesWritten memory."); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - *BytesWritten = BytesToWrite; - return static_cast(WASINN::ErrNo::Success); + Span OutBuffer( + MemInst->getPointer(OutBufferPtr, OutBufferMaxSize), + OutBufferMaxSize); + if (unlikely(OutBuffer.data() == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the Output Buffer memory."); + return WASINN::ErrNo::InvalidArgument; + } + uint32_t *BytesWritten = MemInst->getPointer(BytesWrittenPtr); + if (unlikely(BytesWritten == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the BytesWritten memory."); + return WASINN::ErrNo::InvalidArgument; + } -#else - spdlog::error( - "[WASI-NN] Tensorflowlite backend is not built. use " - "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"Tensorflowlite\" to build it."); -#endif - } else { - spdlog::error("[WASI-NN] Current backend is not supported."); + switch (const auto Backend = Env.NNContext[Context].getBackend()) { +#define EACH(B) \ + case WASINN::Backend::B: \ + return WASINN::B::getOutput(Env, Context, Index, OutBuffer, *BytesWritten); + FOR_EACH_BACKEND(EACH) +#undef EACH + default: + reportUnknownBackend(Backend); + return WASINN::ErrNo::InvalidEncoding; } - return static_cast(WASINN::ErrNo::InvalidArgument); } -Expect WasiNNCompute::body(const Runtime::CallingFrame &Frame, - uint32_t Context) { +Expect +WasiNNCompute::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context) { auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); @@ -945,81 +237,19 @@ Expect WasiNNCompute::body(const Runtime::CallingFrame &Frame, if (Env.NNContext.size() <= Context) { spdlog::error("[WASI-NN] compute: Execution Context does not exist."); - return static_cast(WASINN::ErrNo::InvalidArgument); + return WASINN::ErrNo::InvalidArgument; } - auto &CxtRef = Env.NNContext[Context]; - if (CxtRef.GraphRef.GraphBackend == WASINN::Backend::OpenVINO) { -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO - IEStatusCode Status = ie_infer_request_infer(CxtRef.OpenVINOInferRequest); - if (Status != IEStatusCode::OK) { - spdlog::error( - "[WASI-NN] Unable to perform computation correctly, error code: {}", - Status); - return static_cast(WASINN::ErrNo::Busy); - } - return static_cast(WASINN::ErrNo::Success); -#else - spdlog::error("[WASI-NN] OpenVINO backend is not built. use " - "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"OpenVINO\" to build it."); -#endif - } else if (CxtRef.GraphRef.GraphBackend == WASINN::Backend::PyTorch) { -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH - if (CxtRef.TorchInputs.size() == 0) { - spdlog::error("[WASI-NN] Input is not set!"); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - for (size_t I = 0; I < CxtRef.TorchInputs.size(); I++) { - torch::jit::IValue InTensor = CxtRef.TorchInputs[I]; - if (InTensor.isNone()) { - spdlog::error("[WASI-NN] Input [{}] is not set!", I); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - } - torch::jit::IValue RawOutput = - CxtRef.GraphRef.TorchModel.forward(CxtRef.TorchInputs); - // TODO: more output type should be supported here - if (RawOutput.isTensorList()) { - auto OutTensors = RawOutput.toTensorVector(); - for (auto &OneOf : OutTensors) { - CxtRef.TorchOutputs.push_back(OneOf.clone()); - } - } else if (RawOutput.isTensor()) { - auto OutTensor = RawOutput.toTensor(); - CxtRef.TorchOutputs.push_back(OutTensor.clone()); - } else { - spdlog::error("[WASI-NN] PyTorch backend only supports output a tensor " - "or a list of tensor"); - return static_cast(WASINN::ErrNo::InvalidArgument); - } - return static_cast(WASINN::ErrNo::Success); -#else - spdlog::error("[WASI-NN] PyTorch backend is not built. use " - "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"PyTorch\" to build it."); -#endif - } else if (CxtRef.GraphRef.GraphBackend == WASINN::Backend::TensorflowLite) { -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE - // Run session - if (unlikely(CxtRef.TFLiteInterp == nullptr)) { - spdlog::error("[WASI-NN] Tensorflow Lite context empty"); - return static_cast(WASINN::ErrNo::MissingMemory); - } - TfLiteStatus Stat = TfLiteInterpreterInvoke(CxtRef.TFLiteInterp); - if (unlikely(Stat != TfLiteStatus::kTfLiteOk)) { - spdlog::error("[WASI-NN] Invocation failed."); - return static_cast(WASINN::ErrNo::Busy); - } - return static_cast(WASINN::ErrNo::Success); -#else - spdlog::error( - "[WASI-NN] Tensorflowlite backend is not built. use " - "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"Tensorflowlite\" to build it."); -#endif - } else { - spdlog::error("[WASI-NN] Current backend is not supported."); + switch (const auto Backend = Env.NNContext[Context].getBackend()) { +#define EACH(B) \ + case WASINN::Backend::B: \ + return WASINN::B::compute(Env, Context); + FOR_EACH_BACKEND(EACH) +#undef EACH + default: + reportUnknownBackend(Backend); + return WASINN::ErrNo::InvalidEncoding; } - - return static_cast(WASINN::ErrNo::InvalidArgument); } } // namespace Host diff --git a/plugins/wasi_nn/wasinnfunc.h b/plugins/wasi_nn/wasinnfunc.h index 2a1e7462..cd76d987 100644 --- a/plugins/wasi_nn/wasinnfunc.h +++ b/plugins/wasi_nn/wasinnfunc.h @@ -14,37 +14,76 @@ namespace Host { class WasiNNLoad : public WasiNN { public: WasiNNLoad(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} - Expect body(const Runtime::CallingFrame &, uint32_t BuilderPtr, + Expect body(const Runtime::CallingFrame &Frame, uint32_t BuilderPtr, uint32_t BuilderLen, uint32_t Encoding, uint32_t Target, - uint32_t GraphIdPtr); + uint32_t GraphIdPtr) { + return bodyImpl(Frame, BuilderPtr, BuilderLen, Encoding, Target, GraphIdPtr) + .map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &, + uint32_t BuilderPtr, uint32_t BuilderLen, + uint32_t Encoding, uint32_t Target, + uint32_t GraphIdPtr); }; class WasiNNInitExecCtx : public WasiNN { public: WasiNNInitExecCtx(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} - Expect body(const Runtime::CallingFrame &, uint32_t GraphId, - uint32_t ContextPtr); + Expect body(const Runtime::CallingFrame &Frame, uint32_t GraphId, + uint32_t ContextPtr) { + return bodyImpl(Frame, GraphId, ContextPtr).map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t GraphId, uint32_t ContextPtr); }; class WasiNNSetInput : public WasiNN { public: WasiNNSetInput(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} - Expect body(const Runtime::CallingFrame &, uint32_t Context, - uint32_t Index, uint32_t TensorPtr); + Expect body(const Runtime::CallingFrame &Frame, uint32_t Context, + uint32_t Index, uint32_t TensorPtr) { + return bodyImpl(Frame, Context, Index, TensorPtr).map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t Context, uint32_t Index, + uint32_t TensorPtr); }; -class WasiNNGetOuput : public WasiNN { +class WasiNNGetOutput : public WasiNN { public: - WasiNNGetOuput(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} - Expect body(const Runtime::CallingFrame &, uint32_t Context, + WasiNNGetOutput(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t Context, uint32_t Index, uint32_t OutBufferPtr, - uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr); + uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr) { + return bodyImpl(Frame, Context, Index, OutBufferPtr, OutBufferMaxSize, + BytesWrittenPtr) + .map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t Context, uint32_t Index, + uint32_t OutBufferPtr, + uint32_t OutBufferMaxSize, + uint32_t BytesWrittenPtr); }; class WasiNNCompute : public WasiNN { public: WasiNNCompute(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} - Expect body(const Runtime::CallingFrame &, uint32_t Context); + Expect body(const Runtime::CallingFrame &Frame, uint32_t Context) { + return bodyImpl(Frame, Context).map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t Context); }; } // namespace Host diff --git a/plugins/wasi_nn/wasinnmodule.cpp b/plugins/wasi_nn/wasinnmodule.cpp index 56cf5314..0a805d87 100644 --- a/plugins/wasi_nn/wasinnmodule.cpp +++ b/plugins/wasi_nn/wasinnmodule.cpp @@ -12,7 +12,7 @@ WasiNNModule::WasiNNModule() : ModuleInstance("wasi_ephemeral_nn") { addHostFunc("init_execution_context", std::make_unique(Env)); addHostFunc("set_input", std::make_unique(Env)); - addHostFunc("get_output", std::make_unique(Env)); + addHostFunc("get_output", std::make_unique(Env)); addHostFunc("compute", std::make_unique(Env)); } diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 39d2484a..a8308512 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -39,10 +39,10 @@ inline std::vector readEntireFile(const std::string &Path) { return {}; } Fin.seekg(0, std::ios::end); - std::vector Buf(static_cast(Fin.tellg())); + std::vector Buf(static_cast(Fin.tellg())); Fin.seekg(0, std::ios::beg); if (!Fin.read(reinterpret_cast(Buf.data()), - static_cast(Buf.size()))) { + static_cast(Buf.size()))) { return {}; } Fin.close(); @@ -143,7 +143,7 @@ TEST(WasiNNTest, OpenVINOBackend) { EXPECT_NE(FuncInst, nullptr); EXPECT_TRUE(FuncInst->isHostFunction()); auto &HostFuncGetOutput = - dynamic_cast(FuncInst->getHostFunc()); + dynamic_cast(FuncInst->getHostFunc()); // Get the function "compute". FuncInst = NNMod->findFuncExports("compute"); EXPECT_NE(FuncInst, nullptr); @@ -544,7 +544,7 @@ TEST(WasiNNTest, PyTorchBackend) { EXPECT_NE(FuncInst, nullptr); EXPECT_TRUE(FuncInst->isHostFunction()); auto &HostFuncGetOutput = - dynamic_cast(FuncInst->getHostFunc()); + dynamic_cast(FuncInst->getHostFunc()); // Get the function "compute". FuncInst = NNMod->findFuncExports("compute"); EXPECT_NE(FuncInst, nullptr); @@ -910,7 +910,7 @@ TEST(WasiNNTest, TFLiteBackend) { EXPECT_NE(FuncInst, nullptr); EXPECT_TRUE(FuncInst->isHostFunction()); auto &HostFuncGetOutput = - dynamic_cast(FuncInst->getHostFunc()); + dynamic_cast(FuncInst->getHostFunc()); // Get the function "compute". FuncInst = NNMod->findFuncExports("compute"); EXPECT_NE(FuncInst, nullptr); @@ -1208,4 +1208,4 @@ TEST(WasiNNTest, TFLiteBackend) { } } } -#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE \ No newline at end of file +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE From 6ea6b68a88d4f12b41d72d2748cad046110534a6 Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Fri, 19 May 2023 16:23:20 +0800 Subject: [PATCH 105/623] [WASI-NN] Make `Context`store `graph` id, instead of reference. Signed-off-by: Shen-Ta Hsieh --- plugins/wasi_nn/onnx.h | 2 +- plugins/wasi_nn/openvino.cpp | 20 +++++++++++--------- plugins/wasi_nn/openvino.h | 6 +++--- plugins/wasi_nn/tf.h | 2 +- plugins/wasi_nn/tfl.cpp | 6 +++--- plugins/wasi_nn/tfl.h | 6 +++--- plugins/wasi_nn/torch.cpp | 8 +++++--- plugins/wasi_nn/torch.h | 6 +++--- plugins/wasi_nn/wasinnenv.h | 5 +++-- test/plugins/wasi_nn/wasi_nn.cpp | 6 +++--- 10 files changed, 36 insertions(+), 31 deletions(-) diff --git a/plugins/wasi_nn/onnx.h b/plugins/wasi_nn/onnx.h index f27c7157..e5aaff1a 100644 --- a/plugins/wasi_nn/onnx.h +++ b/plugins/wasi_nn/onnx.h @@ -13,7 +13,7 @@ struct WasiNNEnvironment; namespace WasmEdge::Host::WASINN::ONNX { struct Graph {}; struct Context { - Context(Graph &) noexcept {} + Context(size_t, Graph &) noexcept {} }; struct Environ {}; diff --git a/plugins/wasi_nn/openvino.cpp b/plugins/wasi_nn/openvino.cpp index afb29632..dcdf8eda 100644 --- a/plugins/wasi_nn/openvino.cpp +++ b/plugins/wasi_nn/openvino.cpp @@ -176,7 +176,7 @@ Expect initExecCtx(WASINN::WasiNNEnvironment &Env, } // Create context. - Env.NNContext.emplace_back(Env.NNGraph[GraphId]); + Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); auto &CtxRef = Env.NNContext.back().get(); if (CtxRef.OpenVINOInferRequest == nullptr) { spdlog::error("[WASI-NN] Unable to create openvino context"); @@ -192,20 +192,21 @@ Expect setInput(WASINN::WasiNNEnvironment &Env, uint32_t ContextId, uint32_t Index, const TensorData &Tensor) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); // Check the infer request and the network. - auto *Network = CxtRef.GraphRef.OpenVINONetwork; + auto *Network = GraphRef.OpenVINONetwork; if (Network == nullptr || CxtRef.OpenVINOInferRequest == nullptr) { spdlog::error("[WASI-NN] The founded openvino session is empty"); return WASINN::ErrNo::MissingMemory; } // Check the input index. - if (CxtRef.GraphRef.OpenVINOInputNames.size() <= Index) { + if (GraphRef.OpenVINOInputNames.size() <= Index) { spdlog::error("[WASI-NN] The input index {} exceeds the inputs number {}.", - Index, CxtRef.GraphRef.OpenVINOInputNames.size()); + Index, GraphRef.OpenVINOInputNames.size()); return WASINN::ErrNo::InvalidArgument; } - char *InputName = CxtRef.GraphRef.OpenVINOInputNames[Index]; + char *InputName = GraphRef.OpenVINOInputNames[Index]; if (Tensor.Dimension.size() > 8) { spdlog::error( @@ -318,16 +319,17 @@ Expect getOutput(WASINN::WasiNNEnvironment &Env, Span OutBuffer, uint32_t &BytesWritten) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); - auto *Network = CxtRef.GraphRef.OpenVINONetwork; + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + auto *Network = GraphRef.OpenVINONetwork; // Check the output index. - if (CxtRef.GraphRef.OpenVINOOutputNames.size() <= Index) { + if (GraphRef.OpenVINOOutputNames.size() <= Index) { spdlog::error( "[WASI-NN] The output index {} exceeds the outputs number {}.", Index, - CxtRef.GraphRef.OpenVINOOutputNames.size()); + GraphRef.OpenVINOOutputNames.size()); return WASINN::ErrNo::InvalidArgument; } - char *OutputName = CxtRef.GraphRef.OpenVINOOutputNames[Index]; + char *OutputName = GraphRef.OpenVINOOutputNames[Index]; // Set output precision. IEStatusCode Status = diff --git a/plugins/wasi_nn/openvino.h b/plugins/wasi_nn/openvino.h index 83821a7e..9a1417a4 100644 --- a/plugins/wasi_nn/openvino.h +++ b/plugins/wasi_nn/openvino.h @@ -104,7 +104,7 @@ struct Graph { }; struct Context { - Context(Graph &G) noexcept : GraphRef(G) { + Context(size_t GId, Graph &G) noexcept : GraphId(GId) { IEStatusCode Status = ie_exec_network_create_infer_request( G.OpenVINOExecNetwork, &OpenVINOInferRequest); if (Status != IEStatusCode::OK) { @@ -117,7 +117,7 @@ struct Context { ie_infer_request_free(&OpenVINOInferRequest); } } - Graph &GraphRef; + size_t GraphId; ie_infer_request_t *OpenVINOInferRequest = nullptr; }; @@ -138,7 +138,7 @@ struct Environ { #else struct Graph {}; struct Context { - Context(Graph &) noexcept {} + Context(size_t, Graph &) noexcept {} }; struct Environ {}; #endif diff --git a/plugins/wasi_nn/tf.h b/plugins/wasi_nn/tf.h index b35c904b..20e38b59 100644 --- a/plugins/wasi_nn/tf.h +++ b/plugins/wasi_nn/tf.h @@ -13,7 +13,7 @@ struct WasiNNEnvironment; namespace WasmEdge::Host::WASINN::Tensorflow { struct Graph {}; struct Context { - Context(Graph &) noexcept {} + Context(size_t, Graph &) noexcept {} }; struct Environ {}; diff --git a/plugins/wasi_nn/tfl.cpp b/plugins/wasi_nn/tfl.cpp index 73450c57..acbf4f96 100644 --- a/plugins/wasi_nn/tfl.cpp +++ b/plugins/wasi_nn/tfl.cpp @@ -50,12 +50,12 @@ Expect initExecCtx(WASINN::WasiNNEnvironment &Env, } // Create context. - Env.NNContext.emplace_back(Env.NNGraph[GraphId]); + Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); auto &CxtRef = Env.NNContext.back().get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); auto *TFLiteOps = TfLiteInterpreterOptionsCreate(); TfLiteInterpreterOptionsSetNumThreads(TFLiteOps, 2); - CxtRef.TFLiteInterp = - TfLiteInterpreterCreate(CxtRef.GraphRef.TFLiteMod, TFLiteOps); + CxtRef.TFLiteInterp = TfLiteInterpreterCreate(GraphRef.TFLiteMod, TFLiteOps); TfLiteInterpreterOptionsDelete(TFLiteOps); if (unlikely(CxtRef.TFLiteInterp == nullptr)) { spdlog::error("[WASI-NN] Cannot create TFLite interpreter."); diff --git a/plugins/wasi_nn/tfl.h b/plugins/wasi_nn/tfl.h index 21d637f0..7e2bff4e 100644 --- a/plugins/wasi_nn/tfl.h +++ b/plugins/wasi_nn/tfl.h @@ -29,19 +29,19 @@ struct Graph { struct Context { public: - Context(Graph &G) noexcept : GraphRef(G) {} + Context(size_t GId, Graph &) noexcept : GraphId(GId) {} ~Context() noexcept { if (TFLiteInterp) { TfLiteInterpreterDelete(TFLiteInterp); } } - Graph &GraphRef; + size_t GraphId; TfLiteInterpreter *TFLiteInterp = nullptr; }; #else struct Graph {}; struct Context { - Context(Graph &) noexcept {} + Context(size_t, Graph &) noexcept {} }; #endif diff --git a/plugins/wasi_nn/torch.cpp b/plugins/wasi_nn/torch.cpp index 9ada3547..fba28c81 100644 --- a/plugins/wasi_nn/torch.cpp +++ b/plugins/wasi_nn/torch.cpp @@ -56,7 +56,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, uint32_t &ContextId) noexcept { - Env.NNContext.emplace_back(Env.NNGraph[GraphId]); + Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); ContextId = Env.NNContext.size() - 1; return ErrNo::Success; @@ -79,10 +79,11 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, for (size_t I = 0; I < Tensor.Dimension.size(); I++) { Dims.push_back(static_cast(Tensor.Dimension[I])); } + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); torch::Tensor InTensor = torch::from_blob(reinterpret_cast(Tensor.Tensor.data()), Dims, Options) - .to(CxtRef.GraphRef.TorchDevice); + .to(GraphRef.TorchDevice); CxtRef.TorchInputs[Index] = InTensor.clone(); return ErrNo::Success; @@ -127,8 +128,9 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { return ErrNo::InvalidArgument; } } + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); torch::jit::IValue RawOutput = - CxtRef.GraphRef.TorchModel.forward(CxtRef.TorchInputs); + GraphRef.TorchModel.forward(CxtRef.TorchInputs); // TODO: more output type should be supported here if (RawOutput.isTensorList()) { auto OutTensors = RawOutput.toTensorVector(); diff --git a/plugins/wasi_nn/torch.h b/plugins/wasi_nn/torch.h index ec3ae71e..8961efc8 100644 --- a/plugins/wasi_nn/torch.h +++ b/plugins/wasi_nn/torch.h @@ -25,15 +25,15 @@ struct Graph { struct Context { public: - Context(Graph &G) noexcept : GraphRef(G) {} - Graph &GraphRef; + Context(size_t GId, Graph &) noexcept : GraphId(GId) {} + size_t GraphId; std::vector TorchInputs; std::vector TorchOutputs; }; #else struct Graph {}; struct Context { - Context(Graph &) noexcept {} + Context(size_t, Graph &) noexcept {} }; #endif diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index d89b8dda..bdc792f7 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -95,11 +95,12 @@ class Graph { class Context { public: Context() = delete; - Context(Graph &G) noexcept : Impl(std::in_place_type_t()) { + Context(size_t GId, Graph &G) noexcept + : Impl(std::in_place_type_t()) { switch (G.getBackend()) { #define EACH(B) \ case Backend::B: \ - Impl.emplace(G.get()); \ + Impl.emplace(GId, G.get()); \ break; FOR_EACH_BACKEND(EACH) #undef EACH diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index a8308512..506c219a 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -326,7 +326,7 @@ TEST(WasiNNTest, OpenVINOBackend) { writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); // Swap to the tmp. env. - NNContextTmp.emplace_back(NNGraphTmp[0]); + NNContextTmp.emplace_back(0, NNGraphTmp[0]); NNGraphTmp.swap(NNMod->getEnv().NNGraph); NNContextTmp.swap(NNMod->getEnv().NNContext); // Test: set_input -- context id exceeds. @@ -727,7 +727,7 @@ TEST(WasiNNTest, PyTorchBackend) { static_cast(ErrNo::InvalidArgument)); } - NNContextTmp.emplace_back(NNGraphTmp[0]); + NNContextTmp.emplace_back(0, NNGraphTmp[0]); // Test: set_input -- tensor type not FP32. BuilderPtr = SetInputEntryPtr; @@ -1100,7 +1100,7 @@ TEST(WasiNNTest, TFLiteBackend) { static_cast(ErrNo::InvalidArgument)); } - NNContextTmp.emplace_back(NNGraphTmp[0]); + NNContextTmp.emplace_back(0, NNGraphTmp[0]); // Test: set_input -- set input successfully. BuilderPtr = SetInputEntryPtr; From 69fe04141760a064ac94c7b826a54a527506e8cf Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Fri, 26 May 2023 18:10:11 +0800 Subject: [PATCH 106/623] [Misc] Add `getSpan` `getStringView` method in `MemoryInstance` * Remove `getPointer` with length argument. * Change all relative calls to `getPointer` to `getSpan` or `getStringView` * Change WASI `path_symlink` error reporting behaviors on invalid string pointers Signed-off-by: Shen-Ta Hsieh --- .../wasi_crypto/asymmetric_common/func.cpp | 88 ++++------ plugins/wasi_crypto/common/func.cpp | 47 +++-- plugins/wasi_crypto/kx/func.cpp | 7 +- plugins/wasi_crypto/signatures/func.cpp | 30 ++-- plugins/wasi_crypto/symmetric/func.cpp | 160 ++++++++---------- plugins/wasi_crypto/utils/hostfunction.h | 8 + plugins/wasi_nn/wasinnfunc.cpp | 37 ++-- plugins/wasm_bpf/func-bpf-buffer-poll.cpp | 8 +- plugins/wasm_bpf/func-bpf-buffer-poll.h | 2 +- plugins/wasm_bpf/func-bpf-map-operate.cpp | 7 +- plugins/wasm_bpf/func-load-bpf-object.cpp | 6 +- plugins/wasm_bpf/util.cpp | 2 +- plugins/wasmedge_process/processfunc.cpp | 30 ++-- test/plugins/wasi_nn/wasi_nn.cpp | 17 +- 14 files changed, 207 insertions(+), 242 deletions(-) diff --git a/plugins/wasi_crypto/asymmetric_common/func.cpp b/plugins/wasi_crypto/asymmetric_common/func.cpp index a45443d0..6f70f521 100644 --- a/plugins/wasi_crypto/asymmetric_common/func.cpp +++ b/plugins/wasi_crypto/asymmetric_common/func.cpp @@ -19,14 +19,12 @@ Expect KeypairGenerate::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiAlgLen = AlgLen; - auto *const Alg = MemInst->getPointer(AlgPtr, WasiAlgLen); - checkExist(Alg); + const auto Alg = MemInst->getStringView(AlgPtr, WasiAlgLen); + checkRangeExist(Alg, WasiAlgLen); AsymmetricCommon::Algorithm WasiAlg; if (auto Res = cast<__wasi_algorithm_type_e_t>(AlgType).and_then( - [Alg, WasiAlgLen](auto WasiAlgType) { - return tryFrom(WasiAlgType, {Alg, WasiAlgLen}); - }); + [Alg](auto WasiAlgType) { return tryFrom(WasiAlgType, Alg); }); unlikely(!Res)) { return Res.error(); } else { @@ -59,14 +57,12 @@ Expect KeypairImport::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiAlgLen = AlgLen; - auto *const Alg = MemInst->getPointer(AlgPtr, WasiAlgLen); - checkExist(Alg); + const auto Alg = MemInst->getStringView(AlgPtr, WasiAlgLen); + checkRangeExist(Alg, WasiAlgLen); AsymmetricCommon::Algorithm WasiAlg; if (auto Res = cast<__wasi_algorithm_type_e_t>(AlgType).and_then( - [Alg, WasiAlgLen](auto WasiAlgType) { - return tryFrom(WasiAlgType, {Alg, WasiAlgLen}); - }); + [Alg](auto WasiAlgType) { return tryFrom(WasiAlgType, Alg); }); unlikely(!Res)) { return Res.error(); } else { @@ -74,9 +70,9 @@ Expect KeypairImport::body(const Runtime::CallingFrame &Frame, } const __wasi_size_t WasiEncodedLen = EncodedLen; - auto *const Encoded = - MemInst->getPointer(EncodedPtr, WasiEncodedLen); - checkExist(Encoded); + const auto Encoded = + MemInst->getSpan(EncodedPtr, WasiEncodedLen); + checkRangeExist(Encoded, WasiEncodedLen); const auto WasiEncoding = cast<__wasi_keypair_encoding_e_t>(Encoding); checkExist(WasiEncoding); @@ -84,8 +80,7 @@ Expect KeypairImport::body(const Runtime::CallingFrame &Frame, auto *const KpHandle = MemInst->getPointer<__wasi_keypair_t *>(KpHandlePtr); checkExist(KpHandle); - if (auto Res = - Ctx.keypairImport(WasiAlg, {Encoded, WasiEncodedLen}, *WasiEncoding); + if (auto Res = Ctx.keypairImport(WasiAlg, Encoded, *WasiEncoding); unlikely(!Res)) { return Res.error(); } else { @@ -103,14 +98,12 @@ Expect KeypairGenerateManaged::body( checkExist(MemInst); const __wasi_size_t WasiAlgLen = AlgLen; - auto *const Alg = MemInst->getPointer(AlgPtr, WasiAlgLen); - checkExist(Alg); + const auto Alg = MemInst->getStringView(AlgPtr, WasiAlgLen); + checkRangeExist(Alg, WasiAlgLen); AsymmetricCommon::Algorithm WasiAlg; if (auto Res = cast<__wasi_algorithm_type_e_t>(AlgType).and_then( - [Alg, WasiAlgLen](auto WasiAlgType) { - return tryFrom(WasiAlgType, {Alg, WasiAlgLen}); - }); + [Alg](auto WasiAlgType) { return tryFrom(WasiAlgType, Alg); }); unlikely(!Res)) { return Res.error(); } else { @@ -143,11 +136,10 @@ Expect KeypairStoreManaged::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiKpIdMaxLen = KpIdMaxLen; - auto *const KpId = MemInst->getPointer(KpIdPtr, WasiKpIdMaxLen); - checkExist(KpId); + const auto KpId = MemInst->getSpan(KpIdPtr, WasiKpIdMaxLen); + checkRangeExist(KpId, WasiKpIdMaxLen); - if (auto Res = Ctx.keypairStoreManaged(SecretsManagerHandle, KpHandle, - {KpId, WasiKpIdMaxLen}); + if (auto Res = Ctx.keypairStoreManaged(SecretsManagerHandle, KpHandle, KpId); unlikely(!Res)) { return Res.error(); } @@ -186,8 +178,8 @@ Expect KeypairId::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiKpIdMaxLen = KpIdMaxLen; - auto *const KpId = MemInst->getPointer(KpIdPtr, WasiKpIdMaxLen); - checkExist(KpId); + const auto KpId = MemInst->getSpan(KpIdPtr, WasiKpIdMaxLen); + checkRangeExist(KpId, WasiKpIdMaxLen); auto *const Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); checkExist(Size); @@ -195,8 +187,7 @@ Expect KeypairId::body(const Runtime::CallingFrame &Frame, auto *const Version = MemInst->getPointer<__wasi_version_t *>(KpVersionPtr); checkExist(Version); - if (auto Res = Ctx.keypairId(KpHandle, {KpId, WasiKpIdMaxLen}); - unlikely(!Res)) { + if (auto Res = Ctx.keypairId(KpHandle, KpId); unlikely(!Res)) { return Res.error(); } else { auto [ResSize, ResVersion] = *Res; @@ -223,14 +214,13 @@ Expect KeypairFromId::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiKpIdLen = KpIdLen; - auto *const KpId = MemInst->getPointer(KpIdPtr, WasiKpIdLen); - checkExist(KpId); + const auto KpId = MemInst->getSpan(KpIdPtr, WasiKpIdLen); + checkRangeExist(KpId, WasiKpIdLen); auto *const KpHandle = MemInst->getPointer<__wasi_keypair_t *>(KpHandlePtr); checkExist(KpHandle); - if (auto Res = Ctx.keypairFromId(SecretsManagerHandle, {KpId, WasiKpIdLen}, - KpVersion); + if (auto Res = Ctx.keypairFromId(SecretsManagerHandle, KpId, KpVersion); unlikely(!Res)) { return Res.error(); } else { @@ -339,14 +329,12 @@ Expect PublickeyImport::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiAlgLen = AlgLen; - auto *const Alg = MemInst->getPointer(AlgPtr, WasiAlgLen); - checkExist(Alg); + const auto Alg = MemInst->getStringView(AlgPtr, WasiAlgLen); + checkRangeExist(Alg, WasiAlgLen); AsymmetricCommon::Algorithm WasiAlg; if (auto Res = cast<__wasi_algorithm_type_e_t>(AlgType).and_then( - [Alg, WasiAlgLen](auto WasiAlgType) { - return tryFrom(WasiAlgType, {Alg, WasiAlgLen}); - }); + [Alg](auto WasiAlgType) { return tryFrom(WasiAlgType, Alg); }); unlikely(!Res)) { return Res.error(); } else { @@ -354,9 +342,9 @@ Expect PublickeyImport::body(const Runtime::CallingFrame &Frame, } const __wasi_size_t WasiEncodedLen = EncodedLen; - auto *const Encoded = - MemInst->getPointer(EncodedPtr, WasiEncodedLen); - checkExist(Encoded); + const auto Encoded = + MemInst->getSpan(EncodedPtr, WasiEncodedLen); + checkRangeExist(Encoded, WasiEncodedLen); __wasi_publickey_encoding_e_t WasiPkEncoding; if (auto Res = cast<__wasi_publickey_encoding_e_t>(Encoding); !Res) { @@ -368,8 +356,7 @@ Expect PublickeyImport::body(const Runtime::CallingFrame &Frame, auto *const PkHandle = MemInst->getPointer<__wasi_publickey_t *>(PkHandlePtr); checkExist(PkHandle); - if (auto Res = Ctx.publickeyImport(WasiAlg, {Encoded, WasiEncodedLen}, - WasiPkEncoding); + if (auto Res = Ctx.publickeyImport(WasiAlg, Encoded, WasiPkEncoding); unlikely(!Res)) { return Res.error(); } else { @@ -452,14 +439,12 @@ Expect SecretkeyImport::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiAlgLen = AlgLen; - auto *const Alg = MemInst->getPointer(AlgPtr, WasiAlgLen); - checkExist(Alg); + const auto Alg = MemInst->getStringView(AlgPtr, WasiAlgLen); + checkRangeExist(Alg, WasiAlgLen); AsymmetricCommon::Algorithm WasiAlg; if (auto Res = cast<__wasi_algorithm_type_e_t>(AlgType).and_then( - [Alg, WasiAlgLen](auto WasiAlgType) { - return tryFrom(WasiAlgType, {Alg, WasiAlgLen}); - }); + [Alg](auto WasiAlgType) { return tryFrom(WasiAlgType, Alg); }); unlikely(!Res)) { return Res.error(); } else { @@ -467,9 +452,9 @@ Expect SecretkeyImport::body(const Runtime::CallingFrame &Frame, } const __wasi_size_t WasiEncodedLen = EncodedLen; - auto *const Encoded = - MemInst->getPointer(EncodedPtr, WasiEncodedLen); - checkExist(Encoded); + const auto Encoded = + MemInst->getSpan(EncodedPtr, WasiEncodedLen); + checkRangeExist(Encoded, WasiEncodedLen); auto WasiEncoding = cast<__wasi_secretkey_encoding_e_t>(Encoding); if (!WasiEncoding) { @@ -479,8 +464,7 @@ Expect SecretkeyImport::body(const Runtime::CallingFrame &Frame, auto *const SkHandle = MemInst->getPointer<__wasi_secretkey_t *>(SkHandlePtr); checkExist(SkHandle); - if (auto Res = Ctx.secretkeyImport(WasiAlg, {Encoded, WasiEncodedLen}, - *WasiEncoding); + if (auto Res = Ctx.secretkeyImport(WasiAlg, Encoded, *WasiEncoding); unlikely(!Res)) { return Res.error(); } else { diff --git a/plugins/wasi_crypto/common/func.cpp b/plugins/wasi_crypto/common/func.cpp index 73676491..e8dcc981 100644 --- a/plugins/wasi_crypto/common/func.cpp +++ b/plugins/wasi_crypto/common/func.cpp @@ -37,14 +37,14 @@ Expect ArrayOutputPull::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiBufLen = BufLen; - auto *const Buf = MemInst->getPointer(BufPtr, WasiBufLen); - checkExist(Buf); + const auto Buf = MemInst->getSpan(BufPtr, WasiBufLen); + checkRangeExist(Buf, WasiBufLen); auto *const Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); checkExist(Size); - if (auto Res = Ctx.arrayOutputPull(ArrayOutputHandle, {Buf, WasiBufLen}) - .and_then(toWasiSize); + if (auto Res = + Ctx.arrayOutputPull(ArrayOutputHandle, Buf).and_then(toWasiSize); unlikely(!Res)) { return Res.error(); } else { @@ -99,17 +99,14 @@ Expect OptionsSet::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiNameLen = NameLen; - auto *const Name = MemInst->getPointer(NamePtr, WasiNameLen); - checkExist(Name); + const auto Name = MemInst->getStringView(NamePtr, WasiNameLen); + checkRangeExist(Name, WasiNameLen); const __wasi_size_t WasiValueLen = ValueLen; - auto *const Value = - MemInst->getPointer(ValuePtr, WasiValueLen); - checkExist(Value); + const auto Value = MemInst->getSpan(ValuePtr, WasiValueLen); + checkRangeExist(Value, WasiValueLen); - if (auto Res = Ctx.optionsSet(OptionsHandle, {Name, WasiNameLen}, - {Value, WasiValueLen}); - unlikely(!Res)) { + if (auto Res = Ctx.optionsSet(OptionsHandle, Name, Value); unlikely(!Res)) { return Res.error(); } @@ -124,10 +121,10 @@ Expect OptionsSetU64::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiNameLen = NameLen; - auto *const Name = MemInst->getPointer(NamePtr, WasiNameLen); - checkExist(Name); + const auto Name = MemInst->getStringView(NamePtr, WasiNameLen); + checkRangeExist(Name, WasiNameLen); - if (auto Res = Ctx.optionsSetU64(OptionsHandle, {Name, WasiNameLen}, Value); + if (auto Res = Ctx.optionsSetU64(OptionsHandle, Name, Value); unlikely(!Res)) { return Res.error(); } @@ -144,15 +141,14 @@ Expect OptionsSetGuestBuffer::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiNameLen = NameLen; - auto *const Name = MemInst->getPointer(NamePtr, WasiNameLen); - checkExist(Name); + const auto Name = MemInst->getStringView(NamePtr, WasiNameLen); + checkRangeExist(Name, WasiNameLen); const __wasi_size_t WasiBufLen = BufLen; - auto *const Buf = MemInst->getPointer(BufPtr, WasiBufLen); - checkExist(Buf); + const auto Buf = MemInst->getSpan(BufPtr, WasiBufLen); + checkRangeExist(Buf, WasiBufLen); - if (auto Res = Ctx.optionsSetGuestBuffer(OptionsHandle, {Name, WasiNameLen}, - {Buf, WasiBufLen}); + if (auto Res = Ctx.optionsSetGuestBuffer(OptionsHandle, Name, Buf); unlikely(!Res)) { return Res.error(); } @@ -204,12 +200,11 @@ SecretsManagerInvalidate::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiKeyIdLen = KeyIdLen; - auto *const KeyId = - MemInst->getPointer(KeyIdPtr, WasiKeyIdLen); - checkExist(KeyId); + const auto KeyId = MemInst->getSpan(KeyIdPtr, WasiKeyIdLen); + checkRangeExist(KeyId, WasiKeyIdLen); - if (auto Res = Ctx.secretsManagerInvalidate(SecretsManagerHandle, - {KeyId, WasiKeyIdLen}, Version); + if (auto Res = + Ctx.secretsManagerInvalidate(SecretsManagerHandle, KeyId, Version); unlikely(!Res)) { return Res.error(); } diff --git a/plugins/wasi_crypto/kx/func.cpp b/plugins/wasi_crypto/kx/func.cpp index d0533546..0cdd305d 100644 --- a/plugins/wasi_crypto/kx/func.cpp +++ b/plugins/wasi_crypto/kx/func.cpp @@ -59,16 +59,15 @@ Expect Decapsulate::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiEncapsulatedSecretLen = EncapsulatedSecretLen; - auto *const EncapsulatedSecret = MemInst->getPointer( + const auto EncapsulatedSecret = MemInst->getSpan( EncapsulatedSecretPtr, WasiEncapsulatedSecretLen); - checkExist(EncapsulatedSecret); + checkRangeExist(EncapsulatedSecret, WasiEncapsulatedSecretLen); auto *const Secret = MemInst->getPointer<__wasi_array_output_t *>(SecretPtr); checkExist(Secret); - if (auto Res = Ctx.kxDecapsulate( - SkHandle, {EncapsulatedSecret, WasiEncapsulatedSecretLen}); + if (auto Res = Ctx.kxDecapsulate(SkHandle, EncapsulatedSecret); unlikely(!Res)) { return Res.error(); } else { diff --git a/plugins/wasi_crypto/signatures/func.cpp b/plugins/wasi_crypto/signatures/func.cpp index f5e3069a..8a5b6510 100644 --- a/plugins/wasi_crypto/signatures/func.cpp +++ b/plugins/wasi_crypto/signatures/func.cpp @@ -44,20 +44,20 @@ Expect Import::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiAlgLen = AlgLen; - auto *const Alg = MemInst->getPointer(AlgPtr, WasiAlgLen); - checkExist(Alg); + const auto Alg = MemInst->getStringView(AlgPtr, WasiAlgLen); + checkRangeExist(Alg, WasiAlgLen); Algorithm WasiAlg; - if (auto Res = tryFrom({Alg, AlgLen}); unlikely(!Res)) { + if (auto Res = tryFrom(Alg); unlikely(!Res)) { return Res.error(); } else { WasiAlg = *Res; } const __wasi_size_t WasiEncodedLen = EncodedLen; - auto *const Encoded = - MemInst->getPointer(EncodedPtr, WasiEncodedLen); - checkExist(Encoded); + const auto Encoded = + MemInst->getSpan(EncodedPtr, WasiEncodedLen); + checkRangeExist(Encoded, WasiEncodedLen); __wasi_signature_encoding_e_t WasiEncoding; if (auto Res = cast<__wasi_signature_encoding_e_t>(Encoding); @@ -71,8 +71,7 @@ Expect Import::body(const Runtime::CallingFrame &Frame, MemInst->getPointer<__wasi_signature_t *>(SigHandlePtr); checkExist(SigHandle); - if (auto Res = - Ctx.signatureImport(WasiAlg, {Encoded, WasiEncodedLen}, WasiEncoding); + if (auto Res = Ctx.signatureImport(WasiAlg, Encoded, WasiEncoding); unlikely(!Res)) { return Res.error(); } else { @@ -108,12 +107,10 @@ Expect StateUpdate::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiInputSize = InputSize; - auto *const Input = - MemInst->getPointer(InputPtr, WasiInputSize); - checkExist(Input); + const auto Input = MemInst->getSpan(InputPtr, WasiInputSize); + checkRangeExist(Input, WasiInputSize); - if (auto Res = - Ctx.signatureStateUpdate(SigStateHandle, {Input, WasiInputSize}); + if (auto Res = Ctx.signatureStateUpdate(SigStateHandle, Input); unlikely(!Res)) { return Res.error(); } @@ -182,11 +179,10 @@ VerificationStateUpdate::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiInputSize = InputSize; - auto *const Input = MemInst->getPointer(InputPtr, InputSize); - checkExist(Input); + const auto Input = MemInst->getSpan(InputPtr, WasiInputSize); + checkRangeExist(Input, WasiInputSize); - if (auto Res = Ctx.signatureVerificationStateUpdate(SigStateHandle, - {Input, WasiInputSize}); + if (auto Res = Ctx.signatureVerificationStateUpdate(SigStateHandle, Input); unlikely(!Res)) { return Res.error(); } diff --git a/plugins/wasi_crypto/symmetric/func.cpp b/plugins/wasi_crypto/symmetric/func.cpp index 315b7c4a..47abba79 100644 --- a/plugins/wasi_crypto/symmetric/func.cpp +++ b/plugins/wasi_crypto/symmetric/func.cpp @@ -16,10 +16,11 @@ Expect KeyGenerate::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiAlgLen = AlgLen; - auto *const Alg = MemInst->getPointer(AlgPtr, WasiAlgLen); - checkExist(Alg); + const auto Alg = MemInst->getStringView(AlgPtr, WasiAlgLen); + checkRangeExist(Alg, WasiAlgLen); + Algorithm WasiAlg; - if (auto Res = tryFrom({Alg, WasiAlgLen}); !Res) { + if (auto Res = tryFrom(Alg); !Res) { return Res.error(); } else { WasiAlg = *Res; @@ -51,11 +52,11 @@ Expect KeyImport::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiAlgLen = AlgLen; - auto *const Alg = MemInst->getPointer(AlgPtr, WasiAlgLen); - checkExist(Alg); + const auto Alg = MemInst->getStringView(AlgPtr, WasiAlgLen); + checkRangeExist(Alg, WasiAlgLen); Algorithm WasiAlg; - if (auto Res = tryFrom({Alg, WasiAlgLen}); !Res) { + if (auto Res = tryFrom(Alg); !Res) { return Res.error(); } else { WasiAlg = *Res; @@ -65,11 +66,10 @@ Expect KeyImport::body(const Runtime::CallingFrame &Frame, checkExist(Key); const __wasi_size_t WasiRawLen = RawLen; - auto *Raw = MemInst->getPointer(RawPtr, WasiRawLen); - checkExist(Raw); + const auto Raw = MemInst->getSpan(RawPtr, WasiRawLen); + checkRangeExist(Raw, WasiRawLen); - if (auto Res = Ctx.symmetricKeyImport(WasiAlg, {Raw, WasiRawLen}); - unlikely(!Res)) { + if (auto Res = Ctx.symmetricKeyImport(WasiAlg, Raw); unlikely(!Res)) { return Res.error(); } else { *Key = *Res; @@ -115,10 +115,11 @@ Expect KeyGenerateManaged::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiAlgLen = AlgLen; - auto *const Alg = MemInst->getPointer(AlgPtr, WasiAlgLen); - checkExist(Alg); + const auto Alg = MemInst->getStringView(AlgPtr, WasiAlgLen); + checkRangeExist(Alg, WasiAlgLen); + Algorithm WasiAlg; - if (auto Res = tryFrom({Alg, WasiAlgLen}); !Res) { + if (auto Res = tryFrom(Alg); !Res) { return Res.error(); } else { WasiAlg = *Res; @@ -151,11 +152,11 @@ Expect KeyStoreManaged::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiKeyIdMaxLen = KeyIdMaxLen; - auto *const KeyId = MemInst->getPointer(KeyIdPtr, WasiKeyIdMaxLen); - checkExist(KeyId); + const auto KeyId = MemInst->getSpan(KeyIdPtr, WasiKeyIdMaxLen); + checkRangeExist(KeyId, WasiKeyIdMaxLen); - if (auto Res = Ctx.symmetricKeyStoreManaged(SecretsManagerHandle, KeyHandle, - {KeyId, WasiKeyIdMaxLen}); + if (auto Res = + Ctx.symmetricKeyStoreManaged(SecretsManagerHandle, KeyHandle, KeyId); unlikely(!Res)) { return Res.error(); } @@ -194,8 +195,8 @@ Expect KeyId::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiKeyIdMaxLen = KeyIdMaxLen; - auto *const KeyId = MemInst->getPointer(KeyIdPtr, WasiKeyIdMaxLen); - checkExist(KeyId); + const auto KeyId = MemInst->getSpan(KeyIdPtr, WasiKeyIdMaxLen); + checkRangeExist(KeyId, WasiKeyIdMaxLen); auto *const Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); checkExist(Size); @@ -206,8 +207,7 @@ Expect KeyId::body(const Runtime::CallingFrame &Frame, return __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE; } - if (auto Res = Ctx.symmetricKeyId(KeyHandle, {KeyId, WasiKeyIdMaxLen}); - unlikely(!Res)) { + if (auto Res = Ctx.symmetricKeyId(KeyHandle, KeyId); unlikely(!Res)) { return Res.error(); } else { auto [SizeRes, VersionRes] = *Res; @@ -232,15 +232,15 @@ Expect KeyFromId::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiKeyIdLen = KeyIdLen; - auto *const KeyId = MemInst->getPointer(KeyIdPtr, WasiKeyIdLen); - checkExist(KeyId); + const auto KeyId = MemInst->getSpan(KeyIdPtr, WasiKeyIdLen); + checkRangeExist(KeyId, WasiKeyIdLen); auto *const KeyHandle = MemInst->getPointer<__wasi_symmetric_key_t *>(KeyHandlePtr); checkExist(KeyHandle); - if (auto Res = Ctx.symmetricKeyFromId(SecretsManagerHandle, - {KeyId, WasiKeyIdLen}, KeyVersion); + if (auto Res = + Ctx.symmetricKeyFromId(SecretsManagerHandle, KeyId, KeyVersion); unlikely(!Res)) { return Res.error(); } else { @@ -259,10 +259,10 @@ Expect StateOpen::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiAlgLen = AlgLen; - auto *const Alg = MemInst->getPointer(AlgPtr, WasiAlgLen); - checkExist(Alg); + const auto Alg = MemInst->getStringView(AlgPtr, WasiAlgLen); + checkRangeExist(Alg, WasiAlgLen); Algorithm WasiAlg; - if (auto Res = tryFrom({Alg, WasiAlgLen}); !Res) { + if (auto Res = tryFrom(Alg); !Res) { return Res.error(); } else { WasiAlg = *Res; @@ -320,18 +320,17 @@ Expect StateOptionsGet::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiNameLen = NameLen; - auto *const Name = MemInst->getPointer(NamePtr, WasiNameLen); - checkExist(Name); + const auto Name = MemInst->getStringView(NamePtr, WasiNameLen); + checkRangeExist(Name, WasiNameLen); const __wasi_size_t WasiValueLen = ValueLen; - auto *const Value = MemInst->getPointer(ValuePtr, ValueLen); - checkExist(Value); + const auto Value = MemInst->getSpan(ValuePtr, WasiValueLen); + checkRangeExist(Value, WasiValueLen); auto *const Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); checkExist(Size); - if (auto Res = Ctx.symmetricStateOptionsGet(StateHandle, {Name, WasiNameLen}, - {Value, WasiValueLen}) + if (auto Res = Ctx.symmetricStateOptionsGet(StateHandle, Name, Value) .and_then(toWasiSize); !Res) { return Res.error(); @@ -350,14 +349,13 @@ Expect StateOptionsGetU64::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiNameLen = NameLen; - auto *const Name = MemInst->getPointer(NamePtr, WasiNameLen); - checkExist(Name); + const auto Name = MemInst->getStringView(NamePtr, WasiNameLen); + checkRangeExist(Name, WasiNameLen); auto *const U64 = MemInst->getPointer(U64Ptr); checkExist(U64); - if (auto Res = - Ctx.symmetricStateOptionsGetU64(StateHandle, {Name, WasiNameLen}); + if (auto Res = Ctx.symmetricStateOptionsGetU64(StateHandle, Name); unlikely(!Res)) { return Res.error(); } else { @@ -383,11 +381,10 @@ Expect StateAbsorb::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiDataLen = DataLen; - auto *const Data = MemInst->getPointer(DataPtr, WasiDataLen); - checkExist(Data); + const auto Data = MemInst->getSpan(DataPtr, WasiDataLen); + checkRangeExist(Data, WasiDataLen); - if (auto Res = Ctx.symmetricStateAbsorb(StateHandle, {Data, WasiDataLen}); - unlikely(!Res)) { + if (auto Res = Ctx.symmetricStateAbsorb(StateHandle, Data); unlikely(!Res)) { return Res.error(); } @@ -401,11 +398,10 @@ Expect StateSqueeze::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiOutLen = OutLen; - auto *const Out = MemInst->getPointer(OutPtr, WasiOutLen); - checkExist(Out); + const auto Out = MemInst->getSpan(OutPtr, WasiOutLen); + checkRangeExist(Out, WasiOutLen); - if (auto Res = Ctx.symmetricStateSqueeze(StateHandle, {Out, WasiOutLen}); - unlikely(!Res)) { + if (auto Res = Ctx.symmetricStateSqueeze(StateHandle, Out); unlikely(!Res)) { return Res.error(); } @@ -439,10 +435,10 @@ Expect StateSqueezeKey::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiAlgLen = AlgLen; - auto *const Alg = MemInst->getPointer(AlgPtr, WasiAlgLen); - checkExist(Alg); + const auto Alg = MemInst->getStringView(AlgPtr, WasiAlgLen); + checkRangeExist(Alg, WasiAlgLen); Algorithm WasiAlg; - if (auto Res = tryFrom({Alg, WasiAlgLen}); !Res) { + if (auto Res = tryFrom(Alg); !Res) { return Res.error(); } else { WasiAlg = *Res; @@ -490,18 +486,17 @@ Expect StateEncrypt::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiOutLen = OutLen; - auto *const Out = MemInst->getPointer(OutPtr, WasiOutLen); - checkExist(Out); + const auto Out = MemInst->getSpan(OutPtr, WasiOutLen); + checkRangeExist(Out, WasiOutLen); const __wasi_size_t WasiDataLen = DataLen; - auto *Data = MemInst->getPointer(DataPtr, WasiDataLen); - checkExist(Data); + const auto Data = MemInst->getSpan(DataPtr, WasiDataLen); + checkRangeExist(Data, WasiDataLen); auto *const Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); checkExist(Size); - if (auto Res = Ctx.symmetricStateEncrypt(StateHandle, {Out, WasiOutLen}, - {Data, WasiDataLen}) + if (auto Res = Ctx.symmetricStateEncrypt(StateHandle, Out, Data) .and_then(toWasiSize); unlikely(!Res)) { return Res.error(); @@ -521,19 +516,18 @@ Expect StateEncryptDetached::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiOutLen = OutLen; - auto *const Out = MemInst->getPointer(OutPtr, WasiOutLen); - checkExist(Out); + const auto Out = MemInst->getSpan(OutPtr, WasiOutLen); + checkRangeExist(Out, WasiOutLen); const __wasi_size_t WasiDataLen = DataLen; - auto *const Data = MemInst->getPointer(DataPtr, WasiDataLen); - checkExist(Data); + const auto Data = MemInst->getSpan(DataPtr, WasiDataLen); + checkRangeExist(Data, WasiDataLen); auto *const TagHandle = MemInst->getPointer<__wasi_symmetric_tag_t *>(TagHandlePtr); checkExist(TagHandle); - if (auto Res = Ctx.symmetricStateEncryptDetached( - StateHandle, {Out, WasiOutLen}, {Data, WasiDataLen}); + if (auto Res = Ctx.symmetricStateEncryptDetached(StateHandle, Out, Data); unlikely(!Res)) { return Res.error(); } else { @@ -552,20 +546,19 @@ Expect StateDecrypt::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiOutLen = OutLen; - auto *const Out = MemInst->getPointer(OutPtr, WasiOutLen); - checkExist(Out); + const auto Out = MemInst->getSpan(OutPtr, WasiOutLen); + checkRangeExist(Out, WasiOutLen); const __wasi_size_t WasiDataLen = DataLen; - auto *const Data = MemInst->getPointer(DataPtr, WasiDataLen); - checkExist(Data); + const auto Data = MemInst->getSpan(DataPtr, WasiDataLen); + checkRangeExist(Data, WasiDataLen); auto *const Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); if (unlikely(Size == nullptr)) { return __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE; } - if (auto Res = Ctx.symmetricStateDecrypt(StateHandle, {Out, WasiOutLen}, - {Data, WasiDataLen}) + if (auto Res = Ctx.symmetricStateDecrypt(StateHandle, Out, Data) .and_then(toWasiSize); unlikely(!Res)) { return Res.error(); @@ -584,24 +577,23 @@ Expect StateDecryptDetached::body( checkExist(MemInst); const __wasi_size_t WasiOutLen = OutLen; - auto *const Out = MemInst->getPointer(OutPtr, WasiOutLen); - checkExist(Out); + const auto Out = MemInst->getSpan(OutPtr, WasiOutLen); + checkRangeExist(Out, WasiOutLen); const __wasi_size_t WasiDataLen = DataLen; - auto *const Data = MemInst->getPointer(DataPtr, WasiDataLen); - checkExist(Data); + const auto Data = MemInst->getSpan(DataPtr, WasiDataLen); + checkRangeExist(Data, WasiDataLen); const __wasi_size_t WasiRawTagLen = RawTagLen; - auto *RawTag = MemInst->getPointer(RawTagPtr, WasiRawTagLen); - checkExist(RawTag); + const auto RawTag = MemInst->getSpan(RawTagPtr, WasiRawTagLen); + checkRangeExist(RawTag, WasiRawTagLen); auto *Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); checkExist(Size); - if (auto Res = Ctx.symmetricStateDecryptDetached( - StateHandle, {Out, WasiOutLen}, {Data, WasiDataLen}, - {RawTag, WasiRawTagLen}) - .and_then(toWasiSize); + if (auto Res = + Ctx.symmetricStateDecryptDetached(StateHandle, Out, Data, RawTag) + .and_then(toWasiSize); unlikely(!Res)) { return Res.error(); } else { @@ -647,14 +639,13 @@ Expect TagPull::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiBufLen = BufLen; - auto *Buf = MemInst->getPointer(BufPtr, WasiBufLen); - checkExist(Buf); + const auto Buf = MemInst->getSpan(BufPtr, WasiBufLen); + checkRangeExist(Buf, WasiBufLen); auto *Size = MemInst->getPointer<__wasi_size_t *>(SizePtr); checkExist(Size); - if (auto Res = Ctx.symmetricTagPull(TagHandle, {Buf, WasiBufLen}) - .and_then(toWasiSize); + if (auto Res = Ctx.symmetricTagPull(TagHandle, Buf).and_then(toWasiSize); unlikely(!Res)) { return Res.error(); } else { @@ -671,11 +662,10 @@ Expect TagVerify::body(const Runtime::CallingFrame &Frame, checkExist(MemInst); const __wasi_size_t WasiRawTagLen = RawTagLen; - auto *RawTag = MemInst->getPointer(RawTagPtr, WasiRawTagLen); - checkExist(RawTag); + const auto RawTag = MemInst->getSpan(RawTagPtr, WasiRawTagLen); + checkRangeExist(RawTag, WasiRawTagLen); - if (auto Res = Ctx.symmetricTagVerify(TagHandle, {RawTag, RawTagLen}); - unlikely(!Res)) { + if (auto Res = Ctx.symmetricTagVerify(TagHandle, RawTag); unlikely(!Res)) { return Res.error(); } diff --git a/plugins/wasi_crypto/utils/hostfunction.h b/plugins/wasi_crypto/utils/hostfunction.h index 6d3fe175..bbe58859 100644 --- a/plugins/wasi_crypto/utils/hostfunction.h +++ b/plugins/wasi_crypto/utils/hostfunction.h @@ -158,6 +158,14 @@ tryFrom(std::string_view RawAlgStr) noexcept; } \ } while (0) +/// Check Span exist or return `_algorithm_failure`. +#define checkRangeExist(Expr, Size) \ + do { \ + if (unlikely((Expr).size() != (Size))) { \ + return __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE; \ + } \ + } while (0) + } // namespace WasiCrypto } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index c7f59968..c9e38340 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -27,7 +27,7 @@ WasiNNLoad::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t BuilderPtr, return Unexpect(ErrCode::Value::HostFuncError); } // Check the return value: GraphIdPtr should be valid. - uint32_t *GraphId = MemInst->getPointer(GraphIdPtr, 1); + uint32_t *GraphId = MemInst->getPointer(GraphIdPtr); if (unlikely(GraphId == nullptr)) { spdlog::error("[WASI-NN] Failed when accessing the return GraphID memory."); return WASINN::ErrNo::InvalidArgument; @@ -52,10 +52,9 @@ WasiNNLoad::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t BuilderPtr, uint32_t Len; }; - auto WasiBuilders = Span( - MemInst->getPointer(BuilderPtr, BuilderLen), - BuilderLen); - if (unlikely(WasiBuilders.data() == nullptr)) { + const auto WasiBuilders = + MemInst->getSpan(BuilderPtr, BuilderLen); + if (unlikely(WasiBuilders.size() != BuilderLen)) { spdlog::error("[WASI-NN] Failed when accessing the GraphBuilder memory."); return WASINN::ErrNo::InvalidArgument; } @@ -64,14 +63,13 @@ WasiNNLoad::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t BuilderPtr, Builders.reserve(BuilderLen); for (size_t I = 0; I < WasiBuilders.size(); ++I) { const auto &WasiBuilder = WasiBuilders[I]; - auto Builder = - MemInst->getPointer(WasiBuilder.Ptr, WasiBuilder.Len); - if (unlikely(Builder == nullptr)) { + auto Builder = MemInst->getSpan(WasiBuilder.Ptr, WasiBuilder.Len); + if (unlikely(Builder.size() != WasiBuilder.Len)) { spdlog::error("[WASI-NN] Failed when accessing the Builder[{}] memory.", I); return WASINN::ErrNo::InvalidArgument; } - Builders.emplace_back(Builder, WasiBuilder.Len); + Builders.emplace_back(Builder); } switch (const auto Backend = static_cast(RawEncoding)) { @@ -100,7 +98,7 @@ WasiNNInitExecCtx::bodyImpl(const Runtime::CallingFrame &Frame, } // Check the return value: Context should be valid. - uint32_t *Context = MemInst->getPointer(ContextPtr, 1); + uint32_t *Context = MemInst->getPointer(ContextPtr); if (unlikely(Context == nullptr)) { spdlog::error("[WASI-NN] Failed when accessing the Context memory."); return WASINN::ErrNo::InvalidArgument; @@ -147,19 +145,15 @@ WasiNNSetInput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context, return WASINN::ErrNo::InvalidArgument; } WASINN::TensorData Tensor; - Tensor.Dimension = - Span(MemInst->getPointer(WasiTensor->DimensionPtr, - WasiTensor->DimensionLen), - WasiTensor->DimensionLen); - if (unlikely(Tensor.Dimension.data() == nullptr)) { + Tensor.Dimension = MemInst->getSpan(WasiTensor->DimensionPtr, + WasiTensor->DimensionLen); + if (unlikely(Tensor.Dimension.size() != WasiTensor->DimensionLen)) { spdlog::error("[WASI-NN] Failed when accessing the Dimension memory."); return WASINN::ErrNo::InvalidArgument; } Tensor.Tensor = - Span(MemInst->getPointer(WasiTensor->TensorPtr, - WasiTensor->TensorLen), - WasiTensor->TensorLen); - if (unlikely(Tensor.Tensor.data() == nullptr)) { + MemInst->getSpan(WasiTensor->TensorPtr, WasiTensor->TensorLen); + if (unlikely(Tensor.Tensor.size() != WasiTensor->TensorLen)) { spdlog::error("[WASI-NN] Failed when accessing the TensorData memory."); return WASINN::ErrNo::InvalidArgument; } @@ -203,9 +197,8 @@ WasiNNGetOutput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context, return WASINN::ErrNo::InvalidArgument; } - Span OutBuffer( - MemInst->getPointer(OutBufferPtr, OutBufferMaxSize), - OutBufferMaxSize); + const auto OutBuffer = + MemInst->getSpan(OutBufferPtr, OutBufferMaxSize); if (unlikely(OutBuffer.data() == nullptr)) { spdlog::error("[WASI-NN] Failed when accessing the Output Buffer memory."); return WASINN::ErrNo::InvalidArgument; diff --git a/plugins/wasm_bpf/func-bpf-buffer-poll.cpp b/plugins/wasm_bpf/func-bpf-buffer-poll.cpp index 7d627ea7..45b6e5ca 100644 --- a/plugins/wasm_bpf/func-bpf-buffer-poll.cpp +++ b/plugins/wasm_bpf/func-bpf-buffer-poll.cpp @@ -15,7 +15,7 @@ inline const auto *toCallFrameCxt(const Runtime::CallingFrame *Cxt) noexcept { Expect BpfBufferPoll::body(const Runtime::CallingFrame &Frame, handle_t program, int32_t fd, int32_t sample_func, uint32_t ctx, - uint32_t data, int32_t max_size, + uint32_t data, uint32_t max_size, int32_t timeout_ms) { auto c_ctx = toCallFrameCxt(&Frame); auto c_module = WasmEdge_CallingFrameGetModuleInstance(c_ctx); @@ -36,12 +36,12 @@ Expect BpfBufferPoll::body(const Runtime::CallingFrame &Frame, if (program_ptr == state->handles.end()) { return Unexpect(ErrCode::Value::HostFuncError); } - auto data_buf = memory->getPointer(data, max_size); - if (!data_buf) { + auto data_buf = memory->getSpan(data, max_size); + if (data_buf.size() != max_size) { return Unexpect(ErrCode::Value::HostFuncError); } return program_ptr->second->bpf_buffer_poll(c_executor, c_module, fd, - sample_func, ctx, data_buf, + sample_func, ctx, data_buf.data(), max_size, timeout_ms, data); } diff --git a/plugins/wasm_bpf/func-bpf-buffer-poll.h b/plugins/wasm_bpf/func-bpf-buffer-poll.h index 6a5366de..2ece1588 100644 --- a/plugins/wasm_bpf/func-bpf-buffer-poll.h +++ b/plugins/wasm_bpf/func-bpf-buffer-poll.h @@ -30,7 +30,7 @@ class BpfBufferPoll : public WasmEdge::Runtime::HostFunction { WasmEdge::Expect body(const WasmEdge::Runtime::CallingFrame &Frame, handle_t program, int32_t fd, int32_t sample_func, uint32_t ctx, - uint32_t data, int32_t max_size, + uint32_t data, uint32_t max_size, int32_t timeout_ms); private: diff --git a/plugins/wasm_bpf/func-bpf-map-operate.cpp b/plugins/wasm_bpf/func-bpf-map-operate.cpp index d14f73b9..31ff90ee 100644 --- a/plugins/wasm_bpf/func-bpf-map-operate.cpp +++ b/plugins/wasm_bpf/func-bpf-map-operate.cpp @@ -12,9 +12,10 @@ namespace WasmEdge { namespace Host { #define ensure_memory_size(var, offset, size) \ - void *var = memory->getPointer(offset, size); \ - if (!var) \ - return Unexpect(ErrCode::Value::HostFuncError); + const auto var##_span = memory->getSpan(offset, size); \ + if (var##_span.size() != size) \ + return Unexpect(ErrCode::Value::HostFuncError); \ + const auto var = var##_span.data(); Expect BpfMapOperate::body(const WasmEdge::Runtime::CallingFrame &Frame, int32_t fd, int32_t cmd, uint32_t key, uint32_t value, diff --git a/plugins/wasm_bpf/func-load-bpf-object.cpp b/plugins/wasm_bpf/func-load-bpf-object.cpp index fe576fd3..99e5f72e 100644 --- a/plugins/wasm_bpf/func-load-bpf-object.cpp +++ b/plugins/wasm_bpf/func-load-bpf-object.cpp @@ -12,13 +12,13 @@ Expect LoadBpfObject::body(const Runtime::CallingFrame &Frame, if (unlikely(!memory)) { return Unexpect(ErrCode::Value::HostFuncError); } - char *const object_buffer = memory->getPointer(obj_buf, obj_buf_sz); - if (!object_buffer) { + const auto object_buffer = memory->getSpan(obj_buf, obj_buf_sz); + if (object_buffer.size() != obj_buf_sz) { return Unexpect(ErrCode::Value::HostFuncError); } auto program = std::make_unique(); int32_t res = - program->load_bpf_object(object_buffer, static_cast(obj_buf_sz)); + program->load_bpf_object(object_buffer.data(), object_buffer.size()); if (res < 0) return 0; auto key = reinterpret_cast(program.get()); diff --git a/plugins/wasm_bpf/util.cpp b/plugins/wasm_bpf/util.cpp index 28d383d9..9851c2ac 100644 --- a/plugins/wasm_bpf/util.cpp +++ b/plugins/wasm_bpf/util.cpp @@ -18,7 +18,7 @@ Expect read_c_str(Runtime::Instance::MemoryInstance *memory, tail++; } uint32_t len = tail - ptr + 1; - return memory->getPointer(ptr, len); + return memory->getSpan(ptr, len).data(); } } // namespace Host diff --git a/plugins/wasmedge_process/processfunc.cpp b/plugins/wasmedge_process/processfunc.cpp index ed69a03d..fe4e85b7 100644 --- a/plugins/wasmedge_process/processfunc.cpp +++ b/plugins/wasmedge_process/processfunc.cpp @@ -31,8 +31,8 @@ WasmEdgeProcessSetProgName::body(const Runtime::CallingFrame &Frame, return Unexpect(ErrCode::Value::HostFuncError); } - char *Buf = MemInst->getPointer(NamePtr); - std::copy_n(Buf, NameLen, std::back_inserter(Env.Name)); + const auto Buf = MemInst->getSpan(NamePtr, NameLen); + std::copy(Buf.begin(), Buf.end(), std::back_inserter(Env.Name)); return {}; } @@ -44,9 +44,9 @@ Expect WasmEdgeProcessAddArg::body(const Runtime::CallingFrame &Frame, return Unexpect(ErrCode::Value::HostFuncError); } - char *Buf = MemInst->getPointer(ArgPtr); + const auto Buf = MemInst->getSpan(ArgPtr, ArgLen); std::string NewArg; - std::copy_n(Buf, ArgLen, std::back_inserter(NewArg)); + std::copy(Buf.begin(), Buf.end(), std::back_inserter(NewArg)); Env.Args.push_back(std::move(NewArg)); return {}; } @@ -62,11 +62,11 @@ Expect WasmEdgeProcessAddEnv::body(const Runtime::CallingFrame &Frame, return Unexpect(ErrCode::Value::HostFuncError); } - char *EnvBuf = MemInst->getPointer(EnvNamePtr); - char *ValBuf = MemInst->getPointer(EnvValPtr); + const auto EnvBuf = MemInst->getSpan(EnvNamePtr, EnvNameLen); + const auto ValBuf = MemInst->getSpan(EnvValPtr, EnvValLen); std::string NewEnv, NewVal; - std::copy_n(EnvBuf, EnvNameLen, std::back_inserter(NewEnv)); - std::copy_n(ValBuf, EnvValLen, std::back_inserter(NewVal)); + std::copy(EnvBuf.begin(), EnvBuf.end(), std::back_inserter(NewEnv)); + std::copy(ValBuf.begin(), ValBuf.end(), std::back_inserter(NewVal)); Env.Envs.emplace(std::move(NewEnv), std::move(NewVal)); return {}; } @@ -79,9 +79,9 @@ Expect WasmEdgeProcessAddStdIn::body(const Runtime::CallingFrame &Frame, return Unexpect(ErrCode::Value::HostFuncError); } - uint8_t *Buf = MemInst->getPointer(BufPtr); + auto const Buf = MemInst->getSpan(BufPtr, BufLen); Env.StdIn.reserve(Env.StdIn.size() + BufLen); - std::copy_n(Buf, BufLen, std::back_inserter(Env.StdIn)); + std::copy(Buf.begin(), Buf.end(), std::back_inserter(Env.StdIn)); return {}; } @@ -320,8 +320,9 @@ Expect WasmEdgeProcessGetStdOut::body(const Runtime::CallingFrame &Frame, return Unexpect(ErrCode::Value::HostFuncError); } - char *Buf = MemInst->getPointer(BufPtr); - std::copy_n(Env.StdOut.begin(), Env.StdOut.size(), Buf); + const auto Buf = MemInst->getSpan(BufPtr, Env.StdOut.size()); + std::copy_n(Env.StdOut.begin(), std::min(Env.StdOut.size(), Buf.size()), + Buf.begin()); return {}; } @@ -338,8 +339,9 @@ Expect WasmEdgeProcessGetStdErr::body(const Runtime::CallingFrame &Frame, return Unexpect(ErrCode::Value::HostFuncError); } - char *Buf = MemInst->getPointer(BufPtr); - std::copy_n(Env.StdErr.begin(), Env.StdErr.size(), Buf); + const auto Buf = MemInst->getSpan(BufPtr, Env.StdErr.size()); + std::copy_n(Env.StdErr.begin(), std::min(Env.StdErr.size(), Buf.size()), + Buf.begin()); return {}; } diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 506c219a..f05ebb69 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -69,7 +69,7 @@ void writeFatPointer(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, } template -std::vector classSort(const std::vector &Array) { +std::vector classSort(WasmEdge::Span Array) { std::vector Indices(Array.size()); std::iota(Indices.begin(), Indices.end(), 0); std::sort(Indices.begin(), Indices.end(), @@ -472,9 +472,8 @@ TEST(WasiNNTest, OpenVINOBackend) { Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); EXPECT_EQ(*MemInst.getPointer(BuilderPtr), UINT32_C(4004)); - std::vector OutputClassification( - MemInst.getPointer(StorePtr, 1001) + 1, - MemInst.getPointer(StorePtr, 1001) + 1001); + const auto OutputClassification = + MemInst.getSpan(StorePtr, 1001).subspan(1); std::vector SortedIndex, CorrectClasses{963, 762, 909, 926, 567}; SortedIndex = classSort(OutputClassification); // The probability of class i is placed at buffer[i]. @@ -837,9 +836,8 @@ TEST(WasiNNTest, PyTorchBackend) { Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); EXPECT_EQ(*MemInst.getPointer(BuilderPtr), UINT32_C(4000)); - std::vector OutputClassification( - MemInst.getPointer(StorePtr, 1000), - MemInst.getPointer(StorePtr, 1000) + 1000); + const auto OutputClassification = + MemInst.getSpan(StorePtr, 1000); std::vector SortedIndex, CorrectClasses{954, 940, 951, 950, 953}; SortedIndex = classSort(OutputClassification); // The probability of class i is placed at buffer[i]. @@ -1194,9 +1192,8 @@ TEST(WasiNNTest, TFLiteBackend) { Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); EXPECT_EQ(*MemInst.getPointer(BuilderPtr), UINT32_C(965)); - std::vector OutputClassification( - MemInst.getPointer(StorePtr, 965), - MemInst.getPointer(StorePtr, 965) + 965); + const auto OutputClassification = + MemInst.getSpan(StorePtr, 965); std::vector SortedIndex, CorrectClasses{166, 158, 34, 778, 819}; // FIXME: classSort causing segmentation fault SortedIndex = classSort(OutputClassification); From 500811e3a8199330f65cabbbbd1ec667c51c2bb6 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Thu, 1 Jun 2023 21:25:19 +0800 Subject: [PATCH 107/623] [Plugin] Fix the incorrect function signature of internal WASI-NN implementation. Signed-off-by: YiYing He --- plugins/wasi_nn/torch.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/torch.cpp b/plugins/wasi_nn/torch.cpp index fba28c81..948ace8d 100644 --- a/plugins/wasi_nn/torch.cpp +++ b/plugins/wasi_nn/torch.cpp @@ -157,7 +157,7 @@ Expect reportBackendNotSupported() noexcept { } // namespace Expect load(WasiNNEnvironment &, Span>, Device, - uint32_t *) noexcept { + uint32_t &) noexcept { return reportBackendNotSupported(); } Expect initExecCtx(WasiNNEnvironment &, uint32_t, uint32_t &) noexcept { From fc71ef3389c6f2c31d822681f60d37cd876ac5ad Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Mon, 5 Jun 2023 10:57:32 +0800 Subject: [PATCH 108/623] [Docker] Update manylinux2014 to use devtoolset-11 * Remove self-build gcc * Change `PATH` to devtoolset-11 * Update llvm, cmake, ninja, zstd, boost Signed-off-by: Shen-Ta Hsieh --- utils/docker/Dockerfile.manylinux2014_aarch64 | 103 +++++++----------- utils/docker/Dockerfile.manylinux2014_x86_64 | 103 +++++++----------- utils/docker/SHA256SUM | 18 ++- utils/docker/build-manylinux.sh | 15 +-- 4 files changed, 92 insertions(+), 147 deletions(-) diff --git a/utils/docker/Dockerfile.manylinux2014_aarch64 b/utils/docker/Dockerfile.manylinux2014_aarch64 index a2b01c38..6f30df72 100644 --- a/utils/docker/Dockerfile.manylinux2014_aarch64 +++ b/utils/docker/Dockerfile.manylinux2014_aarch64 @@ -7,76 +7,53 @@ MAINTAINER hydai hydai@secondstate.io ADD SHA256SUM /root/ -ENV PATH /toolchain/bin:$PATH -ENV CC gcc -ENV CXX g++ -ENV CPPFLAGS -I/toolchain/include -ENV LDFLAGS -L/toolchain/lib64 -ENV PKG_CONFIG_PATH /toolchain/lib64/pkgconfig +ENV PATH /opt/rh/devtoolset-11/root/usr/bin${PATH:+:${PATH}} +ENV MANPATH /opt/rh/devtoolset-11/root/usr/share/man${MANPATH:+:${MANPATH}} +ENV INFOPATH /opt/rh/devtoolset-11/root/usr/share/info${INFOPATH:+:${INFOPATH}} +ENV PKG_CONFIG_PATH /opt/rh/devtoolset-11/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} -RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build && \ - export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ +RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build dpkg centos-release-scl && \ + yum install -y devtoolset-11 && \ + export CPU=$(/opt/python/cp311-cp311/bin/python3 -c \ 'import multiprocessing; print(multiprocessing.cpu_count())') && \ - export CFGFLAGS="--prefix=/toolchain --disable-shared --libdir=/toolchain/lib64" && \ + export CFGFLAGS="--prefix=/opt/rh/devtoolset-11/root/usr --disable-shared --libdir=/opt/rh/devtoolset-11/root/usr/lib64" && \ curl -s -L -O --remote-name-all \ - https://ftp.gnu.org/gnu/gmp/gmp-6.2.1.tar.xz \ - https://ftp.gnu.org/gnu/mpfr/mpfr-4.1.0.tar.xz \ - https://ftp.gnu.org/gnu/mpc/mpc-1.2.1.tar.gz \ - https://libisl.sourceforge.io/isl-0.24.tar.xz \ - https://github.com/facebook/zstd/releases/download/v1.5.0/zstd-1.5.0.tar.gz \ - https://ftp.gnu.org/gnu/gcc/gcc-11.1.0/gcc-11.1.0.tar.xz \ - https://github.com/Kitware/CMake/releases/download/v3.20.2/cmake-3.20.2.tar.gz \ - https://github.com/ninja-build/ninja/archive/v1.10.2.tar.gz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-12.0.0/llvm-12.0.0.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-12.0.0/lld-12.0.0.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-12.0.0/libunwind-12.0.0.src.tar.xz && \ + https://github.com/facebook/zstd/releases/download/v1.5.5/zstd-1.5.5.tar.gz \ + https://github.com/Kitware/CMake/releases/download/v3.26.4/cmake-3.26.4.tar.gz \ + https://github.com/ninja-build/ninja/archive/refs/tags/v1.11.1.tar.gz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/llvm-16.0.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/lld-16.0.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/libunwind-16.0.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/cmake-16.0.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/third-party-16.0.5.src.tar.xz && \ sha256sum -c SHA256SUM && \ - xz -dc gmp-6.2.1.tar.xz | tar -xf - && \ - xz -dc mpfr-4.1.0.tar.xz | tar -xf - && \ - gzip -dc mpc-1.2.1.tar.gz | tar -xf - && \ - xz -dc isl-0.24.tar.xz | tar -xf - && \ - gzip -dc zstd-1.5.0.tar.gz | tar -xf - && \ - xz -dc gcc-11.1.0.tar.xz | tar -xf - && \ - gzip -dc cmake-3.20.2.tar.gz | tar -xf - && \ - gzip -dc v1.10.2.tar.gz | tar -xf - && \ - xz -dc llvm-12.0.0.src.tar.xz | tar -xf - && \ - xz -dc lld-12.0.0.src.tar.xz | tar -xf - && \ - xz -dc libunwind-12.0.0.src.tar.xz | tar -xf - && \ - mkdir build && cd build && ../gmp-6.2.1/configure --build=aarch64-redhat-linux-gnu $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mkdir build && cd build && ../mpfr-4.1.0/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mkdir build && cd build && ../mpc-1.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mkdir build && cd build && ../isl-0.24/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - export ZSTDFLAGS="PREFIX=/toolchain LIBDIR=/toolchain/lib64 SED_ERE_OPT=--regexp-extended" && \ - cd zstd-1.5.0 && make -s $ZSTDFLAGS -j $CPU && make -s $ZSTDFLAGS install && rm -f /toolchain/lib64/libzstd.so* && cd - && \ - mkdir build && cd build && ../gcc-11.1.0/configure --prefix=/toolchain --libdir=/toolchain/lib64 \ - --with-gmp=/toolchain --with-gmp-lib=/toolchain/lib64 \ - --with-zstd=/toolchain --with-zstd-lib=/toolchain/lib64 \ - --disable-libmpx --disable-libsanitizer --disable-libunwind-exceptions \ - --disable-multilib --enable-__cxa_atexit --enable-gnu-indirect-function \ - --enable-gnu-unique-object --enable-initfini-array --enable-languages="c,c++,lto" \ - --enable-linker-build-id --enable-lto --enable-plugin --enable-threads=posix \ - --with-default-libstdcxx-abi="gcc4-compatible" \ - --with-gcc-major-version-only --with-linker-hash-style="gnu" \ - --with-arch="armv8-a" && \ - make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mv -v /toolchain/lib64/libstdc++.so.6.0.29 /toolchain/lib64/libstdc++.so.6.0.29.backup && \ - echo -e "OUTPUT_FORMAT(elf64-aarch64)\nINPUT ( libstdc++.so.6.0.19 libstdc++.a )" \ - > /toolchain/lib64/libstdc++.so.6.0.29 && \ - export PATH="/toolchain/bin:$PATH" && \ - mkdir build && cd build && /opt/python/cp39-cp39/bin/python \ - ../ninja-1.10.2/configure.py --bootstrap \ - --with-python=/opt/python/cp39-cp39/bin/python && \ - cp -v ninja /toolchain/bin/ninja && cd - && rm -rf build && \ - mkdir build && cd build && ../cmake-3.20.2/configure --prefix=/toolchain \ + gzip -dc zstd-1.5.5.tar.gz | tar -xf - && \ + gzip -dc cmake-3.26.4.tar.gz | tar -xf - && \ + gzip -dc v1.11.1.tar.gz | tar -xf - && \ + xz -dc llvm-16.0.5.src.tar.xz | tar -xf - && \ + xz -dc lld-16.0.5.src.tar.xz | tar -xf - && \ + xz -dc libunwind-16.0.5.src.tar.xz | tar -xf - && \ + xz -dc cmake-16.0.5.src.tar.xz | tar -xf - && \ + xz -dc third-party-16.0.5.src.tar.xz | tar -xf - && \ + export ZSTDFLAGS=(PREFIX=/opt/rh/devtoolset-11/root/usr LIBDIR=/opt/rh/devtoolset-11/root/usr/lib64 SED_ERE_OPT=--regexp-extended MOREFLAGS="-std=c17 -O3 -fPIC -fPIE -fvisibility=hidden") && \ + cd zstd-1.5.5 && make -s "${ZSTDFLAGS[@]}" -j $CPU && make -s "${ZSTDFLAGS[@]}" install && rm -vf /opt/rh/devtoolset-11/root/usr/lib64/libzstd.so* && cd - && \ + mkdir build && cd build && /opt/python/cp311-cp311/bin/python \ + ../ninja-1.11.1/configure.py --bootstrap \ + --with-python=/opt/python/cp311-cp311/bin/python && \ + cp -v ninja /opt/rh/devtoolset-11/root/usr/bin/ninja && cd - && rm -rf build && \ + mkdir build && cd build && ../cmake-3.26.4/configure --prefix=/opt/rh/devtoolset-11/root/usr \ --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mv -v llvm-12.0.0.src llvm && \ - mv -v lld-12.0.0.src lld && \ - mv -v libunwind-12.0.0.src libunwind && \ + mv -v llvm-16.0.5.src llvm && \ + mv -v lld-16.0.5.src lld && \ + mv -v libunwind-16.0.5.src libunwind && \ + mv -v cmake-16.0.5.src cmake && \ + mv -v third-party-16.0.5.src third-party && \ cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_INSTALL_PREFIX=/toolchain \ - -DPython3_ROOT_DIR=/opt/python/cp39-cp39 -DLLVM_LIBDIR_SUFFIX=64 \ + -DCMAKE_INSTALL_PREFIX=/opt/rh/devtoolset-11/root/usr \ + -DPython3_ROOT_DIR=/opt/python/cp311-cp311 -DLLVM_LIBDIR_SUFFIX=64 \ -DLLVM_TARGETS_TO_BUILD="AArch64" -DLLVM_ENABLE_PROJECTS=lld \ - -DLLVM_DEFAULT_TARGET_TRIPLE="aarch64-redhat-linux-gnu" llvm && \ + -DLLVM_DEFAULT_TARGET_TRIPLE="aarch64-redhat-linux-gnu" \ + -DBUILD_SHARED_LIBS=OFF llvm && \ cmake --build build --target install && \ rm -rf build && rm -rf * diff --git a/utils/docker/Dockerfile.manylinux2014_x86_64 b/utils/docker/Dockerfile.manylinux2014_x86_64 index d968a789..73013544 100644 --- a/utils/docker/Dockerfile.manylinux2014_x86_64 +++ b/utils/docker/Dockerfile.manylinux2014_x86_64 @@ -7,76 +7,53 @@ MAINTAINER hydai hydai@secondstate.io ADD SHA256SUM /root/ -ENV PATH /toolchain/bin:$PATH -ENV CC gcc -ENV CXX g++ -ENV CPPFLAGS -I/toolchain/include -ENV LDFLAGS -L/toolchain/lib64 -ENV PKG_CONFIG_PATH /toolchain/lib64/pkgconfig +ENV PATH /opt/rh/devtoolset-11/root/usr/bin${PATH:+:${PATH}} +ENV MANPATH /opt/rh/devtoolset-11/root/usr/share/man${MANPATH:+:${MANPATH}} +ENV INFOPATH /opt/rh/devtoolset-11/root/usr/share/info${INFOPATH:+:${INFOPATH}} +ENV PKG_CONFIG_PATH /opt/rh/devtoolset-11/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} -RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build dpkg && \ - export CPU=$(/opt/python/cp39-cp39/bin/python3 -c \ +RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build dpkg centos-release-scl && \ + yum install -y devtoolset-11 && \ + export CPU=$(/opt/python/cp311-cp311/bin/python3 -c \ 'import multiprocessing; print(multiprocessing.cpu_count())') && \ - export CFGFLAGS="--prefix=/toolchain --disable-shared --libdir=/toolchain/lib64" && \ + export CFGFLAGS="--prefix=/opt/rh/devtoolset-11/root/usr --disable-shared --libdir=/opt/rh/devtoolset-11/root/usr/lib64" && \ curl -s -L -O --remote-name-all \ - https://ftp.gnu.org/gnu/gmp/gmp-6.2.1.tar.xz \ - https://ftp.gnu.org/gnu/mpfr/mpfr-4.1.0.tar.xz \ - https://ftp.gnu.org/gnu/mpc/mpc-1.2.1.tar.gz \ - https://libisl.sourceforge.io/isl-0.24.tar.xz \ - https://github.com/facebook/zstd/releases/download/v1.5.0/zstd-1.5.0.tar.gz \ - https://ftp.gnu.org/gnu/gcc/gcc-11.1.0/gcc-11.1.0.tar.xz \ - https://github.com/Kitware/CMake/releases/download/v3.20.2/cmake-3.20.2.tar.gz \ - https://github.com/ninja-build/ninja/archive/v1.10.2.tar.gz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-12.0.0/llvm-12.0.0.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-12.0.0/lld-12.0.0.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-12.0.0/libunwind-12.0.0.src.tar.xz && \ + https://github.com/facebook/zstd/releases/download/v1.5.5/zstd-1.5.5.tar.gz \ + https://github.com/Kitware/CMake/releases/download/v3.26.4/cmake-3.26.4.tar.gz \ + https://github.com/ninja-build/ninja/archive/refs/tags/v1.11.1.tar.gz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/llvm-16.0.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/lld-16.0.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/libunwind-16.0.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/cmake-16.0.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/third-party-16.0.5.src.tar.xz && \ sha256sum -c SHA256SUM && \ - xz -dc gmp-6.2.1.tar.xz | tar -xf - && \ - xz -dc mpfr-4.1.0.tar.xz | tar -xf - && \ - gzip -dc mpc-1.2.1.tar.gz | tar -xf - && \ - xz -dc isl-0.24.tar.xz | tar -xf - && \ - gzip -dc zstd-1.5.0.tar.gz | tar -xf - && \ - xz -dc gcc-11.1.0.tar.xz | tar -xf - && \ - gzip -dc cmake-3.20.2.tar.gz | tar -xf - && \ - gzip -dc v1.10.2.tar.gz | tar -xf - && \ - xz -dc llvm-12.0.0.src.tar.xz | tar -xf - && \ - xz -dc lld-12.0.0.src.tar.xz | tar -xf - && \ - xz -dc libunwind-12.0.0.src.tar.xz | tar -xf - && \ - mkdir build && cd build && ../gmp-6.2.1/configure --build=x86_64-pc-linux-gnu $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mkdir build && cd build && ../mpfr-4.1.0/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mkdir build && cd build && ../mpc-1.2.1/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mkdir build && cd build && ../isl-0.24/configure $CFGFLAGS && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - export ZSTDFLAGS="PREFIX=/toolchain LIBDIR=/toolchain/lib64 SED_ERE_OPT=--regexp-extended" && \ - cd zstd-1.5.0 && make -s $ZSTDFLAGS -j $CPU && make -s $ZSTDFLAGS install && rm -f /toolchain/lib64/libzstd.so* && cd - && \ - mkdir build && cd build && ../gcc-11.1.0/configure --prefix=/toolchain --libdir=/toolchain/lib64 \ - --with-gmp=/toolchain --with-gmp-lib=/toolchain/lib64 \ - --with-zstd=/toolchain --with-zstd-lib=/toolchain/lib64 \ - --disable-libmpx --disable-libsanitizer --disable-libunwind-exceptions \ - --disable-multilib --enable-__cxa_atexit --enable-gnu-indirect-function \ - --enable-gnu-unique-object --enable-initfini-array --enable-languages="c,c++,lto" \ - --enable-linker-build-id --enable-lto --enable-plugin --enable-threads=posix \ - --with-default-libstdcxx-abi="gcc4-compatible" \ - --with-gcc-major-version-only --with-linker-hash-style="gnu" \ - --with-arch="x86-64" --with-tune="generic" && \ - make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mv -v /toolchain/lib64/libstdc++.so.6.0.29 /toolchain/lib64/libstdc++.so.6.0.29.backup && \ - echo -e "OUTPUT_FORMAT(elf64-x86-64)\nINPUT ( libstdc++.so.6.0.19 libstdc++.a )" \ - > /toolchain/lib64/libstdc++.so.6.0.29 && \ - export PATH="/toolchain/bin:$PATH" && \ - mkdir build && cd build && /opt/python/cp39-cp39/bin/python \ - ../ninja-1.10.2/configure.py --bootstrap \ - --with-python=/opt/python/cp39-cp39/bin/python && \ - cp -v ninja /toolchain/bin/ninja && cd - && rm -rf build && \ - mkdir build && cd build && ../cmake-3.20.2/configure --prefix=/toolchain \ + gzip -dc zstd-1.5.5.tar.gz | tar -xf - && \ + gzip -dc cmake-3.26.4.tar.gz | tar -xf - && \ + gzip -dc v1.11.1.tar.gz | tar -xf - && \ + xz -dc llvm-16.0.5.src.tar.xz | tar -xf - && \ + xz -dc lld-16.0.5.src.tar.xz | tar -xf - && \ + xz -dc libunwind-16.0.5.src.tar.xz | tar -xf - && \ + xz -dc cmake-16.0.5.src.tar.xz | tar -xf - && \ + xz -dc third-party-16.0.5.src.tar.xz | tar -xf - && \ + export ZSTDFLAGS=(PREFIX=/opt/rh/devtoolset-11/root/usr LIBDIR=/opt/rh/devtoolset-11/root/usr/lib64 SED_ERE_OPT=--regexp-extended MOREFLAGS="-std=c17 -O3 -fPIC -fPIE -fvisibility=hidden") && \ + cd zstd-1.5.5 && make -s "${ZSTDFLAGS[@]}" -j $CPU && make -s "${ZSTDFLAGS[@]}" install && rm -vf /opt/rh/devtoolset-11/root/usr/lib64/libzstd.so* && cd - && \ + mkdir build && cd build && /opt/python/cp311-cp311/bin/python \ + ../ninja-1.11.1/configure.py --bootstrap \ + --with-python=/opt/python/cp311-cp311/bin/python && \ + cp -v ninja /opt/rh/devtoolset-11/root/usr/bin/ninja && cd - && rm -rf build && \ + mkdir build && cd build && ../cmake-3.26.4/configure --prefix=/opt/rh/devtoolset-11/root/usr \ --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mv -v llvm-12.0.0.src llvm && \ - mv -v lld-12.0.0.src lld && \ - mv -v libunwind-12.0.0.src libunwind && \ + mv -v llvm-16.0.5.src llvm && \ + mv -v lld-16.0.5.src lld && \ + mv -v libunwind-16.0.5.src libunwind && \ + mv -v cmake-16.0.5.src cmake && \ + mv -v third-party-16.0.5.src third-party && \ cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_INSTALL_PREFIX=/toolchain \ - -DPython3_ROOT_DIR=/opt/python/cp39-cp39 -DLLVM_LIBDIR_SUFFIX=64 \ + -DCMAKE_INSTALL_PREFIX=/opt/rh/devtoolset-11/root/usr \ + -DPython3_ROOT_DIR=/opt/python/cp311-cp311 -DLLVM_LIBDIR_SUFFIX=64 \ -DLLVM_TARGETS_TO_BUILD="X86" -DLLVM_ENABLE_PROJECTS=lld \ - -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" llvm && \ + -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" \ + -DBUILD_SHARED_LIBS=OFF llvm && \ cmake --build build --target install && \ rm -rf build && rm -rf * diff --git a/utils/docker/SHA256SUM b/utils/docker/SHA256SUM index 66d8051f..e176fbb4 100644 --- a/utils/docker/SHA256SUM +++ b/utils/docker/SHA256SUM @@ -1,11 +1,7 @@ -aecf6ecb975179eb3bb6a4a50cae192d41e92b9372b02300f9e8f1d5f559544e cmake-3.20.2.tar.gz -4c4a6fb8a8396059241c2e674b85b351c26a5d678274007f076957afa1cc9ddf gcc-11.1.0.tar.xz -fd4829912cddd12f84181c3451cc752be224643e87fac497b69edddadc49b4f2 gmp-6.2.1.tar.xz -043105cc544f416b48736fff8caf077fb0663a717d06b1113f16e391ac99ebad isl-0.24.tar.xz -9ed2a5b28853f7f58be9d04836ff43d6e4132df5a2c058b690dc3e9d75bd1cf5 libunwind-12.0.0.src.tar.xz -2cb7d497f3ce33ce8a2c50ad26ec93a8c45f57268d4d96953cd0f25566f753fd lld-12.0.0.src.tar.xz -49dc47c8697a1a0abd4ee51629a696d7bfe803662f2a7252a3b16fc75f3a8b50 llvm-12.0.0.src.tar.xz -17503d2c395dfcf106b622dc142683c1199431d095367c6aacba6eec30340459 mpc-1.2.1.tar.gz -0c98a3f1732ff6ca4ea690552079da9c597872d30e96ec28414ee23c95558a7f mpfr-4.1.0.tar.xz -ce35865411f0490368a8fc383f29071de6690cbadc27704734978221f25e2bed v1.10.2.tar.gz -5194fbfa781fcf45b98c5e849651aa7b3b0a008c6b72d4a0db760f3002291e94 zstd-1.5.0.tar.gz +9400d49acd53a4b8f310de60554a891436db5a19f6f227f99f0de13e4afaaaff cmake-16.0.5.src.tar.xz +e7f65970298a60e9608a9fc55ea9af5e9c8e1bc0dc0067f3e9f10eb3fe3e8986 libunwind-16.0.5.src.tar.xz +0c593d1c23f626dc33caa8bf112868f77126e018b58dd1641f5ae6aa1c2a0ce3 lld-16.0.5.src.tar.xz +701b764a182d8ea8fb017b6b5f7f5f1272a29f17c339b838f48de894ffdd4f91 llvm-16.0.5.src.tar.xz +0a4bbb8505e95570e529d6b3d5176e93beb3260f061de9001e320d57b59aed59 third-party-16.0.5.src.tar.xz +31747ae633213f1eda3842686f83c2aa1412e0f5691d1c14dbbcc67fe7400cea v1.11.1.tar.gz +9c4396cc829cfae319a6e2615202e82aad41372073482fce286fac78646d3ee4 zstd-1.5.5.tar.gz diff --git a/utils/docker/build-manylinux.sh b/utils/docker/build-manylinux.sh index 673104bf..41b1d13e 100755 --- a/utils/docker/build-manylinux.sh +++ b/utils/docker/build-manylinux.sh @@ -2,15 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-FileCopyrightText: 2019-2022 Second State INC -export PATH="/toolchain/bin:$PATH" -export CC=gcc -export CXX=g++ -export CPPFLAGS=-I/toolchain/include -export LDFLAGS=-L/toolchain/lib64 -curl -s -L -O --remote-name-all https://boostorg.jfrog.io/artifactory/main/release/1.79.0/source/boost_1_79_0.tar.bz2 -echo "475d589d51a7f8b3ba2ba4eda022b170e562ca3b760ee922c146b6c65856ef39 boost_1_79_0.tar.bz2" | sha256sum -c +curl -s -L -O --remote-name-all https://boostorg.jfrog.io/artifactory/main/release/1.82.0/source/boost_1_82_0.tar.bz2 +echo "a6e1ab9b0860e6a2881dd7b21fe9f737a095e5f33a3a874afc6a345228597ee6 boost_1_82_0.tar.bz2" | sha256sum -c git config --global --add safe.directory $(pwd) -bzip2 -dc boost_1_79_0.tar.bz2 | tar -xf - +bzip2 -dc boost_1_82_0.tar.bz2 | tar -xf - CMAKE_BUILD_TYPE="Release" IS_BUILD_TARGET=true @@ -43,7 +38,7 @@ for i in "$@"; do done if $IS_NINJA; then - if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DWASMEDGE_BUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM;DEB" -DBoost_NO_SYSTEM_PATHS=TRUE -DBOOST_INCLUDEDIR=$(pwd)/boost_1_79_0/ ${CMAKE_OPTS} .; then + if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DWASMEDGE_BUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM;DEB" -DBoost_NO_SYSTEM_PATHS=TRUE -DBOOST_INCLUDEDIR=$(pwd)/boost_1_82_0/ ${CMAKE_OPTS} .; then echo === CMakeOutput.log === cat build/CMakeFiles/CMakeOutput.log echo === CMakeError.log === @@ -54,7 +49,7 @@ else rm -rf build mkdir build cd build - if ! cmake -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DWASMEDGE_BUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM;DEB" -DBoost_NO_SYSTEM_PATHS=TRUE -DBOOST_INCLUDEDIR=$(pwd)/../boost_1_79_0/ ${CMAKE_OPTS} ..; then + if ! cmake -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DWASMEDGE_BUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM;DEB" -DBoost_NO_SYSTEM_PATHS=TRUE -DBOOST_INCLUDEDIR=$(pwd)/../boost_1_82_0/ ${CMAKE_OPTS} ..; then cd .. echo === CMakeOutput.log === cat build/CMakeFiles/CMakeOutput.log From 2b763702e9cafe7ab2ef4986b28a7b880365043d Mon Sep 17 00:00:00 2001 From: michael1017 <40065278+michael1017@users.noreply.github.com> Date: Sun, 11 Jun 2023 01:52:54 +0800 Subject: [PATCH 109/623] [Plugin] Add WASI-Logging Plugin (#2535) * Add Wasi Logging Plugin Signed-off-by: michael1017 * [Plugin] add tests for wasi_logging Signed-off-by: michael1017 * Add wasi-logging rust example Signed-off-by: michael1017 * Add wasi-logging Mock and Doc Signed-off-by: michael1017 * Example use wit-bindgen Signed-off-by: michael1017 * Rename module name from 'logging' to 'wasi:logging/logging' Signed-off-by: michael1017 * Update CI for #2563 Signed-off-by: michael1017 * Remove submodule and add wit dir Signed-off-by: michael1017 --------- Signed-off-by: michael1017 --- plugins/CMakeLists.txt | 4 + plugins/wasi_logging/CMakeLists.txt | 31 +++++ plugins/wasi_logging/env.cpp | 36 ++++++ plugins/wasi_logging/func.cpp | 85 ++++++++++++++ plugins/wasi_logging/module.cpp | 16 +++ plugins/wasi_logging/wasi_logging/base.h | 22 ++++ plugins/wasi_logging/wasi_logging/enum.h | 13 +++ plugins/wasi_logging/wasi_logging/env.h | 23 ++++ plugins/wasi_logging/wasi_logging/func.h | 17 +++ plugins/wasi_logging/wasi_logging/module.h | 20 ++++ test/plugins/CMakeLists.txt | 4 + test/plugins/wasi_logging/CMakeLists.txt | 35 ++++++ test/plugins/wasi_logging/wasi_logging.cpp | 126 +++++++++++++++++++++ 13 files changed, 432 insertions(+) create mode 100644 plugins/wasi_logging/CMakeLists.txt create mode 100644 plugins/wasi_logging/env.cpp create mode 100644 plugins/wasi_logging/func.cpp create mode 100644 plugins/wasi_logging/module.cpp create mode 100644 plugins/wasi_logging/wasi_logging/base.h create mode 100644 plugins/wasi_logging/wasi_logging/enum.h create mode 100644 plugins/wasi_logging/wasi_logging/env.h create mode 100644 plugins/wasi_logging/wasi_logging/func.h create mode 100644 plugins/wasi_logging/wasi_logging/module.h create mode 100644 test/plugins/wasi_logging/CMakeLists.txt create mode 100644 test/plugins/wasi_logging/wasi_logging.cpp diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index 015ea804..67d7d288 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -26,3 +26,7 @@ if(WASMEDGE_PLUGIN_WASM_BPF) message(WARNING "Only Linux platforms support wasm_bpf plug-in now.") endif() endif() + +if(WASMEDGE_PLUGIN_WASI_LOGGING) + add_subdirectory(wasi_logging) +endif() diff --git a/plugins/wasi_logging/CMakeLists.txt b/plugins/wasi_logging/CMakeLists.txt new file mode 100644 index 00000000..0d0b6469 --- /dev/null +++ b/plugins/wasi_logging/CMakeLists.txt @@ -0,0 +1,31 @@ +wasmedge_add_library(wasmedgePluginWasiLogging + SHARED + env.cpp + func.cpp + module.cpp +) + +target_compile_options(wasmedgePluginWasiLogging + PUBLIC + -DWASMEDGE_PLUGIN +) + +target_include_directories(wasmedgePluginWasiLogging + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} +) + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasiLogging + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginWasiLogging + PRIVATE + wasmedge_shared + ) +endif() + +install(TARGETS wasmedgePluginWasiLogging DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) diff --git a/plugins/wasi_logging/env.cpp b/plugins/wasi_logging/env.cpp new file mode 100644 index 00000000..0f55170f --- /dev/null +++ b/plugins/wasi_logging/env.cpp @@ -0,0 +1,36 @@ +#include "wasi_logging/env.h" +#include "wasi_logging/module.h" + +namespace WasmEdge { +namespace Host { + +namespace { + +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasiLoggingModule; +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasi_logging", + .Description = "", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 1, 0, 0}, + .ModuleCount = 1, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "wasi:logging/logging", + .Description = "", + .Create = create, + }, + }, + .AddOptions = nullptr, +}; + +} // namespace + +Plugin::PluginRegister WasiLoggingEnvironment::Register(&Descriptor); + +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasi_logging/func.cpp b/plugins/wasi_logging/func.cpp new file mode 100644 index 00000000..5f0a2c01 --- /dev/null +++ b/plugins/wasi_logging/func.cpp @@ -0,0 +1,85 @@ +#include "wasi_logging/func.h" +#include "wasi_logging/enum.h" +#include + +namespace WasmEdge { +namespace Host { + +using namespace std::literals; + +Expect WasiLoggingLog::body(const Runtime::CallingFrame &Frame, + uint32_t Level, uint32_t CxtPtr, + uint32_t CxtLen, uint32_t MsgPtr, + uint32_t MsgLen) { + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + // Get Buffer Pointer + char *CxtBuf = MemInst->getPointer(CxtPtr); + char *MsgBuf = MemInst->getPointer(MsgPtr); + if (CxtBuf == nullptr || MsgBuf == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + // Copy Context String and Message String + std::string CxtStr, MsgStr; + std::copy_n(CxtBuf, CxtLen, std::back_inserter(CxtStr)); + std::copy_n(MsgBuf, MsgLen, std::back_inserter(MsgStr)); + + // Setup Logger for Stdout or Stderr + CxtStr == "stderr"sv ? Env.isCxtStrStderr = true : Env.isCxtStrStderr = false; + auto logger = Env.isCxtStrStderr ? Env.StderrLogger : Env.StdoutLogger; + + // Construct Spdlog Message + std::string SpdlogMsg; + if (!CxtStr.empty()) { + SpdlogMsg = CxtStr + ": " + MsgStr; + } else { + SpdlogMsg = MsgStr; + } + + // Print Message by Logging Level + switch (Level) { + case WASILOGGING::WasiLoggingLevel::Trace: + logger->trace(SpdlogMsg); + break; + case WASILOGGING::WasiLoggingLevel::Debug: + logger->debug(SpdlogMsg); + break; + case WASILOGGING::WasiLoggingLevel::Info: + logger->info(SpdlogMsg); + break; + case WASILOGGING::WasiLoggingLevel::Warn: + logger->warn(SpdlogMsg); + break; + case WASILOGGING::WasiLoggingLevel::Error: + logger->error(SpdlogMsg); + break; + case WASILOGGING::WasiLoggingLevel::Critical: + logger->critical(SpdlogMsg); + break; + default: + spdlog::error("[WasiLogging] Unrecognized Logging Level: {}"sv, Level); + spdlog::error("[WasiLogging] Trace Level = {}"sv, + static_cast(WASILOGGING::WasiLoggingLevel::Trace)); + spdlog::error("[WasiLogging] Debug Level = {}"sv, + static_cast(WASILOGGING::WasiLoggingLevel::Debug)); + spdlog::error("[WasiLogging] Info Level = {}"sv, + static_cast(WASILOGGING::WasiLoggingLevel::Info)); + spdlog::error("[WasiLogging] Warn Level = {}"sv, + static_cast(WASILOGGING::WasiLoggingLevel::Warn)); + spdlog::error("[WasiLogging] Error Level = {}"sv, + static_cast(WASILOGGING::WasiLoggingLevel::Error)); + spdlog::error( + "[WasiLogging] Critical Level = {}"sv, + static_cast(WASILOGGING::WasiLoggingLevel::Critical)); + return Unexpect(ErrCode::Value::HostFuncError); + } + return {}; +} + +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasi_logging/module.cpp b/plugins/wasi_logging/module.cpp new file mode 100644 index 00000000..6ebd7879 --- /dev/null +++ b/plugins/wasi_logging/module.cpp @@ -0,0 +1,16 @@ +#include "wasi_logging/module.h" +#include "wasi_logging/func.h" +#include + +namespace WasmEdge { +namespace Host { + +using namespace std::literals; + +WasiLoggingModule::WasiLoggingModule() + : ModuleInstance("wasi:logging/logging"sv) { + addHostFunc("log"sv, std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasi_logging/wasi_logging/base.h b/plugins/wasi_logging/wasi_logging/base.h new file mode 100644 index 00000000..11b5504a --- /dev/null +++ b/plugins/wasi_logging/wasi_logging/base.h @@ -0,0 +1,22 @@ +#pragma once + +#include "wasi_logging/env.h" + +#include "common/errcode.h" +#include "runtime/callingframe.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { + +template class WasiLogging : public Runtime::HostFunction { +public: + WasiLogging(WasiLoggingEnvironment &HostEnv) + : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + WasiLoggingEnvironment &Env; +}; + +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasi_logging/wasi_logging/enum.h b/plugins/wasi_logging/wasi_logging/enum.h new file mode 100644 index 00000000..5c4e2e25 --- /dev/null +++ b/plugins/wasi_logging/wasi_logging/enum.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +namespace WasmEdge { +namespace Host { +namespace WASILOGGING { + +enum WasiLoggingLevel : uint32_t { Trace, Debug, Info, Warn, Error, Critical }; + +} // namespace WASILOGGING +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasi_logging/wasi_logging/env.h b/plugins/wasi_logging/wasi_logging/env.h new file mode 100644 index 00000000..8a2864d2 --- /dev/null +++ b/plugins/wasi_logging/wasi_logging/env.h @@ -0,0 +1,23 @@ +#pragma once + +#include "plugin/plugin.h" +#include +namespace WasmEdge { +namespace Host { + +class WasiLoggingEnvironment { +public: + WasiLoggingEnvironment() noexcept { + StdoutLogger->set_level(spdlog::level::trace); + StderrLogger->set_level(spdlog::level::trace); + } + bool isCxtStrStderr = false; + inline const static std::shared_ptr StdoutLogger = + spdlog::stdout_color_mt("wasi_logging_stdout"); + inline const static std::shared_ptr StderrLogger = + spdlog::stderr_color_mt("wasi_logging_stderr"); + static Plugin::PluginRegister Register; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_logging/wasi_logging/func.h b/plugins/wasi_logging/wasi_logging/func.h new file mode 100644 index 00000000..f7631e98 --- /dev/null +++ b/plugins/wasi_logging/wasi_logging/func.h @@ -0,0 +1,17 @@ +#pragma once + +#include "wasi_logging/base.h" + +namespace WasmEdge { +namespace Host { + +class WasiLoggingLog : public WasiLogging { +public: + WasiLoggingLog(WasiLoggingEnvironment &HostEnv) : WasiLogging(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t Level, + uint32_t CxtPtr, uint32_t CxtLen, uint32_t MsgPtr, + uint32_t MsgLen); +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_logging/wasi_logging/module.h b/plugins/wasi_logging/wasi_logging/module.h new file mode 100644 index 00000000..12504522 --- /dev/null +++ b/plugins/wasi_logging/wasi_logging/module.h @@ -0,0 +1,20 @@ +#pragma once + +#include "runtime/instance/module.h" +#include "wasi_logging/env.h" + +namespace WasmEdge { +namespace Host { + +class WasiLoggingModule : public Runtime::Instance::ModuleInstance { +public: + WasiLoggingModule(); + + WasiLoggingEnvironment &getEnv() { return Env; } + +private: + WasiLoggingEnvironment Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt index 5a5f7c35..ccc4a365 100644 --- a/test/plugins/CMakeLists.txt +++ b/test/plugins/CMakeLists.txt @@ -21,6 +21,10 @@ if(WASMEDGE_PLUGIN_WASM_BPF) endif() endif() +if(WASMEDGE_PLUGIN_WASI_LOGGING) + add_subdirectory(wasi_logging) +endif() + if(CMAKE_SYSTEM_NAME MATCHES "Linux" OR CMAKE_SYSTEM_NAME MATCHES "Darwin") add_subdirectory(unittest) endif() diff --git a/test/plugins/wasi_logging/CMakeLists.txt b/test/plugins/wasi_logging/CMakeLists.txt new file mode 100644 index 00000000..d3373e80 --- /dev/null +++ b/test/plugins/wasi_logging/CMakeLists.txt @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +wasmedge_add_executable(wasiLoggingTests + wasi_logging.cpp +) + +add_dependencies(wasiLoggingTests + wasmedgePluginWasiLogging +) + +target_include_directories(wasiLoggingTests + PUBLIC + $ + $ +) + +target_link_libraries(wasiLoggingTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} +) +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasiLoggingTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasiLoggingTests + PRIVATE + wasmedge_shared + ) +endif() + +add_test(wasiLoggingTests wasiLoggingTests) diff --git a/test/plugins/wasi_logging/wasi_logging.cpp b/test/plugins/wasi_logging/wasi_logging.cpp new file mode 100644 index 00000000..d1c36c02 --- /dev/null +++ b/test/plugins/wasi_logging/wasi_logging.cpp @@ -0,0 +1,126 @@ +#include "common/defines.h" +#include "runtime/instance/module.h" +#include "wasi_logging/func.h" +#include "wasi_logging/module.h" + +#include +#include +#include +#include +#include + +namespace { + +WasmEdge::Runtime::Instance::ModuleInstance *createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasi_logging/" + "libwasmedgePluginWasiLogging" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasi_logging"sv)) { + if (const auto *Module = Plugin->findModule("wasi:logging/logging"sv)) { + return Module->create().release(); + } + } + return nullptr; +} + +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t Offset, uint32_t Cnt, uint8_t C = 0) noexcept { + std::fill_n(MemInst.getPointer(Offset), Cnt, C); +} + +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t Offset, const std::string &Str) noexcept { + char *Buf = MemInst.getPointer(Offset); + std::copy_n(Str.c_str(), Str.length(), Buf); +} + +} // namespace + +TEST(WasiLoggingTests, func_log) { + // Create the wasi-logging module instance. + auto WasiLoggingMod = + dynamic_cast(createModule()); + EXPECT_NE(WasiLoggingMod, nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + EXPECT_NE(MemInstPtr, nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // Clear the memory[0, 32]. + fillMemContent(MemInst, 0, 32); + // Set strings in memory + fillMemContent(MemInst, 0, std::string("CxtStr")); + fillMemContent(MemInst, 8, std::string("stderr")); + fillMemContent(MemInst, 16, std::string("MsgStr")); + + // Get the function "log" + auto *FuncInst = WasiLoggingMod->findFuncExports("log"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInst = + dynamic_cast(FuncInst->getHostFunc()); + + // Show All Level + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), UINT32_C(6), UINT32_C(16), UINT32_C(6)}, + {})); + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + UINT32_C(1), UINT32_C(0), UINT32_C(6), UINT32_C(16), UINT32_C(6)}, + {})); + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + UINT32_C(2), UINT32_C(0), UINT32_C(6), UINT32_C(16), UINT32_C(6)}, + {})); + + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + UINT32_C(3), UINT32_C(0), UINT32_C(6), UINT32_C(16), UINT32_C(6)}, + {})); + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + UINT32_C(4), UINT32_C(0), UINT32_C(6), UINT32_C(16), UINT32_C(6)}, + {})); + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + UINT32_C(5), UINT32_C(0), UINT32_C(6), UINT32_C(16), UINT32_C(6)}, + {})); + EXPECT_FALSE(WasiLoggingMod->getEnv().isCxtStrStderr); + + // Stderr Context + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(8), UINT32_C(6), UINT32_C(16), UINT32_C(6)}, + {})); + EXPECT_TRUE(WasiLoggingMod->getEnv().isCxtStrStderr); + + // UnKnown Level + EXPECT_FALSE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + UINT32_C(6), UINT32_C(0), UINT32_C(6), UINT32_C(16), UINT32_C(6)}, + {})); + EXPECT_FALSE(WasiLoggingMod->getEnv().isCxtStrStderr); + + delete WasiLoggingMod; +} + +GTEST_API_ int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} From e7b58d3f05a6f25325cff3233c674f81f5d802b7 Mon Sep 17 00:00:00 2001 From: hydai Date: Sun, 11 Jun 2023 03:16:37 +0800 Subject: [PATCH 110/623] [Misc] use std::boyer_moore_horspool_searcher instead of boost one Since we removed all of the boost related part, we can also drop the boost dependency. Signed-off-by: hydai --- utils/docker/Dockerfile.base | 1 - utils/docker/Dockerfile.ubuntu2004_x86_64 | 1 - utils/docker/Dockerfile.ubuntu2104_armv7l | 1 - utils/docker/build-manylinux.sh | 7 ++----- 4 files changed, 2 insertions(+), 8 deletions(-) diff --git a/utils/docker/Dockerfile.base b/utils/docker/Dockerfile.base index 08d1e58a..ffd2739b 100644 --- a/utils/docker/Dockerfile.base +++ b/utils/docker/Dockerfile.base @@ -12,7 +12,6 @@ RUN apt update && apt upgrade -y \ ninja-build \ curl \ git \ - libboost-all-dev \ llvm-15-dev \ liblld-15-dev diff --git a/utils/docker/Dockerfile.ubuntu2004_x86_64 b/utils/docker/Dockerfile.ubuntu2004_x86_64 index dda1cb9c..db96296b 100644 --- a/utils/docker/Dockerfile.ubuntu2004_x86_64 +++ b/utils/docker/Dockerfile.ubuntu2004_x86_64 @@ -12,7 +12,6 @@ RUN apt update && apt upgrade -y \ curl \ git \ dpkg-dev \ - libboost-all-dev \ llvm-12-dev \ liblld-12-dev \ gcc \ diff --git a/utils/docker/Dockerfile.ubuntu2104_armv7l b/utils/docker/Dockerfile.ubuntu2104_armv7l index 265f0fd9..28e7c94c 100644 --- a/utils/docker/Dockerfile.ubuntu2104_armv7l +++ b/utils/docker/Dockerfile.ubuntu2104_armv7l @@ -18,7 +18,6 @@ RUN apt update && apt upgrade -y \ g++-multilib \ git \ llvm-12-dev \ - libboost-all-dev \ liblld-12-dev \ libssl-dev \ ninja-build \ diff --git a/utils/docker/build-manylinux.sh b/utils/docker/build-manylinux.sh index 41b1d13e..6570fbe7 100755 --- a/utils/docker/build-manylinux.sh +++ b/utils/docker/build-manylinux.sh @@ -2,10 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-FileCopyrightText: 2019-2022 Second State INC -curl -s -L -O --remote-name-all https://boostorg.jfrog.io/artifactory/main/release/1.82.0/source/boost_1_82_0.tar.bz2 -echo "a6e1ab9b0860e6a2881dd7b21fe9f737a095e5f33a3a874afc6a345228597ee6 boost_1_82_0.tar.bz2" | sha256sum -c git config --global --add safe.directory $(pwd) -bzip2 -dc boost_1_82_0.tar.bz2 | tar -xf - CMAKE_BUILD_TYPE="Release" IS_BUILD_TARGET=true @@ -38,7 +35,7 @@ for i in "$@"; do done if $IS_NINJA; then - if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DWASMEDGE_BUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM;DEB" -DBoost_NO_SYSTEM_PATHS=TRUE -DBOOST_INCLUDEDIR=$(pwd)/boost_1_82_0/ ${CMAKE_OPTS} .; then + if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DWASMEDGE_BUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM;DEB" ${CMAKE_OPTS} .; then echo === CMakeOutput.log === cat build/CMakeFiles/CMakeOutput.log echo === CMakeError.log === @@ -49,7 +46,7 @@ else rm -rf build mkdir build cd build - if ! cmake -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DWASMEDGE_BUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM;DEB" -DBoost_NO_SYSTEM_PATHS=TRUE -DBOOST_INCLUDEDIR=$(pwd)/../boost_1_82_0/ ${CMAKE_OPTS} ..; then + if ! cmake -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DWASMEDGE_BUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM;DEB" ${CMAKE_OPTS} ..; then cd .. echo === CMakeOutput.log === cat build/CMakeFiles/CMakeOutput.log From e18c162ca89d50fdee5bb74c8d0a299f4c28066d Mon Sep 17 00:00:00 2001 From: yanghaku <36074633+yanghaku@users.noreply.github.com> Date: Wed, 14 Jun 2023 20:14:20 +0800 Subject: [PATCH 111/623] [WASI-NN] Fix TfLite model data ownership (#2589) Signed-off-by: yanghaku <1961882079@qq.com> --- plugins/wasi_nn/tfl.cpp | 6 ++++-- plugins/wasi_nn/tfl.h | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn/tfl.cpp b/plugins/wasi_nn/tfl.cpp index acbf4f96..32d3c5f3 100644 --- a/plugins/wasi_nn/tfl.cpp +++ b/plugins/wasi_nn/tfl.cpp @@ -23,12 +23,14 @@ Expect load(WASINN::WasiNNEnvironment &Env, Builders.size()); return WASINN::ErrNo::InvalidArgument; } - auto Weight = Builders[0]; // Add a new graph. Env.NNGraph.emplace_back(WASINN::Backend::TensorflowLite); auto &GraphRef = Env.NNGraph.back().get(); - GraphRef.TFLiteMod = TfLiteModelCreate(Weight.data(), Weight.size()); + // Copy graph builder data to TfLiteModData and create a new TfLiteModel. + GraphRef.TfLiteModData.assign(Builders[0].begin(), Builders[0].end()); + GraphRef.TFLiteMod = TfLiteModelCreate(GraphRef.TfLiteModData.data(), + GraphRef.TfLiteModData.size()); if (unlikely(GraphRef.TFLiteMod == nullptr)) { spdlog::error("[WASI-NN] Cannot import TFLite model"); Env.NNGraph.pop_back(); diff --git a/plugins/wasi_nn/tfl.h b/plugins/wasi_nn/tfl.h index 7e2bff4e..84d8227a 100644 --- a/plugins/wasi_nn/tfl.h +++ b/plugins/wasi_nn/tfl.h @@ -24,6 +24,7 @@ struct Graph { TfLiteModelDelete(TFLiteMod); } } + std::vector TfLiteModData; TfLiteModel *TFLiteMod = nullptr; }; From 6cabbe11207954f8894aefa2fcf81d7ccfc3c57c Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 13 Jun 2023 17:10:00 +0800 Subject: [PATCH 112/623] [CI] Add zlib1g-dev in Ubuntu env to fix zlib not found issue Signed-off-by: hydai --- utils/docker/Dockerfile.base | 1 + 1 file changed, 1 insertion(+) diff --git a/utils/docker/Dockerfile.base b/utils/docker/Dockerfile.base index ffd2739b..db431fb9 100644 --- a/utils/docker/Dockerfile.base +++ b/utils/docker/Dockerfile.base @@ -12,6 +12,7 @@ RUN apt update && apt upgrade -y \ ninja-build \ curl \ git \ + zlib1g-dev \ llvm-15-dev \ liblld-15-dev From 24ff073320673a6eed2eea4fa333faa8691ee39b Mon Sep 17 00:00:00 2001 From: YiYing He Date: Thu, 1 Jun 2023 05:02:34 +0800 Subject: [PATCH 113/623] [Plugin] Add WasmEdge-tensorflow, tensorflowlite, and image plugins. Signed-off-by: YiYing He --- plugins/CMakeLists.txt | 27 ++ plugins/wasmedge_image/CMakeLists.txt | 142 ++++++ plugins/wasmedge_image/image_base.h | 25 ++ plugins/wasmedge_image/image_env.cpp | 40 ++ plugins/wasmedge_image/image_env.h | 32 ++ plugins/wasmedge_image/image_func.cpp | 170 +++++++ plugins/wasmedge_image/image_func.h | 34 ++ plugins/wasmedge_image/image_module.cpp | 19 + plugins/wasmedge_image/image_module.h | 25 ++ plugins/wasmedge_tensorflow/CMakeLists.txt | 37 ++ plugins/wasmedge_tensorflow/tensorflow_base.h | 25 ++ .../wasmedge_tensorflow/tensorflow_env.cpp | 40 ++ plugins/wasmedge_tensorflow/tensorflow_env.h | 129 ++++++ .../wasmedge_tensorflow/tensorflow_func.cpp | 421 ++++++++++++++++++ plugins/wasmedge_tensorflow/tensorflow_func.h | 94 ++++ .../wasmedge_tensorflow/tensorflow_module.cpp | 40 ++ .../wasmedge_tensorflow/tensorflow_module.h | 25 ++ .../wasmedge_tensorflowlite/CMakeLists.txt | 37 ++ .../tensorflowlite_base.h | 25 ++ .../tensorflowlite_env.cpp | 40 ++ .../tensorflowlite_env.h | 75 ++++ .../tensorflowlite_func.cpp | 296 ++++++++++++ .../tensorflowlite_func.h | 66 +++ .../tensorflowlite_module.cpp | 31 ++ .../tensorflowlite_module.h | 25 ++ 25 files changed, 1920 insertions(+) create mode 100644 plugins/wasmedge_image/CMakeLists.txt create mode 100644 plugins/wasmedge_image/image_base.h create mode 100644 plugins/wasmedge_image/image_env.cpp create mode 100644 plugins/wasmedge_image/image_env.h create mode 100644 plugins/wasmedge_image/image_func.cpp create mode 100644 plugins/wasmedge_image/image_func.h create mode 100644 plugins/wasmedge_image/image_module.cpp create mode 100644 plugins/wasmedge_image/image_module.h create mode 100644 plugins/wasmedge_tensorflow/CMakeLists.txt create mode 100644 plugins/wasmedge_tensorflow/tensorflow_base.h create mode 100644 plugins/wasmedge_tensorflow/tensorflow_env.cpp create mode 100644 plugins/wasmedge_tensorflow/tensorflow_env.h create mode 100644 plugins/wasmedge_tensorflow/tensorflow_func.cpp create mode 100644 plugins/wasmedge_tensorflow/tensorflow_func.h create mode 100644 plugins/wasmedge_tensorflow/tensorflow_module.cpp create mode 100644 plugins/wasmedge_tensorflow/tensorflow_module.h create mode 100644 plugins/wasmedge_tensorflowlite/CMakeLists.txt create mode 100644 plugins/wasmedge_tensorflowlite/tensorflowlite_base.h create mode 100644 plugins/wasmedge_tensorflowlite/tensorflowlite_env.cpp create mode 100644 plugins/wasmedge_tensorflowlite/tensorflowlite_env.h create mode 100644 plugins/wasmedge_tensorflowlite/tensorflowlite_func.cpp create mode 100644 plugins/wasmedge_tensorflowlite/tensorflowlite_func.h create mode 100644 plugins/wasmedge_tensorflowlite/tensorflowlite_module.cpp create mode 100644 plugins/wasmedge_tensorflowlite/tensorflowlite_module.h diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index 67d7d288..80e50955 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -9,6 +9,33 @@ if(WASMEDGE_PLUGIN_WASI_CRYPTO) add_subdirectory(wasi_crypto) endif() +if(WASMEDGE_PLUGIN_IMAGE) + # Only Linux and MacOS support wasmedge_image now. + if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") + add_subdirectory(wasmedge_image) + else() + message(WARNING "Only Linux and Darwin platforms support WasmEdge_Image plug-in now.") + endif() +endif() + +if(WASMEDGE_PLUGIN_TENSORFLOW) + # Only Linux and MacOS support wasmedge_tensorflow now. + if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") + add_subdirectory(wasmedge_tensorflow) + else() + message(WARNING "Only Linux and Darwin platforms support WasmEdge_Tensorflow plug-in now.") + endif() +endif() + +if(WASMEDGE_PLUGIN_TENSORFLOWLITE) + # Only Linux and MacOS support wasmedge_tensorflowlite now. + if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") + add_subdirectory(wasmedge_tensorflowlite) + else() + message(WARNING "Only Linux and Darwin platforms support WasmEdge_TensorflowLite plug-in now.") + endif() +endif() + if(WASMEDGE_PLUGIN_PROCESS) # Only Linux systems support wasmedge_process now. if(CMAKE_SYSTEM_NAME MATCHES "Linux") diff --git a/plugins/wasmedge_image/CMakeLists.txt b/plugins/wasmedge_image/CMakeLists.txt new file mode 100644 index 00000000..5cd03215 --- /dev/null +++ b/plugins/wasmedge_image/CMakeLists.txt @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +wasmedge_add_library(wasmedgePluginWasmEdgeImage + SHARED + image_env.cpp + image_func.cpp + image_module.cpp +) + +target_compile_options(wasmedgePluginWasmEdgeImage + PUBLIC + -DWASMEDGE_PLUGIN +) + +target_include_directories(wasmedgePluginWasmEdgeImage + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} +) + +# Find the libjpeg and libpng. +add_library(wasmedgePluginWasmEdgeImageJPEG STATIC IMPORTED GLOBAL) +add_library(wasmedgePluginWasmEdgeImagePNG STATIC IMPORTED GLOBAL) +if(APPLE) + # For MacOS, use the installed libjpeg, libpng, and zlib static library. + find_package(JPEG REQUIRED) + find_package(PNG REQUIRED) + # The find_package will get the shared library. Therefore find the static one. + find_library(JPEG_STATIC NAMES libjpeg.a) + find_library(PNG_STATIC NAMES libpng16.a) + set_target_properties(wasmedgePluginWasmEdgeImageJPEG + PROPERTIES + IMPORTED_LOCATION ${JPEG_STATIC} + INTERFACE_INCLUDE_DIRECTORIES "${JPEG_INCLUDE_DIR}" + ) + set_target_properties(wasmedgePluginWasmEdgeImagePNG + PROPERTIES + IMPORTED_LOCATION ${PNG_STATIC} + INTERFACE_INCLUDE_DIRECTORIES "${PNG_INCLUDE_DIR}" + ) +elseif(UNIX) + # Fetch and build libjpeg and libpng. + include(FetchContent) + FetchContent_Declare( + wasmedge_image_libpng + URL "https://downloads.sourceforge.net/libpng/libpng-1.6.39.tar.gz" + URL_HASH "SHA256=af4fb7f260f839919e5958e5ab01a275d4fe436d45442a36ee62f73e5beb75ba" + ) + FetchContent_GetProperties(wasmedge_image_libpng) + if(NOT wasmedge_image_libpng_POPULATED) + message(STATUS "Downloading libpng source") + FetchContent_Populate(wasmedge_image_libpng) + message(STATUS "Downloading libpng source - done") + endif() + + FetchContent_Declare( + wasmedge_image_libjpeg + URL "http://ijg.org/files/jpegsrc.v9e.tar.gz" + URL_HASH "SHA256=4077d6a6a75aeb01884f708919d25934c93305e49f7e3f36db9129320e6f4f3d" + ) + FetchContent_GetProperties(wasmedge_image_libjpeg) + if(NOT wasmedge_image_libjpeg_POPULATED) + message(STATUS "Downloading libjpeg source") + FetchContent_Populate(wasmedge_image_libjpeg) + message(STATUS "Downloading libjpeg source - done") + endif() + + add_custom_command( + OUTPUT ${wasmedge_image_libjpeg_SOURCE_DIR}/.libs/libjpeg.a + COMMAND ${CMAKE_COMMAND} -E env CFLAGS=-fPIC ./configure --enable-shared=off + COMMAND make + WORKING_DIRECTORY ${wasmedge_image_libjpeg_SOURCE_DIR} + ) + add_custom_command( + OUTPUT ${wasmedge_image_libpng_SOURCE_DIR}/.libs/libpng16.a + COMMAND ${CMAKE_COMMAND} -E env CFLAGS=-fPIC ./configure --enable-shared=off + COMMAND make + WORKING_DIRECTORY ${wasmedge_image_libpng_SOURCE_DIR} + ) + add_custom_target(wasmedgePluginWasmEdgeImageJPEG_target + ALL DEPENDS + ${wasmedge_image_libjpeg_SOURCE_DIR}/.libs/libjpeg.a + ) + add_custom_target(wasmedgePluginWasmEdgeImagePNG_target + ALL DEPENDS + ${wasmedge_image_libpng_SOURCE_DIR}/.libs/libpng16.a + ) + add_dependencies(wasmedgePluginWasmEdgeImageJPEG wasmedgePluginWasmEdgeImageJPEG_target) + add_dependencies(wasmedgePluginWasmEdgeImagePNG wasmedgePluginWasmEdgeImagePNG_target) + set_target_properties(wasmedgePluginWasmEdgeImageJPEG + PROPERTIES + IMPORTED_LOCATION ${wasmedge_image_libjpeg_SOURCE_DIR}/.libs/libjpeg.a + INTERFACE_INCLUDE_DIRECTORIES ${wasmedge_image_libjpeg_SOURCE_DIR} + ) + set_target_properties(wasmedgePluginWasmEdgeImagePNG + PROPERTIES + IMPORTED_LOCATION ${wasmedge_image_libpng_SOURCE_DIR}/.libs/libpng16.a + INTERFACE_INCLUDE_DIRECTORIES ${wasmedge_image_libpng_SOURCE_DIR} + ) +endif() + +# Need zlib and boost. +find_package(ZLIB REQUIRED) +find_package(Boost 1.74.0) +if(${Boost_FOUND}) +else() + FetchContent_Declare( + Boost + URL https://boostorg.jfrog.io/artifactory/main/release/1.82.0/source/boost_1_82_0.tar.bz2 + URL_HASH SHA256=a6e1ab9b0860e6a2881dd7b21fe9f737a095e5f33a3a874afc6a345228597ee6 + ) + set(BOOST_ENABLE_CMAKE ON) + set(BOOST_RUNTIME_LINK static) + message(STATUS "Downloading boost 1.82.0 source") + FetchContent_MakeAvailable(Boost) + message(STATUS "Downloading boost 1.82.0 source - done") + add_library(Boost_boost INTERFACE) + add_library(Boost::boost ALIAS Boost_boost) + target_include_directories(Boost_boost SYSTEM INTERFACE ${boost_SOURCE_DIR}) +endif() + +target_link_libraries(wasmedgePluginWasmEdgeImage + PUBLIC + Boost::boost + wasmedgePluginWasmEdgeImageJPEG + wasmedgePluginWasmEdgeImagePNG + z +) +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasmEdgeImage + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginWasmEdgeImage + PRIVATE + wasmedge_shared + ) +endif() + +install(TARGETS wasmedgePluginWasmEdgeImage DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) diff --git a/plugins/wasmedge_image/image_base.h b/plugins/wasmedge_image/image_base.h new file mode 100644 index 00000000..733cdf18 --- /dev/null +++ b/plugins/wasmedge_image/image_base.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "image_env.h" + +#include "common/errcode.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeImage { + +template class Func : public Runtime::HostFunction { +public: + Func(ImgEnv &HostEnv) : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + ImgEnv &Env; +}; + +} // namespace WasmEdgeImage +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_image/image_env.cpp b/plugins/wasmedge_image/image_env.cpp new file mode 100644 index 00000000..022787ca --- /dev/null +++ b/plugins/wasmedge_image/image_env.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "image_env.h" +#include "image_module.h" + +namespace WasmEdge { +namespace Host { + +namespace { + +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeImageModule; +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasmedge_image", + .Description = "Image loading plug-in for WasmEdge.", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 13, 0, 0}, + .ModuleCount = 1, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "wasmedge_image", + .Description = + "This module contains WasmEdge-Image host functions.", + .Create = create, + }, + }, + .AddOptions = nullptr, +}; + +} // namespace + +Plugin::PluginRegister WasmEdgeImage::ImgEnv::Register(&Descriptor); + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_image/image_env.h b/plugins/wasmedge_image/image_env.h new file mode 100644 index 00000000..0eea7ba5 --- /dev/null +++ b/plugins/wasmedge_image/image_env.h @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "plugin/plugin.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeImage { + +enum class ErrNo : uint32_t { + Success = 0, // No error occurred. + Fail = 1, // Runtime Error. +}; + +enum class DataType : uint32_t { + RGB8 = 0, + BGR8 = 1, + RGB32F = 2, + BGR32F = 3, +}; + +struct ImgEnv { + static Plugin::PluginRegister Register; +}; + +} // namespace WasmEdgeImage +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_image/image_func.cpp b/plugins/wasmedge_image/image_func.cpp new file mode 100644 index 00000000..7710e53c --- /dev/null +++ b/plugins/wasmedge_image/image_func.cpp @@ -0,0 +1,170 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "image_func.h" + +#include "common/log.h" +#include "common/span.h" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeImage { + +namespace { + +// Helper function to decode and resize image. +template +bool decodeImgToSize(Span Buf, uint32_t W, uint32_t H, + Span DstBuf) { + std::stringstream ImgStream; + ImgStream.write(Buf.data(), Buf.size()); + Image Img; + try { + boost::gil::read_image(ImgStream, Img, FormatTag()); + } catch (std::exception const &e) { + spdlog::error("[WasmEdge-Image] Decode image fail: {}"sv, e.what()); + return false; + } + + uint32_t C = boost::gil::num_channels::value; + typename Image::view_t ImgView = boost::gil::interleaved_view( + W, H, reinterpret_cast(DstBuf.data()), + W * C * sizeof(char)); + boost::gil::resize_view(boost::gil::const_view(Img), ImgView, + boost::gil::bilinear_sampler()); + return true; +} + +// Helper function to normalize image. +void normalizeImg(Span SrcBuf, Span DstBuf) { + for (uint32_t I = 0; I < DstBuf.size(); I++) { + DstBuf[I] = static_cast(SrcBuf[I]) / 255.0; + } +} + +// Template to decode and resize image to the target format. +template +uint32_t readBufToImg(Span InBuf, uint32_t W, uint32_t H, + Span OutBuf) { + if (unlikely(!decodeImgToSize(InBuf, W, H, OutBuf))) { + return static_cast(ErrNo::Fail); + } + return static_cast(ErrNo::Success); +} + +// Template to decode and resize image to the target format. +template +uint32_t readBufToFlattenImg(Span InBuf, uint32_t W, uint32_t H, + Span OutBuf) { + std::vector ImgData(3 * W * H); + if (unlikely(!decodeImgToSize( + InBuf, W, H, Span(ImgData.data(), ImgData.size())))) { + return static_cast(ErrNo::Fail); + } + normalizeImg(ImgData, Span(reinterpret_cast(OutBuf.data()), + OutBuf.size() / sizeof(float))); + return static_cast(ErrNo::Success); +} + +#define MEMINST_CHECK(Out, CallFrame, Index) \ + auto *Out = CallFrame.getMemoryByIndex(Index); \ + if (unlikely(Out == nullptr)) { \ + spdlog::error("[WasmEdge-Image] Memory instance not found."sv); \ + return static_cast(ErrNo::Fail); \ + } + +#define MEM_SPAN_CHECK(OutSpan, MemInst, Type, BufPtr, BufLen, Message) \ + auto OutSpan = MemInst->getSpan(BufPtr, BufLen); \ + if (unlikely(OutSpan.size() != BufLen)) { \ + spdlog::error("[WasmEdge-Image] "sv Message); \ + return static_cast(ErrNo::Fail); \ + } + +} // namespace + +Expect LoadJPG::body(const Runtime::CallingFrame &Frame, + uint32_t InImgBufPtr, uint32_t InImgBufLen, + uint32_t OutImgW, uint32_t OutImgH, + uint32_t OutType, uint32_t OutBufPtr, + uint32_t OutBufLen) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Check the input image buffer. + MEM_SPAN_CHECK(ImgBufSpan, MemInst, char, InImgBufPtr, InImgBufLen, + "Failed when accessing the input image buffer memory."sv) + + // Check the output decoded image buffer. + MEM_SPAN_CHECK(OutBufSpan, MemInst, char, OutBufPtr, OutBufLen, + "Failed when accessing the output image data buffer memory."sv) + + switch (static_cast(OutType)) { + case DataType::RGB8: + return readBufToImg( + ImgBufSpan, OutImgW, OutImgH, OutBufSpan); + case DataType::BGR8: + return readBufToImg( + ImgBufSpan, OutImgW, OutImgH, OutBufSpan); + case DataType::RGB32F: + return readBufToFlattenImg( + ImgBufSpan, OutImgW, OutImgH, OutBufSpan); + case DataType::BGR32F: + return readBufToFlattenImg( + ImgBufSpan, OutImgW, OutImgH, OutBufSpan); + break; + default: + spdlog::error("[WasmEdge-Image] Invalid output data format."sv); + return static_cast(ErrNo::Fail); + } +} + +Expect LoadPNG::body(const Runtime::CallingFrame &Frame, + uint32_t InImgBufPtr, uint32_t InImgBufLen, + uint32_t OutImgW, uint32_t OutImgH, + uint32_t OutType, uint32_t OutBufPtr, + uint32_t OutBufLen) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Check the input image buffer. + MEM_SPAN_CHECK(ImgBufSpan, MemInst, char, InImgBufPtr, InImgBufLen, + "Failed when accessing the input image buffer memory."sv) + + // Check the output decoded image buffer. + MEM_SPAN_CHECK(OutBufSpan, MemInst, char, OutBufPtr, OutBufLen, + "Failed when accessing the output image data buffer memory."sv) + + switch (static_cast(OutType)) { + case DataType::RGB8: + return readBufToImg( + ImgBufSpan, OutImgW, OutImgH, OutBufSpan); + case DataType::BGR8: + return readBufToImg( + ImgBufSpan, OutImgW, OutImgH, OutBufSpan); + case DataType::RGB32F: + return readBufToFlattenImg( + ImgBufSpan, OutImgW, OutImgH, OutBufSpan); + case DataType::BGR32F: + return readBufToFlattenImg( + ImgBufSpan, OutImgW, OutImgH, OutBufSpan); + break; + default: + spdlog::error("[WasmEdge-Image] Invalid output data format."sv); + return static_cast(ErrNo::Fail); + } +} + +} // namespace WasmEdgeImage +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_image/image_func.h b/plugins/wasmedge_image/image_func.h new file mode 100644 index 00000000..099aa1a4 --- /dev/null +++ b/plugins/wasmedge_image/image_func.h @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "image_base.h" + +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeImage { + +class LoadJPG : public Func { +public: + LoadJPG(ImgEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t InImgBufPtr, uint32_t InImgBufLen, + uint32_t OutImgW, uint32_t OutImgH, uint32_t OutType, + uint32_t OutBufPtr, uint32_t OutBufLen); +}; + +class LoadPNG : public Func { +public: + LoadPNG(ImgEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t InImgBufPtr, uint32_t InImgBufLen, + uint32_t OutImgW, uint32_t OutImgH, uint32_t OutType, + uint32_t OutBufPtr, uint32_t OutBufLen); +}; + +} // namespace WasmEdgeImage +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_image/image_module.cpp b/plugins/wasmedge_image/image_module.cpp new file mode 100644 index 00000000..89ad74c5 --- /dev/null +++ b/plugins/wasmedge_image/image_module.cpp @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "image_module.h" +#include "image_func.h" + +#include + +namespace WasmEdge { +namespace Host { + +WasmEdgeImageModule::WasmEdgeImageModule() + : Runtime::Instance::ModuleInstance("wasmedge_image") { + addHostFunc("load_jpg", std::make_unique(Env)); + addHostFunc("load_png", std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_image/image_module.h b/plugins/wasmedge_image/image_module.h new file mode 100644 index 00000000..a23b05f0 --- /dev/null +++ b/plugins/wasmedge_image/image_module.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "image_env.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasmEdgeImageModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeImageModule(); + ~WasmEdgeImageModule() = default; + + WasmEdgeImage::ImgEnv &getEnv() { return Env; } + +private: + WasmEdgeImage::ImgEnv Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflow/CMakeLists.txt b/plugins/wasmedge_tensorflow/CMakeLists.txt new file mode 100644 index 00000000..ba93b141 --- /dev/null +++ b/plugins/wasmedge_tensorflow/CMakeLists.txt @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +wasmedge_add_library(wasmedgePluginWasmEdgeTensorflow + SHARED + tensorflow_env.cpp + tensorflow_func.cpp + tensorflow_module.cpp +) + +target_compile_options(wasmedgePluginWasmEdgeTensorflow + PUBLIC + -DWASMEDGE_PLUGIN +) + +target_include_directories(wasmedgePluginWasmEdgeTensorflow + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} +) + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasmEdgeTensorflow + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginWasmEdgeTensorflow + PRIVATE + wasmedge_shared + ) +endif() + +include(WASINNDeps) +wasmedge_setup_tf_target(wasmedgePluginWasmEdgeTensorflow) + +install(TARGETS wasmedgePluginWasmEdgeTensorflow DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) diff --git a/plugins/wasmedge_tensorflow/tensorflow_base.h b/plugins/wasmedge_tensorflow/tensorflow_base.h new file mode 100644 index 00000000..56686110 --- /dev/null +++ b/plugins/wasmedge_tensorflow/tensorflow_base.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "tensorflow_env.h" + +#include "common/errcode.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeTensorflow { + +template class Func : public Runtime::HostFunction { +public: + Func(TFEnv &HostEnv) : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + TFEnv &Env; +}; + +} // namespace WasmEdgeTensorflow +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflow/tensorflow_env.cpp b/plugins/wasmedge_tensorflow/tensorflow_env.cpp new file mode 100644 index 00000000..9a672863 --- /dev/null +++ b/plugins/wasmedge_tensorflow/tensorflow_env.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "tensorflow_env.h" +#include "tensorflow_module.h" + +namespace WasmEdge { +namespace Host { + +namespace { + +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeTensorflowModule; +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasmedge_tensorflow", + .Description = "Tensorflow plug-in for WasmEdge.", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 13, 0, 0}, + .ModuleCount = 1, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "wasmedge_tensorflow", + .Description = + "This module contains WasmEdge-Tensorflow host functions.", + .Create = create, + }, + }, + .AddOptions = nullptr, +}; + +} // namespace + +Plugin::PluginRegister WasmEdgeTensorflow::TFEnv::Register(&Descriptor); + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflow/tensorflow_env.h b/plugins/wasmedge_tensorflow/tensorflow_env.h new file mode 100644 index 00000000..8cfcd645 --- /dev/null +++ b/plugins/wasmedge_tensorflow/tensorflow_env.h @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "plugin/plugin.h" + +#include "tensorflow/c/c_api.h" + +#include +#include +#include +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeTensorflow { + +enum class ErrNo : uint32_t { + Success = 0, // No error occurred. + InvalidArgument = 1, // Caller module passed an invalid argument. + InvalidEncoding = 2, // Invalid encoding. + MissingMemory = 3, // Caller module is missing a memory export. + Busy = 4, // Device or resource busy. + RuntimeError = 5, // Runtime Error. +}; + +struct TensorList { + void reset() noexcept { + for (uint32_t I = 0; I < DataList.size(); ++I) { + if (DataList[I]) { + TF_DeleteTensor(DataList[I]); + } + } + NameMap.clear(); + OperList.clear(); + DataList.clear(); + } + + std::unordered_map NameMap; + std::vector OperList; + std::vector DataList; +}; + +struct Context { + Context() noexcept { Stat = TF_NewStatus(); } + ~Context() noexcept { + reset(); + TF_DeleteStatus(Stat); + } + + void clearInputs() noexcept { Inputs.reset(); } + + void clearOutputs() noexcept { Outputs.reset(); } + + void reset() noexcept { + if (GraphOpts) { + TF_DeleteImportGraphDefOptions(GraphOpts); + GraphOpts = nullptr; + } + if (Buffer) { + TF_DeleteBuffer(Buffer); + Buffer = nullptr; + } + if (Graph) { + TF_DeleteGraph(Graph); + Graph = nullptr; + } + if (SessionOpts) { + TF_DeleteSessionOptions(SessionOpts); + SessionOpts = nullptr; + } + if (Session) { + TF_CloseSession(Session, Stat); + TF_DeleteSession(Session, Stat); + Session = nullptr; + } + clearInputs(); + clearOutputs(); + } + + TF_Status *Stat; + TF_ImportGraphDefOptions *GraphOpts = nullptr; + TF_Buffer *Buffer = nullptr; + TF_Graph *Graph = nullptr; + TF_SessionOptions *SessionOpts = nullptr; + TF_Session *Session = nullptr; + struct TensorList Inputs; + struct TensorList Outputs; +}; + +struct TFEnv { + TFEnv() noexcept { TFContext.reserve(16U); } + + Context *getContext(const uint32_t ID) noexcept { + auto It = RecycledIdx.find(ID); + if (ID < TFContext.size() && It == RecycledIdx.end()) { + return &TFContext[ID]; + } + return nullptr; + } + uint32_t newContext() noexcept { + uint32_t NewIdx = TFContext.size(); + if (RecycledIdx.empty()) { + TFContext.emplace_back(); + } else { + NewIdx = *RecycledIdx.begin(); + RecycledIdx.erase(NewIdx); + } + return NewIdx; + } + void deleteContext(const uint32_t ID) noexcept { + if (ID < TFContext.size()) { + TFContext[ID].reset(); + RecycledIdx.insert(ID); + } + } + + static Plugin::PluginRegister Register; + +private: + std::unordered_set RecycledIdx; + std::vector TFContext; +}; + +} // namespace WasmEdgeTensorflow +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflow/tensorflow_func.cpp b/plugins/wasmedge_tensorflow/tensorflow_func.cpp new file mode 100644 index 00000000..a3ed8bcc --- /dev/null +++ b/plugins/wasmedge_tensorflow/tensorflow_func.cpp @@ -0,0 +1,421 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "tensorflow_func.h" + +#include "common/log.h" +#include "common/span.h" + +#include "tensorflow/c/c_api.h" + +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeTensorflow { + +using namespace std::literals::string_view_literals; + +namespace { + +#define MEMINST_CHECK(Out, CallFrame, Index) \ + auto *Out = CallFrame.getMemoryByIndex(Index); \ + if (unlikely(Out == nullptr)) { \ + spdlog::error("[WasmEdge-Tensorflow] Memory instance not found."sv); \ + return static_cast(ErrNo::MissingMemory); \ + } + +#define SESSION_CHECK(Out, SessionID, Message, ErrNo) \ + auto *Out = Env.getContext(SessionID); \ + if (unlikely(Out == nullptr)) { \ + spdlog::error("[WasmEdge-Tensorflow] "sv Message); \ + return static_cast(ErrNo); \ + } + +#define MEM_SPAN_CHECK(OutSpan, MemInst, Type, BufPtr, BufLen, Message) \ + auto OutSpan = MemInst->getSpan(BufPtr, BufLen); \ + if (unlikely(OutSpan.size() != BufLen)) { \ + spdlog::error("[WasmEdge-Tensorflow] "sv Message); \ + return static_cast(ErrNo::MissingMemory); \ + } + +#define MEM_SV_CHECK(OutSV, MemInst, BufPtr, BufLen, Message) \ + auto OutSV = MemInst->getStringView(BufPtr, BufLen); \ + if (unlikely(OutSV.size() != BufLen)) { \ + spdlog::error("[WasmEdge-Tensorflow] "sv Message); \ + return static_cast(ErrNo::MissingMemory); \ + } + +#define MEM_PTR_CHECK(OutPtr, MemInst, Type, Offset, Message) \ + Type *OutPtr = MemInst->getPointer(Offset); \ + if (unlikely(OutPtr == nullptr)) { \ + spdlog::error("[WasmEdge-Tensorflow] "sv Message); \ + return static_cast(ErrNo::MissingMemory); \ + } + +std::pair parseIndex(std::string_view Name) { + // Check if there's index in the string key. + size_t Pos = Name.find(":"); + int Idx = 0; + std::string NameStr; + if (Pos != std::string::npos) { + Idx = std::strtol(Name.data() + Pos + 1, nullptr, 10); + NameStr = Name.substr(0, Pos); + } else { + NameStr = Name; + } + return std::make_pair(NameStr, Idx); +} + +} // namespace + +Expect CreateSession::body(const Runtime::CallingFrame &Frame, + uint32_t ModBufPtr, uint32_t ModBufLen, + uint32_t SessionIdPtr) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Check the input model buffer. + MEM_SPAN_CHECK(ModBufSpan, MemInst, char, ModBufPtr, ModBufLen, + "Failed when accessing the input model buffer memory."sv) + + // Check the return value: SessionIdPtr should be valid. + MEM_PTR_CHECK(SessionId, MemInst, uint32_t, SessionIdPtr, + "Failed when accessing the return SessionID memory."sv) + + // Create context and import graph. + uint32_t NewID = Env.newContext(); + SESSION_CHECK(Cxt, NewID, "Failed when allocating resources."sv, + ErrNo::MissingMemory) + + Cxt->Graph = TF_NewGraph(); + Cxt->Buffer = TF_NewBufferFromString(ModBufSpan.data(), ModBufLen); + Cxt->GraphOpts = TF_NewImportGraphDefOptions(); + TF_GraphImportGraphDef(Cxt->Graph, Cxt->Buffer, Cxt->GraphOpts, Cxt->Stat); + if (unlikely(TF_GetCode(Cxt->Stat) != TF_OK)) { + spdlog::error("[WasmEdge-Tensorflow] Cannot import graph from buffer: {}"sv, + TF_Message(Cxt->Stat)); + Env.deleteContext(NewID); + return static_cast(ErrNo::InvalidArgument); + } + + // Create session. + Cxt->SessionOpts = TF_NewSessionOptions(); + Cxt->Session = TF_NewSession(Cxt->Graph, Cxt->SessionOpts, Cxt->Stat); + if (unlikely(TF_GetCode(Cxt->Stat) != TF_OK)) { + spdlog::error("[WasmEdge-Tensorflow] Unable to create session: {}"sv, + TF_Message(Cxt->Stat)); + Env.deleteContext(NewID); + return static_cast(ErrNo::InvalidArgument); + } + + *SessionId = NewID; + return static_cast(ErrNo::Success); +} + +Expect CreateSessionSavedModel::body( + const Runtime::CallingFrame &Frame, uint32_t PathPtr, uint32_t PathLen, + uint32_t TagsBufPtr, uint32_t TagsBufLen, uint32_t SessionIdPtr) { + // TODO: Implement this function. + spdlog::error("[WasmEdge-Tensorflow] Saved model is not supported yet."sv); + return static_cast(ErrNo::InvalidArgument); + + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Check the model path buffer. + MEM_SV_CHECK(PathSV, MemInst, PathPtr, PathLen, + "Failed when accessing the model path buffer memory."sv) + + // Check the tags buffer. + struct MetaGraphDefTag { + uint32_t Ptr; + uint32_t Len; + }; + MEM_SPAN_CHECK(TagSpan, MemInst, MetaGraphDefTag, TagsBufPtr, TagsBufLen, + "Failed when accessing the tags memory."sv) + + // Check the elements of tags. + std::vector Tags; + Tags.reserve(TagsBufLen); + for (size_t I = 0; I < TagSpan.size(); ++I) { + const auto &Tag = TagSpan[I]; + MEM_SV_CHECK(TagNameSV, MemInst, Tag.Ptr, Tag.Len, + "Failed when accessing the tag name memory."sv) + Tags.emplace_back(TagNameSV); + } + + // Check the return value: SessionIdPtr should be valid. + MEM_PTR_CHECK(SessionId, MemInst, uint32_t, SessionIdPtr, + "Failed when accessing the return SessionID memory."sv) + + // TODO: Implement here. + + return static_cast(ErrNo::Success); +} + +Expect DeleteSession::body(const Runtime::CallingFrame &, + uint32_t SessionId) { + Env.deleteContext(SessionId); + return static_cast(ErrNo::Success); +} + +Expect RunSession::body(const Runtime::CallingFrame &, + uint32_t SessionId) { + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Delete old output tensors + for (auto T : Cxt->Outputs.DataList) { + if (T) { + TF_DeleteTensor(T); + } + } + + // Run session + TF_SessionRun(Cxt->Session, + // RunOptions + nullptr, + // Input tensors + Cxt->Inputs.OperList.data(), Cxt->Inputs.DataList.data(), + Cxt->Inputs.DataList.size(), + // Output tensors + Cxt->Outputs.OperList.data(), Cxt->Outputs.DataList.data(), + Cxt->Outputs.DataList.size(), + // Target operations + nullptr, 0, + // RunMetadata + nullptr, + // Output status + Cxt->Stat); + + if (unlikely(TF_GetCode(Cxt->Stat) != TF_OK)) { + spdlog::error("[WasmEdge-Tensorflow] Run session failed: {}"sv, + TF_Message(Cxt->Stat)); + return static_cast(ErrNo::Busy); + } + return static_cast(ErrNo::Success); +} + +Expect GetOutputTensor::body(const Runtime::CallingFrame &Frame, + uint32_t SessionId, uint32_t NamePtr, + uint32_t NameLen, uint32_t TensorIdPtr) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Check the input tensor operation name buffer. + MEM_SV_CHECK(NameSV, MemInst, NamePtr, NameLen, + "Failed when accessing the output name buffer memory."sv) + + // Check the return value: TensorIdPtr should be valid. + MEM_PTR_CHECK(TensorId, MemInst, uint32_t, TensorIdPtr, + "Failed when accessing the return TensorID memory."sv) + + // Find the output tensor ID. + auto It = Cxt->Outputs.NameMap.find(std::string(NameSV)); + if (unlikely(It == Cxt->Outputs.NameMap.end())) { + spdlog::error("[WasmEdge-Tensorflow] Output tensor {} not found."sv, + NameSV); + return static_cast(ErrNo::InvalidArgument); + } + *TensorId = It->second; + return static_cast(ErrNo::Success); +} + +Expect GetTensorLen::body(const Runtime::CallingFrame &Frame, + uint32_t SessionId, uint32_t TensorId, + uint32_t LenPtr) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Check the return value: LenPtr should be valid. + MEM_PTR_CHECK(Len, MemInst, uint32_t, LenPtr, + "Failed when accessing the return Length memory."sv) + + // Get output tensor from ID. + if (unlikely(TensorId >= Cxt->Outputs.DataList.size())) { + spdlog::error("[WasmEdge-Tensorflow] Invalid tensor ID."sv); + return static_cast(ErrNo::InvalidArgument); + } + + // Return tensor data length. + auto *Tensor = Cxt->Outputs.DataList[TensorId]; + if (likely(Tensor != nullptr)) { + *Len = TF_TensorByteSize(Tensor); + } else { + *Len = 0U; + } + return static_cast(ErrNo::Success); +} + +Expect GetTensorData::body(const Runtime::CallingFrame &Frame, + uint32_t SessionId, uint32_t TensorId, + uint32_t BufPtr, uint32_t BufLen, + uint32_t WrittenBytesPtr) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Check the output tensor buffer. + MEM_SPAN_CHECK( + BufSpan, MemInst, char, BufPtr, BufLen, + "Failed when accessing the output tensor write buffer memory."sv) + + // Check the return value: WrittenBytesPtr should be valid. + MEM_PTR_CHECK(WrittenBytes, MemInst, uint32_t, WrittenBytesPtr, + "Failed when accessing the return WrittenBytes memory."sv) + + // Get output tensor from ID. + if (unlikely(TensorId >= Cxt->Outputs.DataList.size())) { + spdlog::error("[WasmEdge-Tensorflow] Invalid tensor ID."sv); + return static_cast(ErrNo::InvalidArgument); + } + + // Copy tensor data to buffer. + auto *Tensor = Cxt->Outputs.DataList[TensorId]; + size_t RealSize = TF_TensorByteSize(Tensor); + *WrittenBytes = 0U; + if (Tensor != nullptr && RealSize > 0 && BufLen > 0) { + *WrittenBytes = std::min(static_cast(RealSize), BufLen); + char *Data = static_cast(TF_TensorData(Tensor)); + std::copy_n(Data, *WrittenBytes, BufSpan.data()); + } + return static_cast(ErrNo::Success); +} + +Expect AppendInput::body(const Runtime::CallingFrame &Frame, + uint32_t SessionId, uint32_t NamePtr, + uint32_t NameLen, uint32_t DimPtr, + uint32_t DimCnt, uint32_t DataType, + uint32_t TensorBufPtr, + uint32_t TensorBufLen) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Check the input tensor buffer. + MEM_SPAN_CHECK(TensorBufSpan, MemInst, uint8_t, TensorBufPtr, TensorBufLen, + "Failed when accessing the input tensor buffer memory."sv) + + // Check the input tensor dimension buffer. + MEM_SPAN_CHECK(DimBufSpan, MemInst, int64_t, DimPtr, DimCnt, + "Failed when accessing the input dimension buffer memory."sv) + + // Check the input tensor operation name buffer. + MEM_SV_CHECK(NameSV, MemInst, NamePtr, NameLen, + "Failed when accessing the input name buffer memory."sv) + + // Check the input operation. + auto OperKeyPair = parseIndex(NameSV); + TF_Operation *Operation = + TF_GraphOperationByName(Cxt->Graph, OperKeyPair.first.c_str()); + if (unlikely(Operation == nullptr)) { + spdlog::error("[WasmEdge-Tensorflow] Input operation {} not found."sv, + NameSV); + return static_cast(ErrNo::InvalidArgument); + } + + // Check if the input tensor by name exists. + uint32_t TensorId = Cxt->Inputs.DataList.size(); + auto It = Cxt->Inputs.NameMap.find(std::string(NameSV)); + if (It != Cxt->Inputs.NameMap.end()) { + TensorId = It->second; + } + + // Create the tensor and copy data from buffer. + TF_Tensor *Tensor = nullptr; + if (DimCnt > 0) { + Tensor = TF_AllocateTensor(static_cast(DataType), + DimBufSpan.data(), DimCnt, TensorBufLen); + } else { + Tensor = TF_AllocateTensor(static_cast(DataType), nullptr, 0, + TensorBufLen); + } + if (unlikely(Tensor == nullptr)) { + spdlog::error("[WasmEdge-Tensorflow] Allocate input tensor failed."sv); + return static_cast(ErrNo::Busy); + } + std::copy_n(TensorBufSpan.begin(), TensorBufLen, + static_cast(TF_TensorData(Tensor))); + + // If the old input tensor exists, delete the old one. + if (It != Cxt->Inputs.NameMap.end()) { + TF_DeleteTensor(Cxt->Inputs.DataList[TensorId]); + Cxt->Inputs.DataList[TensorId] = Tensor; + } else { + Cxt->Inputs.OperList.emplace_back(TF_Output{Operation, OperKeyPair.second}); + Cxt->Inputs.DataList.push_back(Tensor); + Cxt->Inputs.NameMap.insert({std::string(NameSV), TensorId}); + } + return static_cast(ErrNo::Success); +} + +Expect AppendOutput::body(const Runtime::CallingFrame &Frame, + uint32_t SessionId, uint32_t NamePtr, + uint32_t NameLen) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Check the output tensor operation name buffer. + MEM_SV_CHECK(NameSV, MemInst, NamePtr, NameLen, + "Failed when accessing the output name buffer memory."sv) + + // Check the output operation. + auto OperKeyPair = parseIndex(NameSV); + TF_Operation *Operation = + TF_GraphOperationByName(Cxt->Graph, OperKeyPair.first.c_str()); + if (unlikely(Operation == nullptr)) { + spdlog::error("[WasmEdge-Tensorflow] Output operation {} not found."sv, + NameSV); + return static_cast(ErrNo::InvalidArgument); + } + + // Store names and operations if the output tensor key not exists. + auto It = Cxt->Outputs.NameMap.find(std::string(NameSV)); + if (It == Cxt->Outputs.NameMap.end()) { + uint32_t TensorId = Cxt->Outputs.DataList.size(); + Cxt->Outputs.OperList.emplace_back( + TF_Output{Operation, OperKeyPair.second}); + Cxt->Outputs.DataList.push_back(nullptr); + Cxt->Outputs.NameMap.insert({std::string(NameSV), TensorId}); + } + return static_cast(ErrNo::Success); +} + +Expect ClearInput::body(const Runtime::CallingFrame &, + uint32_t SessionId) { + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Clear the inputs. + Cxt->clearInputs(); + return static_cast(ErrNo::Success); +} + +Expect ClearOutput::body(const Runtime::CallingFrame &, + uint32_t SessionId) { + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Clear the outputs. + Cxt->clearOutputs(); + return static_cast(ErrNo::Success); +} + +} // namespace WasmEdgeTensorflow +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflow/tensorflow_func.h b/plugins/wasmedge_tensorflow/tensorflow_func.h new file mode 100644 index 00000000..6a2786dc --- /dev/null +++ b/plugins/wasmedge_tensorflow/tensorflow_func.h @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "tensorflow_base.h" + +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeTensorflow { + +class CreateSession : public Func { +public: + CreateSession(TFEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ModBufPtr, + uint32_t ModBufLen, uint32_t SessionIdPtr); +}; + +class CreateSessionSavedModel : public Func { +public: + CreateSessionSavedModel(TFEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t PathPtr, + uint32_t PathLen, uint32_t TagsBufPtr, + uint32_t TagsBufLen, uint32_t SessionIdPtr); +}; + +class DeleteSession : public Func { +public: + DeleteSession(TFEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId); +}; + +class RunSession : public Func { +public: + RunSession(TFEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId); +}; + +class GetOutputTensor : public Func { +public: + GetOutputTensor(TFEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId, + uint32_t NamePtr, uint32_t NameLen, + uint32_t TensorIdPtr); +}; + +class GetTensorLen : public Func { +public: + GetTensorLen(TFEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId, + uint32_t TensorId, uint32_t LenPtr); +}; + +class GetTensorData : public Func { +public: + GetTensorData(TFEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId, + uint32_t TensorId, uint32_t BufPtr, uint32_t BufLen, + uint32_t WrittenBytesPtr); +}; + +class AppendInput : public Func { +public: + AppendInput(TFEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId, + uint32_t NamePtr, uint32_t NameLen, uint32_t DimPtr, + uint32_t DimCnt, uint32_t DataType, + uint32_t TensorBufPtr, uint32_t TensorBufLen); +}; + +class AppendOutput : public Func { +public: + AppendOutput(TFEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId, + uint32_t NamePtr, uint32_t NameLen); +}; + +class ClearInput : public Func { +public: + ClearInput(TFEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId); +}; + +class ClearOutput : public Func { +public: + ClearOutput(TFEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId); +}; + +} // namespace WasmEdgeTensorflow +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflow/tensorflow_module.cpp b/plugins/wasmedge_tensorflow/tensorflow_module.cpp new file mode 100644 index 00000000..73b079c4 --- /dev/null +++ b/plugins/wasmedge_tensorflow/tensorflow_module.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "tensorflow_module.h" +#include "tensorflow_func.h" + +#include + +namespace WasmEdge { +namespace Host { + +WasmEdgeTensorflowModule::WasmEdgeTensorflowModule() + : Runtime::Instance::ModuleInstance("wasmedge_tensorflow") { + addHostFunc("create_session", + std::make_unique(Env)); + addHostFunc( + "create_session_saved_model", + std::make_unique(Env)); + addHostFunc("delete_session", + std::make_unique(Env)); + addHostFunc("run_session", + std::make_unique(Env)); + addHostFunc("get_output_tensor", + std::make_unique(Env)); + addHostFunc("get_tensor_len", + std::make_unique(Env)); + addHostFunc("get_tensor_data", + std::make_unique(Env)); + addHostFunc("append_input", + std::make_unique(Env)); + addHostFunc("append_output", + std::make_unique(Env)); + addHostFunc("clear_input", + std::make_unique(Env)); + addHostFunc("clear_output", + std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflow/tensorflow_module.h b/plugins/wasmedge_tensorflow/tensorflow_module.h new file mode 100644 index 00000000..ae60330e --- /dev/null +++ b/plugins/wasmedge_tensorflow/tensorflow_module.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "tensorflow_env.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasmEdgeTensorflowModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeTensorflowModule(); + ~WasmEdgeTensorflowModule() = default; + + WasmEdgeTensorflow::TFEnv &getEnv() { return Env; } + +private: + WasmEdgeTensorflow::TFEnv Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflowlite/CMakeLists.txt b/plugins/wasmedge_tensorflowlite/CMakeLists.txt new file mode 100644 index 00000000..62e4707e --- /dev/null +++ b/plugins/wasmedge_tensorflowlite/CMakeLists.txt @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +wasmedge_add_library(wasmedgePluginWasmEdgeTensorflowLite + SHARED + tensorflowlite_env.cpp + tensorflowlite_func.cpp + tensorflowlite_module.cpp +) + +target_compile_options(wasmedgePluginWasmEdgeTensorflowLite + PUBLIC + -DWASMEDGE_PLUGIN +) + +target_include_directories(wasmedgePluginWasmEdgeTensorflowLite + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} +) + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasmEdgeTensorflowLite + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginWasmEdgeTensorflowLite + PRIVATE + wasmedge_shared + ) +endif() + +include(WASINNDeps) +wasmedge_setup_tflite_target(wasmedgePluginWasmEdgeTensorflowLite) + +install(TARGETS wasmedgePluginWasmEdgeTensorflowLite DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) diff --git a/plugins/wasmedge_tensorflowlite/tensorflowlite_base.h b/plugins/wasmedge_tensorflowlite/tensorflowlite_base.h new file mode 100644 index 00000000..6b0765ae --- /dev/null +++ b/plugins/wasmedge_tensorflowlite/tensorflowlite_base.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "tensorflowlite_env.h" + +#include "common/errcode.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeTensorflowLite { + +template class Func : public Runtime::HostFunction { +public: + Func(TFLiteEnv &HostEnv) : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + TFLiteEnv &Env; +}; + +} // namespace WasmEdgeTensorflowLite +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflowlite/tensorflowlite_env.cpp b/plugins/wasmedge_tensorflowlite/tensorflowlite_env.cpp new file mode 100644 index 00000000..47a3e037 --- /dev/null +++ b/plugins/wasmedge_tensorflowlite/tensorflowlite_env.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "tensorflowlite_env.h" +#include "tensorflowlite_module.h" + +namespace WasmEdge { +namespace Host { + +namespace { + +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeTensorflowLiteModule; +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasmedge_tensorflowlite", + .Description = "Tensorflow-Lite plug-in for WasmEdge.", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 13, 0, 0}, + .ModuleCount = 1, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "wasmedge_tensorflowlite", + .Description = "This module contains WasmEdge-TensorflowLite " + "host functions.", + .Create = create, + }, + }, + .AddOptions = nullptr, +}; + +} // namespace + +Plugin::PluginRegister WasmEdgeTensorflowLite::TFLiteEnv::Register(&Descriptor); + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflowlite/tensorflowlite_env.h b/plugins/wasmedge_tensorflowlite/tensorflowlite_env.h new file mode 100644 index 00000000..f4f606f4 --- /dev/null +++ b/plugins/wasmedge_tensorflowlite/tensorflowlite_env.h @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "plugin/plugin.h" + +#include "tensorflow/lite/c/c_api.h" + +#include +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeTensorflowLite { + +enum class ErrNo : uint32_t { + Success = 0, // No error occurred. + InvalidArgument = 1, // Caller module passed an invalid argument. + InvalidEncoding = 2, // Invalid encoding. + MissingMemory = 3, // Caller module is missing a memory export. + Busy = 4, // Device or resource busy. + RuntimeError = 5, // Runtime Error. +}; + +struct Context { + Context() = default; + ~Context() { reset(); } + void reset() noexcept { + if (Interp) { + TfLiteInterpreterDelete(Interp); + } + Interp = nullptr; + } + TfLiteInterpreter *Interp = nullptr; +}; + +struct TFLiteEnv { + TFLiteEnv() noexcept { TFLiteContext.reserve(16U); } + + Context *getContext(const uint32_t ID) noexcept { + auto It = RecycledIdx.find(ID); + if (ID < TFLiteContext.size() && It == RecycledIdx.end()) { + return &TFLiteContext[ID]; + } + return nullptr; + } + uint32_t newContext() noexcept { + uint32_t NewIdx = TFLiteContext.size(); + if (RecycledIdx.empty()) { + TFLiteContext.emplace_back(); + } else { + NewIdx = *RecycledIdx.begin(); + RecycledIdx.erase(NewIdx); + } + return NewIdx; + } + void deleteContext(const uint32_t ID) noexcept { + if (ID < TFLiteContext.size()) { + TFLiteContext[ID].reset(); + RecycledIdx.insert(ID); + } + } + + static Plugin::PluginRegister Register; + +private: + std::unordered_set RecycledIdx; + std::vector TFLiteContext; +}; + +} // namespace WasmEdgeTensorflowLite +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflowlite/tensorflowlite_func.cpp b/plugins/wasmedge_tensorflowlite/tensorflowlite_func.cpp new file mode 100644 index 00000000..18af8a1b --- /dev/null +++ b/plugins/wasmedge_tensorflowlite/tensorflowlite_func.cpp @@ -0,0 +1,296 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "tensorflowlite_func.h" + +#include "common/log.h" +#include "common/span.h" + +#include "tensorflow/lite/c/c_api.h" + +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeTensorflowLite { + +namespace { + +#define MEMINST_CHECK(Out, CallFrame, Index) \ + auto *Out = CallFrame.getMemoryByIndex(Index); \ + if (unlikely(Out == nullptr)) { \ + spdlog::error("[WasmEdge-Tensorflow-Lite] Memory instance not found."sv); \ + return static_cast(ErrNo::MissingMemory); \ + } + +#define SESSION_CHECK(Out, SessionID, Message, ErrNo) \ + auto *Out = Env.getContext(SessionID); \ + if (unlikely(Out == nullptr)) { \ + spdlog::error("[WasmEdge-Tensorflow-Lite] "sv Message); \ + return static_cast(ErrNo); \ + } + +#define MEM_SPAN_CHECK(OutSpan, MemInst, Type, BufPtr, BufLen, Message) \ + auto OutSpan = MemInst->getSpan(BufPtr, BufLen); \ + if (unlikely(OutSpan.size() != BufLen)) { \ + spdlog::error("[WasmEdge-Tensorflow-Lite] "sv Message); \ + return static_cast(ErrNo::MissingMemory); \ + } + +#define MEM_SV_CHECK(OutSV, MemInst, BufPtr, BufLen, Message) \ + auto OutSV = MemInst->getStringView(BufPtr, BufLen); \ + if (unlikely(OutSV.size() != BufLen)) { \ + spdlog::error("[WasmEdge-Tensorflow-Lite] "sv Message); \ + return static_cast(ErrNo::MissingMemory); \ + } + +#define MEM_PTR_CHECK(OutPtr, MemInst, Type, Offset, Message) \ + Type *OutPtr = MemInst->getPointer(Offset); \ + if (unlikely(OutPtr == nullptr)) { \ + spdlog::error("[WasmEdge-Tensorflow-Lite] "sv Message); \ + return static_cast(ErrNo::MissingMemory); \ + } + +} // namespace + +Expect CreateSession::body(const Runtime::CallingFrame &Frame, + uint32_t ModBufPtr, uint32_t ModBufLen, + uint32_t SessionIdPtr) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Check the input model buffer. + MEM_SPAN_CHECK(ModBufSpan, MemInst, char, ModBufPtr, ModBufLen, + "Failed when accessing the input model buffer memory."sv) + + // Check the return value: SessionIdPtr should be valid. + MEM_PTR_CHECK(SessionId, MemInst, uint32_t, SessionIdPtr, + "Failed when accessing the return SessionID memory."sv) + + // Create context and import graph. + uint32_t NewID = Env.newContext(); + SESSION_CHECK(Cxt, NewID, "Failed when allocating resources."sv, + ErrNo::MissingMemory) + + auto *Model = TfLiteModelCreate(ModBufSpan.data(), ModBufLen); + if (unlikely(Model == nullptr)) { + spdlog::error("[WasmEdge-Tensorflow-Lite] Cannot import TFLite model."sv); + Env.deleteContext(NewID); + return static_cast(ErrNo::InvalidArgument); + } + auto *Ops = TfLiteInterpreterOptionsCreate(); + if (unlikely(Ops == nullptr)) { + spdlog::error( + "[WasmEdge-Tensorflow-Lite] Cannot create TFLite interpreter options."sv); + Env.deleteContext(NewID); + TfLiteModelDelete(Model); + return static_cast(ErrNo::Busy); + } + TfLiteInterpreterOptionsSetNumThreads(Ops, 2); + Cxt->Interp = TfLiteInterpreterCreate(Model, Ops); + TfLiteInterpreterOptionsDelete(Ops); + TfLiteModelDelete(Model); + if (unlikely(Cxt->Interp == nullptr)) { + spdlog::error( + "[WasmEdge-Tensorflow-Lite] Cannot create TFLite interpreter."sv); + Env.deleteContext(NewID); + return static_cast(ErrNo::Busy); + } + TfLiteStatus Status = TfLiteInterpreterAllocateTensors(Cxt->Interp); + if (unlikely(Status != TfLiteStatus::kTfLiteOk)) { + spdlog::error("[WasmEdge-Tensorflow-Lite] Cannot allocate tensors."sv); + Env.deleteContext(NewID); + return static_cast(ErrNo::Busy); + } + + *SessionId = NewID; + return static_cast(ErrNo::Success); +} + +Expect DeleteSession::body(const Runtime::CallingFrame &, + uint32_t SessionId) { + Env.deleteContext(SessionId); + return static_cast(ErrNo::Success); +} + +Expect RunSession::body(const Runtime::CallingFrame &, + uint32_t SessionId) { + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Run session + TfLiteStatus Stat = TfLiteInterpreterInvoke(Cxt->Interp); + if (unlikely(Stat != TfLiteStatus::kTfLiteOk)) { + spdlog::error("[WasmEdge-Tensorflow-Lite] Invocation failed."sv); + return static_cast(ErrNo::Busy); + } + return static_cast(ErrNo::Success); +} + +Expect GetOutputTensor::body(const Runtime::CallingFrame &Frame, + uint32_t SessionId, uint32_t NamePtr, + uint32_t NameLen, uint32_t TensorIdPtr) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Check the input tensor operation name buffer. + MEM_SV_CHECK(NameSV, MemInst, NamePtr, NameLen, + "Failed when accessing the output name buffer memory."sv) + + // Check the return value: TensorIdPtr should be valid. + MEM_PTR_CHECK(TensorId, MemInst, uint32_t, TensorIdPtr, + "Failed when accessing the return TensorID memory."sv) + + // Find the output tensor. + bool IsFound = false; + uint32_t OutCnt = TfLiteInterpreterGetOutputTensorCount(Cxt->Interp); + for (uint32_t I = 0; I < OutCnt; ++I) { + const TfLiteTensor *T = TfLiteInterpreterGetOutputTensor(Cxt->Interp, I); + if (NameSV == std::string(TfLiteTensorName(T))) { + *TensorId = I; + IsFound = true; + break; + } + } + if (unlikely(!IsFound)) { + spdlog::error("[WasmEdge-Tensorflow-Lite] Output tensor {} not found."sv, + NameSV); + return static_cast(ErrNo::InvalidArgument); + } + return static_cast(ErrNo::Success); +} + +Expect GetTensorLen::body(const Runtime::CallingFrame &Frame, + uint32_t SessionId, uint32_t TensorId, + uint32_t LenPtr) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Check the return value: LenPtr should be valid. + MEM_PTR_CHECK(Len, MemInst, uint32_t, LenPtr, + "Failed when accessing the return Length memory."sv) + + // Get output tensor from ID. + uint32_t OutCnt = TfLiteInterpreterGetOutputTensorCount(Cxt->Interp); + if (unlikely(TensorId >= OutCnt)) { + spdlog::error("[WasmEdge-Tensorflow-Lite] Invalid tensor ID."sv); + return static_cast(ErrNo::InvalidArgument); + } + + // Return tensor data length. + const TfLiteTensor *Tensor = + TfLiteInterpreterGetOutputTensor(Cxt->Interp, TensorId); + if (likely(Tensor != nullptr)) { + *Len = TfLiteTensorByteSize(Tensor); + } else { + *Len = 0U; + } + return static_cast(ErrNo::Success); +} + +Expect GetTensorData::body(const Runtime::CallingFrame &Frame, + uint32_t SessionId, uint32_t TensorId, + uint32_t BufPtr, uint32_t BufLen, + uint32_t WrittenBytesPtr) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Check the output tensor buffer. + MEM_SPAN_CHECK( + BufSpan, MemInst, char, BufPtr, BufLen, + "Failed when accessing the output tensor write buffer memory."sv) + + // Check the return value: WrittenBytesPtr should be valid. + MEM_PTR_CHECK(WrittenBytes, MemInst, uint32_t, WrittenBytesPtr, + "Failed when accessing the return WrittenBytes memory."sv) + + // Get output tensor from ID. + uint32_t OutCnt = TfLiteInterpreterGetOutputTensorCount(Cxt->Interp); + if (unlikely(TensorId >= OutCnt)) { + spdlog::error("[WasmEdge-Tensorflow-Lite] Invalid tensor ID."sv); + return static_cast(ErrNo::InvalidArgument); + } + + // Copy tensor data to buffer. + const TfLiteTensor *Tensor = + TfLiteInterpreterGetOutputTensor(Cxt->Interp, TensorId); + size_t RealSize = TfLiteTensorByteSize(Tensor); + *WrittenBytes = 0U; + if (unlikely(RealSize != BufLen)) { + spdlog::error( + "[WasmEdge-Tensorflow-Lite] Unexpected buffer length: {}, output tensor size: {}."sv, + BufLen, RealSize); + return static_cast(ErrNo::InvalidArgument); + } + if (likely(Tensor != nullptr)) { + TfLiteTensorCopyToBuffer(Tensor, BufSpan.data(), RealSize); + } + return static_cast(ErrNo::Success); +} + +Expect AppendInput::body(const Runtime::CallingFrame &Frame, + uint32_t SessionId, uint32_t NamePtr, + uint32_t NameLen, uint32_t TensorBufPtr, + uint32_t TensorBufLen) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Get context from ID. + SESSION_CHECK(Cxt, SessionId, "Invalid session ID."sv, ErrNo::InvalidArgument) + + // Check the input tensor buffer. + MEM_SPAN_CHECK(TensorBufSpan, MemInst, uint8_t, TensorBufPtr, TensorBufLen, + "Failed when accessing the input tensor buffer memory."sv) + + // Check the input tensor operation name buffer. + MEM_SV_CHECK(NameSV, MemInst, NamePtr, NameLen, + "Failed when accessing the input name buffer memory."sv) + + // Find the input tensor. + bool IsFound = false; + uint32_t InCnt = TfLiteInterpreterGetInputTensorCount(Cxt->Interp); + for (uint32_t I = 0; I < InCnt; ++I) { + TfLiteTensor *Tensor = TfLiteInterpreterGetInputTensor(Cxt->Interp, I); + if (NameSV == std::string(TfLiteTensorName(Tensor))) { + size_t RealSize = TfLiteTensorByteSize(Tensor); + if (unlikely(RealSize != TensorBufLen)) { + spdlog::error( + "[WasmEdge-Tensorflow-Lite] Unexpected buffer length: {}, " + "input tensor size: {}."sv, + TensorBufLen, RealSize); + return static_cast(ErrNo::InvalidArgument); + } + TfLiteStatus Stat = TfLiteTensorCopyFromBuffer( + Tensor, TensorBufSpan.data(), TensorBufLen); + if (unlikely(Stat != TfLiteStatus::kTfLiteOk)) { + spdlog::error( + "[WasmEdge-Tensorflow-Lite] Copy data from tensor {} failed."sv, + NameSV); + return static_cast(ErrNo::Busy); + } + IsFound = true; + break; + } + } + if (unlikely(!IsFound)) { + spdlog::error("[WasmEdge-Tensorflow-Lite] Input tensor {} not found."sv, + NameSV); + return static_cast(ErrNo::InvalidArgument); + } + + return static_cast(ErrNo::Success); +} + +} // namespace WasmEdgeTensorflowLite +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflowlite/tensorflowlite_func.h b/plugins/wasmedge_tensorflowlite/tensorflowlite_func.h new file mode 100644 index 00000000..d8e9d5b5 --- /dev/null +++ b/plugins/wasmedge_tensorflowlite/tensorflowlite_func.h @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "tensorflowlite_base.h" + +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeTensorflowLite { + +class CreateSession : public Func { +public: + CreateSession(TFLiteEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ModBufPtr, + uint32_t ModBufLen, uint32_t SessionIdPtr); +}; + +class DeleteSession : public Func { +public: + DeleteSession(TFLiteEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId); +}; + +class RunSession : public Func { +public: + RunSession(TFLiteEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId); +}; + +class GetOutputTensor : public Func { +public: + GetOutputTensor(TFLiteEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId, + uint32_t NamePtr, uint32_t NameLen, + uint32_t TensorIdPtr); +}; + +class GetTensorLen : public Func { +public: + GetTensorLen(TFLiteEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId, + uint32_t TensorId, uint32_t LenPtr); +}; + +class GetTensorData : public Func { +public: + GetTensorData(TFLiteEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId, + uint32_t TensorId, uint32_t BufPtr, uint32_t BufLen, + uint32_t WrittenBytesPtr); +}; + +class AppendInput : public Func { +public: + AppendInput(TFLiteEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SessionId, + uint32_t NamePtr, uint32_t NameLen, + uint32_t TensorBufPtr, uint32_t TensorBufLen); +}; + +} // namespace WasmEdgeTensorflowLite +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflowlite/tensorflowlite_module.cpp b/plugins/wasmedge_tensorflowlite/tensorflowlite_module.cpp new file mode 100644 index 00000000..5f880229 --- /dev/null +++ b/plugins/wasmedge_tensorflowlite/tensorflowlite_module.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "tensorflowlite_module.h" +#include "tensorflowlite_func.h" + +#include + +namespace WasmEdge { +namespace Host { + +WasmEdgeTensorflowLiteModule::WasmEdgeTensorflowLiteModule() + : Runtime::Instance::ModuleInstance("wasmedge_tensorflowlite") { + addHostFunc("create_session", + std::make_unique(Env)); + addHostFunc("delete_session", + std::make_unique(Env)); + addHostFunc("run_session", + std::make_unique(Env)); + addHostFunc("get_output_tensor", + std::make_unique(Env)); + addHostFunc("get_tensor_len", + std::make_unique(Env)); + addHostFunc("get_tensor_data", + std::make_unique(Env)); + addHostFunc("append_input", + std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflowlite/tensorflowlite_module.h b/plugins/wasmedge_tensorflowlite/tensorflowlite_module.h new file mode 100644 index 00000000..a93545bb --- /dev/null +++ b/plugins/wasmedge_tensorflowlite/tensorflowlite_module.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "tensorflowlite_env.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasmEdgeTensorflowLiteModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeTensorflowLiteModule(); + ~WasmEdgeTensorflowLiteModule() = default; + + WasmEdgeTensorflowLite::TFLiteEnv &getEnv() { return Env; } + +private: + WasmEdgeTensorflowLite::TFLiteEnv Env; +}; + +} // namespace Host +} // namespace WasmEdge From 72b61712dc2bc89b93b2260d4305ec77e221d3e6 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Wed, 7 Jun 2023 08:55:42 +0800 Subject: [PATCH 114/623] [Test] Add unit tests for WasmEdge-tensorflow, tensorflowlite, and image plugins. Signed-off-by: YiYing He --- test/plugins/CMakeLists.txt | 18 ++++++ test/plugins/wasmedge_image/CMakeLists.txt | 35 ++++++++++++ .../plugins/wasmedge_image/wasmedge_image.cpp | 47 +++++++++++++++ .../wasmedge_tensorflow/CMakeLists.txt | 38 +++++++++++++ .../wasmedge_tensorflow.cpp | 57 +++++++++++++++++++ .../wasmedge_tensorflowlite/CMakeLists.txt | 38 +++++++++++++ .../wasmedge_tensorflowlite.cpp | 54 ++++++++++++++++++ 7 files changed, 287 insertions(+) create mode 100644 test/plugins/wasmedge_image/CMakeLists.txt create mode 100644 test/plugins/wasmedge_image/wasmedge_image.cpp create mode 100644 test/plugins/wasmedge_tensorflow/CMakeLists.txt create mode 100644 test/plugins/wasmedge_tensorflow/wasmedge_tensorflow.cpp create mode 100644 test/plugins/wasmedge_tensorflowlite/CMakeLists.txt create mode 100644 test/plugins/wasmedge_tensorflowlite/wasmedge_tensorflowlite.cpp diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt index ccc4a365..d8dfb3c5 100644 --- a/test/plugins/CMakeLists.txt +++ b/test/plugins/CMakeLists.txt @@ -15,6 +15,24 @@ if(WASMEDGE_PLUGIN_WASI_CRYPTO) add_subdirectory(wasi_crypto) endif() +if(WASMEDGE_PLUGIN_TENSORFLOW) + if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") + add_subdirectory(wasmedge_tensorflow) + endif() +endif() + +if(WASMEDGE_PLUGIN_TENSORFLOWLITE) + if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") + add_subdirectory(wasmedge_tensorflowlite) + endif() +endif() + +if(WASMEDGE_PLUGIN_IMAGE) + if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") + add_subdirectory(wasmedge_image) + endif() +endif() + if(WASMEDGE_PLUGIN_WASM_BPF) if(CMAKE_SYSTEM_NAME MATCHES "Linux") add_subdirectory(wasm_bpf) diff --git a/test/plugins/wasmedge_image/CMakeLists.txt b/test/plugins/wasmedge_image/CMakeLists.txt new file mode 100644 index 00000000..83a80949 --- /dev/null +++ b/test/plugins/wasmedge_image/CMakeLists.txt @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +wasmedge_add_executable(wasmedgeImageTests + wasmedge_image.cpp +) + +add_dependencies(wasmedgeImageTests + wasmedgePluginWasmEdgeImage +) + +target_include_directories(wasmedgeImageTests + PUBLIC + $ + $ +) + +target_link_libraries(wasmedgeImageTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} +) +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgeImageTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgeImageTests + PRIVATE + wasmedge_shared + ) +endif() + +add_test(wasmedgeImageTests wasmedgeImageTests) diff --git a/test/plugins/wasmedge_image/wasmedge_image.cpp b/test/plugins/wasmedge_image/wasmedge_image.cpp new file mode 100644 index 00000000..0583c1c7 --- /dev/null +++ b/test/plugins/wasmedge_image/wasmedge_image.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "common/defines.h" +#include "image_func.h" +#include "image_module.h" +#include "runtime/instance/module.h" + +#include +#include +#include +#include +#include +#include + +namespace { +WasmEdge::Runtime::Instance::ModuleInstance *createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasmedge_image/" + "libwasmedgePluginWasmEdgeImage" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasmedge_image"sv)) { + if (const auto *Module = Plugin->findModule("wasmedge_image"sv)) { + return Module->create().release(); + } + } + return nullptr; +} +} // namespace + +// TODO: unit tests for every functions. + +TEST(WasmEdgeImageTest, Module) { + // Create the wasmedge_image module instance. + auto *ImgMod = + dynamic_cast(createModule()); + EXPECT_FALSE(ImgMod == nullptr); + EXPECT_EQ(ImgMod->getFuncExportNum(), 2U); + EXPECT_NE(ImgMod->findFuncExports("load_jpg"), nullptr); + EXPECT_NE(ImgMod->findFuncExports("load_png"), nullptr); + delete ImgMod; +} + +GTEST_API_ int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/plugins/wasmedge_tensorflow/CMakeLists.txt b/test/plugins/wasmedge_tensorflow/CMakeLists.txt new file mode 100644 index 00000000..aba4ee6e --- /dev/null +++ b/test/plugins/wasmedge_tensorflow/CMakeLists.txt @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +wasmedge_add_executable(wasmedgeTensorflowTests + wasmedge_tensorflow.cpp +) + +add_dependencies(wasmedgeTensorflowTests + wasmedgePluginWasmEdgeTensorflow +) + +include(WASINNDeps) +wasmedge_setup_tf_target(wasmedgeTensorflowTests) + +target_include_directories(wasmedgeTensorflowTests + PUBLIC + $ + $ +) + +target_link_libraries(wasmedgeTensorflowTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} +) +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgeTensorflowTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgeTensorflowTests + PRIVATE + wasmedge_shared + ) +endif() + +add_test(wasmedgeTensorflowTests wasmedgeTensorflowTests) diff --git a/test/plugins/wasmedge_tensorflow/wasmedge_tensorflow.cpp b/test/plugins/wasmedge_tensorflow/wasmedge_tensorflow.cpp new file mode 100644 index 00000000..253ab828 --- /dev/null +++ b/test/plugins/wasmedge_tensorflow/wasmedge_tensorflow.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "common/defines.h" +#include "runtime/instance/module.h" +#include "tensorflow_func.h" +#include "tensorflow_module.h" + +#include +#include +#include +#include +#include +#include + +namespace { +WasmEdge::Runtime::Instance::ModuleInstance *createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasmedge_tensorflow/" + "libwasmedgePluginWasmEdgeTensorflow" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = + WasmEdge::Plugin::Plugin::find("wasmedge_tensorflow"sv)) { + if (const auto *Module = Plugin->findModule("wasmedge_tensorflow"sv)) { + return Module->create().release(); + } + } + return nullptr; +} +} // namespace + +// TODO: unit tests for every functions. + +TEST(WasmEdgeTensorflowTest, Module) { + // Create the wasmedge_tensorflow module instance. + auto *TFMod = + dynamic_cast(createModule()); + EXPECT_FALSE(TFMod == nullptr); + EXPECT_EQ(TFMod->getFuncExportNum(), 11U); + EXPECT_NE(TFMod->findFuncExports("create_session"), nullptr); + EXPECT_NE(TFMod->findFuncExports("create_session_saved_model"), nullptr); + EXPECT_NE(TFMod->findFuncExports("delete_session"), nullptr); + EXPECT_NE(TFMod->findFuncExports("run_session"), nullptr); + EXPECT_NE(TFMod->findFuncExports("get_output_tensor"), nullptr); + EXPECT_NE(TFMod->findFuncExports("get_tensor_len"), nullptr); + EXPECT_NE(TFMod->findFuncExports("get_tensor_data"), nullptr); + EXPECT_NE(TFMod->findFuncExports("append_input"), nullptr); + EXPECT_NE(TFMod->findFuncExports("append_output"), nullptr); + EXPECT_NE(TFMod->findFuncExports("clear_input"), nullptr); + EXPECT_NE(TFMod->findFuncExports("clear_output"), nullptr); + delete TFMod; +} + +GTEST_API_ int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/test/plugins/wasmedge_tensorflowlite/CMakeLists.txt b/test/plugins/wasmedge_tensorflowlite/CMakeLists.txt new file mode 100644 index 00000000..04eee39a --- /dev/null +++ b/test/plugins/wasmedge_tensorflowlite/CMakeLists.txt @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +wasmedge_add_executable(wasmedgeTensorflowLiteTests + wasmedge_tensorflowlite.cpp +) + +add_dependencies(wasmedgeTensorflowLiteTests + wasmedgePluginWasmEdgeTensorflowLite +) + +include(WASINNDeps) +wasmedge_setup_tflite_target(wasmedgeTensorflowLiteTests) + +target_include_directories(wasmedgeTensorflowLiteTests + PUBLIC + $ + $ +) + +target_link_libraries(wasmedgeTensorflowLiteTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} +) +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgeTensorflowLiteTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgeTensorflowLiteTests + PRIVATE + wasmedge_shared + ) +endif() + +add_test(wasmedgeTensorflowLiteTests wasmedgeTensorflowLiteTests) diff --git a/test/plugins/wasmedge_tensorflowlite/wasmedge_tensorflowlite.cpp b/test/plugins/wasmedge_tensorflowlite/wasmedge_tensorflowlite.cpp new file mode 100644 index 00000000..b0cfa5c3 --- /dev/null +++ b/test/plugins/wasmedge_tensorflowlite/wasmedge_tensorflowlite.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "common/defines.h" +#include "runtime/instance/module.h" +#include "tensorflowlite_func.h" +#include "tensorflowlite_module.h" + +#include +#include +#include +#include +#include +#include + +namespace { +WasmEdge::Runtime::Instance::ModuleInstance *createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasmedge_tensorflowlite/" + "libwasmedgePluginWasmEdgeTensorflowLite" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = + WasmEdge::Plugin::Plugin::find("wasmedge_tensorflowlite"sv)) { + if (const auto *Module = Plugin->findModule("wasmedge_tensorflowlite"sv)) { + return Module->create().release(); + } + } + return nullptr; +} +} // namespace + +// TODO: unit tests for every functions. + +TEST(WasmEdgeTensorflowLiteTest, Module) { + // Create the wasmedge_tensorflowlite module instance. + auto *TFLiteMod = + dynamic_cast( + createModule()); + EXPECT_FALSE(TFLiteMod == nullptr); + EXPECT_EQ(TFLiteMod->getFuncExportNum(), 7U); + EXPECT_NE(TFLiteMod->findFuncExports("create_session"), nullptr); + EXPECT_NE(TFLiteMod->findFuncExports("delete_session"), nullptr); + EXPECT_NE(TFLiteMod->findFuncExports("run_session"), nullptr); + EXPECT_NE(TFLiteMod->findFuncExports("get_output_tensor"), nullptr); + EXPECT_NE(TFLiteMod->findFuncExports("get_tensor_len"), nullptr); + EXPECT_NE(TFLiteMod->findFuncExports("get_tensor_data"), nullptr); + EXPECT_NE(TFLiteMod->findFuncExports("append_input"), nullptr); + delete TFLiteMod; +} + +GTEST_API_ int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} From 943b243aa10d748b9551865fe0e078b103048822 Mon Sep 17 00:00:00 2001 From: Officeyutong Date: Tue, 20 Jun 2023 01:30:39 +0800 Subject: [PATCH 115/623] [Utils] Add clang tool in manylinux2014_x86_64 docker image (#2606) * update manylinux dockerfile Signed-off-by: officeyutong * Add checksum for clang*.src.tar.xz Signed-off-by: officeyutong --------- Signed-off-by: officeyutong --- utils/docker/Dockerfile.manylinux2014_x86_64 | 7 +++++-- utils/docker/SHA256SUM | 1 + 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/utils/docker/Dockerfile.manylinux2014_x86_64 b/utils/docker/Dockerfile.manylinux2014_x86_64 index 73013544..cc6a942a 100644 --- a/utils/docker/Dockerfile.manylinux2014_x86_64 +++ b/utils/docker/Dockerfile.manylinux2014_x86_64 @@ -25,7 +25,8 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/lld-16.0.5.src.tar.xz \ https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/libunwind-16.0.5.src.tar.xz \ https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/cmake-16.0.5.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/third-party-16.0.5.src.tar.xz && \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/third-party-16.0.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/clang-16.0.5.src.tar.xz && \ sha256sum -c SHA256SUM && \ gzip -dc zstd-1.5.5.tar.gz | tar -xf - && \ gzip -dc cmake-3.26.4.tar.gz | tar -xf - && \ @@ -35,6 +36,7 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil xz -dc libunwind-16.0.5.src.tar.xz | tar -xf - && \ xz -dc cmake-16.0.5.src.tar.xz | tar -xf - && \ xz -dc third-party-16.0.5.src.tar.xz | tar -xf - && \ + xz -dc clang-16.0.5.src.tar.xz | tar -xf - && \ export ZSTDFLAGS=(PREFIX=/opt/rh/devtoolset-11/root/usr LIBDIR=/opt/rh/devtoolset-11/root/usr/lib64 SED_ERE_OPT=--regexp-extended MOREFLAGS="-std=c17 -O3 -fPIC -fPIE -fvisibility=hidden") && \ cd zstd-1.5.5 && make -s "${ZSTDFLAGS[@]}" -j $CPU && make -s "${ZSTDFLAGS[@]}" install && rm -vf /opt/rh/devtoolset-11/root/usr/lib64/libzstd.so* && cd - && \ mkdir build && cd build && /opt/python/cp311-cp311/bin/python \ @@ -48,10 +50,11 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil mv -v libunwind-16.0.5.src libunwind && \ mv -v cmake-16.0.5.src cmake && \ mv -v third-party-16.0.5.src third-party && \ + mv -v clang-16.0.5.src clang && \ cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_INSTALL_PREFIX=/opt/rh/devtoolset-11/root/usr \ -DPython3_ROOT_DIR=/opt/python/cp311-cp311 -DLLVM_LIBDIR_SUFFIX=64 \ - -DLLVM_TARGETS_TO_BUILD="X86" -DLLVM_ENABLE_PROJECTS=lld \ + -DLLVM_TARGETS_TO_BUILD="X86;BPF" -DLLVM_ENABLE_PROJECTS="lld;clang" \ -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" \ -DBUILD_SHARED_LIBS=OFF llvm && \ cmake --build build --target install && \ diff --git a/utils/docker/SHA256SUM b/utils/docker/SHA256SUM index e176fbb4..903268be 100644 --- a/utils/docker/SHA256SUM +++ b/utils/docker/SHA256SUM @@ -5,3 +5,4 @@ e7f65970298a60e9608a9fc55ea9af5e9c8e1bc0dc0067f3e9f10eb3fe3e8986 libunwind-16.0 0a4bbb8505e95570e529d6b3d5176e93beb3260f061de9001e320d57b59aed59 third-party-16.0.5.src.tar.xz 31747ae633213f1eda3842686f83c2aa1412e0f5691d1c14dbbcc67fe7400cea v1.11.1.tar.gz 9c4396cc829cfae319a6e2615202e82aad41372073482fce286fac78646d3ee4 zstd-1.5.5.tar.gz +f4bb3456c415f01e929d96983b851c49d02b595bf4f99edbbfc55626437775a7 clang-16.0.5.src.tar.xz From 4decc4633d60245b35b18480fe87185d31f36de8 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Tue, 20 Jun 2023 05:17:39 +0800 Subject: [PATCH 116/623] [API] Support setting host data and its finalizer into module instances. Signed-off-by: YiYing He --- test/plugins/unittest/testplugin.c | 41 +++++++++++++++++++++++----- test/plugins/unittest/unittest_c.cpp | 5 ++++ 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/test/plugins/unittest/testplugin.c b/test/plugins/unittest/testplugin.c index 7b58cfaf..8239f92b 100644 --- a/test/plugins/unittest/testplugin.c +++ b/test/plugins/unittest/testplugin.c @@ -5,38 +5,63 @@ #include "wasmedge/wasmedge.h" #include +#include +#include static WasmEdge_String NameString; static const char NameCString[] = "name"; static const WasmEdge_String NameStringDefaultValue = {.Buf = NameCString, .Length = 4}; +void Finalizer(void *Data) { + printf("Deallocate host data\n"); + free((int32_t *)Data); +} -WasmEdge_Result HostFuncAdd(void *Data __attribute__((unused)), +WasmEdge_Result HostFuncAdd(void *Data, const WasmEdge_CallingFrameContext *CallFrameCxt __attribute__((unused)), const WasmEdge_Value *In, WasmEdge_Value *Out) { + /* + * Host function to calculate A + B, + * and accumulate (A + B) to the host data. + */ int32_t Val1 = WasmEdge_ValueGetI32(In[0]); int32_t Val2 = WasmEdge_ValueGetI32(In[1]); Out[0] = WasmEdge_ValueGenI32(Val1 + Val2); + int32_t *Accum = (int32_t *)Data; + *Accum += WasmEdge_ValueGetI32(Out[0]); + printf("Current accumulate: %d\n", *Accum); return WasmEdge_Result_Success; } -WasmEdge_Result HostFuncSub(void *Data __attribute__((unused)), +WasmEdge_Result HostFuncSub(void *Data, const WasmEdge_CallingFrameContext *CallFrameCxt __attribute__((unused)), const WasmEdge_Value *In, WasmEdge_Value *Out) { + /* + * Host function to calculate A - B, + * and accumulate (A - B) to the host data. + */ int32_t Val1 = WasmEdge_ValueGetI32(In[0]); int32_t Val2 = WasmEdge_ValueGetI32(In[1]); Out[0] = WasmEdge_ValueGenI32(Val1 - Val2); + int32_t *Accum = (int32_t *)Data; + *Accum += WasmEdge_ValueGetI32(Out[0]); + printf("Current accumulate: %d\n", *Accum); return WasmEdge_Result_Success; } WasmEdge_ModuleInstanceContext * CreateTestModule(const struct WasmEdge_ModuleDescriptor *Desc) { - WasmEdge_String ModuleName = - WasmEdge_StringCreateByCString(Desc->Name); + /* Allocate and initialize a host data. */ + printf("Allocate host data\n"); + int32_t *Accumulate = (int32_t *)malloc(sizeof(int32_t)); + *Accumulate = 0; + + /* Create the module instance. */ + WasmEdge_String ModuleName = WasmEdge_StringCreateByCString(Desc->Name); WasmEdge_ModuleInstanceContext *Mod = - WasmEdge_ModuleInstanceCreate(ModuleName); + WasmEdge_ModuleInstanceCreateWithData(ModuleName, Accumulate, Finalizer); WasmEdge_StringDelete(ModuleName); WasmEdge_String FuncName; @@ -47,13 +72,15 @@ CreateTestModule(const struct WasmEdge_ModuleDescriptor *Desc) { ParamTypes[1] = WasmEdge_ValType_I32; ReturnTypes[0] = WasmEdge_ValType_I32; + /* Create the "add" function and add into the module instance. */ FType = WasmEdge_FunctionTypeCreate(ParamTypes, 2, ReturnTypes, 1); FuncName = WasmEdge_StringCreateByCString("add"); - FuncCxt = WasmEdge_FunctionInstanceCreate(FType, HostFuncAdd, NULL, 0); + FuncCxt = WasmEdge_FunctionInstanceCreate(FType, HostFuncAdd, Accumulate, 0); WasmEdge_ModuleInstanceAddFunction(Mod, FuncName, FuncCxt); WasmEdge_StringDelete(FuncName); + /* Create the "sub" function and add into the module instance. */ FuncName = WasmEdge_StringCreateByCString("sub"); - FuncCxt = WasmEdge_FunctionInstanceCreate(FType, HostFuncSub, NULL, 0); + FuncCxt = WasmEdge_FunctionInstanceCreate(FType, HostFuncSub, Accumulate, 0); WasmEdge_ModuleInstanceAddFunction(Mod, FuncName, FuncCxt); WasmEdge_StringDelete(FuncName); WasmEdge_FunctionTypeDelete(FType); diff --git a/test/plugins/unittest/unittest_c.cpp b/test/plugins/unittest/unittest_c.cpp index cc36965a..19f22993 100644 --- a/test/plugins/unittest/unittest_c.cpp +++ b/test/plugins/unittest/unittest_c.cpp @@ -56,6 +56,9 @@ TEST(wasmedgePluginTests, C_Run) { // Create the wasmedge_plugintest_c_module module instance. auto *ModInstC = createModuleC(); EXPECT_FALSE(ModInstC == nullptr); + int32_t *HostData = + static_cast(WasmEdge_ModuleInstanceGetHostData(ModInstC)); + EXPECT_FALSE(HostData == nullptr); // Create the wasmedge_plugintest_cpp_module module instance. auto *ModInstCPP = createModuleCPP(); @@ -71,6 +74,7 @@ TEST(wasmedgePluginTests, C_Run) { Res = WasmEdge_ExecutorInvoke(ExecCxt, FuncCxt, Params, 2, Returns, 1); EXPECT_TRUE(WasmEdge_ResultOK(Res)); EXPECT_EQ(WasmEdge_ValueGetI32(Returns[0]), 444); + EXPECT_EQ(*HostData, 444); // Test: Run the function "wasmedge_plugintest_c.sub". FuncName = WasmEdge_StringCreateByCString("sub"); @@ -82,6 +86,7 @@ TEST(wasmedgePluginTests, C_Run) { Res = WasmEdge_ExecutorInvoke(ExecCxt, FuncCxt, Params, 2, Returns, 1); EXPECT_TRUE(WasmEdge_ResultOK(Res)); EXPECT_EQ(WasmEdge_ValueGetI32(Returns[0]), 111); + EXPECT_EQ(*HostData, 555); // Test: Run the function "wasmedge_plugintest_cpp.add". FuncName = WasmEdge_StringCreateByCString("add"); From 9c43af3c9863058a9e990434b8d41a01c7be2072 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Mon, 26 Jun 2023 13:31:46 +0800 Subject: [PATCH 117/623] [Plugin] Support saved-model in WasmEdge-Tensorflow plugin. Signed-off-by: YiYing He --- plugins/wasmedge_image/image_func.cpp | 2 +- .../wasmedge_tensorflow/tensorflow_func.cpp | 31 +++++++++++++++---- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/plugins/wasmedge_image/image_func.cpp b/plugins/wasmedge_image/image_func.cpp index 7710e53c..58d62649 100644 --- a/plugins/wasmedge_image/image_func.cpp +++ b/plugins/wasmedge_image/image_func.cpp @@ -31,7 +31,7 @@ bool decodeImgToSize(Span Buf, uint32_t W, uint32_t H, ImgStream.write(Buf.data(), Buf.size()); Image Img; try { - boost::gil::read_image(ImgStream, Img, FormatTag()); + boost::gil::read_and_convert_image(ImgStream, Img, FormatTag()); } catch (std::exception const &e) { spdlog::error("[WasmEdge-Image] Decode image fail: {}"sv, e.what()); return false; diff --git a/plugins/wasmedge_tensorflow/tensorflow_func.cpp b/plugins/wasmedge_tensorflow/tensorflow_func.cpp index a3ed8bcc..eed2b4fc 100644 --- a/plugins/wasmedge_tensorflow/tensorflow_func.cpp +++ b/plugins/wasmedge_tensorflow/tensorflow_func.cpp @@ -117,10 +117,6 @@ Expect CreateSession::body(const Runtime::CallingFrame &Frame, Expect CreateSessionSavedModel::body( const Runtime::CallingFrame &Frame, uint32_t PathPtr, uint32_t PathLen, uint32_t TagsBufPtr, uint32_t TagsBufLen, uint32_t SessionIdPtr) { - // TODO: Implement this function. - spdlog::error("[WasmEdge-Tensorflow] Saved model is not supported yet."sv); - return static_cast(ErrNo::InvalidArgument); - // Check memory instance from module. MEMINST_CHECK(MemInst, Frame, 0) @@ -137,21 +133,44 @@ Expect CreateSessionSavedModel::body( "Failed when accessing the tags memory."sv) // Check the elements of tags. - std::vector Tags; + std::vector Tags; + std::vector TagsArgv; Tags.reserve(TagsBufLen); + TagsArgv.reserve(TagsBufLen); for (size_t I = 0; I < TagSpan.size(); ++I) { + // Should use std::string to copy the tag name here to prevent from no + // null-termination of the tag strings here. const auto &Tag = TagSpan[I]; MEM_SV_CHECK(TagNameSV, MemInst, Tag.Ptr, Tag.Len, "Failed when accessing the tag name memory."sv) Tags.emplace_back(TagNameSV); + TagsArgv.emplace_back(Tags.back().c_str()); } // Check the return value: SessionIdPtr should be valid. MEM_PTR_CHECK(SessionId, MemInst, uint32_t, SessionIdPtr, "Failed when accessing the return SessionID memory."sv) - // TODO: Implement here. + // Create context and import graph. + uint32_t NewID = Env.newContext(); + SESSION_CHECK(Cxt, NewID, "Failed when allocating resources."sv, + ErrNo::MissingMemory) + // Create session. + Cxt->Graph = TF_NewGraph(); + Cxt->GraphOpts = TF_NewImportGraphDefOptions(); + Cxt->SessionOpts = TF_NewSessionOptions(); + Cxt->Session = TF_LoadSessionFromSavedModel( + Cxt->SessionOpts, nullptr, std::string(PathSV).c_str(), TagsArgv.data(), + TagsArgv.size(), Cxt->Graph, nullptr, Cxt->Stat); + if (unlikely(TF_GetCode(Cxt->Stat) != TF_OK)) { + spdlog::error("[WasmEdge-Tensorflow] Unable to create session: {}"sv, + TF_Message(Cxt->Stat)); + Env.deleteContext(NewID); + return static_cast(ErrNo::InvalidArgument); + } + + *SessionId = NewID; return static_cast(ErrNo::Success); } From 5604f1aaefdecb6ee04fb321f7d0710ab8fb0c07 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Wed, 28 Jun 2023 06:02:59 +0800 Subject: [PATCH 118/623] [CMake] Add the option to build WasmEdge-Image plugin. Signed-off-by: YiYing He --- plugins/wasmedge_image/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/plugins/wasmedge_image/CMakeLists.txt b/plugins/wasmedge_image/CMakeLists.txt index 5cd03215..2c2d8c5d 100644 --- a/plugins/wasmedge_image/CMakeLists.txt +++ b/plugins/wasmedge_image/CMakeLists.txt @@ -105,6 +105,7 @@ find_package(ZLIB REQUIRED) find_package(Boost 1.74.0) if(${Boost_FOUND}) else() + include(FetchContent) FetchContent_Declare( Boost URL https://boostorg.jfrog.io/artifactory/main/release/1.82.0/source/boost_1_82_0.tar.bz2 From 8658fc715e6187f2b1e62719cda637fc5c077a55 Mon Sep 17 00:00:00 2001 From: Divyanshu Gupta <114072061+Mash707@users.noreply.github.com> Date: Tue, 18 Jul 2023 15:36:05 +0530 Subject: [PATCH 119/623] [WASI-NN] Updating install script for OpenVino 2023.0.0 version (#2636) * Changed Openvino version to 2023.0.0 * Changed InferenceEngine to OpenVINO Signed-off-by: Divyanshu GUpta --------- Signed-off-by: Divyanshu GUpta --- utils/wasi-nn/build-wasinn-ubuntu-openvino.sh | 1 - utils/wasi-nn/install-openvino.sh | 17 +++++------------ utils/wasi-nn/test-wasinn-ubuntu-openvino.sh | 1 - 3 files changed, 5 insertions(+), 14 deletions(-) diff --git a/utils/wasi-nn/build-wasinn-ubuntu-openvino.sh b/utils/wasi-nn/build-wasinn-ubuntu-openvino.sh index eac9e848..5e3be658 100755 --- a/utils/wasi-nn/build-wasinn-ubuntu-openvino.sh +++ b/utils/wasi-nn/build-wasinn-ubuntu-openvino.sh @@ -6,7 +6,6 @@ if [[ ! -v "${CMAKE_BUILD_TYPE}" ]]; then CMAKE_BUILD_TYPE=Release fi -source /opt/intel/openvino_2021/bin/setupvars.sh ldconfig git config --global --add safe.directory $(pwd) if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=$CMAKE_BUILD_TYPE -DWASMEDGE_BUILD_TESTS=ON -DWASMEDGE_PLUGIN_WASI_NN_BACKEND="OpenVINO" .; then diff --git a/utils/wasi-nn/install-openvino.sh b/utils/wasi-nn/install-openvino.sh index 00d3c13b..c15b6e93 100755 --- a/utils/wasi-nn/install-openvino.sh +++ b/utils/wasi-nn/install-openvino.sh @@ -2,18 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-FileCopyrightText: 2019-2022 Second State INC -if [[ ! -v "${OPENVINO_VERSION}" ]]; then - OPENVINO_VERSION="2021.4.582" -fi -if [[ ! -v "${OPENVINO_YEAR}" ]]; then - OPENVINO_YEAR="2021" -fi - set -e -echo "Installing OpenVINO with version ${OPENVINO_VERSION}" -curl -sSL https://apt.repos.intel.com/openvino/$OPENVINO_YEAR/GPG-PUB-KEY-INTEL-OPENVINO-$OPENVINO_YEAR | gpg --dearmor > /usr/share/keyrings/GPG-PUB-KEY-INTEL-OPENVINO-$OPENVINO_YEAR.gpg -echo "deb [signed-by=/usr/share/keyrings/GPG-PUB-KEY-INTEL-OPENVINO-$OPENVINO_YEAR.gpg] https://apt.repos.intel.com/openvino/$OPENVINO_YEAR all main" | tee /etc/apt/sources.list.d/intel-openvino-$OPENVINO_YEAR.list +echo "Installing OpenVINO with version 2023.0.0" +wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB +apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB +echo "deb https://apt.repos.intel.com/openvino/2023 ubuntu20 main" | tee /etc/apt/sources.list.d/intel-openvino-2023.list apt update -apt install -y intel-openvino-runtime-ubuntu20-$OPENVINO_VERSION -source /opt/intel/openvino_2021/bin/setupvars.sh +apt-get -y install openvino ldconfig diff --git a/utils/wasi-nn/test-wasinn-ubuntu-openvino.sh b/utils/wasi-nn/test-wasinn-ubuntu-openvino.sh index 6500a296..49198538 100755 --- a/utils/wasi-nn/test-wasinn-ubuntu-openvino.sh +++ b/utils/wasi-nn/test-wasinn-ubuntu-openvino.sh @@ -1,4 +1,3 @@ -source /opt/intel/openvino_2021/bin/setupvars.sh ldconfig export LD_LIBRARY_PATH="$(pwd)/build/lib/api:$LD_LIBRARY_PATH" From bc690ab1fa25a1983d2cd7f7073cf4d3853af131 Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 18 Jul 2023 17:13:19 +0800 Subject: [PATCH 120/623] [Utils] devtoolset-11 is not available on manylinux2014 aarch64, downgrade to devtoolset-10 Signed-off-by: hydai --- utils/docker/Dockerfile.manylinux2014_aarch64 | 31 ++++++++++--------- utils/docker/build-manylinux.sh | 2 +- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/utils/docker/Dockerfile.manylinux2014_aarch64 b/utils/docker/Dockerfile.manylinux2014_aarch64 index 6f30df72..59ca239e 100644 --- a/utils/docker/Dockerfile.manylinux2014_aarch64 +++ b/utils/docker/Dockerfile.manylinux2014_aarch64 @@ -7,16 +7,16 @@ MAINTAINER hydai hydai@secondstate.io ADD SHA256SUM /root/ -ENV PATH /opt/rh/devtoolset-11/root/usr/bin${PATH:+:${PATH}} -ENV MANPATH /opt/rh/devtoolset-11/root/usr/share/man${MANPATH:+:${MANPATH}} -ENV INFOPATH /opt/rh/devtoolset-11/root/usr/share/info${INFOPATH:+:${INFOPATH}} -ENV PKG_CONFIG_PATH /opt/rh/devtoolset-11/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} +ENV PATH /opt/rh/devtoolset-10/root/usr/bin${PATH:+:${PATH}} +ENV MANPATH /opt/rh/devtoolset-10/root/usr/share/man${MANPATH:+:${MANPATH}} +ENV INFOPATH /opt/rh/devtoolset-10/root/usr/share/info${INFOPATH:+:${INFOPATH}} +ENV PKG_CONFIG_PATH /opt/rh/devtoolset-10/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} -RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build dpkg centos-release-scl && \ - yum install -y devtoolset-11 && \ +RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-build centos-release-scl && \ + yum install -y devtoolset-10 && \ export CPU=$(/opt/python/cp311-cp311/bin/python3 -c \ 'import multiprocessing; print(multiprocessing.cpu_count())') && \ - export CFGFLAGS="--prefix=/opt/rh/devtoolset-11/root/usr --disable-shared --libdir=/opt/rh/devtoolset-11/root/usr/lib64" && \ + export CFGFLAGS="--prefix=/opt/rh/devtoolset-10/root/usr --disable-shared --libdir=/opt/rh/devtoolset-10/root/usr/lib64" && \ curl -s -L -O --remote-name-all \ https://github.com/facebook/zstd/releases/download/v1.5.5/zstd-1.5.5.tar.gz \ https://github.com/Kitware/CMake/releases/download/v3.26.4/cmake-3.26.4.tar.gz \ @@ -25,7 +25,8 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/lld-16.0.5.src.tar.xz \ https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/libunwind-16.0.5.src.tar.xz \ https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/cmake-16.0.5.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/third-party-16.0.5.src.tar.xz && \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/third-party-16.0.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/clang-16.0.5.src.tar.xz && \ sha256sum -c SHA256SUM && \ gzip -dc zstd-1.5.5.tar.gz | tar -xf - && \ gzip -dc cmake-3.26.4.tar.gz | tar -xf - && \ @@ -35,23 +36,25 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil xz -dc libunwind-16.0.5.src.tar.xz | tar -xf - && \ xz -dc cmake-16.0.5.src.tar.xz | tar -xf - && \ xz -dc third-party-16.0.5.src.tar.xz | tar -xf - && \ - export ZSTDFLAGS=(PREFIX=/opt/rh/devtoolset-11/root/usr LIBDIR=/opt/rh/devtoolset-11/root/usr/lib64 SED_ERE_OPT=--regexp-extended MOREFLAGS="-std=c17 -O3 -fPIC -fPIE -fvisibility=hidden") && \ - cd zstd-1.5.5 && make -s "${ZSTDFLAGS[@]}" -j $CPU && make -s "${ZSTDFLAGS[@]}" install && rm -vf /opt/rh/devtoolset-11/root/usr/lib64/libzstd.so* && cd - && \ + xz -dc clang-16.0.5.src.tar.xz | tar -xf - && \ + export ZSTDFLAGS=(PREFIX=/opt/rh/devtoolset-10/root/usr LIBDIR=/opt/rh/devtoolset-10/root/usr/lib64 SED_ERE_OPT=--regexp-extended MOREFLAGS="-std=c17 -O3 -fPIC -fPIE -fvisibility=hidden") && \ + cd zstd-1.5.5 && make -s "${ZSTDFLAGS[@]}" -j $CPU && make -s "${ZSTDFLAGS[@]}" install && rm -vf /opt/rh/devtoolset-10/root/usr/lib64/libzstd.so* && cd - && \ mkdir build && cd build && /opt/python/cp311-cp311/bin/python \ ../ninja-1.11.1/configure.py --bootstrap \ --with-python=/opt/python/cp311-cp311/bin/python && \ - cp -v ninja /opt/rh/devtoolset-11/root/usr/bin/ninja && cd - && rm -rf build && \ - mkdir build && cd build && ../cmake-3.26.4/configure --prefix=/opt/rh/devtoolset-11/root/usr \ + cp -v ninja /opt/rh/devtoolset-10/root/usr/bin/ninja && cd - && rm -rf build && \ + mkdir build && cd build && ../cmake-3.26.4/configure --prefix=/opt/rh/devtoolset-10/root/usr \ --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ mv -v llvm-16.0.5.src llvm && \ mv -v lld-16.0.5.src lld && \ mv -v libunwind-16.0.5.src libunwind && \ mv -v cmake-16.0.5.src cmake && \ mv -v third-party-16.0.5.src third-party && \ + mv -v clang-16.0.5.src clang && \ cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_INSTALL_PREFIX=/opt/rh/devtoolset-11/root/usr \ + -DCMAKE_INSTALL_PREFIX=/opt/rh/devtoolset-10/root/usr \ -DPython3_ROOT_DIR=/opt/python/cp311-cp311 -DLLVM_LIBDIR_SUFFIX=64 \ - -DLLVM_TARGETS_TO_BUILD="AArch64" -DLLVM_ENABLE_PROJECTS=lld \ + -DLLVM_TARGETS_TO_BUILD="AArch64;BPF" -DLLVM_ENABLE_PROJECTS="lld;clang" \ -DLLVM_DEFAULT_TARGET_TRIPLE="aarch64-redhat-linux-gnu" \ -DBUILD_SHARED_LIBS=OFF llvm && \ cmake --build build --target install && \ diff --git a/utils/docker/build-manylinux.sh b/utils/docker/build-manylinux.sh index 6570fbe7..2c4ead7b 100755 --- a/utils/docker/build-manylinux.sh +++ b/utils/docker/build-manylinux.sh @@ -35,7 +35,7 @@ for i in "$@"; do done if $IS_NINJA; then - if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DWASMEDGE_BUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM;DEB" ${CMAKE_OPTS} .; then + if ! cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DWASMEDGE_BUILD_PACKAGE="TGZ;TBZ2;TXZ;TZST;RPM" ${CMAKE_OPTS} .; then echo === CMakeOutput.log === cat build/CMakeFiles/CMakeOutput.log echo === CMakeError.log === From 15054d6ee9f77d19d41898ebc3bc69f9b91f7431 Mon Sep 17 00:00:00 2001 From: Officeyutong Date: Wed, 19 Jul 2023 16:10:36 +0700 Subject: [PATCH 121/623] [Misc] Add release workflows for wasm_bpf plugin and the related installing script (#2610) * Install necessary dependencies for testing wasm_bpf * Install dependencies on the manylinux platform * Add test file for wasm_bpf * Enable workflow for testing wasm_bpf * Update `build-extensions.yml` * Make release.yml able to publish wasm_bpf * Update installation script * Update version to 0.13.2 Signed-off-by: officeyutong --- plugins/wasm_bpf/CMakeLists.txt | 91 ++++++++++++++----- plugins/wasm_bpf/bpf-api.h | 8 ++ plugins/wasm_bpf/func-bpf-map-operate.cpp | 8 +- test/plugins/wasm_bpf/CMakeLists.txt | 3 + test/plugins/wasm_bpf/simple_ringbuf_test.cpp | 2 +- test/plugins/wasm_bpf/wasm_bpf.cpp | 2 +- 6 files changed, 84 insertions(+), 30 deletions(-) diff --git a/plugins/wasm_bpf/CMakeLists.txt b/plugins/wasm_bpf/CMakeLists.txt index e64aa599..0e4e8120 100644 --- a/plugins/wasm_bpf/CMakeLists.txt +++ b/plugins/wasm_bpf/CMakeLists.txt @@ -6,28 +6,40 @@ # - ${LIBBPF_SOURCE_DIR} # - FetchContent +option(WASMEDGE_PLUGIN_WASM_BPF_BUILD_LIBBPF_WITH_PKG_CONF "Configure libbpf to use pkg-config for the build process. If enabled, the libbpf build script will utilize pkg-config to search for dependencies such as libz and libelf. If this feature is disabled, the headers and binaries for libz and libelf need to be correctly positioned." YES) + message(STATUS "Trying to get libbpf..") +message(STATUS "Build libbpf with pkg-config: ${WASMEDGE_PLUGIN_WASM_BPF_BUILD_LIBBPF_WITH_PKG_CONF}") set(LIBBPF_FOUND FALSE) # A wrapper function to add libbpf located at a local path as a dependency -function(AddLibbpfAsExternal SOURCE_ROOT) +function(AddLibbpfAsExternal SOURCE_ROOT WITH_PKG_CONF) include(ExternalProject) + set(LIBBPF_SO_PATH ${SOURCE_ROOT}/src/build/libbpf.so) + set(LIBBPF_INCLUDE_DIRS_LOCAL "${SOURCE_ROOT}/src/root/usr/include" "${SOURCE_ROOT}/include/uapi" "${SOURCE_ROOT}/include") + set(LIBBPF_INCLUDE_DIRS ${LIBBPF_INCLUDE_DIRS_LOCAL} PARENT_SCOPE) + + set(LIBBPF_LIBRARIES ${LIBBPF_SO_PATH} PARENT_SCOPE) + set(LIBBPF_LIBRARIES_STATIC ${SOURCE_ROOT}/src/build/libbpf.a PARENT_SCOPE) + + if(${WITH_PKG_CONF}) + set(PKGCONF_PREFIX "") + else() + set(PKGCONF_PREFIX "NO_PKG_CONFIG=1") + set(LIBBPF_DEP_LIBRARIES "elf" "z" PARENT_SCOPE) + endif() + message(STATUS "SOURCE_ROOT=${SOURCE_ROOT}") ExternalProject_Add(libbpf PREFIX libbpf SOURCE_DIR ${SOURCE_ROOT} - CONFIGURE_COMMAND "" - BUILD_COMMAND "make" "-C" "${SOURCE_ROOT}/src" - INSTALL_COMMAND "" + CONFIGURE_COMMAND "mkdir" "build" "root" + BUILD_COMMAND "${PKGCONF_PREFIX}" "OBJDIR=${SOURCE_ROOT}/src/build" "DESTDIR=${SOURCE_ROOT}/src/root" "CFLAGS=-fPIC" "make" "-C" "${SOURCE_ROOT}/src" "install" + INSTALL_COMMAND "cp" "${LIBBPF_SO_PATH}" "${CMAKE_CURRENT_BINARY_DIR}/libbpf.so" BUILD_IN_SOURCE TRUE + BUILD_BYPRODUCTS ${LIBBPF_SO_PATH} ${SOURCE_ROOT}/src/build/libbpf.a ) - set(LIBBPF_SO_PATH ${SOURCE_ROOT}/src/libbpf.so) - set(LIBBPF_INCLUDE_DIRS ${SOURCE_ROOT}/src PARENT_SCOPE) - set(LIBBPF_LIBRARIES ${LIBBPF_SO_PATH} PARENT_SCOPE) - set(LIBBPF_TARGET_NAME libbpf PARENT_SCOPE) - file(COPY_FILE ${LIBBPF_SO_PATH} ${CMAKE_CURRENT_BINARY_DIR}/libbpf.so) - # Copy libbpf.so to the place where libwasmedgePluginWasmBpf.so exists - message(STATUS "Copied libbpf.so from ${LIBBPF_SO_PATH} to ${CMAKE_CURRENT_BINARY_DIR}/libbpf.so") + set(LIBBPF_TARGET_NAME libbpf PARENT_SCOPE) endfunction() # Try PkgConfig @@ -38,11 +50,19 @@ if(NOT ${LIBBPF_FOUND}) message(STATUS "Try to get libbpf through PkgConfig") # It will set LIBBPF_FOUND for us - pkg_check_modules(LIBBPF libbpf>=1.2 IMPORTED_TARGET) + pkg_check_modules(LIBBPF libbpf>=1.2.0 IMPORTED_TARGET) set(LIBBPF_TARGET_NAME "PkgConfig::LIBBPF") + message(STATUS "LIBBPF_FOUND=${LIBBPF_FOUND}") + + if(${LIBBPF_FOUND}) + SET(LIBBPF_FOUND TRUE) + else() + SET(LIBBPF_FOUND FALSE) + endif() if(${LIBBPF_FOUND}) message(STATUS "libbpf found using PkgConfig") + set(LIBBPF_SOURCE "pkgconf") else() message(STATUS "libbpf not found using pkgconfig") endif() @@ -56,9 +76,10 @@ if(NOT ${LIBBPF_FOUND}) message(STATUS "Try to get libbpf through the pre-defined LIBBPF_SOURCE_DIR") if(DEFINED LIBBPF_SOURCE_DIR) - AddLibbpfAsExternal(${LIBBPF_SOURCE_DIR}) + AddLibbpfAsExternal(${LIBBPF_SOURCE_DIR} ${WASMEDGE_PLUGIN_WASM_BPF_BUILD_LIBBPF_WITH_PKG_CONF}) set(LIBBPF_FOUND TRUE) message(STATUS "libbpf found using LIBBPF_SOURCE_DIR") + set(LIBBPF_SOURCE "sourcedir") else() message(STATUS "LIBBPF_SOURCE_DIR not defined") endif() @@ -83,8 +104,9 @@ if(NOT ${LIBBPF_FOUND}) set(LIBBPF_DOWNLOAD_SOURCE_DIR "${libbpf_SOURCE_DIR}") message(DEBUG "libbpf saved at: ${LIBBPF_DOWNLOAD_SOURCE_DIR}") - AddLibbpfAsExternal(${LIBBPF_DOWNLOAD_SOURCE_DIR}) + AddLibbpfAsExternal(${LIBBPF_DOWNLOAD_SOURCE_DIR} ${WASMEDGE_PLUGIN_WASM_BPF_BUILD_LIBBPF_WITH_PKG_CONF}) set(LIBBPF_FOUND TRUE) + set(LIBBPF_SOURCE "fetch-content") endif() # If we cannot find libbpf.. @@ -92,16 +114,22 @@ if(NOT ${LIBBPF_FOUND}) message(FATAL_ERROR "Could not find libbpf") endif() -message(DEBUG "LIBBPF_INCLUDE_DIRS=${LIBBPF_INCLUDE_DIRS}") -message(DEBUG "LIBBPF_LIBRARIES=${LIBBPF_LIBRARIES}") -message(DEBUG "LIBBPF_TARGET_NAME=${LIBBPF_TARGET_NAME}") +if(${WASMEDGE_PLUGIN_WASM_BPF_BUILD_LIBBPF_WITH_PKG_CONF}) + # Find the dependencies `libelf` and `libz` of libbpf + find_package(PkgConfig) -# Find the dependencies `libelf` and `libz` of libbpf -find_package(PkgConfig) + pkg_check_modules(LIBBPF_DEP REQUIRED libelf zlib) -pkg_check_modules(LIBBPF_DEP REQUIRED libelf zlib) + message(STATUS "(From PKGCONF) LIBBPF_DEP_LIBRARIES=${LIBBPF_DEP_LIBRARIES}") +endif() -message(DEBUG "LIBBPF_DEP_LIBRARIES=${LIBBPF_DEP_LIBRARIES}") +message(STATUS "LIBBPF_INCLUDE_DIRS=${LIBBPF_INCLUDE_DIRS}") +message(STATUS "LIBBPF_LIBRARIES=${LIBBPF_LIBRARIES}") +message(STATUS "LIBBPF_TARGET_NAME=${LIBBPF_TARGET_NAME}") +message(STATUS "LIBBPF_LIBRARIES_STATIC=${LIBBPF_LIBRARIES_STATIC}") +message(STATUS "LIBBPF_SOURCE=${LIBBPF_SOURCE}") +message(STATUS "LIBBPF_DEP_LIBRARIES=${LIBBPF_DEP_LIBRARIES}") +message(STATUS "LIBBPF_DEP_LIBRARIES_STATIC=${LIBBPF_DEP_LIBRARIES_STATIC}") wasmedge_add_library(wasmedgePluginWasmBpf SHARED @@ -117,11 +145,28 @@ wasmedge_add_library(wasmedgePluginWasmBpf ) add_dependencies(wasmedgePluginWasmBpf ${LIBBPF_TARGET_NAME}) -target_link_libraries(wasmedgePluginWasmBpf PUBLIC ${LIBBPF_LIBRARIES} ${LIBBPF_DEP_LIBRARIES}) + +if("${LIBBPF_SOURCE}" STREQUAL "pkgconf") + message(STATUS "Link libbpf dynamically") + target_link_libraries(wasmedgePluginWasmBpf PUBLIC ${LIBBPF_LIBRARIES} ${LIBBPF_DEP_LIBRARIES}) +else() + # Link libbpf statically if we don't use pkgconf, because under this case libbpf is not installed systemwide + message(STATUS "Link libbpf statically") + target_link_libraries(wasmedgePluginWasmBpf PUBLIC ${LIBBPF_LIBRARIES_STATIC} ${LIBBPF_DEP_LIBRARIES}) +endif() + target_include_directories(wasmedgePluginWasmBpf PUBLIC ${LIBBPF_INCLUDE_DIRS}) + set_target_properties(wasmedgePluginWasmBpf PROPERTIES CXX_STANDARD 17 + # Allow tests accessing plugin class functions + CXX_VISIBILITY_PRESET default + VISIBILITY_INLINES_HIDDEN OFF +) +# Fix undefined reference issue of `fmt::v9::formatter::format(WasmEdge::ErrInfo::InfoBoundary const&, fmt::v9::basic_format_context&) const` +target_link_libraries(wasmedgePluginWasmBpf + PUBLIC wasmedgeCommon ) target_compile_options(wasmedgePluginWasmBpf @@ -141,12 +186,10 @@ if(WASMEDGE_LINK_PLUGINS_STATIC) target_link_libraries(wasmedgePluginWasmBpf PRIVATE wasmedgeCAPI - ${LIBBPF_LIBRARIES} ) else() target_link_libraries(wasmedgePluginWasmBpf PRIVATE wasmedge_shared - ${LIBBPF_LIBRARIES} ) endif() diff --git a/plugins/wasm_bpf/bpf-api.h b/plugins/wasm_bpf/bpf-api.h index 9039cf69..3b2018fc 100644 --- a/plugins/wasm_bpf/bpf-api.h +++ b/plugins/wasm_bpf/bpf-api.h @@ -7,11 +7,19 @@ #include "runtime/instance/module.h" #include "wasmedge/wasmedge.h" +#pragma GCC diagnostic push +#ifdef __clang__ +// Allow compilation using clang +#pragma GCC diagnostic warning "-Wextern-c-compat" + +#endif extern "C" { #include #include } +#pragma GCC diagnostic pop + #define POLL_TIMEOUT_MS 100 #define PERF_BUFFER_PAGES 64 #define DEBUG_LIBBPF_RUNTIME 0 diff --git a/plugins/wasm_bpf/func-bpf-map-operate.cpp b/plugins/wasm_bpf/func-bpf-map-operate.cpp index 31ff90ee..e1dbc242 100644 --- a/plugins/wasm_bpf/func-bpf-map-operate.cpp +++ b/plugins/wasm_bpf/func-bpf-map-operate.cpp @@ -11,9 +11,9 @@ extern "C" { namespace WasmEdge { namespace Host { -#define ensure_memory_size(var, offset, size) \ - const auto var##_span = memory->getSpan(offset, size); \ - if (var##_span.size() != size) \ +#define ensure_memory_size(var, offset, expected_size) \ + const auto var##_span = memory->getSpan(offset, expected_size); \ + if (var##_span.size() != expected_size) \ return Unexpect(ErrCode::Value::HostFuncError); \ const auto var = var##_span.data(); Expect @@ -39,7 +39,7 @@ BpfMapOperate::body(const WasmEdge::Runtime::CallingFrame &Frame, int32_t fd, auto key_size = map_info.key_size; auto value_size = map_info.value_size; - switch ((bpf_map_cmd)cmd) { + switch ((bpf_cmd)cmd) { case BPF_MAP_GET_NEXT_KEY: { ensure_memory_size(key_ptr, key, key_size); ensure_memory_size(next_key_ptr, next_key, key_size); diff --git a/test/plugins/wasm_bpf/CMakeLists.txt b/test/plugins/wasm_bpf/CMakeLists.txt index 735764be..c91608c5 100644 --- a/test/plugins/wasm_bpf/CMakeLists.txt +++ b/test/plugins/wasm_bpf/CMakeLists.txt @@ -23,17 +23,20 @@ target_include_directories(wasmBpfTests target_link_libraries(wasmBpfTests PRIVATE ${GTEST_BOTH_LIBRARIES} + wasmedgePluginWasmBpf ) # Link to the WasmEdge library if(WASMEDGE_LINK_PLUGINS_STATIC) target_link_libraries(wasmBpfTests PRIVATE wasmedgeCAPI + wasmedgeExecutor ) else() target_link_libraries(wasmBpfTests PRIVATE wasmedge_shared + wasmedgeExecutor ) endif() diff --git a/test/plugins/wasm_bpf/simple_ringbuf_test.cpp b/test/plugins/wasm_bpf/simple_ringbuf_test.cpp index cda1865b..9202d8f8 100644 --- a/test/plugins/wasm_bpf/simple_ringbuf_test.cpp +++ b/test/plugins/wasm_bpf/simple_ringbuf_test.cpp @@ -56,7 +56,7 @@ class PollCallbackFunction if (data_sz < static_cast(sizeof(uint32_t))) { return WasmEdge::Unexpect(WasmEdge::ErrCode::Value::HostFuncError); } - const uint32_t* dataPtr = memory->getPointer(data, 1); + const uint32_t* dataPtr = memory->getSpan(data, 1).data(); if (unlikely(!dataPtr)) { return WasmEdge::Unexpect(WasmEdge::ErrCode::Value::HostFuncError); } diff --git a/test/plugins/wasm_bpf/wasm_bpf.cpp b/test/plugins/wasm_bpf/wasm_bpf.cpp index 2ac2078b..11adc89a 100644 --- a/test/plugins/wasm_bpf/wasm_bpf.cpp +++ b/test/plugins/wasm_bpf/wasm_bpf.cpp @@ -102,7 +102,7 @@ class PollCallbackFunction if (data_sz < static_cast(sizeof(event))) { return WasmEdge::Unexpect(WasmEdge::ErrCode::Value::HostFuncError); } - const event *dataPtr = memory->getPointer(data, 1); + const event *dataPtr = memory->getSpan(data, 1).data(); if (unlikely(!dataPtr)) { return WasmEdge::Unexpect(WasmEdge::ErrCode::Value::HostFuncError); } From 83912076a97dfa86edc79b46ebc7e7fef6dec557 Mon Sep 17 00:00:00 2001 From: Jorge Prendes Date: Thu, 20 Jul 2023 14:39:51 +0100 Subject: [PATCH 122/623] [CI] build static library in github actions (#2666) Signed-off-by: Jorge Prendes --- utils/docker/Dockerfile.debian-static | 92 ++++++++++++++++++++++ utils/docker/docker-bake.debian-static.hcl | 17 ++++ 2 files changed, 109 insertions(+) create mode 100644 utils/docker/Dockerfile.debian-static create mode 100644 utils/docker/docker-bake.debian-static.hcl diff --git a/utils/docker/Dockerfile.debian-static b/utils/docker/Dockerfile.debian-static new file mode 100644 index 00000000..7c43fc78 --- /dev/null +++ b/utils/docker/Dockerfile.debian-static @@ -0,0 +1,92 @@ +# syntax=docker/dockerfile:1.5-labs + +ARG XX_VERSION=1.2.1 +ARG DEBIAN_VERSION=bullseye +ARG LLVM_VERSION=16 + +FROM --platform=$BUILDPLATFORM tonistiigi/xx:${XX_VERSION} AS xx +FROM --platform=$BUILDPLATFORM debian:${DEBIAN_VERSION} AS base +COPY --from=xx / / + +# Install host dependencies +RUN apt-get update -y && apt-get install --no-install-recommends -y \ + lsb-release software-properties-common curl wget gnupg \ + cmake ninja-build git clang xz-utils + +# Set up llvm's apt repo +ARG LLVM_VERSION +RUN /bin/bash < /etc/apt/trusted.gpg.d/apt.llvm.org.asc + add-apt-repository "deb http://apt.llvm.org/$(lsb_release -sc)/ llvm-toolchain-$(lsb_release -sc)-${LLVM_VERSION} main" + apt-get update -y +EOT + +FROM base as deps +ARG TARGETPLATFORM LLVM_VERSION + +# Install llvm-*-dev: +RUN /bin/bash < package depends on llvm-*: and llvm-*-tools: but + # should depend on llvm-*-tools: and llvm-*:. + # Patch llvm-*-dev: to replace those dependencies. + # See: https://groups.google.com/g/linux.debian.bugs.dist/c/OrHgd5vY278 + cd $(mktemp -d) + xx-apt show llvm-${LLVM_VERSION}-dev + apt-get download llvm-${LLVM_VERSION}-dev:$(xx-info debian-arch) + ar x llvm-${LLVM_VERSION}-dev_*.deb + tar xJf control.tar.* + sed -Ei 's|(llvm-[0-9]*(-tools)?) |\1:'$(TARGETPLATFORM='' TARGETPAIR='' TARGETOS='' TARGETARCH='' TARGETVARIANT='' xx-info debian-arch)' |g' control + tar --ignore-failed-read -czf control.tar.gz {post,pre}{inst,rm} md5sums control + ar rcs llvm-dev-patched.deb debian-binary control.tar.gz data.tar.* + apt-get install --no-install-recommends -y ./llvm-dev-patched.deb +EOT + +# Install other *: dev dependencies +RUN xx-apt-get install --no-install-recommends -y \ + xx-cxx-essentials \ + liblld-${LLVM_VERSION}-dev libpolly-${LLVM_VERSION}-dev \ + libncurses5-dev zlib1g-dev + +# Make a cmake toolchain file +RUN cat < /toolchain.cmake + set(CMAKE_SYSTEM_NAME "Linux") + set(CMAKE_SYSTEM_VERSION 1) + set(CMAKE_SYSTEM_PROCESSOR "$(xx-info march)") + set(CMAKE_C_COMPILER "xx-clang") + set(CMAKE_CXX_COMPILER "xx-clang++") + set(CMAKE_ASM_COMPILER "xx-clang") + set(CMAKE_AR "ar") + set(DPKG_CONFIG_EXECUTABLE $(xx-clang --print-prog-name=pkg-config)) + set(DCMAKE_C_COMPILER_TARGET $(xx-clang --print-target-triple)) + set(DCMAKE_CXX_COMPILER_TARGET $(xx-clang --print-target-triple)) + set(DCMAKE_ASM_COMPILER_TARGET $(xx-clang --print-target-triple)) +EOT + +FROM deps as src +ADD . /src + +FROM src as build +RUN cmake -S /src -B /build -G Ninja \ + -DCMAKE_TOOLCHAIN_FILE=/toolchain.cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + -DCMAKE_INSTALL_PREFIX=/install \ + -DWASMEDGE_BUILD_PACKAGE="TGZ" \ + -DWASMEDGE_BUILD_TESTS=OFF \ + -DWASMEDGE_BUILD_AOT_RUNTIME=ON \ + -DWASMEDGE_BUILD_SHARED_LIB=OFF \ + -DWASMEDGE_BUILD_STATIC_LIB=ON \ + -DWASMEDGE_BUILD_TOOLS=OFF \ + -DWASMEDGE_BUILD_PLUGINS=OFF \ + -DWASMEDGE_BUILD_EXAMPLE=OFF \ + -DWASMEDGE_LINK_LLVM_STATIC=ON \ + -DWASMEDGE_LINK_TOOLS_STATIC=ON + +RUN cmake --build /build -- install +RUN cmake --build /build -- package +RUN mv /build/WasmEdge-*-Linux.tar.gz $(echo build/WasmEdge-*-Linux.tar.gz | sed 's|WasmEdge-\(.*\)-Linux.tar.gz|WasmEdge-\1-debian'$(lsb_release -sr)'_'$(xx-info march)'_static.tar.gz|') + +FROM scratch as tar +COPY --from=build /build/WasmEdge-*.tar.gz / diff --git a/utils/docker/docker-bake.debian-static.hcl b/utils/docker/docker-bake.debian-static.hcl new file mode 100644 index 00000000..5558e95c --- /dev/null +++ b/utils/docker/docker-bake.debian-static.hcl @@ -0,0 +1,17 @@ +group "default" { + targets = ["cross"] +} + +target "base" { + dockerfile = "./utils/docker/Dockerfile.debian-static" + context = "." + output = ["build"] +} + +target "cross" { + inherits = ["base"] + platforms = [ + "linux/amd64", + "linux/arm64" + ] +} From 77bc162f60941e5f8a20ed6b390d4beb47620ad2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=AEm=20Ts=C3=BA-thu=C3=A0n?= Date: Tue, 25 Jul 2023 12:32:36 +0800 Subject: [PATCH 123/623] [Plugin] Implement OpenCV-mini (#2648) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * first draft to setup opencvmini plugin Signed-off-by: Lîm Tsú-thuàn * fix porting name Signed-off-by: Lîm Tsú-thuàn * missing include Signed-off-by: Lîm Tsú-thuàn * update to 2023 Signed-off-by: Lîm Tsú-thuàn * link opencv into plugin Signed-off-by: Lîm Tsú-thuàn * implements three functions 1. imdecode 2. imshow 3. waitkey The `cv::Mat` will be cached in a map & wasm module will only get handle instead of passing huge multi-dim array (`cv::Mat`) around the host & wasm. Signed-off-by: Lîm Tsú-thuàn * opencvmini has no need add options Signed-off-by: Lîm Tsú-thuàn * fix to be workable Signed-off-by: Lîm Tsú-thuàn * add opencvmini option Signed-off-by: Lîm Tsú-thuàn * formatting Signed-off-by: Lîm Tsú-thuàn * fixing linter problem Signed-off-by: Lîm Tsú-thuàn * add `blur` function & `imwrite` `imwrite` will put matrix into a file, and hence. we can check the transformer result `blur` do simple blurring, this is the first transformer we introduce Signed-off-by: Lîm Tsú-thuàn * implements `imencode` The function will write compressed image back into instance's buffer Signed-off-by: Lîm Tsú-thuàn * normalize function NOTE, this is a normalize function from https://github.com/WasmEdge/WasmEdge/commit/f3c1b911a674fefcff02785e6bf7812ca98d9be0#diff-3333d926ca87cf4285bfcd6deae45ee310307be66fca8a4ca6f0f8a946743fccR50-R54 not `cv::normalize` Signed-off-by: Lîm Tsú-thuàn * add function bilinear sampling Signed-off-by: Lîm Tsú-thuàn * release adding opencvmini Signed-off-by: Lîm Tsú-thuàn * build extensions Signed-off-by: Lîm Tsú-thuàn * add test Signed-off-by: Lîm Tsú-thuàn * cmake track testing Signed-off-by: Lîm Tsú-thuàn * metas Signed-off-by: Lîm Tsú-thuàn * install opencv Signed-off-by: Lîm Tsú-thuàn * fix release name Signed-off-by: Lîm Tsú-thuàn * fix dependencies Signed-off-by: Lîm Tsú-thuàn * fix typo in CI configuration Signed-off-by: Lîm Tsú-thuàn * fix name Signed-off-by: Lîm Tsú-thuàn * fix export functions number test Signed-off-by: Lîm Tsú-thuàn * try newer g++ Signed-off-by: Lîm Tsú-thuàn * install libopencv-dev on ubuntu Signed-off-by: Lîm Tsú-thuàn * yum can install opencv Signed-off-by: Lîm Tsú-thuàn * build and limit build thread Signed-off-by: Lîm Tsú-thuàn * retry with no sudo Signed-off-by: Lîm Tsú-thuàn * independent Signed-off-by: Lîm Tsú-thuàn * ninja Signed-off-by: Lîm Tsú-thuàn * ubuntu also use build Signed-off-by: Lîm Tsú-thuàn * extra module Signed-off-by: Lîm Tsú-thuàn * add platform exclusive, and fix naming Signed-off-by: Lîm Tsú-thuàn * remove contrib Signed-off-by: Lîm Tsú-thuàn * fix binary name Signed-off-by: Lîm Tsú-thuàn * macos missing opencv installation Signed-off-by: Lîm Tsú-thuàn * install certain version Signed-off-by: Lîm Tsú-thuàn * fix install script Signed-off-by: Lîm Tsú-thuàn * fix upload name Signed-off-by: Lîm Tsú-thuàn * fix wrong format Signed-off-by: Lîm Tsú-thuàn * outdated trick Signed-off-by: Lîm Tsú-thuàn * test if we remove devtoolset-8 Signed-off-by: Lîm Tsú-thuàn * there has no http plugin now Signed-off-by: Lîm Tsú-thuàn * remove do nothing configuration Signed-off-by: Lîm Tsú-thuàn * blur export kernel parameters Signed-off-by: Lîm Tsú-thuàn * let user can assign extension Signed-off-by: Lîm Tsú-thuàn --------- Signed-off-by: Lîm Tsú-thuàn --- plugins/CMakeLists.txt | 9 ++ plugins/wasmedge_opencvmini/CMakeLists.txt | 39 +++++ plugins/wasmedge_opencvmini/opencvmini_base.h | 25 +++ .../wasmedge_opencvmini/opencvmini_env.cpp | 41 +++++ plugins/wasmedge_opencvmini/opencvmini_env.h | 40 +++++ .../wasmedge_opencvmini/opencvmini_func.cpp | 151 ++++++++++++++++++ plugins/wasmedge_opencvmini/opencvmini_func.h | 95 +++++++++++ .../wasmedge_opencvmini/opencvmini_module.cpp | 36 +++++ .../wasmedge_opencvmini/opencvmini_module.h | 24 +++ test/plugins/CMakeLists.txt | 6 + .../wasmedge_opencvmini/CMakeLists.txt | 35 ++++ .../wasmedge_opencvmini.cpp | 47 ++++++ utils/opencvmini/install-opencvmini.sh | 16 ++ 13 files changed, 564 insertions(+) create mode 100644 plugins/wasmedge_opencvmini/CMakeLists.txt create mode 100644 plugins/wasmedge_opencvmini/opencvmini_base.h create mode 100644 plugins/wasmedge_opencvmini/opencvmini_env.cpp create mode 100644 plugins/wasmedge_opencvmini/opencvmini_env.h create mode 100644 plugins/wasmedge_opencvmini/opencvmini_func.cpp create mode 100644 plugins/wasmedge_opencvmini/opencvmini_func.h create mode 100644 plugins/wasmedge_opencvmini/opencvmini_module.cpp create mode 100644 plugins/wasmedge_opencvmini/opencvmini_module.h create mode 100644 test/plugins/wasmedge_opencvmini/CMakeLists.txt create mode 100644 test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp create mode 100644 utils/opencvmini/install-opencvmini.sh diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index 80e50955..b9e52060 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -54,6 +54,15 @@ if(WASMEDGE_PLUGIN_WASM_BPF) endif() endif() +if(WASMEDGE_PLUGIN_OPENCVMINI) + # Only Linux and MacOS support wasmedge_opencvmini now. + if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") + add_subdirectory(wasmedge_opencvmini) + else() + message(WARNING "Only Linux and Darwin platforms support WasmEdge_OpenCVMini plug-in now.") + endif() +endif() + if(WASMEDGE_PLUGIN_WASI_LOGGING) add_subdirectory(wasi_logging) endif() diff --git a/plugins/wasmedge_opencvmini/CMakeLists.txt b/plugins/wasmedge_opencvmini/CMakeLists.txt new file mode 100644 index 00000000..6949132d --- /dev/null +++ b/plugins/wasmedge_opencvmini/CMakeLists.txt @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2023 Second State INC + +find_package(OpenCV 4 REQUIRED) + +wasmedge_add_library(wasmedgePluginWasmEdgeOpenCVMini + SHARED + opencvmini_env.cpp + opencvmini_func.cpp + opencvmini_module.cpp +) + +target_compile_options(wasmedgePluginWasmEdgeOpenCVMini + PUBLIC + -DWASMEDGE_PLUGIN +) + +target_include_directories(wasmedgePluginWasmEdgeOpenCVMini + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} + ${OpenCV_INClUDE_DIRS} +) + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasmEdgeOpenCVMini + PRIVATE + wasmedgeCAPI + ${OpenCV_LIBS} + ) +else() + target_link_libraries(wasmedgePluginWasmEdgeOpenCVMini + PRIVATE + wasmedge_shared + ${OpenCV_LIBS} + ) +endif() + +install(TARGETS wasmedgePluginWasmEdgeOpenCVMini DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) diff --git a/plugins/wasmedge_opencvmini/opencvmini_base.h b/plugins/wasmedge_opencvmini/opencvmini_base.h new file mode 100644 index 00000000..1a4dba74 --- /dev/null +++ b/plugins/wasmedge_opencvmini/opencvmini_base.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2023 Second State INC + +#pragma once + +#include "opencvmini_env.h" + +#include "common/errcode.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { + +template +class WasmEdgeOpenCVMini : public Runtime::HostFunction { +public: + WasmEdgeOpenCVMini(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + WasmEdgeOpenCVMiniEnvironment &Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_opencvmini/opencvmini_env.cpp b/plugins/wasmedge_opencvmini/opencvmini_env.cpp new file mode 100644 index 00000000..c205d7f4 --- /dev/null +++ b/plugins/wasmedge_opencvmini/opencvmini_env.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2023 Second State INC + +#include "opencvmini_env.h" +#include "opencvmini_module.h" + +namespace WasmEdge { +namespace Host { + +WasmEdgeOpenCVMiniEnvironment::WasmEdgeOpenCVMiniEnvironment() noexcept {} + +namespace { + +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeOpenCVMiniModule(); +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasmedge_opencvmini", + .Description = "", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 1, 0, 0}, + .ModuleCount = 1, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "wasmedge_opencvmini", + .Description = "", + .Create = create, + }, + }, + .AddOptions = nullptr, +}; + +} // namespace + +Plugin::PluginRegister WasmEdgeOpenCVMiniEnvironment::Register(&Descriptor); + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_opencvmini/opencvmini_env.h b/plugins/wasmedge_opencvmini/opencvmini_env.h new file mode 100644 index 00000000..9ab004ea --- /dev/null +++ b/plugins/wasmedge_opencvmini/opencvmini_env.h @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2023 Second State INC + +#pragma once + +#include "plugin/plugin.h" + +#include +#include +#include + +namespace WasmEdge { +namespace Host { + +class WasmEdgeOpenCVMiniEnvironment { +public: + WasmEdgeOpenCVMiniEnvironment() noexcept; + + static Plugin::PluginRegister Register; + + std::map MatPool; + + Expect getMat(uint32_t MatKey) { + if (auto V = this->MatPool.find(MatKey); V != this->MatPool.end()) { + return V->second; + } else { + return Unexpect(ErrCode::Value::HostFuncError); + } + } + + Expect insertMat(const cv::Mat &Img) { + // cv::Mat::flags contains magic signature & I believe it's a good enough + // key for this purpose. + this->MatPool[static_cast(Img.flags)] = Img; + return static_cast(Img.flags); + } +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_opencvmini/opencvmini_func.cpp b/plugins/wasmedge_opencvmini/opencvmini_func.cpp new file mode 100644 index 00000000..ab5da5d9 --- /dev/null +++ b/plugins/wasmedge_opencvmini/opencvmini_func.cpp @@ -0,0 +1,151 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2023 Second State INC + +#include "opencvmini_func.h" +#include "common/defines.h" + +#include +#include +#include + +namespace WasmEdge { +namespace Host { + +Expect +WasmEdgeOpenCVMiniImdecode::body(const Runtime::CallingFrame &Frame, + uint32_t BufPtr, uint32_t BufLen) { + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + char *Buf = MemInst->getPointer(BufPtr); + + std::vector Content(Buf, Buf + BufLen); + cv::Mat Img = cv::imdecode(cv::InputArray(Content), cv::IMREAD_COLOR); + + return Env.insertMat(Img); +} + +Expect WasmEdgeOpenCVMiniImshow::body(const Runtime::CallingFrame &Frame, + uint32_t WindowNamePtr, + uint32_t WindowNameLen, + uint32_t MatKey) { + std::string WindowName; + + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + char *Buf = MemInst->getPointer(WindowNamePtr); + std::copy_n(Buf, WindowNameLen, std::back_inserter(WindowName)); + + if (auto Img = Env.getMat(MatKey); Img) { + cv::imshow(WindowName.c_str(), *Img); + } + + return {}; +} + +Expect WasmEdgeOpenCVMiniWaitKey::body(const Runtime::CallingFrame &, + uint32_t Delay) { + cv::waitKey(static_cast(Delay)); + return {}; +} + +Expect WasmEdgeOpenCVMiniBlur::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey, + uint32_t KernelWidth, + uint32_t KernelHeight) { + cv::Mat Dst; + if (auto Src = Env.getMat(SrcMatKey); Src) { + cv::blur(*Src, Dst, cv::Size(KernelWidth, KernelHeight)); + } + return Env.insertMat(Dst); +} + +Expect WasmEdgeOpenCVMiniImwrite::body(const Runtime::CallingFrame &Frame, + uint32_t TargetFileNamePtr, + uint32_t TargetFileNameLen, + uint32_t MatKey) { + std::string TargetFileName; + + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + char *Buf = MemInst->getPointer(TargetFileNamePtr); + std::copy_n(Buf, TargetFileNameLen, std::back_inserter(TargetFileName)); + + if (auto Img = Env.getMat(MatKey); Img) { + cv::imwrite(TargetFileName.c_str(), *Img); + } + + return {}; +} + +Expect WasmEdgeOpenCVMiniImencode::body( + const Runtime::CallingFrame &Frame, uint32_t ExtPtr, uint32_t ExtLen, + uint32_t MatKey, uint32_t BufPtr, uint32_t BufLen) { + std::string Ext; + + auto *MemInst = Frame.getMemoryByIndex(0); + + char *Buf = MemInst->getPointer(ExtPtr); + std::copy_n(Buf, ExtLen, std::back_inserter(Ext)); + + auto Img = Env.getMat(MatKey); + if (!Img) { + spdlog::error("[WasmEdge-OpenCVMini] "sv + "Failed to get matrix by key."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto OutSpan = MemInst->getSpan(BufPtr, BufLen); + if (unlikely(OutSpan.size() != BufLen)) { + spdlog::error("[WasmEdge-OpenCVMini] "sv + "Failed when accessing the image target buffer memory."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + std::vector WriteTo; + cv::imencode(Ext, *Img, WriteTo); + + std::copy_n(WriteTo.begin(), WriteTo.size(), OutSpan.begin()); + + return {}; +} + +Expect +WasmEdgeOpenCVMiniNormalize::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey) { + auto Src = Env.getMat(SrcMatKey); + if (!Src) { + return Unexpect(ErrCode::Value::HostFuncError); + } + cv::Mat Dst; + // convert each elements `v` of `Src` to `(1/255) * v + 0` + Src->convertTo(Dst, CV_32F, 1. / 255., 0.); + return Env.insertMat(Dst); +} + +Expect +WasmEdgeOpenCVMiniBilinearSampling::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey, uint32_t OutImgW, + uint32_t OutImgH) { + auto Src = Env.getMat(SrcMatKey); + if (!Src) { + return Unexpect(ErrCode::Value::HostFuncError); + } + cv::Mat Dst; + cv::resize(*Src, Dst, cv::Size(OutImgW, OutImgH), 0, 0, cv::INTER_LINEAR); + return Env.insertMat(Dst); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_opencvmini/opencvmini_func.h b/plugins/wasmedge_opencvmini/opencvmini_func.h new file mode 100644 index 00000000..3743f416 --- /dev/null +++ b/plugins/wasmedge_opencvmini/opencvmini_func.h @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2023 Second State INC + +#pragma once + +#include "opencvmini_base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { + +/// Read image from buffer +class WasmEdgeOpenCVMiniImdecode + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniImdecode(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t BufPtr, + uint32_t BufLen); +}; + +/// Write image into buffer +class WasmEdgeOpenCVMiniImencode + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniImencode(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + + Expect body(const Runtime::CallingFrame &Frame, uint32_t ExtPtr, + uint32_t ExtLen, uint32_t MatKey, uint32_t BufPtr, + uint32_t BufLen); +}; + +class WasmEdgeOpenCVMiniImshow + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniImshow(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t WindowNamePtr, + uint32_t WindowNameLen, uint32_t MatKey); +}; + +class WasmEdgeOpenCVMiniWaitKey + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniWaitKey(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + + Expect body(const Runtime::CallingFrame &Frame, uint32_t Delay); +}; + +class WasmEdgeOpenCVMiniBlur + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniBlur(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + + Expect body(const Runtime::CallingFrame &Frame, uint32_t SrcMatKey, + uint32_t KernelWidth, uint32_t KernelHeight); +}; + +class WasmEdgeOpenCVMiniImwrite + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniImwrite(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + + Expect body(const Runtime::CallingFrame &Frame, + uint32_t TargetFileNamePtr, uint32_t TargetFileNameLen, + uint32_t SrcMatKey); +}; + +/// This is not `cv::normalize`, refers to: +/// https://github.com/WasmEdge/WasmEdge/commit/77051da4995d7318d91a82102a72ce2557151764#diff-3333d926ca87cf4285bfcd6deae45ee310307be66fca8a4ca6f0f8a946743fccR50-R54 +class WasmEdgeOpenCVMiniNormalize + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniNormalize(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + + Expect body(const Runtime::CallingFrame &Frame, uint32_t SrcMatKey); +}; + +class WasmEdgeOpenCVMiniBilinearSampling + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniBilinearSampling(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + + Expect body(const Runtime::CallingFrame &Frame, uint32_t SrcMatKey, + uint32_t OutImgW, uint32_t OutImgH); +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_opencvmini/opencvmini_module.cpp b/plugins/wasmedge_opencvmini/opencvmini_module.cpp new file mode 100644 index 00000000..ab289d36 --- /dev/null +++ b/plugins/wasmedge_opencvmini/opencvmini_module.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2023 Second State INC + +#include "opencvmini_module.h" +#include "opencvmini_func.h" + +#include + +namespace WasmEdge { +namespace Host { + +WasmEdgeOpenCVMiniModule::WasmEdgeOpenCVMiniModule() + : ModuleInstance("wasmedge_opencvmini") { + addHostFunc("wasmedge_opencvmini_imdecode", + std::make_unique(Env)); + addHostFunc("wasmedge_opencvmini_imencode", + std::make_unique(Env)); + + addHostFunc("wasmedge_opencvmini_imwrite", + std::make_unique(Env)); + + addHostFunc("wasmedge_opencvmini_blur", + std::make_unique(Env)); + addHostFunc("wasmedge_opencvmini_normalize", + std::make_unique(Env)); + addHostFunc("wasmedge_opencvmini_bilinear_sampling", + std::make_unique(Env)); + + addHostFunc("wasmedge_opencvmini_imshow", + std::make_unique(Env)); + addHostFunc("wasmedge_opencvmini_waitkey", + std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_opencvmini/opencvmini_module.h b/plugins/wasmedge_opencvmini/opencvmini_module.h new file mode 100644 index 00000000..0175110d --- /dev/null +++ b/plugins/wasmedge_opencvmini/opencvmini_module.h @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2023 Second State INC + +#pragma once + +#include "opencvmini_env.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasmEdgeOpenCVMiniModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeOpenCVMiniModule(); + + WasmEdgeOpenCVMiniEnvironment &getEnv() { return Env; } + +private: + WasmEdgeOpenCVMiniEnvironment Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt index d8dfb3c5..ac61f40a 100644 --- a/test/plugins/CMakeLists.txt +++ b/test/plugins/CMakeLists.txt @@ -33,6 +33,12 @@ if(WASMEDGE_PLUGIN_IMAGE) endif() endif() +if(WASMEDGE_PLUGIN_OPENCVMINI) + if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") + add_subdirectory(wasmedge_opencvmini) + endif() +endif() + if(WASMEDGE_PLUGIN_WASM_BPF) if(CMAKE_SYSTEM_NAME MATCHES "Linux") add_subdirectory(wasm_bpf) diff --git a/test/plugins/wasmedge_opencvmini/CMakeLists.txt b/test/plugins/wasmedge_opencvmini/CMakeLists.txt new file mode 100644 index 00000000..33fc1425 --- /dev/null +++ b/test/plugins/wasmedge_opencvmini/CMakeLists.txt @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2023 Second State INC + +wasmedge_add_executable(wasmedgeOpencvminiTests + wasmedge_opencvmini.cpp +) + +add_dependencies(wasmedgeOpencvminiTests + wasmedgePluginWasmEdgeOpenCVMini +) + +target_include_directories(wasmedgeOpencvminiTests + PUBLIC + $ + $ +) + +target_link_libraries(wasmedgeOpencvminiTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} +) +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgeOpencvminiTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgeOpencvminiTests + PRIVATE + wasmedge_shared + ) +endif() + +add_test(wasmedgeOpencvminiTests wasmedgeOpencvminiTests) diff --git a/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp b/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp new file mode 100644 index 00000000..7e28e9cd --- /dev/null +++ b/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "common/defines.h" +#include "opencvmini_func.h" +#include "opencvmini_module.h" +#include "runtime/instance/module.h" + +#include +#include +#include +#include +#include +#include + +namespace { +WasmEdge::Runtime::Instance::ModuleInstance *createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasmedge_opencvmini/" + "libwasmedgePluginWasmEdgeOpenCVMini" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = + WasmEdge::Plugin::Plugin::find("wasmedge_opencvmini"sv)) { + if (const auto *Module = Plugin->findModule("wasmedge_opencvmini"sv)) { + return Module->create().release(); + } + } + return nullptr; +} +} // namespace + +// TODO: unit tests for every functions. + +TEST(WasmEdgeOpecvminiTest, Module) { + // Create the wasmedge_opencvmini module instance. + auto *ImgMod = + dynamic_cast(createModule()); + EXPECT_FALSE(ImgMod == nullptr); + EXPECT_EQ(ImgMod->getFuncExportNum(), 8U); + EXPECT_NE(ImgMod->findFuncExports("wasmedge_opencvmini_imdecode"), nullptr); + delete ImgMod; +} + +GTEST_API_ int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/utils/opencvmini/install-opencvmini.sh b/utils/opencvmini/install-opencvmini.sh new file mode 100644 index 00000000..ac5128ef --- /dev/null +++ b/utils/opencvmini/install-opencvmini.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2023 Second State INC + +wget -O opencv.zip https://github.com/opencv/opencv/archive/refs/tags/4.8.0.zip + +unzip opencv.zip +mv opencv-4.8.0 opencv + +mkdir -p opencv/build && cd opencv/build +# Configure +cmake -GNinja .. +# Build +cmake --build . +# Install to system +cmake --install . From d085ab6d94e2d899ed4254a62cb3344f4e188afb Mon Sep 17 00:00:00 2001 From: Jorge Prendes Date: Mon, 31 Jul 2023 08:10:56 +0100 Subject: [PATCH 124/623] [CI] build alpine static libraries (#2699) Signed-off-by: Jorge Prendes --- utils/docker/Dockerfile.alpine-static | 96 ++++++++++++++++++++++ utils/docker/docker-bake.alpine-static.hcl | 17 ++++ 2 files changed, 113 insertions(+) create mode 100644 utils/docker/Dockerfile.alpine-static create mode 100644 utils/docker/docker-bake.alpine-static.hcl diff --git a/utils/docker/Dockerfile.alpine-static b/utils/docker/Dockerfile.alpine-static new file mode 100644 index 00000000..d92dd8da --- /dev/null +++ b/utils/docker/Dockerfile.alpine-static @@ -0,0 +1,96 @@ +# syntax=docker/dockerfile:1.5-labs + +ARG XX_VERSION=1.2.1 +ARG ALPINE_VERSION=3.16 +# alpine 3.16 ships with llvm 13. +# alpine 3.17 and 3.18 do not ship lld-static. + +FROM --platform=$BUILDPLATFORM tonistiigi/xx:${XX_VERSION} AS xx +FROM --platform=$BUILDPLATFORM alpine:${ALPINE_VERSION} AS base +COPY --from=xx / / + +# Install host dependencies +RUN apk add bash cmake samurai g++ clang +SHELL [ "bash", "-c" ] + +# Make a cmake toolchain file +RUN cat <<'EOT' > /usr/bin/xx-toolchain && chmod a+x /usr/bin/xx-toolchain +#!/bin/bash +mkdir -p /etc/xx-toolchains/ +TOOLCHAIN_FILE="/etc/xx-toolchains/$(xx-clang --print-target-triple).cmake" +[ -f "$TOOLCHAIN_FILE" ] || cat < "$TOOLCHAIN_FILE" +set(CMAKE_CROSSCOMPILING ON) +set(CMAKE_SYSROOT "$(xx-info sysroot)") +set(CMAKE_SYSTEM_NAME "Linux") +set(CMAKE_SYSTEM_VERSION 1) +set(CMAKE_SYSTEM_PROCESSOR "$(xx-info march)") +set(CMAKE_C_COMPILER "$(which clang)") +set(CMAKE_CXX_COMPILER "$(which clang++)") +set(CMAKE_ASM_COMPILER "$(which clang)") +set(CMAKE_AR "$(which ar)") +set(CMAKE_RANLIB "$(which ranlib)") +set(PKG_CONFIG_EXECUTABLE "$(xx-clang --print-prog-name=pkg-config)") +set(CMAKE_C_COMPILER_TARGET "$(xx-clang --print-target-triple)") +set(CMAKE_CXX_COMPILER_TARGET "$(xx-clang --print-target-triple)") +set(CMAKE_ASM_COMPILER_TARGET "$(xx-clang --print-target-triple)") +EOF +echo "$TOOLCHAIN_FILE" +EOT + +FROM base AS config +ARG TARGETPLATFORM + +RUN apk add git llvm-dev +RUN xx-apk add \ + g++ \ + llvm-dev llvm-static \ + lld lld-dev lld-static \ + zlib-dev zlib-static + +# Fix LLVMConfig.cmake to account for the sysroot +RUN sed -i 's|/usr/lib/llvm|${CMAKE_SYSROOT}usr/lib/llvm|' $(xx-info sysroot)usr/lib/cmake/llvm*/LLVMConfig.cmake +# In cmake/Helper.cmake we assume that lld is installed alongside llvm, so copy files over +RUN cp $(xx-info sysroot)usr/lib/liblld*.a $(xx-info sysroot)usr/lib/llvm*/lib/ +# Use the native version of llvm-config +RUN ! xx-info is-cross || cp -f /usr/lib/llvm*/bin/llvm-config $(xx-info sysroot)usr/lib/llvm*/bin/llvm-config + +RUN --mount=type=bind,target=/src,source=. \ + cmake -S /src -B /build -G Ninja \ + -DCMAKE_BUILD_TYPE=MinSizeRel \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + -DCMAKE_INSTALL_PREFIX="/install" \ + # For cross compiling + -DCMAKE_TOOLCHAIN_FILE="$(xx-toolchain)" \ + -DWASMEDGE_BUILD_PACKAGE="TGZ" \ + -DWASMEDGE_BUILD_AOT_RUNTIME=ON \ + # Build just what we need + -DWASMEDGE_BUILD_STATIC_LIB=ON \ + -DWASMEDGE_BUILD_TESTS=OFF \ + -DWASMEDGE_BUILD_SHARED_LIB=OFF \ + -DWASMEDGE_BUILD_TOOLS=OFF \ + -DWASMEDGE_BUILD_PLUGINS=OFF \ + -DWASMEDGE_BUILD_EXAMPLE=OFF \ + # Link llvm statically + -DWASMEDGE_LINK_LLVM_STATIC=ON \ + -DWASMEDGE_LINK_TOOLS_STATIC=ON \ + # Disable extra dependencies + -DWASMEDGE_DISABLE_LIBTINFO=ON + +FROM config AS build +RUN --mount=type=bind,target=/src,source=. \ + cmake --build /build -- install package + +ARG ALPINE_VERSION +RUN --mount=type=bind,target=/src,source=. < Date: Wed, 2 Aug 2023 22:14:59 +0530 Subject: [PATCH 125/623] [WASI-NN] Added support for Tuple Type Output Tensors in Pytorch Backend (#2564) Signed-off-by: Sarrah Bastawala --- plugins/wasi_nn/torch.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/plugins/wasi_nn/torch.cpp b/plugins/wasi_nn/torch.cpp index 948ace8d..ba1a26f9 100644 --- a/plugins/wasi_nn/torch.cpp +++ b/plugins/wasi_nn/torch.cpp @@ -137,6 +137,11 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { for (auto &OneOf : OutTensors) { CxtRef.TorchOutputs.push_back(OneOf.clone()); } + } else if (RawOutput.isTuple()) { + auto OutTensorsTuple = RawOutput.toTuple()->elements(); + for (auto &OneOf : OutTensorsTuple) { + CxtRef.TorchOutputs.push_back(OneOf.toTensor().clone()); + } } else if (RawOutput.isTensor()) { auto OutTensor = RawOutput.toTensor(); CxtRef.TorchOutputs.push_back(OutTensor.clone()); From 2796d5048a1815914596867f129e10898460dfe2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=AEm=20Ts=C3=BA-thu=C3=A0n?= Date: Tue, 8 Aug 2023 01:08:52 +0800 Subject: [PATCH 126/623] [Plugin] opecvmini `rectangle` and `cvtColor` (#2705) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [Plugin] opecvmini `rectangle` This function can draw rectangle on picture buffer. Signed-off-by: Lîm Tsú-thuàn * rectangle will modify the given matrix, no need to return Signed-off-by: Lîm Tsú-thuàn * update tests Signed-off-by: Lîm Tsú-thuàn * add `cvtColor` function to change color space Signed-off-by: Lîm Tsú-thuàn * provide all options for `rectangle` Signed-off-by: Lîm Tsú-thuàn * allows all parameter, fix review comment Signed-off-by: Lîm Tsú-thuàn --------- Signed-off-by: Lîm Tsú-thuàn --- .../wasmedge_opencvmini/opencvmini_func.cpp | 33 +++++++++++++++++++ plugins/wasmedge_opencvmini/opencvmini_func.h | 22 +++++++++++++ .../wasmedge_opencvmini/opencvmini_module.cpp | 5 +++ .../wasmedge_opencvmini.cpp | 5 ++- 4 files changed, 64 insertions(+), 1 deletion(-) diff --git a/plugins/wasmedge_opencvmini/opencvmini_func.cpp b/plugins/wasmedge_opencvmini/opencvmini_func.cpp index ab5da5d9..7a926946 100644 --- a/plugins/wasmedge_opencvmini/opencvmini_func.cpp +++ b/plugins/wasmedge_opencvmini/opencvmini_func.cpp @@ -142,10 +142,43 @@ WasmEdgeOpenCVMiniBilinearSampling::body(const Runtime::CallingFrame &, if (!Src) { return Unexpect(ErrCode::Value::HostFuncError); } + cv::Mat Dst; cv::resize(*Src, Dst, cv::Size(OutImgW, OutImgH), 0, 0, cv::INTER_LINEAR); return Env.insertMat(Dst); } +Expect WasmEdgeOpenCVMiniRectangle::body( + const Runtime::CallingFrame &, uint32_t SrcMatKey, uint32_t Top, + uint32_t Left, uint32_t Bot, uint32_t Right, double R, double G, double B, + int32_t Thickness, int32_t LineType, int32_t Shift) { + auto Src = Env.getMat(SrcMatKey); + if (!Src) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + cv::Point TopLeft(Top, Left); + cv::Point BottomRight(Bot, Right); + + cv::rectangle(*Src, TopLeft, BottomRight, cv::Scalar(B, G, R), Thickness, + LineType, Shift); + return {}; +} + +Expect WasmEdgeOpenCVMiniCvtColor::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey, + int32_t Code, + int32_t DestChannelN) { + auto Src = Env.getMat(SrcMatKey); + if (!Src) { + return Unexpect(ErrCode::Value::HostFuncError); + } + auto Img = *Src; + + cv::Mat Dst; + cvtColor(Img, Dst, Code, DestChannelN); + return Env.insertMat(Dst); +} + } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasmedge_opencvmini/opencvmini_func.h b/plugins/wasmedge_opencvmini/opencvmini_func.h index 3743f416..f86f07c9 100644 --- a/plugins/wasmedge_opencvmini/opencvmini_func.h +++ b/plugins/wasmedge_opencvmini/opencvmini_func.h @@ -91,5 +91,27 @@ class WasmEdgeOpenCVMiniBilinearSampling uint32_t OutImgW, uint32_t OutImgH); }; +class WasmEdgeOpenCVMiniRectangle + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniRectangle(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + + Expect body(const Runtime::CallingFrame &, uint32_t SrcMatKey, + uint32_t Top, uint32_t Left, uint32_t Bot, uint32_t Right, + double R, double G, double B, int32_t Thickness, + int32_t LineType, int32_t Shift); +}; + +class WasmEdgeOpenCVMiniCvtColor + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniCvtColor(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + + Expect body(const Runtime::CallingFrame &, uint32_t SrcMatKey, + int32_t Code, int32_t DestChannelN); +}; + } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasmedge_opencvmini/opencvmini_module.cpp b/plugins/wasmedge_opencvmini/opencvmini_module.cpp index ab289d36..c515889f 100644 --- a/plugins/wasmedge_opencvmini/opencvmini_module.cpp +++ b/plugins/wasmedge_opencvmini/opencvmini_module.cpp @@ -25,6 +25,11 @@ WasmEdgeOpenCVMiniModule::WasmEdgeOpenCVMiniModule() std::make_unique(Env)); addHostFunc("wasmedge_opencvmini_bilinear_sampling", std::make_unique(Env)); + addHostFunc("wasmedge_opencvmini_cvt_color", + std::make_unique(Env)); + + addHostFunc("wasmedge_opencvmini_rectangle", + std::make_unique(Env)); addHostFunc("wasmedge_opencvmini_imshow", std::make_unique(Env)); diff --git a/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp b/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp index 7e28e9cd..0440a680 100644 --- a/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp +++ b/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp @@ -36,8 +36,11 @@ TEST(WasmEdgeOpecvminiTest, Module) { auto *ImgMod = dynamic_cast(createModule()); EXPECT_FALSE(ImgMod == nullptr); - EXPECT_EQ(ImgMod->getFuncExportNum(), 8U); + EXPECT_EQ(ImgMod->getFuncExportNum(), 10U); EXPECT_NE(ImgMod->findFuncExports("wasmedge_opencvmini_imdecode"), nullptr); + EXPECT_NE(ImgMod->findFuncExports("wasmedge_opencvmini_imencode"), nullptr); + EXPECT_NE(ImgMod->findFuncExports("wasmedge_opencvmini_rectangle"), nullptr); + EXPECT_NE(ImgMod->findFuncExports("wasmedge_opencvmini_cvt_color"), nullptr); delete ImgMod; } From 443426595e380bf3912788640e8502f258ed83f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=AEm=20Ts=C3=BA-thu=C3=A0n?= Date: Thu, 10 Aug 2023 05:58:29 +0800 Subject: [PATCH 127/623] [CI] pre-build image with necessary dependency for extension (#2731) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [CI] try to pre-build image with necessary dependency for extension related with #2729 Signed-off-by: Lîm Tsú-thuàn * revert ci changes Signed-off-by: Lîm Tsú-thuàn * rename dockerfile and tag Signed-off-by: Lîm Tsú-thuàn * revert ci changes Signed-off-by: Lîm Tsú-thuàn * share same installation script Signed-off-by: Lîm Tsú-thuàn * share the same dockerfile for plugin-deps Since the base image can be assigned by --build-arg Signed-off-by: Lîm Tsú-thuàn * Assign opencv version in dockerfile via environment variable Signed-off-by: Lîm Tsú-thuàn * Fix, must use version variable Signed-off-by: Lîm Tsú-thuàn * build image plugins-deps in actions Signed-off-by: Lîm Tsú-thuàn --------- Signed-off-by: Lîm Tsú-thuàn --- utils/docker/Dockerfile.build-plugins-deps | 12 ++++++++++++ utils/docker/build.sh | 14 +++++++++----- utils/docker/install-opencvmini.sh | 16 ++++++++++++++++ 3 files changed, 37 insertions(+), 5 deletions(-) create mode 100644 utils/docker/Dockerfile.build-plugins-deps create mode 100644 utils/docker/install-opencvmini.sh diff --git a/utils/docker/Dockerfile.build-plugins-deps b/utils/docker/Dockerfile.build-plugins-deps new file mode 100644 index 00000000..21c73d97 --- /dev/null +++ b/utils/docker/Dockerfile.build-plugins-deps @@ -0,0 +1,12 @@ +ARG BASE=wasmedge/wasmedge:ubuntu-build-clang +FROM ${BASE} + +RUN apt update && apt install -y \ + wget \ + unzip + +RUN rm -rf /var/lib/apt/lists/* + +COPY install-opencvmini.sh . +ENV OPENCV_VERSION=4.8.0 +RUN [ "/bin/bash", "install-opencvmini.sh" ] diff --git a/utils/docker/build.sh b/utils/docker/build.sh index 5c4ed150..3d99a86f 100755 --- a/utils/docker/build.sh +++ b/utils/docker/build.sh @@ -25,14 +25,18 @@ function docker_build } # Build all images. -docker_build Dockerfile.base ubuntu-base -docker_build Dockerfile.ci-image-base ci-image-base -docker_build Dockerfile.build-clang ubuntu-build-clang \ +docker_build Dockerfile.base ubuntu-base +docker_build Dockerfile.ci-image-base ci-image-base +docker_build Dockerfile.build-clang ubuntu-build-clang \ --build-arg "BASE=${NAME}:ubuntu-base" -docker_build Dockerfile.build-clang latest \ +docker_build Dockerfile.build-clang latest \ --build-arg "BASE=${NAME}:ubuntu-base" -docker_build Dockerfile.build-gcc ubuntu-build-gcc \ +docker_build Dockerfile.build-gcc ubuntu-build-gcc \ --build-arg "BASE=${NAME}:ubuntu-base" +docker_build Dockerfile.build-plugins-deps ubuntu-build-clang-plugins-deps \ + --build-arg "BASE=${NAME}:ubuntu-build-clang" +docker_build Dockerfile.build-plugins-deps ubuntu-build-gcc-plugins-deps \ + --build-arg "BASE=${NAME}:ubuntu-build-gcc" # Remove intermediate images. for NAME_TAG in "${INTERMEDIATES[@]}"; do diff --git a/utils/docker/install-opencvmini.sh b/utils/docker/install-opencvmini.sh new file mode 100644 index 00000000..f66ef586 --- /dev/null +++ b/utils/docker/install-opencvmini.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2023 Second State INC + +wget -O opencv.zip https://github.com/opencv/opencv/archive/refs/tags/${OPENCV_VERSION}.zip + +unzip opencv.zip +mv opencv-${OPENCV_VERSION} opencv + +mkdir -p opencv/build && cd opencv/build +# Configure +cmake -GNinja .. +# Build +cmake --build . +# Install to system +cmake --install . From 2a6b2c6ff09f0c54ccbca89626d53ea6d9d8d03b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=AEm=20Ts=C3=BA-thu=C3=A0n?= Date: Tue, 15 Aug 2023 16:16:44 +0800 Subject: [PATCH 128/623] [Docker] manylinux plugins deps pre-built image (#2737) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Lîm Tsú-thuàn --- .../Dockerfile.manylinux2014-build-plugins-deps | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 utils/docker/Dockerfile.manylinux2014-build-plugins-deps diff --git a/utils/docker/Dockerfile.manylinux2014-build-plugins-deps b/utils/docker/Dockerfile.manylinux2014-build-plugins-deps new file mode 100644 index 00000000..45bcd857 --- /dev/null +++ b/utils/docker/Dockerfile.manylinux2014-build-plugins-deps @@ -0,0 +1,15 @@ +ARG BASE=wasmedge/wasmedge:manylinux2014_x86_64 +FROM ${BASE} + +ENV PATH /opt/rh/devtoolset-11/root/usr/bin${PATH:+:${PATH}} +ENV MANPATH /opt/rh/devtoolset-11/root/usr/share/man${MANPATH:+:${MANPATH}} +ENV INFOPATH /opt/rh/devtoolset-11/root/usr/share/info${INFOPATH:+:${INFOPATH}} +ENV PKG_CONFIG_PATH /opt/rh/devtoolset-11/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} + +RUN cd && (yum check-update || true) && yum install -y cmake wget unzip + +COPY install-opencvmini.sh . +ENV OPENCV_VERSION=4.8.0 +RUN [ "/bin/bash", "install-opencvmini.sh" ] + +RUN yum clean all From 991ac9cf711685b77f40c4939721295b74447650 Mon Sep 17 00:00:00 2001 From: Sarrah Bastawala <84874044+sarrah-basta@users.noreply.github.com> Date: Mon, 21 Aug 2023 21:24:27 +0530 Subject: [PATCH 129/623] [WASI-NN] Correct fallback error message after adding support for tuple-type output tensors in PyTorch for Wasi-NN (#2747) Signed-off-by: Sarrah Bastawala --- plugins/wasi_nn/torch.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn/torch.cpp b/plugins/wasi_nn/torch.cpp index ba1a26f9..153ed32f 100644 --- a/plugins/wasi_nn/torch.cpp +++ b/plugins/wasi_nn/torch.cpp @@ -146,8 +146,8 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { auto OutTensor = RawOutput.toTensor(); CxtRef.TorchOutputs.push_back(OutTensor.clone()); } else { - spdlog::error("[WASI-NN] PyTorch backend only supports output a tensor " - "or a list of tensor"); + spdlog::error("[WASI-NN] PyTorch backend only supports output a tensor, " + "a list of tensor or a tuple of tensor"); return ErrNo::InvalidArgument; } return ErrNo::Success; From 10a45eddf7975ca6ed1b30f27968402447c33eec Mon Sep 17 00:00:00 2001 From: vincent Date: Tue, 22 Aug 2023 14:05:16 +0800 Subject: [PATCH 130/623] [WASI-NN] Add load_by_name implementation into wasi-nn plugin (#2742) Signed-off-by: vincent --- plugins/wasi_nn/types.h | 18 ++++--- plugins/wasi_nn/wasinnenv.cpp | 80 ++++++++++++++++++++++++++-- plugins/wasi_nn/wasinnenv.h | 45 ++++++++++++++-- plugins/wasi_nn/wasinnfunc.cpp | 91 +++++++++++++++++++++++--------- plugins/wasi_nn/wasinnfunc.h | 14 +++++ plugins/wasi_nn/wasinnmodule.cpp | 1 + 6 files changed, 211 insertions(+), 38 deletions(-) diff --git a/plugins/wasi_nn/types.h b/plugins/wasi_nn/types.h index 54073a3f..1dc0f8e6 100644 --- a/plugins/wasi_nn/types.h +++ b/plugins/wasi_nn/types.h @@ -9,12 +9,15 @@ namespace WasmEdge::Host::WASINN { enum class ErrNo : uint32_t { - Success = 0, // No error occurred. - InvalidArgument = 1, // Caller module passed an invalid argument. - InvalidEncoding = 2, // Invalid encoding. - MissingMemory = 3, // Caller module is missing a memory export. - Busy = 4, // Device or resource busy. - RuntimeError = 5, // Runtime Error. + Success = 0, // No error occurred. + InvalidArgument = 1, // Caller module passed an invalid argument. + InvalidEncoding = 2, // Invalid encoding. + MissingMemory = 3, // Caller module is missing a memory export. + Busy = 4, // Device or resource busy. + RuntimeError = 5, // Runtime Error. + UnsupportedOperation = 6, // Unsupported Operation. + TooLarge = 7, // Too Large. + NotFound = 8, // Not Found. }; enum class TensorType : uint8_t { F16 = 0, F32 = 1, U8 = 2, I32 = 3 }; @@ -26,7 +29,8 @@ enum class Backend : uint8_t { ONNX = 1, Tensorflow = 2, PyTorch = 3, - TensorflowLite = 4 + TensorflowLite = 4, + Autodetect = 5, }; #define FOR_EACH_BACKEND(F) \ diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index 2ad93957..73f959d3 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -7,13 +7,87 @@ namespace WasmEdge { namespace Host { -namespace { +namespace WASINN { Runtime::Instance::ModuleInstance * create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { return new WasiNNModule; } +std::map BackendMap = { + {"OpenVINO"sv, Backend::OpenVINO}, + {"ONNX"sv, Backend::ONNX}, + {"Tensorflow"sv, Backend::Tensorflow}, + {"PyTorch"sv, Backend::PyTorch}, + {"TensorflowLite"sv, Backend::TensorflowLite}, + {"Autodetect"sv, Backend::Autodetect}}; + +std::map DeviceMap = { + {"CPU"sv, Device::CPU}, {"GPU"sv, Device::GPU}, {"TPU"sv, Device::TPU}}; + +bool load(const std::filesystem::path &Path, std::vector &Data) { + std::ifstream File(Path, std::ios::binary); + if (!File.is_open()) { + spdlog::error("[WASI-NN] Preload model fail."sv); + return false; + } + File.seekg(0, std::ios::end); + std::streampos FileSize = File.tellg(); + File.seekg(0, std::ios::beg); + Data.resize(FileSize); + File.read(reinterpret_cast(Data.data()), FileSize); + File.close(); + return true; +} + +WasiNNEnvironment::WasiNNEnvironment() noexcept { + // Preload NN Models + for (const auto &M : NNModels.value()) { + std::istringstream ISS(M); + const char Delimiter = ':'; + std::string Name; + std::string Encode; + std::string Target; + std::vector Paths; + std::getline(ISS, Name, Delimiter); + std::getline(ISS, Encode, Delimiter); + std::getline(ISS, Target, Delimiter); + std::string Path; + while (std::getline(ISS, Path, Delimiter)) { + Paths.push_back(Path); + } + std::vector> Models; + Models.reserve(Paths.size()); + auto Backend = BackendMap.find(Encode); + auto Device = DeviceMap.find(Target); + if (Backend != BackendMap.end() && Device != DeviceMap.end()) { + for (const std::string &Path : Paths) { + std::vector Model; + if (load(std::filesystem::u8path(Path), Model)) { + Models.push_back(std::move(Model)); + } + } + RawMdMap.emplace(Name, std::make_tuple(std::move(Models), Backend->second, + Device->second)); + } else { + spdlog::error( + "[WASI-NN] Preload Model's Backend or Device is Not Support."sv); + } + } + NNGraph.reserve(16U); + NNContext.reserve(16U); +} + +PO::List WasiNNEnvironment::NNModels( + PO::Description( + "Allow preload models from wasinn plugin. Each NN model can be specified as --nn-preload `COMMAND`."sv), + PO::MetaVar("COMMANDS"sv)); + +void addOptions(const Plugin::Plugin::PluginDescriptor *, + PO::ArgumentParser &Parser) noexcept { + Parser.add_option("nn-preload"sv, WasiNNEnvironment::NNModels); +} + Plugin::Plugin::PluginDescriptor Descriptor{ .Name = "wasi_nn", .Description = "", @@ -28,10 +102,10 @@ Plugin::Plugin::PluginDescriptor Descriptor{ .Create = create, }, }, - .AddOptions = nullptr, + .AddOptions = addOptions, }; -} // namespace +} // namespace WASINN Plugin::PluginRegister WASINN::WasiNNEnvironment::Register(&Descriptor); diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index bdc792f7..30bceb1e 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -6,6 +6,7 @@ #include "common/log.h" #include "plugin/plugin.h" #include +#include #include #include "onnx.h" @@ -145,13 +146,51 @@ struct WasiNNEnvironment : FOR_EACH_BACKEND(EACH) #undef EACH std::monostate { - WasiNNEnvironment() noexcept { - NNGraph.reserve(16U); - NNContext.reserve(16U); + + using Callback = std::function( + WASINN::WasiNNEnvironment &, Span>, WASINN::Backend, + WASINN::Device, uint32_t &)>; + + WasiNNEnvironment() noexcept; + + bool mdGet(std::string Name, uint32_t &GraphId) noexcept { + std::shared_lock Lock(MdMutex); + if (auto It = MdMap.find(Name); It != MdMap.end()) { + GraphId = static_cast(It->second); + return true; + } + return false; + } + + Expect mdBuild(std::string Name, uint32_t &GraphId, + Callback Load) noexcept { + std::unique_lock Lock(MdMutex); + auto It = RawMdMap.find(Name); + if (It != RawMdMap.end()) { + auto RawMd = std::get<0>(It->second); + std::vector> Builders; + Builders.reserve(RawMd.size()); + for (auto &Builder : RawMd) { + Builders.emplace_back(Builder); + } + auto Result = Load(*this, Builders, std::get<1>(It->second), + std::get<2>(It->second), GraphId); + if (Result.has_value()) { + MdMap[Name] = GraphId; + } + return Result; + } + return WASINN::ErrNo::NotFound; } + mutable std::shared_mutex MdMutex; ///< Protect MdMap + std::unordered_map>, + Backend, Device>> + RawMdMap; + std::unordered_map MdMap; std::vector NNGraph; std::vector NNContext; + static PO::List NNModels; static Plugin::PluginRegister Register; }; diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index c9e38340..09f877bb 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -13,7 +13,22 @@ namespace Host { namespace { inline void reportUnknownBackend(WASINN::Backend B) noexcept { - spdlog::error("[WASI-NN] Unknown backend {}.", static_cast(B)); + spdlog::error("[WASI-NN] Unknown backend {}."sv, static_cast(B)); +} +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Backend Backend, WASINN::Device Device, + uint32_t &GraphId) { + switch (Backend) { +#define EACH(B) \ + case WASINN::Backend::B: \ + return WASINN::B::load(Env, Builders, Device, GraphId); + FOR_EACH_BACKEND(EACH) +#undef EACH + default: + reportUnknownBackend(Backend); + return WASINN::ErrNo::InvalidEncoding; + } } } // namespace @@ -29,7 +44,8 @@ WasiNNLoad::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t BuilderPtr, // Check the return value: GraphIdPtr should be valid. uint32_t *GraphId = MemInst->getPointer(GraphIdPtr); if (unlikely(GraphId == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the return GraphID memory."); + spdlog::error( + "[WASI-NN] Failed when accessing the return GraphID memory."sv); return WASINN::ErrNo::InvalidArgument; } // Get and check the device. @@ -40,7 +56,7 @@ WasiNNLoad::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t BuilderPtr, case WASINN::Device::TPU: break; default: - spdlog::error("[WASI-NN] Unknown device {};", Target); + spdlog::error("[WASI-NN] Unknown device {};"sv, Target); return WASINN::ErrNo::InvalidArgument; } spdlog::debug("[WASI-NN] Using device: {}", Device); @@ -55,7 +71,7 @@ WasiNNLoad::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t BuilderPtr, const auto WasiBuilders = MemInst->getSpan(BuilderPtr, BuilderLen); if (unlikely(WasiBuilders.size() != BuilderLen)) { - spdlog::error("[WASI-NN] Failed when accessing the GraphBuilder memory."); + spdlog::error("[WASI-NN] Failed when accessing the GraphBuilder memory."sv); return WASINN::ErrNo::InvalidArgument; } @@ -65,22 +81,45 @@ WasiNNLoad::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t BuilderPtr, const auto &WasiBuilder = WasiBuilders[I]; auto Builder = MemInst->getSpan(WasiBuilder.Ptr, WasiBuilder.Len); if (unlikely(Builder.size() != WasiBuilder.Len)) { - spdlog::error("[WASI-NN] Failed when accessing the Builder[{}] memory.", + spdlog::error("[WASI-NN] Failed when accessing the Builder[{}] memory."sv, I); return WASINN::ErrNo::InvalidArgument; } Builders.emplace_back(Builder); } + auto Backend = static_cast(RawEncoding); + return load(Env, Builders, Backend, Device, *GraphId); +} - switch (const auto Backend = static_cast(RawEncoding)) { -#define EACH(B) \ - case WASINN::Backend::B: \ - return WASINN::B::load(Env, Builders, Device, *GraphId); - FOR_EACH_BACKEND(EACH) -#undef EACH - default: - reportUnknownBackend(Backend); - return WASINN::ErrNo::InvalidEncoding; +Expect +WasiNNLoadByName::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t NamePtr, + uint32_t NameLen, uint32_t GraphIdPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + // Check the return value: GraphIdPtr should be valid. + uint32_t *GraphId = MemInst->getPointer(GraphIdPtr); + if (unlikely(GraphId == nullptr)) { + spdlog::error( + "[WASI-NN] Failed when accessing the return GraphID memory."sv); + return WASINN::ErrNo::InvalidArgument; + } + + // Get the name of model + std::vector> Builders; + uint32_t *Name = MemInst->getPointer(NamePtr); + if (unlikely(Name == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the return Name memory."sv); + return WASINN::ErrNo::InvalidArgument; + } + + // Get the model + std::string ModelName(reinterpret_cast(Name), NameLen); + if (Env.mdGet(ModelName, *GraphId)) { + return WASINN::ErrNo::Success; + } else { + return Env.mdBuild(ModelName, *GraphId, load); } } @@ -93,14 +132,15 @@ WasiNNInitExecCtx::bodyImpl(const Runtime::CallingFrame &Frame, } if (Env.NNGraph.size() <= GraphId) { - spdlog::error("[WASI-NN] init_execution_context: Graph Id does not exist."); + spdlog::error( + "[WASI-NN] init_execution_context: Graph Id does not exist."sv); return WASINN::ErrNo::InvalidArgument; } // Check the return value: Context should be valid. uint32_t *Context = MemInst->getPointer(ContextPtr); if (unlikely(Context == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the Context memory."); + spdlog::error("[WASI-NN] Failed when accessing the Context memory."sv); return WASINN::ErrNo::InvalidArgument; } @@ -125,7 +165,7 @@ WasiNNSetInput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context, } if (Env.NNContext.size() <= Context) { - spdlog::error("[WASI-NN] set_input: Execution Context does not exist."); + spdlog::error("[WASI-NN] set_input: Execution Context does not exist."sv); return WASINN::ErrNo::InvalidArgument; } @@ -141,20 +181,20 @@ WasiNNSetInput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context, // Get the tensor. auto *WasiTensor = MemInst->getPointer(TensorPtr); if (unlikely(WasiTensor == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the Tensor memory."); + spdlog::error("[WASI-NN] Failed when accessing the Tensor memory."sv); return WASINN::ErrNo::InvalidArgument; } WASINN::TensorData Tensor; Tensor.Dimension = MemInst->getSpan(WasiTensor->DimensionPtr, WasiTensor->DimensionLen); if (unlikely(Tensor.Dimension.size() != WasiTensor->DimensionLen)) { - spdlog::error("[WASI-NN] Failed when accessing the Dimension memory."); + spdlog::error("[WASI-NN] Failed when accessing the Dimension memory."sv); return WASINN::ErrNo::InvalidArgument; } Tensor.Tensor = MemInst->getSpan(WasiTensor->TensorPtr, WasiTensor->TensorLen); if (unlikely(Tensor.Tensor.size() != WasiTensor->TensorLen)) { - spdlog::error("[WASI-NN] Failed when accessing the TensorData memory."); + spdlog::error("[WASI-NN] Failed when accessing the TensorData memory."sv); return WASINN::ErrNo::InvalidArgument; } switch (const auto RType = @@ -166,7 +206,7 @@ WasiNNSetInput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context, Tensor.RType = RType; break; default: - spdlog::error("[WASI-NN] Unknown tensor type {}.", + spdlog::error("[WASI-NN] Unknown tensor type {}."sv, static_cast(RType)); return WASINN::ErrNo::InvalidArgument; } @@ -193,19 +233,20 @@ WasiNNGetOutput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context, } if (Env.NNContext.size() <= Context) { - spdlog::error("[WASI-NN] get_output: Execution Context does not exist"); + spdlog::error("[WASI-NN] get_output: Execution Context does not exist"sv); return WASINN::ErrNo::InvalidArgument; } const auto OutBuffer = MemInst->getSpan(OutBufferPtr, OutBufferMaxSize); if (unlikely(OutBuffer.data() == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the Output Buffer memory."); + spdlog::error( + "[WASI-NN] Failed when accessing the Output Buffer memory."sv); return WASINN::ErrNo::InvalidArgument; } uint32_t *BytesWritten = MemInst->getPointer(BytesWrittenPtr); if (unlikely(BytesWritten == nullptr)) { - spdlog::error("[WASI-NN] Failed when accessing the BytesWritten memory."); + spdlog::error("[WASI-NN] Failed when accessing the BytesWritten memory."sv); return WASINN::ErrNo::InvalidArgument; } @@ -229,7 +270,7 @@ WasiNNCompute::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context) { } if (Env.NNContext.size() <= Context) { - spdlog::error("[WASI-NN] compute: Execution Context does not exist."); + spdlog::error("[WASI-NN] compute: Execution Context does not exist."sv); return WASINN::ErrNo::InvalidArgument; } diff --git a/plugins/wasi_nn/wasinnfunc.h b/plugins/wasi_nn/wasinnfunc.h index cd76d987..1bec0760 100644 --- a/plugins/wasi_nn/wasinnfunc.h +++ b/plugins/wasi_nn/wasinnfunc.h @@ -28,6 +28,20 @@ class WasiNNLoad : public WasiNN { uint32_t GraphIdPtr); }; +class WasiNNLoadByName : public WasiNN { +public: + WasiNNLoadByName(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t NamePtr, + uint32_t NameLen, uint32_t GraphIdPtr) { + return bodyImpl(Frame, NamePtr, NameLen, GraphIdPtr).map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &, + uint32_t NamePtr, uint32_t NameLen, + uint32_t GraphIdPtr); +}; + class WasiNNInitExecCtx : public WasiNN { public: WasiNNInitExecCtx(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} diff --git a/plugins/wasi_nn/wasinnmodule.cpp b/plugins/wasi_nn/wasinnmodule.cpp index 0a805d87..eec2d282 100644 --- a/plugins/wasi_nn/wasinnmodule.cpp +++ b/plugins/wasi_nn/wasinnmodule.cpp @@ -9,6 +9,7 @@ namespace Host { WasiNNModule::WasiNNModule() : ModuleInstance("wasi_ephemeral_nn") { addHostFunc("load", std::make_unique(Env)); + addHostFunc("load_by_name", std::make_unique(Env)); addHostFunc("init_execution_context", std::make_unique(Env)); addHostFunc("set_input", std::make_unique(Env)); From 1178768580b0c70b149a893909cb97ef554df9cf Mon Sep 17 00:00:00 2001 From: dm4 Date: Sun, 3 Sep 2023 21:49:59 +0800 Subject: [PATCH 131/623] [WASI-NN] Add ggml backend for llama (#2763) * [WASI-NN] Add ggml backend Signed-off-by: dm4 * [WASI-NN] Add TODOs Signed-off-by: dm4 * [WASI-NN] Format the ggml code to pass linter Signed-off-by: dm4 * [WASI-NN] Move ggml/ to thirdparty/ Signed-off-by: dm4 * [WASI-NN] Fix ggml warning Signed-off-by: dm4 * [WASI-NN] Remove ggml llama log Signed-off-by: dm4 * [WASI-NN] Correct naming style and error message for ggml backend Signed-off-by: dm4 * [GGML] Fix CI build Signed-off-by: dm4 * [GGML] Fix CI aarch64 build Signed-off-by: dm4 * [GGML] Fix CI Windows build Signed-off-by: dm4 * [GGML] Add license file of llama.cpp Signed-off-by: dm4 * [WASI-NN] Use string_literal and correct naming style Signed-off-by: dm4 * [WASI-NN] Remove the unused local variable Signed-off-by: dm4 * [CI] Build WASI-NN ggml backend Signed-off-by: dm4 * [CI] Release WASI-NN ggml backend Signed-off-by: dm4 * [GGML] Add thirdparty/ggml/README.md Signed-off-by: dm4 --------- Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 5 + plugins/wasi_nn/ggml.cpp | 190 +++++++++++++++++++++++++++++++++ plugins/wasi_nn/ggml.h | 56 ++++++++++ plugins/wasi_nn/types.h | 3 +- plugins/wasi_nn/wasinnenv.cpp | 3 +- plugins/wasi_nn/wasinnenv.h | 1 + plugins/wasi_nn/wasinnfunc.cpp | 1 - 7 files changed, 256 insertions(+), 3 deletions(-) create mode 100644 plugins/wasi_nn/ggml.cpp create mode 100644 plugins/wasi_nn/ggml.h diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 7c9e9412..320c860e 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -11,28 +11,33 @@ wasmedge_add_library(wasmedgePluginWasiNN tf.cpp torch.cpp tfl.cpp + ggml.cpp ) target_compile_options(wasmedgePluginWasiNN PUBLIC -DWASMEDGE_PLUGIN + -DGGML_USE_K_QUANTS ) target_include_directories(wasmedgePluginWasiNN PUBLIC $ ${CMAKE_CURRENT_SOURCE_DIR} + ${PROJECT_SOURCE_DIR}/thirdparty/ggml ) if(WASMEDGE_LINK_PLUGINS_STATIC) target_link_libraries(wasmedgePluginWasiNN PRIVATE wasmedgeCAPI + utilGgml ) else() target_link_libraries(wasmedgePluginWasiNN PRIVATE wasmedge_shared + utilGgml ) endif() diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp new file mode 100644 index 00000000..adfce21b --- /dev/null +++ b/plugins/wasi_nn/ggml.cpp @@ -0,0 +1,190 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "ggml.h" +#include "wasinnenv.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +#include +#include +#endif + +namespace WasmEdge::Host::WASINN::GGML { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +Expect load(WasiNNEnvironment &Env, Span> Builders, + Device Device, uint32_t &GraphId) noexcept { + // The graph builder length must be 1. + if (Builders.size() != 1) { + spdlog::error("[WASI-NN] Wrong GraphBuilder Length {:d}, expect 1"sv, + Builders.size()); + return ErrNo::InvalidArgument; + } + + // Add a new graph. + Env.NNGraph.emplace_back(Backend::GGML); + auto &GraphRef = Env.NNGraph.back().get(); + + // Setup Graph Device + if (Device != Device::CPU) { + spdlog::error( + "[WASI-NN] ggml backend only support CPU target currently."sv); + return ErrNo::InvalidArgument; + } + + auto Weight = Builders[0]; + std::string BinModel(reinterpret_cast(Weight.data()), Weight.size()); + std::istringstream BinRead(BinModel); + + // TODO: pass the model directly to ggml + // Write ggml model to file. + std::string ModelFilePath("ggml-model.bin"sv); + std::ofstream TempFile(ModelFilePath); + if (!TempFile) { + spdlog::error("[WASI-NN] Failed to create the temporary file. Currently, " + "our workaround involves creating a temporary model file " + "named \"ggml-model.bin\" and passing this filename as a " + "parameter to the ggml llama library."sv); + return ErrNo::InvalidArgument; + } + TempFile << BinModel; + TempFile.close(); + + // Initialize ggml model. + gpt_params Params; + Params.model = ModelFilePath; + llama_backend_init(Params.numa); + std::tie(GraphRef.LlamaModel, GraphRef.LlamaContext) = + llama_init_from_gpt_params(Params); + if (GraphRef.LlamaModel == nullptr) { + spdlog::error("[WASI-NN] Error: unable to init model."sv); + return ErrNo::InvalidArgument; + } + + // Store the loaded graph. + GraphId = Env.NNGraph.size() - 1; + return ErrNo::Success; +} + +Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, + uint32_t &ContextId) noexcept { + Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); + + ContextId = Env.NNContext.size() - 1; + return ErrNo::Success; +} + +Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, + [[maybe_unused]] uint32_t Index, + const TensorData &Tensor) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); + CxtRef.LlamaInputs = llama_tokenize(GraphRef.LlamaContext, Prompt, true); + const uint32_t MaxContextSize = llama_n_ctx(GraphRef.LlamaContext); + const uint32_t MaxTokensListSize = MaxContextSize - 4; + if (CxtRef.LlamaInputs.size() > MaxTokensListSize) { + spdlog::error("[WASI-NN]: Error: prompt too long ({} tokens, max %{})"sv, + CxtRef.LlamaInputs.size(), MaxTokensListSize); + return ErrNo::InvalidArgument; + } + return ErrNo::Success; +} + +Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, + [[maybe_unused]] uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + std::copy_n(CxtRef.LlamaOutputs.data(), CxtRef.LlamaOutputs.length(), + OutBuffer.data()); + BytesWritten = CxtRef.LlamaOutputs.length(); + return ErrNo::Success; +} + +Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (CxtRef.LlamaInputs.size() == 0) { + spdlog::error("[WASI-NN] Llama input is not set!"sv); + return ErrNo::InvalidArgument; + } + + // Output start from prompt. + for (auto Id : CxtRef.LlamaInputs) { + CxtRef.LlamaOutputs += llama_token_to_str(GraphRef.LlamaContext, Id); + } + + // Main predict loop. + // TODO: recompute a compressed context based on previous tokens once the + // cache is full. + const int MaxContextSize = llama_n_ctx(GraphRef.LlamaContext); + while (llama_get_kv_cache_token_count(GraphRef.LlamaContext) < + MaxContextSize) { + if (llama_eval(GraphRef.LlamaContext, CxtRef.LlamaInputs.data(), + int(CxtRef.LlamaInputs.size()), + llama_get_kv_cache_token_count(GraphRef.LlamaContext), + get_num_physical_cores())) { + spdlog::error("[WASI-NN] Llama failed to eval."sv); + return ErrNo::InvalidArgument; + } + CxtRef.LlamaInputs.clear(); + + // Select the best prediction. + llama_token NewTokenId = 0; + auto Logits = llama_get_logits(GraphRef.LlamaContext); + auto NVocab = llama_n_vocab(GraphRef.LlamaContext); + std::vector Candidates; + Candidates.reserve(NVocab); + for (llama_token TokenId = 0; TokenId < NVocab; TokenId++) { + Candidates.emplace_back(llama_token_data{TokenId, Logits[TokenId], 0.0f}); + } + llama_token_data_array CandidatesP = {Candidates.data(), Candidates.size(), + false}; + NewTokenId = llama_sample_token_greedy(GraphRef.LlamaContext, &CandidatesP); + + if (NewTokenId == llama_token_eos()) { + CxtRef.LlamaOutputs += "[end of text]"sv; + break; + } + + // Append the new token. + CxtRef.LlamaOutputs += + llama_token_to_str(GraphRef.LlamaContext, NewTokenId); + + // Push this new token for next evaluation. + CxtRef.LlamaInputs.push_back(NewTokenId); + } + + return ErrNo::Success; +} +#else +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error("[WASI-NN] ggml backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"ggml\" to build it."sv); + return ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WasiNNEnvironment &, Span>, Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WasiNNEnvironment &, uint32_t, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WasiNNEnvironment &, uint32_t, uint32_t, Span, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} + +#endif +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h new file mode 100644 index 00000000..7d82cc86 --- /dev/null +++ b/plugins/wasi_nn/ggml.h @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "plugin/plugin.h" +#include "types.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +#include +#endif + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} + +namespace WasmEdge::Host::WASINN::GGML { + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +struct Graph { + llama_model *LlamaModel; + llama_context *LlamaContext; +}; + +struct Context { +public: + Context(size_t GId, Graph &) noexcept : GraphId(GId) {} + size_t GraphId; + std::vector LlamaInputs; + std::string LlamaOutputs; +}; +#else +struct Graph {}; +struct Context { + Context(size_t, Graph &) noexcept {} +}; +#endif + +struct Environ {}; + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/types.h b/plugins/wasi_nn/types.h index 1dc0f8e6..6e2a8128 100644 --- a/plugins/wasi_nn/types.h +++ b/plugins/wasi_nn/types.h @@ -31,10 +31,11 @@ enum class Backend : uint8_t { PyTorch = 3, TensorflowLite = 4, Autodetect = 5, + GGML = 6, }; #define FOR_EACH_BACKEND(F) \ - F(OpenVINO) F(ONNX) F(Tensorflow) F(PyTorch) F(TensorflowLite) + F(OpenVINO) F(ONNX) F(Tensorflow) F(PyTorch) F(TensorflowLite) F(GGML) struct TensorData { Span Dimension; diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index 73f959d3..8bc2b60b 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -20,7 +20,8 @@ std::map BackendMap = { {"Tensorflow"sv, Backend::Tensorflow}, {"PyTorch"sv, Backend::PyTorch}, {"TensorflowLite"sv, Backend::TensorflowLite}, - {"Autodetect"sv, Backend::Autodetect}}; + {"Autodetect"sv, Backend::Autodetect}, + {"GGML"sv, Backend::GGML}}; std::map DeviceMap = { {"CPU"sv, Device::CPU}, {"GPU"sv, Device::GPU}, {"TPU"sv, Device::TPU}}; diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index 30bceb1e..5803e61a 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -9,6 +9,7 @@ #include #include +#include "ggml.h" #include "onnx.h" #include "openvino.h" #include "tf.h" diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index 09f877bb..13170c31 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -107,7 +107,6 @@ WasiNNLoadByName::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t NamePtr, } // Get the name of model - std::vector> Builders; uint32_t *Name = MemInst->getPointer(NamePtr); if (unlikely(Name == nullptr)) { spdlog::error("[WASI-NN] Failed when accessing the return Name memory."sv); From 665612cf6c9ffa96afd7ada92ac17e495e532da7 Mon Sep 17 00:00:00 2001 From: zzz <458761603@qq.com> Date: Mon, 4 Sep 2023 08:48:14 +0800 Subject: [PATCH 132/623] [Plugin] add wasmedge_rustls_plugin (#2762) * [Plugin] add wasmedge_rustls_plugin Signed-off-by: csh <458761603@qq.com> * [Plugin] add WASMEDGE_PLUGIN_RUSTLS option Signed-off-by: csh <458761603@qq.com> * [Plugin] update workflow Signed-off-by: csh <458761603@qq.com> --------- Signed-off-by: csh <458761603@qq.com> --- plugins/CMakeLists.txt | 4 + plugins/wasmedge_rustls/.gitignore | 1 + plugins/wasmedge_rustls/CMakeLists.txt | 19 + plugins/wasmedge_rustls/Cargo.toml | 18 + plugins/wasmedge_rustls/src/lib.rs | 703 ++++++++++++++++++ test/plugins/CMakeLists.txt | 4 + test/plugins/wasmedge_rustls/CMakeLists.txt | 35 + .../wasmedge_rustls/wasmedge_rustls.cpp | 54 ++ 8 files changed, 838 insertions(+) create mode 100644 plugins/wasmedge_rustls/.gitignore create mode 100644 plugins/wasmedge_rustls/CMakeLists.txt create mode 100644 plugins/wasmedge_rustls/Cargo.toml create mode 100644 plugins/wasmedge_rustls/src/lib.rs create mode 100644 test/plugins/wasmedge_rustls/CMakeLists.txt create mode 100644 test/plugins/wasmedge_rustls/wasmedge_rustls.cpp diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index b9e52060..793e3ec1 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -66,3 +66,7 @@ endif() if(WASMEDGE_PLUGIN_WASI_LOGGING) add_subdirectory(wasi_logging) endif() + +if(WASMEDGE_PLUGIN_RUSTLS) + add_subdirectory(wasmedge_rustls) +endif() diff --git a/plugins/wasmedge_rustls/.gitignore b/plugins/wasmedge_rustls/.gitignore new file mode 100644 index 00000000..eb5a316c --- /dev/null +++ b/plugins/wasmedge_rustls/.gitignore @@ -0,0 +1 @@ +target diff --git a/plugins/wasmedge_rustls/CMakeLists.txt b/plugins/wasmedge_rustls/CMakeLists.txt new file mode 100644 index 00000000..0edaf1f8 --- /dev/null +++ b/plugins/wasmedge_rustls/CMakeLists.txt @@ -0,0 +1,19 @@ +if(CMAKE_BUILD_TYPE STREQUAL "Debug") + set(CARGO_CMD cargo build) + set(TARGET_DIR "debug") +else() + set(CARGO_CMD cargo build --release) + set(TARGET_DIR "release") +endif() + +set(RS_SO ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_DIR}/libwasmedge_rustls${CMAKE_SHARED_LIBRARY_SUFFIX}) + +set(WASMEDGE_LIB_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../lib/api) + +add_custom_target(wasmedge_rustls_plugin ALL + COMMAND WASMEDGE_LIB_DIR=${WASMEDGE_LIB_DIR} LD_LIBARAY_PATH=${WASMEDGE_LIB_DIR} CARGO_TARGET_DIR=${CMAKE_CURRENT_BINARY_DIR} ${CARGO_CMD} + COMMAND cp ${RS_SO} ${CMAKE_CURRENT_BINARY_DIR} + COMMAND rm -rf ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_DIR} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + DEPENDS wasmedge_shared +) \ No newline at end of file diff --git a/plugins/wasmedge_rustls/Cargo.toml b/plugins/wasmedge_rustls/Cargo.toml new file mode 100644 index 00000000..1e97c526 --- /dev/null +++ b/plugins/wasmedge_rustls/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "wasmedge_rustls_plugin" +version = "0.2.0" +edition = "2021" + +[lib] +name = "wasmedge_rustls" +path = "src/lib.rs" +crate-type = ["cdylib"] + +[dependencies] +libc = "0.2" +rustls = "0.20" +bytes = "1" +webpki-roots = "0.22" +wasmedge_plugin_sdk = "0.2.0" +log = "0.4" +thiserror = "1" diff --git a/plugins/wasmedge_rustls/src/lib.rs b/plugins/wasmedge_rustls/src/lib.rs new file mode 100644 index 00000000..d32e84c2 --- /dev/null +++ b/plugins/wasmedge_rustls/src/lib.rs @@ -0,0 +1,703 @@ +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum TlsError { + #[error("{0}")] + Tls(#[from] rustls::Error), + #[error("{0}")] + IO(#[from] std::io::Error), + #[error("ParamError")] + ParamError, +} + +impl TlsError { + pub fn error_code(&self) -> i32 { + match self { + TlsError::ParamError => -1, + TlsError::Tls(tls_err) => match tls_err { + rustls::Error::InappropriateMessage { .. } => -2, + rustls::Error::InappropriateHandshakeMessage { .. } => -3, + rustls::Error::CorruptMessage => -4, + rustls::Error::CorruptMessagePayload(_) => -5, + rustls::Error::NoCertificatesPresented => -6, + rustls::Error::UnsupportedNameType => -7, + rustls::Error::DecryptError => -8, + rustls::Error::EncryptError => -9, + rustls::Error::PeerIncompatibleError(_) => -10, + rustls::Error::PeerMisbehavedError(_) => -11, + rustls::Error::AlertReceived(_) => -12, + rustls::Error::InvalidCertificateEncoding => -13, + rustls::Error::InvalidCertificateSignatureType => -14, + rustls::Error::InvalidCertificateSignature => -15, + rustls::Error::InvalidCertificateData(_) => -16, + rustls::Error::InvalidSct(_) => -17, + rustls::Error::General(_) => -18, + rustls::Error::FailedToGetCurrentTime => -19, + rustls::Error::FailedToGetRandomBytes => -20, + rustls::Error::HandshakeNotComplete => -21, + rustls::Error::PeerSentOversizedRecord => -22, + rustls::Error::NoApplicationProtocol => -23, + rustls::Error::BadMaxFragmentSize => -24, + }, + TlsError::IO(io_err) if io_err.kind() == std::io::ErrorKind::WouldBlock => -25, + TlsError::IO(_) => -26, + } + } +} + +#[repr(C)] +pub struct TlsIoState { + tls_bytes_to_write: u32, + plaintext_bytes_to_read: u32, + peer_has_closed: bool, +} + +impl From for TlsIoState { + fn from(value: rustls::IoState) -> Self { + TlsIoState { + tls_bytes_to_write: value.tls_bytes_to_write() as u32, + plaintext_bytes_to_read: value.plaintext_bytes_to_read() as u32, + peer_has_closed: value.peer_has_closed(), + } + } +} + +mod tls_client { + use std::{ + io::{Read, Write}, + sync::Arc, + }; + + use bytes::{Buf, BufMut}; + use rustls::{OwnedTrustAnchor, RootCertStore}; + + use crate::TlsError; + use crate::TlsIoState; + + pub struct Ctx { + pub client_configs: Vec>>, + pub client_codec: Vec>, + } + + impl Ctx { + pub fn new() -> Ctx { + let mut root_store = RootCertStore::empty(); + root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map( + |ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + }, + )); + let config = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(); + + Ctx { + client_configs: vec![Some(Arc::new(config))], + client_codec: Vec::with_capacity(1024), + } + } + + pub fn default_client_config(&mut self) -> usize { + 0 + } + + pub fn new_codec( + &mut self, + server_name: &str, + config_id: usize, + ) -> Result { + let config = self + .client_configs + .get(config_id) + .ok_or(TlsError::ParamError)? + .clone() + .ok_or(TlsError::ParamError)?; + + let name = server_name.try_into().map_err(|_| TlsError::ParamError)?; + let new_codec = rustls::ClientConnection::new(config, name)?; + let new_codec = ClientCodec(new_codec); + + if let Some((id, item)) = self + .client_codec + .iter_mut() + .enumerate() + .find(|(_, item)| item.is_none()) + { + debug_assert!(item.is_none()); + let _ = item.insert(new_codec); + Ok(id) + } else { + let id = self.client_codec.len(); + self.client_codec.push(Some(new_codec)); + Ok(id) + } + } + + pub fn delete_codec(&mut self, codec_id: usize) { + if let Some(codec) = self.client_codec.get_mut(codec_id) { + let _ = codec.take(); + } + } + } + + #[derive(Debug)] + pub struct ClientCodec(pub rustls::ClientConnection); + + impl ClientCodec { + pub fn is_handshaking(&self) -> bool { + self.0.is_handshaking() + } + + pub fn process_new_packets(&mut self) -> Result { + Ok(self.0.process_new_packets()?.into()) + } + + pub fn send_close_notify(&mut self) { + self.0.send_close_notify(); + } + + pub fn write_raw(&mut self, raw_buf: &[u8]) -> Result { + let conn = &mut self.0; + Ok(conn.writer().write(raw_buf)?) + } + + pub fn write_tls(&mut self, tls_buf: &mut [u8]) -> Result { + let conn = &mut self.0; + Ok(conn.write_tls(&mut tls_buf.writer())?) + } + + pub fn read_raw(&mut self, raw_buf: &mut [u8]) -> Result { + let conn = &mut self.0; + Ok(conn.reader().read(raw_buf)?) + } + + pub fn read_tls(&mut self, tls_buf: &[u8]) -> Result { + let conn = &mut self.0; + Ok(conn.read_tls(&mut tls_buf.reader())?) + } + } + + #[cfg(test)] + mod tls_client_test { + use super::*; + #[test] + fn test_ctx() { + let mut ctx = Ctx::new(); + let config_id = ctx.default_client_config(); + assert_eq!(config_id, 0); + + let codec_id_0 = ctx.new_codec("httpbin.org", config_id).unwrap(); + assert_eq!(codec_id_0, 0); + let codec_id_1 = ctx.new_codec("httpbin.org", config_id).unwrap(); + assert_eq!(codec_id_1, 1); + ctx.delete_codec(codec_id_0); + println!("{:?}", ctx.client_codec); + let codec_id_0 = ctx.new_codec("httpbin.org", config_id).unwrap(); + assert_eq!(codec_id_0, 0); + } + } +} + +mod wasmedge_client_plugin { + + use wasmedge_plugin_sdk::{ + error::CoreError, + memory::Memory, + module::{PluginModule, SyncInstanceRef}, + types::{ValType, WasmVal}, + }; + + use crate::{tls_client::*, TlsError}; + + macro_rules! match_value { + ($expression:expr, $t:path, $error:expr) => { + match $expression { + $t(v) => v, + _ => return Err($error), + } + }; + } + + fn default_config( + _inst: &mut SyncInstanceRef, + _memory: &mut Memory, + ctx: &mut Ctx, + _args: Vec, + ) -> Result, CoreError> { + let config_id = ctx.default_client_config(); + Ok(vec![WasmVal::I32(config_id as i32)]) + } + + fn new_client_codec( + _inst: &mut SyncInstanceRef, + memory: &mut Memory, + ctx: &mut Ctx, + args: Vec, + ) -> Result, CoreError> { + #[inline] + fn new_client_codec_inner( + memory: &mut Memory, + ctx: &mut Ctx, + args: Vec, + ) -> Result { + let config_id = args[0].clone(); + let server_ptr = args[1].clone(); + let server_len = args[2].clone(); + + if let (WasmVal::I32(config_id), WasmVal::I32(server_ptr), WasmVal::I32(server_len)) = + (config_id, server_ptr, server_len) + { + let server_name = memory.data_pointer(server_ptr as usize, server_len as usize); + let server_name = server_name + .and_then(|bs| std::str::from_utf8(bs).ok()) + .ok_or(TlsError::ParamError)?; + let r = ctx.new_codec(server_name, config_id as usize)?; + Ok(WasmVal::I32(r as i32)) + } else { + Err(TlsError::ParamError) + } + } + match new_client_codec_inner(memory, ctx, args) { + Ok(ok) => Ok(vec![ok]), + Err(e) => Ok(vec![WasmVal::I32(e.error_code())]), + } + } + + fn is_handshaking( + _inst: &mut SyncInstanceRef, + memory: &mut Memory, + ctx: &mut Ctx, + args: Vec, + ) -> Result, CoreError> { + #[inline] + fn is_handshaking_inner( + _memory: &mut Memory, + ctx: &mut Ctx, + args: Vec, + ) -> Result { + let codec_id = match_value!(args[0].clone(), WasmVal::I32, TlsError::ParamError); + let codec = ctx + .client_codec + .get(codec_id as usize) + .ok_or(TlsError::ParamError)? + .as_ref() + .ok_or(TlsError::ParamError)?; + + if codec.is_handshaking() { + Ok(WasmVal::I32(1)) + } else { + Ok(WasmVal::I32(0)) + } + } + + match is_handshaking_inner(memory, ctx, args) { + Ok(ok) => Ok(vec![ok]), + Err(e) => Ok(vec![WasmVal::I32(e.error_code())]), + } + } + + fn wants( + _inst: &mut SyncInstanceRef, + memory: &mut Memory, + ctx: &mut Ctx, + args: Vec, + ) -> Result, CoreError> { + #[inline] + fn wants_inner( + _memory: &mut Memory, + ctx: &mut Ctx, + args: Vec, + ) -> Result { + let codec_id = match_value!(args[0].clone(), WasmVal::I32, TlsError::ParamError); + let codec = ctx + .client_codec + .get(codec_id as usize) + .ok_or(TlsError::ParamError)? + .as_ref() + .ok_or(TlsError::ParamError)?; + match (codec.0.wants_write(), codec.0.wants_read()) { + (true, true) => Ok(WasmVal::I32(0b11)), + (true, false) => Ok(WasmVal::I32(0b10)), + (false, true) => Ok(WasmVal::I32(0b01)), + (false, false) => Ok(WasmVal::I32(0)), + } + } + + match wants_inner(memory, ctx, args) { + Ok(ok) => Ok(vec![ok]), + Err(e) => Ok(vec![WasmVal::I32(e.error_code())]), + } + } + + fn delete_codec( + _inst: &mut SyncInstanceRef, + memory: &mut Memory, + ctx: &mut Ctx, + args: Vec, + ) -> Result, CoreError> { + #[inline] + fn delete_codec_inner( + _memory: &mut Memory, + ctx: &mut Ctx, + args: Vec, + ) -> Result { + let codec_id = match_value!(args[0].clone(), WasmVal::I32, TlsError::ParamError); + ctx.delete_codec(codec_id as usize); + Ok(WasmVal::I32(0)) + } + + match delete_codec_inner(memory, ctx, args) { + Ok(ok) => Ok(vec![ok]), + Err(e) => Ok(vec![WasmVal::I32(e.error_code())]), + } + } + + fn process_new_packets( + _inst: &mut SyncInstanceRef, + memory: &mut Memory, + ctx: &mut Ctx, + args: Vec, + ) -> Result, CoreError> { + #[inline] + fn process_new_packets_inner( + memory: &mut Memory, + ctx: &mut Ctx, + args: Vec, + ) -> Result { + let codec_id = match_value!(args[0].clone(), WasmVal::I32, TlsError::ParamError); + let result_ptr = match_value!(args[1].clone(), WasmVal::I32, TlsError::ParamError); + + let codec = ctx + .client_codec + .get_mut(codec_id as usize) + .ok_or(TlsError::ParamError)? + .as_mut() + .ok_or(TlsError::ParamError)?; + let io_state = codec.process_new_packets()?; + + memory + .write_data((result_ptr as usize).into(), io_state) + .ok_or(TlsError::ParamError)?; + + Ok(WasmVal::I32(0 as i32)) + } + match process_new_packets_inner(memory, ctx, args) { + Ok(ok) => Ok(vec![ok]), + Err(e) => Ok(vec![WasmVal::I32(e.error_code())]), + } + } + + fn send_close_notify( + _inst: &mut SyncInstanceRef, + memory: &mut Memory, + ctx: &mut Ctx, + args: Vec, + ) -> Result, CoreError> { + #[inline] + fn send_close_notify_inner( + _memory: &mut Memory, + ctx: &mut Ctx, + args: Vec, + ) -> Result { + let codec_id = match_value!(args[0].clone(), WasmVal::I32, TlsError::ParamError); + let codec = ctx + .client_codec + .get_mut(codec_id as usize) + .ok_or(TlsError::ParamError)? + .as_mut() + .ok_or(TlsError::ParamError)?; + codec.send_close_notify(); + Ok(WasmVal::I32(0)) + } + + match send_close_notify_inner(memory, ctx, args) { + Ok(ok) => Ok(vec![ok]), + Err(e) => Ok(vec![WasmVal::I32(e.error_code())]), + } + } + + fn write_raw( + _inst: &mut SyncInstanceRef, + memory: &mut Memory, + ctx: &mut Ctx, + args: Vec, + ) -> Result, CoreError> { + #[inline] + fn write_raw_inner( + memory: &mut Memory, + ctx: &mut Ctx, + args: Vec, + ) -> Result { + let codec_id = match_value!(args[0].clone(), WasmVal::I32, TlsError::ParamError); + let raw_buf_ptr = match_value!(args[1].clone(), WasmVal::I32, TlsError::ParamError); + let raw_len = match_value!(args[2].clone(), WasmVal::I32, TlsError::ParamError); + + let codec = ctx + .client_codec + .get_mut(codec_id as usize) + .ok_or(TlsError::ParamError)? + .as_mut() + .ok_or(TlsError::ParamError)?; + + let raw_buf = memory + .data_pointer(raw_buf_ptr as usize, raw_len as usize) + .ok_or(TlsError::ParamError)?; + + let n = codec.write_raw(raw_buf)?; + Ok(WasmVal::I32(n as i32)) + } + match write_raw_inner(memory, ctx, args) { + Ok(ok) => Ok(vec![ok]), + Err(e) => Ok(vec![WasmVal::I32(e.error_code())]), + } + } + + fn write_tls( + _inst: &mut SyncInstanceRef, + memory: &mut Memory, + ctx: &mut Ctx, + args: Vec, + ) -> Result, CoreError> { + #[inline] + fn write_tls_inner( + memory: &mut Memory, + ctx: &mut Ctx, + args: Vec, + ) -> Result { + let codec_id = match_value!(args[0].clone(), WasmVal::I32, TlsError::ParamError); + let tls_buf_ptr = match_value!(args[1].clone(), WasmVal::I32, TlsError::ParamError); + let tls_len = match_value!(args[2].clone(), WasmVal::I32, TlsError::ParamError); + + let codec = ctx + .client_codec + .get_mut(codec_id as usize) + .ok_or(TlsError::ParamError)? + .as_mut() + .ok_or(TlsError::ParamError)?; + + let raw_buf = memory + .data_pointer_mut(tls_buf_ptr as usize, tls_len as usize) + .ok_or(TlsError::ParamError)?; + + let n = codec.write_tls(raw_buf)?; + Ok(WasmVal::I32(n as i32)) + } + match write_tls_inner(memory, ctx, args) { + Ok(ok) => Ok(vec![ok]), + Err(e) => Ok(vec![WasmVal::I32(e.error_code())]), + } + } + + fn read_raw( + _inst: &mut SyncInstanceRef, + memory: &mut Memory, + ctx: &mut Ctx, + args: Vec, + ) -> Result, CoreError> { + #[inline] + fn read_raw_inner( + memory: &mut Memory, + ctx: &mut Ctx, + args: Vec, + ) -> Result { + let codec_id = match_value!(args[0].clone(), WasmVal::I32, TlsError::ParamError); + let raw_buf_ptr = match_value!(args[1].clone(), WasmVal::I32, TlsError::ParamError); + let raw_len = match_value!(args[2].clone(), WasmVal::I32, TlsError::ParamError); + + let codec = ctx + .client_codec + .get_mut(codec_id as usize) + .ok_or(TlsError::ParamError)? + .as_mut() + .ok_or(TlsError::ParamError)?; + + let raw_buf = memory + .data_pointer_mut(raw_buf_ptr as usize, raw_len as usize) + .ok_or(TlsError::ParamError)?; + + let n = codec.read_raw(raw_buf); + let n = n?; + Ok(WasmVal::I32(n as i32)) + } + match read_raw_inner(memory, ctx, args) { + Ok(ok) => Ok(vec![ok]), + Err(e) => Ok(vec![WasmVal::I32(e.error_code())]), + } + } + + fn read_tls( + _inst: &mut SyncInstanceRef, + memory: &mut Memory, + ctx: &mut Ctx, + args: Vec, + ) -> Result, CoreError> { + #[inline] + fn read_tls_inner( + memory: &mut Memory, + ctx: &mut Ctx, + args: Vec, + ) -> Result { + let codec_id = match_value!(args[0].clone(), WasmVal::I32, TlsError::ParamError); + let tls_buf_ptr = match_value!(args[1].clone(), WasmVal::I32, TlsError::ParamError); + let tls_len = match_value!(args[2].clone(), WasmVal::I32, TlsError::ParamError); + + let codec = ctx + .client_codec + .get_mut(codec_id as usize) + .ok_or(TlsError::ParamError)? + .as_mut() + .ok_or(TlsError::ParamError)?; + + let raw_buf = memory + .data_pointer(tls_buf_ptr as usize, tls_len as usize) + .ok_or(TlsError::ParamError)?; + + let n = codec.read_tls(raw_buf)?; + Ok(WasmVal::I32(n as i32)) + } + match read_tls_inner(memory, ctx, args) { + Ok(ok) => Ok(vec![ok]), + Err(e) => Ok(vec![WasmVal::I32(e.error_code())]), + } + } + + pub fn create_module() -> PluginModule { + let mut module = PluginModule::create("rustls_client", Ctx::new()).unwrap(); + module + .add_func( + "default_config", + (vec![], vec![ValType::I32]), + default_config, + ) + .unwrap(); + + module + .add_func( + "new_codec", + ( + vec![ValType::I32, ValType::I32, ValType::I32], + vec![ValType::I32], + ), + new_client_codec, + ) + .unwrap(); + + module + .add_func( + "codec_is_handshaking", + (vec![ValType::I32], vec![ValType::I32]), + is_handshaking, + ) + .unwrap(); + + module + .add_func( + "codec_wants", + (vec![ValType::I32], vec![ValType::I32]), + wants, + ) + .unwrap(); + + module + .add_func( + "delete_codec", + (vec![ValType::I32], vec![ValType::I32]), + delete_codec, + ) + .unwrap(); + + module + .add_func( + "send_close_notify", + (vec![ValType::I32], vec![ValType::I32]), + send_close_notify, + ) + .unwrap(); + + module + .add_func( + "process_new_packets", + (vec![ValType::I32, ValType::I32], vec![ValType::I32]), + process_new_packets, + ) + .unwrap(); + + module + .add_func( + "write_raw", + ( + vec![ + ValType::I32, //codec_id + ValType::I32, // buf + ValType::I32, // buf_len + ], + vec![ValType::I32], + ), + write_raw, + ) + .unwrap(); + + module + .add_func( + "write_tls", + ( + vec![ + ValType::I32, //codec_id + ValType::I32, // buf + ValType::I32, // buf_len + ], + vec![ValType::I32], + ), + write_tls, + ) + .unwrap(); + + module + .add_func( + "read_raw", + ( + vec![ + ValType::I32, //codec_id + ValType::I32, // buf + ValType::I32, // buf_len + ], + vec![ValType::I32], + ), + read_raw, + ) + .unwrap(); + + module + .add_func( + "read_tls", + ( + vec![ + ValType::I32, //codec_id + ValType::I32, // buf + ValType::I32, // buf_len + ], + vec![ValType::I32], + ), + read_tls, + ) + .unwrap(); + + module + } +} + +use wasmedge_client_plugin::create_module; + +wasmedge_plugin_sdk::plugin::register_plugin!( + plugin_name="rustls", + plugin_description="rustls plugin", + version=(0,0,1,0), + modules=[ + {"rustls_client","rustls client module",create_module} + ] +); diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt index ac61f40a..a30ccef2 100644 --- a/test/plugins/CMakeLists.txt +++ b/test/plugins/CMakeLists.txt @@ -49,6 +49,10 @@ if(WASMEDGE_PLUGIN_WASI_LOGGING) add_subdirectory(wasi_logging) endif() +if(WASMEDGE_PLUGIN_RUSTLS) + add_subdirectory(wasmedge_rustls) +endif() + if(CMAKE_SYSTEM_NAME MATCHES "Linux" OR CMAKE_SYSTEM_NAME MATCHES "Darwin") add_subdirectory(unittest) endif() diff --git a/test/plugins/wasmedge_rustls/CMakeLists.txt b/test/plugins/wasmedge_rustls/CMakeLists.txt new file mode 100644 index 00000000..620f85ed --- /dev/null +++ b/test/plugins/wasmedge_rustls/CMakeLists.txt @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +wasmedge_add_executable(wasmEdgeRUSTLSTests + wasmedge_rustls.cpp +) + +add_dependencies(wasmEdgeRUSTLSTests + wasmedge_rustls_plugin +) + +target_include_directories(wasmEdgeRUSTLSTests + PUBLIC + $ + $ +) + +target_link_libraries(wasmEdgeRUSTLSTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} +) +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmEdgeRUSTLSTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmEdgeRUSTLSTests + PRIVATE + wasmedge_shared + ) +endif() + +add_test(wasmEdgeRUSTLSTests wasmEdgeRUSTLSTests) diff --git a/test/plugins/wasmedge_rustls/wasmedge_rustls.cpp b/test/plugins/wasmedge_rustls/wasmedge_rustls.cpp new file mode 100644 index 00000000..1e758cf0 --- /dev/null +++ b/test/plugins/wasmedge_rustls/wasmedge_rustls.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "common/defines.h" +#include "plugin/plugin.h" +#include "runtime/instance/module.h" + +#include +#include +#include +#include +#include +#include + +namespace { +WasmEdge::Runtime::Instance::ModuleInstance *createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load( + std::filesystem::u8path("../../../plugins/wasmedge_rustls/" + "libwasmedge_rustls" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("rustls"sv)) { + if (const auto *Module = Plugin->findModule("rustls_client"sv)) { + return Module->create().release(); + } + } + return nullptr; +} +} // namespace + +// TODO: unit tests for every functions. + +TEST(WasmEdgeRUSTLSTest, Module) { + // Create the wasmedge_rustls module instance. + auto *TLSMod = createModule(); + EXPECT_FALSE(TLSMod == nullptr); + EXPECT_EQ(TLSMod->getFuncExportNum(), 11U); + EXPECT_NE(TLSMod->findFuncExports("default_config"), nullptr); + EXPECT_NE(TLSMod->findFuncExports("new_codec"), nullptr); + EXPECT_NE(TLSMod->findFuncExports("codec_is_handshaking"), nullptr); + EXPECT_NE(TLSMod->findFuncExports("codec_wants"), nullptr); + EXPECT_NE(TLSMod->findFuncExports("delete_codec"), nullptr); + EXPECT_NE(TLSMod->findFuncExports("send_close_notify"), nullptr); + EXPECT_NE(TLSMod->findFuncExports("process_new_packets"), nullptr); + EXPECT_NE(TLSMod->findFuncExports("write_raw"), nullptr); + EXPECT_NE(TLSMod->findFuncExports("write_tls"), nullptr); + EXPECT_NE(TLSMod->findFuncExports("read_raw"), nullptr); + EXPECT_NE(TLSMod->findFuncExports("read_tls"), nullptr); + delete TLSMod; +} + +GTEST_API_ int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} From 227fa9d91b5fd50e2ef85d0909c3a2793545a65f Mon Sep 17 00:00:00 2001 From: YiYing He Date: Tue, 5 Sep 2023 13:25:47 +0800 Subject: [PATCH 133/623] [CMake] Fix the target name of rustls plugin. Signed-off-by: YiYing He --- plugins/wasmedge_rustls/CMakeLists.txt | 4 ++-- test/plugins/wasmedge_rustls/CMakeLists.txt | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/plugins/wasmedge_rustls/CMakeLists.txt b/plugins/wasmedge_rustls/CMakeLists.txt index 0edaf1f8..56069740 100644 --- a/plugins/wasmedge_rustls/CMakeLists.txt +++ b/plugins/wasmedge_rustls/CMakeLists.txt @@ -10,10 +10,10 @@ set(RS_SO ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_DIR}/libwasmedge_rustls${CMAKE_SH set(WASMEDGE_LIB_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../lib/api) -add_custom_target(wasmedge_rustls_plugin ALL +add_custom_target(wasmedge_rustls ALL COMMAND WASMEDGE_LIB_DIR=${WASMEDGE_LIB_DIR} LD_LIBARAY_PATH=${WASMEDGE_LIB_DIR} CARGO_TARGET_DIR=${CMAKE_CURRENT_BINARY_DIR} ${CARGO_CMD} COMMAND cp ${RS_SO} ${CMAKE_CURRENT_BINARY_DIR} COMMAND rm -rf ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_DIR} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS wasmedge_shared -) \ No newline at end of file +) diff --git a/test/plugins/wasmedge_rustls/CMakeLists.txt b/test/plugins/wasmedge_rustls/CMakeLists.txt index 620f85ed..3cfc874b 100644 --- a/test/plugins/wasmedge_rustls/CMakeLists.txt +++ b/test/plugins/wasmedge_rustls/CMakeLists.txt @@ -6,13 +6,13 @@ wasmedge_add_executable(wasmEdgeRUSTLSTests ) add_dependencies(wasmEdgeRUSTLSTests - wasmedge_rustls_plugin + wasmedge_rustls ) target_include_directories(wasmEdgeRUSTLSTests PUBLIC $ - $ + $ ) target_link_libraries(wasmEdgeRUSTLSTests From 0040cb3693c7b9f666375bbf7e4377f5b26e22fe Mon Sep 17 00:00:00 2001 From: dm4 Date: Fri, 8 Sep 2023 17:44:41 +0800 Subject: [PATCH 134/623] [Test] add tests for wasi-nn ggml backend (#2796) Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 10 +- test/plugins/wasi_nn/CMakeLists.txt | 10 + test/plugins/wasi_nn/wasi_nn.cpp | 280 +++++++++++++++++++++++- utils/wasi-nn/download-ggml-fixtures.sh | 17 ++ 4 files changed, 312 insertions(+), 5 deletions(-) create mode 100755 utils/wasi-nn/download-ggml-fixtures.sh diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index adfce21b..8a08ff95 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -20,10 +20,6 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, return ErrNo::InvalidArgument; } - // Add a new graph. - Env.NNGraph.emplace_back(Backend::GGML); - auto &GraphRef = Env.NNGraph.back().get(); - // Setup Graph Device if (Device != Device::CPU) { spdlog::error( @@ -49,6 +45,10 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, TempFile << BinModel; TempFile.close(); + // Add a new graph. + Env.NNGraph.emplace_back(Backend::GGML); + auto &GraphRef = Env.NNGraph.back().get(); + // Initialize ggml model. gpt_params Params; Params.model = ModelFilePath; @@ -57,6 +57,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, llama_init_from_gpt_params(Params); if (GraphRef.LlamaModel == nullptr) { spdlog::error("[WASI-NN] Error: unable to init model."sv); + Env.NNGraph.pop_back(); return ErrNo::InvalidArgument; } @@ -82,6 +83,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, Tensor.Tensor.size()); CxtRef.LlamaInputs = llama_tokenize(GraphRef.LlamaContext, Prompt, true); const uint32_t MaxContextSize = llama_n_ctx(GraphRef.LlamaContext); + // Minus 4 for the special tokens. const uint32_t MaxTokensListSize = MaxContextSize - 4; if (CxtRef.LlamaInputs.size() > MaxTokensListSize) { spdlog::error("[WASI-NN]: Error: prompt too long ({} tokens, max %{})"sv, diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index 5f050951..8b84ba10 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -54,6 +54,16 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) if(NOT CHECKSUM_IMAGE STREQUAL "ad51c39cfe35d2ef35c4052b78cb3c55") message(FATAL_ERROR "downloaded bird.jpg fixture with wrong md5") endif() + elseif(BACKEND STREQUAL "ggml") + message( STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_ggml_fixtures") + execute_process( + COMMAND bash ${CMAKE_SOURCE_DIR}/utils/wasi-nn/download-ggml-fixtures.sh ${CMAKE_CURRENT_BINARY_DIR}/wasinn_ggml_fixtures + RESULT_VARIABLE DOWNLOAD_ERROR + OUTPUT_STRIP_TRAILING_WHITESPACE) + file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_ggml_fixtures/orca-mini-3b.ggmlv3.q4_0.bin CHECKSUM_MODEL) + if(NOT CHECKSUM_MODEL STREQUAL "6a087f7f4598fad0bb70e6cb4023645e") + message(FATAL_ERROR "orca-mini-3b.ggmlv3.q4_0.bin downloaded with wrong md5") + endif() else() # Add the other backend test files fetching here. endif() diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index f05ebb69..2ff69579 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -18,7 +18,8 @@ using WasmEdge::Host::WASINN::ErrNo; #if defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH) || \ - defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE) + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE) || \ + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML) namespace { WasmEdge::Runtime::Instance::ModuleInstance *createModule() { using namespace std::literals::string_view_literals; @@ -1206,3 +1207,280 @@ TEST(WasiNNTest, TFLiteBackend) { } } #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +TEST(WasiNNTest, GGMLBackend) { + // Create the wasmedge_process module instance. + auto *NNMod = dynamic_cast(createModule()); + EXPECT_FALSE(NNMod == nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(40000))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // Load the files. + std::string Prompt = "Once upon a time, "; + std::vector TensorData(Prompt.begin(), Prompt.end()); + std::vector WeightRead = + readEntireFile("./wasinn_ggml_fixtures/orca-mini-3b.ggmlv3.q4_0.bin"); + + std::vector TensorDim{1}; + uint32_t BuilderPtr = UINT32_C(0); + uint32_t LoadEntryPtr = UINT32_C(0); + uint32_t SetInputEntryPtr = UINT32_C(0); + uint32_t OutBoundPtr = UINT32_C(41000 * 65536); + uint32_t StorePtr = UINT32_C(65536); + + // Return value. + std::array Errno = {UINT32_C(0)}; + + // Get the function "load". + auto *FuncInst = NNMod->findFuncExports("load"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncLoad = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "init_execution_context". + FuncInst = NNMod->findFuncExports("init_execution_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInit = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "set_input". + FuncInst = NNMod->findFuncExports("set_input"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSetInput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "get_output". + FuncInst = NNMod->findFuncExports("get_output"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetOutput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "compute". + FuncInst = NNMod->findFuncExports("compute"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncCompute = + dynamic_cast(FuncInst->getHostFunc()); + + // GGML WASI-NN load tests. + // Test: load -- meaningless binaries. + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::GGML), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- graph id ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::GGML), + UINT32_C(0), OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- graph builder ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + OutBoundPtr, UINT32_C(1), + static_cast(Backend::GGML), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- GGML model bin ptr out of bounds. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, OutBoundPtr, WeightRead.size(), BuilderPtr); + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::GGML), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- wrong builders' length. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, StorePtr, WeightRead.size(), BuilderPtr); + writeBinaries(MemInst, WeightRead, StorePtr); + StorePtr += WeightRead.size(); + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(2), + static_cast(Backend::GGML), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- the GGML backend currently only supports the CPU target. + // (device: CPU 0, GPU 1, TPU 2) + { + for (uint32_t I = 1; I <= 3; I++) { + + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::GGML), I, + BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + } + + // Test: load -- load successfully. + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::GGML), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // GGML WASI-NN init_execution_context tests. + // Test: init_execution_context -- graph id invalid. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(2), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: init_execution_context -- init context successfully. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // GGML WASI-NN set_input tests. + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); + + // Test: set_input -- context id exceeds. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(3), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: set_input -- set input successfully. + BuilderPtr = SetInputEntryPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + BuilderPtr); + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += (TensorDim.size() * 4 + TensorData.size()); + + // GGML WASI-NN compute tests. + // Test: compute -- context id exceeds. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(3)}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: compute -- compute successfully. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(0)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // GGML WASI-NN get_output tests. + // Test: get_output -- output bytes ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- output buffer ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), OutBoundPtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- get output successfully. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + // Should output more than 100 bytes. + auto BytesWritten = *MemInst.getPointer(BuilderPtr); + EXPECT_GE(BytesWritten, 100); + // Output should begin with the prompt. + const auto Output = MemInst.getSpan(StorePtr, 100); + EXPECT_EQ(std::string(Output.begin(), Output.begin() + Prompt.size()), + Prompt); + } +} +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML diff --git a/utils/wasi-nn/download-ggml-fixtures.sh b/utils/wasi-nn/download-ggml-fixtures.sh new file mode 100755 index 00000000..0c77b184 --- /dev/null +++ b/utils/wasi-nn/download-ggml-fixtures.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2023 Second State INC + +TODIR=$1 +if [[ $# -eq 0 ]]; then + TODIR=. +fi +MODEL=orca-mini-3b.ggmlv3.q4_0.bin +FIXTURE=https://huggingface.co/TheBloke/orca_mini_3B-GGML/resolve/main/$MODEL +if [ ! -d $TODIR ]; then + mkdir $TODIR +fi + +if [ ! -f $TODIR/$MODEL ]; then + curl -sL $FIXTURE -o $TODIR/$MODEL +fi From 830c93682059650c1b65032b5535b9565bee846a Mon Sep 17 00:00:00 2001 From: Wang Jikai Date: Sun, 2 Jul 2023 13:23:33 +0800 Subject: [PATCH 135/623] [MSVC] Avoid designated initializers. MSVC requires /std:c++20 to to work with designated initializers. To stay in C++17, avoid using designated initializers. Signed-off-by: Wang Jikai --- test/plugins/unittest/testplugin.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/plugins/unittest/testplugin.c b/test/plugins/unittest/testplugin.c index 8239f92b..cbd2f7f5 100644 --- a/test/plugins/unittest/testplugin.c +++ b/test/plugins/unittest/testplugin.c @@ -10,8 +10,8 @@ static WasmEdge_String NameString; static const char NameCString[] = "name"; -static const WasmEdge_String NameStringDefaultValue = {.Buf = NameCString, - .Length = 4}; +static const WasmEdge_String NameStringDefaultValue = {/*.Length =*/ 4, + /*.Buf =*/ NameCString}; void Finalizer(void *Data) { printf("Deallocate host data\n"); free((int32_t *)Data); From 1d5d0c939164eb3cf056a1e698f937f75905a21d Mon Sep 17 00:00:00 2001 From: Wang Jikai Date: Thu, 31 Aug 2023 15:23:53 +0800 Subject: [PATCH 136/623] [MSVC] Fix designated initializer style, keep designated initializer in C file. Signed-off-by: Wang Jikai --- test/plugins/unittest/testplugin.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/plugins/unittest/testplugin.c b/test/plugins/unittest/testplugin.c index cbd2f7f5..8239f92b 100644 --- a/test/plugins/unittest/testplugin.c +++ b/test/plugins/unittest/testplugin.c @@ -10,8 +10,8 @@ static WasmEdge_String NameString; static const char NameCString[] = "name"; -static const WasmEdge_String NameStringDefaultValue = {/*.Length =*/ 4, - /*.Buf =*/ NameCString}; +static const WasmEdge_String NameStringDefaultValue = {.Buf = NameCString, + .Length = 4}; void Finalizer(void *Data) { printf("Deallocate host data\n"); free((int32_t *)Data); From 4c9d4da91abde4c605c864e0c95cdc79b3a54169 Mon Sep 17 00:00:00 2001 From: Saikat Dey <57017288+notfathomless@users.noreply.github.com> Date: Fri, 15 Sep 2023 05:46:54 +0530 Subject: [PATCH 137/623] [Plugin] initial support of the zlib plugin (#2562) * Include wasmedge_zlib in plugins/CMakeLists.txt . Signed-off-by: Saikat Dey * Add Zlib Plugin as an option in root CMakeLists.txt . Signed-off-by: Saikat Dey * Add zlib-ng dependency through FetchContent. Signed-off-by: Saikat Dey * Added zlib build target for job 'build_ubuntu' in build-extensions.yml . Signed-off-by: Saikat Dey * Added zlib build target for job 'build_manylinux' in build-extensions.yml . Signed-off-by: Saikat Dey * Init env & base headers, struct Wasm_z_stream & ZStreamMap Signed-off-by: Saikat Dey * Added basic zlib function declarations. Signed-off-by: Saikat Dey * Added zlib env implementation & PluginDescriptor. Signed-off-by: Saikat Dey * Added zlib module implementation, and added Host Functions. Signed-off-by: Saikat Dey * Added zlib plugin WasmEdgeZlibDeflateInit_ implementation. Signed-off-by: Saikat Dey * Added zlib plugin WasmEdgeZlibInflateInit_ implementation. Signed-off-by: Saikat Dey * Added zlib plugin WasmEdgeZlibDeflate implementation. Signed-off-by: Saikat Dey * Added zlib plugin WasmEdgeZlibInflate implementation. Signed-off-by: Saikat Dey * Added zlib plugin WasmEdgeZlibDeflateEnd implementation. Signed-off-by: Saikat Dey * Added zlib plugin WasmEdgeZlibInflateEnd implementation. Signed-off-by: Saikat Dey * Populated plugins/wasmedge_zlib/CMakeLists.txt to build & link the zlib plugin source. Signed-off-by: Saikat Dey * Fix: Added ZLIB as link dep to wasmedge_zlib plugin. Signed-off-by: Saikat Dey * Added basic test for zlib plugin; WIP/TODO. Signed-off-by: Saikat Dey * Added test/plugins/wasmedge_zlib/CMakeLists.txt to build and create zlib test executable. Signed-off-by: Saikat Dey * Include zlib plugin test sub directory in test/plugins/CMakeLists.txt. Signed-off-by: Saikat Dey * Fixed naming convention by omitting namespace prefix from function name. Signed-off-by: Saikat Dey * Fix case sensitivity in CMake of zlib library. Signed-off-by: Saikat Dey * Refactor build-extensions.yml to be in-line with upstream master. Signed-off-by: Saikat Dey * Fix typo in build-extensions.yml on zlib cmake option Signed-off-by: Saikat Dey * Minor type changes to better reflect the Zlib API. Signed-off-by: Saikat Dey * Added Deflate Inflate Integration Test Signed-off-by: Saikat Dey * Removed unnecessary nulling of wasm module's z_stream allocators | The host ignores them. Signed-off-by: Saikat Dey * Added static_assert on Wasm_z_stream struct size. Signed-off-by: Saikat Dey * Start writing to the wasm heap from address 1 onwards, to nto collide with the semantic meaning of address/offset 0 as nullptr. Signed-off-by: Saikat Dey * Fix: Correct Version Check and Stream Size Check Error code, and API behaviour, along with relevant test case changes. Signed-off-by: Saikat Dey * Added in-code documentation/pointer/reason on why we just compare only the first character of zlib version strings. Signed-off-by: Saikat Dey * Remove unsafe usage of wasm MemStart(pointer from base{0}). Converted to offset based pointer calculation, based on pointer increment after zlib operation.. Signed-off-by: Saikat Dey * Added in-code documentation to every field in Wasm_z_stream. Signed-off-by: Saikat Dey * Honor LLVM Naming Convention: Change 'Wasm_z_stream' struct name to 'WasmZStream'. Signed-off-by: Saikat Dey * Honor LLVM Naming Convention: Change 'Wasm_z_stream' struct name to 'WasmZStream'. Signed-off-by: Saikat Dey * Changed c style cast to static_cast. Signed-off-by: Saikat Dey * Updated a few variable name to reflect LLVM standards. Signed-off-by: Saikat Dey * Updated a few more variable name to reflect LLVM standards. Signed-off-by: Saikat Dey * Added a more comprehensive Host and Wasm ZStream Syncronization Func, and wrapped all zlib calls with it. Signed-off-by: Saikat Dey * Change few comments. Signed-off-by: Saikat Dey * Add Condition to only add ZStream to internal registry, if init succeeded. Signed-off-by: Saikat Dey * Remove unnecessary zlib version check in wasmedge, delegated it to the zlib implementation. Signed-off-by: Saikat Dey * Added WasmEdgeZlibDeflateSetDictionary and WasmEdgeZlibDeflateGetDictionary. Signed-off-by: Saikat Dey * Added WasmEdgeZlibDeflateCopy with a comment on impl. ref. Signed-off-by: Saikat Dey * Added deflateReset Signed-off-by: Saikat Dey * Added deflateParams Signed-off-by: Saikat Dey * Added deflateTune & deflateBound. Signed-off-by: Saikat Dey * Added deflatePending & deflatePrime. Signed-off-by: Saikat Dey * Defined the Wasm gz_header struct. Signed-off-by: Saikat Dey * Added few function impl of zlib. Signed-off-by: Saikat Dey * Added most of non gz functions. Signed-off-by: Saikat Dey * Added adler32, adler32_z, crc32, crc32_z. Signed-off-by: Saikat Dey * Remove _v2 postfix Signed-off-by: Saikat Dey * Remove duplicate impl. of WasmEdgeZlibInflateSetDictionary. Signed-off-by: Saikat Dey * Remove unused parameter 'Frame' from WasmEdgeZlibZlibCompilerFlags & WasmEdgeZlibCompressBound. Signed-off-by: Saikat Dey * Correct few pointer value usage. Signed-off-by: Saikat Dey * Added GZFile support Env & implemented WasmEdgeZlibGZDOpen. Signed-off-by: Saikat Dey * Implemented gzbuffer, gzsetparams & gzread. Signed-off-by: Saikat Dey * Added gzfread, gzwrite & gzfwrite. Signed-off-by: Saikat Dey * Added gzputs, gzgets, gzputc, gzgetc. Signed-off-by: Saikat Dey * Added gzflush & gzrewind Signed-off-by: Saikat Dey * Remove unused parameter 'Frame'. Signed-off-by: Saikat Dey * Added gzeof, gzdirect, gzclose, gzclose_r, gzclose_w. Signed-off-by: Saikat Dey * Added gzclearerr. Signed-off-by: Saikat Dey * Change void* to unsigned char* Signed-off-by: Saikat Dey * Added gzgetc & gzungetc. Signed-off-by: Saikat Dey * Fix name Spell mistake. Signed-off-by: Saikat Dey * Fixed no-return on a Expect. Signed-off-by: Saikat Dey * Fix move semantics related to unique_ptr. Signed-off-by: Saikat Dey * Added deflateInit2 & inflateInit2 & inflateBackInit2 & Refactor Part 1. Signed-off-by: Saikat Dey * Added gzopen | gzseek | gztell | gzoffset | adler32_combine | crc32_combine & Refactor Part 2 Signed-off-by: Saikat Dey * Added deflateInit2_ | inflateInit2_ | inflateBackInit_. Signed-off-by: Saikat Dey * Added gzgetc_ Signed-off-by: Saikat Dey * Added inflateSyncPoint | inflateUndermine | inflateValidate | inflateCodesUsed | inflateResetKeep | deflateResetKeep. Signed-off-by: Saikat Dey * Added draft WasmEdgeZlibGZVPrintf. Signed-off-by: Saikat Dey * Removed unused Frame parameter. Signed-off-by: Saikat Dey * Change all remaining naming convention to LLVM style. Signed-off-by: Saikat Dey * Added Host Func registration & Refactor Part 3. Signed-off-by: Saikat Dey * Update function presence check to validate al 74 functions. Signed-off-by: Saikat Dey * Update release.yml to include zlib. Signed-off-by: Saikat Dey * Added deflatesetheader & inflategetheader. Signed-off-by: Saikat Dey * Change return type of SyncRun. Signed-off-by: Saikat Dey * SyncRun Func Design & Params Overhaul. Signed-off-by: Saikat Dey * Remove access to moved unique_ptr. Signed-off-by: Saikat Dey * Added GZHeader Sync steps in SyncRun. Signed-off-by: Saikat Dey * Registered deflateSetHeader | inflateGetHeader. Signed-off-by: Saikat Dey * Updated Test to check for deflateSetHeader & inflateGetHeader. Signed-off-by: Saikat Dey * Updated function count test to check for 76 functions. Signed-off-by: Saikat Dey * Check if SyncRun fails even with no call to zlib API. Signed-off-by: Saikat Dey * Removed old unused comments, revert usage of named fields & naming convention fixes. Signed-off-by: Saikat Dey * Added extensive error logging to Zlib Plugin. Signed-off-by: Saikat Dey * Added Function Instance Name, to SyncRun Error messages for better debug experience. Signed-off-by: Saikat Dey * Added a space between error msg tag & message. Signed-off-by: Saikat Dey * Added info to error msg if the error is caught inside SyncRun. Signed-off-by: Saikat Dey * Removed usage of a temporary ZRes variable wherever not absolutely necessary. Signed-off-by: Saikat Dey * Fix Bug: Properly return a placeholder number to Module to act as a pointer to Host gzFile_s. We don't need to pass the raw pointer to Module, since gzFile_s is an opaque structure & the zlib app shouldn't try to access any of it's fields. Signed-off-by: Saikat Dey * [Style Change] Remove trailing _s from GZFile_s. Signed-off-by: Saikat Dey * Removed usage of decltype, due to gcc [error: type qualifiers ignored on cast result type]. Using remove_cv to supress, would create code bloat. Signed-off-by: Saikat Dey --------- Signed-off-by: Saikat Dey --- plugins/CMakeLists.txt | 4 + plugins/wasmedge_zlib/CMakeLists.txt | 50 + plugins/wasmedge_zlib/zlibbase.h | 25 + plugins/wasmedge_zlib/zlibenv.cpp | 39 + plugins/wasmedge_zlib/zlibenv.h | 100 ++ plugins/wasmedge_zlib/zlibfunc.cpp | 1494 ++++++++++++++++++ plugins/wasmedge_zlib/zlibfunc.h | 631 ++++++++ plugins/wasmedge_zlib/zlibmodule.cpp | 114 ++ plugins/wasmedge_zlib/zlibmodule.h | 23 + test/plugins/CMakeLists.txt | 4 + test/plugins/wasmedge_zlib/CMakeLists.txt | 35 + test/plugins/wasmedge_zlib/wasmedge_zlib.cpp | 342 ++++ 12 files changed, 2861 insertions(+) create mode 100644 plugins/wasmedge_zlib/CMakeLists.txt create mode 100644 plugins/wasmedge_zlib/zlibbase.h create mode 100644 plugins/wasmedge_zlib/zlibenv.cpp create mode 100644 plugins/wasmedge_zlib/zlibenv.h create mode 100644 plugins/wasmedge_zlib/zlibfunc.cpp create mode 100644 plugins/wasmedge_zlib/zlibfunc.h create mode 100644 plugins/wasmedge_zlib/zlibmodule.cpp create mode 100644 plugins/wasmedge_zlib/zlibmodule.h create mode 100644 test/plugins/wasmedge_zlib/CMakeLists.txt create mode 100644 test/plugins/wasmedge_zlib/wasmedge_zlib.cpp diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index 793e3ec1..abe3e8e6 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -70,3 +70,7 @@ endif() if(WASMEDGE_PLUGIN_RUSTLS) add_subdirectory(wasmedge_rustls) endif() + +if(WASMEDGE_PLUGIN_ZLIB) + add_subdirectory(wasmedge_zlib) +endif() diff --git a/plugins/wasmedge_zlib/CMakeLists.txt b/plugins/wasmedge_zlib/CMakeLists.txt new file mode 100644 index 00000000..30fa7d73 --- /dev/null +++ b/plugins/wasmedge_zlib/CMakeLists.txt @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +# Don't reply on System zlib +# find_package(ZLIB REQUIRED) + +set(ZLIB_COMPAT ON) +set(ZLIBNG_ENABLE_TESTS OFF) + +FetchContent_Declare( + zlib + GIT_REPOSITORY "https://github.com/zlib-ng/zlib-ng.git" + GIT_TAG 2.0.7 + GIT_PROGRESS TRUE +) +FetchContent_MakeAvailable(zlib) + +wasmedge_add_library(wasmedgePluginWasmEdgeZlib + SHARED + zlibenv.cpp + zlibfunc.cpp + zlibmodule.cpp +) + +target_compile_options(wasmedgePluginWasmEdgeZlib + PUBLIC + -DWASMEDGE_PLUGIN +) + +target_include_directories(wasmedgePluginWasmEdgeZlib + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} +) + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasmEdgeZlib + PRIVATE + wasmedgeCAPI + zlib + ) +else() + target_link_libraries(wasmedgePluginWasmEdgeZlib + PRIVATE + wasmedge_shared + zlib + ) +endif() + +install(TARGETS wasmedgePluginWasmEdgeZlib DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) diff --git a/plugins/wasmedge_zlib/zlibbase.h b/plugins/wasmedge_zlib/zlibbase.h new file mode 100644 index 00000000..11640a72 --- /dev/null +++ b/plugins/wasmedge_zlib/zlibbase.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "zlibenv.h" + +#include "common/errcode.h" +#include "runtime/callingframe.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { + +template class WasmEdgeZlib : public Runtime::HostFunction { +public: + WasmEdgeZlib(WasmEdgeZlibEnvironment &HostEnv) + : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + WasmEdgeZlibEnvironment &Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_zlib/zlibenv.cpp b/plugins/wasmedge_zlib/zlibenv.cpp new file mode 100644 index 00000000..45654a28 --- /dev/null +++ b/plugins/wasmedge_zlib/zlibenv.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "zlibenv.h" +#include "zlibmodule.h" + +namespace WasmEdge { +namespace Host { + +namespace { + +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeZlibModule; +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasmedge_zlib", + .Description = "", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 10, 1, 0}, + .ModuleCount = 1, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "wasmedge_zlib", + .Description = "", + .Create = create, + }, + }, + .AddOptions = nullptr, +}; + +} // namespace + +Plugin::PluginRegister WasmEdgeZlibEnvironment::Register(&Descriptor); + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_zlib/zlibenv.h b/plugins/wasmedge_zlib/zlibenv.h new file mode 100644 index 00000000..97781a57 --- /dev/null +++ b/plugins/wasmedge_zlib/zlibenv.h @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "plugin/plugin.h" + +#include +#include +#include + +#include + +/** + * @brief A struct which maps perfectly to a wasm 32bit z_stream object + * + */ +struct WasmZStream { + /* [Wasm Offset] next input byte */ + uint32_t NextIn; + /* number of bytes available at next_in */ + uint32_t AvailIn; + /* total number of input bytes read so far */ + uint32_t TotalIn; + + /* [Wasm Offset] next output byte will go here */ + uint32_t NextOut; + /* remaining free space at next_out */ + uint32_t AvailOut; + /* total number of bytes output so far */ + uint32_t TotalOut; + + /* [Wasm Offset] last error message, NULL if no error */ + uint32_t Msg; + /* [Wasm Offset] not visible by applications */ + uint32_t State; + + /* used to allocate the internal state */ + uint32_t Zalloc; + /* used to free the internal state */ + uint32_t Zfree; + /* [Wasm Offset] private data object passed to zalloc and zfree */ + uint32_t Opaque; + + /* best guess about the data type: binary or text + for deflate, or the decoding state + for inflate */ + int32_t DataType; + + /* Adler-32 or CRC-32 value of the uncompressed data */ + uint32_t Adler; + /* reserved for future use */ + uint32_t Reserved; +}; +static_assert(sizeof(WasmZStream) == 56, "WasmZStream should be 56 bytes"); + +/* + gzip header information passed to and from zlib routines. See RFC 1952 + for more details on the meanings of these fields. +*/ +struct WasmGZHeader { + int32_t Text; /* true if compressed data believed to be text */ + uint32_t Time; /* modification time */ + int32_t XFlags; /* extra flags (not used when writing a gzip file) */ + int32_t OS; /* operating system */ + uint32_t Extra; /* pointer to extra field or Z_NULL if none */ + uint32_t ExtraLen; /* extra field length (valid if extra != Z_NULL) */ + uint32_t ExtraMax; /* space at extra (only when reading header) */ + uint32_t Name; /* pointer to zero-terminated file name or Z_NULL */ + uint32_t NameMax; /* space at name (only when reading header) */ + uint32_t Comment; /* pointer to zero-terminated comment or Z_NULL */ + uint32_t CommMax; /* space at comment (only when reading header) */ + int32_t HCRC; /* true if there was or will be a header crc */ + int32_t Done; /* true when done reading gzip header (not used + when writing a gzip file) */ +}; +static_assert(sizeof(WasmGZHeader) == 52, "WasmGZHeader should be 52 bytes"); + +namespace WasmEdge { +namespace Host { + +class WasmEdgeZlibEnvironment { +public: + using GZFile = std::remove_pointer_t; + + struct GZStore { + uint32_t WasmGZHeaderOffset; + std::unique_ptr HostGZHeader; + }; + + std::unordered_map> ZStreamMap; + std::map, std::greater> GZFileMap; + std::unordered_map GZHeaderMap; + + /// Initial Configurations + static Plugin::PluginRegister Register; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_zlib/zlibfunc.cpp b/plugins/wasmedge_zlib/zlibfunc.cpp new file mode 100644 index 00000000..04aeac05 --- /dev/null +++ b/plugins/wasmedge_zlib/zlibfunc.cpp @@ -0,0 +1,1494 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "zlibfunc.h" + +#include +#include + +namespace WasmEdge { +namespace Host { + +constexpr bool CheckSize(int32_t StreamSize) { + + return (StreamSize == static_cast(sizeof(WasmZStream))); +} + +static constexpr uint32_t WasmGZFileStart = sizeof(gzFile); + +template +auto SyncRun(const std::string_view &Msg, WasmEdgeZlibEnvironment &Env, + uint32_t ZStreamPtr, const Runtime::CallingFrame &Frame, + T Callback) -> Expect { + + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [{}-SyncRun] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv, + Msg); + return Unexpect(ErrCode::Value::HostFuncError); + } + WasmZStream *ModuleZStream = MemInst->getPointer(ZStreamPtr); + + const auto HostZStreamIt = Env.ZStreamMap.find(ZStreamPtr); + if (HostZStreamIt == Env.ZStreamMap.end()) { + spdlog::error("[WasmEdge-Zlib] [{}-SyncRun] "sv + "Invalid ZStreamPtr received."sv, + Msg); + return Unexpect(ErrCode::Value::HostFuncError); + } + auto HostZStream = HostZStreamIt->second.get(); + const auto GZHeaderStoreIt = Env.GZHeaderMap.find(ZStreamPtr); + + HostZStream->next_in = + MemInst->getPointer(ModuleZStream->NextIn); + HostZStream->avail_in = ModuleZStream->AvailIn; + HostZStream->total_in = ModuleZStream->TotalIn; + + HostZStream->next_out = + MemInst->getPointer(ModuleZStream->NextOut); + HostZStream->avail_out = ModuleZStream->AvailOut; + HostZStream->total_out = ModuleZStream->TotalOut; + + // TODO: ignore msg for now + // ignore state + // ignore zalloc, zfree, opaque + + HostZStream->data_type = ModuleZStream->DataType; + HostZStream->adler = ModuleZStream->Adler; + HostZStream->reserved = ModuleZStream->Reserved; + + const auto PreComputeNextIn = HostZStream->next_in; + const auto PreComputeNextOut = HostZStream->next_out; + + unsigned char *PreComputeExtra{}; + unsigned char *PreComputeName{}; + unsigned char *PreComputeComment{}; + + if (GZHeaderStoreIt != Env.GZHeaderMap.end()) { + // Sync GZ Header + + auto *ModuleGZHeader = MemInst->getPointer( + GZHeaderStoreIt->second.WasmGZHeaderOffset); + auto *HostGZHeader = GZHeaderStoreIt->second.HostGZHeader.get(); + + HostGZHeader->text = ModuleGZHeader->Text; + HostGZHeader->time = ModuleGZHeader->Time; + HostGZHeader->xflags = ModuleGZHeader->XFlags; + HostGZHeader->os = ModuleGZHeader->OS; + + HostGZHeader->extra = + MemInst->getPointer(ModuleGZHeader->Extra); + HostGZHeader->extra_len = ModuleGZHeader->ExtraLen; + HostGZHeader->extra_max = ModuleGZHeader->ExtraMax; + + HostGZHeader->name = + MemInst->getPointer(ModuleGZHeader->Name); + HostGZHeader->name_max = ModuleGZHeader->NameMax; + + HostGZHeader->comment = + MemInst->getPointer(ModuleGZHeader->Comment); + HostGZHeader->comm_max = ModuleGZHeader->CommMax; + + HostGZHeader->hcrc = ModuleGZHeader->HCRC; + HostGZHeader->done = ModuleGZHeader->Done; + + PreComputeExtra = HostGZHeader->extra; + PreComputeName = HostGZHeader->name; + PreComputeComment = HostGZHeader->comment; + } + + const auto ZRes = Callback(HostZStream); + + ModuleZStream->NextIn += HostZStream->next_in - PreComputeNextIn; + ModuleZStream->AvailIn = HostZStream->avail_in; + ModuleZStream->TotalIn = HostZStream->total_in; + + ModuleZStream->NextOut += HostZStream->next_out - PreComputeNextOut; + ModuleZStream->AvailOut = HostZStream->avail_out; + ModuleZStream->TotalOut = HostZStream->total_out; + + // TODO: ignore msg for now + // ignore state + // ignore zalloc, zfree, opaque + + ModuleZStream->DataType = HostZStream->data_type; + ModuleZStream->Adler = HostZStream->adler; + ModuleZStream->Reserved = HostZStream->reserved; + + if (GZHeaderStoreIt != Env.GZHeaderMap.end()) { + // Sync GZ Header + + auto *ModuleGZHeader = MemInst->getPointer( + GZHeaderStoreIt->second.WasmGZHeaderOffset); + auto *HostGZHeader = GZHeaderStoreIt->second.HostGZHeader.get(); + + ModuleGZHeader->Text = HostGZHeader->text; + ModuleGZHeader->Time = HostGZHeader->time; + ModuleGZHeader->XFlags = HostGZHeader->xflags; + ModuleGZHeader->OS = HostGZHeader->os; + + ModuleGZHeader->Extra += HostGZHeader->extra - PreComputeExtra; + ModuleGZHeader->ExtraLen = HostGZHeader->extra_len; + ModuleGZHeader->ExtraMax = HostGZHeader->extra_max; + + ModuleGZHeader->Name += HostGZHeader->name - PreComputeName; + ModuleGZHeader->NameMax = HostGZHeader->name_max; + + ModuleGZHeader->Comment += HostGZHeader->comment - PreComputeComment; + ModuleGZHeader->CommMax = HostGZHeader->comm_max; + + ModuleGZHeader->HCRC = HostGZHeader->hcrc; + ModuleGZHeader->Done = HostGZHeader->done; + } + + return ZRes; +} + +Expect +WasmEdgeZlibDeflateInit::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, int32_t Level) { + + auto HostZStream = std::make_unique(); + HostZStream->zalloc = Z_NULL; + HostZStream->zfree = Z_NULL; + HostZStream->opaque = + Z_NULL; // ignore opaque since zmalloc and zfree was ignored + + auto It = + Env.ZStreamMap.emplace(std::make_pair(ZStreamPtr, std::move(HostZStream))) + .second; + + const auto ZRes = SyncRun( + "WasmEdgeZlibDeflateInit", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { return deflateInit(HostZStream, Level); }); + + if (ZRes != Z_OK) + Env.ZStreamMap.erase(It); + + return ZRes; +} + +Expect WasmEdgeZlibDeflate::WasmEdgeZlibDeflate::body( + const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, int32_t Flush) { + + return SyncRun( + "WasmEdgeZlibDeflate", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { return deflate(HostZStream, Flush); }); +} + +Expect WasmEdgeZlibDeflateEnd::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr) { + + const auto ZRes = + SyncRun("WasmEdgeZlibDeflateEnd", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { return deflateEnd(HostZStream); }); + + if (ZRes == Z_OK) + Env.ZStreamMap.erase(ZStreamPtr); + + return ZRes; +} + +Expect +WasmEdgeZlibInflateInit::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr) { + + auto HostZStream = std::make_unique(); + HostZStream->zalloc = Z_NULL; + HostZStream->zfree = Z_NULL; + HostZStream->opaque = + Z_NULL; // ignore opaque since zmalloc and zfree was ignored + + auto It = + Env.ZStreamMap.emplace(std::make_pair(ZStreamPtr, std::move(HostZStream))) + .second; + + const auto ZRes = + SyncRun("WasmEdgeZlibInflateInit", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { return inflateInit(HostZStream); }); + + if (ZRes != Z_OK) + Env.ZStreamMap.erase(It); + + return ZRes; +} + +Expect WasmEdgeZlibInflate::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, int32_t Flush) { + + return SyncRun( + "WasmEdgeZlibInflate", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { return inflate(HostZStream, Flush); }); +} + +Expect WasmEdgeZlibInflateEnd::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr) { + + const auto ZRes = + SyncRun("WasmEdgeZlibInflateEnd", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { return inflateEnd(HostZStream); }); + + Env.ZStreamMap.erase(ZStreamPtr); + + return ZRes; +} + +Expect WasmEdgeZlibDeflateInit2::body( + const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, int32_t Level, + int32_t Method, int32_t WindowBits, int32_t MemLevel, int32_t Strategy) { + + auto HostZStream = std::make_unique(); + HostZStream->zalloc = Z_NULL; + HostZStream->zfree = Z_NULL; + HostZStream->opaque = + Z_NULL; // ignore opaque since zmalloc and zfree was ignored + + auto It = + Env.ZStreamMap.emplace(std::make_pair(ZStreamPtr, std::move(HostZStream))) + .second; + + const auto ZRes = + SyncRun("WasmEdgeZlibDeflateInit2", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return deflateInit2(HostZStream, Level, Method, WindowBits, + MemLevel, Strategy); + }); + + if (ZRes != Z_OK) + Env.ZStreamMap.erase(It); + + return ZRes; +} + +Expect WasmEdgeZlibDeflateSetDictionary::body( + const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t DictionaryPtr, uint32_t DictLength) { + + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibDeflateSetDictionary] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + const auto *Dictionary = MemInst->getPointer(DictionaryPtr); + + return SyncRun("WasmEdgeZlibDeflateSetDictionary", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return deflateSetDictionary(HostZStream, Dictionary, + DictLength); + }); +} + +Expect WasmEdgeZlibDeflateGetDictionary::body( + const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t DictionaryPtr, uint32_t DictLengthPtr) { + + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibDeflateGetDictionary] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto *Dictionary = MemInst->getPointer(DictionaryPtr); + auto *DictLength = MemInst->getPointer(DictLengthPtr); + + return SyncRun("WasmEdgeZlibDeflateGetDictionary", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return deflateGetDictionary(HostZStream, Dictionary, + DictLength); + }); +} + +/* +"The deflateCopy() function shall copy the compression state information in +source to the uninitialized z_stream structure referenced by dest." + +https://refspecs.linuxbase.org/LSB_3.0.0/LSB-Core-generic/LSB-Core-generic/zlib-deflatecopy-1.html +*/ +Expect +WasmEdgeZlibDeflateCopy::body(const Runtime::CallingFrame &Frame, + uint32_t DestPtr, uint32_t SourcePtr) { + const auto SourceZStreamIt = Env.ZStreamMap.find(SourcePtr); + if (SourceZStreamIt == Env.ZStreamMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibDeflateCopy] "sv + "Invalid SourcePtr received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto DestZStream = std::make_unique(); + + auto It = + Env.ZStreamMap.emplace(std::make_pair(DestPtr, std::move(DestZStream))) + .second; + + const auto Res = SyncRun("WasmEdgeZlibDeflateCopy", Env, DestPtr, Frame, + [&](z_stream *) { return 0; }); + if (!Res.has_value()) + return Res; + + const auto ZRes = + SyncRun("WasmEdgeZlibDeflateCopy", Env, DestPtr, Frame, + [&](z_stream *DestZStream) { + return deflateCopy(DestZStream, SourceZStreamIt->second.get()); + }); + + if (ZRes != Z_OK) + Env.ZStreamMap.erase(It); + + return ZRes; +} + +Expect +WasmEdgeZlibDeflateReset::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr) { + + return SyncRun( + "WasmEdgeZlibDeflateReset", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { return deflateReset(HostZStream); }); +} + +Expect +WasmEdgeZlibDeflateParams::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, int32_t Level, + int32_t Strategy) { + + return SyncRun("WasmEdgeZlibDeflateParams", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return deflateParams(HostZStream, Level, Strategy); + }); +} + +Expect WasmEdgeZlibDeflateTune::body( + const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, int32_t GoodLength, + int32_t MaxLazy, int32_t NiceLength, int32_t MaxChain) { + + return SyncRun("WasmEdgeZlibDeflateTune", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return deflateTune(HostZStream, GoodLength, MaxLazy, + NiceLength, MaxChain); + }); +} + +Expect +WasmEdgeZlibDeflateBound::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, uint32_t SourceLen) { + + return SyncRun("WasmEdgeZlibDeflateBound", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return deflateBound(HostZStream, SourceLen); + }); +} + +Expect +WasmEdgeZlibDeflatePending::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, uint32_t PendingPtr, + uint32_t BitsPtr) { + + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibDeflatePending] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto *Pending = MemInst->getPointer(PendingPtr); + auto *Bits = MemInst->getPointer(BitsPtr); + + return SyncRun("WasmEdgeZlibDeflatePending", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return deflatePending(HostZStream, Pending, Bits); + }); +} + +Expect +WasmEdgeZlibDeflatePrime::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, int32_t Bits, + int32_t Value) { + + return SyncRun("WasmEdgeZlibDeflatePrime", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return deflatePrime(HostZStream, Bits, Value); + }); +} + +Expect +WasmEdgeZlibDeflateSetHeader::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, uint32_t HeadPtr) { + + auto HostGZHeader = std::make_unique(); + auto HostGZHeaderPtr = HostGZHeader.get(); + + auto It = Env.GZHeaderMap + .emplace(std::pair{ + ZStreamPtr, + WasmEdgeZlibEnvironment::GZStore{ + .WasmGZHeaderOffset = HeadPtr, + .HostGZHeader = std::move(HostGZHeader)}}) + .second; + + const auto ZRes = + SyncRun("WasmEdgeZlibDeflateSetHeader", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return deflateSetHeader(HostZStream, HostGZHeaderPtr); + }); + + if (ZRes != Z_OK) + Env.GZHeaderMap.erase(It); + + return ZRes; +} + +Expect +WasmEdgeZlibInflateInit2::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, int32_t WindowBits) { + auto HostZStream = std::make_unique(); + HostZStream->zalloc = Z_NULL; + HostZStream->zfree = Z_NULL; + HostZStream->opaque = + Z_NULL; // ignore opaque since zmalloc and zfree was ignored + + auto It = + Env.ZStreamMap.emplace(std::make_pair(ZStreamPtr, std::move(HostZStream))) + .second; + + const auto ZRes = SyncRun("WasmEdgeZlibInflateInit2", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return inflateInit2(HostZStream, WindowBits); + }); + + if (ZRes != Z_OK) + Env.ZStreamMap.erase(It); + + return ZRes; +} + +Expect WasmEdgeZlibInflateSetDictionary::body( + const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t DictionaryPtr, uint32_t DictLength) { + + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibInflateSetDictionary] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto *Dictionary = MemInst->getPointer(DictionaryPtr); + + return SyncRun("WasmEdgeZlibInflateSetDictionary", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return inflateSetDictionary(HostZStream, Dictionary, + DictLength); + }); +} + +Expect WasmEdgeZlibInflateGetDictionary::body( + const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t DictionaryPtr, uint32_t DictLengthPtr) { + + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibInflateGetDictionary] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto *Dictionary = MemInst->getPointer(DictionaryPtr); + auto *DictLength = MemInst->getPointer(DictLengthPtr); + + return SyncRun("WasmEdgeZlibInflateGetDictionary", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return inflateGetDictionary(HostZStream, Dictionary, + DictLength); + }); +} + +Expect +WasmEdgeZlibInflateSync::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr) { + + return SyncRun( + "WasmEdgeZlibInflateSync", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { return inflateSync(HostZStream); }); +} + +Expect +WasmEdgeZlibInflateCopy::body(const Runtime::CallingFrame &Frame, + uint32_t DestPtr, uint32_t SourcePtr) { + const auto SourceZStreamIt = Env.ZStreamMap.find(SourcePtr); + if (SourceZStreamIt == Env.ZStreamMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibInflateCopy] "sv + "Invalid SourcePtr received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto DestZStream = std::make_unique(); + + auto It = + Env.ZStreamMap.emplace(std::make_pair(DestPtr, std::move(DestZStream))) + .second; + + const auto Res = SyncRun("WasmEdgeZlibInflateCopy", Env, DestPtr, Frame, + [&](z_stream *) { return 0; }); + if (!Res.has_value()) + return Res; + + const auto ZRes = + SyncRun("WasmEdgeZlibInflateCopy", Env, DestPtr, Frame, + [&](z_stream *DestZStream) { + return inflateCopy(DestZStream, SourceZStreamIt->second.get()); + }); + + if (ZRes != Z_OK) + Env.ZStreamMap.erase(It); + + return ZRes; +} + +Expect +WasmEdgeZlibInflateReset::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr) { + + return SyncRun( + "WasmEdgeZlibInflateReset", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { return inflateReset(HostZStream); }); +} + +Expect +WasmEdgeZlibInflateReset2::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, int32_t WindowBits) { + + return SyncRun("WasmEdgeZlibInflateReset2", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return inflateReset2(HostZStream, WindowBits); + }); +} + +Expect +WasmEdgeZlibInflatePrime::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, int32_t Bits, + int32_t Value) { + + return SyncRun("WasmEdgeZlibInflatePrime", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return inflatePrime(HostZStream, Bits, Value); + }); +} + +Expect +WasmEdgeZlibInflateMark::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr) { + + return SyncRun( + "WasmEdgeZlibInflateMark", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { return inflateMark(HostZStream); }); +} + +Expect +WasmEdgeZlibInflateGetHeader::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, uint32_t HeadPtr) { + + auto HostGZHeader = std::make_unique(); + auto HostGZHeaderPtr = HostGZHeader.get(); + + auto It = Env.GZHeaderMap + .emplace(std::pair{ + ZStreamPtr, + WasmEdgeZlibEnvironment::GZStore{ + .WasmGZHeaderOffset = HeadPtr, + .HostGZHeader = std::move(HostGZHeader)}}) + .second; + + const auto ZRes = + SyncRun("WasmEdgeZlibInflateGetHeader", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return inflateGetHeader(HostZStream, HostGZHeaderPtr); + }); + + if (ZRes != Z_OK) + Env.GZHeaderMap.erase(It); + + return ZRes; +} + +Expect +WasmEdgeZlibInflateBackInit::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, int32_t WindowBits, + uint32_t WindowPtr) { + auto HostZStream = std::make_unique(); + HostZStream->zalloc = Z_NULL; + HostZStream->zfree = Z_NULL; + HostZStream->opaque = + Z_NULL; // ignore opaque since zmalloc and zfree was ignored + + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibInflateBackInit] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto *Window = MemInst->getPointer(WindowPtr); + + auto It = + Env.ZStreamMap.emplace(std::make_pair(ZStreamPtr, std::move(HostZStream))) + .second; + + const auto ZRes = + SyncRun("WasmEdgeZlibInflateBackInit", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return inflateBackInit(HostZStream, WindowBits, Window); + }); + + if (ZRes != Z_OK) + Env.ZStreamMap.erase(It); + + return ZRes; +} + +Expect +WasmEdgeZlibInflateBackEnd::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr) { + + const auto ZRes = SyncRun( + "WasmEdgeZlibInflateBackEnd", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { return inflateBackEnd(HostZStream); }); + + Env.ZStreamMap.erase(ZStreamPtr); + + return ZRes; +} + +Expect +WasmEdgeZlibZlibCompilerFlags::body(const Runtime::CallingFrame &) { + return zlibCompileFlags(); +} + +Expect WasmEdgeZlibCompress::body(const Runtime::CallingFrame &Frame, + uint32_t DestPtr, + uint32_t DestLenPtr, + uint32_t SourcePtr, + uint32_t SourceLen) { + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibCompress] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto *Dest = MemInst->getPointer(DestPtr); + auto *DestLen = MemInst->getPointer(DestLenPtr); + auto *Source = MemInst->getPointer(SourcePtr); + + unsigned long HostDestLen; + HostDestLen = *DestLen; + const auto ZRes = compress(Dest, &HostDestLen, Source, SourceLen); + *DestLen = HostDestLen; + + return ZRes; +} + +Expect WasmEdgeZlibCompress2::body(const Runtime::CallingFrame &Frame, + uint32_t DestPtr, + uint32_t DestLenPtr, + uint32_t SourcePtr, + uint32_t SourceLen, int32_t Level) { + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibCompress2] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto *Dest = MemInst->getPointer(DestPtr); + auto *DestLen = MemInst->getPointer(DestLenPtr); + auto *Source = MemInst->getPointer(SourcePtr); + + unsigned long HostDestLen; + HostDestLen = *DestLen; + const auto ZRes = compress2(Dest, &HostDestLen, Source, SourceLen, Level); + *DestLen = HostDestLen; + + return ZRes; +} + +Expect WasmEdgeZlibCompressBound::body(const Runtime::CallingFrame &, + uint32_t SourceLen) { + return compressBound(SourceLen); +} + +Expect WasmEdgeZlibUncompress::body(const Runtime::CallingFrame &Frame, + uint32_t DestPtr, + uint32_t DestLenPtr, + uint32_t SourcePtr, + uint32_t SourceLen) { + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibUncompress] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto *Dest = MemInst->getPointer(DestPtr); + auto *DestLen = MemInst->getPointer(DestLenPtr); + auto *Source = MemInst->getPointer(SourcePtr); + + unsigned long HostDestLen; + HostDestLen = *DestLen; + const auto ZRes = uncompress(Dest, &HostDestLen, Source, SourceLen); + *DestLen = HostDestLen; + + return ZRes; +} + +Expect +WasmEdgeZlibUncompress2::body(const Runtime::CallingFrame &Frame, + uint32_t DestPtr, uint32_t DestLenPtr, + uint32_t SourcePtr, uint32_t SourceLenPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibUncompress2] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto *Dest = MemInst->getPointer(DestPtr); + auto *DestLen = MemInst->getPointer(DestLenPtr); + auto *Source = MemInst->getPointer(SourcePtr); + auto *SourceLen = MemInst->getPointer(SourceLenPtr); + + unsigned long HostDestLen, HostSourceLen; + HostDestLen = *DestLen; + HostSourceLen = *SourceLen; + const auto ZRes = uncompress2(Dest, &HostDestLen, Source, &HostSourceLen); + *DestLen = HostDestLen; + *SourceLen = HostSourceLen; + + return ZRes; +} + +Expect WasmEdgeZlibGZOpen::body(const Runtime::CallingFrame &Frame, + uint32_t PathPtr, uint32_t ModePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZOpen] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto *Path = MemInst->getPointer(PathPtr); + auto *Mode = MemInst->getPointer(ModePtr); + + auto ZRes = gzopen(Path, Mode); + + const auto NewWasmGZFile = WasmGZFileStart + Env.GZFileMap.size(); + auto El = + std::pair>( + NewWasmGZFile, ZRes); + + Env.GZFileMap.emplace(std::move(El)); + + return NewWasmGZFile; +} + +Expect WasmEdgeZlibGZDOpen::body(const Runtime::CallingFrame &Frame, + int32_t FD, uint32_t ModePtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZDOpen] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto *Mode = MemInst->getPointer(ModePtr); + + auto ZRes = gzdopen(FD, Mode); + + const auto NewWasmGZFile = WasmGZFileStart + Env.GZFileMap.size(); + auto El = + std::pair>( + NewWasmGZFile, ZRes); + + Env.GZFileMap.emplace(std::move(El)); + + return NewWasmGZFile; +} + +Expect WasmEdgeZlibGZBuffer::body(const Runtime::CallingFrame &, + uint32_t GZFile, uint32_t Size) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZBuffer] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gzbuffer(GZFileIt->second.get(), Size); +} + +Expect WasmEdgeZlibGZSetParams::body(const Runtime::CallingFrame &, + uint32_t GZFile, int32_t Level, + int32_t Strategy) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZSetParams] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gzsetparams(GZFileIt->second.get(), Level, Strategy); +} + +Expect WasmEdgeZlibGZRead::body(const Runtime::CallingFrame &Frame, + uint32_t GZFile, uint32_t BufPtr, + uint32_t Len) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZRead] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZRead] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto *Buf = MemInst->getPointer(BufPtr); + + return gzread(GZFileIt->second.get(), Buf, Len); +} + +Expect WasmEdgeZlibGZFread::body(const Runtime::CallingFrame &Frame, + uint32_t BufPtr, uint32_t Size, + uint32_t NItems, uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZFread] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZFread] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto *Buf = MemInst->getPointer(BufPtr); + + return gzfread(Buf, Size, NItems, GZFileIt->second.get()); +} + +Expect WasmEdgeZlibGZWrite::body(const Runtime::CallingFrame &Frame, + uint32_t GZFile, uint32_t BufPtr, + uint32_t Len) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZWrite] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZWrite] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto *Buf = MemInst->getPointer(BufPtr); + + return gzwrite(GZFileIt->second.get(), Buf, Len); +} + +Expect WasmEdgeZlibGZFwrite::body(const Runtime::CallingFrame &Frame, + uint32_t BufPtr, uint32_t Size, + uint32_t NItems, uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZFwrite] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZFwrite] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto *Buf = MemInst->getPointer(BufPtr); + + return gzfwrite(Buf, Size, NItems, GZFileIt->second.get()); +} + +Expect WasmEdgeZlibGZPuts::body(const Runtime::CallingFrame &Frame, + uint32_t GZFile, uint32_t StringPtr) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZPuts] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZPuts] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto *String = MemInst->getPointer(StringPtr); + + return gzputs(GZFileIt->second.get(), String); +} + +Expect WasmEdgeZlibGZPutc::body(const Runtime::CallingFrame &, + uint32_t GZFile, int32_t C) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZPutc] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gzputc(GZFileIt->second.get(), C); +} + +Expect WasmEdgeZlibGZGetc::body(const Runtime::CallingFrame &, + uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZGetc] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gzgetc(GZFileIt->second.get()); +} + +Expect WasmEdgeZlibGZUngetc::body(const Runtime::CallingFrame &, + int32_t C, uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZUngetc] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gzungetc(C, GZFileIt->second.get()); +} + +Expect WasmEdgeZlibGZFlush::body(const Runtime::CallingFrame &, + uint32_t GZFile, int32_t Flush) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZFlush] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gzflush(GZFileIt->second.get(), Flush); +} + +Expect WasmEdgeZlibGZSeek::body(const Runtime::CallingFrame &, + uint32_t GZFile, int32_t Offset, + int32_t Whence) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZSeek] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gzseek(GZFileIt->second.get(), Offset, Whence); +} + +Expect WasmEdgeZlibGZRewind::body(const Runtime::CallingFrame &, + uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZRewind] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gzrewind(GZFileIt->second.get()); +} + +Expect WasmEdgeZlibGZTell::body(const Runtime::CallingFrame &, + uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZTell] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gztell(GZFileIt->second.get()); +} + +Expect WasmEdgeZlibGZOffset::body(const Runtime::CallingFrame &, + uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZOffset] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gzoffset(GZFileIt->second.get()); +} + +Expect WasmEdgeZlibGZEof::body(const Runtime::CallingFrame &, + uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZEof] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gzeof(GZFileIt->second.get()); +} + +Expect WasmEdgeZlibGZDirect::body(const Runtime::CallingFrame &, + uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZDirect] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gzdirect(GZFileIt->second.get()); +} + +Expect WasmEdgeZlibGZClose::body(const Runtime::CallingFrame &, + uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZClose] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + const auto ZRes = gzclose(GZFileIt->second.get()); + + Env.GZFileMap.erase(GZFileIt); + + return ZRes; +} + +Expect WasmEdgeZlibGZClose_r::body(const Runtime::CallingFrame &, + uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZClose_r] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + const auto ZRes = gzclose_r(GZFileIt->second.get()); + + Env.GZFileMap.erase(GZFileIt); + + return ZRes; +} + +Expect WasmEdgeZlibGZClose_w::body(const Runtime::CallingFrame &, + uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZClose_w] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + const auto ZRes = gzclose_w(GZFileIt->second.get()); + + Env.GZFileMap.erase(GZFileIt); + + return ZRes; +} + +Expect WasmEdgeZlibGZClearerr::body(const Runtime::CallingFrame &, + uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZClearerr] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + gzclearerr(GZFileIt->second.get()); + + return Expect{}; +} + +Expect WasmEdgeZlibAdler32::body(const Runtime::CallingFrame &Frame, + uint32_t Adler, uint32_t BufPtr, + uint32_t Len) { + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibAdler32] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto *Buf = MemInst->getPointer(BufPtr); + + return adler32(Adler, Buf, Len); +} + +Expect WasmEdgeZlibAdler32_z::body(const Runtime::CallingFrame &Frame, + uint32_t Adler, uint32_t BufPtr, + uint32_t Len) { + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibAdler32_z] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto *Buf = MemInst->getPointer(BufPtr); + + return adler32_z(Adler, Buf, Len); +} + +Expect WasmEdgeZlibAdler32Combine::body(const Runtime::CallingFrame &, + uint32_t Adler1, + uint32_t Adler2, + int32_t Len2) { + return adler32_combine(Adler1, Adler2, Len2); +} + +Expect WasmEdgeZlibCRC32::body(const Runtime::CallingFrame &Frame, + uint32_t CRC, uint32_t BufPtr, + uint32_t Len) { + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibCRC32] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto *Buf = MemInst->getPointer(BufPtr); + + return crc32(CRC, Buf, Len); +} + +Expect WasmEdgeZlibCRC32_z::body(const Runtime::CallingFrame &Frame, + uint32_t CRC, uint32_t BufPtr, + uint32_t Len) { + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibCRC32_z] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + auto *Buf = MemInst->getPointer(BufPtr); + + return crc32_z(CRC, Buf, Len); +} + +Expect WasmEdgeZlibCRC32Combine::body(const Runtime::CallingFrame &, + uint32_t CRC1, uint32_t CRC2, + int32_t Len2) { + return crc32_combine(CRC1, CRC2, Len2); +} + +Expect +WasmEdgeZlibDeflateInit_::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, int32_t Level, + uint32_t VersionPtr, int32_t StreamSize) { + if (!CheckSize(StreamSize)) + return static_cast(Z_VERSION_ERROR); + + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibDeflateInit_] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + const auto *WasmZlibVersion = MemInst->getPointer(VersionPtr); + auto HostZStream = std::make_unique(); + + // ignore wasm custom allocators + HostZStream->zalloc = Z_NULL; + HostZStream->zfree = Z_NULL; + // ignore opaque since zmalloc and zfree was ignored + HostZStream->opaque = Z_NULL; + + auto It = + Env.ZStreamMap.emplace(std::make_pair(ZStreamPtr, std::move(HostZStream))) + .second; + + const auto ZRes = + SyncRun("WasmEdgeZlibDeflateInit_", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return deflateInit_(HostZStream, Level, WasmZlibVersion, + sizeof(z_stream)); + }); + + if (ZRes != Z_OK) + Env.ZStreamMap.erase(It); + + return ZRes; +} + +Expect +WasmEdgeZlibInflateInit_::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, uint32_t VersionPtr, + int32_t StreamSize) { + if (!CheckSize(StreamSize)) + return static_cast(Z_VERSION_ERROR); + + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibInflateInit_] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + const auto *WasmZlibVersion = MemInst->getPointer(VersionPtr); + auto HostZStream = std::make_unique(); + + // ignore wasm custom allocators + HostZStream->zalloc = Z_NULL; + HostZStream->zfree = Z_NULL; + // ignore opaque since zmalloc and zfree was ignored + HostZStream->opaque = Z_NULL; + + auto It = + Env.ZStreamMap.emplace(std::make_pair(ZStreamPtr, std::move(HostZStream))) + .second; + + const auto ZRes = SyncRun("WasmEdgeZlibInflateInit_", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return inflateInit_(HostZStream, WasmZlibVersion, + sizeof(z_stream)); + }); + + if (ZRes != Z_OK) + Env.ZStreamMap.erase(It); + + return ZRes; +} + +Expect WasmEdgeZlibDeflateInit2_::body( + const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, int32_t Level, + int32_t Method, int32_t WindowBits, int32_t MemLevel, int32_t Strategy, + uint32_t VersionPtr, int32_t StreamSize) { + if (!CheckSize(StreamSize)) + return static_cast(Z_VERSION_ERROR); + + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibDeflateInit2_] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + const auto *WasmZlibVersion = MemInst->getPointer(VersionPtr); + auto HostZStream = std::make_unique(); + HostZStream->zalloc = Z_NULL; + HostZStream->zfree = Z_NULL; + HostZStream->opaque = + Z_NULL; // ignore opaque since zmalloc and zfree was ignored + + auto It = + Env.ZStreamMap.emplace(std::make_pair(ZStreamPtr, std::move(HostZStream))) + .second; + + const auto ZRes = SyncRun( + "WasmEdgeZlibDeflateInit2_", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return deflateInit2_(HostZStream, Level, Method, WindowBits, MemLevel, + Strategy, WasmZlibVersion, sizeof(z_stream)); + }); + + if (ZRes != Z_OK) + Env.ZStreamMap.erase(It); + + return ZRes; +} + +Expect +WasmEdgeZlibInflateInit2_::body(const Runtime::CallingFrame &Frame, + uint32_t ZStreamPtr, int32_t WindowBits, + uint32_t VersionPtr, int32_t StreamSize) { + if (!CheckSize(StreamSize)) + return static_cast(Z_VERSION_ERROR); + + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibInflateInit2_] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + const auto *WasmZlibVersion = MemInst->getPointer(VersionPtr); + auto HostZStream = std::make_unique(); + HostZStream->zalloc = Z_NULL; + HostZStream->zfree = Z_NULL; + HostZStream->opaque = + Z_NULL; // ignore opaque since zmalloc and zfree was ignored + + auto It = + Env.ZStreamMap.emplace(std::make_pair(ZStreamPtr, std::move(HostZStream))) + .second; + + const auto ZRes = + SyncRun("WasmEdgeZlibInflateInit2_", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return inflateInit2_(HostZStream, WindowBits, WasmZlibVersion, + sizeof(z_stream)); + }); + + if (ZRes != Z_OK) + Env.ZStreamMap.erase(It); + + return ZRes; +} + +Expect WasmEdgeZlibInflateBackInit_::body( + const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, int32_t WindowBits, + uint32_t WindowPtr, uint32_t VersionPtr, int32_t StreamSize) { + if (!CheckSize(StreamSize)) + return static_cast(Z_VERSION_ERROR); + + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibInflateBackInit_] "sv + "Frame.getMemoryByIndex(0) returned nullptr."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + const auto *WasmZlibVersion = MemInst->getPointer(VersionPtr); + auto *Window = MemInst->getPointer(WindowPtr); + auto HostZStream = std::make_unique(); + HostZStream->zalloc = Z_NULL; + HostZStream->zfree = Z_NULL; + HostZStream->opaque = + Z_NULL; // ignore opaque since zmalloc and zfree was ignored + + auto It = + Env.ZStreamMap.emplace(std::make_pair(ZStreamPtr, std::move(HostZStream))) + .second; + + const auto ZRes = + SyncRun("WasmEdgeZlibInflateBackInit_", Env, ZStreamPtr, Frame, + [&](z_stream *HostZStream) { + return inflateBackInit_(HostZStream, WindowBits, Window, + WasmZlibVersion, sizeof(z_stream)); + }); + + if (ZRes != Z_OK) + Env.ZStreamMap.erase(It); + + return ZRes; +} + +Expect WasmEdgeZlibGZGetc_::body(const Runtime::CallingFrame &, + uint32_t GZFile) { + const auto GZFileIt = Env.GZFileMap.find(GZFile); + if (GZFileIt == Env.GZFileMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZGetc_] "sv + "Invalid GZFile received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return gzgetc_(GZFileIt->second.get()); +} + +Expect +WasmEdgeZlibInflateSyncPoint::body(const Runtime::CallingFrame &, + uint32_t ZStreamPtr) { + const auto HostZStreamIt = Env.ZStreamMap.find(ZStreamPtr); + if (HostZStreamIt == Env.ZStreamMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibInflateSyncPoint] "sv + "Invalid ZStreamPtr received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return inflateSyncPoint(HostZStreamIt->second.get()); +} + +Expect +WasmEdgeZlibInflateUndermine::body(const Runtime::CallingFrame &, + uint32_t ZStreamPtr, int32_t Subvert) { + const auto HostZStreamIt = Env.ZStreamMap.find(ZStreamPtr); + if (HostZStreamIt == Env.ZStreamMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibInflateUndermine] "sv + "Invalid ZStreamPtr received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return inflateUndermine(HostZStreamIt->second.get(), Subvert); +} + +Expect WasmEdgeZlibInflateValidate::body(const Runtime::CallingFrame &, + uint32_t ZStreamPtr, + int32_t Check) { + const auto HostZStreamIt = Env.ZStreamMap.find(ZStreamPtr); + if (HostZStreamIt == Env.ZStreamMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibInflateValidate] "sv + "Invalid ZStreamPtr received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return inflateValidate(HostZStreamIt->second.get(), Check); +} + +Expect +WasmEdgeZlibInflateCodesUsed::body(const Runtime::CallingFrame &, + uint32_t ZStreamPtr) { + const auto HostZStreamIt = Env.ZStreamMap.find(ZStreamPtr); + if (HostZStreamIt == Env.ZStreamMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibInflateCodesUsed] "sv + "Invalid ZStreamPtr received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return inflateCodesUsed(HostZStreamIt->second.get()); +} + +Expect +WasmEdgeZlibInflateResetKeep::body(const Runtime::CallingFrame &, + uint32_t ZStreamPtr) { + const auto HostZStreamIt = Env.ZStreamMap.find(ZStreamPtr); + if (HostZStreamIt == Env.ZStreamMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibInflateResetKeep] "sv + "Invalid ZStreamPtr received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return inflateResetKeep(HostZStreamIt->second.get()); +} + +Expect +WasmEdgeZlibDeflateResetKeep::body(const Runtime::CallingFrame &, + uint32_t ZStreamPtr) { + const auto HostZStreamIt = Env.ZStreamMap.find(ZStreamPtr); + if (HostZStreamIt == Env.ZStreamMap.end()) { + spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibDeflateResetKeep] "sv + "Invalid ZStreamPtr received."sv); + return Unexpect(ErrCode::Value::HostFuncError); + } + + return deflateResetKeep(HostZStreamIt->second.get()); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_zlib/zlibfunc.h b/plugins/wasmedge_zlib/zlibfunc.h new file mode 100644 index 00000000..d276462d --- /dev/null +++ b/plugins/wasmedge_zlib/zlibfunc.h @@ -0,0 +1,631 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "zlibbase.h" + +namespace WasmEdge { +namespace Host { + +class WasmEdgeZlibDeflateInit : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateInit(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t Level); +}; + +class WasmEdgeZlibDeflate : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflate(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t Flush); +}; + +class WasmEdgeZlibDeflateEnd : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateEnd(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr); +}; + +class WasmEdgeZlibInflateInit : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateInit(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr); +}; + +class WasmEdgeZlibInflate : public WasmEdgeZlib { +public: + WasmEdgeZlibInflate(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t Flush); +}; + +class WasmEdgeZlibInflateEnd : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateEnd(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr); +}; + +class WasmEdgeZlibDeflateInit2 : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateInit2(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t Level, int32_t Method, int32_t WindowBits, + int32_t MemLevel, int32_t Strategy); +}; + +class WasmEdgeZlibDeflateSetDictionary + : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateSetDictionary(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t DictionaryPtr, uint32_t DictLength); +}; + +class WasmEdgeZlibDeflateGetDictionary + : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateGetDictionary(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t DictionaryPtr, uint32_t DictLengthPtr); +}; + +class WasmEdgeZlibDeflateCopy : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateCopy(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t DestPtr, + uint32_t SourcePtr); +}; + +class WasmEdgeZlibDeflateReset : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateReset(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr); +}; + +class WasmEdgeZlibDeflateParams + : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateParams(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t Level, int32_t Strategy); +}; + +class WasmEdgeZlibDeflateTune : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateTune(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t GoodLength, int32_t MaxLazy, int32_t NiceLength, + int32_t MaxChain); +}; + +// https://github.com/emscripten-core/emscripten/issues/17009 +// Using 32bit, because on wasm-side it will be 32bit long +class WasmEdgeZlibDeflateBound : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateBound(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t SourceLen); +}; + +class WasmEdgeZlibDeflatePending + : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflatePending(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t PendingPtr, uint32_t BitsPtr); +}; + +class WasmEdgeZlibDeflatePrime : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflatePrime(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t Bits, int32_t Value); +}; + +class WasmEdgeZlibDeflateSetHeader + : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateSetHeader(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t HeadPtr); +}; + +class WasmEdgeZlibInflateInit2 : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateInit2(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t WindowBits); +}; + +class WasmEdgeZlibInflateSetDictionary + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateSetDictionary(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t DictionaryPtr, uint32_t DictLength); +}; + +class WasmEdgeZlibInflateGetDictionary + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateGetDictionary(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t DictionaryPtr, uint32_t DictLengthPtr); +}; + +class WasmEdgeZlibInflateSync : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateSync(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr); +}; + +class WasmEdgeZlibInflateCopy : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateCopy(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t DestPtr, + uint32_t SourcePtr); +}; + +class WasmEdgeZlibInflateReset : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateReset(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr); +}; + +class WasmEdgeZlibInflateReset2 + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateReset2(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t WindowBits); +}; + +class WasmEdgeZlibInflatePrime : public WasmEdgeZlib { +public: + WasmEdgeZlibInflatePrime(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t Bits, int32_t Value); +}; + +class WasmEdgeZlibInflateMark : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateMark(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr); +}; + +class WasmEdgeZlibInflateGetHeader + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateGetHeader(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t HeadPtr); +}; + +class WasmEdgeZlibInflateBackInit + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateBackInit(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t WindowBits, uint32_t WindowPtr); +}; + +class WasmEdgeZlibInflateBackEnd + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateBackEnd(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr); +}; + +class WasmEdgeZlibZlibCompilerFlags + : public WasmEdgeZlib { +public: + WasmEdgeZlibZlibCompilerFlags(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class WasmEdgeZlibCompress : public WasmEdgeZlib { +public: + WasmEdgeZlibCompress(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t DestPtr, + uint32_t DestLenPtr, uint32_t SourcePtr, + uint32_t SourceLen); +}; + +class WasmEdgeZlibCompress2 : public WasmEdgeZlib { +public: + WasmEdgeZlibCompress2(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t DestPtr, + uint32_t DestLenPtr, uint32_t SourcePtr, + uint32_t SourceLen, int32_t Level); +}; + +class WasmEdgeZlibCompressBound + : public WasmEdgeZlib { +public: + WasmEdgeZlibCompressBound(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SourceLen); +}; + +class WasmEdgeZlibUncompress : public WasmEdgeZlib { +public: + WasmEdgeZlibUncompress(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t DestPtr, + uint32_t DestLenPtr, uint32_t SourcePtr, + uint32_t SourceLen); +}; + +class WasmEdgeZlibUncompress2 : public WasmEdgeZlib { +public: + WasmEdgeZlibUncompress2(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t DestPtr, + uint32_t DestLenPtr, uint32_t SourcePtr, + uint32_t SourceLenPtr); +}; + +class WasmEdgeZlibGZOpen : public WasmEdgeZlib { +public: + WasmEdgeZlibGZOpen(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t PathPtr, + uint32_t ModePtr); +}; + +class WasmEdgeZlibGZDOpen : public WasmEdgeZlib { +public: + WasmEdgeZlibGZDOpen(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t FD, + uint32_t ModePtr); +}; + +class WasmEdgeZlibGZBuffer : public WasmEdgeZlib { +public: + WasmEdgeZlibGZBuffer(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile, + uint32_t Size); +}; + +class WasmEdgeZlibGZSetParams : public WasmEdgeZlib { +public: + WasmEdgeZlibGZSetParams(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile, + int32_t Level, int32_t Strategy); +}; + +class WasmEdgeZlibGZRead : public WasmEdgeZlib { +public: + WasmEdgeZlibGZRead(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile, + uint32_t BufPtr, uint32_t Len); +}; + +class WasmEdgeZlibGZFread : public WasmEdgeZlib { +public: + WasmEdgeZlibGZFread(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t BufPtr, + uint32_t Size, uint32_t NItems, uint32_t GZFile); +}; + +class WasmEdgeZlibGZWrite : public WasmEdgeZlib { +public: + WasmEdgeZlibGZWrite(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile, + uint32_t BufPtr, uint32_t Len); +}; + +class WasmEdgeZlibGZFwrite : public WasmEdgeZlib { +public: + WasmEdgeZlibGZFwrite(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t BufPtr, + uint32_t Size, uint32_t NItems, uint32_t GZFile); +}; + +class WasmEdgeZlibGZPuts : public WasmEdgeZlib { +public: + WasmEdgeZlibGZPuts(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile, + uint32_t StringPtr); +}; + +class WasmEdgeZlibGZPutc : public WasmEdgeZlib { +public: + WasmEdgeZlibGZPutc(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile, + int32_t C); +}; + +class WasmEdgeZlibGZGetc : public WasmEdgeZlib { +public: + WasmEdgeZlibGZGetc(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile); +}; + +class WasmEdgeZlibGZUngetc : public WasmEdgeZlib { +public: + WasmEdgeZlibGZUngetc(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t C, + uint32_t GZFile); +}; + +class WasmEdgeZlibGZFlush : public WasmEdgeZlib { +public: + WasmEdgeZlibGZFlush(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile, + int32_t Flush); +}; + +// z_off_t --> long +class WasmEdgeZlibGZSeek : public WasmEdgeZlib { +public: + WasmEdgeZlibGZSeek(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile, + int32_t Offset, int32_t Whence); +}; + +class WasmEdgeZlibGZRewind : public WasmEdgeZlib { +public: + WasmEdgeZlibGZRewind(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile); +}; + +class WasmEdgeZlibGZTell : public WasmEdgeZlib { +public: + WasmEdgeZlibGZTell(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile); +}; + +class WasmEdgeZlibGZOffset : public WasmEdgeZlib { +public: + WasmEdgeZlibGZOffset(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile); +}; + +class WasmEdgeZlibGZEof : public WasmEdgeZlib { +public: + WasmEdgeZlibGZEof(WasmEdgeZlibEnvironment &HostEnv) : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile); +}; + +class WasmEdgeZlibGZDirect : public WasmEdgeZlib { +public: + WasmEdgeZlibGZDirect(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile); +}; + +class WasmEdgeZlibGZClose : public WasmEdgeZlib { +public: + WasmEdgeZlibGZClose(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile); +}; + +class WasmEdgeZlibGZClose_r : public WasmEdgeZlib { +public: + WasmEdgeZlibGZClose_r(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile); +}; + +class WasmEdgeZlibGZClose_w : public WasmEdgeZlib { +public: + WasmEdgeZlibGZClose_w(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile); +}; + +class WasmEdgeZlibGZClearerr : public WasmEdgeZlib { +public: + WasmEdgeZlibGZClearerr(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile); +}; + +class WasmEdgeZlibAdler32 : public WasmEdgeZlib { +public: + WasmEdgeZlibAdler32(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t Adler, + uint32_t BufPtr, uint32_t Len); +}; + +class WasmEdgeZlibAdler32_z : public WasmEdgeZlib { +public: + WasmEdgeZlibAdler32_z(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t Adler, + uint32_t BufPtr, uint32_t Len); +}; + +// z_off_t --> long +class WasmEdgeZlibAdler32Combine + : public WasmEdgeZlib { +public: + WasmEdgeZlibAdler32Combine(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t Adler1, + uint32_t Adler2, int32_t Len2); +}; + +class WasmEdgeZlibCRC32 : public WasmEdgeZlib { +public: + WasmEdgeZlibCRC32(WasmEdgeZlibEnvironment &HostEnv) : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t CRC, + uint32_t BufPtr, uint32_t Len); +}; + +class WasmEdgeZlibCRC32_z : public WasmEdgeZlib { +public: + WasmEdgeZlibCRC32_z(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t CRC, + uint32_t BufPtr, uint32_t Len); +}; + +// z_off_t --> long +class WasmEdgeZlibCRC32Combine : public WasmEdgeZlib { +public: + WasmEdgeZlibCRC32Combine(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t CRC1, + uint32_t CRC2, int32_t Len2); +}; + +class WasmEdgeZlibDeflateInit_ : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateInit_(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t Level, uint32_t VersionPtr, int32_t StreamSize); +}; + +class WasmEdgeZlibInflateInit_ : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateInit_(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + uint32_t VersionPtr, int32_t StreamSize); +}; + +class WasmEdgeZlibDeflateInit2_ + : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateInit2_(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t Level, int32_t Method, int32_t WindowBits, + int32_t MemLevel, int32_t Strategy, uint32_t VersionPtr, + int32_t StreamSize); +}; + +class WasmEdgeZlibInflateInit2_ + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateInit2_(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t WindowBits, uint32_t VersionPtr, + int32_t StreamSize); +}; + +class WasmEdgeZlibInflateBackInit_ + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateBackInit_(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t WindowBits, uint32_t WindowPtr, + uint32_t VersionPtr, int32_t StreamSize); +}; + +class WasmEdgeZlibGZGetc_ : public WasmEdgeZlib { +public: + WasmEdgeZlibGZGetc_(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GZFile); +}; + +class WasmEdgeZlibInflateSyncPoint + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateSyncPoint(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr); +}; + +class WasmEdgeZlibInflateUndermine + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateUndermine(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t Subvert); +}; + +class WasmEdgeZlibInflateValidate + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateValidate(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, + int32_t Check); +}; + +class WasmEdgeZlibInflateCodesUsed + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateCodesUsed(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr); +}; + +class WasmEdgeZlibInflateResetKeep + : public WasmEdgeZlib { +public: + WasmEdgeZlibInflateResetKeep(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr); +}; + +class WasmEdgeZlibDeflateResetKeep + : public WasmEdgeZlib { +public: + WasmEdgeZlibDeflateResetKeep(WasmEdgeZlibEnvironment &HostEnv) + : WasmEdgeZlib(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr); +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_zlib/zlibmodule.cpp b/plugins/wasmedge_zlib/zlibmodule.cpp new file mode 100644 index 00000000..ecd2fd6e --- /dev/null +++ b/plugins/wasmedge_zlib/zlibmodule.cpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "zlibmodule.h" +#include "zlibfunc.h" + +namespace WasmEdge { +namespace Host { + +/// Register your functions in module. +WasmEdgeZlibModule::WasmEdgeZlibModule() : ModuleInstance("wasmedge_zlib") { + addHostFunc("deflateInit", std::make_unique(Env)); + addHostFunc("deflate", std::make_unique(Env)); + addHostFunc("deflateEnd", std::make_unique(Env)); + addHostFunc("inflateInit", std::make_unique(Env)); + addHostFunc("inflate", std::make_unique(Env)); + addHostFunc("inflateEnd", std::make_unique(Env)); + addHostFunc("deflateInit2", std::make_unique(Env)); + addHostFunc("deflateSetDictionary", + std::make_unique(Env)); + addHostFunc("deflateGetDictionary", + std::make_unique(Env)); + addHostFunc("deflateCopy", std::make_unique(Env)); + addHostFunc("deflateReset", std::make_unique(Env)); + addHostFunc("deflateParams", + std::make_unique(Env)); + addHostFunc("deflateTune", std::make_unique(Env)); + addHostFunc("deflateBound", std::make_unique(Env)); + addHostFunc("deflatePending", + std::make_unique(Env)); + addHostFunc("deflatePrime", std::make_unique(Env)); + addHostFunc("deflateSetHeader", + std::make_unique(Env)); + addHostFunc("inflateInit2", std::make_unique(Env)); + addHostFunc("inflateSetDictionary", + std::make_unique(Env)); + addHostFunc("inflateGetDictionary", + std::make_unique(Env)); + addHostFunc("inflateSync", std::make_unique(Env)); + addHostFunc("inflateCopy", std::make_unique(Env)); + addHostFunc("inflateReset", std::make_unique(Env)); + addHostFunc("inflateReset2", + std::make_unique(Env)); + addHostFunc("inflatePrime", std::make_unique(Env)); + addHostFunc("inflateMark", std::make_unique(Env)); + addHostFunc("inflateGetHeader", + std::make_unique(Env)); + addHostFunc("inflateBackInit", + std::make_unique(Env)); + addHostFunc("inflateBackEnd", + std::make_unique(Env)); + addHostFunc("zlibCompileFlags", + std::make_unique(Env)); + addHostFunc("compress", std::make_unique(Env)); + addHostFunc("compress2", std::make_unique(Env)); + addHostFunc("compressBound", + std::make_unique(Env)); + addHostFunc("uncompress", std::make_unique(Env)); + addHostFunc("uncompress2", std::make_unique(Env)); + addHostFunc("gzopen", std::make_unique(Env)); + addHostFunc("gzdopen", std::make_unique(Env)); + addHostFunc("gzbuffer", std::make_unique(Env)); + addHostFunc("gzsetparams", std::make_unique(Env)); + addHostFunc("gzread", std::make_unique(Env)); + addHostFunc("gzfread", std::make_unique(Env)); + addHostFunc("gzwrite", std::make_unique(Env)); + addHostFunc("gzfwrite", std::make_unique(Env)); + addHostFunc("gzputs", std::make_unique(Env)); + addHostFunc("gzputc", std::make_unique(Env)); + addHostFunc("gzgetc", std::make_unique(Env)); + addHostFunc("gzungetc", std::make_unique(Env)); + addHostFunc("gzflush", std::make_unique(Env)); + addHostFunc("gzseek", std::make_unique(Env)); + addHostFunc("gzrewind", std::make_unique(Env)); + addHostFunc("gztell", std::make_unique(Env)); + addHostFunc("gzoffset", std::make_unique(Env)); + addHostFunc("gzeof", std::make_unique(Env)); + addHostFunc("gzdirect", std::make_unique(Env)); + addHostFunc("gzclose", std::make_unique(Env)); + addHostFunc("gzclose_r", std::make_unique(Env)); + addHostFunc("gzclose_w", std::make_unique(Env)); + addHostFunc("gzclearerr", std::make_unique(Env)); + addHostFunc("adler32", std::make_unique(Env)); + addHostFunc("adler32_z", std::make_unique(Env)); + addHostFunc("adler32_combine", + std::make_unique(Env)); + addHostFunc("crc32", std::make_unique(Env)); + addHostFunc("crc32_z", std::make_unique(Env)); + addHostFunc("crc32_combine", std::make_unique(Env)); + addHostFunc("deflateInit_", std::make_unique(Env)); + addHostFunc("inflateInit_", std::make_unique(Env)); + addHostFunc("deflateInit2_", + std::make_unique(Env)); + addHostFunc("inflateInit2_", + std::make_unique(Env)); + addHostFunc("inflateBackInit_", + std::make_unique(Env)); + addHostFunc("gzgetc_", std::make_unique(Env)); + addHostFunc("inflateSyncPoint", + std::make_unique(Env)); + addHostFunc("inflateUndermine", + std::make_unique(Env)); + addHostFunc("inflateValidate", + std::make_unique(Env)); + addHostFunc("inflateCodesUsed", + std::make_unique(Env)); + addHostFunc("inflateResetKeep", + std::make_unique(Env)); + addHostFunc("deflateResetKeep", + std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_zlib/zlibmodule.h b/plugins/wasmedge_zlib/zlibmodule.h new file mode 100644 index 00000000..e502993c --- /dev/null +++ b/plugins/wasmedge_zlib/zlibmodule.h @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "runtime/instance/module.h" +#include "zlibenv.h" + +namespace WasmEdge { +namespace Host { + +class WasmEdgeZlibModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeZlibModule(); + + WasmEdgeZlibEnvironment &getEnv() { return Env; } + +private: + WasmEdgeZlibEnvironment Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt index a30ccef2..808507fd 100644 --- a/test/plugins/CMakeLists.txt +++ b/test/plugins/CMakeLists.txt @@ -7,6 +7,10 @@ if(WASMEDGE_PLUGIN_PROCESS) endif() endif() +if(WASMEDGE_PLUGIN_ZLIB) + add_subdirectory(wasmedge_zlib) +endif() + if(WASMEDGE_PLUGIN_WASI_NN_BACKEND) add_subdirectory(wasi_nn) endif() diff --git a/test/plugins/wasmedge_zlib/CMakeLists.txt b/test/plugins/wasmedge_zlib/CMakeLists.txt new file mode 100644 index 00000000..7159ab83 --- /dev/null +++ b/test/plugins/wasmedge_zlib/CMakeLists.txt @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +wasmedge_add_executable(wasmedgeZlibTests + wasmedge_zlib.cpp +) + +add_dependencies(wasmedgeZlibTests + wasmedgePluginWasmEdgeZlib +) + +target_include_directories(wasmedgeZlibTests + PUBLIC + $ + $ +) + +target_link_libraries(wasmedgeZlibTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} +) +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgeZlibTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgeZlibTests + PRIVATE + wasmedge_shared + ) +endif() + +add_test(wasmedgeZlibTests wasmedgeZlibTests) diff --git a/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp b/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp new file mode 100644 index 00000000..d47474d8 --- /dev/null +++ b/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp @@ -0,0 +1,342 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "common/defines.h" +#include "runtime/instance/module.h" +#include "zlibfunc.h" +#include "zlibmodule.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace { +WasmEdge::Runtime::CallingFrame DummyCallFrame(nullptr, nullptr); + +WasmEdge::Runtime::Instance::ModuleInstance *createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasmedge_zlib/" + "libwasmedgePluginWasmEdgeZlib" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasmedge_zlib"sv)) { + if (const auto *Module = Plugin->findModule("wasmedge_zlib"sv)) { + return Module->create().release(); + } + } + return nullptr; +} + +} // namespace + +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t Offset, uint32_t Cnt, uint8_t C = 0) noexcept { + std::fill_n(MemInst.getPointer(Offset), Cnt, C); +} + +static constexpr size_t DATA_SIZE = 1 * 1024 * 1024ULL; +static constexpr size_t OUTPUT_BUFFER_SIZE = 64 * 1024ULL; + +constexpr auto RandChar = []() -> char { + constexpr char Charset[] = "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + constexpr size_t MaxIndex = (sizeof(Charset) - 1); + return Charset[std::rand() % MaxIndex]; +}; + +TEST(WasmEdgeZlibTest, DeflateInflateCycle) { + auto *ZlibMod = + dynamic_cast(createModule()); + ASSERT_TRUE(ZlibMod != nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(16 * 64, 16 * 64))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + uint32_t + // WASM Memory Heap Pointer + WasmHP = 1, + WasmData, WasmZlibVersion, ModuleZStream, WasmCompressedData, + WasmDecompressedData; + uint32_t WasmCompressedData_size = 0, WasmDecompressedData_size = 0; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + auto *FuncInst = ZlibMod->findFuncExports("deflateInit_"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &DeflateInit_ = dynamic_cast( + FuncInst->getHostFunc()); + + FuncInst = ZlibMod->findFuncExports("deflate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &Deflate = dynamic_cast( + FuncInst->getHostFunc()); + + FuncInst = ZlibMod->findFuncExports("deflateEnd"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &DeflateEnd = dynamic_cast( + FuncInst->getHostFunc()); + + FuncInst = ZlibMod->findFuncExports("inflateInit_"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &InflateInit_ = dynamic_cast( + FuncInst->getHostFunc()); + + FuncInst = ZlibMod->findFuncExports("inflate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &Inflate = dynamic_cast( + FuncInst->getHostFunc()); + + FuncInst = ZlibMod->findFuncExports("inflateEnd"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &InflateEnd = dynamic_cast( + FuncInst->getHostFunc()); + + std::array RetVal; + + WasmZlibVersion = WasmHP; + std::snprintf(MemInst.getPointer(WasmHP), std::strlen(ZLIB_VERSION), + ZLIB_VERSION); + WasmHP += std::strlen(ZLIB_VERSION); + + WasmData = WasmHP; + std::generate_n(MemInst.getPointer(WasmHP), DATA_SIZE, RandChar); + WasmHP += DATA_SIZE; + + ModuleZStream = WasmHP; + WasmZStream *strm = MemInst.getPointer(ModuleZStream); + WasmHP += sizeof(WasmZStream); + + // ----- Deflate Routine START------ + fillMemContent(MemInst, ModuleZStream, sizeof(WasmZStream), 0U); + + // deflateInit_ Test + // WASM z_stream size Mismatch + EXPECT_TRUE(DeflateInit_.run(CallFrame, + std::initializer_list{ + ModuleZStream, INT32_C(-1), WasmZlibVersion, + sizeof(WasmZStream) + 16}, + RetVal)); + EXPECT_EQ(RetVal[0].get(), Z_VERSION_ERROR); + + // Version Mismatch + EXPECT_TRUE(DeflateInit_.run( + CallFrame, + std::initializer_list{ + ModuleZStream, INT32_C(-1), WasmZlibVersion + 2, sizeof(WasmZStream)}, + RetVal)); + EXPECT_EQ(RetVal[0].get(), Z_VERSION_ERROR); + + EXPECT_TRUE(DeflateInit_.run( + CallFrame, + std::initializer_list{ + ModuleZStream, INT32_C(-1), WasmZlibVersion, sizeof(WasmZStream)}, + RetVal)); + EXPECT_EQ(RetVal[0].get(), Z_OK); + + WasmCompressedData = WasmHP; + + strm->AvailIn = DATA_SIZE; + strm->NextIn = WasmData; + strm->AvailOut = OUTPUT_BUFFER_SIZE; + strm->NextOut = WasmCompressedData; + + // deflate Test + do { + if (strm->AvailOut == 0) { + WasmHP += OUTPUT_BUFFER_SIZE; + strm->AvailOut = OUTPUT_BUFFER_SIZE; + strm->NextOut = WasmHP; + } + + EXPECT_TRUE(Deflate.run(CallFrame, + std::initializer_list{ + ModuleZStream, + INT32_C(Z_FINISH), + }, + RetVal)); + EXPECT_NE(RetVal[0].get(), Z_STREAM_ERROR); + } while (RetVal[0].get() != Z_STREAM_END); + + // deflateEnd Test + EXPECT_TRUE(DeflateEnd.run( + CallFrame, std::initializer_list{ModuleZStream}, + RetVal)); + EXPECT_EQ(RetVal[0].get(), Z_OK); + WasmHP += OUTPUT_BUFFER_SIZE - strm->AvailOut; + WasmCompressedData_size = WasmHP - WasmCompressedData; + // ----- Deflate Routine END------ + + // ----- Inflate Routine START------ + fillMemContent(MemInst, ModuleZStream, sizeof(WasmZStream), 0U); + + // inflateInit_ Test + // WASM z_stream size Mismatch + EXPECT_TRUE(InflateInit_.run( + CallFrame, + std::initializer_list{ + ModuleZStream, WasmZlibVersion, sizeof(WasmZStream) + 16}, + RetVal)); + EXPECT_EQ(RetVal[0].get(), Z_VERSION_ERROR); + + // Version Mismatch + EXPECT_TRUE(InflateInit_.run( + CallFrame, + std::initializer_list{ + ModuleZStream, WasmZlibVersion + 2, sizeof(WasmZStream)}, + RetVal)); + EXPECT_EQ(RetVal[0].get(), Z_VERSION_ERROR); + + EXPECT_TRUE( + InflateInit_.run(CallFrame, + std::initializer_list{ + ModuleZStream, WasmZlibVersion, sizeof(WasmZStream)}, + RetVal)); + EXPECT_EQ(RetVal[0].get(), Z_OK); + + WasmDecompressedData = WasmHP; + + strm->AvailIn = WasmCompressedData_size; + strm->NextIn = WasmCompressedData; + strm->AvailOut = OUTPUT_BUFFER_SIZE; + strm->NextOut = WasmDecompressedData; + + // inflate test + do { + if (strm->AvailOut == 0) { + WasmHP += OUTPUT_BUFFER_SIZE; + strm->AvailOut = OUTPUT_BUFFER_SIZE; + strm->NextOut = WasmHP; + } + + EXPECT_TRUE(Inflate.run(CallFrame, + std::initializer_list{ + ModuleZStream, + INT32_C(Z_FINISH), + }, + RetVal)); + EXPECT_NE(RetVal[0].get(), Z_STREAM_ERROR); + } while (RetVal[0].get() != Z_STREAM_END); + + EXPECT_TRUE(InflateEnd.run( + CallFrame, std::initializer_list{ModuleZStream}, + RetVal)); + EXPECT_EQ(RetVal[0].get(), Z_OK); + WasmHP += OUTPUT_BUFFER_SIZE - strm->AvailOut; + WasmDecompressedData_size = WasmHP - WasmDecompressedData; + // ----- Inflate Routine END------ + + // Test Decompressed Buffer size against source Data size. + EXPECT_EQ(WasmDecompressedData_size, DATA_SIZE); + // Test Decompressed Buffer content against source Data. + EXPECT_TRUE(std::equal(MemInst.getPointer(WasmDecompressedData), + MemInst.getPointer( + WasmDecompressedData + WasmDecompressedData_size), + MemInst.getPointer(WasmData))); +} + +TEST(WasmEdgeZlibTest, Module) { + // Create the wasmedge_process module instance. + auto *ZlibMod = + dynamic_cast(createModule()); + EXPECT_FALSE(ZlibMod == nullptr); + EXPECT_TRUE(ZlibMod->getEnv().ZStreamMap.empty()); + EXPECT_EQ(ZlibMod->getFuncExportNum(), 76U); + + EXPECT_NE(ZlibMod->findFuncExports("deflateInit"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflate"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateEnd"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateInit"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflate"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateEnd"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateInit2"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateSetDictionary"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateGetDictionary"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateCopy"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateReset"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateParams"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateTune"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateBound"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflatePending"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflatePrime"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateSetHeader"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateInit2"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateSetDictionary"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateGetDictionary"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateSync"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateCopy"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateReset"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateReset2"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflatePrime"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateMark"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateGetHeader"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateBackInit"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateBackEnd"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("zlibCompileFlags"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("compress"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("compress2"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("compressBound"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("uncompress"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("uncompress2"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzopen"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzdopen"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzbuffer"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzsetparams"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzread"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzfread"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzwrite"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzfwrite"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzputs"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzputc"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzgetc"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzungetc"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzflush"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzseek"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzrewind"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gztell"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzoffset"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzeof"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzdirect"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzclose"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzclose_r"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzclose_w"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzclearerr"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("adler32"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("adler32_z"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("adler32_combine"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("crc32"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("crc32_z"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("crc32_combine"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateInit_"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateInit_"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateInit2_"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateInit2_"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateBackInit_"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("gzgetc_"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateSyncPoint"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateUndermine"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateValidate"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateCodesUsed"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("inflateResetKeep"), nullptr); + EXPECT_NE(ZlibMod->findFuncExports("deflateResetKeep"), nullptr); + + delete ZlibMod; +} + +GTEST_API_ int main(int ArgC, char **ArgV) { + testing::InitGoogleTest(&ArgC, ArgV); + return RUN_ALL_TESTS(); +} From c3bac99c824751b8fcc238be9c0d501403f3e269 Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 13 Sep 2023 12:53:07 +0800 Subject: [PATCH 138/623] [WASI-NN] Add openblas support for ggml Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 320c860e..1e015baa 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -17,7 +17,6 @@ wasmedge_add_library(wasmedgePluginWasiNN target_compile_options(wasmedgePluginWasiNN PUBLIC -DWASMEDGE_PLUGIN - -DGGML_USE_K_QUANTS ) target_include_directories(wasmedgePluginWasiNN From de719bf8b0ad5b5a76582912e955b72d40d7ce20 Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 13 Sep 2023 13:30:29 +0800 Subject: [PATCH 139/623] [WASI-NN] Update thirdparty/ggml Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 4 ++-- plugins/wasi_nn/ggml.cpp | 6 +++--- test/plugins/wasi_nn/CMakeLists.txt | 6 +++--- test/plugins/wasi_nn/wasi_nn.cpp | 9 +++++---- utils/wasi-nn/download-ggml-fixtures.sh | 4 ++-- 5 files changed, 15 insertions(+), 14 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 1e015baa..ce414c92 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -30,13 +30,13 @@ if(WASMEDGE_LINK_PLUGINS_STATIC) target_link_libraries(wasmedgePluginWasiNN PRIVATE wasmedgeCAPI - utilGgml + llama ) else() target_link_libraries(wasmedgePluginWasiNN PRIVATE wasmedge_shared - utilGgml + llama ) endif() diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 8a08ff95..0b483c16 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -114,7 +114,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // Output start from prompt. for (auto Id : CxtRef.LlamaInputs) { - CxtRef.LlamaOutputs += llama_token_to_str(GraphRef.LlamaContext, Id); + CxtRef.LlamaOutputs += llama_token_to_piece(GraphRef.LlamaContext, Id); } // Main predict loop. @@ -145,14 +145,14 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { false}; NewTokenId = llama_sample_token_greedy(GraphRef.LlamaContext, &CandidatesP); - if (NewTokenId == llama_token_eos()) { + if (NewTokenId == llama_token_eos(GraphRef.LlamaContext)) { CxtRef.LlamaOutputs += "[end of text]"sv; break; } // Append the new token. CxtRef.LlamaOutputs += - llama_token_to_str(GraphRef.LlamaContext, NewTokenId); + llama_token_to_piece(GraphRef.LlamaContext, NewTokenId); // Push this new token for next evaluation. CxtRef.LlamaInputs.push_back(NewTokenId); diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index 8b84ba10..257691e6 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -60,9 +60,9 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) COMMAND bash ${CMAKE_SOURCE_DIR}/utils/wasi-nn/download-ggml-fixtures.sh ${CMAKE_CURRENT_BINARY_DIR}/wasinn_ggml_fixtures RESULT_VARIABLE DOWNLOAD_ERROR OUTPUT_STRIP_TRAILING_WHITESPACE) - file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_ggml_fixtures/orca-mini-3b.ggmlv3.q4_0.bin CHECKSUM_MODEL) - if(NOT CHECKSUM_MODEL STREQUAL "6a087f7f4598fad0bb70e6cb4023645e") - message(FATAL_ERROR "orca-mini-3b.ggmlv3.q4_0.bin downloaded with wrong md5") + file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_ggml_fixtures/orca-mini-3b.q4_0.gguf CHECKSUM_MODEL) + if(NOT CHECKSUM_MODEL STREQUAL "516027963397e180d7a92aded43d6b3d") + message(FATAL_ERROR "orca-mini-3b.q4_0.gguf downloaded with wrong md5") endif() else() # Add the other backend test files fetching here. diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 2ff69579..3213d35d 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -1228,7 +1228,7 @@ TEST(WasiNNTest, GGMLBackend) { std::string Prompt = "Once upon a time, "; std::vector TensorData(Prompt.begin(), Prompt.end()); std::vector WeightRead = - readEntireFile("./wasinn_ggml_fixtures/orca-mini-3b.ggmlv3.q4_0.bin"); + readEntireFile("./wasinn_ggml_fixtures/orca-mini-3b.q4_0.gguf"); std::vector TensorDim{1}; uint32_t BuilderPtr = UINT32_C(0); @@ -1477,10 +1477,11 @@ TEST(WasiNNTest, GGMLBackend) { // Should output more than 100 bytes. auto BytesWritten = *MemInst.getPointer(BuilderPtr); EXPECT_GE(BytesWritten, 100); - // Output should begin with the prompt. + // Output should begin with the prompt. (+1 to skip bos token) const auto Output = MemInst.getSpan(StorePtr, 100); - EXPECT_EQ(std::string(Output.begin(), Output.begin() + Prompt.size()), - Prompt); + EXPECT_EQ( + std::string(Output.begin() + 1, Output.begin() + 1 + Prompt.size()), + Prompt); } } #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML diff --git a/utils/wasi-nn/download-ggml-fixtures.sh b/utils/wasi-nn/download-ggml-fixtures.sh index 0c77b184..53f8ea15 100755 --- a/utils/wasi-nn/download-ggml-fixtures.sh +++ b/utils/wasi-nn/download-ggml-fixtures.sh @@ -6,8 +6,8 @@ TODIR=$1 if [[ $# -eq 0 ]]; then TODIR=. fi -MODEL=orca-mini-3b.ggmlv3.q4_0.bin -FIXTURE=https://huggingface.co/TheBloke/orca_mini_3B-GGML/resolve/main/$MODEL +MODEL=orca-mini-3b.q4_0.gguf +FIXTURE=https://huggingface.co/juanjgit/orca_mini_3B-GGUF/resolve/main/$MODEL if [ ! -d $TODIR ]; then mkdir $TODIR fi From def411df1e71cae8ce3adb9c6684a80de8064cdd Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 13 Sep 2023 16:37:33 +0800 Subject: [PATCH 140/623] [WASI-NN] Only build ggml with ggml backend Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index ce414c92..5cba657a 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -30,16 +30,19 @@ if(WASMEDGE_LINK_PLUGINS_STATIC) target_link_libraries(wasmedgePluginWasiNN PRIVATE wasmedgeCAPI - llama ) else() target_link_libraries(wasmedgePluginWasiNN PRIVATE wasmedge_shared - llama ) endif() +string(TOLOWER ${WASMEDGE_PLUGIN_WASI_NN_BACKEND} BACKEND) +if(BACKEND STREQUAL "ggml") + target_link_libraries(wasmedgePluginWasiNN PRIVATE llama) +endif() + include(WASINNDeps) wasmedge_setup_wasinn_target(wasmedgePluginWasiNN) From da69fd1a4df935f82f56acb98e1856fa20695f15 Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 14 Sep 2023 15:04:45 +0800 Subject: [PATCH 141/623] [WASI-NN] Move ggml to wasi_nn/thirdparty/ Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 7 + plugins/wasi_nn/thirdparty/CMakeLists.txt | 10 + .../wasi_nn/thirdparty/ggml/CMakeLists.txt | 644 + plugins/wasi_nn/thirdparty/ggml/LICENSE | 21 + plugins/wasi_nn/thirdparty/ggml/README.md | 10 + plugins/wasi_nn/thirdparty/ggml/common.cpp | 1258 + plugins/wasi_nn/thirdparty/ggml/common.h | 206 + plugins/wasi_nn/thirdparty/ggml/ggml-alloc.c | 633 + plugins/wasi_nn/thirdparty/ggml/ggml-alloc.h | 26 + plugins/wasi_nn/thirdparty/ggml/ggml.c | 20812 ++++++++++++++++ plugins/wasi_nn/thirdparty/ggml/ggml.h | 2005 ++ plugins/wasi_nn/thirdparty/ggml/k_quants.c | 4318 ++++ plugins/wasi_nn/thirdparty/ggml/k_quants.h | 165 + plugins/wasi_nn/thirdparty/ggml/llama.cpp | 6398 +++++ plugins/wasi_nn/thirdparty/ggml/llama.h | 547 + plugins/wasi_nn/thirdparty/ggml/log.h | 643 + 16 files changed, 37703 insertions(+) create mode 100644 plugins/wasi_nn/thirdparty/CMakeLists.txt create mode 100644 plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt create mode 100644 plugins/wasi_nn/thirdparty/ggml/LICENSE create mode 100644 plugins/wasi_nn/thirdparty/ggml/README.md create mode 100644 plugins/wasi_nn/thirdparty/ggml/common.cpp create mode 100644 plugins/wasi_nn/thirdparty/ggml/common.h create mode 100644 plugins/wasi_nn/thirdparty/ggml/ggml-alloc.c create mode 100644 plugins/wasi_nn/thirdparty/ggml/ggml-alloc.h create mode 100644 plugins/wasi_nn/thirdparty/ggml/ggml.c create mode 100644 plugins/wasi_nn/thirdparty/ggml/ggml.h create mode 100644 plugins/wasi_nn/thirdparty/ggml/k_quants.c create mode 100644 plugins/wasi_nn/thirdparty/ggml/k_quants.h create mode 100644 plugins/wasi_nn/thirdparty/ggml/llama.cpp create mode 100644 plugins/wasi_nn/thirdparty/ggml/llama.h create mode 100644 plugins/wasi_nn/thirdparty/ggml/log.h diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 5cba657a..f3443323 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -1,6 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: 2019-2022 Second State INC +# llama.cpp options +set(LLAMA_ALL_WARNINGS OFF) +set(LLAMA_BLAS ON) +set(LLAMA_BLAS_VENDOR "OpenBLAS") + +add_subdirectory(thirdparty) + wasmedge_add_library(wasmedgePluginWasiNN SHARED wasinnenv.cpp diff --git a/plugins/wasi_nn/thirdparty/CMakeLists.txt b/plugins/wasi_nn/thirdparty/CMakeLists.txt new file mode 100644 index 00000000..e71284dd --- /dev/null +++ b/plugins/wasi_nn/thirdparty/CMakeLists.txt @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2023 Second State INC + +if(WASMEDGE_PLUGIN_WASI_NN_BACKEND) + string(TOLOWER ${WASMEDGE_PLUGIN_WASI_NN_BACKEND} BACKEND) + if(BACKEND STREQUAL "ggml") + add_compile_options(-DGGML_BACKEND) + add_subdirectory(ggml) + endif() +endif() diff --git a/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt b/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt new file mode 100644 index 00000000..ca91cf5d --- /dev/null +++ b/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt @@ -0,0 +1,644 @@ +# +# Option list +# + +if (APPLE) + set(LLAMA_METAL_DEFAULT ON) +else() + set(LLAMA_METAL_DEFAULT OFF) +endif() + +# general +option(LLAMA_STATIC "llama: static link libraries" OFF) +option(LLAMA_NATIVE "llama: enable -march=native flag" OFF) +option(LLAMA_LTO "llama: enable link time optimization" OFF) + +# debug +option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON) +option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF) +option(LLAMA_GPROF "llama: enable gprof" OFF) + +# sanitizers +option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF) +option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF) +option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF) + +# instruction set specific +option(LLAMA_AVX "llama: enable AVX" ON) +option(LLAMA_AVX2 "llama: enable AVX2" ON) +option(LLAMA_AVX512 "llama: enable AVX512" OFF) +option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF) +option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF) +option(LLAMA_FMA "llama: enable FMA" ON) +# in MSVC F16C is implied with AVX2/AVX512 +if (NOT MSVC) + option(LLAMA_F16C "llama: enable F16C" ON) +endif() + +# 3rd party libs +option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) +option(LLAMA_BLAS "llama: use BLAS" OFF) +set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") +option(LLAMA_CUBLAS "llama: use CUDA" OFF) +#option(LLAMA_CUDA_CUBLAS "llama: use cuBLAS for prompt processing" OFF) +option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF) +set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") +set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels") +option(LLAMA_CUDA_F16 "llama: use 16 bit floats for some calculations" OFF) +set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K") +option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF) +option(LLAMA_CLBLAST "llama: use CLBlast" OFF) +option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT}) +option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF) +option(LLAMA_MPI "llama: use MPI" OFF) +option(LLAMA_K_QUANTS "llama: use k-quants" ON) +option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF) + +option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) +option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) +option(LLAMA_BUILD_SERVER "llama: build server example" ON) + +# +# Compile flags +# + +set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD_REQUIRED true) +set(CMAKE_C_STANDARD 11) +set(CMAKE_C_STANDARD_REQUIRED true) +set(THREADS_PREFER_PTHREAD_FLAG ON) +find_package(Threads REQUIRED) + +if (NOT MSVC) + if (LLAMA_SANITIZE_THREAD) + add_compile_options(-fsanitize=thread) + link_libraries(-fsanitize=thread) + endif() + + if (LLAMA_SANITIZE_ADDRESS) + add_compile_options(-fsanitize=address -fno-omit-frame-pointer) + link_libraries(-fsanitize=address) + endif() + + if (LLAMA_SANITIZE_UNDEFINED) + add_compile_options(-fsanitize=undefined) + link_libraries(-fsanitize=undefined) + endif() +endif() + +if (APPLE AND LLAMA_ACCELERATE) + find_library(ACCELERATE_FRAMEWORK Accelerate) + if (ACCELERATE_FRAMEWORK) + message(STATUS "Accelerate framework found") + + add_compile_definitions(GGML_USE_ACCELERATE) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK}) + else() + message(WARNING "Accelerate framework not found") + endif() +endif() + +if (LLAMA_METAL) + find_library(FOUNDATION_LIBRARY Foundation REQUIRED) + find_library(METAL_FRAMEWORK Metal REQUIRED) + find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) + + message(STATUS "Metal framework found") + + set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h) + + add_compile_definitions(GGML_USE_METAL) + if (LLAMA_METAL_NDEBUG) + add_compile_definitions(GGML_METAL_NDEBUG) + endif() + + # get full path to the file + #add_compile_definitions(GGML_METAL_DIR_KERNELS="${CMAKE_CURRENT_SOURCE_DIR}/") + + # copy ggml-metal.metal to bin directory + configure_file(ggml-metal.metal bin/ggml-metal.metal COPYONLY) + + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} + ${FOUNDATION_LIBRARY} + ${METAL_FRAMEWORK} + ${METALKIT_FRAMEWORK} + ) +endif() + +if (LLAMA_BLAS) + if (LLAMA_STATIC) + set(BLA_STATIC ON) + endif() + if ($(CMAKE_VERSION) VERSION_GREATER_EQUAL 3.22) + set(BLA_SIZEOF_INTEGER 8) + endif() + + set(BLA_VENDOR ${LLAMA_BLAS_VENDOR}) + find_package(BLAS) + + if (BLAS_FOUND) + message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}") + + if ("${BLAS_INCLUDE_DIRS}" STREQUAL "") + # BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake. + # see https://gitlab.kitware.com/cmake/cmake/-/issues/20268 + find_package(PkgConfig REQUIRED) + if (${LLAMA_BLAS_VENDOR} MATCHES "Generic") + pkg_check_modules(DepBLAS REQUIRED blas) + elseif (${LLAMA_BLAS_VENDOR} MATCHES "OpenBLAS") + pkg_check_modules(DepBLAS REQUIRED openblas) + elseif (${LLAMA_BLAS_VENDOR} MATCHES "FLAME") + pkg_check_modules(DepBLAS REQUIRED blis) + elseif (${LLAMA_BLAS_VENDOR} MATCHES "ATLAS") + pkg_check_modules(DepBLAS REQUIRED blas-atlas) + elseif (${LLAMA_BLAS_VENDOR} MATCHES "FlexiBLAS") + pkg_check_modules(DepBLAS REQUIRED flexiblas_api) + elseif (${LLAMA_BLAS_VENDOR} MATCHES "Intel") + # all Intel* libraries share the same include path + pkg_check_modules(DepBLAS REQUIRED mkl-sdl) + elseif (${LLAMA_BLAS_VENDOR} MATCHES "NVHPC") + # this doesn't provide pkg-config + # suggest to assign BLAS_INCLUDE_DIRS on your own + if ("${NVHPC_VERSION}" STREQUAL "") + message(WARNING "Better to set NVHPC_VERSION") + else() + set(DepBLAS_FOUND ON) + set(DepBLAS_INCLUDE_DIRS "/opt/nvidia/hpc_sdk/${CMAKE_SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR}/${NVHPC_VERSION}/math_libs/include") + endif() + endif() + if (DepBLAS_FOUND) + set(BLAS_INCLUDE_DIRS ${DepBLAS_INCLUDE_DIRS}) + else() + message(WARNING "BLAS_INCLUDE_DIRS neither been provided nor been automatically" + " detected by pkgconfig, trying to find cblas.h from possible paths...") + find_path(BLAS_INCLUDE_DIRS + NAMES cblas.h + HINTS + /usr/include + /usr/local/include + /usr/include/openblas + /opt/homebrew/opt/openblas/include + /usr/local/opt/openblas/include + /usr/include/x86_64-linux-gnu/openblas/include + ) + endif() + endif() + + message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}") + add_compile_options(${BLAS_LINKER_FLAGS}) + add_compile_definitions(GGML_USE_OPENBLAS) + if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${LLAMA_BLAS_VENDOR} MATCHES "Generic" OR ${LLAMA_BLAS_VENDOR} MATCHES "Intel")) + add_compile_definitions(GGML_BLAS_USE_MKL) + endif() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${BLAS_LIBRARIES}) + set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${BLAS_INCLUDE_DIRS}) + + else() + message(WARNING "BLAS not found, please refer to " + "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" + " to set correct LLAMA_BLAS_VENDOR") + endif() +endif() + +if (LLAMA_K_QUANTS) + set(GGML_SOURCES_EXTRA ${GGML_SOURCES_EXTRA} k_quants.c k_quants.h) + add_compile_definitions(GGML_USE_K_QUANTS) + if (LLAMA_QKK_64) + add_compile_definitions(GGML_QKK_64) + endif() +endif() + +if (LLAMA_CUBLAS) + cmake_minimum_required(VERSION 3.17) + + find_package(CUDAToolkit) + if (CUDAToolkit_FOUND) + message(STATUS "cuBLAS found") + + enable_language(CUDA) + + set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h) + + add_compile_definitions(GGML_USE_CUBLAS) +# if (LLAMA_CUDA_CUBLAS) +# add_compile_definitions(GGML_CUDA_CUBLAS) +# endif() + if (LLAMA_CUDA_FORCE_DMMV) + add_compile_definitions(GGML_CUDA_FORCE_DMMV) + endif() + add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) + add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y}) + if (DEFINED LLAMA_CUDA_DMMV_Y) + add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_DMMV_Y}) # for backwards compatibility + endif() + if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16) + add_compile_definitions(GGML_CUDA_F16) + endif() + add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) + + if (LLAMA_STATIC) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) + else() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) + endif() + + if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + # 52 == lowest CUDA 12 standard + # 60 == f16 CUDA intrinsics + # 61 == integer CUDA intrinsics + # 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster + if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16) + set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics + else() + set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics + endif() + endif() + message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") + + else() + message(WARNING "cuBLAS not found") + endif() +endif() + +if (LLAMA_MPI) + cmake_minimum_required(VERSION 3.10) + find_package(MPI) + if (MPI_C_FOUND) + message(STATUS "MPI found") + set(GGML_SOURCES_MPI ggml-mpi.c ggml-mpi.h) + add_compile_definitions(GGML_USE_MPI) + add_compile_definitions(${MPI_C_COMPILE_DEFINITIONS}) + set(cxx_flags ${cxx_flags} -Wno-cast-qual) + set(c_flags ${c_flags} -Wno-cast-qual) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${MPI_C_LIBRARIES}) + set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${MPI_C_INCLUDE_DIRS}) + # Even if you're only using the C header, C++ programs may bring in MPI + # C++ functions, so more linkage is needed + if (MPI_CXX_FOUND) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${MPI_CXX_LIBRARIES}) + endif() + else() + message(WARNING "MPI not found") + endif() +endif() + +if (LLAMA_CLBLAST) + find_package(CLBlast) + if (CLBlast_FOUND) + message(STATUS "CLBlast found") + + set(GGML_SOURCES_OPENCL ggml-opencl.cpp ggml-opencl.h) + + add_compile_definitions(GGML_USE_CLBLAST) + + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast) + else() + message(WARNING "CLBlast not found") + endif() +endif() + +if (LLAMA_HIPBLAS) + list(APPEND CMAKE_PREFIX_PATH /opt/rocm) + + if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") + message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang") + endif() + if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++") + endif() + + find_package(hip) + find_package(hipblas) + find_package(rocblas) + + if (${hipblas_FOUND} AND ${hip_FOUND}) + message(STATUS "HIP and hipBLAS found") + add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS) + add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h) + if (LLAMA_CUDA_FORCE_DMMV) + target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_FORCE_DMMV) + endif() + target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) + target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y}) + target_compile_definitions(ggml-rocm PRIVATE K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) + target_compile_definitions(ggml-rocm PRIVATE CC_TURING=1000000000) + set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX) + target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas) + + if (LLAMA_STATIC) + message(FATAL_ERROR "Static linking not supported for HIP/ROCm") + endif() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ggml-rocm) + else() + message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm") + endif() +endif() + +if (LLAMA_ALL_WARNINGS) + if (NOT MSVC) + set(c_flags + -Wall + -Wextra + -Wpedantic + -Wcast-qual + -Wdouble-promotion + -Wshadow + -Wstrict-prototypes + -Wpointer-arith + -Wmissing-prototypes + -Werror=implicit-int + -Wno-unused-function + ) + set(cxx_flags + -Wall + -Wextra + -Wpedantic + -Wcast-qual + -Wno-unused-function + -Wno-multichar + ) + if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + # g++ only + set(cxx_flags ${cxx_flags} -Wno-format-truncation -Wno-array-bounds) + endif() + else() + # todo : msvc + endif() + + add_compile_options( + "$<$:${c_flags}>" + "$<$:${cxx_flags}>" + ) + +endif() + +if (MSVC) + add_compile_definitions(_CRT_SECURE_NO_WARNINGS) + + if (BUILD_SHARED_LIBS) + set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) + endif() +endif() + +if (LLAMA_LTO) + include(CheckIPOSupported) + check_ipo_supported(RESULT result OUTPUT output) + if (result) + set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE) + else() + message(WARNING "IPO is not supported: ${output}") + endif() +endif() + +# Architecture specific +# TODO: probably these flags need to be tweaked on some architectures +# feel free to update the Makefile for your architecture and send a pull request or issue +message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") +if (MSVC) + string(TOLOWER "${CMAKE_GENERATOR_PLATFORM}" CMAKE_GENERATOR_PLATFORM_LWR) + message(STATUS "CMAKE_GENERATOR_PLATFORM: ${CMAKE_GENERATOR_PLATFORM}") +else () + set(CMAKE_GENERATOR_PLATFORM_LWR "") +endif () + +if (NOT MSVC) + if (LLAMA_STATIC) + add_link_options(-static) + if (MINGW) + add_link_options(-static-libgcc -static-libstdc++) + endif() + endif() + if (LLAMA_GPROF) + add_compile_options(-pg) + endif() + if (LLAMA_NATIVE) + add_compile_options(-march=native) + endif() +endif() + +if ((${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm") OR (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") OR ("${CMAKE_GENERATOR_PLATFORM_LWR}" MATCHES "arm64")) + message(STATUS "ARM detected") + if (MSVC) + add_compile_definitions(__ARM_NEON) + add_compile_definitions(__ARM_FEATURE_FMA) + add_compile_definitions(__ARM_FEATURE_DOTPROD) + # add_compile_definitions(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) # MSVC doesn't support vdupq_n_f16, vld1q_f16, vst1q_f16 + add_compile_definitions(__aarch64__) # MSVC defines _M_ARM64 instead + else() + if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6") + # Raspberry Pi 1, Zero + add_compile_options(-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access) + endif() + if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv7") + # Raspberry Pi 2 + add_compile_options(-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations) + endif() + if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv8") + # Raspberry Pi 3, 4, Zero 2 (32-bit) + add_compile_options(-mfp16-format=ieee -mno-unaligned-access) + endif() + endif() +elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$" OR "${CMAKE_GENERATOR_PLATFORM_LWR}" MATCHES "^(x86_64|i686|amd64|x64)$" ) + message(STATUS "x86 detected") + if (MSVC) + if (LLAMA_AVX512) + add_compile_options($<$:/arch:AVX512>) + add_compile_options($<$:/arch:AVX512>) + # MSVC has no compile-time flags enabling specific + # AVX512 extensions, neither it defines the + # macros corresponding to the extensions. + # Do it manually. + if (LLAMA_AVX512_VBMI) + add_compile_definitions($<$:__AVX512VBMI__>) + add_compile_definitions($<$:__AVX512VBMI__>) + endif() + if (LLAMA_AVX512_VNNI) + add_compile_definitions($<$:__AVX512VNNI__>) + add_compile_definitions($<$:__AVX512VNNI__>) + endif() + elseif (LLAMA_AVX2) + add_compile_options($<$:/arch:AVX2>) + add_compile_options($<$:/arch:AVX2>) + elseif (LLAMA_AVX) + add_compile_options($<$:/arch:AVX>) + add_compile_options($<$:/arch:AVX>) + endif() + else() + if (LLAMA_F16C) + add_compile_options(-mf16c) + endif() + if (LLAMA_FMA) + add_compile_options(-mfma) + endif() + if (LLAMA_AVX) + add_compile_options(-mavx) + endif() + if (LLAMA_AVX2) + add_compile_options(-mavx2) + endif() + if (LLAMA_AVX512) + add_compile_options(-mavx512f) + add_compile_options(-mavx512bw) + endif() + if (LLAMA_AVX512_VBMI) + add_compile_options(-mavx512vbmi) + endif() + if (LLAMA_AVX512_VNNI) + add_compile_options(-mavx512vnni) + endif() + endif() +elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") + message(STATUS "PowerPC detected") + add_compile_options(-mcpu=native -mtune=native) + #TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be) +else() + message(STATUS "Unknown architecture") +endif() + +# +# POSIX conformance +# + +# clock_gettime came in POSIX.1b (1993) +# CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional +# posix_memalign came in POSIX.1-2001 / SUSv3 +# M_PI is an XSI extension since POSIX.1-2001 / SUSv3, came in XPG1 (1985) +add_compile_definitions(_XOPEN_SOURCE=600) + +# Somehow in OpenBSD whenever POSIX conformance is specified +# some string functions rely on locale_t availability, +# which was introduced in POSIX.1-2008, forcing us to go higher +if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD") + remove_definitions(-D_XOPEN_SOURCE=600) + add_compile_definitions(_XOPEN_SOURCE=700) +endif() + +# Data types, macros and functions related to controlling CPU affinity and +# some memory allocation are available on Linux through GNU extensions in libc +if (CMAKE_SYSTEM_NAME MATCHES "Linux") + add_compile_definitions(_GNU_SOURCE) +endif() + +# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1, +# and on macOS its availability depends on enabling Darwin extensions +# similarly on DragonFly, enabling BSD extensions is necessary +if ( + CMAKE_SYSTEM_NAME MATCHES "Darwin" OR + CMAKE_SYSTEM_NAME MATCHES "iOS" OR + CMAKE_SYSTEM_NAME MATCHES "tvOS" OR + CMAKE_SYSTEM_NAME MATCHES "DragonFly" +) + add_compile_definitions(_DARWIN_C_SOURCE) +endif() + +# alloca is a non-standard interface that is not visible on BSDs when +# POSIX conformance is specified, but not all of them provide a clean way +# to enable it in such cases +if (CMAKE_SYSTEM_NAME MATCHES "FreeBSD") + add_compile_definitions(__BSD_VISIBLE) +endif() +if (CMAKE_SYSTEM_NAME MATCHES "NetBSD") + add_compile_definitions(_NETBSD_SOURCE) +endif() +if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD") + add_compile_definitions(_BSD_SOURCE) +endif() + +# +# libraries +# + +# ggml + +if (GGML_USE_CPU_HBM) + add_definitions(-DGGML_USE_CPU_HBM) + find_library(memkind memkind REQUIRED) +endif() + +wasmedge_add_library(ggml OBJECT + ggml.c + ggml.h + ggml-alloc.c + ggml-alloc.h + common.cpp + common.h + ${GGML_SOURCES_CUDA} + ${GGML_SOURCES_OPENCL} + ${GGML_SOURCES_METAL} + ${GGML_SOURCES_MPI} + ${GGML_SOURCES_EXTRA} + ) + +target_include_directories(ggml PUBLIC . ${LLAMA_EXTRA_INCLUDES}) +target_compile_features(ggml PUBLIC c_std_11) # don't bump +target_link_libraries(ggml PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS}) +if (GGML_USE_CPU_HBM) + target_link_libraries(ggml PUBLIC memkind) +endif() + +wasmedge_add_library(ggml_static STATIC $) +if (BUILD_SHARED_LIBS) + set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON) + wasmedge_add_library(ggml_shared SHARED $) + target_link_libraries(ggml_shared PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS}) + install(TARGETS ggml_shared LIBRARY) +endif() + +# llama + +wasmedge_add_library(llama + llama.cpp + llama.h + ) + +target_include_directories(llama PUBLIC .) +target_compile_features(llama PUBLIC cxx_std_11) # don't bump +target_link_libraries(llama PRIVATE + ggml + ${LLAMA_EXTRA_LIBS} + ) + +if (BUILD_SHARED_LIBS) + set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_compile_definitions(llama PRIVATE LLAMA_SHARED LLAMA_BUILD) + if (LLAMA_METAL) + set_target_properties(llama PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal") + endif() +endif() + +# disable warnings +if (NOT WIN32) + target_compile_options(ggml + PRIVATE + -Wno-unused-parameter + -Wno-unused-variable + -Wno-unused-but-set-variable + -Wno-unused-function + -Wno-missing-braces + -DGGML_USE_K_QUANTS + -DGGML_USE_OPENBLAS + ) +else() + target_compile_options(ggml + PRIVATE + -Wno-string-conversion + -Wno-sign-conversion + -Wno-macro-redefined + -Wno-missing-prototypes + -Wno-unreachable-code-return + -Wno-shorten-64-to-32 + -Wno-implicit-int-conversion + -Wno-implicit-float-conversion + -Wno-float-conversion + -Wno-unused-macros + -Wno-unreachable-code-break + -Wno-cast-align + -Wno-undef + -Wno-shadow-uncaptured-local + -Wno-unreachable-code + -Wno-cast-function-type + -Wno-format-nonliteral + -Wno-extra-semi-stmt + -Wno-bad-function-cast + ) +endif() \ No newline at end of file diff --git a/plugins/wasi_nn/thirdparty/ggml/LICENSE b/plugins/wasi_nn/thirdparty/ggml/LICENSE new file mode 100644 index 00000000..8c955688 --- /dev/null +++ b/plugins/wasi_nn/thirdparty/ggml/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Georgi Gerganov + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/plugins/wasi_nn/thirdparty/ggml/README.md b/plugins/wasi_nn/thirdparty/ggml/README.md new file mode 100644 index 00000000..594704db --- /dev/null +++ b/plugins/wasi_nn/thirdparty/ggml/README.md @@ -0,0 +1,10 @@ +# GGML and llama.cpp + +[GGML][] and [llama.cpp][] are open-source projects in the machine learning domain. GGML is a tensor library for machine learning, developed in C. On the other hand, llama.cpp serves as a LLaMA model inference engine and is implemented in C/C++. + +This directory contains the source code from both llama.cpp and GGML. The code in this directory is licensed under the MIT License. For more details, please refer to the [LICENSE](./LICENSE) file. + +WasmEdge includes support for GGML and llama.cpp through its WASI-NN plugin, enabling the execution of machine learning models in WebAssembly. Within the WasmEdge WASI-NN plugin, we have added functionality for GGML model loading and LLaMA model inference. + +[GGML]: http://ggml.ai +[llama.cpp]: https://github.com/ggerganov/ggml diff --git a/plugins/wasi_nn/thirdparty/ggml/common.cpp b/plugins/wasi_nn/thirdparty/ggml/common.cpp new file mode 100644 index 00000000..382f0058 --- /dev/null +++ b/plugins/wasi_nn/thirdparty/ggml/common.cpp @@ -0,0 +1,1258 @@ +#include "common.h" +#include "llama.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__APPLE__) && defined(__MACH__) +#include +#include +#endif + +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +# define NOMINMAX +#endif +#include +#include +#include +#include +#include +#else +#include +#include +#include +#endif + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +int32_t get_num_physical_cores() { +#ifdef __linux__ + // enumerate the set of thread siblings, num entries is num cores + std::unordered_set siblings; + for (uint32_t cpu=0; cpu < UINT32_MAX; ++cpu) { + std::ifstream thread_siblings("/sys/devices/system/cpu" + + std::to_string(cpu) + "/topology/thread_siblings"); + if (!thread_siblings.is_open()) { + break; // no more cpus + } + std::string line; + if (std::getline(thread_siblings, line)) { + siblings.insert(line); + } + } + if (!siblings.empty()) { + return static_cast(siblings.size()); + } +#elif defined(__APPLE__) && defined(__MACH__) + int32_t num_physical_cores; + size_t len = sizeof(num_physical_cores); + int result = sysctlbyname("hw.perflevel0.physicalcpu", &num_physical_cores, &len, NULL, 0); + if (result == 0) { + return num_physical_cores; + } + result = sysctlbyname("hw.physicalcpu", &num_physical_cores, &len, NULL, 0); + if (result == 0) { + return num_physical_cores; + } +#elif defined(_WIN32) + //TODO: Implement +#endif + unsigned int n_threads = std::thread::hardware_concurrency(); + return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; +} + +void process_escapes(std::string& input) { + std::size_t input_len = input.length(); + std::size_t output_idx = 0; + + for (std::size_t input_idx = 0; input_idx < input_len; ++input_idx) { + if (input[input_idx] == '\\' && input_idx + 1 < input_len) { + switch (input[++input_idx]) { + case 'n': input[output_idx++] = '\n'; break; + case 'r': input[output_idx++] = '\r'; break; + case 't': input[output_idx++] = '\t'; break; + case '\'': input[output_idx++] = '\''; break; + case '\"': input[output_idx++] = '\"'; break; + case '\\': input[output_idx++] = '\\'; break; + default: input[output_idx++] = '\\'; + input[output_idx++] = input[input_idx]; break; + } + } else { + input[output_idx++] = input[input_idx]; + } + } + + input.resize(output_idx); +} + +bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { + bool invalid_param = false; + std::string arg; + gpt_params default_params; + const std::string arg_prefix = "--"; + + for (int i = 1; i < argc; i++) { + arg = argv[i]; + if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { + std::replace(arg.begin(), arg.end(), '_', '-'); + } + + if (arg == "-s" || arg == "--seed") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.seed = std::stoul(argv[i]); + } else if (arg == "-t" || arg == "--threads") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_threads = std::stoi(argv[i]); + if (params.n_threads <= 0) { + params.n_threads = std::thread::hardware_concurrency(); + } + } else if (arg == "-p" || arg == "--prompt") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.prompt = argv[i]; + } else if (arg == "-e" || arg == "--escape") { + params.escape = true; + } else if (arg == "--prompt-cache") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.path_prompt_cache = argv[i]; + } else if (arg == "--prompt-cache-all") { + params.prompt_cache_all = true; + } else if (arg == "--prompt-cache-ro") { + params.prompt_cache_ro = true; + } else if (arg == "-f" || arg == "--file") { + if (++i >= argc) { + invalid_param = true; + break; + } + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + break; + } + std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt)); + if (params.prompt.back() == '\n') { + params.prompt.pop_back(); + } + } else if (arg == "-n" || arg == "--n-predict") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_predict = std::stoi(argv[i]); + } else if (arg == "--top-k") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.top_k = std::stoi(argv[i]); + } else if (arg == "-c" || arg == "--ctx-size") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_ctx = std::stoi(argv[i]); + } else if (arg == "--rope-freq-base") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.rope_freq_base = std::stof(argv[i]); + } else if (arg == "--rope-freq-scale") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.rope_freq_scale = std::stof(argv[i]); + } else if (arg == "--rope-scale") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.rope_freq_scale = 1.0f/std::stof(argv[i]); + } else if (arg == "--memory-f32") { + params.memory_f16 = false; + } else if (arg == "--top-p") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.top_p = std::stof(argv[i]); + } else if (arg == "--temp") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.temp = std::stof(argv[i]); + } else if (arg == "--tfs") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.tfs_z = std::stof(argv[i]); + } else if (arg == "--typical") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.typical_p = std::stof(argv[i]); + } else if (arg == "--repeat-last-n") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.repeat_last_n = std::stoi(argv[i]); + } else if (arg == "--repeat-penalty") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.repeat_penalty = std::stof(argv[i]); + } else if (arg == "--frequency-penalty") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.frequency_penalty = std::stof(argv[i]); + } else if (arg == "--presence-penalty") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.presence_penalty = std::stof(argv[i]); + } else if (arg == "--mirostat") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.mirostat = std::stoi(argv[i]); + } else if (arg == "--mirostat-lr") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.mirostat_eta = std::stof(argv[i]); + } else if (arg == "--mirostat-ent") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.mirostat_tau = std::stof(argv[i]); + } else if (arg == "--cfg-negative-prompt") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.cfg_negative_prompt = argv[i]; + } else if (arg == "--cfg-negative-prompt-file") { + if (++i >= argc) { + invalid_param = true; + break; + } + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + break; + } + std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.cfg_negative_prompt)); + if (params.cfg_negative_prompt.back() == '\n') { + params.cfg_negative_prompt.pop_back(); + } + } else if (arg == "--cfg-scale") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.cfg_scale = std::stof(argv[i]); + } else if (arg == "-b" || arg == "--batch-size") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_batch = std::stoi(argv[i]); + } else if (arg == "--keep") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_keep = std::stoi(argv[i]); + } else if (arg == "--draft") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_draft = std::stoi(argv[i]); + } else if (arg == "--chunks") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_chunks = std::stoi(argv[i]); + } else if (arg == "-m" || arg == "--model") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.model = argv[i]; + } else if (arg == "-md" || arg == "--model-draft") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.model_draft = argv[i]; + } else if (arg == "-a" || arg == "--alias") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.model_alias = argv[i]; + } else if (arg == "--lora") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.lora_adapter = argv[i]; + params.use_mmap = false; + } else if (arg == "--lora-base") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.lora_base = argv[i]; + } else if (arg == "-i" || arg == "--interactive") { + params.interactive = true; + } else if (arg == "--embedding") { + params.embedding = true; + } else if (arg == "--interactive-first") { + params.interactive_first = true; + } else if (arg == "-ins" || arg == "--instruct") { + params.instruct = true; + } else if (arg == "--multiline-input") { + params.multiline_input = true; + } else if (arg == "--simple-io") { + params.simple_io = true; + } else if (arg == "--color") { + params.use_color = true; + } else if (arg == "--mlock") { + params.use_mlock = true; + } else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") { + if (++i >= argc) { + invalid_param = true; + break; + } +#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD + params.n_gpu_layers = std::stoi(argv[i]); +#else + fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n"); + fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); +#endif + } else if (arg == "--main-gpu" || arg == "-mg") { + if (++i >= argc) { + invalid_param = true; + break; + } +#ifdef GGML_USE_CUBLAS + params.main_gpu = std::stoi(argv[i]); +#else + fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n"); +#endif + } else if (arg == "--tensor-split" || arg == "-ts") { + if (++i >= argc) { + invalid_param = true; + break; + } +#ifdef GGML_USE_CUBLAS + std::string arg_next = argv[i]; + + // split string by , and / + const std::regex regex{R"([,/]+)"}; + std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; + std::vector split_arg{it, {}}; + GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES); + + for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) { + if (i < split_arg.size()) { + params.tensor_split[i] = std::stof(split_arg[i]); + } else { + params.tensor_split[i] = 0.0f; + } + } +#else + fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n"); +#endif // GGML_USE_CUBLAS + } else if (arg == "--no-mul-mat-q" || arg == "-nommq") { +#ifdef GGML_USE_CUBLAS + params.mul_mat_q = false; +#else + fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Disabling mul_mat_q kernels has no effect.\n"); +#endif // GGML_USE_CUBLAS + } else if (arg == "--low-vram" || arg == "-lv") { +#ifdef GGML_USE_CUBLAS + params.low_vram = true; +#else + fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n"); +#endif // GGML_USE_CUBLAS + } else if (arg == "--no-mmap") { + params.use_mmap = false; + } else if (arg == "--mtest") { + params.mem_test = true; + } else if (arg == "--numa") { + params.numa = true; + } else if (arg == "--export") { + params.export_cgraph = true; + } else if (arg == "--verbose-prompt") { + params.verbose_prompt = true; + } else if (arg == "-r" || arg == "--reverse-prompt") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.antiprompt.push_back(argv[i]); + } else if (arg == "-ld" || arg == "--logdir") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.logdir = argv[i]; + + if (params.logdir.back() != DIRECTORY_SEPARATOR) { + params.logdir += DIRECTORY_SEPARATOR; + } + } else if (arg == "--perplexity") { + params.perplexity = true; + } else if (arg == "--ppl-stride") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.ppl_stride = std::stoi(argv[i]); + } else if (arg == "--ppl-output-type") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.ppl_output_type = std::stoi(argv[i]); + } else if (arg == "--hellaswag") { + params.hellaswag = true; + } else if (arg == "--hellaswag-tasks") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.hellaswag_tasks = std::stoi(argv[i]); + } else if (arg == "--ignore-eos") { + params.ignore_eos = true; + } else if (arg == "--no-penalize-nl") { + params.penalize_nl = false; + } else if (arg == "-l" || arg == "--logit-bias") { + if (++i >= argc) { + invalid_param = true; + break; + } + std::stringstream ss(argv[i]); + llama_token key; + char sign; + std::string value_str; + try { + if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) { + params.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); + } else { + throw std::exception(); + } + } catch (const std::exception&) { + invalid_param = true; + break; + } + } else if (arg == "-h" || arg == "--help") { + gpt_print_usage(argc, argv, default_params); +#ifndef LOG_DISABLE_LOGS + log_print_usage(); +#endif // LOG_DISABLE_LOGS + exit(0); + } else if (arg == "--random-prompt") { + params.random_prompt = true; + } else if (arg == "--in-prefix-bos") { + params.input_prefix_bos = true; + } else if (arg == "--in-prefix") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.input_prefix = argv[i]; + } else if (arg == "--in-suffix") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.input_suffix = argv[i]; + } else if (arg == "--grammar") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.grammar = argv[i]; + } else if (arg == "--grammar-file") { + if (++i >= argc) { + invalid_param = true; + break; + } + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + break; + } + std::copy( + std::istreambuf_iterator(file), + std::istreambuf_iterator(), + std::back_inserter(params.grammar) + ); +#ifndef LOG_DISABLE_LOGS + // Parse args for logging parameters + } else if ( log_param_single_parse( argv[i] ) ) { + // Do nothing, log_param_single_parse automatically does it's thing + // and returns if a match was found and parsed. + } else if ( log_param_pair_parse( /*check_but_dont_parse*/ true, argv[i] ) ) { + // We have a matching known parameter requiring an argument, + // now we need to check if there is anything after this argv + // and flag invalid_param or parse it. + if (++i >= argc) { + invalid_param = true; + break; + } + if( !log_param_pair_parse( /*check_but_dont_parse*/ false, argv[i-1], argv[i]) ) { + invalid_param = true; + break; + } + // End of Parse args for logging parameters +#endif // LOG_DISABLE_LOGS + } else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + gpt_print_usage(argc, argv, default_params); + exit(1); + } + } + if (invalid_param) { + fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); + gpt_print_usage(argc, argv, default_params); + exit(1); + } + if (params.prompt_cache_all && + (params.interactive || params.interactive_first || + params.instruct)) { + fprintf(stderr, "error: --prompt-cache-all not supported in interactive mode yet\n"); + gpt_print_usage(argc, argv, default_params); + exit(1); + } + + if (params.escape) { + process_escapes(params.prompt); + process_escapes(params.input_prefix); + process_escapes(params.input_suffix); + } + + return true; +} + +void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { + printf("usage: %s [options]\n", argv[0]); + printf("\n"); + printf("options:\n"); + printf(" -h, --help show this help message and exit\n"); + printf(" -i, --interactive run in interactive mode\n"); + printf(" --interactive-first run in interactive mode and wait for input right away\n"); + printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n"); + printf(" --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); + printf(" -r PROMPT, --reverse-prompt PROMPT\n"); + printf(" halt generation at PROMPT, return control in interactive mode\n"); + printf(" (can be specified more than once for multiple prompts).\n"); + printf(" --color colorise output to distinguish prompt and user input from generations\n"); + printf(" -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n"); + printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); + printf(" -p PROMPT, --prompt PROMPT\n"); + printf(" prompt to start generation with (default: empty)\n"); + printf(" -e, --escape process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n"); + printf(" --prompt-cache FNAME file to cache prompt state for faster startup (default: none)\n"); + printf(" --prompt-cache-all if specified, saves user input and generations to cache as well.\n"); + printf(" not supported with --interactive or other interactive options\n"); + printf(" --prompt-cache-ro if specified, uses the prompt cache but does not update it.\n"); + printf(" --random-prompt start with a randomized prompt.\n"); + printf(" --in-prefix-bos prefix BOS to user inputs, preceding the `--in-prefix` string\n"); + printf(" --in-prefix STRING string to prefix user inputs with (default: empty)\n"); + printf(" --in-suffix STRING string to suffix after user inputs with (default: empty)\n"); + printf(" -f FNAME, --file FNAME\n"); + printf(" prompt file to start generation.\n"); + printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict); + printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); + printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); + printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k); + printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p); + printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z); + printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)params.typical_p); + printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", params.repeat_last_n); + printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty); + printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty); + printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty); + printf(" --mirostat N use Mirostat sampling.\n"); + printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"); + printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat); + printf(" --mirostat-lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)params.mirostat_eta); + printf(" --mirostat-ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)params.mirostat_tau); + printf(" -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n"); + printf(" modifies the likelihood of token appearing in the completion,\n"); + printf(" i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); + printf(" or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); + printf(" --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir)\n"); + printf(" --grammar-file FNAME file to read grammar from\n"); + printf(" --cfg-negative-prompt PROMPT\n"); + printf(" negative prompt to use for guidance. (default: empty)\n"); + printf(" --cfg-negative-prompt-file FNAME\n"); + printf(" negative prompt file to use for guidance. (default: empty)\n"); + printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); + printf(" --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale (default: %g)\n", 1.0f/params.rope_freq_scale); + printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: %.1f)\n", params.rope_freq_base); + printf(" --rope-freq-scale N RoPE frequency linear scaling factor, inverse of --rope-scale (default: %g)\n", params.rope_freq_scale); + printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); + printf(" --no-penalize-nl do not penalize newline token\n"); + printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); + printf(" not recommended: doubles context memory required and no measurable increase in quality\n"); + printf(" --temp N temperature (default: %.1f)\n", (double)params.temp); + printf(" --perplexity compute perplexity over each ctx window of the prompt\n"); + printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n"); + printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks); + printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); + printf(" --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft); + printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks); + if (llama_mlock_supported()) { + printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n"); + } + if (llama_mmap_supported()) { + printf(" --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n"); + } + printf(" --numa attempt optimizations that help on some NUMA systems\n"); + printf(" if run without this previously, it is recommended to drop the system page cache before using this\n"); + printf(" see https://github.com/ggerganov/llama.cpp/issues/1437\n"); +#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD + printf(" -ngl N, --n-gpu-layers N\n"); + printf(" number of layers to store in VRAM\n"); + printf(" -ts SPLIT --tensor-split SPLIT\n"); + printf(" how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n"); + printf(" -mg i, --main-gpu i the GPU to use for scratch and small tensors\n"); + printf(" -lv, --low-vram don't allocate VRAM scratch buffer\n"); +#ifdef GGML_USE_CUBLAS + printf(" -nommq, --no-mul-mat-q\n"); + printf(" use " GGML_CUBLAS_NAME " instead of custom mul_mat_q " GGML_CUDA_NAME " kernels.\n"); + printf(" Not recommended since this is both slower and uses more VRAM.\n"); +#endif // GGML_USE_CUBLAS +#endif + printf(" --mtest compute maximum memory usage\n"); + printf(" --export export the computation graph to 'llama.ggml'\n"); + printf(" --verbose-prompt print prompt before generation\n"); + fprintf(stderr, " --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n"); + printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); + printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); + printf(" -m FNAME, --model FNAME\n"); + printf(" model path (default: %s)\n", params.model.c_str()); + printf(" -md FNAME, --model-draft FNAME\n"); + printf(" draft model for speculative decoding (default: %s)\n", params.model.c_str()); + printf(" -ld LOGDIR, --logdir LOGDIR\n"); + printf(" path under which to save YAML logs (no logging if unset)\n"); + printf("\n"); +} + +std::string gpt_random_prompt(std::mt19937 & rng) { + const int r = rng() % 10; + switch (r) { + case 0: return "So"; + case 1: return "Once upon a time"; + case 2: return "When"; + case 3: return "The"; + case 4: return "After"; + case 5: return "If"; + case 6: return "import"; + case 7: return "He"; + case 8: return "She"; + case 9: return "They"; + default: return "To"; + } + + return "The"; +} + +// +// Model utils +// + +struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) { + auto lparams = llama_context_default_params(); + + lparams.n_ctx = params.n_ctx; + lparams.n_batch = params.n_batch; + if (params.n_gpu_layers != -1) { + lparams.n_gpu_layers = params.n_gpu_layers; + } + lparams.main_gpu = params.main_gpu; + lparams.tensor_split = params.tensor_split; + lparams.low_vram = params.low_vram; + lparams.mul_mat_q = params.mul_mat_q; + lparams.seed = params.seed; + lparams.f16_kv = params.memory_f16; + lparams.use_mmap = params.use_mmap; + lparams.use_mlock = params.use_mlock; + lparams.logits_all = params.perplexity; + lparams.embedding = params.embedding; + lparams.rope_freq_base = params.rope_freq_base; + lparams.rope_freq_scale = params.rope_freq_scale; + + return lparams; +} + +std::tuple llama_init_from_gpt_params(gpt_params & params) { + auto lparams = llama_context_params_from_gpt_params(params); + + llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams); + if (model == NULL) { + fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); + return std::make_tuple(nullptr, nullptr); + } + + llama_context * lctx = llama_new_context_with_model(model, lparams); + if (lctx == NULL) { + fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); + llama_free_model(model); + return std::make_tuple(nullptr, nullptr); + } + + if (!params.lora_adapter.empty()) { + int err = llama_model_apply_lora_from_file(model, + params.lora_adapter.c_str(), + params.lora_base.empty() ? NULL : params.lora_base.c_str(), + params.n_threads); + if (err != 0) { + fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); + llama_free(lctx); + llama_free_model(model); + return std::make_tuple(nullptr, nullptr); + } + } + + if (params.ignore_eos) { + params.logit_bias[llama_token_eos(lctx)] = -INFINITY; + } + + { + LOG("warming up the model with an empty run\n"); + + const std::vector tmp = { llama_token_bos(lctx), llama_token_eos(lctx), }; + llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads); + llama_reset_timings(lctx); + } + + return std::make_tuple(model, lctx); +} + +// +// Vocab utils +// + +std::vector llama_tokenize( + struct llama_context * ctx, + const std::string & text, + bool add_bos) { + // upper limit for the number of tokens + int n_tokens = text.length() + add_bos; + std::vector result(n_tokens); + n_tokens = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos); + if (n_tokens < 0) { + result.resize(-n_tokens); + int check = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos); + GGML_ASSERT(check == -n_tokens); + } else { + result.resize(n_tokens); + } + return result; +} + +std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) { + std::vector result(8, 0); + const int n_tokens = llama_token_to_piece(ctx, token, result.data(), result.size()); + if (n_tokens < 0) { + result.resize(-n_tokens); + int check = llama_token_to_piece(ctx, token, result.data(), result.size()); + GGML_ASSERT(check == -n_tokens); + } else { + result.resize(n_tokens); + } + + return std::string(result.data(), result.size()); +} + +std::string llama_detokenize_spm(llama_context * ctx, const std::vector & tokens) { + const llama_token bos_id = llama_token_bos(ctx); + + std::string piece; + std::string result; + + for (size_t i = 0; i < tokens.size(); ++i) { + piece = llama_token_to_piece(ctx, tokens[i]); + + // remove the leading space of the first non-BOS token + if (((tokens[0] == bos_id && i == 1) || (tokens[0] != bos_id && i == 0)) && piece[0] == ' ') { + piece = piece.substr(1); + } + + result += piece; + } + + return result; +} + +std::string llama_detokenize_bpe(llama_context * ctx, const std::vector & tokens) { + std::string piece; + std::string result; + + for (size_t i = 0; i < tokens.size(); ++i) { + piece = llama_token_to_piece(ctx, tokens[i]); + + result += piece; + } + + return result; +} + +// +// Sampling utils +// + +llama_token llama_sample_token( + struct llama_context * ctx, + struct llama_context * ctx_guidance, + struct llama_grammar * grammar, + const struct gpt_params & params, + const std::vector & last_tokens, + std::vector & candidates, + int idx) { + const int n_ctx = llama_n_ctx(ctx); + const int n_vocab = llama_n_vocab(ctx); + + const float temp = params.temp; + const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; + const float top_p = params.top_p; + const float tfs_z = params.tfs_z; + const float typical_p = params.typical_p; + const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; + const float repeat_penalty = params.repeat_penalty; + const float alpha_presence = params.presence_penalty; + const float alpha_frequency = params.frequency_penalty; + const int mirostat = params.mirostat; + const float mirostat_tau = params.mirostat_tau; + const float mirostat_eta = params.mirostat_eta; + const bool penalize_nl = params.penalize_nl; + + llama_token id = 0; + + float * logits = llama_get_logits(ctx) + idx * n_vocab; + + // Apply params.logit_bias map + for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { + logits[it->first] += it->second; + } + + candidates.clear(); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + + llama_token_data_array cur_p = { candidates.data(), candidates.size(), false }; + + if (ctx_guidance) { + llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale); + } + + // apply penalties + if (!last_tokens.empty()) { + const float nl_logit = logits[llama_token_nl(ctx)]; + const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx); + + llama_sample_repetition_penalty(ctx, &cur_p, + last_tokens.data() + last_tokens.size() - last_n_repeat, + last_n_repeat, repeat_penalty); + llama_sample_frequency_and_presence_penalties(ctx, &cur_p, + last_tokens.data() + last_tokens.size() - last_n_repeat, + last_n_repeat, alpha_frequency, alpha_presence); + + if (!penalize_nl) { + for (size_t idx = 0; idx < cur_p.size; idx++) { + if (cur_p.data[idx].id == llama_token_nl(ctx)) { + cur_p.data[idx].logit = nl_logit; + break; + } + } + } + } + + if (grammar != NULL) { + llama_sample_grammar(ctx, &cur_p, grammar); + } + + if (temp <= 0) { + // Greedy sampling + id = llama_sample_token_greedy(ctx, &cur_p); + } else { + if (mirostat == 1) { + static float mirostat_mu = 2.0f * mirostat_tau; + const int mirostat_m = 100; + llama_sample_temperature(ctx, &cur_p, temp); + id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); + } else if (mirostat == 2) { + static float mirostat_mu = 2.0f * mirostat_tau; + llama_sample_temperature(ctx, &cur_p, temp); + id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu); + } else { + // Temperature sampling + llama_sample_top_k (ctx, &cur_p, top_k, 1); + llama_sample_tail_free (ctx, &cur_p, tfs_z, 1); + llama_sample_typical (ctx, &cur_p, typical_p, 1); + llama_sample_top_p (ctx, &cur_p, top_p, 1); + llama_sample_temperature(ctx, &cur_p, temp); + + { + const int n_top = 10; + LOG("top %d candidates:\n", n_top); + + for (int i = 0; i < n_top; i++) { + const llama_token id = cur_p.data[i].id; + LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p); + } + } + + id = llama_sample_token(ctx, &cur_p); + + LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str()); + } + } + // printf("`%d`", candidates_p.size); + + if (grammar != NULL) { + llama_grammar_accept_token(ctx, grammar, id); + } + + return id; +} + +// +// YAML utils +// + +// returns true if successful, false otherwise +bool create_directory_with_parents(const std::string & path) { +#ifdef _WIN32 + std::wstring_convert> converter; + std::wstring wpath = converter.from_bytes(path); + + // if the path already exists, check whether it's a directory + const DWORD attributes = GetFileAttributesW(wpath.c_str()); + if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) { + return true; + } + + size_t pos_slash = 0; + + // process path from front to back, procedurally creating directories + while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) { + const std::wstring subpath = wpath.substr(0, pos_slash); + const wchar_t * test = subpath.c_str(); + + const bool success = CreateDirectoryW(test, NULL); + if (!success) { + const DWORD error = GetLastError(); + + // if the path already exists, ensure that it's a directory + if (error == ERROR_ALREADY_EXISTS) { + const DWORD attributes = GetFileAttributesW(subpath.c_str()); + if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) { + return false; + } + } else { + return false; + } + } + + pos_slash += 1; + } + + return true; +#else + // if the path already exists, check whether it's a directory + struct stat info; + if (stat(path.c_str(), &info) == 0) { + return S_ISDIR(info.st_mode); + } + + size_t pos_slash = 1; // skip leading slashes for directory creation + + // process path from front to back, procedurally creating directories + while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) { + const std::string subpath = path.substr(0, pos_slash); + struct stat info; + + // if the path already exists, ensure that it's a directory + if (stat(subpath.c_str(), &info) == 0) { + if (!S_ISDIR(info.st_mode)) { + return false; + } + } else { + // create parent directories + const int ret = mkdir(subpath.c_str(), 0755); + if (ret != 0) { + return false; + } + } + + pos_slash += 1; + } + + return true; +#endif // _WIN32 +} + +void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector & data) { + if (data.empty()) { + fprintf(stream, "%s:\n", prop_name); + return; + } + + fprintf(stream, "%s: [", prop_name); + for (size_t i = 0; i < data.size() - 1; ++i) { + fprintf(stream, "%e, ", data[i]); + } + fprintf(stream, "%e]\n", data.back()); +} + +void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector & data) { + if (data.empty()) { + fprintf(stream, "%s:\n", prop_name); + return; + } + + fprintf(stream, "%s: [", prop_name); + for (size_t i = 0; i < data.size() - 1; ++i) { + fprintf(stream, "%d, ", data[i]); + } + fprintf(stream, "%d]\n", data.back()); +} + +void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const char * data) { + std::string data_str(data == NULL ? "" : data); + + if (data_str.empty()) { + fprintf(stream, "%s:\n", prop_name); + return; + } + + size_t pos_start = 0; + size_t pos_found = 0; + + if (!data_str.empty() && (std::isspace(data_str[0]) || std::isspace(data_str.back()))) { + data_str = std::regex_replace(data_str, std::regex("\n"), "\\n"); + data_str = std::regex_replace(data_str, std::regex("\""), "\\\""); + data_str = "\"" + data_str + "\""; + fprintf(stream, "%s: %s\n", prop_name, data_str.c_str()); + return; + } + + if (data_str.find('\n') == std::string::npos) { + fprintf(stream, "%s: %s\n", prop_name, data_str.c_str()); + return; + } + + fprintf(stream, "%s: |\n", prop_name); + while ((pos_found = data_str.find('\n', pos_start)) != std::string::npos) { + fprintf(stream, " %s\n", data_str.substr(pos_start, pos_found-pos_start).c_str()); + pos_start = pos_found + 1; + } +} + +std::string get_sortable_timestamp() { + using clock = std::chrono::system_clock; + + const clock::time_point current_time = clock::now(); + const time_t as_time_t = clock::to_time_t(current_time); + char timestamp_no_ns[100]; + std::strftime(timestamp_no_ns, 100, "%Y_%m_%d-%H_%M_%S", std::localtime(&as_time_t)); + + const int64_t ns = std::chrono::duration_cast( + current_time.time_since_epoch() % 1000000000).count(); + char timestamp_ns[11]; + snprintf(timestamp_ns, 11, "%09" PRId64, ns); + + return std::string(timestamp_no_ns) + "." + std::string(timestamp_ns); +} + +void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const llama_context * lctx, + const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc) { + fprintf(stream, "cpu_has_arm_fma: %s\n", ggml_cpu_has_arm_fma() ? "true" : "false"); + fprintf(stream, "cpu_has_avx: %s\n", ggml_cpu_has_avx() ? "true" : "false"); + fprintf(stream, "cpu_has_avx2: %s\n", ggml_cpu_has_avx2() ? "true" : "false"); + fprintf(stream, "cpu_has_avx512: %s\n", ggml_cpu_has_avx512() ? "true" : "false"); + fprintf(stream, "cpu_has_avx512_vbmi: %s\n", ggml_cpu_has_avx512_vbmi() ? "true" : "false"); + fprintf(stream, "cpu_has_avx512_vnni: %s\n", ggml_cpu_has_avx512_vnni() ? "true" : "false"); + fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false"); + fprintf(stream, "cpu_has_cublas: %s\n", ggml_cpu_has_cublas() ? "true" : "false"); + fprintf(stream, "cpu_has_clblast: %s\n", ggml_cpu_has_clblast() ? "true" : "false"); + fprintf(stream, "cpu_has_fma: %s\n", ggml_cpu_has_fma() ? "true" : "false"); + fprintf(stream, "cpu_has_gpublas: %s\n", ggml_cpu_has_gpublas() ? "true" : "false"); + fprintf(stream, "cpu_has_neon: %s\n", ggml_cpu_has_neon() ? "true" : "false"); + fprintf(stream, "cpu_has_f16c: %s\n", ggml_cpu_has_f16c() ? "true" : "false"); + fprintf(stream, "cpu_has_fp16_va: %s\n", ggml_cpu_has_fp16_va() ? "true" : "false"); + fprintf(stream, "cpu_has_wasm_simd: %s\n", ggml_cpu_has_wasm_simd() ? "true" : "false"); + fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false"); + fprintf(stream, "cpu_has_sse3: %s\n", ggml_cpu_has_sse3() ? "true" : "false"); + fprintf(stream, "cpu_has_vsx: %s\n", ggml_cpu_has_vsx() ? "true" : "false"); + +#ifdef NDEBUG + fprintf(stream, "debug: false\n"); +#else + fprintf(stream, "debug: true\n"); +#endif // NDEBUG + + fprintf(stream, "model_desc: %s\n", model_desc); + fprintf(stream, "n_vocab: %d # output size of the final layer, 32001 for some models\n", llama_n_vocab(lctx)); + +#ifdef __OPTIMIZE__ + fprintf(stream, "optimize: true\n"); +#else + fprintf(stream, "optimize: false\n"); +#endif // __OPTIMIZE__ + + fprintf(stream, "time: %s\n", timestamp.c_str()); + + fprintf(stream, "\n"); + fprintf(stream, "###############\n"); + fprintf(stream, "# User Inputs #\n"); + fprintf(stream, "###############\n"); + fprintf(stream, "\n"); + + fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str()); + fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch); + dump_string_yaml_multiline(stream, "cfg_negative_prompt", params.cfg_negative_prompt.c_str()); + fprintf(stream, "cfg_scale: %f # default: 1.0\n", params.cfg_scale); + fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks); + fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false"); + fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); + fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); + fprintf(stream, "export: %s # default: false\n", params.export_cgraph ? "true" : "false"); + fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); + fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", params.frequency_penalty); + dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str()); + fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n"); + fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); + fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks); + + const auto logit_bias_eos = params.logit_bias.find(llama_token_eos(lctx)); + const bool ignore_eos = logit_bias_eos != params.logit_bias.end() && logit_bias_eos->second == -INFINITY; + fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false"); + + dump_string_yaml_multiline(stream, "in_prefix", params.input_prefix.c_str()); + fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false"); + dump_string_yaml_multiline(stream, "in_suffix", params.input_prefix.c_str()); + fprintf(stream, "instruct: %s # default: false\n", params.instruct ? "true" : "false"); + fprintf(stream, "interactive: %s # default: false\n", params.interactive ? "true" : "false"); + fprintf(stream, "interactive_first: %s # default: false\n", params.interactive_first ? "true" : "false"); + fprintf(stream, "keep: %d # default: 0\n", params.n_keep); + fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str()); + + fprintf(stream, "logit_bias:\n"); + for (std::pair lb : params.logit_bias) { + if (ignore_eos && lb.first == logit_bias_eos->first) { + continue; + } + fprintf(stream, " %d: %f", lb.first, lb.second); + } + + fprintf(stream, "lora: %s\n", params.lora_adapter.c_str()); + fprintf(stream, "lora_base: %s\n", params.lora_base.c_str()); + fprintf(stream, "low_vram: %s # default: false\n", params.low_vram ? "true" : "false"); + fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu); + fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false"); + fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", params.mirostat); + fprintf(stream, "mirostat_ent: %f # default: 5.0\n", params.mirostat_tau); + fprintf(stream, "mirostat_lr: %f # default: 0.1\n", params.mirostat_eta); + fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false"); + fprintf(stream, "model: %s # default: models/7B/ggml-model.bin\n", params.model.c_str()); + fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str()); + fprintf(stream, "mtest: %s # default: false\n", params.mem_test ? "true" : "false"); + fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false"); + fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers); + fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict); + fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", params.n_probs); + fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false"); + fprintf(stream, "no_mul_mat_q: %s # default: false\n", !params.mul_mat_q ? "true" : "false"); + fprintf(stream, "no_penalize_nl: %s # default: false\n", !params.penalize_nl ? "true" : "false"); + fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false"); + fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type); + fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride); + fprintf(stream, "presence_penalty: %f # default: 0.0\n", params.presence_penalty); + dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str()); + fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str()); + fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false"); + fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false"); + dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens); + fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false"); + fprintf(stream, "repeat_penalty: %f # default: 1.1\n", params.repeat_penalty); + + fprintf(stream, "reverse_prompt:\n"); + for (std::string ap : params.antiprompt) { + size_t pos = 0; + while ((pos = ap.find('\n', pos)) != std::string::npos) { + ap.replace(pos, 1, "\\n"); + pos += 1; + } + + fprintf(stream, " - %s\n", ap.c_str()); + } + + fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base); + fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale); + fprintf(stream, "seed: %d # default: -1 (random seed)\n", params.seed); + fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); + fprintf(stream, "temp: %f # default: 0.8\n", params.temp); + + const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + LLAMA_MAX_DEVICES); + dump_vector_float_yaml(stream, "tensor_split", tensor_split_vector); + + fprintf(stream, "tfs: %f # default: 1.0\n", params.tfs_z); + fprintf(stream, "threads: %d # default: %d\n", params.n_threads, std::thread::hardware_concurrency()); + fprintf(stream, "top_k: %d # default: 40\n", params.top_k); + fprintf(stream, "top_p: %f # default: 0.95\n", params.top_p); + fprintf(stream, "typical_p: %f # default: 1.0\n", params.typical_p); + fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); +} diff --git a/plugins/wasi_nn/thirdparty/ggml/common.h b/plugins/wasi_nn/thirdparty/ggml/common.h new file mode 100644 index 00000000..012bf5e1 --- /dev/null +++ b/plugins/wasi_nn/thirdparty/ggml/common.h @@ -0,0 +1,206 @@ +// Various helper functions and utilities + +#pragma once + +#include "llama.h" + +#define LOG_NO_FILE_LINE_FUNCTION +#include "log.h" + +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#define DIRECTORY_SEPARATOR '\\' +#else +#define DIRECTORY_SEPARATOR '/' +#endif // _WIN32 + +#define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0) +#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", ##__VA_ARGS__); exit(1); } while (0) + +// +// CLI argument parsing +// +int32_t get_num_physical_cores(); + +struct gpt_params { + uint32_t seed = -1; // RNG seed + int32_t n_threads = get_num_physical_cores(); + int32_t n_predict = -1; // new tokens to predict + int32_t n_ctx = 512; // context size + int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_draft = 16; // number of tokens to draft during speculative decoding + int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) + int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) + int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors + float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs + int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t n_beams = 0; // if non-zero then use beam search of given width. + float rope_freq_base = 10000.0f; // RoPE base frequency + float rope_freq_scale = 1.0f; // RoPE frequency scaling factor + + // sampling parameters + int32_t top_k = 40; // <= 0 to use vocab size + float top_p = 0.95f; // 1.0 = disabled + float tfs_z = 1.00f; // 1.0 = disabled + float typical_p = 1.00f; // 1.0 = disabled + float temp = 0.80f; // 1.0 = disabled + float repeat_penalty = 1.10f; // 1.0 = disabled + int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float frequency_penalty = 0.00f; // 0.0 = disabled + float presence_penalty = 0.00f; // 0.0 = disabled + int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau = 5.00f; // target entropy + float mirostat_eta = 0.10f; // learning rate + + std::unordered_map logit_bias; // logit bias for specific tokens + + // Classifier-Free Guidance + // https://arxiv.org/abs/2306.17806 + std::string cfg_negative_prompt; // string to help guidance + float cfg_scale = 1.f; // How strong is guidance + + std::string model = "models/7B/ggml-model-f16.gguf"; // model path + std::string model_draft = ""; // draft model for speculative decoding + std::string model_alias = "unknown"; // model alias + std::string prompt = ""; + std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state + std::string input_prefix = ""; // string to prefix user inputs with + std::string input_suffix = ""; // string to suffix user inputs with + std::string grammar = ""; // optional BNF-like grammar to constrain sampling + std::vector antiprompt; // string upon seeing which more user input is prompted + std::string logdir = ""; // directory in which to save YAML log files + + std::string lora_adapter = ""; // lora adapter path + std::string lora_base = ""; // base model path for the lora adapter + + int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. + int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line + // (which is more convenient to use for plotting) + // + bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt + size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score + + bool low_vram = false; // if true, reduce VRAM usage at the cost of performance + bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS + bool memory_f16 = true; // use f16 instead of f32 for memory kv + bool random_prompt = false; // do not randomize prompt if none provided + bool use_color = false; // use color to distinguish generations and inputs + bool interactive = false; // interactive mode + bool prompt_cache_all = false; // save user input and generations to prompt cache + bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it + + bool embedding = false; // get only sentence embedding + bool escape = false; // escape "\n", "\r", "\t", "\'", "\"", and "\\" + bool interactive_first = false; // wait for user input immediately + bool multiline_input = false; // reverse the usage of `\` + bool simple_io = false; // improves compatibility with subprocesses and limited consoles + + bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix + bool ignore_eos = false; // ignore generated EOS tokens + bool instruct = false; // instruction mode (used for Alpaca models) + bool penalize_nl = true; // consider newlines as a repeatable token + bool perplexity = false; // compute perplexity over the prompt + bool use_mmap = true; // use mmap for faster loads + bool use_mlock = false; // use mlock to keep model in memory + bool mem_test = false; // compute maximum memory usage + bool numa = false; // attempt optimizations that help on some NUMA systems + bool export_cgraph = false; // export the computation graph + bool verbose_prompt = false; // print prompt tokens before generation +}; + +bool gpt_params_parse(int argc, char ** argv, gpt_params & params); + +void gpt_print_usage(int argc, char ** argv, const gpt_params & params); + +std::string gpt_random_prompt(std::mt19937 & rng); + +// +// Model utils +// + +std::tuple llama_init_from_gpt_params(gpt_params & params); +struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); + +// +// Vocab utils +// + +// tokenizes a string into a vector of tokens +// should work similar to Python's `tokenizer.encode` +std::vector llama_tokenize( + struct llama_context * ctx, + const std::string & text, + bool add_bos); + +// tokenizes a token into a piece +// should work similar to Python's `tokenizer.id_to_piece` +std::string llama_token_to_piece( + const struct llama_context * ctx, + llama_token token); + +// TODO: these should be moved in llama.h C-style API under single `llama_detokenize` function +// that takes into account the tokenizer type and decides how to handle the leading space +// +// detokenizes a vector of tokens into a string +// should work similar to Python's `tokenizer.decode` +// removes the leading space from the first non-BOS token +std::string llama_detokenize_spm( + llama_context * ctx, + const std::vector & tokens); + +// detokenizes a vector of tokens into a string +// should work similar to Python's `tokenizer.decode` +std::string llama_detokenize_bpe( + llama_context * ctx, + const std::vector & tokens); + +// +// Sampling utils +// + +// this is a common sampling function used across the examples for convenience +// it can serve as a starting point for implementing your own sampling function +// +// required: +// - ctx: context to use for sampling +// - params: sampling parameters +// +// optional: +// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL +// - grammar: grammar to use for sampling, ignore if NULL +// - last_tokens: needed for repetition penalty, ignore if empty +// - idx: sample from llama_get_logits(ctx) + idx * n_vocab +// +// returns: +// - token: sampled token +// - candidates: vector of candidate tokens +// +llama_token llama_sample_token( + struct llama_context * ctx, + struct llama_context * ctx_guidance, + struct llama_grammar * grammar, + const struct gpt_params & params, + const std::vector & last_tokens, + std::vector & candidates, + int idx = 0); + +// +// YAML utils +// + +bool create_directory_with_parents(const std::string & path); +void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector & data); +void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector & data); +void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const char * data); +std::string get_sortable_timestamp(); + +void dump_non_result_info_yaml( + FILE * stream, const gpt_params & params, const llama_context * lctx, + const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc); diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.c b/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.c new file mode 100644 index 00000000..a1f6e7bf --- /dev/null +++ b/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.c @@ -0,0 +1,633 @@ +#include "ggml-alloc.h" +#include "ggml.h" +#include +#include +#include +#include +#include + +#ifdef __has_include + #if __has_include() + #include + #if defined(_POSIX_MAPPED_FILES) + #include + #include + #endif + #endif +#endif + +#if defined(_WIN32) + #define WIN32_LEAN_AND_MEAN + #ifndef NOMINMAX + #define NOMINMAX + #endif + #include + #include +#endif + + +#define UNUSED(x) (void)(x) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define GGML_MAX_CONCUR (2*GGML_MAX_NODES) + +//#define GGML_ALLOCATOR_DEBUG + +//#define AT_PRINTF printf +#define AT_PRINTF(...) ((void)0) + +struct hash_node { + struct ggml_tensor * t; + int n_children; + int n_views; +}; + +static size_t hash(void * p) { + return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE; +} + +static struct hash_node * hash_get(struct hash_node hash_table[], struct ggml_tensor * t) { + size_t h = hash(t); + + // linear probing + size_t i = h; + while (hash_table[i].t != NULL) { + if (hash_table[i].t == t) { + return &hash_table[i]; + } + i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE; + if (i == h) { + // hash table is full + GGML_ASSERT(false); + } + } + + hash_table[i].t = t; + return &hash_table[i]; +} + +// TODO: GGML_PAD ? +static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) { + assert(alignment && !(alignment & (alignment - 1))); // power of 2 + size_t align = (alignment - (((uintptr_t)buffer + offset) % alignment)) % alignment; + return offset + align; +} + +struct free_block { + void * addr; + size_t size; +}; + +#define MAX_FREE_BLOCKS 128 + +struct ggml_allocr { + void * data; + size_t size; + size_t alignment; + int n_free_blocks; + struct free_block free_blocks[MAX_FREE_BLOCKS]; + struct hash_node hash_table[GGML_GRAPH_HASHTABLE_SIZE]; + size_t max_size; + bool measure; + int parse_seq[GGML_MAX_CONCUR]; + int parse_seq_len; + +#ifdef GGML_ALLOCATOR_DEBUG + struct ggml_tensor * allocated_tensors[1024]; +#endif +}; + +#ifdef GGML_ALLOCATOR_DEBUG +static void add_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) { + for (int i = 0; i < 1024; i++) { + if (alloc->allocated_tensors[i] == NULL) { + alloc->allocated_tensors[i] = tensor; + return; + } + } + GGML_ASSERT(!"out of allocated_tensors"); +} +static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) { + for (int i = 0; i < 1024; i++) { + if (alloc->allocated_tensors[i] == tensor || + (alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) { + alloc->allocated_tensors[i] = NULL; + return; + } + } + printf("tried to free tensor %s not found\n", tensor->name); + GGML_ASSERT(!"tensor not found"); +} +#endif + +static size_t ggml_allocr_get_alloc_size(struct ggml_allocr * alloc, struct ggml_tensor * tensor) { + return ggml_nbytes(tensor); + + UNUSED(alloc); +} + +// check if a tensor is allocated by this buffer +static bool ggml_allocr_is_own(struct ggml_allocr * alloc, const struct ggml_tensor * tensor) { + void * ptr = tensor->data; + return ptr >= alloc->data && (char *)ptr < (char *)alloc->data + alloc->max_size; +} + +void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) { +#ifdef GGML_ALLOCATOR_DEBUG + GGML_ASSERT(!ggml_is_view(tensor)); // views generally get data pointer from one of their sources + GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated +#endif + size_t size = ggml_allocr_get_alloc_size(alloc, tensor); + size = aligned_offset(NULL, size, alloc->alignment); + + AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size); + + size_t max_avail = 0; + + // find the best fitting free block besides the last block + int best_fit_block = -1; + size_t best_fit_size = SIZE_MAX; + for (int i = 0; i < alloc->n_free_blocks - 1; i++) { + struct free_block * block = &alloc->free_blocks[i]; + max_avail = MAX(max_avail, block->size); + if (block->size >= size && block->size <= best_fit_size) { + best_fit_block = i; + best_fit_size = block->size; + } + } + + AT_PRINTF("block %d\n", best_fit_block); + + if (best_fit_block == -1) { + // the last block is our last resort + struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1]; + max_avail = MAX(max_avail, block->size); + if (block->size >= size) { + best_fit_block = alloc->n_free_blocks - 1; + } else { + fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n", + __func__, size, max_avail); + GGML_ASSERT(!"not enough space in the buffer"); + return; + } + } + struct free_block * block = &alloc->free_blocks[best_fit_block]; + void * addr = block->addr; + block->addr = (char*)block->addr + size; + block->size -= size; + if (block->size == 0) { + // remove block if empty + alloc->n_free_blocks--; + for (int j = best_fit_block; j < alloc->n_free_blocks; j++) { + alloc->free_blocks[j] = alloc->free_blocks[j+1]; + } + } + + tensor->data = addr; + +#ifdef GGML_ALLOCATOR_DEBUG + add_allocated_tensor(alloc, tensor); + size_t cur_max = (char*)addr - (char*)alloc->data + size; + if (cur_max > alloc->max_size) { + printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0); + for (int i = 0; i < 1024; i++) { + if (alloc->allocated_tensors[i]) { + printf("%s (%.2f MB) ", alloc->allocated_tensors[i]->name, ggml_nbytes(alloc->allocated_tensors[i]) / 1024.0 / 1024.0); + } + } + printf("\n"); + } +#endif + + alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->data + size); +} + +// this is a very naive implementation, but for our case the number of free blocks should be very small +static void ggml_allocr_free_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) { + void * ptr = tensor->data; + + if (ggml_allocr_is_own(alloc, tensor) == false) { + // the tensor was not allocated in this buffer + // this can happen because the graph allocator will try to free weights and other tensors from different buffers + // the easiest way to deal with this is just to ignore it + return; + } + + size_t size = ggml_allocr_get_alloc_size(alloc, tensor); + size = aligned_offset(NULL, size, alloc->alignment); + AT_PRINTF("%s: freeing %s (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, size, alloc->n_free_blocks); + +#ifdef GGML_ALLOCATOR_DEBUG + remove_allocated_tensor(alloc, tensor); +#endif + + // see if we can merge with an existing block + for (int i = 0; i < alloc->n_free_blocks; i++) { + struct free_block * block = &alloc->free_blocks[i]; + // check if ptr is at the end of the block + if ((char*)block->addr + block->size == ptr) { + block->size += size; + // check if we can merge with the next block + if (i < alloc->n_free_blocks - 1 && (char*)block->addr + block->size == alloc->free_blocks[i+1].addr) { + block->size += alloc->free_blocks[i+1].size; + alloc->n_free_blocks--; + for (int j = i+1; j < alloc->n_free_blocks; j++) { + alloc->free_blocks[j] = alloc->free_blocks[j+1]; + } + } + return; + } + // check if ptr is at the beginning of the block + if ((char*)ptr + size == block->addr) { + block->addr = ptr; + block->size += size; + // check if we can merge with the previous block + if (i > 0 && (char*)alloc->free_blocks[i-1].addr + alloc->free_blocks[i-1].size == block->addr) { + alloc->free_blocks[i-1].size += block->size; + alloc->n_free_blocks--; + for (int j = i; j < alloc->n_free_blocks; j++) { + alloc->free_blocks[j] = alloc->free_blocks[j+1]; + } + } + return; + } + } + // otherwise, add a new block + GGML_ASSERT(alloc->n_free_blocks < MAX_FREE_BLOCKS && "out of free blocks"); + // insert the new block in the correct position to keep the array sorted by address (to make merging blocks faster) + int insert_pos = 0; + while (insert_pos < alloc->n_free_blocks && alloc->free_blocks[insert_pos].addr < ptr) { + insert_pos++; + } + // shift all blocks from insert_pos onward to make room for the new block + for (int i = alloc->n_free_blocks; i > insert_pos; i--) { + alloc->free_blocks[i] = alloc->free_blocks[i-1]; + } + // insert the new block + alloc->free_blocks[insert_pos].addr = ptr; + alloc->free_blocks[insert_pos].size = size; + alloc->n_free_blocks++; +} + +void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n) { + for (int i = 0; i < n; i++) { + alloc->parse_seq[i] = list[i]; + } + alloc->parse_seq_len = n; +} + +void ggml_allocr_reset(struct ggml_allocr * alloc) { + alloc->n_free_blocks = 1; + size_t align_offset = aligned_offset(alloc->data, 0, alloc->alignment); + alloc->free_blocks[0].addr = (char *)alloc->data + align_offset; + alloc->free_blocks[0].size = alloc->size - align_offset; +} + +struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment) { + struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */); + + *alloc = (struct ggml_allocr){ + /*.data = */ data, + /*.size = */ size, + /*.alignment = */ alignment, + /*.n_free_blocks = */ 0, + /*.free_blocks = */ {{0}}, + /*.hash_table = */ {{0}}, + /*.max_size = */ 0, + /*.measure = */ false, + /*.parse_seq = */ {0}, + /*.parse_seq_len = */ 0, +#ifdef GGML_ALLOCATOR_DEBUG + /*.allocated_tensors = */ {0}, +#endif + }; + + ggml_allocr_reset(alloc); + + return alloc; +} + +// OS specific functions to allocate and free uncommitted virtual memory +static void * alloc_vmem(size_t size) { +#if defined(_WIN32) + return VirtualAlloc(NULL, size, MEM_RESERVE, PAGE_NOACCESS); +#elif defined(_POSIX_MAPPED_FILES) + void * ptr = mmap(NULL, size, PROT_NONE, MAP_PRIVATE | MAP_ANON, -1, 0); + if (ptr == MAP_FAILED) { + return NULL; + } + return ptr; +#else + // use a fixed address for other platforms + uintptr_t base_addr = (uintptr_t)-size - 0x100; + return (void *)base_addr; +#endif +} + +static void free_vmem(void * base_addr, size_t size) { +#if defined(_WIN32) + VirtualFree(base_addr, 0, MEM_RELEASE); + UNUSED(size); +#elif defined(_POSIX_MAPPED_FILES) + munmap(base_addr, size); +#else + // nothing to do + UNUSED(base_addr); + UNUSED(size); +#endif +} + +// allocate uncommitted virtual memory to measure the size of the graph +static void alloc_measure_vmem(void ** base_addr, size_t * size) { + // 1TB for 64-bit, 1GB for 32-bit + *size = sizeof(void *) == 4 ? 1ULL<<30 : 1ULL<<40; + do { + *base_addr = alloc_vmem(*size); + if (*base_addr != NULL) { + AT_PRINTF("allocated %.2f GB of virtual memory for measure buffer at %p\n", *size / 1024.0 / 1024.0 / 1024.0, *base_addr); + return; + } + // try again with half the size + *size /= 2; + } while (*size > 0); + + GGML_ASSERT(!"failed to allocate virtual memory for measure buffer"); +} + +static void free_measure_vmem(void * base_addr, size_t size) { + free_vmem(base_addr, size); +} + +struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) { + struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */); + + void * base_addr; + size_t size; + + alloc_measure_vmem(&base_addr, &size); + + *alloc = (struct ggml_allocr){ + /*.data = */ base_addr, + /*.size = */ size, + /*.alignment = */ alignment, + /*.n_free_blocks = */ 0, + /*.free_blocks = */ {{0}}, + /*.hash_table = */ {{0}}, + /*.max_size = */ 0, + /*.measure = */ true, + /*.parse_seq = */ {0}, + /*.parse_seq_len = */ 0, +#ifdef GGML_ALLOCATOR_DEBUG + /*.allocated_tensors = */ {0}, +#endif + }; + + ggml_allocr_reset(alloc); + + return alloc; +} + +void ggml_allocr_free(struct ggml_allocr * alloc) { + if (alloc->measure) { + free_measure_vmem(alloc->data, alloc->size); + } + free(alloc); +} + +bool ggml_allocr_is_measure(struct ggml_allocr * alloc) { + return alloc->measure; +} + +//////////// compute graph allocator + +static bool ggml_is_view(struct ggml_tensor * t) { + return t->view_src != NULL; +} + +static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) { + if (a->type != b->type) { + return false; + } + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (a->ne[i] != b->ne[i]) { + return false; + } + if (a->nb[i] != b->nb[i]) { + return false; + } + } + return true; +} + +static bool ggml_op_can_inplace(enum ggml_op op) { + switch (op) { + case GGML_OP_SCALE: + case GGML_OP_DIAG_MASK_ZERO: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_ADD: + case GGML_OP_ADD1: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_LOG: + case GGML_OP_UNARY: + case GGML_OP_ROPE: + case GGML_OP_RMS_NORM: + case GGML_OP_SOFT_MAX: + case GGML_OP_CONT: + return true; + + default: + return false; + } +} + +static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node) { + struct hash_node * ht = alloc->hash_table; + if (node->data == NULL) { + if (ggml_is_view(node)) { + assert(node->view_src->data != NULL); + node->data = (char *)node->view_src->data + node->view_offs; + } else { + // see if we can reuse a parent's buffer (inplace) + if (ggml_op_can_inplace(node->op)) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + struct ggml_tensor * parent = node->src[i]; + if (parent == NULL) { + break; + } + + // if the node's data is external, then we cannot re-use it + if (ggml_allocr_is_own(alloc, parent) == false) { + AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data); + continue; + } + + struct hash_node * p_hn = hash_get(ht, parent); + if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) { + if (ggml_is_view(parent)) { + struct ggml_tensor * view_src = parent->view_src; + struct hash_node * view_src_hn = hash_get(ht, view_src); + if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) { + // TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite + // the parent's data that it will need later (same layout requirement). the problem is that then + // we cannot free the tensor because the original address of the allocation is lost. + // adding a view_src pointer to the tensor would solve this and simplify the code dealing with views + // for now, we only reuse the parent's data if the offset is zero (view_src->data == parent->data) + AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name); + node->data = parent->data; + return; + } + } + else { + AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name); + node->data = parent->data; + return; + } + } + } + } + ggml_allocr_alloc(alloc, node); + } + } +} + +static size_t ggml_allocr_alloc_graph_tensors_n( + struct ggml_allocr * alloc, + struct ggml_cgraph ** graphs, int n_graphs, + struct ggml_tensor *** inputs, struct ggml_tensor *** outputs) { + + // reset hash table + struct hash_node * ht = alloc->hash_table; + memset(ht, 0, sizeof(struct hash_node) * GGML_GRAPH_HASHTABLE_SIZE); + + // count number of children and views + for (int g = 0; g < n_graphs; g++) { + struct ggml_cgraph * gf = graphs[g]; + for (int i = 0; i < gf->n_nodes; i++) { + struct ggml_tensor * node = gf->nodes[i]; + + if (ggml_is_view(node)) { + struct ggml_tensor * view_src = node->view_src; + hash_get(ht, view_src)->n_views += 1; + } + + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * parent = node->src[j]; + if (parent == NULL) { + break; + } + hash_get(ht, parent)->n_children += 1; + } + } + } + + // allocate tensors + for (int g = 0; g < n_graphs; g++) { + struct ggml_cgraph * gf = graphs[g]; + AT_PRINTF("####### graph %d/%d\n", g, n_graphs); + // graph inputs are allocated first to ensure that they are not overwritten by each other + if (inputs != NULL && inputs[g] != NULL) { + for (int i = 0; inputs[g][i] != NULL; i++) { + struct ggml_tensor * input = inputs[g][i]; + AT_PRINTF("input: %s\n", input->name); + allocate_node(alloc, input); + } + } + // if we have parse_seq then we allocate nodes following the list, and we only free nodes at barriers + int last_barrier_pos = 0; + int n_nodes = alloc->parse_seq_len ? alloc->parse_seq_len : gf->n_nodes; + + for (int ind = 0; ind < n_nodes; ind++) { + // allocate a node if there is no parse_seq or this is not a barrier + if ((alloc->parse_seq_len==0) || alloc->parse_seq[ind] != -1) { + int i = alloc->parse_seq_len ? alloc->parse_seq[ind] : ind; + struct ggml_tensor * node = gf->nodes[i]; + + // allocate parents (leafs) + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * parent = node->src[j]; + if (parent == NULL) { + break; + } + allocate_node(alloc, parent); + } + + // allocate node + allocate_node(alloc, node); + + AT_PRINTF("exec: %s (%s) <= ", ggml_op_name(node->op), node->name); + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * parent = node->src[j]; + if (parent == NULL) { + break; + } + AT_PRINTF("%s", parent->name); + if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) { + AT_PRINTF(", "); + } + } + AT_PRINTF("\n"); + } + + // update parents + // update immediately if there is no parse_seq + // update only at barriers if there is parse_seq + if ((alloc->parse_seq_len == 0) || alloc->parse_seq[ind] == -1) { + int update_start = alloc->parse_seq_len ? last_barrier_pos : ind; + int update_end = alloc->parse_seq_len ? ind : ind + 1; + for (int i = update_start; i < update_end; i++) { + int node_i = alloc->parse_seq_len ? alloc->parse_seq[i] : i; + struct ggml_tensor * node = gf->nodes[node_i]; + + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * parent = node->src[j]; + if (parent == NULL) { + break; + } + struct hash_node * p_hn = hash_get(ht, parent); + p_hn->n_children -= 1; + + //AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views); + + if (p_hn->n_children == 0 && p_hn->n_views == 0) { + if (ggml_is_view(parent)) { + struct ggml_tensor * view_src = parent->view_src; + struct hash_node * view_src_hn = hash_get(ht, view_src); + view_src_hn->n_views -= 1; + AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views); + if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) { + ggml_allocr_free_tensor(alloc, view_src); + } + } + else { + if (parent->data != node->data) { + ggml_allocr_free_tensor(alloc, parent); + } + } + } + } + } + AT_PRINTF("\n"); + if (alloc->parse_seq_len) { + last_barrier_pos = ind + 1; + } + } + } + // free graph outputs here that wouldn't be freed otherwise because they have no children + if (outputs != NULL && outputs[g] != NULL) { + for (int i = 0; outputs[g][i] != NULL; i++) { + struct ggml_tensor * output = outputs[g][i]; + AT_PRINTF("output: %s\n", output->name); + ggml_allocr_free_tensor(alloc, output); + } + } + } + + return alloc->max_size; +} + +size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) { + return ggml_allocr_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL); +} diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.h b/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.h new file mode 100644 index 00000000..9559da75 --- /dev/null +++ b/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.h @@ -0,0 +1,26 @@ +#pragma once + +#include "ggml.h" + +#ifdef __cplusplus +extern "C" { +#endif + + +GGML_API struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment); +GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment); + +// tell the allocator to parse nodes following the order described in the list +// you should call this if your graph are optimized to execute out-of-order +GGML_API void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n); + +GGML_API void ggml_allocr_free(struct ggml_allocr * alloc); +GGML_API bool ggml_allocr_is_measure(struct ggml_allocr * alloc); +GGML_API void ggml_allocr_reset(struct ggml_allocr * alloc); +GGML_API void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor); +GGML_API size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph); + + +#ifdef __cplusplus +} +#endif diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml.c b/plugins/wasi_nn/thirdparty/ggml/ggml.c new file mode 100644 index 00000000..a9cffb43 --- /dev/null +++ b/plugins/wasi_nn/thirdparty/ggml/ggml.c @@ -0,0 +1,20812 @@ +#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows + +#include "ggml.h" + +#ifdef GGML_USE_K_QUANTS +#include "k_quants.h" +#endif + +#if defined(_MSC_VER) || defined(__MINGW32__) +#include // using malloc.h with MSC/MINGW +#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef GGML_USE_METAL +#include +#endif + +// static_assert should be a #define, but if it's not, +// fall back to the _Static_assert C11 keyword. +// if C99 - static_assert is noop +// ref: https://stackoverflow.com/a/53923785/4039976 +#ifndef static_assert +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L) +#define static_assert(cond, msg) _Static_assert(cond, msg) +#else +#define static_assert(cond, msg) struct global_scope_noop_trick +#endif +#endif + +#if defined(_MSC_VER) +// disable "possible loss of data" to avoid hundreds of casts +// we should just be careful :) +#pragma warning(disable: 4244 4267) + +// disable POSIX deprecation warnigns +// these functions are never going away, anyway +#pragma warning(disable: 4996) +#endif + +#if defined(_WIN32) + +#include + +typedef volatile LONG atomic_int; +typedef atomic_int atomic_bool; + +static void atomic_store(atomic_int * ptr, LONG val) { + InterlockedExchange(ptr, val); +} +static LONG atomic_load(atomic_int * ptr) { + return InterlockedCompareExchange(ptr, 0, 0); +} +static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) { + return InterlockedExchangeAdd(ptr, inc); +} +static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) { + return atomic_fetch_add(ptr, -(dec)); +} + +typedef HANDLE pthread_t; + +typedef DWORD thread_ret_t; +static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(void *), void * arg) { + (void) unused; + HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL); + if (handle == NULL) + { + return EAGAIN; + } + + *out = handle; + return 0; +} + +static int pthread_join(pthread_t thread, void * unused) { + (void) unused; + return (int) WaitForSingleObject(thread, INFINITE); +} + +static int sched_yield (void) { + Sleep (0); + return 0; +} +#else +#include +#include + +typedef void * thread_ret_t; + +#include +#include +#include + +#endif +#ifdef GGML_USE_CPU_HBM +#include +#endif + +// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512 +#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)) +#ifndef __FMA__ +#define __FMA__ +#endif +#ifndef __F16C__ +#define __F16C__ +#endif +#ifndef __SSE3__ +#define __SSE3__ +#endif +#endif + +/*#define GGML_PERF*/ +#define GGML_DEBUG 0 +#define GGML_GELU_FP16 +#define GGML_GELU_QUICK_FP16 +#define GGML_SILU_FP16 +// #define GGML_CROSS_ENTROPY_EXP_FP16 +// #define GGML_FLASH_ATTN_EXP_FP16 + +#define GGML_SOFT_MAX_UNROLL 4 +#define GGML_VEC_DOT_UNROLL 2 + +// +// logging +// + +#if (GGML_DEBUG >= 1) +#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG(...) +#endif + +#if (GGML_DEBUG >= 5) +#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG_5(...) +#endif + +#if (GGML_DEBUG >= 10) +#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG_10(...) +#endif + +#define GGML_PRINT(...) printf(__VA_ARGS__) + +#ifdef GGML_USE_ACCELERATE +// uncomment to use vDSP for soft max computation +// note: not sure if it is actually faster +//#define GGML_SOFT_MAX_ACCELERATE +#endif + +// +// logging +// + +#if (GGML_DEBUG >= 1) +#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG(...) +#endif + +#if (GGML_DEBUG >= 5) +#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG_5(...) +#endif + +#if (GGML_DEBUG >= 10) +#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG_10(...) +#endif + +#define GGML_PRINT(...) printf(__VA_ARGS__) + +// +// end of logging block +// + +#if defined(_MSC_VER) || defined(__MINGW32__) +#define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN) +#define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr) +#else +inline static void * ggml_aligned_malloc(size_t size) { + if (size == 0) { + GGML_PRINT("WARNING: Behavior may be unexpected when allocating 0 bytes for ggml_aligned_malloc!\n"); + return NULL; + } + void * aligned_memory = NULL; +#ifdef GGML_USE_CPU_HBM + int result = hbw_posix_memalign(&aligned_memory, 16, size); +#elif GGML_USE_METAL + int result = posix_memalign(&aligned_memory, sysconf(_SC_PAGESIZE), size); +#else + int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size); +#endif + if (result != 0) { + // Handle allocation failure + const char *error_desc = "unknown allocation error"; + switch (result) { + case EINVAL: + error_desc = "invalid alignment value"; + break; + case ENOMEM: + error_desc = "insufficient memory"; + break; + } + GGML_PRINT("%s: %s (attempted to allocate %6.2f MB)\n", __func__, error_desc, size/(1024.0*1024.0)); + return NULL; + } + return aligned_memory; +} +#define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size) +#ifdef GGML_USE_CPU_HBM +#define GGML_ALIGNED_FREE(ptr) if(NULL != ptr) hbw_free(ptr) +#else +#define GGML_ALIGNED_FREE(ptr) free(ptr) +#endif +#endif + +#define UNUSED GGML_UNUSED +#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0) + +// +// tensor access macros +// + +#define GGML_TENSOR_UNARY_OP_LOCALS \ + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); \ + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); \ + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); \ + GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + +#define GGML_TENSOR_BINARY_OP_LOCALS \ + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); \ + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); \ + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); \ + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); \ + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); \ + GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + +#if defined(GGML_USE_ACCELERATE) +#include +#if defined(GGML_USE_CLBLAST) // allow usage of CLBlast alongside Accelerate functions +#include "ggml-opencl.h" +#endif +#elif defined(GGML_USE_OPENBLAS) +#if defined(GGML_BLAS_USE_MKL) +#include +#else +#include +#endif +#elif defined(GGML_USE_CUBLAS) +#include "ggml-cuda.h" +#elif defined(GGML_USE_CLBLAST) +#include "ggml-opencl.h" +#endif + +#undef MIN +#undef MAX +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +// floating point type used to accumulate sums +typedef double ggml_float; + +// 16-bit float +// on Arm, we use __fp16 +// on x86, we use uint16_t +#if defined(__ARM_NEON) && !defined(_MSC_VER) + +// if YCM cannot find , make a symbolic link to it, for example: +// +// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ +// +#include + +#define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x)) +#define GGML_COMPUTE_FP32_TO_FP16(x) (x) + +#define GGML_FP16_TO_FP32(x) ((float) (x)) +#define GGML_FP32_TO_FP16(x) (x) + +#else + +#ifdef __wasm_simd128__ +#include +#else +#ifdef __POWER9_VECTOR__ +#include +#undef bool +#define bool _Bool +#else +#if defined(_MSC_VER) || defined(__MINGW32__) +#include +#else +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) +#if !defined(__riscv) +#include +#endif +#endif +#endif +#endif +#endif + +#ifdef __riscv_v_intrinsic +#include +#endif + +#ifdef __F16C__ + +#ifdef _MSC_VER +#define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x))) +#define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0) +#else +#define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x) +#define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0) +#endif + +#elif defined(__POWER9_VECTOR__) + +#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) +#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) +/* the inline asm below is about 12% faster than the lookup method */ +#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x) +#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) + +static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { + register float f; + register double d; + __asm__( + "mtfprd %0,%2\n" + "xscvhpdp %0,%0\n" + "frsp %1,%0\n" : + /* temp */ "=d"(d), + /* out */ "=f"(f): + /* in */ "r"(h)); + return f; +} + +static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { + register double d; + register ggml_fp16_t r; + __asm__( /* xscvdphp can work on double or single precision */ + "xscvdphp %0,%2\n" + "mffprd %1,%0\n" : + /* temp */ "=d"(d), + /* out */ "=r"(r): + /* in */ "f"(f)); + return r; +} + +#else + +// FP16 <-> FP32 +// ref: https://github.com/Maratyszcza/FP16 + +static inline float fp32_from_bits(uint32_t w) { + union { + uint32_t as_bits; + float as_value; + } fp32; + fp32.as_bits = w; + return fp32.as_value; +} + +static inline uint32_t fp32_to_bits(float f) { + union { + float as_value; + uint32_t as_bits; + } fp32; + fp32.as_value = f; + return fp32.as_bits; +} + +static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { + const uint32_t w = (uint32_t) h << 16; + const uint32_t sign = w & UINT32_C(0x80000000); + const uint32_t two_w = w + w; + + const uint32_t exp_offset = UINT32_C(0xE0) << 23; +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float exp_scale = 0x1.0p-112f; +#else + const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); +#endif + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +} + +static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float scale_to_inf = 0x1.0p+112f; + const float scale_to_zero = 0x1.0p-110f; +#else + const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); + const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); +#endif + float base = (fabsf(f) * scale_to_inf) * scale_to_zero; + + const uint32_t w = fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); +} + +#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) +#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) + +#endif // __F16C__ + +#endif // __ARM_NEON + +// +// global data +// + +// precomputed gelu table for f16 (128 KB) +static ggml_fp16_t table_gelu_f16[1 << 16]; + +// precomputed quick gelu table for f16 (128 KB) +static ggml_fp16_t table_gelu_quick_f16[1 << 16]; + +// precomputed silu table for f16 (128 KB) +static ggml_fp16_t table_silu_f16[1 << 16]; + +// precomputed exp table for f16 (128 KB) +static ggml_fp16_t table_exp_f16[1 << 16]; + +// precomputed f32 table for f16 (256 KB) +static float table_f32_f16[1 << 16]; + +#if defined(__ARM_NEON) || defined(__wasm_simd128__) +#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s +#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) +#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s) +#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s) +#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s) +#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s) +#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s) +#define B8(c,s ) B7(c,s, c), B7(c,s, s) + +// precomputed tables for expanding 8bits to 8 bytes: +static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4 +static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 +#endif + +// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, +// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON. +// This is also true for POWER9. +#if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16) + +inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { + uint16_t s; + memcpy(&s, &f, sizeof(uint16_t)); + return table_f32_f16[s]; +} + +#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x) +#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) + +#endif + +// note: do not use these inside ggml.c +// these are meant to be used via the ggml.h API +float ggml_fp16_to_fp32(ggml_fp16_t x) { + return (float) GGML_FP16_TO_FP32(x); +} + +ggml_fp16_t ggml_fp32_to_fp16(float x) { + return GGML_FP32_TO_FP16(x); +} + +void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int n) { + for (int i = 0; i < n; i++) { + y[i] = GGML_FP16_TO_FP32(x[i]); + } +} + +void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n) { + int i = 0; +#if defined(__F16C__) + for (; i + 7 < n; i += 8) { + __m256 x_vec = _mm256_loadu_ps(x + i); + __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); + _mm_storeu_si128((__m128i *)(y + i), y_vec); + } + for(; i + 3 < n; i += 4) { + __m128 x_vec = _mm_loadu_ps(x + i); + __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); + _mm_storel_epi64((__m128i *)(y + i), y_vec); + } +#endif + for (; i < n; i++) { + y[i] = GGML_FP32_TO_FP16(x[i]); + } +} + +// +// timing +// + +#if defined(_MSC_VER) || defined(__MINGW32__) +static int64_t timer_freq, timer_start; +void ggml_time_init(void) { + LARGE_INTEGER t; + QueryPerformanceFrequency(&t); + timer_freq = t.QuadPart; + + // The multiplication by 1000 or 1000000 below can cause an overflow if timer_freq + // and the uptime is high enough. + // We subtract the program start time to reduce the likelihood of that happening. + QueryPerformanceCounter(&t); + timer_start = t.QuadPart; +} +int64_t ggml_time_ms(void) { + LARGE_INTEGER t; + QueryPerformanceCounter(&t); + return ((t.QuadPart-timer_start) * 1000) / timer_freq; +} +int64_t ggml_time_us(void) { + LARGE_INTEGER t; + QueryPerformanceCounter(&t); + return ((t.QuadPart-timer_start) * 1000000) / timer_freq; +} +#else +void ggml_time_init(void) {} +int64_t ggml_time_ms(void) { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (int64_t)ts.tv_sec*1000 + (int64_t)ts.tv_nsec/1000000; +} + +int64_t ggml_time_us(void) { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (int64_t)ts.tv_sec*1000000 + (int64_t)ts.tv_nsec/1000; +} +#endif + +int64_t ggml_cycles(void) { + return clock(); +} + +int64_t ggml_cycles_per_ms(void) { + return CLOCKS_PER_SEC/1000; +} + +#ifdef GGML_PERF +#define ggml_perf_time_ms() ggml_time_ms() +#define ggml_perf_time_us() ggml_time_us() +#define ggml_perf_cycles() ggml_cycles() +#define ggml_perf_cycles_per_ms() ggml_cycles_per_ms() +#else +#define ggml_perf_time_ms() 0 +#define ggml_perf_time_us() 0 +#define ggml_perf_cycles() 0 +#define ggml_perf_cycles_per_ms() 0 +#endif + + +// +// cache line +// + +#if defined(__cpp_lib_hardware_interference_size) +#define CACHE_LINE_SIZE hardware_destructive_interference_size +#else +#if defined(__POWER9_VECTOR__) +#define CACHE_LINE_SIZE 128 +#else +#define CACHE_LINE_SIZE 64 +#endif +#endif + +static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); + +// +// quantization +// + +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) + +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) +// multiply int8_t, add results pairwise twice +static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { + // Get absolute values of x vectors + const __m128i ax = _mm_sign_epi8(x, x); + // Sign the values of the y vectors + const __m128i sy = _mm_sign_epi8(y, x); + // Perform multiplication and create 16-bit values + const __m128i dot = _mm_maddubs_epi16(ax, sy); + const __m128i ones = _mm_set1_epi16(1); + return _mm_madd_epi16(ones, dot); +} + +#if __AVX__ || __AVX2__ || __AVX512F__ +// horizontally add 8 floats +static inline float hsum_float_8(const __m256 x) { + __m128 res = _mm256_extractf128_ps(x, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(x)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + return _mm_cvtss_f32(res); +} + +// horizontally add 8 int32_t +static inline int hsum_i32_8(const __m256i a) { + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + +// horizontally add 4 int32_t +static inline int hsum_i32_4(const __m128i a) { + const __m128i hi64 = _mm_unpackhi_epi64(a, a); + const __m128i sum64 = _mm_add_epi32(hi64, a); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + +#if defined(__AVX2__) || defined(__AVX512F__) +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m256i shuf_mask = _mm256_set_epi64x( + 0x0303030303030303, 0x0202020202020202, + 0x0101010101010101, 0x0000000000000000); + __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask); + const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe); + bytes = _mm256_or_si256(bytes, bit_mask); + return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1)); +} + +// Unpack 32 4-bit fields into 32 bytes +// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval +static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) +{ + const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi); + const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); + const __m256i lowMask = _mm256_set1_epi8( 0xF ); + return _mm256_and_si256(lowMask, bytes); +} + +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m256i x) { + const __m256i ones = _mm256_set1_epi16(1); + const __m256i summed_pairs = _mm256_madd_epi16(ones, x); + return _mm256_cvtepi32_ps(summed_pairs); +} + +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { +#if __AVXVNNI__ + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); + return _mm256_cvtepi32_ps(summed_pairs); +#else + // Perform multiplication and create 16-bit values + const __m256i dot = _mm256_maddubs_epi16(ax, sy); + return sum_i16_pairs_float(dot); +#endif +} + +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { +#if __AVXVNNIINT8__ + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y); + return _mm256_cvtepi32_ps(summed_pairs); +#else + // Get absolute values of x vectors + const __m256i ax = _mm256_sign_epi8(x, x); + // Sign the values of the y vectors + const __m256i sy = _mm256_sign_epi8(y, x); + return mul_sum_us8_pairs_float(ax, sy); +#endif +} + +static inline __m128i packNibbles( __m256i bytes ) +{ + // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh +#if __AVX512F__ + const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000 + bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh + return _mm256_cvtepi16_epi8(bytes); // abcd_efgh +#else + const __m256i lowByte = _mm256_set1_epi16( 0xFF ); + __m256i high = _mm256_andnot_si256( lowByte, bytes ); + __m256i low = _mm256_and_si256( lowByte, bytes ); + high = _mm256_srli_epi16( high, 4 ); + bytes = _mm256_or_si256( low, high ); + + // Compress uint16_t lanes into bytes + __m128i r0 = _mm256_castsi256_si128( bytes ); + __m128i r1 = _mm256_extracti128_si256( bytes, 1 ); + return _mm_packus_epi16( r0, r1 ); +#endif +} +#elif defined(__AVX__) +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); + const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202); + __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl); + __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh); + const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe); + bytesl = _mm_or_si128(bytesl, bit_mask); + bytesh = _mm_or_si128(bytesh, bit_mask); + bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1)); + bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1)); + return MM256_SET_M128I(bytesh, bytesl); +} + +// Unpack 32 4-bit fields into 32 bytes +// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval +static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) +{ + // Load 16 bytes from memory + __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi); + __m128i tmph = _mm_srli_epi16(tmpl, 4); + const __m128i lowMask = _mm_set1_epi8(0xF); + tmpl = _mm_and_si128(lowMask, tmpl); + tmph = _mm_and_si128(lowMask, tmph); + return MM256_SET_M128I(tmph, tmpl); +} + +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) { + const __m128i ones = _mm_set1_epi16(1); + const __m128i summed_pairsl = _mm_madd_epi16(ones, xl); + const __m128i summed_pairsh = _mm_madd_epi16(ones, xh); + const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl); + return _mm256_cvtepi32_ps(summed_pairs); +} + +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { + const __m128i axl = _mm256_castsi256_si128(ax); + const __m128i axh = _mm256_extractf128_si256(ax, 1); + const __m128i syl = _mm256_castsi256_si128(sy); + const __m128i syh = _mm256_extractf128_si256(sy, 1); + // Perform multiplication and create 16-bit values + const __m128i dotl = _mm_maddubs_epi16(axl, syl); + const __m128i doth = _mm_maddubs_epi16(axh, syh); + return sum_i16_pairs_float(doth, dotl); +} + +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { + const __m128i xl = _mm256_castsi256_si128(x); + const __m128i xh = _mm256_extractf128_si256(x, 1); + const __m128i yl = _mm256_castsi256_si128(y); + const __m128i yh = _mm256_extractf128_si256(y, 1); + // Get absolute values of x vectors + const __m128i axl = _mm_sign_epi8(xl, xl); + const __m128i axh = _mm_sign_epi8(xh, xh); + // Sign the values of the y vectors + const __m128i syl = _mm_sign_epi8(yl, xl); + const __m128i syh = _mm_sign_epi8(yh, xh); + // Perform multiplication and create 16-bit values + const __m128i dotl = _mm_maddubs_epi16(axl, syl); + const __m128i doth = _mm_maddubs_epi16(axh, syh); + return sum_i16_pairs_float(doth, dotl); +} + +static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) +{ + // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh + const __m128i lowByte = _mm_set1_epi16( 0xFF ); + __m128i high = _mm_andnot_si128( lowByte, bytes1 ); + __m128i low = _mm_and_si128( lowByte, bytes1 ); + high = _mm_srli_epi16( high, 4 ); + bytes1 = _mm_or_si128( low, high ); + high = _mm_andnot_si128( lowByte, bytes2 ); + low = _mm_and_si128( lowByte, bytes2 ); + high = _mm_srli_epi16( high, 4 ); + bytes2 = _mm_or_si128( low, high ); + + return _mm_packus_epi16( bytes1, bytes2); +} +#endif +#elif defined(__SSSE3__) +// horizontally add 4x4 floats +static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) { + __m128 res_0 =_mm_hadd_ps(a, b); + __m128 res_1 =_mm_hadd_ps(c, d); + __m128 res =_mm_hadd_ps(res_0, res_1); + res =_mm_hadd_ps(res, res); + res =_mm_hadd_ps(res, res); + + return _mm_cvtss_f32(res); +} +#endif // __AVX__ || __AVX2__ || __AVX512F__ +#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) + +#if defined(__ARM_NEON) + +#if !defined(__aarch64__) + +inline static int32_t vaddvq_s32(int32x4_t v) { + return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); +} + +inline static float vaddvq_f32(float32x4_t v) { + return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3); +} + +inline static float vmaxvq_f32(float32x4_t v) { + return + MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)), + MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3))); +} + +inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) { + int32x4_t res; + + res[0] = roundf(vgetq_lane_f32(v, 0)); + res[1] = roundf(vgetq_lane_f32(v, 1)); + res[2] = roundf(vgetq_lane_f32(v, 2)); + res[3] = roundf(vgetq_lane_f32(v, 3)); + + return res; +} + +#endif +#endif + +#define QK4_0 32 +typedef struct { + ggml_fp16_t d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); + +#define QK4_1 32 +typedef struct { + ggml_fp16_t d; // delta + ggml_fp16_t m; // min + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; +static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding"); + +#define QK5_0 32 +typedef struct { + ggml_fp16_t d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; +static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + +#define QK5_1 32 +typedef struct { + ggml_fp16_t d; // delta + ggml_fp16_t m; // min + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; +static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); + +#define QK8_0 32 +typedef struct { + ggml_fp16_t d; // delta + int8_t qs[QK8_0]; // quants +} block_q8_0; +static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); + +#define QK8_1 32 +typedef struct { + float d; // delta + float s; // d * sum(qs[i]) + int8_t qs[QK8_1]; // quants +} block_q8_1; +static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding"); + +// reference implementation for deterministic creation of model files +static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { + static const int qk = QK4_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < qk; j++) { + const float v = x[i*qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + + const float d = max / -8; + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + for (int j = 0; j < qk/2; ++j) { + const float x0 = x[i*qk + 0 + j]*id; + const float x1 = x[i*qk + qk/2 + j]*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); + + y[i].qs[j] = xi0; + y[i].qs[j] |= xi1 << 4; + } + } +} + +static void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { + quantize_row_q4_0_reference(x, y, k); +} + +static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) { + const int qk = QK4_1; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int j = 0; j < qk; j++) { + const float v = x[i*qk + j]; + + if (v < min) min = v; + if (v > max) max = v; + } + + const float d = (max - min) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + y[i].m = GGML_FP32_TO_FP16(min); + + for (int j = 0; j < qk/2; ++j) { + const float x0 = (x[i*qk + 0 + j] - min)*id; + const float x1 = (x[i*qk + qk/2 + j] - min)*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); + + y[i].qs[j] = xi0; + y[i].qs[j] |= xi1 << 4; + } + } +} + +static void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) { + quantize_row_q4_1_reference(x, y, k); +} + +static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) { + static const int qk = QK5_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < qk; j++) { + const float v = x[i*qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + + const float d = max / -16; + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + uint32_t qh = 0; + + for (int j = 0; j < qk/2; ++j) { + const float x0 = x[i*qk + 0 + j]*id; + const float x1 = x[i*qk + qk/2 + j]*id; + + const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f)); + const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f)); + + y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); + + // get the 5-th bit and store it in qh at the right position + qh |= ((xi0 & 0x10) >> 4) << (j + 0); + qh |= ((xi1 & 0x10) >> 4) << (j + qk/2); + } + + memcpy(&y[i].qh, &qh, sizeof(qh)); + } +} + +static void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) { + quantize_row_q5_0_reference(x, y, k); +} + +static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) { + const int qk = QK5_1; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int j = 0; j < qk; j++) { + const float v = x[i*qk + j]; + + if (v < min) min = v; + if (v > max) max = v; + } + + const float d = (max - min) / ((1 << 5) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + y[i].m = GGML_FP32_TO_FP16(min); + + uint32_t qh = 0; + + for (int j = 0; j < qk/2; ++j) { + const float x0 = (x[i*qk + 0 + j] - min)*id; + const float x1 = (x[i*qk + qk/2 + j] - min)*id; + + const uint8_t xi0 = (uint8_t)(x0 + 0.5f); + const uint8_t xi1 = (uint8_t)(x1 + 0.5f); + + y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); + + // get the 5-th bit and store it in qh at the right position + qh |= ((xi0 & 0x10) >> 4) << (j + 0); + qh |= ((xi1 & 0x10) >> 4) << (j + qk/2); + } + + memcpy(&y[i].qh, &qh, sizeof(y[i].qh)); + } +} + +static void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) { + quantize_row_q5_1_reference(x, y, k); +} + +// reference implementation for deterministic creation of model files +static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) { + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = x[i*QK8_0 + j]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = x[i*QK8_0 + j]*id; + + y[i].qs[j] = roundf(x0); + } + } +} + +static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + for (int j = 0; j < 8; j++) { + const float32x4_t v = vmulq_n_f32(srcv[j], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + } + } +#elif defined(__wasm_simd128__) + for (int i = 0; i < nb; i++) { + v128_t srcv [8]; + v128_t asrcv[8]; + v128_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), + wasm_f32x4_extract_lane(amaxv[0], 1)), + MAX(wasm_f32x4_extract_lane(amaxv[0], 2), + wasm_f32x4_extract_lane(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + for (int j = 0; j < 8; j++) { + const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); + const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); + + y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); + y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); + y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); + y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); + } + } +#elif defined(__AVX2__) || defined(__AVX__) + for (int i = 0; i < nb; i++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + // Quantize these floats + const float d = maxScalar / 127.f; + y[i].d = GGML_FP32_TO_FP16(d); + const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + // Apply the multiplier + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + +#if defined(__AVX2__) + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + _mm256_storeu_si256((__m256i *)y[i].qs, i0); +#else + // Since we don't have in AVX some necessary functions, + // we split the registers in half and call AVX2 analogs from SSE + __m128i ni0 = _mm256_castsi256_si128( i0 ); + __m128i ni1 = _mm256_extractf128_si256( i0, 1); + __m128i ni2 = _mm256_castsi256_si128( i1 ); + __m128i ni3 = _mm256_extractf128_si256( i1, 1); + __m128i ni4 = _mm256_castsi256_si128( i2 ); + __m128i ni5 = _mm256_extractf128_si256( i2, 1); + __m128i ni6 = _mm256_castsi256_si128( i3 ); + __m128i ni7 = _mm256_extractf128_si256( i3, 1); + + // Convert int32 to int16 + ni0 = _mm_packs_epi32( ni0, ni1 ); + ni2 = _mm_packs_epi32( ni2, ni3 ); + ni4 = _mm_packs_epi32( ni4, ni5 ); + ni6 = _mm_packs_epi32( ni6, ni7 ); + // Convert int16 to int8 + ni0 = _mm_packs_epi16( ni0, ni2 ); + ni4 = _mm_packs_epi16( ni4, ni6 ); + + _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); + _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); +#endif + } +#else + // scalar + quantize_row_q8_0_reference(x, y, k); +#endif +} + +// reference implementation for deterministic creation of model files +static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) { + assert(QK8_1 == 32); + assert(k % QK8_1 == 0); + const int nb = k / QK8_1; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_1; j++) { + const float v = x[i*QK8_1 + j]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + + int sum = 0; + + for (int j = 0; j < QK8_1/2; ++j) { + const float v0 = x[i*QK8_1 + j]*id; + const float v1 = x[i*QK8_1 + QK8_1/2 + j]*id; + + y[i].qs[ j] = roundf(v0); + y[i].qs[QK8_1/2 + j] = roundf(v1); + + sum += y[i].qs[ j]; + sum += y[i].qs[QK8_1/2 + j]; + } + + y[i].s = sum*d; + } +} + +static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { + assert(k % QK8_1 == 0); + const int nb = k / QK8_1; + + block_q8_1 * restrict y = vy; + +#if defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + + int32x4_t accv = vdupq_n_s32(0); + + for (int j = 0; j < 8; j++) { + const float32x4_t v = vmulq_n_f32(srcv[j], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + + accv = vaddq_s32(accv, vi); + } + + y[i].s = d * vaddvq_s32(accv); + } +#elif defined(__wasm_simd128__) + for (int i = 0; i < nb; i++) { + v128_t srcv [8]; + v128_t asrcv[8]; + v128_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), + wasm_f32x4_extract_lane(amaxv[0], 1)), + MAX(wasm_f32x4_extract_lane(amaxv[0], 2), + wasm_f32x4_extract_lane(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + + v128_t accv = wasm_i32x4_splat(0); + + for (int j = 0; j < 8; j++) { + const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); + const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); + + y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); + y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); + y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); + y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); + + accv = wasm_i32x4_add(accv, vi); + } + + y[i].s = d * (wasm_i32x4_extract_lane(accv, 0) + + wasm_i32x4_extract_lane(accv, 1) + + wasm_i32x4_extract_lane(accv, 2) + + wasm_i32x4_extract_lane(accv, 3)); + } +#elif defined(__AVX2__) || defined(__AVX__) + for (int i = 0; i < nb; i++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + // Quantize these floats + const float d = maxScalar / 127.f; + y[i].d = d; + const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + // Apply the multiplier + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + +#if defined(__AVX2__) + // Compute the sum of the quants and set y[i].s + y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); + + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + _mm256_storeu_si256((__m256i *)y[i].qs, i0); +#else + // Since we don't have in AVX some necessary functions, + // we split the registers in half and call AVX2 analogs from SSE + __m128i ni0 = _mm256_castsi256_si128( i0 ); + __m128i ni1 = _mm256_extractf128_si256( i0, 1); + __m128i ni2 = _mm256_castsi256_si128( i1 ); + __m128i ni3 = _mm256_extractf128_si256( i1, 1); + __m128i ni4 = _mm256_castsi256_si128( i2 ); + __m128i ni5 = _mm256_extractf128_si256( i2, 1); + __m128i ni6 = _mm256_castsi256_si128( i3 ); + __m128i ni7 = _mm256_extractf128_si256( i3, 1); + + // Compute the sum of the quants and set y[i].s + const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3)); + const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7)); + y[i].s = d * hsum_i32_4(_mm_add_epi32(s0, s1)); + + // Convert int32 to int16 + ni0 = _mm_packs_epi32( ni0, ni1 ); + ni2 = _mm_packs_epi32( ni2, ni3 ); + ni4 = _mm_packs_epi32( ni4, ni5 ); + ni6 = _mm_packs_epi32( ni6, ni7 ); + // Convert int16 to int8 + ni0 = _mm_packs_epi16( ni0, ni2 ); + ni4 = _mm_packs_epi16( ni4, ni6 ); + + _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); + _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); +#endif + } +#else + // scalar + quantize_row_q8_1_reference(x, y, k); +#endif +} + +static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) { + static const int qk = QK4_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + + for (int j = 0; j < qk/2; ++j) { + const int x0 = (x[i].qs[j] & 0x0F) - 8; + const int x1 = (x[i].qs[j] >> 4) - 8; + + y[i*qk + j + 0 ] = x0*d; + y[i*qk + j + qk/2] = x1*d; + } + } +} + +static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) { + static const int qk = QK4_1; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + const float m = GGML_FP16_TO_FP32(x[i].m); + + for (int j = 0; j < qk/2; ++j) { + const int x0 = (x[i].qs[j] & 0x0F); + const int x1 = (x[i].qs[j] >> 4); + + y[i*qk + j + 0 ] = x0*d + m; + y[i*qk + j + qk/2] = x1*d + m; + } + } +} + +static void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) { + static const int qk = QK5_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + + const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16; + const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; + + y[i*qk + j + 0 ] = x0*d; + y[i*qk + j + qk/2] = x1*d; + } + } +} + +static void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) { + static const int qk = QK5_1; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + const float m = GGML_FP16_TO_FP32(x[i].m); + + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + + const int x0 = (x[i].qs[j] & 0x0F) | xh_0; + const int x1 = (x[i].qs[j] >> 4) | xh_1; + + y[i*qk + j + 0 ] = x0*d + m; + y[i*qk + j + qk/2] = x1*d + m; + } + } +} + +static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, int k) { + static const int qk = QK8_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + const block_q8_0 * restrict x = vx; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + + for (int j = 0; j < qk; ++j) { + y[i*qk + j] = x[i].qs[j]*d; + } + } +} + +static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y); +static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y); +static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); + +static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { + [GGML_TYPE_I8] = { + .type_name = "i8", + .blck_size = 1, + .type_size = sizeof(int8_t), + .is_quantized = false, + }, + [GGML_TYPE_I16] = { + .type_name = "i16", + .blck_size = 1, + .type_size = sizeof(int16_t), + .is_quantized = false, + }, + [GGML_TYPE_I32] = { + .type_name = "i32", + .blck_size = 1, + .type_size = sizeof(int32_t), + .is_quantized = false, + }, + [GGML_TYPE_F32] = { + .type_name = "f32", + .blck_size = 1, + .type_size = sizeof(float), + .is_quantized = false, + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, + .vec_dot_type = GGML_TYPE_F32, + }, + [GGML_TYPE_F16] = { + .type_name = "f16", + .blck_size = 1, + .type_size = sizeof(ggml_fp16_t), + .is_quantized = false, + .to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row, + .from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row, + .from_float_reference = (ggml_from_float_t) ggml_fp32_to_fp16_row, + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16, + .vec_dot_type = GGML_TYPE_F16, + }, + [GGML_TYPE_Q4_0] = { + .type_name = "q4_0", + .blck_size = QK4_0, + .type_size = sizeof(block_q4_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q4_0, + .from_float = quantize_row_q4_0, + .from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference, + .vec_dot = ggml_vec_dot_q4_0_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + }, + [GGML_TYPE_Q4_1] = { + .type_name = "q4_1", + .blck_size = QK4_1, + .type_size = sizeof(block_q4_1), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q4_1, + .from_float = quantize_row_q4_1, + .from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference, + .vec_dot = ggml_vec_dot_q4_1_q8_1, + .vec_dot_type = GGML_TYPE_Q8_1, + }, + [GGML_TYPE_Q5_0] = { + .type_name = "q5_0", + .blck_size = QK5_0, + .type_size = sizeof(block_q5_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q5_0, + .from_float = quantize_row_q5_0, + .from_float_reference = (ggml_from_float_t) quantize_row_q5_0_reference, + .vec_dot = ggml_vec_dot_q5_0_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + }, + [GGML_TYPE_Q5_1] = { + .type_name = "q5_1", + .blck_size = QK5_1, + .type_size = sizeof(block_q5_1), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q5_1, + .from_float = quantize_row_q5_1, + .from_float_reference = (ggml_from_float_t) quantize_row_q5_1_reference, + .vec_dot = ggml_vec_dot_q5_1_q8_1, + .vec_dot_type = GGML_TYPE_Q8_1, + }, + [GGML_TYPE_Q8_0] = { + .type_name = "q8_0", + .blck_size = QK8_0, + .type_size = sizeof(block_q8_0), + .is_quantized = true, + .to_float = dequantize_row_q8_0, + .from_float = quantize_row_q8_0, + .from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference, + .vec_dot = ggml_vec_dot_q8_0_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + }, + [GGML_TYPE_Q8_1] = { + .type_name = "q8_1", + .blck_size = QK8_1, + .type_size = sizeof(block_q8_1), + .is_quantized = true, + .from_float = quantize_row_q8_1, + .from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference, + .vec_dot_type = GGML_TYPE_Q8_1, + }, +#ifdef GGML_USE_K_QUANTS + [GGML_TYPE_Q2_K] = { + .type_name = "q2_K", + .blck_size = QK_K, + .type_size = sizeof(block_q2_K), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q2_K, + .from_float = quantize_row_q2_K, + .from_float_reference = (ggml_from_float_t) quantize_row_q2_K_reference, + .vec_dot = ggml_vec_dot_q2_K_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + }, + [GGML_TYPE_Q3_K] = { + .type_name = "q3_K", + .blck_size = QK_K, + .type_size = sizeof(block_q3_K), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q3_K, + .from_float = quantize_row_q3_K, + .from_float_reference = (ggml_from_float_t) quantize_row_q3_K_reference, + .vec_dot = ggml_vec_dot_q3_K_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + }, + [GGML_TYPE_Q4_K] = { + .type_name = "q4_K", + .blck_size = QK_K, + .type_size = sizeof(block_q4_K), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q4_K, + .from_float = quantize_row_q4_K, + .from_float_reference = (ggml_from_float_t) quantize_row_q4_K_reference, + .vec_dot = ggml_vec_dot_q4_K_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + }, + [GGML_TYPE_Q5_K] = { + .type_name = "q5_K", + .blck_size = QK_K, + .type_size = sizeof(block_q5_K), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q5_K, + .from_float = quantize_row_q5_K, + .from_float_reference = (ggml_from_float_t) quantize_row_q5_K_reference, + .vec_dot = ggml_vec_dot_q5_K_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + }, + [GGML_TYPE_Q6_K] = { + .type_name = "q6_K", + .blck_size = QK_K, + .type_size = sizeof(block_q6_K), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q6_K, + .from_float = quantize_row_q6_K, + .from_float_reference = (ggml_from_float_t) quantize_row_q6_K_reference, + .vec_dot = ggml_vec_dot_q6_K_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + }, + [GGML_TYPE_Q8_K] = { + .type_name = "q8_K", + .blck_size = QK_K, + .type_size = sizeof(block_q8_K), + .is_quantized = true, + .from_float = quantize_row_q8_K, + } +#endif +}; + +// For internal test use +ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) { + GGML_ASSERT(type < GGML_TYPE_COUNT); + return type_traits[type]; +} + + +// +// simd mappings +// + +// we define a common set of C macros which map to specific intrinsics based on the current architecture +// we then implement the fundamental computation operations below using only these macros +// adding support for new architectures requires to define the corresponding SIMD macros +// +// GGML_F32_STEP / GGML_F16_STEP +// number of elements to process in a single step +// +// GGML_F32_EPR / GGML_F16_EPR +// number of elements to fit in a single register +// + +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) + +#define GGML_SIMD + +// F32 NEON + +#define GGML_F32_STEP 16 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 float32x4_t +#define GGML_F32x4_ZERO vdupq_n_f32(0.0f) +#define GGML_F32x4_SET1(x) vdupq_n_f32(x) +#define GGML_F32x4_LOAD vld1q_f32 +#define GGML_F32x4_STORE vst1q_f32 +#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c) +#define GGML_F32x4_ADD vaddq_f32 +#define GGML_F32x4_MUL vmulq_f32 +#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x) +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vaddq_f32(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vaddq_f32(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vaddq_f32(x[i], x[offset+i]); \ + } \ + res = GGML_F32x4_REDUCE_ONE(x[0]); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 NEON + +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + #define GGML_F16_STEP 32 + #define GGML_F16_EPR 8 + + #define GGML_F16x8 float16x8_t + #define GGML_F16x8_ZERO vdupq_n_f16(0.0f) + #define GGML_F16x8_SET1(x) vdupq_n_f16(x) + #define GGML_F16x8_LOAD vld1q_f16 + #define GGML_F16x8_STORE vst1q_f16 + #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c) + #define GGML_F16x8_ADD vaddq_f16 + #define GGML_F16x8_MUL vmulq_f16 + #define GGML_F16x8_REDUCE(res, x) \ + { \ + int offset = GGML_F16_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vaddq_f16(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vaddq_f16(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vaddq_f16(x[i], x[offset+i]); \ + } \ + const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \ + const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \ + res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \ + } + + #define GGML_F16_VEC GGML_F16x8 + #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO + #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 + #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i]) + #define GGML_F16_VEC_FMA GGML_F16x8_FMA + #define GGML_F16_VEC_ADD GGML_F16x8_ADD + #define GGML_F16_VEC_MUL GGML_F16x8_MUL + #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE +#else + // if FP16 vector arithmetic is not supported, we use FP32 instead + // and take advantage of the vcvt_ functions to convert to/from FP16 + + #define GGML_F16_STEP 16 + #define GGML_F16_EPR 4 + + #define GGML_F32Cx4 float32x4_t + #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f) + #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x) + #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16(x)) + #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y)) + #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c) + #define GGML_F32Cx4_ADD vaddq_f32 + #define GGML_F32Cx4_MUL vmulq_f32 + #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE + + #define GGML_F16_VEC GGML_F32Cx4 + #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO + #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 + #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) + #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA + #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD + #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL + #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE +#endif + +#elif defined(__AVX__) + +#define GGML_SIMD + +// F32 AVX + +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 8 + +#define GGML_F32x8 __m256 +#define GGML_F32x8_ZERO _mm256_setzero_ps() +#define GGML_F32x8_SET1(x) _mm256_set1_ps(x) +#define GGML_F32x8_LOAD _mm256_loadu_ps +#define GGML_F32x8_STORE _mm256_storeu_ps +#if defined(__FMA__) + #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a) +#else + #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a) +#endif +#define GGML_F32x8_ADD _mm256_add_ps +#define GGML_F32x8_MUL _mm256_mul_ps +#define GGML_F32x8_REDUCE(res, x) \ +{ \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm256_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm256_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm256_add_ps(x[i], x[offset+i]); \ + } \ + const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \ + _mm256_extractf128_ps(x[0], 1)); \ + const __m128 t1 = _mm_hadd_ps(t0, t0); \ + res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \ +} +// TODO: is this optimal ? + +#define GGML_F32_VEC GGML_F32x8 +#define GGML_F32_VEC_ZERO GGML_F32x8_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x8_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x8_LOAD +#define GGML_F32_VEC_STORE GGML_F32x8_STORE +#define GGML_F32_VEC_FMA GGML_F32x8_FMA +#define GGML_F32_VEC_ADD GGML_F32x8_ADD +#define GGML_F32_VEC_MUL GGML_F32x8_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE + +// F16 AVX + +#define GGML_F16_STEP 32 +#define GGML_F16_EPR 8 + +// F16 arithmetic is not supported by AVX, so we use F32 instead + +#define GGML_F32Cx8 __m256 +#define GGML_F32Cx8_ZERO _mm256_setzero_ps() +#define GGML_F32Cx8_SET1(x) _mm256_set1_ps(x) + +#if defined(__F16C__) +// the _mm256_cvt intrinsics require F16C +#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) +#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) +#else +static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) { + float tmp[8]; + + for (int i = 0; i < 8; i++) { + tmp[i] = GGML_FP16_TO_FP32(x[i]); + } + + return _mm256_loadu_ps(tmp); +} +static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) { + float arr[8]; + + _mm256_storeu_ps(arr, y); + + for (int i = 0; i < 8; i++) + x[i] = GGML_FP32_TO_FP16(arr[i]); +} +#define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x) +#define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y) +#endif + +#define GGML_F32Cx8_FMA GGML_F32x8_FMA +#define GGML_F32Cx8_ADD _mm256_add_ps +#define GGML_F32Cx8_MUL _mm256_mul_ps +#define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE + +#define GGML_F16_VEC GGML_F32Cx8 +#define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO +#define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1 +#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p) +#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i]) +#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA +#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD +#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL +#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE + +#elif defined(__POWER9_VECTOR__) + +#define GGML_SIMD + +// F32 POWER9 + +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 vector float +#define GGML_F32x4_ZERO 0.0f +#define GGML_F32x4_SET1 vec_splats +#define GGML_F32x4_LOAD(p) vec_xl(0, p) +#define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p) +#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a) +#define GGML_F32x4_ADD vec_add +#define GGML_F32x4_MUL vec_mul +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vec_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vec_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vec_add(x[i], x[offset+i]); \ + } \ + res = vec_extract(x[0], 0) + \ + vec_extract(x[0], 1) + \ + vec_extract(x[0], 2) + \ + vec_extract(x[0], 3); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 POWER9 +#define GGML_F16_STEP GGML_F32_STEP +#define GGML_F16_EPR GGML_F32_EPR +#define GGML_F16_VEC GGML_F32x4 +#define GGML_F16_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F16_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F16_VEC_FMA GGML_F32x4_FMA +#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE +// Use vec_xl, not vec_ld, in case the load address is not aligned. +#define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \ + vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \ + vec_extract_fp32_from_shortl(vec_xl(0, p)) +#define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i] +#define GGML_F16_VEC_STORE(p, r, i) \ + if (i & 0x1) \ + vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)], \ + r[i - GGML_ENDIAN_BYTE(0)]), \ + 0, p - GGML_F16_EPR) + +#elif defined(__wasm_simd128__) + +#define GGML_SIMD + +// F32 WASM + +#define GGML_F32_STEP 16 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 v128_t +#define GGML_F32x4_ZERO wasm_f32x4_splat(0.0f) +#define GGML_F32x4_SET1(x) wasm_f32x4_splat(x) +#define GGML_F32x4_LOAD wasm_v128_load +#define GGML_F32x4_STORE wasm_v128_store +#define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a) +#define GGML_F32x4_ADD wasm_f32x4_add +#define GGML_F32x4_MUL wasm_f32x4_mul +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + res = wasm_f32x4_extract_lane(x[0], 0) + \ + wasm_f32x4_extract_lane(x[0], 1) + \ + wasm_f32x4_extract_lane(x[0], 2) + \ + wasm_f32x4_extract_lane(x[0], 3); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 WASM + +#define GGML_F16_STEP 16 +#define GGML_F16_EPR 4 + +inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) { + float tmp[4]; + + tmp[0] = GGML_FP16_TO_FP32(p[0]); + tmp[1] = GGML_FP16_TO_FP32(p[1]); + tmp[2] = GGML_FP16_TO_FP32(p[2]); + tmp[3] = GGML_FP16_TO_FP32(p[3]); + + return wasm_v128_load(tmp); +} + +inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) { + float tmp[4]; + + wasm_v128_store(tmp, x); + + p[0] = GGML_FP32_TO_FP16(tmp[0]); + p[1] = GGML_FP32_TO_FP16(tmp[1]); + p[2] = GGML_FP32_TO_FP16(tmp[2]); + p[3] = GGML_FP32_TO_FP16(tmp[3]); +} + +#define GGML_F16x4 v128_t +#define GGML_F16x4_ZERO wasm_f32x4_splat(0.0f) +#define GGML_F16x4_SET1(x) wasm_f32x4_splat(x) +#define GGML_F16x4_LOAD(x) __wasm_f16x4_load(x) +#define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y) +#define GGML_F16x4_FMA GGML_F32x4_FMA +#define GGML_F16x4_ADD wasm_f32x4_add +#define GGML_F16x4_MUL wasm_f32x4_mul +#define GGML_F16x4_REDUCE(res, x) \ +{ \ + int offset = GGML_F16_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + res = wasm_f32x4_extract_lane(x[0], 0) + \ + wasm_f32x4_extract_lane(x[0], 1) + \ + wasm_f32x4_extract_lane(x[0], 2) + \ + wasm_f32x4_extract_lane(x[0], 3); \ +} + +#define GGML_F16_VEC GGML_F16x4 +#define GGML_F16_VEC_ZERO GGML_F16x4_ZERO +#define GGML_F16_VEC_SET1 GGML_F16x4_SET1 +#define GGML_F16_VEC_LOAD(p, i) GGML_F16x4_LOAD(p) +#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i]) +#define GGML_F16_VEC_FMA GGML_F16x4_FMA +#define GGML_F16_VEC_ADD GGML_F16x4_ADD +#define GGML_F16_VEC_MUL GGML_F16x4_MUL +#define GGML_F16_VEC_REDUCE GGML_F16x4_REDUCE + +#elif defined(__SSE3__) + +#define GGML_SIMD + +// F32 SSE + +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 __m128 +#define GGML_F32x4_ZERO _mm_setzero_ps() +#define GGML_F32x4_SET1(x) _mm_set1_ps(x) +#define GGML_F32x4_LOAD _mm_loadu_ps +#define GGML_F32x4_STORE _mm_storeu_ps +#if defined(__FMA__) + // TODO: Does this work? + #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a) +#else + #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a) +#endif +#define GGML_F32x4_ADD _mm_add_ps +#define GGML_F32x4_MUL _mm_mul_ps +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm_add_ps(x[i], x[offset+i]); \ + } \ + const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \ + res = _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \ +} +// TODO: is this optimal ? + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 SSE + +#define GGML_F16_STEP 32 +#define GGML_F16_EPR 4 + +static inline __m128 __sse_f16x4_load(ggml_fp16_t *x) { + float tmp[4]; + + tmp[0] = GGML_FP16_TO_FP32(x[0]); + tmp[1] = GGML_FP16_TO_FP32(x[1]); + tmp[2] = GGML_FP16_TO_FP32(x[2]); + tmp[3] = GGML_FP16_TO_FP32(x[3]); + + return _mm_loadu_ps(tmp); +} + +static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) { + float arr[4]; + + _mm_storeu_ps(arr, y); + + x[0] = GGML_FP32_TO_FP16(arr[0]); + x[1] = GGML_FP32_TO_FP16(arr[1]); + x[2] = GGML_FP32_TO_FP16(arr[2]); + x[3] = GGML_FP32_TO_FP16(arr[3]); +} + +#define GGML_F32Cx4 __m128 +#define GGML_F32Cx4_ZERO _mm_setzero_ps() +#define GGML_F32Cx4_SET1(x) _mm_set1_ps(x) +#define GGML_F32Cx4_LOAD(x) __sse_f16x4_load(x) +#define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y) +#define GGML_F32Cx4_FMA GGML_F32x4_FMA +#define GGML_F32Cx4_ADD _mm_add_ps +#define GGML_F32Cx4_MUL _mm_mul_ps +#define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE + +#define GGML_F16_VEC GGML_F32Cx4 +#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO +#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 +#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) +#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) +#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA +#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD +#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL +#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE + +#endif + +// GGML_F32_ARR / GGML_F16_ARR +// number of registers to use per step +#ifdef GGML_SIMD +#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR) +#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR) +#endif + +// +// fundamental operations +// + +inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } +inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; } +inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } +inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } +inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; } +inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } +inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } +inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } +inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } +inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } + +static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) { +#ifdef GGML_SIMD + float sumf = 0.0f; + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + + sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F32_VEC_REDUCE(sumf, sum); + + // leftovers + for (int i = np; i < n; ++i) { + sumf += x[i]*y[i]; + } +#else + // scalar + ggml_float sumf = 0.0; + for (int i = 0; i < n; ++i) { + sumf += (ggml_float)(x[i]*y[i]); + } +#endif + + *s = sumf; +} + +static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) { + ggml_float sumf = 0.0; + +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO }; + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + + sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F16_VEC_REDUCE(sumf, sum); + + // leftovers + for (int i = np; i < n; ++i) { + sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i])); + } +#else + for (int i = 0; i < n; ++i) { + sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i])); + } +#endif + + *s = sumf; +} + +static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + + const block_q4_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb + for (int i = 0; i < nb; i += 2) { + const block_q4_0 * restrict x0 = &x[i + 0]; + const block_q4_0 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i + 0]; + const block_q8_0 * restrict y1 = &y[i + 1]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + const int8x16_t s8b = vdupq_n_s8(0x8); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // sub 8 + const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); + const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); + const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); + const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + +#if defined(__ARM_FEATURE_DOTPROD) + // dot product into int32x4_t + const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h); + const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); + const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); +#endif + } + + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; ++i) { + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); + + __m256i bx = bytes_from_nibbles_32(x[i].qs); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m256i off = _mm256_set1_epi8( 8 ); + bx = _mm256_sub_epi8( bx, off ); + + __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps( d, q, acc ); + } + + *s = hsum_float_8(acc); +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; ++i) { + // Compute combined scale for the block + const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); + + const __m128i lowMask = _mm_set1_epi8(0xF); + const __m128i off = _mm_set1_epi8(8); + + const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs); + + __m128i bx = _mm_and_si128(lowMask, tmp); + __m128i by = _mm_loadu_si128((const __m128i *)y[i].qs); + bx = _mm_sub_epi8(bx, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx, by); + + bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4)); + by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); + bx = _mm_sub_epi8(bx, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx, by); + + // Convert int32_t to float + __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1)); + + // Apply the scale, and accumulate + acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); + } + + *s = hsum_float_8(acc); +#elif defined(__SSSE3__) + // set constants + const __m128i lowMask = _mm_set1_epi8(0xF); + const __m128i off = _mm_set1_epi8(8); + + // Initialize accumulator with zeros + __m128 acc_0 = _mm_setzero_ps(); + __m128 acc_1 = _mm_setzero_ps(); + __m128 acc_2 = _mm_setzero_ps(); + __m128 acc_3 = _mm_setzero_ps(); + + // First round without accumulation + { + _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 0 and 1 + const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) ); + + const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs); + + __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); + __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs); + bx_0 = _mm_sub_epi8(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); + + __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); + __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16)); + bx_1 = _mm_sub_epi8(bx_1, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); + + _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 2 and 3 + const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) ); + + const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs); + + __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); + __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs); + bx_2 = _mm_sub_epi8(bx_2, off); + const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); + + __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); + __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16)); + bx_3 = _mm_sub_epi8(bx_3, off); + const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); + + // Convert int32_t to float + __m128 p0 = _mm_cvtepi32_ps(i32_0); + __m128 p1 = _mm_cvtepi32_ps(i32_1); + __m128 p2 = _mm_cvtepi32_ps(i32_2); + __m128 p3 = _mm_cvtepi32_ps(i32_3); + + // Apply the scale + acc_0 = _mm_mul_ps( d_0_1, p0 ); + acc_1 = _mm_mul_ps( d_0_1, p1 ); + acc_2 = _mm_mul_ps( d_2_3, p2 ); + acc_3 = _mm_mul_ps( d_2_3, p3 ); + } + + // Main loop + GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb + for (int i = 2; i < nb; i+=2) { + _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 0 and 1 + const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); + + const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs); + + __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); + __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs); + bx_0 = _mm_sub_epi8(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); + + __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); + __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); + bx_1 = _mm_sub_epi8(bx_1, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); + + _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 2 and 3 + const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) ); + + const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs); + + __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); + __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs); + bx_2 = _mm_sub_epi8(bx_2, off); + const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); + + __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); + __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16)); + bx_3 = _mm_sub_epi8(bx_3, off); + const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); + + // Convert int32_t to float + __m128 p0 = _mm_cvtepi32_ps(i32_0); + __m128 p1 = _mm_cvtepi32_ps(i32_1); + __m128 p2 = _mm_cvtepi32_ps(i32_2); + __m128 p3 = _mm_cvtepi32_ps(i32_3); + + // Apply the scale + __m128 p0_d = _mm_mul_ps( d_0_1, p0 ); + __m128 p1_d = _mm_mul_ps( d_0_1, p1 ); + __m128 p2_d = _mm_mul_ps( d_2_3, p2 ); + __m128 p3_d = _mm_mul_ps( d_2_3, p3 ); + + // Acummulate + acc_0 = _mm_add_ps(p0_d, acc_0); + acc_1 = _mm_add_ps(p1_d, acc_1); + acc_2 = _mm_add_ps(p2_d, acc_2); + acc_3 = _mm_add_ps(p3_d, acc_3); + } + + *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; + + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + for (int i = 0; i < nb; i++) { + vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl); + + vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl); + vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl); + + vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl); + vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); + + vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a); + vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l); + + vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl); + vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl); + + vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); + vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs1); + sumi += __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d); + } + + *s = sumf; +#else + // scalar + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + int sumi = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[i].qs[j] & 0x0F) - 8; + const int v1 = (x[i].qs[j] >> 4) - 8; + + sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); + } + + sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d); + } + + *s = sumf; +#endif +} + +static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int qk = QK8_1; + const int nb = n / qk; + + assert(n % qk == 0); + + const block_q4_1 * restrict x = vx; + const block_q8_1 * restrict y = vy; + + // TODO: add WASM SIMD +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + float summs = 0; + + GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb + for (int i = 0; i < nb; i += 2) { + const block_q4_1 * restrict x0 = &x[i + 0]; + const block_q4_1 * restrict x1 = &x[i + 1]; + const block_q8_1 * restrict y0 = &y[i + 0]; + const block_q8_1 * restrict y1 = &y[i + 1]; + + summs += GGML_FP16_TO_FP32(x0->m) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + +#if defined(__ARM_FEATURE_DOTPROD) + // dot product into int32x4_t + const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h); + const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); + const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d); +#endif + } + + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; +#elif defined(__AVX2__) || defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + float summs = 0; + + // Main loop + for (int i = 0; i < nb; ++i) { + const float d0 = GGML_FP16_TO_FP32(x[i].d); + const float d1 = y[i].d; + + summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; + + const __m256 d0v = _mm256_set1_ps( d0 ); + const __m256 d1v = _mm256_set1_ps( d1 ); + + // Compute combined scales + const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); + + // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes + const __m256i bx = bytes_from_nibbles_32(x[i].qs); + const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs ); + + const __m256 xy = mul_sum_us8_pairs_float(bx, by); + + // Accumulate d0*d1*x*y +#if defined(__AVX2__) + acc = _mm256_fmadd_ps( d0d1, xy, acc ); +#else + acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc ); +#endif + } + + *s = hsum_float_8(acc) + summs; +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; + + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + for (int i = 0; i < nb; i++) { + vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl); + + vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl); + vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl); + + vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl); + vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); + + vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a); + vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l); + + vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); + vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs1); + sumi += __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; + } + + *s = sumf; +#else + // scalar + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + int sumi = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[i].qs[j] & 0x0F); + const int v1 = (x[i].qs[j] >> 4); + + sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); + } + + sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; + } + + *s = sumf; +#endif +} + +static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(qk == QK5_0); + + const block_q5_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + uint32_t qh0; + uint32_t qh1; + + uint64_t tmp0[4]; + uint64_t tmp1[4]; + + GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb + for (int i = 0; i < nb; i += 2) { + const block_q5_0 * restrict x0 = &x[i]; + const block_q5_0 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i]; + const block_q8_0 * restrict y1 = &y[i + 1]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + // extract the 5th bit via lookup table ((!b) << 4) + memcpy(&qh0, x0->qh, sizeof(qh0)); + memcpy(&qh1, x1->qh, sizeof(qh1)); + + tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF]; + tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF]; + tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF]; + tmp0[3] = table_b2b_1[(qh0 >> 24) ]; + + tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF]; + tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF]; + tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF]; + tmp1[3] = table_b2b_1[(qh1 >> 24) ]; + + const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); + const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); + const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); + const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) + const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0); + const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0); + const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1); + const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + +#if defined(__ARM_FEATURE_DOTPROD) + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), + vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), + vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); + const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); +#endif + } + + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__wasm_simd128__) + v128_t sumv = wasm_f32x4_splat(0.0f); + + uint32_t qh; + uint64_t tmp[4]; + + // TODO: check if unrolling this is better + for (int i = 0; i < nb; ++i) { + const block_q5_0 * restrict x0 = &x[i]; + const block_q8_0 * restrict y0 = &y[i]; + + const v128_t m4b = wasm_i8x16_splat(0x0F); + + // extract the 5th bit + memcpy(&qh, x0->qh, sizeof(qh)); + + tmp[0] = table_b2b_1[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b_1[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b_1[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b_1[(qh >> 24) ]; + + const v128_t qhl = wasm_v128_load(tmp + 0); + const v128_t qhh = wasm_v128_load(tmp + 2); + + const v128_t v0 = wasm_v128_load(x0->qs); + + // 4-bit -> 8-bit + const v128_t v0l = wasm_v128_and (v0, m4b); + const v128_t v0h = wasm_u8x16_shr(v0, 4); + + // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) + const v128_t v0lf = wasm_i8x16_sub(v0l, qhl); + const v128_t v0hf = wasm_i8x16_sub(v0h, qhh); + + // load y + const v128_t v1l = wasm_v128_load(y0->qs); + const v128_t v1h = wasm_v128_load(y0->qs + 16); + + // int8x16 -> int16x8 + const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); + const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); + const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); + const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); + + const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); + const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); + const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); + const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); + + // dot product + sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4( + wasm_i32x4_add( + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), + wasm_i32x4_dot_i16x8(v0lfh, v1lh)), + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), + wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), + wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d)))); + } + + *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; i++) { + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); + + __m256i bx = bytes_from_nibbles_32(x[i].qs); + __m256i bxhi = bytes_from_bits_32(x[i].qh); + bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0)); + bx = _mm256_or_si256(bx, bxhi); + + __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps(d, q, acc); + } + + *s = hsum_float_8(acc); +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + __m128i mask = _mm_set1_epi8((char)0xF0); + + // Main loop + for (int i = 0; i < nb; i++) { + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); + + __m256i bx = bytes_from_nibbles_32(x[i].qs); + const __m256i bxhi = bytes_from_bits_32(x[i].qh); + __m128i bxhil = _mm256_castsi256_si128(bxhi); + __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); + bxhil = _mm_andnot_si128(bxhil, mask); + bxhih = _mm_andnot_si128(bxhih, mask); + __m128i bxl = _mm256_castsi256_si128(bx); + __m128i bxh = _mm256_extractf128_si256(bx, 1); + bxl = _mm_or_si128(bxl, bxhil); + bxh = _mm_or_si128(bxh, bxhih); + bx = MM256_SET_M128I(bxh, bxl); + + const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + /* Multiply q with scale and accumulate */ + acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc); + } + + *s = hsum_float_8(acc); +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; + + uint32_t qh; + + // These temp values are for masking and shift operations + uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + uint32_t temp_2[16] = {0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80, + 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000}; + + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + for (int i = 0; i < nb; i++) { + memcpy(&qh, x[i].qh, sizeof(uint32_t)); + + // temporary registers + vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_2, vl); + vuint32m4_t vt_2 = __riscv_vle32_v_u32m4(temp_1, vl); + vuint32m4_t vt_3 = __riscv_vsll_vx_u32m4(vt_1, 16, vl); + vuint32m4_t vt_4 = __riscv_vadd_vx_u32m4(vt_2, 12, vl); + + // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(vt_1, qh, vl); + vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(xha_0, vt_2, vl); + vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl); + + // ((qh & (1u << (j + 16))) >> (j + 12)); + vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(vt_3, qh, vl); + vuint32m4_t xhl_1 = __riscv_vsrl_vv_u32m4(xha_1, vt_4, vl); + + // narrowing + vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xhl_0, vl); + vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl); + + vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xhl_1, vl); + vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl); + + // load + vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl); + + vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl); + vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl); + + vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl); + vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); + + vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl); + vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl); + + vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a); + vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l); + + vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 16, vl); + vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 16, vl); + + vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); + vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs1); + sumi += __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi; + } + + *s = sumf; +#else + // scalar + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + int sumi = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); + + const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16; + const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; + + sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); + } + + sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi; + } + + *s = sumf; +#endif +} + +static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int qk = QK8_1; + const int nb = n / qk; + + assert(n % qk == 0); + assert(qk == QK5_1); + + const block_q5_1 * restrict x = vx; + const block_q8_1 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + float summs0 = 0.0f; + float summs1 = 0.0f; + + uint32_t qh0; + uint32_t qh1; + + uint64_t tmp0[4]; + uint64_t tmp1[4]; + + GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb + for (int i = 0; i < nb; i += 2) { + const block_q5_1 * restrict x0 = &x[i]; + const block_q5_1 * restrict x1 = &x[i + 1]; + const block_q8_1 * restrict y0 = &y[i]; + const block_q8_1 * restrict y1 = &y[i + 1]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + summs0 += GGML_FP16_TO_FP32(x0->m) * y0->s; + summs1 += GGML_FP16_TO_FP32(x1->m) * y1->s; + + // extract the 5th bit via lookup table ((b) << 4) + memcpy(&qh0, x0->qh, sizeof(qh0)); + memcpy(&qh1, x1->qh, sizeof(qh1)); + + tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF]; + tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF]; + tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF]; + tmp0[3] = table_b2b_0[(qh0 >> 24) ]; + + tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF]; + tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF]; + tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF]; + tmp1[3] = table_b2b_0[(qh1 >> 24) ]; + + const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); + const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); + const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); + const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // add high bit + const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0); + const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0); + const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1); + const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + +#if defined(__ARM_FEATURE_DOTPROD) + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), + vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), + vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); + const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d); +#endif + } + + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; +#elif defined(__wasm_simd128__) + v128_t sumv = wasm_f32x4_splat(0.0f); + + float summs = 0.0f; + + uint32_t qh; + uint64_t tmp[4]; + + // TODO: check if unrolling this is better + for (int i = 0; i < nb; ++i) { + const block_q5_1 * restrict x0 = &x[i]; + const block_q8_1 * restrict y0 = &y[i]; + + summs += GGML_FP16_TO_FP32(x0->m) * y0->s; + + const v128_t m4b = wasm_i8x16_splat(0x0F); + + // extract the 5th bit + memcpy(&qh, x0->qh, sizeof(qh)); + + tmp[0] = table_b2b_0[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b_0[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b_0[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b_0[(qh >> 24) ]; + + const v128_t qhl = wasm_v128_load(tmp + 0); + const v128_t qhh = wasm_v128_load(tmp + 2); + + const v128_t v0 = wasm_v128_load(x0->qs); + + // 4-bit -> 8-bit + const v128_t v0l = wasm_v128_and (v0, m4b); + const v128_t v0h = wasm_u8x16_shr(v0, 4); + + // add high bit + const v128_t v0lf = wasm_v128_or(v0l, qhl); + const v128_t v0hf = wasm_v128_or(v0h, qhh); + + // load y + const v128_t v1l = wasm_v128_load(y0->qs); + const v128_t v1h = wasm_v128_load(y0->qs + 16); + + // int8x16 -> int16x8 + const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); + const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); + const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); + const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); + + const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); + const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); + const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); + const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); + + // dot product + sumv = wasm_f32x4_add(sumv, + wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add( + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), + wasm_i32x4_dot_i16x8(v0lfh, v1lh)), + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), + wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), + wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d))); + } + + *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.0f; + + // Main loop + for (int i = 0; i < nb; i++) { + const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)); + + summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; + + __m256i bx = bytes_from_nibbles_32(x[i].qs); + __m256i bxhi = bytes_from_bits_32(x[i].qh); + bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10)); + bx = _mm256_or_si256(bx, bxhi); + + const __m256 dy = _mm256_set1_ps(y[i].d); + const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_us8_pairs_float(bx, by); + + acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); + } + + *s = hsum_float_8(acc) + summs; +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + __m128i mask = _mm_set1_epi8(0x10); + + float summs = 0.0f; + + // Main loop + for (int i = 0; i < nb; i++) { + const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)); + + summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; + + __m256i bx = bytes_from_nibbles_32(x[i].qs); + const __m256i bxhi = bytes_from_bits_32(x[i].qh); + __m128i bxhil = _mm256_castsi256_si128(bxhi); + __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); + bxhil = _mm_and_si128(bxhil, mask); + bxhih = _mm_and_si128(bxhih, mask); + __m128i bxl = _mm256_castsi256_si128(bx); + __m128i bxh = _mm256_extractf128_si256(bx, 1); + bxl = _mm_or_si128(bxl, bxhil); + bxh = _mm_or_si128(bxh, bxhih); + bx = MM256_SET_M128I(bxh, bxl); + + const __m256 dy = _mm256_set1_ps(y[i].d); + const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_us8_pairs_float(bx, by); + + acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); + } + + *s = hsum_float_8(acc) + summs; +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; + + uint32_t qh; + + // These temp values are for shift operations + uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + for (int i = 0; i < nb; i++) { + memcpy(&qh, x[i].qh, sizeof(uint32_t)); + + // temporary registers + vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_1, vl); + vuint32m4_t vt_2 = __riscv_vadd_vx_u32m4(vt_1, 12, vl); + + // load qh + vuint32m4_t vqh = __riscv_vmv_v_x_u32m4(qh, vl); + + // ((qh >> (j + 0)) << 4) & 0x10; + vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(vqh, vt_1, vl); + vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl); + vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(xhl_0, 0x10, vl); + + // ((qh >> (j + 12)) ) & 0x10; + vuint32m4_t xhr_1 = __riscv_vsrl_vv_u32m4(vqh, vt_2, vl); + vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(xhr_1, 0x10, vl); + + // narrowing + vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xha_0, vl); + vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl); + + vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xha_1, vl); + vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl); + + // load + vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl); + + vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl); + vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl); + + vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl); + vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); + + vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl); + vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl); + + vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a); + vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l); + + vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); + vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs1); + sumi += __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; + } + + *s = sumf; +#else + // scalar + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + int sumi = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + + const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0; + const int32_t x1 = (x[i].qs[j] >> 4) | xh_1; + + sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); + } + + sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; + } + + *s = sumf; +#endif +} + +static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + + const block_q8_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb + for (int i = 0; i < nb; i += 2) { + const block_q8_0 * restrict x0 = &x[i + 0]; + const block_q8_0 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i + 0]; + const block_q8_0 * restrict y1 = &y[i + 1]; + + const int8x16_t x0_0 = vld1q_s8(x0->qs); + const int8x16_t x0_1 = vld1q_s8(x0->qs + 16); + const int8x16_t x1_0 = vld1q_s8(x1->qs); + const int8x16_t x1_1 = vld1q_s8(x1->qs + 16); + + // load y + const int8x16_t y0_0 = vld1q_s8(y0->qs); + const int8x16_t y0_1 = vld1q_s8(y0->qs + 16); + const int8x16_t y1_0 = vld1q_s8(y1->qs); + const int8x16_t y1_1 = vld1q_s8(y1->qs + 16); + +#if defined(__ARM_FEATURE_DOTPROD) + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), x0_0, y0_0), + vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), x1_0, y1_0), + vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + +#else + const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0)); + const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0)); + const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1)); + const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1)); + + const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0)); + const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0)); + const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1)); + const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1)); + + const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1)); + const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3)); + const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1)); + const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); +#endif + } + + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__AVX2__) || defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; ++i) { + // Compute combined scale for the block + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); + __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs); + __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + // Multiply q with scale and accumulate +#if defined(__AVX2__) + acc = _mm256_fmadd_ps( d, q, acc ); +#else + acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc ); +#endif + } + + *s = hsum_float_8(acc); +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; + size_t vl = __riscv_vsetvl_e8m1(qk); + + for (int i = 0; i < nb; i++) { + // load elements + vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl); + vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl); + + vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum); + + sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)); + } + + *s = sumf; +#else + // scalar + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + int sumi = 0; + + for (int j = 0; j < qk; j++) { + sumi += x[i].qs[j]*y[i].qs[j]; + } + + sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)); + } + + *s = sumf; +#endif +} + +// compute GGML_VEC_DOT_UNROLL dot products at once +// xs - x row stride in bytes +inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) { + ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 }; + + ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL]; + + for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { + x[i] = (ggml_fp16_t *) ((char *) xv + i*xs); + } + +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } }; + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + + for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { + ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j); + + sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]); + } + } + } + + // reduce sum0..sum3 to sum0 + for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { + GGML_F16_VEC_REDUCE(sumf[k], sum[k]); + } + + // leftovers + for (int i = np; i < n; ++i) { + for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { + sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i])); + } + } +#else + for (int i = 0; i < n; ++i) { + for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { + sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i])); + } + } +#endif + + for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { + s[i] = sumf[i]; + } +} + +inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx); + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] += x[i]*v; + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] += x[i]*v; + } +#endif +} + +//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } +inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { +#if defined(GGML_USE_ACCELERATE) + vDSP_vsmul(y, 1, &v, y, 1, n); +#elif defined(GGML_SIMD) + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_MUL(ay[j], vx); + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] *= v; + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] *= v; + } +#endif +} + +inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); } +inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } +inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } +inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); } +inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); } +inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); } +inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } +inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); } +inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; } +inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } + +static const float GELU_COEF_A = 0.044715f; +static const float GELU_QUICK_COEF = -1.702f; +static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +inline static float ggml_gelu_f32(float x) { + return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + const uint16_t * i16 = (const uint16_t *) x; + for (int i = 0; i < n; ++i) { + y[i] = table_gelu_f16[i16[i]]; + } +} + +#ifdef GGML_GELU_FP16 +inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { + uint16_t t; + for (int i = 0; i < n; ++i) { + ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = GGML_FP16_TO_FP32(table_gelu_f16[t]); + } +} +#else +inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + y[i] = ggml_gelu_f32(x[i]); + } +} +#endif + +inline static float ggml_gelu_quick_f32(float x) { + return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x))); +} + +//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { +// const uint16_t * i16 = (const uint16_t *) x; +// for (int i = 0; i < n; ++i) { +// y[i] = table_gelu_quick_f16[i16[i]]; +// } +//} + +#ifdef GGML_GELU_QUICK_FP16 +inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { + uint16_t t; + for (int i = 0; i < n; ++i) { + ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = GGML_FP16_TO_FP32(table_gelu_quick_f16[t]); + } +} +#else +inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + y[i] = ggml_gelu_quick_f32(x[i]); + } +} +#endif + +// Sigmoid Linear Unit (SiLU) function +inline static float ggml_silu_f32(float x) { + return x/(1.0f + expf(-x)); +} + +//inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { +// const uint16_t * i16 = (const uint16_t *) x; +// for (int i = 0; i < n; ++i) { +// y[i] = table_silu_f16[i16[i]]; +// } +//} + +#ifdef GGML_SILU_FP16 +inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) { + uint16_t t; + for (int i = 0; i < n; ++i) { + ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = GGML_FP16_TO_FP32(table_silu_f16[t]); + } +} +#else +inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + y[i] = ggml_silu_f32(x[i]); + } +} +#endif + +inline static float ggml_silu_backward_f32(float x, float dy) { + const float s = 1.0f/(1.0f + expf(-x)); + return dy*s*(1.0f + x*(1.0f - s)); +} + +#ifdef GGML_SILU_FP16 +inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) { + for (int i = 0; i < n; ++i) { + // we did not use x[i] to compute forward silu but its f16 equivalent + // take derivative at f16 of x[i]: + ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); + float usedx = GGML_FP16_TO_FP32(fp16); + dx[i] = ggml_silu_backward_f32(usedx, dy[i]); + } +} +#else +inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) { + for (int i = 0; i < n; ++i) { + dx[i] = ggml_silu_backward_f32(x[i], dy[i]); + } +} +#endif + +inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { +#ifndef GGML_USE_ACCELERATE + ggml_float sum = 0.0; + for (int i = 0; i < n; ++i) { + sum += (ggml_float)x[i]; + } + *s = sum; +#else + vDSP_sve(x, 1, s, n); +#endif +} + +inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) { + ggml_float sum = 0.0; + for (int i = 0; i < n; ++i) { + sum += (ggml_float)x[i]; + } + *s = sum; +} + +inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_t * x) { + float sum = 0.0f; + for (int i = 0; i < n; ++i) { + sum += GGML_FP16_TO_FP32(x[i]); + } + *s = sum; +} + +inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { +#ifndef GGML_USE_ACCELERATE + float max = -INFINITY; + for (int i = 0; i < n; ++i) { + max = MAX(max, x[i]); + } + *s = max; +#else + vDSP_maxv(x, 1, s, n); +#endif +} + +inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { + ggml_vec_norm_f32(n, s, x); + *s = 1.f/(*s); +} + +inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) { + float max = -INFINITY; + int idx = 0; + for (int i = 0; i < n; ++i) { + max = MAX(max, x[i]); + if (max == x[i]) { idx = i; } + } + *s = idx; +} + +// +// data types +// + +static const char * GGML_OP_NAME[GGML_OP_COUNT] = { + "NONE", + + "DUP", + "ADD", + "ADD1", + "ACC", + "SUB", + "MUL", + "DIV", + "SQR", + "SQRT", + "LOG", + "SUM", + "SUM_ROWS", + "MEAN", + "ARGMAX", + "REPEAT", + "REPEAT_BACK", + "CONCAT", + "SILU_BACK", + "NORM", + "RMS_NORM", + "RMS_NORM_BACK", + "GROUP_NORM", + + "MUL_MAT", + "OUT_PROD", + + "SCALE", + "SET", + "CPY", + "CONT", + "RESHAPE", + "VIEW", + "PERMUTE", + "TRANSPOSE", + "GET_ROWS", + "GET_ROWS_BACK", + "DIAG", + "DIAG_MASK_INF", + "DIAG_MASK_ZERO", + "SOFT_MAX", + "SOFT_MAX_BACK", + "ROPE", + "ROPE_BACK", + "ALIBI", + "CLAMP", + "CONV_1D", + "CONV_2D", + "CONV_TRANSPOSE_2D", + "POOL_1D", + "POOL_2D", + "UPSCALE", + + "FLASH_ATTN", + "FLASH_FF", + "FLASH_ATTN_BACK", + "WIN_PART", + "WIN_UNPART", + "GET_REL_POS", + "ADD_REL_POS", + + "UNARY", + + "MAP_UNARY", + "MAP_BINARY", + + "MAP_CUSTOM1_F32", + "MAP_CUSTOM2_F32", + "MAP_CUSTOM3_F32", + + "MAP_CUSTOM1", + "MAP_CUSTOM2", + "MAP_CUSTOM3", + + "CROSS_ENTROPY_LOSS", + "CROSS_ENTROPY_LOSS_BACK", +}; + +static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68"); + +static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { + "none", + + "x", + "x+y", + "x+y", + "view(x,nb,offset)+=y->x", + "x-y", + "x*y", + "x/y", + "x^2", + "√x", + "log(x)", + "Σx", + "Σx_k", + "Σx/n", + "argmax(x)", + "repeat(x)", + "repeat_back(x)", + "concat(x, y)", + "silu_back(x)", + "norm(x)", + "rms_norm(x)", + "rms_norm_back(x)", + "group_norm(x)", + + "X*Y", + "X*Y", + + "x*v", + "y-\\>view(x)", + "x-\\>y", + "cont(x)", + "reshape(x)", + "view(x)", + "permute(x)", + "transpose(x)", + "get_rows(x)", + "get_rows_back(x)", + "diag(x)", + "diag_mask_inf(x)", + "diag_mask_zero(x)", + "soft_max(x)", + "soft_max_back(x)", + "rope(x)", + "rope_back(x)", + "alibi(x)", + "clamp(x)", + "conv_1d(x)", + "conv_2d(x)", + "conv_transpose_2d(x)", + "pool_1d(x)", + "pool_2d(x)", + "upscale(x)", + + "flash_attn(x)", + "flash_ff(x)", + "flash_attn_back(x)", + "win_part(x)", + "win_unpart(x)", + "get_rel_pos(x)", + "add_rel_pos(x)", + + "unary(x)", + + "f(x)", + "f(x,y)", + + "custom_f32(x)", + "custom_f32(x,y)", + "custom_f32(x,y,z)", + + "custom(x)", + "custom(x,y)", + "custom(x,y,z)", + + "cross_entropy_loss(x,y)", + "cross_entropy_loss_back(x,y)", +}; + +static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68"); + +static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); + +static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); +static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); + +// WARN: +// Mis-confguration can lead to problem that's hard to reason about: +// * At best it crash or talks nosense. +// * At worst it talks slightly difference but hard to perceive. +// +// An op has to enable INIT or FINALIZE when any of it's branch needs that pass. +// Take care about compile options (e.g., GGML_USE_xxx). +static bool GGML_OP_HAS_INIT [GGML_OP_COUNT] = { 0 }; +static bool GGML_OP_HAS_FINALIZE[GGML_OP_COUNT] = { 0 }; + +static void ggml_setup_op_has_task_pass(void) { + { // INIT + bool * p = GGML_OP_HAS_INIT; + + p[GGML_OP_ACC ] = true; + p[GGML_OP_MUL_MAT ] = true; + p[GGML_OP_OUT_PROD ] = true; + p[GGML_OP_SET ] = true; + p[GGML_OP_GET_ROWS_BACK ] = true; + p[GGML_OP_DIAG_MASK_INF ] = true; + p[GGML_OP_DIAG_MASK_ZERO ] = true; + p[GGML_OP_CONV_1D ] = true; + p[GGML_OP_CONV_2D ] = true; + p[GGML_OP_CONV_TRANSPOSE_2D ] = true; + p[GGML_OP_FLASH_ATTN_BACK ] = true; + p[GGML_OP_CROSS_ENTROPY_LOSS ] = true; + p[GGML_OP_ADD_REL_POS ] = true; + } + + { // FINALIZE + bool * p = GGML_OP_HAS_FINALIZE; + + p[GGML_OP_CROSS_ENTROPY_LOSS ] = true; + } +} + +// +// ggml context +// + +struct ggml_context { + size_t mem_size; + void * mem_buffer; + bool mem_buffer_owned; + bool no_alloc; + bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers + + int n_objects; + + struct ggml_object * objects_begin; + struct ggml_object * objects_end; + + struct ggml_scratch scratch; + struct ggml_scratch scratch_save; +}; + +struct ggml_context_container { + bool used; + + struct ggml_context context; +}; + +// +// NUMA support +// + +#define GGML_NUMA_MAX_NODES 8 +#define GGML_NUMA_MAX_CPUS 512 + +struct ggml_numa_node { + uint32_t cpus[GGML_NUMA_MAX_CPUS]; // hardware threads on this node + uint32_t n_cpus; +}; + +struct ggml_numa_nodes { + struct ggml_numa_node nodes[GGML_NUMA_MAX_NODES]; + uint32_t n_nodes; + uint32_t total_cpus; // hardware threads on system +}; + +// +// ggml state +// + +struct ggml_state { + struct ggml_context_container contexts[GGML_MAX_CONTEXTS]; + struct ggml_numa_nodes numa; +}; + +// global state +static struct ggml_state g_state; +static atomic_int g_state_barrier = 0; + +// barrier via spin lock +inline static void ggml_critical_section_start(void) { + int processing = atomic_fetch_add(&g_state_barrier, 1); + + while (processing > 0) { + // wait for other threads to finish + atomic_fetch_sub(&g_state_barrier, 1); + sched_yield(); // TODO: reconsider this + processing = atomic_fetch_add(&g_state_barrier, 1); + } +} + +// TODO: make this somehow automatically executed +// some sort of "sentry" mechanism +inline static void ggml_critical_section_end(void) { + atomic_fetch_sub(&g_state_barrier, 1); +} + +void ggml_numa_init(void) { + if (g_state.numa.n_nodes > 0) { + fprintf(stderr, "ggml_numa_init: NUMA already initialized\n"); + + return; + } + +#ifdef __linux__ + struct stat st; + char path[256]; + int rv; + + // enumerate nodes + while (g_state.numa.n_nodes < GGML_NUMA_MAX_NODES) { + rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u", g_state.numa.n_nodes); + GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); + if (stat(path, &st) != 0) { break; } + ++g_state.numa.n_nodes; + } + + // enumerate CPUs + while (g_state.numa.total_cpus < GGML_NUMA_MAX_CPUS) { + rv = snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%u", g_state.numa.total_cpus); + GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); + if (stat(path, &st) != 0) { break; } + ++g_state.numa.total_cpus; + } + + GGML_PRINT_DEBUG("found %u numa nodes, %u CPUs\n", g_state.numa.n_nodes, g_state.numa.total_cpus); + + if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1) { + g_state.numa.n_nodes = 0; + return; + } + + for (uint32_t n = 0; n < g_state.numa.n_nodes; ++n) { + struct ggml_numa_node * node = &g_state.numa.nodes[n]; + GGML_PRINT_DEBUG("CPUs on node %u:", n); + node->n_cpus = 0; + for (uint32_t c = 0; c < g_state.numa.total_cpus; ++c) { + rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u/cpu%u", n, c); + GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); + if (stat(path, &st) == 0) { + node->cpus[node->n_cpus++] = c; + GGML_PRINT_DEBUG(" %u", c); + } + } + GGML_PRINT_DEBUG("\n"); + } + + if (ggml_is_numa()) { + FILE *fptr = fopen("/proc/sys/kernel/numa_balancing", "r"); + if (fptr != NULL) { + char buf[42]; + if (fgets(buf, sizeof(buf), fptr) && strncmp(buf, "0\n", sizeof(buf)) != 0) { + GGML_PRINT("WARNING: /proc/sys/kernel/numa_balancing is enabled, this has been observed to impair performance\n"); + } + fclose(fptr); + } + } +#else + // TODO +#endif +} + +bool ggml_is_numa(void) { + return g_state.numa.n_nodes > 1; +} + +//////////////////////////////////////////////////////////////////////////////// + +void ggml_print_object(const struct ggml_object * obj) { + GGML_PRINT(" - ggml_object: type = %d, offset = %zu, size = %zu, next = %p\n", + obj->type, obj->offs, obj->size, (const void *) obj->next); +} + +void ggml_print_objects(const struct ggml_context * ctx) { + struct ggml_object * obj = ctx->objects_begin; + + GGML_PRINT("%s: objects in context %p:\n", __func__, (const void *) ctx); + + while (obj != NULL) { + ggml_print_object(obj); + obj = obj->next; + } + + GGML_PRINT("%s: --- end ---\n", __func__); +} + +int64_t ggml_nelements(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; +} + +int64_t ggml_nrows(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; +} + +size_t ggml_nbytes(const struct ggml_tensor * tensor) { + size_t nbytes = tensor->ne[0]*tensor->nb[0]/ggml_blck_size(tensor->type); + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; + } + return nbytes; +} + +size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) { + return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN); +} + +size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return (nrows_split*tensor->ne[0]*ggml_type_size(tensor->type))/ggml_blck_size(tensor->type); +} + +int ggml_blck_size(enum ggml_type type) { + return type_traits[type].blck_size; +} + +size_t ggml_type_size(enum ggml_type type) { + return type_traits[type].type_size; +} + +float ggml_type_sizef(enum ggml_type type) { + return ((float)(type_traits[type].type_size))/type_traits[type].blck_size; +} + +const char * ggml_type_name(enum ggml_type type) { + return type_traits[type].type_name; +} + +bool ggml_is_quantized(enum ggml_type type) { + return type_traits[type].is_quantized; +} + +const char * ggml_op_name(enum ggml_op op) { + return GGML_OP_NAME[op]; +} + +const char * ggml_op_symbol(enum ggml_op op) { + return GGML_OP_SYMBOL[op]; +} + +size_t ggml_element_size(const struct ggml_tensor * tensor) { + return ggml_type_size(tensor->type); +} + +static inline bool ggml_is_scalar(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +static inline bool ggml_is_vector(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +static inline bool ggml_is_matrix(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return (t0->ne[0] == t1->ne[0]) && + (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable + (t1->ne[3]%t0->ne[3] == 0); +} + +static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + (t0->ne[1] == t1->ne[1]) && + (t0->ne[2] == t1->ne[2]) && + (t0->ne[3] == t1->ne[3]); +} + +enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { + enum ggml_type wtype = GGML_TYPE_COUNT; + + switch (ftype) { + case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break; + case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break; + case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break; + case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break; + case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break; + case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break; + case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; + case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break; + case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break; + case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break; + case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break; + case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break; + case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; + case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; + } + + GGML_ASSERT(wtype != GGML_TYPE_COUNT); + + return wtype; +} + +size_t ggml_tensor_overhead(void) { + return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE; +} + +bool ggml_is_transposed(const struct ggml_tensor * tensor) { + return tensor->nb[0] > tensor->nb[1]; +} + +bool ggml_is_contiguous(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + tensor->nb[0] == ggml_type_size(tensor->type) && + tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) && + tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +} + +static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + tensor->nb[0] == ggml_type_size(tensor->type) && + tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +} + +bool ggml_is_permuted(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3]; +} + +static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + tensor->nb[0] == ggml_type_size(tensor->type) && + tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +} + +bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + (t0->ne[0] == t1->ne[0] ) && + (t0->ne[1] == t1->ne[1] ) && + (t0->ne[2] == t1->ne[2] ) && + (t0->ne[3] == t1->ne[3] ); +} + +// check if t1 can be represented as a repeatition of t0 +static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + (t1->ne[0]%t0->ne[0] == 0) && + (t1->ne[1]%t0->ne[1] == 0) && + (t1->ne[2]%t0->ne[2] == 0) && + (t1->ne[3]%t0->ne[3] == 0); +} + +static inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return (t0->ne[0] == t1->ne[0]) && ggml_can_repeat(t0, t1); +} + +static inline int ggml_up32(int n) { + return (n + 31) & ~31; +} + +//static inline int ggml_up64(int n) { +// return (n + 63) & ~63; +//} + +static inline int ggml_up(int n, int m) { + // assert m is a power of 2 + GGML_ASSERT((m & (m - 1)) == 0); + return (n + m - 1) & ~(m - 1); +} + +// assert that pointer is aligned to GGML_MEM_ALIGN +#define ggml_assert_aligned(ptr) \ + GGML_ASSERT(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0) + +//////////////////////////////////////////////////////////////////////////////// + +struct ggml_context * ggml_init(struct ggml_init_params params) { + // make this function thread safe + ggml_critical_section_start(); + + static bool is_first_call = true; + + if (is_first_call) { + // initialize time system (required on Windows) + ggml_time_init(); + + // initialize GELU, Quick GELU, SILU and EXP F32 tables + { + const uint64_t t_start = ggml_time_us(); UNUSED(t_start); + + ggml_fp16_t ii; + for (int i = 0; i < (1 << 16); ++i) { + uint16_t ui = i; + memcpy(&ii, &ui, sizeof(ii)); + const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii); + table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f)); + table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f)); + table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f)); + table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f)); + } + + const uint64_t t_end = ggml_time_us(); UNUSED(t_end); + + GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); + } + + // initialize g_state + { + const uint64_t t_start = ggml_time_us(); UNUSED(t_start); + + g_state = (struct ggml_state) { + /*.contexts =*/ { { 0 } }, + /*.numa =*/ { + .n_nodes = 0, + .total_cpus = 0, + }, + }; + + for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) { + g_state.contexts[i].used = false; + } + + const uint64_t t_end = ggml_time_us(); UNUSED(t_end); + + GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); + } + +#if defined(GGML_USE_CUBLAS) + ggml_init_cublas(); +#elif defined(GGML_USE_CLBLAST) + ggml_cl_init(); +#endif + + ggml_setup_op_has_task_pass(); + + is_first_call = false; + } + + // find non-used context in g_state + struct ggml_context * ctx = NULL; + + for (int i = 0; i < GGML_MAX_CONTEXTS; i++) { + if (!g_state.contexts[i].used) { + g_state.contexts[i].used = true; + ctx = &g_state.contexts[i].context; + + GGML_PRINT_DEBUG("%s: found unused context %d\n", __func__, i); + break; + } + } + + if (ctx == NULL) { + GGML_PRINT_DEBUG("%s: no unused context found\n", __func__); + + ggml_critical_section_end(); + + return NULL; + } + + // allow to call ggml_init with 0 size + if (params.mem_size == 0) { + params.mem_size = GGML_MEM_ALIGN; + } + + const size_t mem_size = params.mem_buffer ? params.mem_size : GGML_PAD(params.mem_size, GGML_MEM_ALIGN); + + *ctx = (struct ggml_context) { + /*.mem_size =*/ mem_size, + /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size), + /*.mem_buffer_owned =*/ params.mem_buffer ? false : true, + /*.no_alloc =*/ params.no_alloc, + /*.no_alloc_save =*/ params.no_alloc, + /*.n_objects =*/ 0, + /*.objects_begin =*/ NULL, + /*.objects_end =*/ NULL, + /*.scratch =*/ { 0, 0, NULL, }, + /*.scratch_save =*/ { 0, 0, NULL, }, + }; + + GGML_ASSERT(ctx->mem_buffer != NULL); + + ggml_assert_aligned(ctx->mem_buffer); + + GGML_PRINT_DEBUG("%s: context initialized\n", __func__); + + ggml_critical_section_end(); + + return ctx; +} + +void ggml_free(struct ggml_context * ctx) { + // make this function thread safe + ggml_critical_section_start(); + + bool found = false; + + for (int i = 0; i < GGML_MAX_CONTEXTS; i++) { + if (&g_state.contexts[i].context == ctx) { + g_state.contexts[i].used = false; + + GGML_PRINT_DEBUG("%s: context %d has been freed. memory used = %zu\n", + __func__, i, ggml_used_mem(ctx)); + + if (ctx->mem_buffer_owned) { + GGML_ALIGNED_FREE(ctx->mem_buffer); + } + + found = true; + break; + } + } + + if (!found) { + GGML_PRINT_DEBUG("%s: context not found\n", __func__); + } + + ggml_critical_section_end(); +} + +size_t ggml_used_mem(const struct ggml_context * ctx) { + return ctx->objects_end == NULL ? 0 : ctx->objects_end->offs + ctx->objects_end->size; +} + +size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch) { + const size_t result = ctx->scratch.data ? ctx->scratch.offs : 0; + + ctx->scratch = scratch; + + return result; +} + +bool ggml_get_no_alloc(struct ggml_context * ctx) { + return ctx->no_alloc; +} + +void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc) { + ctx->no_alloc = no_alloc; +} + +void * ggml_get_mem_buffer(const struct ggml_context * ctx) { + return ctx->mem_buffer; +} + +size_t ggml_get_mem_size(const struct ggml_context * ctx) { + return ctx->mem_size; +} + +size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) { + size_t max_size = 0; + + struct ggml_object * obj = ctx->objects_begin; + + while (obj != NULL) { + if (obj->type == GGML_OBJECT_TENSOR) { + struct ggml_tensor * tensor = (struct ggml_tensor *) ((char *) ctx->mem_buffer + obj->offs); + + const size_t size = ggml_nbytes(tensor); + + if (max_size < size) { + max_size = size; + } + } + + obj = obj->next; + } + + return max_size; +} + +// IMPORTANT: +// when creating "opt" tensors, always save and load the scratch buffer +// this is an error prone process, but it is necessary to support inplace +// operators when using scratch buffers +// TODO: implement a better way +static void ggml_scratch_save(struct ggml_context * ctx) { + // this is needed to allow opt tensors to store their data + // TODO: again, need to find a better way + ctx->no_alloc_save = ctx->no_alloc; + ctx->no_alloc = false; + + ctx->scratch_save = ctx->scratch; + ctx->scratch.data = NULL; +} + +static void ggml_scratch_load(struct ggml_context * ctx) { + ctx->no_alloc = ctx->no_alloc_save; + + ctx->scratch = ctx->scratch_save; +} + +//////////////////////////////////////////////////////////////////////////////// + +static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml_object_type type, size_t size) { + // always insert objects at the end of the context's memory pool + struct ggml_object * obj_cur = ctx->objects_end; + + const size_t cur_offs = obj_cur == NULL ? 0 : obj_cur->offs; + const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size; + const size_t cur_end = cur_offs + cur_size; + + // align to GGML_MEM_ALIGN + size_t size_needed = GGML_PAD(size, GGML_MEM_ALIGN); + + char * const mem_buffer = ctx->mem_buffer; + struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end); + + if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) { + GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", + __func__, cur_end + size_needed, ctx->mem_size); + assert(false); + return NULL; + } + + *obj_new = (struct ggml_object) { + .offs = cur_end + GGML_OBJECT_SIZE, + .size = size_needed, + .next = NULL, + .type = type, + }; + + ggml_assert_aligned(mem_buffer + obj_new->offs); + + if (obj_cur != NULL) { + obj_cur->next = obj_new; + } else { + // this is the first object in this context + ctx->objects_begin = obj_new; + } + + ctx->objects_end = obj_new; + + //printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size); + + return obj_new; +} + +static struct ggml_tensor * ggml_new_tensor_impl( + struct ggml_context * ctx, + enum ggml_type type, + int n_dims, + const int64_t * ne, + struct ggml_tensor * view_src, + size_t view_offs) { + + assert(n_dims >= 1 && n_dims <= GGML_MAX_DIMS); + + // find the base tensor and absolute offset + if (view_src != NULL && view_src->view_src != NULL) { + view_offs += view_src->view_offs; + view_src = view_src->view_src; + } + + size_t data_size = ggml_type_size(type)*(ne[0]/ggml_blck_size(type)); + for (int i = 1; i < n_dims; i++) { + data_size *= ne[i]; + } + + GGML_ASSERT(view_src == NULL || data_size + view_offs <= ggml_nbytes(view_src)); + + void * data = view_src != NULL ? view_src->data : NULL; + if (data != NULL) { + data = (char *) data + view_offs; + } + + size_t obj_alloc_size = 0; + + if (view_src == NULL && !ctx->no_alloc) { + if (ctx->scratch.data != NULL) { + // allocate tensor data in the scratch buffer + if (ctx->scratch.offs + data_size > ctx->scratch.size) { + GGML_PRINT("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n", + __func__, ctx->scratch.offs + data_size, ctx->scratch.size); + assert(false); + return NULL; + } + + data = (char * const) ctx->scratch.data + ctx->scratch.offs; + + ctx->scratch.offs += data_size; + } else { + // allocate tensor data in the context's memory pool + obj_alloc_size = data_size; + } + } + + struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size); + + // TODO: for recoverable errors, we would need to free the data allocated from the scratch buffer here + + struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs); + + *result = (struct ggml_tensor) { + /*.type =*/ type, + /*.backend =*/ GGML_BACKEND_CPU, + /*.n_dims =*/ n_dims, + /*.ne =*/ { 1, 1, 1, 1 }, + /*.nb =*/ { 0, 0, 0, 0 }, + /*.op =*/ GGML_OP_NONE, + /*.op_params =*/ { 0 }, + /*.is_param =*/ false, + /*.grad =*/ NULL, + /*.src =*/ { NULL }, + /*.perf_runs =*/ 0, + /*.perf_cycles =*/ 0, + /*.perf_time_us =*/ 0, + /*.view_src =*/ view_src, + /*.view_offs =*/ view_offs, + /*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data, + /*.name =*/ { 0 }, + /*.extra =*/ NULL, + /*.padding =*/ { 0 }, + }; + + // TODO: this should not be needed as long as we don't rely on aligned SIMD loads + //ggml_assert_aligned(result->data); + + for (int i = 0; i < n_dims; i++) { + result->ne[i] = ne[i]; + } + + result->nb[0] = ggml_type_size(type); + result->nb[1] = result->nb[0]*(result->ne[0]/ggml_blck_size(type)); + for (int i = 2; i < GGML_MAX_DIMS; i++) { + result->nb[i] = result->nb[i - 1]*result->ne[i - 1]; + } + + ctx->n_objects++; + + return result; +} + +struct ggml_tensor * ggml_new_tensor( + struct ggml_context * ctx, + enum ggml_type type, + int n_dims, + const int64_t * ne) { + return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL, 0); +} + +struct ggml_tensor * ggml_new_tensor_1d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0) { + return ggml_new_tensor(ctx, type, 1, &ne0); +} + +struct ggml_tensor * ggml_new_tensor_2d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1) { + const int64_t ne[2] = { ne0, ne1 }; + return ggml_new_tensor(ctx, type, 2, ne); +} + +struct ggml_tensor * ggml_new_tensor_3d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2) { + const int64_t ne[3] = { ne0, ne1, ne2 }; + return ggml_new_tensor(ctx, type, 3, ne); +} + +struct ggml_tensor * ggml_new_tensor_4d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3) { + const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; + return ggml_new_tensor(ctx, type, 4, ne); +} + +struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) { + ggml_scratch_save(ctx); + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); + + ggml_scratch_load(ctx); + + ggml_set_i32(result, value); + + return result; +} + +struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) { + ggml_scratch_save(ctx); + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); + + ggml_scratch_load(ctx); + + ggml_set_f32(result, value); + + return result; +} + +struct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggml_tensor * src) { + return ggml_new_tensor(ctx, src->type, src->n_dims, src->ne); +} + +static void ggml_set_op_params(struct ggml_tensor * tensor, const void * params, size_t params_size) { + GGML_ASSERT(tensor != NULL); // silence -Warray-bounds warnings + assert(params_size <= GGML_MAX_OP_PARAMS); + memcpy(tensor->op_params, params, params_size); +} + +static int32_t ggml_get_op_params_i32(const struct ggml_tensor * tensor, uint32_t i) { + assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t)); + return ((const int32_t *)(tensor->op_params))[i]; +} + +static void ggml_set_op_params_i32(struct ggml_tensor * tensor, uint32_t i, int32_t value) { + assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t)); + ((int32_t *)(tensor->op_params))[i] = value; +} + +struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) { + memset(tensor->data, 0, ggml_nbytes(tensor)); + return tensor; +} + +struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) { + const int n = ggml_nrows(tensor); + const int nc = tensor->ne[0]; + const size_t n1 = tensor->nb[1]; + + char * const data = tensor->data; + + switch (tensor->type) { + case GGML_TYPE_I8: + { + assert(tensor->nb[0] == sizeof(int8_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I16: + { + assert(tensor->nb[0] == sizeof(int16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I32: + { + assert(tensor->nb[0] == sizeof(int32_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_F16: + { + assert(tensor->nb[0] == sizeof(ggml_fp16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value)); + } + } break; + case GGML_TYPE_F32: + { + assert(tensor->nb[0] == sizeof(float)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f32(nc, (float *)(data + i*n1), value); + } + } break; + default: + { + GGML_ASSERT(false); + } break; + } + + return tensor; +} + +struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { + const int n = ggml_nrows(tensor); + const int nc = tensor->ne[0]; + const size_t n1 = tensor->nb[1]; + + char * const data = tensor->data; + + switch (tensor->type) { + case GGML_TYPE_I8: + { + assert(tensor->nb[0] == sizeof(int8_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I16: + { + assert(tensor->nb[0] == sizeof(int16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I32: + { + assert(tensor->nb[0] == sizeof(int32_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_F16: + { + assert(tensor->nb[0] == sizeof(ggml_fp16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value)); + } + } break; + case GGML_TYPE_F32: + { + assert(tensor->nb[0] == sizeof(float)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f32(nc, (float *)(data + i*n1), value); + } + } break; + default: + { + GGML_ASSERT(false); + } break; + } + + return tensor; +} + +int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { + switch (tensor->type) { + case GGML_TYPE_I8: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + return ((int8_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_I16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + return ((int16_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_I32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + return ((int32_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); + return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + return ((float *)(tensor->data))[i]; + } break; + default: + { + GGML_ASSERT(false); + } break; + } + + return 0.0f; +} + +void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { + switch (tensor->type) { + case GGML_TYPE_I8: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + ((int8_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + ((int16_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + ((int32_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); + ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + ((float *)(tensor->data))[i] = value; + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { + switch (tensor->type) { + case GGML_TYPE_I8: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + return ((int8_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_I16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + return ((int16_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_I32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + return ((int32_t *)(tensor->data))[i]; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); + return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + return ((float *)(tensor->data))[i]; + } break; + default: + { + GGML_ASSERT(false); + } break; + } + + return 0.0f; +} + +void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { + switch (tensor->type) { + case GGML_TYPE_I8: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + ((int8_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + ((int16_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + ((int32_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); + ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + ((float *)(tensor->data))[i] = value; + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +void * ggml_get_data(const struct ggml_tensor * tensor) { + return tensor->data; +} + +float * ggml_get_data_f32(const struct ggml_tensor * tensor) { + assert(tensor->type == GGML_TYPE_F32); + return (float *)(tensor->data); +} + +enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) { + GGML_ASSERT(tensor->op == GGML_OP_UNARY); + return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0); +} + +const char * ggml_get_name(const struct ggml_tensor * tensor) { + return tensor->name; +} + +struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name) { + strncpy(tensor->name, name, sizeof(tensor->name)); + tensor->name[sizeof(tensor->name) - 1] = '\0'; + return tensor; +} + +struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char * fmt, ...) { + va_list args; + va_start(args, fmt); + vsnprintf(tensor->name, sizeof(tensor->name), fmt, args); + va_end(args); + return tensor; +} + +struct ggml_tensor * ggml_view_tensor( + struct ggml_context * ctx, + struct ggml_tensor * src) { + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src, 0); + ggml_format_name(result, "%s (view)", src->name); + + for (int i = 0; i < GGML_MAX_DIMS; i++) { + result->nb[i] = src->nb[i]; + } + + return result; +} + +struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name) { + struct ggml_object * obj = ctx->objects_begin; + + char * const mem_buffer = ctx->mem_buffer; + + while (obj != NULL) { + if (obj->type == GGML_OBJECT_TENSOR) { + struct ggml_tensor * cur = (struct ggml_tensor *)(mem_buffer + obj->offs); + if (strcmp(cur->name, name) == 0) { + return cur; + } + } + + obj = obj->next; + } + + return NULL; +} + +//////////////////////////////////////////////////////////////////////////////// + +// ggml_dup + +static struct ggml_tensor * ggml_dup_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_DUP; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_dup( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_dup_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_dup_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_dup_impl(ctx, a, true); +} + +// ggml_add + +static struct ggml_tensor * ggml_add_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + // TODO: support less-strict constraint + // GGML_ASSERT(ggml_can_repeat(b, a)); + GGML_ASSERT(ggml_can_repeat_rows(b, a)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + // TODO: support backward pass for broadcasting + GGML_ASSERT(ggml_are_same_shape(a, b)); + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_ADD; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_add( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_add_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_add_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_add_impl(ctx, a, b, true); +} + +// ggml_add1 + +static struct ggml_tensor * ggml_add1_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + GGML_ASSERT(ggml_is_scalar(b)); + GGML_ASSERT(ggml_is_padded_1d(a)); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_ADD1; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_add1( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_add1_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_add1_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_add1_impl(ctx, a, b, true); +} + +// ggml_acc + +static struct ggml_tensor * ggml_acc_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset, + bool inplace) { + GGML_ASSERT(ggml_nelements(b) <= ggml_nelements(a)); + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(a->type == GGML_TYPE_F32); + GGML_ASSERT(b->type == GGML_TYPE_F32); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_ACC; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_acc( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); +} + +struct ggml_tensor * ggml_acc_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true); +} + +// ggml_sub + +static struct ggml_tensor * ggml_sub_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + GGML_ASSERT(ggml_are_same_shape(a, b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SUB; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_sub( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_sub_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_sub_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_sub_impl(ctx, a, b, true); +} + +// ggml_mul + +static struct ggml_tensor * ggml_mul_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + // TODO: support less-strict constraint + // GGML_ASSERT(ggml_can_repeat(b, a)); + GGML_ASSERT(ggml_can_repeat_rows(b, a)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + // TODO: support backward pass for broadcasting + GGML_ASSERT(ggml_are_same_shape(a, b)); + is_node = true; + } + + if (inplace) { + GGML_ASSERT(!is_node); + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_MUL; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_mul( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_mul_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_mul_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_mul_impl(ctx, a, b, true); +} + +// ggml_div + +static struct ggml_tensor * ggml_div_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + GGML_ASSERT(ggml_are_same_shape(a, b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + if (inplace) { + GGML_ASSERT(!is_node); + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_DIV; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_div( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_div_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_div_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_div_impl(ctx, a, b, true); +} + +// ggml_sqr + +static struct ggml_tensor * ggml_sqr_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SQR; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_sqr( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sqr_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_sqr_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sqr_impl(ctx, a, true); +} + +// ggml_sqrt + +static struct ggml_tensor * ggml_sqrt_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SQRT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_sqrt( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sqrt_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_sqrt_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_sqrt_impl(ctx, a, true); +} + + +// ggml_log + +static struct ggml_tensor * ggml_log_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_LOG; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_log( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_log_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_log_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_log_impl(ctx, a, true); +} + +// ggml_sum + +struct ggml_tensor * ggml_sum( + struct ggml_context * ctx, + struct ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1); + + result->op = GGML_OP_SUM; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + + +// ggml_sum_rows + +struct ggml_tensor * ggml_sum_rows( + struct ggml_context * ctx, + struct ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + int64_t ne[4] = {1,1,1,1}; + for (int i=1; in_dims; ++i) { + ne[i] = a->ne[i]; + } + + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, a->n_dims, ne); + + result->op = GGML_OP_SUM_ROWS; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_mean + +struct ggml_tensor * ggml_mean( + struct ggml_context * ctx, + struct ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement + is_node = true; + } + + int64_t ne[GGML_MAX_DIMS] = { 1, a->ne[1], a->ne[2], a->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, ne); + + result->op = GGML_OP_MEAN; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_argmax + +struct ggml_tensor * ggml_argmax( + struct ggml_context * ctx, + struct ggml_tensor * a) { + GGML_ASSERT(ggml_is_matrix(a)); + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); + is_node = true; + } + + int64_t ne[GGML_MAX_DIMS] = { a->ne[1], 1, 1, 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, a->n_dims, ne); + + result->op = GGML_OP_ARGMAX; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_repeat + +struct ggml_tensor * ggml_repeat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_can_repeat(a, b)); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne); + + result->op = GGML_OP_REPEAT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_repeat_back + +struct ggml_tensor * ggml_repeat_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_can_repeat(b, a)); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + if (ggml_are_same_shape(a, b) && !is_node) { + return a; + } + + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne); + + result->op = GGML_OP_REPEAT_BACK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_concat + +struct ggml_tensor * ggml_concat( + struct ggml_context* ctx, + struct ggml_tensor* a, + struct ggml_tensor* b) { + GGML_ASSERT(a->ne[0] == b->ne[0] && a->ne[1] == b->ne[1] && a->ne[3] == b->ne[3]); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, a->ne[0], a->ne[1], a->ne[2] + b->ne[2], a->ne[3]); + + result->op = GGML_OP_CONCAT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_abs + +struct ggml_tensor * ggml_abs( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_ABS); +} + +struct ggml_tensor * ggml_abs_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ABS); +} + +// ggml_sgn + +struct ggml_tensor * ggml_sgn( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_SGN); +} + +struct ggml_tensor * ggml_sgn_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SGN); +} + +// ggml_neg + +struct ggml_tensor * ggml_neg( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_NEG); +} + +struct ggml_tensor * ggml_neg_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_NEG); +} + +// ggml_step + +struct ggml_tensor * ggml_step( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_STEP); +} + +struct ggml_tensor * ggml_step_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_STEP); +} + +// ggml_tanh + +struct ggml_tensor * ggml_tanh( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_TANH); +} + +struct ggml_tensor * ggml_tanh_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_TANH); +} + +// ggml_elu + +struct ggml_tensor * ggml_elu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_ELU); +} + +struct ggml_tensor * ggml_elu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ELU); +} + +// ggml_relu + +struct ggml_tensor * ggml_relu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_RELU); +} + +struct ggml_tensor * ggml_relu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU); +} + +// ggml_gelu + +struct ggml_tensor * ggml_gelu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_GELU); +} + +struct ggml_tensor * ggml_gelu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU); +} + +// ggml_gelu_quick + +struct ggml_tensor * ggml_gelu_quick( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_QUICK); +} + +struct ggml_tensor * ggml_gelu_quick_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_QUICK); +} + +// ggml_silu + +struct ggml_tensor * ggml_silu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_SILU); +} + +struct ggml_tensor * ggml_silu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU); +} + +// ggml_silu_back + +struct ggml_tensor * ggml_silu_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + bool is_node = false; + + if (a->grad || b->grad) { + // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SILU_BACK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_norm + +static struct ggml_tensor * ggml_norm_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params(result, &eps, sizeof(eps)); + + result->op = GGML_OP_NORM; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps) { + return ggml_norm_impl(ctx, a, eps, false); +} + +struct ggml_tensor * ggml_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps) { + return ggml_norm_impl(ctx, a, eps, true); +} + +// ggml_rms_norm + +static struct ggml_tensor * ggml_rms_norm_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params(result, &eps, sizeof(eps)); + + result->op = GGML_OP_RMS_NORM; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_rms_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps) { + return ggml_rms_norm_impl(ctx, a, eps, false); +} + +struct ggml_tensor * ggml_rms_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps) { + return ggml_rms_norm_impl(ctx, a, eps, true); +} + +// ggml_rms_norm_back + +struct ggml_tensor * ggml_rms_norm_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float eps) { + bool is_node = false; + + if (a->grad) { + // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + ggml_set_op_params(result, &eps, sizeof(eps)); + + result->op = GGML_OP_RMS_NORM_BACK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_group_norm + +static struct ggml_tensor * ggml_group_norm_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups, + bool inplace) { + + bool is_node = false; + if (!inplace && (a->grad)) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_GROUP_NORM; + result->op_params[0] = n_groups; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = NULL; // TODO: maybe store epsilon here? + + return result; +} + +struct ggml_tensor * ggml_group_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups) { + return ggml_group_norm_impl(ctx, a, n_groups, false); +} + +struct ggml_tensor * ggml_group_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups) { + return ggml_group_norm_impl(ctx, a, n_groups, true); +} + +// ggml_mul_mat + +struct ggml_tensor * ggml_mul_mat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_can_mul_mat(a, b)); + GGML_ASSERT(!ggml_is_transposed(a)); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne); + + result->op = GGML_OP_MUL_MAT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_out_prod + +struct ggml_tensor * ggml_out_prod( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_can_out_prod(a, b)); + GGML_ASSERT(!ggml_is_transposed(a)); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne); + + result->op = GGML_OP_OUT_PROD; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_scale + +static struct ggml_tensor * ggml_scale_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + GGML_ASSERT(ggml_is_scalar(b)); + GGML_ASSERT(ggml_is_padded_1d(a)); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SCALE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_scale( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_scale_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_scale_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_scale_impl(ctx, a, b, true); +} + +// ggml_set + +static struct ggml_tensor * ggml_set_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset, + bool inplace) { + GGML_ASSERT(ggml_nelements(a) >= ggml_nelements(b)); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + // make a view of the destination + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_SET; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_set( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + return ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, false); +} + +struct ggml_tensor * ggml_set_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + return ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, true); +} + +struct ggml_tensor * ggml_set_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t offset) { + return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, false); +} + +struct ggml_tensor * ggml_set_1d_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t offset) { + return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, true); +} + +struct ggml_tensor * ggml_set_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t offset) { + return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false); +} + +struct ggml_tensor * ggml_set_2d_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t offset) { + return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false); +} + + +// ggml_cpy + +static struct ggml_tensor * ggml_cpy_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + // make a view of the destination + struct ggml_tensor * result = ggml_view_tensor(ctx, b); + if (strlen(b->name) > 0) { + ggml_format_name(result, "%s (copy of %s)", b->name, a->name); + } else { + ggml_format_name(result, "%s (copy)", a->name); + } + + result->op = GGML_OP_CPY; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_cpy( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_cpy_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_cpy_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_cpy_impl(ctx, a, b, true); +} + +// ggml_cont + +static struct ggml_tensor * ggml_cont_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && a->grad) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + ggml_format_name(result, "%s (cont)", a->name); + + result->op = GGML_OP_CONT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_cont( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_cont_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_cont_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_cont_impl(ctx, a, true); +} + +// ggml_reshape + +struct ggml_tensor * ggml_reshape( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_is_contiguous(b)); + GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b)); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + if (b->grad) { + // gradient propagation is not supported + //GGML_ASSERT(false); + } + + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a, 0); + ggml_format_name(result, "%s (reshaped)", a->name); + + result->op = GGML_OP_RESHAPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_reshape_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0) { + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_nelements(a) == ne0); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + const int64_t ne[1] = { ne0 }; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, ne, a, 0); + ggml_format_name(result, "%s (reshaped)", a->name); + + result->op = GGML_OP_RESHAPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_reshape_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1) { + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_nelements(a) == ne0*ne1); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + const int64_t ne[2] = { ne0, ne1 }; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a, 0); + ggml_format_name(result, "%s (reshaped)", a->name); + + result->op = GGML_OP_RESHAPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_reshape_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2) { + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + const int64_t ne[3] = { ne0, ne1, ne2 }; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a, 0); + ggml_format_name(result, "%s (reshaped)", a->name); + + result->op = GGML_OP_RESHAPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_reshape_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3) { + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a, 0); + ggml_format_name(result, "%s (reshaped)", a->name); + + result->op = GGML_OP_RESHAPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +static struct ggml_tensor * ggml_view_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_dims, + const int64_t * ne, + size_t offset) { + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, n_dims, ne, a, offset); + ggml_format_name(result, "%s (view)", a->name); + + ggml_set_op_params(result, &offset, sizeof(offset)); + + result->op = GGML_OP_VIEW; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_view_1d + +struct ggml_tensor * ggml_view_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + size_t offset) { + + struct ggml_tensor * result = ggml_view_impl(ctx, a, 1, &ne0, offset); + + return result; +} + +// ggml_view_2d + +struct ggml_tensor * ggml_view_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + size_t nb1, + size_t offset) { + + const int64_t ne[2] = { ne0, ne1 }; + + struct ggml_tensor * result = ggml_view_impl(ctx, a, 2, ne, offset); + + result->nb[1] = nb1; + result->nb[2] = result->nb[1]*ne1; + result->nb[3] = result->nb[2]; + + return result; +} + +// ggml_view_3d + +struct ggml_tensor * ggml_view_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + size_t nb1, + size_t nb2, + size_t offset) { + + const int64_t ne[3] = { ne0, ne1, ne2 }; + + struct ggml_tensor * result = ggml_view_impl(ctx, a, 3, ne, offset); + + result->nb[1] = nb1; + result->nb[2] = nb2; + result->nb[3] = result->nb[2]*ne2; + + return result; +} + +// ggml_view_4d + +struct ggml_tensor * ggml_view_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { + + const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; + + struct ggml_tensor * result = ggml_view_impl(ctx, a, 4, ne, offset); + + result->nb[1] = nb1; + result->nb[2] = nb2; + result->nb[3] = nb3; + + return result; +} + +// ggml_permute + +struct ggml_tensor * ggml_permute( + struct ggml_context * ctx, + struct ggml_tensor * a, + int axis0, + int axis1, + int axis2, + int axis3) { + GGML_ASSERT(axis0 >= 0 && axis0 < GGML_MAX_DIMS); + GGML_ASSERT(axis1 >= 0 && axis1 < GGML_MAX_DIMS); + GGML_ASSERT(axis2 >= 0 && axis2 < GGML_MAX_DIMS); + GGML_ASSERT(axis3 >= 0 && axis3 < GGML_MAX_DIMS); + + GGML_ASSERT(axis0 != axis1); + GGML_ASSERT(axis0 != axis2); + GGML_ASSERT(axis0 != axis3); + GGML_ASSERT(axis1 != axis2); + GGML_ASSERT(axis1 != axis3); + GGML_ASSERT(axis2 != axis3); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + ggml_format_name(result, "%s (permuted)", a->name); + + int ne[GGML_MAX_DIMS]; + int nb[GGML_MAX_DIMS]; + + ne[axis0] = a->ne[0]; + ne[axis1] = a->ne[1]; + ne[axis2] = a->ne[2]; + ne[axis3] = a->ne[3]; + + nb[axis0] = a->nb[0]; + nb[axis1] = a->nb[1]; + nb[axis2] = a->nb[2]; + nb[axis3] = a->nb[3]; + + result->ne[0] = ne[0]; + result->ne[1] = ne[1]; + result->ne[2] = ne[2]; + result->ne[3] = ne[3]; + + result->nb[0] = nb[0]; + result->nb[1] = nb[1]; + result->nb[2] = nb[2]; + result->nb[3] = nb[3]; + + result->op = GGML_OP_PERMUTE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + int32_t params[] = { axis0, axis1, axis2, axis3 }; + ggml_set_op_params(result, params, sizeof(params)); + + return result; +} + +// ggml_transpose + +struct ggml_tensor * ggml_transpose( + struct ggml_context * ctx, + struct ggml_tensor * a) { + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + ggml_format_name(result, "%s (transposed)", a->name); + + result->ne[0] = a->ne[1]; + result->ne[1] = a->ne[0]; + + result->nb[0] = a->nb[1]; + result->nb[1] = a->nb[0]; + + result->op = GGML_OP_TRANSPOSE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_get_rows + +struct ggml_tensor * ggml_get_rows( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + // TODO: implement non F32 return + //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); + struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0]); + + result->op = GGML_OP_GET_ROWS; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_get_rows_back + +struct ggml_tensor * ggml_get_rows_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c) { + GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_matrix(c) && (a->ne[0] == c->ne[0])); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + // TODO: implement non F32 return + //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); + struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, c->ne[0], c->ne[1]); + + result->op = GGML_OP_GET_ROWS_BACK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + result->src[2] = c; + + return result; +} + +// ggml_diag + +struct ggml_tensor * ggml_diag( + struct ggml_context * ctx, + struct ggml_tensor * a) { + GGML_ASSERT(a->ne[1] == 1); + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + const int64_t ne[4] = { a->ne[0], a->ne[0], a->ne[2], a->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, MAX(a->n_dims, 2), ne); + + result->op = GGML_OP_DIAG; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + + +// ggml_diag_mask_inf + +static struct ggml_tensor * ggml_diag_mask_inf_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + bool inplace) { + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + int32_t params[] = { n_past }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_DIAG_MASK_INF; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_diag_mask_inf( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past) { + return ggml_diag_mask_inf_impl(ctx, a, n_past, false); +} + +struct ggml_tensor * ggml_diag_mask_inf_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past) { + return ggml_diag_mask_inf_impl(ctx, a, n_past, true); +} + +// ggml_diag_mask_zero + +static struct ggml_tensor * ggml_diag_mask_zero_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + bool inplace) { + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + int32_t params[] = { n_past }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_DIAG_MASK_ZERO; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_diag_mask_zero( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past) { + return ggml_diag_mask_zero_impl(ctx, a, n_past, false); +} + +struct ggml_tensor * ggml_diag_mask_zero_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past) { + return ggml_diag_mask_zero_impl(ctx, a, n_past, true); +} + +// ggml_soft_max + +static struct ggml_tensor * ggml_soft_max_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SOFT_MAX; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_soft_max( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_soft_max_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_soft_max_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_soft_max_impl(ctx, a, true); +} + + +// ggml_soft_max_back + +static struct ggml_tensor * ggml_soft_max_back_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; // TODO : implement backward pass + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_SOFT_MAX_BACK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_soft_max_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_soft_max_back_impl(ctx, a, b, false); +} + +struct ggml_tensor * ggml_soft_max_back_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_soft_max_back_impl(ctx, a, b, true); +} + +// ggml_rope + +static struct ggml_tensor * ggml_rope_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_dims, + int mode, + int n_ctx, + float freq_base, + float freq_scale, + float xpos_base, + bool xpos_down, + bool inplace) { + GGML_ASSERT(n_past >= 0); + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + int32_t params[8] = { n_past, n_dims, mode, n_ctx }; + memcpy(params + 4, &freq_base, sizeof(float)); + memcpy(params + 5, &freq_scale, sizeof(float)); + memcpy(params + 6, &xpos_base, sizeof(float)); + memcpy(params + 7, &xpos_down, sizeof(bool)); + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_ROPE; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_rope( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_dims, + int mode, + int n_ctx) { + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false); +} + +struct ggml_tensor * ggml_rope_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_dims, + int mode, + int n_ctx) { + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true); +} + +struct ggml_tensor * ggml_rope_custom( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_dims, + int mode, + int n_ctx, + float freq_base, + float freq_scale) { + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false); +} + +struct ggml_tensor * ggml_rope_custom_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_dims, + int mode, + int n_ctx, + float freq_base, + float freq_scale) { + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true); +} + +struct ggml_tensor * ggml_rope_xpos_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_dims, + float base, + bool down) { + return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true); +} + +// ggml_rope_back + +struct ggml_tensor * ggml_rope_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_dims, + int mode, + int n_ctx, + float freq_base, + float freq_scale, + float xpos_base, + bool xpos_down) { + GGML_ASSERT(n_past >= 0); + GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet"); + + bool is_node = false; + + if (a->grad) { + is_node = false; // TODO: implement backward + } + + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + int32_t params[8] = { n_past, n_dims, mode, n_ctx }; + memcpy(params + 4, &freq_base, sizeof(float)); + memcpy(params + 5, &freq_scale, sizeof(float)); + memcpy(params + 6, &xpos_base, sizeof(float)); + memcpy(params + 7, &xpos_down, sizeof(bool)); + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_ROPE_BACK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_alibi + +struct ggml_tensor * ggml_alibi( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_head, + float bias_max) { + GGML_ASSERT(n_past >= 0); + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + // TODO: when implement backward, fix this: + //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + int32_t op_params[3] = { n_past, n_head }; + memcpy(op_params + 2, &bias_max, sizeof(float)); + ggml_set_op_params(result, op_params, sizeof(op_params)); + + result->op = GGML_OP_ALIBI; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_clamp + +struct ggml_tensor * ggml_clamp( + struct ggml_context * ctx, + struct ggml_tensor * a, + float min, + float max) { + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + // TODO: when implement backward, fix this: + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + float params[] = { min, max }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_CLAMP; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_conv_1d + +static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) { + return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; +} + +GGML_API struct ggml_tensor * ggml_conv_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int p0, + int d0) { + GGML_ASSERT(ggml_is_matrix(b)); + GGML_ASSERT(a->ne[1] == b->ne[1]); + bool is_node = false; + + if (a->grad || b->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { + ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0), + a->ne[2], 1, 1, + }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); + + int32_t params[] = { s0, p0, d0 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_CONV_1D; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_conv_1d_ph + +struct ggml_tensor* ggml_conv_1d_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s, + int d) { + return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d); +} + +// ggml_conv_2d + +struct ggml_tensor * ggml_conv_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1) { + + GGML_ASSERT(a->ne[2] == b->ne[2]); + bool is_node = false; + + if (a->grad || b->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { + ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0), + ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1), + a->ne[3], b->ne[3], + }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + int32_t params[] = { s0, s1, p0, p1, d0, d1 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_CONV_2D; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; + +} + +// ggml_conv_2d_sk_p0 + +struct ggml_tensor * ggml_conv_2d_sk_p0( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_conv_2d(ctx, a, b, a->ne[0], a->ne[1], 0, 0, 1, 1); +} + +// ggml_conv_2d_s1_ph + +struct ggml_tensor * ggml_conv_2d_s1_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1); +} + +// ggml_conv_transpose_2d_p0 + +static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) { + return (ins - 1) * s - 2 * p + ks; +} + +struct ggml_tensor * ggml_conv_transpose_2d_p0( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int stride) { + GGML_ASSERT(a->ne[3] == b->ne[2]); + + bool is_node = false; + + if (a->grad || b->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { + ggml_calc_conv_transpose_output_size(b->ne[0], a->ne[0], stride, 0 /*p0*/), + ggml_calc_conv_transpose_output_size(b->ne[1], a->ne[1], stride, 0 /*p1*/), + a->ne[2], b->ne[3], + }; + + struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + ggml_set_op_params_i32(result, 0, stride); + + result->op = GGML_OP_CONV_TRANSPOSE_2D; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_pool_* + +static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, int p) { + return (ins + 2 * p - ks) / s + 1; +} + +// ggml_pool_1d + +struct ggml_tensor * ggml_pool_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_op_pool op, + int k0, + int s0, + int p0) { + + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[3] = { + ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), + a->ne[1], + }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); + + int32_t params[] = { op, k0, s0, p0 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_POOL_1D; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_pool_2d + +struct ggml_tensor * ggml_pool_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_op_pool op, + int k0, + int k1, + int s0, + int s1, + int p0, + int p1) { + + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[3] = { + ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), + ggml_calc_pool_output_size(a->ne[1], k1, s1, p1), + a->ne[2], + }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne); + + int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_POOL_2D; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_upscale + +static struct ggml_tensor * ggml_upscale_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + int scale_factor) { + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, + a->ne[0] * scale_factor, + a->ne[1] * scale_factor, + a->ne[2], a->ne[3]); + + result->op = GGML_OP_UPSCALE; + result->op_params[0] = scale_factor; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = NULL; + + return result; +} + +struct ggml_tensor * ggml_upscale( + struct ggml_context * ctx, + struct ggml_tensor * a, + int scale_factor) { + return ggml_upscale_impl(ctx, a, scale_factor); +} + +// ggml_flash_attn + +struct ggml_tensor * ggml_flash_attn( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + bool masked) { + GGML_ASSERT(ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) + + bool is_node = false; + + if (q->grad || k->grad || v->grad) { + is_node = true; + } + + //struct ggml_tensor * result = ggml_dup_tensor(ctx, q); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, q->n_dims, q->ne); + + int32_t t = masked ? 1 : 0; + ggml_set_op_params(result, &t, sizeof(t)); + + result->op = GGML_OP_FLASH_ATTN; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + + return result; +} + +// ggml_flash_ff + +struct ggml_tensor * ggml_flash_ff( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b0, + struct ggml_tensor * b1, + struct ggml_tensor * c0, + struct ggml_tensor * c1) { + GGML_ASSERT(ggml_can_mul_mat(b0, a)); + // TODO: more checks + + bool is_node = false; + + if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) { + is_node = true; + } + + //struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, a->ne); + + result->op = GGML_OP_FLASH_FF; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b0; + result->src[2] = b1; + result->src[3] = c0; + result->src[4] = c1; + + return result; +} + +// ggml_flash_attn_back + +struct ggml_tensor * ggml_flash_attn_back( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * d, + bool masked) { + GGML_ASSERT(ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) + + // d shape [D,N,ne2,ne3] + // q shape [D,N,ne2,ne3] + // k shape [D,M,ne2,ne3] + // v shape [M,D,ne2,ne3] + + const int64_t D = q->ne[0]; + const int64_t N = q->ne[1]; + const int64_t M = k->ne[1]; + const int64_t ne2 = q->ne[2]; + const int64_t ne3 = q->ne[3]; + + GGML_ASSERT(k->ne[0] == D); + GGML_ASSERT(v->ne[0] == M); + GGML_ASSERT(v->ne[1] == D); + GGML_ASSERT(d->ne[0] == D); + GGML_ASSERT(d->ne[1] == N); + GGML_ASSERT(k->ne[2] == ne2); + GGML_ASSERT(k->ne[3] == ne3); + GGML_ASSERT(v->ne[2] == ne2); + GGML_ASSERT(v->ne[3] == ne3); + GGML_ASSERT(d->ne[2] == ne2); + GGML_ASSERT(d->ne[3] == ne3); + + bool is_node = false; + + if (q->grad || k->grad || v->grad) { + // when using this operation (in backwards pass) these grads are set. + // we don't want to create (big) grad of our result, so is_node is false. + is_node = false; + } + + // store gradients of q, k and v as continuous tensors concatenated in result. + // q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3] + // gradq->data = result->data + // gradk->data = result->data + nb0*D*N*ne2*ne3 + // gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3 + // note: v and gradv are actually transposed, i.e. v->ne[0] != D. + int64_t ne[4] = {D,M+N+M,ne2,ne3}; + + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + int32_t masked_i = masked ? 1 : 0; + ggml_set_op_params(result, &masked_i, sizeof(masked_i)); + + result->op = GGML_OP_FLASH_ATTN_BACK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = d; + + return result; +} + +// ggml_win_part + +struct ggml_tensor * ggml_win_part( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w) { + GGML_ASSERT(a->ne[3] == 1); + GGML_ASSERT(a->type == GGML_TYPE_F32); + + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + // padding + const int px = (w - a->ne[1]%w)%w; + const int py = (w - a->ne[2]%w)%w; + + const int npx = (px + a->ne[1])/w; + const int npy = (py + a->ne[2])/w; + const int np = npx*npy; + + const int64_t ne[4] = { a->ne[0], w, w, np, }; + + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + int32_t params[] = { npx, npy, w }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_WIN_PART; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_win_unpart + +struct ggml_tensor * ggml_win_unpart( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w0, + int h0, + int w) { + GGML_ASSERT(a->type == GGML_TYPE_F32); + + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { a->ne[0], w0, h0, 1, }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne); + + int32_t params[] = { w }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_WIN_UNPART; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_get_rel_pos + +struct ggml_tensor * ggml_get_rel_pos( + struct ggml_context * ctx, + struct ggml_tensor * a, + int qh, + int kh) { + GGML_ASSERT(qh == kh); + GGML_ASSERT(2*MAX(qh, kh) - 1 == a->ne[1]); + + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { a->ne[0], kh, qh, 1, }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 3, ne); + + result->op = GGML_OP_GET_REL_POS; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = NULL; + + return result; +} + +// ggml_add_rel_pos + +static struct ggml_tensor * ggml_add_rel_pos_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * pw, + struct ggml_tensor * ph, + bool inplace) { + GGML_ASSERT(ggml_are_same_shape(pw, ph)); + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_is_contiguous(pw)); + GGML_ASSERT(ggml_is_contiguous(ph)); + GGML_ASSERT(ph->type == GGML_TYPE_F32); + GGML_ASSERT(pw->type == GGML_TYPE_F32); + GGML_ASSERT(pw->ne[3] == a->ne[2]); + GGML_ASSERT(pw->ne[0]*pw->ne[0] == a->ne[0]); + GGML_ASSERT(pw->ne[1]*pw->ne[2] == a->ne[1]); + + bool is_node = false; + + if (!inplace && (a->grad || pw->grad || ph->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + ggml_set_op_params_i32(result, 0, inplace ? 1 : 0); + + result->op = GGML_OP_ADD_REL_POS; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = pw; + result->src[2] = ph; + + return result; +} + + +struct ggml_tensor * ggml_add_rel_pos( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * pw, + struct ggml_tensor * ph) { + return ggml_add_rel_pos_impl(ctx, a, pw, ph, false); +} + +struct ggml_tensor * ggml_add_rel_pos_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * pw, + struct ggml_tensor * ph) { + return ggml_add_rel_pos_impl(ctx, a, pw, ph, true); +} + +// gmml_unary + +static struct ggml_tensor * ggml_unary_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_unary_op op, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params_i32(result, 0, (int32_t) op); + + result->op = GGML_OP_UNARY; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_unary( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_unary_op op) { + return ggml_unary_impl(ctx, a, op, false); +} + +struct ggml_tensor * ggml_unary_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_unary_op op) { + return ggml_unary_impl(ctx, a, op, true); +} + +// ggml_map_unary + +static struct ggml_tensor * ggml_map_unary_impl_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_unary_op_f32_t fun, + bool inplace) { + bool is_node = false; + + if (!inplace && a->grad) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); + + result->op = GGML_OP_MAP_UNARY; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_map_unary_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_unary_op_f32_t fun) { + return ggml_map_unary_impl_f32(ctx, a, fun, false); +} + +struct ggml_tensor * ggml_map_unary_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_unary_op_f32_t fun) { + return ggml_map_unary_impl_f32(ctx, a, fun, true); +} + +// ggml_map_binary + +static struct ggml_tensor * ggml_map_binary_impl_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_binary_op_f32_t fun, + bool inplace) { + GGML_ASSERT(ggml_are_same_shape(a, b)); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); + + result->op = GGML_OP_MAP_BINARY; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_map_binary_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_binary_op_f32_t fun) { + return ggml_map_binary_impl_f32(ctx, a, b, fun, false); +} + +struct ggml_tensor * ggml_map_binary_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_binary_op_f32_t fun) { + return ggml_map_binary_impl_f32(ctx, a, b, fun, true); +} + +// ggml_map_custom1_f32 + +static struct ggml_tensor * ggml_map_custom1_impl_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_custom1_op_f32_t fun, + bool inplace) { + bool is_node = false; + + if (!inplace && a->grad) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); + + result->op = GGML_OP_MAP_CUSTOM1_F32; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_map_custom1_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_custom1_op_f32_t fun) { + return ggml_map_custom1_impl_f32(ctx, a, fun, false); +} + +struct ggml_tensor * ggml_map_custom1_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_custom1_op_f32_t fun) { + return ggml_map_custom1_impl_f32(ctx, a, fun, true); +} + +// ggml_map_custom2_f32 + +static struct ggml_tensor * ggml_map_custom2_impl_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_custom2_op_f32_t fun, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); + + result->op = GGML_OP_MAP_CUSTOM2_F32; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_map_custom2_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_custom2_op_f32_t fun) { + return ggml_map_custom2_impl_f32(ctx, a, b, fun, false); +} + +struct ggml_tensor * ggml_map_custom2_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_custom2_op_f32_t fun) { + return ggml_map_custom2_impl_f32(ctx, a, b, fun, true); +} + +// ggml_map_custom3_f32 + +static struct ggml_tensor * ggml_map_custom3_impl_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + const ggml_custom3_op_f32_t fun, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad || b->grad || c->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); + + result->op = GGML_OP_MAP_CUSTOM3_F32; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + result->src[2] = c; + + return result; +} + +struct ggml_tensor * ggml_map_custom3_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + const ggml_custom3_op_f32_t fun) { + return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, false); +} + +struct ggml_tensor * ggml_map_custom3_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + const ggml_custom3_op_f32_t fun) { + return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, true); +} + +// ggml_map_custom1 +struct ggml_map_custom1_op_params { + ggml_custom1_op_t fun; + int n_tasks; + void * userdata; +}; + +static struct ggml_tensor * ggml_map_custom1_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_custom1_op_t fun, + int n_tasks, + void * userdata, + bool inplace) { + GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0); + + bool is_node = false; + + if (!inplace && a->grad) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + struct ggml_map_custom1_op_params params = { + /*.fun =*/ fun, + /*.n_tasks =*/ n_tasks, + /*.userdata =*/ userdata + }; + ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); + + result->op = GGML_OP_MAP_CUSTOM1; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_map_custom1( + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_custom1_op_t fun, + int n_tasks, + void * userdata) { + return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, false); +} + +struct ggml_tensor * ggml_map_custom1_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_custom1_op_t fun, + int n_tasks, + void * userdata) { + return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, true); +} + +// ggml_map_custom2 + +struct ggml_map_custom2_op_params { + ggml_custom2_op_t fun; + int n_tasks; + void * userdata; +}; + +static struct ggml_tensor * ggml_map_custom2_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_custom2_op_t fun, + int n_tasks, + void * userdata, + bool inplace) { + GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + struct ggml_map_custom2_op_params params = { + /*.fun =*/ fun, + /*.n_tasks =*/ n_tasks, + /*.userdata =*/ userdata + }; + ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); + + result->op = GGML_OP_MAP_CUSTOM2; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_map_custom2( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_custom2_op_t fun, + int n_tasks, + void * userdata) { + return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, false); +} + +struct ggml_tensor * ggml_map_custom2_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_custom2_op_t fun, + int n_tasks, + void * userdata) { + return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, true); +} + +// ggml_map_custom3 + +struct ggml_map_custom3_op_params { + ggml_custom3_op_t fun; + int n_tasks; + void * userdata; +}; + +static struct ggml_tensor * ggml_map_custom3_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + const ggml_custom3_op_t fun, + int n_tasks, + void * userdata, + bool inplace) { + GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad || c->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + struct ggml_map_custom3_op_params params = { + /*.fun =*/ fun, + /*.n_tasks =*/ n_tasks, + /*.userdata =*/ userdata + }; + ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); + + result->op = GGML_OP_MAP_CUSTOM3; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + result->src[2] = c; + + return result; +} + +struct ggml_tensor * ggml_map_custom3( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + const ggml_custom3_op_t fun, + int n_tasks, + void * userdata) { + return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, false); +} + +struct ggml_tensor * ggml_map_custom3_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + const ggml_custom3_op_t fun, + int n_tasks, + void * userdata) { + return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true); +} + + + +// ggml_cross_entropy_loss + +struct ggml_tensor * ggml_cross_entropy_loss( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_are_same_shape(a, b)); + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1); + + result->op = GGML_OP_CROSS_ENTROPY_LOSS; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_cross_entropy_loss_back + +struct ggml_tensor * ggml_cross_entropy_loss_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c) { + GGML_ASSERT(ggml_are_same_shape(a, b)); + GGML_ASSERT(ggml_is_scalar(c)); + + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK; + result->grad = NULL; + result->src[0] = a; + result->src[1] = b; + result->src[2] = c; + + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +void ggml_set_param( + struct ggml_context * ctx, + struct ggml_tensor * tensor) { + tensor->is_param = true; + + GGML_ASSERT(tensor->grad == NULL); + tensor->grad = ggml_dup_tensor(ctx, tensor); +} + +// ggml_compute_forward_dup + +static void ggml_compute_forward_dup_same_cont( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); + GGML_ASSERT(src0->type == dst->type); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const size_t nb00 = src0->nb[0]; + const size_t nb0 = dst->nb[0]; + + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + // parallelize by elements + const int ne = ggml_nelements(dst); + const int dr = (ne + nth - 1) / nth; + const int ie0 = dr * ith; + const int ie1 = MIN(ie0 + dr, ne); + + if (ie0 < ie1) { + memcpy( + ((char *) dst->data + ie0*nb0), + ((char *) src0->data + ie0*nb00), + (ie1 - ie0) * ggml_type_size(src0->type)); + } + +} +static void ggml_compute_forward_dup_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_TENSOR_UNARY_OP_LOCALS; + + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) { + ggml_compute_forward_dup_same_cont(params, src0, dst); + return; + } + + // parallelize by rows + const int nr = ne01; + // number of rows per thread + const int dr = (nr + nth - 1) / nth; + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (src0->type == dst->type && + ne00 == ne0 && + nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) { + // copy by rows + const size_t rs = ne00*nb00; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ir0; i01 < ir1; i01++) { + memcpy( + ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), + rs); + } + } + } + return; + } + + // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy + + if (ggml_is_contiguous(dst)) { + if (nb00 == sizeof(ggml_fp16_t)) { + if (dst->type == GGML_TYPE_F16) { + size_t id = 0; + const size_t rs = ne00 * nb00; + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, rs); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + for (int i00 = 0; i00 < ne00; i00++) { + dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (type_traits[dst->type].from_float) { + ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float; + float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; + + size_t id = 0; + size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + for (int i00 = 0; i00 < ne00; i00++) { + src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]); + } + + quantize_row_q(src0_f32, dst_ptr + id, ne00); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_F16) { + size_t id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } + return; + } + + // dst counters + int64_t i10 = 0; + int64_t i11 = 0; + int64_t i12 = 0; + int64_t i13 = 0; + + if (dst->type == GGML_TYPE_F16) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t)); + + if (++i10 == ne00) { + i10 = 0; + if (++i11 == ne01) { + i11 = 0; + if (++i12 == ne02) { + i12 = 0; + if (++i13 == ne03) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else if (dst->type == GGML_TYPE_F32) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } +} + +static void ggml_compute_forward_dup_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_TENSOR_UNARY_OP_LOCALS; + + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) { + ggml_compute_forward_dup_same_cont(params, src0, dst); + return; + } + + // parallelize by rows + const int nr = ne01; + // number of rows per thread + const int dr = (nr + nth - 1) / nth; + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (src0->type == dst->type && + ne00 == ne0 && + nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) { + // copy by rows + const size_t rs = ne00*nb00; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ir0; i01 < ir1; i01++) { + memcpy( + ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), + rs); + } + } + } + return; + } + + if (ggml_is_contiguous(dst)) { + // TODO: simplify + if (nb00 == sizeof(float)) { + if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + const size_t rs = ne00 * nb00; + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, rs); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else if (type_traits[dst->type].from_float) { + ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float; + + size_t id = 0; + size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + quantize_row_q(src0_ptr, dst_ptr + id, ne00); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_F16) { + size_t id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } + + return; + } + + // dst counters + + int64_t i10 = 0; + int64_t i11 = 0; + int64_t i12 = 0; + int64_t i13 = 0; + + if (dst->type == GGML_TYPE_F32) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + memcpy(dst_ptr, src0_ptr, sizeof(float)); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else if (dst->type == GGML_TYPE_F16) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } +} + +static void ggml_compute_forward_dup( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) { + ggml_compute_forward_dup_same_cont(params, src0, dst); + return; + } + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_dup_f16(params, src0, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_dup_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_add + +static void ggml_compute_forward_add_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS; + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(float)) { + for (int ir = ir0; ir < ir1; ++ir) { + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + +#ifdef GGML_USE_ACCELERATE + vDSP_vadd(src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00); +#else + ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr); +#endif + // } + // } + } + } else { + // src1 is not contiguous + for (int ir = ir0; ir < ir1; ++ir) { + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + for (int i0 = 0; i0 < ne0; i0++) { + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10); + + dst_ptr[i0] = src0_ptr[i0] + *src1_ptr; + } + } + } +} + +static void ggml_compute_forward_add_f16_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS; + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F16); + + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(float)) { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]); + } + } + } + else { + // src1 is not contiguous + GGML_ASSERT(false); + } +} + +static void ggml_compute_forward_add_f16_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS; + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F16); + + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(ggml_fp16_t)) { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(src1_ptr[i])); + } + } + } + else { + // src1 is not contiguous + GGML_ASSERT(false); + } +} + +static void ggml_compute_forward_add_q_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int ith = params->ith; + const int nth = params->nth; + + const enum ggml_type type = src0->type; + ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; + ggml_from_float_t const quantize_row_q = type_traits[type].from_float; + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(type)); + GGML_ASSERT(nb10 == sizeof(float)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ggml_is_quantized(src0->type)); + GGML_ASSERT(dst->type == src0->type); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + + // src1 and dst are same shape as src0 => same indices + const int i13 = i03; + const int i12 = i02; + const int i11 = i01; + + const int i3 = i03; + const int i2 = i02; + const int i1 = i01; + + void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)); + void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + assert(ne00 % 32 == 0); + + // unquantize row from src0 to temp buffer + dequantize_row_q(src0_row, wdata, ne00); + // add src1 + ggml_vec_acc_f32(ne00, wdata, src1_row); + // quantize row to dst + quantize_row_q(wdata, dst_row, ne00); + } +} + +static void ggml_compute_forward_add( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_add_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F16: + { + if (src1->type == GGML_TYPE_F16) { + ggml_compute_forward_add_f16_f16(params, src0, src1, dst); + } + else if (src1->type == GGML_TYPE_F32) { + ggml_compute_forward_add_f16_f32(params, src0, src1, dst); + } + else { + GGML_ASSERT(false); + } + } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + { + ggml_compute_forward_add_q_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_add1 + +static void ggml_compute_forward_add1_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_scalar(src1)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_UNARY_OP_LOCALS; + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + +#ifdef GGML_USE_ACCELERATE + UNUSED(ggml_vec_add1_f32); + + vDSP_vadd( + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, + (float *) ((char *) src1->data), 0, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, + ne0); +#else + ggml_vec_add1_f32(ne0, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), + *(float *) src1->data); +#endif + } +} + +static void ggml_compute_forward_add1_f16_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_scalar(src1)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // scalar to add + const float v = *(float *) src1->data; + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_UNARY_OP_LOCALS; + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F16); + + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v); + } + } +} + +static void ggml_compute_forward_add1_f16_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_scalar(src1)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // scalar to add + const float v = GGML_FP16_TO_FP32(*(ggml_fp16_t *) src1->data); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_UNARY_OP_LOCALS; + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F16); + + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v); + } + } +} + +static void ggml_compute_forward_add1_q_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_scalar(src1)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // scalar to add + const float v = *(float *) src1->data; + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_UNARY_OP_LOCALS; + + const enum ggml_type type = src0->type; + ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; + ggml_from_float_t const quantize_row_q = type_traits[type].from_float; + + // we don't support permuted src0 + GGML_ASSERT(nb00 == ggml_type_size(type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ggml_is_quantized(src0->type)); + GGML_ASSERT(dst->type == src0->type); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith; + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + void * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03)); + void * dst_row = (void *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb0 )); + + assert(ne0 % 32 == 0); + + // unquantize row from src0 to temp buffer + dequantize_row_q(src0_row, wdata, ne0); + // add src1 + ggml_vec_acc1_f32(ne0, wdata, v); + // quantize row to dst + quantize_row_q(wdata, dst_row, ne0); + } +} + +static void ggml_compute_forward_add1( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_add1_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F16: + { + if (src1->type == GGML_TYPE_F16) { + ggml_compute_forward_add1_f16_f16(params, src0, src1, dst); + } + else if (src1->type == GGML_TYPE_F32) { + ggml_compute_forward_add1_f16_f32(params, src0, src1, dst); + } + else { + GGML_ASSERT(false); + } + } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + { + ggml_compute_forward_add1_q_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + + +// ggml_compute_forward_acc + +static void ggml_compute_forward_acc_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); + + // view src0 and dst with these strides and data offset inbytes during acc + // nb0 is implicitely element_size because src0 and dst are contiguous + size_t nb1 = ((int32_t *) dst->op_params)[0]; + size_t nb2 = ((int32_t *) dst->op_params)[1]; + size_t nb3 = ((int32_t *) dst->op_params)[2]; + size_t offset = ((int32_t *) dst->op_params)[3]; + bool inplace = (bool) ((int32_t *) dst->op_params)[4]; + + if (!inplace && (params->type == GGML_TASK_INIT)) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + memcpy( + ((char *) dst->data), + ((char *) src0->data), + ggml_nbytes(dst)); + } + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src1); + const int nc = src1->ne[0]; + + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); + + // src0 and dst as viewed during acc + const size_t nb0 = ggml_element_size(src0); + + const size_t nb00 = nb0; + const size_t nb01 = nb1; + const size_t nb02 = nb2; + const size_t nb03 = nb3; + + GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0 + (ne11 == 0 ? 0 : ne11-1)*nb1 + (ne12 == 0 ? 0 : ne12-1)*nb2 + (ne13 == 0 ? 0 : ne13-1)*nb3 < ggml_nbytes(dst)); + GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_nbytes(src0)); + + GGML_ASSERT(nb10 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are viewed with shape of src1 and offset + // => same indices + const int i3 = ir/(ne12*ne11); + const int i2 = (ir - i3*ne12*ne11)/ne11; + const int i1 = (ir - i3*ne12*ne11 - i2*ne11); + +#ifdef GGML_USE_ACCELERATE + vDSP_vadd( + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1, + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), 1, nc); +#else + ggml_vec_add_f32(nc, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); +#endif + } +} + +static void ggml_compute_forward_acc( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_acc_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_sub + +static void ggml_compute_forward_sub_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS; + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + if (nb10 == sizeof(float)) { + for (int ir = 0; ir < nr; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + +#ifdef GGML_USE_ACCELERATE + vDSP_vsub( + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, + ne0); +#else + ggml_vec_sub_f32(ne0, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); +#endif + // } + // } + } + } else { + // src1 is not contiguous + for (int ir = 0; ir < nr; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i0 = 0; i0 < ne0; i0++) { + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10); + + dst_ptr[i0] = src0_ptr[i0] - *src1_ptr; + } + } + } +} + +static void ggml_compute_forward_sub( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sub_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_mul + +static void ggml_compute_forward_mul_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + const int ith = params->ith; + const int nth = params->nth; + +#ifdef GGML_USE_CLBLAST + if (src1->backend == GGML_BACKEND_GPU) { + if (ith == 0) { + ggml_cl_mul(src0, src1, dst); + } + return; + } +#endif + + const int64_t nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS; + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(ne00 == ne10); + + if (nb10 == sizeof(float)) { + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + +#ifdef GGML_USE_ACCELERATE + UNUSED(ggml_vec_mul_f32); + + vDSP_vmul( src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00); +#else + ggml_vec_mul_f32(ne00, dst_ptr, src0_ptr, src1_ptr); +#endif + // } + // } + } + } else { + // src1 is not contiguous + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + for (int64_t i0 = 0; i0 < ne00; i0++) { + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10); + + dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr); + } + } + } +} + +static void ggml_compute_forward_mul( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src1->type == GGML_TYPE_F32 && "only f32 src1 supported for now"); + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_mul_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_div + +static void ggml_compute_forward_div_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS; + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + if (nb10 == sizeof(float)) { + for (int ir = 0; ir < nr; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + +#ifdef GGML_USE_ACCELERATE + UNUSED(ggml_vec_div_f32); + + vDSP_vdiv( + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, + ne0); +#else + ggml_vec_div_f32(ne0, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); +#endif + // } + // } + } + } else { + // src1 is not contiguous + for (int ir = 0; ir < nr; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i0 = 0; i0 < ne0; i0++) { + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10); + + dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr); + } + } + } +} + +static void ggml_compute_forward_div( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_div_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_sqr + +static void ggml_compute_forward_sqr_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_sqr_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_sqr( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sqr_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_sqrt + +static void ggml_compute_forward_sqrt_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_sqrt_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_sqrt( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sqrt_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + + +// ggml_compute_forward_log + +static void ggml_compute_forward_log_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(params->ith == 0); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + GGML_ASSERT( dst->nb[0] == sizeof(float)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_log_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_log( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_log_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_sum + +static void ggml_compute_forward_sum_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_is_scalar(dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + assert(ggml_is_scalar(dst)); + assert(src0->nb[0] == sizeof(float)); + + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); + + ggml_float sum = 0; + ggml_float row_sum = 0; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + ggml_vec_sum_f32_ggf(ne00, + &row_sum, + (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); + sum += row_sum; + } + } + } + ((float *) dst->data)[0] = sum; +} + +static void ggml_compute_forward_sum_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_is_scalar(dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + assert(src0->nb[0] == sizeof(ggml_fp16_t)); + + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); + + float sum = 0; + float row_sum = 0; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + ggml_vec_sum_f16_ggf(ne00, + &row_sum, + (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03)); + sum += row_sum; + } + } + } + ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum); +} + +static void ggml_compute_forward_sum( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sum_f32(params, src0, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_sum_f16(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_sum_rows + +static void ggml_compute_forward_sum_rows_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(dst->nb[0] == sizeof(float)); + + GGML_TENSOR_UNARY_OP_LOCALS; + + GGML_ASSERT(ne0 == 1); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + for (int64_t i3 = 0; i3 < ne03; i3++) { + for (int64_t i2 = 0; i2 < ne02; i2++) { + for (int64_t i1 = 0; i1 < ne01; i1++) { + float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); + float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); + float row_sum = 0; + ggml_vec_sum_f32(ne00, &row_sum, src_row); + dst_row[0] = row_sum; + } + } + } +} + +static void ggml_compute_forward_sum_rows( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sum_rows_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_mean + +static void ggml_compute_forward_mean_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + assert(src0->nb[0] == sizeof(float)); + + GGML_TENSOR_UNARY_OP_LOCALS; + + assert(ne0 == 1); + assert(ne1 == ne01); + assert(ne2 == ne02); + assert(ne3 == ne03); + + UNUSED(ne0); + UNUSED(ne1); + UNUSED(ne2); + UNUSED(ne3); + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + ggml_vec_sum_f32(ne00, + (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); + + *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00; + } + } + } +} + +static void ggml_compute_forward_mean( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_mean_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_argmax + +static void ggml_compute_forward_argmax_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + assert(src0->nb[0] == sizeof(float)); + assert(dst->nb[0] == sizeof(float)); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + + const size_t nb01 = src0->nb[1]; + const size_t nb0 = dst->nb[0]; + + for (int64_t i1 = 0; i1 < ne01; i1++) { + float * src = (float *) ((char *) src0->data + i1*nb01); + int32_t * dst_ = (int32_t *) ((char *) dst->data + i1*nb0); + int v = 0; + ggml_vec_argmax_f32(ne00, &v, src); + dst_[0] = v; + } +} + +static void ggml_compute_forward_argmax( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_argmax_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_repeat + +static void ggml_compute_forward_repeat_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(params->ith == 0); + GGML_ASSERT(ggml_can_repeat(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_TENSOR_UNARY_OP_LOCALS; + + // guaranteed to be an integer due to the check in ggml_can_repeat + const int nr0 = (int)(ne0/ne00); + const int nr1 = (int)(ne1/ne01); + const int nr2 = (int)(ne2/ne02); + const int nr3 = (int)(ne3/ne03); + + // TODO: support for transposed / permuted tensors + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + // TODO: maybe this is not optimal? + for (int i3 = 0; i3 < nr3; i3++) { + for (int k3 = 0; k3 < ne03; k3++) { + for (int i2 = 0; i2 < nr2; i2++) { + for (int k2 = 0; k2 < ne02; k2++) { + for (int i1 = 0; i1 < nr1; i1++) { + for (int k1 = 0; k1 < ne01; k1++) { + for (int i0 = 0; i0 < nr0; i0++) { + ggml_vec_cpy_f32(ne00, + (float *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0), + (float *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01)); + } + } + } + } + } + } + } +} + +static void ggml_compute_forward_repeat( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_repeat_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_repeat_back + +static void ggml_compute_forward_repeat_back_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(params->ith == 0); + GGML_ASSERT(ggml_can_repeat(dst, src0)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_TENSOR_UNARY_OP_LOCALS; + + // guaranteed to be an integer due to the check in ggml_can_repeat + const int nr0 = (int)(ne00/ne0); + const int nr1 = (int)(ne01/ne1); + const int nr2 = (int)(ne02/ne2); + const int nr3 = (int)(ne03/ne3); + + // TODO: support for transposed / permuted tensors + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + if (ggml_is_contiguous(dst)) { + ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); + } else { + for (int k3 = 0; k3 < ne3; k3++) { + for (int k2 = 0; k2 < ne2; k2++) { + for (int k1 = 0; k1 < ne1; k1++) { + ggml_vec_set_f32(ne0, + (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3), + 0); + } + } + } + } + + // TODO: maybe this is not optimal? + for (int i3 = 0; i3 < nr3; i3++) { + for (int k3 = 0; k3 < ne3; k3++) { + for (int i2 = 0; i2 < nr2; i2++) { + for (int k2 = 0; k2 < ne2; k2++) { + for (int i1 = 0; i1 < nr1; i1++) { + for (int k1 = 0; k1 < ne1; k1++) { + for (int i0 = 0; i0 < nr0; i0++) { + ggml_vec_acc_f32(ne0, + (float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1), + (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00)); + } + } + } + } + } + } + } +} + +static void ggml_compute_forward_repeat_back( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_repeat_back_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_concat + +static void ggml_compute_forward_concat_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + + GGML_TENSOR_BINARY_OP_LOCALS; + + // TODO: support for transposed / permuted tensors + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = ith; i2 < ne2; i2++) { + if (i2 < ne02) { // src0 + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < ne0; i0++) { + const float * x = (float *)((char *) src0->data + i0 * nb00 + i1 * nb01 + i2 * nb02 + i3 * nb03); + + float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3); + *y = *x; + } + } + } // src1 + else { + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < ne0; i0++) { + const float * x = (float *)((char *) src1->data + i0 * nb10 + i1 * nb11 + (i2 - ne02) * nb12 + i3 * nb13); + + float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3); + *y = *x; + } + } + } + } + } +} + +static void ggml_compute_forward_concat( + const struct ggml_compute_params* params, + const struct ggml_tensor* src0, + const struct ggml_tensor* src1, + struct ggml_tensor* dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_concat_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_abs + +static void ggml_compute_forward_abs_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_abs_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_abs( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_abs_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_sgn + +static void ggml_compute_forward_sgn_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_sgn_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_sgn( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sgn_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_neg + +static void ggml_compute_forward_neg_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_neg_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_neg( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_neg_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_step + +static void ggml_compute_forward_step_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_step_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_step( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_step_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_tanh + +static void ggml_compute_forward_tanh_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_tanh_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_tanh( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_tanh_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_elu + +static void ggml_compute_forward_elu_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_elu_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_elu( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_elu_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_relu + +static void ggml_compute_forward_relu_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_relu_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_relu( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_relu_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_gelu + +static void ggml_compute_forward_gelu_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_gelu_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_gelu( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_gelu_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_gelu_quick + +static void ggml_compute_forward_gelu_quick_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_gelu_quick_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_gelu_quick( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_gelu_quick_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_silu + +static void ggml_compute_forward_silu_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_silu_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_silu( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_silu_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_silu_back + +static void ggml_compute_forward_silu_back_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * grad, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous_except_dim_1(grad)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); + GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_are_same_shape(src0, grad)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_silu_backward_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1])), + (float *) ((char *) grad->data + i1*(grad->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_silu_back( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * grad, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_silu_back_f32(params, src0, grad, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_norm + +static void ggml_compute_forward_norm_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS; + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + // TODO: optimize + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + ggml_float sum = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum += (ggml_float)x[i00]; + } + + float mean = sum/ne00; + + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + ggml_float sum2 = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + float v = x[i00] - mean; + y[i00] = v; + sum2 += (ggml_float)(v*v); + } + + float variance = sum2/ne00; + const float scale = 1.0f/sqrtf(variance + eps); + + ggml_vec_scale_f32(ne00, y, scale); + } + } + } +} + +static void ggml_compute_forward_norm( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_norm_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_group_rms_norm + +static void ggml_compute_forward_rms_norm_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS; + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + // TODO: optimize + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + ggml_float sum = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum += (ggml_float)(x[i00] * x[i00]); + } + + const float mean = sum/ne00; + + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + memcpy(y, x, ne00 * sizeof(float)); + // for (int i00 = 0; i00 < ne00; i00++) { + // y[i00] = x[i00]; + // } + + const float scale = 1.0f/sqrtf(mean + eps); + + ggml_vec_scale_f32(ne00, y, scale); + } + } + } +} + +static void ggml_compute_forward_rms_norm( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_rms_norm_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +static void ggml_compute_forward_rms_norm_back_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_BINARY_OP_LOCALS; + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + // TODO: optimize + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + // src1 is same shape as src0 => same indices + const int64_t i11 = i01; + const int64_t i12 = i02; + const int64_t i13 = i03; + + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13); + + ggml_float sum_xx = 0.0; + ggml_float sum_xdz = 0.0; + + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum_xx += (ggml_float)(x[i00] * x[i00]); + sum_xdz += (ggml_float)(x[i00] * dz[i00]); + } + + //const float mean = (float)(sum_xx)/ne00; + const float mean_eps = (float)(sum_xx)/ne00 + eps; + const float sum_eps = (float)(sum_xx) + eps*ne00; + //const float mean_xdz = (float)(sum_xdz)/ne00; + // we could cache rms from forward pass to improve performance. + // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms. + //const float rms = sqrtf(mean_eps); + const float rrms = 1.0f / sqrtf(mean_eps); + //const float scale = -rrms/(ne00 * mean_eps); // -1/(n*rms**3) + + { + // z = rms_norm(x) + // + // rms_norm(src0) = + // scale( + // src0, + // div( + // 1, + // sqrt( + // add( + // scale( + // sum( + // sqr( + // src0)), + // (1.0/N)), + // eps)))); + + // postorder: + // ## op args grad + // 00 param src0 grad[#00] + // 01 const 1 + // 02 sqr (#00) grad[#02] + // 03 sum (#02) grad[#03] + // 04 const 1/N + // 05 scale (#03, #04) grad[#05] + // 06 const eps + // 07 add (#05, #06) grad[#07] + // 08 sqrt (#07) grad[#08] + // 09 div (#01,#08) grad[#09] + // 10 scale (#00,#09) grad[#10] + // + // backward pass, given grad[#10] + // #10: scale + // grad[#00] += scale(grad[#10],#09) + // grad[#09] += sum(mul(grad[#10],#00)) + // #09: div + // grad[#08] += neg(mul(grad[#09], div(#09,#08))) + // #08: sqrt + // grad[#07] += mul(grad[#08], div(0.5, #08)) + // #07: add + // grad[#05] += grad[#07] + // #05: scale + // grad[#03] += scale(grad[#05],#04) + // #03: sum + // grad[#02] += repeat(grad[#03], #02) + // #02: + // grad[#00] += scale(mul(#00, grad[#02]), 2.0) + // + // substitute and simplify: + // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0) + // grad[#02] = repeat(grad[#03], #02) + // grad[#02] = repeat(scale(grad[#05],#04), #02) + // grad[#02] = repeat(scale(grad[#07],#04), #02) + // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02) + // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02) + // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02) + // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02) + // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02) + // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02) + // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02) + // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0) + // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0) + // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0) + // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N))) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N)) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N)) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N)) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps)) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps))) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps)) + // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps)) + // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps)) + // a = b*c + d*e + // a = b*c*f/f + d*e*f/f + // a = (b*c*f + d*e*f)*(1/f) + // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c)) + // a = (b + d*e/c)*c + // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps) + // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms + // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms + // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms + // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms + // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms + // a = (dz + x*div(-mean_xdz,mean_eps))*rrms + // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms) + // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) + // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) + } + // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) + // post-order: + // dx := x + // dx := scale(dx,-mean_xdz/mean_eps) + // dx := add(dx, dz) + // dx := scale(dx, rrms) + float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + ggml_vec_cpy_f32 (ne00, dx, x); + // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps); + ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps); + ggml_vec_acc_f32 (ne00, dx, dz); + ggml_vec_scale_f32(ne00, dx, rrms); + } + } + } +} + +static void ggml_compute_forward_rms_norm_back( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_rms_norm_back_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_group_norm + +static void ggml_compute_forward_group_norm_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS; + + const float eps = 1e-6f; // TODO: make this a parameter + + // TODO: optimize + + int n_channels = src0->ne[2]; + int n_groups = dst->op_params[0]; + int n_channels_per_group = (n_channels + n_groups - 1) / n_groups; + for (int i = ith; i < n_groups; i+=nth) { + int start = i * n_channels_per_group; + int end = start + n_channels_per_group; + if (end > n_channels) { + end = n_channels; + } + int step = end - start; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + ggml_float sum = 0.0; + for (int64_t i02 = start; i02 < end; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum += (ggml_float)x[i00]; + } + } + } + float mean = sum / (ne00 * ne01 * step); + ggml_float sum2 = 0.0; + + for (int64_t i02 = start; i02 < end; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + + float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3); + + for (int64_t i00 = 0; i00 < ne00; i00++) { + float v = x[i00] - mean; + y[i00] = v; + sum2 += (ggml_float)(v * v); + } + } + } + float variance = sum2 / (ne00 * ne01 * step); + const float scale = 1.0f / sqrtf(variance + eps); + + for (int64_t i02 = start; i02 < end; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3); + ggml_vec_scale_f32(ne00, y, scale); + } + } + } + } +} + +static void ggml_compute_forward_group_norm( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_group_norm_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_mul_mat + +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) +// helper function to determine if it is better to use BLAS or not +// for large matrices, BLAS is faster +static bool ggml_compute_forward_mul_mat_use_blas( + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + //const int64_t ne00 = src0->ne[0]; + //const int64_t ne01 = src0->ne[1]; + + const int64_t ne10 = src1->ne[0]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + + // TODO: find the optimal values for these + if (ggml_is_contiguous(src0) && + ggml_is_contiguous(src1) && + (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) { + + /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/ + return true; + } + + return false; +} +#endif + +static void ggml_compute_forward_mul_mat( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int ith = params->ith; + const int nth = params->nth; + + const enum ggml_type type = src0->type; + + const bool src1_cont = ggml_is_contiguous(src1); + + ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; + enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; + ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; + + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(type)); + GGML_ASSERT(nb10 == sizeof(float)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + // broadcast factors + const int64_t r2 = ne12/ne02; + const int64_t r3 = ne13/ne03; + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + +#if defined(GGML_USE_CLBLAST) + if (ggml_cl_can_mul_mat(src0, src1, dst)) { + // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension + // ref: https://github.com/ggerganov/ggml/pull/224 + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + + if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) { + ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize); + } + return; + } +#endif + +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { + if (params->ith != 0) { + return; + } + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + for (int64_t i13 = 0; i13 < ne13; i13++) { + for (int64_t i12 = 0; i12 < ne12; i12++) { + // broadcast src0 into src1 across 2nd,3rd dimension + const int64_t i03 = i13/r3; + const int64_t i02 = i12/r2; + + const void * x = (char *) src0->data + i02*nb02 + i03*nb03; + const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13); + + float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); + + if (type != GGML_TYPE_F32) { + float * const wdata = params->wdata; + ggml_to_float_t const to_float = type_traits[type].to_float; + + size_t id = 0; + for (int64_t i01 = 0; i01 < ne01; ++i01) { + to_float((const char *) x + i01*nb01, wdata + id, ne00); + id += ne00; + } + + assert(id*sizeof(float) <= params->wsize); + x = wdata; + } + + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + ne11, ne01, ne10, + 1.0f, y, ne10, + x, ne00, + 0.0f, d, ne01); + } + } + + //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); + + return; + } +#endif + + if (params->type == GGML_TASK_INIT) { + if (src1->type != vec_dot_type) { + char * wdata = params->wdata; + const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type); + + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); + wdata += row_size; + } + } + } + } + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type); + + const int64_t nr0 = ne01; // src0 rows + const int64_t nr1 = ne11*ne12*ne13; // src1 rows + + //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1); + + // distribute the thread work across the inner or outer loop based on which one is larger + + const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows + const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows + + const int64_t ith0 = ith % nth0; + const int64_t ith1 = ith / nth0; + + const int64_t dr0 = (nr0 + nth0 - 1)/nth0; + const int64_t dr1 = (nr1 + nth1 - 1)/nth1; + + const int64_t ir010 = dr0*ith0; + const int64_t ir011 = MIN(ir010 + dr0, nr0); + + const int64_t ir110 = dr1*ith1; + const int64_t ir111 = MIN(ir110 + dr1, nr1); + + //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111); + + // threads with no work simply yield (not sure if it helps) + if (ir010 >= ir011 || ir110 >= ir111) { + sched_yield(); + return; + } + + assert(ne12 % ne02 == 0); + assert(ne13 % ne03 == 0); + + // block-tiling attempt + const int64_t blck_0 = 16; + const int64_t blck_1 = 16; + + // attempt to reduce false-sharing (does not seem to make a difference) + float tmp[16]; + + for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { + for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { + for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) { + const int64_t i13 = (ir1/(ne12*ne11)); + const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11; + const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11); + + // broadcast src0 into src1 + const int64_t i03 = i13/r3; + const int64_t i02 = i12/r2; + + const int64_t i1 = i11; + const int64_t i2 = i12; + const int64_t i3 = i13; + + const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03); + + // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides + // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using + // the original src1 data pointer, so we should index using the indices directly + // TODO: this is a bit of a hack, we should probably have a better way to handle this + const char * src1_col = (const char *) wdata + + (src1_cont || src1->type != vec_dot_type + ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size + : (i11*nb11 + i12*nb12 + i13*nb13)); + + float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)); + + //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { + // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); + //} + + for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { + vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col); + } + memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); + } + } + } +} + +// ggml_compute_forward_out_prod + +static void ggml_compute_forward_out_prod_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int ith = params->ith; + const int nth = params->nth; + + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == sizeof(float)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + // GGML_ASSERT(nb0 <= nb1); + // GGML_ASSERT(nb1 <= nb2); + // GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne10); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + + // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod + // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST) + + if (params->type == GGML_TASK_INIT) { + ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by last three dimensions + + // total rows in dst + const int64_t nr = ne1*ne2*ne3; + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + // dst[:,:,:,:] = 0 + // for i2,i3: + // for i1: + // for i01: + // for i0: + // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] + + for (int64_t ir = ir0; ir < ir1; ++ir) { + // dst indices + const int64_t i3 = ir/(ne2*ne1); + const int64_t i2 = (ir - i3*ne2*ne1)/ne1; + const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); + + const int64_t i02 = i2; + const int64_t i03 = i3; + + //const int64_t i10 = i1; + const int64_t i12 = i2; + const int64_t i13 = i3; + + for (int64_t i01 = 0; i01 < ne01; ++i01) { + const int64_t i11 = i01; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + ggml_vec_mad_f32(ne0, d, s0, *s1); + // for (int64_t i0 = 0; i0 < ne0; ++i0) { + // d[i0] += s0[i0] * s1[i1]; + // } + } + } + + //int64_t t1 = ggml_perf_time_us(); + //static int64_t acc = 0; + //acc += t1 - t0; + //if (t1 - t0 > 10) { + // printf("\n"); + // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); + // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); + // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); + // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13); + + // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); + //} +} + +static void ggml_compute_forward_out_prod( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + { + GGML_ASSERT(false); // todo + // ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(false); // todo + // ggml_compute_forward_out_prod_f16_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_out_prod_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_scale + +static void ggml_compute_forward_scale_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_scalar(src1)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // scale factor + const float v = *(float *) src1->data; + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const size_t nb01 = src0->nb[1]; + + const size_t nb1 = dst->nb[1]; + + + for (int i1 = ir0; i1 < ir1; i1++) { + if (dst->data != src0->data) { + // src0 is same shape as dst => same indices + memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); + } + ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v); + } +} + +static void ggml_compute_forward_scale( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_scale_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_set + +static void ggml_compute_forward_set_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); + + // view src0 and dst with these strides and data offset inbytes during set + // nb0 is implicitely element_size because src0 and dst are contiguous + size_t nb1 = ((int32_t *) dst->op_params)[0]; + size_t nb2 = ((int32_t *) dst->op_params)[1]; + size_t nb3 = ((int32_t *) dst->op_params)[2]; + size_t offset = ((int32_t *) dst->op_params)[3]; + bool inplace = (bool) ((int32_t *) dst->op_params)[4]; + + if (!inplace && (params->type == GGML_TASK_INIT)) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + memcpy( + ((char *) dst->data), + ((char *) src0->data), + ggml_nbytes(dst)); + } + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src1); + const int nc = src1->ne[0]; + + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); + + // src0 and dst as viewed during set + const size_t nb0 = ggml_element_size(src0); + + const int im0 = (ne10 == 0 ? 0 : ne10-1); + const int im1 = (ne11 == 0 ? 0 : ne11-1); + const int im2 = (ne12 == 0 ? 0 : ne12-1); + const int im3 = (ne13 == 0 ? 0 : ne13-1); + + GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst)); + + GGML_ASSERT(nb10 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are viewed with shape of src1 and offset + // => same indices + const int i3 = ir/(ne12*ne11); + const int i2 = (ir - i3*ne12*ne11)/ne11; + const int i1 = (ir - i3*ne12*ne11 - i2*ne11); + + ggml_vec_cpy_f32(nc, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); + } +} + +static void ggml_compute_forward_set( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_set_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_cpy + +static void ggml_compute_forward_cpy( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + ggml_compute_forward_dup(params, src0, dst); +} + +// ggml_compute_forward_cont + +static void ggml_compute_forward_cont( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + ggml_compute_forward_dup(params, src0, dst); +} + +// ggml_compute_forward_reshape + +static void ggml_compute_forward_reshape( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + // NOP + UNUSED(params); + UNUSED(src0); + UNUSED(dst); +} + +// ggml_compute_forward_view + +static void ggml_compute_forward_view( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0) { + // NOP + UNUSED(params); + UNUSED(src0); +} + +// ggml_compute_forward_permute + +static void ggml_compute_forward_permute( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0) { + // NOP + UNUSED(params); + UNUSED(src0); +} + +// ggml_compute_forward_transpose + +static void ggml_compute_forward_transpose( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0) { + // NOP + UNUSED(params); + UNUSED(src0); +} + +// ggml_compute_forward_get_rows + +static void ggml_compute_forward_get_rows_q( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int nc = src0->ne[0]; + const int nr = ggml_nelements(src1); + const enum ggml_type type = src0->type; + ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; + + assert( dst->ne[0] == nc); + assert( dst->ne[1] == nr); + assert(src0->nb[0] == ggml_type_size(type)); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + dequantize_row_q( + (const void *) ((char *) src0->data + r*src0->nb[1]), + (float *) ((char *) dst->data + i*dst->nb[1]), nc); + } +} + +static void ggml_compute_forward_get_rows_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int nc = src0->ne[0]; + const int nr = ggml_nelements(src1); + + assert( dst->ne[0] == nc); + assert( dst->ne[1] == nr); + assert(src0->nb[0] == sizeof(ggml_fp16_t)); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + for (int j = 0; j < nc; ++j) { + ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j]; + ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v); + } + } +} + +static void ggml_compute_forward_get_rows_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int nc = src0->ne[0]; + const int nr = ggml_nelements(src1); + + assert( dst->ne[0] == nc); + assert( dst->ne[1] == nr); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + ggml_vec_cpy_f32(nc, + (float *) ((char *) dst->data + i*dst->nb[1]), + (float *) ((char *) src0->data + r*src0->nb[1])); + } +} + +static void ggml_compute_forward_get_rows( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + { + ggml_compute_forward_get_rows_q(params, src0, src1, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_get_rows_f16(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_get_rows_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } + + //static bool first = true; + //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]); + //if (first) { + // first = false; + //} else { + // for (int k = 0; k < dst->ne[1]; ++k) { + // for (int j = 0; j < dst->ne[0]/16; ++j) { + // for (int i = 0; i < 16; ++i) { + // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); + // } + // printf("\n"); + // } + // printf("\n"); + // } + // printf("\n"); + // exit(0); + //} +} + +// ggml_compute_forward_get_rows_back + +static void ggml_compute_forward_get_rows_back_f32_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * opt0, + struct ggml_tensor * dst) { + GGML_ASSERT(params->ith == 0); + GGML_ASSERT(ggml_are_same_shape(opt0, dst)); + GGML_ASSERT(ggml_is_contiguous(opt0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + ggml_compute_forward_dup_same_cont(params, opt0, dst); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int nc = src0->ne[0]; + const int nr = ggml_nelements(src1); + + GGML_ASSERT( dst->ne[0] == nc); + GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t)); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + for (int j = 0; j < nc; ++j) { + ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j]; + ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_FP16_TO_FP32(v); + } + } +} + +static void ggml_compute_forward_get_rows_back_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * opt0, + struct ggml_tensor * dst) { + GGML_ASSERT(params->ith == 0); + GGML_ASSERT(ggml_are_same_shape(opt0, dst)); + GGML_ASSERT(ggml_is_contiguous(opt0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + // ggml_compute_forward_dup_same_cont(params, opt0, dst); + + if (params->type == GGML_TASK_INIT) { + memset(dst->data, 0, ggml_nbytes(dst)); + } + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int nc = src0->ne[0]; + const int nr = ggml_nelements(src1); + + GGML_ASSERT( dst->ne[0] == nc); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + ggml_vec_add_f32(nc, + (float *) ((char *) dst->data + r*dst->nb[1]), + (float *) ((char *) dst->data + r*dst->nb[1]), + (float *) ((char *) src0->data + i*src0->nb[1])); + } +} + + +static void ggml_compute_forward_get_rows_back( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * opt0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, opt0, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_get_rows_back_f32(params, src0, src1, opt0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } + + //static bool first = true; + //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]); + //if (first) { + // first = false; + //} else { + // for (int k = 0; k < dst->ne[1]; ++k) { + // for (int j = 0; j < dst->ne[0]/16; ++j) { + // for (int i = 0; i < 16; ++i) { + // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); + // } + // printf("\n"); + // } + // printf("\n"); + // } + // printf("\n"); + // exit(0); + //} +} + +// ggml_compute_forward_diag + +static void ggml_compute_forward_diag_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // TODO: handle transposed/permuted matrices + + GGML_TENSOR_UNARY_OP_LOCALS; + + GGML_ASSERT(ne00 == ne0); + GGML_ASSERT(ne00 == ne1); + GGML_ASSERT(ne01 == 1); + GGML_ASSERT(ne02 == ne2); + GGML_ASSERT(ne03 == ne3); + + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb0 == sizeof(float)); + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + for (int i1 = 0; i1 < ne1; i1++) { + float * d = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02); + for (int i0 = 0; i0 < i1; i0++) { + d[i0] = 0; + } + d[i1] = s[i1]; + for (int i0 = i1+1; i0 < ne0; i0++) { + d[i0] = 0; + } + } + } + } +} + +static void ggml_compute_forward_diag( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_diag_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_diag_mask_inf + +static void ggml_compute_forward_diag_mask_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst, + const float value) { + + const int ith = params->ith; + const int nth = params->nth; + + const int n_past = ((int32_t *) dst->op_params)[0]; + const bool inplace = src0->data == dst->data; + + GGML_ASSERT(n_past >= 0); + + if (!inplace && (params->type == GGML_TASK_INIT)) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); + memcpy( + ((char *) dst->data), + ((char *) src0->data), + ggml_nbytes(dst)); + } + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // TODO: handle transposed/permuted matrices + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + const int nr = src0->ne[1]; + const int nz = n/nr; + + GGML_ASSERT( dst->nb[0] == sizeof(float)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + for (int k = 0; k < nz; k++) { + for (int j = ith; j < nr; j += nth) { + for (int i = n_past; i < nc; i++) { + if (i > n_past + j) { + *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value; + } + } + } + } +} + +static void ggml_compute_forward_diag_mask_inf( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_diag_mask_f32(params, src0, dst, -INFINITY); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +static void ggml_compute_forward_diag_mask_zero( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_diag_mask_f32(params, src0, dst, 0); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_soft_max + +static void ggml_compute_forward_soft_max_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // TODO: handle transposed/permuted matrices + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float *sp = (float *)((char *) src0->data + i1*src0->nb[1]); + float *dp = (float *)((char *) dst->data + i1*dst->nb[1]); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(sp[i])); + } +#endif + + float max = -INFINITY; + ggml_vec_max_f32(nc, &max, sp); + + ggml_float sum = 0.0; + + uint16_t scvt; + for (int i = 0; i < nc; i++) { + if (sp[i] == -INFINITY) { + dp[i] = 0.0f; + } else { + // const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max); + ggml_fp16_t s = GGML_FP32_TO_FP16(sp[i] - max); + memcpy(&scvt, &s, sizeof(scvt)); + const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]); + sum += (ggml_float)val; + dp[i] = val; + } + } + + assert(sum > 0.0); + + sum = 1.0/sum; + ggml_vec_scale_f32(nc, dp, sum); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(dp[i])); + assert(!isinf(dp[i])); + } +#endif + } +} + +static void ggml_compute_forward_soft_max( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_soft_max_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_soft_max_back + +static void ggml_compute_forward_soft_max_back_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_are_same_shape(src1, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // TODO: handle transposed/permuted matrices + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float *dy = (float *)((char *) src0->data + i1*src0->nb[1]); + float *y = (float *)((char *) src1->data + i1*src1->nb[1]); + float *dx = (float *)((char *) dst->data + i1*dst->nb[1]); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(dy[i])); + assert(!isnan(y[i])); + } +#endif + // Jii = yi - yi*yi + // Jij = -yi*yj + // J = diag(y)-y.T*y + // dx = J * dy + // dxk = sum_i(Jki * dyi) + // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk + // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk + // dxk = sum_i(-yk*yi * dyi) + yk*dyk + // dxk = -yk * sum_i(yi * dyi) + yk*dyk + // dxk = -yk * dot(y, dy) + yk*dyk + // dxk = yk * (- dot(y, dy) + dyk) + // dxk = yk * (dyk - dot(y, dy)) + // + // post-order: + // dot_y_dy := dot(y, dy) + // dx := dy + // dx := dx - dot_y_dy + // dx := dx * y + + // linear runtime, no additional memory + float dot_y_dy = 0; + ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy); + ggml_vec_cpy_f32 (nc, dx, dy); + ggml_vec_acc1_f32(nc, dx, -dot_y_dy); + ggml_vec_mul_f32 (nc, dx, dx, y); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(dx[i])); + assert(!isinf(dx[i])); + } +#endif + } +} + +static void ggml_compute_forward_soft_max_back( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_soft_max_back_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_alibi + +static void ggml_compute_forward_alibi_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_head = ((int32_t *) dst->op_params)[1]; + float max_bias; + memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); + + assert(n_past >= 0); + + const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 + const int ne1 = src0->ne[1]; // seq_len_without_past + const int ne2 = src0->ne[2]; // n_head -> this is k + //const int ne3 = src0->ne[3]; // 1 -> bsz + + const int n = ggml_nrows(src0); + const int ne2_ne3 = n/ne1; // ne2*ne3 + + const int nb0 = src0->nb[0]; + const int nb1 = src0->nb[1]; + const int nb2 = src0->nb[2]; + //const int nb3 = src0->nb[3]; + + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(ne1 + n_past == ne0); + GGML_ASSERT(n_head == ne2); + + // add alibi to src0 (KQ_scaled) + const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); + + for (int i = 0; i < ne0; i++) { + for (int j = 0; j < ne1; j++) { + for (int k = 0; k < ne2_ne3; k++) { + float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2); + float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2); + + // TODO: k*nb2 or k*nb3 + + float m_k; + + if (k < n_heads_log2_floor) { + m_k = powf(m0, k + 1); + } else { + m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); + } + + pdst[0] = i * m_k + src[0]; + + } + } + } +} + +static void ggml_compute_forward_alibi_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_head = ((int32_t *) dst->op_params)[1]; + float max_bias; + memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); + + assert(n_past >= 0); + + const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 + const int ne1 = src0->ne[1]; // seq_len_without_past + const int ne2 = src0->ne[2]; // n_head -> this is k + //const int ne3 = src0->ne[3]; // 1 -> bsz + + const int n = ggml_nrows(src0); + const int ne2_ne3 = n/ne1; // ne2*ne3 + + const int nb0 = src0->nb[0]; + const int nb1 = src0->nb[1]; + const int nb2 = src0->nb[2]; + //const int nb3 = src0->nb[3]; + + GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(ne1 + n_past == ne0); (void) n_past; + GGML_ASSERT(n_head == ne2); + + // add alibi to src0 (KQ_scaled) + const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); + + for (int i = 0; i < ne0; i++) { + for (int j = 0; j < ne1; j++) { + for (int k = 0; k < ne2_ne3; k++) { + ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2); + float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2); + + // TODO: k*nb2 or k*nb3 + + float m_k; + + if (k < n_heads_log2_floor) { + m_k = powf(m0, k + 1); + } else { + m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); + } + + // we return F32 + pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]); + } + } + } +} + +static void ggml_compute_forward_alibi( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_alibi_f16(params, src0, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_alibi_f32(params, src0, dst); + } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_Q8_K: + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_clamp + +static void ggml_compute_forward_clamp_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + float min; + float max; + memcpy(&min, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + for (int j = ith; j < n; j += nth) { + float * dst_ptr = (float *) ((char *) dst->data + j*nb1); + float * src0_ptr = (float *) ((char *) src0->data + j*nb01); + + for (int i = 0; i < nc; i++) { + dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min); + } + } +} + +static void ggml_compute_forward_clamp( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_clamp_f32(params, src0, dst); + } break; + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_Q8_K: + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_rope + +static void ggml_compute_forward_rope_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + float freq_base; + float freq_scale; + + // these two only relevant for xPos RoPE: + float xpos_base; + bool xpos_down; + + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx = ((int32_t *) dst->op_params)[3]; + memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool)); + + assert(n_past >= 0); + + GGML_TENSOR_UNARY_OP_LOCALS; + + //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); + //printf("n_past = %d, ne2 = %d\n", n_past, ne2); + + GGML_ASSERT(nb00 == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(dst); + + GGML_ASSERT(n_dims <= ne0); + GGML_ASSERT(n_dims % 2 == 0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + // row index used to determine which thread to use + int ir = 0; + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + const bool is_neox = mode & 2; + const bool is_glm = mode & 4; + + for (int64_t i3 = 0; i3 < ne3; i3++) { + for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { + const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i1 = 0; i1 < ne1; i1++) { + if (ir++ < ir0) continue; + if (ir > ir1) break; + + float theta = freq_scale * (float)p; + + if (is_glm) { + theta = MIN(p, n_ctx - 2); + float block_theta = MAX(p - (n_ctx - 2), 0); + for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { + const float cos_theta = cosf(theta); + const float sin_theta = sinf(theta); + const float cos_block_theta = cosf(block_theta); + const float sin_block_theta = sinf(block_theta); + + theta *= theta_scale; + block_theta *= theta_scale; + + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + const float x2 = src[n_dims]; + const float x3 = src[n_dims/2*3]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + dst_data[n_dims] = x2*cos_block_theta - x3*sin_block_theta; + dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta; + } + } else if (!is_neox) { + for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + const float cos_theta = cosf(theta); + const float sin_theta = sinf(theta); + // zeta scaling for xPos only: + float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f; + if (xpos_down) zeta = 1.0f / zeta; + + theta *= theta_scale; + + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[1]; + + dst_data[0] = x0*cos_theta*zeta - x1*sin_theta*zeta; + dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta; + } + } else { + // TODO: this might be wrong for ne0 != n_dims - need double check + // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 + for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { + for (int64_t ic = 0; ic < n_dims; ic += 2) { + const float cos_theta = cosf(theta); + const float sin_theta = sinf(theta); + + theta *= theta_scale; + + const int64_t i0 = ib*n_dims + ic/2; + + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } + } + } + } + } + } +} + +static void ggml_compute_forward_rope_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + float freq_base; + float freq_scale; + + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx = ((int32_t *) dst->op_params)[3]; + memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); + + assert(n_past >= 0); + + GGML_TENSOR_UNARY_OP_LOCALS; + + //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); + //printf("n_past = %d, ne2 = %d\n", n_past, ne2); + + GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(dst); + + GGML_ASSERT(n_dims <= ne0); + GGML_ASSERT(n_dims % 2 == 0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + // row index used to determine which thread to use + int ir = 0; + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + const bool is_neox = mode & 2; + const bool is_glm = mode & 4; + + for (int64_t i3 = 0; i3 < ne3; i3++) { + for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { + const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i1 = 0; i1 < ne1; i1++) { + if (ir++ < ir0) continue; + if (ir > ir1) break; + + float theta = freq_scale * (float)p; + + if (is_glm) { + theta = MIN(p, n_ctx - 2); + float block_theta = MAX(p - (n_ctx - 2), 0); + for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { + const float cos_theta = cosf(theta); + const float sin_theta = sinf(theta); + const float cos_block_theta = cosf(block_theta); + const float sin_block_theta = sinf(block_theta); + + theta *= theta_scale; + block_theta *= theta_scale; + + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = GGML_FP16_TO_FP32(src[0]); + const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]); + const float x2 = GGML_FP16_TO_FP32(src[n_dims]); + const float x3 = GGML_FP16_TO_FP32(src[n_dims/2*3]); + + dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + dst_data[n_dims] = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta); + dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta); + } + } if (!is_neox) { + for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + const float cos_theta = cosf(theta); + const float sin_theta = sinf(theta); + + theta *= theta_scale; + + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = GGML_FP16_TO_FP32(src[0]); + const float x1 = GGML_FP16_TO_FP32(src[1]); + + dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + } + } else { + // TODO: this might be wrong for ne0 != n_dims - need double check + // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 + for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { + for (int64_t ic = 0; ic < n_dims; ic += 2) { + const float cos_theta = cosf(theta); + const float sin_theta = sinf(theta); + + theta *= theta_scale; + + const int64_t i0 = ib*n_dims + ic/2; + + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = GGML_FP16_TO_FP32(src[0]); + const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]); + + dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + } + } + } + } + } + } +} + +static void ggml_compute_forward_rope( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_rope_f16(params, src0, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_rope_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_rope_back + +static void ggml_compute_forward_rope_back_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // y = rope(x, src1) + // dx = rope_back(dy, src1) + // src0 is dy, src1 contains options + + float freq_base; + float freq_scale; + + // these two only relevant for xPos RoPE: + float xpos_base; + bool xpos_down; + + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx); + memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool)); + + assert(n_past >= 0); + + GGML_TENSOR_UNARY_OP_LOCALS; + + //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); + //printf("n_past = %d, ne2 = %d\n", n_past, ne2); + + assert(nb0 == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(dst); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + // row index used to determine which thread to use + int ir = 0; + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + const bool is_neox = mode & 2; + + for (int64_t i3 = 0; i3 < ne3; i3++) { + for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { + const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i1 = 0; i1 < ne1; i1++) { + if (ir++ < ir0) continue; + if (ir > ir1) break; + + float theta = freq_scale * (float)p; + + if (!is_neox) { + for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + const float cos_theta = cosf(theta); + const float sin_theta = sinf(theta); + // zeta scaling for xPos only: + float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f; + if (xpos_down) zeta = 1.0f / zeta; + + theta *= theta_scale; + + const float * const dy = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float dy0 = dy[0]; + const float dy1 = dy[1]; + + dx[0] = dy0*cos_theta*zeta + dy1*sin_theta*zeta; + dx[1] = - dy0*sin_theta*zeta + dy1*cos_theta*zeta; + } + } else { + for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { + for (int64_t ic = 0; ic < n_dims; ic += 2) { + const float cos_theta = cosf(theta); + const float sin_theta = sinf(theta); + + theta *= theta_scale; + + const int64_t i0 = ib*n_dims + ic/2; + + const float * const dy = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float dy0 = dy[0]; + const float dy1 = dy[n_dims/2]; + + dx[0] = dy0*cos_theta + dy1*sin_theta; + dx[n_dims/2] = - dy0*sin_theta + dy1*cos_theta; + } + } + } + } + } + } +} + +static void ggml_compute_forward_rope_back_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // y = rope(x, src1) + // dx = rope_back(dy, src1) + // src0 is dy, src1 contains options + + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + + assert(n_past >= 0); + + GGML_TENSOR_UNARY_OP_LOCALS; + + //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); + //printf("n_past = %d, ne2 = %d\n", n_past, ne2); + + assert(nb0 == sizeof(ggml_fp16_t)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(dst); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + // row index used to determine which thread to use + int ir = 0; + + const float theta_scale = powf(10000.0, -2.0f/n_dims); + + const bool is_neox = mode & 2; + + for (int64_t i3 = 0; i3 < ne3; i3++) { + for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { + const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i1 = 0; i1 < ne1; i1++) { + if (ir++ < ir0) continue; + if (ir > ir1) break; + + float theta = (float)p; + + if (!is_neox) { + for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + const float cos_theta = cosf(theta); + const float sin_theta = sinf(theta); + + theta *= theta_scale; + + const ggml_fp16_t * const dy = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + ggml_fp16_t * dx = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float dy0 = GGML_FP16_TO_FP32(dy[0]); + const float dy1 = GGML_FP16_TO_FP32(dy[1]); + + dx[0] = GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta); + dx[1] = GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta); + } + } else { + for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { + for (int64_t ic = 0; ic < n_dims; ic += 2) { + const float cos_theta = cosf(theta); + const float sin_theta = sinf(theta); + + theta *= theta_scale; + + const int64_t i0 = ib*n_dims + ic/2; + + const ggml_fp16_t * const dy = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + ggml_fp16_t * dx = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float dy0 = GGML_FP16_TO_FP32(dy[0]); + const float dy1 = GGML_FP16_TO_FP32(dy[n_dims/2]); + + dx[0] = GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta); + dx[n_dims/2] = GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta); + } + } + } + } + } + } +} + +static void ggml_compute_forward_rope_back( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_rope_back_f16(params, src0, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_rope_back_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_conv_1d + +static void ggml_compute_forward_conv_1d_s1_ph_f16_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00; + const int nh = nk/2; + + const int ew0 = ggml_up32(ne01); + + GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + + // prepare kernel data (src0) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); + ggml_fp16_t * dst_data = wdata + i02*ew0*ne00; + for (int64_t i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ew0 + i01] = src[i00]; + } + } + } + } + + // prepare source data (src1) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00; + + for (int64_t i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + ggml_fp16_t * dst_data = wdata; + for (int64_t i10 = 0; i10 < ne10; i10++) { + dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]); + } + } + } + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // total rows in dst + const int nr = ne02; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + for (int64_t i0 = 0; i0 < ne10; ++i0) { + dst_data[i0] = 0; + for (int k = -nh; k <= nh; k++) { + float v = 0.0f; + ggml_vec_dot_f16(ew0, &v, + (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, + (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); + + dst_data[i0] += v; + } + } + } +} + +static void ggml_compute_forward_conv_1d_s1_ph_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00; + const int nh = nk/2; + + const int ew0 = ggml_up32(ne01); + + GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + + // prepare kernel data (src0) + { + float * const wdata = (float *) params->wdata + 0; + + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); + float * dst_data = wdata + i02*ew0*ne00; + for (int64_t i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ew0 + i01] = src[i00]; + } + } + } + } + + // prepare source data (src1) + { + float * const wdata = (float *) params->wdata + ne02*ew0*ne00; + + for (int64_t i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + float * dst_data = wdata; + for (int64_t i10 = 0; i10 < ne10; i10++) { + dst_data[(i10 + nh)*ew0 + i11] = src[i10]; + } + } + } + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // total rows in dst + const int nr = ne02; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + for (int64_t i0 = 0; i0 < ne10; ++i0) { + dst_data[i0] = 0; + for (int k = -nh; k <= nh; k++) { + float v = 0.0f; + ggml_vec_dot_f32(ew0, &v, + (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, + (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); + + dst_data[i0] += v; + } + } + } +} + +static void ggml_compute_forward_conv_1d_s1_ph( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_conv_1d_s1_ph_f16_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_conv_1d_s1_ph_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +static void ggml_compute_forward_conv_1d_s2_ph_f16_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00; + const int nh = nk/2; + + const int ew0 = ggml_up32(ne01); + + GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + + // prepare kernel data (src0) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); + ggml_fp16_t * dst_data = wdata + i02*ew0*ne00; + for (int64_t i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ew0 + i01] = src[i00]; + } + } + } + } + + // prepare source data (src1) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00; + + for (int64_t i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + ggml_fp16_t * dst_data = wdata; + for (int64_t i10 = 0; i10 < ne10; i10++) { + dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]); + } + } + } + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // total rows in dst + const int nr = ne02; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + for (int64_t i0 = 0; i0 < ne10; i0 += 2) { + dst_data[i0/2] = 0; + for (int k = -nh; k <= nh; k++) { + float v = 0.0f; + ggml_vec_dot_f16(ew0, &v, + (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, + (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); + + dst_data[i0/2] += v; + } + } + } +} + +static void ggml_compute_forward_conv_1d_s2_ph_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00; + const int nh = nk/2; + + const int ew0 = ggml_up32(ne01); + + GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + + // prepare kernel data (src0) + { + float * const wdata = (float *) params->wdata + 0; + + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); + float * dst_data = wdata + i02*ew0*ne00; + for (int64_t i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ew0 + i01] = src[i00]; + } + } + } + } + + // prepare source data (src1) + { + float * const wdata = (float *) params->wdata + ne02*ew0*ne00; + + for (int64_t i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + float * dst_data = wdata; + for (int64_t i10 = 0; i10 < ne10; i10++) { + dst_data[(i10 + nh)*ew0 + i11] = src[i10]; + } + } + } + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // total rows in dst + const int nr = ne02; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + for (int64_t i0 = 0; i0 < ne10; i0 += 2) { + dst_data[i0/2] = 0; + for (int k = -nh; k <= nh; k++) { + float v = 0.0f; + ggml_vec_dot_f32(ew0, &v, + (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, + (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); + + dst_data[i0/2] += v; + } + } + } +} + +static void ggml_compute_forward_conv_1d_s2_ph( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_conv_1d_s2_ph_f16_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_conv_1d_s2_ph_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_conv_1d + +static void ggml_compute_forward_conv_1d( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; + GGML_ASSERT(d0 == 1); // dilation not supported + GGML_ASSERT(p0 == src0->ne[0]/2); // only half padding supported + if (s0 == 1) { + ggml_compute_forward_conv_1d_s1_ph(params, src0, src1, dst); + } else if (s0 == 2) { + ggml_compute_forward_conv_1d_s2_ph(params, src0, src1, dst); + } else { + GGML_ASSERT(false); // only stride 1 and 2 supported + }; +} + +// ggml_compute_forward_conv_2d + +static void ggml_compute_forward_conv_2d_f16_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk0 = ne00; + const int nk1 = ne01; + + // size of the convolution row - the kernel size unrolled across all channels + const int ew0 = nk0*nk1*ne02; + + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; + + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + memset(params->wdata, 0, params->wsize); + + // prepare source data (src1) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + + for (int i12 = 0; i12 < ne12; i12++) { + const float * const src = (float *)((char *) src1->data + i12*nb12); + ggml_fp16_t * dst_data = wdata; + + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < ne0; i0++) { + for (int ik1 = 0; ik1 < nk1; ik1++) { + for (int ik0 = 0; ik0 < nk0; ik0++) { + const int idx0 = i0*s0 + ik0*d0 - p0; + const int idx1 = i1*s1 + ik1*d1 - p1; + + if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) { + dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] = + GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]); + } + } + } + } + } + } + } + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // total patches in dst + const int np = ne2; + + // patches per thread + const int dp = (np + nth - 1)/nth; + + // patch range for this thread + const int ip0 = dp*ith; + const int ip1 = MIN(ip0 + dp, np); + + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = ip0; i2 < ip1; i2++) { + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2); + + for (int i1 = 0; i1 < ne1; ++i1) { + for (int i0 = 0; i0 < ne0; ++i0) { + ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0, + (ggml_fp16_t *) ((char *) src0->data + i2*nb03), + (ggml_fp16_t *) wdata + i3*nb3 + (i1*ne0 + i0)*ew0); + } + } + } + } +} + +static void ggml_compute_forward_conv_2d( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + //ggml_compute_forward_conv_2d_f32(params, src0, src1, dst); + GGML_ASSERT(false); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_conv_transpose_2d + +static void ggml_compute_forward_conv_transpose_2d( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00*ne01*ne02*ne03; + + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + memset(params->wdata, 0, params->wsize); + + // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02); + ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03; + for (int64_t i01 = 0; i01 < ne01; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00]; + } + } + } + } + } + + // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk; + for (int i12 = 0; i12 < ne12; i12++) { + for (int i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11); + ggml_fp16_t * dst_data = wdata + i11*ne10*ne12; + for (int i10 = 0; i10 < ne10; i10++) { + dst_data[i10*ne12 + i12] = GGML_FP32_TO_FP16(src[i10]); + } + } + } + } + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + const int32_t stride = ggml_get_op_params_i32(dst, 0); + + // total patches in dst + const int np = ne2; + + // patches per thread + const int dp = (np + nth - 1)/nth; + + // patch range for this thread + const int ip0 = dp*ith; + const int ip1 = MIN(ip0 + dp, np); + + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + ggml_fp16_t * const wdata_src = wdata + nk; + + for (int i2 = ip0; i2 < ip1; i2++) { // Cout + float * dst_data = (float *)((char *) dst->data + i2*nb2); + ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03; + for (int i11 = 0; i11 < ne11; i11++) { + for (int i10 = 0; i10 < ne10; i10++) { + const int i1n = i11*ne10*ne12 + i10*ne12; + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + float v = 0; + ggml_vec_dot_f16(ne03, &v, + wdata_src + i1n, + wdata_kernel + i01*ne00*ne03 + i00*ne03); + dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v; + } + } + } + } + } +} + +// ggml_compute_forward_pool_1d_sk_p0 + +static void ggml_compute_forward_pool_1d_sk_p0( + const struct ggml_compute_params * params, + const enum ggml_op_pool op, + const struct ggml_tensor * src, + const int k, + struct ggml_tensor * dst) { + assert(src->type == GGML_TYPE_F32); + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const char * cdata = (const char *)src->data; + const char * const data_end = cdata + ggml_nbytes(src); + float * drow = (float *)dst->data; + + const int64_t rs = dst->ne[0]; + + while (cdata < data_end) { + const float * const srow = (const float *)cdata; + + int j = 0; + + for (int64_t i = 0; i < rs; ++i) { + switch (op) { + case GGML_OP_POOL_AVG: drow[i] = 0; break; + case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break; + case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; + } + for (int ki = 0; ki < k; ++ki) { + switch (op) { + case GGML_OP_POOL_AVG: drow[i] += srow[j]; break; + case GGML_OP_POOL_MAX: if (srow[j] > drow[i]) drow[i] = srow[j]; break; + case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; + } + ++j; + } + switch (op) { + case GGML_OP_POOL_AVG: drow[i] /= k; break; + case GGML_OP_POOL_MAX: break; + case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; + } + } + + cdata += src->nb[1]; + drow += rs; + } +} + +// ggml_compute_forward_pool_1d + +static void ggml_compute_forward_pool_1d( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + + const int32_t * opts = (const int32_t *)dst->op_params; + enum ggml_op_pool op = opts[0]; + const int k0 = opts[1]; + const int s0 = opts[2]; + const int p0 = opts[3]; + GGML_ASSERT(p0 == 0); // padding not supported + GGML_ASSERT(k0 == s0); // only s = k supported + + ggml_compute_forward_pool_1d_sk_p0(params, op, src0, k0, dst); +} + +// ggml_compute_forward_pool_2d_sk_p0 + +static void ggml_compute_forward_pool_2d_sk_p0( + const struct ggml_compute_params * params, + const enum ggml_op_pool op, + const struct ggml_tensor * src, + const int k0, + const int k1, + struct ggml_tensor * dst) { + assert(src->type == GGML_TYPE_F32); + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const char * cdata = (const char*)src->data; + const char * const data_end = cdata + ggml_nbytes(src); + + const int64_t px = dst->ne[0]; + const int64_t py = dst->ne[1]; + const int64_t pa = px * py; + + float * dplane = (float *)dst->data; + + const int ka = k0 * k1; + + while (cdata < data_end) { + for (int oy = 0; oy < py; ++oy) { + float * const drow = dplane + oy * px; + for (int ox = 0; ox < px; ++ox) { + float * const out = drow + ox; + switch (op) { + case GGML_OP_POOL_AVG: *out = 0; break; + case GGML_OP_POOL_MAX: *out = -FLT_MAX; break; + case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; + } + + const int ix = ox * k0; + const int iy = oy * k1; + + for (int ky = 0; ky < k1; ++ky) { + const float * const srow = (const float *)(cdata + src->nb[1] * (iy + ky)); + for (int kx = 0; kx < k0; ++kx) { + int j = ix + kx; + switch (op) { + case GGML_OP_POOL_AVG: *out += srow[j]; break; + case GGML_OP_POOL_MAX: if (srow[j] > *out) *out = srow[j]; break; + case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; + } + } + } + switch (op) { + case GGML_OP_POOL_AVG: *out /= ka; break; + case GGML_OP_POOL_MAX: break; + case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; + } + } + } + + cdata += src->nb[2]; + dplane += pa; + } +} + +// ggml_compute_forward_pool_2d + +static void ggml_compute_forward_pool_2d( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + + const int32_t * opts = (const int32_t *)dst->op_params; + enum ggml_op_pool op = opts[0]; + const int k0 = opts[1]; + const int k1 = opts[2]; + const int s0 = opts[3]; + const int s1 = opts[4]; + const int p0 = opts[5]; + const int p1 = opts[6]; + GGML_ASSERT(p0 == 0); + GGML_ASSERT(p1 == 0); // padding not supported + GGML_ASSERT(k0 == s0); + GGML_ASSERT(k1 == s1); // only s = k supported + + ggml_compute_forward_pool_2d_sk_p0(params, op, src0, k0, k1, dst); +} + +// ggml_compute_forward_upscale + +static void ggml_compute_forward_upscale_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + + GGML_TENSOR_UNARY_OP_LOCALS; + + const int scale_factor = dst->op_params[0]; + + // TODO: optimize + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = ith; i02 < ne02; i02++) { + for (int m = 0; m < dst->ne[1]; m++) { + int i01 = m / scale_factor; + for (int n = 0; n < dst->ne[0]; n++) { + int i00 = n / scale_factor; + + const float * x = (float *)((char *) src0->data + i00 * nb00 +i01 * nb01 + i02 * nb02 + i03 * nb03); + + float * y = (float *)((char *) dst->data + n * dst->nb[0] + m * dst->nb[1] + i02 * dst->nb[2] + i03 * dst->nb[3]); + + *y = *x; + } + } + } + } +} + +static void ggml_compute_forward_upscale( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_upscale_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_flash_attn + +static void ggml_compute_forward_flash_attn_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const bool masked, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne); + GGML_TENSOR_LOCALS(size_t, nbq, q, nb); + GGML_TENSOR_LOCALS(int64_t, nek, k, ne); + GGML_TENSOR_LOCALS(size_t, nbk, k, nb); + GGML_TENSOR_LOCALS(int64_t, nev, v, ne); + GGML_TENSOR_LOCALS(size_t, nbv, v, nb); + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + const int64_t P = nek1 - N; + const int64_t M = P + N; + + const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne1 == N); + GGML_ASSERT(P >= 0); + + GGML_ASSERT(nbq0 == sizeof(float)); + GGML_ASSERT(nbk0 == sizeof(float)); + GGML_ASSERT(nbv0 == sizeof(float)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev1 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nek1 == N + P); + GGML_ASSERT(nev1 == D); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const float scale = 1.0f/sqrtf(D); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32); + + for (int i = M; i < Mup; ++i) { + S[i] = -INFINITY; + } + + for (int64_t ic = 0; ic < nek1; ++ic) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f32(neq0, + S + i1, + (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + + // scale + ggml_vec_scale_f32(nek1, S, scale); + + if (masked) { + for (int64_t i = P; i < M; i++) { + if (i > P + iq1) { + S[i] = -INFINITY; + } + } + } + + // softmax + { + float max = -INFINITY; + ggml_vec_max_f32(M, &max, S); + + ggml_float sum = 0.0; + { +#ifdef GGML_SOFT_MAX_ACCELERATE + max = -max; + vDSP_vsadd(S, 1, &max, S, 1, Mup); + vvexpf(S, S, &Mup); + ggml_vec_sum_f32(Mup, &sum, S); +#else + uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt); + ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; + + for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { + float * SS = S + i; + + for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { + if (SS[j] == -INFINITY) { + SS[j] = 0.0f; + } else { +#ifndef GGML_FLASH_ATTN_EXP_FP16 + const float val = expf(SS[j] - max); +#else + ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); + memcpy(&scvt[j], &s, sizeof(uint16_t)); + const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); +#endif + sump[j] += (ggml_float)val; + SS[j] = val; + } + } + } + + for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { + sum += sump[i]; + } +#endif + } + + assert(sum > 0.0); + + sum = 1.0/sum; + ggml_vec_scale_f32(M, S, sum); + +#ifndef NDEBUG + for (int i = 0; i < M; ++i) { + assert(!isnan(S[i])); + assert(!isinf(S[i])); + } +#endif + } + + for (int64_t ic = 0; ic < nev1; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + ggml_vec_dot_f32(nek1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + S); + } + } +} + +static void ggml_compute_forward_flash_attn_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const bool masked, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne); + GGML_TENSOR_LOCALS(size_t, nbq, q, nb); + GGML_TENSOR_LOCALS(int64_t, nek, k, ne); + GGML_TENSOR_LOCALS(size_t, nbk, k, nb); + GGML_TENSOR_LOCALS(int64_t, nev, v, ne); + GGML_TENSOR_LOCALS(size_t, nbv, v, nb); + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + const int64_t P = nek1 - N; + const int64_t M = P + N; + + const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne1 == N); + GGML_ASSERT(P >= 0); + + GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev1 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nek1 == N + P); + GGML_ASSERT(nev1 == D); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const float scale = 1.0f/sqrtf(D); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32); + + for (int i = M; i < Mup; ++i) { + S[i] = -INFINITY; + } + + if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { + for (int64_t ic = 0; ic < nek1; ++ic) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f16(neq0, + S + i1, + (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + } else { + for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f16_unroll(neq0, nbk1, + S + i1, + ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + } + + // scale + ggml_vec_scale_f32(nek1, S, scale); + + if (masked) { + for (int64_t i = P; i < M; i++) { + if (i > P + iq1) { + S[i] = -INFINITY; + } + } + } + + // softmax + { + float max = -INFINITY; + ggml_vec_max_f32(M, &max, S); + + ggml_float sum = 0.0; + { +#ifdef GGML_SOFT_MAX_ACCELERATE + max = -max; + vDSP_vsadd(S, 1, &max, S, 1, Mup); + vvexpf(S, S, &Mup); + ggml_vec_sum_f32(Mup, &sum, S); +#else + uint16_t scvt[GGML_SOFT_MAX_UNROLL]; + ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; + + for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { + float * SS = S + i; + + for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { + if (SS[j] == -INFINITY) { + SS[j] = 0.0f; + } else { + ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); + memcpy(&scvt[j], &s, sizeof(uint16_t)); + const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); + sump[j] += (ggml_float)val; + SS[j] = val; + } + } + } + + for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { + sum += sump[i]; + } +#endif + } + + assert(sum > 0.0); + + sum = 1.0/sum; + ggml_vec_scale_f32(M, S, sum); + +#ifndef NDEBUG + for (int i = 0; i < M; ++i) { + assert(!isnan(S[i])); + assert(!isinf(S[i])); + } +#endif + } + + ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup); + + for (int64_t i = 0; i < M; i++) { + S16[i] = GGML_FP32_TO_FP16(S[i]); + } + + if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) { + for (int64_t ic = 0; ic < nev1; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + ggml_vec_dot_f16(nek1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + S16); + } + } else { + for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + ggml_vec_dot_f16_unroll(nek1, nbv1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + S16); + } + } + } +} + +static void ggml_compute_forward_flash_attn( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const bool masked, + struct ggml_tensor * dst) { + switch (q->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_flash_attn_f16(params, q, k, v, masked, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_flash_ff + +static void ggml_compute_forward_flash_ff_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * a, // F16 + const struct ggml_tensor * b0, // F16 fc_w + const struct ggml_tensor * b1, // F32 fc_b + const struct ggml_tensor * c0, // F16 proj_w + const struct ggml_tensor * c1, // F32 proj_b + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_LOCALS(int64_t, nea, a, ne); + GGML_TENSOR_LOCALS(size_t, nba, a, nb); + GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne); + GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb); + GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne); + GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb); + GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne); + GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb); + GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne); + GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb); + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = nea0; + //const int64_t N = nea1; + const int64_t M = neb01; + + GGML_ASSERT(ne0 == nea0); + GGML_ASSERT(ne1 == nea1); + GGML_ASSERT(ne2 == nea2); + + GGML_ASSERT(nba0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbb10 == sizeof(float)); + GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbc10 == sizeof(float)); + + GGML_ASSERT(neb00 == D); + GGML_ASSERT(neb01 == M); + GGML_ASSERT(neb10 == M); + GGML_ASSERT(neb11 == 1); + + GGML_ASSERT(nec00 == M); + GGML_ASSERT(nec01 == D); + GGML_ASSERT(nec10 == D); + GGML_ASSERT(nec11 == 1); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by a rows using ggml_vec_dot_f32 + + // total rows in a + const int nr = nea1*nea2*nea3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // a indices + const int ia3 = ir/(nea2*nea1); + const int ia2 = (ir - ia3*nea2*nea1)/nea1; + const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1); + + float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32); + + for (int64_t ic = 0; ic < neb01; ++ic) { + // b0 indices + const int ib03 = ia3; + const int ib02 = ia2; + const int ib01 = ic; + + // S indices + const int i1 = ib01; + + ggml_vec_dot_f16(nea0, + S + i1, + (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), + (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3))); + } + + ggml_vec_add_f32(neb01, S, S, (float *) b1->data); + //ggml_vec_gelu_f32(neb01, S, S); + + ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M); + + for (int64_t i = 0; i < M; i++) { + S16[i] = GGML_FP32_TO_FP16(S[i]); + } + + ggml_vec_gelu_f16(neb01, S16, S16); + + { + // dst indices + const int i1 = ia1; + const int i2 = ia2; + const int i3 = ia3; + + for (int64_t ic = 0; ic < nec01; ++ic) { + + ggml_vec_dot_f16(neb01, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), + S16); + } + + ggml_vec_add_f32(nec01, + (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)), + (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)), + (float *) c1->data); + } + } +} + +static void ggml_compute_forward_flash_ff( + const struct ggml_compute_params * params, + const struct ggml_tensor * a, + const struct ggml_tensor * b0, + const struct ggml_tensor * b1, + const struct ggml_tensor * c0, + const struct ggml_tensor * c1, + struct ggml_tensor * dst) { + switch (b0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_flash_ff_f16(params, a, b0, b1, c0, c1, dst); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(false); // TODO + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_flash_attn_back + +static void ggml_compute_forward_flash_attn_back_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * d, + const bool masked, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne); + GGML_TENSOR_LOCALS(size_t, nbq, q, nb); + GGML_TENSOR_LOCALS(int64_t, nek, k, ne); + GGML_TENSOR_LOCALS(size_t, nbk, k, nb); + GGML_TENSOR_LOCALS(int64_t, nev, v, ne); + GGML_TENSOR_LOCALS(size_t, nbv, v, nb); + GGML_TENSOR_LOCALS(int64_t, ned, d, ne); + GGML_TENSOR_LOCALS(size_t, nbd, d, nb); + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + const int64_t P = nek1 - N; + const int64_t M = P + N; + + const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); + const int mxDM = MAX(D, Mup); + + // GGML_ASSERT(ne0 == D); + // GGML_ASSERT(ne1 == N); + GGML_ASSERT(P >= 0); + + GGML_ASSERT(nbq0 == sizeof(float)); + GGML_ASSERT(nbk0 == sizeof(float)); + GGML_ASSERT(nbv0 == sizeof(float)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev1 == D); + GGML_ASSERT(ned0 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nek1 == N + P); + GGML_ASSERT(nev1 == D); + GGML_ASSERT(ned1 == N); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (params->type == GGML_TASK_INIT) { + if (ith == 0) { + memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3); + } + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const float scale = 1.0f/sqrtf(D); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2); + const int iq2 = ir - iq3*neq2; + for ( int iq1 = 0; iq1 < neq1; ++iq1) { + + + // not sure about CACHE_LINE_SIZE_F32.. + // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset? + float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32); + float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32); + + for (int i = M; i < Mup; ++i) { + S[i] = -INFINITY; + } + + for (int64_t ic = 0; ic < nek1; ++ic) { + // k indices + const int ik3 = iq3; + const int ik2 = iq2; + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f32(neq0, + S + i1, + (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); + } + + // scale + ggml_vec_scale_f32(nek1, S, scale); + + if (masked) { + for (int64_t i = P; i < M; i++) { + if (i > P + iq1) { + S[i] = -INFINITY; + } + } + } + + // softmax + { + float max = -INFINITY; + ggml_vec_max_f32(M, &max, S); + + ggml_float sum = 0.0; + { +#ifdef GGML_SOFT_MAX_ACCELERATE + max = -max; + vDSP_vsadd(SM, 1, &max, SM, 1, Mup); + vvexpf(SM, SM, &Mup); + ggml_vec_sum_f32(Mup, &sum, SM); +#else + uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt); + ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; + + for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { + float * SR = S + i; + float * SW = SM + i; + + for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { + if (SR[j] == -INFINITY) { + SW[j] = 0.0f; + } else { +#ifndef GGML_FLASH_ATTN_EXP_FP16 + const float val = expf(SR[j] - max); +#else + ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max); + memcpy(&scvt[j], &s, sizeof(uint16_t)); + const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); +#endif + sump[j] += (ggml_float)val; + SW[j] = val; + } + } + } + + for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { + sum += sump[i]; + } +#endif + } + + assert(sum > 0.0); + + sum = 1.0/sum; + ggml_vec_scale_f32(M, SM, sum); + + } + + // step-by-step explanation + { + // forward-process shape grads from backward process + // parallel_for iq2,iq3: + // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur] + // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur] + // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur] + // for iq1: + // kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur + // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur + // vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4 + // S0 = -Inf [D,1,1,1] + // ~S1[i] = dot(kcur[:D,i], qcur) + // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale + // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P) + // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur + // ~S5[i] = dot(vcur[:,i], S4) + // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3] + // ~dst[i,iq1,iq2,iq3] = S5[i] ^ + // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3] + // dst backward-/ grad[dst] = d + // + // output gradients with their dependencies: + // + // grad[kcur] = grad[S1].T @ qcur + // grad[S1] = diag_mask_zero(grad[S3], P) * scale + // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // grad[S4] = grad[S5] @ vcur + // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur + // grad[qcur] = grad[S1] @ kcur + // grad[vcur] = grad[S5].T @ S4 + // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4 + // + // in post-order: + // + // S1 = qcur @ kcur.T + // S2 = S1 * scale + // S3 = diag_mask_inf(S2, P) + // S4 = softmax(S3) + // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur + // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // grad[S1] = diag_mask_zero(grad[S3], P) * scale + // grad[qcur] = grad[S1] @ kcur + // grad[kcur] = grad[S1].T @ qcur + // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4 + // + // using less variables (SM=S4): + // + // S = diag_mask_inf(qcur @ kcur.T * scale, P) + // SM = softmax(S) + // S = d[:D,iq1,iq2,iq3] @ vcur + // dot_SM_gradSM = dot(SM, S) + // S = SM * (S - dot(SM, S)) + // S = diag_mask_zero(S, P) * scale + // + // grad[q][:D,iq1,iq2,iq3] += S @ kcur + // grad[k][:D,:M,iq2,iq3] += S.T @ qcur + // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM + } + + // S = gradSM = d[:D,iq1,iq2,iq3] @ vcur + // S = d[:D,iq1,iq2,iq3] @ vcur + // S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3] + ggml_vec_set_f32(M, S, 0); + for (int64_t ic = 0; ic < D; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + ggml_vec_mad_f32(M, + S, + (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3))); + } + + // S = SM * (S - dot(SM, S)) + float dot_SM_gradSM = 0; + ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S); + ggml_vec_acc1_f32(M, S, -dot_SM_gradSM); + ggml_vec_mul_f32 (M, S, S, SM); + + // S = diag_mask_zero(S, P) * scale + if (masked) { + // for (int64_t i = P + iq1 + 1; i < M; i++) { + // S[i] = 0; + // } + for (int64_t i = P; i < M; i++) { + if (i > P + iq1) { + S[i] = 0; + } + } + } + ggml_vec_scale_f32(M, S, scale); + + void * grad_q = (char *) dst->data; + void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3; + void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3; + + const size_t nbgq1 = nb0*neq0; + const size_t nbgq2 = nb0*neq0*neq1; + const size_t nbgq3 = nb0*neq0*neq1*neq2; + + const size_t nbgk1 = nb0*nek0; + const size_t nbgk2 = nb0*nek0*nek1; + const size_t nbgk3 = nb0*nek0*nek1*neq2; + + const size_t nbgv1 = nb0*nev0; + const size_t nbgv2 = nb0*nev0*nev1; + const size_t nbgv3 = nb0*nev0*nev1*neq2; + + // S shape [M,1] + // SM shape [M,1] + // kcur shape [D,M] + // qcur shape [D,1] + // vcur shape [M,D] + // + // grad[q][:D,iq1,iq2,iq3] += S @ kcur + // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M] + // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic] + // + //// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T) + //// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T) + for (int64_t ic = 0; ic < M; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + ggml_vec_mad_f32(D, + (float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)), + (float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)), + S[ic]); + } + + // grad[k][:D,:M,iq2,iq3] += S.T @ qcur + // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0] + // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0] + for (int64_t ic = 0; ic < M; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // ggml_vec_set_f32(D, + // (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)), + // 0); + ggml_vec_mad_f32(D, + (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)), + (float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)), + S[ic]); + } + + // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM + // grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M] + // grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M] + for (int64_t ic = 0; ic < D; ++ic) { + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // ggml_vec_set_f32(M, + // (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)), + // 0); + ggml_vec_mad_f32(M, + (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)), + SM, + *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3))); + } + } + } +} + +static void ggml_compute_forward_flash_attn_back( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * d, + const bool masked, + struct ggml_tensor * dst) { + switch (q->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_flash_attn_back_f32(params, q, k, v, d, masked, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_win_part + +static void ggml_compute_forward_win_part_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + + const int32_t nep0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t nep1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t w = ((const int32_t *)(dst->op_params))[2]; + + assert(ne00 == ne0); + assert(ne3 == nep0*nep1); + + // TODO: optimize / multi-thread + for (int py = 0; py < nep1; ++py) { + for (int px = 0; px < nep0; ++px) { + const int64_t i3 = py*nep0 + px; + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int64_t i02 = py*w + i2; + const int64_t i01 = px*w + i1; + const int64_t i00 = i0; + + const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0; + const int64_t j = i02*ne01*ne00 + i01*ne00 + i00; + + if (py*w + i2 >= ne02 || px*w + i1 >= ne01) { + ((float *) dst->data)[i] = 0.0f; + } else { + ((float *) dst->data)[i] = ((float *) src0->data)[j]; + } + } + } + } + } + } +} + +static void ggml_compute_forward_win_part( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_win_part_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_win_unpart + +static void ggml_compute_forward_win_unpart_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + + const int32_t w = ((const int32_t *)(dst->op_params))[0]; + + // padding + const int px = (w - ne1%w)%w; + //const int py = (w - ne2%w)%w; + + const int npx = (px + ne1)/w; + //const int npy = (py + ne2)/w; + + assert(ne0 == ne00); + + // TODO: optimize / multi-thread + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int ip2 = i2/w; + const int ip1 = i1/w; + + const int64_t i02 = i2%w; + const int64_t i01 = i1%w; + const int64_t i00 = i0; + + const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00; + const int64_t j = i2*ne1*ne0 + i1*ne0 + i0; + + ((float *) dst->data)[j] = ((float *) src0->data)[i]; + } + } + } +} + +static void ggml_compute_forward_win_unpart( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_win_unpart_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +//gmml_compute_forward_unary + +static void ggml_compute_forward_unary( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + const enum ggml_unary_op op = ggml_get_unary_op(dst); + + switch (op) { + case GGML_UNARY_OP_ABS: + { + ggml_compute_forward_abs(params, src0, dst); + } break; + case GGML_UNARY_OP_SGN: + { + ggml_compute_forward_sgn(params, src0, dst); + } break; + case GGML_UNARY_OP_NEG: + { + ggml_compute_forward_neg(params, src0, dst); + } break; + case GGML_UNARY_OP_STEP: + { + ggml_compute_forward_step(params, src0, dst); + } break; + case GGML_UNARY_OP_TANH: + { + ggml_compute_forward_tanh(params, src0, dst); + } break; + case GGML_UNARY_OP_ELU: + { + ggml_compute_forward_elu(params, src0, dst); + } break; + case GGML_UNARY_OP_RELU: + { + ggml_compute_forward_relu(params, src0, dst); + } break; + case GGML_UNARY_OP_GELU: + { + ggml_compute_forward_gelu(params, src0, dst); + } break; + case GGML_UNARY_OP_GELU_QUICK: + { + ggml_compute_forward_gelu_quick(params, src0, dst); + } break; + case GGML_UNARY_OP_SILU: + { + ggml_compute_forward_silu(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_get_rel_pos + +static void ggml_compute_forward_get_rel_pos_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322 + + GGML_TENSOR_UNARY_OP_LOCALS; + + const int64_t w = ne1; + + ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data; + ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data; + + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + const int64_t pos = (w - i1 - 1) + i2; + for (int64_t i0 = 0; i0 < ne0; ++i0) { + dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0]; + } + } + } +} + +static void ggml_compute_forward_get_rel_pos( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_get_rel_pos_f16(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_add_rel_pos + +static void ggml_compute_forward_add_rel_pos_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * src2, + struct ggml_tensor * dst) { + + const bool inplace = (bool) ((int32_t *) dst->op_params)[0]; + if (!inplace && params->type == GGML_TASK_INIT) { + memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst)); + return; + } + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359 + + float * src1_data = (float *) src1->data; + float * src2_data = (float *) src2->data; + float * dst_data = (float *) dst->data; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + + const int ith = params->ith; + const int nth = params->nth; + + // total patches in dst + const int np = ne13; + + // patches per thread + const int dp = (np + nth - 1)/nth; + + // patch range for this thread + const int ip0 = dp*ith; + const int ip1 = MIN(ip0 + dp, np); + + + for (int64_t i13 = ip0; i13 < ip1; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10; + for (int64_t i10 = 0; i10 < ne10; ++i10) { + const int64_t jp0 = jp1 + i10; + const float src1_e = src1_data[jp0]; + const float src2_e = src2_data[jp0]; + + const int64_t jdh = jp0 * ne10; + const int64_t jdw = jdh - (ne10 - 1) * i10; + + for (int64_t j = 0; j < ne10; ++j) { + dst_data[jdh + j ] += src2_e; + dst_data[jdw + j*ne10] += src1_e; + } + } + } + } + } +} + +static void ggml_compute_forward_add_rel_pos( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * src2, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_add_rel_pos_f32(params, src0, src1, src2, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_map_unary + +static void ggml_compute_forward_map_unary_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst, + const ggml_unary_op_f32_t fun) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + fun(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + + +static void ggml_compute_forward_map_unary( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst, + const ggml_unary_op_f32_t fun) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_map_unary_f32(params, src0, dst, fun); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_map_binary + +static void ggml_compute_forward_map_binary_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst, + const ggml_binary_op_f32_t fun) { + assert(params->ith == 0); + assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + assert(src1->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + fun(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1])), + (float *) ((char *) src1->data + i*(src1->nb[1]))); + } +} + + +static void ggml_compute_forward_map_binary( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst, + const ggml_binary_op_f32_t fun) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_map_custom1 + +static void ggml_compute_forward_map_custom1_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * a, + struct ggml_tensor * dst, + const ggml_custom1_op_f32_t fun) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + fun(dst, a); +} + +// ggml_compute_forward_map_custom2 + +static void ggml_compute_forward_map_custom2_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * a, + const struct ggml_tensor * b, + struct ggml_tensor * dst, + const ggml_custom2_op_f32_t fun) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + fun(dst, a, b); +} + + +// ggml_compute_forward_map_custom3 + +static void ggml_compute_forward_map_custom3_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * a, + const struct ggml_tensor * b, + const struct ggml_tensor * c, + struct ggml_tensor * dst, + const ggml_custom3_op_f32_t fun) { + assert(params->ith == 0); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + fun(dst, a, b, c); +} + +// ggml_compute_forward_map_custom1 + +static void ggml_compute_forward_map_custom1( + const struct ggml_compute_params * params, + const struct ggml_tensor * a, + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + struct ggml_map_custom1_op_params * p = (struct ggml_map_custom1_op_params *) dst->op_params; + + p->fun(dst, a, params->ith, params->nth, p->userdata); +} + +// ggml_compute_forward_map_custom2 + +static void ggml_compute_forward_map_custom2( + const struct ggml_compute_params * params, + const struct ggml_tensor * a, + const struct ggml_tensor * b, + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + struct ggml_map_custom2_op_params * p = (struct ggml_map_custom2_op_params *) dst->op_params; + + p->fun(dst, a, b, params->ith, params->nth, p->userdata); +} + +// ggml_compute_forward_map_custom3 + +static void ggml_compute_forward_map_custom3( + const struct ggml_compute_params * params, + const struct ggml_tensor * a, + const struct ggml_tensor * b, + const struct ggml_tensor * c, + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + struct ggml_map_custom3_op_params * p = (struct ggml_map_custom3_op_params *) dst->op_params; + + p->fun(dst, a, b, c, params->ith, params->nth, p->userdata); +} + +// ggml_compute_forward_cross_entropy_loss + +static void ggml_compute_forward_cross_entropy_loss_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_scalar(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + + const int ith = params->ith; + const int nth = params->nth; + + float * sums = (float *) params->wdata; + + // TODO: handle transposed/permuted matrices + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc)); + + if (params->type == GGML_TASK_INIT) { + if (ith == 0) { + memset(sums, 0, sizeof(float) * (nth + nth * nc)); + } + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + if (ith == 0) { + float * dp = (float *) dst->data; + ggml_vec_sum_f32(nth, dp, sums); + dp[0] *= -1.0f / (float) nr; + } + return; + } + + const double eps = 1e-9; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]); + float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]); + float * st = ((float *) params->wdata) + nth + ith*nc; + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(s0[i])); + assert(!isnan(s1[i])); + } +#endif + // soft_max + ggml_float sum = 0.0; + { + float max = -INFINITY; + ggml_vec_max_f32(nc, &max, s0); + + uint16_t scvt; UNUSED(scvt); + for (int i = 0; i < nc; i++) { + if (s0[i] == -INFINITY) { + st[i] = 0.0f; + } else { +#ifndef GGML_CROSS_ENTROPY_EXP_FP16 + const float s = s0[i] - max; + const float val = expf(s); +#else + ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max); + memcpy(&scvt, &s, sizeof(scvt)); + const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]); +#endif + sum += (ggml_float)val; + st[i] = val; + } + } + + assert(sum > 0.0); + // sum = 1.0/sum; + } + // avoid log(0) by rescaling from [0..1] to [eps..1] + sum = (1.0 - eps) / sum; + ggml_vec_scale_f32(nc, st, sum); + ggml_vec_add1_f32(nc, st, st, eps); + ggml_vec_log_f32(nc, st, st); + ggml_vec_mul_f32(nc, st, st, s1); + + float st_sum = 0; + ggml_vec_sum_f32(nc, &st_sum, st); + sums[ith] += st_sum; + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(st[i])); + assert(!isinf(st[i])); + } +#endif + } + +} + +static void ggml_compute_forward_cross_entropy_loss( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_cross_entropy_loss_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_cross_entropy_loss_back + +static void ggml_compute_forward_cross_entropy_loss_back_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * opt0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(opt0)); + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + const int64_t ith = params->ith; + const int64_t nth = params->nth; + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const double eps = 1e-9; + + // TODO: handle transposed/permuted matrices + const int64_t nc = src0->ne[0]; + const int64_t nr = ggml_nrows(src0); + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + float * d = (float *) opt0->data; + + for (int64_t i1 = ir0; i1 < ir1; i1++) { + float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]); + float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]); + float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(s0[i])); + assert(!isnan(s1[i])); + } +#endif + + // soft_max + ggml_float sum = 0.0; + { + float max = -INFINITY; + ggml_vec_max_f32(nc, &max, s0); + + uint16_t scvt; UNUSED(scvt); + for (int i = 0; i < nc; i++) { + if (s0[i] == -INFINITY) { + ds0[i] = 0.0f; + } else { +#ifndef GGML_CROSS_ENTROPY_EXP_FP16 + const float s = s0[i] - max; + const float val = expf(s); +#else + ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max); + memcpy(&scvt, &s, sizeof(scvt)); + const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]); +#endif + sum += (ggml_float)val; + ds0[i] = val; + } + } + + assert(sum > 0.0); + sum = (1.0 - eps)/sum; + } + + // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr + ggml_vec_scale_f32(nc, ds0, sum); + ggml_vec_add1_f32(nc, ds0, ds0, eps); + ggml_vec_sub_f32(nc, ds0, ds0, s1); + ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr); + + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(ds0[i])); + assert(!isinf(ds0[i])); + } +#endif + } +} + +static void ggml_compute_forward_cross_entropy_loss_back( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * opt0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_cross_entropy_loss_back_f32(params, src0, src1, opt0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + + +///////////////////////////////// + +static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { + GGML_ASSERT(params); + +#ifdef GGML_USE_CUBLAS + bool skip_cpu = ggml_cuda_compute_forward(params, tensor); + if (skip_cpu) { + return; + } + GGML_ASSERT(tensor->src[0] == NULL || tensor->src[0]->backend == GGML_BACKEND_CPU); + GGML_ASSERT(tensor->src[1] == NULL || tensor->src[1]->backend == GGML_BACKEND_CPU); +#endif // GGML_USE_CUBLAS + + switch (tensor->op) { + case GGML_OP_DUP: + { + ggml_compute_forward_dup(params, tensor->src[0], tensor); + } break; + case GGML_OP_ADD: + { + ggml_compute_forward_add(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_ADD1: + { + ggml_compute_forward_add1(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_ACC: + { + ggml_compute_forward_acc(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_SUB: + { + ggml_compute_forward_sub(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_MUL: + { + ggml_compute_forward_mul(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_DIV: + { + ggml_compute_forward_div(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_SQR: + { + ggml_compute_forward_sqr(params, tensor->src[0], tensor); + } break; + case GGML_OP_SQRT: + { + ggml_compute_forward_sqrt(params, tensor->src[0], tensor); + } break; + case GGML_OP_LOG: + { + ggml_compute_forward_log(params, tensor->src[0], tensor); + } break; + case GGML_OP_SUM: + { + ggml_compute_forward_sum(params, tensor->src[0], tensor); + } break; + case GGML_OP_SUM_ROWS: + { + ggml_compute_forward_sum_rows(params, tensor->src[0], tensor); + } break; + case GGML_OP_MEAN: + { + ggml_compute_forward_mean(params, tensor->src[0], tensor); + } break; + case GGML_OP_ARGMAX: + { + ggml_compute_forward_argmax(params, tensor->src[0], tensor); + } break; + case GGML_OP_REPEAT: + { + ggml_compute_forward_repeat(params, tensor->src[0], tensor); + } break; + case GGML_OP_REPEAT_BACK: + { + ggml_compute_forward_repeat_back(params, tensor->src[0], tensor); + } break; + case GGML_OP_CONCAT: + { + ggml_compute_forward_concat(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_SILU_BACK: + { + ggml_compute_forward_silu_back(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_NORM: + { + ggml_compute_forward_norm(params, tensor->src[0], tensor); + } break; + case GGML_OP_RMS_NORM: + { + ggml_compute_forward_rms_norm(params, tensor->src[0], tensor); + } break; + case GGML_OP_RMS_NORM_BACK: + { + ggml_compute_forward_rms_norm_back(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_GROUP_NORM: + { + ggml_compute_forward_group_norm(params, tensor->src[0], tensor); + } break; + case GGML_OP_MUL_MAT: + { + ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_OUT_PROD: + { + ggml_compute_forward_out_prod(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_SCALE: + { + ggml_compute_forward_scale(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_SET: + { + ggml_compute_forward_set(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_CPY: + { + ggml_compute_forward_cpy(params, tensor->src[0], tensor); + } break; + case GGML_OP_CONT: + { + ggml_compute_forward_cont(params, tensor->src[0], tensor); + } break; + case GGML_OP_RESHAPE: + { + ggml_compute_forward_reshape(params, tensor->src[0], tensor); + } break; + case GGML_OP_VIEW: + { + ggml_compute_forward_view(params, tensor->src[0]); + } break; + case GGML_OP_PERMUTE: + { + ggml_compute_forward_permute(params, tensor->src[0]); + } break; + case GGML_OP_TRANSPOSE: + { + ggml_compute_forward_transpose(params, tensor->src[0]); + } break; + case GGML_OP_GET_ROWS: + { + ggml_compute_forward_get_rows(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_GET_ROWS_BACK: + { + ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); + } break; + case GGML_OP_DIAG: + { + ggml_compute_forward_diag(params, tensor->src[0], tensor); + } break; + case GGML_OP_DIAG_MASK_INF: + { + ggml_compute_forward_diag_mask_inf(params, tensor->src[0], tensor); + } break; + case GGML_OP_DIAG_MASK_ZERO: + { + ggml_compute_forward_diag_mask_zero(params, tensor->src[0], tensor); + } break; + case GGML_OP_SOFT_MAX: + { + ggml_compute_forward_soft_max(params, tensor->src[0], tensor); + } break; + case GGML_OP_SOFT_MAX_BACK: + { + ggml_compute_forward_soft_max_back(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_ROPE: + { + ggml_compute_forward_rope(params, tensor->src[0], tensor); + } break; + case GGML_OP_ROPE_BACK: + { + ggml_compute_forward_rope_back(params, tensor->src[0], tensor); + } break; + case GGML_OP_ALIBI: + { + ggml_compute_forward_alibi(params, tensor->src[0], tensor); + } break; + case GGML_OP_CLAMP: + { + ggml_compute_forward_clamp(params, tensor->src[0], tensor); + } break; + case GGML_OP_CONV_1D: + { + ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_CONV_2D: + { + ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_CONV_TRANSPOSE_2D: + { + ggml_compute_forward_conv_transpose_2d(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_POOL_1D: + { + ggml_compute_forward_pool_1d(params, tensor->src[0], tensor); + } break; + case GGML_OP_POOL_2D: + { + ggml_compute_forward_pool_2d(params, tensor->src[0], tensor); + } break; + case GGML_OP_UPSCALE: + { + ggml_compute_forward_upscale(params, tensor->src[0], tensor); + } break; + case GGML_OP_FLASH_ATTN: + { + const int32_t t = ggml_get_op_params_i32(tensor, 0); + GGML_ASSERT(t == 0 || t == 1); + const bool masked = t != 0; + ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor); + } break; + case GGML_OP_FLASH_FF: + { + ggml_compute_forward_flash_ff(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor); + } break; + case GGML_OP_FLASH_ATTN_BACK: + { + int32_t t = ggml_get_op_params_i32(tensor, 0); + GGML_ASSERT(t == 0 || t == 1); + bool masked = t != 0; + ggml_compute_forward_flash_attn_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], masked, tensor); + } break; + case GGML_OP_WIN_PART: + { + ggml_compute_forward_win_part(params, tensor->src[0], tensor); + } break; + case GGML_OP_WIN_UNPART: + { + ggml_compute_forward_win_unpart(params, tensor->src[0], tensor); + } break; + case GGML_OP_UNARY: + { + ggml_compute_forward_unary(params, tensor->src[0], tensor); + } break; + case GGML_OP_GET_REL_POS: + { + ggml_compute_forward_get_rel_pos(params, tensor->src[0], tensor); + } break; + case GGML_OP_ADD_REL_POS: + { + ggml_compute_forward_add_rel_pos(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); + } break; + case GGML_OP_MAP_UNARY: + { + ggml_unary_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); + ggml_compute_forward_map_unary(params, tensor->src[0], tensor, fun); + } + break; + case GGML_OP_MAP_BINARY: + { + ggml_binary_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); + ggml_compute_forward_map_binary(params, tensor->src[0], tensor->src[1], tensor, fun); + } + break; + case GGML_OP_MAP_CUSTOM1_F32: + { + ggml_custom1_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); + ggml_compute_forward_map_custom1_f32(params, tensor->src[0], tensor, fun); + } + break; + case GGML_OP_MAP_CUSTOM2_F32: + { + ggml_custom2_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); + ggml_compute_forward_map_custom2_f32(params, tensor->src[0], tensor->src[1], tensor, fun); + } + break; + case GGML_OP_MAP_CUSTOM3_F32: + { + ggml_custom3_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); + ggml_compute_forward_map_custom3_f32(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor, fun); + } + break; + case GGML_OP_MAP_CUSTOM1: + { + ggml_compute_forward_map_custom1(params, tensor->src[0], tensor); + } + break; + case GGML_OP_MAP_CUSTOM2: + { + ggml_compute_forward_map_custom2(params, tensor->src[0], tensor->src[1], tensor); + } + break; + case GGML_OP_MAP_CUSTOM3: + { + ggml_compute_forward_map_custom3(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); + } + break; + case GGML_OP_CROSS_ENTROPY_LOSS: + { + ggml_compute_forward_cross_entropy_loss(params, tensor->src[0], tensor->src[1], tensor); + } + break; + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + { + ggml_compute_forward_cross_entropy_loss_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); + } + break; + case GGML_OP_NONE: + { + // nop + } break; + case GGML_OP_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +//////////////////////////////////////////////////////////////////////////////// + +static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) { + struct ggml_tensor * src0 = tensor->src[0]; + struct ggml_tensor * src1 = tensor->src[1]; + + switch (tensor->op) { + case GGML_OP_DUP: + { + if (src0->grad) { + src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + } + } break; + case GGML_OP_ADD: + { + if (src0->grad) { + src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + } + if (src1->grad) { + src1->grad = ggml_add_impl(ctx, src1->grad, tensor->grad, inplace); + } + } break; + case GGML_OP_ADD1: + { + if (src0->grad) { + src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + } + if (src1->grad) { + src1->grad = ggml_add_impl(ctx, + src1->grad, + ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean + inplace); + } + } break; + case GGML_OP_ACC: + { + if (src0->grad) { + src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + } + if (src1->grad) { + const size_t nb1 = ((int32_t *) tensor->op_params)[0]; + const size_t nb2 = ((int32_t *) tensor->op_params)[1]; + const size_t nb3 = ((int32_t *) tensor->op_params)[2]; + const size_t offset = ((int32_t *) tensor->op_params)[3]; + + struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx, + tensor->grad, + src1->grad->ne[0], + src1->grad->ne[1], + src1->grad->ne[2], + src1->grad->ne[3], + nb1, nb2, nb3, offset); + + src1->grad = + ggml_add_impl(ctx, + src1->grad, + ggml_reshape(ctx, + ggml_cont(ctx, tensor_grad_view), + src1->grad), + inplace); + } + } break; + case GGML_OP_SUB: + { + if (src0->grad) { + src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + } + if (src1->grad) { + src1->grad = ggml_sub_impl(ctx, src1->grad, tensor->grad, inplace); + } + } break; + case GGML_OP_MUL: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_mul(ctx, src1, tensor->grad), + inplace); + } + if (src1->grad) { + src1->grad = + ggml_add_impl(ctx, + src1->grad, + ggml_mul(ctx, src0, tensor->grad), + inplace); + } + } break; + case GGML_OP_DIV: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_div(ctx, tensor->grad, src1), + inplace); + } + if (src1->grad) { + src1->grad = + ggml_sub_impl(ctx, + src1->grad, + ggml_mul(ctx, + tensor->grad, + ggml_div(ctx, tensor, src1)), + inplace); + } + } break; + case GGML_OP_SQR: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_scale(ctx, + ggml_mul(ctx, src0, tensor->grad), + ggml_new_f32(ctx, 2.0f)), + inplace); + } + } break; + case GGML_OP_SQRT: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_scale(ctx, + ggml_div(ctx, + tensor->grad, + tensor), + ggml_new_f32(ctx, 0.5f)), + inplace); + } + } break; + case GGML_OP_LOG: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_div(ctx, + tensor->grad, + src0), + inplace); + } + } break; + case GGML_OP_SUM: + { + if (src0->grad) { + src0->grad = + ggml_add1_impl(ctx, + src0->grad, + tensor->grad, + inplace); + } + } break; + case GGML_OP_SUM_ROWS: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_repeat(ctx, + tensor->grad, + src0->grad), + inplace); + } + } break; + case GGML_OP_MEAN: + case GGML_OP_ARGMAX: + { + GGML_ASSERT(false); // TODO: implement + } break; + case GGML_OP_REPEAT: + { + // necessary for llama + if (src0->grad) { + src0->grad = ggml_add_impl(ctx, + src0->grad, + ggml_repeat_back(ctx, tensor->grad, src0->grad), + inplace); + } + } break; + case GGML_OP_REPEAT_BACK: + { + if (src0->grad) { + // TODO: test this + src0->grad = ggml_add_impl(ctx, + src0->grad, + ggml_repeat(ctx, tensor->grad, src0->grad), + inplace); + } + } break; + case GGML_OP_CONCAT: + { + GGML_ASSERT(false); // TODO: implement + } break; + case GGML_OP_SILU_BACK: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_NORM: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_RMS_NORM: + { + // necessary for llama + if (src0->grad) { + float eps; + memcpy(&eps, tensor->op_params, sizeof(float)); + + src0->grad = ggml_add_impl(ctx, + src0->grad, + ggml_rms_norm_back(ctx, src0, tensor->grad, eps), + inplace); + } + } break; + case GGML_OP_RMS_NORM_BACK: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_GROUP_NORM: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_MUL_MAT: + { + // https://cs231n.github.io/optimization-2/#staged + // # forward pass + // s0 = np.random.randn(5, 10) + // s1 = np.random.randn(10, 3) + // t = s0.dot(s1) + + // # now suppose we had the gradient on t from above in the circuit + // dt = np.random.randn(*t.shape) # same shape as t + // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix + // ds1 = t.T.dot(dt) + + // tensor.shape [m,p] + // src0.shape [n,m] + // src1.shape [n,p] + + // necessary for llama + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_out_prod(ctx, // [n,m] + src1, // [n,p] + tensor->grad), // [m,p] + inplace); + } + if (src1->grad) { + src1->grad = + ggml_add_impl(ctx, + src1->grad, + // ggml_mul_mat(ctx, // [n,p] + // ggml_cont(ctx, // [m,n] + // ggml_transpose(ctx, src0)), // [m,n] + // tensor->grad), // [m,p] + + // // when src0 is bigger than tensor->grad (this is mostly the case in llama), + // // avoid transpose of src0, rather transpose smaller tensor->grad + // // and then use ggml_out_prod + ggml_out_prod(ctx, // [n,p] + src0, // [n,m] + ggml_transpose(ctx, // [p,m] + tensor->grad)), // [m,p] + inplace); + } + } break; + case GGML_OP_OUT_PROD: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_SCALE: + { + // necessary for llama + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_scale_impl(ctx, tensor->grad, src1, false), + inplace); + } + if (src1->grad) { + src1->grad = + ggml_add_impl(ctx, + src1->grad, + ggml_sum(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)), + inplace); + } + } break; + case GGML_OP_SET: + { + const size_t nb1 = ((int32_t *) tensor->op_params)[0]; + const size_t nb2 = ((int32_t *) tensor->op_params)[1]; + const size_t nb3 = ((int32_t *) tensor->op_params)[2]; + const size_t offset = ((int32_t *) tensor->op_params)[3]; + + struct ggml_tensor * tensor_grad_view = NULL; + + if (src0->grad || src1->grad) { + GGML_ASSERT(src0->type == tensor->type); + GGML_ASSERT(tensor->grad->type == tensor->type); + GGML_ASSERT(tensor->grad->type == src1->grad->type); + + tensor_grad_view = ggml_view_4d(ctx, + tensor->grad, + src1->grad->ne[0], + src1->grad->ne[1], + src1->grad->ne[2], + src1->grad->ne[3], + nb1, nb2, nb3, offset); + } + + if (src0->grad) { + src0->grad = ggml_add_impl(ctx, + src0->grad, + ggml_acc_impl(ctx, + tensor->grad, + ggml_neg(ctx, tensor_grad_view), + nb1, nb2, nb3, offset, false), + inplace); + } + + if (src1->grad) { + src1->grad = + ggml_add_impl(ctx, + src1->grad, + ggml_reshape(ctx, + ggml_cont(ctx, tensor_grad_view), + src1->grad), + inplace); + } + } break; + case GGML_OP_CPY: + { + // necessary for llama + // cpy overwrites value of src1 by src0 and returns view(src1) + // the overwriting is mathematically equivalent to: + // tensor = src0 * 1 + src1 * 0 + if (src0->grad) { + // dsrc0 = dtensor * 1 + src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + } + if (src1->grad) { + // dsrc1 = dtensor * 0 -> noop + } + } break; + case GGML_OP_CONT: + { + // same as cpy + if (src0->grad) { + GGML_ASSERT(ggml_is_contiguous(src0->grad)); + GGML_ASSERT(ggml_is_contiguous(tensor->grad)); + src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + } + } break; + case GGML_OP_RESHAPE: + { + // necessary for llama + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, src0->grad, + ggml_reshape(ctx, tensor->grad, src0->grad), + inplace); + } + } break; + case GGML_OP_VIEW: + { + // necessary for llama + if (src0->grad) { + size_t offset; + + memcpy(&offset, tensor->op_params, sizeof(offset)); + + size_t nb1 = tensor->nb[1]; + size_t nb2 = tensor->nb[2]; + size_t nb3 = tensor->nb[3]; + + if (src0->type != src0->grad->type) { + // gradient is typically F32, but src0 could be other type + size_t ng = ggml_element_size(src0->grad); + size_t n0 = ggml_element_size(src0); + GGML_ASSERT(offset % n0 == 0); + GGML_ASSERT(nb1 % n0 == 0); + GGML_ASSERT(nb2 % n0 == 0); + GGML_ASSERT(nb3 % n0 == 0); + offset = (offset / n0) * ng; + nb1 = (nb1 / n0) * ng; + nb2 = (nb2 / n0) * ng; + nb3 = (nb3 / n0) * ng; + } + + src0->grad = ggml_acc_impl(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, inplace); + } + } break; + case GGML_OP_PERMUTE: + { + // necessary for llama + if (src0->grad) { + int32_t * axes = (int32_t *) tensor->op_params; + int axis0 = axes[0] & 0x3; + int axis1 = axes[1] & 0x3; + int axis2 = axes[2] & 0x3; + int axis3 = axes[3] & 0x3; + int axes_backward[4] = {0,0,0,0}; + axes_backward[axis0] = 0; + axes_backward[axis1] = 1; + axes_backward[axis2] = 2; + axes_backward[axis3] = 3; + src0->grad = + ggml_add_impl(ctx, src0->grad, + ggml_permute(ctx, + tensor->grad, + axes_backward[0], + axes_backward[1], + axes_backward[2], + axes_backward[3]), + inplace); + } + } break; + case GGML_OP_TRANSPOSE: + { + // necessary for llama + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, src0->grad, + ggml_transpose(ctx, tensor->grad), + inplace); + } + } break; + case GGML_OP_GET_ROWS: + { + // necessary for llama (only for tokenizer) + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, src0->grad, + ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad), + inplace); + } + if (src1->grad) { + // noop + } + } break; + case GGML_OP_GET_ROWS_BACK: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_DIAG: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_DIAG_MASK_INF: + { + // necessary for llama + if (src0->grad) { + const int n_past = ((int32_t *) tensor->op_params)[0]; + src0->grad = + ggml_add_impl(ctx, src0->grad, + ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), + inplace); + } + } break; + case GGML_OP_DIAG_MASK_ZERO: + { + // necessary for llama + if (src0->grad) { + const int n_past = ((int32_t *) tensor->op_params)[0]; + src0->grad = + ggml_add_impl(ctx, src0->grad, + ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), + inplace); + } + } break; + case GGML_OP_SOFT_MAX: + { + // necessary for llama + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, src0->grad, + ggml_soft_max_back(ctx, tensor->grad, tensor), + inplace); + } + + } break; + case GGML_OP_SOFT_MAX_BACK: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_ROPE: + { + // necessary for llama + if (src0->grad) { + const int n_past = ((int32_t *) tensor->op_params)[0]; + const int n_dims = ((int32_t *) tensor->op_params)[1]; + const int mode = ((int32_t *) tensor->op_params)[2]; + const int n_ctx = ((int32_t *) tensor->op_params)[3]; + float freq_base; + float freq_scale; + float xpos_base; + bool xpos_down; + memcpy(&freq_base, (int32_t *) tensor->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float)); + memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float)); + memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool)); + + src0->grad = ggml_add_impl(ctx, + src0->grad, + ggml_rope_back(ctx, + tensor->grad, + n_past, + n_dims, + mode, + n_ctx, + freq_base, + freq_scale, + xpos_base, + xpos_down), + inplace); + } + } break; + case GGML_OP_ROPE_BACK: + { + if (src0->grad) { + const int n_past = ((int32_t *) tensor->op_params)[0]; + const int n_dims = ((int32_t *) tensor->op_params)[1]; + const int mode = ((int32_t *) tensor->op_params)[2]; + const int n_ctx = ((int32_t *) tensor->op_params)[3]; + float freq_base; + float freq_scale; + float xpos_base; + bool xpos_down; + memcpy(&freq_base, (int32_t *) tensor->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float)); + memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float)); + memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool)); + + src0->grad = ggml_add_impl(ctx, + src0->grad, + ggml_rope_impl(ctx, + tensor->grad, + n_past, + n_dims, + mode, + n_ctx, + freq_base, + freq_scale, + xpos_base, + xpos_down, + false), + inplace); + } + } break; + case GGML_OP_ALIBI: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_CLAMP: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_CONV_1D: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_CONV_2D: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_CONV_TRANSPOSE_2D: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_POOL_1D: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_POOL_2D: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_UPSCALE: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_FLASH_ATTN: + { + struct ggml_tensor * flash_grad = NULL; + if (src0->grad || src1->grad || tensor->src[2]->grad) { + int32_t t = ggml_get_op_params_i32(tensor, 0); + GGML_ASSERT(t == 0 || t == 1); + bool masked = t != 0; + flash_grad = + ggml_flash_attn_back(ctx, + src0, + src1, + tensor->src[2], + tensor->grad, + masked); + } + + if (src0->grad) { + struct ggml_tensor * grad_q = NULL; + const size_t nb0 = flash_grad->nb[0]; + const size_t offset = 0; + switch(src0->n_dims) { + case 2: + { + grad_q = ggml_view_2d(ctx, + flash_grad, + src0->ne[0], + src0->ne[1], + nb0*src0->ne[0], + offset); + } break; + case 3: + { + grad_q = ggml_view_3d(ctx, + flash_grad, + src0->ne[0], + src0->ne[1], + src0->ne[2], + nb0*src0->ne[0], + nb0*src0->ne[0]*src0->ne[1], + offset); + } break; + case 4: + { + grad_q = ggml_view_4d(ctx, + flash_grad, + src0->ne[0], + src0->ne[1], + src0->ne[2], + src0->ne[3], + nb0*src0->ne[0], + nb0*src0->ne[0]*src0->ne[1], + nb0*src0->ne[0]*src0->ne[1]*src0->ne[2], + offset); + } break; + } + + src0->grad = ggml_add_impl(ctx, + src0->grad, + grad_q, + inplace); + } + + if (src1->grad) { + struct ggml_tensor * grad_k = NULL; + const size_t nb0 = flash_grad->nb[0]; + const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]; + switch(src1->n_dims) { + case 2: + { + grad_k = ggml_view_2d(ctx, + flash_grad, + src1->ne[0], + src1->ne[1], + nb0*src1->ne[0], + offset); + } break; + case 3: + { + grad_k = ggml_view_3d(ctx, + flash_grad, + src1->ne[0], + src1->ne[1], + src1->ne[2], + nb0*src1->ne[0], + nb0*src1->ne[0]*src1->ne[1], + offset); + } break; + case 4: + { + grad_k = ggml_view_4d(ctx, + flash_grad, + src1->ne[0], + src1->ne[1], + src1->ne[2], + src1->ne[3], + nb0*src1->ne[0], + nb0*src1->ne[0]*src1->ne[1], + nb0*src1->ne[0]*src1->ne[1]*src1->ne[2], + offset); + } break; + } + + src1->grad = ggml_add_impl(ctx, + src1->grad, + grad_k, + inplace); + } + + struct ggml_tensor * opt0 = tensor->src[2]; + + if (opt0->grad) { + struct ggml_tensor * grad_v = NULL; + const size_t nb0 = flash_grad->nb[0]; + const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3] + + nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3]; + switch(opt0->n_dims) { + case 2: + { + grad_v = ggml_view_2d(ctx, + flash_grad, + opt0->ne[0], + opt0->ne[1], + nb0*opt0->ne[0], + offset); + } break; + case 3: + { + grad_v = ggml_view_3d(ctx, + flash_grad, + opt0->ne[0], + opt0->ne[1], + opt0->ne[2], + nb0*opt0->ne[0], + nb0*opt0->ne[0]*opt0->ne[1], + offset); + } break; + case 4: + { + grad_v = ggml_view_4d(ctx, + flash_grad, + opt0->ne[0], + opt0->ne[1], + opt0->ne[2], + opt0->ne[3], + nb0*opt0->ne[0], + nb0*opt0->ne[0]*opt0->ne[1], + nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2], + offset); + } break; + } + + opt0->grad = ggml_add_impl(ctx, + opt0->grad, + grad_v, + inplace); + } + } break; + case GGML_OP_FLASH_FF: + { + GGML_ASSERT(false); // not supported + } break; + case GGML_OP_FLASH_ATTN_BACK: + { + GGML_ASSERT(false); // not supported + } break; + case GGML_OP_WIN_PART: + case GGML_OP_WIN_UNPART: + case GGML_OP_UNARY: + { + switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_ABS: + { + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, + src0->grad, + ggml_mul(ctx, + ggml_sgn(ctx, src0), + tensor->grad), + inplace); + } + } break; + case GGML_UNARY_OP_SGN: + { + if (src0->grad) { + // noop + } + } break; + case GGML_UNARY_OP_NEG: + { + if (src0->grad) { + src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace); + } + } break; + case GGML_UNARY_OP_STEP: + { + if (src0->grad) { + // noop + } + } break; + case GGML_UNARY_OP_TANH: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_UNARY_OP_ELU: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_UNARY_OP_RELU: + { + if (src0->grad) { + src0->grad = ggml_add_impl(ctx, + src0->grad, + ggml_mul(ctx, + ggml_step(ctx, src0), + tensor->grad), + inplace); + } + } break; + case GGML_UNARY_OP_GELU: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_UNARY_OP_GELU_QUICK: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_UNARY_OP_SILU: + { + // necessary for llama + if (src0->grad) { + src0->grad = ggml_add_impl(ctx, + src0->grad, + ggml_silu_back(ctx, src0, tensor->grad), + inplace); + } + } break; + default: + GGML_ASSERT(false); + } + } break; + case GGML_OP_GET_REL_POS: + case GGML_OP_ADD_REL_POS: + case GGML_OP_MAP_UNARY: + case GGML_OP_MAP_BINARY: + case GGML_OP_MAP_CUSTOM1_F32: + case GGML_OP_MAP_CUSTOM2_F32: + case GGML_OP_MAP_CUSTOM3_F32: + case GGML_OP_MAP_CUSTOM1: + case GGML_OP_MAP_CUSTOM2: + case GGML_OP_MAP_CUSTOM3: + { + GGML_ASSERT(false); // not supported + } break; + case GGML_OP_CROSS_ENTROPY_LOSS: + { + if (src0->grad) { + src0->grad = ggml_add_impl(ctx, + src0->grad, + ggml_cross_entropy_loss_back(ctx, + src0, + src1, + tensor->grad), + inplace); + } + } break; + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + { + GGML_ASSERT(false); // not supported + } break; + case GGML_OP_NONE: + { + // nop + } break; + case GGML_OP_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + +static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small"); + +static size_t hash(void * p) { + return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE; +} + +static bool hash_insert(void * hash_table[], void * p) { + size_t h = hash(p); + + // linear probing + size_t i = h; + while (hash_table[i] != NULL && hash_table[i] != p) { + i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE; + if (i == h) { + // hash table is full + GGML_ASSERT(false); + } + } + + if (hash_table[i] == p) { + return true; + } + + // insert + hash_table[i] = p; + return false; +} + +static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { + if (node->grad == NULL) { + // this usually happens when we generate intermediate nodes from constants in the backward pass + // it can also happen during forward pass, if the user performs computations with constants + if (node->op != GGML_OP_NONE) { + //GGML_PRINT_DEBUG("%s: warning: node %p has no grad, but op %d\n", __func__, (void *) node, node->op); + } + } + + // check if already visited + if (hash_insert(cgraph->visited_hash_table, node)) { + return; + } + + for (int i = 0; i < GGML_MAX_SRC; ++i) { + if (node->src[i]) { + ggml_visit_parents(cgraph, node->src[i]); + } + } + + if (node->op == GGML_OP_NONE && node->grad == NULL) { + // reached a leaf node, not part of the gradient graph (e.g. a constant) + GGML_ASSERT(cgraph->n_leafs < GGML_MAX_NODES); + + if (strlen(node->name) == 0) { + ggml_format_name(node, "leaf_%d", cgraph->n_leafs); + } + + cgraph->leafs[cgraph->n_leafs] = node; + cgraph->n_leafs++; + } else { + GGML_ASSERT(cgraph->n_nodes < GGML_MAX_NODES); + + if (strlen(node->name) == 0) { + ggml_format_name(node, "node_%d", cgraph->n_nodes); + } + + cgraph->nodes[cgraph->n_nodes] = node; + cgraph->grads[cgraph->n_nodes] = node->grad; + cgraph->n_nodes++; + } +} + +static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) { + if (!expand) { + cgraph->n_nodes = 0; + cgraph->n_leafs = 0; + } + + const int n0 = cgraph->n_nodes; + UNUSED(n0); + + ggml_visit_parents(cgraph, tensor); + + const int n_new = cgraph->n_nodes - n0; + GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new); + + if (n_new > 0) { + // the last added node should always be starting point + GGML_ASSERT(cgraph->nodes[cgraph->n_nodes - 1] == tensor); + } +} + +void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) { + ggml_build_forward_impl(cgraph, tensor, true); +} + +struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) { + struct ggml_cgraph result = { + /*.n_nodes =*/ 0, + /*.n_leafs =*/ 0, + /*.nodes =*/ { NULL }, + /*.grads =*/ { NULL }, + /*.leafs =*/ { NULL }, + /*.hash_table =*/ { NULL }, + /*.perf_runs =*/ 0, + /*.perf_cycles =*/ 0, + /*.perf_time_us =*/ 0, + }; + + ggml_build_forward_impl(&result, tensor, false); + + return result; +} + +void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep) { + GGML_ASSERT(gf->n_nodes > 0); + + // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph + if (keep) { + for (int i = 0; i < gf->n_nodes; i++) { + struct ggml_tensor * node = gf->nodes[i]; + + if (node->grad) { + node->grad = ggml_dup_tensor(ctx, node); + gf->grads[i] = node->grad; + } + } + } + + for (int i = gf->n_nodes - 1; i >= 0; i--) { + struct ggml_tensor * node = gf->nodes[i]; + + // because we detached the grad nodes from the original graph, we can afford inplace operations + if (node->grad) { + ggml_compute_backward(ctx, node, keep); + } + } + + for (int i = 0; i < gf->n_nodes; i++) { + struct ggml_tensor * node = gf->nodes[i]; + + if (node->is_param) { + GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); + ggml_build_forward_expand(gb, node->grad); + } + } +} + +struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) { + struct ggml_cgraph result = *gf; + ggml_build_backward_expand(ctx, gf, &result, keep); + return result; +} + +struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) { + struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_GRAPH, GGML_GRAPH_SIZE); + struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs); + + *cgraph = (struct ggml_cgraph) { + /*.n_nodes =*/ 0, + /*.n_leafs =*/ 0, + /*.nodes =*/ { NULL }, + /*.grads =*/ { NULL }, + /*.leafs =*/ { NULL }, + /*.hash_table =*/ { NULL }, + /*.perf_runs =*/ 0, + /*.perf_cycles =*/ 0, + /*.perf_time_us =*/ 0, + }; + + return cgraph; +} + +struct ggml_cgraph * ggml_build_forward_ctx(struct ggml_context * ctx, struct ggml_tensor * tensor) { + struct ggml_cgraph * cgraph = ggml_new_graph(ctx); + ggml_build_forward_impl(cgraph, tensor, false); + return cgraph; +} + +size_t ggml_graph_overhead(void) { + return GGML_OBJECT_SIZE + GGML_PAD(GGML_GRAPH_SIZE, GGML_MEM_ALIGN); +} + +// +// thread data +// +// synchronization is done via busy loops +// I tried using spin locks, but not sure how to use them correctly - the things I tried were slower than busy loops +// + +#ifdef __APPLE__ + +//#include +// +//typedef os_unfair_lock ggml_lock_t; +// +//#define ggml_lock_init(x) UNUSED(x) +//#define ggml_lock_destroy(x) UNUSED(x) +//#define ggml_lock_lock os_unfair_lock_lock +//#define ggml_lock_unlock os_unfair_lock_unlock +// +//#define GGML_LOCK_INITIALIZER OS_UNFAIR_LOCK_INIT + +typedef int ggml_lock_t; + +#define ggml_lock_init(x) UNUSED(x) +#define ggml_lock_destroy(x) UNUSED(x) +#define ggml_lock_lock(x) UNUSED(x) +#define ggml_lock_unlock(x) UNUSED(x) + +#define GGML_LOCK_INITIALIZER 0 + +typedef pthread_t ggml_thread_t; + +#define ggml_thread_create pthread_create +#define ggml_thread_join pthread_join + +#else + +//typedef pthread_spinlock_t ggml_lock_t; + +//#define ggml_lock_init(x) pthread_spin_init(x, PTHREAD_PROCESS_PRIVATE) +//#define ggml_lock_destroy pthread_spin_destroy +//#define ggml_lock_lock pthread_spin_lock +//#define ggml_lock_unlock pthread_spin_unlock + +typedef int ggml_lock_t; + +#define ggml_lock_init(x) UNUSED(x) +#define ggml_lock_destroy(x) UNUSED(x) +#if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64)) +#define ggml_lock_lock(x) _mm_pause() +#else +#define ggml_lock_lock(x) UNUSED(x) +#endif +#define ggml_lock_unlock(x) UNUSED(x) + +#define GGML_LOCK_INITIALIZER 0 + +typedef pthread_t ggml_thread_t; + +#define ggml_thread_create pthread_create +#define ggml_thread_join pthread_join + +#endif + +// Android's libc implementation "bionic" does not support setting affinity +#if defined(__linux__) && !defined(__BIONIC__) +static void set_numa_thread_affinity(int thread_n, int n_threads) { + if (!ggml_is_numa()) { + return; + } + + // run thread on node_num thread_n / (threads per node) + const int node_num = thread_n / ((n_threads + g_state.numa.n_nodes - 1) / g_state.numa.n_nodes); + struct ggml_numa_node * node = &g_state.numa.nodes[node_num]; + size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus); + + cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus); + CPU_ZERO_S(setsize, cpus); + for (size_t i = 0; i < node->n_cpus; ++i) { + CPU_SET_S(node->cpus[i], setsize, cpus); + } + + int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus); + if (rv) { + fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", + strerror(rv)); + } + + CPU_FREE(cpus); +} + +static void clear_numa_thread_affinity(void) { + if (!ggml_is_numa()) { + return; + } + + size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus); + + cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus); + CPU_ZERO_S(setsize, cpus); + for (unsigned i = 0; i < g_state.numa.total_cpus; ++i) { + CPU_SET_S(i, setsize, cpus); + } + + int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus); + if (rv) { + fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", + strerror(rv)); + } + + CPU_FREE(cpus); +} +#else +// TODO: Windows etc. +// (the linux implementation may also work on BSD, someone should test) +static void set_numa_thread_affinity(int thread_n, int n_threads) { UNUSED(thread_n); UNUSED(n_threads); } +static void clear_numa_thread_affinity(void) {} +#endif + +struct ggml_compute_state_shared { + const struct ggml_cgraph * cgraph; + const struct ggml_cplan * cplan; + + int64_t perf_node_start_cycles; + int64_t perf_node_start_time_us; + + const int n_threads; + + // synchronization primitives + atomic_int n_active; // num active threads + atomic_int node_n; // active graph node + + bool (*abort_callback)(void * data); // abort ggml_graph_compute when true + void * abort_callback_data; +}; + +struct ggml_compute_state { + ggml_thread_t thrd; + int ith; + struct ggml_compute_state_shared * shared; +}; + +static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const struct ggml_compute_state_shared * st) { + int64_t cycles_cur = ggml_perf_cycles() - st->perf_node_start_cycles; + int64_t time_us_cur = ggml_perf_time_us() - st->perf_node_start_time_us; + + node->perf_runs++; + node->perf_cycles += cycles_cur; + node->perf_time_us += time_us_cur; +} + +static thread_ret_t ggml_graph_compute_thread(void * data) { + struct ggml_compute_state * state = (struct ggml_compute_state *) data; + + const struct ggml_cgraph * cgraph = state->shared->cgraph; + const struct ggml_cplan * cplan = state->shared->cplan; + + const int * n_tasks_arr = cplan->n_tasks; + const int n_threads = state->shared->n_threads; + + set_numa_thread_affinity(state->ith, n_threads); + + int node_n = -1; + + while (true) { + if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { + state->shared->node_n += 1; + return (thread_ret_t) GGML_EXIT_ABORTED; + } + if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) { + // all other threads are finished and spinning + // do finalize and init here so we don't have synchronize again + struct ggml_compute_params params = { + /*.type =*/ GGML_TASK_FINALIZE, + /*.ith =*/ 0, + /*.nth =*/ 0, + /*.wsize =*/ cplan->work_size, + /*.wdata =*/ cplan->work_data, + }; + + if (node_n != -1) { + /* FINALIZE */ + struct ggml_tensor * node = state->shared->cgraph->nodes[node_n]; + if (GGML_OP_HAS_FINALIZE[node->op]) { + params.nth = n_tasks_arr[node_n]; + ggml_compute_forward(¶ms, node); + } + ggml_graph_compute_perf_stats_node(node, state->shared); + } + + // distribute new work or execute it direct if 1T + while (++node_n < cgraph->n_nodes) { + GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes); + + struct ggml_tensor * node = cgraph->nodes[node_n]; + const int n_tasks = n_tasks_arr[node_n]; + + state->shared->perf_node_start_cycles = ggml_perf_cycles(); + state->shared->perf_node_start_time_us = ggml_perf_time_us(); + + params.nth = n_tasks; + + /* INIT */ + if (GGML_OP_HAS_INIT[node->op]) { + params.type = GGML_TASK_INIT; + ggml_compute_forward(¶ms, node); + } + + if (n_tasks == 1) { + // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1, + // they do something more efficient than spinning (?) + params.type = GGML_TASK_COMPUTE; + ggml_compute_forward(¶ms, node); + + if (GGML_OP_HAS_FINALIZE[node->op]) { + params.type = GGML_TASK_FINALIZE; + ggml_compute_forward(¶ms, node); + } + + ggml_graph_compute_perf_stats_node(node, state->shared); + } else { + break; + } + + if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { + break; + } + } + + atomic_store(&state->shared->n_active, n_threads); + atomic_store(&state->shared->node_n, node_n); + } else { + // wait for other threads to finish + const int last = node_n; + do { + //sched_yield(); + node_n = atomic_load(&state->shared->node_n); + } while (node_n == last); + } + + // check if we should stop + if (node_n >= cgraph->n_nodes) break; + + /* COMPUTE */ + struct ggml_tensor * node = cgraph->nodes[node_n]; + const int n_tasks = n_tasks_arr[node_n]; + + struct ggml_compute_params params = { + /*.type =*/ GGML_TASK_COMPUTE, + /*.ith =*/ state->ith, + /*.nth =*/ n_tasks, + /*.wsize =*/ cplan->work_size, + /*.wdata =*/ cplan->work_data, + }; + + if (state->ith < n_tasks) { + ggml_compute_forward(¶ms, node); + } + } + + return GGML_EXIT_SUCCESS; +} + +struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { + if (n_threads <= 0) { + n_threads = GGML_DEFAULT_N_THREADS; + } + + size_t work_size = 0; + + struct ggml_cplan cplan; + memset(&cplan, 0, sizeof(struct ggml_cplan)); + + // thread scheduling for the different operations + work buffer size estimation + for (int i = 0; i < cgraph->n_nodes; i++) { + int n_tasks = 1; + + struct ggml_tensor * node = cgraph->nodes[i]; + + switch (node->op) { + case GGML_OP_CPY: + case GGML_OP_DUP: + { + n_tasks = n_threads; + + size_t cur = 0; + if (ggml_is_quantized(node->type)) { + cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_ADD: + case GGML_OP_ADD1: + { + n_tasks = n_threads; + + size_t cur = 0; + + if (ggml_is_quantized(node->src[0]->type)) { + cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_ACC: + { + n_tasks = n_threads; + + size_t cur = 0; + + if (ggml_is_quantized(node->src[0]->type)) { + cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks; + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_SUB: + case GGML_OP_DIV: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_LOG: + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + case GGML_OP_ARGMAX: + case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + { + n_tasks = 1; + } break; + + case GGML_OP_UNARY: + { + switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_SGN: + case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_RELU: + { + n_tasks = 1; + } break; + + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_SILU: + { + n_tasks = n_threads; + } break; + } + } break; + case GGML_OP_SILU_BACK: + case GGML_OP_MUL: + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: + case GGML_OP_GROUP_NORM: + { + n_tasks = n_threads; + } break; + case GGML_OP_CONCAT: + case GGML_OP_MUL_MAT: + case GGML_OP_OUT_PROD: + { + n_tasks = n_threads; + + // TODO: use different scheduling for different matrix sizes + //const int nr0 = ggml_nrows(node->src[0]); + //const int nr1 = ggml_nrows(node->src[1]); + + //n_tasks = MIN(n_threads, MAX(1, nr0/128)); + //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks); + + size_t cur = 0; + const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type; + +#if defined(GGML_USE_CUBLAS) + if (ggml_cuda_can_mul_mat(node->src[0], node->src[1], node)) { + n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning + } else +#elif defined(GGML_USE_CLBLAST) + if (ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) { + n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning + cur = ggml_cl_mul_mat_get_wsize(node->src[0], node->src[1], node); + } else +#endif +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(node->src[0], node->src[1], node)) { + n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning + if (node->src[0]->type != GGML_TYPE_F32) { + // here we need memory just for single 2D matrix from src0 + cur = ggml_type_size(GGML_TYPE_F32)*(node->src[0]->ne[0]*node->src[0]->ne[1]); + } + } else +#endif + if (node->src[1]->type != vec_dot_type) { + cur = ggml_type_size(vec_dot_type)*ggml_nelements(node->src[1])/ggml_blck_size(vec_dot_type); + } else { + cur = 0; + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_SCALE: + { + n_tasks = 1; + } break; + case GGML_OP_SET: + case GGML_OP_CONT: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_GET_ROWS: + case GGML_OP_GET_ROWS_BACK: + case GGML_OP_DIAG: + { + n_tasks = 1; + } break; + case GGML_OP_DIAG_MASK_ZERO: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + case GGML_OP_ADD_REL_POS: + { + n_tasks = n_threads; + } break; + case GGML_OP_ALIBI: + { + n_tasks = 1; //TODO + } break; + case GGML_OP_CLAMP: + { + n_tasks = 1; //TODO + } break; + case GGML_OP_CONV_1D: + { + n_tasks = n_threads; + + GGML_ASSERT(node->src[0]->ne[3] == 1); + GGML_ASSERT(node->src[1]->ne[2] == 1); + GGML_ASSERT(node->src[1]->ne[3] == 1); + + size_t cur = 0; + const int nk = node->src[0]->ne[0]; + + if (node->src[0]->type == GGML_TYPE_F16 && + node->src[1]->type == GGML_TYPE_F32) { + cur = sizeof(ggml_fp16_t)*( + nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] + + ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1] + ); + } else if (node->src[0]->type == GGML_TYPE_F32 && + node->src[1]->type == GGML_TYPE_F32) { + cur = sizeof(float)*( + nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] + + ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1] + ); + } else { + GGML_ASSERT(false); + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_CONV_2D: + { + n_tasks = n_threads; + + const int64_t ne00 = node->src[0]->ne[0]; // W + const int64_t ne01 = node->src[0]->ne[1]; // H + const int64_t ne02 = node->src[0]->ne[2]; // C + const int64_t ne03 = node->src[0]->ne[3]; // N + + const int64_t ne10 = node->src[1]->ne[0]; // W + const int64_t ne11 = node->src[1]->ne[1]; // H + const int64_t ne12 = node->src[1]->ne[2]; // C + + const int64_t ne0 = node->ne[0]; + const int64_t ne1 = node->ne[1]; + const int64_t ne2 = node->ne[2]; + const int64_t nk = ne00*ne01; + const int64_t ew0 = nk * ne02; + + UNUSED(ne03); + UNUSED(ne2); + + size_t cur = 0; + + if (node->src[0]->type == GGML_TYPE_F16 && + node->src[1]->type == GGML_TYPE_F32) { + cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0); + } else if (node->src[0]->type == GGML_TYPE_F32 && + node->src[1]->type == GGML_TYPE_F32) { + cur = sizeof(float)* (ne10*ne11*ne12); + } else { + GGML_ASSERT(false); + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_CONV_TRANSPOSE_2D: + { + n_tasks = n_threads; + + const int64_t ne00 = node->src[0]->ne[0]; // W + const int64_t ne01 = node->src[0]->ne[1]; // H + const int64_t ne02 = node->src[0]->ne[2]; // Channels Out + const int64_t ne03 = node->src[0]->ne[3]; // Channels In + + const int64_t ne10 = node->src[1]->ne[0]; // W + const int64_t ne11 = node->src[1]->ne[1]; // H + const int64_t ne12 = node->src[1]->ne[2]; // Channels In + + size_t cur = 0; + cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03; + cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_POOL_1D: + case GGML_OP_POOL_2D: + { + n_tasks = 1; + } break; + case GGML_OP_UPSCALE: + { + n_tasks = n_threads; + } break; + case GGML_OP_FLASH_ATTN: + { + n_tasks = n_threads; + + size_t cur = 0; + + const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); + + if (node->src[1]->type == GGML_TYPE_F32) { + cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 + } + + if (node->src[1]->type == GGML_TYPE_F16) { + cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_FLASH_FF: + { + n_tasks = n_threads; + + size_t cur = 0; + + if (node->src[1]->type == GGML_TYPE_F32) { + cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2 + } + + if (node->src[1]->type == GGML_TYPE_F16) { + cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2 + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_FLASH_ATTN_BACK: + { + n_tasks = n_threads; + + size_t cur = 0; + + const int64_t D = node->src[0]->ne[0]; + const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); + const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back + if (node->src[1]->type == GGML_TYPE_F32) { + cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 + } + + if (node->src[1]->type == GGML_TYPE_F16) { + cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_WIN_PART: + case GGML_OP_WIN_UNPART: + case GGML_OP_GET_REL_POS: + case GGML_OP_MAP_UNARY: + case GGML_OP_MAP_BINARY: + case GGML_OP_MAP_CUSTOM1_F32: + case GGML_OP_MAP_CUSTOM2_F32: + case GGML_OP_MAP_CUSTOM3_F32: + { + n_tasks = 1; + } break; + case GGML_OP_MAP_CUSTOM1: + { + struct ggml_map_custom1_op_params * p = (struct ggml_map_custom1_op_params *) node->op_params; + if (p->n_tasks == GGML_N_TASKS_MAX) { + n_tasks = n_threads; + } else { + n_tasks = MIN(p->n_tasks, n_threads); + } + } break; + case GGML_OP_MAP_CUSTOM2: + { + struct ggml_map_custom2_op_params * p = (struct ggml_map_custom2_op_params *) node->op_params; + if (p->n_tasks == GGML_N_TASKS_MAX) { + n_tasks = n_threads; + } else { + n_tasks = MIN(p->n_tasks, n_threads); + } + } break; + case GGML_OP_MAP_CUSTOM3: + { + struct ggml_map_custom3_op_params * p = (struct ggml_map_custom3_op_params *) node->op_params; + if (p->n_tasks == GGML_N_TASKS_MAX) { + n_tasks = n_threads; + } else { + n_tasks = MIN(p->n_tasks, n_threads); + } + } break; + case GGML_OP_CROSS_ENTROPY_LOSS: + { + n_tasks = n_threads; + + size_t cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + { + n_tasks = n_threads; + } break; + case GGML_OP_NONE: + { + n_tasks = 1; + } break; + case GGML_OP_COUNT: + { + GGML_ASSERT(false); + } break; + } + + cplan.n_tasks[i] = n_tasks; + } + + if (work_size > 0) { + work_size += CACHE_LINE_SIZE*(n_threads - 1); + } + + cplan.n_threads = n_threads; + cplan.work_size = work_size; + cplan.work_data = NULL; + + return cplan; +} + +int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) { + { + GGML_ASSERT(cplan); + GGML_ASSERT(cplan->n_threads > 0); + + if (cplan->work_size > 0) { + GGML_ASSERT(cplan->work_data); + } + + for (int i = 0; i < cgraph->n_nodes; ++i) { + if (cgraph->nodes[i]->op != GGML_OP_NONE) { + GGML_ASSERT(cplan->n_tasks[i] > 0); + } + } + } + + const int n_threads = cplan->n_threads; + + struct ggml_compute_state_shared state_shared = { + /*.cgraph =*/ cgraph, + /*.cgraph_plan =*/ cplan, + /*.perf_node_start_cycles =*/ 0, + /*.perf_node_start_time_us =*/ 0, + /*.n_threads =*/ n_threads, + /*.n_active =*/ n_threads, + /*.node_n =*/ -1, + /*.abort_callback =*/ NULL, + /*.abort_callback_data =*/ NULL, + }; + struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads); + + // create thread pool + if (n_threads > 1) { + for (int j = 1; j < n_threads; ++j) { + workers[j] = (struct ggml_compute_state) { + .thrd = 0, + .ith = j, + .shared = &state_shared, + }; + + const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]); + GGML_ASSERT(rc == 0); + UNUSED(rc); + } + } + + workers[0].ith = 0; + workers[0].shared = &state_shared; + + const int64_t perf_start_cycles = ggml_perf_cycles(); + const int64_t perf_start_time_us = ggml_perf_time_us(); + + // this is a work thread too + int compute_status = (size_t) ggml_graph_compute_thread(&workers[0]); + + // don't leave affinity set on the main thread + clear_numa_thread_affinity(); + + // join or kill thread pool + if (n_threads > 1) { + for (int j = 1; j < n_threads; j++) { + const int rc = ggml_thread_join(workers[j].thrd, NULL); + GGML_ASSERT(rc == 0); + } + } + + // performance stats (graph) + { + int64_t perf_cycles_cur = ggml_perf_cycles() - perf_start_cycles; + int64_t perf_time_us_cur = ggml_perf_time_us() - perf_start_time_us; + + cgraph->perf_runs++; + cgraph->perf_cycles += perf_cycles_cur; + cgraph->perf_time_us += perf_time_us_cur; + + GGML_PRINT_DEBUG("%s: perf (%d) - cpu = %.3f / %.3f ms, wall = %.3f / %.3f ms\n", + __func__, cgraph->perf_runs, + (double) perf_cycles_cur / (double) ggml_cycles_per_ms(), + (double) cgraph->perf_cycles / (double) ggml_cycles_per_ms() / (double) cgraph->perf_runs, + (double) perf_time_us_cur / 1000.0, + (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs); + } + + return compute_status; +} + +void ggml_graph_reset(struct ggml_cgraph * cgraph) { + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * grad = cgraph->grads[i]; + + if (grad) { + ggml_set_zero(grad); + } + } +} + +void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) { + struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads); + + struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size); + + cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; + + ggml_graph_compute(cgraph, &cplan); +} + +struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name) { + for (int i = 0; i < cgraph->n_leafs; i++) { + struct ggml_tensor * leaf = cgraph->leafs[i]; + + if (strcmp(leaf->name, name) == 0) { + return leaf; + } + } + + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * node = cgraph->nodes[i]; + + if (strcmp(node->name, name) == 0) { + return node; + } + } + + return NULL; +} + +static void ggml_graph_export_leaf(const struct ggml_tensor * tensor, FILE * fout) { + const int64_t * ne = tensor->ne; + const size_t * nb = tensor->nb; + + fprintf(fout, "%-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n", + ggml_type_name(tensor->type), + ggml_op_name (tensor->op), + tensor->n_dims, + ne[0], ne[1], ne[2], ne[3], + nb[0], nb[1], nb[2], nb[3], + tensor->data, + tensor->name); +} + +static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char * arg, FILE * fout) { + const int64_t * ne = tensor->ne; + const size_t * nb = tensor->nb; + + fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n", + arg, + ggml_type_name(tensor->type), + ggml_op_name (tensor->op), + tensor->n_dims, + ne[0], ne[1], ne[2], ne[3], + nb[0], nb[1], nb[2], nb[3], + tensor->data, + tensor->name); +} + +void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { + uint64_t size_eval = 0; + + // compute size of intermediate results + // TODO: does not take into account scratch buffers !!!! + for (int i = 0; i < cgraph->n_nodes; ++i) { + size_eval += ggml_nbytes_pad(cgraph->nodes[i]); + } + + // print + { + FILE * fout = stdout; + + fprintf(fout, "\n"); + fprintf(fout, "%-16s %8x\n", "magic", GGML_FILE_MAGIC); + fprintf(fout, "%-16s %8d\n", "version", GGML_FILE_VERSION); + fprintf(fout, "%-16s %8d\n", "leafs", cgraph->n_leafs); + fprintf(fout, "%-16s %8d\n", "nodes", cgraph->n_nodes); + fprintf(fout, "%-16s %" PRIu64 "\n", "eval", size_eval); + + // header + fprintf(fout, "\n"); + fprintf(fout, "%-6s %-12s %8s %8s %8s %8s %8s %16s %16s %16s %16s %16s %16s\n", + "TYPE", "OP", "NDIMS", "NE0", "NE1", "NE2", "NE3", "NB0", "NB1", "NB2", "NB3", "DATA", "NAME"); + + for (int i = 0; i < cgraph->n_leafs; ++i) { + ggml_graph_export_leaf(cgraph->leafs[i], fout); + + GGML_ASSERT(cgraph->leafs[i]->op == GGML_OP_NONE); + GGML_ASSERT(cgraph->leafs[i]->src[0] == NULL); + GGML_ASSERT(cgraph->leafs[i]->src[1] == NULL); + } + + // header + fprintf(fout, "\n"); + fprintf(fout, "%-6s %-6s %-12s %8s %8s %8s %8s %8s %16s %16s %16s %16s %8s %16s %16s\n", + "ARG", "TYPE", "OP", "NDIMS", "NE0", "NE1", "NE2", "NE3", "NB0", "NB1", "NB2", "NB3", "NTASKS", "DATA", "NAME"); + + for (int i = 0; i < cgraph->n_nodes; ++i) { + ggml_graph_export_node(cgraph->nodes[i], "DST", fout); + + for (int j = 0; j < GGML_MAX_SRC; ++j) { + if (cgraph->nodes[i]->src[j]) { + ggml_graph_export_node(cgraph->nodes[i]->src[j], "SRC", fout); + } + } + + fprintf(fout, "\n"); + } + + fprintf(fout, "\n"); + } + + // write binary data + { + FILE * fout = fopen(fname, "wb"); + + if (!fout) { + fprintf(stderr, "%s: failed to open %s\n", __func__, fname); + return; + } + + // header + { + const uint32_t magic = GGML_FILE_MAGIC; + const uint32_t version = GGML_FILE_VERSION; + const uint32_t n_leafs = cgraph->n_leafs; + const uint32_t nodes = cgraph->n_nodes; + + fwrite(&magic, sizeof(uint32_t), 1, fout); + fwrite(&version, sizeof(uint32_t), 1, fout); + fwrite(&n_leafs, sizeof(uint32_t), 1, fout); + fwrite(&nodes, sizeof(uint32_t), 1, fout); + fwrite(&size_eval, sizeof(uint64_t), 1, fout); + } + + // leafs + { + for (int i = 0; i < cgraph->n_leafs; ++i) { + const struct ggml_tensor * tensor = cgraph->leafs[i]; + + const uint32_t type = tensor->type; + const uint32_t op = tensor->op; + const uint32_t n_dims = tensor->n_dims; + + fwrite(&type, sizeof(uint32_t), 1, fout); + fwrite(&op, sizeof(uint32_t), 1, fout); + fwrite(&n_dims, sizeof(uint32_t), 1, fout); + + for (int j = 0; j < GGML_MAX_DIMS; ++j) { + const uint64_t ne = tensor->ne[j]; + const uint64_t nb = tensor->nb[j]; + + fwrite(&ne, sizeof(uint64_t), 1, fout); + fwrite(&nb, sizeof(uint64_t), 1, fout); + } + + fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout); + fwrite(tensor->op_params, sizeof(char), GGML_MAX_OP_PARAMS, fout); + + // dump the data + // TODO: pad this to 32 byte boundary + { + const size_t size = ggml_nbytes(tensor); + + fwrite(tensor->data, sizeof(char), size, fout); + } + } + } + + // nodes + { + for (int i = 0; i < cgraph->n_nodes; ++i) { + const struct ggml_tensor * tensor = cgraph->nodes[i]; + + const uint32_t type = tensor->type; + const uint32_t op = tensor->op; + const uint32_t n_dims = tensor->n_dims; + + fwrite(&type, sizeof(uint32_t), 1, fout); + fwrite(&op, sizeof(uint32_t), 1, fout); + fwrite(&n_dims, sizeof(uint32_t), 1, fout); + + for (int j = 0; j < GGML_MAX_DIMS; ++j) { + const uint64_t ne = tensor->ne[j]; + const uint64_t nb = tensor->nb[j]; + + fwrite(&ne, sizeof(uint64_t), 1, fout); + fwrite(&nb, sizeof(uint64_t), 1, fout); + } + + fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout); + fwrite(tensor->op_params, sizeof(char), GGML_MAX_OP_PARAMS, fout); + + // output the op arguments + { + struct ggml_tensor * args[GGML_MAX_SRC] = { NULL }; + + for (int j = 0; j < GGML_MAX_SRC; ++j) { + args[j] = tensor->src[j]; + } + + for (int j = 0; j < GGML_MAX_SRC; ++j) { + if (args[j]) { + int32_t idx = -1; + + // check if leaf + { + for (int k = 0; k < cgraph->n_leafs; ++k) { + if (args[j] == cgraph->leafs[k]) { + idx = k; + break; + } + } + } + + // check if node + if (idx == -1) { + for (int k = 0; k < cgraph->n_nodes; ++k) { + if (args[j] == cgraph->nodes[k]) { + idx = GGML_MAX_NODES + k; + break; + } + } + } + + if (idx == -1) { + fprintf(stderr, "%s: failed to find tensor, arg = %d, node = %d\n", __func__, j, i); + return; + } + + fwrite(&idx, sizeof(int32_t), 1, fout); + } else { + const int32_t nul = -1; + + fwrite(&nul, sizeof(int32_t), 1, fout); + } + } + } + } + } + + fclose(fout); + } +} + +struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval) { + assert(*ctx_data == NULL); + assert(*ctx_eval == NULL); + + struct ggml_cgraph result = { 0 }; + + struct ggml_tensor * data = NULL; + + // read file into data + { + FILE * fin = fopen(fname, "rb"); + if (!fin) { + fprintf(stderr, "%s: failed to open %s\n", __func__, fname); + return result; + } + + size_t fsize = 0; + + fseek(fin, 0, SEEK_END); + fsize = ftell(fin); + fseek(fin, 0, SEEK_SET); + + // create the data context + { + const size_t overhead = 1*ggml_tensor_overhead(); + + struct ggml_init_params params = { + .mem_size = fsize + overhead, + .mem_buffer = NULL, + .no_alloc = false, + }; + + *ctx_data = ggml_init(params); + + if (!*ctx_data) { + fprintf(stderr, "%s: failed to create ggml context\n", __func__); + fclose(fin); + return result; + } + } + + data = ggml_new_tensor_1d(*ctx_data, GGML_TYPE_I8, fsize); + + { + const size_t ret = fread(data->data, sizeof(char), fsize, fin); + if (ret != fsize) { + fprintf(stderr, "%s: failed to read %s\n", __func__, fname); + fclose(fin); + return result; + } + } + + fclose(fin); + } + + // populate result + { + char * ptr = (char *) data->data; + + const uint32_t magic = *(const uint32_t *) ptr; ptr += sizeof(magic); + + if (magic != GGML_FILE_MAGIC) { + fprintf(stderr, "%s: invalid magic number, got %08x\n", __func__, magic); + return result; + } + + const uint32_t version = *(const uint32_t *) ptr; ptr += sizeof(version); + + if (version != GGML_FILE_VERSION) { + fprintf(stderr, "%s: invalid version number\n", __func__); + return result; + } + + const uint32_t n_leafs = *(const uint32_t *) ptr; ptr += sizeof(n_leafs); + const uint32_t n_nodes = *(const uint32_t *) ptr; ptr += sizeof(n_nodes); + const uint64_t size_eval = *(const uint64_t *) ptr; ptr += sizeof(size_eval); + + result.n_leafs = n_leafs; + result.n_nodes = n_nodes; + + // create the data context + { + const size_t overhead = (n_leafs + n_nodes)*ggml_tensor_overhead(); + + struct ggml_init_params params = { + .mem_size = size_eval + overhead, + .mem_buffer = NULL, + .no_alloc = true, + }; + + *ctx_eval = ggml_init(params); + + if (!*ctx_eval) { + fprintf(stderr, "%s: failed to create ggml context\n", __func__); + return result; + } + } + + // leafs + { + uint32_t type; + uint32_t op; + uint32_t n_dims; + + for (uint32_t i = 0; i < n_leafs; ++i) { + type = *(const uint32_t *) ptr; ptr += sizeof(type); + op = *(const uint32_t *) ptr; ptr += sizeof(op); + n_dims = *(const uint32_t *) ptr; ptr += sizeof(n_dims); + + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS]; + + for (int j = 0; j < GGML_MAX_DIMS; ++j) { + uint64_t ne_cur; + uint64_t nb_cur; + + ne_cur = *(const uint64_t *) ptr; ptr += sizeof(ne_cur); + nb_cur = *(const uint64_t *) ptr; ptr += sizeof(nb_cur); + + ne[j] = ne_cur; + nb[j] = nb_cur; + } + + struct ggml_tensor * tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne); + + tensor->op = (enum ggml_op) op; + + memcpy(tensor->name, ptr, GGML_MAX_NAME); ptr += GGML_MAX_NAME; + memcpy(tensor->op_params, ptr, GGML_MAX_OP_PARAMS); ptr += GGML_MAX_OP_PARAMS; + + tensor->data = (void *) ptr; + + for (int j = 0; j < GGML_MAX_DIMS; ++j) { + tensor->nb[j] = nb[j]; + } + + result.leafs[i] = tensor; + + ptr += ggml_nbytes(tensor); + + fprintf(stderr, "%s: loaded leaf %d: '%16s', %3d dims, %9zu bytes\n", __func__, i, tensor->name, n_dims, ggml_nbytes(tensor)); + } + } + + ggml_set_no_alloc(*ctx_eval, false); + + // nodes + { + uint32_t type; + uint32_t op; + uint32_t n_dims; + + for (uint32_t i = 0; i < n_nodes; ++i) { + type = *(const uint32_t *) ptr; ptr += sizeof(type); + op = *(const uint32_t *) ptr; ptr += sizeof(op); + n_dims = *(const uint32_t *) ptr; ptr += sizeof(n_dims); + + enum ggml_op eop = (enum ggml_op) op; + + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS]; + + for (int j = 0; j < GGML_MAX_DIMS; ++j) { + uint64_t ne_cur; + uint64_t nb_cur; + + ne_cur = *(const uint64_t *) ptr; ptr += sizeof(ne_cur); + nb_cur = *(const uint64_t *) ptr; ptr += sizeof(nb_cur); + + ne[j] = ne_cur; + nb[j] = nb_cur; + } + + const char * ptr_name = ptr; ptr += GGML_MAX_NAME; + const char * ptr_op_params = ptr; ptr += GGML_MAX_OP_PARAMS; + + const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += GGML_MAX_SRC*sizeof(int32_t); + + struct ggml_tensor * args[GGML_MAX_SRC] = { NULL }; + + // parse args + for (int j = 0; j < GGML_MAX_SRC; ++j) { + const int32_t arg_idx = ptr_arg_idx[j]; + + if (arg_idx == -1) { + continue; + } + + if (arg_idx < GGML_MAX_NODES) { + args[j] = result.leafs[arg_idx]; + } else { + args[j] = result.nodes[arg_idx - GGML_MAX_NODES]; + } + } + + // create the tensor + // "view" operations are handled differently + // TODO: handle inplace ops - currently a copy is always made + + struct ggml_tensor * tensor = NULL; + + switch (eop) { + // TODO: implement other view ops + case GGML_OP_RESHAPE: + { + tensor = ggml_reshape_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3]); + } break; + case GGML_OP_VIEW: + { + tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0); + + size_t offs; + memcpy(&offs, ptr_op_params, sizeof(offs)); + + tensor->data = ((char *) tensor->data) + offs; + } break; + case GGML_OP_TRANSPOSE: + { + tensor = ggml_transpose(*ctx_eval, args[0]); + } break; + case GGML_OP_PERMUTE: + { + tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0); + } break; + default: + { + tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne); + + tensor->op = eop; + } break; + } + + memcpy(tensor->name, ptr_name, GGML_MAX_NAME); + memcpy(tensor->op_params, ptr_op_params, GGML_MAX_OP_PARAMS); + + for (int j = 0; j < GGML_MAX_DIMS; ++j) { + tensor->nb[j] = nb[j]; + } + + for (int j = 0; j < GGML_MAX_SRC; ++j) { + tensor->src[j] = args[j]; + } + + result.nodes[i] = tensor; + + fprintf(stderr, "%s: loaded node %d: '%16s', %3d dims, %9zu bytes\n", __func__, i, tensor->name, n_dims, ggml_nbytes(tensor)); + } + } + } + + return result; +} + +void ggml_graph_print(const struct ggml_cgraph * cgraph) { + int64_t perf_total_per_op_us[GGML_OP_COUNT] = {0}; + + GGML_PRINT("=== GRAPH ===\n"); + + GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes); + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * node = cgraph->nodes[i]; + + perf_total_per_op_us[node->op] += MAX(1, node->perf_time_us); + + GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n", + i, + node->ne[0], node->ne[1], node->ne[2], + ggml_op_name(node->op), node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs, + (double) node->perf_cycles / (double) ggml_cycles_per_ms(), + (double) node->perf_cycles / (double) ggml_cycles_per_ms() / (double) node->perf_runs, + (double) node->perf_time_us / 1000.0, + (double) node->perf_time_us / 1000.0 / node->perf_runs); + } + + GGML_PRINT("n_leafs = %d\n", cgraph->n_leafs); + for (int i = 0; i < cgraph->n_leafs; i++) { + struct ggml_tensor * node = cgraph->leafs[i]; + + GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n", + i, + node->ne[0], node->ne[1], + ggml_op_name(node->op)); + } + + for (int i = 0; i < GGML_OP_COUNT; i++) { + if (perf_total_per_op_us[i] == 0) { + continue; + } + + GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", ggml_op_name(i), (double) perf_total_per_op_us[i] / 1000.0); + } + + GGML_PRINT("========================================\n"); +} + +// check if node is part of the graph +static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { + if (cgraph == NULL) { + return true; + } + + for (int i = 0; i < cgraph->n_nodes; i++) { + if (cgraph->nodes[i] == node) { + return true; + } + } + + return false; +} + +static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * parent = cgraph->nodes[i]; + + if (parent->grad == node) { + return parent; + } + } + + return NULL; +} + +static void ggml_graph_dump_dot_node_edge(FILE * fp, const struct ggml_cgraph * gb, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) { + struct ggml_tensor * gparent = ggml_graph_get_parent(gb, node); + struct ggml_tensor * gparent0 = ggml_graph_get_parent(gb, parent); + fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"%s\"; ]\n", + gparent0 ? (void *) gparent0 : (void *) parent, + gparent0 ? "g" : "x", + gparent ? (void *) gparent : (void *) node, + gparent ? "g" : "x", + gparent ? "empty" : "vee", + gparent ? "dashed" : "solid", + label); +} + +static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) { + fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"%s\"; ]\n", + (void *) parent, "x", + (void *) node, "x", + label); +} + +void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) { + char color[16]; + + FILE * fp = fopen(filename, "w"); + GGML_ASSERT(fp); + + fprintf(fp, "digraph G {\n"); + fprintf(fp, " newrank = true;\n"); + fprintf(fp, " rankdir = LR;\n"); + + for (int i = 0; i < gb->n_nodes; i++) { + struct ggml_tensor * node = gb->nodes[i]; + + if (ggml_graph_get_parent(gb, node) != NULL) { + continue; + } + + if (node->is_param) { + snprintf(color, sizeof(color), "yellow"); + } else if (node->grad) { + if (ggml_graph_find(gf, node)) { + snprintf(color, sizeof(color), "green"); + } else { + snprintf(color, sizeof(color), "lightblue"); + } + } else { + snprintf(color, sizeof(color), "white"); + } + + fprintf(fp, " \"%p\" [ " + "style = filled; fillcolor = %s; shape = record; " + "label=\"", + (void *) node, color); + + if (strlen(node->name) > 0) { + fprintf(fp, "%s (%s)|", node->name, ggml_type_name(node->type)); + } else { + fprintf(fp, "(%s)|", ggml_type_name(node->type)); + } + + if (node->n_dims == 2) { + fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], ggml_op_symbol(node->op)); + } else { + fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], node->ne[2], ggml_op_symbol(node->op)); + } + + if (node->grad) { + fprintf(fp, " | %s\"; ]\n", ggml_op_symbol(node->grad->op)); + } else { + fprintf(fp, "\"; ]\n"); + } + } + + for (int i = 0; i < gb->n_leafs; i++) { + struct ggml_tensor * node = gb->leafs[i]; + + snprintf(color, sizeof(color), "pink"); + + fprintf(fp, " \"%p\" [ " + "style = filled; fillcolor = %s; shape = record; " + "label=\"", + (void *) node, color); + + if (strlen(node->name) > 0) { + fprintf(fp, "%s (%s)|", node->name, ggml_type_name(node->type)); + } else { + fprintf(fp, "(%s)|", ggml_type_name(node->type)); + } + + fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]); + if (ggml_nelements(node) < 5) { + fprintf(fp, " | ("); + for (int j = 0; j < ggml_nelements(node); j++) { + if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) { + fprintf(fp, "%d", ggml_get_i32_1d(node, j)); + } + else if (node->type == GGML_TYPE_F32 || node->type == GGML_TYPE_F16) { + fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, j)); + } + else { + fprintf(fp, "#"); + } + if (j < ggml_nelements(node) - 1) { + fprintf(fp, ", "); + } + } + fprintf(fp, ")"); + } + fprintf(fp, "\"; ]\n"); + } + + for (int i = 0; i < gb->n_nodes; i++) { + struct ggml_tensor * node = gb->nodes[i]; + + for (int j = 0; j < GGML_MAX_SRC; j++) { + if (node->src[j]) { + char label[16]; + snprintf(label, sizeof(label), "src %d", j); + ggml_graph_dump_dot_node_edge(fp, gb, node, node->src[j], label); + } + } + } + + for (int i = 0; i < gb->n_leafs; i++) { + struct ggml_tensor * node = gb->leafs[i]; + + for (int j = 0; j < GGML_MAX_SRC; j++) { + if (node->src[j]) { + char label[16]; + snprintf(label, sizeof(label), "src %d", j); + ggml_graph_dump_dot_leaf_edge(fp, node, node->src[j], label); + } + } + } + + fprintf(fp, "}\n"); + + fclose(fp); + + GGML_PRINT("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename); +} + +//////////////////////////////////////////////////////////////////////////////// + +static void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const float * x) { + int i = 0; + for (int p = 0; p < np; ++p) { + const int64_t ne = ggml_nelements(ps[p]) ; + // TODO: add function to set tensor from array + for (int64_t j = 0; j < ne; ++j) { + ggml_set_f32_1d(ps[p], j, x[i++]); + } + } +} + +static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * x) { + int i = 0; + for (int p = 0; p < np; ++p) { + const int64_t ne = ggml_nelements(ps[p]) ; + // TODO: add function to get all elements at once + for (int64_t j = 0; j < ne; ++j) { + x[i++] = ggml_get_f32_1d(ps[p], j); + } + } +} + +static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) { + int i = 0; + for (int p = 0; p < np; ++p) { + const int64_t ne = ggml_nelements(ps[p]) ; + // TODO: add function to get all elements at once + for (int64_t j = 0; j < ne; ++j) { + g[i++] = ggml_get_f32_1d(ps[p]->grad, j); + } + } +} + +// +// ADAM +// +// ref: https://arxiv.org/pdf/1412.6980.pdf +// + +static enum ggml_opt_result ggml_opt_adam( + struct ggml_context * ctx, + struct ggml_opt_context * opt, + struct ggml_opt_params params, + struct ggml_tensor * f, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + ggml_opt_callback callback, + void * callback_data) { + GGML_ASSERT(ggml_is_scalar(f)); + + // these will store the parameters we want to optimize + struct ggml_tensor * ps[GGML_MAX_PARAMS]; + + int np = 0; + int64_t nx = 0; + for (int i = 0; i < gf->n_nodes; ++i) { + if (gf->nodes[i]->is_param) { + GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); + + GGML_ASSERT(np < GGML_MAX_PARAMS); + + ps[np++] = gf->nodes[i]; + nx += ggml_nelements(gf->nodes[i]); + } + } + + if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past)) { + int iter = opt->iter; + ggml_opt_init(opt->ctx, opt, params, nx); + opt->iter = iter; + } + + // constants + float sched = params.adam.sched; + const float alpha = params.adam.alpha; + const float decay = params.adam.decay * alpha; + const float beta1 = params.adam.beta1; + const float beta2 = params.adam.beta2; + const float eps = params.adam.eps; + const float gclip = params.adam.gclip; + const int decay_min_ndim = params.adam.decay_min_ndim; + + float * m = opt->adam.m->data; // first moment + float * v = opt->adam.v->data; // second moment + + float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values + + if (callback) { + callback(callback_data, &sched); + } + + // compute the function value + ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + + struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads); + struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size); + cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; + ggml_graph_compute(gb, &cplan); + + opt->adam.fx_prev = ggml_get_f32_1d(f, 0); + opt->adam.fx_best = opt->adam.fx_prev; + if (pf) { + pf[opt->iter % params.past] = opt->adam.fx_prev; + } + + opt->loss_before = opt->adam.fx_prev; + opt->loss_after = opt->adam.fx_prev; + + // initialize + if (opt->just_initialized) { + opt->adam.n_no_improvement = 0; + opt->just_initialized = false; + } + + float * fx_best = &opt->adam.fx_best; + float * fx_prev = &opt->adam.fx_prev; + int * n_no_improvement = &opt->adam.n_no_improvement; + + int iter0 = opt->iter; + + // run the optimizer + for (int t = 0; t < params.adam.n_iter; ++t) { + opt->iter = iter0 + t + 1; + GGML_PRINT_DEBUG ("=== iter %d ===\n", t); + + GGML_PRINT_DEBUG ("f = %10.6f\n", ggml_get_f32_1d(f, 0)); + GGML_PRINT_DEBUG_5("df/dx0 = %10.6f\n", ggml_get_f32_1d(ps[0]->grad, 0)); + GGML_PRINT_DEBUG_5("df/dx1 = %10.6f\n", ggml_get_f32_1d(ps[1]->grad, 0)); + + for (int i = 0; i < np; ++i) { + GGML_PRINT_DEBUG("param %d: %10.6f, g = %10.6f\n", i, + ggml_get_f32_1d(ps[i], 0), ggml_get_f32_1d(ps[i]->grad, 0)); + } + + const int64_t t_start_wall = ggml_time_us(); + const int64_t t_start_cpu = ggml_cycles(); + UNUSED(t_start_wall); + UNUSED(t_start_cpu); + + { + float gnorm = 1.0f; + if (gclip > 0.0f) { + // gradient clipping + ggml_float sum = 0.0; + for (int p = 0; p < np; ++p) { + const int64_t ne = ggml_nelements(ps[p]); + for (int64_t j = 0; j < ne; ++j) { + float g = ggml_get_f32_1d(ps[p]->grad, j); + sum += (ggml_float)(g*g); + } + } + ggml_float norm = sqrt(sum); + if (norm > (ggml_float) gclip) { + gnorm = (float) ((ggml_float) gclip / norm); + } + } + const float beta1h = alpha*sched/(1.0f - powf(beta1, opt->iter)); + const float beta2h = 1.0f/(1.0f - powf(beta2, opt->iter)); + int64_t i = 0; + for (int p = 0; p < np; ++p) { + const int64_t ne = ggml_nelements(ps[p]); + const float p_decay = ((ps[p]->n_dims >= decay_min_ndim) ? decay : 0.0f) * sched; + for (int64_t j = 0; j < ne; ++j) { + float x = ggml_get_f32_1d(ps[p], j); + float g = ggml_get_f32_1d(ps[p]->grad, j)*gnorm; + m[i] = m[i]*beta1 + g*(1.0f - beta1); + v[i] = v[i]*beta2 + g*g*(1.0f - beta2); + float mh = m[i]*beta1h; + float vh = v[i]*beta2h; + vh = sqrtf(vh) + eps; + x = x*(1.0f - p_decay) - mh/vh; + ggml_set_f32_1d(ps[p], j, x); + ++i; + } + } + } + + if (callback) { + callback(callback_data, &sched); + } + + ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + + ggml_graph_compute(gb, &cplan); + + const float fx = ggml_get_f32_1d(f, 0); + opt->loss_after = fx; + + + // check convergence + if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) { + GGML_PRINT_DEBUG("converged\n"); + + return GGML_OPT_OK; + } + + // delta-based convergence test + if (pf != NULL) { + // need at least params.past iterations to start checking for convergence + if (params.past <= iter0 + t) { + const float rate = (pf[(iter0 + t)%params.past] - fx)/fx; + + if (fabsf(rate) < params.delta) { + return GGML_OPT_OK; + } + } + + pf[(iter0 + t)%params.past] = fx; + } + + // check for improvement + if (params.max_no_improvement > 0) { + if (fx_best[0] > fx) { + fx_best[0] = fx; + n_no_improvement[0] = 0; + } else { + ++n_no_improvement[0]; + + if (n_no_improvement[0] >= params.max_no_improvement) { + return GGML_OPT_OK; + } + } + } + + fx_prev[0] = fx; + + { + const int64_t t_end_cpu = ggml_cycles(); + GGML_PRINT_DEBUG("time iter: %5.3f s\n", ((float)(t_end_cpu - t_start_cpu))/CLOCKS_PER_SEC); + UNUSED(t_end_cpu); + + const int64_t t_end_wall = ggml_time_us(); + GGML_PRINT_DEBUG("wall time iter: %5.3f s\n", (t_end_wall - t_start_wall)/1e6); + UNUSED(t_end_wall); + } + } + + return GGML_OPT_DID_NOT_CONVERGE; +} + +// +// L-BFGS +// +// the L-BFGS implementation below is based on the following implementation: +// +// https://github.com/chokkan/liblbfgs +// + +struct ggml_lbfgs_iteration_data { + float alpha; + float ys; + float * s; + float * y; +}; + +static enum ggml_opt_result linesearch_backtracking( + const struct ggml_opt_params * params, + int nx, + float * x, + float * fx, + float * g, + float * d, + float * step, + const float * xp, + struct ggml_tensor * f, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + struct ggml_cplan * cplan, + const int np, + struct ggml_tensor * ps[], + ggml_opt_callback callback, + void * callback_data) { + int count = 0; + + float width = 0.0f; + float dg = 0.0f; + float finit = 0.0f; + float dginit = 0.0f; + float dgtest = 0.0f; + + const float dec = 0.5f; + const float inc = 2.1f; + + if (*step <= 0.f) { + return GGML_LINESEARCH_INVALID_PARAMETERS; + } + + // compute the initial gradient in the search direction + ggml_vec_dot_f32(nx, &dginit, g, d); + + // make sure that d points to a descent direction + if (0 < dginit) { + return GGML_LINESEARCH_FAIL; + } + + // initialize local variables + finit = *fx; + dgtest = params->lbfgs.ftol*dginit; + + while (true) { + if (callback) { + // LBFG-S does not support learning rate -> ignore learning schedule + float sched = 0; + callback(callback_data, &sched); + } + + ggml_vec_cpy_f32(nx, x, xp); + ggml_vec_mad_f32(nx, x, d, *step); + + // evaluate the function and gradient values + { + ggml_opt_set_params(np, ps, x); + + ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + + ggml_graph_compute(gb, cplan); + + ggml_opt_get_grad(np, ps, g); + + *fx = ggml_get_f32_1d(f, 0); + } + + ++count; + + if (*fx > finit + (*step)*dgtest) { + width = dec; + } else { + // Armijo condition is satisfied + if (params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_ARMIJO) { + return count; + } + + ggml_vec_dot_f32(nx, &dg, g, d); + + // check the Wolfe condition + if (dg < params->lbfgs.wolfe * dginit) { + width = inc; + } else { + if(params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE) { + // regular Wolfe conditions + return count; + } + + if(dg > -params->lbfgs.wolfe*dginit) { + width = dec; + } else { + // strong Wolfe condition (GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) + return count; + } + } + } + + if (*step < params->lbfgs.min_step) { + return GGML_LINESEARCH_MINIMUM_STEP; + } + if (*step > params->lbfgs.max_step) { + return GGML_LINESEARCH_MAXIMUM_STEP; + } + if (params->lbfgs.max_linesearch <= count) { + return GGML_LINESEARCH_MAXIMUM_ITERATIONS; + } + + (*step) *= width; + } + + return GGML_LINESEARCH_FAIL; +} + +static enum ggml_opt_result ggml_opt_lbfgs( + struct ggml_context * ctx, + struct ggml_opt_context * opt, + struct ggml_opt_params params, + struct ggml_tensor * f, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + ggml_opt_callback callback, + void * callback_data) { + if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE || + params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) { + if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1.f <= params.lbfgs.wolfe) { + return GGML_OPT_INVALID_WOLFE; + } + } + + const int m = params.lbfgs.m; + + // these will store the parameters we want to optimize + struct ggml_tensor * ps[GGML_MAX_PARAMS]; + + int np = 0; + int nx = 0; + for (int i = 0; i < gf->n_nodes; ++i) { + if (gf->nodes[i]->is_param) { + GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); + + GGML_ASSERT(np < GGML_MAX_PARAMS); + + ps[np++] = gf->nodes[i]; + nx += ggml_nelements(gf->nodes[i]); + } + } + + if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past) || (opt->params.lbfgs.m != params.lbfgs.m)) { + int iter = opt->iter; + ggml_opt_init(ctx, opt, params, nx); + opt->iter = iter; + } + + struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads); + struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size); + cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; + + float * x = opt->lbfgs.x->data; // current parameters + float * xp = opt->lbfgs.xp->data; // previous parameters + float * g = opt->lbfgs.g->data; // current gradient + float * gp = opt->lbfgs.gp->data; // previous gradient + float * d = opt->lbfgs.d->data; // search direction + + float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values + + float fx = 0.0f; // cost function value + float xnorm = 0.0f; // ||x|| + float gnorm = 0.0f; // ||g|| + + // initialize x from the graph nodes + ggml_opt_get_params(np, ps, x); + + // the L-BFGS memory + float * lm_alpha = opt->lbfgs.lmal->data; + float * lm_ys = opt->lbfgs.lmys->data; + float * lm_s = opt->lbfgs.lms->data; + float * lm_y = opt->lbfgs.lmy->data; + + if (callback) { + // LBFG-S does not support learning rate -> ignore learning schedule + float sched = 0; + callback(callback_data, &sched); + } + + // evaluate the function value and its gradient + { + ggml_opt_set_params(np, ps, x); + + ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + + ggml_graph_compute(gb, &cplan); + + ggml_opt_get_grad(np, ps, g); + + fx = ggml_get_f32_1d(f, 0); + + opt->loss_before = fx; + opt->loss_after = fx; + } + + // search direction = -gradient + ggml_vec_neg_f32(nx, d, g); + + // ||x||, ||g|| + ggml_vec_norm_f32(nx, &xnorm, x); + ggml_vec_norm_f32(nx, &gnorm, g); + + if (xnorm < 1.0f) { + xnorm = 1.0f; + } + + // already optimized + if (gnorm/xnorm <= params.lbfgs.eps) { + return GGML_OPT_OK; + } + + if (opt->just_initialized) { + if (pf) { + pf[0] = fx; + } + opt->lbfgs.fx_best = fx; + + // initial step + ggml_vec_norm_inv_f32(nx, &opt->lbfgs.step, d); + opt->lbfgs.j = 0; + opt->lbfgs.k = 1; + opt->lbfgs.end = 0; + opt->lbfgs.n_no_improvement = 0; + opt->just_initialized = false; + } + + float * fx_best = &opt->lbfgs.fx_best; + float * step = &opt->lbfgs.step; + int * j = &opt->lbfgs.j; + int * k = &opt->lbfgs.k; + int * end = &opt->lbfgs.end; + int * n_no_improvement = &opt->lbfgs.n_no_improvement; + + int ls = 0; + int bound = 0; + + float ys = 0.0f; + float yy = 0.0f; + float beta = 0.0f; + + int it = 0; + + while (true) { + // store the current position and gradient vectors + ggml_vec_cpy_f32(nx, xp, x); + ggml_vec_cpy_f32(nx, gp, g); + + ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f, gf, gb, &cplan, np, ps, callback, callback_data); + + if (ls < 0) { + // linesearch failed - go back to the previous point and return + ggml_vec_cpy_f32(nx, x, xp); + ggml_vec_cpy_f32(nx, g, gp); + + return ls; + } + + opt->loss_after = fx; + + ggml_vec_norm_f32(nx, &xnorm, x); + ggml_vec_norm_f32(nx, &gnorm, g); + + GGML_PRINT_DEBUG("f = %10.6f\n", ggml_get_f32_1d(f, 0)); + + if (xnorm < 1.0f) { + xnorm = 1.0f; + } + if (gnorm/xnorm <= params.lbfgs.eps) { + // converged + return GGML_OPT_OK; + } + + // delta-based convergence test + if (pf != NULL) { + // need at least params.past iterations to start checking for convergence + if (params.past <= k[0]) { + const float rate = (pf[k[0]%params.past] - fx)/fx; + + if (fabsf(rate) < params.delta) { + return GGML_OPT_OK; + } + } + + pf[k[0]%params.past] = fx; + } + + // check for improvement + if (params.max_no_improvement > 0) { + if (fx < fx_best[0]) { + fx_best[0] = fx; + n_no_improvement[0] = 0; + } else { + n_no_improvement[0]++; + + if (n_no_improvement[0] >= params.max_no_improvement) { + return GGML_OPT_OK; + } + } + } + + if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < it + 1) { + // reached the maximum number of iterations + return GGML_OPT_DID_NOT_CONVERGE; + } + + // update vectors s and y: + // s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}. + // y_{k+1} = g_{k+1} - g_{k}. + // + ggml_vec_sub_f32(nx, &lm_s[end[0]*nx], x, xp); + ggml_vec_sub_f32(nx, &lm_y[end[0]*nx], g, gp); + + // compute scalars ys and yy: + // ys = y^t \cdot s -> 1 / \rho. + // yy = y^t \cdot y. + // + ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0]*nx]); + ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]); + + lm_ys[end[0]] = ys; + + // find new search direction + // ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS + + bound = (m <= k[0]) ? m : k[0]; + k[0]++; + it++; + end[0] = (end[0] + 1)%m; + + // initialize search direction with -g + ggml_vec_neg_f32(nx, d, g); + + j[0] = end[0]; + for (int i = 0; i < bound; ++i) { + j[0] = (j[0] + m - 1) % m; + // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1} + ggml_vec_dot_f32(nx, &lm_alpha[j[0]], &lm_s[j[0]*nx], d); + lm_alpha[j[0]] /= lm_ys[j[0]]; + // q_{i} = q_{i+1} - \alpha_{i} y_{i} + ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]); + } + + ggml_vec_scale_f32(nx, d, ys/yy); + + for (int i = 0; i < bound; ++i) { + // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i} + ggml_vec_dot_f32(nx, &beta, &lm_y[j[0]*nx], d); + beta /= lm_ys[j[0]]; + // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j} + ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta); + j[0] = (j[0] + 1)%m; + } + + step[0] = 1.0; + } + + return GGML_OPT_DID_NOT_CONVERGE; +} + +struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) { + struct ggml_opt_params result; + + switch (type) { + case GGML_OPT_ADAM: + { + result = (struct ggml_opt_params) { + .type = GGML_OPT_ADAM, + .n_threads = 1, + .past = 0, + .delta = 1e-5f, + + .max_no_improvement = 100, + + .print_forward_graph = true, + .print_backward_graph = true, + + .adam = { + .n_iter = 10000, + .sched = 1.000f, + .decay = 0.0f, + .decay_min_ndim = 2, + .alpha = 0.001f, + .beta1 = 0.9f, + .beta2 = 0.999f, + .eps = 1e-8f, + .eps_f = 1e-5f, + .eps_g = 1e-3f, + .gclip = 0.0f, + }, + }; + } break; + case GGML_OPT_LBFGS: + { + result = (struct ggml_opt_params) { + .type = GGML_OPT_LBFGS, + .n_threads = 1, + .past = 0, + .delta = 1e-5f, + + .max_no_improvement = 0, + + .print_forward_graph = true, + .print_backward_graph = true, + + .lbfgs = { + .m = 6, + .n_iter = 100, + .max_linesearch = 20, + + .eps = 1e-5f, + .ftol = 1e-4f, + .wolfe = 0.9f, + .min_step = 1e-20f, + .max_step = 1e+20f, + + .linesearch = GGML_LINESEARCH_DEFAULT, + }, + }; + } break; + } + + return result; +} + +GGML_API void ggml_opt_init( + struct ggml_context * ctx, + struct ggml_opt_context * opt, + struct ggml_opt_params params, + int64_t nx) { + opt->ctx = ctx; + opt->params = params; + opt->iter = 0; + opt->nx = nx; + opt->just_initialized = true; + switch (opt->params.type) { + case GGML_OPT_ADAM: + { + opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); + opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); + opt->adam.pf = params.past > 0 + ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past) + : NULL; + ggml_set_zero(opt->adam.m); + ggml_set_zero(opt->adam.v); + if (opt->adam.pf) { + ggml_set_zero(opt->adam.pf); + } + } break; + case GGML_OPT_LBFGS: + { + opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); + opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); + opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); + opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); + opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); + opt->lbfgs.pf = params.past > 0 + ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past) + : NULL; + opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m); + opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m); + opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m); + opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m); + ggml_set_zero(opt->lbfgs.x); + ggml_set_zero(opt->lbfgs.xp); + ggml_set_zero(opt->lbfgs.g); + ggml_set_zero(opt->lbfgs.gp); + ggml_set_zero(opt->lbfgs.d); + if (opt->lbfgs.pf) { + ggml_set_zero(opt->lbfgs.pf); + } + ggml_set_zero(opt->lbfgs.lmal); + ggml_set_zero(opt->lbfgs.lmys); + ggml_set_zero(opt->lbfgs.lms); + ggml_set_zero(opt->lbfgs.lmy); + } break; + } +} + +enum ggml_opt_result ggml_opt( + struct ggml_context * ctx, + struct ggml_opt_params params, + struct ggml_tensor * f) { + bool free_ctx = false; + if (ctx == NULL) { + struct ggml_init_params params_ctx = { + .mem_size = 16*1024*1024, + .mem_buffer = NULL, + .no_alloc = false, + }; + + ctx = ggml_init(params_ctx); + if (ctx == NULL) { + return GGML_OPT_NO_CONTEXT; + } + + free_ctx = true; + } + + enum ggml_opt_result result = GGML_OPT_OK; + + struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context)); + + ggml_opt_init(ctx, opt, params, 0); + result = ggml_opt_resume(ctx, opt, f); + + if (free_ctx) { + ggml_free(ctx); + } + + return result; +} + +enum ggml_opt_result ggml_opt_resume( + struct ggml_context * ctx, + struct ggml_opt_context * opt, + struct ggml_tensor * f) { + + // build forward + backward compute graphs + struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32)+ (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0)); + struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32)+ (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0)); + + struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data; + struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data; + + *gf = ggml_build_forward (f); + *gb = ggml_build_backward(ctx, gf, true); + + return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL); +} + +enum ggml_opt_result ggml_opt_resume_g( + struct ggml_context * ctx, + struct ggml_opt_context * opt, + struct ggml_tensor * f, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + ggml_opt_callback callback, + void * callback_data) { + + // build forward + backward compute graphs + enum ggml_opt_result result = GGML_OPT_OK; + + switch (opt->params.type) { + case GGML_OPT_ADAM: + { + result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb, callback, callback_data); + } break; + case GGML_OPT_LBFGS: + { + result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb, callback, callback_data); + } break; + } + + if (opt->params.print_forward_graph) { + ggml_graph_print (gf); + ggml_graph_dump_dot(gf, NULL, "opt-forward.dot"); + } + + if (opt->params.print_backward_graph) { + ggml_graph_print (gb); + ggml_graph_dump_dot(gb, gf, "opt-backward.dot"); + } + + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK4_0 == 0); + const int nb = k / QK4_0; + + for (int b = 0; b < n; b += k) { + block_q4_0 * restrict y = (block_q4_0 *) dst + b/QK4_0; + + quantize_row_q4_0_reference(src + b, y, k); + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < QK4_0; j += 2) { + const uint8_t vi0 = y[i].qs[j/2] & 0x0F; + const uint8_t vi1 = y[i].qs[j/2] >> 4; + + hist[vi0]++; + hist[vi1]++; + } + } + } + + return (n/QK4_0*sizeof(block_q4_0)); +} + +size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK4_1 == 0); + const int nb = k / QK4_1; + + for (int b = 0; b < n; b += k) { + block_q4_1 * restrict y = (block_q4_1 *) dst + b/QK4_1; + + quantize_row_q4_1_reference(src + b, y, k); + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < QK4_1; j += 2) { + const uint8_t vi0 = y[i].qs[j/2] & 0x0F; + const uint8_t vi1 = y[i].qs[j/2] >> 4; + + hist[vi0]++; + hist[vi1]++; + } + } + } + + return (n/QK4_1*sizeof(block_q4_1)); +} + +size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK5_0 == 0); + const int nb = k / QK5_0; + + for (int b = 0; b < n; b += k) { + block_q5_0 * restrict y = (block_q5_0 *)dst + b/QK5_0; + + quantize_row_q5_0_reference(src + b, y, k); + + for (int i = 0; i < nb; i++) { + uint32_t qh; + memcpy(&qh, &y[i].qh, sizeof(qh)); + + for (int j = 0; j < QK5_0; j += 2) { + const uint8_t vh0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + const uint8_t vh1 = ((qh & (1u << (j + 16))) >> (j + 12)); + + // cast to 16 bins + const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2; + const uint8_t vi1 = ((y[i].qs[j/2] >> 4) | vh1) / 2; + + hist[vi0]++; + hist[vi1]++; + } + } + } + + return (n/QK5_0*sizeof(block_q5_0)); +} + +size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK5_1 == 0); + const int nb = k / QK5_1; + + for (int b = 0; b < n; b += k) { + block_q5_1 * restrict y = (block_q5_1 *)dst + b/QK5_1; + + quantize_row_q5_1_reference(src + b, y, k); + + for (int i = 0; i < nb; i++) { + uint32_t qh; + memcpy(&qh, &y[i].qh, sizeof(qh)); + + for (int j = 0; j < QK5_1; j += 2) { + const uint8_t vh0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + const uint8_t vh1 = ((qh & (1u << (j + 16))) >> (j + 12)); + + // cast to 16 bins + const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2; + const uint8_t vi1 = ((y[i].qs[j/2] >> 4) | vh1) / 2; + + hist[vi0]++; + hist[vi1]++; + } + } + } + + return (n/QK5_1*sizeof(block_q5_1)); +} + +size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + for (int b = 0; b < n; b += k) { + block_q8_0 * restrict y = (block_q8_0 *)dst + b/QK8_0; + + quantize_row_q8_0_reference(src + b, y, k); + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < QK8_0; ++j) { + const int8_t vi = y[i].qs[j]; + + hist[vi/16 + 8]++; + } + } + } + + return (n/QK8_0*sizeof(block_q8_0)); +} + +size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) { + size_t result = 0; + switch (type) { + case GGML_TYPE_Q4_0: + { + GGML_ASSERT(start % QK4_0 == 0); + block_q4_0 * block = (block_q4_0*)dst + start / QK4_0; + result = ggml_quantize_q4_0(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q4_1: + { + GGML_ASSERT(start % QK4_1 == 0); + block_q4_1 * block = (block_q4_1*)dst + start / QK4_1; + result = ggml_quantize_q4_1(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q5_0: + { + GGML_ASSERT(start % QK5_0 == 0); + block_q5_0 * block = (block_q5_0*)dst + start / QK5_0; + result = ggml_quantize_q5_0(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q5_1: + { + GGML_ASSERT(start % QK5_1 == 0); + block_q5_1 * block = (block_q5_1*)dst + start / QK5_1; + result = ggml_quantize_q5_1(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q8_0: + { + GGML_ASSERT(start % QK8_0 == 0); + block_q8_0 * block = (block_q8_0*)dst + start / QK8_0; + result = ggml_quantize_q8_0(src + start, block, n, n, hist); + } break; +#ifdef GGML_USE_K_QUANTS + case GGML_TYPE_Q2_K: + { + GGML_ASSERT(start % QK_K == 0); + block_q2_K * block = (block_q2_K*)dst + start / QK_K; + result = ggml_quantize_q2_K(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q3_K: + { + GGML_ASSERT(start % QK_K == 0); + block_q3_K * block = (block_q3_K*)dst + start / QK_K; + result = ggml_quantize_q3_K(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q4_K: + { + GGML_ASSERT(start % QK_K == 0); + block_q4_K * block = (block_q4_K*)dst + start / QK_K; + result = ggml_quantize_q4_K(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q5_K: + { + GGML_ASSERT(start % QK_K == 0); + block_q5_K * block = (block_q5_K*)dst + start / QK_K; + result = ggml_quantize_q5_K(src + start, block, n, n, hist); + } break; + case GGML_TYPE_Q6_K: + { + GGML_ASSERT(start % QK_K == 0); + block_q6_K * block = (block_q6_K*)dst + start / QK_K; + result = ggml_quantize_q6_K(src + start, block, n, n, hist); + } break; +#endif + case GGML_TYPE_F16: + { + int elemsize = sizeof(ggml_fp16_t); + ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n); + result = n * elemsize; + } break; + case GGML_TYPE_F32: + { + int elemsize = sizeof(float); + result = n * elemsize; + memcpy((uint8_t *)dst + start * elemsize, src + start, result); + } break; + default: + assert(false); + } + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +struct gguf_str { + uint64_t n; // GGUFv2 + char * data; +}; + +static const size_t GGUF_TYPE_SIZE[GGUF_TYPE_COUNT] = { + [GGUF_TYPE_UINT8] = sizeof(uint8_t), + [GGUF_TYPE_INT8] = sizeof(int8_t), + [GGUF_TYPE_UINT16] = sizeof(uint16_t), + [GGUF_TYPE_INT16] = sizeof(int16_t), + [GGUF_TYPE_UINT32] = sizeof(uint32_t), + [GGUF_TYPE_INT32] = sizeof(int32_t), + [GGUF_TYPE_FLOAT32] = sizeof(float), + [GGUF_TYPE_BOOL] = sizeof(bool), + [GGUF_TYPE_STRING] = sizeof(struct gguf_str), + [GGUF_TYPE_UINT64] = sizeof(uint64_t), + [GGUF_TYPE_INT64] = sizeof(int64_t), + [GGUF_TYPE_FLOAT64] = sizeof(double), + [GGUF_TYPE_ARRAY] = 0, // undefined +}; +static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13"); + +static const char * GGUF_TYPE_NAME[GGUF_TYPE_COUNT] = { + [GGUF_TYPE_UINT8] = "u8", + [GGUF_TYPE_INT8] = "i8", + [GGUF_TYPE_UINT16] = "u16", + [GGUF_TYPE_INT16] = "i16", + [GGUF_TYPE_UINT32] = "u32", + [GGUF_TYPE_INT32] = "i32", + [GGUF_TYPE_FLOAT32] = "f32", + [GGUF_TYPE_BOOL] = "bool", + [GGUF_TYPE_STRING] = "str", + [GGUF_TYPE_ARRAY] = "arr", + [GGUF_TYPE_UINT64] = "u64", + [GGUF_TYPE_INT64] = "i64", + [GGUF_TYPE_FLOAT64] = "f64", +}; +static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13"); + +union gguf_value { + uint8_t uint8; + int8_t int8; + uint16_t uint16; + int16_t int16; + uint32_t uint32; + int32_t int32; + float float32; + uint64_t uint64; + int64_t int64; + double float64; + bool bool_; + + struct gguf_str str; + + struct { + enum gguf_type type; + + uint64_t n; // GGUFv2 + void * data; + } arr; +}; + +struct gguf_kv { + struct gguf_str key; + + enum gguf_type type; + union gguf_value value; +}; + +struct gguf_header { + uint32_t magic; + uint32_t version; + uint64_t n_tensors; // GGUFv2 + uint64_t n_kv; // GGUFv2 +}; + +struct gguf_tensor_info { + struct gguf_str name; + + uint32_t n_dims; + uint64_t ne[GGML_MAX_DIMS]; + + enum ggml_type type; + + uint64_t offset; // offset from start of `data`, must be a multiple of `ALIGNMENT` + + // for writing API + const void * data; + size_t size; +}; + +struct gguf_context { + struct gguf_header header; + + struct gguf_kv * kv; + struct gguf_tensor_info * infos; + + size_t alignment; + size_t offset; // offset of `data` from beginning of file + size_t size; // size of `data` in bytes + + //uint8_t * padding; + void * data; +}; + +static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset) { + const size_t n = fread(dst, 1, size, file); + *offset += n; + return n == size; +} + +// NOTE: temporary handling of GGUFv1 >> remove after Oct 2023 +static bool gguf_fread_str_cur(FILE * file, struct gguf_str * p, size_t * offset) { + p->n = 0; + p->data = NULL; + + bool ok = true; + + ok = ok && gguf_fread_el(file, &p->n, sizeof(p->n), offset); p->data = calloc(p->n + 1, 1); + ok = ok && gguf_fread_el(file, p->data, p->n, offset); + + return ok; +} + +static bool gguf_fread_str_v1(FILE * file, struct gguf_str * p, size_t * offset) { + p->n = 0; + p->data = NULL; + + bool ok = true; + + uint32_t n = 0; + ok = ok && gguf_fread_el(file, &n, sizeof(n), offset); p->data = calloc(n + 1, 1); p->n = n; + ok = ok && gguf_fread_el(file, p->data, p->n, offset); + + return ok; +} + +struct gguf_context * gguf_init_empty(void) { + struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context)); + + ctx->header.magic = GGUF_MAGIC; + ctx->header.version = GGUF_VERSION; + ctx->header.n_tensors = 0; + ctx->header.n_kv = 0; + + ctx->kv = NULL; + ctx->infos = NULL; + + ctx->alignment = GGUF_DEFAULT_ALIGNMENT; + ctx->offset = 0; + ctx->size = 0; + + ctx->data = NULL; + + return ctx; +} + +struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) { + FILE * file = fopen(fname, "rb"); + if (!file) { + return NULL; + } + + // offset from start of file + size_t offset = 0; + + uint32_t magic = 0; + + // check the magic before making allocations + { + gguf_fread_el(file, &magic, sizeof(magic), &offset); + + if (magic != GGUF_MAGIC) { + fprintf(stderr, "%s: invalid magic number %08x\n", __func__, magic); + fclose(file); + return NULL; + } + } + + bool ok = true; + + struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context)); + + // read the header + { + ctx->header.magic = magic; + + ctx->kv = NULL; + ctx->infos = NULL; + ctx->data = NULL; + + ok = ok && gguf_fread_el(file, &ctx->header.version, sizeof(ctx->header.version), &offset); + + if (ctx->header.version == 1) { + // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023 + uint32_t n_tensors = 0; + uint32_t n_kv = 0; + + ok = ok && gguf_fread_el(file, &n_tensors, sizeof(n_tensors), &offset); + ok = ok && gguf_fread_el(file, &n_kv, sizeof(n_kv), &offset); + + ctx->header.n_tensors = n_tensors; + ctx->header.n_kv = n_kv; + } else { + ok = ok && gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset); + ok = ok && gguf_fread_el(file, &ctx->header.n_kv, sizeof(ctx->header.n_kv), &offset); + } + + if (!ok) { + fprintf(stderr, "%s: failed to read header\n", __func__); + fclose(file); + gguf_free(ctx); + return NULL; + } + } + + // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023 + bool (* gguf_fread_str)(FILE *, struct gguf_str *, size_t *) = gguf_fread_str_cur; + if (ctx->header.version == 1) { + gguf_fread_str = gguf_fread_str_v1; + } + + // read the kv pairs + { + ctx->kv = malloc(ctx->header.n_kv * sizeof(struct gguf_kv)); + + for (uint32_t i = 0; i < ctx->header.n_kv; ++i) { + struct gguf_kv * kv = &ctx->kv[i]; + + //fprintf(stderr, "%s: reading kv %d\n", __func__, i); + + ok = ok && gguf_fread_str(file, &kv->key, &offset); + ok = ok && gguf_fread_el (file, &kv->type, sizeof(kv->type), &offset); + + //fprintf(stderr, "%s: reading kv with key %s\n", __func__, kv->key.data); + + switch (kv->type) { + case GGUF_TYPE_UINT8: ok = ok && gguf_fread_el (file, &kv->value.uint8, sizeof(kv->value.uint8), &offset); break; + case GGUF_TYPE_INT8: ok = ok && gguf_fread_el (file, &kv->value.int8, sizeof(kv->value.int8), &offset); break; + case GGUF_TYPE_UINT16: ok = ok && gguf_fread_el (file, &kv->value.uint16, sizeof(kv->value.uint16), &offset); break; + case GGUF_TYPE_INT16: ok = ok && gguf_fread_el (file, &kv->value.int16, sizeof(kv->value.int16), &offset); break; + case GGUF_TYPE_UINT32: ok = ok && gguf_fread_el (file, &kv->value.uint32, sizeof(kv->value.uint32), &offset); break; + case GGUF_TYPE_INT32: ok = ok && gguf_fread_el (file, &kv->value.int32, sizeof(kv->value.int32), &offset); break; + case GGUF_TYPE_FLOAT32: ok = ok && gguf_fread_el (file, &kv->value.float32, sizeof(kv->value.float32), &offset); break; + case GGUF_TYPE_UINT64: ok = ok && gguf_fread_el (file, &kv->value.uint64, sizeof(kv->value.uint64), &offset); break; + case GGUF_TYPE_INT64: ok = ok && gguf_fread_el (file, &kv->value.int64, sizeof(kv->value.int64), &offset); break; + case GGUF_TYPE_FLOAT64: ok = ok && gguf_fread_el (file, &kv->value.float64, sizeof(kv->value.float64), &offset); break; + case GGUF_TYPE_BOOL: ok = ok && gguf_fread_el (file, &kv->value.bool_, sizeof(kv->value.bool_), &offset); break; + case GGUF_TYPE_STRING: ok = ok && gguf_fread_str(file, &kv->value.str, &offset); break; + case GGUF_TYPE_ARRAY: + { + ok = ok && gguf_fread_el(file, &kv->value.arr.type, sizeof(kv->value.arr.type), &offset); + + if (ctx->header.version == 1) { + // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023 + uint32_t n = 0; + ok = ok && gguf_fread_el(file, &n, sizeof(n), &offset); + kv->value.arr.n = n; + } else { + ok = ok && gguf_fread_el(file, &kv->value.arr.n, sizeof(kv->value.arr.n), &offset); + } + + switch (kv->value.arr.type) { + case GGUF_TYPE_UINT8: + case GGUF_TYPE_INT8: + case GGUF_TYPE_UINT16: + case GGUF_TYPE_INT16: + case GGUF_TYPE_UINT32: + case GGUF_TYPE_INT32: + case GGUF_TYPE_FLOAT32: + case GGUF_TYPE_UINT64: + case GGUF_TYPE_INT64: + case GGUF_TYPE_FLOAT64: + case GGUF_TYPE_BOOL: + { + kv->value.arr.data = malloc(kv->value.arr.n * GGUF_TYPE_SIZE[kv->value.arr.type]); + ok = ok && gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * GGUF_TYPE_SIZE[kv->value.arr.type], &offset); + } break; + case GGUF_TYPE_STRING: + { + kv->value.arr.data = malloc(kv->value.arr.n * sizeof(struct gguf_str)); + for (uint32_t j = 0; j < kv->value.arr.n; ++j) { + ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset); + } + } break; + case GGUF_TYPE_ARRAY: + case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break; + }; + } break; + case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); + }; + + if (!ok) { + break; + } + } + + if (!ok) { + fprintf(stderr, "%s: failed to read key-value pairs\n", __func__); + fclose(file); + gguf_free(ctx); + return NULL; + } + } + + // read the tensor infos + { + ctx->infos = malloc(ctx->header.n_tensors * sizeof(struct gguf_tensor_info)); + + for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { + struct gguf_tensor_info * info = &ctx->infos[i]; + + for (int j = 0; j < GGML_MAX_DIMS; ++j) { + info->ne[j] = 1; + } + + ok = ok && gguf_fread_str(file, &info->name, &offset); + ok = ok && gguf_fread_el (file, &info->n_dims, sizeof(info->n_dims), &offset); + for (uint32_t j = 0; j < info->n_dims; ++j) { + if (ctx->header.version == 1) { + // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023 + uint32_t t = 0; + ok = ok && gguf_fread_el(file, &t, sizeof(t), &offset); + info->ne[j] = t; + } else { + ok = ok && gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset); + } + } + ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset); + ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset); + + if (!ok) { + fprintf(stderr, "%s: failed to read tensor info\n", __func__); + fclose(file); + gguf_free(ctx); + return NULL; + } + } + } + + ctx->alignment = GGUF_DEFAULT_ALIGNMENT; + + int alignment_idx = gguf_find_key(ctx, "general.alignment"); + if (alignment_idx != -1) { + ctx->alignment = gguf_get_val_u32(ctx, alignment_idx); + } + + // we require the data section to be aligned, so take into account any padding + { + const size_t offset_pad = offset % ctx->alignment; + + if (offset_pad != 0) { + offset += ctx->alignment - offset_pad; + fseek(file, offset, SEEK_SET); + } + } + + // store the current file offset - this is where the data section starts + ctx->offset = offset; + + // compute the total size of the data section, taking into account the alignment + { + ctx->size = 0; + for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { + struct gguf_tensor_info * info = &ctx->infos[i]; + + const int64_t ne = + (int64_t) info->ne[0] * + (int64_t) info->ne[1] * + (int64_t) info->ne[2] * + (int64_t) info->ne[3]; + + if (ne % ggml_blck_size(info->type) != 0) { + fprintf(stderr, "%s: tensor '%s' number of elements (%" PRId64 ") is not a multiple of block size (%d)\n", + __func__, info->name.data, ne, ggml_blck_size(info->type)); + fclose(file); + gguf_free(ctx); + return NULL; + } + + const size_t size_cur = (ne*ggml_type_size(info->type))/ggml_blck_size(info->type); + + ctx->size += GGML_PAD(size_cur, ctx->alignment); + } + } + + // load the tensor data only if requested + if (params.ctx != NULL) { + // if the provided gguf_context is no_alloc, then we create "empty" tensors and do not read the binary blob + // otherwise, we load the binary blob into the created ggml_context as well, and point the "data" members of + // the ggml_tensor structs to the appropriate locations in the binary blob + + // compute the exact size needed for the new ggml_context + const size_t mem_size = + params.no_alloc ? + (ctx->header.n_tensors )*ggml_tensor_overhead() : + (ctx->header.n_tensors + 1)*ggml_tensor_overhead() + ctx->size; + + struct ggml_init_params pdata = { + .mem_size = mem_size, + .mem_buffer = NULL, + .no_alloc = params.no_alloc, + }; + + *params.ctx = ggml_init(pdata); + + struct ggml_context * ctx_data = *params.ctx; + + struct ggml_tensor * data = NULL; + + if (!params.no_alloc) { + data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size); + + ok = ok && data != NULL; + + // read the binary blob with the tensor data + ok = ok && gguf_fread_el(file, data->data, ctx->size, &offset); + + if (!ok) { + fprintf(stderr, "%s: failed to read tensor data\n", __func__); + fclose(file); + ggml_free(ctx_data); + gguf_free(ctx); + return NULL; + } + + ctx->data = data->data; + } + + ggml_set_no_alloc(ctx_data, true); + + // create the tensors + for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { + const int64_t ne[GGML_MAX_DIMS] = { + ctx->infos[i].ne[0], + ctx->infos[i].ne[1], + ctx->infos[i].ne[2], + ctx->infos[i].ne[3], + }; + + struct ggml_tensor * cur = ggml_new_tensor(ctx_data, ctx->infos[i].type, ctx->infos[i].n_dims, ne); + + ok = ok && cur != NULL; + + ggml_set_name(cur, ctx->infos[i].name.data); + + if (!ok) { + break; + } + + // point the data member to the appropriate location in the binary blob using the tensor infos + if (!params.no_alloc) { + //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file + cur->data = (char *) data->data + ctx->infos[i].offset; // offset from data + } + } + + if (!ok) { + fprintf(stderr, "%s: failed to read the tensor data\n", __func__); + fclose(file); + ggml_free(ctx_data); + gguf_free(ctx); + return NULL; + } + + ggml_set_no_alloc(ctx_data, params.no_alloc); + } + + fclose(file); + + return ctx; +} + +void gguf_free(struct gguf_context * ctx) { + if (ctx == NULL) { + return; + } + + if (ctx->kv) { + // free string memory - not great.. + for (uint32_t i = 0; i < ctx->header.n_kv; ++i) { + struct gguf_kv * kv = &ctx->kv[i]; + + if (kv->key.data) { + free(kv->key.data); + } + + if (kv->type == GGUF_TYPE_STRING) { + if (kv->value.str.data) { + free(kv->value.str.data); + } + } + + if (kv->type == GGUF_TYPE_ARRAY) { + if (kv->value.arr.data) { + if (kv->value.arr.type == GGUF_TYPE_STRING) { + for (uint32_t j = 0; j < kv->value.arr.n; ++j) { + struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[j]; + if (str->data) { + free(str->data); + } + } + } + free(kv->value.arr.data); + } + } + } + + free(ctx->kv); + } + + if (ctx->infos) { + for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { + struct gguf_tensor_info * info = &ctx->infos[i]; + + if (info->name.data) { + free(info->name.data); + } + } + + free(ctx->infos); + } + + GGML_ALIGNED_FREE(ctx); +} + +const char * gguf_type_name(enum gguf_type type) { + return GGUF_TYPE_NAME[type]; +} + +int gguf_get_version(struct gguf_context * ctx) { + return ctx->header.version; +} + +size_t gguf_get_alignment(struct gguf_context * ctx) { + return ctx->alignment; +} + +size_t gguf_get_data_offset(struct gguf_context * ctx) { + return ctx->offset; +} + +void * gguf_get_data(struct gguf_context * ctx) { + return ctx->data; +} + +int gguf_get_n_kv(struct gguf_context * ctx) { + return ctx->header.n_kv; +} + +int gguf_find_key(struct gguf_context * ctx, const char * key) { + // return -1 if key not found + int keyfound = -1; + + const int n_kv = gguf_get_n_kv(ctx); + + for (int i = 0; i < n_kv; ++i) { + if (strcmp(key, gguf_get_key(ctx, i)) == 0) { + keyfound = i; + break; + } + } + + return keyfound; +} + +const char * gguf_get_key(struct gguf_context * ctx, int i) { + return ctx->kv[i].key.data; +} + +enum gguf_type gguf_get_kv_type(struct gguf_context * ctx, int i) { + return ctx->kv[i].type; +} + +enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i) { + return ctx->kv[i].value.arr.type; +} + +const void * gguf_get_arr_data(struct gguf_context * ctx, int i) { + return ctx->kv[i].value.arr.data; +} + +const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i) { + struct gguf_kv * kv = &ctx->kv[key_id]; + struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i]; + return str->data; +} + +int gguf_get_arr_n(struct gguf_context * ctx, int i) { + return ctx->kv[i].value.arr.n; +} + +uint8_t gguf_get_val_u8(struct gguf_context * ctx, int i) { + return ctx->kv[i].value.uint8; +} + +int8_t gguf_get_val_i8(struct gguf_context * ctx, int i) { + return ctx->kv[i].value.int8; +} + +uint16_t gguf_get_val_u16(struct gguf_context * ctx, int i) { + return ctx->kv[i].value.uint16; +} + +int16_t gguf_get_val_i16(struct gguf_context * ctx, int i) { + return ctx->kv[i].value.int16; +} + +uint32_t gguf_get_val_u32(struct gguf_context * ctx, int i) { + return ctx->kv[i].value.uint32; +} + +int32_t gguf_get_val_i32(struct gguf_context * ctx, int i) { + return ctx->kv[i].value.int32; +} + +float gguf_get_val_f32(struct gguf_context * ctx, int i) { + return ctx->kv[i].value.float32; +} + +uint64_t gguf_get_val_u64(struct gguf_context * ctx, int i) { + return ctx->kv[i].value.uint64; +} + +int64_t gguf_get_val_i64(struct gguf_context * ctx, int i) { + return ctx->kv[i].value.int64; +} + +double gguf_get_val_f64(struct gguf_context * ctx, int i) { + return ctx->kv[i].value.float64; +} + +bool gguf_get_val_bool(struct gguf_context * ctx, int i) { + return ctx->kv[i].value.bool_; +} + +const char * gguf_get_val_str (struct gguf_context * ctx, int i) { + return ctx->kv[i].value.str.data; +} + +int gguf_get_n_tensors(struct gguf_context * ctx) { + return ctx->header.n_tensors; +} + +int gguf_find_tensor(struct gguf_context * ctx, const char * name) { + // return -1 if tensor not found + int tensorfound = -1; + + const int n_tensors = gguf_get_n_tensors(ctx); + + for (int i = 0; i < n_tensors; ++i) { + if (strcmp(name, gguf_get_tensor_name(ctx, i)) == 0) { + tensorfound = i; + break; + } + } + + return tensorfound; +} + +size_t gguf_get_tensor_offset(struct gguf_context * ctx, int i) { + return ctx->infos[i].offset; +} + +char * gguf_get_tensor_name(struct gguf_context * ctx, int i) { + return ctx->infos[i].name.data; +} + +// returns the index +static int gguf_get_or_add_key(struct gguf_context * ctx, const char * key) { + const int idx = gguf_find_key(ctx, key); + if (idx >= 0) { + return idx; + } + + const int n_kv = gguf_get_n_kv(ctx); + + ctx->kv = realloc(ctx->kv, (n_kv + 1) * sizeof(struct gguf_kv)); + ctx->kv[n_kv].key.n = strlen(key); + ctx->kv[n_kv].key.data = strdup(key); + ctx->header.n_kv++; + + return n_kv; +} + +void gguf_set_val_u8(struct gguf_context * ctx, const char * key, uint8_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_UINT8; + ctx->kv[idx].value.uint8 = val; +} + +void gguf_set_val_i8(struct gguf_context * ctx, const char * key, int8_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_INT8; + ctx->kv[idx].value.int8 = val; +} + +void gguf_set_val_u16(struct gguf_context * ctx, const char * key, uint16_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_UINT16; + ctx->kv[idx].value.uint16 = val; +} + +void gguf_set_val_i16(struct gguf_context * ctx, const char * key, int16_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_INT16; + ctx->kv[idx].value.int16 = val; +} + +void gguf_set_val_u32(struct gguf_context * ctx, const char * key, uint32_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_UINT32; + ctx->kv[idx].value.uint32 = val; +} + +void gguf_set_val_i32(struct gguf_context * ctx, const char * key, int32_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_INT32; + ctx->kv[idx].value.int32 = val; +} + +void gguf_set_val_f32(struct gguf_context * ctx, const char * key, float val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_FLOAT32; + ctx->kv[idx].value.float32 = val; +} + +void gguf_set_val_u64(struct gguf_context * ctx, const char * key, uint64_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_UINT64; + ctx->kv[idx].value.uint64 = val; +} + +void gguf_set_val_i64(struct gguf_context * ctx, const char * key, int64_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_INT64; + ctx->kv[idx].value.int64 = val; +} + +void gguf_set_val_f64(struct gguf_context * ctx, const char * key, double val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_FLOAT64; + ctx->kv[idx].value.float64 = val; +} + +void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_BOOL; + ctx->kv[idx].value.bool_ = val; +} + +void gguf_set_val_str(struct gguf_context * ctx, const char * key, const char * val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_STRING; + ctx->kv[idx].value.str.n = strlen(val); + ctx->kv[idx].value.str.data = strdup(val); +} + +void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_ARRAY; + ctx->kv[idx].value.arr.type = type; + ctx->kv[idx].value.arr.n = n; + ctx->kv[idx].value.arr.data = malloc(n*GGUF_TYPE_SIZE[type]); + memcpy(ctx->kv[idx].value.arr.data, data, n*GGUF_TYPE_SIZE[type]); +} + +void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char ** data, int n) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_ARRAY; + ctx->kv[idx].value.arr.type = GGUF_TYPE_STRING; + ctx->kv[idx].value.arr.n = n; + ctx->kv[idx].value.arr.data = malloc(n*sizeof(struct gguf_str)); + for (int i = 0; i < n; i++) { + struct gguf_str * str = &((struct gguf_str *)ctx->kv[idx].value.arr.data)[i]; + str->n = strlen(data[i]); + str->data = strdup(data[i]); + } +} + +// set or add KV pairs from another context +void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) { + for (uint32_t i = 0; i < src->header.n_kv; i++) { + switch (src->kv[i].type) { + case GGUF_TYPE_UINT8: gguf_set_val_u8 (ctx, src->kv[i].key.data, src->kv[i].value.uint8); break; + case GGUF_TYPE_INT8: gguf_set_val_i8 (ctx, src->kv[i].key.data, src->kv[i].value.int8); break; + case GGUF_TYPE_UINT16: gguf_set_val_u16 (ctx, src->kv[i].key.data, src->kv[i].value.uint16); break; + case GGUF_TYPE_INT16: gguf_set_val_i16 (ctx, src->kv[i].key.data, src->kv[i].value.int16); break; + case GGUF_TYPE_UINT32: gguf_set_val_u32 (ctx, src->kv[i].key.data, src->kv[i].value.uint32); break; + case GGUF_TYPE_INT32: gguf_set_val_i32 (ctx, src->kv[i].key.data, src->kv[i].value.int32); break; + case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (ctx, src->kv[i].key.data, src->kv[i].value.float32); break; + case GGUF_TYPE_UINT64: gguf_set_val_u64 (ctx, src->kv[i].key.data, src->kv[i].value.uint64); break; + case GGUF_TYPE_INT64: gguf_set_val_i64 (ctx, src->kv[i].key.data, src->kv[i].value.int64); break; + case GGUF_TYPE_FLOAT64: gguf_set_val_f64 (ctx, src->kv[i].key.data, src->kv[i].value.float64); break; + case GGUF_TYPE_BOOL: gguf_set_val_bool(ctx, src->kv[i].key.data, src->kv[i].value.bool_); break; + case GGUF_TYPE_STRING: gguf_set_val_str (ctx, src->kv[i].key.data, src->kv[i].value.str.data); break; + case GGUF_TYPE_ARRAY: + { + if (src->kv[i].value.arr.type == GGUF_TYPE_STRING) { + const char ** data = malloc(src->kv[i].value.arr.n*sizeof(char *)); + for (uint32_t j = 0; j < src->kv[i].value.arr.n; j++) { + data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data; + } + gguf_set_arr_str(ctx, src->kv[i].key.data, data, src->kv[i].value.arr.n); + free(data); + } else if (src->kv[i].value.arr.type == GGUF_TYPE_ARRAY) { + GGML_ASSERT(false && "nested arrays not supported"); + } else { + gguf_set_arr_data(ctx, src->kv[i].key.data, src->kv[i].value.arr.type, src->kv[i].value.arr.data, src->kv[i].value.arr.n); + } + } break; + case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break; + } + } +} + +void gguf_add_tensor( + struct gguf_context * ctx, + const struct ggml_tensor * tensor) { + const int idx = ctx->header.n_tensors; + ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_tensor_info)); + + ctx->infos[idx].name.n = strlen(tensor->name); + ctx->infos[idx].name.data = strdup(tensor->name); + + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + ctx->infos[idx].ne[i] = 1; + } + + ctx->infos[idx].n_dims = tensor->n_dims; + for (int i = 0; i < tensor->n_dims; i++) { + ctx->infos[idx].ne[i] = tensor->ne[i]; + } + + ctx->infos[idx].type = tensor->type; + ctx->infos[idx].offset = 0; + ctx->infos[idx].data = tensor->data; + ctx->infos[idx].size = ggml_nbytes(tensor); + + if (ctx->header.n_tensors > 0) { + ctx->infos[idx].offset = ctx->infos[idx - 1].offset + GGML_PAD(ctx->infos[idx - 1].size, ctx->alignment); + } + + ctx->header.n_tensors++; +} + +void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type) { + const int idx = gguf_find_tensor(ctx, name); + if (idx < 0) { + GGML_ASSERT(false && "tensor not found"); + } + + ctx->infos[idx].type = type; +} + +void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size) { + const int idx = gguf_find_tensor(ctx, name); + if (idx < 0) { + GGML_ASSERT(false && "tensor not found"); + } + + ctx->infos[idx].data = data; + ctx->infos[idx].size = size; + + // update offsets + for (uint32_t i = idx + 1; i < ctx->header.n_tensors; ++i) { + ctx->infos[i].offset = ctx->infos[i - 1].offset + GGML_PAD(ctx->infos[i - 1].size, ctx->alignment); + } +} + +//static void gguf_fwrite_str(FILE * file, const struct gguf_str * val) { +// fwrite(&val->n, sizeof(val->n), 1, file); +// fwrite(val->data, sizeof(char), val->n, file); +//} +// +//static void gguf_fwrite_el(FILE * file, const void * val, size_t size) { +// fwrite(val, sizeof(char), size, file); +//} + +struct gguf_buf { + void * data; + size_t size; + size_t offset; +}; + +static struct gguf_buf gguf_buf_init(size_t size) { + struct gguf_buf buf = { + /*buf.data =*/ size == 0 ? NULL : malloc(size), + /*buf.size =*/ size, + /*buf.offset =*/ 0, + }; + + return buf; +} + +static void gguf_buf_free(struct gguf_buf buf) { + if (buf.data) { + free(buf.data); + } +} + +static void gguf_buf_grow(struct gguf_buf * buf, size_t size) { + if (buf->offset + size > buf->size) { + buf->size = 1.5*(buf->offset + size); + if (buf->data) { + buf->data = realloc(buf->data, buf->size); + } + } +} + +static void gguf_bwrite_str(struct gguf_buf * buf, const struct gguf_str * val) { + gguf_buf_grow(buf, sizeof(val->n) + val->n); + + if (buf->data) { + memcpy((char *) buf->data + buf->offset, &val->n, sizeof(val->n)); + } + buf->offset += sizeof(val->n); + + if (buf->data) { + memcpy((char *) buf->data + buf->offset, val->data, val->n); + } + buf->offset += val->n; +} + +static void gguf_bwrite_el(struct gguf_buf * buf, const void * val, size_t el_size) { + gguf_buf_grow(buf, el_size); + + if (buf->data) { + memcpy((char *) buf->data + buf->offset, val, el_size); + } + buf->offset += el_size; +} + +static void gguf_write_to_buf(struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) { + // write header + gguf_bwrite_el(buf, &ctx->header.magic, sizeof(ctx->header.magic)); + gguf_bwrite_el(buf, &ctx->header.version, sizeof(ctx->header.version)); + gguf_bwrite_el(buf, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors)); + gguf_bwrite_el(buf, &ctx->header.n_kv, sizeof(ctx->header.n_kv)); + + // write key-value pairs + for (uint32_t i = 0; i < ctx->header.n_kv; ++i) { + struct gguf_kv * kv = &ctx->kv[i]; + + gguf_bwrite_str(buf, &kv->key); + gguf_bwrite_el (buf, &kv->type, sizeof(kv->type)); + + switch (kv->type) { + case GGUF_TYPE_UINT8: gguf_bwrite_el( buf, &kv->value.uint8, sizeof(kv->value.uint8) ); break; + case GGUF_TYPE_INT8: gguf_bwrite_el (buf, &kv->value.int8, sizeof(kv->value.int8) ); break; + case GGUF_TYPE_UINT16: gguf_bwrite_el (buf, &kv->value.uint16, sizeof(kv->value.uint16) ); break; + case GGUF_TYPE_INT16: gguf_bwrite_el (buf, &kv->value.int16, sizeof(kv->value.int16) ); break; + case GGUF_TYPE_UINT32: gguf_bwrite_el (buf, &kv->value.uint32, sizeof(kv->value.uint32) ); break; + case GGUF_TYPE_INT32: gguf_bwrite_el (buf, &kv->value.int32, sizeof(kv->value.int32) ); break; + case GGUF_TYPE_FLOAT32: gguf_bwrite_el (buf, &kv->value.float32, sizeof(kv->value.float32)); break; + case GGUF_TYPE_UINT64: gguf_bwrite_el (buf, &kv->value.uint64, sizeof(kv->value.uint64) ); break; + case GGUF_TYPE_INT64: gguf_bwrite_el (buf, &kv->value.int64, sizeof(kv->value.int64) ); break; + case GGUF_TYPE_FLOAT64: gguf_bwrite_el (buf, &kv->value.float64, sizeof(kv->value.float64)); break; + case GGUF_TYPE_BOOL: gguf_bwrite_el (buf, &kv->value.bool_, sizeof(kv->value.bool_) ); break; + case GGUF_TYPE_STRING: gguf_bwrite_str(buf, &kv->value.str ); break; + case GGUF_TYPE_ARRAY: + { + gguf_bwrite_el(buf, &kv->value.arr.type, sizeof(kv->value.arr.type)); + gguf_bwrite_el(buf, &kv->value.arr.n, sizeof(kv->value.arr.n) ); + + switch (kv->value.arr.type) { + case GGUF_TYPE_UINT8: + case GGUF_TYPE_INT8: + case GGUF_TYPE_UINT16: + case GGUF_TYPE_INT16: + case GGUF_TYPE_UINT32: + case GGUF_TYPE_INT32: + case GGUF_TYPE_FLOAT32: + case GGUF_TYPE_UINT64: + case GGUF_TYPE_INT64: + case GGUF_TYPE_FLOAT64: + case GGUF_TYPE_BOOL: + { + gguf_bwrite_el(buf, kv->value.arr.data, kv->value.arr.n * GGUF_TYPE_SIZE[kv->value.arr.type]); + } break; + case GGUF_TYPE_STRING: + { + for (uint32_t j = 0; j < kv->value.arr.n; ++j) { + gguf_bwrite_str(buf, &((struct gguf_str *) kv->value.arr.data)[j]); + } + } break; + case GGUF_TYPE_ARRAY: + case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break; + }; + } break; + case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); + }; + } + + // write tensor infos + for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { + struct gguf_tensor_info * info = &ctx->infos[i]; + + gguf_bwrite_str(buf, &info->name); + gguf_bwrite_el (buf, &info->n_dims, sizeof(info->n_dims)); + for (uint32_t j = 0; j < info->n_dims; ++j) { + gguf_bwrite_el(buf, &info->ne[j], sizeof(info->ne[j])); + } + gguf_bwrite_el(buf, &info->type, sizeof(info->type)); + gguf_bwrite_el(buf, &info->offset, sizeof(info->offset)); + } + + // we require the data section to be aligned, so take into account any padding + { + const size_t offset = buf->offset; + const size_t offset_pad = GGML_PAD(offset, ctx->alignment); + + if (offset_pad != offset) { + uint8_t pad = 0; + for (size_t i = 0; i < offset_pad - offset; ++i) { + gguf_bwrite_el(buf, &pad, sizeof(pad)); + } + } + } + + if (only_meta) { + return; + } + + size_t offset = 0; + + // write tensor data + for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { + struct gguf_tensor_info * info = &ctx->infos[i]; + + const size_t size = info->size; + const size_t size_pad = GGML_PAD(size, ctx->alignment); + + gguf_bwrite_el(buf, info->data, size); + + if (size_pad != size) { + uint8_t pad = 0; + for (size_t j = 0; j < size_pad - size; ++j) { + gguf_bwrite_el(buf, &pad, sizeof(pad)); + } + } + + GGML_ASSERT(offset == info->offset); + + offset += size_pad; + } +} + +void gguf_write_to_file(struct gguf_context * ctx, const char * fname, bool only_meta) { + FILE * file = fopen(fname, "wb"); + if (!file) { + GGML_ASSERT(false && "failed to open file for writing"); + } + + struct gguf_buf buf = gguf_buf_init(16*1024); + + gguf_write_to_buf(ctx, &buf, only_meta); + + fwrite(buf.data, 1, buf.offset, file); + + gguf_buf_free(buf); + + fclose(file); +} + +size_t gguf_get_meta_size(struct gguf_context * ctx) { + // no allocs - only compute size + struct gguf_buf buf = gguf_buf_init(0); + + gguf_write_to_buf(ctx, &buf, true); + + return buf.offset; +} + +void gguf_get_meta_data(struct gguf_context * ctx, void * data) { + struct gguf_buf buf = gguf_buf_init(16*1024); + + gguf_write_to_buf(ctx, &buf, true); + + memcpy(data, buf.data, buf.offset); + + gguf_buf_free(buf); +} + +//////////////////////////////////////////////////////////////////////////////// + +int ggml_cpu_has_avx(void) { +#if defined(__AVX__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_avx2(void) { +#if defined(__AVX2__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_avx512(void) { +#if defined(__AVX512F__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_avx512_vbmi(void) { +#if defined(__AVX512VBMI__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_avx512_vnni(void) { +#if defined(__AVX512VNNI__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_fma(void) { +#if defined(__FMA__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_neon(void) { +#if defined(__ARM_NEON) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_arm_fma(void) { +#if defined(__ARM_FEATURE_FMA) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_f16c(void) { +#if defined(__F16C__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_fp16_va(void) { +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_wasm_simd(void) { +#if defined(__wasm_simd128__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_blas(void) { +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_cublas(void) { +#if defined(GGML_USE_CUBLAS) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_clblast(void) { +#if defined(GGML_USE_CLBLAST) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_gpublas(void) { + return ggml_cpu_has_cublas() || ggml_cpu_has_clblast(); +} + +int ggml_cpu_has_sse3(void) { +#if defined(__SSE3__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_ssse3(void) { +#if defined(__SSSE3__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_vsx(void) { +#if defined(__POWER9_VECTOR__) + return 1; +#else + return 0; +#endif +} + +//////////////////////////////////////////////////////////////////////////////// diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml.h b/plugins/wasi_nn/thirdparty/ggml/ggml.h new file mode 100644 index 00000000..6d4cf465 --- /dev/null +++ b/plugins/wasi_nn/thirdparty/ggml/ggml.h @@ -0,0 +1,2005 @@ +#pragma once + +// +// GGML Tensor Library +// +// This documentation is still a work in progress. +// If you wish some specific topics to be covered, feel free to drop a comment: +// +// https://github.com/ggerganov/whisper.cpp/issues/40 +// +// ## Overview +// +// This library implements: +// +// - a set of tensor operations +// - automatic differentiation +// - basic optimization algorithms +// +// The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes, +// but is not limited to, the following: +// +// - linear regression +// - support vector machines +// - neural networks +// +// The library allows the user to define a certain function using the available tensor operations. This function +// definition is represented internally via a computation graph. Each tensor operation in the function definition +// corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the +// function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized +// using one of the available optimization algorithms. +// +// For example, here we define the function: f(x) = a*x^2 + b +// +// { +// struct ggml_init_params params = { +// .mem_size = 16*1024*1024, +// .mem_buffer = NULL, +// }; +// +// // memory allocation happens here +// struct ggml_context * ctx = ggml_init(params); +// +// struct ggml_tensor * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); +// +// ggml_set_param(ctx, x); // x is an input variable +// +// struct ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); +// struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); +// struct ggml_tensor * x2 = ggml_mul(ctx, x, x); +// struct ggml_tensor * f = ggml_add(ctx, ggml_mul(ctx, a, x2), b); +// +// ... +// } +// +// Notice that the function definition above does not involve any actual computation. The computation is performed only +// when the user explicitly requests it. For example, to compute the function's value at x = 2.0: +// +// { +// ... +// +// struct ggml_cgraph gf = ggml_build_forward(f); +// +// // set the input variable and parameter values +// ggml_set_f32(x, 2.0f); +// ggml_set_f32(a, 3.0f); +// ggml_set_f32(b, 4.0f); +// +// ggml_graph_compute_with_ctx(ctx, &gf, n_threads); +// +// printf("f = %f\n", ggml_get_f32_1d(f, 0)); +// +// ... +// } +// +// The actual computation is performed in the ggml_graph_compute() function. +// +// The ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the +// ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know +// in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory +// and after defining the computation graph, call the ggml_used_mem() function to find out how much memory was +// actually needed. +// +// The ggml_set_param() function marks a tensor as an input variable. This is used by the automatic +// differentiation and optimization algorithms. +// +// The described approach allows to define the function graph once and then compute its forward or backward graphs +// multiple times. All computations will use the same memory buffer allocated in the ggml_init() function. This way +// the user can avoid the memory allocation overhead at runtime. +// +// The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class +// citizens, but in theory the library can be extended to support FP8 and integer data types. +// +// Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary +// and binary operations. Most of the available operations fall into one of these two categories. With time, it became +// clear that the library needs to support more complex operations. The way to support these operations is not clear +// yet, but a few examples are demonstrated in the following operations: +// +// - ggml_permute() +// - ggml_conv_1d_1s() +// - ggml_conv_1d_2s() +// +// For each tensor operator, the library implements a forward and backward computation function. The forward function +// computes the output tensor value given the input tensor values. The backward function computes the adjoint of the +// input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a +// calculus class, or watch the following video: +// +// What is Automatic Differentiation? +// https://www.youtube.com/watch?v=wG_nF1awSSY +// +// +// ## Tensor data (struct ggml_tensor) +// +// The tensors are stored in memory via the ggml_tensor struct. The structure provides information about the size of +// the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains +// pointers to the "source" tensors - i.e. the tensors that were used to compute the current tensor. For example: +// +// { +// struct ggml_tensor * c = ggml_add(ctx, a, b); +// +// assert(c->src[0] == a); +// assert(c->src[1] == b); +// } +// +// The multi-dimensional tensors are stored in row-major order. The ggml_tensor struct contains fields for the +// number of elements in each dimension ("ne") as well as the number of bytes ("nb", a.k.a. stride). This allows +// to store tensors that are not contiguous in memory, which is useful for operations such as transposition and +// permutation. All tensor operations have to take the stride into account and not assume that the tensor is +// contiguous in memory. +// +// The data of the tensor is accessed via the "data" pointer. For example: +// +// { +// const int nx = 2; +// const int ny = 3; +// +// struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, ny); +// +// for (int y = 0; y < ny; y++) { +// for (int x = 0; x < nx; x++) { +// *(float *) ((char *) a->data + y*a->nb[1] + x*a->nb[0]) = x + y; +// } +// } +// +// ... +// } +// +// Alternatively, there are helper functions, such as ggml_get_f32_1d() and ggml_set_f32_1d() that can be used. +// +// ## The matrix multiplication operator (ggml_mul_mat) +// +// TODO +// +// +// ## Multi-threading +// +// TODO +// +// +// ## Overview of ggml.c +// +// TODO +// +// +// ## SIMD optimizations +// +// TODO +// +// +// ## Debugging ggml +// +// TODO +// +// + +#ifdef GGML_SHARED +# if defined(_WIN32) && !defined(__MINGW32__) +# ifdef GGML_BUILD +# define GGML_API __declspec(dllexport) +# else +# define GGML_API __declspec(dllimport) +# endif +# else +# define GGML_API __attribute__ ((visibility ("default"))) +# endif +#else +# define GGML_API +#endif + +// TODO: support for clang +#ifdef __GNUC__ +# define GGML_DEPRECATED(func, hint) func __attribute__((deprecated(hint))) +#elif defined(_MSC_VER) +# define GGML_DEPRECATED(func, hint) __declspec(deprecated(hint)) func +#else +# define GGML_DEPRECATED(func, hint) func +#endif + +#include +#include +#include + +#define GGML_FILE_MAGIC 0x67676d6c // "ggml" +#define GGML_FILE_VERSION 1 + +#define GGML_QNT_VERSION 2 // bump this on quantization format changes +#define GGML_QNT_VERSION_FACTOR 1000 // do not change this + +#define GGML_MAX_DIMS 4 +#define GGML_MAX_NODES 4096 +#define GGML_MAX_PARAMS 256 +#define GGML_MAX_CONTEXTS 64 +#define GGML_MAX_SRC 6 +#define GGML_MAX_NAME 64 +#define GGML_MAX_OP_PARAMS 32 +#define GGML_DEFAULT_N_THREADS 4 + +#if UINTPTR_MAX == 0xFFFFFFFF + #define GGML_MEM_ALIGN 4 +#else + #define GGML_MEM_ALIGN 16 +#endif + +#define GGML_EXIT_SUCCESS 0 +#define GGML_EXIT_ABORTED 1 + +#define GGUF_MAGIC 0x46554747 // "GGUF" +#define GGUF_VERSION 2 + +#define GGUF_DEFAULT_ALIGNMENT 32 + +#define GGML_UNUSED(x) (void)(x) + +#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1)) + +#define GGML_ASSERT(x) \ + do { \ + if (!(x)) { \ + fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + abort(); \ + } \ + } while (0) + +// used to copy the number of elements and stride in bytes of tensors into local variables. +// main purpose is to reduce code duplication and improve readability. +// +// example: +// +// GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); +// GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); +// +#define GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \ + const type prefix##0 = (pointer)->array[0]; \ + GGML_UNUSED(prefix##0); +#define GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \ + GGML_TENSOR_LOCALS_1 (type, prefix, pointer, array) \ + const type prefix##1 = (pointer)->array[1]; \ + GGML_UNUSED(prefix##1); +#define GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \ + GGML_TENSOR_LOCALS_2 (type, prefix, pointer, array) \ + const type prefix##2 = (pointer)->array[2]; \ + GGML_UNUSED(prefix##2); +#define GGML_TENSOR_LOCALS(type, prefix, pointer, array) \ + GGML_TENSOR_LOCALS_3 (type, prefix, pointer, array) \ + const type prefix##3 = (pointer)->array[3]; \ + GGML_UNUSED(prefix##3); + +#ifdef __cplusplus +extern "C" { +#endif + +#if defined(__ARM_NEON) && defined(__CUDACC__) + typedef half ggml_fp16_t; +#elif defined(__ARM_NEON) && !defined(_MSC_VER) + typedef __fp16 ggml_fp16_t; +#else + typedef uint16_t ggml_fp16_t; +#endif + + // convert FP16 <-> FP32 + GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x); + GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x); + + GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int n); + GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n); + + struct ggml_object; + struct ggml_context; + + enum ggml_type { + GGML_TYPE_F32 = 0, + GGML_TYPE_F16 = 1, + GGML_TYPE_Q4_0 = 2, + GGML_TYPE_Q4_1 = 3, + // GGML_TYPE_Q4_2 = 4, support has been removed + // GGML_TYPE_Q4_3 (5) support has been removed + GGML_TYPE_Q5_0 = 6, + GGML_TYPE_Q5_1 = 7, + GGML_TYPE_Q8_0 = 8, + GGML_TYPE_Q8_1 = 9, + // k-quantizations + GGML_TYPE_Q2_K = 10, + GGML_TYPE_Q3_K = 11, + GGML_TYPE_Q4_K = 12, + GGML_TYPE_Q5_K = 13, + GGML_TYPE_Q6_K = 14, + GGML_TYPE_Q8_K = 15, + GGML_TYPE_I8, + GGML_TYPE_I16, + GGML_TYPE_I32, + GGML_TYPE_COUNT, + }; + + enum ggml_backend { + GGML_BACKEND_CPU = 0, + GGML_BACKEND_GPU = 10, + GGML_BACKEND_GPU_SPLIT = 20, + }; + + // model file types + enum ggml_ftype { + GGML_FTYPE_UNKNOWN = -1, + GGML_FTYPE_ALL_F32 = 0, + GGML_FTYPE_MOSTLY_F16 = 1, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 + GGML_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors + GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors + GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors + GGML_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors + GGML_FTYPE_MOSTLY_Q3_K = 11, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors + GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors + GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors + }; + + // available tensor operations: + enum ggml_op { + GGML_OP_NONE = 0, + + GGML_OP_DUP, + GGML_OP_ADD, + GGML_OP_ADD1, + GGML_OP_ACC, + GGML_OP_SUB, + GGML_OP_MUL, + GGML_OP_DIV, + GGML_OP_SQR, + GGML_OP_SQRT, + GGML_OP_LOG, + GGML_OP_SUM, + GGML_OP_SUM_ROWS, + GGML_OP_MEAN, + GGML_OP_ARGMAX, + GGML_OP_REPEAT, + GGML_OP_REPEAT_BACK, + GGML_OP_CONCAT, + GGML_OP_SILU_BACK, + GGML_OP_NORM, // normalize + GGML_OP_RMS_NORM, + GGML_OP_RMS_NORM_BACK, + GGML_OP_GROUP_NORM, + + GGML_OP_MUL_MAT, + GGML_OP_OUT_PROD, + + GGML_OP_SCALE, + GGML_OP_SET, + GGML_OP_CPY, + GGML_OP_CONT, + GGML_OP_RESHAPE, + GGML_OP_VIEW, + GGML_OP_PERMUTE, + GGML_OP_TRANSPOSE, + GGML_OP_GET_ROWS, + GGML_OP_GET_ROWS_BACK, + GGML_OP_DIAG, + GGML_OP_DIAG_MASK_INF, + GGML_OP_DIAG_MASK_ZERO, + GGML_OP_SOFT_MAX, + GGML_OP_SOFT_MAX_BACK, + GGML_OP_ROPE, + GGML_OP_ROPE_BACK, + GGML_OP_ALIBI, + GGML_OP_CLAMP, + GGML_OP_CONV_1D, + GGML_OP_CONV_2D, + GGML_OP_CONV_TRANSPOSE_2D, + GGML_OP_POOL_1D, + GGML_OP_POOL_2D, + + GGML_OP_UPSCALE, // nearest interpolate + + GGML_OP_FLASH_ATTN, + GGML_OP_FLASH_FF, + GGML_OP_FLASH_ATTN_BACK, + GGML_OP_WIN_PART, + GGML_OP_WIN_UNPART, + GGML_OP_GET_REL_POS, + GGML_OP_ADD_REL_POS, + + GGML_OP_UNARY, + + GGML_OP_MAP_UNARY, + GGML_OP_MAP_BINARY, + + GGML_OP_MAP_CUSTOM1_F32, + GGML_OP_MAP_CUSTOM2_F32, + GGML_OP_MAP_CUSTOM3_F32, + + GGML_OP_MAP_CUSTOM1, + GGML_OP_MAP_CUSTOM2, + GGML_OP_MAP_CUSTOM3, + + GGML_OP_CROSS_ENTROPY_LOSS, + GGML_OP_CROSS_ENTROPY_LOSS_BACK, + + GGML_OP_COUNT, + }; + + enum ggml_unary_op { + GGML_UNARY_OP_ABS, + GGML_UNARY_OP_SGN, + GGML_UNARY_OP_NEG, + GGML_UNARY_OP_STEP, + GGML_UNARY_OP_TANH, + GGML_UNARY_OP_ELU, + GGML_UNARY_OP_RELU, + GGML_UNARY_OP_GELU, + GGML_UNARY_OP_GELU_QUICK, + GGML_UNARY_OP_SILU, + }; + + enum ggml_object_type { + GGML_OBJECT_TENSOR, + GGML_OBJECT_GRAPH, + GGML_OBJECT_WORK_BUFFER + }; + + // ggml object + struct ggml_object { + size_t offs; + size_t size; + + struct ggml_object * next; + + enum ggml_object_type type; + + char padding[4]; + }; + + static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object); + + // n-dimensional tensor + struct ggml_tensor { + enum ggml_type type; + enum ggml_backend backend; + + int n_dims; + int64_t ne[GGML_MAX_DIMS]; // number of elements + size_t nb[GGML_MAX_DIMS]; // stride in bytes: + // nb[0] = sizeof(type) + // nb[1] = nb[0] * ne[0] + padding + // nb[i] = nb[i-1] * ne[i-1] + + // compute data + enum ggml_op op; + + // op params - allocated as int32_t for alignment + int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; + + bool is_param; + + struct ggml_tensor * grad; + struct ggml_tensor * src[GGML_MAX_SRC]; + + // performance + int perf_runs; + int64_t perf_cycles; + int64_t perf_time_us; + + struct ggml_tensor * view_src; + size_t view_offs; + + void * data; + + char name[GGML_MAX_NAME]; + + void * extra; // extra things e.g. for ggml-cuda.cu + + char padding[4]; + }; + + static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); + + // the compute plan that needs to be prepared for ggml_graph_compute() + // since https://github.com/ggerganov/ggml/issues/287 + struct ggml_cplan { + size_t work_size; // size of work buffer, calculated by `ggml_graph_plan()` + uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()` + + int n_threads; + + // the `n_tasks` of nodes, 1:1 mapping to cgraph nodes + int n_tasks[GGML_MAX_NODES]; + + // abort ggml_graph_compute when true + bool (*abort_callback)(void * data); + void * abort_callback_data; + }; + + // next prime after GGML_MAX_NODES + // #define GGML_GRAPH_HASHTABLE_SIZE 4099 + // next prime after GGML_MAX_NODES * 2 (nodes + leafs) + #define GGML_GRAPH_HASHTABLE_SIZE 8273 + + // computation graph + struct ggml_cgraph { + int n_nodes; + int n_leafs; + + struct ggml_tensor * nodes[GGML_MAX_NODES]; + struct ggml_tensor * grads[GGML_MAX_NODES]; + struct ggml_tensor * leafs[GGML_MAX_NODES]; + + void * visited_hash_table[GGML_GRAPH_HASHTABLE_SIZE]; + + // performance + int perf_runs; + int64_t perf_cycles; + int64_t perf_time_us; + }; + + static const size_t GGML_GRAPH_SIZE = sizeof(struct ggml_cgraph); + + // scratch buffer + struct ggml_scratch { + size_t offs; + size_t size; + void * data; + }; + + struct ggml_init_params { + // memory pool + size_t mem_size; // bytes + void * mem_buffer; // if NULL, memory will be allocated internally + bool no_alloc; // don't allocate memory for the tensor data + }; + + + // compute types + + // NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled. + // This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995. + enum ggml_task_type { + GGML_TASK_INIT = 0, + GGML_TASK_COMPUTE, + GGML_TASK_FINALIZE, + }; + + struct ggml_compute_params { + enum ggml_task_type type; + + // ith = thread index, nth = number of threads + int ith, nth; + + // work buffer for all threads + size_t wsize; + void * wdata; + }; + + // misc + + GGML_API void ggml_time_init(void); // call this once at the beginning of the program + GGML_API int64_t ggml_time_ms(void); + GGML_API int64_t ggml_time_us(void); + GGML_API int64_t ggml_cycles(void); + GGML_API int64_t ggml_cycles_per_ms(void); + + GGML_API void ggml_numa_init(void); // call once for better performance on NUMA systems + GGML_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node + + GGML_API void ggml_print_object (const struct ggml_object * obj); + GGML_API void ggml_print_objects(const struct ggml_context * ctx); + + GGML_API int64_t ggml_nelements (const struct ggml_tensor * tensor); + GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor); + GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor); + GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN + GGML_API size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split); + + GGML_API int ggml_blck_size (enum ggml_type type); + GGML_API size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block + GGML_API float ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float + + GGML_API const char * ggml_type_name(enum ggml_type type); + GGML_API const char * ggml_op_name (enum ggml_op op); + GGML_API const char * ggml_op_symbol(enum ggml_op op); + + GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor); + + GGML_API bool ggml_is_quantized(enum ggml_type type); + + // TODO: temporary until model loading of ggml examples is refactored + GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype); + + GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor); + GGML_API bool ggml_is_contiguous(const struct ggml_tensor * tensor); + GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor); + + GGML_API bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1); + + // use this to compute the memory overhead of a tensor + GGML_API size_t ggml_tensor_overhead(void); + + // main + + GGML_API struct ggml_context * ggml_init(struct ggml_init_params params); + GGML_API void ggml_free(struct ggml_context * ctx); + + GGML_API size_t ggml_used_mem(const struct ggml_context * ctx); + + GGML_API size_t ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch); + GGML_API bool ggml_get_no_alloc(struct ggml_context * ctx); + GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc); + + GGML_API void * ggml_get_mem_buffer (const struct ggml_context * ctx); + GGML_API size_t ggml_get_mem_size (const struct ggml_context * ctx); + GGML_API size_t ggml_get_max_tensor_size(const struct ggml_context * ctx); + + GGML_API struct ggml_tensor * ggml_new_tensor( + struct ggml_context * ctx, + enum ggml_type type, + int n_dims, + const int64_t *ne); + + GGML_API struct ggml_tensor * ggml_new_tensor_1d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0); + + GGML_API struct ggml_tensor * ggml_new_tensor_2d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1); + + GGML_API struct ggml_tensor * ggml_new_tensor_3d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2); + + GGML_API struct ggml_tensor * ggml_new_tensor_4d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3); + + GGML_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value); + GGML_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value); + + GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src); + GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src); + + GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name); + + GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); + GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value); + GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value); + + GGML_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i); + GGML_API void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value); + + GGML_API float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i); + GGML_API void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value); + + GGML_API void * ggml_get_data (const struct ggml_tensor * tensor); + GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor); + + GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor); + + GGML_API const char * ggml_get_name (const struct ggml_tensor * tensor); + GGML_API struct ggml_tensor * ggml_set_name ( struct ggml_tensor * tensor, const char * name); + GGML_API struct ggml_tensor * ggml_format_name( struct ggml_tensor * tensor, const char * fmt, ...); + + // + // operations on tensors with backpropagation + // + + GGML_API struct ggml_tensor * ggml_dup( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_dup_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_add( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_add_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_add1( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_add1_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_acc( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset); + + GGML_API struct ggml_tensor * ggml_acc_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset); + + GGML_API struct ggml_tensor * ggml_sub( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_sub_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_mul( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_mul_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_div( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_div_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_sqr( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sqr_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sqrt( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sqrt_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_log( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_log_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // return scalar + GGML_API struct ggml_tensor * ggml_sum( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // sums along rows, with input shape [a,b,c,d] return shape [1,b,c,d] + GGML_API struct ggml_tensor * ggml_sum_rows( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // mean along rows + GGML_API struct ggml_tensor * ggml_mean( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // argmax along rows + GGML_API struct ggml_tensor * ggml_argmax( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // if a is the same shape as b, and a is not parameter, return a + // otherwise, return a new tensor: repeat(a) to fit in b + GGML_API struct ggml_tensor * ggml_repeat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_repeat_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // concat a and b on dim 2 + // used in stable-diffusion + GGML_API struct ggml_tensor * ggml_concat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_abs( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_abs_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sgn( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sgn_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_neg( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_neg_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_step( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_step_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_tanh( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_tanh_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_elu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_elu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_relu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_relu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // TODO: double-check this computation is correct + GGML_API struct ggml_tensor * ggml_gelu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_gelu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_gelu_quick( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_gelu_quick_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_silu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_silu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // a - x + // b - dy + GGML_API struct ggml_tensor * ggml_silu_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // normalize along rows + GGML_API struct ggml_tensor * ggml_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps); + + GGML_API struct ggml_tensor * ggml_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps); + + GGML_API struct ggml_tensor * ggml_rms_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps); + + GGML_API struct ggml_tensor * ggml_rms_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps); + + // group normalize along ne0*ne1*n_groups + // used in stable-diffusion + // TODO: eps is hardcoded to 1e-6 for now + GGML_API struct ggml_tensor * ggml_group_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups); + + GGML_API struct ggml_tensor * ggml_group_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups); + + // a - x + // b - dy + GGML_API struct ggml_tensor * ggml_rms_norm_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float eps); + + // A: n columns, m rows + // B: n columns, p rows (i.e. we transpose it internally) + // result is m columns, p rows + GGML_API struct ggml_tensor * ggml_mul_mat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // A: m columns, n rows, + // B: p columns, n rows, + // result is m columns, p rows + GGML_API struct ggml_tensor * ggml_out_prod( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // + // operations on tensors without backpropagation + // + + GGML_API struct ggml_tensor * ggml_scale( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_scale_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // b -> view(a,offset,nb1,nb2,3), return modified a + GGML_API struct ggml_tensor * ggml_set( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset); + + // b -> view(a,offset,nb1,nb2,3), return view(a) + GGML_API struct ggml_tensor * ggml_set_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset); + + GGML_API struct ggml_tensor * ggml_set_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t offset); + + GGML_API struct ggml_tensor * ggml_set_1d_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t offset); + + // b -> view(a,offset,nb1,nb2,3), return modified a + GGML_API struct ggml_tensor * ggml_set_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t offset); + + // b -> view(a,offset,nb1,nb2,3), return view(a) + GGML_API struct ggml_tensor * ggml_set_2d_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t offset); + + + // a -> b, return view(b) + GGML_API struct ggml_tensor * ggml_cpy( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // a -> b, in-place, return view(b) + GGML_API struct ggml_tensor * ggml_cpy_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // make contiguous + GGML_API struct ggml_tensor * ggml_cont( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // make contiguous, in-place + GGML_API struct ggml_tensor * ggml_cont_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // return view(a), b specifies the new shape + // TODO: when we start computing gradient, make a copy instead of view + GGML_API struct ggml_tensor * ggml_reshape( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // return view(a) + // TODO: when we start computing gradient, make a copy instead of view + GGML_API struct ggml_tensor * ggml_reshape_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0); + + GGML_API struct ggml_tensor * ggml_reshape_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1); + + // return view(a) + // TODO: when we start computing gradient, make a copy instead of view + GGML_API struct ggml_tensor * ggml_reshape_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2); + + GGML_API struct ggml_tensor * ggml_reshape_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3); + + // offset in bytes + GGML_API struct ggml_tensor * ggml_view_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + size_t offset); + + GGML_API struct ggml_tensor * ggml_view_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + size_t nb1, // row stride in bytes + size_t offset); + + GGML_API struct ggml_tensor * ggml_view_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + size_t nb1, // row stride in bytes + size_t nb2, // slice stride in bytes + size_t offset); + + GGML_API struct ggml_tensor * ggml_view_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + size_t nb1, // row stride in bytes + size_t nb2, // slice stride in bytes + size_t nb3, + size_t offset); + + GGML_API struct ggml_tensor * ggml_permute( + struct ggml_context * ctx, + struct ggml_tensor * a, + int axis0, + int axis1, + int axis2, + int axis3); + + // alias for ggml_permute(ctx, a, 1, 0, 2, 3) + GGML_API struct ggml_tensor * ggml_transpose( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_get_rows( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_get_rows_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c); + + GGML_API struct ggml_tensor * ggml_diag( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // set elements above the diagonal to -INF + GGML_API struct ggml_tensor * ggml_diag_mask_inf( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_diag_mask_inf_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past); + + // set elements above the diagonal to 0 + GGML_API struct ggml_tensor * ggml_diag_mask_zero( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_diag_mask_zero_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past); + + GGML_API struct ggml_tensor * ggml_soft_max( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_soft_max_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_soft_max_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_soft_max_back_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // rotary position embedding + // if mode & 1 == 1, skip n_past elements + // if mode & 2 == 1, GPT-NeoX style + // if mode & 4 == 1, ChatGLM style + // TODO: avoid creating a new tensor every time + GGML_API struct ggml_tensor * ggml_rope( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_dims, + int mode, + int n_ctx); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_rope_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_dims, + int mode, + int n_ctx); + + // custom RoPE + GGML_API struct ggml_tensor * ggml_rope_custom( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_dims, + int mode, + int n_ctx, + float freq_base, + float freq_scale); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_rope_custom_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_dims, + int mode, + int n_ctx, + float freq_base, + float freq_scale); + + // xPos RoPE, in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_rope_xpos_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_dims, + float base, + bool down); + + // rotary position embedding backward, i.e compute dx from dy + // a - dy + GGML_API struct ggml_tensor * ggml_rope_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_dims, + int mode, + int n_ctx, + float freq_base, + float freq_scale, + float xpos_base, + bool xpos_down); + + // alibi position embedding + // in-place, returns view(a) + struct ggml_tensor * ggml_alibi( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_head, + float bias_max); + + // clamp + // in-place, returns view(a) + struct ggml_tensor * ggml_clamp( + struct ggml_context * ctx, + struct ggml_tensor * a, + float min, + float max); + + GGML_API struct ggml_tensor * ggml_conv_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, // stride + int p0, // padding + int d0); // dilation + + // conv_1d with padding = half + // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d) + GGML_API struct ggml_tensor* ggml_conv_1d_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s, + int d); + + GGML_API struct ggml_tensor * ggml_conv_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1); + + + // kernel size is a->ne[0] x a->ne[1] + // stride is equal to kernel size + // padding is zero + // example: + // a: 16 16 3 768 + // b: 1024 1024 3 1 + // res: 64 64 768 1 + // used in sam + GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // kernel size is a->ne[0] x a->ne[1] + // stride is 1 + // padding is half + // example: + // a: 3 3 256 256 + // b: 64 64 256 1 + // res: 64 64 256 1 + // used in sam + GGML_API struct ggml_tensor * ggml_conv_2d_s1_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int stride); + + enum ggml_op_pool { + GGML_OP_POOL_MAX, + GGML_OP_POOL_AVG, + GGML_OP_POOL_COUNT, + }; + + GGML_API struct ggml_tensor * ggml_pool_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_op_pool op, + int k0, // kernel size + int s0, // stride + int p0); // padding + + GGML_API struct ggml_tensor * ggml_pool_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_op_pool op, + int k0, + int k1, + int s0, + int s1, + int p0, + int p1); + + // nearest interpolate + // used in stable-diffusion + GGML_API struct ggml_tensor * ggml_upscale( + struct ggml_context * ctx, + struct ggml_tensor * a, + int scale_factor); + + GGML_API struct ggml_tensor * ggml_flash_attn( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + bool masked); + + GGML_API struct ggml_tensor * ggml_flash_attn_back( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * d, + bool masked); + + GGML_API struct ggml_tensor * ggml_flash_ff( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b0, + struct ggml_tensor * b1, + struct ggml_tensor * c0, + struct ggml_tensor * c1); + + // partition into non-overlapping windows with padding if needed + // example: + // a: 768 64 64 1 + // w: 14 + // res: 768 14 14 25 + // used in sam + GGML_API struct ggml_tensor * ggml_win_part( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w); + + // reverse of ggml_win_part + // used in sam + GGML_API struct ggml_tensor * ggml_win_unpart( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w0, + int h0, + int w); + + GGML_API struct ggml_tensor * ggml_unary( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_unary_op op); + + GGML_API struct ggml_tensor * ggml_unary_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_unary_op op); + + // used in sam + GGML_API struct ggml_tensor * ggml_get_rel_pos( + struct ggml_context * ctx, + struct ggml_tensor * a, + int qh, + int kh); + + // used in sam + + GGML_API struct ggml_tensor * ggml_add_rel_pos( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * pw, + struct ggml_tensor * ph); + + GGML_API struct ggml_tensor * ggml_add_rel_pos_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * pw, + struct ggml_tensor * ph); + + // custom operators + + typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *); + typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *); + + typedef void (*ggml_custom1_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *); + typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *); + typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + ggml_unary_op_f32_t fun), + "use ggml_map_custom1 instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + ggml_unary_op_f32_t fun), + "use ggml_map_custom1_inplace instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + ggml_binary_op_f32_t fun), + "use ggml_map_custom2 instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + ggml_binary_op_f32_t fun), + "use ggml_map_custom2_inplace instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + ggml_custom1_op_f32_t fun), + "use ggml_map_custom1 instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + ggml_custom1_op_f32_t fun), + "use ggml_map_custom1_inplace instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + ggml_custom2_op_f32_t fun), + "use ggml_map_custom2 instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + ggml_custom2_op_f32_t fun), + "use ggml_map_custom2_inplace instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + ggml_custom3_op_f32_t fun), + "use ggml_map_custom3 instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + ggml_custom3_op_f32_t fun), + "use ggml_map_custom3_inplace instead"); + + // custom operators v2 + + typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); + typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata); + typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata); + + #define GGML_N_TASKS_MAX -1 + + GGML_API struct ggml_tensor * ggml_map_custom1( + struct ggml_context * ctx, + struct ggml_tensor * a, + ggml_custom1_op_t fun, + int n_tasks, + void * userdata); + + GGML_API struct ggml_tensor * ggml_map_custom1_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + ggml_custom1_op_t fun, + int n_tasks, + void * userdata); + + GGML_API struct ggml_tensor * ggml_map_custom2( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + ggml_custom2_op_t fun, + int n_tasks, + void * userdata); + + GGML_API struct ggml_tensor * ggml_map_custom2_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + ggml_custom2_op_t fun, + int n_tasks, + void * userdata); + + GGML_API struct ggml_tensor * ggml_map_custom3( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + ggml_custom3_op_t fun, + int n_tasks, + void * userdata); + + GGML_API struct ggml_tensor * ggml_map_custom3_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + ggml_custom3_op_t fun, + int n_tasks, + void * userdata); + + // loss function + + GGML_API struct ggml_tensor * ggml_cross_entropy_loss( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c); + + // + // automatic differentiation + // + + GGML_API void ggml_set_param( + struct ggml_context * ctx, + struct ggml_tensor * tensor); + + + GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); + GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep); + + GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor); + GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep); + + // graph allocation in a context + GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); + GGML_API struct ggml_cgraph * ggml_build_forward_ctx(struct ggml_context * ctx, struct ggml_tensor * tensor); + GGML_API size_t ggml_graph_overhead(void); + + // ggml_graph_plan() has to be called before ggml_graph_compute() + // when plan.work_size > 0, caller must allocate memory for plan.work_data + GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/); + GGML_API int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan); + GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); + + // same as ggml_graph_compute() but the work data is allocated as a part of the context + // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data + GGML_API void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads); + + GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name); + + GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname); + GGML_API struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval); + + // print info and performance information for the graph + GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph); + + // dump the graph into a file using the dot format + GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename); + + // + // optimization + // + + // optimization methods + enum ggml_opt_type { + GGML_OPT_ADAM, + GGML_OPT_LBFGS, + }; + + // linesearch methods + enum ggml_linesearch { + GGML_LINESEARCH_DEFAULT = 1, + + GGML_LINESEARCH_BACKTRACKING_ARMIJO = 0, + GGML_LINESEARCH_BACKTRACKING_WOLFE = 1, + GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2, + }; + + // optimization return values + enum ggml_opt_result { + GGML_OPT_OK = 0, + GGML_OPT_DID_NOT_CONVERGE, + GGML_OPT_NO_CONTEXT, + GGML_OPT_INVALID_WOLFE, + GGML_OPT_FAIL, + + GGML_LINESEARCH_FAIL = -128, + GGML_LINESEARCH_MINIMUM_STEP, + GGML_LINESEARCH_MAXIMUM_STEP, + GGML_LINESEARCH_MAXIMUM_ITERATIONS, + GGML_LINESEARCH_INVALID_PARAMETERS, + }; + + typedef void (*ggml_opt_callback)(void * data, float * sched); + + // optimization parameters + // + // see ggml.c (ggml_opt_default_params) for default values + // + struct ggml_opt_params { + enum ggml_opt_type type; + + int n_threads; + + // delta-based convergence test + // + // if past == 0 - disabled + // if past > 0: + // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|) + // + int past; + float delta; + + // maximum number of iterations without improvement + // + // if 0 - disabled + // if > 0: + // assume convergence if no cost improvement in this number of iterations + // + int max_no_improvement; + + bool print_forward_graph; + bool print_backward_graph; + + // ADAM parameters + struct { + int n_iter; + + float sched; // schedule multiplier (fixed, decay or warmup) + float decay; // weight decay for AdamW, use 0.0f to disable + int decay_min_ndim; // minimum number of tensor dimension to apply weight decay + float alpha; // learning rate + float beta1; + float beta2; + float eps; // epsilon for numerical stability + float eps_f; // epsilon for convergence test + float eps_g; // epsilon for convergence test + float gclip; // gradient clipping + } adam; + + // LBFGS parameters + struct { + int m; // number of corrections to approximate the inv. Hessian + int n_iter; + int max_linesearch; + + float eps; // convergence tolerance + float ftol; // line search tolerance + float wolfe; + float min_step; + float max_step; + + enum ggml_linesearch linesearch; + } lbfgs; + }; + + struct ggml_opt_context { + struct ggml_context * ctx; + struct ggml_opt_params params; + + int iter; + int64_t nx; // number of parameter elements + + bool just_initialized; + + float loss_before; + float loss_after; + + struct { + struct ggml_tensor * m; // first moment + struct ggml_tensor * v; // second moment + struct ggml_tensor * pf; // past function values + float fx_best; + float fx_prev; + int n_no_improvement; + } adam; + + struct { + struct ggml_tensor * x; // current parameters + struct ggml_tensor * xp; // previous parameters + struct ggml_tensor * g; // current gradient + struct ggml_tensor * gp; // previous gradient + struct ggml_tensor * d; // search direction + struct ggml_tensor * pf; // past function values + struct ggml_tensor * lmal; // the L-BFGS memory alpha + struct ggml_tensor * lmys; // the L-BFGS memory ys + struct ggml_tensor * lms; // the L-BFGS memory s + struct ggml_tensor * lmy; // the L-BFGS memory y + float fx_best; + float step; + int j; + int k; + int end; + int n_no_improvement; + } lbfgs; + }; + + GGML_API struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type); + + // optimize the function defined by the tensor f + GGML_API enum ggml_opt_result ggml_opt( + struct ggml_context * ctx, + struct ggml_opt_params params, + struct ggml_tensor * f); + + // initialize optimizer context + GGML_API void ggml_opt_init( + struct ggml_context * ctx, + struct ggml_opt_context * opt, + struct ggml_opt_params params, + int64_t nx); + + // continue optimizing the function defined by the tensor f + GGML_API enum ggml_opt_result ggml_opt_resume( + struct ggml_context * ctx, + struct ggml_opt_context * opt, + struct ggml_tensor * f); + + // continue optimizing the function defined by the tensor f + GGML_API enum ggml_opt_result ggml_opt_resume_g( + struct ggml_context * ctx, + struct ggml_opt_context * opt, + struct ggml_tensor * f, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + ggml_opt_callback callback, + void * callback_data); + + // + // quantization + // + + GGML_API size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist); + + GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist); + + // + // gguf + // + + enum gguf_type { + GGUF_TYPE_UINT8 = 0, + GGUF_TYPE_INT8 = 1, + GGUF_TYPE_UINT16 = 2, + GGUF_TYPE_INT16 = 3, + GGUF_TYPE_UINT32 = 4, + GGUF_TYPE_INT32 = 5, + GGUF_TYPE_FLOAT32 = 6, + GGUF_TYPE_BOOL = 7, + GGUF_TYPE_STRING = 8, + GGUF_TYPE_ARRAY = 9, + GGUF_TYPE_UINT64 = 10, + GGUF_TYPE_INT64 = 11, + GGUF_TYPE_FLOAT64 = 12, + GGUF_TYPE_COUNT, // marks the end of the enum + }; + + struct gguf_context; + + struct gguf_init_params { + bool no_alloc; + + // if not NULL, create a ggml_context and allocate the tensor data in it + struct ggml_context ** ctx; + }; + + GGML_API struct gguf_context * gguf_init_empty(void); + GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params); + //GGML_API struct gguf_context * gguf_init_from_buffer(..); + + GGML_API void gguf_free(struct gguf_context * ctx); + + GGML_API const char * gguf_type_name(enum gguf_type type); + + GGML_API int gguf_get_version (struct gguf_context * ctx); + GGML_API size_t gguf_get_alignment (struct gguf_context * ctx); + GGML_API size_t gguf_get_data_offset(struct gguf_context * ctx); + GGML_API void * gguf_get_data (struct gguf_context * ctx); + + GGML_API int gguf_get_n_kv(struct gguf_context * ctx); + GGML_API int gguf_find_key(struct gguf_context * ctx, const char * key); + GGML_API const char * gguf_get_key (struct gguf_context * ctx, int i); + + GGML_API enum gguf_type gguf_get_kv_type (struct gguf_context * ctx, int i); + GGML_API enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i); + + // results are undefined if the wrong type is used for the key + GGML_API uint8_t gguf_get_val_u8 (struct gguf_context * ctx, int i); + GGML_API int8_t gguf_get_val_i8 (struct gguf_context * ctx, int i); + GGML_API uint16_t gguf_get_val_u16 (struct gguf_context * ctx, int i); + GGML_API int16_t gguf_get_val_i16 (struct gguf_context * ctx, int i); + GGML_API uint32_t gguf_get_val_u32 (struct gguf_context * ctx, int i); + GGML_API int32_t gguf_get_val_i32 (struct gguf_context * ctx, int i); + GGML_API float gguf_get_val_f32 (struct gguf_context * ctx, int i); + GGML_API uint64_t gguf_get_val_u64 (struct gguf_context * ctx, int i); + GGML_API int64_t gguf_get_val_i64 (struct gguf_context * ctx, int i); + GGML_API double gguf_get_val_f64 (struct gguf_context * ctx, int i); + GGML_API bool gguf_get_val_bool(struct gguf_context * ctx, int i); + GGML_API const char * gguf_get_val_str (struct gguf_context * ctx, int i); + GGML_API int gguf_get_arr_n (struct gguf_context * ctx, int i); + GGML_API const void * gguf_get_arr_data(struct gguf_context * ctx, int i); + GGML_API const char * gguf_get_arr_str (struct gguf_context * ctx, int key_id, int i); + + GGML_API int gguf_get_n_tensors (struct gguf_context * ctx); + GGML_API int gguf_find_tensor (struct gguf_context * ctx, const char * name); + GGML_API size_t gguf_get_tensor_offset(struct gguf_context * ctx, int i); + GGML_API char * gguf_get_tensor_name (struct gguf_context * ctx, int i); + + // overrides existing values or adds a new one + GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val); + GGML_API void gguf_set_val_i8 (struct gguf_context * ctx, const char * key, int8_t val); + GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val); + GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t val); + GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val); + GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t val); + GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float val); + GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val); + GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t val); + GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double val); + GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val); + GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val); + GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n); + GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, int n); + + // set or add KV pairs from another context + GGML_API void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src); + + // manage tensor info + GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor); + GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type); + GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size); + + // writing gguf files can be done in 2 ways: + // + // - write the entire gguf_context to a binary file in a single pass: + // + // gguf_write_to_file(ctx, fname); + // + // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data: + // + // FILE * f = fopen(fname, "wb"); + // fseek(f, gguf_get_meta_size(ctx), SEEK_SET); + // fwrite(f, ...); + // void * data = gguf_meta_get_meta_data(ctx); + // fseek(f, 0, SEEK_SET); + // fwrite(f, data, gguf_get_meta_size(ctx)); + // free(data); + // fclose(f); + // + + // write the entire context to a binary file + GGML_API void gguf_write_to_file(struct gguf_context * ctx, const char * fname, bool only_meta); + + // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding + GGML_API size_t gguf_get_meta_size(struct gguf_context * ctx); + GGML_API void gguf_get_meta_data(struct gguf_context * ctx, void * data); + + // + // system info + // + + GGML_API int ggml_cpu_has_avx (void); + GGML_API int ggml_cpu_has_avx2 (void); + GGML_API int ggml_cpu_has_avx512 (void); + GGML_API int ggml_cpu_has_avx512_vbmi(void); + GGML_API int ggml_cpu_has_avx512_vnni(void); + GGML_API int ggml_cpu_has_fma (void); + GGML_API int ggml_cpu_has_neon (void); + GGML_API int ggml_cpu_has_arm_fma (void); + GGML_API int ggml_cpu_has_f16c (void); + GGML_API int ggml_cpu_has_fp16_va (void); + GGML_API int ggml_cpu_has_wasm_simd (void); + GGML_API int ggml_cpu_has_blas (void); + GGML_API int ggml_cpu_has_cublas (void); + GGML_API int ggml_cpu_has_clblast (void); + GGML_API int ggml_cpu_has_gpublas (void); + GGML_API int ggml_cpu_has_sse3 (void); + GGML_API int ggml_cpu_has_ssse3 (void); + GGML_API int ggml_cpu_has_vsx (void); + + // + // Internal types and functions exposed for tests and benchmarks + // + +#ifdef __cplusplus +// restrict not standard in C++ +#define GGML_RESTRICT +#else +#define GGML_RESTRICT restrict +#endif + typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); + typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); + typedef void (*ggml_vec_dot_t) (const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y); + + typedef struct { + const char * type_name; + int blck_size; + size_t type_size; + bool is_quantized; + ggml_to_float_t to_float; + ggml_from_float_t from_float; + ggml_from_float_t from_float_reference; + ggml_vec_dot_t vec_dot; + enum ggml_type vec_dot_type; + } ggml_type_traits_t; + + ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type); + +#ifdef __cplusplus +} +#endif diff --git a/plugins/wasi_nn/thirdparty/ggml/k_quants.c b/plugins/wasi_nn/thirdparty/ggml/k_quants.c new file mode 100644 index 00000000..62085882 --- /dev/null +++ b/plugins/wasi_nn/thirdparty/ggml/k_quants.c @@ -0,0 +1,4318 @@ +#include "k_quants.h" +#include "ggml.h" + +#include +#include +#include + +#ifdef __ARM_NEON + +// if YCM cannot find , make a symbolic link to it, for example: +// +// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ +// +#include + +#if !defined(__aarch64__) +inline static int32_t vaddvq_s16(int16x8_t v) { + return + (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) + + (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) + + (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) + + (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7); +} + +inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { + int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a)); + int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b)); + return vcombine_s16(a0, b0); +} + +inline static int32_t vaddvq_s32(int32x4_t v) { + return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); +} +#endif + +#else + +#ifdef __wasm_simd128__ +#include +#else +#ifdef __POWER9_VECTOR__ +#include +#undef bool +#define bool _Bool +#else +#if defined(_MSC_VER) || defined(__MINGW32__) +#include +#else +#if !defined(__riscv) +#include +#endif +#endif +#endif +#endif +#endif + +#undef MIN +#undef MAX +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) + +// +// 2-6 bit quantization in super-blocks +// + + +// +// ===================== Helper functions +// +static inline int nearest_int(float fval) { + assert(fval <= 4194303.f); + float val = fval + 12582912.f; + int i; memcpy(&i, &val, sizeof(int)); + return (i & 0x007fffff) - 0x00400000; +} + +static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type) { + float max = 0; + float amax = 0; + for (int i = 0; i < n; ++i) { + float ax = fabsf(x[i]); + if (ax > amax) { amax = ax; max = x[i]; } + } + if (amax < 1e-30f) { // all zero + for (int i = 0; i < n; ++i) { + L[i] = 0; + } + return 0.f; + } + float iscale = -nmax / max; + if (rmse_type == 0) { + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + L[i] = nmax + MAX(-nmax, MIN(nmax-1, l)); + } + return 1/iscale; + } + bool return_early = false; + if (rmse_type < 0) { + rmse_type = -rmse_type; + return_early = true; + } + int weight_type = rmse_type%2; + float sumlx = 0; + float suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + L[i] = l + nmax; + float w = weight_type == 1 ? x[i] * x[i] : 1; + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + float scale = sumlx/suml2; + if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale; + float best = scale * sumlx; + for (int is = -9; is <= 9; ++is) { + if (is == 0) { + continue; + } + iscale = -(nmax + 0.1f*is) / max; + sumlx = suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + float w = weight_type == 1 ? x[i] * x[i] : 1; + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + if (suml2 > 0 && sumlx*sumlx > best*suml2) { + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + L[i] = nmax + MAX(-nmax, MIN(nmax-1, l)); + } + scale = sumlx/suml2; best = scale*sumlx; + } + } + return scale; +} + +static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) { + float max = 0; + float amax = 0; + for (int i = 0; i < n; ++i) { + float ax = fabsf(x[i]); + if (ax > amax) { amax = ax; max = x[i]; } + } + if (!amax) { // all zero + for (int i = 0; i < n; ++i) { L[i] = 0; } + return 0.f; + } + float iscale = -nmax / max; + if (do_rmse) { + float sumlx = 0; + float suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + L[i] = l; + float w = x[i]*x[i]; + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + for (int itry = 0; itry < 5; ++itry) { + int n_changed = 0; + for (int i = 0; i < n; ++i) { + float w = x[i]*x[i]; + float slx = sumlx - w*x[i]*L[i]; + if (slx > 0) { + float sl2 = suml2 - w*L[i]*L[i]; + int new_l = nearest_int(x[i] * sl2 / slx); + new_l = MAX(-nmax, MIN(nmax-1, new_l)); + if (new_l != L[i]) { + slx += w*x[i]*new_l; + sl2 += w*new_l*new_l; + if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) { + L[i] = new_l; sumlx = slx; suml2 = sl2; + ++n_changed; + } + } + } + } + if (!n_changed) { + break; + } + } + for (int i = 0; i < n; ++i) { + L[i] += nmax; + } + return sumlx / suml2; + } + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + L[i] = l + nmax; + } + return 1/iscale; +} + +static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min, + int ntry, float alpha) { + float min = x[0]; + float max = x[0]; + for (int i = 1; i < n; ++i) { + if (x[i] < min) min = x[i]; + if (x[i] > max) max = x[i]; + } + if (max == min) { + for (int i = 0; i < n; ++i) L[i] = 0; + *the_min = 0; + return 0.f; + } + if (min > 0) min = 0; + float iscale = nmax/(max - min); + float scale = 1/iscale; + for (int itry = 0; itry < ntry; ++itry) { + float sumlx = 0; int suml2 = 0; + bool did_change = false; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + l = MAX(0, MIN(nmax, l)); + if (l != L[i]) { + L[i] = l; + did_change = true; + } + sumlx += (x[i] - min)*l; + suml2 += l*l; + } + scale = sumlx/suml2; + float sum = 0; + for (int i = 0; i < n; ++i) { + sum += x[i] - scale*L[i]; + } + min = alpha*min + (1 - alpha)*sum/n; + if (min > 0) min = 0; + iscale = 1/scale; + if (!did_change) break; + } + *the_min = -min; + return scale; +} + +static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights, + uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux, + float rmin, float rdelta, int nstep, bool use_mad) { + float min = x[0]; + float max = x[0]; + float sum_w = weights[0]; + float sum_x = sum_w * x[0]; + for (int i = 1; i < n; ++i) { + if (x[i] < min) min = x[i]; + if (x[i] > max) max = x[i]; + float w = weights[i]; + sum_w += w; + sum_x += w * x[i]; + } + if (min > 0) min = 0; + if (max == min) { + for (int i = 0; i < n; ++i) L[i] = 0; + *the_min = -min; + return 0.f; + } + float iscale = nmax/(max - min); + float scale = 1/iscale; + float best_mad = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + L[i] = MAX(0, MIN(nmax, l)); + float diff = scale * L[i] + min - x[i]; + diff = use_mad ? fabsf(diff) : diff * diff; + float w = weights[i]; + best_mad += w * diff; + } + if (nstep < 1) { + *the_min = -min; + return scale; + } + for (int is = 0; is <= nstep; ++is) { + iscale = (rmin + rdelta*is + nmax)/(max - min); + float sum_l = 0, sum_l2 = 0, sum_xl = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + l = MAX(0, MIN(nmax, l)); + Laux[i] = l; + float w = weights[i]; + sum_l += w*l; + sum_l2 += w*l*l; + sum_xl += w*l*x[i]; + } + float D = sum_w * sum_l2 - sum_l * sum_l; + if (D > 0) { + float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D; + float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D; + if (this_min > 0) { + this_min = 0; + this_scale = sum_xl / sum_l2; + } + float mad = 0; + for (int i = 0; i < n; ++i) { + float diff = this_scale * Laux[i] + this_min - x[i]; + diff = use_mad ? fabsf(diff) : diff * diff; + float w = weights[i]; + mad += w * diff; + } + if (mad < best_mad) { + for (int i = 0; i < n; ++i) { + L[i] = Laux[i]; + } + best_mad = mad; + scale = this_scale; + min = this_min; + } + } + } + *the_min = -min; + return scale; +} + +#if QK_K == 256 +static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) { + if (j < 4) { + *d = q[j] & 63; *m = q[j + 4] & 63; + } else { + *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + } +} +#endif + +//========================- 2-bit (de)-quantization + +void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + uint8_t L[QK_K]; + uint8_t Laux[16]; + float weights[16]; + float mins[QK_K/16]; + float scales[QK_K/16]; + + const float q4scale = 15.f; + + for (int i = 0; i < nb; i++) { + + float max_scale = 0; // as we are deducting the min, scales are always positive + float max_min = 0; + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 16; ++l) weights[l] = fabsf(x[16*j + l]); + scales[j] = make_qkx2_quants(16, 3, x + 16*j, weights, L + 16*j, &mins[j], Laux, -0.5f, 0.1f, 15, true); + float scale = scales[j]; + if (scale > max_scale) { + max_scale = scale; + } + float min = mins[j]; + if (min > max_min) { + max_min = min; + } + } + + if (max_scale > 0) { + float iscale = q4scale/max_scale; + for (int j = 0; j < QK_K/16; ++j) { + int l = nearest_int(iscale*scales[j]); + y[i].scales[j] = l; + } + y[i].d = ggml_fp32_to_fp16(max_scale/q4scale); + } else { + for (int j = 0; j < QK_K/16; ++j) y[i].scales[j] = 0; + y[i].d = ggml_fp32_to_fp16(0.f); + } + if (max_min > 0) { + float iscale = q4scale/max_min; + for (int j = 0; j < QK_K/16; ++j) { + int l = nearest_int(iscale*mins[j]); + y[i].scales[j] |= (l << 4); + } + y[i].dmin = ggml_fp32_to_fp16(max_min/q4scale); + } else { + y[i].dmin = ggml_fp32_to_fp16(0.f); + } + for (int j = 0; j < QK_K/16; ++j) { + const float d = ggml_fp16_to_fp32(y[i].d) * (y[i].scales[j] & 0xF); + if (!d) continue; + const float dm = ggml_fp16_to_fp32(y[i].dmin) * (y[i].scales[j] >> 4); + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int((x[16*j + ii] + dm)/d); + l = MAX(0, MIN(3, l)); + L[16*j + ii] = l; + } + } + +#if QK_K == 256 + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); + } + } +#else + for (int l = 0; l < 16; ++l) { + y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6); + } +#endif + + x += QK_K; + + } +} + +void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = ggml_fp16_to_fp32(x[i].d); + const float min = ggml_fp16_to_fp32(x[i].dmin); + + const uint8_t * q = x[i].qs; + +#if QK_K == 256 + int is = 0; + float dl, ml; + for (int n = 0; n < QK_K; n += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + + uint8_t sc = x[i].scales[is++]; + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml; + + sc = x[i].scales[is++]; + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml; + + shift += 2; + } + q += 32; + } +#else + float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4); + float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4); + float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4); + float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4); + for (int l = 0; l < 16; ++l) { + y[l+ 0] = dl1 * ((int8_t)((q[l] >> 0) & 3)) - ml1; + y[l+16] = dl2 * ((int8_t)((q[l] >> 2) & 3)) - ml2; + y[l+32] = dl3 * ((int8_t)((q[l] >> 4) & 3)) - ml3; + y[l+48] = dl4 * ((int8_t)((q[l] >> 6) & 3)) - ml4; + } + y += QK_K; +#endif + } +} + +void quantize_row_q2_K(const float * restrict x, void * restrict vy, int k) { + quantize_row_q2_K_reference(x, vy, k); +} + +size_t ggml_quantize_q2_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { + const int nb = k / QK_K; + + // TODO - collect histograms - although, at a second thought, I don't really care about them + (void)hist; + + for (int j = 0; j < nb; j += k) { + block_q2_K * restrict y = (block_q2_K *)dst + j/QK_K; + quantize_row_q2_K_reference(src + j, y, k); + } + return (n/QK_K*sizeof(block_q2_K)); +} + +//========================= 3-bit (de)-quantization + +void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + int8_t L[QK_K]; + float scales[QK_K / 16]; + + for (int i = 0; i < nb; i++) { + + float max_scale = 0; + float amax = 0; + for (int j = 0; j < QK_K/16; ++j) { + scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true); + float scale = fabsf(scales[j]); + if (scale > amax) { + amax = scale; max_scale = scales[j]; + } + } + +#if QK_K == 256 + memset(y[i].scales, 0, 12); + if (max_scale) { + float iscale = -32.f/max_scale; + for (int j = 0; j < QK_K/16; ++j) { + int8_t l = nearest_int(iscale*scales[j]); + l = MAX(-32, MIN(31, l)) + 32; + if (j < 8) { + y[i].scales[j] = l & 0xF; + } else { + y[i].scales[j-8] |= ((l & 0xF) << 4); + } + l >>= 4; + y[i].scales[j%4 + 8] |= (l << (2*(j/4))); + } + y[i].d = ggml_fp32_to_fp16(1/iscale); + } else { + y[i].d = ggml_fp32_to_fp16(0.f); + } + + int8_t sc; + for (int j = 0; j < QK_K/16; ++j) { + sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4; + sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32; + float d = ggml_fp16_to_fp32(y[i].d) * sc; + if (!d) { + continue; + } + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-4, MIN(3, l)); + L[16*j + ii] = l + 4; + } + } +#else + if (max_scale) { + float iscale = -8.f/max_scale; + for (int j = 0; j < QK_K/16; j+=2) { + int l1 = nearest_int(iscale*scales[j]); + l1 = 8 + MAX(-8, MIN(7, l1)); + int l2 = nearest_int(iscale*scales[j+1]); + l2 = 8 + MAX(-8, MIN(7, l2)); + y[i].scales[j/2] = l1 | (l2 << 4); + } + y[i].d = ggml_fp32_to_fp16(1/iscale); + } else { + for (int j = 0; j < QK_K/16; j+=2) { + y[i].scales[j/2] = 0; + } + y[i].d = ggml_fp32_to_fp16(0.f); + } + for (int j = 0; j < QK_K/16; ++j) { + int s = j%2 == 0 ? y[i].scales[j/2] & 0xF : y[i].scales[j/2] >> 4; + float d = ggml_fp16_to_fp32(y[i].d) * (s - 8); + if (!d) { + continue; + } + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-4, MIN(3, l)); + L[16*j + ii] = l + 4; + } + } +#endif + + memset(y[i].hmask, 0, QK_K/8); + // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc. + int m = 0; + uint8_t hm = 1; + for (int j = 0; j < QK_K; ++j) { + if (L[j] > 3) { + y[i].hmask[m] |= hm; + L[j] -= 4; + } + if (++m == QK_K/8) { + m = 0; hm <<= 1; + } + } +#if QK_K == 256 + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); + } + } +#else + for (int l = 0; l < 16; ++l) { + y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6); + } +#endif + + x += QK_K; + } +} + +#if QK_K == 256 +void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + uint32_t aux[4]; + const int8_t * scales = (const int8_t*)aux; + + for (int i = 0; i < nb; i++) { + + const float d_all = ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + uint8_t m = 1; + + memcpy(aux, x[i].scales, 12); + uint32_t tmp = aux[2]; + aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + + int is = 0; + float dl; + for (int n = 0; n < QK_K; n += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + + dl = d_all * (scales[is++] - 32); + for (int l = 0; l < 16; ++l) { + *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4)); + } + + dl = d_all * (scales[is++] - 32); + for (int l = 0; l < 16; ++l) { + *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4)); + } + + shift += 2; + m <<= 1; + } + q += 32; + } + + } +} +#else +void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + assert(QK_K == 64); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d_all = ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + + const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8); + const float d2 = d_all * ((x[i].scales[0] >> 4) - 8); + const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8); + const float d4 = d_all * ((x[i].scales[1] >> 4) - 8); + + for (int l=0; l<8; ++l) { + uint8_t h = hm[l]; + y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4)); + y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4)); + y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4)); + y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4)); + y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4)); + y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4)); + y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4)); + y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4)); + } + y += QK_K; + } +} +#endif + +void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) { + quantize_row_q3_K_reference(x, vy, k); +} + +size_t ggml_quantize_q3_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { + const int nb = k / QK_K; + + // TODO - collect histograms - although, at a second thought, I don't really care about them + (void)hist; + + for (int j = 0; j < nb; j += k) { + block_q3_K * restrict y = (block_q3_K *)dst + j/QK_K; + quantize_row_q3_K_reference(src + j, y, k); + } + return (n/QK_K*sizeof(block_q3_K)); +} + +// ====================== 4-bit (de)-quantization + +void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + uint8_t L[QK_K]; + uint8_t Laux[32]; + float weights[32]; + float mins[QK_K/32]; + float scales[QK_K/32]; + + for (int i = 0; i < nb; i++) { + + float max_scale = 0; // as we are deducting the min, scales are always positive + float max_min = 0; + for (int j = 0; j < QK_K/32; ++j) { + //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f); + float sum_x2 = 0; + for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l]; + float av_x = sqrtf(sum_x2/32); + for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]); + scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false); + float scale = scales[j]; + if (scale > max_scale) { + max_scale = scale; + } + float min = mins[j]; + if (min > max_min) { + max_min = min; + } + } + +#if QK_K == 256 + float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f; + float inv_min = max_min > 0 ? 63.f/max_min : 0.f; + for (int j = 0; j < QK_K/32; ++j) { + uint8_t ls = nearest_int(inv_scale*scales[j]); + uint8_t lm = nearest_int(inv_min*mins[j]); + ls = MIN(63, ls); + lm = MIN(63, lm); + if (j < 4) { + y[i].scales[j] = ls; + y[i].scales[j+4] = lm; + } else { + y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4); + y[i].scales[j-4] |= ((ls >> 4) << 6); + y[i].scales[j-0] |= ((lm >> 4) << 6); + } + } + y[i].d = ggml_fp32_to_fp16(max_scale/63.f); + y[i].dmin = ggml_fp32_to_fp16(max_min/63.f); + + uint8_t sc, m; + for (int j = 0; j < QK_K/32; ++j) { + get_scale_min_k4(j, y[i].scales, &sc, &m); + const float d = ggml_fp16_to_fp32(y[i].d) * sc; + if (!d) continue; + const float dm = ggml_fp16_to_fp32(y[i].dmin) * m; + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((x[32*j + ii] + dm)/d); + l = MAX(0, MIN(15, l)); + L[32*j + ii] = l; + } + } +#else + const float s_factor = 15.f; + float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f; + float inv_min = max_min > 0 ? s_factor/max_min : 0.f; + int d1 = nearest_int(inv_scale*scales[0]); + int m1 = nearest_int(inv_min*mins[0]); + int d2 = nearest_int(inv_scale*scales[1]); + int m2 = nearest_int(inv_min*mins[1]); + y[i].scales[0] = d1 | (m1 << 4); + y[i].scales[1] = d2 | (m2 << 4); + y[i].d[0] = ggml_fp32_to_fp16(max_scale/s_factor); + y[i].d[1] = ggml_fp32_to_fp16(max_min/s_factor); + + float sumlx = 0; + int suml2 = 0; + for (int j = 0; j < QK_K/32; ++j) { + const uint8_t sd = y[i].scales[j] & 0xF; + const uint8_t sm = y[i].scales[j] >> 4; + const float d = ggml_fp16_to_fp32(y[i].d[0]) * sd; + if (!d) continue; + const float m = ggml_fp16_to_fp32(y[i].d[1]) * sm; + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((x[32*j + ii] + m)/d); + l = MAX(0, MIN(15, l)); + L[32*j + ii] = l; + sumlx += (x[32*j + ii] + m)*l*sd; + suml2 += l*l*sd*sd; + } + } + if (suml2) { + y[i].d[0] = ggml_fp32_to_fp16(sumlx/suml2); + } +#endif + uint8_t * q = y[i].qs; + for (int j = 0; j < QK_K; j += 64) { + for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4); + q += 32; + } + + x += QK_K; + + } +} + +void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const uint8_t * q = x[i].qs; + +#if QK_K == 256 + + const float d = ggml_fp16_to_fp32(x[i].d); + const float min = ggml_fp16_to_fp32(x[i].dmin); + + int is = 0; + uint8_t sc, m; + for (int j = 0; j < QK_K; j += 64) { + get_scale_min_k4(is + 0, x[i].scales, &sc, &m); + const float d1 = d * sc; const float m1 = min * m; + get_scale_min_k4(is + 1, x[i].scales, &sc, &m); + const float d2 = d * sc; const float m2 = min * m; + for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1; + for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2; + q += 32; is += 2; + } +#else + const float dall = ggml_fp16_to_fp32(x[i].d[0]); + const float mall = ggml_fp16_to_fp32(x[i].d[1]); + const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4); + const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4); + for (int l = 0; l < 32; ++l) { + y[l+ 0] = d1 * (q[l] & 0xF) - m1; + y[l+32] = d2 * (q[l] >> 4) - m2; + } + y += QK_K; +#endif + + } +} + +void quantize_row_q4_K(const float * restrict x, void * restrict vy, int k) { + assert(k % QK_K == 0); + block_q4_K * restrict y = vy; + quantize_row_q4_K_reference(x, y, k); +} + +size_t ggml_quantize_q4_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + (void)hist; // TODO: collect histograms + for (int j = 0; j < nb; j += k) { + block_q4_K * restrict y = (block_q4_K *)dst + j/QK_K; + quantize_row_q4_K_reference(src + j, y, k); + } + return (n/QK_K*sizeof(block_q4_K)); +} + +// ====================== 5-bit (de)-quantization + +void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + +#if QK_K == 256 + uint8_t L[QK_K]; + float mins[QK_K/32]; + float scales[QK_K/32]; + float weights[32]; + uint8_t Laux[32]; +#else + int8_t L[QK_K]; + float scales[QK_K/16]; +#endif + + for (int i = 0; i < nb; i++) { + +#if QK_K == 256 + + float max_scale = 0; // as we are deducting the min, scales are always positive + float max_min = 0; + for (int j = 0; j < QK_K/32; ++j) { + //scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f); + float sum_x2 = 0; + for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l]; + float av_x = sqrtf(sum_x2/32); + for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]); + scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false); + float scale = scales[j]; + if (scale > max_scale) { + max_scale = scale; + } + float min = mins[j]; + if (min > max_min) { + max_min = min; + } + } + + float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f; + float inv_min = max_min > 0 ? 63.f/max_min : 0.f; + for (int j = 0; j < QK_K/32; ++j) { + uint8_t ls = nearest_int(inv_scale*scales[j]); + uint8_t lm = nearest_int(inv_min*mins[j]); + ls = MIN(63, ls); + lm = MIN(63, lm); + if (j < 4) { + y[i].scales[j] = ls; + y[i].scales[j+4] = lm; + } else { + y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4); + y[i].scales[j-4] |= ((ls >> 4) << 6); + y[i].scales[j-0] |= ((lm >> 4) << 6); + } + } + y[i].d = ggml_fp32_to_fp16(max_scale/63.f); + y[i].dmin = ggml_fp32_to_fp16(max_min/63.f); + + uint8_t sc, m; + for (int j = 0; j < QK_K/32; ++j) { + get_scale_min_k4(j, y[i].scales, &sc, &m); + const float d = ggml_fp16_to_fp32(y[i].d) * sc; + if (!d) continue; + const float dm = ggml_fp16_to_fp32(y[i].dmin) * m; + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((x[32*j + ii] + dm)/d); + l = MAX(0, MIN(31, l)); + L[32*j + ii] = l; + } + } + + uint8_t * restrict qh = y[i].qh; + uint8_t * restrict ql = y[i].qs; + memset(qh, 0, QK_K/8); + + uint8_t m1 = 1, m2 = 2; + for (int n = 0; n < QK_K; n += 64) { + for (int j = 0; j < 32; ++j) { + int l1 = L[n + j]; + if (l1 > 15) { + l1 -= 16; qh[j] |= m1; + } + int l2 = L[n + j + 32]; + if (l2 > 15) { + l2 -= 16; qh[j] |= m2; + } + ql[j] = l1 | (l2 << 4); + } + m1 <<= 2; m2 <<= 2; + ql += 32; + } +#else + float max_scale = 0, amax = 0; + for (int j = 0; j < QK_K/16; ++j) { + scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1); + float abs_scale = fabsf(scales[j]); + if (abs_scale > amax) { + amax = abs_scale; + max_scale = scales[j]; + } + } + + float iscale = -128.f/max_scale; + for (int j = 0; j < QK_K/16; ++j) { + int l = nearest_int(iscale*scales[j]); + y[i].scales[j] = MAX(-128, MIN(127, l)); + } + y[i].d = ggml_fp32_to_fp16(1/iscale); + + for (int j = 0; j < QK_K/16; ++j) { + const float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j]; + if (!d) continue; + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-16, MIN(15, l)); + L[16*j + ii] = l + 16; + } + } + + uint8_t * restrict qh = y[i].qh; + uint8_t * restrict ql = y[i].qs; + memset(qh, 0, QK_K/8); + + for (int j = 0; j < 32; ++j) { + int jm = j%8; + int is = j/8; + int l1 = L[j]; + if (l1 > 15) { + l1 -= 16; qh[jm] |= (1 << is); + } + int l2 = L[j + 32]; + if (l2 > 15) { + l2 -= 16; qh[jm] |= (1 << (4 + is)); + } + ql[j] = l1 | (l2 << 4); + } +#endif + + x += QK_K; + + } +} + +void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const uint8_t * ql = x[i].qs; + const uint8_t * qh = x[i].qh; + +#if QK_K == 256 + + const float d = ggml_fp16_to_fp32(x[i].d); + const float min = ggml_fp16_to_fp32(x[i].dmin); + + int is = 0; + uint8_t sc, m; + uint8_t u1 = 1, u2 = 2; + for (int j = 0; j < QK_K; j += 64) { + get_scale_min_k4(is + 0, x[i].scales, &sc, &m); + const float d1 = d * sc; const float m1 = min * m; + get_scale_min_k4(is + 1, x[i].scales, &sc, &m); + const float d2 = d * sc; const float m2 = min * m; + for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1; + for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2; + ql += 32; is += 2; + u1 <<= 2; u2 <<= 2; + } +#else + float d = ggml_fp16_to_fp32(x[i].d); + const int8_t * restrict s = x[i].scales; + for (int l = 0; l < 8; ++l) { + y[l+ 0] = d * s[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16)); + y[l+ 8] = d * s[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16)); + y[l+16] = d * s[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16)); + y[l+24] = d * s[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16)); + y[l+32] = d * s[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16)); + y[l+40] = d * s[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16)); + y[l+48] = d * s[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16)); + y[l+56] = d * s[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16)); + } + y += QK_K; +#endif + } +} + +void quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) { + assert(k % QK_K == 0); + block_q5_K * restrict y = vy; + quantize_row_q5_K_reference(x, y, k); +} + +size_t ggml_quantize_q5_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + (void)hist; + for (int j = 0; j < nb; j += k) { + block_q5_K * restrict y = (block_q5_K *)dst + j/QK_K; + quantize_row_q5_K_reference(src + j, y, k); + } + return (n/QK_K*sizeof(block_q5_K)); +} + +// ====================== 6-bit (de)-quantization + +void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + int8_t L[QK_K]; + float scales[QK_K/16]; + + for (int i = 0; i < nb; i++) { + + float max_scale = 0; + float max_abs_scale = 0; + + for (int ib = 0; ib < QK_K/16; ++ib) { + + const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1); + scales[ib] = scale; + + const float abs_scale = fabsf(scale); + if (abs_scale > max_abs_scale) { + max_abs_scale = abs_scale; + max_scale = scale; + } + + } + + if (!max_abs_scale) { + memset(&y[i], 0, sizeof(block_q6_K)); + y[i].d = ggml_fp32_to_fp16(0.f); + x += QK_K; + continue; + } + + float iscale = -128.f/max_scale; + y[i].d = ggml_fp32_to_fp16(1/iscale); + for (int ib = 0; ib < QK_K/16; ++ib) { + y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib])); + } + + for (int j = 0; j < QK_K/16; ++j) { + float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j]; + if (!d) { + continue; + } + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-32, MIN(31, l)); + L[16*j + ii] = l + 32; + } + } + + uint8_t * restrict ql = y[i].ql; + uint8_t * restrict qh = y[i].qh; +#if QK_K == 256 + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + const uint8_t q1 = L[j + l + 0] & 0xF; + const uint8_t q2 = L[j + l + 32] & 0xF; + const uint8_t q3 = L[j + l + 64] & 0xF; + const uint8_t q4 = L[j + l + 96] & 0xF; + ql[l+ 0] = q1 | (q3 << 4); + ql[l+32] = q2 | (q4 << 4); + qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6); + } + ql += 64; + qh += 32; + } +#else + for (int l = 0; l < 32; ++l) { + const uint8_t q1 = L[l + 0] & 0xF; + const uint8_t q2 = L[l + 32] & 0xF; + ql[l] = q1 | (q2 << 4); + } + for (int l = 0; l < 16; ++l) { + qh[l] = (L[l] >> 4) | ((L[l + 16] >> 4) << 2) | ((L[l + 32] >> 4) << 4) | ((L[l + 48] >> 4) << 6); + } +#endif + + x += QK_K; + + } +} + +void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict ql = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict sc = x[i].scales; + +#if QK_K == 256 + for (int n = 0; n < QK_K; n += 128) { + for (int l = 0; l < 32; ++l) { + int is = l/16; + const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + y[l + 0] = d * sc[is + 0] * q1; + y[l + 32] = d * sc[is + 2] * q2; + y[l + 64] = d * sc[is + 4] * q3; + y[l + 96] = d * sc[is + 6] * q4; + } + y += 128; + ql += 64; + qh += 32; + sc += 8; + } +#else + for (int l = 0; l < 16; ++l) { + const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + y[l+ 0] = d * sc[0] * q1; + y[l+16] = d * sc[1] * q2; + y[l+32] = d * sc[2] * q3; + y[l+48] = d * sc[3] * q4; + } + y += 64; +#endif + + } +} + +void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) { + assert(k % QK_K == 0); + block_q6_K * restrict y = vy; + quantize_row_q6_K_reference(x, y, k); +} + +size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + (void)hist; // TODO + + for (int j = 0; j < nb; j += k) { + block_q6_K * restrict y = (block_q6_K *)dst + j/QK_K; + quantize_row_q6_K_reference(src + j, y, k); + } + return (n/QK_K*sizeof(block_q6_K)); +} + +//===================================== Q8_K ============================================== + +void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + float max = 0; + float amax = 0; + for (int j = 0; j < QK_K; ++j) { + float ax = fabsf(x[j]); + if (ax > amax) { + amax = ax; max = x[j]; + } + } + if (!amax) { + y[i].d = 0; + memset(y[i].qs, 0, QK_K); + x += QK_K; + continue; + } + const float iscale = -128.f/max; + for (int j = 0; j < QK_K; ++j) { + int v = nearest_int(iscale*x[j]); + y[i].qs[j] = MIN(127, v); + } + for (int j = 0; j < QK_K/16; ++j) { + int sum = 0; + for (int ii = 0; ii < 16; ++ii) { + sum += y[i].qs[j*16 + ii]; + } + y[i].bsums[j] = sum; + } + y[i].d = 1/iscale; + x += QK_K; + } +} + +void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < QK_K; ++j) { + *y++ = x[i].d * x[i].qs[j]; + } + } +} + +void quantize_row_q8_K(const float * restrict x, void * restrict y, int k) { + quantize_row_q8_K_reference(x, y, k); +} + +//===================================== Dot ptoducts ================================= + +// +// Helper functions +// +#if __AVX__ || __AVX2__ || __AVX512F__ + +// horizontally add 8 floats +static inline float hsum_float_8(const __m256 x) { + __m128 res = _mm256_extractf128_ps(x, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(x)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + return _mm_cvtss_f32(res); +} + +// shuffles to pick the required scales in dot products +static inline __m256i get_scale_shuffle_q3k(int i) { + static const uint8_t k_shuffle[128] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, + }; + return _mm256_loadu_si256((const __m256i*)k_shuffle + i); +} +static inline __m256i get_scale_shuffle_k4(int i) { + static const uint8_t k_shuffle[256] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, + 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, + 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, + 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15 + }; + return _mm256_loadu_si256((const __m256i*)k_shuffle + i); +} +static inline __m128i get_scale_shuffle(int i) { + static const uint8_t k_shuffle[128] = { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, + 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, + 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11, + 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13, + 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15 + }; + return _mm_loadu_si128((const __m128i*)k_shuffle + i); +} +#endif + +#if QK_K == 256 +void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + + const block_q2_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + const uint8x16_t m3 = vdupq_n_u8(0x3); + const uint8x16_t m4 = vdupq_n_u8(0xF); +#if defined(__ARM_FEATURE_DOTPROD) + const int32x4_t vzero = vdupq_n_s32(0); +#endif + + int8x16x2_t q2bytes; + uint8_t aux[16]; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + const uint8_t * restrict sc = x[i].scales; + + const uint8x16_t mins_and_scales = vld1q_u8(sc); + const uint8x16_t scales = vandq_u8(mins_and_scales, m4); + vst1q_u8(aux, scales); + + const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); + const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums); + const int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}; + const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])), + vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0]))); + const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])), + vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1]))); + sum += dmin * vaddvq_s32(vaddq_s32(s0, s1)); + + int isum = 0; + int is = 0; + +// We use this macro instead of a function call because for some reason +// the code runs 2-3% slower, even if the function is declared inline +#if defined(__ARM_FEATURE_DOTPROD) +#define MULTIPLY_ACCUM_WITH_SCALE(index)\ + isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\ + isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)]; +#else +#define MULTIPLY_ACCUM_WITH_SCALE(index)\ + {\ + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\ + vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\ + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\ + vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\ + isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\ + } +#endif + +#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\ + q8bytes = vld1q_s8_x2(q8); q8 += 32;\ + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\ + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\ + MULTIPLY_ACCUM_WITH_SCALE((index)); + + + for (int j = 0; j < QK_K/128; ++j) { + + const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32; + + int8x16x2_t q8bytes = vld1q_s8_x2(q8); q8 += 32; + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3)); + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3)); + MULTIPLY_ACCUM_WITH_SCALE(0); + + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2); + + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4); + + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6); + + is += 8; + } + sum += d * isum; + + } + + *s = sum; + +#elif defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + const __m128i m4 = _mm_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m128i scales8 = _mm_and_si128(mins_and_scales, m4); + const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); + const __m256i mins = _mm256_cvtepi8_epi16(mins8); + const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums)); + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc); + + const __m256i all_scales = _mm256_cvtepi8_epi16(scales8); + const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); + const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); + const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; + + __m256i sumi = _mm256_setzero_si256(); + + for (int j = 0; j < QK_K/128; ++j) { + + const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32; + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + const __m256i q2_0 = _mm256_and_si256(q2bits, m3); + const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3); + const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3); + const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3); + + __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); + __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); + __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2); + __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3); + + p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0); + p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1); + p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2); + p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3); + + p0 = _mm256_add_epi32(p0, p1); + p2 = _mm256_add_epi32(p2, p3); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2)); + } + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(0x3); + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m2 = _mm_set1_epi8(0x2); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + // load mins and scales from block_q2_K.scales[QK_K/16] + const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m128i scales16 = _mm_and_si128(mins_and_scales, m4); + const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); + const __m128i mins_0 = _mm_cvtepi8_epi16(mins16); + const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16)); + + // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2 + const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0])); + const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8])); + + // sumf += -dmin * summs in 32bits*8 + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc); + + const __m128i scales_0 = _mm_cvtepi8_epi16(scales16); + const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16)); + const __m128i scales[2] = { scales_0, scales_1 }; + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + for (int j = 0; j < QK_K/128; ++j) { + + // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K] + const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + + // load 2bits*16*8 from block_q2_K.qs[QK_K/4] + __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; + const __m128i q2_0 = _mm_and_si128(q2bits, m3); + const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); + const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); + const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); + q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; + const __m128i q2_1 = _mm_and_si128(q2bits, m3); + const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); + const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); + const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); + + // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8 + __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0); + __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1); + __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2); + __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3); + __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4); + __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5); + __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6); + __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7); + + // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8 + __m128i shuffle = _mm_set1_epi16(0x0100); + p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0); + shuffle = _mm_add_epi16(shuffle, m2); + p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1); + shuffle = _mm_add_epi16(shuffle, m2); + p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2); + shuffle = _mm_add_epi16(shuffle, m2); + p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3); + shuffle = _mm_add_epi16(shuffle, m2); + p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4); + shuffle = _mm_add_epi16(shuffle, m2); + p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5); + shuffle = _mm_add_epi16(shuffle, m2); + p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6); + shuffle = _mm_add_epi16(shuffle, m2); + p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7); + + p0 = _mm_add_epi32(p0, p1); + p2 = _mm_add_epi32(p2, p3); + p4 = _mm_add_epi32(p4, p5); + p6 = _mm_add_epi32(p6, p7); + + // isum in 32bits*4*2 + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6)); + } + + // sumf += dall * isum - dmin * summs in 32bits + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc); + } + + *s = hsum_float_8(acc); + +#else + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + int summs = 0; + for (int j = 0; j < 16; ++j) { + summs += y[i].bsums[j] * (sc[j] >> 4); + } + + const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + int isum = 0; + int is = 0; + int d; + for (int k = 0; k < QK_K/128; ++k) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + d = sc[is++] & 0xF; + int isuml = 0; + for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + d = sc[is++] & 0xF; + isuml = 0; + for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + shift += 2; + q8 += 32; + } + q2 += 32; + } + sumf += dall * isum - dmin * summs; + } + *s = sumf; +#endif +} + +#else + +void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + + const block_q2_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + const uint8x16_t m3 = vdupq_n_u8(0x3); +#if defined(__ARM_FEATURE_DOTPROD) + const int32x4_t vzero = vdupq_n_s32(0); +#endif + + int8x16x4_t q2bytes; + + uint32_t aux32[2]; + const uint8_t * scales = (const uint8_t *)aux32; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * (float)x[i].d; + const float dmin = -y[i].d * (float)x[i].dmin; + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + const uint32_t * restrict sc = (const uint32_t *)x[i].scales; + + aux32[0] = sc[0] & 0x0f0f0f0f; + aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f; + + sum += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]); + + int isum1 = 0, isum2 = 0; + + const uint8x16_t q2bits = vld1q_u8(q2); + + const int8x16x4_t q8bytes = vld1q_s8_x4(q8); + + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3)); + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3)); + q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3)); + q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3)); + +#if defined(__ARM_FEATURE_DOTPROD) + isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0]; + isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1]; + isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2]; + isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3]; +#else + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + isum1 += vaddvq_s16(p1) * scales[0]; + isum2 += vaddvq_s16(p2) * scales[1]; + + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p4 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + isum1 += vaddvq_s16(p3) * scales[2]; + isum2 += vaddvq_s16(p4) * scales[3]; +#endif + sum += d * (isum1 + isum2); + + } + + *s = sum; + +#elif defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + + __m256 acc = _mm256_setzero_ps(); + + uint32_t ud, um; + const uint8_t * restrict db = (const uint8_t *)&ud; + const uint8_t * restrict mb = (const uint8_t *)&um; + + float summs = 0; + + // TODO: optimize this + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint32_t * restrict sc = (const uint32_t *)x[i].scales; + ud = (sc[0] >> 0) & 0x0f0f0f0f; + um = (sc[0] >> 4) & 0x0f0f0f0f; + + int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3]; + summs += dmin * smin; + + const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); + const __m256i q2_0 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits), m3); + const __m256i q2_1 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); + const __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); + + const __m256i p_0 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 0)); + const __m256i p_1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 1)); + const __m256i p_2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 0)); + const __m256i p_3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 1)); + + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3), acc); + } + + *s = hsum_float_8(acc) + summs; + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(3); + + __m256 acc = _mm256_setzero_ps(); + + uint32_t ud, um; + const uint8_t * restrict db = (const uint8_t *)&ud; + const uint8_t * restrict mb = (const uint8_t *)&um; + + float summs = 0; + + // TODO: optimize this + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint32_t * restrict sc = (const uint32_t *)x[i].scales; + ud = (sc[0] >> 0) & 0x0f0f0f0f; + um = (sc[0] >> 4) & 0x0f0f0f0f; + + int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3]; + summs += dmin * smin; + + const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); + const __m128i q2_0 = _mm_and_si128(q2bits, m3); + const __m128i q2_1 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); + const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); + const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m128i p0 = _mm_maddubs_epi16(q2_0, _mm256_extractf128_si256(q8_0, 0)); + const __m128i p1 = _mm_maddubs_epi16(q2_1, _mm256_extractf128_si256(q8_0, 1)); + const __m128i p2 = _mm_maddubs_epi16(q2_2, _mm256_extractf128_si256(q8_1, 0)); + const __m128i p3 = _mm_maddubs_epi16(q2_3, _mm256_extractf128_si256(q8_1, 1)); + + const __m256i p_0 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p0, p0)), _mm_cvtepi16_epi32(p0)); + const __m256i p_1 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p1, p1)), _mm_cvtepi16_epi32(p1)); + const __m256i p_2 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p2, p2)), _mm_cvtepi16_epi32(p2)); + const __m256i p_3 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p3, p3)), _mm_cvtepi16_epi32(p3)); + + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0)), acc); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1)), acc); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2)), acc); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3)), acc); + } + + *s = hsum_float_8(acc) + summs; + +#else + + float sumf = 0; + + int isum[4]; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + int summs = 0; + for (int j = 0; j < QK_K/16; ++j) { + summs += y[i].bsums[j] * (sc[j] >> 4); + } + + const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + isum[0] = isum[1] = isum[2] = isum[3] = 0; + for (int l = 0; l < 16; ++l) { + isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3); + isum[1] += q8[l+16] * ((q2[l] >> 2) & 3); + isum[2] += q8[l+32] * ((q2[l] >> 4) & 3); + isum[3] += q8[l+48] * ((q2[l] >> 6) & 3); + } + for (int l = 0; l < 4; ++l) { + isum[l] *= (sc[l] & 0xF); + } + sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs; + } + *s = sumf; +#endif +} +#endif + +#if QK_K == 256 +void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + uint32_t aux[3]; + uint32_t utmp[4]; + + const uint8x16_t m3b = vdupq_n_u8(0x3); +#ifdef __ARM_FEATURE_DOTPROD + const int32x4_t vzero = vdupq_n_s32(0); +#endif + + const uint8x16_t m0 = vdupq_n_u8(1); + const uint8x16_t m1 = vshlq_n_u8(m0, 1); + const uint8x16_t m2 = vshlq_n_u8(m0, 2); + const uint8x16_t m3 = vshlq_n_u8(m0, 3); + const int8_t m32 = 32; + + int8x16x4_t q3bytes; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + uint8x16x2_t qhbits = vld1q_u8_x2(qh); + + uint8x16x4_t q3h; + + int32_t isum = 0; + + // Set up scales + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= m32; + + for (int j = 0; j < QK_K/128; ++j) { + + const uint8x16x2_t q3bits = vld1q_u8_x2(q3); q3 += 32; + const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64; + const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64; + + q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2); + q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2); + q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1); + q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1); + + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0])); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1])); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2])); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3]; +#else + int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])), + vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0]))); + int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])), + vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_1.val[1]))); + int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])), + vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2]))); + int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])), + vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3]))); + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; +#endif + scale += 4; + + q3h.val[0] = vbicq_u8(m2, qhbits.val[0]); + q3h.val[1] = vbicq_u8(m2, qhbits.val[1]); + q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1); + q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1); + + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0])); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1])); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2])); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3]; +#else + p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])), + vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0]))); + p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])), + vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_2.val[1]))); + p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])), + vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2]))); + p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])), + vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3]))); + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; +#endif + scale += 4; + + if (j == 0) { + qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4); + qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4); + } + + } + sum += d * isum; + + } + + *s = sum; + +#elif defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + const __m256i mone = _mm256_set1_epi8(1); + const __m128i m32 = _mm_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + uint32_t aux[3]; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + // Set up scales + memcpy(aux, x[i].scales, 12); + __m128i scales128 = _mm_set_epi32( + ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), + ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), + (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), + (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); + scales128 = _mm_sub_epi8(scales128, m32); + const __m256i all_scales = _mm256_cvtepi8_epi16(scales128); + const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); + const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); + const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; + + // high bit + const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask); + + // integer accumulator + __m256i sumi = _mm256_setzero_si256(); + + int bit = 0; + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + // load low 2 bits + const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32; + + // prepare low and high bits + const __m256i q3l_0 = _mm256_and_si256(q3bits, m3); + const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3); + const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3); + const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3); + const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + // load Q8 quants + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); + __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2); + __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3); + + __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); + __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2); + __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + p16_2 = _mm256_sub_epi16(p16_2, q8s_2); + p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + + // multiply with scales + p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1); + p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2); + p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3); + + // accumulate + p16_0 = _mm256_add_epi32(p16_0, p16_1); + p16_2 = _mm256_add_epi32(p16_2, p16_3); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2)); + + } + + // multiply with block scale and accumulate + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(3); + const __m128i mone = _mm_set1_epi8(1); + const __m128i m32 = _mm_set1_epi8(32); + const __m128i m2 = _mm_set1_epi8(2); + + __m256 acc = _mm256_setzero_ps(); + + const uint32_t *aux; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + // Set up scales + aux = (const uint32_t *)x[i].scales; + __m128i scales128 = _mm_set_epi32( + ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), + ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), + (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), + (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); + scales128 = _mm_sub_epi8(scales128, m32); + const __m128i scales_0 = _mm_cvtepi8_epi16(scales128); + const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128)); + const __m128i scales[2] = { scales_0, scales_1 }; + + // high bit *128*2 from block_q3_K.hmask[QK_K/8] + const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]); + const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]); + + // integer accumulator + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + for (int j = 0; j < QK_K/128; ++j) { + // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4] + const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; + const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; + + // prepare low and high bits + const int bit = j << 2; + + const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3); + const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3); + const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2); + const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2); + + const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3); + const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3); + const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2); + const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2); + + const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3); + const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3); + const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2); + const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2); + + const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3); + const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3); + const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2); + const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2); + + // load Q8 quants from block_q8_K.qs[QK_K] + const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0); + __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1); + __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2); + __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3); + __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4); + __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5); + __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6); + __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7); + + __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0); + __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1); + __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2); + __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3); + __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4); + __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5); + __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6); + __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7); + + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); + p16_4 = _mm_sub_epi16(p16_4, q8s_4); + p16_5 = _mm_sub_epi16(p16_5, q8s_5); + p16_6 = _mm_sub_epi16(p16_6, q8s_6); + p16_7 = _mm_sub_epi16(p16_7, q8s_7); + + // multiply with scales + __m128i shuffle = _mm_set1_epi16(0x0100); + p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0); + shuffle = _mm_add_epi16(shuffle, m2); + p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1); + shuffle = _mm_add_epi16(shuffle, m2); + p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2); + shuffle = _mm_add_epi16(shuffle, m2); + p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3); + shuffle = _mm_add_epi16(shuffle, m2); + p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4); + shuffle = _mm_add_epi16(shuffle, m2); + p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5); + shuffle = _mm_add_epi16(shuffle, m2); + p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6); + shuffle = _mm_add_epi16(shuffle, m2); + p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7); + + // accumulate + p16_0 = _mm_add_epi32(p16_0, p16_1); + p16_2 = _mm_add_epi32(p16_2, p16_3); + p16_4 = _mm_add_epi32(p16_4, p16_5); + p16_6 = _mm_add_epi32(p16_6, p16_7); + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6)); + + } + + // multiply with block scale and accumulate + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); + + } + + *s = hsum_float_8(acc); + +#else + // scalar version + // This function is written like this so the compiler can manage to vectorize most of it + // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the + // manually vectorized version above. Every other version I tried would run at least 4 times slower. + // The ideal situation would be if we could just write the code once, and the compiler would + // automatically produce the best possible set of machine instructions, instead of us having to manually + // write vectorized versions for AVX, ARM_NEON, etc. + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + uint32_t auxs[4]; + const int8_t * scales = (const int8_t*)auxs; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + q3 += 32; + } + a = aux8; + + memcpy(auxs, x[i].scales, 12); + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + } + const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; + +#endif + +} + +#else + +void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q3_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + +#ifdef __ARM_FEATURE_DOTPROD + const int32x4_t vzero = vdupq_n_s32(0); +#endif + + const uint8x16_t m3b = vdupq_n_u8(0x3); + const uint8x16_t mh = vdupq_n_u8(4); + + int8x16x4_t q3bytes; + + uint16_t aux16[2]; + int8_t * scales = (int8_t *)aux16; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + uint8x16x4_t q3h; + + const uint8x8_t hbits = vld1_u8(x[i].hmask); + const uint8x16_t q3bits = vld1q_u8(x[i].qs); + const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs); + + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + for (int j = 0; j < 4; ++j) scales[j] -= 8; + + int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]); + + const float d = y[i].d * (float)x[i].d; + + const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1)); + q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2)); + q3h.val[1] = vandq_u8(mh, htmp); + q3h.val[2] = vandq_u8(mh, vshrq_n_u8(htmp, 2)); + q3h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 4)); + + q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b), q3h.val[0])); + q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1])); + q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2])); + q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3]; +#else + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3]; +#endif + + sum += d * isum; + + } + + *s = sum; + +#elif defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + const __m256i m1 = _mm256_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + uint64_t aux64; + + uint16_t aux16[2]; + const int8_t * aux8 = (const int8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + const __m256i scale_0 = MM256_SET_M128I(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8)); + const __m256i scale_1 = MM256_SET_M128I(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8)); + + memcpy(&aux64, x[i].hmask, 8); + + const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0); + __m256i q3h_0 = MM256_SET_M128I(_mm_srli_epi16(haux, 2), haux); + __m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4); + q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2); + q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2); + + // load low 2 bits + const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3); + + // prepare low and high bits + const __m256i q3aux = MM256_SET_M128I(_mm_srli_epi16(q3bits, 2), q3bits); + const __m256i q3l_0 = _mm256_and_si256(q3aux, m3); + const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3); + + // load Q8 quants + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + const __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); + const __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); + + __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + + // multiply with scales + p16_0 = _mm256_madd_epi16(scale_0, p16_0); + p16_1 = _mm256_madd_epi16(scale_1, p16_1); + + p16_0 = _mm256_add_epi32(p16_0, p16_1); + + // multiply with block scale and accumulate + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16_0), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(3); + const __m128i m1 = _mm_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + uint64_t aux64; + + uint16_t aux16[2]; + const int8_t * aux8 = (const int8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + const __m128i scale_0 = _mm_set1_epi16(aux8[0] - 8); + const __m128i scale_1 = _mm_set1_epi16(aux8[2] - 8); + const __m128i scale_2 = _mm_set1_epi16(aux8[1] - 8); + const __m128i scale_3 = _mm_set1_epi16(aux8[3] - 8); + + memcpy(&aux64, x[i].hmask, 8); + + __m128i q3h_0 = _mm_set_epi64x(aux64 >> 1, aux64 >> 0); + __m128i q3h_1 = _mm_srli_epi16(q3h_0, 2); + __m128i q3h_2 = _mm_srli_epi16(q3h_0, 4); + __m128i q3h_3 = _mm_srli_epi16(q3h_0, 6); + q3h_0 = _mm_slli_epi16(_mm_andnot_si128(q3h_0, m1), 2); + q3h_1 = _mm_slli_epi16(_mm_andnot_si128(q3h_1, m1), 2); + q3h_2 = _mm_slli_epi16(_mm_andnot_si128(q3h_2, m1), 2); + q3h_3 = _mm_slli_epi16(_mm_andnot_si128(q3h_3, m1), 2); + + // load low 2 bits + const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3); + + // prepare low and high bits + const __m128i q3l_0 = _mm_and_si128(q3bits, m3); + const __m128i q3l_1 = _mm_and_si128(_mm_srli_epi16(q3bits, 2), m3); + const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits, 4), m3); + const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits, 6), m3); + + // load Q8 quants + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0)); + const __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, _mm256_extractf128_si256(q8_0, 1)); + const __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, _mm256_extractf128_si256(q8_1, 0)); + const __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, _mm256_extractf128_si256(q8_1, 1)); + + __m128i p16_0 = _mm_maddubs_epi16(q3l_0, _mm256_extractf128_si256(q8_0, 0)); + __m128i p16_1 = _mm_maddubs_epi16(q3l_1, _mm256_extractf128_si256(q8_0, 1)); + __m128i p16_2 = _mm_maddubs_epi16(q3l_2, _mm256_extractf128_si256(q8_1, 0)); + __m128i p16_3 = _mm_maddubs_epi16(q3l_3, _mm256_extractf128_si256(q8_1, 1)); + + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); + + // multiply with scales + p16_0 = _mm_madd_epi16(scale_0, p16_0); + p16_1 = _mm_madd_epi16(scale_1, p16_1); + p16_2 = _mm_madd_epi16(scale_2, p16_2); + p16_3 = _mm_madd_epi16(scale_3, p16_3); + + p16_0 = _mm_add_epi32(p16_0, p16_2); + p16_1 = _mm_add_epi32(p16_1, p16_3); + __m256i p16 = MM256_SET_M128I(p16_1, p16_0); + + // multiply with block scale and accumulate + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16)), acc); + + } + + *s = hsum_float_8(acc); + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + int32_t scales[4]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + int8_t * restrict a = aux8; + for (int l = 0; l < 8; ++l) { + a[l+ 0] = (int8_t)((q3[l+0] >> 0) & 3) - (hm[l] & 0x01 ? 0 : 4); + a[l+ 8] = (int8_t)((q3[l+8] >> 0) & 3) - (hm[l] & 0x02 ? 0 : 4); + a[l+16] = (int8_t)((q3[l+0] >> 2) & 3) - (hm[l] & 0x04 ? 0 : 4); + a[l+24] = (int8_t)((q3[l+8] >> 2) & 3) - (hm[l] & 0x08 ? 0 : 4); + a[l+32] = (int8_t)((q3[l+0] >> 4) & 3) - (hm[l] & 0x10 ? 0 : 4); + a[l+40] = (int8_t)((q3[l+8] >> 4) & 3) - (hm[l] & 0x20 ? 0 : 4); + a[l+48] = (int8_t)((q3[l+0] >> 6) & 3) - (hm[l] & 0x40 ? 0 : 4); + a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4); + } + + scales[0] = (x[i].scales[0] & 0xF) - 8; + scales[1] = (x[i].scales[0] >> 4) - 8; + scales[2] = (x[i].scales[1] & 0xF) - 8; + scales[3] = (x[i].scales[1] >> 4) - 8; + + memset(aux32, 0, 8*sizeof(int32_t)); + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux32[l] += scales[j] * aux16[l]; + } + const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; + +#endif + +} +#endif + +#if QK_K == 256 +void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q4_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#ifdef __ARM_NEON + + const uint8x16_t m4b = vdupq_n_u8(0xf); +#ifdef __ARM_FEATURE_DOTPROD + const int32x4_t mzero = vdupq_n_s32(0); +#endif + + int8x16x2_t q4bytes; + int8x16x2_t q8bytes; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, 12); + + uint32x2_t mins8 = { 0 }; + mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0); + mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1); + + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[0] &= kmask1; + + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + sumf -= dmin * vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + int32_t sumi1 = 0; + int32_t sumi2 = 0; + + for (int j = 0; j < QK_K/64; ++j) { + + const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32; + +#ifdef __ARM_FEATURE_DOTPROD + q8bytes = vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + + const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + sumi1 += vaddvq_s32(p1) * scales[2*j+0]; + + q8bytes = vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + + const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + + sumi2 += vaddvq_s32(p2) * scales[2*j+1]; +#else + q8bytes = vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0]; + + q8bytes = vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) * scales[2*j+1]; + +#endif + } + + sumf += d * (sumi1 + sumi2); + + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + __m128 acc_m = _mm_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); + + const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); + acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m); + + const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); + const __m256i scales = MM256_SET_M128I(sc128, sc128); + + __m256i sumi = _mm256_setzero_si256(); + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); + const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); + + const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4l = _mm256_and_si256(q4bits, m4); + const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); + + const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); + p16l = _mm256_madd_epi16(scale_l, p16l); + + const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); + p16h = _mm256_madd_epi16(scale_h, p16h); + const __m256i sumj = _mm256_add_epi32(p16l, p16h); + + sumi = _mm256_add_epi32(sumi, sumj); + } + + __m256 vd = _mm256_set1_ps(d); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); + + } + + acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); + acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); + + *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m2 = _mm_set1_epi8(0x2); + + __m256 acc = _mm256_setzero_ps(); + __m128 acc_m = _mm_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); + const __m128i scales = _mm_cvtepu8_epi16(utmps); + const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); + + const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); + const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); + const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); + const __m128i prod = _mm_madd_epi16(mins, q8s); + acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m); + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + __m128i shuffle = _mm_set1_epi16(0x0100); + for (int j = 0; j < QK_K/64; ++j) { + + const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + + __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4l_0 = _mm_and_si128(q4bits, m4); + const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); + q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4l_1 = _mm_and_si128(q4bits, m4); + const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); + + const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0); + p16l = _mm_madd_epi16(scale_l, p16l); + sumi_0 = _mm_add_epi32(sumi_0, p16l); + const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + p16l = _mm_maddubs_epi16(q4l_1, q8l_1); + p16l = _mm_madd_epi16(scale_l, p16l); + sumi_1 = _mm_add_epi32(sumi_1, p16l); + + const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0); + p16h = _mm_madd_epi16(scale_h, p16h); + sumi_0 = _mm_add_epi32(sumi_0, p16h); + const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + p16h = _mm_maddubs_epi16(q4h_1, q8h_1); + p16h = _mm_madd_epi16(scale_h, p16h); + sumi_1 = _mm_add_epi32(sumi_1, p16h); + + } + + __m256 vd = _mm256_set1_ps(d); + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); + + } + + acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); + acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); + + *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); + +#else + + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + a += 32; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + a += 32; q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = ggml_fp16_to_fp32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} +#else +void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q4_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + const uint8x16_t m4b = vdupq_n_u8(0xf); + +#ifdef __ARM_FEATURE_DOTPROD + const int32x4_t mzero = vdupq_n_s32(0); +#endif + + float sumf = 0; + + int8x16x2_t q4bytes; + int8x16x4_t q8bytes; + + float sum_mins = 0.f; + + uint16_t aux16[2]; + const uint8_t * restrict scales = (const uint8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t * restrict a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]); + sum_mins += y[i].d * (float)x[i].d[1] * summi; + + const float d = y[i].d * (float)x[i].d[0]; + + const uint8x16x2_t q4bits = vld1q_u8_x2(q4); + +#ifdef __ARM_FEATURE_DOTPROD + q8bytes = vld1q_s8_x4(q8); + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + + const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + const int32_t sumi1 = vaddvq_s32(p1) * scales[0]; + + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + + const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]); + const int32_t sumi2 = vaddvq_s32(p2) * scales[1]; + +#else + q8bytes = vld1q_s8_x4(q8); + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0]; + + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3]))); + int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1]; + +#endif + sumf += d * (sumi1 + sumi2); + + } + + *s = sumf - sum_mins; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0; + + uint16_t aux16[2]; + const uint8_t * scales = (const uint8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d; + const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d; + const __m256 vd = _mm256_set1_ps(d); + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); + const __m256i q4l = _mm256_and_si256(q4bits, m4); + const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); + + const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); + const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); + + const __m256i p32l = _mm256_madd_epi16(_mm256_set1_epi16(scales[0]), p16l); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32l), acc); + + const __m256i p32h = _mm256_madd_epi16(_mm256_set1_epi16(scales[1]), p16h); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32h), acc); + + } + + *s = hsum_float_8(acc) - summs; + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0; + + uint16_t aux16[2]; + const uint8_t * scales = (const uint8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d; + const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d; + const __m256 vd = _mm256_set1_ps(d); + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); + const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0); + const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1); + const __m128i q4_0 = _mm_and_si128(q4bits_0, m4); + const __m128i q4_1 = _mm_and_si128(q4bits_1, m4); + const __m128i q4_2 = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4); + const __m128i q4_3 = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0)); + const __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1)); + const __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0)); + const __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1)); + + const __m128i p32_0 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_0); + const __m128i p32_1 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_1); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_1, p32_0))), acc); + + const __m128i p32_2 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_2); + const __m128i p32_3 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_3); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_3, p32_2))), acc); + + } + + *s = hsum_float_8(acc) - summs; + +#else + + uint8_t aux8[QK_K]; + int16_t aux16[16]; + float sums [8]; + memset(sums, 0, 8*sizeof(float)); + + uint16_t s16[2]; + const uint8_t * restrict scales = (const uint8_t *)s16; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + uint8_t * restrict a = aux8; + for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF; + for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4; + + const uint16_t * restrict b = (const uint16_t *)x[i].scales; + s16[0] = b[0] & 0x0f0f; + s16[1] = (b[0] >> 4) & 0x0f0f; + + sumf -= y[i].d * ggml_fp16_to_fp32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[0]); + + for (int j = 0; j < QK_K/32; ++j) { + for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; + q8 += 16; a += 16; + for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l]; + q8 += 16; a += 16; + const float dl = d * scales[j]; + for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[l+8]); + } + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} +#endif + +#if QK_K == 256 +void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q5_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + + +#ifdef __ARM_NEON + + const uint8x16_t m4b = vdupq_n_u8(0xf); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); +#if defined(__ARM_FEATURE_DOTPROD) + const int32x4_t mzero = vdupq_n_s32(0); +#endif + + int8x16x4_t q5bytes; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8); + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8)); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + int32_t sumi_mins = vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + uint8x16x2_t qhbits = vld1q_u8_x2(qh); + + uint8x16x4_t q5h; + + int32_t sumi = 0; + + for (int j = 0; j < QK_K/64; ++j) { + + const uint8x16x2_t q5bits = vld1q_u8_x2(q5); q5 += 32; + const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64; + + q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); + q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); + q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3); + q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3); + qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2); + qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2); + + q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0])); + q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1])); + q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2])); + q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + + sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++; + sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++; +#else + + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++; + + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++; +#endif + } + + sumf += d * sumi - dmin * sumi_mins; + + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m128i mzero = _mm_setzero_si128(); + const __m256i mone = _mm256_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.f; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + +#if QK_K == 256 + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; +#else + // TODO + const float d = 0, dmin = 0; +#endif + + const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); + + const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); + const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); + summs += dmin * _mm_extract_epi32(hsum, 0); + + const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); + const __m256i scales = MM256_SET_M128I(sc128, sc128); + + const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh); + __m256i hmask = mone; + + __m256i sumi = _mm256_setzero_si256(); + + int bit = 0; + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); + const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); + + const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32; + + const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); + const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); + const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0); + hmask = _mm256_slli_epi16(hmask, 1); + + const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); + const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); + const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1); + hmask = _mm256_slli_epi16(hmask, 1); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1); + + p16_0 = _mm256_madd_epi16(scale_0, p16_0); + p16_1 = _mm256_madd_epi16(scale_1, p16_1); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + + } + + __m256 vd = _mm256_set1_ps(d); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc) + summs; + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i mzero = _mm_setzero_si128(); + const __m128i mone = _mm_set1_epi8(1); + const __m128i m2 = _mm_set1_epi8(2); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.f; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); + const __m128i scales = _mm_cvtepu8_epi16(utmps); + const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); + + const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); + const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); + const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); + const __m128i prod = _mm_madd_epi16(mins, q8s); + const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); + summs += dmin * _mm_extract_epi32(hsum, 0); + + const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]); + const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]); + __m128i hmask = mone; + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + int bit = 0; + + __m128i shuffle = _mm_set1_epi16(0x0100); + for (int j = 0; j < QK_K/64; ++j) { + + const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + + const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; + const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; + + __m128i q5l_0 = _mm_and_si128(q5bits_0, m4); + __m128i q5l_1 = _mm_and_si128(q5bits_1, m4); + __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); + __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); + __m128i q5_0 = _mm_add_epi8(q5l_0, q5h_0); + __m128i q5_1 = _mm_add_epi8(q5l_1, q5h_1); + hmask = _mm_slli_epi16(hmask, 1); + + __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0); + __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1); + p16_0 = _mm_madd_epi16(scale_0, p16_0); + p16_1 = _mm_madd_epi16(scale_0, p16_1); + + q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4); + q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4); + q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); + q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); + q5_0 = _mm_add_epi8(q5l_0, q5h_0); + q5_1 = _mm_add_epi8(q5l_1, q5h_1); + hmask = _mm_slli_epi16(hmask, 1); + + q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0); + __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1); + p16_2 = _mm_madd_epi16(scale_1, p16_2); + p16_3 = _mm_madd_epi16(scale_1, p16_3); + + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); + + } + + __m256 vd = _mm256_set1_ps(d); + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); + + } + + *s = hsum_float_8(acc) + summs; + +#else + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const uint8_t * restrict hm = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = ggml_fp16_to_fp32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +#else + +void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q5_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + const uint8x16_t m4b = vdupq_n_u8(0xf); + const uint8x16_t mh = vdupq_n_u8(16); +#if defined(__ARM_FEATURE_DOTPROD) + const int32x4_t mzero = vdupq_n_s32(0); +#endif + + int8x16x4_t q5bytes; + uint8x16x4_t q5h; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * (float)x[i].d; + const int8_t * sc = x[i].scales; + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const uint8x8_t qhbits = vld1_u8(qh); + + const uint8x16x2_t q5bits = vld1q_u8_x2(q5); + const int8x16x4_t q8bytes = vld1q_s8_x4(q8); + + const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1)); + q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4)); + q5h.val[1] = vbicq_u8(mh, vshlq_n_u8(htmp, 2)); + q5h.val[2] = vbicq_u8(mh, htmp); + q5h.val[3] = vbicq_u8(mh, vshrq_n_u8(htmp, 2)); + + q5bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[0], m4b)), vreinterpretq_s8_u8(q5h.val[0])); + q5bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[1], m4b)), vreinterpretq_s8_u8(q5h.val[1])); + q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2])); + q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + + int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0])); + int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1])); + int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2])); + int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3])); + + sumf += d * (sumi1 + sumi2 + sumi3 + sumi4); + +#else + + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1); + + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3); + + sumf += d*sumi; +#endif + + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i mone = _mm256_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); + + const __m256i scale_l = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0])); + const __m256i scale_h = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2])); + + int64_t aux64; + memcpy(&aux64, x[i].qh, 8); + const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64); + const __m256i haux256 = MM256_SET_M128I(_mm_srli_epi16(haux128, 2), haux128); + + const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4); + const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4); + + const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); + const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m256i p16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5l_0, q8_0)); + const __m256i p16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5l_1, q8_1)); + const __m256i s16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5h_0, q8_0)); + const __m256i s16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5h_1, q8_1)); + + const __m256i dot = _mm256_sub_epi32(_mm256_add_epi32(p16_0, p16_1), _mm256_add_epi32(s16_0, s16_1)); + + acc = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(dot), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i mone = _mm_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); + + const __m128i scale_0 = _mm_set1_epi16(x[i].scales[0]); + const __m128i scale_1 = _mm_set1_epi16(x[i].scales[1]); + const __m128i scale_2 = _mm_set1_epi16(x[i].scales[2]); + const __m128i scale_3 = _mm_set1_epi16(x[i].scales[3]); + + int64_t aux64; + memcpy(&aux64, x[i].qh, 8); + const __m128i haux128_0 = _mm_set_epi64x(aux64 >> 1, aux64); + const __m128i haux128_1 = _mm_srli_epi16(haux128_0, 2); + + const __m128i q5h_0 = _mm_slli_epi16(_mm_andnot_si128(haux128_0, mone), 4); + const __m128i q5h_1 = _mm_slli_epi16(_mm_andnot_si128(haux128_1, mone), 4); + const __m128i q5h_2 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_0, 4), mone), 4); + const __m128i q5h_3 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_1, 4), mone), 4); + + const __m128i q5l_0 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 0), m4); + const __m128i q5l_1 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 1), m4); + const __m128i q5l_2 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 0), 4), m4); + const __m128i q5l_3 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 1), 4), m4); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m128i p16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5l_0, _mm256_extractf128_si256(q8_0, 0))); + const __m128i p16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5l_1, _mm256_extractf128_si256(q8_0, 1))); + const __m128i p16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5l_2, _mm256_extractf128_si256(q8_1, 0))); + const __m128i p16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5l_3, _mm256_extractf128_si256(q8_1, 1))); + const __m128i s16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5h_0, _mm256_extractf128_si256(q8_0, 0))); + const __m128i s16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5h_1, _mm256_extractf128_si256(q8_0, 1))); + const __m128i s16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5h_2, _mm256_extractf128_si256(q8_1, 0))); + const __m128i s16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5h_3, _mm256_extractf128_si256(q8_1, 1))); + + const __m128i dot_0 = _mm_sub_epi32(_mm_add_epi32(p16_0, p16_2), _mm_add_epi32(s16_0, s16_2)); + const __m128i dot_1 = _mm_sub_epi32(_mm_add_epi32(p16_1, p16_3), _mm_add_epi32(s16_1, s16_3)); + + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(dot_1, dot_0))), acc); + + } + + *s = hsum_float_8(acc); + +#else + + int8_t aux8[QK_K]; + int16_t aux16[16]; + float sums [8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const uint8_t * restrict hm = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + int8_t * restrict a = aux8; + for (int l = 0; l < 32; ++l) { + a[l+ 0] = q4[l] & 0xF; + a[l+32] = q4[l] >> 4; + } + for (int is = 0; is < 8; ++is) { + uint8_t m = 1 << is; + for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16); + } + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const int8_t * restrict sc = x[i].scales; + + for (int j = 0; j < QK_K/16; ++j) { + const float dl = d * sc[j]; + for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[8+l]); + q8 += 16; a += 16; + } + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} +#endif + + +#if QK_K == 256 +void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q6_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + float sum = 0; + + const uint8x16_t m4b = vdupq_n_u8(0xF); +#if defined(__ARM_FEATURE_DOTPROD) + const int32x4_t vzero = vdupq_n_s32(0); +#endif + //const int8x16_t m32s = vdupq_n_s8(32); + + const uint8x16_t mone = vdupq_n_u8(3); + + int8x16x4_t q6bytes; + uint8x16x4_t q6h; + + for (int i = 0; i < nb; ++i) { + + const float d_all = ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums); + const int8x16_t scales = vld1q_s8(scale); + const int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}; + + const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])), + vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))), + vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])), + vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1])))); + int32_t isum_mins = vaddvq_s32(prod); + + int32_t isum = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + uint8x16x2_t qhbits = vld1q_u8_x2(qh); qh += 32; + uint8x16x4_t q6bits = vld1q_u8_x4(q6); q6 += 64; + int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64; + + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); + uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 2); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); + //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); + //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s); + //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s); + q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])); + q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])); + q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])); + q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + + isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; + scale += 4; + +#else + + int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; + scale += 2; + + int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1]; + scale += 2; +#endif + + q8bytes = vld1q_s8_x4(q8); q8 += 64; + + shifted = vshrq_n_u8(qhbits.val[0], 4); + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 4); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[0], 6); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 6); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s); + //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s); + //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s); + //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s); + q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])); + q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])); + q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])); + q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + + isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; + scale += 4; + + //for (int l = 0; l < 4; ++l) { + // const int32x4_t p = vdotq_s32(vzero, q6bytes.val[l], q8bytes.val[l]); + // isum += vaddvq_s32(p) * *scale++; + //} +#else + p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; + scale += 2; + + p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1]; + scale += 2; +#endif + + } + //sum += isum * d_all * y[i].d; + sum += d_all * y[i].d * (isum - 32 * isum_mins); + + } + *s = sum; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i m2 = _mm256_set1_epi8(3); + const __m256i m32s = _mm256_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); + + __m256i sumi = _mm256_setzero_si256(); + + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); + const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); + const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); + const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); + is += 4; + + const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32; + + const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4); + const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4); + const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4); + const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4); + + const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); + const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1); + const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2); + const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); + __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2); + __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3); + + __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); + __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2); + __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + p16_2 = _mm256_sub_epi16(p16_2, q8s_2); + p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + + p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); + p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2); + p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3)); + + } + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m3 = _mm_set1_epi8(3); + const __m128i m32s = _mm_set1_epi8(32); + const __m128i m2 = _mm_set1_epi8(2); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + __m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); + for (int j = 0; j < QK_K/128; ++j) { + + const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16; + const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16; + + const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4); + const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4); + const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4); + const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4); + const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4); + const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4); + const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4); + const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4); + + const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + + const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0); + const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1); + const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2); + const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3); + const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4); + const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5); + const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6); + const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7); + + const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + + __m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0); + __m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1); + __m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2); + __m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3); + __m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4); + __m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5); + __m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6); + __m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7); + + __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0); + __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1); + __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2); + __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3); + __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4); + __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5); + __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6); + __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7); + + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); + p16_4 = _mm_sub_epi16(p16_4, q8s_4); + p16_5 = _mm_sub_epi16(p16_5, q8s_5); + p16_6 = _mm_sub_epi16(p16_6, q8s_6); + p16_7 = _mm_sub_epi16(p16_7, q8s_7); + + const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi8(shuffle, m2); + const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi8(shuffle, m2); + const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi8(shuffle, m2); + const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi8(shuffle, m2); + + p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1); + p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); + p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3); + p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4); + p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5); + p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6); + p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7); + + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7)); + + } + + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); + } + + *s = hsum_float_8(acc); + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + a += 128; + q4 += 64; + qh += 32; + } + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/16; ++j) { + int scale = x[i].scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +#else + +void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q6_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + float sum = 0; + + const uint8x16_t m4b = vdupq_n_u8(0xF); + const int8x16_t m32s = vdupq_n_s8(32); +#if defined(__ARM_FEATURE_DOTPROD) + const int32x4_t vzero = vdupq_n_s32(0); +#endif + + const uint8x16_t mone = vdupq_n_u8(3); + + int8x16x4_t q6bytes; + uint8x16x4_t q6h; + + for (int i = 0; i < nb; ++i) { + + const float d_all = (float)x[i].d; + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + int32_t isum = 0; + + uint8x16_t qhbits = vld1q_u8(qh); + uint8x16x2_t q6bits = vld1q_u8_x2(q6); + int8x16x4_t q8bytes = vld1q_s8_x4(q8); + + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4); + uint8x16_t shifted = vshrq_n_u8(qhbits, 2); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits, 4); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits, 6); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); + q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); + q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s); + q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s); + +#if defined(__ARM_FEATURE_DOTPROD) + + isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; +#else + + int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; + + int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + isum += vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; +#endif + + sum += isum * d_all * y[i].d; + + } + *s = sum; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i m2 = _mm256_set1_epi8(3); + const __m256i m32s = _mm256_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]); + const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]); + const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]); + const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]); + + __m256i sumi = _mm256_setzero_si256(); + + const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1); + const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3); + + const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); + const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh); + + const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4); + const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4); + + const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); + const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); + + __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + + p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m2 = _mm_set1_epi8(3); + const __m128i m32s = _mm_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]); + const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]); + const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]); + const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]); + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1); + const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3); + + const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); + const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh); + + const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4); + const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4); + const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4); + const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4); + + const __m128i q4_0 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0); + const __m128i q4_1 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_1); + const __m128i q4_2 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_2); + const __m128i q4_3 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_3); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + __m128i q8s_0 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0)); + __m128i q8s_1 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1)); + __m128i q8s_2 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0)); + __m128i q8s_3 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1)); + + __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0)); + __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1)); + __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0)); + __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1)); + + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); + + p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1); + p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); + p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3); + + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); + + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi_1, sumi_0))), acc); + } + + *s = hsum_float_8(acc); + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + for (int l = 0; l < 16; ++l) { + a[l+ 0] = (int8_t)((q4[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l+16] = (int8_t)((q4[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l+32] = (int8_t)((q4[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l+48] = (int8_t)((q4[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + int is = 0; + for (int j = 0; j < QK_K/16; ++j) { + int scale = x[i].scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +#endif diff --git a/plugins/wasi_nn/thirdparty/ggml/k_quants.h b/plugins/wasi_nn/thirdparty/ggml/k_quants.h new file mode 100644 index 00000000..adc6a391 --- /dev/null +++ b/plugins/wasi_nn/thirdparty/ggml/k_quants.h @@ -0,0 +1,165 @@ +#pragma once + +#include "ggml.h" + +#include +#include +#include + +// Super-block size +#ifdef GGML_QKK_64 +#define QK_K 64 +#define K_SCALE_SIZE 4 +#else +#define QK_K 256 +#define K_SCALE_SIZE 12 +#endif + +#ifndef static_assert +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L) +#define static_assert(cond, msg) _Static_assert(cond, msg) +#else +#define static_assert(cond, msg) struct global_scope_noop_trick +#endif +#endif + +// +// Super-block quantization structures +// + +// 2-bit quantization +// weight is represented as x = a * q + b +// 16 blocks of 16 elemenets each +// Effectively 2.5625 bits per weight +typedef struct { + uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits + uint8_t qs[QK_K/4]; // quants + ggml_fp16_t d; // super-block scale for quantized scales + ggml_fp16_t dmin; // super-block scale for quantized mins +} block_q2_K; +static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); + +// 3-bit quantization +// weight is represented as x = a * q +// 16 blocks of 16 elemenets each +// Effectively 3.4375 bits per weight +#ifdef GGML_QKK_64 +typedef struct { + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits + uint8_t scales[2]; + ggml_fp16_t d; // super-block scale +} block_q3_K; +static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding"); +#else +typedef struct { + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits + uint8_t scales[12]; // scales, quantized with 6 bits + ggml_fp16_t d; // super-block scale +} block_q3_K; +static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding"); +#endif + +// 4-bit quantization +// 16 blocks of 32 elements each +// weight is represented as x = a * q + b +// Effectively 4.5 bits per weight +#ifdef GGML_QKK_64 +typedef struct { + ggml_fp16_t d[2]; // super-block scales/mins + uint8_t scales[2]; // 4-bit block scales/mins + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding"); +#else +typedef struct { + ggml_fp16_t d; // super-block scale for quantized scales + ggml_fp16_t dmin; // super-block scale for quantized mins + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding"); +#endif + +// 5-bit quantization +// 16 blocks of 32 elements each +// weight is represented as x = a * q + b +// Effectively 5.5 bits per weight +#ifdef GGML_QKK_64 +typedef struct { + ggml_fp16_t d; // super-block scale + int8_t scales[QK_K/16]; // 8-bit block scales + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); +#else +typedef struct { + ggml_fp16_t d; // super-block scale for quantized scales + ggml_fp16_t dmin; // super-block scale for quantized mins + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +#endif + +// 6-bit quantization +// weight is represented as x = a * q +// 16 blocks of 16 elemenets each +// Effectively 6.5625 bits per weight +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + ggml_fp16_t d; // super-block scale +} block_q6_K; +static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding"); + +// This is only used for intermediate quantization and dot products +typedef struct { + float d; // delta + int8_t qs[QK_K]; // quants + int16_t bsums[QK_K/16]; // sum of quants in groups of 16 +} block_q8_K; +static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding"); + + +// Quantization +void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k); +void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k); +void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k); +void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k); +void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k); +void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k); + +void quantize_row_q2_K(const float * restrict x, void * restrict y, int k); +void quantize_row_q3_K(const float * restrict x, void * restrict y, int k); +void quantize_row_q4_K(const float * restrict x, void * restrict y, int k); +void quantize_row_q5_K(const float * restrict x, void * restrict y, int k); +void quantize_row_q6_K(const float * restrict x, void * restrict y, int k); +void quantize_row_q8_K(const float * restrict x, void * restrict y, int k); + +// Dequantization +void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k); +void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k); +void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k); +void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k); +void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k); +void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k); + +// Dot product +void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); + +// Quantization with histogram collection +size_t ggml_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist); +size_t ggml_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist); +size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist); +size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist); +size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist); + diff --git a/plugins/wasi_nn/thirdparty/ggml/llama.cpp b/plugins/wasi_nn/thirdparty/ggml/llama.cpp new file mode 100644 index 00000000..2a2a0c9c --- /dev/null +++ b/plugins/wasi_nn/thirdparty/ggml/llama.cpp @@ -0,0 +1,6398 @@ +#include "llama.h" + +#include "ggml.h" + +#include "ggml-alloc.h" + +#ifdef GGML_USE_CUBLAS +# include "ggml-cuda.h" +#elif defined(GGML_USE_CLBLAST) +# include "ggml-opencl.h" +#endif + +#ifdef GGML_USE_METAL +# include "ggml-metal.h" +#endif +#ifdef GGML_USE_MPI +# include "ggml-mpi.h" +#endif +#ifdef GGML_USE_K_QUANTS +# ifndef QK_K +# ifdef GGML_QKK_64 +# define QK_K 64 +# else +# define QK_K 256 +# endif +# endif +#endif + +#ifdef __has_include + #if __has_include() + #include + #if defined(_POSIX_MAPPED_FILES) + #include + #endif + #if defined(_POSIX_MEMLOCK_RANGE) + #include + #endif + #endif +#endif + +#if defined(_WIN32) + #define WIN32_LEAN_AND_MEAN + #ifndef NOMINMAX + #define NOMINMAX + #endif + #include + #include + #include // for _fseeki64 +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +#ifdef __GNUC__ +#ifdef __MINGW32__ +#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif +#else +#define LLAMA_ATTRIBUTE_FORMAT(...) +#endif + +// +// logging +// + +LLAMA_ATTRIBUTE_FORMAT(2, 3) +static void llama_log_internal (llama_log_level level, const char* format, ...); +static void llama_log_callback_default(llama_log_level level, const char * text, void * user_data); + +#define LLAMA_LOG_INFO(...) llama_log_internal(LLAMA_LOG_LEVEL_INFO , __VA_ARGS__) +#define LLAMA_LOG_WARN(...) llama_log_internal(LLAMA_LOG_LEVEL_WARN , __VA_ARGS__) +#define LLAMA_LOG_ERROR(...) llama_log_internal(LLAMA_LOG_LEVEL_ERROR, __VA_ARGS__) + +// +// helpers +// + +static size_t utf8_len(char src) { + const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + uint8_t highbits = static_cast(src) >> 4; + return lookup[highbits]; +} + +void replace_all(std::string & s, const std::string & search, const std::string & replace) { + std::string result; + for (size_t pos = 0; ; pos += search.length()) { + auto new_pos = s.find(search, pos); + if (new_pos == std::string::npos) { + result += s.substr(pos, s.size() - pos); + break; + } + result += s.substr(pos, new_pos - pos) + replace; + pos = new_pos; + } + s = std::move(result); +} +#ifdef GGML_USE_CPU_HBM +#include +#endif + +static void zeros(std::ofstream & file, size_t n) { + char zero = 0; + for (size_t i = 0; i < n; ++i) { + file.write(&zero, 1); + } +} + +LLAMA_ATTRIBUTE_FORMAT(1, 2) +static std::string format(const char * fmt, ...) { + va_list ap; + va_list ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + int size = vsnprintf(NULL, 0, fmt, ap); + GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT + std::vector buf(size + 1); + int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); + GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + return std::string(buf.data(), size); +} + +// +// gguf constants (sync with gguf.py) +// + +enum llm_arch { + LLM_ARCH_LLAMA, + LLM_ARCH_FALCON, + LLM_ARCH_GPT2, + LLM_ARCH_GPTJ, + LLM_ARCH_GPTNEOX, + LLM_ARCH_MPT, + LLM_ARCH_UNKNOWN, +}; + +static std::map LLM_ARCH_NAMES = { + { LLM_ARCH_LLAMA, "llama" }, + { LLM_ARCH_FALCON, "falcon" }, + { LLM_ARCH_GPT2, "gpt2" }, + { LLM_ARCH_GPTJ, "gptj" }, + { LLM_ARCH_GPTNEOX, "gptneox" }, + { LLM_ARCH_MPT, "mpt" }, +}; + +enum llm_kv { + LLM_KV_GENERAL_ARCHITECTURE, + LLM_KV_GENERAL_QUANTIZATION_VERSION, + LLM_KV_GENERAL_ALIGNMENT, + LLM_KV_GENERAL_NAME, + LLM_KV_GENERAL_AUTHOR, + LLM_KV_GENERAL_URL, + LLM_KV_GENERAL_DESCRIPTION, + LLM_KV_GENERAL_LICENSE, + LLM_KV_GENERAL_SOURCE_URL, + LLM_KV_GENERAL_SOURCE_HF_REPO, + + LLM_KV_CONTEXT_LENGTH, + LLM_KV_EMBEDDING_LENGTH, + LLM_KV_BLOCK_COUNT, + LLM_KV_FEED_FORWARD_LENGTH, + LLM_KV_USE_PARALLEL_RESIDUAL, + LLM_KV_TENSOR_DATA_LAYOUT, + + LLM_KV_ATTENTION_HEAD_COUNT, + LLM_KV_ATTENTION_HEAD_COUNT_KV, + LLM_KV_ATTENTION_MAX_ALIBI_BIAS, + LLM_KV_ATTENTION_CLAMP_KQV, + LLM_KV_ATTENTION_LAYERNORM_EPS, + LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, + + LLM_KV_ROPE_DIMENSION_COUNT, + LLM_KV_ROPE_FREQ_BASE, + LLM_KV_ROPE_SCALE_LINEAR, + + LLM_KV_TOKENIZER_MODEL, + LLM_KV_TOKENIZER_LIST, + LLM_KV_TOKENIZER_TOKEN_TYPE, + LLM_KV_TOKENIZER_SCORES, + LLM_KV_TOKENIZER_MERGES, + LLM_KV_TOKENIZER_BOS_ID, + LLM_KV_TOKENIZER_EOS_ID, + LLM_KV_TOKENIZER_UNK_ID, + LLM_KV_TOKENIZER_SEP_ID, + LLM_KV_TOKENIZER_PAD_ID, + LLM_KV_TOKENIZER_HF_JSON, + LLM_KV_TOKENIZER_RWKV, +}; + +static std::map LLM_KV_NAMES = { + { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, + { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, + { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, + { LLM_KV_GENERAL_NAME, "general.name" }, + { LLM_KV_GENERAL_AUTHOR, "general.author" }, + { LLM_KV_GENERAL_URL, "general.url" }, + { LLM_KV_GENERAL_DESCRIPTION, "general.description" }, + { LLM_KV_GENERAL_LICENSE, "general.license" }, + { LLM_KV_GENERAL_SOURCE_URL, "general.source_url" }, + { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source_hf_repo" }, + + { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, + { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, + { LLM_KV_BLOCK_COUNT, "%s.block_count" }, + { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, + { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, + { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, + + { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, + { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, + { LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" }, + { LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" }, + { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, + { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, + + { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, + { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, + + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, + { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, + { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, + { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, + { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, + { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" }, + { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, + { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, + { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, + { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, + { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, +}; + +struct LLM_KV { + LLM_KV(llm_arch arch) : arch(arch) {} + + llm_arch arch; + + std::string operator()(llm_kv kv) const { + return ::format(LLM_KV_NAMES[kv].c_str(), LLM_ARCH_NAMES[arch].c_str()); + } +}; + +enum llm_tensor { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_POS_EMBD, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_NORM_2, + LLM_TENSOR_ATTN_ROT_EMBD, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_NORM, +}; + +static std::map> LLM_TENSOR_NAMES = { + { + LLM_ARCH_LLAMA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_FALCON, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_GPT2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + }, + }, + { + LLM_ARCH_GPTJ, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + }, + }, + { + LLM_ARCH_GPTNEOX, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_MPT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + }, + }, + { + LLM_ARCH_UNKNOWN, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + }, + }, +}; + +static llm_arch llm_arch_from_string(const std::string & name) { + for (const auto & kv : LLM_ARCH_NAMES) { // NOLINT + if (kv.second == name) { + return kv.first; + } + } + + return LLM_ARCH_UNKNOWN; +} + +// helper to handle gguf constants +// usage: +// +// const auto tn = LLM_TN(LLM_ARCH_LLAMA); +// +// std::string name = tn(LLM_TENSOR_OUTPUT); -> "output" +// std::string name = tn(LLM_TENSOR_TOKEN_EMBD, "bias"); -> "token_embd.bias" +// std::string name = tn(LLM_TENSOR_ATTN_NORM, "weight", 3); -> "blk.3.attn_norm.weight" +// +struct LLM_TN { + LLM_TN(llm_arch arch) : arch(arch) {} + + llm_arch arch; + + std::string operator()(llm_tensor tensor) const { + return LLM_TENSOR_NAMES[arch].at(tensor); + } + + std::string operator()(llm_tensor tensor, const std::string & suffix) const { + return LLM_TENSOR_NAMES[arch].at(tensor) + "." + suffix; + } + + std::string operator()(llm_tensor tensor, int bid) const { + return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid); + } + + std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const { + return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid) + "." + suffix; + } +}; + +// +// gguf helpers +// + +#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \ +{ \ + const std::string skey(key); \ + const int kid = gguf_find_key(ctx, skey.c_str()); \ + if (kid >= 0) { \ + enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \ + if (ktype != (type)) { \ + throw std::runtime_error(format("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype))); \ + } \ + (dst) = func(ctx, kid); \ + } else if (req) { \ + throw std::runtime_error(format("key not found in model: %s", skey.c_str())); \ + } \ +} + +// +// ggml helpers +// + +static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * graph, int n_threads) { + struct ggml_cplan plan = ggml_graph_plan(graph, n_threads); + + if (plan.work_size > 0) { + buf.resize(plan.work_size); + plan.work_data = buf.data(); + } + + ggml_graph_compute(graph, &plan); +} + +// +// llama helpers +// + +#ifdef GGML_USE_CUBLAS +# define llama_host_malloc(n) ggml_cuda_host_malloc(n) +# define llama_host_free(data) ggml_cuda_host_free(data) +#elif GGML_USE_METAL +# define llama_host_malloc(n) ggml_metal_host_malloc(n) +# define llama_host_free(data) ggml_metal_host_free(data) +#elif GGML_USE_CPU_HBM +# define llama_host_malloc(n) hbw_malloc(n) +# define llama_host_free(data) if (data != NULL) hbw_free(data) +#else +# define llama_host_malloc(n) malloc(n) +# define llama_host_free(data) free(data) +#endif + +#if defined(_WIN32) +static std::string llama_format_win_err(DWORD err) { + LPSTR buf; + size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL); + if (!size) { + return "FormatMessageA failed"; + } + std::string ret(buf, size); + LocalFree(buf); + return ret; +} +#endif + +struct llama_buffer { + void * data = NULL; + size_t size = 0; + + // fallback to malloc / free + // useful in cases where CUDA can try to allocate PINNED memory + bool fallback = false; + + void resize(size_t n) { + llama_host_free(data); + + data = llama_host_malloc(n); + if (!data) { + fallback = true; + data = malloc(n); + } else { + fallback = false; + } + + GGML_ASSERT(data); + size = n; + } + + ~llama_buffer() { + if (data) { + if (fallback) { // NOLINT + free(data); + } else { + llama_host_free(data); + } + } + + data = NULL; + } +}; + +struct llama_file { + // use FILE * so we don't have to re-open the file to mmap + FILE * fp; + size_t size; + + llama_file(const char * fname, const char * mode) { + fp = std::fopen(fname, mode); + if (fp == NULL) { + throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); + } + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + + size_t tell() const { +#ifdef _WIN32 + __int64 ret = _ftelli64(fp); +#else + long ret = std::ftell(fp); +#endif + GGML_ASSERT(ret != -1); // this really shouldn't fail + return (size_t) ret; + } + + void seek(size_t offset, int whence) const { +#ifdef _WIN32 + int ret = _fseeki64(fp, (__int64) offset, whence); +#else + int ret = std::fseek(fp, (long) offset, whence); +#endif + GGML_ASSERT(ret == 0); // same + } + + void read_raw(void * ptr, size_t len) const { + if (len == 0) { + return; + } + errno = 0; + std::size_t ret = std::fread(ptr, len, 1, fp); + if (ferror(fp)) { + throw std::runtime_error(format("read error: %s", strerror(errno))); + } + if (ret != 1) { + throw std::runtime_error(std::string("unexpectedly reached end of file")); + } + } + + uint32_t read_u32() const { + uint32_t ret; + read_raw(&ret, sizeof(ret)); + return ret; + } + + void write_raw(const void * ptr, size_t len) const { + if (len == 0) { + return; + } + errno = 0; + size_t ret = std::fwrite(ptr, len, 1, fp); + if (ret != 1) { + throw std::runtime_error(format("write error: %s", strerror(errno))); + } + } + + void write_u32(std::uint32_t val) const { + write_raw(&val, sizeof(val)); + } + + ~llama_file() { + if (fp) { + std::fclose(fp); + } + } +}; + +struct llama_mmap { + void * addr; + size_t size; + + llama_mmap(const llama_mmap &) = delete; + +#ifdef _POSIX_MAPPED_FILES + static constexpr bool SUPPORTED = true; + + llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */, bool numa = false) { + size = file->size; + int fd = fileno(file->fp); + int flags = MAP_SHARED; + // prefetch/readahead impairs performance on NUMA systems + if (numa) { prefetch = 0; } +#ifdef __linux__ + if (prefetch) { flags |= MAP_POPULATE; } +#endif + addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0); + if (addr == MAP_FAILED) { + throw std::runtime_error(format("mmap failed: %s", strerror(errno))); + } + + if (prefetch > 0) { + // Advise the kernel to preload the mapped memory + if (posix_madvise(addr, std::min(file->size, prefetch), POSIX_MADV_WILLNEED)) { + fprintf(stderr, "warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n", + strerror(errno)); + } + } + if (numa) { + // advise the kernel not to use readahead + // (because the next page might not belong on the same node) + if (posix_madvise(addr, file->size, POSIX_MADV_RANDOM)) { + fprintf(stderr, "warning: posix_madvise(.., POSIX_MADV_RANDOM) failed: %s\n", + strerror(errno)); + } + } + } + + ~llama_mmap() { + munmap(addr, size); + } +#elif defined(_WIN32) + static constexpr bool SUPPORTED = true; + + llama_mmap(struct llama_file * file, bool prefetch = true, bool numa = false) { + (void) numa; + + size = file->size; + + HANDLE hFile = (HANDLE) _get_osfhandle(_fileno(file->fp)); + + HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); + DWORD error = GetLastError(); + + if (hMapping == NULL) { + throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str())); + } + + addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); + error = GetLastError(); + CloseHandle(hMapping); + + if (addr == NULL) { + throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str())); + } + + if (prefetch) { + // PrefetchVirtualMemory is only present on Windows 8 and above, so we dynamically load it + BOOL (WINAPI *pPrefetchVirtualMemory) (HANDLE, ULONG_PTR, PWIN32_MEMORY_RANGE_ENTRY, ULONG); + HMODULE hKernel32 = GetModuleHandleW(L"kernel32.dll"); + + // may fail on pre-Windows 8 systems + pPrefetchVirtualMemory = reinterpret_cast (GetProcAddress(hKernel32, "PrefetchVirtualMemory")); + + if (pPrefetchVirtualMemory) { + // advise the kernel to preload the mapped memory + WIN32_MEMORY_RANGE_ENTRY range; + range.VirtualAddress = addr; + range.NumberOfBytes = (SIZE_T)size; + if (!pPrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) { + fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } + } + } + + ~llama_mmap() { + if (!UnmapViewOfFile(addr)) { + fprintf(stderr, "warning: UnmapViewOfFile failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } +#else + static constexpr bool SUPPORTED = false; + + llama_mmap(struct llama_file * file, bool prefetch = true, bool numa = false) { + (void) file; + (void) prefetch; + (void) numa; + + throw std::runtime_error(std::string("mmap not supported")); + } +#endif +}; + +// Represents some region of memory being locked using mlock or VirtualLock; +// will automatically unlock on destruction. +struct llama_mlock { + void * addr = NULL; + size_t size = 0; + + bool failed_already = false; + + llama_mlock() {} + llama_mlock(const llama_mlock &) = delete; + + ~llama_mlock() { + if (size) { + raw_unlock(addr, size); + } + } + + void init(void * ptr) { + GGML_ASSERT(addr == NULL && size == 0); // NOLINT + addr = ptr; + } + + void grow_to(size_t target_size) { + GGML_ASSERT(addr); + if (failed_already) { + return; + } + size_t granularity = lock_granularity(); + target_size = (target_size + granularity - 1) & ~(granularity - 1); + if (target_size > size) { + if (raw_lock((uint8_t *) addr + size, target_size - size)) { + size = target_size; + } else { + failed_already = true; + } + } + } + +#ifdef _POSIX_MEMLOCK_RANGE + static constexpr bool SUPPORTED = true; + + static size_t lock_granularity() { + return (size_t) sysconf(_SC_PAGESIZE); + } + + #ifdef __APPLE__ + #define MLOCK_SUGGESTION \ + "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \ + "decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l).\n" + #else + #define MLOCK_SUGGESTION \ + "Try increasing RLIMIT_MLOCK ('ulimit -l' as root).\n" + #endif + + bool raw_lock(const void * addr, size_t size) const { + if (!mlock(addr, size)) { + return true; + } + + char* errmsg = std::strerror(errno); + bool suggest = (errno == ENOMEM); + + // Check if the resource limit is fine after all + struct rlimit lock_limit; + if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) { + suggest = false; + } + if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) { + suggest = false; + } + + fprintf(stderr, "warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s", + size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : ""); + return false; + } + + #undef MLOCK_SUGGESTION + + static void raw_unlock(void * addr, size_t size) { + if (munlock(addr, size)) { + fprintf(stderr, "warning: failed to munlock buffer: %s\n", std::strerror(errno)); + } + } +#elif defined(_WIN32) + static constexpr bool SUPPORTED = true; + + static size_t lock_granularity() { + SYSTEM_INFO si; + GetSystemInfo(&si); + return (size_t) si.dwPageSize; + } + + bool raw_lock(void * ptr, size_t len) const { + for (int tries = 1; ; tries++) { + if (VirtualLock(ptr, len)) { + return true; + } + if (tries == 2) { + fprintf(stderr, "warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n", + len, size, llama_format_win_err(GetLastError()).c_str()); + return false; + } + + // It failed but this was only the first try; increase the working + // set size and try again. + SIZE_T min_ws_size, max_ws_size; + if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) { + fprintf(stderr, "warning: GetProcessWorkingSetSize failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + return false; + } + // Per MSDN: "The maximum number of pages that a process can lock + // is equal to the number of pages in its minimum working set minus + // a small overhead." + // Hopefully a megabyte is enough overhead: + size_t increment = len + 1048576; + // The minimum must be <= the maximum, so we need to increase both: + min_ws_size += increment; + max_ws_size += increment; + if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) { + fprintf(stderr, "warning: SetProcessWorkingSetSize failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + return false; + } + } + } + + static void raw_unlock(void * ptr, size_t len) { + if (!VirtualUnlock(ptr, len)) { + fprintf(stderr, "warning: failed to VirtualUnlock buffer: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } +#else + static constexpr bool SUPPORTED = false; + + static size_t lock_granularity() { + return (size_t) 65536; + } + + bool raw_lock(const void * addr, size_t len) const { + fprintf(stderr, "warning: mlock not supported on this system\n"); + return false; + } + + static void raw_unlock(const void * addr, size_t len) {} +#endif +}; + +typedef void (*offload_func_t)(struct ggml_tensor * tensor); + +static void llama_nop(struct ggml_tensor * tensor) { // don't offload by default + (void) tensor; +} + +static std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) { + std::vector result(8, 0); + const int n_tokens = llama_token_to_piece(ctx, token, result.data(), result.size()); + if (n_tokens < 0) { + result.resize(-n_tokens); + int check = llama_token_to_piece(ctx, token, result.data(), result.size()); + GGML_ASSERT(check == -n_tokens); + } else { + result.resize(n_tokens); + } + + return std::string(result.data(), result.size()); +} + +// +// globals +// + +struct llama_state { + // We save the log callback globally + llama_log_callback log_callback = llama_log_callback_default; + void * log_callback_user_data = nullptr; +}; + +static llama_state g_state; + +// available llama models +enum e_model { + MODEL_UNKNOWN, + MODEL_3B, + MODEL_7B, + MODEL_13B, + MODEL_30B, + MODEL_34B, + MODEL_40B, + MODEL_65B, + MODEL_70B, +}; + +static const size_t kB = 1024; +static const size_t MB = kB*kB; + +// default hparams (LLaMA 7B) +struct llama_hparams { + uint32_t n_vocab = 32000; + uint32_t n_ctx_train = 2048; // the context size used during training + uint32_t n_ctx = 512; // the context size used during inference + uint32_t n_embd = 4096; + uint32_t n_head = 32; + uint32_t n_head_kv = 32; + uint32_t n_layer = 32; + uint32_t n_rot = 64; + uint32_t n_ff = 11008; + + float f_norm_eps = 1e-5; + float f_norm_rms_eps = 1e-5; + + float rope_freq_base = 10000.0f; + float rope_freq_scale = 1.0f; + + bool operator!=(const llama_hparams & other) const { + return static_cast(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT + } + + uint32_t n_gqa() const { + return n_head/n_head_kv; + } + + uint32_t n_embd_head() const { + return n_embd/n_head; + } + + uint32_t n_embd_gqa() const { + return n_embd/n_gqa(); + } + + size_t kv_size() const { + size_t result = 2ull; + result *= (size_t) n_embd_gqa(); + result *= (size_t) n_ctx; + result *= (size_t) n_layer; + result *= sizeof(ggml_fp16_t); + return result; + } +}; + +struct llama_layer { + // normalization + struct ggml_tensor * attn_norm; + struct ggml_tensor * attn_norm_b; + struct ggml_tensor * attn_norm_2; + struct ggml_tensor * attn_norm_2_b; + + // attention + struct ggml_tensor * wq; + struct ggml_tensor * wk; + struct ggml_tensor * wv; + struct ggml_tensor * wo; + struct ggml_tensor * wqkv; + + // normalization + struct ggml_tensor * ffn_norm; + + // ff + struct ggml_tensor * w1; // ffn_gate + struct ggml_tensor * w2; // ffn_down + struct ggml_tensor * w3; // ffn_up +}; + +struct llama_kv_cache { + struct ggml_tensor * k = NULL; + struct ggml_tensor * v = NULL; + + struct ggml_context * ctx = NULL; + + llama_buffer buf; + + int n; // number of tokens currently in the cache + + ~llama_kv_cache() { + if (ctx) { + ggml_free(ctx); + } + +#ifdef GGML_USE_CUBLAS + ggml_cuda_free_data(k); + ggml_cuda_free_data(v); +#endif // GGML_USE_CUBLAS + } +}; + +struct llama_vocab { + using id = int32_t; + using token = std::string; + using ttype = llama_token_type; + + struct token_data { + token text; + float score; + ttype type; + }; + + enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; + + std::unordered_map token_to_id; + std::vector id_to_token; + + std::map, int> bpe_ranks; + + // default LLaMA special tokens + id special_bos_id = 1; + id special_eos_id = 2; + id special_unk_id = 0; + id special_sep_id = -1; + id special_pad_id = -1; + + id linefeed_id = 13; + + int find_bpe_rank(std::string token_left, std::string token_right) const { + replace_all(token_left, " ", "\u0120"); + replace_all(token_left, "\n", "\u010A"); + replace_all(token_right, " ", "\u0120"); + replace_all(token_right, "\n", "\u010A"); + + auto it = bpe_ranks.find(std::make_pair(token_left, token_right)); + if (it == bpe_ranks.end()) { + return -1; + } + + return it->second; + } +}; + +struct llama_model { + e_model type = MODEL_UNKNOWN; + llm_arch arch = LLM_ARCH_UNKNOWN; + llama_ftype ftype = LLAMA_FTYPE_ALL_F32; + + std::string name = "n/a"; + + llama_hparams hparams; + llama_vocab vocab; + + struct ggml_tensor * tok_embeddings; + + struct ggml_tensor * output_norm; + struct ggml_tensor * output_norm_b; + struct ggml_tensor * output; + + std::vector layers; + + int n_gpu_layers; + + // context + struct ggml_context * ctx = NULL; + + // the model memory buffer + llama_buffer buf; + + // model memory mapped file + std::unique_ptr mapping; + + // objects representing data potentially being locked in memory + llama_mlock mlock_buf; + llama_mlock mlock_mmap; + + // for quantize-stats only + std::vector> tensors_by_name; + + int64_t t_load_us = 0; + int64_t t_start_us = 0; + + ~llama_model() { + if (ctx) { + ggml_free(ctx); + } + +#ifdef GGML_USE_CUBLAS + for (size_t i = 0; i < tensors_by_name.size(); ++i) { + ggml_cuda_free_data(tensors_by_name[i].second); + } + ggml_cuda_free_scratch(); +#elif defined(GGML_USE_CLBLAST) + for (size_t i = 0; i < tensors_by_name.size(); ++i) { + ggml_cl_free_data(tensors_by_name[i].second); + } +#endif + } +}; + +struct llama_context { + llama_context(const llama_model & model) : model(model), t_load_us(model.t_load_us), t_start_us(model.t_start_us) {} + ~llama_context() { + if (model_owner) { + delete &model; + } +#ifdef GGML_USE_METAL + if (ctx_metal) { + ggml_metal_free(ctx_metal); + } +#endif + if (alloc) { + ggml_allocr_free(alloc); + } + } + + std::mt19937 rng; + + bool has_evaluated_once = false; + + int64_t t_sample_us = 0; + int64_t t_eval_us = 0; + int64_t t_p_eval_us = 0; + + int32_t n_sample = 0; // number of tokens sampled + int32_t n_eval = 0; // number of eval calls + int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) + + const llama_model & model; + + bool model_owner = false; + + int64_t t_load_us; + int64_t t_start_us; + + // key + value cache for the self attention + struct llama_kv_cache kv_self; + + // decode output (2-dimensional array: [n_tokens][n_vocab]) + std::vector logits; + bool logits_all = false; + + // input embedding (1-dimensional array: [n_embd]) + std::vector embedding; + + // reusable buffer for `struct ggml_graph_plan.work_data` + std::vector work_buffer; + + // memory buffers used to evaluate the model + llama_buffer buf_compute; + + llama_buffer buf_alloc; + ggml_allocr * alloc = NULL; + +#ifdef GGML_USE_METAL + ggml_metal_context * ctx_metal = NULL; +#endif + +#ifdef GGML_USE_MPI + ggml_mpi_context * ctx_mpi = NULL; +#endif +}; + +// +// kv cache helpers +// + +static bool llama_kv_cache_init( + const struct llama_hparams & hparams, + struct llama_kv_cache & cache, + ggml_type wtype, + int n_ctx, + int n_gpu_layers) { + const int n_embd = hparams.n_embd_gqa(); + const int n_layer = hparams.n_layer; + + const int64_t n_mem = n_layer*n_ctx; + const int64_t n_elements = n_embd*n_mem; + + cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB); + cache.n = 0; + + struct ggml_init_params params; + params.mem_size = cache.buf.size; + params.mem_buffer = cache.buf.data; + params.no_alloc = false; + + cache.ctx = ggml_init(params); + + if (!cache.ctx) { + LLAMA_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__); + return false; + } + + cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + ggml_set_name(cache.k, "cache_k"); + ggml_set_name(cache.v, "cache_v"); + + (void) n_gpu_layers; +#ifdef GGML_USE_CUBLAS + if (n_gpu_layers > n_layer + 1) { + ggml_cuda_assign_buffers_no_scratch(cache.v); + } + if (n_gpu_layers > n_layer + 2) { + ggml_cuda_assign_buffers_no_scratch(cache.k); + } +#endif // GGML_USE_CUBLAS + + return true; +} + +// +// model loading and saving +// + +enum llama_fver { + GGUF_FILE_VERSION_V1 = 1, + GGUF_FILE_VERSION_V2 = 2, +}; + +static const char * llama_file_version_name(llama_fver version) { + switch (version) { + case GGUF_FILE_VERSION_V1: return "GGUF V1 (support until nov 2023)"; + case GGUF_FILE_VERSION_V2: return "GGUF V2 (latest)"; + } + + return "unknown"; +} + +static std::string llama_format_tensor_shape(const std::vector & ne) { + char buf[256]; + snprintf(buf, sizeof(buf), "%5" PRId64, ne.at(0)); + for (size_t i = 1; i < ne.size(); i++) { + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, ne.at(i)); + } + return buf; +} + +static std::string llama_format_tensor_shape(const struct ggml_tensor * t) { + char buf[256]; + snprintf(buf, sizeof(buf), "%5" PRId64, t->ne[0]); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, t->ne[i]); + } + return buf; +} + +struct llama_model_loader { + int n_kv = 0; + int n_tensors = 0; + int n_created = 0; + + int64_t n_elements = 0; + + bool use_mmap = false; + + llama_file file; + llama_ftype ftype; + llama_fver fver; + + std::unique_ptr mapping; + + struct gguf_context * ctx_gguf = NULL; + struct ggml_context * ctx_meta = NULL; + + llama_model_loader(const std::string & fname, bool use_mmap) : file(fname.c_str(), "rb") { + struct gguf_init_params params = { + /*.no_alloc = */ true, + /*.ctx = */ &ctx_meta, + }; + + ctx_gguf = gguf_init_from_file(fname.c_str(), params); + if (!ctx_gguf) { + throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str())); + } + + n_kv = gguf_get_n_kv(ctx_gguf); + n_tensors = gguf_get_n_tensors(ctx_gguf); + + fver = (enum llama_fver ) gguf_get_version(ctx_gguf); + + for (int i = 0; i < n_tensors; i++) { + const char * name = gguf_get_tensor_name(ctx_gguf, i); + struct ggml_tensor * t = ggml_get_tensor(ctx_meta, name); + n_elements += ggml_nelements(t); + } + + LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n", + __func__, n_kv, n_tensors, fname.c_str(), llama_file_version_name(fver)); + + // determine file type based on the number of tensors for each quantization and print meta data + // TODO: make optional + { + std::map n_type; + + uint32_t n_type_max = 0; + enum ggml_type type_max = GGML_TYPE_F32; + + for (int i = 0; i < n_tensors; i++) { + const char * name = gguf_get_tensor_name(ctx_gguf, i); + struct ggml_tensor * meta = ggml_get_tensor(ctx_meta, name); + + n_type[meta->type]++; + + if (n_type_max < n_type[meta->type]) { + n_type_max = n_type[meta->type]; + type_max = meta->type; + } + + LLAMA_LOG_INFO("%s: - tensor %4d: %32s %-8s [ %s ]\n", __func__, i, name, ggml_type_name(meta->type), llama_format_tensor_shape(meta).c_str()); + } + + switch (type_max) { + case GGML_TYPE_F32: ftype = LLAMA_FTYPE_ALL_F32; break; + case GGML_TYPE_F16: ftype = LLAMA_FTYPE_MOSTLY_F16; break; + case GGML_TYPE_Q4_0: ftype = LLAMA_FTYPE_MOSTLY_Q4_0; break; + case GGML_TYPE_Q4_1: ftype = LLAMA_FTYPE_MOSTLY_Q4_1; break; + case GGML_TYPE_Q5_0: ftype = LLAMA_FTYPE_MOSTLY_Q5_0; break; + case GGML_TYPE_Q5_1: ftype = LLAMA_FTYPE_MOSTLY_Q5_1; break; + case GGML_TYPE_Q8_0: ftype = LLAMA_FTYPE_MOSTLY_Q8_0; break; + case GGML_TYPE_Q2_K: ftype = LLAMA_FTYPE_MOSTLY_Q2_K; break; + case GGML_TYPE_Q3_K: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_M; break; + case GGML_TYPE_Q4_K: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_M; break; + case GGML_TYPE_Q5_K: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M; break; + case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break; + default: + { + LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); + ftype = LLAMA_FTYPE_ALL_F32; + } break; + } + + // this is a way to mark that we have "guessed" the file type + ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED); + + { + const int kid = gguf_find_key(ctx_gguf, "general.file_type"); + if (kid >= 0) { + ftype = (llama_ftype) gguf_get_val_u32(ctx_gguf, kid); + } + } + + for (int i = 0; i < n_kv; i++) { + const char * name = gguf_get_key(ctx_gguf, i); + const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i); + + LLAMA_LOG_INFO("%s: - kv %3d: %42s %-8s\n", __func__, i, name, gguf_type_name(type)); + } + + // print type counts + for (auto & kv : n_type) { + if (kv.second == 0) { + continue; + } + + LLAMA_LOG_INFO("%s: - type %4s: %4d tensors\n", __func__, ggml_type_name(kv.first), kv.second); + } + } + + if (!llama_mmap::SUPPORTED) { + LLAMA_LOG_WARN("%s: mmap is not supported on this platform\n", __func__); + use_mmap = false; + } + + this->use_mmap = use_mmap; + } + + ~llama_model_loader() { + if (ctx_gguf) { + gguf_free(ctx_gguf); + } + if (ctx_meta) { + ggml_free(ctx_meta); + } + } + + std::string get_arch_name() const { + const auto kv = LLM_KV(LLM_ARCH_UNKNOWN); + + std::string arch_name; + GGUF_GET_KEY(ctx_gguf, arch_name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_ARCHITECTURE)); + + return arch_name; + } + + enum llm_arch get_arch() const { + const std::string arch_name = get_arch_name(); + + return llm_arch_from_string(arch_name); + } + + const char * get_tensor_name(int i) const { + return gguf_get_tensor_name(ctx_gguf, i); + } + + struct ggml_tensor * get_tensor_meta(int i) const { + return ggml_get_tensor(ctx_meta, get_tensor_name(i)); + } + + void calc_sizes(size_t & ctx_size_p, size_t & mmapped_size_p) const { + ctx_size_p = 0; + mmapped_size_p = 0; + + for (int i = 0; i < n_tensors; i++) { + struct ggml_tensor * meta = get_tensor_meta(i); + ctx_size_p += sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE; + (use_mmap ? mmapped_size_p : ctx_size_p) += ggml_nbytes_pad(meta); + } + } + + struct ggml_tensor * create_tensor_for(struct ggml_context * ctx, struct ggml_tensor * meta, ggml_backend backend) { + if (backend != GGML_BACKEND_CPU) { + ggml_set_no_alloc(ctx, true); + } + + struct ggml_tensor * tensor = ggml_dup_tensor(ctx, meta); + tensor->backend = backend; // TODO: ggml_set_backend + ggml_set_name(tensor, ggml_get_name(meta)); + + if (backend != GGML_BACKEND_CPU) { + ggml_set_no_alloc(ctx, use_mmap); + } + + n_created++; + + return tensor; + } + + struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector & ne, ggml_backend backend) { + struct ggml_tensor * cur = ggml_get_tensor(ctx_meta, name.c_str()); + + if (cur == NULL) { + throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name.c_str())); + } + + { + bool is_ok = true; + for (size_t i = 0; i < ne.size(); ++i) { + if (ne[i] != cur->ne[i]) { + is_ok = false; + break; + } + } + if (!is_ok) { + throw std::runtime_error( + format("%s: tensor '%s' has wrong shape; expected %s, got %s", + __func__, name.c_str(), + llama_format_tensor_shape(ne).c_str(), + llama_format_tensor_shape(cur).c_str())); + } + } + + return create_tensor_for(ctx, cur, backend); + } + + void done_getting_tensors() const { + if (n_created != n_tensors) { + throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); + } + } + + size_t file_offset(const char * name) const { + const int idx = gguf_find_tensor(ctx_gguf, name); + + if (idx < 0) { + throw std::runtime_error(format("%s: tensor '%s' not found in the file", __func__, name)); + } + + return gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, idx); + } + + void load_data_for(struct ggml_tensor * cur) const { + const size_t offs = file_offset(ggml_get_name(cur)); + + if (use_mmap) { + cur->data = (uint8_t *) mapping->addr + offs; + } else { + file.seek(offs, SEEK_SET); + file.read_raw(cur->data, ggml_nbytes(cur)); + } + } + + void load_all_data(struct ggml_context * ctx, llama_progress_callback progress_callback, void * progress_callback_user_data, llama_mlock * lmlock) { + size_t size_data = 0; + size_t size_lock = 0; + size_t size_pref = 0; // prefetch + + for (int i = 0; i < gguf_get_n_tensors(ctx_gguf); i++) { + struct ggml_tensor * cur = ggml_get_tensor(ctx, gguf_get_tensor_name(ctx_gguf, i)); + size_data += ggml_nbytes(cur); + if (cur->backend == GGML_BACKEND_CPU) { + size_pref += ggml_nbytes(cur); + } + } + + if (use_mmap) { + mapping.reset(new llama_mmap(&file, size_pref, ggml_is_numa())); + if (lmlock) { + lmlock->init(mapping->addr); + } + } + + size_t done_size = 0; + for (int i = 0; i < gguf_get_n_tensors(ctx_gguf); i++) { + struct ggml_tensor * cur = ggml_get_tensor(ctx, gguf_get_tensor_name(ctx_gguf, i)); + GGML_ASSERT(cur); // unused tensors should have been caught by load_data already + + if (progress_callback) { + progress_callback((float) done_size / size_data, progress_callback_user_data); + } + + // allocate temp buffer if not using mmap + if (!use_mmap && cur->data == NULL) { + GGML_ASSERT(cur->backend != GGML_BACKEND_CPU); + #ifdef GGML_USE_CPU_HBM + cur->data = (uint8_t*)hbw_malloc(ggml_nbytes(cur)); + #else + cur->data = (uint8_t*)malloc(ggml_nbytes(cur)); + #endif + } + + load_data_for(cur); + + switch (cur->backend) { + case GGML_BACKEND_CPU: + if (use_mmap && lmlock) { + size_lock += ggml_nbytes(cur); + lmlock->grow_to(size_lock); + } + break; +#if defined(GGML_USE_CUBLAS) + case GGML_BACKEND_GPU: + case GGML_BACKEND_GPU_SPLIT: + // old code: + //ggml_cuda_transform_tensor(lt.data, lt.ggml_tensor); + + // TODO: test if this works !! + ggml_cuda_transform_tensor(cur->data, cur); + if (!use_mmap) { + free(cur->data); + } + break; +#elif defined(GGML_USE_CLBLAST) + case GGML_BACKEND_GPU: + ggml_cl_transform_tensor(cur->data, cur); + if (!use_mmap) { + free(cur->data); + } + break; +#endif + default: + continue; + } + + done_size += ggml_nbytes(cur); + } + } +}; + +// +// load LLaMA models +// + +std::string llama_model_ftype_name(enum llama_ftype ftype) { + if (ftype & LLAMA_FTYPE_GUESSED) { + return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)"; + } + + switch (ftype) { + case LLAMA_FTYPE_ALL_F32: return "all F32"; + case LLAMA_FTYPE_MOSTLY_F16: return "mostly F16"; + case LLAMA_FTYPE_MOSTLY_Q4_0: return "mostly Q4_0"; + case LLAMA_FTYPE_MOSTLY_Q4_1: return "mostly Q4_1"; + case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: + return "mostly Q4_1, some F16"; + case LLAMA_FTYPE_MOSTLY_Q5_0: return "mostly Q5_0"; + case LLAMA_FTYPE_MOSTLY_Q5_1: return "mostly Q5_1"; + case LLAMA_FTYPE_MOSTLY_Q8_0: return "mostly Q8_0"; + + // K-quants + case LLAMA_FTYPE_MOSTLY_Q2_K: return "mostly Q2_K"; + case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "mostly Q3_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "mostly Q3_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "mostly Q3_K - Large"; + case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "mostly Q4_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "mostly Q4_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "mostly Q5_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "mostly Q5_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q6_K: return "mostly Q6_K"; + + default: return "unknown, may not work"; + } +} + +static const char * llama_model_type_name(e_model type) { + switch (type) { + case MODEL_3B: return "3B"; + case MODEL_7B: return "7B"; + case MODEL_13B: return "13B"; + case MODEL_30B: return "30B"; + case MODEL_34B: return "34B"; + case MODEL_40B: return "40B"; + case MODEL_65B: return "65B"; + case MODEL_70B: return "70B"; + default: return "?B"; + } +} + +static void llm_load_arch(llama_model_loader & ml, llama_model & model) { + model.arch = ml.get_arch(); + if (model.arch == LLM_ARCH_UNKNOWN) { + throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'"); + } +} + +static void llm_load_hparams( + llama_model_loader & ml, + llama_model & model, + int n_ctx, + float rope_freq_base, + float rope_freq_scale) { + struct gguf_context * ctx = ml.ctx_gguf; + + const auto kv = LLM_KV(model.arch); + + auto & hparams = model.hparams; + + // get general kv + GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME)); + + // get hparams kv + GGUF_GET_KEY(ctx, hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, kv(LLM_KV_TOKENIZER_LIST)); + GGUF_GET_KEY(ctx, hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_CONTEXT_LENGTH)); + GGUF_GET_KEY(ctx, hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH)); + GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH)); + GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT)); + GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT)); + + // n_head_kv is optional, default to n_head + hparams.n_head_kv = hparams.n_head; + GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV)); + + // TODO: manually setting rope freq base and scale should override this + // FIXME: partial fix when the param specified is not the default value, but + // will not work for overriding the model value to the params default + + llama_context_params defaults = llama_context_default_params(); + + // rope_freq_base + { + float ropebase = 10000.0f; + GGUF_GET_KEY(ctx, ropebase, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE)); + if (ropebase != 10000.0f && rope_freq_base == defaults.rope_freq_base) { + rope_freq_base = ropebase; + } + } + + // rope_freq_scale (inverse of the kv) is optional + { + float ropescale = 1.0f; + GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR)); + if (ropescale != 1.0f && rope_freq_scale == defaults.rope_freq_scale) { + rope_freq_scale = 1.0f/ropescale; + } + } + + // sanity check for n_rot (optional) + { + hparams.n_rot = hparams.n_embd / hparams.n_head; + + GGUF_GET_KEY(ctx, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT)); + + if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) { + if (hparams.n_rot != hparams.n_embd / hparams.n_head) { + throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head)); + } + } + // gpt-neox n_rot = rotary_pct * (n_embd / n_head) + // gpt-j n_rot = rotary_dim + } + + // arch-specific KVs + switch (model.arch) { + case LLM_ARCH_LLAMA: + { + GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS)); + + switch (hparams.n_layer) { + case 26: model.type = e_model::MODEL_3B; break; + case 32: model.type = e_model::MODEL_7B; break; + case 40: model.type = e_model::MODEL_13B; break; + case 48: model.type = e_model::MODEL_34B; break; + case 60: model.type = e_model::MODEL_30B; break; + case 80: model.type = hparams.n_head == hparams.n_head_kv ? e_model::MODEL_65B : e_model::MODEL_70B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_FALCON: + { + GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS)); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 60: model.type = e_model::MODEL_40B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + default: (void)0; + }; + + model.ftype = ml.ftype; + + hparams.n_ctx = n_ctx; + hparams.rope_freq_base = rope_freq_base; + hparams.rope_freq_scale = rope_freq_scale; +} + +// TODO: This should probably be in llama.h +static std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos); +static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch); + +static void llm_load_vocab( + llama_model_loader & ml, + llama_model & model) { + auto & vocab = model.vocab; + + struct gguf_context * ctx = ml.ctx_gguf; + + const auto kv = LLM_KV(model.arch); + + const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str()); + if (token_idx == -1) { + throw std::runtime_error("cannot find tokenizer vocab in model file\n"); + } + + const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str()); + if (score_idx == -1) { + throw std::runtime_error("cannot find tokenizer scores in model file\n"); + } + + const float * scores = (const float * ) gguf_get_arr_data(ctx, score_idx); + + const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str()); + if (toktype_idx == -1) { + throw std::runtime_error("cannot find token type list in GGUF file\n"); + } + + const int * toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); + + // determine vocab type + { + std::string tokenizer_name; + + GGUF_GET_KEY(ctx, tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, kv(LLM_KV_TOKENIZER_MODEL)); + + if (tokenizer_name == "llama") { + vocab.type = LLAMA_VOCAB_TYPE_SPM; + + // default special tokens + vocab.special_bos_id = 1; + vocab.special_eos_id = 2; + vocab.special_unk_id = 0; + vocab.special_sep_id = -1; + vocab.special_pad_id = -1; + } else if (tokenizer_name == "gpt2") { + vocab.type = LLAMA_VOCAB_TYPE_BPE; + + // read bpe merges and populate bpe ranks + const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); + if (merges_keyidx == -1) { + throw std::runtime_error("cannot find tokenizer merges in model file\n"); + } + + const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); + + for (int i = 0; i < n_merges; i++) { + const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); + + std::string first; + std::string second; + + const size_t pos = word.find(' ', 1); + + if (pos != std::string::npos) { + first = word.substr(0, pos); + second = word.substr(pos + 1); + } + + vocab.bpe_ranks.emplace(std::make_pair(first, second), i); + } + + // default special tokens + vocab.special_bos_id = 11; + vocab.special_eos_id = 11; + vocab.special_unk_id = -1; + vocab.special_sep_id = -1; + vocab.special_pad_id = -1; + } else { + LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str()); + LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); + + vocab.type = LLAMA_VOCAB_TYPE_SPM; + } + } + + const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx); + + vocab.id_to_token.resize(n_vocab); + + for (uint32_t i = 0; i < n_vocab; i++) { + std::string word = gguf_get_arr_str(ctx, token_idx, i); + + vocab.token_to_id[word] = i; + + auto & token_data = vocab.id_to_token[i]; + token_data.text = std::move(word); + token_data.score = scores[i]; + token_data.type = (llama_token_type) toktypes[i]; + } + + // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' + if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { + vocab.linefeed_id = llama_byte_to_token(vocab, '\n'); + } else { + vocab.linefeed_id = llama_tokenize_internal(vocab, "\n", false)[0]; + } + + // special tokens + GGUF_GET_KEY(ctx, vocab.special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_BOS_ID)); + GGUF_GET_KEY(ctx, vocab.special_eos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_EOS_ID)); + GGUF_GET_KEY(ctx, vocab.special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_UNK_ID)); + GGUF_GET_KEY(ctx, vocab.special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_SEP_ID)); + GGUF_GET_KEY(ctx, vocab.special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_PAD_ID)); +} + +static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { + const auto & hparams = model.hparams; + const auto & vocab = model.vocab; + + // hparams + LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver)); + LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str()); + LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix + LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab); + LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) vocab.bpe_ranks.size()); + LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); + LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx); + LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); + LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head); + LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); + LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); + LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim + LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa()); + LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); + LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); + LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); + LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base); + LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); + LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str()); + LLAMA_LOG_INFO("%s: model size = %.2f B\n", __func__, ml.n_elements*1e-9); + + // general kv + LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str()); + + // special tokens + if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); } + if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); } + if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); } + if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); } + if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); } + if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); } +} + +static void llm_load_tensors( + llama_model_loader & ml, + llama_model & model, + int n_batch, + int n_gpu_layers, + int main_gpu, + const float * tensor_split, + const bool mul_mat_q, + bool low_vram, + ggml_type memory_type, + bool use_mlock, + llama_progress_callback progress_callback, + void * progress_callback_user_data) { + model.t_start_us = ggml_time_us(); + + auto & ctx = model.ctx; + auto & hparams = model.hparams; + + model.n_gpu_layers = n_gpu_layers; + + size_t ctx_size; + size_t mmapped_size; + + ml.calc_sizes(ctx_size, mmapped_size); + + LLAMA_LOG_INFO("%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/1024.0/1024.0); + + // create the ggml context + { + model.buf.resize(ctx_size); + if (use_mlock) { + model.mlock_buf.init (model.buf.data); + model.mlock_buf.grow_to(model.buf.size); + } + + struct ggml_init_params params = { + /*.mem_size =*/ model.buf.size, + /*.mem_buffer =*/ model.buf.data, + /*.no_alloc =*/ ml.use_mmap, + }; + + model.ctx = ggml_init(params); + if (!model.ctx) { + throw std::runtime_error(format("ggml_init() failed")); + } + } + + (void) main_gpu; + (void) mul_mat_q; +#if defined(GGML_USE_CUBLAS) + LLAMA_LOG_INFO("%s: using " GGML_CUDA_NAME " for GPU acceleration\n", __func__); + ggml_cuda_set_main_device(main_gpu); + ggml_cuda_set_mul_mat_q(mul_mat_q); +#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU +#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU_SPLIT +#elif defined(GGML_USE_CLBLAST) + LLAMA_LOG_INFO("%s: using OpenCL for GPU acceleration\n", __func__); +#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU +#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU +#else +#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_CPU +#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_CPU +#endif + + // prepare memory for the weights + size_t vram_weights = 0; + { + const int64_t n_embd = hparams.n_embd; + const int64_t n_embd_gqa = hparams.n_embd_gqa(); + const int64_t n_layer = hparams.n_layer; + const int64_t n_vocab = hparams.n_vocab; + + const auto tn = LLM_TN(model.arch); + + switch (model.arch) { + case LLM_ARCH_LLAMA: + { + model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + + // output + { + ggml_backend backend_norm; + ggml_backend backend_output; + + if (n_gpu_layers > int(n_layer)) { + // norm is not performance relevant on its own but keeping it in VRAM reduces data copying + // on Windows however this is detrimental unless everything is on the GPU +#ifndef _WIN32 + backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; +#else + backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; +#endif // _WIN32 + + backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } + + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } + } + + const uint32_t n_ff = hparams.n_ff; + + const int i_gpu_start = n_layer - n_gpu_layers; + + model.layers.resize(n_layer); + + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT + const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + + layer.wq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, backend_split); + layer.wk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + + layer.w1 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); + layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += + ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + + ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + + ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3); + } + } + } break; + case LLM_ARCH_FALCON: + { + // TODO: CPU-only for now + + model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + + // output + { + ggml_backend backend_norm; + ggml_backend backend_output; + + if (n_gpu_layers > int(n_layer)) { + // norm is not performance relevant on its own but keeping it in VRAM reduces data copying + // on Windows however this is detrimental unless everything is on the GPU +#ifndef _WIN32 + backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; +#else + backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; +#endif // _WIN32 + + backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } + + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + vram_weights += ggml_nbytes(model.output_norm_b); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } + } + + const uint32_t n_ff = hparams.n_ff; + + const int i_gpu_start = n_layer - n_gpu_layers; + + model.layers.resize(n_layer); + + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT + const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); + + if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i).c_str()) >= 0) { + layer.attn_norm_2 = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, backend); + layer.attn_norm_2_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, backend); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(layer.attn_norm_2); + vram_weights += ggml_nbytes(layer.attn_norm_2_b); + } + } + + layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + + layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += + ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) + + ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.wo) + + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3); + } + } + } break; + default: + throw std::runtime_error("unknown architecture"); + }; + } + + ml.done_getting_tensors(); + + // print memory requirements + { + const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1; + + // this is the total memory required to run the inference + size_t mem_required = + ctx_size + + mmapped_size - vram_weights; // weights in VRAM not in memory + + // this is the memory required by one llama_state + const size_t mem_required_state = scale*hparams.kv_size(); + + LLAMA_LOG_INFO("%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__, + mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); + + (void) n_batch; + +#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) + const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); + + LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_gpu); + if (n_gpu_layers > (int) hparams.n_layer) { + LLAMA_LOG_INFO("%s: offloading non-repeating layers to GPU\n", __func__); + } + size_t vram_kv_cache = 0; + +#ifdef GGML_USE_CUBLAS + const int max_backend_supported_layers = hparams.n_layer + 3; + const int max_offloadable_layers = low_vram ? hparams.n_layer + 1 : hparams.n_layer + 3; + if (n_gpu_layers > (int) hparams.n_layer + 1) { + if (low_vram) { + LLAMA_LOG_INFO("%s: cannot offload v cache to GPU due to low VRAM option\n", __func__); + } else { + LLAMA_LOG_INFO("%s: offloading v cache to GPU\n", __func__); + vram_kv_cache += hparams.kv_size() / 2; + } + } + if (n_gpu_layers > (int) hparams.n_layer + 2) { + if (low_vram) { + LLAMA_LOG_WARN("%s: cannot offload k cache to GPU due to low VRAM option\n", __func__); + } else { + LLAMA_LOG_INFO("%s: offloading k cache to GPU\n", __func__); + vram_kv_cache += hparams.kv_size() / 2; + } + } +#elif defined(GGML_USE_CLBLAST) + const int max_backend_supported_layers = hparams.n_layer + 1; + const int max_offloadable_layers = hparams.n_layer + 1; +#endif // GGML_USE_CUBLAS + + LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", + __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); + LLAMA_LOG_INFO("%s: VRAM used: %zu MB\n", + __func__, (vram_weights + vram_kv_cache + MB - 1) / MB); // round up +#else + (void) n_gpu_layers; +#endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) + } + + // populate `tensors_by_name` + for (int i = 0; i < ml.n_tensors; ++i) { + struct ggml_tensor * cur = ggml_get_tensor(ctx, ml.get_tensor_name(i)); + model.tensors_by_name.emplace_back(ggml_get_name(cur), cur); + } + + (void) tensor_split; +#if defined(GGML_USE_CUBLAS) + { + ggml_cuda_set_tensor_split(tensor_split); + } +#endif + + ml.load_all_data(ctx, progress_callback, progress_callback_user_data, use_mlock ? &model.mlock_mmap : NULL); + + if (progress_callback) { + progress_callback(1.0f, progress_callback_user_data); + } + + model.mapping = std::move(ml.mapping); + + // loading time will be recalculate after the first eval, so + // we take page faults deferred by mmap() into consideration + model.t_load_us = ggml_time_us() - model.t_start_us; +} + +static bool llama_model_load( + const std::string & fname, + llama_model & model, + int n_ctx, + int n_batch, + int n_gpu_layers, + int main_gpu, + const float * tensor_split, + const bool mul_mat_q, + float rope_freq_base, + float rope_freq_scale, + bool low_vram, + ggml_type memory_type, + bool use_mmap, + bool use_mlock, + bool vocab_only, + llama_progress_callback progress_callback, + void *progress_callback_user_data) { + try { + std::unique_ptr ml(new llama_model_loader(fname, use_mmap)); + + llm_load_arch (*ml, model); + llm_load_hparams(*ml, model, n_ctx, rope_freq_base, rope_freq_scale); + llm_load_vocab (*ml, model); + + llm_load_print_meta(*ml, model); + + if (model.hparams.n_vocab != model.vocab.id_to_token.size()) { + throw std::runtime_error("vocab size mismatch"); + } + + if (vocab_only) { + LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__); + return true; + } + + llm_load_tensors( + *ml, model, n_batch, n_gpu_layers, + main_gpu, tensor_split, mul_mat_q, low_vram, memory_type, + use_mlock, progress_callback, progress_callback_user_data); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("error loading model: %s\n", err.what()); + return false; + } + + return true; +} + +static struct ggml_cgraph * llm_build_llama( + llama_context & lctx, + const llama_token * tokens, + const float * embd, + int n_tokens, + int n_past) { + + GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT + + const int N = n_tokens; + + const auto & model = lctx.model; + const auto & hparams = model.hparams; + + const auto & kv_self = lctx.kv_self; + + GGML_ASSERT(!!kv_self.ctx); + + const int64_t n_embd = hparams.n_embd; + const int64_t n_layer = hparams.n_layer; + const int64_t n_ctx = hparams.n_ctx; + const int64_t n_head = hparams.n_head; + const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_embd_head = hparams.n_embd_head(); + const int64_t n_embd_gqa = hparams.n_embd_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_rot); + + const float freq_base = hparams.rope_freq_base; + const float freq_scale = hparams.rope_freq_scale; + const float norm_rms_eps = hparams.f_norm_rms_eps; + + const int n_gpu_layers = model.n_gpu_layers; + + auto & buf_compute = lctx.buf_compute; + + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute.size, + /*.mem_buffer =*/ buf_compute.data, + /*.no_alloc =*/ false, + }; + + params.no_alloc = true; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + if (tokens) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + + ggml_allocr_alloc(lctx.alloc, inp_tokens); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); + } + ggml_set_name(inp_tokens, "inp_tokens"); + + inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); + } else { +#ifdef GGML_USE_MPI + GGML_ASSERT(false && "not implemented"); +#endif + + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); + + ggml_allocr_alloc(lctx.alloc, inpL); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); + } + } + + const int i_gpu_start = n_layer - n_gpu_layers; + (void) i_gpu_start; + + // offload functions set the tensor output backend to GPU + // tensors are GPU-accelerated if any input or the output has been offloaded + // + // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal + // in that case ggml_cuda_assign_buffers has no effect + offload_func_t offload_func_nr = llama_nop; // nr = non-repeating + offload_func_t offload_func_kq = llama_nop; + offload_func_t offload_func_v = llama_nop; + +#ifdef GGML_USE_CUBLAS + if (n_gpu_layers > n_layer) { + offload_func_nr = ggml_cuda_assign_buffers_no_alloc; + } + if (n_gpu_layers > n_layer + 1) { + offload_func_v = ggml_cuda_assign_buffers_no_alloc; + } + if (n_gpu_layers > n_layer + 2) { + offload_func_kq = ggml_cuda_assign_buffers_no_alloc; + } +#endif // GGML_USE_CUBLAS + + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_allocr_alloc(lctx.alloc, KQ_scale); + if (!ggml_allocr_is_measure(lctx.alloc)) { + ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); + } + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + + for (int il = 0; il < n_layer; ++il) { + ggml_format_name(inpL, "layer_inp_%d", il); + + offload_func_t offload_func = llama_nop; + +#ifdef GGML_USE_CUBLAS + if (il >= i_gpu_start) { + offload_func = ggml_cuda_assign_buffers_no_alloc; + } +#endif // GGML_USE_CUBLAS + + struct ggml_tensor * inpSA = inpL; + + // norm + { + cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps); + offload_func(cur); + ggml_set_name(cur, "rms_norm_0"); + + // cur = cur*attn_norm(broadcasted) + cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm); + offload_func(cur); + ggml_set_name(cur, "attention_norm_0"); + } + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + offload_func_kq(tmpk); + ggml_set_name(tmpk, "tmpk"); + + struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + offload_func_kq(tmpq); + ggml_set_name(tmpq, "tmpq"); + + struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); + offload_func_kq(Kcur); + ggml_set_name(Kcur, "Kcur"); + + struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); + offload_func_kq(Qcur); + ggml_set_name(Qcur, "Qcur"); + + // store key and value to memory + { + // compute the transposed [N, n_embd] V matrix + + struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + offload_func_v(tmpv); + ggml_set_name(tmpv, "tmpv"); + + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N)); + offload_func_v(Vcur); + ggml_set_name(Vcur, "Vcur"); + + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); + offload_func_kq(k); + ggml_set_name(k, "k"); + + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); + offload_func_v(v); + ggml_set_name(v, "v"); + + // important: storing RoPE-ed version of K in the KV cache! + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + } + + struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + offload_func_kq(Q); + ggml_set_name(Q, "Q"); + + struct ggml_tensor * K = + ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_past + N, n_head_kv, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); + offload_func_kq(K); + ggml_set_name(K, "K"); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + offload_func_kq(KQ); + ggml_set_name(KQ, "KQ"); + + // KQ_scaled = KQ / sqrt(n_embd_head) + // KQ_scaled shape [n_past + N, N, n_head, 1] + struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + offload_func_kq(KQ_scaled); + ggml_set_name(KQ_scaled, "KQ_scaled"); + + // KQ_masked = mask_past(KQ_scaled) + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + offload_func_kq(KQ_masked); + ggml_set_name(KQ_masked, "KQ_masked"); + + // KQ = soft_max(KQ_masked) + struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + offload_func_v(KQ_soft_max); + ggml_set_name(KQ_soft_max, "KQ_soft_max"); + + // split cached V into n_head heads + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_past + N, n_embd_head, n_head_kv, + ggml_element_size(kv_self.v)*n_ctx, + ggml_element_size(kv_self.v)*n_ctx*n_embd_head, + ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); + offload_func_v(V); + ggml_set_name(V, "V"); + +#if 1 + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + offload_func_v(KQV); + ggml_set_name(KQV, "KQV"); +#else + // make V contiguous in memory to speed up the matmul, however we waste time on the copy + // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation + // is there a better way? + struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head)); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max); +#endif + + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + offload_func_v(KQV_merged); + ggml_set_name(KQV_merged, "KQV_merged"); + + // cur = KQV_merged.contiguous().view(n_embd, N) + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + offload_func_v(cur); + ggml_set_name(cur, "KQV_merged_contiguous"); + + // projection (no bias) + cur = ggml_mul_mat(ctx0, + model.layers[il].wo, + cur); + offload_func(cur); + ggml_set_name(cur, "result_wo"); + } + + struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); + offload_func(inpFF); + ggml_set_name(inpFF, "inpFF"); + + // feed-forward network + { + // norm + { + cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps); + offload_func(cur); + ggml_set_name(cur, "rms_norm_1"); + + // cur = cur*ffn_norm(broadcasted) + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm); + offload_func(cur); + ggml_set_name(cur, "ffn_norm"); + } + + struct ggml_tensor * tmp = ggml_mul_mat(ctx0, + model.layers[il].w3, + cur); + offload_func(tmp); + ggml_set_name(tmp, "result_w3"); + + cur = ggml_mul_mat(ctx0, + model.layers[il].w1, + cur); + offload_func(cur); + ggml_set_name(cur, "result_w1"); + + // SILU activation + cur = ggml_silu(ctx0, cur); + offload_func(cur); + ggml_set_name(cur, "silu"); + + cur = ggml_mul(ctx0, cur, tmp); + offload_func(cur); + ggml_set_name(cur, "silu_x_result_w3"); + + cur = ggml_mul_mat(ctx0, + model.layers[il].w2, + cur); + offload_func(cur); + ggml_set_name(cur, "result_w2"); + } + + cur = ggml_add(ctx0, cur, inpFF); + offload_func(cur); + ggml_set_name(cur, "inpFF_+_result_w2"); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + // norm + { + cur = ggml_rms_norm(ctx0, cur, norm_rms_eps); + offload_func_nr(cur); + ggml_set_name(cur, "rms_norm_2"); + + // cur = cur*norm(broadcasted) + cur = ggml_mul(ctx0, cur, model.output_norm); + // offload_func_nr(cur); // TODO CPU + GPU mirrored backend + ggml_set_name(cur, "result_norm"); + } + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + ggml_set_name(cur, "result_output"); + + ggml_build_forward_expand(gf, cur); + + ggml_free(ctx0); + + return gf; +} + +static struct ggml_cgraph * llm_build_falcon( + llama_context & lctx, + const llama_token * tokens, + const float * embd, + int n_tokens, + int n_past) { + + GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT + + const int N = n_tokens; + + const auto & model = lctx.model; + const auto & hparams = model.hparams; + + const auto & kv_self = lctx.kv_self; + + GGML_ASSERT(!!kv_self.ctx); + + const int64_t n_embd = hparams.n_embd; + const int64_t n_layer = hparams.n_layer; + const int64_t n_ctx = hparams.n_ctx; + const int64_t n_head = hparams.n_head; + const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_embd_head = hparams.n_embd_head(); + const int64_t n_embd_gqa = hparams.n_embd_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_rot); + + const float freq_base = hparams.rope_freq_base; + const float freq_scale = hparams.rope_freq_scale; + const float norm_eps = hparams.f_norm_eps; + + const int n_gpu_layers = model.n_gpu_layers; + + auto & buf_compute = lctx.buf_compute; + + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute.size, + /*.mem_buffer =*/ buf_compute.data, + /*.no_alloc =*/ false, + }; + + params.no_alloc = true; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + if (tokens) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + + ggml_allocr_alloc(lctx.alloc, inp_tokens); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); + } + ggml_set_name(inp_tokens, "inp_tokens"); + + inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); + } else { +#ifdef GGML_USE_MPI + GGML_ASSERT(false && "not implemented"); +#endif + + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); + + ggml_allocr_alloc(lctx.alloc, inpL); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); + } + } + + const int i_gpu_start = n_layer - n_gpu_layers; + (void) i_gpu_start; + + // offload functions set the tensor output backend to GPU + // tensors are GPU-accelerated if any input or the output has been offloaded + // + // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal + // in that case ggml_cuda_assign_buffers has no effect + offload_func_t offload_func_nr = llama_nop; // nr = non-repeating + offload_func_t offload_func_kq = llama_nop; + offload_func_t offload_func_v = llama_nop; + +#ifdef GGML_USE_CUBLAS + if (n_gpu_layers > n_layer) { + offload_func_nr = ggml_cuda_assign_buffers_no_alloc; + } + if (n_gpu_layers > n_layer + 1) { + offload_func_v = ggml_cuda_assign_buffers_no_alloc; + } + if (n_gpu_layers > n_layer + 2) { + offload_func_kq = ggml_cuda_assign_buffers_no_alloc; + } +#endif // GGML_USE_CUBLAS + + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_allocr_alloc(lctx.alloc, KQ_scale); + if (!ggml_allocr_is_measure(lctx.alloc)) { + ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); + } + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * attn_norm; + + offload_func_t offload_func = llama_nop; + +#ifdef GGML_USE_CUBLAS + if (il >= i_gpu_start) { + offload_func = ggml_cuda_assign_buffers_no_alloc; + } +#endif // GGML_USE_CUBLAS + + // self-attention + // TODO: refactor into common function (shared with LLaMA) + { + attn_norm = ggml_norm(ctx0, inpL, norm_eps); + offload_func(attn_norm); + + attn_norm = ggml_add(ctx0, + ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm), + model.layers[il].attn_norm_b); + offload_func(attn_norm->src[0]); + offload_func(attn_norm); + + if (model.layers[il].attn_norm_2) { // Falcon-40B + cur = ggml_norm(ctx0, inpL, norm_eps); + offload_func(cur); + + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, model.layers[il].attn_norm_2), + model.layers[il].attn_norm_2_b); + offload_func(cur->src[0]); + offload_func(cur); + } else { // Falcon 7B + cur = attn_norm; + } + + // compute QKV + + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + offload_func_kq(cur); + + // Note that the strides for Kcur, Vcur are set up so that the + // resulting views are misaligned with the tensor's storage + // (by applying the K/V offset we shift the tensor's original + // view to stick out behind the viewed QKV tensor's allocated + // memory, so to say). This is ok because no actual accesses + // happen to that out-of-range memory, but it can require some + // trickery when trying to accurately dump these views for + // debugging. + + const size_t wsize = ggml_type_size(cur->type); + + // TODO: these 2 ggml_conts are technically not needed, but we add them until CUDA support for + // non-contiguous views is added for the rope operator + struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_3d( + ctx0, cur, n_embd_head, n_head, N, + wsize * n_embd_head, + wsize * n_embd_head * (n_head + 2 * n_head_kv), + 0)); + offload_func_kq(tmpq); + + struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_3d( + ctx0, cur, n_embd_head, n_head_kv, N, + wsize * n_embd_head, + wsize * n_embd_head * (n_head + 2 * n_head_kv), + wsize * n_embd_head * n_head)); + offload_func_kq(tmpk); + + struct ggml_tensor * tmpv = ggml_view_3d( + ctx0, cur, n_embd_head, n_head_kv, N, + wsize * n_embd_head, + wsize * n_embd_head * (n_head + 2 * n_head_kv), + wsize * n_embd_head * (n_head + n_head_kv)); + offload_func_v(tmpv); + + // using mode = 2 for neox mode + struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, tmpq, n_past, n_embd_head, 2, 0, freq_base, freq_scale); + offload_func_kq(Qcur); + struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, tmpk, n_past, n_embd_head, 2, 0, freq_base, freq_scale); + offload_func_kq(Kcur); + + { + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, N)); + offload_func_v(Vcur); + offload_func_v(Vcur->src[0]->src[0]); + ggml_set_name(Vcur, "Vcur"); + + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); + offload_func_kq(k); + ggml_set_name(k, "k"); + + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); + offload_func_v(v); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + } + + struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + offload_func_kq(Q); + ggml_set_name(Q, "Q"); + + struct ggml_tensor * K = + ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_past + N, n_head_kv, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); + offload_func_kq(K); + ggml_set_name(K, "K"); + + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + offload_func_kq(KQ); + ggml_set_name(KQ, "KQ"); + + struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + offload_func_kq(KQ_scaled); + ggml_set_name(KQ_scaled, "KQ_scaled"); + + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + offload_func_kq(KQ_masked); + ggml_set_name(KQ_masked, "KQ_masked"); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + offload_func_v(KQ_soft_max); + ggml_set_name(KQ_soft_max, "KQ_soft_max"); + + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_past + N, n_embd_head, n_head_kv, + ggml_element_size(kv_self.v)*n_ctx, + ggml_element_size(kv_self.v)*n_ctx*n_embd_head, + ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); + offload_func_v(V); + ggml_set_name(V, "V"); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + offload_func_v(KQV); + ggml_set_name(KQV, "KQV"); + + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + offload_func_v(KQV_merged); + ggml_set_name(KQV_merged, "KQV_merged"); + + cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + offload_func_v(cur); + ggml_set_name(cur, "KQV_merged_contiguous"); + + cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); + offload_func(cur); + ggml_set_name(cur, "result_wo"); + } + + struct ggml_tensor * attn_out = cur; + + // feed forward + { + struct ggml_tensor * inpFF = attn_norm; + + cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF); + offload_func(cur); + + cur = ggml_gelu(ctx0, cur); + offload_func(cur); + cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); + offload_func(cur); + } + + cur = ggml_add(ctx0, cur, attn_out); + offload_func(cur); + cur = ggml_add(ctx0, cur, inpL); + offload_func(cur); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + // norm + { + cur = ggml_norm(ctx0, cur, norm_eps); + offload_func_nr(cur); + + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, model.output_norm), + model.output_norm_b); + ggml_set_name(cur, "result_norm"); + } + + cur = ggml_mul_mat(ctx0, model.output, cur); + ggml_set_name(cur, "result_output"); + + ggml_build_forward_expand(gf, cur); + + ggml_free(ctx0); + + return gf; +} + +static struct ggml_cgraph * llama_build_graph( + llama_context & lctx, + const llama_token * tokens, + const float * embd, + int n_tokens, + int n_past) { + const auto & model = lctx.model; + + struct ggml_cgraph * result = NULL; + + switch (model.arch) { + case LLM_ARCH_LLAMA: + { + result = llm_build_llama(lctx, tokens, embd, n_tokens, n_past); + } break; + case LLM_ARCH_FALCON: + { + result = llm_build_falcon(lctx, tokens, embd, n_tokens, n_past); + } break; + default: + GGML_ASSERT(false); + }; + + return result; +} + +// evaluate the transformer +// +// - lctx: llama context +// - tokens: new batch of tokens to process +// - embd embeddings input +// - n_tokens number of tokens +// - n_past: the context size so far +// - n_threads: number of threads to use +// +static bool llama_eval_internal( + llama_context & lctx, + const llama_token * tokens, + const float * embd, + int n_tokens, + int n_past, + int n_threads, + const char * cgraph_fname) { + + GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT + + GGML_ASSERT(n_tokens > 0); + GGML_ASSERT(n_past >= 0); + // TODO: keep the values of n_batch and n_ctx + // GGML_ASSERT(n_tokens <= n_batch); + // GGML_ASSERT(n_past + n_tokens <= n_ctx); + + const int64_t t_start_us = ggml_time_us(); + +#ifdef GGML_USE_MPI + ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); +#endif + + GGML_ASSERT(n_threads > 0); + + const int N = n_tokens; + + const auto & model = lctx.model; + const auto & hparams = model.hparams; + + const auto & kv_self = lctx.kv_self; + + GGML_ASSERT(!!kv_self.ctx); + + const int64_t n_embd = hparams.n_embd; + const int64_t n_vocab = hparams.n_vocab; + + ggml_allocr_reset(lctx.alloc); + + ggml_cgraph * gf = llama_build_graph(lctx, tokens, embd, n_tokens, n_past); + + ggml_allocr_alloc_graph(lctx.alloc, gf); + +#ifdef GGML_USE_CUBLAS + for (int i = 0; i < gf->n_leafs; i++) { + ggml_tensor * node = gf->leafs[i]; + if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) { + ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data); + } + } + + for (int i = 0; i < gf->n_nodes; i++) { + ggml_tensor * node = gf->nodes[i]; + if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) { + ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data); + } + } +#endif + + // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); + + // for big prompts, if BLAS is enabled, it is better to use only one thread + // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance + // TODO: this is mostly important for Apple Silicon where CBLAS is still performing very well + // we still need some threads to process all non-mul_mat ops, but not too much to avoid interfering + // with the BLAS calls. need a better solution + if (N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) { + n_threads = std::min(4, n_threads); + } + + struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; + struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; + + GGML_ASSERT(strcmp(res->name, "result_output") == 0); + GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); + +#if GGML_USE_MPI + const int64_t n_layer = hparams.n_layer; + ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer); +#endif + +#ifdef GGML_USE_METAL + if (lctx.ctx_metal) { + ggml_metal_set_n_cb (lctx.ctx_metal, n_threads); + ggml_metal_graph_compute(lctx.ctx_metal, gf); + ggml_metal_get_tensor (lctx.ctx_metal, res); + if (!lctx.embedding.empty()) { + ggml_metal_get_tensor(lctx.ctx_metal, embeddings); + } + } else { + ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); + } +#else + ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); +#endif + +#if GGML_USE_MPI + ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer); +#endif + + // update kv token count + lctx.kv_self.n = n_past + N; + + if (cgraph_fname) { + ggml_graph_export(gf, cgraph_fname); + } + +#ifdef GGML_PERF + // print timing information per ggml operation (for debugging purposes) + // requires GGML_PERF to be defined + ggml_graph_print(gf); +#endif + + // plot the computation graph in dot format (for debugging purposes) + //if (n_past%100 == 0) { + // ggml_graph_dump_dot(gf, NULL, "llama.dot"); + //} + + // extract logits + { + auto & logits_out = lctx.logits; + + if (lctx.logits_all) { + logits_out.resize(n_vocab * N); + memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*N); + } else { + // return result for just the last token + logits_out.resize(n_vocab); + memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + } + } + + // extract embeddings + if (!lctx.embedding.empty()) { + auto & embedding_out = lctx.embedding; + + embedding_out.resize(n_embd); + memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd); + } + + // measure the performance only for the single-token evals + if (N == 1) { + lctx.t_eval_us += ggml_time_us() - t_start_us; + lctx.n_eval++; + } + else if (N > 1) { + lctx.t_p_eval_us += ggml_time_us() - t_start_us; + lctx.n_p_eval += N; + } + + return true; +} + +// +// tokenizer +// + +static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) { + return vocab.type; +} + +static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL; +} + +static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_UNKNOWN; +} + +static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_CONTROL; +} + +static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_BYTE; +} + +static uint8_t llama_token_to_byte(const llama_vocab & vocab, llama_token id) { + GGML_ASSERT(llama_is_byte_token(vocab, id)); + const auto& token_data = vocab.id_to_token.at(id); + auto buf = token_data.text.substr(3, 2); + return strtol(buf.c_str(), NULL, 16); +} + +static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) { + char buf[7]; + int result = snprintf(buf, sizeof(buf), "<0x%02X>", ch); + GGML_ASSERT(0 <= result && result < 7); + return vocab.token_to_id.at(buf); +} + +static void llama_escape_whitespace(std::string & text) { + replace_all(text, " ", "\xe2\x96\x81"); +} + +static void llama_unescape_whitespace(std::string & word) { + replace_all(word, "\xe2\x96\x81", " "); +} + +struct llm_symbol { + using index = int; + index prev; + index next; + const char * text; + size_t n; +}; + +static_assert(std::is_trivially_copyable::value, "llm_symbol is not trivially copyable"); + +// SPM tokenizer +// original implementation: +// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4 + +struct llm_bigram_spm { + struct comparator { + bool operator()(llm_bigram_spm & l, llm_bigram_spm & r) { + return (l.score < r.score) || (l.score == r.score && l.left > r.left); + } + }; + using queue_storage = std::vector; + using queue = std::priority_queue; + llm_symbol::index left; + llm_symbol::index right; + float score; + size_t size; +}; + +struct llm_tokenizer_spm { + llm_tokenizer_spm(const llama_vocab & vocab): vocab(vocab) {} + + void tokenize(const std::string & text, std::vector & output) { + // split string into utf8 chars + int index = 0; + size_t offs = 0; + while (offs < text.size()) { + llm_symbol sym; + size_t len = utf8_len(text[offs]); + GGML_ASSERT(offs + len <= text.size()); + sym.text = text.c_str() + offs; + sym.n = len; + offs += len; + sym.prev = index - 1; + sym.next = offs == text.size() ? -1 : index + 1; + index++; + symbols.emplace_back(sym); + } + + // seed the work queue with all possible 2-character tokens. + for (size_t i = 1; i < symbols.size(); ++i) { + try_add_bigram(i - 1, i); + } + + // keep substituting the highest frequency pairs for as long as we can. + while (!work_queue.empty()) { + auto bigram = work_queue.top(); + work_queue.pop(); + + auto & left_sym = symbols[bigram.left]; + auto & right_sym = symbols[bigram.right]; + + // if one of the symbols already got merged, skip it. + if (left_sym.n == 0 || right_sym.n == 0 || + left_sym.n + right_sym.n != bigram.size) { + continue; + } + + // merge the right sym into the left one + left_sym.n += right_sym.n; + right_sym.n = 0; + + //LLAMA_LOG_INFO("left = '%*s' size = %zu\n", (int) left_sym.n, left_sym.text, bigram.size); + + // remove the right sym from the chain + left_sym.next = right_sym.next; + if (right_sym.next >= 0) { + symbols[right_sym.next].prev = bigram.left; + } + + // find more substitutions + try_add_bigram(left_sym.prev, bigram.left); + try_add_bigram(bigram.left, left_sym.next); + } + + for (int i = 0; i != -1; i = symbols[i].next) { + auto & symbol = symbols[i]; + resegment(symbol, output); + } + } + +private: + void resegment(llm_symbol & symbol, std::vector & output) { + auto text = std::string(symbol.text, symbol.n); + auto token = vocab.token_to_id.find(text); + + // Do we need to support is_unused? + if (token != vocab.token_to_id.end()) { + output.push_back((*token).second); + return; + } + + const auto p = rev_merge.find(text); + + if (p == rev_merge.end()) { + // output any symbols that did not form tokens as bytes. + for (int j = 0; j < (int)symbol.n; ++j) { + llama_vocab::id token_id = llama_byte_to_token(vocab, symbol.text[j]); + output.push_back(token_id); + } + return; + } + + resegment(symbols[p->second.first], output); + resegment(symbols[p->second.second], output); + } + + void try_add_bigram(int left, int right) { + if (left == -1 || right == -1) { + return; + } + + const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n); + auto token = vocab.token_to_id.find(text); + + if (token == vocab.token_to_id.end()) { + return; + } + + if (static_cast((*token).second) >= vocab.id_to_token.size()) { + return; + } + + const auto & tok_data = vocab.id_to_token[(*token).second]; + + llm_bigram_spm bigram; + bigram.left = left; + bigram.right = right; + bigram.score = tok_data.score; + bigram.size = text.size(); + + work_queue.push(bigram); + + // Do we need to support is_unused? + rev_merge[text] = std::make_pair(left, right); + } + + const llama_vocab & vocab; + + std::vector symbols; + llm_bigram_spm::queue work_queue; + + std::map> rev_merge; +}; + +// BPE tokenizer +// adapted from https://github.com/cmp-nct/ggllm.cpp [MIT License] +// tried to simplify unicode stuff, so most likely does not work 100% correctly! + +// TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused + +struct llm_bigram_bpe { + struct comparator { + bool operator()(const llm_bigram_bpe & l, const llm_bigram_bpe & r) const { + return l.rank > r.rank || (l.rank == r.rank && l.left > r.left); + } + }; + + using queue_storage = std::vector; + using queue = std::priority_queue; + llm_symbol::index left; + llm_symbol::index right; + std::string text; + int rank; + size_t size; +}; + +struct llm_tokenizer_bpe { + llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {} + + void tokenize(const std::string & text, std::vector & output) { + int final_prev_index = -1; + auto word_collection = bpe_gpt2_preprocess(text); + + symbols_final.clear(); + + for (auto & word : word_collection) { + work_queue = llm_bigram_bpe::queue(); + symbols.clear(); + + int index = 0; + size_t offset = 0; + + while (offset < word.size()) { + llm_symbol sym; + size_t char_len = std::min(word.size() - offset, (size_t) ::utf8_len(word[offset])); + sym.text = word.c_str() + offset; + sym.n = 1; + sym.n = char_len; + offset += sym.n; + sym.prev = index - 1; + sym.next = offset == word.size() ? -1 : index + 1; + index++; + symbols.emplace_back(sym); + } + for (size_t i = 1; i < symbols.size(); ++i) { + add_new_bigram(i - 1, i); + } + + // build token(s) + while (!work_queue.empty()) { + auto bigram = work_queue.top(); + work_queue.pop(); + + auto & left_symbol = symbols[bigram.left]; + auto & right_symbol = symbols[bigram.right]; + + if (left_symbol.n == 0 || right_symbol.n == 0) { + continue; + } + std::string left_token = std::string(left_symbol.text, left_symbol.n); + std::string right_token = std::string(right_symbol.text, right_symbol.n); + if (left_token + right_token != bigram.text) { + continue; // Skip this bigram if it's outdated + } + + // merge the right sym into the left one + left_symbol.n += right_symbol.n; + right_symbol.n = 0; + + // remove the right sym from the chain + left_symbol.next = right_symbol.next; + if (right_symbol.next >= 0) { + symbols[right_symbol.next].prev = bigram.left; + } + + add_new_bigram(left_symbol.prev, bigram.left); // left side of current symbol + add_new_bigram(bigram.left, left_symbol.next); // right side of current symbol + } + + // add the fnished tokens to the final list keeping correct order for next and prev + for (auto & sym : symbols) { + if (sym.n > 0) { + sym.prev = final_prev_index; + sym.next = -1; + if (final_prev_index != -1) { + symbols_final[final_prev_index].next = symbols_final.size(); + } + symbols_final.emplace_back(sym); + final_prev_index = symbols_final.size() - 1; + } + } + } + + symbols = symbols_final; + + if (!symbols.empty()) { + for (int i = 0; i != -1; i = symbols[i].next) { + auto & symbol = symbols[i]; + if (symbol.n == 0) { + continue; + } + + const std::string str = std::string(symbol.text, symbol.n); + const auto token = vocab.token_to_id.find(str); + + if (token == vocab.token_to_id.end()) { + for (auto j = str.begin(); j != str.end(); ++j) { + std::string byte_str(1, *j); + auto token_multibyte = vocab.token_to_id.find(byte_str); + if (token_multibyte == vocab.token_to_id.end()) { + try { + llama_token token_byte = llama_byte_to_token(vocab, *j); + output.push_back(token_byte); + } catch (const std::out_of_range & err) { + fprintf(stderr,"ERROR: byte not found in vocab: '%s'\n", byte_str.c_str()); + } + } else { + output.push_back((*token_multibyte).second); + } + } + } else { + output.push_back((*token).second); + } + } + } + } + +private: + void add_new_bigram(int left, int right) { + if (left == -1 || right == -1) { + return; + } + + std::string left_token = std::string(symbols[left].text, symbols[left].n); + std::string right_token = std::string(symbols[right].text, symbols[right].n); + + int rank_found = -1; + + rank_found = vocab.find_bpe_rank(left_token, right_token); + + if (rank_found < 0) { + return; + } + + llm_bigram_bpe bigram; + + bigram.left = left; + bigram.right = right; + bigram.text = left_token + right_token; + bigram.size = left_token.size() + right_token.size(); + bigram.rank = rank_found; + + work_queue.push(bigram); + } + + // probably not 100% correct + static std::vector bpe_gpt2_preprocess(const std::string & text) { + std::vector words; + + // ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53 + const std::string pattern = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; + const std::regex re(pattern); + + auto words_begin = std::sregex_iterator(text.begin(), text.end(), re); + auto words_end = std::sregex_iterator(); + auto n_words = std::distance(words_begin, words_end); + words.reserve(n_words); + for (auto it = words_begin; it != words_end; ++it) { + words.push_back(it->str()); + } + return words; + + } + + const llama_vocab & vocab; + + std::vector symbols; + std::vector symbols_final; + + llm_bigram_bpe::queue work_queue; +}; + +static std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos) { + std::vector output; + + // OG tokenizer behavior: + // + // tokenizer.encode('', add_bos=True) returns [1] + // tokenizer.encode('', add_bos=False) returns [] + + if (bos && vocab.special_bos_id != -1) { + output.push_back(vocab.special_bos_id); + } + + if (raw_text.empty()) { + return output; + } + + switch (vocab.type) { + case LLAMA_VOCAB_TYPE_SPM: + { + // without adding this leading whitespace, we do not get the same results as the original tokenizer + raw_text = " " + raw_text; + + llm_tokenizer_spm tokenizer(vocab); + llama_escape_whitespace(raw_text); + tokenizer.tokenize(raw_text, output); + } break; + case LLAMA_VOCAB_TYPE_BPE: + { + llm_tokenizer_bpe tokenizer(vocab); + tokenizer.tokenize(raw_text, output); + } break; + }; + + return output; +} + +// +// grammar - internal +// + +struct llama_partial_utf8 { + uint32_t value; // bit value so far (unshifted) + int n_remain; // num bytes remaining; -1 indicates invalid sequence +}; + +struct llama_grammar { + const std::vector> rules; + std::vector> stacks; + + // buffer for partially generated UTF-8 sequence from accepted tokens + llama_partial_utf8 partial_utf8; +}; + +struct llama_grammar_candidate { + size_t index; + const uint32_t * code_points; + llama_partial_utf8 partial_utf8; +}; + +// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as +// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`. +std::pair, llama_partial_utf8> decode_utf8( + const char * src, + llama_partial_utf8 partial_start) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; + const char * pos = src; + std::vector code_points; + uint32_t value = partial_start.value; + int n_remain = partial_start.n_remain; + + // continue previous decode, if applicable + while (*pos != 0 && n_remain > 0) { + uint8_t next_byte = static_cast(*pos); + if ((next_byte >> 6) != 2) { + // invalid sequence, abort + code_points.push_back(0); + return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 }); + } + value = (value << 6) + (next_byte & 0x3F); + ++pos; + --n_remain; + } + + if (partial_start.n_remain > 0 && n_remain == 0) { + code_points.push_back(value); + } + + // decode any subsequent utf-8 sequences, which may end in an incomplete one + while (*pos != 0) { + uint8_t first_byte = static_cast(*pos); + uint8_t highbits = first_byte >> 4; + n_remain = lookup[highbits] - 1; + + if (n_remain < 0) { + // invalid sequence, abort + code_points.clear(); + code_points.push_back(0); + return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain }); + } + + uint8_t mask = (1 << (7 - n_remain)) - 1; + value = first_byte & mask; + ++pos; + while (*pos != 0 && n_remain > 0) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + ++pos; + --n_remain; + } + if (n_remain == 0) { + code_points.push_back(value); + } + } + code_points.push_back(0); + + return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain }); +} + +// returns true iff pos points to the end of one of the definitions of a rule +static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) { + switch (pos->type) { + case LLAMA_GRETYPE_END: return true; // NOLINT + case LLAMA_GRETYPE_ALT: return true; // NOLINT + default: return false; + } +} + +// returns true iff chr satisfies the char range at pos (regular or inverse range) +// asserts that pos is pointing to a char range element +static std::pair llama_grammar_match_char( + const llama_grammar_element * pos, + const uint32_t chr) { + + bool found = false; + bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR; + + GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT + + do { + if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + found = found || (pos->value <= chr && chr <= pos[1].value); + pos += 2; + } else { + // exact char match, e.g. [a] or "a" + found = found || pos->value == chr; + pos += 1; + } + } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); + + return std::make_pair(found == is_positive_char, pos); +} + +// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char +// range at pos (regular or inverse range) +// asserts that pos is pointing to a char range element +static bool llama_grammar_match_partial_char( + const llama_grammar_element * pos, + const llama_partial_utf8 partial_utf8) { + + bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR; + GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); + + uint32_t partial_value = partial_utf8.value; + int n_remain = partial_utf8.n_remain; + + // invalid sequence or 7-bit char split across 2 bytes (overlong) + if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) { + return false; + } + + // range of possible code points this partial UTF-8 sequence could complete to + uint32_t low = partial_value << (n_remain * 6); + uint32_t high = low | ((1 << (n_remain * 6)) - 1); + + if (low == 0) { + if (n_remain == 2) { + low = 1 << 11; + } else if (n_remain == 3) { + low = 1 << 16; + } + } + + do { + if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + if (pos->value <= high && low <= pos[1].value) { + return is_positive_char; + } + pos += 2; + } else { + // exact char match, e.g. [a] or "a" + if (low <= pos->value && pos->value <= high) { + return is_positive_char; + } + pos += 1; + } + } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); + + return !is_positive_char; +} + + +// transforms a grammar pushdown stack into N possible stacks, all ending +// at a character range (terminal element) +static void llama_grammar_advance_stack( + const std::vector> & rules, + const std::vector & stack, + std::vector> & new_stacks) { + + if (stack.empty()) { + new_stacks.emplace_back(stack); + return; + } + + const llama_grammar_element * pos = stack.back(); + + switch (pos->type) { + case LLAMA_GRETYPE_RULE_REF: { + const size_t rule_id = static_cast(pos->value); + const llama_grammar_element * subpos = rules[rule_id].data(); + do { + // init new stack without the top (pos) + std::vector new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos + 1)) { + // if this rule ref is followed by another element, add that to stack + new_stack.push_back(pos + 1); + } + if (!llama_grammar_is_end_of_sequence(subpos)) { + // if alternate is nonempty, add to stack + new_stack.push_back(subpos); + } + llama_grammar_advance_stack(rules, new_stack, new_stacks); + while (!llama_grammar_is_end_of_sequence(subpos)) { + // scan to end of alternate def + subpos++; + } + if (subpos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + subpos++; + } else { + break; + } + } while (true); + break; + } + case LLAMA_GRETYPE_CHAR: + case LLAMA_GRETYPE_CHAR_NOT: + new_stacks.emplace_back(stack); + break; + default: + // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range + // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on + // those + GGML_ASSERT(false); + } +} + +// takes a set of possible pushdown stacks on a grammar, which are required to +// be positioned at a character range (see `llama_grammar_advance_stack`), and +// produces the N possible stacks if the given char is accepted at those +// positions +static std::vector> llama_grammar_accept( + const std::vector> & rules, + const std::vector> & stacks, + const uint32_t chr) { + + std::vector> new_stacks; + + for (const auto & stack : stacks) { + if (stack.empty()) { + continue; + } + + auto match = llama_grammar_match_char(stack.back(), chr); + if (match.first) { + const llama_grammar_element * pos = match.second; + + // update top of stack to next element, if any + std::vector new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos)) { + new_stack.push_back(pos); + } + llama_grammar_advance_stack(rules, new_stack, new_stacks); + } + } + + return new_stacks; +} + +static std::vector llama_grammar_reject_candidates( + const std::vector> & rules, + const std::vector> & stacks, + const std::vector & candidates); + +static std::vector llama_grammar_reject_candidates_for_stack( + const std::vector> & rules, + const std::vector & stack, + const std::vector & candidates) { + + std::vector rejects; + + if (stack.empty()) { + for (auto tok : candidates) { + if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) { + rejects.push_back(tok); + } + } + return rejects; + } + + const llama_grammar_element * stack_pos = stack.back(); + + std::vector next_candidates; + for (auto tok : candidates) { + if (*tok.code_points == 0) { + // reached end of full codepoints in token, reject iff it ended in a partial sequence + // that cannot satisfy this position in grammar + if (tok.partial_utf8.n_remain != 0 && + !llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) { + rejects.push_back(tok); + } + } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) { + next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 }); + } else { + rejects.push_back(tok); + } + } + + const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second; + + // update top of stack to next element, if any + std::vector stack_after(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(stack_pos_after)) { + stack_after.push_back(stack_pos_after); + } + std::vector> next_stacks; + llama_grammar_advance_stack(rules, stack_after, next_stacks); + + auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); + for (auto tok : next_rejects) { + rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 }); + } + + return rejects; +} + +static std::vector llama_grammar_reject_candidates( + const std::vector> & rules, + const std::vector> & stacks, + const std::vector & candidates) { + GGML_ASSERT(!stacks.empty()); // REVIEW + + if (candidates.empty()) { + return std::vector(); + } + + auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates); + + for (size_t i = 1, size = stacks.size(); i < size; ++i) { + rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects); + } + return rejects; +} + +// +// grammar - external +// + +struct llama_grammar * llama_grammar_init( + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index) { + const llama_grammar_element * pos; + + // copy rule definitions into vectors + std::vector> vec_rules(n_rules); + for (size_t i = 0; i < n_rules; i++) { + for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) { + vec_rules[i].push_back(*pos); + } + vec_rules[i].push_back({LLAMA_GRETYPE_END, 0}); + } + + // loop over alternates of start rule to build initial stacks + std::vector> stacks; + pos = rules[start_rule_index]; + do { + std::vector stack; + if (!llama_grammar_is_end_of_sequence(pos)) { + // if alternate is nonempty, add to stack + stack.push_back(pos); + } + llama_grammar_advance_stack(vec_rules, stack, stacks); + while (!llama_grammar_is_end_of_sequence(pos)) { + // scan to end of alternate def + pos++; + } + if (pos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + pos++; + } else { + break; + } + } while (true); + + return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} }; +} + +void llama_grammar_free(struct llama_grammar * grammar) { + delete grammar; +} + +struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) { + llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 }; + + // redirect elements in stacks to point to new rules + for (size_t is = 0; is < result->stacks.size(); is++) { + for (size_t ie = 0; ie < result->stacks[is].size(); ie++) { + for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) { + for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) { + if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) { + result->stacks[is][ie] = &result->rules[ir0][ir1]; + } + } + } + } + } + + return result; +} + +// +// sampling +// + +void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { + GGML_ASSERT(candidates->size > 0); + + const int64_t t_start_sample_us = ggml_time_us(); + + // Sort the logits in descending order + if (!candidates->sorted) { + std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + candidates->sorted = true; + } + + float max_l = candidates->data[0].logit; + float cum_sum = 0.0f; + for (size_t i = 0; i < candidates->size; ++i) { + float p = expf(candidates->data[i].logit - max_l); + candidates->data[i].p = p; + cum_sum += p; + } + for (size_t i = 0; i < candidates->size; ++i) { + candidates->data[i].p /= cum_sum; + } + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep) { + const int64_t t_start_sample_us = ggml_time_us(); + + k = std::max(k, (int) min_keep); + k = std::min(k, (int) candidates->size); + + // Sort scores in descending order + if (!candidates->sorted) { + auto comp = [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }; + if (k == (int) candidates->size) { + std::sort(candidates->data, candidates->data + candidates->size, comp); + } else { + std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); + } + candidates->sorted = true; + } + candidates->size = k; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { + if (p >= 1.0f) { + return; + } + + llama_sample_softmax(ctx, candidates); + + const int64_t t_start_sample_us = ggml_time_us(); + + // Compute the cumulative probabilities + float cum_sum = 0.0f; + size_t last_idx = candidates->size; + + for (size_t i = 0; i < candidates->size; ++i) { + cum_sum += candidates->data[i].p; + + // Check if the running sum is at least p or if we have kept at least min_keep tokens + // we set the last index to i+1 to indicate that the current iterate should be included in the set + if (cum_sum >= p && i + 1 >= min_keep) { + last_idx = i + 1; + break; + } + } + + // Resize the output vector to keep only the top-p tokens + candidates->size = last_idx; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { + if (z >= 1.0f || candidates->size <= 2) { + return; + } + + llama_sample_softmax(nullptr, candidates); + const int64_t t_start_sample_us = ggml_time_us(); + + // Compute the first and second derivatives + std::vector first_derivatives(candidates->size - 1); + std::vector second_derivatives(candidates->size - 2); + + for (size_t i = 0; i < first_derivatives.size(); ++i) { + first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p; + } + for (size_t i = 0; i < second_derivatives.size(); ++i) { + second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1]; + } + + // Calculate absolute value of second derivatives + for (size_t i = 0; i < second_derivatives.size(); ++i) { + second_derivatives[i] = std::abs(second_derivatives[i]); + } + + // Normalize the second derivatives + { + const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f); + + if (second_derivatives_sum > 1e-6f) { + for (float & value : second_derivatives) { + value /= second_derivatives_sum; + } + } else { + for (float & value : second_derivatives) { + value = 1.0f / second_derivatives.size(); + } + } + } + + float cum_sum = 0.0f; + size_t last_idx = candidates->size; + for (size_t i = 0; i < second_derivatives.size(); ++i) { + cum_sum += second_derivatives[i]; + + // Check if the running sum is greater than z or if we have kept at least min_keep tokens + if (cum_sum > z && i >= min_keep) { + last_idx = i; + break; + } + } + + // Resize the output vector to keep only the tokens above the tail location + candidates->size = last_idx; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { + // Reference implementation: + // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr + if (p >= 1.0f) { + return; + } + + // Compute the softmax of logits and calculate entropy + llama_sample_softmax(nullptr, candidates); + + const int64_t t_start_sample_us = ggml_time_us(); + + float entropy = 0.0f; + for (size_t i = 0; i < candidates->size; ++i) { + entropy += -candidates->data[i].p * logf(candidates->data[i].p); + } + + // Compute the absolute difference between negative log probability and entropy for each candidate + std::vector shifted_scores; + for (size_t i = 0; i < candidates->size; ++i) { + float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy); + shifted_scores.push_back(shifted_score); + } + + // Sort tokens based on the shifted_scores and their corresponding indices + std::vector indices(candidates->size); + std::iota(indices.begin(), indices.end(), 0); + + std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) { + return shifted_scores[a] < shifted_scores[b]; + }); + + // Compute the cumulative probabilities + float cum_sum = 0.0f; + size_t last_idx = indices.size(); + + for (size_t i = 0; i < indices.size(); ++i) { + size_t idx = indices[i]; + cum_sum += candidates->data[idx].p; + + // Check if the running sum is greater than typical or if we have kept at least min_keep tokens + if (cum_sum > p && i >= min_keep - 1) { + last_idx = i + 1; + break; + } + } + + // Resize the output vector to keep only the locally typical tokens + std::vector new_candidates; + for (size_t i = 0; i < last_idx; ++i) { + size_t idx = indices[i]; + new_candidates.push_back(candidates->data[idx]); + } + + // Replace the data in candidates with the new_candidates data + std::copy(new_candidates.begin(), new_candidates.end(), candidates->data); + candidates->size = new_candidates.size(); + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { + const int64_t t_start_sample_us = ggml_time_us(); + + for (size_t i = 0; i < candidates_p->size; ++i) { + candidates_p->data[i].logit /= temp; + } + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty) { + if (last_tokens_size == 0 || penalty == 1.0f) { + return; + } + + const int64_t t_start_sample_us = ggml_time_us(); + + for (size_t i = 0; i < candidates->size; ++i) { + const auto * token_iter = std::find(last_tokens, last_tokens + last_tokens_size, candidates->data[i].id); + if (token_iter == last_tokens + last_tokens_size) { + continue; + } + + // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. + // This is common fix for this problem, which is to multiply by the penalty instead of dividing. + if (candidates->data[i].logit <= 0) { + candidates->data[i].logit *= penalty; + } else { + candidates->data[i].logit /= penalty; + } + } + + candidates->sorted = false; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) { + if (last_tokens_size == 0 || (alpha_frequency == 0.0f && alpha_presence == 0.0f)) { + return; + } + + const int64_t t_start_sample_us = ggml_time_us(); + + // Create a frequency map to count occurrences of each token in last_tokens + std::unordered_map token_count; + for (size_t i = 0; i < last_tokens_size; ++i) { + token_count[last_tokens_p[i]]++; + } + + // Apply frequency and presence penalties to the candidates + for (size_t i = 0; i < candidates->size; ++i) { + auto token_iter = token_count.find(candidates->data[i].id); + if (token_iter == token_count.end()) { + continue; + } + + int count = token_iter->second; + candidates->data[i].logit -= float(count) * alpha_frequency + float(count > 0) * alpha_presence; + } + + candidates->sorted = false; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) { + GGML_ASSERT(ctx); + const int64_t t_start_sample_us = ggml_time_us(); + + bool allow_eos = false; + for (const auto & stack : grammar->stacks) { + if (stack.empty()) { + allow_eos = true; + break; + } + } + + const llama_token eos = llama_token_eos(ctx); + + std::vector, llama_partial_utf8>> candidates_decoded; + std::vector candidates_grammar; + + for (size_t i = 0; i < candidates->size; ++i) { + const llama_token id = candidates->data[i].id; + const std::string piece = llama_token_to_str(ctx, id); + if (id == eos) { + if (!allow_eos) { + candidates->data[i].logit = -INFINITY; + } + } else if (piece.empty() || piece[0] == 0) { + candidates->data[i].logit = -INFINITY; + } else { + candidates_decoded.push_back(decode_utf8(piece.c_str(), grammar->partial_utf8)); + candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); + } + } + + const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar); + for (const auto & reject : rejects) { + candidates->data[reject.index].logit = -INFINITY; + } + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; +} + +static void llama_log_softmax(float * array, size_t size) { + float max_l = *std::max_element(array, array + size); + float sum = 0.f; + for (size_t i = 0; i < size; ++i) { + float p = expf(array[i] - max_l); + sum += p; + array[i] = p; + } + + for (size_t i = 0; i < size; ++i) { + array[i] = logf(array[i] / sum); + } +} + +void llama_sample_classifier_free_guidance( + struct llama_context * ctx, + llama_token_data_array * candidates, + struct llama_context * guidance_ctx, + float scale) { + int64_t t_start_sample_us = ggml_time_us(); + + GGML_ASSERT(ctx); + + auto n_vocab = llama_n_vocab(ctx); + + GGML_ASSERT(n_vocab == (int)candidates->size); + GGML_ASSERT(!candidates->sorted); + + std::vector logits_base; + logits_base.reserve(candidates->size); + for (size_t i = 0; i < candidates->size; ++i) { + logits_base.push_back(candidates->data[i].logit); + } + llama_log_softmax(logits_base.data(), candidates->size); + + float* logits_guidance = llama_get_logits(guidance_ctx); + llama_log_softmax(logits_guidance, n_vocab); + + for (int i = 0; i < n_vocab; ++i) { + float logit_guidance = logits_guidance[i]; + float logit_base = logits_base[i]; + candidates->data[i].logit = scale * (logit_base - logit_guidance) + logit_guidance; + } + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) { + GGML_ASSERT(ctx); + + auto N = float(llama_n_vocab(ctx)); + int64_t t_start_sample_us; + t_start_sample_us = ggml_time_us(); + + llama_sample_softmax(nullptr, candidates); + + // Estimate s_hat using the most probable m tokens + float s_hat = 0.0; + float sum_ti_bi = 0.0; + float sum_ti_sq = 0.0; + for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) { + float t_i = logf(float(i + 2) / float(i + 1)); + float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p); + sum_ti_bi += t_i * b_i; + sum_ti_sq += t_i * t_i; + } + s_hat = sum_ti_bi / sum_ti_sq; + + // Compute k from the estimated s_hat and target surprise value + float epsilon_hat = s_hat - 1; + float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat); + + // Sample the next word X using top-k sampling + llama_sample_top_k(nullptr, candidates, int(k), 1); + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } + llama_token X = llama_sample_token(ctx, candidates); + t_start_sample_us = ggml_time_us(); + + // Compute error as the difference between observed surprise and target surprise value + size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + return candidate.id == X; + })); + float observed_surprise = -log2f(candidates->data[X_idx].p); + float e = observed_surprise - tau; + + // Update mu using the learning rate and error + *mu = *mu - eta * e; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } + return X; +} + +llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) { + int64_t t_start_sample_us; + t_start_sample_us = ggml_time_us(); + + llama_sample_softmax(ctx, candidates); + + // Truncate the words with surprise values greater than mu + candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + return -log2f(candidate.p) > *mu; + })); + + if (candidates->size == 0) { + candidates->size = 1; + } + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } + + // Normalize the probabilities of the remaining words + llama_sample_softmax(ctx, candidates); + + // Sample the next word X from the remaining words + llama_token X = llama_sample_token(ctx, candidates); + t_start_sample_us = ggml_time_us(); + + // Compute error as the difference between observed surprise and target surprise value + size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + return candidate.id == X; + })); + float observed_surprise = -log2f(candidates->data[X_idx].p); + float e = observed_surprise - tau; + + // Update mu using the learning rate and error + *mu = *mu - eta * e; + + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } + return X; +} + +llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) { + const int64_t t_start_sample_us = ggml_time_us(); + + // Find max element + auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { + return a.logit < b.logit; + }); + + llama_token result = max_iter->id; + if (ctx) { + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + ctx->n_sample++; + } + return result; +} + +llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { + GGML_ASSERT(ctx); + + const int64_t t_start_sample_us = ggml_time_us(); + llama_sample_softmax(nullptr, candidates); + + std::vector probs; + probs.reserve(candidates->size); + for (size_t i = 0; i < candidates->size; ++i) { + probs.push_back(candidates->data[i].p); + } + + std::discrete_distribution<> dist(probs.begin(), probs.end()); + auto & rng = ctx->rng; + int idx = dist(rng); + + llama_token result = candidates->data[idx].id; + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + ctx->n_sample++; + return result; +} + +void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { + const int64_t t_start_sample_us = ggml_time_us(); + + if (token == llama_token_eos(ctx)) { + for (const auto & stack : grammar->stacks) { + if (stack.empty()) { + return; + } + } + GGML_ASSERT(false); + } + + const std::string piece = llama_token_to_str(ctx, token); + + // Note terminating 0 in decoded string + const auto decoded = decode_utf8(piece.c_str(), grammar->partial_utf8); + const auto & code_points = decoded.first; + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); + } + grammar->partial_utf8 = decoded.second; + GGML_ASSERT(!grammar->stacks.empty()); + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; +} + +// +// Beam search +// + +struct llama_beam { + std::vector tokens; + float p; // Cumulative beam probability (renormalized relative to all beams) + bool eob; // Initialize end-of-beam to false. Callback sets this to true. + // Sort beams by probability. In case of ties, prefer beams at eob. + bool operator<(const llama_beam & rhs) const { + return std::make_pair(p, eob) < std::make_pair(rhs.p, rhs.eob); + } + // Shift off first n tokens and discard them. + void shift_tokens(const size_t n) { + if (n) { + std::copy(tokens.begin() + n, tokens.end(), tokens.begin()); + tokens.resize(tokens.size() - n); + } + } + llama_beam_view view() const { return {tokens.data(), tokens.size(), p, eob}; } +}; + +// A struct for calculating logit-related info. +struct llama_logit_info { + const float * const logits; + const int n_vocab; + const float max_l; + const float normalizer; + struct sum_exp { + float max_l; + float operator()(float sum, float l) const { return sum + std::exp(l - max_l); } + }; + llama_logit_info(llama_context * ctx) + : logits(llama_get_logits(ctx)) + , n_vocab(llama_n_vocab(ctx)) + , max_l(*std::max_element(logits, logits + n_vocab)) + , normalizer(1.0f / std::accumulate(logits, logits + n_vocab, 0.0f, sum_exp{max_l})) + { } + llama_token_data get_token_data(const llama_token token_id) const { + constexpr auto p = std::numeric_limits::quiet_NaN(); // never used + return {token_id, logits[token_id], p}; + } + // Return top k token_data by logit. + std::vector top_k(size_t k) { + std::vector min_heap; // min-heap by logit + const llama_token k_min = std::min(static_cast(k), n_vocab); + min_heap.reserve(k_min); + for (llama_token token_id = 0 ; token_id < k_min ; ++token_id) { + min_heap.push_back(get_token_data(token_id)); + } + auto comp = [](const llama_token_data & a, const llama_token_data & b) { return a.logit > b.logit; }; + std::make_heap(min_heap.begin(), min_heap.end(), comp); + for (llama_token token_id = k_min ; token_id < n_vocab ; ++token_id) { + if (min_heap.front().logit < logits[token_id]) { + std::pop_heap(min_heap.begin(), min_heap.end(), comp); + min_heap.back().id = token_id; + min_heap.back().logit = logits[token_id]; + std::push_heap(min_heap.begin(), min_heap.end(), comp); + } + } + return min_heap; + } + float probability_from_logit(float logit) const { + return normalizer * std::exp(logit - max_l); + } +}; + +struct llama_beam_search_data { + llama_context * ctx; + size_t n_beams; + int n_past; + int n_predict; + int n_threads; + std::vector beams; + std::vector next_beams; + + // Re-calculated on each loop iteration + size_t common_prefix_length; + + // Used to communicate to/from callback on beams state. + std::vector beam_views; + + llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict, int n_threads) + : ctx(ctx) + , n_beams(n_beams) + , n_past(n_past) + , n_predict(n_predict) + , n_threads(n_threads) + , beam_views(n_beams) { + beams.reserve(n_beams); + next_beams.reserve(n_beams); + } + + // Collapse beams to a single beam given by index. + void collapse_beams(const size_t beam_idx) { + if (0u < beam_idx) { + std::swap(beams[0], beams[beam_idx]); + } + beams.resize(1); + } + + // Min-heaps are used to efficiently collect the top-k elements (k=n_beams). + // The repetative patterns below reflect the 2 stages of heaps: + // * Gather elements until the vector is full, then call std::make_heap() on it. + // * If the heap is full and a new element is found that should be included, pop the + // least element to the back(), replace it with the new, then push it into the heap. + void fill_next_beams_by_top_probabilities(llama_beam & beam) { + // Min-heaps use a greater-than comparator. + const auto comp = [](const llama_beam & a, const llama_beam & b) { return a.p > b.p; }; + if (beam.eob) { + // beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough. + if (next_beams.size() < n_beams) { + next_beams.push_back(std::move(beam)); + if (next_beams.size() == n_beams) { + std::make_heap(next_beams.begin(), next_beams.end(), comp); + } + } else if (next_beams.front().p < beam.p) { + std::pop_heap(next_beams.begin(), next_beams.end(), comp); + next_beams.back() = std::move(beam); + std::push_heap(next_beams.begin(), next_beams.end(), comp); + } + } else { + // beam is not at end-of-sentence, so branch with next top_k tokens. + if (!beam.tokens.empty()) { + llama_eval(ctx, beam.tokens.data(), beam.tokens.size(), n_past, n_threads); + } + llama_logit_info logit_info(ctx); + std::vector next_tokens = logit_info.top_k(n_beams); + size_t i=0; + if (next_beams.size() < n_beams) { + for (; next_beams.size() < n_beams ; ++i) { + llama_beam next_beam = beam; + next_beam.tokens.push_back(next_tokens[i].id); + next_beam.p *= logit_info.probability_from_logit(next_tokens[i].logit); + next_beams.push_back(std::move(next_beam)); + } + std::make_heap(next_beams.begin(), next_beams.end(), comp); + } else { + for (; next_beams.front().p == 0.0f ; ++i) { + std::pop_heap(next_beams.begin(), next_beams.end(), comp); + next_beams.back() = beam; + next_beams.back().tokens.push_back(next_tokens[i].id); + next_beams.back().p *= logit_info.probability_from_logit(next_tokens[i].logit); + std::push_heap(next_beams.begin(), next_beams.end(), comp); + } + } + for (; i < n_beams ; ++i) { + const float next_p = beam.p * logit_info.probability_from_logit(next_tokens[i].logit); + if (next_beams.front().p < next_p) { + std::pop_heap(next_beams.begin(), next_beams.end(), comp); + next_beams.back() = beam; + next_beams.back().tokens.push_back(next_tokens[i].id); + next_beams.back().p = next_p; + std::push_heap(next_beams.begin(), next_beams.end(), comp); + } + } + } + } + + // Find common_prefix_length based on beams. + // Requires beams is not empty. + size_t find_common_prefix_length() { + size_t common_prefix_length = beams[0].tokens.size(); + for (size_t i = 1 ; i < beams.size() ; ++i) { + common_prefix_length = std::min(common_prefix_length, beams[i].tokens.size()); + for (size_t j = 0 ; j < common_prefix_length ; ++j) { + if (beams[0].tokens[j] != beams[i].tokens[j]) { + common_prefix_length = j; + break; + } + } + } + return common_prefix_length; + } + + // Construct beams_state to send back to caller via the callback function. + // Side effect: set common_prefix_length = find_common_prefix_length(); + llama_beams_state get_beams_state(const bool last_call) { + for (size_t i = 0 ; i < beams.size() ; ++i) { + beam_views[i] = beams[i].view(); + } + common_prefix_length = find_common_prefix_length(); + return {beam_views.data(), beams.size(), common_prefix_length, last_call}; + } + + // Loop: + // * while i < n_predict, AND + // * any of the beams have not yet reached end-of-beam (eob), AND + // * the highest probability beam(s) (plural in case of ties) are not at end-of-sentence + // (since all other beam probabilities can only decrease) + void loop(const llama_beam_search_callback_fn_t callback, void * const callback_data) { + beams.push_back({{}, 1.0f, false}); // Start with one empty beam w/ probability = 1.0 and !eob. + const auto not_eob = [](const llama_beam & beam) { return !beam.eob; }; + for (int i = 0 ; i < n_predict && std::any_of(beams.begin(),beams.end(),not_eob) && + !beams[top_beam_index()].eob ; ++i) { + callback(callback_data, get_beams_state(false)); // Sets common_prefix_length + update_beams_from_beam_views(); // Update values (p,eob) that callback may have changed. + if (common_prefix_length) { + llama_eval(ctx, beams[0].tokens.data(), common_prefix_length, n_past, n_threads); + n_past += common_prefix_length; + } + // Zero-out next_beam probabilities to place them last in following min-heap. + std::for_each(next_beams.begin(), next_beams.end(), [](llama_beam & beam) { beam.p = 0.0f; }); + for (llama_beam & beam : beams) { + beam.shift_tokens(common_prefix_length); + fill_next_beams_by_top_probabilities(beam); + } + // next_beams become the beams of next/final iteration. Swap them to re-use memory. + beams.swap(next_beams); + renormalize_beam_probabilities(beams); + } + collapse_beams(top_beam_index()); + callback(callback_data, get_beams_state(true)); + } + + // As beams grow, the cumulative probabilities decrease. + // Renormalize them to avoid floating point underflow. + static void renormalize_beam_probabilities(std::vector & beams) { + const auto sum_p = [](float sum, llama_beam & beam) { return sum + beam.p; }; + const float inv_sum = 1.0f / std::accumulate(beams.begin(), beams.end(), 0.0f, sum_p); + std::for_each(beams.begin(), beams.end(), [=](llama_beam & beam) { beam.p *= inv_sum; }); + } + + // Assumes beams is non-empty. Uses llama_beam::operator<() for ordering. + size_t top_beam_index() { + return std::max_element(beams.begin(), beams.end()) - beams.begin(); + } + + // Copy (p,eob) for each beam which may have been changed by the callback. + void update_beams_from_beam_views() { + for (size_t i = 0 ; i < beams.size() ; ++i) { + beams[i].p = beam_views[i].p; + beams[i].eob = beam_views[i].eob; + } + } +}; + +void llama_beam_search(llama_context * ctx, + llama_beam_search_callback_fn_t callback, void * callback_data, + size_t n_beams, int n_past, int n_predict, int n_threads) { + assert(ctx); + const int64_t t_start_sample_us = ggml_time_us(); + + llama_beam_search_data beam_search_data(ctx, n_beams, n_past, n_predict, n_threads); + + beam_search_data.loop(callback, callback_data); + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + ctx->n_sample++; +} + +// +// quantization +// + +static void llama_convert_tensor_internal(struct ggml_tensor * tensor, std::vector & output, const size_t nelements, const int nthread) { + if (output.size() < nelements) { + output.resize(nelements); + } + float * f32_output = (float *) output.data(); + + ggml_type_traits_t qtype; + if (ggml_is_quantized(tensor->type)) { + qtype = ggml_internal_get_type_traits(tensor->type); + if (qtype.to_float == NULL) { + throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type))); + } + } else if (tensor->type != GGML_TYPE_F16) { + throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type))); + } + + if (nthread < 2) { + if (tensor->type == GGML_TYPE_F16) { + ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements); + } else if (ggml_is_quantized(tensor->type)) { + qtype.to_float(tensor->data, f32_output, nelements); + } else { + GGML_ASSERT(false); // unreachable + } + return; + } + + auto block_size = tensor->type == GGML_TYPE_F16 ? 1 : (size_t)ggml_blck_size(tensor->type); + auto block_size_bytes = ggml_type_size(tensor->type); + + GGML_ASSERT(nelements % block_size == 0); + auto nblocks = nelements / block_size; + auto blocks_per_thread = nblocks / nthread; + auto spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count + + std::vector workers; + for (auto tnum = 0, in_buff_offs = 0, out_buff_offs = 0; tnum < nthread; tnum++) { + auto thr_blocks = blocks_per_thread + (tnum == nthread - 1 ? spare_blocks : 0); // num blocks for this thread + auto thr_elems = thr_blocks * block_size; // number of elements for this thread + auto thr_block_bytes = thr_blocks * block_size_bytes; // number of input bytes for this thread + + auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) { + if (typ == GGML_TYPE_F16) { + ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels); + } else { + qtype.to_float(inbuf, outbuf, nels); + } + }; + workers.push_back(std::thread(compute, tensor->type, (uint8_t *) tensor->data + in_buff_offs, f32_output + out_buff_offs, thr_elems)); + in_buff_offs += thr_block_bytes; + out_buff_offs += thr_elems; + } + for (auto & worker : workers) { + worker.join(); + } +} + +static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { + ggml_type quantized_type; + llama_ftype ftype = params->ftype; + + switch (params->ftype) { + case LLAMA_FTYPE_MOSTLY_Q4_0: quantized_type = GGML_TYPE_Q4_0; break; + case LLAMA_FTYPE_MOSTLY_Q4_1: quantized_type = GGML_TYPE_Q4_1; break; + case LLAMA_FTYPE_MOSTLY_Q5_0: quantized_type = GGML_TYPE_Q5_0; break; + case LLAMA_FTYPE_MOSTLY_Q5_1: quantized_type = GGML_TYPE_Q5_1; break; + case LLAMA_FTYPE_MOSTLY_Q8_0: quantized_type = GGML_TYPE_Q8_0; break; + case LLAMA_FTYPE_MOSTLY_F16: quantized_type = GGML_TYPE_F16; break; + case LLAMA_FTYPE_ALL_F32: quantized_type = GGML_TYPE_F32; break; + +#ifdef GGML_USE_K_QUANTS + // K-quants + case LLAMA_FTYPE_MOSTLY_Q2_K: quantized_type = GGML_TYPE_Q2_K; break; + case LLAMA_FTYPE_MOSTLY_Q3_K_S: + case LLAMA_FTYPE_MOSTLY_Q3_K_M: + case LLAMA_FTYPE_MOSTLY_Q3_K_L: quantized_type = GGML_TYPE_Q3_K; break; + case LLAMA_FTYPE_MOSTLY_Q4_K_S: + case LLAMA_FTYPE_MOSTLY_Q4_K_M: quantized_type = GGML_TYPE_Q4_K; break; + case LLAMA_FTYPE_MOSTLY_Q5_K_S: + case LLAMA_FTYPE_MOSTLY_Q5_K_M: quantized_type = GGML_TYPE_Q5_K; break; + case LLAMA_FTYPE_MOSTLY_Q6_K: quantized_type = GGML_TYPE_Q6_K; break; +#endif + default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); + } + + int nthread = params->nthread; + + if (nthread <= 0) { + nthread = std::thread::hardware_concurrency(); + } + + std::unique_ptr ml(new llama_model_loader(fname_inp, /*use_mmap*/ false)); + + llama_model model; + llm_load_arch(*ml, model); + llm_load_hparams(*ml, model, 0, 0, 0); + + if (params->only_copy) { + ftype = model.ftype; + } + + const size_t align = GGUF_DEFAULT_ALIGNMENT; + struct gguf_context * ctx_out = gguf_init_empty(); + + // copy the KV pairs from the input file + gguf_set_kv (ctx_out, ml->ctx_gguf); + gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION); + gguf_set_val_u32(ctx_out, "general.file_type", ftype); + +#ifdef GGML_USE_K_QUANTS + int n_attention_wv = 0; + int n_feed_forward_w2 = 0; + + for (int i = 0; i < ml->n_tensors; ++i) { + struct ggml_tensor * meta = ml->get_tensor_meta(i); + + const std::string name = ggml_get_name(meta); + + // TODO: avoid hardcoded tensor names - use the TN_* constants + if (name.find("attn_v.weight") != std::string::npos) { + ++n_attention_wv; + } + else if (name.find("ffn_down.weight") != std::string::npos) { + ++n_feed_forward_w2; + } + } + if (n_attention_wv != n_feed_forward_w2 || (uint32_t)n_attention_wv != model.hparams.n_layer) { + LLAMA_LOG_WARN("%s ============ Strange model: n_attention_wv = %d, n_feed_forward_w2 = %d, hparams.n_layer = %d\n", + __func__, n_attention_wv, n_feed_forward_w2, model.hparams.n_layer); + } + + int i_attention_wv = 0; + int i_feed_forward_w2 = 0; +#endif + + size_t total_size_org = 0; + size_t total_size_new = 0; + std::vector hist_all(1 << 4, 0); + + std::vector workers; + std::mutex mutex; + +#ifdef GGML_USE_K_QUANTS + auto use_more_bits = [] (int i_layer, int num_layers) -> bool { + return i_layer < num_layers/8 || i_layer >= 7*num_layers/8 || (i_layer - num_layers/8)%3 == 2; + }; +#endif + + int idx = 0; + + std::vector read_data; + std::vector work; + + // populate the original tensors so we get an initial meta data + for (int i = 0; i < ml->n_tensors; ++i) { + struct ggml_tensor * meta = ml->get_tensor_meta(i); + gguf_add_tensor(ctx_out, meta); + } + + std::ofstream fout(fname_out, std::ios::binary); + + const size_t meta_size = gguf_get_meta_size(ctx_out); + + LLAMA_LOG_INFO("%s: meta size = %zu bytes\n", __func__, meta_size); + + // placeholder for the meta data + ::zeros(fout, meta_size); + + for (int i = 0; i < ml->n_tensors; ++i) { + struct ggml_tensor * tensor = ml->get_tensor_meta(i); + + const std::string name = ggml_get_name(tensor); + + read_data.resize(ggml_nbytes(tensor)); + tensor->data = read_data.data(); + ml->load_data_for(tensor); + + LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ", + ++idx, ml->n_tensors, + ggml_get_name(tensor), + llama_format_tensor_shape(tensor).c_str(), + ggml_type_name(tensor->type)); + + // This used to be a regex, but has an extreme cost to compile times. + bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? + + // quantize only 2D tensors + quantize &= (tensor->n_dims == 2); + quantize &= params->quantize_output_tensor || name != "output.weight"; + quantize &= !params->only_copy; + + enum ggml_type new_type; + void * new_data; + size_t new_size; + + if (quantize) { + new_type = quantized_type; +#ifdef GGML_USE_K_QUANTS + // TODO: avoid hardcoded tensor names - use the TN_* constants + const auto tn = LLM_TN(ml->get_arch()); + + if (name == tn(LLM_TENSOR_OUTPUT, "weight")) { + int nx = tensor->ne[0]; + if (model.arch == LLM_ARCH_FALCON || nx % QK_K != 0) { + new_type = GGML_TYPE_Q8_0; + } + else if (new_type != GGML_TYPE_Q8_0) { + new_type = GGML_TYPE_Q6_K; + } + } else if (name.find("attn_v.weight") != std::string::npos) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { + new_type = i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; + else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && + use_more_bits(i_attention_wv, n_attention_wv)) new_type = GGML_TYPE_Q6_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_attention_wv < 4) new_type = GGML_TYPE_Q5_K; + else if (QK_K == 64 && (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) && + (i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8)) new_type = GGML_TYPE_Q6_K; + if (model.type == MODEL_70B) { + // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is + // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with + // nearly negligible increase in model size by quantizing this tensor with more bits: + if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K; + } + ++i_attention_wv; + } else if (name.find("ffn_down.weight") != std::string::npos) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { + new_type = i_feed_forward_w2 < 2 ? GGML_TYPE_Q5_K + : model.arch != LLM_ARCH_FALCON || use_more_bits(i_feed_forward_w2, n_feed_forward_w2) ? GGML_TYPE_Q4_K + : GGML_TYPE_Q3_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) { + new_type = model.arch == LLM_ARCH_FALCON ? GGML_TYPE_Q4_K : GGML_TYPE_Q5_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) { + if (model.arch == LLM_ARCH_FALCON) { + new_type = i_feed_forward_w2 < 2 ? GGML_TYPE_Q6_K : + use_more_bits(i_feed_forward_w2, n_feed_forward_w2) ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; + } else { + if (use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K; + } + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && model.arch != LLM_ARCH_FALCON && i_feed_forward_w2 < 4) { + new_type = GGML_TYPE_Q5_K; + } + ++i_feed_forward_w2; + } else if (name.find("attn_output.weight") != std::string::npos) { + if (model.arch != LLM_ARCH_FALCON) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K ) new_type = GGML_TYPE_Q3_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) new_type = GGML_TYPE_Q4_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; + } else { + if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K; + } + } + else if (name.find("attn_qkv.weight") != std::string::npos) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K; + } + else if (name.find("ffn_gate.weight") != std::string::npos || name.find("ffn_up.weight") != std::string::npos) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; + } + // This can be used to reduce the size of the Q5_K_S model. + // The associated PPL increase is fully in line with the size reduction + //else { + // if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K; + //} + bool convert_incompatible_tensor = false; + if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K || + new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K) { + int nx = tensor->ne[0]; + int ny = tensor->ne[1]; + if (nx % QK_K != 0) { + LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for k-quants\n", __func__, nx, ny, QK_K); + convert_incompatible_tensor = true; + } + } + if (convert_incompatible_tensor) { + if (name == tn(LLM_TENSOR_OUTPUT, "weight")) { + new_type = GGML_TYPE_F16; //fall back to F16 instead of just failing. + LLAMA_LOG_WARN("F16 will be used for this tensor instead.\n"); + } else if (name == tn(LLM_TENSOR_TOKEN_EMBD, "weight")) { + new_type = GGML_TYPE_Q4_0; //fall back to Q4_0 instead of just failing. + LLAMA_LOG_WARN("Q4_0 will be used for this tensor instead.\n"); + } else { + throw std::runtime_error("Unsupported tensor size encountered\n"); + } + } +#endif + // If we've decided to quantize to the same type the tensor is already + // in then there's nothing to do. + quantize = tensor->type != new_type; + } + if (!quantize) { + new_type = tensor->type; + new_data = tensor->data; + new_size = ggml_nbytes(tensor); + LLAMA_LOG_INFO("size = %8.3f MB\n", ggml_nbytes(tensor)/1024.0/1024.0); + } else { + const size_t nelements = ggml_nelements(tensor); + + float * f32_data; + std::vector f32_conv_buf; + + if (tensor->type == GGML_TYPE_F32) { + f32_data = (float *) tensor->data; + } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) { + throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type))); + } else { + llama_convert_tensor_internal(tensor, f32_conv_buf, nelements, nthread); + f32_data = (float *) f32_conv_buf.data(); + } + + LLAMA_LOG_INFO("quantizing to %s .. ", ggml_type_name(new_type)); + fflush(stdout); + + work.resize(nelements * 4); // upper bound on size + new_data = work.data(); + std::vector hist_cur(1 << 4, 0); + + static const int chunk_size = 32 * 512; + const int nchunk = (nelements + chunk_size - 1)/chunk_size; + const int nthread_use = nthread > 1 ? std::max(1, std::min(nthread, nchunk)) : 1; + if (nthread_use < 2) { + new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nelements, hist_cur.data()); + } else { + size_t counter = 0; + new_size = 0; + auto compute = [&mutex, &counter, &hist_cur, &new_size, new_type, f32_data, new_data, nelements]() { + std::vector local_hist; + size_t local_size = 0; + while (true) { + std::unique_lock lock(mutex); + size_t first = counter; counter += chunk_size; + if (first >= nelements) { + if (!local_hist.empty()) { + for (int j=0; j %8.2f MB | hist: ", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0); + int64_t tot_count = 0; + for (size_t i = 0; i < hist_cur.size(); i++) { + hist_all[i] += hist_cur[i]; + tot_count += hist_cur[i]; + } + + if (tot_count > 0) { + for (size_t i = 0; i < hist_cur.size(); i++) { + LLAMA_LOG_INFO("%5.3f ", hist_cur[i] / float(nelements)); + } + } + LLAMA_LOG_INFO("\n"); + } + total_size_org += ggml_nbytes(tensor); + total_size_new += new_size; + + // update the gguf meta data as we go + gguf_set_tensor_type(ctx_out, name.c_str(), new_type); + gguf_set_tensor_data(ctx_out, name.c_str(), new_data, new_size); + + // write tensor data + padding + fout.write((const char *) new_data, new_size); + zeros(fout, GGML_PAD(new_size, align) - new_size); + } + + // go back to beginning of file and write the updated meta data + { + fout.seekp(0); + std::vector data(gguf_get_meta_size(ctx_out)); + gguf_get_meta_data(ctx_out, data.data()); + fout.write((const char *) data.data(), data.size()); + } + + fout.close(); + + gguf_free(ctx_out); + + LLAMA_LOG_INFO("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0); + LLAMA_LOG_INFO("%s: quant size = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0); + + // print histogram for all tensors + { + int64_t sum_all = 0; + for (size_t i = 0; i < hist_all.size(); i++) { + sum_all += hist_all[i]; + } + + if (sum_all > 0) { + LLAMA_LOG_INFO("%s: hist: ", __func__); + for (size_t i = 0; i < hist_all.size(); i++) { + LLAMA_LOG_INFO("%5.3f ", hist_all[i] / float(sum_all)); + } + LLAMA_LOG_INFO("\n"); + } + } +} + +// TODO: after the GGUF PR, this likely won't work and needs to be updated +int llama_apply_lora_from_file_internal(const struct llama_model & model, const char * path_lora, const char * path_base_model, int n_threads) { + LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora); + + const int64_t t_start_lora_us = ggml_time_us(); + + auto fin = std::ifstream(path_lora, std::ios::binary); + if (!fin) { + LLAMA_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_lora); + return 1; + } + + // verify magic and version + { + uint32_t magic; + fin.read((char *) &magic, sizeof(magic)); + uint32_t format_version; + fin.read((char *) &format_version, sizeof(format_version)); + + if (format_version != 1) { + LLAMA_LOG_ERROR("%s: unsupported file version\n", __func__ ); + return 1; + } + } + + int32_t lora_r; + int32_t lora_alpha; + fin.read((char *) &lora_r, sizeof(lora_r)); + fin.read((char *) &lora_alpha, sizeof(lora_alpha)); + float scaling = (float)lora_alpha / (float)lora_r; + + LLAMA_LOG_INFO("%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling); + + // create a temporary ggml context to store the lora tensors + // todo: calculate size from biggest possible tensor + std::vector lora_buf(1024ull * 1024ull * 1024ull); + struct ggml_init_params params; + params.mem_size = lora_buf.size(); + params.mem_buffer = lora_buf.data(); + params.no_alloc = false; + + ggml_context * lora_ctx = ggml_init(params); + std::unordered_map lora_tensors; + + // create a name -> tensor map of the model to accelerate lookups + std::unordered_map model_tensors; + for (const auto & kv : model.tensors_by_name) { + model_tensors.insert(kv); + } + + // load base model + std::unique_ptr ml; + ggml_context * base_ctx = NULL; + std::vector base_buf; + if (path_base_model) { + LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model); + ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true)); + + size_t ctx_size; + size_t mmapped_size; + ml->calc_sizes(ctx_size, mmapped_size); + base_buf.resize(ctx_size); + + ggml_init_params base_params; + base_params.mem_size = base_buf.size(); + base_params.mem_buffer = base_buf.data(); + base_params.no_alloc = ml->use_mmap; + + base_ctx = ggml_init(base_params); + + // maybe this should in llama_model_loader + if (ml->use_mmap) { + ml->mapping.reset(new llama_mmap(&ml->file, /* prefetch */ 0, ggml_is_numa())); + } + } + + // read tensors and apply + bool warned = false; + int n_tensors = 0; + + std::vector work_buffer; + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ftype; + + fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); + fin.read(reinterpret_cast(&length), sizeof(length)); + fin.read(reinterpret_cast(&ftype), sizeof(ftype)); + if (fin.eof()) { + break; + } + + int32_t ne[2] = { 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); + } + + std::string name; + { + char buf[1024]; + fin.read(buf, length); + name = std::string(buf, length); + } + + // check for lora suffix and get the type of tensor + const std::string lora_suffix = ".lora"; + size_t pos = name.rfind(lora_suffix); + if (pos == std::string::npos) { + LLAMA_LOG_ERROR("%s: error: '%s' is not a lora tensor\n", __func__, name.c_str()); + return 1; + } + + std::string lora_type = name.substr(pos + lora_suffix.length()); + std::string base_name = name; + base_name.erase(pos); + // LLAMA_LOG_INFO("%s: %s => %s (lora type %s) \n", __func__, name.c_str(),base_name.c_str(), lora_type.c_str()); + + if (model_tensors.find(base_name) == model_tensors.end()) { + LLAMA_LOG_ERROR("%s: unknown tensor '%s' in lora adapter\n", __func__, name.data()); + return 1; + } + + // create ggml tensor + ggml_type wtype; + switch (ftype) { + case 0: wtype = GGML_TYPE_F32; break; + case 1: wtype = GGML_TYPE_F16; break; + default: + { + LLAMA_LOG_ERROR("%s: invalid tensor data type '%d'\n", + __func__, ftype); + return false; + } + } + ggml_tensor * lora_tensor; + if (n_dims == 2) { + lora_tensor = ggml_new_tensor_2d(lora_ctx, wtype, ne[0], ne[1]); + } + else { + LLAMA_LOG_ERROR("%s: unsupported tensor dimension %d\n", __func__, n_dims); + return 1; + } + ggml_set_name(lora_tensor, "lora_tensor"); + + // load tensor data + size_t offset = fin.tellg(); + size_t tensor_data_size = ggml_nbytes(lora_tensor); + offset = (offset + 31) & -32; + fin.seekg(offset); + fin.read((char*)lora_tensor->data, tensor_data_size); + + lora_tensors[name] = lora_tensor; + + // check if we have both A and B tensors and apply + if (lora_tensors.find(base_name + ".loraA") != lora_tensors.end() && + lora_tensors.find(base_name + ".loraB") != lora_tensors.end()) { + + ggml_tensor * dest_t = model_tensors[base_name]; + + offload_func_t offload_func = llama_nop; + offload_func_t offload_func_force_inplace = llama_nop; + +#ifdef GGML_USE_CUBLAS + if (dest_t->backend == GGML_BACKEND_GPU || dest_t->backend == GGML_BACKEND_GPU_SPLIT) { + if (dest_t->type != GGML_TYPE_F16) { + throw std::runtime_error(format( + "%s: error: the simultaneous use of LoRAs and GPU acceleration is only supported for f16 models", __func__)); + } + offload_func = ggml_cuda_assign_buffers; + offload_func_force_inplace = ggml_cuda_assign_buffers_force_inplace; + } +#endif // GGML_USE_CUBLAS + + ggml_tensor * base_t; + if (ml) { + struct gguf_context * ctx_gguf = ml->ctx_gguf; + + // load from base model + if (gguf_find_tensor(ctx_gguf, base_name.c_str()) < 0) { + // TODO: throw + LLAMA_LOG_ERROR("%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str()); + return 1; + } + + // TODO: not tested!! maybe not working! + base_t = ml->create_tensor(base_ctx, base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }, GGML_BACKEND_CPU); + ml->load_data_for(base_t); + } else { + base_t = dest_t; + } + + if (ggml_is_quantized(base_t->type)) { + if (!warned) { + LLAMA_LOG_WARN("%s: warning: using a lora adapter with a quantized model may result in poor quality, " + "use a f16 or f32 base model with --lora-base\n", __func__); + warned = true; + } + } + + ggml_tensor * loraA = lora_tensors[base_name + ".loraA"]; + GGML_ASSERT(loraA->type == GGML_TYPE_F32); + ggml_set_name(loraA, "loraA"); + + ggml_tensor * loraB = lora_tensors[base_name + ".loraB"]; + GGML_ASSERT(loraB->type == GGML_TYPE_F32); + ggml_set_name(loraB, "loraB"); + + if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) { + LLAMA_LOG_ERROR("%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");" + " are you sure that this adapter is for this model?\n", __func__, base_t->ne[0], loraA->ne[1]); + return 1; + } + + // w = w + BA*s + ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB); + offload_func(BA); + ggml_set_name(BA, "BA"); + + if (scaling != 1.0f) { + ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, scaling); + ggml_set_name(scale_tensor, "scale_tensor"); + + BA = ggml_scale_inplace(lora_ctx, BA, scale_tensor); + offload_func(BA); + ggml_set_name(BA, "BA_scaled"); + } + + ggml_tensor * r; + if (base_t == dest_t) { + r = ggml_add_inplace(lora_ctx, dest_t, BA); + offload_func_force_inplace(r); + ggml_set_name(r, "r_add_inplace"); + } + else { + r = ggml_add(lora_ctx, base_t, BA); + offload_func(r); + ggml_set_name(r, "r_add"); + + r = ggml_cpy(lora_ctx, r, dest_t); + offload_func(r); + ggml_set_name(r, "r_cpy"); + } + + struct ggml_cgraph gf = ggml_build_forward(r); + + ggml_graph_compute_helper(work_buffer, &gf, n_threads); + + // we won't need these tensors again, reset the context to save memory + ggml_free(lora_ctx); + lora_ctx = ggml_init(params); + lora_tensors.clear(); + + n_tensors++; + if (n_tensors % 4 == 0) { + LLAMA_LOG_INFO("."); + } + } + } + + // TODO: this should be in a destructor, it will leak on failure + ggml_free(lora_ctx); + if (base_ctx) { + ggml_free(base_ctx); + } + + const int64_t t_lora_us = ggml_time_us() - t_start_lora_us; + LLAMA_LOG_INFO(" done (%.2f ms)\n", t_lora_us / 1000.0); + + return 0; +} + +// +// interface implementation +// + +struct llama_context_params llama_context_default_params() { + struct llama_context_params result = { + /*.seed =*/ LLAMA_DEFAULT_SEED, + /*.n_ctx =*/ 512, + /*.n_batch =*/ 512, + /*.n_gpu_layers =*/ 0, + /*.main_gpu =*/ 0, + /*.tensor_split =*/ nullptr, + /*.rope_freq_base =*/ 10000.0f, + /*.rope_freq_scale =*/ 1.0f, + /*.progress_callback =*/ nullptr, + /*.progress_callback_user_data =*/ nullptr, + /*.low_vram =*/ false, + /*.mul_mat_q =*/ true, + /*.f16_kv =*/ true, + /*.logits_all =*/ false, + /*.vocab_only =*/ false, + /*.use_mmap =*/ true, + /*.use_mlock =*/ false, + /*.embedding =*/ false, + }; + +#ifdef GGML_USE_METAL + result.n_gpu_layers = 1; +#endif + + return result; +} + +struct llama_model_quantize_params llama_model_quantize_default_params() { + struct llama_model_quantize_params result = { + /*.nthread =*/ 0, + /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1, + /*.allow_requantize =*/ false, + /*.quantize_output_tensor =*/ true, + /*.only_copy =*/ false, + }; + + return result; +} + +int llama_max_devices(void) { + return LLAMA_MAX_DEVICES; +} + +bool llama_mmap_supported(void) { + return llama_mmap::SUPPORTED; +} + +bool llama_mlock_supported(void) { + return llama_mlock::SUPPORTED; +} + +void llama_backend_init(bool numa) { + ggml_time_init(); + + // needed to initialize f16 tables + { + struct ggml_init_params params = { 0, NULL, false }; + struct ggml_context * ctx = ggml_init(params); + ggml_free(ctx); + } + + if (numa) { + ggml_numa_init(); + } + +#ifdef GGML_USE_MPI + ggml_mpi_backend_init(); +#endif +} + +void llama_backend_free(void) { +#ifdef GGML_USE_MPI + ggml_mpi_backend_free(); +#endif +} + +int64_t llama_time_us(void) { + return ggml_time_us(); +} + +struct llama_model * llama_load_model_from_file( + const char * path_model, + struct llama_context_params params) { + ggml_time_init(); + + llama_model * model = new llama_model; + + ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; + + unsigned cur_percentage = 0; + if (params.progress_callback == NULL) { + params.progress_callback_user_data = &cur_percentage; + params.progress_callback = [](float progress, void * ctx) { + unsigned * cur_percentage_p = (unsigned *) ctx; + unsigned percentage = (unsigned) (100 * progress); + while (percentage > *cur_percentage_p) { + *cur_percentage_p = percentage; + LLAMA_LOG_INFO("."); + if (percentage >= 100) { + LLAMA_LOG_INFO("\n"); + } + } + }; + } + + if (!llama_model_load(path_model, *model, params.n_ctx, params.n_batch, params.n_gpu_layers, + params.main_gpu, params.tensor_split, params.mul_mat_q, params.rope_freq_base, params.rope_freq_scale, + params.low_vram, memory_type, params.use_mmap, params.use_mlock, params.vocab_only, + params.progress_callback, params.progress_callback_user_data)) { + LLAMA_LOG_ERROR("%s: failed to load model\n", __func__); + delete model; + return nullptr; + } + + return model; +} + +void llama_free_model(struct llama_model * model) { + delete model; +} + +struct llama_context * llama_new_context_with_model( + struct llama_model * model, + struct llama_context_params params) { + + if (!model) { + return nullptr; + } + + llama_context * ctx = new llama_context(*model); + + if (params.seed == LLAMA_DEFAULT_SEED) { + params.seed = time(NULL); + } + + ctx->rng = std::mt19937(params.seed); + ctx->logits_all = params.logits_all; + + ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; + + // reserve memory for context buffers + if (!params.vocab_only) { + if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, ctx->model.hparams.n_ctx, params.n_gpu_layers)) { + LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); + llama_free(ctx); + return nullptr; + } + + { + const size_t memory_size = ggml_nbytes(ctx->kv_self.k) + ggml_nbytes(ctx->kv_self.v); + LLAMA_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); + } + + const auto & hparams = ctx->model.hparams; + + // resized during inference + if (params.logits_all) { + ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab); + } else { + ctx->logits.reserve(hparams.n_vocab); + } + + if (params.embedding){ + ctx->embedding.resize(hparams.n_embd); + } + + { + static const size_t tensor_alignment = 32; + // the compute buffer is used to store the tensor and graph structs, while the allocator buffer is used for the tensor data + ctx->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); + + // create measure allocator + ctx->alloc = ggml_allocr_new_measure(tensor_alignment); + + // build worst-case graph + int n_tokens = std::min((int)hparams.n_ctx, params.n_batch); + int n_past = hparams.n_ctx - n_tokens; + llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph + ggml_cgraph * gf = llama_build_graph(*ctx, &token, NULL, n_tokens, n_past); +#ifdef GGML_USE_METAL + if (params.n_gpu_layers > 0) { + ctx->ctx_metal = ggml_metal_init(1); + if (!ctx->ctx_metal) { + LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__); + llama_free(ctx); + return NULL; + } + ggml_metal_graph_find_concurrency(ctx->ctx_metal, gf, false); + ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); + } +#endif + // measure memory requirements for the graph + size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf) + tensor_alignment; + + LLAMA_LOG_INFO("%s: compute buffer total size = %7.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0); + + // recreate allocator with exact memory requirements + ggml_allocr_free(ctx->alloc); + + ctx->buf_alloc.resize(alloc_size); + ctx->alloc = ggml_allocr_new(ctx->buf_alloc.data, ctx->buf_alloc.size, tensor_alignment); +#ifdef GGML_USE_METAL + if (ctx->ctx_metal) { + ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); + } +#endif +#ifdef GGML_USE_CUBLAS + if (params.low_vram) { + LLAMA_LOG_INFO("%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__); + ggml_cuda_set_scratch_size(0); // disable scratch + } else { + ggml_cuda_set_scratch_size(alloc_size); + LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MB\n", __func__, alloc_size / 1024.0 / 1024.0); + } +#endif + } + +#ifdef GGML_USE_METAL + if (params.n_gpu_layers > 0) { + // this allocates all Metal resources and memory buffers + + void * data_ptr = NULL; + size_t data_size = 0; + + if (params.use_mmap) { + data_ptr = ctx->model.mapping->addr; + data_size = ctx->model.mapping->size; + } else { + data_ptr = ggml_get_mem_buffer(ctx->model.ctx); + data_size = ggml_get_mem_size (ctx->model.ctx); + } + + const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx); + + LLAMA_LOG_INFO("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0); + +#define LLAMA_METAL_CHECK_BUF(result) \ + if (!(result)) { \ + LLAMA_LOG_ERROR("%s: failed to add buffer\n", __func__); \ + llama_free(ctx); \ + return NULL; \ + } + + LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "data", data_ptr, data_size, max_size)); + + LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.data, ctx->buf_compute.size, 0)); + LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->kv_self.buf.data, ctx->kv_self.buf.size, 0)); + + LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "alloc", ctx->buf_alloc.data, ctx->buf_alloc.size, 0)); +#undef LLAMA_METAL_CHECK_BUF + } +#endif + } + +#ifdef GGML_USE_MPI + ctx->ctx_mpi = ggml_mpi_init(); + + if (ggml_mpi_rank(ctx->ctx_mpi) > 0) { + // Enter a blocking eval loop with dummy input, letting rank=0 drive the process + const std::vector tmp(ctx->model.hparams.n_ctx, llama_token_bos(ctx)); + while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {}; + llama_backend_free(); + exit(1); + } +#endif + + return ctx; +} + +struct llama_context * llama_init_from_file( + const char * path_model, + struct llama_context_params params) { + struct llama_model * model = llama_load_model_from_file(path_model, params); + if (!model) { + return nullptr; + } + + struct llama_context * ctx = llama_new_context_with_model(model, params); + ctx->model_owner = true; + + return ctx; +} + +void llama_free(struct llama_context * ctx) { + delete ctx; +} + +int llama_n_vocab(const struct llama_context * ctx) { + return llama_model_n_vocab(&ctx->model); +} + +int llama_n_ctx(const struct llama_context * ctx) { + return llama_model_n_ctx(&ctx->model); +} + +int llama_n_ctx_train(const struct llama_context * ctx) { + return llama_model_n_ctx_train(&ctx->model); +} + +int llama_n_embd(const struct llama_context * ctx) { + return llama_model_n_embd(&ctx->model); +} + +enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx) { + return ctx->model.vocab.type; +} + +int llama_model_n_vocab(const struct llama_model * model) { + return model->vocab.id_to_token.size(); +} + +int llama_model_n_ctx(const struct llama_model * model) { + return model->hparams.n_ctx; +} + +int llama_model_n_ctx_train(const struct llama_model * model) { + return model->hparams.n_ctx_train; +} + +int llama_model_n_embd(const struct llama_model * model) { + return model->hparams.n_embd; +} + +int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) { + return snprintf(buf, buf_size, "%s %s %s", + model->name.c_str(), + llama_model_type_name(model->type), + llama_model_ftype_name(model->ftype).c_str()); +} + +uint64_t llama_model_size(const struct llama_model * model) { + uint64_t size = 0; + for (const auto & it : model->tensors_by_name) { + size += ggml_nbytes(it.second); + } + return size; +} + +uint64_t llama_model_n_params(const struct llama_model * model) { + uint64_t nparams = 0; + for (const auto & it : model->tensors_by_name) { + nparams += ggml_nelements(it.second); + } + return nparams; +} + +int llama_model_quantize( + const char * fname_inp, + const char * fname_out, + const llama_model_quantize_params * params) { + try { + llama_model_quantize_internal(fname_inp, fname_out, params); + return 0; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: failed to quantize: %s\n", __func__, err.what()); + return 1; + } +} + +int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, const char * path_base_model, int n_threads) { + try { + return llama_apply_lora_from_file_internal(ctx->model, path_lora, path_base_model, n_threads); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what()); + return 1; + } +} + +int llama_model_apply_lora_from_file(const struct llama_model * model, const char * path_lora, const char * path_base_model, int n_threads) { + try { + return llama_apply_lora_from_file_internal(*model, path_lora, path_base_model, n_threads); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what()); + return 1; + } +} + +int llama_get_kv_cache_token_count(const struct llama_context * ctx) { + return ctx->kv_self.n; +} + +#define LLAMA_MAX_RNG_STATE (64*1024) + +void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { + if (seed == LLAMA_DEFAULT_SEED) { + seed = time(NULL); + } + ctx->rng.seed(seed); +} + +// Returns the *maximum* size of the state +size_t llama_get_state_size(const struct llama_context * ctx) { + // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. + // for reference, std::mt19937(1337) serializes to 6701 bytes. + const size_t s_rng_size = sizeof(size_t); + const size_t s_rng = LLAMA_MAX_RNG_STATE; + const size_t s_logits_capacity = sizeof(size_t); + const size_t s_logits_size = sizeof(size_t); + const size_t s_logits = ctx->logits.capacity() * sizeof(float); + const size_t s_embedding_size = sizeof(size_t); + const size_t s_embedding = ctx->embedding.size() * sizeof(float); + const size_t s_kv_size = sizeof(size_t); + const size_t s_kv_ntok = sizeof(int); + const size_t s_kv = ctx->kv_self.buf.size; + + const size_t s_total = ( + + s_rng_size + + s_rng + + s_logits_capacity + + s_logits_size + + s_logits + + s_embedding_size + + s_embedding + + s_kv_size + + s_kv_ntok + + s_kv + ); + + return s_total; +} + +// llama_context_data +struct llama_data_context { + virtual void write(const void * src, size_t size) = 0; + virtual size_t get_size_written() = 0; + virtual ~llama_data_context() = default; +}; + +struct llama_data_buffer_context : llama_data_context { + uint8_t * ptr; + size_t size_written = 0; + + llama_data_buffer_context(uint8_t * p) : ptr(p) {} + + void write(const void * src, size_t size) override { + memcpy(ptr, src, size); + ptr += size; + size_written += size; + } + + size_t get_size_written() override { + return size_written; + } +}; + +struct llama_data_file_context : llama_data_context { + llama_file * file; + size_t size_written = 0; + + llama_data_file_context(llama_file * f) : file(f) {} + + void write(const void * src, size_t size) override { + file->write_raw(src, size); + size_written += size; + } + + size_t get_size_written() override { + return size_written; + } +}; + +/** copy state data into either a buffer or file depending on the passed in context + * + * file context: + * llama_file file("/path", "wb"); + * llama_data_file_context data_ctx(&file); + * llama_copy_state_data(ctx, &data_ctx); + * + * buffer context: + * std::vector buf(max_size, 0); + * llama_data_buffer_context data_ctx(&buf.data()); + * llama_copy_state_data(ctx, &data_ctx); + * +*/ +void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { + // copy rng + { + std::stringstream rng_ss; + rng_ss << ctx->rng; + + const size_t rng_size = rng_ss.str().size(); + char rng_buf[LLAMA_MAX_RNG_STATE]; + + memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE); + memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); + + data_ctx->write(&rng_size, sizeof(rng_size)); + data_ctx->write(&rng_buf[0], LLAMA_MAX_RNG_STATE); + } + + // copy logits + { + const size_t logits_cap = ctx->logits.capacity(); + const size_t logits_size = ctx->logits.size(); + + data_ctx->write(&logits_cap, sizeof(logits_cap)); + data_ctx->write(&logits_size, sizeof(logits_size)); + + if (logits_size) { + data_ctx->write(ctx->logits.data(), logits_size * sizeof(float)); + } + + // If there is a gap between the size and the capacity, write padding + size_t padding_size = (logits_cap - logits_size) * sizeof(float); + if (padding_size > 0) { + std::vector padding(padding_size, 0); // Create a buffer filled with zeros + data_ctx->write(padding.data(), padding_size); + } + } + + // copy embeddings + { + const size_t embedding_size = ctx->embedding.size(); + + data_ctx->write(&embedding_size, sizeof(embedding_size)); + + if (embedding_size) { + data_ctx->write(ctx->embedding.data(), embedding_size * sizeof(float)); + } + } + + // copy kv cache + { + const auto & kv_self = ctx->kv_self; + const auto & hparams = ctx->model.hparams; + const int n_layer = hparams.n_layer; + const int n_embd = hparams.n_embd_gqa(); + const int n_ctx = hparams.n_ctx; + + const size_t kv_size = kv_self.buf.size; + const int kv_ntok = llama_get_kv_cache_token_count(ctx); + + data_ctx->write(&kv_size, sizeof(kv_size)); + data_ctx->write(&kv_ntok, sizeof(kv_ntok)); + + if (kv_size) { + const size_t elt_size = ggml_element_size(kv_self.k); + + ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); + ggml_cgraph gf{}; + + ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); + std::vector kout3d_data(ggml_nbytes(kout3d), 0); + kout3d->data = kout3d_data.data(); + + ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer); + std::vector vout3d_data(ggml_nbytes(vout3d), 0); + vout3d->data = vout3d_data.data(); + + ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k, + n_embd, kv_ntok, n_layer, + elt_size*n_embd, elt_size*n_embd*n_ctx, 0); + + ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v, + kv_ntok, n_embd, n_layer, + elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); + + ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d)); + ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d)); + ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1); + + ggml_free(cpy_ctx); + + // our data is now in the kout3d_data and vout3d_data buffers + // write them to file + data_ctx->write(kout3d_data.data(), kout3d_data.size()); + data_ctx->write(vout3d_data.data(), vout3d_data.size()); + } + } +} + +size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { + llama_data_buffer_context data_ctx(dst); + llama_copy_state_data_internal(ctx, &data_ctx); + + return data_ctx.get_size_written(); +} + +// Sets the state reading from the specified source address +size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { + uint8_t * inp = src; + + // set rng + { + size_t rng_size; + char rng_buf[LLAMA_MAX_RNG_STATE]; + + memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size); + memcpy(&rng_buf[0], inp, LLAMA_MAX_RNG_STATE); inp += LLAMA_MAX_RNG_STATE; + + std::stringstream rng_ss; + rng_ss.str(std::string(&rng_buf[0], rng_size)); + rng_ss >> ctx->rng; + + GGML_ASSERT(!rng_ss.fail()); + } + + // set logits + { + size_t logits_cap; + size_t logits_size; + + memcpy(&logits_cap, inp, sizeof(logits_cap)); inp += sizeof(logits_cap); + memcpy(&logits_size, inp, sizeof(logits_size)); inp += sizeof(logits_size); + + GGML_ASSERT(ctx->logits.capacity() == logits_cap); + + if (logits_size) { + ctx->logits.resize(logits_size); + memcpy(ctx->logits.data(), inp, logits_size * sizeof(float)); + } + + inp += logits_cap * sizeof(float); + } + + // set embeddings + { + size_t embedding_size; + + memcpy(&embedding_size, inp, sizeof(embedding_size)); inp += sizeof(embedding_size); + + GGML_ASSERT(ctx->embedding.capacity() == embedding_size); + + if (embedding_size) { + memcpy(ctx->embedding.data(), inp, embedding_size * sizeof(float)); + inp += embedding_size * sizeof(float); + } + } + + // set kv cache + { + const auto & kv_self = ctx->kv_self; + const auto & hparams = ctx->model.hparams; + const int n_layer = hparams.n_layer; + const int n_embd = hparams.n_embd_gqa(); + const int n_ctx = hparams.n_ctx; + + size_t kv_size; + int kv_ntok; + + memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size); + memcpy(&kv_ntok, inp, sizeof(kv_ntok)); inp += sizeof(kv_ntok); + + if (kv_size) { + GGML_ASSERT(kv_self.buf.size == kv_size); + + const size_t elt_size = ggml_element_size(kv_self.k); + + ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); + ggml_cgraph gf{}; + + ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); + kin3d->data = (void *) inp; + inp += ggml_nbytes(kin3d); + + ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer); + vin3d->data = (void *) inp; + inp += ggml_nbytes(vin3d); + + ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k, + n_embd, kv_ntok, n_layer, + elt_size*n_embd, elt_size*n_embd*n_ctx, 0); + + ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v, + kv_ntok, n_embd, n_layer, + elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); + + ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d)); + ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d)); + ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1); + + ggml_free(cpy_ctx); + } + + ctx->kv_self.n = kv_ntok; + } + + const size_t nread = inp - src; + const size_t max_size = llama_get_state_size(ctx); + + GGML_ASSERT(nread <= max_size); + + return nread; +} + +static bool llama_load_session_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + llama_file file(path_session, "rb"); + + // sanity checks + { + const uint32_t magic = file.read_u32(); + const uint32_t version = file.read_u32(); + + if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) { + LLAMA_LOG_ERROR("%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version); + return false; + } + + llama_hparams session_hparams; + file.read_raw(&session_hparams, sizeof(llama_hparams)); + + if (session_hparams != ctx->model.hparams) { + LLAMA_LOG_INFO("%s : model hparams didn't match from session file!\n", __func__); + return false; + } + } + + // load the prompt + { + const uint32_t n_token_count = file.read_u32(); + + if (n_token_count > n_token_capacity) { + LLAMA_LOG_ERROR("%s : token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); + return false; + } + + file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); + *n_token_count_out = n_token_count; + } + + // restore the context state + { + const size_t n_state_size_cur = file.size - file.tell(); + const size_t n_state_size_max = llama_get_state_size(ctx); + + if (n_state_size_cur > n_state_size_max) { + LLAMA_LOG_ERROR("%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur); + return false; + } + + std::vector state_data(n_state_size_max); + file.read_raw(state_data.data(), n_state_size_cur); + + llama_set_state_data(ctx, state_data.data()); + } + + return true; +} + +bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + try { + return llama_load_session_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("error loading session file: %s\n", err.what()); + return false; + } +} + +bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { + llama_file file(path_session, "wb"); + + file.write_u32(LLAMA_SESSION_MAGIC); + file.write_u32(LLAMA_SESSION_VERSION); + + file.write_raw(&ctx->model.hparams, sizeof(llama_hparams)); + + // save the prompt + file.write_u32((uint32_t) n_token_count); + file.write_raw(tokens, sizeof(llama_token) * n_token_count); + + // save the context state using stream saving + llama_data_file_context data_ctx(&file); + llama_copy_state_data_internal(ctx, &data_ctx); + + return true; +} + +int llama_eval( + struct llama_context * ctx, + const llama_token * tokens, + int n_tokens, + int n_past, + int n_threads) { + if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, n_threads, nullptr)) { + LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); + return 1; + } + + // get a more accurate load time, upon first eval + // TODO: fix this + if (!ctx->has_evaluated_once) { + ctx->t_load_us = ggml_time_us() - ctx->t_start_us; + ctx->has_evaluated_once = true; + } + + return 0; +} + +int llama_eval_embd( + struct llama_context * ctx, + const float * embd, + int n_tokens, + int n_past, + int n_threads) { + if (!llama_eval_internal(*ctx, nullptr, embd, n_tokens, n_past, n_threads, nullptr)) { + LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); + return 1; + } + + // get a more accurate load time, upon first eval + // TODO: fix this + if (!ctx->has_evaluated_once) { + ctx->t_load_us = ggml_time_us() - ctx->t_start_us; + ctx->has_evaluated_once = true; + } + + return 0; +} + +int llama_eval_export(struct llama_context * ctx, const char * fname) { + const int n_batch = 1; + const int n_ctx = 512 - n_batch; + + const std::vector tmp(n_batch, llama_token_bos(ctx)); + + if (!llama_eval_internal(*ctx, tmp.data(), nullptr, tmp.size(), n_ctx, 1, fname)) { + LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); + return 1; + } + + return 0; +} + +float * llama_get_logits(struct llama_context * ctx) { + return ctx->logits.data(); +} + +float * llama_get_embeddings(struct llama_context * ctx) { + return ctx->embedding.data(); +} + +const char * llama_token_get_text(const struct llama_context * ctx, llama_token token) { + return ctx->model.vocab.id_to_token[token].text.c_str(); +} + +float llama_token_get_score(const struct llama_context * ctx, llama_token token) { + return ctx->model.vocab.id_to_token[token].score; +} + +llama_token_type llama_token_get_type(const struct llama_context * ctx, llama_token token) { + return ctx->model.vocab.id_to_token[token].type; +} + +llama_token llama_token_bos(const struct llama_context * ctx) { + return ctx->model.vocab.special_bos_id; +} + +llama_token llama_token_eos(const struct llama_context * ctx) { + return ctx->model.vocab.special_eos_id; +} + +llama_token llama_token_nl(const struct llama_context * ctx) { + return ctx->model.vocab.linefeed_id; +} + +int llama_tokenize( + struct llama_context * ctx, + const char * text, + llama_token * tokens, + int n_max_tokens, + bool add_bos) { + return llama_tokenize_with_model(&ctx->model, text, tokens, n_max_tokens, add_bos); +} + +int llama_tokenize_with_model( + const struct llama_model * model, + const char * text, + llama_token * tokens, + int n_max_tokens, + bool add_bos) { + auto res = llama_tokenize_internal(model->vocab, text, add_bos); + + if (n_max_tokens < (int) res.size()) { + LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); + return -((int) res.size()); + } + + for (size_t i = 0; i < res.size(); i++) { + tokens[i] = res[i]; + } + + return res.size(); +} + +int llama_token_to_piece(const struct llama_context * ctx, llama_token token, char * buf, int length) { + return llama_token_to_piece_with_model(&ctx->model, token, buf, length); +} + +// does not write null-terminator to buf +int llama_token_to_piece_with_model(const struct llama_model * model, llama_token token, char * buf, int length) { + if (0 <= token && token < llama_model_n_vocab(model)) { + if (llama_is_normal_token(model->vocab, token)) { + std::string result = model->vocab.id_to_token[token].text; + if (llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM) { + llama_unescape_whitespace(result); + } + if (length < (int) result.length()) { + return -result.length(); + } + memcpy(buf, result.c_str(), result.length()); + return result.length(); + } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT + if (length < 3) { + return -3; + } + buf[0] = '\xe2'; + buf[1] = '\x96'; + buf[2] = '\x85'; + return 3; + } else if (llama_is_control_token(model->vocab, token)) { + ; + } else if (llama_is_byte_token(model->vocab, token)) { + if (length < 1) { + return -1; + } + buf[0] = llama_token_to_byte(model->vocab, token); + return 1; + } + } + return 0; +} + +struct llama_timings llama_get_timings(struct llama_context * ctx) { + struct llama_timings result = { + /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, + /*.t_end_ms =*/ 1.00 * ggml_time_ms(), + /*.t_load_ms =*/ 1e-3 * ctx->t_load_us, + /*.t_sample_ms =*/ 1e-3 * ctx->t_sample_us, + /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us, + /*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us, + + /*.n_sample =*/ std::max(1, ctx->n_sample), + /*.n_p_eval =*/ std::max(1, ctx->n_p_eval), + /*.n_eval =*/ std::max(1, ctx->n_eval), + }; + + return result; +} + +void llama_print_timings(struct llama_context * ctx) { + const llama_timings timings = llama_get_timings(ctx); + + LLAMA_LOG_INFO("\n"); + LLAMA_LOG_INFO("%s: load time = %8.2f ms\n", __func__, timings.t_load_ms); + LLAMA_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, timings.t_sample_ms, timings.n_sample, timings.t_sample_ms / timings.n_sample, 1e3 / timings.t_sample_ms * timings.n_sample); + LLAMA_LOG_INFO("%s: prompt eval time = %8.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval); + LLAMA_LOG_INFO("%s: eval time = %8.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, timings.t_eval_ms, timings.n_eval, timings.t_eval_ms / timings.n_eval, 1e3 / timings.t_eval_ms * timings.n_eval); + LLAMA_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (timings.t_end_ms - timings.t_start_ms)); +} + +void llama_reset_timings(struct llama_context * ctx) { + ctx->t_start_us = ggml_time_us(); + ctx->t_sample_us = ctx->n_sample = 0; + ctx->t_eval_us = ctx->n_eval = 0; + ctx->t_p_eval_us = ctx->n_p_eval = 0; +} + +const char * llama_print_system_info(void) { + static std::string s; + + s = ""; + s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; + s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; + s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; + s += "AVX512_VBMI = " + std::to_string(ggml_cpu_has_avx512_vbmi()) + " | "; + s += "AVX512_VNNI = " + std::to_string(ggml_cpu_has_avx512_vnni()) + " | "; + s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; + s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; + s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; + s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | "; + s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; + s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; + s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; + s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | "; + s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | "; + s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; + + return s.c_str(); +} + +void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) { + fprintf(stream, "\n"); + fprintf(stream, "###########\n"); + fprintf(stream, "# Timings #\n"); + fprintf(stream, "###########\n"); + fprintf(stream, "\n"); + + fprintf(stream, "mst_eval: %.2f # ms / token during generation\n", + 1.0e-3 * ctx->t_eval_us / ctx->n_eval); + fprintf(stream, "mst_p_eval: %.2f # ms / token during prompt processing\n", + 1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval); + fprintf(stream, "mst_sample: %.2f # ms / token during sampling\n", + 1.0e-3 * ctx->t_sample_us / ctx->n_sample); + fprintf(stream, "n_eval: %d # number of tokens generated (excluding the first one)\n", ctx->n_eval); + fprintf(stream, "n_p_eval: %d # number of tokens processed in batches at the beginning\n", ctx->n_p_eval); + fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->n_sample); + fprintf(stream, "t_eval_us: %" PRId64 " # total microseconds spent generating tokens\n", ctx->t_eval_us); + fprintf(stream, "t_load_us: %" PRId64 " # total microseconds spent loading the model\n", ctx->t_load_us); + fprintf(stream, "t_p_eval_us: %" PRId64 " # total microseconds spent prompt processing\n", ctx->t_p_eval_us); + fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->t_sample_us); + fprintf(stream, "ts_eval: %.2f # tokens / second during generation\n", + 1.0e6 * ctx->n_eval / ctx->t_eval_us); + fprintf(stream, "ts_p_eval: %.2f # tokens / second during prompt processing\n", + 1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us); + fprintf(stream, "ts_sample: %.2f # tokens / second during sampling\n", + 1.0e6 * ctx->n_sample / ctx->t_sample_us); +} + +// For internal test use +const std::vector>& llama_internal_get_tensor_map(struct llama_context * ctx) { + return ctx->model.tensors_by_name; +} + +void llama_log_set(llama_log_callback log_callback, void * user_data) { + g_state.log_callback = log_callback ? log_callback : llama_log_callback_default; + g_state.log_callback_user_data = user_data; +} + +static void llama_log_internal_v(llama_log_level level, const char * format, va_list args) { + va_list args_copy; + va_copy(args_copy, args); + char buffer[128]; + int len = vsnprintf(buffer, 128, format, args); + if (len < 128) { + g_state.log_callback(level, buffer, g_state.log_callback_user_data); + } else { + char* buffer2 = new char[len+1]; + vsnprintf(buffer2, len+1, format, args_copy); + buffer2[len] = 0; + g_state.log_callback(level, buffer2, g_state.log_callback_user_data); + delete[] buffer2; + } + va_end(args_copy); +} + +static void llama_log_internal(llama_log_level level, const char * format, ...) { + va_list args; + va_start(args, format); + llama_log_internal_v(level, format, args); + va_end(args); +} + +static void llama_log_callback_default(llama_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} diff --git a/plugins/wasi_nn/thirdparty/ggml/llama.h b/plugins/wasi_nn/thirdparty/ggml/llama.h new file mode 100644 index 00000000..37975beb --- /dev/null +++ b/plugins/wasi_nn/thirdparty/ggml/llama.h @@ -0,0 +1,547 @@ +#ifndef LLAMA_H +#define LLAMA_H + +#include "ggml.h" +#ifdef GGML_USE_CUBLAS +#include "ggml-cuda.h" +#define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES +#else +#define LLAMA_MAX_DEVICES 1 +#endif // GGML_USE_CUBLAS +#include +#include +#include +#include + +#ifdef LLAMA_SHARED +# if defined(_WIN32) && !defined(__MINGW32__) +# ifdef LLAMA_BUILD +# define LLAMA_API __declspec(dllexport) +# else +# define LLAMA_API __declspec(dllimport) +# endif +# else +# define LLAMA_API __attribute__ ((visibility ("default"))) +# endif +#else +# define LLAMA_API +#endif + +#ifdef __GNUC__ +# define DEPRECATED(func, hint) func __attribute__((deprecated(hint))) +#elif defined(_MSC_VER) +# define DEPRECATED(func, hint) __declspec(deprecated(hint)) func +#else +# define DEPRECATED(func, hint) func +#endif + +#define LLAMA_DEFAULT_SEED 0xFFFFFFFF + +#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' + +#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN +#define LLAMA_SESSION_VERSION 1 + +#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) +// Defined when llama.cpp is compiled with support for offloading model layers to GPU. +#define LLAMA_SUPPORTS_GPU_OFFLOAD +#endif + +#ifdef __cplusplus +extern "C" { +#endif + + // + // C interface + // + // TODO: show sample usage + // + + struct llama_model; + struct llama_context; + + typedef int llama_token; + + enum llama_log_level { + LLAMA_LOG_LEVEL_ERROR = 2, + LLAMA_LOG_LEVEL_WARN = 3, + LLAMA_LOG_LEVEL_INFO = 4 + }; + + enum llama_vocab_type { + LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece + LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding + }; + + enum llama_token_type { + LLAMA_TOKEN_TYPE_UNDEFINED = 0, + LLAMA_TOKEN_TYPE_NORMAL = 1, + LLAMA_TOKEN_TYPE_UNKNOWN = 2, + LLAMA_TOKEN_TYPE_CONTROL = 3, + LLAMA_TOKEN_TYPE_USER_DEFINED = 4, + LLAMA_TOKEN_TYPE_UNUSED = 5, + LLAMA_TOKEN_TYPE_BYTE = 6, + }; + + // model file types + enum llama_ftype { + LLAMA_FTYPE_ALL_F32 = 0, + LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 + // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed + // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed + LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors + + LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file + }; + + typedef struct llama_token_data { + llama_token id; // token id + float logit; // log-odds of the token + float p; // probability of the token + } llama_token_data; + + typedef struct llama_token_data_array { + llama_token_data * data; + size_t size; + bool sorted; + } llama_token_data_array; + + typedef void (*llama_progress_callback)(float progress, void *ctx); + + struct llama_context_params { + uint32_t seed; // RNG seed, -1 for random + int32_t n_ctx; // text context + int32_t n_batch; // prompt processing batch size + int32_t n_gpu_layers; // number of layers to store in VRAM + int32_t main_gpu; // the GPU that is used for scratch and small tensors + + const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES) + + // ref: https://github.com/ggerganov/llama.cpp/pull/2054 + float rope_freq_base; // RoPE base frequency + float rope_freq_scale; // RoPE frequency scaling factor + + // called with a progress value between 0 and 1, pass NULL to disable + llama_progress_callback progress_callback; + // context pointer passed to the progress callback + void * progress_callback_user_data; + + // Keep the booleans together to avoid misalignment during copy-by-value. + bool low_vram; // if true, reduce VRAM usage at the cost of performance + bool mul_mat_q; // if true, use experimental mul_mat_q kernels + bool f16_kv; // use fp16 for KV cache + bool logits_all; // the llama_eval() call computes all logits, not just the last one + bool vocab_only; // only load the vocabulary, no weights + bool use_mmap; // use mmap if possible + bool use_mlock; // force system to keep model in RAM + bool embedding; // embedding mode only + }; + + // Signature for logging events + // Note that text includes the new line character at the end for most events. + // If your logging mechanism cannot handle that, check if the last character is '\n' and strip it + // if it exists. + // It might not exist for progress report where '.' is output repeatedly. + typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data); + + // model quantization parameters + typedef struct llama_model_quantize_params { + int nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() + enum llama_ftype ftype; // quantize to this llama_ftype + bool allow_requantize; // allow quantizing non-f32/f16 tensors + bool quantize_output_tensor; // quantize output.weight + bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored + } llama_model_quantize_params; + + // grammar types + struct llama_grammar; + + // grammar element type + enum llama_gretype { + // end of rule definition + LLAMA_GRETYPE_END = 0, + + // start of alternate definition for rule + LLAMA_GRETYPE_ALT = 1, + + // non-terminal element: reference to rule + LLAMA_GRETYPE_RULE_REF = 2, + + // terminal element: character (code point) + LLAMA_GRETYPE_CHAR = 3, + + // inverse char(s) ([^a], [^a-b] [^abc]) + LLAMA_GRETYPE_CHAR_NOT = 4, + + // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to + // be an inclusive range ([a-z]) + LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, + + // modifies a preceding LLAMA_GRETYPE_CHAR or + // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + LLAMA_GRETYPE_CHAR_ALT = 6, + }; + + typedef struct llama_grammar_element { + enum llama_gretype type; + uint32_t value; // Unicode code point or rule ID + } llama_grammar_element; + + // performance timing information + struct llama_timings { + double t_start_ms; + double t_end_ms; + double t_load_ms; + double t_sample_ms; + double t_p_eval_ms; + double t_eval_ms; + + int32_t n_sample; + int32_t n_p_eval; + int32_t n_eval; + }; + + LLAMA_API struct llama_context_params llama_context_default_params(void); + LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); + + // Initialize the llama + ggml backend + // If numa is true, use NUMA optimizations + // Call once at the start of the program + LLAMA_API void llama_backend_init(bool numa); + + // Call once at the end of the program - currently only used for MPI + LLAMA_API void llama_backend_free(void); + + LLAMA_API struct llama_model * llama_load_model_from_file( + const char * path_model, + struct llama_context_params params); + + LLAMA_API void llama_free_model(struct llama_model * model); + + LLAMA_API struct llama_context * llama_new_context_with_model( + struct llama_model * model, + struct llama_context_params params); + + // Frees all allocated memory + LLAMA_API void llama_free(struct llama_context * ctx); + + LLAMA_API int64_t llama_time_us(void); + + LLAMA_API int llama_max_devices (void); + LLAMA_API bool llama_mmap_supported (void); + LLAMA_API bool llama_mlock_supported(void); + + LLAMA_API int llama_n_vocab (const struct llama_context * ctx); + LLAMA_API int llama_n_ctx (const struct llama_context * ctx); + LLAMA_API int llama_n_ctx_train(const struct llama_context * ctx); + LLAMA_API int llama_n_embd (const struct llama_context * ctx); + + LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx); + + LLAMA_API int llama_model_n_vocab (const struct llama_model * model); + LLAMA_API int llama_model_n_ctx (const struct llama_model * model); + LLAMA_API int llama_model_n_ctx_train(const struct llama_model * model); + LLAMA_API int llama_model_n_embd (const struct llama_model * model); + + // Get a string describing the model type + LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size); + // Returns the total size of all the tensors in the model in bytes + LLAMA_API uint64_t llama_model_size(const struct llama_model * model); + // Returns the total number of parameters in the model + LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); + + // Returns 0 on success + LLAMA_API int llama_model_quantize( + const char * fname_inp, + const char * fname_out, + const llama_model_quantize_params * params); + + // Apply a LoRA adapter to a loaded model + // path_base_model is the path to a higher quality model to use as a base for + // the layers modified by the adapter. Can be NULL to use the current loaded model. + // The model needs to be reloaded before applying a new adapter, otherwise the adapter + // will be applied on top of the previous one + // Returns 0 on success + LLAMA_API DEPRECATED(int llama_apply_lora_from_file( + struct llama_context * ctx, + const char * path_lora, + const char * path_base_model, + int n_threads), + "please use llama_model_apply_lora_from_file instead"); + + LLAMA_API int llama_model_apply_lora_from_file( + const struct llama_model * model, + const char * path_lora, + const char * path_base_model, + int n_threads); + + // Returns the number of tokens in the KV cache + LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx); + + // Sets the current rng seed. + LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); + + // Returns the maximum size in bytes of the state (rng, logits, embedding + // and kv_cache) - will often be smaller after compacting tokens + LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx); + + // Copies the state to the specified destination address. + // Destination needs to have allocated enough memory. + // Returns the number of bytes copied + LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst); + + // Set the state reading from the specified address + // Returns the number of bytes read + LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src); + + // Save/load session file + LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out); + LLAMA_API bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count); + + // Run the llama inference to obtain the logits and probabilities for the next token. + // tokens + n_tokens is the provided batch of new tokens to process + // n_past is the number of tokens to use from previous eval calls + // Returns 0 on success + LLAMA_API int llama_eval( + struct llama_context * ctx, + const llama_token * tokens, + int n_tokens, + int n_past, + int n_threads); + + // Same as llama_eval, but use float matrix input directly. + LLAMA_API int llama_eval_embd( + struct llama_context * ctx, + const float * embd, + int n_tokens, + int n_past, + int n_threads); + + // Export a static computation graph for context of 511 and batch size of 1 + // NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these + // parameters here to keep things simple + // IMPORTANT: do not use for anything else other than debugging and testing! + LLAMA_API int llama_eval_export(struct llama_context * ctx, const char * fname); + + // Token logits obtained from the last call to llama_eval() + // The logits for the last token are stored in the last row + // Can be mutated in order to change the probabilities of the next token + // Rows: n_tokens + // Cols: n_vocab + LLAMA_API float * llama_get_logits(struct llama_context * ctx); + + // Get the embeddings for the input + // shape: [n_embd] (1-dimensional) + LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); + + // + // Vocab + // + + LLAMA_API const char * llama_token_get_text(const struct llama_context * ctx, llama_token token); + + LLAMA_API float llama_token_get_score(const struct llama_context * ctx, llama_token token); + + LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_context * ctx, llama_token token); + + // Special tokens + LLAMA_API llama_token llama_token_bos(const struct llama_context * ctx); // beginning-of-sentence + LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx); // end-of-sentence + LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx); // next-line + + // + // Tokenization + // + + // Convert the provided text into tokens. + // The tokens pointer must be large enough to hold the resulting tokens. + // Returns the number of tokens on success, no more than n_max_tokens + // Returns a negative number on failure - the number of tokens that would have been returned + LLAMA_API int llama_tokenize( + struct llama_context * ctx, + const char * text, + llama_token * tokens, + int n_max_tokens, + bool add_bos); + + LLAMA_API int llama_tokenize_with_model( + const struct llama_model * model, + const char * text, + llama_token * tokens, + int n_max_tokens, + bool add_bos); + + // Token Id -> Piece. + // Uses the vocabulary in the provided context. + // Does not write null terminator to the buffer. + // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens. + LLAMA_API int llama_token_to_piece( + const struct llama_context * ctx, + llama_token token, + char * buf, + int length); + + LLAMA_API int llama_token_to_piece_with_model( + const struct llama_model * model, + llama_token token, + char * buf, + int length); + + // + // Grammar + // + + LLAMA_API struct llama_grammar * llama_grammar_init( + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index); + + LLAMA_API void llama_grammar_free(struct llama_grammar * grammar); + + LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar); + + // + // Sampling functions + // + + /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. + LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty); + + /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. + LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence); + + /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 + /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted. + /// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. + /// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. + LLAMA_API void llama_sample_classifier_free_guidance( + struct llama_context * ctx, + llama_token_data_array * candidates, + struct llama_context * guidance_ctx, + float scale); + + /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. + LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates); + + /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 + LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep); + + /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 + LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep); + + /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. + LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep); + + /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. + LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep); + LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp); + + /// @details Apply constraints from grammar + LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar); + + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. + /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. + /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. + /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu); + + /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. + /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. + /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu); + + /// @details Selects the token with the highest probability. + LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates); + + /// @details Randomly selects a token from the candidates based on their probabilities. + LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates); + + /// @details Accepts the sampled token into the grammar + LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token); + + // + // Beam search + // + + struct llama_beam_view { + const llama_token * tokens; + size_t n_tokens; + float p; // Cumulative beam probability (renormalized relative to all beams) + bool eob; // Callback should set this to true when a beam is at end-of-beam. + }; + + // Passed to beam_search_callback function. + // Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams + // (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks. + // These pointers are valid only during the synchronous callback, so should not be saved. + struct llama_beams_state { + struct llama_beam_view * beam_views; + size_t n_beams; // Number of elements in beam_views[]. + size_t common_prefix_length; // Current max length of prefix tokens shared by all beams. + bool last_call; // True iff this is the last callback invocation. + }; + + // Type of pointer to the beam_search_callback function. + // void* callback_data is any custom data passed to llama_beam_search, that is subsequently + // passed back to beam_search_callback. This avoids having to use global variables in the callback. + typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state); + + /// @details Deterministically returns entire sentence constructed by a beam search. + /// @param ctx Pointer to the llama_context. + /// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state. + /// @param callback_data A pointer that is simply passed back to callback. + /// @param n_beams Number of beams to use. + /// @param n_past Number of tokens already evaluated. + /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier. + /// @param n_threads Number of threads as passed to llama_eval(). + LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads); + + // Performance information + LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); + LLAMA_API void llama_print_timings(struct llama_context * ctx); + LLAMA_API void llama_reset_timings(struct llama_context * ctx); + + // Print system information + LLAMA_API const char * llama_print_system_info(void); + + // Set callback for all future logging events. + // If this is not called, or NULL is supplied, everything is output on stderr. + LLAMA_API void llama_log_set(llama_log_callback log_callback, void * user_data); + + LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx); + +#ifdef __cplusplus +} +#endif + +// Internal API to be implemented by llama.cpp and used by tests/benchmarks only +#ifdef LLAMA_API_INTERNAL + +#include +#include + +struct ggml_tensor; + +const std::vector>& llama_internal_get_tensor_map(struct llama_context * ctx); + +#endif // LLAMA_API_INTERNAL + +#endif // LLAMA_H diff --git a/plugins/wasi_nn/thirdparty/ggml/log.h b/plugins/wasi_nn/thirdparty/ggml/log.h new file mode 100644 index 00000000..18f3b976 --- /dev/null +++ b/plugins/wasi_nn/thirdparty/ggml/log.h @@ -0,0 +1,643 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +// -------------------------------- +// +// Basic usage: +// +// -------- +// +// The LOG() and LOG_TEE() macros are ready to go by default +// they do not require any initialization. +// +// LOGLN() and LOG_TEELN() are variants which automatically +// include \n character at the end of the log string. +// +// LOG() behaves exactly like printf, by default writing to a logfile. +// LOG_TEE() additionally, prints to the screen too ( mimics Unix tee command ). +// +// Default logfile is named +// "llama..log" +// Default LOG_TEE() secondary output target is +// stderr +// +// Logs can be dynamically disabled or enabled using functions: +// log_disable() +// and +// log_enable() +// +// A log target can be changed with: +// log_set_target( string ) +// creating and opening, or re-opening a file by string filename +// or +// log_set_target( FILE* ) +// allowing to point at stderr, stdout, or any valid FILE* file handler. +// +// -------- +// +// End of Basic usage. +// +// -------------------------------- + +// Specifies a log target. +// default uses log_handler() with "llama.log" log file +// this can be changed, by defining LOG_TARGET +// like so: +// +// #define LOG_TARGET (a valid FILE*) +// #include "log.h" +// +// or it can be simply redirected to stdout or stderr +// like so: +// +// #define LOG_TARGET stderr +// #include "log.h" +// +// The log target can also be redirected to a diffrent function +// like so: +// +// #define LOG_TARGET log_handler_diffrent() +// #include "log.h" +// +// FILE* log_handler_diffrent() +// { +// return stderr; +// } +// +// or: +// +// #define LOG_TARGET log_handler_another_one("somelog.log") +// #include "log.h" +// +// FILE* log_handler_another_one(char*filename) +// { +// static FILE* logfile = nullptr; +// (...) +// if( !logfile ) +// { +// fopen(...) +// } +// (...) +// return logfile +// } +// +#ifndef LOG_TARGET + #define LOG_TARGET log_handler() +#endif + +#ifndef LOG_TEE_TARGET + #define LOG_TEE_TARGET stderr +#endif + +// Utility to obtain "pid" like unique process id and use it when creating log files. +inline std::string log_get_pid() +{ + static std::string pid; + if (pid.empty()) + { + // std::this_thread::get_id() is the most portable way of obtaining a "process id" + // it's not the same as "pid" but is unique enough to solve multiple instances + // trying to write to the same log. + std::stringstream ss; + ss << std::this_thread::get_id(); + pid = ss.str(); + } + + return pid; +} + +// Utility function for generating log file names with unique id based on thread id. +// invocation with log_filename_generator( "llama", "log" ) creates a string "llama..log" +// where the number is a runtime id of the current thread. + +#define log_filename_generator(log_file_basename, log_file_extension) log_filename_generator_impl(log_file_basename, log_file_extension) + +// INTERNAL, DO NOT USE +inline std::string log_filename_generator_impl(const std::string & log_file_basename, const std::string & log_file_extension) +{ + std::stringstream buf; + + buf << log_file_basename; + buf << "."; + buf << log_get_pid(); + buf << "."; + buf << log_file_extension; + + return buf.str(); +} + +#ifndef LOG_DEFAULT_FILE_NAME + #define LOG_DEFAULT_FILE_NAME log_filename_generator("llama", "log") +#endif + +// Utility for turning #define values into string literals +// so we can have a define for stderr and +// we can print "stderr" instead of literal stderr, etc. +#define LOG_STRINGIZE1(s) #s +#define LOG_STRINGIZE(s) LOG_STRINGIZE1(s) + +#define LOG_TEE_TARGET_STRING LOG_STRINGIZE(LOG_TEE_TARGET) + +// Allows disabling timestamps. +// in order to disable, define LOG_NO_TIMESTAMPS +// like so: +// +// #define LOG_NO_TIMESTAMPS +// #include "log.h" +// +#ifndef LOG_NO_TIMESTAMPS + #ifndef _MSC_VER + #define LOG_TIMESTAMP_FMT "[%" PRIu64 "] " + #define LOG_TIMESTAMP_VAL , (std::chrono::duration_cast>(std::chrono::system_clock::now().time_since_epoch())).count() + #else + #define LOG_TIMESTAMP_FMT "[%" PRIu64 "] " + #define LOG_TIMESTAMP_VAL , (std::chrono::duration_cast>(std::chrono::system_clock::now().time_since_epoch())).count() + #endif +#else + #define LOG_TIMESTAMP_FMT "%s" + #define LOG_TIMESTAMP_VAL ,"" +#endif + +#ifdef LOG_TEE_TIMESTAMPS + #ifndef _MSC_VER + #define LOG_TEE_TIMESTAMP_FMT "[%" PRIu64 "] " + #define LOG_TEE_TIMESTAMP_VAL , (std::chrono::duration_cast>(std::chrono::system_clock::now().time_since_epoch())).count() + #else + #define LOG_TEE_TIMESTAMP_FMT "[%" PRIu64 "] " + #define LOG_TEE_TIMESTAMP_VAL , (std::chrono::duration_cast>(std::chrono::system_clock::now().time_since_epoch())).count() + #endif +#else + #define LOG_TEE_TIMESTAMP_FMT "%s" + #define LOG_TEE_TIMESTAMP_VAL ,"" +#endif + +// Allows disabling file/line/function prefix +// in order to disable, define LOG_NO_FILE_LINE_FUNCTION +// like so: +// +// #define LOG_NO_FILE_LINE_FUNCTION +// #include "log.h" +// +#ifndef LOG_NO_FILE_LINE_FUNCTION + #ifndef _MSC_VER + #define LOG_FLF_FMT "[%24s:%5d][%24s] " + #define LOG_FLF_VAL , __FILE__, __LINE__, __FUNCTION__ + #else + #define LOG_FLF_FMT "[%24s:%5ld][%24s] " + #define LOG_FLF_VAL , __FILE__, __LINE__, __FUNCTION__ + #endif +#else + #define LOG_FLF_FMT "%s" + #define LOG_FLF_VAL ,"" +#endif + +#ifdef LOG_TEE_FILE_LINE_FUNCTION + #ifndef _MSC_VER + #define LOG_TEE_FLF_FMT "[%24s:%5d][%24s] " + #define LOG_TEE_FLF_VAL , __FILE__, __LINE__, __FUNCTION__ + #else + #define LOG_TEE_FLF_FMT "[%24s:%5ld][%24s] " + #define LOG_TEE_FLF_VAL , __FILE__, __LINE__, __FUNCTION__ + #endif +#else + #define LOG_TEE_FLF_FMT "%s" + #define LOG_TEE_FLF_VAL ,"" +#endif + +// Utility for synchronizing log configuration state +// since std::optional was introduced only in c++17 +enum LogTriState +{ + LogTriStateSame, + LogTriStateFalse, + LogTriStateTrue +}; + +// INTERNAL, DO NOT USE +// USE LOG() INSTEAD +// +#ifndef _MSC_VER + #define LOG_IMPL(str, ...) \ + { \ + if (LOG_TARGET != nullptr) \ + { \ + fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL, __VA_ARGS__); \ + fflush(LOG_TARGET); \ + } \ + } +#else + #define LOG_IMPL(str, ...) \ + { \ + if (LOG_TARGET != nullptr) \ + { \ + fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL "", ##__VA_ARGS__); \ + fflush(LOG_TARGET); \ + } \ + } +#endif + +// INTERNAL, DO NOT USE +// USE LOG_TEE() INSTEAD +// +#ifndef _MSC_VER + #define LOG_TEE_IMPL(str, ...) \ + { \ + if (LOG_TARGET != nullptr) \ + { \ + fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL, __VA_ARGS__); \ + fflush(LOG_TARGET); \ + } \ + if (LOG_TARGET != nullptr && LOG_TARGET != stdout && LOG_TARGET != stderr && LOG_TEE_TARGET != nullptr) \ + { \ + fprintf(LOG_TEE_TARGET, LOG_TEE_TIMESTAMP_FMT LOG_TEE_FLF_FMT str "%s" LOG_TEE_TIMESTAMP_VAL LOG_TEE_FLF_VAL, __VA_ARGS__); \ + fflush(LOG_TEE_TARGET); \ + } \ + } +#else + #define LOG_TEE_IMPL(str, ...) \ + { \ + if (LOG_TARGET != nullptr) \ + { \ + fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL "", ##__VA_ARGS__); \ + fflush(LOG_TARGET); \ + } \ + if (LOG_TARGET != nullptr && LOG_TARGET != stdout && LOG_TARGET != stderr && LOG_TEE_TARGET != nullptr) \ + { \ + fprintf(LOG_TEE_TARGET, LOG_TEE_TIMESTAMP_FMT LOG_TEE_FLF_FMT str "%s" LOG_TEE_TIMESTAMP_VAL LOG_TEE_FLF_VAL "", ##__VA_ARGS__); \ + fflush(LOG_TEE_TARGET); \ + } \ + } +#endif + +// The '\0' as a last argument, is a trick to bypass the silly +// "warning: ISO C++11 requires at least one argument for the "..." in a variadic macro" +// so we can have a single macro which can be called just like printf. + +// Main LOG macro. +// behaves like printf, and supports arguments the exact same way. +// +#ifndef _MSC_VER + #define LOG(...) LOG_IMPL(__VA_ARGS__, "") +#else + #define LOG(str, ...) LOG_IMPL("%s" str, "", __VA_ARGS__, "") +#endif + +// Main TEE macro. +// does the same as LOG +// and +// simultaneously writes stderr. +// +// Secondary target can be changed just like LOG_TARGET +// by defining LOG_TEE_TARGET +// +#ifndef _MSC_VER + #define LOG_TEE(...) LOG_TEE_IMPL(__VA_ARGS__, "") +#else + #define LOG_TEE(str, ...) LOG_TEE_IMPL("%s" str, "", __VA_ARGS__, "") +#endif + +// LOG macro variants with auto endline. +#ifndef _MSC_VER + #define LOGLN(...) LOG_IMPL(__VA_ARGS__, "\n") + #define LOG_TEELN(...) LOG_TEE_IMPL(__VA_ARGS__, "\n") +#else + #define LOGLN(str, ...) LOG_IMPL("%s" str, "", __VA_ARGS__, "\n") + #define LOG_TEELN(str, ...) LOG_TEE_IMPL("%s" str, "", __VA_ARGS__, "\n") +#endif + +// INTERNAL, DO NOT USE +inline FILE *log_handler1_impl(bool change = false, LogTriState disable = LogTriStateSame, const std::string & filename = LOG_DEFAULT_FILE_NAME, FILE *target = nullptr) +{ + static bool _initialized{false}; + static bool _disabled{(filename.empty() && target == nullptr)}; + static std::string log_current_filename{filename}; + static FILE *log_current_target{target}; + static FILE *logfile = nullptr; + + if (change) + { + if (disable == LogTriStateTrue) + { + // Disable primary target + _disabled = true; + } + // If previously disabled, only enable, and keep previous target + else if (disable == LogTriStateFalse) + { + _disabled = false; + } + // Otherwise, process the arguments + else if (log_current_filename != filename || log_current_target != target) + { + _initialized = false; + } + } + + if (_disabled) + { + // Log is disabled + return nullptr; + } + + if (_initialized) + { + // with fallback in case something went wrong + return logfile ? logfile : stderr; + } + + // do the (re)initialization + if (target != nullptr) + { + if (logfile != nullptr && logfile != stdout && logfile != stderr) + { + fclose(logfile); + } + + log_current_filename = LOG_DEFAULT_FILE_NAME; + log_current_target = target; + + logfile = target; + } + else + { + if (log_current_filename != filename) + { + if (logfile != nullptr && logfile != stdout && logfile != stderr) + { + fclose(logfile); + } + } + + logfile = fopen(filename.c_str(), "w"); + } + + if (!logfile) + { + // Verify whether the file was opened, otherwise fallback to stderr + logfile = stderr; + + fprintf(stderr, "Failed to open logfile '%s' with error '%s'\n", filename.c_str(), std::strerror(errno)); + fflush(stderr); + + // At this point we let the init flag be to true below, and let the target fallback to stderr + // otherwise we would repeatedly fopen() which was already unsuccessful + } + + _initialized = true; + + return logfile ? logfile : stderr; +} + +// INTERNAL, DO NOT USE +inline FILE *log_handler2_impl(bool change = false, LogTriState disable = LogTriStateSame, FILE *target = nullptr, const std::string & filename = LOG_DEFAULT_FILE_NAME) +{ + return log_handler1_impl(change, disable, filename, target); +} + +// Disables logs entirely at runtime. +// Makes LOG() and LOG_TEE() produce no output, +// untill enabled back. +#define log_disable() log_disable_impl() + +// INTERNAL, DO NOT USE +inline FILE *log_disable_impl() +{ + return log_handler1_impl(true, LogTriStateTrue); +} + +// Enables logs at runtime. +#define log_enable() log_enable_impl() + +// INTERNAL, DO NOT USE +inline FILE *log_enable_impl() +{ + return log_handler1_impl(true, LogTriStateFalse); +} + +// Sets target fir logs, either by a file name or FILE* pointer (stdout, stderr, or any valid FILE*) +#define log_set_target(target) log_set_target_impl(target) + +// INTERNAL, DO NOT USE +inline FILE *log_set_target_impl(const std::string & filename) { return log_handler1_impl(true, LogTriStateSame, filename); } +inline FILE *log_set_target_impl(FILE *target) { return log_handler2_impl(true, LogTriStateSame, target); } + +// INTERNAL, DO NOT USE +inline FILE *log_handler() { return log_handler1_impl(); } + +inline void log_test() +{ + log_disable(); + LOG("01 Hello World to nobody, because logs are disabled!\n") + log_enable(); + LOG("02 Hello World to default output, which is \"%s\" ( Yaaay, arguments! )!\n", LOG_STRINGIZE(LOG_TARGET)) + LOG_TEE("03 Hello World to **both** default output and " LOG_TEE_TARGET_STRING "!\n") + log_set_target(stderr); + LOG("04 Hello World to stderr!\n") + LOG_TEE("05 Hello World TEE with double printing to stderr prevented!\n") + log_set_target(LOG_DEFAULT_FILE_NAME); + LOG("06 Hello World to default log file!\n") + log_set_target(stdout); + LOG("07 Hello World to stdout!\n") + log_set_target(LOG_DEFAULT_FILE_NAME); + LOG("08 Hello World to default log file again!\n") + log_disable(); + LOG("09 Hello World _1_ into the void!\n") + log_enable(); + LOG("10 Hello World back from the void ( you should not see _1_ in the log or the output )!\n") + log_disable(); + log_set_target("llama.anotherlog.log"); + LOG("11 Hello World _2_ to nobody, new target was selected but logs are still disabled!\n") + log_enable(); + LOG("12 Hello World this time in a new file ( you should not see _2_ in the log or the output )?\n") + log_set_target("llama.yetanotherlog.log"); + LOG("13 Hello World this time in yet new file?\n") + log_set_target(log_filename_generator("llama_autonamed", "log")); + LOG("14 Hello World in log with generated filename!\n") +#ifdef _MSC_VER + LOG_TEE("15 Hello msvc TEE without arguments\n") + LOG_TEE("16 Hello msvc TEE with (%d)(%s) arguments\n", 1, "test") + LOG_TEELN("17 Hello msvc TEELN without arguments\n") + LOG_TEELN("18 Hello msvc TEELN with (%d)(%s) arguments\n", 1, "test") + LOG("19 Hello msvc LOG without arguments\n") + LOG("20 Hello msvc LOG with (%d)(%s) arguments\n", 1, "test") + LOGLN("21 Hello msvc LOGLN without arguments\n") + LOGLN("22 Hello msvc LOGLN with (%d)(%s) arguments\n", 1, "test") +#endif +} + +inline bool log_param_single_parse(const std::string & param) +{ + if ( param == "--log-test") + { + log_test(); + return true; + } + + if ( param == "--log-disable") + { + log_disable(); + return true; + } + + if ( param == "--log-enable") + { + log_enable(); + return true; + } + + return false; +} + +inline bool log_param_pair_parse(bool check_but_dont_parse, const std::string & param, const std::string & next = std::string()) +{ + if ( param == "--log-file") + { + if (!check_but_dont_parse) + { + log_set_target(log_filename_generator(next.empty() ? "unnamed" : next, "log")); + } + + return true; + } + + return false; +} + +inline void log_print_usage() +{ + printf("log options:\n"); + /* format + printf(" -h, --help show this help message and exit\n");*/ + /* spacing + printf("__-param----------------Description\n");*/ + printf(" --log-test Run simple logging test\n"); + printf(" --log-disable Disable trace logs\n"); + printf(" --log-enable Enable trace logs\n"); + printf(" --log-file Specify a log filename (without extension)\n"); + printf(" Log file will be tagged with unique ID and written as \"..log\"\n"); /* */ +} + +#define log_dump_cmdline(argc, argv) log_dump_cmdline_impl(argc, argv) + +// INTERNAL, DO NOT USE +inline void log_dump_cmdline_impl(int argc, char **argv) +{ + std::stringstream buf; + for (int i = 0; i < argc; ++i) + { + if (std::string(argv[i]).find(' ') != std::string::npos) + { + buf << " \"" << argv[i] <<"\""; + } + else + { + buf << " " << argv[i]; + } + } + LOGLN("Cmd:%s", buf.str().c_str()) +} + +#define log_tostr(var) log_var_to_string_impl(var).c_str() + +inline std::string log_var_to_string_impl(bool var) +{ + return var ? "true" : "false"; +} + +inline std::string log_var_to_string_impl(std::string var) +{ + return var; +} + +inline std::string log_var_to_string_impl(const std::vector & var) +{ + std::stringstream buf; + buf << "[ "; + bool first = true; + for (auto e : var) + { + if (first) + { + first = false; + } + else + { + buf << ", "; + } + buf << std::to_string(e); + } + buf << " ]"; + + return buf.str(); +} + +#define LOG_TOKENS_TOSTR_PRETTY(ctx, tokens) \ + [&tokens, &ctx]() \ + { \ + std::stringstream buf; \ + buf << "[ "; \ + \ + bool first = true; \ + for (const auto &token : tokens) \ + { \ + if (!first) \ + buf << ", "; \ + else \ + first = false; \ + \ + auto detokenized = llama_token_to_piece(ctx, token); \ + \ + detokenized.erase( \ + std::remove_if( \ + detokenized.begin(), \ + detokenized.end(), \ + [](const unsigned char c) { return !std::isprint(c); }), \ + detokenized.end()); \ + \ + buf \ + << "'" << detokenized << "'" \ + << ":" << std::to_string(token); \ + } \ + buf << " ]"; \ + \ + return buf.str(); \ + }() \ + .c_str() + +#ifdef LOG_DISABLE_LOGS + +#undef LOG +#define LOG(...) // dummy stub +#undef LOGLN +#define LOGLN(...) // dummy stub + +#undef LOG_TEE +#define LOG_TEE(...) fprintf(stderr, __VA_ARGS__); // convert to normal fprintf + +#undef LOG_TEELN +#define LOG_TEELN(...) fprintf(stderr, __VA_ARGS__); // convert to normal fprintf + +#undef LOG_DISABLE +#define LOG_DISABLE() // dummy stub + +#undef LOG_ENABLE +#define LOG_ENABLE() // dummy stub + +#undef LOG_ENABLE +#define LOG_ENABLE() // dummy stub + +#undef LOG_SET_TARGET +#define LOG_SET_TARGET(...) // dummy stub + +#undef LOG_DUMP_CMDLINE +#define LOG_DUMP_CMDLINE(...) // dummy stub + +#endif // LOG_DISABLE_LOGS From 78a8ee0a9a3796303b0c5152c3a1f839d4cd536e Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 14 Sep 2023 17:52:37 +0800 Subject: [PATCH 142/623] [WASI-NN] Remove unused ggml options Signed-off-by: dm4 --- plugins/wasi_nn/thirdparty/CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/plugins/wasi_nn/thirdparty/CMakeLists.txt b/plugins/wasi_nn/thirdparty/CMakeLists.txt index e71284dd..94db3597 100644 --- a/plugins/wasi_nn/thirdparty/CMakeLists.txt +++ b/plugins/wasi_nn/thirdparty/CMakeLists.txt @@ -4,7 +4,6 @@ if(WASMEDGE_PLUGIN_WASI_NN_BACKEND) string(TOLOWER ${WASMEDGE_PLUGIN_WASI_NN_BACKEND} BACKEND) if(BACKEND STREQUAL "ggml") - add_compile_options(-DGGML_BACKEND) add_subdirectory(ggml) endif() endif() From b11d6aae18057e0cf4cea1c9ca15af661aee5067 Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 19 Sep 2023 15:24:59 +0800 Subject: [PATCH 143/623] [WASI-NN] Check LLAMA_LOG to print logs Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 15 +++++++++++++++ plugins/wasi_nn/thirdparty/ggml/common.cpp | 2 +- plugins/wasi_nn/thirdparty/ggml/llama.cpp | 5 ++++- 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 0b483c16..21363965 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -6,6 +6,7 @@ #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML #include +#include #include #endif @@ -63,6 +64,10 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Store the loaded graph. GraphId = Env.NNGraph.size() - 1; + + // Disable llama log by default. + log_disable(); + return ErrNo::Success; } @@ -112,6 +117,12 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { return ErrNo::InvalidArgument; } + // Use env LLAMA_LOG=1 to enable llama log. + const char *LlamaLogEnv = std::getenv("LLAMA_LOG"); + if (LlamaLogEnv != nullptr) { + spdlog::info("llama_system_info: {}"sv, llama_print_system_info()); + } + // Output start from prompt. for (auto Id : CxtRef.LlamaInputs) { CxtRef.LlamaOutputs += llama_token_to_piece(GraphRef.LlamaContext, Id); @@ -158,6 +169,10 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { CxtRef.LlamaInputs.push_back(NewTokenId); } + if (LlamaLogEnv != nullptr) { + llama_print_timings(GraphRef.LlamaContext); + } + return ErrNo::Success; } #else diff --git a/plugins/wasi_nn/thirdparty/ggml/common.cpp b/plugins/wasi_nn/thirdparty/ggml/common.cpp index 382f0058..b78ad002 100644 --- a/plugins/wasi_nn/thirdparty/ggml/common.cpp +++ b/plugins/wasi_nn/thirdparty/ggml/common.cpp @@ -769,7 +769,7 @@ std::tuple llama_init_from_gpt_par } { - LOG("warming up the model with an empty run\n"); + // LOG("warming up the model with an empty run\n"); const std::vector tmp = { llama_token_bos(lctx), llama_token_eos(lctx), }; llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads); diff --git a/plugins/wasi_nn/thirdparty/ggml/llama.cpp b/plugins/wasi_nn/thirdparty/ggml/llama.cpp index 2a2a0c9c..c183ddb3 100644 --- a/plugins/wasi_nn/thirdparty/ggml/llama.cpp +++ b/plugins/wasi_nn/thirdparty/ggml/llama.cpp @@ -57,6 +57,7 @@ #include #include #include +#include #include #include #include @@ -6393,6 +6394,8 @@ static void llama_log_internal(llama_log_level level, const char * format, ...) static void llama_log_callback_default(llama_log_level level, const char * text, void * user_data) { (void) level; (void) user_data; - fputs(text, stderr); + if (std::getenv("LLAMA_LOG") != nullptr) { + fputs(text, stderr); + } fflush(stderr); } From 7c46ad5aa7b5c6ab500153c60335b977d5d4c260 Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 19 Sep 2023 21:44:41 +0800 Subject: [PATCH 144/623] [WASI-NN] Clear LlamaOutputs before prediction Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 21363965..be6c41c1 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -123,10 +123,8 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { spdlog::info("llama_system_info: {}"sv, llama_print_system_info()); } - // Output start from prompt. - for (auto Id : CxtRef.LlamaInputs) { - CxtRef.LlamaOutputs += llama_token_to_piece(GraphRef.LlamaContext, Id); - } + // Clear the outputs. + CxtRef.LlamaOutputs = ""sv; // Main predict loop. // TODO: recompute a compressed context based on previous tokens once the From f86b128577bf1fd675f24cd80b3dd566ea39b89c Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 20 Sep 2023 16:31:06 +0800 Subject: [PATCH 145/623] [WASI-NN] Initialize llama context when setting inputs. Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index be6c41c1..bb78710f 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -52,10 +52,10 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Initialize ggml model. gpt_params Params; - Params.model = ModelFilePath; llama_backend_init(Params.numa); - std::tie(GraphRef.LlamaModel, GraphRef.LlamaContext) = - llama_init_from_gpt_params(Params); + llama_context_params ContextParams = llama_context_default_params(); + GraphRef.LlamaModel = + llama_load_model_from_file(ModelFilePath.c_str(), ContextParams); if (GraphRef.LlamaModel == nullptr) { spdlog::error("[WASI-NN] Error: unable to init model."sv); Env.NNGraph.pop_back(); @@ -84,6 +84,13 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, const TensorData &Tensor) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + + // Initialize the llama context. + llama_context_params ContextParams = llama_context_default_params(); + GraphRef.LlamaContext = + llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); + + // Set the input. std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), Tensor.Tensor.size()); CxtRef.LlamaInputs = llama_tokenize(GraphRef.LlamaContext, Prompt, true); @@ -168,6 +175,8 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { } if (LlamaLogEnv != nullptr) { + spdlog::info("llama_get_kv_cache_token_count {}"sv, + llama_get_kv_cache_token_count(GraphRef.LlamaContext)); llama_print_timings(GraphRef.LlamaContext); } From b7ab2228504f342d8f70bb885e552e5b2c4a917d Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 20 Sep 2023 16:32:25 +0800 Subject: [PATCH 146/623] [WASI-NN] Do not append [end of text] for ggml outputs Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index bb78710f..16a5e02e 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -162,7 +162,6 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { NewTokenId = llama_sample_token_greedy(GraphRef.LlamaContext, &CandidatesP); if (NewTokenId == llama_token_eos(GraphRef.LlamaContext)) { - CxtRef.LlamaOutputs += "[end of text]"sv; break; } From a1fed2d995dfa1831120494de420e36f91b99413 Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 21 Sep 2023 14:28:48 +0800 Subject: [PATCH 147/623] [WASI-NN] Update ggml tests Signed-off-by: dm4 --- test/plugins/wasi_nn/wasi_nn.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 3213d35d..e5bec1ff 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -1477,11 +1477,6 @@ TEST(WasiNNTest, GGMLBackend) { // Should output more than 100 bytes. auto BytesWritten = *MemInst.getPointer(BuilderPtr); EXPECT_GE(BytesWritten, 100); - // Output should begin with the prompt. (+1 to skip bos token) - const auto Output = MemInst.getSpan(StorePtr, 100); - EXPECT_EQ( - std::string(Output.begin() + 1, Output.begin() + 1 + Prompt.size()), - Prompt); } } #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML From f0b4392e553673d3961d160e3a40e7ab921f86d2 Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 21 Sep 2023 15:43:35 +0800 Subject: [PATCH 148/623] [WASI-NN] Add ggml backend log prefix Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 16a5e02e..94aa8dc5 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -16,15 +16,16 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, Device Device, uint32_t &GraphId) noexcept { // The graph builder length must be 1. if (Builders.size() != 1) { - spdlog::error("[WASI-NN] Wrong GraphBuilder Length {:d}, expect 1"sv, - Builders.size()); + spdlog::error( + "[WASI-NN] GGML backend: Wrong GraphBuilder Length {:d}, expect 1"sv, + Builders.size()); return ErrNo::InvalidArgument; } // Setup Graph Device if (Device != Device::CPU) { spdlog::error( - "[WASI-NN] ggml backend only support CPU target currently."sv); + "[WASI-NN] GGML backend: Only support CPU target currently."sv); return ErrNo::InvalidArgument; } @@ -37,10 +38,11 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, std::string ModelFilePath("ggml-model.bin"sv); std::ofstream TempFile(ModelFilePath); if (!TempFile) { - spdlog::error("[WASI-NN] Failed to create the temporary file. Currently, " - "our workaround involves creating a temporary model file " - "named \"ggml-model.bin\" and passing this filename as a " - "parameter to the ggml llama library."sv); + spdlog::error( + "[WASI-NN] GGML backend: Failed to create the temporary file. " + "Currently, our workaround involves creating a temporary model " + "file named \"ggml-model.bin\" and passing this filename as a " + "parameter to the ggml llama library."sv); return ErrNo::InvalidArgument; } TempFile << BinModel; @@ -57,7 +59,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, GraphRef.LlamaModel = llama_load_model_from_file(ModelFilePath.c_str(), ContextParams); if (GraphRef.LlamaModel == nullptr) { - spdlog::error("[WASI-NN] Error: unable to init model."sv); + spdlog::error("[WASI-NN] GGML backend: Error: unable to init model."sv); Env.NNGraph.pop_back(); return ErrNo::InvalidArgument; } @@ -98,8 +100,9 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, // Minus 4 for the special tokens. const uint32_t MaxTokensListSize = MaxContextSize - 4; if (CxtRef.LlamaInputs.size() > MaxTokensListSize) { - spdlog::error("[WASI-NN]: Error: prompt too long ({} tokens, max %{})"sv, - CxtRef.LlamaInputs.size(), MaxTokensListSize); + spdlog::error( + "[WASI-NN] GGML backend: Error: prompt too long ({} tokens, max %{})"sv, + CxtRef.LlamaInputs.size(), MaxTokensListSize); return ErrNo::InvalidArgument; } return ErrNo::Success; @@ -120,14 +123,15 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); if (CxtRef.LlamaInputs.size() == 0) { - spdlog::error("[WASI-NN] Llama input is not set!"sv); + spdlog::error("[WASI-NN] GGML backend: Llama input is not set!"sv); return ErrNo::InvalidArgument; } // Use env LLAMA_LOG=1 to enable llama log. const char *LlamaLogEnv = std::getenv("LLAMA_LOG"); if (LlamaLogEnv != nullptr) { - spdlog::info("llama_system_info: {}"sv, llama_print_system_info()); + spdlog::info("[WASI-NN] GGML backend: llama_system_info: {}"sv, + llama_print_system_info()); } // Clear the outputs. @@ -143,7 +147,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { int(CxtRef.LlamaInputs.size()), llama_get_kv_cache_token_count(GraphRef.LlamaContext), get_num_physical_cores())) { - spdlog::error("[WASI-NN] Llama failed to eval."sv); + spdlog::error("[WASI-NN] GGML backend: Llama failed to eval."sv); return ErrNo::InvalidArgument; } CxtRef.LlamaInputs.clear(); @@ -174,7 +178,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { } if (LlamaLogEnv != nullptr) { - spdlog::info("llama_get_kv_cache_token_count {}"sv, + spdlog::info("[WASI-NN] GGML backend: llama_get_kv_cache_token_count {}"sv, llama_get_kv_cache_token_count(GraphRef.LlamaContext)); llama_print_timings(GraphRef.LlamaContext); } From ca5a086bced6ae9f73044261462ab848e14a4ca3 Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 21 Sep 2023 17:44:16 +0800 Subject: [PATCH 149/623] [CI] Fix CI build Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 12 ++++++++++-- utils/wasi-nn/install-openvino.sh | 2 +- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index f3443323..a947b64e 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -3,8 +3,16 @@ # llama.cpp options set(LLAMA_ALL_WARNINGS OFF) -set(LLAMA_BLAS ON) -set(LLAMA_BLAS_VENDOR "OpenBLAS") +option(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_BLAS "Enable LLAMA_BLAS in the WASI-NN GGML backend" ON) +if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_BLAS) + message(INFO "WASI-NN GGML LLAMA backend: Enable LLAMA_BLAS") + # Default use OpenBLAS + set(LLAMA_BLAS ON) + set(LLAMA_BLAS_VENDOR "OpenBLAS") +else() + message(INFO "WASI-NN GGML LLAMA backend: Disable LLAMA_BLAS") + set(LLAMA_BLAS OFF) +endif() add_subdirectory(thirdparty) diff --git a/utils/wasi-nn/install-openvino.sh b/utils/wasi-nn/install-openvino.sh index c15b6e93..a6624372 100755 --- a/utils/wasi-nn/install-openvino.sh +++ b/utils/wasi-nn/install-openvino.sh @@ -8,5 +8,5 @@ wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PU apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB echo "deb https://apt.repos.intel.com/openvino/2023 ubuntu20 main" | tee /etc/apt/sources.list.d/intel-openvino-2023.list apt update -apt-get -y install openvino +apt-get -y install openvino-2023.0.2 ldconfig From da9335309b53c913602c4164e21939650e3cdd47 Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 21 Sep 2023 18:03:44 +0800 Subject: [PATCH 150/623] [WASI-NN] Pass preloaded model path for ggml backend Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 43 ++++++++++++++++------------------- plugins/wasi_nn/wasinnenv.cpp | 14 +++++++++--- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 94aa8dc5..8b6f02c2 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -13,7 +13,7 @@ namespace WasmEdge::Host::WASINN::GGML { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML Expect load(WasiNNEnvironment &Env, Span> Builders, - Device Device, uint32_t &GraphId) noexcept { + [[maybe_unused]] Device Device, uint32_t &GraphId) noexcept { // The graph builder length must be 1. if (Builders.size() != 1) { spdlog::error( @@ -22,31 +22,28 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, return ErrNo::InvalidArgument; } - // Setup Graph Device - if (Device != Device::CPU) { - spdlog::error( - "[WASI-NN] GGML backend: Only support CPU target currently."sv); - return ErrNo::InvalidArgument; - } - auto Weight = Builders[0]; std::string BinModel(reinterpret_cast(Weight.data()), Weight.size()); - std::istringstream BinRead(BinModel); - - // TODO: pass the model directly to ggml - // Write ggml model to file. - std::string ModelFilePath("ggml-model.bin"sv); - std::ofstream TempFile(ModelFilePath); - if (!TempFile) { - spdlog::error( - "[WASI-NN] GGML backend: Failed to create the temporary file. " - "Currently, our workaround involves creating a temporary model " - "file named \"ggml-model.bin\" and passing this filename as a " - "parameter to the ggml llama library."sv); - return ErrNo::InvalidArgument; + std::string ModelFilePath; + if (BinModel.substr(0, 8) == "preload:") { + ModelFilePath = BinModel.substr(8); + } else { + // TODO: pass the model directly to ggml + // Write ggml model to file. + std::istringstream BinRead(BinModel); + ModelFilePath = "ggml-model.bin"sv; + std::ofstream TempFile(ModelFilePath); + if (!TempFile) { + spdlog::error( + "[WASI-NN] GGML backend: Failed to create the temporary file. " + "Currently, our workaround involves creating a temporary model " + "file named \"ggml-model.bin\" and passing this filename as a " + "parameter to the ggml llama library."sv); + return ErrNo::InvalidArgument; + } + TempFile << BinModel; + TempFile.close(); } - TempFile << BinModel; - TempFile.close(); // Add a new graph. Env.NNGraph.emplace_back(Backend::GGML); diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index 8bc2b60b..45748cca 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -63,9 +63,17 @@ WasiNNEnvironment::WasiNNEnvironment() noexcept { auto Device = DeviceMap.find(Target); if (Backend != BackendMap.end() && Device != DeviceMap.end()) { for (const std::string &Path : Paths) { - std::vector Model; - if (load(std::filesystem::u8path(Path), Model)) { - Models.push_back(std::move(Model)); + if (Backend->second == Backend::GGML) { + // We write model path to model data to avoid file IO in llama.cpp. + std::string ModelPath = "preload:" + Path; + std::vector ModelPathData(ModelPath.begin(), + ModelPath.end()); + Models.push_back(std::move(ModelPathData)); + } else { + std::vector Model; + if (load(std::filesystem::u8path(Path), Model)) { + Models.push_back(std::move(Model)); + } } } RawMdMap.emplace(Name, std::make_tuple(std::move(Models), Backend->second, From 58c0779d799277295dd16a70be8ac2f9e1b660aa Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 21 Sep 2023 18:06:10 +0800 Subject: [PATCH 151/623] [CI] Fix CI build Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 1 - plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt | 10 ++++++++-- test/plugins/wasi_nn/wasi_nn.cpp | 16 ---------------- 3 files changed, 8 insertions(+), 19 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index a947b64e..8bb5e389 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -3,7 +3,6 @@ # llama.cpp options set(LLAMA_ALL_WARNINGS OFF) -option(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_BLAS "Enable LLAMA_BLAS in the WASI-NN GGML backend" ON) if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_BLAS) message(INFO "WASI-NN GGML LLAMA backend: Enable LLAMA_BLAS") # Default use OpenBLAS diff --git a/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt b/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt index ca91cf5d..22ec934d 100644 --- a/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt +++ b/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt @@ -606,6 +606,14 @@ if (BUILD_SHARED_LIBS) endif() endif() +# global flags for ggml +if (NOT WIN32) + target_compile_options(ggml + PRIVATE + -DGGML_USE_K_QUANTS + ) +endif() + # disable warnings if (NOT WIN32) target_compile_options(ggml @@ -615,8 +623,6 @@ if (NOT WIN32) -Wno-unused-but-set-variable -Wno-unused-function -Wno-missing-braces - -DGGML_USE_K_QUANTS - -DGGML_USE_OPENBLAS ) else() target_compile_options(ggml diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index e5bec1ff..f03c4bbb 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -1338,22 +1338,6 @@ TEST(WasiNNTest, GGMLBackend) { static_cast(ErrNo::InvalidArgument)); } - // Test: load -- the GGML backend currently only supports the CPU target. - // (device: CPU 0, GPU 1, TPU 2) - { - for (uint32_t I = 1; I <= 3; I++) { - - EXPECT_TRUE(HostFuncLoad.run(CallFrame, - std::initializer_list{ - LoadEntryPtr, UINT32_C(1), - static_cast(Backend::GGML), I, - BuilderPtr}, - Errno)); - EXPECT_EQ(Errno[0].get(), - static_cast(ErrNo::InvalidArgument)); - } - } - // Test: load -- load successfully. { EXPECT_TRUE(HostFuncLoad.run(CallFrame, From d47a36e0a04354bc0d416cada435804d569cc811 Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 26 Sep 2023 14:04:14 +0800 Subject: [PATCH 152/623] [WASI-NN] Set n_ctx from LLAMA_N_CTX env Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 8b6f02c2..116b5f5c 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -12,6 +12,20 @@ namespace WasmEdge::Host::WASINN::GGML { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +llama_context_params wasmedge_llama_context_params() noexcept { + llama_context_params Params = llama_context_default_params(); + const char *LlamaNContextEnv = std::getenv("LLAMA_N_CTX"); + const char *LlamaLogEnv = std::getenv("LLAMA_LOG"); + if (LlamaNContextEnv != nullptr) { + Params.n_ctx = std::stoi(LlamaNContextEnv); + if (LlamaLogEnv != nullptr) { + spdlog::info("[WASI-NN] GGML backend: set n_ctx to {}"sv, Params.n_ctx); + } + } + + return Params; +} + Expect load(WasiNNEnvironment &Env, Span> Builders, [[maybe_unused]] Device Device, uint32_t &GraphId) noexcept { // The graph builder length must be 1. @@ -52,7 +66,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Initialize ggml model. gpt_params Params; llama_backend_init(Params.numa); - llama_context_params ContextParams = llama_context_default_params(); + llama_context_params ContextParams = wasmedge_llama_context_params(); GraphRef.LlamaModel = llama_load_model_from_file(ModelFilePath.c_str(), ContextParams); if (GraphRef.LlamaModel == nullptr) { @@ -85,7 +99,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); // Initialize the llama context. - llama_context_params ContextParams = llama_context_default_params(); + llama_context_params ContextParams = wasmedge_llama_context_params(); GraphRef.LlamaContext = llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); @@ -98,7 +112,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, const uint32_t MaxTokensListSize = MaxContextSize - 4; if (CxtRef.LlamaInputs.size() > MaxTokensListSize) { spdlog::error( - "[WASI-NN] GGML backend: Error: prompt too long ({} tokens, max %{})"sv, + "[WASI-NN] GGML backend: Error: prompt too long ({} tokens, max {})"sv, CxtRef.LlamaInputs.size(), MaxTokensListSize); return ErrNo::InvalidArgument; } From 55c1f4eca4764bbf3aaa85b99a7268e8fd2824cd Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 26 Sep 2023 14:51:19 +0800 Subject: [PATCH 153/623] [WASI-NN] Set max tokens predicted by LLAMA_N_PREDICT Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 116b5f5c..791ab5e3 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -152,8 +152,17 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // TODO: recompute a compressed context based on previous tokens once the // cache is full. const int MaxContextSize = llama_n_ctx(GraphRef.LlamaContext); + int NPredict = std::numeric_limits::max(); + const char *LlamaNPredictEnv = std::getenv("LLAMA_N_PREDICT"); + if (LlamaNPredictEnv != nullptr) { + NPredict = std::stoi(LlamaNPredictEnv); + if (LlamaLogEnv != nullptr) { + spdlog::info("[WASI-NN] GGML backend: set n_predict to {}"sv, NPredict); + } + } while (llama_get_kv_cache_token_count(GraphRef.LlamaContext) < - MaxContextSize) { + MaxContextSize && + llama_get_kv_cache_token_count(GraphRef.LlamaContext) < NPredict) { if (llama_eval(GraphRef.LlamaContext, CxtRef.LlamaInputs.data(), int(CxtRef.LlamaInputs.size()), llama_get_kv_cache_token_count(GraphRef.LlamaContext), From f99faea0587e967dd002aef0748405db21ad11a4 Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 26 Sep 2023 16:22:52 +0800 Subject: [PATCH 154/623] [WASI-NN] Update md5sum of testing model Signed-off-by: dm4 --- test/plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index 257691e6..9546c667 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -61,7 +61,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) RESULT_VARIABLE DOWNLOAD_ERROR OUTPUT_STRIP_TRAILING_WHITESPACE) file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_ggml_fixtures/orca-mini-3b.q4_0.gguf CHECKSUM_MODEL) - if(NOT CHECKSUM_MODEL STREQUAL "516027963397e180d7a92aded43d6b3d") + if(NOT CHECKSUM_MODEL STREQUAL "aae346fe095e60139ca39b3fda4ac7ae") message(FATAL_ERROR "orca-mini-3b.q4_0.gguf downloaded with wrong md5") endif() else() From f04fdc57aa3d68ba9b22c2a0bb4dfa2142dbd4b7 Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 27 Sep 2023 12:10:15 +0800 Subject: [PATCH 155/623] [WASI-NN] Update ggml backend to b1273 Signed-off-by: dm4 --- .../wasi_nn/thirdparty/ggml/CMakeLists.txt | 78 +- plugins/wasi_nn/thirdparty/ggml/common.cpp | 31 +- plugins/wasi_nn/thirdparty/ggml/common.h | 15 +- plugins/wasi_nn/thirdparty/ggml/ggml-alloc.c | 12 +- plugins/wasi_nn/thirdparty/ggml/ggml.c | 108 +- plugins/wasi_nn/thirdparty/ggml/ggml.h | 76 +- plugins/wasi_nn/thirdparty/ggml/llama.cpp | 1221 ++++++++++++++--- plugins/wasi_nn/thirdparty/ggml/llama.h | 6 +- 8 files changed, 1242 insertions(+), 305 deletions(-) diff --git a/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt b/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt index 22ec934d..6e97d00d 100644 --- a/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt +++ b/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt @@ -46,6 +46,8 @@ set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kern set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels") option(LLAMA_CUDA_F16 "llama: use 16 bit floats for some calculations" OFF) set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K") +set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING + "llama: max. batch size for using peer access") option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF) option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT}) @@ -68,6 +70,7 @@ set(CMAKE_C_STANDARD 11) set(CMAKE_C_STANDARD_REQUIRED true) set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) +include(CheckCXXCompilerFlag) if (NOT MSVC) if (LLAMA_SANITIZE_THREAD) @@ -104,8 +107,8 @@ if (LLAMA_METAL) find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) message(STATUS "Metal framework found") - - set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h) + set(GGML_HEADERS_METAL ggml-metal.h) + set(GGML_SOURCES_METAL ggml-metal.m) add_compile_definitions(GGML_USE_METAL) if (LLAMA_METAL_NDEBUG) @@ -124,7 +127,6 @@ if (LLAMA_METAL) ${METALKIT_FRAMEWORK} ) endif() - if (LLAMA_BLAS) if (LLAMA_STATIC) set(BLA_STATIC ON) @@ -201,7 +203,8 @@ if (LLAMA_BLAS) endif() if (LLAMA_K_QUANTS) - set(GGML_SOURCES_EXTRA ${GGML_SOURCES_EXTRA} k_quants.c k_quants.h) + set(GGML_HEADERS_EXTRA k_quants.h) + set(GGML_SOURCES_EXTRA k_quants.c) add_compile_definitions(GGML_USE_K_QUANTS) if (LLAMA_QKK_64) add_compile_definitions(GGML_QKK_64) @@ -217,7 +220,8 @@ if (LLAMA_CUBLAS) enable_language(CUDA) - set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h) + set(GGML_HEADERS_CUDA ggml-cuda.h) + set(GGML_SOURCES_CUDA ggml-cuda.cu) add_compile_definitions(GGML_USE_CUBLAS) # if (LLAMA_CUDA_CUBLAS) @@ -235,6 +239,7 @@ if (LLAMA_CUBLAS) add_compile_definitions(GGML_CUDA_F16) endif() add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) + add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${LLAMA_CUDA_PEER_MAX_BATCH_SIZE}) if (LLAMA_STATIC) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) @@ -265,6 +270,7 @@ if (LLAMA_MPI) find_package(MPI) if (MPI_C_FOUND) message(STATUS "MPI found") + set(GGML_HEADERS_MPI ggml-mpi.h) set(GGML_SOURCES_MPI ggml-mpi.c ggml-mpi.h) add_compile_definitions(GGML_USE_MPI) add_compile_definitions(${MPI_C_COMPILE_DEFINITIONS}) @@ -287,7 +293,8 @@ if (LLAMA_CLBLAST) if (CLBlast_FOUND) message(STATUS "CLBlast found") - set(GGML_SOURCES_OPENCL ggml-opencl.cpp ggml-opencl.h) + set(GGML_HEADERS_OPENCL ggml-opencl.h) + set(GGML_SOURCES_OPENCL ggml-opencl.cpp) add_compile_definitions(GGML_USE_CLBLAST) @@ -315,13 +322,15 @@ if (LLAMA_HIPBLAS) message(STATUS "HIP and hipBLAS found") add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS) add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h) + if (BUILD_SHARED_LIBS) + set_target_properties(ggml-rocm PROPERTIES POSITION_INDEPENDENT_CODE ON) + endif() if (LLAMA_CUDA_FORCE_DMMV) target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_FORCE_DMMV) endif() target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y}) target_compile_definitions(ggml-rocm PRIVATE K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) - target_compile_definitions(ggml-rocm PRIVATE CC_TURING=1000000000) set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX) target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas) @@ -354,6 +363,7 @@ if (LLAMA_ALL_WARNINGS) -Wextra -Wpedantic -Wcast-qual + -Wmissing-declarations -Wno-unused-function -Wno-multichar ) @@ -372,7 +382,7 @@ if (LLAMA_ALL_WARNINGS) endif() -if (MSVC) +if (WIN32) add_compile_definitions(_CRT_SECURE_NO_WARNINGS) if (BUILD_SHARED_LIBS) @@ -425,17 +435,21 @@ if ((${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm") OR (${CMAKE_SYSTEM_PROCESSOR} MATC # add_compile_definitions(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) # MSVC doesn't support vdupq_n_f16, vld1q_f16, vst1q_f16 add_compile_definitions(__aarch64__) # MSVC defines _M_ARM64 instead else() + check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E) + if (NOT "${COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "") + add_compile_options(-mfp16-format=ieee) + endif() if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6") # Raspberry Pi 1, Zero - add_compile_options(-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access) + add_compile_options(-mfpu=neon-fp-armv8 -mno-unaligned-access) endif() if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv7") # Raspberry Pi 2 - add_compile_options(-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations) + add_compile_options(-mfpu=neon-fp-armv8 -mno-unaligned-access -funsafe-math-optimizations) endif() if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv8") # Raspberry Pi 3, 4, Zero 2 (32-bit) - add_compile_options(-mfp16-format=ieee -mno-unaligned-access) + add_compile_options(-mno-unaligned-access) endif() endif() elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$" OR "${CMAKE_GENERATOR_PLATFORM_LWR}" MATCHES "^(x86_64|i686|amd64|x64)$" ) @@ -562,11 +576,11 @@ wasmedge_add_library(ggml OBJECT ggml-alloc.h common.cpp common.h - ${GGML_SOURCES_CUDA} - ${GGML_SOURCES_OPENCL} - ${GGML_SOURCES_METAL} - ${GGML_SOURCES_MPI} - ${GGML_SOURCES_EXTRA} + ${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA} + ${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL} + ${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL} + ${GGML_SOURCES_MPI} ${GGML_HEADERS_MPI} + ${GGML_SOURCES_EXTRA} ${GGML_HEADERS_EXTRA} ) target_include_directories(ggml PUBLIC . ${LLAMA_EXTRA_INCLUDES}) @@ -624,6 +638,14 @@ if (NOT WIN32) -Wno-unused-function -Wno-missing-braces ) + target_compile_options(llama + PRIVATE + -Wno-unused-parameter + -Wno-unused-variable + -Wno-unused-but-set-variable + -Wno-unused-function + -Wno-missing-braces + ) else() target_compile_options(ggml PRIVATE @@ -647,4 +669,26 @@ else() -Wno-extra-semi-stmt -Wno-bad-function-cast ) -endif() \ No newline at end of file + target_compile_options(llama + PRIVATE + -Wno-string-conversion + -Wno-sign-conversion + -Wno-macro-redefined + -Wno-missing-prototypes + -Wno-unreachable-code-return + -Wno-shorten-64-to-32 + -Wno-implicit-int-conversion + -Wno-implicit-float-conversion + -Wno-float-conversion + -Wno-unused-macros + -Wno-unreachable-code-break + -Wno-cast-align + -Wno-undef + -Wno-shadow-uncaptured-local + -Wno-unreachable-code + -Wno-cast-function-type + -Wno-format-nonliteral + -Wno-extra-semi-stmt + -Wno-bad-function-cast + ) +endif() diff --git a/plugins/wasi_nn/thirdparty/ggml/common.cpp b/plugins/wasi_nn/thirdparty/ggml/common.cpp index b78ad002..275c038d 100644 --- a/plugins/wasi_nn/thirdparty/ggml/common.cpp +++ b/plugins/wasi_nn/thirdparty/ggml/common.cpp @@ -77,7 +77,7 @@ int32_t get_num_physical_cores() { return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; } -void process_escapes(std::string& input) { +static void process_escapes(std::string& input) { std::size_t input_len = input.length(); std::size_t output_idx = 0; @@ -373,6 +373,17 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { #else fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n"); fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); +#endif + } else if (arg == "--gpu-layers-draft" || arg == "-ngld" || arg == "--n-gpu-layers-draft") { + if (++i >= argc) { + invalid_param = true; + break; + } +#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD + params.n_gpu_layers_draft = std::stoi(argv[i]); +#else + fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n"); + fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); #endif } else if (arg == "--main-gpu" || arg == "-mg") { if (++i >= argc) { @@ -422,8 +433,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { #endif // GGML_USE_CUBLAS } else if (arg == "--no-mmap") { params.use_mmap = false; - } else if (arg == "--mtest") { - params.mem_test = true; } else if (arg == "--numa") { params.numa = true; } else if (arg == "--export") { @@ -637,9 +646,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --cfg-negative-prompt-file FNAME\n"); printf(" negative prompt file to use for guidance. (default: empty)\n"); printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); - printf(" --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale (default: %g)\n", 1.0f/params.rope_freq_scale); - printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: %.1f)\n", params.rope_freq_base); - printf(" --rope-freq-scale N RoPE frequency linear scaling factor, inverse of --rope-scale (default: %g)\n", params.rope_freq_scale); + printf(" --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale\n"); + printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n"); + printf(" --rope-freq-scale N RoPE frequency linear scaling factor (default: loaded from model)\n"); printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); printf(" --no-penalize-nl do not penalize newline token\n"); printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); @@ -663,6 +672,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { #ifdef LLAMA_SUPPORTS_GPU_OFFLOAD printf(" -ngl N, --n-gpu-layers N\n"); printf(" number of layers to store in VRAM\n"); + printf(" -ngld N, --n-gpu-layers-draft N\n"); + printf(" number of layers to store in VRAM for the draft model\n"); printf(" -ts SPLIT --tensor-split SPLIT\n"); printf(" how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n"); printf(" -mg i, --main-gpu i the GPU to use for scratch and small tensors\n"); @@ -673,7 +684,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" Not recommended since this is both slower and uses more VRAM.\n"); #endif // GGML_USE_CUBLAS #endif - printf(" --mtest compute maximum memory usage\n"); printf(" --export export the computation graph to 'llama.ggml'\n"); printf(" --verbose-prompt print prompt before generation\n"); fprintf(stderr, " --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n"); @@ -769,7 +779,7 @@ std::tuple llama_init_from_gpt_par } { - // LOG("warming up the model with an empty run\n"); + LOG("warming up the model with an empty run\n"); const std::vector tmp = { llama_token_bos(lctx), llama_token_eos(lctx), }; llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads); @@ -790,10 +800,10 @@ std::vector llama_tokenize( // upper limit for the number of tokens int n_tokens = text.length() + add_bos; std::vector result(n_tokens); - n_tokens = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos); + n_tokens = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos); if (n_tokens < 0) { result.resize(-n_tokens); - int check = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos); + int check = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos); GGML_ASSERT(check == -n_tokens); } else { result.resize(n_tokens); @@ -1209,7 +1219,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false"); fprintf(stream, "model: %s # default: models/7B/ggml-model.bin\n", params.model.c_str()); fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str()); - fprintf(stream, "mtest: %s # default: false\n", params.mem_test ? "true" : "false"); fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false"); fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers); fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict); diff --git a/plugins/wasi_nn/thirdparty/ggml/common.h b/plugins/wasi_nn/thirdparty/ggml/common.h index 012bf5e1..2761503b 100644 --- a/plugins/wasi_nn/thirdparty/ggml/common.h +++ b/plugins/wasi_nn/thirdparty/ggml/common.h @@ -20,8 +20,13 @@ #define DIRECTORY_SEPARATOR '/' #endif // _WIN32 -#define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0) -#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", ##__VA_ARGS__); exit(1); } while (0) +#define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0) +#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0) + +#define print_build_info() do { \ + fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT); \ + fprintf(stderr, "%s: built with %s for %s\n", __func__, BUILD_COMPILER, BUILD_TARGET); \ +} while(0) // // CLI argument parsing @@ -38,12 +43,13 @@ struct gpt_params { int32_t n_draft = 16; // number of tokens to draft during speculative decoding int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) + int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t n_beams = 0; // if non-zero then use beam search of given width. - float rope_freq_base = 10000.0f; // RoPE base frequency - float rope_freq_scale = 1.0f; // RoPE frequency scaling factor + float rope_freq_base = 0.0f; // RoPE base frequency + float rope_freq_scale = 0.0f; // RoPE frequency scaling factor // sampling parameters int32_t top_k = 40; // <= 0 to use vocab size @@ -109,7 +115,6 @@ struct gpt_params { bool perplexity = false; // compute perplexity over the prompt bool use_mmap = true; // use mmap for faster loads bool use_mlock = false; // use mlock to keep model in memory - bool mem_test = false; // compute maximum memory usage bool numa = false; // attempt optimizations that help on some NUMA systems bool export_cgraph = false; // export the computation graph bool verbose_prompt = false; // print prompt tokens before generation diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.c b/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.c index a1f6e7bf..304964be 100644 --- a/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.c +++ b/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.c @@ -131,6 +131,10 @@ static bool ggml_allocr_is_own(struct ggml_allocr * alloc, const struct ggml_ten return ptr >= alloc->data && (char *)ptr < (char *)alloc->data + alloc->max_size; } +static bool ggml_is_view(struct ggml_tensor * t) { + return t->view_src != NULL; +} + void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) { #ifdef GGML_ALLOCATOR_DEBUG GGML_ASSERT(!ggml_is_view(tensor)); // views generally get data pointer from one of their sources @@ -338,8 +342,8 @@ static void free_vmem(void * base_addr, size_t size) { // allocate uncommitted virtual memory to measure the size of the graph static void alloc_measure_vmem(void ** base_addr, size_t * size) { - // 1TB for 64-bit, 1GB for 32-bit - *size = sizeof(void *) == 4 ? 1ULL<<30 : 1ULL<<40; + // 128GB for 64-bit, 1GB for 32-bit + *size = sizeof(void *) == 4 ? 1ULL<<30 : 1ULL<<37; do { *base_addr = alloc_vmem(*size); if (*base_addr != NULL) { @@ -399,10 +403,6 @@ bool ggml_allocr_is_measure(struct ggml_allocr * alloc) { //////////// compute graph allocator -static bool ggml_is_view(struct ggml_tensor * t) { - return t->view_src != NULL; -} - static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) { if (a->type != b->type) { return false; diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml.c b/plugins/wasi_nn/thirdparty/ggml/ggml.c index a9cffb43..a0be068d 100644 --- a/plugins/wasi_nn/thirdparty/ggml/ggml.c +++ b/plugins/wasi_nn/thirdparty/ggml/ggml.c @@ -4303,10 +4303,21 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) { } size_t ggml_nbytes(const struct ggml_tensor * tensor) { - size_t nbytes = tensor->ne[0]*tensor->nb[0]/ggml_blck_size(tensor->type); - for (int i = 1; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; + size_t nbytes; + size_t blck_size = ggml_blck_size(tensor->type); + if (blck_size == 1) { + nbytes = ggml_type_size(tensor->type); + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; + } + } + else { + nbytes = tensor->ne[0]*tensor->nb[0]/blck_size; + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; + } } + return nbytes; } @@ -17283,10 +17294,18 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { } else { // wait for other threads to finish const int last = node_n; - do { - //sched_yield(); + while (true) { + // TODO: this sched_yield can have significant impact on the performance - either positive or negative + // depending on the workload and the operating system. + // since it is not clear what is the best approach, it should potentially become user-configurable + // ref: https://github.com/ggerganov/ggml/issues/291 +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + sched_yield(); +#endif + node_n = atomic_load(&state->shared->node_n); - } while (node_n == last); + if (node_n != last) break; + }; } // check if we should stop @@ -18337,10 +18356,11 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) { for (int i = 0; i < cgraph->n_leafs; i++) { struct ggml_tensor * node = cgraph->leafs[i]; - GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n", + GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n", i, node->ne[0], node->ne[1], - ggml_op_name(node->op)); + ggml_op_name(node->op), + ggml_get_name(node)); } for (int i = 0; i < GGML_OP_COUNT; i++) { @@ -20099,27 +20119,27 @@ const char * gguf_type_name(enum gguf_type type) { return GGUF_TYPE_NAME[type]; } -int gguf_get_version(struct gguf_context * ctx) { +int gguf_get_version(const struct gguf_context * ctx) { return ctx->header.version; } -size_t gguf_get_alignment(struct gguf_context * ctx) { +size_t gguf_get_alignment(const struct gguf_context * ctx) { return ctx->alignment; } -size_t gguf_get_data_offset(struct gguf_context * ctx) { +size_t gguf_get_data_offset(const struct gguf_context * ctx) { return ctx->offset; } -void * gguf_get_data(struct gguf_context * ctx) { +void * gguf_get_data(const struct gguf_context * ctx) { return ctx->data; } -int gguf_get_n_kv(struct gguf_context * ctx) { +int gguf_get_n_kv(const struct gguf_context * ctx) { return ctx->header.n_kv; } -int gguf_find_key(struct gguf_context * ctx, const char * key) { +int gguf_find_key(const struct gguf_context * ctx, const char * key) { // return -1 if key not found int keyfound = -1; @@ -20135,85 +20155,85 @@ int gguf_find_key(struct gguf_context * ctx, const char * key) { return keyfound; } -const char * gguf_get_key(struct gguf_context * ctx, int i) { +const char * gguf_get_key(const struct gguf_context * ctx, int i) { return ctx->kv[i].key.data; } -enum gguf_type gguf_get_kv_type(struct gguf_context * ctx, int i) { +enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int i) { return ctx->kv[i].type; } -enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i) { +enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int i) { return ctx->kv[i].value.arr.type; } -const void * gguf_get_arr_data(struct gguf_context * ctx, int i) { +const void * gguf_get_arr_data(const struct gguf_context * ctx, int i) { return ctx->kv[i].value.arr.data; } -const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i) { +const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) { struct gguf_kv * kv = &ctx->kv[key_id]; struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i]; return str->data; } -int gguf_get_arr_n(struct gguf_context * ctx, int i) { +int gguf_get_arr_n(const struct gguf_context * ctx, int i) { return ctx->kv[i].value.arr.n; } -uint8_t gguf_get_val_u8(struct gguf_context * ctx, int i) { +uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int i) { return ctx->kv[i].value.uint8; } -int8_t gguf_get_val_i8(struct gguf_context * ctx, int i) { +int8_t gguf_get_val_i8(const struct gguf_context * ctx, int i) { return ctx->kv[i].value.int8; } -uint16_t gguf_get_val_u16(struct gguf_context * ctx, int i) { +uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int i) { return ctx->kv[i].value.uint16; } -int16_t gguf_get_val_i16(struct gguf_context * ctx, int i) { +int16_t gguf_get_val_i16(const struct gguf_context * ctx, int i) { return ctx->kv[i].value.int16; } -uint32_t gguf_get_val_u32(struct gguf_context * ctx, int i) { +uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int i) { return ctx->kv[i].value.uint32; } -int32_t gguf_get_val_i32(struct gguf_context * ctx, int i) { +int32_t gguf_get_val_i32(const struct gguf_context * ctx, int i) { return ctx->kv[i].value.int32; } -float gguf_get_val_f32(struct gguf_context * ctx, int i) { +float gguf_get_val_f32(const struct gguf_context * ctx, int i) { return ctx->kv[i].value.float32; } -uint64_t gguf_get_val_u64(struct gguf_context * ctx, int i) { +uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int i) { return ctx->kv[i].value.uint64; } -int64_t gguf_get_val_i64(struct gguf_context * ctx, int i) { +int64_t gguf_get_val_i64(const struct gguf_context * ctx, int i) { return ctx->kv[i].value.int64; } -double gguf_get_val_f64(struct gguf_context * ctx, int i) { +double gguf_get_val_f64(const struct gguf_context * ctx, int i) { return ctx->kv[i].value.float64; } -bool gguf_get_val_bool(struct gguf_context * ctx, int i) { +bool gguf_get_val_bool(const struct gguf_context * ctx, int i) { return ctx->kv[i].value.bool_; } -const char * gguf_get_val_str (struct gguf_context * ctx, int i) { +const char * gguf_get_val_str (const struct gguf_context * ctx, int i) { return ctx->kv[i].value.str.data; } -int gguf_get_n_tensors(struct gguf_context * ctx) { +int gguf_get_n_tensors(const struct gguf_context * ctx) { return ctx->header.n_tensors; } -int gguf_find_tensor(struct gguf_context * ctx, const char * name) { +int gguf_find_tensor(const struct gguf_context * ctx, const char * name) { // return -1 if tensor not found int tensorfound = -1; @@ -20229,11 +20249,11 @@ int gguf_find_tensor(struct gguf_context * ctx, const char * name) { return tensorfound; } -size_t gguf_get_tensor_offset(struct gguf_context * ctx, int i) { +size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i) { return ctx->infos[i].offset; } -char * gguf_get_tensor_name(struct gguf_context * ctx, int i) { +char * gguf_get_tensor_name(const struct gguf_context * ctx, int i) { return ctx->infos[i].name.data; } @@ -20516,7 +20536,7 @@ static void gguf_bwrite_el(struct gguf_buf * buf, const void * val, size_t el_si buf->offset += el_size; } -static void gguf_write_to_buf(struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) { +static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) { // write header gguf_bwrite_el(buf, &ctx->header.magic, sizeof(ctx->header.magic)); gguf_bwrite_el(buf, &ctx->header.version, sizeof(ctx->header.version)); @@ -20631,7 +20651,7 @@ static void gguf_write_to_buf(struct gguf_context * ctx, struct gguf_buf * buf, } } -void gguf_write_to_file(struct gguf_context * ctx, const char * fname, bool only_meta) { +void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) { FILE * file = fopen(fname, "wb"); if (!file) { GGML_ASSERT(false && "failed to open file for writing"); @@ -20648,7 +20668,7 @@ void gguf_write_to_file(struct gguf_context * ctx, const char * fname, bool only fclose(file); } -size_t gguf_get_meta_size(struct gguf_context * ctx) { +size_t gguf_get_meta_size(const struct gguf_context * ctx) { // no allocs - only compute size struct gguf_buf buf = gguf_buf_init(0); @@ -20657,7 +20677,7 @@ size_t gguf_get_meta_size(struct gguf_context * ctx) { return buf.offset; } -void gguf_get_meta_data(struct gguf_context * ctx, void * data) { +void gguf_get_meta_data(const struct gguf_context * ctx, void * data) { struct gguf_buf buf = gguf_buf_init(16*1024); gguf_write_to_buf(ctx, &buf, true); @@ -20733,6 +20753,14 @@ int ggml_cpu_has_arm_fma(void) { #endif } +int ggml_cpu_has_metal(void) { +#if defined(GGML_USE_METAL) + return 1; +#else + return 0; +#endif +} + int ggml_cpu_has_f16c(void) { #if defined(__F16C__) return 1; diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml.h b/plugins/wasi_nn/thirdparty/ggml/ggml.h index 6d4cf465..f4545687 100644 --- a/plugins/wasi_nn/thirdparty/ggml/ggml.h +++ b/plugins/wasi_nn/thirdparty/ggml/ggml.h @@ -195,6 +195,14 @@ # define GGML_DEPRECATED(func, hint) func #endif +#ifndef __GNUC__ +# define GGML_ATTRIBUTE_FORMAT(...) +#elif defined(__MINGW32__) +# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif + #include #include #include @@ -270,7 +278,7 @@ extern "C" { #if defined(__ARM_NEON) && defined(__CUDACC__) typedef half ggml_fp16_t; -#elif defined(__ARM_NEON) && !defined(_MSC_VER) +#elif defined(__ARM_NEON) typedef __fp16 ggml_fp16_t; #else typedef uint16_t ggml_fp16_t; @@ -685,6 +693,7 @@ extern "C" { GGML_API const char * ggml_get_name (const struct ggml_tensor * tensor); GGML_API struct ggml_tensor * ggml_set_name ( struct ggml_tensor * tensor, const char * name); + GGML_ATTRIBUTE_FORMAT(2, 3) GGML_API struct ggml_tensor * ggml_format_name( struct ggml_tensor * tensor, const char * fmt, ...); // @@ -1866,39 +1875,39 @@ extern "C" { GGML_API const char * gguf_type_name(enum gguf_type type); - GGML_API int gguf_get_version (struct gguf_context * ctx); - GGML_API size_t gguf_get_alignment (struct gguf_context * ctx); - GGML_API size_t gguf_get_data_offset(struct gguf_context * ctx); - GGML_API void * gguf_get_data (struct gguf_context * ctx); + GGML_API int gguf_get_version (const struct gguf_context * ctx); + GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx); + GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx); + GGML_API void * gguf_get_data (const struct gguf_context * ctx); - GGML_API int gguf_get_n_kv(struct gguf_context * ctx); - GGML_API int gguf_find_key(struct gguf_context * ctx, const char * key); - GGML_API const char * gguf_get_key (struct gguf_context * ctx, int i); + GGML_API int gguf_get_n_kv(const struct gguf_context * ctx); + GGML_API int gguf_find_key(const struct gguf_context * ctx, const char * key); + GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int i); - GGML_API enum gguf_type gguf_get_kv_type (struct gguf_context * ctx, int i); - GGML_API enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i); + GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int i); + GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int i); // results are undefined if the wrong type is used for the key - GGML_API uint8_t gguf_get_val_u8 (struct gguf_context * ctx, int i); - GGML_API int8_t gguf_get_val_i8 (struct gguf_context * ctx, int i); - GGML_API uint16_t gguf_get_val_u16 (struct gguf_context * ctx, int i); - GGML_API int16_t gguf_get_val_i16 (struct gguf_context * ctx, int i); - GGML_API uint32_t gguf_get_val_u32 (struct gguf_context * ctx, int i); - GGML_API int32_t gguf_get_val_i32 (struct gguf_context * ctx, int i); - GGML_API float gguf_get_val_f32 (struct gguf_context * ctx, int i); - GGML_API uint64_t gguf_get_val_u64 (struct gguf_context * ctx, int i); - GGML_API int64_t gguf_get_val_i64 (struct gguf_context * ctx, int i); - GGML_API double gguf_get_val_f64 (struct gguf_context * ctx, int i); - GGML_API bool gguf_get_val_bool(struct gguf_context * ctx, int i); - GGML_API const char * gguf_get_val_str (struct gguf_context * ctx, int i); - GGML_API int gguf_get_arr_n (struct gguf_context * ctx, int i); - GGML_API const void * gguf_get_arr_data(struct gguf_context * ctx, int i); - GGML_API const char * gguf_get_arr_str (struct gguf_context * ctx, int key_id, int i); - - GGML_API int gguf_get_n_tensors (struct gguf_context * ctx); - GGML_API int gguf_find_tensor (struct gguf_context * ctx, const char * name); - GGML_API size_t gguf_get_tensor_offset(struct gguf_context * ctx, int i); - GGML_API char * gguf_get_tensor_name (struct gguf_context * ctx, int i); + GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int i); + GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int i); + GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int i); + GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int i); + GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int i); + GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int i); + GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int i); + GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int i); + GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int i); + GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int i); + GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int i); + GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int i); + GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int i); + GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int i); + GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i); + + GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx); + GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name); + GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i); + GGML_API char * gguf_get_tensor_name (const struct gguf_context * ctx, int i); // overrides existing values or adds a new one GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val); @@ -1943,11 +1952,11 @@ extern "C" { // // write the entire context to a binary file - GGML_API void gguf_write_to_file(struct gguf_context * ctx, const char * fname, bool only_meta); + GGML_API void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta); // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding - GGML_API size_t gguf_get_meta_size(struct gguf_context * ctx); - GGML_API void gguf_get_meta_data(struct gguf_context * ctx, void * data); + GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx); + GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data); // // system info @@ -1961,6 +1970,7 @@ extern "C" { GGML_API int ggml_cpu_has_fma (void); GGML_API int ggml_cpu_has_neon (void); GGML_API int ggml_cpu_has_arm_fma (void); + GGML_API int ggml_cpu_has_metal (void); GGML_API int ggml_cpu_has_f16c (void); GGML_API int ggml_cpu_has_fp16_va (void); GGML_API int ggml_cpu_has_wasm_simd (void); diff --git a/plugins/wasi_nn/thirdparty/ggml/llama.cpp b/plugins/wasi_nn/thirdparty/ggml/llama.cpp index c183ddb3..758a1c12 100644 --- a/plugins/wasi_nn/thirdparty/ggml/llama.cpp +++ b/plugins/wasi_nn/thirdparty/ggml/llama.cpp @@ -1,3 +1,4 @@ +#define LLAMA_API_INTERNAL #include "llama.h" #include "ggml.h" @@ -57,7 +58,6 @@ #include #include #include -#include #include #include #include @@ -109,7 +109,7 @@ static size_t utf8_len(char src) { return lookup[highbits]; } -void replace_all(std::string & s, const std::string & search, const std::string & replace) { +static void replace_all(std::string & s, const std::string & search, const std::string & replace) { std::string result; for (size_t pos = 0; ; pos += search.length()) { auto new_pos = s.find(search, pos); @@ -156,20 +156,24 @@ static std::string format(const char * fmt, ...) { enum llm_arch { LLM_ARCH_LLAMA, LLM_ARCH_FALCON, + LLM_ARCH_BAICHUAN, LLM_ARCH_GPT2, LLM_ARCH_GPTJ, LLM_ARCH_GPTNEOX, LLM_ARCH_MPT, + LLM_ARCH_STARCODER, LLM_ARCH_UNKNOWN, }; static std::map LLM_ARCH_NAMES = { - { LLM_ARCH_LLAMA, "llama" }, - { LLM_ARCH_FALCON, "falcon" }, - { LLM_ARCH_GPT2, "gpt2" }, - { LLM_ARCH_GPTJ, "gptj" }, - { LLM_ARCH_GPTNEOX, "gptneox" }, - { LLM_ARCH_MPT, "mpt" }, + { LLM_ARCH_LLAMA, "llama" }, + { LLM_ARCH_FALCON, "falcon" }, + { LLM_ARCH_GPT2, "gpt2" }, + { LLM_ARCH_GPTJ, "gptj" }, + { LLM_ARCH_GPTNEOX, "gptneox" }, + { LLM_ARCH_MPT, "mpt" }, + { LLM_ARCH_BAICHUAN, "baichuan" }, + { LLM_ARCH_STARCODER, "starcoder" }, }; enum llm_kv { @@ -310,6 +314,25 @@ static std::map> LLM_TENSOR_NAMES = { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_BAICHUAN, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_FALCON, { @@ -356,6 +379,21 @@ static std::map> LLM_TENSOR_NAMES = { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, }, }, + { + LLM_ARCH_STARCODER, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_POS_EMBD, "position_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -875,9 +913,11 @@ static llama_state g_state; // available llama models enum e_model { MODEL_UNKNOWN, + MODEL_1B, MODEL_3B, MODEL_7B, MODEL_13B, + MODEL_15B, MODEL_30B, MODEL_34B, MODEL_40B, @@ -887,24 +927,24 @@ enum e_model { static const size_t kB = 1024; static const size_t MB = kB*kB; +static const size_t GB = kB*kB*kB; -// default hparams (LLaMA 7B) struct llama_hparams { - uint32_t n_vocab = 32000; - uint32_t n_ctx_train = 2048; // the context size used during training - uint32_t n_ctx = 512; // the context size used during inference - uint32_t n_embd = 4096; - uint32_t n_head = 32; - uint32_t n_head_kv = 32; - uint32_t n_layer = 32; - uint32_t n_rot = 64; - uint32_t n_ff = 11008; - - float f_norm_eps = 1e-5; - float f_norm_rms_eps = 1e-5; - - float rope_freq_base = 10000.0f; - float rope_freq_scale = 1.0f; + uint32_t n_vocab; + uint32_t n_ctx_train; // context size the model was trained on + uint32_t n_ctx; // context size used during inference + uint32_t n_embd; + uint32_t n_head; + uint32_t n_head_kv; + uint32_t n_layer; + uint32_t n_rot; + uint32_t n_ff; + + float f_norm_eps; + float f_norm_rms_eps; + + float rope_freq_base; + float rope_freq_scale; bool operator!=(const llama_hparams & other) const { return static_cast(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT @@ -946,13 +986,22 @@ struct llama_layer { struct ggml_tensor * wo; struct ggml_tensor * wqkv; + // attention bias + struct ggml_tensor * bo; + struct ggml_tensor * bqkv; + // normalization struct ggml_tensor * ffn_norm; + struct ggml_tensor * ffn_norm_b; // ff struct ggml_tensor * w1; // ffn_gate struct ggml_tensor * w2; // ffn_down struct ggml_tensor * w3; // ffn_up + + // ff bias + struct ggml_tensor * b2; // ffn_down + struct ggml_tensor * b3; // ffn_up }; struct llama_kv_cache { @@ -1026,10 +1075,11 @@ struct llama_model { std::string name = "n/a"; - llama_hparams hparams; + llama_hparams hparams = {}; llama_vocab vocab; struct ggml_tensor * tok_embeddings; + struct ggml_tensor * pos_embeddings; struct ggml_tensor * output_norm; struct ggml_tensor * output_norm_b; @@ -1230,6 +1280,7 @@ struct llama_model_loader { int n_created = 0; int64_t n_elements = 0; + size_t n_bytes = 0; bool use_mmap = false; @@ -1262,6 +1313,7 @@ struct llama_model_loader { const char * name = gguf_get_tensor_name(ctx_gguf, i); struct ggml_tensor * t = ggml_get_tensor(ctx_meta, name); n_elements += ggml_nelements(t); + n_bytes += ggml_nbytes(t); } LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n", @@ -1540,7 +1592,7 @@ struct llama_model_loader { // load LLaMA models // -std::string llama_model_ftype_name(enum llama_ftype ftype) { +static std::string llama_model_ftype_name(enum llama_ftype ftype) { if (ftype & LLAMA_FTYPE_GUESSED) { return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)"; } @@ -1573,9 +1625,11 @@ std::string llama_model_ftype_name(enum llama_ftype ftype) { static const char * llama_model_type_name(e_model type) { switch (type) { + case MODEL_1B: return "1B"; case MODEL_3B: return "3B"; case MODEL_7B: return "7B"; case MODEL_13B: return "13B"; + case MODEL_15B: return "15B"; case MODEL_30B: return "30B"; case MODEL_34B: return "34B"; case MODEL_40B: return "40B"; @@ -1619,28 +1673,17 @@ static void llm_load_hparams( hparams.n_head_kv = hparams.n_head; GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV)); - // TODO: manually setting rope freq base and scale should override this - // FIXME: partial fix when the param specified is not the default value, but - // will not work for overriding the model value to the params default - - llama_context_params defaults = llama_context_default_params(); - - // rope_freq_base - { - float ropebase = 10000.0f; - GGUF_GET_KEY(ctx, ropebase, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE)); - if (ropebase != 10000.0f && rope_freq_base == defaults.rope_freq_base) { - rope_freq_base = ropebase; - } + // rope_freq_base (optional) + if (rope_freq_base == 0.0f) { + rope_freq_base = 10000.0f; + GGUF_GET_KEY(ctx, rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE)); } // rope_freq_scale (inverse of the kv) is optional - { + if (rope_freq_scale == 0.0f) { float ropescale = 1.0f; GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR)); - if (ropescale != 1.0f && rope_freq_scale == defaults.rope_freq_scale) { - rope_freq_scale = 1.0f/ropescale; - } + rope_freq_scale = 1.0f/ropescale; } // sanity check for n_rot (optional) @@ -1684,6 +1727,26 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_BAICHUAN: + { + GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS)); + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 40: model.type = e_model::MODEL_13B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_STARCODER: + { + GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS)); + switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_1B; break; + case 36: model.type = e_model::MODEL_3B; break; + case 42: model.type = e_model::MODEL_7B; break; + case 40: model.type = e_model::MODEL_15B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; }; @@ -1837,7 +1900,12 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str()); - LLAMA_LOG_INFO("%s: model size = %.2f B\n", __func__, ml.n_elements*1e-9); + LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9); + if (ml.n_bytes < GB) { + LLAMA_LOG_INFO("%s: model size = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements); + } else { + LLAMA_LOG_INFO("%s: model size = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements); + } // general kv LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str()); @@ -1924,7 +1992,6 @@ static void llm_load_tensors( const int64_t n_vocab = hparams.n_vocab; const auto tn = LLM_TN(model.arch); - switch (model.arch) { case LLM_ARCH_LLAMA: { @@ -1967,6 +2034,72 @@ static void llm_load_tensors( model.layers.resize(n_layer); + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT + const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + + layer.wq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, backend_split); + layer.wk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + + layer.w1 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); + layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += + ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + + ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + + ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3); + } + } + } break; + case LLM_ARCH_BAICHUAN: + { + model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + { + ggml_backend backend_norm; + ggml_backend backend_output; + + if (n_gpu_layers > int(n_layer)) { + // norm is not performance relevant on its own but keeping it in VRAM reduces data copying + // on Windows however this is detrimental unless everything is on the GPU +#ifndef _WIN32 + backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; +#else + backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; +#endif // _WIN32 + + backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } + + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } + } + + const uint32_t n_ff = hparams.n_ff; + + const int i_gpu_start = n_layer - n_gpu_layers; + + model.layers.resize(n_layer); + for (uint32_t i = 0; i < n_layer; ++i) { const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT @@ -2072,6 +2205,85 @@ static void llm_load_tensors( } } } break; + case LLM_ARCH_STARCODER: + { + model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + model.pos_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}, GGML_BACKEND_CPU); + + // output + { + ggml_backend backend_norm; + ggml_backend backend_output; + + if (n_gpu_layers > int(n_layer)) { + // norm is not performance relevant on its own but keeping it in VRAM reduces data copying + // on Windows however this is detrimental unless everything is on the GPU +#ifndef _WIN32 + backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; +#else + backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; +#endif // _WIN32 + + backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } + + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + vram_weights += ggml_nbytes(model.output_norm_b); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } + } + + const uint32_t n_ff = hparams.n_ff; + + const int i_gpu_start = n_layer - n_gpu_layers; + + model.layers.resize(n_layer); + + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT + const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); + + layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); + layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend_split); + + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend_split); + + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend); + + layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split); + layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend_split); + + layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend_split); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += + ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) + + ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.bqkv) + + ggml_nbytes(layer.wo) + ggml_nbytes(layer.bo) + + ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_norm_b) + + ggml_nbytes(layer.w2) + ggml_nbytes(layer.b2) + + ggml_nbytes(layer.w3) + ggml_nbytes(layer.b3); + } + } + } break; default: throw std::runtime_error("unknown architecture"); }; @@ -2353,11 +2565,356 @@ static struct ggml_cgraph * llm_build_llama( offload_func_kq(tmpq); ggml_set_name(tmpq, "tmpq"); - struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); + struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); + offload_func_kq(Kcur); + ggml_set_name(Kcur, "Kcur"); + + struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); + offload_func_kq(Qcur); + ggml_set_name(Qcur, "Qcur"); + + // store key and value to memory + { + // compute the transposed [N, n_embd] V matrix + + struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + offload_func_v(tmpv); + ggml_set_name(tmpv, "tmpv"); + + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N)); + offload_func_v(Vcur); + ggml_set_name(Vcur, "Vcur"); + + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); + offload_func_kq(k); + ggml_set_name(k, "k"); + + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); + offload_func_v(v); + ggml_set_name(v, "v"); + + // important: storing RoPE-ed version of K in the KV cache! + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + } + + struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + offload_func_kq(Q); + ggml_set_name(Q, "Q"); + + struct ggml_tensor * K = + ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_past + N, n_head_kv, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); + offload_func_kq(K); + ggml_set_name(K, "K"); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + offload_func_kq(KQ); + ggml_set_name(KQ, "KQ"); + + // KQ_scaled = KQ / sqrt(n_embd_head) + // KQ_scaled shape [n_past + N, N, n_head, 1] + struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + offload_func_kq(KQ_scaled); + ggml_set_name(KQ_scaled, "KQ_scaled"); + + // KQ_masked = mask_past(KQ_scaled) + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + offload_func_kq(KQ_masked); + ggml_set_name(KQ_masked, "KQ_masked"); + + // KQ = soft_max(KQ_masked) + struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + offload_func_v(KQ_soft_max); + ggml_set_name(KQ_soft_max, "KQ_soft_max"); + + // split cached V into n_head heads + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_past + N, n_embd_head, n_head_kv, + ggml_element_size(kv_self.v)*n_ctx, + ggml_element_size(kv_self.v)*n_ctx*n_embd_head, + ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); + offload_func_v(V); + ggml_set_name(V, "V"); + +#if 1 + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + offload_func_v(KQV); + ggml_set_name(KQV, "KQV"); +#else + // make V contiguous in memory to speed up the matmul, however we waste time on the copy + // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation + // is there a better way? + struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head)); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max); +#endif + + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + offload_func_v(KQV_merged); + ggml_set_name(KQV_merged, "KQV_merged"); + + // cur = KQV_merged.contiguous().view(n_embd, N) + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + offload_func_v(cur); + ggml_set_name(cur, "KQV_merged_contiguous"); + + // projection (no bias) + cur = ggml_mul_mat(ctx0, + model.layers[il].wo, + cur); + offload_func(cur); + ggml_set_name(cur, "result_wo"); + } + + struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); + offload_func(inpFF); + ggml_set_name(inpFF, "inpFF"); + + // feed-forward network + { + // norm + { + cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps); + offload_func(cur); + ggml_set_name(cur, "rms_norm_1"); + + // cur = cur*ffn_norm(broadcasted) + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm); + offload_func(cur); + ggml_set_name(cur, "ffn_norm"); + } + + struct ggml_tensor * tmp = ggml_mul_mat(ctx0, + model.layers[il].w3, + cur); + offload_func(tmp); + ggml_set_name(tmp, "result_w3"); + + cur = ggml_mul_mat(ctx0, + model.layers[il].w1, + cur); + offload_func(cur); + ggml_set_name(cur, "result_w1"); + + // SILU activation + cur = ggml_silu(ctx0, cur); + offload_func(cur); + ggml_set_name(cur, "silu"); + + cur = ggml_mul(ctx0, cur, tmp); + offload_func(cur); + ggml_set_name(cur, "silu_x_result_w3"); + + cur = ggml_mul_mat(ctx0, + model.layers[il].w2, + cur); + offload_func(cur); + ggml_set_name(cur, "result_w2"); + } + + cur = ggml_add(ctx0, cur, inpFF); + offload_func(cur); + ggml_set_name(cur, "inpFF_+_result_w2"); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + // norm + { + cur = ggml_rms_norm(ctx0, cur, norm_rms_eps); + offload_func_nr(cur); + ggml_set_name(cur, "rms_norm_2"); + + // cur = cur*norm(broadcasted) + cur = ggml_mul(ctx0, cur, model.output_norm); + // offload_func_nr(cur); // TODO CPU + GPU mirrored backend + ggml_set_name(cur, "result_norm"); + } + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + ggml_set_name(cur, "result_output"); + + ggml_build_forward_expand(gf, cur); + + ggml_free(ctx0); + + return gf; +} + + +static struct ggml_cgraph * llm_build_baichaun( + llama_context & lctx, + const llama_token * tokens, + const float * embd, + int n_tokens, + int n_past) { + + GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT + + const int N = n_tokens; + + const auto & model = lctx.model; + const auto & hparams = model.hparams; + + const auto & kv_self = lctx.kv_self; + + GGML_ASSERT(!!kv_self.ctx); + + const int64_t n_embd = hparams.n_embd; + const int64_t n_layer = hparams.n_layer; + const int64_t n_ctx = hparams.n_ctx; + const int64_t n_head = hparams.n_head; + const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_embd_head = hparams.n_embd_head(); + const int64_t n_embd_gqa = hparams.n_embd_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_rot); + + const float freq_base = hparams.rope_freq_base; + const float freq_scale = hparams.rope_freq_scale; + const float norm_rms_eps = hparams.f_norm_rms_eps; + + const int n_gpu_layers = model.n_gpu_layers; + + auto & buf_compute = lctx.buf_compute; + + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute.size, + /*.mem_buffer =*/ buf_compute.data, + /*.no_alloc =*/ false, + }; + + params.no_alloc = true; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + if (tokens) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + + ggml_allocr_alloc(lctx.alloc, inp_tokens); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); + } + ggml_set_name(inp_tokens, "inp_tokens"); + + inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); + } else { +#ifdef GGML_USE_MPI + GGML_ASSERT(false && "not implemented"); +#endif + + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); + + ggml_allocr_alloc(lctx.alloc, inpL); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); + } + } + + const int i_gpu_start = n_layer - n_gpu_layers; + (void) i_gpu_start; + + // offload functions set the tensor output backend to GPU + // tensors are GPU-accelerated if any input or the output has been offloaded + // + // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal + // in that case ggml_cuda_assign_buffers has no effect + offload_func_t offload_func_nr = llama_nop; // nr = non-repeating + offload_func_t offload_func_kq = llama_nop; + offload_func_t offload_func_v = llama_nop; + +#ifdef GGML_USE_CUBLAS + if (n_gpu_layers > n_layer) { + offload_func_nr = ggml_cuda_assign_buffers_no_alloc; + } + if (n_gpu_layers > n_layer + 1) { + offload_func_v = ggml_cuda_assign_buffers_no_alloc; + } + if (n_gpu_layers > n_layer + 2) { + offload_func_kq = ggml_cuda_assign_buffers_no_alloc; + } +#endif // GGML_USE_CUBLAS + + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_allocr_alloc(lctx.alloc, KQ_scale); + if (!ggml_allocr_is_measure(lctx.alloc)) { + ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); + } + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + + for (int il = 0; il < n_layer; ++il) { + ggml_format_name(inpL, "layer_inp_%d", il); + + offload_func_t offload_func = llama_nop; + +#ifdef GGML_USE_CUBLAS + if (il >= i_gpu_start) { + offload_func = ggml_cuda_assign_buffers_no_alloc; + } +#endif // GGML_USE_CUBLAS + + struct ggml_tensor * inpSA = inpL; + + // norm + { + cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps); + offload_func(cur); + ggml_set_name(cur, "rms_norm_0"); + + // cur = cur*attn_norm(broadcasted) + cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm); + offload_func(cur); + ggml_set_name(cur, "attention_norm_0"); + } + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + offload_func_kq(tmpk); + ggml_set_name(tmpk, "tmpk"); + + struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + offload_func_kq(tmpq); + ggml_set_name(tmpq, "tmpq"); + + struct ggml_tensor * Kcur; + struct ggml_tensor * Qcur; + switch (model.type) { + case MODEL_7B: + Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); + Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); + break; + case MODEL_13B: + Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N); + Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N); + break; + default: + GGML_ASSERT(false); + } + offload_func_kq(Kcur); ggml_set_name(Kcur, "Kcur"); - struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); offload_func_kq(Qcur); ggml_set_name(Qcur, "Qcur"); @@ -2412,10 +2969,26 @@ static struct ggml_cgraph * llm_build_llama( offload_func_kq(KQ_scaled); ggml_set_name(KQ_scaled, "KQ_scaled"); + struct ggml_tensor * KQ_masked; + struct ggml_tensor * KQ_scaled_alibi; + + switch (model.type) { + case MODEL_7B: + KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + break; + case MODEL_13B: + KQ_scaled_alibi =ggml_alibi(ctx0, KQ_scaled, n_past, n_head, 8); + ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi"); + KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past); + break; + default: + GGML_ASSERT(false); + } // KQ_masked = mask_past(KQ_scaled) - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); - offload_func_kq(KQ_masked); - ggml_set_name(KQ_masked, "KQ_masked"); + // struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + // struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past); + // offload_func_kq(KQ_masked); + // ggml_set_name(KQ_masked, "KQ_masked"); // KQ = soft_max(KQ_masked) struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); @@ -2850,6 +3423,235 @@ static struct ggml_cgraph * llm_build_falcon( return gf; } +static struct ggml_cgraph * llm_build_starcoder( + llama_context & lctx, + const llama_token * tokens, + const float * embd, + int n_tokens, + int n_past) { + + GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT + + const int N = n_tokens; + + const auto & model = lctx.model; + const auto & hparams = model.hparams; + + const auto & kv_self = lctx.kv_self; + + GGML_ASSERT(!!kv_self.ctx); + + const int64_t n_embd = hparams.n_embd; + const int64_t n_layer = hparams.n_layer; + const int64_t n_ctx = hparams.n_ctx; + const int64_t n_head = hparams.n_head; + const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_embd_head = hparams.n_embd_head(); + const int64_t n_embd_gqa = hparams.n_embd_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_rot); + + const float norm_eps = hparams.f_norm_eps; + + auto & buf_compute = lctx.buf_compute; + + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute.size, + /*.mem_buffer =*/ buf_compute.data, + /*.no_alloc =*/ false, + }; + + params.no_alloc = true; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur; + struct ggml_tensor * token; + struct ggml_tensor * position; + struct ggml_tensor * inpL; + + if (tokens) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + + ggml_allocr_alloc(lctx.alloc, inp_tokens); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); + } + ggml_set_name(inp_tokens, "inp_tokens"); + + token = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); + } else { +#ifdef GGML_USE_MPI + GGML_ASSERT(false && "not implemented"); +#endif + + token = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); + + ggml_allocr_alloc(lctx.alloc, token); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(token->data, embd, N * n_embd * ggml_element_size(token)); + } + } + + { + // Compute position embeddings. + struct ggml_tensor * inp_positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + ggml_allocr_alloc(lctx.alloc, inp_positions); + if (!ggml_allocr_is_measure(lctx.alloc)) { + for (int i = 0; i < N; ++i) { + ((int32_t *) inp_positions->data)[i] = n_past + i; + } + } + ggml_set_name(inp_positions, "inp_positions"); + + position = ggml_get_rows(ctx0, model.pos_embeddings, inp_positions); + } + + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_allocr_alloc(lctx.alloc, KQ_scale); + if (!ggml_allocr_is_measure(lctx.alloc)) { + ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); + } + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + + inpL = ggml_add(ctx0, token, position); + ggml_set_name(inpL, "inpL"); + + for (int il = 0; il < n_layer; ++il) { + { + // Norm + cur = ggml_norm(ctx0, inpL, norm_eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].attn_norm), model.layers[il].attn_norm_b); + } + + { + // Self Attention + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wqkv, cur), model.layers[il].bqkv); + + struct ggml_tensor * tmpq = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd); + struct ggml_tensor * tmpk = ggml_view_2d(ctx0, cur, n_embd_gqa, N, cur->nb[1], sizeof(float)*n_embd); + struct ggml_tensor * tmpv = ggml_view_2d(ctx0, cur, n_embd_gqa, N, cur->nb[1], sizeof(float)*(n_embd + n_embd_gqa)); + + struct ggml_tensor * Qcur = tmpq; + struct ggml_tensor * Kcur = tmpk; + + { + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, N)); + ggml_set_name(Vcur, "Vcur"); + + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); + ggml_set_name(k, "k"); + + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + } + + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Qcur, + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd_head, n_head, N)), + 0, 2, 1, 3); + ggml_set_name(Q, "Q"); + + struct ggml_tensor * K = + ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_past + N, n_head_kv, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); + ggml_set_name(K, "K"); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + ggml_set_name(KQ, "KQ"); + + // KQ_scaled = KQ / sqrt(n_embd_head) + // KQ_scaled shape [n_past + N, N, n_head, 1] + struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + ggml_set_name(KQ_scaled, "KQ_scaled"); + + // KQ_masked = mask_past(KQ_scaled) + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + ggml_set_name(KQ_masked, "KQ_masked"); + + // KQ = soft_max(KQ_masked) + struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + ggml_set_name(KQ_soft_max, "KQ_soft_max"); + + // split cached V into n_head heads + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_past + N, n_embd_head, n_head_kv, + ggml_element_size(kv_self.v)*n_ctx, + ggml_element_size(kv_self.v)*n_ctx*n_embd_head, + ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); + ggml_set_name(V, "V"); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + ggml_set_name(KQV, "KQV"); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + ggml_set_name(KQV_merged, "KQV_merged"); + + // cur = KQV_merged.contiguous().view(n_embd, N) + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + ggml_set_name(cur, "KQV_merged_contiguous"); + } + + // Projection + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wo, cur), model.layers[il].bo); + + // Add the input + cur = ggml_add(ctx0, cur, inpL); + + struct ggml_tensor * inpFF = cur; + + // FF + { + // Norm + { + cur = ggml_norm(ctx0, inpFF, norm_eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ffn_norm), model.layers[il].ffn_norm_b); + } + + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w3, cur), model.layers[il].b3); + + // GELU activation + cur = ggml_gelu(ctx0, cur); + + // Projection + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w2, cur), model.layers[il].b2); + } + + inpL = ggml_add(ctx0, cur, inpFF); + } + + // Output Norm + { + cur = ggml_norm(ctx0, inpL, norm_eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.output_norm), model.output_norm_b); + } + ggml_set_name(cur, "result_norm"); + + cur = ggml_mul_mat(ctx0, model.output, cur); + ggml_set_name(cur, "result_output"); + + ggml_build_forward_expand(gf, cur); + ggml_free(ctx0); + + return gf; +} + static struct ggml_cgraph * llama_build_graph( llama_context & lctx, const llama_token * tokens, @@ -2865,10 +3667,18 @@ static struct ggml_cgraph * llama_build_graph( { result = llm_build_llama(lctx, tokens, embd, n_tokens, n_past); } break; + case LLM_ARCH_BAICHUAN: + { + result = llm_build_baichaun(lctx, tokens, embd, n_tokens, n_past); + } break; case LLM_ARCH_FALCON: { result = llm_build_falcon(lctx, tokens, embd, n_tokens, n_past); } break; + case LLM_ARCH_STARCODER: + { + result = llm_build_starcoder(lctx, tokens, embd, n_tokens, n_past); + } break; default: GGML_ASSERT(false); }; @@ -2955,6 +3765,15 @@ static bool llama_eval_internal( n_threads = std::min(4, n_threads); } + // If all tensors can be run on the GPU then using more than 1 thread is detrimental. + const bool full_offload_supported = model.arch == LLM_ARCH_LLAMA || + model.arch == LLM_ARCH_BAICHUAN || + model.arch == LLM_ARCH_FALCON; + const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 3; + if (ggml_cpu_has_cublas() && full_offload_supported && fully_offloaded) { + n_threads = 1; + } + struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; @@ -2970,10 +3789,6 @@ static bool llama_eval_internal( if (lctx.ctx_metal) { ggml_metal_set_n_cb (lctx.ctx_metal, n_threads); ggml_metal_graph_compute(lctx.ctx_metal, gf); - ggml_metal_get_tensor (lctx.ctx_metal, res); - if (!lctx.embedding.empty()) { - ggml_metal_get_tensor(lctx.ctx_metal, embeddings); - } } else { ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); } @@ -3122,10 +3937,9 @@ struct llm_tokenizer_spm { while (offs < text.size()) { llm_symbol sym; size_t len = utf8_len(text[offs]); - GGML_ASSERT(offs + len <= text.size()); sym.text = text.c_str() + offs; - sym.n = len; - offs += len; + sym.n = std::min(len, text.size() - offs); + offs += sym.n; sym.prev = index - 1; sym.next = offs == text.size() ? -1 : index + 1; index++; @@ -3487,7 +4301,7 @@ struct llama_grammar_candidate { // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as // pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`. -std::pair, llama_partial_utf8> decode_utf8( +static std::pair, llama_partial_utf8> decode_utf8( const char * src, llama_partial_utf8 partial_start) { static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; @@ -4641,7 +5455,16 @@ void llama_beam_search(llama_context * ctx, // quantization // -static void llama_convert_tensor_internal(struct ggml_tensor * tensor, std::vector & output, const size_t nelements, const int nthread) { +template +struct no_init { + T value; + no_init() { /* do nothing */ } +}; + +static void llama_convert_tensor_internal( + struct ggml_tensor * tensor, std::vector> & output, std::vector & workers, + const size_t nelements, const int nthread +) { if (output.size() < nelements) { output.resize(nelements); } @@ -4676,7 +5499,6 @@ static void llama_convert_tensor_internal(struct ggml_tensor * tensor, std::vect auto blocks_per_thread = nblocks / nthread; auto spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count - std::vector workers; for (auto tnum = 0, in_buff_offs = 0, out_buff_offs = 0; tnum < nthread; tnum++) { auto thr_blocks = blocks_per_thread + (tnum == nthread - 1 ? spare_blocks : 0); // num blocks for this thread auto thr_elems = thr_blocks * block_size; // number of elements for this thread @@ -4689,14 +5511,123 @@ static void llama_convert_tensor_internal(struct ggml_tensor * tensor, std::vect qtype.to_float(inbuf, outbuf, nels); } }; - workers.push_back(std::thread(compute, tensor->type, (uint8_t *) tensor->data + in_buff_offs, f32_output + out_buff_offs, thr_elems)); + workers.emplace_back(compute, tensor->type, (uint8_t *) tensor->data + in_buff_offs, f32_output + out_buff_offs, thr_elems); in_buff_offs += thr_block_bytes; out_buff_offs += thr_elems; } - for (auto & worker : workers) { - worker.join(); + for (auto & w : workers) { w.join(); } + workers.clear(); +} + +#ifdef GGML_USE_K_QUANTS +static ggml_type get_k_quant_type( + ggml_type new_type, const ggml_tensor * tensor, const llama_model & model, llama_ftype ftype, int * i_attention_wv, + int n_attention_wv, int * i_feed_forward_w2, int n_feed_forward_w2 +) { + const std::string name = ggml_get_name(tensor); + // TODO: avoid hardcoded tensor names - use the TN_* constants + const auto tn = LLM_TN(model.arch); + + auto use_more_bits = [](int i_layer, int num_layers) -> bool { + return i_layer < num_layers/8 || i_layer >= 7*num_layers/8 || (i_layer - num_layers/8)%3 == 2; + }; + + if (name == tn(LLM_TENSOR_OUTPUT, "weight")) { + int nx = tensor->ne[0]; + if (model.arch == LLM_ARCH_FALCON || nx % QK_K != 0) { + new_type = GGML_TYPE_Q8_0; + } + else if (new_type != GGML_TYPE_Q8_0) { + new_type = GGML_TYPE_Q6_K; + } + } else if (name.find("attn_v.weight") != std::string::npos) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { + new_type = *i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; + else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && + use_more_bits(*i_attention_wv, n_attention_wv)) new_type = GGML_TYPE_Q6_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && *i_attention_wv < 4) new_type = GGML_TYPE_Q5_K; + else if (QK_K == 64 && (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) && + (*i_attention_wv < n_attention_wv/8 || *i_attention_wv >= 7*n_attention_wv/8)) new_type = GGML_TYPE_Q6_K; + if (model.type == MODEL_70B) { + // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is + // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with + // nearly negligible increase in model size by quantizing this tensor with more bits: + if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K; + } + ++*i_attention_wv; + } else if (name.find("ffn_down.weight") != std::string::npos) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { + new_type = *i_feed_forward_w2 < 2 ? GGML_TYPE_Q5_K + : model.arch != LLM_ARCH_FALCON || use_more_bits(*i_feed_forward_w2, n_feed_forward_w2) ? GGML_TYPE_Q4_K + : GGML_TYPE_Q3_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) { + new_type = model.arch == LLM_ARCH_FALCON ? GGML_TYPE_Q4_K : GGML_TYPE_Q5_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) { + if (model.arch == LLM_ARCH_FALCON) { + new_type = *i_feed_forward_w2 < 2 ? GGML_TYPE_Q6_K : + use_more_bits(*i_feed_forward_w2, n_feed_forward_w2) ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; + } else { + if (use_more_bits(*i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K; + } + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(*i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && model.arch != LLM_ARCH_FALCON && *i_feed_forward_w2 < 4) { + new_type = GGML_TYPE_Q5_K; + } + ++*i_feed_forward_w2; + } else if (name.find("attn_output.weight") != std::string::npos) { + if (model.arch != LLM_ARCH_FALCON) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K ) new_type = GGML_TYPE_Q3_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) new_type = GGML_TYPE_Q4_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; + } else { + if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K; + } + } + else if (name.find("attn_qkv.weight") != std::string::npos) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K; + } + else if (name.find("ffn_gate.weight") != std::string::npos || name.find("ffn_up.weight") != std::string::npos) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; + } + // This can be used to reduce the size of the Q5_K_S model. + // The associated PPL increase is fully in line with the size reduction + //else { + // if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K; + //} + bool convert_incompatible_tensor = false; + if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K || + new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K) { + int nx = tensor->ne[0]; + int ny = tensor->ne[1]; + if (nx % QK_K != 0) { + LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for k-quants\n", __func__, nx, ny, QK_K); + convert_incompatible_tensor = true; + } + } + if (convert_incompatible_tensor) { + if (name == tn(LLM_TENSOR_OUTPUT, "weight")) { + new_type = GGML_TYPE_F16; //fall back to F16 instead of just failing. + LLAMA_LOG_WARN("F16 will be used for this tensor instead.\n"); + } else if (name == tn(LLM_TENSOR_TOKEN_EMBD, "weight")) { + new_type = GGML_TYPE_Q4_0; //fall back to Q4_0 instead of just failing. + LLAMA_LOG_WARN("Q4_0 will be used for this tensor instead.\n"); + } else { + throw std::runtime_error("Unsupported tensor size encountered\n"); + } } + + return new_type; } +#endif static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { ggml_type quantized_type; @@ -4781,18 +5712,14 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s std::vector hist_all(1 << 4, 0); std::vector workers; + workers.reserve(nthread); std::mutex mutex; -#ifdef GGML_USE_K_QUANTS - auto use_more_bits = [] (int i_layer, int num_layers) -> bool { - return i_layer < num_layers/8 || i_layer >= 7*num_layers/8 || (i_layer - num_layers/8)%3 == 2; - }; -#endif - int idx = 0; - std::vector read_data; - std::vector work; + std::vector> read_data; + std::vector> work; + std::vector> f32_conv_buf; // populate the original tensors so we get an initial meta data for (int i = 0; i < ml->n_tensors; ++i) { @@ -4814,7 +5741,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s const std::string name = ggml_get_name(tensor); - read_data.resize(ggml_nbytes(tensor)); + if (read_data.size() < ggml_nbytes(tensor)) { + read_data.resize(ggml_nbytes(tensor)); + } tensor->data = read_data.data(); ml->load_data_for(tensor); @@ -4839,101 +5768,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (quantize) { new_type = quantized_type; #ifdef GGML_USE_K_QUANTS - // TODO: avoid hardcoded tensor names - use the TN_* constants - const auto tn = LLM_TN(ml->get_arch()); - - if (name == tn(LLM_TENSOR_OUTPUT, "weight")) { - int nx = tensor->ne[0]; - if (model.arch == LLM_ARCH_FALCON || nx % QK_K != 0) { - new_type = GGML_TYPE_Q8_0; - } - else if (new_type != GGML_TYPE_Q8_0) { - new_type = GGML_TYPE_Q6_K; - } - } else if (name.find("attn_v.weight") != std::string::npos) { - if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; - else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { - new_type = i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; - } - else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; - else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && - use_more_bits(i_attention_wv, n_attention_wv)) new_type = GGML_TYPE_Q6_K; - else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_attention_wv < 4) new_type = GGML_TYPE_Q5_K; - else if (QK_K == 64 && (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) && - (i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8)) new_type = GGML_TYPE_Q6_K; - if (model.type == MODEL_70B) { - // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is - // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with - // nearly negligible increase in model size by quantizing this tensor with more bits: - if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K; - } - ++i_attention_wv; - } else if (name.find("ffn_down.weight") != std::string::npos) { - if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; - else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { - new_type = i_feed_forward_w2 < 2 ? GGML_TYPE_Q5_K - : model.arch != LLM_ARCH_FALCON || use_more_bits(i_feed_forward_w2, n_feed_forward_w2) ? GGML_TYPE_Q4_K - : GGML_TYPE_Q3_K; - } - else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) { - new_type = model.arch == LLM_ARCH_FALCON ? GGML_TYPE_Q4_K : GGML_TYPE_Q5_K; - } - else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) { - if (model.arch == LLM_ARCH_FALCON) { - new_type = i_feed_forward_w2 < 2 ? GGML_TYPE_Q6_K : - use_more_bits(i_feed_forward_w2, n_feed_forward_w2) ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; - } else { - if (use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K; - } - } - else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K; - else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && model.arch != LLM_ARCH_FALCON && i_feed_forward_w2 < 4) { - new_type = GGML_TYPE_Q5_K; - } - ++i_feed_forward_w2; - } else if (name.find("attn_output.weight") != std::string::npos) { - if (model.arch != LLM_ARCH_FALCON) { - if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K ) new_type = GGML_TYPE_Q3_K; - else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) new_type = GGML_TYPE_Q4_K; - else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; - } else { - if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K; - } - } - else if (name.find("attn_qkv.weight") != std::string::npos) { - if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K; - else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K; - else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K; - } - else if (name.find("ffn_gate.weight") != std::string::npos || name.find("ffn_up.weight") != std::string::npos) { - if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; - } - // This can be used to reduce the size of the Q5_K_S model. - // The associated PPL increase is fully in line with the size reduction - //else { - // if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K; - //} - bool convert_incompatible_tensor = false; - if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K || - new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K) { - int nx = tensor->ne[0]; - int ny = tensor->ne[1]; - if (nx % QK_K != 0) { - LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for k-quants\n", __func__, nx, ny, QK_K); - convert_incompatible_tensor = true; - } - } - if (convert_incompatible_tensor) { - if (name == tn(LLM_TENSOR_OUTPUT, "weight")) { - new_type = GGML_TYPE_F16; //fall back to F16 instead of just failing. - LLAMA_LOG_WARN("F16 will be used for this tensor instead.\n"); - } else if (name == tn(LLM_TENSOR_TOKEN_EMBD, "weight")) { - new_type = GGML_TYPE_Q4_0; //fall back to Q4_0 instead of just failing. - LLAMA_LOG_WARN("Q4_0 will be used for this tensor instead.\n"); - } else { - throw std::runtime_error("Unsupported tensor size encountered\n"); - } - } + new_type = get_k_quant_type( + new_type, tensor, model, ftype, &i_attention_wv, n_attention_wv, &i_feed_forward_w2, n_feed_forward_w2 + ); #endif // If we've decided to quantize to the same type the tensor is already // in then there's nothing to do. @@ -4948,23 +5785,24 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s const size_t nelements = ggml_nelements(tensor); float * f32_data; - std::vector f32_conv_buf; if (tensor->type == GGML_TYPE_F32) { f32_data = (float *) tensor->data; } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) { throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type))); } else { - llama_convert_tensor_internal(tensor, f32_conv_buf, nelements, nthread); + llama_convert_tensor_internal(tensor, f32_conv_buf, workers, nelements, nthread); f32_data = (float *) f32_conv_buf.data(); } LLAMA_LOG_INFO("quantizing to %s .. ", ggml_type_name(new_type)); fflush(stdout); - work.resize(nelements * 4); // upper bound on size + if (work.size() < nelements * 4) { + work.resize(nelements * 4); // upper bound on size + } new_data = work.data(); - std::vector hist_cur(1 << 4, 0); + std::array hist_cur = {}; static const int chunk_size = 32 * 512; const int nchunk = (nelements + chunk_size - 1)/chunk_size; @@ -4975,13 +5813,13 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s size_t counter = 0; new_size = 0; auto compute = [&mutex, &counter, &hist_cur, &new_size, new_type, f32_data, new_data, nelements]() { - std::vector local_hist; + std::array local_hist = {}; size_t local_size = 0; while (true) { std::unique_lock lock(mutex); size_t first = counter; counter += chunk_size; if (first >= nelements) { - if (!local_hist.empty()) { + if (local_size > 0) { for (int j=0; j %8.2f MB | hist: ", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0); @@ -5068,7 +5899,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } // TODO: after the GGUF PR, this likely won't work and needs to be updated -int llama_apply_lora_from_file_internal(const struct llama_model & model, const char * path_lora, const char * path_base_model, int n_threads) { +static int llama_apply_lora_from_file_internal( + const struct llama_model & model, const char * path_lora, const char * path_base_model, int n_threads +) { LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora); const int64_t t_start_lora_us = ggml_time_us(); @@ -5352,8 +6185,8 @@ struct llama_context_params llama_context_default_params() { /*.n_gpu_layers =*/ 0, /*.main_gpu =*/ 0, /*.tensor_split =*/ nullptr, - /*.rope_freq_base =*/ 10000.0f, - /*.rope_freq_scale =*/ 1.0f, + /*.rope_freq_base =*/ 0.0f, + /*.rope_freq_scale =*/ 0.0f, /*.progress_callback =*/ nullptr, /*.progress_callback_user_data =*/ nullptr, /*.low_vram =*/ false, @@ -5615,7 +6448,7 @@ struct llama_context * llama_new_context_with_model( return ctx; } -struct llama_context * llama_init_from_file( +static struct llama_context * llama_init_from_file( const char * path_model, struct llama_context_params params) { struct llama_model * model = llama_load_model_from_file(path_model, params); @@ -5820,7 +6653,7 @@ struct llama_data_file_context : llama_data_context { * llama_copy_state_data(ctx, &data_ctx); * */ -void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { +static void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { // copy rng { std::stringstream rng_ss; @@ -6204,22 +7037,24 @@ llama_token llama_token_nl(const struct llama_context * ctx) { int llama_tokenize( struct llama_context * ctx, const char * text, + int text_len, llama_token * tokens, int n_max_tokens, bool add_bos) { - return llama_tokenize_with_model(&ctx->model, text, tokens, n_max_tokens, add_bos); + return llama_tokenize_with_model(&ctx->model, text, text_len, tokens, n_max_tokens, add_bos); } int llama_tokenize_with_model( const struct llama_model * model, const char * text, + int text_len, llama_token * tokens, int n_max_tokens, bool add_bos) { - auto res = llama_tokenize_internal(model->vocab, text, add_bos); + auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos); if (n_max_tokens < (int) res.size()) { - LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); + // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); return -((int) res.size()); } @@ -6358,7 +7193,9 @@ void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) { } // For internal test use -const std::vector>& llama_internal_get_tensor_map(struct llama_context * ctx) { +const std::vector> & llama_internal_get_tensor_map( + struct llama_context * ctx +) { return ctx->model.tensors_by_name; } diff --git a/plugins/wasi_nn/thirdparty/ggml/llama.h b/plugins/wasi_nn/thirdparty/ggml/llama.h index 37975beb..369be048 100644 --- a/plugins/wasi_nn/thirdparty/ggml/llama.h +++ b/plugins/wasi_nn/thirdparty/ggml/llama.h @@ -374,6 +374,7 @@ extern "C" { LLAMA_API int llama_tokenize( struct llama_context * ctx, const char * text, + int text_len, llama_token * tokens, int n_max_tokens, bool add_bos); @@ -381,6 +382,7 @@ extern "C" { LLAMA_API int llama_tokenize_with_model( const struct llama_model * model, const char * text, + int text_len, llama_token * tokens, int n_max_tokens, bool add_bos); @@ -540,7 +542,9 @@ extern "C" { struct ggml_tensor; -const std::vector>& llama_internal_get_tensor_map(struct llama_context * ctx); +const std::vector> & llama_internal_get_tensor_map( + struct llama_context * ctx +); #endif // LLAMA_API_INTERNAL From d29b7fead132d5d0a056ac4541c21c0ca1451c19 Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 28 Sep 2023 00:31:42 +0800 Subject: [PATCH 156/623] [WASI-NN] ggml backend: handle exceptions & add more comments Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 49 ++++++++++++++++++++++++++++++++++------ 1 file changed, 42 insertions(+), 7 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 791ab5e3..03387726 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -12,18 +12,29 @@ namespace WasmEdge::Host::WASINN::GGML { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML -llama_context_params wasmedge_llama_context_params() noexcept { - llama_context_params Params = llama_context_default_params(); +ErrNo wasmedge_llama_context_params(llama_context_params &Params) noexcept { const char *LlamaNContextEnv = std::getenv("LLAMA_N_CTX"); const char *LlamaLogEnv = std::getenv("LLAMA_LOG"); if (LlamaNContextEnv != nullptr) { - Params.n_ctx = std::stoi(LlamaNContextEnv); + try { + Params.n_ctx = std::stoi(LlamaNContextEnv); + } catch (const std::out_of_range &e) { + spdlog::error( + "[WASI-NN] GGML backend: set n_ctx failed: out_of_range {}"sv, + e.what()); + return ErrNo::InvalidArgument; + } catch (const std::invalid_argument &e) { + spdlog::error( + "[WASI-NN] GGML backend: set n_ctx failed: invalid_argument {}"sv, + e.what()); + return ErrNo::InvalidArgument; + } if (LlamaLogEnv != nullptr) { spdlog::info("[WASI-NN] GGML backend: set n_ctx to {}"sv, Params.n_ctx); } } - return Params; + return ErrNo::Success; } Expect load(WasiNNEnvironment &Env, Span> Builders, @@ -66,7 +77,12 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Initialize ggml model. gpt_params Params; llama_backend_init(Params.numa); - llama_context_params ContextParams = wasmedge_llama_context_params(); + llama_context_params ContextParams = llama_context_default_params(); + ErrNo Err = wasmedge_llama_context_params(ContextParams); + if (Err != ErrNo::Success) { + spdlog::error("[WASI-NN] GGML backend: Error: unable to init context."sv); + return ErrNo::InvalidArgument; + } GraphRef.LlamaModel = llama_load_model_from_file(ModelFilePath.c_str(), ContextParams); if (GraphRef.LlamaModel == nullptr) { @@ -99,7 +115,12 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); // Initialize the llama context. - llama_context_params ContextParams = wasmedge_llama_context_params(); + llama_context_params ContextParams = llama_context_default_params(); + ErrNo Err = wasmedge_llama_context_params(ContextParams); + if (Err != ErrNo::Success) { + spdlog::error("[WASI-NN] GGML backend: Error: unable to init context."sv); + return ErrNo::InvalidArgument; + } GraphRef.LlamaContext = llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); @@ -152,10 +173,24 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // TODO: recompute a compressed context based on previous tokens once the // cache is full. const int MaxContextSize = llama_n_ctx(GraphRef.LlamaContext); + // NPredict is the number of tokens to predict. Same as -n, --n-predict in + // llama.cpp. int NPredict = std::numeric_limits::max(); const char *LlamaNPredictEnv = std::getenv("LLAMA_N_PREDICT"); if (LlamaNPredictEnv != nullptr) { - NPredict = std::stoi(LlamaNPredictEnv); + try { + NPredict = std::stoi(LlamaNPredictEnv); + } catch (const std::out_of_range &e) { + spdlog::error( + "[WASI-NN] GGML backend: set n_predict failed: out_of_range {}"sv, + e.what()); + return ErrNo::InvalidArgument; + } catch (const std::invalid_argument &e) { + spdlog::error( + "[WASI-NN] GGML backend: set n_predict failed: invalid_argument {}"sv, + e.what()); + return ErrNo::InvalidArgument; + } if (LlamaLogEnv != nullptr) { spdlog::info("[WASI-NN] GGML backend: set n_predict to {}"sv, NPredict); } From e9cf23285f6ac9b80e9e7b99e986dc294990dd86 Mon Sep 17 00:00:00 2001 From: hydai Date: Fri, 22 Sep 2023 00:42:39 +0800 Subject: [PATCH 157/623] [WASI-NN] Enable ggml llama.cpp plugin on macOS Signed-off-by: hydai --- .../wasi_nn/thirdparty/ggml/CMakeLists.txt | 20 +- plugins/wasi_nn/thirdparty/ggml/ggml-metal.h | 85 + plugins/wasi_nn/thirdparty/ggml/ggml-metal.m | 1268 ++++++++++ .../wasi_nn/thirdparty/ggml/ggml-metal.metal | 2254 +++++++++++++++++ plugins/wasi_nn/wasinnenv.cpp | 2 + 5 files changed, 3612 insertions(+), 17 deletions(-) create mode 100644 plugins/wasi_nn/thirdparty/ggml/ggml-metal.h create mode 100644 plugins/wasi_nn/thirdparty/ggml/ggml-metal.m create mode 100644 plugins/wasi_nn/thirdparty/ggml/ggml-metal.metal diff --git a/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt b/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt index 6e97d00d..f18b264a 100644 --- a/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt +++ b/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt @@ -2,11 +2,9 @@ # Option list # -if (APPLE) - set(LLAMA_METAL_DEFAULT ON) -else() - set(LLAMA_METAL_DEFAULT OFF) -endif() +# Get errors when enabling METAL API +# Disable it currently +set(LLAMA_METAL_DEFAULT OFF) # general option(LLAMA_STATIC "llama: static link libraries" OFF) @@ -89,18 +87,6 @@ if (NOT MSVC) endif() endif() -if (APPLE AND LLAMA_ACCELERATE) - find_library(ACCELERATE_FRAMEWORK Accelerate) - if (ACCELERATE_FRAMEWORK) - message(STATUS "Accelerate framework found") - - add_compile_definitions(GGML_USE_ACCELERATE) - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK}) - else() - message(WARNING "Accelerate framework not found") - endif() -endif() - if (LLAMA_METAL) find_library(FOUNDATION_LIBRARY Foundation REQUIRED) find_library(METAL_FRAMEWORK Metal REQUIRED) diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-metal.h b/plugins/wasi_nn/thirdparty/ggml/ggml-metal.h new file mode 100644 index 00000000..fca28d37 --- /dev/null +++ b/plugins/wasi_nn/thirdparty/ggml/ggml-metal.h @@ -0,0 +1,85 @@ +// An interface allowing to compute ggml_cgraph with Metal +// +// This is a fully functional interface that extends ggml with GPU support for Apple devices. +// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, OpenCL, etc.) +// +// How it works? +// +// As long as your program can create and evaluate a ggml_cgraph on the CPU, you can use this +// interface to evaluate the same graph on the GPU. Instead of using ggml_graph_compute(), you +// use ggml_metal_graph_compute() (or ggml_vulkan_graph_compute(), etc.) +// +// You only need to make sure that all memory buffers that you used during the graph creation +// are mapped to the device memory with the ggml_metal_add_buffer() function. This mapping is +// used during the graph evaluation to determine the arguments of the compute kernels. +// +// Synchronization between device and host memory (for example for input and output tensors) +// is done with the ggml_metal_set_tensor() and ggml_metal_get_tensor() functions. +// + +#pragma once + +#include +#include + +// max memory buffers that can be mapped to the device +#define GGML_METAL_MAX_BUFFERS 16 +#define GGML_METAL_MAX_COMMAND_BUFFERS 32 + +struct ggml_tensor; +struct ggml_cgraph; + +#ifdef __cplusplus +extern "C" { +#endif + +struct ggml_metal_context; + +// number of command buffers to use +struct ggml_metal_context * ggml_metal_init(int n_cb); +void ggml_metal_free(struct ggml_metal_context * ctx); + +void * ggml_metal_host_malloc(size_t n); +void ggml_metal_host_free (void * data); + +// set the number of command buffers to use +void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb); + +// creates a mapping between a host memory buffer and a device memory buffer +// - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute +// - the mapping is used during computation to determine the arguments of the compute kernels +// - you don't need to keep the host memory buffer allocated as it is never accessed by Metal +// - max_size specifies the maximum size of a tensor and is used to create shared views such +// that it is guaranteed that the tensor will fit in at least one of the views +// +bool ggml_metal_add_buffer( + struct ggml_metal_context * ctx, + const char * name, + void * data, + size_t size, + size_t max_size); + +// set data from host memory into the device +void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t); + +// get data from the device into host memory +void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t); + +// try to find operations that can be run concurrently in the graph +// you should run it again if the topology of your graph changes +void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf, bool check_mem); + +// if the graph has been optimized for concurrently dispatch, return length of the concur_list if optimized +int ggml_metal_if_optimized(struct ggml_metal_context * ctx); + +// output the concur_list for ggml_alloc +int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx); + +// same as ggml_graph_compute but uses Metal +// creates gf->n_threads command buffers in parallel +void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf); + +#ifdef __cplusplus +} +#endif + diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-metal.m b/plugins/wasi_nn/thirdparty/ggml/ggml-metal.m new file mode 100644 index 00000000..4f3f14e2 --- /dev/null +++ b/plugins/wasi_nn/thirdparty/ggml/ggml-metal.m @@ -0,0 +1,1268 @@ +#import "ggml-metal.h" + +#import "ggml.h" + +#import + +#import + +#undef MIN +#undef MAX +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +// TODO: temporary - reuse llama.cpp logging +#ifdef GGML_METAL_NDEBUG +#define metal_printf(...) +#else +#define metal_printf(...) fprintf(stderr, __VA_ARGS__) +#endif + +#define UNUSED(x) (void)(x) + +#define GGML_MAX_CONCUR (2*GGML_MAX_NODES) + +struct ggml_metal_buffer { + const char * name; + + void * data; + size_t size; + + id metal; +}; + +struct ggml_metal_context { + int n_cb; + + id device; + id queue; + id library; + + id command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS]; + id command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS]; + + dispatch_queue_t d_queue; + + int n_buffers; + struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS]; + + int concur_list[GGML_MAX_CONCUR]; + int concur_list_len; + + // custom kernels +#define GGML_METAL_DECL_KERNEL(name) \ + id function_##name; \ + id pipeline_##name + + GGML_METAL_DECL_KERNEL(add); + GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast + GGML_METAL_DECL_KERNEL(mul); + GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast + GGML_METAL_DECL_KERNEL(scale); + GGML_METAL_DECL_KERNEL(silu); + GGML_METAL_DECL_KERNEL(relu); + GGML_METAL_DECL_KERNEL(gelu); + GGML_METAL_DECL_KERNEL(soft_max); + GGML_METAL_DECL_KERNEL(soft_max_4); + GGML_METAL_DECL_KERNEL(diag_mask_inf); + GGML_METAL_DECL_KERNEL(diag_mask_inf_8); + GGML_METAL_DECL_KERNEL(get_rows_f16); + GGML_METAL_DECL_KERNEL(get_rows_q4_0); + GGML_METAL_DECL_KERNEL(get_rows_q4_1); + GGML_METAL_DECL_KERNEL(get_rows_q8_0); + GGML_METAL_DECL_KERNEL(get_rows_q2_K); + GGML_METAL_DECL_KERNEL(get_rows_q3_K); + GGML_METAL_DECL_KERNEL(get_rows_q4_K); + GGML_METAL_DECL_KERNEL(get_rows_q5_K); + GGML_METAL_DECL_KERNEL(get_rows_q6_K); + GGML_METAL_DECL_KERNEL(rms_norm); + GGML_METAL_DECL_KERNEL(norm); + GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); + GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row); + GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4); + GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32); + GGML_METAL_DECL_KERNEL(mul_mm_f16_f32); + GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32); + GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32); + GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32); + GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32); + GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32); + GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32); + GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32); + GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32); + GGML_METAL_DECL_KERNEL(rope); + GGML_METAL_DECL_KERNEL(alibi_f32); + GGML_METAL_DECL_KERNEL(cpy_f32_f16); + GGML_METAL_DECL_KERNEL(cpy_f32_f32); + GGML_METAL_DECL_KERNEL(cpy_f16_f16); + +#undef GGML_METAL_DECL_KERNEL +}; + +// MSL code +// TODO: move the contents here when ready +// for now it is easier to work in a separate file +static NSString * const msl_library_source = @"see metal.metal"; + +// Here to assist with NSBundle Path Hack +@interface GGMLMetalClass : NSObject +@end +@implementation GGMLMetalClass +@end + +struct ggml_metal_context * ggml_metal_init(int n_cb) { + metal_printf("%s: allocating\n", __func__); + + id device; + NSString * s; + +#if TARGET_OS_OSX + // Show all the Metal device instances in the system + NSArray * devices = MTLCopyAllDevices(); + for (device in devices) { + s = [device name]; + metal_printf("%s: found device: %s\n", __func__, [s UTF8String]); + } +#endif + + // Pick and show default Metal device + device = MTLCreateSystemDefaultDevice(); + s = [device name]; + metal_printf("%s: picking default device: %s\n", __func__, [s UTF8String]); + + // Configure context + struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context)); + ctx->device = device; + ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS); + ctx->queue = [ctx->device newCommandQueue]; + ctx->n_buffers = 0; + ctx->concur_list_len = 0; + + ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT); + +#ifdef GGML_SWIFT + // load the default.metallib file + { + NSError * error = nil; + + NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; + NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"]; + NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath]; + NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"]; + NSURL * libURL = [NSURL fileURLWithPath:libPath]; + + // Load the metallib file into a Metal library + ctx->library = [ctx->device newLibraryWithURL:libURL error:&error]; + + if (error) { + metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } + } +#else + UNUSED(msl_library_source); + + // read the source from "ggml-metal.metal" into a string and use newLibraryWithSource + { + NSError * error = nil; + + //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"]; + NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; + NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; + metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]); + + NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error]; + if (error) { + metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } + +#ifdef GGML_QKK_64 + MTLCompileOptions* options = [MTLCompileOptions new]; + options.preprocessorMacros = @{ @"QK_K" : @(64) }; + ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; +#else + ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error]; +#endif + if (error) { + metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } + } +#endif + + // load kernels + { + NSError * error = nil; +#define GGML_METAL_ADD_KERNEL(name) \ + ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \ + ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \ + metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \ + (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \ + (int) ctx->pipeline_##name.threadExecutionWidth); \ + if (error) { \ + metal_printf("%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ + return NULL; \ + } + + GGML_METAL_ADD_KERNEL(add); + GGML_METAL_ADD_KERNEL(add_row); + GGML_METAL_ADD_KERNEL(mul); + GGML_METAL_ADD_KERNEL(mul_row); + GGML_METAL_ADD_KERNEL(scale); + GGML_METAL_ADD_KERNEL(silu); + GGML_METAL_ADD_KERNEL(relu); + GGML_METAL_ADD_KERNEL(gelu); + GGML_METAL_ADD_KERNEL(soft_max); + GGML_METAL_ADD_KERNEL(soft_max_4); + GGML_METAL_ADD_KERNEL(diag_mask_inf); + GGML_METAL_ADD_KERNEL(diag_mask_inf_8); + GGML_METAL_ADD_KERNEL(get_rows_f16); + GGML_METAL_ADD_KERNEL(get_rows_q4_0); + GGML_METAL_ADD_KERNEL(get_rows_q4_1); + GGML_METAL_ADD_KERNEL(get_rows_q8_0); + GGML_METAL_ADD_KERNEL(get_rows_q2_K); + GGML_METAL_ADD_KERNEL(get_rows_q3_K); + GGML_METAL_ADD_KERNEL(get_rows_q4_K); + GGML_METAL_ADD_KERNEL(get_rows_q5_K); + GGML_METAL_ADD_KERNEL(get_rows_q6_K); + GGML_METAL_ADD_KERNEL(rms_norm); + GGML_METAL_ADD_KERNEL(norm); + GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); + GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row); + GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4); + GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32); + GGML_METAL_ADD_KERNEL(rope); + GGML_METAL_ADD_KERNEL(alibi_f32); + GGML_METAL_ADD_KERNEL(cpy_f32_f16); + GGML_METAL_ADD_KERNEL(cpy_f32_f32); + GGML_METAL_ADD_KERNEL(cpy_f16_f16); + +#undef GGML_METAL_ADD_KERNEL + } + + metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); +#if TARGET_OS_OSX + metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); + if (ctx->device.maxTransferRate != 0) { + metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0); + } else { + metal_printf("%s: maxTransferRate = built-in GPU\n", __func__); + } +#endif + + return ctx; +} + +void ggml_metal_free(struct ggml_metal_context * ctx) { + metal_printf("%s: deallocating\n", __func__); +#define GGML_METAL_DEL_KERNEL(name) \ + [ctx->function_##name release]; \ + [ctx->pipeline_##name release]; + + GGML_METAL_DEL_KERNEL(add); + GGML_METAL_DEL_KERNEL(add_row); + GGML_METAL_DEL_KERNEL(mul); + GGML_METAL_DEL_KERNEL(mul_row); + GGML_METAL_DEL_KERNEL(scale); + GGML_METAL_DEL_KERNEL(silu); + GGML_METAL_DEL_KERNEL(relu); + GGML_METAL_DEL_KERNEL(gelu); + GGML_METAL_DEL_KERNEL(soft_max); + GGML_METAL_DEL_KERNEL(soft_max_4); + GGML_METAL_DEL_KERNEL(diag_mask_inf_8); + GGML_METAL_DEL_KERNEL(get_rows_f16); + GGML_METAL_DEL_KERNEL(get_rows_q4_0); + GGML_METAL_DEL_KERNEL(get_rows_q4_1); + GGML_METAL_DEL_KERNEL(get_rows_q8_0); + GGML_METAL_DEL_KERNEL(get_rows_q2_K); + GGML_METAL_DEL_KERNEL(get_rows_q3_K); + GGML_METAL_DEL_KERNEL(get_rows_q4_K); + GGML_METAL_DEL_KERNEL(get_rows_q5_K); + GGML_METAL_DEL_KERNEL(get_rows_q6_K); + GGML_METAL_DEL_KERNEL(rms_norm); + GGML_METAL_DEL_KERNEL(norm); + GGML_METAL_DEL_KERNEL(mul_mat_f16_f32); + GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row); + GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4); + GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32); + GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32); + GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32); + GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32); + GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32); + GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32); + GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32); + GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32); + GGML_METAL_DEL_KERNEL(mul_mm_f16_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32); + GGML_METAL_DEL_KERNEL(rope); + GGML_METAL_DEL_KERNEL(alibi_f32); + GGML_METAL_DEL_KERNEL(cpy_f32_f16); + GGML_METAL_DEL_KERNEL(cpy_f32_f32); + GGML_METAL_DEL_KERNEL(cpy_f16_f16); + +#undef GGML_METAL_DEL_KERNEL + + for (int i = 0; i < ctx->n_buffers; ++i) { + [ctx->buffers[i].metal release]; + } + + [ctx->library release]; + [ctx->queue release]; + [ctx->device release]; + + dispatch_release(ctx->d_queue); + + free(ctx); +} + +void * ggml_metal_host_malloc(size_t n) { + void * data = NULL; + const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n); + if (result != 0) { + metal_printf("%s: error: posix_memalign failed\n", __func__); + return NULL; + } + + return data; +} + +void ggml_metal_host_free(void * data) { + free(data); +} + +void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) { + ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS); +} + +int ggml_metal_if_optimized(struct ggml_metal_context * ctx) { + return ctx->concur_list_len; +} + +int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) { + return ctx->concur_list; +} + +// finds the Metal buffer that contains the tensor data on the GPU device +// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the +// Metal buffer based on the host memory pointer +// +static id ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) { + //metal_printf("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach); + + const int64_t tsize = ggml_nbytes(t); + + // find the view that contains the tensor fully + for (int i = 0; i < ctx->n_buffers; ++i) { + const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data; + + if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) { + *offs = (size_t) ioffs; + + //metal_printf("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs); + + return ctx->buffers[i].metal; + } + } + + metal_printf("%s: error: buffer is nil\n", __func__); + + return nil; +} + +bool ggml_metal_add_buffer( + struct ggml_metal_context * ctx, + const char * name, + void * data, + size_t size, + size_t max_size) { + if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) { + metal_printf("%s: too many buffers\n", __func__); + return false; + } + + if (data) { + // verify that the buffer does not overlap with any of the existing buffers + for (int i = 0; i < ctx->n_buffers; ++i) { + const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data; + + if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) { + metal_printf("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name); + return false; + } + } + + const size_t size_page = sysconf(_SC_PAGESIZE); + + size_t size_aligned = size; + if ((size_aligned % size_page) != 0) { + size_aligned += (size_page - (size_aligned % size_page)); + } + + // the buffer fits into the max buffer size allowed by the device + if (size_aligned <= ctx->device.maxBufferLength) { + ctx->buffers[ctx->n_buffers].name = name; + ctx->buffers[ctx->n_buffers].data = data; + ctx->buffers[ctx->n_buffers].size = size; + + ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; + + if (ctx->buffers[ctx->n_buffers].metal == nil) { + metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0); + return false; + } + + metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0); + + ++ctx->n_buffers; + } else { + // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into + // one of the views + const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case + const size_t size_step = ctx->device.maxBufferLength - size_ovlp; + const size_t size_view = ctx->device.maxBufferLength; + + for (size_t i = 0; i < size; i += size_step) { + const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i); + + ctx->buffers[ctx->n_buffers].name = name; + ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i); + ctx->buffers[ctx->n_buffers].size = size_step_aligned; + + ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; + + if (ctx->buffers[ctx->n_buffers].metal == nil) { + metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0); + return false; + } + + metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i); + if (i + size_step < size) { + metal_printf("\n"); + } + + ++ctx->n_buffers; + } + } + +#if TARGET_OS_OSX + metal_printf(", (%8.2f / %8.2f)", + ctx->device.currentAllocatedSize / 1024.0 / 1024.0, + ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); + + if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) { + metal_printf(", warning: current allocated size is greater than the recommended max working set size\n"); + } else { + metal_printf("\n"); + } +#else + metal_printf(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0); +#endif + } + + return true; +} + +void ggml_metal_set_tensor( + struct ggml_metal_context * ctx, + struct ggml_tensor * t) { + size_t offs; + id id_dst = ggml_metal_get_buffer(ctx, t, &offs); + + memcpy((void *) ((uint8_t *) id_dst.contents + offs), t->data, ggml_nbytes(t)); +} + +void ggml_metal_get_tensor( + struct ggml_metal_context * ctx, + struct ggml_tensor * t) { + size_t offs; + id id_src = ggml_metal_get_buffer(ctx, t, &offs); + + memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t)); +} + +void ggml_metal_graph_find_concurrency( + struct ggml_metal_context * ctx, + struct ggml_cgraph * gf, bool check_mem) { + int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time + int nodes_unused[GGML_MAX_CONCUR]; + + for (int i = 0; i < GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; } + for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; } + ctx->concur_list_len = 0; + + int n_left = gf->n_nodes; + int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list + int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos + + while (n_left > 0) { + // number of nodes at a layer (that can be issued concurrently) + int concurrency = 0; + for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) { + if (nodes_unused[i]) { + // if the requirements for gf->nodes[i] are satisfied + int exe_flag = 1; + + // scan all srcs + for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) { + struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind]; + if (src_cur) { + // if is leaf nodes it's satisfied. + // TODO: ggml_is_leaf() + if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) { + continue; + } + + // otherwise this src should be the output from previous nodes. + int is_found = 0; + + // scan 2*search_depth back because we inserted barrier. + //for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) { + for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) { + if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) { + is_found = 1; + break; + } + } + if (is_found == 0) { + exe_flag = 0; + break; + } + } + } + if (exe_flag && check_mem) { + // check if nodes[i]'s data will be overwritten by a node before nodes[i]. + // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3] + int64_t data_start = (int64_t) gf->nodes[i]->data; + int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]); + for (int j = n_start; j < i; j++) { + if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \ + && gf->nodes[j]->op != GGML_OP_VIEW \ + && gf->nodes[j]->op != GGML_OP_TRANSPOSE \ + && gf->nodes[j]->op != GGML_OP_PERMUTE) { + if (((int64_t)gf->nodes[j]->data) >= data_start + length || \ + ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) { + continue; + } + + exe_flag = 0; + } + } + } + if (exe_flag) { + ctx->concur_list[level_pos + concurrency] = i; + nodes_unused[i] = 0; + concurrency++; + ctx->concur_list_len++; + } + } + } + n_left -= concurrency; + // adding a barrier different layer + ctx->concur_list[level_pos + concurrency] = -1; + ctx->concur_list_len++; + // jump all sorted nodes at nodes_bak + while (!nodes_unused[n_start]) { + n_start++; + } + level_pos += concurrency + 1; + } + + if (ctx->concur_list_len > GGML_MAX_CONCUR) { + metal_printf("%s: too many elements for metal ctx->concur_list!\n", __func__); + } +} + +void ggml_metal_graph_compute( + struct ggml_metal_context * ctx, + struct ggml_cgraph * gf) { + @autoreleasepool { + + // if there is ctx->concur_list, dispatch concurrently + // else fallback to serial dispatch + MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; + + const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR; + + const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes; + edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial; + + // create multiple command buffers and enqueue them + // then, we encode the graph into the command buffers in parallel + + const int n_cb = ctx->n_cb; + + for (int i = 0; i < n_cb; ++i) { + ctx->command_buffers[i] = [ctx->queue commandBuffer]; + + // enqueue the command buffers in order to specify their execution order + [ctx->command_buffers[i] enqueue]; + + ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc]; + } + + for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { + const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb; + + dispatch_async(ctx->d_queue, ^{ + size_t offs_src0 = 0; + size_t offs_src1 = 0; + size_t offs_dst = 0; + + id command_buffer = ctx->command_buffers[cb_idx]; + id encoder = ctx->command_encoders[cb_idx]; + + const int node_start = (cb_idx + 0) * n_nodes_per_cb; + const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes); + + for (int ind = node_start; ind < node_end; ++ind) { + const int i = has_concur ? ctx->concur_list[ind] : ind; + + if (i == -1) { + [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]; + continue; + } + + //metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); + + struct ggml_tensor * src0 = gf->nodes[i]->src[0]; + struct ggml_tensor * src1 = gf->nodes[i]->src[1]; + struct ggml_tensor * dst = gf->nodes[i]; + + const int64_t ne00 = src0 ? src0->ne[0] : 0; + const int64_t ne01 = src0 ? src0->ne[1] : 0; + const int64_t ne02 = src0 ? src0->ne[2] : 0; + const int64_t ne03 = src0 ? src0->ne[3] : 0; + + const uint64_t nb00 = src0 ? src0->nb[0] : 0; + const uint64_t nb01 = src0 ? src0->nb[1] : 0; + const uint64_t nb02 = src0 ? src0->nb[2] : 0; + const uint64_t nb03 = src0 ? src0->nb[3] : 0; + + const int64_t ne10 = src1 ? src1->ne[0] : 0; + const int64_t ne11 = src1 ? src1->ne[1] : 0; + const int64_t ne12 = src1 ? src1->ne[2] : 0; + const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13); + + const uint64_t nb10 = src1 ? src1->nb[0] : 0; + const uint64_t nb11 = src1 ? src1->nb[1] : 0; + const uint64_t nb12 = src1 ? src1->nb[2] : 0; + const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13); + + const int64_t ne0 = dst ? dst->ne[0] : 0; + const int64_t ne1 = dst ? dst->ne[1] : 0; + const int64_t ne2 = dst ? dst->ne[2] : 0; + const int64_t ne3 = dst ? dst->ne[3] : 0; + + const uint64_t nb0 = dst ? dst->nb[0] : 0; + const uint64_t nb1 = dst ? dst->nb[1] : 0; + const uint64_t nb2 = dst ? dst->nb[2] : 0; + const uint64_t nb3 = dst ? dst->nb[3] : 0; + + const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; + const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; + const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT; + + id id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil; + id id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil; + id id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil; + + //metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op)); + //if (src0) { + // metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, + // ggml_is_contiguous(src0), src0->name); + //} + //if (src1) { + // metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, + // ggml_is_contiguous(src1), src1->name); + //} + //if (dst) { + // metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, + // dst->name); + //} + + switch (dst->op) { + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_TRANSPOSE: + case GGML_OP_PERMUTE: + { + // noop + } break; + case GGML_OP_ADD: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + // utilize float4 + GGML_ASSERT(ne00 % 4 == 0); + const int64_t nb = ne00/4; + + if (ggml_nelements(src1) == ne10) { + // src1 is a row + [encoder setComputePipelineState:ctx->pipeline_add_row]; + } else { + [encoder setComputePipelineState:ctx->pipeline_add]; + } + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&nb length:sizeof(nb) atIndex:3]; + + const int64_t n = ggml_nelements(dst)/4; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_MUL: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + // utilize float4 + GGML_ASSERT(ne00 % 4 == 0); + const int64_t nb = ne00/4; + + if (ggml_nelements(src1) == ne10) { + // src1 is a row + [encoder setComputePipelineState:ctx->pipeline_mul_row]; + } else { + [encoder setComputePipelineState:ctx->pipeline_mul]; + } + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&nb length:sizeof(nb) atIndex:3]; + + const int64_t n = ggml_nelements(dst)/4; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SCALE: + { + const float scale = *(const float *) src1->data; + + [encoder setComputePipelineState:ctx->pipeline_scale]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; + + const int64_t n = ggml_nelements(dst)/4; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(gf->nodes[i])) { + case GGML_UNARY_OP_SILU: + { + [encoder setComputePipelineState:ctx->pipeline_silu]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst)/4; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_RELU: + { + [encoder setComputePipelineState:ctx->pipeline_relu]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_GELU: + { + [encoder setComputePipelineState:ctx->pipeline_gelu]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst)/4; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + default: + { + metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); + GGML_ASSERT(false); + } + } break; + case GGML_OP_SOFT_MAX: + { + const int nth = 32; + + if (ne00%4 == 0) { + [encoder setComputePipelineState:ctx->pipeline_soft_max_4]; + } else { + [encoder setComputePipelineState:ctx->pipeline_soft_max]; + } + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_DIAG_MASK_INF: + { + const int n_past = ((int32_t *)(dst->op_params))[0]; + + if (ne00%8 == 0) { + [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8]; + } else { + [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf]; + } + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; + + if (ne00%8 == 0) { + [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } + else { + [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } + } break; + case GGML_OP_MUL_MAT: + { + // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224 + + GGML_ASSERT(ne00 == ne10); + // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere + uint gqa = ne12/ne02; + GGML_ASSERT(ne03 == ne13); + + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs + // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel + if (ggml_is_contiguous(src0) && + ggml_is_contiguous(src1) && + src1t == GGML_TYPE_F32 && + [ctx->device supportsFamily:MTLGPUFamilyApple7] && + ne00%32 == 0 && + ne11 > 1) { + switch (src0->type) { + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break; + case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break; + case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break; + case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break; + case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break; + case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break; + case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break; + case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break; + case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break; + default: GGML_ASSERT(false && "MUL MAT-MAT not implemented"); + } + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9]; + [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10]; + [encoder setThreadgroupMemoryLength:8192 atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + } else { + int nth0 = 32; + int nth1 = 1; + int nrows = 1; + + // use custom matrix x vector kernel + switch (src0t) { + case GGML_TYPE_F16: + { + nth0 = 32; + nth1 = 1; + if (ne11 * ne12 < 4) { + [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row]; + } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { + [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4]; + nrows = ne11; + } else { + [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; + nrows = 4; + } + } break; + case GGML_TYPE_Q4_0: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; + } break; + case GGML_TYPE_Q4_1: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32]; + } break; + case GGML_TYPE_Q8_0: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32]; + } break; + case GGML_TYPE_Q2_K: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 2; + nth1 = 32; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32]; + } break; + case GGML_TYPE_Q3_K: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 2; + nth1 = 32; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32]; + } break; + case GGML_TYPE_Q4_K: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 4; //1; + nth1 = 8; //32; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32]; + } break; + case GGML_TYPE_Q5_K: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 2; + nth1 = 32; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32]; + } break; + case GGML_TYPE_Q6_K: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 2; + nth1 = 32; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32]; + } break; + default: + { + metal_printf("Asserting on type %d\n",(int)src0t); + GGML_ASSERT(false && "not implemented"); + } + }; + + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; + [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17]; + + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 || + src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q4_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q3_K) { +#ifdef GGML_QKK_64 + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +#else + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +#endif + } + else if (src0t == GGML_TYPE_Q5_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q6_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else { + int64_t ny = (ne11 + nrows - 1)/nrows; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + } + } break; + case GGML_OP_GET_ROWS: + { + switch (src0->type) { + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; + case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; + case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; + case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break; + case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break; + case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break; + case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break; + case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break; + case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break; + default: GGML_ASSERT(false && "not implemented"); + } + + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4]; + [encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5]; + + const int64_t n = ggml_nelements(src1); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_RMS_NORM: + { + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + const int nth = 512; + + [encoder setComputePipelineState:ctx->pipeline_rms_norm]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_NORM: + { + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + const int nth = 256; + + [encoder setComputePipelineState:ctx->pipeline_norm]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ALIBI: + { + GGML_ASSERT((src0t == GGML_TYPE_F32)); + + const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past); + const int n_head = ((int32_t *) dst->op_params)[1]; + float max_bias; + memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); + + if (__builtin_popcount(n_head) != 1) { + GGML_ASSERT(false && "only power-of-two n_head implemented"); + } + + const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); + const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); + + [encoder setComputePipelineState:ctx->pipeline_alibi_f32]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&m0 length:sizeof( float) atIndex:18]; + + const int nth = 32; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ROPE: + { + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + + float freq_base; + float freq_scale; + memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); + + [encoder setComputePipelineState:ctx->pipeline_rope]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&n_past length:sizeof( int) atIndex:18]; + [encoder setBytes:&n_dims length:sizeof( int) atIndex:19]; + [encoder setBytes:&mode length:sizeof( int) atIndex:20]; + [encoder setBytes:&freq_base length:sizeof(float) atIndex:21]; + [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + } break; + case GGML_OP_DUP: + case GGML_OP_CPY: + case GGML_OP_CONT: + { + const int nth = 32; + + switch (src0t) { + case GGML_TYPE_F32: + { + switch (dstt) { + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break; + case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break; + default: GGML_ASSERT(false && "not implemented"); + }; + } break; + case GGML_TYPE_F16: + { + switch (dstt) { + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break; + case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break; + default: GGML_ASSERT(false && "not implemented"); + }; + } break; + default: GGML_ASSERT(false && "not implemented"); + } + + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + default: + { + metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); + GGML_ASSERT(false); + } + } + } + + if (encoder != nil) { + [encoder endEncoding]; + encoder = nil; + } + + [command_buffer commit]; + }); + } + + // wait for all threads to finish + dispatch_barrier_sync(ctx->d_queue, ^{}); + + // check status of command buffers + // needed to detect if the device ran out-of-memory for example (#1881) + for (int i = 0; i < n_cb; i++) { + [ctx->command_buffers[i] waitUntilCompleted]; + + MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status]; + if (status != MTLCommandBufferStatusCompleted) { + metal_printf("%s: command buffer %d failed with status %lu\n", __func__, i, status); + GGML_ASSERT(false); + } + } + + } +} diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-metal.metal b/plugins/wasi_nn/thirdparty/ggml/ggml-metal.metal new file mode 100644 index 00000000..f45b1490 --- /dev/null +++ b/plugins/wasi_nn/thirdparty/ggml/ggml-metal.metal @@ -0,0 +1,2254 @@ +#include + +using namespace metal; + +#define MAX(x, y) ((x) > (y) ? (x) : (y)) + +#define QK4_0 32 +#define QR4_0 2 +typedef struct { + half d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; + +#define QK4_1 32 +typedef struct { + half d; // delta + half m; // min + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; + +#define QK8_0 32 +typedef struct { + half d; // delta + int8_t qs[QK8_0]; // quants +} block_q8_0; + +kernel void kernel_add( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] + src1[tpig]; +} + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_add_row( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + constant int64_t & nb, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] + src1[tpig % nb]; +} + +kernel void kernel_mul( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * src1[tpig]; +} + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_mul_row( + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + constant int64_t & nb, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * src1[tpig % nb]; +} + +kernel void kernel_scale( + device const float4 * src0, + device float4 * dst, + constant float & scale, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * scale; +} + +kernel void kernel_silu( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + dst[tpig] = x / (1.0f + exp(-x)); +} + +kernel void kernel_relu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = max(0.0f, src0[tpig]); +} + +constant float GELU_COEF_A = 0.044715f; +constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +kernel void kernel_gelu( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + + // BEWARE !!! + // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! + // This was observed with Falcon 7B and 40B models + // + dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_soft_max( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + // parallel max + float lmax = psrc0[tpitg[0]]; + for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) { + lmax = MAX(lmax, psrc0[i00]); + } + const float max = simd_max(lmax); + + // parallel sum + float lsum = 0.0f; + for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { + const float exp_psrc0 = exp(psrc0[i00] - max); + lsum += exp_psrc0; + // Remember the result of exp here. exp is expensive, so we really do not + // whish to compute it twice. + pdst[i00] = exp_psrc0; + } + + const float sum = simd_sum(lsum); + + for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { + pdst[i00] /= sum; + } +} + +kernel void kernel_soft_max_4( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + // parallel max + float4 lmax4 = psrc4[tpitg[0]]; + for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) { + lmax4 = fmax(lmax4, psrc4[i00]); + } + float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); + + const float max = simd_max(lmax); + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) { + const float4 exp_psrc4 = exp(psrc4[i00] - max); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; + + const float sum = simd_sum(lsum); + + for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) { + pdst4[i00] /= sum; + } +} + +kernel void kernel_diag_mask_inf( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int & n_past, + uint3 tpig[[thread_position_in_grid]]) { + const int64_t i02 = tpig[2]; + const int64_t i01 = tpig[1]; + const int64_t i00 = tpig[0]; + + if (i00 > n_past + i01) { + dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; + } else { + dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; + } +} + +kernel void kernel_diag_mask_inf_8( + device const float4 * src0, + device float4 * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int & n_past, + uint3 tpig[[thread_position_in_grid]]) { + + const int64_t i = 2*tpig[0]; + + dst[i+0] = src0[i+0]; + dst[i+1] = src0[i+1]; + int64_t i4 = 4*i; + const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; + const int64_t i01 = i4/(ne00); i4 -= i01*ne00; + const int64_t i00 = i4; + for (int k = 3; k >= 0; --k) { + if (i00 + 4 + k <= n_past + i01) { + break; + } + dst[i+1][k] = -INFINITY; + if (i00 + k > n_past + i01) { + dst[i][k] = -INFINITY; + } + } +} + +kernel void kernel_norm( + device const void * src0, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant float & eps, + threadgroup float * sum [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01); + // MEAN + // parallel sum + sum[tpitg] = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + sum[tpitg] += x[i00]; + } + // reduce + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = ntg/2; i > 0; i /= 2) { + if (tpitg < i) { + sum[tpitg] += sum[tpitg + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + const float mean = sum[0] / ne00; + + // recenter and VARIANCE + threadgroup_barrier(mem_flags::mem_threadgroup); + device float * y = dst + tgpig*ne00; + sum[tpitg] = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + y[i00] = x[i00] - mean; + sum[tpitg] += y[i00] * y[i00]; + } + + // reduce + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = ntg/2; i > 0; i /= 2) { + if (tpitg < i) { + sum[tpitg] += sum[tpitg + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + const float variance = sum[0] / ne00; + + const float scale = 1.0f/sqrt(variance + eps); + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + y[i00] = y[i00] * scale; + } +} + +kernel void kernel_rms_norm( + device const void * src0, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant float & eps, + threadgroup float * sum [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); + device const float * x_scalar = (device const float *) x; + float4 sumf=0; + float all_sum=0; + + // parallel sum + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + sumf += x[i00] * x[i00]; + } + all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3]; + all_sum = simd_sum(all_sum); + if (tiisg == 0) { + sum[sgitg] = all_sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + // broadcast, simd group number is ntg / 32 + for (uint i = ntg / 32 / 2; i > 0; i /= 2) { + if (tpitg < i) { + sum[tpitg] += sum[tpitg + i]; + } + } + if (tpitg == 0) { + for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];} + sum[0] /= ne00; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float mean = sum[0]; + const float scale = 1.0f/sqrt(mean + eps); + + device float4 * y = (device float4 *) (dst + tgpig*ne00); + device float * y_scalar = (device float *) y; + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + y[i00] = x[i00] * scale; + } + if (tpitg == 0) { + for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;} + } +} + +// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + float2 acc = 0.f; + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + + yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) + + yl[i + 9] * (qs[i / 2] & 0xF000); + } + return d * (sumy * -8.f + acc[0] + acc[1]); +} + +// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + float m = qb_curr->m; + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + float2 acc = 0.f; + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + + yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) + + yl[i + 9] * (qs[i / 2] & 0xF000); + } + return d * (acc[0] + acc[1]) + sumy * m; +} + +// putting them in the kernel cause a significant performance penalty +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 2 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 +//Note: This is a template, but strictly speaking it only applies to +// quantizations where the block size is 32. It also does not +// giard against the number of rows not being divisible by +// N_DST, so this is another explicit assumption of the implementation. +template +void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst, + int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa, + uint3 tgpig, uint tiisg, uint sgitg) { + const int nb = ne00/QK4_0; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * nsg + sgitg) * nr; + const uint offset0 = first_row * nb + im/gqa*(nb*ne0); + device const block_q_type * x = (device const block_q_type *) src0 + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + float yl[16]; // src1 vector cache + float sumf[nr]={0.f}; + + const int ix = tiisg/2; + const int il = 8*(tiisg%2); + + device const float * yb = y + ix * QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += nw/2) { + float sumy = 0; + for (int i = 0; i < 8; i += 2) { + sumy += yb[i] + yb[i+1]; + yl[i+0] = yb[i+ 0]; + yl[i+1] = yb[i+ 1]/256.f; + sumy += yb[i+16] + yb[i+17]; + yl[i+8] = yb[i+16]/16.f; + yl[i+9] = yb[i+17]/4096.f; + } + + for (int row = 0; row < nr; row++) { + sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il); + } + + yb += QK4_0 * 16; + } + + for (int row = 0; row < nr; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && first_row + row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + } + } +} + +kernel void kernel_mul_mat_q4_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); +} + +kernel void kernel_mul_mat_q4_1_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); +} + +#define NB_Q8_0 8 + +kernel void kernel_mul_mat_q8_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + const int nr = N_DST; + const int nsg = N_SIMDGROUP; + const int nw = N_SIMDWIDTH; + + const int nb = ne00/QK8_0; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * nsg + sgitg) * nr; + const uint offset0 = first_row * nb + im/gqa*(nb*ne0); + device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[NB_Q8_0]; + float sumf[nr]={0.f}; + + const int ix = tiisg/4; + const int il = tiisg%4; + + device const float * yb = y + ix * QK8_0 + NB_Q8_0*il; + + // each thread in a SIMD group deals with NB_Q8_0 quants at a time + for (int ib = ix; ib < nb; ib += nw/4) { + for (int i = 0; i < NB_Q8_0; ++i) { + yl[i] = yb[i]; + } + + for (int row = 0; row < nr; row++) { + device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il; + float sumq = 0.f; + for (int iq = 0; iq < NB_Q8_0; ++iq) { + sumq += qs[iq] * yl[iq]; + } + sumf[row] += sumq*x[ib+row*nb].d; + } + + yb += NB_Q8_0 * nw; + } + + for (int row = 0; row < nr; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && first_row + row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + } + } +} + +kernel void kernel_mul_mat_f16_f32_1row( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int64_t im = tgpig.z; + + device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + if (ne00 < 128) { + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } else { + device const half4 * x4 = (device const half4 *) x; + device const float4 * y4 = (device const float4 *) y; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k]; + } + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + +} + +#define N_F16_F32 4 + +kernel void kernel_mul_mat_f16_f32( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int64_t r0 = tgpig.x; + const int64_t rb = tgpig.y*N_F16_F32; + const int64_t im = tgpig.z; + + device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); + + if (ne00 < 128) { + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + device const half4 * x4 = (device const half4 *)x; + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const float4 * y4 = (device const float4 *) y; + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} + +// Assumes row size (ne00) is a multiple of 4 +kernel void kernel_mul_mat_f16_f32_l4( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int nrows = ne11; + const int64_t r0 = tgpig.x; + const int64_t im = tgpig.z; + + device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); + + for (int r1 = 0; r1 < nrows; ++r1) { + device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } +} + +kernel void kernel_alibi_f32( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant float & m0, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + float m_k = pow(m0, i2 + 1); + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1); + } +} + +kernel void kernel_rope( + device const void * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int & n_past, + constant int & n_dims, + constant int & mode, + constant float & freq_base, + constant float & freq_scale, + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg[[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int64_t i3 = tgpig[2]; + const int64_t i2 = tgpig[1]; + const int64_t i1 = tgpig[0]; + + const bool is_neox = mode & 2; + + const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + + const float theta_0 = freq_scale * (float)p; + const float inv_ndims = -1.f/n_dims; + + if (!is_neox) { + for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { + + const float theta = theta_0 * pow(freq_base, inv_ndims*i0); + const float cos_theta = cos(theta); + const float sin_theta = sin(theta); + + device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[1]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[1] = x0*sin_theta + x1*cos_theta; + } + } else { + for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { + for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) { + + const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib); + const float cos_theta = cos(theta); + const float sin_theta = sin(theta); + + const int64_t i0 = ib*n_dims + ic/2; + + device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } + } + } +} + +kernel void kernel_cpy_f16_f16( + device const half * src0, + device half * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_f16( + device const float * src0, + device half * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_f32( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} + +//============================================ k-quants ====================================================== + +#ifndef QK_K +#define QK_K 256 +#else +static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64"); +#endif + +#if QK_K == 256 +#define K_SCALE_SIZE 12 +#else +#define K_SCALE_SIZE 4 +#endif + +typedef struct { + uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits + uint8_t qs[QK_K/4]; // quants + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins +} block_q2_K; +// 84 bytes / block + +typedef struct { + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits +#if QK_K == 64 + uint8_t scales[2]; +#else + uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits +#endif + half d; // super-block scale +} block_q3_K; + +#if QK_K == 64 +typedef struct { + half d[2]; // super-block scales/mins + uint8_t scales[2]; + uint8_t qs[QK_K/2]; // 4-bit quants +} block_q4_K; +#else +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +#endif + +#if QK_K == 64 +typedef struct { + half d; // super-block scales/mins + int8_t scales[QK_K/16]; // 8-bit block scales + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +#else +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +// 176 bytes / block +#endif + +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + half d; // super-block scale +} block_q6_K; +// 210 bytes / block + +static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { + uchar4 r; + if (j < 4) { + r[0] = q[j+0] & 63; + r[2] = q[j+1] & 63; + r[1] = q[j+4] & 63; + r[3] = q[j+5] & 63; + } else { + r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4); + r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4); + } + return r; +} + +//====================================== dot products ========================= + +kernel void kernel_mul_mat_q2_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int r2 = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + const uint offset0 = r2/gqa*(nb*ne0); + device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int step = sizeof(block_q2_K) * nb; + +#if QK_K == 256 + const int ix = tiisg/8; // 0...3 + const int it = tiisg%8; // 0...7 + const int im = it/4; // 0 or 1 + const int ir = it%4; // 0...3 + const int is = (8*ir)/16;// 0 or 1 + + device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir; + + for (int ib = ix; ib < nb; ib += 4) { + + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+96]; sumy[3] += yl[i+24]; + } + + device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir; + device const half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); + acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); + acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); + acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); + acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); + acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); + acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); + acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); + } + float dall = dh[0]; + float dmin = dh[1] * 1.f/16.f; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + + (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f + + (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f + + (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - + dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); + + qs += step/2; + sc += step; + dh += step/2; + } + + y4 += 4 * QK_K; + } +#else + const int ix = tiisg/2; // 0...15 + const int it = tiisg%2; // 0...1 + + device const float * y4 = y + ix * QK_K + 8 * it; + + for (int ib = ix; ib < nb; ib += 16) { + + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+32]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+48]; sumy[3] += yl[i+24]; + } + + device const uint8_t * sc = (device const uint8_t *)x[ib].scales; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; + device const half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); + acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); + acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); + acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); + acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); + acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); + acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); + acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + + (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f + + (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f + + (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) - + dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4)); + + qs += step/2; + sc += step; + dh += step/2; + } + + y4 += 16 * QK_K; + } +#endif + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum; + } + } +} + +#if QK_K == 256 +kernel void kernel_mul_mat_q3_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int64_t r2 = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; + const uint offset0 = r2/gqa*(nb*ne0); + device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; + + float yl[32]; + + const uint16_t kmask1 = 0x3030; + const uint16_t kmask2 = 0x0f0f; + + const int tid = tiisg/4; + const int ix = tiisg%4; + const int ip = tid/4; // 0 or 1 + const int il = 2*((tid%4)/2); // 0 or 2 + const int ir = tid%2; + const int n = 8; + const int l0 = n*ir; + + // One would think that the Metal compiler would figure out that ip and il can only have + // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it + // with these two tales. + // + // Possible masks for the high bit + const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0 + {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2 + {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0 + {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2 + + // Possible masks for the low 2 bits + const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}}; + + const ushort4 hm = mm[2*ip + il/2]; + + const int shift = 2*il; + const float v1 = il == 0 ? 4.f : 64.f; + const float v2 = 4.f * v1; + + const uint16_t s_shift1 = 4*ip; + const uint16_t s_shift2 = s_shift1 + il; + + const int q_offset = 32*ip + l0; + const int y_offset = 128*ip + 32*il + l0; + + const int step = sizeof(block_q3_K) * nb / 2; + + device const float * y1 = yy + ix*QK_K + y_offset; + + uint32_t scales32, aux32; + thread uint16_t * scales16 = (thread uint16_t *)&scales32; + thread const int8_t * scales = (thread const int8_t *)&scales32; + + float sumf1[2] = {0.f}; + float sumf2[2] = {0.f}; + for (int i = ix; i < nb; i += 4) { + + for (int l = 0; l < 8; ++l) { + yl[l+ 0] = y1[l+ 0]; + yl[l+ 8] = y1[l+16]; + yl[l+16] = y1[l+32]; + yl[l+24] = y1[l+48]; + } + + device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); + device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0); + device const uint16_t * a = (device const uint16_t *)(x[i].scales); + device const half * dh = &x[i].d; + + for (int row = 0; row < 2; ++row) { + + const float d_all = (float)dh[0]; + + scales16[0] = a[4]; + scales16[1] = a[5]; + aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030; + scales16[0] = a[il+0]; + scales16[1] = a[il+1]; + scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32; + + float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; + for (int l = 0; l < n; l += 2) { + const int32_t qs = q[l/2]; + s1 += yl[l+0] * (qs & qm[il/2][0]); + s2 += yl[l+1] * (qs & qm[il/2][1]); + s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]); + s4 += yl[l+16] * (qs & qm[il/2][2]); + s5 += yl[l+17] * (qs & qm[il/2][3]); + s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]); + } + float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); + float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); + sumf1[row] += d1 * (scales[0] - 32); + sumf2[row] += d2 * (scales[2] - 32); + + s1 = s2 = s3 = s4 = s5 = s6 = 0; + for (int l = 0; l < n; l += 2) { + const int32_t qs = q[l/2+8]; + s1 += yl[l+8] * (qs & qm[il/2][0]); + s2 += yl[l+9] * (qs & qm[il/2][1]); + s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]); + s4 += yl[l+24] * (qs & qm[il/2][2]); + s5 += yl[l+25] * (qs & qm[il/2][3]); + s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]); + } + d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); + d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); + sumf1[row] += d1 * (scales[1] - 32); + sumf2[row] += d2 * (scales[3] - 32); + + q += step; + h += step; + a += step; + dh += step; + + } + + y1 += 4 * QK_K; + + } + + for (int row = 0; row < 2; ++row) { + const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); + sumf1[row] = simd_sum(sumf); + } + if (tiisg == 0) { + for (int row = 0; row < 2; ++row) { + dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row]; + } + } + +} +#else +kernel void kernel_mul_mat_q3_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int64_t r2 = tgpig.z; + + const int row = 2 * r0 + sgitg; + const uint offset0 = r2/gqa*(nb*ne0); + device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; + const int ix = tiisg/4; + const int il = 4 * (tiisg%4);// 0, 4, 8, 12 + const int im = il/8; // 0, 0, 1, 1 + const int in = il%8; // 0, 4, 0, 4 + + float2 sum = {0.f, 0.f}; + + for (int i = ix; i < nb; i += 8) { + + const float d_all = (float)(x[i].d); + + device const uint16_t * q = (device const uint16_t *)(x[i].qs + il); + device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in); + device const uint16_t * s = (device const uint16_t *)(x[i].scales); + device const float * y = yy + i * QK_K + il; + + const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8); + const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f; + const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f; + const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f; + + for (int l = 0; l < 4; l += 2) { + const uint16_t hm = h[l/2] >> im; + sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4)) + + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16)) + + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64)) + + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256)); + sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024)) + + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096)) + + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384)) + + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536)); + } + + } + const float sumf = sum[0] + sum[1] * 1.f/256.f; + + const float tot = simd_sum(sumf); + if (tiisg == 0) { + dst[r1*ne0 + r2*ne0*ne1 + row] = tot; + } + +} +#endif + +#if QK_K == 256 +kernel void kernel_mul_mat_q4_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int ix = tiisg/8; // 0...3 + const int it = tiisg%8; // 0...7 + const int im = it/4; // 0 or 1 + const int ir = it%4; // 0...3 + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int r2 = tgpig.z; + //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = r0 * N_DST; + const int ib_row = first_row * nb; + const uint offset0 = r2/gqa*(nb*ne0); + device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; + float yl[16]; + float yh[16]; + float sumf[N_DST]={0.f}, all_sum; + + const int step = sizeof(block_q4_K) * nb / 2; + + device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir; + + uint16_t sc16[4]; + thread const uint8_t * sc8 = (thread const uint8_t *)sc16; + + for (int ib = ix; ib < nb; ib += 4) { + + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; + yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8]; + yh[i+0] = y4[i+128]; sumy[2] += yh[i+0]; + yh[i+8] = y4[i+160]; sumy[3] += yh[i+8]; + } + + device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im; + device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir; + device const half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + device const uint16_t * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+0] * (q1[i/2] & 0x000F); + acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00); + acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0); + acc1[3] += yl[i+9] * (q1[i/2] & 0xF000); + acc2[0] += yh[i+0] * (q2[i/2] & 0x000F); + acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00); + acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0); + acc2[3] += yh[i+9] * (q2[i/2] & 0xF000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + + (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + + (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + + (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - + dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + + q1 += step; + sc += step; + dh += step; + } + + y4 += 4 * QK_K; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum; + } + } +} +#else +kernel void kernel_mul_mat_q4_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int ix = tiisg/4; // 0...7 + const int it = tiisg%4; // 0...3 + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int r2 = tgpig.z; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + const uint offset0 = r2/gqa*(nb*ne0); + device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; + float yl[8]; + float yh[8]; + float sumf[N_DST]={0.f}, all_sum; + + const int step = sizeof(block_q4_K) * nb / 2; + + device const float * y4 = y + ix * QK_K + 8 * it; + + uint16_t sc16[4]; + + for (int ib = ix; ib < nb; ib += 8) { + + float2 sumy = {0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i] = y4[i+ 0]; sumy[0] += yl[i]; + yh[i] = y4[i+32]; sumy[1] += yh[i]; + } + + device const uint16_t * sc = (device const uint16_t *)x[ib].scales; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; + device const half * dh = x[ib].d; + + for (int row = 0; row < N_DST; row++) { + + sc16[0] = sc[0] & 0x000f; + sc16[1] = sc[0] & 0x0f00; + sc16[2] = sc[0] & 0x00f0; + sc16[3] = sc[0] & 0xf000; + + float2 acc1 = {0.f, 0.f}; + float2 acc2 = {0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+0] * (qs[i/2] & 0x000F); + acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00); + acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0); + acc2[1] += yh[i+1] * (qs[i/2] & 0xF000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] + + (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) - + dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f); + + qs += step; + sc += step; + dh += step; + } + + y4 += 8 * QK_K; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum; + } + } +} +#endif + +kernel void kernel_mul_mat_q5_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int r2 = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; + const uint offset0 = r2/gqa*(nb*ne0); + device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; + + float sumf[2]={0.f}; + + const int step = sizeof(block_q5_K) * nb; + +#if QK_K == 256 +# + float yl[16], yh[16]; + + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int tid = tiisg/4; + const int ix = tiisg%4; + const int im = tid/4; + const int ir = tid%4; + const int n = 8; + + const int l0 = n*ir; + const int q_offset = 32*im + l0; + const int y_offset = 64*im + l0; + + const uint8_t hm1 = 1u << (2*im); + const uint8_t hm2 = hm1 << 1; + const uint8_t hm3 = hm1 << 4; + const uint8_t hm4 = hm2 << 4; + + uint16_t sc16[4]; + thread const uint8_t * sc8 = (thread const uint8_t *)sc16; + + device const float * y1 = yy + ix*QK_K + y_offset; + + for (int i = ix; i < nb; i += 4) { + + device const uint8_t * q1 = x[i].qs + q_offset; + device const uint8_t * qh = x[i].qh + l0; + device const half * dh = &x[i].d; + device const uint16_t * a = (device const uint16_t *)x[i].scales + im; + + device const float * y2 = y1 + 128; + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < 8; ++l) { + yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; + yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; + yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; + yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; + } + + for (int row = 0; row < 2; ++row) { + + device const uint8_t * q2 = q1 + 64; + + sc16[0] = a[0] & kmask1; + sc16[1] = a[2] & kmask1; + sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); + sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); + + float4 acc1 = {0.f}; + float4 acc2 = {0.f}; + for (int l = 0; l < n; ++l) { + uint8_t h = qh[l]; + acc1[0] += yl[l+0] * (q1[l] & 0x0F); + acc1[1] += yl[l+8] * (q1[l] & 0xF0); + acc1[2] += yh[l+0] * (q2[l] & 0x0F); + acc1[3] += yh[l+8] * (q2[l] & 0xF0); + acc2[0] += h & hm1 ? yl[l+0] : 0.f; + acc2[1] += h & hm2 ? yl[l+8] : 0.f; + acc2[2] += h & hm3 ? yh[l+0] : 0.f; + acc2[3] += h & hm4 ? yh[l+8] : 0.f; + } + const float dall = dh[0]; + const float dmin = dh[1]; + sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) + + sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + + sc8[4] * (acc1[2] + 16.f*acc2[2]) + + sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - + dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + + q1 += step; + qh += step; + dh += step/2; + a += step/2; + + } + + y1 += 4 * QK_K; + + } +#else + float yl[8], yh[8]; + + const int il = 4 * (tiisg/8); // 0, 4, 8, 12 + const int ix = tiisg%8; + const int im = il/8; // 0, 0, 1, 1 + const int in = il%8; // 0, 4, 0, 4 + + device const float * y = yy + ix*QK_K + il; + + for (int i = ix; i < nb; i += 8) { + + for (int l = 0; l < 4; ++l) { + yl[l+0] = y[l+ 0]; + yl[l+4] = y[l+16]; + yh[l+0] = y[l+32]; + yh[l+4] = y[l+48]; + } + + device const half * dh = &x[i].d; + device const uint8_t * q = x[i].qs + il; + device const uint8_t * h = x[i].qh + in; + device const int8_t * s = x[i].scales; + + for (int row = 0; row < 2; ++row) { + + const float d = dh[0]; + + float2 acc = {0.f, 0.f}; + for (int l = 0; l < 4; ++l) { + const uint8_t hl = h[l] >> im; + acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16)) + + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16)); + acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256)) + + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256)); + } + sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]); + + q += step; + h += step; + s += step; + dh += step/2; + + } + + y += 8 * QK_K; + } +#endif + + for (int row = 0; row < 2; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot; + } + } + +} + +kernel void kernel_mul_mat_q6_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const uint8_t kmask1 = 0x03; + const uint8_t kmask2 = 0x0C; + const uint8_t kmask3 = 0x30; + const uint8_t kmask4 = 0xC0; + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int r2 = tgpig.z; + + const int row = 2 * r0 + sgitg; + const uint offset0 = r2/gqa*(nb*ne0); + device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0; + device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; + + float sumf = 0; + +#if QK_K == 256 + const int tid = tiisg/2; + const int ix = tiisg%2; + const int ip = tid/8; // 0 or 1 + const int il = tid%8; + const int n = 4; + const int l0 = n*il; + const int is = 8*ip + l0/16; + + const int y_offset = 128*ip + l0; + const int q_offset_l = 64*ip + l0; + const int q_offset_h = 32*ip + l0; + + for (int i = ix; i < nb; i += 2) { + + device const uint8_t * q1 = x[i].ql + q_offset_l; + device const uint8_t * q2 = q1 + 32; + device const uint8_t * qh = x[i].qh + q_offset_h; + device const int8_t * sc = x[i].scales + is; + + device const float * y = yy + i * QK_K + y_offset; + + const float dall = x[i].d; + + float4 sums = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < n; ++l) { + sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); + sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + } + + sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); + + } + +#else + const int ix = tiisg/4; + const int il = 4*(tiisg%4); + + for (int i = ix; i < nb; i += 8) { + device const float * y = yy + i * QK_K + il; + device const uint8_t * ql = x[i].ql + il; + device const uint8_t * qh = x[i].qh + il; + device const int8_t * s = x[i].scales; + + const float d = x[i].d; + + float4 sums = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < 4; ++l) { + sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32); + sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + } + sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]); + } + +#endif + + const float tot = simd_sum(sumf); + if (tiisg == 0) { + dst[r1*ne0 + r2*ne0*ne1 + row] = tot; + } +} + +//============================= templates and their specializations ============================= + +template +void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { + half4x4 temp = *(((device half4x4 *)src)); + for (int i = 0; i < 16; i++){ + reg[i/4][i%4] = temp[i/4][i%4]; + } +} + +template +void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { + + device const uint16_t * qs = ((device const uint16_t *)xb + 1); + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float md = -8.h * xb->d; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + for (int i=0;i<8;i++) { + reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md; + reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md; + } + +} + +template +void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { + + device const uint16_t * qs = ((device const uint16_t *)xb + 2); + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float m = xb->m; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + for (int i=0;i<8;i++) { + reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m; + reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m; + } +} + +template +void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { + device const int8_t * qs = ((device const int8_t *)xb->qs); + const half d = xb->d; + + for (int i=0;i<16;i++) { + reg[i/4][i%4] = (qs[i + 16*il] * d); + } +} + +template +void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { + const half d = xb->d; + const half min = xb->dmin; + device const uint8_t * q = (device const uint8_t *)xb->qs; + half dl, ml; + uint8_t sc = xb->scales[il]; + +#if QK_K == 256 + q = q + 32*(il/8) + 16*(il&1); + il = (il/2)%4; +#endif + half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4); + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - ml; + } +} + +template +void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { + const half d_all = xb->d; + device const uint8_t * q = (device const uint8_t *)xb->qs; + device const uint8_t * h = (device const uint8_t *)xb->hmask; + device const int8_t * scales = (device const int8_t *)xb->scales; + +#if QK_K == 256 + q = q + 32 * (il/8) + 16 * (il&1); + h = h + 16 * (il&1); + uint8_t m = 1 << (il/2); + uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \ + ((il/4)>0 ? 12 : 3); + uint16_t kmask2 = il/8 ? 0xF0 : 0x0F; + uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; + int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) + : (scale_2&kmask2) | ((scale_1&kmask1) << 4); + half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h); + const half ml = 4.h * dl; + + il = (il/2) & 3; + const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + dl *= coef; + + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml); + } + +#else + float kcoef = il&1 ? 1.f/16.f : 1.f; + uint16_t kmask = il&1 ? 0xF0 : 0x0F; + float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8); + float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + uint8_t m = 1<<(il*2); + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef)); + } +#endif +} + +static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) { + return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)} + : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))}; +} + +template +void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) { + device const uchar * q = xb->qs; + +#if QK_K == 256 + short is = (il/4) * 2; + q = q + (il/4) * 32 + 16 * (il&1); + il = il & 3; + const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); + const half d = il < 2 ? xb->d : xb->d / 16.h; + const half min = xb->dmin; + const half dl = d * sc[0]; + const half ml = min * sc[1]; +#else + q = q + 16 * (il&1); + device const uint8_t * s = xb->scales; + device const half2 * dh = (device const half2 *)xb->d; + const float2 d = (float2)dh[0]; + const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h; + const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4); +#endif + const ushort mask = il<2 ? 0x0F : 0xF0; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - ml; + } + +} + +template +void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) { + device const uint8_t * q = xb->qs; + device const uint8_t * qh = xb->qh; + +#if QK_K == 256 + short is = (il/4) * 2; + q = q + 32 * (il/4) + 16 * (il&1); + qh = qh + 16 * (il&1); + uint8_t ul = 1 << (il/2); + il = il & 3; + const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); + const half d = il < 2 ? xb->d : xb->d / 16.h; + const half min = xb->dmin; + const half dl = d * sc[0]; + const half ml = min * sc[1]; + + const ushort mask = il<2 ? 0x0F : 0xF0; + const half qh_val = il<2 ? 16.h : 256.h; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; + } +#else + q = q + 16 * (il&1); + device const int8_t * s = xb->scales; + const float dl = xb->d * s[il]; + uint8_t m = 1<<(il*2); + const float coef = il<2 ? 1.f : 1.f/16.f; + const ushort mask = il<2 ? 0x0F : 0xF0; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef)); + } +#endif +} + +template +void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { + const half d_all = xb->d; + device const uint8_t * ql = (device const uint8_t *)xb->ql; + device const uint8_t * qh = (device const uint8_t *)xb->qh; + device const int8_t * scales = (device const int8_t *)xb->scales; + +#if QK_K == 256 + ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); + qh = qh + 32*(il/8) + 16*(il&1); + half sc = scales[(il%2) + 2 * ((il/2))]; + il = (il/2) & 3; +#else + ql = ql + 16 * (il&1); + half sc = scales[il]; +#endif + const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; + const half coef = il>1 ? 1.f/16.h : 1.h; + const half ml = d_all * sc * 32.h; + const half dl = d_all * sc * coef; + for (int i = 0; i < 16; ++i) { + const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2)) + : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4)); + reg[i/4][i%4] = dl * q - ml; + } +} + +template +kernel void kernel_get_rows( + device const void * src0, + device const int * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb1, + uint tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tptg[[threads_per_threadgroup]]) { + const int i = tgpig; + const int r = ((device int32_t *) src1)[i]; + + for (int ind = tiitg; ind < ne00/16; ind += tptg) { + float4x4 temp; + dequantize_func( + ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp); + *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp; + } +} + +#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A +#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A +#define BLOCK_SIZE_K 32 +#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A +#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B +#define THREAD_PER_BLOCK 128 +#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers +#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers +#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8 +#define SG_MAT_ROW 8 + +// each block_q contains 16*nl weights +template +kernel void kernel_mul_mm(device const uchar * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & gqa, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup half * sa = ((threadgroup half *)shared_memory); + threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); + + const uint r0 = tgpig.y; + const uint r1 = tgpig.x; + const uint im = tgpig.z; + // if this block is of 64x32 shape or smaller + short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; + short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; + // a thread shouldn't load data outside of the matrix + short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; + short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; + + simdgroup_half8x8 ma[4]; + simdgroup_float8x8 mb[2]; + simdgroup_float8x8 c_res[8]; + for (int i = 0; i < 8; i++){ + c_res[i] = make_filled_simdgroup_matrix(0.f); + } + + short il = (tiitg % THREAD_PER_ROW); + uint offset0 = im/gqa*nb02; ushort offset1 = il/nl; + device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; + device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \ + + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1; + + for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { + //load data and store to threadgroup memory + half4x4 temp_a; + dequantize_func(x, il, temp_a); + threadgroup_barrier(mem_flags::mem_threadgroup); + #pragma unroll(16) + for (int i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ + + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \ + + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; + } + *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \ + = *((device float2x4 *)y); + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2+nl-1)/nl : x; + y += BLOCK_SIZE_K; + + threadgroup_barrier(mem_flags::mem_threadgroup); + //load matrices from threadgroup memory and conduct outer products + threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); + threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + #pragma unroll(4) + for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + #pragma unroll(4) + for (int i = 0; i < 4; i++) { + simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); + } + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(2) + for (int i = 0; i < 2; i++) { + simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); + } + + lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; + lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; + #pragma unroll(8) + for (int i = 0; i < 8; i++){ + simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); + } + } + } + + if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { + device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \ + + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0; + for (int i = 0; i < 8; i++) { + simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); + } + } else { + // block is smaller than 64x32, we should avoid writing data outside of the matrix + threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup float *temp_str = ((threadgroup float *)shared_memory) \ + + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; + for (int i = 0; i < 8; i++) { + simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; + if (sgitg==0) { + for (int i = 0; i < n_rows; i++) { + for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) { + *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); + } + } + } + } +} + +#if QK_K == 256 +#define QK_NL 16 +#else +#define QK_NL 4 +#endif + +typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \ + constant uint64_t &, constant uint64_t &, uint, uint, uint); + +template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; + +typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\ + constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \ + constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint); + +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index 45748cca..92ee74d7 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -4,6 +4,8 @@ #include "wasinnenv.h" #include "wasinnmodule.h" +#include + namespace WasmEdge { namespace Host { From 806d974af1a4d993d26089569b25df2366dac395 Mon Sep 17 00:00:00 2001 From: hydai Date: Fri, 22 Sep 2023 20:51:06 +0800 Subject: [PATCH 158/623] [CI] Enable METAL Signed-off-by: hydai --- plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt b/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt index f18b264a..7e8b9aea 100644 --- a/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt +++ b/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt @@ -2,9 +2,11 @@ # Option list # -# Get errors when enabling METAL API -# Disable it currently -set(LLAMA_METAL_DEFAULT OFF) +if (APPLE) + set(LLAMA_METAL_DEFAULT ON) +else() + set(LLAMA_METAL_DEFAULT OFF) +endif() # general option(LLAMA_STATIC "llama: static link libraries" OFF) @@ -49,7 +51,7 @@ set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF) option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT}) -option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF) +option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" ON) option(LLAMA_MPI "llama: use MPI" OFF) option(LLAMA_K_QUANTS "llama: use k-quants" ON) option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF) From e67fba69549369fe7ad8875706df3fa8839fa2f5 Mon Sep 17 00:00:00 2001 From: hydai Date: Mon, 2 Oct 2023 15:37:57 +0800 Subject: [PATCH 159/623] [Misc] Add WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL for controlling the LLAMA_METAL flag Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 14 ++++++++++++-- plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt | 2 +- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 8bb5e389..6e246d39 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -2,16 +2,26 @@ # SPDX-FileCopyrightText: 2019-2022 Second State INC # llama.cpp options +# Disable warnings and debug messages set(LLAMA_ALL_WARNINGS OFF) +set(LLAMA_METAL_NDEBUG ON) + if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_BLAS) - message(INFO "WASI-NN GGML LLAMA backend: Enable LLAMA_BLAS") + message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_BLAS") # Default use OpenBLAS set(LLAMA_BLAS ON) set(LLAMA_BLAS_VENDOR "OpenBLAS") else() - message(INFO "WASI-NN GGML LLAMA backend: Disable LLAMA_BLAS") + message(STATUS "WASI-NN GGML LLAMA backend: Disable LLAMA_BLAS") set(LLAMA_BLAS OFF) endif() +if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL) + message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_METAL") + set(LLAMA_METAL ON) +else() + message(STATUS "WASI-NN GGML LLAMA backend: Disable LLAMA_METAL") + set(LLAMA_METAL OFF) +endif() add_subdirectory(thirdparty) diff --git a/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt b/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt index 7e8b9aea..b21cc977 100644 --- a/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt +++ b/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt @@ -51,7 +51,7 @@ set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF) option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT}) -option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" ON) +option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF) option(LLAMA_MPI "llama: use MPI" OFF) option(LLAMA_K_QUANTS "llama: use k-quants" ON) option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF) From 8df3086c93d9bb283dcab22c20d4a49c17b5f868 Mon Sep 17 00:00:00 2001 From: hydai Date: Mon, 2 Oct 2023 17:00:49 +0800 Subject: [PATCH 160/623] [WASI-NN] Metal related options should only apply on macOS Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 6e246d39..ef6ad97b 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -15,6 +15,11 @@ else() message(STATUS "WASI-NN GGML LLAMA backend: Disable LLAMA_BLAS") set(LLAMA_BLAS OFF) endif() + +if(NOT APPLE) + set(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL OFF) +endif() + if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL) message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_METAL") set(LLAMA_METAL ON) From 8471d7a4df8859ba216a764c4d076e354252260a Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 3 Oct 2023 17:09:24 +0800 Subject: [PATCH 161/623] [WASI-NN] Upgrade ggml from b1273 to b1309 (#2941) * [WASI-NN] Upgrade ggml from b1273 to b1309 Signed-off-by: dm4 * [WASI-NN] Update ggml backend Signed-off-by: dm4 --------- Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 73 +- .../wasi_nn/thirdparty/ggml/CMakeLists.txt | 76 +- plugins/wasi_nn/thirdparty/ggml/common.cpp | 199 +- plugins/wasi_nn/thirdparty/ggml/common.h | 27 +- plugins/wasi_nn/thirdparty/ggml/ggml-alloc.c | 10 +- plugins/wasi_nn/thirdparty/ggml/ggml-alloc.h | 1 + plugins/wasi_nn/thirdparty/ggml/ggml-metal.h | 4 + plugins/wasi_nn/thirdparty/ggml/ggml-metal.m | 294 +- .../wasi_nn/thirdparty/ggml/ggml-metal.metal | 346 ++- plugins/wasi_nn/thirdparty/ggml/ggml.c | 2443 +++++++++++------ plugins/wasi_nn/thirdparty/ggml/ggml.h | 151 +- plugins/wasi_nn/thirdparty/ggml/llama.cpp | 1639 +++++++---- plugins/wasi_nn/thirdparty/ggml/llama.h | 428 ++- plugins/wasi_nn/thirdparty/ggml/log.h | 74 +- 14 files changed, 3840 insertions(+), 1925 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 03387726..fa241c33 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -83,8 +83,9 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, spdlog::error("[WASI-NN] GGML backend: Error: unable to init context."sv); return ErrNo::InvalidArgument; } + llama_model_params ModelParams = llama_model_default_params(); GraphRef.LlamaModel = - llama_load_model_from_file(ModelFilePath.c_str(), ContextParams); + llama_load_model_from_file(ModelFilePath.c_str(), ModelParams); if (GraphRef.LlamaModel == nullptr) { spdlog::error("[WASI-NN] GGML backend: Error: unable to init model."sv); Env.NNGraph.pop_back(); @@ -175,7 +176,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { const int MaxContextSize = llama_n_ctx(GraphRef.LlamaContext); // NPredict is the number of tokens to predict. Same as -n, --n-predict in // llama.cpp. - int NPredict = std::numeric_limits::max(); + int NPredict = MaxContextSize; const char *LlamaNPredictEnv = std::getenv("LLAMA_N_PREDICT"); if (LlamaNPredictEnv != nullptr) { try { @@ -195,22 +196,31 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { spdlog::info("[WASI-NN] GGML backend: set n_predict to {}"sv, NPredict); } } - while (llama_get_kv_cache_token_count(GraphRef.LlamaContext) < - MaxContextSize && - llama_get_kv_cache_token_count(GraphRef.LlamaContext) < NPredict) { - if (llama_eval(GraphRef.LlamaContext, CxtRef.LlamaInputs.data(), - int(CxtRef.LlamaInputs.size()), - llama_get_kv_cache_token_count(GraphRef.LlamaContext), - get_num_physical_cores())) { - spdlog::error("[WASI-NN] GGML backend: Llama failed to eval."sv); - return ErrNo::InvalidArgument; - } - CxtRef.LlamaInputs.clear(); - // Select the best prediction. - llama_token NewTokenId = 0; - auto Logits = llama_get_logits(GraphRef.LlamaContext); - auto NVocab = llama_n_vocab(GraphRef.LlamaContext); + // Evaluate the initial prompt. + llama_batch LlamaBatch = llama_batch_init(NPredict, 0); + LlamaBatch.n_tokens = CxtRef.LlamaInputs.size(); + for (int32_t I = 0; I < LlamaBatch.n_tokens; I++) { + LlamaBatch.token[I] = CxtRef.LlamaInputs[I]; + LlamaBatch.pos[I] = I; + LlamaBatch.seq_id[I] = 0; + LlamaBatch.logits[I] = false; + } + + // llama_decode will output logits only for the last token of the prompt + LlamaBatch.logits[LlamaBatch.n_tokens - 1] = true; + if (llama_decode(GraphRef.LlamaContext, LlamaBatch) != 0) { + spdlog::info("[WASI-NN] GGML backend: llama_decode() failed"sv); + return ErrNo::RuntimeError; + } + + int NCur = LlamaBatch.n_tokens; + while (NCur < MaxContextSize && NCur < NPredict) { + // Sample the next token + auto NVocab = llama_n_vocab(GraphRef.LlamaModel); + auto *Logits = + llama_get_logits_ith(GraphRef.LlamaContext, LlamaBatch.n_tokens - 1); + std::vector Candidates; Candidates.reserve(NVocab); for (llama_token TokenId = 0; TokenId < NVocab; TokenId++) { @@ -218,9 +228,14 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { } llama_token_data_array CandidatesP = {Candidates.data(), Candidates.size(), false}; - NewTokenId = llama_sample_token_greedy(GraphRef.LlamaContext, &CandidatesP); - if (NewTokenId == llama_token_eos(GraphRef.LlamaContext)) { + // Sample the most likely token + const llama_token NewTokenId = + llama_sample_token_greedy(GraphRef.LlamaContext, &CandidatesP); + + // Is it an end of stream? + if (NewTokenId == llama_token_eos(GraphRef.LlamaContext) || + NCur == MaxContextSize || NCur == NPredict) { break; } @@ -228,13 +243,25 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { CxtRef.LlamaOutputs += llama_token_to_piece(GraphRef.LlamaContext, NewTokenId); - // Push this new token for next evaluation. - CxtRef.LlamaInputs.push_back(NewTokenId); + // Prepare the next batch + LlamaBatch.n_tokens = 0; + + // Push this new token for next evaluation + LlamaBatch.token[LlamaBatch.n_tokens] = NewTokenId; + LlamaBatch.pos[LlamaBatch.n_tokens] = NCur; + LlamaBatch.seq_id[LlamaBatch.n_tokens] = 0; + LlamaBatch.logits[LlamaBatch.n_tokens] = true; + LlamaBatch.n_tokens += 1; + NCur += 1; + + // Evaluate the current batch with the transformer model + if (llama_decode(GraphRef.LlamaContext, LlamaBatch)) { + spdlog::error("[WASI-NN] GGML backend: failed to llama_decode"sv); + return ErrNo::RuntimeError; + } } if (LlamaLogEnv != nullptr) { - spdlog::info("[WASI-NN] GGML backend: llama_get_kv_cache_token_count {}"sv, - llama_get_kv_cache_token_count(GraphRef.LlamaContext)); llama_print_timings(GraphRef.LlamaContext); } diff --git a/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt b/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt index b21cc977..da9542b0 100644 --- a/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt +++ b/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt @@ -262,8 +262,9 @@ if (LLAMA_MPI) set(GGML_SOURCES_MPI ggml-mpi.c ggml-mpi.h) add_compile_definitions(GGML_USE_MPI) add_compile_definitions(${MPI_C_COMPILE_DEFINITIONS}) - set(cxx_flags ${cxx_flags} -Wno-cast-qual) - set(c_flags ${c_flags} -Wno-cast-qual) + if (NOT MSVC) + add_compile_options(-Wno-cast-qual) + endif() set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${MPI_C_LIBRARIES}) set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${MPI_C_INCLUDE_DIRS}) # Even if you're only using the C header, C++ programs may bring in MPI @@ -333,42 +334,55 @@ endif() if (LLAMA_ALL_WARNINGS) if (NOT MSVC) - set(c_flags - -Wall - -Wextra - -Wpedantic - -Wcast-qual - -Wdouble-promotion - -Wshadow - -Wstrict-prototypes - -Wpointer-arith - -Wmissing-prototypes - -Werror=implicit-int - -Wno-unused-function - ) - set(cxx_flags - -Wall - -Wextra - -Wpedantic - -Wcast-qual - -Wmissing-declarations - -Wno-unused-function - -Wno-multichar - ) - if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - # g++ only - set(cxx_flags ${cxx_flags} -Wno-format-truncation -Wno-array-bounds) + set(warning_flags -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function) + set(c_flags -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes -Werror=implicit-int + -Werror=implicit-function-declaration) + set(cxx_flags -Wmissing-declarations -Wmissing-noreturn) + set(host_cxx_flags "") + + if (CMAKE_C_COMPILER_ID MATCHES "Clang") + set(warning_flags ${warning_flags} -Wunreachable-code-break -Wunreachable-code-return) + set(host_cxx_flags ${host_cxx_flags} -Wmissing-prototypes -Wextra-semi) + + if ( + (CMAKE_C_COMPILER_ID STREQUAL "Clang" AND CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 3.8.0) OR + (CMAKE_C_COMPILER_ID STREQUAL "AppleClang" AND CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 7.3.0) + ) + set(c_flags ${c_flags} -Wdouble-promotion) + endif() + elseif (CMAKE_C_COMPILER_ID STREQUAL "GNU") + set(c_flags ${c_flags} -Wdouble-promotion) + set(host_cxx_flags ${host_cxx_flags} -Wno-array-bounds) + + if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 7.1.0) + set(host_cxx_flags ${host_cxx_flags} -Wno-format-truncation) + endif() + if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 8.1.0) + set(host_cxx_flags ${host_cxx_flags} -Wextra-semi) + endif() endif() else() # todo : msvc endif() - add_compile_options( - "$<$:${c_flags}>" - "$<$:${cxx_flags}>" - ) + set(c_flags ${c_flags} ${warning_flags}) + set(cxx_flags ${cxx_flags} ${warning_flags}) + add_compile_options("$<$:${c_flags}>" + "$<$:${cxx_flags} ${host_cxx_flags}>") + +endif() +if (NOT MSVC) + set(cuda_flags -Wno-pedantic) endif() +set(cuda_flags ${cxx_flags} -use_fast_math ${cuda_flags}) + +list(JOIN host_cxx_flags " " cuda_host_flags) # pass host compiler flags as a single argument +if (NOT cuda_host_flags STREQUAL "") + set(cuda_flags ${cuda_flags} -Xcompiler ${cuda_host_flags}) +endif() + +add_compile_options("$<$:${cuda_flags}>") if (WIN32) add_compile_definitions(_CRT_SECURE_NO_WARNINGS) diff --git a/plugins/wasi_nn/thirdparty/ggml/common.cpp b/plugins/wasi_nn/thirdparty/ggml/common.cpp index 275c038d..91bece39 100644 --- a/plugins/wasi_nn/thirdparty/ggml/common.cpp +++ b/plugins/wasi_nn/thirdparty/ggml/common.cpp @@ -77,7 +77,7 @@ int32_t get_num_physical_cores() { return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; } -static void process_escapes(std::string& input) { +void process_escapes(std::string& input) { std::size_t input_len = input.length(); std::size_t output_idx = 0; @@ -128,6 +128,15 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { if (params.n_threads <= 0) { params.n_threads = std::thread::hardware_concurrency(); } + } else if (arg == "-tb" || arg == "--threads-batch") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_threads_batch = std::stoi(argv[i]); + if (params.n_threads_batch <= 0) { + params.n_threads_batch = std::thread::hardware_concurrency(); + } } else if (arg == "-p" || arg == "--prompt") { if (++i >= argc) { invalid_param = true; @@ -316,6 +325,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.n_chunks = std::stoi(argv[i]); + } else if (arg == "-np" || arg == "--parallel") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_parallel = std::stoi(argv[i]); + } else if (arg == "-ns" || arg == "--sequences") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_sequences = std::stoi(argv[i]); } else if (arg == "-m" || arg == "--model") { if (++i >= argc) { invalid_param = true; @@ -339,7 +360,19 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { invalid_param = true; break; } - params.lora_adapter = argv[i]; + params.lora_adapter.push_back({argv[i], 1.0f}); + params.use_mmap = false; + } else if (arg == "--lora-scaled") { + if (++i >= argc) { + invalid_param = true; + break; + } + const char * lora_adapter = argv[i]; + if (++i >= argc) { + invalid_param = true; + break; + } + params.lora_adapter.push_back({lora_adapter, std::stof(argv[i])}); params.use_mmap = false; } else if (arg == "--lora-base") { if (++i >= argc) { @@ -355,10 +388,14 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.interactive_first = true; } else if (arg == "-ins" || arg == "--instruct") { params.instruct = true; + } else if (arg == "--infill") { + params.infill = true; } else if (arg == "--multiline-input") { params.multiline_input = true; } else if (arg == "--simple-io") { params.simple_io = true; + } else if (arg == "-cb" || arg == "--cont-batching") { + params.cont_batching = true; } else if (arg == "--color") { params.use_color = true; } else if (arg == "--mlock") { @@ -424,19 +461,11 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.mul_mat_q = false; #else fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Disabling mul_mat_q kernels has no effect.\n"); -#endif // GGML_USE_CUBLAS - } else if (arg == "--low-vram" || arg == "-lv") { -#ifdef GGML_USE_CUBLAS - params.low_vram = true; -#else - fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n"); #endif // GGML_USE_CUBLAS } else if (arg == "--no-mmap") { params.use_mmap = false; } else if (arg == "--numa") { params.numa = true; - } else if (arg == "--export") { - params.export_cgraph = true; } else if (arg == "--verbose-prompt") { params.verbose_prompt = true; } else if (arg == "-r" || arg == "--reverse-prompt") { @@ -455,8 +484,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { if (params.logdir.back() != DIRECTORY_SEPARATOR) { params.logdir += DIRECTORY_SEPARATOR; } - } else if (arg == "--perplexity") { - params.perplexity = true; + } else if (arg == "--perplexity" || arg == "--all-logits") { + params.logits_all = true; } else if (arg == "--ppl-stride") { if (++i >= argc) { invalid_param = true; @@ -605,7 +634,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" (can be specified more than once for multiple prompts).\n"); printf(" --color colorise output to distinguish prompt and user input from generations\n"); printf(" -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n"); - printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); + printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads); + printf(" -tb N, --threads-batch N\n"); + printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n"); printf(" -p PROMPT, --prompt PROMPT\n"); printf(" prompt to start generation with (default: empty)\n"); printf(" -e, --escape process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n"); @@ -620,7 +651,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -f FNAME, --file FNAME\n"); printf(" prompt file to start generation.\n"); printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict); - printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); + printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx); printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k); printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p); @@ -654,12 +685,15 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); printf(" not recommended: doubles context memory required and no measurable increase in quality\n"); printf(" --temp N temperature (default: %.1f)\n", (double)params.temp); - printf(" --perplexity compute perplexity over each ctx window of the prompt\n"); + printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n"); printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n"); printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks); printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); printf(" --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft); printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks); + printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel); + printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences); + printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); if (llama_mlock_supported()) { printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n"); } @@ -677,17 +711,16 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -ts SPLIT --tensor-split SPLIT\n"); printf(" how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n"); printf(" -mg i, --main-gpu i the GPU to use for scratch and small tensors\n"); - printf(" -lv, --low-vram don't allocate VRAM scratch buffer\n"); #ifdef GGML_USE_CUBLAS printf(" -nommq, --no-mul-mat-q\n"); printf(" use " GGML_CUBLAS_NAME " instead of custom mul_mat_q " GGML_CUDA_NAME " kernels.\n"); printf(" Not recommended since this is both slower and uses more VRAM.\n"); #endif // GGML_USE_CUBLAS #endif - printf(" --export export the computation graph to 'llama.ggml'\n"); printf(" --verbose-prompt print prompt before generation\n"); fprintf(stderr, " --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n"); printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); + printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n"); printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); printf(" -m FNAME, --model FNAME\n"); printf(" model path (default: %s)\n", params.model.c_str()); @@ -698,6 +731,18 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf("\n"); } +std::string get_system_info(const gpt_params & params) { + std::ostringstream os; + + os << "system_info: n_threads = " << params.n_threads; + if (params.n_threads_batch != -1) { + os << " (n_threads_batch = " << params.n_threads_batch << ")"; + } + os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info(); + + return os.str(); +} + std::string gpt_random_prompt(std::mt19937 & rng) { const int r = rng() % 10; switch (r) { @@ -711,60 +756,74 @@ std::string gpt_random_prompt(std::mt19937 & rng) { case 7: return "He"; case 8: return "She"; case 9: return "They"; - default: return "To"; } - return "The"; + GGML_UNREACHABLE(); } // // Model utils // -struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) { - auto lparams = llama_context_default_params(); +struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) { + auto mparams = llama_model_default_params(); - lparams.n_ctx = params.n_ctx; - lparams.n_batch = params.n_batch; if (params.n_gpu_layers != -1) { - lparams.n_gpu_layers = params.n_gpu_layers; + mparams.n_gpu_layers = params.n_gpu_layers; } - lparams.main_gpu = params.main_gpu; - lparams.tensor_split = params.tensor_split; - lparams.low_vram = params.low_vram; - lparams.mul_mat_q = params.mul_mat_q; - lparams.seed = params.seed; - lparams.f16_kv = params.memory_f16; - lparams.use_mmap = params.use_mmap; - lparams.use_mlock = params.use_mlock; - lparams.logits_all = params.perplexity; - lparams.embedding = params.embedding; - lparams.rope_freq_base = params.rope_freq_base; - lparams.rope_freq_scale = params.rope_freq_scale; - - return lparams; + mparams.main_gpu = params.main_gpu; + mparams.tensor_split = params.tensor_split; + mparams.use_mmap = params.use_mmap; + mparams.use_mlock = params.use_mlock; + + return mparams; +} + +struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) { + auto cparams = llama_context_default_params(); + + cparams.n_ctx = params.n_ctx; + cparams.n_batch = params.n_batch; + cparams.n_threads = params.n_threads; + cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + cparams.mul_mat_q = params.mul_mat_q; + cparams.seed = params.seed; + cparams.f16_kv = params.memory_f16; + cparams.logits_all = params.logits_all; + cparams.embedding = params.embedding; + cparams.rope_freq_base = params.rope_freq_base; + cparams.rope_freq_scale = params.rope_freq_scale; + + return cparams; } std::tuple llama_init_from_gpt_params(gpt_params & params) { - auto lparams = llama_context_params_from_gpt_params(params); + auto mparams = llama_model_params_from_gpt_params(params); - llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams); + llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams); if (model == NULL) { fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); return std::make_tuple(nullptr, nullptr); } - llama_context * lctx = llama_new_context_with_model(model, lparams); + auto cparams = llama_context_params_from_gpt_params(params); + + llama_context * lctx = llama_new_context_with_model(model, cparams); if (lctx == NULL) { fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); llama_free_model(model); return std::make_tuple(nullptr, nullptr); } - if (!params.lora_adapter.empty()) { + for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) { + const std::string& lora_adapter = std::get<0>(params.lora_adapter[i]); + float lora_scale = std::get<1>(params.lora_adapter[i]); int err = llama_model_apply_lora_from_file(model, - params.lora_adapter.c_str(), - params.lora_base.empty() ? NULL : params.lora_base.c_str(), + lora_adapter.c_str(), + lora_scale, + ((i > 0) || params.lora_base.empty()) + ? NULL + : params.lora_base.c_str(), params.n_threads); if (err != 0) { fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); @@ -781,8 +840,9 @@ std::tuple llama_init_from_gpt_par { LOG("warming up the model with an empty run\n"); - const std::vector tmp = { llama_token_bos(lctx), llama_token_eos(lctx), }; - llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads); + std::vector tmp = { llama_token_bos(lctx), llama_token_eos(lctx), }; + llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); + llama_kv_cache_tokens_rm(lctx, -1, -1); llama_reset_timings(lctx); } @@ -794,16 +854,23 @@ std::tuple llama_init_from_gpt_par // std::vector llama_tokenize( - struct llama_context * ctx, + const struct llama_context * ctx, + const std::string & text, + bool add_bos) { + return llama_tokenize(llama_get_model(ctx), text, add_bos); +} + +std::vector llama_tokenize( + const struct llama_model * model, const std::string & text, bool add_bos) { // upper limit for the number of tokens int n_tokens = text.length() + add_bos; std::vector result(n_tokens); - n_tokens = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos); + n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos); if (n_tokens < 0) { result.resize(-n_tokens); - int check = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos); + int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos); GGML_ASSERT(check == -n_tokens); } else { result.resize(n_tokens); @@ -813,10 +880,10 @@ std::vector llama_tokenize( std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) { std::vector result(8, 0); - const int n_tokens = llama_token_to_piece(ctx, token, result.data(), result.size()); + const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); if (n_tokens < 0) { result.resize(-n_tokens); - int check = llama_token_to_piece(ctx, token, result.data(), result.size()); + int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); GGML_ASSERT(check == -n_tokens); } else { result.resize(n_tokens); @@ -871,7 +938,7 @@ llama_token llama_sample_token( std::vector & candidates, int idx) { const int n_ctx = llama_n_ctx(ctx); - const int n_vocab = llama_n_vocab(ctx); + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const float temp = params.temp; const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; @@ -889,7 +956,7 @@ llama_token llama_sample_token( llama_token id = 0; - float * logits = llama_get_logits(ctx) + idx * n_vocab; + float * logits = llama_get_logits_ith(ctx, idx); // Apply params.logit_bias map for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { @@ -940,11 +1007,11 @@ llama_token llama_sample_token( if (mirostat == 1) { static float mirostat_mu = 2.0f * mirostat_tau; const int mirostat_m = 100; - llama_sample_temperature(ctx, &cur_p, temp); + llama_sample_temp(ctx, &cur_p, temp); id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); } else if (mirostat == 2) { static float mirostat_mu = 2.0f * mirostat_tau; - llama_sample_temperature(ctx, &cur_p, temp); + llama_sample_temp(ctx, &cur_p, temp); id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu); } else { // Temperature sampling @@ -952,7 +1019,7 @@ llama_token llama_sample_token( llama_sample_tail_free (ctx, &cur_p, tfs_z, 1); llama_sample_typical (ctx, &cur_p, typical_p, 1); llama_sample_top_p (ctx, &cur_p, top_p, 1); - llama_sample_temperature(ctx, &cur_p, temp); + llama_sample_temp(ctx, &cur_p, temp); { const int n_top = 10; @@ -1155,7 +1222,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l #endif // NDEBUG fprintf(stream, "model_desc: %s\n", model_desc); - fprintf(stream, "n_vocab: %d # output size of the final layer, 32001 for some models\n", llama_n_vocab(lctx)); + fprintf(stream, "n_vocab: %d # output size of the final layer, 32001 for some models\n", llama_n_vocab(llama_get_model(lctx))); #ifdef __OPTIMIZE__ fprintf(stream, "optimize: true\n"); @@ -1179,7 +1246,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false"); fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); - fprintf(stream, "export: %s # default: false\n", params.export_cgraph ? "true" : "false"); fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", params.frequency_penalty); dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str()); @@ -1208,9 +1274,21 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, " %d: %f", lb.first, lb.second); } - fprintf(stream, "lora: %s\n", params.lora_adapter.c_str()); + fprintf(stream, "lora:\n"); + for (std::tuple la : params.lora_adapter) { + if (std::get<1>(la) != 1.0f) { + continue; + } + fprintf(stream, " - %s\n", std::get<0>(la).c_str()); + } + fprintf(stream, "lora_scaled:\n"); + for (std::tuple la : params.lora_adapter) { + if (std::get<1>(la) == 1.0f) { + continue; + } + fprintf(stream, " - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la)); + } fprintf(stream, "lora_base: %s\n", params.lora_base.c_str()); - fprintf(stream, "low_vram: %s # default: false\n", params.low_vram ? "true" : "false"); fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu); fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false"); fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", params.mirostat); @@ -1253,6 +1331,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale); fprintf(stream, "seed: %d # default: -1 (random seed)\n", params.seed); fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); + fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", params.temp); const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + LLAMA_MAX_DEVICES); diff --git a/plugins/wasi_nn/thirdparty/ggml/common.h b/plugins/wasi_nn/thirdparty/ggml/common.h index 2761503b..e095c56e 100644 --- a/plugins/wasi_nn/thirdparty/ggml/common.h +++ b/plugins/wasi_nn/thirdparty/ggml/common.h @@ -36,12 +36,15 @@ int32_t get_num_physical_cores(); struct gpt_params { uint32_t seed = -1; // RNG seed int32_t n_threads = get_num_physical_cores(); + int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 512; // context size int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) int32_t n_keep = 0; // number of tokens to keep from initial prompt int32_t n_draft = 16; // number of tokens to draft during speculative decoding int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) + int32_t n_parallel = 1; // number of parallel sequences to decode + int32_t n_sequences = 1; // number of sequences to decode int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors @@ -83,8 +86,8 @@ struct gpt_params { std::vector antiprompt; // string upon seeing which more user input is prompted std::string logdir = ""; // directory in which to save YAML log files - std::string lora_adapter = ""; // lora adapter path - std::string lora_base = ""; // base model path for the lora adapter + std::vector> lora_adapter; // lora adapter path with user defined scale + std::string lora_base = ""; // base model path for the lora adapter int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line @@ -93,7 +96,6 @@ struct gpt_params { bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score - bool low_vram = false; // if true, reduce VRAM usage at the cost of performance bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS bool memory_f16 = true; // use f16 instead of f32 for memory kv bool random_prompt = false; // do not randomize prompt if none provided @@ -107,30 +109,36 @@ struct gpt_params { bool interactive_first = false; // wait for user input immediately bool multiline_input = false; // reverse the usage of `\` bool simple_io = false; // improves compatibility with subprocesses and limited consoles + bool cont_batching = false; // insert new sequences for decoding on-the-fly bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens bool instruct = false; // instruction mode (used for Alpaca models) bool penalize_nl = true; // consider newlines as a repeatable token - bool perplexity = false; // compute perplexity over the prompt + bool logits_all = false; // return logits for all tokens in the batch bool use_mmap = true; // use mmap for faster loads bool use_mlock = false; // use mlock to keep model in memory bool numa = false; // attempt optimizations that help on some NUMA systems - bool export_cgraph = false; // export the computation graph bool verbose_prompt = false; // print prompt tokens before generation + bool infill = false; // use infill mode }; bool gpt_params_parse(int argc, char ** argv, gpt_params & params); void gpt_print_usage(int argc, char ** argv, const gpt_params & params); +std::string get_system_info(const gpt_params & params); + std::string gpt_random_prompt(std::mt19937 & rng); +void process_escapes(std::string& input); + // // Model utils // std::tuple llama_init_from_gpt_params(gpt_params & params); +struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params); struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); // @@ -140,7 +148,12 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param // tokenizes a string into a vector of tokens // should work similar to Python's `tokenizer.encode` std::vector llama_tokenize( - struct llama_context * ctx, + const struct llama_context * ctx, + const std::string & text, + bool add_bos); + +std::vector llama_tokenize( + const struct llama_model * model, const std::string & text, bool add_bos); @@ -181,7 +194,7 @@ std::string llama_detokenize_bpe( // - ctx_guidance: context to use for classifier-free guidance, ignore if NULL // - grammar: grammar to use for sampling, ignore if NULL // - last_tokens: needed for repetition penalty, ignore if empty -// - idx: sample from llama_get_logits(ctx) + idx * n_vocab +// - idx: sample from llama_get_logits_ith(ctx, idx) // // returns: // - token: sampled token diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.c b/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.c index 304964be..805759db 100644 --- a/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.c +++ b/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.c @@ -77,7 +77,7 @@ struct free_block { size_t size; }; -#define MAX_FREE_BLOCKS 128 +#define MAX_FREE_BLOCKS 256 struct ggml_allocr { void * data; @@ -187,6 +187,7 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) } tensor->data = addr; + AT_PRINTF("%s: allocated data at %p\n", __func__, tensor->data); #ifdef GGML_ALLOCATOR_DEBUG add_allocated_tensor(alloc, tensor); @@ -218,7 +219,8 @@ static void ggml_allocr_free_tensor(struct ggml_allocr * alloc, struct ggml_tens size_t size = ggml_allocr_get_alloc_size(alloc, tensor); size = aligned_offset(NULL, size, alloc->alignment); - AT_PRINTF("%s: freeing %s (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, size, alloc->n_free_blocks); + AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks); + AT_PRINTF("%s: alloc->data = %p alloc->data+alloc->size = %p alloc->data+alloc->max_size = %p\n", __func__, alloc->data, (char*)alloc->data + alloc->size, (char*)alloc->data + alloc->max_size); #ifdef GGML_ALLOCATOR_DEBUG remove_allocated_tensor(alloc, tensor); @@ -631,3 +633,7 @@ static size_t ggml_allocr_alloc_graph_tensors_n( size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) { return ggml_allocr_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL); } + +size_t ggml_allocr_max_size(struct ggml_allocr * alloc) { + return alloc->max_size; +} diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.h b/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.h index 9559da75..0c224f17 100644 --- a/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.h +++ b/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.h @@ -19,6 +19,7 @@ GGML_API bool ggml_allocr_is_measure(struct ggml_allocr * alloc); GGML_API void ggml_allocr_reset(struct ggml_allocr * alloc); GGML_API void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor); GGML_API size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph); +GGML_API size_t ggml_allocr_max_size(struct ggml_allocr * alloc); #ifdef __cplusplus diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-metal.h b/plugins/wasi_nn/thirdparty/ggml/ggml-metal.h index fca28d37..790cf0bf 100644 --- a/plugins/wasi_nn/thirdparty/ggml/ggml-metal.h +++ b/plugins/wasi_nn/thirdparty/ggml/ggml-metal.h @@ -19,6 +19,8 @@ #pragma once +#include "ggml.h" + #include #include @@ -33,6 +35,8 @@ struct ggml_cgraph; extern "C" { #endif +void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data); + struct ggml_metal_context; // number of command buffers to use diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-metal.m b/plugins/wasi_nn/thirdparty/ggml/ggml-metal.m index 4f3f14e2..b3c463f0 100644 --- a/plugins/wasi_nn/thirdparty/ggml/ggml-metal.m +++ b/plugins/wasi_nn/thirdparty/ggml/ggml-metal.m @@ -11,11 +11,14 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b)) -// TODO: temporary - reuse llama.cpp logging #ifdef GGML_METAL_NDEBUG -#define metal_printf(...) +#define GGML_METAL_LOG_INFO(...) +#define GGML_METAL_LOG_WARN(...) +#define GGML_METAL_LOG_ERROR(...) #else -#define metal_printf(...) fprintf(stderr, __VA_ARGS__) +#define GGML_METAL_LOG_INFO(...) ggml_metal_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__) +#define GGML_METAL_LOG_WARN(...) ggml_metal_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__) +#define GGML_METAL_LOG_ERROR(...) ggml_metal_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) #endif #define UNUSED(x) (void)(x) @@ -66,6 +69,7 @@ GGML_METAL_DECL_KERNEL(soft_max_4); GGML_METAL_DECL_KERNEL(diag_mask_inf); GGML_METAL_DECL_KERNEL(diag_mask_inf_8); + GGML_METAL_DECL_KERNEL(get_rows_f32); GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(get_rows_q4_1); @@ -77,6 +81,7 @@ GGML_METAL_DECL_KERNEL(get_rows_q6_K); GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(norm); + GGML_METAL_DECL_KERNEL(mul_mat_f32_f32); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4); @@ -88,6 +93,7 @@ GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32); GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32); GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32); + GGML_METAL_DECL_KERNEL(mul_mm_f32_f32); GGML_METAL_DECL_KERNEL(mul_mm_f16_f32); GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32); @@ -97,7 +103,8 @@ GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32); - GGML_METAL_DECL_KERNEL(rope); + GGML_METAL_DECL_KERNEL(rope_f32); + GGML_METAL_DECL_KERNEL(rope_f16); GGML_METAL_DECL_KERNEL(alibi_f32); GGML_METAL_DECL_KERNEL(cpy_f32_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f32); @@ -117,8 +124,37 @@ @interface GGMLMetalClass : NSObject @implementation GGMLMetalClass @end +ggml_log_callback ggml_metal_log_callback = NULL; +void * ggml_metal_log_user_data = NULL; + +void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) { + ggml_metal_log_callback = log_callback; + ggml_metal_log_user_data = user_data; +} + +static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){ + if (ggml_metal_log_callback != NULL) { + va_list args; + va_start(args, format); + char buffer[128]; + int len = vsnprintf(buffer, 128, format, args); + if (len < 128) { + ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data); + } else { + char* buffer2 = malloc(len+1); + vsnprintf(buffer2, len+1, format, args); + buffer2[len] = 0; + ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data); + free(buffer2); + } + va_end(args); + } +} + + + struct ggml_metal_context * ggml_metal_init(int n_cb) { - metal_printf("%s: allocating\n", __func__); + GGML_METAL_LOG_INFO("%s: allocating\n", __func__); id device; NSString * s; @@ -128,14 +164,14 @@ @implementation GGMLMetalClass NSArray * devices = MTLCopyAllDevices(); for (device in devices) { s = [device name]; - metal_printf("%s: found device: %s\n", __func__, [s UTF8String]); + GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]); } #endif // Pick and show default Metal device device = MTLCreateSystemDefaultDevice(); s = [device name]; - metal_printf("%s: picking default device: %s\n", __func__, [s UTF8String]); + GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]); // Configure context struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context)); @@ -145,7 +181,7 @@ @implementation GGMLMetalClass ctx->n_buffers = 0; ctx->concur_list_len = 0; - ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT); + ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); #ifdef GGML_SWIFT // load the default.metallib file @@ -162,7 +198,7 @@ @implementation GGMLMetalClass ctx->library = [ctx->device newLibraryWithURL:libURL error:&error]; if (error) { - metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]); + GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); return NULL; } } @@ -175,12 +211,12 @@ @implementation GGMLMetalClass //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"]; NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; - NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; - metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]); + NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; + GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path UTF8String]); NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error]; if (error) { - metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]); + GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); return NULL; } @@ -192,7 +228,7 @@ @implementation GGMLMetalClass ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error]; #endif if (error) { - metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]); + GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); return NULL; } } @@ -204,11 +240,11 @@ @implementation GGMLMetalClass #define GGML_METAL_ADD_KERNEL(name) \ ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \ ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \ - metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \ + GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \ (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \ (int) ctx->pipeline_##name.threadExecutionWidth); \ if (error) { \ - metal_printf("%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ + GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ return NULL; \ } @@ -224,6 +260,7 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(soft_max_4); GGML_METAL_ADD_KERNEL(diag_mask_inf); GGML_METAL_ADD_KERNEL(diag_mask_inf_8); + GGML_METAL_ADD_KERNEL(get_rows_f32); GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(get_rows_q4_1); @@ -235,6 +272,7 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(get_rows_q6_K); GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(norm); + GGML_METAL_ADD_KERNEL(mul_mat_f32_f32); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4); @@ -246,6 +284,7 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32); GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32); GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_f32_f32); GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32); @@ -255,7 +294,8 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32); - GGML_METAL_ADD_KERNEL(rope); + GGML_METAL_ADD_KERNEL(rope_f32); + GGML_METAL_ADD_KERNEL(rope_f16); GGML_METAL_ADD_KERNEL(alibi_f32); GGML_METAL_ADD_KERNEL(cpy_f32_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f32); @@ -264,13 +304,13 @@ @implementation GGMLMetalClass #undef GGML_METAL_ADD_KERNEL } - metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); + GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); #if TARGET_OS_OSX - metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); + GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); if (ctx->device.maxTransferRate != 0) { - metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0); + GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0); } else { - metal_printf("%s: maxTransferRate = built-in GPU\n", __func__); + GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__); } #endif @@ -278,7 +318,7 @@ @implementation GGMLMetalClass } void ggml_metal_free(struct ggml_metal_context * ctx) { - metal_printf("%s: deallocating\n", __func__); + GGML_METAL_LOG_INFO("%s: deallocating\n", __func__); #define GGML_METAL_DEL_KERNEL(name) \ [ctx->function_##name release]; \ [ctx->pipeline_##name release]; @@ -293,7 +333,9 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(gelu); GGML_METAL_DEL_KERNEL(soft_max); GGML_METAL_DEL_KERNEL(soft_max_4); + GGML_METAL_DEL_KERNEL(diag_mask_inf); GGML_METAL_DEL_KERNEL(diag_mask_inf_8); + GGML_METAL_DEL_KERNEL(get_rows_f32); GGML_METAL_DEL_KERNEL(get_rows_f16); GGML_METAL_DEL_KERNEL(get_rows_q4_0); GGML_METAL_DEL_KERNEL(get_rows_q4_1); @@ -305,6 +347,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(get_rows_q6_K); GGML_METAL_DEL_KERNEL(rms_norm); GGML_METAL_DEL_KERNEL(norm); + GGML_METAL_DEL_KERNEL(mul_mat_f32_f32); GGML_METAL_DEL_KERNEL(mul_mat_f16_f32); GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row); GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4); @@ -316,6 +359,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32); GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32); GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32); + GGML_METAL_DEL_KERNEL(mul_mm_f32_f32); GGML_METAL_DEL_KERNEL(mul_mm_f16_f32); GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32); GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32); @@ -325,7 +369,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32); - GGML_METAL_DEL_KERNEL(rope); + GGML_METAL_DEL_KERNEL(rope_f32); + GGML_METAL_DEL_KERNEL(rope_f16); GGML_METAL_DEL_KERNEL(alibi_f32); GGML_METAL_DEL_KERNEL(cpy_f32_f16); GGML_METAL_DEL_KERNEL(cpy_f32_f32); @@ -350,7 +395,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { void * data = NULL; const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n); if (result != 0) { - metal_printf("%s: error: posix_memalign failed\n", __func__); + GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__); return NULL; } @@ -378,7 +423,7 @@ int ggml_metal_if_optimized(struct ggml_metal_context * ctx) { // Metal buffer based on the host memory pointer // static id ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) { - //metal_printf("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach); + //GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach); const int64_t tsize = ggml_nbytes(t); @@ -386,16 +431,17 @@ int ggml_metal_if_optimized(struct ggml_metal_context * ctx) { for (int i = 0; i < ctx->n_buffers; ++i) { const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data; + //metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name); if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) { *offs = (size_t) ioffs; - //metal_printf("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs); + //GGML_METAL_LOG_INFO("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs); return ctx->buffers[i].metal; } } - metal_printf("%s: error: buffer is nil\n", __func__); + GGML_METAL_LOG_ERROR("%s: error: buffer is nil\n", __func__); return nil; } @@ -407,7 +453,7 @@ bool ggml_metal_add_buffer( size_t size, size_t max_size) { if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) { - metal_printf("%s: too many buffers\n", __func__); + GGML_METAL_LOG_ERROR("%s: error: too many buffers\n", __func__); return false; } @@ -417,7 +463,7 @@ bool ggml_metal_add_buffer( const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data; if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) { - metal_printf("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name); + GGML_METAL_LOG_ERROR("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name); return false; } } @@ -438,11 +484,11 @@ bool ggml_metal_add_buffer( ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; if (ctx->buffers[ctx->n_buffers].metal == nil) { - metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0); + GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0); return false; } - metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0); + GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0); ++ctx->n_buffers; } else { @@ -462,13 +508,13 @@ bool ggml_metal_add_buffer( ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; if (ctx->buffers[ctx->n_buffers].metal == nil) { - metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0); + GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0); return false; } - metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i); + GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i); if (i + size_step < size) { - metal_printf("\n"); + GGML_METAL_LOG_INFO("\n"); } ++ctx->n_buffers; @@ -476,17 +522,17 @@ bool ggml_metal_add_buffer( } #if TARGET_OS_OSX - metal_printf(", (%8.2f / %8.2f)", + GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)", ctx->device.currentAllocatedSize / 1024.0 / 1024.0, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) { - metal_printf(", warning: current allocated size is greater than the recommended max working set size\n"); + GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__); } else { - metal_printf("\n"); + GGML_METAL_LOG_INFO("\n"); } #else - metal_printf(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0); + GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0); #endif } @@ -599,7 +645,7 @@ void ggml_metal_graph_find_concurrency( } if (ctx->concur_list_len > GGML_MAX_CONCUR) { - metal_printf("%s: too many elements for metal ctx->concur_list!\n", __func__); + GGML_METAL_LOG_WARN("%s: too many elements for metal ctx->concur_list!\n", __func__); } } @@ -653,7 +699,7 @@ void ggml_metal_graph_compute( continue; } - //metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); + //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); struct ggml_tensor * src0 = gf->nodes[i]->src[0]; struct ggml_tensor * src1 = gf->nodes[i]->src[1]; @@ -697,17 +743,17 @@ void ggml_metal_graph_compute( id id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil; id id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil; - //metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op)); + //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); //if (src0) { - // metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, + // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, // ggml_is_contiguous(src0), src0->name); //} //if (src1) { - // metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, + // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, // ggml_is_contiguous(src1), src1->name); //} //if (dst) { - // metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, + // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, // dst->name); //} @@ -723,29 +769,66 @@ void ggml_metal_graph_compute( case GGML_OP_ADD: { GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); - // utilize float4 - GGML_ASSERT(ne00 % 4 == 0); - const int64_t nb = ne00/4; + bool bcast_row = false; - if (ggml_nelements(src1) == ne10) { + int64_t nb = ne00; + + if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) { // src1 is a row + GGML_ASSERT(ne11 == 1); + + nb = ne00 / 4; [encoder setComputePipelineState:ctx->pipeline_add_row]; + + bcast_row = true; } else { [encoder setComputePipelineState:ctx->pipeline_add]; } [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:3]; - - const int64_t n = ggml_nelements(dst)/4; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; + [encoder setBytes:&nb length:sizeof(nb) atIndex:27]; + + if (bcast_row) { + const int64_t n = ggml_nelements(dst)/4; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } else { + const int nth = MIN(1024, ne0); - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } } break; case GGML_OP_MUL: { GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); // utilize float4 GGML_ASSERT(ne00 % 4 == 0); @@ -753,6 +836,7 @@ void ggml_metal_graph_compute( if (ggml_nelements(src1) == ne10) { // src1 is a row + GGML_ASSERT(ne11 == 1); [encoder setComputePipelineState:ctx->pipeline_mul_row]; } else { [encoder setComputePipelineState:ctx->pipeline_mul]; @@ -768,6 +852,8 @@ void ggml_metal_graph_compute( } break; case GGML_OP_SCALE: { + GGML_ASSERT(ggml_is_contiguous(src0)); + const float scale = *(const float *) src1->data; [encoder setComputePipelineState:ctx->pipeline_scale]; @@ -813,13 +899,13 @@ void ggml_metal_graph_compute( } break; default: { - metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); + GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); GGML_ASSERT(false); } } break; case GGML_OP_SOFT_MAX: { - const int nth = 32; + const int nth = MIN(32, ne00); if (ne00%4 == 0) { [encoder setComputePipelineState:ctx->pipeline_soft_max_4]; @@ -867,13 +953,14 @@ void ggml_metal_graph_compute( // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - if (ggml_is_contiguous(src0) && - ggml_is_contiguous(src1) && + if (!ggml_is_transposed(src0) && + !ggml_is_transposed(src1) && src1t == GGML_TYPE_F32 && [ctx->device supportsFamily:MTLGPUFamilyApple7] && ne00%32 == 0 && - ne11 > 1) { + ne11 > 2) { switch (src0->type) { + case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break; case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break; case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break; @@ -893,9 +980,12 @@ void ggml_metal_graph_compute( [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9]; - [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; + [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { @@ -905,6 +995,11 @@ void ggml_metal_graph_compute( // use custom matrix x vector kernel switch (src0t) { + case GGML_TYPE_F32: + { + [encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32]; + nrows = 4; + } break; case GGML_TYPE_F16: { nth0 = 32; @@ -993,7 +1088,7 @@ void ggml_metal_graph_compute( } break; default: { - metal_printf("Asserting on type %d\n",(int)src0t); + GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t); GGML_ASSERT(false && "not implemented"); } }; @@ -1045,6 +1140,7 @@ void ggml_metal_graph_compute( case GGML_OP_GET_ROWS: { switch (src0->type) { + case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break; case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; @@ -1060,9 +1156,9 @@ void ggml_metal_graph_compute( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4]; - [encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5]; const int64_t n = ggml_nelements(src1); @@ -1073,7 +1169,7 @@ void ggml_metal_graph_compute( float eps; memcpy(&eps, dst->op_params, sizeof(float)); - const int nth = 512; + const int nth = MIN(512, ne00); [encoder setComputePipelineState:ctx->pipeline_rms_norm]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1092,7 +1188,7 @@ void ggml_metal_graph_compute( float eps; memcpy(&eps, dst->op_params, sizeof(float)); - const int nth = 256; + const int nth = MIN(256, ne00); [encoder setComputePipelineState:ctx->pipeline_norm]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1110,6 +1206,8 @@ void ggml_metal_graph_compute( { GGML_ASSERT((src0t == GGML_TYPE_F32)); + const int nth = MIN(1024, ne00); + const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past); const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; @@ -1143,12 +1241,14 @@ void ggml_metal_graph_compute( [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; [encoder setBytes:&m0 length:sizeof( float) atIndex:18]; - const int nth = 32; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_ROPE: { + GGML_ASSERT(ne10 == ne02); + + const int nth = MIN(1024, ne00); + const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; @@ -1158,38 +1258,44 @@ void ggml_metal_graph_compute( memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); - [encoder setComputePipelineState:ctx->pipeline_rope]; + switch (src0->type) { + case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break; + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break; + default: GGML_ASSERT(false); + }; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&n_past length:sizeof( int) atIndex:18]; - [encoder setBytes:&n_dims length:sizeof( int) atIndex:19]; - [encoder setBytes:&mode length:sizeof( int) atIndex:20]; - [encoder setBytes:&freq_base length:sizeof(float) atIndex:21]; - [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&n_past length:sizeof( int) atIndex:19]; + [encoder setBytes:&n_dims length:sizeof( int) atIndex:20]; + [encoder setBytes:&mode length:sizeof( int) atIndex:21]; + [encoder setBytes:&freq_base length:sizeof(float) atIndex:22]; + [encoder setBytes:&freq_scale length:sizeof(float) atIndex:23]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: { - const int nth = 32; + const int nth = MIN(1024, ne00); switch (src0t) { case GGML_TYPE_F32: @@ -1234,7 +1340,7 @@ void ggml_metal_graph_compute( } break; default: { - metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); + GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); GGML_ASSERT(false); } } @@ -1259,7 +1365,7 @@ void ggml_metal_graph_compute( MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status]; if (status != MTLCommandBufferStatusCompleted) { - metal_printf("%s: command buffer %d failed with status %lu\n", __func__, i, status); + GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); GGML_ASSERT(false); } } diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-metal.metal b/plugins/wasi_nn/thirdparty/ggml/ggml-metal.metal index f45b1490..5e1af6a0 100644 --- a/plugins/wasi_nn/thirdparty/ggml/ggml-metal.metal +++ b/plugins/wasi_nn/thirdparty/ggml/ggml-metal.metal @@ -24,12 +24,59 @@ typedef struct { int8_t qs[QK8_0]; // quants } block_q8_0; +// general-purpose kernel for addition of two tensors +// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3 +// cons: not very efficient kernel void kernel_add( - device const float4 * src0, - device const float4 * src1, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] + src1[tpig]; + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant int64_t & nb00, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant int64_t & nb10, + constant int64_t & nb11, + constant int64_t & nb12, + constant int64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant int64_t & nb0, + constant int64_t & nb1, + constant int64_t & nb2, + constant int64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0]; + + src0_ptr += ntg.x*nb00; + src1_ptr += ntg.x*nb10; + dst_ptr += ntg.x*nb0; + } } // assumption: src1 is a row @@ -38,7 +85,7 @@ kernel void kernel_add_row( device const float4 * src0, device const float4 * src1, device float4 * dst, - constant int64_t & nb, + constant int64_t & nb [[buffer(27)]], uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] + src1[tpig % nb]; } @@ -118,7 +165,7 @@ kernel void kernel_soft_max( device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; // parallel max - float lmax = psrc0[tpitg[0]]; + float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY; for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) { lmax = MAX(lmax, psrc0[i00]); } @@ -158,7 +205,7 @@ kernel void kernel_soft_max_4( device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); // parallel max - float4 lmax4 = psrc4[tpitg[0]]; + float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY; for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) { lmax4 = fmax(lmax4, psrc4[i00]); } @@ -523,6 +570,79 @@ kernel void kernel_mul_mat_q8_0_f32( } } +#define N_F32_F32 4 + +kernel void kernel_mul_mat_f32_f32( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int64_t r0 = tgpig.x; + const int64_t rb = tgpig.y*N_F32_F32; + const int64_t im = tgpig.z; + + device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); + + if (ne00 < 128) { + for (int row = 0; row < N_F32_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + device const float4 * x4 = (device const float4 *)x; + for (int row = 0; row < N_F32_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const float4 * y4 = (device const float4 *) y; + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} + kernel void kernel_mul_mat_f16_f32_1row( device const char * src0, device const char * src1, @@ -733,30 +853,61 @@ kernel void kernel_alibi_f32( } } +typedef void (rope_t)( + device const void * src0, + device const int32_t * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int & n_past, + constant int & n_dims, + constant int & mode, + constant float & freq_base, + constant float & freq_scale, + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg[[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]); + +template kernel void kernel_rope( - device const void * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & mode, - constant float & freq_base, - constant float & freq_scale, + device const void * src0, + device const int32_t * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int & n_past, + constant int & n_dims, + constant int & mode, + constant float & freq_base, + constant float & freq_scale, uint tiitg[[thread_index_in_threadgroup]], uint3 tptg[[threads_per_threadgroup]], uint3 tgpig[[threadgroup_position_in_grid]]) { @@ -766,7 +917,9 @@ kernel void kernel_rope( const bool is_neox = mode & 2; - const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + device const int32_t * pos = src1; + + const int64_t p = pos[i2]; const float theta_0 = freq_scale * (float)p; const float inv_ndims = -1.f/n_dims; @@ -778,11 +931,11 @@ kernel void kernel_rope( const float cos_theta = cos(theta); const float sin_theta = sin(theta); - device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - const float x0 = src[0]; - const float x1 = src[1]; + const T x0 = src[0]; + const T x1 = src[1]; dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[1] = x0*sin_theta + x1*cos_theta; @@ -797,8 +950,8 @@ kernel void kernel_rope( const int64_t i0 = ib*n_dims + ic/2; - device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); const float x0 = src[0]; const float x1 = src[n_dims/2]; @@ -810,6 +963,9 @@ kernel void kernel_rope( } } +template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope; +template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope; + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, @@ -1200,8 +1356,8 @@ kernel void kernel_mul_mat_q3_K_f32( float yl[32]; - const uint16_t kmask1 = 0x3030; - const uint16_t kmask2 = 0x0f0f; + //const uint16_t kmask1 = 0x3030; + //const uint16_t kmask2 = 0x0f0f; const int tid = tiisg/4; const int ix = tiisg%4; @@ -1321,7 +1477,6 @@ kernel void kernel_mul_mat_q3_K_f32( dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row]; } } - } #else kernel void kernel_mul_mat_q3_K_f32( @@ -1400,13 +1555,13 @@ kernel void kernel_mul_mat_q4_K_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], + constant int64_t & ne01 [[buffer(4)]], + constant int64_t & ne02 [[buffer(5)]], + constant int64_t & ne10 [[buffer(9)]], + constant int64_t & ne12 [[buffer(11)]], + constant int64_t & ne0 [[buffer(15)]], + constant int64_t & ne1 [[buffer(16)]], + constant uint & gqa [[buffer(17)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1865,6 +2020,15 @@ kernel void kernel_mul_mat_q6_K_f32( //============================= templates and their specializations ============================= +// NOTE: this is not dequantizing - we are simply fitting the template +template +void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { + float4x4 temp = *(((device float4x4 *)src)); + for (int i = 0; i < 16; i++){ + reg[i/4][i%4] = temp[i/4][i%4]; + } +} + template void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { half4x4 temp = *(((device half4x4 *)src)); @@ -1875,7 +2039,6 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) template void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 1); const float d1 = il ? (xb->d / 16.h) : xb->d; const float d2 = d1 / 256.f; @@ -1887,12 +2050,10 @@ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md; reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md; } - } template void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 2); const float d1 = il ? (xb->d / 16.h) : xb->d; const float d2 = d1 / 256.f; @@ -1964,7 +2125,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg for (int i = 0; i < 16; ++i) { reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml); } - #else float kcoef = il&1 ? 1.f/16.f : 1.f; uint16_t kmask = il&1 ? 0xF0 : 0x0F; @@ -2008,7 +2168,6 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg for (int i = 0; i < 16; ++i) { reg[i/4][i%4] = dl * (q[i] & mask) - ml; } - } template @@ -2110,22 +2269,25 @@ kernel void kernel_get_rows( // each block_q contains 16*nl weights template kernel void kernel_mul_mm(device const uchar * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant int64_t & nb01, - constant int64_t & nb02, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & gqa, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - threadgroup half * sa = ((threadgroup half *)shared_memory); + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & ne12, + constant int64_t & nb10, + constant int64_t & nb11, + constant int64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & gqa, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup half * sa = (threadgroup half *)(shared_memory); threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); const uint r0 = tgpig.y; @@ -2138,7 +2300,7 @@ kernel void kernel_mul_mm(device const uchar * src0, short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - simdgroup_half8x8 ma[4]; + simdgroup_half8x8 ma[4]; simdgroup_float8x8 mb[2]; simdgroup_float8x8 c_res[8]; for (int i = 0; i < 8; i++){ @@ -2146,10 +2308,15 @@ kernel void kernel_mul_mm(device const uchar * src0, } short il = (tiitg % THREAD_PER_ROW); - uint offset0 = im/gqa*nb02; ushort offset1 = il/nl; - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; - device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \ - + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1; + + uint offset0 = im/gqa*nb02; + ushort offset1 = il/nl; + + device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; + device const float * y = (device const float *)(src1 + + nb12 * im + + nb11 * (r1 * BLOCK_SIZE_N + thread_col) + + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { //load data and store to threadgroup memory @@ -2229,6 +2396,7 @@ kernel void kernel_mul_mm(device const uchar * src0, typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \ constant uint64_t &, constant uint64_t &, uint, uint, uint); +template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; @@ -2239,14 +2407,28 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; -typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\ - constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \ - constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint); - -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; +typedef void (mat_mm_t)( + device const uchar * src0, + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & ne12, + constant int64_t & nb10, + constant int64_t & nb11, + constant int64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & gqa, + threadgroup uchar *, uint3, uint, uint); + +template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml.c b/plugins/wasi_nn/thirdparty/ggml/ggml.c index a0be068d..bf1426d2 100644 --- a/plugins/wasi_nn/thirdparty/ggml/ggml.c +++ b/plugins/wasi_nn/thirdparty/ggml/ggml.c @@ -89,7 +89,9 @@ static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(vo static int pthread_join(pthread_t thread, void * unused) { (void) unused; - return (int) WaitForSingleObject(thread, INFINITE); + int ret = (int) WaitForSingleObject(thread, INFINITE); + CloseHandle(thread); + return ret; } static int sched_yield (void) { @@ -134,6 +136,7 @@ typedef void * thread_ret_t; #define GGML_SOFT_MAX_UNROLL 4 #define GGML_VEC_DOT_UNROLL 2 +#define GGML_VEC_MAD_UNROLL 32 // // logging @@ -242,18 +245,18 @@ inline static void * ggml_aligned_malloc(size_t size) { // #define GGML_TENSOR_UNARY_OP_LOCALS \ - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); \ - GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); \ - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); \ - GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) #define GGML_TENSOR_BINARY_OP_LOCALS \ - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); \ - GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); \ - GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); \ - GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); \ - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); \ - GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \ + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) #if defined(GGML_USE_ACCELERATE) #include @@ -1863,7 +1866,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) { #define GGML_F16x8_ADD vaddq_f16 #define GGML_F16x8_MUL vmulq_f16 #define GGML_F16x8_REDUCE(res, x) \ - { \ + do { \ int offset = GGML_F16_ARR >> 1; \ for (int i = 0; i < offset; ++i) { \ x[i] = vaddq_f16(x[i], x[offset+i]); \ @@ -1879,7 +1882,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) { const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \ const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \ res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \ - } + } while (0) #define GGML_F16_VEC GGML_F16x8 #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO @@ -1940,7 +1943,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) { #define GGML_F32x8_ADD _mm256_add_ps #define GGML_F32x8_MUL _mm256_mul_ps #define GGML_F32x8_REDUCE(res, x) \ -{ \ +do { \ int offset = GGML_F32_ARR >> 1; \ for (int i = 0; i < offset; ++i) { \ x[i] = _mm256_add_ps(x[i], x[offset+i]); \ @@ -1957,7 +1960,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) { _mm256_extractf128_ps(x[0], 1)); \ const __m128 t1 = _mm_hadd_ps(t0, t0); \ res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \ -} +} while (0) // TODO: is this optimal ? #define GGML_F32_VEC GGML_F32x8 @@ -3707,6 +3710,58 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float #endif } +// xs and vs are byte strides of x and v +inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) { + + const float * restrict x[GGML_VEC_MAD_UNROLL]; + const float * restrict v[GGML_VEC_MAD_UNROLL]; + + for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) { + x[i] = (const float *) ((const char *) xv + i*xs); + v[i] = (const float *) ((const char *) vv + i*vs); + } + +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL]; + + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + vx[k] = GGML_F32_VEC_SET1(v[k][0]); + } + + GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]); + } + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + for (int i = np; i < n; ++i) { + y[i] += x[k][i]*v[k][0]; + } + } +#else + // scalar + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + for (int i = 0; i < n; ++i) { + y[i] += x[k][i]*v[k][0]; + } + } +#endif +} + //inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { #if defined(GGML_USE_ACCELERATE) @@ -4392,10 +4447,9 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - return - (t0->ne[1] == t1->ne[1]) && - (t0->ne[2] == t1->ne[2]) && - (t0->ne[3] == t1->ne[3]); + return (t0->ne[1] == t1->ne[1]) && + (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable + (t1->ne[3]%t0->ne[3] == 0); } enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { @@ -5065,43 +5119,78 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { return tensor; } +void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3) { + const int64_t ne2 = tensor->ne[2]; + const int64_t ne1 = tensor->ne[1]; + const int64_t ne0 = tensor->ne[0]; + + const int64_t i3_ = (i/(ne2*ne1*ne0)); + const int64_t i2_ = (i - i3_*ne2*ne1*ne0)/(ne1*ne0); + const int64_t i1_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0)/ne0; + const int64_t i0_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0 - i1_*ne0); + + if (i0) { + * i0 = i0_; + } + if (i1) { + * i1 = i1_; + } + if (i2) { + * i2 = i2_; + } + if (i3) { + * i3 = i3_; + } +} + int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { + if (!ggml_is_contiguous(tensor)) { + int64_t id[4] = { 0, 0, 0, 0 }; + ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); + return ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]); + } switch (tensor->type) { case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); return ((int8_t *)(tensor->data))[i]; - } break; + } case GGML_TYPE_I16: { GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); return ((int16_t *)(tensor->data))[i]; - } break; + } case GGML_TYPE_I32: { GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); return ((int32_t *)(tensor->data))[i]; - } break; + } case GGML_TYPE_F16: { GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); - } break; + } case GGML_TYPE_F32: { GGML_ASSERT(tensor->nb[0] == sizeof(float)); return ((float *)(tensor->data))[i]; - } break; + } default: { GGML_ASSERT(false); - } break; + } } return 0.0f; } void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { + if (!ggml_is_contiguous(tensor)) { + int64_t id[4] = { 0, 0, 0, 0 }; + ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); + ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value); + return; + } switch (tensor->type) { case GGML_TYPE_I8: { @@ -5135,43 +5224,104 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { } } +int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) { + void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; + switch (tensor->type) { + case GGML_TYPE_I8: + return ((int8_t *) data)[0]; + case GGML_TYPE_I16: + return ((int16_t *) data)[0]; + case GGML_TYPE_I32: + return ((int32_t *) data)[0]; + case GGML_TYPE_F16: + return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]); + case GGML_TYPE_F32: + return ((float *) data)[0]; + default: + GGML_ASSERT(false); + } + + return 0.0f; +} + +void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) { + void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; + switch (tensor->type) { + case GGML_TYPE_I8: + { + ((int8_t *)(data))[0] = value; + } break; + case GGML_TYPE_I16: + { + ((int16_t *)(data))[0] = value; + } break; + case GGML_TYPE_I32: + { + ((int32_t *)(data))[0] = value; + } break; + case GGML_TYPE_F16: + { + ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value); + } break; + case GGML_TYPE_F32: + { + ((float *)(data))[0] = value; + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { + if (!ggml_is_contiguous(tensor)) { + int64_t id[4] = { 0, 0, 0, 0 }; + ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); + return ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]); + } switch (tensor->type) { case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); return ((int8_t *)(tensor->data))[i]; - } break; + } case GGML_TYPE_I16: { GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); return ((int16_t *)(tensor->data))[i]; - } break; + } case GGML_TYPE_I32: { GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); return ((int32_t *)(tensor->data))[i]; - } break; + } case GGML_TYPE_F16: { GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); - } break; + } case GGML_TYPE_F32: { GGML_ASSERT(tensor->nb[0] == sizeof(float)); return ((float *)(tensor->data))[i]; - } break; + } default: { GGML_ASSERT(false); - } break; + } } return 0.0f; } void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { + if (!ggml_is_contiguous(tensor)) { + int64_t id[4] = { 0, 0, 0, 0 }; + ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); + ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value); + return; + } switch (tensor->type) { case GGML_TYPE_I8: { @@ -5205,6 +5355,56 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { } } +float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) { + void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; + switch (tensor->type) { + case GGML_TYPE_I8: + return ((int8_t *) data)[0]; + case GGML_TYPE_I16: + return ((int16_t *) data)[0]; + case GGML_TYPE_I32: + return ((int32_t *) data)[0]; + case GGML_TYPE_F16: + return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]); + case GGML_TYPE_F32: + return ((float *) data)[0]; + default: + GGML_ASSERT(false); + } + + return 0.0f; +} + +void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) { + void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; + switch (tensor->type) { + case GGML_TYPE_I8: + { + ((int8_t *)(data))[0] = value; + } break; + case GGML_TYPE_I16: + { + ((int16_t *)(data))[0] = value; + } break; + case GGML_TYPE_I32: + { + ((int32_t *)(data))[0] = value; + } break; + case GGML_TYPE_F16: + { + ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value); + } break; + case GGML_TYPE_F32: + { + ((float *)(data))[0] = value; + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + void * ggml_get_data(const struct ggml_tensor * tensor) { return tensor->data; } @@ -5347,6 +5547,44 @@ struct ggml_tensor * ggml_add_inplace( return ggml_add_impl(ctx, a, b, true); } +// ggml_add_cast + +static struct ggml_tensor * ggml_add_cast_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_type type) { + // TODO: support less-strict constraint + // GGML_ASSERT(ggml_can_repeat(b, a)); + GGML_ASSERT(ggml_can_repeat_rows(b, a)); + GGML_ASSERT(ggml_is_quantized(a->type)); // currently only supported for quantized input + + bool is_node = false; + + if (a->grad || b->grad) { + // TODO: support backward pass for broadcasting + GGML_ASSERT(ggml_are_same_shape(a, b)); + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor(ctx, type, a->n_dims, a->ne); + + result->op = GGML_OP_ADD; + result->grad = is_node ? ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, a->ne) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_add_cast( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_type type) { + return ggml_add_cast_impl(ctx, a, b, type); +} + // ggml_add1 static struct ggml_tensor * ggml_add1_impl( @@ -5783,7 +6021,6 @@ struct ggml_tensor * ggml_repeat( result->op = GGML_OP_REPEAT; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = b; return result; } @@ -5811,7 +6048,6 @@ struct ggml_tensor * ggml_repeat_back( result->op = GGML_OP_REPEAT_BACK; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = b; return result; } @@ -6186,8 +6422,9 @@ struct ggml_tensor * ggml_out_prod( is_node = true; } - const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne); + // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3] + const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne); result->op = GGML_OP_OUT_PROD; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -6406,6 +6643,54 @@ struct ggml_tensor * ggml_cont_inplace( return ggml_cont_impl(ctx, a, true); } + +// make contiguous, with new shape +GGML_API struct ggml_tensor * ggml_cont_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0) { + return ggml_cont_4d(ctx, a, ne0, 1, 1, 1); +} + +GGML_API struct ggml_tensor * ggml_cont_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1) { + return ggml_cont_4d(ctx, a, ne0, ne1, 1, 1); +} + +GGML_API struct ggml_tensor * ggml_cont_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2) { + return ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1); +} + +struct ggml_tensor * ggml_cont_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3) { + GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3)); + + bool is_node = false; + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); + ggml_format_name(result, "%s (cont)", a->name); + + result->op = GGML_OP_CONT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + + // ggml_reshape struct ggml_tensor * ggml_reshape( @@ -6413,7 +6698,7 @@ struct ggml_tensor * ggml_reshape( struct ggml_tensor * a, struct ggml_tensor * b) { GGML_ASSERT(ggml_is_contiguous(a)); - GGML_ASSERT(ggml_is_contiguous(b)); + // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous. GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b)); bool is_node = false; @@ -6786,7 +7071,6 @@ struct ggml_tensor * ggml_get_rows_back( result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; - result->src[2] = c; return result; } @@ -6968,7 +7252,7 @@ struct ggml_tensor * ggml_soft_max_back_inplace( static struct ggml_tensor * ggml_rope_impl( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx, @@ -6977,7 +7261,10 @@ static struct ggml_tensor * ggml_rope_impl( float xpos_base, bool xpos_down, bool inplace) { - GGML_ASSERT(n_past >= 0); + GGML_ASSERT(ggml_is_vector(b)); + GGML_ASSERT(b->type == GGML_TYPE_I32); + GGML_ASSERT(a->ne[2] == b->ne[0]); + bool is_node = false; if (a->grad) { @@ -6986,7 +7273,7 @@ static struct ggml_tensor * ggml_rope_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - int32_t params[8] = { n_past, n_dims, mode, n_ctx }; + int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx }; memcpy(params + 4, &freq_base, sizeof(float)); memcpy(params + 5, &freq_scale, sizeof(float)); memcpy(params + 6, &xpos_base, sizeof(float)); @@ -6996,6 +7283,7 @@ static struct ggml_tensor * ggml_rope_impl( result->op = GGML_OP_ROPE; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; + result->src[1] = b; return result; } @@ -7003,55 +7291,55 @@ static struct ggml_tensor * ggml_rope_impl( struct ggml_tensor * ggml_rope( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false); + return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false); } struct ggml_tensor * ggml_rope_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true); + return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true); } struct ggml_tensor * ggml_rope_custom( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx, float freq_base, float freq_scale) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false); + return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false); } struct ggml_tensor * ggml_rope_custom_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx, float freq_base, float freq_scale) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true); + return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true); } struct ggml_tensor * ggml_rope_xpos_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, float base, bool down) { - return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true); + return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true); } // ggml_rope_back @@ -7059,7 +7347,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace( struct ggml_tensor * ggml_rope_back( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx, @@ -7067,7 +7355,10 @@ struct ggml_tensor * ggml_rope_back( float freq_scale, float xpos_base, bool xpos_down) { - GGML_ASSERT(n_past >= 0); + GGML_ASSERT(ggml_is_vector(b)); + GGML_ASSERT(b->type == GGML_TYPE_I32); + GGML_ASSERT(a->ne[2] == b->ne[0]); + GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet"); bool is_node = false; @@ -7078,7 +7369,7 @@ struct ggml_tensor * ggml_rope_back( struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - int32_t params[8] = { n_past, n_dims, mode, n_ctx }; + int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx }; memcpy(params + 4, &freq_base, sizeof(float)); memcpy(params + 5, &freq_scale, sizeof(float)); memcpy(params + 6, &xpos_base, sizeof(float)); @@ -7088,6 +7379,7 @@ struct ggml_tensor * ggml_rope_back( result->op = GGML_OP_ROPE_BACK; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; + result->src[1] = b; return result; } @@ -7484,27 +7776,30 @@ struct ggml_tensor * ggml_flash_attn_back( // d shape [D,N,ne2,ne3] // q shape [D,N,ne2,ne3] - // k shape [D,M,ne2,ne3] - // v shape [M,D,ne2,ne3] + // k shape [D,M,kvne2,ne3] + // v shape [M,D,kvne2,ne3] - const int64_t D = q->ne[0]; - const int64_t N = q->ne[1]; - const int64_t M = k->ne[1]; - const int64_t ne2 = q->ne[2]; - const int64_t ne3 = q->ne[3]; + const int64_t D = q->ne[0]; + const int64_t N = q->ne[1]; + const int64_t M = k->ne[1]; + const int64_t ne2 = q->ne[2]; + const int64_t ne3 = q->ne[3]; + const int64_t kvne2 = k->ne[2]; GGML_ASSERT(k->ne[0] == D); GGML_ASSERT(v->ne[0] == M); GGML_ASSERT(v->ne[1] == D); GGML_ASSERT(d->ne[0] == D); GGML_ASSERT(d->ne[1] == N); - GGML_ASSERT(k->ne[2] == ne2); + GGML_ASSERT(k->ne[2] == kvne2); GGML_ASSERT(k->ne[3] == ne3); - GGML_ASSERT(v->ne[2] == ne2); + GGML_ASSERT(v->ne[2] == kvne2); GGML_ASSERT(v->ne[3] == ne3); GGML_ASSERT(d->ne[2] == ne2); GGML_ASSERT(d->ne[3] == ne3); + GGML_ASSERT(ne2 % kvne2 == 0); + bool is_node = false; if (q->grad || k->grad || v->grad) { @@ -7514,14 +7809,23 @@ struct ggml_tensor * ggml_flash_attn_back( } // store gradients of q, k and v as continuous tensors concatenated in result. - // q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3] - // gradq->data = result->data - // gradk->data = result->data + nb0*D*N*ne2*ne3 - // gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3 // note: v and gradv are actually transposed, i.e. v->ne[0] != D. - int64_t ne[4] = {D,M+N+M,ne2,ne3}; + const int64_t elem_q = ggml_nelements(q); + const int64_t elem_k = ggml_nelements(k); + const int64_t elem_v = ggml_nelements(v); - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + enum ggml_type result_type = GGML_TYPE_F32; + GGML_ASSERT(ggml_blck_size(result_type) == 1); + const size_t tsize = ggml_type_size(result_type); + + const size_t offs_q = 0; + const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN); + const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN); + const size_t end = offs_v + GGML_PAD(elem_v * tsize, GGML_MEM_ALIGN); + + const size_t nelements = (end + tsize - 1)/tsize; + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nelements); int32_t masked_i = masked ? 1 : 0; ggml_set_op_params(result, &masked_i, sizeof(masked_i)); @@ -8214,7 +8518,7 @@ static void ggml_compute_forward_dup_f16( return; } - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS const int ith = params->ith; // thread index const int nth = params->nth; // number of threads @@ -8485,7 +8789,7 @@ static void ggml_compute_forward_dup_f32( return; } - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS const int ith = params->ith; // thread index const int nth = params->nth; // number of threads @@ -8766,7 +9070,7 @@ static void ggml_compute_forward_add_f32( const int nr = ggml_nrows(src0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); @@ -8798,8 +9102,6 @@ static void ggml_compute_forward_add_f32( #else ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr); #endif - // } - // } } } else { // src1 is not contiguous @@ -8841,7 +9143,7 @@ static void ggml_compute_forward_add_f16_f32( const int nr = ggml_nrows(src0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); @@ -8895,7 +9197,7 @@ static void ggml_compute_forward_add_f16_f16( const int nr = ggml_nrows(src0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F16); @@ -8946,14 +9248,15 @@ static void ggml_compute_forward_add_q_f32( const int nr = ggml_nrows(src0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; const int nth = params->nth; const enum ggml_type type = src0->type; + const enum ggml_type dtype = dst->type; ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; - ggml_from_float_t const quantize_row_q = type_traits[type].from_float; + ggml_from_float_t const quantize_row_q = type_traits[dtype].from_float; // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == ggml_type_size(type)); @@ -8965,7 +9268,6 @@ static void ggml_compute_forward_add_q_f32( GGML_ASSERT(nb2 <= nb3); GGML_ASSERT(ggml_is_quantized(src0->type)); - GGML_ASSERT(dst->type == src0->type); GGML_ASSERT(src1->type == GGML_TYPE_F32); // rows per thread @@ -9003,7 +9305,11 @@ static void ggml_compute_forward_add_q_f32( // add src1 ggml_vec_acc_f32(ne00, wdata, src1_row); // quantize row to dst - quantize_row_q(wdata, dst_row, ne00); + if (quantize_row_q != NULL) { + quantize_row_q(wdata, dst_row, ne00); + } else { + memcpy(dst_row, wdata, ne0*nb0); + } } } @@ -9068,7 +9374,7 @@ static void ggml_compute_forward_add1_f32( const int nr = ggml_nrows(src0); - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); @@ -9123,7 +9429,7 @@ static void ggml_compute_forward_add1_f16_f32( const int nr = ggml_nrows(src0); - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); @@ -9173,7 +9479,7 @@ static void ggml_compute_forward_add1_f16_f16( const int nr = ggml_nrows(src0); - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F16); @@ -9223,7 +9529,7 @@ static void ggml_compute_forward_add1_q_f32( const int nr = ggml_nrows(src0); - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS const enum ggml_type type = src0->type; ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; @@ -9351,8 +9657,8 @@ static void ggml_compute_forward_acc_f32( const int nr = ggml_nrows(src1); const int nc = src1->ne[0]; - GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); - GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) // src0 and dst as viewed during acc const size_t nb0 = ggml_element_size(src0); @@ -9441,7 +9747,7 @@ static void ggml_compute_forward_sub_f32( const int nr = ggml_nrows(src0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); @@ -9531,7 +9837,7 @@ static void ggml_compute_forward_mul_f32( const int64_t nr = ggml_nrows(src0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); @@ -9622,7 +9928,7 @@ static void ggml_compute_forward_div_f32( const int nr = ggml_nrows(src0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); @@ -9831,8 +10137,8 @@ static void ggml_compute_forward_sum_f32( assert(ggml_is_scalar(dst)); assert(src0->nb[0] == sizeof(float)); - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); - GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) ggml_float sum = 0; ggml_float row_sum = 0; @@ -9863,8 +10169,8 @@ static void ggml_compute_forward_sum_f16( assert(src0->nb[0] == sizeof(ggml_fp16_t)); - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); - GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) float sum = 0; float row_sum = 0; @@ -9917,7 +10223,7 @@ static void ggml_compute_forward_sum_rows_f32( GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(dst->nb[0] == sizeof(float)); - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS GGML_ASSERT(ne0 == 1); GGML_ASSERT(ne1 == ne01); @@ -9967,7 +10273,7 @@ static void ggml_compute_forward_mean_f32( assert(src0->nb[0] == sizeof(float)); - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS assert(ne0 == 1); assert(ne1 == ne01); @@ -10067,7 +10373,7 @@ static void ggml_compute_forward_repeat_f32( return; } - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS // guaranteed to be an integer due to the check in ggml_can_repeat const int nr0 = (int)(ne0/ne00); @@ -10099,11 +10405,61 @@ static void ggml_compute_forward_repeat_f32( } } +static void ggml_compute_forward_repeat_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(params->ith == 0); + GGML_ASSERT(ggml_can_repeat(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_TENSOR_UNARY_OP_LOCALS; + + // guaranteed to be an integer due to the check in ggml_can_repeat + const int nr0 = (int)(ne0/ne00); + const int nr1 = (int)(ne1/ne01); + const int nr2 = (int)(ne2/ne02); + const int nr3 = (int)(ne3/ne03); + + // TODO: support for transposed / permuted tensors + GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + // TODO: maybe this is not optimal? + for (int i3 = 0; i3 < nr3; i3++) { + for (int k3 = 0; k3 < ne03; k3++) { + for (int i2 = 0; i2 < nr2; i2++) { + for (int k2 = 0; k2 < ne02; k2++) { + for (int i1 = 0; i1 < nr1; i1++) { + for (int k1 = 0; k1 < ne01; k1++) { + for (int i0 = 0; i0 < nr0; i0++) { + ggml_fp16_t * y = (ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0); + ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01); + // ggml_vec_cpy_f16(ne00, y, x) + for (int i = 0; i < ne00; ++i) { + y[i] = x[i]; + } + } + } + } + } + } + } + } +} + static void ggml_compute_forward_repeat( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_repeat_f16(params, src0, dst); + } break; case GGML_TYPE_F32: { ggml_compute_forward_repeat_f32(params, src0, dst); @@ -10128,7 +10484,7 @@ static void ggml_compute_forward_repeat_back_f32( return; } - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS // guaranteed to be an integer due to the check in ggml_can_repeat const int nr0 = (int)(ne00/ne0); @@ -10206,7 +10562,7 @@ static void ggml_compute_forward_concat_f32( const int ith = params->ith; - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS // TODO: support for transposed / permuted tensors GGML_ASSERT(nb0 == sizeof(float)); @@ -10808,7 +11164,7 @@ static void ggml_compute_forward_norm_f32( const int ith = params->ith; const int nth = params->nth; - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS float eps; memcpy(&eps, dst->op_params, sizeof(float)); @@ -10877,7 +11233,7 @@ static void ggml_compute_forward_rms_norm_f32( const int ith = params->ith; const int nth = params->nth; - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS float eps; memcpy(&eps, dst->op_params, sizeof(float)); @@ -10942,7 +11298,7 @@ static void ggml_compute_forward_rms_norm_back_f32( const int ith = params->ith; const int nth = params->nth; - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS float eps; memcpy(&eps, dst->op_params, sizeof(float)); @@ -11117,7 +11473,7 @@ static void ggml_compute_forward_group_norm_f32( const int ith = params->ith; const int nth = params->nth; - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS const float eps = 1e-6f; // TODO: make this a parameter @@ -11228,7 +11584,7 @@ static void ggml_compute_forward_mul_mat( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; const int nth = params->nth; @@ -11265,11 +11621,6 @@ static void ggml_compute_forward_mul_mat( #if defined(GGML_USE_CLBLAST) if (ggml_cl_can_mul_mat(src0, src1, dst)) { - // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension - // ref: https://github.com/ggerganov/ggml/pull/224 - GGML_ASSERT(ne02 == ne12); - GGML_ASSERT(ne03 == ne13); - if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) { ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize); } @@ -11443,10 +11794,10 @@ static void ggml_compute_forward_out_prod_f32( const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); + // int64_t t0 = ggml_perf_time_us(); + // UNUSED(t0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; const int nth = params->nth; @@ -11485,6 +11836,146 @@ static void ggml_compute_forward_out_prod_f32( return; } + // dst[:,:,:,:] = 0 + // for i2,i3: + // for i1: + // for i01: + // for i0: + // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] + + // parallelize by last three dimensions + + // total rows in dst + const int64_t nr = ne1*ne2*ne3; + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + // block-tiling attempt + const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32); + const int64_t blck_1 = 16; + + for (int64_t bir = ir0; bir < ir1; bir += blck_1) { + const int64_t bir1 = MIN(bir + blck_1, ir1); + for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) { + const int64_t bne01 = MIN(bi01 + blck_0, ne01); + for (int64_t ir = bir; ir < bir1; ++ir) { + // dst indices + const int64_t i3 = ir/(ne2*ne1); + const int64_t i2 = (ir - i3*ne2*ne1)/ne1; + const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); + + const int64_t i02 = i2; + const int64_t i03 = i3; + + //const int64_t i10 = i1; + const int64_t i12 = i2; + const int64_t i13 = i3; + +#if GGML_VEC_MAD_UNROLL > 2 + const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL); + for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) { + const int64_t i11 = i01; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1); + } + for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) { + const int64_t i11 = i01; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + ggml_vec_mad_f32(ne0, d, s0, *s1); + } +#else + for (int64_t i01 = bi01; i01 < bne01; ++i01) { + const int64_t i11 = i01; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + ggml_vec_mad_f32(ne0, d, s0, *s1); + } +#endif + } + } + } + + + //int64_t t1 = ggml_perf_time_us(); + //static int64_t acc = 0; + //acc += t1 - t0; + //if (t1 - t0 > 10) { + // printf("\n"); + // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); + // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); + // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); + // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13); + + // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); + //} +} + +static void ggml_compute_forward_out_prod_q_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + // int64_t t0 = ggml_perf_time_us(); + // UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int ith = params->ith; + const int nth = params->nth; + + const enum ggml_type type = src0->type; + ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; + + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 dim0 + GGML_ASSERT(nb00 == ggml_type_size(type)); + + // dst dim0 cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + // GGML_ASSERT(nb0 <= nb1); + // GGML_ASSERT(nb1 <= nb2); + // GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne10); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + + // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod + // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST) + + if (params->type == GGML_TASK_INIT) { + ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + // parallelize by last three dimensions // total rows in dst @@ -11504,6 +11995,8 @@ static void ggml_compute_forward_out_prod_f32( // for i0: // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] + float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith; + for (int64_t ir = ir0; ir < ir1; ++ir) { // dst indices const int64_t i3 = ir/(ne2*ne1); @@ -11524,10 +12017,8 @@ static void ggml_compute_forward_out_prod_f32( float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); - ggml_vec_mad_f32(ne0, d, s0, *s1); - // for (int64_t i0 = 0; i0 < ne0; ++i0) { - // d[i0] += s0[i0] * s1[i1]; - // } + dequantize_row_q(s0, wdata, ne0); + ggml_vec_mad_f32(ne0, d, wdata, *s1); } } @@ -11556,10 +12047,13 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: - case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: { - GGML_ASSERT(false); // todo - // ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst); + ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst); } break; case GGML_TYPE_F16: { @@ -11677,8 +12171,8 @@ static void ggml_compute_forward_set_f32( const int nr = ggml_nrows(src1); const int nc = src1->ne[0]; - GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); - GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) // src0 and dst as viewed during set const size_t nb0 = ggml_element_size(src0); @@ -11947,14 +12441,15 @@ static void ggml_compute_forward_get_rows_back_f32_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - const struct ggml_tensor * opt0, struct ggml_tensor * dst) { GGML_ASSERT(params->ith == 0); - GGML_ASSERT(ggml_are_same_shape(opt0, dst)); - GGML_ASSERT(ggml_is_contiguous(opt0)); GGML_ASSERT(ggml_is_contiguous(dst)); - ggml_compute_forward_dup_same_cont(params, opt0, dst); + // ggml_compute_forward_dup_same_cont(params, opt0, dst); + + if (params->type == GGML_TASK_INIT) { + memset(dst->data, 0, ggml_nbytes(dst)); + } if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -11980,11 +12475,8 @@ static void ggml_compute_forward_get_rows_back_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - const struct ggml_tensor * opt0, struct ggml_tensor * dst) { GGML_ASSERT(params->ith == 0); - GGML_ASSERT(ggml_are_same_shape(opt0, dst)); - GGML_ASSERT(ggml_is_contiguous(opt0)); GGML_ASSERT(ggml_is_contiguous(dst)); // ggml_compute_forward_dup_same_cont(params, opt0, dst); @@ -12018,16 +12510,15 @@ static void ggml_compute_forward_get_rows_back( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - const struct ggml_tensor * opt0, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, opt0, dst); + ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, dst); } break; case GGML_TYPE_F32: { - ggml_compute_forward_get_rows_back_f32(params, src0, src1, opt0, dst); + ggml_compute_forward_get_rows_back_f32(params, src0, src1, dst); } break; default: { @@ -12068,7 +12559,7 @@ static void ggml_compute_forward_diag_f32( // TODO: handle transposed/permuted matrices - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS GGML_ASSERT(ne00 == ne0); GGML_ASSERT(ne00 == ne1); @@ -12456,13 +12947,11 @@ static void ggml_compute_forward_alibi_f16( return; } - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - assert(n_past >= 0); - const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 const int ne1 = src0->ne[1]; // seq_len_without_past const int ne2 = src0->ne[2]; // n_head -> this is k @@ -12477,7 +12966,7 @@ static void ggml_compute_forward_alibi_f16( //const int nb3 = src0->nb[3]; GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(ne1 + n_past == ne0); (void) n_past; + //GGML_ASSERT(ne1 + n_past == ne0); (void) n_past; GGML_ASSERT(n_head == ne2); // add alibi to src0 (KQ_scaled) @@ -12623,8 +13112,8 @@ static void ggml_compute_forward_clamp( static void ggml_compute_forward_rope_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, + const struct ggml_tensor * src1, struct ggml_tensor * dst) { - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } @@ -12634,9 +13123,9 @@ static void ggml_compute_forward_rope_f32( // these two only relevant for xPos RoPE: float xpos_base; - bool xpos_down; + bool xpos_down; - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; const int n_ctx = ((int32_t *) dst->op_params)[3]; @@ -12645,9 +13134,7 @@ static void ggml_compute_forward_rope_f32( memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool)); - assert(n_past >= 0); - - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); //printf("n_past = %d, ne2 = %d\n", n_past, ne2); @@ -12677,9 +13164,11 @@ static void ggml_compute_forward_rope_f32( const bool is_neox = mode & 2; const bool is_glm = mode & 4; + const int32_t * pos = (const int32_t *) src1->data; + for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { - const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t p = pos[i2]; for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; @@ -12716,7 +13205,7 @@ static void ggml_compute_forward_rope_f32( const float cos_theta = cosf(theta); const float sin_theta = sinf(theta); // zeta scaling for xPos only: - float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f; + float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f; if (xpos_down) zeta = 1.0f / zeta; theta *= theta_scale; @@ -12761,8 +13250,8 @@ static void ggml_compute_forward_rope_f32( static void ggml_compute_forward_rope_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, + const struct ggml_tensor * src1, struct ggml_tensor * dst) { - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } @@ -12770,16 +13259,14 @@ static void ggml_compute_forward_rope_f16( float freq_base; float freq_scale; - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; const int n_ctx = ((int32_t *) dst->op_params)[3]; memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); - assert(n_past >= 0); - - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); //printf("n_past = %d, ne2 = %d\n", n_past, ne2); @@ -12809,9 +13296,11 @@ static void ggml_compute_forward_rope_f16( const bool is_neox = mode & 2; const bool is_glm = mode & 4; + const int32_t * pos = (const int32_t *) src1->data; + for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { - const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t p = pos[i2]; for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; @@ -12890,15 +13379,16 @@ static void ggml_compute_forward_rope_f16( static void ggml_compute_forward_rope( const struct ggml_compute_params * params, const struct ggml_tensor * src0, + const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_rope_f16(params, src0, dst); + ggml_compute_forward_rope_f16(params, src0, src1, dst); } break; case GGML_TYPE_F32: { - ggml_compute_forward_rope_f32(params, src0, dst); + ggml_compute_forward_rope_f32(params, src0, src1, dst); } break; default: { @@ -12912,6 +13402,7 @@ static void ggml_compute_forward_rope( static void ggml_compute_forward_rope_back_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, + const struct ggml_tensor * src1, struct ggml_tensor * dst) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { @@ -12929,7 +13420,7 @@ static void ggml_compute_forward_rope_back_f32( float xpos_base; bool xpos_down; - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; const int n_ctx = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx); @@ -12938,9 +13429,7 @@ static void ggml_compute_forward_rope_back_f32( memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool)); - assert(n_past >= 0); - - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); //printf("n_past = %d, ne2 = %d\n", n_past, ne2); @@ -12966,9 +13455,11 @@ static void ggml_compute_forward_rope_back_f32( const bool is_neox = mode & 2; + const int32_t * pos = (const int32_t *) src1->data; + for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { - const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t p = pos[i2]; for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; @@ -12980,7 +13471,7 @@ static void ggml_compute_forward_rope_back_f32( const float cos_theta = cosf(theta); const float sin_theta = sinf(theta); // zeta scaling for xPos only: - float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f; + float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f; if (xpos_down) zeta = 1.0f / zeta; theta *= theta_scale; @@ -13023,6 +13514,7 @@ static void ggml_compute_forward_rope_back_f32( static void ggml_compute_forward_rope_back_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, + const struct ggml_tensor * src1, struct ggml_tensor * dst) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { @@ -13033,13 +13525,11 @@ static void ggml_compute_forward_rope_back_f16( // dx = rope_back(dy, src1) // src0 is dy, src1 contains options - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; - assert(n_past >= 0); - - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); //printf("n_past = %d, ne2 = %d\n", n_past, ne2); @@ -13065,9 +13555,11 @@ static void ggml_compute_forward_rope_back_f16( const bool is_neox = mode & 2; + const int32_t * pos = (const int32_t *) src1->data; + for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { - const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t p = pos[i2]; for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; @@ -13119,15 +13611,16 @@ static void ggml_compute_forward_rope_back_f16( static void ggml_compute_forward_rope_back( const struct ggml_compute_params * params, const struct ggml_tensor * src0, + const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_rope_back_f16(params, src0, dst); + ggml_compute_forward_rope_back_f16(params, src0, src1, dst); } break; case GGML_TYPE_F32: { - ggml_compute_forward_rope_back_f32(params, src0, dst); + ggml_compute_forward_rope_back_f32(params, src0, src1, dst); } break; default: { @@ -13150,7 +13643,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; const int nth = params->nth; @@ -13241,7 +13734,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; const int nth = params->nth; @@ -13353,7 +13846,7 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; const int nth = params->nth; @@ -13444,7 +13937,7 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; const int nth = params->nth; @@ -13562,7 +14055,7 @@ static void ggml_compute_forward_conv_1d( ggml_compute_forward_conv_1d_s2_ph(params, src0, src1, dst); } else { GGML_ASSERT(false); // only stride 1 and 2 supported - }; + } } // ggml_compute_forward_conv_2d @@ -13579,7 +14072,7 @@ static void ggml_compute_forward_conv_2d_f16_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; const int nth = params->nth; @@ -13699,7 +14192,7 @@ static void ggml_compute_forward_conv_transpose_2d( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - GGML_TENSOR_BINARY_OP_LOCALS; + GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; const int nth = params->nth; @@ -13958,7 +14451,7 @@ static void ggml_compute_forward_upscale_f32( const int ith = params->ith; - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS const int scale_factor = dst->op_params[0]; @@ -14010,14 +14503,14 @@ static void ggml_compute_forward_flash_attn_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - GGML_TENSOR_LOCALS(int64_t, neq, q, ne); - GGML_TENSOR_LOCALS(size_t, nbq, q, nb); - GGML_TENSOR_LOCALS(int64_t, nek, k, ne); - GGML_TENSOR_LOCALS(size_t, nbk, k, nb); - GGML_TENSOR_LOCALS(int64_t, nev, v, ne); - GGML_TENSOR_LOCALS(size_t, nbv, v, nb); - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); - GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) const int ith = params->ith; const int nth = params->nth; @@ -14087,10 +14580,11 @@ static void ggml_compute_forward_flash_attn_f32( S[i] = -INFINITY; } - for (int64_t ic = 0; ic < nek1; ++ic) { + const int64_t masked_begin = masked ? (P + iq1 + 1) : M; + for (int64_t ic = 0; ic < masked_begin; ++ic) { // k indices const int ik3 = iq3; - const int ik2 = iq2; + const int ik2 = iq2 % nek2; const int ik1 = ic; // S indices @@ -14103,20 +14597,18 @@ static void ggml_compute_forward_flash_attn_f32( } // scale - ggml_vec_scale_f32(nek1, S, scale); + ggml_vec_scale_f32(masked_begin, S, scale); - if (masked) { - for (int64_t i = P; i < M; i++) { - if (i > P + iq1) { - S[i] = -INFINITY; - } - } + for (int64_t i = masked_begin; i < M; i++) { + S[i] = -INFINITY; } // softmax + // exclude known -INF S[..] values from max and loop + // dont forget to set their SW values to zero { float max = -INFINITY; - ggml_vec_max_f32(M, &max, S); + ggml_vec_max_f32(masked_begin, &max, S); ggml_float sum = 0.0; { @@ -14130,10 +14622,15 @@ static void ggml_compute_forward_flash_attn_f32( ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { + if (i >= masked_begin) { + break; + } float * SS = S + i; for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { - if (SS[j] == -INFINITY) { + if (i + j >= masked_begin) { + break; + } else if (SS[j] == -INFINITY) { SS[j] = 0.0f; } else { #ifndef GGML_FLASH_ATTN_EXP_FP16 @@ -14158,10 +14655,10 @@ static void ggml_compute_forward_flash_attn_f32( assert(sum > 0.0); sum = 1.0/sum; - ggml_vec_scale_f32(M, S, sum); + ggml_vec_scale_f32(masked_begin, S, sum); #ifndef NDEBUG - for (int i = 0; i < M; ++i) { + for (int i = 0; i < masked_begin; ++i) { assert(!isnan(S[i])); assert(!isinf(S[i])); } @@ -14174,9 +14671,13 @@ static void ggml_compute_forward_flash_attn_f32( const int i2 = iq2; const int i3 = iq3; - ggml_vec_dot_f32(nek1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + // v indices + const int iv2 = iq2 % nev2; + const int iv3 = iq3; + + ggml_vec_dot_f32(masked_begin, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), S); } } @@ -14192,14 +14693,14 @@ static void ggml_compute_forward_flash_attn_f16( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - GGML_TENSOR_LOCALS(int64_t, neq, q, ne); - GGML_TENSOR_LOCALS(size_t, nbq, q, nb); - GGML_TENSOR_LOCALS(int64_t, nek, k, ne); - GGML_TENSOR_LOCALS(size_t, nbk, k, nb); - GGML_TENSOR_LOCALS(int64_t, nev, v, ne); - GGML_TENSOR_LOCALS(size_t, nbv, v, nb); - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); - GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) const int ith = params->ith; const int nth = params->nth; @@ -14273,7 +14774,7 @@ static void ggml_compute_forward_flash_attn_f16( for (int64_t ic = 0; ic < nek1; ++ic) { // k indices const int ik3 = iq3; - const int ik2 = iq2; + const int ik2 = iq2 % nek2; const int ik1 = ic; // S indices @@ -14288,7 +14789,7 @@ static void ggml_compute_forward_flash_attn_f16( for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { // k indices const int ik3 = iq3; - const int ik2 = iq2; + const int ik2 = iq2 % nek2; const int ik1 = ic; // S indices @@ -14313,6 +14814,8 @@ static void ggml_compute_forward_flash_attn_f16( } // softmax + // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero. + // dont forget to set their S values to zero { float max = -INFINITY; ggml_vec_max_f32(M, &max, S); @@ -14369,6 +14872,7 @@ static void ggml_compute_forward_flash_attn_f16( S16[i] = GGML_FP32_TO_FP16(S[i]); } + // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16). if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) { for (int64_t ic = 0; ic < nev1; ++ic) { // dst indices @@ -14376,9 +14880,13 @@ static void ggml_compute_forward_flash_attn_f16( const int i2 = iq2; const int i3 = iq3; - ggml_vec_dot_f16(nek1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + // v indices + const int iv2 = iq2 % nev2; + const int iv3 = iq3; + + ggml_vec_dot_f16(nev0, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), S16); } } else { @@ -14388,9 +14896,13 @@ static void ggml_compute_forward_flash_attn_f16( const int i2 = iq2; const int i3 = iq3; - ggml_vec_dot_f16_unroll(nek1, nbv1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + // v indices + const int iv2 = iq2 % nev2; + const int iv3 = iq3; + + ggml_vec_dot_f16_unroll(nev0, nbv1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), S16); } } @@ -14433,18 +14945,18 @@ static void ggml_compute_forward_flash_ff_f16( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - GGML_TENSOR_LOCALS(int64_t, nea, a, ne); - GGML_TENSOR_LOCALS(size_t, nba, a, nb); - GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne); - GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb); - GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne); - GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb); - GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne); - GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb); - GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne); - GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb); - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); - GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + GGML_TENSOR_LOCALS(int64_t, nea, a, ne) + GGML_TENSOR_LOCALS(size_t, nba, a, nb) + GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne) + GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb) + GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne) + GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb) + GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne) + GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb) + GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne) + GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) const int ith = params->ith; const int nth = params->nth; @@ -14592,16 +15104,16 @@ static void ggml_compute_forward_flash_attn_back_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - GGML_TENSOR_LOCALS(int64_t, neq, q, ne); - GGML_TENSOR_LOCALS(size_t, nbq, q, nb); - GGML_TENSOR_LOCALS(int64_t, nek, k, ne); - GGML_TENSOR_LOCALS(size_t, nbk, k, nb); - GGML_TENSOR_LOCALS(int64_t, nev, v, ne); - GGML_TENSOR_LOCALS(size_t, nbv, v, nb); - GGML_TENSOR_LOCALS(int64_t, ned, d, ne); - GGML_TENSOR_LOCALS(size_t, nbd, d, nb); - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); - GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ned, d, ne) + GGML_TENSOR_LOCALS(size_t, nbd, d, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) const int ith = params->ith; const int nth = params->nth; @@ -14649,10 +15161,37 @@ static void ggml_compute_forward_flash_attn_back_f32( return; } - // parallelize by q rows using ggml_vec_dot_f32 + const int64_t elem_q = ggml_nelements(q); + const int64_t elem_k = ggml_nelements(k); - // total rows in q - const int nr = neq2*neq3; + enum ggml_type result_type = dst->type; + GGML_ASSERT(ggml_blck_size(result_type) == 1); + const size_t tsize = ggml_type_size(result_type); + + const size_t offs_q = 0; + const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN); + const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN); + + void * grad_q = (char *) dst->data; + void * grad_k = (char *) dst->data + offs_k; + void * grad_v = (char *) dst->data + offs_v; + + const size_t nbgq1 = nb0*neq0; + const size_t nbgq2 = nb0*neq0*neq1; + const size_t nbgq3 = nb0*neq0*neq1*neq2; + + const size_t nbgk1 = nb0*nek0; + const size_t nbgk2 = nb0*nek0*nek1; + const size_t nbgk3 = nb0*nek0*nek1*neq2; + + const size_t nbgv1 = nb0*nev0; + const size_t nbgv2 = nb0*nev0*nev1; + const size_t nbgv3 = nb0*nev0*nev1*neq2; + + // parallelize by k rows using ggml_vec_dot_f32 + + // total rows in k + const int nr = nek2*nek3; // rows per thread const int dr = (nr + nth - 1)/nth; @@ -14665,268 +15204,243 @@ static void ggml_compute_forward_flash_attn_back_f32( //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + // how often k2 (and v2) is repeated in q2 + int nrep = neq2/nek2; + for (int ir = ir0; ir < ir1; ++ir) { // q indices - const int iq3 = ir/(neq2); - const int iq2 = ir - iq3*neq2; - for ( int iq1 = 0; iq1 < neq1; ++iq1) { + const int ik3 = ir/(nek2); + const int ik2 = ir - ik3*nek2; + const int iq3 = ik3; + const int id3 = ik3; + const int iv3 = ik3; + const int iv2 = ik2; - // not sure about CACHE_LINE_SIZE_F32.. - // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset? - float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32); - float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32); + for (int irep = 0; irep < nrep; ++irep) { + const int iq2 = ik2 + irep*nek2; + const int id2 = iq2; - for (int i = M; i < Mup; ++i) { - S[i] = -INFINITY; - } + // (ik2 + irep*nek2) % nek2 == ik2 + for (int iq1 = 0; iq1 < neq1; ++iq1) { + const int id1 = iq1; - for (int64_t ic = 0; ic < nek1; ++ic) { - // k indices - const int ik3 = iq3; - const int ik2 = iq2; - const int ik1 = ic; + // not sure about CACHE_LINE_SIZE_F32.. + // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset? + float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32); + float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32); - // S indices - const int i1 = ik1; + for (int i = M; i < Mup; ++i) { + S[i] = -INFINITY; + } - ggml_vec_dot_f32(neq0, - S + i1, - (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - } + const int64_t masked_begin = masked ? (P + iq1 + 1) : M; + for (int64_t ic = 0; ic < masked_begin; ++ic) { + // k indices + const int ik1 = ic; - // scale - ggml_vec_scale_f32(nek1, S, scale); + // S indices + const int i1 = ik1; - if (masked) { - for (int64_t i = P; i < M; i++) { - if (i > P + iq1) { - S[i] = -INFINITY; - } + ggml_vec_dot_f32(neq0, + S + i1, + (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); } - } - // softmax - { - float max = -INFINITY; - ggml_vec_max_f32(M, &max, S); + // scale + ggml_vec_scale_f32(masked_begin, S, scale); - ggml_float sum = 0.0; + for (int64_t i = masked_begin; i < M; i++) { + S[i] = -INFINITY; + } + + // softmax + // exclude known -INF S[..] values from max and loop + // dont forget to set their SM values to zero { + float max = -INFINITY; + ggml_vec_max_f32(masked_begin, &max, S); + + ggml_float sum = 0.0; + { #ifdef GGML_SOFT_MAX_ACCELERATE - max = -max; - vDSP_vsadd(SM, 1, &max, SM, 1, Mup); - vvexpf(SM, SM, &Mup); - ggml_vec_sum_f32(Mup, &sum, SM); + max = -max; + vDSP_vsadd(SM, 1, &max, SM, 1, Mup); + vvexpf(SM, SM, &Mup); + ggml_vec_sum_f32(Mup, &sum, SM); #else - uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt); - ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; + uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt); + ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; - for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { - float * SR = S + i; - float * SW = SM + i; - - for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { - if (SR[j] == -INFINITY) { - SW[j] = 0.0f; - } else { + for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { + if (i >= masked_begin) { + break; + } + float * SR = S + i; + float * SW = SM + i; + + for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { + if (i + j >= masked_begin) { + break; + } else if (SR[j] == -INFINITY) { + SW[j] = 0.0f; + } else { #ifndef GGML_FLASH_ATTN_EXP_FP16 - const float val = expf(SR[j] - max); + const float val = expf(SR[j] - max); #else - ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max); - memcpy(&scvt[j], &s, sizeof(uint16_t)); - const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); + ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max); + memcpy(&scvt[j], &s, sizeof(uint16_t)); + const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); #endif - sump[j] += (ggml_float)val; - SW[j] = val; + sump[j] += (ggml_float)val; + SW[j] = val; + } } } - } - for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { - sum += sump[i]; - } + for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { + sum += sump[i]; + } #endif - } - - assert(sum > 0.0); - - sum = 1.0/sum; - ggml_vec_scale_f32(M, SM, sum); - - } - - // step-by-step explanation - { - // forward-process shape grads from backward process - // parallel_for iq2,iq3: - // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur] - // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur] - // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur] - // for iq1: - // kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur - // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur - // vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4 - // S0 = -Inf [D,1,1,1] - // ~S1[i] = dot(kcur[:D,i], qcur) - // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale - // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P) - // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) - // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur - // ~S5[i] = dot(vcur[:,i], S4) - // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3] - // ~dst[i,iq1,iq2,iq3] = S5[i] ^ - // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3] - // dst backward-/ grad[dst] = d - // - // output gradients with their dependencies: - // - // grad[kcur] = grad[S1].T @ qcur - // grad[S1] = diag_mask_zero(grad[S3], P) * scale - // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) - // grad[S4] = grad[S5] @ vcur - // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur - // grad[qcur] = grad[S1] @ kcur - // grad[vcur] = grad[S5].T @ S4 - // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4 - // - // in post-order: - // - // S1 = qcur @ kcur.T - // S2 = S1 * scale - // S3 = diag_mask_inf(S2, P) - // S4 = softmax(S3) - // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur - // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) - // grad[S1] = diag_mask_zero(grad[S3], P) * scale - // grad[qcur] = grad[S1] @ kcur - // grad[kcur] = grad[S1].T @ qcur - // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4 - // - // using less variables (SM=S4): - // - // S = diag_mask_inf(qcur @ kcur.T * scale, P) - // SM = softmax(S) - // S = d[:D,iq1,iq2,iq3] @ vcur - // dot_SM_gradSM = dot(SM, S) - // S = SM * (S - dot(SM, S)) - // S = diag_mask_zero(S, P) * scale - // - // grad[q][:D,iq1,iq2,iq3] += S @ kcur - // grad[k][:D,:M,iq2,iq3] += S.T @ qcur - // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM - } - - // S = gradSM = d[:D,iq1,iq2,iq3] @ vcur - // S = d[:D,iq1,iq2,iq3] @ vcur - // S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3] - ggml_vec_set_f32(M, S, 0); - for (int64_t ic = 0; ic < D; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; + } - ggml_vec_mad_f32(M, - S, - (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), - *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3))); - } + assert(sum > 0.0); - // S = SM * (S - dot(SM, S)) - float dot_SM_gradSM = 0; - ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S); - ggml_vec_acc1_f32(M, S, -dot_SM_gradSM); - ggml_vec_mul_f32 (M, S, S, SM); + sum = 1.0/sum; + ggml_vec_scale_f32(masked_begin, SM, sum); - // S = diag_mask_zero(S, P) * scale - if (masked) { - // for (int64_t i = P + iq1 + 1; i < M; i++) { - // S[i] = 0; - // } - for (int64_t i = P; i < M; i++) { - if (i > P + iq1) { - S[i] = 0; - } } - } - ggml_vec_scale_f32(M, S, scale); - - void * grad_q = (char *) dst->data; - void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3; - void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3; - - const size_t nbgq1 = nb0*neq0; - const size_t nbgq2 = nb0*neq0*neq1; - const size_t nbgq3 = nb0*neq0*neq1*neq2; - - const size_t nbgk1 = nb0*nek0; - const size_t nbgk2 = nb0*nek0*nek1; - const size_t nbgk3 = nb0*nek0*nek1*neq2; - - const size_t nbgv1 = nb0*nev0; - const size_t nbgv2 = nb0*nev0*nev1; - const size_t nbgv3 = nb0*nev0*nev1*neq2; - - // S shape [M,1] - // SM shape [M,1] - // kcur shape [D,M] - // qcur shape [D,1] - // vcur shape [M,D] - // - // grad[q][:D,iq1,iq2,iq3] += S @ kcur - // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M] - // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic] - // - //// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T) - //// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T) - for (int64_t ic = 0; ic < M; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - ggml_vec_mad_f32(D, - (float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)), - (float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)), - S[ic]); - } + // step-by-step explanation + { + // forward-process shape grads from backward process + // parallel_for ik2,ik3: + // for irep: + // iq2 = ik2 + irep*nek2 + // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur] + // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur] + // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur] + // for iq1: + // kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur + // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur + // vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4 + // S0 = -Inf [D,1,1,1] + // ~S1[i] = dot(kcur[:D,i], qcur) + // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale + // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P) + // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur + // ~S5[i] = dot(vcur[:,i], S4) + // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3] + // ~dst[i,iq1,iq2,iq3] = S5[i] ^ + // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3] + // dst backward-/ grad[dst] = d + // + // output gradients with their dependencies: + // + // grad[kcur] = grad[S1].T @ qcur + // grad[S1] = diag_mask_zero(grad[S3], P) * scale + // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // grad[S4] = grad[S5] @ vcur + // grad[S4] = d[:D,id1,id2,id3] @ vcur + // grad[qcur] = grad[S1] @ kcur + // grad[vcur] = grad[S5].T @ S4 + // grad[vcur] = d[:D,id1,id2,id3].T @ S4 + // + // in post-order: + // + // S1 = qcur @ kcur.T + // S2 = S1 * scale + // S3 = diag_mask_inf(S2, P) + // S4 = softmax(S3) + // grad[S4] = d[:D,id1,id2,id3] @ vcur + // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // grad[S1] = diag_mask_zero(grad[S3], P) * scale + // grad[qcur] = grad[S1] @ kcur + // grad[kcur] = grad[S1].T @ qcur + // grad[vcur] = d[:D,id1,id2,id3].T @ S4 + // + // using less variables (SM=S4): + // + // S = diag_mask_inf(qcur @ kcur.T * scale, P) + // SM = softmax(S) + // S = d[:D,iq1,iq2,iq3] @ vcur + // dot_SM_gradSM = dot(SM, S) + // S = SM * (S - dot(SM, S)) + // S = diag_mask_zero(S, P) * scale + // + // grad[q][:D,iq1,iq2,iq3] += S @ kcur + // grad[k][:D,:M,ik2,ik3] += S.T @ qcur + // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM + } - // grad[k][:D,:M,iq2,iq3] += S.T @ qcur - // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0] - // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0] - for (int64_t ic = 0; ic < M; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; + // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3] + // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3] + // for ic: + // S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3] + // exclude known future zero S[..] values from operation + ggml_vec_set_f32(masked_begin, S, 0); + for (int64_t ic = 0; ic < D; ++ic) { + ggml_vec_mad_f32(masked_begin, + S, + (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), + *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3))); + } - // ggml_vec_set_f32(D, - // (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)), - // 0); - ggml_vec_mad_f32(D, - (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)), - (float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)), - S[ic]); - } + // S = SM * (S - dot(SM, S)) + float dot_SM_gradSM = 0; + ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, SM, S); + ggml_vec_acc1_f32(M, S, -dot_SM_gradSM); + ggml_vec_mul_f32 (masked_begin, S, S, SM); + + // S = diag_mask_zero(S, P) * scale + // already done by above ggml_vec_set_f32 + + // exclude known zero S[..] values from operation + ggml_vec_scale_f32(masked_begin, S, scale); + + // S shape [M,1] + // SM shape [M,1] + // kcur shape [D,M] + // qcur shape [D,1] + // vcur shape [M,D] + + // grad[q][:D,iq1,iq2,iq3] += S @ kcur + // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M] + // for ic: + // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3] + // exclude known zero S[..] values from loop + for (int64_t ic = 0; ic < masked_begin; ++ic) { + ggml_vec_mad_f32(D, + (float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)), + (float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)), + S[ic]); + } - // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM - // grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M] - // grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M] - for (int64_t ic = 0; ic < D; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; + // grad[k][:D,:M,iq2,iq3] += S.T @ qcur + // for ic: + // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0] + // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0] + // exclude known zero S[..] values from loop + for (int64_t ic = 0; ic < masked_begin; ++ic) { + ggml_vec_mad_f32(D, + (float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)), + (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), + S[ic]); + } - // ggml_vec_set_f32(M, - // (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)), - // 0); - ggml_vec_mad_f32(M, - (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)), - SM, - *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3))); + // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM + // for ic: + // grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M] + // grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M] + // exclude known zero SM[..] values from mad + for (int64_t ic = 0; ic < D; ++ic) { + ggml_vec_mad_f32(masked_begin, + (float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)), + SM, + *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3))); + } } } } @@ -14962,8 +15476,8 @@ static void ggml_compute_forward_win_part_f32( return; } - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) const int32_t nep0 = ((const int32_t *)(dst->op_params))[0]; const int32_t nep1 = ((const int32_t *)(dst->op_params))[1]; @@ -15024,8 +15538,8 @@ static void ggml_compute_forward_win_unpart_f32( return; } - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) const int32_t w = ((const int32_t *)(dst->op_params))[0]; @@ -15142,7 +15656,7 @@ static void ggml_compute_forward_get_rel_pos_f16( // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322 - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS const int64_t w = ne1; @@ -15840,7 +16354,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_GET_ROWS_BACK: { - ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); + ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor); } break; case GGML_OP_DIAG: { @@ -15864,11 +16378,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_ROPE: { - ggml_compute_forward_rope(params, tensor->src[0], tensor); + ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor); } break; case GGML_OP_ROPE_BACK: { - ggml_compute_forward_rope_back(params, tensor->src[0], tensor); + ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor); } break; case GGML_OP_ALIBI: { @@ -16013,7 +16527,218 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm //////////////////////////////////////////////////////////////////////////////// -static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) { +static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small"); + +static size_t hash(void * p) { + return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE; +} + +static size_t hash_find(void * hash_table[], void * p) { + size_t h = hash(p); + + // linear probing + size_t i = h; + while (hash_table[i] != NULL && hash_table[i] != p) { + i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE; + if (i == h) { + // visited all hash table entries -> not found + return GGML_GRAPH_HASHTABLE_SIZE; + } + } + return i; +} + +static bool hash_insert(void * hash_table[], void * p) { + size_t i = hash_find(hash_table, p); + + GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full + + if (hash_table[i] == p) { + return true; + } + + // insert + GGML_ASSERT(hash_table[i] == NULL); + hash_table[i] = p; + return false; +} + +static bool hash_contains(void * hash_table[], void * p) { + size_t i = hash_find(hash_table, p); + return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p); +} + +struct hash_map { + void * keys[GGML_GRAPH_HASHTABLE_SIZE]; + void * vals[GGML_GRAPH_HASHTABLE_SIZE]; +}; + +static struct hash_map * new_hash_map(void) { + struct hash_map * result = malloc(sizeof(struct hash_map)); + for (int i=0; ikeys[i] = NULL; + result->vals[i] = NULL; + } + return result; +} + +static void free_hash_map(struct hash_map * map) { + free(map); +} + +// gradient checkpointing + +static struct ggml_tensor * ggml_recompute_graph_node( + struct ggml_context * ctx, + struct ggml_cgraph * graph, + struct hash_map * replacements, + struct ggml_tensor * node) { + + if (node == NULL) { + return NULL; + } + + if (node->is_param) { + return node; + } + + if (!hash_contains(graph->visited_hash_table, node)) { + return node; + } + + int count_children = 0; + for (int k = 0; k < GGML_MAX_SRC; ++k) { + if (node->src[k]) { + ++count_children; + } + } + + if (count_children == 0) { + return node; + } + + size_t i = hash_find(replacements->keys, node); + GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full + if (replacements->keys[i] == node) { + return (struct ggml_tensor *) replacements->vals[i]; + } + + struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne); + + // insert clone into replacements + GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite + replacements->keys[i] = node; + replacements->vals[i] = clone; + + clone->op = node->op; + clone->grad = node->grad; + clone->is_param = node->is_param; + clone->extra = node->extra; + for (int k = 0; k < GGML_MAX_DIMS; ++k) { + clone->nb[k] = node->nb[k]; + } + for (int k = 0; k < GGML_MAX_SRC; ++k) { + clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]); + } + if (node->view_src != NULL) { + clone->data = (node->view_src->data == NULL) + ? NULL // view_src not yet allocated + : (char *) node->view_src->data // view_src already allocated + + node->view_offs; + clone->view_src = node->view_src; + clone->view_offs = node->view_offs; + } + + GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t))); + GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME); + memcpy(clone->op_params, node->op_params, sizeof(node->op_params)); + ggml_format_name(clone, "%s (clone)", ggml_get_name(node)); + + return clone; +} + +void ggml_build_backward_gradient_checkpointing( + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + struct ggml_cgraph * gb_tmp, + struct ggml_tensor * * checkpoints, + int n_checkpoints) { + *gb_tmp = *gf; + ggml_build_backward_expand(ctx, gf, gb_tmp, true); + + if (n_checkpoints <= 0) { + *gb = *gb_tmp; + return; + } + + struct hash_map * replacements = new_hash_map(); + + // insert checkpoints in replacements + for (int i = 0; i < n_checkpoints; ++i) { + size_t k = hash_find(replacements->keys, checkpoints[i]); + GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full + GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite + replacements->keys[k] = checkpoints[i]; + replacements->vals[k] = checkpoints[i]; + } + + *gb = *gf; + // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes], + // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]), + // by recomputing them from checkpoints + for (int i = gf->n_nodes; in_nodes; ++i) { + struct ggml_tensor * node = gb_tmp->nodes[i]; + for (int k = 0; k < GGML_MAX_SRC; ++k) { + // insert new tensors recomputing src, reusing already made replacements, + // remember replacements: remember new tensors with mapping from corresponding gf nodes + // recurse for input tensors, + // unless (i.e. terminating when) input tensors are replacments (like checkpoints) + node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]); + } + // insert rewritten backward node with replacements made into resulting backward graph gb + ggml_build_forward_expand(gb, node); + } + + free_hash_map(replacements); +} + +// functions to change gradients considering the case that input a might be initial gradient with zero value + +static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) { + if (hash_contains(zero_table, a)) { + return b; + } else { + return ggml_add_impl(ctx, a, b, false); + } +} + +static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, void * zero_table[]) { + if (hash_contains(zero_table, a)) { + struct ggml_tensor * a_zero = ggml_scale(ctx, a, ggml_new_f32(ctx, 0)); + return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false); + } else { + return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); + } +} + +static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) { + if (hash_contains(zero_table, a)) { + return ggml_repeat(ctx, b, a); + } else { + return ggml_add1_impl(ctx, a, b, false); + } +} + +static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) { + if (hash_contains(zero_table, a)) { + return ggml_neg(ctx, b); + } else { + return ggml_sub_impl(ctx, a, b, false); + } +} + +static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, void * zero_table[]) { struct ggml_tensor * src0 = tensor->src[0]; struct ggml_tensor * src1 = tensor->src[1]; @@ -16021,34 +16746,34 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor case GGML_OP_DUP: { if (src0->grad) { - src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); } } break; case GGML_OP_ADD: { if (src0->grad) { - src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); } if (src1->grad) { - src1->grad = ggml_add_impl(ctx, src1->grad, tensor->grad, inplace); + src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table); } } break; case GGML_OP_ADD1: { if (src0->grad) { - src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); } if (src1->grad) { - src1->grad = ggml_add_impl(ctx, + src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean - inplace); + zero_table); } } break; case GGML_OP_ACC: { if (src0->grad) { - src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); } if (src1->grad) { const size_t nb1 = ((int32_t *) tensor->op_params)[0]; @@ -16065,117 +16790,117 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor nb1, nb2, nb3, offset); src1->grad = - ggml_add_impl(ctx, + ggml_add_or_set(ctx, src1->grad, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1->grad), - inplace); + zero_table); } } break; case GGML_OP_SUB: { if (src0->grad) { - src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); } if (src1->grad) { - src1->grad = ggml_sub_impl(ctx, src1->grad, tensor->grad, inplace); + src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table); } } break; case GGML_OP_MUL: { if (src0->grad) { src0->grad = - ggml_add_impl(ctx, + ggml_add_or_set(ctx, src0->grad, ggml_mul(ctx, src1, tensor->grad), - inplace); + zero_table); } if (src1->grad) { src1->grad = - ggml_add_impl(ctx, + ggml_add_or_set(ctx, src1->grad, ggml_mul(ctx, src0, tensor->grad), - inplace); + zero_table); } } break; case GGML_OP_DIV: { if (src0->grad) { src0->grad = - ggml_add_impl(ctx, + ggml_add_or_set(ctx, src0->grad, ggml_div(ctx, tensor->grad, src1), - inplace); + zero_table); } if (src1->grad) { src1->grad = - ggml_sub_impl(ctx, + ggml_sub_or_set(ctx, src1->grad, ggml_mul(ctx, tensor->grad, ggml_div(ctx, tensor, src1)), - inplace); + zero_table); } } break; case GGML_OP_SQR: { if (src0->grad) { src0->grad = - ggml_add_impl(ctx, + ggml_add_or_set(ctx, src0->grad, ggml_scale(ctx, ggml_mul(ctx, src0, tensor->grad), ggml_new_f32(ctx, 2.0f)), - inplace); + zero_table); } } break; case GGML_OP_SQRT: { if (src0->grad) { src0->grad = - ggml_add_impl(ctx, + ggml_add_or_set(ctx, src0->grad, ggml_scale(ctx, ggml_div(ctx, tensor->grad, tensor), ggml_new_f32(ctx, 0.5f)), - inplace); + zero_table); } } break; case GGML_OP_LOG: { if (src0->grad) { src0->grad = - ggml_add_impl(ctx, + ggml_add_or_set(ctx, src0->grad, ggml_div(ctx, tensor->grad, src0), - inplace); + zero_table); } } break; case GGML_OP_SUM: { if (src0->grad) { src0->grad = - ggml_add1_impl(ctx, + ggml_add1_or_set(ctx, src0->grad, tensor->grad, - inplace); + zero_table); } } break; case GGML_OP_SUM_ROWS: { if (src0->grad) { src0->grad = - ggml_add_impl(ctx, + ggml_add_or_set(ctx, src0->grad, ggml_repeat(ctx, tensor->grad, src0->grad), - inplace); + zero_table); } } break; case GGML_OP_MEAN: @@ -16187,20 +16912,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { // necessary for llama if (src0->grad) { - src0->grad = ggml_add_impl(ctx, + src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_repeat_back(ctx, tensor->grad, src0->grad), - inplace); + zero_table); } } break; case GGML_OP_REPEAT_BACK: { if (src0->grad) { // TODO: test this - src0->grad = ggml_add_impl(ctx, + src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_repeat(ctx, tensor->grad, src0->grad), - inplace); + zero_table); } } break; case GGML_OP_CONCAT: @@ -16222,10 +16947,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor float eps; memcpy(&eps, tensor->op_params, sizeof(float)); - src0->grad = ggml_add_impl(ctx, + src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_rms_norm_back(ctx, src0, tensor->grad, eps), - inplace); + zero_table); } } break; case GGML_OP_RMS_NORM_BACK: @@ -16249,37 +16974,49 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix // ds1 = t.T.dot(dt) - // tensor.shape [m,p] - // src0.shape [n,m] - // src1.shape [n,p] + // tensor.shape [m,p,qq,rr] + // src0.shape [n,m,q1,r1] + // src1.shape [n,p,qq,rr] // necessary for llama if (src0->grad) { + struct ggml_tensor * s1_tg = + ggml_out_prod(ctx, // [n,m,qq,rr] + src1, // [n,p,qq,rr] + tensor->grad); // [m,p,qq,rr] + const int64_t qq = s1_tg->ne[2]; + const int64_t rr = s1_tg->ne[3]; + const int64_t q1 = src0->ne[2]; + const int64_t r1 = src0->ne[3]; + const bool ne2_broadcasted = qq > q1; + const bool ne3_broadcasted = rr > r1; + if (ne2_broadcasted || ne3_broadcasted) { + // sum broadcast repetitions of s1_tg into shape of src0 + s1_tg = ggml_repeat_back(ctx, s1_tg, src0); + } src0->grad = - ggml_add_impl(ctx, - src0->grad, - ggml_out_prod(ctx, // [n,m] - src1, // [n,p] - tensor->grad), // [m,p] - inplace); + ggml_add_or_set(ctx, + src0->grad, // [n,m,q1,r1] + s1_tg, // [n,m,q1,r1] + zero_table); } if (src1->grad) { src1->grad = - ggml_add_impl(ctx, - src1->grad, - // ggml_mul_mat(ctx, // [n,p] - // ggml_cont(ctx, // [m,n] - // ggml_transpose(ctx, src0)), // [m,n] - // tensor->grad), // [m,p] + ggml_add_or_set(ctx, + src1->grad, // [n,p,qq,rr] + // ggml_mul_mat(ctx, // [n,p,qq,rr] + // ggml_cont(ctx, // [m,n,q1,r1] + // ggml_transpose(ctx, src0)), // [m,n,q1,r1] + // tensor->grad), // [m,p,qq,rr] // // when src0 is bigger than tensor->grad (this is mostly the case in llama), // // avoid transpose of src0, rather transpose smaller tensor->grad // // and then use ggml_out_prod - ggml_out_prod(ctx, // [n,p] - src0, // [n,m] - ggml_transpose(ctx, // [p,m] - tensor->grad)), // [m,p] - inplace); + ggml_out_prod(ctx, // [n,p,qq,rr] + src0, // [n,m,q1,r1] + ggml_transpose(ctx, // [p,m,qq,rr] + tensor->grad)), // [m,p,qq,rr] + zero_table); } } break; case GGML_OP_OUT_PROD: @@ -16291,17 +17028,17 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // necessary for llama if (src0->grad) { src0->grad = - ggml_add_impl(ctx, + ggml_add_or_set(ctx, src0->grad, ggml_scale_impl(ctx, tensor->grad, src1, false), - inplace); + zero_table); } if (src1->grad) { src1->grad = - ggml_add_impl(ctx, + ggml_add_or_set(ctx, src1->grad, ggml_sum(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)), - inplace); + zero_table); } } break; case GGML_OP_SET: @@ -16328,23 +17065,23 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } if (src0->grad) { - src0->grad = ggml_add_impl(ctx, + src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_acc_impl(ctx, tensor->grad, ggml_neg(ctx, tensor_grad_view), nb1, nb2, nb3, offset, false), - inplace); + zero_table); } if (src1->grad) { src1->grad = - ggml_add_impl(ctx, + ggml_add_or_set(ctx, src1->grad, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1->grad), - inplace); + zero_table); } } break; case GGML_OP_CPY: @@ -16355,7 +17092,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // tensor = src0 * 1 + src1 * 0 if (src0->grad) { // dsrc0 = dtensor * 1 - src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); } if (src1->grad) { // dsrc1 = dtensor * 0 -> noop @@ -16367,7 +17104,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor if (src0->grad) { GGML_ASSERT(ggml_is_contiguous(src0->grad)); GGML_ASSERT(ggml_is_contiguous(tensor->grad)); - src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); } } break; case GGML_OP_RESHAPE: @@ -16375,9 +17112,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // necessary for llama if (src0->grad) { src0->grad = - ggml_add_impl(ctx, src0->grad, - ggml_reshape(ctx, tensor->grad, src0->grad), - inplace); + ggml_add_or_set(ctx, src0->grad, + ggml_reshape(ctx, + ggml_is_contiguous(tensor->grad) + ? tensor->grad + : ggml_cont(ctx, tensor->grad), + src0->grad), + zero_table); } } break; case GGML_OP_VIEW: @@ -16406,7 +17147,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor nb3 = (nb3 / n0) * ng; } - src0->grad = ggml_acc_impl(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, inplace); + src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table); } } break; case GGML_OP_PERMUTE: @@ -16424,14 +17165,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor axes_backward[axis2] = 2; axes_backward[axis3] = 3; src0->grad = - ggml_add_impl(ctx, src0->grad, + ggml_add_or_set(ctx, src0->grad, ggml_permute(ctx, tensor->grad, axes_backward[0], axes_backward[1], axes_backward[2], axes_backward[3]), - inplace); + zero_table); } } break; case GGML_OP_TRANSPOSE: @@ -16439,9 +17180,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // necessary for llama if (src0->grad) { src0->grad = - ggml_add_impl(ctx, src0->grad, + ggml_add_or_set(ctx, src0->grad, ggml_transpose(ctx, tensor->grad), - inplace); + zero_table); } } break; case GGML_OP_GET_ROWS: @@ -16449,9 +17190,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // necessary for llama (only for tokenizer) if (src0->grad) { src0->grad = - ggml_add_impl(ctx, src0->grad, + ggml_add_or_set(ctx, src0->grad, + // last ggml_get_rows_back argument src0->grad is only + // necessary to setup correct output shape ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad), - inplace); + zero_table); } if (src1->grad) { // noop @@ -16471,9 +17214,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor if (src0->grad) { const int n_past = ((int32_t *) tensor->op_params)[0]; src0->grad = - ggml_add_impl(ctx, src0->grad, + ggml_add_or_set(ctx, src0->grad, ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), - inplace); + zero_table); } } break; case GGML_OP_DIAG_MASK_ZERO: @@ -16482,9 +17225,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor if (src0->grad) { const int n_past = ((int32_t *) tensor->op_params)[0]; src0->grad = - ggml_add_impl(ctx, src0->grad, + ggml_add_or_set(ctx, src0->grad, ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), - inplace); + zero_table); } } break; case GGML_OP_SOFT_MAX: @@ -16492,9 +17235,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // necessary for llama if (src0->grad) { src0->grad = - ggml_add_impl(ctx, src0->grad, + ggml_add_or_set(ctx, src0->grad, ggml_soft_max_back(ctx, tensor->grad, tensor), - inplace); + zero_table); } } break; @@ -16506,7 +17249,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { // necessary for llama if (src0->grad) { - const int n_past = ((int32_t *) tensor->op_params)[0]; + //const int n_past = ((int32_t *) tensor->op_params)[0]; const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; const int n_ctx = ((int32_t *) tensor->op_params)[3]; @@ -16519,11 +17262,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float)); memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool)); - src0->grad = ggml_add_impl(ctx, + src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_rope_back(ctx, tensor->grad, - n_past, + src1, n_dims, mode, n_ctx, @@ -16531,13 +17274,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor freq_scale, xpos_base, xpos_down), - inplace); + zero_table); } } break; case GGML_OP_ROPE_BACK: { if (src0->grad) { - const int n_past = ((int32_t *) tensor->op_params)[0]; + //const int n_past = ((int32_t *) tensor->op_params)[0]; const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; const int n_ctx = ((int32_t *) tensor->op_params)[3]; @@ -16550,11 +17293,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float)); memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool)); - src0->grad = ggml_add_impl(ctx, + src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_rope_impl(ctx, tensor->grad, - n_past, + src1, n_dims, mode, n_ctx, @@ -16563,7 +17306,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor xpos_base, xpos_down, false), - inplace); + zero_table); } } break; case GGML_OP_ALIBI: @@ -16614,145 +17357,42 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor masked); } - if (src0->grad) { - struct ggml_tensor * grad_q = NULL; - const size_t nb0 = flash_grad->nb[0]; - const size_t offset = 0; - switch(src0->n_dims) { - case 2: - { - grad_q = ggml_view_2d(ctx, - flash_grad, - src0->ne[0], - src0->ne[1], - nb0*src0->ne[0], - offset); - } break; - case 3: - { - grad_q = ggml_view_3d(ctx, - flash_grad, - src0->ne[0], - src0->ne[1], - src0->ne[2], - nb0*src0->ne[0], - nb0*src0->ne[0]*src0->ne[1], - offset); - } break; - case 4: - { - grad_q = ggml_view_4d(ctx, - flash_grad, - src0->ne[0], - src0->ne[1], - src0->ne[2], - src0->ne[3], - nb0*src0->ne[0], - nb0*src0->ne[0]*src0->ne[1], - nb0*src0->ne[0]*src0->ne[1]*src0->ne[2], - offset); - } break; - } + struct ggml_tensor * src2 = tensor->src[2]; + const int64_t elem_q = ggml_nelements(src0); + const int64_t elem_k = ggml_nelements(src1); + const int64_t elem_v = ggml_nelements(src2); + + enum ggml_type result_type = flash_grad->type; + GGML_ASSERT(ggml_blck_size(result_type) == 1); + const size_t tsize = ggml_type_size(result_type); - src0->grad = ggml_add_impl(ctx, + const size_t offs_q = 0; + const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN); + const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN); + + if (src0->grad) { + struct ggml_tensor * view_q = ggml_view_1d(ctx, flash_grad, elem_q, offs_q); + struct ggml_tensor * grad_q = ggml_reshape(ctx, view_q, src0); + src0->grad = ggml_add_or_set(ctx, src0->grad, grad_q, - inplace); + zero_table); } - if (src1->grad) { - struct ggml_tensor * grad_k = NULL; - const size_t nb0 = flash_grad->nb[0]; - const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]; - switch(src1->n_dims) { - case 2: - { - grad_k = ggml_view_2d(ctx, - flash_grad, - src1->ne[0], - src1->ne[1], - nb0*src1->ne[0], - offset); - } break; - case 3: - { - grad_k = ggml_view_3d(ctx, - flash_grad, - src1->ne[0], - src1->ne[1], - src1->ne[2], - nb0*src1->ne[0], - nb0*src1->ne[0]*src1->ne[1], - offset); - } break; - case 4: - { - grad_k = ggml_view_4d(ctx, - flash_grad, - src1->ne[0], - src1->ne[1], - src1->ne[2], - src1->ne[3], - nb0*src1->ne[0], - nb0*src1->ne[0]*src1->ne[1], - nb0*src1->ne[0]*src1->ne[1]*src1->ne[2], - offset); - } break; - } - - src1->grad = ggml_add_impl(ctx, + struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k); + struct ggml_tensor * grad_k = ggml_reshape(ctx, view_k, src1); + src1->grad = ggml_add_or_set(ctx, src1->grad, grad_k, - inplace); + zero_table); } - - struct ggml_tensor * opt0 = tensor->src[2]; - - if (opt0->grad) { - struct ggml_tensor * grad_v = NULL; - const size_t nb0 = flash_grad->nb[0]; - const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3] - + nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3]; - switch(opt0->n_dims) { - case 2: - { - grad_v = ggml_view_2d(ctx, - flash_grad, - opt0->ne[0], - opt0->ne[1], - nb0*opt0->ne[0], - offset); - } break; - case 3: - { - grad_v = ggml_view_3d(ctx, - flash_grad, - opt0->ne[0], - opt0->ne[1], - opt0->ne[2], - nb0*opt0->ne[0], - nb0*opt0->ne[0]*opt0->ne[1], - offset); - } break; - case 4: - { - grad_v = ggml_view_4d(ctx, - flash_grad, - opt0->ne[0], - opt0->ne[1], - opt0->ne[2], - opt0->ne[3], - nb0*opt0->ne[0], - nb0*opt0->ne[0]*opt0->ne[1], - nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2], - offset); - } break; - } - - opt0->grad = ggml_add_impl(ctx, - opt0->grad, + if (src2->grad) { + struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v); + struct ggml_tensor * grad_v = ggml_reshape(ctx, view_v, src2); + src2->grad = ggml_add_or_set(ctx, + src2->grad, grad_v, - inplace); + zero_table); } } break; case GGML_OP_FLASH_FF: @@ -16772,12 +17412,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { if (src0->grad) { src0->grad = - ggml_add_impl(ctx, + ggml_add_or_set(ctx, src0->grad, ggml_mul(ctx, ggml_sgn(ctx, src0), tensor->grad), - inplace); + zero_table); } } break; case GGML_UNARY_OP_SGN: @@ -16789,7 +17429,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor case GGML_UNARY_OP_NEG: { if (src0->grad) { - src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace); + src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table); } } break; case GGML_UNARY_OP_STEP: @@ -16809,12 +17449,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor case GGML_UNARY_OP_RELU: { if (src0->grad) { - src0->grad = ggml_add_impl(ctx, + src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_mul(ctx, ggml_step(ctx, src0), tensor->grad), - inplace); + zero_table); } } break; case GGML_UNARY_OP_GELU: @@ -16829,10 +17469,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { // necessary for llama if (src0->grad) { - src0->grad = ggml_add_impl(ctx, + src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_silu_back(ctx, src0, tensor->grad), - inplace); + zero_table); } } break; default: @@ -16855,13 +17495,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor case GGML_OP_CROSS_ENTROPY_LOSS: { if (src0->grad) { - src0->grad = ggml_add_impl(ctx, + src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_cross_entropy_loss_back(ctx, src0, src1, tensor->grad), - inplace); + zero_table); } } break; case GGML_OP_CROSS_ENTROPY_LOSS_BACK: @@ -16877,34 +17517,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor GGML_ASSERT(false); } break; } -} - -static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small"); - -static size_t hash(void * p) { - return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE; -} -static bool hash_insert(void * hash_table[], void * p) { - size_t h = hash(p); - - // linear probing - size_t i = h; - while (hash_table[i] != NULL && hash_table[i] != p) { - i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE; - if (i == h) { - // hash table is full - GGML_ASSERT(false); + for (int i = 0; i < GGML_MAX_SRC; ++i) { + if (tensor->src[i] && tensor->src[i]->grad) { + GGML_ASSERT(ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad)); } } - - if (hash_table[i] == p) { - return true; - } - - // insert - hash_table[i] = p; - return false; } static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { @@ -16922,8 +17540,12 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * } for (int i = 0; i < GGML_MAX_SRC; ++i) { - if (node->src[i]) { - ggml_visit_parents(cgraph, node->src[i]); + const int k = + (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i : + (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) : + /* unknown order, just fall back to using i*/ i; + if (node->src[k]) { + ggml_visit_parents(cgraph, node->src[k]); } } @@ -16982,6 +17604,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) { /*.grads =*/ { NULL }, /*.leafs =*/ { NULL }, /*.hash_table =*/ { NULL }, + /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT, /*.perf_runs =*/ 0, /*.perf_cycles =*/ 0, /*.perf_time_us =*/ 0, @@ -17007,12 +17630,22 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * } } + // remember original gradients which start with zero values + void ** zero_table = malloc(sizeof(void *) * GGML_GRAPH_HASHTABLE_SIZE); + memset(zero_table, 0, sizeof(void*) * GGML_GRAPH_HASHTABLE_SIZE); + for (int i = 0; i < gf->n_nodes; i++) { + if (gf->grads[i]) { + hash_insert(zero_table, gf->grads[i]); + } + } + for (int i = gf->n_nodes - 1; i >= 0; i--) { struct ggml_tensor * node = gf->nodes[i]; - // because we detached the grad nodes from the original graph, we can afford inplace operations + // inplace operations to add gradients are not created by ggml_compute_backward + // use allocator to automatically make inplace operations if (node->grad) { - ggml_compute_backward(ctx, node, keep); + ggml_compute_backward(ctx, node, zero_table); } } @@ -17024,6 +17657,8 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * ggml_build_forward_expand(gb, node->grad); } } + + free(zero_table); } struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) { @@ -17043,6 +17678,7 @@ struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) { /*.grads =*/ { NULL }, /*.leafs =*/ { NULL }, /*.hash_table =*/ { NULL }, + /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT, /*.perf_runs =*/ 0, /*.perf_cycles =*/ 0, /*.perf_time_us =*/ 0, @@ -17433,7 +18069,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { } break; case GGML_OP_CONCAT: case GGML_OP_MUL_MAT: - case GGML_OP_OUT_PROD: { n_tasks = n_threads; @@ -17475,6 +18110,18 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { cur = 0; } + work_size = MAX(work_size, cur); + } break; + case GGML_OP_OUT_PROD: + { + n_tasks = n_threads; + + size_t cur = 0; + + if (ggml_is_quantized(node->src[0]->type)) { + cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; + } + work_size = MAX(work_size, cur); } break; case GGML_OP_SCALE: @@ -18568,7 +19215,7 @@ static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * } static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) { - int i = 0; + int64_t i = 0; for (int p = 0; p < np; ++p) { const int64_t ne = ggml_nelements(ps[p]) ; // TODO: add function to get all elements at once @@ -18578,6 +19225,17 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g } } +static void ggml_opt_acc_grad(int np, struct ggml_tensor * const ps[], float * g, float scale) { + int64_t i = 0; + for (int p = 0; p < np; ++p) { + const int64_t ne = ggml_nelements(ps[p]) ; + // TODO: add function to get all elements at once + for (int64_t j = 0; j < ne; ++j) { + g[i++] += ggml_get_f32_1d(ps[p]->grad, j) * scale; + } + } +} + // // ADAM // @@ -18626,26 +19284,43 @@ static enum ggml_opt_result ggml_opt_adam( const float eps = params.adam.eps; const float gclip = params.adam.gclip; const int decay_min_ndim = params.adam.decay_min_ndim; + const int n_accum = MAX(1, params.n_gradient_accumulation); + const float accum_norm = 1.0f / (float) n_accum; + float * g = opt->adam.g->data; // gradients float * m = opt->adam.m->data; // first moment float * v = opt->adam.v->data; // second moment float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values - if (callback) { - callback(callback_data, &sched); - } - - // compute the function value - ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); - struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads); struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size); cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; - ggml_graph_compute(gb, &cplan); - opt->adam.fx_prev = ggml_get_f32_1d(f, 0); + bool cancel = false; + + // compute the function value + float fx = 0; + ggml_set_zero(opt->adam.g); + for (int accum_step = 0; accum_step < n_accum; ++accum_step) { + if (callback) { + callback(callback_data, accum_step, &sched, &cancel); + if (cancel) { + break; + } + } + // ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(gb, &cplan); + ggml_opt_acc_grad(np, ps, g, accum_norm); + fx += ggml_get_f32_1d(f, 0); + } + if (cancel) { + return GGML_OPT_DID_NOT_CONVERGE; + } + fx *= accum_norm; + + opt->adam.fx_prev = fx; opt->adam.fx_best = opt->adam.fx_prev; if (pf) { pf[opt->iter % params.past] = opt->adam.fx_prev; @@ -18668,6 +19343,9 @@ static enum ggml_opt_result ggml_opt_adam( // run the optimizer for (int t = 0; t < params.adam.n_iter; ++t) { + if (cancel) { + break; + } opt->iter = iter0 + t + 1; GGML_PRINT_DEBUG ("=== iter %d ===\n", t); @@ -18690,12 +19368,8 @@ static enum ggml_opt_result ggml_opt_adam( if (gclip > 0.0f) { // gradient clipping ggml_float sum = 0.0; - for (int p = 0; p < np; ++p) { - const int64_t ne = ggml_nelements(ps[p]); - for (int64_t j = 0; j < ne; ++j) { - float g = ggml_get_f32_1d(ps[p]->grad, j); - sum += (ggml_float)(g*g); - } + for (int64_t i = 0; i < nx; ++i) { + sum += (ggml_float)(g[i]*g[i]); } ggml_float norm = sqrt(sum); if (norm > (ggml_float) gclip) { @@ -18709,10 +19383,10 @@ static enum ggml_opt_result ggml_opt_adam( const int64_t ne = ggml_nelements(ps[p]); const float p_decay = ((ps[p]->n_dims >= decay_min_ndim) ? decay : 0.0f) * sched; for (int64_t j = 0; j < ne; ++j) { - float x = ggml_get_f32_1d(ps[p], j); - float g = ggml_get_f32_1d(ps[p]->grad, j)*gnorm; - m[i] = m[i]*beta1 + g*(1.0f - beta1); - v[i] = v[i]*beta2 + g*g*(1.0f - beta2); + float x = ggml_get_f32_1d(ps[p], j); + float g_ = g[i]*gnorm; + m[i] = m[i]*beta1 + g_*(1.0f - beta1); + v[i] = v[i]*beta2 + g_*g_*(1.0f - beta2); float mh = m[i]*beta1h; float vh = v[i]*beta2h; vh = sqrtf(vh) + eps; @@ -18723,16 +19397,26 @@ static enum ggml_opt_result ggml_opt_adam( } } - if (callback) { - callback(callback_data, &sched); + fx = 0; + ggml_set_zero(opt->adam.g); + for (int accum_step = 0; accum_step < n_accum; ++accum_step) { + if (callback) { + callback(callback_data, accum_step, &sched, &cancel); + if (cancel) { + break; + } + } + // ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(gb, &cplan); + ggml_opt_acc_grad(np, ps, g, accum_norm); + fx += ggml_get_f32_1d(f, 0); } + if (cancel) { + break; + } + fx *= accum_norm; - ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); - - ggml_graph_compute(gb, &cplan); - - const float fx = ggml_get_f32_1d(f, 0); opt->loss_after = fx; @@ -18812,11 +19496,11 @@ static enum ggml_opt_result linesearch_backtracking( float * step, const float * xp, struct ggml_tensor * f, - struct ggml_cgraph * gf, struct ggml_cgraph * gb, struct ggml_cplan * cplan, const int np, struct ggml_tensor * ps[], + bool * cancel, ggml_opt_callback callback, void * callback_data) { int count = 0; @@ -18830,6 +19514,9 @@ static enum ggml_opt_result linesearch_backtracking( const float dec = 0.5f; const float inc = 2.1f; + const int n_accum = MAX(1, params->n_gradient_accumulation); + const float accum_norm = 1.0f / (float) n_accum; + if (*step <= 0.f) { return GGML_LINESEARCH_INVALID_PARAMETERS; } @@ -18846,13 +19533,7 @@ static enum ggml_opt_result linesearch_backtracking( finit = *fx; dgtest = params->lbfgs.ftol*dginit; - while (true) { - if (callback) { - // LBFG-S does not support learning rate -> ignore learning schedule - float sched = 0; - callback(callback_data, &sched); - } - + while (!*cancel) { ggml_vec_cpy_f32(nx, x, xp); ggml_vec_mad_f32(nx, x, d, *step); @@ -18860,14 +19541,28 @@ static enum ggml_opt_result linesearch_backtracking( { ggml_opt_set_params(np, ps, x); - ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); - - ggml_graph_compute(gb, cplan); - - ggml_opt_get_grad(np, ps, g); + *fx = 0; + memset(g, 0, sizeof(float)*nx); + for (int accum_step = 0; accum_step < n_accum; ++accum_step) { + if (callback) { + // LBFG-S does not support learning rate -> ignore learning schedule + float sched = 0; + callback(callback_data, accum_step, &sched, cancel); + if (*cancel) { + break; + } + } + // ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(gb, cplan); + ggml_opt_acc_grad(np, ps, g, accum_norm); + *fx += ggml_get_f32_1d(f, 0); + } + if (*cancel) { + break; + } + *fx *= accum_norm; - *fx = ggml_get_f32_1d(f, 0); } ++count; @@ -18913,7 +19608,7 @@ static enum ggml_opt_result linesearch_backtracking( (*step) *= width; } - return GGML_LINESEARCH_FAIL; + GGML_UNREACHABLE(); } static enum ggml_opt_result ggml_opt_lbfgs( @@ -18968,6 +19663,9 @@ static enum ggml_opt_result ggml_opt_lbfgs( float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values + const int n_accum = MAX(1, params.n_gradient_accumulation); + const float accum_norm = 1.0f / (float) n_accum; + float fx = 0.0f; // cost function value float xnorm = 0.0f; // ||x|| float gnorm = 0.0f; // ||g|| @@ -18981,24 +19679,33 @@ static enum ggml_opt_result ggml_opt_lbfgs( float * lm_s = opt->lbfgs.lms->data; float * lm_y = opt->lbfgs.lmy->data; - if (callback) { - // LBFG-S does not support learning rate -> ignore learning schedule - float sched = 0; - callback(callback_data, &sched); - } + bool cancel = false; // evaluate the function value and its gradient { ggml_opt_set_params(np, ps, x); - ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); - - ggml_graph_compute(gb, &cplan); - - ggml_opt_get_grad(np, ps, g); - - fx = ggml_get_f32_1d(f, 0); + fx = 0; + memset(g, 0, sizeof(float)*nx); + for (int accum_step = 0; accum_step < n_accum; ++accum_step) { + if (callback) { + // LBFG-S does not support learning rate -> ignore learning schedule + float sched = 0; + callback(callback_data, accum_step, &sched, &cancel); + if (cancel) { + break; + } + } + // ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(gb, &cplan); + ggml_opt_acc_grad(np, ps, g, accum_norm); + fx += ggml_get_f32_1d(f, 0); + } + if (cancel) { + return GGML_OPT_DID_NOT_CONVERGE; + } + fx *= accum_norm; opt->loss_before = fx; opt->loss_after = fx; @@ -19056,7 +19763,10 @@ static enum ggml_opt_result ggml_opt_lbfgs( ggml_vec_cpy_f32(nx, xp, x); ggml_vec_cpy_f32(nx, gp, g); - ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f, gf, gb, &cplan, np, ps, callback, callback_data); + ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data); + if (!cancel) { + break; + } if (ls < 0) { // linesearch failed - go back to the previous point and return @@ -19165,7 +19875,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( step[0] = 1.0; } - return GGML_OPT_DID_NOT_CONVERGE; + GGML_UNREACHABLE(); } struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) { @@ -19185,6 +19895,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) { .print_forward_graph = true, .print_backward_graph = true, + .n_gradient_accumulation = 1, + .adam = { .n_iter = 10000, .sched = 1.000f, @@ -19213,6 +19925,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) { .print_forward_graph = true, .print_backward_graph = true, + .n_gradient_accumulation = 1, + .lbfgs = { .m = 6, .n_iter = 100, @@ -19243,13 +19957,32 @@ GGML_API void ggml_opt_init( opt->iter = 0; opt->nx = nx; opt->just_initialized = true; + if (opt->ctx == NULL) { + struct ggml_init_params ctx_opt_params; + if (opt->params.type == GGML_OPT_ADAM) { + ctx_opt_params.mem_size = GGML_MEM_ALIGN*3 + ggml_tensor_overhead()*3 + ggml_type_size(GGML_TYPE_F32)*nx*3; + if (opt->params.past > 0) { + ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past; + } + } else if (opt->params.type == GGML_OPT_LBFGS) { + ctx_opt_params.mem_size = GGML_MEM_ALIGN*9 + ggml_tensor_overhead()*9 + ggml_type_size(GGML_TYPE_F32)*(nx*5 + opt->params.lbfgs.m*2 + nx*opt->params.lbfgs.m*2); + if (opt->params.past > 0) { + ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past; + } + } + ctx_opt_params.mem_buffer = NULL; + ctx_opt_params.no_alloc = false; + + opt->ctx = ggml_init(ctx_opt_params); + } switch (opt->params.type) { case GGML_OPT_ADAM: { - opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); - opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); + opt->adam.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); + opt->adam.m = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); + opt->adam.v = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); opt->adam.pf = params.past > 0 - ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past) + ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past) : NULL; ggml_set_zero(opt->adam.m); ggml_set_zero(opt->adam.v); @@ -19259,18 +19992,18 @@ GGML_API void ggml_opt_init( } break; case GGML_OPT_LBFGS: { - opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); - opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); - opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); - opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); - opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); + opt->lbfgs.x = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); + opt->lbfgs.xp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); + opt->lbfgs.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); + opt->lbfgs.gp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); + opt->lbfgs.d = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); opt->lbfgs.pf = params.past > 0 - ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past) + ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past) : NULL; - opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m); - opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m); - opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m); - opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m); + opt->lbfgs.lmal = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m); + opt->lbfgs.lmys = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m); + opt->lbfgs.lms = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m); + opt->lbfgs.lmy = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m); ggml_set_zero(opt->lbfgs.x); ggml_set_zero(opt->lbfgs.xp); ggml_set_zero(opt->lbfgs.g); @@ -19876,10 +20609,10 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p } break; case GGUF_TYPE_ARRAY: case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break; - }; + } } break; case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); - }; + } if (!ok) { break; @@ -20155,78 +20888,94 @@ int gguf_find_key(const struct gguf_context * ctx, const char * key) { return keyfound; } -const char * gguf_get_key(const struct gguf_context * ctx, int i) { - return ctx->kv[i].key.data; +const char * gguf_get_key(const struct gguf_context * ctx, int key_id) { + return ctx->kv[key_id].key.data; } -enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int i) { - return ctx->kv[i].type; +enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) { + return ctx->kv[key_id].type; } -enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.arr.type; +enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); + return ctx->kv[key_id].value.arr.type; } -const void * gguf_get_arr_data(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.arr.data; +const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); + return ctx->kv[key_id].value.arr.data; } const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); struct gguf_kv * kv = &ctx->kv[key_id]; struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i]; return str->data; } -int gguf_get_arr_n(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.arr.n; +int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); + return ctx->kv[key_id].value.arr.n; } -uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.uint8; +uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8); + return ctx->kv[key_id].value.uint8; } -int8_t gguf_get_val_i8(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.int8; +int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8); + return ctx->kv[key_id].value.int8; } -uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.uint16; +uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16); + return ctx->kv[key_id].value.uint16; } -int16_t gguf_get_val_i16(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.int16; +int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16); + return ctx->kv[key_id].value.int16; } -uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.uint32; +uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32); + return ctx->kv[key_id].value.uint32; } -int32_t gguf_get_val_i32(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.int32; +int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32); + return ctx->kv[key_id].value.int32; } -float gguf_get_val_f32(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.float32; +float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32); + return ctx->kv[key_id].value.float32; } -uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.uint64; +uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64); + return ctx->kv[key_id].value.uint64; } -int64_t gguf_get_val_i64(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.int64; +int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64); + return ctx->kv[key_id].value.int64; } -double gguf_get_val_f64(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.float64; +double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64); + return ctx->kv[key_id].value.float64; } -bool gguf_get_val_bool(const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.bool_; +bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL); + return ctx->kv[key_id].value.bool_; } -const char * gguf_get_val_str (const struct gguf_context * ctx, int i) { - return ctx->kv[i].value.str.data; +const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) { + GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING); + return ctx->kv[key_id].value.str.data; } int gguf_get_n_tensors(const struct gguf_context * ctx) { @@ -20591,10 +21340,10 @@ static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * } break; case GGUF_TYPE_ARRAY: case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break; - }; + } } break; case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); - }; + } } // write tensor infos diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml.h b/plugins/wasi_nn/thirdparty/ggml/ggml.h index f4545687..460857fa 100644 --- a/plugins/wasi_nn/thirdparty/ggml/ggml.h +++ b/plugins/wasi_nn/thirdparty/ggml/ggml.h @@ -214,8 +214,8 @@ #define GGML_QNT_VERSION_FACTOR 1000 // do not change this #define GGML_MAX_DIMS 4 -#define GGML_MAX_NODES 4096 -#define GGML_MAX_PARAMS 256 +#define GGML_MAX_NODES 16384 +#define GGML_MAX_PARAMS 1024 #define GGML_MAX_CONTEXTS 64 #define GGML_MAX_SRC 6 #define GGML_MAX_NAME 64 @@ -248,6 +248,14 @@ } \ } while (0) +#ifndef NDEBUG +#define GGML_UNREACHABLE() GGML_ASSERT(!"statement should not be reached") +#elif defined(__GNUC__) +#define GGML_UNREACHABLE() __builtin_unreachable() +#else +#define GGML_UNREACHABLE() ((void) 0) +#endif + // used to copy the number of elements and stride in bytes of tensors into local variables. // main purpose is to reduce code duplication and improve readability. // @@ -445,6 +453,12 @@ extern "C" { GGML_OBJECT_WORK_BUFFER }; + enum ggml_log_level { + GGML_LOG_LEVEL_ERROR = 2, + GGML_LOG_LEVEL_WARN = 3, + GGML_LOG_LEVEL_INFO = 4 + }; + // ggml object struct ggml_object { size_t offs; @@ -467,8 +481,8 @@ extern "C" { int n_dims; int64_t ne[GGML_MAX_DIMS]; // number of elements size_t nb[GGML_MAX_DIMS]; // stride in bytes: - // nb[0] = sizeof(type) - // nb[1] = nb[0] * ne[0] + padding + // nb[0] = ggml_type_size(type) + // nb[1] = nb[0] * (ne[0] / ggml_blck_size(type)) + padding // nb[i] = nb[i-1] * ne[i-1] // compute data @@ -520,7 +534,15 @@ extern "C" { // next prime after GGML_MAX_NODES // #define GGML_GRAPH_HASHTABLE_SIZE 4099 // next prime after GGML_MAX_NODES * 2 (nodes + leafs) - #define GGML_GRAPH_HASHTABLE_SIZE 8273 + // #define GGML_GRAPH_HASHTABLE_SIZE 8273 + // #define GGML_GRAPH_HASHTABLE_SIZE 16411 + #define GGML_GRAPH_HASHTABLE_SIZE 32771 + + enum ggml_cgraph_eval_order { + GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0, + GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT, + GGML_CGRAPH_EVAL_ORDER_COUNT + }; // computation graph struct ggml_cgraph { @@ -533,6 +555,8 @@ extern "C" { void * visited_hash_table[GGML_GRAPH_HASHTABLE_SIZE]; + enum ggml_cgraph_eval_order order; + // performance int perf_runs; int64_t perf_cycles; @@ -680,12 +704,21 @@ extern "C" { GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value); GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value); + // Converts a flat index into coordinates + GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3); + GGML_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i); GGML_API void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value); + GGML_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3); + GGML_API void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value); + GGML_API float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i); GGML_API void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value); + GGML_API float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3); + GGML_API void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value); + GGML_API void * ggml_get_data (const struct ggml_tensor * tensor); GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor); @@ -719,6 +752,12 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + GGML_API struct ggml_tensor * ggml_add_cast( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_type type); + GGML_API struct ggml_tensor * ggml_add1( struct ggml_context * ctx, struct ggml_tensor * a, @@ -828,6 +867,7 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + // sums repetitions in a into shape of b GGML_API struct ggml_tensor * ggml_repeat_back( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1049,7 +1089,6 @@ extern "C" { size_t nb1, size_t offset); - // a -> b, return view(b) GGML_API struct ggml_tensor * ggml_cpy( struct ggml_context * ctx, @@ -1072,6 +1111,33 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // make contiguous, with new shape + GGML_API struct ggml_tensor * ggml_cont_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0); + + GGML_API struct ggml_tensor * ggml_cont_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1); + + GGML_API struct ggml_tensor * ggml_cont_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2); + + GGML_API struct ggml_tensor * ggml_cont_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3); + // return view(a), b specifies the new shape // TODO: when we start computing gradient, make a copy instead of view GGML_API struct ggml_tensor * ggml_reshape( @@ -1219,14 +1285,15 @@ extern "C" { struct ggml_tensor * b); // rotary position embedding - // if mode & 1 == 1, skip n_past elements + // if mode & 1 == 1, skip n_past elements (DEPRECATED) // if mode & 2 == 1, GPT-NeoX style // if mode & 4 == 1, ChatGLM style - // TODO: avoid creating a new tensor every time + // + // b is an int32 vector with size a->ne[2], it contains the positions GGML_API struct ggml_tensor * ggml_rope( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx); @@ -1235,7 +1302,7 @@ extern "C" { GGML_API struct ggml_tensor * ggml_rope_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx); @@ -1244,7 +1311,7 @@ extern "C" { GGML_API struct ggml_tensor * ggml_rope_custom( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx, @@ -1255,7 +1322,7 @@ extern "C" { GGML_API struct ggml_tensor * ggml_rope_custom_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx, @@ -1266,7 +1333,7 @@ extern "C" { GGML_API struct ggml_tensor * ggml_rope_xpos_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, float base, bool down); @@ -1276,7 +1343,7 @@ extern "C" { GGML_API struct ggml_tensor * ggml_rope_back( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx, @@ -1656,6 +1723,16 @@ extern "C" { // dump the graph into a file using the dot format GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename); + // build gradient checkpointing backward graph gb for gf using provided checkpoints + // gb_tmp will contain original backward graph with rewritten backward process nodes, + // but without the second forward pass nodes. + GGML_API void ggml_build_backward_gradient_checkpointing( + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + struct ggml_cgraph * gb_tmp, + struct ggml_tensor * * checkpoints, + int n_checkpoints); // // optimization // @@ -1690,7 +1767,8 @@ extern "C" { GGML_LINESEARCH_INVALID_PARAMETERS, }; - typedef void (*ggml_opt_callback)(void * data, float * sched); + typedef void (*ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel); + typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data); // optimization parameters // @@ -1721,6 +1799,8 @@ extern "C" { bool print_forward_graph; bool print_backward_graph; + int n_gradient_accumulation; + // ADAM parameters struct { int n_iter; @@ -1766,6 +1846,7 @@ extern "C" { float loss_after; struct { + struct ggml_tensor * g; // current gradient struct ggml_tensor * m; // first moment struct ggml_tensor * v; // second moment struct ggml_tensor * pf; // past function values @@ -1882,26 +1963,26 @@ extern "C" { GGML_API int gguf_get_n_kv(const struct gguf_context * ctx); GGML_API int gguf_find_key(const struct gguf_context * ctx, const char * key); - GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int i); - - GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int i); - GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int i); - - // results are undefined if the wrong type is used for the key - GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int i); - GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int i); - GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int i); - GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int i); - GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int i); - GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int i); - GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int i); - GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int i); - GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int i); - GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int i); - GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int i); - GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int i); - GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int i); - GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int i); + GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int key_id); + + GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int key_id); + GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id); + + // will abort if the wrong type is used for the key + GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int key_id); + GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int key_id); + GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int key_id); + GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int key_id); + GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int key_id); + GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int key_id); + GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int key_id); + GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int key_id); + GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int key_id); + GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id); + GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id); + GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id); + GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id); + GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id); GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i); GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx); diff --git a/plugins/wasi_nn/thirdparty/ggml/llama.cpp b/plugins/wasi_nn/thirdparty/ggml/llama.cpp index 758a1c12..40d2246f 100644 --- a/plugins/wasi_nn/thirdparty/ggml/llama.cpp +++ b/plugins/wasi_nn/thirdparty/ggml/llama.cpp @@ -72,6 +72,7 @@ #include #include #include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -92,12 +93,12 @@ // LLAMA_ATTRIBUTE_FORMAT(2, 3) -static void llama_log_internal (llama_log_level level, const char* format, ...); -static void llama_log_callback_default(llama_log_level level, const char * text, void * user_data); +static void llama_log_internal (ggml_log_level level, const char* format, ...); +static void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data); -#define LLAMA_LOG_INFO(...) llama_log_internal(LLAMA_LOG_LEVEL_INFO , __VA_ARGS__) -#define LLAMA_LOG_WARN(...) llama_log_internal(LLAMA_LOG_LEVEL_WARN , __VA_ARGS__) -#define LLAMA_LOG_ERROR(...) llama_log_internal(LLAMA_LOG_LEVEL_ERROR, __VA_ARGS__) +#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) +#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) +#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) // // helpers @@ -166,13 +167,13 @@ enum llm_arch { }; static std::map LLM_ARCH_NAMES = { - { LLM_ARCH_LLAMA, "llama" }, - { LLM_ARCH_FALCON, "falcon" }, - { LLM_ARCH_GPT2, "gpt2" }, - { LLM_ARCH_GPTJ, "gptj" }, - { LLM_ARCH_GPTNEOX, "gptneox" }, - { LLM_ARCH_MPT, "mpt" }, - { LLM_ARCH_BAICHUAN, "baichuan" }, + { LLM_ARCH_LLAMA, "llama" }, + { LLM_ARCH_FALCON, "falcon" }, + { LLM_ARCH_GPT2, "gpt2" }, + { LLM_ARCH_GPTJ, "gptj" }, + { LLM_ARCH_GPTNEOX, "gptneox" }, + { LLM_ARCH_MPT, "mpt" }, + { LLM_ARCH_BAICHUAN, "baichuan" }, { LLM_ARCH_STARCODER, "starcoder" }, }; @@ -221,16 +222,16 @@ enum llm_kv { }; static std::map LLM_KV_NAMES = { - { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, - { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, - { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, - { LLM_KV_GENERAL_NAME, "general.name" }, - { LLM_KV_GENERAL_AUTHOR, "general.author" }, - { LLM_KV_GENERAL_URL, "general.url" }, - { LLM_KV_GENERAL_DESCRIPTION, "general.description" }, - { LLM_KV_GENERAL_LICENSE, "general.license" }, - { LLM_KV_GENERAL_SOURCE_URL, "general.source_url" }, - { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source_hf_repo" }, + { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, + { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, + { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, + { LLM_KV_GENERAL_NAME, "general.name" }, + { LLM_KV_GENERAL_AUTHOR, "general.author" }, + { LLM_KV_GENERAL_URL, "general.url" }, + { LLM_KV_GENERAL_DESCRIPTION, "general.description" }, + { LLM_KV_GENERAL_LICENSE, "general.license" }, + { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, + { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, @@ -448,7 +449,7 @@ struct LLM_TN { // #define GGUF_GET_KEY(ctx, dst, func, type, req, key) \ -{ \ +do { \ const std::string skey(key); \ const int kid = gguf_find_key(ctx, skey.c_str()); \ if (kid >= 0) { \ @@ -460,7 +461,7 @@ struct LLM_TN { } else if (req) { \ throw std::runtime_error(format("key not found in model: %s", skey.c_str())); \ } \ -} +} while (0) // // ggml helpers @@ -886,10 +887,10 @@ static void llama_nop(struct ggml_tensor * tensor) { // don't offload by default static std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) { std::vector result(8, 0); - const int n_tokens = llama_token_to_piece(ctx, token, result.data(), result.size()); + const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); if (n_tokens < 0) { result.resize(-n_tokens); - int check = llama_token_to_piece(ctx, token, result.data(), result.size()); + int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); GGML_ASSERT(check == -n_tokens); } else { result.resize(n_tokens); @@ -904,7 +905,7 @@ static std::string llama_token_to_str(const struct llama_context * ctx, llama_to struct llama_state { // We save the log callback globally - llama_log_callback log_callback = llama_log_callback_default; + ggml_log_callback log_callback = llama_log_callback_default; void * log_callback_user_data = nullptr; }; @@ -930,9 +931,9 @@ static const size_t MB = kB*kB; static const size_t GB = kB*kB*kB; struct llama_hparams { + bool vocab_only; uint32_t n_vocab; uint32_t n_ctx_train; // context size the model was trained on - uint32_t n_ctx; // context size used during inference uint32_t n_embd; uint32_t n_head; uint32_t n_head_kv; @@ -943,8 +944,8 @@ struct llama_hparams { float f_norm_eps; float f_norm_rms_eps; - float rope_freq_base; - float rope_freq_scale; + float rope_freq_base_train; + float rope_freq_scale_train; bool operator!=(const llama_hparams & other) const { return static_cast(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT @@ -961,15 +962,18 @@ struct llama_hparams { uint32_t n_embd_gqa() const { return n_embd/n_gqa(); } +}; - size_t kv_size() const { - size_t result = 2ull; - result *= (size_t) n_embd_gqa(); - result *= (size_t) n_ctx; - result *= (size_t) n_layer; - result *= sizeof(ggml_fp16_t); - return result; - } +struct llama_cparams { + uint32_t n_ctx; // context size used during inference + uint32_t n_batch; + uint32_t n_threads; // number of threads to use for generation + uint32_t n_threads_batch; // number of threads to use for batch processing + + float rope_freq_base; + float rope_freq_scale; + + bool mul_mat_q; }; struct llama_layer { @@ -1004,7 +1008,29 @@ struct llama_layer { struct ggml_tensor * b3; // ffn_up }; +struct llama_kv_cell { + llama_pos pos = -1; + llama_pos delta = 0; + + std::set seq_id; + + bool has_seq_id(const llama_seq_id & id) const { + return seq_id.find(id) != seq_id.end(); + } +}; + +// ring-buffer of cached KV data struct llama_kv_cache { + bool has_shift = false; + + uint32_t head = 0; + uint32_t size = 0; + + // computed before each graph build + uint32_t n = 0; + + std::vector cells; + struct ggml_tensor * k = NULL; struct ggml_tensor * v = NULL; @@ -1012,8 +1038,6 @@ struct llama_kv_cache { llama_buffer buf; - int n; // number of tokens currently in the cache - ~llama_kv_cache() { if (ctx) { ggml_free(ctx); @@ -1052,6 +1076,10 @@ struct llama_vocab { id special_pad_id = -1; id linefeed_id = 13; + id special_prefix_id = 32007; + id special_middle_id = 32009; + id special_suffix_id = 32008; + id special_eot_id = 32010; int find_bpe_rank(std::string token_left, std::string token_right) const { replace_all(token_left, " ", "\u0120"); @@ -1127,11 +1155,8 @@ struct llama_model { }; struct llama_context { - llama_context(const llama_model & model) : model(model), t_load_us(model.t_load_us), t_start_us(model.t_start_us) {} + llama_context(const llama_model & model) : model(model), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {} ~llama_context() { - if (model_owner) { - delete &model; - } #ifdef GGML_USE_METAL if (ctx_metal) { ggml_metal_free(ctx_metal); @@ -1142,27 +1167,26 @@ struct llama_context { } } + llama_cparams cparams; + + const llama_model & model; + + // key + value cache for the self attention + struct llama_kv_cache kv_self; + std::mt19937 rng; bool has_evaluated_once = false; + int64_t t_start_us; + int64_t t_load_us; int64_t t_sample_us = 0; - int64_t t_eval_us = 0; int64_t t_p_eval_us = 0; + int64_t t_eval_us = 0; int32_t n_sample = 0; // number of tokens sampled - int32_t n_eval = 0; // number of eval calls int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) - - const llama_model & model; - - bool model_owner = false; - - int64_t t_load_us; - int64_t t_start_us; - - // key + value cache for the self attention - struct llama_kv_cache kv_self; + int32_t n_eval = 0; // number of eval calls // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; @@ -1197,16 +1221,23 @@ static bool llama_kv_cache_init( const struct llama_hparams & hparams, struct llama_kv_cache & cache, ggml_type wtype, - int n_ctx, + uint32_t n_ctx, int n_gpu_layers) { - const int n_embd = hparams.n_embd_gqa(); - const int n_layer = hparams.n_layer; + const uint32_t n_embd = hparams.n_embd_gqa(); + const uint32_t n_layer = hparams.n_layer; const int64_t n_mem = n_layer*n_ctx; const int64_t n_elements = n_embd*n_mem; + cache.has_shift = false; + + cache.head = 0; + cache.size = n_ctx; + + cache.cells.clear(); + cache.cells.resize(n_ctx); + cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB); - cache.n = 0; struct ggml_init_params params; params.mem_size = cache.buf.size; @@ -1227,17 +1258,154 @@ static bool llama_kv_cache_init( (void) n_gpu_layers; #ifdef GGML_USE_CUBLAS - if (n_gpu_layers > n_layer + 1) { + size_t vram_kv_cache = 0; + + if (n_gpu_layers > (int)n_layer + 1) { ggml_cuda_assign_buffers_no_scratch(cache.v); + LLAMA_LOG_INFO("%s: offloading v cache to GPU\n", __func__); + vram_kv_cache += ggml_nbytes(cache.v); } - if (n_gpu_layers > n_layer + 2) { + if (n_gpu_layers > (int)n_layer + 2) { ggml_cuda_assign_buffers_no_scratch(cache.k); + LLAMA_LOG_INFO("%s: offloading k cache to GPU\n", __func__); + vram_kv_cache += ggml_nbytes(cache.k); + } + if (vram_kv_cache > 0) { + LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MB\n", __func__, vram_kv_cache / 1024.0 / 1024.0); } #endif // GGML_USE_CUBLAS return true; } +// find an empty slot of size "n_tokens" in the cache +// updates the cache head +static bool llama_kv_cache_find_slot( + struct llama_kv_cache & cache, + const struct llama_batch & batch) { + const uint32_t n_ctx = cache.size; + const uint32_t n_tokens = batch.n_tokens; + + if (n_tokens > n_ctx) { + LLAMA_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx); + return false; + } + + uint32_t n_tested = 0; + + while (true) { + if (cache.head + n_tokens > n_ctx) { + cache.head = 0; + n_tested += n_ctx - cache.head; + continue; + } + + bool found = true; + for (uint32_t i = 0; i < n_tokens; i++) { + if (cache.cells[cache.head + i].pos >= 0) { + found = false; + cache.head += i + 1; + n_tested += i + 1; + break; + } + } + + if (found) { + break; + } + + if (n_tested >= n_ctx) { + //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return false; + } + } + + for (uint32_t i = 0; i < n_tokens; i++) { + cache.cells[cache.head + i].pos = batch.pos[i]; + cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i]); + } + + return true; +} + +// find how many cells are currently in use +static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { + for (uint32_t i = cache.size - 1; i > 0; --i) { + if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) { + return i + 1; + } + } + + return 0; +} + +static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, int32_t c1) { + if (c0 < 0) c0 = 0; + if (c1 < 0) c1 = cache.size; + + for (int32_t i = c0; i < c1; ++i) { + cache.cells[i].pos = -1; + cache.cells[i].seq_id.clear(); + } +} + +static void llama_kv_cache_seq_rm( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + cache.cells[i].seq_id.erase(seq_id); + if (cache.cells[i].seq_id.empty()) { + cache.cells[i].pos = -1; + } + } + } +} + +static void llama_kv_cache_seq_cp( + struct llama_kv_cache & cache, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1) { + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + cache.cells[i].seq_id.insert(seq_id_dst); + } + } +} + +static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) { + for (uint32_t i = 0; i < cache.size; ++i) { + if (!cache.cells[i].has_seq_id(seq_id)) { + cache.cells[i].pos = -1; + cache.cells[i].seq_id.clear(); + } + } +} + +static void llama_kv_cache_seq_shift( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta) { + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + cache.cells[i].pos += delta; + if (cache.cells[i].pos < 0) { + cache.cells[i].pos = -1; + cache.cells[i].seq_id.clear(); + } else { + cache.has_shift = true; + cache.cells[i].delta = delta; + } + } + } +} + // // model loading and saving // @@ -1559,7 +1727,7 @@ struct llama_model_loader { lmlock->grow_to(size_lock); } break; -#if defined(GGML_USE_CUBLAS) +#ifdef GGML_USE_CUBLAS case GGML_BACKEND_GPU: case GGML_BACKEND_GPU_SPLIT: // old code: @@ -1592,7 +1760,15 @@ struct llama_model_loader { // load LLaMA models // -static std::string llama_model_ftype_name(enum llama_ftype ftype) { +static std::string llama_model_arch_name(llm_arch arch) { + auto it = LLM_ARCH_NAMES.find(arch); + if (it == LLM_ARCH_NAMES.end()) { + return "unknown"; + } + return it->second; +} + +static std::string llama_model_ftype_name(llama_ftype ftype) { if (ftype & LLAMA_FTYPE_GUESSED) { return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)"; } @@ -1648,10 +1824,7 @@ static void llm_load_arch(llama_model_loader & ml, llama_model & model) { static void llm_load_hparams( llama_model_loader & ml, - llama_model & model, - int n_ctx, - float rope_freq_base, - float rope_freq_scale) { + llama_model & model) { struct gguf_context * ctx = ml.ctx_gguf; const auto kv = LLM_KV(model.arch); @@ -1662,29 +1835,25 @@ static void llm_load_hparams( GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME)); // get hparams kv - GGUF_GET_KEY(ctx, hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, kv(LLM_KV_TOKENIZER_LIST)); - GGUF_GET_KEY(ctx, hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_CONTEXT_LENGTH)); - GGUF_GET_KEY(ctx, hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH)); - GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH)); - GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT)); - GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT)); + GGUF_GET_KEY(ctx, hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, kv(LLM_KV_TOKENIZER_LIST)); + GGUF_GET_KEY(ctx, hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_CONTEXT_LENGTH)); + GGUF_GET_KEY(ctx, hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH)); + GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH)); + GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT)); + GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT)); // n_head_kv is optional, default to n_head hparams.n_head_kv = hparams.n_head; GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV)); // rope_freq_base (optional) - if (rope_freq_base == 0.0f) { - rope_freq_base = 10000.0f; - GGUF_GET_KEY(ctx, rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE)); - } + hparams.rope_freq_base_train = 10000.0f; + GGUF_GET_KEY(ctx, hparams.rope_freq_base_train, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE)); // rope_freq_scale (inverse of the kv) is optional - if (rope_freq_scale == 0.0f) { - float ropescale = 1.0f; - GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR)); - rope_freq_scale = 1.0f/ropescale; - } + float ropescale = 1.0f; + GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR)); + hparams.rope_freq_scale_train = 1.0f/ropescale; // sanity check for n_rot (optional) { @@ -1748,13 +1917,9 @@ static void llm_load_hparams( } } break; default: (void)0; - }; + } model.ftype = ml.ftype; - - hparams.n_ctx = n_ctx; - hparams.rope_freq_base = rope_freq_base; - hparams.rope_freq_scale = rope_freq_scale; } // TODO: This should probably be in llama.h @@ -1775,20 +1940,18 @@ static void llm_load_vocab( throw std::runtime_error("cannot find tokenizer vocab in model file\n"); } + const float * scores = nullptr; const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str()); - if (score_idx == -1) { - throw std::runtime_error("cannot find tokenizer scores in model file\n"); + if (score_idx != -1) { + scores = (const float * ) gguf_get_arr_data(ctx, score_idx); } - const float * scores = (const float * ) gguf_get_arr_data(ctx, score_idx); - + const int * toktypes = nullptr; const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str()); - if (toktype_idx == -1) { - throw std::runtime_error("cannot find token type list in GGUF file\n"); + if (toktype_idx != -1) { + toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); } - const int * toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); - // determine vocab type { std::string tokenizer_name; @@ -1856,8 +2019,8 @@ static void llm_load_vocab( auto & token_data = vocab.id_to_token[i]; token_data.text = std::move(word); - token_data.score = scores[i]; - token_data.type = (llama_token_type) toktypes[i]; + token_data.score = scores ? scores[i] : 0.0f; + token_data.type = toktypes ? (llama_token_type) toktypes[i] : LLAMA_TOKEN_TYPE_NORMAL; } // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' @@ -1880,31 +2043,30 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { const auto & vocab = model.vocab; // hparams - LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver)); - LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str()); - LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix - LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab); - LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) vocab.bpe_ranks.size()); - LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); - LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx); - LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); - LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head); - LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); - LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); - LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim - LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa()); - LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); - LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); - LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); - LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base); - LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); - LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); - LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str()); - LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9); + LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver)); + LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str()); + LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix + LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab); + LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) vocab.bpe_ranks.size()); + LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); + LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); + LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head); + LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); + LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); + LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim + LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa()); + LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); + LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); + LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); + LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); + LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); + LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str()); + LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9); if (ml.n_bytes < GB) { - LLAMA_LOG_INFO("%s: model size = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements); + LLAMA_LOG_INFO("%s: model size = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements); } else { - LLAMA_LOG_INFO("%s: model size = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements); + LLAMA_LOG_INFO("%s: model size = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements); } // general kv @@ -1922,13 +2084,9 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { static void llm_load_tensors( llama_model_loader & ml, llama_model & model, - int n_batch, int n_gpu_layers, int main_gpu, const float * tensor_split, - const bool mul_mat_q, - bool low_vram, - ggml_type memory_type, bool use_mlock, llama_progress_callback progress_callback, void * progress_callback_user_data) { @@ -1967,11 +2125,9 @@ static void llm_load_tensors( } (void) main_gpu; - (void) mul_mat_q; -#if defined(GGML_USE_CUBLAS) +#ifdef GGML_USE_CUBLAS LLAMA_LOG_INFO("%s: using " GGML_CUDA_NAME " for GPU acceleration\n", __func__); ggml_cuda_set_main_device(main_gpu); - ggml_cuda_set_mul_mat_q(mul_mat_q); #define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU #define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU_SPLIT #elif defined(GGML_USE_CLBLAST) @@ -2006,9 +2162,9 @@ static void llm_load_tensors( // norm is not performance relevant on its own but keeping it in VRAM reduces data copying // on Windows however this is detrimental unless everything is on the GPU #ifndef _WIN32 - backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = LLAMA_BACKEND_OFFLOAD; #else - backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; #endif // _WIN32 backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; @@ -2072,9 +2228,9 @@ static void llm_load_tensors( // norm is not performance relevant on its own but keeping it in VRAM reduces data copying // on Windows however this is detrimental unless everything is on the GPU #ifndef _WIN32 - backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = LLAMA_BACKEND_OFFLOAD; #else - backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; #endif // _WIN32 backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; @@ -2142,9 +2298,9 @@ static void llm_load_tensors( // norm is not performance relevant on its own but keeping it in VRAM reduces data copying // on Windows however this is detrimental unless everything is on the GPU #ifndef _WIN32 - backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = LLAMA_BACKEND_OFFLOAD; #else - backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; #endif // _WIN32 backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; @@ -2219,9 +2375,9 @@ static void llm_load_tensors( // norm is not performance relevant on its own but keeping it in VRAM reduces data copying // on Windows however this is detrimental unless everything is on the GPU #ifndef _WIN32 - backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = LLAMA_BACKEND_OFFLOAD; #else - backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; #endif // _WIN32 backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; @@ -2286,27 +2442,19 @@ static void llm_load_tensors( } break; default: throw std::runtime_error("unknown architecture"); - }; + } } ml.done_getting_tensors(); // print memory requirements { - const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1; - // this is the total memory required to run the inference size_t mem_required = ctx_size + mmapped_size - vram_weights; // weights in VRAM not in memory - // this is the memory required by one llama_state - const size_t mem_required_state = scale*hparams.kv_size(); - - LLAMA_LOG_INFO("%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__, - mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); - - (void) n_batch; + LLAMA_LOG_INFO("%s: mem required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0); #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); @@ -2315,36 +2463,17 @@ static void llm_load_tensors( if (n_gpu_layers > (int) hparams.n_layer) { LLAMA_LOG_INFO("%s: offloading non-repeating layers to GPU\n", __func__); } - size_t vram_kv_cache = 0; #ifdef GGML_USE_CUBLAS const int max_backend_supported_layers = hparams.n_layer + 3; - const int max_offloadable_layers = low_vram ? hparams.n_layer + 1 : hparams.n_layer + 3; - if (n_gpu_layers > (int) hparams.n_layer + 1) { - if (low_vram) { - LLAMA_LOG_INFO("%s: cannot offload v cache to GPU due to low VRAM option\n", __func__); - } else { - LLAMA_LOG_INFO("%s: offloading v cache to GPU\n", __func__); - vram_kv_cache += hparams.kv_size() / 2; - } - } - if (n_gpu_layers > (int) hparams.n_layer + 2) { - if (low_vram) { - LLAMA_LOG_WARN("%s: cannot offload k cache to GPU due to low VRAM option\n", __func__); - } else { - LLAMA_LOG_INFO("%s: offloading k cache to GPU\n", __func__); - vram_kv_cache += hparams.kv_size() / 2; - } - } + const int max_offloadable_layers = hparams.n_layer + 3; #elif defined(GGML_USE_CLBLAST) const int max_backend_supported_layers = hparams.n_layer + 1; const int max_offloadable_layers = hparams.n_layer + 1; #endif // GGML_USE_CUBLAS - LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", - __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); - LLAMA_LOG_INFO("%s: VRAM used: %zu MB\n", - __func__, (vram_weights + vram_kv_cache + MB - 1) / MB); // round up + LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); + LLAMA_LOG_INFO("%s: VRAM used: %.2f MB\n", __func__, vram_weights / 1024.0 / 1024.0); #else (void) n_gpu_layers; #endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) @@ -2357,7 +2486,7 @@ static void llm_load_tensors( } (void) tensor_split; -#if defined(GGML_USE_CUBLAS) +#ifdef GGML_USE_CUBLAS { ggml_cuda_set_tensor_split(tensor_split); } @@ -2379,29 +2508,24 @@ static void llm_load_tensors( static bool llama_model_load( const std::string & fname, llama_model & model, - int n_ctx, - int n_batch, int n_gpu_layers, int main_gpu, const float * tensor_split, - const bool mul_mat_q, - float rope_freq_base, - float rope_freq_scale, - bool low_vram, - ggml_type memory_type, bool use_mmap, bool use_mlock, bool vocab_only, llama_progress_callback progress_callback, void *progress_callback_user_data) { try { - std::unique_ptr ml(new llama_model_loader(fname, use_mmap)); + llama_model_loader ml(fname, use_mmap); + + model.hparams.vocab_only = vocab_only; - llm_load_arch (*ml, model); - llm_load_hparams(*ml, model, n_ctx, rope_freq_base, rope_freq_scale); - llm_load_vocab (*ml, model); + llm_load_arch (ml, model); + llm_load_hparams(ml, model); + llm_load_vocab (ml, model); - llm_load_print_meta(*ml, model); + llm_load_print_meta(ml, model); if (model.hparams.n_vocab != model.vocab.id_to_token.size()) { throw std::runtime_error("vocab size mismatch"); @@ -2413,8 +2537,8 @@ static bool llama_model_load( } llm_load_tensors( - *ml, model, n_batch, n_gpu_layers, - main_gpu, tensor_split, mul_mat_q, low_vram, memory_type, + ml, model, n_gpu_layers, + main_gpu, tensor_split, use_mlock, progress_callback, progress_callback_user_data); } catch (const std::exception & err) { LLAMA_LOG_ERROR("error loading model: %s\n", err.what()); @@ -2426,17 +2550,10 @@ static bool llama_model_load( static struct ggml_cgraph * llm_build_llama( llama_context & lctx, - const llama_token * tokens, - const float * embd, - int n_tokens, - int n_past) { - - GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT - - const int N = n_tokens; - + const llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; const auto & kv_self = lctx.kv_self; @@ -2444,7 +2561,7 @@ static struct ggml_cgraph * llm_build_llama( const int64_t n_embd = hparams.n_embd; const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = hparams.n_ctx; + const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_embd_head = hparams.n_embd_head(); @@ -2452,12 +2569,20 @@ static struct ggml_cgraph * llm_build_llama( GGML_ASSERT(n_embd_head == hparams.n_rot); - const float freq_base = hparams.rope_freq_base; - const float freq_scale = hparams.rope_freq_scale; + const float freq_base = cparams.rope_freq_base; + const float freq_scale = cparams.rope_freq_scale; const float norm_rms_eps = hparams.f_norm_rms_eps; const int n_gpu_layers = model.n_gpu_layers; + const int32_t n_tokens = batch.n_tokens; + const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; + const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; + + const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; + + //printf("n_kv = %d\n", n_kv); + auto & buf_compute = lctx.buf_compute; struct ggml_init_params params = { @@ -2475,12 +2600,12 @@ static struct ggml_cgraph * llm_build_llama( struct ggml_tensor * cur; struct ggml_tensor * inpL; - if (tokens) { - struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + if (batch.token) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_allocr_alloc(lctx.alloc, inp_tokens); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); + memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); } ggml_set_name(inp_tokens, "inp_tokens"); @@ -2490,11 +2615,11 @@ static struct ggml_cgraph * llm_build_llama( GGML_ASSERT(false && "not implemented"); #endif - inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); ggml_allocr_alloc(lctx.alloc, inpL); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); + memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); } } @@ -2503,9 +2628,6 @@ static struct ggml_cgraph * llm_build_llama( // offload functions set the tensor output backend to GPU // tensors are GPU-accelerated if any input or the output has been offloaded - // - // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal - // in that case ggml_cuda_assign_buffers has no effect offload_func_t offload_func_nr = llama_nop; // nr = non-repeating offload_func_t offload_func_kq = llama_nop; offload_func_t offload_func_v = llama_nop; @@ -2522,12 +2644,75 @@ static struct ggml_cgraph * llm_build_llama( } #endif // GGML_USE_CUBLAS + // KQ_scale struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); ggml_allocr_alloc(lctx.alloc, KQ_scale); if (!ggml_allocr_is_measure(lctx.alloc)) { - ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); + ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd_head))); + } + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + offload_func_kq(KQ_mask); + ggml_set_name(KQ_mask, "KQ_mask"); + ggml_allocr_alloc(lctx.alloc, KQ_mask); + if (!ggml_allocr_is_measure(lctx.alloc)) { + float * data = (float *) KQ_mask->data; + memset(data, 0, ggml_nbytes(KQ_mask)); + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j]; + + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } + } + + // KQ_pos - contains the positions + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + offload_func_kq(KQ_pos); + ggml_set_name(KQ_pos, "KQ_pos"); + ggml_allocr_alloc(lctx.alloc, KQ_pos); + if (!ggml_allocr_is_measure(lctx.alloc)) { + int * data = (int *) KQ_pos->data; + for (int i = 0; i < n_tokens; ++i) { + data[i] = batch.pos[i]; + } + } + + // shift the entire K-cache if needed + if (do_rope_shift) { + struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); + offload_func_kq(K_shift); + ggml_set_name(K_shift, "K_shift"); + ggml_allocr_alloc(lctx.alloc, K_shift); + if (!ggml_allocr_is_measure(lctx.alloc)) { + int * data = (int *) K_shift->data; + for (int i = 0; i < n_ctx; ++i) { + data[i] = kv_self.cells[i].delta; + } + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * tmp = + ggml_rope_custom_inplace(ctx0, + ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_head_kv, n_ctx, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il), + K_shift, n_embd_head, 0, 0, freq_base, freq_scale); + offload_func_kq(tmp); + ggml_build_forward_expand(gf, tmp); + } } - ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); for (int il = 0; il < n_layer; ++il) { ggml_format_name(inpL, "layer_inp_%d", il); @@ -2565,33 +2750,33 @@ static struct ggml_cgraph * llm_build_llama( offload_func_kq(tmpq); ggml_set_name(tmpq, "tmpq"); - struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); + struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); offload_func_kq(Kcur); ggml_set_name(Kcur, "Kcur"); - struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); + struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); offload_func_kq(Qcur); ggml_set_name(Qcur, "Qcur"); // store key and value to memory { - // compute the transposed [N, n_embd] V matrix + // compute the transposed [n_tokens, n_embd] V matrix struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); offload_func_v(tmpv); ggml_set_name(tmpv, "tmpv"); - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N)); + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens)); offload_func_v(Vcur); ggml_set_name(Vcur, "Vcur"); - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); offload_func_kq(k); ggml_set_name(k, "k"); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); offload_func_v(v); ggml_set_name(v, "v"); @@ -2606,7 +2791,7 @@ static struct ggml_cgraph * llm_build_llama( struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_past + N, n_head_kv, + n_embd_head, n_kv, n_head_kv, ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_head, ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); @@ -2619,25 +2804,25 @@ static struct ggml_cgraph * llm_build_llama( ggml_set_name(KQ, "KQ"); // KQ_scaled = KQ / sqrt(n_embd_head) - // KQ_scaled shape [n_past + N, N, n_head, 1] - struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + // KQ_scaled shape [n_kv, n_tokens, n_head, 1] + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); offload_func_kq(KQ_scaled); ggml_set_name(KQ_scaled, "KQ_scaled"); // KQ_masked = mask_past(KQ_scaled) - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); offload_func_kq(KQ_masked); ggml_set_name(KQ_masked, "KQ_masked"); // KQ = soft_max(KQ_masked) - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); offload_func_v(KQ_soft_max); ggml_set_name(KQ_soft_max, "KQ_soft_max"); // split cached V into n_head heads struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - n_past + N, n_embd_head, n_head_kv, + n_kv, n_embd_head, n_head_kv, ggml_element_size(kv_self.v)*n_ctx, ggml_element_size(kv_self.v)*n_ctx*n_embd_head, ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); @@ -2652,7 +2837,7 @@ static struct ggml_cgraph * llm_build_llama( // make V contiguous in memory to speed up the matmul, however we waste time on the copy // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation // is there a better way? - struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head)); + struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_ctx, n_embd_head, n_head)); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max); #endif @@ -2661,10 +2846,8 @@ static struct ggml_cgraph * llm_build_llama( offload_func_v(KQV_merged); ggml_set_name(KQV_merged, "KQV_merged"); - // cur = KQV_merged.contiguous().view(n_embd, N) - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + // cur = KQV_merged.contiguous().view(n_embd, n_tokens) + cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); offload_func_v(cur); ggml_set_name(cur, "KQV_merged_contiguous"); @@ -2755,20 +2938,12 @@ static struct ggml_cgraph * llm_build_llama( return gf; } - static struct ggml_cgraph * llm_build_baichaun( llama_context & lctx, - const llama_token * tokens, - const float * embd, - int n_tokens, - int n_past) { - - GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT - - const int N = n_tokens; - + const llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; const auto & kv_self = lctx.kv_self; @@ -2776,7 +2951,7 @@ static struct ggml_cgraph * llm_build_baichaun( const int64_t n_embd = hparams.n_embd; const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = hparams.n_ctx; + const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_embd_head = hparams.n_embd_head(); @@ -2784,12 +2959,18 @@ static struct ggml_cgraph * llm_build_baichaun( GGML_ASSERT(n_embd_head == hparams.n_rot); - const float freq_base = hparams.rope_freq_base; - const float freq_scale = hparams.rope_freq_scale; + const float freq_base = cparams.rope_freq_base; + const float freq_scale = cparams.rope_freq_scale; const float norm_rms_eps = hparams.f_norm_rms_eps; const int n_gpu_layers = model.n_gpu_layers; + const int32_t n_tokens = batch.n_tokens; + const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; + const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; + + const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; + auto & buf_compute = lctx.buf_compute; struct ggml_init_params params = { @@ -2807,12 +2988,12 @@ static struct ggml_cgraph * llm_build_baichaun( struct ggml_tensor * cur; struct ggml_tensor * inpL; - if (tokens) { - struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + if (batch.token) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_allocr_alloc(lctx.alloc, inp_tokens); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); + memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); } ggml_set_name(inp_tokens, "inp_tokens"); @@ -2822,11 +3003,11 @@ static struct ggml_cgraph * llm_build_baichaun( GGML_ASSERT(false && "not implemented"); #endif - inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); ggml_allocr_alloc(lctx.alloc, inpL); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); + memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); } } @@ -2835,9 +3016,6 @@ static struct ggml_cgraph * llm_build_baichaun( // offload functions set the tensor output backend to GPU // tensors are GPU-accelerated if any input or the output has been offloaded - // - // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal - // in that case ggml_cuda_assign_buffers has no effect offload_func_t offload_func_nr = llama_nop; // nr = non-repeating offload_func_t offload_func_kq = llama_nop; offload_func_t offload_func_v = llama_nop; @@ -2854,12 +3032,75 @@ static struct ggml_cgraph * llm_build_baichaun( } #endif // GGML_USE_CUBLAS + // KQ_scale struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); ggml_allocr_alloc(lctx.alloc, KQ_scale); if (!ggml_allocr_is_measure(lctx.alloc)) { ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); } - ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + offload_func_kq(KQ_mask); + ggml_set_name(KQ_mask, "KQ_mask"); + ggml_allocr_alloc(lctx.alloc, KQ_mask); + if (!ggml_allocr_is_measure(lctx.alloc)) { + float * data = (float *) KQ_mask->data; + memset(data, 0, ggml_nbytes(KQ_mask)); + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j]; + + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } + } + + // KQ_pos - contains the positions + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + offload_func_kq(KQ_pos); + ggml_set_name(KQ_pos, "KQ_pos"); + ggml_allocr_alloc(lctx.alloc, KQ_pos); + if (!ggml_allocr_is_measure(lctx.alloc)) { + int * data = (int *) KQ_pos->data; + for (int i = 0; i < n_tokens; ++i) { + data[i] = batch.pos[i]; + } + } + + // shift the entire K-cache if needed + if (do_rope_shift) { + struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); + offload_func_kq(K_shift); + ggml_set_name(K_shift, "K_shift"); + ggml_allocr_alloc(lctx.alloc, K_shift); + if (!ggml_allocr_is_measure(lctx.alloc)) { + int * data = (int *) K_shift->data; + for (int i = 0; i < n_ctx; ++i) { + data[i] = kv_self.cells[i].delta; + } + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * tmp = + ggml_rope_custom_inplace(ctx0, + ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_head_kv, n_ctx, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il), + K_shift, n_embd_head, 0, 0, freq_base, freq_scale); + offload_func_kq(tmp); + ggml_build_forward_expand(gf, tmp); + } + } for (int il = 0; il < n_layer; ++il) { ggml_format_name(inpL, "layer_inp_%d", il); @@ -2901,12 +3142,12 @@ static struct ggml_cgraph * llm_build_baichaun( struct ggml_tensor * Qcur; switch (model.type) { case MODEL_7B: - Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); - Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); + Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); + Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); break; case MODEL_13B: - Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N); - Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N); + Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, n_tokens); + Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, n_tokens); break; default: GGML_ASSERT(false); @@ -2920,23 +3161,23 @@ static struct ggml_cgraph * llm_build_baichaun( // store key and value to memory { - // compute the transposed [N, n_embd] V matrix + // compute the transposed [n_tokens, n_embd] V matrix struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); offload_func_v(tmpv); ggml_set_name(tmpv, "tmpv"); - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N)); + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens)); offload_func_v(Vcur); ggml_set_name(Vcur, "Vcur"); - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); offload_func_kq(k); ggml_set_name(k, "k"); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); offload_func_v(v); ggml_set_name(v, "v"); @@ -2951,7 +3192,7 @@ static struct ggml_cgraph * llm_build_baichaun( struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_past + N, n_head_kv, + n_embd_head, n_kv, n_head_kv, ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_head, ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); @@ -2964,8 +3205,8 @@ static struct ggml_cgraph * llm_build_baichaun( ggml_set_name(KQ, "KQ"); // KQ_scaled = KQ / sqrt(n_embd_head) - // KQ_scaled shape [n_past + N, N, n_head, 1] - struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + // KQ_scaled shape [n_past + n_tokens, n_tokens, n_head, 1] + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); offload_func_kq(KQ_scaled); ggml_set_name(KQ_scaled, "KQ_scaled"); @@ -2974,58 +3215,44 @@ static struct ggml_cgraph * llm_build_baichaun( switch (model.type) { case MODEL_7B: - KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); break; case MODEL_13B: - KQ_scaled_alibi =ggml_alibi(ctx0, KQ_scaled, n_past, n_head, 8); + // TODO: replace with ggml_add() + KQ_scaled_alibi = ggml_alibi(ctx0, KQ_scaled, /*n_past*/ 0, n_head, 8); ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi"); - KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past); + KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask); break; default: GGML_ASSERT(false); } - // KQ_masked = mask_past(KQ_scaled) - // struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); - // struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past); - // offload_func_kq(KQ_masked); - // ggml_set_name(KQ_masked, "KQ_masked"); // KQ = soft_max(KQ_masked) - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); offload_func_v(KQ_soft_max); ggml_set_name(KQ_soft_max, "KQ_soft_max"); // split cached V into n_head heads struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - n_past + N, n_embd_head, n_head_kv, + n_kv, n_embd_head, n_head_kv, ggml_element_size(kv_self.v)*n_ctx, ggml_element_size(kv_self.v)*n_ctx*n_embd_head, ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); offload_func_v(V); ggml_set_name(V, "V"); -#if 1 struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); offload_func_v(KQV); ggml_set_name(KQV, "KQV"); -#else - // make V contiguous in memory to speed up the matmul, however we waste time on the copy - // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation - // is there a better way? - struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head)); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max); -#endif // KQV_merged = KQV.permute(0, 2, 1, 3) struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); offload_func_v(KQV_merged); ggml_set_name(KQV_merged, "KQV_merged"); - // cur = KQV_merged.contiguous().view(n_embd, N) - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + // cur = KQV_merged.contiguous().view(n_embd, n_tokens) + cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); offload_func_v(cur); ggml_set_name(cur, "KQV_merged_contiguous"); @@ -3118,17 +3345,10 @@ static struct ggml_cgraph * llm_build_baichaun( static struct ggml_cgraph * llm_build_falcon( llama_context & lctx, - const llama_token * tokens, - const float * embd, - int n_tokens, - int n_past) { - - GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT - - const int N = n_tokens; - + const llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; const auto & kv_self = lctx.kv_self; @@ -3136,7 +3356,7 @@ static struct ggml_cgraph * llm_build_falcon( const int64_t n_embd = hparams.n_embd; const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = hparams.n_ctx; + const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_embd_head = hparams.n_embd_head(); @@ -3144,12 +3364,21 @@ static struct ggml_cgraph * llm_build_falcon( GGML_ASSERT(n_embd_head == hparams.n_rot); - const float freq_base = hparams.rope_freq_base; - const float freq_scale = hparams.rope_freq_scale; + const float freq_base = cparams.rope_freq_base; + const float freq_scale = cparams.rope_freq_scale; const float norm_eps = hparams.f_norm_eps; const int n_gpu_layers = model.n_gpu_layers; + const int32_t n_tokens = batch.n_tokens; + const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; + const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; + + const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; + + //printf("kv_head = %d, n_kv = %d, n_tokens = %d, n_ctx = %d, is_measure = %d, has_shift = %d\n", + // kv_head, n_kv, n_tokens, n_ctx, ggml_allocr_is_measure(lctx.alloc), kv_self.has_shift); + auto & buf_compute = lctx.buf_compute; struct ggml_init_params params = { @@ -3167,12 +3396,12 @@ static struct ggml_cgraph * llm_build_falcon( struct ggml_tensor * cur; struct ggml_tensor * inpL; - if (tokens) { - struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + if (batch.token) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_allocr_alloc(lctx.alloc, inp_tokens); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); + memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); } ggml_set_name(inp_tokens, "inp_tokens"); @@ -3182,11 +3411,11 @@ static struct ggml_cgraph * llm_build_falcon( GGML_ASSERT(false && "not implemented"); #endif - inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); ggml_allocr_alloc(lctx.alloc, inpL); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); + memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); } } @@ -3195,9 +3424,6 @@ static struct ggml_cgraph * llm_build_falcon( // offload functions set the tensor output backend to GPU // tensors are GPU-accelerated if any input or the output has been offloaded - // - // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal - // in that case ggml_cuda_assign_buffers has no effect offload_func_t offload_func_nr = llama_nop; // nr = non-repeating offload_func_t offload_func_kq = llama_nop; offload_func_t offload_func_v = llama_nop; @@ -3214,12 +3440,75 @@ static struct ggml_cgraph * llm_build_falcon( } #endif // GGML_USE_CUBLAS + // KQ_scale struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); ggml_allocr_alloc(lctx.alloc, KQ_scale); if (!ggml_allocr_is_measure(lctx.alloc)) { ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); } - ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + offload_func_kq(KQ_mask); + ggml_set_name(KQ_mask, "KQ_mask"); + ggml_allocr_alloc(lctx.alloc, KQ_mask); + if (!ggml_allocr_is_measure(lctx.alloc)) { + float * data = (float *) KQ_mask->data; + memset(data, 0, ggml_nbytes(KQ_mask)); + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j]; + + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } + } + + // KQ_pos - contains the positions + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + offload_func_kq(KQ_pos); + ggml_set_name(KQ_pos, "KQ_pos"); + ggml_allocr_alloc(lctx.alloc, KQ_pos); + if (!ggml_allocr_is_measure(lctx.alloc)) { + int * data = (int *) KQ_pos->data; + for (int i = 0; i < n_tokens; ++i) { + data[i] = batch.pos[i]; + } + } + + // shift the entire K-cache if needed + if (do_rope_shift) { + struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); + offload_func_kq(K_shift); + ggml_set_name(K_shift, "K_shift"); + ggml_allocr_alloc(lctx.alloc, K_shift); + if (!ggml_allocr_is_measure(lctx.alloc)) { + int * data = (int *) K_shift->data; + for (int i = 0; i < n_ctx; ++i) { + data[i] = kv_self.cells[i].delta; + } + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * tmp = + ggml_rope_custom_inplace(ctx0, + ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_head_kv, n_ctx, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il), + K_shift, n_embd_head, 2, 0, freq_base, freq_scale); + offload_func_kq(tmp); + ggml_build_forward_expand(gf, tmp); + } + } for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * attn_norm; @@ -3276,45 +3565,45 @@ static struct ggml_cgraph * llm_build_falcon( // TODO: these 2 ggml_conts are technically not needed, but we add them until CUDA support for // non-contiguous views is added for the rope operator struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_3d( - ctx0, cur, n_embd_head, n_head, N, + ctx0, cur, n_embd_head, n_head, n_tokens, wsize * n_embd_head, wsize * n_embd_head * (n_head + 2 * n_head_kv), 0)); offload_func_kq(tmpq); struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_3d( - ctx0, cur, n_embd_head, n_head_kv, N, + ctx0, cur, n_embd_head, n_head_kv, n_tokens, wsize * n_embd_head, wsize * n_embd_head * (n_head + 2 * n_head_kv), wsize * n_embd_head * n_head)); offload_func_kq(tmpk); struct ggml_tensor * tmpv = ggml_view_3d( - ctx0, cur, n_embd_head, n_head_kv, N, + ctx0, cur, n_embd_head, n_head_kv, n_tokens, wsize * n_embd_head, wsize * n_embd_head * (n_head + 2 * n_head_kv), wsize * n_embd_head * (n_head + n_head_kv)); offload_func_v(tmpv); // using mode = 2 for neox mode - struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, tmpq, n_past, n_embd_head, 2, 0, freq_base, freq_scale); + struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, tmpq, KQ_pos, n_embd_head, 2, 0, freq_base, freq_scale); offload_func_kq(Qcur); - struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, tmpk, n_past, n_embd_head, 2, 0, freq_base, freq_scale); + struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, tmpk, KQ_pos, n_embd_head, 2, 0, freq_base, freq_scale); offload_func_kq(Kcur); { - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, N)); + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens)); offload_func_v(Vcur); offload_func_v(Vcur->src[0]->src[0]); ggml_set_name(Vcur, "Vcur"); - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); offload_func_kq(k); ggml_set_name(k, "k"); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); offload_func_v(v); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); @@ -3327,7 +3616,7 @@ static struct ggml_cgraph * llm_build_falcon( struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_past + N, n_head_kv, + n_embd_head, n_kv, n_head_kv, ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_head, ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); @@ -3338,21 +3627,21 @@ static struct ggml_cgraph * llm_build_falcon( offload_func_kq(KQ); ggml_set_name(KQ, "KQ"); - struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); offload_func_kq(KQ_scaled); ggml_set_name(KQ_scaled, "KQ_scaled"); - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); offload_func_kq(KQ_masked); ggml_set_name(KQ_masked, "KQ_masked"); - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); offload_func_v(KQ_soft_max); ggml_set_name(KQ_soft_max, "KQ_soft_max"); struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - n_past + N, n_embd_head, n_head_kv, + n_kv, n_embd_head, n_head_kv, ggml_element_size(kv_self.v)*n_ctx, ggml_element_size(kv_self.v)*n_ctx*n_embd_head, ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); @@ -3367,7 +3656,7 @@ static struct ggml_cgraph * llm_build_falcon( offload_func_v(KQV_merged); ggml_set_name(KQV_merged, "KQV_merged"); - cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); offload_func_v(cur); ggml_set_name(cur, "KQV_merged_contiguous"); @@ -3425,17 +3714,10 @@ static struct ggml_cgraph * llm_build_falcon( static struct ggml_cgraph * llm_build_starcoder( llama_context & lctx, - const llama_token * tokens, - const float * embd, - int n_tokens, - int n_past) { - - GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT - - const int N = n_tokens; - + const llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; const auto & kv_self = lctx.kv_self; @@ -3443,7 +3725,7 @@ static struct ggml_cgraph * llm_build_starcoder( const int64_t n_embd = hparams.n_embd; const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = hparams.n_ctx; + const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_embd_head = hparams.n_embd_head(); @@ -3451,7 +3733,11 @@ static struct ggml_cgraph * llm_build_starcoder( GGML_ASSERT(n_embd_head == hparams.n_rot); - const float norm_eps = hparams.f_norm_eps; + const float norm_eps = hparams.f_norm_eps; + + const int32_t n_tokens = batch.n_tokens; + const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; + const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; auto & buf_compute = lctx.buf_compute; @@ -3472,12 +3758,12 @@ static struct ggml_cgraph * llm_build_starcoder( struct ggml_tensor * position; struct ggml_tensor * inpL; - if (tokens) { - struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + if (batch.token) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_allocr_alloc(lctx.alloc, inp_tokens); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); + memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); } ggml_set_name(inp_tokens, "inp_tokens"); @@ -3487,21 +3773,21 @@ static struct ggml_cgraph * llm_build_starcoder( GGML_ASSERT(false && "not implemented"); #endif - token = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); + token = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); ggml_allocr_alloc(lctx.alloc, token); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(token->data, embd, N * n_embd * ggml_element_size(token)); + memcpy(token->data, batch.embd, n_tokens * n_embd * ggml_element_size(token)); } } { // Compute position embeddings. - struct ggml_tensor * inp_positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + struct ggml_tensor * inp_positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_allocr_alloc(lctx.alloc, inp_positions); if (!ggml_allocr_is_measure(lctx.alloc)) { - for (int i = 0; i < N; ++i) { - ((int32_t *) inp_positions->data)[i] = n_past + i; + for (int i = 0; i < n_tokens; ++i) { + ((int32_t *) inp_positions->data)[i] = batch.pos[i]; } } ggml_set_name(inp_positions, "inp_positions"); @@ -3509,12 +3795,35 @@ static struct ggml_cgraph * llm_build_starcoder( position = ggml_get_rows(ctx0, model.pos_embeddings, inp_positions); } + // KQ_scale struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); ggml_allocr_alloc(lctx.alloc, KQ_scale); if (!ggml_allocr_is_measure(lctx.alloc)) { ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); } - ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + ggml_set_name(KQ_mask, "KQ_mask"); + ggml_allocr_alloc(lctx.alloc, KQ_mask); + if (!ggml_allocr_is_measure(lctx.alloc)) { + float * data = (float *) KQ_mask->data; + memset(data, 0, ggml_nbytes(KQ_mask)); + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j]; + + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } + } inpL = ggml_add(ctx0, token, position); ggml_set_name(inpL, "inpL"); @@ -3530,23 +3839,23 @@ static struct ggml_cgraph * llm_build_starcoder( // Self Attention cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wqkv, cur), model.layers[il].bqkv); - struct ggml_tensor * tmpq = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd); - struct ggml_tensor * tmpk = ggml_view_2d(ctx0, cur, n_embd_gqa, N, cur->nb[1], sizeof(float)*n_embd); - struct ggml_tensor * tmpv = ggml_view_2d(ctx0, cur, n_embd_gqa, N, cur->nb[1], sizeof(float)*(n_embd + n_embd_gqa)); + struct ggml_tensor * tmpq = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*n_embd); + struct ggml_tensor * tmpk = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*n_embd); + struct ggml_tensor * tmpv = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*(n_embd + n_embd_gqa)); struct ggml_tensor * Qcur = tmpq; struct ggml_tensor * Kcur = tmpk; { - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, N)); + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens)); ggml_set_name(Vcur, "Vcur"); - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); ggml_set_name(k, "k"); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); @@ -3556,13 +3865,13 @@ static struct ggml_cgraph * llm_build_starcoder( ggml_permute(ctx0, ggml_cpy(ctx0, Qcur, - ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd_head, n_head, N)), + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd_head, n_head, n_tokens)), 0, 2, 1, 3); ggml_set_name(Q, "Q"); struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_past + N, n_head_kv, + n_embd_head, n_kv, n_head_kv, ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_head, ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); @@ -3573,12 +3882,12 @@ static struct ggml_cgraph * llm_build_starcoder( ggml_set_name(KQ, "KQ"); // KQ_scaled = KQ / sqrt(n_embd_head) - // KQ_scaled shape [n_past + N, N, n_head, 1] + // KQ_scaled shape [n_past + n_tokens, n_tokens, n_head, 1] struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); ggml_set_name(KQ_scaled, "KQ_scaled"); // KQ_masked = mask_past(KQ_scaled) - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); ggml_set_name(KQ_masked, "KQ_masked"); // KQ = soft_max(KQ_masked) @@ -3588,7 +3897,7 @@ static struct ggml_cgraph * llm_build_starcoder( // split cached V into n_head heads struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - n_past + N, n_embd_head, n_head_kv, + n_kv, n_embd_head, n_head_kv, ggml_element_size(kv_self.v)*n_ctx, ggml_element_size(kv_self.v)*n_ctx*n_embd_head, ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); @@ -3601,10 +3910,8 @@ static struct ggml_cgraph * llm_build_starcoder( struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); ggml_set_name(KQV_merged, "KQV_merged"); - // cur = KQV_merged.contiguous().view(n_embd, N) - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + // cur = KQV_merged.contiguous().view(n_embd, n_tokens) + cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); ggml_set_name(cur, "KQV_merged_contiguous"); } @@ -3654,10 +3961,7 @@ static struct ggml_cgraph * llm_build_starcoder( static struct ggml_cgraph * llama_build_graph( llama_context & lctx, - const llama_token * tokens, - const float * embd, - int n_tokens, - int n_past) { + const llama_batch & batch) { const auto & model = lctx.model; struct ggml_cgraph * result = NULL; @@ -3665,76 +3969,117 @@ static struct ggml_cgraph * llama_build_graph( switch (model.arch) { case LLM_ARCH_LLAMA: { - result = llm_build_llama(lctx, tokens, embd, n_tokens, n_past); + result = llm_build_llama(lctx, batch); } break; case LLM_ARCH_BAICHUAN: { - result = llm_build_baichaun(lctx, tokens, embd, n_tokens, n_past); + result = llm_build_baichaun(lctx, batch); } break; case LLM_ARCH_FALCON: { - result = llm_build_falcon(lctx, tokens, embd, n_tokens, n_past); + result = llm_build_falcon(lctx, batch); } break; case LLM_ARCH_STARCODER: { - result = llm_build_starcoder(lctx, tokens, embd, n_tokens, n_past); + result = llm_build_starcoder(lctx, batch); } break; default: GGML_ASSERT(false); - }; + } return result; } -// evaluate the transformer +// decode a batch of tokens by evaluating the transformer // // - lctx: llama context -// - tokens: new batch of tokens to process -// - embd embeddings input -// - n_tokens number of tokens -// - n_past: the context size so far +// - batch: batch to evaluate // - n_threads: number of threads to use // -static bool llama_eval_internal( +// return 0 on success +// return positive int on warning +// return negative int on error +// +static int llama_decode_internal( llama_context & lctx, - const llama_token * tokens, - const float * embd, - int n_tokens, - int n_past, - int n_threads, - const char * cgraph_fname) { + llama_batch batch) { + const uint32_t n_tokens = batch.n_tokens; + + if (n_tokens == 0) { + LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__); + return -1; + } + + const auto & model = lctx.model; + const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; - GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT + const auto n_batch = cparams.n_batch; - GGML_ASSERT(n_tokens > 0); - GGML_ASSERT(n_past >= 0); - // TODO: keep the values of n_batch and n_ctx - // GGML_ASSERT(n_tokens <= n_batch); - // GGML_ASSERT(n_past + n_tokens <= n_ctx); + GGML_ASSERT(n_tokens <= n_batch); + + int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; + GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT const int64_t t_start_us = ggml_time_us(); #ifdef GGML_USE_MPI - ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); + // TODO: needs fix after #3228 + GGML_ASSERT(false && "not implemented"); + //ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); #endif GGML_ASSERT(n_threads > 0); - const int N = n_tokens; - - const auto & model = lctx.model; - const auto & hparams = model.hparams; - - const auto & kv_self = lctx.kv_self; + auto & kv_self = lctx.kv_self; GGML_ASSERT(!!kv_self.ctx); const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; + // helpers for smoother batch API transistion + // after deprecating the llama_eval calls, these will be removed + std::vector pos; + std::vector seq_id; + + if (batch.pos == nullptr) { + pos.resize(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + pos[i] = batch.all_pos_0 + i*batch.all_pos_1; + } + + batch.pos = pos.data(); + } + + if (batch.seq_id == nullptr) { + seq_id.resize(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + seq_id[i] = batch.all_seq_id; + } + + batch.seq_id = seq_id.data(); + } + + // we always start to search for a free slot from the start of the cache + // TODO: better strategies can be implemented + kv_self.head = 0; + + if (!llama_kv_cache_find_slot(kv_self, batch)) { + return 1; + } + + // a heuristic, to avoid attending the full cache if it is not yet utilized + // after enough generations, the benefit from this heuristic disappears + // if we start defragmenting the cache, the benefit from this will be more important + //kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)); // TODO: this might be better for CUDA? + kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self))); + + //printf("kv_self.n = %d\n", kv_self.n); + ggml_allocr_reset(lctx.alloc); - ggml_cgraph * gf = llama_build_graph(lctx, tokens, embd, n_tokens, n_past); + ggml_cgraph * gf = llama_build_graph(lctx, batch); ggml_allocr_alloc_graph(lctx.alloc, gf); @@ -3743,6 +4088,7 @@ static bool llama_eval_internal( ggml_tensor * node = gf->leafs[i]; if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) { ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data); + ggml_cuda_copy_to_device(node); } } @@ -3752,6 +4098,8 @@ static bool llama_eval_internal( ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data); } } + + ggml_cuda_set_mul_mat_q(cparams.mul_mat_q); #endif // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); @@ -3761,7 +4109,7 @@ static bool llama_eval_internal( // TODO: this is mostly important for Apple Silicon where CBLAS is still performing very well // we still need some threads to process all non-mul_mat ops, but not too much to avoid interfering // with the BLAS calls. need a better solution - if (N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) { + if (n_tokens >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) { n_threads = std::min(4, n_threads); } @@ -3800,12 +4148,9 @@ static bool llama_eval_internal( ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer); #endif - // update kv token count - lctx.kv_self.n = n_past + N; - - if (cgraph_fname) { - ggml_graph_export(gf, cgraph_fname); - } + // update the kv ring buffer + lctx.kv_self.head += n_tokens; + lctx.kv_self.has_shift = false; #ifdef GGML_PERF // print timing information per ggml operation (for debugging purposes) @@ -3822,13 +4167,20 @@ static bool llama_eval_internal( { auto & logits_out = lctx.logits; - if (lctx.logits_all) { - logits_out.resize(n_vocab * N); - memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*N); + if (batch.logits) { + logits_out.resize(n_vocab * n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + if (batch.logits[i] == 0) { + continue; + } + memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab); + } + } else if (lctx.logits_all) { + logits_out.resize(n_vocab * n_tokens); + memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens); } else { - // return result for just the last token logits_out.resize(n_vocab); - memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab); } } @@ -3837,20 +4189,27 @@ static bool llama_eval_internal( auto & embedding_out = lctx.embedding; embedding_out.resize(n_embd); - memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd); + memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(n_tokens - 1)), sizeof(float)*n_embd); } // measure the performance only for the single-token evals - if (N == 1) { + if (n_tokens == 1) { lctx.t_eval_us += ggml_time_us() - t_start_us; lctx.n_eval++; } - else if (N > 1) { + else if (n_tokens > 1) { lctx.t_p_eval_us += ggml_time_us() - t_start_us; - lctx.n_p_eval += N; + lctx.n_p_eval += n_tokens; } - return true; + // get a more accurate load time, upon first eval + // TODO: fix this + if (!lctx.has_evaluated_once) { + lctx.t_load_us = ggml_time_us() - lctx.t_start_us; + lctx.has_evaluated_once = true; + } + + return 0; } // @@ -4271,7 +4630,7 @@ static std::vector llama_tokenize_internal(const llama_vocab & llm_tokenizer_bpe tokenizer(vocab); tokenizer.tokenize(raw_text, output); } break; - }; + } return output; } @@ -4675,6 +5034,13 @@ struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) // sampling // +void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { + if (seed == LLAMA_DEFAULT_SEED) { + seed = time(NULL); + } + ctx->rng.seed(seed); +} + void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { GGML_ASSERT(candidates->size > 0); @@ -4883,7 +5249,7 @@ void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * c } } -void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { +void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { const int64_t t_start_sample_us = ggml_time_us(); for (size_t i = 0; i < candidates_p->size; ++i) { @@ -4895,6 +5261,10 @@ void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array } } +void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { + llama_sample_temp(ctx, candidates_p, temp); +} + void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty) { if (last_tokens_size == 0 || penalty == 1.0f) { return; @@ -5018,7 +5388,7 @@ void llama_sample_classifier_free_guidance( GGML_ASSERT(ctx); - auto n_vocab = llama_n_vocab(ctx); + auto n_vocab = llama_n_vocab(llama_get_model(ctx)); GGML_ASSERT(n_vocab == (int)candidates->size); GGML_ASSERT(!candidates->sorted); @@ -5047,7 +5417,7 @@ void llama_sample_classifier_free_guidance( llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) { GGML_ASSERT(ctx); - auto N = float(llama_n_vocab(ctx)); + auto N = float(llama_n_vocab(llama_get_model(ctx))); int64_t t_start_sample_us; t_start_sample_us = ggml_time_us(); @@ -5234,7 +5604,7 @@ struct llama_logit_info { }; llama_logit_info(llama_context * ctx) : logits(llama_get_logits(ctx)) - , n_vocab(llama_n_vocab(ctx)) + , n_vocab(llama_n_vocab(llama_get_model(ctx))) , max_l(*std::max_element(logits, logits + n_vocab)) , normalizer(1.0f / std::accumulate(logits, logits + n_vocab, 0.0f, sum_exp{max_l})) { } @@ -5272,7 +5642,6 @@ struct llama_beam_search_data { size_t n_beams; int n_past; int n_predict; - int n_threads; std::vector beams; std::vector next_beams; @@ -5282,12 +5651,11 @@ struct llama_beam_search_data { // Used to communicate to/from callback on beams state. std::vector beam_views; - llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict, int n_threads) + llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict) : ctx(ctx) , n_beams(n_beams) , n_past(n_past) , n_predict(n_predict) - , n_threads(n_threads) , beam_views(n_beams) { beams.reserve(n_beams); next_beams.reserve(n_beams); @@ -5324,7 +5692,7 @@ struct llama_beam_search_data { } else { // beam is not at end-of-sentence, so branch with next top_k tokens. if (!beam.tokens.empty()) { - llama_eval(ctx, beam.tokens.data(), beam.tokens.size(), n_past, n_threads); + llama_decode(ctx, llama_batch_get_one(beam.tokens.data(), beam.tokens.size(), n_past, 0)); } llama_logit_info logit_info(ctx); std::vector next_tokens = logit_info.top_k(n_beams); @@ -5398,7 +5766,7 @@ struct llama_beam_search_data { callback(callback_data, get_beams_state(false)); // Sets common_prefix_length update_beams_from_beam_views(); // Update values (p,eob) that callback may have changed. if (common_prefix_length) { - llama_eval(ctx, beams[0].tokens.data(), common_prefix_length, n_past, n_threads); + llama_decode(ctx, llama_batch_get_one(beams[0].tokens.data(), common_prefix_length, n_past, 0)); n_past += common_prefix_length; } // Zero-out next_beam probabilities to place them last in following min-heap. @@ -5439,11 +5807,11 @@ struct llama_beam_search_data { void llama_beam_search(llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, - size_t n_beams, int n_past, int n_predict, int n_threads) { + size_t n_beams, int n_past, int n_predict) { assert(ctx); const int64_t t_start_sample_us = ggml_time_us(); - llama_beam_search_data beam_search_data(ctx, n_beams, n_past, n_predict, n_threads); + llama_beam_search_data beam_search_data(ctx, n_beams, n_past, n_predict); beam_search_data.loop(callback, callback_data); @@ -5663,11 +6031,22 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s nthread = std::thread::hardware_concurrency(); } - std::unique_ptr ml(new llama_model_loader(fname_inp, /*use_mmap*/ false)); + // mmap consistently increases speed Linux, and also increases speed on Windows with + // hot cache. It may cause a slowdown on macOS, possibly related to free memory. +#if defined(__linux__) || defined(_WIN32) + constexpr bool use_mmap = true; +#else + constexpr bool use_mmap = false; +#endif + + llama_model_loader ml(fname_inp, use_mmap); + if (ml.use_mmap) { + ml.mapping.reset(new llama_mmap(&ml.file, /* prefetch */ 0, ggml_is_numa())); + } llama_model model; - llm_load_arch(*ml, model); - llm_load_hparams(*ml, model, 0, 0, 0); + llm_load_arch(ml, model); + llm_load_hparams(ml, model); if (params->only_copy) { ftype = model.ftype; @@ -5677,7 +6056,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s struct gguf_context * ctx_out = gguf_init_empty(); // copy the KV pairs from the input file - gguf_set_kv (ctx_out, ml->ctx_gguf); + gguf_set_kv (ctx_out, ml.ctx_gguf); gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION); gguf_set_val_u32(ctx_out, "general.file_type", ftype); @@ -5685,8 +6064,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s int n_attention_wv = 0; int n_feed_forward_w2 = 0; - for (int i = 0; i < ml->n_tensors; ++i) { - struct ggml_tensor * meta = ml->get_tensor_meta(i); + for (int i = 0; i < ml.n_tensors; ++i) { + struct ggml_tensor * meta = ml.get_tensor_meta(i); const std::string name = ggml_get_name(meta); @@ -5722,8 +6101,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s std::vector> f32_conv_buf; // populate the original tensors so we get an initial meta data - for (int i = 0; i < ml->n_tensors; ++i) { - struct ggml_tensor * meta = ml->get_tensor_meta(i); + for (int i = 0; i < ml.n_tensors; ++i) { + struct ggml_tensor * meta = ml.get_tensor_meta(i); gguf_add_tensor(ctx_out, meta); } @@ -5736,19 +6115,21 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // placeholder for the meta data ::zeros(fout, meta_size); - for (int i = 0; i < ml->n_tensors; ++i) { - struct ggml_tensor * tensor = ml->get_tensor_meta(i); + for (int i = 0; i < ml.n_tensors; ++i) { + struct ggml_tensor * tensor = ml.get_tensor_meta(i); const std::string name = ggml_get_name(tensor); - if (read_data.size() < ggml_nbytes(tensor)) { - read_data.resize(ggml_nbytes(tensor)); + if (!ml.use_mmap) { + if (read_data.size() < ggml_nbytes(tensor)) { + read_data.resize(ggml_nbytes(tensor)); + } + tensor->data = read_data.data(); } - tensor->data = read_data.data(); - ml->load_data_for(tensor); + ml.load_data_for(tensor); LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ", - ++idx, ml->n_tensors, + ++idx, ml.n_tensors, ggml_get_name(tensor), llama_format_tensor_shape(tensor).c_str(), ggml_type_name(tensor->type)); @@ -5898,9 +6279,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } } -// TODO: after the GGUF PR, this likely won't work and needs to be updated static int llama_apply_lora_from_file_internal( - const struct llama_model & model, const char * path_lora, const char * path_base_model, int n_threads + const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads ) { LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora); @@ -5929,7 +6309,7 @@ static int llama_apply_lora_from_file_internal( int32_t lora_alpha; fin.read((char *) &lora_r, sizeof(lora_r)); fin.read((char *) &lora_alpha, sizeof(lora_alpha)); - float scaling = (float)lora_alpha / (float)lora_r; + float scaling = scale * (float)lora_alpha / (float)lora_r; LLAMA_LOG_INFO("%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling); @@ -6145,9 +6525,10 @@ static int llama_apply_lora_from_file_internal( ggml_set_name(r, "r_cpy"); } - struct ggml_cgraph gf = ggml_build_forward(r); + struct ggml_cgraph * gf = ggml_new_graph(lora_ctx); + ggml_build_forward_expand(gf, r); - ggml_graph_compute_helper(work_buffer, &gf, n_threads); + ggml_graph_compute_helper(work_buffer, gf, n_threads); // we won't need these tensors again, reset the context to save memory ggml_free(lora_ctx); @@ -6176,27 +6557,16 @@ static int llama_apply_lora_from_file_internal( // // interface implementation // - -struct llama_context_params llama_context_default_params() { - struct llama_context_params result = { - /*.seed =*/ LLAMA_DEFAULT_SEED, - /*.n_ctx =*/ 512, - /*.n_batch =*/ 512, +struct llama_model_params llama_model_default_params() { + struct llama_model_params result = { /*.n_gpu_layers =*/ 0, /*.main_gpu =*/ 0, /*.tensor_split =*/ nullptr, - /*.rope_freq_base =*/ 0.0f, - /*.rope_freq_scale =*/ 0.0f, /*.progress_callback =*/ nullptr, /*.progress_callback_user_data =*/ nullptr, - /*.low_vram =*/ false, - /*.mul_mat_q =*/ true, - /*.f16_kv =*/ true, - /*.logits_all =*/ false, /*.vocab_only =*/ false, /*.use_mmap =*/ true, /*.use_mlock =*/ false, - /*.embedding =*/ false, }; #ifdef GGML_USE_METAL @@ -6206,6 +6576,24 @@ struct llama_context_params llama_context_default_params() { return result; } +struct llama_context_params llama_context_default_params() { + struct llama_context_params result = { + /*.seed =*/ LLAMA_DEFAULT_SEED, + /*.n_ctx =*/ 512, + /*.n_batch =*/ 512, + /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default + /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, + /*.rope_freq_base =*/ 0.0f, + /*.rope_freq_scale =*/ 0.0f, + /*.mul_mat_q =*/ true, + /*.f16_kv =*/ true, + /*.logits_all =*/ false, + /*.embedding =*/ false, + }; + + return result; +} + struct llama_model_quantize_params llama_model_quantize_default_params() { struct llama_model_quantize_params result = { /*.nthread =*/ 0, @@ -6261,13 +6649,11 @@ int64_t llama_time_us(void) { struct llama_model * llama_load_model_from_file( const char * path_model, - struct llama_context_params params) { + struct llama_model_params params) { ggml_time_init(); llama_model * model = new llama_model; - ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - unsigned cur_percentage = 0; if (params.progress_callback == NULL) { params.progress_callback_user_data = &cur_percentage; @@ -6284,9 +6670,9 @@ struct llama_model * llama_load_model_from_file( }; } - if (!llama_model_load(path_model, *model, params.n_ctx, params.n_batch, params.n_gpu_layers, - params.main_gpu, params.tensor_split, params.mul_mat_q, params.rope_freq_base, params.rope_freq_scale, - params.low_vram, memory_type, params.use_mmap, params.use_mlock, params.vocab_only, + if (!llama_model_load(path_model, *model, params.n_gpu_layers, + params.main_gpu, params.tensor_split, + params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback, params.progress_callback_user_data)) { LLAMA_LOG_ERROR("%s: failed to load model\n", __func__); delete model; @@ -6310,18 +6696,33 @@ struct llama_context * llama_new_context_with_model( llama_context * ctx = new llama_context(*model); + const auto & hparams = model->hparams; + auto & cparams = ctx->cparams; + + cparams.n_batch = params.n_batch; + cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; + cparams.rope_freq_base = params.rope_freq_base == 0 ? hparams.rope_freq_base_train : params.rope_freq_base; + cparams.rope_freq_scale = params.rope_freq_scale == 0 ? hparams.rope_freq_scale_train : params.rope_freq_scale; + cparams.n_threads = params.n_threads; + cparams.n_threads_batch = params.n_threads_batch; + cparams.mul_mat_q = params.mul_mat_q; + if (params.seed == LLAMA_DEFAULT_SEED) { params.seed = time(NULL); } + LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); + LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); + LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); + ctx->rng = std::mt19937(params.seed); ctx->logits_all = params.logits_all; ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; // reserve memory for context buffers - if (!params.vocab_only) { - if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, ctx->model.hparams.n_ctx, params.n_gpu_layers)) { + if (!hparams.vocab_only) { + if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, cparams.n_ctx, model->n_gpu_layers)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; @@ -6332,11 +6733,9 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); } - const auto & hparams = ctx->model.hparams; - // resized during inference if (params.logits_all) { - ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab); + ctx->logits.reserve(cparams.n_ctx*hparams.n_vocab); } else { ctx->logits.reserve(hparams.n_vocab); } @@ -6354,26 +6753,29 @@ struct llama_context * llama_new_context_with_model( ctx->alloc = ggml_allocr_new_measure(tensor_alignment); // build worst-case graph - int n_tokens = std::min((int)hparams.n_ctx, params.n_batch); - int n_past = hparams.n_ctx - n_tokens; + int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch); + int n_past = cparams.n_ctx - n_tokens; llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - ggml_cgraph * gf = llama_build_graph(*ctx, &token, NULL, n_tokens, n_past); + ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0)); + #ifdef GGML_USE_METAL - if (params.n_gpu_layers > 0) { + if (model->n_gpu_layers > 0) { + ggml_metal_log_set_callback(llama_log_callback_default, NULL); + ctx->ctx_metal = ggml_metal_init(1); if (!ctx->ctx_metal) { LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__); llama_free(ctx); return NULL; } - ggml_metal_graph_find_concurrency(ctx->ctx_metal, gf, false); - ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); + //ggml_metal_graph_find_concurrency(ctx->ctx_metal, gf, false); + //ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); } #endif // measure memory requirements for the graph size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf) + tensor_alignment; - LLAMA_LOG_INFO("%s: compute buffer total size = %7.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0); + LLAMA_LOG_INFO("%s: compute buffer total size = %.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0); // recreate allocator with exact memory requirements ggml_allocr_free(ctx->alloc); @@ -6382,28 +6784,46 @@ struct llama_context * llama_new_context_with_model( ctx->alloc = ggml_allocr_new(ctx->buf_alloc.data, ctx->buf_alloc.size, tensor_alignment); #ifdef GGML_USE_METAL if (ctx->ctx_metal) { - ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); + //ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); } #endif #ifdef GGML_USE_CUBLAS - if (params.low_vram) { - LLAMA_LOG_INFO("%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__); - ggml_cuda_set_scratch_size(0); // disable scratch - } else { - ggml_cuda_set_scratch_size(alloc_size); - LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MB\n", __func__, alloc_size / 1024.0 / 1024.0); + ggml_cuda_set_scratch_size(alloc_size); + LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MB\n", __func__, alloc_size / 1024.0 / 1024.0); + + // calculate total VRAM usage + auto add_tensor = [](const ggml_tensor * t, size_t & size) { + if (t->backend == GGML_BACKEND_GPU || t->backend == GGML_BACKEND_GPU_SPLIT) { + size += ggml_nbytes(t); + } + }; + size_t model_vram_size = 0; + for (const auto & kv : model->tensors_by_name) { + add_tensor(kv.second, model_vram_size); } + + size_t kv_vram_size = 0; + add_tensor(ctx->kv_self.k, kv_vram_size); + add_tensor(ctx->kv_self.v, kv_vram_size); + + size_t ctx_vram_size = alloc_size + kv_vram_size; + size_t total_vram_size = model_vram_size + ctx_vram_size; + + LLAMA_LOG_INFO("%s: total VRAM used: %.2f MB (model: %.2f MB, context: %.2f MB)\n", __func__, + total_vram_size / 1024.0 / 1024.0, + model_vram_size / 1024.0 / 1024.0, + ctx_vram_size / 1024.0 / 1024.0); #endif } #ifdef GGML_USE_METAL - if (params.n_gpu_layers > 0) { + if (model->n_gpu_layers > 0) { // this allocates all Metal resources and memory buffers void * data_ptr = NULL; size_t data_size = 0; - if (params.use_mmap) { + if (ctx->model.mapping) { data_ptr = ctx->model.mapping->addr; data_size = ctx->model.mapping->size; } else { @@ -6422,11 +6842,8 @@ struct llama_context * llama_new_context_with_model( return NULL; \ } - LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "data", data_ptr, data_size, max_size)); - - LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.data, ctx->buf_compute.size, 0)); - LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->kv_self.buf.data, ctx->kv_self.buf.size, 0)); - + LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "data", data_ptr, data_size, max_size)); + LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->kv_self.buf.data, ctx->kv_self.buf.size, 0)); LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "alloc", ctx->buf_alloc.data, ctx->buf_alloc.size, 0)); #undef LLAMA_METAL_CHECK_BUF } @@ -6438,8 +6855,10 @@ struct llama_context * llama_new_context_with_model( if (ggml_mpi_rank(ctx->ctx_mpi) > 0) { // Enter a blocking eval loop with dummy input, letting rank=0 drive the process - const std::vector tmp(ctx->model.hparams.n_ctx, llama_token_bos(ctx)); - while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {}; + // TODO: needs fix after #3228 + GGML_ASSERT(false && "not implemented"); + //const std::vector tmp(ctx->model.hparams.n_ctx, llama_token_bos(ctx)); + //while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {}; llama_backend_free(); exit(1); } @@ -6448,63 +6867,37 @@ struct llama_context * llama_new_context_with_model( return ctx; } -static struct llama_context * llama_init_from_file( - const char * path_model, - struct llama_context_params params) { - struct llama_model * model = llama_load_model_from_file(path_model, params); - if (!model) { - return nullptr; - } - - struct llama_context * ctx = llama_new_context_with_model(model, params); - ctx->model_owner = true; - - return ctx; -} - void llama_free(struct llama_context * ctx) { delete ctx; } -int llama_n_vocab(const struct llama_context * ctx) { - return llama_model_n_vocab(&ctx->model); +const llama_model * llama_get_model(const struct llama_context * ctx) { + return &ctx->model; } int llama_n_ctx(const struct llama_context * ctx) { - return llama_model_n_ctx(&ctx->model); -} - -int llama_n_ctx_train(const struct llama_context * ctx) { - return llama_model_n_ctx_train(&ctx->model); + return ctx->cparams.n_ctx; } -int llama_n_embd(const struct llama_context * ctx) { - return llama_model_n_embd(&ctx->model); +enum llama_vocab_type llama_vocab_type(const struct llama_model * model) { + return model->vocab.type; } -enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx) { - return ctx->model.vocab.type; -} - -int llama_model_n_vocab(const struct llama_model * model) { +int llama_n_vocab(const struct llama_model * model) { return model->vocab.id_to_token.size(); } -int llama_model_n_ctx(const struct llama_model * model) { - return model->hparams.n_ctx; -} - -int llama_model_n_ctx_train(const struct llama_model * model) { +int llama_n_ctx_train(const struct llama_model * model) { return model->hparams.n_ctx_train; } -int llama_model_n_embd(const struct llama_model * model) { +int llama_n_embd(const struct llama_model * model) { return model->hparams.n_embd; } int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) { return snprintf(buf, buf_size, "%s %s %s", - model->name.c_str(), + llama_model_arch_name(model->arch).c_str(), llama_model_type_name(model->type), llama_model_ftype_name(model->ftype).c_str()); } @@ -6525,6 +6918,10 @@ uint64_t llama_model_n_params(const struct llama_model * model) { return nparams; } +struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name) { + return ggml_get_tensor(model->ctx, name); +} + int llama_model_quantize( const char * fname_inp, const char * fname_out, @@ -6538,18 +6935,18 @@ int llama_model_quantize( } } -int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, const char * path_base_model, int n_threads) { +int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, float scale, const char * path_base_model, int n_threads) { try { - return llama_apply_lora_from_file_internal(ctx->model, path_lora, path_base_model, n_threads); + return llama_apply_lora_from_file_internal(ctx->model, path_lora, scale, path_base_model, n_threads); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what()); return 1; } } -int llama_model_apply_lora_from_file(const struct llama_model * model, const char * path_lora, const char * path_base_model, int n_threads) { +int llama_model_apply_lora_from_file(const struct llama_model * model, const char * path_lora, float scale, const char * path_base_model, int n_threads) { try { - return llama_apply_lora_from_file_internal(*model, path_lora, path_base_model, n_threads); + return llama_apply_lora_from_file_internal(*model, path_lora, scale, path_base_model, n_threads); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what()); return 1; @@ -6557,16 +6954,27 @@ int llama_model_apply_lora_from_file(const struct llama_model * model, const cha } int llama_get_kv_cache_token_count(const struct llama_context * ctx) { - return ctx->kv_self.n; + return ctx->kv_self.head; } -#define LLAMA_MAX_RNG_STATE (64*1024) +void llama_kv_cache_tokens_rm(struct llama_context * ctx, int32_t c0, int32_t c1) { + llama_kv_cache_tokens_rm(ctx->kv_self, c0, c1); +} -void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { - if (seed == LLAMA_DEFAULT_SEED) { - seed = time(NULL); - } - ctx->rng.seed(seed); +void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1); +} + +void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1); +} + +void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { + llama_kv_cache_seq_keep(ctx->kv_self, seq_id); +} + +void llama_kv_cache_seq_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + llama_kv_cache_seq_shift(ctx->kv_self, seq_id, p0, p1, delta); } // Returns the *maximum* size of the state @@ -6654,6 +7062,16 @@ struct llama_data_file_context : llama_data_context { * */ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { + // TODO: does not support multi-sequence states + { + const auto & kv_self = ctx->kv_self; + for (uint32_t i = 0; i < kv_self.head; ++i) { + GGML_ASSERT(kv_self.cells[i].pos == (int32_t) i); + GGML_ASSERT(kv_self.cells[i].seq_id.size() == 1); + GGML_ASSERT(kv_self.cells[i].has_seq_id(0)); + } + } + // copy rng { std::stringstream rng_ss; @@ -6704,12 +7122,14 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat { const auto & kv_self = ctx->kv_self; const auto & hparams = ctx->model.hparams; + const auto & cparams = ctx->cparams; + const int n_layer = hparams.n_layer; const int n_embd = hparams.n_embd_gqa(); - const int n_ctx = hparams.n_ctx; + const int n_ctx = cparams.n_ctx; const size_t kv_size = kv_self.buf.size; - const int kv_ntok = llama_get_kv_cache_token_count(ctx); + const int kv_ntok = kv_self.head; data_ctx->write(&kv_size, sizeof(kv_size)); data_ctx->write(&kv_ntok, sizeof(kv_ntok)); @@ -6812,9 +7232,11 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { { const auto & kv_self = ctx->kv_self; const auto & hparams = ctx->model.hparams; + const auto & cparams = ctx->cparams; + const int n_layer = hparams.n_layer; const int n_embd = hparams.n_embd_gqa(); - const int n_ctx = hparams.n_ctx; + const int n_ctx = cparams.n_ctx; size_t kv_size; int kv_ntok; @@ -6853,7 +7275,8 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { ggml_free(cpy_ctx); } - ctx->kv_self.n = kv_ntok; + ctx->kv_self.head = kv_ntok; + ctx->kv_self.size = kv_size; } const size_t nread = inp - src; @@ -6948,64 +7371,102 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi int llama_eval( struct llama_context * ctx, - const llama_token * tokens, - int n_tokens, - int n_past, - int n_threads) { - if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, n_threads, nullptr)) { - LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); - return 1; - } + llama_token * tokens, + int32_t n_tokens, + int n_past) { + llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); - // get a more accurate load time, upon first eval - // TODO: fix this - if (!ctx->has_evaluated_once) { - ctx->t_load_us = ggml_time_us() - ctx->t_start_us; - ctx->has_evaluated_once = true; + const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0)); + if (ret < 0) { + LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } - return 0; + return ret; } int llama_eval_embd( struct llama_context * ctx, - const float * embd, - int n_tokens, - int n_past, - int n_threads) { - if (!llama_eval_internal(*ctx, nullptr, embd, n_tokens, n_past, n_threads, nullptr)) { - LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); - return 1; - } + float * embd, + int32_t n_tokens, + int n_past) { + llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); - // get a more accurate load time, upon first eval - // TODO: fix this - if (!ctx->has_evaluated_once) { - ctx->t_load_us = ggml_time_us() - ctx->t_start_us; - ctx->has_evaluated_once = true; + llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, }; + + const int ret = llama_decode_internal(*ctx, batch); + if (ret < 0) { + LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } - return 0; + return ret; } -int llama_eval_export(struct llama_context * ctx, const char * fname) { - const int n_batch = 1; - const int n_ctx = 512 - n_batch; +void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) { + ctx->cparams.n_threads = n_threads; + ctx->cparams.n_threads_batch = n_threads_batch; +} + +struct llama_batch llama_batch_get_one( + llama_token * tokens, + int32_t n_tokens, + llama_pos pos_0, + llama_seq_id seq_id) { + return { + /*n_tokens =*/ n_tokens, + /*tokens =*/ tokens, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + /*all_pos_0 =*/ pos_0, + /*all_pos_1 =*/ 1, + /*all_seq_id =*/ seq_id, + }; +} - const std::vector tmp(n_batch, llama_token_bos(ctx)); +struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) { + llama_batch batch = { -1, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, }; - if (!llama_eval_internal(*ctx, tmp.data(), nullptr, tmp.size(), n_ctx, 1, fname)) { - LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); - return 1; + if (embd) { + batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd); + } else { + batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); } - return 0; + batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); + batch.seq_id = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_tokens); + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + + return batch; +} + +void llama_batch_free(struct llama_batch batch) { + if (batch.token) free(batch.token); + if (batch.embd) free(batch.embd); + if (batch.pos) free(batch.pos); + if (batch.seq_id) free(batch.seq_id); + if (batch.logits) free(batch.logits); +} + +int llama_decode( + struct llama_context * ctx, + struct llama_batch batch) { + const int ret = llama_decode_internal(*ctx, batch); + if (ret < 0) { + LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); + } + + return ret; } float * llama_get_logits(struct llama_context * ctx) { return ctx->logits.data(); } +float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { + return ctx->logits.data() + i*ctx->model.hparams.n_vocab; +} + float * llama_get_embeddings(struct llama_context * ctx) { return ctx->embedding.data(); } @@ -7033,18 +7494,24 @@ llama_token llama_token_eos(const struct llama_context * ctx) { llama_token llama_token_nl(const struct llama_context * ctx) { return ctx->model.vocab.linefeed_id; } +llama_token llama_token_prefix(const struct llama_context * ctx) { + return ctx->model.vocab.special_prefix_id; +} -int llama_tokenize( - struct llama_context * ctx, - const char * text, - int text_len, - llama_token * tokens, - int n_max_tokens, - bool add_bos) { - return llama_tokenize_with_model(&ctx->model, text, text_len, tokens, n_max_tokens, add_bos); +llama_token llama_token_middle(const struct llama_context * ctx) { + return ctx->model.vocab.special_middle_id; +} + +llama_token llama_token_suffix(const struct llama_context * ctx) { + return ctx->model.vocab.special_suffix_id; } -int llama_tokenize_with_model( +llama_token llama_token_eot(const struct llama_context * ctx) { + return ctx->model.vocab.special_eot_id; +} + + +int llama_tokenize( const struct llama_model * model, const char * text, int text_len, @@ -7065,13 +7532,9 @@ int llama_tokenize_with_model( return res.size(); } -int llama_token_to_piece(const struct llama_context * ctx, llama_token token, char * buf, int length) { - return llama_token_to_piece_with_model(&ctx->model, token, buf, length); -} - // does not write null-terminator to buf -int llama_token_to_piece_with_model(const struct llama_model * model, llama_token token, char * buf, int length) { - if (0 <= token && token < llama_model_n_vocab(model)) { +int llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int length) { + if (0 <= token && token < llama_n_vocab(model)) { if (llama_is_normal_token(model->vocab, token)) { std::string result = model->vocab.id_to_token[token].text; if (llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM) { @@ -7091,7 +7554,7 @@ int llama_token_to_piece_with_model(const struct llama_model * model, llama_toke buf[2] = '\x85'; return 3; } else if (llama_is_control_token(model->vocab, token)) { - ; + // do nothing } else if (llama_is_byte_token(model->vocab, token)) { if (length < 1) { return -1; @@ -7199,12 +7662,12 @@ const std::vector> & llama_internal return ctx->model.tensors_by_name; } -void llama_log_set(llama_log_callback log_callback, void * user_data) { +void llama_log_set(ggml_log_callback log_callback, void * user_data) { g_state.log_callback = log_callback ? log_callback : llama_log_callback_default; g_state.log_callback_user_data = user_data; } -static void llama_log_internal_v(llama_log_level level, const char * format, va_list args) { +static void llama_log_internal_v(ggml_log_level level, const char * format, va_list args) { va_list args_copy; va_copy(args_copy, args); char buffer[128]; @@ -7221,14 +7684,14 @@ static void llama_log_internal_v(llama_log_level level, const char * format, va_ va_end(args_copy); } -static void llama_log_internal(llama_log_level level, const char * format, ...) { +static void llama_log_internal(ggml_log_level level, const char * format, ...) { va_list args; va_start(args, format); llama_log_internal_v(level, format, args); va_end(args); } -static void llama_log_callback_default(llama_log_level level, const char * text, void * user_data) { +static void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data) { (void) level; (void) user_data; if (std::getenv("LLAMA_LOG") != nullptr) { diff --git a/plugins/wasi_nn/thirdparty/ggml/llama.h b/plugins/wasi_nn/thirdparty/ggml/llama.h index 369be048..fd215840 100644 --- a/plugins/wasi_nn/thirdparty/ggml/llama.h +++ b/plugins/wasi_nn/thirdparty/ggml/llama.h @@ -37,6 +37,8 @@ #define LLAMA_DEFAULT_SEED 0xFFFFFFFF +#define LLAMA_MAX_RNG_STATE (64*1024) + #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN @@ -60,13 +62,9 @@ extern "C" { struct llama_model; struct llama_context; - typedef int llama_token; - - enum llama_log_level { - LLAMA_LOG_LEVEL_ERROR = 2, - LLAMA_LOG_LEVEL_WARN = 3, - LLAMA_LOG_LEVEL_INFO = 4 - }; + typedef int32_t llama_pos; + typedef int32_t llama_token; + typedef int32_t llama_seq_id; enum llama_vocab_type { LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece @@ -86,24 +84,24 @@ extern "C" { // model file types enum llama_ftype { LLAMA_FTYPE_ALL_F32 = 0, - LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 - // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed - // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed - LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors + LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 + // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed + // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed + LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_S = 11, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_M = 12, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_L = 13, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_K_S = 14, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_K_M = 15, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_K_S = 16, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_K_M = 17, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q6_K = 18, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; @@ -122,41 +120,68 @@ extern "C" { typedef void (*llama_progress_callback)(float progress, void *ctx); - struct llama_context_params { - uint32_t seed; // RNG seed, -1 for random - int32_t n_ctx; // text context - int32_t n_batch; // prompt processing batch size - int32_t n_gpu_layers; // number of layers to store in VRAM - int32_t main_gpu; // the GPU that is used for scratch and small tensors - + // Input data for llama_decode + // A llama_batch object can contain input about one or many sequences + // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens + // + // - token : the token ids of the input (used when embd is NULL) + // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) + // - pos : the positions of the respective token in the sequence + // - seq_id : the sequence to which the respective token belongs + // - logits : if zero, the logits for the respective token will not be output + // + typedef struct llama_batch { + int32_t n_tokens; + + llama_token * token; + float * embd; + llama_pos * pos; + llama_seq_id * seq_id; + int8_t * logits; + + // NOTE: helpers for smooth API transition - can be deprecated in the future + // for future-proof code, use the above fields instead and ignore everything below + // + // pos[i] = all_pos_0 + i*all_pos_1 + // + llama_pos all_pos_0; // used if pos == NULL + llama_pos all_pos_1; // used if pos == NULL + llama_seq_id all_seq_id; // used if seq_id == NULL + } llama_batch; + + struct llama_model_params { + int32_t n_gpu_layers; // number of layers to store in VRAM + int32_t main_gpu; // the GPU that is used for scratch and small tensors const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES) - // ref: https://github.com/ggerganov/llama.cpp/pull/2054 - float rope_freq_base; // RoPE base frequency - float rope_freq_scale; // RoPE frequency scaling factor - // called with a progress value between 0 and 1, pass NULL to disable llama_progress_callback progress_callback; // context pointer passed to the progress callback void * progress_callback_user_data; // Keep the booleans together to avoid misalignment during copy-by-value. - bool low_vram; // if true, reduce VRAM usage at the cost of performance - bool mul_mat_q; // if true, use experimental mul_mat_q kernels - bool f16_kv; // use fp16 for KV cache - bool logits_all; // the llama_eval() call computes all logits, not just the last one bool vocab_only; // only load the vocabulary, no weights bool use_mmap; // use mmap if possible bool use_mlock; // force system to keep model in RAM - bool embedding; // embedding mode only }; - // Signature for logging events - // Note that text includes the new line character at the end for most events. - // If your logging mechanism cannot handle that, check if the last character is '\n' and strip it - // if it exists. - // It might not exist for progress report where '.' is output repeatedly. - typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data); + struct llama_context_params { + uint32_t seed; // RNG seed, -1 for random + uint32_t n_ctx; // text context, 0 = from model + uint32_t n_batch; // prompt processing maximum batch size + uint32_t n_threads; // number of threads to use for generation + uint32_t n_threads_batch; // number of threads to use for batch processing + + // ref: https://github.com/ggerganov/llama.cpp/pull/2054 + float rope_freq_base; // RoPE base frequency, 0 = from model + float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model + + // Keep the booleans together to avoid misalignment during copy-by-value. + bool mul_mat_q; // if true, use experimental mul_mat_q kernels + bool f16_kv; // use fp16 for KV cache, fp32 otherwise + bool logits_all; // the llama_eval() call computes all logits, not just the last one + bool embedding; // embedding mode only + }; // model quantization parameters typedef struct llama_model_quantize_params { @@ -215,6 +240,8 @@ extern "C" { int32_t n_eval; }; + // Helpers for getting default parameters + LLAMA_API struct llama_model_params llama_model_default_params(void); LLAMA_API struct llama_context_params llama_context_default_params(void); LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); @@ -228,7 +255,7 @@ extern "C" { LLAMA_API struct llama_model * llama_load_model_from_file( const char * path_model, - struct llama_context_params params); + struct llama_model_params params); LLAMA_API void llama_free_model(struct llama_model * model); @@ -245,25 +272,28 @@ extern "C" { LLAMA_API bool llama_mmap_supported (void); LLAMA_API bool llama_mlock_supported(void); - LLAMA_API int llama_n_vocab (const struct llama_context * ctx); + LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); + LLAMA_API int llama_n_ctx (const struct llama_context * ctx); - LLAMA_API int llama_n_ctx_train(const struct llama_context * ctx); - LLAMA_API int llama_n_embd (const struct llama_context * ctx); - LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx); + LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model); - LLAMA_API int llama_model_n_vocab (const struct llama_model * model); - LLAMA_API int llama_model_n_ctx (const struct llama_model * model); - LLAMA_API int llama_model_n_ctx_train(const struct llama_model * model); - LLAMA_API int llama_model_n_embd (const struct llama_model * model); + LLAMA_API int llama_n_vocab (const struct llama_model * model); + LLAMA_API int llama_n_ctx_train(const struct llama_model * model); + LLAMA_API int llama_n_embd (const struct llama_model * model); // Get a string describing the model type LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size); + // Returns the total size of all the tensors in the model in bytes LLAMA_API uint64_t llama_model_size(const struct llama_model * model); + // Returns the total number of parameters in the model LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); + // Get a llama model tensor + LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name); + // Returns 0 on success LLAMA_API int llama_model_quantize( const char * fname_inp, @@ -279,21 +309,65 @@ extern "C" { LLAMA_API DEPRECATED(int llama_apply_lora_from_file( struct llama_context * ctx, const char * path_lora, + float scale, const char * path_base_model, int n_threads), - "please use llama_model_apply_lora_from_file instead"); + "use llama_model_apply_lora_from_file instead"); LLAMA_API int llama_model_apply_lora_from_file( const struct llama_model * model, - const char * path_lora, - const char * path_base_model, - int n_threads); + const char * path_lora, + float scale, + const char * path_base_model, + int n_threads); + + // + // KV cache + // // Returns the number of tokens in the KV cache - LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx); + LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx), + "avoid using this, it will be removed in the future, instead - count the tokens in user code"); - // Sets the current rng seed. - LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); + // Remove all tokens data of cells in [c0, c1) + LLAMA_API void llama_kv_cache_tokens_rm( + struct llama_context * ctx, + int32_t c0, + int32_t c1); + + // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) + LLAMA_API void llama_kv_cache_seq_rm( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1); + + // Copy all tokens that belong to the specified sequence to another sequence + // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence + LLAMA_API void llama_kv_cache_seq_cp( + struct llama_context * ctx, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1); + + // Removes all tokens that do not belong to the specified sequence + LLAMA_API void llama_kv_cache_seq_keep( + struct llama_context * ctx, + llama_seq_id seq_id); + + // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) + // If the KV cache is RoPEd, the KV data is updated accordingly + LLAMA_API void llama_kv_cache_seq_shift( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta); + + // + // State / sessions + // // Returns the maximum size in bytes of the state (rng, logits, embedding // and kv_cache) - will often be smaller after compacting tokens @@ -302,48 +376,102 @@ extern "C" { // Copies the state to the specified destination address. // Destination needs to have allocated enough memory. // Returns the number of bytes copied - LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst); + LLAMA_API size_t llama_copy_state_data( + struct llama_context * ctx, + uint8_t * dst); // Set the state reading from the specified address // Returns the number of bytes read - LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src); + LLAMA_API size_t llama_set_state_data( + struct llama_context * ctx, + uint8_t * src); // Save/load session file - LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out); - LLAMA_API bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count); + LLAMA_API bool llama_load_session_file( + struct llama_context * ctx, + const char * path_session, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out); - // Run the llama inference to obtain the logits and probabilities for the next token. + LLAMA_API bool llama_save_session_file( + struct llama_context * ctx, + const char * path_session, + const llama_token * tokens, + size_t n_token_count); + + // + // Decoding + // + + // Run the llama inference to obtain the logits and probabilities for the next token(s). // tokens + n_tokens is the provided batch of new tokens to process // n_past is the number of tokens to use from previous eval calls // Returns 0 on success - LLAMA_API int llama_eval( + // DEPRECATED: use llama_decode() instead + LLAMA_API DEPRECATED(int llama_eval( struct llama_context * ctx, - const llama_token * tokens, - int n_tokens, - int n_past, - int n_threads); + llama_token * tokens, + int32_t n_tokens, + int n_past), + "use llama_decode() instead"); // Same as llama_eval, but use float matrix input directly. - LLAMA_API int llama_eval_embd( + // DEPRECATED: use llama_decode() instead + LLAMA_API DEPRECATED(int llama_eval_embd( struct llama_context * ctx, - const float * embd, - int n_tokens, - int n_past, - int n_threads); + float * embd, + int32_t n_tokens, + int n_past), + "use llama_decode() instead"); + + // Return batch for single sequence of tokens starting at pos_0 + // + // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it + // + LLAMA_API struct llama_batch llama_batch_get_one( + llama_token * tokens, + int32_t n_tokens, + llama_pos pos_0, + llama_seq_id seq_id); + + // Allocates a batch of tokens on the heap + // The batch has to be freed with llama_batch_free() + // If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float) + // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token + // The rest of the llama_batch members are allocated with size n_tokens + // All members are left uninitialized + LLAMA_API struct llama_batch llama_batch_init( + int32_t n_tokens, + int32_t embd); + + // Frees a batch of tokens allocated with llama_batch_init() + LLAMA_API void llama_batch_free(struct llama_batch batch); + + // Positive return values does not mean a fatal error, but rather a warning. + // 0 - success + // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) + // < 0 - error + LLAMA_API int llama_decode( + struct llama_context * ctx, + struct llama_batch batch); - // Export a static computation graph for context of 511 and batch size of 1 - // NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these - // parameters here to keep things simple - // IMPORTANT: do not use for anything else other than debugging and testing! - LLAMA_API int llama_eval_export(struct llama_context * ctx, const char * fname); + // Set the number of threads used for decoding + // n_threads is the number of threads used for generation (single token) + // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) + LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch); // Token logits obtained from the last call to llama_eval() // The logits for the last token are stored in the last row - // Can be mutated in order to change the probabilities of the next token - // Rows: n_tokens + // Logits for which llama_batch.logits[i] == 0 are undefined + // Rows: n_tokens provided with llama_batch // Cols: n_vocab LLAMA_API float * llama_get_logits(struct llama_context * ctx); + // Logits for the ith token. Equivalent to: + // llama_get_logits(ctx) + i*n_vocab + LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); + // Get the embeddings for the input // shape: [n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); @@ -362,6 +490,11 @@ extern "C" { LLAMA_API llama_token llama_token_bos(const struct llama_context * ctx); // beginning-of-sentence LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx); // end-of-sentence LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx); // next-line + // codellama infill tokens + LLAMA_API llama_token llama_token_prefix(const struct llama_context * ctx); // Beginning of infill prefix + LLAMA_API llama_token llama_token_middle(const struct llama_context * ctx); // Beginning of infill middle + LLAMA_API llama_token llama_token_suffix(const struct llama_context * ctx); // Beginning of infill suffix + LLAMA_API llama_token llama_token_eot (const struct llama_context * ctx); // End of infill middle // // Tokenization @@ -372,14 +505,6 @@ extern "C" { // Returns the number of tokens on success, no more than n_max_tokens // Returns a negative number on failure - the number of tokens that would have been returned LLAMA_API int llama_tokenize( - struct llama_context * ctx, - const char * text, - int text_len, - llama_token * tokens, - int n_max_tokens, - bool add_bos); - - LLAMA_API int llama_tokenize_with_model( const struct llama_model * model, const char * text, int text_len, @@ -392,12 +517,6 @@ extern "C" { // Does not write null terminator to the buffer. // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens. LLAMA_API int llama_token_to_piece( - const struct llama_context * ctx, - llama_token token, - char * buf, - int length); - - LLAMA_API int llama_token_to_piece_with_model( const struct llama_model * model, llama_token token, char * buf, @@ -420,11 +539,25 @@ extern "C" { // Sampling functions // + // Sets the current rng seed. + LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); + /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. - LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty); + LLAMA_API void llama_sample_repetition_penalty( + struct llama_context * ctx, + llama_token_data_array * candidates, + const llama_token * last_tokens, + size_t last_tokens_size, + float penalty); /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. - LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence); + LLAMA_API void llama_sample_frequency_and_presence_penalties( + struct llama_context * ctx, + llama_token_data_array * candidates, + const llama_token * last_tokens, + size_t last_tokens_size, + float alpha_frequency, + float alpha_presence); /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted. @@ -437,23 +570,54 @@ extern "C" { float scale); /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. - LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates); + LLAMA_API void llama_sample_softmax( + struct llama_context * ctx, + llama_token_data_array * candidates); /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep); + LLAMA_API void llama_sample_top_k( + struct llama_context * ctx, + llama_token_data_array * candidates, + int k, + size_t min_keep); /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep); + LLAMA_API void llama_sample_top_p( + struct llama_context * ctx, + llama_token_data_array * candidates, + float p, + size_t min_keep); /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. - LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep); + LLAMA_API void llama_sample_tail_free( + struct llama_context * ctx, + llama_token_data_array * candidates, + float z, + size_t min_keep); /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. - LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep); - LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp); + LLAMA_API void llama_sample_typical( + struct llama_context * ctx, + llama_token_data_array * candidates, + float p, + size_t min_keep); + + LLAMA_API void llama_sample_temp( + struct llama_context * ctx, + llama_token_data_array * candidates, + float temp); + + LLAMA_API DEPRECATED(void llama_sample_temperature( + struct llama_context * ctx, + llama_token_data_array * candidates, + float temp), + "use llama_sample_temp instead"); /// @details Apply constraints from grammar - LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar); + LLAMA_API void llama_sample_grammar( + struct llama_context * ctx, + llama_token_data_array * candidates, + const struct llama_grammar * grammar); /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. @@ -461,23 +625,41 @@ extern "C" { /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu); + LLAMA_API llama_token llama_sample_token_mirostat( + struct llama_context * ctx, + llama_token_data_array * candidates, + float tau, + float eta, + int m, + float * mu); /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu); + LLAMA_API llama_token llama_sample_token_mirostat_v2( + struct llama_context * ctx, + llama_token_data_array * candidates, + float tau, + float eta, + float * mu); /// @details Selects the token with the highest probability. - LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates); + LLAMA_API llama_token llama_sample_token_greedy( + struct llama_context * ctx, + llama_token_data_array * candidates); /// @details Randomly selects a token from the candidates based on their probabilities. - LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates); + LLAMA_API llama_token llama_sample_token( + struct llama_context * ctx, + llama_token_data_array * candidates); /// @details Accepts the sampled token into the grammar - LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token); + LLAMA_API void llama_grammar_accept_token( + struct llama_context * ctx, + struct llama_grammar * grammar, + llama_token token); // // Beam search @@ -485,9 +667,10 @@ extern "C" { struct llama_beam_view { const llama_token * tokens; + size_t n_tokens; - float p; // Cumulative beam probability (renormalized relative to all beams) - bool eob; // Callback should set this to true when a beam is at end-of-beam. + float p; // Cumulative beam probability (renormalized relative to all beams) + bool eob; // Callback should set this to true when a beam is at end-of-beam. }; // Passed to beam_search_callback function. @@ -496,9 +679,10 @@ extern "C" { // These pointers are valid only during the synchronous callback, so should not be saved. struct llama_beams_state { struct llama_beam_view * beam_views; + size_t n_beams; // Number of elements in beam_views[]. size_t common_prefix_length; // Current max length of prefix tokens shared by all beams. - bool last_call; // True iff this is the last callback invocation. + bool last_call; // True iff this is the last callback invocation. }; // Type of pointer to the beam_search_callback function. @@ -513,11 +697,17 @@ extern "C" { /// @param n_beams Number of beams to use. /// @param n_past Number of tokens already evaluated. /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier. - /// @param n_threads Number of threads as passed to llama_eval(). - LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads); + LLAMA_API void llama_beam_search( + struct llama_context * ctx, + llama_beam_search_callback_fn_t callback, + void * callback_data, + size_t n_beams, + int n_past, + int n_predict); // Performance information LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); + LLAMA_API void llama_print_timings(struct llama_context * ctx); LLAMA_API void llama_reset_timings(struct llama_context * ctx); @@ -526,7 +716,7 @@ extern "C" { // Set callback for all future logging events. // If this is not called, or NULL is supplied, everything is output on stderr. - LLAMA_API void llama_log_set(llama_log_callback log_callback, void * user_data); + LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data); LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx); diff --git a/plugins/wasi_nn/thirdparty/ggml/log.h b/plugins/wasi_nn/thirdparty/ggml/log.h index 18f3b976..b8953fdc 100644 --- a/plugins/wasi_nn/thirdparty/ggml/log.h +++ b/plugins/wasi_nn/thirdparty/ggml/log.h @@ -225,31 +225,31 @@ enum LogTriState // USE LOG() INSTEAD // #ifndef _MSC_VER - #define LOG_IMPL(str, ...) \ - { \ + #define LOG_IMPL(str, ...) \ + do { \ if (LOG_TARGET != nullptr) \ { \ fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL, __VA_ARGS__); \ fflush(LOG_TARGET); \ } \ - } + } while (0) #else - #define LOG_IMPL(str, ...) \ - { \ + #define LOG_IMPL(str, ...) \ + do { \ if (LOG_TARGET != nullptr) \ { \ fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL "", ##__VA_ARGS__); \ fflush(LOG_TARGET); \ } \ - } + } while (0) #endif // INTERNAL, DO NOT USE // USE LOG_TEE() INSTEAD // #ifndef _MSC_VER - #define LOG_TEE_IMPL(str, ...) \ - { \ + #define LOG_TEE_IMPL(str, ...) \ + do { \ if (LOG_TARGET != nullptr) \ { \ fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL, __VA_ARGS__); \ @@ -260,10 +260,10 @@ enum LogTriState fprintf(LOG_TEE_TARGET, LOG_TEE_TIMESTAMP_FMT LOG_TEE_FLF_FMT str "%s" LOG_TEE_TIMESTAMP_VAL LOG_TEE_FLF_VAL, __VA_ARGS__); \ fflush(LOG_TEE_TARGET); \ } \ - } + } while (0) #else - #define LOG_TEE_IMPL(str, ...) \ - { \ + #define LOG_TEE_IMPL(str, ...) \ + do { \ if (LOG_TARGET != nullptr) \ { \ fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL "", ##__VA_ARGS__); \ @@ -274,7 +274,7 @@ enum LogTriState fprintf(LOG_TEE_TARGET, LOG_TEE_TIMESTAMP_FMT LOG_TEE_FLF_FMT str "%s" LOG_TEE_TIMESTAMP_VAL LOG_TEE_FLF_VAL "", ##__VA_ARGS__); \ fflush(LOG_TEE_TARGET); \ } \ - } + } while (0) #endif // The '\0' as a last argument, is a trick to bypass the silly @@ -435,41 +435,41 @@ inline FILE *log_handler() { return log_handler1_impl(); } inline void log_test() { log_disable(); - LOG("01 Hello World to nobody, because logs are disabled!\n") + LOG("01 Hello World to nobody, because logs are disabled!\n"); log_enable(); - LOG("02 Hello World to default output, which is \"%s\" ( Yaaay, arguments! )!\n", LOG_STRINGIZE(LOG_TARGET)) - LOG_TEE("03 Hello World to **both** default output and " LOG_TEE_TARGET_STRING "!\n") + LOG("02 Hello World to default output, which is \"%s\" ( Yaaay, arguments! )!\n", LOG_STRINGIZE(LOG_TARGET)); + LOG_TEE("03 Hello World to **both** default output and " LOG_TEE_TARGET_STRING "!\n"); log_set_target(stderr); - LOG("04 Hello World to stderr!\n") - LOG_TEE("05 Hello World TEE with double printing to stderr prevented!\n") + LOG("04 Hello World to stderr!\n"); + LOG_TEE("05 Hello World TEE with double printing to stderr prevented!\n"); log_set_target(LOG_DEFAULT_FILE_NAME); - LOG("06 Hello World to default log file!\n") + LOG("06 Hello World to default log file!\n"); log_set_target(stdout); - LOG("07 Hello World to stdout!\n") + LOG("07 Hello World to stdout!\n"); log_set_target(LOG_DEFAULT_FILE_NAME); - LOG("08 Hello World to default log file again!\n") + LOG("08 Hello World to default log file again!\n"); log_disable(); - LOG("09 Hello World _1_ into the void!\n") + LOG("09 Hello World _1_ into the void!\n"); log_enable(); - LOG("10 Hello World back from the void ( you should not see _1_ in the log or the output )!\n") + LOG("10 Hello World back from the void ( you should not see _1_ in the log or the output )!\n"); log_disable(); log_set_target("llama.anotherlog.log"); - LOG("11 Hello World _2_ to nobody, new target was selected but logs are still disabled!\n") + LOG("11 Hello World _2_ to nobody, new target was selected but logs are still disabled!\n"); log_enable(); - LOG("12 Hello World this time in a new file ( you should not see _2_ in the log or the output )?\n") + LOG("12 Hello World this time in a new file ( you should not see _2_ in the log or the output )?\n"); log_set_target("llama.yetanotherlog.log"); - LOG("13 Hello World this time in yet new file?\n") + LOG("13 Hello World this time in yet new file?\n"); log_set_target(log_filename_generator("llama_autonamed", "log")); - LOG("14 Hello World in log with generated filename!\n") + LOG("14 Hello World in log with generated filename!\n"); #ifdef _MSC_VER - LOG_TEE("15 Hello msvc TEE without arguments\n") - LOG_TEE("16 Hello msvc TEE with (%d)(%s) arguments\n", 1, "test") - LOG_TEELN("17 Hello msvc TEELN without arguments\n") - LOG_TEELN("18 Hello msvc TEELN with (%d)(%s) arguments\n", 1, "test") - LOG("19 Hello msvc LOG without arguments\n") - LOG("20 Hello msvc LOG with (%d)(%s) arguments\n", 1, "test") - LOGLN("21 Hello msvc LOGLN without arguments\n") - LOGLN("22 Hello msvc LOGLN with (%d)(%s) arguments\n", 1, "test") + LOG_TEE("15 Hello msvc TEE without arguments\n"); + LOG_TEE("16 Hello msvc TEE with (%d)(%s) arguments\n", 1, "test"); + LOG_TEELN("17 Hello msvc TEELN without arguments\n"); + LOG_TEELN("18 Hello msvc TEELN with (%d)(%s) arguments\n", 1, "test"); + LOG("19 Hello msvc LOG without arguments\n"); + LOG("20 Hello msvc LOG with (%d)(%s) arguments\n", 1, "test"); + LOGLN("21 Hello msvc LOGLN without arguments\n"); + LOGLN("22 Hello msvc LOGLN with (%d)(%s) arguments\n", 1, "test"); #endif } @@ -542,7 +542,7 @@ inline void log_dump_cmdline_impl(int argc, char **argv) buf << " " << argv[i]; } } - LOGLN("Cmd:%s", buf.str().c_str()) + LOGLN("Cmd:%s", buf.str().c_str()); } #define log_tostr(var) log_var_to_string_impl(var).c_str() @@ -620,10 +620,10 @@ inline std::string log_var_to_string_impl(const std::vector & var) #define LOGLN(...) // dummy stub #undef LOG_TEE -#define LOG_TEE(...) fprintf(stderr, __VA_ARGS__); // convert to normal fprintf +#define LOG_TEE(...) fprintf(stderr, __VA_ARGS__) // convert to normal fprintf #undef LOG_TEELN -#define LOG_TEELN(...) fprintf(stderr, __VA_ARGS__); // convert to normal fprintf +#define LOG_TEELN(...) fprintf(stderr, __VA_ARGS__) // convert to normal fprintf #undef LOG_DISABLE #define LOG_DISABLE() // dummy stub From 172576685a81540ea8a41da9907a2914c9731cd4 Mon Sep 17 00:00:00 2001 From: Wck-iipi <110763795+Wck-iipi@users.noreply.github.com> Date: Fri, 6 Oct 2023 15:54:18 +0530 Subject: [PATCH 162/623] [Plugin] Added imgproc 9 functions to opencvmini plugin. (#2794) * Added imgproc's 9 functions. Added BilateralFilter, BoxFilter, MiniDilate, MiniErode, GaussianBlur, MiniLaplacian, MedianBlur, PyrDown, PyrUp Kernel is now added through function. Added WasmEdgeOpenCVMiniEmptyMat Signed-off-by: wck-iipi <21dcs006@nith.ac.in> * Changed Expect in opencvmini_func.h Signed-off-by: wck-iipi <21dcs006@nith.ac.in> * Changed OpenCVMiniEmptyMat in opencvmini_func.cpp Signed-off-by: wck-iipi <21dcs006@nith.ac.in> * Updated tests for opencvmini Signed-off-by: wck-iipi <21dcs006@nith.ac.in> --------- Signed-off-by: wck-iipi <21dcs006@nith.ac.in> --- .../wasmedge_opencvmini/opencvmini_func.cpp | 107 ++++++++++++++++++ plugins/wasmedge_opencvmini/opencvmini_func.h | 92 +++++++++++++++ .../wasmedge_opencvmini/opencvmini_module.cpp | 20 ++++ .../wasmedge_opencvmini.cpp | 2 +- 4 files changed, 220 insertions(+), 1 deletion(-) diff --git a/plugins/wasmedge_opencvmini/opencvmini_func.cpp b/plugins/wasmedge_opencvmini/opencvmini_func.cpp index 7a926946..7173a0ed 100644 --- a/plugins/wasmedge_opencvmini/opencvmini_func.cpp +++ b/plugins/wasmedge_opencvmini/opencvmini_func.cpp @@ -3,7 +3,11 @@ #include "opencvmini_func.h" #include "common/defines.h" +#include "common/errcode.h" +#include +#include +#include #include #include #include @@ -67,6 +71,109 @@ Expect WasmEdgeOpenCVMiniBlur::body(const Runtime::CallingFrame &, return Env.insertMat(Dst); } +Expect +WasmEdgeOpenCVMiniBilateralFilter::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey, uint32_t D, + double SigmaColor, double SigmaSpace) { + cv::Mat Dst; + if (auto Src = Env.getMat(SrcMatKey); Src) { + cv::bilateralFilter(*Src, Dst, D, SigmaColor, SigmaSpace); + } + return Env.insertMat(Dst); +} + +Expect +WasmEdgeOpenCVMiniBoxFilter::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey, uint32_t Ddepth, + uint32_t KernelWidth, uint32_t KernelHeight) { + cv::Mat Dst; + if (auto Src = Env.getMat(SrcMatKey); Src) { + cv::boxFilter(*Src, Dst, Ddepth, cv::Size(KernelWidth, KernelHeight)); + } + return Env.insertMat(Dst); +} + +Expect +WasmEdgeOpenCVMiniEmptyMat::body(const Runtime::CallingFrame &) { + cv::Mat Kernel; + return Env.insertMat(Kernel); +} + +Expect WasmEdgeOpenCVMiniDilate::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey, + uint32_t KernelMatKey) { + cv::Mat Dst; + auto Kernel = Env.getMat(KernelMatKey); + if (auto Src = Env.getMat(SrcMatKey); Src) { + cv::dilate(*Src, Dst, *Kernel); + } + return Env.insertMat(Dst); +} + +Expect WasmEdgeOpenCVMiniErode::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey, + uint32_t KernelMatKey) { + cv::Mat Dst; + auto Kernel = Env.getMat(KernelMatKey); + if (auto Src = Env.getMat(SrcMatKey); Src) { + cv::erode(*Src, Dst, *Kernel); + } + return Env.insertMat(Dst); +} + +Expect +WasmEdgeOpenCVMiniGaussianBlur::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey, uint32_t KernelWidth, + uint32_t KernelHeight, double SigmaX) { + cv::Mat Dst; + if (auto Src = Env.getMat(SrcMatKey); Src) { + cv::GaussianBlur(*Src, Dst, cv::Size(KernelWidth, KernelHeight), SigmaX); + } + return Env.insertMat(Dst); +} + +Expect +WasmEdgeOpenCVMiniLaplacian::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey, uint32_t Ddepth) { + cv::Mat Dst; + if (auto Src = Env.getMat(SrcMatKey); Src) { + cv::Laplacian(*Src, Dst, Ddepth); + } + return Env.insertMat(Dst); +} + +Expect +WasmEdgeOpenCVMiniMedianBlur::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey, uint32_t Ksize) { + cv::Mat Dst; + if (auto Src = Env.getMat(SrcMatKey); Src) { + cv::medianBlur(*Src, Dst, Ksize); + } + return Env.insertMat(Dst); +} + +Expect WasmEdgeOpenCVMiniPyrDown::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey, + uint32_t KernelWidth, + uint32_t KernelHeight) { + cv::Mat Dst; + if (auto Src = Env.getMat(SrcMatKey); Src) { + cv::pyrDown(*Src, Dst, cv::Size(KernelWidth, KernelHeight)); + } + return Env.insertMat(Dst); +} + +Expect WasmEdgeOpenCVMiniPyrUp::body(const Runtime::CallingFrame &, + uint32_t SrcMatKey, + uint32_t KernelWidth, + uint32_t KernelHeight) { + cv::Mat Dst; + if (auto Src = Env.getMat(SrcMatKey); Src) { + cv::pyrUp(*Src, Dst, cv::Size(KernelWidth, KernelHeight)); + } + return Env.insertMat(Dst); +} + Expect WasmEdgeOpenCVMiniImwrite::body(const Runtime::CallingFrame &Frame, uint32_t TargetFileNamePtr, uint32_t TargetFileNameLen, diff --git a/plugins/wasmedge_opencvmini/opencvmini_func.h b/plugins/wasmedge_opencvmini/opencvmini_func.h index f86f07c9..80953b7f 100644 --- a/plugins/wasmedge_opencvmini/opencvmini_func.h +++ b/plugins/wasmedge_opencvmini/opencvmini_func.h @@ -59,6 +59,16 @@ class WasmEdgeOpenCVMiniBlur uint32_t KernelWidth, uint32_t KernelHeight); }; +class WasmEdgeOpenCVMiniBilateralFilter + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniBilateralFilter(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + + Expect body(const Runtime::CallingFrame &Frame, uint32_t SrcMatKey, + uint32_t D, double SigmaColor, double SigmaSpace); +}; + class WasmEdgeOpenCVMiniImwrite : public WasmEdgeOpenCVMini { public: @@ -91,6 +101,88 @@ class WasmEdgeOpenCVMiniBilinearSampling uint32_t OutImgW, uint32_t OutImgH); }; +class WasmEdgeOpenCVMiniBoxFilter + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniBoxFilter(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t SrcMatKey, + uint32_t Ddepth, uint32_t KernelWidth, + uint32_t KernelHeight); +}; + +class WasmEdgeOpenCVMiniEmptyMat + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniEmptyMat(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + Expect body(const Runtime::CallingFrame &); +}; + +class WasmEdgeOpenCVMiniDilate + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniDilate(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t SrcMatKey, + uint32_t KernelMatKey); +}; + +class WasmEdgeOpenCVMiniErode + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniErode(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t SrcMatKey, + uint32_t Kernel); +}; + +class WasmEdgeOpenCVMiniGaussianBlur + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniGaussianBlur(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t SrcMatKey, + uint32_t KernelWidth, uint32_t KernelHeight, + double SigmaX); +}; + +class WasmEdgeOpenCVMiniLaplacian + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniLaplacian(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t SrcMatKey, + uint32_t Ddepth); +}; + +class WasmEdgeOpenCVMiniMedianBlur + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniMedianBlur(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t SrcMatKey, + uint32_t Ksize); +}; + +class WasmEdgeOpenCVMiniPyrDown + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniPyrDown(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t SrcMatKey, + uint32_t KernelWidth, uint32_t KernelHeight); +}; + +class WasmEdgeOpenCVMiniPyrUp + : public WasmEdgeOpenCVMini { +public: + WasmEdgeOpenCVMiniPyrUp(WasmEdgeOpenCVMiniEnvironment &HostEnv) + : WasmEdgeOpenCVMini(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t SrcMatKey, + uint32_t KernelWidth, uint32_t KernelHeight); +}; + class WasmEdgeOpenCVMiniRectangle : public WasmEdgeOpenCVMini { public: diff --git a/plugins/wasmedge_opencvmini/opencvmini_module.cpp b/plugins/wasmedge_opencvmini/opencvmini_module.cpp index c515889f..5aac89db 100644 --- a/plugins/wasmedge_opencvmini/opencvmini_module.cpp +++ b/plugins/wasmedge_opencvmini/opencvmini_module.cpp @@ -21,6 +21,26 @@ WasmEdgeOpenCVMiniModule::WasmEdgeOpenCVMiniModule() addHostFunc("wasmedge_opencvmini_blur", std::make_unique(Env)); + + addHostFunc("wasmedge_opencvmini_bilateral_filter", + std::make_unique(Env)); + + addHostFunc("wasmedge_opencvmini_box_filter", + std::make_unique(Env)); + addHostFunc("wasmedge_opencvmin_dilate", + std::make_unique(Env)); + addHostFunc("wasmedge_opencvmini_erode", + std::make_unique(Env)); + addHostFunc("wasmedge_opencvmini_gaussian_blur", + std::make_unique(Env)); + addHostFunc("wasmedge_opencvmini_Laplacian", + std::make_unique(Env)); + addHostFunc("wasmedge_opencvmini_median_blur", + std::make_unique(Env)); + addHostFunc("wasmedge_opencvmini_pyrDown", + std::make_unique(Env)); + addHostFunc("wasmedge_opencvmini_pyrUp", + std::make_unique(Env)); addHostFunc("wasmedge_opencvmini_normalize", std::make_unique(Env)); addHostFunc("wasmedge_opencvmini_bilinear_sampling", diff --git a/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp b/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp index 0440a680..926e1330 100644 --- a/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp +++ b/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp @@ -36,7 +36,7 @@ TEST(WasmEdgeOpecvminiTest, Module) { auto *ImgMod = dynamic_cast(createModule()); EXPECT_FALSE(ImgMod == nullptr); - EXPECT_EQ(ImgMod->getFuncExportNum(), 10U); + EXPECT_EQ(ImgMod->getFuncExportNum(), 19U); EXPECT_NE(ImgMod->findFuncExports("wasmedge_opencvmini_imdecode"), nullptr); EXPECT_NE(ImgMod->findFuncExports("wasmedge_opencvmini_imencode"), nullptr); EXPECT_NE(ImgMod->findFuncExports("wasmedge_opencvmini_rectangle"), nullptr); From f78ab8b8a941f047138b250ec8a29d6a1dddd427 Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 4 Oct 2023 17:15:18 +0800 Subject: [PATCH 163/623] [WASI-NN] Unified preload options with case-insensitive matching Signed-off-by: dm4 --- plugins/wasi_nn/wasinnenv.cpp | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index 92ee74d7..af567393 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -17,16 +17,16 @@ create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { } std::map BackendMap = { - {"OpenVINO"sv, Backend::OpenVINO}, - {"ONNX"sv, Backend::ONNX}, - {"Tensorflow"sv, Backend::Tensorflow}, - {"PyTorch"sv, Backend::PyTorch}, - {"TensorflowLite"sv, Backend::TensorflowLite}, - {"Autodetect"sv, Backend::Autodetect}, - {"GGML"sv, Backend::GGML}}; + {"openvino"sv, Backend::OpenVINO}, + {"onnx"sv, Backend::ONNX}, + {"tensorflow"sv, Backend::Tensorflow}, + {"pytorch"sv, Backend::PyTorch}, + {"tensorflowlite"sv, Backend::TensorflowLite}, + {"autodetect"sv, Backend::Autodetect}, + {"ggml"sv, Backend::GGML}}; std::map DeviceMap = { - {"CPU"sv, Device::CPU}, {"GPU"sv, Device::GPU}, {"TPU"sv, Device::TPU}}; + {"cpu"sv, Device::CPU}, {"gpu"sv, Device::GPU}, {"tpu"sv, Device::TPU}}; bool load(const std::filesystem::path &Path, std::vector &Data) { std::ifstream File(Path, std::ios::binary); @@ -61,6 +61,10 @@ WasiNNEnvironment::WasiNNEnvironment() noexcept { } std::vector> Models; Models.reserve(Paths.size()); + std::transform(Encode.begin(), Encode.end(), Encode.begin(), + [](unsigned char c) { return std::tolower(c); }); + std::transform(Target.begin(), Target.end(), Target.begin(), + [](unsigned char c) { return std::tolower(c); }); auto Backend = BackendMap.find(Encode); auto Device = DeviceMap.find(Target); if (Backend != BackendMap.end() && Device != DeviceMap.end()) { From 8aec1f92b361804edff4b88fbc94695fcd3bb43b Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 5 Oct 2023 10:43:51 +0800 Subject: [PATCH 164/623] [WASI-NN] Add AUTO device support in ggml backend Signed-off-by: dm4 --- plugins/wasi_nn/types.h | 2 +- plugins/wasi_nn/wasinnenv.cpp | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/plugins/wasi_nn/types.h b/plugins/wasi_nn/types.h index 6e2a8128..98eca041 100644 --- a/plugins/wasi_nn/types.h +++ b/plugins/wasi_nn/types.h @@ -22,7 +22,7 @@ enum class ErrNo : uint32_t { enum class TensorType : uint8_t { F16 = 0, F32 = 1, U8 = 2, I32 = 3 }; -enum class Device : uint32_t { CPU = 0, GPU = 1, TPU = 2 }; +enum class Device : uint32_t { CPU = 0, GPU = 1, TPU = 2, AUTO = 3 }; enum class Backend : uint8_t { OpenVINO = 0, diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index af567393..6c90ae54 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -25,8 +25,10 @@ std::map BackendMap = { {"autodetect"sv, Backend::Autodetect}, {"ggml"sv, Backend::GGML}}; -std::map DeviceMap = { - {"cpu"sv, Device::CPU}, {"gpu"sv, Device::GPU}, {"tpu"sv, Device::TPU}}; +std::map DeviceMap = {{"cpu"sv, Device::CPU}, + {"gpu"sv, Device::GPU}, + {"tpu"sv, Device::TPU}, + {"auto"sv, Device::AUTO}}; bool load(const std::filesystem::path &Path, std::vector &Data) { std::ifstream File(Path, std::ios::binary); @@ -62,9 +64,9 @@ WasiNNEnvironment::WasiNNEnvironment() noexcept { std::vector> Models; Models.reserve(Paths.size()); std::transform(Encode.begin(), Encode.end(), Encode.begin(), - [](unsigned char c) { return std::tolower(c); }); + [](unsigned char C) { return std::tolower(C); }); std::transform(Target.begin(), Target.end(), Target.begin(), - [](unsigned char c) { return std::tolower(c); }); + [](unsigned char C) { return std::tolower(C); }); auto Backend = BackendMap.find(Encode); auto Device = DeviceMap.find(Target); if (Backend != BackendMap.end() && Device != DeviceMap.end()) { From 8a413d51db69b5777fa82f7a0cd8f7adcdd20166 Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 12 Oct 2023 19:53:12 +0800 Subject: [PATCH 165/623] [WASI-NN] Add STREAM_TO_STDOUT support in ggml backend Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index fa241c33..96e305b1 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -239,10 +239,18 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { break; } - // Append the new token. - CxtRef.LlamaOutputs += + std::string NextToken = llama_token_to_piece(GraphRef.LlamaContext, NewTokenId); + // When setting STREAM_TO_STDOUT, we print the output to stdout. + const char *StreamOutput = std::getenv("STREAM_TO_STDOUT"); + if (StreamOutput != nullptr) { + std::cout << NextToken << std::flush; + } + + // Append the new token. + CxtRef.LlamaOutputs += NextToken; + // Prepare the next batch LlamaBatch.n_tokens = 0; From 5965257fe82bb20983df4f2d2129b3136a4c978c Mon Sep 17 00:00:00 2001 From: Saikat Dey <57017288+notfathomless@users.noreply.github.com> Date: Fri, 13 Oct 2023 20:19:59 +0530 Subject: [PATCH 166/623] [Plugin] zlib: Changed a few naming style issues for the zlib plugin tests. (#2947) Signed-off-by: Saikat Dey --- test/plugins/wasmedge_zlib/wasmedge_zlib.cpp | 36 ++++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp b/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp index d47474d8..d526178a 100644 --- a/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp +++ b/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp @@ -37,8 +37,8 @@ void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, std::fill_n(MemInst.getPointer(Offset), Cnt, C); } -static constexpr size_t DATA_SIZE = 1 * 1024 * 1024ULL; -static constexpr size_t OUTPUT_BUFFER_SIZE = 64 * 1024ULL; +static constexpr size_t DataSize = 1 * 1024 * 1024ULL; +static constexpr size_t OutputBufferSize = 64 * 1024ULL; constexpr auto RandChar = []() -> char { constexpr char Charset[] = "0123456789" @@ -66,7 +66,7 @@ TEST(WasmEdgeZlibTest, DeflateInflateCycle) { WasmHP = 1, WasmData, WasmZlibVersion, ModuleZStream, WasmCompressedData, WasmDecompressedData; - uint32_t WasmCompressedData_size = 0, WasmDecompressedData_size = 0; + uint32_t WasmCompressedData_size = 0, WasmDecompressedDataSize = 0; WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); auto *FuncInst = ZlibMod->findFuncExports("deflateInit_"); @@ -113,8 +113,8 @@ TEST(WasmEdgeZlibTest, DeflateInflateCycle) { WasmHP += std::strlen(ZLIB_VERSION); WasmData = WasmHP; - std::generate_n(MemInst.getPointer(WasmHP), DATA_SIZE, RandChar); - WasmHP += DATA_SIZE; + std::generate_n(MemInst.getPointer(WasmHP), DataSize, RandChar); + WasmHP += DataSize; ModuleZStream = WasmHP; WasmZStream *strm = MemInst.getPointer(ModuleZStream); @@ -149,16 +149,16 @@ TEST(WasmEdgeZlibTest, DeflateInflateCycle) { WasmCompressedData = WasmHP; - strm->AvailIn = DATA_SIZE; + strm->AvailIn = DataSize; strm->NextIn = WasmData; - strm->AvailOut = OUTPUT_BUFFER_SIZE; + strm->AvailOut = OutputBufferSize; strm->NextOut = WasmCompressedData; // deflate Test do { if (strm->AvailOut == 0) { - WasmHP += OUTPUT_BUFFER_SIZE; - strm->AvailOut = OUTPUT_BUFFER_SIZE; + WasmHP += OutputBufferSize; + strm->AvailOut = OutputBufferSize; strm->NextOut = WasmHP; } @@ -176,7 +176,7 @@ TEST(WasmEdgeZlibTest, DeflateInflateCycle) { CallFrame, std::initializer_list{ModuleZStream}, RetVal)); EXPECT_EQ(RetVal[0].get(), Z_OK); - WasmHP += OUTPUT_BUFFER_SIZE - strm->AvailOut; + WasmHP += OutputBufferSize - strm->AvailOut; WasmCompressedData_size = WasmHP - WasmCompressedData; // ----- Deflate Routine END------ @@ -211,14 +211,14 @@ TEST(WasmEdgeZlibTest, DeflateInflateCycle) { strm->AvailIn = WasmCompressedData_size; strm->NextIn = WasmCompressedData; - strm->AvailOut = OUTPUT_BUFFER_SIZE; + strm->AvailOut = OutputBufferSize; strm->NextOut = WasmDecompressedData; // inflate test do { if (strm->AvailOut == 0) { - WasmHP += OUTPUT_BUFFER_SIZE; - strm->AvailOut = OUTPUT_BUFFER_SIZE; + WasmHP += OutputBufferSize; + strm->AvailOut = OutputBufferSize; strm->NextOut = WasmHP; } @@ -235,21 +235,21 @@ TEST(WasmEdgeZlibTest, DeflateInflateCycle) { CallFrame, std::initializer_list{ModuleZStream}, RetVal)); EXPECT_EQ(RetVal[0].get(), Z_OK); - WasmHP += OUTPUT_BUFFER_SIZE - strm->AvailOut; - WasmDecompressedData_size = WasmHP - WasmDecompressedData; + WasmHP += OutputBufferSize - strm->AvailOut; + WasmDecompressedDataSize = WasmHP - WasmDecompressedData; // ----- Inflate Routine END------ // Test Decompressed Buffer size against source Data size. - EXPECT_EQ(WasmDecompressedData_size, DATA_SIZE); + EXPECT_EQ(WasmDecompressedDataSize, DataSize); // Test Decompressed Buffer content against source Data. EXPECT_TRUE(std::equal(MemInst.getPointer(WasmDecompressedData), MemInst.getPointer( - WasmDecompressedData + WasmDecompressedData_size), + WasmDecompressedData + WasmDecompressedDataSize), MemInst.getPointer(WasmData))); } TEST(WasmEdgeZlibTest, Module) { - // Create the wasmedge_process module instance. + // Create the wasmedge_zlib module instance. auto *ZlibMod = dynamic_cast(createModule()); EXPECT_FALSE(ZlibMod == nullptr); From 18184fa9627f07ac7274e82a2ed2e76bb0c21143 Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 17 Oct 2023 06:46:01 +0800 Subject: [PATCH 167/623] [WASI-NN] Enable OpenVINO target again (#2940) Signed-off-by: hydai --- utils/wasi-nn/install-openvino.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/wasi-nn/install-openvino.sh b/utils/wasi-nn/install-openvino.sh index a6624372..15a8fcc8 100755 --- a/utils/wasi-nn/install-openvino.sh +++ b/utils/wasi-nn/install-openvino.sh @@ -3,7 +3,7 @@ # SPDX-FileCopyrightText: 2019-2022 Second State INC set -e -echo "Installing OpenVINO with version 2023.0.0" +echo "Installing OpenVINO with version 2023.0.2" wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB echo "deb https://apt.repos.intel.com/openvino/2023 ubuntu20 main" | tee /etc/apt/sources.list.d/intel-openvino-2023.list From 092a39b1b1cd67283d99cd0ce3ad123d8632e26e Mon Sep 17 00:00:00 2001 From: dm4 Date: Fri, 20 Oct 2023 15:03:43 +0800 Subject: [PATCH 168/623] [WASI-NN] Add metadata support (#2957) - Use `setInput()` with index = 1 for llama.cpp options - Encode options with JSON string Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 59 +++++++++++++- plugins/wasi_nn/ggml.cpp | 142 ++++++++++++++++++--------------- plugins/wasi_nn/ggml.h | 5 ++ 3 files changed, 139 insertions(+), 67 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index ef6ad97b..66eb1cdc 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -28,6 +28,63 @@ else() set(LLAMA_METAL OFF) endif() +# simdjson for ggml backend +find_package(simdjson QUIET) +if(simdjson_FOUND) + message(STATUS "SIMDJSON found") +else() + message(STATUS "Downloading SIMDJSON source") + include(FetchContent) + FetchContent_Declare( + simdjson + GIT_REPOSITORY https://github.com/simdjson/simdjson.git + GIT_TAG tags/v3.2.1 + GIT_SHALLOW TRUE) + + if(MSVC) + if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") + get_property( + compile_options + DIRECTORY + PROPERTY COMPILE_OPTIONS + ) + set_property( + DIRECTORY + APPEND + PROPERTY COMPILE_OPTIONS + -Wno-undef + -Wno-suggest-override + -Wno-documentation + -Wno-sign-conversion + -Wno-extra-semi-stmt + -Wno-old-style-cast + -Wno-error=unused-parameter + -Wno-error=unused-template + -Wno-conditional-uninitialized + -Wno-implicit-int-conversion + -Wno-shorten-64-to-32 + -Wno-range-loop-bind-reference + -Wno-format-nonliteral + -Wno-unused-exception-parameter + -Wno-unused-member-function + ) + unset(compile_options) + elseif(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") + set_property( + DIRECTORY + APPEND + PROPERTY COMPILE_OPTIONS + /wd4100 # unreferenced formal parameter + ) + endif() + endif() + + set_property(TARGET simdjson PROPERTY POSITION_INDEPENDENT_CODE ON) + FetchContent_MakeAvailable(simdjson) + + message(STATUS "Downloading SIMDJSON source -- done") +endif() + add_subdirectory(thirdparty) wasmedge_add_library(wasmedgePluginWasiNN @@ -69,7 +126,7 @@ endif() string(TOLOWER ${WASMEDGE_PLUGIN_WASI_NN_BACKEND} BACKEND) if(BACKEND STREQUAL "ggml") - target_link_libraries(wasmedgePluginWasiNN PRIVATE llama) + target_link_libraries(wasmedgePluginWasiNN PRIVATE llama simdjson) endif() include(WASINNDeps) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 96e305b1..b9de6c07 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -5,6 +5,7 @@ #include "wasinnenv.h" #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +#include "simdjson.h" #include #include #include @@ -12,31 +13,6 @@ namespace WasmEdge::Host::WASINN::GGML { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML -ErrNo wasmedge_llama_context_params(llama_context_params &Params) noexcept { - const char *LlamaNContextEnv = std::getenv("LLAMA_N_CTX"); - const char *LlamaLogEnv = std::getenv("LLAMA_LOG"); - if (LlamaNContextEnv != nullptr) { - try { - Params.n_ctx = std::stoi(LlamaNContextEnv); - } catch (const std::out_of_range &e) { - spdlog::error( - "[WASI-NN] GGML backend: set n_ctx failed: out_of_range {}"sv, - e.what()); - return ErrNo::InvalidArgument; - } catch (const std::invalid_argument &e) { - spdlog::error( - "[WASI-NN] GGML backend: set n_ctx failed: invalid_argument {}"sv, - e.what()); - return ErrNo::InvalidArgument; - } - if (LlamaLogEnv != nullptr) { - spdlog::info("[WASI-NN] GGML backend: set n_ctx to {}"sv, Params.n_ctx); - } - } - - return ErrNo::Success; -} - Expect load(WasiNNEnvironment &Env, Span> Builders, [[maybe_unused]] Device Device, uint32_t &GraphId) noexcept { // The graph builder length must be 1. @@ -77,12 +53,6 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Initialize ggml model. gpt_params Params; llama_backend_init(Params.numa); - llama_context_params ContextParams = llama_context_default_params(); - ErrNo Err = wasmedge_llama_context_params(ContextParams); - if (Err != ErrNo::Success) { - spdlog::error("[WASI-NN] GGML backend: Error: unable to init context."sv); - return ErrNo::InvalidArgument; - } llama_model_params ModelParams = llama_model_default_params(); GraphRef.LlamaModel = llama_load_model_from_file(ModelFilePath.c_str(), ModelParams); @@ -104,24 +74,86 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, uint32_t &ContextId) noexcept { Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); - ContextId = Env.NNContext.size() - 1; + + // Set the default context options. + auto &CxtRef = Env.NNContext[ContextId].get(); + auto ContextDefault = llama_context_default_params(); + CxtRef.EnableLog = false; + CxtRef.StreamStdout = false; + CxtRef.CtxSize = ContextDefault.n_ctx; + CxtRef.NPredict = ContextDefault.n_ctx; + CxtRef.NGPULayers = 0; + return ErrNo::Success; } Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, - [[maybe_unused]] uint32_t Index, - const TensorData &Tensor) noexcept { + uint32_t Index, const TensorData &Tensor) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + // Use index 1 for metadata. + if (Index == 1) { + // Decode metadata. + std::string Metadata(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + auto ParseError = Parser.parse(Metadata).get(Doc); + if (ParseError) { + spdlog::error("[WASI-NN] GGML backend: Parse metadata error"sv); + return ErrNo::InvalidEncoding; + } + + // Get metadata from the json. + if (Doc.at_key("enable-log").error() == simdjson::SUCCESS) { + auto Err = Doc["enable-log"].get().get(CxtRef.EnableLog); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the enable-log option."sv); + return ErrNo::InvalidArgument; + } + } + if (Doc.at_key("stream-stdout").error() == simdjson::SUCCESS) { + auto Err = Doc["stream-stdout"].get().get(CxtRef.StreamStdout); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the stream-stdout option."sv); + return ErrNo::InvalidArgument; + } + } + if (Doc.at_key("ctx-size").error() == simdjson::SUCCESS) { + auto Err = Doc["ctx-size"].get().get(CxtRef.CtxSize); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the ctx-size option."sv); + return ErrNo::InvalidArgument; + } + } + if (Doc.at_key("n-predict").error() == simdjson::SUCCESS) { + auto Err = Doc["n-predict"].get().get(CxtRef.NPredict); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the n-predict option."sv); + return ErrNo::InvalidArgument; + } + } + if (Doc.at_key("n-gpu-layers").error() == simdjson::SUCCESS) { + auto Err = Doc["n-gpu-layers"].get().get(CxtRef.NGPULayers); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the n-gpu-layers option."sv); + return ErrNo::InvalidArgument; + } + } + + return ErrNo::Success; + } + // Initialize the llama context. llama_context_params ContextParams = llama_context_default_params(); - ErrNo Err = wasmedge_llama_context_params(ContextParams); - if (Err != ErrNo::Success) { - spdlog::error("[WASI-NN] GGML backend: Error: unable to init context."sv); - return ErrNo::InvalidArgument; - } + ContextParams.n_ctx = CxtRef.CtxSize; GraphRef.LlamaContext = llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); @@ -160,9 +192,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { return ErrNo::InvalidArgument; } - // Use env LLAMA_LOG=1 to enable llama log. - const char *LlamaLogEnv = std::getenv("LLAMA_LOG"); - if (LlamaLogEnv != nullptr) { + if (CxtRef.EnableLog) { spdlog::info("[WASI-NN] GGML backend: llama_system_info: {}"sv, llama_print_system_info()); } @@ -176,26 +206,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { const int MaxContextSize = llama_n_ctx(GraphRef.LlamaContext); // NPredict is the number of tokens to predict. Same as -n, --n-predict in // llama.cpp. - int NPredict = MaxContextSize; - const char *LlamaNPredictEnv = std::getenv("LLAMA_N_PREDICT"); - if (LlamaNPredictEnv != nullptr) { - try { - NPredict = std::stoi(LlamaNPredictEnv); - } catch (const std::out_of_range &e) { - spdlog::error( - "[WASI-NN] GGML backend: set n_predict failed: out_of_range {}"sv, - e.what()); - return ErrNo::InvalidArgument; - } catch (const std::invalid_argument &e) { - spdlog::error( - "[WASI-NN] GGML backend: set n_predict failed: invalid_argument {}"sv, - e.what()); - return ErrNo::InvalidArgument; - } - if (LlamaLogEnv != nullptr) { - spdlog::info("[WASI-NN] GGML backend: set n_predict to {}"sv, NPredict); - } - } + int NPredict = CxtRef.NPredict; // Evaluate the initial prompt. llama_batch LlamaBatch = llama_batch_init(NPredict, 0); @@ -242,9 +253,8 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { std::string NextToken = llama_token_to_piece(GraphRef.LlamaContext, NewTokenId); - // When setting STREAM_TO_STDOUT, we print the output to stdout. - const char *StreamOutput = std::getenv("STREAM_TO_STDOUT"); - if (StreamOutput != nullptr) { + // When setting StreamStdout, we print the output to stdout. + if (CxtRef.StreamStdout) { std::cout << NextToken << std::flush; } @@ -269,7 +279,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { } } - if (LlamaLogEnv != nullptr) { + if (CxtRef.EnableLog) { llama_print_timings(GraphRef.LlamaContext); } diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index 7d82cc86..35cc3cc5 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -28,6 +28,11 @@ struct Context { size_t GraphId; std::vector LlamaInputs; std::string LlamaOutputs; + bool EnableLog; + bool StreamStdout; + uint64_t CtxSize; + uint64_t NPredict; + uint64_t NGPULayers; }; #else struct Graph {}; From 0d10a1f6a9f0f2c004ec31b884477110f06e7c53 Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 26 Oct 2023 13:15:00 +0800 Subject: [PATCH 169/623] [WASI-NN] Make simdjson available before `set_property` (#2968) Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 66eb1cdc..05f356d7 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -79,8 +79,8 @@ else() endif() endif() - set_property(TARGET simdjson PROPERTY POSITION_INDEPENDENT_CODE ON) FetchContent_MakeAvailable(simdjson) + set_property(TARGET simdjson PROPERTY POSITION_INDEPENDENT_CODE ON) message(STATUS "Downloading SIMDJSON source -- done") endif() From 1e58ec99ef411d81c7566e63aa109e6400c679f3 Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 25 Oct 2023 13:54:26 +0800 Subject: [PATCH 170/623] [WASI-NN] Support batch-size Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 12 +++++++++++- plugins/wasi_nn/ggml.h | 1 + 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index b9de6c07..40e199b8 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -84,6 +84,7 @@ Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, CxtRef.CtxSize = ContextDefault.n_ctx; CxtRef.NPredict = ContextDefault.n_ctx; CxtRef.NGPULayers = 0; + CxtRef.BatchSize = ContextDefault.n_batch; return ErrNo::Success; } @@ -147,6 +148,14 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, return ErrNo::InvalidArgument; } } + if (Doc.at_key("batch-size").error() == simdjson::SUCCESS) { + auto Err = Doc["batch-size"].get().get(CxtRef.BatchSize); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the batch-size option."sv); + return ErrNo::InvalidArgument; + } + } return ErrNo::Success; } @@ -154,6 +163,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, // Initialize the llama context. llama_context_params ContextParams = llama_context_default_params(); ContextParams.n_ctx = CxtRef.CtxSize; + ContextParams.n_batch = CxtRef.BatchSize; GraphRef.LlamaContext = llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); @@ -209,7 +219,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { int NPredict = CxtRef.NPredict; // Evaluate the initial prompt. - llama_batch LlamaBatch = llama_batch_init(NPredict, 0); + llama_batch LlamaBatch = llama_batch_init(CxtRef.BatchSize, 0); LlamaBatch.n_tokens = CxtRef.LlamaInputs.size(); for (int32_t I = 0; I < LlamaBatch.n_tokens; I++) { LlamaBatch.token[I] = CxtRef.LlamaInputs[I]; diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index 35cc3cc5..e9981655 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -33,6 +33,7 @@ struct Context { uint64_t CtxSize; uint64_t NPredict; uint64_t NGPULayers; + uint64_t BatchSize; }; #else struct Graph {}; From b00ebae7d48941ffc416af6ae9954a1c93b9649c Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 25 Oct 2023 16:41:06 +0800 Subject: [PATCH 171/623] [WASI-NN] Support reverse-prompt Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 21 +++++++++++++++++++++ plugins/wasi_nn/ggml.h | 1 + 2 files changed, 22 insertions(+) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 40e199b8..83082ea9 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -85,6 +85,7 @@ Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, CxtRef.NPredict = ContextDefault.n_ctx; CxtRef.NGPULayers = 0; CxtRef.BatchSize = ContextDefault.n_batch; + CxtRef.ReversePrompt = ""sv; return ErrNo::Success; } @@ -156,6 +157,17 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, return ErrNo::InvalidArgument; } } + if (Doc.at_key("reverse-prompt").error() == simdjson::SUCCESS) { + std::string_view ReversePrompt; + auto Err = + Doc["reverse-prompt"].get().get(ReversePrompt); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the reverse-prompt option."sv); + return ErrNo::InvalidArgument; + } + CxtRef.ReversePrompt = ReversePrompt; + } return ErrNo::Success; } @@ -287,6 +299,15 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { spdlog::error("[WASI-NN] GGML backend: failed to llama_decode"sv); return ErrNo::RuntimeError; } + + // Break if reverse prompt is found. + if (!CxtRef.ReversePrompt.empty() && + CxtRef.LlamaOutputs.find(CxtRef.ReversePrompt) != std::string::npos) { + if (CxtRef.EnableLog) { + spdlog::info("[WASI-NN] GGML backend: reverse prompt found"sv); + } + break; + } } if (CxtRef.EnableLog) { diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index e9981655..ae165436 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -34,6 +34,7 @@ struct Context { uint64_t NPredict; uint64_t NGPULayers; uint64_t BatchSize; + std::string ReversePrompt; }; #else struct Graph {}; From ac212381e941be4248999bb6e5802c1431f13068 Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 26 Oct 2023 13:10:07 +0800 Subject: [PATCH 172/623] [WASI-NN] Fix llama log print Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 1 + plugins/wasi_nn/thirdparty/ggml/llama.cpp | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 83082ea9..5fe62fc9 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -311,6 +311,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { } if (CxtRef.EnableLog) { + llama_log_set(nullptr, &CxtRef.EnableLog); llama_print_timings(GraphRef.LlamaContext); } diff --git a/plugins/wasi_nn/thirdparty/ggml/llama.cpp b/plugins/wasi_nn/thirdparty/ggml/llama.cpp index 40d2246f..d4f3b184 100644 --- a/plugins/wasi_nn/thirdparty/ggml/llama.cpp +++ b/plugins/wasi_nn/thirdparty/ggml/llama.cpp @@ -7693,9 +7693,9 @@ static void llama_log_internal(ggml_log_level level, const char * format, ...) { static void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data) { (void) level; - (void) user_data; - if (std::getenv("LLAMA_LOG") != nullptr) { + bool enable_log = static_cast(user_data); + if (enable_log) { fputs(text, stderr); + fflush(stderr); } - fflush(stderr); } From c58bd29714675cfdcccaa3a4b28ba1fd0e68760b Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 11 Oct 2023 09:14:02 -0500 Subject: [PATCH 173/623] [WASI-NN] ggml backend: Enable cuBLAS and use LLAMA_N_GL to set the n_gpu_layers Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 9 + plugins/wasi_nn/ggml.cpp | 18 + plugins/wasi_nn/thirdparty/ggml/ggml-cuda.cu | 7402 ++++++++++++++++++ plugins/wasi_nn/thirdparty/ggml/ggml-cuda.h | 47 + 4 files changed, 7476 insertions(+) create mode 100644 plugins/wasi_nn/thirdparty/ggml/ggml-cuda.cu create mode 100644 plugins/wasi_nn/thirdparty/ggml/ggml-cuda.h diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 05f356d7..886d828a 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -6,6 +6,15 @@ set(LLAMA_ALL_WARNINGS OFF) set(LLAMA_METAL_NDEBUG ON) +if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_CUBLAS) + message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_CUBLAS") + # Default use OpenBLAS + set(LLAMA_CUBLAS ON) +else() + message(STATUS "WASI-NN GGML LLAMA backend: Disable LLAMA_CUBLAS") + set(LLAMA_CUBLAS OFF) +endif() + if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_BLAS) message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_BLAS") # Default use OpenBLAS diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 5fe62fc9..109fa042 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -54,6 +54,24 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, gpt_params Params; llama_backend_init(Params.numa); llama_model_params ModelParams = llama_model_default_params(); + + const char *LlamaNGPULayerEnv = std::getenv("LLAMA_N_GL"); + if (LlamaNGPULayerEnv != nullptr) { + try { + ModelParams.n_gpu_layers = std::stoi(LlamaNGPULayerEnv); + } catch (const std::out_of_range &e) { + spdlog::error( + "[WASI-NN] GGML backend: set n_gpu_layers failed: out_of_range {}"sv, + e.what()); + return ErrNo::InvalidArgument; + } catch (const std::invalid_argument &e) { + spdlog::error( + "[WASI-NN] GGML backend: set n_gpu_layers failed: invalid_argument {}"sv, + e.what()); + return ErrNo::InvalidArgument; + } + } + GraphRef.LlamaModel = llama_load_model_from_file(ModelFilePath.c_str(), ModelParams); if (GraphRef.LlamaModel == nullptr) { diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-cuda.cu b/plugins/wasi_nn/thirdparty/ggml/ggml-cuda.cu new file mode 100644 index 00000000..989c419c --- /dev/null +++ b/plugins/wasi_nn/thirdparty/ggml/ggml-cuda.cu @@ -0,0 +1,7402 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(GGML_USE_HIPBLAS) +#include +#include +#include +#ifdef __HIP_PLATFORM_AMD__ +// for rocblas_initialize() +#include "rocblas/rocblas.h" +#endif // __HIP_PLATFORM_AMD__ +#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F +#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F +#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F +#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT +#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT +#define CUBLAS_OP_N HIPBLAS_OP_N +#define CUBLAS_OP_T HIPBLAS_OP_T +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define CUBLAS_TF32_TENSOR_OP_MATH 0 +#define CUDA_R_16F HIPBLAS_R_16F +#define CUDA_R_32F HIPBLAS_R_32F +#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) +#define cublasCreate hipblasCreate +#define cublasGemmEx hipblasGemmEx +#define cublasHandle_t hipblasHandle_t +#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS +#define cublasSetStream hipblasSetStream +#define cublasSgemm hipblasSgemm +#define cublasStatus_t hipblasStatus_t +#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer +#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess +#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess +#define cudaDeviceProp hipDeviceProp_t +#define cudaDeviceSynchronize hipDeviceSynchronize +#define cudaError_t hipError_t +#define cudaEventCreateWithFlags hipEventCreateWithFlags +#define cudaEventDisableTiming hipEventDisableTiming +#define cudaEventRecord hipEventRecord +#define cudaEvent_t hipEvent_t +#define cudaEventDestroy hipEventDestroy +#define cudaFree hipFree +#define cudaFreeHost hipHostFree +#define cudaGetDevice hipGetDevice +#define cudaGetDeviceCount hipGetDeviceCount +#define cudaGetDeviceProperties hipGetDeviceProperties +#define cudaGetErrorString hipGetErrorString +#define cudaGetLastError hipGetLastError +#define cudaMalloc hipMalloc +#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) +#define cudaMemcpy hipMemcpy +#define cudaMemcpy2DAsync hipMemcpy2DAsync +#define cudaMemcpyAsync hipMemcpyAsync +#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice +#define cudaMemcpyKind hipMemcpyKind +#define cudaMemset hipMemset +#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize +#define cudaSetDevice hipSetDevice +#define cudaStreamCreateWithFlags hipStreamCreateWithFlags +#define cudaStreamNonBlocking hipStreamNonBlocking +#define cudaStreamSynchronize hipStreamSynchronize +#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) +#define cudaStream_t hipStream_t +#define cudaSuccess hipSuccess +#else +#include +#include +#include +#endif // defined(GGML_USE_HIPBLAS) + +#include "ggml-cuda.h" +#include "ggml.h" + +#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products +#define CC_VOLTA 700 +#define CC_OFFSET_AMD 1000000 +#define CC_RDNA2 (CC_OFFSET_AMD + 1030) + +#if defined(GGML_USE_HIPBLAS) +#define __CUDA_ARCH__ 1300 + +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \ + defined(__gfx1150__) || defined(__gfx1151__) +#define RDNA3 +#endif + +#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \ + defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__) +#define RDNA2 +#endif + +#ifndef __has_builtin + #define __has_builtin(x) 0 +#endif + +typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); +static __device__ __forceinline__ int __vsubss4(const int a, const int b) { + const int8x4_t va = reinterpret_cast(a); + const int8x4_t vb = reinterpret_cast(b); +#if __has_builtin(__builtin_elementwise_sub_sat) + const int8x4_t c = __builtin_elementwise_sub_sat(va, vb); + return reinterpret_cast(c); +#else + int8x4_t c; + int16_t tmp; +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp = va[i] - vb[i]; + if(tmp > std::numeric_limits::max()) tmp = std::numeric_limits::max(); + if(tmp < std::numeric_limits::min()) tmp = std::numeric_limits::min(); + c[i] = tmp; + } + return reinterpret_cast(c); +#endif // __has_builtin(__builtin_elementwise_sub_sat) +} + +static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { +#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__) + c = __builtin_amdgcn_sdot4(a, b, c, false); +#elif defined(__gfx1100__) + c = __builtin_amdgcn_sudot4( true, a, true, b, c, false); +#elif defined(__gfx1010__) || defined(__gfx900__) + int tmp1; + int tmp2; + asm("\n \ + v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \ + v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \ + v_add3_u32 %0, %1, %2, %0 \n \ + v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \ + v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \ + v_add3_u32 %0, %1, %2, %0 \n \ + " + : "+v"(c), "=&v"(tmp1), "=&v"(tmp2) + : "v"(a), "v"(b) + ); +#else + const int8x4_t va = reinterpret_cast(a); + const int8x4_t vb = reinterpret_cast(b); + c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3]; +#endif + return c; +} +#endif // defined(GGML_USE_HIPBLAS) + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); + +#define CUDA_CHECK(err) \ + do { \ + cudaError_t err_ = (err); \ + if (err_ != cudaSuccess) { \ + int id; \ + cudaGetDevice(&id); \ + fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ + cudaGetErrorString(err_)); \ + fprintf(stderr, "current device: %d\n", id); \ + exit(1); \ + } \ + } while (0) + +#if CUDART_VERSION >= 12000 +#define CUBLAS_CHECK(err) \ + do { \ + cublasStatus_t err_ = (err); \ + if (err_ != CUBLAS_STATUS_SUCCESS) { \ + int id; \ + cudaGetDevice(&id); \ + fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \ + err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \ + fprintf(stderr, "current device: %d\n", id); \ + exit(1); \ + } \ + } while (0) +#else +#define CUBLAS_CHECK(err) \ + do { \ + cublasStatus_t err_ = (err); \ + if (err_ != CUBLAS_STATUS_SUCCESS) { \ + int id; \ + cudaGetDevice(&id); \ + fprintf(stderr, "\ncuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \ + fprintf(stderr, "current device: %d\n", id); \ + exit(1); \ + } \ + } while (0) +#endif // CUDART_VERSION >= 11 + +#if CUDART_VERSION >= 11100 +#define GGML_CUDA_ASSUME(x) __builtin_assume(x) +#else +#define GGML_CUDA_ASSUME(x) +#endif // CUDART_VERSION >= 11100 + +#ifdef GGML_CUDA_F16 +typedef half dfloat; // dequantize float +typedef half2 dfloat2; +#else +typedef float dfloat; // dequantize float +typedef float2 dfloat2; +#endif //GGML_CUDA_F16 + +static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) { + const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment + + int x32 = 0; + x32 |= x16[0] << 0; + x32 |= x16[1] << 16; + + return x32; +} + +static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) { + const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment + + int x32 = 0; + x32 |= x16[0] << 0; + x32 |= x16[1] << 16; + + return x32; +} + +static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) { + return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment +} + +static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) { + return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment +} + +template +using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int k, cudaStream_t stream); +typedef to_t_cuda_t to_fp32_cuda_t; +typedef to_t_cuda_t to_fp16_cuda_t; + +typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); +typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v); +typedef void (*cpy_kernel_t)(const char * cx, char * cdst); +typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); +typedef void (*ggml_cuda_op_mul_mat_t)( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, const cudaStream_t & stream); +typedef void (*ggml_cuda_op_flatten_t)( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream); + +// QK = number of values after dequantization +// QR = QK / number of values before dequantization +// QI = number of 32 bit integers before dequantization + +#define QK4_0 32 +#define QR4_0 2 +#define QI4_0 (QK4_0 / (4 * QR4_0)) +typedef struct { + half d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); + +#define QK4_1 32 +#define QR4_1 2 +#define QI4_1 (QK4_1 / (4 * QR4_1)) +typedef struct { + half2 dm; // dm.x = delta, dm.y = min + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; +static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); + +#define QK5_0 32 +#define QR5_0 2 +#define QI5_0 (QK5_0 / (4 * QR5_0)) +typedef struct { + half d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; +static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + +#define QK5_1 32 +#define QR5_1 2 +#define QI5_1 (QK5_1 / (4 * QR5_1)) +typedef struct { + half2 dm; // dm.x = delta, dm.y = min + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; +static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); + +#define QK8_0 32 +#define QR8_0 1 +#define QI8_0 (QK8_0 / (4 * QR8_0)) +typedef struct { + half d; // delta + int8_t qs[QK8_0]; // quants +} block_q8_0; +static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); + +#define QK8_1 32 +#define QR8_1 1 +#define QI8_1 (QK8_1 / (4 * QR8_1)) +typedef struct { + half2 ds; // ds.x = delta, ds.y = sum + int8_t qs[QK8_0]; // quants +} block_q8_1; +static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding"); + +typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs); +typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc); +typedef void (*load_tiles_cuda_t)( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row); +typedef float (*vec_dot_q_mul_mat_cuda_t)( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, const int & i, const int & j, const int & k); + +//================================= k-quants + +#ifdef GGML_QKK_64 +#define QK_K 64 +#define K_SCALE_SIZE 4 +#else +#define QK_K 256 +#define K_SCALE_SIZE 12 +#endif + +#define QR2_K 4 +#define QI2_K (QK_K / (4*QR2_K)) +typedef struct { + uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits + uint8_t qs[QK_K/4]; // quants + half2 dm; // super-block scale for quantized scales/mins +} block_q2_K; +static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); + +#define QR3_K 4 +#define QI3_K (QK_K / (4*QR3_K)) +typedef struct { + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits +#ifdef GGML_QKK_64 + uint8_t scales[2]; // scales, quantized with 8 bits +#else + uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits +#endif + half d; // super-block scale +} block_q3_K; +//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding"); + +#define QR4_K 2 +#define QI4_K (QK_K / (4*QR4_K)) +#ifdef GGML_QKK_64 +typedef struct { + half dm[2]; // super-block scales/mins + uint8_t scales[2]; // 4-bit block scales/mins + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == sizeof(half2) + QK_K/2 + 2, "wrong q4_K block size/padding"); +#else +typedef struct { + half2 dm; // super-block scale for quantized scales/mins + uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding"); +#endif + +#define QR5_K 2 +#define QI5_K (QK_K / (4*QR5_K)) +#ifdef GGML_QKK_64 +typedef struct { + half d; // super-block scale + int8_t scales[QK_K/16]; // block scales + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); +#else +typedef struct { + half2 dm; // super-block scale for quantized scales/mins + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +#endif + +#define QR6_K 2 +#define QI6_K (QK_K / (4*QR6_K)) +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales + half d; // delta +} block_q6_K; +static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding"); + +#define WARP_SIZE 32 +#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses + +#define CUDA_ADD_BLOCK_SIZE 256 +#define CUDA_MUL_BLOCK_SIZE 256 +#define CUDA_GELU_BLOCK_SIZE 256 +#define CUDA_SILU_BLOCK_SIZE 256 +#define CUDA_CPY_BLOCK_SIZE 32 +#define CUDA_SCALE_BLOCK_SIZE 256 +#define CUDA_ROPE_BLOCK_SIZE 256 +#define CUDA_ALIBI_BLOCK_SIZE 32 +#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32 +#define CUDA_QUANTIZE_BLOCK_SIZE 256 +#define CUDA_DEQUANTIZE_BLOCK_SIZE 256 + +// dmmv = dequantize_mul_mat_vec +#ifndef GGML_CUDA_DMMV_X +#define GGML_CUDA_DMMV_X 32 +#endif +#ifndef GGML_CUDA_MMV_Y +#define GGML_CUDA_MMV_Y 1 +#endif + +#ifndef K_QUANTS_PER_ITERATION +#define K_QUANTS_PER_ITERATION 2 +#else +static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); +#endif + +#ifndef GGML_CUDA_PEER_MAX_BATCH_SIZE +#define GGML_CUDA_PEER_MAX_BATCH_SIZE 128 +#endif // GGML_CUDA_PEER_MAX_BATCH_SIZE + +#define MUL_MAT_SRC1_COL_STRIDE 128 + +#define MAX_STREAMS 8 +static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr }; + +struct ggml_tensor_extra_gpu { + void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors + cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs +}; + +// this is faster on Windows +// probably because the Windows CUDA libraries forget to make this check before invoking the drivers +inline cudaError_t ggml_cuda_set_device(const int device) { + int current_device; + CUDA_CHECK(cudaGetDevice(¤t_device)); + + if (device == current_device) { + return cudaSuccess; + } + + return cudaSetDevice(device); +} + +static int g_device_count = -1; +static int g_main_device = 0; +static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES]; +static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0}; +static bool g_mul_mat_q = true; + +static void * g_scratch_buffer = nullptr; +static size_t g_scratch_size = 0; // disabled by default +static size_t g_scratch_offset = 0; + +static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; + +static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= kx) { + return; + } + dst[i] = x[i] + y[i%ky]; +} + +static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = __hadd(x[i], __float2half(y[i])); +} + +static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= kx) { + return; + } + dst[i] = x[i] * y[i%ky]; +} + +static __global__ void gelu_f32(const float * x, float * dst, const int k) { + const float GELU_COEF_A = 0.044715f; + const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + float xi = x[i]; + dst[i] = 0.5f*xi*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi))); +} + +static __global__ void silu_f32(const float * x, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = x[i] / (1.0f + expf(-x[i])); +} + +static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32); + a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32); + } + return a; +} + +template +static __global__ void norm_f32(const float * x, float * dst, const int ncols) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + + const float eps = 1e-5f; + + float2 mean_var = make_float2(0.f, 0.f); + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[row*ncols + col]; + mean_var.x += xi; + mean_var.y += xi * xi; + } + + // sum up partial sums + mean_var = warp_reduce_sum(mean_var); + if (block_size > WARP_SIZE) { + __shared__ float2 s_sum[32]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = mean_var; + } + __syncthreads(); + mean_var = s_sum[lane_id]; + mean_var = warp_reduce_sum(mean_var); + } + + const float mean = mean_var.x / ncols; + const float var = mean_var.y / ncols - mean * mean; + const float inv_std = rsqrtf(var + eps); + + for (int col = tid; col < ncols; col += block_size) { + dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std; + } +} + +static __device__ __forceinline__ float warp_reduce_sum(float x) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, mask, 32); + } + return x; +} + +template +static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + + float tmp = 0.0f; // partial sum for thread in warp + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[row*ncols + col]; + tmp += xi * xi; + } + + // sum up partial sums + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + __shared__ float s_sum[32]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + __syncthreads(); + tmp = s_sum[lane_id]; + tmp = warp_reduce_sum(tmp); + } + + const float mean = tmp / ncols; + const float scale = rsqrtf(mean + eps); + + for (int col = tid; col < ncols; col += block_size) { + dst[row*ncols + col] = scale * x[row*ncols + col]; + } +} + +static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const block_q4_0 * x = (const block_q4_0 *) vx; + + const dfloat d = x[ib].d; + + const int vui = x[ib].qs[iqs]; + + v.x = vui & 0xF; + v.y = vui >> 4; + +#ifdef GGML_CUDA_F16 + v = __hsub2(v, {8.0f, 8.0f}); + v = __hmul2(v, {d, d}); +#else + v.x = (v.x - 8.0f) * d; + v.y = (v.y - 8.0f) * d; +#endif // GGML_CUDA_F16 +} + +static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const block_q4_1 * x = (const block_q4_1 *) vx; + + const dfloat d = __low2half(x[ib].dm); + const dfloat m = __high2half(x[ib].dm); + + const int vui = x[ib].qs[iqs]; + + v.x = vui & 0xF; + v.y = vui >> 4; + +#ifdef GGML_CUDA_F16 + v = __hmul2(v, {d, d}); + v = __hadd2(v, {m, m}); +#else + v.x = (v.x * d) + m; + v.y = (v.y * d) + m; +#endif // GGML_CUDA_F16 +} + +static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const block_q5_0 * x = (const block_q5_0 *) vx; + + const dfloat d = x[ib].d; + + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; + const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10; + + v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); + v.y = ((x[ib].qs[iqs] >> 4) | xh_1); + +#ifdef GGML_CUDA_F16 + v = __hsub2(v, {16.0f, 16.0f}); + v = __hmul2(v, {d, d}); +#else + v.x = (v.x - 16.0f) * d; + v.y = (v.y - 16.0f) * d; +#endif // GGML_CUDA_F16 +} + +static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const block_q5_1 * x = (const block_q5_1 *) vx; + + const dfloat d = __low2half(x[ib].dm); + const dfloat m = __high2half(x[ib].dm); + + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; + const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10; + + v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); + v.y = ((x[ib].qs[iqs] >> 4) | xh_1); + +#ifdef GGML_CUDA_F16 + v = __hmul2(v, {d, d}); + v = __hadd2(v, {m, m}); +#else + v.x = (v.x * d) + m; + v.y = (v.y * d) + m; +#endif // GGML_CUDA_F16 +} + +static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const block_q8_0 * x = (const block_q8_0 *) vx; + + const dfloat d = x[ib].d; + + v.x = x[ib].qs[iqs + 0]; + v.y = x[ib].qs[iqs + 1]; + +#ifdef GGML_CUDA_F16 + v = __hmul2(v, {d, d}); +#else + v.x *= d; + v.y *= d; +#endif // GGML_CUDA_F16 +} + +//================================== k-quants + +template +static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_q2_K * x = (const block_q2_K *) vx; + + const int tid = threadIdx.x; +#if QK_K == 256 + const int n = tid/32; + const int l = tid - 32*n; + const int is = 8*n + l/16; + + const uint8_t q = x[i].qs[32*n + l]; + dst_t * y = yy + i*QK_K + 128*n; + + float dall = __low2half(x[i].dm); + float dmin = __high2half(x[i].dm); + y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); + y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4); + y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4); + y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4); +#else + const int is = tid/16; // 0 or 1 + const int il = tid%16; // 0...15 + const uint8_t q = x[i].qs[il] >> (2*is); + dst_t * y = yy + i*QK_K + 16*is + il; + float dall = __low2half(x[i].dm); + float dmin = __high2half(x[i].dm); + y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); + y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4); +#endif + +} + +template +static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_q3_K * x = (const block_q3_K *) vx; + +#if QK_K == 256 + const int r = threadIdx.x/4; + const int tid = r/2; + const int is0 = r%2; + const int l0 = 16*is0 + 4*(threadIdx.x%4); + const int n = tid / 4; + const int j = tid - 4*n; + + uint8_t m = 1 << (4*n + j); + int is = 8*n + 2*j + is0; + int shift = 2*j; + + int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) : + is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) : + is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) : + (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4); + float d_all = x[i].d; + float dl = d_all * (us - 32); + + dst_t * y = yy + i*QK_K + 128*n + 32*j; + const uint8_t * q = x[i].qs + 32*n; + const uint8_t * hm = x[i].hmask; + + for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); +#else + const int tid = threadIdx.x; + const int is = tid/16; // 0 or 1 + const int il = tid%16; // 0...15 + const int im = il/8; // 0...1 + const int in = il%8; // 0...7 + + dst_t * y = yy + i*QK_K + 16*is + il; + + const uint8_t q = x[i].qs[il] >> (2*is); + const uint8_t h = x[i].hmask[in] >> (2*is + im); + const float d = (float)x[i].d; + + if (is == 0) { + y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); + y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); + } else { + y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); + y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); + } +#endif + +} + +#if QK_K == 256 +static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { + if (j < 4) { + d = q[j] & 63; m = q[j + 4] & 63; + } else { + d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + } +} +#endif + +template +static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const block_q4_K * x = (const block_q4_K *) vx; + + const int i = blockIdx.x; + +#if QK_K == 256 + // assume 32 threads + const int tid = threadIdx.x; + const int il = tid/8; + const int ir = tid%8; + const int is = 2*il; + const int n = 4; + + dst_t * y = yy + i*QK_K + 64*il + n*ir; + + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); + + const uint8_t * q = x[i].qs + 32*il + n*ir; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[i].scales, sc, m); + const float d1 = dall * sc; const float m1 = dmin * m; + get_scale_min_k4(is + 1, x[i].scales, sc, m); + const float d2 = dall * sc; const float m2 = dmin * m; + for (int l = 0; l < n; ++l) { + y[l + 0] = d1 * (q[l] & 0xF) - m1; + y[l +32] = d2 * (q[l] >> 4) - m2; + } +#else + const int tid = threadIdx.x; + const uint8_t * q = x[i].qs; + dst_t * y = yy + i*QK_K; + const float d = (float)x[i].dm[0]; + const float m = (float)x[i].dm[1]; + y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4); + y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4); +#endif +} + +template +static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const block_q5_K * x = (const block_q5_K *) vx; + + const int i = blockIdx.x; + +#if QK_K == 256 + // assume 64 threads - this is very slightly better than the one below + const int tid = threadIdx.x; + const int il = tid/16; // il is in 0...3 + const int ir = tid%16; // ir is in 0...15 + const int is = 2*il; // is is in 0...6 + + dst_t * y = yy + i*QK_K + 64*il + 2*ir; + + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); + + const uint8_t * ql = x[i].qs + 32*il + 2*ir; + const uint8_t * qh = x[i].qh + 2*ir; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[i].scales, sc, m); + const float d1 = dall * sc; const float m1 = dmin * m; + get_scale_min_k4(is + 1, x[i].scales, sc, m); + const float d2 = dall * sc; const float m2 = dmin * m; + + uint8_t hm = 1 << (2*il); + y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1; + y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1; + hm <<= 1; + y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2; + y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; +#else + const int tid = threadIdx.x; + const uint8_t q = x[i].qs[tid]; + const int im = tid/8; // 0...3 + const int in = tid%8; // 0...7 + const int is = tid/16; // 0 or 1 + const uint8_t h = x[i].qh[in] >> im; + const float d = x[i].d; + dst_t * y = yy + i*QK_K + tid; + y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16)); + y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16)); +#endif +} + +template +static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const block_q6_K * x = (const block_q6_K *) vx; + + const int i = blockIdx.x; +#if QK_K == 256 + + // assume 64 threads - this is very slightly better than the one below + const int tid = threadIdx.x; + const int ip = tid/32; // ip is 0 or 1 + const int il = tid - 32*ip; // 0...32 + const int is = 8*ip + il/16; + + dst_t * y = yy + i*QK_K + 128*ip + il; + + const float d = x[i].d; + + const uint8_t * ql = x[i].ql + 64*ip + il; + const uint8_t qh = x[i].qh[32*ip + il]; + const int8_t * sc = x[i].scales + is; + + y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32); + y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32); + y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32); + y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); +#else + + // assume 32 threads + const int tid = threadIdx.x; + const int ip = tid/16; // 0 or 1 + const int il = tid - 16*ip; // 0...15 + + dst_t * y = yy + i*QK_K + 16*ip + il; + + const float d = x[i].d; + + const uint8_t ql = x[i].ql[16*ip + il]; + const uint8_t qh = x[i].qh[il] >> (2*ip); + const int8_t * sc = x[i].scales; + + y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32); + y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32); +#endif +} + +static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { + + static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); + + const int row = blockIdx.y*blockDim.y + threadIdx.y; + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q2_K * x = (const block_q2_K *)vx + ib0; + + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 + + const int step = 16/K_QUANTS_PER_ITERATION; + + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 + + const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2 + const int q_offset = 32*im + l0; + const int s_offset = 8*im; + const int y_offset = 128*im + l0; + + uint32_t aux[4]; + const uint8_t * d = (const uint8_t *)aux; + const uint8_t * m = (const uint8_t *)(aux + 2); + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * q = x[i].qs + q_offset; + + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); + + const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset); + aux[0] = a[0] & 0x0f0f0f0f; + aux[1] = a[1] & 0x0f0f0f0f; + aux[2] = (a[0] >> 4) & 0x0f0f0f0f; + aux[3] = (a[1] >> 4) & 0x0f0f0f0f; + + float sum1 = 0, sum2 = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3) + + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3) + + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3) + + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3) + + y[l+16] * d[1] * ((q[l+16] >> 0) & 3) + + y[l+48] * d[3] * ((q[l+16] >> 2) & 3) + + y[l+80] * d[5] * ((q[l+16] >> 4) & 3) + +y[l+112] * d[7] * ((q[l+16] >> 6) & 3); + sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6] + + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7]; + + } + tmp += dall * sum1 - dmin * sum2; + + } +#else + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 + const int offset = tid * K_QUANTS_PER_ITERATION; + + uint32_t uaux[2]; + const uint8_t * d = (const uint8_t *)uaux; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + offset; + const uint8_t * q = x[i].qs + offset; + const uint32_t * s = (const uint32_t *)x[i].scales; + + uaux[0] = s[0] & 0x0f0f0f0f; + uaux[1] = (s[0] >> 4) & 0x0f0f0f0f; + + const float2 dall = __half22float2(x[i].dm); + + float sum1 = 0, sum2 = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + const uint8_t ql = q[l]; + sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3) + + y[l+16] * d[1] * ((ql >> 2) & 3) + + y[l+32] * d[2] * ((ql >> 4) & 3) + + y[l+48] * d[3] * ((ql >> 6) & 3); + sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7]; + } + tmp += dall.x * sum1 - dall.y * sum2; + } +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[row] = tmp; + } +} + +static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { + + const int row = blockIdx.y*blockDim.y + threadIdx.y; + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q3_K * x = (const block_q3_K *)vx + ib0; + + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + + const uint16_t kmask1 = 0x0303; + const uint16_t kmask2 = 0x0f0f; + + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 + + const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop + const int step = 16/K_QUANTS_PER_ITERATION; + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0....15 or 0...7 + + const uint8_t m = 1 << (4*im); + + const int l0 = n*in; // 0...15 or 0...14 in steps of 2 + const int q_offset = 32*im + l0; + const int y_offset = 128*im + l0; + + uint16_t utmp[4]; + const int8_t * s = (const int8_t *)utmp; + + const uint16_t s_shift = 4*im; + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * q = x[i].qs + q_offset; + const uint8_t * h = x[i].hmask + l0; + + const uint16_t * a = (const uint16_t *)x[i].scales; + utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4); + utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4); + utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4); + utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4); + + const float d = x[i].d; + + float sum = 0; + for (int l = 0; l < n; ++l) { + sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4)) + + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4)) + + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4)) + + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4)); + sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4)) + + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4)) + + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4)) + + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4)); + } + tmp += d * sum; + + } +#else + + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 + const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14 + const int in = offset/8; // 0 or 1 + const int im = offset%8; // 0...7 + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + offset; + const uint8_t * q = x[i].qs + offset; + const uint8_t * s = x[i].scales; + + const float dall = (float)x[i].d; + + float sum = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + const uint8_t hl = x[i].hmask[im+l] >> in; + const uint8_t ql = q[l]; + sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4)) + + y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4)) + + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4)) + + y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4)); + } + tmp += sum; + } +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[row] = tmp; + } +} + +static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { + + const int row = blockIdx.y*blockDim.y + threadIdx.y; + if (row > nrows) return; + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q4_K * x = (const block_q4_K *)vx + ib0; + +#if QK_K == 256 + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 + + const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 + + const int il = tid/step; // 0...3 + const int ir = tid - step*il; // 0...7 or 0...3 + const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 + + const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const int in = il%2; + + const int l0 = n*(2*ir + in); + const int q_offset = 32*im + l0; + const int y_offset = 64*im + l0; + + uint16_t aux[4]; + const uint8_t * sc = (const uint8_t *)aux; + +#if K_QUANTS_PER_ITERATION == 2 + uint32_t q32[4]; + const uint8_t * q4 = (const uint8_t *)q32; +#else + uint16_t q16[4]; + const uint8_t * q4 = (const uint8_t *)q16; +#endif + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const float * y1 = yy + i*QK_K + y_offset; + const float * y2 = y1 + 128; + + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux[0] = a[im+0] & kmask1; + aux[1] = a[im+2] & kmask1; + aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); + aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); + +#if K_QUANTS_PER_ITERATION == 2 + const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset); + const uint32_t * q2 = q1 + 16; + + q32[0] = q1[0] & 0x0f0f0f0f; + q32[1] = q1[0] & 0xf0f0f0f0; + q32[2] = q2[0] & 0x0f0f0f0f; + q32[3] = q2[0] & 0xf0f0f0f0; + + float4 s = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + for (int l = 0; l < 4; ++l) { + s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4]; + s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12]; + smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + } + tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; +#else + const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset); + const uint16_t * q2 = q1 + 32; + + q16[0] = q1[0] & 0x0f0f; + q16[1] = q1[0] & 0xf0f0; + q16[2] = q2[0] & 0x0f0f; + q16[3] = q2[0] & 0xf0f0; + + float4 s = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + for (int l = 0; l < 2; ++l) { + s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2]; + s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6]; + smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + } + tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; +#endif + + } +#else + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); + + const int step = tid * K_QUANTS_PER_ITERATION; + + uint16_t aux16[2]; + const uint8_t * s = (const uint8_t *)aux16; + + float tmp = 0; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const uint8_t * q = x[i].qs + step; + const float * y = yy + i*QK_K + step; + const uint16_t * a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + const float d = (float)x[i].dm[0]; + const float m = (float)x[i].dm[1]; + float sum = 0.f; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2]) + + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2]) + + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3]) + + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]); + } + tmp += sum; + } + +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + +static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) { + + const int row = blockIdx.x; + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q5_K * x = (const block_q5_K *)vx + ib0; + + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int tid = threadIdx.x/2; // 0...15 + const int ix = threadIdx.x%2; + + const int il = tid/4; // 0...3 + const int ir = tid - 4*il;// 0...3 + const int n = 2; + + const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const int in = il%2; + + const int l0 = n*(2*ir + in); + const int q_offset = 32*im + l0; + const int y_offset = 64*im + l0; + + const uint8_t hm1 = 1 << (2*im); + const uint8_t hm2 = hm1 << 4; + + uint16_t aux[4]; + const uint8_t * sc = (const uint8_t *)aux; + + uint16_t q16[8]; + const uint8_t * q4 = (const uint8_t *)q16; + + for (int i = ix; i < num_blocks_per_row; i += 2) { + + const uint8_t * ql1 = x[i].qs + q_offset; + const uint8_t * qh = x[i].qh + l0; + const float * y1 = yy + i*QK_K + y_offset; + const float * y2 = y1 + 128; + + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux[0] = a[im+0] & kmask1; + aux[1] = a[im+2] & kmask1; + aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); + aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); + + float4 sum = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + const uint16_t * q1 = (const uint16_t *)ql1; + const uint16_t * q2 = q1 + 32; + q16[0] = q1[0] & 0x0f0f; + q16[1] = q1[8] & 0x0f0f; + q16[2] = (q1[0] >> 4) & 0x0f0f; + q16[3] = (q1[8] >> 4) & 0x0f0f; + q16[4] = q2[0] & 0x0f0f; + q16[5] = q2[8] & 0x0f0f; + q16[6] = (q2[0] >> 4) & 0x0f0f; + q16[7] = (q2[8] >> 4) & 0x0f0f; + for (int l = 0; l < n; ++l) { + sum.x += y1[l+ 0] * (q4[l +0] + (qh[l+ 0] & (hm1 << 0) ? 16 : 0)) + + y1[l+16] * (q4[l +2] + (qh[l+16] & (hm1 << 0) ? 16 : 0)); + sum.y += y1[l+32] * (q4[l +4] + (qh[l+ 0] & (hm1 << 1) ? 16 : 0)) + + y1[l+48] * (q4[l +6] + (qh[l+16] & (hm1 << 1) ? 16 : 0)); + sum.z += y2[l+ 0] * (q4[l +8] + (qh[l+ 0] & (hm2 << 0) ? 16 : 0)) + + y2[l+16] * (q4[l+10] + (qh[l+16] & (hm2 << 0) ? 16 : 0)); + sum.w += y2[l+32] * (q4[l+12] + (qh[l+ 0] & (hm2 << 1) ? 16 : 0)) + + y2[l+48] * (q4[l+14] + (qh[l+16] & (hm2 << 1) ? 16 : 0)); + smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3] + + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7]; + } + tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin; + } + +#else + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); + const int step = tid * K_QUANTS_PER_ITERATION; + const int im = step/8; + const int in = step%8; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const uint8_t * q = x[i].qs + step; + const int8_t * s = x[i].scales; + const float * y = yy + i*QK_K + step; + const float d = x[i].d; + float sum = 0.f; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + const uint8_t h = x[i].qh[in+j] >> im; + sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16)) + + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16)) + + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16)) + + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16)); + } + tmp += sum; + } +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[row] = tmp; + } +} + +static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { + + static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); + + const int row = blockIdx.y*blockDim.y + threadIdx.y; + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q6_K * x = (const block_q6_K *)vx + ib0; + +#if QK_K == 256 + + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + + const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 + +#if K_QUANTS_PER_ITERATION == 1 + const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 + const int is = 0; +#else + const int l0 = 4 * in; // 0, 4, 8, ..., 28 + const int is = in / 4; +#endif + const int ql_offset = 64*im + l0; + const int qh_offset = 32*im + l0; + const int s_offset = 8*im + is; + const int y_offset = 128*im + l0; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * ql = x[i].ql + ql_offset; + const uint8_t * qh = x[i].qh + qh_offset; + const int8_t * s = x[i].scales + s_offset; + + const float d = x[i].d; + +#if K_QUANTS_PER_ITERATION == 1 + float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32) + + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32) + + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32) + + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32) + + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32) + + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32) + + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32) + +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32); + tmp += sum; +#else + float sum = 0; + for (int l = 0; l < 4; ++l) { + sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) + + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) + + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) + + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); + } + tmp += sum; +#endif + + } + +#else + + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...7 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0...3 + + const int step = tid * K_QUANTS_PER_ITERATION; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + step; + const uint8_t * ql = x[i].ql + step; + const uint8_t * qh = x[i].qh + step; + const int8_t * s = x[i].scales; + + const float d = x[i+0].d; + + float sum = 0; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32) + + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32) + + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32) + + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32); + } + tmp += sum; + + } + +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + +static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const half * x = (const half *) vx; + + // automatic half -> float type cast if dfloat == float + v.x = x[ib + iqs + 0]; + v.y = x[ib + iqs + 1]; +} + +static __device__ void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const float * x = (const float *) vx; + + // automatic half -> float type cast if dfloat == float + v.x = x[ib + iqs + 0]; + v.y = x[ib + iqs + 1]; +} + +static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) { + const int ix = blockDim.x*blockIdx.x + threadIdx.x; + + if (ix >= kx_padded) { + return; + } + + const int iy = blockDim.y*blockIdx.y + threadIdx.y; + + const int i_padded = iy*kx_padded + ix; + + block_q8_1 * y = (block_q8_1 *) vy; + + const int ib = i_padded / QK8_1; // block index + const int iqs = i_padded % QK8_1; // quant index + + const float xi = ix < kx ? x[iy*kx + ix] : 0.0f; + float amax = fabsf(xi); + float sum = xi; + +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32)); + sum += __shfl_xor_sync(0xffffffff, sum, mask, 32); + } + + const float d = amax / 127; + const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); + + y[ib].qs[iqs] = q; + + if (iqs > 0) { + return; + } + + reinterpret_cast(y[ib].ds.x) = d; + reinterpret_cast(y[ib].ds.y) = sum; +} + +template +static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) { + const int i = blockDim.x*blockIdx.x + 2*threadIdx.x; + + if (i >= k) { + return; + } + + const int ib = i/qk; // block index + const int iqs = (i%qk)/qr; // quant index + const int iybs = i - i%qk; // y block start index + const int y_offset = qr == 1 ? 1 : qk/2; + + // dequantize + dfloat2 v; + dequantize_kernel(vx, ib, iqs, v); + + y[iybs + iqs + 0] = v.x; + y[iybs + iqs + y_offset] = v.y; +} + +// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called +// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q + +#define VDR_Q4_0_Q8_1_MMVQ 2 +#define VDR_Q4_0_Q8_1_MMQ 4 + +template static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl( + const int * v, const int * u, const float & d4, const half2 & ds8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; + const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; + + // SIMD dot product of quantized values + sumi = __dp4a(vi0, u[2*i+0], sumi); + sumi = __dp4a(vi1, u[2*i+1], sumi); + } + + const float2 ds8f = __half22float2(ds8); + + // second part effectively subtracts 8 from each quant value + return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y); +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +#define VDR_Q4_1_Q8_1_MMVQ 2 +#define VDR_Q4_1_Q8_1_MMQ 4 + +template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl( + const int * v, const int * u, const half2 & dm4, const half2 & ds8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; + const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; + + // SIMD dot product of quantized values + sumi = __dp4a(vi0, u[2*i+0], sumi); + sumi = __dp4a(vi1, u[2*i+1], sumi); + } + +#ifdef GGML_CUDA_F16 + const float2 tmp = __half22float2(__hmul2(dm4, ds8)); + const float d4d8 = tmp.x; + const float m4s8 = tmp.y; +#else + const float2 dm4f = __half22float2(dm4); + const float2 ds8f = __half22float2(ds8); + const float d4d8 = dm4f.x * ds8f.x; + const float m4s8 = dm4f.y * ds8f.y; +#endif // GGML_CUDA_F16 + + // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it + return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1)); +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +#define VDR_Q5_0_Q8_1_MMVQ 2 +#define VDR_Q5_0_Q8_1_MMQ 4 + +template static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl( + const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits + vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4 + vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 + vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 + vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 + sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values + + int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits + vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 + vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 + vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 + vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 + sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values + } + + const float2 ds8f = __half22float2(ds8); + + // second part effectively subtracts 16 from each quant value + return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y); +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +#define VDR_Q5_1_Q8_1_MMVQ 2 +#define VDR_Q5_1_Q8_1_MMQ 4 + +template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl( + const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits + vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4 + vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 + vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 + vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 + sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values + + int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits + vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 + vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 + vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 + vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 + sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values + } + +#ifdef GGML_CUDA_F16 + const float2 tmp = __half22float2(__hmul2(dm5, ds8)); + const float d5d8 = tmp.x; + const float m5s8 = tmp.y; +#else + const float2 dm5f = __half22float2(dm5); + const float2 ds8f = __half22float2(ds8); + const float d5d8 = dm5f.x * ds8f.x; + const float m5s8 = dm5f.y * ds8f.y; +#endif // GGML_CUDA_F16 + + // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it + return sumi*d5d8 + m5s8 / (QI5_1 / vdr); + +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +#define VDR_Q8_0_Q8_1_MMVQ 2 +#define VDR_Q8_0_Q8_1_MMQ 8 + +template static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl( + const int * v, const int * u, const float & d8_0, const float & d8_1) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + // SIMD dot product of quantized values + sumi = __dp4a(v[i], u[i], sumi); + } + + return d8_0*d8_1 * sumi; +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl( + const int * v, const int * u, const half2 & dm8, const half2 & ds8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + // SIMD dot product of quantized values + sumi = __dp4a(v[i], u[i], sumi); + } + +#ifdef GGML_CUDA_F16 + const float2 tmp = __half22float2(__hmul2(dm8, ds8)); + const float d8d8 = tmp.x; + const float m8s8 = tmp.y; +#else + const float2 dm8f = __half22float2(dm8); + const float2 ds8f = __half22float2(ds8); + const float d8d8 = dm8f.x * ds8f.x; + const float m8s8 = dm8f.y * ds8f.y; +#endif // GGML_CUDA_F16 + + // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it + return sumi*d8d8 + m8s8 / (QI8_1 / vdr); +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +#define VDR_Q2_K_Q8_1_MMVQ 1 +#define VDR_Q2_K_Q8_1_MMQ 2 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( + const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales, + const half2 & dm2, const float * __restrict__ d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR2_K; ++i) { + const int sc = scales[2*i]; + + const int vi = (v >> (2*i)) & 0x03030303; + + sumf_d += d8[i] * (__dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product + + // fill int with 4x m + int m = sc >> 4; + m |= m << 8; + m |= m << 16; + sumf_m += d8[i] * __dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values + } + + const float2 dm2f = __half22float2(dm2); + + return dm2f.x*sumf_d - dm2f.y*sumf_m; +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ scales, + const half2 & dm2, const float & d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + int sumi_d = 0; + int sumi_m = 0; + +#pragma unroll + for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) { + int sumi_d_sc = 0; + + const int sc = scales[i0 / (QI8_1/2)]; + + // fill int with 4x m + int m = sc >> 4; + m |= m << 8; + m |= m << 16; + +#pragma unroll + for (int i = i0; i < i0 + QI8_1/2; ++i) { + sumi_d_sc = __dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product + sumi_m = __dp4a(m, u[i], sumi_m); // multiply sum of q8_1 values with m + } + + sumi_d += sumi_d_sc * (sc & 0xF); + } + + const float2 dm2f = __half22float2(dm2); + + return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m); +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +#define VDR_Q3_K_Q8_1_MMVQ 1 +#define VDR_Q3_K_Q8_1_MMQ 2 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq( + const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales, + const int & scale_offset, const float & d3, const float * __restrict__ d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf = 0.0f; + +#pragma unroll + for (int i = 0; i < QR3_K; ++i) { + const int isc = scale_offset + 2*i; + + const int isc_low = isc % (QK_K/32); + const int sc_shift_low = 4 * (isc / (QK_K/32)); + const int sc_low = (scales[isc_low] >> sc_shift_low) & 0xF; + + const int isc_high = isc % (QK_K/64); + const int sc_shift_high = 2 * (isc / (QK_K/64)); + const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4; + + const int sc = (sc_low | sc_high) - 32; + + const int vil = (vl >> (2*i)) & 0x03030303; + + const int vih = ((vh >> i) << 2) & 0x04040404; + + const int vi = __vsubss4(vil, vih); + + sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product + } + + return d3 * sumf; +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales, + const float & d3, const float & d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + int sumi = 0; + +#pragma unroll + for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) { + int sumi_sc = 0; + + for (int i = i0; i < i0 + QI8_1/2; ++i) { + sumi_sc = __dp4a(v[i], u[i], sumi_sc); // SIMD dot product + } + + sumi += sumi_sc * scales[i0 / (QI8_1/2)]; + } + + return d3*d8 * sumi; +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +#define VDR_Q4_K_Q8_1_MMVQ 2 +#define VDR_Q4_K_Q8_1_MMQ 8 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR4_K; ++i) { + const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F; + const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F; + + const int dot1 = __dp4a(v1i, u[2*i+1], __dp4a(v0i, u[2*i+0], 0)); // SIMD dot product + const int dot2 = __dp4a(0x01010101, u[2*i+1], __dp4a(0x01010101, u[2*i+0], 0)); // sum of u + + sumf_d += d8[i] * (dot1 * sc[i]); + sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values + } + + const float2 dm4f = __half22float2(dm4); + + return dm4f.x*sumf_d - dm4f.y*sumf_m; + +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) { + int sumi_d = 0; + +#pragma unroll + for (int j = 0; j < QI8_1; ++j) { + sumi_d = __dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product + } + + const float2 ds8f = __half22float2(ds8[i]); + + sumf_d += ds8f.x * (sc[i] * sumi_d); + sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val + } + + const float2 dm4f = __half22float2(dm4); + + return dm4f.x*sumf_d - dm4f.y*sumf_m; + +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +#define VDR_Q5_K_Q8_1_MMVQ 2 +#define VDR_Q5_K_Q8_1_MMQ 8 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq( + const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR5_K; ++i) { + const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F; + const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F; + + const int vh0i = ((vh[0] >> i) << 4) & 0x10101010; + const int vh1i = ((vh[1] >> i) << 4) & 0x10101010; + + const int v0i = vl0i | vh0i; + const int v1i = vl1i | vh1i; + + const int dot1 = __dp4a(v0i, u[2*i+0], __dp4a(v1i, u[2*i+1], 0)); // SIMD dot product + const int dot2 = __dp4a(0x01010101, u[2*i+0], __dp4a(0x01010101, u[2*i+1], 0)); // sum of u + + sumf_d += d8[i] * (dot1 * sc[i]); + sumf_m += d8[i] * (dot2 * m[i]); + + } + + const float2 dm5f = __half22float2(dm5); + + return dm5f.x*sumf_d - dm5f.y*sumf_m; + +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) { + int sumi_d = 0; + +#pragma unroll + for (int j = 0; j < QI8_1; ++j) { + sumi_d = __dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product + } + + const float2 ds8f = __half22float2(ds8[i]); + + sumf_d += ds8f.x * (sc[i] * sumi_d); + sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val + } + + const float2 dm4f = __half22float2(dm4); + + return dm4f.x*sumf_d - dm4f.y*sumf_m; + +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +#define VDR_Q6_K_Q8_1_MMVQ 1 +#define VDR_Q6_K_Q8_1_MMQ 8 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq( + const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales, + const float & d, const float * __restrict__ d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf = 0.0f; + +#pragma unroll + for (int i = 0; i < QR6_K; ++i) { + const int sc = scales[4*i]; + + const int vil = (vl >> (4*i)) & 0x0F0F0F0F; + + const int vih = ((vh >> (4*i)) << 4) & 0x30303030; + + const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32 + + sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product + } + + return d*sumf; +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc, + const float & d6, const float * __restrict__ d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf_d = 0.0f; + +#pragma unroll + for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) { + int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale + +#pragma unroll + for (int i = i0; i < i0 + 2; ++i) { + sumi_d.x = __dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product + sumi_d.x = __dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product + + sumi_d.y = __dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product + sumi_d.y = __dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product + } + + sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y); + } + + return d6 * sumf_d; + +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +static __device__ __forceinline__ float vec_dot_q4_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; + + int v[VDR_Q4_0_Q8_1_MMVQ]; + int u[2*VDR_Q4_0_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) { + v[i] = get_int_from_uint8(bq4_0->qs, iqs + i); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0); + } + + return vec_dot_q4_0_q8_1_impl(v, u, bq4_0->d, bq8_1->ds); +} + +template static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + + __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y]; + __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI4_0) + mmq_y/QI4_0]; + + *x_ql = tile_x_qs; + *x_dm = (half2 *) tile_x_d; +} + +template static __device__ __forceinline__ void load_tiles_q4_0( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI4_0; + const int kqsx = k % QI4_0; + + const block_q4_0 * bx0 = (block_q4_0 *) vx; + + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); + // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) { + int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d; + } +} + +static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + const float * x_dmf = (float *) x_dm; + + int u[2*VDR_Q4_0_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE]; + } + + return vec_dot_q4_0_q8_1_impl + (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0], + y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); +} + +static __device__ __forceinline__ float vec_dot_q4_1_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq; + + int v[VDR_Q4_1_Q8_1_MMVQ]; + int u[2*VDR_Q4_1_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) { + v[i] = get_int_from_uint8_aligned(bq4_1->qs, iqs + i); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1); + } + + return vec_dot_q4_1_q8_1_impl(v, u, bq4_1->dm, bq8_1->ds); +} + +template static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + + __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1]; + + *x_ql = tile_x_qs; + *x_dm = tile_x_dm; +} + +template static __device__ __forceinline__ void load_tiles_q4_1( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI4_1; + const int kqsx = k % QI4_1; + + const block_q4_1 * bx0 = (block_q4_1 *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) { + int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm; + } +} + +static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + + int u[2*VDR_Q4_1_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE]; + } + + return vec_dot_q4_1_q8_1_impl + (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1], + y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); +} + +static __device__ __forceinline__ float vec_dot_q5_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq; + + int vl[VDR_Q5_0_Q8_1_MMVQ]; + int vh[VDR_Q5_0_Q8_1_MMVQ]; + int u[2*VDR_Q5_0_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) { + vl[i] = get_int_from_uint8(bq5_0->qs, iqs + i); + vh[i] = get_int_from_uint8(bq5_0->qh, 0) >> (4 * (iqs + i)); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_0); + } + + return vec_dot_q5_0_q8_1_impl(vl, vh, u, bq5_0->d, bq8_1->ds); +} + +template static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; + __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI5_0) + mmq_y/QI5_0]; + + *x_ql = tile_x_ql; + *x_dm = (half2 *) tile_x_d; +} + +template static __device__ __forceinline__ void load_tiles_q5_0( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI5_0; + const int kqsx = k % QI5_0; + + const block_q5_0 * bx0 = (block_q5_0 *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx; + + const int ql = get_int_from_uint8(bxi->qs, kqsx); + const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0)); + + int qs0 = (ql >> 0) & 0x0F0F0F0F; + qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 + qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 + qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 + qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 + qs0 = __vsubss4(qs0, 0x10101010); // subtract 16 + + x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0; + + int qs1 = (ql >> 4) & 0x0F0F0F0F; + qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 + qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 + qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 + + x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI5_0; + const int kbxd = k % blocks_per_tile_x_row; + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) { + int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d; + } +} + +static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0; + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + int u[2*VDR_Q5_0_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE]; + } + + return vec_dot_q8_0_q8_1_impl + (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); +} + +static __device__ __forceinline__ float vec_dot_q5_1_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq; + + int vl[VDR_Q5_1_Q8_1_MMVQ]; + int vh[VDR_Q5_1_Q8_1_MMVQ]; + int u[2*VDR_Q5_1_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) { + vl[i] = get_int_from_uint8_aligned(bq5_1->qs, iqs + i); + vh[i] = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * (iqs + i)); + u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1); + } + + return vec_dot_q5_1_q8_1_impl(vl, vh, u, bq5_1->dm, bq8_1->ds); +} + +template static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_1) + mmq_y/QI5_1]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; +} + +template static __device__ __forceinline__ void load_tiles_q5_1( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI5_1; + const int kqsx = k % QI5_1; + + const block_q5_1 * bx0 = (block_q5_1 *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx; + + const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); + const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1)); + + int qs0 = (ql >> 0) & 0x0F0F0F0F; + qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 + qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 + qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 + qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 + + x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0; + + int qs1 = (ql >> 4) & 0x0F0F0F0F; + qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 + qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 + + x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI5_1; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) { + int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm; + } +} + +static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1; + + int u[2*VDR_Q5_1_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE]; + } + + return vec_dot_q8_1_q8_1_impl + (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); +} + +static __device__ __forceinline__ float vec_dot_q8_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq; + + int v[VDR_Q8_0_Q8_1_MMVQ]; + int u[VDR_Q8_0_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) { + v[i] = get_int_from_int8(bq8_0->qs, iqs + i); + u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + } + + return vec_dot_q8_0_q8_1_impl(v, u, bq8_0->d, __low2half(bq8_1->ds)); +} + +template static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + + __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y]; + __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0]; + + *x_ql = tile_x_qs; + *x_dm = (half2 *) tile_x_d; +} + +template static __device__ __forceinline__ void load_tiles_q8_0( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI8_0; + const int kqsx = k % QI8_0; + float * x_dmf = (float *) x_dm; + + const block_q8_0 * bx0 = (block_q8_0 *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI8_0; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) { + int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d; + } +} + +static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + return vec_dot_q8_0_q8_1_impl + (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0], + y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]); +} + +static __device__ __forceinline__ float vec_dot_q2_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q2_K * bq2_K = (const block_q2_K *) vbq; + + const int bq8_offset = QR2_K * (iqs / QI8_1); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); + + const uint8_t * scales = bq2_K->scales + scale_offset; + + const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs); + int u[QR2_K]; + float d8[QR2_K]; + +#pragma unroll + for (int i = 0; i < QR2_K; ++ i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); + d8[i] = __low2half(bq8_1[bq8_offset + i].ds); + } + + return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8); +} + +template static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + + __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI2_K) + mmq_y/QI2_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_sc = tile_x_sc; +} + +template static __device__ __forceinline__ void load_tiles_q2_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI2_K; + const int kqsx = k % QI2_K; + + const block_q2_K * bx0 = (block_q2_K *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI2_K; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) { + int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + i_offset * 4 + k / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4); + + x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4)); + } +} + +static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + const int kbx = k / QI2_K; + const int ky = (k % QI2_K) * QR2_K; + const float * y_df = (const float *) y_ds; + + int v[QR2_K*VDR_Q2_K_Q8_1_MMQ]; + + const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2); + const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2)); + +#pragma unroll + for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) { + v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303; + } + + const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4; + + const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE; + return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]); +} + +static __device__ __forceinline__ float vec_dot_q3_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q3_K * bq3_K = (const block_q3_K *) vbq; + + const int bq8_offset = QR3_K * (iqs / (QI3_K/2)); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); + + const float d = bq3_K->d; + + const int vl = get_int_from_uint8(bq3_K->qs, iqs); + + // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted + const int vh = ~get_int_from_uint8(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset; + + int u[QR3_K]; + float d8[QR3_K]; + +#pragma unroll + for (int i = 0; i < QR3_K; ++i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); + d8[i] = __low2half(bq8_1[bq8_offset + i].ds); + } + + return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8); +} + +template static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + + __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI3_K) + mmq_y/QI3_K]; + __shared__ int tile_x_qh[mmq_y * (WARP_SIZE/2) + mmq_y/2]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_qh = tile_x_qh; + *x_sc = tile_x_sc; +} + +template static __device__ __forceinline__ void load_tiles_q3_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI3_K; + const int kqsx = k % QI3_K; + + const block_q3_K * bx0 = (block_q3_K *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI3_K; + const int kbxd = k % blocks_per_tile_x_row; + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) { + int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) { + int i = i0 + i_offset * 2 + k / (WARP_SIZE/2); + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2); + + // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted + x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2)); + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + i_offset * 4 + k / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4); + + const int ksc = k % (QI3_K/4); + + const int ksc_low = ksc % (QI3_K/8); + const int shift_low = 4 * (ksc / (QI3_K/8)); + const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F; + + const int ksc_high = QI3_K/8; + const int shift_high = 2 * ksc; + const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030; + + const int sc = __vsubss4(sc_low | sc_high, 0x20202020); + + x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc; + } +} + +static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + const int kbx = k / QI3_K; + const int ky = (k % QI3_K) * QR3_K; + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + const int8_t * scales = ((int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; + + int v[QR3_K*VDR_Q3_K_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) { + const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2); + const int shift = 2 * ((ky % 32) / 8); + const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303; + + const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8); + const int vlh = (vh << 2) & 0x04040404; + + v[l] = __vsubss4(vll, vlh); + } + + const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE; + return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]); +} + +static __device__ __forceinline__ float vec_dot_q4_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + +#ifndef GGML_QKK_64 + const block_q4_K * bq4_K = (const block_q4_K *) vbq; + + int v[2]; + int u[2*QR4_K]; + float d8[QR4_K]; + + // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6 + const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2)); + + // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12 + // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44 + // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76 + // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108 + + const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4)); + v[0] = q4[0]; + v[1] = q4[4]; + + const uint16_t * scales = (const uint16_t *)bq4_K->scales; + uint16_t aux[2]; + const int j = bq8_offset/2; + if (j < 2) { + aux[0] = scales[j+0] & 0x3f3f; + aux[1] = scales[j+2] & 0x3f3f; + } else { + aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); + aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); + } + const uint8_t * sc = (const uint8_t *)aux; + const uint8_t * m = sc + 2; + + for (int i = 0; i < QR4_K; ++i) { + const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; + d8[i] = __low2half(bq8i->ds); + + const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); + u[2*i+0] = q8[0]; + u[2*i+1] = q8[4]; + } + + return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8); + +#else + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + const block_q4_K * bq4_K = (const block_q4_K *) vbq; + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + + uint16_t aux16[2]; + const uint8_t * s = (const uint8_t *)aux16; + + const uint16_t * a = (const uint16_t *)bq4_K->scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + const float dall = bq4_K->dm[0]; + const float dmin = bq4_K->dm[1]; + + const float d8_1 = __low2float(bq8_1[0].ds); + const float d8_2 = __low2float(bq8_1[1].ds); + + const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2)); + const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4); + const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2)); + const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4); + + const int * q4 = (const int *)bq4_K->qs + (iqs/2); + const int v1 = q4[0]; + const int v2 = q4[4]; + + const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0)); + const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0)); + const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0)); + const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0)); + + sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]); + sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]); + + return dall * sumf_d - dmin * sumf_m; + +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A + +#endif +} + +template static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + + __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_K) + mmq_y/QI4_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_sc = tile_x_sc; +} + +template static __device__ __forceinline__ void load_tiles_q4_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI4_K; // == 0 if QK_K == 256 + const int kqsx = k % QI4_K; // == k if QK_K == 256 + + const block_q4_K * bx0 = (block_q4_K *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256 + const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) { + int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd; + +#if QK_K == 256 + x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm; +#else + x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]}; +#endif + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8); + + const int * scales = (int *) bxi->scales; + + const int ksc = k % (WARP_SIZE/8); + + // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 + int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits + scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits + + x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; + } +} + +static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8); + + const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE; + return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8, + x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]); +} + +static __device__ __forceinline__ float vec_dot_q5_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + +#ifndef GGML_QKK_64 + const block_q5_K * bq5_K = (const block_q5_K *) vbq; + + int vl[2]; + int vh[2]; + int u[2*QR5_K]; + float d8[QR5_K]; + + const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2)); + const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4)); + const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4)); + + vl[0] = ql[0]; + vl[1] = ql[4]; + + vh[0] = qh[0] >> bq8_offset; + vh[1] = qh[4] >> bq8_offset; + + const uint16_t * scales = (const uint16_t *)bq5_K->scales; + uint16_t aux[2]; + const int j = bq8_offset/2; + if (j < 2) { + aux[0] = scales[j+0] & 0x3f3f; + aux[1] = scales[j+2] & 0x3f3f; + } else { + aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); + aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); + } + const uint8_t * sc = (const uint8_t *)aux; + const uint8_t * m = sc + 2; + +#pragma unroll + for (int i = 0; i < QR5_K; ++i) { + const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; + d8[i] = __low2float(bq8i->ds); + + const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); + u[2*i+0] = q8[0]; + u[2*i+1] = q8[4]; + } + + return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8); + +#else + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + const block_q5_K * bq5_K = (const block_q5_K *) vbq; + + const int8_t * s = bq5_K->scales; + + const float d = bq5_K->d; + + const float d8_1 = __low2half(bq8_1[0].ds); + const float d8_2 = __low2half(bq8_1[1].ds); + + const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2)); + const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4); + const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2)); + const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4); + + const int * ql = (const int *)bq5_K->qs + (iqs/2); + const int vl1 = ql[0]; + const int vl2 = ql[4]; + + const int step = 4 * (iqs/2); // 0, 4, 8, 12 + const int im = step/8; // = 0 for iqs = 0, 2, = 1 for iqs = 4, 6 + const int in = step%8; // 0, 4, 0, 4 + const int vh = (*((const int *)(bq5_K->qh + in))) >> im; + + const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f); + const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f); + const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f); + const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f); + + const float sumf_d = d8_1 * (__dp4a(ui1, v1, 0) * s[0] + __dp4a(ui2, v2, 0) * s[1]) + + d8_2 * (__dp4a(ui3, v3, 0) * s[2] + __dp4a(ui4, v4, 0) * s[3]); + + return d * sumf_d; + +#else + assert(false); + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A + +#endif +} + +template static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_K) + mmq_y/QI5_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_sc = tile_x_sc; +} + +template static __device__ __forceinline__ void load_tiles_q5_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI5_K; // == 0 if QK_K == 256 + const int kqsx = k % QI5_K; // == k if QK_K == 256 + + const block_q5_K * bx0 = (block_q5_K *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx; + const int ky = QR5_K*kqsx; + + const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; + + const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4)); + const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010; + const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010; + + const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0; + const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4); + + x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0; + x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256 + const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) { + int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd; + +#if QK_K == 256 + x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm; +#endif + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8); + + const int * scales = (int *) bxi->scales; + + const int ksc = k % (WARP_SIZE/8); + + // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 + int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits + scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits + + x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; + } +} + +static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8); + + const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k; + const int index_y = j * WARP_SIZE + (QR5_K*k) % WARP_SIZE; + return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8, + x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]); +} + +static __device__ __forceinline__ float vec_dot_q6_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_q6_K * bq6_K = (const block_q6_K *) vbq; + + const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4); + const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8); + const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4)); + + const int vl = get_int_from_uint8(bq6_K->ql, iqs); + const int vh = get_int_from_uint8(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift; + + const int8_t * scales = bq6_K->scales + scale_offset; + + int u[QR6_K]; + float d8[QR6_K]; + +#pragma unroll + for (int i = 0; i < QR6_K; ++i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1); + d8[i] = __low2half(bq8_1[bq8_offset + 2*i].ds); + } + + return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8); +} + +template static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI6_K) + mmq_y/QI6_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_sc = tile_x_sc; +} + +template static __device__ __forceinline__ void load_tiles_q6_K( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + + GGML_CUDA_ASSUME(i_offset >= 0); + GGML_CUDA_ASSUME(i_offset < nwarps); + GGML_CUDA_ASSUME(k >= 0); + GGML_CUDA_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI6_K; // == 0 if QK_K == 256 + const int kqsx = k % QI6_K; // == k if QK_K == 256 + + const block_q6_K * bx0 = (block_q6_K *) vx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx; + const int ky = QR6_K*kqsx; + + const int ql = get_int_from_uint8(bxi->ql, kqsx); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; + + const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4)); + const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030; + const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030; + + const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0; + const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2); + + x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 + const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) { + int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4; + + x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8)); + } +} + +static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]); + + const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k; + const int index_y = j * WARP_SIZE + (QR6_K*k) % WARP_SIZE; + return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]); +} + +template +static __device__ __forceinline__ void mul_mat_q( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_col_y = nrows_y / QK8_1; + const int blocks_per_warp = WARP_SIZE / qi; + + const int & ncols_dst = ncols_y; + + const int row_dst_0 = blockIdx.x*mmq_y; + const int & row_x_0 = row_dst_0; + + const int col_dst_0 = blockIdx.y*mmq_x; + const int & col_y_0 = col_dst_0; + + int * tile_x_ql = nullptr; + half2 * tile_x_dm = nullptr; + int * tile_x_qh = nullptr; + int * tile_x_sc = nullptr; + + allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc); + + __shared__ int tile_y_qs[mmq_x * WARP_SIZE]; + __shared__ half2 tile_y_ds[mmq_x * WARP_SIZE/QI8_1]; + + float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {0.0f}; + + for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) { + + load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, + threadIdx.y, nrows_x-row_x_0-1, threadIdx.x, blocks_per_row_x); + +#pragma unroll + for (int ir = 0; ir < qr; ++ir) { + const int kqs = ir*WARP_SIZE + threadIdx.x; + const int kbxd = kqs / QI8_1; + +#pragma unroll + for (int i = 0; i < mmq_x; i += nwarps) { + const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y-1); // to prevent out-of-bounds memory accesses + + const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd]; + + const int index_y = (threadIdx.y + i) * WARP_SIZE + kqs % WARP_SIZE; + tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1); + } + +#pragma unroll + for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) { + const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x; + const int kby = threadIdx.x % (WARP_SIZE/QI8_1); + const int col_y_eff = min(col_y_0 + ids, ncols_y-1); + + // if the sum is not needed it's faster to transform the scale to f32 ahead of time + const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE/QI8_1) + kby].ds; + half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby]; + if (need_sum) { + *dsi_dst = *dsi_src; + } else { + float * dfi_dst = (float *) dsi_dst; + *dfi_dst = __low2half(*dsi_src); + } + } + + __syncthreads(); + +// #pragma unroll // unrolling this loop causes too much register pressure + for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) { +#pragma unroll + for (int j = 0; j < mmq_x; j += nwarps) { +#pragma unroll + for (int i = 0; i < mmq_y; i += WARP_SIZE) { + sum[i/WARP_SIZE][j/nwarps] += vec_dot( + tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, + threadIdx.x + i, threadIdx.y + j, k); + } + } + } + + __syncthreads(); + } + } + +#pragma unroll + for (int j = 0; j < mmq_x; j += nwarps) { + const int col_dst = col_dst_0 + j + threadIdx.y; + + if (col_dst >= ncols_dst) { + return; + } + +#pragma unroll + for (int i = 0; i < mmq_y; i += WARP_SIZE) { + const int row_dst = row_dst_0 + threadIdx.x + i; + + if (row_dst >= nrows_dst) { + continue; + } + + dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps]; + } + } +} + +#define MMQ_X_Q4_0_RDNA2 64 +#define MMQ_Y_Q4_0_RDNA2 128 +#define NWARPS_Q4_0_RDNA2 8 +#define MMQ_X_Q4_0_RDNA1 64 +#define MMQ_Y_Q4_0_RDNA1 64 +#define NWARPS_Q4_0_RDNA1 8 +#define MMQ_X_Q4_0_AMPERE 64 +#define MMQ_Y_Q4_0_AMPERE 128 +#define NWARPS_Q4_0_AMPERE 4 +#define MMQ_X_Q4_0_PASCAL 64 +#define MMQ_Y_Q4_0_PASCAL 64 +#define NWARPS_Q4_0_PASCAL 8 + +template static __global__ void +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + __launch_bounds__(WARP_SIZE*NWARPS_Q4_0_RDNA2, 2) +#endif // defined(RDNA3) || defined(RDNA2) +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + mul_mat_q4_0( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + const int mmq_x = MMQ_X_Q4_0_RDNA2; + const int mmq_y = MMQ_Y_Q4_0_RDNA2; + const int nwarps = NWARPS_Q4_0_RDNA2; +#else + const int mmq_x = MMQ_X_Q4_0_RDNA1; + const int mmq_y = MMQ_Y_Q4_0_RDNA1; + const int nwarps = NWARPS_Q4_0_RDNA1; +#endif // defined(RDNA3) || defined(RDNA2) + + mul_mat_q, + load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= CC_VOLTA + const int mmq_x = MMQ_X_Q4_0_AMPERE; + const int mmq_y = MMQ_Y_Q4_0_AMPERE; + const int nwarps = NWARPS_Q4_0_AMPERE; + + mul_mat_q, + load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= MIN_CC_DP4A + const int mmq_x = MMQ_X_Q4_0_PASCAL; + const int mmq_y = MMQ_Y_Q4_0_PASCAL; + const int nwarps = NWARPS_Q4_0_PASCAL; + + mul_mat_q, + load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +#else + (void) vec_dot_q4_0_q8_1_mul_mat; + assert(false); +#endif // __CUDA_ARCH__ >= CC_VOLTA +} + +#define MMQ_X_Q4_1_RDNA2 64 +#define MMQ_Y_Q4_1_RDNA2 128 +#define NWARPS_Q4_1_RDNA2 8 +#define MMQ_X_Q4_1_RDNA1 64 +#define MMQ_Y_Q4_1_RDNA1 64 +#define NWARPS_Q4_1_RDNA1 8 +#define MMQ_X_Q4_1_AMPERE 64 +#define MMQ_Y_Q4_1_AMPERE 128 +#define NWARPS_Q4_1_AMPERE 4 +#define MMQ_X_Q4_1_PASCAL 64 +#define MMQ_Y_Q4_1_PASCAL 64 +#define NWARPS_Q4_1_PASCAL 8 + +template static __global__ void +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_RDNA2, 2) +#endif // defined(RDNA3) || defined(RDNA2) +#elif __CUDA_ARCH__ < CC_VOLTA + __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_PASCAL, 2) +#endif // __CUDA_ARCH__ < CC_VOLTA + mul_mat_q4_1( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + const int mmq_x = MMQ_X_Q4_1_RDNA2; + const int mmq_y = MMQ_Y_Q4_1_RDNA2; + const int nwarps = NWARPS_Q4_1_RDNA2; +#else + const int mmq_x = MMQ_X_Q4_1_RDNA1; + const int mmq_y = MMQ_Y_Q4_1_RDNA1; + const int nwarps = NWARPS_Q4_1_RDNA1; +#endif // defined(RDNA3) || defined(RDNA2) + + mul_mat_q, + load_tiles_q4_1, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= CC_VOLTA + const int mmq_x = MMQ_X_Q4_1_AMPERE; + const int mmq_y = MMQ_Y_Q4_1_AMPERE; + const int nwarps = NWARPS_Q4_1_AMPERE; + + mul_mat_q, + load_tiles_q4_1, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= MIN_CC_DP4A + const int mmq_x = MMQ_X_Q4_1_PASCAL; + const int mmq_y = MMQ_Y_Q4_1_PASCAL; + const int nwarps = NWARPS_Q4_1_PASCAL; + + mul_mat_q, + load_tiles_q4_1, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +#else + (void) vec_dot_q4_1_q8_1_mul_mat; + assert(false); +#endif // __CUDA_ARCH__ >= CC_VOLTA +} + +#define MMQ_X_Q5_0_RDNA2 64 +#define MMQ_Y_Q5_0_RDNA2 128 +#define NWARPS_Q5_0_RDNA2 8 +#define MMQ_X_Q5_0_RDNA1 64 +#define MMQ_Y_Q5_0_RDNA1 64 +#define NWARPS_Q5_0_RDNA1 8 +#define MMQ_X_Q5_0_AMPERE 128 +#define MMQ_Y_Q5_0_AMPERE 64 +#define NWARPS_Q5_0_AMPERE 4 +#define MMQ_X_Q5_0_PASCAL 64 +#define MMQ_Y_Q5_0_PASCAL 64 +#define NWARPS_Q5_0_PASCAL 8 + +template static __global__ void +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + __launch_bounds__(WARP_SIZE*NWARPS_Q5_0_RDNA2, 2) +#endif // defined(RDNA3) || defined(RDNA2) +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + mul_mat_q5_0( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + const int mmq_x = MMQ_X_Q5_0_RDNA2; + const int mmq_y = MMQ_Y_Q5_0_RDNA2; + const int nwarps = NWARPS_Q5_0_RDNA2; +#else + const int mmq_x = MMQ_X_Q5_0_RDNA1; + const int mmq_y = MMQ_Y_Q5_0_RDNA1; + const int nwarps = NWARPS_Q5_0_RDNA1; +#endif // defined(RDNA3) || defined(RDNA2) + + mul_mat_q, + load_tiles_q5_0, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= CC_VOLTA + const int mmq_x = MMQ_X_Q5_0_AMPERE; + const int mmq_y = MMQ_Y_Q5_0_AMPERE; + const int nwarps = NWARPS_Q5_0_AMPERE; + + mul_mat_q, + load_tiles_q5_0, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= MIN_CC_DP4A + const int mmq_x = MMQ_X_Q5_0_PASCAL; + const int mmq_y = MMQ_Y_Q5_0_PASCAL; + const int nwarps = NWARPS_Q5_0_PASCAL; + + mul_mat_q, + load_tiles_q5_0, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +#else + (void) vec_dot_q5_0_q8_1_mul_mat; + assert(false); +#endif // __CUDA_ARCH__ >= CC_VOLTA +} + +#define MMQ_X_Q5_1_RDNA2 64 +#define MMQ_Y_Q5_1_RDNA2 128 +#define NWARPS_Q5_1_RDNA2 8 +#define MMQ_X_Q5_1_RDNA1 64 +#define MMQ_Y_Q5_1_RDNA1 64 +#define NWARPS_Q5_1_RDNA1 8 +#define MMQ_X_Q5_1_AMPERE 128 +#define MMQ_Y_Q5_1_AMPERE 64 +#define NWARPS_Q5_1_AMPERE 4 +#define MMQ_X_Q5_1_PASCAL 64 +#define MMQ_Y_Q5_1_PASCAL 64 +#define NWARPS_Q5_1_PASCAL 8 + +template static __global__ void +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + __launch_bounds__(WARP_SIZE*NWARPS_Q5_1_RDNA2, 2) +#endif // defined(RDNA3) || defined(RDNA2) +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +mul_mat_q5_1( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + const int mmq_x = MMQ_X_Q5_1_RDNA2; + const int mmq_y = MMQ_Y_Q5_1_RDNA2; + const int nwarps = NWARPS_Q5_1_RDNA2; +#else + const int mmq_x = MMQ_X_Q5_1_RDNA1; + const int mmq_y = MMQ_Y_Q5_1_RDNA1; + const int nwarps = NWARPS_Q5_1_RDNA1; +#endif // defined(RDNA3) || defined(RDNA2) + + mul_mat_q, + load_tiles_q5_1, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= CC_VOLTA + const int mmq_x = MMQ_X_Q5_1_AMPERE; + const int mmq_y = MMQ_Y_Q5_1_AMPERE; + const int nwarps = NWARPS_Q5_1_AMPERE; + + mul_mat_q, + load_tiles_q5_1, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= MIN_CC_DP4A + const int mmq_x = MMQ_X_Q5_1_PASCAL; + const int mmq_y = MMQ_Y_Q5_1_PASCAL; + const int nwarps = NWARPS_Q5_1_PASCAL; + + mul_mat_q, + load_tiles_q5_1, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +#else + (void) vec_dot_q5_1_q8_1_mul_mat; + assert(false); +#endif // __CUDA_ARCH__ >= CC_VOLTA +} + +#define MMQ_X_Q8_0_RDNA2 64 +#define MMQ_Y_Q8_0_RDNA2 128 +#define NWARPS_Q8_0_RDNA2 8 +#define MMQ_X_Q8_0_RDNA1 64 +#define MMQ_Y_Q8_0_RDNA1 64 +#define NWARPS_Q8_0_RDNA1 8 +#define MMQ_X_Q8_0_AMPERE 128 +#define MMQ_Y_Q8_0_AMPERE 64 +#define NWARPS_Q8_0_AMPERE 4 +#define MMQ_X_Q8_0_PASCAL 64 +#define MMQ_Y_Q8_0_PASCAL 64 +#define NWARPS_Q8_0_PASCAL 8 + +template static __global__ void +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + __launch_bounds__(WARP_SIZE*NWARPS_Q8_0_RDNA2, 2) +#endif // defined(RDNA3) || defined(RDNA2) +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + mul_mat_q8_0( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + const int mmq_x = MMQ_X_Q8_0_RDNA2; + const int mmq_y = MMQ_Y_Q8_0_RDNA2; + const int nwarps = NWARPS_Q8_0_RDNA2; +#else + const int mmq_x = MMQ_X_Q8_0_RDNA1; + const int mmq_y = MMQ_Y_Q8_0_RDNA1; + const int nwarps = NWARPS_Q8_0_RDNA1; +#endif // defined(RDNA3) || defined(RDNA2) + + mul_mat_q, + load_tiles_q8_0, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= CC_VOLTA + const int mmq_x = MMQ_X_Q8_0_AMPERE; + const int mmq_y = MMQ_Y_Q8_0_AMPERE; + const int nwarps = NWARPS_Q8_0_AMPERE; + + mul_mat_q, + load_tiles_q8_0, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= MIN_CC_DP4A + const int mmq_x = MMQ_X_Q8_0_PASCAL; + const int mmq_y = MMQ_Y_Q8_0_PASCAL; + const int nwarps = NWARPS_Q8_0_PASCAL; + + mul_mat_q, + load_tiles_q8_0, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +#else + (void) vec_dot_q8_0_q8_1_mul_mat; + assert(false); +#endif // __CUDA_ARCH__ >= CC_VOLTA +} + +#define MMQ_X_Q2_K_RDNA2 64 +#define MMQ_Y_Q2_K_RDNA2 128 +#define NWARPS_Q2_K_RDNA2 8 +#define MMQ_X_Q2_K_RDNA1 128 +#define MMQ_Y_Q2_K_RDNA1 32 +#define NWARPS_Q2_K_RDNA1 8 +#define MMQ_X_Q2_K_AMPERE 64 +#define MMQ_Y_Q2_K_AMPERE 128 +#define NWARPS_Q2_K_AMPERE 4 +#define MMQ_X_Q2_K_PASCAL 64 +#define MMQ_Y_Q2_K_PASCAL 64 +#define NWARPS_Q2_K_PASCAL 8 + +template static __global__ void +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + __launch_bounds__(WARP_SIZE*NWARPS_Q2_K_RDNA2, 2) +#endif // defined(RDNA3) || defined(RDNA2) +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +mul_mat_q2_K( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + const int mmq_x = MMQ_X_Q2_K_RDNA2; + const int mmq_y = MMQ_Y_Q2_K_RDNA2; + const int nwarps = NWARPS_Q2_K_RDNA2; +#else + const int mmq_x = MMQ_X_Q2_K_RDNA1; + const int mmq_y = MMQ_Y_Q2_K_RDNA1; + const int nwarps = NWARPS_Q2_K_RDNA1; +#endif // defined(RDNA3) || defined(RDNA2) + + mul_mat_q, + load_tiles_q2_K, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= CC_VOLTA + const int mmq_x = MMQ_X_Q2_K_AMPERE; + const int mmq_y = MMQ_Y_Q2_K_AMPERE; + const int nwarps = NWARPS_Q2_K_AMPERE; + + mul_mat_q, + load_tiles_q2_K, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= MIN_CC_DP4A + const int mmq_x = MMQ_X_Q2_K_PASCAL; + const int mmq_y = MMQ_Y_Q2_K_PASCAL; + const int nwarps = NWARPS_Q2_K_PASCAL; + + mul_mat_q, + load_tiles_q2_K, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +#else + (void) vec_dot_q2_K_q8_1_mul_mat; + assert(false); +#endif // __CUDA_ARCH__ >= CC_VOLTA +} + +#define MMQ_X_Q3_K_RDNA2 128 +#define MMQ_Y_Q3_K_RDNA2 64 +#define NWARPS_Q3_K_RDNA2 8 +#define MMQ_X_Q3_K_RDNA1 32 +#define MMQ_Y_Q3_K_RDNA1 128 +#define NWARPS_Q3_K_RDNA1 8 +#define MMQ_X_Q3_K_AMPERE 128 +#define MMQ_Y_Q3_K_AMPERE 128 +#define NWARPS_Q3_K_AMPERE 4 +#define MMQ_X_Q3_K_PASCAL 64 +#define MMQ_Y_Q3_K_PASCAL 64 +#define NWARPS_Q3_K_PASCAL 8 + +template static __global__ void +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_RDNA2, 2) +#endif // defined(RDNA3) || defined(RDNA2) +#elif __CUDA_ARCH__ < CC_VOLTA + __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_PASCAL, 2) +#endif // __CUDA_ARCH__ < CC_VOLTA + mul_mat_q3_K( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + const int mmq_x = MMQ_X_Q3_K_RDNA2; + const int mmq_y = MMQ_Y_Q3_K_RDNA2; + const int nwarps = NWARPS_Q3_K_RDNA2; +#else + const int mmq_x = MMQ_X_Q3_K_RDNA1; + const int mmq_y = MMQ_Y_Q3_K_RDNA1; + const int nwarps = NWARPS_Q3_K_RDNA1; +#endif // defined(RDNA3) || defined(RDNA2) + + mul_mat_q, + load_tiles_q3_K, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= CC_VOLTA + const int mmq_x = MMQ_X_Q3_K_AMPERE; + const int mmq_y = MMQ_Y_Q3_K_AMPERE; + const int nwarps = NWARPS_Q3_K_AMPERE; + + mul_mat_q, + load_tiles_q3_K, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= MIN_CC_DP4A + const int mmq_x = MMQ_X_Q3_K_PASCAL; + const int mmq_y = MMQ_Y_Q3_K_PASCAL; + const int nwarps = NWARPS_Q3_K_PASCAL; + + mul_mat_q, + load_tiles_q3_K, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +#else + (void) vec_dot_q3_K_q8_1_mul_mat; + assert(false); +#endif // __CUDA_ARCH__ >= CC_VOLTA +} + +#define MMQ_X_Q4_K_RDNA2 64 +#define MMQ_Y_Q4_K_RDNA2 128 +#define NWARPS_Q4_K_RDNA2 8 +#define MMQ_X_Q4_K_RDNA1 32 +#define MMQ_Y_Q4_K_RDNA1 64 +#define NWARPS_Q4_K_RDNA1 8 +#define MMQ_X_Q4_K_AMPERE 64 +#define MMQ_Y_Q4_K_AMPERE 128 +#define NWARPS_Q4_K_AMPERE 4 +#define MMQ_X_Q4_K_PASCAL 64 +#define MMQ_Y_Q4_K_PASCAL 64 +#define NWARPS_Q4_K_PASCAL 8 + +template static __global__ void +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_RDNA2, 2) +#endif // defined(RDNA3) || defined(RDNA2) +#elif __CUDA_ARCH__ < CC_VOLTA + __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_PASCAL, 2) +#endif // __CUDA_ARCH__ < CC_VOLTA + mul_mat_q4_K( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + const int mmq_x = MMQ_X_Q4_K_RDNA2; + const int mmq_y = MMQ_Y_Q4_K_RDNA2; + const int nwarps = NWARPS_Q4_K_RDNA2; +#else + const int mmq_x = MMQ_X_Q4_K_RDNA1; + const int mmq_y = MMQ_Y_Q4_K_RDNA1; + const int nwarps = NWARPS_Q4_K_RDNA1; +#endif // defined(RDNA3) || defined(RDNA2) + + mul_mat_q, + load_tiles_q4_K, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= CC_VOLTA + const int mmq_x = MMQ_X_Q4_K_AMPERE; + const int mmq_y = MMQ_Y_Q4_K_AMPERE; + const int nwarps = NWARPS_Q4_K_AMPERE; + + mul_mat_q, + load_tiles_q4_K, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= MIN_CC_DP4A + const int mmq_x = MMQ_X_Q4_K_PASCAL; + const int mmq_y = MMQ_Y_Q4_K_PASCAL; + const int nwarps = NWARPS_Q4_K_PASCAL; + + mul_mat_q, + load_tiles_q4_K, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +#else + (void) vec_dot_q4_K_q8_1_mul_mat; + assert(false); +#endif // __CUDA_ARCH__ >= CC_VOLTA +} + +#define MMQ_X_Q5_K_RDNA2 64 +#define MMQ_Y_Q5_K_RDNA2 128 +#define NWARPS_Q5_K_RDNA2 8 +#define MMQ_X_Q5_K_RDNA1 32 +#define MMQ_Y_Q5_K_RDNA1 64 +#define NWARPS_Q5_K_RDNA1 8 +#define MMQ_X_Q5_K_AMPERE 64 +#define MMQ_Y_Q5_K_AMPERE 128 +#define NWARPS_Q5_K_AMPERE 4 +#define MMQ_X_Q5_K_PASCAL 64 +#define MMQ_Y_Q5_K_PASCAL 64 +#define NWARPS_Q5_K_PASCAL 8 + +template static __global__ void +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + __launch_bounds__(WARP_SIZE*NWARPS_Q5_K_RDNA2, 2) +#endif // defined(RDNA3) || defined(RDNA2) +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +mul_mat_q5_K( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + const int mmq_x = MMQ_X_Q5_K_RDNA2; + const int mmq_y = MMQ_Y_Q5_K_RDNA2; + const int nwarps = NWARPS_Q5_K_RDNA2; +#else + const int mmq_x = MMQ_X_Q5_K_RDNA1; + const int mmq_y = MMQ_Y_Q5_K_RDNA1; + const int nwarps = NWARPS_Q5_K_RDNA1; +#endif // defined(RDNA3) || defined(RDNA2) + + mul_mat_q, + load_tiles_q5_K, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= CC_VOLTA + const int mmq_x = MMQ_X_Q5_K_AMPERE; + const int mmq_y = MMQ_Y_Q5_K_AMPERE; + const int nwarps = NWARPS_Q5_K_AMPERE; + + mul_mat_q, + load_tiles_q5_K, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= MIN_CC_DP4A + const int mmq_x = MMQ_X_Q5_K_PASCAL; + const int mmq_y = MMQ_Y_Q5_K_PASCAL; + const int nwarps = NWARPS_Q5_K_PASCAL; + + mul_mat_q, + load_tiles_q5_K, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +#else + (void) vec_dot_q5_K_q8_1_mul_mat; + assert(false); +#endif // __CUDA_ARCH__ >= CC_VOLTA +} + +#define MMQ_X_Q6_K_RDNA2 64 +#define MMQ_Y_Q6_K_RDNA2 128 +#define NWARPS_Q6_K_RDNA2 8 +#define MMQ_X_Q6_K_RDNA1 32 +#define MMQ_Y_Q6_K_RDNA1 64 +#define NWARPS_Q6_K_RDNA1 8 +#define MMQ_X_Q6_K_AMPERE 64 +#define MMQ_Y_Q6_K_AMPERE 64 +#define NWARPS_Q6_K_AMPERE 4 +#define MMQ_X_Q6_K_PASCAL 64 +#define MMQ_Y_Q6_K_PASCAL 64 +#define NWARPS_Q6_K_PASCAL 8 + +template static __global__ void +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_RDNA2, 2) +#endif // defined(RDNA3) || defined(RDNA2) +#elif __CUDA_ARCH__ < CC_VOLTA + __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_PASCAL, 2) +#endif // __CUDA_ARCH__ < CC_VOLTA + mul_mat_q6_K( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + const int mmq_x = MMQ_X_Q6_K_RDNA2; + const int mmq_y = MMQ_Y_Q6_K_RDNA2; + const int nwarps = NWARPS_Q6_K_RDNA2; +#else + const int mmq_x = MMQ_X_Q6_K_RDNA1; + const int mmq_y = MMQ_Y_Q6_K_RDNA1; + const int nwarps = NWARPS_Q6_K_RDNA1; +#endif // defined(RDNA3) || defined(RDNA2) + + mul_mat_q, + load_tiles_q6_K, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= CC_VOLTA + const int mmq_x = MMQ_X_Q6_K_AMPERE; + const int mmq_y = MMQ_Y_Q6_K_AMPERE; + const int nwarps = NWARPS_Q6_K_AMPERE; + + mul_mat_q, + load_tiles_q6_K, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + +#elif __CUDA_ARCH__ >= MIN_CC_DP4A + const int mmq_x = MMQ_X_Q6_K_PASCAL; + const int mmq_y = MMQ_Y_Q6_K_PASCAL; + const int nwarps = NWARPS_Q6_K_PASCAL; + + mul_mat_q, + load_tiles_q6_K, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +#else + (void) vec_dot_q6_K_q8_1_mul_mat; + assert(false); +#endif // __CUDA_ARCH__ >= CC_VOLTA +} + +template +static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) { + const int row = blockIdx.y*blockDim.y + threadIdx.y; + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + +// partial sum for each thread + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = 0; i < blocks_per_row; i += blocks_per_warp) { + const int ibx = row*blocks_per_row + i + threadIdx.x / (qi/vdr); // x block index + + const int iby = (i + threadIdx.x / (qi/vdr)) * (qk/QK8_1); // y block index that aligns with ibx + + const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs); + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[row] = tmp; + } +} + +template +static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { + // qk = quantized weights per x block + // qr = number of quantized weights per data value in x block + const int row = blockIdx.y*blockDim.y + threadIdx.y; + + if (row >= nrows) { + return; + } + + const int tid = threadIdx.x; + + const int iter_stride = 2*GGML_CUDA_DMMV_X; + const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter + const int y_offset = qr == 1 ? 1 : qk/2; + +// partial sum for each thread +#ifdef GGML_CUDA_F16 + half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics +#else + float tmp = 0.0f; +#endif // GGML_CUDA_F16 + + for (int i = 0; i < ncols; i += iter_stride) { + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/qk; // x block index + const int iqs = (col%qk)/qr; // x quant index + const int iybs = col - col%qk; // y block start index + +// processing >2 values per i iter is faster for fast GPUs +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 2) { + // process 2 vals per j iter + + // dequantize + // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val + dfloat2 v; + dequantize_kernel(vx, ib, iqs + j/qr, v); + + // matrix multiplication + // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 +#ifdef GGML_CUDA_F16 + tmp += __hmul2(v, { + y[iybs + iqs + j/qr + 0], + y[iybs + iqs + j/qr + y_offset] + }); +#else + tmp += v.x * y[iybs + iqs + j/qr + 0]; + tmp += v.y * y[iybs + iqs + j/qr + y_offset]; +#endif // GGML_CUDA_F16 + } + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { +#ifdef GGML_CUDA_F16 + dst[row] = tmp.x + tmp.y; +#else + dst[row] = tmp; +#endif // GGML_CUDA_F16 + } +} + +static __global__ void mul_mat_p021_f16_f32( + const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) { + + const half * x = (const half *) vx; + + const int row_x = blockDim.y*blockIdx.y + threadIdx.y; + const int channel = blockDim.z*blockIdx.z + threadIdx.z; + const int channel_x = channel / (nchannels_y / nchannels_x); + + const int nrows_y = ncols_x; + const int nrows_dst = nrows_x; + const int row_dst = row_x; + + float tmp = 0.0f; + + for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) { + const int col_x = col_x0 + threadIdx.x; + + if (col_x >= ncols_x) { + break; + } + + // x is transposed and permuted + const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x; + const float xi = __half2float(x[ix]); + + const int row_y = col_x; + + + // y is not transposed but permuted + const int iy = channel*nrows_y + row_y; + + tmp += xi * y[iy]; + } + + // dst is not transposed and not permuted + const int idst = channel*nrows_dst + row_dst; + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[idst] = tmp; + } +} + +static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous + const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, + const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) { + + const half * x = (const half *) vx; + + const int row_x = blockDim.y*blockIdx.y + threadIdx.y; + const int channel = blockDim.z*blockIdx.z + threadIdx.z; + const int channel_x = channel / channel_x_divisor; + + const int nrows_y = ncols_x; + const int nrows_dst = nrows_x; + const int row_dst = row_x; + + const int idst = channel*nrows_dst + row_dst; + + float tmp = 0.0f; + + for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) { + const int col_x = col_x0 + threadIdx.x; + + if (col_x >= ncols_x) { + break; + } + + const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x; + const float xi = __half2float(x[ix]); + + const int row_y = col_x; + + const int iy = channel*nrows_y + row_y; + + tmp += xi * y[iy]; + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[idst] = tmp; + } +} + +static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + float * dsti = (float *) cdsti; + + *dsti = *xi; +} + +static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) { + const float * xi = (const float *) cxi; + half * dsti = (half *) cdsti; + + *dsti = __float2half(*xi); +} + +template +static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, + const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= ne) { + return; + } + + // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor + // then combine those indices with the corresponding byte offsets to get the total offsets + const int i02 = i / (ne00*ne01); + const int i01 = (i - i02*ne01*ne00) / ne00; + const int i00 = i - i02*ne01*ne00 - i01*ne00; + const int x_offset = i00*nb00 + i01*nb01 + i02*nb02; + + const int i12 = i / (ne10*ne11); + const int i11 = (i - i12*ne10*ne11) / ne10; + const int i10 = i - i12*ne10*ne11 - i11*ne10; + const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12; + + cpy_1(cx + x_offset, cdst + dst_offset); +} + +// rope == RoPE == rotary positional embedding + +template +static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale, + const int p_delta_rows, const float theta_scale) { + const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); + + if (col >= ncols) { + return; + } + + const int row = blockDim.x*blockIdx.x + threadIdx.x; + const int i = row*ncols + col; + const int i2 = row/p_delta_rows; + + const int p = has_pos ? pos[i2] : 0; + const float p0 = p*freq_scale; + const float theta = p0*powf(theta_scale, col/2); + const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta); + + const float x0 = x[i + 0]; + const float x1 = x[i + 1]; + + dst[i + 0] = x0*cos_theta - x1*sin_theta; + dst[i + 1] = x0*sin_theta + x1*cos_theta; +} + +template +static __global__ void rope_neox(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale, + const int p_delta_rows, const float theta_scale) { + const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); + + if (col >= ncols) { + return; + } + + const int row = blockDim.x*blockIdx.x + threadIdx.x; + const int i = row*ncols + col/2; + const int i2 = row/p_delta_rows; + + const int p = has_pos ? pos[i2] : 0; + const float p0 = p*freq_scale; + const float theta = p0*powf(theta_scale, col/2); + const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta); + + const float x0 = x[i + 0]; + const float x1 = x[i + ncols/2]; + + dst[i + 0] = x0*cos_theta - x1*sin_theta; + dst[i + ncols/2] = x0*sin_theta + x1*cos_theta; +} + +static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale, + const int p_delta_rows, const float theta_scale, const int n_ctx) { + const int col = blockDim.x*blockIdx.x + threadIdx.x; + const int half_n_dims = ncols/4; + + if (col >= half_n_dims) { + return; + } + + const int row = blockDim.y*blockIdx.y + threadIdx.y; + const int i = row*ncols + col; + const int i2 = row/p_delta_rows; + + const float col_theta_scale = powf(theta_scale, col); + // FIXME: this is likely wrong + const int p = pos != nullptr ? pos[i2] : 0; + + const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale; + const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta); + + const float x0 = x[i + 0]; + const float x1 = x[i + half_n_dims]; + + dst[i + 0] = x0*cos_theta - x1*sin_theta; + dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta; + + const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale; + const float sin_block_theta = sinf(block_theta); + const float cos_block_theta = cosf(block_theta); + + const float x2 = x[i + half_n_dims * 2]; + const float x3 = x[i + half_n_dims * 3]; + + dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta; + dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta; +} + +static __global__ void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows, + const int n_heads_log2_floor, const float m0, const float m1) { + const int col = blockDim.x*blockIdx.x + threadIdx.x; + + if (col >= ncols) { + return; + } + + const int row = blockDim.y*blockIdx.y + threadIdx.y; + const int i = row*ncols + col; + + const int k = row/k_rows; + + float m_k; + if (k < n_heads_log2_floor) { + m_k = powf(m0, k + 1); + } else { + m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); + } + + dst[i] = col * m_k + x[i]; +} + +static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) { + const int col = blockDim.y*blockIdx.y + threadIdx.y; + const int row = blockDim.x*blockIdx.x + threadIdx.x; + + if (col >= ncols) { + return; + } + + const int i = row*ncols + col; + // dst[i] = col > n_past + row ? -INFINITY : x[i]; + dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU +} + +// the CUDA soft max implementation differs from the CPU implementation +// instead of doubles floats are used +static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) { + const int row = blockDim.x*blockIdx.x + threadIdx.x; + const int block_size = blockDim.y; + const int tid = threadIdx.y; + + float max_val = -INFINITY; + + for (int col = tid; col < ncols; col += block_size) { + const int i = row*ncols + col; + max_val = max(max_val, x[i]); + } + + // find the max value in the block +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32)); + } + + float tmp = 0.f; + + for (int col = tid; col < ncols; col += block_size) { + const int i = row*ncols + col; + const float val = expf(x[i] - max_val); + tmp += val; + dst[i] = val; + } + + // sum up partial sums +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + const float inv_tmp = 1.f / tmp; + + for (int col = tid; col < ncols; col += block_size) { + const int i = row*ncols + col; + dst[i] *= inv_tmp; + } +} + +static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + dst[i] = scale * x[i]; +} + +static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { + const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; + add_f32<<>>(x, y, dst, kx, ky); +} + +static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; + add_f16_f32_f16<<>>(x, y, dst, k); +} + +static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { + const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE; + mul_f32<<>>(x, y, dst, kx, ky); +} + +static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; + gelu_f32<<>>(x, dst, k); +} + +static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; + silu_f32<<>>(x, dst, k); +} + +static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % WARP_SIZE == 0); + if (ncols < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + norm_f32<<>>(x, dst, ncols); + } else { + const dim3 block_dims(1024, 1, 1); + norm_f32<1024><<>>(x, dst, ncols); + } +} + +static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { + GGML_ASSERT(ncols % WARP_SIZE == 0); + if (ncols < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + rms_norm_f32<<>>(x, dst, ncols, eps); + } else { + const dim3 block_dims(1024, 1, 1); + rms_norm_f32<1024><<>>(x, dst, ncols, eps); + } +} + +static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) { + const int block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; + const dim3 num_blocks(block_num_x, ky, 1); + const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1); + quantize_q8_1<<>>(x, vy, kx, kx_padded); +} + +template +static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block<<>>(vx, y, k); +} + +template +static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block<<>>(vx, y, k); +} + +template +static void dequantize_row_q5_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block<<>>(vx, y, k); +} + +template +static void dequantize_row_q5_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block<<>>(vx, y, k); +} + +template +static void dequantize_row_q8_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block<<>>(vx, y, k); +} + +template +static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; +#if QK_K == 256 + dequantize_block_q2_K<<>>(vx, y); +#else + dequantize_block_q2_K<<>>(vx, y); +#endif +} + +template +static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; +#if QK_K == 256 + dequantize_block_q3_K<<>>(vx, y); +#else + dequantize_block_q3_K<<>>(vx, y); +#endif +} + +template +static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_q4_K<<>>(vx, y); +} + +template +static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; +#if QK_K == 256 + dequantize_block_q5_K<<>>(vx, y); +#else + dequantize_block_q5_K<<>>(vx, y); +#endif +} + +template +static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = k / QK_K; +#if QK_K == 256 + dequantize_block_q6_K<<>>(vx, y); +#else + dequantize_block_q6_K<<>>(vx, y); +#endif +} + +static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols, nrows); +} + +static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols, nrows); +} + +static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols, nrows); +} + +static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols, nrows); +} + +static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols, nrows); +} + +static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2 + const int block_num_y = (nrows + ny - 1) / ny; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(32, ny, 1); + dequantize_mul_mat_vec_q2_k<<>>(vx, y, dst, ncols, nrows); +} + +static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int ny = 2 / K_QUANTS_PER_ITERATION; + const int block_num_y = (nrows + ny - 1) / ny; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(32, ny, 1); + dequantize_mul_mat_vec_q3_k<<>>(vx, y, dst, ncols, nrows); +} + +static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int ny = 2 / K_QUANTS_PER_ITERATION; + const int block_num_y = (nrows + ny - 1) / ny; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(32, ny, 1); + dequantize_mul_mat_vec_q4_k<<>>(vx, y, dst, ncols, nrows); +} + +static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const dim3 block_dims(32, 1, 1); + dequantize_mul_mat_vec_q5_k<<>>(vx, y, dst, ncols); +} + +static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int ny = 2 / K_QUANTS_PER_ITERATION; + const int block_num_y = (nrows + ny - 1) / ny; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(32, ny, 1); + dequantize_mul_mat_vec_q6_k<<>>(vx, y, dst, ncols, nrows); +} + +static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK4_0 == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK4_1 == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK5_0 == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK5_1 == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK8_0 == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block<1, 1, convert_f16><<>>(vx, y, k); +} + +static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; + dequantize_block<1, 1, convert_f32><<>>(vx, y, k); +} + +static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + dequantize_mul_mat_vec<1, 1, convert_f16> + <<>>(vx, y, dst, ncols, nrows); +} + +static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return dequantize_row_q4_0_cuda; + case GGML_TYPE_Q4_1: + return dequantize_row_q4_1_cuda; + case GGML_TYPE_Q5_0: + return dequantize_row_q5_0_cuda; + case GGML_TYPE_Q5_1: + return dequantize_row_q5_1_cuda; + case GGML_TYPE_Q8_0: + return dequantize_row_q8_0_cuda; + case GGML_TYPE_Q2_K: + return dequantize_row_q2_K_cuda; + case GGML_TYPE_Q3_K: + return dequantize_row_q3_K_cuda; + case GGML_TYPE_Q4_K: + return dequantize_row_q4_K_cuda; + case GGML_TYPE_Q5_K: + return dequantize_row_q5_K_cuda; + case GGML_TYPE_Q6_K: + return dequantize_row_q6_K_cuda; + case GGML_TYPE_F32: + return convert_fp32_to_fp16_cuda; + default: + return nullptr; + } +} + +static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return dequantize_row_q4_0_cuda; + case GGML_TYPE_Q4_1: + return dequantize_row_q4_1_cuda; + case GGML_TYPE_Q5_0: + return dequantize_row_q5_0_cuda; + case GGML_TYPE_Q5_1: + return dequantize_row_q5_1_cuda; + case GGML_TYPE_Q8_0: + return dequantize_row_q8_0_cuda; + case GGML_TYPE_Q2_K: + return dequantize_row_q2_K_cuda; + case GGML_TYPE_Q3_K: + return dequantize_row_q3_K_cuda; + case GGML_TYPE_Q4_K: + return dequantize_row_q4_K_cuda; + case GGML_TYPE_Q5_K: + return dequantize_row_q5_K_cuda; + case GGML_TYPE_Q6_K: + return dequantize_row_q6_K_cuda; + case GGML_TYPE_F16: + return convert_fp16_to_fp32_cuda; + default: + return nullptr; + } +} + +static void ggml_mul_mat_q4_0_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + const int compute_capability = g_compute_capabilities[id]; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= CC_RDNA2) { + mmq_x = MMQ_X_Q4_0_RDNA2; + mmq_y = MMQ_Y_Q4_0_RDNA2; + nwarps = NWARPS_Q4_0_RDNA2; + } else if (compute_capability >= CC_OFFSET_AMD) { + mmq_x = MMQ_X_Q4_0_RDNA1; + mmq_y = MMQ_Y_Q4_0_RDNA1; + nwarps = NWARPS_Q4_0_RDNA1; + } else if (compute_capability >= CC_VOLTA) { + mmq_x = MMQ_X_Q4_0_AMPERE; + mmq_y = MMQ_Y_Q4_0_AMPERE; + nwarps = NWARPS_Q4_0_AMPERE; + } else if (compute_capability >= MIN_CC_DP4A) { + mmq_x = MMQ_X_Q4_0_PASCAL; + mmq_y = MMQ_Y_Q4_0_PASCAL; + nwarps = NWARPS_Q4_0_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q4_0<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q4_0<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +static void ggml_mul_mat_q4_1_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + const int compute_capability = g_compute_capabilities[id]; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= CC_RDNA2) { + mmq_x = MMQ_X_Q4_1_RDNA2; + mmq_y = MMQ_Y_Q4_1_RDNA2; + nwarps = NWARPS_Q4_1_RDNA2; + } else if (compute_capability >= CC_OFFSET_AMD) { + mmq_x = MMQ_X_Q4_1_RDNA1; + mmq_y = MMQ_Y_Q4_1_RDNA1; + nwarps = NWARPS_Q4_1_RDNA1; + } else if (compute_capability >= CC_VOLTA) { + mmq_x = MMQ_X_Q4_1_AMPERE; + mmq_y = MMQ_Y_Q4_1_AMPERE; + nwarps = NWARPS_Q4_1_AMPERE; + } else if (compute_capability >= MIN_CC_DP4A) { + mmq_x = MMQ_X_Q4_1_PASCAL; + mmq_y = MMQ_Y_Q4_1_PASCAL; + nwarps = NWARPS_Q4_1_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q4_1<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q4_1<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +static void ggml_mul_mat_q5_0_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + const int compute_capability = g_compute_capabilities[id]; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= CC_RDNA2) { + mmq_x = MMQ_X_Q5_0_RDNA2; + mmq_y = MMQ_Y_Q5_0_RDNA2; + nwarps = NWARPS_Q5_0_RDNA2; + } else if (compute_capability >= CC_OFFSET_AMD) { + mmq_x = MMQ_X_Q5_0_RDNA1; + mmq_y = MMQ_Y_Q5_0_RDNA1; + nwarps = NWARPS_Q5_0_RDNA1; + } else if (compute_capability >= CC_VOLTA) { + mmq_x = MMQ_X_Q5_0_AMPERE; + mmq_y = MMQ_Y_Q5_0_AMPERE; + nwarps = NWARPS_Q5_0_AMPERE; + } else if (compute_capability >= MIN_CC_DP4A) { + mmq_x = MMQ_X_Q5_0_PASCAL; + mmq_y = MMQ_Y_Q5_0_PASCAL; + nwarps = NWARPS_Q5_0_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q5_0<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q5_0<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +static void ggml_mul_mat_q5_1_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + const int compute_capability = g_compute_capabilities[id]; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= CC_RDNA2) { + mmq_x = MMQ_X_Q5_1_RDNA2; + mmq_y = MMQ_Y_Q5_1_RDNA2; + nwarps = NWARPS_Q5_1_RDNA2; + } else if (compute_capability >= CC_OFFSET_AMD) { + mmq_x = MMQ_X_Q5_1_RDNA1; + mmq_y = MMQ_Y_Q5_1_RDNA1; + nwarps = NWARPS_Q5_1_RDNA1; + } else if (compute_capability >= CC_VOLTA) { + mmq_x = MMQ_X_Q5_1_AMPERE; + mmq_y = MMQ_Y_Q5_1_AMPERE; + nwarps = NWARPS_Q5_1_AMPERE; + } else if (compute_capability >= MIN_CC_DP4A) { + mmq_x = MMQ_X_Q5_1_PASCAL; + mmq_y = MMQ_Y_Q5_1_PASCAL; + nwarps = NWARPS_Q5_1_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q5_1<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q5_1<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +static void ggml_mul_mat_q8_0_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + const int compute_capability = g_compute_capabilities[id]; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= CC_RDNA2) { + mmq_x = MMQ_X_Q8_0_RDNA2; + mmq_y = MMQ_Y_Q8_0_RDNA2; + nwarps = NWARPS_Q8_0_RDNA2; + } else if (compute_capability >= CC_OFFSET_AMD) { + mmq_x = MMQ_X_Q8_0_RDNA1; + mmq_y = MMQ_Y_Q8_0_RDNA1; + nwarps = NWARPS_Q8_0_RDNA1; + } else if (compute_capability >= CC_VOLTA) { + mmq_x = MMQ_X_Q8_0_AMPERE; + mmq_y = MMQ_Y_Q8_0_AMPERE; + nwarps = NWARPS_Q8_0_AMPERE; + } else if (compute_capability >= MIN_CC_DP4A) { + mmq_x = MMQ_X_Q8_0_PASCAL; + mmq_y = MMQ_Y_Q8_0_PASCAL; + nwarps = NWARPS_Q8_0_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q8_0<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q8_0<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +static void ggml_mul_mat_q2_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + const int compute_capability = g_compute_capabilities[id]; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= CC_RDNA2) { + mmq_x = MMQ_X_Q2_K_RDNA2; + mmq_y = MMQ_Y_Q2_K_RDNA2; + nwarps = NWARPS_Q2_K_RDNA2; + } else if (compute_capability >= CC_OFFSET_AMD) { + mmq_x = MMQ_X_Q2_K_RDNA1; + mmq_y = MMQ_Y_Q2_K_RDNA1; + nwarps = NWARPS_Q2_K_RDNA1; + } else if (compute_capability >= CC_VOLTA) { + mmq_x = MMQ_X_Q2_K_AMPERE; + mmq_y = MMQ_Y_Q2_K_AMPERE; + nwarps = NWARPS_Q2_K_AMPERE; + } else if (compute_capability >= MIN_CC_DP4A) { + mmq_x = MMQ_X_Q2_K_PASCAL; + mmq_y = MMQ_Y_Q2_K_PASCAL; + nwarps = NWARPS_Q2_K_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q2_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q2_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +static void ggml_mul_mat_q3_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + +#if QK_K == 256 + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + const int compute_capability = g_compute_capabilities[id]; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= CC_RDNA2) { + mmq_x = MMQ_X_Q3_K_RDNA2; + mmq_y = MMQ_Y_Q3_K_RDNA2; + nwarps = NWARPS_Q3_K_RDNA2; + } else if (compute_capability >= CC_OFFSET_AMD) { + mmq_x = MMQ_X_Q3_K_RDNA1; + mmq_y = MMQ_Y_Q3_K_RDNA1; + nwarps = NWARPS_Q3_K_RDNA1; + } else if (compute_capability >= CC_VOLTA) { + mmq_x = MMQ_X_Q3_K_AMPERE; + mmq_y = MMQ_Y_Q3_K_AMPERE; + nwarps = NWARPS_Q3_K_AMPERE; + } else if (compute_capability >= MIN_CC_DP4A) { + mmq_x = MMQ_X_Q3_K_PASCAL; + mmq_y = MMQ_Y_Q3_K_PASCAL; + nwarps = NWARPS_Q3_K_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q3_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q3_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +#endif +} + +static void ggml_mul_mat_q4_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + const int compute_capability = g_compute_capabilities[id]; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= CC_RDNA2) { + mmq_x = MMQ_X_Q4_K_RDNA2; + mmq_y = MMQ_Y_Q4_K_RDNA2; + nwarps = NWARPS_Q4_K_RDNA2; + } else if (compute_capability >= CC_OFFSET_AMD) { + mmq_x = MMQ_X_Q4_K_RDNA1; + mmq_y = MMQ_Y_Q4_K_RDNA1; + nwarps = NWARPS_Q4_K_RDNA1; + } else if (compute_capability >= CC_VOLTA) { + mmq_x = MMQ_X_Q4_K_AMPERE; + mmq_y = MMQ_Y_Q4_K_AMPERE; + nwarps = NWARPS_Q4_K_AMPERE; + } else if (compute_capability >= MIN_CC_DP4A) { + mmq_x = MMQ_X_Q4_K_PASCAL; + mmq_y = MMQ_Y_Q4_K_PASCAL; + nwarps = NWARPS_Q4_K_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q4_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q4_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +static void ggml_mul_mat_q5_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + const int compute_capability = g_compute_capabilities[id]; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= CC_RDNA2) { + mmq_x = MMQ_X_Q5_K_RDNA2; + mmq_y = MMQ_Y_Q5_K_RDNA2; + nwarps = NWARPS_Q5_K_RDNA2; + } else if (compute_capability >= CC_OFFSET_AMD) { + mmq_x = MMQ_X_Q5_K_RDNA1; + mmq_y = MMQ_Y_Q5_K_RDNA1; + nwarps = NWARPS_Q5_K_RDNA1; + } else if (compute_capability >= CC_VOLTA) { + mmq_x = MMQ_X_Q5_K_AMPERE; + mmq_y = MMQ_Y_Q5_K_AMPERE; + nwarps = NWARPS_Q5_K_AMPERE; + } else if (compute_capability >= MIN_CC_DP4A) { + mmq_x = MMQ_X_Q5_K_PASCAL; + mmq_y = MMQ_Y_Q5_K_PASCAL; + nwarps = NWARPS_Q5_K_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q5_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q5_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +static void ggml_mul_mat_q6_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, + const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + const int compute_capability = g_compute_capabilities[id]; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= CC_RDNA2) { + mmq_x = MMQ_X_Q6_K_RDNA2; + mmq_y = MMQ_Y_Q6_K_RDNA2; + nwarps = NWARPS_Q6_K_RDNA2; + } else if (compute_capability >= CC_OFFSET_AMD) { + mmq_x = MMQ_X_Q6_K_RDNA1; + mmq_y = MMQ_Y_Q6_K_RDNA1; + nwarps = NWARPS_Q6_K_RDNA1; + } else if (compute_capability >= CC_VOLTA) { + mmq_x = MMQ_X_Q6_K_AMPERE; + mmq_y = MMQ_Y_Q6_K_AMPERE; + nwarps = NWARPS_Q6_K_AMPERE; + } else if (compute_capability >= MIN_CC_DP4A) { + mmq_x = MMQ_X_Q6_K_PASCAL; + mmq_y = MMQ_Y_Q6_K_PASCAL; + nwarps = NWARPS_Q6_K_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + mul_mat_q6_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } else { + const bool need_check = true; + mul_mat_q6_K<<>> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + } +} + +static void ggml_mul_mat_p021_f16_f32_cuda( + const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, + const int nchannels_x, const int nchannels_y, cudaStream_t stream) { + + const dim3 block_nums(1, nrows_x, nchannels_y); + const dim3 block_dims(WARP_SIZE, 1, 1); + mul_mat_p021_f16_f32<<>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y); +} + +static void ggml_mul_mat_vec_nc_f16_f32_cuda( + const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x, + const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) { + + const dim3 block_nums(1, nrows_x, nchannels_y); + const dim3 block_dims(WARP_SIZE, 1, 1); + mul_mat_vec_nc_f16_f32<<>> + (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x); +} + +static void ggml_cpy_f32_f32_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, + const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { + + const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + cpy_f32_f16<<>> + (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); +} + +static void ggml_cpy_f32_f16_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, + const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { + + const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + cpy_f32_f16<<>> + (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); +} + +static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; + scale_f32<<>>(x, dst, scale, k); +} + +template +static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, + const int p_delta_rows, const float theta_scale, cudaStream_t stream) { + GGML_ASSERT(ncols % 2 == 0); + const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); + const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const dim3 block_nums(nrows, num_blocks_x, 1); + if (pos == nullptr) { + rope<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); + } else { + rope<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); + } +} + +template +static void rope_neox_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, + const int p_delta_rows, const float theta_scale, cudaStream_t stream) { + GGML_ASSERT(ncols % 2 == 0); + const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); + const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const dim3 block_nums(nrows, num_blocks_x, 1); + if (pos == nullptr) { + rope_neox<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); + } else { + rope_neox<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); + } +} + +static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, + const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) { + GGML_ASSERT(ncols % 4 == 0); + const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1); + const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE; + const dim3 block_nums(num_blocks_x, nrows, 1); + rope_glm_f32<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale, n_ctx); +} + +static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, + const int k_rows, const int n_heads_log2_floor, const float m0, + const float m1, cudaStream_t stream) { + const dim3 block_dims(CUDA_ALIBI_BLOCK_SIZE, 1, 1); + const int num_blocks_x = (ncols + CUDA_ALIBI_BLOCK_SIZE - 1) / (CUDA_ALIBI_BLOCK_SIZE); + const dim3 block_nums(num_blocks_x, nrows, 1); + alibi_f32<<>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1); +} + +static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) { + const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1); + const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE; + const dim3 block_nums(nrows_x, block_num_x, 1); + diag_mask_inf_f32<<>>(x, dst, ncols_x, rows_per_channel, n_past); +} + +static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) { + const dim3 block_dims(1, WARP_SIZE, 1); + const dim3 block_nums(nrows_x, 1, 1); + soft_max_f32<<>>(x, dst, ncols_x); +} + +// buffer pool for cuda +#define MAX_CUDA_BUFFERS 256 + +struct scoped_spin_lock { + std::atomic_flag& lock; + scoped_spin_lock(std::atomic_flag& lock) : lock(lock) { + while (lock.test_and_set(std::memory_order_acquire)) { + ; // spin + } + } + ~scoped_spin_lock() { + lock.clear(std::memory_order_release); + } + scoped_spin_lock(const scoped_spin_lock&) = delete; + scoped_spin_lock& operator=(const scoped_spin_lock&) = delete; +}; + +struct cuda_buffer { + void * ptr = nullptr; + size_t size = 0; +}; + +static cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS]; +static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT; + +static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { + scoped_spin_lock lock(g_cuda_pool_lock); + int id; + CUDA_CHECK(cudaGetDevice(&id)); +#ifdef DEBUG_CUDA_MALLOC + int nnz = 0; + size_t max_size = 0, tot_size = 0; +#endif + size_t best_diff = 1ull << 36; + int ibest = -1; + for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { + cuda_buffer& b = g_cuda_buffer_pool[id][i]; + if (b.ptr != nullptr) { +#ifdef DEBUG_CUDA_MALLOC + ++nnz; + tot_size += b.size; + if (b.size > max_size) max_size = b.size; +#endif + if (b.size >= size) { + size_t diff = b.size - size; + if (diff < best_diff) { + best_diff = diff; + ibest = i; + if (!best_diff) { + void * ptr = b.ptr; + *actual_size = b.size; + b.ptr = nullptr; + b.size = 0; + return ptr; + } + } + } + } + } + if (ibest >= 0) { + cuda_buffer& b = g_cuda_buffer_pool[id][ibest]; + void * ptr = b.ptr; + *actual_size = b.size; + b.ptr = nullptr; + b.size = 0; + return ptr; + } +#ifdef DEBUG_CUDA_MALLOC + fprintf(stderr, "%s: %d buffers, max_size = %u MB, tot_size = %u MB, requested %u MB\n", __func__, nnz, + (uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024)); +#endif + void * ptr; + size_t look_ahead_size = (size_t) (1.05 * size); + look_ahead_size = 256 * ((look_ahead_size + 255)/256); + CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size)); + *actual_size = look_ahead_size; + return ptr; +} + +static void ggml_cuda_pool_free(void * ptr, size_t size) { + scoped_spin_lock lock(g_cuda_pool_lock); + int id; + CUDA_CHECK(cudaGetDevice(&id)); + + for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { + cuda_buffer& b = g_cuda_buffer_pool[id][i]; + if (b.ptr == nullptr) { + b.ptr = ptr; + b.size = size; + return; + } + } + fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n"); + CUDA_CHECK(cudaFree(ptr)); +} + + +void ggml_init_cublas() { + static bool initialized = false; + + if (!initialized) { + +#ifdef __HIP_PLATFORM_AMD__ + // Workaround for a rocBLAS bug when using multiple graphics cards: + // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346 + rocblas_initialize(); + CUDA_CHECK(cudaDeviceSynchronize()); +#endif + + CUDA_CHECK(cudaGetDeviceCount(&g_device_count)); + GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES); + int64_t total_vram = 0; + fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count); + for (int64_t id = 0; id < g_device_count; ++id) { + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); + fprintf(stderr, " Device %ld: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor); + + g_tensor_split[id] = total_vram; + total_vram += prop.totalGlobalMem; +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + g_compute_capabilities[id] = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD; +#else + g_compute_capabilities[id] = 100*prop.major + 10*prop.minor; +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + } + for (int64_t id = 0; id < g_device_count; ++id) { + g_tensor_split[id] /= total_vram; + } + + for (int64_t id = 0; id < g_device_count; ++id) { + CUDA_CHECK(ggml_cuda_set_device(id)); + + // create cuda streams + for (int64_t is = 0; is < MAX_STREAMS; ++is) { + CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[id][is], cudaStreamNonBlocking)); + } + + // create cublas handle + CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id])); + CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH)); + } + + // configure logging to stdout + // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); + + initialized = true; + } +} + +void ggml_cuda_set_tensor_split(const float * tensor_split) { + if (tensor_split == nullptr) { + return; + } + bool all_zero = true; + for (int i = 0; i < g_device_count; ++i) { + if (tensor_split[i] != 0.0f) { + all_zero = false; + break; + } + } + if (all_zero) { + return; + } + float split_sum = 0.0f; + for (int i = 0; i < g_device_count; ++i) { + g_tensor_split[i] = split_sum; + split_sum += tensor_split[i]; + } + for (int i = 0; i < g_device_count; ++i) { + g_tensor_split[i] /= split_sum; + } +} + +void * ggml_cuda_host_malloc(size_t size) { + if (getenv("GGML_CUDA_NO_PINNED") != nullptr) { + return nullptr; + } + + void * ptr = nullptr; + cudaError_t err = cudaMallocHost((void **) &ptr, size); + if (err != cudaSuccess) { + // The allocation error can be bypassed. A null ptr will assigned out of this function. + // This can fixed the OOM error in WSL. + cudaGetLastError(); + fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n", + size/1024.0/1024.0, cudaGetErrorString(err)); + return nullptr; + } + + return ptr; +} + +void ggml_cuda_host_free(void * ptr) { + CUDA_CHECK(cudaFreeHost(ptr)); +} + +static cudaError_t ggml_cuda_cpy_tensor_2d( + void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) { + + cudaMemcpyKind kind; + char * src_ptr; + if (src->backend == GGML_BACKEND_CPU) { + kind = cudaMemcpyHostToDevice; + src_ptr = (char *) src->data; + } else if (src->backend == GGML_BACKEND_GPU || src->backend == GGML_BACKEND_GPU_SPLIT) { + GGML_ASSERT(src->backend != GGML_BACKEND_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1])); + kind = cudaMemcpyDeviceToDevice; + struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra; + int id; + CUDA_CHECK(cudaGetDevice(&id)); + src_ptr = (char *) extra->data_device[id]; + } else { + GGML_ASSERT(false); + } + char * dst_ptr = (char *) dst; + + const int64_t ne0 = src->ne[0]; + const int64_t nb0 = src->nb[0]; + const int64_t nb1 = src->nb[1]; + const int64_t nb2 = src->nb[2]; + const int64_t nb3 = src->nb[3]; + const enum ggml_type type = src->type; + const int64_t ts = ggml_type_size(type); + const int64_t bs = ggml_blck_size(type); + int64_t i1_diff = i1_high - i1_low; + + const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3; + if (nb0 == ts && nb1 == ts*ne0/bs) { + return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream); + } else if (nb0 == ts) { + return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream); + } else { + for (int64_t i1 = 0; i1 < i1_diff; i1++) { + const void * rx = (const void *) ((const char *) x + i1*nb1); + void * rd = (void *) (dst_ptr + i1*ts*ne0/bs); + // pretend the row is a matrix with cols=1 + cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, kind, stream); + if (r != cudaSuccess) return r; + } + return cudaSuccess; + } +} + +inline void ggml_cuda_op_add( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + add_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + add_f16_f32_f16_cuda((const half *) src0_dd, src1_dd, (half *) dst_dd, ggml_nelements(src0), main_stream); + } else { + GGML_ASSERT(false); + } + + (void) src1; + (void) dst; +} + +inline void ggml_cuda_op_mul( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + + mul_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream); + + (void) dst; +} + +inline void ggml_cuda_op_gelu( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + gelu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_silu( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + silu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_norm( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_rms_norm( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + rms_norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, eps, main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_mul_mat_q( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, const cudaStream_t & stream) { + + const int64_t ne00 = src0->ne[0]; + + const int64_t ne10 = src1->ne[0]; + GGML_ASSERT(ne10 % QK8_1 == 0); + + const int64_t ne0 = dst->ne[0]; + + const int64_t row_diff = row_high - row_low; + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + + // the main device has a larger memory buffer to hold the results from all GPUs + // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into + const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff; + + switch (src0->type) { + case GGML_TYPE_Q4_0: + ggml_mul_mat_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q4_1: + ggml_mul_mat_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q5_0: + ggml_mul_mat_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q5_1: + ggml_mul_mat_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q8_0: + ggml_mul_mat_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q2_K: + ggml_mul_mat_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q3_K: + ggml_mul_mat_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q4_K: + ggml_mul_mat_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q5_K: + ggml_mul_mat_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + case GGML_TYPE_Q6_K: + ggml_mul_mat_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + break; + default: + GGML_ASSERT(false); + break; + } + + (void) src1; + (void) dst; + (void) src1_ddf_i; +} + +static int64_t get_row_rounding(ggml_type type) { + int64_t min_compute_capability = INT_MAX; + int64_t max_compute_capability = INT_MIN; + for (int64_t id = 0; id < g_device_count; ++id) { + if (g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) { + if (min_compute_capability > g_compute_capabilities[id]) { + min_compute_capability = g_compute_capabilities[id]; + } + if (max_compute_capability < g_compute_capabilities[id]) { + max_compute_capability = g_compute_capabilities[id]; + } + } + } + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + switch(type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + return max_compute_capability >= CC_RDNA2 ? 128 : 64; + case GGML_TYPE_F16: + return 1; + case GGML_TYPE_Q2_K: + return max_compute_capability >= CC_RDNA2 ? 128 : 32; + case GGML_TYPE_Q3_K: + return min_compute_capability < CC_RDNA2 ? 128 : 64; + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + return max_compute_capability >= CC_RDNA2 ? 128 : 64; + default: + GGML_ASSERT(false); + } +#else + switch(type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + return max_compute_capability >= CC_VOLTA ? 128 : 64; + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + return 64; + case GGML_TYPE_F16: + return 1; + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + return max_compute_capability >= CC_VOLTA ? 128 : 64; + case GGML_TYPE_Q6_K: + return 64; + default: + GGML_ASSERT(false); + } +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +} + +inline void ggml_cuda_op_mul_mat_vec_q( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, const cudaStream_t & stream) { + + const int64_t ne00 = src0->ne[0]; + const int64_t row_diff = row_high - row_low; + + switch (src0->type) { + case GGML_TYPE_Q4_0: + mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q4_1: + mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_0: + mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_1: + mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q8_0: + mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q2_K: + mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q3_K: + mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q4_K: + mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_K: + mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q6_K: + mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + default: + GGML_ASSERT(false); + break; + } + + (void) src1; + (void) dst; + (void) src1_ddf_i; + (void) src1_ncols; + (void) src1_padded_row_size; +} + +inline void ggml_cuda_op_dequantize_mul_mat_vec( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, const cudaStream_t & stream) { + + const int64_t ne00 = src0->ne[0]; + const int64_t row_diff = row_high - row_low; + + // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics +#ifdef GGML_CUDA_F16 + size_t ash; + dfloat * src1_dfloat = nullptr; // dfloat == half + + bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || + src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || + src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; + + if (src1_convert_f16) { + src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash); + ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00, + ne00, 1, sizeof(float), 0, 0, + ne00, 1, sizeof(half), 0, 0, stream); + } +#else + const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion +#endif // GGML_CUDA_F16 + + switch (src0->type) { + case GGML_TYPE_Q4_0: + dequantize_mul_mat_vec_q4_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q4_1: + dequantize_mul_mat_vec_q4_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_0: + dequantize_mul_mat_vec_q5_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_1: + dequantize_mul_mat_vec_q5_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q8_0: + dequantize_mul_mat_vec_q8_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q2_K: + dequantize_mul_mat_vec_q2_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q3_K: + dequantize_mul_mat_vec_q3_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q4_K: + dequantize_mul_mat_vec_q4_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_K: + dequantize_mul_mat_vec_q5_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q6_K: + dequantize_mul_mat_vec_q6_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_F16: + convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; + default: + GGML_ASSERT(false); + break; + } + +#ifdef GGML_CUDA_F16 + if (src1_convert_f16) { + ggml_cuda_pool_free(src1_dfloat, ash); + } +#endif // GGML_CUDA_F16 + + (void) src1; + (void) dst; + (void) src1_ddq_i; + (void) src1_ncols; + (void) src1_padded_row_size; +} + +inline void ggml_cuda_op_mul_mat_cublas( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, const cudaStream_t & stream) { + + GGML_ASSERT(src0_dd_i != nullptr); + GGML_ASSERT(src1_ddf_i != nullptr); + GGML_ASSERT(dst_dd_i != nullptr); + + + const int64_t ne00 = src0->ne[0]; + + const int64_t ne10 = src1->ne[0]; + + const int64_t ne0 = dst->ne[0]; + const int64_t row_diff = row_high - row_low; + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + + // the main device has a larger memory buffer to hold the results from all GPUs + // ldc == nrows of the matrix that cuBLAS writes into + int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff; + + const int compute_capability = g_compute_capabilities[id]; + + if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) { + // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 + half * src0_as_f16 = nullptr; + size_t src0_as = 0; + if (src0->type != GGML_TYPE_F16) { + const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type); + GGML_ASSERT(to_fp16_cuda != nullptr); + size_t ne = row_diff*ne00; + src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as); + to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream); + } + const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16; + + half * src1_as_f16 = nullptr; + size_t src1_as = 0; + if (src1->type != GGML_TYPE_F16) { + const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); + GGML_ASSERT(to_fp16_cuda != nullptr); + size_t ne = src1_ncols*ne10; + src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as); + to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream); + } + const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16; + + size_t dst_as = 0; + half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as); + + const half alpha_f16 = 1.0f; + const half beta_f16 = 0.0f; + + CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream)); + CUBLAS_CHECK( + cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha_f16, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, + &beta_f16, dst_f16, CUDA_R_16F, ldc, + CUBLAS_COMPUTE_16F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream); + + ggml_cuda_pool_free(dst_f16, dst_as); + + if (src0_as != 0) { + ggml_cuda_pool_free(src0_as_f16, src0_as); + } + + if (src1_as != 0) { + ggml_cuda_pool_free(src1_as_f16, src1_as); + } + } + else { + float * src0_ddq_as_f32 = nullptr; + size_t src0_as = 0; + + if (src0->type != GGML_TYPE_F32) { + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); + GGML_ASSERT(to_fp32_cuda != nullptr); + src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT + to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream); + } + const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32; + + const float alpha = 1.0f; + const float beta = 0.0f; + + CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream)); + CUBLAS_CHECK( + cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha, src0_ddf_i, ne00, + src1_ddf_i, ne10, + &beta, dst_dd_i, ldc)); + + if (src0_as != 0) { + ggml_cuda_pool_free(src0_ddq_as_f32, src0_as); + } + } + + (void) dst; + (void) src1_ddq_i; + (void) src1_padded_row_size; +} + +inline void ggml_cuda_op_rope( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t nrows = ggml_nrows(src0); + + //const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx = ((int32_t *) dst->op_params)[3]; + // RoPE alteration for extended context + + float freq_base, freq_scale; + memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + const int32_t * pos = nullptr; + if ((mode & 1) == 0) { + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(src1->ne[0] == ne2); + pos = (const int32_t *) src1_dd; + } + + const bool is_neox = mode & 2; + const bool is_glm = mode & 4; + + // compute + if (is_glm) { + GGML_ASSERT(false); + rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream); + } else if (is_neox) { + GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet"); + if (src0->type == GGML_TYPE_F32) { + rope_neox_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); + } else if (src0->type == GGML_TYPE_F16) { + rope_neox_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); + } else { + GGML_ASSERT(false); + } + } else { + if (src0->type == GGML_TYPE_F32) { + rope_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); + } else if (src0->type == GGML_TYPE_F16) { + rope_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); + } else { + GGML_ASSERT(false); + } + } + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_alibi( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t nrows = ggml_nrows(src0); + + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_head = ((int32_t *) dst->op_params)[1]; + float max_bias; + memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); + + GGML_ASSERT(ne01 + n_past == ne00); + GGML_ASSERT(n_head == ne02); + + const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); + + alibi_f32_cuda(src0_dd, dst_dd, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, main_stream); + + (void) src1; + (void) src1_dd; +} + +inline void ggml_cuda_op_diag_mask_inf( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int nrows0 = ggml_nrows(src0); + + const int n_past = ((int32_t *) dst->op_params)[0]; + + diag_mask_inf_f32_cuda(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_soft_max( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_scale( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const float scale = ((float *) src1->data)[0]; + + scale_f32_cuda(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream); + CUDA_CHECK(cudaGetLastError()); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_cuda_op_flatten_t op) { + const int64_t nrows0 = ggml_nrows(src0); + + const bool use_src1 = src1 != nullptr; + const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1; + + GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT( dst->backend != GGML_BACKEND_GPU_SPLIT); + + struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + struct ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr; + struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + + const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT; + const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU; + const bool dst_on_device = dst->backend == GGML_BACKEND_GPU; + + const bool src1_stays_on_host = use_src1 && dst->op == GGML_OP_SCALE; + + // dd = data device + float * src0_ddf = nullptr; + float * src1_ddf = nullptr; + float * dst_ddf = nullptr; + + // as = actual size + size_t src0_asf = 0; + size_t src1_asf = 0; + size_t dst_asf = 0; + + ggml_cuda_set_device(g_main_device); + const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + if (src0_on_device) { + src0_ddf = (float *) src0_extra->data_device[g_main_device]; + } else { + src0_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_asf); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream)); + } + + if (use_src1 && !src1_stays_on_host) { + if (src1_on_device) { + src1_ddf = (float *) src1_extra->data_device[g_main_device]; + } else { + src1_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf, src1, 0, 0, 0, nrows1, main_stream)); + } + } + if (dst_on_device) { + dst_ddf = (float *) dst_extra->data_device[g_main_device]; + } else { + dst_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(dst), &dst_asf); + } + + // do the computation + op(src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream); + CUDA_CHECK(cudaGetLastError()); + + // copy dst to host if necessary + if (!dst_on_device) { + CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream)); + } + + if (src0_asf > 0) { + ggml_cuda_pool_free(src0_ddf, src0_asf); + } + if (src1_asf > 0) { + ggml_cuda_pool_free(src1_ddf, src1_asf); + } + if (dst_asf > 0) { + ggml_cuda_pool_free(dst_ddf, dst_asf); + } + + if (dst->backend == GGML_BACKEND_CPU) { + CUDA_CHECK(cudaDeviceSynchronize()); + } +} + +static void ggml_cuda_set_peer_access(const int n_tokens) { + static bool peer_access_enabled = false; + + const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE; + + if (peer_access_enabled == enable_peer_access) { + return; + } + +#ifdef NDEBUG + for (int id = 0; id < g_device_count; ++id) { + CUDA_CHECK(ggml_cuda_set_device(id)); + + for (int id_other = 0; id_other < g_device_count; ++id_other) { + if (id == id_other) { + continue; + } + if (id != g_main_device && id_other != g_main_device) { + continue; + } + + int can_access_peer; + CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other)); + if (can_access_peer) { + if (enable_peer_access) { + CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0)); + } else { + CUDA_CHECK(cudaDeviceDisablePeerAccess(id_other)); + } + } + } + } +#endif // NDEBUG + + peer_access_enabled = enable_peer_access; +} + +static void ggml_cuda_op_mul_mat( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op, + const bool convert_src1_to_q8_1) { + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + const int64_t nrows0 = ggml_nrows(src0); + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + const int64_t nrows1 = ggml_nrows(src1); + + GGML_ASSERT(ne03 == ne13); + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + ggml_cuda_set_peer_access(ne11); + + GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(src1->backend != GGML_BACKEND_GPU_SPLIT); + + GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0); + + const int64_t i02_divisor = ne12 / ne02; + + const size_t src0_ts = ggml_type_size(src0->type); + const size_t src0_bs = ggml_blck_size(src0->type); + const size_t q8_1_ts = sizeof(block_q8_1); + const size_t q8_1_bs = QK8_1; + + struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + + const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT; + const bool src0_is_contiguous = ggml_is_contiguous(src0); + + const bool src1_is_contiguous = ggml_is_contiguous(src1); + const int64_t src1_padded_col_size = ne10 % MATRIX_ROW_PADDING == 0 ? + ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING; + + const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT; + GGML_ASSERT(!(split && ne02 > 1)); + GGML_ASSERT(!(split && ne03 > 1)); + GGML_ASSERT(!(split && ne02 < ne12)); + + // dd = data device + char * src0_dd[GGML_CUDA_MAX_DEVICES] = {nullptr}; + float * src1_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // float + char * src1_ddq[GGML_CUDA_MAX_DEVICES] = {nullptr}; // q8_1 + float * dst_dd[GGML_CUDA_MAX_DEVICES] = {nullptr}; + + // as = actual size + size_t src0_as[GGML_CUDA_MAX_DEVICES] = {0}; + size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0}; + size_t src1_asq[GGML_CUDA_MAX_DEVICES] = {0}; + size_t dst_as[GGML_CUDA_MAX_DEVICES] = {0}; + + int64_t row_low[GGML_CUDA_MAX_DEVICES]; + int64_t row_high[GGML_CUDA_MAX_DEVICES]; + + for (int64_t id = 0; id < g_device_count; ++id) { + // by default, use all rows + row_low[id] = 0; + row_high[id] = ne01; + + // for multi GPU, get the row boundaries from tensor split + // and round to mul_mat_q tile sizes + if (split) { + const int64_t rounding = get_row_rounding(src0->type); + + if (id != 0) { + row_low[id] = ne01*g_tensor_split[id]; + row_low[id] -= row_low[id] % rounding; + } + + if (id != g_device_count - 1) { + row_high[id] = ne01*g_tensor_split[id + 1]; + row_high[id] -= row_high[id] % rounding; + } + } + } + + for (int64_t id = 0; id < g_device_count; ++id) { + if ((!split && id != g_main_device) || row_low[id] == row_high[id]) { + continue; + } + + const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device; + const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device; + + ggml_cuda_set_device(id); + const cudaStream_t stream = g_cudaStreams[id][0]; + + if (src0_on_device && src0_is_contiguous) { + src0_dd[id] = (char *) src0_extra->data_device[id]; + } else { + const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0); + src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]); + } + + if (src1_on_device && src1_is_contiguous) { + src1_ddf[id] = (float *) src1_extra->data_device[id]; + } else { + src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]); + } + + if (convert_src1_to_q8_1) { + src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]); + + if (split && src1_on_device && src1_is_contiguous) { + quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + } + } + + if (dst_on_device) { + dst_dd[id] = (float *) dst_extra->data_device[id]; + } else { + const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst); + dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]); + } + } + + // if multiple devices are used they need to wait for the main device + // here an event is recorded that signals that the main device has finished calculating the input data + if (split && g_device_count > 1) { + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0])); + } + + const int64_t src1_col_stride = split && g_device_count > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11; + for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) { + const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0; + const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride; + + for (int64_t id = 0; id < g_device_count; ++id) { + if ((!split && id != g_main_device) || row_low[id] == row_high[id]) { + continue; + } + + const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device; + const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device; + const int64_t row_diff = row_high[id] - row_low[id]; + + ggml_cuda_set_device(id); + const cudaStream_t stream = g_cudaStreams[id][is]; + + // wait for main GPU data if necessary + if (split && (id != g_main_device || is != 0)) { + CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[g_main_device][0], 0)); + } + + for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) { + const int64_t i03 = i0 / ne12; + const int64_t i02 = i0 % ne12; + + const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs; + + // for split tensors the data begins at i0 == i0_offset_low + char * src0_dd_i = src0_dd[id] + (i0/i02_divisor) * ne01*ne00*src0_ts/src0_bs; + float * src1_ddf_i = src1_ddf[id] + (i0*ne11 + src1_col_0) * ne10; + char * src1_ddq_i = src1_ddq[id] + src1_ddq_i_offset; + float * dst_dd_i = dst_dd[id] + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff); + + // the main device memory buffer can be on VRAM scratch, with space for all partial results + // in that case an offset on dst_ddf_i is needed + if (dst->backend == GGML_BACKEND_GPU && id == g_main_device) { + dst_dd_i += row_low[id]; // offset is 0 if no tensor split + } + + // copy src0, src1 to device if necessary + if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) { + if (id != g_main_device) { + if (convert_src1_to_q8_1) { + char * src1_ddq_i_source = src1_ddq[g_main_device] + src1_ddq_i_offset; + CUDA_CHECK(cudaMemcpyAsync(src1_ddq_i, src1_ddq_i_source, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, + cudaMemcpyDeviceToDevice, stream)); + } else { + float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device]; + src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10; + CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_ncols*ne10*sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + } + } else if (src1->backend == GGML_BACKEND_CPU || (src1_on_device && !src1_is_contiguous)) { + CUDA_CHECK(ggml_cuda_cpy_tensor_2d( + src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream)); + } else { + GGML_ASSERT(false); + } + + if (convert_src1_to_q8_1 && src1->backend == GGML_BACKEND_CPU) { + quantize_row_q8_1_cuda(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + } + + if (src1_col_0 == 0 && (!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) { + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, row_low[id], row_high[id], stream)); + } + + // do the computation + op(src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i, + row_low[id], row_high[id], src1_ncols, src1_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + + // copy dst to host or other device if necessary + if (!dst_on_device) { + void * dst_off_device; + cudaMemcpyKind kind; + if (dst->backend == GGML_BACKEND_CPU) { + dst_off_device = dst->data; + kind = cudaMemcpyDeviceToHost; + } else if (dst->backend == GGML_BACKEND_GPU) { + dst_off_device = dst_extra->data_device[g_main_device]; + kind = cudaMemcpyDeviceToDevice; + } else { + GGML_ASSERT(false); + } + if (split) { + // src0 = weight matrix is saved as a transposed matrix for better memory layout. + // dst is NOT transposed. + // The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU. + // Instead they need to be copied to the correct slice in ne0 = dst row index. + // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results. + float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); + GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); + dhf_dst_i += src1_col_0*ne0 + row_low[id]; + CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float), dst_dd_i, row_diff*sizeof(float), + row_diff*sizeof(float), src1_ncols, kind, stream)); + } else { + float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); + GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); + dhf_dst_i += src1_col_0*ne0; + CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_dd_i, src1_ncols*ne0*sizeof(float), kind, stream)); + } + } + + // add event for the main device to wait on until other device is done + if (split && (id != g_main_device || is != 0)) { + CUDA_CHECK(cudaEventRecord(src0_extra->events[id][is], stream)); + } + } + } + } + + for (int64_t id = 0; id < g_device_count; ++id) { + CUDA_CHECK(ggml_cuda_set_device(id)); + + // free buffers again when done + if (src0_as[id] > 0) { + ggml_cuda_pool_free(src0_dd[id], src0_as[id]); + } + if (src1_asf[id] > 0) { + ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]); + } + if (src1_asq[id] > 0) { + ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]); + } + if (dst_as[id] > 0) { + ggml_cuda_pool_free(dst_dd[id], dst_as[id]); + } + } + + // main device waits for all other devices to be finished + if (split && g_device_count > 1) { + int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE; + is_max = is_max <= MAX_STREAMS ? is_max : MAX_STREAMS; + + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + for (int64_t id = 0; id < g_device_count; ++id) { + for (int64_t is = 0; is < is_max; ++is) { + CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is], 0)); + } + } + } + + if (dst->backend == GGML_BACKEND_CPU) { + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + CUDA_CHECK(cudaDeviceSynchronize()); + } +} + +static void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add); +} + +static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul); +} + +static void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu); +} + +static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu); +} + +static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm); +} + +static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm); +} + +bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { + const int64_t ne10 = src1->ne[0]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + + // TODO: find the optimal values for these + return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && + src1->type == GGML_TYPE_F32 && + dst->type == GGML_TYPE_F32 && + (ne0 >= 32 && ne1 >= 32 && ne10 >= 32); +} + +static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ + GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); + GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation + GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + + const int64_t ne12 = src1->ne[2]; + + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + void * src0_ddq = src0_extra->data_device[g_main_device]; + + struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; + + struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; + + ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream); +} + +static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ + GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)); + GGML_ASSERT(!ggml_is_permuted(src0)); + GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + + const int64_t ne12 = src1->ne[2]; + + const int64_t nb01 = src0->nb[1]; + const int64_t nb02 = src0->nb[2]; + + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + void * src0_ddq = src0_extra->data_device[g_main_device]; + + struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; + + struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; + + const int64_t row_stride_x = nb01 / sizeof(half); + const int64_t channel_stride_x = nb02 / sizeof(half); + + ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); +} + +static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && + src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU; + + int64_t min_compute_capability = INT_MAX; + for (int64_t id = 0; id < g_device_count; ++id) { + if (min_compute_capability > g_compute_capabilities[id] + && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) { + min_compute_capability = g_compute_capabilities[id]; + } + } + + if (all_on_device && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { + ggml_cuda_mul_mat_vec_p021(src0, src1, dst); + } else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) { + ggml_cuda_mul_mat_vec_nc(src0, src1, dst); + }else if (src0->type == GGML_TYPE_F32) { + ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); + } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { + if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) { + +#ifdef GGML_CUDA_FORCE_DMMV + const bool use_mul_mat_vec_q = false; +#else + const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type); +#endif // GGML_CUDA_FORCE_DMMV + + if (use_mul_mat_vec_q) { + ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true); + } else { + ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false); + } + } else { + if (g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) { + ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true); + } else { + ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); + } + } + } else { + GGML_ASSERT(false); + } +} + +static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale); +} + +static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + const int64_t ne = ggml_nelements(src0); + GGML_ASSERT(ne == ggml_nelements(src1)); + + GGML_ASSERT(src0->backend == GGML_BACKEND_GPU); + GGML_ASSERT(src1->backend == GGML_BACKEND_GPU); + + GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); + GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + GGML_ASSERT(src0->ne[3] == 1); + + const int64_t nb00 = src0->nb[0]; + const int64_t nb01 = src0->nb[1]; + const int64_t nb02 = src0->nb[2]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + GGML_ASSERT(src1->ne[3] == 1); + + const int64_t nb10 = src1->nb[0]; + const int64_t nb11 = src1->nb[1]; + const int64_t nb12 = src1->nb[2]; + + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + + char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; + char * src1_ddc = (char *) src1_extra->data_device[g_main_device]; + + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { + ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, + ne10, ne11, nb10, nb11, nb12, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { + ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, + ne10, ne11, nb10, nb11, nb12, main_stream); + } else { + fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, + ggml_type_name(src0->type), ggml_type_name(src1->type)); + GGML_ASSERT(false); + } + + (void) dst; +} + +static void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_cpy(src0, dst, nullptr); + (void) src1; +} + +static void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_diag_mask_inf); +} + +static void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_soft_max); +} + +static void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rope); +} + +static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi); +} + +static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + (void) src0; + (void) src1; + (void) dst; +} + +void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { + const int64_t nrows = ggml_nrows(tensor); + + const int64_t ne0 = tensor->ne[0]; + + const size_t nb1 = tensor->nb[1]; + + ggml_backend backend = tensor->backend; + struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu; + memset(extra, 0, sizeof(*extra)); + + for (int64_t id = 0; id < g_device_count; ++id) { + if (backend == GGML_BACKEND_GPU && id != g_main_device) { + continue; + } + + ggml_cuda_set_device(id); + + int64_t row_low, row_high; + if (backend == GGML_BACKEND_GPU) { + row_low = 0; + row_high = nrows; + } else if (backend == GGML_BACKEND_GPU_SPLIT) { + const int64_t rounding = get_row_rounding(tensor->type); + + row_low = id == 0 ? 0 : nrows*g_tensor_split[id]; + row_low -= row_low % rounding; + + if (id == g_device_count - 1) { + row_high = nrows; + } else { + row_high = nrows*g_tensor_split[id + 1]; + row_high -= row_high % rounding; + } + } else { + GGML_ASSERT(false); + } + if (row_low == row_high) { + continue; + } + + int64_t nrows_split = row_high - row_low; + + const size_t offset_split = row_low*nb1; + size_t size = ggml_nbytes_split(tensor, nrows_split); + const size_t original_size = size; + + // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses + if (ne0 % MATRIX_ROW_PADDING != 0) { + size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING) + * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type); + } + + char * buf; + CUDA_CHECK(cudaMalloc(&buf, size)); + char * buf_host = (char*)data + offset_split; + + // set padding to 0 to avoid possible NaN values + if (size > original_size) { + CUDA_CHECK(cudaMemset(buf + original_size, 0, size - original_size)); + } + + + CUDA_CHECK(cudaMemcpy(buf, buf_host, original_size, cudaMemcpyHostToDevice)); + + extra->data_device[id] = buf; + + if (backend == GGML_BACKEND_GPU_SPLIT) { + for (int64_t is = 0; is < MAX_STREAMS; ++is) { + CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id][is], cudaEventDisableTiming)); + } + } + } + + tensor->extra = extra; +} + +void ggml_cuda_free_data(struct ggml_tensor * tensor) { + if (!tensor || (tensor->backend != GGML_BACKEND_GPU && tensor->backend != GGML_BACKEND_GPU_SPLIT) ) { + return; + } + + ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; + + for (int64_t id = 0; id < g_device_count; ++id) { + if (extra->data_device[id] != nullptr) { + CUDA_CHECK(ggml_cuda_set_device(id)); + CUDA_CHECK(cudaFree(extra->data_device[id])); + } + + for (int64_t is = 0; is < MAX_STREAMS; ++is) { + if (extra->events[id][is] != nullptr) { + CUDA_CHECK(ggml_cuda_set_device(id)); + CUDA_CHECK(cudaEventDestroy(extra->events[id][is])); + } + } + } + + delete extra; +} + +static struct ggml_tensor_extra_gpu * g_temp_tensor_extras = nullptr; +static size_t g_temp_tensor_extra_index = 0; + +static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() { + if (g_temp_tensor_extras == nullptr) { + g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_MAX_NODES]; + } + + size_t alloc_index = g_temp_tensor_extra_index; + g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_MAX_NODES; + struct ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index]; + memset(extra, 0, sizeof(*extra)); + + return extra; +} + +static void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) { + if (scratch && g_scratch_size == 0) { + return; + } + + tensor->backend = GGML_BACKEND_GPU; + + // recursively assign CUDA buffers until a compute tensor is found + if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) { + const ggml_op src0_op = tensor->src[0]->op; + if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE) { + ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace, no_alloc); + } + } + if (tensor->op == GGML_OP_CPY && tensor->src[1]->backend == GGML_BACKEND_CPU) { + ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace, no_alloc); + } + + if (scratch && no_alloc) { + return; + } + + struct ggml_tensor_extra_gpu * extra; + + const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) || + tensor->op == GGML_OP_VIEW || + force_inplace; + const size_t size = ggml_nbytes(tensor); + + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) { + struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra; + char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; + size_t offset = 0; + if (tensor->op == GGML_OP_VIEW) { + memcpy(&offset, tensor->op_params, sizeof(size_t)); + } + extra = ggml_cuda_alloc_temp_tensor_extra(); + extra->data_device[g_main_device] = src0_ddc + offset; + } else if (tensor->op == GGML_OP_CPY) { + struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src[1]->extra; + void * src1_ddv = src1_extra->data_device[g_main_device]; + extra = ggml_cuda_alloc_temp_tensor_extra(); + extra->data_device[g_main_device] = src1_ddv; + } else if (scratch) { + GGML_ASSERT(size <= g_scratch_size); + if (g_scratch_offset + size > g_scratch_size) { + g_scratch_offset = 0; + } + + char * data = (char *) g_scratch_buffer; + if (data == nullptr) { + CUDA_CHECK(cudaMalloc(&data, g_scratch_size)); + g_scratch_buffer = data; + } + extra = ggml_cuda_alloc_temp_tensor_extra(); + extra->data_device[g_main_device] = data + g_scratch_offset; + + g_scratch_offset += size; + + GGML_ASSERT(g_scratch_offset <= g_scratch_size); + } else { // allocate new buffers outside of scratch + void * data; + CUDA_CHECK(cudaMalloc(&data, size)); + CUDA_CHECK(cudaMemset(data, 0, size)); + extra = new ggml_tensor_extra_gpu; + memset(extra, 0, sizeof(*extra)); + extra->data_device[g_main_device] = data; + } + + tensor->extra = extra; +} + +void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset) { + if (g_scratch_size == 0) { + return; + } + if (g_scratch_buffer == nullptr) { + ggml_cuda_set_device(g_main_device); + CUDA_CHECK(cudaMalloc(&g_scratch_buffer, g_scratch_size)); + } + + struct ggml_tensor_extra_gpu * extra = ggml_cuda_alloc_temp_tensor_extra(); + + const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) || + tensor->op == GGML_OP_VIEW; + + if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) { + struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra; + char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; + size_t view_offset = 0; + if (tensor->op == GGML_OP_VIEW) { + memcpy(&view_offset, tensor->op_params, sizeof(size_t)); + } + extra->data_device[g_main_device] = src0_ddc + view_offset; + } else { + extra->data_device[g_main_device] = (char *) g_scratch_buffer + offset; + } + + tensor->extra = extra; +} + +void ggml_cuda_copy_to_device(struct ggml_tensor * tensor) { + GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + GGML_ASSERT(ggml_is_contiguous(tensor)); + + struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + CUDA_CHECK(cudaMemcpy(extra->data_device[g_main_device], tensor->data, ggml_nbytes(tensor), cudaMemcpyHostToDevice)); +} + +void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) { + ggml_cuda_assign_buffers_impl(tensor, true, false, false); +} + +void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor) { + ggml_cuda_assign_buffers_impl(tensor, true, false, true); +} + +void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) { + ggml_cuda_assign_buffers_impl(tensor, false, false, false); +} + +void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) { + ggml_cuda_assign_buffers_impl(tensor, false, true, false); +} + +void ggml_cuda_set_main_device(const int main_device) { + if (main_device >= g_device_count) { + fprintf(stderr, "warning: cannot set main_device=%d because there are only %d devices. Using device %d instead.\n", + main_device, g_device_count, g_main_device); + return; + } + g_main_device = main_device; + if (g_device_count > 1) { + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, g_main_device)); + fprintf(stderr, "%s: using device %d (%s) as main device\n", __func__, g_main_device, prop.name); + } +} + +void ggml_cuda_set_mul_mat_q(const bool mul_mat_q) { + g_mul_mat_q = mul_mat_q; +} + +void ggml_cuda_set_scratch_size(const size_t scratch_size) { + // this is a hack to not completely break llama.cpp when using multiple models or contexts simultaneously + // it still won't always work as expected, but it's better than nothing + if (scratch_size > g_scratch_size) { + ggml_cuda_free_scratch(); + } + g_scratch_size = std::max(g_scratch_size, scratch_size); +} + +void ggml_cuda_free_scratch() { + if (g_scratch_buffer == nullptr) { + return; + } + + CUDA_CHECK(cudaFree(g_scratch_buffer)); + g_scratch_buffer = nullptr; +} + +bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){ + ggml_cuda_func_t func; + const bool any_on_device = tensor->backend == GGML_BACKEND_GPU + || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) + || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU); + + switch (tensor->op) { + case GGML_OP_DUP: + if (!any_on_device) { + return false; + } + func = ggml_cuda_dup; + break; + case GGML_OP_ADD: + if (!any_on_device) { + return false; + } + func = ggml_cuda_add; + break; + case GGML_OP_MUL: + if (!any_on_device) { + return false; + } + func = ggml_cuda_mul; + break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_GELU: + if (!any_on_device) { + return false; + } + func = ggml_cuda_gelu; + break; + case GGML_UNARY_OP_SILU: + if (!any_on_device) { + return false; + } + func = ggml_cuda_silu; + break; + default: + return false; + } break; + case GGML_OP_NORM: + if (!any_on_device) { + return false; + } + func = ggml_cuda_norm; + break; + case GGML_OP_RMS_NORM: + if (!any_on_device) { + return false; + } + func = ggml_cuda_rms_norm; + break; + case GGML_OP_MUL_MAT: + if (!any_on_device && !ggml_cuda_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) { + return false; + } + func = ggml_cuda_mul_mat; + break; + case GGML_OP_SCALE: + if (!any_on_device) { + return false; + } + func = ggml_cuda_scale; + break; + case GGML_OP_CPY: + if (!any_on_device) { + return false; + } + func = ggml_cuda_cpy; + break; + case GGML_OP_CONT: + if (!any_on_device) { + return false; + } + func = ggml_cuda_dup; + break; + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + if (!any_on_device) { + return false; + } + func = ggml_cuda_nop; + break; + case GGML_OP_DIAG_MASK_INF: + if (!any_on_device) { + return false; + } + func = ggml_cuda_diag_mask_inf; + break; + case GGML_OP_SOFT_MAX: + if (!any_on_device) { + return false; + } + func = ggml_cuda_soft_max; + break; + case GGML_OP_ROPE: + if (!any_on_device) { + return false; + } + func = ggml_cuda_rope; + break; + case GGML_OP_ALIBI: + if (!any_on_device) { + return false; + } + func = ggml_cuda_alibi; + break; + default: + return false; + } + + if (params->ith != 0) { + return true; + } + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return true; + } + func(tensor->src[0], tensor->src[1], tensor); + return true; +} + +int ggml_cuda_get_device_count() { + int device_count; + CUDA_CHECK(cudaGetDeviceCount(&device_count)); + return device_count; +} + +void ggml_cuda_get_device_description(int device, char * description, size_t description_size) { + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, device)); + snprintf(description, description_size, "%s", prop.name); +} diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-cuda.h b/plugins/wasi_nn/thirdparty/ggml/ggml-cuda.h new file mode 100644 index 00000000..fda704b6 --- /dev/null +++ b/plugins/wasi_nn/thirdparty/ggml/ggml-cuda.h @@ -0,0 +1,47 @@ +#pragma once + +#include "ggml.h" + +#ifdef GGML_USE_HIPBLAS +#define GGML_CUDA_NAME "ROCm" +#define GGML_CUBLAS_NAME "hipBLAS" +#else +#define GGML_CUDA_NAME "CUDA" +#define GGML_CUBLAS_NAME "cuBLAS" +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#define GGML_CUDA_MAX_DEVICES 16 + +GGML_API void ggml_init_cublas(void); +GGML_API void * ggml_cuda_host_malloc(size_t size); +GGML_API void ggml_cuda_host_free(void * ptr); + +GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); +GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split); +GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor); +GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor); + +GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor); +GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor); +GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor); + +GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor); +GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset); +GGML_API void ggml_cuda_copy_to_device(struct ggml_tensor * tensor); + +GGML_API void ggml_cuda_set_main_device(int main_device); +GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q); +GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size); +GGML_API void ggml_cuda_free_scratch(void); +GGML_API bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor); + +GGML_API int ggml_cuda_get_device_count(void); +GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size); + +#ifdef __cplusplus +} +#endif From 99008e112b165a0d7f7f3a151319d819659efcbb Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 17 Oct 2023 12:07:28 +0900 Subject: [PATCH 174/623] [WASI-NN] Bump llama.cpp from b1309 to b1383 Signed-off-by: hydai --- .../wasi_nn/thirdparty/ggml/CMakeLists.txt | 44 +- plugins/wasi_nn/thirdparty/ggml/common.cpp | 255 +- plugins/wasi_nn/thirdparty/ggml/common.h | 64 +- plugins/wasi_nn/thirdparty/ggml/ggml-alloc.c | 169 +- plugins/wasi_nn/thirdparty/ggml/ggml-alloc.h | 16 +- .../wasi_nn/thirdparty/ggml/ggml-backend.c | 385 +++ .../wasi_nn/thirdparty/ggml/ggml-backend.h | 143 + plugins/wasi_nn/thirdparty/ggml/ggml-cuda.cu | 578 +++- plugins/wasi_nn/thirdparty/ggml/ggml-cuda.h | 4 + plugins/wasi_nn/thirdparty/ggml/ggml-metal.h | 19 +- plugins/wasi_nn/thirdparty/ggml/ggml-metal.m | 489 ++- .../wasi_nn/thirdparty/ggml/ggml-metal.metal | 184 +- plugins/wasi_nn/thirdparty/ggml/ggml.c | 1172 ++++--- plugins/wasi_nn/thirdparty/ggml/ggml.h | 32 +- plugins/wasi_nn/thirdparty/ggml/k_quants.c | 746 ++++- plugins/wasi_nn/thirdparty/ggml/k_quants.h | 10 +- plugins/wasi_nn/thirdparty/ggml/llama.cpp | 2764 ++++++++++++++--- plugins/wasi_nn/thirdparty/ggml/llama.h | 13 +- plugins/wasi_nn/thirdparty/ggml/sampling.cpp | 166 + plugins/wasi_nn/thirdparty/ggml/sampling.h | 108 + plugins/wasi_nn/thirdparty/ggml/unicode.h | 462 +++ 21 files changed, 6424 insertions(+), 1399 deletions(-) create mode 100644 plugins/wasi_nn/thirdparty/ggml/ggml-backend.c create mode 100644 plugins/wasi_nn/thirdparty/ggml/ggml-backend.h create mode 100644 plugins/wasi_nn/thirdparty/ggml/sampling.cpp create mode 100644 plugins/wasi_nn/thirdparty/ggml/sampling.h create mode 100644 plugins/wasi_nn/thirdparty/ggml/unicode.h diff --git a/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt b/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt index da9542b0..b2fa2ce9 100644 --- a/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt +++ b/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt @@ -24,6 +24,12 @@ option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF) # instruction set specific +if (LLAMA_NATIVE) + set(INS_ENB OFF) +else() + set(INS_ENB ON) +endif() + option(LLAMA_AVX "llama: enable AVX" ON) option(LLAMA_AVX2 "llama: enable AVX2" ON) option(LLAMA_AVX512 "llama: enable AVX512" OFF) @@ -89,6 +95,20 @@ if (NOT MSVC) endif() endif() +if (APPLE AND LLAMA_ACCELERATE) + find_library(ACCELERATE_FRAMEWORK Accelerate) + if (ACCELERATE_FRAMEWORK) + message(STATUS "Accelerate framework found") + + add_compile_definitions(GGML_USE_ACCELERATE) + add_compile_definitions(ACCELERATE_NEW_LAPACK) + add_compile_definitions(ACCELERATE_LAPACK_ILP64) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK}) + else() + message(WARNING "Accelerate framework not found") + endif() +endif() + if (LLAMA_METAL) find_library(FOUNDATION_LIBRARY Foundation REQUIRED) find_library(METAL_FRAMEWORK Metal REQUIRED) @@ -335,8 +355,7 @@ endif() if (LLAMA_ALL_WARNINGS) if (NOT MSVC) set(warning_flags -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function) - set(c_flags -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes -Werror=implicit-int - -Werror=implicit-function-declaration) + set(c_flags -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes -Werror=implicit-int -Werror=implicit-function-declaration) set(cxx_flags -Wmissing-declarations -Wmissing-noreturn) set(host_cxx_flags "") @@ -368,7 +387,8 @@ if (LLAMA_ALL_WARNINGS) set(c_flags ${c_flags} ${warning_flags}) set(cxx_flags ${cxx_flags} ${warning_flags}) add_compile_options("$<$:${c_flags}>" - "$<$:${cxx_flags} ${host_cxx_flags}>") + "$<$:${cxx_flags}>" + "$<$:${host_cxx_flags}>") endif() @@ -423,9 +443,6 @@ if (NOT MSVC) if (LLAMA_GPROF) add_compile_options(-pg) endif() - if (LLAMA_NATIVE) - add_compile_options(-march=native) - endif() endif() if ((${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm") OR (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") OR ("${CMAKE_GENERATOR_PLATFORM_LWR}" MATCHES "arm64")) @@ -480,6 +497,9 @@ elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$" OR "${CMAKE_GE add_compile_options($<$:/arch:AVX>) endif() else() + if (LLAMA_NATIVE) + add_compile_options(-march=native) + endif() if (LLAMA_F16C) add_compile_options(-mf16c) endif() @@ -576,8 +596,12 @@ wasmedge_add_library(ggml OBJECT ggml.h ggml-alloc.c ggml-alloc.h + ggml-backend.c + ggml-backend.h common.cpp common.h + sampling.cpp + sampling.h ${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA} ${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL} ${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL} @@ -622,14 +646,6 @@ if (BUILD_SHARED_LIBS) endif() endif() -# global flags for ggml -if (NOT WIN32) - target_compile_options(ggml - PRIVATE - -DGGML_USE_K_QUANTS - ) -endif() - # disable warnings if (NOT WIN32) target_compile_options(ggml diff --git a/plugins/wasi_nn/thirdparty/ggml/common.cpp b/plugins/wasi_nn/thirdparty/ggml/common.cpp index 91bece39..79e645c9 100644 --- a/plugins/wasi_nn/thirdparty/ggml/common.cpp +++ b/plugins/wasi_nn/thirdparty/ggml/common.cpp @@ -106,6 +106,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { std::string arg; gpt_params default_params; const std::string arg_prefix = "--"; + llama_sampling_params & sparams = params.sampling_params; for (int i = 1; i < argc; i++) { arg = argv[i]; @@ -166,8 +167,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { invalid_param = true; break; } + // store the external file name in params + params.prompt_file = argv[i]; std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt)); - if (params.prompt.back() == '\n') { + if (!params.prompt.empty() && params.prompt.back() == '\n') { params.prompt.pop_back(); } } else if (arg == "-n" || arg == "--n-predict") { @@ -181,7 +184,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { invalid_param = true; break; } - params.top_k = std::stoi(argv[i]); + sparams.top_k = std::stoi(argv[i]); } else if (arg == "-c" || arg == "--ctx-size") { if (++i >= argc) { invalid_param = true; @@ -213,73 +216,73 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { invalid_param = true; break; } - params.top_p = std::stof(argv[i]); + sparams.top_p = std::stof(argv[i]); } else if (arg == "--temp") { if (++i >= argc) { invalid_param = true; break; } - params.temp = std::stof(argv[i]); + sparams.temp = std::stof(argv[i]); } else if (arg == "--tfs") { if (++i >= argc) { invalid_param = true; break; } - params.tfs_z = std::stof(argv[i]); + sparams.tfs_z = std::stof(argv[i]); } else if (arg == "--typical") { if (++i >= argc) { invalid_param = true; break; } - params.typical_p = std::stof(argv[i]); + sparams.typical_p = std::stof(argv[i]); } else if (arg == "--repeat-last-n") { if (++i >= argc) { invalid_param = true; break; } - params.repeat_last_n = std::stoi(argv[i]); + sparams.repeat_last_n = std::stoi(argv[i]); } else if (arg == "--repeat-penalty") { if (++i >= argc) { invalid_param = true; break; } - params.repeat_penalty = std::stof(argv[i]); + sparams.repeat_penalty = std::stof(argv[i]); } else if (arg == "--frequency-penalty") { if (++i >= argc) { invalid_param = true; break; } - params.frequency_penalty = std::stof(argv[i]); + sparams.frequency_penalty = std::stof(argv[i]); } else if (arg == "--presence-penalty") { if (++i >= argc) { invalid_param = true; break; } - params.presence_penalty = std::stof(argv[i]); + sparams.presence_penalty = std::stof(argv[i]); } else if (arg == "--mirostat") { if (++i >= argc) { invalid_param = true; break; } - params.mirostat = std::stoi(argv[i]); + sparams.mirostat = std::stoi(argv[i]); } else if (arg == "--mirostat-lr") { if (++i >= argc) { invalid_param = true; break; } - params.mirostat_eta = std::stof(argv[i]); + sparams.mirostat_eta = std::stof(argv[i]); } else if (arg == "--mirostat-ent") { if (++i >= argc) { invalid_param = true; break; } - params.mirostat_tau = std::stof(argv[i]); + sparams.mirostat_tau = std::stof(argv[i]); } else if (arg == "--cfg-negative-prompt") { if (++i >= argc) { invalid_param = true; break; } - params.cfg_negative_prompt = argv[i]; + sparams.cfg_negative_prompt = argv[i]; } else if (arg == "--cfg-negative-prompt-file") { if (++i >= argc) { invalid_param = true; @@ -291,16 +294,16 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { invalid_param = true; break; } - std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.cfg_negative_prompt)); - if (params.cfg_negative_prompt.back() == '\n') { - params.cfg_negative_prompt.pop_back(); + std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(sparams.cfg_negative_prompt)); + if (!sparams.cfg_negative_prompt.empty() && sparams.cfg_negative_prompt.back() == '\n') { + sparams.cfg_negative_prompt.pop_back(); } } else if (arg == "--cfg-scale") { if (++i >= argc) { invalid_param = true; break; } - params.cfg_scale = std::stof(argv[i]); + sparams.cfg_scale = std::stof(argv[i]); } else if (arg == "-b" || arg == "--batch-size") { if (++i >= argc) { invalid_param = true; @@ -360,7 +363,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { invalid_param = true; break; } - params.lora_adapter.push_back({argv[i], 1.0f}); + params.lora_adapter.push_back(std::make_tuple(argv[i], 1.0f)); params.use_mmap = false; } else if (arg == "--lora-scaled") { if (++i >= argc) { @@ -372,7 +375,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { invalid_param = true; break; } - params.lora_adapter.push_back({lora_adapter, std::stof(argv[i])}); + params.lora_adapter.push_back(std::make_tuple(lora_adapter, std::stof(argv[i]))); params.use_mmap = false; } else if (arg == "--lora-base") { if (++i >= argc) { @@ -380,6 +383,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.lora_base = argv[i]; + } else if (arg == "--mmproj") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.mmproj = argv[i]; + } else if (arg == "--image") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.image = argv[i]; } else if (arg == "-i" || arg == "--interactive") { params.interactive = true; } else if (arg == "--embedding") { @@ -509,7 +524,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } else if (arg == "--ignore-eos") { params.ignore_eos = true; } else if (arg == "--no-penalize-nl") { - params.penalize_nl = false; + sparams.penalize_nl = false; } else if (arg == "-l" || arg == "--logit-bias") { if (++i >= argc) { invalid_param = true; @@ -521,7 +536,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { std::string value_str; try { if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) { - params.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); + sparams.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); } else { throw std::exception(); } @@ -615,12 +630,17 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { process_escapes(params.prompt); process_escapes(params.input_prefix); process_escapes(params.input_suffix); + for (auto & antiprompt : params.antiprompt) { + process_escapes(antiprompt); + } } return true; } void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { + const llama_sampling_params & sparams = params.sampling_params; + printf("usage: %s [options]\n", argv[0]); printf("\n"); printf("options:\n"); @@ -653,19 +673,19 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict); printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx); printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); - printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k); - printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p); - printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z); - printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)params.typical_p); - printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", params.repeat_last_n); - printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty); - printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty); - printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty); + printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k); + printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); + printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z); + printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p); + printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.repeat_last_n); + printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.repeat_penalty); + printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.presence_penalty); + printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.frequency_penalty); printf(" --mirostat N use Mirostat sampling.\n"); printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"); - printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat); - printf(" --mirostat-lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)params.mirostat_eta); - printf(" --mirostat-ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)params.mirostat_tau); + printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat); + printf(" --mirostat-lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)sparams.mirostat_eta); + printf(" --mirostat-ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)sparams.mirostat_tau); printf(" -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n"); printf(" modifies the likelihood of token appearing in the completion,\n"); printf(" i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); @@ -676,7 +696,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" negative prompt to use for guidance. (default: empty)\n"); printf(" --cfg-negative-prompt-file FNAME\n"); printf(" negative prompt file to use for guidance. (default: empty)\n"); - printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); + printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", sparams.cfg_scale); printf(" --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale\n"); printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n"); printf(" --rope-freq-scale N RoPE frequency linear scaling factor (default: loaded from model)\n"); @@ -684,7 +704,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --no-penalize-nl do not penalize newline token\n"); printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); printf(" not recommended: doubles context memory required and no measurable increase in quality\n"); - printf(" --temp N temperature (default: %.1f)\n", (double)params.temp); + printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp); printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n"); printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n"); printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks); @@ -694,6 +714,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel); printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences); printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); + printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n"); + printf(" --image IMAGE_FILE path to an image file. use with multimodal models\n"); if (llama_mlock_supported()) { printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n"); } @@ -834,7 +856,7 @@ std::tuple llama_init_from_gpt_par } if (params.ignore_eos) { - params.logit_bias[llama_token_eos(lctx)] = -INFINITY; + params.sampling_params.logit_bias[llama_token_eos(lctx)] = -INFINITY; } { @@ -922,129 +944,10 @@ std::string llama_detokenize_bpe(llama_context * ctx, const std::vector & last_tokens, - std::vector & candidates, - int idx) { - const int n_ctx = llama_n_ctx(ctx); - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); - - const float temp = params.temp; - const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; - const float top_p = params.top_p; - const float tfs_z = params.tfs_z; - const float typical_p = params.typical_p; - const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; - const float repeat_penalty = params.repeat_penalty; - const float alpha_presence = params.presence_penalty; - const float alpha_frequency = params.frequency_penalty; - const int mirostat = params.mirostat; - const float mirostat_tau = params.mirostat_tau; - const float mirostat_eta = params.mirostat_eta; - const bool penalize_nl = params.penalize_nl; - - llama_token id = 0; - - float * logits = llama_get_logits_ith(ctx, idx); - - // Apply params.logit_bias map - for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { - logits[it->first] += it->second; - } - - candidates.clear(); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); - } - - llama_token_data_array cur_p = { candidates.data(), candidates.size(), false }; - - if (ctx_guidance) { - llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale); - } - - // apply penalties - if (!last_tokens.empty()) { - const float nl_logit = logits[llama_token_nl(ctx)]; - const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx); - - llama_sample_repetition_penalty(ctx, &cur_p, - last_tokens.data() + last_tokens.size() - last_n_repeat, - last_n_repeat, repeat_penalty); - llama_sample_frequency_and_presence_penalties(ctx, &cur_p, - last_tokens.data() + last_tokens.size() - last_n_repeat, - last_n_repeat, alpha_frequency, alpha_presence); - - if (!penalize_nl) { - for (size_t idx = 0; idx < cur_p.size; idx++) { - if (cur_p.data[idx].id == llama_token_nl(ctx)) { - cur_p.data[idx].logit = nl_logit; - break; - } - } - } - } - - if (grammar != NULL) { - llama_sample_grammar(ctx, &cur_p, grammar); - } - - if (temp <= 0) { - // Greedy sampling - id = llama_sample_token_greedy(ctx, &cur_p); - } else { - if (mirostat == 1) { - static float mirostat_mu = 2.0f * mirostat_tau; - const int mirostat_m = 100; - llama_sample_temp(ctx, &cur_p, temp); - id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); - } else if (mirostat == 2) { - static float mirostat_mu = 2.0f * mirostat_tau; - llama_sample_temp(ctx, &cur_p, temp); - id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu); - } else { - // Temperature sampling - llama_sample_top_k (ctx, &cur_p, top_k, 1); - llama_sample_tail_free (ctx, &cur_p, tfs_z, 1); - llama_sample_typical (ctx, &cur_p, typical_p, 1); - llama_sample_top_p (ctx, &cur_p, top_p, 1); - llama_sample_temp(ctx, &cur_p, temp); - - { - const int n_top = 10; - LOG("top %d candidates:\n", n_top); - - for (int i = 0; i < n_top; i++) { - const llama_token id = cur_p.data[i].id; - LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p); - } - } - - id = llama_sample_token(ctx, &cur_p); - - LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str()); - } - } - // printf("`%d`", candidates_p.size); - - if (grammar != NULL) { - llama_grammar_accept_token(ctx, grammar, id); - } - - return id; -} - // // YAML utils // @@ -1196,6 +1099,10 @@ std::string get_sortable_timestamp() { void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const llama_context * lctx, const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc) { + const llama_sampling_params & sparams = params.sampling_params; + + fprintf(stream, "build_commit: %s\n", BUILD_COMMIT); + fprintf(stream, "build_number: %d\n", BUILD_NUMBER); fprintf(stream, "cpu_has_arm_fma: %s\n", ggml_cpu_has_arm_fma() ? "true" : "false"); fprintf(stream, "cpu_has_avx: %s\n", ggml_cpu_has_avx() ? "true" : "false"); fprintf(stream, "cpu_has_avx2: %s\n", ggml_cpu_has_avx2() ? "true" : "false"); @@ -1240,21 +1147,21 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str()); fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch); - dump_string_yaml_multiline(stream, "cfg_negative_prompt", params.cfg_negative_prompt.c_str()); - fprintf(stream, "cfg_scale: %f # default: 1.0\n", params.cfg_scale); + dump_string_yaml_multiline(stream, "cfg_negative_prompt", sparams.cfg_negative_prompt.c_str()); + fprintf(stream, "cfg_scale: %f # default: 1.0\n", sparams.cfg_scale); fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks); fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false"); fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); - fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", params.frequency_penalty); + fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.frequency_penalty); dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str()); fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n"); fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks); - const auto logit_bias_eos = params.logit_bias.find(llama_token_eos(lctx)); - const bool ignore_eos = logit_bias_eos != params.logit_bias.end() && logit_bias_eos->second == -INFINITY; + const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(lctx)); + const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY; fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false"); dump_string_yaml_multiline(stream, "in_prefix", params.input_prefix.c_str()); @@ -1267,7 +1174,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str()); fprintf(stream, "logit_bias:\n"); - for (std::pair lb : params.logit_bias) { + for (std::pair lb : sparams.logit_bias) { if (ignore_eos && lb.first == logit_bias_eos->first) { continue; } @@ -1291,30 +1198,30 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "lora_base: %s\n", params.lora_base.c_str()); fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu); fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false"); - fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", params.mirostat); - fprintf(stream, "mirostat_ent: %f # default: 5.0\n", params.mirostat_tau); - fprintf(stream, "mirostat_lr: %f # default: 0.1\n", params.mirostat_eta); + fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat); + fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau); + fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta); fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false"); fprintf(stream, "model: %s # default: models/7B/ggml-model.bin\n", params.model.c_str()); fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str()); fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false"); fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers); fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict); - fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", params.n_probs); + fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", sparams.n_probs); fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false"); fprintf(stream, "no_mul_mat_q: %s # default: false\n", !params.mul_mat_q ? "true" : "false"); - fprintf(stream, "no_penalize_nl: %s # default: false\n", !params.penalize_nl ? "true" : "false"); + fprintf(stream, "no_penalize_nl: %s # default: false\n", !sparams.penalize_nl ? "true" : "false"); fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false"); fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type); fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride); - fprintf(stream, "presence_penalty: %f # default: 0.0\n", params.presence_penalty); + fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.presence_penalty); dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str()); fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str()); fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false"); fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false"); dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens); fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false"); - fprintf(stream, "repeat_penalty: %f # default: 1.1\n", params.repeat_penalty); + fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.repeat_penalty); fprintf(stream, "reverse_prompt:\n"); for (std::string ap : params.antiprompt) { @@ -1332,15 +1239,15 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "seed: %d # default: -1 (random seed)\n", params.seed); fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); - fprintf(stream, "temp: %f # default: 0.8\n", params.temp); + fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + LLAMA_MAX_DEVICES); dump_vector_float_yaml(stream, "tensor_split", tensor_split_vector); - fprintf(stream, "tfs: %f # default: 1.0\n", params.tfs_z); + fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z); fprintf(stream, "threads: %d # default: %d\n", params.n_threads, std::thread::hardware_concurrency()); - fprintf(stream, "top_k: %d # default: 40\n", params.top_k); - fprintf(stream, "top_p: %f # default: 0.95\n", params.top_p); - fprintf(stream, "typical_p: %f # default: 1.0\n", params.typical_p); + fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); + fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); + fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); } diff --git a/plugins/wasi_nn/thirdparty/ggml/common.h b/plugins/wasi_nn/thirdparty/ggml/common.h index e095c56e..4305047d 100644 --- a/plugins/wasi_nn/thirdparty/ggml/common.h +++ b/plugins/wasi_nn/thirdparty/ggml/common.h @@ -4,6 +4,8 @@ #include "llama.h" +#include "sampling.h" + #define LOG_NO_FILE_LINE_FUNCTION #include "log.h" @@ -20,6 +22,11 @@ #define DIRECTORY_SEPARATOR '/' #endif // _WIN32 +#define BUILD_NUMBER 1383 +#define BUILD_COMMIT "Embedded in WasmEdge" +#define BUILD_COMPILER "Embedded in WasmEdge" +#define BUILD_TARGET "Embedded in WasmEdge" + #define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0) #define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0) @@ -49,36 +56,18 @@ struct gpt_params { int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs - int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t n_beams = 0; // if non-zero then use beam search of given width. float rope_freq_base = 0.0f; // RoPE base frequency float rope_freq_scale = 0.0f; // RoPE frequency scaling factor // sampling parameters - int32_t top_k = 40; // <= 0 to use vocab size - float top_p = 0.95f; // 1.0 = disabled - float tfs_z = 1.00f; // 1.0 = disabled - float typical_p = 1.00f; // 1.0 = disabled - float temp = 0.80f; // 1.0 = disabled - float repeat_penalty = 1.10f; // 1.0 = disabled - int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float frequency_penalty = 0.00f; // 0.0 = disabled - float presence_penalty = 0.00f; // 0.0 = disabled - int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float mirostat_tau = 5.00f; // target entropy - float mirostat_eta = 0.10f; // learning rate - - std::unordered_map logit_bias; // logit bias for specific tokens - - // Classifier-Free Guidance - // https://arxiv.org/abs/2306.17806 - std::string cfg_negative_prompt; // string to help guidance - float cfg_scale = 1.f; // How strong is guidance + struct llama_sampling_params sampling_params; std::string model = "models/7B/ggml-model-f16.gguf"; // model path std::string model_draft = ""; // draft model for speculative decoding std::string model_alias = "unknown"; // model alias std::string prompt = ""; + std::string prompt_file = ""; // store the external prompt file name std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state std::string input_prefix = ""; // string to prefix user inputs with std::string input_suffix = ""; // string to suffix user inputs with @@ -114,13 +103,16 @@ struct gpt_params { bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens bool instruct = false; // instruction mode (used for Alpaca models) - bool penalize_nl = true; // consider newlines as a repeatable token bool logits_all = false; // return logits for all tokens in the batch bool use_mmap = true; // use mmap for faster loads bool use_mlock = false; // use mlock to keep model in memory bool numa = false; // attempt optimizations that help on some NUMA systems bool verbose_prompt = false; // print prompt tokens before generation bool infill = false; // use infill mode + + // multimodal models (see examples/llava) + std::string mmproj = ""; // path to multimodal projector + std::string image = ""; // path to an image file }; bool gpt_params_parse(int argc, char ** argv, gpt_params & params); @@ -179,36 +171,6 @@ std::string llama_detokenize_bpe( llama_context * ctx, const std::vector & tokens); -// -// Sampling utils -// - -// this is a common sampling function used across the examples for convenience -// it can serve as a starting point for implementing your own sampling function -// -// required: -// - ctx: context to use for sampling -// - params: sampling parameters -// -// optional: -// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL -// - grammar: grammar to use for sampling, ignore if NULL -// - last_tokens: needed for repetition penalty, ignore if empty -// - idx: sample from llama_get_logits_ith(ctx, idx) -// -// returns: -// - token: sampled token -// - candidates: vector of candidate tokens -// -llama_token llama_sample_token( - struct llama_context * ctx, - struct llama_context * ctx_guidance, - struct llama_grammar * grammar, - const struct gpt_params & params, - const std::vector & last_tokens, - std::vector & candidates, - int idx = 0); - // // YAML utils // diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.c b/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.c index 805759db..34eba3f8 100644 --- a/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.c +++ b/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.c @@ -1,4 +1,5 @@ #include "ggml-alloc.h" +#include "ggml-backend.h" #include "ggml.h" #include #include @@ -6,25 +7,6 @@ #include #include -#ifdef __has_include - #if __has_include() - #include - #if defined(_POSIX_MAPPED_FILES) - #include - #include - #endif - #endif -#endif - -#if defined(_WIN32) - #define WIN32_LEAN_AND_MEAN - #ifndef NOMINMAX - #define NOMINMAX - #endif - #include - #include -#endif - #define UNUSED(x) (void)(x) #define MAX(a, b) ((a) > (b) ? (a) : (b)) @@ -80,8 +62,9 @@ struct free_block { #define MAX_FREE_BLOCKS 256 struct ggml_allocr { + struct ggml_backend_buffer * buffer; + bool buffer_owned; void * data; - size_t size; size_t alignment; int n_free_blocks; struct free_block free_blocks[MAX_FREE_BLOCKS]; @@ -119,16 +102,9 @@ static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tens } #endif -static size_t ggml_allocr_get_alloc_size(struct ggml_allocr * alloc, struct ggml_tensor * tensor) { - return ggml_nbytes(tensor); - - UNUSED(alloc); -} - // check if a tensor is allocated by this buffer static bool ggml_allocr_is_own(struct ggml_allocr * alloc, const struct ggml_tensor * tensor) { - void * ptr = tensor->data; - return ptr >= alloc->data && (char *)ptr < (char *)alloc->data + alloc->max_size; + return tensor->buffer == alloc->buffer; } static bool ggml_is_view(struct ggml_tensor * t) { @@ -136,11 +112,10 @@ static bool ggml_is_view(struct ggml_tensor * t) { } void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) { -#ifdef GGML_ALLOCATOR_DEBUG GGML_ASSERT(!ggml_is_view(tensor)); // views generally get data pointer from one of their sources GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated -#endif - size_t size = ggml_allocr_get_alloc_size(alloc, tensor); + + size_t size = ggml_backend_buffer_get_alloc_size(alloc->buffer, tensor); size = aligned_offset(NULL, size, alloc->alignment); AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size); @@ -188,6 +163,8 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) tensor->data = addr; AT_PRINTF("%s: allocated data at %p\n", __func__, tensor->data); + tensor->buffer = alloc->buffer; + ggml_backend_buffer_init_tensor(alloc->buffer, tensor); #ifdef GGML_ALLOCATOR_DEBUG add_allocated_tensor(alloc, tensor); @@ -208,19 +185,21 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) // this is a very naive implementation, but for our case the number of free blocks should be very small static void ggml_allocr_free_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) { - void * ptr = tensor->data; - if (ggml_allocr_is_own(alloc, tensor) == false) { // the tensor was not allocated in this buffer // this can happen because the graph allocator will try to free weights and other tensors from different buffers // the easiest way to deal with this is just to ignore it + AT_PRINTF("ignoring %s (their buffer: %p, our buffer: %p)\n", tensor->name, (void *)tensor->buffer, (void *)alloc->buffer); return; } - size_t size = ggml_allocr_get_alloc_size(alloc, tensor); + void * ptr = tensor->data; + + size_t size = ggml_backend_buffer_get_alloc_size(alloc->buffer, tensor); size = aligned_offset(NULL, size, alloc->alignment); AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks); - AT_PRINTF("%s: alloc->data = %p alloc->data+alloc->size = %p alloc->data+alloc->max_size = %p\n", __func__, alloc->data, (char*)alloc->data + alloc->size, (char*)alloc->data + alloc->max_size); + + ggml_backend_buffer_free_tensor(alloc->buffer, tensor); #ifdef GGML_ALLOCATOR_DEBUG remove_allocated_tensor(alloc, tensor); @@ -285,15 +264,18 @@ void ggml_allocr_reset(struct ggml_allocr * alloc) { alloc->n_free_blocks = 1; size_t align_offset = aligned_offset(alloc->data, 0, alloc->alignment); alloc->free_blocks[0].addr = (char *)alloc->data + align_offset; - alloc->free_blocks[0].size = alloc->size - align_offset; + alloc->free_blocks[0].size = ggml_backend_buffer_get_size(alloc->buffer) - align_offset; } struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment) { - struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */); + struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(NULL, data, size); + + struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr)); *alloc = (struct ggml_allocr){ - /*.data = */ data, - /*.size = */ size, + /*.buffer = */ buffer, + /*.buffer_owned = */ true, + /*.base = */ ggml_backend_buffer_get_base(buffer), /*.alignment = */ alignment, /*.n_free_blocks = */ 0, /*.free_blocks = */ {{0}}, @@ -312,74 +294,26 @@ struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment) return alloc; } -// OS specific functions to allocate and free uncommitted virtual memory -static void * alloc_vmem(size_t size) { -#if defined(_WIN32) - return VirtualAlloc(NULL, size, MEM_RESERVE, PAGE_NOACCESS); -#elif defined(_POSIX_MAPPED_FILES) - void * ptr = mmap(NULL, size, PROT_NONE, MAP_PRIVATE | MAP_ANON, -1, 0); - if (ptr == MAP_FAILED) { - return NULL; - } - return ptr; -#else - // use a fixed address for other platforms - uintptr_t base_addr = (uintptr_t)-size - 0x100; - return (void *)base_addr; -#endif -} - -static void free_vmem(void * base_addr, size_t size) { -#if defined(_WIN32) - VirtualFree(base_addr, 0, MEM_RELEASE); - UNUSED(size); -#elif defined(_POSIX_MAPPED_FILES) - munmap(base_addr, size); -#else - // nothing to do - UNUSED(base_addr); - UNUSED(size); -#endif -} - -// allocate uncommitted virtual memory to measure the size of the graph -static void alloc_measure_vmem(void ** base_addr, size_t * size) { - // 128GB for 64-bit, 1GB for 32-bit - *size = sizeof(void *) == 4 ? 1ULL<<30 : 1ULL<<37; - do { - *base_addr = alloc_vmem(*size); - if (*base_addr != NULL) { - AT_PRINTF("allocated %.2f GB of virtual memory for measure buffer at %p\n", *size / 1024.0 / 1024.0 / 1024.0, *base_addr); - return; - } - // try again with half the size - *size /= 2; - } while (*size > 0); - - GGML_ASSERT(!"failed to allocate virtual memory for measure buffer"); -} - -static void free_measure_vmem(void * base_addr, size_t size) { - free_vmem(base_addr, size); -} - struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) { - struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */); + struct ggml_allocr * alloc = ggml_allocr_new((void *)0x1000, (size_t)-0x1001, alignment); + alloc->measure = true; - void * base_addr; - size_t size; + return alloc; +} - alloc_measure_vmem(&base_addr, &size); +struct ggml_allocr * ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer) { + struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr)); *alloc = (struct ggml_allocr){ - /*.data = */ base_addr, - /*.size = */ size, - /*.alignment = */ alignment, + /*.buffer = */ buffer, + /*.buffer_owned = */ false, + /*.base = */ ggml_backend_buffer_get_base(buffer), + /*.alignment = */ ggml_backend_buffer_get_alignment(buffer), /*.n_free_blocks = */ 0, /*.free_blocks = */ {{0}}, /*.hash_table = */ {{0}}, /*.max_size = */ 0, - /*.measure = */ true, + /*.measure = */ false, /*.parse_seq = */ {0}, /*.parse_seq_len = */ 0, #ifdef GGML_ALLOCATOR_DEBUG @@ -393,8 +327,8 @@ struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) { } void ggml_allocr_free(struct ggml_allocr * alloc) { - if (alloc->measure) { - free_measure_vmem(alloc->data, alloc->size); + if (alloc->buffer_owned) { + ggml_backend_buffer_free(alloc->buffer); } free(alloc); } @@ -437,7 +371,6 @@ static bool ggml_op_can_inplace(enum ggml_op op) { case GGML_OP_ROPE: case GGML_OP_RMS_NORM: case GGML_OP_SOFT_MAX: - case GGML_OP_CONT: return true; default: @@ -445,12 +378,23 @@ static bool ggml_op_can_inplace(enum ggml_op op) { } } +static void init_view(struct ggml_allocr * alloc, struct ggml_tensor * view) { + assert(view->view_src != NULL && view->view_src->data != NULL); + view->backend = view->view_src->backend; + view->buffer = view->view_src->buffer; + view->data = (char *)view->view_src->data + view->view_offs; + + // FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend + // due to the ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras + assert(ggml_allocr_is_measure(alloc) || !view->buffer || view->buffer->backend == alloc->buffer->backend); + ggml_backend_buffer_init_tensor(alloc->buffer, view); +} + static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node) { struct hash_node * ht = alloc->hash_table; if (node->data == NULL) { if (ggml_is_view(node)) { - assert(node->view_src->data != NULL); - node->data = (char *)node->view_src->data + node->view_offs; + init_view(alloc, node); } else { // see if we can reuse a parent's buffer (inplace) if (ggml_op_can_inplace(node->op)) { @@ -478,13 +422,17 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node) // adding a view_src pointer to the tensor would solve this and simplify the code dealing with views // for now, we only reuse the parent's data if the offset is zero (view_src->data == parent->data) AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name); - node->data = parent->data; + node->view_src = view_src; + view_src_hn->n_views += 1; + init_view(alloc, node); return; } } else { AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name); - node->data = parent->data; + node->view_src = parent; + p_hn->n_views += 1; + init_view(alloc, node); return; } } @@ -495,7 +443,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node) } } -static size_t ggml_allocr_alloc_graph_tensors_n( +size_t ggml_allocr_alloc_graph_n( struct ggml_allocr * alloc, struct ggml_cgraph ** graphs, int n_graphs, struct ggml_tensor *** inputs, struct ggml_tensor *** outputs) { @@ -513,6 +461,10 @@ static size_t ggml_allocr_alloc_graph_tensors_n( if (ggml_is_view(node)) { struct ggml_tensor * view_src = node->view_src; hash_get(ht, view_src)->n_views += 1; + if (node->buffer == NULL && node->data != NULL) { + // view of a pre-allocated tensor, didn't call init_view() yet + init_view(alloc, node); + } } for (int j = 0; j < GGML_MAX_SRC; j++) { @@ -521,6 +473,9 @@ static size_t ggml_allocr_alloc_graph_tensors_n( break; } hash_get(ht, parent)->n_children += 1; + if (ggml_is_view(parent) && parent->buffer == NULL && parent->data != NULL) { + init_view(alloc, parent); + } } } } @@ -631,7 +586,7 @@ static size_t ggml_allocr_alloc_graph_tensors_n( } size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) { - return ggml_allocr_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL); + return ggml_allocr_alloc_graph_n(alloc, &graph, 1, NULL, NULL); } size_t ggml_allocr_max_size(struct ggml_allocr * alloc) { diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.h b/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.h index 0c224f17..e3875887 100644 --- a/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.h +++ b/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.h @@ -6,21 +6,27 @@ extern "C" { #endif +struct ggml_backend_buffer; GGML_API struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment); GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment); +GGML_API struct ggml_allocr * ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer); // tell the allocator to parse nodes following the order described in the list // you should call this if your graph are optimized to execute out-of-order GGML_API void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n); -GGML_API void ggml_allocr_free(struct ggml_allocr * alloc); -GGML_API bool ggml_allocr_is_measure(struct ggml_allocr * alloc); -GGML_API void ggml_allocr_reset(struct ggml_allocr * alloc); -GGML_API void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor); +GGML_API void ggml_allocr_free (struct ggml_allocr * alloc); +GGML_API bool ggml_allocr_is_measure (struct ggml_allocr * alloc); +GGML_API void ggml_allocr_reset (struct ggml_allocr * alloc); +GGML_API void ggml_allocr_alloc (struct ggml_allocr * alloc, struct ggml_tensor * tensor); GGML_API size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph); -GGML_API size_t ggml_allocr_max_size(struct ggml_allocr * alloc); +GGML_API size_t ggml_allocr_max_size (struct ggml_allocr * alloc); +GGML_API size_t ggml_allocr_alloc_graph_n( + struct ggml_allocr * alloc, + struct ggml_cgraph ** graphs, int n_graphs, + struct ggml_tensor *** inputs, struct ggml_tensor *** outputs); #ifdef __cplusplus } diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-backend.c b/plugins/wasi_nn/thirdparty/ggml/ggml-backend.c new file mode 100644 index 00000000..ca8d83da --- /dev/null +++ b/plugins/wasi_nn/thirdparty/ggml/ggml-backend.c @@ -0,0 +1,385 @@ +#include "ggml-backend.h" +#include "ggml-alloc.h" + +#include +#include +#include +#include +#include + +#define UNUSED GGML_UNUSED + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +// backend buffer + +ggml_backend_buffer_t ggml_backend_buffer_init( + struct ggml_backend * backend, + struct ggml_backend_buffer_i iface, + ggml_backend_buffer_context_t context, + size_t size) { + ggml_backend_buffer_t buffer = malloc(sizeof(struct ggml_backend_buffer)); + + GGML_ASSERT(iface.get_base != NULL); + + (*buffer) = (struct ggml_backend_buffer) { + /* .interface = */ iface, + /* .backend = */ backend, + /* .context = */ context, + /* .size = */ size, + }; + + return buffer; +} + +void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) { + if (buffer->iface.free_buffer != NULL) { + buffer->iface.free_buffer(buffer); + } + free(buffer); +} + +size_t ggml_backend_buffer_get_alignment(ggml_backend_buffer_t buffer) { + return ggml_backend_get_alignment(buffer->backend); +} + +void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { + return buffer->iface.get_base(buffer); +} + +size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) { + return buffer->size; +} + +size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { + if (buffer->iface.get_alloc_size) { + return buffer->iface.get_alloc_size(buffer, tensor); + } + return ggml_nbytes(tensor); +} + +void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { + if (buffer->iface.init_tensor) { + buffer->iface.init_tensor(buffer, tensor); + } +} + +void ggml_backend_buffer_free_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { + if (buffer->iface.free_tensor) { + buffer->iface.free_tensor(buffer, tensor); + } +} + +// backend + +ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor) { + return tensor->buffer->backend; +} + +const char * ggml_backend_name(ggml_backend_t backend) { + return backend->iface.get_name(backend); +} + +void ggml_backend_free(ggml_backend_t backend) { + backend->iface.free(backend); +} + +ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size) { + return backend->iface.alloc_buffer(backend, size); +} + +size_t ggml_backend_get_alignment(ggml_backend_t backend) { + return backend->iface.get_alignment(backend); +} + +void ggml_backend_tensor_set_async(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + ggml_get_backend(tensor)->iface.set_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size); +} + +void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + ggml_get_backend(tensor)->iface.get_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size); +} + +void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + ggml_get_backend(tensor)->iface.set_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size); + ggml_get_backend(tensor)->iface.synchronize(ggml_get_backend(tensor)); +} + +void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + ggml_get_backend(tensor)->iface.get_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size); + ggml_get_backend(tensor)->iface.synchronize(ggml_get_backend(tensor)); +} + +void ggml_backend_synchronize(ggml_backend_t backend) { + backend->iface.synchronize(backend); +} + +ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + return backend->iface.graph_plan_create(backend, cgraph); +} + +void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + backend->iface.graph_plan_free(backend, plan); +} + +void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + backend->iface.graph_plan_compute(backend, plan); +} + +void ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + backend->iface.graph_compute(backend, cgraph); +} + +bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { + return backend->iface.supports_op(backend, op); +} + +// backend copy + +static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) { + if (a->type != b->type) { + return false; + } + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (a->ne[i] != b->ne[i]) { + return false; + } + if (a->nb[i] != b->nb[i]) { + return false; + } + } + return true; +} + +void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) { + //printf("src: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", src->name, (int)src->ne[0], (int)src->ne[1], (int)src->ne[2], (int)src->ne[3], (int)src->nb[0], (int)src->nb[1], (int)src->nb[2], (int)src->nb[3]); + //printf("dst: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", dst->name, (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], (int)dst->nb[0], (int)dst->nb[1], (int)dst->nb[2], (int)dst->nb[3]); + GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); + + // printf("cpy tensor %s from %s to %s (%lu bytes)\n", src->name, ggml_backend_name(src->backend), ggml_backend_name(dst->backend), ggml_nbytes(src)); + + if (src == dst) { + return; + } + + // TODO: allow backends to support copy to/from same backend + + if (ggml_get_backend(dst)->iface.cpy_tensor_from != NULL) { + ggml_get_backend(dst)->iface.cpy_tensor_from(ggml_get_backend(dst)->context, src, dst); + } else if (ggml_get_backend(src)->iface.cpy_tensor_to != NULL) { + ggml_get_backend(src)->iface.cpy_tensor_to(ggml_get_backend(src)->context, src, dst); + } else { + // shouldn't be hit when copying from/to CPU + #ifndef NDEBUG + fprintf(stderr, "ggml_backend_tensor_copy: neither cpy_tensor_from nor cpy_tensor_to are implemented for backends %s and %s, falling back to get/set\n", ggml_backend_name(src->buffer->backend), ggml_backend_name(dst->buffer->backend)); + #endif + size_t nbytes = ggml_nbytes(src); + void * data = malloc(nbytes); + ggml_backend_tensor_get(src, data, 0, nbytes); + ggml_backend_tensor_set(dst, data, 0, nbytes); + free(data); + } +} + +// backend CPU + +struct ggml_backend_cpu_context { + int n_threads; + void * work_data; + size_t work_size; +}; + +static const char * ggml_backend_cpu_name(ggml_backend_t backend) { + return "CPU"; + + UNUSED(backend); +} + +static void ggml_backend_cpu_free(ggml_backend_t backend) { + struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; + free(cpu_ctx->work_data); + free(cpu_ctx); + free(backend); +} + +static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { + return (void *)buffer->context; +} + +static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { + free(buffer->context); + UNUSED(buffer); +} + +static struct ggml_backend_buffer_i cpu_backend_buffer_i = { + /* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer, + /* .get_base = */ ggml_backend_cpu_buffer_get_base, + /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .init_tensor = */ NULL, // no initialization required + /* .free_tensor = */ NULL, // no cleanup required +}; + +// for buffers from ptr, free is not called +static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = { + /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed + /* .get_base = */ ggml_backend_cpu_buffer_get_base, + /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .init_tensor = */ NULL, + /* .free_tensor = */ NULL, +}; + +static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512 + +static ggml_backend_buffer_t ggml_backend_cpu_alloc_buffer(ggml_backend_t backend, size_t size) { + size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned + void * data = malloc(size); // TODO: maybe use GGML_ALIGNED_MALLOC? + + return ggml_backend_buffer_init(backend, cpu_backend_buffer_i, data, size); +} + +static size_t ggml_backend_cpu_get_alignment(ggml_backend_t backend) { + return TENSOR_ALIGNMENT; + UNUSED(backend); +} + +static void ggml_backend_cpu_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + + memcpy((char *)tensor->data + offset, data, size); + + UNUSED(backend); +} + +static void ggml_backend_cpu_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + + memcpy(data, (const char *)tensor->data + offset, size); + + UNUSED(backend); +} + +static void ggml_backend_cpu_synchronize(ggml_backend_t backend) { + UNUSED(backend); +} + +static void ggml_backend_cpu_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) { + ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src)); + + UNUSED(backend); +} + +static void ggml_backend_cpu_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) { + // for a backend such as CUDA that can queue async calls, it is ok to do this asynchronously, but it may not be the case for other backends + ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src)); + + UNUSED(backend); +} + +struct ggml_backend_plan_cpu { + struct ggml_cplan cplan; + struct ggml_cgraph cgraph; +}; + +static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; + + struct ggml_backend_plan_cpu * cpu_plan = malloc(sizeof(struct ggml_backend_plan_cpu)); + + cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads); + cpu_plan->cgraph = *cgraph; + + if (cpu_plan->cplan.work_size > 0) { + cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size); + } + + return cpu_plan; +} + +static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan; + + free(cpu_plan->cplan.work_data); + free(cpu_plan); + + UNUSED(backend); +} + +static void ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan; + + ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan); + + UNUSED(backend); +} + +static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; + + struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads); + + if (cpu_ctx->work_size < cplan.work_size) { + // TODO: may be faster to free and use malloc to avoid the copy + cpu_ctx->work_data = realloc(cpu_ctx->work_data, cplan.work_size); + cpu_ctx->work_size = cplan.work_size; + } + + cplan.work_data = cpu_ctx->work_data; + + ggml_graph_compute(cgraph, &cplan); +} + +static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { + return true; + UNUSED(backend); + UNUSED(op); +} + +static struct ggml_backend_i cpu_backend_i = { + /* .get_name = */ ggml_backend_cpu_name, + /* .free = */ ggml_backend_cpu_free, + /* .alloc_buffer = */ ggml_backend_cpu_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_get_alignment, + /* .set_tensor_async = */ ggml_backend_cpu_set_tensor_async, + /* .get_tensor_async = */ ggml_backend_cpu_get_tensor_async, + /* .synchronize = */ ggml_backend_cpu_synchronize, + /* .cpy_tensor_from = */ ggml_backend_cpu_cpy_tensor_from, + /* .cpy_tensor_to = */ ggml_backend_cpu_cpy_tensor_to, + /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create, + /* .graph_plan_free = */ ggml_backend_cpu_graph_plan_free, + /* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute, + /* .graph_compute = */ ggml_backend_cpu_graph_compute, + /* .supports_op = */ ggml_backend_cpu_supports_op, +}; + +ggml_backend_t ggml_backend_cpu_init(void) { + struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context)); + + ctx->n_threads = GGML_DEFAULT_N_THREADS; + ctx->work_data = NULL; + ctx->work_size = 0; + + ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend)); + + *cpu_backend = (struct ggml_backend) { + /* .interface = */ cpu_backend_i, + /* .context = */ ctx + }; + return cpu_backend; +} + +bool ggml_backend_is_cpu(ggml_backend_t backend) { + return backend->iface.get_name == ggml_backend_cpu_name; +} + +void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) { + GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); + + struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; + ctx->n_threads = n_threads; +} + +ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size) { + return ggml_backend_buffer_init(backend_cpu, cpu_backend_buffer_i_from_ptr, ptr, size); +} diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-backend.h b/plugins/wasi_nn/thirdparty/ggml/ggml-backend.h new file mode 100644 index 00000000..da134b0d --- /dev/null +++ b/plugins/wasi_nn/thirdparty/ggml/ggml-backend.h @@ -0,0 +1,143 @@ +#pragma once + +#include "ggml.h" + +#ifdef __cplusplus +extern "C" { +#endif + struct ggml_backend; + struct ggml_backend_buffer; + + // type-erased backend-specific types / wrappers + typedef void * ggml_backend_context_t; + typedef void * ggml_backend_graph_plan_t; + typedef void * ggml_backend_buffer_context_t; + + // avoid accessing internals of these types + typedef struct ggml_backend * ggml_backend_t; + typedef struct ggml_backend_buffer * ggml_backend_buffer_t; + + // + // backend buffer + // + + struct ggml_backend_buffer_i { + void (*free_buffer) (ggml_backend_buffer_t buffer); + void * (*get_base) (ggml_backend_buffer_t buffer); // get base pointer + size_t (*get_alloc_size)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-allocation callback + void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // post-allocation callback + void (*free_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-free callback + }; + + // TODO: hide behind API + struct ggml_backend_buffer { + struct ggml_backend_buffer_i iface; + + ggml_backend_t backend; + ggml_backend_buffer_context_t context; + + size_t size; + }; + + // backend buffer functions + GGML_API ggml_backend_buffer_t ggml_backend_buffer_init( + struct ggml_backend * backend, + struct ggml_backend_buffer_i iface, + ggml_backend_buffer_context_t context, + size_t size); + + GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer); + GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer); + GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer); + GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer); + GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + GGML_API void ggml_backend_buffer_free_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + + // + // backend + // + + struct ggml_backend_i { + const char * (*get_name)(ggml_backend_t backend); + + void (*free)(ggml_backend_t backend); + + // buffer allocation + ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_t backend, size_t size); + + // get buffer alignment + size_t (*get_alignment)(ggml_backend_t backend); + + // tensor data access + // these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize + void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + void (*synchronize) (ggml_backend_t backend); + + // (optional) copy tensor between different backends, allow for single-copy tranfers + void (*cpy_tensor_from)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); + void (*cpy_tensor_to) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); + + // compute graph with a plan + ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph); + void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan); + void (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan); + + // compute graph without a plan + void (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph); + + // check if the backend supports an operation + bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op); + }; + + // TODO: hide behind API + struct ggml_backend { + struct ggml_backend_i iface; + + ggml_backend_context_t context; + }; + + // backend helper functions + GGML_API ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor); + + GGML_API const char * ggml_backend_name(ggml_backend_t backend); + GGML_API void ggml_backend_free(ggml_backend_t backend); + + GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size); + + GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend); + + GGML_API void ggml_backend_tensor_set_async( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + + GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + + GGML_API void ggml_backend_synchronize(ggml_backend_t backend); + + GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create (ggml_backend_t backend, struct ggml_cgraph * cgraph); + + GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan); + GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan); + GGML_API void ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph); + GGML_API bool ggml_backend_supports_op (ggml_backend_t backend, const struct ggml_tensor * op); + + // tensor copy between different backends + GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst); + + // + // CPU backend + // + + GGML_API ggml_backend_t ggml_backend_cpu_init(void); + + GGML_API bool ggml_backend_is_cpu(ggml_backend_t backend); + + GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads); + + GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size); + +#ifdef __cplusplus +} +#endif diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-cuda.cu b/plugins/wasi_nn/thirdparty/ggml/ggml-cuda.cu index 989c419c..654d3632 100644 --- a/plugins/wasi_nn/thirdparty/ggml/ggml-cuda.cu +++ b/plugins/wasi_nn/thirdparty/ggml/ggml-cuda.cu @@ -62,6 +62,7 @@ #define cudaMemcpyHostToDevice hipMemcpyHostToDevice #define cudaMemcpyKind hipMemcpyKind #define cudaMemset hipMemset +#define cudaMemsetAsync hipMemsetAsync #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize #define cudaSetDevice hipSetDevice #define cudaStreamCreateWithFlags hipStreamCreateWithFlags @@ -414,11 +415,13 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define CUDA_SILU_BLOCK_SIZE 256 #define CUDA_CPY_BLOCK_SIZE 32 #define CUDA_SCALE_BLOCK_SIZE 256 +#define CUDA_CLAMP_BLOCK_SIZE 256 #define CUDA_ROPE_BLOCK_SIZE 256 #define CUDA_ALIBI_BLOCK_SIZE 32 #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32 #define CUDA_QUANTIZE_BLOCK_SIZE 256 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 +#define CUDA_GET_ROWS_BLOCK_SIZE 256 // dmmv = dequantize_mul_mat_vec #ifndef GGML_CUDA_DMMV_X @@ -1574,6 +1577,34 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest reinterpret_cast(y[ib].ds.y) = sum; } +template +static __global__ void k_get_rows(const void * x, const int32_t * y, dst_t * dst, const int ncols) { + const int col = (blockIdx.x*blockDim.x + threadIdx.x)*2; + const int row = blockDim.y*blockIdx.y + threadIdx.y; + + if (col >= ncols) { + return; + } + + const int r = y[row]; + + // copy x[r*ncols + col] to dst[row*ncols + col] + const int xi = r*ncols + col; + const int di = row*ncols + col; + + const int ib = xi/qk; // block index + const int iqs = (xi%qk)/qr; // quant index + const int iybs = di - di%qk; // y block start index + const int y_offset = qr == 1 ? 1 : qk/2; + + // dequantize + dfloat2 v; + dequantize_kernel(x, ib, iqs, v); + + dst[iybs + iqs + 0] = v.x; + dst[iybs + iqs + y_offset] = v.y; +} + template static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) { const int i = blockDim.x*blockIdx.x + 2*threadIdx.x; @@ -4555,6 +4586,24 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale dst[i] = scale * x[i]; } +static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]); +} + +template +static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) { + const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); + const int block_num_x = (ncols + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); + const dim3 block_nums(block_num_x, nrows, 1); + k_get_rows<<>>(x, y, dst, ncols); +} + static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; add_f32<<>>(x, y, dst, kx, ky); @@ -5436,6 +5485,11 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons scale_f32<<>>(x, dst, scale, k); } +static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE; + clamp_f32<<>>(x, dst, min, max, k); +} + template static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { @@ -5703,7 +5757,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d( } else if (src->backend == GGML_BACKEND_GPU || src->backend == GGML_BACKEND_GPU_SPLIT) { GGML_ASSERT(src->backend != GGML_BACKEND_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1])); kind = cudaMemcpyDeviceToDevice; - struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra; + ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra; int id; CUDA_CHECK(cudaGetDevice(&id)); src_ptr = (char *) extra->data_device[id]; @@ -5739,6 +5793,107 @@ static cudaError_t ggml_cuda_cpy_tensor_2d( } } +static void ggml_cuda_op_repeat( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) { + // guaranteed to be an integer due to the check in ggml_can_repeat + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + const int nr0 = (int)(ne0/ne00); + const int nr1 = (int)(ne1/ne01); + const int nr2 = (int)(ne2/ne02); + const int nr3 = (int)(ne3/ne03); + + // TODO: support for transposed / permuted tensors + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + // TODO: very inefficient, implement in a kernel, or fewer cudaMemcpyAsync calls for contiguous tensors + for (int i3 = 0; i3 < nr3; i3++) { + for (int k3 = 0; k3 < ne03; k3++) { + for (int i2 = 0; i2 < nr2; i2++) { + for (int k2 = 0; k2 < ne02; k2++) { + for (int i1 = 0; i1 < nr1; i1++) { + for (int k1 = 0; k1 < ne01; k1++) { + for (int i0 = 0; i0 < nr0; i0++) { + CUDA_CHECK(cudaMemcpyAsync( + (char *) dst_d + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0, + (const char *) src0_d + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01, + ne00*nb0, cudaMemcpyDeviceToDevice, stream)); + } + } + } + } + } + } + } + + (void) src1; + (void) src1_d; +} + +static void ggml_cuda_op_get_rows( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) { + + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + const int ncols = src0->ne[0]; + const int nrows = ggml_nelements(src1); + + const int32_t * src1_i32 = (const int32_t *) src1_d; + + switch (src0->type) { + case GGML_TYPE_F16: + get_rows_cuda<1, 1, convert_f16>(src0_d, src1_i32, dst_d, nrows, ncols, stream); + break; + case GGML_TYPE_F32: + get_rows_cuda<1, 1, convert_f32>(src0_d, src1_i32, dst_d, nrows, ncols, stream); + break; + case GGML_TYPE_Q4_0: + get_rows_cuda(src0_d, src1_i32, dst_d, nrows, ncols, stream); + break; + case GGML_TYPE_Q4_1: + get_rows_cuda(src0_d, src1_i32, dst_d, nrows, ncols, stream); + break; + case GGML_TYPE_Q5_0: + get_rows_cuda(src0_d, src1_i32, dst_d, nrows, ncols, stream); + break; + case GGML_TYPE_Q5_1: + get_rows_cuda(src0_d, src1_i32, dst_d, nrows, ncols, stream); + break; + case GGML_TYPE_Q8_0: + get_rows_cuda(src0_d, src1_i32, dst_d, nrows, ncols, stream); + break; + default: + // TODO: k-quants + GGML_ASSERT(false); + break; + } +} + inline void ggml_cuda_op_add( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { @@ -6279,12 +6434,12 @@ inline void ggml_cuda_op_alibi( const int64_t ne02 = src0->ne[2]; const int64_t nrows = ggml_nrows(src0); - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - GGML_ASSERT(ne01 + n_past == ne00); + //GGML_ASSERT(ne01 + n_past == ne00); GGML_ASSERT(n_head == ne02); const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); @@ -6343,7 +6498,14 @@ inline void ggml_cuda_op_scale( GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - const float scale = ((float *) src1->data)[0]; + float scale; + // HACK: support for ggml backend interface + if (src1->backend == GGML_BACKEND_CPU) { + scale = ((float *) src1->data)[0]; + } else { + // TODO: pass pointer to kernel instead of copying to host + CUDA_CHECK(cudaMemcpy(&scale, src1->data, sizeof(float), cudaMemcpyDeviceToHost)); + } scale_f32_cuda(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream); CUDA_CHECK(cudaGetLastError()); @@ -6353,6 +6515,24 @@ inline void ggml_cuda_op_scale( (void) src1_dd; } +inline void ggml_cuda_op_clamp( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const float min = ((float *) dst->op_params)[0]; + const float max = ((float *) dst->op_params)[1]; + + clamp_f32_cuda(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream); + CUDA_CHECK(cudaGetLastError()); + + (void) src1; + (void) dst; + (void) src1_dd; +} + static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_cuda_op_flatten_t op) { const int64_t nrows0 = ggml_nrows(src0); @@ -6362,9 +6542,9 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT); GGML_ASSERT( dst->backend != GGML_BACKEND_GPU_SPLIT); - struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; - struct ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr; - struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr; + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT; const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU; @@ -6505,9 +6685,9 @@ static void ggml_cuda_op_mul_mat( const size_t q8_1_ts = sizeof(block_q8_1); const size_t q8_1_bs = QK8_1; - struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; - struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; - struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT; const bool src0_is_contiguous = ggml_is_contiguous(src0); @@ -6585,7 +6765,7 @@ static void ggml_cuda_op_mul_mat( if (convert_src1_to_q8_1) { src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]); - if (split && src1_on_device && src1_is_contiguous) { + if (src1_on_device && src1_is_contiguous) { quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream); CUDA_CHECK(cudaGetLastError()); } @@ -6667,7 +6847,7 @@ static void ggml_cuda_op_mul_mat( GGML_ASSERT(false); } - if (convert_src1_to_q8_1 && src1->backend == GGML_BACKEND_CPU) { + if (convert_src1_to_q8_1 && (src1->backend == GGML_BACKEND_CPU || !src1_is_contiguous)) { quantize_row_q8_1_cuda(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream); CUDA_CHECK(cudaGetLastError()); } @@ -6758,6 +6938,14 @@ static void ggml_cuda_op_mul_mat( } } +static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_repeat); +} + +static void ggml_cuda_get_rows(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_get_rows); +} + static void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add); } @@ -6812,13 +7000,13 @@ static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tens CUDA_CHECK(ggml_cuda_set_device(g_main_device)); cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; - struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; void * src0_ddq = src0_extra->data_device[g_main_device]; - struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; - struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream); @@ -6843,13 +7031,13 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor CUDA_CHECK(ggml_cuda_set_device(g_main_device)); cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; - struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; void * src0_ddq = src0_extra->data_device[g_main_device]; - struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; - struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; const int64_t row_stride_x = nb01 / sizeof(half); @@ -6870,11 +7058,11 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 } } - if (all_on_device && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { + if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { ggml_cuda_mul_mat_vec_p021(src0, src1, dst); } else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) { ggml_cuda_mul_mat_vec_nc(src0, src1, dst); - }else if (src0->type == GGML_TYPE_F32) { + } else if (src0->type == GGML_TYPE_F32) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) { @@ -6906,6 +7094,10 @@ static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale); } +static void ggml_cuda_clamp(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_clamp); +} + static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); @@ -6935,8 +7127,8 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg CUDA_CHECK(ggml_cuda_set_device(g_main_device)); cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; - const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; - const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + const ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + const ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; char * src1_ddc = (char *) src1_extra->data_device[g_main_device]; @@ -6991,8 +7183,8 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { const size_t nb1 = tensor->nb[1]; - ggml_backend backend = tensor->backend; - struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu; + ggml_backend_type backend = tensor->backend; + ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu; memset(extra, 0, sizeof(*extra)); for (int64_t id = 0; id < g_device_count; ++id) { @@ -7046,7 +7238,6 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { CUDA_CHECK(cudaMemset(buf + original_size, 0, size - original_size)); } - CUDA_CHECK(cudaMemcpy(buf, buf_host, original_size, cudaMemcpyHostToDevice)); extra->data_device[id] = buf; @@ -7085,17 +7276,17 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) { delete extra; } -static struct ggml_tensor_extra_gpu * g_temp_tensor_extras = nullptr; +static ggml_tensor_extra_gpu * g_temp_tensor_extras = nullptr; static size_t g_temp_tensor_extra_index = 0; -static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() { +static ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() { if (g_temp_tensor_extras == nullptr) { g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_MAX_NODES]; } size_t alloc_index = g_temp_tensor_extra_index; g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_MAX_NODES; - struct ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index]; + ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index]; memset(extra, 0, sizeof(*extra)); return extra; @@ -7123,7 +7314,7 @@ static void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scra return; } - struct ggml_tensor_extra_gpu * extra; + ggml_tensor_extra_gpu * extra; const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) || tensor->op == GGML_OP_VIEW || @@ -7132,7 +7323,7 @@ static void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scra CUDA_CHECK(ggml_cuda_set_device(g_main_device)); if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) { - struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra; + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra; char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; size_t offset = 0; if (tensor->op == GGML_OP_VIEW) { @@ -7141,7 +7332,7 @@ static void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scra extra = ggml_cuda_alloc_temp_tensor_extra(); extra->data_device[g_main_device] = src0_ddc + offset; } else if (tensor->op == GGML_OP_CPY) { - struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src[1]->extra; + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src[1]->extra; void * src1_ddv = src1_extra->data_device[g_main_device]; extra = ggml_cuda_alloc_temp_tensor_extra(); extra->data_device[g_main_device] = src1_ddv; @@ -7183,13 +7374,13 @@ void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset) CUDA_CHECK(cudaMalloc(&g_scratch_buffer, g_scratch_size)); } - struct ggml_tensor_extra_gpu * extra = ggml_cuda_alloc_temp_tensor_extra(); + ggml_tensor_extra_gpu * extra = ggml_cuda_alloc_temp_tensor_extra(); const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) || tensor->op == GGML_OP_VIEW; if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) { - struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra; + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra; char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; size_t view_offset = 0; if (tensor->op == GGML_OP_VIEW) { @@ -7207,7 +7398,7 @@ void ggml_cuda_copy_to_device(struct ggml_tensor * tensor) { GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); GGML_ASSERT(ggml_is_contiguous(tensor)); - struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; + ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; CUDA_CHECK(ggml_cuda_set_device(g_main_device)); CUDA_CHECK(cudaMemcpy(extra->data_device[g_main_device], tensor->data, ggml_nbytes(tensor), cudaMemcpyHostToDevice)); } @@ -7264,58 +7455,47 @@ void ggml_cuda_free_scratch() { g_scratch_buffer = nullptr; } -bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){ +bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { ggml_cuda_func_t func; const bool any_on_device = tensor->backend == GGML_BACKEND_GPU || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU); + if (!any_on_device && tensor->op != GGML_OP_MUL_MAT) { + return false; + } + switch (tensor->op) { + case GGML_OP_REPEAT: + func = ggml_cuda_repeat; + break; + case GGML_OP_GET_ROWS: + func = ggml_cuda_get_rows; + break; case GGML_OP_DUP: - if (!any_on_device) { - return false; - } func = ggml_cuda_dup; break; case GGML_OP_ADD: - if (!any_on_device) { - return false; - } func = ggml_cuda_add; break; case GGML_OP_MUL: - if (!any_on_device) { - return false; - } func = ggml_cuda_mul; break; case GGML_OP_UNARY: switch (ggml_get_unary_op(tensor)) { case GGML_UNARY_OP_GELU: - if (!any_on_device) { - return false; - } func = ggml_cuda_gelu; break; case GGML_UNARY_OP_SILU: - if (!any_on_device) { - return false; - } func = ggml_cuda_silu; break; default: return false; } break; case GGML_OP_NORM: - if (!any_on_device) { - return false; - } func = ggml_cuda_norm; break; case GGML_OP_RMS_NORM: - if (!any_on_device) { - return false; - } func = ggml_cuda_rms_norm; break; case GGML_OP_MUL_MAT: @@ -7325,54 +7505,36 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ func = ggml_cuda_mul_mat; break; case GGML_OP_SCALE: - if (!any_on_device) { - return false; - } func = ggml_cuda_scale; break; - case GGML_OP_CPY: + case GGML_OP_CLAMP: if (!any_on_device) { return false; } + func = ggml_cuda_clamp; + break; + case GGML_OP_CPY: func = ggml_cuda_cpy; break; case GGML_OP_CONT: - if (!any_on_device) { - return false; - } func = ggml_cuda_dup; break; case GGML_OP_RESHAPE: case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: - if (!any_on_device) { - return false; - } func = ggml_cuda_nop; break; case GGML_OP_DIAG_MASK_INF: - if (!any_on_device) { - return false; - } func = ggml_cuda_diag_mask_inf; break; case GGML_OP_SOFT_MAX: - if (!any_on_device) { - return false; - } func = ggml_cuda_soft_max; break; case GGML_OP_ROPE: - if (!any_on_device) { - return false; - } func = ggml_cuda_rope; break; case GGML_OP_ALIBI: - if (!any_on_device) { - return false; - } func = ggml_cuda_alibi; break; default: @@ -7400,3 +7562,263 @@ void ggml_cuda_get_device_description(int device, char * description, size_t des CUDA_CHECK(cudaGetDeviceProperties(&prop, device)); snprintf(description, description_size, "%s", prop.name); } + +//////////////////////////////////////////////////////////////////////////////// + +// backend interface + +#define UNUSED GGML_UNUSED + +struct ggml_backend_context_cuda { +}; + +static const char * ggml_backend_cuda_name(ggml_backend_t backend) { + return GGML_CUDA_NAME; + + UNUSED(backend); +} + +static void ggml_backend_cuda_free(ggml_backend_t backend) { + ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context; + delete cuda_ctx; + delete backend; +} + +struct ggml_backend_buffer_context_cuda { + void * device; + + ggml_tensor_extra_gpu * temp_tensor_extras = nullptr; + size_t temp_tensor_extra_index = 0; + + ~ggml_backend_buffer_context_cuda() { + delete[] temp_tensor_extras; + } + + ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() { + if (temp_tensor_extras == nullptr) { + temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_MAX_NODES]; + } + + size_t alloc_index = temp_tensor_extra_index; + temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_MAX_NODES; + ggml_tensor_extra_gpu * extra = &temp_tensor_extras[alloc_index]; + memset(extra, 0, sizeof(*extra)); + + return extra; + } +}; + +static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context; + CUDA_CHECK(cudaFree(ctx->device)); + delete ctx; +} + +static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) { + ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context; + return ctx->device; +} + +static size_t ggml_backend_cuda_buffer_get_alloc_size(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + int64_t row_low = 0; + int64_t row_high = ggml_nrows(tensor); + int64_t nrows_split = row_high - row_low; + + size_t size = ggml_nbytes_split(tensor, nrows_split); + + int64_t ne0 = tensor->ne[0]; + + if (ggml_is_quantized(tensor->type)) { + if (ne0 % MATRIX_ROW_PADDING != 0) { + size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING) + * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type); + } + } + + return size; + + UNUSED(buffer); +} + +static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context; + + if (tensor->view_src != NULL && tensor->view_offs == 0) { + assert(tensor->view_src->buffer->backend == buffer->backend); + tensor->backend = tensor->view_src->backend; + tensor->extra = tensor->view_src->extra; + return; + } + + ggml_tensor_extra_gpu * extra = ctx->ggml_cuda_alloc_temp_tensor_extra(); + + extra->data_device[g_main_device] = tensor->data; + + tensor->backend = GGML_BACKEND_GPU; + tensor->extra = extra; + + if (ggml_is_quantized(tensor->type)) { + // initialize padding to 0 to avoid possible NaN values + int64_t row_low = 0; + int64_t row_high = ggml_nrows(tensor); + int64_t nrows_split = row_high - row_low; + + size_t original_size = ggml_nbytes_split(tensor, nrows_split); + size_t padded_size = ggml_backend_cuda_buffer_get_alloc_size(tensor->buffer, tensor); + + if (padded_size > original_size && tensor->view_src == nullptr) { + CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + original_size, 0, padded_size - original_size, g_cudaStreams[g_main_device][0])); + } + } + + UNUSED(buffer); +} + +static struct ggml_backend_buffer_i cuda_backend_buffer_interface = { + /* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer, + /* .get_base = */ ggml_backend_cuda_buffer_get_base, + /* .get_alloc_size = */ ggml_backend_cuda_buffer_get_alloc_size, + /* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor, + /* .free_tensor = */ NULL, +}; + +static ggml_backend_buffer_t ggml_backend_cuda_alloc_buffer(ggml_backend_t backend, size_t size) { + ggml_cuda_set_device(g_main_device); + + ggml_backend_buffer_context_cuda * ctx = new ggml_backend_buffer_context_cuda; + CUDA_CHECK(cudaMalloc(&ctx->device, size)); + return ggml_backend_buffer_init(backend, cuda_backend_buffer_interface, ctx, size); +} + +static size_t ggml_backend_cuda_get_alignment(ggml_backend_t backend) { + return 128; + UNUSED(backend); +} + +static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + + CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, g_cudaStreams[g_main_device][0])); + + UNUSED(backend); +} + +static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + + CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0])); + + UNUSED(backend); +} + +static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { + CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0])); + + UNUSED(backend); +} + +static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backend_t backend, ggml_cgraph * cgraph) { + GGML_ASSERT(!"not implemented"); + + return nullptr; + + UNUSED(backend); + UNUSED(cgraph); +} + +static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + GGML_ASSERT(!"not implemented"); + + UNUSED(backend); + UNUSED(plan); +} + +static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + GGML_ASSERT(!"not implemented"); + + UNUSED(backend); + UNUSED(plan); +} + +static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_cuda_set_device(g_main_device); + + ggml_compute_params params = {}; + params.type = GGML_TASK_COMPUTE; + params.ith = 0; + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + + assert(node->backend == GGML_BACKEND_GPU); + for (int j = 0; j < GGML_MAX_SRC; j++) { + if (node->src[j] != nullptr) { + assert(node->src[j]->backend == GGML_BACKEND_GPU); + } + } + + bool ok = ggml_cuda_compute_forward(¶ms, node); + if (!ok) { + fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + } + GGML_ASSERT(ok); + +#if 0 + if (node->type == GGML_TYPE_F32) { + cudaDeviceSynchronize(); + std::vector tmp(ggml_nelements(node), 0.0f); + cudaMemcpy(tmp.data(), node->data, ggml_nelements(node)*sizeof(float), cudaMemcpyDeviceToHost); + printf("\n%s (%s) (%s %s) (%s %s): ", node->name, ggml_op_name(node->op), + ggml_type_name(node->src[0]->type), + node->src[1] ? ggml_type_name(node->src[1]->type) : "none", + node->src[0]->name, + node->src[1] ? node->src[1]->name : "none"); + double sum = 0.0; + double sq_sum = 0.0; + for (int i = 0; i < ggml_nelements(node); i++) { + printf("%f ", tmp[i]); + sum += tmp[i]; + sq_sum += tmp[i]*tmp[i]; + } + printf("\n"); + printf("sum: %f, ", sum); + printf("sq_sum: %f\n", sq_sum); + } +#endif + } + + UNUSED(backend); +} + +static ggml_backend_i cuda_backend_i = { + /* .get_name = */ ggml_backend_cuda_name, + /* .free = */ ggml_backend_cuda_free, + /* .alloc_buffer = */ ggml_backend_cuda_alloc_buffer, + /* .get_alignment = */ ggml_backend_cuda_get_alignment, + /* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async, + /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async, + /* .synchronize = */ ggml_backend_cuda_synchronize, + /* .cpy_tensor_from = */ nullptr, + /* .cpy_tensor_to = */ nullptr, + /* .graph_plan_create = */ ggml_backend_cuda_graph_plan_create, + /* .graph_plan_free = */ ggml_backend_cuda_graph_plan_free, + /* .graph_plan_compute = */ ggml_backend_cuda_graph_plan_compute, + /* .graph_compute = */ ggml_backend_cuda_graph_compute, + /* .supports_op = */ nullptr, +}; + +ggml_backend_t ggml_backend_cuda_init() { + ggml_init_cublas(); // TODO: remove from ggml.c + + ggml_backend_context_cuda * ctx = new ggml_backend_context_cuda; + + ggml_backend_t cuda_backend = new ggml_backend { + /* .interface = */ cuda_backend_i, + /* .context = */ ctx + }; + + return cuda_backend; +} diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-cuda.h b/plugins/wasi_nn/thirdparty/ggml/ggml-cuda.h index fda704b6..57adc9cf 100644 --- a/plugins/wasi_nn/thirdparty/ggml/ggml-cuda.h +++ b/plugins/wasi_nn/thirdparty/ggml/ggml-cuda.h @@ -1,6 +1,7 @@ #pragma once #include "ggml.h" +#include "ggml-backend.h" #ifdef GGML_USE_HIPBLAS #define GGML_CUDA_NAME "ROCm" @@ -42,6 +43,9 @@ GGML_API bool ggml_cuda_compute_forward(struct ggml_compute_params * params, s GGML_API int ggml_cuda_get_device_count(void); GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size); +// backend API +GGML_API ggml_backend_t ggml_backend_cuda_init(void); // TODO: take a list of devices to use + #ifdef __cplusplus } #endif diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-metal.h b/plugins/wasi_nn/thirdparty/ggml/ggml-metal.h index 790cf0bf..096b844e 100644 --- a/plugins/wasi_nn/thirdparty/ggml/ggml-metal.h +++ b/plugins/wasi_nn/thirdparty/ggml/ggml-metal.h @@ -20,6 +20,7 @@ #pragma once #include "ggml.h" +#include "ggml-backend.h" #include #include @@ -35,10 +36,15 @@ struct ggml_cgraph; extern "C" { #endif -void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data); +// +// internal API +// temporary exposed to user-code +// struct ggml_metal_context; +void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data); + // number of command buffers to use struct ggml_metal_context * ggml_metal_init(int n_cb); void ggml_metal_free(struct ggml_metal_context * ctx); @@ -83,6 +89,17 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx); // creates gf->n_threads command buffers in parallel void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf); +// +// backend API +// user-code should use only these functions +// + +GGML_API ggml_backend_t ggml_backend_metal_init(void); + +GGML_API bool ggml_backend_is_metal(ggml_backend_t backend); + +GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb); + #ifdef __cplusplus } #endif diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-metal.m b/plugins/wasi_nn/thirdparty/ggml/ggml-metal.m index b3c463f0..87fa1721 100644 --- a/plugins/wasi_nn/thirdparty/ggml/ggml-metal.m +++ b/plugins/wasi_nn/thirdparty/ggml/ggml-metal.m @@ -81,18 +81,18 @@ GGML_METAL_DECL_KERNEL(get_rows_q6_K); GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(norm); - GGML_METAL_DECL_KERNEL(mul_mat_f32_f32); - GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); - GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row); - GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4); - GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32); + GGML_METAL_DECL_KERNEL(mul_mv_f32_f32); + GGML_METAL_DECL_KERNEL(mul_mv_f16_f32); + GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row); + GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4); + GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_f32_f32); GGML_METAL_DECL_KERNEL(mul_mm_f16_f32); GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32); @@ -109,6 +109,8 @@ GGML_METAL_DECL_KERNEL(cpy_f32_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f32); GGML_METAL_DECL_KERNEL(cpy_f16_f16); + GGML_METAL_DECL_KERNEL(concat); + GGML_METAL_DECL_KERNEL(sqr); #undef GGML_METAL_DECL_KERNEL }; @@ -183,56 +185,44 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){ ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); -#ifdef GGML_SWIFT - // load the default.metallib file + // load library { - NSError * error = nil; - - NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; - NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"]; - NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath]; - NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"]; - NSURL * libURL = [NSURL fileURLWithPath:libPath]; - - // Load the metallib file into a Metal library - ctx->library = [ctx->device newLibraryWithURL:libURL error:&error]; - - if (error) { - GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); - return NULL; - } - } + NSBundle * bundle = nil; +#ifdef SWIFT_PACKAGE + bundle = SWIFTPM_MODULE_BUNDLE; #else - UNUSED(msl_library_source); - - // read the source from "ggml-metal.metal" into a string and use newLibraryWithSource - { + bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; +#endif NSError * error = nil; + NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"]; + if (libPath != nil) { + NSURL * libURL = [NSURL fileURLWithPath:libPath]; + GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]); + ctx->library = [ctx->device newLibraryWithURL:libURL error:&error]; + } else { + GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); + + NSString * sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; + GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [sourcePath UTF8String]); + NSString * src = [NSString stringWithContentsOfFile:sourcePath encoding:NSUTF8StringEncoding error:&error]; + if (error) { + GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } - //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"]; - NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; - NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; - GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path UTF8String]); - - NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error]; - if (error) { - GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); - return NULL; - } - + MTLCompileOptions* options = nil; #ifdef GGML_QKK_64 - MTLCompileOptions* options = [MTLCompileOptions new]; - options.preprocessorMacros = @{ @"QK_K" : @(64) }; - ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; -#else - ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error]; + options = [MTLCompileOptions new]; + options.preprocessorMacros = @{ @"QK_K" : @(64) }; #endif + ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; + } + if (error) { GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); return NULL; } } -#endif // load kernels { @@ -272,40 +262,57 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){ GGML_METAL_ADD_KERNEL(get_rows_q6_K); GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(norm); - GGML_METAL_ADD_KERNEL(mul_mat_f32_f32); - GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); - GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row); - GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4); - GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32); - GGML_METAL_ADD_KERNEL(mul_mm_f32_f32); - GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32); + GGML_METAL_ADD_KERNEL(mul_mv_f32_f32); + GGML_METAL_ADD_KERNEL(mul_mv_f16_f32); + GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row); + GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4); + GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32); + if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) { + GGML_METAL_ADD_KERNEL(mul_mm_f32_f32); + GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32); + } GGML_METAL_ADD_KERNEL(rope_f32); GGML_METAL_ADD_KERNEL(rope_f16); GGML_METAL_ADD_KERNEL(alibi_f32); GGML_METAL_ADD_KERNEL(cpy_f32_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f32); GGML_METAL_ADD_KERNEL(cpy_f16_f16); + GGML_METAL_ADD_KERNEL(concat); + GGML_METAL_ADD_KERNEL(sqr); #undef GGML_METAL_ADD_KERNEL } - GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); #if TARGET_OS_OSX + // print MTL GPU family: + GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]); + + // determine max supported GPU family + // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf + // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf + for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) { + if ([ctx->device supportsFamily:i]) { + GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i); + break; + } + } + + GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); if (ctx->device.maxTransferRate != 0) { GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0); @@ -347,34 +354,38 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(get_rows_q6_K); GGML_METAL_DEL_KERNEL(rms_norm); GGML_METAL_DEL_KERNEL(norm); - GGML_METAL_DEL_KERNEL(mul_mat_f32_f32); - GGML_METAL_DEL_KERNEL(mul_mat_f16_f32); - GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row); - GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4); - GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32); - GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32); - GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32); - GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32); - GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32); - GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32); - GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32); - GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32); - GGML_METAL_DEL_KERNEL(mul_mm_f32_f32); - GGML_METAL_DEL_KERNEL(mul_mm_f16_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32); + GGML_METAL_DEL_KERNEL(mul_mv_f32_f32); + GGML_METAL_DEL_KERNEL(mul_mv_f16_f32); + GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row); + GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4); + GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32); + if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) { + GGML_METAL_DEL_KERNEL(mul_mm_f32_f32); + GGML_METAL_DEL_KERNEL(mul_mm_f16_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32); + } GGML_METAL_DEL_KERNEL(rope_f32); GGML_METAL_DEL_KERNEL(rope_f16); GGML_METAL_DEL_KERNEL(alibi_f32); GGML_METAL_DEL_KERNEL(cpy_f32_f16); GGML_METAL_DEL_KERNEL(cpy_f32_f32); GGML_METAL_DEL_KERNEL(cpy_f16_f16); + GGML_METAL_DEL_KERNEL(concat); + GGML_METAL_DEL_KERNEL(sqr); #undef GGML_METAL_DEL_KERNEL @@ -431,7 +442,7 @@ int ggml_metal_if_optimized(struct ggml_metal_context * ctx) { for (int i = 0; i < ctx->n_buffers; ++i) { const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data; - //metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name); + //GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name); if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) { *offs = (size_t) ioffs; @@ -766,6 +777,44 @@ void ggml_metal_graph_compute( { // noop } break; + case GGML_OP_CONCAT: + { + const int64_t nb = ne00; + + [encoder setComputePipelineState:ctx->pipeline_concat]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; + [encoder setBytes:&nb length:sizeof(nb) atIndex:27]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_ADD: { GGML_ASSERT(ggml_is_contiguous(src0)); @@ -861,9 +910,10 @@ void ggml_metal_graph_compute( [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; - const int64_t n = ggml_nelements(dst)/4; + const int64_t n = ggml_nelements(dst); + GGML_ASSERT(n % 4 == 0); - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_OP_UNARY: switch (ggml_get_unary_op(gf->nodes[i])) { @@ -873,9 +923,10 @@ void ggml_metal_graph_compute( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const int64_t n = ggml_nelements(dst)/4; + const int64_t n = ggml_nelements(dst); + GGML_ASSERT(n % 4 == 0); - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_UNARY_OP_RELU: { @@ -893,9 +944,10 @@ void ggml_metal_graph_compute( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const int64_t n = ggml_nelements(dst)/4; + const int64_t n = ggml_nelements(dst); + GGML_ASSERT(n % 4 == 0); - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; default: { @@ -903,6 +955,17 @@ void ggml_metal_graph_compute( GGML_ASSERT(false); } } break; + case GGML_OP_SQR: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + [encoder setComputePipelineState:ctx->pipeline_sqr]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; case GGML_OP_SOFT_MAX: { const int nth = MIN(32, ne00); @@ -944,21 +1007,46 @@ void ggml_metal_graph_compute( } break; case GGML_OP_MUL_MAT: { - // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224 - GGML_ASSERT(ne00 == ne10); - // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere - uint gqa = ne12/ne02; GGML_ASSERT(ne03 == ne13); + const uint gqa = ne12/ne02; + + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + int ne11_mm_min = 1; + +#if 0 + // the numbers below are measured on M2 Ultra for 7B and 13B models + // these numbers do not translate to other devices or model sizes + // TODO: need to find a better approach + if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) { + switch (src0t) { + case GGML_TYPE_F16: ne11_mm_min = 2; break; + case GGML_TYPE_Q8_0: ne11_mm_min = 7; break; + case GGML_TYPE_Q2_K: ne11_mm_min = 15; break; + case GGML_TYPE_Q3_K: ne11_mm_min = 7; break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: ne11_mm_min = 15; break; + case GGML_TYPE_Q4_K: ne11_mm_min = 11; break; + case GGML_TYPE_Q5_0: // not tested yet + case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet + case GGML_TYPE_Q5_K: ne11_mm_min = 7; break; + case GGML_TYPE_Q6_K: ne11_mm_min = 7; break; + default: ne11_mm_min = 1; break; + } + } +#endif + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - if (!ggml_is_transposed(src0) && + if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && + !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1t == GGML_TYPE_F32 && - [ctx->device supportsFamily:MTLGPUFamilyApple7] && - ne00%32 == 0 && - ne11 > 2) { + ne00 % 32 == 0 && ne00 >= 64 && + ne11 > ne11_mm_min) { + //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); switch (src0->type) { case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break; case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break; @@ -987,17 +1075,18 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { int nth0 = 32; int nth1 = 1; int nrows = 1; + //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); // use custom matrix x vector kernel switch (src0t) { case GGML_TYPE_F32: { - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32]; nrows = 4; } break; case GGML_TYPE_F16: @@ -1005,12 +1094,12 @@ void ggml_metal_graph_compute( nth0 = 32; nth1 = 1; if (ne11 * ne12 < 4) { - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row]; } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4]; nrows = ne11; } else { - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32]; nrows = 4; } } break; @@ -1021,7 +1110,7 @@ void ggml_metal_graph_compute( nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32]; } break; case GGML_TYPE_Q4_1: { @@ -1030,7 +1119,7 @@ void ggml_metal_graph_compute( nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32]; } break; case GGML_TYPE_Q8_0: { @@ -1039,7 +1128,7 @@ void ggml_metal_graph_compute( nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32]; } break; case GGML_TYPE_Q2_K: { @@ -1048,7 +1137,7 @@ void ggml_metal_graph_compute( nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32]; } break; case GGML_TYPE_Q3_K: { @@ -1057,7 +1146,7 @@ void ggml_metal_graph_compute( nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32]; } break; case GGML_TYPE_Q4_K: { @@ -1066,7 +1155,7 @@ void ggml_metal_graph_compute( nth0 = 4; //1; nth1 = 8; //32; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32]; } break; case GGML_TYPE_Q5_K: { @@ -1075,7 +1164,7 @@ void ggml_metal_graph_compute( nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32]; } break; case GGML_TYPE_Q6_K: { @@ -1084,7 +1173,7 @@ void ggml_metal_graph_compute( nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32]; } break; default: { @@ -1113,7 +1202,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 || - src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) { + src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q4_K) { @@ -1166,6 +1255,8 @@ void ggml_metal_graph_compute( } break; case GGML_OP_RMS_NORM: { + GGML_ASSERT(ne00 % 4 == 0); + float eps; memcpy(&eps, dst->op_params, sizeof(float)); @@ -1208,17 +1299,14 @@ void ggml_metal_graph_compute( const int nth = MIN(1024, ne00); - const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past); + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - if (__builtin_popcount(n_head) != 1) { - GGML_ASSERT(false && "only power-of-two n_head implemented"); - } - const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); [encoder setComputePipelineState:ctx->pipeline_alibi_f32]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1239,7 +1327,9 @@ void ggml_metal_graph_compute( [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&m0 length:sizeof( float) atIndex:18]; + [encoder setBytes:&m0 length:sizeof( float) atIndex:18]; + [encoder setBytes:&m1 length:sizeof( float) atIndex:19]; + [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; @@ -1372,3 +1462,140 @@ void ggml_metal_graph_compute( } } + +//////////////////////////////////////////////////////////////////////////////// + +// backend interface + +static const char * ggml_backend_metal_name(ggml_backend_t backend) { + return "Metal"; + + UNUSED(backend); +} + +static void ggml_backend_metal_free(ggml_backend_t backend) { + struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context; + ggml_metal_free(ctx); + free(backend); +} + +static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) { + return (void *)buffer->context; +} + +static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) { + free(buffer->context); + UNUSED(buffer); +} + +static struct ggml_backend_buffer_i metal_backend_buffer_i = { + /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer, + /* .get_base = */ ggml_backend_metal_buffer_get_base, + /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .init_tensor = */ NULL, // no initialization required + /* .free_tensor = */ NULL, // no cleanup required +}; + +static ggml_backend_buffer_t ggml_backend_metal_alloc_buffer(ggml_backend_t backend, size_t size) { + struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context; + + void * data = ggml_metal_host_malloc(size); + + // TODO: set proper name of the buffers + ggml_metal_add_buffer(ctx, "backend", data, size, 0); + + return ggml_backend_buffer_init(backend, metal_backend_buffer_i, data, size); +} + +static size_t ggml_backend_metal_get_alignment(ggml_backend_t backend) { + return 32; + UNUSED(backend); +} + +static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + + memcpy((char *)tensor->data + offset, data, size); + + UNUSED(backend); +} + +static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + + memcpy(data, (const char *)tensor->data + offset, size); + + UNUSED(backend); +} + +static void ggml_backend_metal_synchronize(ggml_backend_t backend) { + UNUSED(backend); +} + +static void ggml_backend_metal_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) { + ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src)); + + UNUSED(backend); +} + +static void ggml_backend_metal_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) { + ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src)); + + UNUSED(backend); +} + +static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context; + + ggml_metal_graph_compute(metal_ctx, cgraph); +} + +static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { + return true; + UNUSED(backend); + UNUSED(op); +} + +static struct ggml_backend_i metal_backend_i = { + /* .get_name = */ ggml_backend_metal_name, + /* .free = */ ggml_backend_metal_free, + /* .alloc_buffer = */ ggml_backend_metal_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_get_alignment, + /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async, + /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async, + /* .synchronize = */ ggml_backend_metal_synchronize, + /* .cpy_tensor_from = */ ggml_backend_metal_cpy_tensor_from, + /* .cpy_tensor_to = */ ggml_backend_metal_cpy_tensor_to, + /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm + /* .graph_plan_free = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_metal_graph_compute, + /* .supports_op = */ ggml_backend_metal_supports_op, +}; + +ggml_backend_t ggml_backend_metal_init(void) { + struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context)); + + ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS); + + ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend)); + + *metal_backend = (struct ggml_backend) { + /* .interface = */ metal_backend_i, + /* .context = */ ctx, + }; + + return metal_backend; +} + +bool ggml_backend_is_metal(ggml_backend_t backend) { + return backend->iface.get_name == ggml_backend_metal_name; +} + +void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { + struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context; + + ggml_metal_set_n_cb(ctx, n_cb); +} diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-metal.metal b/plugins/wasi_nn/thirdparty/ggml/ggml-metal.metal index 5e1af6a0..99b9fd7a 100644 --- a/plugins/wasi_nn/thirdparty/ggml/ggml-metal.metal +++ b/plugins/wasi_nn/thirdparty/ggml/ggml-metal.metal @@ -13,8 +13,8 @@ typedef struct { #define QK4_1 32 typedef struct { - half d; // delta - half m; // min + half d; // delta + half m; // min uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; @@ -132,6 +132,13 @@ kernel void kernel_relu( dst[tpig] = max(0.0f, src0[tpig]); } +kernel void kernel_sqr( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * src0[tpig]; +} + constant float GELU_COEF_A = 0.044715f; constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; @@ -338,10 +345,11 @@ kernel void kernel_rms_norm( uint sgitg[[simdgroup_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); - device const float * x_scalar = (device const float *) x; - float4 sumf=0; - float all_sum=0; + device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); + device const float * x_scalar = (device const float *) x; + + float4 sumf = 0; + float all_sum = 0; // parallel sum for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { @@ -354,6 +362,7 @@ kernel void kernel_rms_norm( } threadgroup_barrier(mem_flags::mem_threadgroup); + // broadcast, simd group number is ntg / 32 for (uint i = ntg / 32 / 2; i > 0; i /= 2) { if (tpitg < i) { @@ -361,7 +370,9 @@ kernel void kernel_rms_norm( } } if (tpitg == 0) { - for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];} + for (int i = 4 * (ne00 / 4); i < ne00; i++) { + sum[0] += x_scalar[i]; + } sum[0] /= ne00; } @@ -376,7 +387,9 @@ kernel void kernel_rms_norm( y[i00] = x[i00] * scale; } if (tpitg == 0) { - for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;} + for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) { + y_scalar[i00] = x_scalar[i00] * scale; + } } } @@ -416,8 +429,8 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre } // putting them in the kernel cause a significant performance penalty -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 2 // number of SIMD groups in a thread group +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 2 // number of SIMD groups in a thread group #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 //Note: This is a template, but strictly speaking it only applies to // quantizations where the block size is 32. It also does not @@ -428,18 +441,23 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa, uint3 tgpig, uint tiisg, uint sgitg) { const int nb = ne00/QK4_0; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; + const int first_row = (r0 * nsg + sgitg) * nr; + const uint offset0 = first_row * nb + im/gqa*(nb*ne0); + device const block_q_type * x = (device const block_q_type *) src0 + offset0; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - float yl[16]; // src1 vector cache - float sumf[nr]={0.f}; - const int ix = tiisg/2; - const int il = 8*(tiisg%2); + float yl[16]; // src1 vector cache + float sumf[nr] = {0.f}; + + const int ix = (tiisg/2); + const int il = (tiisg%2)*8; device const float * yb = y + ix * QK4_0 + il; @@ -450,6 +468,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device sumy += yb[i] + yb[i+1]; yl[i+0] = yb[i+ 0]; yl[i+1] = yb[i+ 1]/256.f; + sumy += yb[i+16] + yb[i+17]; yl[i+8] = yb[i+16]/16.f; yl[i+9] = yb[i+17]/4096.f; @@ -465,12 +484,12 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0 && first_row + row < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot; } } } -kernel void kernel_mul_mat_q4_0_f32( +kernel void kernel_mul_mv_q4_0_f32( device const void * src0, device const float * src1, device float * dst, @@ -483,12 +502,12 @@ kernel void kernel_mul_mat_q4_0_f32( constant int64_t & ne1[[buffer(16)]], constant uint & gqa[[buffer(17)]], uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); } -kernel void kernel_mul_mat_q4_1_f32( +kernel void kernel_mul_mv_q4_1_f32( device const void * src0, device const float * src1, device float * dst, @@ -508,7 +527,7 @@ kernel void kernel_mul_mat_q4_1_f32( #define NB_Q8_0 8 -kernel void kernel_mul_mat_q8_0_f32( +kernel void kernel_mul_mv_q8_0_f32( device const void * src0, device const float * src1, device float * dst, @@ -572,7 +591,7 @@ kernel void kernel_mul_mat_q8_0_f32( #define N_F32_F32 4 -kernel void kernel_mul_mat_f32_f32( +kernel void kernel_mul_mv_f32_f32( device const char * src0, device const char * src1, device float * dst, @@ -643,7 +662,7 @@ kernel void kernel_mul_mat_f32_f32( } } -kernel void kernel_mul_mat_f16_f32_1row( +kernel void kernel_mul_mv_f16_f32_1row( device const char * src0, device const char * src1, device float * dst, @@ -662,7 +681,7 @@ kernel void kernel_mul_mat_f16_f32_1row( constant int64_t & ne0, constant int64_t & ne1, uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + uint tiisg[[thread_index_in_simdgroup]]) { const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; @@ -697,7 +716,7 @@ kernel void kernel_mul_mat_f16_f32_1row( #define N_F16_F32 4 -kernel void kernel_mul_mat_f16_f32( +kernel void kernel_mul_mv_f16_f32( device const char * src0, device const char * src1, device float * dst, @@ -769,7 +788,7 @@ kernel void kernel_mul_mat_f16_f32( } // Assumes row size (ne00) is a multiple of 4 -kernel void kernel_mul_mat_f16_f32_l4( +kernel void kernel_mul_mv_f16_f32_l4( device const char * src0, device const char * src1, device float * dst, @@ -830,7 +849,9 @@ kernel void kernel_alibi_f32( constant uint64_t & nb1, constant uint64_t & nb2, constant uint64_t & nb3, - constant float & m0, + constant float & m0, + constant float & m1, + constant int & n_heads_log2_floor, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -846,7 +867,12 @@ kernel void kernel_alibi_f32( const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - float m_k = pow(m0, i2 + 1); + float m_k; + if (i2 < n_heads_log2_floor) { + m_k = pow(m0, i2 + 1); + } else { + m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1); + } for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1); @@ -1091,6 +1117,62 @@ kernel void kernel_cpy_f32_f32( } } +kernel void kernel_concat( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + if (i02 < ne02) { + ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0]; + src0_ptr += ntg.x*nb00; + } else { + ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0]; + src1_ptr += ntg.x*nb10; + } + dst_ptr += ntg.x*nb0; + } +} + //============================================ k-quants ====================================================== #ifndef QK_K @@ -1183,7 +1265,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { //====================================== dot products ========================= -kernel void kernel_mul_mat_q2_K_f32( +kernel void kernel_mul_mv_q2_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1327,7 +1409,7 @@ kernel void kernel_mul_mat_q2_K_f32( } #if QK_K == 256 -kernel void kernel_mul_mat_q3_K_f32( +kernel void kernel_mul_mv_q3_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1479,7 +1561,7 @@ kernel void kernel_mul_mat_q3_K_f32( } } #else -kernel void kernel_mul_mat_q3_K_f32( +kernel void kernel_mul_mv_q3_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1550,7 +1632,7 @@ kernel void kernel_mul_mat_q3_K_f32( #endif #if QK_K == 256 -kernel void kernel_mul_mat_q4_K_f32( +kernel void kernel_mul_mv_q4_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1656,7 +1738,7 @@ kernel void kernel_mul_mat_q4_K_f32( } } #else -kernel void kernel_mul_mat_q4_K_f32( +kernel void kernel_mul_mv_q4_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1745,7 +1827,7 @@ kernel void kernel_mul_mat_q4_K_f32( } #endif -kernel void kernel_mul_mat_q5_K_f32( +kernel void kernel_mul_mv_q5_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1918,7 +2000,7 @@ kernel void kernel_mul_mat_q5_K_f32( } -kernel void kernel_mul_mat_q6_K_f32( +kernel void kernel_mul_mv_q6_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -2256,7 +2338,7 @@ kernel void kernel_get_rows( } #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A -#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A +#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B #define BLOCK_SIZE_K 32 #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B @@ -2293,9 +2375,11 @@ kernel void kernel_mul_mm(device const uchar * src0, const uint r0 = tgpig.y; const uint r1 = tgpig.x; const uint im = tgpig.z; + // if this block is of 64x32 shape or smaller short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; + // a thread shouldn't load data outside of the matrix short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; @@ -2319,26 +2403,30 @@ kernel void kernel_mul_mm(device const uchar * src0, + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { - //load data and store to threadgroup memory + // load data and store to threadgroup memory half4x4 temp_a; dequantize_func(x, il, temp_a); threadgroup_barrier(mem_flags::mem_threadgroup); + #pragma unroll(16) for (int i = 0; i < 16; i++) { *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ - + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \ - + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; + + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ + + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; } - *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \ - = *((device float2x4 *)y); + + *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); + il = (il + 2 < nl) ? il + 2 : il % 2; x = (il < 2) ? x + (2+nl-1)/nl : x; y += BLOCK_SIZE_K; threadgroup_barrier(mem_flags::mem_threadgroup); - //load matrices from threadgroup memory and conduct outer products + + // load matrices from threadgroup memory and conduct outer products threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + #pragma unroll(4) for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { #pragma unroll(4) @@ -2353,6 +2441,7 @@ kernel void kernel_mul_mm(device const uchar * src0, lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; + #pragma unroll(8) for (int i = 0; i < 8; i++){ simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); @@ -2361,25 +2450,26 @@ kernel void kernel_mul_mm(device const uchar * src0, } if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { - device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \ - + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0; + device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \ + + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; for (int i = 0; i < 8; i++) { simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); } } else { // block is smaller than 64x32, we should avoid writing data outside of the matrix threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float *temp_str = ((threadgroup float *)shared_memory) \ + threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; for (int i = 0; i < 8; i++) { simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); } threadgroup_barrier(mem_flags::mem_threadgroup); - device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; - if (sgitg==0) { + + device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; + if (sgitg == 0) { for (int i = 0; i < n_rows; i++) { - for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) { + for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); } } diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml.c b/plugins/wasi_nn/thirdparty/ggml/ggml.c index bf1426d2..630deb49 100644 --- a/plugins/wasi_nn/thirdparty/ggml/ggml.c +++ b/plugins/wasi_nn/thirdparty/ggml/ggml.c @@ -162,40 +162,16 @@ typedef void * thread_ret_t; #define GGML_PRINT(...) printf(__VA_ARGS__) +// +// end of logging block +// + #ifdef GGML_USE_ACCELERATE // uncomment to use vDSP for soft max computation // note: not sure if it is actually faster //#define GGML_SOFT_MAX_ACCELERATE #endif -// -// logging -// - -#if (GGML_DEBUG >= 1) -#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG(...) -#endif - -#if (GGML_DEBUG >= 5) -#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG_5(...) -#endif - -#if (GGML_DEBUG >= 10) -#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG_10(...) -#endif - -#define GGML_PRINT(...) printf(__VA_ARGS__) - -// -// end of logging block -// - #if defined(_MSC_VER) || defined(__MINGW32__) #define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN) #define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr) @@ -1032,8 +1008,8 @@ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * r y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); // get the 5-th bit and store it in qh at the right position - qh |= ((xi0 & 0x10) >> 4) << (j + 0); - qh |= ((xi1 & 0x10) >> 4) << (j + qk/2); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2); } memcpy(&y[i].qh, &qh, sizeof(qh)); @@ -1080,8 +1056,8 @@ static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * r y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); // get the 5-th bit and store it in qh at the right position - qh |= ((xi0 & 0x10) >> 4) << (j + 0); - qh |= ((xi1 & 0x10) >> 4) << (j + qk/2); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2); } memcpy(&y[i].qh, &qh, sizeof(y[i].qh)); @@ -1272,6 +1248,33 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); #endif } +#elif defined(__riscv_v_intrinsic) + + size_t vl = __riscv_vsetvl_e32m4(QK8_0); + + for (int i = 0; i < nb; i++) { + // load elements + vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl); + + vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl); + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl); + float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl); + + // convert to integer + vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl); + vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl); + + // store result + __riscv_vse8_v_i8m1(y[i].qs , vs, vl); + } #else // scalar quantize_row_q8_0_reference(x, y, k); @@ -1490,6 +1493,41 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); #endif } +#elif defined(__riscv_v_intrinsic) + + size_t vl = __riscv_vsetvl_e32m4(QK8_1); + + for (int i = 0; i < nb; i++) { + // load elements + vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl); + + vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl); + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl); + vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl); + float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + + vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl); + + // convert to integer + vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl); + vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl); + + // store result + __riscv_vse8_v_i8m1(y[i].qs , vs, vl); + + // compute sum for y[i].s + vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl); + + // set y[i].s + int sum = __riscv_vmv_x_s_i16m1_i16(vwrs); + y[i].s = sum*d; + } #else // scalar quantize_row_q8_1_reference(x, y, k); @@ -2662,30 +2700,32 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * size_t vl = __riscv_vsetvl_e8m1(qk/2); for (int i = 0; i < nb; i++) { - vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl); + // load elements + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); - vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl); - vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl); + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); - vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl); - vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); + // mask and store lower part of x, and then upper part + vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a); - vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l); + vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl); - vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl); + // subtract offset + vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl); + vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl); - vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); - vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl); + vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); + vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - int sumi = __riscv_vmv_x_s_i32m1_i32(vs1); - sumi += __riscv_vmv_x_s_i32m1_i32(vs2); + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d); } @@ -2823,27 +2863,28 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * size_t vl = __riscv_vsetvl_e8m1(qk/2); for (int i = 0; i < nb; i++) { - vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl); + // load elements + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); - vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl); - vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl); + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); - vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl); - vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); + // mask and store lower part of x, and then upper part + vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a); - vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l); + vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); - vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl); + vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); + vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - int sumi = __riscv_vmv_x_s_i32m1_i32(vs1); - sumi += __riscv_vmv_x_s_i32m1_i32(vs2); + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; } @@ -3088,66 +3129,61 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * uint32_t qh; - // These temp values are for masking and shift operations - uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - uint32_t temp_2[16] = {0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80, - 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000}; - size_t vl = __riscv_vsetvl_e8m1(qk/2); + // These tempory registers are for masking and shift operations + vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); + vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl); + + vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl); + vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); + for (int i = 0; i < nb; i++) { memcpy(&qh, x[i].qh, sizeof(uint32_t)); - // temporary registers - vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_2, vl); - vuint32m4_t vt_2 = __riscv_vle32_v_u32m4(temp_1, vl); - vuint32m4_t vt_3 = __riscv_vsll_vx_u32m4(vt_1, 16, vl); - vuint32m4_t vt_4 = __riscv_vadd_vx_u32m4(vt_2, 12, vl); - // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; - vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(vt_1, qh, vl); - vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(xha_0, vt_2, vl); - vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl); + vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl); + vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl); + vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); // ((qh & (1u << (j + 16))) >> (j + 12)); - vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(vt_3, qh, vl); - vuint32m4_t xhl_1 = __riscv_vsrl_vv_u32m4(xha_1, vt_4, vl); + vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl); + vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl); // narrowing - vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xhl_0, vl); - vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl); + vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl); + vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); - vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xhl_1, vl); - vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl); + vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl); + vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); // load - vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl); + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); - vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl); - vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl); + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); - vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl); - vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); + vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl); - vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl); + vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); + vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); - vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a); - vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l); + vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 16, vl); - vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 16, vl); + vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl); + vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl); - vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); - vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl); + vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); + vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - int sumi = __riscv_vmv_x_s_i32m1_i32(vs1); - sumi += __riscv_vmv_x_s_i32m1_i32(vs2); + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi; } @@ -3414,62 +3450,58 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * uint32_t qh; - // These temp values are for shift operations - uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - size_t vl = __riscv_vsetvl_e8m1(qk/2); + // temporary registers for shift operations + vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); + vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); + for (int i = 0; i < nb; i++) { memcpy(&qh, x[i].qh, sizeof(uint32_t)); - // temporary registers - vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_1, vl); - vuint32m4_t vt_2 = __riscv_vadd_vx_u32m4(vt_1, 12, vl); - // load qh - vuint32m4_t vqh = __riscv_vmv_v_x_u32m4(qh, vl); + vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl); // ((qh >> (j + 0)) << 4) & 0x10; - vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(vqh, vt_1, vl); - vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl); - vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(xhl_0, 0x10, vl); + vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl); + vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); + vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl); // ((qh >> (j + 12)) ) & 0x10; - vuint32m4_t xhr_1 = __riscv_vsrl_vv_u32m4(vqh, vt_2, vl); - vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(xhr_1, 0x10, vl); + vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl); + vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl); // narrowing - vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xha_0, vl); - vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl); + vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl); + vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); - vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xha_1, vl); - vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl); + vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl); + vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); // load - vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl); + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); - vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl); - vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl); + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); - vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl); - vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); + vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl); - vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl); + vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); + vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); - vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a); - vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l); + vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); - vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl); + vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); + vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - int sumi = __riscv_vmv_x_s_i32m1_i32(vs1); - sumi += __riscv_vmv_x_s_i32m1_i32(vs2); + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; } @@ -4025,12 +4057,16 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "ALIBI", "CLAMP", "CONV_1D", + "CONV_TRANSPOSE_1D", "CONV_2D", "CONV_TRANSPOSE_2D", "POOL_1D", "POOL_2D", "UPSCALE", + "CONV_1D_STAGE_0", + "CONV_1D_STAGE_1", + "FLASH_ATTN", "FLASH_FF", "FLASH_ATTN_BACK", @@ -4056,7 +4092,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68"); +static_assert(GGML_OP_COUNT == 71, "GGML_OP_COUNT != 71"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -4107,12 +4143,16 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "alibi(x)", "clamp(x)", "conv_1d(x)", + "conv_transpose_1d(x)", "conv_2d(x)", "conv_transpose_2d(x)", "pool_1d(x)", "pool_2d(x)", "upscale(x)", + "conv_1d_stage_0(x)", + "conv_1d_stage_1(x)", + "flash_attn(x)", "flash_ff(x)", "flash_attn_back(x)", @@ -4138,7 +4178,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68"); +static_assert(GGML_OP_COUNT == 71, "GGML_OP_COUNT != 71"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4167,7 +4207,10 @@ static void ggml_setup_op_has_task_pass(void) { p[GGML_OP_DIAG_MASK_INF ] = true; p[GGML_OP_DIAG_MASK_ZERO ] = true; p[GGML_OP_CONV_1D ] = true; + p[GGML_OP_CONV_1D_STAGE_0 ] = true; + p[GGML_OP_CONV_1D_STAGE_1 ] = true; p[GGML_OP_CONV_2D ] = true; + p[GGML_OP_CONV_TRANSPOSE_1D ] = true; p[GGML_OP_CONV_TRANSPOSE_2D ] = true; p[GGML_OP_FLASH_ATTN_BACK ] = true; p[GGML_OP_CROSS_ENTROPY_LOSS ] = true; @@ -4884,6 +4927,7 @@ static struct ggml_tensor * ggml_new_tensor_impl( *result = (struct ggml_tensor) { /*.type =*/ type, /*.backend =*/ GGML_BACKEND_CPU, + /*.buffer =*/ NULL, /*.n_dims =*/ n_dims, /*.ne =*/ { 1, 1, 1, 1 }, /*.nb =*/ { 0, 0, 0, 0 }, @@ -5450,6 +5494,39 @@ struct ggml_tensor * ggml_view_tensor( return result; } +struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx) { + struct ggml_object * obj = ctx->objects_begin; + + char * const mem_buffer = ctx->mem_buffer; + + while (obj != NULL) { + if (obj->type == GGML_OBJECT_TENSOR) { + return (struct ggml_tensor *)(mem_buffer + obj->offs); + } + + obj = obj->next; + } + + return NULL; +} + +struct ggml_tensor * ggml_get_next_tensor(struct ggml_context * ctx, struct ggml_tensor * tensor) { + struct ggml_object * obj = (struct ggml_object *) ((char *)tensor - GGML_OBJECT_SIZE); + obj = obj->next; + + char * const mem_buffer = ctx->mem_buffer; + + while (obj != NULL) { + if (obj->type == GGML_OBJECT_TENSOR) { + return (struct ggml_tensor *)(mem_buffer + obj->offs); + } + + obj = obj->next; + } + + return NULL; +} + struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name) { struct ggml_object * obj = ctx->objects_begin; @@ -6690,7 +6767,6 @@ struct ggml_tensor * ggml_cont_4d( return result; } - // ggml_reshape struct ggml_tensor * ggml_reshape( @@ -7448,14 +7524,17 @@ static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; } -GGML_API struct ggml_tensor * ggml_conv_1d( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int p0, - int d0) { - GGML_ASSERT(ggml_is_matrix(b)); +// im2col: [N, IC, IL] => [N, OL, IC*K] +// a: [OC,IC, K] +// b: [N, IC, IL] +// result: [N, OL, IC*K] +static struct ggml_tensor * ggml_conv_1d_stage_0( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int p0, + int d0) { GGML_ASSERT(a->ne[1] == b->ne[1]); bool is_node = false; @@ -7464,16 +7543,54 @@ GGML_API struct ggml_tensor * ggml_conv_1d( is_node = true; } + const int64_t OL = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); + const int64_t ne[4] = { - ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0), - a->ne[2], 1, 1, + a->ne[1] * a->ne[0], + OL, + b->ne[2], + 1, }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne); int32_t params[] = { s0, p0, d0 }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_CONV_1D; + result->op = GGML_OP_CONV_1D_STAGE_0; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_conv_1d_stage_1 + +// gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K] +// a: [OC, IC, K] +// b: [N, OL, IC * K] +// result: [N, OC, OL] +static struct ggml_tensor * ggml_conv_1d_stage_1( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + + bool is_node = false; + + if (a->grad || b->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { + b->ne[1], + a->ne[2], + b->ne[2], + 1, + }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_CONV_1D_STAGE_1; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; @@ -7481,6 +7598,53 @@ GGML_API struct ggml_tensor * ggml_conv_1d( return result; } +// ggml_conv_1d + +GGML_API struct ggml_tensor * ggml_conv_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int p0, + int d0) { + struct ggml_tensor * result = ggml_conv_1d_stage_0(ctx, a, b, s0, p0, d0); + result = ggml_conv_1d_stage_1(ctx, a, result); + return result; +} + +// GGML_API struct ggml_tensor * ggml_conv_1d( +// struct ggml_context * ctx, +// struct ggml_tensor * a, +// struct ggml_tensor * b, +// int s0, +// int p0, +// int d0) { +// GGML_ASSERT(ggml_is_matrix(b)); +// GGML_ASSERT(a->ne[1] == b->ne[1]); +// bool is_node = false; + +// if (a->grad || b->grad) { +// GGML_ASSERT(false); // TODO: implement backward +// is_node = true; +// } + +// const int64_t ne[4] = { +// ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0), +// a->ne[2], 1, 1, +// }; +// struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); + +// int32_t params[] = { s0, p0, d0 }; +// ggml_set_op_params(result, params, sizeof(params)); + +// result->op = GGML_OP_CONV_1D; +// result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; +// result->src[0] = a; +// result->src[1] = b; + +// return result; +// } + // ggml_conv_1d_ph struct ggml_tensor* ggml_conv_1d_ph( @@ -7492,6 +7656,50 @@ struct ggml_tensor* ggml_conv_1d_ph( return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d); } +// ggml_conv_transpose_1d + +static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) { + return (ins - 1) * s - 2 * p + d * (ks - 1) + 1; +} + +GGML_API struct ggml_tensor * ggml_conv_transpose_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int p0, + int d0) { + GGML_ASSERT(ggml_is_matrix(b)); + GGML_ASSERT(a->ne[2] == b->ne[1]); + GGML_ASSERT(a->ne[3] == 1); + + GGML_ASSERT(p0 == 0); + GGML_ASSERT(d0 == 1); + + bool is_node = false; + + if (a->grad || b->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { + ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/), + a->ne[1], b->ne[2], 1, + }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + int32_t params[] = { s0, p0, d0 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_CONV_TRANSPOSE_1D; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + // ggml_conv_2d struct ggml_tensor * ggml_conv_2d( @@ -8472,6 +8680,7 @@ void ggml_set_param( GGML_ASSERT(tensor->grad == NULL); tensor->grad = ggml_dup_tensor(ctx, tensor); + ggml_format_name(tensor->grad, "%s (grad)", tensor->name); } // ggml_compute_forward_dup @@ -11058,7 +11267,7 @@ static void ggml_compute_forward_silu_f32( #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k]; UNUSED(x); assert(!isnan(x)); assert(!isinf(x)); @@ -12884,28 +13093,25 @@ static void ggml_compute_forward_alibi_f32( return; } - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - assert(n_past >= 0); - - const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 - const int ne1 = src0->ne[1]; // seq_len_without_past - const int ne2 = src0->ne[2]; // n_head -> this is k - //const int ne3 = src0->ne[3]; // 1 -> bsz + const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 + const int64_t ne1 = src0->ne[1]; // seq_len_without_past + const int64_t ne2 = src0->ne[2]; // n_head -> this is k + //const int64_t ne3 = src0->ne[3]; // 1 -> bsz - const int n = ggml_nrows(src0); - const int ne2_ne3 = n/ne1; // ne2*ne3 + const int64_t n = ggml_nrows(src0); + const int64_t ne2_ne3 = n/ne1; // ne2*ne3 - const int nb0 = src0->nb[0]; - const int nb1 = src0->nb[1]; - const int nb2 = src0->nb[2]; + const size_t nb0 = src0->nb[0]; + const size_t nb1 = src0->nb[1]; + const size_t nb2 = src0->nb[2]; //const int nb3 = src0->nb[3]; GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(ne1 + n_past == ne0); GGML_ASSERT(n_head == ne2); // add alibi to src0 (KQ_scaled) @@ -12914,9 +13120,9 @@ static void ggml_compute_forward_alibi_f32( const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - for (int i = 0; i < ne0; i++) { - for (int j = 0; j < ne1; j++) { - for (int k = 0; k < ne2_ne3; k++) { + for (int64_t i = 0; i < ne0; i++) { + for (int64_t j = 0; j < ne1; j++) { + for (int64_t k = 0; k < ne2_ne3; k++) { float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2); float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2); @@ -12931,7 +13137,6 @@ static void ggml_compute_forward_alibi_f32( } pdst[0] = i * m_k + src[0]; - } } } @@ -13631,7 +13836,7 @@ static void ggml_compute_forward_rope_back( // ggml_compute_forward_conv_1d -static void ggml_compute_forward_conv_1d_s1_ph_f16_f32( +static void ggml_compute_forward_conv_1d_f16_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -13649,42 +13854,33 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32( const int nth = params->nth; const int nk = ne00; - const int nh = nk/2; - const int ew0 = ggml_up32(ne01); + // size of the convolution row - the kernel size unrolled across all input channels + const int ew0 = nk*ne01; + + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; - GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb10 == sizeof(float)); if (params->type == GGML_TASK_INIT) { - // TODO: fix this memset (wsize is overestimated) memset(params->wdata, 0, params->wsize); - // prepare kernel data (src0) - { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); - ggml_fp16_t * dst_data = wdata + i02*ew0*ne00; - for (int64_t i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ew0 + i01] = src[i00]; - } - } - } - } + for (int64_t i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + ggml_fp16_t * dst_data = wdata; - // prepare source data (src1) - { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00; + for (int64_t i0 = 0; i0 < ne0; i0++) { + for (int64_t ik = 0; ik < nk; ik++) { + const int idx0 = i0*s0 + ik*d0 - p0; - for (int64_t i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - ggml_fp16_t * dst_data = wdata; - for (int64_t i10 = 0; i10 < ne10; i10++) { - dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]); + if(!(idx0 < 0 || idx0 >= ne10)) { + dst_data[i0*ew0 + i11*nk + ik] = GGML_FP32_TO_FP16(src[idx0]); + } } } } @@ -13697,7 +13893,7 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32( } // total rows in dst - const int nr = ne02; + const int nr = ne2; // rows per thread const int dr = (nr + nth - 1)/nth; @@ -13706,23 +13902,22 @@ static void ggml_compute_forward_conv_1d_s1_ph_f16_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i1*nb1); - for (int64_t i0 = 0; i0 < ne10; ++i0) { - dst_data[i0] = 0; - for (int k = -nh; k <= nh; k++) { - float v = 0.0f; - ggml_vec_dot_f16(ew0, &v, - (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, - (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); - - dst_data[i0] += v; + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + + for (int i2 = 0; i2 < ne2; i2++) { + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1); + + for (int i0 = 0; i0 < ne0; i0++) { + ggml_vec_dot_f16(ew0, dst_data + i0, + (ggml_fp16_t *) ((char *) src0->data + i1*nb02), + (ggml_fp16_t *) wdata + i2*nb2 + i0*ew0); } } } } -static void ggml_compute_forward_conv_1d_s1_ph_f32( +static void ggml_compute_forward_conv_1d_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -13740,42 +13935,32 @@ static void ggml_compute_forward_conv_1d_s1_ph_f32( const int nth = params->nth; const int nk = ne00; - const int nh = nk/2; - const int ew0 = ggml_up32(ne01); + const int ew0 = nk*ne01; + + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; - GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes GGML_ASSERT(nb00 == sizeof(float)); GGML_ASSERT(nb10 == sizeof(float)); if (params->type == GGML_TASK_INIT) { - // TODO: fix this memset (wsize is overestimated) memset(params->wdata, 0, params->wsize); - // prepare kernel data (src0) - { - float * const wdata = (float *) params->wdata + 0; + float * const wdata = (float *) params->wdata + 0; - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); - float * dst_data = wdata + i02*ew0*ne00; - for (int64_t i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ew0 + i01] = src[i00]; - } - } - } - } + for (int64_t i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + float * dst_data = wdata; - // prepare source data (src1) - { - float * const wdata = (float *) params->wdata + ne02*ew0*ne00; + for (int64_t i0 = 0; i0 < ne0; i0++) { + for (int64_t ik = 0; ik < nk; ik++) { + const int idx0 = i0*s0 + ik*d0 - p0; - for (int64_t i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - float * dst_data = wdata; - for (int64_t i10 = 0; i10 < ne10; i10++) { - dst_data[(i10 + nh)*ew0 + i11] = src[i10]; + if(!(idx0 < 0 || idx0 >= ne10)) { + dst_data[i0*ew0 + i11*nk + ik] = src[idx0]; + } } } } @@ -13797,35 +13982,225 @@ static void ggml_compute_forward_conv_1d_s1_ph_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i1*nb1); - for (int64_t i0 = 0; i0 < ne10; ++i0) { - dst_data[i0] = 0; - for (int k = -nh; k <= nh; k++) { - float v = 0.0f; - ggml_vec_dot_f32(ew0, &v, - (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, - (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); - - dst_data[i0] += v; + float * const wdata = (float *) params->wdata + 0; + + for (int i2 = 0; i2 < ne2; i2++) { + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1); + + for (int i0 = 0; i0 < ne0; i0++) { + ggml_vec_dot_f32(ew0, dst_data + i0, + (float *) ((char *) src0->data + i1*nb02), + (float *) wdata + i2*nb2 + i0*ew0); } } } } -static void ggml_compute_forward_conv_1d_s1_ph( +static void gemm_f16_out_f32(int64_t m, int64_t n, int64_t k, + ggml_fp16_t * A, + ggml_fp16_t * B, + float * C, + const int ith, const int nth) { + // does not seem to make a difference + int64_t m0, m1, n0, n1; + // patches per thread + if (m > n) { + n0 = 0; + n1 = n; + + // total patches in dst + const int np = m; + + // patches per thread + const int dp = (np + nth - 1)/nth; + + // patch range for this thread + m0 = dp*ith; + m1 = MIN(m0 + dp, np); + } else { + m0 = 0; + m1 = m; + + // total patches in dst + const int np = n; + + // patches per thread + const int dp = (np + nth - 1)/nth; + + // patch range for this thread + n0 = dp*ith; + n1 = MIN(n0 + dp, np); + } + + // block-tiling attempt + int64_t blck_n = 16; + int64_t blck_m = 16; + + // int64_t CACHE_SIZE = 2 * 1024 * 1024; // 2MB + // int64_t blck_size = CACHE_SIZE / (sizeof(float) + 2 * sizeof(ggml_fp16_t) * K); + // if (blck_size > 0) { + // blck_0 = 4; + // blck_1 = blck_size / blck_0; + // if (blck_1 < 0) { + // blck_1 = 1; + // } + // // blck_0 = (int64_t)sqrt(blck_size); + // // blck_1 = blck_0; + // } + // // printf("%zd %zd %zd %zd\n", blck_size, K, blck_0, blck_1); + + for (int j = n0; j < n1; j+=blck_n) { + for (int i = m0; i < m1; i+=blck_m) { + // printf("i j k => %d %d %d\n", i, j, K); + for (int ii = i; ii < i + blck_m && ii < m1; ii++) { + for (int jj = j; jj < j + blck_n && jj < n1; jj++) { + ggml_vec_dot_f16(k, + C + ii*n + jj, + A + ii * k, + B + jj * k); + } + } + } + } +} + +// src0: kernel [OC, IC, K] +// src1: signal [N, IC, IL] +// dst: result [N, OL, IC*K] +static void ggml_compute_forward_conv_1d_stage_0_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - switch (src0->type) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int64_t N = ne12; + const int64_t IC = ne11; + const int64_t IL = ne10; + + const int64_t K = ne00; + + const int64_t OL = ne1; + + const int ith = params->ith; + const int nth = params->nth; + + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; + + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + memset(dst->data, 0, ggml_nbytes(dst)); + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // im2col: [N, IC, IL] => [N, OL, IC*K] + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data; + + for (int64_t in = 0; in < N; in++) { + for (int64_t iol = 0; iol < OL; iol++) { + for (int64_t iic = ith; iic < IC; iic+=nth) { + + // micro kernel + ggml_fp16_t * dst_data = wdata + (in*OL + iol)*(IC*K); // [IC, K] + const float * const src_data = (float *)((char *) src1->data + in*nb12 + iic*nb11); // [IL] + + for (int64_t ik = 0; ik < K; ik++) { + const int64_t iil = iol*s0 + ik*d0 - p0; + + if (!(iil < 0 || iil >= IL)) { + dst_data[iic*K + ik] = GGML_FP32_TO_FP16(src_data[iil]); + } + } + } + } + } + } +} + +// gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K] +// src0: [OC, IC, K] +// src1: [N, OL, IC * K] +// result: [N, OC, OL] +static void ggml_compute_forward_conv_1d_stage_1_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + if (params->type == GGML_TASK_INIT) { + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_TENSOR_BINARY_OP_LOCALS; + + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb0 == sizeof(float)); + + const int N = ne12; + const int OL = ne11; + + const int OC = ne02; + const int IC = ne01; + const int K = ne00; + + const int ith = params->ith; + const int nth = params->nth; + + int64_t m = OC; + int64_t n = OL; + int64_t k = IC * K; + + // [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K] + for (int i = 0; i < N; i++) { + ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k] + ggml_fp16_t * B = (ggml_fp16_t *)src1->data + i * m * k; // [n, k] + float * C = (float *)dst->data + i * m * n; // [m, n] + + gemm_f16_out_f32(m, n, k, A, B, C, ith, nth); + } +} + +static void ggml_compute_forward_conv_1d( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch(src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_conv_1d_s1_ph_f16_f32(params, src0, src1, dst); + ggml_compute_forward_conv_1d_f16_f32(params, src0, src1, dst); } break; case GGML_TYPE_F32: { - ggml_compute_forward_conv_1d_s1_ph_f32(params, src0, src1, dst); + ggml_compute_forward_conv_1d_f32(params, src0, src1, dst); } break; default: { @@ -13834,7 +14209,43 @@ static void ggml_compute_forward_conv_1d_s1_ph( } } -static void ggml_compute_forward_conv_1d_s2_ph_f16_f32( +static void ggml_compute_forward_conv_1d_stage_0( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch(src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_conv_1d_stage_0_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +static void ggml_compute_forward_conv_1d_stage_1( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch(src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_conv_1d_stage_1_f16(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_conv_transpose_1d + +static void ggml_compute_forward_conv_transpose_1d_f16_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -13851,43 +14262,38 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32( const int ith = params->ith; const int nth = params->nth; - const int nk = ne00; - const int nh = nk/2; - - const int ew0 = ggml_up32(ne01); + const int nk = ne00*ne01*ne02; - GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb10 == sizeof(float)); if (params->type == GGML_TASK_INIT) { - // TODO: fix this memset (wsize is overestimated) memset(params->wdata, 0, params->wsize); - // prepare kernel data (src0) + // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) { ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = 0; i01 < ne01; i01++) { const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); - ggml_fp16_t * dst_data = wdata + i02*ew0*ne00; + ggml_fp16_t * dst_data = wdata + i01*ne00*ne02; for (int64_t i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ew0 + i01] = src[i00]; + dst_data[i00*ne02 + i02] = src[i00]; } } } } - // prepare source data (src1) + // permute source data (src1) from (L x Cin) to (Cin x L) { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne02*ew0*ne00; + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk; + ggml_fp16_t * dst_data = wdata; for (int64_t i11 = 0; i11 < ne11; i11++) { const float * const src = (float *)((char *) src1->data + i11*nb11); - ggml_fp16_t * dst_data = wdata; for (int64_t i10 = 0; i10 < ne10; i10++) { - dst_data[(i10 + nh)*ew0 + i11] = GGML_FP32_TO_FP16(src[i10]); + dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]); } } } @@ -13899,8 +14305,10 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32( return; } + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + // total rows in dst - const int nr = ne02; + const int nr = ne1; // rows per thread const int dr = (nr + nth - 1)/nth; @@ -13909,23 +14317,26 @@ static void ggml_compute_forward_conv_1d_s2_ph_f16_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + ggml_fp16_t * const wdata_src = wdata + nk; + for (int i1 = ir0; i1 < ir1; i1++) { float * dst_data = (float *)((char *) dst->data + i1*nb1); - for (int64_t i0 = 0; i0 < ne10; i0 += 2) { - dst_data[i0/2] = 0; - for (int k = -nh; k <= nh; k++) { - float v = 0.0f; - ggml_vec_dot_f16(ew0, &v, - (ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, - (ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); - - dst_data[i0/2] += v; + ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00; + for (int i10 = 0; i10 < ne10; i10++) { + const int i1n = i10*ne11; + for (int i00 = 0; i00 < ne00; i00++) { + float v = 0; + ggml_vec_dot_f16(ne02, &v, + (ggml_fp16_t *) wdata_src + i1n, + (ggml_fp16_t *) wdata_kernel + i00*ne02); + dst_data[i10*s0 + i00] += v; } } } } -static void ggml_compute_forward_conv_1d_s2_ph_f32( +static void ggml_compute_forward_conv_transpose_1d_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -13942,29 +14353,24 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32( const int ith = params->ith; const int nth = params->nth; - const int nk = ne00; - const int nh = nk/2; + const int nk = ne00*ne01*ne02; - const int ew0 = ggml_up32(ne01); - - GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes GGML_ASSERT(nb00 == sizeof(float)); GGML_ASSERT(nb10 == sizeof(float)); if (params->type == GGML_TASK_INIT) { - // TODO: fix this memset (wsize is overestimated) memset(params->wdata, 0, params->wsize); - // prepare kernel data (src0) + // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) { float * const wdata = (float *) params->wdata + 0; for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = 0; i01 < ne01; i01++) { const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); - float * dst_data = wdata + i02*ew0*ne00; + float * dst_data = wdata + i01*ne00*ne02; for (int64_t i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ew0 + i01] = src[i00]; + dst_data[i01*ne00*ne02 + i00*ne02 + i02] = src[i00]; } } } @@ -13972,13 +14378,13 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32( // prepare source data (src1) { - float * const wdata = (float *) params->wdata + ne02*ew0*ne00; + float * const wdata = (float *) params->wdata + nk; + float * dst_data = wdata; for (int64_t i11 = 0; i11 < ne11; i11++) { const float * const src = (float *)((char *) src1->data + i11*nb11); - float * dst_data = wdata; for (int64_t i10 = 0; i10 < ne10; i10++) { - dst_data[(i10 + nh)*ew0 + i11] = src[i10]; + dst_data[i10*ne11 + i11] = src[i10]; } } } @@ -13990,8 +14396,10 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32( return; } + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + // total rows in dst - const int nr = ne02; + const int nr = ne1; // rows per thread const int dr = (nr + nth - 1)/nth; @@ -14000,23 +14408,26 @@ static void ggml_compute_forward_conv_1d_s2_ph_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); + float * const wdata = (float *) params->wdata + 0; + float * const wdata_src = wdata + nk; + for (int i1 = ir0; i1 < ir1; i1++) { float * dst_data = (float *)((char *) dst->data + i1*nb1); - for (int64_t i0 = 0; i0 < ne10; i0 += 2) { - dst_data[i0/2] = 0; - for (int k = -nh; k <= nh; k++) { - float v = 0.0f; - ggml_vec_dot_f32(ew0, &v, - (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, - (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); - - dst_data[i0/2] += v; + float * wdata_kernel = wdata + i1*ne02*ne00; + for (int i10 = 0; i10 < ne10; i10++) { + const int i1n = i10*ne11; + for (int i00 = 0; i00 < ne00; i00++) { + float v = 0; + ggml_vec_dot_f32(ne02, &v, + wdata_src + i1n, + wdata_kernel + i00*ne02); + dst_data[i10*s0 + i00] += v; } } } } -static void ggml_compute_forward_conv_1d_s2_ph( +static void ggml_compute_forward_conv_transpose_1d( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -14024,11 +14435,11 @@ static void ggml_compute_forward_conv_1d_s2_ph( switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_conv_1d_s2_ph_f16_f32(params, src0, src1, dst); + ggml_compute_forward_conv_transpose_1d_f16_f32(params, src0, src1, dst); } break; case GGML_TYPE_F32: { - ggml_compute_forward_conv_1d_s2_ph_f32(params, src0, src1, dst); + ggml_compute_forward_conv_transpose_1d_f32(params, src0, src1, dst); } break; default: { @@ -14037,27 +14448,6 @@ static void ggml_compute_forward_conv_1d_s2_ph( } } -// ggml_compute_forward_conv_1d - -static void ggml_compute_forward_conv_1d( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; - GGML_ASSERT(d0 == 1); // dilation not supported - GGML_ASSERT(p0 == src0->ne[0]/2); // only half padding supported - if (s0 == 1) { - ggml_compute_forward_conv_1d_s1_ph(params, src0, src1, dst); - } else if (s0 == 2) { - ggml_compute_forward_conv_1d_s2_ph(params, src0, src1, dst); - } else { - GGML_ASSERT(false); // only stride 1 and 2 supported - } -} - // ggml_compute_forward_conv_2d static void ggml_compute_forward_conv_2d_f16_f32( @@ -14072,7 +14462,7 @@ static void ggml_compute_forward_conv_2d_f16_f32( int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - GGML_TENSOR_BINARY_OP_LOCALS + GGML_TENSOR_BINARY_OP_LOCALS; const int ith = params->ith; const int nth = params->nth; @@ -14100,20 +14490,22 @@ static void ggml_compute_forward_conv_2d_f16_f32( { ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - for (int i12 = 0; i12 < ne12; i12++) { - const float * const src = (float *)((char *) src1->data + i12*nb12); - ggml_fp16_t * dst_data = wdata; - - for (int i1 = 0; i1 < ne1; i1++) { - for (int i0 = 0; i0 < ne0; i0++) { - for (int ik1 = 0; ik1 < nk1; ik1++) { - for (int ik0 = 0; ik0 < nk0; ik0++) { - const int idx0 = i0*s0 + ik0*d0 - p0; - const int idx1 = i1*s1 + ik1*d1 - p1; - - if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) { - dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] = - GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]); + for (int i13 = 0; i13 < ne13; i13++) { + for (int i12 = 0; i12 < ne12; i12++) { + const float * const src = (float *)((char *) src1->data + i13*nb13 + i12*nb12); + ggml_fp16_t * dst_data = wdata + i13*(ne1*ne0*ew0); + + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < ne0; i0++) { + for (int ik1 = 0; ik1 < nk1; ik1++) { + for (int ik0 = 0; ik0 < nk0; ik0++) { + const int idx0 = i0*s0 + ik0*d0 - p0; + const int idx1 = i1*s1 + ik1*d1 - p1; + + if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) { + dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] = + GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]); + } } } } @@ -16396,6 +16788,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor); } break; + case GGML_OP_CONV_1D_STAGE_0: + { + ggml_compute_forward_conv_1d_stage_0(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_CONV_1D_STAGE_1: + { + ggml_compute_forward_conv_1d_stage_1(params, tensor->src[0], tensor->src[1], tensor); + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + ggml_compute_forward_conv_transpose_1d(params, tensor->src[0], tensor->src[1], tensor); + } break; case GGML_OP_CONV_2D: { ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor); @@ -17321,10 +17725,22 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_CONV_1D_STAGE_0: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_CONV_1D_STAGE_1: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_CONV_2D: { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_CONV_TRANSPOSE_2D: { GGML_ASSERT(false); // TODO: not implemented @@ -18166,21 +18582,68 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { GGML_ASSERT(node->src[1]->ne[2] == 1); GGML_ASSERT(node->src[1]->ne[3] == 1); + const int64_t ne00 = node->src[0]->ne[0]; + const int64_t ne01 = node->src[0]->ne[1]; + const int64_t ne02 = node->src[0]->ne[2]; + + const int64_t ne10 = node->src[1]->ne[0]; + const int64_t ne11 = node->src[1]->ne[1]; + + const int64_t ne0 = node->ne[0]; + const int64_t ne1 = node->ne[1]; + const int64_t nk = ne00; + const int64_t ew0 = nk * ne01; + + UNUSED(ne02); + UNUSED(ne10); + UNUSED(ne11); + size_t cur = 0; - const int nk = node->src[0]->ne[0]; if (node->src[0]->type == GGML_TYPE_F16 && - node->src[1]->type == GGML_TYPE_F32) { - cur = sizeof(ggml_fp16_t)*( - nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] + - ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1] - ); + node->src[1]->type == GGML_TYPE_F32) { + cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0); } else if (node->src[0]->type == GGML_TYPE_F32 && - node->src[1]->type == GGML_TYPE_F32) { - cur = sizeof(float)*( - nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] + - ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1] - ); + node->src[1]->type == GGML_TYPE_F32) { + cur = sizeof(float)*(ne0*ne1*ew0); + } else { + GGML_ASSERT(false); + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_CONV_1D_STAGE_0: + { + n_tasks = n_threads; + } break; + case GGML_OP_CONV_1D_STAGE_1: + { + n_tasks = n_threads; + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + n_tasks = n_threads; + + GGML_ASSERT(node->src[0]->ne[3] == 1); + GGML_ASSERT(node->src[1]->ne[2] == 1); + GGML_ASSERT(node->src[1]->ne[3] == 1); + + const int64_t ne00 = node->src[0]->ne[0]; // K + const int64_t ne01 = node->src[0]->ne[1]; // Cout + const int64_t ne02 = node->src[0]->ne[2]; // Cin + + const int64_t ne10 = node->src[1]->ne[0]; // L + const int64_t ne11 = node->src[1]->ne[1]; // Cin + + size_t cur = 0; + if (node->src[0]->type == GGML_TYPE_F16 && + node->src[1]->type == GGML_TYPE_F32) { + cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02; + cur += sizeof(ggml_fp16_t)*ne10*ne11; + } else if (node->src[0]->type == GGML_TYPE_F32 && + node->src[1]->type == GGML_TYPE_F32) { + cur += sizeof(float)*ne00*ne01*ne02; + cur += sizeof(float)*ne10*ne11; } else { GGML_ASSERT(false); } @@ -19306,7 +19769,7 @@ static enum ggml_opt_result ggml_opt_adam( if (callback) { callback(callback_data, accum_step, &sched, &cancel); if (cancel) { - break; + return GGML_OPT_CANCEL; } } // ggml_graph_reset (gf); @@ -19315,9 +19778,6 @@ static enum ggml_opt_result ggml_opt_adam( ggml_opt_acc_grad(np, ps, g, accum_norm); fx += ggml_get_f32_1d(f, 0); } - if (cancel) { - return GGML_OPT_DID_NOT_CONVERGE; - } fx *= accum_norm; opt->adam.fx_prev = fx; @@ -19343,9 +19803,6 @@ static enum ggml_opt_result ggml_opt_adam( // run the optimizer for (int t = 0; t < params.adam.n_iter; ++t) { - if (cancel) { - break; - } opt->iter = iter0 + t + 1; GGML_PRINT_DEBUG ("=== iter %d ===\n", t); @@ -19403,7 +19860,7 @@ static enum ggml_opt_result ggml_opt_adam( if (callback) { callback(callback_data, accum_step, &sched, &cancel); if (cancel) { - break; + return GGML_OPT_CANCEL;; } } // ggml_graph_reset (gf); @@ -19412,9 +19869,6 @@ static enum ggml_opt_result ggml_opt_adam( ggml_opt_acc_grad(np, ps, g, accum_norm); fx += ggml_get_f32_1d(f, 0); } - if (cancel) { - break; - } fx *= accum_norm; opt->loss_after = fx; @@ -19533,7 +19987,7 @@ static enum ggml_opt_result linesearch_backtracking( finit = *fx; dgtest = params->lbfgs.ftol*dginit; - while (!*cancel) { + while (true) { ggml_vec_cpy_f32(nx, x, xp); ggml_vec_mad_f32(nx, x, d, *step); @@ -19549,7 +20003,7 @@ static enum ggml_opt_result linesearch_backtracking( float sched = 0; callback(callback_data, accum_step, &sched, cancel); if (*cancel) { - break; + return GGML_OPT_CANCEL; } } // ggml_graph_reset (gf); @@ -19558,9 +20012,6 @@ static enum ggml_opt_result linesearch_backtracking( ggml_opt_acc_grad(np, ps, g, accum_norm); *fx += ggml_get_f32_1d(f, 0); } - if (*cancel) { - break; - } *fx *= accum_norm; } @@ -19693,7 +20144,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( float sched = 0; callback(callback_data, accum_step, &sched, &cancel); if (cancel) { - break; + return GGML_OPT_CANCEL; } } // ggml_graph_reset (gf); @@ -19702,9 +20153,6 @@ static enum ggml_opt_result ggml_opt_lbfgs( ggml_opt_acc_grad(np, ps, g, accum_norm); fx += ggml_get_f32_1d(f, 0); } - if (cancel) { - return GGML_OPT_DID_NOT_CONVERGE; - } fx *= accum_norm; opt->loss_before = fx; @@ -19763,9 +20211,13 @@ static enum ggml_opt_result ggml_opt_lbfgs( ggml_vec_cpy_f32(nx, xp, x); ggml_vec_cpy_f32(nx, gp, g); + // TODO: instead of passing &cancel here, use the return code of the linesearch + // to determine if the optimization should be cancelled + // this is a simple change, but not doing this atm, since I don't have a nice + // way to test and don't want to break something with so many changes lined up ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data); - if (!cancel) { - break; + if (cancel) { + return GGML_OPT_CANCEL; } if (ls < 0) { diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml.h b/plugins/wasi_nn/thirdparty/ggml/ggml.h index 460857fa..6e35888e 100644 --- a/plugins/wasi_nn/thirdparty/ggml/ggml.h +++ b/plugins/wasi_nn/thirdparty/ggml/ggml.h @@ -326,7 +326,7 @@ extern "C" { GGML_TYPE_COUNT, }; - enum ggml_backend { + enum ggml_backend_type { GGML_BACKEND_CPU = 0, GGML_BACKEND_GPU = 10, GGML_BACKEND_GPU_SPLIT = 20, @@ -401,10 +401,14 @@ extern "C" { GGML_OP_CLAMP, GGML_OP_CONV_1D, GGML_OP_CONV_2D, + GGML_OP_CONV_TRANSPOSE_1D, GGML_OP_CONV_TRANSPOSE_2D, GGML_OP_POOL_1D, GGML_OP_POOL_2D, + GGML_OP_CONV_1D_STAGE_0, // internal + GGML_OP_CONV_1D_STAGE_1, // internal + GGML_OP_UPSCALE, // nearest interpolate GGML_OP_FLASH_ATTN, @@ -475,8 +479,10 @@ extern "C" { // n-dimensional tensor struct ggml_tensor { - enum ggml_type type; - enum ggml_backend backend; + enum ggml_type type; + enum ggml_backend_type backend; + + struct ggml_backend_buffer * buffer; int n_dims; int64_t ne[GGML_MAX_DIMS]; // number of elements @@ -510,7 +516,7 @@ extern "C" { void * extra; // extra things e.g. for ggml-cuda.cu - char padding[4]; + char padding[12]; }; static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); @@ -698,6 +704,9 @@ extern "C" { GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src); GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src); + // Context tensor enumeration and lookup + GGML_API struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx); + GGML_API struct ggml_tensor * ggml_get_next_tensor (struct ggml_context * ctx, struct ggml_tensor * tensor); GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name); GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); @@ -1354,7 +1363,7 @@ extern "C" { // alibi position embedding // in-place, returns view(a) - struct ggml_tensor * ggml_alibi( + GGML_API struct ggml_tensor * ggml_alibi( struct ggml_context * ctx, struct ggml_tensor * a, int n_past, @@ -1363,7 +1372,7 @@ extern "C" { // clamp // in-place, returns view(a) - struct ggml_tensor * ggml_clamp( + GGML_API struct ggml_tensor * ggml_clamp( struct ggml_context * ctx, struct ggml_tensor * a, float min, @@ -1386,6 +1395,14 @@ extern "C" { int s, int d); + GGML_API struct ggml_tensor * ggml_conv_transpose_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int p0, + int d0); + GGML_API struct ggml_tensor * ggml_conv_2d( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1759,6 +1776,7 @@ extern "C" { GGML_OPT_NO_CONTEXT, GGML_OPT_INVALID_WOLFE, GGML_OPT_FAIL, + GGML_OPT_CANCEL, GGML_LINESEARCH_FAIL = -128, GGML_LINESEARCH_MINIMUM_STEP, @@ -2089,7 +2107,7 @@ extern "C" { enum ggml_type vec_dot_type; } ggml_type_traits_t; - ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type); + GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type); #ifdef __cplusplus } diff --git a/plugins/wasi_nn/thirdparty/ggml/k_quants.c b/plugins/wasi_nn/thirdparty/ggml/k_quants.c index 62085882..558f5fda 100644 --- a/plugins/wasi_nn/thirdparty/ggml/k_quants.c +++ b/plugins/wasi_nn/thirdparty/ggml/k_quants.c @@ -54,6 +54,10 @@ inline static int32_t vaddvq_s32(int32x4_t v) { #endif #endif +#ifdef __riscv_v_intrinsic +#include +#endif + #undef MIN #undef MAX #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -65,7 +69,6 @@ inline static int32_t vaddvq_s32(int32x4_t v) { // 2-6 bit quantization in super-blocks // - // // ===================== Helper functions // @@ -344,7 +347,6 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict const float q4scale = 15.f; for (int i = 0; i < nb; i++) { - float max_scale = 0; // as we are deducting the min, scales are always positive float max_min = 0; for (int j = 0; j < QK_K/16; ++j) { @@ -1582,6 +1584,90 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri *s = hsum_float_8(acc); +#elif defined __riscv_v_intrinsic + + float sumf = 0; + uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + size_t vl = 16; + + vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); + vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); + + vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); + + vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); + vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); + vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); + vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); + vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + + sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); + + vl = 32; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); + + uint8_t is=0; + int isum=0; + + for (int j = 0; j < QK_K/128; ++j) { + // load Q2 + vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); + + vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); + vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03 , vl); + vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03 , vl); + vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl); + + // duplicate scale elements for product + vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0+is, vl), vl); + vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2+is, vl), vl); + vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4+is, vl), vl); + vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl); + + vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); + vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); + vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); + vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); + + // load Q8 + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); + vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8+64, vl); + vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl); + + vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); + vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); + vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); + vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); + + isum += __riscv_vmv_x_s_i32m1_i32(isum1); + + q2+=32; q8+=128; is=8; + + } + + sumf += dall * isum; + + } + + *s = sumf; + #else float sumf = 0; @@ -1807,6 +1893,64 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri *s = hsum_float_8(acc) + summs; +#elif defined __riscv_v_intrinsic + + uint32_t aux32[2]; + const uint8_t * scales = (const uint8_t *)aux32; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * (float)x[i].d; + const float dmin = -y[i].d * (float)x[i].dmin; + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + const uint32_t * restrict sc = (const uint32_t *)x[i].scales; + + aux32[0] = sc[0] & 0x0f0f0f0f; + aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f; + + sumf += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]); + + int isum1 = 0; + int isum2 = 0; + + size_t vl = 16; + + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + + // load Q2 + vuint8mf2_t q2_x = __riscv_vle8_v_u8mf2(q2, vl); + + vint8mf2_t q2_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q2_x, 0x03, vl)); + vint8mf2_t q2_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x2, vl), 0x03 , vl)); + vint8mf2_t q2_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x4, vl), 0x03 , vl)); + vint8mf2_t q2_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x6, vl), 0x03 , vl)); + + // load Q8, and take product with Q2 + vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q2_0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q2_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); + vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q2_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q2_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); + + vint16m1_t vs_0 = __riscv_vredsum_vs_i16m1_i16m1(p0, vzero, vl); + vint16m1_t vs_1 = __riscv_vredsum_vs_i16m1_i16m1(p1, vzero, vl); + vint16m1_t vs_2 = __riscv_vredsum_vs_i16m1_i16m1(p2, vzero, vl); + vint16m1_t vs_3 = __riscv_vredsum_vs_i16m1_i16m1(p3, vzero, vl); + + isum1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[0]; + isum2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[1]; + isum1 += __riscv_vmv_x_s_i16m1_i16(vs_2) * scales[2]; + isum2 += __riscv_vmv_x_s_i16m1_i16(vs_3) * scales[3]; + + sumf += d * (isum1 + isum2); + + } + + *s = sumf; + #else float sumf = 0; @@ -2220,6 +2364,106 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri *s = hsum_float_8(acc); +#elif defined __riscv_v_intrinsic + + uint32_t aux[3]; + uint32_t utmp[4]; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; + + + size_t vl = 32; + uint8_t m = 1; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); + + int sum_t = 0; + + for (int j = 0; j < QK_K; j += 128) { + + vl = 32; + + // load Q3 + vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); + + vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); + vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); + vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); + vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); + + // compute mask for subtraction + vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); + vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_m(vmask_0, q3_0, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); + vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_m(vmask_1, q3_1, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); + vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_m(vmask_2, q3_2, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); + vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_m(vmask_3, q3_3, 0x4, vl); + m <<= 1; + + // load Q8 and take product with Q3 + vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + + vl = 16; + + // retreive lane to multiply with scale + vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); + vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); + vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); + vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); + vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); + vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); + vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); + vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + + q3 += 32; q8 += 128; scale += 8; + + } + + const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; + + sumf += d*sum_t; + + } + + *s = sumf; + #else // scalar version // This function is written like this so the compiler can manage to vectorize most of it @@ -2523,6 +2767,79 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri *s = hsum_float_8(acc); +#elif defined __riscv_v_intrinsic + + uint16_t aux16[2]; + int8_t * scales = (int8_t *)aux16; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + for (int j = 0; j < 4; ++j) scales[j] -= 8; + + int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]); + + const float d = y[i].d * (float)x[i].d; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + + // load qh + vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4(x[i].hmask, 8); + vuint8mf2_t qh_x2 = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8)); + + size_t vl = 16; + + // extend and combine both qh_x1 and qh_x2 + vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl); + + vuint8mf2_t qh_0 = __riscv_vand_vx_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl); + vuint8mf2_t qh_1 = __riscv_vand_vx_u8mf2(qh_x, 0x4, vl); + vuint8mf2_t qh_2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl); + vuint8mf2_t qh_3 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), 0x4, vl); + + // load Q3 + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(q3, vl); + + vuint8mf2_t q3h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q3_x, 0x3, vl), qh_0, vl); + vuint8mf2_t q3h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 2, vl), 0x3, vl), qh_1, vl); + vuint8mf2_t q3h_2 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 4, vl), 0x3, vl), qh_2, vl); + vuint8mf2_t q3h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x6, vl), qh_3, vl); + + vint8mf2_t q3_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_0); + vint8mf2_t q3_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_1); + vint8mf2_t q3_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_2); + vint8mf2_t q3_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_3); + + // load Q8 and take product with Q3 + vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q3_0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q3_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); + vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q3_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q3_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); + + vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl); + vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl); + vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl); + vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl); + + isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scales[0]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scales[2]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scales[1]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scales[3]; + + sumf += d * isum; + + } + + *s = sumf; + #else int8_t aux8[QK_K]; @@ -2823,6 +3140,78 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); +#elif defined __riscv_v_intrinsic + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + size_t vl = 8; + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); + vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); + vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); + vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); + vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + + vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + vl = 32; + + int32_t sum_1 = 0; + int32_t sum_2 = 0; + + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + + for (int j = 0; j < QK_K/64; ++j) { + // load Q4 + vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); + + // load Q8 and multiply it with lower Q4 nibble + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); + vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); + vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); + + sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; + + // load Q8 and multiply it with upper Q4 nibble + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); + vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); + vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); + vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); + + sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; + + q4 += 32; q8 += 64; + + } + + sumf += d*(sum_1 + sum_2); + + } + + *s = sumf; + #else @@ -3064,6 +3453,50 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri *s = hsum_float_8(acc) - summs; +#elif defined __riscv_v_intrinsic + + uint16_t s16[2]; + const uint8_t * restrict scales = (const uint8_t *)s16; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t * restrict b = (const uint16_t *)x[i].scales; + s16[0] = b[0] & 0x0f0f; + s16[1] = (b[0] >> 4) & 0x0f0f; + + sumf -= y[i].d * ggml_fp16_to_fp32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[0]); + + size_t vl = 32; + + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + + // load Q4 + vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); + + // load Q8 and multiply it with lower Q4 nibble + vint8m1_t q4_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); + vint16m2_t va_0 = __riscv_vwmul_vv_i16m2(q4_a, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m1_t aux1 = __riscv_vredsum_vs_i16m2_i16m1(va_0, vzero, vl); + + sumf += d*scales[0]*__riscv_vmv_x_s_i16m1_i16(aux1); + + // load Q8 and multiply it with upper Q4 nibble + vint8m1_t q4_s = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); + vint16m2_t va_1 = __riscv_vwmul_vv_i16m2(q4_s, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m1_t aux2 = __riscv_vredsum_vs_i16m2_i16m1(va_1, vzero, vl); + + sumf += d*scales[1]*__riscv_vmv_x_s_i16m1_i16(aux2); + + } + + *s = sumf; + #else uint8_t aux8[QK_K]; @@ -3394,6 +3827,93 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri *s = hsum_float_8(acc) + summs; +#elif defined __riscv_v_intrinsic + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + float sumf = 0; + float sums = 0.0; + + size_t vl; + + for (int i = 0; i < nb; ++i) { + + vl = 8; + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict hm = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; + const float dmin = ggml_fp16_to_fp32(x[i].dmin) * y[i].d; + + vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); + vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); + vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); + vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); + vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + + vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + + vl = 32; + int32_t aux32 = 0; + int is = 0; + + uint8_t m = 1; + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl); + + for (int j = 0; j < QK_K/64; ++j) { + // load Q5 and Q8 + vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl); + vint8m1_t q8_y1 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl); + + // compute mask for addition + vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl)); + vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl); + vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_m(vmask_1, q5_a, 16, vl); + m <<= 1; + + vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl)); + vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl); + vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_m(vmask_2, q5_l, 16, vl); + m <<= 1; + + vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl); + vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl); + + vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl); + vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl); + + vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl); + vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl); + + aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2); + q5 += 32; q8 += 64; + + } + + vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1); + sums += __riscv_vfmv_f_s_f32m1_f32(vaux); + + } + + *s = sumf+sums; + #else const uint8_t * scales = (const uint8_t*)&utmp[0]; @@ -3639,6 +4159,76 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri *s = hsum_float_8(acc); +#elif defined __riscv_v_intrinsic + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * (float)x[i].d; + const int8_t * sc = x[i].scales; + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + + // load qh + vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4(qh, 8); + vuint8mf2_t qh_x2 = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8)); + + size_t vl = 16; + + // combine both qh_1 and qh_2 + vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl); + + vuint8mf2_t qh_h0 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl); + vuint8mf2_t qh_h1 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), vl), 16, vl); + vuint8mf2_t qh_h2 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(qh_x, vl), 16, vl); + vuint8mf2_t qh_h3 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl); + + vint8mf2_t qh_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h0); + vint8mf2_t qh_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h1); + vint8mf2_t qh_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h2); + vint8mf2_t qh_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h3); + + // load q5 + vuint8mf2_t q5_x1 = __riscv_vle8_v_u8mf2(q5, vl); + vuint8mf2_t q5_x2 = __riscv_vle8_v_u8mf2(q5+16, vl); + + vint8mf2_t q5s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x1, 0xF, vl)); + vint8mf2_t q5s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x2, 0xF, vl)); + vint8mf2_t q5s_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x1, 0x4, vl)); + vint8mf2_t q5s_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x2, 0x4, vl)); + + vint8mf2_t q5_0 = __riscv_vsub_vv_i8mf2(q5s_0, qh_0, vl); + vint8mf2_t q5_1 = __riscv_vsub_vv_i8mf2(q5s_1, qh_1, vl); + vint8mf2_t q5_2 = __riscv_vsub_vv_i8mf2(q5s_2, qh_2, vl); + vint8mf2_t q5_3 = __riscv_vsub_vv_i8mf2(q5s_3, qh_3, vl); + + // load Q8 and multiply it with Q5 + vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q5_0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q5_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); + vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q5_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q5_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); + + vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl); + vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl); + vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl); + vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl); + + int32_t sumi1 = sc[0] * __riscv_vmv_x_s_i32m1_i32(vs_0); + int32_t sumi2 = sc[1] * __riscv_vmv_x_s_i32m1_i32(vs_1); + int32_t sumi3 = sc[2] * __riscv_vmv_x_s_i32m1_i32(vs_2); + int32_t sumi4 = sc[3] * __riscv_vmv_x_s_i32m1_i32(vs_3); + + sumf += d * (sumi1 + sumi2 + sumi3 + sumi4); + + } + + *s = sumf; + #else int8_t aux8[QK_K]; @@ -4023,6 +4613,91 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri *s = hsum_float_8(acc); +#elif defined __riscv_v_intrinsic + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + size_t vl; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + + int sum_t = 0; + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + vl = 32; + + // load qh + vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); + + // load Q6 + vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); + vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); + + vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); + vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); + vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); + vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); + + vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); + vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); + vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); + vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); + + vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); + vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); + vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); + vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); + + vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); + vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); + vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); + vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); + + // load Q8 and take product + vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + + vl = 16; + + vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); + vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); + vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); + vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); + vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); + vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); + vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); + vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + + q6 += 64; qh += 32; q8 += 128; is=8; + + } + + sumf += d * sum_t; + + } + + *s = sumf; + #else int8_t aux8[QK_K]; @@ -4276,6 +4951,73 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri *s = hsum_float_8(acc); +#elif defined __riscv_v_intrinsic + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d_all = (float)x[i].d; + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + int32_t isum = 0; + + size_t vl = 16; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + + // load Q6 + vuint8mf2_t q6_0 = __riscv_vle8_v_u8mf2(q6, vl); + vuint8mf2_t q6_1 = __riscv_vle8_v_u8mf2(q6+16, vl); + + // load qh + vuint8mf2_t qh_x = __riscv_vle8_v_u8mf2(qh, vl); + + vuint8mf2_t qh0 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); + qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); + vuint8mf2_t qh1 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); + qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); + vuint8mf2_t qh2 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); + qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); + vuint8mf2_t qh3 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); + + vuint8mf2_t q6h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_0, 0xF, vl), qh0, vl); + vuint8mf2_t q6h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_1, 0xF, vl), qh1, vl); + vuint8mf2_t q6h_2 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_0, 0x4, vl), qh2, vl); + vuint8mf2_t q6h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_1, 0x4, vl), qh3, vl); + + vint8mf2_t q6v_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_0), 32, vl); + vint8mf2_t q6v_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_1), 32, vl); + vint8mf2_t q6v_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_2), 32, vl); + vint8mf2_t q6v_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_3), 32, vl); + + // load Q8 and take product + vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q6v_0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q6v_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); + vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q6v_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q6v_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); + + vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl); + vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl); + vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl); + vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl); + + isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scale[0]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scale[1]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scale[2]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scale[3]; + + sumf += isum * d_all * y[i].d; + + } + + *s = sumf; + #else int8_t aux8[QK_K]; diff --git a/plugins/wasi_nn/thirdparty/ggml/k_quants.h b/plugins/wasi_nn/thirdparty/ggml/k_quants.h index adc6a391..9de089e7 100644 --- a/plugins/wasi_nn/thirdparty/ggml/k_quants.h +++ b/plugins/wasi_nn/thirdparty/ggml/k_quants.h @@ -29,7 +29,7 @@ // 2-bit quantization // weight is represented as x = a * q + b -// 16 blocks of 16 elemenets each +// 16 blocks of 16 elements each // Effectively 2.5625 bits per weight typedef struct { uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits @@ -41,7 +41,7 @@ static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "w // 3-bit quantization // weight is represented as x = a * q -// 16 blocks of 16 elemenets each +// 16 blocks of 16 elements each // Effectively 3.4375 bits per weight #ifdef GGML_QKK_64 typedef struct { @@ -62,7 +62,7 @@ static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + #endif // 4-bit quantization -// 16 blocks of 32 elements each +// 8 blocks of 32 elements each // weight is represented as x = a * q + b // Effectively 4.5 bits per weight #ifdef GGML_QKK_64 @@ -83,7 +83,7 @@ static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/ #endif // 5-bit quantization -// 16 blocks of 32 elements each +// 8 blocks of 32 elements each // weight is represented as x = a * q + b // Effectively 5.5 bits per weight #ifdef GGML_QKK_64 @@ -107,7 +107,7 @@ static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/ // 6-bit quantization // weight is represented as x = a * q -// 16 blocks of 16 elemenets each +// 16 blocks of 16 elements each // Effectively 6.5625 bits per weight typedef struct { uint8_t ql[QK_K/2]; // quants, lower 4 bits diff --git a/plugins/wasi_nn/thirdparty/ggml/llama.cpp b/plugins/wasi_nn/thirdparty/ggml/llama.cpp index d4f3b184..ef5eeb99 100644 --- a/plugins/wasi_nn/thirdparty/ggml/llama.cpp +++ b/plugins/wasi_nn/thirdparty/ggml/llama.cpp @@ -1,6 +1,8 @@ #define LLAMA_API_INTERNAL #include "llama.h" +#include "unicode.h" + #include "ggml.h" #include "ggml-alloc.h" @@ -123,6 +125,27 @@ static void replace_all(std::string & s, const std::string & search, const std:: } s = std::move(result); } + +static bool is_float_close(float a, float b, float abs_tol) { + // Check for non-negative tolerance + if (abs_tol < 0.0) { + throw std::invalid_argument("Tolerance must be non-negative"); + } + + // Exact equality check + if (a == b) { + return true; + } + + // Check for infinities + if (std::isinf(a) || std::isinf(b)) { + return false; + } + + // Regular comparison using the provided absolute tolerance + return std::fabs(b - a) <= abs_tol; +} + #ifdef GGML_USE_CPU_HBM #include #endif @@ -163,6 +186,9 @@ enum llm_arch { LLM_ARCH_GPTNEOX, LLM_ARCH_MPT, LLM_ARCH_STARCODER, + LLM_ARCH_PERSIMMON, + LLM_ARCH_REFACT, + LLM_ARCH_BLOOM, LLM_ARCH_UNKNOWN, }; @@ -175,6 +201,9 @@ static std::map LLM_ARCH_NAMES = { { LLM_ARCH_MPT, "mpt" }, { LLM_ARCH_BAICHUAN, "baichuan" }, { LLM_ARCH_STARCODER, "starcoder" }, + { LLM_ARCH_PERSIMMON, "persimmon" }, + { LLM_ARCH_REFACT, "refact" }, + { LLM_ARCH_BLOOM, "bloom" }, }; enum llm_kv { @@ -277,6 +306,7 @@ struct LLM_KV { enum llm_tensor { LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_TOKEN_EMBD_NORM, LLM_TENSOR_POS_EMBD, LLM_TENSOR_OUTPUT, LLM_TENSOR_OUTPUT_NORM, @@ -293,6 +323,8 @@ enum llm_tensor { LLM_TENSOR_FFN_DOWN, LLM_TENSOR_FFN_UP, LLM_TENSOR_FFN_NORM, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K_NORM, }; static std::map> LLM_TENSOR_NAMES = { @@ -374,10 +406,35 @@ static std::map> LLM_TENSOR_NAMES = { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_PERSIMMON, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd"}, + { LLM_TENSOR_OUTPUT_NORM, "output_norm"}, + { LLM_TENSOR_OUTPUT, "output"}, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm"}, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv"}, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output"}, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm"}, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm"}, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm"}, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down"}, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up"}, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd"}, + }, + }, { LLM_ARCH_MPT, { { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, { @@ -395,6 +452,38 @@ static std::map> LLM_TENSOR_NAMES = { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, }, }, + { + LLM_ARCH_REFACT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_BLOOM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -917,6 +1006,7 @@ enum e_model { MODEL_1B, MODEL_3B, MODEL_7B, + MODEL_8B, MODEL_13B, MODEL_15B, MODEL_30B, @@ -947,8 +1037,28 @@ struct llama_hparams { float rope_freq_base_train; float rope_freq_scale_train; + float f_clamp_kqv; + float f_max_alibi_bias; + bool operator!=(const llama_hparams & other) const { - return static_cast(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT + if (this->vocab_only != other.vocab_only) return true; + if (this->n_vocab != other.n_vocab) return true; + if (this->n_ctx_train != other.n_ctx_train) return true; + if (this->n_embd != other.n_embd) return true; + if (this->n_head != other.n_head) return true; + if (this->n_head_kv != other.n_head_kv) return true; + if (this->n_layer != other.n_layer) return true; + if (this->n_rot != other.n_rot) return true; + if (this->n_ff != other.n_ff) return true; + + const float EPSILON = 1e-9; + + if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true; + if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true; + if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true; + if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true; + + return false; } uint32_t n_gqa() const { @@ -982,6 +1092,10 @@ struct llama_layer { struct ggml_tensor * attn_norm_b; struct ggml_tensor * attn_norm_2; struct ggml_tensor * attn_norm_2_b; + struct ggml_tensor * attn_q_norm; + struct ggml_tensor * attn_q_norm_b; + struct ggml_tensor * attn_k_norm; + struct ggml_tensor * attn_k_norm_b; // attention struct ggml_tensor * wq; @@ -1023,6 +1137,9 @@ struct llama_kv_cell { struct llama_kv_cache { bool has_shift = false; + // Note: The value of head isn't only used to optimize searching + // for a free KV slot. llama_decode_internal also uses it, so it + // cannot be freely changed after a slot has been allocated. uint32_t head = 0; uint32_t size = 0; @@ -1108,6 +1225,8 @@ struct llama_model { struct ggml_tensor * tok_embeddings; struct ggml_tensor * pos_embeddings; + struct ggml_tensor * tok_norm; + struct ggml_tensor * tok_norm_b; struct ggml_tensor * output_norm; struct ggml_tensor * output_norm_b; @@ -1237,7 +1356,11 @@ static bool llama_kv_cache_init( cache.cells.clear(); cache.cells.resize(n_ctx); + // TODO: this should be: + // cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*ggml_tensor_overhead()); + // change it and test that it works cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB); + memset(cache.buf.data, 0, cache.buf.size); struct ggml_init_params params; params.mem_size = cache.buf.size; @@ -1280,9 +1403,11 @@ static bool llama_kv_cache_init( // find an empty slot of size "n_tokens" in the cache // updates the cache head +// Note: On success, it's important that cache.head points +// to the first cell of the slot. static bool llama_kv_cache_find_slot( - struct llama_kv_cache & cache, - const struct llama_batch & batch) { + struct llama_kv_cache & cache, + const struct llama_batch & batch) { const uint32_t n_ctx = cache.size; const uint32_t n_tokens = batch.n_tokens; @@ -1295,8 +1420,8 @@ static bool llama_kv_cache_find_slot( while (true) { if (cache.head + n_tokens > n_ctx) { + n_tested += n_ctx - cache.head; cache.head = 0; - n_tested += n_ctx - cache.head; continue; } @@ -1347,29 +1472,46 @@ static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, cache.cells[i].pos = -1; cache.cells[i].seq_id.clear(); } + + // Searching for a free slot can start here since we know it will be empty. + cache.head = uint32_t(c0); } static void llama_kv_cache_seq_rm( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1) { + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { + uint32_t new_head = cache.size; + + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + for (uint32_t i = 0; i < cache.size; ++i) { if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { cache.cells[i].seq_id.erase(seq_id); if (cache.cells[i].seq_id.empty()) { cache.cells[i].pos = -1; + if (new_head == cache.size) new_head = i; } } } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.size) cache.head = new_head; } static void llama_kv_cache_seq_cp( - struct llama_kv_cache & cache, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1) { + struct llama_kv_cache & cache, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1) { + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + + cache.head = 0; + for (uint32_t i = 0; i < cache.size; ++i) { if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { cache.cells[i].seq_id.insert(seq_id_dst); @@ -1378,32 +1520,48 @@ static void llama_kv_cache_seq_cp( } static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) { + uint32_t new_head = cache.size; + for (uint32_t i = 0; i < cache.size; ++i) { if (!cache.cells[i].has_seq_id(seq_id)) { cache.cells[i].pos = -1; cache.cells[i].seq_id.clear(); + if (new_head == cache.size) new_head = i; } } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.size) cache.head = new_head; } static void llama_kv_cache_seq_shift( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta) { + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta) { + uint32_t new_head = cache.size; + + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + for (uint32_t i = 0; i < cache.size; ++i) { if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { cache.cells[i].pos += delta; if (cache.cells[i].pos < 0) { cache.cells[i].pos = -1; cache.cells[i].seq_id.clear(); + if (new_head == cache.size) new_head = i; } else { cache.has_shift = true; cache.cells[i].delta = delta; } } } + + // If we freed up a slot, set head to it so searching can start there. + // Otherwise we just start the next search from the beginning. + cache.head = new_head != cache.size ? new_head : 0; } // @@ -1607,7 +1765,7 @@ struct llama_model_loader { } } - struct ggml_tensor * create_tensor_for(struct ggml_context * ctx, struct ggml_tensor * meta, ggml_backend backend) { + struct ggml_tensor * create_tensor_for(struct ggml_context * ctx, struct ggml_tensor * meta, ggml_backend_type backend) { if (backend != GGML_BACKEND_CPU) { ggml_set_no_alloc(ctx, true); } @@ -1625,7 +1783,7 @@ struct llama_model_loader { return tensor; } - struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector & ne, ggml_backend backend) { + struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector & ne, ggml_backend_type backend) { struct ggml_tensor * cur = ggml_get_tensor(ctx_meta, name.c_str()); if (cur == NULL) { @@ -1804,6 +1962,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_1B: return "1B"; case MODEL_3B: return "3B"; case MODEL_7B: return "7B"; + case MODEL_8B: return "8B"; case MODEL_13B: return "13B"; case MODEL_15B: return "15B"; case MODEL_30B: return "30B"; @@ -1916,6 +2075,49 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_PERSIMMON: + { + GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS)); + switch (hparams.n_layer) { + case 36: model.type = e_model::MODEL_8B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_REFACT: + { + GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS)); + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_1B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_BLOOM: + { + GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS)); + + switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_1B; break; + case 30: + switch (hparams.n_embd) { + case 2560: model.type = e_model::MODEL_3B; break; + case 4096: model.type = e_model::MODEL_7B; break; + } break; + } + } break; + case LLM_ARCH_MPT: + { + hparams.f_clamp_kqv = 0.0f; + + GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS)); + GGUF_GET_KEY(ctx, hparams.f_clamp_kqv, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_CLAMP_KQV)); + GGUF_GET_KEY(ctx, hparams.f_max_alibi_bias, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS)); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 48: model.type = e_model::MODEL_30B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -1980,6 +2182,7 @@ static void llm_load_vocab( for (int i = 0; i < n_merges; i++) { const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); + GGML_ASSERT(codepoints_from_utf8(word).size() > 0); std::string first; std::string second; @@ -2014,6 +2217,7 @@ static void llm_load_vocab( for (uint32_t i = 0; i < n_vocab; i++) { std::string word = gguf_get_arr_str(ctx, token_idx, i); + GGML_ASSERT(codepoints_from_utf8(word).size() > 0); vocab.token_to_id[word] = i; @@ -2022,12 +2226,13 @@ static void llm_load_vocab( token_data.score = scores ? scores[i] : 0.0f; token_data.type = toktypes ? (llama_token_type) toktypes[i] : LLAMA_TOKEN_TYPE_NORMAL; } + GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size()); // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { vocab.linefeed_id = llama_byte_to_token(vocab, '\n'); } else { - vocab.linefeed_id = llama_tokenize_internal(vocab, "\n", false)[0]; + vocab.linefeed_id = llama_tokenize_internal(vocab, "\u010A", false)[0]; } // special tokens @@ -2057,6 +2262,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa()); LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); + LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); + LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); @@ -2150,13 +2357,14 @@ static void llm_load_tensors( const auto tn = LLM_TN(model.arch); switch (model.arch) { case LLM_ARCH_LLAMA: + case LLM_ARCH_REFACT: { model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); // output { - ggml_backend backend_norm; - ggml_backend backend_output; + ggml_backend_type backend_norm; + ggml_backend_type backend_output; if (n_gpu_layers > int(n_layer)) { // norm is not performance relevant on its own but keeping it in VRAM reduces data copying @@ -2191,8 +2399,8 @@ static void llm_load_tensors( model.layers.resize(n_layer); for (uint32_t i = 0; i < n_layer; ++i) { - const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT - const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT + const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT + const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT auto & layer = model.layers[i]; @@ -2221,8 +2429,8 @@ static void llm_load_tensors( { model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); { - ggml_backend backend_norm; - ggml_backend backend_output; + ggml_backend_type backend_norm; + ggml_backend_type backend_output; if (n_gpu_layers > int(n_layer)) { // norm is not performance relevant on its own but keeping it in VRAM reduces data copying @@ -2257,8 +2465,8 @@ static void llm_load_tensors( model.layers.resize(n_layer); for (uint32_t i = 0; i < n_layer; ++i) { - const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT - const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT + const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT + const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT auto & layer = model.layers[i]; @@ -2291,8 +2499,8 @@ static void llm_load_tensors( // output { - ggml_backend backend_norm; - ggml_backend backend_output; + ggml_backend_type backend_norm; + ggml_backend_type backend_output; if (n_gpu_layers > int(n_layer)) { // norm is not performance relevant on its own but keeping it in VRAM reduces data copying @@ -2329,8 +2537,8 @@ static void llm_load_tensors( model.layers.resize(n_layer); for (uint32_t i = 0; i < n_layer; ++i) { - const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT - const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT + const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT + const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT auto & layer = model.layers[i]; @@ -2368,8 +2576,8 @@ static void llm_load_tensors( // output { - ggml_backend backend_norm; - ggml_backend backend_output; + ggml_backend_type backend_norm; + ggml_backend_type backend_output; if (n_gpu_layers > int(n_layer)) { // norm is not performance relevant on its own but keeping it in VRAM reduces data copying @@ -2406,8 +2614,8 @@ static void llm_load_tensors( model.layers.resize(n_layer); for (uint32_t i = 0; i < n_layer; ++i) { - const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT - const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT + const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT + const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT auto & layer = model.layers[i]; @@ -2440,103 +2648,313 @@ static void llm_load_tensors( } } } break; - default: - throw std::runtime_error("unknown architecture"); - } - } + case LLM_ARCH_PERSIMMON: + { + model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); - ml.done_getting_tensors(); + { + ggml_backend_type backend_norm; + ggml_backend_type backend_output; - // print memory requirements - { - // this is the total memory required to run the inference - size_t mem_required = - ctx_size + - mmapped_size - vram_weights; // weights in VRAM not in memory + if (n_gpu_layers > int(n_layer)) { + // norm is not performance relevant on its own but keeping it in VRAM reduces data copying + // on Windows however this is detrimental unless everything is on the GPU +#ifndef _WIN32 + backend_norm = LLAMA_BACKEND_OFFLOAD; +#else + backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; +#endif // _WIN32 - LLAMA_LOG_INFO("%s: mem required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0); + backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } -#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) - const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); - LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_gpu); - if (n_gpu_layers > (int) hparams.n_layer) { - LLAMA_LOG_INFO("%s: offloading non-repeating layers to GPU\n", __func__); - } + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + vram_weights += ggml_nbytes(model.output_norm_b); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } + } -#ifdef GGML_USE_CUBLAS - const int max_backend_supported_layers = hparams.n_layer + 3; - const int max_offloadable_layers = hparams.n_layer + 3; -#elif defined(GGML_USE_CLBLAST) - const int max_backend_supported_layers = hparams.n_layer + 1; - const int max_offloadable_layers = hparams.n_layer + 1; -#endif // GGML_USE_CUBLAS + const uint32_t n_ff = hparams.n_ff; + const int i_gpu_start = n_layer - n_gpu_layers; + model.layers.resize(n_layer); + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; + auto & layer = model.layers[i]; + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); + layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); + layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend_split); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend_split); + layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split); + layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend_split); + layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend_split); + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend); + layer.attn_q_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {64}, backend); + layer.attn_q_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {64}, backend); + layer.attn_k_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {64}, backend); + layer.attn_k_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {64}, backend); + } + } break; + case LLM_ARCH_BLOOM: + { + // TODO: CPU-only for now - LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); - LLAMA_LOG_INFO("%s: VRAM used: %.2f MB\n", __func__, vram_weights / 1024.0 / 1024.0); + model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + model.tok_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, GGML_BACKEND_CPU); + model.tok_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, GGML_BACKEND_CPU); + + // output + { + ggml_backend_type backend_norm; + ggml_backend_type backend_output; + + if (n_gpu_layers > int(n_layer)) { + // norm is not performance relevant on its own but keeping it in VRAM reduces data copying + // on Windows however this is detrimental unless everything is on the GPU +#ifndef _WIN32 + backend_norm = LLAMA_BACKEND_OFFLOAD; #else - (void) n_gpu_layers; -#endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) - } + backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; +#endif // _WIN32 - // populate `tensors_by_name` - for (int i = 0; i < ml.n_tensors; ++i) { - struct ggml_tensor * cur = ggml_get_tensor(ctx, ml.get_tensor_name(i)); - model.tensors_by_name.emplace_back(ggml_get_name(cur), cur); - } + backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } - (void) tensor_split; -#ifdef GGML_USE_CUBLAS - { - ggml_cuda_set_tensor_split(tensor_split); - } -#endif + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); - ml.load_all_data(ctx, progress_callback, progress_callback_user_data, use_mlock ? &model.mlock_mmap : NULL); + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + vram_weights += ggml_nbytes(model.output_norm_b); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } + } - if (progress_callback) { - progress_callback(1.0f, progress_callback_user_data); - } + const uint32_t n_ff = hparams.n_ff; - model.mapping = std::move(ml.mapping); + const int i_gpu_start = n_layer - n_gpu_layers; - // loading time will be recalculate after the first eval, so - // we take page faults deferred by mmap() into consideration - model.t_load_us = ggml_time_us() - model.t_start_us; -} + model.layers.resize(n_layer); -static bool llama_model_load( - const std::string & fname, - llama_model & model, - int n_gpu_layers, - int main_gpu, - const float * tensor_split, - bool use_mmap, - bool use_mlock, - bool vocab_only, - llama_progress_callback progress_callback, - void *progress_callback_user_data) { - try { - llama_model_loader ml(fname, use_mmap); + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT + const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT - model.hparams.vocab_only = vocab_only; + auto & layer = model.layers[i]; - llm_load_arch (ml, model); - llm_load_hparams(ml, model); - llm_load_vocab (ml, model); + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); - llm_load_print_meta(ml, model); + layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); + layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend_split); - if (model.hparams.n_vocab != model.vocab.id_to_token.size()) { - throw std::runtime_error("vocab size mismatch"); - } + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend_split); - if (vocab_only) { - LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__); - return true; - } + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend); - llm_load_tensors( + layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split); + layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend_split); + + layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend_split); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += + ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) + + ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.bqkv) + + ggml_nbytes(layer.wo) + ggml_nbytes(layer.bo) + + ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_norm_b) + + ggml_nbytes(layer.w3) + ggml_nbytes(layer.b3) + + ggml_nbytes(layer.w2) + ggml_nbytes(layer.b2); + } + } + } break; + case LLM_ARCH_MPT: + { + model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + + // output + { + ggml_backend_type backend_norm; + ggml_backend_type backend_output; + + if (n_gpu_layers > int(n_layer)) { + // norm is not performance relevant on its own but keeping it in VRAM reduces data copying + // on Windows however this is detrimental unless everything is on the GPU +#ifndef _WIN32 + backend_norm = LLAMA_BACKEND_OFFLOAD; +#else + backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; +#endif // _WIN32 + + backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } + + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } + } + + const uint32_t n_ff = hparams.n_ff; + + const int i_gpu_start = n_layer - n_gpu_layers; + + model.layers.resize(n_layer); + + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT + const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + + layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += + ggml_nbytes(layer.attn_norm) + + ggml_nbytes(layer.wqkv) + + ggml_nbytes(layer.wo) + + ggml_nbytes(layer.ffn_norm) + + ggml_nbytes(layer.w2) + + ggml_nbytes(layer.w3); + } + } + } break; + default: + throw std::runtime_error("unknown architecture"); + } + } + + ml.done_getting_tensors(); + + // print memory requirements + { + // this is the total memory required to run the inference + size_t mem_required = + ctx_size + + mmapped_size - vram_weights; // weights in VRAM not in memory + + LLAMA_LOG_INFO("%s: mem required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0); + +#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) + const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); + + LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_gpu); + if (n_gpu_layers > (int) hparams.n_layer) { + LLAMA_LOG_INFO("%s: offloading non-repeating layers to GPU\n", __func__); + } + +#ifdef GGML_USE_CUBLAS + const int max_backend_supported_layers = hparams.n_layer + 3; + const int max_offloadable_layers = hparams.n_layer + 3; +#elif defined(GGML_USE_CLBLAST) + const int max_backend_supported_layers = hparams.n_layer + 1; + const int max_offloadable_layers = hparams.n_layer + 1; +#endif // GGML_USE_CUBLAS + + LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); + LLAMA_LOG_INFO("%s: VRAM used: %.2f MB\n", __func__, vram_weights / 1024.0 / 1024.0); +#else + (void) n_gpu_layers; +#endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) + } + + // populate `tensors_by_name` + for (int i = 0; i < ml.n_tensors; ++i) { + struct ggml_tensor * cur = ggml_get_tensor(ctx, ml.get_tensor_name(i)); + model.tensors_by_name.emplace_back(ggml_get_name(cur), cur); + } + + (void) tensor_split; +#ifdef GGML_USE_CUBLAS + { + ggml_cuda_set_tensor_split(tensor_split); + } +#endif + + ml.load_all_data(ctx, progress_callback, progress_callback_user_data, use_mlock ? &model.mlock_mmap : NULL); + + if (progress_callback) { + progress_callback(1.0f, progress_callback_user_data); + } + + model.mapping = std::move(ml.mapping); + + // loading time will be recalculate after the first eval, so + // we take page faults deferred by mmap() into consideration + model.t_load_us = ggml_time_us() - model.t_start_us; +} + +static bool llama_model_load( + const std::string & fname, + llama_model & model, + int n_gpu_layers, + int main_gpu, + const float * tensor_split, + bool use_mmap, + bool use_mlock, + bool vocab_only, + llama_progress_callback progress_callback, + void *progress_callback_user_data) { + try { + llama_model_loader ml(fname, use_mmap); + + model.hparams.vocab_only = vocab_only; + + llm_load_arch (ml, model); + llm_load_hparams(ml, model); + llm_load_vocab (ml, model); + + llm_load_print_meta(ml, model); + + if (model.hparams.n_vocab != model.vocab.id_to_token.size()) { + throw std::runtime_error("vocab size mismatch"); + } + + if (vocab_only) { + LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__); + return true; + } + + llm_load_tensors( ml, model, n_gpu_layers, main_gpu, tensor_split, use_mlock, progress_callback, progress_callback_user_data); @@ -2549,8 +2967,8 @@ static bool llama_model_load( } static struct ggml_cgraph * llm_build_llama( - llama_context & lctx, - const llama_batch & batch) { + llama_context & lctx, + const llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; const auto & cparams = lctx.cparams; @@ -2588,11 +3006,9 @@ static struct ggml_cgraph * llm_build_llama( struct ggml_init_params params = { /*.mem_size =*/ buf_compute.size, /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ false, + /*.no_alloc =*/ true, }; - params.no_alloc = true; - struct ggml_context * ctx0 = ggml_init(params); ggml_cgraph * gf = ggml_new_graph(ctx0); @@ -2976,11 +3392,9 @@ static struct ggml_cgraph * llm_build_baichaun( struct ggml_init_params params = { /*.mem_size =*/ buf_compute.size, /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ false, + /*.no_alloc =*/ true, }; - params.no_alloc = true; - struct ggml_context * ctx0 = ggml_init(params); ggml_cgraph * gf = ggml_new_graph(ctx0); @@ -3343,7 +3757,7 @@ static struct ggml_cgraph * llm_build_baichaun( return gf; } -static struct ggml_cgraph * llm_build_falcon( +static struct ggml_cgraph * llm_build_refact( llama_context & lctx, const llama_batch & batch) { const auto & model = lctx.model; @@ -3362,11 +3776,7 @@ static struct ggml_cgraph * llm_build_falcon( const int64_t n_embd_head = hparams.n_embd_head(); const int64_t n_embd_gqa = hparams.n_embd_gqa(); - GGML_ASSERT(n_embd_head == hparams.n_rot); - - const float freq_base = cparams.rope_freq_base; - const float freq_scale = cparams.rope_freq_scale; - const float norm_eps = hparams.f_norm_eps; + const float norm_rms_eps = hparams.f_norm_rms_eps; const int n_gpu_layers = model.n_gpu_layers; @@ -3374,21 +3784,16 @@ static struct ggml_cgraph * llm_build_falcon( const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; - const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; - - //printf("kv_head = %d, n_kv = %d, n_tokens = %d, n_ctx = %d, is_measure = %d, has_shift = %d\n", - // kv_head, n_kv, n_tokens, n_ctx, ggml_allocr_is_measure(lctx.alloc), kv_self.has_shift); + // printf("n_kv = %d\n", n_kv); auto & buf_compute = lctx.buf_compute; struct ggml_init_params params = { /*.mem_size =*/ buf_compute.size, /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ false, + /*.no_alloc =*/ true, }; - params.no_alloc = true; - struct ggml_context * ctx0 = ggml_init(params); ggml_cgraph * gf = ggml_new_graph(ctx0); @@ -3445,7 +3850,7 @@ static struct ggml_cgraph * llm_build_falcon( ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); ggml_allocr_alloc(lctx.alloc, KQ_scale); if (!ggml_allocr_is_measure(lctx.alloc)) { - ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); + ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd_head))); } // KQ_mask (mask for 1 head, it will be broadcasted to all heads) @@ -3471,47 +3876,8 @@ static struct ggml_cgraph * llm_build_falcon( } } - // KQ_pos - contains the positions - struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - offload_func_kq(KQ_pos); - ggml_set_name(KQ_pos, "KQ_pos"); - ggml_allocr_alloc(lctx.alloc, KQ_pos); - if (!ggml_allocr_is_measure(lctx.alloc)) { - int * data = (int *) KQ_pos->data; - for (int i = 0; i < n_tokens; ++i) { - data[i] = batch.pos[i]; - } - } - - // shift the entire K-cache if needed - if (do_rope_shift) { - struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); - offload_func_kq(K_shift); - ggml_set_name(K_shift, "K_shift"); - ggml_allocr_alloc(lctx.alloc, K_shift); - if (!ggml_allocr_is_measure(lctx.alloc)) { - int * data = (int *) K_shift->data; - for (int i = 0; i < n_ctx; ++i) { - data[i] = kv_self.cells[i].delta; - } - } - - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * tmp = - ggml_rope_custom_inplace(ctx0, - ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_head_kv, n_ctx, - ggml_element_size(kv_self.k)*n_embd_head, - ggml_element_size(kv_self.k)*n_embd_gqa, - ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il), - K_shift, n_embd_head, 2, 0, freq_base, freq_scale); - offload_func_kq(tmp); - ggml_build_forward_expand(gf, tmp); - } - } - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * attn_norm; + ggml_format_name(inpL, "layer_inp_%d", il); offload_func_t offload_func = llama_nop; @@ -3521,80 +3887,49 @@ static struct ggml_cgraph * llm_build_falcon( } #endif // GGML_USE_CUBLAS - // self-attention - // TODO: refactor into common function (shared with LLaMA) + struct ggml_tensor * inpSA = inpL; + + // norm { - attn_norm = ggml_norm(ctx0, inpL, norm_eps); - offload_func(attn_norm); + cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps); + offload_func(cur); + ggml_set_name(cur, "rms_norm_0"); - attn_norm = ggml_add(ctx0, - ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm), - model.layers[il].attn_norm_b); - offload_func(attn_norm->src[0]); - offload_func(attn_norm); + // cur = cur*attn_norm(broadcasted) + cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm); + offload_func(cur); + ggml_set_name(cur, "attention_norm_0"); + } - if (model.layers[il].attn_norm_2) { // Falcon-40B - cur = ggml_norm(ctx0, inpL, norm_eps); - offload_func(cur); - - cur = ggml_add(ctx0, - ggml_mul(ctx0, cur, model.layers[il].attn_norm_2), - model.layers[il].attn_norm_2_b); - offload_func(cur->src[0]); - offload_func(cur); - } else { // Falcon 7B - cur = attn_norm; - } - - // compute QKV - - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); - offload_func_kq(cur); - - // Note that the strides for Kcur, Vcur are set up so that the - // resulting views are misaligned with the tensor's storage - // (by applying the K/V offset we shift the tensor's original - // view to stick out behind the viewed QKV tensor's allocated - // memory, so to say). This is ok because no actual accesses - // happen to that out-of-range memory, but it can require some - // trickery when trying to accurately dump these views for - // debugging. - - const size_t wsize = ggml_type_size(cur->type); + // self-attention + { + // compute Q and K + struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + offload_func_kq(tmpk); + ggml_set_name(tmpk, "tmpk"); - // TODO: these 2 ggml_conts are technically not needed, but we add them until CUDA support for - // non-contiguous views is added for the rope operator - struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_3d( - ctx0, cur, n_embd_head, n_head, n_tokens, - wsize * n_embd_head, - wsize * n_embd_head * (n_head + 2 * n_head_kv), - 0)); + struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); offload_func_kq(tmpq); + ggml_set_name(tmpq, "tmpq"); - struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_3d( - ctx0, cur, n_embd_head, n_head_kv, n_tokens, - wsize * n_embd_head, - wsize * n_embd_head * (n_head + 2 * n_head_kv), - wsize * n_embd_head * n_head)); - offload_func_kq(tmpk); - - struct ggml_tensor * tmpv = ggml_view_3d( - ctx0, cur, n_embd_head, n_head_kv, n_tokens, - wsize * n_embd_head, - wsize * n_embd_head * (n_head + 2 * n_head_kv), - wsize * n_embd_head * (n_head + n_head_kv)); - offload_func_v(tmpv); + struct ggml_tensor * Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens); + offload_func_kq(Kcur); + ggml_set_name(Kcur, "Kcur"); - // using mode = 2 for neox mode - struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, tmpq, KQ_pos, n_embd_head, 2, 0, freq_base, freq_scale); + struct ggml_tensor * Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens); offload_func_kq(Qcur); - struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, tmpk, KQ_pos, n_embd_head, 2, 0, freq_base, freq_scale); - offload_func_kq(Kcur); + ggml_set_name(Qcur, "Qcur"); + // store key and value to memory { - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens)); + // compute the transposed [n_tokens, n_embd] V matrix + + struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + offload_func_v(tmpv); + ggml_set_name(tmpv, "tmpv"); + + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens)); offload_func_v(Vcur); - offload_func_v(Vcur->src[0]->src[0]); ggml_set_name(Vcur, "Vcur"); struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); @@ -3605,6 +3940,7 @@ static struct ggml_cgraph * llm_build_falcon( ( n_ctx)*ggml_element_size(kv_self.v), (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); offload_func_v(v); + ggml_set_name(v, "v"); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); @@ -3623,22 +3959,31 @@ static struct ggml_cgraph * llm_build_falcon( offload_func_kq(K); ggml_set_name(K, "K"); + // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); offload_func_kq(KQ); ggml_set_name(KQ, "KQ"); + // KQ_scaled = KQ / sqrt(n_embd_head) + // KQ_scaled shape [n_kv, n_tokens, n_head, 1] struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); offload_func_kq(KQ_scaled); ggml_set_name(KQ_scaled, "KQ_scaled"); - struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); + // KQ_masked = mask_past(KQ_scaled) + struct ggml_tensor * KQ_scaled_alibi = ggml_alibi(ctx0, KQ_scaled, /*n_past*/ 0, n_head, 8); + ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi"); + + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask); offload_func_kq(KQ_masked); ggml_set_name(KQ_masked, "KQ_masked"); + // KQ = soft_max(KQ_masked) struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); offload_func_v(KQ_soft_max); ggml_set_name(KQ_soft_max, "KQ_soft_max"); + // split cached V into n_head heads struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, n_kv, n_embd_head, n_head_kv, @@ -3648,42 +3993,85 @@ static struct ggml_cgraph * llm_build_falcon( offload_func_v(V); ggml_set_name(V, "V"); +#if 1 struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); offload_func_v(KQV); ggml_set_name(KQV, "KQV"); +#else + // make V contiguous in memory to speed up the matmul, however we waste time on the copy + // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation + // is there a better way? + struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_ctx, n_embd_head, n_head)); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max); +#endif + // KQV_merged = KQV.permute(0, 2, 1, 3) struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); offload_func_v(KQV_merged); ggml_set_name(KQV_merged, "KQV_merged"); + // cur = KQV_merged.contiguous().view(n_embd, n_tokens) cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); offload_func_v(cur); ggml_set_name(cur, "KQV_merged_contiguous"); - cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); + // projection (no bias) + cur = ggml_mul_mat(ctx0, + model.layers[il].wo, + cur); offload_func(cur); ggml_set_name(cur, "result_wo"); } - struct ggml_tensor * attn_out = cur; + struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); + offload_func(inpFF); + ggml_set_name(inpFF, "inpFF"); - // feed forward + // feed-forward network { - struct ggml_tensor * inpFF = attn_norm; + // norm + { + cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps); + offload_func(cur); + ggml_set_name(cur, "rms_norm_1"); - cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF); + // cur = cur*ffn_norm(broadcasted) + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm); + offload_func(cur); + ggml_set_name(cur, "ffn_norm"); + } + + struct ggml_tensor * tmp = ggml_mul_mat(ctx0, + model.layers[il].w3, + cur); + offload_func(tmp); + ggml_set_name(tmp, "result_w3"); + + cur = ggml_mul_mat(ctx0, + model.layers[il].w1, + cur); offload_func(cur); + ggml_set_name(cur, "result_w1"); - cur = ggml_gelu(ctx0, cur); + // SILU activation + cur = ggml_silu(ctx0, cur); offload_func(cur); - cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); + ggml_set_name(cur, "silu"); + + cur = ggml_mul(ctx0, cur, tmp); + offload_func(cur); + ggml_set_name(cur, "silu_x_result_w3"); + + cur = ggml_mul_mat(ctx0, + model.layers[il].w2, + cur); offload_func(cur); + ggml_set_name(cur, "result_w2"); } - cur = ggml_add(ctx0, cur, attn_out); - offload_func(cur); - cur = ggml_add(ctx0, cur, inpL); + cur = ggml_add(ctx0, cur, inpFF); offload_func(cur); + ggml_set_name(cur, "inpFF_+_result_w2"); // input for next layer inpL = cur; @@ -3693,15 +4081,17 @@ static struct ggml_cgraph * llm_build_falcon( // norm { - cur = ggml_norm(ctx0, cur, norm_eps); + cur = ggml_rms_norm(ctx0, cur, norm_rms_eps); offload_func_nr(cur); + ggml_set_name(cur, "rms_norm_2"); - cur = ggml_add(ctx0, - ggml_mul(ctx0, cur, model.output_norm), - model.output_norm_b); + // cur = cur*norm(broadcasted) + cur = ggml_mul(ctx0, cur, model.output_norm); + // offload_func_nr(cur); // TODO CPU + GPU mirrored backend ggml_set_name(cur, "result_norm"); } + // lm_head cur = ggml_mul_mat(ctx0, model.output, cur); ggml_set_name(cur, "result_output"); @@ -3712,7 +4102,7 @@ static struct ggml_cgraph * llm_build_falcon( return gf; } -static struct ggml_cgraph * llm_build_starcoder( +static struct ggml_cgraph * llm_build_falcon( llama_context & lctx, const llama_batch & batch) { const auto & model = lctx.model; @@ -3733,29 +4123,34 @@ static struct ggml_cgraph * llm_build_starcoder( GGML_ASSERT(n_embd_head == hparams.n_rot); - const float norm_eps = hparams.f_norm_eps; + const float freq_base = cparams.rope_freq_base; + const float freq_scale = cparams.rope_freq_scale; + const float norm_eps = hparams.f_norm_eps; + + const int n_gpu_layers = model.n_gpu_layers; const int32_t n_tokens = batch.n_tokens; const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; + const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; + + //printf("kv_head = %d, n_kv = %d, n_tokens = %d, n_ctx = %d, is_measure = %d, has_shift = %d\n", + // kv_head, n_kv, n_tokens, n_ctx, ggml_allocr_is_measure(lctx.alloc), kv_self.has_shift); + auto & buf_compute = lctx.buf_compute; struct ggml_init_params params = { /*.mem_size =*/ buf_compute.size, /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ false, + /*.no_alloc =*/ true, }; - params.no_alloc = true; - struct ggml_context * ctx0 = ggml_init(params); ggml_cgraph * gf = ggml_new_graph(ctx0); struct ggml_tensor * cur; - struct ggml_tensor * token; - struct ggml_tensor * position; struct ggml_tensor * inpL; if (batch.token) { @@ -3767,33 +4162,1289 @@ static struct ggml_cgraph * llm_build_starcoder( } ggml_set_name(inp_tokens, "inp_tokens"); - token = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); + inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); } else { #ifdef GGML_USE_MPI GGML_ASSERT(false && "not implemented"); #endif - token = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); - ggml_allocr_alloc(lctx.alloc, token); + ggml_allocr_alloc(lctx.alloc, inpL); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(token->data, batch.embd, n_tokens * n_embd * ggml_element_size(token)); + memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); } } - { - // Compute position embeddings. - struct ggml_tensor * inp_positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - ggml_allocr_alloc(lctx.alloc, inp_positions); + const int i_gpu_start = n_layer - n_gpu_layers; + (void) i_gpu_start; + + // offload functions set the tensor output backend to GPU + // tensors are GPU-accelerated if any input or the output has been offloaded + offload_func_t offload_func_nr = llama_nop; // nr = non-repeating + offload_func_t offload_func_kq = llama_nop; + offload_func_t offload_func_v = llama_nop; + +#ifdef GGML_USE_CUBLAS + if (n_gpu_layers > n_layer) { + offload_func_nr = ggml_cuda_assign_buffers_no_alloc; + } + if (n_gpu_layers > n_layer + 1) { + offload_func_v = ggml_cuda_assign_buffers_no_alloc; + } + if (n_gpu_layers > n_layer + 2) { + offload_func_kq = ggml_cuda_assign_buffers_no_alloc; + } +#endif // GGML_USE_CUBLAS + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + ggml_allocr_alloc(lctx.alloc, KQ_scale); + if (!ggml_allocr_is_measure(lctx.alloc)) { + ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); + } + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + offload_func_kq(KQ_mask); + ggml_set_name(KQ_mask, "KQ_mask"); + ggml_allocr_alloc(lctx.alloc, KQ_mask); + if (!ggml_allocr_is_measure(lctx.alloc)) { + float * data = (float *) KQ_mask->data; + memset(data, 0, ggml_nbytes(KQ_mask)); + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j]; + + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } + } + + // KQ_pos - contains the positions + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + offload_func_kq(KQ_pos); + ggml_set_name(KQ_pos, "KQ_pos"); + ggml_allocr_alloc(lctx.alloc, KQ_pos); + if (!ggml_allocr_is_measure(lctx.alloc)) { + int * data = (int *) KQ_pos->data; + for (int i = 0; i < n_tokens; ++i) { + data[i] = batch.pos[i]; + } + } + + // shift the entire K-cache if needed + if (do_rope_shift) { + struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); + offload_func_kq(K_shift); + ggml_set_name(K_shift, "K_shift"); + ggml_allocr_alloc(lctx.alloc, K_shift); if (!ggml_allocr_is_measure(lctx.alloc)) { - for (int i = 0; i < n_tokens; ++i) { - ((int32_t *) inp_positions->data)[i] = batch.pos[i]; + int * data = (int *) K_shift->data; + for (int i = 0; i < n_ctx; ++i) { + data[i] = kv_self.cells[i].delta; + } + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * tmp = + ggml_rope_custom_inplace(ctx0, + ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_head_kv, n_ctx, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il), + K_shift, n_embd_head, 2, 0, freq_base, freq_scale); + offload_func_kq(tmp); + ggml_build_forward_expand(gf, tmp); + } + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * attn_norm; + + offload_func_t offload_func = llama_nop; + +#ifdef GGML_USE_CUBLAS + if (il >= i_gpu_start) { + offload_func = ggml_cuda_assign_buffers_no_alloc; + } +#endif // GGML_USE_CUBLAS + + // self-attention + // TODO: refactor into common function (shared with LLaMA) + { + attn_norm = ggml_norm(ctx0, inpL, norm_eps); + offload_func(attn_norm); + + attn_norm = ggml_add(ctx0, + ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm), + model.layers[il].attn_norm_b); + offload_func(attn_norm->src[0]); + offload_func(attn_norm); + + if (model.layers[il].attn_norm_2) { // Falcon-40B + cur = ggml_norm(ctx0, inpL, norm_eps); + offload_func(cur); + + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, model.layers[il].attn_norm_2), + model.layers[il].attn_norm_2_b); + offload_func(cur->src[0]); + offload_func(cur); + } else { // Falcon 7B + cur = attn_norm; + } + + // compute QKV + + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + offload_func_kq(cur); + + // Note that the strides for Kcur, Vcur are set up so that the + // resulting views are misaligned with the tensor's storage + // (by applying the K/V offset we shift the tensor's original + // view to stick out behind the viewed QKV tensor's allocated + // memory, so to say). This is ok because no actual accesses + // happen to that out-of-range memory, but it can require some + // trickery when trying to accurately dump these views for + // debugging. + + const size_t wsize = ggml_type_size(cur->type); + + // TODO: these 2 ggml_conts are technically not needed, but we add them until CUDA support for + // non-contiguous views is added for the rope operator + struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_3d( + ctx0, cur, n_embd_head, n_head, n_tokens, + wsize * n_embd_head, + wsize * n_embd_head * (n_head + 2 * n_head_kv), + 0)); + offload_func_kq(tmpq); + + struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_3d( + ctx0, cur, n_embd_head, n_head_kv, n_tokens, + wsize * n_embd_head, + wsize * n_embd_head * (n_head + 2 * n_head_kv), + wsize * n_embd_head * n_head)); + offload_func_kq(tmpk); + + struct ggml_tensor * tmpv = ggml_view_3d( + ctx0, cur, n_embd_head, n_head_kv, n_tokens, + wsize * n_embd_head, + wsize * n_embd_head * (n_head + 2 * n_head_kv), + wsize * n_embd_head * (n_head + n_head_kv)); + offload_func_v(tmpv); + + // using mode = 2 for neox mode + struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, tmpq, KQ_pos, n_embd_head, 2, 0, freq_base, freq_scale); + offload_func_kq(Qcur); + struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, tmpk, KQ_pos, n_embd_head, 2, 0, freq_base, freq_scale); + offload_func_kq(Kcur); + + { + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens)); + offload_func_v(Vcur); + offload_func_v(Vcur->src[0]->src[0]); + ggml_set_name(Vcur, "Vcur"); + + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); + offload_func_kq(k); + ggml_set_name(k, "k"); + + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); + offload_func_v(v); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + } + + struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + offload_func_kq(Q); + ggml_set_name(Q, "Q"); + + struct ggml_tensor * K = + ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_kv, n_head_kv, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); + offload_func_kq(K); + ggml_set_name(K, "K"); + + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + offload_func_kq(KQ); + ggml_set_name(KQ, "KQ"); + + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); + offload_func_kq(KQ_scaled); + ggml_set_name(KQ_scaled, "KQ_scaled"); + + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); + offload_func_kq(KQ_masked); + ggml_set_name(KQ_masked, "KQ_masked"); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); + offload_func_v(KQ_soft_max); + ggml_set_name(KQ_soft_max, "KQ_soft_max"); + + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_kv, n_embd_head, n_head_kv, + ggml_element_size(kv_self.v)*n_ctx, + ggml_element_size(kv_self.v)*n_ctx*n_embd_head, + ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); + offload_func_v(V); + ggml_set_name(V, "V"); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + offload_func_v(KQV); + ggml_set_name(KQV, "KQV"); + + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + offload_func_v(KQV_merged); + ggml_set_name(KQV_merged, "KQV_merged"); + + cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); + offload_func_v(cur); + ggml_set_name(cur, "KQV_merged_contiguous"); + + cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); + offload_func(cur); + ggml_set_name(cur, "result_wo"); + } + + struct ggml_tensor * attn_out = cur; + + // feed forward + { + struct ggml_tensor * inpFF = attn_norm; + + cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF); + offload_func(cur); + + cur = ggml_gelu(ctx0, cur); + offload_func(cur); + cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); + offload_func(cur); + } + + cur = ggml_add(ctx0, cur, attn_out); + offload_func(cur); + cur = ggml_add(ctx0, cur, inpL); + offload_func(cur); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + // norm + { + cur = ggml_norm(ctx0, cur, norm_eps); + offload_func_nr(cur); + + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, model.output_norm), + model.output_norm_b); + ggml_set_name(cur, "result_norm"); + } + + cur = ggml_mul_mat(ctx0, model.output, cur); + ggml_set_name(cur, "result_output"); + + ggml_build_forward_expand(gf, cur); + + ggml_free(ctx0); + + return gf; +} + +static struct ggml_cgraph * llm_build_starcoder( + llama_context & lctx, + const llama_batch & batch) { + const auto & model = lctx.model; + const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; + + const auto & kv_self = lctx.kv_self; + + GGML_ASSERT(!!kv_self.ctx); + + const int64_t n_embd = hparams.n_embd; + const int64_t n_layer = hparams.n_layer; + const int64_t n_ctx = cparams.n_ctx; + const int64_t n_head = hparams.n_head; + const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_embd_head = hparams.n_embd_head(); + const int64_t n_embd_gqa = hparams.n_embd_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_rot); + + const float norm_eps = hparams.f_norm_eps; + + const int32_t n_tokens = batch.n_tokens; + const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; + const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; + + auto & buf_compute = lctx.buf_compute; + + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute.size, + /*.mem_buffer =*/ buf_compute.data, + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur; + struct ggml_tensor * token; + struct ggml_tensor * position; + struct ggml_tensor * inpL; + + if (batch.token) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + + ggml_allocr_alloc(lctx.alloc, inp_tokens); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); + } + ggml_set_name(inp_tokens, "inp_tokens"); + + token = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); + } else { +#ifdef GGML_USE_MPI + GGML_ASSERT(false && "not implemented"); +#endif + + token = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + + ggml_allocr_alloc(lctx.alloc, token); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(token->data, batch.embd, n_tokens * n_embd * ggml_element_size(token)); + } + } + + { + // Compute position embeddings. + struct ggml_tensor * inp_positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_allocr_alloc(lctx.alloc, inp_positions); + if (!ggml_allocr_is_measure(lctx.alloc)) { + for (int i = 0; i < n_tokens; ++i) { + ((int32_t *) inp_positions->data)[i] = batch.pos[i]; } } ggml_set_name(inp_positions, "inp_positions"); - position = ggml_get_rows(ctx0, model.pos_embeddings, inp_positions); + position = ggml_get_rows(ctx0, model.pos_embeddings, inp_positions); + } + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + ggml_allocr_alloc(lctx.alloc, KQ_scale); + if (!ggml_allocr_is_measure(lctx.alloc)) { + ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); + } + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + ggml_set_name(KQ_mask, "KQ_mask"); + ggml_allocr_alloc(lctx.alloc, KQ_mask); + if (!ggml_allocr_is_measure(lctx.alloc)) { + float * data = (float *) KQ_mask->data; + memset(data, 0, ggml_nbytes(KQ_mask)); + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j]; + + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } + } + + inpL = ggml_add(ctx0, token, position); + ggml_set_name(inpL, "inpL"); + + for (int il = 0; il < n_layer; ++il) { + { + // Norm + cur = ggml_norm(ctx0, inpL, norm_eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].attn_norm), model.layers[il].attn_norm_b); + } + + { + // Self Attention + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wqkv, cur), model.layers[il].bqkv); + + struct ggml_tensor * tmpq = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*n_embd); + struct ggml_tensor * tmpk = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*n_embd); + struct ggml_tensor * tmpv = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*(n_embd + n_embd_gqa)); + + struct ggml_tensor * Qcur = tmpq; + struct ggml_tensor * Kcur = tmpk; + + { + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens)); + ggml_set_name(Vcur, "Vcur"); + + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); + ggml_set_name(k, "k"); + + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + } + + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Qcur, + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd_head, n_head, n_tokens)), + 0, 2, 1, 3); + ggml_set_name(Q, "Q"); + + struct ggml_tensor * K = + ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_kv, n_head_kv, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); + ggml_set_name(K, "K"); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + ggml_set_name(KQ, "KQ"); + + // KQ_scaled = KQ / sqrt(n_embd_head) + // KQ_scaled shape [n_past + n_tokens, n_tokens, n_head, 1] + struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + ggml_set_name(KQ_scaled, "KQ_scaled"); + + // KQ_masked = mask_past(KQ_scaled) + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); + ggml_set_name(KQ_masked, "KQ_masked"); + + // KQ = soft_max(KQ_masked) + struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + ggml_set_name(KQ_soft_max, "KQ_soft_max"); + + // split cached V into n_head heads + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_kv, n_embd_head, n_head_kv, + ggml_element_size(kv_self.v)*n_ctx, + ggml_element_size(kv_self.v)*n_ctx*n_embd_head, + ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); + ggml_set_name(V, "V"); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + ggml_set_name(KQV, "KQV"); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + ggml_set_name(KQV_merged, "KQV_merged"); + + // cur = KQV_merged.contiguous().view(n_embd, n_tokens) + cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); + ggml_set_name(cur, "KQV_merged_contiguous"); + } + + // Projection + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wo, cur), model.layers[il].bo); + + // Add the input + cur = ggml_add(ctx0, cur, inpL); + + struct ggml_tensor * inpFF = cur; + + // FF + { + // Norm + { + cur = ggml_norm(ctx0, inpFF, norm_eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ffn_norm), model.layers[il].ffn_norm_b); + } + + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w3, cur), model.layers[il].b3); + + // GELU activation + cur = ggml_gelu(ctx0, cur); + + // Projection + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w2, cur), model.layers[il].b2); + } + + inpL = ggml_add(ctx0, cur, inpFF); + } + + // Output Norm + { + cur = ggml_norm(ctx0, inpL, norm_eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.output_norm), model.output_norm_b); + } + ggml_set_name(cur, "result_norm"); + + cur = ggml_mul_mat(ctx0, model.output, cur); + ggml_set_name(cur, "result_output"); + + ggml_build_forward_expand(gf, cur); + ggml_free(ctx0); + + return gf; +} + +static struct ggml_cgraph * llm_build_persimmon( + llama_context & lctx, + const llama_batch & batch) { + const auto & model = lctx.model; + const auto & hparams = model.hparams; + + const auto & kv_self = lctx.kv_self; + + GGML_ASSERT(!!kv_self.ctx); + + const auto & cparams = lctx.cparams; + const int64_t n_embd = hparams.n_embd; + const int64_t n_layer = hparams.n_layer; + const int64_t n_ctx = cparams.n_ctx; + const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_head = hparams.n_head; + const int64_t n_embd_head = hparams.n_embd_head(); + const int64_t n_embd_gqa = hparams.n_embd_gqa(); + const size_t n_rot = n_embd_head / 2; + + const float freq_base = cparams.rope_freq_base; + const float freq_scale = cparams.rope_freq_scale; + const float norm_eps = hparams.f_norm_eps; + + const int n_gpu_layers = model.n_gpu_layers; + + + const int32_t n_tokens = batch.n_tokens; + const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; + const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; + + const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; + + auto & buf_compute = lctx.buf_compute; + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute.size, + /*.mem_buffer =*/ buf_compute.data, + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + if (batch.token) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + + ggml_allocr_alloc(lctx.alloc, inp_tokens); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); + } + ggml_set_name(inp_tokens, "inp_tokens"); + inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); + } else { + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + ggml_allocr_alloc(lctx.alloc, inpL); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); + } + } + const int i_gpu_start = n_layer - n_gpu_layers; + (void) i_gpu_start; + offload_func_t offload_func_nr = llama_nop; // nr = non-repeating + offload_func_t offload_func_kq = llama_nop; + offload_func_t offload_func_v = llama_nop; + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_allocr_alloc(lctx.alloc, KQ_scale); + if (!ggml_allocr_is_measure(lctx.alloc)) { + ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd_head))); + } + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + offload_func_kq(KQ_mask); + ggml_set_name(KQ_mask, "KQ_mask"); + ggml_allocr_alloc(lctx.alloc, KQ_mask); + + if (!ggml_allocr_is_measure(lctx.alloc)) { + float * data = (float *) KQ_mask->data; + memset(data, 0, ggml_nbytes(KQ_mask)); + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j]; + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } + } + + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + offload_func_kq(KQ_pos); + ggml_set_name(KQ_pos, "KQ_pos"); + ggml_allocr_alloc(lctx.alloc, KQ_pos); + if (!ggml_allocr_is_measure(lctx.alloc)) { + int * data = (int *) KQ_pos->data; + for (int i = 0; i < n_tokens; ++i) { + data[i] = batch.pos[i]; + } + } + if (do_rope_shift) { + struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); + offload_func_kq(K_shift); + ggml_set_name(K_shift, "K_shift"); + ggml_allocr_alloc(lctx.alloc, K_shift); + if (!ggml_allocr_is_measure(lctx.alloc)) { + int * data = (int *) K_shift->data; + for (int i = 0; i < n_ctx; ++i) { + data[i] = kv_self.cells[i].delta; + } + } + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * tmp = + // we rotate only the first n_rot dimensions. + ggml_rope_custom_inplace(ctx0, + ggml_view_3d(ctx0, kv_self.k, + n_rot, n_head, n_ctx, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*(n_embd_head*n_ctx*il) + ), + K_shift, n_rot, 2, 0, freq_base, freq_scale); + offload_func_kq(tmp); + ggml_build_forward_expand(gf, tmp); + } + } + for (int il=0; il < n_layer; ++il) { + struct ggml_tensor * residual = inpL; + offload_func_t offload_func = llama_nop; + { + cur = ggml_norm(ctx0, inpL, norm_eps); + offload_func(cur); + cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm); + offload_func(cur); + cur = ggml_add(ctx0, cur, model.layers[il].attn_norm_b); + offload_func(cur); + ggml_format_name(cur, "input_layernorm_%d", il); + } + // self attention + { + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + offload_func_kq(cur); + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + offload_func_kq(cur); + + // split qkv + GGML_ASSERT(n_head_kv == n_head); + ggml_set_name(cur, format("qkv_%d", il).c_str()); + struct ggml_tensor * tmpqkv = ggml_reshape_4d(ctx0, cur, n_embd_head, 3, n_head, n_tokens); + offload_func_kq(tmpqkv); + struct ggml_tensor * tmpqkv_perm = ggml_cont(ctx0, ggml_permute(ctx0, tmpqkv, 0, 3, 1, 2)); + offload_func_kq(tmpqkv_perm); + ggml_format_name(tmpqkv_perm, "tmpqkv_perm_%d", il); + struct ggml_tensor * tmpq = ggml_view_3d( + ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, + ggml_element_size(tmpqkv_perm) * n_embd_head, + ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, + 0 + ); + offload_func_kq(tmpq); + struct ggml_tensor * tmpk = ggml_view_3d( + ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, + ggml_element_size(tmpqkv_perm) * n_embd_head, + ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, + ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens + ); + offload_func_kq(tmpk); + // Q/K Layernorm + tmpq = ggml_norm(ctx0, tmpq, norm_eps); + offload_func_kq(tmpq); + tmpq = ggml_mul(ctx0, tmpq, model.layers[il].attn_q_norm); + offload_func_kq(tmpq); + tmpq = ggml_add(ctx0, tmpq, model.layers[il].attn_q_norm_b); + offload_func_kq(tmpq); + + tmpk = ggml_norm(ctx0, tmpk, norm_eps); + offload_func_v(tmpk); + tmpk = ggml_mul(ctx0, tmpk, model.layers[il].attn_k_norm); + offload_func_v(tmpk); + tmpk = ggml_add(ctx0, tmpk, model.layers[il].attn_k_norm_b); + offload_func_v(tmpk); + + // RoPE the first n_rot of q/k, pass the other half, and concat. + struct ggml_tensor * qrot = ggml_view_3d( + ctx0, tmpq, n_rot, n_head, n_tokens, + ggml_element_size(tmpq) * n_embd_head, + ggml_element_size(tmpq) * n_embd_head * n_head, + 0 + ); + offload_func_kq(qrot); + ggml_format_name(qrot, "qrot_%d", il); + struct ggml_tensor * krot = ggml_view_3d( + ctx0, tmpk, n_rot, n_head, n_tokens, + ggml_element_size(tmpk) * n_embd_head, + ggml_element_size(tmpk) * n_embd_head * n_head, + 0 + ); + offload_func_kq(krot); + ggml_format_name(krot, "krot_%d", il); + + // get the second half of tmpq, e.g tmpq[n_rot:, :, :] + struct ggml_tensor * qpass = ggml_view_3d( + ctx0, tmpq, n_rot, n_head, n_tokens, + ggml_element_size(tmpq) * n_embd_head, + ggml_element_size(tmpq) * n_embd_head * n_head, + ggml_element_size(tmpq) * n_rot + ); + offload_func_kq(qpass); + ggml_format_name(qpass, "qpass_%d", il); + struct ggml_tensor * kpass = ggml_view_3d( + ctx0, tmpk, n_rot, n_head, n_tokens, + ggml_element_size(tmpk) * n_embd_head, + ggml_element_size(tmpk) * n_embd_head * n_head, + ggml_element_size(tmpk) * n_rot + ); + offload_func_kq(kpass); + ggml_format_name(kpass, "kpass_%d", il); + + struct ggml_tensor * qrotated = ggml_rope_custom( + ctx0, qrot, KQ_pos, n_rot, 2, 0, freq_base, freq_scale + ); + offload_func_kq(qrotated); + struct ggml_tensor * krotated = ggml_rope_custom( + ctx0, krot, KQ_pos, n_rot, 2, 0, freq_base, freq_scale + ); + offload_func_kq(krotated); + // ggml currently only supports concatenation on dim=2 + // so we need to permute qrot, qpass, concat, then permute back. + qrotated = ggml_cont(ctx0, ggml_permute(ctx0, qrotated, 2, 1, 0, 3)); + offload_func_kq(qrotated); + krotated = ggml_cont(ctx0, ggml_permute(ctx0, krotated, 2, 1, 0, 3)); + offload_func_kq(krotated); + + qpass = ggml_cont(ctx0, ggml_permute(ctx0, qpass, 2, 1, 0, 3)); + offload_func_kq(qpass); + kpass = ggml_cont(ctx0, ggml_permute(ctx0, kpass, 2, 1, 0, 3)); + offload_func_kq(kpass); + + struct ggml_tensor * Qcur = ggml_concat(ctx0, qrotated, qpass); + offload_func_kq(Qcur); + struct ggml_tensor * Kcur = ggml_concat(ctx0, krotated, kpass); + offload_func_kq(Kcur); + + struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 1, 2, 0, 3)); + offload_func_kq(Q); + + Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 2, 1, 0, 3)); + offload_func_kq(Kcur); + { + struct ggml_tensor * tmpv = ggml_view_3d( + ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, + ggml_element_size(tmpqkv_perm) * n_embd_head, + ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, + ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens * 2 + ); + offload_func_v(tmpv); + // store K, V in cache + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens)); + offload_func_v(Vcur); + ggml_set_name(Vcur, "Vcur"); + + struct ggml_tensor * k = ggml_view_1d( + ctx0, kv_self.k, n_tokens*n_embd_gqa, + (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head) + ); + offload_func_kq(k); + ggml_set_name(k, "k"); + + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); + offload_func_v(v); + ggml_set_name(v, "v"); + + // important: storing RoPE-ed version of K in the KV cache! + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + } + struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_kv, n_head_kv, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); + + offload_func_kq(K); + ggml_format_name(K, "K_%d", il); + + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + offload_func_kq(KQ); + ggml_set_name(KQ, "KQ"); + + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); + offload_func_kq(KQ_scaled); + ggml_set_name(KQ_scaled, "KQ_scaled"); + + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); + offload_func_kq(KQ_masked); + ggml_set_name(KQ_masked, "KQ_masked"); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + offload_func_kq(KQ_soft_max); + ggml_set_name(KQ_soft_max, "KQ_soft_max"); + + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_kv, n_embd_head, n_head_kv, + ggml_element_size(kv_self.v)*n_ctx, + ggml_element_size(kv_self.v)*n_ctx*n_embd_head, + ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); + offload_func_v(V); + ggml_set_name(V, "V"); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + offload_func_v(KQV); + ggml_set_name(KQV, "KQV"); + + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + offload_func_v(KQV_merged); + ggml_set_name(KQV_merged, "KQV_merged"); + + cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); + offload_func_v(cur); + ggml_set_name(cur, "KQV_merged_contiguous"); + + cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); + offload_func(cur); + cur = ggml_add(ctx0, cur, model.layers[il].bo); + offload_func(cur); + ggml_set_name(cur, "result_wo"); + } + + struct ggml_tensor * inpFF = ggml_add(ctx0, residual, cur); + offload_func(inpFF); + ggml_set_name(inpFF, "inpFF"); + { + // MLP + { + // Norm + cur = ggml_norm(ctx0, inpFF, norm_eps); + offload_func(cur); + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, model.layers[il].ffn_norm), + model.layers[il].ffn_norm_b + ); + ggml_set_name(cur, "ffn_norm"); + offload_func(cur); + } + cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur); + offload_func(cur); + + cur = ggml_add(ctx0, cur, model.layers[il].b3); + offload_func(cur); + ggml_set_name(cur, "result_ffn_up"); + + cur = ggml_sqr(ctx0, ggml_relu(ctx0, cur)); + ggml_set_name(cur, "result_ffn_act"); + offload_func(cur); + offload_func(cur->src[0]); + + cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); + offload_func(cur); + cur = ggml_add(ctx0, + cur, + model.layers[il].b2); + offload_func(cur); + ggml_set_name(cur, "outFF"); + } + cur = ggml_add(ctx0, cur, inpFF); + offload_func(cur); + ggml_set_name(cur, "inpFF_+_outFF"); + inpL = cur; + } + cur = inpL; + { + cur = ggml_norm(ctx0, cur, norm_eps); + offload_func_nr(cur); + cur = ggml_mul(ctx0, cur, model.output_norm); + offload_func_nr(cur); + + cur = ggml_add(ctx0, cur, model.output_norm_b); + // offload_func_nr(cur); + + ggml_set_name(cur, "result_norm"); + } + cur = ggml_mul_mat(ctx0, model.output, cur); + ggml_set_name(cur, "result_output"); + ggml_build_forward_expand(gf, cur); + ggml_free(ctx0); + return gf; +} + +static struct ggml_cgraph * llm_build_bloom( + llama_context & lctx, + const llama_batch & batch) { + const auto & model = lctx.model; + const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; + + const auto & kv_self = lctx.kv_self; + + GGML_ASSERT(!!kv_self.ctx); + + const int64_t n_embd = hparams.n_embd; + const int64_t n_layer = hparams.n_layer; + const int64_t n_ctx = cparams.n_ctx; + const int64_t n_head = hparams.n_head; + const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_embd_head = hparams.n_embd_head(); + const int64_t n_embd_gqa = hparams.n_embd_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_rot); + + const float norm_eps = hparams.f_norm_eps; + + const int32_t n_tokens = batch.n_tokens; + const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; + const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; + + auto & buf_compute = lctx.buf_compute; + + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute.size, + /*.mem_buffer =*/ buf_compute.data, + /*.no_alloc =*/ false, + }; + + params.no_alloc = true; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur; + struct ggml_tensor * token; + struct ggml_tensor * inpL; + + if (batch.token) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + + ggml_allocr_alloc(lctx.alloc, inp_tokens); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); + } + ggml_set_name(inp_tokens, "inp_tokens"); + + token = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); + } else { +#ifdef GGML_USE_MPI + GGML_ASSERT(false && "not implemented"); +#endif + + token = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + + ggml_allocr_alloc(lctx.alloc, token); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(token->data, batch.embd, n_tokens * n_embd * ggml_element_size(token)); + } + } + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + ggml_allocr_alloc(lctx.alloc, KQ_scale); + if (!ggml_allocr_is_measure(lctx.alloc)) { + ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); + } + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + ggml_set_name(KQ_mask, "KQ_mask"); + ggml_allocr_alloc(lctx.alloc, KQ_mask); + if (!ggml_allocr_is_measure(lctx.alloc)) { + float * data = (float *) KQ_mask->data; + memset(data, 0, ggml_nbytes(KQ_mask)); + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j]; + + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } + } + + // norm + { + inpL = ggml_norm(ctx0, token, norm_eps); + inpL = ggml_add(ctx0, ggml_mul(ctx0, inpL, model.tok_norm), model.tok_norm_b); + } + + ggml_set_name(inpL, "inpL"); + + for (int il = 0; il < n_layer; ++il) { + { + // Norm + cur = ggml_norm(ctx0, inpL, norm_eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].attn_norm), model.layers[il].attn_norm_b); + } + + { + // Self Attention + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wqkv, cur), model.layers[il].bqkv); + + struct ggml_tensor * tmpq = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*n_embd); + struct ggml_tensor * tmpk = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*n_embd); + struct ggml_tensor * tmpv = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*(n_embd + n_embd_gqa)); + + struct ggml_tensor * Qcur = tmpq; + struct ggml_tensor * Kcur = tmpk; + + // store key and value to memory + { + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens)); + ggml_set_name(Vcur, "Vcur"); + + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); + ggml_set_name(k, "k"); + + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + } + + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Qcur, + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd_head, n_head, n_tokens)), + 0, 2, 1, 3); + ggml_set_name(Q, "Q"); + + struct ggml_tensor * K = + ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_kv, n_head_kv, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); + ggml_set_name(K, "K"); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + ggml_set_name(KQ, "KQ"); + + // KQ_scaled = KQ / sqrt(n_embd_head) + // KQ_scaled shape [n_past + n_tokens, n_tokens, n_head, 1] + struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + ggml_set_name(KQ_scaled, "KQ_scaled"); + + struct ggml_tensor * KQ_scaled_alibi = ggml_alibi(ctx0, KQ_scaled, /*n_past*/ kv_head, n_head, 8); + ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi"); + + // KQ_masked = mask_past(KQ_scaled) + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask); + ggml_set_name(KQ_masked, "KQ_masked"); + + // KQ = soft_max(KQ_masked) + struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + ggml_set_name(KQ_soft_max, "KQ_soft_max"); + + // split cached V into n_head heads + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_kv, n_embd_head, n_head_kv, + ggml_element_size(kv_self.v)*n_ctx, + ggml_element_size(kv_self.v)*n_ctx*n_embd_head, + ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); + ggml_set_name(V, "V"); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + ggml_set_name(KQV, "KQV"); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + ggml_set_name(KQV_merged, "KQV_merged"); + + // cur = KQV_merged.contiguous().view(n_embd, n_tokens) + cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); + ggml_set_name(cur, "KQV_merged_contiguous"); + } + + // Projection + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wo, cur), model.layers[il].bo); + + // Add the input + cur = ggml_add(ctx0, cur, inpL); + + struct ggml_tensor * inpFF = cur; + + // FF + { + // Norm + { + cur = ggml_norm(ctx0, inpFF, norm_eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ffn_norm), model.layers[il].ffn_norm_b); + } + + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w3, cur), model.layers[il].b3); + + // GELU activation + cur = ggml_gelu(ctx0, cur); + + // Projection + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w2, cur), model.layers[il].b2); + } + + inpL = ggml_add(ctx0, cur, inpFF); + } + + // Output Norm + { + cur = ggml_norm(ctx0, inpL, norm_eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.output_norm), model.output_norm_b); + } + ggml_set_name(cur, "result_norm"); + + cur = ggml_mul_mat(ctx0, model.output, cur); + ggml_set_name(cur, "result_output"); + + ggml_build_forward_expand(gf, cur); + + ggml_free(ctx0); + + return gf; +} + +static struct ggml_cgraph * llm_build_mpt( + llama_context & lctx, + const llama_batch & batch) { + const auto & model = lctx.model; + const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; + + const auto & kv_self = lctx.kv_self; + + GGML_ASSERT(!!kv_self.ctx); + + const int64_t n_embd = hparams.n_embd; + const int64_t n_layer = hparams.n_layer; + const int64_t n_ctx = cparams.n_ctx; + const int64_t n_head = hparams.n_head; + const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_embd_head = hparams.n_embd_head(); + const int64_t n_embd_gqa = hparams.n_embd_gqa(); + + const float norm_eps = hparams.f_norm_eps; + const float clamp_kqv = hparams.f_clamp_kqv; + const float max_alibi_bias = hparams.f_max_alibi_bias; + + const int n_gpu_layers = model.n_gpu_layers; + + const int32_t n_tokens = batch.n_tokens; + const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; + const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; + + auto & buf_compute = lctx.buf_compute; + + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute.size, + /*.mem_buffer =*/ buf_compute.data, + /*.no_alloc =*/ false, + }; + + params.no_alloc = true; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + //int warmup = 0; + if (batch.token) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + + ggml_allocr_alloc(lctx.alloc, inp_tokens); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); + //warmup = ((uint32_t*) inp_tokens->data)[0] == 0; + } + + ggml_set_name(inp_tokens, "inp_tokens"); + + inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); + } else { +#ifdef GGML_USE_MPI + GGML_ASSERT(false && "not implemented"); +#endif + + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + + ggml_allocr_alloc(lctx.alloc, inpL); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); + } + } + + const int i_gpu_start = n_layer - n_gpu_layers; + (void) i_gpu_start; + + // offload functions set the tensor output backend to GPU + // tensors are GPU-accelerated if any input or the output has been offloaded + offload_func_t offload_func_nr = llama_nop; // nr = non-repeating + offload_func_t offload_func_kq = llama_nop; + offload_func_t offload_func_v = llama_nop; + +#ifdef GGML_USE_CUBLAS + if (n_gpu_layers > n_layer) { + offload_func_nr = ggml_cuda_assign_buffers_no_alloc; + } + if (n_gpu_layers > n_layer + 1) { + offload_func_v = ggml_cuda_assign_buffers_no_alloc; + } + if (n_gpu_layers > n_layer + 2) { + offload_func_kq = ggml_cuda_assign_buffers_no_alloc; } +#endif // GGML_USE_CUBLAS // KQ_scale struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); @@ -3805,6 +5456,7 @@ static struct ggml_cgraph * llm_build_starcoder( // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + offload_func_kq(KQ_mask); ggml_set_name(KQ_mask, "KQ_mask"); ggml_allocr_alloc(lctx.alloc, KQ_mask); if (!ggml_allocr_is_measure(lctx.alloc)) { @@ -3825,48 +5477,87 @@ static struct ggml_cgraph * llm_build_starcoder( } } - inpL = ggml_add(ctx0, token, position); - ggml_set_name(inpL, "inpL"); - for (int il = 0; il < n_layer; ++il) { - { - // Norm - cur = ggml_norm(ctx0, inpL, norm_eps); - cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].attn_norm), model.layers[il].attn_norm_b); + struct ggml_tensor * attn_norm; + + offload_func_t offload_func = llama_nop; + +#ifdef GGML_USE_CUBLAS + if (il >= i_gpu_start) { + offload_func = ggml_cuda_assign_buffers_no_alloc; } +#endif // GGML_USE_CUBLAS + // self-attention + // TODO: refactor into common function (shared with LLaMA) { - // Self Attention - cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wqkv, cur), model.layers[il].bqkv); + attn_norm = ggml_norm(ctx0, inpL, norm_eps); + offload_func(attn_norm); - struct ggml_tensor * tmpq = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*n_embd); - struct ggml_tensor * tmpk = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*n_embd); - struct ggml_tensor * tmpv = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*(n_embd + n_embd_gqa)); + attn_norm = ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm); + offload_func(attn_norm); - struct ggml_tensor * Qcur = tmpq; - struct ggml_tensor * Kcur = tmpk; + if (1) { + cur = attn_norm; + } + + // compute QKV + + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + offload_func_kq(cur); + + if (clamp_kqv > 0.0f) { + cur = ggml_clamp(ctx0, cur, -clamp_kqv, clamp_kqv); + offload_func_kq(cur); + } + + const size_t wsize = ggml_type_size(cur->type); + + struct ggml_tensor * Qcur = ggml_view_3d( + ctx0, cur, n_embd_head, n_head, n_tokens, + wsize * n_embd_head, + wsize * n_embd_head * (n_head + 2 * n_head_kv), + 0); + offload_func_kq(Qcur); + + struct ggml_tensor * Kcur = ggml_view_3d( + ctx0, cur, n_embd_head, n_head_kv, n_tokens, + wsize * n_embd_head, + wsize * n_embd_head * (n_head + 2 * n_head_kv), + wsize * n_embd_head * n_head); + offload_func_kq(Kcur); + + struct ggml_tensor * tmpv = ggml_view_3d( + ctx0, cur, n_embd_head, n_head_kv, n_tokens, + wsize * n_embd_head, + wsize * n_embd_head * (n_head + 2 * n_head_kv), + wsize * n_embd_head * (n_head + n_head_kv)); + offload_func_kq(Kcur); + + ggml_set_name(Qcur, "Qcur"); + ggml_set_name(Kcur, "Kcur"); { struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens)); + offload_func_v(Vcur); + offload_func_v(Vcur->src[0]->src[0]); ggml_set_name(Vcur, "Vcur"); struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); + offload_func_kq(k); ggml_set_name(k, "k"); struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, ( n_ctx)*ggml_element_size(kv_self.v), (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); + offload_func_v(v); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); } - struct ggml_tensor * Q = - ggml_permute(ctx0, - ggml_cpy(ctx0, - Qcur, - ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd_head, n_head, n_tokens)), - 0, 2, 1, 3); + struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + offload_func_kq(Q); ggml_set_name(Q, "Q"); struct ggml_tensor * K = @@ -3875,85 +5566,105 @@ static struct ggml_cgraph * llm_build_starcoder( ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_head, ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); + offload_func_kq(K); ggml_set_name(K, "K"); - // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + offload_func_kq(KQ); ggml_set_name(KQ, "KQ"); - // KQ_scaled = KQ / sqrt(n_embd_head) - // KQ_scaled shape [n_past + n_tokens, n_tokens, n_head, 1] - struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); + offload_func_kq(KQ_scaled); ggml_set_name(KQ_scaled, "KQ_scaled"); - // KQ_masked = mask_past(KQ_scaled) - struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); + // TODO: replace with ggml_add() + struct ggml_tensor * KQ_scaled_alibi = + ggml_alibi(ctx0, KQ_scaled, 0, n_head, max_alibi_bias); + offload_func_kq(KQ_scaled_alibi); + ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi"); + + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask); + offload_func_kq(KQ_masked); ggml_set_name(KQ_masked, "KQ_masked"); - // KQ = soft_max(KQ_masked) - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); + offload_func_v(KQ_soft_max); ggml_set_name(KQ_soft_max, "KQ_soft_max"); - // split cached V into n_head heads struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, n_kv, n_embd_head, n_head_kv, ggml_element_size(kv_self.v)*n_ctx, ggml_element_size(kv_self.v)*n_ctx*n_embd_head, ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); + offload_func_v(V); ggml_set_name(V, "V"); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + offload_func_v(KQV); ggml_set_name(KQV, "KQV"); - // KQV_merged = KQV.permute(0, 2, 1, 3) struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + offload_func_v(KQV_merged); ggml_set_name(KQV_merged, "KQV_merged"); - // cur = KQV_merged.contiguous().view(n_embd, n_tokens) cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); + offload_func_v(cur); ggml_set_name(cur, "KQV_merged_contiguous"); - } - // Projection - cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wo, cur), model.layers[il].bo); + cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); + offload_func(cur); + ggml_set_name(cur, "result_wo"); + } // Add the input cur = ggml_add(ctx0, cur, inpL); + offload_func(cur); - struct ggml_tensor * inpFF = cur; + struct ggml_tensor * attn_out = cur; - // FF + // feed forward { // Norm { - cur = ggml_norm(ctx0, inpFF, norm_eps); - cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ffn_norm), model.layers[il].ffn_norm_b); + cur = ggml_norm(ctx0, attn_out, norm_eps); + offload_func(cur); + + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm); + offload_func(cur); } - cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w3, cur), model.layers[il].b3); + cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur); + offload_func(cur); - // GELU activation cur = ggml_gelu(ctx0, cur); - - // Projection - cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w2, cur), model.layers[il].b2); + offload_func(cur); + cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); + offload_func(cur); } - inpL = ggml_add(ctx0, cur, inpFF); + cur = ggml_add(ctx0, cur, attn_out); + offload_func(cur); + // input for next layer + inpL = cur; } - // Output Norm + cur = inpL; + + // norm { - cur = ggml_norm(ctx0, inpL, norm_eps); - cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.output_norm), model.output_norm_b); + cur = ggml_norm(ctx0, cur, norm_eps); + offload_func_nr(cur); + + cur = ggml_mul(ctx0, cur, model.output_norm); + ggml_set_name(cur, "result_norm"); } - ggml_set_name(cur, "result_norm"); cur = ggml_mul_mat(ctx0, model.output, cur); ggml_set_name(cur, "result_output"); ggml_build_forward_expand(gf, cur); + ggml_free(ctx0); return gf; @@ -3983,6 +5694,22 @@ static struct ggml_cgraph * llama_build_graph( { result = llm_build_starcoder(lctx, batch); } break; + case LLM_ARCH_PERSIMMON: + { + result = llm_build_persimmon(lctx, batch); + } break; + case LLM_ARCH_REFACT: + { + result = llm_build_refact(lctx, batch); + } break; + case LLM_ARCH_BLOOM: + { + result = llm_build_bloom(lctx, batch); + } break; + case LLM_ARCH_MPT: + { + result = llm_build_mpt(lctx, batch); + } break; default: GGML_ASSERT(false); } @@ -3994,7 +5721,6 @@ static struct ggml_cgraph * llama_build_graph( // // - lctx: llama context // - batch: batch to evaluate -// - n_threads: number of threads to use // // return 0 on success // return positive int on warning @@ -4061,10 +5787,6 @@ static int llama_decode_internal( batch.seq_id = seq_id.data(); } - // we always start to search for a free slot from the start of the cache - // TODO: better strategies can be implemented - kv_self.head = 0; - if (!llama_kv_cache_find_slot(kv_self, batch)) { return 1; } @@ -4116,7 +5838,9 @@ static int llama_decode_internal( // If all tensors can be run on the GPU then using more than 1 thread is detrimental. const bool full_offload_supported = model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_BAICHUAN || - model.arch == LLM_ARCH_FALCON; + model.arch == LLM_ARCH_FALCON || + model.arch == LLM_ARCH_REFACT || + model.arch == LLM_ARCH_MPT; const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 3; if (ggml_cpu_has_cublas() && full_offload_supported && fully_offloaded) { n_threads = 1; @@ -4149,8 +5873,12 @@ static int llama_decode_internal( #endif // update the kv ring buffer - lctx.kv_self.head += n_tokens; lctx.kv_self.has_shift = false; + lctx.kv_self.head += n_tokens; + // Ensure kv cache head points to a valid index. + if (lctx.kv_self.head >= lctx.kv_self.size) { + lctx.kv_self.head = 0; + } #ifdef GGML_PERF // print timing information per ggml operation (for debugging purposes) @@ -4236,18 +5964,41 @@ static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) { return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_BYTE; } -static uint8_t llama_token_to_byte(const llama_vocab & vocab, llama_token id) { +static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id) { + return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_USER_DEFINED; +} + +static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { GGML_ASSERT(llama_is_byte_token(vocab, id)); const auto& token_data = vocab.id_to_token.at(id); - auto buf = token_data.text.substr(3, 2); - return strtol(buf.c_str(), NULL, 16); + switch (llama_vocab_get_type(vocab)) { + case LLAMA_VOCAB_TYPE_SPM: { + auto buf = token_data.text.substr(3, 2); + return strtol(buf.c_str(), NULL, 16); + } + case LLAMA_VOCAB_TYPE_BPE: { + GGML_ASSERT(false); + return unicode_to_bytes_bpe(token_data.text); + } + default: + GGML_ASSERT(false); + } } static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) { - char buf[7]; - int result = snprintf(buf, sizeof(buf), "<0x%02X>", ch); - GGML_ASSERT(0 <= result && result < 7); - return vocab.token_to_id.at(buf); + switch (llama_vocab_get_type(vocab)) { + case LLAMA_VOCAB_TYPE_SPM: { + char buf[7]; + int result = snprintf(buf, sizeof(buf), "<0x%02X>", ch); + GGML_ASSERT(0 <= result && result < 7); + return vocab.token_to_id.at(buf); + } + case LLAMA_VOCAB_TYPE_BPE: { + return vocab.token_to_id.at(bytes_to_unicode_bpe(ch)); + } + default: + GGML_ASSERT(false); + } } static void llama_escape_whitespace(std::string & text) { @@ -4527,15 +6278,9 @@ struct llm_tokenizer_bpe { std::string byte_str(1, *j); auto token_multibyte = vocab.token_to_id.find(byte_str); if (token_multibyte == vocab.token_to_id.end()) { - try { - llama_token token_byte = llama_byte_to_token(vocab, *j); - output.push_back(token_byte); - } catch (const std::out_of_range & err) { - fprintf(stderr,"ERROR: byte not found in vocab: '%s'\n", byte_str.c_str()); - } - } else { - output.push_back((*token_multibyte).second); + throw std::runtime_error("ERROR: byte not found in vocab"); } + output.push_back((*token_multibyte).second); } } else { output.push_back((*token).second); @@ -4572,23 +6317,143 @@ struct llm_tokenizer_bpe { work_queue.push(bigram); } - // probably not 100% correct - static std::vector bpe_gpt2_preprocess(const std::string & text) { - std::vector words; + std::vector bpe_gpt2_preprocess(const std::string & text) { + std::vector bpe_words; + std::vector bpe_encoded_words; + + std::string token = ""; + // GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+ + bool collecting_numeric = false; + bool collecting_letter = false; + bool collecting_special = false; + bool collecting_whitespace_lookahead = false; + bool collecting = false; + + std::vector text_utf; + text_utf.reserve(text.size()); + bpe_words.reserve(text.size()); + bpe_encoded_words.reserve(text.size()); + + auto cps = codepoints_from_utf8(text); + for (size_t i = 0; i < cps.size(); ++i) + text_utf.emplace_back(codepoint_to_utf8(cps[i])); + + for (int i = 0; i < (int)text_utf.size(); i++) { + const std::string & utf_char = text_utf[i]; + bool split_condition = false; + int bytes_remain = text_utf.size() - i; + // forward backward lookups + const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : ""; + const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : ""; + + // handling contractions + if (!split_condition && bytes_remain >= 2) { + // 's|'t|'m|'d + if (utf_char == "\'" && (utf_char_next == "s" || utf_char_next == "t" || utf_char_next == "m" || utf_char_next == "d")) { + split_condition = true; + } + if (split_condition) { + if (token.size()) { + bpe_words.emplace_back(token); // push previous content as token + } + token = utf_char + utf_char_next; + bpe_words.emplace_back(token); + token = ""; + i++; + continue; + } + } + if (!split_condition && bytes_remain >= 3) { + // 're|'ve|'ll + if (utf_char == "\'" && ( + (utf_char_next == "r" && utf_char_next_next == "e") || + (utf_char_next == "v" && utf_char_next_next == "e") || + (utf_char_next == "l" && utf_char_next_next == "l")) + ) { + split_condition = true; + } + if (split_condition) { + // current token + next token can be defined + if (token.size()) { + bpe_words.emplace_back(token); // push previous content as token + } + token = utf_char + utf_char_next + utf_char_next_next; + bpe_words.emplace_back(token); // the contraction + token = ""; + i += 2; + continue; + } + } + + if (!split_condition && !collecting) { + if (codepoint_type(utf_char) == CODEPOINT_TYPE_LETTER || (!token.size() && utf_char == " " && codepoint_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) { + collecting_letter = true; + collecting = true; + } + else if (codepoint_type(utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size() && utf_char == " " && codepoint_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) { + collecting_numeric = true; + collecting = true; + } + else if ( + ((codepoint_type(utf_char) != CODEPOINT_TYPE_LETTER && codepoint_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (codepoint_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) || + (!token.size() && utf_char == " " && codepoint_type(utf_char_next) != CODEPOINT_TYPE_LETTER && codepoint_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && codepoint_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE) + ) { + collecting_special = true; + collecting = true; + } + else if (codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && codepoint_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) { + collecting_whitespace_lookahead = true; + collecting = true; + } + else if (codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) { + split_condition = true; + } + } + else if (!split_condition && collecting) { + if (collecting_letter && codepoint_type(utf_char) != CODEPOINT_TYPE_LETTER) { + split_condition = true; + } + else if (collecting_numeric && codepoint_type(utf_char) != CODEPOINT_TYPE_DIGIT) { + split_condition = true; + } + else if (collecting_special && (codepoint_type(utf_char) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char) == CODEPOINT_TYPE_DIGIT || codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) { + split_condition = true; + } + else if (collecting_whitespace_lookahead && (codepoint_type(utf_char_next) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) { + split_condition = true; + } + } + + if (utf_char_next == "") { + split_condition = true; // final + token += utf_char; + } - // ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53 - const std::string pattern = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; - const std::regex re(pattern); + if (split_condition) { + if (token.size()) { + bpe_words.emplace_back(token); + } + token = utf_char; + collecting = false; + collecting_letter = false; + collecting_numeric = false; + collecting_special = false; + collecting_whitespace_lookahead = false; + } + else { + token += utf_char; + } + } - auto words_begin = std::sregex_iterator(text.begin(), text.end(), re); - auto words_end = std::sregex_iterator(); - auto n_words = std::distance(words_begin, words_end); - words.reserve(n_words); - for (auto it = words_begin; it != words_end; ++it) { - words.push_back(it->str()); + for (std::string & word : bpe_words) { + std::string encoded_token = ""; + for (char & c : word) { + encoded_token += bytes_to_unicode_bpe(c); + } + bpe_encoded_words.emplace_back(encoded_token); } - return words; + return bpe_encoded_words; } const llama_vocab & vocab; @@ -6070,7 +7935,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s const std::string name = ggml_get_name(meta); // TODO: avoid hardcoded tensor names - use the TN_* constants - if (name.find("attn_v.weight") != std::string::npos) { + if (name.find("attn_v.weight") != std::string::npos || name.find("attn_qkv.weight") != std::string::npos) { ++n_attention_wv; } else if (name.find("ffn_down.weight") != std::string::npos) { @@ -6107,6 +7972,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } std::ofstream fout(fname_out, std::ios::binary); + fout.exceptions(std::ofstream::failbit); // fail fast on write errors const size_t meta_size = gguf_get_meta_size(ctx_out); @@ -6895,6 +8761,10 @@ int llama_n_embd(const struct llama_model * model) { return model->hparams.n_embd; } +float llama_rope_freq_scale_train(const struct llama_model * model) { + return model->hparams.rope_freq_scale_train; +} + int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) { return snprintf(buf, buf_size, "%s %s %s", llama_model_arch_name(model->arch).c_str(), @@ -7062,16 +8932,6 @@ struct llama_data_file_context : llama_data_context { * */ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { - // TODO: does not support multi-sequence states - { - const auto & kv_self = ctx->kv_self; - for (uint32_t i = 0; i < kv_self.head; ++i) { - GGML_ASSERT(kv_self.cells[i].pos == (int32_t) i); - GGML_ASSERT(kv_self.cells[i].seq_id.size() == 1); - GGML_ASSERT(kv_self.cells[i].has_seq_id(0)); - } - } - // copy rng { std::stringstream rng_ss; @@ -7124,36 +8984,38 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat const auto & hparams = ctx->model.hparams; const auto & cparams = ctx->cparams; - const int n_layer = hparams.n_layer; - const int n_embd = hparams.n_embd_gqa(); - const int n_ctx = cparams.n_ctx; + const auto n_layer = hparams.n_layer; + const auto n_embd = hparams.n_embd_gqa(); + const auto n_ctx = cparams.n_ctx; - const size_t kv_size = kv_self.buf.size; - const int kv_ntok = kv_self.head; + const size_t kv_buf_size = kv_self.buf.size; + const uint32_t kv_head = kv_self.head; + const uint32_t kv_size = kv_self.size; - data_ctx->write(&kv_size, sizeof(kv_size)); - data_ctx->write(&kv_ntok, sizeof(kv_ntok)); + data_ctx->write(&kv_buf_size, sizeof(kv_buf_size)); + data_ctx->write(&kv_head, sizeof(kv_head)); + data_ctx->write(&kv_size, sizeof(kv_size)); - if (kv_size) { + if (kv_buf_size) { const size_t elt_size = ggml_element_size(kv_self.k); ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); ggml_cgraph gf{}; - ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); + ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer); std::vector kout3d_data(ggml_nbytes(kout3d), 0); kout3d->data = kout3d_data.data(); - ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer); + ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_head, n_embd, n_layer); std::vector vout3d_data(ggml_nbytes(vout3d), 0); vout3d->data = vout3d_data.data(); ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k, - n_embd, kv_ntok, n_layer, + n_embd, kv_head, n_layer, elt_size*n_embd, elt_size*n_embd*n_ctx, 0); ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v, - kv_ntok, n_embd, n_layer, + kv_head, n_embd, n_layer, elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d)); @@ -7167,6 +9029,20 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat data_ctx->write(kout3d_data.data(), kout3d_data.size()); data_ctx->write(vout3d_data.data(), vout3d_data.size()); } + + for (uint32_t i = 0; i < kv_size; ++i) { + const auto & cell = kv_self.cells[i]; + + const llama_pos pos = cell.pos; + const size_t seq_id_size = cell.seq_id.size(); + + data_ctx->write(&pos, sizeof(pos)); + data_ctx->write(&seq_id_size, sizeof(seq_id_size)); + + for (auto seq_id : cell.seq_id) { + data_ctx->write(&seq_id, sizeof(seq_id)); + } + } } } @@ -7238,34 +9114,36 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { const int n_embd = hparams.n_embd_gqa(); const int n_ctx = cparams.n_ctx; - size_t kv_size; - int kv_ntok; + size_t kv_buf_size; + uint32_t kv_head; + uint32_t kv_size; - memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size); - memcpy(&kv_ntok, inp, sizeof(kv_ntok)); inp += sizeof(kv_ntok); + memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size); + memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head); + memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size); - if (kv_size) { - GGML_ASSERT(kv_self.buf.size == kv_size); + if (kv_buf_size) { + GGML_ASSERT(kv_self.buf.size == kv_buf_size); const size_t elt_size = ggml_element_size(kv_self.k); ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); ggml_cgraph gf{}; - ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); + ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer); kin3d->data = (void *) inp; inp += ggml_nbytes(kin3d); - ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer); + ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_head, n_embd, n_layer); vin3d->data = (void *) inp; inp += ggml_nbytes(vin3d); ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k, - n_embd, kv_ntok, n_layer, + n_embd, kv_head, n_layer, elt_size*n_embd, elt_size*n_embd*n_ctx, 0); ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v, - kv_ntok, n_embd, n_layer, + kv_head, n_embd, n_layer, elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d)); @@ -7275,8 +9153,27 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { ggml_free(cpy_ctx); } - ctx->kv_self.head = kv_ntok; + ctx->kv_self.head = kv_head; ctx->kv_self.size = kv_size; + + ctx->kv_self.cells.resize(kv_size); + + for (uint32_t i = 0; i < kv_size; ++i) { + llama_pos pos; + size_t seq_id_size; + + memcpy(&pos, inp, sizeof(pos)); inp += sizeof(pos); + memcpy(&seq_id_size, inp, sizeof(seq_id_size)); inp += sizeof(seq_id_size); + + ctx->kv_self.cells[i].pos = pos; + + llama_seq_id seq_id; + + for (size_t j = 0; j < seq_id_size; ++j) { + memcpy(&seq_id, inp, sizeof(seq_id)); inp += sizeof(seq_id); + ctx->kv_self.cells[i].seq_id.insert(seq_id); + } + } } const size_t nread = inp - src; @@ -7532,35 +9429,70 @@ int llama_tokenize( return res.size(); } +static std::string llama_decode_text(const std::string & text) { + std::string decoded_text; + auto unicode_sequences = codepoints_from_utf8(text); + for (auto& unicode_sequence : unicode_sequences) { + decoded_text += unicode_to_bytes_bpe(codepoint_to_utf8(unicode_sequence)); + } + + return decoded_text; +} + // does not write null-terminator to buf int llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int length) { if (0 <= token && token < llama_n_vocab(model)) { - if (llama_is_normal_token(model->vocab, token)) { - std::string result = model->vocab.id_to_token[token].text; - if (llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM) { + switch (llama_vocab_get_type(model->vocab)) { + case LLAMA_VOCAB_TYPE_SPM: { + if (llama_is_normal_token(model->vocab, token)) { + std::string result = model->vocab.id_to_token[token].text; llama_unescape_whitespace(result); + if (length < (int) result.length()) { + return -result.length(); + } + memcpy(buf, result.c_str(), result.length()); + return result.length(); + } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT + if (length < 3) { + return -3; + } + memcpy(buf, "\xe2\x96\x85", 3); + return 3; + } else if (llama_is_control_token(model->vocab, token)) { + ; + } else if (llama_is_byte_token(model->vocab, token)) { + if (length < 1) { + return -1; + } + buf[0] = llama_token_to_byte(model->vocab, token); + return 1; + } else { + // TODO: for now we accept all unsupported token types, + // suppressing them like CONTROL tokens. + // GGML_ASSERT(false); } - if (length < (int) result.length()) { - return -result.length(); - } - memcpy(buf, result.c_str(), result.length()); - return result.length(); - } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT - if (length < 3) { - return -3; - } - buf[0] = '\xe2'; - buf[1] = '\x96'; - buf[2] = '\x85'; - return 3; - } else if (llama_is_control_token(model->vocab, token)) { - // do nothing - } else if (llama_is_byte_token(model->vocab, token)) { - if (length < 1) { - return -1; + break; + } + case LLAMA_VOCAB_TYPE_BPE: { + if (llama_is_normal_token(model->vocab, token)) { + std::string result = model->vocab.id_to_token[token].text; + result = llama_decode_text(result); + if (length < (int) result.length()) { + return -result.length(); + } + memcpy(buf, result.c_str(), result.length()); + return result.length(); + } else if (llama_is_control_token(model->vocab, token)) { + ; + } else { + // TODO: for now we accept all unsupported token types, + // suppressing them like CONTROL tokens. + // GGML_ASSERT(false); } - buf[0] = llama_token_to_byte(model->vocab, token); - return 1; + break; + } + default: + GGML_ASSERT(false); } } return 0; @@ -7587,14 +9519,14 @@ void llama_print_timings(struct llama_context * ctx) { const llama_timings timings = llama_get_timings(ctx); LLAMA_LOG_INFO("\n"); - LLAMA_LOG_INFO("%s: load time = %8.2f ms\n", __func__, timings.t_load_ms); - LLAMA_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, timings.t_load_ms); + LLAMA_LOG_INFO("%s: sample time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", __func__, timings.t_sample_ms, timings.n_sample, timings.t_sample_ms / timings.n_sample, 1e3 / timings.t_sample_ms * timings.n_sample); - LLAMA_LOG_INFO("%s: prompt eval time = %8.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", + LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", __func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval); - LLAMA_LOG_INFO("%s: eval time = %8.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", __func__, timings.t_eval_ms, timings.n_eval, timings.t_eval_ms / timings.n_eval, 1e3 / timings.t_eval_ms * timings.n_eval); - LLAMA_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (timings.t_end_ms - timings.t_start_ms)); + LLAMA_LOG_INFO("%s: total time = %10.2f ms\n", __func__, (timings.t_end_ms - timings.t_start_ms)); } void llama_reset_timings(struct llama_context * ctx) { diff --git a/plugins/wasi_nn/thirdparty/ggml/llama.h b/plugins/wasi_nn/thirdparty/ggml/llama.h index fd215840..a78015ad 100644 --- a/plugins/wasi_nn/thirdparty/ggml/llama.h +++ b/plugins/wasi_nn/thirdparty/ggml/llama.h @@ -42,7 +42,7 @@ #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN -#define LLAMA_SESSION_VERSION 1 +#define LLAMA_SESSION_VERSION 2 #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) // Defined when llama.cpp is compiled with support for offloading model layers to GPU. @@ -282,6 +282,9 @@ extern "C" { LLAMA_API int llama_n_ctx_train(const struct llama_model * model); LLAMA_API int llama_n_embd (const struct llama_model * model); + // Get the model's RoPE frequency scaling factor + LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model); + // Get a string describing the model type LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size); @@ -330,12 +333,16 @@ extern "C" { "avoid using this, it will be removed in the future, instead - count the tokens in user code"); // Remove all tokens data of cells in [c0, c1) + // c0 < 0 : [0, c1] + // c1 < 0 : [c0, inf) LLAMA_API void llama_kv_cache_tokens_rm( struct llama_context * ctx, int32_t c0, int32_t c1); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) LLAMA_API void llama_kv_cache_seq_rm( struct llama_context * ctx, llama_seq_id seq_id, @@ -344,6 +351,8 @@ extern "C" { // Copy all tokens that belong to the specified sequence to another sequence // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) LLAMA_API void llama_kv_cache_seq_cp( struct llama_context * ctx, llama_seq_id seq_id_src, @@ -358,6 +367,8 @@ extern "C" { // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // If the KV cache is RoPEd, the KV data is updated accordingly + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) LLAMA_API void llama_kv_cache_seq_shift( struct llama_context * ctx, llama_seq_id seq_id, diff --git a/plugins/wasi_nn/thirdparty/ggml/sampling.cpp b/plugins/wasi_nn/thirdparty/ggml/sampling.cpp new file mode 100644 index 00000000..8ce41945 --- /dev/null +++ b/plugins/wasi_nn/thirdparty/ggml/sampling.cpp @@ -0,0 +1,166 @@ +#include "sampling.h" + +llama_sampling_context::~llama_sampling_context() { + for (auto & it : sequence_contexts) { + if (it.second.grammar != NULL) { + llama_grammar_free(it.second.grammar); + it.second.grammar = NULL; + } + } +} + +llama_sampling_context llama_sampling_context_init( + const struct gpt_params & params, + llama_grammar * grammar) { + llama_sampling_context result; + + result.params = params.sampling_params; + result.grammar = grammar; + return result; +} + +// Note: Creates the context if it doesn't exist, so this always return something. +llama_sampler_sequence_context & llama_sampling_get_sequence_context( + llama_sampling_context & ctx_sampling, + const llama_seq_id seq) { + const auto it = ctx_sampling.sequence_contexts.find(seq); + if (it != ctx_sampling.sequence_contexts.end()) { + return it->second; + } + llama_sampler_sequence_context new_ctx = { + 2.0f * ctx_sampling.params.mirostat_tau, + ctx_sampling.grammar != NULL ? llama_grammar_copy(ctx_sampling.grammar) : NULL, + }; + return ctx_sampling.sequence_contexts.insert({seq, new_ctx}).first->second; +} + +bool llama_sampling_context_reset( + llama_sampling_context & ctx_sampling, + const llama_seq_id seq) { + const auto it = ctx_sampling.sequence_contexts.find(seq); + if (it == ctx_sampling.sequence_contexts.end()) return false; + if (it->second.grammar != NULL) { + llama_grammar_free(it->second.grammar); + it->second.grammar = NULL; + } + ctx_sampling.sequence_contexts.erase(it); + return true; +} + +llama_token llama_sampling_sample( + struct llama_context * ctx, + struct llama_context * ctx_guidance, + struct llama_sampling_context & ctx_sampling, + const std::vector & last_tokens, + std::vector & candidates, + const int idx, + llama_seq_id seq) { + const int n_ctx = llama_n_ctx(ctx); + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + + const llama_sampling_params & params = ctx_sampling.params; + const float temp = params.temp; + const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; + const float top_p = params.top_p; + const float tfs_z = params.tfs_z; + const float typical_p = params.typical_p; + const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; + const float repeat_penalty = params.repeat_penalty; + const float alpha_presence = params.presence_penalty; + const float alpha_frequency = params.frequency_penalty; + const int mirostat = params.mirostat; + const float mirostat_tau = params.mirostat_tau; + const float mirostat_eta = params.mirostat_eta; + const bool penalize_nl = params.penalize_nl; + + llama_token id = 0; + + float * logits = llama_get_logits_ith(ctx, idx); + + // Apply params.logit_bias map + for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { + logits[it->first] += it->second; + } + + candidates.clear(); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + + llama_token_data_array cur_p = { candidates.data(), candidates.size(), false }; + + if (ctx_guidance) { + llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale); + } + + // apply penalties + if (!last_tokens.empty()) { + const float nl_logit = logits[llama_token_nl(ctx)]; + const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx); + + llama_sample_repetition_penalty(ctx, &cur_p, + last_tokens.data() + last_tokens.size() - last_n_repeat, + last_n_repeat, repeat_penalty); + llama_sample_frequency_and_presence_penalties(ctx, &cur_p, + last_tokens.data() + last_tokens.size() - last_n_repeat, + last_n_repeat, alpha_frequency, alpha_presence); + + if (!penalize_nl) { + for (size_t idx = 0; idx < cur_p.size; idx++) { + if (cur_p.data[idx].id == llama_token_nl(ctx)) { + cur_p.data[idx].logit = nl_logit; + break; + } + } + } + } + + llama_sampler_sequence_context & ctx_seq = llama_sampling_get_sequence_context(ctx_sampling, seq); + + if (ctx_seq.grammar != NULL) { + llama_sample_grammar(ctx, &cur_p, ctx_seq.grammar); + } + + if (temp <= 0) { + // Greedy sampling + id = llama_sample_token_greedy(ctx, &cur_p); + } else { + if (mirostat == 1) { + const int mirostat_m = 100; + llama_sample_temp(ctx, &cur_p, temp); + id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_seq.mirostat_mu); + } else if (mirostat == 2) { + llama_sample_temp(ctx, &cur_p, temp); + id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_seq.mirostat_mu); + } else { + // Temperature sampling + size_t min_keep = std::max(1, params.n_probs); + llama_sample_top_k (ctx, &cur_p, top_k, min_keep); + llama_sample_tail_free (ctx, &cur_p, tfs_z, min_keep); + llama_sample_typical (ctx, &cur_p, typical_p, min_keep); + llama_sample_top_p (ctx, &cur_p, top_p, min_keep); + llama_sample_temp(ctx, &cur_p, temp); + + { + const int n_top = 10; + LOG("top %d candidates:\n", n_top); + + for (int i = 0; i < n_top; i++) { + const llama_token id = cur_p.data[i].id; + (void)id; // To avoid a warning that id is unused when logging is disabled. + LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p); + } + } + + id = llama_sample_token(ctx, &cur_p); + + LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str()); + } + } + + if (ctx_seq.grammar != NULL) { + llama_grammar_accept_token(ctx, ctx_seq.grammar, id); + } + + return id; +} diff --git a/plugins/wasi_nn/thirdparty/ggml/sampling.h b/plugins/wasi_nn/thirdparty/ggml/sampling.h new file mode 100644 index 00000000..0aab5d03 --- /dev/null +++ b/plugins/wasi_nn/thirdparty/ggml/sampling.h @@ -0,0 +1,108 @@ +#pragma once + +#include "llama.h" + +#include +#include +#include + +// sampling parameters +typedef struct llama_sampling_params { + int32_t top_k = 40; // <= 0 to use vocab size + float top_p = 0.95f; // 1.0 = disabled + float tfs_z = 1.00f; // 1.0 = disabled + float typical_p = 1.00f; // 1.0 = disabled + float temp = 0.80f; // 1.0 = disabled + float repeat_penalty = 1.10f; // 1.0 = disabled + int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float frequency_penalty = 0.00f; // 0.0 = disabled + float presence_penalty = 0.00f; // 0.0 = disabled + int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau = 5.00f; // target entropy + float mirostat_eta = 0.10f; // learning rate + + bool penalize_nl = true; // consider newlines as a repeatable token + + int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + + // Classifier-Free Guidance + // https://arxiv.org/abs/2306.17806 + std::string cfg_negative_prompt; // string to help guidance + float cfg_scale = 1.f; // How strong is guidance + + std::unordered_map logit_bias; // logit bias for specific tokens + +} llama_sampling_params; + +// per-sequence sampler context +typedef struct llama_sampler_sequence_context { + float mirostat_mu; // mirostat sampler state + llama_grammar * grammar; +} llama_sampler_sequence_context; + +// general sampler context +typedef struct llama_sampling_context { + ~llama_sampling_context(); + + // parameters that will be used for sampling and when creating + // new llama_sampler_sequence_context instances + llama_sampling_params params; + + // map of sequence ids to sampler contexts + std::unordered_map sequence_contexts; + + // when non-NULL, new instances of llama_sampler_sequence_context + // will get a copy of the grammar here + // note: only the pointer is stored here, it is not a copy of + // the grammar and shouldn't be freed + llama_grammar * grammar; +} llama_sampling_context; + +#include "common.h" + +// Create a new sampling context instance. +llama_sampling_context llama_sampling_context_init( + const struct gpt_params & params, + llama_grammar * grammar = NULL); + +// Fetches the sampler context for the specified sequence id (defaults to 0). +// If the context for that sequence id doesn't already exist, it will be created with +// default values based on the parameters in the ctx_sampling argument. +llama_sampler_sequence_context & llama_sampling_get_sequence_context( + llama_sampling_context & ctx_sampling, + const llama_seq_id seq = 0); + +// Reset the sampler context for the supplied sequence id (defaults to 0). +// This is necessary to reuse a sequence id or free memory used by sequences +// that are no longer required. +bool llama_sampling_context_reset( + llama_sampling_context & ctx_sampling, + const llama_seq_id seq = 0); + +// this is a common sampling function used across the examples for convenience +// it can serve as a starting point for implementing your own sampling function +// Note: When using multiple sequences, it is the caller's responsibility to call +// llama_sampling_context_reset when a sequence ends +// +// required: +// - ctx: context to use for sampling +// - ctx_sampling: sampling-specific context +// +// optional: +// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL +// - last_tokens: needed for repetition penalty, ignore if empty +// - idx: sample from llama_get_logits_ith(ctx, idx) +// - seq: sequence id to associate sampler state with +// +// returns: +// - token: sampled token +// - candidates: vector of candidate tokens +// +llama_token llama_sampling_sample( + struct llama_context * ctx, + struct llama_context * ctx_guidance, + struct llama_sampling_context & ctx_sampling, + const std::vector & last_tokens, + std::vector & candidates, + const int idx = 0, + llama_seq_id seq = 0); diff --git a/plugins/wasi_nn/thirdparty/ggml/unicode.h b/plugins/wasi_nn/thirdparty/ggml/unicode.h new file mode 100644 index 00000000..aeca879e --- /dev/null +++ b/plugins/wasi_nn/thirdparty/ggml/unicode.h @@ -0,0 +1,462 @@ +#pragma once + +#include +#include +#include +#include + +static const std::vector> digit_ranges = { +{0x30, 0x39}, {0xB2, 0xB3}, {0xB9, 0xB9}, {0x660, 0x669}, {0x6F0, 0x6F9}, {0x7C0, 0x7C9}, {0x966, 0x96F}, {0x9E6, 0x9EF}, {0xA66, 0xA6F}, {0xAE6, 0xAEF}, {0xB66, 0xB6F}, {0xBE6, 0xBEF}, {0xC66, 0xC6F}, +{0xCE6, 0xCEF}, {0xD66, 0xD6F}, {0xDE6, 0xDEF}, {0xE50, 0xE59}, {0xED0, 0xED9}, {0xF20, 0xF29}, {0x1040, 0x1049}, {0x1090, 0x1099}, {0x1369, 0x1371}, {0x17E0, 0x17E9}, {0x1810, 0x1819}, {0x1946, 0x194F}, +{0x19D0, 0x19DA}, {0x1A80, 0x1A89}, {0x1A90, 0x1A99}, {0x1B50, 0x1B59}, {0x1BB0, 0x1BB9}, {0x1C40, 0x1C49}, {0x1C50, 0x1C59}, {0x2070, 0x2070}, {0x2074, 0x2079}, {0x2080, 0x2089}, {0x2460, 0x2468}, +{0x2474, 0x247C}, {0x2488, 0x2490}, {0x24EA, 0x24EA}, {0x24F5, 0x24FD}, {0x24FF, 0x24FF}, {0x2776, 0x277E}, {0x2780, 0x2788}, {0x278A, 0x2792}, {0xA620, 0xA629}, {0xA8D0, 0xA8D9}, {0xA900, 0xA909}, +{0xA9D0, 0xA9D9}, {0xA9F0, 0xA9F9}, {0xAA50, 0xAA59}, {0xABF0, 0xABF9}, {0xFF10, 0xFF19}, {0x104A0, 0x104A9}, {0x10A40, 0x10A43}, {0x10D30, 0x10D39}, {0x10E60, 0x10E68}, {0x11052, 0x1105A}, +{0x11066, 0x1106F}, {0x110F0, 0x110F9}, {0x11136, 0x1113F}, {0x111D0, 0x111D9}, {0x112F0, 0x112F9}, {0x11450, 0x11459}, {0x114D0, 0x114D9}, {0x11650, 0x11659}, {0x116C0, 0x116C9}, {0x11730, 0x11739}, +{0x118E0, 0x118E9}, {0x11950, 0x11959}, {0x11C50, 0x11C59}, {0x11D50, 0x11D59}, {0x11DA0, 0x11DA9}, {0x16A60, 0x16A69}, {0x16B50, 0x16B59}, {0x1D7CE, 0x1D7FF}, {0x1E140, 0x1E149}, {0x1E2F0, 0x1E2F9}, +{0x1E950, 0x1E959}, {0x1F100, 0x1F10A}, {0x1FBF0, 0x1FBF9}, +}; + +static const std::vector> letter_ranges = { +{0x41, 0x5A}, {0x61, 0x7A}, {0xAA, 0xAA}, {0xB5, 0xB5}, {0xBA, 0xBA}, {0xC0, 0xD6}, {0xD8, 0xF6}, {0xF8, 0x2C1}, {0x2C6, 0x2D1}, {0x2E0, 0x2E4}, {0x2EC, 0x2EC}, {0x2EE, 0x2EE}, {0x370, 0x374}, +{0x376, 0x377}, {0x37A, 0x37D}, {0x37F, 0x37F}, {0x386, 0x386}, {0x388, 0x38A}, {0x38C, 0x38C}, {0x38E, 0x3A1}, {0x3A3, 0x3F5}, {0x3F7, 0x481}, {0x48A, 0x52F}, {0x531, 0x556}, {0x559, 0x559}, +{0x560, 0x588}, {0x5D0, 0x5EA}, {0x5EF, 0x5F2}, {0x620, 0x64A}, {0x66E, 0x66F}, {0x671, 0x6D3}, {0x6D5, 0x6D5}, {0x6E5, 0x6E6}, {0x6EE, 0x6EF}, {0x6FA, 0x6FC}, {0x6FF, 0x6FF}, {0x710, 0x710}, +{0x712, 0x72F}, {0x74D, 0x7A5}, {0x7B1, 0x7B1}, {0x7CA, 0x7EA}, {0x7F4, 0x7F5}, {0x7FA, 0x7FA}, {0x800, 0x815}, {0x81A, 0x81A}, {0x824, 0x824}, {0x828, 0x828}, {0x840, 0x858}, {0x860, 0x86A}, +{0x8A0, 0x8B4}, {0x8B6, 0x8C7}, {0x904, 0x939}, {0x93D, 0x93D}, {0x950, 0x950}, {0x958, 0x961}, {0x971, 0x980}, {0x985, 0x98C}, {0x98F, 0x990}, {0x993, 0x9A8}, {0x9AA, 0x9B0}, {0x9B2, 0x9B2}, +{0x9B6, 0x9B9}, {0x9BD, 0x9BD}, {0x9CE, 0x9CE}, {0x9DC, 0x9DD}, {0x9DF, 0x9E1}, {0x9F0, 0x9F1}, {0x9FC, 0x9FC}, {0xA05, 0xA0A}, {0xA0F, 0xA10}, {0xA13, 0xA28}, {0xA2A, 0xA30}, {0xA32, 0xA33}, +{0xA35, 0xA36}, {0xA38, 0xA39}, {0xA59, 0xA5C}, {0xA5E, 0xA5E}, {0xA72, 0xA74}, {0xA85, 0xA8D}, {0xA8F, 0xA91}, {0xA93, 0xAA8}, {0xAAA, 0xAB0}, {0xAB2, 0xAB3}, {0xAB5, 0xAB9}, {0xABD, 0xABD}, +{0xAD0, 0xAD0}, {0xAE0, 0xAE1}, {0xAF9, 0xAF9}, {0xB05, 0xB0C}, {0xB0F, 0xB10}, {0xB13, 0xB28}, {0xB2A, 0xB30}, {0xB32, 0xB33}, {0xB35, 0xB39}, {0xB3D, 0xB3D}, {0xB5C, 0xB5D}, {0xB5F, 0xB61}, +{0xB71, 0xB71}, {0xB83, 0xB83}, {0xB85, 0xB8A}, {0xB8E, 0xB90}, {0xB92, 0xB95}, {0xB99, 0xB9A}, {0xB9C, 0xB9C}, {0xB9E, 0xB9F}, {0xBA3, 0xBA4}, {0xBA8, 0xBAA}, {0xBAE, 0xBB9}, {0xBD0, 0xBD0}, +{0xC05, 0xC0C}, {0xC0E, 0xC10}, {0xC12, 0xC28}, {0xC2A, 0xC39}, {0xC3D, 0xC3D}, {0xC58, 0xC5A}, {0xC60, 0xC61}, {0xC80, 0xC80}, {0xC85, 0xC8C}, {0xC8E, 0xC90}, {0xC92, 0xCA8}, {0xCAA, 0xCB3}, +{0xCB5, 0xCB9}, {0xCBD, 0xCBD}, {0xCDE, 0xCDE}, {0xCE0, 0xCE1}, {0xCF1, 0xCF2}, {0xD04, 0xD0C}, {0xD0E, 0xD10}, {0xD12, 0xD3A}, {0xD3D, 0xD3D}, {0xD4E, 0xD4E}, {0xD54, 0xD56}, {0xD5F, 0xD61}, +{0xD7A, 0xD7F}, {0xD85, 0xD96}, {0xD9A, 0xDB1}, {0xDB3, 0xDBB}, {0xDBD, 0xDBD}, {0xDC0, 0xDC6}, {0xE01, 0xE30}, {0xE32, 0xE33}, {0xE40, 0xE46}, {0xE81, 0xE82}, {0xE84, 0xE84}, {0xE86, 0xE8A}, +{0xE8C, 0xEA3}, {0xEA5, 0xEA5}, {0xEA7, 0xEB0}, {0xEB2, 0xEB3}, {0xEBD, 0xEBD}, {0xEC0, 0xEC4}, {0xEC6, 0xEC6}, {0xEDC, 0xEDF}, {0xF00, 0xF00}, {0xF40, 0xF47}, {0xF49, 0xF6C}, {0xF88, 0xF8C}, +{0x1000, 0x102A}, {0x103F, 0x103F}, {0x1050, 0x1055}, {0x105A, 0x105D}, {0x1061, 0x1061}, {0x1065, 0x1066}, {0x106E, 0x1070}, {0x1075, 0x1081}, {0x108E, 0x108E}, {0x10A0, 0x10C5}, {0x10C7, 0x10C7}, +{0x10CD, 0x10CD}, {0x10D0, 0x10FA}, {0x10FC, 0x1248}, {0x124A, 0x124D}, {0x1250, 0x1256}, {0x1258, 0x1258}, {0x125A, 0x125D}, {0x1260, 0x1288}, {0x128A, 0x128D}, {0x1290, 0x12B0}, {0x12B2, 0x12B5}, +{0x12B8, 0x12BE}, {0x12C0, 0x12C0}, {0x12C2, 0x12C5}, {0x12C8, 0x12D6}, {0x12D8, 0x1310}, {0x1312, 0x1315}, {0x1318, 0x135A}, {0x1380, 0x138F}, {0x13A0, 0x13F5}, {0x13F8, 0x13FD}, {0x1401, 0x166C}, +{0x166F, 0x167F}, {0x1681, 0x169A}, {0x16A0, 0x16EA}, {0x16F1, 0x16F8}, {0x1700, 0x170C}, {0x170E, 0x1711}, {0x1720, 0x1731}, {0x1740, 0x1751}, {0x1760, 0x176C}, {0x176E, 0x1770}, {0x1780, 0x17B3}, +{0x17D7, 0x17D7}, {0x17DC, 0x17DC}, {0x1820, 0x1878}, {0x1880, 0x1884}, {0x1887, 0x18A8}, {0x18AA, 0x18AA}, {0x18B0, 0x18F5}, {0x1900, 0x191E}, {0x1950, 0x196D}, {0x1970, 0x1974}, {0x1980, 0x19AB}, +{0x19B0, 0x19C9}, {0x1A00, 0x1A16}, {0x1A20, 0x1A54}, {0x1AA7, 0x1AA7}, {0x1B05, 0x1B33}, {0x1B45, 0x1B4B}, {0x1B83, 0x1BA0}, {0x1BAE, 0x1BAF}, {0x1BBA, 0x1BE5}, {0x1C00, 0x1C23}, {0x1C4D, 0x1C4F}, +{0x1C5A, 0x1C7D}, {0x1C80, 0x1C88}, {0x1C90, 0x1CBA}, {0x1CBD, 0x1CBF}, {0x1CE9, 0x1CEC}, {0x1CEE, 0x1CF3}, {0x1CF5, 0x1CF6}, {0x1CFA, 0x1CFA}, {0x1D00, 0x1DBF}, {0x1E00, 0x1F15}, {0x1F18, 0x1F1D}, +{0x1F20, 0x1F45}, {0x1F48, 0x1F4D}, {0x1F50, 0x1F57}, {0x1F59, 0x1F59}, {0x1F5B, 0x1F5B}, {0x1F5D, 0x1F5D}, {0x1F5F, 0x1F7D}, {0x1F80, 0x1FB4}, {0x1FB6, 0x1FBC}, {0x1FBE, 0x1FBE}, {0x1FC2, 0x1FC4}, +{0x1FC6, 0x1FCC}, {0x1FD0, 0x1FD3}, {0x1FD6, 0x1FDB}, {0x1FE0, 0x1FEC}, {0x1FF2, 0x1FF4}, {0x1FF6, 0x1FFC}, {0x2071, 0x2071}, {0x207F, 0x207F}, {0x2090, 0x209C}, {0x2102, 0x2102}, {0x2107, 0x2107}, +{0x210A, 0x2113}, {0x2115, 0x2115}, {0x2119, 0x211D}, {0x2124, 0x2124}, {0x2126, 0x2126}, {0x2128, 0x2128}, {0x212A, 0x212D}, {0x212F, 0x2139}, {0x213C, 0x213F}, {0x2145, 0x2149}, {0x214E, 0x214E}, +{0x2183, 0x2184}, {0x2C00, 0x2C2E}, {0x2C30, 0x2C5E}, {0x2C60, 0x2CE4}, {0x2CEB, 0x2CEE}, {0x2CF2, 0x2CF3}, {0x2D00, 0x2D25}, {0x2D27, 0x2D27}, {0x2D2D, 0x2D2D}, {0x2D30, 0x2D67}, {0x2D6F, 0x2D6F}, +{0x2D80, 0x2D96}, {0x2DA0, 0x2DA6}, {0x2DA8, 0x2DAE}, {0x2DB0, 0x2DB6}, {0x2DB8, 0x2DBE}, {0x2DC0, 0x2DC6}, {0x2DC8, 0x2DCE}, {0x2DD0, 0x2DD6}, {0x2DD8, 0x2DDE}, {0x2E2F, 0x2E2F}, {0x3005, 0x3006}, +{0x3031, 0x3035}, {0x303B, 0x303C}, {0x3041, 0x3096}, {0x309D, 0x309F}, {0x30A1, 0x30FA}, {0x30FC, 0x30FF}, {0x3105, 0x312F}, {0x3131, 0x318E}, {0x31A0, 0x31BF}, {0x31F0, 0x31FF}, {0x3400, 0x4DBF}, +{0x4E00, 0x9FFC}, {0xA000, 0xA48C}, {0xA4D0, 0xA4FD}, {0xA500, 0xA60C}, {0xA610, 0xA61F}, {0xA62A, 0xA62B}, {0xA640, 0xA66E}, {0xA67F, 0xA69D}, {0xA6A0, 0xA6E5}, {0xA717, 0xA71F}, {0xA722, 0xA788}, +{0xA78B, 0xA7BF}, {0xA7C2, 0xA7CA}, {0xA7F5, 0xA801}, {0xA803, 0xA805}, {0xA807, 0xA80A}, {0xA80C, 0xA822}, {0xA840, 0xA873}, {0xA882, 0xA8B3}, {0xA8F2, 0xA8F7}, {0xA8FB, 0xA8FB}, {0xA8FD, 0xA8FE}, +{0xA90A, 0xA925}, {0xA930, 0xA946}, {0xA960, 0xA97C}, {0xA984, 0xA9B2}, {0xA9CF, 0xA9CF}, {0xA9E0, 0xA9E4}, {0xA9E6, 0xA9EF}, {0xA9FA, 0xA9FE}, {0xAA00, 0xAA28}, {0xAA40, 0xAA42}, {0xAA44, 0xAA4B}, +{0xAA60, 0xAA76}, {0xAA7A, 0xAA7A}, {0xAA7E, 0xAAAF}, {0xAAB1, 0xAAB1}, {0xAAB5, 0xAAB6}, {0xAAB9, 0xAABD}, {0xAAC0, 0xAAC0}, {0xAAC2, 0xAAC2}, {0xAADB, 0xAADD}, {0xAAE0, 0xAAEA}, {0xAAF2, 0xAAF4}, +{0xAB01, 0xAB06}, {0xAB09, 0xAB0E}, {0xAB11, 0xAB16}, {0xAB20, 0xAB26}, {0xAB28, 0xAB2E}, {0xAB30, 0xAB5A}, {0xAB5C, 0xAB69}, {0xAB70, 0xABE2}, {0xAC00, 0xD7A3}, {0xD7B0, 0xD7C6}, {0xD7CB, 0xD7FB}, +{0xF900, 0xFA6D}, {0xFA70, 0xFAD9}, {0xFB00, 0xFB06}, {0xFB13, 0xFB17}, {0xFB1D, 0xFB1D}, {0xFB1F, 0xFB28}, {0xFB2A, 0xFB36}, {0xFB38, 0xFB3C}, {0xFB3E, 0xFB3E}, {0xFB40, 0xFB41}, {0xFB43, 0xFB44}, +{0xFB46, 0xFBB1}, {0xFBD3, 0xFD3D}, {0xFD50, 0xFD8F}, {0xFD92, 0xFDC7}, {0xFDF0, 0xFDFB}, {0xFE70, 0xFE74}, {0xFE76, 0xFEFC}, {0xFF21, 0xFF3A}, {0xFF41, 0xFF5A}, {0xFF66, 0xFFBE}, {0xFFC2, 0xFFC7}, +{0xFFCA, 0xFFCF}, {0xFFD2, 0xFFD7}, {0xFFDA, 0xFFDC}, {0x10000, 0x1000B}, {0x1000D, 0x10026}, {0x10028, 0x1003A}, {0x1003C, 0x1003D}, {0x1003F, 0x1004D}, {0x10050, 0x1005D}, {0x10080, 0x100FA}, +{0x10280, 0x1029C}, {0x102A0, 0x102D0}, {0x10300, 0x1031F}, {0x1032D, 0x10340}, {0x10342, 0x10349}, {0x10350, 0x10375}, {0x10380, 0x1039D}, {0x103A0, 0x103C3}, {0x103C8, 0x103CF}, {0x10400, 0x1049D}, +{0x104B0, 0x104D3}, {0x104D8, 0x104FB}, {0x10500, 0x10527}, {0x10530, 0x10563}, {0x10600, 0x10736}, {0x10740, 0x10755}, {0x10760, 0x10767}, {0x10800, 0x10805}, {0x10808, 0x10808}, {0x1080A, 0x10835}, +{0x10837, 0x10838}, {0x1083C, 0x1083C}, {0x1083F, 0x10855}, {0x10860, 0x10876}, {0x10880, 0x1089E}, {0x108E0, 0x108F2}, {0x108F4, 0x108F5}, {0x10900, 0x10915}, {0x10920, 0x10939}, {0x10980, 0x109B7}, +{0x109BE, 0x109BF}, {0x10A00, 0x10A00}, {0x10A10, 0x10A13}, {0x10A15, 0x10A17}, {0x10A19, 0x10A35}, {0x10A60, 0x10A7C}, {0x10A80, 0x10A9C}, {0x10AC0, 0x10AC7}, {0x10AC9, 0x10AE4}, {0x10B00, 0x10B35}, +{0x10B40, 0x10B55}, {0x10B60, 0x10B72}, {0x10B80, 0x10B91}, {0x10C00, 0x10C48}, {0x10C80, 0x10CB2}, {0x10CC0, 0x10CF2}, {0x10D00, 0x10D23}, {0x10E80, 0x10EA9}, {0x10EB0, 0x10EB1}, {0x10F00, 0x10F1C}, +{0x10F27, 0x10F27}, {0x10F30, 0x10F45}, {0x10FB0, 0x10FC4}, {0x10FE0, 0x10FF6}, {0x11003, 0x11037}, {0x11083, 0x110AF}, {0x110D0, 0x110E8}, {0x11103, 0x11126}, {0x11144, 0x11144}, {0x11147, 0x11147}, +{0x11150, 0x11172}, {0x11176, 0x11176}, {0x11183, 0x111B2}, {0x111C1, 0x111C4}, {0x111DA, 0x111DA}, {0x111DC, 0x111DC}, {0x11200, 0x11211}, {0x11213, 0x1122B}, {0x11280, 0x11286}, {0x11288, 0x11288}, +{0x1128A, 0x1128D}, {0x1128F, 0x1129D}, {0x1129F, 0x112A8}, {0x112B0, 0x112DE}, {0x11305, 0x1130C}, {0x1130F, 0x11310}, {0x11313, 0x11328}, {0x1132A, 0x11330}, {0x11332, 0x11333}, {0x11335, 0x11339}, +{0x1133D, 0x1133D}, {0x11350, 0x11350}, {0x1135D, 0x11361}, {0x11400, 0x11434}, {0x11447, 0x1144A}, {0x1145F, 0x11461}, {0x11480, 0x114AF}, {0x114C4, 0x114C5}, {0x114C7, 0x114C7}, {0x11580, 0x115AE}, +{0x115D8, 0x115DB}, {0x11600, 0x1162F}, {0x11644, 0x11644}, {0x11680, 0x116AA}, {0x116B8, 0x116B8}, {0x11700, 0x1171A}, {0x11800, 0x1182B}, {0x118A0, 0x118DF}, {0x118FF, 0x11906}, {0x11909, 0x11909}, +{0x1190C, 0x11913}, {0x11915, 0x11916}, {0x11918, 0x1192F}, {0x1193F, 0x1193F}, {0x11941, 0x11941}, {0x119A0, 0x119A7}, {0x119AA, 0x119D0}, {0x119E1, 0x119E1}, {0x119E3, 0x119E3}, {0x11A00, 0x11A00}, +{0x11A0B, 0x11A32}, {0x11A3A, 0x11A3A}, {0x11A50, 0x11A50}, {0x11A5C, 0x11A89}, {0x11A9D, 0x11A9D}, {0x11AC0, 0x11AF8}, {0x11C00, 0x11C08}, {0x11C0A, 0x11C2E}, {0x11C40, 0x11C40}, {0x11C72, 0x11C8F}, +{0x11D00, 0x11D06}, {0x11D08, 0x11D09}, {0x11D0B, 0x11D30}, {0x11D46, 0x11D46}, {0x11D60, 0x11D65}, {0x11D67, 0x11D68}, {0x11D6A, 0x11D89}, {0x11D98, 0x11D98}, {0x11EE0, 0x11EF2}, {0x11FB0, 0x11FB0}, +{0x12000, 0x12399}, {0x12480, 0x12543}, {0x13000, 0x1342E}, {0x14400, 0x14646}, {0x16800, 0x16A38}, {0x16A40, 0x16A5E}, {0x16AD0, 0x16AED}, {0x16B00, 0x16B2F}, {0x16B40, 0x16B43}, {0x16B63, 0x16B77}, +{0x16B7D, 0x16B8F}, {0x16E40, 0x16E7F}, {0x16F00, 0x16F4A}, {0x16F50, 0x16F50}, {0x16F93, 0x16F9F}, {0x16FE0, 0x16FE1}, {0x16FE3, 0x16FE3}, {0x17000, 0x187F7}, {0x18800, 0x18CD5}, {0x18D00, 0x18D08}, +{0x1B000, 0x1B11E}, {0x1B150, 0x1B152}, {0x1B164, 0x1B167}, {0x1B170, 0x1B2FB}, {0x1BC00, 0x1BC6A}, {0x1BC70, 0x1BC7C}, {0x1BC80, 0x1BC88}, {0x1BC90, 0x1BC99}, {0x1D400, 0x1D454}, {0x1D456, 0x1D49C}, +{0x1D49E, 0x1D49F}, {0x1D4A2, 0x1D4A2}, {0x1D4A5, 0x1D4A6}, {0x1D4A9, 0x1D4AC}, {0x1D4AE, 0x1D4B9}, {0x1D4BB, 0x1D4BB}, {0x1D4BD, 0x1D4C3}, {0x1D4C5, 0x1D505}, {0x1D507, 0x1D50A}, {0x1D50D, 0x1D514}, +{0x1D516, 0x1D51C}, {0x1D51E, 0x1D539}, {0x1D53B, 0x1D53E}, {0x1D540, 0x1D544}, {0x1D546, 0x1D546}, {0x1D54A, 0x1D550}, {0x1D552, 0x1D6A5}, {0x1D6A8, 0x1D6C0}, {0x1D6C2, 0x1D6DA}, {0x1D6DC, 0x1D6FA}, +{0x1D6FC, 0x1D714}, {0x1D716, 0x1D734}, {0x1D736, 0x1D74E}, {0x1D750, 0x1D76E}, {0x1D770, 0x1D788}, {0x1D78A, 0x1D7A8}, {0x1D7AA, 0x1D7C2}, {0x1D7C4, 0x1D7CB}, {0x1E100, 0x1E12C}, {0x1E137, 0x1E13D}, +{0x1E14E, 0x1E14E}, {0x1E2C0, 0x1E2EB}, {0x1E800, 0x1E8C4}, {0x1E900, 0x1E943}, {0x1E94B, 0x1E94B}, {0x1EE00, 0x1EE03}, {0x1EE05, 0x1EE1F}, {0x1EE21, 0x1EE22}, {0x1EE24, 0x1EE24}, {0x1EE27, 0x1EE27}, +{0x1EE29, 0x1EE32}, {0x1EE34, 0x1EE37}, {0x1EE39, 0x1EE39}, {0x1EE3B, 0x1EE3B}, {0x1EE42, 0x1EE42}, {0x1EE47, 0x1EE47}, {0x1EE49, 0x1EE49}, {0x1EE4B, 0x1EE4B}, {0x1EE4D, 0x1EE4F}, {0x1EE51, 0x1EE52}, +{0x1EE54, 0x1EE54}, {0x1EE57, 0x1EE57}, {0x1EE59, 0x1EE59}, {0x1EE5B, 0x1EE5B}, {0x1EE5D, 0x1EE5D}, {0x1EE5F, 0x1EE5F}, {0x1EE61, 0x1EE62}, {0x1EE64, 0x1EE64}, {0x1EE67, 0x1EE6A}, {0x1EE6C, 0x1EE72}, +{0x1EE74, 0x1EE77}, {0x1EE79, 0x1EE7C}, {0x1EE7E, 0x1EE7E}, {0x1EE80, 0x1EE89}, {0x1EE8B, 0x1EE9B}, {0x1EEA1, 0x1EEA3}, {0x1EEA5, 0x1EEA9}, {0x1EEAB, 0x1EEBB}, {0x20000, 0x2A6DD}, {0x2A700, 0x2B734}, +{0x2B740, 0x2B81D}, {0x2B820, 0x2CEA1}, {0x2CEB0, 0x2EBE0}, {0x2F800, 0x2FA1D}, {0x30000, 0x3134A}, +}; + +static const std::vector> whitespace_ranges = { +{0x9, 0xD}, {0x1C, 0x20}, {0x85, 0x85}, {0xA0, 0xA0}, {0x1680, 0x1680}, {0x2000, 0x200A}, {0x2028, 0x2029}, {0x202F, 0x202F}, {0x205F, 0x205F}, {0x3000, 0x3000}, +}; + +static const std::vector> accent_mark_ranges = { +{0x300, 0x36F}, {0x483, 0x489}, {0x591, 0x5BD}, {0x5BF, 0x5BF}, {0x5C1, 0x5C2}, {0x5C4, 0x5C5}, {0x5C7, 0x5C7}, {0x610, 0x61A}, {0x64B, 0x65F}, {0x670, 0x670}, {0x6D6, 0x6DC}, {0x6DF, 0x6E4}, +{0x6E7, 0x6E8}, {0x6EA, 0x6ED}, {0x711, 0x711}, {0x730, 0x74A}, {0x7A6, 0x7B0}, {0x7EB, 0x7F3}, {0x7FD, 0x7FD}, {0x816, 0x819}, {0x81B, 0x823}, {0x825, 0x827}, {0x829, 0x82D}, {0x859, 0x85B}, +{0x8D3, 0x8E1}, {0x8E3, 0x903}, {0x93A, 0x93C}, {0x93E, 0x94F}, {0x951, 0x957}, {0x962, 0x963}, {0x981, 0x983}, {0x9BC, 0x9BC}, {0x9BE, 0x9C4}, {0x9C7, 0x9C8}, {0x9CB, 0x9CD}, {0x9D7, 0x9D7}, +{0x9E2, 0x9E3}, {0x9FE, 0x9FE}, {0xA01, 0xA03}, {0xA3C, 0xA3C}, {0xA3E, 0xA42}, {0xA47, 0xA48}, {0xA4B, 0xA4D}, {0xA51, 0xA51}, {0xA70, 0xA71}, {0xA75, 0xA75}, {0xA81, 0xA83}, {0xABC, 0xABC}, +{0xABE, 0xAC5}, {0xAC7, 0xAC9}, {0xACB, 0xACD}, {0xAE2, 0xAE3}, {0xAFA, 0xAFF}, {0xB01, 0xB03}, {0xB3C, 0xB3C}, {0xB3E, 0xB44}, {0xB47, 0xB48}, {0xB4B, 0xB4D}, {0xB55, 0xB57}, {0xB62, 0xB63}, +{0xB82, 0xB82}, {0xBBE, 0xBC2}, {0xBC6, 0xBC8}, {0xBCA, 0xBCD}, {0xBD7, 0xBD7}, {0xC00, 0xC04}, {0xC3E, 0xC44}, {0xC46, 0xC48}, {0xC4A, 0xC4D}, {0xC55, 0xC56}, {0xC62, 0xC63}, {0xC81, 0xC83}, +{0xCBC, 0xCBC}, {0xCBE, 0xCC4}, {0xCC6, 0xCC8}, {0xCCA, 0xCCD}, {0xCD5, 0xCD6}, {0xCE2, 0xCE3}, {0xD00, 0xD03}, {0xD3B, 0xD3C}, {0xD3E, 0xD44}, {0xD46, 0xD48}, {0xD4A, 0xD4D}, {0xD57, 0xD57}, +{0xD62, 0xD63}, {0xD81, 0xD83}, {0xDCA, 0xDCA}, {0xDCF, 0xDD4}, {0xDD6, 0xDD6}, {0xDD8, 0xDDF}, {0xDF2, 0xDF3}, {0xE31, 0xE31}, {0xE34, 0xE3A}, {0xE47, 0xE4E}, {0xEB1, 0xEB1}, {0xEB4, 0xEBC}, +{0xEC8, 0xECD}, {0xF18, 0xF19}, {0xF35, 0xF35}, {0xF37, 0xF37}, {0xF39, 0xF39}, {0xF3E, 0xF3F}, {0xF71, 0xF84}, {0xF86, 0xF87}, {0xF8D, 0xF97}, {0xF99, 0xFBC}, {0xFC6, 0xFC6}, {0x102B, 0x103E}, +{0x1056, 0x1059}, {0x105E, 0x1060}, {0x1062, 0x1064}, {0x1067, 0x106D}, {0x1071, 0x1074}, {0x1082, 0x108D}, {0x108F, 0x108F}, {0x109A, 0x109D}, {0x135D, 0x135F}, {0x1712, 0x1714}, {0x1732, 0x1734}, +{0x1752, 0x1753}, {0x1772, 0x1773}, {0x17B4, 0x17D3}, {0x17DD, 0x17DD}, {0x180B, 0x180D}, {0x1885, 0x1886}, {0x18A9, 0x18A9}, {0x1920, 0x192B}, {0x1930, 0x193B}, {0x1A17, 0x1A1B}, {0x1A55, 0x1A5E}, +{0x1A60, 0x1A7C}, {0x1A7F, 0x1A7F}, {0x1AB0, 0x1AC0}, {0x1B00, 0x1B04}, {0x1B34, 0x1B44}, {0x1B6B, 0x1B73}, {0x1B80, 0x1B82}, {0x1BA1, 0x1BAD}, {0x1BE6, 0x1BF3}, {0x1C24, 0x1C37}, {0x1CD0, 0x1CD2}, +{0x1CD4, 0x1CE8}, {0x1CED, 0x1CED}, {0x1CF4, 0x1CF4}, {0x1CF7, 0x1CF9}, {0x1DC0, 0x1DF9}, {0x1DFB, 0x1DFF}, {0x20D0, 0x20F0}, {0x2CEF, 0x2CF1}, {0x2D7F, 0x2D7F}, {0x2DE0, 0x2DFF}, {0x302A, 0x302F}, +{0x3099, 0x309A}, {0xA66F, 0xA672}, {0xA674, 0xA67D}, {0xA69E, 0xA69F}, {0xA6F0, 0xA6F1}, {0xA802, 0xA802}, {0xA806, 0xA806}, {0xA80B, 0xA80B}, {0xA823, 0xA827}, {0xA82C, 0xA82C}, {0xA880, 0xA881}, +{0xA8B4, 0xA8C5}, {0xA8E0, 0xA8F1}, {0xA8FF, 0xA8FF}, {0xA926, 0xA92D}, {0xA947, 0xA953}, {0xA980, 0xA983}, {0xA9B3, 0xA9C0}, {0xA9E5, 0xA9E5}, {0xAA29, 0xAA36}, {0xAA43, 0xAA43}, {0xAA4C, 0xAA4D}, +{0xAA7B, 0xAA7D}, {0xAAB0, 0xAAB0}, {0xAAB2, 0xAAB4}, {0xAAB7, 0xAAB8}, {0xAABE, 0xAABF}, {0xAAC1, 0xAAC1}, {0xAAEB, 0xAAEF}, {0xAAF5, 0xAAF6}, {0xABE3, 0xABEA}, {0xABEC, 0xABED}, {0xFB1E, 0xFB1E}, +{0xFE00, 0xFE0F}, {0xFE20, 0xFE2F}, {0x101FD, 0x101FD}, {0x102E0, 0x102E0}, {0x10376, 0x1037A}, {0x10A01, 0x10A03}, {0x10A05, 0x10A06}, {0x10A0C, 0x10A0F}, {0x10A38, 0x10A3A}, {0x10A3F, 0x10A3F}, +{0x10AE5, 0x10AE6}, {0x10D24, 0x10D27}, {0x10EAB, 0x10EAC}, {0x10F46, 0x10F50}, {0x11000, 0x11002}, {0x11038, 0x11046}, {0x1107F, 0x11082}, {0x110B0, 0x110BA}, {0x11100, 0x11102}, {0x11127, 0x11134}, +{0x11145, 0x11146}, {0x11173, 0x11173}, {0x11180, 0x11182}, {0x111B3, 0x111C0}, {0x111C9, 0x111CC}, {0x111CE, 0x111CF}, {0x1122C, 0x11237}, {0x1123E, 0x1123E}, {0x112DF, 0x112EA}, {0x11300, 0x11303}, +{0x1133B, 0x1133C}, {0x1133E, 0x11344}, {0x11347, 0x11348}, {0x1134B, 0x1134D}, {0x11357, 0x11357}, {0x11362, 0x11363}, {0x11366, 0x1136C}, {0x11370, 0x11374}, {0x11435, 0x11446}, {0x1145E, 0x1145E}, +{0x114B0, 0x114C3}, {0x115AF, 0x115B5}, {0x115B8, 0x115C0}, {0x115DC, 0x115DD}, {0x11630, 0x11640}, {0x116AB, 0x116B7}, {0x1171D, 0x1172B}, {0x1182C, 0x1183A}, {0x11930, 0x11935}, {0x11937, 0x11938}, +{0x1193B, 0x1193E}, {0x11940, 0x11940}, {0x11942, 0x11943}, {0x119D1, 0x119D7}, {0x119DA, 0x119E0}, {0x119E4, 0x119E4}, {0x11A01, 0x11A0A}, {0x11A33, 0x11A39}, {0x11A3B, 0x11A3E}, {0x11A47, 0x11A47}, +{0x11A51, 0x11A5B}, {0x11A8A, 0x11A99}, {0x11C2F, 0x11C36}, {0x11C38, 0x11C3F}, {0x11C92, 0x11CA7}, {0x11CA9, 0x11CB6}, {0x11D31, 0x11D36}, {0x11D3A, 0x11D3A}, {0x11D3C, 0x11D3D}, {0x11D3F, 0x11D45}, +{0x11D47, 0x11D47}, {0x11D8A, 0x11D8E}, {0x11D90, 0x11D91}, {0x11D93, 0x11D97}, {0x11EF3, 0x11EF6}, {0x16AF0, 0x16AF4}, {0x16B30, 0x16B36}, {0x16F4F, 0x16F4F}, {0x16F51, 0x16F87}, {0x16F8F, 0x16F92}, +{0x16FE4, 0x16FE4}, {0x16FF0, 0x16FF1}, {0x1BC9D, 0x1BC9E}, {0x1D165, 0x1D169}, {0x1D16D, 0x1D172}, {0x1D17B, 0x1D182}, {0x1D185, 0x1D18B}, {0x1D1AA, 0x1D1AD}, {0x1D242, 0x1D244}, {0x1DA00, 0x1DA36}, +{0x1DA3B, 0x1DA6C}, {0x1DA75, 0x1DA75}, {0x1DA84, 0x1DA84}, {0x1DA9B, 0x1DA9F}, {0x1DAA1, 0x1DAAF}, {0x1E000, 0x1E006}, {0x1E008, 0x1E018}, {0x1E01B, 0x1E021}, {0x1E023, 0x1E024}, {0x1E026, 0x1E02A}, +{0x1E130, 0x1E136}, {0x1E2EC, 0x1E2EF}, {0x1E8D0, 0x1E8D6}, {0x1E944, 0x1E94A}, {0xE0100, 0xE01EF}, +}; + +static const std::vector> punctuation_ranges = { +{0x21, 0x23}, {0x25, 0x2A}, {0x2C, 0x2F}, {0x3A, 0x3B}, {0x3F, 0x40}, {0x5B, 0x5D}, {0x5F, 0x5F}, {0x7B, 0x7B}, {0x7D, 0x7D}, {0xA1, 0xA1}, {0xA7, 0xA7}, {0xAB, 0xAB}, {0xB6, 0xB7}, {0xBB, 0xBB}, +{0xBF, 0xBF}, {0x37E, 0x37E}, {0x387, 0x387}, {0x55A, 0x55F}, {0x589, 0x58A}, {0x5BE, 0x5BE}, {0x5C0, 0x5C0}, {0x5C3, 0x5C3}, {0x5C6, 0x5C6}, {0x5F3, 0x5F4}, {0x609, 0x60A}, {0x60C, 0x60D}, +{0x61B, 0x61B}, {0x61E, 0x61F}, {0x66A, 0x66D}, {0x6D4, 0x6D4}, {0x700, 0x70D}, {0x7F7, 0x7F9}, {0x830, 0x83E}, {0x85E, 0x85E}, {0x964, 0x965}, {0x970, 0x970}, {0x9FD, 0x9FD}, {0xA76, 0xA76}, +{0xAF0, 0xAF0}, {0xC77, 0xC77}, {0xC84, 0xC84}, {0xDF4, 0xDF4}, {0xE4F, 0xE4F}, {0xE5A, 0xE5B}, {0xF04, 0xF12}, {0xF14, 0xF14}, {0xF3A, 0xF3D}, {0xF85, 0xF85}, {0xFD0, 0xFD4}, {0xFD9, 0xFDA}, +{0x104A, 0x104F}, {0x10FB, 0x10FB}, {0x1360, 0x1368}, {0x1400, 0x1400}, {0x166E, 0x166E}, {0x169B, 0x169C}, {0x16EB, 0x16ED}, {0x1735, 0x1736}, {0x17D4, 0x17D6}, {0x17D8, 0x17DA}, {0x1800, 0x180A}, +{0x1944, 0x1945}, {0x1A1E, 0x1A1F}, {0x1AA0, 0x1AA6}, {0x1AA8, 0x1AAD}, {0x1B5A, 0x1B60}, {0x1BFC, 0x1BFF}, {0x1C3B, 0x1C3F}, {0x1C7E, 0x1C7F}, {0x1CC0, 0x1CC7}, {0x1CD3, 0x1CD3}, {0x2010, 0x2027}, +{0x2030, 0x2043}, {0x2045, 0x2051}, {0x2053, 0x205E}, {0x207D, 0x207E}, {0x208D, 0x208E}, {0x2308, 0x230B}, {0x2329, 0x232A}, {0x2768, 0x2775}, {0x27C5, 0x27C6}, {0x27E6, 0x27EF}, {0x2983, 0x2998}, +{0x29D8, 0x29DB}, {0x29FC, 0x29FD}, {0x2CF9, 0x2CFC}, {0x2CFE, 0x2CFF}, {0x2D70, 0x2D70}, {0x2E00, 0x2E2E}, {0x2E30, 0x2E4F}, {0x2E52, 0x2E52}, {0x3001, 0x3003}, {0x3008, 0x3011}, {0x3014, 0x301F}, +{0x3030, 0x3030}, {0x303D, 0x303D}, {0x30A0, 0x30A0}, {0x30FB, 0x30FB}, {0xA4FE, 0xA4FF}, {0xA60D, 0xA60F}, {0xA673, 0xA673}, {0xA67E, 0xA67E}, {0xA6F2, 0xA6F7}, {0xA874, 0xA877}, {0xA8CE, 0xA8CF}, +{0xA8F8, 0xA8FA}, {0xA8FC, 0xA8FC}, {0xA92E, 0xA92F}, {0xA95F, 0xA95F}, {0xA9C1, 0xA9CD}, {0xA9DE, 0xA9DF}, {0xAA5C, 0xAA5F}, {0xAADE, 0xAADF}, {0xAAF0, 0xAAF1}, {0xABEB, 0xABEB}, {0xFD3E, 0xFD3F}, +{0xFE10, 0xFE19}, {0xFE30, 0xFE52}, {0xFE54, 0xFE61}, {0xFE63, 0xFE63}, {0xFE68, 0xFE68}, {0xFE6A, 0xFE6B}, {0xFF01, 0xFF03}, {0xFF05, 0xFF0A}, {0xFF0C, 0xFF0F}, {0xFF1A, 0xFF1B}, {0xFF1F, 0xFF20}, +{0xFF3B, 0xFF3D}, {0xFF3F, 0xFF3F}, {0xFF5B, 0xFF5B}, {0xFF5D, 0xFF5D}, {0xFF5F, 0xFF65}, {0x10100, 0x10102}, {0x1039F, 0x1039F}, {0x103D0, 0x103D0}, {0x1056F, 0x1056F}, {0x10857, 0x10857}, +{0x1091F, 0x1091F}, {0x1093F, 0x1093F}, {0x10A50, 0x10A58}, {0x10A7F, 0x10A7F}, {0x10AF0, 0x10AF6}, {0x10B39, 0x10B3F}, {0x10B99, 0x10B9C}, {0x10EAD, 0x10EAD}, {0x10F55, 0x10F59}, {0x11047, 0x1104D}, +{0x110BB, 0x110BC}, {0x110BE, 0x110C1}, {0x11140, 0x11143}, {0x11174, 0x11175}, {0x111C5, 0x111C8}, {0x111CD, 0x111CD}, {0x111DB, 0x111DB}, {0x111DD, 0x111DF}, {0x11238, 0x1123D}, {0x112A9, 0x112A9}, +{0x1144B, 0x1144F}, {0x1145A, 0x1145B}, {0x1145D, 0x1145D}, {0x114C6, 0x114C6}, {0x115C1, 0x115D7}, {0x11641, 0x11643}, {0x11660, 0x1166C}, {0x1173C, 0x1173E}, {0x1183B, 0x1183B}, {0x11944, 0x11946}, +{0x119E2, 0x119E2}, {0x11A3F, 0x11A46}, {0x11A9A, 0x11A9C}, {0x11A9E, 0x11AA2}, {0x11C41, 0x11C45}, {0x11C70, 0x11C71}, {0x11EF7, 0x11EF8}, {0x11FFF, 0x11FFF}, {0x12470, 0x12474}, {0x16A6E, 0x16A6F}, +{0x16AF5, 0x16AF5}, {0x16B37, 0x16B3B}, {0x16B44, 0x16B44}, {0x16E97, 0x16E9A}, {0x16FE2, 0x16FE2}, {0x1BC9F, 0x1BC9F}, {0x1DA87, 0x1DA8B}, {0x1E95E, 0x1E95F}, +}; + +static const std::vector> symbol_ranges = { +{0x24, 0x24}, {0x2B, 0x2B}, {0x3C, 0x3E}, {0x5E, 0x5E}, {0x60, 0x60}, {0x7C, 0x7C}, {0x7E, 0x7E}, {0xA2, 0xA6}, {0xA8, 0xA9}, {0xAC, 0xAC}, {0xAE, 0xB1}, {0xB4, 0xB4}, {0xB8, 0xB8}, {0xD7, 0xD7}, +{0xF7, 0xF7}, {0x2C2, 0x2C5}, {0x2D2, 0x2DF}, {0x2E5, 0x2EB}, {0x2ED, 0x2ED}, {0x2EF, 0x2FF}, {0x375, 0x375}, {0x384, 0x385}, {0x3F6, 0x3F6}, {0x482, 0x482}, {0x58D, 0x58F}, {0x606, 0x608}, +{0x60B, 0x60B}, {0x60E, 0x60F}, {0x6DE, 0x6DE}, {0x6E9, 0x6E9}, {0x6FD, 0x6FE}, {0x7F6, 0x7F6}, {0x7FE, 0x7FF}, {0x9F2, 0x9F3}, {0x9FA, 0x9FB}, {0xAF1, 0xAF1}, {0xB70, 0xB70}, {0xBF3, 0xBFA}, +{0xC7F, 0xC7F}, {0xD4F, 0xD4F}, {0xD79, 0xD79}, {0xE3F, 0xE3F}, {0xF01, 0xF03}, {0xF13, 0xF13}, {0xF15, 0xF17}, {0xF1A, 0xF1F}, {0xF34, 0xF34}, {0xF36, 0xF36}, {0xF38, 0xF38}, {0xFBE, 0xFC5}, +{0xFC7, 0xFCC}, {0xFCE, 0xFCF}, {0xFD5, 0xFD8}, {0x109E, 0x109F}, {0x1390, 0x1399}, {0x166D, 0x166D}, {0x17DB, 0x17DB}, {0x1940, 0x1940}, {0x19DE, 0x19FF}, {0x1B61, 0x1B6A}, {0x1B74, 0x1B7C}, +{0x1FBD, 0x1FBD}, {0x1FBF, 0x1FC1}, {0x1FCD, 0x1FCF}, {0x1FDD, 0x1FDF}, {0x1FED, 0x1FEF}, {0x1FFD, 0x1FFE}, {0x2044, 0x2044}, {0x2052, 0x2052}, {0x207A, 0x207C}, {0x208A, 0x208C}, {0x20A0, 0x20BF}, +{0x2100, 0x2101}, {0x2103, 0x2106}, {0x2108, 0x2109}, {0x2114, 0x2114}, {0x2116, 0x2118}, {0x211E, 0x2123}, {0x2125, 0x2125}, {0x2127, 0x2127}, {0x2129, 0x2129}, {0x212E, 0x212E}, {0x213A, 0x213B}, +{0x2140, 0x2144}, {0x214A, 0x214D}, {0x214F, 0x214F}, {0x218A, 0x218B}, {0x2190, 0x2307}, {0x230C, 0x2328}, {0x232B, 0x2426}, {0x2440, 0x244A}, {0x249C, 0x24E9}, {0x2500, 0x2767}, {0x2794, 0x27C4}, +{0x27C7, 0x27E5}, {0x27F0, 0x2982}, {0x2999, 0x29D7}, {0x29DC, 0x29FB}, {0x29FE, 0x2B73}, {0x2B76, 0x2B95}, {0x2B97, 0x2BFF}, {0x2CE5, 0x2CEA}, {0x2E50, 0x2E51}, {0x2E80, 0x2E99}, {0x2E9B, 0x2EF3}, +{0x2F00, 0x2FD5}, {0x2FF0, 0x2FFB}, {0x3004, 0x3004}, {0x3012, 0x3013}, {0x3020, 0x3020}, {0x3036, 0x3037}, {0x303E, 0x303F}, {0x309B, 0x309C}, {0x3190, 0x3191}, {0x3196, 0x319F}, {0x31C0, 0x31E3}, +{0x3200, 0x321E}, {0x322A, 0x3247}, {0x3250, 0x3250}, {0x3260, 0x327F}, {0x328A, 0x32B0}, {0x32C0, 0x33FF}, {0x4DC0, 0x4DFF}, {0xA490, 0xA4C6}, {0xA700, 0xA716}, {0xA720, 0xA721}, {0xA789, 0xA78A}, +{0xA828, 0xA82B}, {0xA836, 0xA839}, {0xAA77, 0xAA79}, {0xAB5B, 0xAB5B}, {0xAB6A, 0xAB6B}, {0xFB29, 0xFB29}, {0xFBB2, 0xFBC1}, {0xFDFC, 0xFDFD}, {0xFE62, 0xFE62}, {0xFE64, 0xFE66}, {0xFE69, 0xFE69}, +{0xFF04, 0xFF04}, {0xFF0B, 0xFF0B}, {0xFF1C, 0xFF1E}, {0xFF3E, 0xFF3E}, {0xFF40, 0xFF40}, {0xFF5C, 0xFF5C}, {0xFF5E, 0xFF5E}, {0xFFE0, 0xFFE6}, {0xFFE8, 0xFFEE}, {0xFFFC, 0xFFFD}, {0x10137, 0x1013F}, +{0x10179, 0x10189}, {0x1018C, 0x1018E}, {0x10190, 0x1019C}, {0x101A0, 0x101A0}, {0x101D0, 0x101FC}, {0x10877, 0x10878}, {0x10AC8, 0x10AC8}, {0x1173F, 0x1173F}, {0x11FD5, 0x11FF1}, {0x16B3C, 0x16B3F}, +{0x16B45, 0x16B45}, {0x1BC9C, 0x1BC9C}, {0x1D000, 0x1D0F5}, {0x1D100, 0x1D126}, {0x1D129, 0x1D164}, {0x1D16A, 0x1D16C}, {0x1D183, 0x1D184}, {0x1D18C, 0x1D1A9}, {0x1D1AE, 0x1D1E8}, {0x1D200, 0x1D241}, +{0x1D245, 0x1D245}, {0x1D300, 0x1D356}, {0x1D6C1, 0x1D6C1}, {0x1D6DB, 0x1D6DB}, {0x1D6FB, 0x1D6FB}, {0x1D715, 0x1D715}, {0x1D735, 0x1D735}, {0x1D74F, 0x1D74F}, {0x1D76F, 0x1D76F}, {0x1D789, 0x1D789}, +{0x1D7A9, 0x1D7A9}, {0x1D7C3, 0x1D7C3}, {0x1D800, 0x1D9FF}, {0x1DA37, 0x1DA3A}, {0x1DA6D, 0x1DA74}, {0x1DA76, 0x1DA83}, {0x1DA85, 0x1DA86}, {0x1E14F, 0x1E14F}, {0x1E2FF, 0x1E2FF}, {0x1ECAC, 0x1ECAC}, +{0x1ECB0, 0x1ECB0}, {0x1ED2E, 0x1ED2E}, {0x1EEF0, 0x1EEF1}, {0x1F000, 0x1F02B}, {0x1F030, 0x1F093}, {0x1F0A0, 0x1F0AE}, {0x1F0B1, 0x1F0BF}, {0x1F0C1, 0x1F0CF}, {0x1F0D1, 0x1F0F5}, {0x1F10D, 0x1F1AD}, +{0x1F1E6, 0x1F202}, {0x1F210, 0x1F23B}, {0x1F240, 0x1F248}, {0x1F250, 0x1F251}, {0x1F260, 0x1F265}, {0x1F300, 0x1F6D7}, {0x1F6E0, 0x1F6EC}, {0x1F6F0, 0x1F6FC}, {0x1F700, 0x1F773}, {0x1F780, 0x1F7D8}, +{0x1F7E0, 0x1F7EB}, {0x1F800, 0x1F80B}, {0x1F810, 0x1F847}, {0x1F850, 0x1F859}, {0x1F860, 0x1F887}, {0x1F890, 0x1F8AD}, {0x1F8B0, 0x1F8B1}, {0x1F900, 0x1F978}, {0x1F97A, 0x1F9CB}, {0x1F9CD, 0x1FA53}, +{0x1FA60, 0x1FA6D}, {0x1FA70, 0x1FA74}, {0x1FA78, 0x1FA7A}, {0x1FA80, 0x1FA86}, {0x1FA90, 0x1FAA8}, {0x1FAB0, 0x1FAB6}, {0x1FAC0, 0x1FAC2}, {0x1FAD0, 0x1FAD6}, {0x1FB00, 0x1FB92}, {0x1FB94, 0x1FBCA}, +}; + +static const std::vector> control_ranges = { +{0x0, 0x8}, {0xE, 0x1B}, {0x7F, 0x84}, {0x86, 0x9F}, {0xAD, 0xAD}, {0x378, 0x379}, {0x380, 0x383}, {0x38B, 0x38B}, {0x38D, 0x38D}, {0x3A2, 0x3A2}, {0x530, 0x530}, {0x557, 0x558}, {0x58B, 0x58C}, +{0x590, 0x590}, {0x5C8, 0x5CF}, {0x5EB, 0x5EE}, {0x5F5, 0x605}, {0x61C, 0x61D}, {0x6DD, 0x6DD}, {0x70E, 0x70F}, {0x74B, 0x74C}, {0x7B2, 0x7BF}, {0x7FB, 0x7FC}, {0x82E, 0x82F}, {0x83F, 0x83F}, +{0x85C, 0x85D}, {0x85F, 0x85F}, {0x86B, 0x89F}, {0x8B5, 0x8B5}, {0x8C8, 0x8D2}, {0x8E2, 0x8E2}, {0x984, 0x984}, {0x98D, 0x98E}, {0x991, 0x992}, {0x9A9, 0x9A9}, {0x9B1, 0x9B1}, {0x9B3, 0x9B5}, +{0x9BA, 0x9BB}, {0x9C5, 0x9C6}, {0x9C9, 0x9CA}, {0x9CF, 0x9D6}, {0x9D8, 0x9DB}, {0x9DE, 0x9DE}, {0x9E4, 0x9E5}, {0x9FF, 0xA00}, {0xA04, 0xA04}, {0xA0B, 0xA0E}, {0xA11, 0xA12}, {0xA29, 0xA29}, +{0xA31, 0xA31}, {0xA34, 0xA34}, {0xA37, 0xA37}, {0xA3A, 0xA3B}, {0xA3D, 0xA3D}, {0xA43, 0xA46}, {0xA49, 0xA4A}, {0xA4E, 0xA50}, {0xA52, 0xA58}, {0xA5D, 0xA5D}, {0xA5F, 0xA65}, {0xA77, 0xA80}, +{0xA84, 0xA84}, {0xA8E, 0xA8E}, {0xA92, 0xA92}, {0xAA9, 0xAA9}, {0xAB1, 0xAB1}, {0xAB4, 0xAB4}, {0xABA, 0xABB}, {0xAC6, 0xAC6}, {0xACA, 0xACA}, {0xACE, 0xACF}, {0xAD1, 0xADF}, {0xAE4, 0xAE5}, +{0xAF2, 0xAF8}, {0xB00, 0xB00}, {0xB04, 0xB04}, {0xB0D, 0xB0E}, {0xB11, 0xB12}, {0xB29, 0xB29}, {0xB31, 0xB31}, {0xB34, 0xB34}, {0xB3A, 0xB3B}, {0xB45, 0xB46}, {0xB49, 0xB4A}, {0xB4E, 0xB54}, +{0xB58, 0xB5B}, {0xB5E, 0xB5E}, {0xB64, 0xB65}, {0xB78, 0xB81}, {0xB84, 0xB84}, {0xB8B, 0xB8D}, {0xB91, 0xB91}, {0xB96, 0xB98}, {0xB9B, 0xB9B}, {0xB9D, 0xB9D}, {0xBA0, 0xBA2}, {0xBA5, 0xBA7}, +{0xBAB, 0xBAD}, {0xBBA, 0xBBD}, {0xBC3, 0xBC5}, {0xBC9, 0xBC9}, {0xBCE, 0xBCF}, {0xBD1, 0xBD6}, {0xBD8, 0xBE5}, {0xBFB, 0xBFF}, {0xC0D, 0xC0D}, {0xC11, 0xC11}, {0xC29, 0xC29}, {0xC3A, 0xC3C}, +{0xC45, 0xC45}, {0xC49, 0xC49}, {0xC4E, 0xC54}, {0xC57, 0xC57}, {0xC5B, 0xC5F}, {0xC64, 0xC65}, {0xC70, 0xC76}, {0xC8D, 0xC8D}, {0xC91, 0xC91}, {0xCA9, 0xCA9}, {0xCB4, 0xCB4}, {0xCBA, 0xCBB}, +{0xCC5, 0xCC5}, {0xCC9, 0xCC9}, {0xCCE, 0xCD4}, {0xCD7, 0xCDD}, {0xCDF, 0xCDF}, {0xCE4, 0xCE5}, {0xCF0, 0xCF0}, {0xCF3, 0xCFF}, {0xD0D, 0xD0D}, {0xD11, 0xD11}, {0xD45, 0xD45}, {0xD49, 0xD49}, +{0xD50, 0xD53}, {0xD64, 0xD65}, {0xD80, 0xD80}, {0xD84, 0xD84}, {0xD97, 0xD99}, {0xDB2, 0xDB2}, {0xDBC, 0xDBC}, {0xDBE, 0xDBF}, {0xDC7, 0xDC9}, {0xDCB, 0xDCE}, {0xDD5, 0xDD5}, {0xDD7, 0xDD7}, +{0xDE0, 0xDE5}, {0xDF0, 0xDF1}, {0xDF5, 0xE00}, {0xE3B, 0xE3E}, {0xE5C, 0xE80}, {0xE83, 0xE83}, {0xE85, 0xE85}, {0xE8B, 0xE8B}, {0xEA4, 0xEA4}, {0xEA6, 0xEA6}, {0xEBE, 0xEBF}, {0xEC5, 0xEC5}, +{0xEC7, 0xEC7}, {0xECE, 0xECF}, {0xEDA, 0xEDB}, {0xEE0, 0xEFF}, {0xF48, 0xF48}, {0xF6D, 0xF70}, {0xF98, 0xF98}, {0xFBD, 0xFBD}, {0xFCD, 0xFCD}, {0xFDB, 0xFFF}, {0x10C6, 0x10C6}, {0x10C8, 0x10CC}, +{0x10CE, 0x10CF}, {0x1249, 0x1249}, {0x124E, 0x124F}, {0x1257, 0x1257}, {0x1259, 0x1259}, {0x125E, 0x125F}, {0x1289, 0x1289}, {0x128E, 0x128F}, {0x12B1, 0x12B1}, {0x12B6, 0x12B7}, {0x12BF, 0x12BF}, +{0x12C1, 0x12C1}, {0x12C6, 0x12C7}, {0x12D7, 0x12D7}, {0x1311, 0x1311}, {0x1316, 0x1317}, {0x135B, 0x135C}, {0x137D, 0x137F}, {0x139A, 0x139F}, {0x13F6, 0x13F7}, {0x13FE, 0x13FF}, {0x169D, 0x169F}, +{0x16F9, 0x16FF}, {0x170D, 0x170D}, {0x1715, 0x171F}, {0x1737, 0x173F}, {0x1754, 0x175F}, {0x176D, 0x176D}, {0x1771, 0x1771}, {0x1774, 0x177F}, {0x17DE, 0x17DF}, {0x17EA, 0x17EF}, {0x17FA, 0x17FF}, +{0x180E, 0x180F}, {0x181A, 0x181F}, {0x1879, 0x187F}, {0x18AB, 0x18AF}, {0x18F6, 0x18FF}, {0x191F, 0x191F}, {0x192C, 0x192F}, {0x193C, 0x193F}, {0x1941, 0x1943}, {0x196E, 0x196F}, {0x1975, 0x197F}, +{0x19AC, 0x19AF}, {0x19CA, 0x19CF}, {0x19DB, 0x19DD}, {0x1A1C, 0x1A1D}, {0x1A5F, 0x1A5F}, {0x1A7D, 0x1A7E}, {0x1A8A, 0x1A8F}, {0x1A9A, 0x1A9F}, {0x1AAE, 0x1AAF}, {0x1AC1, 0x1AFF}, {0x1B4C, 0x1B4F}, +{0x1B7D, 0x1B7F}, {0x1BF4, 0x1BFB}, {0x1C38, 0x1C3A}, {0x1C4A, 0x1C4C}, {0x1C89, 0x1C8F}, {0x1CBB, 0x1CBC}, {0x1CC8, 0x1CCF}, {0x1CFB, 0x1CFF}, {0x1DFA, 0x1DFA}, {0x1F16, 0x1F17}, {0x1F1E, 0x1F1F}, +{0x1F46, 0x1F47}, {0x1F4E, 0x1F4F}, {0x1F58, 0x1F58}, {0x1F5A, 0x1F5A}, {0x1F5C, 0x1F5C}, {0x1F5E, 0x1F5E}, {0x1F7E, 0x1F7F}, {0x1FB5, 0x1FB5}, {0x1FC5, 0x1FC5}, {0x1FD4, 0x1FD5}, {0x1FDC, 0x1FDC}, +{0x1FF0, 0x1FF1}, {0x1FF5, 0x1FF5}, {0x1FFF, 0x1FFF}, {0x200B, 0x200F}, {0x202A, 0x202E}, {0x2060, 0x206F}, {0x2072, 0x2073}, {0x208F, 0x208F}, {0x209D, 0x209F}, {0x20C0, 0x20CF}, {0x20F1, 0x20FF}, +{0x218C, 0x218F}, {0x2427, 0x243F}, {0x244B, 0x245F}, {0x2B74, 0x2B75}, {0x2B96, 0x2B96}, {0x2C2F, 0x2C2F}, {0x2C5F, 0x2C5F}, {0x2CF4, 0x2CF8}, {0x2D26, 0x2D26}, {0x2D28, 0x2D2C}, {0x2D2E, 0x2D2F}, +{0x2D68, 0x2D6E}, {0x2D71, 0x2D7E}, {0x2D97, 0x2D9F}, {0x2DA7, 0x2DA7}, {0x2DAF, 0x2DAF}, {0x2DB7, 0x2DB7}, {0x2DBF, 0x2DBF}, {0x2DC7, 0x2DC7}, {0x2DCF, 0x2DCF}, {0x2DD7, 0x2DD7}, {0x2DDF, 0x2DDF}, +{0x2E53, 0x2E7F}, {0x2E9A, 0x2E9A}, {0x2EF4, 0x2EFF}, {0x2FD6, 0x2FEF}, {0x2FFC, 0x2FFF}, {0x3040, 0x3040}, {0x3097, 0x3098}, {0x3100, 0x3104}, {0x3130, 0x3130}, {0x318F, 0x318F}, {0x31E4, 0x31EF}, +{0x321F, 0x321F}, {0x9FFD, 0x9FFF}, {0xA48D, 0xA48F}, {0xA4C7, 0xA4CF}, {0xA62C, 0xA63F}, {0xA6F8, 0xA6FF}, {0xA7C0, 0xA7C1}, {0xA7CB, 0xA7F4}, {0xA82D, 0xA82F}, {0xA83A, 0xA83F}, {0xA878, 0xA87F}, +{0xA8C6, 0xA8CD}, {0xA8DA, 0xA8DF}, {0xA954, 0xA95E}, {0xA97D, 0xA97F}, {0xA9CE, 0xA9CE}, {0xA9DA, 0xA9DD}, {0xA9FF, 0xA9FF}, {0xAA37, 0xAA3F}, {0xAA4E, 0xAA4F}, {0xAA5A, 0xAA5B}, {0xAAC3, 0xAADA}, +{0xAAF7, 0xAB00}, {0xAB07, 0xAB08}, {0xAB0F, 0xAB10}, {0xAB17, 0xAB1F}, {0xAB27, 0xAB27}, {0xAB2F, 0xAB2F}, {0xAB6C, 0xAB6F}, {0xABEE, 0xABEF}, {0xABFA, 0xABFF}, {0xD7A4, 0xD7AF}, {0xD7C7, 0xD7CA}, +{0xD7FC, 0xF8FF}, {0xFA6E, 0xFA6F}, {0xFADA, 0xFAFF}, {0xFB07, 0xFB12}, {0xFB18, 0xFB1C}, {0xFB37, 0xFB37}, {0xFB3D, 0xFB3D}, {0xFB3F, 0xFB3F}, {0xFB42, 0xFB42}, {0xFB45, 0xFB45}, {0xFBC2, 0xFBD2}, +{0xFD40, 0xFD4F}, {0xFD90, 0xFD91}, {0xFDC8, 0xFDEF}, {0xFDFE, 0xFDFF}, {0xFE1A, 0xFE1F}, {0xFE53, 0xFE53}, {0xFE67, 0xFE67}, {0xFE6C, 0xFE6F}, {0xFE75, 0xFE75}, {0xFEFD, 0xFF00}, {0xFFBF, 0xFFC1}, +{0xFFC8, 0xFFC9}, {0xFFD0, 0xFFD1}, {0xFFD8, 0xFFD9}, {0xFFDD, 0xFFDF}, {0xFFE7, 0xFFE7}, {0xFFEF, 0xFFFB}, {0xFFFE, 0xFFFF}, {0x1000C, 0x1000C}, {0x10027, 0x10027}, {0x1003B, 0x1003B}, +{0x1003E, 0x1003E}, {0x1004E, 0x1004F}, {0x1005E, 0x1007F}, {0x100FB, 0x100FF}, {0x10103, 0x10106}, {0x10134, 0x10136}, {0x1018F, 0x1018F}, {0x1019D, 0x1019F}, {0x101A1, 0x101CF}, {0x101FE, 0x1027F}, +{0x1029D, 0x1029F}, {0x102D1, 0x102DF}, {0x102FC, 0x102FF}, {0x10324, 0x1032C}, {0x1034B, 0x1034F}, {0x1037B, 0x1037F}, {0x1039E, 0x1039E}, {0x103C4, 0x103C7}, {0x103D6, 0x103FF}, {0x1049E, 0x1049F}, +{0x104AA, 0x104AF}, {0x104D4, 0x104D7}, {0x104FC, 0x104FF}, {0x10528, 0x1052F}, {0x10564, 0x1056E}, {0x10570, 0x105FF}, {0x10737, 0x1073F}, {0x10756, 0x1075F}, {0x10768, 0x107FF}, {0x10806, 0x10807}, +{0x10809, 0x10809}, {0x10836, 0x10836}, {0x10839, 0x1083B}, {0x1083D, 0x1083E}, {0x10856, 0x10856}, {0x1089F, 0x108A6}, {0x108B0, 0x108DF}, {0x108F3, 0x108F3}, {0x108F6, 0x108FA}, {0x1091C, 0x1091E}, +{0x1093A, 0x1093E}, {0x10940, 0x1097F}, {0x109B8, 0x109BB}, {0x109D0, 0x109D1}, {0x10A04, 0x10A04}, {0x10A07, 0x10A0B}, {0x10A14, 0x10A14}, {0x10A18, 0x10A18}, {0x10A36, 0x10A37}, {0x10A3B, 0x10A3E}, +{0x10A49, 0x10A4F}, {0x10A59, 0x10A5F}, {0x10AA0, 0x10ABF}, {0x10AE7, 0x10AEA}, {0x10AF7, 0x10AFF}, {0x10B36, 0x10B38}, {0x10B56, 0x10B57}, {0x10B73, 0x10B77}, {0x10B92, 0x10B98}, {0x10B9D, 0x10BA8}, +{0x10BB0, 0x10BFF}, {0x10C49, 0x10C7F}, {0x10CB3, 0x10CBF}, {0x10CF3, 0x10CF9}, {0x10D28, 0x10D2F}, {0x10D3A, 0x10E5F}, {0x10E7F, 0x10E7F}, {0x10EAA, 0x10EAA}, {0x10EAE, 0x10EAF}, {0x10EB2, 0x10EFF}, +{0x10F28, 0x10F2F}, {0x10F5A, 0x10FAF}, {0x10FCC, 0x10FDF}, {0x10FF7, 0x10FFF}, {0x1104E, 0x11051}, {0x11070, 0x1107E}, {0x110BD, 0x110BD}, {0x110C2, 0x110CF}, {0x110E9, 0x110EF}, {0x110FA, 0x110FF}, +{0x11135, 0x11135}, {0x11148, 0x1114F}, {0x11177, 0x1117F}, {0x111E0, 0x111E0}, {0x111F5, 0x111FF}, {0x11212, 0x11212}, {0x1123F, 0x1127F}, {0x11287, 0x11287}, {0x11289, 0x11289}, {0x1128E, 0x1128E}, +{0x1129E, 0x1129E}, {0x112AA, 0x112AF}, {0x112EB, 0x112EF}, {0x112FA, 0x112FF}, {0x11304, 0x11304}, {0x1130D, 0x1130E}, {0x11311, 0x11312}, {0x11329, 0x11329}, {0x11331, 0x11331}, {0x11334, 0x11334}, +{0x1133A, 0x1133A}, {0x11345, 0x11346}, {0x11349, 0x1134A}, {0x1134E, 0x1134F}, {0x11351, 0x11356}, {0x11358, 0x1135C}, {0x11364, 0x11365}, {0x1136D, 0x1136F}, {0x11375, 0x113FF}, {0x1145C, 0x1145C}, +{0x11462, 0x1147F}, {0x114C8, 0x114CF}, {0x114DA, 0x1157F}, {0x115B6, 0x115B7}, {0x115DE, 0x115FF}, {0x11645, 0x1164F}, {0x1165A, 0x1165F}, {0x1166D, 0x1167F}, {0x116B9, 0x116BF}, {0x116CA, 0x116FF}, +{0x1171B, 0x1171C}, {0x1172C, 0x1172F}, {0x11740, 0x117FF}, {0x1183C, 0x1189F}, {0x118F3, 0x118FE}, {0x11907, 0x11908}, {0x1190A, 0x1190B}, {0x11914, 0x11914}, {0x11917, 0x11917}, {0x11936, 0x11936}, +{0x11939, 0x1193A}, {0x11947, 0x1194F}, {0x1195A, 0x1199F}, {0x119A8, 0x119A9}, {0x119D8, 0x119D9}, {0x119E5, 0x119FF}, {0x11A48, 0x11A4F}, {0x11AA3, 0x11ABF}, {0x11AF9, 0x11BFF}, {0x11C09, 0x11C09}, +{0x11C37, 0x11C37}, {0x11C46, 0x11C4F}, {0x11C6D, 0x11C6F}, {0x11C90, 0x11C91}, {0x11CA8, 0x11CA8}, {0x11CB7, 0x11CFF}, {0x11D07, 0x11D07}, {0x11D0A, 0x11D0A}, {0x11D37, 0x11D39}, {0x11D3B, 0x11D3B}, +{0x11D3E, 0x11D3E}, {0x11D48, 0x11D4F}, {0x11D5A, 0x11D5F}, {0x11D66, 0x11D66}, {0x11D69, 0x11D69}, {0x11D8F, 0x11D8F}, {0x11D92, 0x11D92}, {0x11D99, 0x11D9F}, {0x11DAA, 0x11EDF}, {0x11EF9, 0x11FAF}, +{0x11FB1, 0x11FBF}, {0x11FF2, 0x11FFE}, {0x1239A, 0x123FF}, {0x1246F, 0x1246F}, {0x12475, 0x1247F}, {0x12544, 0x12FFF}, {0x1342F, 0x143FF}, {0x14647, 0x167FF}, {0x16A39, 0x16A3F}, {0x16A5F, 0x16A5F}, +{0x16A6A, 0x16A6D}, {0x16A70, 0x16ACF}, {0x16AEE, 0x16AEF}, {0x16AF6, 0x16AFF}, {0x16B46, 0x16B4F}, {0x16B5A, 0x16B5A}, {0x16B62, 0x16B62}, {0x16B78, 0x16B7C}, {0x16B90, 0x16E3F}, {0x16E9B, 0x16EFF}, +{0x16F4B, 0x16F4E}, {0x16F88, 0x16F8E}, {0x16FA0, 0x16FDF}, {0x16FE5, 0x16FEF}, {0x16FF2, 0x16FFF}, {0x187F8, 0x187FF}, {0x18CD6, 0x18CFF}, {0x18D09, 0x1AFFF}, {0x1B11F, 0x1B14F}, {0x1B153, 0x1B163}, +{0x1B168, 0x1B16F}, {0x1B2FC, 0x1BBFF}, {0x1BC6B, 0x1BC6F}, {0x1BC7D, 0x1BC7F}, {0x1BC89, 0x1BC8F}, {0x1BC9A, 0x1BC9B}, {0x1BCA0, 0x1CFFF}, {0x1D0F6, 0x1D0FF}, {0x1D127, 0x1D128}, {0x1D173, 0x1D17A}, +{0x1D1E9, 0x1D1FF}, {0x1D246, 0x1D2DF}, {0x1D2F4, 0x1D2FF}, {0x1D357, 0x1D35F}, {0x1D379, 0x1D3FF}, {0x1D455, 0x1D455}, {0x1D49D, 0x1D49D}, {0x1D4A0, 0x1D4A1}, {0x1D4A3, 0x1D4A4}, {0x1D4A7, 0x1D4A8}, +{0x1D4AD, 0x1D4AD}, {0x1D4BA, 0x1D4BA}, {0x1D4BC, 0x1D4BC}, {0x1D4C4, 0x1D4C4}, {0x1D506, 0x1D506}, {0x1D50B, 0x1D50C}, {0x1D515, 0x1D515}, {0x1D51D, 0x1D51D}, {0x1D53A, 0x1D53A}, {0x1D53F, 0x1D53F}, +{0x1D545, 0x1D545}, {0x1D547, 0x1D549}, {0x1D551, 0x1D551}, {0x1D6A6, 0x1D6A7}, {0x1D7CC, 0x1D7CD}, {0x1DA8C, 0x1DA9A}, {0x1DAA0, 0x1DAA0}, {0x1DAB0, 0x1DFFF}, {0x1E007, 0x1E007}, {0x1E019, 0x1E01A}, +{0x1E022, 0x1E022}, {0x1E025, 0x1E025}, {0x1E02B, 0x1E0FF}, {0x1E12D, 0x1E12F}, {0x1E13E, 0x1E13F}, {0x1E14A, 0x1E14D}, {0x1E150, 0x1E2BF}, {0x1E2FA, 0x1E2FE}, {0x1E300, 0x1E7FF}, {0x1E8C5, 0x1E8C6}, +{0x1E8D7, 0x1E8FF}, {0x1E94C, 0x1E94F}, {0x1E95A, 0x1E95D}, {0x1E960, 0x1EC70}, {0x1ECB5, 0x1ED00}, {0x1ED3E, 0x1EDFF}, {0x1EE04, 0x1EE04}, {0x1EE20, 0x1EE20}, {0x1EE23, 0x1EE23}, {0x1EE25, 0x1EE26}, +{0x1EE28, 0x1EE28}, {0x1EE33, 0x1EE33}, {0x1EE38, 0x1EE38}, {0x1EE3A, 0x1EE3A}, {0x1EE3C, 0x1EE41}, {0x1EE43, 0x1EE46}, {0x1EE48, 0x1EE48}, {0x1EE4A, 0x1EE4A}, {0x1EE4C, 0x1EE4C}, {0x1EE50, 0x1EE50}, +{0x1EE53, 0x1EE53}, {0x1EE55, 0x1EE56}, {0x1EE58, 0x1EE58}, {0x1EE5A, 0x1EE5A}, {0x1EE5C, 0x1EE5C}, {0x1EE5E, 0x1EE5E}, {0x1EE60, 0x1EE60}, {0x1EE63, 0x1EE63}, {0x1EE65, 0x1EE66}, {0x1EE6B, 0x1EE6B}, +{0x1EE73, 0x1EE73}, {0x1EE78, 0x1EE78}, {0x1EE7D, 0x1EE7D}, {0x1EE7F, 0x1EE7F}, {0x1EE8A, 0x1EE8A}, {0x1EE9C, 0x1EEA0}, {0x1EEA4, 0x1EEA4}, {0x1EEAA, 0x1EEAA}, {0x1EEBC, 0x1EEEF}, {0x1EEF2, 0x1EFFF}, +{0x1F02C, 0x1F02F}, {0x1F094, 0x1F09F}, {0x1F0AF, 0x1F0B0}, {0x1F0C0, 0x1F0C0}, {0x1F0D0, 0x1F0D0}, {0x1F0F6, 0x1F0FF}, {0x1F1AE, 0x1F1E5}, {0x1F203, 0x1F20F}, {0x1F23C, 0x1F23F}, {0x1F249, 0x1F24F}, +{0x1F252, 0x1F25F}, {0x1F266, 0x1F2FF}, {0x1F6D8, 0x1F6DF}, {0x1F6ED, 0x1F6EF}, {0x1F6FD, 0x1F6FF}, {0x1F774, 0x1F77F}, {0x1F7D9, 0x1F7DF}, {0x1F7EC, 0x1F7FF}, {0x1F80C, 0x1F80F}, {0x1F848, 0x1F84F}, +{0x1F85A, 0x1F85F}, {0x1F888, 0x1F88F}, {0x1F8AE, 0x1F8AF}, {0x1F8B2, 0x1F8FF}, {0x1F979, 0x1F979}, {0x1F9CC, 0x1F9CC}, {0x1FA54, 0x1FA5F}, {0x1FA6E, 0x1FA6F}, {0x1FA75, 0x1FA77}, {0x1FA7B, 0x1FA7F}, +{0x1FA87, 0x1FA8F}, {0x1FAA9, 0x1FAAF}, {0x1FAB7, 0x1FABF}, {0x1FAC3, 0x1FACF}, {0x1FAD7, 0x1FAFF}, {0x1FB93, 0x1FB93}, {0x1FBCB, 0x1FBEF}, {0x1FBFA, 0x1FFFF}, {0x2A6DE, 0x2A6FF}, {0x2B735, 0x2B73F}, +{0x2B81E, 0x2B81F}, {0x2CEA2, 0x2CEAF}, {0x2EBE1, 0x2F7FF}, {0x2FA1E, 0x2FFFF}, {0x3134B, 0xE00FF}, {0xE01F0, 0x10FFFF}, +}; + +static std::string codepoint_to_utf8(uint32_t cp) { + std::string result; + if (/* 0x00 <= cp && */ cp <= 0x7f) { + result.push_back(cp); + } + else if (0x80 <= cp && cp <= 0x7ff) { + result.push_back(0xc0 | ((cp >> 6) & 0x1f)); + result.push_back(0x80 | (cp & 0x3f)); + } + else if (0x800 <= cp && cp <= 0xffff) { + result.push_back(0xe0 | ((cp >> 12) & 0x0f)); + result.push_back(0x80 | ((cp >> 6) & 0x3f)); + result.push_back(0x80 | (cp & 0x3f)); + } + else if (0x10000 <= cp && cp <= 0x10ffff) { + result.push_back(0xf0 | ((cp >> 18) & 0x07)); + result.push_back(0x80 | ((cp >> 12) & 0x3f)); + result.push_back(0x80 | ((cp >> 6) & 0x3f)); + result.push_back(0x80 | (cp & 0x3f)); + } + else { + throw std::invalid_argument("invalid codepoint"); + } + return result; +} + +static std::string codepoints_to_utf8(const std::vector & cps) { + std::string result; + for (size_t i = 0; i < cps.size(); ++i) { + result.append(codepoint_to_utf8(cps[i])); + } + return result; +} + +static uint32_t codepoint_from_utf8(const std::string & utf8, size_t & offset) { + assert(offset < utf8.size()); + if (!(utf8[offset + 0] & 0x80)) { + auto result = utf8[offset + 0]; + offset += 1; + return result; + } + else if (!(utf8[offset + 0] & 0x40)) { + throw std::invalid_argument("invalid character"); + } + else if (!(utf8[offset + 0] & 0x20)) { + if (offset + 1 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80)) + throw std::invalid_argument("invalid character"); + auto result = ((utf8[offset + 0] & 0x1f) << 6) | (utf8[offset + 1] & 0x3f); + offset += 2; + return result; + } + else if (!(utf8[offset + 0] & 0x10)) { + if (offset + 2 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80)) + throw std::invalid_argument("invalid character"); + auto result = ((utf8[offset + 0] & 0x0f) << 12) | ((utf8[offset + 1] & 0x3f) << 6) | (utf8[offset + 2] & 0x3f); + offset += 3; + return result; + } + else if (!(utf8[offset + 0] & 0x08)) { + if (offset + 3 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80)) + throw std::invalid_argument("invalid character"); + auto result = ((utf8[offset + 0] & 0x07) << 18) | ((utf8[offset + 1] & 0x3f) << 12) | ((utf8[offset + 2] & 0x3f) << 6) | (utf8[offset + 3] & 0x3f); + offset += 4; + return result; + } + throw std::invalid_argument("invalid string"); +} + +static std::vector codepoints_from_utf8(const std::string & utf8) { + std::vector result; + size_t offset = 0; + while (offset < utf8.size()) { + result.push_back(codepoint_from_utf8(utf8, offset)); + } + return result; +} + +static std::vector codepoint_to_utf16(uint32_t cp) { + std::vector result; + if (/* 0x0000 <= cp && */ cp <= 0xffff) { + result.emplace_back(cp); + } + else if (0x10000 <= cp && cp <= 0x10ffff) { + result.emplace_back(0xd800 | ((cp - 0x10000) >> 10)); + result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff)); + } + else { + throw std::invalid_argument("invalid codepoint"); + } + return result; +} + +static std::vector codepoints_to_utf16(const std::vector & cps) { + std::vector result; + for (size_t i = 0; i < cps.size(); ++i) { + auto temp = codepoint_to_utf16(cps[i]); + result.insert(result.end(), temp.begin(), temp.end()); + } + return result; +} + +static uint32_t codepoint_from_utf16(const std::vector & utf16, size_t & offset) { + assert(offset < utf16.size()); + if (((utf16[0] >> 10) << 10) != 0xd800) { + auto result = utf16[offset + 0]; + offset += 1; + return result; + } + else { + if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) + throw std::invalid_argument("invalid character"); + auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff)); + offset += 2; + return result; + } + throw std::invalid_argument("invalid string"); +} + +static std::vector codepoints_from_utf16(const std::vector & utf16) { + std::vector result; + size_t offset = 0; + while (offset < utf16.size()) + result.push_back(codepoint_from_utf16(utf16, offset)); + return result; +} + +#define CODEPOINT_TYPE_UNIDENTIFIED 0 +#define CODEPOINT_TYPE_DIGIT 1 +#define CODEPOINT_TYPE_LETTER 2 +#define CODEPOINT_TYPE_WHITESPACE 3 +#define CODEPOINT_TYPE_ACCENT_MARK 4 +#define CODEPOINT_TYPE_PUNCTUATION 5 +#define CODEPOINT_TYPE_SYMBOL 6 +#define CODEPOINT_TYPE_CONTROL 7 + +static std::unordered_map codepoint_type_map() { + std::unordered_map codepoint_types; + for (auto p : digit_ranges) { + for(auto i = p.first; i <= p.second; ++ i) + codepoint_types[i] = CODEPOINT_TYPE_DIGIT; + } + for(auto p : letter_ranges) { + for(auto i = p.first; i <= p.second; ++ i) + codepoint_types[i] = CODEPOINT_TYPE_LETTER; + } + for(auto p : whitespace_ranges) { + for(auto i = p.first; i <= p.second; ++ i) + codepoint_types[i] = CODEPOINT_TYPE_WHITESPACE; + } + for(auto p : accent_mark_ranges) { + for(auto i = p.first; i <= p.second; ++ i) + codepoint_types[i] = CODEPOINT_TYPE_ACCENT_MARK; + } + for(auto p : punctuation_ranges) { + for(auto i = p.first; i <= p.second; ++ i) + codepoint_types[i] = CODEPOINT_TYPE_PUNCTUATION; + } + for (auto p : symbol_ranges) { + for (auto i = p.first; i <= p.second; ++i) + codepoint_types[i] = CODEPOINT_TYPE_SYMBOL; + } + for(auto p : control_ranges) { + for(auto i = p.first; i <= p.second; ++ i) + codepoint_types[i] = CODEPOINT_TYPE_CONTROL; + } + return codepoint_types; +} + +static int codepoint_type(uint32_t cp) { + static std::unordered_map codepoint_types = codepoint_type_map(); + return codepoint_types[cp]; +} + +static int codepoint_type(const std::string & utf8) { + if (utf8.length() == 0) + return CODEPOINT_TYPE_UNIDENTIFIED; + size_t offset = 0; + return codepoint_type(codepoint_from_utf8(utf8, offset)); +} + +static std::unordered_map bytes_to_unicode_map_bpe() { + std::unordered_map map; + for (int ch = u'!'; ch <= u'~'; ++ch) { + assert(0 <= ch && ch < 256); + map[ch] = codepoint_to_utf8(ch); + } + for (int ch = u'¡'; ch <= u'¬'; ++ch) { + assert(0 <= ch && ch < 256); + map[ch] = codepoint_to_utf8(ch); + } + for (int ch = u'®'; ch <= u'ÿ'; ++ch) { + assert(0 <= ch && ch < 256); + map[ch] = codepoint_to_utf8(ch); + } + auto n = 0; + for (int ch = 0; ch < 256; ++ch) { + if (map.find(ch) == map.end()) { + map[ch] = codepoint_to_utf8(256 + n); + ++n; + } + } + return map; +} + +static std::string bytes_to_unicode_bpe(uint8_t byte) { + static std::unordered_map map = bytes_to_unicode_map_bpe(); + return map.at(byte); +} + +static std::unordered_map unicode_to_bytes_map_bpe() { + std::unordered_map map; + for (int ch = u'!'; ch <= u'~'; ++ch) { + assert(0 <= ch && ch < 256); + map[codepoint_to_utf8(ch)] = ch; + } + for (int ch = u'¡'; ch <= u'¬'; ++ch) { + assert(0 <= ch && ch < 256); + map[codepoint_to_utf8(ch)] = ch; + } + for (int ch = u'®'; ch <= u'ÿ'; ++ch) { + assert(0 <= ch && ch < 256); + map[codepoint_to_utf8(ch)] = ch; + } + auto n = 0; + for (int ch = 0; ch < 256; ++ch) { + if (map.find(codepoint_to_utf8(ch)) == map.end()) { + map[codepoint_to_utf8(256 + n)] = ch; + ++n; + } + } + return map; +} + +static uint8_t unicode_to_bytes_bpe(const std::string & utf8) { + static std::unordered_map map = unicode_to_bytes_map_bpe(); + return map.at(utf8); +} + From d9d410c192d3af8d0126c19c9d9b342c67c32d87 Mon Sep 17 00:00:00 2001 From: hydai Date: Thu, 2 Nov 2023 01:25:01 +0800 Subject: [PATCH 175/623] [WASI-NN] OpenBLAS should be OFF when cuBLAS is ON Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 886d828a..fc8935b7 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -8,8 +8,9 @@ set(LLAMA_METAL_NDEBUG ON) if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_CUBLAS) message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_CUBLAS") - # Default use OpenBLAS set(LLAMA_CUBLAS ON) + # If CUBLAS is ON, then OpenBLAS should be OFF. + set(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_BLAS OFF) else() message(STATUS "WASI-NN GGML LLAMA backend: Disable LLAMA_CUBLAS") set(LLAMA_CUBLAS OFF) From 53ae40b9c06900a7b3059067a254fdb081812bdd Mon Sep 17 00:00:00 2001 From: hydai Date: Thu, 2 Nov 2023 01:37:59 +0800 Subject: [PATCH 176/623] [WASI-NN] Workaround: reload model again if the n_gpu_layers is changed in setInput Signed-off-by: hydai --- plugins/wasi_nn/ggml.cpp | 43 ++++++++++++++++++++++------------------ plugins/wasi_nn/ggml.h | 5 +++-- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 109fa042..8e95f2cf 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -54,26 +54,9 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, gpt_params Params; llama_backend_init(Params.numa); llama_model_params ModelParams = llama_model_default_params(); - - const char *LlamaNGPULayerEnv = std::getenv("LLAMA_N_GL"); - if (LlamaNGPULayerEnv != nullptr) { - try { - ModelParams.n_gpu_layers = std::stoi(LlamaNGPULayerEnv); - } catch (const std::out_of_range &e) { - spdlog::error( - "[WASI-NN] GGML backend: set n_gpu_layers failed: out_of_range {}"sv, - e.what()); - return ErrNo::InvalidArgument; - } catch (const std::invalid_argument &e) { - spdlog::error( - "[WASI-NN] GGML backend: set n_gpu_layers failed: invalid_argument {}"sv, - e.what()); - return ErrNo::InvalidArgument; - } - } - + GraphRef.ModelFilePath = ModelFilePath; GraphRef.LlamaModel = - llama_load_model_from_file(ModelFilePath.c_str(), ModelParams); + llama_load_model_from_file(GraphRef.ModelFilePath.c_str(), ModelParams); if (GraphRef.LlamaModel == nullptr) { spdlog::error("[WASI-NN] GGML backend: Error: unable to init model."sv); Env.NNGraph.pop_back(); @@ -194,6 +177,28 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, llama_context_params ContextParams = llama_context_default_params(); ContextParams.n_ctx = CxtRef.CtxSize; ContextParams.n_batch = CxtRef.BatchSize; + + // XXX: Due to the limitation of WASI-NN proposal, + // we have no way to pass the metadata before the setInput phase + // when we want to do some configurations in the load phase. + // That's why we have this hack. + { + llama_model_params ModelParams = llama_model_default_params(); + // If the `n_gpu_layers` in `setInput` is different from the + // `n_gpu_layers` in `llama_model_params`, we will reload + // the model with the new configuration. + if (ModelParams.n_gpu_layers != CxtRef.NGPULayers) { + ModelParams.n_gpu_layers = CxtRef.NGPULayers; + GraphRef.LlamaModel = llama_load_model_from_file( + GraphRef.ModelFilePath.c_str(), ModelParams); + if (GraphRef.LlamaModel == nullptr) { + spdlog::error("[WASI-NN] GGML backend: Error: unable to init model."sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } + } + } + GraphRef.LlamaContext = llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index ae165436..fad69026 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -18,8 +18,9 @@ namespace WasmEdge::Host::WASINN::GGML { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML struct Graph { - llama_model *LlamaModel; - llama_context *LlamaContext; + llama_model *LlamaModel = nullptr; + llama_context *LlamaContext = nullptr; + std::string ModelFilePath; }; struct Context { From c578eebf84fe8db6c569841308468d264b9d302e Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 1 Nov 2023 22:55:10 -0500 Subject: [PATCH 177/623] [WASI-NN] llama.cpp: Fix segfault of the wrong usage of the user_data field. Signed-off-by: hydai --- plugins/wasi_nn/ggml.cpp | 3 ++- plugins/wasi_nn/thirdparty/ggml/llama.cpp | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 8e95f2cf..4231b9d2 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -117,6 +117,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, "[WASI-NN] GGML backend: Unable to retrieve the enable-log option."sv); return ErrNo::InvalidArgument; } + llama_log_set(nullptr, &CxtRef.EnableLog); } if (Doc.at_key("stream-stdout").error() == simdjson::SUCCESS) { auto Err = Doc["stream-stdout"].get().get(CxtRef.StreamStdout); @@ -187,7 +188,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, // If the `n_gpu_layers` in `setInput` is different from the // `n_gpu_layers` in `llama_model_params`, we will reload // the model with the new configuration. - if (ModelParams.n_gpu_layers != CxtRef.NGPULayers) { + if (ModelParams.n_gpu_layers != static_cast(CxtRef.NGPULayers)) { ModelParams.n_gpu_layers = CxtRef.NGPULayers; GraphRef.LlamaModel = llama_load_model_from_file( GraphRef.ModelFilePath.c_str(), ModelParams); diff --git a/plugins/wasi_nn/thirdparty/ggml/llama.cpp b/plugins/wasi_nn/thirdparty/ggml/llama.cpp index ef5eeb99..fa2b3d2f 100644 --- a/plugins/wasi_nn/thirdparty/ggml/llama.cpp +++ b/plugins/wasi_nn/thirdparty/ggml/llama.cpp @@ -9625,8 +9625,8 @@ static void llama_log_internal(ggml_log_level level, const char * format, ...) { static void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data) { (void) level; - bool enable_log = static_cast(user_data); - if (enable_log) { + bool * enable_log = static_cast(user_data); + if (enable_log && *enable_log) { fputs(text, stderr); fflush(stderr); } From 8103fe87ca442881e5bbdf9467f08f2d0b7f18f5 Mon Sep 17 00:00:00 2001 From: hydai Date: Fri, 3 Nov 2023 15:10:30 +0800 Subject: [PATCH 178/623] [WASI-NN] ggml: Set n_gpu_layers to 0 on the macOS platform (#3004) Signed-off-by: hydai --- plugins/wasi_nn/ggml.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 4231b9d2..deee8ce4 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -144,12 +144,18 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } } if (Doc.at_key("n-gpu-layers").error() == simdjson::SUCCESS) { + // Metal framework has the different behavior of CUDA. + // Hence, we have to set the n_gpu_layers to 0 on the macOS platform. +#ifndef __APPLE__ auto Err = Doc["n-gpu-layers"].get().get(CxtRef.NGPULayers); if (Err) { spdlog::error( "[WASI-NN] GGML backend: Unable to retrieve the n-gpu-layers option."sv); return ErrNo::InvalidArgument; } +#else + CxtRef.NGPULayers = 0; +#endif } if (Doc.at_key("batch-size").error() == simdjson::SUCCESS) { auto Err = Doc["batch-size"].get().get(CxtRef.BatchSize); From 1e88303709f4cb163de75a1bf7b32c97477ec99c Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 7 Nov 2023 16:31:56 +0800 Subject: [PATCH 179/623] [WASI-NN] ggml: Ensure the model/cxt is re-allocated when params change Signed-off-by: hydai --- plugins/wasi_nn/ggml.cpp | 204 ++++++++++-------- plugins/wasi_nn/ggml.h | 7 +- .../wasi_nn/thirdparty/ggml/CMakeLists.txt | 2 +- 3 files changed, 123 insertions(+), 90 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index deee8ce4..875dd9fd 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -13,6 +13,107 @@ namespace WasmEdge::Host::WASINN::GGML { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML + +namespace details { +Expect getMetadata(Context &CxtRef, const TensorData &Tensor, + bool &IsUpdated) noexcept { + // Decode metadata. + const std::string Metadata(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + auto ParseError = Parser.parse(Metadata).get(Doc); + if (ParseError) { + spdlog::error("[WASI-NN] GGML backend: Parse metadata error"sv); + return ErrNo::InvalidEncoding; + } + + // Get metadata from the json. + // Need to update Model: + // * n_gpu_layers + // Need to update Context: + // * ctx-size + // * batch-size + // Initialize the llama context. + llama_context_params CxtParams = llama_context_default_params(); + CxtParams.n_ctx = CxtRef.CtxSize; + CxtParams.n_batch = CxtRef.BatchSize; + + if (Doc.at_key("enable-log").error() == simdjson::SUCCESS) { + auto Err = Doc["enable-log"].get().get(CxtRef.EnableLog); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the enable-log option."sv); + return ErrNo::InvalidArgument; + } + llama_log_set(nullptr, &CxtRef.EnableLog); + } + if (Doc.at_key("stream-stdout").error() == simdjson::SUCCESS) { + auto Err = Doc["stream-stdout"].get().get(CxtRef.StreamStdout); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the stream-stdout option."sv); + return ErrNo::InvalidArgument; + } + } + if (Doc.at_key("ctx-size").error() == simdjson::SUCCESS) { + auto Err = Doc["ctx-size"].get().get(CxtRef.CtxSize); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the ctx-size option."sv); + return ErrNo::InvalidArgument; + } + } + if (Doc.at_key("n-predict").error() == simdjson::SUCCESS) { + auto Err = Doc["n-predict"].get().get(CxtRef.NPredict); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the n-predict option."sv); + return ErrNo::InvalidArgument; + } + } + if (Doc.at_key("n-gpu-layers").error() == simdjson::SUCCESS) { + // Metal framework has the different behavior of CUDA. + // Hence, we have to set the n_gpu_layers to 0 on the macOS platform. +#ifndef __APPLE__ + auto Err = Doc["n-gpu-layers"].get().get(CxtRef.NGPULayers); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the n-gpu-layers option."sv); + return ErrNo::InvalidArgument; + } +#else + CxtRef.NGPULayers = 0; +#endif + } + if (Doc.at_key("batch-size").error() == simdjson::SUCCESS) { + auto Err = Doc["batch-size"].get().get(CxtRef.BatchSize); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the batch-size option."sv); + return ErrNo::InvalidArgument; + } + } + if (Doc.at_key("reverse-prompt").error() == simdjson::SUCCESS) { + std::string_view ReversePrompt; + auto Err = Doc["reverse-prompt"].get().get(ReversePrompt); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the reverse-prompt option."sv); + return ErrNo::InvalidArgument; + } + CxtRef.ReversePrompt = ReversePrompt; + } + + // Check if the context is updated. + if (CxtParams.n_ctx != CxtRef.CtxSize || + CxtParams.n_batch != CxtRef.BatchSize) { + IsUpdated = true; + } + return ErrNo::Success; +} +} // namespace details + Expect load(WasiNNEnvironment &Env, Span> Builders, [[maybe_unused]] Device Device, uint32_t &GraphId) noexcept { // The graph builder length must be 1. @@ -96,99 +197,17 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + bool IsCxtParamsUpdated = false; // Use index 1 for metadata. if (Index == 1) { - // Decode metadata. - std::string Metadata(reinterpret_cast(Tensor.Tensor.data()), - Tensor.Tensor.size()); - simdjson::dom::parser Parser; - simdjson::dom::element Doc; - auto ParseError = Parser.parse(Metadata).get(Doc); - if (ParseError) { - spdlog::error("[WASI-NN] GGML backend: Parse metadata error"sv); - return ErrNo::InvalidEncoding; - } - - // Get metadata from the json. - if (Doc.at_key("enable-log").error() == simdjson::SUCCESS) { - auto Err = Doc["enable-log"].get().get(CxtRef.EnableLog); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the enable-log option."sv); - return ErrNo::InvalidArgument; - } - llama_log_set(nullptr, &CxtRef.EnableLog); - } - if (Doc.at_key("stream-stdout").error() == simdjson::SUCCESS) { - auto Err = Doc["stream-stdout"].get().get(CxtRef.StreamStdout); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the stream-stdout option."sv); - return ErrNo::InvalidArgument; - } - } - if (Doc.at_key("ctx-size").error() == simdjson::SUCCESS) { - auto Err = Doc["ctx-size"].get().get(CxtRef.CtxSize); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the ctx-size option."sv); - return ErrNo::InvalidArgument; - } - } - if (Doc.at_key("n-predict").error() == simdjson::SUCCESS) { - auto Err = Doc["n-predict"].get().get(CxtRef.NPredict); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the n-predict option."sv); - return ErrNo::InvalidArgument; - } - } - if (Doc.at_key("n-gpu-layers").error() == simdjson::SUCCESS) { - // Metal framework has the different behavior of CUDA. - // Hence, we have to set the n_gpu_layers to 0 on the macOS platform. -#ifndef __APPLE__ - auto Err = Doc["n-gpu-layers"].get().get(CxtRef.NGPULayers); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the n-gpu-layers option."sv); - return ErrNo::InvalidArgument; - } -#else - CxtRef.NGPULayers = 0; -#endif - } - if (Doc.at_key("batch-size").error() == simdjson::SUCCESS) { - auto Err = Doc["batch-size"].get().get(CxtRef.BatchSize); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the batch-size option."sv); - return ErrNo::InvalidArgument; - } - } - if (Doc.at_key("reverse-prompt").error() == simdjson::SUCCESS) { - std::string_view ReversePrompt; - auto Err = - Doc["reverse-prompt"].get().get(ReversePrompt); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the reverse-prompt option."sv); - return ErrNo::InvalidArgument; - } - CxtRef.ReversePrompt = ReversePrompt; - } - - return ErrNo::Success; + return details::getMetadata(CxtRef, Tensor, IsCxtParamsUpdated); } - // Initialize the llama context. - llama_context_params ContextParams = llama_context_default_params(); - ContextParams.n_ctx = CxtRef.CtxSize; - ContextParams.n_batch = CxtRef.BatchSize; - // XXX: Due to the limitation of WASI-NN proposal, // we have no way to pass the metadata before the setInput phase // when we want to do some configurations in the load phase. // That's why we have this hack. +#ifndef __APPLE__ { llama_model_params ModelParams = llama_model_default_params(); // If the `n_gpu_layers` in `setInput` is different from the @@ -196,6 +215,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, // the model with the new configuration. if (ModelParams.n_gpu_layers != static_cast(CxtRef.NGPULayers)) { ModelParams.n_gpu_layers = CxtRef.NGPULayers; + llama_free_model(GraphRef.LlamaModel); GraphRef.LlamaModel = llama_load_model_from_file( GraphRef.ModelFilePath.c_str(), ModelParams); if (GraphRef.LlamaModel == nullptr) { @@ -205,9 +225,19 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } } } +#endif - GraphRef.LlamaContext = - llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); + // Initialize the llama context. + if (GraphRef.LlamaContext == nullptr || IsCxtParamsUpdated) { + llama_context_params ContextParams = llama_context_default_params(); + ContextParams.n_ctx = CxtRef.CtxSize; + ContextParams.n_batch = CxtRef.BatchSize; + if (GraphRef.LlamaContext != nullptr) { + llama_free(GraphRef.LlamaContext); + } + GraphRef.LlamaContext = + llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); + } // Set the input. std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index fad69026..2424d913 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -29,13 +29,16 @@ struct Context { size_t GraphId; std::vector LlamaInputs; std::string LlamaOutputs; + // Plugin parameters: bool EnableLog; bool StreamStdout; - uint64_t CtxSize; uint64_t NPredict; + std::string ReversePrompt; + // Model parameters: uint64_t NGPULayers; + // Context parameters: + uint64_t CtxSize; uint64_t BatchSize; - std::string ReversePrompt; }; #else struct Graph {}; diff --git a/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt b/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt index b2fa2ce9..75867a3d 100644 --- a/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt +++ b/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt @@ -42,7 +42,7 @@ if (NOT MSVC) endif() # 3rd party libs -option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) +option(LLAMA_ACCELERATE "llama: enable Accelerate framework" OFF) option(LLAMA_BLAS "llama: use BLAS" OFF) set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") option(LLAMA_CUBLAS "llama: use CUDA" OFF) From d67a3b0534766278c65844e74b154ead2aa41e86 Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 8 Nov 2023 17:24:45 +0800 Subject: [PATCH 180/623] [WASI-NN] Support model parameters from cli - Use `--nn-preload` with model file path and configs. - Move NGPULayers from Context to Graph. Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 108 +++++++++++++++++++++++++++++---------- plugins/wasi_nn/ggml.h | 4 +- 2 files changed, 84 insertions(+), 28 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 875dd9fd..ddb1f545 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -15,8 +15,9 @@ namespace WasmEdge::Host::WASINN::GGML { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML namespace details { -Expect getMetadata(Context &CxtRef, const TensorData &Tensor, - bool &IsUpdated) noexcept { +Expect getMetadata(Context &CxtRef, Graph &GraphRef, + const TensorData &Tensor, bool &IsCxtUpdated, + bool &IsModelUpdated) noexcept { // Decode metadata. const std::string Metadata(reinterpret_cast(Tensor.Tensor.data()), Tensor.Tensor.size()); @@ -34,7 +35,10 @@ Expect getMetadata(Context &CxtRef, const TensorData &Tensor, // Need to update Context: // * ctx-size // * batch-size - // Initialize the llama context. + + // Get the current llama parameters. + llama_model_params ModelParams = llama_model_default_params(); + ModelParams.n_gpu_layers = GraphRef.NGPULayers; llama_context_params CxtParams = llama_context_default_params(); CxtParams.n_ctx = CxtRef.CtxSize; CxtParams.n_batch = CxtRef.BatchSize; @@ -73,18 +77,12 @@ Expect getMetadata(Context &CxtRef, const TensorData &Tensor, } } if (Doc.at_key("n-gpu-layers").error() == simdjson::SUCCESS) { - // Metal framework has the different behavior of CUDA. - // Hence, we have to set the n_gpu_layers to 0 on the macOS platform. -#ifndef __APPLE__ - auto Err = Doc["n-gpu-layers"].get().get(CxtRef.NGPULayers); + auto Err = Doc["n-gpu-layers"].get().get(GraphRef.NGPULayers); if (Err) { spdlog::error( "[WASI-NN] GGML backend: Unable to retrieve the n-gpu-layers option."sv); return ErrNo::InvalidArgument; } -#else - CxtRef.NGPULayers = 0; -#endif } if (Doc.at_key("batch-size").error() == simdjson::SUCCESS) { auto Err = Doc["batch-size"].get().get(CxtRef.BatchSize); @@ -105,11 +103,17 @@ Expect getMetadata(Context &CxtRef, const TensorData &Tensor, CxtRef.ReversePrompt = ReversePrompt; } + // Check if the model is updated. + if (ModelParams.n_gpu_layers != GraphRef.NGPULayers) { + IsModelUpdated = true; + } + // Check if the context is updated. if (CxtParams.n_ctx != CxtRef.CtxSize || CxtParams.n_batch != CxtRef.BatchSize) { - IsUpdated = true; + IsCxtUpdated = true; } + return ErrNo::Success; } } // namespace details @@ -124,11 +128,67 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, return ErrNo::InvalidArgument; } + // Add a new graph. + Env.NNGraph.emplace_back(Backend::GGML); + auto &GraphRef = Env.NNGraph.back().get(); + + // Initialize the model parameters. + GraphRef.NGPULayers = 0; + + // Handle the model path. auto Weight = Builders[0]; std::string BinModel(reinterpret_cast(Weight.data()), Weight.size()); std::string ModelFilePath; if (BinModel.substr(0, 8) == "preload:") { - ModelFilePath = BinModel.substr(8); + // If BinModel starts with 'preload:', it means that the model name passed + // in as the --nn-preload parameter may have a config separated by ',' at + // the end. For example, "preload:./model.bin,n_gpu_layers=99" + std::vector Configs; + std::string Delimiter = ","; + std::string ModelFilePathWithConfig = BinModel.substr(8); + if (ModelFilePathWithConfig.find(Delimiter) == std::string::npos) { + ModelFilePath = ModelFilePathWithConfig; + } else { + // Handle model path with config. + size_t Pos = 0; + std::string Token; + Pos = ModelFilePathWithConfig.find(Delimiter); + ModelFilePath = ModelFilePathWithConfig.substr(0, Pos); + ModelFilePathWithConfig.erase(0, Pos + Delimiter.length()); + while ((Pos = ModelFilePathWithConfig.find(Delimiter)) != + std::string::npos) { + Token = ModelFilePathWithConfig.substr(0, Pos); + Configs.push_back(Token); + ModelFilePathWithConfig.erase(0, Pos + Delimiter.length()); + } + Configs.push_back(ModelFilePathWithConfig); + } + // Parse the configs. + for (const auto &Config : Configs) { + std::string Delimiter = "="; + size_t Pos = 0; + std::string Token; + Pos = Config.find(Delimiter); + Token = Config.substr(0, Pos); + try { + if (Token == "n_gpu_layers" || Token == "ngl") { + GraphRef.NGPULayers = + std::stoi(Config.substr(Pos + Delimiter.length())); + } + } catch (const std::invalid_argument &e) { + spdlog::error( + "[WASI-NN] GGML backend: parse model parameter failed: invalid_argument {}"sv, + e.what()); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } catch (const std::out_of_range &e) { + spdlog::error( + "[WASI-NN] GGML backend: parse parameter failed: out_of_range {}"sv, + e.what()); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } + } } else { // TODO: pass the model directly to ggml // Write ggml model to file. @@ -141,21 +201,19 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, "Currently, our workaround involves creating a temporary model " "file named \"ggml-model.bin\" and passing this filename as a " "parameter to the ggml llama library."sv); + Env.NNGraph.pop_back(); return ErrNo::InvalidArgument; } TempFile << BinModel; TempFile.close(); } - // Add a new graph. - Env.NNGraph.emplace_back(Backend::GGML); - auto &GraphRef = Env.NNGraph.back().get(); - - // Initialize ggml model. + // Initialize ggml model with model parameters. gpt_params Params; llama_backend_init(Params.numa); - llama_model_params ModelParams = llama_model_default_params(); GraphRef.ModelFilePath = ModelFilePath; + llama_model_params ModelParams = llama_model_default_params(); + ModelParams.n_gpu_layers = GraphRef.NGPULayers; GraphRef.LlamaModel = llama_load_model_from_file(GraphRef.ModelFilePath.c_str(), ModelParams); if (GraphRef.LlamaModel == nullptr) { @@ -185,7 +243,6 @@ Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, CxtRef.StreamStdout = false; CxtRef.CtxSize = ContextDefault.n_ctx; CxtRef.NPredict = ContextDefault.n_ctx; - CxtRef.NGPULayers = 0; CxtRef.BatchSize = ContextDefault.n_batch; CxtRef.ReversePrompt = ""sv; @@ -198,9 +255,11 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); bool IsCxtParamsUpdated = false; + bool IsModelParamsUpdated = false; // Use index 1 for metadata. if (Index == 1) { - return details::getMetadata(CxtRef, Tensor, IsCxtParamsUpdated); + return details::getMetadata(CxtRef, GraphRef, Tensor, IsCxtParamsUpdated, + IsModelParamsUpdated); } // XXX: Due to the limitation of WASI-NN proposal, @@ -209,12 +268,9 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, // That's why we have this hack. #ifndef __APPLE__ { - llama_model_params ModelParams = llama_model_default_params(); - // If the `n_gpu_layers` in `setInput` is different from the - // `n_gpu_layers` in `llama_model_params`, we will reload - // the model with the new configuration. - if (ModelParams.n_gpu_layers != static_cast(CxtRef.NGPULayers)) { - ModelParams.n_gpu_layers = CxtRef.NGPULayers; + if (IsModelParamsUpdated) { + llama_model_params ModelParams = llama_model_default_params(); + ModelParams.n_gpu_layers = GraphRef.NGPULayers; llama_free_model(GraphRef.LlamaModel); GraphRef.LlamaModel = llama_load_model_from_file( GraphRef.ModelFilePath.c_str(), ModelParams); diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index 2424d913..1347ee3c 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -21,6 +21,8 @@ struct Graph { llama_model *LlamaModel = nullptr; llama_context *LlamaContext = nullptr; std::string ModelFilePath; + // Model parameters: + int64_t NGPULayers; }; struct Context { @@ -34,8 +36,6 @@ struct Context { bool StreamStdout; uint64_t NPredict; std::string ReversePrompt; - // Model parameters: - uint64_t NGPULayers; // Context parameters: uint64_t CtxSize; uint64_t BatchSize; From 7b3e6abaa2ed3be916b0297586547e3e2a156a33 Mon Sep 17 00:00:00 2001 From: dm4 Date: Fri, 10 Nov 2023 17:13:37 +0800 Subject: [PATCH 181/623] [WASI-NN] ggml: support load_by_name_with_config - When the graph builder length > 1, the data of builder[1] is the metadata. - The metadata is a json string as before. - Still support set metadata from set_input with index 1 for backward compatibility. Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 168 +++++++++++++++++-------------- plugins/wasi_nn/ggml.h | 18 ++-- plugins/wasi_nn/wasinnenv.h | 9 +- plugins/wasi_nn/wasinnfunc.cpp | 46 ++++++++- plugins/wasi_nn/wasinnfunc.h | 18 ++++ plugins/wasi_nn/wasinnmodule.cpp | 2 + test/plugins/wasi_nn/wasi_nn.cpp | 4 +- 7 files changed, 174 insertions(+), 91 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index ddb1f545..ae08819d 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -15,12 +15,9 @@ namespace WasmEdge::Host::WASINN::GGML { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML namespace details { -Expect getMetadata(Context &CxtRef, Graph &GraphRef, - const TensorData &Tensor, bool &IsCxtUpdated, - bool &IsModelUpdated) noexcept { - // Decode metadata. - const std::string Metadata(reinterpret_cast(Tensor.Tensor.data()), - Tensor.Tensor.size()); +Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, + bool *IsCxtUpdated = nullptr, + bool *IsModelUpdated = nullptr) noexcept { simdjson::dom::parser Parser; simdjson::dom::element Doc; auto ParseError = Parser.parse(Metadata).get(Doc); @@ -40,42 +37,47 @@ Expect getMetadata(Context &CxtRef, Graph &GraphRef, llama_model_params ModelParams = llama_model_default_params(); ModelParams.n_gpu_layers = GraphRef.NGPULayers; llama_context_params CxtParams = llama_context_default_params(); - CxtParams.n_ctx = CxtRef.CtxSize; - CxtParams.n_batch = CxtRef.BatchSize; + CxtParams.n_ctx = GraphRef.CtxSize; + CxtParams.n_batch = GraphRef.BatchSize; + // The plugin parameters. if (Doc.at_key("enable-log").error() == simdjson::SUCCESS) { - auto Err = Doc["enable-log"].get().get(CxtRef.EnableLog); + auto Err = Doc["enable-log"].get().get(GraphRef.EnableLog); if (Err) { spdlog::error( "[WASI-NN] GGML backend: Unable to retrieve the enable-log option."sv); return ErrNo::InvalidArgument; } - llama_log_set(nullptr, &CxtRef.EnableLog); + llama_log_set(nullptr, &GraphRef.EnableLog); } if (Doc.at_key("stream-stdout").error() == simdjson::SUCCESS) { - auto Err = Doc["stream-stdout"].get().get(CxtRef.StreamStdout); + auto Err = Doc["stream-stdout"].get().get(GraphRef.StreamStdout); if (Err) { spdlog::error( "[WASI-NN] GGML backend: Unable to retrieve the stream-stdout option."sv); return ErrNo::InvalidArgument; } } - if (Doc.at_key("ctx-size").error() == simdjson::SUCCESS) { - auto Err = Doc["ctx-size"].get().get(CxtRef.CtxSize); + if (Doc.at_key("n-predict").error() == simdjson::SUCCESS) { + auto Err = Doc["n-predict"].get().get(GraphRef.NPredict); if (Err) { spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the ctx-size option."sv); + "[WASI-NN] GGML backend: Unable to retrieve the n-predict option."sv); return ErrNo::InvalidArgument; } } - if (Doc.at_key("n-predict").error() == simdjson::SUCCESS) { - auto Err = Doc["n-predict"].get().get(CxtRef.NPredict); + if (Doc.at_key("reverse-prompt").error() == simdjson::SUCCESS) { + std::string_view ReversePrompt; + auto Err = Doc["reverse-prompt"].get().get(ReversePrompt); if (Err) { spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the n-predict option."sv); + "[WASI-NN] GGML backend: Unable to retrieve the reverse-prompt option."sv); return ErrNo::InvalidArgument; } + GraphRef.ReversePrompt = ReversePrompt; } + + // The model parameters. if (Doc.at_key("n-gpu-layers").error() == simdjson::SUCCESS) { auto Err = Doc["n-gpu-layers"].get().get(GraphRef.NGPULayers); if (Err) { @@ -84,34 +86,34 @@ Expect getMetadata(Context &CxtRef, Graph &GraphRef, return ErrNo::InvalidArgument; } } - if (Doc.at_key("batch-size").error() == simdjson::SUCCESS) { - auto Err = Doc["batch-size"].get().get(CxtRef.BatchSize); + + // The context parameters. + if (Doc.at_key("ctx-size").error() == simdjson::SUCCESS) { + auto Err = Doc["ctx-size"].get().get(GraphRef.CtxSize); if (Err) { spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the batch-size option."sv); + "[WASI-NN] GGML backend: Unable to retrieve the ctx-size option."sv); return ErrNo::InvalidArgument; } } - if (Doc.at_key("reverse-prompt").error() == simdjson::SUCCESS) { - std::string_view ReversePrompt; - auto Err = Doc["reverse-prompt"].get().get(ReversePrompt); + if (Doc.at_key("batch-size").error() == simdjson::SUCCESS) { + auto Err = Doc["batch-size"].get().get(GraphRef.BatchSize); if (Err) { spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the reverse-prompt option."sv); + "[WASI-NN] GGML backend: Unable to retrieve the batch-size option."sv); return ErrNo::InvalidArgument; } - CxtRef.ReversePrompt = ReversePrompt; } // Check if the model is updated. - if (ModelParams.n_gpu_layers != GraphRef.NGPULayers) { - IsModelUpdated = true; + if (IsModelUpdated && ModelParams.n_gpu_layers != GraphRef.NGPULayers) { + *IsModelUpdated = true; } // Check if the context is updated. - if (CxtParams.n_ctx != CxtRef.CtxSize || - CxtParams.n_batch != CxtRef.BatchSize) { - IsCxtUpdated = true; + if (IsCxtUpdated && (CxtParams.n_ctx != GraphRef.CtxSize || + CxtParams.n_batch != GraphRef.BatchSize)) { + *IsCxtUpdated = true; } return ErrNo::Success; @@ -120,20 +122,34 @@ Expect getMetadata(Context &CxtRef, Graph &GraphRef, Expect load(WasiNNEnvironment &Env, Span> Builders, [[maybe_unused]] Device Device, uint32_t &GraphId) noexcept { - // The graph builder length must be 1. - if (Builders.size() != 1) { - spdlog::error( - "[WASI-NN] GGML backend: Wrong GraphBuilder Length {:d}, expect 1"sv, - Builders.size()); - return ErrNo::InvalidArgument; - } - // Add a new graph. Env.NNGraph.emplace_back(Backend::GGML); auto &GraphRef = Env.NNGraph.back().get(); + // Initialize the plugin parameters. + auto ContextDefault = llama_context_default_params(); + GraphRef.EnableLog = false; + GraphRef.StreamStdout = false; + GraphRef.ReversePrompt = ""sv; + GraphRef.NPredict = ContextDefault.n_ctx; // Initialize the model parameters. GraphRef.NGPULayers = 0; + // Initialize the context parameters. + GraphRef.CtxSize = ContextDefault.n_ctx; + GraphRef.BatchSize = ContextDefault.n_batch; + + // If the graph builder length > 1, the data of builder[1] is the metadata. + if (Builders.size() > 1) { + std::string Metadata(reinterpret_cast(Builders[1].data()), + Builders[1].size()); + // Ignore context or model updates when initializing the graph. + auto Res = details::parseMetadata(GraphRef, Metadata); + if (Res != ErrNo::Success) { + spdlog::error("[WASI-NN] GGML backend: Failed to parse metadata."sv); + Env.NNGraph.pop_back(); + return Res; + } + } // Handle the model path. auto Weight = Builders[0]; @@ -158,10 +174,10 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, while ((Pos = ModelFilePathWithConfig.find(Delimiter)) != std::string::npos) { Token = ModelFilePathWithConfig.substr(0, Pos); - Configs.push_back(Token); + Configs.emplace_back(Token); ModelFilePathWithConfig.erase(0, Pos + Delimiter.length()); } - Configs.push_back(ModelFilePathWithConfig); + Configs.emplace_back(ModelFilePathWithConfig); } // Parse the configs. for (const auto &Config : Configs) { @@ -235,16 +251,15 @@ Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, uint32_t &ContextId) noexcept { Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); ContextId = Env.NNContext.size() - 1; - - // Set the default context options. auto &CxtRef = Env.NNContext[ContextId].get(); - auto ContextDefault = llama_context_default_params(); - CxtRef.EnableLog = false; - CxtRef.StreamStdout = false; - CxtRef.CtxSize = ContextDefault.n_ctx; - CxtRef.NPredict = ContextDefault.n_ctx; - CxtRef.BatchSize = ContextDefault.n_batch; - CxtRef.ReversePrompt = ""sv; + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + + // Initialize the llama context. + llama_context_params ContextParams = llama_context_default_params(); + ContextParams.n_ctx = GraphRef.CtxSize; + ContextParams.n_batch = GraphRef.BatchSize; + CxtRef.LlamaContext = + llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); return ErrNo::Success; } @@ -258,8 +273,10 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, bool IsModelParamsUpdated = false; // Use index 1 for metadata. if (Index == 1) { - return details::getMetadata(CxtRef, GraphRef, Tensor, IsCxtParamsUpdated, - IsModelParamsUpdated); + const std::string Metadata(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); + return details::parseMetadata(GraphRef, Metadata, &IsCxtParamsUpdated, + &IsModelParamsUpdated); } // XXX: Due to the limitation of WASI-NN proposal, @@ -284,22 +301,22 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, #endif // Initialize the llama context. - if (GraphRef.LlamaContext == nullptr || IsCxtParamsUpdated) { + if (CxtRef.LlamaContext == nullptr || IsCxtParamsUpdated) { llama_context_params ContextParams = llama_context_default_params(); - ContextParams.n_ctx = CxtRef.CtxSize; - ContextParams.n_batch = CxtRef.BatchSize; - if (GraphRef.LlamaContext != nullptr) { - llama_free(GraphRef.LlamaContext); + ContextParams.n_ctx = GraphRef.CtxSize; + ContextParams.n_batch = GraphRef.BatchSize; + if (CxtRef.LlamaContext != nullptr) { + llama_free(CxtRef.LlamaContext); } - GraphRef.LlamaContext = + CxtRef.LlamaContext = llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); } // Set the input. std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), Tensor.Tensor.size()); - CxtRef.LlamaInputs = llama_tokenize(GraphRef.LlamaContext, Prompt, true); - const uint32_t MaxContextSize = llama_n_ctx(GraphRef.LlamaContext); + CxtRef.LlamaInputs = llama_tokenize(CxtRef.LlamaContext, Prompt, true); + const uint32_t MaxContextSize = llama_n_ctx(CxtRef.LlamaContext); // Minus 4 for the special tokens. const uint32_t MaxTokensListSize = MaxContextSize - 4; if (CxtRef.LlamaInputs.size() > MaxTokensListSize) { @@ -330,7 +347,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { return ErrNo::InvalidArgument; } - if (CxtRef.EnableLog) { + if (GraphRef.EnableLog) { spdlog::info("[WASI-NN] GGML backend: llama_system_info: {}"sv, llama_print_system_info()); } @@ -341,13 +358,13 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // Main predict loop. // TODO: recompute a compressed context based on previous tokens once the // cache is full. - const int MaxContextSize = llama_n_ctx(GraphRef.LlamaContext); + const int MaxContextSize = llama_n_ctx(CxtRef.LlamaContext); // NPredict is the number of tokens to predict. Same as -n, --n-predict in // llama.cpp. - int NPredict = CxtRef.NPredict; + int NPredict = GraphRef.NPredict; // Evaluate the initial prompt. - llama_batch LlamaBatch = llama_batch_init(CxtRef.BatchSize, 0); + llama_batch LlamaBatch = llama_batch_init(GraphRef.BatchSize, 0); LlamaBatch.n_tokens = CxtRef.LlamaInputs.size(); for (int32_t I = 0; I < LlamaBatch.n_tokens; I++) { LlamaBatch.token[I] = CxtRef.LlamaInputs[I]; @@ -358,7 +375,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // llama_decode will output logits only for the last token of the prompt LlamaBatch.logits[LlamaBatch.n_tokens - 1] = true; - if (llama_decode(GraphRef.LlamaContext, LlamaBatch) != 0) { + if (llama_decode(CxtRef.LlamaContext, LlamaBatch) != 0) { spdlog::info("[WASI-NN] GGML backend: llama_decode() failed"sv); return ErrNo::RuntimeError; } @@ -368,7 +385,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // Sample the next token auto NVocab = llama_n_vocab(GraphRef.LlamaModel); auto *Logits = - llama_get_logits_ith(GraphRef.LlamaContext, LlamaBatch.n_tokens - 1); + llama_get_logits_ith(CxtRef.LlamaContext, LlamaBatch.n_tokens - 1); std::vector Candidates; Candidates.reserve(NVocab); @@ -380,19 +397,19 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // Sample the most likely token const llama_token NewTokenId = - llama_sample_token_greedy(GraphRef.LlamaContext, &CandidatesP); + llama_sample_token_greedy(CxtRef.LlamaContext, &CandidatesP); // Is it an end of stream? - if (NewTokenId == llama_token_eos(GraphRef.LlamaContext) || + if (NewTokenId == llama_token_eos(CxtRef.LlamaContext) || NCur == MaxContextSize || NCur == NPredict) { break; } std::string NextToken = - llama_token_to_piece(GraphRef.LlamaContext, NewTokenId); + llama_token_to_piece(CxtRef.LlamaContext, NewTokenId); // When setting StreamStdout, we print the output to stdout. - if (CxtRef.StreamStdout) { + if (GraphRef.StreamStdout) { std::cout << NextToken << std::flush; } @@ -411,24 +428,23 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { NCur += 1; // Evaluate the current batch with the transformer model - if (llama_decode(GraphRef.LlamaContext, LlamaBatch)) { + if (llama_decode(CxtRef.LlamaContext, LlamaBatch)) { spdlog::error("[WASI-NN] GGML backend: failed to llama_decode"sv); return ErrNo::RuntimeError; } // Break if reverse prompt is found. - if (!CxtRef.ReversePrompt.empty() && - CxtRef.LlamaOutputs.find(CxtRef.ReversePrompt) != std::string::npos) { - if (CxtRef.EnableLog) { + if (!GraphRef.ReversePrompt.empty() && + CxtRef.LlamaOutputs.find(GraphRef.ReversePrompt) != std::string::npos) { + if (GraphRef.EnableLog) { spdlog::info("[WASI-NN] GGML backend: reverse prompt found"sv); } break; } } - if (CxtRef.EnableLog) { - llama_log_set(nullptr, &CxtRef.EnableLog); - llama_print_timings(GraphRef.LlamaContext); + if (GraphRef.EnableLog) { + llama_print_timings(CxtRef.LlamaContext); } return ErrNo::Success; diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index 1347ee3c..e24ae954 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -19,26 +19,26 @@ namespace WasmEdge::Host::WASINN::GGML { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML struct Graph { llama_model *LlamaModel = nullptr; - llama_context *LlamaContext = nullptr; std::string ModelFilePath; + // Plugin parameters: + bool EnableLog; + bool StreamStdout; + uint64_t NPredict; + std::string ReversePrompt; // Model parameters: int64_t NGPULayers; + // Context parameters: + uint64_t CtxSize; + uint64_t BatchSize; }; struct Context { public: Context(size_t GId, Graph &) noexcept : GraphId(GId) {} size_t GraphId; + llama_context *LlamaContext = nullptr; std::vector LlamaInputs; std::string LlamaOutputs; - // Plugin parameters: - bool EnableLog; - bool StreamStdout; - uint64_t NPredict; - std::string ReversePrompt; - // Context parameters: - uint64_t CtxSize; - uint64_t BatchSize; }; #else struct Graph {}; diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index 5803e61a..7d919b5c 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -163,8 +163,9 @@ struct WasiNNEnvironment : return false; } - Expect mdBuild(std::string Name, uint32_t &GraphId, - Callback Load) noexcept { + Expect + mdBuild(std::string Name, uint32_t &GraphId, Callback Load, + std::vector Config = std::vector()) noexcept { std::unique_lock Lock(MdMutex); auto It = RawMdMap.find(Name); if (It != RawMdMap.end()) { @@ -174,6 +175,10 @@ struct WasiNNEnvironment : for (auto &Builder : RawMd) { Builders.emplace_back(Builder); } + // Add config to the end of Builders if exists. + if (Config.size() > 0) { + Builders.emplace_back(Config); + } auto Result = Load(*this, Builders, std::get<1>(It->second), std::get<2>(It->second), GraphId); if (Result.has_value()) { diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index 13170c31..39b51009 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -107,14 +107,14 @@ WasiNNLoadByName::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t NamePtr, } // Get the name of model - uint32_t *Name = MemInst->getPointer(NamePtr); + auto Name = MemInst->getPointer(NamePtr); if (unlikely(Name == nullptr)) { spdlog::error("[WASI-NN] Failed when accessing the return Name memory."sv); return WASINN::ErrNo::InvalidArgument; } // Get the model - std::string ModelName(reinterpret_cast(Name), NameLen); + std::string ModelName(reinterpret_cast(Name), NameLen); if (Env.mdGet(ModelName, *GraphId)) { return WASINN::ErrNo::Success; } else { @@ -122,6 +122,48 @@ WasiNNLoadByName::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t NamePtr, } } +Expect WasiNNLoadByNameWithConfig::bodyImpl( + const Runtime::CallingFrame &Frame, uint32_t NamePtr, uint32_t NameLen, + uint32_t ConfigPtr, uint32_t ConfigLen, uint32_t GraphIdPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + // Check the return value: GraphIdPtr should be valid. + auto GraphId = MemInst->getPointer(GraphIdPtr); + if (unlikely(GraphId == nullptr)) { + spdlog::error( + "[WASI-NN] Failed when accessing the return GraphID memory."sv); + return WASINN::ErrNo::InvalidArgument; + } + + // Get the name of model + auto Name = MemInst->getPointer(NamePtr); + if (unlikely(Name == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the return Name memory."sv); + return WASINN::ErrNo::InvalidArgument; + } + + // Get the config of model + auto Config = MemInst->getPointer(ConfigPtr); + if (unlikely(Config == nullptr)) { + spdlog::error( + "[WASI-NN] Failed when accessing the return Config memory."sv); + return WASINN::ErrNo::InvalidArgument; + } + + // Get the model + std::string ModelName(reinterpret_cast(Name), NameLen); + std::vector ModelConfig(reinterpret_cast(Config), + reinterpret_cast(Config) + + ConfigLen); + if (Env.mdGet(ModelName, *GraphId)) { + return WASINN::ErrNo::Success; + } else { + return Env.mdBuild(ModelName, *GraphId, load, ModelConfig); + } +} + Expect WasiNNInitExecCtx::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t GraphId, uint32_t ContextPtr) { diff --git a/plugins/wasi_nn/wasinnfunc.h b/plugins/wasi_nn/wasinnfunc.h index 1bec0760..5db87f7a 100644 --- a/plugins/wasi_nn/wasinnfunc.h +++ b/plugins/wasi_nn/wasinnfunc.h @@ -42,6 +42,24 @@ class WasiNNLoadByName : public WasiNN { uint32_t GraphIdPtr); }; +class WasiNNLoadByNameWithConfig : public WasiNN { +public: + WasiNNLoadByNameWithConfig(WASINN::WasiNNEnvironment &HostEnv) + : WasiNN(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t NamePtr, + uint32_t NameLen, uint32_t ConfigPtr, + uint32_t ConfigLen, uint32_t GraphIdPtr) { + return bodyImpl(Frame, NamePtr, NameLen, ConfigPtr, ConfigLen, GraphIdPtr) + .map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &, + uint32_t NamePtr, uint32_t NameLen, + uint32_t ConfigPtr, uint32_t ConfigLen, + uint32_t GraphIdPtr); +}; + class WasiNNInitExecCtx : public WasiNN { public: WasiNNInitExecCtx(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} diff --git a/plugins/wasi_nn/wasinnmodule.cpp b/plugins/wasi_nn/wasinnmodule.cpp index eec2d282..cb3eea06 100644 --- a/plugins/wasi_nn/wasinnmodule.cpp +++ b/plugins/wasi_nn/wasinnmodule.cpp @@ -10,6 +10,8 @@ namespace Host { WasiNNModule::WasiNNModule() : ModuleInstance("wasi_ephemeral_nn") { addHostFunc("load", std::make_unique(Env)); addHostFunc("load_by_name", std::make_unique(Env)); + addHostFunc("load_by_name_with_config", + std::make_unique(Env)); addHostFunc("init_execution_context", std::make_unique(Env)); addHostFunc("set_input", std::make_unique(Env)); diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index f03c4bbb..b4b1ce9d 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -1322,7 +1322,7 @@ TEST(WasiNNTest, GGMLBackend) { static_cast(ErrNo::InvalidArgument)); } - // Test: load -- wrong builders' length. + // Test: load -- wrong metadata encoding when builders length > 1. BuilderPtr = LoadEntryPtr; writeFatPointer(MemInst, StorePtr, WeightRead.size(), BuilderPtr); writeBinaries(MemInst, WeightRead, StorePtr); @@ -1335,7 +1335,7 @@ TEST(WasiNNTest, GGMLBackend) { UINT32_C(0), BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), - static_cast(ErrNo::InvalidArgument)); + static_cast(ErrNo::InvalidEncoding)); } // Test: load -- load successfully. From 9083e0819fb236eb5d2a8e3d53bf639110fc74d7 Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 21 Nov 2023 18:30:01 +0800 Subject: [PATCH 182/623] [WASI-NN] ggml: add details::parseModelConfig Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 103 ++++++++++++++++++++++----------------- 1 file changed, 58 insertions(+), 45 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index ae08819d..0d892f33 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -118,6 +118,58 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, return ErrNo::Success; } + +Expect parseModelConfig(Graph &GraphRef, + std::string ModelFilePathWithConfig, + std::string &ModelFilePath) noexcept { + std::vector Configs; + std::string Delimiter = ","; + if (ModelFilePathWithConfig.find(Delimiter) == std::string::npos) { + ModelFilePath = ModelFilePathWithConfig; + } else { + // Handle model path with config. + size_t Pos = 0; + std::string Token; + Pos = ModelFilePathWithConfig.find(Delimiter); + ModelFilePath = ModelFilePathWithConfig.substr(0, Pos); + ModelFilePathWithConfig.erase(0, Pos + Delimiter.length()); + while ((Pos = ModelFilePathWithConfig.find(Delimiter)) != + std::string::npos) { + Token = ModelFilePathWithConfig.substr(0, Pos); + Configs.emplace_back(Token); + ModelFilePathWithConfig.erase(0, Pos + Delimiter.length()); + } + Configs.emplace_back(ModelFilePathWithConfig); + } + + // Parse the configs. + for (const auto &Config : Configs) { + std::string Delimiter = "="; + size_t Pos = 0; + std::string Token; + Pos = Config.find(Delimiter); + Token = Config.substr(0, Pos); + try { + if (Token == "n_gpu_layers" || Token == "ngl") { + GraphRef.NGPULayers = + std::stoi(Config.substr(Pos + Delimiter.length())); + } + } catch (const std::invalid_argument &e) { + spdlog::error( + "[WASI-NN] GGML backend: parse model parameter failed: invalid_argument {}"sv, + e.what()); + return ErrNo::InvalidArgument; + } catch (const std::out_of_range &e) { + spdlog::error( + "[WASI-NN] GGML backend: parse parameter failed: out_of_range {}"sv, + e.what()); + return ErrNo::InvalidArgument; + } + } + + return ErrNo::Success; +} + } // namespace details Expect load(WasiNNEnvironment &Env, Span> Builders, @@ -159,51 +211,12 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // If BinModel starts with 'preload:', it means that the model name passed // in as the --nn-preload parameter may have a config separated by ',' at // the end. For example, "preload:./model.bin,n_gpu_layers=99" - std::vector Configs; - std::string Delimiter = ","; - std::string ModelFilePathWithConfig = BinModel.substr(8); - if (ModelFilePathWithConfig.find(Delimiter) == std::string::npos) { - ModelFilePath = ModelFilePathWithConfig; - } else { - // Handle model path with config. - size_t Pos = 0; - std::string Token; - Pos = ModelFilePathWithConfig.find(Delimiter); - ModelFilePath = ModelFilePathWithConfig.substr(0, Pos); - ModelFilePathWithConfig.erase(0, Pos + Delimiter.length()); - while ((Pos = ModelFilePathWithConfig.find(Delimiter)) != - std::string::npos) { - Token = ModelFilePathWithConfig.substr(0, Pos); - Configs.emplace_back(Token); - ModelFilePathWithConfig.erase(0, Pos + Delimiter.length()); - } - Configs.emplace_back(ModelFilePathWithConfig); - } - // Parse the configs. - for (const auto &Config : Configs) { - std::string Delimiter = "="; - size_t Pos = 0; - std::string Token; - Pos = Config.find(Delimiter); - Token = Config.substr(0, Pos); - try { - if (Token == "n_gpu_layers" || Token == "ngl") { - GraphRef.NGPULayers = - std::stoi(Config.substr(Pos + Delimiter.length())); - } - } catch (const std::invalid_argument &e) { - spdlog::error( - "[WASI-NN] GGML backend: parse model parameter failed: invalid_argument {}"sv, - e.what()); - Env.NNGraph.pop_back(); - return ErrNo::InvalidArgument; - } catch (const std::out_of_range &e) { - spdlog::error( - "[WASI-NN] GGML backend: parse parameter failed: out_of_range {}"sv, - e.what()); - Env.NNGraph.pop_back(); - return ErrNo::InvalidArgument; - } + auto Res = + details::parseModelConfig(GraphRef, BinModel.substr(8), ModelFilePath); + if (Res != ErrNo::Success) { + spdlog::error("[WASI-NN] GGML backend: Failed to parse model config."sv); + Env.NNGraph.pop_back(); + return Res; } } else { // TODO: pass the model directly to ggml From 0f15eb45580eb7dd8b3b646da0138fecb547bff8 Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 22 Nov 2023 14:21:19 +0800 Subject: [PATCH 183/623] [WASI-NN] ggml: use FetchContent to download llama.cpp Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 198 +- plugins/wasi_nn/ggml.cpp | 2 - plugins/wasi_nn/thirdparty/CMakeLists.txt | 9 - .../wasi_nn/thirdparty/ggml/CMakeLists.txt | 712 - plugins/wasi_nn/thirdparty/ggml/LICENSE | 21 - plugins/wasi_nn/thirdparty/ggml/README.md | 10 - plugins/wasi_nn/thirdparty/ggml/common.cpp | 1253 - plugins/wasi_nn/thirdparty/ggml/common.h | 186 - plugins/wasi_nn/thirdparty/ggml/ggml-alloc.c | 594 - plugins/wasi_nn/thirdparty/ggml/ggml-alloc.h | 33 - .../wasi_nn/thirdparty/ggml/ggml-backend.c | 385 - .../wasi_nn/thirdparty/ggml/ggml-backend.h | 143 - plugins/wasi_nn/thirdparty/ggml/ggml-cuda.cu | 7824 ------ plugins/wasi_nn/thirdparty/ggml/ggml-cuda.h | 51 - plugins/wasi_nn/thirdparty/ggml/ggml-metal.h | 106 - plugins/wasi_nn/thirdparty/ggml/ggml-metal.m | 1601 -- .../wasi_nn/thirdparty/ggml/ggml-metal.metal | 2526 -- plugins/wasi_nn/thirdparty/ggml/ggml.c | 22041 ---------------- plugins/wasi_nn/thirdparty/ggml/ggml.h | 2114 -- plugins/wasi_nn/thirdparty/ggml/k_quants.c | 5060 ---- plugins/wasi_nn/thirdparty/ggml/k_quants.h | 165 - plugins/wasi_nn/thirdparty/ggml/llama.cpp | 9633 ------- plugins/wasi_nn/thirdparty/ggml/llama.h | 752 - plugins/wasi_nn/thirdparty/ggml/log.h | 643 - plugins/wasi_nn/thirdparty/ggml/sampling.cpp | 166 - plugins/wasi_nn/thirdparty/ggml/sampling.h | 108 - plugins/wasi_nn/thirdparty/ggml/unicode.h | 462 - 27 files changed, 107 insertions(+), 56691 deletions(-) delete mode 100644 plugins/wasi_nn/thirdparty/CMakeLists.txt delete mode 100644 plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt delete mode 100644 plugins/wasi_nn/thirdparty/ggml/LICENSE delete mode 100644 plugins/wasi_nn/thirdparty/ggml/README.md delete mode 100644 plugins/wasi_nn/thirdparty/ggml/common.cpp delete mode 100644 plugins/wasi_nn/thirdparty/ggml/common.h delete mode 100644 plugins/wasi_nn/thirdparty/ggml/ggml-alloc.c delete mode 100644 plugins/wasi_nn/thirdparty/ggml/ggml-alloc.h delete mode 100644 plugins/wasi_nn/thirdparty/ggml/ggml-backend.c delete mode 100644 plugins/wasi_nn/thirdparty/ggml/ggml-backend.h delete mode 100644 plugins/wasi_nn/thirdparty/ggml/ggml-cuda.cu delete mode 100644 plugins/wasi_nn/thirdparty/ggml/ggml-cuda.h delete mode 100644 plugins/wasi_nn/thirdparty/ggml/ggml-metal.h delete mode 100644 plugins/wasi_nn/thirdparty/ggml/ggml-metal.m delete mode 100644 plugins/wasi_nn/thirdparty/ggml/ggml-metal.metal delete mode 100644 plugins/wasi_nn/thirdparty/ggml/ggml.c delete mode 100644 plugins/wasi_nn/thirdparty/ggml/ggml.h delete mode 100644 plugins/wasi_nn/thirdparty/ggml/k_quants.c delete mode 100644 plugins/wasi_nn/thirdparty/ggml/k_quants.h delete mode 100644 plugins/wasi_nn/thirdparty/ggml/llama.cpp delete mode 100644 plugins/wasi_nn/thirdparty/ggml/llama.h delete mode 100644 plugins/wasi_nn/thirdparty/ggml/log.h delete mode 100644 plugins/wasi_nn/thirdparty/ggml/sampling.cpp delete mode 100644 plugins/wasi_nn/thirdparty/ggml/sampling.h delete mode 100644 plugins/wasi_nn/thirdparty/ggml/unicode.h diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index fc8935b7..390714f9 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -1,102 +1,119 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: 2019-2022 Second State INC -# llama.cpp options -# Disable warnings and debug messages -set(LLAMA_ALL_WARNINGS OFF) -set(LLAMA_METAL_NDEBUG ON) - -if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_CUBLAS) - message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_CUBLAS") - set(LLAMA_CUBLAS ON) - # If CUBLAS is ON, then OpenBLAS should be OFF. - set(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_BLAS OFF) -else() - message(STATUS "WASI-NN GGML LLAMA backend: Disable LLAMA_CUBLAS") - set(LLAMA_CUBLAS OFF) -endif() +string(TOLOWER ${WASMEDGE_PLUGIN_WASI_NN_BACKEND} BACKEND) +if(BACKEND STREQUAL "ggml") + # llama.cpp options + # Disable warnings and debug messages + set(LLAMA_ALL_WARNINGS OFF) + set(LLAMA_METAL_NDEBUG ON) + set(LLAMA_ACCELERATE OFF) + + if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_CUBLAS) + message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_CUBLAS") + set(LLAMA_CUBLAS ON) + # If CUBLAS is ON, then OpenBLAS should be OFF. + set(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_BLAS OFF) + else() + message(STATUS "WASI-NN GGML LLAMA backend: Disable LLAMA_CUBLAS") + set(LLAMA_CUBLAS OFF) + endif() -if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_BLAS) - message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_BLAS") - # Default use OpenBLAS - set(LLAMA_BLAS ON) - set(LLAMA_BLAS_VENDOR "OpenBLAS") -else() - message(STATUS "WASI-NN GGML LLAMA backend: Disable LLAMA_BLAS") - set(LLAMA_BLAS OFF) -endif() + if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_BLAS) + message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_BLAS") + # Default use OpenBLAS + set(LLAMA_BLAS ON) + set(LLAMA_BLAS_VENDOR "OpenBLAS") + else() + message(STATUS "WASI-NN GGML LLAMA backend: Disable LLAMA_BLAS") + set(LLAMA_BLAS OFF) + endif() -if(NOT APPLE) - set(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL OFF) -endif() + if(NOT APPLE) + set(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL OFF) + endif() -if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL) - message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_METAL") - set(LLAMA_METAL ON) -else() - message(STATUS "WASI-NN GGML LLAMA backend: Disable LLAMA_METAL") - set(LLAMA_METAL OFF) -endif() + if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL) + message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_METAL") + set(LLAMA_METAL ON) + else() + message(STATUS "WASI-NN GGML LLAMA backend: Disable LLAMA_METAL") + set(LLAMA_METAL OFF) + endif() -# simdjson for ggml backend -find_package(simdjson QUIET) -if(simdjson_FOUND) - message(STATUS "SIMDJSON found") -else() - message(STATUS "Downloading SIMDJSON source") + # setup llama.cpp + message(STATUS "Downloading llama.cpp source") include(FetchContent) FetchContent_Declare( - simdjson - GIT_REPOSITORY https://github.com/simdjson/simdjson.git - GIT_TAG tags/v3.2.1 - GIT_SHALLOW TRUE) - - if(MSVC) - if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") - get_property( - compile_options - DIRECTORY - PROPERTY COMPILE_OPTIONS - ) - set_property( - DIRECTORY - APPEND - PROPERTY COMPILE_OPTIONS - -Wno-undef - -Wno-suggest-override - -Wno-documentation - -Wno-sign-conversion - -Wno-extra-semi-stmt - -Wno-old-style-cast - -Wno-error=unused-parameter - -Wno-error=unused-template - -Wno-conditional-uninitialized - -Wno-implicit-int-conversion - -Wno-shorten-64-to-32 - -Wno-range-loop-bind-reference - -Wno-format-nonliteral - -Wno-unused-exception-parameter - -Wno-unused-member-function - ) - unset(compile_options) - elseif(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") - set_property( - DIRECTORY - APPEND - PROPERTY COMPILE_OPTIONS - /wd4100 # unreferenced formal parameter - ) + llama + GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git + GIT_TAG b1383 + PATCH_COMMAND test -f ggml.patched || git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch && ${CMAKE_COMMAND} -E touch ggml.patched + GIT_SHALLOW TRUE + ) + FetchContent_MakeAvailable(llama) + set_property(TARGET ggml PROPERTY POSITION_INDEPENDENT_CODE ON) + set_property(TARGET common PROPERTY POSITION_INDEPENDENT_CODE ON) + set_property(TARGET llama PROPERTY POSITION_INDEPENDENT_CODE ON) + + # setup simdjson + find_package(simdjson QUIET) + if(simdjson_FOUND) + message(STATUS "SIMDJSON found") + else() + message(STATUS "Downloading SIMDJSON source") + include(FetchContent) + FetchContent_Declare( + simdjson + GIT_REPOSITORY https://github.com/simdjson/simdjson.git + GIT_TAG tags/v3.2.1 + GIT_SHALLOW TRUE) + + if(MSVC) + if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") + get_property( + compile_options + DIRECTORY + PROPERTY COMPILE_OPTIONS + ) + set_property( + DIRECTORY + APPEND + PROPERTY COMPILE_OPTIONS + -Wno-undef + -Wno-suggest-override + -Wno-documentation + -Wno-sign-conversion + -Wno-extra-semi-stmt + -Wno-old-style-cast + -Wno-error=unused-parameter + -Wno-error=unused-template + -Wno-conditional-uninitialized + -Wno-implicit-int-conversion + -Wno-shorten-64-to-32 + -Wno-range-loop-bind-reference + -Wno-format-nonliteral + -Wno-unused-exception-parameter + -Wno-unused-member-function + ) + unset(compile_options) + elseif(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") + set_property( + DIRECTORY + APPEND + PROPERTY COMPILE_OPTIONS + /wd4100 # unreferenced formal parameter + ) + endif() endif() - endif() - FetchContent_MakeAvailable(simdjson) - set_property(TARGET simdjson PROPERTY POSITION_INDEPENDENT_CODE ON) + FetchContent_MakeAvailable(simdjson) + set_property(TARGET simdjson PROPERTY POSITION_INDEPENDENT_CODE ON) - message(STATUS "Downloading SIMDJSON source -- done") + message(STATUS "Downloading SIMDJSON source -- done") + endif() endif() -add_subdirectory(thirdparty) - wasmedge_add_library(wasmedgePluginWasiNN SHARED wasinnenv.cpp @@ -119,9 +136,13 @@ target_include_directories(wasmedgePluginWasiNN PUBLIC $ ${CMAKE_CURRENT_SOURCE_DIR} - ${PROJECT_SOURCE_DIR}/thirdparty/ggml ) +if(BACKEND STREQUAL "ggml") + target_include_directories(wasmedgePluginWasiNN PUBLIC ${CMAKE_BINARY_DIR}/_deps/llama-src) + target_link_libraries(wasmedgePluginWasiNN PRIVATE common simdjson) +endif() + if(WASMEDGE_LINK_PLUGINS_STATIC) target_link_libraries(wasmedgePluginWasiNN PRIVATE @@ -134,11 +155,6 @@ else() ) endif() -string(TOLOWER ${WASMEDGE_PLUGIN_WASI_NN_BACKEND} BACKEND) -if(BACKEND STREQUAL "ggml") - target_link_libraries(wasmedgePluginWasiNN PRIVATE llama simdjson) -endif() - include(WASINNDeps) wasmedge_setup_wasinn_target(wasmedgePluginWasiNN) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 0d892f33..bd19cc1a 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -238,8 +238,6 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, } // Initialize ggml model with model parameters. - gpt_params Params; - llama_backend_init(Params.numa); GraphRef.ModelFilePath = ModelFilePath; llama_model_params ModelParams = llama_model_default_params(); ModelParams.n_gpu_layers = GraphRef.NGPULayers; diff --git a/plugins/wasi_nn/thirdparty/CMakeLists.txt b/plugins/wasi_nn/thirdparty/CMakeLists.txt deleted file mode 100644 index 94db3597..00000000 --- a/plugins/wasi_nn/thirdparty/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2023 Second State INC - -if(WASMEDGE_PLUGIN_WASI_NN_BACKEND) - string(TOLOWER ${WASMEDGE_PLUGIN_WASI_NN_BACKEND} BACKEND) - if(BACKEND STREQUAL "ggml") - add_subdirectory(ggml) - endif() -endif() diff --git a/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt b/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt deleted file mode 100644 index 75867a3d..00000000 --- a/plugins/wasi_nn/thirdparty/ggml/CMakeLists.txt +++ /dev/null @@ -1,712 +0,0 @@ -# -# Option list -# - -if (APPLE) - set(LLAMA_METAL_DEFAULT ON) -else() - set(LLAMA_METAL_DEFAULT OFF) -endif() - -# general -option(LLAMA_STATIC "llama: static link libraries" OFF) -option(LLAMA_NATIVE "llama: enable -march=native flag" OFF) -option(LLAMA_LTO "llama: enable link time optimization" OFF) - -# debug -option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON) -option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF) -option(LLAMA_GPROF "llama: enable gprof" OFF) - -# sanitizers -option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF) -option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF) -option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF) - -# instruction set specific -if (LLAMA_NATIVE) - set(INS_ENB OFF) -else() - set(INS_ENB ON) -endif() - -option(LLAMA_AVX "llama: enable AVX" ON) -option(LLAMA_AVX2 "llama: enable AVX2" ON) -option(LLAMA_AVX512 "llama: enable AVX512" OFF) -option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF) -option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF) -option(LLAMA_FMA "llama: enable FMA" ON) -# in MSVC F16C is implied with AVX2/AVX512 -if (NOT MSVC) - option(LLAMA_F16C "llama: enable F16C" ON) -endif() - -# 3rd party libs -option(LLAMA_ACCELERATE "llama: enable Accelerate framework" OFF) -option(LLAMA_BLAS "llama: use BLAS" OFF) -set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") -option(LLAMA_CUBLAS "llama: use CUDA" OFF) -#option(LLAMA_CUDA_CUBLAS "llama: use cuBLAS for prompt processing" OFF) -option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF) -set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") -set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels") -option(LLAMA_CUDA_F16 "llama: use 16 bit floats for some calculations" OFF) -set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K") -set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING - "llama: max. batch size for using peer access") -option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF) -option(LLAMA_CLBLAST "llama: use CLBlast" OFF) -option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT}) -option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF) -option(LLAMA_MPI "llama: use MPI" OFF) -option(LLAMA_K_QUANTS "llama: use k-quants" ON) -option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF) - -option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) -option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) -option(LLAMA_BUILD_SERVER "llama: build server example" ON) - -# -# Compile flags -# - -set(CMAKE_CXX_STANDARD 11) -set(CMAKE_CXX_STANDARD_REQUIRED true) -set(CMAKE_C_STANDARD 11) -set(CMAKE_C_STANDARD_REQUIRED true) -set(THREADS_PREFER_PTHREAD_FLAG ON) -find_package(Threads REQUIRED) -include(CheckCXXCompilerFlag) - -if (NOT MSVC) - if (LLAMA_SANITIZE_THREAD) - add_compile_options(-fsanitize=thread) - link_libraries(-fsanitize=thread) - endif() - - if (LLAMA_SANITIZE_ADDRESS) - add_compile_options(-fsanitize=address -fno-omit-frame-pointer) - link_libraries(-fsanitize=address) - endif() - - if (LLAMA_SANITIZE_UNDEFINED) - add_compile_options(-fsanitize=undefined) - link_libraries(-fsanitize=undefined) - endif() -endif() - -if (APPLE AND LLAMA_ACCELERATE) - find_library(ACCELERATE_FRAMEWORK Accelerate) - if (ACCELERATE_FRAMEWORK) - message(STATUS "Accelerate framework found") - - add_compile_definitions(GGML_USE_ACCELERATE) - add_compile_definitions(ACCELERATE_NEW_LAPACK) - add_compile_definitions(ACCELERATE_LAPACK_ILP64) - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK}) - else() - message(WARNING "Accelerate framework not found") - endif() -endif() - -if (LLAMA_METAL) - find_library(FOUNDATION_LIBRARY Foundation REQUIRED) - find_library(METAL_FRAMEWORK Metal REQUIRED) - find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) - - message(STATUS "Metal framework found") - set(GGML_HEADERS_METAL ggml-metal.h) - set(GGML_SOURCES_METAL ggml-metal.m) - - add_compile_definitions(GGML_USE_METAL) - if (LLAMA_METAL_NDEBUG) - add_compile_definitions(GGML_METAL_NDEBUG) - endif() - - # get full path to the file - #add_compile_definitions(GGML_METAL_DIR_KERNELS="${CMAKE_CURRENT_SOURCE_DIR}/") - - # copy ggml-metal.metal to bin directory - configure_file(ggml-metal.metal bin/ggml-metal.metal COPYONLY) - - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} - ${FOUNDATION_LIBRARY} - ${METAL_FRAMEWORK} - ${METALKIT_FRAMEWORK} - ) -endif() -if (LLAMA_BLAS) - if (LLAMA_STATIC) - set(BLA_STATIC ON) - endif() - if ($(CMAKE_VERSION) VERSION_GREATER_EQUAL 3.22) - set(BLA_SIZEOF_INTEGER 8) - endif() - - set(BLA_VENDOR ${LLAMA_BLAS_VENDOR}) - find_package(BLAS) - - if (BLAS_FOUND) - message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}") - - if ("${BLAS_INCLUDE_DIRS}" STREQUAL "") - # BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake. - # see https://gitlab.kitware.com/cmake/cmake/-/issues/20268 - find_package(PkgConfig REQUIRED) - if (${LLAMA_BLAS_VENDOR} MATCHES "Generic") - pkg_check_modules(DepBLAS REQUIRED blas) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "OpenBLAS") - pkg_check_modules(DepBLAS REQUIRED openblas) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "FLAME") - pkg_check_modules(DepBLAS REQUIRED blis) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "ATLAS") - pkg_check_modules(DepBLAS REQUIRED blas-atlas) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "FlexiBLAS") - pkg_check_modules(DepBLAS REQUIRED flexiblas_api) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "Intel") - # all Intel* libraries share the same include path - pkg_check_modules(DepBLAS REQUIRED mkl-sdl) - elseif (${LLAMA_BLAS_VENDOR} MATCHES "NVHPC") - # this doesn't provide pkg-config - # suggest to assign BLAS_INCLUDE_DIRS on your own - if ("${NVHPC_VERSION}" STREQUAL "") - message(WARNING "Better to set NVHPC_VERSION") - else() - set(DepBLAS_FOUND ON) - set(DepBLAS_INCLUDE_DIRS "/opt/nvidia/hpc_sdk/${CMAKE_SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR}/${NVHPC_VERSION}/math_libs/include") - endif() - endif() - if (DepBLAS_FOUND) - set(BLAS_INCLUDE_DIRS ${DepBLAS_INCLUDE_DIRS}) - else() - message(WARNING "BLAS_INCLUDE_DIRS neither been provided nor been automatically" - " detected by pkgconfig, trying to find cblas.h from possible paths...") - find_path(BLAS_INCLUDE_DIRS - NAMES cblas.h - HINTS - /usr/include - /usr/local/include - /usr/include/openblas - /opt/homebrew/opt/openblas/include - /usr/local/opt/openblas/include - /usr/include/x86_64-linux-gnu/openblas/include - ) - endif() - endif() - - message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}") - add_compile_options(${BLAS_LINKER_FLAGS}) - add_compile_definitions(GGML_USE_OPENBLAS) - if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${LLAMA_BLAS_VENDOR} MATCHES "Generic" OR ${LLAMA_BLAS_VENDOR} MATCHES "Intel")) - add_compile_definitions(GGML_BLAS_USE_MKL) - endif() - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${BLAS_LIBRARIES}) - set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${BLAS_INCLUDE_DIRS}) - - else() - message(WARNING "BLAS not found, please refer to " - "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" - " to set correct LLAMA_BLAS_VENDOR") - endif() -endif() - -if (LLAMA_K_QUANTS) - set(GGML_HEADERS_EXTRA k_quants.h) - set(GGML_SOURCES_EXTRA k_quants.c) - add_compile_definitions(GGML_USE_K_QUANTS) - if (LLAMA_QKK_64) - add_compile_definitions(GGML_QKK_64) - endif() -endif() - -if (LLAMA_CUBLAS) - cmake_minimum_required(VERSION 3.17) - - find_package(CUDAToolkit) - if (CUDAToolkit_FOUND) - message(STATUS "cuBLAS found") - - enable_language(CUDA) - - set(GGML_HEADERS_CUDA ggml-cuda.h) - set(GGML_SOURCES_CUDA ggml-cuda.cu) - - add_compile_definitions(GGML_USE_CUBLAS) -# if (LLAMA_CUDA_CUBLAS) -# add_compile_definitions(GGML_CUDA_CUBLAS) -# endif() - if (LLAMA_CUDA_FORCE_DMMV) - add_compile_definitions(GGML_CUDA_FORCE_DMMV) - endif() - add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) - add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y}) - if (DEFINED LLAMA_CUDA_DMMV_Y) - add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_DMMV_Y}) # for backwards compatibility - endif() - if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16) - add_compile_definitions(GGML_CUDA_F16) - endif() - add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) - add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${LLAMA_CUDA_PEER_MAX_BATCH_SIZE}) - - if (LLAMA_STATIC) - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) - else() - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) - endif() - - if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) - # 52 == lowest CUDA 12 standard - # 60 == f16 CUDA intrinsics - # 61 == integer CUDA intrinsics - # 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster - if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16) - set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics - else() - set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics - endif() - endif() - message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") - - else() - message(WARNING "cuBLAS not found") - endif() -endif() - -if (LLAMA_MPI) - cmake_minimum_required(VERSION 3.10) - find_package(MPI) - if (MPI_C_FOUND) - message(STATUS "MPI found") - set(GGML_HEADERS_MPI ggml-mpi.h) - set(GGML_SOURCES_MPI ggml-mpi.c ggml-mpi.h) - add_compile_definitions(GGML_USE_MPI) - add_compile_definitions(${MPI_C_COMPILE_DEFINITIONS}) - if (NOT MSVC) - add_compile_options(-Wno-cast-qual) - endif() - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${MPI_C_LIBRARIES}) - set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${MPI_C_INCLUDE_DIRS}) - # Even if you're only using the C header, C++ programs may bring in MPI - # C++ functions, so more linkage is needed - if (MPI_CXX_FOUND) - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${MPI_CXX_LIBRARIES}) - endif() - else() - message(WARNING "MPI not found") - endif() -endif() - -if (LLAMA_CLBLAST) - find_package(CLBlast) - if (CLBlast_FOUND) - message(STATUS "CLBlast found") - - set(GGML_HEADERS_OPENCL ggml-opencl.h) - set(GGML_SOURCES_OPENCL ggml-opencl.cpp) - - add_compile_definitions(GGML_USE_CLBLAST) - - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast) - else() - message(WARNING "CLBlast not found") - endif() -endif() - -if (LLAMA_HIPBLAS) - list(APPEND CMAKE_PREFIX_PATH /opt/rocm) - - if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") - message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang") - endif() - if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") - message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++") - endif() - - find_package(hip) - find_package(hipblas) - find_package(rocblas) - - if (${hipblas_FOUND} AND ${hip_FOUND}) - message(STATUS "HIP and hipBLAS found") - add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS) - add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h) - if (BUILD_SHARED_LIBS) - set_target_properties(ggml-rocm PROPERTIES POSITION_INDEPENDENT_CODE ON) - endif() - if (LLAMA_CUDA_FORCE_DMMV) - target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_FORCE_DMMV) - endif() - target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) - target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y}) - target_compile_definitions(ggml-rocm PRIVATE K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) - set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX) - target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas) - - if (LLAMA_STATIC) - message(FATAL_ERROR "Static linking not supported for HIP/ROCm") - endif() - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ggml-rocm) - else() - message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm") - endif() -endif() - -if (LLAMA_ALL_WARNINGS) - if (NOT MSVC) - set(warning_flags -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function) - set(c_flags -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes -Werror=implicit-int -Werror=implicit-function-declaration) - set(cxx_flags -Wmissing-declarations -Wmissing-noreturn) - set(host_cxx_flags "") - - if (CMAKE_C_COMPILER_ID MATCHES "Clang") - set(warning_flags ${warning_flags} -Wunreachable-code-break -Wunreachable-code-return) - set(host_cxx_flags ${host_cxx_flags} -Wmissing-prototypes -Wextra-semi) - - if ( - (CMAKE_C_COMPILER_ID STREQUAL "Clang" AND CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 3.8.0) OR - (CMAKE_C_COMPILER_ID STREQUAL "AppleClang" AND CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 7.3.0) - ) - set(c_flags ${c_flags} -Wdouble-promotion) - endif() - elseif (CMAKE_C_COMPILER_ID STREQUAL "GNU") - set(c_flags ${c_flags} -Wdouble-promotion) - set(host_cxx_flags ${host_cxx_flags} -Wno-array-bounds) - - if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 7.1.0) - set(host_cxx_flags ${host_cxx_flags} -Wno-format-truncation) - endif() - if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 8.1.0) - set(host_cxx_flags ${host_cxx_flags} -Wextra-semi) - endif() - endif() - else() - # todo : msvc - endif() - - set(c_flags ${c_flags} ${warning_flags}) - set(cxx_flags ${cxx_flags} ${warning_flags}) - add_compile_options("$<$:${c_flags}>" - "$<$:${cxx_flags}>" - "$<$:${host_cxx_flags}>") - -endif() - -if (NOT MSVC) - set(cuda_flags -Wno-pedantic) -endif() -set(cuda_flags ${cxx_flags} -use_fast_math ${cuda_flags}) - -list(JOIN host_cxx_flags " " cuda_host_flags) # pass host compiler flags as a single argument -if (NOT cuda_host_flags STREQUAL "") - set(cuda_flags ${cuda_flags} -Xcompiler ${cuda_host_flags}) -endif() - -add_compile_options("$<$:${cuda_flags}>") - -if (WIN32) - add_compile_definitions(_CRT_SECURE_NO_WARNINGS) - - if (BUILD_SHARED_LIBS) - set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) - endif() -endif() - -if (LLAMA_LTO) - include(CheckIPOSupported) - check_ipo_supported(RESULT result OUTPUT output) - if (result) - set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE) - else() - message(WARNING "IPO is not supported: ${output}") - endif() -endif() - -# Architecture specific -# TODO: probably these flags need to be tweaked on some architectures -# feel free to update the Makefile for your architecture and send a pull request or issue -message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") -if (MSVC) - string(TOLOWER "${CMAKE_GENERATOR_PLATFORM}" CMAKE_GENERATOR_PLATFORM_LWR) - message(STATUS "CMAKE_GENERATOR_PLATFORM: ${CMAKE_GENERATOR_PLATFORM}") -else () - set(CMAKE_GENERATOR_PLATFORM_LWR "") -endif () - -if (NOT MSVC) - if (LLAMA_STATIC) - add_link_options(-static) - if (MINGW) - add_link_options(-static-libgcc -static-libstdc++) - endif() - endif() - if (LLAMA_GPROF) - add_compile_options(-pg) - endif() -endif() - -if ((${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm") OR (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") OR ("${CMAKE_GENERATOR_PLATFORM_LWR}" MATCHES "arm64")) - message(STATUS "ARM detected") - if (MSVC) - add_compile_definitions(__ARM_NEON) - add_compile_definitions(__ARM_FEATURE_FMA) - add_compile_definitions(__ARM_FEATURE_DOTPROD) - # add_compile_definitions(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) # MSVC doesn't support vdupq_n_f16, vld1q_f16, vst1q_f16 - add_compile_definitions(__aarch64__) # MSVC defines _M_ARM64 instead - else() - check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E) - if (NOT "${COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "") - add_compile_options(-mfp16-format=ieee) - endif() - if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6") - # Raspberry Pi 1, Zero - add_compile_options(-mfpu=neon-fp-armv8 -mno-unaligned-access) - endif() - if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv7") - # Raspberry Pi 2 - add_compile_options(-mfpu=neon-fp-armv8 -mno-unaligned-access -funsafe-math-optimizations) - endif() - if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv8") - # Raspberry Pi 3, 4, Zero 2 (32-bit) - add_compile_options(-mno-unaligned-access) - endif() - endif() -elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$" OR "${CMAKE_GENERATOR_PLATFORM_LWR}" MATCHES "^(x86_64|i686|amd64|x64)$" ) - message(STATUS "x86 detected") - if (MSVC) - if (LLAMA_AVX512) - add_compile_options($<$:/arch:AVX512>) - add_compile_options($<$:/arch:AVX512>) - # MSVC has no compile-time flags enabling specific - # AVX512 extensions, neither it defines the - # macros corresponding to the extensions. - # Do it manually. - if (LLAMA_AVX512_VBMI) - add_compile_definitions($<$:__AVX512VBMI__>) - add_compile_definitions($<$:__AVX512VBMI__>) - endif() - if (LLAMA_AVX512_VNNI) - add_compile_definitions($<$:__AVX512VNNI__>) - add_compile_definitions($<$:__AVX512VNNI__>) - endif() - elseif (LLAMA_AVX2) - add_compile_options($<$:/arch:AVX2>) - add_compile_options($<$:/arch:AVX2>) - elseif (LLAMA_AVX) - add_compile_options($<$:/arch:AVX>) - add_compile_options($<$:/arch:AVX>) - endif() - else() - if (LLAMA_NATIVE) - add_compile_options(-march=native) - endif() - if (LLAMA_F16C) - add_compile_options(-mf16c) - endif() - if (LLAMA_FMA) - add_compile_options(-mfma) - endif() - if (LLAMA_AVX) - add_compile_options(-mavx) - endif() - if (LLAMA_AVX2) - add_compile_options(-mavx2) - endif() - if (LLAMA_AVX512) - add_compile_options(-mavx512f) - add_compile_options(-mavx512bw) - endif() - if (LLAMA_AVX512_VBMI) - add_compile_options(-mavx512vbmi) - endif() - if (LLAMA_AVX512_VNNI) - add_compile_options(-mavx512vnni) - endif() - endif() -elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") - message(STATUS "PowerPC detected") - add_compile_options(-mcpu=native -mtune=native) - #TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be) -else() - message(STATUS "Unknown architecture") -endif() - -# -# POSIX conformance -# - -# clock_gettime came in POSIX.1b (1993) -# CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional -# posix_memalign came in POSIX.1-2001 / SUSv3 -# M_PI is an XSI extension since POSIX.1-2001 / SUSv3, came in XPG1 (1985) -add_compile_definitions(_XOPEN_SOURCE=600) - -# Somehow in OpenBSD whenever POSIX conformance is specified -# some string functions rely on locale_t availability, -# which was introduced in POSIX.1-2008, forcing us to go higher -if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD") - remove_definitions(-D_XOPEN_SOURCE=600) - add_compile_definitions(_XOPEN_SOURCE=700) -endif() - -# Data types, macros and functions related to controlling CPU affinity and -# some memory allocation are available on Linux through GNU extensions in libc -if (CMAKE_SYSTEM_NAME MATCHES "Linux") - add_compile_definitions(_GNU_SOURCE) -endif() - -# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1, -# and on macOS its availability depends on enabling Darwin extensions -# similarly on DragonFly, enabling BSD extensions is necessary -if ( - CMAKE_SYSTEM_NAME MATCHES "Darwin" OR - CMAKE_SYSTEM_NAME MATCHES "iOS" OR - CMAKE_SYSTEM_NAME MATCHES "tvOS" OR - CMAKE_SYSTEM_NAME MATCHES "DragonFly" -) - add_compile_definitions(_DARWIN_C_SOURCE) -endif() - -# alloca is a non-standard interface that is not visible on BSDs when -# POSIX conformance is specified, but not all of them provide a clean way -# to enable it in such cases -if (CMAKE_SYSTEM_NAME MATCHES "FreeBSD") - add_compile_definitions(__BSD_VISIBLE) -endif() -if (CMAKE_SYSTEM_NAME MATCHES "NetBSD") - add_compile_definitions(_NETBSD_SOURCE) -endif() -if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD") - add_compile_definitions(_BSD_SOURCE) -endif() - -# -# libraries -# - -# ggml - -if (GGML_USE_CPU_HBM) - add_definitions(-DGGML_USE_CPU_HBM) - find_library(memkind memkind REQUIRED) -endif() - -wasmedge_add_library(ggml OBJECT - ggml.c - ggml.h - ggml-alloc.c - ggml-alloc.h - ggml-backend.c - ggml-backend.h - common.cpp - common.h - sampling.cpp - sampling.h - ${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA} - ${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL} - ${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL} - ${GGML_SOURCES_MPI} ${GGML_HEADERS_MPI} - ${GGML_SOURCES_EXTRA} ${GGML_HEADERS_EXTRA} - ) - -target_include_directories(ggml PUBLIC . ${LLAMA_EXTRA_INCLUDES}) -target_compile_features(ggml PUBLIC c_std_11) # don't bump -target_link_libraries(ggml PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS}) -if (GGML_USE_CPU_HBM) - target_link_libraries(ggml PUBLIC memkind) -endif() - -wasmedge_add_library(ggml_static STATIC $) -if (BUILD_SHARED_LIBS) - set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON) - wasmedge_add_library(ggml_shared SHARED $) - target_link_libraries(ggml_shared PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS}) - install(TARGETS ggml_shared LIBRARY) -endif() - -# llama - -wasmedge_add_library(llama - llama.cpp - llama.h - ) - -target_include_directories(llama PUBLIC .) -target_compile_features(llama PUBLIC cxx_std_11) # don't bump -target_link_libraries(llama PRIVATE - ggml - ${LLAMA_EXTRA_LIBS} - ) - -if (BUILD_SHARED_LIBS) - set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON) - target_compile_definitions(llama PRIVATE LLAMA_SHARED LLAMA_BUILD) - if (LLAMA_METAL) - set_target_properties(llama PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal") - endif() -endif() - -# disable warnings -if (NOT WIN32) - target_compile_options(ggml - PRIVATE - -Wno-unused-parameter - -Wno-unused-variable - -Wno-unused-but-set-variable - -Wno-unused-function - -Wno-missing-braces - ) - target_compile_options(llama - PRIVATE - -Wno-unused-parameter - -Wno-unused-variable - -Wno-unused-but-set-variable - -Wno-unused-function - -Wno-missing-braces - ) -else() - target_compile_options(ggml - PRIVATE - -Wno-string-conversion - -Wno-sign-conversion - -Wno-macro-redefined - -Wno-missing-prototypes - -Wno-unreachable-code-return - -Wno-shorten-64-to-32 - -Wno-implicit-int-conversion - -Wno-implicit-float-conversion - -Wno-float-conversion - -Wno-unused-macros - -Wno-unreachable-code-break - -Wno-cast-align - -Wno-undef - -Wno-shadow-uncaptured-local - -Wno-unreachable-code - -Wno-cast-function-type - -Wno-format-nonliteral - -Wno-extra-semi-stmt - -Wno-bad-function-cast - ) - target_compile_options(llama - PRIVATE - -Wno-string-conversion - -Wno-sign-conversion - -Wno-macro-redefined - -Wno-missing-prototypes - -Wno-unreachable-code-return - -Wno-shorten-64-to-32 - -Wno-implicit-int-conversion - -Wno-implicit-float-conversion - -Wno-float-conversion - -Wno-unused-macros - -Wno-unreachable-code-break - -Wno-cast-align - -Wno-undef - -Wno-shadow-uncaptured-local - -Wno-unreachable-code - -Wno-cast-function-type - -Wno-format-nonliteral - -Wno-extra-semi-stmt - -Wno-bad-function-cast - ) -endif() diff --git a/plugins/wasi_nn/thirdparty/ggml/LICENSE b/plugins/wasi_nn/thirdparty/ggml/LICENSE deleted file mode 100644 index 8c955688..00000000 --- a/plugins/wasi_nn/thirdparty/ggml/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2023 Georgi Gerganov - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. \ No newline at end of file diff --git a/plugins/wasi_nn/thirdparty/ggml/README.md b/plugins/wasi_nn/thirdparty/ggml/README.md deleted file mode 100644 index 594704db..00000000 --- a/plugins/wasi_nn/thirdparty/ggml/README.md +++ /dev/null @@ -1,10 +0,0 @@ -# GGML and llama.cpp - -[GGML][] and [llama.cpp][] are open-source projects in the machine learning domain. GGML is a tensor library for machine learning, developed in C. On the other hand, llama.cpp serves as a LLaMA model inference engine and is implemented in C/C++. - -This directory contains the source code from both llama.cpp and GGML. The code in this directory is licensed under the MIT License. For more details, please refer to the [LICENSE](./LICENSE) file. - -WasmEdge includes support for GGML and llama.cpp through its WASI-NN plugin, enabling the execution of machine learning models in WebAssembly. Within the WasmEdge WASI-NN plugin, we have added functionality for GGML model loading and LLaMA model inference. - -[GGML]: http://ggml.ai -[llama.cpp]: https://github.com/ggerganov/ggml diff --git a/plugins/wasi_nn/thirdparty/ggml/common.cpp b/plugins/wasi_nn/thirdparty/ggml/common.cpp deleted file mode 100644 index 79e645c9..00000000 --- a/plugins/wasi_nn/thirdparty/ggml/common.cpp +++ /dev/null @@ -1,1253 +0,0 @@ -#include "common.h" -#include "llama.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if defined(__APPLE__) && defined(__MACH__) -#include -#include -#endif - -#if defined(_WIN32) -#define WIN32_LEAN_AND_MEAN -#ifndef NOMINMAX -# define NOMINMAX -#endif -#include -#include -#include -#include -#include -#else -#include -#include -#include -#endif - -#if defined(_MSC_VER) -#pragma warning(disable: 4244 4267) // possible loss of data -#endif - -int32_t get_num_physical_cores() { -#ifdef __linux__ - // enumerate the set of thread siblings, num entries is num cores - std::unordered_set siblings; - for (uint32_t cpu=0; cpu < UINT32_MAX; ++cpu) { - std::ifstream thread_siblings("/sys/devices/system/cpu" - + std::to_string(cpu) + "/topology/thread_siblings"); - if (!thread_siblings.is_open()) { - break; // no more cpus - } - std::string line; - if (std::getline(thread_siblings, line)) { - siblings.insert(line); - } - } - if (!siblings.empty()) { - return static_cast(siblings.size()); - } -#elif defined(__APPLE__) && defined(__MACH__) - int32_t num_physical_cores; - size_t len = sizeof(num_physical_cores); - int result = sysctlbyname("hw.perflevel0.physicalcpu", &num_physical_cores, &len, NULL, 0); - if (result == 0) { - return num_physical_cores; - } - result = sysctlbyname("hw.physicalcpu", &num_physical_cores, &len, NULL, 0); - if (result == 0) { - return num_physical_cores; - } -#elif defined(_WIN32) - //TODO: Implement -#endif - unsigned int n_threads = std::thread::hardware_concurrency(); - return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; -} - -void process_escapes(std::string& input) { - std::size_t input_len = input.length(); - std::size_t output_idx = 0; - - for (std::size_t input_idx = 0; input_idx < input_len; ++input_idx) { - if (input[input_idx] == '\\' && input_idx + 1 < input_len) { - switch (input[++input_idx]) { - case 'n': input[output_idx++] = '\n'; break; - case 'r': input[output_idx++] = '\r'; break; - case 't': input[output_idx++] = '\t'; break; - case '\'': input[output_idx++] = '\''; break; - case '\"': input[output_idx++] = '\"'; break; - case '\\': input[output_idx++] = '\\'; break; - default: input[output_idx++] = '\\'; - input[output_idx++] = input[input_idx]; break; - } - } else { - input[output_idx++] = input[input_idx]; - } - } - - input.resize(output_idx); -} - -bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { - bool invalid_param = false; - std::string arg; - gpt_params default_params; - const std::string arg_prefix = "--"; - llama_sampling_params & sparams = params.sampling_params; - - for (int i = 1; i < argc; i++) { - arg = argv[i]; - if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { - std::replace(arg.begin(), arg.end(), '_', '-'); - } - - if (arg == "-s" || arg == "--seed") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.seed = std::stoul(argv[i]); - } else if (arg == "-t" || arg == "--threads") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_threads = std::stoi(argv[i]); - if (params.n_threads <= 0) { - params.n_threads = std::thread::hardware_concurrency(); - } - } else if (arg == "-tb" || arg == "--threads-batch") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_threads_batch = std::stoi(argv[i]); - if (params.n_threads_batch <= 0) { - params.n_threads_batch = std::thread::hardware_concurrency(); - } - } else if (arg == "-p" || arg == "--prompt") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.prompt = argv[i]; - } else if (arg == "-e" || arg == "--escape") { - params.escape = true; - } else if (arg == "--prompt-cache") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.path_prompt_cache = argv[i]; - } else if (arg == "--prompt-cache-all") { - params.prompt_cache_all = true; - } else if (arg == "--prompt-cache-ro") { - params.prompt_cache_ro = true; - } else if (arg == "-f" || arg == "--file") { - if (++i >= argc) { - invalid_param = true; - break; - } - std::ifstream file(argv[i]); - if (!file) { - fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); - invalid_param = true; - break; - } - // store the external file name in params - params.prompt_file = argv[i]; - std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt)); - if (!params.prompt.empty() && params.prompt.back() == '\n') { - params.prompt.pop_back(); - } - } else if (arg == "-n" || arg == "--n-predict") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_predict = std::stoi(argv[i]); - } else if (arg == "--top-k") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.top_k = std::stoi(argv[i]); - } else if (arg == "-c" || arg == "--ctx-size") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_ctx = std::stoi(argv[i]); - } else if (arg == "--rope-freq-base") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.rope_freq_base = std::stof(argv[i]); - } else if (arg == "--rope-freq-scale") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.rope_freq_scale = std::stof(argv[i]); - } else if (arg == "--rope-scale") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.rope_freq_scale = 1.0f/std::stof(argv[i]); - } else if (arg == "--memory-f32") { - params.memory_f16 = false; - } else if (arg == "--top-p") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.top_p = std::stof(argv[i]); - } else if (arg == "--temp") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.temp = std::stof(argv[i]); - } else if (arg == "--tfs") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.tfs_z = std::stof(argv[i]); - } else if (arg == "--typical") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.typical_p = std::stof(argv[i]); - } else if (arg == "--repeat-last-n") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.repeat_last_n = std::stoi(argv[i]); - } else if (arg == "--repeat-penalty") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.repeat_penalty = std::stof(argv[i]); - } else if (arg == "--frequency-penalty") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.frequency_penalty = std::stof(argv[i]); - } else if (arg == "--presence-penalty") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.presence_penalty = std::stof(argv[i]); - } else if (arg == "--mirostat") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.mirostat = std::stoi(argv[i]); - } else if (arg == "--mirostat-lr") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.mirostat_eta = std::stof(argv[i]); - } else if (arg == "--mirostat-ent") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.mirostat_tau = std::stof(argv[i]); - } else if (arg == "--cfg-negative-prompt") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.cfg_negative_prompt = argv[i]; - } else if (arg == "--cfg-negative-prompt-file") { - if (++i >= argc) { - invalid_param = true; - break; - } - std::ifstream file(argv[i]); - if (!file) { - fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); - invalid_param = true; - break; - } - std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(sparams.cfg_negative_prompt)); - if (!sparams.cfg_negative_prompt.empty() && sparams.cfg_negative_prompt.back() == '\n') { - sparams.cfg_negative_prompt.pop_back(); - } - } else if (arg == "--cfg-scale") { - if (++i >= argc) { - invalid_param = true; - break; - } - sparams.cfg_scale = std::stof(argv[i]); - } else if (arg == "-b" || arg == "--batch-size") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_batch = std::stoi(argv[i]); - } else if (arg == "--keep") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_keep = std::stoi(argv[i]); - } else if (arg == "--draft") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_draft = std::stoi(argv[i]); - } else if (arg == "--chunks") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_chunks = std::stoi(argv[i]); - } else if (arg == "-np" || arg == "--parallel") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_parallel = std::stoi(argv[i]); - } else if (arg == "-ns" || arg == "--sequences") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_sequences = std::stoi(argv[i]); - } else if (arg == "-m" || arg == "--model") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.model = argv[i]; - } else if (arg == "-md" || arg == "--model-draft") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.model_draft = argv[i]; - } else if (arg == "-a" || arg == "--alias") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.model_alias = argv[i]; - } else if (arg == "--lora") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.lora_adapter.push_back(std::make_tuple(argv[i], 1.0f)); - params.use_mmap = false; - } else if (arg == "--lora-scaled") { - if (++i >= argc) { - invalid_param = true; - break; - } - const char * lora_adapter = argv[i]; - if (++i >= argc) { - invalid_param = true; - break; - } - params.lora_adapter.push_back(std::make_tuple(lora_adapter, std::stof(argv[i]))); - params.use_mmap = false; - } else if (arg == "--lora-base") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.lora_base = argv[i]; - } else if (arg == "--mmproj") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.mmproj = argv[i]; - } else if (arg == "--image") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.image = argv[i]; - } else if (arg == "-i" || arg == "--interactive") { - params.interactive = true; - } else if (arg == "--embedding") { - params.embedding = true; - } else if (arg == "--interactive-first") { - params.interactive_first = true; - } else if (arg == "-ins" || arg == "--instruct") { - params.instruct = true; - } else if (arg == "--infill") { - params.infill = true; - } else if (arg == "--multiline-input") { - params.multiline_input = true; - } else if (arg == "--simple-io") { - params.simple_io = true; - } else if (arg == "-cb" || arg == "--cont-batching") { - params.cont_batching = true; - } else if (arg == "--color") { - params.use_color = true; - } else if (arg == "--mlock") { - params.use_mlock = true; - } else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") { - if (++i >= argc) { - invalid_param = true; - break; - } -#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD - params.n_gpu_layers = std::stoi(argv[i]); -#else - fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n"); - fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); -#endif - } else if (arg == "--gpu-layers-draft" || arg == "-ngld" || arg == "--n-gpu-layers-draft") { - if (++i >= argc) { - invalid_param = true; - break; - } -#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD - params.n_gpu_layers_draft = std::stoi(argv[i]); -#else - fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n"); - fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); -#endif - } else if (arg == "--main-gpu" || arg == "-mg") { - if (++i >= argc) { - invalid_param = true; - break; - } -#ifdef GGML_USE_CUBLAS - params.main_gpu = std::stoi(argv[i]); -#else - fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n"); -#endif - } else if (arg == "--tensor-split" || arg == "-ts") { - if (++i >= argc) { - invalid_param = true; - break; - } -#ifdef GGML_USE_CUBLAS - std::string arg_next = argv[i]; - - // split string by , and / - const std::regex regex{R"([,/]+)"}; - std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; - std::vector split_arg{it, {}}; - GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES); - - for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) { - if (i < split_arg.size()) { - params.tensor_split[i] = std::stof(split_arg[i]); - } else { - params.tensor_split[i] = 0.0f; - } - } -#else - fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n"); -#endif // GGML_USE_CUBLAS - } else if (arg == "--no-mul-mat-q" || arg == "-nommq") { -#ifdef GGML_USE_CUBLAS - params.mul_mat_q = false; -#else - fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Disabling mul_mat_q kernels has no effect.\n"); -#endif // GGML_USE_CUBLAS - } else if (arg == "--no-mmap") { - params.use_mmap = false; - } else if (arg == "--numa") { - params.numa = true; - } else if (arg == "--verbose-prompt") { - params.verbose_prompt = true; - } else if (arg == "-r" || arg == "--reverse-prompt") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.antiprompt.push_back(argv[i]); - } else if (arg == "-ld" || arg == "--logdir") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.logdir = argv[i]; - - if (params.logdir.back() != DIRECTORY_SEPARATOR) { - params.logdir += DIRECTORY_SEPARATOR; - } - } else if (arg == "--perplexity" || arg == "--all-logits") { - params.logits_all = true; - } else if (arg == "--ppl-stride") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.ppl_stride = std::stoi(argv[i]); - } else if (arg == "--ppl-output-type") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.ppl_output_type = std::stoi(argv[i]); - } else if (arg == "--hellaswag") { - params.hellaswag = true; - } else if (arg == "--hellaswag-tasks") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.hellaswag_tasks = std::stoi(argv[i]); - } else if (arg == "--ignore-eos") { - params.ignore_eos = true; - } else if (arg == "--no-penalize-nl") { - sparams.penalize_nl = false; - } else if (arg == "-l" || arg == "--logit-bias") { - if (++i >= argc) { - invalid_param = true; - break; - } - std::stringstream ss(argv[i]); - llama_token key; - char sign; - std::string value_str; - try { - if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) { - sparams.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); - } else { - throw std::exception(); - } - } catch (const std::exception&) { - invalid_param = true; - break; - } - } else if (arg == "-h" || arg == "--help") { - gpt_print_usage(argc, argv, default_params); -#ifndef LOG_DISABLE_LOGS - log_print_usage(); -#endif // LOG_DISABLE_LOGS - exit(0); - } else if (arg == "--random-prompt") { - params.random_prompt = true; - } else if (arg == "--in-prefix-bos") { - params.input_prefix_bos = true; - } else if (arg == "--in-prefix") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.input_prefix = argv[i]; - } else if (arg == "--in-suffix") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.input_suffix = argv[i]; - } else if (arg == "--grammar") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.grammar = argv[i]; - } else if (arg == "--grammar-file") { - if (++i >= argc) { - invalid_param = true; - break; - } - std::ifstream file(argv[i]); - if (!file) { - fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); - invalid_param = true; - break; - } - std::copy( - std::istreambuf_iterator(file), - std::istreambuf_iterator(), - std::back_inserter(params.grammar) - ); -#ifndef LOG_DISABLE_LOGS - // Parse args for logging parameters - } else if ( log_param_single_parse( argv[i] ) ) { - // Do nothing, log_param_single_parse automatically does it's thing - // and returns if a match was found and parsed. - } else if ( log_param_pair_parse( /*check_but_dont_parse*/ true, argv[i] ) ) { - // We have a matching known parameter requiring an argument, - // now we need to check if there is anything after this argv - // and flag invalid_param or parse it. - if (++i >= argc) { - invalid_param = true; - break; - } - if( !log_param_pair_parse( /*check_but_dont_parse*/ false, argv[i-1], argv[i]) ) { - invalid_param = true; - break; - } - // End of Parse args for logging parameters -#endif // LOG_DISABLE_LOGS - } else { - fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); - gpt_print_usage(argc, argv, default_params); - exit(1); - } - } - if (invalid_param) { - fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); - gpt_print_usage(argc, argv, default_params); - exit(1); - } - if (params.prompt_cache_all && - (params.interactive || params.interactive_first || - params.instruct)) { - fprintf(stderr, "error: --prompt-cache-all not supported in interactive mode yet\n"); - gpt_print_usage(argc, argv, default_params); - exit(1); - } - - if (params.escape) { - process_escapes(params.prompt); - process_escapes(params.input_prefix); - process_escapes(params.input_suffix); - for (auto & antiprompt : params.antiprompt) { - process_escapes(antiprompt); - } - } - - return true; -} - -void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { - const llama_sampling_params & sparams = params.sampling_params; - - printf("usage: %s [options]\n", argv[0]); - printf("\n"); - printf("options:\n"); - printf(" -h, --help show this help message and exit\n"); - printf(" -i, --interactive run in interactive mode\n"); - printf(" --interactive-first run in interactive mode and wait for input right away\n"); - printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n"); - printf(" --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); - printf(" -r PROMPT, --reverse-prompt PROMPT\n"); - printf(" halt generation at PROMPT, return control in interactive mode\n"); - printf(" (can be specified more than once for multiple prompts).\n"); - printf(" --color colorise output to distinguish prompt and user input from generations\n"); - printf(" -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n"); - printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads); - printf(" -tb N, --threads-batch N\n"); - printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n"); - printf(" -p PROMPT, --prompt PROMPT\n"); - printf(" prompt to start generation with (default: empty)\n"); - printf(" -e, --escape process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n"); - printf(" --prompt-cache FNAME file to cache prompt state for faster startup (default: none)\n"); - printf(" --prompt-cache-all if specified, saves user input and generations to cache as well.\n"); - printf(" not supported with --interactive or other interactive options\n"); - printf(" --prompt-cache-ro if specified, uses the prompt cache but does not update it.\n"); - printf(" --random-prompt start with a randomized prompt.\n"); - printf(" --in-prefix-bos prefix BOS to user inputs, preceding the `--in-prefix` string\n"); - printf(" --in-prefix STRING string to prefix user inputs with (default: empty)\n"); - printf(" --in-suffix STRING string to suffix after user inputs with (default: empty)\n"); - printf(" -f FNAME, --file FNAME\n"); - printf(" prompt file to start generation.\n"); - printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict); - printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx); - printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); - printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k); - printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); - printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z); - printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p); - printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.repeat_last_n); - printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.repeat_penalty); - printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.presence_penalty); - printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.frequency_penalty); - printf(" --mirostat N use Mirostat sampling.\n"); - printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"); - printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat); - printf(" --mirostat-lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)sparams.mirostat_eta); - printf(" --mirostat-ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)sparams.mirostat_tau); - printf(" -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n"); - printf(" modifies the likelihood of token appearing in the completion,\n"); - printf(" i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); - printf(" or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); - printf(" --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir)\n"); - printf(" --grammar-file FNAME file to read grammar from\n"); - printf(" --cfg-negative-prompt PROMPT\n"); - printf(" negative prompt to use for guidance. (default: empty)\n"); - printf(" --cfg-negative-prompt-file FNAME\n"); - printf(" negative prompt file to use for guidance. (default: empty)\n"); - printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", sparams.cfg_scale); - printf(" --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale\n"); - printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n"); - printf(" --rope-freq-scale N RoPE frequency linear scaling factor (default: loaded from model)\n"); - printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); - printf(" --no-penalize-nl do not penalize newline token\n"); - printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); - printf(" not recommended: doubles context memory required and no measurable increase in quality\n"); - printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp); - printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n"); - printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n"); - printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks); - printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); - printf(" --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft); - printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks); - printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel); - printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences); - printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); - printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n"); - printf(" --image IMAGE_FILE path to an image file. use with multimodal models\n"); - if (llama_mlock_supported()) { - printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n"); - } - if (llama_mmap_supported()) { - printf(" --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n"); - } - printf(" --numa attempt optimizations that help on some NUMA systems\n"); - printf(" if run without this previously, it is recommended to drop the system page cache before using this\n"); - printf(" see https://github.com/ggerganov/llama.cpp/issues/1437\n"); -#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD - printf(" -ngl N, --n-gpu-layers N\n"); - printf(" number of layers to store in VRAM\n"); - printf(" -ngld N, --n-gpu-layers-draft N\n"); - printf(" number of layers to store in VRAM for the draft model\n"); - printf(" -ts SPLIT --tensor-split SPLIT\n"); - printf(" how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n"); - printf(" -mg i, --main-gpu i the GPU to use for scratch and small tensors\n"); -#ifdef GGML_USE_CUBLAS - printf(" -nommq, --no-mul-mat-q\n"); - printf(" use " GGML_CUBLAS_NAME " instead of custom mul_mat_q " GGML_CUDA_NAME " kernels.\n"); - printf(" Not recommended since this is both slower and uses more VRAM.\n"); -#endif // GGML_USE_CUBLAS -#endif - printf(" --verbose-prompt print prompt before generation\n"); - fprintf(stderr, " --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n"); - printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); - printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n"); - printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); - printf(" -m FNAME, --model FNAME\n"); - printf(" model path (default: %s)\n", params.model.c_str()); - printf(" -md FNAME, --model-draft FNAME\n"); - printf(" draft model for speculative decoding (default: %s)\n", params.model.c_str()); - printf(" -ld LOGDIR, --logdir LOGDIR\n"); - printf(" path under which to save YAML logs (no logging if unset)\n"); - printf("\n"); -} - -std::string get_system_info(const gpt_params & params) { - std::ostringstream os; - - os << "system_info: n_threads = " << params.n_threads; - if (params.n_threads_batch != -1) { - os << " (n_threads_batch = " << params.n_threads_batch << ")"; - } - os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info(); - - return os.str(); -} - -std::string gpt_random_prompt(std::mt19937 & rng) { - const int r = rng() % 10; - switch (r) { - case 0: return "So"; - case 1: return "Once upon a time"; - case 2: return "When"; - case 3: return "The"; - case 4: return "After"; - case 5: return "If"; - case 6: return "import"; - case 7: return "He"; - case 8: return "She"; - case 9: return "They"; - } - - GGML_UNREACHABLE(); -} - -// -// Model utils -// - -struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) { - auto mparams = llama_model_default_params(); - - if (params.n_gpu_layers != -1) { - mparams.n_gpu_layers = params.n_gpu_layers; - } - mparams.main_gpu = params.main_gpu; - mparams.tensor_split = params.tensor_split; - mparams.use_mmap = params.use_mmap; - mparams.use_mlock = params.use_mlock; - - return mparams; -} - -struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) { - auto cparams = llama_context_default_params(); - - cparams.n_ctx = params.n_ctx; - cparams.n_batch = params.n_batch; - cparams.n_threads = params.n_threads; - cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; - cparams.mul_mat_q = params.mul_mat_q; - cparams.seed = params.seed; - cparams.f16_kv = params.memory_f16; - cparams.logits_all = params.logits_all; - cparams.embedding = params.embedding; - cparams.rope_freq_base = params.rope_freq_base; - cparams.rope_freq_scale = params.rope_freq_scale; - - return cparams; -} - -std::tuple llama_init_from_gpt_params(gpt_params & params) { - auto mparams = llama_model_params_from_gpt_params(params); - - llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams); - if (model == NULL) { - fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); - return std::make_tuple(nullptr, nullptr); - } - - auto cparams = llama_context_params_from_gpt_params(params); - - llama_context * lctx = llama_new_context_with_model(model, cparams); - if (lctx == NULL) { - fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); - llama_free_model(model); - return std::make_tuple(nullptr, nullptr); - } - - for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) { - const std::string& lora_adapter = std::get<0>(params.lora_adapter[i]); - float lora_scale = std::get<1>(params.lora_adapter[i]); - int err = llama_model_apply_lora_from_file(model, - lora_adapter.c_str(), - lora_scale, - ((i > 0) || params.lora_base.empty()) - ? NULL - : params.lora_base.c_str(), - params.n_threads); - if (err != 0) { - fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); - llama_free(lctx); - llama_free_model(model); - return std::make_tuple(nullptr, nullptr); - } - } - - if (params.ignore_eos) { - params.sampling_params.logit_bias[llama_token_eos(lctx)] = -INFINITY; - } - - { - LOG("warming up the model with an empty run\n"); - - std::vector tmp = { llama_token_bos(lctx), llama_token_eos(lctx), }; - llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); - llama_kv_cache_tokens_rm(lctx, -1, -1); - llama_reset_timings(lctx); - } - - return std::make_tuple(model, lctx); -} - -// -// Vocab utils -// - -std::vector llama_tokenize( - const struct llama_context * ctx, - const std::string & text, - bool add_bos) { - return llama_tokenize(llama_get_model(ctx), text, add_bos); -} - -std::vector llama_tokenize( - const struct llama_model * model, - const std::string & text, - bool add_bos) { - // upper limit for the number of tokens - int n_tokens = text.length() + add_bos; - std::vector result(n_tokens); - n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos); - if (n_tokens < 0) { - result.resize(-n_tokens); - int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos); - GGML_ASSERT(check == -n_tokens); - } else { - result.resize(n_tokens); - } - return result; -} - -std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) { - std::vector result(8, 0); - const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); - if (n_tokens < 0) { - result.resize(-n_tokens); - int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); - GGML_ASSERT(check == -n_tokens); - } else { - result.resize(n_tokens); - } - - return std::string(result.data(), result.size()); -} - -std::string llama_detokenize_spm(llama_context * ctx, const std::vector & tokens) { - const llama_token bos_id = llama_token_bos(ctx); - - std::string piece; - std::string result; - - for (size_t i = 0; i < tokens.size(); ++i) { - piece = llama_token_to_piece(ctx, tokens[i]); - - // remove the leading space of the first non-BOS token - if (((tokens[0] == bos_id && i == 1) || (tokens[0] != bos_id && i == 0)) && piece[0] == ' ') { - piece = piece.substr(1); - } - - result += piece; - } - - return result; -} - -std::string llama_detokenize_bpe(llama_context * ctx, const std::vector & tokens) { - std::string piece; - std::string result; - - for (size_t i = 0; i < tokens.size(); ++i) { - piece = llama_token_to_piece(ctx, tokens[i]); - - result += piece; - } - - // NOTE: the original tokenizer decodes bytes after collecting the pieces. - return result; -} - -// -// YAML utils -// - -// returns true if successful, false otherwise -bool create_directory_with_parents(const std::string & path) { -#ifdef _WIN32 - std::wstring_convert> converter; - std::wstring wpath = converter.from_bytes(path); - - // if the path already exists, check whether it's a directory - const DWORD attributes = GetFileAttributesW(wpath.c_str()); - if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) { - return true; - } - - size_t pos_slash = 0; - - // process path from front to back, procedurally creating directories - while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) { - const std::wstring subpath = wpath.substr(0, pos_slash); - const wchar_t * test = subpath.c_str(); - - const bool success = CreateDirectoryW(test, NULL); - if (!success) { - const DWORD error = GetLastError(); - - // if the path already exists, ensure that it's a directory - if (error == ERROR_ALREADY_EXISTS) { - const DWORD attributes = GetFileAttributesW(subpath.c_str()); - if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) { - return false; - } - } else { - return false; - } - } - - pos_slash += 1; - } - - return true; -#else - // if the path already exists, check whether it's a directory - struct stat info; - if (stat(path.c_str(), &info) == 0) { - return S_ISDIR(info.st_mode); - } - - size_t pos_slash = 1; // skip leading slashes for directory creation - - // process path from front to back, procedurally creating directories - while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) { - const std::string subpath = path.substr(0, pos_slash); - struct stat info; - - // if the path already exists, ensure that it's a directory - if (stat(subpath.c_str(), &info) == 0) { - if (!S_ISDIR(info.st_mode)) { - return false; - } - } else { - // create parent directories - const int ret = mkdir(subpath.c_str(), 0755); - if (ret != 0) { - return false; - } - } - - pos_slash += 1; - } - - return true; -#endif // _WIN32 -} - -void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector & data) { - if (data.empty()) { - fprintf(stream, "%s:\n", prop_name); - return; - } - - fprintf(stream, "%s: [", prop_name); - for (size_t i = 0; i < data.size() - 1; ++i) { - fprintf(stream, "%e, ", data[i]); - } - fprintf(stream, "%e]\n", data.back()); -} - -void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector & data) { - if (data.empty()) { - fprintf(stream, "%s:\n", prop_name); - return; - } - - fprintf(stream, "%s: [", prop_name); - for (size_t i = 0; i < data.size() - 1; ++i) { - fprintf(stream, "%d, ", data[i]); - } - fprintf(stream, "%d]\n", data.back()); -} - -void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const char * data) { - std::string data_str(data == NULL ? "" : data); - - if (data_str.empty()) { - fprintf(stream, "%s:\n", prop_name); - return; - } - - size_t pos_start = 0; - size_t pos_found = 0; - - if (!data_str.empty() && (std::isspace(data_str[0]) || std::isspace(data_str.back()))) { - data_str = std::regex_replace(data_str, std::regex("\n"), "\\n"); - data_str = std::regex_replace(data_str, std::regex("\""), "\\\""); - data_str = "\"" + data_str + "\""; - fprintf(stream, "%s: %s\n", prop_name, data_str.c_str()); - return; - } - - if (data_str.find('\n') == std::string::npos) { - fprintf(stream, "%s: %s\n", prop_name, data_str.c_str()); - return; - } - - fprintf(stream, "%s: |\n", prop_name); - while ((pos_found = data_str.find('\n', pos_start)) != std::string::npos) { - fprintf(stream, " %s\n", data_str.substr(pos_start, pos_found-pos_start).c_str()); - pos_start = pos_found + 1; - } -} - -std::string get_sortable_timestamp() { - using clock = std::chrono::system_clock; - - const clock::time_point current_time = clock::now(); - const time_t as_time_t = clock::to_time_t(current_time); - char timestamp_no_ns[100]; - std::strftime(timestamp_no_ns, 100, "%Y_%m_%d-%H_%M_%S", std::localtime(&as_time_t)); - - const int64_t ns = std::chrono::duration_cast( - current_time.time_since_epoch() % 1000000000).count(); - char timestamp_ns[11]; - snprintf(timestamp_ns, 11, "%09" PRId64, ns); - - return std::string(timestamp_no_ns) + "." + std::string(timestamp_ns); -} - -void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const llama_context * lctx, - const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc) { - const llama_sampling_params & sparams = params.sampling_params; - - fprintf(stream, "build_commit: %s\n", BUILD_COMMIT); - fprintf(stream, "build_number: %d\n", BUILD_NUMBER); - fprintf(stream, "cpu_has_arm_fma: %s\n", ggml_cpu_has_arm_fma() ? "true" : "false"); - fprintf(stream, "cpu_has_avx: %s\n", ggml_cpu_has_avx() ? "true" : "false"); - fprintf(stream, "cpu_has_avx2: %s\n", ggml_cpu_has_avx2() ? "true" : "false"); - fprintf(stream, "cpu_has_avx512: %s\n", ggml_cpu_has_avx512() ? "true" : "false"); - fprintf(stream, "cpu_has_avx512_vbmi: %s\n", ggml_cpu_has_avx512_vbmi() ? "true" : "false"); - fprintf(stream, "cpu_has_avx512_vnni: %s\n", ggml_cpu_has_avx512_vnni() ? "true" : "false"); - fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false"); - fprintf(stream, "cpu_has_cublas: %s\n", ggml_cpu_has_cublas() ? "true" : "false"); - fprintf(stream, "cpu_has_clblast: %s\n", ggml_cpu_has_clblast() ? "true" : "false"); - fprintf(stream, "cpu_has_fma: %s\n", ggml_cpu_has_fma() ? "true" : "false"); - fprintf(stream, "cpu_has_gpublas: %s\n", ggml_cpu_has_gpublas() ? "true" : "false"); - fprintf(stream, "cpu_has_neon: %s\n", ggml_cpu_has_neon() ? "true" : "false"); - fprintf(stream, "cpu_has_f16c: %s\n", ggml_cpu_has_f16c() ? "true" : "false"); - fprintf(stream, "cpu_has_fp16_va: %s\n", ggml_cpu_has_fp16_va() ? "true" : "false"); - fprintf(stream, "cpu_has_wasm_simd: %s\n", ggml_cpu_has_wasm_simd() ? "true" : "false"); - fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false"); - fprintf(stream, "cpu_has_sse3: %s\n", ggml_cpu_has_sse3() ? "true" : "false"); - fprintf(stream, "cpu_has_vsx: %s\n", ggml_cpu_has_vsx() ? "true" : "false"); - -#ifdef NDEBUG - fprintf(stream, "debug: false\n"); -#else - fprintf(stream, "debug: true\n"); -#endif // NDEBUG - - fprintf(stream, "model_desc: %s\n", model_desc); - fprintf(stream, "n_vocab: %d # output size of the final layer, 32001 for some models\n", llama_n_vocab(llama_get_model(lctx))); - -#ifdef __OPTIMIZE__ - fprintf(stream, "optimize: true\n"); -#else - fprintf(stream, "optimize: false\n"); -#endif // __OPTIMIZE__ - - fprintf(stream, "time: %s\n", timestamp.c_str()); - - fprintf(stream, "\n"); - fprintf(stream, "###############\n"); - fprintf(stream, "# User Inputs #\n"); - fprintf(stream, "###############\n"); - fprintf(stream, "\n"); - - fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str()); - fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch); - dump_string_yaml_multiline(stream, "cfg_negative_prompt", sparams.cfg_negative_prompt.c_str()); - fprintf(stream, "cfg_scale: %f # default: 1.0\n", sparams.cfg_scale); - fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks); - fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false"); - fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); - fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); - fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); - fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.frequency_penalty); - dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str()); - fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n"); - fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); - fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks); - - const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(lctx)); - const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY; - fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false"); - - dump_string_yaml_multiline(stream, "in_prefix", params.input_prefix.c_str()); - fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false"); - dump_string_yaml_multiline(stream, "in_suffix", params.input_prefix.c_str()); - fprintf(stream, "instruct: %s # default: false\n", params.instruct ? "true" : "false"); - fprintf(stream, "interactive: %s # default: false\n", params.interactive ? "true" : "false"); - fprintf(stream, "interactive_first: %s # default: false\n", params.interactive_first ? "true" : "false"); - fprintf(stream, "keep: %d # default: 0\n", params.n_keep); - fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str()); - - fprintf(stream, "logit_bias:\n"); - for (std::pair lb : sparams.logit_bias) { - if (ignore_eos && lb.first == logit_bias_eos->first) { - continue; - } - fprintf(stream, " %d: %f", lb.first, lb.second); - } - - fprintf(stream, "lora:\n"); - for (std::tuple la : params.lora_adapter) { - if (std::get<1>(la) != 1.0f) { - continue; - } - fprintf(stream, " - %s\n", std::get<0>(la).c_str()); - } - fprintf(stream, "lora_scaled:\n"); - for (std::tuple la : params.lora_adapter) { - if (std::get<1>(la) == 1.0f) { - continue; - } - fprintf(stream, " - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la)); - } - fprintf(stream, "lora_base: %s\n", params.lora_base.c_str()); - fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu); - fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false"); - fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat); - fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau); - fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta); - fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false"); - fprintf(stream, "model: %s # default: models/7B/ggml-model.bin\n", params.model.c_str()); - fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str()); - fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false"); - fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers); - fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict); - fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", sparams.n_probs); - fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false"); - fprintf(stream, "no_mul_mat_q: %s # default: false\n", !params.mul_mat_q ? "true" : "false"); - fprintf(stream, "no_penalize_nl: %s # default: false\n", !sparams.penalize_nl ? "true" : "false"); - fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false"); - fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type); - fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride); - fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.presence_penalty); - dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str()); - fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str()); - fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false"); - fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false"); - dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens); - fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false"); - fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.repeat_penalty); - - fprintf(stream, "reverse_prompt:\n"); - for (std::string ap : params.antiprompt) { - size_t pos = 0; - while ((pos = ap.find('\n', pos)) != std::string::npos) { - ap.replace(pos, 1, "\\n"); - pos += 1; - } - - fprintf(stream, " - %s\n", ap.c_str()); - } - - fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base); - fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale); - fprintf(stream, "seed: %d # default: -1 (random seed)\n", params.seed); - fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); - fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); - fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); - - const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + LLAMA_MAX_DEVICES); - dump_vector_float_yaml(stream, "tensor_split", tensor_split_vector); - - fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z); - fprintf(stream, "threads: %d # default: %d\n", params.n_threads, std::thread::hardware_concurrency()); - fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); - fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); - fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); - fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); -} diff --git a/plugins/wasi_nn/thirdparty/ggml/common.h b/plugins/wasi_nn/thirdparty/ggml/common.h deleted file mode 100644 index 4305047d..00000000 --- a/plugins/wasi_nn/thirdparty/ggml/common.h +++ /dev/null @@ -1,186 +0,0 @@ -// Various helper functions and utilities - -#pragma once - -#include "llama.h" - -#include "sampling.h" - -#define LOG_NO_FILE_LINE_FUNCTION -#include "log.h" - -#include -#include -#include -#include -#include -#include - -#ifdef _WIN32 -#define DIRECTORY_SEPARATOR '\\' -#else -#define DIRECTORY_SEPARATOR '/' -#endif // _WIN32 - -#define BUILD_NUMBER 1383 -#define BUILD_COMMIT "Embedded in WasmEdge" -#define BUILD_COMPILER "Embedded in WasmEdge" -#define BUILD_TARGET "Embedded in WasmEdge" - -#define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0) -#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0) - -#define print_build_info() do { \ - fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT); \ - fprintf(stderr, "%s: built with %s for %s\n", __func__, BUILD_COMPILER, BUILD_TARGET); \ -} while(0) - -// -// CLI argument parsing -// -int32_t get_num_physical_cores(); - -struct gpt_params { - uint32_t seed = -1; // RNG seed - int32_t n_threads = get_num_physical_cores(); - int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) - int32_t n_predict = -1; // new tokens to predict - int32_t n_ctx = 512; // context size - int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_draft = 16; // number of tokens to draft during speculative decoding - int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) - int32_t n_parallel = 1; // number of parallel sequences to decode - int32_t n_sequences = 1; // number of sequences to decode - int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) - int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) - int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors - float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs - int32_t n_beams = 0; // if non-zero then use beam search of given width. - float rope_freq_base = 0.0f; // RoPE base frequency - float rope_freq_scale = 0.0f; // RoPE frequency scaling factor - - // sampling parameters - struct llama_sampling_params sampling_params; - - std::string model = "models/7B/ggml-model-f16.gguf"; // model path - std::string model_draft = ""; // draft model for speculative decoding - std::string model_alias = "unknown"; // model alias - std::string prompt = ""; - std::string prompt_file = ""; // store the external prompt file name - std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state - std::string input_prefix = ""; // string to prefix user inputs with - std::string input_suffix = ""; // string to suffix user inputs with - std::string grammar = ""; // optional BNF-like grammar to constrain sampling - std::vector antiprompt; // string upon seeing which more user input is prompted - std::string logdir = ""; // directory in which to save YAML log files - - std::vector> lora_adapter; // lora adapter path with user defined scale - std::string lora_base = ""; // base model path for the lora adapter - - int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. - int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line - // (which is more convenient to use for plotting) - // - bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt - size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score - - bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS - bool memory_f16 = true; // use f16 instead of f32 for memory kv - bool random_prompt = false; // do not randomize prompt if none provided - bool use_color = false; // use color to distinguish generations and inputs - bool interactive = false; // interactive mode - bool prompt_cache_all = false; // save user input and generations to prompt cache - bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it - - bool embedding = false; // get only sentence embedding - bool escape = false; // escape "\n", "\r", "\t", "\'", "\"", and "\\" - bool interactive_first = false; // wait for user input immediately - bool multiline_input = false; // reverse the usage of `\` - bool simple_io = false; // improves compatibility with subprocesses and limited consoles - bool cont_batching = false; // insert new sequences for decoding on-the-fly - - bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix - bool ignore_eos = false; // ignore generated EOS tokens - bool instruct = false; // instruction mode (used for Alpaca models) - bool logits_all = false; // return logits for all tokens in the batch - bool use_mmap = true; // use mmap for faster loads - bool use_mlock = false; // use mlock to keep model in memory - bool numa = false; // attempt optimizations that help on some NUMA systems - bool verbose_prompt = false; // print prompt tokens before generation - bool infill = false; // use infill mode - - // multimodal models (see examples/llava) - std::string mmproj = ""; // path to multimodal projector - std::string image = ""; // path to an image file -}; - -bool gpt_params_parse(int argc, char ** argv, gpt_params & params); - -void gpt_print_usage(int argc, char ** argv, const gpt_params & params); - -std::string get_system_info(const gpt_params & params); - -std::string gpt_random_prompt(std::mt19937 & rng); - -void process_escapes(std::string& input); - -// -// Model utils -// - -std::tuple llama_init_from_gpt_params(gpt_params & params); -struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params); -struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); - -// -// Vocab utils -// - -// tokenizes a string into a vector of tokens -// should work similar to Python's `tokenizer.encode` -std::vector llama_tokenize( - const struct llama_context * ctx, - const std::string & text, - bool add_bos); - -std::vector llama_tokenize( - const struct llama_model * model, - const std::string & text, - bool add_bos); - -// tokenizes a token into a piece -// should work similar to Python's `tokenizer.id_to_piece` -std::string llama_token_to_piece( - const struct llama_context * ctx, - llama_token token); - -// TODO: these should be moved in llama.h C-style API under single `llama_detokenize` function -// that takes into account the tokenizer type and decides how to handle the leading space -// -// detokenizes a vector of tokens into a string -// should work similar to Python's `tokenizer.decode` -// removes the leading space from the first non-BOS token -std::string llama_detokenize_spm( - llama_context * ctx, - const std::vector & tokens); - -// detokenizes a vector of tokens into a string -// should work similar to Python's `tokenizer.decode` -std::string llama_detokenize_bpe( - llama_context * ctx, - const std::vector & tokens); - -// -// YAML utils -// - -bool create_directory_with_parents(const std::string & path); -void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector & data); -void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector & data); -void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const char * data); -std::string get_sortable_timestamp(); - -void dump_non_result_info_yaml( - FILE * stream, const gpt_params & params, const llama_context * lctx, - const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc); diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.c b/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.c deleted file mode 100644 index 34eba3f8..00000000 --- a/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.c +++ /dev/null @@ -1,594 +0,0 @@ -#include "ggml-alloc.h" -#include "ggml-backend.h" -#include "ggml.h" -#include -#include -#include -#include -#include - - -#define UNUSED(x) (void)(x) -#define MAX(a, b) ((a) > (b) ? (a) : (b)) -#define GGML_MAX_CONCUR (2*GGML_MAX_NODES) - -//#define GGML_ALLOCATOR_DEBUG - -//#define AT_PRINTF printf -#define AT_PRINTF(...) ((void)0) - -struct hash_node { - struct ggml_tensor * t; - int n_children; - int n_views; -}; - -static size_t hash(void * p) { - return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE; -} - -static struct hash_node * hash_get(struct hash_node hash_table[], struct ggml_tensor * t) { - size_t h = hash(t); - - // linear probing - size_t i = h; - while (hash_table[i].t != NULL) { - if (hash_table[i].t == t) { - return &hash_table[i]; - } - i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE; - if (i == h) { - // hash table is full - GGML_ASSERT(false); - } - } - - hash_table[i].t = t; - return &hash_table[i]; -} - -// TODO: GGML_PAD ? -static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) { - assert(alignment && !(alignment & (alignment - 1))); // power of 2 - size_t align = (alignment - (((uintptr_t)buffer + offset) % alignment)) % alignment; - return offset + align; -} - -struct free_block { - void * addr; - size_t size; -}; - -#define MAX_FREE_BLOCKS 256 - -struct ggml_allocr { - struct ggml_backend_buffer * buffer; - bool buffer_owned; - void * data; - size_t alignment; - int n_free_blocks; - struct free_block free_blocks[MAX_FREE_BLOCKS]; - struct hash_node hash_table[GGML_GRAPH_HASHTABLE_SIZE]; - size_t max_size; - bool measure; - int parse_seq[GGML_MAX_CONCUR]; - int parse_seq_len; - -#ifdef GGML_ALLOCATOR_DEBUG - struct ggml_tensor * allocated_tensors[1024]; -#endif -}; - -#ifdef GGML_ALLOCATOR_DEBUG -static void add_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) { - for (int i = 0; i < 1024; i++) { - if (alloc->allocated_tensors[i] == NULL) { - alloc->allocated_tensors[i] = tensor; - return; - } - } - GGML_ASSERT(!"out of allocated_tensors"); -} -static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) { - for (int i = 0; i < 1024; i++) { - if (alloc->allocated_tensors[i] == tensor || - (alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) { - alloc->allocated_tensors[i] = NULL; - return; - } - } - printf("tried to free tensor %s not found\n", tensor->name); - GGML_ASSERT(!"tensor not found"); -} -#endif - -// check if a tensor is allocated by this buffer -static bool ggml_allocr_is_own(struct ggml_allocr * alloc, const struct ggml_tensor * tensor) { - return tensor->buffer == alloc->buffer; -} - -static bool ggml_is_view(struct ggml_tensor * t) { - return t->view_src != NULL; -} - -void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) { - GGML_ASSERT(!ggml_is_view(tensor)); // views generally get data pointer from one of their sources - GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated - - size_t size = ggml_backend_buffer_get_alloc_size(alloc->buffer, tensor); - size = aligned_offset(NULL, size, alloc->alignment); - - AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size); - - size_t max_avail = 0; - - // find the best fitting free block besides the last block - int best_fit_block = -1; - size_t best_fit_size = SIZE_MAX; - for (int i = 0; i < alloc->n_free_blocks - 1; i++) { - struct free_block * block = &alloc->free_blocks[i]; - max_avail = MAX(max_avail, block->size); - if (block->size >= size && block->size <= best_fit_size) { - best_fit_block = i; - best_fit_size = block->size; - } - } - - AT_PRINTF("block %d\n", best_fit_block); - - if (best_fit_block == -1) { - // the last block is our last resort - struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1]; - max_avail = MAX(max_avail, block->size); - if (block->size >= size) { - best_fit_block = alloc->n_free_blocks - 1; - } else { - fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n", - __func__, size, max_avail); - GGML_ASSERT(!"not enough space in the buffer"); - return; - } - } - struct free_block * block = &alloc->free_blocks[best_fit_block]; - void * addr = block->addr; - block->addr = (char*)block->addr + size; - block->size -= size; - if (block->size == 0) { - // remove block if empty - alloc->n_free_blocks--; - for (int j = best_fit_block; j < alloc->n_free_blocks; j++) { - alloc->free_blocks[j] = alloc->free_blocks[j+1]; - } - } - - tensor->data = addr; - AT_PRINTF("%s: allocated data at %p\n", __func__, tensor->data); - tensor->buffer = alloc->buffer; - ggml_backend_buffer_init_tensor(alloc->buffer, tensor); - -#ifdef GGML_ALLOCATOR_DEBUG - add_allocated_tensor(alloc, tensor); - size_t cur_max = (char*)addr - (char*)alloc->data + size; - if (cur_max > alloc->max_size) { - printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0); - for (int i = 0; i < 1024; i++) { - if (alloc->allocated_tensors[i]) { - printf("%s (%.2f MB) ", alloc->allocated_tensors[i]->name, ggml_nbytes(alloc->allocated_tensors[i]) / 1024.0 / 1024.0); - } - } - printf("\n"); - } -#endif - - alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->data + size); -} - -// this is a very naive implementation, but for our case the number of free blocks should be very small -static void ggml_allocr_free_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) { - if (ggml_allocr_is_own(alloc, tensor) == false) { - // the tensor was not allocated in this buffer - // this can happen because the graph allocator will try to free weights and other tensors from different buffers - // the easiest way to deal with this is just to ignore it - AT_PRINTF("ignoring %s (their buffer: %p, our buffer: %p)\n", tensor->name, (void *)tensor->buffer, (void *)alloc->buffer); - return; - } - - void * ptr = tensor->data; - - size_t size = ggml_backend_buffer_get_alloc_size(alloc->buffer, tensor); - size = aligned_offset(NULL, size, alloc->alignment); - AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks); - - ggml_backend_buffer_free_tensor(alloc->buffer, tensor); - -#ifdef GGML_ALLOCATOR_DEBUG - remove_allocated_tensor(alloc, tensor); -#endif - - // see if we can merge with an existing block - for (int i = 0; i < alloc->n_free_blocks; i++) { - struct free_block * block = &alloc->free_blocks[i]; - // check if ptr is at the end of the block - if ((char*)block->addr + block->size == ptr) { - block->size += size; - // check if we can merge with the next block - if (i < alloc->n_free_blocks - 1 && (char*)block->addr + block->size == alloc->free_blocks[i+1].addr) { - block->size += alloc->free_blocks[i+1].size; - alloc->n_free_blocks--; - for (int j = i+1; j < alloc->n_free_blocks; j++) { - alloc->free_blocks[j] = alloc->free_blocks[j+1]; - } - } - return; - } - // check if ptr is at the beginning of the block - if ((char*)ptr + size == block->addr) { - block->addr = ptr; - block->size += size; - // check if we can merge with the previous block - if (i > 0 && (char*)alloc->free_blocks[i-1].addr + alloc->free_blocks[i-1].size == block->addr) { - alloc->free_blocks[i-1].size += block->size; - alloc->n_free_blocks--; - for (int j = i; j < alloc->n_free_blocks; j++) { - alloc->free_blocks[j] = alloc->free_blocks[j+1]; - } - } - return; - } - } - // otherwise, add a new block - GGML_ASSERT(alloc->n_free_blocks < MAX_FREE_BLOCKS && "out of free blocks"); - // insert the new block in the correct position to keep the array sorted by address (to make merging blocks faster) - int insert_pos = 0; - while (insert_pos < alloc->n_free_blocks && alloc->free_blocks[insert_pos].addr < ptr) { - insert_pos++; - } - // shift all blocks from insert_pos onward to make room for the new block - for (int i = alloc->n_free_blocks; i > insert_pos; i--) { - alloc->free_blocks[i] = alloc->free_blocks[i-1]; - } - // insert the new block - alloc->free_blocks[insert_pos].addr = ptr; - alloc->free_blocks[insert_pos].size = size; - alloc->n_free_blocks++; -} - -void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n) { - for (int i = 0; i < n; i++) { - alloc->parse_seq[i] = list[i]; - } - alloc->parse_seq_len = n; -} - -void ggml_allocr_reset(struct ggml_allocr * alloc) { - alloc->n_free_blocks = 1; - size_t align_offset = aligned_offset(alloc->data, 0, alloc->alignment); - alloc->free_blocks[0].addr = (char *)alloc->data + align_offset; - alloc->free_blocks[0].size = ggml_backend_buffer_get_size(alloc->buffer) - align_offset; -} - -struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment) { - struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(NULL, data, size); - - struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr)); - - *alloc = (struct ggml_allocr){ - /*.buffer = */ buffer, - /*.buffer_owned = */ true, - /*.base = */ ggml_backend_buffer_get_base(buffer), - /*.alignment = */ alignment, - /*.n_free_blocks = */ 0, - /*.free_blocks = */ {{0}}, - /*.hash_table = */ {{0}}, - /*.max_size = */ 0, - /*.measure = */ false, - /*.parse_seq = */ {0}, - /*.parse_seq_len = */ 0, -#ifdef GGML_ALLOCATOR_DEBUG - /*.allocated_tensors = */ {0}, -#endif - }; - - ggml_allocr_reset(alloc); - - return alloc; -} - -struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) { - struct ggml_allocr * alloc = ggml_allocr_new((void *)0x1000, (size_t)-0x1001, alignment); - alloc->measure = true; - - return alloc; -} - -struct ggml_allocr * ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer) { - struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr)); - - *alloc = (struct ggml_allocr){ - /*.buffer = */ buffer, - /*.buffer_owned = */ false, - /*.base = */ ggml_backend_buffer_get_base(buffer), - /*.alignment = */ ggml_backend_buffer_get_alignment(buffer), - /*.n_free_blocks = */ 0, - /*.free_blocks = */ {{0}}, - /*.hash_table = */ {{0}}, - /*.max_size = */ 0, - /*.measure = */ false, - /*.parse_seq = */ {0}, - /*.parse_seq_len = */ 0, -#ifdef GGML_ALLOCATOR_DEBUG - /*.allocated_tensors = */ {0}, -#endif - }; - - ggml_allocr_reset(alloc); - - return alloc; -} - -void ggml_allocr_free(struct ggml_allocr * alloc) { - if (alloc->buffer_owned) { - ggml_backend_buffer_free(alloc->buffer); - } - free(alloc); -} - -bool ggml_allocr_is_measure(struct ggml_allocr * alloc) { - return alloc->measure; -} - -//////////// compute graph allocator - -static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) { - if (a->type != b->type) { - return false; - } - for (int i = 0; i < GGML_MAX_DIMS; i++) { - if (a->ne[i] != b->ne[i]) { - return false; - } - if (a->nb[i] != b->nb[i]) { - return false; - } - } - return true; -} - -static bool ggml_op_can_inplace(enum ggml_op op) { - switch (op) { - case GGML_OP_SCALE: - case GGML_OP_DIAG_MASK_ZERO: - case GGML_OP_DIAG_MASK_INF: - case GGML_OP_ADD: - case GGML_OP_ADD1: - case GGML_OP_SUB: - case GGML_OP_MUL: - case GGML_OP_DIV: - case GGML_OP_SQR: - case GGML_OP_SQRT: - case GGML_OP_LOG: - case GGML_OP_UNARY: - case GGML_OP_ROPE: - case GGML_OP_RMS_NORM: - case GGML_OP_SOFT_MAX: - return true; - - default: - return false; - } -} - -static void init_view(struct ggml_allocr * alloc, struct ggml_tensor * view) { - assert(view->view_src != NULL && view->view_src->data != NULL); - view->backend = view->view_src->backend; - view->buffer = view->view_src->buffer; - view->data = (char *)view->view_src->data + view->view_offs; - - // FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend - // due to the ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras - assert(ggml_allocr_is_measure(alloc) || !view->buffer || view->buffer->backend == alloc->buffer->backend); - ggml_backend_buffer_init_tensor(alloc->buffer, view); -} - -static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node) { - struct hash_node * ht = alloc->hash_table; - if (node->data == NULL) { - if (ggml_is_view(node)) { - init_view(alloc, node); - } else { - // see if we can reuse a parent's buffer (inplace) - if (ggml_op_can_inplace(node->op)) { - for (int i = 0; i < GGML_MAX_SRC; i++) { - struct ggml_tensor * parent = node->src[i]; - if (parent == NULL) { - break; - } - - // if the node's data is external, then we cannot re-use it - if (ggml_allocr_is_own(alloc, parent) == false) { - AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data); - continue; - } - - struct hash_node * p_hn = hash_get(ht, parent); - if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) { - if (ggml_is_view(parent)) { - struct ggml_tensor * view_src = parent->view_src; - struct hash_node * view_src_hn = hash_get(ht, view_src); - if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) { - // TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite - // the parent's data that it will need later (same layout requirement). the problem is that then - // we cannot free the tensor because the original address of the allocation is lost. - // adding a view_src pointer to the tensor would solve this and simplify the code dealing with views - // for now, we only reuse the parent's data if the offset is zero (view_src->data == parent->data) - AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name); - node->view_src = view_src; - view_src_hn->n_views += 1; - init_view(alloc, node); - return; - } - } - else { - AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name); - node->view_src = parent; - p_hn->n_views += 1; - init_view(alloc, node); - return; - } - } - } - } - ggml_allocr_alloc(alloc, node); - } - } -} - -size_t ggml_allocr_alloc_graph_n( - struct ggml_allocr * alloc, - struct ggml_cgraph ** graphs, int n_graphs, - struct ggml_tensor *** inputs, struct ggml_tensor *** outputs) { - - // reset hash table - struct hash_node * ht = alloc->hash_table; - memset(ht, 0, sizeof(struct hash_node) * GGML_GRAPH_HASHTABLE_SIZE); - - // count number of children and views - for (int g = 0; g < n_graphs; g++) { - struct ggml_cgraph * gf = graphs[g]; - for (int i = 0; i < gf->n_nodes; i++) { - struct ggml_tensor * node = gf->nodes[i]; - - if (ggml_is_view(node)) { - struct ggml_tensor * view_src = node->view_src; - hash_get(ht, view_src)->n_views += 1; - if (node->buffer == NULL && node->data != NULL) { - // view of a pre-allocated tensor, didn't call init_view() yet - init_view(alloc, node); - } - } - - for (int j = 0; j < GGML_MAX_SRC; j++) { - struct ggml_tensor * parent = node->src[j]; - if (parent == NULL) { - break; - } - hash_get(ht, parent)->n_children += 1; - if (ggml_is_view(parent) && parent->buffer == NULL && parent->data != NULL) { - init_view(alloc, parent); - } - } - } - } - - // allocate tensors - for (int g = 0; g < n_graphs; g++) { - struct ggml_cgraph * gf = graphs[g]; - AT_PRINTF("####### graph %d/%d\n", g, n_graphs); - // graph inputs are allocated first to ensure that they are not overwritten by each other - if (inputs != NULL && inputs[g] != NULL) { - for (int i = 0; inputs[g][i] != NULL; i++) { - struct ggml_tensor * input = inputs[g][i]; - AT_PRINTF("input: %s\n", input->name); - allocate_node(alloc, input); - } - } - // if we have parse_seq then we allocate nodes following the list, and we only free nodes at barriers - int last_barrier_pos = 0; - int n_nodes = alloc->parse_seq_len ? alloc->parse_seq_len : gf->n_nodes; - - for (int ind = 0; ind < n_nodes; ind++) { - // allocate a node if there is no parse_seq or this is not a barrier - if ((alloc->parse_seq_len==0) || alloc->parse_seq[ind] != -1) { - int i = alloc->parse_seq_len ? alloc->parse_seq[ind] : ind; - struct ggml_tensor * node = gf->nodes[i]; - - // allocate parents (leafs) - for (int j = 0; j < GGML_MAX_SRC; j++) { - struct ggml_tensor * parent = node->src[j]; - if (parent == NULL) { - break; - } - allocate_node(alloc, parent); - } - - // allocate node - allocate_node(alloc, node); - - AT_PRINTF("exec: %s (%s) <= ", ggml_op_name(node->op), node->name); - for (int j = 0; j < GGML_MAX_SRC; j++) { - struct ggml_tensor * parent = node->src[j]; - if (parent == NULL) { - break; - } - AT_PRINTF("%s", parent->name); - if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) { - AT_PRINTF(", "); - } - } - AT_PRINTF("\n"); - } - - // update parents - // update immediately if there is no parse_seq - // update only at barriers if there is parse_seq - if ((alloc->parse_seq_len == 0) || alloc->parse_seq[ind] == -1) { - int update_start = alloc->parse_seq_len ? last_barrier_pos : ind; - int update_end = alloc->parse_seq_len ? ind : ind + 1; - for (int i = update_start; i < update_end; i++) { - int node_i = alloc->parse_seq_len ? alloc->parse_seq[i] : i; - struct ggml_tensor * node = gf->nodes[node_i]; - - for (int j = 0; j < GGML_MAX_SRC; j++) { - struct ggml_tensor * parent = node->src[j]; - if (parent == NULL) { - break; - } - struct hash_node * p_hn = hash_get(ht, parent); - p_hn->n_children -= 1; - - //AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views); - - if (p_hn->n_children == 0 && p_hn->n_views == 0) { - if (ggml_is_view(parent)) { - struct ggml_tensor * view_src = parent->view_src; - struct hash_node * view_src_hn = hash_get(ht, view_src); - view_src_hn->n_views -= 1; - AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views); - if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) { - ggml_allocr_free_tensor(alloc, view_src); - } - } - else { - if (parent->data != node->data) { - ggml_allocr_free_tensor(alloc, parent); - } - } - } - } - } - AT_PRINTF("\n"); - if (alloc->parse_seq_len) { - last_barrier_pos = ind + 1; - } - } - } - // free graph outputs here that wouldn't be freed otherwise because they have no children - if (outputs != NULL && outputs[g] != NULL) { - for (int i = 0; outputs[g][i] != NULL; i++) { - struct ggml_tensor * output = outputs[g][i]; - AT_PRINTF("output: %s\n", output->name); - ggml_allocr_free_tensor(alloc, output); - } - } - } - - return alloc->max_size; -} - -size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) { - return ggml_allocr_alloc_graph_n(alloc, &graph, 1, NULL, NULL); -} - -size_t ggml_allocr_max_size(struct ggml_allocr * alloc) { - return alloc->max_size; -} diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.h b/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.h deleted file mode 100644 index e3875887..00000000 --- a/plugins/wasi_nn/thirdparty/ggml/ggml-alloc.h +++ /dev/null @@ -1,33 +0,0 @@ -#pragma once - -#include "ggml.h" - -#ifdef __cplusplus -extern "C" { -#endif - -struct ggml_backend_buffer; - -GGML_API struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment); -GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment); -GGML_API struct ggml_allocr * ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer); - -// tell the allocator to parse nodes following the order described in the list -// you should call this if your graph are optimized to execute out-of-order -GGML_API void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n); - -GGML_API void ggml_allocr_free (struct ggml_allocr * alloc); -GGML_API bool ggml_allocr_is_measure (struct ggml_allocr * alloc); -GGML_API void ggml_allocr_reset (struct ggml_allocr * alloc); -GGML_API void ggml_allocr_alloc (struct ggml_allocr * alloc, struct ggml_tensor * tensor); -GGML_API size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph); -GGML_API size_t ggml_allocr_max_size (struct ggml_allocr * alloc); - -GGML_API size_t ggml_allocr_alloc_graph_n( - struct ggml_allocr * alloc, - struct ggml_cgraph ** graphs, int n_graphs, - struct ggml_tensor *** inputs, struct ggml_tensor *** outputs); - -#ifdef __cplusplus -} -#endif diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-backend.c b/plugins/wasi_nn/thirdparty/ggml/ggml-backend.c deleted file mode 100644 index ca8d83da..00000000 --- a/plugins/wasi_nn/thirdparty/ggml/ggml-backend.c +++ /dev/null @@ -1,385 +0,0 @@ -#include "ggml-backend.h" -#include "ggml-alloc.h" - -#include -#include -#include -#include -#include - -#define UNUSED GGML_UNUSED - -#define MAX(a, b) ((a) > (b) ? (a) : (b)) - -// backend buffer - -ggml_backend_buffer_t ggml_backend_buffer_init( - struct ggml_backend * backend, - struct ggml_backend_buffer_i iface, - ggml_backend_buffer_context_t context, - size_t size) { - ggml_backend_buffer_t buffer = malloc(sizeof(struct ggml_backend_buffer)); - - GGML_ASSERT(iface.get_base != NULL); - - (*buffer) = (struct ggml_backend_buffer) { - /* .interface = */ iface, - /* .backend = */ backend, - /* .context = */ context, - /* .size = */ size, - }; - - return buffer; -} - -void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) { - if (buffer->iface.free_buffer != NULL) { - buffer->iface.free_buffer(buffer); - } - free(buffer); -} - -size_t ggml_backend_buffer_get_alignment(ggml_backend_buffer_t buffer) { - return ggml_backend_get_alignment(buffer->backend); -} - -void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { - return buffer->iface.get_base(buffer); -} - -size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) { - return buffer->size; -} - -size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { - if (buffer->iface.get_alloc_size) { - return buffer->iface.get_alloc_size(buffer, tensor); - } - return ggml_nbytes(tensor); -} - -void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { - if (buffer->iface.init_tensor) { - buffer->iface.init_tensor(buffer, tensor); - } -} - -void ggml_backend_buffer_free_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { - if (buffer->iface.free_tensor) { - buffer->iface.free_tensor(buffer, tensor); - } -} - -// backend - -ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor) { - return tensor->buffer->backend; -} - -const char * ggml_backend_name(ggml_backend_t backend) { - return backend->iface.get_name(backend); -} - -void ggml_backend_free(ggml_backend_t backend) { - backend->iface.free(backend); -} - -ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size) { - return backend->iface.alloc_buffer(backend, size); -} - -size_t ggml_backend_get_alignment(ggml_backend_t backend) { - return backend->iface.get_alignment(backend); -} - -void ggml_backend_tensor_set_async(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - ggml_get_backend(tensor)->iface.set_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size); -} - -void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { - ggml_get_backend(tensor)->iface.get_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size); -} - -void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - ggml_get_backend(tensor)->iface.set_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size); - ggml_get_backend(tensor)->iface.synchronize(ggml_get_backend(tensor)); -} - -void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { - ggml_get_backend(tensor)->iface.get_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size); - ggml_get_backend(tensor)->iface.synchronize(ggml_get_backend(tensor)); -} - -void ggml_backend_synchronize(ggml_backend_t backend) { - backend->iface.synchronize(backend); -} - -ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) { - return backend->iface.graph_plan_create(backend, cgraph); -} - -void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { - backend->iface.graph_plan_free(backend, plan); -} - -void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { - backend->iface.graph_plan_compute(backend, plan); -} - -void ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { - backend->iface.graph_compute(backend, cgraph); -} - -bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { - return backend->iface.supports_op(backend, op); -} - -// backend copy - -static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) { - if (a->type != b->type) { - return false; - } - for (int i = 0; i < GGML_MAX_DIMS; i++) { - if (a->ne[i] != b->ne[i]) { - return false; - } - if (a->nb[i] != b->nb[i]) { - return false; - } - } - return true; -} - -void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) { - //printf("src: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", src->name, (int)src->ne[0], (int)src->ne[1], (int)src->ne[2], (int)src->ne[3], (int)src->nb[0], (int)src->nb[1], (int)src->nb[2], (int)src->nb[3]); - //printf("dst: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", dst->name, (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], (int)dst->nb[0], (int)dst->nb[1], (int)dst->nb[2], (int)dst->nb[3]); - GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); - - // printf("cpy tensor %s from %s to %s (%lu bytes)\n", src->name, ggml_backend_name(src->backend), ggml_backend_name(dst->backend), ggml_nbytes(src)); - - if (src == dst) { - return; - } - - // TODO: allow backends to support copy to/from same backend - - if (ggml_get_backend(dst)->iface.cpy_tensor_from != NULL) { - ggml_get_backend(dst)->iface.cpy_tensor_from(ggml_get_backend(dst)->context, src, dst); - } else if (ggml_get_backend(src)->iface.cpy_tensor_to != NULL) { - ggml_get_backend(src)->iface.cpy_tensor_to(ggml_get_backend(src)->context, src, dst); - } else { - // shouldn't be hit when copying from/to CPU - #ifndef NDEBUG - fprintf(stderr, "ggml_backend_tensor_copy: neither cpy_tensor_from nor cpy_tensor_to are implemented for backends %s and %s, falling back to get/set\n", ggml_backend_name(src->buffer->backend), ggml_backend_name(dst->buffer->backend)); - #endif - size_t nbytes = ggml_nbytes(src); - void * data = malloc(nbytes); - ggml_backend_tensor_get(src, data, 0, nbytes); - ggml_backend_tensor_set(dst, data, 0, nbytes); - free(data); - } -} - -// backend CPU - -struct ggml_backend_cpu_context { - int n_threads; - void * work_data; - size_t work_size; -}; - -static const char * ggml_backend_cpu_name(ggml_backend_t backend) { - return "CPU"; - - UNUSED(backend); -} - -static void ggml_backend_cpu_free(ggml_backend_t backend) { - struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; - free(cpu_ctx->work_data); - free(cpu_ctx); - free(backend); -} - -static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { - return (void *)buffer->context; -} - -static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { - free(buffer->context); - UNUSED(buffer); -} - -static struct ggml_backend_buffer_i cpu_backend_buffer_i = { - /* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer, - /* .get_base = */ ggml_backend_cpu_buffer_get_base, - /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .init_tensor = */ NULL, // no initialization required - /* .free_tensor = */ NULL, // no cleanup required -}; - -// for buffers from ptr, free is not called -static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = { - /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed - /* .get_base = */ ggml_backend_cpu_buffer_get_base, - /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .init_tensor = */ NULL, - /* .free_tensor = */ NULL, -}; - -static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512 - -static ggml_backend_buffer_t ggml_backend_cpu_alloc_buffer(ggml_backend_t backend, size_t size) { - size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned - void * data = malloc(size); // TODO: maybe use GGML_ALIGNED_MALLOC? - - return ggml_backend_buffer_init(backend, cpu_backend_buffer_i, data, size); -} - -static size_t ggml_backend_cpu_get_alignment(ggml_backend_t backend) { - return TENSOR_ALIGNMENT; - UNUSED(backend); -} - -static void ggml_backend_cpu_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); - GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - - memcpy((char *)tensor->data + offset, data, size); - - UNUSED(backend); -} - -static void ggml_backend_cpu_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { - GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); - GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - - memcpy(data, (const char *)tensor->data + offset, size); - - UNUSED(backend); -} - -static void ggml_backend_cpu_synchronize(ggml_backend_t backend) { - UNUSED(backend); -} - -static void ggml_backend_cpu_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) { - ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src)); - - UNUSED(backend); -} - -static void ggml_backend_cpu_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) { - // for a backend such as CUDA that can queue async calls, it is ok to do this asynchronously, but it may not be the case for other backends - ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src)); - - UNUSED(backend); -} - -struct ggml_backend_plan_cpu { - struct ggml_cplan cplan; - struct ggml_cgraph cgraph; -}; - -static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) { - struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; - - struct ggml_backend_plan_cpu * cpu_plan = malloc(sizeof(struct ggml_backend_plan_cpu)); - - cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads); - cpu_plan->cgraph = *cgraph; - - if (cpu_plan->cplan.work_size > 0) { - cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size); - } - - return cpu_plan; -} - -static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { - struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan; - - free(cpu_plan->cplan.work_data); - free(cpu_plan); - - UNUSED(backend); -} - -static void ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { - struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan; - - ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan); - - UNUSED(backend); -} - -static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { - struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; - - struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads); - - if (cpu_ctx->work_size < cplan.work_size) { - // TODO: may be faster to free and use malloc to avoid the copy - cpu_ctx->work_data = realloc(cpu_ctx->work_data, cplan.work_size); - cpu_ctx->work_size = cplan.work_size; - } - - cplan.work_data = cpu_ctx->work_data; - - ggml_graph_compute(cgraph, &cplan); -} - -static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { - return true; - UNUSED(backend); - UNUSED(op); -} - -static struct ggml_backend_i cpu_backend_i = { - /* .get_name = */ ggml_backend_cpu_name, - /* .free = */ ggml_backend_cpu_free, - /* .alloc_buffer = */ ggml_backend_cpu_alloc_buffer, - /* .get_alignment = */ ggml_backend_cpu_get_alignment, - /* .set_tensor_async = */ ggml_backend_cpu_set_tensor_async, - /* .get_tensor_async = */ ggml_backend_cpu_get_tensor_async, - /* .synchronize = */ ggml_backend_cpu_synchronize, - /* .cpy_tensor_from = */ ggml_backend_cpu_cpy_tensor_from, - /* .cpy_tensor_to = */ ggml_backend_cpu_cpy_tensor_to, - /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create, - /* .graph_plan_free = */ ggml_backend_cpu_graph_plan_free, - /* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute, - /* .graph_compute = */ ggml_backend_cpu_graph_compute, - /* .supports_op = */ ggml_backend_cpu_supports_op, -}; - -ggml_backend_t ggml_backend_cpu_init(void) { - struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context)); - - ctx->n_threads = GGML_DEFAULT_N_THREADS; - ctx->work_data = NULL; - ctx->work_size = 0; - - ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend)); - - *cpu_backend = (struct ggml_backend) { - /* .interface = */ cpu_backend_i, - /* .context = */ ctx - }; - return cpu_backend; -} - -bool ggml_backend_is_cpu(ggml_backend_t backend) { - return backend->iface.get_name == ggml_backend_cpu_name; -} - -void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) { - GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); - - struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; - ctx->n_threads = n_threads; -} - -ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size) { - return ggml_backend_buffer_init(backend_cpu, cpu_backend_buffer_i_from_ptr, ptr, size); -} diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-backend.h b/plugins/wasi_nn/thirdparty/ggml/ggml-backend.h deleted file mode 100644 index da134b0d..00000000 --- a/plugins/wasi_nn/thirdparty/ggml/ggml-backend.h +++ /dev/null @@ -1,143 +0,0 @@ -#pragma once - -#include "ggml.h" - -#ifdef __cplusplus -extern "C" { -#endif - struct ggml_backend; - struct ggml_backend_buffer; - - // type-erased backend-specific types / wrappers - typedef void * ggml_backend_context_t; - typedef void * ggml_backend_graph_plan_t; - typedef void * ggml_backend_buffer_context_t; - - // avoid accessing internals of these types - typedef struct ggml_backend * ggml_backend_t; - typedef struct ggml_backend_buffer * ggml_backend_buffer_t; - - // - // backend buffer - // - - struct ggml_backend_buffer_i { - void (*free_buffer) (ggml_backend_buffer_t buffer); - void * (*get_base) (ggml_backend_buffer_t buffer); // get base pointer - size_t (*get_alloc_size)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-allocation callback - void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // post-allocation callback - void (*free_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-free callback - }; - - // TODO: hide behind API - struct ggml_backend_buffer { - struct ggml_backend_buffer_i iface; - - ggml_backend_t backend; - ggml_backend_buffer_context_t context; - - size_t size; - }; - - // backend buffer functions - GGML_API ggml_backend_buffer_t ggml_backend_buffer_init( - struct ggml_backend * backend, - struct ggml_backend_buffer_i iface, - ggml_backend_buffer_context_t context, - size_t size); - - GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer); - GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer); - GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer); - GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer); - GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); - GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); - GGML_API void ggml_backend_buffer_free_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); - - // - // backend - // - - struct ggml_backend_i { - const char * (*get_name)(ggml_backend_t backend); - - void (*free)(ggml_backend_t backend); - - // buffer allocation - ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_t backend, size_t size); - - // get buffer alignment - size_t (*get_alignment)(ggml_backend_t backend); - - // tensor data access - // these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize - void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); - void (*synchronize) (ggml_backend_t backend); - - // (optional) copy tensor between different backends, allow for single-copy tranfers - void (*cpy_tensor_from)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); - void (*cpy_tensor_to) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst); - - // compute graph with a plan - ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph); - void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan); - void (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan); - - // compute graph without a plan - void (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph); - - // check if the backend supports an operation - bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op); - }; - - // TODO: hide behind API - struct ggml_backend { - struct ggml_backend_i iface; - - ggml_backend_context_t context; - }; - - // backend helper functions - GGML_API ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor); - - GGML_API const char * ggml_backend_name(ggml_backend_t backend); - GGML_API void ggml_backend_free(ggml_backend_t backend); - - GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size); - - GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend); - - GGML_API void ggml_backend_tensor_set_async( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - GGML_API void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); - - GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); - - GGML_API void ggml_backend_synchronize(ggml_backend_t backend); - - GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create (ggml_backend_t backend, struct ggml_cgraph * cgraph); - - GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan); - GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan); - GGML_API void ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph); - GGML_API bool ggml_backend_supports_op (ggml_backend_t backend, const struct ggml_tensor * op); - - // tensor copy between different backends - GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst); - - // - // CPU backend - // - - GGML_API ggml_backend_t ggml_backend_cpu_init(void); - - GGML_API bool ggml_backend_is_cpu(ggml_backend_t backend); - - GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads); - - GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size); - -#ifdef __cplusplus -} -#endif diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-cuda.cu b/plugins/wasi_nn/thirdparty/ggml/ggml-cuda.cu deleted file mode 100644 index 654d3632..00000000 --- a/plugins/wasi_nn/thirdparty/ggml/ggml-cuda.cu +++ /dev/null @@ -1,7824 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include - -#if defined(GGML_USE_HIPBLAS) -#include -#include -#include -#ifdef __HIP_PLATFORM_AMD__ -// for rocblas_initialize() -#include "rocblas/rocblas.h" -#endif // __HIP_PLATFORM_AMD__ -#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F -#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F -#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F -#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT -#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT -#define CUBLAS_OP_N HIPBLAS_OP_N -#define CUBLAS_OP_T HIPBLAS_OP_T -#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS -#define CUBLAS_TF32_TENSOR_OP_MATH 0 -#define CUDA_R_16F HIPBLAS_R_16F -#define CUDA_R_32F HIPBLAS_R_32F -#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) -#define cublasCreate hipblasCreate -#define cublasGemmEx hipblasGemmEx -#define cublasHandle_t hipblasHandle_t -#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS -#define cublasSetStream hipblasSetStream -#define cublasSgemm hipblasSgemm -#define cublasStatus_t hipblasStatus_t -#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer -#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess -#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess -#define cudaDeviceProp hipDeviceProp_t -#define cudaDeviceSynchronize hipDeviceSynchronize -#define cudaError_t hipError_t -#define cudaEventCreateWithFlags hipEventCreateWithFlags -#define cudaEventDisableTiming hipEventDisableTiming -#define cudaEventRecord hipEventRecord -#define cudaEvent_t hipEvent_t -#define cudaEventDestroy hipEventDestroy -#define cudaFree hipFree -#define cudaFreeHost hipHostFree -#define cudaGetDevice hipGetDevice -#define cudaGetDeviceCount hipGetDeviceCount -#define cudaGetDeviceProperties hipGetDeviceProperties -#define cudaGetErrorString hipGetErrorString -#define cudaGetLastError hipGetLastError -#define cudaMalloc hipMalloc -#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) -#define cudaMemcpy hipMemcpy -#define cudaMemcpy2DAsync hipMemcpy2DAsync -#define cudaMemcpyAsync hipMemcpyAsync -#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice -#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost -#define cudaMemcpyHostToDevice hipMemcpyHostToDevice -#define cudaMemcpyKind hipMemcpyKind -#define cudaMemset hipMemset -#define cudaMemsetAsync hipMemsetAsync -#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize -#define cudaSetDevice hipSetDevice -#define cudaStreamCreateWithFlags hipStreamCreateWithFlags -#define cudaStreamNonBlocking hipStreamNonBlocking -#define cudaStreamSynchronize hipStreamSynchronize -#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) -#define cudaStream_t hipStream_t -#define cudaSuccess hipSuccess -#else -#include -#include -#include -#endif // defined(GGML_USE_HIPBLAS) - -#include "ggml-cuda.h" -#include "ggml.h" - -#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products -#define CC_VOLTA 700 -#define CC_OFFSET_AMD 1000000 -#define CC_RDNA2 (CC_OFFSET_AMD + 1030) - -#if defined(GGML_USE_HIPBLAS) -#define __CUDA_ARCH__ 1300 - -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \ - defined(__gfx1150__) || defined(__gfx1151__) -#define RDNA3 -#endif - -#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \ - defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__) -#define RDNA2 -#endif - -#ifndef __has_builtin - #define __has_builtin(x) 0 -#endif - -typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); -static __device__ __forceinline__ int __vsubss4(const int a, const int b) { - const int8x4_t va = reinterpret_cast(a); - const int8x4_t vb = reinterpret_cast(b); -#if __has_builtin(__builtin_elementwise_sub_sat) - const int8x4_t c = __builtin_elementwise_sub_sat(va, vb); - return reinterpret_cast(c); -#else - int8x4_t c; - int16_t tmp; -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp = va[i] - vb[i]; - if(tmp > std::numeric_limits::max()) tmp = std::numeric_limits::max(); - if(tmp < std::numeric_limits::min()) tmp = std::numeric_limits::min(); - c[i] = tmp; - } - return reinterpret_cast(c); -#endif // __has_builtin(__builtin_elementwise_sub_sat) -} - -static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { -#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__) - c = __builtin_amdgcn_sdot4(a, b, c, false); -#elif defined(__gfx1100__) - c = __builtin_amdgcn_sudot4( true, a, true, b, c, false); -#elif defined(__gfx1010__) || defined(__gfx900__) - int tmp1; - int tmp2; - asm("\n \ - v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \ - v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \ - v_add3_u32 %0, %1, %2, %0 \n \ - v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \ - v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \ - v_add3_u32 %0, %1, %2, %0 \n \ - " - : "+v"(c), "=&v"(tmp1), "=&v"(tmp2) - : "v"(a), "v"(b) - ); -#else - const int8x4_t va = reinterpret_cast(a); - const int8x4_t vb = reinterpret_cast(b); - c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3]; -#endif - return c; -} -#endif // defined(GGML_USE_HIPBLAS) - -#if defined(_MSC_VER) -#pragma warning(disable: 4244 4267) // possible loss of data -#endif - -static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); - -#define CUDA_CHECK(err) \ - do { \ - cudaError_t err_ = (err); \ - if (err_ != cudaSuccess) { \ - int id; \ - cudaGetDevice(&id); \ - fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ - cudaGetErrorString(err_)); \ - fprintf(stderr, "current device: %d\n", id); \ - exit(1); \ - } \ - } while (0) - -#if CUDART_VERSION >= 12000 -#define CUBLAS_CHECK(err) \ - do { \ - cublasStatus_t err_ = (err); \ - if (err_ != CUBLAS_STATUS_SUCCESS) { \ - int id; \ - cudaGetDevice(&id); \ - fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \ - err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \ - fprintf(stderr, "current device: %d\n", id); \ - exit(1); \ - } \ - } while (0) -#else -#define CUBLAS_CHECK(err) \ - do { \ - cublasStatus_t err_ = (err); \ - if (err_ != CUBLAS_STATUS_SUCCESS) { \ - int id; \ - cudaGetDevice(&id); \ - fprintf(stderr, "\ncuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \ - fprintf(stderr, "current device: %d\n", id); \ - exit(1); \ - } \ - } while (0) -#endif // CUDART_VERSION >= 11 - -#if CUDART_VERSION >= 11100 -#define GGML_CUDA_ASSUME(x) __builtin_assume(x) -#else -#define GGML_CUDA_ASSUME(x) -#endif // CUDART_VERSION >= 11100 - -#ifdef GGML_CUDA_F16 -typedef half dfloat; // dequantize float -typedef half2 dfloat2; -#else -typedef float dfloat; // dequantize float -typedef float2 dfloat2; -#endif //GGML_CUDA_F16 - -static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) { - const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment - - int x32 = 0; - x32 |= x16[0] << 0; - x32 |= x16[1] << 16; - - return x32; -} - -static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) { - const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment - - int x32 = 0; - x32 |= x16[0] << 0; - x32 |= x16[1] << 16; - - return x32; -} - -static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) { - return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment -} - -static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) { - return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment -} - -template -using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int k, cudaStream_t stream); -typedef to_t_cuda_t to_fp32_cuda_t; -typedef to_t_cuda_t to_fp16_cuda_t; - -typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); -typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v); -typedef void (*cpy_kernel_t)(const char * cx, char * cdst); -typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); -typedef void (*ggml_cuda_op_mul_mat_t)( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, - const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, - const int64_t src1_padded_row_size, const cudaStream_t & stream); -typedef void (*ggml_cuda_op_flatten_t)( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream); - -// QK = number of values after dequantization -// QR = QK / number of values before dequantization -// QI = number of 32 bit integers before dequantization - -#define QK4_0 32 -#define QR4_0 2 -#define QI4_0 (QK4_0 / (4 * QR4_0)) -typedef struct { - half d; // delta - uint8_t qs[QK4_0 / 2]; // nibbles / quants -} block_q4_0; -static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); - -#define QK4_1 32 -#define QR4_1 2 -#define QI4_1 (QK4_1 / (4 * QR4_1)) -typedef struct { - half2 dm; // dm.x = delta, dm.y = min - uint8_t qs[QK4_1 / 2]; // nibbles / quants -} block_q4_1; -static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); - -#define QK5_0 32 -#define QR5_0 2 -#define QI5_0 (QK5_0 / (4 * QR5_0)) -typedef struct { - half d; // delta - uint8_t qh[4]; // 5-th bit of quants - uint8_t qs[QK5_0 / 2]; // nibbles / quants -} block_q5_0; -static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); - -#define QK5_1 32 -#define QR5_1 2 -#define QI5_1 (QK5_1 / (4 * QR5_1)) -typedef struct { - half2 dm; // dm.x = delta, dm.y = min - uint8_t qh[4]; // 5-th bit of quants - uint8_t qs[QK5_1 / 2]; // nibbles / quants -} block_q5_1; -static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); - -#define QK8_0 32 -#define QR8_0 1 -#define QI8_0 (QK8_0 / (4 * QR8_0)) -typedef struct { - half d; // delta - int8_t qs[QK8_0]; // quants -} block_q8_0; -static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); - -#define QK8_1 32 -#define QR8_1 1 -#define QI8_1 (QK8_1 / (4 * QR8_1)) -typedef struct { - half2 ds; // ds.x = delta, ds.y = sum - int8_t qs[QK8_0]; // quants -} block_q8_1; -static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding"); - -typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs); -typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc); -typedef void (*load_tiles_cuda_t)( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row); -typedef float (*vec_dot_q_mul_mat_cuda_t)( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, const int & i, const int & j, const int & k); - -//================================= k-quants - -#ifdef GGML_QKK_64 -#define QK_K 64 -#define K_SCALE_SIZE 4 -#else -#define QK_K 256 -#define K_SCALE_SIZE 12 -#endif - -#define QR2_K 4 -#define QI2_K (QK_K / (4*QR2_K)) -typedef struct { - uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits - uint8_t qs[QK_K/4]; // quants - half2 dm; // super-block scale for quantized scales/mins -} block_q2_K; -static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); - -#define QR3_K 4 -#define QI3_K (QK_K / (4*QR3_K)) -typedef struct { - uint8_t hmask[QK_K/8]; // quants - high bit - uint8_t qs[QK_K/4]; // quants - low 2 bits -#ifdef GGML_QKK_64 - uint8_t scales[2]; // scales, quantized with 8 bits -#else - uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits -#endif - half d; // super-block scale -} block_q3_K; -//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding"); - -#define QR4_K 2 -#define QI4_K (QK_K / (4*QR4_K)) -#ifdef GGML_QKK_64 -typedef struct { - half dm[2]; // super-block scales/mins - uint8_t scales[2]; // 4-bit block scales/mins - uint8_t qs[QK_K/2]; // 4--bit quants -} block_q4_K; -static_assert(sizeof(block_q4_K) == sizeof(half2) + QK_K/2 + 2, "wrong q4_K block size/padding"); -#else -typedef struct { - half2 dm; // super-block scale for quantized scales/mins - uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits - uint8_t qs[QK_K/2]; // 4--bit quants -} block_q4_K; -static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding"); -#endif - -#define QR5_K 2 -#define QI5_K (QK_K / (4*QR5_K)) -#ifdef GGML_QKK_64 -typedef struct { - half d; // super-block scale - int8_t scales[QK_K/16]; // block scales - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_K; -static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); -#else -typedef struct { - half2 dm; // super-block scale for quantized scales/mins - uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_K; -static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); -#endif - -#define QR6_K 2 -#define QI6_K (QK_K / (4*QR6_K)) -typedef struct { - uint8_t ql[QK_K/2]; // quants, lower 4 bits - uint8_t qh[QK_K/4]; // quants, upper 2 bits - int8_t scales[QK_K/16]; // scales - half d; // delta -} block_q6_K; -static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding"); - -#define WARP_SIZE 32 -#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses - -#define CUDA_ADD_BLOCK_SIZE 256 -#define CUDA_MUL_BLOCK_SIZE 256 -#define CUDA_GELU_BLOCK_SIZE 256 -#define CUDA_SILU_BLOCK_SIZE 256 -#define CUDA_CPY_BLOCK_SIZE 32 -#define CUDA_SCALE_BLOCK_SIZE 256 -#define CUDA_CLAMP_BLOCK_SIZE 256 -#define CUDA_ROPE_BLOCK_SIZE 256 -#define CUDA_ALIBI_BLOCK_SIZE 32 -#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32 -#define CUDA_QUANTIZE_BLOCK_SIZE 256 -#define CUDA_DEQUANTIZE_BLOCK_SIZE 256 -#define CUDA_GET_ROWS_BLOCK_SIZE 256 - -// dmmv = dequantize_mul_mat_vec -#ifndef GGML_CUDA_DMMV_X -#define GGML_CUDA_DMMV_X 32 -#endif -#ifndef GGML_CUDA_MMV_Y -#define GGML_CUDA_MMV_Y 1 -#endif - -#ifndef K_QUANTS_PER_ITERATION -#define K_QUANTS_PER_ITERATION 2 -#else -static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); -#endif - -#ifndef GGML_CUDA_PEER_MAX_BATCH_SIZE -#define GGML_CUDA_PEER_MAX_BATCH_SIZE 128 -#endif // GGML_CUDA_PEER_MAX_BATCH_SIZE - -#define MUL_MAT_SRC1_COL_STRIDE 128 - -#define MAX_STREAMS 8 -static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr }; - -struct ggml_tensor_extra_gpu { - void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors - cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs -}; - -// this is faster on Windows -// probably because the Windows CUDA libraries forget to make this check before invoking the drivers -inline cudaError_t ggml_cuda_set_device(const int device) { - int current_device; - CUDA_CHECK(cudaGetDevice(¤t_device)); - - if (device == current_device) { - return cudaSuccess; - } - - return cudaSetDevice(device); -} - -static int g_device_count = -1; -static int g_main_device = 0; -static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES]; -static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0}; -static bool g_mul_mat_q = true; - -static void * g_scratch_buffer = nullptr; -static size_t g_scratch_size = 0; // disabled by default -static size_t g_scratch_offset = 0; - -static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; - -static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= kx) { - return; - } - dst[i] = x[i] + y[i%ky]; -} - -static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - dst[i] = __hadd(x[i], __float2half(y[i])); -} - -static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= kx) { - return; - } - dst[i] = x[i] * y[i%ky]; -} - -static __global__ void gelu_f32(const float * x, float * dst, const int k) { - const float GELU_COEF_A = 0.044715f; - const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - - float xi = x[i]; - dst[i] = 0.5f*xi*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi))); -} - -static __global__ void silu_f32(const float * x, float * dst, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - dst[i] = x[i] / (1.0f + expf(-x[i])); -} - -static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32); - a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32); - } - return a; -} - -template -static __global__ void norm_f32(const float * x, float * dst, const int ncols) { - const int row = blockIdx.x*blockDim.y + threadIdx.y; - const int tid = threadIdx.x; - - const float eps = 1e-5f; - - float2 mean_var = make_float2(0.f, 0.f); - - for (int col = tid; col < ncols; col += block_size) { - const float xi = x[row*ncols + col]; - mean_var.x += xi; - mean_var.y += xi * xi; - } - - // sum up partial sums - mean_var = warp_reduce_sum(mean_var); - if (block_size > WARP_SIZE) { - __shared__ float2 s_sum[32]; - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = mean_var; - } - __syncthreads(); - mean_var = s_sum[lane_id]; - mean_var = warp_reduce_sum(mean_var); - } - - const float mean = mean_var.x / ncols; - const float var = mean_var.y / ncols - mean * mean; - const float inv_std = rsqrtf(var + eps); - - for (int col = tid; col < ncols; col += block_size) { - dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std; - } -} - -static __device__ __forceinline__ float warp_reduce_sum(float x) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - x += __shfl_xor_sync(0xffffffff, x, mask, 32); - } - return x; -} - -template -static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) { - const int row = blockIdx.x*blockDim.y + threadIdx.y; - const int tid = threadIdx.x; - - float tmp = 0.0f; // partial sum for thread in warp - - for (int col = tid; col < ncols; col += block_size) { - const float xi = x[row*ncols + col]; - tmp += xi * xi; - } - - // sum up partial sums - tmp = warp_reduce_sum(tmp); - if (block_size > WARP_SIZE) { - __shared__ float s_sum[32]; - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = tmp; - } - __syncthreads(); - tmp = s_sum[lane_id]; - tmp = warp_reduce_sum(tmp); - } - - const float mean = tmp / ncols; - const float scale = rsqrtf(mean + eps); - - for (int col = tid; col < ncols; col += block_size) { - dst[row*ncols + col] = scale * x[row*ncols + col]; - } -} - -static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ - const block_q4_0 * x = (const block_q4_0 *) vx; - - const dfloat d = x[ib].d; - - const int vui = x[ib].qs[iqs]; - - v.x = vui & 0xF; - v.y = vui >> 4; - -#ifdef GGML_CUDA_F16 - v = __hsub2(v, {8.0f, 8.0f}); - v = __hmul2(v, {d, d}); -#else - v.x = (v.x - 8.0f) * d; - v.y = (v.y - 8.0f) * d; -#endif // GGML_CUDA_F16 -} - -static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ - const block_q4_1 * x = (const block_q4_1 *) vx; - - const dfloat d = __low2half(x[ib].dm); - const dfloat m = __high2half(x[ib].dm); - - const int vui = x[ib].qs[iqs]; - - v.x = vui & 0xF; - v.y = vui >> 4; - -#ifdef GGML_CUDA_F16 - v = __hmul2(v, {d, d}); - v = __hadd2(v, {m, m}); -#else - v.x = (v.x * d) + m; - v.y = (v.y * d) + m; -#endif // GGML_CUDA_F16 -} - -static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ - const block_q5_0 * x = (const block_q5_0 *) vx; - - const dfloat d = x[ib].d; - - uint32_t qh; - memcpy(&qh, x[ib].qh, sizeof(qh)); - - const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; - const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10; - - v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); - v.y = ((x[ib].qs[iqs] >> 4) | xh_1); - -#ifdef GGML_CUDA_F16 - v = __hsub2(v, {16.0f, 16.0f}); - v = __hmul2(v, {d, d}); -#else - v.x = (v.x - 16.0f) * d; - v.y = (v.y - 16.0f) * d; -#endif // GGML_CUDA_F16 -} - -static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ - const block_q5_1 * x = (const block_q5_1 *) vx; - - const dfloat d = __low2half(x[ib].dm); - const dfloat m = __high2half(x[ib].dm); - - uint32_t qh; - memcpy(&qh, x[ib].qh, sizeof(qh)); - - const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; - const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10; - - v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); - v.y = ((x[ib].qs[iqs] >> 4) | xh_1); - -#ifdef GGML_CUDA_F16 - v = __hmul2(v, {d, d}); - v = __hadd2(v, {m, m}); -#else - v.x = (v.x * d) + m; - v.y = (v.y * d) + m; -#endif // GGML_CUDA_F16 -} - -static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ - const block_q8_0 * x = (const block_q8_0 *) vx; - - const dfloat d = x[ib].d; - - v.x = x[ib].qs[iqs + 0]; - v.y = x[ib].qs[iqs + 1]; - -#ifdef GGML_CUDA_F16 - v = __hmul2(v, {d, d}); -#else - v.x *= d; - v.y *= d; -#endif // GGML_CUDA_F16 -} - -//================================== k-quants - -template -static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { - - const int i = blockIdx.x; - const block_q2_K * x = (const block_q2_K *) vx; - - const int tid = threadIdx.x; -#if QK_K == 256 - const int n = tid/32; - const int l = tid - 32*n; - const int is = 8*n + l/16; - - const uint8_t q = x[i].qs[32*n + l]; - dst_t * y = yy + i*QK_K + 128*n; - - float dall = __low2half(x[i].dm); - float dmin = __high2half(x[i].dm); - y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); - y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4); - y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4); - y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4); -#else - const int is = tid/16; // 0 or 1 - const int il = tid%16; // 0...15 - const uint8_t q = x[i].qs[il] >> (2*is); - dst_t * y = yy + i*QK_K + 16*is + il; - float dall = __low2half(x[i].dm); - float dmin = __high2half(x[i].dm); - y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); - y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4); -#endif - -} - -template -static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { - - const int i = blockIdx.x; - const block_q3_K * x = (const block_q3_K *) vx; - -#if QK_K == 256 - const int r = threadIdx.x/4; - const int tid = r/2; - const int is0 = r%2; - const int l0 = 16*is0 + 4*(threadIdx.x%4); - const int n = tid / 4; - const int j = tid - 4*n; - - uint8_t m = 1 << (4*n + j); - int is = 8*n + 2*j + is0; - int shift = 2*j; - - int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) : - is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) : - is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) : - (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4); - float d_all = x[i].d; - float dl = d_all * (us - 32); - - dst_t * y = yy + i*QK_K + 128*n + 32*j; - const uint8_t * q = x[i].qs + 32*n; - const uint8_t * hm = x[i].hmask; - - for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); -#else - const int tid = threadIdx.x; - const int is = tid/16; // 0 or 1 - const int il = tid%16; // 0...15 - const int im = il/8; // 0...1 - const int in = il%8; // 0...7 - - dst_t * y = yy + i*QK_K + 16*is + il; - - const uint8_t q = x[i].qs[il] >> (2*is); - const uint8_t h = x[i].hmask[in] >> (2*is + im); - const float d = (float)x[i].d; - - if (is == 0) { - y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); - y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); - } else { - y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); - y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); - } -#endif - -} - -#if QK_K == 256 -static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { - if (j < 4) { - d = q[j] & 63; m = q[j + 4] & 63; - } else { - d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); - m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); - } -} -#endif - -template -static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const block_q4_K * x = (const block_q4_K *) vx; - - const int i = blockIdx.x; - -#if QK_K == 256 - // assume 32 threads - const int tid = threadIdx.x; - const int il = tid/8; - const int ir = tid%8; - const int is = 2*il; - const int n = 4; - - dst_t * y = yy + i*QK_K + 64*il + n*ir; - - const float dall = __low2half(x[i].dm); - const float dmin = __high2half(x[i].dm); - - const uint8_t * q = x[i].qs + 32*il + n*ir; - - uint8_t sc, m; - get_scale_min_k4(is + 0, x[i].scales, sc, m); - const float d1 = dall * sc; const float m1 = dmin * m; - get_scale_min_k4(is + 1, x[i].scales, sc, m); - const float d2 = dall * sc; const float m2 = dmin * m; - for (int l = 0; l < n; ++l) { - y[l + 0] = d1 * (q[l] & 0xF) - m1; - y[l +32] = d2 * (q[l] >> 4) - m2; - } -#else - const int tid = threadIdx.x; - const uint8_t * q = x[i].qs; - dst_t * y = yy + i*QK_K; - const float d = (float)x[i].dm[0]; - const float m = (float)x[i].dm[1]; - y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4); - y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4); -#endif -} - -template -static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const block_q5_K * x = (const block_q5_K *) vx; - - const int i = blockIdx.x; - -#if QK_K == 256 - // assume 64 threads - this is very slightly better than the one below - const int tid = threadIdx.x; - const int il = tid/16; // il is in 0...3 - const int ir = tid%16; // ir is in 0...15 - const int is = 2*il; // is is in 0...6 - - dst_t * y = yy + i*QK_K + 64*il + 2*ir; - - const float dall = __low2half(x[i].dm); - const float dmin = __high2half(x[i].dm); - - const uint8_t * ql = x[i].qs + 32*il + 2*ir; - const uint8_t * qh = x[i].qh + 2*ir; - - uint8_t sc, m; - get_scale_min_k4(is + 0, x[i].scales, sc, m); - const float d1 = dall * sc; const float m1 = dmin * m; - get_scale_min_k4(is + 1, x[i].scales, sc, m); - const float d2 = dall * sc; const float m2 = dmin * m; - - uint8_t hm = 1 << (2*il); - y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1; - y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1; - hm <<= 1; - y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2; - y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; -#else - const int tid = threadIdx.x; - const uint8_t q = x[i].qs[tid]; - const int im = tid/8; // 0...3 - const int in = tid%8; // 0...7 - const int is = tid/16; // 0 or 1 - const uint8_t h = x[i].qh[in] >> im; - const float d = x[i].d; - dst_t * y = yy + i*QK_K + tid; - y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16)); - y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16)); -#endif -} - -template -static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { - const block_q6_K * x = (const block_q6_K *) vx; - - const int i = blockIdx.x; -#if QK_K == 256 - - // assume 64 threads - this is very slightly better than the one below - const int tid = threadIdx.x; - const int ip = tid/32; // ip is 0 or 1 - const int il = tid - 32*ip; // 0...32 - const int is = 8*ip + il/16; - - dst_t * y = yy + i*QK_K + 128*ip + il; - - const float d = x[i].d; - - const uint8_t * ql = x[i].ql + 64*ip + il; - const uint8_t qh = x[i].qh[32*ip + il]; - const int8_t * sc = x[i].scales + is; - - y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32); - y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32); - y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32); - y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); -#else - - // assume 32 threads - const int tid = threadIdx.x; - const int ip = tid/16; // 0 or 1 - const int il = tid - 16*ip; // 0...15 - - dst_t * y = yy + i*QK_K + 16*ip + il; - - const float d = x[i].d; - - const uint8_t ql = x[i].ql[16*ip + il]; - const uint8_t qh = x[i].qh[il] >> (2*ip); - const int8_t * sc = x[i].scales; - - y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32); - y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32); -#endif -} - -static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - - static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); - - const int row = blockIdx.y*blockDim.y + threadIdx.y; - if (row > nrows) return; - - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - const block_q2_K * x = (const block_q2_K *)vx + ib0; - - float tmp = 0; // partial sum for thread in warp - -#if QK_K == 256 - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 - - const int step = 16/K_QUANTS_PER_ITERATION; - - const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - step*im; // 0...15 or 0...7 - - const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2 - const int q_offset = 32*im + l0; - const int s_offset = 8*im; - const int y_offset = 128*im + l0; - - uint32_t aux[4]; - const uint8_t * d = (const uint8_t *)aux; - const uint8_t * m = (const uint8_t *)(aux + 2); - - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - - const float * y = yy + i * QK_K + y_offset; - const uint8_t * q = x[i].qs + q_offset; - - const float dall = __low2half(x[i].dm); - const float dmin = __high2half(x[i].dm); - - const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset); - aux[0] = a[0] & 0x0f0f0f0f; - aux[1] = a[1] & 0x0f0f0f0f; - aux[2] = (a[0] >> 4) & 0x0f0f0f0f; - aux[3] = (a[1] >> 4) & 0x0f0f0f0f; - - float sum1 = 0, sum2 = 0; - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { - sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3) - + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3) - + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3) - + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3) - + y[l+16] * d[1] * ((q[l+16] >> 0) & 3) - + y[l+48] * d[3] * ((q[l+16] >> 2) & 3) - + y[l+80] * d[5] * ((q[l+16] >> 4) & 3) - +y[l+112] * d[7] * ((q[l+16] >> 6) & 3); - sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6] - + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7]; - - } - tmp += dall * sum1 - dmin * sum2; - - } -#else - const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 - const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 - const int offset = tid * K_QUANTS_PER_ITERATION; - - uint32_t uaux[2]; - const uint8_t * d = (const uint8_t *)uaux; - - for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { - - const float * y = yy + i * QK_K + offset; - const uint8_t * q = x[i].qs + offset; - const uint32_t * s = (const uint32_t *)x[i].scales; - - uaux[0] = s[0] & 0x0f0f0f0f; - uaux[1] = (s[0] >> 4) & 0x0f0f0f0f; - - const float2 dall = __half22float2(x[i].dm); - - float sum1 = 0, sum2 = 0; - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { - const uint8_t ql = q[l]; - sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3) - + y[l+16] * d[1] * ((ql >> 2) & 3) - + y[l+32] * d[2] * ((ql >> 4) & 3) - + y[l+48] * d[3] * ((ql >> 6) & 3); - sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7]; - } - tmp += dall.x * sum1 - dall.y * sum2; - } -#endif - - // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } - - if (threadIdx.x == 0) { - dst[row] = tmp; - } -} - -static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - - const int row = blockIdx.y*blockDim.y + threadIdx.y; - if (row > nrows) return; - - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - const block_q3_K * x = (const block_q3_K *)vx + ib0; - - float tmp = 0; // partial sum for thread in warp - -#if QK_K == 256 - - const uint16_t kmask1 = 0x0303; - const uint16_t kmask2 = 0x0f0f; - - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 - - const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop - const int step = 16/K_QUANTS_PER_ITERATION; - const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - step*im; // 0....15 or 0...7 - - const uint8_t m = 1 << (4*im); - - const int l0 = n*in; // 0...15 or 0...14 in steps of 2 - const int q_offset = 32*im + l0; - const int y_offset = 128*im + l0; - - uint16_t utmp[4]; - const int8_t * s = (const int8_t *)utmp; - - const uint16_t s_shift = 4*im; - - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - - const float * y = yy + i * QK_K + y_offset; - const uint8_t * q = x[i].qs + q_offset; - const uint8_t * h = x[i].hmask + l0; - - const uint16_t * a = (const uint16_t *)x[i].scales; - utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4); - utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4); - utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4); - utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4); - - const float d = x[i].d; - - float sum = 0; - for (int l = 0; l < n; ++l) { - sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4)) - + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4)) - + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4)) - + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4)); - sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4)) - + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4)) - + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4)) - + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4)); - } - tmp += d * sum; - - } -#else - - const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 - const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 - const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14 - const int in = offset/8; // 0 or 1 - const int im = offset%8; // 0...7 - - for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { - - const float * y = yy + i * QK_K + offset; - const uint8_t * q = x[i].qs + offset; - const uint8_t * s = x[i].scales; - - const float dall = (float)x[i].d; - - float sum = 0; - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { - const uint8_t hl = x[i].hmask[im+l] >> in; - const uint8_t ql = q[l]; - sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4)) - + y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4)) - + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4)) - + y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4)); - } - tmp += sum; - } -#endif - - // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } - - if (threadIdx.x == 0) { - dst[row] = tmp; - } -} - -static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - - const int row = blockIdx.y*blockDim.y + threadIdx.y; - if (row > nrows) return; - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - const block_q4_K * x = (const block_q4_K *)vx + ib0; - -#if QK_K == 256 - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 - - const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 - - const int il = tid/step; // 0...3 - const int ir = tid - step*il; // 0...7 or 0...3 - const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 - - const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 - const int in = il%2; - - const int l0 = n*(2*ir + in); - const int q_offset = 32*im + l0; - const int y_offset = 64*im + l0; - - uint16_t aux[4]; - const uint8_t * sc = (const uint8_t *)aux; - -#if K_QUANTS_PER_ITERATION == 2 - uint32_t q32[4]; - const uint8_t * q4 = (const uint8_t *)q32; -#else - uint16_t q16[4]; - const uint8_t * q4 = (const uint8_t *)q16; -#endif - - float tmp = 0; // partial sum for thread in warp - - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - - const float * y1 = yy + i*QK_K + y_offset; - const float * y2 = y1 + 128; - - const float dall = __low2half(x[i].dm); - const float dmin = __high2half(x[i].dm); - - const uint16_t * a = (const uint16_t *)x[i].scales; - aux[0] = a[im+0] & kmask1; - aux[1] = a[im+2] & kmask1; - aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); - aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); - -#if K_QUANTS_PER_ITERATION == 2 - const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset); - const uint32_t * q2 = q1 + 16; - - q32[0] = q1[0] & 0x0f0f0f0f; - q32[1] = q1[0] & 0xf0f0f0f0; - q32[2] = q2[0] & 0x0f0f0f0f; - q32[3] = q2[0] & 0xf0f0f0f0; - - float4 s = {0.f, 0.f, 0.f, 0.f}; - float smin = 0; - for (int l = 0; l < 4; ++l) { - s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4]; - s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12]; - smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; - } - tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; -#else - const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset); - const uint16_t * q2 = q1 + 32; - - q16[0] = q1[0] & 0x0f0f; - q16[1] = q1[0] & 0xf0f0; - q16[2] = q2[0] & 0x0f0f; - q16[3] = q2[0] & 0xf0f0; - - float4 s = {0.f, 0.f, 0.f, 0.f}; - float smin = 0; - for (int l = 0; l < 2; ++l) { - s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2]; - s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6]; - smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; - } - tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; -#endif - - } -#else - const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 - const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); - - const int step = tid * K_QUANTS_PER_ITERATION; - - uint16_t aux16[2]; - const uint8_t * s = (const uint8_t *)aux16; - - float tmp = 0; - - for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { - const uint8_t * q = x[i].qs + step; - const float * y = yy + i*QK_K + step; - const uint16_t * a = (const uint16_t *)x[i].scales; - aux16[0] = a[0] & 0x0f0f; - aux16[1] = (a[0] >> 4) & 0x0f0f; - const float d = (float)x[i].dm[0]; - const float m = (float)x[i].dm[1]; - float sum = 0.f; - for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { - sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2]) - + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2]) - + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3]) - + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]); - } - tmp += sum; - } - -#endif - - // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } - - if (tid == 0) { - dst[row] = tmp; - } -} - -static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) { - - const int row = blockIdx.x; - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - const block_q5_K * x = (const block_q5_K *)vx + ib0; - - float tmp = 0; // partial sum for thread in warp - -#if QK_K == 256 - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - const int tid = threadIdx.x/2; // 0...15 - const int ix = threadIdx.x%2; - - const int il = tid/4; // 0...3 - const int ir = tid - 4*il;// 0...3 - const int n = 2; - - const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 - const int in = il%2; - - const int l0 = n*(2*ir + in); - const int q_offset = 32*im + l0; - const int y_offset = 64*im + l0; - - const uint8_t hm1 = 1 << (2*im); - const uint8_t hm2 = hm1 << 4; - - uint16_t aux[4]; - const uint8_t * sc = (const uint8_t *)aux; - - uint16_t q16[8]; - const uint8_t * q4 = (const uint8_t *)q16; - - for (int i = ix; i < num_blocks_per_row; i += 2) { - - const uint8_t * ql1 = x[i].qs + q_offset; - const uint8_t * qh = x[i].qh + l0; - const float * y1 = yy + i*QK_K + y_offset; - const float * y2 = y1 + 128; - - const float dall = __low2half(x[i].dm); - const float dmin = __high2half(x[i].dm); - - const uint16_t * a = (const uint16_t *)x[i].scales; - aux[0] = a[im+0] & kmask1; - aux[1] = a[im+2] & kmask1; - aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); - aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); - - float4 sum = {0.f, 0.f, 0.f, 0.f}; - float smin = 0; - const uint16_t * q1 = (const uint16_t *)ql1; - const uint16_t * q2 = q1 + 32; - q16[0] = q1[0] & 0x0f0f; - q16[1] = q1[8] & 0x0f0f; - q16[2] = (q1[0] >> 4) & 0x0f0f; - q16[3] = (q1[8] >> 4) & 0x0f0f; - q16[4] = q2[0] & 0x0f0f; - q16[5] = q2[8] & 0x0f0f; - q16[6] = (q2[0] >> 4) & 0x0f0f; - q16[7] = (q2[8] >> 4) & 0x0f0f; - for (int l = 0; l < n; ++l) { - sum.x += y1[l+ 0] * (q4[l +0] + (qh[l+ 0] & (hm1 << 0) ? 16 : 0)) - + y1[l+16] * (q4[l +2] + (qh[l+16] & (hm1 << 0) ? 16 : 0)); - sum.y += y1[l+32] * (q4[l +4] + (qh[l+ 0] & (hm1 << 1) ? 16 : 0)) - + y1[l+48] * (q4[l +6] + (qh[l+16] & (hm1 << 1) ? 16 : 0)); - sum.z += y2[l+ 0] * (q4[l +8] + (qh[l+ 0] & (hm2 << 0) ? 16 : 0)) - + y2[l+16] * (q4[l+10] + (qh[l+16] & (hm2 << 0) ? 16 : 0)); - sum.w += y2[l+32] * (q4[l+12] + (qh[l+ 0] & (hm2 << 1) ? 16 : 0)) - + y2[l+48] * (q4[l+14] + (qh[l+16] & (hm2 << 1) ? 16 : 0)); - smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3] - + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7]; - } - tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin; - } - -#else - const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 - const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); - const int step = tid * K_QUANTS_PER_ITERATION; - const int im = step/8; - const int in = step%8; - - for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { - const uint8_t * q = x[i].qs + step; - const int8_t * s = x[i].scales; - const float * y = yy + i*QK_K + step; - const float d = x[i].d; - float sum = 0.f; - for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { - const uint8_t h = x[i].qh[in+j] >> im; - sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16)) - + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16)) - + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16)) - + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16)); - } - tmp += sum; - } -#endif - - // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } - - if (threadIdx.x == 0) { - dst[row] = tmp; - } -} - -static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - - static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); - - const int row = blockIdx.y*blockDim.y + threadIdx.y; - if (row > nrows) return; - - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - const block_q6_K * x = (const block_q6_K *)vx + ib0; - -#if QK_K == 256 - - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 - - const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 - - const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - step*im; // 0...15 or 0...7 - -#if K_QUANTS_PER_ITERATION == 1 - const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 - const int is = 0; -#else - const int l0 = 4 * in; // 0, 4, 8, ..., 28 - const int is = in / 4; -#endif - const int ql_offset = 64*im + l0; - const int qh_offset = 32*im + l0; - const int s_offset = 8*im + is; - const int y_offset = 128*im + l0; - - float tmp = 0; // partial sum for thread in warp - - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - - const float * y = yy + i * QK_K + y_offset; - const uint8_t * ql = x[i].ql + ql_offset; - const uint8_t * qh = x[i].qh + qh_offset; - const int8_t * s = x[i].scales + s_offset; - - const float d = x[i].d; - -#if K_QUANTS_PER_ITERATION == 1 - float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32) - + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32) - + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32) - + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32) - + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32) - + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32) - + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32) - +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32); - tmp += sum; -#else - float sum = 0; - for (int l = 0; l < 4; ++l) { - sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) - + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) - + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) - + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); - } - tmp += sum; -#endif - - } - -#else - - const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...7 - const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0...3 - - const int step = tid * K_QUANTS_PER_ITERATION; - - float tmp = 0; // partial sum for thread in warp - - for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { - - const float * y = yy + i * QK_K + step; - const uint8_t * ql = x[i].ql + step; - const uint8_t * qh = x[i].qh + step; - const int8_t * s = x[i].scales; - - const float d = x[i+0].d; - - float sum = 0; - for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { - sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32) - + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32) - + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32) - + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32); - } - tmp += sum; - - } - -#endif - - // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } - - if (tid == 0) { - dst[row] = tmp; - } -} - -static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){ - const half * x = (const half *) vx; - - // automatic half -> float type cast if dfloat == float - v.x = x[ib + iqs + 0]; - v.y = x[ib + iqs + 1]; -} - -static __device__ void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){ - const float * x = (const float *) vx; - - // automatic half -> float type cast if dfloat == float - v.x = x[ib + iqs + 0]; - v.y = x[ib + iqs + 1]; -} - -static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) { - const int ix = blockDim.x*blockIdx.x + threadIdx.x; - - if (ix >= kx_padded) { - return; - } - - const int iy = blockDim.y*blockIdx.y + threadIdx.y; - - const int i_padded = iy*kx_padded + ix; - - block_q8_1 * y = (block_q8_1 *) vy; - - const int ib = i_padded / QK8_1; // block index - const int iqs = i_padded % QK8_1; // quant index - - const float xi = ix < kx ? x[iy*kx + ix] : 0.0f; - float amax = fabsf(xi); - float sum = xi; - -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32)); - sum += __shfl_xor_sync(0xffffffff, sum, mask, 32); - } - - const float d = amax / 127; - const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); - - y[ib].qs[iqs] = q; - - if (iqs > 0) { - return; - } - - reinterpret_cast(y[ib].ds.x) = d; - reinterpret_cast(y[ib].ds.y) = sum; -} - -template -static __global__ void k_get_rows(const void * x, const int32_t * y, dst_t * dst, const int ncols) { - const int col = (blockIdx.x*blockDim.x + threadIdx.x)*2; - const int row = blockDim.y*blockIdx.y + threadIdx.y; - - if (col >= ncols) { - return; - } - - const int r = y[row]; - - // copy x[r*ncols + col] to dst[row*ncols + col] - const int xi = r*ncols + col; - const int di = row*ncols + col; - - const int ib = xi/qk; // block index - const int iqs = (xi%qk)/qr; // quant index - const int iybs = di - di%qk; // y block start index - const int y_offset = qr == 1 ? 1 : qk/2; - - // dequantize - dfloat2 v; - dequantize_kernel(x, ib, iqs, v); - - dst[iybs + iqs + 0] = v.x; - dst[iybs + iqs + y_offset] = v.y; -} - -template -static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) { - const int i = blockDim.x*blockIdx.x + 2*threadIdx.x; - - if (i >= k) { - return; - } - - const int ib = i/qk; // block index - const int iqs = (i%qk)/qr; // quant index - const int iybs = i - i%qk; // y block start index - const int y_offset = qr == 1 ? 1 : qk/2; - - // dequantize - dfloat2 v; - dequantize_kernel(vx, ib, iqs, v); - - y[iybs + iqs + 0] = v.x; - y[iybs + iqs + y_offset] = v.y; -} - -// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called -// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q - -#define VDR_Q4_0_Q8_1_MMVQ 2 -#define VDR_Q4_0_Q8_1_MMQ 4 - -template static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl( - const int * v, const int * u, const float & d4, const half2 & ds8) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - int sumi = 0; - -#pragma unroll - for (int i = 0; i < vdr; ++i) { - const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; - const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; - - // SIMD dot product of quantized values - sumi = __dp4a(vi0, u[2*i+0], sumi); - sumi = __dp4a(vi1, u[2*i+1], sumi); - } - - const float2 ds8f = __half22float2(ds8); - - // second part effectively subtracts 8 from each quant value - return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y); -#else - assert(false); - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -#define VDR_Q4_1_Q8_1_MMVQ 2 -#define VDR_Q4_1_Q8_1_MMQ 4 - -template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl( - const int * v, const int * u, const half2 & dm4, const half2 & ds8) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - int sumi = 0; - -#pragma unroll - for (int i = 0; i < vdr; ++i) { - const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; - const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; - - // SIMD dot product of quantized values - sumi = __dp4a(vi0, u[2*i+0], sumi); - sumi = __dp4a(vi1, u[2*i+1], sumi); - } - -#ifdef GGML_CUDA_F16 - const float2 tmp = __half22float2(__hmul2(dm4, ds8)); - const float d4d8 = tmp.x; - const float m4s8 = tmp.y; -#else - const float2 dm4f = __half22float2(dm4); - const float2 ds8f = __half22float2(ds8); - const float d4d8 = dm4f.x * ds8f.x; - const float m4s8 = dm4f.y * ds8f.y; -#endif // GGML_CUDA_F16 - - // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it - return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1)); -#else - assert(false); - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -#define VDR_Q5_0_Q8_1_MMVQ 2 -#define VDR_Q5_0_Q8_1_MMQ 4 - -template static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl( - const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - int sumi = 0; - -#pragma unroll - for (int i = 0; i < vdr; ++i) { - int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits - vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4 - vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 - vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 - vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 - sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values - - int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits - vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 - vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 - vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 - vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 - sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values - } - - const float2 ds8f = __half22float2(ds8); - - // second part effectively subtracts 16 from each quant value - return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y); -#else - assert(false); - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -#define VDR_Q5_1_Q8_1_MMVQ 2 -#define VDR_Q5_1_Q8_1_MMQ 4 - -template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl( - const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - int sumi = 0; - -#pragma unroll - for (int i = 0; i < vdr; ++i) { - int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits - vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4 - vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 - vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 - vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 - sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values - - int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits - vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 - vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 - vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 - vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 - sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values - } - -#ifdef GGML_CUDA_F16 - const float2 tmp = __half22float2(__hmul2(dm5, ds8)); - const float d5d8 = tmp.x; - const float m5s8 = tmp.y; -#else - const float2 dm5f = __half22float2(dm5); - const float2 ds8f = __half22float2(ds8); - const float d5d8 = dm5f.x * ds8f.x; - const float m5s8 = dm5f.y * ds8f.y; -#endif // GGML_CUDA_F16 - - // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it - return sumi*d5d8 + m5s8 / (QI5_1 / vdr); - -#else - assert(false); - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -#define VDR_Q8_0_Q8_1_MMVQ 2 -#define VDR_Q8_0_Q8_1_MMQ 8 - -template static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl( - const int * v, const int * u, const float & d8_0, const float & d8_1) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - int sumi = 0; - -#pragma unroll - for (int i = 0; i < vdr; ++i) { - // SIMD dot product of quantized values - sumi = __dp4a(v[i], u[i], sumi); - } - - return d8_0*d8_1 * sumi; -#else - assert(false); - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl( - const int * v, const int * u, const half2 & dm8, const half2 & ds8) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - int sumi = 0; - -#pragma unroll - for (int i = 0; i < vdr; ++i) { - // SIMD dot product of quantized values - sumi = __dp4a(v[i], u[i], sumi); - } - -#ifdef GGML_CUDA_F16 - const float2 tmp = __half22float2(__hmul2(dm8, ds8)); - const float d8d8 = tmp.x; - const float m8s8 = tmp.y; -#else - const float2 dm8f = __half22float2(dm8); - const float2 ds8f = __half22float2(ds8); - const float d8d8 = dm8f.x * ds8f.x; - const float m8s8 = dm8f.y * ds8f.y; -#endif // GGML_CUDA_F16 - - // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it - return sumi*d8d8 + m8s8 / (QI8_1 / vdr); -#else - assert(false); - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -#define VDR_Q2_K_Q8_1_MMVQ 1 -#define VDR_Q2_K_Q8_1_MMQ 2 - -// contiguous v/x values -static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( - const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales, - const half2 & dm2, const float * __restrict__ d8) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - float sumf_d = 0.0f; - float sumf_m = 0.0f; - -#pragma unroll - for (int i = 0; i < QR2_K; ++i) { - const int sc = scales[2*i]; - - const int vi = (v >> (2*i)) & 0x03030303; - - sumf_d += d8[i] * (__dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product - - // fill int with 4x m - int m = sc >> 4; - m |= m << 8; - m |= m << 16; - sumf_m += d8[i] * __dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values - } - - const float2 dm2f = __half22float2(dm2); - - return dm2f.x*sumf_d - dm2f.y*sumf_m; -#else - assert(false); - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -// contiguous u/y values -static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( - const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ scales, - const half2 & dm2, const float & d8) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - int sumi_d = 0; - int sumi_m = 0; - -#pragma unroll - for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) { - int sumi_d_sc = 0; - - const int sc = scales[i0 / (QI8_1/2)]; - - // fill int with 4x m - int m = sc >> 4; - m |= m << 8; - m |= m << 16; - -#pragma unroll - for (int i = i0; i < i0 + QI8_1/2; ++i) { - sumi_d_sc = __dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product - sumi_m = __dp4a(m, u[i], sumi_m); // multiply sum of q8_1 values with m - } - - sumi_d += sumi_d_sc * (sc & 0xF); - } - - const float2 dm2f = __half22float2(dm2); - - return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m); -#else - assert(false); - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -#define VDR_Q3_K_Q8_1_MMVQ 1 -#define VDR_Q3_K_Q8_1_MMQ 2 - -// contiguous v/x values -static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq( - const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales, - const int & scale_offset, const float & d3, const float * __restrict__ d8) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - float sumf = 0.0f; - -#pragma unroll - for (int i = 0; i < QR3_K; ++i) { - const int isc = scale_offset + 2*i; - - const int isc_low = isc % (QK_K/32); - const int sc_shift_low = 4 * (isc / (QK_K/32)); - const int sc_low = (scales[isc_low] >> sc_shift_low) & 0xF; - - const int isc_high = isc % (QK_K/64); - const int sc_shift_high = 2 * (isc / (QK_K/64)); - const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4; - - const int sc = (sc_low | sc_high) - 32; - - const int vil = (vl >> (2*i)) & 0x03030303; - - const int vih = ((vh >> i) << 2) & 0x04040404; - - const int vi = __vsubss4(vil, vih); - - sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product - } - - return d3 * sumf; -#else - assert(false); - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -// contiguous u/y values -static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( - const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales, - const float & d3, const float & d8) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - int sumi = 0; - -#pragma unroll - for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) { - int sumi_sc = 0; - - for (int i = i0; i < i0 + QI8_1/2; ++i) { - sumi_sc = __dp4a(v[i], u[i], sumi_sc); // SIMD dot product - } - - sumi += sumi_sc * scales[i0 / (QI8_1/2)]; - } - - return d3*d8 * sumi; -#else - assert(false); - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -#define VDR_Q4_K_Q8_1_MMVQ 2 -#define VDR_Q4_K_Q8_1_MMQ 8 - -// contiguous v/x values -static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq( - const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, - const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - float sumf_d = 0.0f; - float sumf_m = 0.0f; - -#pragma unroll - for (int i = 0; i < QR4_K; ++i) { - const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F; - const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F; - - const int dot1 = __dp4a(v1i, u[2*i+1], __dp4a(v0i, u[2*i+0], 0)); // SIMD dot product - const int dot2 = __dp4a(0x01010101, u[2*i+1], __dp4a(0x01010101, u[2*i+0], 0)); // sum of u - - sumf_d += d8[i] * (dot1 * sc[i]); - sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values - } - - const float2 dm4f = __half22float2(dm4); - - return dm4f.x*sumf_d - dm4f.y*sumf_m; - -#else - assert(false); - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -// contiguous u/y values -static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( - const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, - const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - float sumf_d = 0.0f; - float sumf_m = 0.0f; - -#pragma unroll - for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) { - int sumi_d = 0; - -#pragma unroll - for (int j = 0; j < QI8_1; ++j) { - sumi_d = __dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product - } - - const float2 ds8f = __half22float2(ds8[i]); - - sumf_d += ds8f.x * (sc[i] * sumi_d); - sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val - } - - const float2 dm4f = __half22float2(dm4); - - return dm4f.x*sumf_d - dm4f.y*sumf_m; - -#else - assert(false); - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -#define VDR_Q5_K_Q8_1_MMVQ 2 -#define VDR_Q5_K_Q8_1_MMQ 8 - -// contiguous v/x values -static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq( - const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc, - const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - float sumf_d = 0.0f; - float sumf_m = 0.0f; - -#pragma unroll - for (int i = 0; i < QR5_K; ++i) { - const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F; - const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F; - - const int vh0i = ((vh[0] >> i) << 4) & 0x10101010; - const int vh1i = ((vh[1] >> i) << 4) & 0x10101010; - - const int v0i = vl0i | vh0i; - const int v1i = vl1i | vh1i; - - const int dot1 = __dp4a(v0i, u[2*i+0], __dp4a(v1i, u[2*i+1], 0)); // SIMD dot product - const int dot2 = __dp4a(0x01010101, u[2*i+0], __dp4a(0x01010101, u[2*i+1], 0)); // sum of u - - sumf_d += d8[i] * (dot1 * sc[i]); - sumf_m += d8[i] * (dot2 * m[i]); - - } - - const float2 dm5f = __half22float2(dm5); - - return dm5f.x*sumf_d - dm5f.y*sumf_m; - -#else - assert(false); - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -// contiguous u/y values -static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( - const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, - const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - float sumf_d = 0.0f; - float sumf_m = 0.0f; - -#pragma unroll - for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) { - int sumi_d = 0; - -#pragma unroll - for (int j = 0; j < QI8_1; ++j) { - sumi_d = __dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product - } - - const float2 ds8f = __half22float2(ds8[i]); - - sumf_d += ds8f.x * (sc[i] * sumi_d); - sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val - } - - const float2 dm4f = __half22float2(dm4); - - return dm4f.x*sumf_d - dm4f.y*sumf_m; - -#else - assert(false); - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -#define VDR_Q6_K_Q8_1_MMVQ 1 -#define VDR_Q6_K_Q8_1_MMQ 8 - -// contiguous v/x values -static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq( - const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales, - const float & d, const float * __restrict__ d8) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - float sumf = 0.0f; - -#pragma unroll - for (int i = 0; i < QR6_K; ++i) { - const int sc = scales[4*i]; - - const int vil = (vl >> (4*i)) & 0x0F0F0F0F; - - const int vih = ((vh >> (4*i)) << 4) & 0x30303030; - - const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32 - - sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product - } - - return d*sumf; -#else - assert(false); - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -// contiguous u/y values -static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( - const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc, - const float & d6, const float * __restrict__ d8) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - float sumf_d = 0.0f; - -#pragma unroll - for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) { - int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale - -#pragma unroll - for (int i = i0; i < i0 + 2; ++i) { - sumi_d.x = __dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product - sumi_d.x = __dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product - - sumi_d.y = __dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product - sumi_d.y = __dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product - } - - sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y); - } - - return d6 * sumf_d; - -#else - assert(false); - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A -} - -static __device__ __forceinline__ float vec_dot_q4_0_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { - - const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; - - int v[VDR_Q4_0_Q8_1_MMVQ]; - int u[2*VDR_Q4_0_Q8_1_MMVQ]; - -#pragma unroll - for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) { - v[i] = get_int_from_uint8(bq4_0->qs, iqs + i); - u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); - u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0); - } - - return vec_dot_q4_0_q8_1_impl(v, u, bq4_0->d, bq8_1->ds); -} - -template static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - - __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI4_0) + mmq_y/QI4_0]; - - *x_ql = tile_x_qs; - *x_dm = (half2 *) tile_x_d; -} - -template static __device__ __forceinline__ void load_tiles_q4_0( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI4_0; - const int kqsx = k % QI4_0; - - const block_q4_0 * bx0 = (block_q4_0 *) vx; - - float * x_dmf = (float *) x_dm; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx; - - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); - // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d; - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; - const int kbxd = k % blocks_per_tile_x_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) { - int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d; - } -} - -static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - - const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); - const float * x_dmf = (float *) x_dm; - - int u[2*VDR_Q4_0_Q8_1_MMQ]; - -#pragma unroll - for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE]; - } - - return vec_dot_q4_0_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0], - y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); -} - -static __device__ __forceinline__ float vec_dot_q4_1_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { - - const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq; - - int v[VDR_Q4_1_Q8_1_MMVQ]; - int u[2*VDR_Q4_1_Q8_1_MMVQ]; - -#pragma unroll - for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) { - v[i] = get_int_from_uint8_aligned(bq4_1->qs, iqs + i); - u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); - u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1); - } - - return vec_dot_q4_1_q8_1_impl(v, u, bq4_1->dm, bq8_1->ds); -} - -template static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - - __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1]; - - *x_ql = tile_x_qs; - *x_dm = tile_x_dm; -} - -template static __device__ __forceinline__ void load_tiles_q4_1( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI4_1; - const int kqsx = k % QI4_1; - - const block_q4_1 * bx0 = (block_q4_1 *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx; - - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; - const int kbxd = k % blocks_per_tile_x_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) { - int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm; - } -} - -static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - - const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); - - int u[2*VDR_Q4_1_Q8_1_MMQ]; - -#pragma unroll - for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE]; - } - - return vec_dot_q4_1_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1], - y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); -} - -static __device__ __forceinline__ float vec_dot_q5_0_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { - - const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq; - - int vl[VDR_Q5_0_Q8_1_MMVQ]; - int vh[VDR_Q5_0_Q8_1_MMVQ]; - int u[2*VDR_Q5_0_Q8_1_MMVQ]; - -#pragma unroll - for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) { - vl[i] = get_int_from_uint8(bq5_0->qs, iqs + i); - vh[i] = get_int_from_uint8(bq5_0->qh, 0) >> (4 * (iqs + i)); - u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); - u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_0); - } - - return vec_dot_q5_0_q8_1_impl(vl, vh, u, bq5_0->d, bq8_1->ds); -} - -template static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - - __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; - __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI5_0) + mmq_y/QI5_0]; - - *x_ql = tile_x_ql; - *x_dm = (half2 *) tile_x_d; -} - -template static __device__ __forceinline__ void load_tiles_q5_0( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI5_0; - const int kqsx = k % QI5_0; - - const block_q5_0 * bx0 = (block_q5_0 *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx; - - const int ql = get_int_from_uint8(bxi->qs, kqsx); - const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0)); - - int qs0 = (ql >> 0) & 0x0F0F0F0F; - qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 - qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 - qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 - qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 - qs0 = __vsubss4(qs0, 0x10101010); // subtract 16 - - x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0; - - int qs1 = (ql >> 4) & 0x0F0F0F0F; - qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 - qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 - qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 - qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 - qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 - - x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1; - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI5_0; - const int kbxd = k % blocks_per_tile_x_row; - float * x_dmf = (float *) x_dm; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) { - int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d; - } -} - -static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - - const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); - const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0; - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; - - int u[2*VDR_Q5_0_Q8_1_MMQ]; - -#pragma unroll - for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE]; - } - - return vec_dot_q8_0_q8_1_impl - (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); -} - -static __device__ __forceinline__ float vec_dot_q5_1_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { - - const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq; - - int vl[VDR_Q5_1_Q8_1_MMVQ]; - int vh[VDR_Q5_1_Q8_1_MMVQ]; - int u[2*VDR_Q5_1_Q8_1_MMVQ]; - -#pragma unroll - for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) { - vl[i] = get_int_from_uint8_aligned(bq5_1->qs, iqs + i); - vh[i] = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * (iqs + i)); - u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); - u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1); - } - - return vec_dot_q5_1_q8_1_impl(vl, vh, u, bq5_1->dm, bq8_1->ds); -} - -template static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - - __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_1) + mmq_y/QI5_1]; - - *x_ql = tile_x_ql; - *x_dm = tile_x_dm; -} - -template static __device__ __forceinline__ void load_tiles_q5_1( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI5_1; - const int kqsx = k % QI5_1; - - const block_q5_1 * bx0 = (block_q5_1 *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx; - - const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); - const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1)); - - int qs0 = (ql >> 0) & 0x0F0F0F0F; - qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 - qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 - qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 - qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 - - x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0; - - int qs1 = (ql >> 4) & 0x0F0F0F0F; - qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 - qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 - qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 - qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 - - x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1; - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI5_1; - const int kbxd = k % blocks_per_tile_x_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) { - int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm; - } -} - -static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - - const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); - const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1; - - int u[2*VDR_Q5_1_Q8_1_MMQ]; - -#pragma unroll - for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE]; - } - - return vec_dot_q8_1_q8_1_impl - (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); -} - -static __device__ __forceinline__ float vec_dot_q8_0_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { - - const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq; - - int v[VDR_Q8_0_Q8_1_MMVQ]; - int u[VDR_Q8_0_Q8_1_MMVQ]; - -#pragma unroll - for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) { - v[i] = get_int_from_int8(bq8_0->qs, iqs + i); - u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); - } - - return vec_dot_q8_0_q8_1_impl(v, u, bq8_0->d, __low2half(bq8_1->ds)); -} - -template static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - - __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0]; - - *x_ql = tile_x_qs; - *x_dm = (half2 *) tile_x_d; -} - -template static __device__ __forceinline__ void load_tiles_q8_0( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI8_0; - const int kqsx = k % QI8_0; - float * x_dmf = (float *) x_dm; - - const block_q8_0 * bx0 = (block_q8_0 *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx; - - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx); - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI8_0; - const int kbxd = k % blocks_per_tile_x_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) { - int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d; - } -} - -static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; - - return vec_dot_q8_0_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0], - y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]); -} - -static __device__ __forceinline__ float vec_dot_q2_K_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { - - const block_q2_K * bq2_K = (const block_q2_K *) vbq; - - const int bq8_offset = QR2_K * (iqs / QI8_1); - const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); - - const uint8_t * scales = bq2_K->scales + scale_offset; - - const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs); - int u[QR2_K]; - float d8[QR2_K]; - -#pragma unroll - for (int i = 0; i < QR2_K; ++ i) { - u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); - d8[i] = __low2half(bq8_1[bq8_offset + i].ds); - } - - return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8); -} - -template static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - - __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI2_K) + mmq_y/QI2_K]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4]; - - *x_ql = tile_x_ql; - *x_dm = tile_x_dm; - *x_sc = tile_x_sc; -} - -template static __device__ __forceinline__ void load_tiles_q2_K( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI2_K; - const int kqsx = k % QI2_K; - - const block_q2_K * bx0 = (block_q2_K *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx; - - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI2_K; - const int kbxd = k % blocks_per_tile_x_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) { - int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm; - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { - int i = i0 + i_offset * 4 + k / (WARP_SIZE/4); - - if (need_check) { - i = min(i, i_max); - } - - const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4); - - x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4)); - } -} - -static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - - const int kbx = k / QI2_K; - const int ky = (k % QI2_K) * QR2_K; - const float * y_df = (const float *) y_ds; - - int v[QR2_K*VDR_Q2_K_Q8_1_MMQ]; - - const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2); - const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2)); - -#pragma unroll - for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) { - v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303; - } - - const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4; - - const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE; - return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]); -} - -static __device__ __forceinline__ float vec_dot_q3_K_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { - - const block_q3_K * bq3_K = (const block_q3_K *) vbq; - - const int bq8_offset = QR3_K * (iqs / (QI3_K/2)); - const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); - - const float d = bq3_K->d; - - const int vl = get_int_from_uint8(bq3_K->qs, iqs); - - // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted - const int vh = ~get_int_from_uint8(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset; - - int u[QR3_K]; - float d8[QR3_K]; - -#pragma unroll - for (int i = 0; i < QR3_K; ++i) { - u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); - d8[i] = __low2half(bq8_1[bq8_offset + i].ds); - } - - return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8); -} - -template static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - - __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI3_K) + mmq_y/QI3_K]; - __shared__ int tile_x_qh[mmq_y * (WARP_SIZE/2) + mmq_y/2]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4]; - - *x_ql = tile_x_ql; - *x_dm = tile_x_dm; - *x_qh = tile_x_qh; - *x_sc = tile_x_sc; -} - -template static __device__ __forceinline__ void load_tiles_q3_K( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI3_K; - const int kqsx = k % QI3_K; - - const block_q3_K * bx0 = (block_q3_K *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx; - - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI3_K; - const int kbxd = k % blocks_per_tile_x_row; - float * x_dmf = (float *) x_dm; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) { - int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d; - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) { - int i = i0 + i_offset * 2 + k / (WARP_SIZE/2); - - if (need_check) { - i = min(i, i_max); - } - - const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2); - - // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted - x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2)); - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { - int i = i0 + i_offset * 4 + k / (WARP_SIZE/4); - - if (need_check) { - i = min(i, i_max); - } - - const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4); - - const int ksc = k % (QI3_K/4); - - const int ksc_low = ksc % (QI3_K/8); - const int shift_low = 4 * (ksc / (QI3_K/8)); - const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F; - - const int ksc_high = QI3_K/8; - const int shift_high = 2 * ksc; - const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030; - - const int sc = __vsubss4(sc_low | sc_high, 0x20202020); - - x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc; - } -} - -static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - - const int kbx = k / QI3_K; - const int ky = (k % QI3_K) * QR3_K; - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; - - const int8_t * scales = ((int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; - - int v[QR3_K*VDR_Q3_K_Q8_1_MMQ]; - -#pragma unroll - for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) { - const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2); - const int shift = 2 * ((ky % 32) / 8); - const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303; - - const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8); - const int vlh = (vh << 2) & 0x04040404; - - v[l] = __vsubss4(vll, vlh); - } - - const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE; - return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]); -} - -static __device__ __forceinline__ float vec_dot_q4_K_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { - -#ifndef GGML_QKK_64 - const block_q4_K * bq4_K = (const block_q4_K *) vbq; - - int v[2]; - int u[2*QR4_K]; - float d8[QR4_K]; - - // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6 - const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2)); - - // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12 - // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44 - // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76 - // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108 - - const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4)); - v[0] = q4[0]; - v[1] = q4[4]; - - const uint16_t * scales = (const uint16_t *)bq4_K->scales; - uint16_t aux[2]; - const int j = bq8_offset/2; - if (j < 2) { - aux[0] = scales[j+0] & 0x3f3f; - aux[1] = scales[j+2] & 0x3f3f; - } else { - aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); - aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); - } - const uint8_t * sc = (const uint8_t *)aux; - const uint8_t * m = sc + 2; - - for (int i = 0; i < QR4_K; ++i) { - const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; - d8[i] = __low2half(bq8i->ds); - - const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); - u[2*i+0] = q8[0]; - u[2*i+1] = q8[4]; - } - - return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8); - -#else - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const block_q4_K * bq4_K = (const block_q4_K *) vbq; - - float sumf_d = 0.0f; - float sumf_m = 0.0f; - - uint16_t aux16[2]; - const uint8_t * s = (const uint8_t *)aux16; - - const uint16_t * a = (const uint16_t *)bq4_K->scales; - aux16[0] = a[0] & 0x0f0f; - aux16[1] = (a[0] >> 4) & 0x0f0f; - - const float dall = bq4_K->dm[0]; - const float dmin = bq4_K->dm[1]; - - const float d8_1 = __low2float(bq8_1[0].ds); - const float d8_2 = __low2float(bq8_1[1].ds); - - const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2)); - const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4); - const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2)); - const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4); - - const int * q4 = (const int *)bq4_K->qs + (iqs/2); - const int v1 = q4[0]; - const int v2 = q4[4]; - - const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0)); - const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0)); - const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0)); - const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0)); - - sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]); - sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]); - - return dall * sumf_d - dmin * sumf_m; - -#else - assert(false); - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A - -#endif -} - -template static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - - __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_K) + mmq_y/QI4_K]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; - - *x_ql = tile_x_ql; - *x_dm = tile_x_dm; - *x_sc = tile_x_sc; -} - -template static __device__ __forceinline__ void load_tiles_q4_K( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI4_K; // == 0 if QK_K == 256 - const int kqsx = k % QI4_K; // == k if QK_K == 256 - - const block_q4_K * bx0 = (block_q4_K *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx; - - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256 - const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) { - int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd; - -#if QK_K == 256 - x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm; -#else - x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]}; -#endif - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8); - - const int * scales = (int *) bxi->scales; - - const int ksc = k % (WARP_SIZE/8); - - // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 - int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits - scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits - - x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; - } -} - -static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8); - - const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE; - return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8, - x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]); -} - -static __device__ __forceinline__ float vec_dot_q5_K_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { - -#ifndef GGML_QKK_64 - const block_q5_K * bq5_K = (const block_q5_K *) vbq; - - int vl[2]; - int vh[2]; - int u[2*QR5_K]; - float d8[QR5_K]; - - const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2)); - const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4)); - const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4)); - - vl[0] = ql[0]; - vl[1] = ql[4]; - - vh[0] = qh[0] >> bq8_offset; - vh[1] = qh[4] >> bq8_offset; - - const uint16_t * scales = (const uint16_t *)bq5_K->scales; - uint16_t aux[2]; - const int j = bq8_offset/2; - if (j < 2) { - aux[0] = scales[j+0] & 0x3f3f; - aux[1] = scales[j+2] & 0x3f3f; - } else { - aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); - aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); - } - const uint8_t * sc = (const uint8_t *)aux; - const uint8_t * m = sc + 2; - -#pragma unroll - for (int i = 0; i < QR5_K; ++i) { - const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; - d8[i] = __low2float(bq8i->ds); - - const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); - u[2*i+0] = q8[0]; - u[2*i+1] = q8[4]; - } - - return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8); - -#else - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const block_q5_K * bq5_K = (const block_q5_K *) vbq; - - const int8_t * s = bq5_K->scales; - - const float d = bq5_K->d; - - const float d8_1 = __low2half(bq8_1[0].ds); - const float d8_2 = __low2half(bq8_1[1].ds); - - const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2)); - const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4); - const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2)); - const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4); - - const int * ql = (const int *)bq5_K->qs + (iqs/2); - const int vl1 = ql[0]; - const int vl2 = ql[4]; - - const int step = 4 * (iqs/2); // 0, 4, 8, 12 - const int im = step/8; // = 0 for iqs = 0, 2, = 1 for iqs = 4, 6 - const int in = step%8; // 0, 4, 0, 4 - const int vh = (*((const int *)(bq5_K->qh + in))) >> im; - - const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f); - const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f); - const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f); - const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f); - - const float sumf_d = d8_1 * (__dp4a(ui1, v1, 0) * s[0] + __dp4a(ui2, v2, 0) * s[1]) - + d8_2 * (__dp4a(ui3, v3, 0) * s[2] + __dp4a(ui4, v4, 0) * s[3]); - - return d * sumf_d; - -#else - assert(false); - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A - -#endif -} - -template static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - - __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_K) + mmq_y/QI5_K]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; - - *x_ql = tile_x_ql; - *x_dm = tile_x_dm; - *x_sc = tile_x_sc; -} - -template static __device__ __forceinline__ void load_tiles_q5_K( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI5_K; // == 0 if QK_K == 256 - const int kqsx = k % QI5_K; // == k if QK_K == 256 - - const block_q5_K * bx0 = (block_q5_K *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx; - const int ky = QR5_K*kqsx; - - const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); - const int ql0 = (ql >> 0) & 0x0F0F0F0F; - const int ql1 = (ql >> 4) & 0x0F0F0F0F; - - const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4)); - const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010; - const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010; - - const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0; - const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4); - - x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0; - x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1; - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256 - const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) { - int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd; - -#if QK_K == 256 - x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm; -#endif - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8); - - const int * scales = (int *) bxi->scales; - - const int ksc = k % (WARP_SIZE/8); - - // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 - int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits - scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits - - x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; - } -} - -static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8); - - const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k; - const int index_y = j * WARP_SIZE + (QR5_K*k) % WARP_SIZE; - return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8, - x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]); -} - -static __device__ __forceinline__ float vec_dot_q6_K_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { - - const block_q6_K * bq6_K = (const block_q6_K *) vbq; - - const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4); - const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8); - const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4)); - - const int vl = get_int_from_uint8(bq6_K->ql, iqs); - const int vh = get_int_from_uint8(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift; - - const int8_t * scales = bq6_K->scales + scale_offset; - - int u[QR6_K]; - float d8[QR6_K]; - -#pragma unroll - for (int i = 0; i < QR6_K; ++i) { - u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1); - d8[i] = __low2half(bq8_1[bq8_offset + 2*i].ds); - } - - return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8); -} - -template static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - - __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI6_K) + mmq_y/QI6_K]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; - - *x_ql = tile_x_ql; - *x_dm = tile_x_dm; - *x_sc = tile_x_sc; -} - -template static __device__ __forceinline__ void load_tiles_q6_K( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI6_K; // == 0 if QK_K == 256 - const int kqsx = k % QI6_K; // == k if QK_K == 256 - - const block_q6_K * bx0 = (block_q6_K *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx; - const int ky = QR6_K*kqsx; - - const int ql = get_int_from_uint8(bxi->ql, kqsx); - const int ql0 = (ql >> 0) & 0x0F0F0F0F; - const int ql1 = (ql >> 4) & 0x0F0F0F0F; - - const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4)); - const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030; - const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030; - - const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0; - const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2); - - x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); - x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 - const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 - float * x_dmf = (float *) x_dm; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) { - int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d; - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4; - - x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8)); - } -} - -static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; - - const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]); - - const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k; - const int index_y = j * WARP_SIZE + (QR6_K*k) % WARP_SIZE; - return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]); -} - -template -static __device__ __forceinline__ void mul_mat_q( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - - const block_q_t * x = (const block_q_t *) vx; - const block_q8_1 * y = (const block_q8_1 *) vy; - - const int blocks_per_row_x = ncols_x / qk; - const int blocks_per_col_y = nrows_y / QK8_1; - const int blocks_per_warp = WARP_SIZE / qi; - - const int & ncols_dst = ncols_y; - - const int row_dst_0 = blockIdx.x*mmq_y; - const int & row_x_0 = row_dst_0; - - const int col_dst_0 = blockIdx.y*mmq_x; - const int & col_y_0 = col_dst_0; - - int * tile_x_ql = nullptr; - half2 * tile_x_dm = nullptr; - int * tile_x_qh = nullptr; - int * tile_x_sc = nullptr; - - allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc); - - __shared__ int tile_y_qs[mmq_x * WARP_SIZE]; - __shared__ half2 tile_y_ds[mmq_x * WARP_SIZE/QI8_1]; - - float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {0.0f}; - - for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) { - - load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, - threadIdx.y, nrows_x-row_x_0-1, threadIdx.x, blocks_per_row_x); - -#pragma unroll - for (int ir = 0; ir < qr; ++ir) { - const int kqs = ir*WARP_SIZE + threadIdx.x; - const int kbxd = kqs / QI8_1; - -#pragma unroll - for (int i = 0; i < mmq_x; i += nwarps) { - const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y-1); // to prevent out-of-bounds memory accesses - - const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd]; - - const int index_y = (threadIdx.y + i) * WARP_SIZE + kqs % WARP_SIZE; - tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1); - } - -#pragma unroll - for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) { - const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x; - const int kby = threadIdx.x % (WARP_SIZE/QI8_1); - const int col_y_eff = min(col_y_0 + ids, ncols_y-1); - - // if the sum is not needed it's faster to transform the scale to f32 ahead of time - const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE/QI8_1) + kby].ds; - half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby]; - if (need_sum) { - *dsi_dst = *dsi_src; - } else { - float * dfi_dst = (float *) dsi_dst; - *dfi_dst = __low2half(*dsi_src); - } - } - - __syncthreads(); - -// #pragma unroll // unrolling this loop causes too much register pressure - for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) { -#pragma unroll - for (int j = 0; j < mmq_x; j += nwarps) { -#pragma unroll - for (int i = 0; i < mmq_y; i += WARP_SIZE) { - sum[i/WARP_SIZE][j/nwarps] += vec_dot( - tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, - threadIdx.x + i, threadIdx.y + j, k); - } - } - } - - __syncthreads(); - } - } - -#pragma unroll - for (int j = 0; j < mmq_x; j += nwarps) { - const int col_dst = col_dst_0 + j + threadIdx.y; - - if (col_dst >= ncols_dst) { - return; - } - -#pragma unroll - for (int i = 0; i < mmq_y; i += WARP_SIZE) { - const int row_dst = row_dst_0 + threadIdx.x + i; - - if (row_dst >= nrows_dst) { - continue; - } - - dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps]; - } - } -} - -#define MMQ_X_Q4_0_RDNA2 64 -#define MMQ_Y_Q4_0_RDNA2 128 -#define NWARPS_Q4_0_RDNA2 8 -#define MMQ_X_Q4_0_RDNA1 64 -#define MMQ_Y_Q4_0_RDNA1 64 -#define NWARPS_Q4_0_RDNA1 8 -#define MMQ_X_Q4_0_AMPERE 64 -#define MMQ_Y_Q4_0_AMPERE 128 -#define NWARPS_Q4_0_AMPERE 4 -#define MMQ_X_Q4_0_PASCAL 64 -#define MMQ_Y_Q4_0_PASCAL 64 -#define NWARPS_Q4_0_PASCAL 8 - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*NWARPS_Q4_0_RDNA2, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - mul_mat_q4_0( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - const int mmq_x = MMQ_X_Q4_0_RDNA2; - const int mmq_y = MMQ_Y_Q4_0_RDNA2; - const int nwarps = NWARPS_Q4_0_RDNA2; -#else - const int mmq_x = MMQ_X_Q4_0_RDNA1; - const int mmq_y = MMQ_Y_Q4_0_RDNA1; - const int nwarps = NWARPS_Q4_0_RDNA1; -#endif // defined(RDNA3) || defined(RDNA2) - - mul_mat_q, - load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= CC_VOLTA - const int mmq_x = MMQ_X_Q4_0_AMPERE; - const int mmq_y = MMQ_Y_Q4_0_AMPERE; - const int nwarps = NWARPS_Q4_0_AMPERE; - - mul_mat_q, - load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= MIN_CC_DP4A - const int mmq_x = MMQ_X_Q4_0_PASCAL; - const int mmq_y = MMQ_Y_Q4_0_PASCAL; - const int nwarps = NWARPS_Q4_0_PASCAL; - - mul_mat_q, - load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - (void) vec_dot_q4_0_q8_1_mul_mat; - assert(false); -#endif // __CUDA_ARCH__ >= CC_VOLTA -} - -#define MMQ_X_Q4_1_RDNA2 64 -#define MMQ_Y_Q4_1_RDNA2 128 -#define NWARPS_Q4_1_RDNA2 8 -#define MMQ_X_Q4_1_RDNA1 64 -#define MMQ_Y_Q4_1_RDNA1 64 -#define NWARPS_Q4_1_RDNA1 8 -#define MMQ_X_Q4_1_AMPERE 64 -#define MMQ_Y_Q4_1_AMPERE 128 -#define NWARPS_Q4_1_AMPERE 4 -#define MMQ_X_Q4_1_PASCAL 64 -#define MMQ_Y_Q4_1_PASCAL 64 -#define NWARPS_Q4_1_PASCAL 8 - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_RDNA2, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#elif __CUDA_ARCH__ < CC_VOLTA - __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_PASCAL, 2) -#endif // __CUDA_ARCH__ < CC_VOLTA - mul_mat_q4_1( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - const int mmq_x = MMQ_X_Q4_1_RDNA2; - const int mmq_y = MMQ_Y_Q4_1_RDNA2; - const int nwarps = NWARPS_Q4_1_RDNA2; -#else - const int mmq_x = MMQ_X_Q4_1_RDNA1; - const int mmq_y = MMQ_Y_Q4_1_RDNA1; - const int nwarps = NWARPS_Q4_1_RDNA1; -#endif // defined(RDNA3) || defined(RDNA2) - - mul_mat_q, - load_tiles_q4_1, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= CC_VOLTA - const int mmq_x = MMQ_X_Q4_1_AMPERE; - const int mmq_y = MMQ_Y_Q4_1_AMPERE; - const int nwarps = NWARPS_Q4_1_AMPERE; - - mul_mat_q, - load_tiles_q4_1, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= MIN_CC_DP4A - const int mmq_x = MMQ_X_Q4_1_PASCAL; - const int mmq_y = MMQ_Y_Q4_1_PASCAL; - const int nwarps = NWARPS_Q4_1_PASCAL; - - mul_mat_q, - load_tiles_q4_1, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - (void) vec_dot_q4_1_q8_1_mul_mat; - assert(false); -#endif // __CUDA_ARCH__ >= CC_VOLTA -} - -#define MMQ_X_Q5_0_RDNA2 64 -#define MMQ_Y_Q5_0_RDNA2 128 -#define NWARPS_Q5_0_RDNA2 8 -#define MMQ_X_Q5_0_RDNA1 64 -#define MMQ_Y_Q5_0_RDNA1 64 -#define NWARPS_Q5_0_RDNA1 8 -#define MMQ_X_Q5_0_AMPERE 128 -#define MMQ_Y_Q5_0_AMPERE 64 -#define NWARPS_Q5_0_AMPERE 4 -#define MMQ_X_Q5_0_PASCAL 64 -#define MMQ_Y_Q5_0_PASCAL 64 -#define NWARPS_Q5_0_PASCAL 8 - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*NWARPS_Q5_0_RDNA2, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - mul_mat_q5_0( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - const int mmq_x = MMQ_X_Q5_0_RDNA2; - const int mmq_y = MMQ_Y_Q5_0_RDNA2; - const int nwarps = NWARPS_Q5_0_RDNA2; -#else - const int mmq_x = MMQ_X_Q5_0_RDNA1; - const int mmq_y = MMQ_Y_Q5_0_RDNA1; - const int nwarps = NWARPS_Q5_0_RDNA1; -#endif // defined(RDNA3) || defined(RDNA2) - - mul_mat_q, - load_tiles_q5_0, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= CC_VOLTA - const int mmq_x = MMQ_X_Q5_0_AMPERE; - const int mmq_y = MMQ_Y_Q5_0_AMPERE; - const int nwarps = NWARPS_Q5_0_AMPERE; - - mul_mat_q, - load_tiles_q5_0, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= MIN_CC_DP4A - const int mmq_x = MMQ_X_Q5_0_PASCAL; - const int mmq_y = MMQ_Y_Q5_0_PASCAL; - const int nwarps = NWARPS_Q5_0_PASCAL; - - mul_mat_q, - load_tiles_q5_0, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - (void) vec_dot_q5_0_q8_1_mul_mat; - assert(false); -#endif // __CUDA_ARCH__ >= CC_VOLTA -} - -#define MMQ_X_Q5_1_RDNA2 64 -#define MMQ_Y_Q5_1_RDNA2 128 -#define NWARPS_Q5_1_RDNA2 8 -#define MMQ_X_Q5_1_RDNA1 64 -#define MMQ_Y_Q5_1_RDNA1 64 -#define NWARPS_Q5_1_RDNA1 8 -#define MMQ_X_Q5_1_AMPERE 128 -#define MMQ_Y_Q5_1_AMPERE 64 -#define NWARPS_Q5_1_AMPERE 4 -#define MMQ_X_Q5_1_PASCAL 64 -#define MMQ_Y_Q5_1_PASCAL 64 -#define NWARPS_Q5_1_PASCAL 8 - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*NWARPS_Q5_1_RDNA2, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -mul_mat_q5_1( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - const int mmq_x = MMQ_X_Q5_1_RDNA2; - const int mmq_y = MMQ_Y_Q5_1_RDNA2; - const int nwarps = NWARPS_Q5_1_RDNA2; -#else - const int mmq_x = MMQ_X_Q5_1_RDNA1; - const int mmq_y = MMQ_Y_Q5_1_RDNA1; - const int nwarps = NWARPS_Q5_1_RDNA1; -#endif // defined(RDNA3) || defined(RDNA2) - - mul_mat_q, - load_tiles_q5_1, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= CC_VOLTA - const int mmq_x = MMQ_X_Q5_1_AMPERE; - const int mmq_y = MMQ_Y_Q5_1_AMPERE; - const int nwarps = NWARPS_Q5_1_AMPERE; - - mul_mat_q, - load_tiles_q5_1, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= MIN_CC_DP4A - const int mmq_x = MMQ_X_Q5_1_PASCAL; - const int mmq_y = MMQ_Y_Q5_1_PASCAL; - const int nwarps = NWARPS_Q5_1_PASCAL; - - mul_mat_q, - load_tiles_q5_1, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - (void) vec_dot_q5_1_q8_1_mul_mat; - assert(false); -#endif // __CUDA_ARCH__ >= CC_VOLTA -} - -#define MMQ_X_Q8_0_RDNA2 64 -#define MMQ_Y_Q8_0_RDNA2 128 -#define NWARPS_Q8_0_RDNA2 8 -#define MMQ_X_Q8_0_RDNA1 64 -#define MMQ_Y_Q8_0_RDNA1 64 -#define NWARPS_Q8_0_RDNA1 8 -#define MMQ_X_Q8_0_AMPERE 128 -#define MMQ_Y_Q8_0_AMPERE 64 -#define NWARPS_Q8_0_AMPERE 4 -#define MMQ_X_Q8_0_PASCAL 64 -#define MMQ_Y_Q8_0_PASCAL 64 -#define NWARPS_Q8_0_PASCAL 8 - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*NWARPS_Q8_0_RDNA2, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - mul_mat_q8_0( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - const int mmq_x = MMQ_X_Q8_0_RDNA2; - const int mmq_y = MMQ_Y_Q8_0_RDNA2; - const int nwarps = NWARPS_Q8_0_RDNA2; -#else - const int mmq_x = MMQ_X_Q8_0_RDNA1; - const int mmq_y = MMQ_Y_Q8_0_RDNA1; - const int nwarps = NWARPS_Q8_0_RDNA1; -#endif // defined(RDNA3) || defined(RDNA2) - - mul_mat_q, - load_tiles_q8_0, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= CC_VOLTA - const int mmq_x = MMQ_X_Q8_0_AMPERE; - const int mmq_y = MMQ_Y_Q8_0_AMPERE; - const int nwarps = NWARPS_Q8_0_AMPERE; - - mul_mat_q, - load_tiles_q8_0, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= MIN_CC_DP4A - const int mmq_x = MMQ_X_Q8_0_PASCAL; - const int mmq_y = MMQ_Y_Q8_0_PASCAL; - const int nwarps = NWARPS_Q8_0_PASCAL; - - mul_mat_q, - load_tiles_q8_0, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - (void) vec_dot_q8_0_q8_1_mul_mat; - assert(false); -#endif // __CUDA_ARCH__ >= CC_VOLTA -} - -#define MMQ_X_Q2_K_RDNA2 64 -#define MMQ_Y_Q2_K_RDNA2 128 -#define NWARPS_Q2_K_RDNA2 8 -#define MMQ_X_Q2_K_RDNA1 128 -#define MMQ_Y_Q2_K_RDNA1 32 -#define NWARPS_Q2_K_RDNA1 8 -#define MMQ_X_Q2_K_AMPERE 64 -#define MMQ_Y_Q2_K_AMPERE 128 -#define NWARPS_Q2_K_AMPERE 4 -#define MMQ_X_Q2_K_PASCAL 64 -#define MMQ_Y_Q2_K_PASCAL 64 -#define NWARPS_Q2_K_PASCAL 8 - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*NWARPS_Q2_K_RDNA2, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -mul_mat_q2_K( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - const int mmq_x = MMQ_X_Q2_K_RDNA2; - const int mmq_y = MMQ_Y_Q2_K_RDNA2; - const int nwarps = NWARPS_Q2_K_RDNA2; -#else - const int mmq_x = MMQ_X_Q2_K_RDNA1; - const int mmq_y = MMQ_Y_Q2_K_RDNA1; - const int nwarps = NWARPS_Q2_K_RDNA1; -#endif // defined(RDNA3) || defined(RDNA2) - - mul_mat_q, - load_tiles_q2_K, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= CC_VOLTA - const int mmq_x = MMQ_X_Q2_K_AMPERE; - const int mmq_y = MMQ_Y_Q2_K_AMPERE; - const int nwarps = NWARPS_Q2_K_AMPERE; - - mul_mat_q, - load_tiles_q2_K, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= MIN_CC_DP4A - const int mmq_x = MMQ_X_Q2_K_PASCAL; - const int mmq_y = MMQ_Y_Q2_K_PASCAL; - const int nwarps = NWARPS_Q2_K_PASCAL; - - mul_mat_q, - load_tiles_q2_K, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - (void) vec_dot_q2_K_q8_1_mul_mat; - assert(false); -#endif // __CUDA_ARCH__ >= CC_VOLTA -} - -#define MMQ_X_Q3_K_RDNA2 128 -#define MMQ_Y_Q3_K_RDNA2 64 -#define NWARPS_Q3_K_RDNA2 8 -#define MMQ_X_Q3_K_RDNA1 32 -#define MMQ_Y_Q3_K_RDNA1 128 -#define NWARPS_Q3_K_RDNA1 8 -#define MMQ_X_Q3_K_AMPERE 128 -#define MMQ_Y_Q3_K_AMPERE 128 -#define NWARPS_Q3_K_AMPERE 4 -#define MMQ_X_Q3_K_PASCAL 64 -#define MMQ_Y_Q3_K_PASCAL 64 -#define NWARPS_Q3_K_PASCAL 8 - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_RDNA2, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#elif __CUDA_ARCH__ < CC_VOLTA - __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_PASCAL, 2) -#endif // __CUDA_ARCH__ < CC_VOLTA - mul_mat_q3_K( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - const int mmq_x = MMQ_X_Q3_K_RDNA2; - const int mmq_y = MMQ_Y_Q3_K_RDNA2; - const int nwarps = NWARPS_Q3_K_RDNA2; -#else - const int mmq_x = MMQ_X_Q3_K_RDNA1; - const int mmq_y = MMQ_Y_Q3_K_RDNA1; - const int nwarps = NWARPS_Q3_K_RDNA1; -#endif // defined(RDNA3) || defined(RDNA2) - - mul_mat_q, - load_tiles_q3_K, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= CC_VOLTA - const int mmq_x = MMQ_X_Q3_K_AMPERE; - const int mmq_y = MMQ_Y_Q3_K_AMPERE; - const int nwarps = NWARPS_Q3_K_AMPERE; - - mul_mat_q, - load_tiles_q3_K, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= MIN_CC_DP4A - const int mmq_x = MMQ_X_Q3_K_PASCAL; - const int mmq_y = MMQ_Y_Q3_K_PASCAL; - const int nwarps = NWARPS_Q3_K_PASCAL; - - mul_mat_q, - load_tiles_q3_K, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - (void) vec_dot_q3_K_q8_1_mul_mat; - assert(false); -#endif // __CUDA_ARCH__ >= CC_VOLTA -} - -#define MMQ_X_Q4_K_RDNA2 64 -#define MMQ_Y_Q4_K_RDNA2 128 -#define NWARPS_Q4_K_RDNA2 8 -#define MMQ_X_Q4_K_RDNA1 32 -#define MMQ_Y_Q4_K_RDNA1 64 -#define NWARPS_Q4_K_RDNA1 8 -#define MMQ_X_Q4_K_AMPERE 64 -#define MMQ_Y_Q4_K_AMPERE 128 -#define NWARPS_Q4_K_AMPERE 4 -#define MMQ_X_Q4_K_PASCAL 64 -#define MMQ_Y_Q4_K_PASCAL 64 -#define NWARPS_Q4_K_PASCAL 8 - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_RDNA2, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#elif __CUDA_ARCH__ < CC_VOLTA - __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_PASCAL, 2) -#endif // __CUDA_ARCH__ < CC_VOLTA - mul_mat_q4_K( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - const int mmq_x = MMQ_X_Q4_K_RDNA2; - const int mmq_y = MMQ_Y_Q4_K_RDNA2; - const int nwarps = NWARPS_Q4_K_RDNA2; -#else - const int mmq_x = MMQ_X_Q4_K_RDNA1; - const int mmq_y = MMQ_Y_Q4_K_RDNA1; - const int nwarps = NWARPS_Q4_K_RDNA1; -#endif // defined(RDNA3) || defined(RDNA2) - - mul_mat_q, - load_tiles_q4_K, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= CC_VOLTA - const int mmq_x = MMQ_X_Q4_K_AMPERE; - const int mmq_y = MMQ_Y_Q4_K_AMPERE; - const int nwarps = NWARPS_Q4_K_AMPERE; - - mul_mat_q, - load_tiles_q4_K, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= MIN_CC_DP4A - const int mmq_x = MMQ_X_Q4_K_PASCAL; - const int mmq_y = MMQ_Y_Q4_K_PASCAL; - const int nwarps = NWARPS_Q4_K_PASCAL; - - mul_mat_q, - load_tiles_q4_K, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - (void) vec_dot_q4_K_q8_1_mul_mat; - assert(false); -#endif // __CUDA_ARCH__ >= CC_VOLTA -} - -#define MMQ_X_Q5_K_RDNA2 64 -#define MMQ_Y_Q5_K_RDNA2 128 -#define NWARPS_Q5_K_RDNA2 8 -#define MMQ_X_Q5_K_RDNA1 32 -#define MMQ_Y_Q5_K_RDNA1 64 -#define NWARPS_Q5_K_RDNA1 8 -#define MMQ_X_Q5_K_AMPERE 64 -#define MMQ_Y_Q5_K_AMPERE 128 -#define NWARPS_Q5_K_AMPERE 4 -#define MMQ_X_Q5_K_PASCAL 64 -#define MMQ_Y_Q5_K_PASCAL 64 -#define NWARPS_Q5_K_PASCAL 8 - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*NWARPS_Q5_K_RDNA2, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -mul_mat_q5_K( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - const int mmq_x = MMQ_X_Q5_K_RDNA2; - const int mmq_y = MMQ_Y_Q5_K_RDNA2; - const int nwarps = NWARPS_Q5_K_RDNA2; -#else - const int mmq_x = MMQ_X_Q5_K_RDNA1; - const int mmq_y = MMQ_Y_Q5_K_RDNA1; - const int nwarps = NWARPS_Q5_K_RDNA1; -#endif // defined(RDNA3) || defined(RDNA2) - - mul_mat_q, - load_tiles_q5_K, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= CC_VOLTA - const int mmq_x = MMQ_X_Q5_K_AMPERE; - const int mmq_y = MMQ_Y_Q5_K_AMPERE; - const int nwarps = NWARPS_Q5_K_AMPERE; - - mul_mat_q, - load_tiles_q5_K, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= MIN_CC_DP4A - const int mmq_x = MMQ_X_Q5_K_PASCAL; - const int mmq_y = MMQ_Y_Q5_K_PASCAL; - const int nwarps = NWARPS_Q5_K_PASCAL; - - mul_mat_q, - load_tiles_q5_K, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - (void) vec_dot_q5_K_q8_1_mul_mat; - assert(false); -#endif // __CUDA_ARCH__ >= CC_VOLTA -} - -#define MMQ_X_Q6_K_RDNA2 64 -#define MMQ_Y_Q6_K_RDNA2 128 -#define NWARPS_Q6_K_RDNA2 8 -#define MMQ_X_Q6_K_RDNA1 32 -#define MMQ_Y_Q6_K_RDNA1 64 -#define NWARPS_Q6_K_RDNA1 8 -#define MMQ_X_Q6_K_AMPERE 64 -#define MMQ_Y_Q6_K_AMPERE 64 -#define NWARPS_Q6_K_AMPERE 4 -#define MMQ_X_Q6_K_PASCAL 64 -#define MMQ_Y_Q6_K_PASCAL 64 -#define NWARPS_Q6_K_PASCAL 8 - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_RDNA2, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#elif __CUDA_ARCH__ < CC_VOLTA - __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_PASCAL, 2) -#endif // __CUDA_ARCH__ < CC_VOLTA - mul_mat_q6_K( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - const int mmq_x = MMQ_X_Q6_K_RDNA2; - const int mmq_y = MMQ_Y_Q6_K_RDNA2; - const int nwarps = NWARPS_Q6_K_RDNA2; -#else - const int mmq_x = MMQ_X_Q6_K_RDNA1; - const int mmq_y = MMQ_Y_Q6_K_RDNA1; - const int nwarps = NWARPS_Q6_K_RDNA1; -#endif // defined(RDNA3) || defined(RDNA2) - - mul_mat_q, - load_tiles_q6_K, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= CC_VOLTA - const int mmq_x = MMQ_X_Q6_K_AMPERE; - const int mmq_y = MMQ_Y_Q6_K_AMPERE; - const int nwarps = NWARPS_Q6_K_AMPERE; - - mul_mat_q, - load_tiles_q6_K, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= MIN_CC_DP4A - const int mmq_x = MMQ_X_Q6_K_PASCAL; - const int mmq_y = MMQ_Y_Q6_K_PASCAL; - const int nwarps = NWARPS_Q6_K_PASCAL; - - mul_mat_q, - load_tiles_q6_K, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - (void) vec_dot_q6_K_q8_1_mul_mat; - assert(false); -#endif // __CUDA_ARCH__ >= CC_VOLTA -} - -template -static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) { - const int row = blockIdx.y*blockDim.y + threadIdx.y; - - if (row >= nrows) { - return; - } - - const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; - -// partial sum for each thread - float tmp = 0.0f; - - const block_q_t * x = (const block_q_t *) vx; - const block_q8_1 * y = (const block_q8_1 *) vy; - - for (int i = 0; i < blocks_per_row; i += blocks_per_warp) { - const int ibx = row*blocks_per_row + i + threadIdx.x / (qi/vdr); // x block index - - const int iby = (i + threadIdx.x / (qi/vdr)) * (qk/QK8_1); // y block index that aligns with ibx - - const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int - - tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs); - } - - // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } - - if (threadIdx.x == 0) { - dst[row] = tmp; - } -} - -template -static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { - // qk = quantized weights per x block - // qr = number of quantized weights per data value in x block - const int row = blockIdx.y*blockDim.y + threadIdx.y; - - if (row >= nrows) { - return; - } - - const int tid = threadIdx.x; - - const int iter_stride = 2*GGML_CUDA_DMMV_X; - const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter - const int y_offset = qr == 1 ? 1 : qk/2; - -// partial sum for each thread -#ifdef GGML_CUDA_F16 - half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics -#else - float tmp = 0.0f; -#endif // GGML_CUDA_F16 - - for (int i = 0; i < ncols; i += iter_stride) { - const int col = i + vals_per_iter*tid; - const int ib = (row*ncols + col)/qk; // x block index - const int iqs = (col%qk)/qr; // x quant index - const int iybs = col - col%qk; // y block start index - -// processing >2 values per i iter is faster for fast GPUs -#pragma unroll - for (int j = 0; j < vals_per_iter; j += 2) { - // process 2 vals per j iter - - // dequantize - // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val - dfloat2 v; - dequantize_kernel(vx, ib, iqs + j/qr, v); - - // matrix multiplication - // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 -#ifdef GGML_CUDA_F16 - tmp += __hmul2(v, { - y[iybs + iqs + j/qr + 0], - y[iybs + iqs + j/qr + y_offset] - }); -#else - tmp += v.x * y[iybs + iqs + j/qr + 0]; - tmp += v.y * y[iybs + iqs + j/qr + y_offset]; -#endif // GGML_CUDA_F16 - } - } - - // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } - - if (tid == 0) { -#ifdef GGML_CUDA_F16 - dst[row] = tmp.x + tmp.y; -#else - dst[row] = tmp; -#endif // GGML_CUDA_F16 - } -} - -static __global__ void mul_mat_p021_f16_f32( - const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) { - - const half * x = (const half *) vx; - - const int row_x = blockDim.y*blockIdx.y + threadIdx.y; - const int channel = blockDim.z*blockIdx.z + threadIdx.z; - const int channel_x = channel / (nchannels_y / nchannels_x); - - const int nrows_y = ncols_x; - const int nrows_dst = nrows_x; - const int row_dst = row_x; - - float tmp = 0.0f; - - for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) { - const int col_x = col_x0 + threadIdx.x; - - if (col_x >= ncols_x) { - break; - } - - // x is transposed and permuted - const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x; - const float xi = __half2float(x[ix]); - - const int row_y = col_x; - - - // y is not transposed but permuted - const int iy = channel*nrows_y + row_y; - - tmp += xi * y[iy]; - } - - // dst is not transposed and not permuted - const int idst = channel*nrows_dst + row_dst; - - // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } - - if (threadIdx.x == 0) { - dst[idst] = tmp; - } -} - -static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous - const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, - const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) { - - const half * x = (const half *) vx; - - const int row_x = blockDim.y*blockIdx.y + threadIdx.y; - const int channel = blockDim.z*blockIdx.z + threadIdx.z; - const int channel_x = channel / channel_x_divisor; - - const int nrows_y = ncols_x; - const int nrows_dst = nrows_x; - const int row_dst = row_x; - - const int idst = channel*nrows_dst + row_dst; - - float tmp = 0.0f; - - for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) { - const int col_x = col_x0 + threadIdx.x; - - if (col_x >= ncols_x) { - break; - } - - const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x; - const float xi = __half2float(x[ix]); - - const int row_y = col_x; - - const int iy = channel*nrows_y + row_y; - - tmp += xi * y[iy]; - } - - // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } - - if (threadIdx.x == 0) { - dst[idst] = tmp; - } -} - -static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - float * dsti = (float *) cdsti; - - *dsti = *xi; -} - -static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) { - const float * xi = (const float *) cxi; - half * dsti = (half *) cdsti; - - *dsti = __float2half(*xi); -} - -template -static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, - const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= ne) { - return; - } - - // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor - // then combine those indices with the corresponding byte offsets to get the total offsets - const int i02 = i / (ne00*ne01); - const int i01 = (i - i02*ne01*ne00) / ne00; - const int i00 = i - i02*ne01*ne00 - i01*ne00; - const int x_offset = i00*nb00 + i01*nb01 + i02*nb02; - - const int i12 = i / (ne10*ne11); - const int i11 = (i - i12*ne10*ne11) / ne10; - const int i10 = i - i12*ne10*ne11 - i11*ne10; - const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12; - - cpy_1(cx + x_offset, cdst + dst_offset); -} - -// rope == RoPE == rotary positional embedding - -template -static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale, - const int p_delta_rows, const float theta_scale) { - const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); - - if (col >= ncols) { - return; - } - - const int row = blockDim.x*blockIdx.x + threadIdx.x; - const int i = row*ncols + col; - const int i2 = row/p_delta_rows; - - const int p = has_pos ? pos[i2] : 0; - const float p0 = p*freq_scale; - const float theta = p0*powf(theta_scale, col/2); - const float sin_theta = sinf(theta); - const float cos_theta = cosf(theta); - - const float x0 = x[i + 0]; - const float x1 = x[i + 1]; - - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + 1] = x0*sin_theta + x1*cos_theta; -} - -template -static __global__ void rope_neox(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale, - const int p_delta_rows, const float theta_scale) { - const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); - - if (col >= ncols) { - return; - } - - const int row = blockDim.x*blockIdx.x + threadIdx.x; - const int i = row*ncols + col/2; - const int i2 = row/p_delta_rows; - - const int p = has_pos ? pos[i2] : 0; - const float p0 = p*freq_scale; - const float theta = p0*powf(theta_scale, col/2); - const float sin_theta = sinf(theta); - const float cos_theta = cosf(theta); - - const float x0 = x[i + 0]; - const float x1 = x[i + ncols/2]; - - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + ncols/2] = x0*sin_theta + x1*cos_theta; -} - -static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale, - const int p_delta_rows, const float theta_scale, const int n_ctx) { - const int col = blockDim.x*blockIdx.x + threadIdx.x; - const int half_n_dims = ncols/4; - - if (col >= half_n_dims) { - return; - } - - const int row = blockDim.y*blockIdx.y + threadIdx.y; - const int i = row*ncols + col; - const int i2 = row/p_delta_rows; - - const float col_theta_scale = powf(theta_scale, col); - // FIXME: this is likely wrong - const int p = pos != nullptr ? pos[i2] : 0; - - const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale; - const float sin_theta = sinf(theta); - const float cos_theta = cosf(theta); - - const float x0 = x[i + 0]; - const float x1 = x[i + half_n_dims]; - - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta; - - const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale; - const float sin_block_theta = sinf(block_theta); - const float cos_block_theta = cosf(block_theta); - - const float x2 = x[i + half_n_dims * 2]; - const float x3 = x[i + half_n_dims * 3]; - - dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta; - dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta; -} - -static __global__ void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows, - const int n_heads_log2_floor, const float m0, const float m1) { - const int col = blockDim.x*blockIdx.x + threadIdx.x; - - if (col >= ncols) { - return; - } - - const int row = blockDim.y*blockIdx.y + threadIdx.y; - const int i = row*ncols + col; - - const int k = row/k_rows; - - float m_k; - if (k < n_heads_log2_floor) { - m_k = powf(m0, k + 1); - } else { - m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); - } - - dst[i] = col * m_k + x[i]; -} - -static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) { - const int col = blockDim.y*blockIdx.y + threadIdx.y; - const int row = blockDim.x*blockIdx.x + threadIdx.x; - - if (col >= ncols) { - return; - } - - const int i = row*ncols + col; - // dst[i] = col > n_past + row ? -INFINITY : x[i]; - dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU -} - -// the CUDA soft max implementation differs from the CPU implementation -// instead of doubles floats are used -static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) { - const int row = blockDim.x*blockIdx.x + threadIdx.x; - const int block_size = blockDim.y; - const int tid = threadIdx.y; - - float max_val = -INFINITY; - - for (int col = tid; col < ncols; col += block_size) { - const int i = row*ncols + col; - max_val = max(max_val, x[i]); - } - - // find the max value in the block -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32)); - } - - float tmp = 0.f; - - for (int col = tid; col < ncols; col += block_size) { - const int i = row*ncols + col; - const float val = expf(x[i] - max_val); - tmp += val; - dst[i] = val; - } - - // sum up partial sums -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } - - const float inv_tmp = 1.f / tmp; - - for (int col = tid; col < ncols; col += block_size) { - const int i = row*ncols + col; - dst[i] *= inv_tmp; - } -} - -static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - - dst[i] = scale * x[i]; -} - -static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - - dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]); -} - -template -static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) { - const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); - const int block_num_x = (ncols + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); - const dim3 block_nums(block_num_x, nrows, 1); - k_get_rows<<>>(x, y, dst, ncols); -} - -static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { - const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; - add_f32<<>>(x, y, dst, kx, ky); -} - -static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; - add_f16_f32_f16<<>>(x, y, dst, k); -} - -static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { - const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE; - mul_f32<<>>(x, y, dst, kx, ky); -} - -static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; - gelu_f32<<>>(x, dst, k); -} - -static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; - silu_f32<<>>(x, dst, k); -} - -static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % WARP_SIZE == 0); - if (ncols < 1024) { - const dim3 block_dims(WARP_SIZE, 1, 1); - norm_f32<<>>(x, dst, ncols); - } else { - const dim3 block_dims(1024, 1, 1); - norm_f32<1024><<>>(x, dst, ncols); - } -} - -static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { - GGML_ASSERT(ncols % WARP_SIZE == 0); - if (ncols < 1024) { - const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, ncols, eps); - } else { - const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024><<>>(x, dst, ncols, eps); - } -} - -static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) { - const int block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; - const dim3 num_blocks(block_num_x, ky, 1); - const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1); - quantize_q8_1<<>>(x, vy, kx, kx_padded); -} - -template -static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<<>>(vx, y, k); -} - -template -static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<<>>(vx, y, k); -} - -template -static void dequantize_row_q5_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<<>>(vx, y, k); -} - -template -static void dequantize_row_q5_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<<>>(vx, y, k); -} - -template -static void dequantize_row_q8_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<<>>(vx, y, k); -} - -template -static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { - const int nb = k / QK_K; -#if QK_K == 256 - dequantize_block_q2_K<<>>(vx, y); -#else - dequantize_block_q2_K<<>>(vx, y); -#endif -} - -template -static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { - const int nb = k / QK_K; -#if QK_K == 256 - dequantize_block_q3_K<<>>(vx, y); -#else - dequantize_block_q3_K<<>>(vx, y); -#endif -} - -template -static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { - const int nb = k / QK_K; - dequantize_block_q4_K<<>>(vx, y); -} - -template -static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { - const int nb = k / QK_K; -#if QK_K == 256 - dequantize_block_q5_K<<>>(vx, y); -#else - dequantize_block_q5_K<<>>(vx, y); -#endif -} - -template -static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { - const int nb = k / QK_K; -#if QK_K == 256 - dequantize_block_q6_K<<>>(vx, y); -#else - dequantize_block_q6_K<<>>(vx, y); -#endif -} - -static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2 - const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_q2_k<<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; - const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_q3_k<<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; - const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_q4_k<<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const dim3 block_dims(32, 1, 1); - dequantize_mul_mat_vec_q5_k<<>>(vx, y, dst, ncols); -} - -static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; - const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_q6_k<<>>(vx, y, dst, ncols, nrows); -} - -static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK4_0 == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); -} - -static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK4_1 == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); -} - -static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK5_0 == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); -} - -static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK5_1 == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); -} - -static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK8_0 == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); -} - -static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); -} - -static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); -} - -static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); -} - -static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); -} - -static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); -} - -static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<1, 1, convert_f16><<>>(vx, y, k); -} - -static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; - dequantize_block<1, 1, convert_f32><<>>(vx, y, k); -} - -static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec<1, 1, convert_f16> - <<>>(vx, y, dst, ncols, nrows); -} - -static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { - switch (type) { - case GGML_TYPE_Q4_0: - return dequantize_row_q4_0_cuda; - case GGML_TYPE_Q4_1: - return dequantize_row_q4_1_cuda; - case GGML_TYPE_Q5_0: - return dequantize_row_q5_0_cuda; - case GGML_TYPE_Q5_1: - return dequantize_row_q5_1_cuda; - case GGML_TYPE_Q8_0: - return dequantize_row_q8_0_cuda; - case GGML_TYPE_Q2_K: - return dequantize_row_q2_K_cuda; - case GGML_TYPE_Q3_K: - return dequantize_row_q3_K_cuda; - case GGML_TYPE_Q4_K: - return dequantize_row_q4_K_cuda; - case GGML_TYPE_Q5_K: - return dequantize_row_q5_K_cuda; - case GGML_TYPE_Q6_K: - return dequantize_row_q6_K_cuda; - case GGML_TYPE_F32: - return convert_fp32_to_fp16_cuda; - default: - return nullptr; - } -} - -static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { - switch (type) { - case GGML_TYPE_Q4_0: - return dequantize_row_q4_0_cuda; - case GGML_TYPE_Q4_1: - return dequantize_row_q4_1_cuda; - case GGML_TYPE_Q5_0: - return dequantize_row_q5_0_cuda; - case GGML_TYPE_Q5_1: - return dequantize_row_q5_1_cuda; - case GGML_TYPE_Q8_0: - return dequantize_row_q8_0_cuda; - case GGML_TYPE_Q2_K: - return dequantize_row_q2_K_cuda; - case GGML_TYPE_Q3_K: - return dequantize_row_q3_K_cuda; - case GGML_TYPE_Q4_K: - return dequantize_row_q4_K_cuda; - case GGML_TYPE_Q5_K: - return dequantize_row_q5_K_cuda; - case GGML_TYPE_Q6_K: - return dequantize_row_q6_K_cuda; - case GGML_TYPE_F16: - return convert_fp16_to_fp32_cuda; - default: - return nullptr; - } -} - -static void ggml_mul_mat_q4_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - - int id; - CUDA_CHECK(cudaGetDevice(&id)); - const int compute_capability = g_compute_capabilities[id]; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= CC_RDNA2) { - mmq_x = MMQ_X_Q4_0_RDNA2; - mmq_y = MMQ_Y_Q4_0_RDNA2; - nwarps = NWARPS_Q4_0_RDNA2; - } else if (compute_capability >= CC_OFFSET_AMD) { - mmq_x = MMQ_X_Q4_0_RDNA1; - mmq_y = MMQ_Y_Q4_0_RDNA1; - nwarps = NWARPS_Q4_0_RDNA1; - } else if (compute_capability >= CC_VOLTA) { - mmq_x = MMQ_X_Q4_0_AMPERE; - mmq_y = MMQ_Y_Q4_0_AMPERE; - nwarps = NWARPS_Q4_0_AMPERE; - } else if (compute_capability >= MIN_CC_DP4A) { - mmq_x = MMQ_X_Q4_0_PASCAL; - mmq_y = MMQ_Y_Q4_0_PASCAL; - nwarps = NWARPS_Q4_0_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - mul_mat_q4_0<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } else { - const bool need_check = true; - mul_mat_q4_0<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } -} - -static void ggml_mul_mat_q4_1_q8_1_cuda( - const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - - int id; - CUDA_CHECK(cudaGetDevice(&id)); - const int compute_capability = g_compute_capabilities[id]; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= CC_RDNA2) { - mmq_x = MMQ_X_Q4_1_RDNA2; - mmq_y = MMQ_Y_Q4_1_RDNA2; - nwarps = NWARPS_Q4_1_RDNA2; - } else if (compute_capability >= CC_OFFSET_AMD) { - mmq_x = MMQ_X_Q4_1_RDNA1; - mmq_y = MMQ_Y_Q4_1_RDNA1; - nwarps = NWARPS_Q4_1_RDNA1; - } else if (compute_capability >= CC_VOLTA) { - mmq_x = MMQ_X_Q4_1_AMPERE; - mmq_y = MMQ_Y_Q4_1_AMPERE; - nwarps = NWARPS_Q4_1_AMPERE; - } else if (compute_capability >= MIN_CC_DP4A) { - mmq_x = MMQ_X_Q4_1_PASCAL; - mmq_y = MMQ_Y_Q4_1_PASCAL; - nwarps = NWARPS_Q4_1_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - mul_mat_q4_1<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } else { - const bool need_check = true; - mul_mat_q4_1<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } -} - -static void ggml_mul_mat_q5_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - - int id; - CUDA_CHECK(cudaGetDevice(&id)); - const int compute_capability = g_compute_capabilities[id]; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= CC_RDNA2) { - mmq_x = MMQ_X_Q5_0_RDNA2; - mmq_y = MMQ_Y_Q5_0_RDNA2; - nwarps = NWARPS_Q5_0_RDNA2; - } else if (compute_capability >= CC_OFFSET_AMD) { - mmq_x = MMQ_X_Q5_0_RDNA1; - mmq_y = MMQ_Y_Q5_0_RDNA1; - nwarps = NWARPS_Q5_0_RDNA1; - } else if (compute_capability >= CC_VOLTA) { - mmq_x = MMQ_X_Q5_0_AMPERE; - mmq_y = MMQ_Y_Q5_0_AMPERE; - nwarps = NWARPS_Q5_0_AMPERE; - } else if (compute_capability >= MIN_CC_DP4A) { - mmq_x = MMQ_X_Q5_0_PASCAL; - mmq_y = MMQ_Y_Q5_0_PASCAL; - nwarps = NWARPS_Q5_0_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - mul_mat_q5_0<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } else { - const bool need_check = true; - mul_mat_q5_0<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } -} - -static void ggml_mul_mat_q5_1_q8_1_cuda( - const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - - int id; - CUDA_CHECK(cudaGetDevice(&id)); - const int compute_capability = g_compute_capabilities[id]; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= CC_RDNA2) { - mmq_x = MMQ_X_Q5_1_RDNA2; - mmq_y = MMQ_Y_Q5_1_RDNA2; - nwarps = NWARPS_Q5_1_RDNA2; - } else if (compute_capability >= CC_OFFSET_AMD) { - mmq_x = MMQ_X_Q5_1_RDNA1; - mmq_y = MMQ_Y_Q5_1_RDNA1; - nwarps = NWARPS_Q5_1_RDNA1; - } else if (compute_capability >= CC_VOLTA) { - mmq_x = MMQ_X_Q5_1_AMPERE; - mmq_y = MMQ_Y_Q5_1_AMPERE; - nwarps = NWARPS_Q5_1_AMPERE; - } else if (compute_capability >= MIN_CC_DP4A) { - mmq_x = MMQ_X_Q5_1_PASCAL; - mmq_y = MMQ_Y_Q5_1_PASCAL; - nwarps = NWARPS_Q5_1_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - mul_mat_q5_1<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } else { - const bool need_check = true; - mul_mat_q5_1<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } -} - -static void ggml_mul_mat_q8_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - - int id; - CUDA_CHECK(cudaGetDevice(&id)); - const int compute_capability = g_compute_capabilities[id]; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= CC_RDNA2) { - mmq_x = MMQ_X_Q8_0_RDNA2; - mmq_y = MMQ_Y_Q8_0_RDNA2; - nwarps = NWARPS_Q8_0_RDNA2; - } else if (compute_capability >= CC_OFFSET_AMD) { - mmq_x = MMQ_X_Q8_0_RDNA1; - mmq_y = MMQ_Y_Q8_0_RDNA1; - nwarps = NWARPS_Q8_0_RDNA1; - } else if (compute_capability >= CC_VOLTA) { - mmq_x = MMQ_X_Q8_0_AMPERE; - mmq_y = MMQ_Y_Q8_0_AMPERE; - nwarps = NWARPS_Q8_0_AMPERE; - } else if (compute_capability >= MIN_CC_DP4A) { - mmq_x = MMQ_X_Q8_0_PASCAL; - mmq_y = MMQ_Y_Q8_0_PASCAL; - nwarps = NWARPS_Q8_0_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - mul_mat_q8_0<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } else { - const bool need_check = true; - mul_mat_q8_0<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } -} - -static void ggml_mul_mat_q2_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - - int id; - CUDA_CHECK(cudaGetDevice(&id)); - const int compute_capability = g_compute_capabilities[id]; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= CC_RDNA2) { - mmq_x = MMQ_X_Q2_K_RDNA2; - mmq_y = MMQ_Y_Q2_K_RDNA2; - nwarps = NWARPS_Q2_K_RDNA2; - } else if (compute_capability >= CC_OFFSET_AMD) { - mmq_x = MMQ_X_Q2_K_RDNA1; - mmq_y = MMQ_Y_Q2_K_RDNA1; - nwarps = NWARPS_Q2_K_RDNA1; - } else if (compute_capability >= CC_VOLTA) { - mmq_x = MMQ_X_Q2_K_AMPERE; - mmq_y = MMQ_Y_Q2_K_AMPERE; - nwarps = NWARPS_Q2_K_AMPERE; - } else if (compute_capability >= MIN_CC_DP4A) { - mmq_x = MMQ_X_Q2_K_PASCAL; - mmq_y = MMQ_Y_Q2_K_PASCAL; - nwarps = NWARPS_Q2_K_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - mul_mat_q2_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } else { - const bool need_check = true; - mul_mat_q2_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } -} - -static void ggml_mul_mat_q3_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - -#if QK_K == 256 - - int id; - CUDA_CHECK(cudaGetDevice(&id)); - const int compute_capability = g_compute_capabilities[id]; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= CC_RDNA2) { - mmq_x = MMQ_X_Q3_K_RDNA2; - mmq_y = MMQ_Y_Q3_K_RDNA2; - nwarps = NWARPS_Q3_K_RDNA2; - } else if (compute_capability >= CC_OFFSET_AMD) { - mmq_x = MMQ_X_Q3_K_RDNA1; - mmq_y = MMQ_Y_Q3_K_RDNA1; - nwarps = NWARPS_Q3_K_RDNA1; - } else if (compute_capability >= CC_VOLTA) { - mmq_x = MMQ_X_Q3_K_AMPERE; - mmq_y = MMQ_Y_Q3_K_AMPERE; - nwarps = NWARPS_Q3_K_AMPERE; - } else if (compute_capability >= MIN_CC_DP4A) { - mmq_x = MMQ_X_Q3_K_PASCAL; - mmq_y = MMQ_Y_Q3_K_PASCAL; - nwarps = NWARPS_Q3_K_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - mul_mat_q3_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } else { - const bool need_check = true; - mul_mat_q3_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } -#endif -} - -static void ggml_mul_mat_q4_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - - int id; - CUDA_CHECK(cudaGetDevice(&id)); - const int compute_capability = g_compute_capabilities[id]; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= CC_RDNA2) { - mmq_x = MMQ_X_Q4_K_RDNA2; - mmq_y = MMQ_Y_Q4_K_RDNA2; - nwarps = NWARPS_Q4_K_RDNA2; - } else if (compute_capability >= CC_OFFSET_AMD) { - mmq_x = MMQ_X_Q4_K_RDNA1; - mmq_y = MMQ_Y_Q4_K_RDNA1; - nwarps = NWARPS_Q4_K_RDNA1; - } else if (compute_capability >= CC_VOLTA) { - mmq_x = MMQ_X_Q4_K_AMPERE; - mmq_y = MMQ_Y_Q4_K_AMPERE; - nwarps = NWARPS_Q4_K_AMPERE; - } else if (compute_capability >= MIN_CC_DP4A) { - mmq_x = MMQ_X_Q4_K_PASCAL; - mmq_y = MMQ_Y_Q4_K_PASCAL; - nwarps = NWARPS_Q4_K_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - mul_mat_q4_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } else { - const bool need_check = true; - mul_mat_q4_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } -} - -static void ggml_mul_mat_q5_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - - int id; - CUDA_CHECK(cudaGetDevice(&id)); - const int compute_capability = g_compute_capabilities[id]; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= CC_RDNA2) { - mmq_x = MMQ_X_Q5_K_RDNA2; - mmq_y = MMQ_Y_Q5_K_RDNA2; - nwarps = NWARPS_Q5_K_RDNA2; - } else if (compute_capability >= CC_OFFSET_AMD) { - mmq_x = MMQ_X_Q5_K_RDNA1; - mmq_y = MMQ_Y_Q5_K_RDNA1; - nwarps = NWARPS_Q5_K_RDNA1; - } else if (compute_capability >= CC_VOLTA) { - mmq_x = MMQ_X_Q5_K_AMPERE; - mmq_y = MMQ_Y_Q5_K_AMPERE; - nwarps = NWARPS_Q5_K_AMPERE; - } else if (compute_capability >= MIN_CC_DP4A) { - mmq_x = MMQ_X_Q5_K_PASCAL; - mmq_y = MMQ_Y_Q5_K_PASCAL; - nwarps = NWARPS_Q5_K_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - mul_mat_q5_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } else { - const bool need_check = true; - mul_mat_q5_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } -} - -static void ggml_mul_mat_q6_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - - int id; - CUDA_CHECK(cudaGetDevice(&id)); - const int compute_capability = g_compute_capabilities[id]; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= CC_RDNA2) { - mmq_x = MMQ_X_Q6_K_RDNA2; - mmq_y = MMQ_Y_Q6_K_RDNA2; - nwarps = NWARPS_Q6_K_RDNA2; - } else if (compute_capability >= CC_OFFSET_AMD) { - mmq_x = MMQ_X_Q6_K_RDNA1; - mmq_y = MMQ_Y_Q6_K_RDNA1; - nwarps = NWARPS_Q6_K_RDNA1; - } else if (compute_capability >= CC_VOLTA) { - mmq_x = MMQ_X_Q6_K_AMPERE; - mmq_y = MMQ_Y_Q6_K_AMPERE; - nwarps = NWARPS_Q6_K_AMPERE; - } else if (compute_capability >= MIN_CC_DP4A) { - mmq_x = MMQ_X_Q6_K_PASCAL; - mmq_y = MMQ_Y_Q6_K_PASCAL; - nwarps = NWARPS_Q6_K_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - mul_mat_q6_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } else { - const bool need_check = true; - mul_mat_q6_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } -} - -static void ggml_mul_mat_p021_f16_f32_cuda( - const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, - const int nchannels_x, const int nchannels_y, cudaStream_t stream) { - - const dim3 block_nums(1, nrows_x, nchannels_y); - const dim3 block_dims(WARP_SIZE, 1, 1); - mul_mat_p021_f16_f32<<>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y); -} - -static void ggml_mul_mat_vec_nc_f16_f32_cuda( - const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x, - const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) { - - const dim3 block_nums(1, nrows_x, nchannels_y); - const dim3 block_dims(WARP_SIZE, 1, 1); - mul_mat_vec_nc_f16_f32<<>> - (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x); -} - -static void ggml_cpy_f32_f32_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, - const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { - - const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); -} - -static void ggml_cpy_f32_f16_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, - const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { - - const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); -} - -static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; - scale_f32<<>>(x, dst, scale, k); -} - -static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE; - clamp_f32<<>>(x, dst, min, max, k); -} - -template -static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, - const int p_delta_rows, const float theta_scale, cudaStream_t stream) { - GGML_ASSERT(ncols % 2 == 0); - const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); - const dim3 block_nums(nrows, num_blocks_x, 1); - if (pos == nullptr) { - rope<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); - } else { - rope<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); - } -} - -template -static void rope_neox_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, - const int p_delta_rows, const float theta_scale, cudaStream_t stream) { - GGML_ASSERT(ncols % 2 == 0); - const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); - const dim3 block_nums(nrows, num_blocks_x, 1); - if (pos == nullptr) { - rope_neox<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); - } else { - rope_neox<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale); - } -} - -static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, - const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) { - GGML_ASSERT(ncols % 4 == 0); - const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1); - const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE; - const dim3 block_nums(num_blocks_x, nrows, 1); - rope_glm_f32<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale, n_ctx); -} - -static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, - const int k_rows, const int n_heads_log2_floor, const float m0, - const float m1, cudaStream_t stream) { - const dim3 block_dims(CUDA_ALIBI_BLOCK_SIZE, 1, 1); - const int num_blocks_x = (ncols + CUDA_ALIBI_BLOCK_SIZE - 1) / (CUDA_ALIBI_BLOCK_SIZE); - const dim3 block_nums(num_blocks_x, nrows, 1); - alibi_f32<<>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1); -} - -static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) { - const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1); - const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE; - const dim3 block_nums(nrows_x, block_num_x, 1); - diag_mask_inf_f32<<>>(x, dst, ncols_x, rows_per_channel, n_past); -} - -static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) { - const dim3 block_dims(1, WARP_SIZE, 1); - const dim3 block_nums(nrows_x, 1, 1); - soft_max_f32<<>>(x, dst, ncols_x); -} - -// buffer pool for cuda -#define MAX_CUDA_BUFFERS 256 - -struct scoped_spin_lock { - std::atomic_flag& lock; - scoped_spin_lock(std::atomic_flag& lock) : lock(lock) { - while (lock.test_and_set(std::memory_order_acquire)) { - ; // spin - } - } - ~scoped_spin_lock() { - lock.clear(std::memory_order_release); - } - scoped_spin_lock(const scoped_spin_lock&) = delete; - scoped_spin_lock& operator=(const scoped_spin_lock&) = delete; -}; - -struct cuda_buffer { - void * ptr = nullptr; - size_t size = 0; -}; - -static cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS]; -static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT; - -static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { - scoped_spin_lock lock(g_cuda_pool_lock); - int id; - CUDA_CHECK(cudaGetDevice(&id)); -#ifdef DEBUG_CUDA_MALLOC - int nnz = 0; - size_t max_size = 0, tot_size = 0; -#endif - size_t best_diff = 1ull << 36; - int ibest = -1; - for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { - cuda_buffer& b = g_cuda_buffer_pool[id][i]; - if (b.ptr != nullptr) { -#ifdef DEBUG_CUDA_MALLOC - ++nnz; - tot_size += b.size; - if (b.size > max_size) max_size = b.size; -#endif - if (b.size >= size) { - size_t diff = b.size - size; - if (diff < best_diff) { - best_diff = diff; - ibest = i; - if (!best_diff) { - void * ptr = b.ptr; - *actual_size = b.size; - b.ptr = nullptr; - b.size = 0; - return ptr; - } - } - } - } - } - if (ibest >= 0) { - cuda_buffer& b = g_cuda_buffer_pool[id][ibest]; - void * ptr = b.ptr; - *actual_size = b.size; - b.ptr = nullptr; - b.size = 0; - return ptr; - } -#ifdef DEBUG_CUDA_MALLOC - fprintf(stderr, "%s: %d buffers, max_size = %u MB, tot_size = %u MB, requested %u MB\n", __func__, nnz, - (uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024)); -#endif - void * ptr; - size_t look_ahead_size = (size_t) (1.05 * size); - look_ahead_size = 256 * ((look_ahead_size + 255)/256); - CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size)); - *actual_size = look_ahead_size; - return ptr; -} - -static void ggml_cuda_pool_free(void * ptr, size_t size) { - scoped_spin_lock lock(g_cuda_pool_lock); - int id; - CUDA_CHECK(cudaGetDevice(&id)); - - for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { - cuda_buffer& b = g_cuda_buffer_pool[id][i]; - if (b.ptr == nullptr) { - b.ptr = ptr; - b.size = size; - return; - } - } - fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n"); - CUDA_CHECK(cudaFree(ptr)); -} - - -void ggml_init_cublas() { - static bool initialized = false; - - if (!initialized) { - -#ifdef __HIP_PLATFORM_AMD__ - // Workaround for a rocBLAS bug when using multiple graphics cards: - // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346 - rocblas_initialize(); - CUDA_CHECK(cudaDeviceSynchronize()); -#endif - - CUDA_CHECK(cudaGetDeviceCount(&g_device_count)); - GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES); - int64_t total_vram = 0; - fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count); - for (int64_t id = 0; id < g_device_count; ++id) { - cudaDeviceProp prop; - CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); - fprintf(stderr, " Device %ld: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor); - - g_tensor_split[id] = total_vram; - total_vram += prop.totalGlobalMem; -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - g_compute_capabilities[id] = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD; -#else - g_compute_capabilities[id] = 100*prop.major + 10*prop.minor; -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - } - for (int64_t id = 0; id < g_device_count; ++id) { - g_tensor_split[id] /= total_vram; - } - - for (int64_t id = 0; id < g_device_count; ++id) { - CUDA_CHECK(ggml_cuda_set_device(id)); - - // create cuda streams - for (int64_t is = 0; is < MAX_STREAMS; ++is) { - CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[id][is], cudaStreamNonBlocking)); - } - - // create cublas handle - CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id])); - CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH)); - } - - // configure logging to stdout - // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); - - initialized = true; - } -} - -void ggml_cuda_set_tensor_split(const float * tensor_split) { - if (tensor_split == nullptr) { - return; - } - bool all_zero = true; - for (int i = 0; i < g_device_count; ++i) { - if (tensor_split[i] != 0.0f) { - all_zero = false; - break; - } - } - if (all_zero) { - return; - } - float split_sum = 0.0f; - for (int i = 0; i < g_device_count; ++i) { - g_tensor_split[i] = split_sum; - split_sum += tensor_split[i]; - } - for (int i = 0; i < g_device_count; ++i) { - g_tensor_split[i] /= split_sum; - } -} - -void * ggml_cuda_host_malloc(size_t size) { - if (getenv("GGML_CUDA_NO_PINNED") != nullptr) { - return nullptr; - } - - void * ptr = nullptr; - cudaError_t err = cudaMallocHost((void **) &ptr, size); - if (err != cudaSuccess) { - // The allocation error can be bypassed. A null ptr will assigned out of this function. - // This can fixed the OOM error in WSL. - cudaGetLastError(); - fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n", - size/1024.0/1024.0, cudaGetErrorString(err)); - return nullptr; - } - - return ptr; -} - -void ggml_cuda_host_free(void * ptr) { - CUDA_CHECK(cudaFreeHost(ptr)); -} - -static cudaError_t ggml_cuda_cpy_tensor_2d( - void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) { - - cudaMemcpyKind kind; - char * src_ptr; - if (src->backend == GGML_BACKEND_CPU) { - kind = cudaMemcpyHostToDevice; - src_ptr = (char *) src->data; - } else if (src->backend == GGML_BACKEND_GPU || src->backend == GGML_BACKEND_GPU_SPLIT) { - GGML_ASSERT(src->backend != GGML_BACKEND_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1])); - kind = cudaMemcpyDeviceToDevice; - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra; - int id; - CUDA_CHECK(cudaGetDevice(&id)); - src_ptr = (char *) extra->data_device[id]; - } else { - GGML_ASSERT(false); - } - char * dst_ptr = (char *) dst; - - const int64_t ne0 = src->ne[0]; - const int64_t nb0 = src->nb[0]; - const int64_t nb1 = src->nb[1]; - const int64_t nb2 = src->nb[2]; - const int64_t nb3 = src->nb[3]; - const enum ggml_type type = src->type; - const int64_t ts = ggml_type_size(type); - const int64_t bs = ggml_blck_size(type); - int64_t i1_diff = i1_high - i1_low; - - const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3; - if (nb0 == ts && nb1 == ts*ne0/bs) { - return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream); - } else if (nb0 == ts) { - return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream); - } else { - for (int64_t i1 = 0; i1 < i1_diff; i1++) { - const void * rx = (const void *) ((const char *) x + i1*nb1); - void * rd = (void *) (dst_ptr + i1*ts*ne0/bs); - // pretend the row is a matrix with cols=1 - cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, kind, stream); - if (r != cudaSuccess) return r; - } - return cudaSuccess; - } -} - -static void ggml_cuda_op_repeat( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) { - // guaranteed to be an integer due to the check in ggml_can_repeat - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t ne3 = dst->ne[3]; - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; - - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const int nr0 = (int)(ne0/ne00); - const int nr1 = (int)(ne1/ne01); - const int nr2 = (int)(ne2/ne02); - const int nr3 = (int)(ne3/ne03); - - // TODO: support for transposed / permuted tensors - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - - // TODO: very inefficient, implement in a kernel, or fewer cudaMemcpyAsync calls for contiguous tensors - for (int i3 = 0; i3 < nr3; i3++) { - for (int k3 = 0; k3 < ne03; k3++) { - for (int i2 = 0; i2 < nr2; i2++) { - for (int k2 = 0; k2 < ne02; k2++) { - for (int i1 = 0; i1 < nr1; i1++) { - for (int k1 = 0; k1 < ne01; k1++) { - for (int i0 = 0; i0 < nr0; i0++) { - CUDA_CHECK(cudaMemcpyAsync( - (char *) dst_d + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0, - (const char *) src0_d + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01, - ne00*nb0, cudaMemcpyDeviceToDevice, stream)); - } - } - } - } - } - } - } - - (void) src1; - (void) src1_d; -} - -static void ggml_cuda_op_get_rows( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) { - - GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - GGML_ASSERT(ggml_is_contiguous(dst)); - - const int ncols = src0->ne[0]; - const int nrows = ggml_nelements(src1); - - const int32_t * src1_i32 = (const int32_t *) src1_d; - - switch (src0->type) { - case GGML_TYPE_F16: - get_rows_cuda<1, 1, convert_f16>(src0_d, src1_i32, dst_d, nrows, ncols, stream); - break; - case GGML_TYPE_F32: - get_rows_cuda<1, 1, convert_f32>(src0_d, src1_i32, dst_d, nrows, ncols, stream); - break; - case GGML_TYPE_Q4_0: - get_rows_cuda(src0_d, src1_i32, dst_d, nrows, ncols, stream); - break; - case GGML_TYPE_Q4_1: - get_rows_cuda(src0_d, src1_i32, dst_d, nrows, ncols, stream); - break; - case GGML_TYPE_Q5_0: - get_rows_cuda(src0_d, src1_i32, dst_d, nrows, ncols, stream); - break; - case GGML_TYPE_Q5_1: - get_rows_cuda(src0_d, src1_i32, dst_d, nrows, ncols, stream); - break; - case GGML_TYPE_Q8_0: - get_rows_cuda(src0_d, src1_i32, dst_d, nrows, ncols, stream); - break; - default: - // TODO: k-quants - GGML_ASSERT(false); - break; - } -} - -inline void ggml_cuda_op_add( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - add_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream); - } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { - add_f16_f32_f16_cuda((const half *) src0_dd, src1_dd, (half *) dst_dd, ggml_nelements(src0), main_stream); - } else { - GGML_ASSERT(false); - } - - (void) src1; - (void) dst; -} - -inline void ggml_cuda_op_mul( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - - mul_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream); - - (void) dst; -} - -inline void ggml_cuda_op_gelu( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - gelu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -inline void ggml_cuda_op_silu( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - silu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -inline void ggml_cuda_op_norm( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - const int64_t ne00 = src0->ne[0]; - const int64_t nrows = ggml_nrows(src0); - - norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -inline void ggml_cuda_op_rms_norm( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - const int64_t ne00 = src0->ne[0]; - const int64_t nrows = ggml_nrows(src0); - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - rms_norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, eps, main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -inline void ggml_cuda_op_mul_mat_q( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, - const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, - const int64_t src1_padded_row_size, const cudaStream_t & stream) { - - const int64_t ne00 = src0->ne[0]; - - const int64_t ne10 = src1->ne[0]; - GGML_ASSERT(ne10 % QK8_1 == 0); - - const int64_t ne0 = dst->ne[0]; - - const int64_t row_diff = row_high - row_low; - - int id; - CUDA_CHECK(cudaGetDevice(&id)); - - // the main device has a larger memory buffer to hold the results from all GPUs - // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into - const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff; - - switch (src0->type) { - case GGML_TYPE_Q4_0: - ggml_mul_mat_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); - break; - case GGML_TYPE_Q4_1: - ggml_mul_mat_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); - break; - case GGML_TYPE_Q5_0: - ggml_mul_mat_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); - break; - case GGML_TYPE_Q5_1: - ggml_mul_mat_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); - break; - case GGML_TYPE_Q8_0: - ggml_mul_mat_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); - break; - case GGML_TYPE_Q2_K: - ggml_mul_mat_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); - break; - case GGML_TYPE_Q3_K: - ggml_mul_mat_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); - break; - case GGML_TYPE_Q4_K: - ggml_mul_mat_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); - break; - case GGML_TYPE_Q5_K: - ggml_mul_mat_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); - break; - case GGML_TYPE_Q6_K: - ggml_mul_mat_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); - break; - default: - GGML_ASSERT(false); - break; - } - - (void) src1; - (void) dst; - (void) src1_ddf_i; -} - -static int64_t get_row_rounding(ggml_type type) { - int64_t min_compute_capability = INT_MAX; - int64_t max_compute_capability = INT_MIN; - for (int64_t id = 0; id < g_device_count; ++id) { - if (g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) { - if (min_compute_capability > g_compute_capabilities[id]) { - min_compute_capability = g_compute_capabilities[id]; - } - if (max_compute_capability < g_compute_capabilities[id]) { - max_compute_capability = g_compute_capabilities[id]; - } - } - } - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - switch(type) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - return max_compute_capability >= CC_RDNA2 ? 128 : 64; - case GGML_TYPE_F16: - return 1; - case GGML_TYPE_Q2_K: - return max_compute_capability >= CC_RDNA2 ? 128 : 32; - case GGML_TYPE_Q3_K: - return min_compute_capability < CC_RDNA2 ? 128 : 64; - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - return max_compute_capability >= CC_RDNA2 ? 128 : 64; - default: - GGML_ASSERT(false); - } -#else - switch(type) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - return max_compute_capability >= CC_VOLTA ? 128 : 64; - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - return 64; - case GGML_TYPE_F16: - return 1; - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - return max_compute_capability >= CC_VOLTA ? 128 : 64; - case GGML_TYPE_Q6_K: - return 64; - default: - GGML_ASSERT(false); - } -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -} - -inline void ggml_cuda_op_mul_mat_vec_q( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, - const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, - const int64_t src1_padded_row_size, const cudaStream_t & stream) { - - const int64_t ne00 = src0->ne[0]; - const int64_t row_diff = row_high - row_low; - - switch (src0->type) { - case GGML_TYPE_Q4_0: - mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q4_1: - mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_0: - mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_1: - mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q8_0: - mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q2_K: - mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q3_K: - mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q4_K: - mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_K: - mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q6_K: - mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; - default: - GGML_ASSERT(false); - break; - } - - (void) src1; - (void) dst; - (void) src1_ddf_i; - (void) src1_ncols; - (void) src1_padded_row_size; -} - -inline void ggml_cuda_op_dequantize_mul_mat_vec( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, - const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, - const int64_t src1_padded_row_size, const cudaStream_t & stream) { - - const int64_t ne00 = src0->ne[0]; - const int64_t row_diff = row_high - row_low; - - // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics -#ifdef GGML_CUDA_F16 - size_t ash; - dfloat * src1_dfloat = nullptr; // dfloat == half - - bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || - src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || - src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; - - if (src1_convert_f16) { - src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash); - ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00, - ne00, 1, sizeof(float), 0, 0, - ne00, 1, sizeof(half), 0, 0, stream); - } -#else - const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion -#endif // GGML_CUDA_F16 - - switch (src0->type) { - case GGML_TYPE_Q4_0: - dequantize_mul_mat_vec_q4_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q4_1: - dequantize_mul_mat_vec_q4_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_0: - dequantize_mul_mat_vec_q5_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_1: - dequantize_mul_mat_vec_q5_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q8_0: - dequantize_mul_mat_vec_q8_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q2_K: - dequantize_mul_mat_vec_q2_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q3_K: - dequantize_mul_mat_vec_q3_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q4_K: - dequantize_mul_mat_vec_q4_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_K: - dequantize_mul_mat_vec_q5_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q6_K: - dequantize_mul_mat_vec_q6_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_F16: - convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); - break; - default: - GGML_ASSERT(false); - break; - } - -#ifdef GGML_CUDA_F16 - if (src1_convert_f16) { - ggml_cuda_pool_free(src1_dfloat, ash); - } -#endif // GGML_CUDA_F16 - - (void) src1; - (void) dst; - (void) src1_ddq_i; - (void) src1_ncols; - (void) src1_padded_row_size; -} - -inline void ggml_cuda_op_mul_mat_cublas( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, - const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, - const int64_t src1_padded_row_size, const cudaStream_t & stream) { - - GGML_ASSERT(src0_dd_i != nullptr); - GGML_ASSERT(src1_ddf_i != nullptr); - GGML_ASSERT(dst_dd_i != nullptr); - - - const int64_t ne00 = src0->ne[0]; - - const int64_t ne10 = src1->ne[0]; - - const int64_t ne0 = dst->ne[0]; - const int64_t row_diff = row_high - row_low; - - int id; - CUDA_CHECK(cudaGetDevice(&id)); - - // the main device has a larger memory buffer to hold the results from all GPUs - // ldc == nrows of the matrix that cuBLAS writes into - int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff; - - const int compute_capability = g_compute_capabilities[id]; - - if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) { - // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 - half * src0_as_f16 = nullptr; - size_t src0_as = 0; - if (src0->type != GGML_TYPE_F16) { - const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type); - GGML_ASSERT(to_fp16_cuda != nullptr); - size_t ne = row_diff*ne00; - src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as); - to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream); - } - const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16; - - half * src1_as_f16 = nullptr; - size_t src1_as = 0; - if (src1->type != GGML_TYPE_F16) { - const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); - GGML_ASSERT(to_fp16_cuda != nullptr); - size_t ne = src1_ncols*ne10; - src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as); - to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream); - } - const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16; - - size_t dst_as = 0; - half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as); - - const half alpha_f16 = 1.0f; - const half beta_f16 = 0.0f; - - CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream)); - CUBLAS_CHECK( - cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, - row_diff, src1_ncols, ne10, - &alpha_f16, src0_ptr, CUDA_R_16F, ne00, - src1_ptr, CUDA_R_16F, ne10, - &beta_f16, dst_f16, CUDA_R_16F, ldc, - CUBLAS_COMPUTE_16F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream); - - ggml_cuda_pool_free(dst_f16, dst_as); - - if (src0_as != 0) { - ggml_cuda_pool_free(src0_as_f16, src0_as); - } - - if (src1_as != 0) { - ggml_cuda_pool_free(src1_as_f16, src1_as); - } - } - else { - float * src0_ddq_as_f32 = nullptr; - size_t src0_as = 0; - - if (src0->type != GGML_TYPE_F32) { - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); - GGML_ASSERT(to_fp32_cuda != nullptr); - src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT - to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream); - } - const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32; - - const float alpha = 1.0f; - const float beta = 0.0f; - - CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream)); - CUBLAS_CHECK( - cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, - row_diff, src1_ncols, ne10, - &alpha, src0_ddf_i, ne00, - src1_ddf_i, ne10, - &beta, dst_dd_i, ldc)); - - if (src0_as != 0) { - ggml_cuda_pool_free(src0_ddq_as_f32, src0_as); - } - } - - (void) dst; - (void) src1_ddq_i; - (void) src1_padded_row_size; -} - -inline void ggml_cuda_op_rope( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); - GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - GGML_ASSERT(src0->type == dst->type); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t nrows = ggml_nrows(src0); - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - const int n_ctx = ((int32_t *) dst->op_params)[3]; - // RoPE alteration for extended context - - float freq_base, freq_scale; - memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); - - const float theta_scale = powf(freq_base, -2.0f/n_dims); - - const int32_t * pos = nullptr; - if ((mode & 1) == 0) { - GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(src1->ne[0] == ne2); - pos = (const int32_t *) src1_dd; - } - - const bool is_neox = mode & 2; - const bool is_glm = mode & 4; - - // compute - if (is_glm) { - GGML_ASSERT(false); - rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream); - } else if (is_neox) { - GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet"); - if (src0->type == GGML_TYPE_F32) { - rope_neox_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); - } else if (src0->type == GGML_TYPE_F16) { - rope_neox_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); - } else { - GGML_ASSERT(false); - } - } else { - if (src0->type == GGML_TYPE_F32) { - rope_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); - } else if (src0->type == GGML_TYPE_F16) { - rope_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream); - } else { - GGML_ASSERT(false); - } - } - - (void) src1; - (void) dst; - (void) src1_dd; -} - -inline void ggml_cuda_op_alibi( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t nrows = ggml_nrows(src0); - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_head = ((int32_t *) dst->op_params)[1]; - float max_bias; - memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - - //GGML_ASSERT(ne01 + n_past == ne00); - GGML_ASSERT(n_head == ne02); - - const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - - const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - - alibi_f32_cuda(src0_dd, dst_dd, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, main_stream); - - (void) src1; - (void) src1_dd; -} - -inline void ggml_cuda_op_diag_mask_inf( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int nrows0 = ggml_nrows(src0); - - const int n_past = ((int32_t *) dst->op_params)[0]; - - diag_mask_inf_f32_cuda(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -inline void ggml_cuda_op_soft_max( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - const int64_t ne00 = src0->ne[0]; - const int64_t nrows = ggml_nrows(src0); - - soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -inline void ggml_cuda_op_scale( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - float scale; - // HACK: support for ggml backend interface - if (src1->backend == GGML_BACKEND_CPU) { - scale = ((float *) src1->data)[0]; - } else { - // TODO: pass pointer to kernel instead of copying to host - CUDA_CHECK(cudaMemcpy(&scale, src1->data, sizeof(float), cudaMemcpyDeviceToHost)); - } - - scale_f32_cuda(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream); - CUDA_CHECK(cudaGetLastError()); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -inline void ggml_cuda_op_clamp( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - const float min = ((float *) dst->op_params)[0]; - const float max = ((float *) dst->op_params)[1]; - - clamp_f32_cuda(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream); - CUDA_CHECK(cudaGetLastError()); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_cuda_op_flatten_t op) { - const int64_t nrows0 = ggml_nrows(src0); - - const bool use_src1 = src1 != nullptr; - const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1; - - GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT); - GGML_ASSERT( dst->backend != GGML_BACKEND_GPU_SPLIT); - - ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; - ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr; - ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; - - const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT; - const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU; - const bool dst_on_device = dst->backend == GGML_BACKEND_GPU; - - const bool src1_stays_on_host = use_src1 && dst->op == GGML_OP_SCALE; - - // dd = data device - float * src0_ddf = nullptr; - float * src1_ddf = nullptr; - float * dst_ddf = nullptr; - - // as = actual size - size_t src0_asf = 0; - size_t src1_asf = 0; - size_t dst_asf = 0; - - ggml_cuda_set_device(g_main_device); - const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; - - if (src0_on_device) { - src0_ddf = (float *) src0_extra->data_device[g_main_device]; - } else { - src0_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_asf); - CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream)); - } - - if (use_src1 && !src1_stays_on_host) { - if (src1_on_device) { - src1_ddf = (float *) src1_extra->data_device[g_main_device]; - } else { - src1_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf); - CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf, src1, 0, 0, 0, nrows1, main_stream)); - } - } - if (dst_on_device) { - dst_ddf = (float *) dst_extra->data_device[g_main_device]; - } else { - dst_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(dst), &dst_asf); - } - - // do the computation - op(src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream); - CUDA_CHECK(cudaGetLastError()); - - // copy dst to host if necessary - if (!dst_on_device) { - CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream)); - } - - if (src0_asf > 0) { - ggml_cuda_pool_free(src0_ddf, src0_asf); - } - if (src1_asf > 0) { - ggml_cuda_pool_free(src1_ddf, src1_asf); - } - if (dst_asf > 0) { - ggml_cuda_pool_free(dst_ddf, dst_asf); - } - - if (dst->backend == GGML_BACKEND_CPU) { - CUDA_CHECK(cudaDeviceSynchronize()); - } -} - -static void ggml_cuda_set_peer_access(const int n_tokens) { - static bool peer_access_enabled = false; - - const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE; - - if (peer_access_enabled == enable_peer_access) { - return; - } - -#ifdef NDEBUG - for (int id = 0; id < g_device_count; ++id) { - CUDA_CHECK(ggml_cuda_set_device(id)); - - for (int id_other = 0; id_other < g_device_count; ++id_other) { - if (id == id_other) { - continue; - } - if (id != g_main_device && id_other != g_main_device) { - continue; - } - - int can_access_peer; - CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other)); - if (can_access_peer) { - if (enable_peer_access) { - CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0)); - } else { - CUDA_CHECK(cudaDeviceDisablePeerAccess(id_other)); - } - } - } - } -#endif // NDEBUG - - peer_access_enabled = enable_peer_access; -} - -static void ggml_cuda_op_mul_mat( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op, - const bool convert_src1_to_q8_1) { - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; - const int64_t nrows0 = ggml_nrows(src0); - - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - const int64_t ne12 = src1->ne[2]; - const int64_t ne13 = src1->ne[3]; - const int64_t nrows1 = ggml_nrows(src1); - - GGML_ASSERT(ne03 == ne13); - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - - const int nb2 = dst->nb[2]; - const int nb3 = dst->nb[3]; - - ggml_cuda_set_peer_access(ne11); - - GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT); - GGML_ASSERT(src1->backend != GGML_BACKEND_GPU_SPLIT); - - GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0); - - const int64_t i02_divisor = ne12 / ne02; - - const size_t src0_ts = ggml_type_size(src0->type); - const size_t src0_bs = ggml_blck_size(src0->type); - const size_t q8_1_ts = sizeof(block_q8_1); - const size_t q8_1_bs = QK8_1; - - ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; - ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; - ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; - - const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT; - const bool src0_is_contiguous = ggml_is_contiguous(src0); - - const bool src1_is_contiguous = ggml_is_contiguous(src1); - const int64_t src1_padded_col_size = ne10 % MATRIX_ROW_PADDING == 0 ? - ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING; - - const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT; - GGML_ASSERT(!(split && ne02 > 1)); - GGML_ASSERT(!(split && ne03 > 1)); - GGML_ASSERT(!(split && ne02 < ne12)); - - // dd = data device - char * src0_dd[GGML_CUDA_MAX_DEVICES] = {nullptr}; - float * src1_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // float - char * src1_ddq[GGML_CUDA_MAX_DEVICES] = {nullptr}; // q8_1 - float * dst_dd[GGML_CUDA_MAX_DEVICES] = {nullptr}; - - // as = actual size - size_t src0_as[GGML_CUDA_MAX_DEVICES] = {0}; - size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0}; - size_t src1_asq[GGML_CUDA_MAX_DEVICES] = {0}; - size_t dst_as[GGML_CUDA_MAX_DEVICES] = {0}; - - int64_t row_low[GGML_CUDA_MAX_DEVICES]; - int64_t row_high[GGML_CUDA_MAX_DEVICES]; - - for (int64_t id = 0; id < g_device_count; ++id) { - // by default, use all rows - row_low[id] = 0; - row_high[id] = ne01; - - // for multi GPU, get the row boundaries from tensor split - // and round to mul_mat_q tile sizes - if (split) { - const int64_t rounding = get_row_rounding(src0->type); - - if (id != 0) { - row_low[id] = ne01*g_tensor_split[id]; - row_low[id] -= row_low[id] % rounding; - } - - if (id != g_device_count - 1) { - row_high[id] = ne01*g_tensor_split[id + 1]; - row_high[id] -= row_high[id] % rounding; - } - } - } - - for (int64_t id = 0; id < g_device_count; ++id) { - if ((!split && id != g_main_device) || row_low[id] == row_high[id]) { - continue; - } - - const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device; - const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device; - - ggml_cuda_set_device(id); - const cudaStream_t stream = g_cudaStreams[id][0]; - - if (src0_on_device && src0_is_contiguous) { - src0_dd[id] = (char *) src0_extra->data_device[id]; - } else { - const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0); - src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]); - } - - if (src1_on_device && src1_is_contiguous) { - src1_ddf[id] = (float *) src1_extra->data_device[id]; - } else { - src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]); - } - - if (convert_src1_to_q8_1) { - src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]); - - if (src1_on_device && src1_is_contiguous) { - quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream); - CUDA_CHECK(cudaGetLastError()); - } - } - - if (dst_on_device) { - dst_dd[id] = (float *) dst_extra->data_device[id]; - } else { - const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst); - dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]); - } - } - - // if multiple devices are used they need to wait for the main device - // here an event is recorded that signals that the main device has finished calculating the input data - if (split && g_device_count > 1) { - CUDA_CHECK(ggml_cuda_set_device(g_main_device)); - CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device][0], g_cudaStreams[g_main_device][0])); - } - - const int64_t src1_col_stride = split && g_device_count > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11; - for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) { - const int64_t is = split ? (src1_col_0/src1_col_stride) % MAX_STREAMS : 0; - const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride; - - for (int64_t id = 0; id < g_device_count; ++id) { - if ((!split && id != g_main_device) || row_low[id] == row_high[id]) { - continue; - } - - const bool src1_on_device = src1->backend == GGML_BACKEND_GPU && id == g_main_device; - const bool dst_on_device = dst->backend == GGML_BACKEND_GPU && id == g_main_device; - const int64_t row_diff = row_high[id] - row_low[id]; - - ggml_cuda_set_device(id); - const cudaStream_t stream = g_cudaStreams[id][is]; - - // wait for main GPU data if necessary - if (split && (id != g_main_device || is != 0)) { - CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[g_main_device][0], 0)); - } - - for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) { - const int64_t i03 = i0 / ne12; - const int64_t i02 = i0 % ne12; - - const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs; - - // for split tensors the data begins at i0 == i0_offset_low - char * src0_dd_i = src0_dd[id] + (i0/i02_divisor) * ne01*ne00*src0_ts/src0_bs; - float * src1_ddf_i = src1_ddf[id] + (i0*ne11 + src1_col_0) * ne10; - char * src1_ddq_i = src1_ddq[id] + src1_ddq_i_offset; - float * dst_dd_i = dst_dd[id] + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff); - - // the main device memory buffer can be on VRAM scratch, with space for all partial results - // in that case an offset on dst_ddf_i is needed - if (dst->backend == GGML_BACKEND_GPU && id == g_main_device) { - dst_dd_i += row_low[id]; // offset is 0 if no tensor split - } - - // copy src0, src1 to device if necessary - if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) { - if (id != g_main_device) { - if (convert_src1_to_q8_1) { - char * src1_ddq_i_source = src1_ddq[g_main_device] + src1_ddq_i_offset; - CUDA_CHECK(cudaMemcpyAsync(src1_ddq_i, src1_ddq_i_source, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, - cudaMemcpyDeviceToDevice, stream)); - } else { - float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device]; - src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10; - CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_ncols*ne10*sizeof(float), - cudaMemcpyDeviceToDevice, stream)); - } - } - } else if (src1->backend == GGML_BACKEND_CPU || (src1_on_device && !src1_is_contiguous)) { - CUDA_CHECK(ggml_cuda_cpy_tensor_2d( - src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream)); - } else { - GGML_ASSERT(false); - } - - if (convert_src1_to_q8_1 && (src1->backend == GGML_BACKEND_CPU || !src1_is_contiguous)) { - quantize_row_q8_1_cuda(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream); - CUDA_CHECK(cudaGetLastError()); - } - - if (src1_col_0 == 0 && (!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) { - CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, row_low[id], row_high[id], stream)); - } - - // do the computation - op(src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i, - row_low[id], row_high[id], src1_ncols, src1_padded_col_size, stream); - CUDA_CHECK(cudaGetLastError()); - - // copy dst to host or other device if necessary - if (!dst_on_device) { - void * dst_off_device; - cudaMemcpyKind kind; - if (dst->backend == GGML_BACKEND_CPU) { - dst_off_device = dst->data; - kind = cudaMemcpyDeviceToHost; - } else if (dst->backend == GGML_BACKEND_GPU) { - dst_off_device = dst_extra->data_device[g_main_device]; - kind = cudaMemcpyDeviceToDevice; - } else { - GGML_ASSERT(false); - } - if (split) { - // src0 = weight matrix is saved as a transposed matrix for better memory layout. - // dst is NOT transposed. - // The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU. - // Instead they need to be copied to the correct slice in ne0 = dst row index. - // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results. - float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); - GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); - dhf_dst_i += src1_col_0*ne0 + row_low[id]; - CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float), dst_dd_i, row_diff*sizeof(float), - row_diff*sizeof(float), src1_ncols, kind, stream)); - } else { - float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); - GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); - dhf_dst_i += src1_col_0*ne0; - CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_dd_i, src1_ncols*ne0*sizeof(float), kind, stream)); - } - } - - // add event for the main device to wait on until other device is done - if (split && (id != g_main_device || is != 0)) { - CUDA_CHECK(cudaEventRecord(src0_extra->events[id][is], stream)); - } - } - } - } - - for (int64_t id = 0; id < g_device_count; ++id) { - CUDA_CHECK(ggml_cuda_set_device(id)); - - // free buffers again when done - if (src0_as[id] > 0) { - ggml_cuda_pool_free(src0_dd[id], src0_as[id]); - } - if (src1_asf[id] > 0) { - ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]); - } - if (src1_asq[id] > 0) { - ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]); - } - if (dst_as[id] > 0) { - ggml_cuda_pool_free(dst_dd[id], dst_as[id]); - } - } - - // main device waits for all other devices to be finished - if (split && g_device_count > 1) { - int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE; - is_max = is_max <= MAX_STREAMS ? is_max : MAX_STREAMS; - - CUDA_CHECK(ggml_cuda_set_device(g_main_device)); - for (int64_t id = 0; id < g_device_count; ++id) { - for (int64_t is = 0; is < is_max; ++is) { - CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is], 0)); - } - } - } - - if (dst->backend == GGML_BACKEND_CPU) { - CUDA_CHECK(ggml_cuda_set_device(g_main_device)); - CUDA_CHECK(cudaDeviceSynchronize()); - } -} - -static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_repeat); -} - -static void ggml_cuda_get_rows(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_get_rows); -} - -static void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add); -} - -static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul); -} - -static void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu); -} - -static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu); -} - -static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm); -} - -static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm); -} - -bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - const int64_t ne10 = src1->ne[0]; - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - - // TODO: find the optimal values for these - return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && - src1->type == GGML_TYPE_F32 && - dst->type == GGML_TYPE_F32 && - (ne0 >= 32 && ne1 >= 32 && ne10 >= 32); -} - -static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ - GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); - GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); - GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation - GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - - const int64_t ne12 = src1->ne[2]; - - CUDA_CHECK(ggml_cuda_set_device(g_main_device)); - cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; - - ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; - void * src0_ddq = src0_extra->data_device[g_main_device]; - - ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; - float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; - - ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; - float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; - - ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream); -} - -static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ - GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)); - GGML_ASSERT(!ggml_is_permuted(src0)); - GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - - const int64_t ne12 = src1->ne[2]; - - const int64_t nb01 = src0->nb[1]; - const int64_t nb02 = src0->nb[2]; - - CUDA_CHECK(ggml_cuda_set_device(g_main_device)); - cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; - - ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; - void * src0_ddq = src0_extra->data_device[g_main_device]; - - ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; - float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; - - ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; - float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; - - const int64_t row_stride_x = nb01 / sizeof(half); - const int64_t channel_stride_x = nb02 / sizeof(half); - - ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); -} - -static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && - src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU; - - int64_t min_compute_capability = INT_MAX; - for (int64_t id = 0; id < g_device_count; ++id) { - if (min_compute_capability > g_compute_capabilities[id] - && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) { - min_compute_capability = g_compute_capabilities[id]; - } - } - - if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { - ggml_cuda_mul_mat_vec_p021(src0, src1, dst); - } else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) { - ggml_cuda_mul_mat_vec_nc(src0, src1, dst); - } else if (src0->type == GGML_TYPE_F32) { - ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); - } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { - if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) { - -#ifdef GGML_CUDA_FORCE_DMMV - const bool use_mul_mat_vec_q = false; -#else - const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type); -#endif // GGML_CUDA_FORCE_DMMV - - if (use_mul_mat_vec_q) { - ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true); - } else { - ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false); - } - } else { - if (g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) { - ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true); - } else { - ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); - } - } - } else { - GGML_ASSERT(false); - } -} - -static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale); -} - -static void ggml_cuda_clamp(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_clamp); -} - -static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - const int64_t ne = ggml_nelements(src0); - GGML_ASSERT(ne == ggml_nelements(src1)); - - GGML_ASSERT(src0->backend == GGML_BACKEND_GPU); - GGML_ASSERT(src1->backend == GGML_BACKEND_GPU); - - GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); - GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - GGML_ASSERT(src0->ne[3] == 1); - - const int64_t nb00 = src0->nb[0]; - const int64_t nb01 = src0->nb[1]; - const int64_t nb02 = src0->nb[2]; - - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - GGML_ASSERT(src1->ne[3] == 1); - - const int64_t nb10 = src1->nb[0]; - const int64_t nb11 = src1->nb[1]; - const int64_t nb12 = src1->nb[2]; - - CUDA_CHECK(ggml_cuda_set_device(g_main_device)); - cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; - - const ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; - const ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; - - char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; - char * src1_ddc = (char *) src1_extra->data_device[g_main_device]; - - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, - ne10, ne11, nb10, nb11, nb12, main_stream); - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, - ne10, ne11, nb10, nb11, nb12, main_stream); - } else { - fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, - ggml_type_name(src0->type), ggml_type_name(src1->type)); - GGML_ASSERT(false); - } - - (void) dst; -} - -static void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_cuda_cpy(src0, dst, nullptr); - (void) src1; -} - -static void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_diag_mask_inf); -} - -static void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_soft_max); -} - -static void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented - ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rope); -} - -static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi); -} - -static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - (void) src0; - (void) src1; - (void) dst; -} - -void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { - const int64_t nrows = ggml_nrows(tensor); - - const int64_t ne0 = tensor->ne[0]; - - const size_t nb1 = tensor->nb[1]; - - ggml_backend_type backend = tensor->backend; - ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu; - memset(extra, 0, sizeof(*extra)); - - for (int64_t id = 0; id < g_device_count; ++id) { - if (backend == GGML_BACKEND_GPU && id != g_main_device) { - continue; - } - - ggml_cuda_set_device(id); - - int64_t row_low, row_high; - if (backend == GGML_BACKEND_GPU) { - row_low = 0; - row_high = nrows; - } else if (backend == GGML_BACKEND_GPU_SPLIT) { - const int64_t rounding = get_row_rounding(tensor->type); - - row_low = id == 0 ? 0 : nrows*g_tensor_split[id]; - row_low -= row_low % rounding; - - if (id == g_device_count - 1) { - row_high = nrows; - } else { - row_high = nrows*g_tensor_split[id + 1]; - row_high -= row_high % rounding; - } - } else { - GGML_ASSERT(false); - } - if (row_low == row_high) { - continue; - } - - int64_t nrows_split = row_high - row_low; - - const size_t offset_split = row_low*nb1; - size_t size = ggml_nbytes_split(tensor, nrows_split); - const size_t original_size = size; - - // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses - if (ne0 % MATRIX_ROW_PADDING != 0) { - size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING) - * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type); - } - - char * buf; - CUDA_CHECK(cudaMalloc(&buf, size)); - char * buf_host = (char*)data + offset_split; - - // set padding to 0 to avoid possible NaN values - if (size > original_size) { - CUDA_CHECK(cudaMemset(buf + original_size, 0, size - original_size)); - } - - CUDA_CHECK(cudaMemcpy(buf, buf_host, original_size, cudaMemcpyHostToDevice)); - - extra->data_device[id] = buf; - - if (backend == GGML_BACKEND_GPU_SPLIT) { - for (int64_t is = 0; is < MAX_STREAMS; ++is) { - CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id][is], cudaEventDisableTiming)); - } - } - } - - tensor->extra = extra; -} - -void ggml_cuda_free_data(struct ggml_tensor * tensor) { - if (!tensor || (tensor->backend != GGML_BACKEND_GPU && tensor->backend != GGML_BACKEND_GPU_SPLIT) ) { - return; - } - - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; - - for (int64_t id = 0; id < g_device_count; ++id) { - if (extra->data_device[id] != nullptr) { - CUDA_CHECK(ggml_cuda_set_device(id)); - CUDA_CHECK(cudaFree(extra->data_device[id])); - } - - for (int64_t is = 0; is < MAX_STREAMS; ++is) { - if (extra->events[id][is] != nullptr) { - CUDA_CHECK(ggml_cuda_set_device(id)); - CUDA_CHECK(cudaEventDestroy(extra->events[id][is])); - } - } - } - - delete extra; -} - -static ggml_tensor_extra_gpu * g_temp_tensor_extras = nullptr; -static size_t g_temp_tensor_extra_index = 0; - -static ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() { - if (g_temp_tensor_extras == nullptr) { - g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_MAX_NODES]; - } - - size_t alloc_index = g_temp_tensor_extra_index; - g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_MAX_NODES; - ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index]; - memset(extra, 0, sizeof(*extra)); - - return extra; -} - -static void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) { - if (scratch && g_scratch_size == 0) { - return; - } - - tensor->backend = GGML_BACKEND_GPU; - - // recursively assign CUDA buffers until a compute tensor is found - if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) { - const ggml_op src0_op = tensor->src[0]->op; - if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE) { - ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace, no_alloc); - } - } - if (tensor->op == GGML_OP_CPY && tensor->src[1]->backend == GGML_BACKEND_CPU) { - ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace, no_alloc); - } - - if (scratch && no_alloc) { - return; - } - - ggml_tensor_extra_gpu * extra; - - const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) || - tensor->op == GGML_OP_VIEW || - force_inplace; - const size_t size = ggml_nbytes(tensor); - - CUDA_CHECK(ggml_cuda_set_device(g_main_device)); - if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) { - ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra; - char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; - size_t offset = 0; - if (tensor->op == GGML_OP_VIEW) { - memcpy(&offset, tensor->op_params, sizeof(size_t)); - } - extra = ggml_cuda_alloc_temp_tensor_extra(); - extra->data_device[g_main_device] = src0_ddc + offset; - } else if (tensor->op == GGML_OP_CPY) { - ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src[1]->extra; - void * src1_ddv = src1_extra->data_device[g_main_device]; - extra = ggml_cuda_alloc_temp_tensor_extra(); - extra->data_device[g_main_device] = src1_ddv; - } else if (scratch) { - GGML_ASSERT(size <= g_scratch_size); - if (g_scratch_offset + size > g_scratch_size) { - g_scratch_offset = 0; - } - - char * data = (char *) g_scratch_buffer; - if (data == nullptr) { - CUDA_CHECK(cudaMalloc(&data, g_scratch_size)); - g_scratch_buffer = data; - } - extra = ggml_cuda_alloc_temp_tensor_extra(); - extra->data_device[g_main_device] = data + g_scratch_offset; - - g_scratch_offset += size; - - GGML_ASSERT(g_scratch_offset <= g_scratch_size); - } else { // allocate new buffers outside of scratch - void * data; - CUDA_CHECK(cudaMalloc(&data, size)); - CUDA_CHECK(cudaMemset(data, 0, size)); - extra = new ggml_tensor_extra_gpu; - memset(extra, 0, sizeof(*extra)); - extra->data_device[g_main_device] = data; - } - - tensor->extra = extra; -} - -void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset) { - if (g_scratch_size == 0) { - return; - } - if (g_scratch_buffer == nullptr) { - ggml_cuda_set_device(g_main_device); - CUDA_CHECK(cudaMalloc(&g_scratch_buffer, g_scratch_size)); - } - - ggml_tensor_extra_gpu * extra = ggml_cuda_alloc_temp_tensor_extra(); - - const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) || - tensor->op == GGML_OP_VIEW; - - if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) { - ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra; - char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; - size_t view_offset = 0; - if (tensor->op == GGML_OP_VIEW) { - memcpy(&view_offset, tensor->op_params, sizeof(size_t)); - } - extra->data_device[g_main_device] = src0_ddc + view_offset; - } else { - extra->data_device[g_main_device] = (char *) g_scratch_buffer + offset; - } - - tensor->extra = extra; -} - -void ggml_cuda_copy_to_device(struct ggml_tensor * tensor) { - GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); - GGML_ASSERT(ggml_is_contiguous(tensor)); - - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; - CUDA_CHECK(ggml_cuda_set_device(g_main_device)); - CUDA_CHECK(cudaMemcpy(extra->data_device[g_main_device], tensor->data, ggml_nbytes(tensor), cudaMemcpyHostToDevice)); -} - -void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) { - ggml_cuda_assign_buffers_impl(tensor, true, false, false); -} - -void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor) { - ggml_cuda_assign_buffers_impl(tensor, true, false, true); -} - -void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) { - ggml_cuda_assign_buffers_impl(tensor, false, false, false); -} - -void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) { - ggml_cuda_assign_buffers_impl(tensor, false, true, false); -} - -void ggml_cuda_set_main_device(const int main_device) { - if (main_device >= g_device_count) { - fprintf(stderr, "warning: cannot set main_device=%d because there are only %d devices. Using device %d instead.\n", - main_device, g_device_count, g_main_device); - return; - } - g_main_device = main_device; - if (g_device_count > 1) { - cudaDeviceProp prop; - CUDA_CHECK(cudaGetDeviceProperties(&prop, g_main_device)); - fprintf(stderr, "%s: using device %d (%s) as main device\n", __func__, g_main_device, prop.name); - } -} - -void ggml_cuda_set_mul_mat_q(const bool mul_mat_q) { - g_mul_mat_q = mul_mat_q; -} - -void ggml_cuda_set_scratch_size(const size_t scratch_size) { - // this is a hack to not completely break llama.cpp when using multiple models or contexts simultaneously - // it still won't always work as expected, but it's better than nothing - if (scratch_size > g_scratch_size) { - ggml_cuda_free_scratch(); - } - g_scratch_size = std::max(g_scratch_size, scratch_size); -} - -void ggml_cuda_free_scratch() { - if (g_scratch_buffer == nullptr) { - return; - } - - CUDA_CHECK(cudaFree(g_scratch_buffer)); - g_scratch_buffer = nullptr; -} - -bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { - ggml_cuda_func_t func; - const bool any_on_device = tensor->backend == GGML_BACKEND_GPU - || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) - || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU); - - if (!any_on_device && tensor->op != GGML_OP_MUL_MAT) { - return false; - } - - switch (tensor->op) { - case GGML_OP_REPEAT: - func = ggml_cuda_repeat; - break; - case GGML_OP_GET_ROWS: - func = ggml_cuda_get_rows; - break; - case GGML_OP_DUP: - func = ggml_cuda_dup; - break; - case GGML_OP_ADD: - func = ggml_cuda_add; - break; - case GGML_OP_MUL: - func = ggml_cuda_mul; - break; - case GGML_OP_UNARY: - switch (ggml_get_unary_op(tensor)) { - case GGML_UNARY_OP_GELU: - func = ggml_cuda_gelu; - break; - case GGML_UNARY_OP_SILU: - func = ggml_cuda_silu; - break; - default: - return false; - } break; - case GGML_OP_NORM: - func = ggml_cuda_norm; - break; - case GGML_OP_RMS_NORM: - func = ggml_cuda_rms_norm; - break; - case GGML_OP_MUL_MAT: - if (!any_on_device && !ggml_cuda_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) { - return false; - } - func = ggml_cuda_mul_mat; - break; - case GGML_OP_SCALE: - func = ggml_cuda_scale; - break; - case GGML_OP_CLAMP: - if (!any_on_device) { - return false; - } - func = ggml_cuda_clamp; - break; - case GGML_OP_CPY: - func = ggml_cuda_cpy; - break; - case GGML_OP_CONT: - func = ggml_cuda_dup; - break; - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_PERMUTE: - case GGML_OP_TRANSPOSE: - func = ggml_cuda_nop; - break; - case GGML_OP_DIAG_MASK_INF: - func = ggml_cuda_diag_mask_inf; - break; - case GGML_OP_SOFT_MAX: - func = ggml_cuda_soft_max; - break; - case GGML_OP_ROPE: - func = ggml_cuda_rope; - break; - case GGML_OP_ALIBI: - func = ggml_cuda_alibi; - break; - default: - return false; - } - - if (params->ith != 0) { - return true; - } - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return true; - } - func(tensor->src[0], tensor->src[1], tensor); - return true; -} - -int ggml_cuda_get_device_count() { - int device_count; - CUDA_CHECK(cudaGetDeviceCount(&device_count)); - return device_count; -} - -void ggml_cuda_get_device_description(int device, char * description, size_t description_size) { - cudaDeviceProp prop; - CUDA_CHECK(cudaGetDeviceProperties(&prop, device)); - snprintf(description, description_size, "%s", prop.name); -} - -//////////////////////////////////////////////////////////////////////////////// - -// backend interface - -#define UNUSED GGML_UNUSED - -struct ggml_backend_context_cuda { -}; - -static const char * ggml_backend_cuda_name(ggml_backend_t backend) { - return GGML_CUDA_NAME; - - UNUSED(backend); -} - -static void ggml_backend_cuda_free(ggml_backend_t backend) { - ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context; - delete cuda_ctx; - delete backend; -} - -struct ggml_backend_buffer_context_cuda { - void * device; - - ggml_tensor_extra_gpu * temp_tensor_extras = nullptr; - size_t temp_tensor_extra_index = 0; - - ~ggml_backend_buffer_context_cuda() { - delete[] temp_tensor_extras; - } - - ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() { - if (temp_tensor_extras == nullptr) { - temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_MAX_NODES]; - } - - size_t alloc_index = temp_tensor_extra_index; - temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_MAX_NODES; - ggml_tensor_extra_gpu * extra = &temp_tensor_extras[alloc_index]; - memset(extra, 0, sizeof(*extra)); - - return extra; - } -}; - -static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) { - ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context; - CUDA_CHECK(cudaFree(ctx->device)); - delete ctx; -} - -static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) { - ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context; - return ctx->device; -} - -static size_t ggml_backend_cuda_buffer_get_alloc_size(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { - int64_t row_low = 0; - int64_t row_high = ggml_nrows(tensor); - int64_t nrows_split = row_high - row_low; - - size_t size = ggml_nbytes_split(tensor, nrows_split); - - int64_t ne0 = tensor->ne[0]; - - if (ggml_is_quantized(tensor->type)) { - if (ne0 % MATRIX_ROW_PADDING != 0) { - size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING) - * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type); - } - } - - return size; - - UNUSED(buffer); -} - -static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { - ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context; - - if (tensor->view_src != NULL && tensor->view_offs == 0) { - assert(tensor->view_src->buffer->backend == buffer->backend); - tensor->backend = tensor->view_src->backend; - tensor->extra = tensor->view_src->extra; - return; - } - - ggml_tensor_extra_gpu * extra = ctx->ggml_cuda_alloc_temp_tensor_extra(); - - extra->data_device[g_main_device] = tensor->data; - - tensor->backend = GGML_BACKEND_GPU; - tensor->extra = extra; - - if (ggml_is_quantized(tensor->type)) { - // initialize padding to 0 to avoid possible NaN values - int64_t row_low = 0; - int64_t row_high = ggml_nrows(tensor); - int64_t nrows_split = row_high - row_low; - - size_t original_size = ggml_nbytes_split(tensor, nrows_split); - size_t padded_size = ggml_backend_cuda_buffer_get_alloc_size(tensor->buffer, tensor); - - if (padded_size > original_size && tensor->view_src == nullptr) { - CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + original_size, 0, padded_size - original_size, g_cudaStreams[g_main_device][0])); - } - } - - UNUSED(buffer); -} - -static struct ggml_backend_buffer_i cuda_backend_buffer_interface = { - /* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer, - /* .get_base = */ ggml_backend_cuda_buffer_get_base, - /* .get_alloc_size = */ ggml_backend_cuda_buffer_get_alloc_size, - /* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor, - /* .free_tensor = */ NULL, -}; - -static ggml_backend_buffer_t ggml_backend_cuda_alloc_buffer(ggml_backend_t backend, size_t size) { - ggml_cuda_set_device(g_main_device); - - ggml_backend_buffer_context_cuda * ctx = new ggml_backend_buffer_context_cuda; - CUDA_CHECK(cudaMalloc(&ctx->device, size)); - return ggml_backend_buffer_init(backend, cuda_backend_buffer_interface, ctx, size); -} - -static size_t ggml_backend_cuda_get_alignment(ggml_backend_t backend) { - return 128; - UNUSED(backend); -} - -static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); - GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); - - CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, g_cudaStreams[g_main_device][0])); - - UNUSED(backend); -} - -static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { - GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); - GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); - - CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0])); - - UNUSED(backend); -} - -static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { - CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0])); - - UNUSED(backend); -} - -static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backend_t backend, ggml_cgraph * cgraph) { - GGML_ASSERT(!"not implemented"); - - return nullptr; - - UNUSED(backend); - UNUSED(cgraph); -} - -static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { - GGML_ASSERT(!"not implemented"); - - UNUSED(backend); - UNUSED(plan); -} - -static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { - GGML_ASSERT(!"not implemented"); - - UNUSED(backend); - UNUSED(plan); -} - -static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { - ggml_cuda_set_device(g_main_device); - - ggml_compute_params params = {}; - params.type = GGML_TASK_COMPUTE; - params.ith = 0; - for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_tensor * node = cgraph->nodes[i]; - - assert(node->backend == GGML_BACKEND_GPU); - for (int j = 0; j < GGML_MAX_SRC; j++) { - if (node->src[j] != nullptr) { - assert(node->src[j]->backend == GGML_BACKEND_GPU); - } - } - - bool ok = ggml_cuda_compute_forward(¶ms, node); - if (!ok) { - fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); - } - GGML_ASSERT(ok); - -#if 0 - if (node->type == GGML_TYPE_F32) { - cudaDeviceSynchronize(); - std::vector tmp(ggml_nelements(node), 0.0f); - cudaMemcpy(tmp.data(), node->data, ggml_nelements(node)*sizeof(float), cudaMemcpyDeviceToHost); - printf("\n%s (%s) (%s %s) (%s %s): ", node->name, ggml_op_name(node->op), - ggml_type_name(node->src[0]->type), - node->src[1] ? ggml_type_name(node->src[1]->type) : "none", - node->src[0]->name, - node->src[1] ? node->src[1]->name : "none"); - double sum = 0.0; - double sq_sum = 0.0; - for (int i = 0; i < ggml_nelements(node); i++) { - printf("%f ", tmp[i]); - sum += tmp[i]; - sq_sum += tmp[i]*tmp[i]; - } - printf("\n"); - printf("sum: %f, ", sum); - printf("sq_sum: %f\n", sq_sum); - } -#endif - } - - UNUSED(backend); -} - -static ggml_backend_i cuda_backend_i = { - /* .get_name = */ ggml_backend_cuda_name, - /* .free = */ ggml_backend_cuda_free, - /* .alloc_buffer = */ ggml_backend_cuda_alloc_buffer, - /* .get_alignment = */ ggml_backend_cuda_get_alignment, - /* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async, - /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async, - /* .synchronize = */ ggml_backend_cuda_synchronize, - /* .cpy_tensor_from = */ nullptr, - /* .cpy_tensor_to = */ nullptr, - /* .graph_plan_create = */ ggml_backend_cuda_graph_plan_create, - /* .graph_plan_free = */ ggml_backend_cuda_graph_plan_free, - /* .graph_plan_compute = */ ggml_backend_cuda_graph_plan_compute, - /* .graph_compute = */ ggml_backend_cuda_graph_compute, - /* .supports_op = */ nullptr, -}; - -ggml_backend_t ggml_backend_cuda_init() { - ggml_init_cublas(); // TODO: remove from ggml.c - - ggml_backend_context_cuda * ctx = new ggml_backend_context_cuda; - - ggml_backend_t cuda_backend = new ggml_backend { - /* .interface = */ cuda_backend_i, - /* .context = */ ctx - }; - - return cuda_backend; -} diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-cuda.h b/plugins/wasi_nn/thirdparty/ggml/ggml-cuda.h deleted file mode 100644 index 57adc9cf..00000000 --- a/plugins/wasi_nn/thirdparty/ggml/ggml-cuda.h +++ /dev/null @@ -1,51 +0,0 @@ -#pragma once - -#include "ggml.h" -#include "ggml-backend.h" - -#ifdef GGML_USE_HIPBLAS -#define GGML_CUDA_NAME "ROCm" -#define GGML_CUBLAS_NAME "hipBLAS" -#else -#define GGML_CUDA_NAME "CUDA" -#define GGML_CUBLAS_NAME "cuBLAS" -#endif - -#ifdef __cplusplus -extern "C" { -#endif - -#define GGML_CUDA_MAX_DEVICES 16 - -GGML_API void ggml_init_cublas(void); -GGML_API void * ggml_cuda_host_malloc(size_t size); -GGML_API void ggml_cuda_host_free(void * ptr); - -GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); -GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split); -GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor); -GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor); - -GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor); -GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor); -GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor); - -GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor); -GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset); -GGML_API void ggml_cuda_copy_to_device(struct ggml_tensor * tensor); - -GGML_API void ggml_cuda_set_main_device(int main_device); -GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q); -GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size); -GGML_API void ggml_cuda_free_scratch(void); -GGML_API bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor); - -GGML_API int ggml_cuda_get_device_count(void); -GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size); - -// backend API -GGML_API ggml_backend_t ggml_backend_cuda_init(void); // TODO: take a list of devices to use - -#ifdef __cplusplus -} -#endif diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-metal.h b/plugins/wasi_nn/thirdparty/ggml/ggml-metal.h deleted file mode 100644 index 096b844e..00000000 --- a/plugins/wasi_nn/thirdparty/ggml/ggml-metal.h +++ /dev/null @@ -1,106 +0,0 @@ -// An interface allowing to compute ggml_cgraph with Metal -// -// This is a fully functional interface that extends ggml with GPU support for Apple devices. -// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, OpenCL, etc.) -// -// How it works? -// -// As long as your program can create and evaluate a ggml_cgraph on the CPU, you can use this -// interface to evaluate the same graph on the GPU. Instead of using ggml_graph_compute(), you -// use ggml_metal_graph_compute() (or ggml_vulkan_graph_compute(), etc.) -// -// You only need to make sure that all memory buffers that you used during the graph creation -// are mapped to the device memory with the ggml_metal_add_buffer() function. This mapping is -// used during the graph evaluation to determine the arguments of the compute kernels. -// -// Synchronization between device and host memory (for example for input and output tensors) -// is done with the ggml_metal_set_tensor() and ggml_metal_get_tensor() functions. -// - -#pragma once - -#include "ggml.h" -#include "ggml-backend.h" - -#include -#include - -// max memory buffers that can be mapped to the device -#define GGML_METAL_MAX_BUFFERS 16 -#define GGML_METAL_MAX_COMMAND_BUFFERS 32 - -struct ggml_tensor; -struct ggml_cgraph; - -#ifdef __cplusplus -extern "C" { -#endif - -// -// internal API -// temporary exposed to user-code -// - -struct ggml_metal_context; - -void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data); - -// number of command buffers to use -struct ggml_metal_context * ggml_metal_init(int n_cb); -void ggml_metal_free(struct ggml_metal_context * ctx); - -void * ggml_metal_host_malloc(size_t n); -void ggml_metal_host_free (void * data); - -// set the number of command buffers to use -void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb); - -// creates a mapping between a host memory buffer and a device memory buffer -// - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute -// - the mapping is used during computation to determine the arguments of the compute kernels -// - you don't need to keep the host memory buffer allocated as it is never accessed by Metal -// - max_size specifies the maximum size of a tensor and is used to create shared views such -// that it is guaranteed that the tensor will fit in at least one of the views -// -bool ggml_metal_add_buffer( - struct ggml_metal_context * ctx, - const char * name, - void * data, - size_t size, - size_t max_size); - -// set data from host memory into the device -void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t); - -// get data from the device into host memory -void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t); - -// try to find operations that can be run concurrently in the graph -// you should run it again if the topology of your graph changes -void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf, bool check_mem); - -// if the graph has been optimized for concurrently dispatch, return length of the concur_list if optimized -int ggml_metal_if_optimized(struct ggml_metal_context * ctx); - -// output the concur_list for ggml_alloc -int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx); - -// same as ggml_graph_compute but uses Metal -// creates gf->n_threads command buffers in parallel -void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf); - -// -// backend API -// user-code should use only these functions -// - -GGML_API ggml_backend_t ggml_backend_metal_init(void); - -GGML_API bool ggml_backend_is_metal(ggml_backend_t backend); - -GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb); - -#ifdef __cplusplus -} -#endif - diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-metal.m b/plugins/wasi_nn/thirdparty/ggml/ggml-metal.m deleted file mode 100644 index 87fa1721..00000000 --- a/plugins/wasi_nn/thirdparty/ggml/ggml-metal.m +++ /dev/null @@ -1,1601 +0,0 @@ -#import "ggml-metal.h" - -#import "ggml.h" - -#import - -#import - -#undef MIN -#undef MAX -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define MAX(a, b) ((a) > (b) ? (a) : (b)) - -#ifdef GGML_METAL_NDEBUG -#define GGML_METAL_LOG_INFO(...) -#define GGML_METAL_LOG_WARN(...) -#define GGML_METAL_LOG_ERROR(...) -#else -#define GGML_METAL_LOG_INFO(...) ggml_metal_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__) -#define GGML_METAL_LOG_WARN(...) ggml_metal_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__) -#define GGML_METAL_LOG_ERROR(...) ggml_metal_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) -#endif - -#define UNUSED(x) (void)(x) - -#define GGML_MAX_CONCUR (2*GGML_MAX_NODES) - -struct ggml_metal_buffer { - const char * name; - - void * data; - size_t size; - - id metal; -}; - -struct ggml_metal_context { - int n_cb; - - id device; - id queue; - id library; - - id command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS]; - id command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS]; - - dispatch_queue_t d_queue; - - int n_buffers; - struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS]; - - int concur_list[GGML_MAX_CONCUR]; - int concur_list_len; - - // custom kernels -#define GGML_METAL_DECL_KERNEL(name) \ - id function_##name; \ - id pipeline_##name - - GGML_METAL_DECL_KERNEL(add); - GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast - GGML_METAL_DECL_KERNEL(mul); - GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast - GGML_METAL_DECL_KERNEL(scale); - GGML_METAL_DECL_KERNEL(silu); - GGML_METAL_DECL_KERNEL(relu); - GGML_METAL_DECL_KERNEL(gelu); - GGML_METAL_DECL_KERNEL(soft_max); - GGML_METAL_DECL_KERNEL(soft_max_4); - GGML_METAL_DECL_KERNEL(diag_mask_inf); - GGML_METAL_DECL_KERNEL(diag_mask_inf_8); - GGML_METAL_DECL_KERNEL(get_rows_f32); - GGML_METAL_DECL_KERNEL(get_rows_f16); - GGML_METAL_DECL_KERNEL(get_rows_q4_0); - GGML_METAL_DECL_KERNEL(get_rows_q4_1); - GGML_METAL_DECL_KERNEL(get_rows_q8_0); - GGML_METAL_DECL_KERNEL(get_rows_q2_K); - GGML_METAL_DECL_KERNEL(get_rows_q3_K); - GGML_METAL_DECL_KERNEL(get_rows_q4_K); - GGML_METAL_DECL_KERNEL(get_rows_q5_K); - GGML_METAL_DECL_KERNEL(get_rows_q6_K); - GGML_METAL_DECL_KERNEL(rms_norm); - GGML_METAL_DECL_KERNEL(norm); - GGML_METAL_DECL_KERNEL(mul_mv_f32_f32); - GGML_METAL_DECL_KERNEL(mul_mv_f16_f32); - GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row); - GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4); - GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32); - GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32); - GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32); - GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32); - GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32); - GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32); - GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32); - GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32); - GGML_METAL_DECL_KERNEL(mul_mm_f32_f32); - GGML_METAL_DECL_KERNEL(mul_mm_f16_f32); - GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32); - GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32); - GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32); - GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32); - GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32); - GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32); - GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32); - GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32); - GGML_METAL_DECL_KERNEL(rope_f32); - GGML_METAL_DECL_KERNEL(rope_f16); - GGML_METAL_DECL_KERNEL(alibi_f32); - GGML_METAL_DECL_KERNEL(cpy_f32_f16); - GGML_METAL_DECL_KERNEL(cpy_f32_f32); - GGML_METAL_DECL_KERNEL(cpy_f16_f16); - GGML_METAL_DECL_KERNEL(concat); - GGML_METAL_DECL_KERNEL(sqr); - -#undef GGML_METAL_DECL_KERNEL -}; - -// MSL code -// TODO: move the contents here when ready -// for now it is easier to work in a separate file -static NSString * const msl_library_source = @"see metal.metal"; - -// Here to assist with NSBundle Path Hack -@interface GGMLMetalClass : NSObject -@end -@implementation GGMLMetalClass -@end - -ggml_log_callback ggml_metal_log_callback = NULL; -void * ggml_metal_log_user_data = NULL; - -void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) { - ggml_metal_log_callback = log_callback; - ggml_metal_log_user_data = user_data; -} - -static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){ - if (ggml_metal_log_callback != NULL) { - va_list args; - va_start(args, format); - char buffer[128]; - int len = vsnprintf(buffer, 128, format, args); - if (len < 128) { - ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data); - } else { - char* buffer2 = malloc(len+1); - vsnprintf(buffer2, len+1, format, args); - buffer2[len] = 0; - ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data); - free(buffer2); - } - va_end(args); - } -} - - - -struct ggml_metal_context * ggml_metal_init(int n_cb) { - GGML_METAL_LOG_INFO("%s: allocating\n", __func__); - - id device; - NSString * s; - -#if TARGET_OS_OSX - // Show all the Metal device instances in the system - NSArray * devices = MTLCopyAllDevices(); - for (device in devices) { - s = [device name]; - GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]); - } -#endif - - // Pick and show default Metal device - device = MTLCreateSystemDefaultDevice(); - s = [device name]; - GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]); - - // Configure context - struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context)); - ctx->device = device; - ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS); - ctx->queue = [ctx->device newCommandQueue]; - ctx->n_buffers = 0; - ctx->concur_list_len = 0; - - ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); - - // load library - { - NSBundle * bundle = nil; -#ifdef SWIFT_PACKAGE - bundle = SWIFTPM_MODULE_BUNDLE; -#else - bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; -#endif - NSError * error = nil; - NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"]; - if (libPath != nil) { - NSURL * libURL = [NSURL fileURLWithPath:libPath]; - GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]); - ctx->library = [ctx->device newLibraryWithURL:libURL error:&error]; - } else { - GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); - - NSString * sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; - GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [sourcePath UTF8String]); - NSString * src = [NSString stringWithContentsOfFile:sourcePath encoding:NSUTF8StringEncoding error:&error]; - if (error) { - GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); - return NULL; - } - - MTLCompileOptions* options = nil; -#ifdef GGML_QKK_64 - options = [MTLCompileOptions new]; - options.preprocessorMacros = @{ @"QK_K" : @(64) }; -#endif - ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; - } - - if (error) { - GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); - return NULL; - } - } - - // load kernels - { - NSError * error = nil; -#define GGML_METAL_ADD_KERNEL(name) \ - ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \ - ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \ - GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \ - (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \ - (int) ctx->pipeline_##name.threadExecutionWidth); \ - if (error) { \ - GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ - return NULL; \ - } - - GGML_METAL_ADD_KERNEL(add); - GGML_METAL_ADD_KERNEL(add_row); - GGML_METAL_ADD_KERNEL(mul); - GGML_METAL_ADD_KERNEL(mul_row); - GGML_METAL_ADD_KERNEL(scale); - GGML_METAL_ADD_KERNEL(silu); - GGML_METAL_ADD_KERNEL(relu); - GGML_METAL_ADD_KERNEL(gelu); - GGML_METAL_ADD_KERNEL(soft_max); - GGML_METAL_ADD_KERNEL(soft_max_4); - GGML_METAL_ADD_KERNEL(diag_mask_inf); - GGML_METAL_ADD_KERNEL(diag_mask_inf_8); - GGML_METAL_ADD_KERNEL(get_rows_f32); - GGML_METAL_ADD_KERNEL(get_rows_f16); - GGML_METAL_ADD_KERNEL(get_rows_q4_0); - GGML_METAL_ADD_KERNEL(get_rows_q4_1); - GGML_METAL_ADD_KERNEL(get_rows_q8_0); - GGML_METAL_ADD_KERNEL(get_rows_q2_K); - GGML_METAL_ADD_KERNEL(get_rows_q3_K); - GGML_METAL_ADD_KERNEL(get_rows_q4_K); - GGML_METAL_ADD_KERNEL(get_rows_q5_K); - GGML_METAL_ADD_KERNEL(get_rows_q6_K); - GGML_METAL_ADD_KERNEL(rms_norm); - GGML_METAL_ADD_KERNEL(norm); - GGML_METAL_ADD_KERNEL(mul_mv_f32_f32); - GGML_METAL_ADD_KERNEL(mul_mv_f16_f32); - GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row); - GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4); - GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32); - GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32); - GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32); - GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32); - GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32); - GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32); - GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32); - GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32); - if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) { - GGML_METAL_ADD_KERNEL(mul_mm_f32_f32); - GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32); - } - GGML_METAL_ADD_KERNEL(rope_f32); - GGML_METAL_ADD_KERNEL(rope_f16); - GGML_METAL_ADD_KERNEL(alibi_f32); - GGML_METAL_ADD_KERNEL(cpy_f32_f16); - GGML_METAL_ADD_KERNEL(cpy_f32_f32); - GGML_METAL_ADD_KERNEL(cpy_f16_f16); - GGML_METAL_ADD_KERNEL(concat); - GGML_METAL_ADD_KERNEL(sqr); - -#undef GGML_METAL_ADD_KERNEL - } - -#if TARGET_OS_OSX - // print MTL GPU family: - GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]); - - // determine max supported GPU family - // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf - // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf - for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) { - if ([ctx->device supportsFamily:i]) { - GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i); - break; - } - } - - GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); - GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); - if (ctx->device.maxTransferRate != 0) { - GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0); - } else { - GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__); - } -#endif - - return ctx; -} - -void ggml_metal_free(struct ggml_metal_context * ctx) { - GGML_METAL_LOG_INFO("%s: deallocating\n", __func__); -#define GGML_METAL_DEL_KERNEL(name) \ - [ctx->function_##name release]; \ - [ctx->pipeline_##name release]; - - GGML_METAL_DEL_KERNEL(add); - GGML_METAL_DEL_KERNEL(add_row); - GGML_METAL_DEL_KERNEL(mul); - GGML_METAL_DEL_KERNEL(mul_row); - GGML_METAL_DEL_KERNEL(scale); - GGML_METAL_DEL_KERNEL(silu); - GGML_METAL_DEL_KERNEL(relu); - GGML_METAL_DEL_KERNEL(gelu); - GGML_METAL_DEL_KERNEL(soft_max); - GGML_METAL_DEL_KERNEL(soft_max_4); - GGML_METAL_DEL_KERNEL(diag_mask_inf); - GGML_METAL_DEL_KERNEL(diag_mask_inf_8); - GGML_METAL_DEL_KERNEL(get_rows_f32); - GGML_METAL_DEL_KERNEL(get_rows_f16); - GGML_METAL_DEL_KERNEL(get_rows_q4_0); - GGML_METAL_DEL_KERNEL(get_rows_q4_1); - GGML_METAL_DEL_KERNEL(get_rows_q8_0); - GGML_METAL_DEL_KERNEL(get_rows_q2_K); - GGML_METAL_DEL_KERNEL(get_rows_q3_K); - GGML_METAL_DEL_KERNEL(get_rows_q4_K); - GGML_METAL_DEL_KERNEL(get_rows_q5_K); - GGML_METAL_DEL_KERNEL(get_rows_q6_K); - GGML_METAL_DEL_KERNEL(rms_norm); - GGML_METAL_DEL_KERNEL(norm); - GGML_METAL_DEL_KERNEL(mul_mv_f32_f32); - GGML_METAL_DEL_KERNEL(mul_mv_f16_f32); - GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row); - GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4); - GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32); - GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32); - GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32); - GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32); - GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32); - GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32); - GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32); - GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32); - if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) { - GGML_METAL_DEL_KERNEL(mul_mm_f32_f32); - GGML_METAL_DEL_KERNEL(mul_mm_f16_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32); - } - GGML_METAL_DEL_KERNEL(rope_f32); - GGML_METAL_DEL_KERNEL(rope_f16); - GGML_METAL_DEL_KERNEL(alibi_f32); - GGML_METAL_DEL_KERNEL(cpy_f32_f16); - GGML_METAL_DEL_KERNEL(cpy_f32_f32); - GGML_METAL_DEL_KERNEL(cpy_f16_f16); - GGML_METAL_DEL_KERNEL(concat); - GGML_METAL_DEL_KERNEL(sqr); - -#undef GGML_METAL_DEL_KERNEL - - for (int i = 0; i < ctx->n_buffers; ++i) { - [ctx->buffers[i].metal release]; - } - - [ctx->library release]; - [ctx->queue release]; - [ctx->device release]; - - dispatch_release(ctx->d_queue); - - free(ctx); -} - -void * ggml_metal_host_malloc(size_t n) { - void * data = NULL; - const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n); - if (result != 0) { - GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__); - return NULL; - } - - return data; -} - -void ggml_metal_host_free(void * data) { - free(data); -} - -void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) { - ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS); -} - -int ggml_metal_if_optimized(struct ggml_metal_context * ctx) { - return ctx->concur_list_len; -} - -int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) { - return ctx->concur_list; -} - -// finds the Metal buffer that contains the tensor data on the GPU device -// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the -// Metal buffer based on the host memory pointer -// -static id ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) { - //GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach); - - const int64_t tsize = ggml_nbytes(t); - - // find the view that contains the tensor fully - for (int i = 0; i < ctx->n_buffers; ++i) { - const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data; - - //GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name); - if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) { - *offs = (size_t) ioffs; - - //GGML_METAL_LOG_INFO("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs); - - return ctx->buffers[i].metal; - } - } - - GGML_METAL_LOG_ERROR("%s: error: buffer is nil\n", __func__); - - return nil; -} - -bool ggml_metal_add_buffer( - struct ggml_metal_context * ctx, - const char * name, - void * data, - size_t size, - size_t max_size) { - if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) { - GGML_METAL_LOG_ERROR("%s: error: too many buffers\n", __func__); - return false; - } - - if (data) { - // verify that the buffer does not overlap with any of the existing buffers - for (int i = 0; i < ctx->n_buffers; ++i) { - const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data; - - if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) { - GGML_METAL_LOG_ERROR("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name); - return false; - } - } - - const size_t size_page = sysconf(_SC_PAGESIZE); - - size_t size_aligned = size; - if ((size_aligned % size_page) != 0) { - size_aligned += (size_page - (size_aligned % size_page)); - } - - // the buffer fits into the max buffer size allowed by the device - if (size_aligned <= ctx->device.maxBufferLength) { - ctx->buffers[ctx->n_buffers].name = name; - ctx->buffers[ctx->n_buffers].data = data; - ctx->buffers[ctx->n_buffers].size = size; - - ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; - - if (ctx->buffers[ctx->n_buffers].metal == nil) { - GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0); - return false; - } - - GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0); - - ++ctx->n_buffers; - } else { - // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into - // one of the views - const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case - const size_t size_step = ctx->device.maxBufferLength - size_ovlp; - const size_t size_view = ctx->device.maxBufferLength; - - for (size_t i = 0; i < size; i += size_step) { - const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i); - - ctx->buffers[ctx->n_buffers].name = name; - ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i); - ctx->buffers[ctx->n_buffers].size = size_step_aligned; - - ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; - - if (ctx->buffers[ctx->n_buffers].metal == nil) { - GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0); - return false; - } - - GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i); - if (i + size_step < size) { - GGML_METAL_LOG_INFO("\n"); - } - - ++ctx->n_buffers; - } - } - -#if TARGET_OS_OSX - GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)", - ctx->device.currentAllocatedSize / 1024.0 / 1024.0, - ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); - - if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) { - GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__); - } else { - GGML_METAL_LOG_INFO("\n"); - } -#else - GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0); -#endif - } - - return true; -} - -void ggml_metal_set_tensor( - struct ggml_metal_context * ctx, - struct ggml_tensor * t) { - size_t offs; - id id_dst = ggml_metal_get_buffer(ctx, t, &offs); - - memcpy((void *) ((uint8_t *) id_dst.contents + offs), t->data, ggml_nbytes(t)); -} - -void ggml_metal_get_tensor( - struct ggml_metal_context * ctx, - struct ggml_tensor * t) { - size_t offs; - id id_src = ggml_metal_get_buffer(ctx, t, &offs); - - memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t)); -} - -void ggml_metal_graph_find_concurrency( - struct ggml_metal_context * ctx, - struct ggml_cgraph * gf, bool check_mem) { - int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time - int nodes_unused[GGML_MAX_CONCUR]; - - for (int i = 0; i < GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; } - for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; } - ctx->concur_list_len = 0; - - int n_left = gf->n_nodes; - int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list - int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos - - while (n_left > 0) { - // number of nodes at a layer (that can be issued concurrently) - int concurrency = 0; - for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) { - if (nodes_unused[i]) { - // if the requirements for gf->nodes[i] are satisfied - int exe_flag = 1; - - // scan all srcs - for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) { - struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind]; - if (src_cur) { - // if is leaf nodes it's satisfied. - // TODO: ggml_is_leaf() - if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) { - continue; - } - - // otherwise this src should be the output from previous nodes. - int is_found = 0; - - // scan 2*search_depth back because we inserted barrier. - //for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) { - for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) { - if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) { - is_found = 1; - break; - } - } - if (is_found == 0) { - exe_flag = 0; - break; - } - } - } - if (exe_flag && check_mem) { - // check if nodes[i]'s data will be overwritten by a node before nodes[i]. - // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3] - int64_t data_start = (int64_t) gf->nodes[i]->data; - int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]); - for (int j = n_start; j < i; j++) { - if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \ - && gf->nodes[j]->op != GGML_OP_VIEW \ - && gf->nodes[j]->op != GGML_OP_TRANSPOSE \ - && gf->nodes[j]->op != GGML_OP_PERMUTE) { - if (((int64_t)gf->nodes[j]->data) >= data_start + length || \ - ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) { - continue; - } - - exe_flag = 0; - } - } - } - if (exe_flag) { - ctx->concur_list[level_pos + concurrency] = i; - nodes_unused[i] = 0; - concurrency++; - ctx->concur_list_len++; - } - } - } - n_left -= concurrency; - // adding a barrier different layer - ctx->concur_list[level_pos + concurrency] = -1; - ctx->concur_list_len++; - // jump all sorted nodes at nodes_bak - while (!nodes_unused[n_start]) { - n_start++; - } - level_pos += concurrency + 1; - } - - if (ctx->concur_list_len > GGML_MAX_CONCUR) { - GGML_METAL_LOG_WARN("%s: too many elements for metal ctx->concur_list!\n", __func__); - } -} - -void ggml_metal_graph_compute( - struct ggml_metal_context * ctx, - struct ggml_cgraph * gf) { - @autoreleasepool { - - // if there is ctx->concur_list, dispatch concurrently - // else fallback to serial dispatch - MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; - - const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR; - - const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes; - edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial; - - // create multiple command buffers and enqueue them - // then, we encode the graph into the command buffers in parallel - - const int n_cb = ctx->n_cb; - - for (int i = 0; i < n_cb; ++i) { - ctx->command_buffers[i] = [ctx->queue commandBuffer]; - - // enqueue the command buffers in order to specify their execution order - [ctx->command_buffers[i] enqueue]; - - ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc]; - } - - for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { - const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb; - - dispatch_async(ctx->d_queue, ^{ - size_t offs_src0 = 0; - size_t offs_src1 = 0; - size_t offs_dst = 0; - - id command_buffer = ctx->command_buffers[cb_idx]; - id encoder = ctx->command_encoders[cb_idx]; - - const int node_start = (cb_idx + 0) * n_nodes_per_cb; - const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes); - - for (int ind = node_start; ind < node_end; ++ind) { - const int i = has_concur ? ctx->concur_list[ind] : ind; - - if (i == -1) { - [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]; - continue; - } - - //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); - - struct ggml_tensor * src0 = gf->nodes[i]->src[0]; - struct ggml_tensor * src1 = gf->nodes[i]->src[1]; - struct ggml_tensor * dst = gf->nodes[i]; - - const int64_t ne00 = src0 ? src0->ne[0] : 0; - const int64_t ne01 = src0 ? src0->ne[1] : 0; - const int64_t ne02 = src0 ? src0->ne[2] : 0; - const int64_t ne03 = src0 ? src0->ne[3] : 0; - - const uint64_t nb00 = src0 ? src0->nb[0] : 0; - const uint64_t nb01 = src0 ? src0->nb[1] : 0; - const uint64_t nb02 = src0 ? src0->nb[2] : 0; - const uint64_t nb03 = src0 ? src0->nb[3] : 0; - - const int64_t ne10 = src1 ? src1->ne[0] : 0; - const int64_t ne11 = src1 ? src1->ne[1] : 0; - const int64_t ne12 = src1 ? src1->ne[2] : 0; - const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13); - - const uint64_t nb10 = src1 ? src1->nb[0] : 0; - const uint64_t nb11 = src1 ? src1->nb[1] : 0; - const uint64_t nb12 = src1 ? src1->nb[2] : 0; - const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13); - - const int64_t ne0 = dst ? dst->ne[0] : 0; - const int64_t ne1 = dst ? dst->ne[1] : 0; - const int64_t ne2 = dst ? dst->ne[2] : 0; - const int64_t ne3 = dst ? dst->ne[3] : 0; - - const uint64_t nb0 = dst ? dst->nb[0] : 0; - const uint64_t nb1 = dst ? dst->nb[1] : 0; - const uint64_t nb2 = dst ? dst->nb[2] : 0; - const uint64_t nb3 = dst ? dst->nb[3] : 0; - - const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; - const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; - const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT; - - id id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil; - id id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil; - id id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil; - - //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); - //if (src0) { - // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, - // ggml_is_contiguous(src0), src0->name); - //} - //if (src1) { - // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, - // ggml_is_contiguous(src1), src1->name); - //} - //if (dst) { - // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, - // dst->name); - //} - - switch (dst->op) { - case GGML_OP_NONE: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_TRANSPOSE: - case GGML_OP_PERMUTE: - { - // noop - } break; - case GGML_OP_CONCAT: - { - const int64_t nb = ne00; - - [encoder setComputePipelineState:ctx->pipeline_concat]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:27]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ADD: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - - bool bcast_row = false; - - int64_t nb = ne00; - - if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) { - // src1 is a row - GGML_ASSERT(ne11 == 1); - - nb = ne00 / 4; - [encoder setComputePipelineState:ctx->pipeline_add_row]; - - bcast_row = true; - } else { - [encoder setComputePipelineState:ctx->pipeline_add]; - } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:27]; - - if (bcast_row) { - const int64_t n = ggml_nelements(dst)/4; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } else { - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } - } break; - case GGML_OP_MUL: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - - // utilize float4 - GGML_ASSERT(ne00 % 4 == 0); - const int64_t nb = ne00/4; - - if (ggml_nelements(src1) == ne10) { - // src1 is a row - GGML_ASSERT(ne11 == 1); - [encoder setComputePipelineState:ctx->pipeline_mul_row]; - } else { - [encoder setComputePipelineState:ctx->pipeline_mul]; - } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:3]; - - const int64_t n = ggml_nelements(dst)/4; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SCALE: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - const float scale = *(const float *) src1->data; - - [encoder setComputePipelineState:ctx->pipeline_scale]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; - - const int64_t n = ggml_nelements(dst); - GGML_ASSERT(n % 4 == 0); - - [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_UNARY: - switch (ggml_get_unary_op(gf->nodes[i])) { - case GGML_UNARY_OP_SILU: - { - [encoder setComputePipelineState:ctx->pipeline_silu]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - GGML_ASSERT(n % 4 == 0); - - [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_RELU: - { - [encoder setComputePipelineState:ctx->pipeline_relu]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_GELU: - { - [encoder setComputePipelineState:ctx->pipeline_gelu]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - GGML_ASSERT(n % 4 == 0); - - [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - default: - { - GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); - GGML_ASSERT(false); - } - } break; - case GGML_OP_SQR: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - [encoder setComputePipelineState:ctx->pipeline_sqr]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SOFT_MAX: - { - const int nth = MIN(32, ne00); - - if (ne00%4 == 0) { - [encoder setComputePipelineState:ctx->pipeline_soft_max_4]; - } else { - [encoder setComputePipelineState:ctx->pipeline_soft_max]; - } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_DIAG_MASK_INF: - { - const int n_past = ((int32_t *)(dst->op_params))[0]; - - if (ne00%8 == 0) { - [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8]; - } else { - [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf]; - } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; - - if (ne00%8 == 0) { - [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } - else { - [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } - } break; - case GGML_OP_MUL_MAT: - { - GGML_ASSERT(ne00 == ne10); - GGML_ASSERT(ne03 == ne13); - - const uint gqa = ne12/ne02; - - // find the break-even point where the matrix-matrix kernel becomes more efficient compared - // to the matrix-vector kernel - int ne11_mm_min = 1; - -#if 0 - // the numbers below are measured on M2 Ultra for 7B and 13B models - // these numbers do not translate to other devices or model sizes - // TODO: need to find a better approach - if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) { - switch (src0t) { - case GGML_TYPE_F16: ne11_mm_min = 2; break; - case GGML_TYPE_Q8_0: ne11_mm_min = 7; break; - case GGML_TYPE_Q2_K: ne11_mm_min = 15; break; - case GGML_TYPE_Q3_K: ne11_mm_min = 7; break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: ne11_mm_min = 15; break; - case GGML_TYPE_Q4_K: ne11_mm_min = 11; break; - case GGML_TYPE_Q5_0: // not tested yet - case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet - case GGML_TYPE_Q5_K: ne11_mm_min = 7; break; - case GGML_TYPE_Q6_K: ne11_mm_min = 7; break; - default: ne11_mm_min = 1; break; - } - } -#endif - - // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs - // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && - !ggml_is_transposed(src0) && - !ggml_is_transposed(src1) && - src1t == GGML_TYPE_F32 && - ne00 % 32 == 0 && ne00 >= 64 && - ne11 > ne11_mm_min) { - //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); - switch (src0->type) { - case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break; - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break; - case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break; - case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break; - case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break; - case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break; - case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break; - case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break; - case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break; - case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break; - default: GGML_ASSERT(false && "MUL MAT-MAT not implemented"); - } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; - [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13]; - [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; - } else { - int nth0 = 32; - int nth1 = 1; - int nrows = 1; - //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); - - // use custom matrix x vector kernel - switch (src0t) { - case GGML_TYPE_F32: - { - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32]; - nrows = 4; - } break; - case GGML_TYPE_F16: - { - nth0 = 32; - nth1 = 1; - if (ne11 * ne12 < 4) { - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row]; - } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4]; - nrows = ne11; - } else { - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32]; - nrows = 4; - } - } break; - case GGML_TYPE_Q4_0: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 8; - nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32]; - } break; - case GGML_TYPE_Q4_1: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 8; - nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32]; - } break; - case GGML_TYPE_Q8_0: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 8; - nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32]; - } break; - case GGML_TYPE_Q2_K: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 2; - nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32]; - } break; - case GGML_TYPE_Q3_K: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 2; - nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32]; - } break; - case GGML_TYPE_Q4_K: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 4; //1; - nth1 = 8; //32; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32]; - } break; - case GGML_TYPE_Q5_K: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 2; - nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32]; - } break; - case GGML_TYPE_Q6_K: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 2; - nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32]; - } break; - default: - { - GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t); - GGML_ASSERT(false && "not implemented"); - } - }; - - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; - [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17]; - - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 || - src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q3_K) { -#ifdef GGML_QKK_64 - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; -#else - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; -#endif - } - else if (src0t == GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else { - int64_t ny = (ne11 + nrows - 1)/nrows; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - } - } break; - case GGML_OP_GET_ROWS: - { - switch (src0->type) { - case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break; - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; - case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; - case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; - case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break; - case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break; - case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break; - case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break; - case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break; - case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break; - default: GGML_ASSERT(false && "not implemented"); - } - - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5]; - - const int64_t n = ggml_nelements(src1); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_RMS_NORM: - { - GGML_ASSERT(ne00 % 4 == 0); - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - const int nth = MIN(512, ne00); - - [encoder setComputePipelineState:ctx->pipeline_rms_norm]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0]; - - const int64_t nrows = ggml_nrows(src0); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_NORM: - { - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - const int nth = MIN(256, ne00); - - [encoder setComputePipelineState:ctx->pipeline_norm]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; - - const int64_t nrows = ggml_nrows(src0); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ALIBI: - { - GGML_ASSERT((src0t == GGML_TYPE_F32)); - - const int nth = MIN(1024, ne00); - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_head = ((int32_t *) dst->op_params)[1]; - float max_bias; - memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - - const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - - [encoder setComputePipelineState:ctx->pipeline_alibi_f32]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&m0 length:sizeof( float) atIndex:18]; - [encoder setBytes:&m1 length:sizeof( float) atIndex:19]; - [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ROPE: - { - GGML_ASSERT(ne10 == ne02); - - const int nth = MIN(1024, ne00); - - const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - - float freq_base; - float freq_scale; - memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); - - switch (src0->type) { - case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break; - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break; - default: GGML_ASSERT(false); - }; - - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&n_past length:sizeof( int) atIndex:19]; - [encoder setBytes:&n_dims length:sizeof( int) atIndex:20]; - [encoder setBytes:&mode length:sizeof( int) atIndex:21]; - [encoder setBytes:&freq_base length:sizeof(float) atIndex:22]; - [encoder setBytes:&freq_scale length:sizeof(float) atIndex:23]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_DUP: - case GGML_OP_CPY: - case GGML_OP_CONT: - { - const int nth = MIN(1024, ne00); - - switch (src0t) { - case GGML_TYPE_F32: - { - switch (dstt) { - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break; - case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break; - default: GGML_ASSERT(false && "not implemented"); - }; - } break; - case GGML_TYPE_F16: - { - switch (dstt) { - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break; - case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break; - default: GGML_ASSERT(false && "not implemented"); - }; - } break; - default: GGML_ASSERT(false && "not implemented"); - } - - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - default: - { - GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); - GGML_ASSERT(false); - } - } - } - - if (encoder != nil) { - [encoder endEncoding]; - encoder = nil; - } - - [command_buffer commit]; - }); - } - - // wait for all threads to finish - dispatch_barrier_sync(ctx->d_queue, ^{}); - - // check status of command buffers - // needed to detect if the device ran out-of-memory for example (#1881) - for (int i = 0; i < n_cb; i++) { - [ctx->command_buffers[i] waitUntilCompleted]; - - MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status]; - if (status != MTLCommandBufferStatusCompleted) { - GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); - GGML_ASSERT(false); - } - } - - } -} - -//////////////////////////////////////////////////////////////////////////////// - -// backend interface - -static const char * ggml_backend_metal_name(ggml_backend_t backend) { - return "Metal"; - - UNUSED(backend); -} - -static void ggml_backend_metal_free(ggml_backend_t backend) { - struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context; - ggml_metal_free(ctx); - free(backend); -} - -static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) { - return (void *)buffer->context; -} - -static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) { - free(buffer->context); - UNUSED(buffer); -} - -static struct ggml_backend_buffer_i metal_backend_buffer_i = { - /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer, - /* .get_base = */ ggml_backend_metal_buffer_get_base, - /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .init_tensor = */ NULL, // no initialization required - /* .free_tensor = */ NULL, // no cleanup required -}; - -static ggml_backend_buffer_t ggml_backend_metal_alloc_buffer(ggml_backend_t backend, size_t size) { - struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context; - - void * data = ggml_metal_host_malloc(size); - - // TODO: set proper name of the buffers - ggml_metal_add_buffer(ctx, "backend", data, size, 0); - - return ggml_backend_buffer_init(backend, metal_backend_buffer_i, data, size); -} - -static size_t ggml_backend_metal_get_alignment(ggml_backend_t backend) { - return 32; - UNUSED(backend); -} - -static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); - GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - - memcpy((char *)tensor->data + offset, data, size); - - UNUSED(backend); -} - -static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { - GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); - GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - - memcpy(data, (const char *)tensor->data + offset, size); - - UNUSED(backend); -} - -static void ggml_backend_metal_synchronize(ggml_backend_t backend) { - UNUSED(backend); -} - -static void ggml_backend_metal_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) { - ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src)); - - UNUSED(backend); -} - -static void ggml_backend_metal_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) { - ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src)); - - UNUSED(backend); -} - -static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { - struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context; - - ggml_metal_graph_compute(metal_ctx, cgraph); -} - -static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { - return true; - UNUSED(backend); - UNUSED(op); -} - -static struct ggml_backend_i metal_backend_i = { - /* .get_name = */ ggml_backend_metal_name, - /* .free = */ ggml_backend_metal_free, - /* .alloc_buffer = */ ggml_backend_metal_alloc_buffer, - /* .get_alignment = */ ggml_backend_metal_get_alignment, - /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async, - /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async, - /* .synchronize = */ ggml_backend_metal_synchronize, - /* .cpy_tensor_from = */ ggml_backend_metal_cpy_tensor_from, - /* .cpy_tensor_to = */ ggml_backend_metal_cpy_tensor_to, - /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm - /* .graph_plan_free = */ NULL, - /* .graph_plan_compute = */ NULL, - /* .graph_compute = */ ggml_backend_metal_graph_compute, - /* .supports_op = */ ggml_backend_metal_supports_op, -}; - -ggml_backend_t ggml_backend_metal_init(void) { - struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context)); - - ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS); - - ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend)); - - *metal_backend = (struct ggml_backend) { - /* .interface = */ metal_backend_i, - /* .context = */ ctx, - }; - - return metal_backend; -} - -bool ggml_backend_is_metal(ggml_backend_t backend) { - return backend->iface.get_name == ggml_backend_metal_name; -} - -void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { - struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context; - - ggml_metal_set_n_cb(ctx, n_cb); -} diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml-metal.metal b/plugins/wasi_nn/thirdparty/ggml/ggml-metal.metal deleted file mode 100644 index 99b9fd7a..00000000 --- a/plugins/wasi_nn/thirdparty/ggml/ggml-metal.metal +++ /dev/null @@ -1,2526 +0,0 @@ -#include - -using namespace metal; - -#define MAX(x, y) ((x) > (y) ? (x) : (y)) - -#define QK4_0 32 -#define QR4_0 2 -typedef struct { - half d; // delta - uint8_t qs[QK4_0 / 2]; // nibbles / quants -} block_q4_0; - -#define QK4_1 32 -typedef struct { - half d; // delta - half m; // min - uint8_t qs[QK4_1 / 2]; // nibbles / quants -} block_q4_1; - -#define QK8_0 32 -typedef struct { - half d; // delta - int8_t qs[QK8_0]; // quants -} block_q8_0; - -// general-purpose kernel for addition of two tensors -// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3 -// cons: not very efficient -kernel void kernel_add( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant int64_t & nb00, - constant int64_t & nb01, - constant int64_t & nb02, - constant int64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant int64_t & nb10, - constant int64_t & nb11, - constant int64_t & nb12, - constant int64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant int64_t & nb0, - constant int64_t & nb1, - constant int64_t & nb2, - constant int64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0]; - - src0_ptr += ntg.x*nb00; - src1_ptr += ntg.x*nb10; - dst_ptr += ntg.x*nb0; - } -} - -// assumption: src1 is a row -// broadcast src1 into src0 -kernel void kernel_add_row( - device const float4 * src0, - device const float4 * src1, - device float4 * dst, - constant int64_t & nb [[buffer(27)]], - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] + src1[tpig % nb]; -} - -kernel void kernel_mul( - device const float4 * src0, - device const float4 * src1, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src1[tpig]; -} - -// assumption: src1 is a row -// broadcast src1 into src0 -kernel void kernel_mul_row( - device const float4 * src0, - device const float4 * src1, - device float4 * dst, - constant int64_t & nb, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src1[tpig % nb]; -} - -kernel void kernel_scale( - device const float4 * src0, - device float4 * dst, - constant float & scale, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * scale; -} - -kernel void kernel_silu( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - dst[tpig] = x / (1.0f + exp(-x)); -} - -kernel void kernel_relu( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = max(0.0f, src0[tpig]); -} - -kernel void kernel_sqr( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src0[tpig]; -} - -constant float GELU_COEF_A = 0.044715f; -constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - -kernel void kernel_gelu( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - - // BEWARE !!! - // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! - // This was observed with Falcon 7B and 40B models - // - dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -kernel void kernel_soft_max( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - // parallel max - float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY; - for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) { - lmax = MAX(lmax, psrc0[i00]); - } - const float max = simd_max(lmax); - - // parallel sum - float lsum = 0.0f; - for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { - const float exp_psrc0 = exp(psrc0[i00] - max); - lsum += exp_psrc0; - // Remember the result of exp here. exp is expensive, so we really do not - // whish to compute it twice. - pdst[i00] = exp_psrc0; - } - - const float sum = simd_sum(lsum); - - for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { - pdst[i00] /= sum; - } -} - -kernel void kernel_soft_max_4( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - - // parallel max - float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY; - for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) { - lmax4 = fmax(lmax4, psrc4[i00]); - } - float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); - - const float max = simd_max(lmax); - - // parallel sum - float4 lsum4 = 0.0f; - for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) { - const float4 exp_psrc4 = exp(psrc4[i00] - max); - lsum4 += exp_psrc4; - pdst4[i00] = exp_psrc4; - } - float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; - - const float sum = simd_sum(lsum); - - for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) { - pdst4[i00] /= sum; - } -} - -kernel void kernel_diag_mask_inf( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int & n_past, - uint3 tpig[[thread_position_in_grid]]) { - const int64_t i02 = tpig[2]; - const int64_t i01 = tpig[1]; - const int64_t i00 = tpig[0]; - - if (i00 > n_past + i01) { - dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; - } else { - dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; - } -} - -kernel void kernel_diag_mask_inf_8( - device const float4 * src0, - device float4 * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int & n_past, - uint3 tpig[[thread_position_in_grid]]) { - - const int64_t i = 2*tpig[0]; - - dst[i+0] = src0[i+0]; - dst[i+1] = src0[i+1]; - int64_t i4 = 4*i; - const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; - const int64_t i01 = i4/(ne00); i4 -= i01*ne00; - const int64_t i00 = i4; - for (int k = 3; k >= 0; --k) { - if (i00 + 4 + k <= n_past + i01) { - break; - } - dst[i+1][k] = -INFINITY; - if (i00 + k > n_past + i01) { - dst[i][k] = -INFINITY; - } - } -} - -kernel void kernel_norm( - device const void * src0, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant float & eps, - threadgroup float * sum [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint ntg[[threads_per_threadgroup]]) { - device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01); - // MEAN - // parallel sum - sum[tpitg] = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - sum[tpitg] += x[i00]; - } - // reduce - threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg/2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - const float mean = sum[0] / ne00; - - // recenter and VARIANCE - threadgroup_barrier(mem_flags::mem_threadgroup); - device float * y = dst + tgpig*ne00; - sum[tpitg] = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - y[i00] = x[i00] - mean; - sum[tpitg] += y[i00] * y[i00]; - } - - // reduce - threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg/2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - const float variance = sum[0] / ne00; - - const float scale = 1.0f/sqrt(variance + eps); - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - y[i00] = y[i00] * scale; - } -} - -kernel void kernel_rms_norm( - device const void * src0, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant float & eps, - threadgroup float * sum [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); - device const float * x_scalar = (device const float *) x; - - float4 sumf = 0; - float all_sum = 0; - - // parallel sum - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - sumf += x[i00] * x[i00]; - } - all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3]; - all_sum = simd_sum(all_sum); - if (tiisg == 0) { - sum[sgitg] = all_sum; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // broadcast, simd group number is ntg / 32 - for (uint i = ntg / 32 / 2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - } - if (tpitg == 0) { - for (int i = 4 * (ne00 / 4); i < ne00; i++) { - sum[0] += x_scalar[i]; - } - sum[0] /= ne00; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - const float mean = sum[0]; - const float scale = 1.0f/sqrt(mean + eps); - - device float4 * y = (device float4 *) (dst + tgpig*ne00); - device float * y_scalar = (device float *) y; - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - y[i00] = x[i00] * scale; - } - if (tpitg == 0) { - for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) { - y_scalar[i00] = x_scalar[i00] * scale; - } - } -} - -// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q4 quants begin (0 or QK4_0/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; - float2 acc = 0.f; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); - for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); - } - return d * (sumy * -8.f + acc[0] + acc[1]); -} - -// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q4 quants begin (0 or QK4_0/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; - float m = qb_curr->m; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); - float2 acc = 0.f; - for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); - } - return d * (acc[0] + acc[1]) + sumy * m; -} - -// putting them in the kernel cause a significant performance penalty -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 2 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 -//Note: This is a template, but strictly speaking it only applies to -// quantizations where the block size is 32. It also does not -// giard against the number of rows not being divisible by -// N_DST, so this is another explicit assumption of the implementation. -template -void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst, - int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa, - uint3 tgpig, uint tiisg, uint sgitg) { - const int nb = ne00/QK4_0; - - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * nsg + sgitg) * nr; - - const uint offset0 = first_row * nb + im/gqa*(nb*ne0); - - device const block_q_type * x = (device const block_q_type *) src0 + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[16]; // src1 vector cache - float sumf[nr] = {0.f}; - - const int ix = (tiisg/2); - const int il = (tiisg%2)*8; - - device const float * yb = y + ix * QK4_0 + il; - - // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += nw/2) { - float sumy = 0; - for (int i = 0; i < 8; i += 2) { - sumy += yb[i] + yb[i+1]; - yl[i+0] = yb[i+ 0]; - yl[i+1] = yb[i+ 1]/256.f; - - sumy += yb[i+16] + yb[i+17]; - yl[i+8] = yb[i+16]/16.f; - yl[i+9] = yb[i+17]/4096.f; - } - - for (int row = 0; row < nr; row++) { - sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il); - } - - yb += QK4_0 * 16; - } - - for (int row = 0; row < nr; ++row) { - const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < ne01) { - dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot; - } - } -} - -kernel void kernel_mul_mv_q4_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); -} - -kernel void kernel_mul_mv_q4_1_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); -} - -#define NB_Q8_0 8 - -kernel void kernel_mul_mv_q8_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int nr = N_DST; - const int nsg = N_SIMDGROUP; - const int nw = N_SIMDWIDTH; - - const int nb = ne00/QK8_0; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr; - const uint offset0 = first_row * nb + im/gqa*(nb*ne0); - device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[NB_Q8_0]; - float sumf[nr]={0.f}; - - const int ix = tiisg/4; - const int il = tiisg%4; - - device const float * yb = y + ix * QK8_0 + NB_Q8_0*il; - - // each thread in a SIMD group deals with NB_Q8_0 quants at a time - for (int ib = ix; ib < nb; ib += nw/4) { - for (int i = 0; i < NB_Q8_0; ++i) { - yl[i] = yb[i]; - } - - for (int row = 0; row < nr; row++) { - device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il; - float sumq = 0.f; - for (int iq = 0; iq < NB_Q8_0; ++iq) { - sumq += qs[iq] * yl[iq]; - } - sumf[row] += sumq*x[ib+row*nb].d; - } - - yb += NB_Q8_0 * nw; - } - - for (int row = 0; row < nr; ++row) { - const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; - } - } -} - -#define N_F32_F32 4 - -kernel void kernel_mul_mv_f32_f32( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - - const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_F32_F32; - const int64_t im = tgpig.z; - - device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); - - if (ne00 < 128) { - for (int row = 0; row < N_F32_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - device const float4 * x4 = (device const float4 *)x; - for (int row = 0; row < N_F32_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - device const float4 * y4 = (device const float4 *) y; - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -kernel void kernel_mul_mv_f16_f32_1row( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; - - device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - if (ne00 < 128) { - for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; - } - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } else { - device const half4 * x4 = (device const half4 *) x; - device const float4 * y4 = (device const float4 *) y; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k]; - } - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - -} - -#define N_F16_F32 4 - -kernel void kernel_mul_mv_f16_f32( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - - const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_F16_F32; - const int64_t im = tgpig.z; - - device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); - - if (ne00 < 128) { - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - device const half4 * x4 = (device const half4 *)x; - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - device const float4 * y4 = (device const float4 *) y; - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -// Assumes row size (ne00) is a multiple of 4 -kernel void kernel_mul_mv_f16_f32_l4( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - - const int nrows = ne11; - const int64_t r0 = tgpig.x; - const int64_t im = tgpig.z; - - device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); - - for (int r1 = 0; r1 < nrows; ++r1) { - device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } -} - -kernel void kernel_alibi_f32( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant float & m0, - constant float & m1, - constant int & n_heads_log2_floor, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - float m_k; - if (i2 < n_heads_log2_floor) { - m_k = pow(m0, i2 + 1); - } else { - m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1); - } - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1); - } -} - -typedef void (rope_t)( - device const void * src0, - device const int32_t * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & mode, - constant float & freq_base, - constant float & freq_scale, - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]], - uint3 tgpig[[threadgroup_position_in_grid]]); - -template -kernel void kernel_rope( - device const void * src0, - device const int32_t * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & mode, - constant float & freq_base, - constant float & freq_scale, - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]], - uint3 tgpig[[threadgroup_position_in_grid]]) { - const int64_t i3 = tgpig[2]; - const int64_t i2 = tgpig[1]; - const int64_t i1 = tgpig[0]; - - const bool is_neox = mode & 2; - - device const int32_t * pos = src1; - - const int64_t p = pos[i2]; - - const float theta_0 = freq_scale * (float)p; - const float inv_ndims = -1.f/n_dims; - - if (!is_neox) { - for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { - - const float theta = theta_0 * pow(freq_base, inv_ndims*i0); - const float cos_theta = cos(theta); - const float sin_theta = sin(theta); - - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const T x0 = src[0]; - const T x1 = src[1]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[1] = x0*sin_theta + x1*cos_theta; - } - } else { - for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { - for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) { - - const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib); - const float cos_theta = cos(theta); - const float sin_theta = sin(theta); - - const int64_t i0 = ib*n_dims + ic/2; - - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - } - } - } -} - -template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope; -template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope; - -kernel void kernel_cpy_f16_f16( - device const half * src0, - device half * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f32_f16( - device const float * src0, - device half * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f32_f32( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - dst_data[i00] = src[0]; - } -} - -kernel void kernel_concat( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - if (i02 < ne02) { - ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0]; - src0_ptr += ntg.x*nb00; - } else { - ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0]; - src1_ptr += ntg.x*nb10; - } - dst_ptr += ntg.x*nb0; - } -} - -//============================================ k-quants ====================================================== - -#ifndef QK_K -#define QK_K 256 -#else -static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64"); -#endif - -#if QK_K == 256 -#define K_SCALE_SIZE 12 -#else -#define K_SCALE_SIZE 4 -#endif - -typedef struct { - uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits - uint8_t qs[QK_K/4]; // quants - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins -} block_q2_K; -// 84 bytes / block - -typedef struct { - uint8_t hmask[QK_K/8]; // quants - high bit - uint8_t qs[QK_K/4]; // quants - low 2 bits -#if QK_K == 64 - uint8_t scales[2]; -#else - uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits -#endif - half d; // super-block scale -} block_q3_K; - -#if QK_K == 64 -typedef struct { - half d[2]; // super-block scales/mins - uint8_t scales[2]; - uint8_t qs[QK_K/2]; // 4-bit quants -} block_q4_K; -#else -typedef struct { - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins - uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits - uint8_t qs[QK_K/2]; // 4--bit quants -} block_q4_K; -#endif - -#if QK_K == 64 -typedef struct { - half d; // super-block scales/mins - int8_t scales[QK_K/16]; // 8-bit block scales - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_K; -#else -typedef struct { - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins - uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_K; -// 176 bytes / block -#endif - -typedef struct { - uint8_t ql[QK_K/2]; // quants, lower 4 bits - uint8_t qh[QK_K/4]; // quants, upper 2 bits - int8_t scales[QK_K/16]; // scales, quantized with 8 bits - half d; // super-block scale -} block_q6_K; -// 210 bytes / block - -static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { - uchar4 r; - if (j < 4) { - r[0] = q[j+0] & 63; - r[2] = q[j+1] & 63; - r[1] = q[j+4] & 63; - r[3] = q[j+5] & 63; - } else { - r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); - r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4); - r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); - r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4); - } - return r; -} - -//====================================== dot products ========================= - -kernel void kernel_mul_mv_q2_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int r2 = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - const uint offset0 = r2/gqa*(nb*ne0); - device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int step = sizeof(block_q2_K) * nb; - -#if QK_K == 256 - const int ix = tiisg/8; // 0...3 - const int it = tiisg%8; // 0...7 - const int im = it/4; // 0 or 1 - const int ir = it%4; // 0...3 - const int is = (8*ir)/16;// 0 or 1 - - device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir; - - for (int ib = ix; ib < nb; ib += 4) { - - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; - yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; - yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; - yl[i+24] = y4[i+96]; sumy[3] += yl[i+24]; - } - - device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is; - device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir; - device const half * dh = &x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; - float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); - acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); - acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); - acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); - acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); - acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); - acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); - acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); - } - float dall = dh[0]; - float dmin = dh[1] * 1.f/16.f; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + - (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f + - (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f + - (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - - dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); - - qs += step/2; - sc += step; - dh += step/2; - } - - y4 += 4 * QK_K; - } -#else - const int ix = tiisg/2; // 0...15 - const int it = tiisg%2; // 0...1 - - device const float * y4 = y + ix * QK_K + 8 * it; - - for (int ib = ix; ib < nb; ib += 16) { - - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; - yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8]; - yl[i+16] = y4[i+32]; sumy[2] += yl[i+16]; - yl[i+24] = y4[i+48]; sumy[3] += yl[i+24]; - } - - device const uint8_t * sc = (device const uint8_t *)x[ib].scales; - device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; - device const half * dh = &x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; - float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); - acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); - acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); - acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); - acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); - acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); - acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); - acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); - } - - float dall = dh[0]; - float dmin = dh[1]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + - (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f + - (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f + - (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) - - dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4)); - - qs += step/2; - sc += step; - dh += step/2; - } - - y4 += 16 * QK_K; - } -#endif - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum; - } - } -} - -#if QK_K == 256 -kernel void kernel_mul_mv_q3_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t r2 = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - const uint offset0 = r2/gqa*(nb*ne0); - device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; - - float yl[32]; - - //const uint16_t kmask1 = 0x3030; - //const uint16_t kmask2 = 0x0f0f; - - const int tid = tiisg/4; - const int ix = tiisg%4; - const int ip = tid/4; // 0 or 1 - const int il = 2*((tid%4)/2); // 0 or 2 - const int ir = tid%2; - const int n = 8; - const int l0 = n*ir; - - // One would think that the Metal compiler would figure out that ip and il can only have - // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it - // with these two tales. - // - // Possible masks for the high bit - const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0 - {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2 - {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0 - {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2 - - // Possible masks for the low 2 bits - const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}}; - - const ushort4 hm = mm[2*ip + il/2]; - - const int shift = 2*il; - const float v1 = il == 0 ? 4.f : 64.f; - const float v2 = 4.f * v1; - - const uint16_t s_shift1 = 4*ip; - const uint16_t s_shift2 = s_shift1 + il; - - const int q_offset = 32*ip + l0; - const int y_offset = 128*ip + 32*il + l0; - - const int step = sizeof(block_q3_K) * nb / 2; - - device const float * y1 = yy + ix*QK_K + y_offset; - - uint32_t scales32, aux32; - thread uint16_t * scales16 = (thread uint16_t *)&scales32; - thread const int8_t * scales = (thread const int8_t *)&scales32; - - float sumf1[2] = {0.f}; - float sumf2[2] = {0.f}; - for (int i = ix; i < nb; i += 4) { - - for (int l = 0; l < 8; ++l) { - yl[l+ 0] = y1[l+ 0]; - yl[l+ 8] = y1[l+16]; - yl[l+16] = y1[l+32]; - yl[l+24] = y1[l+48]; - } - - device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); - device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0); - device const uint16_t * a = (device const uint16_t *)(x[i].scales); - device const half * dh = &x[i].d; - - for (int row = 0; row < 2; ++row) { - - const float d_all = (float)dh[0]; - - scales16[0] = a[4]; - scales16[1] = a[5]; - aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030; - scales16[0] = a[il+0]; - scales16[1] = a[il+1]; - scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32; - - float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; - for (int l = 0; l < n; l += 2) { - const int32_t qs = q[l/2]; - s1 += yl[l+0] * (qs & qm[il/2][0]); - s2 += yl[l+1] * (qs & qm[il/2][1]); - s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]); - s4 += yl[l+16] * (qs & qm[il/2][2]); - s5 += yl[l+17] * (qs & qm[il/2][3]); - s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]); - } - float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); - float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); - sumf1[row] += d1 * (scales[0] - 32); - sumf2[row] += d2 * (scales[2] - 32); - - s1 = s2 = s3 = s4 = s5 = s6 = 0; - for (int l = 0; l < n; l += 2) { - const int32_t qs = q[l/2+8]; - s1 += yl[l+8] * (qs & qm[il/2][0]); - s2 += yl[l+9] * (qs & qm[il/2][1]); - s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]); - s4 += yl[l+24] * (qs & qm[il/2][2]); - s5 += yl[l+25] * (qs & qm[il/2][3]); - s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]); - } - d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); - d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); - sumf1[row] += d1 * (scales[1] - 32); - sumf2[row] += d2 * (scales[3] - 32); - - q += step; - h += step; - a += step; - dh += step; - - } - - y1 += 4 * QK_K; - - } - - for (int row = 0; row < 2; ++row) { - const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); - sumf1[row] = simd_sum(sumf); - } - if (tiisg == 0) { - for (int row = 0; row < 2; ++row) { - dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row]; - } - } -} -#else -kernel void kernel_mul_mv_q3_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t r2 = tgpig.z; - - const int row = 2 * r0 + sgitg; - const uint offset0 = r2/gqa*(nb*ne0); - device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; - const int ix = tiisg/4; - const int il = 4 * (tiisg%4);// 0, 4, 8, 12 - const int im = il/8; // 0, 0, 1, 1 - const int in = il%8; // 0, 4, 0, 4 - - float2 sum = {0.f, 0.f}; - - for (int i = ix; i < nb; i += 8) { - - const float d_all = (float)(x[i].d); - - device const uint16_t * q = (device const uint16_t *)(x[i].qs + il); - device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in); - device const uint16_t * s = (device const uint16_t *)(x[i].scales); - device const float * y = yy + i * QK_K + il; - - const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8); - const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f; - const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f; - const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f; - - for (int l = 0; l < 4; l += 2) { - const uint16_t hm = h[l/2] >> im; - sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4)) - + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16)) - + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64)) - + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256)); - sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024)) - + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096)) - + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384)) - + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536)); - } - - } - const float sumf = sum[0] + sum[1] * 1.f/256.f; - - const float tot = simd_sum(sumf); - if (tiisg == 0) { - dst[r1*ne0 + r2*ne0*ne1 + row] = tot; - } - -} -#endif - -#if QK_K == 256 -kernel void kernel_mul_mv_q4_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01 [[buffer(4)]], - constant int64_t & ne02 [[buffer(5)]], - constant int64_t & ne10 [[buffer(9)]], - constant int64_t & ne12 [[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & gqa [[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - const int ix = tiisg/8; // 0...3 - const int it = tiisg%8; // 0...7 - const int im = it/4; // 0 or 1 - const int ir = it%4; // 0...3 - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int r2 = tgpig.z; - //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int first_row = r0 * N_DST; - const int ib_row = first_row * nb; - const uint offset0 = r2/gqa*(nb*ne0); - device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; - float yl[16]; - float yh[16]; - float sumf[N_DST]={0.f}, all_sum; - - const int step = sizeof(block_q4_K) * nb / 2; - - device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir; - - uint16_t sc16[4]; - thread const uint8_t * sc8 = (thread const uint8_t *)sc16; - - for (int ib = ix; ib < nb; ib += 4) { - - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; - yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8]; - yh[i+0] = y4[i+128]; sumy[2] += yh[i+0]; - yh[i+8] = y4[i+160]; sumy[3] += yh[i+8]; - } - - device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im; - device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir; - device const half * dh = &x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - sc16[0] = sc[0] & kmask1; - sc16[1] = sc[2] & kmask1; - sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); - sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); - - device const uint16_t * q2 = q1 + 32; - - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; - float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+0] * (q1[i/2] & 0x000F); - acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00); - acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0); - acc1[3] += yl[i+9] * (q1[i/2] & 0xF000); - acc2[0] += yh[i+0] * (q2[i/2] & 0x000F); - acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00); - acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0); - acc2[3] += yh[i+9] * (q2[i/2] & 0xF000); - } - - float dall = dh[0]; - float dmin = dh[1]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + - (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + - (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + - (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - - q1 += step; - sc += step; - dh += step; - } - - y4 += 4 * QK_K; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum; - } - } -} -#else -kernel void kernel_mul_mv_q4_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int ix = tiisg/4; // 0...7 - const int it = tiisg%4; // 0...3 - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int r2 = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - const uint offset0 = r2/gqa*(nb*ne0); - device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; - float yl[8]; - float yh[8]; - float sumf[N_DST]={0.f}, all_sum; - - const int step = sizeof(block_q4_K) * nb / 2; - - device const float * y4 = y + ix * QK_K + 8 * it; - - uint16_t sc16[4]; - - for (int ib = ix; ib < nb; ib += 8) { - - float2 sumy = {0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i] = y4[i+ 0]; sumy[0] += yl[i]; - yh[i] = y4[i+32]; sumy[1] += yh[i]; - } - - device const uint16_t * sc = (device const uint16_t *)x[ib].scales; - device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; - device const half * dh = x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - sc16[0] = sc[0] & 0x000f; - sc16[1] = sc[0] & 0x0f00; - sc16[2] = sc[0] & 0x00f0; - sc16[3] = sc[0] & 0xf000; - - float2 acc1 = {0.f, 0.f}; - float2 acc2 = {0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+0] * (qs[i/2] & 0x000F); - acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00); - acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0); - acc2[1] += yh[i+1] * (qs[i/2] & 0xF000); - } - - float dall = dh[0]; - float dmin = dh[1]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] + - (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) - - dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f); - - qs += step; - sc += step; - dh += step; - } - - y4 += 8 * QK_K; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum; - } - } -} -#endif - -kernel void kernel_mul_mv_q5_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int r2 = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - const uint offset0 = r2/gqa*(nb*ne0); - device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; - - float sumf[2]={0.f}; - - const int step = sizeof(block_q5_K) * nb; - -#if QK_K == 256 -# - float yl[16], yh[16]; - - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - const int tid = tiisg/4; - const int ix = tiisg%4; - const int im = tid/4; - const int ir = tid%4; - const int n = 8; - - const int l0 = n*ir; - const int q_offset = 32*im + l0; - const int y_offset = 64*im + l0; - - const uint8_t hm1 = 1u << (2*im); - const uint8_t hm2 = hm1 << 1; - const uint8_t hm3 = hm1 << 4; - const uint8_t hm4 = hm2 << 4; - - uint16_t sc16[4]; - thread const uint8_t * sc8 = (thread const uint8_t *)sc16; - - device const float * y1 = yy + ix*QK_K + y_offset; - - for (int i = ix; i < nb; i += 4) { - - device const uint8_t * q1 = x[i].qs + q_offset; - device const uint8_t * qh = x[i].qh + l0; - device const half * dh = &x[i].d; - device const uint16_t * a = (device const uint16_t *)x[i].scales + im; - - device const float * y2 = y1 + 128; - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < 8; ++l) { - yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; - yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; - yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; - yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; - } - - for (int row = 0; row < 2; ++row) { - - device const uint8_t * q2 = q1 + 64; - - sc16[0] = a[0] & kmask1; - sc16[1] = a[2] & kmask1; - sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); - sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); - - float4 acc1 = {0.f}; - float4 acc2 = {0.f}; - for (int l = 0; l < n; ++l) { - uint8_t h = qh[l]; - acc1[0] += yl[l+0] * (q1[l] & 0x0F); - acc1[1] += yl[l+8] * (q1[l] & 0xF0); - acc1[2] += yh[l+0] * (q2[l] & 0x0F); - acc1[3] += yh[l+8] * (q2[l] & 0xF0); - acc2[0] += h & hm1 ? yl[l+0] : 0.f; - acc2[1] += h & hm2 ? yl[l+8] : 0.f; - acc2[2] += h & hm3 ? yh[l+0] : 0.f; - acc2[3] += h & hm4 ? yh[l+8] : 0.f; - } - const float dall = dh[0]; - const float dmin = dh[1]; - sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) + - sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + - sc8[4] * (acc1[2] + 16.f*acc2[2]) + - sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - - q1 += step; - qh += step; - dh += step/2; - a += step/2; - - } - - y1 += 4 * QK_K; - - } -#else - float yl[8], yh[8]; - - const int il = 4 * (tiisg/8); // 0, 4, 8, 12 - const int ix = tiisg%8; - const int im = il/8; // 0, 0, 1, 1 - const int in = il%8; // 0, 4, 0, 4 - - device const float * y = yy + ix*QK_K + il; - - for (int i = ix; i < nb; i += 8) { - - for (int l = 0; l < 4; ++l) { - yl[l+0] = y[l+ 0]; - yl[l+4] = y[l+16]; - yh[l+0] = y[l+32]; - yh[l+4] = y[l+48]; - } - - device const half * dh = &x[i].d; - device const uint8_t * q = x[i].qs + il; - device const uint8_t * h = x[i].qh + in; - device const int8_t * s = x[i].scales; - - for (int row = 0; row < 2; ++row) { - - const float d = dh[0]; - - float2 acc = {0.f, 0.f}; - for (int l = 0; l < 4; ++l) { - const uint8_t hl = h[l] >> im; - acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16)) - + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16)); - acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256)) - + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256)); - } - sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]); - - q += step; - h += step; - s += step; - dh += step/2; - - } - - y += 8 * QK_K; - } -#endif - - for (int row = 0; row < 2; ++row) { - const float tot = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot; - } - } - -} - -kernel void kernel_mul_mv_q6_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const uint8_t kmask1 = 0x03; - const uint8_t kmask2 = 0x0C; - const uint8_t kmask3 = 0x30; - const uint8_t kmask4 = 0xC0; - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int r2 = tgpig.z; - - const int row = 2 * r0 + sgitg; - const uint offset0 = r2/gqa*(nb*ne0); - device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; - - float sumf = 0; - -#if QK_K == 256 - const int tid = tiisg/2; - const int ix = tiisg%2; - const int ip = tid/8; // 0 or 1 - const int il = tid%8; - const int n = 4; - const int l0 = n*il; - const int is = 8*ip + l0/16; - - const int y_offset = 128*ip + l0; - const int q_offset_l = 64*ip + l0; - const int q_offset_h = 32*ip + l0; - - for (int i = ix; i < nb; i += 2) { - - device const uint8_t * q1 = x[i].ql + q_offset_l; - device const uint8_t * q2 = q1 + 32; - device const uint8_t * qh = x[i].qh + q_offset_h; - device const int8_t * sc = x[i].scales + is; - - device const float * y = yy + i * QK_K + y_offset; - - const float dall = x[i].d; - - float4 sums = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < n; ++l) { - sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); - sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); - sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); - sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); - } - - sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); - - } - -#else - const int ix = tiisg/4; - const int il = 4*(tiisg%4); - - for (int i = ix; i < nb; i += 8) { - device const float * y = yy + i * QK_K + il; - device const uint8_t * ql = x[i].ql + il; - device const uint8_t * qh = x[i].qh + il; - device const int8_t * s = x[i].scales; - - const float d = x[i].d; - - float4 sums = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < 4; ++l) { - sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); - sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); - sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32); - sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); - } - sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]); - } - -#endif - - const float tot = simd_sum(sumf); - if (tiisg == 0) { - dst[r1*ne0 + r2*ne0*ne1 + row] = tot; - } -} - -//============================= templates and their specializations ============================= - -// NOTE: this is not dequantizing - we are simply fitting the template -template -void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { - float4x4 temp = *(((device float4x4 *)src)); - for (int i = 0; i < 16; i++){ - reg[i/4][i%4] = temp[i/4][i%4]; - } -} - -template -void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { - half4x4 temp = *(((device half4x4 *)src)); - for (int i = 0; i < 16; i++){ - reg[i/4][i%4] = temp[i/4][i%4]; - } -} - -template -void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 1); - const float d1 = il ? (xb->d / 16.h) : xb->d; - const float d2 = d1 / 256.f; - const float md = -8.h * xb->d; - const ushort mask0 = il ? 0x00F0 : 0x000F; - const ushort mask1 = mask0 << 8; - - for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md; - reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md; - } -} - -template -void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 2); - const float d1 = il ? (xb->d / 16.h) : xb->d; - const float d2 = d1 / 256.f; - const float m = xb->m; - const ushort mask0 = il ? 0x00F0 : 0x000F; - const ushort mask1 = mask0 << 8; - - for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m; - reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m; - } -} - -template -void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { - device const int8_t * qs = ((device const int8_t *)xb->qs); - const half d = xb->d; - - for (int i=0;i<16;i++) { - reg[i/4][i%4] = (qs[i + 16*il] * d); - } -} - -template -void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { - const half d = xb->d; - const half min = xb->dmin; - device const uint8_t * q = (device const uint8_t *)xb->qs; - half dl, ml; - uint8_t sc = xb->scales[il]; - -#if QK_K == 256 - q = q + 32*(il/8) + 16*(il&1); - il = (il/2)%4; -#endif - half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4); - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - ml; - } -} - -template -void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { - const half d_all = xb->d; - device const uint8_t * q = (device const uint8_t *)xb->qs; - device const uint8_t * h = (device const uint8_t *)xb->hmask; - device const int8_t * scales = (device const int8_t *)xb->scales; - -#if QK_K == 256 - q = q + 32 * (il/8) + 16 * (il&1); - h = h + 16 * (il&1); - uint8_t m = 1 << (il/2); - uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \ - ((il/4)>0 ? 12 : 3); - uint16_t kmask2 = il/8 ? 0xF0 : 0x0F; - uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; - int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) - : (scale_2&kmask2) | ((scale_1&kmask1) << 4); - half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h); - const half ml = 4.h * dl; - - il = (il/2) & 3; - const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - dl *= coef; - - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml); - } -#else - float kcoef = il&1 ? 1.f/16.f : 1.f; - uint16_t kmask = il&1 ? 0xF0 : 0x0F; - float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8); - float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - uint8_t m = 1<<(il*2); - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef)); - } -#endif -} - -static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) { - return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)} - : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))}; -} - -template -void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) { - device const uchar * q = xb->qs; - -#if QK_K == 256 - short is = (il/4) * 2; - q = q + (il/4) * 32 + 16 * (il&1); - il = il & 3; - const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); - const half d = il < 2 ? xb->d : xb->d / 16.h; - const half min = xb->dmin; - const half dl = d * sc[0]; - const half ml = min * sc[1]; -#else - q = q + 16 * (il&1); - device const uint8_t * s = xb->scales; - device const half2 * dh = (device const half2 *)xb->d; - const float2 d = (float2)dh[0]; - const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h; - const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4); -#endif - const ushort mask = il<2 ? 0x0F : 0xF0; - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - ml; - } -} - -template -void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) { - device const uint8_t * q = xb->qs; - device const uint8_t * qh = xb->qh; - -#if QK_K == 256 - short is = (il/4) * 2; - q = q + 32 * (il/4) + 16 * (il&1); - qh = qh + 16 * (il&1); - uint8_t ul = 1 << (il/2); - il = il & 3; - const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); - const half d = il < 2 ? xb->d : xb->d / 16.h; - const half min = xb->dmin; - const half dl = d * sc[0]; - const half ml = min * sc[1]; - - const ushort mask = il<2 ? 0x0F : 0xF0; - const half qh_val = il<2 ? 16.h : 256.h; - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; - } -#else - q = q + 16 * (il&1); - device const int8_t * s = xb->scales; - const float dl = xb->d * s[il]; - uint8_t m = 1<<(il*2); - const float coef = il<2 ? 1.f : 1.f/16.f; - const ushort mask = il<2 ? 0x0F : 0xF0; - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef)); - } -#endif -} - -template -void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { - const half d_all = xb->d; - device const uint8_t * ql = (device const uint8_t *)xb->ql; - device const uint8_t * qh = (device const uint8_t *)xb->qh; - device const int8_t * scales = (device const int8_t *)xb->scales; - -#if QK_K == 256 - ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); - qh = qh + 32*(il/8) + 16*(il&1); - half sc = scales[(il%2) + 2 * ((il/2))]; - il = (il/2) & 3; -#else - ql = ql + 16 * (il&1); - half sc = scales[il]; -#endif - const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; - const half coef = il>1 ? 1.f/16.h : 1.h; - const half ml = d_all * sc * 32.h; - const half dl = d_all * sc * coef; - for (int i = 0; i < 16; ++i) { - const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2)) - : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4)); - reg[i/4][i%4] = dl * q - ml; - } -} - -template -kernel void kernel_get_rows( - device const void * src0, - device const int * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb1, - uint tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tptg[[threads_per_threadgroup]]) { - const int i = tgpig; - const int r = ((device int32_t *) src1)[i]; - - for (int ind = tiitg; ind < ne00/16; ind += tptg) { - float4x4 temp; - dequantize_func( - ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp); - *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp; - } -} - -#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A -#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B -#define BLOCK_SIZE_K 32 -#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A -#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B -#define THREAD_PER_BLOCK 128 -#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers -#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers -#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8 -#define SG_MAT_ROW 8 - -// each block_q contains 16*nl weights -template -kernel void kernel_mul_mm(device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant int64_t & nb01, - constant int64_t & nb02, - constant int64_t & ne12, - constant int64_t & nb10, - constant int64_t & nb11, - constant int64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & gqa, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - threadgroup half * sa = (threadgroup half *)(shared_memory); - threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); - - const uint r0 = tgpig.y; - const uint r1 = tgpig.x; - const uint im = tgpig.z; - - // if this block is of 64x32 shape or smaller - short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; - short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; - - // a thread shouldn't load data outside of the matrix - short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; - short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - - simdgroup_half8x8 ma[4]; - simdgroup_float8x8 mb[2]; - simdgroup_float8x8 c_res[8]; - for (int i = 0; i < 8; i++){ - c_res[i] = make_filled_simdgroup_matrix(0.f); - } - - short il = (tiitg % THREAD_PER_ROW); - - uint offset0 = im/gqa*nb02; - ushort offset1 = il/nl; - - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; - device const float * y = (device const float *)(src1 - + nb12 * im - + nb11 * (r1 * BLOCK_SIZE_N + thread_col) - + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); - - for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { - // load data and store to threadgroup memory - half4x4 temp_a; - dequantize_func(x, il, temp_a); - threadgroup_barrier(mem_flags::mem_threadgroup); - - #pragma unroll(16) - for (int i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ - + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ - + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; - } - - *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); - - il = (il + 2 < nl) ? il + 2 : il % 2; - x = (il < 2) ? x + (2+nl-1)/nl : x; - y += BLOCK_SIZE_K; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // load matrices from threadgroup memory and conduct outer products - threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); - threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); - - #pragma unroll(4) - for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { - #pragma unroll(4) - for (int i = 0; i < 4; i++) { - simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); - } - simdgroup_barrier(mem_flags::mem_none); - #pragma unroll(2) - for (int i = 0; i < 2; i++) { - simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); - } - - lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; - lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; - - #pragma unroll(8) - for (int i = 0; i < 8; i++){ - simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); - } - } - } - - if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { - device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \ - + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; - for (int i = 0; i < 8; i++) { - simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); - } - } else { - // block is smaller than 64x32, we should avoid writing data outside of the matrix - threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ - + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; - for (int i = 0; i < 8; i++) { - simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; - if (sgitg == 0) { - for (int i = 0; i < n_rows; i++) { - for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); - } - } - } - } -} - -#if QK_K == 256 -#define QK_NL 16 -#else -#define QK_NL 4 -#endif - -typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \ - constant uint64_t &, constant uint64_t &, uint, uint, uint); - -template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; - -typedef void (mat_mm_t)( - device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant int64_t & nb01, - constant int64_t & nb02, - constant int64_t & ne12, - constant int64_t & nb10, - constant int64_t & nb11, - constant int64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & gqa, - threadgroup uchar *, uint3, uint, uint); - -template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml.c b/plugins/wasi_nn/thirdparty/ggml/ggml.c deleted file mode 100644 index 630deb49..00000000 --- a/plugins/wasi_nn/thirdparty/ggml/ggml.c +++ /dev/null @@ -1,22041 +0,0 @@ -#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows - -#include "ggml.h" - -#ifdef GGML_USE_K_QUANTS -#include "k_quants.h" -#endif - -#if defined(_MSC_VER) || defined(__MINGW32__) -#include // using malloc.h with MSC/MINGW -#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) -#include -#endif - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#ifdef GGML_USE_METAL -#include -#endif - -// static_assert should be a #define, but if it's not, -// fall back to the _Static_assert C11 keyword. -// if C99 - static_assert is noop -// ref: https://stackoverflow.com/a/53923785/4039976 -#ifndef static_assert -#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L) -#define static_assert(cond, msg) _Static_assert(cond, msg) -#else -#define static_assert(cond, msg) struct global_scope_noop_trick -#endif -#endif - -#if defined(_MSC_VER) -// disable "possible loss of data" to avoid hundreds of casts -// we should just be careful :) -#pragma warning(disable: 4244 4267) - -// disable POSIX deprecation warnigns -// these functions are never going away, anyway -#pragma warning(disable: 4996) -#endif - -#if defined(_WIN32) - -#include - -typedef volatile LONG atomic_int; -typedef atomic_int atomic_bool; - -static void atomic_store(atomic_int * ptr, LONG val) { - InterlockedExchange(ptr, val); -} -static LONG atomic_load(atomic_int * ptr) { - return InterlockedCompareExchange(ptr, 0, 0); -} -static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) { - return InterlockedExchangeAdd(ptr, inc); -} -static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) { - return atomic_fetch_add(ptr, -(dec)); -} - -typedef HANDLE pthread_t; - -typedef DWORD thread_ret_t; -static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(void *), void * arg) { - (void) unused; - HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL); - if (handle == NULL) - { - return EAGAIN; - } - - *out = handle; - return 0; -} - -static int pthread_join(pthread_t thread, void * unused) { - (void) unused; - int ret = (int) WaitForSingleObject(thread, INFINITE); - CloseHandle(thread); - return ret; -} - -static int sched_yield (void) { - Sleep (0); - return 0; -} -#else -#include -#include - -typedef void * thread_ret_t; - -#include -#include -#include - -#endif -#ifdef GGML_USE_CPU_HBM -#include -#endif - -// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512 -#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)) -#ifndef __FMA__ -#define __FMA__ -#endif -#ifndef __F16C__ -#define __F16C__ -#endif -#ifndef __SSE3__ -#define __SSE3__ -#endif -#endif - -/*#define GGML_PERF*/ -#define GGML_DEBUG 0 -#define GGML_GELU_FP16 -#define GGML_GELU_QUICK_FP16 -#define GGML_SILU_FP16 -// #define GGML_CROSS_ENTROPY_EXP_FP16 -// #define GGML_FLASH_ATTN_EXP_FP16 - -#define GGML_SOFT_MAX_UNROLL 4 -#define GGML_VEC_DOT_UNROLL 2 -#define GGML_VEC_MAD_UNROLL 32 - -// -// logging -// - -#if (GGML_DEBUG >= 1) -#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG(...) -#endif - -#if (GGML_DEBUG >= 5) -#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG_5(...) -#endif - -#if (GGML_DEBUG >= 10) -#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG_10(...) -#endif - -#define GGML_PRINT(...) printf(__VA_ARGS__) - -// -// end of logging block -// - -#ifdef GGML_USE_ACCELERATE -// uncomment to use vDSP for soft max computation -// note: not sure if it is actually faster -//#define GGML_SOFT_MAX_ACCELERATE -#endif - -#if defined(_MSC_VER) || defined(__MINGW32__) -#define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN) -#define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr) -#else -inline static void * ggml_aligned_malloc(size_t size) { - if (size == 0) { - GGML_PRINT("WARNING: Behavior may be unexpected when allocating 0 bytes for ggml_aligned_malloc!\n"); - return NULL; - } - void * aligned_memory = NULL; -#ifdef GGML_USE_CPU_HBM - int result = hbw_posix_memalign(&aligned_memory, 16, size); -#elif GGML_USE_METAL - int result = posix_memalign(&aligned_memory, sysconf(_SC_PAGESIZE), size); -#else - int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size); -#endif - if (result != 0) { - // Handle allocation failure - const char *error_desc = "unknown allocation error"; - switch (result) { - case EINVAL: - error_desc = "invalid alignment value"; - break; - case ENOMEM: - error_desc = "insufficient memory"; - break; - } - GGML_PRINT("%s: %s (attempted to allocate %6.2f MB)\n", __func__, error_desc, size/(1024.0*1024.0)); - return NULL; - } - return aligned_memory; -} -#define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size) -#ifdef GGML_USE_CPU_HBM -#define GGML_ALIGNED_FREE(ptr) if(NULL != ptr) hbw_free(ptr) -#else -#define GGML_ALIGNED_FREE(ptr) free(ptr) -#endif -#endif - -#define UNUSED GGML_UNUSED -#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0) - -// -// tensor access macros -// - -#define GGML_TENSOR_UNARY_OP_LOCALS \ - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ - GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ - GGML_TENSOR_LOCALS(size_t, nb, dst, nb) - -#define GGML_TENSOR_BINARY_OP_LOCALS \ - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ - GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ - GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \ - GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \ - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ - GGML_TENSOR_LOCALS(size_t, nb, dst, nb) - -#if defined(GGML_USE_ACCELERATE) -#include -#if defined(GGML_USE_CLBLAST) // allow usage of CLBlast alongside Accelerate functions -#include "ggml-opencl.h" -#endif -#elif defined(GGML_USE_OPENBLAS) -#if defined(GGML_BLAS_USE_MKL) -#include -#else -#include -#endif -#elif defined(GGML_USE_CUBLAS) -#include "ggml-cuda.h" -#elif defined(GGML_USE_CLBLAST) -#include "ggml-opencl.h" -#endif - -#undef MIN -#undef MAX -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define MAX(a, b) ((a) > (b) ? (a) : (b)) - -// floating point type used to accumulate sums -typedef double ggml_float; - -// 16-bit float -// on Arm, we use __fp16 -// on x86, we use uint16_t -#if defined(__ARM_NEON) && !defined(_MSC_VER) - -// if YCM cannot find , make a symbolic link to it, for example: -// -// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ -// -#include - -#define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x)) -#define GGML_COMPUTE_FP32_TO_FP16(x) (x) - -#define GGML_FP16_TO_FP32(x) ((float) (x)) -#define GGML_FP32_TO_FP16(x) (x) - -#else - -#ifdef __wasm_simd128__ -#include -#else -#ifdef __POWER9_VECTOR__ -#include -#undef bool -#define bool _Bool -#else -#if defined(_MSC_VER) || defined(__MINGW32__) -#include -#else -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) -#if !defined(__riscv) -#include -#endif -#endif -#endif -#endif -#endif - -#ifdef __riscv_v_intrinsic -#include -#endif - -#ifdef __F16C__ - -#ifdef _MSC_VER -#define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x))) -#define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0) -#else -#define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x) -#define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0) -#endif - -#elif defined(__POWER9_VECTOR__) - -#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) -#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) -/* the inline asm below is about 12% faster than the lookup method */ -#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x) -#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) - -static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { - register float f; - register double d; - __asm__( - "mtfprd %0,%2\n" - "xscvhpdp %0,%0\n" - "frsp %1,%0\n" : - /* temp */ "=d"(d), - /* out */ "=f"(f): - /* in */ "r"(h)); - return f; -} - -static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { - register double d; - register ggml_fp16_t r; - __asm__( /* xscvdphp can work on double or single precision */ - "xscvdphp %0,%2\n" - "mffprd %1,%0\n" : - /* temp */ "=d"(d), - /* out */ "=r"(r): - /* in */ "f"(f)); - return r; -} - -#else - -// FP16 <-> FP32 -// ref: https://github.com/Maratyszcza/FP16 - -static inline float fp32_from_bits(uint32_t w) { - union { - uint32_t as_bits; - float as_value; - } fp32; - fp32.as_bits = w; - return fp32.as_value; -} - -static inline uint32_t fp32_to_bits(float f) { - union { - float as_value; - uint32_t as_bits; - } fp32; - fp32.as_value = f; - return fp32.as_bits; -} - -static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { - const uint32_t w = (uint32_t) h << 16; - const uint32_t sign = w & UINT32_C(0x80000000); - const uint32_t two_w = w + w; - - const uint32_t exp_offset = UINT32_C(0xE0) << 23; -#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) - const float exp_scale = 0x1.0p-112f; -#else - const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); -#endif - const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; - - const uint32_t magic_mask = UINT32_C(126) << 23; - const float magic_bias = 0.5f; - const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; - - const uint32_t denormalized_cutoff = UINT32_C(1) << 27; - const uint32_t result = sign | - (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); - return fp32_from_bits(result); -} - -static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { -#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) - const float scale_to_inf = 0x1.0p+112f; - const float scale_to_zero = 0x1.0p-110f; -#else - const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); - const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); -#endif - float base = (fabsf(f) * scale_to_inf) * scale_to_zero; - - const uint32_t w = fp32_to_bits(f); - const uint32_t shl1_w = w + w; - const uint32_t sign = w & UINT32_C(0x80000000); - uint32_t bias = shl1_w & UINT32_C(0xFF000000); - if (bias < UINT32_C(0x71000000)) { - bias = UINT32_C(0x71000000); - } - - base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; - const uint32_t bits = fp32_to_bits(base); - const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); - const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); - const uint32_t nonsign = exp_bits + mantissa_bits; - return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); -} - -#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) -#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) - -#endif // __F16C__ - -#endif // __ARM_NEON - -// -// global data -// - -// precomputed gelu table for f16 (128 KB) -static ggml_fp16_t table_gelu_f16[1 << 16]; - -// precomputed quick gelu table for f16 (128 KB) -static ggml_fp16_t table_gelu_quick_f16[1 << 16]; - -// precomputed silu table for f16 (128 KB) -static ggml_fp16_t table_silu_f16[1 << 16]; - -// precomputed exp table for f16 (128 KB) -static ggml_fp16_t table_exp_f16[1 << 16]; - -// precomputed f32 table for f16 (256 KB) -static float table_f32_f16[1 << 16]; - -#if defined(__ARM_NEON) || defined(__wasm_simd128__) -#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s -#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) -#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s) -#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s) -#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s) -#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s) -#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s) -#define B8(c,s ) B7(c,s, c), B7(c,s, s) - -// precomputed tables for expanding 8bits to 8 bytes: -static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4 -static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 -#endif - -// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, -// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON. -// This is also true for POWER9. -#if !defined(GGML_FP16_TO_FP32) || !defined(GGML_FP32_TO_FP16) - -inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { - uint16_t s; - memcpy(&s, &f, sizeof(uint16_t)); - return table_f32_f16[s]; -} - -#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x) -#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) - -#endif - -// note: do not use these inside ggml.c -// these are meant to be used via the ggml.h API -float ggml_fp16_to_fp32(ggml_fp16_t x) { - return (float) GGML_FP16_TO_FP32(x); -} - -ggml_fp16_t ggml_fp32_to_fp16(float x) { - return GGML_FP32_TO_FP16(x); -} - -void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int n) { - for (int i = 0; i < n; i++) { - y[i] = GGML_FP16_TO_FP32(x[i]); - } -} - -void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n) { - int i = 0; -#if defined(__F16C__) - for (; i + 7 < n; i += 8) { - __m256 x_vec = _mm256_loadu_ps(x + i); - __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); - _mm_storeu_si128((__m128i *)(y + i), y_vec); - } - for(; i + 3 < n; i += 4) { - __m128 x_vec = _mm_loadu_ps(x + i); - __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); - _mm_storel_epi64((__m128i *)(y + i), y_vec); - } -#endif - for (; i < n; i++) { - y[i] = GGML_FP32_TO_FP16(x[i]); - } -} - -// -// timing -// - -#if defined(_MSC_VER) || defined(__MINGW32__) -static int64_t timer_freq, timer_start; -void ggml_time_init(void) { - LARGE_INTEGER t; - QueryPerformanceFrequency(&t); - timer_freq = t.QuadPart; - - // The multiplication by 1000 or 1000000 below can cause an overflow if timer_freq - // and the uptime is high enough. - // We subtract the program start time to reduce the likelihood of that happening. - QueryPerformanceCounter(&t); - timer_start = t.QuadPart; -} -int64_t ggml_time_ms(void) { - LARGE_INTEGER t; - QueryPerformanceCounter(&t); - return ((t.QuadPart-timer_start) * 1000) / timer_freq; -} -int64_t ggml_time_us(void) { - LARGE_INTEGER t; - QueryPerformanceCounter(&t); - return ((t.QuadPart-timer_start) * 1000000) / timer_freq; -} -#else -void ggml_time_init(void) {} -int64_t ggml_time_ms(void) { - struct timespec ts; - clock_gettime(CLOCK_MONOTONIC, &ts); - return (int64_t)ts.tv_sec*1000 + (int64_t)ts.tv_nsec/1000000; -} - -int64_t ggml_time_us(void) { - struct timespec ts; - clock_gettime(CLOCK_MONOTONIC, &ts); - return (int64_t)ts.tv_sec*1000000 + (int64_t)ts.tv_nsec/1000; -} -#endif - -int64_t ggml_cycles(void) { - return clock(); -} - -int64_t ggml_cycles_per_ms(void) { - return CLOCKS_PER_SEC/1000; -} - -#ifdef GGML_PERF -#define ggml_perf_time_ms() ggml_time_ms() -#define ggml_perf_time_us() ggml_time_us() -#define ggml_perf_cycles() ggml_cycles() -#define ggml_perf_cycles_per_ms() ggml_cycles_per_ms() -#else -#define ggml_perf_time_ms() 0 -#define ggml_perf_time_us() 0 -#define ggml_perf_cycles() 0 -#define ggml_perf_cycles_per_ms() 0 -#endif - - -// -// cache line -// - -#if defined(__cpp_lib_hardware_interference_size) -#define CACHE_LINE_SIZE hardware_destructive_interference_size -#else -#if defined(__POWER9_VECTOR__) -#define CACHE_LINE_SIZE 128 -#else -#define CACHE_LINE_SIZE 64 -#endif -#endif - -static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); - -// -// quantization -// - -#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) - -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) -// multiply int8_t, add results pairwise twice -static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { - // Get absolute values of x vectors - const __m128i ax = _mm_sign_epi8(x, x); - // Sign the values of the y vectors - const __m128i sy = _mm_sign_epi8(y, x); - // Perform multiplication and create 16-bit values - const __m128i dot = _mm_maddubs_epi16(ax, sy); - const __m128i ones = _mm_set1_epi16(1); - return _mm_madd_epi16(ones, dot); -} - -#if __AVX__ || __AVX2__ || __AVX512F__ -// horizontally add 8 floats -static inline float hsum_float_8(const __m256 x) { - __m128 res = _mm256_extractf128_ps(x, 1); - res = _mm_add_ps(res, _mm256_castps256_ps128(x)); - res = _mm_add_ps(res, _mm_movehl_ps(res, res)); - res = _mm_add_ss(res, _mm_movehdup_ps(res)); - return _mm_cvtss_f32(res); -} - -// horizontally add 8 int32_t -static inline int hsum_i32_8(const __m256i a) { - const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); - const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); - const __m128i sum64 = _mm_add_epi32(hi64, sum128); - const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); - return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); -} - -// horizontally add 4 int32_t -static inline int hsum_i32_4(const __m128i a) { - const __m128i hi64 = _mm_unpackhi_epi64(a, a); - const __m128i sum64 = _mm_add_epi32(hi64, a); - const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); - return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); -} - -#if defined(__AVX2__) || defined(__AVX512F__) -// spread 32 bits to 32 bytes { 0x00, 0xFF } -static inline __m256i bytes_from_bits_32(const uint8_t * x) { - uint32_t x32; - memcpy(&x32, x, sizeof(uint32_t)); - const __m256i shuf_mask = _mm256_set_epi64x( - 0x0303030303030303, 0x0202020202020202, - 0x0101010101010101, 0x0000000000000000); - __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask); - const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe); - bytes = _mm256_or_si256(bytes, bit_mask); - return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1)); -} - -// Unpack 32 4-bit fields into 32 bytes -// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval -static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) -{ - const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi); - const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); - const __m256i lowMask = _mm256_set1_epi8( 0xF ); - return _mm256_and_si256(lowMask, bytes); -} - -// add int16_t pairwise and return as float vector -static inline __m256 sum_i16_pairs_float(const __m256i x) { - const __m256i ones = _mm256_set1_epi16(1); - const __m256i summed_pairs = _mm256_madd_epi16(ones, x); - return _mm256_cvtepi32_ps(summed_pairs); -} - -static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { -#if __AVXVNNI__ - const __m256i zero = _mm256_setzero_si256(); - const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); - return _mm256_cvtepi32_ps(summed_pairs); -#else - // Perform multiplication and create 16-bit values - const __m256i dot = _mm256_maddubs_epi16(ax, sy); - return sum_i16_pairs_float(dot); -#endif -} - -// multiply int8_t, add results pairwise twice and return as float vector -static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { -#if __AVXVNNIINT8__ - const __m256i zero = _mm256_setzero_si256(); - const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y); - return _mm256_cvtepi32_ps(summed_pairs); -#else - // Get absolute values of x vectors - const __m256i ax = _mm256_sign_epi8(x, x); - // Sign the values of the y vectors - const __m256i sy = _mm256_sign_epi8(y, x); - return mul_sum_us8_pairs_float(ax, sy); -#endif -} - -static inline __m128i packNibbles( __m256i bytes ) -{ - // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh -#if __AVX512F__ - const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000 - bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh - return _mm256_cvtepi16_epi8(bytes); // abcd_efgh -#else - const __m256i lowByte = _mm256_set1_epi16( 0xFF ); - __m256i high = _mm256_andnot_si256( lowByte, bytes ); - __m256i low = _mm256_and_si256( lowByte, bytes ); - high = _mm256_srli_epi16( high, 4 ); - bytes = _mm256_or_si256( low, high ); - - // Compress uint16_t lanes into bytes - __m128i r0 = _mm256_castsi256_si128( bytes ); - __m128i r1 = _mm256_extracti128_si256( bytes, 1 ); - return _mm_packus_epi16( r0, r1 ); -#endif -} -#elif defined(__AVX__) -// spread 32 bits to 32 bytes { 0x00, 0xFF } -static inline __m256i bytes_from_bits_32(const uint8_t * x) { - uint32_t x32; - memcpy(&x32, x, sizeof(uint32_t)); - const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); - const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202); - __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl); - __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh); - const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe); - bytesl = _mm_or_si128(bytesl, bit_mask); - bytesh = _mm_or_si128(bytesh, bit_mask); - bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1)); - bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1)); - return MM256_SET_M128I(bytesh, bytesl); -} - -// Unpack 32 4-bit fields into 32 bytes -// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval -static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) -{ - // Load 16 bytes from memory - __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi); - __m128i tmph = _mm_srli_epi16(tmpl, 4); - const __m128i lowMask = _mm_set1_epi8(0xF); - tmpl = _mm_and_si128(lowMask, tmpl); - tmph = _mm_and_si128(lowMask, tmph); - return MM256_SET_M128I(tmph, tmpl); -} - -// add int16_t pairwise and return as float vector -static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) { - const __m128i ones = _mm_set1_epi16(1); - const __m128i summed_pairsl = _mm_madd_epi16(ones, xl); - const __m128i summed_pairsh = _mm_madd_epi16(ones, xh); - const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl); - return _mm256_cvtepi32_ps(summed_pairs); -} - -static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { - const __m128i axl = _mm256_castsi256_si128(ax); - const __m128i axh = _mm256_extractf128_si256(ax, 1); - const __m128i syl = _mm256_castsi256_si128(sy); - const __m128i syh = _mm256_extractf128_si256(sy, 1); - // Perform multiplication and create 16-bit values - const __m128i dotl = _mm_maddubs_epi16(axl, syl); - const __m128i doth = _mm_maddubs_epi16(axh, syh); - return sum_i16_pairs_float(doth, dotl); -} - -// multiply int8_t, add results pairwise twice and return as float vector -static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { - const __m128i xl = _mm256_castsi256_si128(x); - const __m128i xh = _mm256_extractf128_si256(x, 1); - const __m128i yl = _mm256_castsi256_si128(y); - const __m128i yh = _mm256_extractf128_si256(y, 1); - // Get absolute values of x vectors - const __m128i axl = _mm_sign_epi8(xl, xl); - const __m128i axh = _mm_sign_epi8(xh, xh); - // Sign the values of the y vectors - const __m128i syl = _mm_sign_epi8(yl, xl); - const __m128i syh = _mm_sign_epi8(yh, xh); - // Perform multiplication and create 16-bit values - const __m128i dotl = _mm_maddubs_epi16(axl, syl); - const __m128i doth = _mm_maddubs_epi16(axh, syh); - return sum_i16_pairs_float(doth, dotl); -} - -static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) -{ - // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh - const __m128i lowByte = _mm_set1_epi16( 0xFF ); - __m128i high = _mm_andnot_si128( lowByte, bytes1 ); - __m128i low = _mm_and_si128( lowByte, bytes1 ); - high = _mm_srli_epi16( high, 4 ); - bytes1 = _mm_or_si128( low, high ); - high = _mm_andnot_si128( lowByte, bytes2 ); - low = _mm_and_si128( lowByte, bytes2 ); - high = _mm_srli_epi16( high, 4 ); - bytes2 = _mm_or_si128( low, high ); - - return _mm_packus_epi16( bytes1, bytes2); -} -#endif -#elif defined(__SSSE3__) -// horizontally add 4x4 floats -static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) { - __m128 res_0 =_mm_hadd_ps(a, b); - __m128 res_1 =_mm_hadd_ps(c, d); - __m128 res =_mm_hadd_ps(res_0, res_1); - res =_mm_hadd_ps(res, res); - res =_mm_hadd_ps(res, res); - - return _mm_cvtss_f32(res); -} -#endif // __AVX__ || __AVX2__ || __AVX512F__ -#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) - -#if defined(__ARM_NEON) - -#if !defined(__aarch64__) - -inline static int32_t vaddvq_s32(int32x4_t v) { - return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); -} - -inline static float vaddvq_f32(float32x4_t v) { - return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3); -} - -inline static float vmaxvq_f32(float32x4_t v) { - return - MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)), - MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3))); -} - -inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) { - int32x4_t res; - - res[0] = roundf(vgetq_lane_f32(v, 0)); - res[1] = roundf(vgetq_lane_f32(v, 1)); - res[2] = roundf(vgetq_lane_f32(v, 2)); - res[3] = roundf(vgetq_lane_f32(v, 3)); - - return res; -} - -#endif -#endif - -#define QK4_0 32 -typedef struct { - ggml_fp16_t d; // delta - uint8_t qs[QK4_0 / 2]; // nibbles / quants -} block_q4_0; -static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); - -#define QK4_1 32 -typedef struct { - ggml_fp16_t d; // delta - ggml_fp16_t m; // min - uint8_t qs[QK4_1 / 2]; // nibbles / quants -} block_q4_1; -static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding"); - -#define QK5_0 32 -typedef struct { - ggml_fp16_t d; // delta - uint8_t qh[4]; // 5-th bit of quants - uint8_t qs[QK5_0 / 2]; // nibbles / quants -} block_q5_0; -static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); - -#define QK5_1 32 -typedef struct { - ggml_fp16_t d; // delta - ggml_fp16_t m; // min - uint8_t qh[4]; // 5-th bit of quants - uint8_t qs[QK5_1 / 2]; // nibbles / quants -} block_q5_1; -static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); - -#define QK8_0 32 -typedef struct { - ggml_fp16_t d; // delta - int8_t qs[QK8_0]; // quants -} block_q8_0; -static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); - -#define QK8_1 32 -typedef struct { - float d; // delta - float s; // d * sum(qs[i]) - int8_t qs[QK8_1]; // quants -} block_q8_1; -static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding"); - -// reference implementation for deterministic creation of model files -static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { - static const int qk = QK4_0; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < qk; j++) { - const float v = x[i*qk + j]; - if (amax < fabsf(v)) { - amax = fabsf(v); - max = v; - } - } - - const float d = max / -8; - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - - for (int j = 0; j < qk/2; ++j) { - const float x0 = x[i*qk + 0 + j]*id; - const float x1 = x[i*qk + qk/2 + j]*id; - - const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); - const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); - - y[i].qs[j] = xi0; - y[i].qs[j] |= xi1 << 4; - } - } -} - -static void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { - quantize_row_q4_0_reference(x, y, k); -} - -static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) { - const int qk = QK4_1; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - float min = FLT_MAX; - float max = -FLT_MAX; - - for (int j = 0; j < qk; j++) { - const float v = x[i*qk + j]; - - if (v < min) min = v; - if (v > max) max = v; - } - - const float d = (max - min) / ((1 << 4) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - y[i].m = GGML_FP32_TO_FP16(min); - - for (int j = 0; j < qk/2; ++j) { - const float x0 = (x[i*qk + 0 + j] - min)*id; - const float x1 = (x[i*qk + qk/2 + j] - min)*id; - - const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); - const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); - - y[i].qs[j] = xi0; - y[i].qs[j] |= xi1 << 4; - } - } -} - -static void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) { - quantize_row_q4_1_reference(x, y, k); -} - -static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) { - static const int qk = QK5_0; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < qk; j++) { - const float v = x[i*qk + j]; - if (amax < fabsf(v)) { - amax = fabsf(v); - max = v; - } - } - - const float d = max / -16; - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - - uint32_t qh = 0; - - for (int j = 0; j < qk/2; ++j) { - const float x0 = x[i*qk + 0 + j]*id; - const float x1 = x[i*qk + qk/2 + j]*id; - - const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f)); - const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f)); - - y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); - - // get the 5-th bit and store it in qh at the right position - qh |= ((xi0 & 0x10u) >> 4) << (j + 0); - qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2); - } - - memcpy(&y[i].qh, &qh, sizeof(qh)); - } -} - -static void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) { - quantize_row_q5_0_reference(x, y, k); -} - -static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) { - const int qk = QK5_1; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - float min = FLT_MAX; - float max = -FLT_MAX; - - for (int j = 0; j < qk; j++) { - const float v = x[i*qk + j]; - - if (v < min) min = v; - if (v > max) max = v; - } - - const float d = (max - min) / ((1 << 5) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - y[i].m = GGML_FP32_TO_FP16(min); - - uint32_t qh = 0; - - for (int j = 0; j < qk/2; ++j) { - const float x0 = (x[i*qk + 0 + j] - min)*id; - const float x1 = (x[i*qk + qk/2 + j] - min)*id; - - const uint8_t xi0 = (uint8_t)(x0 + 0.5f); - const uint8_t xi1 = (uint8_t)(x1 + 0.5f); - - y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); - - // get the 5-th bit and store it in qh at the right position - qh |= ((xi0 & 0x10u) >> 4) << (j + 0); - qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2); - } - - memcpy(&y[i].qh, &qh, sizeof(y[i].qh)); - } -} - -static void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) { - quantize_row_q5_1_reference(x, y, k); -} - -// reference implementation for deterministic creation of model files -static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) { - assert(k % QK8_0 == 0); - const int nb = k / QK8_0; - - for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max - - for (int j = 0; j < QK8_0; j++) { - const float v = x[i*QK8_0 + j]; - amax = MAX(amax, fabsf(v)); - } - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - - for (int j = 0; j < QK8_0; ++j) { - const float x0 = x[i*QK8_0 + j]*id; - - y[i].qs[j] = roundf(x0); - } - } -} - -static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { - assert(QK8_0 == 32); - assert(k % QK8_0 == 0); - const int nb = k / QK8_0; - - block_q8_0 * restrict y = vy; - -#if defined(__ARM_NEON) - for (int i = 0; i < nb; i++) { - float32x4_t srcv [8]; - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); - - const float amax = vmaxvq_f32(amaxv[0]); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - - for (int j = 0; j < 8; j++) { - const float32x4_t v = vmulq_n_f32(srcv[j], id); - const int32x4_t vi = vcvtnq_s32_f32(v); - - y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); - } - } -#elif defined(__wasm_simd128__) - for (int i = 0; i < nb; i++) { - v128_t srcv [8]; - v128_t asrcv[8]; - v128_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); - - const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), - wasm_f32x4_extract_lane(amaxv[0], 1)), - MAX(wasm_f32x4_extract_lane(amaxv[0], 2), - wasm_f32x4_extract_lane(amaxv[0], 3))); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - - for (int j = 0; j < 8; j++) { - const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); - const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); - - y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); - y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); - y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); - y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); - } - } -#elif defined(__AVX2__) || defined(__AVX__) - for (int i = 0; i < nb; i++) { - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x ); - __m256 v1 = _mm256_loadu_ps( x + 8 ); - __m256 v2 = _mm256_loadu_ps( x + 16 ); - __m256 v3 = _mm256_loadu_ps( x + 24 ); - x += 32; - - // Compute max(abs(e)) for the block - const __m256 signBit = _mm256_set1_ps( -0.0f ); - __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); - - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); - max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); - max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); - const float maxScalar = _mm_cvtss_f32( max4 ); - - // Quantize these floats - const float d = maxScalar / 127.f; - y[i].d = GGML_FP32_TO_FP16(d); - const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; - const __m256 mul = _mm256_set1_ps( id ); - - // Apply the multiplier - v0 = _mm256_mul_ps( v0, mul ); - v1 = _mm256_mul_ps( v1, mul ); - v2 = _mm256_mul_ps( v2, mul ); - v3 = _mm256_mul_ps( v3, mul ); - - // Round to nearest integer - v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); - v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); - v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); - v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); - - // Convert floats to integers - __m256i i0 = _mm256_cvtps_epi32( v0 ); - __m256i i1 = _mm256_cvtps_epi32( v1 ); - __m256i i2 = _mm256_cvtps_epi32( v2 ); - __m256i i3 = _mm256_cvtps_epi32( v3 ); - -#if defined(__AVX2__) - // Convert int32 to int16 - i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 - i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 - // Convert int16 to int8 - i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 - - // We got our precious signed bytes, but the order is now wrong - // These AVX2 pack instructions process 16-byte pieces independently - // The following instruction is fixing the order - const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); - i0 = _mm256_permutevar8x32_epi32( i0, perm ); - - _mm256_storeu_si256((__m256i *)y[i].qs, i0); -#else - // Since we don't have in AVX some necessary functions, - // we split the registers in half and call AVX2 analogs from SSE - __m128i ni0 = _mm256_castsi256_si128( i0 ); - __m128i ni1 = _mm256_extractf128_si256( i0, 1); - __m128i ni2 = _mm256_castsi256_si128( i1 ); - __m128i ni3 = _mm256_extractf128_si256( i1, 1); - __m128i ni4 = _mm256_castsi256_si128( i2 ); - __m128i ni5 = _mm256_extractf128_si256( i2, 1); - __m128i ni6 = _mm256_castsi256_si128( i3 ); - __m128i ni7 = _mm256_extractf128_si256( i3, 1); - - // Convert int32 to int16 - ni0 = _mm_packs_epi32( ni0, ni1 ); - ni2 = _mm_packs_epi32( ni2, ni3 ); - ni4 = _mm_packs_epi32( ni4, ni5 ); - ni6 = _mm_packs_epi32( ni6, ni7 ); - // Convert int16 to int8 - ni0 = _mm_packs_epi16( ni0, ni2 ); - ni4 = _mm_packs_epi16( ni4, ni6 ); - - _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); - _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); -#endif - } -#elif defined(__riscv_v_intrinsic) - - size_t vl = __riscv_vsetvl_e32m4(QK8_0); - - for (int i = 0; i < nb; i++) { - // load elements - vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl); - - vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl); - vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); - vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl); - float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - - vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl); - - // convert to integer - vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl); - vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl); - - // store result - __riscv_vse8_v_i8m1(y[i].qs , vs, vl); - } -#else - // scalar - quantize_row_q8_0_reference(x, y, k); -#endif -} - -// reference implementation for deterministic creation of model files -static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) { - assert(QK8_1 == 32); - assert(k % QK8_1 == 0); - const int nb = k / QK8_1; - - for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max - - for (int j = 0; j < QK8_1; j++) { - const float v = x[i*QK8_1 + j]; - amax = MAX(amax, fabsf(v)); - } - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = d; - - int sum = 0; - - for (int j = 0; j < QK8_1/2; ++j) { - const float v0 = x[i*QK8_1 + j]*id; - const float v1 = x[i*QK8_1 + QK8_1/2 + j]*id; - - y[i].qs[ j] = roundf(v0); - y[i].qs[QK8_1/2 + j] = roundf(v1); - - sum += y[i].qs[ j]; - sum += y[i].qs[QK8_1/2 + j]; - } - - y[i].s = sum*d; - } -} - -static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { - assert(k % QK8_1 == 0); - const int nb = k / QK8_1; - - block_q8_1 * restrict y = vy; - -#if defined(__ARM_NEON) - for (int i = 0; i < nb; i++) { - float32x4_t srcv [8]; - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); - - const float amax = vmaxvq_f32(amaxv[0]); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = d; - - int32x4_t accv = vdupq_n_s32(0); - - for (int j = 0; j < 8; j++) { - const float32x4_t v = vmulq_n_f32(srcv[j], id); - const int32x4_t vi = vcvtnq_s32_f32(v); - - y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); - - accv = vaddq_s32(accv, vi); - } - - y[i].s = d * vaddvq_s32(accv); - } -#elif defined(__wasm_simd128__) - for (int i = 0; i < nb; i++) { - v128_t srcv [8]; - v128_t asrcv[8]; - v128_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); - - const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), - wasm_f32x4_extract_lane(amaxv[0], 1)), - MAX(wasm_f32x4_extract_lane(amaxv[0], 2), - wasm_f32x4_extract_lane(amaxv[0], 3))); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = d; - - v128_t accv = wasm_i32x4_splat(0); - - for (int j = 0; j < 8; j++) { - const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); - const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); - - y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); - y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); - y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); - y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); - - accv = wasm_i32x4_add(accv, vi); - } - - y[i].s = d * (wasm_i32x4_extract_lane(accv, 0) + - wasm_i32x4_extract_lane(accv, 1) + - wasm_i32x4_extract_lane(accv, 2) + - wasm_i32x4_extract_lane(accv, 3)); - } -#elif defined(__AVX2__) || defined(__AVX__) - for (int i = 0; i < nb; i++) { - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x ); - __m256 v1 = _mm256_loadu_ps( x + 8 ); - __m256 v2 = _mm256_loadu_ps( x + 16 ); - __m256 v3 = _mm256_loadu_ps( x + 24 ); - x += 32; - - // Compute max(abs(e)) for the block - const __m256 signBit = _mm256_set1_ps( -0.0f ); - __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); - - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); - max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); - max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); - const float maxScalar = _mm_cvtss_f32( max4 ); - - // Quantize these floats - const float d = maxScalar / 127.f; - y[i].d = d; - const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; - const __m256 mul = _mm256_set1_ps( id ); - - // Apply the multiplier - v0 = _mm256_mul_ps( v0, mul ); - v1 = _mm256_mul_ps( v1, mul ); - v2 = _mm256_mul_ps( v2, mul ); - v3 = _mm256_mul_ps( v3, mul ); - - // Round to nearest integer - v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); - v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); - v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); - v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); - - // Convert floats to integers - __m256i i0 = _mm256_cvtps_epi32( v0 ); - __m256i i1 = _mm256_cvtps_epi32( v1 ); - __m256i i2 = _mm256_cvtps_epi32( v2 ); - __m256i i3 = _mm256_cvtps_epi32( v3 ); - -#if defined(__AVX2__) - // Compute the sum of the quants and set y[i].s - y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); - - // Convert int32 to int16 - i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 - i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 - // Convert int16 to int8 - i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 - - // We got our precious signed bytes, but the order is now wrong - // These AVX2 pack instructions process 16-byte pieces independently - // The following instruction is fixing the order - const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); - i0 = _mm256_permutevar8x32_epi32( i0, perm ); - - _mm256_storeu_si256((__m256i *)y[i].qs, i0); -#else - // Since we don't have in AVX some necessary functions, - // we split the registers in half and call AVX2 analogs from SSE - __m128i ni0 = _mm256_castsi256_si128( i0 ); - __m128i ni1 = _mm256_extractf128_si256( i0, 1); - __m128i ni2 = _mm256_castsi256_si128( i1 ); - __m128i ni3 = _mm256_extractf128_si256( i1, 1); - __m128i ni4 = _mm256_castsi256_si128( i2 ); - __m128i ni5 = _mm256_extractf128_si256( i2, 1); - __m128i ni6 = _mm256_castsi256_si128( i3 ); - __m128i ni7 = _mm256_extractf128_si256( i3, 1); - - // Compute the sum of the quants and set y[i].s - const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3)); - const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7)); - y[i].s = d * hsum_i32_4(_mm_add_epi32(s0, s1)); - - // Convert int32 to int16 - ni0 = _mm_packs_epi32( ni0, ni1 ); - ni2 = _mm_packs_epi32( ni2, ni3 ); - ni4 = _mm_packs_epi32( ni4, ni5 ); - ni6 = _mm_packs_epi32( ni6, ni7 ); - // Convert int16 to int8 - ni0 = _mm_packs_epi16( ni0, ni2 ); - ni4 = _mm_packs_epi16( ni4, ni6 ); - - _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); - _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); -#endif - } -#elif defined(__riscv_v_intrinsic) - - size_t vl = __riscv_vsetvl_e32m4(QK8_1); - - for (int i = 0; i < nb; i++) { - // load elements - vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl); - - vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl); - vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl); - vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl); - float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = d; - - vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl); - - // convert to integer - vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl); - vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl); - - // store result - __riscv_vse8_v_i8m1(y[i].qs , vs, vl); - - // compute sum for y[i].s - vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl); - vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl); - - // set y[i].s - int sum = __riscv_vmv_x_s_i16m1_i16(vwrs); - y[i].s = sum*d; - } -#else - // scalar - quantize_row_q8_1_reference(x, y, k); -#endif -} - -static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) { - static const int qk = QK4_0; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - const float d = GGML_FP16_TO_FP32(x[i].d); - - for (int j = 0; j < qk/2; ++j) { - const int x0 = (x[i].qs[j] & 0x0F) - 8; - const int x1 = (x[i].qs[j] >> 4) - 8; - - y[i*qk + j + 0 ] = x0*d; - y[i*qk + j + qk/2] = x1*d; - } - } -} - -static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) { - static const int qk = QK4_1; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - const float d = GGML_FP16_TO_FP32(x[i].d); - const float m = GGML_FP16_TO_FP32(x[i].m); - - for (int j = 0; j < qk/2; ++j) { - const int x0 = (x[i].qs[j] & 0x0F); - const int x1 = (x[i].qs[j] >> 4); - - y[i*qk + j + 0 ] = x0*d + m; - y[i*qk + j + qk/2] = x1*d + m; - } - } -} - -static void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) { - static const int qk = QK5_0; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - const float d = GGML_FP16_TO_FP32(x[i].d); - - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); - - for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; - const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; - - const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16; - const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; - - y[i*qk + j + 0 ] = x0*d; - y[i*qk + j + qk/2] = x1*d; - } - } -} - -static void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) { - static const int qk = QK5_1; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - const float d = GGML_FP16_TO_FP32(x[i].d); - const float m = GGML_FP16_TO_FP32(x[i].m); - - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); - - for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; - const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; - - const int x0 = (x[i].qs[j] & 0x0F) | xh_0; - const int x1 = (x[i].qs[j] >> 4) | xh_1; - - y[i*qk + j + 0 ] = x0*d + m; - y[i*qk + j + qk/2] = x1*d + m; - } - } -} - -static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, int k) { - static const int qk = QK8_0; - - assert(k % qk == 0); - - const int nb = k / qk; - - const block_q8_0 * restrict x = vx; - - for (int i = 0; i < nb; i++) { - const float d = GGML_FP16_TO_FP32(x[i].d); - - for (int j = 0; j < qk; ++j) { - y[i*qk + j] = x[i].qs[j]*d; - } - } -} - -static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y); -static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y); -static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); -static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); -static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); -static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); -static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); - -static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { - [GGML_TYPE_I8] = { - .type_name = "i8", - .blck_size = 1, - .type_size = sizeof(int8_t), - .is_quantized = false, - }, - [GGML_TYPE_I16] = { - .type_name = "i16", - .blck_size = 1, - .type_size = sizeof(int16_t), - .is_quantized = false, - }, - [GGML_TYPE_I32] = { - .type_name = "i32", - .blck_size = 1, - .type_size = sizeof(int32_t), - .is_quantized = false, - }, - [GGML_TYPE_F32] = { - .type_name = "f32", - .blck_size = 1, - .type_size = sizeof(float), - .is_quantized = false, - .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, - .vec_dot_type = GGML_TYPE_F32, - }, - [GGML_TYPE_F16] = { - .type_name = "f16", - .blck_size = 1, - .type_size = sizeof(ggml_fp16_t), - .is_quantized = false, - .to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row, - .from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row, - .from_float_reference = (ggml_from_float_t) ggml_fp32_to_fp16_row, - .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16, - .vec_dot_type = GGML_TYPE_F16, - }, - [GGML_TYPE_Q4_0] = { - .type_name = "q4_0", - .blck_size = QK4_0, - .type_size = sizeof(block_q4_0), - .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_q4_0, - .from_float = quantize_row_q4_0, - .from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference, - .vec_dot = ggml_vec_dot_q4_0_q8_0, - .vec_dot_type = GGML_TYPE_Q8_0, - }, - [GGML_TYPE_Q4_1] = { - .type_name = "q4_1", - .blck_size = QK4_1, - .type_size = sizeof(block_q4_1), - .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_q4_1, - .from_float = quantize_row_q4_1, - .from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference, - .vec_dot = ggml_vec_dot_q4_1_q8_1, - .vec_dot_type = GGML_TYPE_Q8_1, - }, - [GGML_TYPE_Q5_0] = { - .type_name = "q5_0", - .blck_size = QK5_0, - .type_size = sizeof(block_q5_0), - .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_q5_0, - .from_float = quantize_row_q5_0, - .from_float_reference = (ggml_from_float_t) quantize_row_q5_0_reference, - .vec_dot = ggml_vec_dot_q5_0_q8_0, - .vec_dot_type = GGML_TYPE_Q8_0, - }, - [GGML_TYPE_Q5_1] = { - .type_name = "q5_1", - .blck_size = QK5_1, - .type_size = sizeof(block_q5_1), - .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_q5_1, - .from_float = quantize_row_q5_1, - .from_float_reference = (ggml_from_float_t) quantize_row_q5_1_reference, - .vec_dot = ggml_vec_dot_q5_1_q8_1, - .vec_dot_type = GGML_TYPE_Q8_1, - }, - [GGML_TYPE_Q8_0] = { - .type_name = "q8_0", - .blck_size = QK8_0, - .type_size = sizeof(block_q8_0), - .is_quantized = true, - .to_float = dequantize_row_q8_0, - .from_float = quantize_row_q8_0, - .from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference, - .vec_dot = ggml_vec_dot_q8_0_q8_0, - .vec_dot_type = GGML_TYPE_Q8_0, - }, - [GGML_TYPE_Q8_1] = { - .type_name = "q8_1", - .blck_size = QK8_1, - .type_size = sizeof(block_q8_1), - .is_quantized = true, - .from_float = quantize_row_q8_1, - .from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference, - .vec_dot_type = GGML_TYPE_Q8_1, - }, -#ifdef GGML_USE_K_QUANTS - [GGML_TYPE_Q2_K] = { - .type_name = "q2_K", - .blck_size = QK_K, - .type_size = sizeof(block_q2_K), - .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_q2_K, - .from_float = quantize_row_q2_K, - .from_float_reference = (ggml_from_float_t) quantize_row_q2_K_reference, - .vec_dot = ggml_vec_dot_q2_K_q8_K, - .vec_dot_type = GGML_TYPE_Q8_K, - }, - [GGML_TYPE_Q3_K] = { - .type_name = "q3_K", - .blck_size = QK_K, - .type_size = sizeof(block_q3_K), - .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_q3_K, - .from_float = quantize_row_q3_K, - .from_float_reference = (ggml_from_float_t) quantize_row_q3_K_reference, - .vec_dot = ggml_vec_dot_q3_K_q8_K, - .vec_dot_type = GGML_TYPE_Q8_K, - }, - [GGML_TYPE_Q4_K] = { - .type_name = "q4_K", - .blck_size = QK_K, - .type_size = sizeof(block_q4_K), - .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_q4_K, - .from_float = quantize_row_q4_K, - .from_float_reference = (ggml_from_float_t) quantize_row_q4_K_reference, - .vec_dot = ggml_vec_dot_q4_K_q8_K, - .vec_dot_type = GGML_TYPE_Q8_K, - }, - [GGML_TYPE_Q5_K] = { - .type_name = "q5_K", - .blck_size = QK_K, - .type_size = sizeof(block_q5_K), - .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_q5_K, - .from_float = quantize_row_q5_K, - .from_float_reference = (ggml_from_float_t) quantize_row_q5_K_reference, - .vec_dot = ggml_vec_dot_q5_K_q8_K, - .vec_dot_type = GGML_TYPE_Q8_K, - }, - [GGML_TYPE_Q6_K] = { - .type_name = "q6_K", - .blck_size = QK_K, - .type_size = sizeof(block_q6_K), - .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_q6_K, - .from_float = quantize_row_q6_K, - .from_float_reference = (ggml_from_float_t) quantize_row_q6_K_reference, - .vec_dot = ggml_vec_dot_q6_K_q8_K, - .vec_dot_type = GGML_TYPE_Q8_K, - }, - [GGML_TYPE_Q8_K] = { - .type_name = "q8_K", - .blck_size = QK_K, - .type_size = sizeof(block_q8_K), - .is_quantized = true, - .from_float = quantize_row_q8_K, - } -#endif -}; - -// For internal test use -ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) { - GGML_ASSERT(type < GGML_TYPE_COUNT); - return type_traits[type]; -} - - -// -// simd mappings -// - -// we define a common set of C macros which map to specific intrinsics based on the current architecture -// we then implement the fundamental computation operations below using only these macros -// adding support for new architectures requires to define the corresponding SIMD macros -// -// GGML_F32_STEP / GGML_F16_STEP -// number of elements to process in a single step -// -// GGML_F32_EPR / GGML_F16_EPR -// number of elements to fit in a single register -// - -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) - -#define GGML_SIMD - -// F32 NEON - -#define GGML_F32_STEP 16 -#define GGML_F32_EPR 4 - -#define GGML_F32x4 float32x4_t -#define GGML_F32x4_ZERO vdupq_n_f32(0.0f) -#define GGML_F32x4_SET1(x) vdupq_n_f32(x) -#define GGML_F32x4_LOAD vld1q_f32 -#define GGML_F32x4_STORE vst1q_f32 -#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c) -#define GGML_F32x4_ADD vaddq_f32 -#define GGML_F32x4_MUL vmulq_f32 -#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x) -#define GGML_F32x4_REDUCE(res, x) \ -{ \ - int offset = GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vaddq_f32(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vaddq_f32(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vaddq_f32(x[i], x[offset+i]); \ - } \ - res = GGML_F32x4_REDUCE_ONE(x[0]); \ -} - -#define GGML_F32_VEC GGML_F32x4 -#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO -#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 -#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD -#define GGML_F32_VEC_STORE GGML_F32x4_STORE -#define GGML_F32_VEC_FMA GGML_F32x4_FMA -#define GGML_F32_VEC_ADD GGML_F32x4_ADD -#define GGML_F32_VEC_MUL GGML_F32x4_MUL -#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE - -// F16 NEON - -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) - #define GGML_F16_STEP 32 - #define GGML_F16_EPR 8 - - #define GGML_F16x8 float16x8_t - #define GGML_F16x8_ZERO vdupq_n_f16(0.0f) - #define GGML_F16x8_SET1(x) vdupq_n_f16(x) - #define GGML_F16x8_LOAD vld1q_f16 - #define GGML_F16x8_STORE vst1q_f16 - #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c) - #define GGML_F16x8_ADD vaddq_f16 - #define GGML_F16x8_MUL vmulq_f16 - #define GGML_F16x8_REDUCE(res, x) \ - do { \ - int offset = GGML_F16_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vaddq_f16(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vaddq_f16(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vaddq_f16(x[i], x[offset+i]); \ - } \ - const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \ - const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \ - res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \ - } while (0) - - #define GGML_F16_VEC GGML_F16x8 - #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO - #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 - #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i]) - #define GGML_F16_VEC_FMA GGML_F16x8_FMA - #define GGML_F16_VEC_ADD GGML_F16x8_ADD - #define GGML_F16_VEC_MUL GGML_F16x8_MUL - #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE -#else - // if FP16 vector arithmetic is not supported, we use FP32 instead - // and take advantage of the vcvt_ functions to convert to/from FP16 - - #define GGML_F16_STEP 16 - #define GGML_F16_EPR 4 - - #define GGML_F32Cx4 float32x4_t - #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f) - #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x) - #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16(x)) - #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y)) - #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c) - #define GGML_F32Cx4_ADD vaddq_f32 - #define GGML_F32Cx4_MUL vmulq_f32 - #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE - - #define GGML_F16_VEC GGML_F32Cx4 - #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO - #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 - #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) - #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA - #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD - #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL - #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE -#endif - -#elif defined(__AVX__) - -#define GGML_SIMD - -// F32 AVX - -#define GGML_F32_STEP 32 -#define GGML_F32_EPR 8 - -#define GGML_F32x8 __m256 -#define GGML_F32x8_ZERO _mm256_setzero_ps() -#define GGML_F32x8_SET1(x) _mm256_set1_ps(x) -#define GGML_F32x8_LOAD _mm256_loadu_ps -#define GGML_F32x8_STORE _mm256_storeu_ps -#if defined(__FMA__) - #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a) -#else - #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a) -#endif -#define GGML_F32x8_ADD _mm256_add_ps -#define GGML_F32x8_MUL _mm256_mul_ps -#define GGML_F32x8_REDUCE(res, x) \ -do { \ - int offset = GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm256_add_ps(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm256_add_ps(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm256_add_ps(x[i], x[offset+i]); \ - } \ - const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \ - _mm256_extractf128_ps(x[0], 1)); \ - const __m128 t1 = _mm_hadd_ps(t0, t0); \ - res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \ -} while (0) -// TODO: is this optimal ? - -#define GGML_F32_VEC GGML_F32x8 -#define GGML_F32_VEC_ZERO GGML_F32x8_ZERO -#define GGML_F32_VEC_SET1 GGML_F32x8_SET1 -#define GGML_F32_VEC_LOAD GGML_F32x8_LOAD -#define GGML_F32_VEC_STORE GGML_F32x8_STORE -#define GGML_F32_VEC_FMA GGML_F32x8_FMA -#define GGML_F32_VEC_ADD GGML_F32x8_ADD -#define GGML_F32_VEC_MUL GGML_F32x8_MUL -#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE - -// F16 AVX - -#define GGML_F16_STEP 32 -#define GGML_F16_EPR 8 - -// F16 arithmetic is not supported by AVX, so we use F32 instead - -#define GGML_F32Cx8 __m256 -#define GGML_F32Cx8_ZERO _mm256_setzero_ps() -#define GGML_F32Cx8_SET1(x) _mm256_set1_ps(x) - -#if defined(__F16C__) -// the _mm256_cvt intrinsics require F16C -#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) -#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) -#else -static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) { - float tmp[8]; - - for (int i = 0; i < 8; i++) { - tmp[i] = GGML_FP16_TO_FP32(x[i]); - } - - return _mm256_loadu_ps(tmp); -} -static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) { - float arr[8]; - - _mm256_storeu_ps(arr, y); - - for (int i = 0; i < 8; i++) - x[i] = GGML_FP32_TO_FP16(arr[i]); -} -#define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x) -#define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y) -#endif - -#define GGML_F32Cx8_FMA GGML_F32x8_FMA -#define GGML_F32Cx8_ADD _mm256_add_ps -#define GGML_F32Cx8_MUL _mm256_mul_ps -#define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE - -#define GGML_F16_VEC GGML_F32Cx8 -#define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO -#define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1 -#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p) -#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i]) -#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA -#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD -#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL -#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE - -#elif defined(__POWER9_VECTOR__) - -#define GGML_SIMD - -// F32 POWER9 - -#define GGML_F32_STEP 32 -#define GGML_F32_EPR 4 - -#define GGML_F32x4 vector float -#define GGML_F32x4_ZERO 0.0f -#define GGML_F32x4_SET1 vec_splats -#define GGML_F32x4_LOAD(p) vec_xl(0, p) -#define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p) -#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a) -#define GGML_F32x4_ADD vec_add -#define GGML_F32x4_MUL vec_mul -#define GGML_F32x4_REDUCE(res, x) \ -{ \ - int offset = GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vec_add(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vec_add(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vec_add(x[i], x[offset+i]); \ - } \ - res = vec_extract(x[0], 0) + \ - vec_extract(x[0], 1) + \ - vec_extract(x[0], 2) + \ - vec_extract(x[0], 3); \ -} - -#define GGML_F32_VEC GGML_F32x4 -#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO -#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 -#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD -#define GGML_F32_VEC_STORE GGML_F32x4_STORE -#define GGML_F32_VEC_FMA GGML_F32x4_FMA -#define GGML_F32_VEC_ADD GGML_F32x4_ADD -#define GGML_F32_VEC_MUL GGML_F32x4_MUL -#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE - -// F16 POWER9 -#define GGML_F16_STEP GGML_F32_STEP -#define GGML_F16_EPR GGML_F32_EPR -#define GGML_F16_VEC GGML_F32x4 -#define GGML_F16_VEC_ZERO GGML_F32x4_ZERO -#define GGML_F16_VEC_SET1 GGML_F32x4_SET1 -#define GGML_F16_VEC_FMA GGML_F32x4_FMA -#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE -// Use vec_xl, not vec_ld, in case the load address is not aligned. -#define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \ - vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \ - vec_extract_fp32_from_shortl(vec_xl(0, p)) -#define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i] -#define GGML_F16_VEC_STORE(p, r, i) \ - if (i & 0x1) \ - vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)], \ - r[i - GGML_ENDIAN_BYTE(0)]), \ - 0, p - GGML_F16_EPR) - -#elif defined(__wasm_simd128__) - -#define GGML_SIMD - -// F32 WASM - -#define GGML_F32_STEP 16 -#define GGML_F32_EPR 4 - -#define GGML_F32x4 v128_t -#define GGML_F32x4_ZERO wasm_f32x4_splat(0.0f) -#define GGML_F32x4_SET1(x) wasm_f32x4_splat(x) -#define GGML_F32x4_LOAD wasm_v128_load -#define GGML_F32x4_STORE wasm_v128_store -#define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a) -#define GGML_F32x4_ADD wasm_f32x4_add -#define GGML_F32x4_MUL wasm_f32x4_mul -#define GGML_F32x4_REDUCE(res, x) \ -{ \ - int offset = GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ - } \ - res = wasm_f32x4_extract_lane(x[0], 0) + \ - wasm_f32x4_extract_lane(x[0], 1) + \ - wasm_f32x4_extract_lane(x[0], 2) + \ - wasm_f32x4_extract_lane(x[0], 3); \ -} - -#define GGML_F32_VEC GGML_F32x4 -#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO -#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 -#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD -#define GGML_F32_VEC_STORE GGML_F32x4_STORE -#define GGML_F32_VEC_FMA GGML_F32x4_FMA -#define GGML_F32_VEC_ADD GGML_F32x4_ADD -#define GGML_F32_VEC_MUL GGML_F32x4_MUL -#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE - -// F16 WASM - -#define GGML_F16_STEP 16 -#define GGML_F16_EPR 4 - -inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) { - float tmp[4]; - - tmp[0] = GGML_FP16_TO_FP32(p[0]); - tmp[1] = GGML_FP16_TO_FP32(p[1]); - tmp[2] = GGML_FP16_TO_FP32(p[2]); - tmp[3] = GGML_FP16_TO_FP32(p[3]); - - return wasm_v128_load(tmp); -} - -inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) { - float tmp[4]; - - wasm_v128_store(tmp, x); - - p[0] = GGML_FP32_TO_FP16(tmp[0]); - p[1] = GGML_FP32_TO_FP16(tmp[1]); - p[2] = GGML_FP32_TO_FP16(tmp[2]); - p[3] = GGML_FP32_TO_FP16(tmp[3]); -} - -#define GGML_F16x4 v128_t -#define GGML_F16x4_ZERO wasm_f32x4_splat(0.0f) -#define GGML_F16x4_SET1(x) wasm_f32x4_splat(x) -#define GGML_F16x4_LOAD(x) __wasm_f16x4_load(x) -#define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y) -#define GGML_F16x4_FMA GGML_F32x4_FMA -#define GGML_F16x4_ADD wasm_f32x4_add -#define GGML_F16x4_MUL wasm_f32x4_mul -#define GGML_F16x4_REDUCE(res, x) \ -{ \ - int offset = GGML_F16_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ - } \ - res = wasm_f32x4_extract_lane(x[0], 0) + \ - wasm_f32x4_extract_lane(x[0], 1) + \ - wasm_f32x4_extract_lane(x[0], 2) + \ - wasm_f32x4_extract_lane(x[0], 3); \ -} - -#define GGML_F16_VEC GGML_F16x4 -#define GGML_F16_VEC_ZERO GGML_F16x4_ZERO -#define GGML_F16_VEC_SET1 GGML_F16x4_SET1 -#define GGML_F16_VEC_LOAD(p, i) GGML_F16x4_LOAD(p) -#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i]) -#define GGML_F16_VEC_FMA GGML_F16x4_FMA -#define GGML_F16_VEC_ADD GGML_F16x4_ADD -#define GGML_F16_VEC_MUL GGML_F16x4_MUL -#define GGML_F16_VEC_REDUCE GGML_F16x4_REDUCE - -#elif defined(__SSE3__) - -#define GGML_SIMD - -// F32 SSE - -#define GGML_F32_STEP 32 -#define GGML_F32_EPR 4 - -#define GGML_F32x4 __m128 -#define GGML_F32x4_ZERO _mm_setzero_ps() -#define GGML_F32x4_SET1(x) _mm_set1_ps(x) -#define GGML_F32x4_LOAD _mm_loadu_ps -#define GGML_F32x4_STORE _mm_storeu_ps -#if defined(__FMA__) - // TODO: Does this work? - #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a) -#else - #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a) -#endif -#define GGML_F32x4_ADD _mm_add_ps -#define GGML_F32x4_MUL _mm_mul_ps -#define GGML_F32x4_REDUCE(res, x) \ -{ \ - int offset = GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm_add_ps(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm_add_ps(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm_add_ps(x[i], x[offset+i]); \ - } \ - const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \ - res = _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \ -} -// TODO: is this optimal ? - -#define GGML_F32_VEC GGML_F32x4 -#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO -#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 -#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD -#define GGML_F32_VEC_STORE GGML_F32x4_STORE -#define GGML_F32_VEC_FMA GGML_F32x4_FMA -#define GGML_F32_VEC_ADD GGML_F32x4_ADD -#define GGML_F32_VEC_MUL GGML_F32x4_MUL -#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE - -// F16 SSE - -#define GGML_F16_STEP 32 -#define GGML_F16_EPR 4 - -static inline __m128 __sse_f16x4_load(ggml_fp16_t *x) { - float tmp[4]; - - tmp[0] = GGML_FP16_TO_FP32(x[0]); - tmp[1] = GGML_FP16_TO_FP32(x[1]); - tmp[2] = GGML_FP16_TO_FP32(x[2]); - tmp[3] = GGML_FP16_TO_FP32(x[3]); - - return _mm_loadu_ps(tmp); -} - -static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) { - float arr[4]; - - _mm_storeu_ps(arr, y); - - x[0] = GGML_FP32_TO_FP16(arr[0]); - x[1] = GGML_FP32_TO_FP16(arr[1]); - x[2] = GGML_FP32_TO_FP16(arr[2]); - x[3] = GGML_FP32_TO_FP16(arr[3]); -} - -#define GGML_F32Cx4 __m128 -#define GGML_F32Cx4_ZERO _mm_setzero_ps() -#define GGML_F32Cx4_SET1(x) _mm_set1_ps(x) -#define GGML_F32Cx4_LOAD(x) __sse_f16x4_load(x) -#define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y) -#define GGML_F32Cx4_FMA GGML_F32x4_FMA -#define GGML_F32Cx4_ADD _mm_add_ps -#define GGML_F32Cx4_MUL _mm_mul_ps -#define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE - -#define GGML_F16_VEC GGML_F32Cx4 -#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO -#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 -#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) -#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) -#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA -#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD -#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL -#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE - -#endif - -// GGML_F32_ARR / GGML_F16_ARR -// number of registers to use per step -#ifdef GGML_SIMD -#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR) -#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR) -#endif - -// -// fundamental operations -// - -inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } -inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; } -inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } -inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } -inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; } -inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } -inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } -inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } -inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } -inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } - -static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) { -#ifdef GGML_SIMD - float sumf = 0.0f; - const int np = (n & ~(GGML_F32_STEP - 1)); - - GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; - - GGML_F32_VEC ax[GGML_F32_ARR]; - GGML_F32_VEC ay[GGML_F32_ARR]; - - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - - sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); - } - } - - // reduce sum0..sum3 to sum0 - GGML_F32_VEC_REDUCE(sumf, sum); - - // leftovers - for (int i = np; i < n; ++i) { - sumf += x[i]*y[i]; - } -#else - // scalar - ggml_float sumf = 0.0; - for (int i = 0; i < n; ++i) { - sumf += (ggml_float)(x[i]*y[i]); - } -#endif - - *s = sumf; -} - -static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) { - ggml_float sumf = 0.0; - -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F16_STEP - 1)); - - GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO }; - - GGML_F16_VEC ax[GGML_F16_ARR]; - GGML_F16_VEC ay[GGML_F16_ARR]; - - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); - - sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]); - } - } - - // reduce sum0..sum3 to sum0 - GGML_F16_VEC_REDUCE(sumf, sum); - - // leftovers - for (int i = np; i < n; ++i) { - sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i])); - } -#else - for (int i = 0; i < n; ++i) { - sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i])); - } -#endif - - *s = sumf; -} - -static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int qk = QK8_0; - const int nb = n / qk; - - assert(n % qk == 0); - - const block_q4_0 * restrict x = vx; - const block_q8_0 * restrict y = vy; - -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb - for (int i = 0; i < nb; i += 2) { - const block_q4_0 * restrict x0 = &x[i + 0]; - const block_q4_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i + 0]; - const block_q8_0 * restrict y1 = &y[i + 1]; - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - const int8x16_t s8b = vdupq_n_s8(0x8); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // sub 8 - const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); - const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); - const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); - const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); - - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - -#if defined(__ARM_FEATURE_DOTPROD) - // dot product into int32x4_t - const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h); - const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h)); - - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); -#endif - } - - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - // Main loop - for (int i = 0; i < nb; ++i) { - /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); - - __m256i bx = bytes_from_nibbles_32(x[i].qs); - - // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. - const __m256i off = _mm256_set1_epi8( 8 ); - bx = _mm256_sub_epi8( bx, off ); - - __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - - const __m256 q = mul_sum_i8_pairs_float(bx, by); - - /* Multiply q with scale and accumulate */ - acc = _mm256_fmadd_ps( d, q, acc ); - } - - *s = hsum_float_8(acc); -#elif defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - // Main loop - for (int i = 0; i < nb; ++i) { - // Compute combined scale for the block - const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); - - const __m128i lowMask = _mm_set1_epi8(0xF); - const __m128i off = _mm_set1_epi8(8); - - const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs); - - __m128i bx = _mm_and_si128(lowMask, tmp); - __m128i by = _mm_loadu_si128((const __m128i *)y[i].qs); - bx = _mm_sub_epi8(bx, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx, by); - - bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4)); - by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); - bx = _mm_sub_epi8(bx, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx, by); - - // Convert int32_t to float - __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1)); - - // Apply the scale, and accumulate - acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); - } - - *s = hsum_float_8(acc); -#elif defined(__SSSE3__) - // set constants - const __m128i lowMask = _mm_set1_epi8(0xF); - const __m128i off = _mm_set1_epi8(8); - - // Initialize accumulator with zeros - __m128 acc_0 = _mm_setzero_ps(); - __m128 acc_1 = _mm_setzero_ps(); - __m128 acc_2 = _mm_setzero_ps(); - __m128 acc_3 = _mm_setzero_ps(); - - // First round without accumulation - { - _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) ); - - const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs); - - __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); - __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs); - bx_0 = _mm_sub_epi8(bx_0, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); - - __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); - __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16)); - bx_1 = _mm_sub_epi8(bx_1, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); - - _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) ); - - const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs); - - __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); - __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs); - bx_2 = _mm_sub_epi8(bx_2, off); - const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); - - __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); - __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16)); - bx_3 = _mm_sub_epi8(bx_3, off); - const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); - - // Convert int32_t to float - __m128 p0 = _mm_cvtepi32_ps(i32_0); - __m128 p1 = _mm_cvtepi32_ps(i32_1); - __m128 p2 = _mm_cvtepi32_ps(i32_2); - __m128 p3 = _mm_cvtepi32_ps(i32_3); - - // Apply the scale - acc_0 = _mm_mul_ps( d_0_1, p0 ); - acc_1 = _mm_mul_ps( d_0_1, p1 ); - acc_2 = _mm_mul_ps( d_2_3, p2 ); - acc_3 = _mm_mul_ps( d_2_3, p3 ); - } - - // Main loop - GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb - for (int i = 2; i < nb; i+=2) { - _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); - - const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs); - - __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); - __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs); - bx_0 = _mm_sub_epi8(bx_0, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); - - __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); - __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); - bx_1 = _mm_sub_epi8(bx_1, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); - - _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) ); - - const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs); - - __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); - __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs); - bx_2 = _mm_sub_epi8(bx_2, off); - const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); - - __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); - __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16)); - bx_3 = _mm_sub_epi8(bx_3, off); - const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); - - // Convert int32_t to float - __m128 p0 = _mm_cvtepi32_ps(i32_0); - __m128 p1 = _mm_cvtepi32_ps(i32_1); - __m128 p2 = _mm_cvtepi32_ps(i32_2); - __m128 p3 = _mm_cvtepi32_ps(i32_3); - - // Apply the scale - __m128 p0_d = _mm_mul_ps( d_0_1, p0 ); - __m128 p1_d = _mm_mul_ps( d_0_1, p1 ); - __m128 p2_d = _mm_mul_ps( d_2_3, p2 ); - __m128 p3_d = _mm_mul_ps( d_2_3, p3 ); - - // Acummulate - acc_0 = _mm_add_ps(p0_d, acc_0); - acc_1 = _mm_add_ps(p1_d, acc_1); - acc_2 = _mm_add_ps(p2_d, acc_2); - acc_3 = _mm_add_ps(p3_d, acc_3); - } - - *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); -#elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - for (int i = 0; i < nb; i++) { - // load elements - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); - - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); - - // mask and store lower part of x, and then upper part - vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - - vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - - // subtract offset - vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl); - vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl); - - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - - sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d); - } - - *s = sumf; -#else - // scalar - float sumf = 0.0; - - for (int i = 0; i < nb; i++) { - int sumi = 0; - - for (int j = 0; j < qk/2; ++j) { - const int v0 = (x[i].qs[j] & 0x0F) - 8; - const int v1 = (x[i].qs[j] >> 4) - 8; - - sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); - } - - sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d); - } - - *s = sumf; -#endif -} - -static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int qk = QK8_1; - const int nb = n / qk; - - assert(n % qk == 0); - - const block_q4_1 * restrict x = vx; - const block_q8_1 * restrict y = vy; - - // TODO: add WASM SIMD -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - float summs = 0; - - GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb - for (int i = 0; i < nb; i += 2) { - const block_q4_1 * restrict x0 = &x[i + 0]; - const block_q4_1 * restrict x1 = &x[i + 1]; - const block_q8_1 * restrict y0 = &y[i + 0]; - const block_q8_1 * restrict y1 = &y[i + 1]; - - summs += GGML_FP16_TO_FP32(x0->m) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s; - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - -#if defined(__ARM_FEATURE_DOTPROD) - // dot product into int32x4_t - const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h); - const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h)); - - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d); -#endif - } - - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; -#elif defined(__AVX2__) || defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - float summs = 0; - - // Main loop - for (int i = 0; i < nb; ++i) { - const float d0 = GGML_FP16_TO_FP32(x[i].d); - const float d1 = y[i].d; - - summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; - - const __m256 d0v = _mm256_set1_ps( d0 ); - const __m256 d1v = _mm256_set1_ps( d1 ); - - // Compute combined scales - const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); - - // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes - const __m256i bx = bytes_from_nibbles_32(x[i].qs); - const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs ); - - const __m256 xy = mul_sum_us8_pairs_float(bx, by); - - // Accumulate d0*d1*x*y -#if defined(__AVX2__) - acc = _mm256_fmadd_ps( d0d1, xy, acc ); -#else - acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc ); -#endif - } - - *s = hsum_float_8(acc) + summs; -#elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - for (int i = 0; i < nb; i++) { - // load elements - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); - - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); - - // mask and store lower part of x, and then upper part - vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - - vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - - sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; - } - - *s = sumf; -#else - // scalar - float sumf = 0.0; - - for (int i = 0; i < nb; i++) { - int sumi = 0; - - for (int j = 0; j < qk/2; ++j) { - const int v0 = (x[i].qs[j] & 0x0F); - const int v1 = (x[i].qs[j] >> 4); - - sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); - } - - sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; - } - - *s = sumf; -#endif -} - -static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int qk = QK8_0; - const int nb = n / qk; - - assert(n % qk == 0); - assert(qk == QK5_0); - - const block_q5_0 * restrict x = vx; - const block_q8_0 * restrict y = vy; - -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - uint32_t qh0; - uint32_t qh1; - - uint64_t tmp0[4]; - uint64_t tmp1[4]; - - GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb - for (int i = 0; i < nb; i += 2) { - const block_q5_0 * restrict x0 = &x[i]; - const block_q5_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i]; - const block_q8_0 * restrict y1 = &y[i + 1]; - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - - // extract the 5th bit via lookup table ((!b) << 4) - memcpy(&qh0, x0->qh, sizeof(qh0)); - memcpy(&qh1, x1->qh, sizeof(qh1)); - - tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF]; - tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF]; - tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF]; - tmp0[3] = table_b2b_1[(qh0 >> 24) ]; - - tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF]; - tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF]; - tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF]; - tmp1[3] = table_b2b_1[(qh1 >> 24) ]; - - const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); - const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); - const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); - const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - - // 4-bit -> 8-bit - int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) - const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0); - const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0); - const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1); - const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1); - - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - -#if defined(__ARM_FEATURE_DOTPROD) - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), - vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), - vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h)); - - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); -#endif - } - - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__wasm_simd128__) - v128_t sumv = wasm_f32x4_splat(0.0f); - - uint32_t qh; - uint64_t tmp[4]; - - // TODO: check if unrolling this is better - for (int i = 0; i < nb; ++i) { - const block_q5_0 * restrict x0 = &x[i]; - const block_q8_0 * restrict y0 = &y[i]; - - const v128_t m4b = wasm_i8x16_splat(0x0F); - - // extract the 5th bit - memcpy(&qh, x0->qh, sizeof(qh)); - - tmp[0] = table_b2b_1[(qh >> 0) & 0xFF]; - tmp[1] = table_b2b_1[(qh >> 8) & 0xFF]; - tmp[2] = table_b2b_1[(qh >> 16) & 0xFF]; - tmp[3] = table_b2b_1[(qh >> 24) ]; - - const v128_t qhl = wasm_v128_load(tmp + 0); - const v128_t qhh = wasm_v128_load(tmp + 2); - - const v128_t v0 = wasm_v128_load(x0->qs); - - // 4-bit -> 8-bit - const v128_t v0l = wasm_v128_and (v0, m4b); - const v128_t v0h = wasm_u8x16_shr(v0, 4); - - // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) - const v128_t v0lf = wasm_i8x16_sub(v0l, qhl); - const v128_t v0hf = wasm_i8x16_sub(v0h, qhh); - - // load y - const v128_t v1l = wasm_v128_load(y0->qs); - const v128_t v1h = wasm_v128_load(y0->qs + 16); - - // int8x16 -> int16x8 - const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); - const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); - const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); - const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); - - const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); - const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); - const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); - const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); - - // dot product - sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4( - wasm_i32x4_add( - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), - wasm_i32x4_dot_i16x8(v0lfh, v1lh)), - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), - wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), - wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d)))); - } - - *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + - wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - // Main loop - for (int i = 0; i < nb; i++) { - /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); - - __m256i bx = bytes_from_nibbles_32(x[i].qs); - __m256i bxhi = bytes_from_bits_32(x[i].qh); - bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0)); - bx = _mm256_or_si256(bx, bxhi); - - __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - - const __m256 q = mul_sum_i8_pairs_float(bx, by); - - /* Multiply q with scale and accumulate */ - acc = _mm256_fmadd_ps(d, q, acc); - } - - *s = hsum_float_8(acc); -#elif defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - __m128i mask = _mm_set1_epi8((char)0xF0); - - // Main loop - for (int i = 0; i < nb; i++) { - /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); - - __m256i bx = bytes_from_nibbles_32(x[i].qs); - const __m256i bxhi = bytes_from_bits_32(x[i].qh); - __m128i bxhil = _mm256_castsi256_si128(bxhi); - __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); - bxhil = _mm_andnot_si128(bxhil, mask); - bxhih = _mm_andnot_si128(bxhih, mask); - __m128i bxl = _mm256_castsi256_si128(bx); - __m128i bxh = _mm256_extractf128_si256(bx, 1); - bxl = _mm_or_si128(bxl, bxhil); - bxh = _mm_or_si128(bxh, bxhih); - bx = MM256_SET_M128I(bxh, bxl); - - const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - - const __m256 q = mul_sum_i8_pairs_float(bx, by); - - /* Multiply q with scale and accumulate */ - acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc); - } - - *s = hsum_float_8(acc); -#elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - - uint32_t qh; - - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - // These tempory registers are for masking and shift operations - vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); - vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl); - - vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl); - vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); - - for (int i = 0; i < nb; i++) { - memcpy(&qh, x[i].qh, sizeof(uint32_t)); - - // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; - vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl); - vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl); - vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); - - // ((qh & (1u << (j + 16))) >> (j + 12)); - vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl); - vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl); - - // narrowing - vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl); - vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); - - vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl); - vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); - - // load - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); - - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); - - vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - - vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); - vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); - - vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - - vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl); - vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl); - - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - - sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi; - } - - *s = sumf; -#else - // scalar - float sumf = 0.0; - - for (int i = 0; i < nb; i++) { - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); - - int sumi = 0; - - for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; - const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); - - const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16; - const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; - - sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); - } - - sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi; - } - - *s = sumf; -#endif -} - -static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int qk = QK8_1; - const int nb = n / qk; - - assert(n % qk == 0); - assert(qk == QK5_1); - - const block_q5_1 * restrict x = vx; - const block_q8_1 * restrict y = vy; - -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - float summs0 = 0.0f; - float summs1 = 0.0f; - - uint32_t qh0; - uint32_t qh1; - - uint64_t tmp0[4]; - uint64_t tmp1[4]; - - GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb - for (int i = 0; i < nb; i += 2) { - const block_q5_1 * restrict x0 = &x[i]; - const block_q5_1 * restrict x1 = &x[i + 1]; - const block_q8_1 * restrict y0 = &y[i]; - const block_q8_1 * restrict y1 = &y[i + 1]; - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - - summs0 += GGML_FP16_TO_FP32(x0->m) * y0->s; - summs1 += GGML_FP16_TO_FP32(x1->m) * y1->s; - - // extract the 5th bit via lookup table ((b) << 4) - memcpy(&qh0, x0->qh, sizeof(qh0)); - memcpy(&qh1, x1->qh, sizeof(qh1)); - - tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF]; - tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF]; - tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF]; - tmp0[3] = table_b2b_0[(qh0 >> 24) ]; - - tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF]; - tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF]; - tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF]; - tmp1[3] = table_b2b_0[(qh1 >> 24) ]; - - const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); - const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); - const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); - const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // add high bit - const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0); - const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0); - const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1); - const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1); - - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - -#if defined(__ARM_FEATURE_DOTPROD) - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), - vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), - vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h)); - - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d); -#endif - } - - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; -#elif defined(__wasm_simd128__) - v128_t sumv = wasm_f32x4_splat(0.0f); - - float summs = 0.0f; - - uint32_t qh; - uint64_t tmp[4]; - - // TODO: check if unrolling this is better - for (int i = 0; i < nb; ++i) { - const block_q5_1 * restrict x0 = &x[i]; - const block_q8_1 * restrict y0 = &y[i]; - - summs += GGML_FP16_TO_FP32(x0->m) * y0->s; - - const v128_t m4b = wasm_i8x16_splat(0x0F); - - // extract the 5th bit - memcpy(&qh, x0->qh, sizeof(qh)); - - tmp[0] = table_b2b_0[(qh >> 0) & 0xFF]; - tmp[1] = table_b2b_0[(qh >> 8) & 0xFF]; - tmp[2] = table_b2b_0[(qh >> 16) & 0xFF]; - tmp[3] = table_b2b_0[(qh >> 24) ]; - - const v128_t qhl = wasm_v128_load(tmp + 0); - const v128_t qhh = wasm_v128_load(tmp + 2); - - const v128_t v0 = wasm_v128_load(x0->qs); - - // 4-bit -> 8-bit - const v128_t v0l = wasm_v128_and (v0, m4b); - const v128_t v0h = wasm_u8x16_shr(v0, 4); - - // add high bit - const v128_t v0lf = wasm_v128_or(v0l, qhl); - const v128_t v0hf = wasm_v128_or(v0h, qhh); - - // load y - const v128_t v1l = wasm_v128_load(y0->qs); - const v128_t v1h = wasm_v128_load(y0->qs + 16); - - // int8x16 -> int16x8 - const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); - const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); - const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); - const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); - - const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); - const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); - const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); - const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); - - // dot product - sumv = wasm_f32x4_add(sumv, - wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add( - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), - wasm_i32x4_dot_i16x8(v0lfh, v1lh)), - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), - wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), - wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d))); - } - - *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + - wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - float summs = 0.0f; - - // Main loop - for (int i = 0; i < nb; i++) { - const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)); - - summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; - - __m256i bx = bytes_from_nibbles_32(x[i].qs); - __m256i bxhi = bytes_from_bits_32(x[i].qh); - bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10)); - bx = _mm256_or_si256(bx, bxhi); - - const __m256 dy = _mm256_set1_ps(y[i].d); - const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - - const __m256 q = mul_sum_us8_pairs_float(bx, by); - - acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); - } - - *s = hsum_float_8(acc) + summs; -#elif defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - __m128i mask = _mm_set1_epi8(0x10); - - float summs = 0.0f; - - // Main loop - for (int i = 0; i < nb; i++) { - const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)); - - summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; - - __m256i bx = bytes_from_nibbles_32(x[i].qs); - const __m256i bxhi = bytes_from_bits_32(x[i].qh); - __m128i bxhil = _mm256_castsi256_si128(bxhi); - __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); - bxhil = _mm_and_si128(bxhil, mask); - bxhih = _mm_and_si128(bxhih, mask); - __m128i bxl = _mm256_castsi256_si128(bx); - __m128i bxh = _mm256_extractf128_si256(bx, 1); - bxl = _mm_or_si128(bxl, bxhil); - bxh = _mm_or_si128(bxh, bxhih); - bx = MM256_SET_M128I(bxh, bxl); - - const __m256 dy = _mm256_set1_ps(y[i].d); - const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - - const __m256 q = mul_sum_us8_pairs_float(bx, by); - - acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); - } - - *s = hsum_float_8(acc) + summs; -#elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - - uint32_t qh; - - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - // temporary registers for shift operations - vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); - vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); - - for (int i = 0; i < nb; i++) { - memcpy(&qh, x[i].qh, sizeof(uint32_t)); - - // load qh - vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl); - - // ((qh >> (j + 0)) << 4) & 0x10; - vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl); - vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); - vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl); - - // ((qh >> (j + 12)) ) & 0x10; - vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl); - vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl); - - // narrowing - vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl); - vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); - - vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl); - vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); - - // load - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); - - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); - - vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - - vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); - vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); - - vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - - sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; - } - - *s = sumf; -#else - // scalar - float sumf = 0.0; - - for (int i = 0; i < nb; i++) { - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); - - int sumi = 0; - - for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; - const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; - - const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0; - const int32_t x1 = (x[i].qs[j] >> 4) | xh_1; - - sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); - } - - sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; - } - - *s = sumf; -#endif -} - -static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int qk = QK8_0; - const int nb = n / qk; - - assert(n % qk == 0); - - const block_q8_0 * restrict x = vx; - const block_q8_0 * restrict y = vy; - -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb - for (int i = 0; i < nb; i += 2) { - const block_q8_0 * restrict x0 = &x[i + 0]; - const block_q8_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i + 0]; - const block_q8_0 * restrict y1 = &y[i + 1]; - - const int8x16_t x0_0 = vld1q_s8(x0->qs); - const int8x16_t x0_1 = vld1q_s8(x0->qs + 16); - const int8x16_t x1_0 = vld1q_s8(x1->qs); - const int8x16_t x1_1 = vld1q_s8(x1->qs + 16); - - // load y - const int8x16_t y0_0 = vld1q_s8(y0->qs); - const int8x16_t y0_1 = vld1q_s8(y0->qs + 16); - const int8x16_t y1_0 = vld1q_s8(y1->qs); - const int8x16_t y1_1 = vld1q_s8(y1->qs + 16); - -#if defined(__ARM_FEATURE_DOTPROD) - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), x0_0, y0_0), - vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), x1_0, y1_0), - vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - -#else - const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0)); - const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0)); - const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1)); - const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1)); - - const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0)); - const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0)); - const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1)); - const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1)); - - const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1)); - const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3)); - const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1)); - const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); -#endif - } - - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__AVX2__) || defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - // Main loop - for (int i = 0; i < nb; ++i) { - // Compute combined scale for the block - const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); - __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs); - __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - - const __m256 q = mul_sum_i8_pairs_float(bx, by); - - // Multiply q with scale and accumulate -#if defined(__AVX2__) - acc = _mm256_fmadd_ps( d, q, acc ); -#else - acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc ); -#endif - } - - *s = hsum_float_8(acc); -#elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - size_t vl = __riscv_vsetvl_e8m1(qk); - - for (int i = 0; i < nb; i++) { - // load elements - vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl); - vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl); - - vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl); - - vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl); - vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum); - - sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)); - } - - *s = sumf; -#else - // scalar - float sumf = 0.0; - - for (int i = 0; i < nb; i++) { - int sumi = 0; - - for (int j = 0; j < qk; j++) { - sumi += x[i].qs[j]*y[i].qs[j]; - } - - sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)); - } - - *s = sumf; -#endif -} - -// compute GGML_VEC_DOT_UNROLL dot products at once -// xs - x row stride in bytes -inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) { - ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 }; - - ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL]; - - for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { - x[i] = (ggml_fp16_t *) ((char *) xv + i*xs); - } - -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F16_STEP - 1)); - - GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } }; - - GGML_F16_VEC ax[GGML_F16_ARR]; - GGML_F16_VEC ay[GGML_F16_ARR]; - - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); - - for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { - ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j); - - sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]); - } - } - } - - // reduce sum0..sum3 to sum0 - for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { - GGML_F16_VEC_REDUCE(sumf[k], sum[k]); - } - - // leftovers - for (int i = np; i < n; ++i) { - for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { - sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i])); - } - } -#else - for (int i = 0; i < n; ++i) { - for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { - sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i])); - } - } -#endif - - for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { - s[i] = sumf[i]; - } -} - -inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) { -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F32_STEP - 1)); - - GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); - - GGML_F32_VEC ax[GGML_F32_ARR]; - GGML_F32_VEC ay[GGML_F32_ARR]; - - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx); - - GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); - } - } - - // leftovers - for (int i = np; i < n; ++i) { - y[i] += x[i]*v; - } -#else - // scalar - for (int i = 0; i < n; ++i) { - y[i] += x[i]*v; - } -#endif -} - -// xs and vs are byte strides of x and v -inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) { - - const float * restrict x[GGML_VEC_MAD_UNROLL]; - const float * restrict v[GGML_VEC_MAD_UNROLL]; - - for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) { - x[i] = (const float *) ((const char *) xv + i*xs); - v[i] = (const float *) ((const char *) vv + i*vs); - } - -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F32_STEP - 1)); - - GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL]; - - for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { - vx[k] = GGML_F32_VEC_SET1(v[k][0]); - } - - GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR]; - GGML_F32_VEC ay[GGML_F32_ARR]; - - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - - for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { - ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]); - } - - GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); - } - } - - // leftovers - for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { - for (int i = np; i < n; ++i) { - y[i] += x[k][i]*v[k][0]; - } - } -#else - // scalar - for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { - for (int i = 0; i < n; ++i) { - y[i] += x[k][i]*v[k][0]; - } - } -#endif -} - -//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } -inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { -#if defined(GGML_USE_ACCELERATE) - vDSP_vsmul(y, 1, &v, y, 1, n); -#elif defined(GGML_SIMD) - const int np = (n & ~(GGML_F32_STEP - 1)); - - GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); - - GGML_F32_VEC ay[GGML_F32_ARR]; - - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_MUL(ay[j], vx); - - GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); - } - } - - // leftovers - for (int i = np; i < n; ++i) { - y[i] *= v; - } -#else - // scalar - for (int i = 0; i < n; ++i) { - y[i] *= v; - } -#endif -} - -inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); } -inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } -inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } -inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); } -inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); } -inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); } -inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } -inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); } -inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; } -inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } - -static const float GELU_COEF_A = 0.044715f; -static const float GELU_QUICK_COEF = -1.702f; -static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - -inline static float ggml_gelu_f32(float x) { - return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { - const uint16_t * i16 = (const uint16_t *) x; - for (int i = 0; i < n; ++i) { - y[i] = table_gelu_f16[i16[i]]; - } -} - -#ifdef GGML_GELU_FP16 -inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { - uint16_t t; - for (int i = 0; i < n; ++i) { - ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); - memcpy(&t, &fp16, sizeof(uint16_t)); - y[i] = GGML_FP16_TO_FP32(table_gelu_f16[t]); - } -} -#else -inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { - for (int i = 0; i < n; ++i) { - y[i] = ggml_gelu_f32(x[i]); - } -} -#endif - -inline static float ggml_gelu_quick_f32(float x) { - return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x))); -} - -//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { -// const uint16_t * i16 = (const uint16_t *) x; -// for (int i = 0; i < n; ++i) { -// y[i] = table_gelu_quick_f16[i16[i]]; -// } -//} - -#ifdef GGML_GELU_QUICK_FP16 -inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { - uint16_t t; - for (int i = 0; i < n; ++i) { - ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); - memcpy(&t, &fp16, sizeof(uint16_t)); - y[i] = GGML_FP16_TO_FP32(table_gelu_quick_f16[t]); - } -} -#else -inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { - for (int i = 0; i < n; ++i) { - y[i] = ggml_gelu_quick_f32(x[i]); - } -} -#endif - -// Sigmoid Linear Unit (SiLU) function -inline static float ggml_silu_f32(float x) { - return x/(1.0f + expf(-x)); -} - -//inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { -// const uint16_t * i16 = (const uint16_t *) x; -// for (int i = 0; i < n; ++i) { -// y[i] = table_silu_f16[i16[i]]; -// } -//} - -#ifdef GGML_SILU_FP16 -inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) { - uint16_t t; - for (int i = 0; i < n; ++i) { - ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); - memcpy(&t, &fp16, sizeof(uint16_t)); - y[i] = GGML_FP16_TO_FP32(table_silu_f16[t]); - } -} -#else -inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) { - for (int i = 0; i < n; ++i) { - y[i] = ggml_silu_f32(x[i]); - } -} -#endif - -inline static float ggml_silu_backward_f32(float x, float dy) { - const float s = 1.0f/(1.0f + expf(-x)); - return dy*s*(1.0f + x*(1.0f - s)); -} - -#ifdef GGML_SILU_FP16 -inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) { - for (int i = 0; i < n; ++i) { - // we did not use x[i] to compute forward silu but its f16 equivalent - // take derivative at f16 of x[i]: - ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); - float usedx = GGML_FP16_TO_FP32(fp16); - dx[i] = ggml_silu_backward_f32(usedx, dy[i]); - } -} -#else -inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) { - for (int i = 0; i < n; ++i) { - dx[i] = ggml_silu_backward_f32(x[i], dy[i]); - } -} -#endif - -inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { -#ifndef GGML_USE_ACCELERATE - ggml_float sum = 0.0; - for (int i = 0; i < n; ++i) { - sum += (ggml_float)x[i]; - } - *s = sum; -#else - vDSP_sve(x, 1, s, n); -#endif -} - -inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) { - ggml_float sum = 0.0; - for (int i = 0; i < n; ++i) { - sum += (ggml_float)x[i]; - } - *s = sum; -} - -inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_t * x) { - float sum = 0.0f; - for (int i = 0; i < n; ++i) { - sum += GGML_FP16_TO_FP32(x[i]); - } - *s = sum; -} - -inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { -#ifndef GGML_USE_ACCELERATE - float max = -INFINITY; - for (int i = 0; i < n; ++i) { - max = MAX(max, x[i]); - } - *s = max; -#else - vDSP_maxv(x, 1, s, n); -#endif -} - -inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { - ggml_vec_norm_f32(n, s, x); - *s = 1.f/(*s); -} - -inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) { - float max = -INFINITY; - int idx = 0; - for (int i = 0; i < n; ++i) { - max = MAX(max, x[i]); - if (max == x[i]) { idx = i; } - } - *s = idx; -} - -// -// data types -// - -static const char * GGML_OP_NAME[GGML_OP_COUNT] = { - "NONE", - - "DUP", - "ADD", - "ADD1", - "ACC", - "SUB", - "MUL", - "DIV", - "SQR", - "SQRT", - "LOG", - "SUM", - "SUM_ROWS", - "MEAN", - "ARGMAX", - "REPEAT", - "REPEAT_BACK", - "CONCAT", - "SILU_BACK", - "NORM", - "RMS_NORM", - "RMS_NORM_BACK", - "GROUP_NORM", - - "MUL_MAT", - "OUT_PROD", - - "SCALE", - "SET", - "CPY", - "CONT", - "RESHAPE", - "VIEW", - "PERMUTE", - "TRANSPOSE", - "GET_ROWS", - "GET_ROWS_BACK", - "DIAG", - "DIAG_MASK_INF", - "DIAG_MASK_ZERO", - "SOFT_MAX", - "SOFT_MAX_BACK", - "ROPE", - "ROPE_BACK", - "ALIBI", - "CLAMP", - "CONV_1D", - "CONV_TRANSPOSE_1D", - "CONV_2D", - "CONV_TRANSPOSE_2D", - "POOL_1D", - "POOL_2D", - "UPSCALE", - - "CONV_1D_STAGE_0", - "CONV_1D_STAGE_1", - - "FLASH_ATTN", - "FLASH_FF", - "FLASH_ATTN_BACK", - "WIN_PART", - "WIN_UNPART", - "GET_REL_POS", - "ADD_REL_POS", - - "UNARY", - - "MAP_UNARY", - "MAP_BINARY", - - "MAP_CUSTOM1_F32", - "MAP_CUSTOM2_F32", - "MAP_CUSTOM3_F32", - - "MAP_CUSTOM1", - "MAP_CUSTOM2", - "MAP_CUSTOM3", - - "CROSS_ENTROPY_LOSS", - "CROSS_ENTROPY_LOSS_BACK", -}; - -static_assert(GGML_OP_COUNT == 71, "GGML_OP_COUNT != 71"); - -static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { - "none", - - "x", - "x+y", - "x+y", - "view(x,nb,offset)+=y->x", - "x-y", - "x*y", - "x/y", - "x^2", - "√x", - "log(x)", - "Σx", - "Σx_k", - "Σx/n", - "argmax(x)", - "repeat(x)", - "repeat_back(x)", - "concat(x, y)", - "silu_back(x)", - "norm(x)", - "rms_norm(x)", - "rms_norm_back(x)", - "group_norm(x)", - - "X*Y", - "X*Y", - - "x*v", - "y-\\>view(x)", - "x-\\>y", - "cont(x)", - "reshape(x)", - "view(x)", - "permute(x)", - "transpose(x)", - "get_rows(x)", - "get_rows_back(x)", - "diag(x)", - "diag_mask_inf(x)", - "diag_mask_zero(x)", - "soft_max(x)", - "soft_max_back(x)", - "rope(x)", - "rope_back(x)", - "alibi(x)", - "clamp(x)", - "conv_1d(x)", - "conv_transpose_1d(x)", - "conv_2d(x)", - "conv_transpose_2d(x)", - "pool_1d(x)", - "pool_2d(x)", - "upscale(x)", - - "conv_1d_stage_0(x)", - "conv_1d_stage_1(x)", - - "flash_attn(x)", - "flash_ff(x)", - "flash_attn_back(x)", - "win_part(x)", - "win_unpart(x)", - "get_rel_pos(x)", - "add_rel_pos(x)", - - "unary(x)", - - "f(x)", - "f(x,y)", - - "custom_f32(x)", - "custom_f32(x,y)", - "custom_f32(x,y,z)", - - "custom(x)", - "custom(x,y)", - "custom(x,y,z)", - - "cross_entropy_loss(x,y)", - "cross_entropy_loss_back(x,y)", -}; - -static_assert(GGML_OP_COUNT == 71, "GGML_OP_COUNT != 71"); - -static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); - -static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); -static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); - -// WARN: -// Mis-confguration can lead to problem that's hard to reason about: -// * At best it crash or talks nosense. -// * At worst it talks slightly difference but hard to perceive. -// -// An op has to enable INIT or FINALIZE when any of it's branch needs that pass. -// Take care about compile options (e.g., GGML_USE_xxx). -static bool GGML_OP_HAS_INIT [GGML_OP_COUNT] = { 0 }; -static bool GGML_OP_HAS_FINALIZE[GGML_OP_COUNT] = { 0 }; - -static void ggml_setup_op_has_task_pass(void) { - { // INIT - bool * p = GGML_OP_HAS_INIT; - - p[GGML_OP_ACC ] = true; - p[GGML_OP_MUL_MAT ] = true; - p[GGML_OP_OUT_PROD ] = true; - p[GGML_OP_SET ] = true; - p[GGML_OP_GET_ROWS_BACK ] = true; - p[GGML_OP_DIAG_MASK_INF ] = true; - p[GGML_OP_DIAG_MASK_ZERO ] = true; - p[GGML_OP_CONV_1D ] = true; - p[GGML_OP_CONV_1D_STAGE_0 ] = true; - p[GGML_OP_CONV_1D_STAGE_1 ] = true; - p[GGML_OP_CONV_2D ] = true; - p[GGML_OP_CONV_TRANSPOSE_1D ] = true; - p[GGML_OP_CONV_TRANSPOSE_2D ] = true; - p[GGML_OP_FLASH_ATTN_BACK ] = true; - p[GGML_OP_CROSS_ENTROPY_LOSS ] = true; - p[GGML_OP_ADD_REL_POS ] = true; - } - - { // FINALIZE - bool * p = GGML_OP_HAS_FINALIZE; - - p[GGML_OP_CROSS_ENTROPY_LOSS ] = true; - } -} - -// -// ggml context -// - -struct ggml_context { - size_t mem_size; - void * mem_buffer; - bool mem_buffer_owned; - bool no_alloc; - bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers - - int n_objects; - - struct ggml_object * objects_begin; - struct ggml_object * objects_end; - - struct ggml_scratch scratch; - struct ggml_scratch scratch_save; -}; - -struct ggml_context_container { - bool used; - - struct ggml_context context; -}; - -// -// NUMA support -// - -#define GGML_NUMA_MAX_NODES 8 -#define GGML_NUMA_MAX_CPUS 512 - -struct ggml_numa_node { - uint32_t cpus[GGML_NUMA_MAX_CPUS]; // hardware threads on this node - uint32_t n_cpus; -}; - -struct ggml_numa_nodes { - struct ggml_numa_node nodes[GGML_NUMA_MAX_NODES]; - uint32_t n_nodes; - uint32_t total_cpus; // hardware threads on system -}; - -// -// ggml state -// - -struct ggml_state { - struct ggml_context_container contexts[GGML_MAX_CONTEXTS]; - struct ggml_numa_nodes numa; -}; - -// global state -static struct ggml_state g_state; -static atomic_int g_state_barrier = 0; - -// barrier via spin lock -inline static void ggml_critical_section_start(void) { - int processing = atomic_fetch_add(&g_state_barrier, 1); - - while (processing > 0) { - // wait for other threads to finish - atomic_fetch_sub(&g_state_barrier, 1); - sched_yield(); // TODO: reconsider this - processing = atomic_fetch_add(&g_state_barrier, 1); - } -} - -// TODO: make this somehow automatically executed -// some sort of "sentry" mechanism -inline static void ggml_critical_section_end(void) { - atomic_fetch_sub(&g_state_barrier, 1); -} - -void ggml_numa_init(void) { - if (g_state.numa.n_nodes > 0) { - fprintf(stderr, "ggml_numa_init: NUMA already initialized\n"); - - return; - } - -#ifdef __linux__ - struct stat st; - char path[256]; - int rv; - - // enumerate nodes - while (g_state.numa.n_nodes < GGML_NUMA_MAX_NODES) { - rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u", g_state.numa.n_nodes); - GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); - if (stat(path, &st) != 0) { break; } - ++g_state.numa.n_nodes; - } - - // enumerate CPUs - while (g_state.numa.total_cpus < GGML_NUMA_MAX_CPUS) { - rv = snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%u", g_state.numa.total_cpus); - GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); - if (stat(path, &st) != 0) { break; } - ++g_state.numa.total_cpus; - } - - GGML_PRINT_DEBUG("found %u numa nodes, %u CPUs\n", g_state.numa.n_nodes, g_state.numa.total_cpus); - - if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1) { - g_state.numa.n_nodes = 0; - return; - } - - for (uint32_t n = 0; n < g_state.numa.n_nodes; ++n) { - struct ggml_numa_node * node = &g_state.numa.nodes[n]; - GGML_PRINT_DEBUG("CPUs on node %u:", n); - node->n_cpus = 0; - for (uint32_t c = 0; c < g_state.numa.total_cpus; ++c) { - rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u/cpu%u", n, c); - GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); - if (stat(path, &st) == 0) { - node->cpus[node->n_cpus++] = c; - GGML_PRINT_DEBUG(" %u", c); - } - } - GGML_PRINT_DEBUG("\n"); - } - - if (ggml_is_numa()) { - FILE *fptr = fopen("/proc/sys/kernel/numa_balancing", "r"); - if (fptr != NULL) { - char buf[42]; - if (fgets(buf, sizeof(buf), fptr) && strncmp(buf, "0\n", sizeof(buf)) != 0) { - GGML_PRINT("WARNING: /proc/sys/kernel/numa_balancing is enabled, this has been observed to impair performance\n"); - } - fclose(fptr); - } - } -#else - // TODO -#endif -} - -bool ggml_is_numa(void) { - return g_state.numa.n_nodes > 1; -} - -//////////////////////////////////////////////////////////////////////////////// - -void ggml_print_object(const struct ggml_object * obj) { - GGML_PRINT(" - ggml_object: type = %d, offset = %zu, size = %zu, next = %p\n", - obj->type, obj->offs, obj->size, (const void *) obj->next); -} - -void ggml_print_objects(const struct ggml_context * ctx) { - struct ggml_object * obj = ctx->objects_begin; - - GGML_PRINT("%s: objects in context %p:\n", __func__, (const void *) ctx); - - while (obj != NULL) { - ggml_print_object(obj); - obj = obj->next; - } - - GGML_PRINT("%s: --- end ---\n", __func__); -} - -int64_t ggml_nelements(const struct ggml_tensor * tensor) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; -} - -int64_t ggml_nrows(const struct ggml_tensor * tensor) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; -} - -size_t ggml_nbytes(const struct ggml_tensor * tensor) { - size_t nbytes; - size_t blck_size = ggml_blck_size(tensor->type); - if (blck_size == 1) { - nbytes = ggml_type_size(tensor->type); - for (int i = 0; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; - } - } - else { - nbytes = tensor->ne[0]*tensor->nb[0]/blck_size; - for (int i = 1; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; - } - } - - return nbytes; -} - -size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) { - return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN); -} - -size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return (nrows_split*tensor->ne[0]*ggml_type_size(tensor->type))/ggml_blck_size(tensor->type); -} - -int ggml_blck_size(enum ggml_type type) { - return type_traits[type].blck_size; -} - -size_t ggml_type_size(enum ggml_type type) { - return type_traits[type].type_size; -} - -float ggml_type_sizef(enum ggml_type type) { - return ((float)(type_traits[type].type_size))/type_traits[type].blck_size; -} - -const char * ggml_type_name(enum ggml_type type) { - return type_traits[type].type_name; -} - -bool ggml_is_quantized(enum ggml_type type) { - return type_traits[type].is_quantized; -} - -const char * ggml_op_name(enum ggml_op op) { - return GGML_OP_NAME[op]; -} - -const char * ggml_op_symbol(enum ggml_op op) { - return GGML_OP_SYMBOL[op]; -} - -size_t ggml_element_size(const struct ggml_tensor * tensor) { - return ggml_type_size(tensor->type); -} - -static inline bool ggml_is_scalar(const struct ggml_tensor * tensor) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; -} - -static inline bool ggml_is_vector(const struct ggml_tensor * tensor) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; -} - -static inline bool ggml_is_matrix(const struct ggml_tensor * tensor) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return tensor->ne[2] == 1 && tensor->ne[3] == 1; -} - -static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return (t0->ne[0] == t1->ne[0]) && - (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable - (t1->ne[3]%t0->ne[3] == 0); -} - -static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return (t0->ne[1] == t1->ne[1]) && - (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable - (t1->ne[3]%t0->ne[3] == 0); -} - -enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { - enum ggml_type wtype = GGML_TYPE_COUNT; - - switch (ftype) { - case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break; - case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break; - case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break; - case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break; - case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break; - case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break; - case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; - case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break; - case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break; - case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break; - case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break; - case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break; - case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; - case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; - } - - GGML_ASSERT(wtype != GGML_TYPE_COUNT); - - return wtype; -} - -size_t ggml_tensor_overhead(void) { - return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE; -} - -bool ggml_is_transposed(const struct ggml_tensor * tensor) { - return tensor->nb[0] > tensor->nb[1]; -} - -bool ggml_is_contiguous(const struct ggml_tensor * tensor) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return - tensor->nb[0] == ggml_type_size(tensor->type) && - tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) && - tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && - tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; -} - -static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * tensor) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return - tensor->nb[0] == ggml_type_size(tensor->type) && - tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && - tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; -} - -bool ggml_is_permuted(const struct ggml_tensor * tensor) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3]; -} - -static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return - tensor->nb[0] == ggml_type_size(tensor->type) && - tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && - tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; -} - -bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return - (t0->ne[0] == t1->ne[0] ) && - (t0->ne[1] == t1->ne[1] ) && - (t0->ne[2] == t1->ne[2] ) && - (t0->ne[3] == t1->ne[3] ); -} - -// check if t1 can be represented as a repeatition of t0 -static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return - (t1->ne[0]%t0->ne[0] == 0) && - (t1->ne[1]%t0->ne[1] == 0) && - (t1->ne[2]%t0->ne[2] == 0) && - (t1->ne[3]%t0->ne[3] == 0); -} - -static inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return (t0->ne[0] == t1->ne[0]) && ggml_can_repeat(t0, t1); -} - -static inline int ggml_up32(int n) { - return (n + 31) & ~31; -} - -//static inline int ggml_up64(int n) { -// return (n + 63) & ~63; -//} - -static inline int ggml_up(int n, int m) { - // assert m is a power of 2 - GGML_ASSERT((m & (m - 1)) == 0); - return (n + m - 1) & ~(m - 1); -} - -// assert that pointer is aligned to GGML_MEM_ALIGN -#define ggml_assert_aligned(ptr) \ - GGML_ASSERT(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0) - -//////////////////////////////////////////////////////////////////////////////// - -struct ggml_context * ggml_init(struct ggml_init_params params) { - // make this function thread safe - ggml_critical_section_start(); - - static bool is_first_call = true; - - if (is_first_call) { - // initialize time system (required on Windows) - ggml_time_init(); - - // initialize GELU, Quick GELU, SILU and EXP F32 tables - { - const uint64_t t_start = ggml_time_us(); UNUSED(t_start); - - ggml_fp16_t ii; - for (int i = 0; i < (1 << 16); ++i) { - uint16_t ui = i; - memcpy(&ii, &ui, sizeof(ii)); - const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii); - table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f)); - table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f)); - table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f)); - table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f)); - } - - const uint64_t t_end = ggml_time_us(); UNUSED(t_end); - - GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); - } - - // initialize g_state - { - const uint64_t t_start = ggml_time_us(); UNUSED(t_start); - - g_state = (struct ggml_state) { - /*.contexts =*/ { { 0 } }, - /*.numa =*/ { - .n_nodes = 0, - .total_cpus = 0, - }, - }; - - for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) { - g_state.contexts[i].used = false; - } - - const uint64_t t_end = ggml_time_us(); UNUSED(t_end); - - GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); - } - -#if defined(GGML_USE_CUBLAS) - ggml_init_cublas(); -#elif defined(GGML_USE_CLBLAST) - ggml_cl_init(); -#endif - - ggml_setup_op_has_task_pass(); - - is_first_call = false; - } - - // find non-used context in g_state - struct ggml_context * ctx = NULL; - - for (int i = 0; i < GGML_MAX_CONTEXTS; i++) { - if (!g_state.contexts[i].used) { - g_state.contexts[i].used = true; - ctx = &g_state.contexts[i].context; - - GGML_PRINT_DEBUG("%s: found unused context %d\n", __func__, i); - break; - } - } - - if (ctx == NULL) { - GGML_PRINT_DEBUG("%s: no unused context found\n", __func__); - - ggml_critical_section_end(); - - return NULL; - } - - // allow to call ggml_init with 0 size - if (params.mem_size == 0) { - params.mem_size = GGML_MEM_ALIGN; - } - - const size_t mem_size = params.mem_buffer ? params.mem_size : GGML_PAD(params.mem_size, GGML_MEM_ALIGN); - - *ctx = (struct ggml_context) { - /*.mem_size =*/ mem_size, - /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size), - /*.mem_buffer_owned =*/ params.mem_buffer ? false : true, - /*.no_alloc =*/ params.no_alloc, - /*.no_alloc_save =*/ params.no_alloc, - /*.n_objects =*/ 0, - /*.objects_begin =*/ NULL, - /*.objects_end =*/ NULL, - /*.scratch =*/ { 0, 0, NULL, }, - /*.scratch_save =*/ { 0, 0, NULL, }, - }; - - GGML_ASSERT(ctx->mem_buffer != NULL); - - ggml_assert_aligned(ctx->mem_buffer); - - GGML_PRINT_DEBUG("%s: context initialized\n", __func__); - - ggml_critical_section_end(); - - return ctx; -} - -void ggml_free(struct ggml_context * ctx) { - // make this function thread safe - ggml_critical_section_start(); - - bool found = false; - - for (int i = 0; i < GGML_MAX_CONTEXTS; i++) { - if (&g_state.contexts[i].context == ctx) { - g_state.contexts[i].used = false; - - GGML_PRINT_DEBUG("%s: context %d has been freed. memory used = %zu\n", - __func__, i, ggml_used_mem(ctx)); - - if (ctx->mem_buffer_owned) { - GGML_ALIGNED_FREE(ctx->mem_buffer); - } - - found = true; - break; - } - } - - if (!found) { - GGML_PRINT_DEBUG("%s: context not found\n", __func__); - } - - ggml_critical_section_end(); -} - -size_t ggml_used_mem(const struct ggml_context * ctx) { - return ctx->objects_end == NULL ? 0 : ctx->objects_end->offs + ctx->objects_end->size; -} - -size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch) { - const size_t result = ctx->scratch.data ? ctx->scratch.offs : 0; - - ctx->scratch = scratch; - - return result; -} - -bool ggml_get_no_alloc(struct ggml_context * ctx) { - return ctx->no_alloc; -} - -void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc) { - ctx->no_alloc = no_alloc; -} - -void * ggml_get_mem_buffer(const struct ggml_context * ctx) { - return ctx->mem_buffer; -} - -size_t ggml_get_mem_size(const struct ggml_context * ctx) { - return ctx->mem_size; -} - -size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) { - size_t max_size = 0; - - struct ggml_object * obj = ctx->objects_begin; - - while (obj != NULL) { - if (obj->type == GGML_OBJECT_TENSOR) { - struct ggml_tensor * tensor = (struct ggml_tensor *) ((char *) ctx->mem_buffer + obj->offs); - - const size_t size = ggml_nbytes(tensor); - - if (max_size < size) { - max_size = size; - } - } - - obj = obj->next; - } - - return max_size; -} - -// IMPORTANT: -// when creating "opt" tensors, always save and load the scratch buffer -// this is an error prone process, but it is necessary to support inplace -// operators when using scratch buffers -// TODO: implement a better way -static void ggml_scratch_save(struct ggml_context * ctx) { - // this is needed to allow opt tensors to store their data - // TODO: again, need to find a better way - ctx->no_alloc_save = ctx->no_alloc; - ctx->no_alloc = false; - - ctx->scratch_save = ctx->scratch; - ctx->scratch.data = NULL; -} - -static void ggml_scratch_load(struct ggml_context * ctx) { - ctx->no_alloc = ctx->no_alloc_save; - - ctx->scratch = ctx->scratch_save; -} - -//////////////////////////////////////////////////////////////////////////////// - -static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml_object_type type, size_t size) { - // always insert objects at the end of the context's memory pool - struct ggml_object * obj_cur = ctx->objects_end; - - const size_t cur_offs = obj_cur == NULL ? 0 : obj_cur->offs; - const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size; - const size_t cur_end = cur_offs + cur_size; - - // align to GGML_MEM_ALIGN - size_t size_needed = GGML_PAD(size, GGML_MEM_ALIGN); - - char * const mem_buffer = ctx->mem_buffer; - struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end); - - if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) { - GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", - __func__, cur_end + size_needed, ctx->mem_size); - assert(false); - return NULL; - } - - *obj_new = (struct ggml_object) { - .offs = cur_end + GGML_OBJECT_SIZE, - .size = size_needed, - .next = NULL, - .type = type, - }; - - ggml_assert_aligned(mem_buffer + obj_new->offs); - - if (obj_cur != NULL) { - obj_cur->next = obj_new; - } else { - // this is the first object in this context - ctx->objects_begin = obj_new; - } - - ctx->objects_end = obj_new; - - //printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size); - - return obj_new; -} - -static struct ggml_tensor * ggml_new_tensor_impl( - struct ggml_context * ctx, - enum ggml_type type, - int n_dims, - const int64_t * ne, - struct ggml_tensor * view_src, - size_t view_offs) { - - assert(n_dims >= 1 && n_dims <= GGML_MAX_DIMS); - - // find the base tensor and absolute offset - if (view_src != NULL && view_src->view_src != NULL) { - view_offs += view_src->view_offs; - view_src = view_src->view_src; - } - - size_t data_size = ggml_type_size(type)*(ne[0]/ggml_blck_size(type)); - for (int i = 1; i < n_dims; i++) { - data_size *= ne[i]; - } - - GGML_ASSERT(view_src == NULL || data_size + view_offs <= ggml_nbytes(view_src)); - - void * data = view_src != NULL ? view_src->data : NULL; - if (data != NULL) { - data = (char *) data + view_offs; - } - - size_t obj_alloc_size = 0; - - if (view_src == NULL && !ctx->no_alloc) { - if (ctx->scratch.data != NULL) { - // allocate tensor data in the scratch buffer - if (ctx->scratch.offs + data_size > ctx->scratch.size) { - GGML_PRINT("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n", - __func__, ctx->scratch.offs + data_size, ctx->scratch.size); - assert(false); - return NULL; - } - - data = (char * const) ctx->scratch.data + ctx->scratch.offs; - - ctx->scratch.offs += data_size; - } else { - // allocate tensor data in the context's memory pool - obj_alloc_size = data_size; - } - } - - struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size); - - // TODO: for recoverable errors, we would need to free the data allocated from the scratch buffer here - - struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs); - - *result = (struct ggml_tensor) { - /*.type =*/ type, - /*.backend =*/ GGML_BACKEND_CPU, - /*.buffer =*/ NULL, - /*.n_dims =*/ n_dims, - /*.ne =*/ { 1, 1, 1, 1 }, - /*.nb =*/ { 0, 0, 0, 0 }, - /*.op =*/ GGML_OP_NONE, - /*.op_params =*/ { 0 }, - /*.is_param =*/ false, - /*.grad =*/ NULL, - /*.src =*/ { NULL }, - /*.perf_runs =*/ 0, - /*.perf_cycles =*/ 0, - /*.perf_time_us =*/ 0, - /*.view_src =*/ view_src, - /*.view_offs =*/ view_offs, - /*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data, - /*.name =*/ { 0 }, - /*.extra =*/ NULL, - /*.padding =*/ { 0 }, - }; - - // TODO: this should not be needed as long as we don't rely on aligned SIMD loads - //ggml_assert_aligned(result->data); - - for (int i = 0; i < n_dims; i++) { - result->ne[i] = ne[i]; - } - - result->nb[0] = ggml_type_size(type); - result->nb[1] = result->nb[0]*(result->ne[0]/ggml_blck_size(type)); - for (int i = 2; i < GGML_MAX_DIMS; i++) { - result->nb[i] = result->nb[i - 1]*result->ne[i - 1]; - } - - ctx->n_objects++; - - return result; -} - -struct ggml_tensor * ggml_new_tensor( - struct ggml_context * ctx, - enum ggml_type type, - int n_dims, - const int64_t * ne) { - return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL, 0); -} - -struct ggml_tensor * ggml_new_tensor_1d( - struct ggml_context * ctx, - enum ggml_type type, - int64_t ne0) { - return ggml_new_tensor(ctx, type, 1, &ne0); -} - -struct ggml_tensor * ggml_new_tensor_2d( - struct ggml_context * ctx, - enum ggml_type type, - int64_t ne0, - int64_t ne1) { - const int64_t ne[2] = { ne0, ne1 }; - return ggml_new_tensor(ctx, type, 2, ne); -} - -struct ggml_tensor * ggml_new_tensor_3d( - struct ggml_context * ctx, - enum ggml_type type, - int64_t ne0, - int64_t ne1, - int64_t ne2) { - const int64_t ne[3] = { ne0, ne1, ne2 }; - return ggml_new_tensor(ctx, type, 3, ne); -} - -struct ggml_tensor * ggml_new_tensor_4d( - struct ggml_context * ctx, - enum ggml_type type, - int64_t ne0, - int64_t ne1, - int64_t ne2, - int64_t ne3) { - const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; - return ggml_new_tensor(ctx, type, 4, ne); -} - -struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) { - ggml_scratch_save(ctx); - - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); - - ggml_scratch_load(ctx); - - ggml_set_i32(result, value); - - return result; -} - -struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) { - ggml_scratch_save(ctx); - - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); - - ggml_scratch_load(ctx); - - ggml_set_f32(result, value); - - return result; -} - -struct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggml_tensor * src) { - return ggml_new_tensor(ctx, src->type, src->n_dims, src->ne); -} - -static void ggml_set_op_params(struct ggml_tensor * tensor, const void * params, size_t params_size) { - GGML_ASSERT(tensor != NULL); // silence -Warray-bounds warnings - assert(params_size <= GGML_MAX_OP_PARAMS); - memcpy(tensor->op_params, params, params_size); -} - -static int32_t ggml_get_op_params_i32(const struct ggml_tensor * tensor, uint32_t i) { - assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t)); - return ((const int32_t *)(tensor->op_params))[i]; -} - -static void ggml_set_op_params_i32(struct ggml_tensor * tensor, uint32_t i, int32_t value) { - assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t)); - ((int32_t *)(tensor->op_params))[i] = value; -} - -struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) { - memset(tensor->data, 0, ggml_nbytes(tensor)); - return tensor; -} - -struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) { - const int n = ggml_nrows(tensor); - const int nc = tensor->ne[0]; - const size_t n1 = tensor->nb[1]; - - char * const data = tensor->data; - - switch (tensor->type) { - case GGML_TYPE_I8: - { - assert(tensor->nb[0] == sizeof(int8_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value); - } - } break; - case GGML_TYPE_I16: - { - assert(tensor->nb[0] == sizeof(int16_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value); - } - } break; - case GGML_TYPE_I32: - { - assert(tensor->nb[0] == sizeof(int32_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value); - } - } break; - case GGML_TYPE_F16: - { - assert(tensor->nb[0] == sizeof(ggml_fp16_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value)); - } - } break; - case GGML_TYPE_F32: - { - assert(tensor->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { - ggml_vec_set_f32(nc, (float *)(data + i*n1), value); - } - } break; - default: - { - GGML_ASSERT(false); - } break; - } - - return tensor; -} - -struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { - const int n = ggml_nrows(tensor); - const int nc = tensor->ne[0]; - const size_t n1 = tensor->nb[1]; - - char * const data = tensor->data; - - switch (tensor->type) { - case GGML_TYPE_I8: - { - assert(tensor->nb[0] == sizeof(int8_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value); - } - } break; - case GGML_TYPE_I16: - { - assert(tensor->nb[0] == sizeof(int16_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value); - } - } break; - case GGML_TYPE_I32: - { - assert(tensor->nb[0] == sizeof(int32_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value); - } - } break; - case GGML_TYPE_F16: - { - assert(tensor->nb[0] == sizeof(ggml_fp16_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value)); - } - } break; - case GGML_TYPE_F32: - { - assert(tensor->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { - ggml_vec_set_f32(nc, (float *)(data + i*n1), value); - } - } break; - default: - { - GGML_ASSERT(false); - } break; - } - - return tensor; -} - -void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3) { - const int64_t ne2 = tensor->ne[2]; - const int64_t ne1 = tensor->ne[1]; - const int64_t ne0 = tensor->ne[0]; - - const int64_t i3_ = (i/(ne2*ne1*ne0)); - const int64_t i2_ = (i - i3_*ne2*ne1*ne0)/(ne1*ne0); - const int64_t i1_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0)/ne0; - const int64_t i0_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0 - i1_*ne0); - - if (i0) { - * i0 = i0_; - } - if (i1) { - * i1 = i1_; - } - if (i2) { - * i2 = i2_; - } - if (i3) { - * i3 = i3_; - } -} - -int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { - if (!ggml_is_contiguous(tensor)) { - int64_t id[4] = { 0, 0, 0, 0 }; - ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); - return ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]); - } - switch (tensor->type) { - case GGML_TYPE_I8: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); - return ((int8_t *)(tensor->data))[i]; - } - case GGML_TYPE_I16: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); - return ((int16_t *)(tensor->data))[i]; - } - case GGML_TYPE_I32: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); - return ((int32_t *)(tensor->data))[i]; - } - case GGML_TYPE_F16: - { - GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); - return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); - } - case GGML_TYPE_F32: - { - GGML_ASSERT(tensor->nb[0] == sizeof(float)); - return ((float *)(tensor->data))[i]; - } - default: - { - GGML_ASSERT(false); - } - } - - return 0.0f; -} - -void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { - if (!ggml_is_contiguous(tensor)) { - int64_t id[4] = { 0, 0, 0, 0 }; - ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); - ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value); - return; - } - switch (tensor->type) { - case GGML_TYPE_I8: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); - ((int8_t *)(tensor->data))[i] = value; - } break; - case GGML_TYPE_I16: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); - ((int16_t *)(tensor->data))[i] = value; - } break; - case GGML_TYPE_I32: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); - ((int32_t *)(tensor->data))[i] = value; - } break; - case GGML_TYPE_F16: - { - GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); - ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); - } break; - case GGML_TYPE_F32: - { - GGML_ASSERT(tensor->nb[0] == sizeof(float)); - ((float *)(tensor->data))[i] = value; - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) { - void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; - switch (tensor->type) { - case GGML_TYPE_I8: - return ((int8_t *) data)[0]; - case GGML_TYPE_I16: - return ((int16_t *) data)[0]; - case GGML_TYPE_I32: - return ((int32_t *) data)[0]; - case GGML_TYPE_F16: - return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]); - case GGML_TYPE_F32: - return ((float *) data)[0]; - default: - GGML_ASSERT(false); - } - - return 0.0f; -} - -void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) { - void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; - switch (tensor->type) { - case GGML_TYPE_I8: - { - ((int8_t *)(data))[0] = value; - } break; - case GGML_TYPE_I16: - { - ((int16_t *)(data))[0] = value; - } break; - case GGML_TYPE_I32: - { - ((int32_t *)(data))[0] = value; - } break; - case GGML_TYPE_F16: - { - ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value); - } break; - case GGML_TYPE_F32: - { - ((float *)(data))[0] = value; - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { - if (!ggml_is_contiguous(tensor)) { - int64_t id[4] = { 0, 0, 0, 0 }; - ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); - return ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]); - } - switch (tensor->type) { - case GGML_TYPE_I8: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); - return ((int8_t *)(tensor->data))[i]; - } - case GGML_TYPE_I16: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); - return ((int16_t *)(tensor->data))[i]; - } - case GGML_TYPE_I32: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); - return ((int32_t *)(tensor->data))[i]; - } - case GGML_TYPE_F16: - { - GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); - return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); - } - case GGML_TYPE_F32: - { - GGML_ASSERT(tensor->nb[0] == sizeof(float)); - return ((float *)(tensor->data))[i]; - } - default: - { - GGML_ASSERT(false); - } - } - - return 0.0f; -} - -void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { - if (!ggml_is_contiguous(tensor)) { - int64_t id[4] = { 0, 0, 0, 0 }; - ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); - ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value); - return; - } - switch (tensor->type) { - case GGML_TYPE_I8: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); - ((int8_t *)(tensor->data))[i] = value; - } break; - case GGML_TYPE_I16: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); - ((int16_t *)(tensor->data))[i] = value; - } break; - case GGML_TYPE_I32: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); - ((int32_t *)(tensor->data))[i] = value; - } break; - case GGML_TYPE_F16: - { - GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); - ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); - } break; - case GGML_TYPE_F32: - { - GGML_ASSERT(tensor->nb[0] == sizeof(float)); - ((float *)(tensor->data))[i] = value; - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) { - void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; - switch (tensor->type) { - case GGML_TYPE_I8: - return ((int8_t *) data)[0]; - case GGML_TYPE_I16: - return ((int16_t *) data)[0]; - case GGML_TYPE_I32: - return ((int32_t *) data)[0]; - case GGML_TYPE_F16: - return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]); - case GGML_TYPE_F32: - return ((float *) data)[0]; - default: - GGML_ASSERT(false); - } - - return 0.0f; -} - -void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) { - void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; - switch (tensor->type) { - case GGML_TYPE_I8: - { - ((int8_t *)(data))[0] = value; - } break; - case GGML_TYPE_I16: - { - ((int16_t *)(data))[0] = value; - } break; - case GGML_TYPE_I32: - { - ((int32_t *)(data))[0] = value; - } break; - case GGML_TYPE_F16: - { - ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value); - } break; - case GGML_TYPE_F32: - { - ((float *)(data))[0] = value; - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -void * ggml_get_data(const struct ggml_tensor * tensor) { - return tensor->data; -} - -float * ggml_get_data_f32(const struct ggml_tensor * tensor) { - assert(tensor->type == GGML_TYPE_F32); - return (float *)(tensor->data); -} - -enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) { - GGML_ASSERT(tensor->op == GGML_OP_UNARY); - return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0); -} - -const char * ggml_get_name(const struct ggml_tensor * tensor) { - return tensor->name; -} - -struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name) { - strncpy(tensor->name, name, sizeof(tensor->name)); - tensor->name[sizeof(tensor->name) - 1] = '\0'; - return tensor; -} - -struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char * fmt, ...) { - va_list args; - va_start(args, fmt); - vsnprintf(tensor->name, sizeof(tensor->name), fmt, args); - va_end(args); - return tensor; -} - -struct ggml_tensor * ggml_view_tensor( - struct ggml_context * ctx, - struct ggml_tensor * src) { - struct ggml_tensor * result = ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src, 0); - ggml_format_name(result, "%s (view)", src->name); - - for (int i = 0; i < GGML_MAX_DIMS; i++) { - result->nb[i] = src->nb[i]; - } - - return result; -} - -struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx) { - struct ggml_object * obj = ctx->objects_begin; - - char * const mem_buffer = ctx->mem_buffer; - - while (obj != NULL) { - if (obj->type == GGML_OBJECT_TENSOR) { - return (struct ggml_tensor *)(mem_buffer + obj->offs); - } - - obj = obj->next; - } - - return NULL; -} - -struct ggml_tensor * ggml_get_next_tensor(struct ggml_context * ctx, struct ggml_tensor * tensor) { - struct ggml_object * obj = (struct ggml_object *) ((char *)tensor - GGML_OBJECT_SIZE); - obj = obj->next; - - char * const mem_buffer = ctx->mem_buffer; - - while (obj != NULL) { - if (obj->type == GGML_OBJECT_TENSOR) { - return (struct ggml_tensor *)(mem_buffer + obj->offs); - } - - obj = obj->next; - } - - return NULL; -} - -struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name) { - struct ggml_object * obj = ctx->objects_begin; - - char * const mem_buffer = ctx->mem_buffer; - - while (obj != NULL) { - if (obj->type == GGML_OBJECT_TENSOR) { - struct ggml_tensor * cur = (struct ggml_tensor *)(mem_buffer + obj->offs); - if (strcmp(cur->name, name) == 0) { - return cur; - } - } - - obj = obj->next; - } - - return NULL; -} - -//////////////////////////////////////////////////////////////////////////////// - -// ggml_dup - -static struct ggml_tensor * ggml_dup_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_DUP; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -struct ggml_tensor * ggml_dup( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_dup_impl(ctx, a, false); -} - -struct ggml_tensor * ggml_dup_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_dup_impl(ctx, a, true); -} - -// ggml_add - -static struct ggml_tensor * ggml_add_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - bool inplace) { - // TODO: support less-strict constraint - // GGML_ASSERT(ggml_can_repeat(b, a)); - GGML_ASSERT(ggml_can_repeat_rows(b, a)); - - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - // TODO: support backward pass for broadcasting - GGML_ASSERT(ggml_are_same_shape(a, b)); - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_ADD; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -struct ggml_tensor * ggml_add( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_add_impl(ctx, a, b, false); -} - -struct ggml_tensor * ggml_add_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_add_impl(ctx, a, b, true); -} - -// ggml_add_cast - -static struct ggml_tensor * ggml_add_cast_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - enum ggml_type type) { - // TODO: support less-strict constraint - // GGML_ASSERT(ggml_can_repeat(b, a)); - GGML_ASSERT(ggml_can_repeat_rows(b, a)); - GGML_ASSERT(ggml_is_quantized(a->type)); // currently only supported for quantized input - - bool is_node = false; - - if (a->grad || b->grad) { - // TODO: support backward pass for broadcasting - GGML_ASSERT(ggml_are_same_shape(a, b)); - is_node = true; - } - - struct ggml_tensor * result = ggml_new_tensor(ctx, type, a->n_dims, a->ne); - - result->op = GGML_OP_ADD; - result->grad = is_node ? ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, a->ne) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -struct ggml_tensor * ggml_add_cast( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - enum ggml_type type) { - return ggml_add_cast_impl(ctx, a, b, type); -} - -// ggml_add1 - -static struct ggml_tensor * ggml_add1_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - bool inplace) { - GGML_ASSERT(ggml_is_scalar(b)); - GGML_ASSERT(ggml_is_padded_1d(a)); - - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_ADD1; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -struct ggml_tensor * ggml_add1( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_add1_impl(ctx, a, b, false); -} - -struct ggml_tensor * ggml_add1_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_add1_impl(ctx, a, b, true); -} - -// ggml_acc - -static struct ggml_tensor * ggml_acc_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset, - bool inplace) { - GGML_ASSERT(ggml_nelements(b) <= ggml_nelements(a)); - GGML_ASSERT(ggml_is_contiguous(a)); - GGML_ASSERT(a->type == GGML_TYPE_F32); - GGML_ASSERT(b->type == GGML_TYPE_F32); - - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 }; - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_ACC; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -struct ggml_tensor * ggml_acc( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset) { - return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); -} - -struct ggml_tensor * ggml_acc_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset) { - return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true); -} - -// ggml_sub - -static struct ggml_tensor * ggml_sub_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - bool inplace) { - GGML_ASSERT(ggml_are_same_shape(a, b)); - - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_SUB; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -struct ggml_tensor * ggml_sub( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_sub_impl(ctx, a, b, false); -} - -struct ggml_tensor * ggml_sub_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_sub_impl(ctx, a, b, true); -} - -// ggml_mul - -static struct ggml_tensor * ggml_mul_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - bool inplace) { - // TODO: support less-strict constraint - // GGML_ASSERT(ggml_can_repeat(b, a)); - GGML_ASSERT(ggml_can_repeat_rows(b, a)); - - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - // TODO: support backward pass for broadcasting - GGML_ASSERT(ggml_are_same_shape(a, b)); - is_node = true; - } - - if (inplace) { - GGML_ASSERT(!is_node); - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_MUL; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -struct ggml_tensor * ggml_mul( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_mul_impl(ctx, a, b, false); -} - -struct ggml_tensor * ggml_mul_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_mul_impl(ctx, a, b, true); -} - -// ggml_div - -static struct ggml_tensor * ggml_div_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - bool inplace) { - GGML_ASSERT(ggml_are_same_shape(a, b)); - - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - - if (inplace) { - GGML_ASSERT(!is_node); - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_DIV; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -struct ggml_tensor * ggml_div( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_div_impl(ctx, a, b, false); -} - -struct ggml_tensor * ggml_div_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_div_impl(ctx, a, b, true); -} - -// ggml_sqr - -static struct ggml_tensor * ggml_sqr_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_SQR; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -struct ggml_tensor * ggml_sqr( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_sqr_impl(ctx, a, false); -} - -struct ggml_tensor * ggml_sqr_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_sqr_impl(ctx, a, true); -} - -// ggml_sqrt - -static struct ggml_tensor * ggml_sqrt_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_SQRT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -struct ggml_tensor * ggml_sqrt( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_sqrt_impl(ctx, a, false); -} - -struct ggml_tensor * ggml_sqrt_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_sqrt_impl(ctx, a, true); -} - - -// ggml_log - -static struct ggml_tensor * ggml_log_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_LOG; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -struct ggml_tensor * ggml_log( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_log_impl(ctx, a, false); -} - -struct ggml_tensor * ggml_log_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_log_impl(ctx, a, true); -} - -// ggml_sum - -struct ggml_tensor * ggml_sum( - struct ggml_context * ctx, - struct ggml_tensor * a) { - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1); - - result->op = GGML_OP_SUM; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - - -// ggml_sum_rows - -struct ggml_tensor * ggml_sum_rows( - struct ggml_context * ctx, - struct ggml_tensor * a) { - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - int64_t ne[4] = {1,1,1,1}; - for (int i=1; in_dims; ++i) { - ne[i] = a->ne[i]; - } - - struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, a->n_dims, ne); - - result->op = GGML_OP_SUM_ROWS; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -// ggml_mean - -struct ggml_tensor * ggml_mean( - struct ggml_context * ctx, - struct ggml_tensor * a) { - bool is_node = false; - - if (a->grad) { - GGML_ASSERT(false); // TODO: implement - is_node = true; - } - - int64_t ne[GGML_MAX_DIMS] = { 1, a->ne[1], a->ne[2], a->ne[3] }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, ne); - - result->op = GGML_OP_MEAN; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -// ggml_argmax - -struct ggml_tensor * ggml_argmax( - struct ggml_context * ctx, - struct ggml_tensor * a) { - GGML_ASSERT(ggml_is_matrix(a)); - bool is_node = false; - - if (a->grad) { - GGML_ASSERT(false); - is_node = true; - } - - int64_t ne[GGML_MAX_DIMS] = { a->ne[1], 1, 1, 1 }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, a->n_dims, ne); - - result->op = GGML_OP_ARGMAX; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -// ggml_repeat - -struct ggml_tensor * ggml_repeat( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - GGML_ASSERT(ggml_can_repeat(a, b)); - - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne); - - result->op = GGML_OP_REPEAT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -// ggml_repeat_back - -struct ggml_tensor * ggml_repeat_back( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - GGML_ASSERT(ggml_can_repeat(b, a)); - - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - if (ggml_are_same_shape(a, b) && !is_node) { - return a; - } - - struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne); - - result->op = GGML_OP_REPEAT_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -// ggml_concat - -struct ggml_tensor * ggml_concat( - struct ggml_context* ctx, - struct ggml_tensor* a, - struct ggml_tensor* b) { - GGML_ASSERT(a->ne[0] == b->ne[0] && a->ne[1] == b->ne[1] && a->ne[3] == b->ne[3]); - - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - - struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, a->ne[0], a->ne[1], a->ne[2] + b->ne[2], a->ne[3]); - - result->op = GGML_OP_CONCAT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -// ggml_abs - -struct ggml_tensor * ggml_abs( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_unary(ctx, a, GGML_UNARY_OP_ABS); -} - -struct ggml_tensor * ggml_abs_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ABS); -} - -// ggml_sgn - -struct ggml_tensor * ggml_sgn( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_unary(ctx, a, GGML_UNARY_OP_SGN); -} - -struct ggml_tensor * ggml_sgn_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SGN); -} - -// ggml_neg - -struct ggml_tensor * ggml_neg( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_unary(ctx, a, GGML_UNARY_OP_NEG); -} - -struct ggml_tensor * ggml_neg_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_NEG); -} - -// ggml_step - -struct ggml_tensor * ggml_step( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_unary(ctx, a, GGML_UNARY_OP_STEP); -} - -struct ggml_tensor * ggml_step_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_STEP); -} - -// ggml_tanh - -struct ggml_tensor * ggml_tanh( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_unary(ctx, a, GGML_UNARY_OP_TANH); -} - -struct ggml_tensor * ggml_tanh_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_TANH); -} - -// ggml_elu - -struct ggml_tensor * ggml_elu( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_unary(ctx, a, GGML_UNARY_OP_ELU); -} - -struct ggml_tensor * ggml_elu_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ELU); -} - -// ggml_relu - -struct ggml_tensor * ggml_relu( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_unary(ctx, a, GGML_UNARY_OP_RELU); -} - -struct ggml_tensor * ggml_relu_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU); -} - -// ggml_gelu - -struct ggml_tensor * ggml_gelu( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_unary(ctx, a, GGML_UNARY_OP_GELU); -} - -struct ggml_tensor * ggml_gelu_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU); -} - -// ggml_gelu_quick - -struct ggml_tensor * ggml_gelu_quick( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_QUICK); -} - -struct ggml_tensor * ggml_gelu_quick_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_QUICK); -} - -// ggml_silu - -struct ggml_tensor * ggml_silu( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_unary(ctx, a, GGML_UNARY_OP_SILU); -} - -struct ggml_tensor * ggml_silu_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU); -} - -// ggml_silu_back - -struct ggml_tensor * ggml_silu_back( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - bool is_node = false; - - if (a->grad || b->grad) { - // TODO: implement backward - is_node = true; - } - - struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_SILU_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -// ggml_norm - -static struct ggml_tensor * ggml_norm_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - float eps, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_set_op_params(result, &eps, sizeof(eps)); - - result->op = GGML_OP_NORM; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -struct ggml_tensor * ggml_norm( - struct ggml_context * ctx, - struct ggml_tensor * a, - float eps) { - return ggml_norm_impl(ctx, a, eps, false); -} - -struct ggml_tensor * ggml_norm_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - float eps) { - return ggml_norm_impl(ctx, a, eps, true); -} - -// ggml_rms_norm - -static struct ggml_tensor * ggml_rms_norm_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - float eps, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_set_op_params(result, &eps, sizeof(eps)); - - result->op = GGML_OP_RMS_NORM; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -struct ggml_tensor * ggml_rms_norm( - struct ggml_context * ctx, - struct ggml_tensor * a, - float eps) { - return ggml_rms_norm_impl(ctx, a, eps, false); -} - -struct ggml_tensor * ggml_rms_norm_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - float eps) { - return ggml_rms_norm_impl(ctx, a, eps, true); -} - -// ggml_rms_norm_back - -struct ggml_tensor * ggml_rms_norm_back( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - float eps) { - bool is_node = false; - - if (a->grad) { - // TODO: implement backward - is_node = true; - } - - struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - - ggml_set_op_params(result, &eps, sizeof(eps)); - - result->op = GGML_OP_RMS_NORM_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -// ggml_group_norm - -static struct ggml_tensor * ggml_group_norm_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_groups, - bool inplace) { - - bool is_node = false; - if (!inplace && (a->grad)) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_GROUP_NORM; - result->op_params[0] = n_groups; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = NULL; // TODO: maybe store epsilon here? - - return result; -} - -struct ggml_tensor * ggml_group_norm( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_groups) { - return ggml_group_norm_impl(ctx, a, n_groups, false); -} - -struct ggml_tensor * ggml_group_norm_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_groups) { - return ggml_group_norm_impl(ctx, a, n_groups, true); -} - -// ggml_mul_mat - -struct ggml_tensor * ggml_mul_mat( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - GGML_ASSERT(ggml_can_mul_mat(a, b)); - GGML_ASSERT(!ggml_is_transposed(a)); - - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - - const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne); - - result->op = GGML_OP_MUL_MAT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -// ggml_out_prod - -struct ggml_tensor * ggml_out_prod( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - GGML_ASSERT(ggml_can_out_prod(a, b)); - GGML_ASSERT(!ggml_is_transposed(a)); - - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - - // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3] - const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne); - - result->op = GGML_OP_OUT_PROD; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -// ggml_scale - -static struct ggml_tensor * ggml_scale_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - bool inplace) { - GGML_ASSERT(ggml_is_scalar(b)); - GGML_ASSERT(ggml_is_padded_1d(a)); - - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_SCALE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -struct ggml_tensor * ggml_scale( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_scale_impl(ctx, a, b, false); -} - -struct ggml_tensor * ggml_scale_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_scale_impl(ctx, a, b, true); -} - -// ggml_set - -static struct ggml_tensor * ggml_set_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset, - bool inplace) { - GGML_ASSERT(ggml_nelements(a) >= ggml_nelements(b)); - - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - - // make a view of the destination - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 }; - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_SET; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -struct ggml_tensor * ggml_set( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset) { - return ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, false); -} - -struct ggml_tensor * ggml_set_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset) { - return ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, true); -} - -struct ggml_tensor * ggml_set_1d( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - size_t offset) { - return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, false); -} - -struct ggml_tensor * ggml_set_1d_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - size_t offset) { - return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, true); -} - -struct ggml_tensor * ggml_set_2d( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - size_t nb1, - size_t offset) { - return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false); -} - -struct ggml_tensor * ggml_set_2d_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - size_t nb1, - size_t offset) { - return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false); -} - - -// ggml_cpy - -static struct ggml_tensor * ggml_cpy_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - bool inplace) { - GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b)); - - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - - // make a view of the destination - struct ggml_tensor * result = ggml_view_tensor(ctx, b); - if (strlen(b->name) > 0) { - ggml_format_name(result, "%s (copy of %s)", b->name, a->name); - } else { - ggml_format_name(result, "%s (copy)", a->name); - } - - result->op = GGML_OP_CPY; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -struct ggml_tensor * ggml_cpy( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_cpy_impl(ctx, a, b, false); -} - -struct ggml_tensor * ggml_cpy_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_cpy_impl(ctx, a, b, true); -} - -// ggml_cont - -static struct ggml_tensor * ggml_cont_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && a->grad) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_format_name(result, "%s (cont)", a->name); - - result->op = GGML_OP_CONT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -struct ggml_tensor * ggml_cont( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_cont_impl(ctx, a, false); -} - -struct ggml_tensor * ggml_cont_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_cont_impl(ctx, a, true); -} - - -// make contiguous, with new shape -GGML_API struct ggml_tensor * ggml_cont_1d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int64_t ne0) { - return ggml_cont_4d(ctx, a, ne0, 1, 1, 1); -} - -GGML_API struct ggml_tensor * ggml_cont_2d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int64_t ne0, - int64_t ne1) { - return ggml_cont_4d(ctx, a, ne0, ne1, 1, 1); -} - -GGML_API struct ggml_tensor * ggml_cont_3d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int64_t ne0, - int64_t ne1, - int64_t ne2) { - return ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1); -} - -struct ggml_tensor * ggml_cont_4d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int64_t ne0, - int64_t ne1, - int64_t ne2, - int64_t ne3) { - GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3)); - - bool is_node = false; - - struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); - ggml_format_name(result, "%s (cont)", a->name); - - result->op = GGML_OP_CONT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -// ggml_reshape - -struct ggml_tensor * ggml_reshape( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - GGML_ASSERT(ggml_is_contiguous(a)); - // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous. - GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b)); - - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - if (b->grad) { - // gradient propagation is not supported - //GGML_ASSERT(false); - } - - struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a, 0); - ggml_format_name(result, "%s (reshaped)", a->name); - - result->op = GGML_OP_RESHAPE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -struct ggml_tensor * ggml_reshape_1d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int64_t ne0) { - GGML_ASSERT(ggml_is_contiguous(a)); - GGML_ASSERT(ggml_nelements(a) == ne0); - - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - const int64_t ne[1] = { ne0 }; - struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, ne, a, 0); - ggml_format_name(result, "%s (reshaped)", a->name); - - result->op = GGML_OP_RESHAPE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -struct ggml_tensor * ggml_reshape_2d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int64_t ne0, - int64_t ne1) { - GGML_ASSERT(ggml_is_contiguous(a)); - GGML_ASSERT(ggml_nelements(a) == ne0*ne1); - - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - const int64_t ne[2] = { ne0, ne1 }; - struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a, 0); - ggml_format_name(result, "%s (reshaped)", a->name); - - result->op = GGML_OP_RESHAPE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -struct ggml_tensor * ggml_reshape_3d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int64_t ne0, - int64_t ne1, - int64_t ne2) { - GGML_ASSERT(ggml_is_contiguous(a)); - GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2); - - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - const int64_t ne[3] = { ne0, ne1, ne2 }; - struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a, 0); - ggml_format_name(result, "%s (reshaped)", a->name); - - result->op = GGML_OP_RESHAPE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -struct ggml_tensor * ggml_reshape_4d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int64_t ne0, - int64_t ne1, - int64_t ne2, - int64_t ne3) { - GGML_ASSERT(ggml_is_contiguous(a)); - GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3); - - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; - struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a, 0); - ggml_format_name(result, "%s (reshaped)", a->name); - - result->op = GGML_OP_RESHAPE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -static struct ggml_tensor * ggml_view_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_dims, - const int64_t * ne, - size_t offset) { - - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, n_dims, ne, a, offset); - ggml_format_name(result, "%s (view)", a->name); - - ggml_set_op_params(result, &offset, sizeof(offset)); - - result->op = GGML_OP_VIEW; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -// ggml_view_1d - -struct ggml_tensor * ggml_view_1d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int64_t ne0, - size_t offset) { - - struct ggml_tensor * result = ggml_view_impl(ctx, a, 1, &ne0, offset); - - return result; -} - -// ggml_view_2d - -struct ggml_tensor * ggml_view_2d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int64_t ne0, - int64_t ne1, - size_t nb1, - size_t offset) { - - const int64_t ne[2] = { ne0, ne1 }; - - struct ggml_tensor * result = ggml_view_impl(ctx, a, 2, ne, offset); - - result->nb[1] = nb1; - result->nb[2] = result->nb[1]*ne1; - result->nb[3] = result->nb[2]; - - return result; -} - -// ggml_view_3d - -struct ggml_tensor * ggml_view_3d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int64_t ne0, - int64_t ne1, - int64_t ne2, - size_t nb1, - size_t nb2, - size_t offset) { - - const int64_t ne[3] = { ne0, ne1, ne2 }; - - struct ggml_tensor * result = ggml_view_impl(ctx, a, 3, ne, offset); - - result->nb[1] = nb1; - result->nb[2] = nb2; - result->nb[3] = result->nb[2]*ne2; - - return result; -} - -// ggml_view_4d - -struct ggml_tensor * ggml_view_4d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int64_t ne0, - int64_t ne1, - int64_t ne2, - int64_t ne3, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset) { - - const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; - - struct ggml_tensor * result = ggml_view_impl(ctx, a, 4, ne, offset); - - result->nb[1] = nb1; - result->nb[2] = nb2; - result->nb[3] = nb3; - - return result; -} - -// ggml_permute - -struct ggml_tensor * ggml_permute( - struct ggml_context * ctx, - struct ggml_tensor * a, - int axis0, - int axis1, - int axis2, - int axis3) { - GGML_ASSERT(axis0 >= 0 && axis0 < GGML_MAX_DIMS); - GGML_ASSERT(axis1 >= 0 && axis1 < GGML_MAX_DIMS); - GGML_ASSERT(axis2 >= 0 && axis2 < GGML_MAX_DIMS); - GGML_ASSERT(axis3 >= 0 && axis3 < GGML_MAX_DIMS); - - GGML_ASSERT(axis0 != axis1); - GGML_ASSERT(axis0 != axis2); - GGML_ASSERT(axis0 != axis3); - GGML_ASSERT(axis1 != axis2); - GGML_ASSERT(axis1 != axis3); - GGML_ASSERT(axis2 != axis3); - - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - struct ggml_tensor * result = ggml_view_tensor(ctx, a); - ggml_format_name(result, "%s (permuted)", a->name); - - int ne[GGML_MAX_DIMS]; - int nb[GGML_MAX_DIMS]; - - ne[axis0] = a->ne[0]; - ne[axis1] = a->ne[1]; - ne[axis2] = a->ne[2]; - ne[axis3] = a->ne[3]; - - nb[axis0] = a->nb[0]; - nb[axis1] = a->nb[1]; - nb[axis2] = a->nb[2]; - nb[axis3] = a->nb[3]; - - result->ne[0] = ne[0]; - result->ne[1] = ne[1]; - result->ne[2] = ne[2]; - result->ne[3] = ne[3]; - - result->nb[0] = nb[0]; - result->nb[1] = nb[1]; - result->nb[2] = nb[2]; - result->nb[3] = nb[3]; - - result->op = GGML_OP_PERMUTE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - int32_t params[] = { axis0, axis1, axis2, axis3 }; - ggml_set_op_params(result, params, sizeof(params)); - - return result; -} - -// ggml_transpose - -struct ggml_tensor * ggml_transpose( - struct ggml_context * ctx, - struct ggml_tensor * a) { - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - struct ggml_tensor * result = ggml_view_tensor(ctx, a); - ggml_format_name(result, "%s (transposed)", a->name); - - result->ne[0] = a->ne[1]; - result->ne[1] = a->ne[0]; - - result->nb[0] = a->nb[1]; - result->nb[1] = a->nb[0]; - - result->op = GGML_OP_TRANSPOSE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -// ggml_get_rows - -struct ggml_tensor * ggml_get_rows( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32); - - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - - // TODO: implement non F32 return - //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); - struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0]); - - result->op = GGML_OP_GET_ROWS; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -// ggml_get_rows_back - -struct ggml_tensor * ggml_get_rows_back( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c) { - GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32); - GGML_ASSERT(ggml_is_matrix(c) && (a->ne[0] == c->ne[0])); - - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - - // TODO: implement non F32 return - //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); - struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, c->ne[0], c->ne[1]); - - result->op = GGML_OP_GET_ROWS_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -// ggml_diag - -struct ggml_tensor * ggml_diag( - struct ggml_context * ctx, - struct ggml_tensor * a) { - GGML_ASSERT(a->ne[1] == 1); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - const int64_t ne[4] = { a->ne[0], a->ne[0], a->ne[2], a->ne[3] }; - struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, MAX(a->n_dims, 2), ne); - - result->op = GGML_OP_DIAG; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - - -// ggml_diag_mask_inf - -static struct ggml_tensor * ggml_diag_mask_inf_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_past, - bool inplace) { - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - int32_t params[] = { n_past }; - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_DIAG_MASK_INF; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -struct ggml_tensor * ggml_diag_mask_inf( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_past) { - return ggml_diag_mask_inf_impl(ctx, a, n_past, false); -} - -struct ggml_tensor * ggml_diag_mask_inf_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_past) { - return ggml_diag_mask_inf_impl(ctx, a, n_past, true); -} - -// ggml_diag_mask_zero - -static struct ggml_tensor * ggml_diag_mask_zero_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_past, - bool inplace) { - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - int32_t params[] = { n_past }; - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_DIAG_MASK_ZERO; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -struct ggml_tensor * ggml_diag_mask_zero( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_past) { - return ggml_diag_mask_zero_impl(ctx, a, n_past, false); -} - -struct ggml_tensor * ggml_diag_mask_zero_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_past) { - return ggml_diag_mask_zero_impl(ctx, a, n_past, true); -} - -// ggml_soft_max - -static struct ggml_tensor * ggml_soft_max_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_SOFT_MAX; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -struct ggml_tensor * ggml_soft_max( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_soft_max_impl(ctx, a, false); -} - -struct ggml_tensor * ggml_soft_max_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_soft_max_impl(ctx, a, true); -} - - -// ggml_soft_max_back - -static struct ggml_tensor * ggml_soft_max_back_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - bool inplace) { - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; // TODO : implement backward pass - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_SOFT_MAX_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -struct ggml_tensor * ggml_soft_max_back( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_soft_max_back_impl(ctx, a, b, false); -} - -struct ggml_tensor * ggml_soft_max_back_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_soft_max_back_impl(ctx, a, b, true); -} - -// ggml_rope - -static struct ggml_tensor * ggml_rope_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int n_dims, - int mode, - int n_ctx, - float freq_base, - float freq_scale, - float xpos_base, - bool xpos_down, - bool inplace) { - GGML_ASSERT(ggml_is_vector(b)); - GGML_ASSERT(b->type == GGML_TYPE_I32); - GGML_ASSERT(a->ne[2] == b->ne[0]); - - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx }; - memcpy(params + 4, &freq_base, sizeof(float)); - memcpy(params + 5, &freq_scale, sizeof(float)); - memcpy(params + 6, &xpos_base, sizeof(float)); - memcpy(params + 7, &xpos_down, sizeof(bool)); - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_ROPE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -struct ggml_tensor * ggml_rope( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int n_dims, - int mode, - int n_ctx) { - return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false); -} - -struct ggml_tensor * ggml_rope_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int n_dims, - int mode, - int n_ctx) { - return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true); -} - -struct ggml_tensor * ggml_rope_custom( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int n_dims, - int mode, - int n_ctx, - float freq_base, - float freq_scale) { - return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false); -} - -struct ggml_tensor * ggml_rope_custom_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int n_dims, - int mode, - int n_ctx, - float freq_base, - float freq_scale) { - return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true); -} - -struct ggml_tensor * ggml_rope_xpos_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int n_dims, - float base, - bool down) { - return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true); -} - -// ggml_rope_back - -struct ggml_tensor * ggml_rope_back( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int n_dims, - int mode, - int n_ctx, - float freq_base, - float freq_scale, - float xpos_base, - bool xpos_down) { - GGML_ASSERT(ggml_is_vector(b)); - GGML_ASSERT(b->type == GGML_TYPE_I32); - GGML_ASSERT(a->ne[2] == b->ne[0]); - - GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet"); - - bool is_node = false; - - if (a->grad) { - is_node = false; // TODO: implement backward - } - - struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - - int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx }; - memcpy(params + 4, &freq_base, sizeof(float)); - memcpy(params + 5, &freq_scale, sizeof(float)); - memcpy(params + 6, &xpos_base, sizeof(float)); - memcpy(params + 7, &xpos_down, sizeof(bool)); - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_ROPE_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -// ggml_alibi - -struct ggml_tensor * ggml_alibi( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_past, - int n_head, - float bias_max) { - GGML_ASSERT(n_past >= 0); - bool is_node = false; - - if (a->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - // TODO: when implement backward, fix this: - //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - struct ggml_tensor * result = ggml_view_tensor(ctx, a); - - int32_t op_params[3] = { n_past, n_head }; - memcpy(op_params + 2, &bias_max, sizeof(float)); - ggml_set_op_params(result, op_params, sizeof(op_params)); - - result->op = GGML_OP_ALIBI; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -// ggml_clamp - -struct ggml_tensor * ggml_clamp( - struct ggml_context * ctx, - struct ggml_tensor * a, - float min, - float max) { - bool is_node = false; - - if (a->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - // TODO: when implement backward, fix this: - struct ggml_tensor * result = ggml_view_tensor(ctx, a); - - float params[] = { min, max }; - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_CLAMP; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -// ggml_conv_1d - -static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) { - return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; -} - -// im2col: [N, IC, IL] => [N, OL, IC*K] -// a: [OC,IC, K] -// b: [N, IC, IL] -// result: [N, OL, IC*K] -static struct ggml_tensor * ggml_conv_1d_stage_0( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int p0, - int d0) { - GGML_ASSERT(a->ne[1] == b->ne[1]); - bool is_node = false; - - if (a->grad || b->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - const int64_t OL = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); - - const int64_t ne[4] = { - a->ne[1] * a->ne[0], - OL, - b->ne[2], - 1, - }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne); - - int32_t params[] = { s0, p0, d0 }; - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_CONV_1D_STAGE_0; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -// ggml_conv_1d_stage_1 - -// gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K] -// a: [OC, IC, K] -// b: [N, OL, IC * K] -// result: [N, OC, OL] -static struct ggml_tensor * ggml_conv_1d_stage_1( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - - bool is_node = false; - - if (a->grad || b->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - const int64_t ne[4] = { - b->ne[1], - a->ne[2], - b->ne[2], - 1, - }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - - result->op = GGML_OP_CONV_1D_STAGE_1; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -// ggml_conv_1d - -GGML_API struct ggml_tensor * ggml_conv_1d( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int p0, - int d0) { - struct ggml_tensor * result = ggml_conv_1d_stage_0(ctx, a, b, s0, p0, d0); - result = ggml_conv_1d_stage_1(ctx, a, result); - return result; -} - -// GGML_API struct ggml_tensor * ggml_conv_1d( -// struct ggml_context * ctx, -// struct ggml_tensor * a, -// struct ggml_tensor * b, -// int s0, -// int p0, -// int d0) { -// GGML_ASSERT(ggml_is_matrix(b)); -// GGML_ASSERT(a->ne[1] == b->ne[1]); -// bool is_node = false; - -// if (a->grad || b->grad) { -// GGML_ASSERT(false); // TODO: implement backward -// is_node = true; -// } - -// const int64_t ne[4] = { -// ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0), -// a->ne[2], 1, 1, -// }; -// struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); - -// int32_t params[] = { s0, p0, d0 }; -// ggml_set_op_params(result, params, sizeof(params)); - -// result->op = GGML_OP_CONV_1D; -// result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; -// result->src[0] = a; -// result->src[1] = b; - -// return result; -// } - -// ggml_conv_1d_ph - -struct ggml_tensor* ggml_conv_1d_ph( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s, - int d) { - return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d); -} - -// ggml_conv_transpose_1d - -static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) { - return (ins - 1) * s - 2 * p + d * (ks - 1) + 1; -} - -GGML_API struct ggml_tensor * ggml_conv_transpose_1d( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int p0, - int d0) { - GGML_ASSERT(ggml_is_matrix(b)); - GGML_ASSERT(a->ne[2] == b->ne[1]); - GGML_ASSERT(a->ne[3] == 1); - - GGML_ASSERT(p0 == 0); - GGML_ASSERT(d0 == 1); - - bool is_node = false; - - if (a->grad || b->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - const int64_t ne[4] = { - ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/), - a->ne[1], b->ne[2], 1, - }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - - int32_t params[] = { s0, p0, d0 }; - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_CONV_TRANSPOSE_1D; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -// ggml_conv_2d - -struct ggml_tensor * ggml_conv_2d( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1) { - - GGML_ASSERT(a->ne[2] == b->ne[2]); - bool is_node = false; - - if (a->grad || b->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - const int64_t ne[4] = { - ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0), - ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1), - a->ne[3], b->ne[3], - }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - - int32_t params[] = { s0, s1, p0, p1, d0, d1 }; - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_CONV_2D; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; - -} - -// ggml_conv_2d_sk_p0 - -struct ggml_tensor * ggml_conv_2d_sk_p0( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_conv_2d(ctx, a, b, a->ne[0], a->ne[1], 0, 0, 1, 1); -} - -// ggml_conv_2d_s1_ph - -struct ggml_tensor * ggml_conv_2d_s1_ph( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1); -} - -// ggml_conv_transpose_2d_p0 - -static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) { - return (ins - 1) * s - 2 * p + ks; -} - -struct ggml_tensor * ggml_conv_transpose_2d_p0( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int stride) { - GGML_ASSERT(a->ne[3] == b->ne[2]); - - bool is_node = false; - - if (a->grad || b->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - const int64_t ne[4] = { - ggml_calc_conv_transpose_output_size(b->ne[0], a->ne[0], stride, 0 /*p0*/), - ggml_calc_conv_transpose_output_size(b->ne[1], a->ne[1], stride, 0 /*p1*/), - a->ne[2], b->ne[3], - }; - - struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - - ggml_set_op_params_i32(result, 0, stride); - - result->op = GGML_OP_CONV_TRANSPOSE_2D; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -// ggml_pool_* - -static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, int p) { - return (ins + 2 * p - ks) / s + 1; -} - -// ggml_pool_1d - -struct ggml_tensor * ggml_pool_1d( - struct ggml_context * ctx, - struct ggml_tensor * a, - enum ggml_op_pool op, - int k0, - int s0, - int p0) { - - bool is_node = false; - - if (a->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - const int64_t ne[3] = { - ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), - a->ne[1], - }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); - - int32_t params[] = { op, k0, s0, p0 }; - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_POOL_1D; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -// ggml_pool_2d - -struct ggml_tensor * ggml_pool_2d( - struct ggml_context * ctx, - struct ggml_tensor * a, - enum ggml_op_pool op, - int k0, - int k1, - int s0, - int s1, - int p0, - int p1) { - - bool is_node = false; - - if (a->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - const int64_t ne[3] = { - ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), - ggml_calc_pool_output_size(a->ne[1], k1, s1, p1), - a->ne[2], - }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne); - - int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_POOL_2D; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -// ggml_upscale - -static struct ggml_tensor * ggml_upscale_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - int scale_factor) { - bool is_node = false; - - if (a->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, - a->ne[0] * scale_factor, - a->ne[1] * scale_factor, - a->ne[2], a->ne[3]); - - result->op = GGML_OP_UPSCALE; - result->op_params[0] = scale_factor; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = NULL; - - return result; -} - -struct ggml_tensor * ggml_upscale( - struct ggml_context * ctx, - struct ggml_tensor * a, - int scale_factor) { - return ggml_upscale_impl(ctx, a, scale_factor); -} - -// ggml_flash_attn - -struct ggml_tensor * ggml_flash_attn( - struct ggml_context * ctx, - struct ggml_tensor * q, - struct ggml_tensor * k, - struct ggml_tensor * v, - bool masked) { - GGML_ASSERT(ggml_can_mul_mat(k, q)); - // TODO: check if vT can be multiplied by (k*qT) - - bool is_node = false; - - if (q->grad || k->grad || v->grad) { - is_node = true; - } - - //struct ggml_tensor * result = ggml_dup_tensor(ctx, q); - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, q->n_dims, q->ne); - - int32_t t = masked ? 1 : 0; - ggml_set_op_params(result, &t, sizeof(t)); - - result->op = GGML_OP_FLASH_ATTN; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = q; - result->src[1] = k; - result->src[2] = v; - - return result; -} - -// ggml_flash_ff - -struct ggml_tensor * ggml_flash_ff( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b0, - struct ggml_tensor * b1, - struct ggml_tensor * c0, - struct ggml_tensor * c1) { - GGML_ASSERT(ggml_can_mul_mat(b0, a)); - // TODO: more checks - - bool is_node = false; - - if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) { - is_node = true; - } - - //struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, a->ne); - - result->op = GGML_OP_FLASH_FF; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b0; - result->src[2] = b1; - result->src[3] = c0; - result->src[4] = c1; - - return result; -} - -// ggml_flash_attn_back - -struct ggml_tensor * ggml_flash_attn_back( - struct ggml_context * ctx, - struct ggml_tensor * q, - struct ggml_tensor * k, - struct ggml_tensor * v, - struct ggml_tensor * d, - bool masked) { - GGML_ASSERT(ggml_can_mul_mat(k, q)); - // TODO: check if vT can be multiplied by (k*qT) - - // d shape [D,N,ne2,ne3] - // q shape [D,N,ne2,ne3] - // k shape [D,M,kvne2,ne3] - // v shape [M,D,kvne2,ne3] - - const int64_t D = q->ne[0]; - const int64_t N = q->ne[1]; - const int64_t M = k->ne[1]; - const int64_t ne2 = q->ne[2]; - const int64_t ne3 = q->ne[3]; - const int64_t kvne2 = k->ne[2]; - - GGML_ASSERT(k->ne[0] == D); - GGML_ASSERT(v->ne[0] == M); - GGML_ASSERT(v->ne[1] == D); - GGML_ASSERT(d->ne[0] == D); - GGML_ASSERT(d->ne[1] == N); - GGML_ASSERT(k->ne[2] == kvne2); - GGML_ASSERT(k->ne[3] == ne3); - GGML_ASSERT(v->ne[2] == kvne2); - GGML_ASSERT(v->ne[3] == ne3); - GGML_ASSERT(d->ne[2] == ne2); - GGML_ASSERT(d->ne[3] == ne3); - - GGML_ASSERT(ne2 % kvne2 == 0); - - bool is_node = false; - - if (q->grad || k->grad || v->grad) { - // when using this operation (in backwards pass) these grads are set. - // we don't want to create (big) grad of our result, so is_node is false. - is_node = false; - } - - // store gradients of q, k and v as continuous tensors concatenated in result. - // note: v and gradv are actually transposed, i.e. v->ne[0] != D. - const int64_t elem_q = ggml_nelements(q); - const int64_t elem_k = ggml_nelements(k); - const int64_t elem_v = ggml_nelements(v); - - enum ggml_type result_type = GGML_TYPE_F32; - GGML_ASSERT(ggml_blck_size(result_type) == 1); - const size_t tsize = ggml_type_size(result_type); - - const size_t offs_q = 0; - const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN); - const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN); - const size_t end = offs_v + GGML_PAD(elem_v * tsize, GGML_MEM_ALIGN); - - const size_t nelements = (end + tsize - 1)/tsize; - - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nelements); - - int32_t masked_i = masked ? 1 : 0; - ggml_set_op_params(result, &masked_i, sizeof(masked_i)); - - result->op = GGML_OP_FLASH_ATTN_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = q; - result->src[1] = k; - result->src[2] = v; - result->src[3] = d; - - return result; -} - -// ggml_win_part - -struct ggml_tensor * ggml_win_part( - struct ggml_context * ctx, - struct ggml_tensor * a, - int w) { - GGML_ASSERT(a->ne[3] == 1); - GGML_ASSERT(a->type == GGML_TYPE_F32); - - bool is_node = false; - - if (a->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - // padding - const int px = (w - a->ne[1]%w)%w; - const int py = (w - a->ne[2]%w)%w; - - const int npx = (px + a->ne[1])/w; - const int npy = (py + a->ne[2])/w; - const int np = npx*npy; - - const int64_t ne[4] = { a->ne[0], w, w, np, }; - - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - - int32_t params[] = { npx, npy, w }; - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_WIN_PART; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -// ggml_win_unpart - -struct ggml_tensor * ggml_win_unpart( - struct ggml_context * ctx, - struct ggml_tensor * a, - int w0, - int h0, - int w) { - GGML_ASSERT(a->type == GGML_TYPE_F32); - - bool is_node = false; - - if (a->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - const int64_t ne[4] = { a->ne[0], w0, h0, 1, }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne); - - int32_t params[] = { w }; - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_WIN_UNPART; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -// ggml_get_rel_pos - -struct ggml_tensor * ggml_get_rel_pos( - struct ggml_context * ctx, - struct ggml_tensor * a, - int qh, - int kh) { - GGML_ASSERT(qh == kh); - GGML_ASSERT(2*MAX(qh, kh) - 1 == a->ne[1]); - - bool is_node = false; - - if (a->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - const int64_t ne[4] = { a->ne[0], kh, qh, 1, }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 3, ne); - - result->op = GGML_OP_GET_REL_POS; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = NULL; - - return result; -} - -// ggml_add_rel_pos - -static struct ggml_tensor * ggml_add_rel_pos_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * pw, - struct ggml_tensor * ph, - bool inplace) { - GGML_ASSERT(ggml_are_same_shape(pw, ph)); - GGML_ASSERT(ggml_is_contiguous(a)); - GGML_ASSERT(ggml_is_contiguous(pw)); - GGML_ASSERT(ggml_is_contiguous(ph)); - GGML_ASSERT(ph->type == GGML_TYPE_F32); - GGML_ASSERT(pw->type == GGML_TYPE_F32); - GGML_ASSERT(pw->ne[3] == a->ne[2]); - GGML_ASSERT(pw->ne[0]*pw->ne[0] == a->ne[0]); - GGML_ASSERT(pw->ne[1]*pw->ne[2] == a->ne[1]); - - bool is_node = false; - - if (!inplace && (a->grad || pw->grad || ph->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_set_op_params_i32(result, 0, inplace ? 1 : 0); - - result->op = GGML_OP_ADD_REL_POS; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = pw; - result->src[2] = ph; - - return result; -} - - -struct ggml_tensor * ggml_add_rel_pos( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * pw, - struct ggml_tensor * ph) { - return ggml_add_rel_pos_impl(ctx, a, pw, ph, false); -} - -struct ggml_tensor * ggml_add_rel_pos_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * pw, - struct ggml_tensor * ph) { - return ggml_add_rel_pos_impl(ctx, a, pw, ph, true); -} - -// gmml_unary - -static struct ggml_tensor * ggml_unary_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - enum ggml_unary_op op, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_set_op_params_i32(result, 0, (int32_t) op); - - result->op = GGML_OP_UNARY; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -struct ggml_tensor * ggml_unary( - struct ggml_context * ctx, - struct ggml_tensor * a, - enum ggml_unary_op op) { - return ggml_unary_impl(ctx, a, op, false); -} - -struct ggml_tensor * ggml_unary_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - enum ggml_unary_op op) { - return ggml_unary_impl(ctx, a, op, true); -} - -// ggml_map_unary - -static struct ggml_tensor * ggml_map_unary_impl_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_unary_op_f32_t fun, - bool inplace) { - bool is_node = false; - - if (!inplace && a->grad) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - - result->op = GGML_OP_MAP_UNARY; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -struct ggml_tensor * ggml_map_unary_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_unary_op_f32_t fun) { - return ggml_map_unary_impl_f32(ctx, a, fun, false); -} - -struct ggml_tensor * ggml_map_unary_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_unary_op_f32_t fun) { - return ggml_map_unary_impl_f32(ctx, a, fun, true); -} - -// ggml_map_binary - -static struct ggml_tensor * ggml_map_binary_impl_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_binary_op_f32_t fun, - bool inplace) { - GGML_ASSERT(ggml_are_same_shape(a, b)); - - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - - result->op = GGML_OP_MAP_BINARY; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -struct ggml_tensor * ggml_map_binary_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_binary_op_f32_t fun) { - return ggml_map_binary_impl_f32(ctx, a, b, fun, false); -} - -struct ggml_tensor * ggml_map_binary_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_binary_op_f32_t fun) { - return ggml_map_binary_impl_f32(ctx, a, b, fun, true); -} - -// ggml_map_custom1_f32 - -static struct ggml_tensor * ggml_map_custom1_impl_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_custom1_op_f32_t fun, - bool inplace) { - bool is_node = false; - - if (!inplace && a->grad) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - - result->op = GGML_OP_MAP_CUSTOM1_F32; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -struct ggml_tensor * ggml_map_custom1_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_custom1_op_f32_t fun) { - return ggml_map_custom1_impl_f32(ctx, a, fun, false); -} - -struct ggml_tensor * ggml_map_custom1_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_custom1_op_f32_t fun) { - return ggml_map_custom1_impl_f32(ctx, a, fun, true); -} - -// ggml_map_custom2_f32 - -static struct ggml_tensor * ggml_map_custom2_impl_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_custom2_op_f32_t fun, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - - result->op = GGML_OP_MAP_CUSTOM2_F32; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -struct ggml_tensor * ggml_map_custom2_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_custom2_op_f32_t fun) { - return ggml_map_custom2_impl_f32(ctx, a, b, fun, false); -} - -struct ggml_tensor * ggml_map_custom2_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_custom2_op_f32_t fun) { - return ggml_map_custom2_impl_f32(ctx, a, b, fun, true); -} - -// ggml_map_custom3_f32 - -static struct ggml_tensor * ggml_map_custom3_impl_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - const ggml_custom3_op_f32_t fun, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad || b->grad || c->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - - result->op = GGML_OP_MAP_CUSTOM3_F32; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - result->src[2] = c; - - return result; -} - -struct ggml_tensor * ggml_map_custom3_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - const ggml_custom3_op_f32_t fun) { - return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, false); -} - -struct ggml_tensor * ggml_map_custom3_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - const ggml_custom3_op_f32_t fun) { - return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, true); -} - -// ggml_map_custom1 -struct ggml_map_custom1_op_params { - ggml_custom1_op_t fun; - int n_tasks; - void * userdata; -}; - -static struct ggml_tensor * ggml_map_custom1_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_custom1_op_t fun, - int n_tasks, - void * userdata, - bool inplace) { - GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0); - - bool is_node = false; - - if (!inplace && a->grad) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - struct ggml_map_custom1_op_params params = { - /*.fun =*/ fun, - /*.n_tasks =*/ n_tasks, - /*.userdata =*/ userdata - }; - ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); - - result->op = GGML_OP_MAP_CUSTOM1; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - -struct ggml_tensor * ggml_map_custom1( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_custom1_op_t fun, - int n_tasks, - void * userdata) { - return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, false); -} - -struct ggml_tensor * ggml_map_custom1_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_custom1_op_t fun, - int n_tasks, - void * userdata) { - return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, true); -} - -// ggml_map_custom2 - -struct ggml_map_custom2_op_params { - ggml_custom2_op_t fun; - int n_tasks; - void * userdata; -}; - -static struct ggml_tensor * ggml_map_custom2_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_custom2_op_t fun, - int n_tasks, - void * userdata, - bool inplace) { - GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0); - - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - struct ggml_map_custom2_op_params params = { - /*.fun =*/ fun, - /*.n_tasks =*/ n_tasks, - /*.userdata =*/ userdata - }; - ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); - - result->op = GGML_OP_MAP_CUSTOM2; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -struct ggml_tensor * ggml_map_custom2( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_custom2_op_t fun, - int n_tasks, - void * userdata) { - return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, false); -} - -struct ggml_tensor * ggml_map_custom2_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_custom2_op_t fun, - int n_tasks, - void * userdata) { - return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, true); -} - -// ggml_map_custom3 - -struct ggml_map_custom3_op_params { - ggml_custom3_op_t fun; - int n_tasks; - void * userdata; -}; - -static struct ggml_tensor * ggml_map_custom3_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - const ggml_custom3_op_t fun, - int n_tasks, - void * userdata, - bool inplace) { - GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0); - - bool is_node = false; - - if (!inplace && (a->grad || b->grad || c->grad)) { - is_node = true; - } - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - struct ggml_map_custom3_op_params params = { - /*.fun =*/ fun, - /*.n_tasks =*/ n_tasks, - /*.userdata =*/ userdata - }; - ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); - - result->op = GGML_OP_MAP_CUSTOM3; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - result->src[2] = c; - - return result; -} - -struct ggml_tensor * ggml_map_custom3( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - const ggml_custom3_op_t fun, - int n_tasks, - void * userdata) { - return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, false); -} - -struct ggml_tensor * ggml_map_custom3_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - const ggml_custom3_op_t fun, - int n_tasks, - void * userdata) { - return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true); -} - - - -// ggml_cross_entropy_loss - -struct ggml_tensor * ggml_cross_entropy_loss( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { - GGML_ASSERT(ggml_are_same_shape(a, b)); - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1); - - result->op = GGML_OP_CROSS_ENTROPY_LOSS; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -// ggml_cross_entropy_loss_back - -struct ggml_tensor * ggml_cross_entropy_loss_back( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c) { - GGML_ASSERT(ggml_are_same_shape(a, b)); - GGML_ASSERT(ggml_is_scalar(c)); - - struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - - result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK; - result->grad = NULL; - result->src[0] = a; - result->src[1] = b; - result->src[2] = c; - - return result; -} - -//////////////////////////////////////////////////////////////////////////////// - -void ggml_set_param( - struct ggml_context * ctx, - struct ggml_tensor * tensor) { - tensor->is_param = true; - - GGML_ASSERT(tensor->grad == NULL); - tensor->grad = ggml_dup_tensor(ctx, tensor); - ggml_format_name(tensor->grad, "%s (grad)", tensor->name); -} - -// ggml_compute_forward_dup - -static void ggml_compute_forward_dup_same_cont( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); - GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); - GGML_ASSERT(src0->type == dst->type); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const size_t nb00 = src0->nb[0]; - const size_t nb0 = dst->nb[0]; - - const int ith = params->ith; // thread index - const int nth = params->nth; // number of threads - - // parallelize by elements - const int ne = ggml_nelements(dst); - const int dr = (ne + nth - 1) / nth; - const int ie0 = dr * ith; - const int ie1 = MIN(ie0 + dr, ne); - - if (ie0 < ie1) { - memcpy( - ((char *) dst->data + ie0*nb0), - ((char *) src0->data + ie0*nb00), - (ie1 - ie0) * ggml_type_size(src0->type)); - } - -} -static void ggml_compute_forward_dup_f16( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - GGML_TENSOR_UNARY_OP_LOCALS - - const int ith = params->ith; // thread index - const int nth = params->nth; // number of threads - - if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) { - ggml_compute_forward_dup_same_cont(params, src0, dst); - return; - } - - // parallelize by rows - const int nr = ne01; - // number of rows per thread - const int dr = (nr + nth - 1) / nth; - // row range for this thread - const int ir0 = dr * ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (src0->type == dst->type && - ne00 == ne0 && - nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) { - // copy by rows - const size_t rs = ne00*nb00; - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ir0; i01 < ir1; i01++) { - memcpy( - ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), - ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), - rs); - } - } - } - return; - } - - // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy - - if (ggml_is_contiguous(dst)) { - if (nb00 == sizeof(ggml_fp16_t)) { - if (dst->type == GGML_TYPE_F16) { - size_t id = 0; - const size_t rs = ne00 * nb00; - char * dst_ptr = (char *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; - memcpy(dst_ptr + id, src0_ptr, rs); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else if (dst->type == GGML_TYPE_F32) { - size_t id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - for (int i00 = 0; i00 < ne00; i00++) { - dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (type_traits[dst->type].from_float) { - ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float; - float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; - - size_t id = 0; - size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); - char * dst_ptr = (char *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - - for (int i00 = 0; i00 < ne00; i00++) { - src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]); - } - - quantize_row_q(src0_f32, dst_ptr + id, ne00); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else { - GGML_ASSERT(false); // TODO: implement - } - } else { - //printf("%s: this is not optimal - fix me\n", __func__); - - if (dst->type == GGML_TYPE_F32) { - size_t id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (dst->type == GGML_TYPE_F16) { - size_t id = 0; - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = *src0_ptr; - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else { - GGML_ASSERT(false); // TODO: implement - } - } - return; - } - - // dst counters - int64_t i10 = 0; - int64_t i11 = 0; - int64_t i12 = 0; - int64_t i13 = 0; - - if (dst->type == GGML_TYPE_F16) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t)); - - if (++i10 == ne00) { - i10 = 0; - if (++i11 == ne01) { - i11 = 0; - if (++i12 == ne02) { - i12 = 0; - if (++i13 == ne03) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else if (dst->type == GGML_TYPE_F32) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else { - GGML_ASSERT(false); // TODO: implement - } -} - -static void ggml_compute_forward_dup_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - GGML_TENSOR_UNARY_OP_LOCALS - - const int ith = params->ith; // thread index - const int nth = params->nth; // number of threads - - if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) { - ggml_compute_forward_dup_same_cont(params, src0, dst); - return; - } - - // parallelize by rows - const int nr = ne01; - // number of rows per thread - const int dr = (nr + nth - 1) / nth; - // row range for this thread - const int ir0 = dr * ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (src0->type == dst->type && - ne00 == ne0 && - nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) { - // copy by rows - const size_t rs = ne00*nb00; - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ir0; i01 < ir1; i01++) { - memcpy( - ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), - ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), - rs); - } - } - } - return; - } - - if (ggml_is_contiguous(dst)) { - // TODO: simplify - if (nb00 == sizeof(float)) { - if (dst->type == GGML_TYPE_F32) { - size_t id = 0; - const size_t rs = ne00 * nb00; - char * dst_ptr = (char *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; - memcpy(dst_ptr + id, src0_ptr, rs); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else if (type_traits[dst->type].from_float) { - ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float; - - size_t id = 0; - size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); - char * dst_ptr = (char *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - quantize_row_q(src0_ptr, dst_ptr + id, ne00); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else { - GGML_ASSERT(false); // TODO: implement - } - } else { - //printf("%s: this is not optimal - fix me\n", __func__); - - if (dst->type == GGML_TYPE_F32) { - size_t id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = *src0_ptr; - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (dst->type == GGML_TYPE_F16) { - size_t id = 0; - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else { - GGML_ASSERT(false); // TODO: implement - } - } - - return; - } - - // dst counters - - int64_t i10 = 0; - int64_t i11 = 0; - int64_t i12 = 0; - int64_t i13 = 0; - - if (dst->type == GGML_TYPE_F32) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - memcpy(dst_ptr, src0_ptr, sizeof(float)); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else if (dst->type == GGML_TYPE_F16) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else { - GGML_ASSERT(false); // TODO: implement - } -} - -static void ggml_compute_forward_dup( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) { - ggml_compute_forward_dup_same_cont(params, src0, dst); - return; - } - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_dup_f16(params, src0, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_dup_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_add - -static void ggml_compute_forward_add_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT( nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (nb10 == sizeof(float)) { - for (int ir = ir0; ir < ir1; ++ir) { - // src1 is broadcastable across src0 and dst in i1, i2, i3 - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); - -#ifdef GGML_USE_ACCELERATE - vDSP_vadd(src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00); -#else - ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr); -#endif - } - } else { - // src1 is not contiguous - for (int ir = ir0; ir < ir1; ++ir) { - // src1 is broadcastable across src0 and dst in i1, i2, i3 - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - - for (int i0 = 0; i0 < ne0; i0++) { - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10); - - dst_ptr[i0] = src0_ptr[i0] + *src1_ptr; - } - } - } -} - -static void ggml_compute_forward_add_f16_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F16); - - GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (nb10 == sizeof(float)) { - for (int ir = ir0; ir < ir1; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); - - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]); - } - } - } - else { - // src1 is not contiguous - GGML_ASSERT(false); - } -} - -static void ggml_compute_forward_add_f16_f16( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F16); - - GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (nb10 == sizeof(ggml_fp16_t)) { - for (int ir = ir0; ir < ir1; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); - - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(src1_ptr[i])); - } - } - } - else { - // src1 is not contiguous - GGML_ASSERT(false); - } -} - -static void ggml_compute_forward_add_q_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const enum ggml_type type = src0->type; - const enum ggml_type dtype = dst->type; - ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; - ggml_from_float_t const quantize_row_q = type_traits[dtype].from_float; - - // we don't support permuted src0 or src1 - GGML_ASSERT(nb00 == ggml_type_size(type)); - GGML_ASSERT(nb10 == sizeof(float)); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - GGML_ASSERT(ggml_is_quantized(src0->type)); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir/(ne02*ne01); - const int i02 = (ir - i03*ne02*ne01)/ne01; - const int i01 = (ir - i03*ne02*ne01 - i02*ne01); - - // src1 and dst are same shape as src0 => same indices - const int i13 = i03; - const int i12 = i02; - const int i11 = i01; - - const int i3 = i03; - const int i2 = i02; - const int i1 = i01; - - void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); - float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)); - void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); - - assert(ne00 % 32 == 0); - - // unquantize row from src0 to temp buffer - dequantize_row_q(src0_row, wdata, ne00); - // add src1 - ggml_vec_acc_f32(ne00, wdata, src1_row); - // quantize row to dst - if (quantize_row_q != NULL) { - quantize_row_q(wdata, dst_row, ne00); - } else { - memcpy(dst_row, wdata, ne0*nb0); - } - } -} - -static void ggml_compute_forward_add( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_add_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_F16: - { - if (src1->type == GGML_TYPE_F16) { - ggml_compute_forward_add_f16_f16(params, src0, src1, dst); - } - else if (src1->type == GGML_TYPE_F32) { - ggml_compute_forward_add_f16_f32(params, src0, src1, dst); - } - else { - GGML_ASSERT(false); - } - } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - { - ggml_compute_forward_add_q_f32(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_add1 - -static void ggml_compute_forward_add1_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_is_scalar(src1)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_UNARY_OP_LOCALS - - GGML_ASSERT( nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - -#ifdef GGML_USE_ACCELERATE - UNUSED(ggml_vec_add1_f32); - - vDSP_vadd( - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, - (float *) ((char *) src1->data), 0, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, - ne0); -#else - ggml_vec_add1_f32(ne0, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), - *(float *) src1->data); -#endif - } -} - -static void ggml_compute_forward_add1_f16_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_is_scalar(src1)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - // scalar to add - const float v = *(float *) src1->data; - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_UNARY_OP_LOCALS - - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F16); - - GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v); - } - } -} - -static void ggml_compute_forward_add1_f16_f16( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_is_scalar(src1)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - // scalar to add - const float v = GGML_FP16_TO_FP32(*(ggml_fp16_t *) src1->data); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_UNARY_OP_LOCALS - - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F16); - - GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v); - } - } -} - -static void ggml_compute_forward_add1_q_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_is_scalar(src1)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - // scalar to add - const float v = *(float *) src1->data; - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_UNARY_OP_LOCALS - - const enum ggml_type type = src0->type; - ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; - ggml_from_float_t const quantize_row_q = type_traits[type].from_float; - - // we don't support permuted src0 - GGML_ASSERT(nb00 == ggml_type_size(type)); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - GGML_ASSERT(ggml_is_quantized(src0->type)); - GGML_ASSERT(dst->type == src0->type); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith; - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - void * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03)); - void * dst_row = (void *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb0 )); - - assert(ne0 % 32 == 0); - - // unquantize row from src0 to temp buffer - dequantize_row_q(src0_row, wdata, ne0); - // add src1 - ggml_vec_acc1_f32(ne0, wdata, v); - // quantize row to dst - quantize_row_q(wdata, dst_row, ne0); - } -} - -static void ggml_compute_forward_add1( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_add1_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_F16: - { - if (src1->type == GGML_TYPE_F16) { - ggml_compute_forward_add1_f16_f16(params, src0, src1, dst); - } - else if (src1->type == GGML_TYPE_F32) { - ggml_compute_forward_add1_f16_f32(params, src0, src1, dst); - } - else { - GGML_ASSERT(false); - } - } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q8_1: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - { - ggml_compute_forward_add1_q_f32(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - - -// ggml_compute_forward_acc - -static void ggml_compute_forward_acc_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); - - // view src0 and dst with these strides and data offset inbytes during acc - // nb0 is implicitely element_size because src0 and dst are contiguous - size_t nb1 = ((int32_t *) dst->op_params)[0]; - size_t nb2 = ((int32_t *) dst->op_params)[1]; - size_t nb3 = ((int32_t *) dst->op_params)[2]; - size_t offset = ((int32_t *) dst->op_params)[3]; - bool inplace = (bool) ((int32_t *) dst->op_params)[4]; - - if (!inplace && (params->type == GGML_TASK_INIT)) { - // memcpy needs to be synchronized across threads to avoid race conditions. - // => do it in INIT phase - memcpy( - ((char *) dst->data), - ((char *) src0->data), - ggml_nbytes(dst)); - } - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src1); - const int nc = src1->ne[0]; - - GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) - GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) - - // src0 and dst as viewed during acc - const size_t nb0 = ggml_element_size(src0); - - const size_t nb00 = nb0; - const size_t nb01 = nb1; - const size_t nb02 = nb2; - const size_t nb03 = nb3; - - GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0 + (ne11 == 0 ? 0 : ne11-1)*nb1 + (ne12 == 0 ? 0 : ne12-1)*nb2 + (ne13 == 0 ? 0 : ne13-1)*nb3 < ggml_nbytes(dst)); - GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_nbytes(src0)); - - GGML_ASSERT(nb10 == sizeof(float)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 and dst are viewed with shape of src1 and offset - // => same indices - const int i3 = ir/(ne12*ne11); - const int i2 = (ir - i3*ne12*ne11)/ne11; - const int i1 = (ir - i3*ne12*ne11 - i2*ne11); - -#ifdef GGML_USE_ACCELERATE - vDSP_vadd( - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1, - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), 1, nc); -#else - ggml_vec_add_f32(nc, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); -#endif - } -} - -static void ggml_compute_forward_acc( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_acc_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_F16: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q8_1: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_sub - -static void ggml_compute_forward_sub_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT( nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - - if (nb10 == sizeof(float)) { - for (int ir = 0; ir < nr; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - -#ifdef GGML_USE_ACCELERATE - vDSP_vsub( - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, - ne0); -#else - ggml_vec_sub_f32(ne0, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); -#endif - // } - // } - } - } else { - // src1 is not contiguous - for (int ir = 0; ir < nr; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - for (int i0 = 0; i0 < ne0; i0++) { - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10); - - dst_ptr[i0] = src0_ptr[i0] - *src1_ptr; - } - } - } -} - -static void ggml_compute_forward_sub( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_sub_f32(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_mul - -static void ggml_compute_forward_mul_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - const int ith = params->ith; - const int nth = params->nth; - -#ifdef GGML_USE_CLBLAST - if (src1->backend == GGML_BACKEND_GPU) { - if (ith == 0) { - ggml_cl_mul(src0, src1, dst); - } - return; - } -#endif - - const int64_t nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT( nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - GGML_ASSERT(ne00 == ne10); - - if (nb10 == sizeof(float)) { - for (int64_t ir = ith; ir < nr; ir += nth) { - // src0 and dst are same shape => same indices - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); - -#ifdef GGML_USE_ACCELERATE - UNUSED(ggml_vec_mul_f32); - - vDSP_vmul( src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00); -#else - ggml_vec_mul_f32(ne00, dst_ptr, src0_ptr, src1_ptr); -#endif - // } - // } - } - } else { - // src1 is not contiguous - for (int64_t ir = ith; ir < nr; ir += nth) { - // src0 and dst are same shape => same indices - // src1 is broadcastable across src0 and dst in i1, i2, i3 - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - - for (int64_t i0 = 0; i0 < ne00; i0++) { - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10); - - dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr); - } - } - } -} - -static void ggml_compute_forward_mul( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(src1->type == GGML_TYPE_F32 && "only f32 src1 supported for now"); - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_mul_f32(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_div - -static void ggml_compute_forward_div_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT( nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - - if (nb10 == sizeof(float)) { - for (int ir = 0; ir < nr; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - -#ifdef GGML_USE_ACCELERATE - UNUSED(ggml_vec_div_f32); - - vDSP_vdiv( - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, - ne0); -#else - ggml_vec_div_f32(ne0, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); -#endif - // } - // } - } - } else { - // src1 is not contiguous - for (int ir = 0; ir < nr; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - for (int i0 = 0; i0 < ne0; i0++) { - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10); - - dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr); - } - } - } -} - -static void ggml_compute_forward_div( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_div_f32(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_sqr - -static void ggml_compute_forward_sqr_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_sqr_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_sqr( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_sqr_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_sqrt - -static void ggml_compute_forward_sqrt_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_sqrt_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_sqrt( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_sqrt_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - - -// ggml_compute_forward_log - -static void ggml_compute_forward_log_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - GGML_ASSERT(params->ith == 0); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - GGML_ASSERT( dst->nb[0] == sizeof(float)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_log_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_log( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_log_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_sum - -static void ggml_compute_forward_sum_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_is_scalar(dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - assert(ggml_is_scalar(dst)); - assert(src0->nb[0] == sizeof(float)); - - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) - GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) - - ggml_float sum = 0; - ggml_float row_sum = 0; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - ggml_vec_sum_f32_ggf(ne00, - &row_sum, - (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); - sum += row_sum; - } - } - } - ((float *) dst->data)[0] = sum; -} - -static void ggml_compute_forward_sum_f16( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_is_scalar(dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - assert(src0->nb[0] == sizeof(ggml_fp16_t)); - - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) - GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) - - float sum = 0; - float row_sum = 0; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - ggml_vec_sum_f16_ggf(ne00, - &row_sum, - (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03)); - sum += row_sum; - } - } - } - ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum); -} - -static void ggml_compute_forward_sum( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_sum_f32(params, src0, dst); - } break; - case GGML_TYPE_F16: - { - ggml_compute_forward_sum_f16(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_sum_rows - -static void ggml_compute_forward_sum_rows_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - GGML_ASSERT(params->ith == 0); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - GGML_ASSERT(src0->nb[0] == sizeof(float)); - GGML_ASSERT(dst->nb[0] == sizeof(float)); - - GGML_TENSOR_UNARY_OP_LOCALS - - GGML_ASSERT(ne0 == 1); - GGML_ASSERT(ne1 == ne01); - GGML_ASSERT(ne2 == ne02); - GGML_ASSERT(ne3 == ne03); - - for (int64_t i3 = 0; i3 < ne03; i3++) { - for (int64_t i2 = 0; i2 < ne02; i2++) { - for (int64_t i1 = 0; i1 < ne01; i1++) { - float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); - float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); - float row_sum = 0; - ggml_vec_sum_f32(ne00, &row_sum, src_row); - dst_row[0] = row_sum; - } - } - } -} - -static void ggml_compute_forward_sum_rows( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_sum_rows_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_mean - -static void ggml_compute_forward_mean_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - assert(src0->nb[0] == sizeof(float)); - - GGML_TENSOR_UNARY_OP_LOCALS - - assert(ne0 == 1); - assert(ne1 == ne01); - assert(ne2 == ne02); - assert(ne3 == ne03); - - UNUSED(ne0); - UNUSED(ne1); - UNUSED(ne2); - UNUSED(ne3); - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - ggml_vec_sum_f32(ne00, - (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), - (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); - - *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00; - } - } - } -} - -static void ggml_compute_forward_mean( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_mean_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_argmax - -static void ggml_compute_forward_argmax_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - assert(src0->nb[0] == sizeof(float)); - assert(dst->nb[0] == sizeof(float)); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - - const size_t nb01 = src0->nb[1]; - const size_t nb0 = dst->nb[0]; - - for (int64_t i1 = 0; i1 < ne01; i1++) { - float * src = (float *) ((char *) src0->data + i1*nb01); - int32_t * dst_ = (int32_t *) ((char *) dst->data + i1*nb0); - int v = 0; - ggml_vec_argmax_f32(ne00, &v, src); - dst_[0] = v; - } -} - -static void ggml_compute_forward_argmax( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_argmax_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_repeat - -static void ggml_compute_forward_repeat_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - GGML_ASSERT(params->ith == 0); - GGML_ASSERT(ggml_can_repeat(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - GGML_TENSOR_UNARY_OP_LOCALS - - // guaranteed to be an integer due to the check in ggml_can_repeat - const int nr0 = (int)(ne0/ne00); - const int nr1 = (int)(ne1/ne01); - const int nr2 = (int)(ne2/ne02); - const int nr3 = (int)(ne3/ne03); - - // TODO: support for transposed / permuted tensors - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - - // TODO: maybe this is not optimal? - for (int i3 = 0; i3 < nr3; i3++) { - for (int k3 = 0; k3 < ne03; k3++) { - for (int i2 = 0; i2 < nr2; i2++) { - for (int k2 = 0; k2 < ne02; k2++) { - for (int i1 = 0; i1 < nr1; i1++) { - for (int k1 = 0; k1 < ne01; k1++) { - for (int i0 = 0; i0 < nr0; i0++) { - ggml_vec_cpy_f32(ne00, - (float *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0), - (float *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01)); - } - } - } - } - } - } - } -} - -static void ggml_compute_forward_repeat_f16( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - GGML_ASSERT(params->ith == 0); - GGML_ASSERT(ggml_can_repeat(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - GGML_TENSOR_UNARY_OP_LOCALS; - - // guaranteed to be an integer due to the check in ggml_can_repeat - const int nr0 = (int)(ne0/ne00); - const int nr1 = (int)(ne1/ne01); - const int nr2 = (int)(ne2/ne02); - const int nr3 = (int)(ne3/ne03); - - // TODO: support for transposed / permuted tensors - GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - - // TODO: maybe this is not optimal? - for (int i3 = 0; i3 < nr3; i3++) { - for (int k3 = 0; k3 < ne03; k3++) { - for (int i2 = 0; i2 < nr2; i2++) { - for (int k2 = 0; k2 < ne02; k2++) { - for (int i1 = 0; i1 < nr1; i1++) { - for (int k1 = 0; k1 < ne01; k1++) { - for (int i0 = 0; i0 < nr0; i0++) { - ggml_fp16_t * y = (ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0); - ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01); - // ggml_vec_cpy_f16(ne00, y, x) - for (int i = 0; i < ne00; ++i) { - y[i] = x[i]; - } - } - } - } - } - } - } - } -} - -static void ggml_compute_forward_repeat( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_repeat_f16(params, src0, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_repeat_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_repeat_back - -static void ggml_compute_forward_repeat_back_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - GGML_ASSERT(params->ith == 0); - GGML_ASSERT(ggml_can_repeat(dst, src0)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - GGML_TENSOR_UNARY_OP_LOCALS - - // guaranteed to be an integer due to the check in ggml_can_repeat - const int nr0 = (int)(ne00/ne0); - const int nr1 = (int)(ne01/ne1); - const int nr2 = (int)(ne02/ne2); - const int nr3 = (int)(ne03/ne3); - - // TODO: support for transposed / permuted tensors - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - - if (ggml_is_contiguous(dst)) { - ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); - } else { - for (int k3 = 0; k3 < ne3; k3++) { - for (int k2 = 0; k2 < ne2; k2++) { - for (int k1 = 0; k1 < ne1; k1++) { - ggml_vec_set_f32(ne0, - (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3), - 0); - } - } - } - } - - // TODO: maybe this is not optimal? - for (int i3 = 0; i3 < nr3; i3++) { - for (int k3 = 0; k3 < ne3; k3++) { - for (int i2 = 0; i2 < nr2; i2++) { - for (int k2 = 0; k2 < ne2; k2++) { - for (int i1 = 0; i1 < nr1; i1++) { - for (int k1 = 0; k1 < ne1; k1++) { - for (int i0 = 0; i0 < nr0; i0++) { - ggml_vec_acc_f32(ne0, - (float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1), - (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00)); - } - } - } - } - } - } - } -} - -static void ggml_compute_forward_repeat_back( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_repeat_back_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_concat - -static void ggml_compute_forward_concat_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - - GGML_TENSOR_BINARY_OP_LOCALS - - // TODO: support for transposed / permuted tensors - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - GGML_ASSERT(nb10 == sizeof(float)); - - for (int i3 = 0; i3 < ne3; i3++) { - for (int i2 = ith; i2 < ne2; i2++) { - if (i2 < ne02) { // src0 - for (int i1 = 0; i1 < ne1; i1++) { - for (int i0 = 0; i0 < ne0; i0++) { - const float * x = (float *)((char *) src0->data + i0 * nb00 + i1 * nb01 + i2 * nb02 + i3 * nb03); - - float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3); - *y = *x; - } - } - } // src1 - else { - for (int i1 = 0; i1 < ne1; i1++) { - for (int i0 = 0; i0 < ne0; i0++) { - const float * x = (float *)((char *) src1->data + i0 * nb10 + i1 * nb11 + (i2 - ne02) * nb12 + i3 * nb13); - - float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3); - *y = *x; - } - } - } - } - } -} - -static void ggml_compute_forward_concat( - const struct ggml_compute_params* params, - const struct ggml_tensor* src0, - const struct ggml_tensor* src1, - struct ggml_tensor* dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_concat_f32(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_abs - -static void ggml_compute_forward_abs_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_abs_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_abs( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_abs_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_sgn - -static void ggml_compute_forward_sgn_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_sgn_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_sgn( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_sgn_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_neg - -static void ggml_compute_forward_neg_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_neg_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_neg( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_neg_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_step - -static void ggml_compute_forward_step_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_step_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_step( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_step_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_tanh - -static void ggml_compute_forward_tanh_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_tanh_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_tanh( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_tanh_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_elu - -static void ggml_compute_forward_elu_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_elu_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_elu( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_elu_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_relu - -static void ggml_compute_forward_relu_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_relu_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_relu( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_relu_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_gelu - -static void ggml_compute_forward_gelu_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); - GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - ggml_vec_gelu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); - -#ifndef NDEBUG - for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; - UNUSED(x); - assert(!isnan(x)); - assert(!isinf(x)); - } -#endif - } -} - -static void ggml_compute_forward_gelu( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_gelu_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_gelu_quick - -static void ggml_compute_forward_gelu_quick_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); - GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - ggml_vec_gelu_quick_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); - -#ifndef NDEBUG - for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; - UNUSED(x); - assert(!isnan(x)); - assert(!isinf(x)); - } -#endif - } -} - -static void ggml_compute_forward_gelu_quick( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_gelu_quick_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_silu - -static void ggml_compute_forward_silu_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); - GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - ggml_vec_silu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); - -#ifndef NDEBUG - for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k]; - UNUSED(x); - assert(!isnan(x)); - assert(!isinf(x)); - } -#endif - } -} - -static void ggml_compute_forward_silu( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_silu_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_silu_back - -static void ggml_compute_forward_silu_back_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * grad, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous_except_dim_1(grad)); - GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); - GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_are_same_shape(src0, grad)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - ggml_vec_silu_backward_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1])), - (float *) ((char *) grad->data + i1*(grad->nb[1]))); - -#ifndef NDEBUG - for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; - UNUSED(x); - assert(!isnan(x)); - assert(!isinf(x)); - } -#endif - } -} - -static void ggml_compute_forward_silu_back( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * grad, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_silu_back_f32(params, src0, grad, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_norm - -static void ggml_compute_forward_norm_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - GGML_TENSOR_UNARY_OP_LOCALS - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - // TODO: optimize - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ith; i01 < ne01; i01 += nth) { - const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - - ggml_float sum = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - sum += (ggml_float)x[i00]; - } - - float mean = sum/ne00; - - float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - - ggml_float sum2 = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - float v = x[i00] - mean; - y[i00] = v; - sum2 += (ggml_float)(v*v); - } - - float variance = sum2/ne00; - const float scale = 1.0f/sqrtf(variance + eps); - - ggml_vec_scale_f32(ne00, y, scale); - } - } - } -} - -static void ggml_compute_forward_norm( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_norm_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_group_rms_norm - -static void ggml_compute_forward_rms_norm_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - GGML_TENSOR_UNARY_OP_LOCALS - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - // TODO: optimize - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ith; i01 < ne01; i01 += nth) { - const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - - ggml_float sum = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - sum += (ggml_float)(x[i00] * x[i00]); - } - - const float mean = sum/ne00; - - float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - - memcpy(y, x, ne00 * sizeof(float)); - // for (int i00 = 0; i00 < ne00; i00++) { - // y[i00] = x[i00]; - // } - - const float scale = 1.0f/sqrtf(mean + eps); - - ggml_vec_scale_f32(ne00, y, scale); - } - } - } -} - -static void ggml_compute_forward_rms_norm( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_rms_norm_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -static void ggml_compute_forward_rms_norm_back_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - GGML_TENSOR_BINARY_OP_LOCALS - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - // TODO: optimize - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ith; i01 < ne01; i01 += nth) { - // src1 is same shape as src0 => same indices - const int64_t i11 = i01; - const int64_t i12 = i02; - const int64_t i13 = i03; - - const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13); - - ggml_float sum_xx = 0.0; - ggml_float sum_xdz = 0.0; - - for (int64_t i00 = 0; i00 < ne00; i00++) { - sum_xx += (ggml_float)(x[i00] * x[i00]); - sum_xdz += (ggml_float)(x[i00] * dz[i00]); - } - - //const float mean = (float)(sum_xx)/ne00; - const float mean_eps = (float)(sum_xx)/ne00 + eps; - const float sum_eps = (float)(sum_xx) + eps*ne00; - //const float mean_xdz = (float)(sum_xdz)/ne00; - // we could cache rms from forward pass to improve performance. - // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms. - //const float rms = sqrtf(mean_eps); - const float rrms = 1.0f / sqrtf(mean_eps); - //const float scale = -rrms/(ne00 * mean_eps); // -1/(n*rms**3) - - { - // z = rms_norm(x) - // - // rms_norm(src0) = - // scale( - // src0, - // div( - // 1, - // sqrt( - // add( - // scale( - // sum( - // sqr( - // src0)), - // (1.0/N)), - // eps)))); - - // postorder: - // ## op args grad - // 00 param src0 grad[#00] - // 01 const 1 - // 02 sqr (#00) grad[#02] - // 03 sum (#02) grad[#03] - // 04 const 1/N - // 05 scale (#03, #04) grad[#05] - // 06 const eps - // 07 add (#05, #06) grad[#07] - // 08 sqrt (#07) grad[#08] - // 09 div (#01,#08) grad[#09] - // 10 scale (#00,#09) grad[#10] - // - // backward pass, given grad[#10] - // #10: scale - // grad[#00] += scale(grad[#10],#09) - // grad[#09] += sum(mul(grad[#10],#00)) - // #09: div - // grad[#08] += neg(mul(grad[#09], div(#09,#08))) - // #08: sqrt - // grad[#07] += mul(grad[#08], div(0.5, #08)) - // #07: add - // grad[#05] += grad[#07] - // #05: scale - // grad[#03] += scale(grad[#05],#04) - // #03: sum - // grad[#02] += repeat(grad[#03], #02) - // #02: - // grad[#00] += scale(mul(#00, grad[#02]), 2.0) - // - // substitute and simplify: - // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0) - // grad[#02] = repeat(grad[#03], #02) - // grad[#02] = repeat(scale(grad[#05],#04), #02) - // grad[#02] = repeat(scale(grad[#07],#04), #02) - // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02) - // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02) - // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02) - // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02) - // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02) - // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02) - // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02) - // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0) - // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0) - // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0) - // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N))) - // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N)) - // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N)) - // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N)) - // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps)) - // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps))) - // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps)) - // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps)) - // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps)) - // a = b*c + d*e - // a = b*c*f/f + d*e*f/f - // a = (b*c*f + d*e*f)*(1/f) - // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c)) - // a = (b + d*e/c)*c - // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps) - // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms - // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms - // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms - // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms - // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms - // a = (dz + x*div(-mean_xdz,mean_eps))*rrms - // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms) - // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) - // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) - } - // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) - // post-order: - // dx := x - // dx := scale(dx,-mean_xdz/mean_eps) - // dx := add(dx, dz) - // dx := scale(dx, rrms) - float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - - ggml_vec_cpy_f32 (ne00, dx, x); - // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps); - ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps); - ggml_vec_acc_f32 (ne00, dx, dz); - ggml_vec_scale_f32(ne00, dx, rrms); - } - } - } -} - -static void ggml_compute_forward_rms_norm_back( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_rms_norm_back_f32(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_group_norm - -static void ggml_compute_forward_group_norm_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - GGML_TENSOR_UNARY_OP_LOCALS - - const float eps = 1e-6f; // TODO: make this a parameter - - // TODO: optimize - - int n_channels = src0->ne[2]; - int n_groups = dst->op_params[0]; - int n_channels_per_group = (n_channels + n_groups - 1) / n_groups; - for (int i = ith; i < n_groups; i+=nth) { - int start = i * n_channels_per_group; - int end = start + n_channels_per_group; - if (end > n_channels) { - end = n_channels; - } - int step = end - start; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - ggml_float sum = 0.0; - for (int64_t i02 = start; i02 < end; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); - - for (int64_t i00 = 0; i00 < ne00; i00++) { - sum += (ggml_float)x[i00]; - } - } - } - float mean = sum / (ne00 * ne01 * step); - ggml_float sum2 = 0.0; - - for (int64_t i02 = start; i02 < end; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); - - float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3); - - for (int64_t i00 = 0; i00 < ne00; i00++) { - float v = x[i00] - mean; - y[i00] = v; - sum2 += (ggml_float)(v * v); - } - } - } - float variance = sum2 / (ne00 * ne01 * step); - const float scale = 1.0f / sqrtf(variance + eps); - - for (int64_t i02 = start; i02 < end; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3); - ggml_vec_scale_f32(ne00, y, scale); - } - } - } - } -} - -static void ggml_compute_forward_group_norm( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_group_norm_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_mul_mat - -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) -// helper function to determine if it is better to use BLAS or not -// for large matrices, BLAS is faster -static bool ggml_compute_forward_mul_mat_use_blas( - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - //const int64_t ne00 = src0->ne[0]; - //const int64_t ne01 = src0->ne[1]; - - const int64_t ne10 = src1->ne[0]; - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - - // TODO: find the optimal values for these - if (ggml_is_contiguous(src0) && - ggml_is_contiguous(src1) && - (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) { - - /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/ - return true; - } - - return false; -} -#endif - -static void ggml_compute_forward_mul_mat( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const enum ggml_type type = src0->type; - - const bool src1_cont = ggml_is_contiguous(src1); - - ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; - enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; - ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; - - GGML_ASSERT(ne0 == ne01); - GGML_ASSERT(ne1 == ne11); - GGML_ASSERT(ne2 == ne12); - GGML_ASSERT(ne3 == ne13); - - // we don't support permuted src0 or src1 - GGML_ASSERT(nb00 == ggml_type_size(type)); - GGML_ASSERT(nb10 == sizeof(float)); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - // broadcast factors - const int64_t r2 = ne12/ne02; - const int64_t r3 = ne13/ne03; - - // nb01 >= nb00 - src0 is not transposed - // compute by src0 rows - -#if defined(GGML_USE_CLBLAST) - if (ggml_cl_can_mul_mat(src0, src1, dst)) { - if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) { - ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize); - } - return; - } -#endif - -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) - if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { - if (params->ith != 0) { - return; - } - - if (params->type == GGML_TASK_INIT) { - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - for (int64_t i13 = 0; i13 < ne13; i13++) { - for (int64_t i12 = 0; i12 < ne12; i12++) { - // broadcast src0 into src1 across 2nd,3rd dimension - const int64_t i03 = i13/r3; - const int64_t i02 = i12/r2; - - const void * x = (char *) src0->data + i02*nb02 + i03*nb03; - const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13); - - float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); - - if (type != GGML_TYPE_F32) { - float * const wdata = params->wdata; - ggml_to_float_t const to_float = type_traits[type].to_float; - - size_t id = 0; - for (int64_t i01 = 0; i01 < ne01; ++i01) { - to_float((const char *) x + i01*nb01, wdata + id, ne00); - id += ne00; - } - - assert(id*sizeof(float) <= params->wsize); - x = wdata; - } - - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - ne11, ne01, ne10, - 1.0f, y, ne10, - x, ne00, - 0.0f, d, ne01); - } - } - - //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); - - return; - } -#endif - - if (params->type == GGML_TASK_INIT) { - if (src1->type != vec_dot_type) { - char * wdata = params->wdata; - const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type); - - for (int64_t i13 = 0; i13 < ne13; ++i13) { - for (int64_t i12 = 0; i12 < ne12; ++i12) { - for (int64_t i11 = 0; i11 < ne11; ++i11) { - from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); - wdata += row_size; - } - } - } - } - - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; - const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type); - - const int64_t nr0 = ne01; // src0 rows - const int64_t nr1 = ne11*ne12*ne13; // src1 rows - - //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1); - - // distribute the thread work across the inner or outer loop based on which one is larger - - const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows - const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows - - const int64_t ith0 = ith % nth0; - const int64_t ith1 = ith / nth0; - - const int64_t dr0 = (nr0 + nth0 - 1)/nth0; - const int64_t dr1 = (nr1 + nth1 - 1)/nth1; - - const int64_t ir010 = dr0*ith0; - const int64_t ir011 = MIN(ir010 + dr0, nr0); - - const int64_t ir110 = dr1*ith1; - const int64_t ir111 = MIN(ir110 + dr1, nr1); - - //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111); - - // threads with no work simply yield (not sure if it helps) - if (ir010 >= ir011 || ir110 >= ir111) { - sched_yield(); - return; - } - - assert(ne12 % ne02 == 0); - assert(ne13 % ne03 == 0); - - // block-tiling attempt - const int64_t blck_0 = 16; - const int64_t blck_1 = 16; - - // attempt to reduce false-sharing (does not seem to make a difference) - float tmp[16]; - - for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { - for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { - for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) { - const int64_t i13 = (ir1/(ne12*ne11)); - const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11; - const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11); - - // broadcast src0 into src1 - const int64_t i03 = i13/r3; - const int64_t i02 = i12/r2; - - const int64_t i1 = i11; - const int64_t i2 = i12; - const int64_t i3 = i13; - - const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03); - - // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides - // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using - // the original src1 data pointer, so we should index using the indices directly - // TODO: this is a bit of a hack, we should probably have a better way to handle this - const char * src1_col = (const char *) wdata + - (src1_cont || src1->type != vec_dot_type - ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size - : (i11*nb11 + i12*nb12 + i13*nb13)); - - float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)); - - //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { - // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); - //} - - for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { - vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col); - } - memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); - } - } - } -} - -// ggml_compute_forward_out_prod - -static void ggml_compute_forward_out_prod_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - // int64_t t0 = ggml_perf_time_us(); - // UNUSED(t0); - - GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - GGML_ASSERT(ne02 == ne12); - GGML_ASSERT(ne03 == ne13); - GGML_ASSERT(ne2 == ne12); - GGML_ASSERT(ne3 == ne13); - - // we don't support permuted src0 or src1 - GGML_ASSERT(nb00 == sizeof(float)); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - // GGML_ASSERT(nb0 <= nb1); - // GGML_ASSERT(nb1 <= nb2); - // GGML_ASSERT(nb2 <= nb3); - - GGML_ASSERT(ne0 == ne00); - GGML_ASSERT(ne1 == ne10); - GGML_ASSERT(ne2 == ne02); - GGML_ASSERT(ne3 == ne03); - - // nb01 >= nb00 - src0 is not transposed - // compute by src0 rows - - // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod - // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST) - - if (params->type == GGML_TASK_INIT) { - ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - // dst[:,:,:,:] = 0 - // for i2,i3: - // for i1: - // for i01: - // for i0: - // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] - - // parallelize by last three dimensions - - // total rows in dst - const int64_t nr = ne1*ne2*ne3; - - // rows per thread - const int64_t dr = (nr + nth - 1)/nth; - - // row range for this thread - const int64_t ir0 = dr*ith; - const int64_t ir1 = MIN(ir0 + dr, nr); - - // block-tiling attempt - const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32); - const int64_t blck_1 = 16; - - for (int64_t bir = ir0; bir < ir1; bir += blck_1) { - const int64_t bir1 = MIN(bir + blck_1, ir1); - for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) { - const int64_t bne01 = MIN(bi01 + blck_0, ne01); - for (int64_t ir = bir; ir < bir1; ++ir) { - // dst indices - const int64_t i3 = ir/(ne2*ne1); - const int64_t i2 = (ir - i3*ne2*ne1)/ne1; - const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); - - const int64_t i02 = i2; - const int64_t i03 = i3; - - //const int64_t i10 = i1; - const int64_t i12 = i2; - const int64_t i13 = i3; - -#if GGML_VEC_MAD_UNROLL > 2 - const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL); - for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) { - const int64_t i11 = i01; - - float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); - float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); - - ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1); - } - for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) { - const int64_t i11 = i01; - - float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); - float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); - - ggml_vec_mad_f32(ne0, d, s0, *s1); - } -#else - for (int64_t i01 = bi01; i01 < bne01; ++i01) { - const int64_t i11 = i01; - - float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); - float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); - - ggml_vec_mad_f32(ne0, d, s0, *s1); - } -#endif - } - } - } - - - //int64_t t1 = ggml_perf_time_us(); - //static int64_t acc = 0; - //acc += t1 - t0; - //if (t1 - t0 > 10) { - // printf("\n"); - // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); - // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); - // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); - // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13); - - // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); - //} -} - -static void ggml_compute_forward_out_prod_q_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - // int64_t t0 = ggml_perf_time_us(); - // UNUSED(t0); - - GGML_TENSOR_BINARY_OP_LOCALS; - - const int ith = params->ith; - const int nth = params->nth; - - const enum ggml_type type = src0->type; - ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; - - GGML_ASSERT(ne02 == ne12); - GGML_ASSERT(ne03 == ne13); - GGML_ASSERT(ne2 == ne12); - GGML_ASSERT(ne3 == ne13); - - // we don't support permuted src0 dim0 - GGML_ASSERT(nb00 == ggml_type_size(type)); - - // dst dim0 cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - // GGML_ASSERT(nb0 <= nb1); - // GGML_ASSERT(nb1 <= nb2); - // GGML_ASSERT(nb2 <= nb3); - - GGML_ASSERT(ne0 == ne00); - GGML_ASSERT(ne1 == ne10); - GGML_ASSERT(ne2 == ne02); - GGML_ASSERT(ne3 == ne03); - - // nb01 >= nb00 - src0 is not transposed - // compute by src0 rows - - // TODO: #if defined(GGML_USE_CUBLAS) ggml_cuda_out_prod - // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST) - - if (params->type == GGML_TASK_INIT) { - ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - // parallelize by last three dimensions - - // total rows in dst - const int64_t nr = ne1*ne2*ne3; - - // rows per thread - const int64_t dr = (nr + nth - 1)/nth; - - // row range for this thread - const int64_t ir0 = dr*ith; - const int64_t ir1 = MIN(ir0 + dr, nr); - - // dst[:,:,:,:] = 0 - // for i2,i3: - // for i1: - // for i01: - // for i0: - // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] - - float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith; - - for (int64_t ir = ir0; ir < ir1; ++ir) { - // dst indices - const int64_t i3 = ir/(ne2*ne1); - const int64_t i2 = (ir - i3*ne2*ne1)/ne1; - const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); - - const int64_t i02 = i2; - const int64_t i03 = i3; - - //const int64_t i10 = i1; - const int64_t i12 = i2; - const int64_t i13 = i3; - - for (int64_t i01 = 0; i01 < ne01; ++i01) { - const int64_t i11 = i01; - - float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); - float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); - - dequantize_row_q(s0, wdata, ne0); - ggml_vec_mad_f32(ne0, d, wdata, *s1); - } - } - - //int64_t t1 = ggml_perf_time_us(); - //static int64_t acc = 0; - //acc += t1 - t0; - //if (t1 - t0 > 10) { - // printf("\n"); - // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); - // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); - // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); - // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13); - - // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); - //} -} - -static void ggml_compute_forward_out_prod( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - { - ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_F16: - { - GGML_ASSERT(false); // todo - // ggml_compute_forward_out_prod_f16_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_out_prod_f32(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_scale - -static void ggml_compute_forward_scale_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_is_scalar(src1)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - // scale factor - const float v = *(float *) src1->data; - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - const size_t nb01 = src0->nb[1]; - - const size_t nb1 = dst->nb[1]; - - - for (int i1 = ir0; i1 < ir1; i1++) { - if (dst->data != src0->data) { - // src0 is same shape as dst => same indices - memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); - } - ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v); - } -} - -static void ggml_compute_forward_scale( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_scale_f32(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_set - -static void ggml_compute_forward_set_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); - - // view src0 and dst with these strides and data offset inbytes during set - // nb0 is implicitely element_size because src0 and dst are contiguous - size_t nb1 = ((int32_t *) dst->op_params)[0]; - size_t nb2 = ((int32_t *) dst->op_params)[1]; - size_t nb3 = ((int32_t *) dst->op_params)[2]; - size_t offset = ((int32_t *) dst->op_params)[3]; - bool inplace = (bool) ((int32_t *) dst->op_params)[4]; - - if (!inplace && (params->type == GGML_TASK_INIT)) { - // memcpy needs to be synchronized across threads to avoid race conditions. - // => do it in INIT phase - memcpy( - ((char *) dst->data), - ((char *) src0->data), - ggml_nbytes(dst)); - } - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src1); - const int nc = src1->ne[0]; - - GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) - GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) - - // src0 and dst as viewed during set - const size_t nb0 = ggml_element_size(src0); - - const int im0 = (ne10 == 0 ? 0 : ne10-1); - const int im1 = (ne11 == 0 ? 0 : ne11-1); - const int im2 = (ne12 == 0 ? 0 : ne12-1); - const int im3 = (ne13 == 0 ? 0 : ne13-1); - - GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst)); - - GGML_ASSERT(nb10 == sizeof(float)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 and dst are viewed with shape of src1 and offset - // => same indices - const int i3 = ir/(ne12*ne11); - const int i2 = (ir - i3*ne12*ne11)/ne11; - const int i1 = (ir - i3*ne12*ne11 - i2*ne11); - - ggml_vec_cpy_f32(nc, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); - } -} - -static void ggml_compute_forward_set( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_set_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_F16: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q8_1: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_cpy - -static void ggml_compute_forward_cpy( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - ggml_compute_forward_dup(params, src0, dst); -} - -// ggml_compute_forward_cont - -static void ggml_compute_forward_cont( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - ggml_compute_forward_dup(params, src0, dst); -} - -// ggml_compute_forward_reshape - -static void ggml_compute_forward_reshape( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - // NOP - UNUSED(params); - UNUSED(src0); - UNUSED(dst); -} - -// ggml_compute_forward_view - -static void ggml_compute_forward_view( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0) { - // NOP - UNUSED(params); - UNUSED(src0); -} - -// ggml_compute_forward_permute - -static void ggml_compute_forward_permute( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0) { - // NOP - UNUSED(params); - UNUSED(src0); -} - -// ggml_compute_forward_transpose - -static void ggml_compute_forward_transpose( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0) { - // NOP - UNUSED(params); - UNUSED(src0); -} - -// ggml_compute_forward_get_rows - -static void ggml_compute_forward_get_rows_q( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - assert(params->ith == 0); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int nc = src0->ne[0]; - const int nr = ggml_nelements(src1); - const enum ggml_type type = src0->type; - ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; - - assert( dst->ne[0] == nc); - assert( dst->ne[1] == nr); - assert(src0->nb[0] == ggml_type_size(type)); - - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t *) src1->data)[i]; - - dequantize_row_q( - (const void *) ((char *) src0->data + r*src0->nb[1]), - (float *) ((char *) dst->data + i*dst->nb[1]), nc); - } -} - -static void ggml_compute_forward_get_rows_f16( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - assert(params->ith == 0); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int nc = src0->ne[0]; - const int nr = ggml_nelements(src1); - - assert( dst->ne[0] == nc); - assert( dst->ne[1] == nr); - assert(src0->nb[0] == sizeof(ggml_fp16_t)); - - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t *) src1->data)[i]; - - for (int j = 0; j < nc; ++j) { - ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j]; - ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v); - } - } -} - -static void ggml_compute_forward_get_rows_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - assert(params->ith == 0); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int nc = src0->ne[0]; - const int nr = ggml_nelements(src1); - - assert( dst->ne[0] == nc); - assert( dst->ne[1] == nr); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t *) src1->data)[i]; - - ggml_vec_cpy_f32(nc, - (float *) ((char *) dst->data + i*dst->nb[1]), - (float *) ((char *) src0->data + r*src0->nb[1])); - } -} - -static void ggml_compute_forward_get_rows( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q8_1: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - { - ggml_compute_forward_get_rows_q(params, src0, src1, dst); - } break; - case GGML_TYPE_F16: - { - ggml_compute_forward_get_rows_f16(params, src0, src1, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_get_rows_f32(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } - - //static bool first = true; - //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]); - //if (first) { - // first = false; - //} else { - // for (int k = 0; k < dst->ne[1]; ++k) { - // for (int j = 0; j < dst->ne[0]/16; ++j) { - // for (int i = 0; i < 16; ++i) { - // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); - // } - // printf("\n"); - // } - // printf("\n"); - // } - // printf("\n"); - // exit(0); - //} -} - -// ggml_compute_forward_get_rows_back - -static void ggml_compute_forward_get_rows_back_f32_f16( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(params->ith == 0); - GGML_ASSERT(ggml_is_contiguous(dst)); - - // ggml_compute_forward_dup_same_cont(params, opt0, dst); - - if (params->type == GGML_TASK_INIT) { - memset(dst->data, 0, ggml_nbytes(dst)); - } - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int nc = src0->ne[0]; - const int nr = ggml_nelements(src1); - - GGML_ASSERT( dst->ne[0] == nc); - GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t)); - - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t *) src1->data)[i]; - - for (int j = 0; j < nc; ++j) { - ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j]; - ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_FP16_TO_FP32(v); - } - } -} - -static void ggml_compute_forward_get_rows_back_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(params->ith == 0); - GGML_ASSERT(ggml_is_contiguous(dst)); - - // ggml_compute_forward_dup_same_cont(params, opt0, dst); - - if (params->type == GGML_TASK_INIT) { - memset(dst->data, 0, ggml_nbytes(dst)); - } - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int nc = src0->ne[0]; - const int nr = ggml_nelements(src1); - - GGML_ASSERT( dst->ne[0] == nc); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t *) src1->data)[i]; - - ggml_vec_add_f32(nc, - (float *) ((char *) dst->data + r*dst->nb[1]), - (float *) ((char *) dst->data + r*dst->nb[1]), - (float *) ((char *) src0->data + i*src0->nb[1])); - } -} - - -static void ggml_compute_forward_get_rows_back( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_get_rows_back_f32(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } - - //static bool first = true; - //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]); - //if (first) { - // first = false; - //} else { - // for (int k = 0; k < dst->ne[1]; ++k) { - // for (int j = 0; j < dst->ne[0]/16; ++j) { - // for (int i = 0; i < 16; ++i) { - // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); - // } - // printf("\n"); - // } - // printf("\n"); - // } - // printf("\n"); - // exit(0); - //} -} - -// ggml_compute_forward_diag - -static void ggml_compute_forward_diag_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - GGML_ASSERT(params->ith == 0); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - // TODO: handle transposed/permuted matrices - - GGML_TENSOR_UNARY_OP_LOCALS - - GGML_ASSERT(ne00 == ne0); - GGML_ASSERT(ne00 == ne1); - GGML_ASSERT(ne01 == 1); - GGML_ASSERT(ne02 == ne2); - GGML_ASSERT(ne03 == ne3); - - GGML_ASSERT(nb00 == sizeof(float)); - GGML_ASSERT(nb0 == sizeof(float)); - - for (int i3 = 0; i3 < ne3; i3++) { - for (int i2 = 0; i2 < ne2; i2++) { - for (int i1 = 0; i1 < ne1; i1++) { - float * d = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02); - for (int i0 = 0; i0 < i1; i0++) { - d[i0] = 0; - } - d[i1] = s[i1]; - for (int i0 = i1+1; i0 < ne0; i0++) { - d[i0] = 0; - } - } - } - } -} - -static void ggml_compute_forward_diag( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_diag_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_diag_mask_inf - -static void ggml_compute_forward_diag_mask_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst, - const float value) { - - const int ith = params->ith; - const int nth = params->nth; - - const int n_past = ((int32_t *) dst->op_params)[0]; - const bool inplace = src0->data == dst->data; - - GGML_ASSERT(n_past >= 0); - - if (!inplace && (params->type == GGML_TASK_INIT)) { - // memcpy needs to be synchronized across threads to avoid race conditions. - // => do it in INIT phase - GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); - GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); - memcpy( - ((char *) dst->data), - ((char *) src0->data), - ggml_nbytes(dst)); - } - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - // TODO: handle transposed/permuted matrices - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - const int nr = src0->ne[1]; - const int nz = n/nr; - - GGML_ASSERT( dst->nb[0] == sizeof(float)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - for (int k = 0; k < nz; k++) { - for (int j = ith; j < nr; j += nth) { - for (int i = n_past; i < nc; i++) { - if (i > n_past + j) { - *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value; - } - } - } - } -} - -static void ggml_compute_forward_diag_mask_inf( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_diag_mask_f32(params, src0, dst, -INFINITY); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -static void ggml_compute_forward_diag_mask_zero( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_diag_mask_f32(params, src0, dst, 0); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_soft_max - -static void ggml_compute_forward_soft_max_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - // TODO: handle transposed/permuted matrices - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - float *sp = (float *)((char *) src0->data + i1*src0->nb[1]); - float *dp = (float *)((char *) dst->data + i1*dst->nb[1]); - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(sp[i])); - } -#endif - - float max = -INFINITY; - ggml_vec_max_f32(nc, &max, sp); - - ggml_float sum = 0.0; - - uint16_t scvt; - for (int i = 0; i < nc; i++) { - if (sp[i] == -INFINITY) { - dp[i] = 0.0f; - } else { - // const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max); - ggml_fp16_t s = GGML_FP32_TO_FP16(sp[i] - max); - memcpy(&scvt, &s, sizeof(scvt)); - const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]); - sum += (ggml_float)val; - dp[i] = val; - } - } - - assert(sum > 0.0); - - sum = 1.0/sum; - ggml_vec_scale_f32(nc, dp, sum); - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - assert(!isnan(dp[i])); - assert(!isinf(dp[i])); - } -#endif - } -} - -static void ggml_compute_forward_soft_max( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_soft_max_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_soft_max_back - -static void ggml_compute_forward_soft_max_back_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_are_same_shape(src1, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - // TODO: handle transposed/permuted matrices - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - float *dy = (float *)((char *) src0->data + i1*src0->nb[1]); - float *y = (float *)((char *) src1->data + i1*src1->nb[1]); - float *dx = (float *)((char *) dst->data + i1*dst->nb[1]); - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(dy[i])); - assert(!isnan(y[i])); - } -#endif - // Jii = yi - yi*yi - // Jij = -yi*yj - // J = diag(y)-y.T*y - // dx = J * dy - // dxk = sum_i(Jki * dyi) - // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk - // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk - // dxk = sum_i(-yk*yi * dyi) + yk*dyk - // dxk = -yk * sum_i(yi * dyi) + yk*dyk - // dxk = -yk * dot(y, dy) + yk*dyk - // dxk = yk * (- dot(y, dy) + dyk) - // dxk = yk * (dyk - dot(y, dy)) - // - // post-order: - // dot_y_dy := dot(y, dy) - // dx := dy - // dx := dx - dot_y_dy - // dx := dx * y - - // linear runtime, no additional memory - float dot_y_dy = 0; - ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy); - ggml_vec_cpy_f32 (nc, dx, dy); - ggml_vec_acc1_f32(nc, dx, -dot_y_dy); - ggml_vec_mul_f32 (nc, dx, dx, y); - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - assert(!isnan(dx[i])); - assert(!isinf(dx[i])); - } -#endif - } -} - -static void ggml_compute_forward_soft_max_back( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_soft_max_back_f32(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_alibi - -static void ggml_compute_forward_alibi_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_head = ((int32_t *) dst->op_params)[1]; - float max_bias; - memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - - const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 - const int64_t ne1 = src0->ne[1]; // seq_len_without_past - const int64_t ne2 = src0->ne[2]; // n_head -> this is k - //const int64_t ne3 = src0->ne[3]; // 1 -> bsz - - const int64_t n = ggml_nrows(src0); - const int64_t ne2_ne3 = n/ne1; // ne2*ne3 - - const size_t nb0 = src0->nb[0]; - const size_t nb1 = src0->nb[1]; - const size_t nb2 = src0->nb[2]; - //const int nb3 = src0->nb[3]; - - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(n_head == ne2); - - // add alibi to src0 (KQ_scaled) - const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - - const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - - for (int64_t i = 0; i < ne0; i++) { - for (int64_t j = 0; j < ne1; j++) { - for (int64_t k = 0; k < ne2_ne3; k++) { - float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2); - float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2); - - // TODO: k*nb2 or k*nb3 - - float m_k; - - if (k < n_heads_log2_floor) { - m_k = powf(m0, k + 1); - } else { - m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); - } - - pdst[0] = i * m_k + src[0]; - } - } - } -} - -static void ggml_compute_forward_alibi_f16( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_head = ((int32_t *) dst->op_params)[1]; - float max_bias; - memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - - const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 - const int ne1 = src0->ne[1]; // seq_len_without_past - const int ne2 = src0->ne[2]; // n_head -> this is k - //const int ne3 = src0->ne[3]; // 1 -> bsz - - const int n = ggml_nrows(src0); - const int ne2_ne3 = n/ne1; // ne2*ne3 - - const int nb0 = src0->nb[0]; - const int nb1 = src0->nb[1]; - const int nb2 = src0->nb[2]; - //const int nb3 = src0->nb[3]; - - GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); - //GGML_ASSERT(ne1 + n_past == ne0); (void) n_past; - GGML_ASSERT(n_head == ne2); - - // add alibi to src0 (KQ_scaled) - const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - - const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - - for (int i = 0; i < ne0; i++) { - for (int j = 0; j < ne1; j++) { - for (int k = 0; k < ne2_ne3; k++) { - ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2); - float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2); - - // TODO: k*nb2 or k*nb3 - - float m_k; - - if (k < n_heads_log2_floor) { - m_k = powf(m0, k + 1); - } else { - m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); - } - - // we return F32 - pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]); - } - } - } -} - -static void ggml_compute_forward_alibi( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_alibi_f16(params, src0, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_alibi_f32(params, src0, dst); - } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q8_1: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - case GGML_TYPE_Q8_K: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_COUNT: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_clamp - -static void ggml_compute_forward_clamp_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - assert(params->ith == 0); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - float min; - float max; - memcpy(&min, (float *) dst->op_params + 0, sizeof(float)); - memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - - GGML_ASSERT( nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - - for (int j = ith; j < n; j += nth) { - float * dst_ptr = (float *) ((char *) dst->data + j*nb1); - float * src0_ptr = (float *) ((char *) src0->data + j*nb01); - - for (int i = 0; i < nc; i++) { - dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min); - } - } -} - -static void ggml_compute_forward_clamp( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_clamp_f32(params, src0, dst); - } break; - case GGML_TYPE_F16: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q8_1: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - case GGML_TYPE_Q8_K: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_COUNT: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_rope - -static void ggml_compute_forward_rope_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - float freq_base; - float freq_scale; - - // these two only relevant for xPos RoPE: - float xpos_base; - bool xpos_down; - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - const int n_ctx = ((int32_t *) dst->op_params)[3]; - memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool)); - - GGML_TENSOR_UNARY_OP_LOCALS - - //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); - //printf("n_past = %d, ne2 = %d\n", n_past, ne2); - - GGML_ASSERT(nb00 == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(dst); - - GGML_ASSERT(n_dims <= ne0); - GGML_ASSERT(n_dims % 2 == 0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - // row index used to determine which thread to use - int ir = 0; - - const float theta_scale = powf(freq_base, -2.0f/n_dims); - - const bool is_neox = mode & 2; - const bool is_glm = mode & 4; - - const int32_t * pos = (const int32_t *) src1->data; - - for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = 0; i2 < ne2; i2++) { - const int64_t p = pos[i2]; - for (int64_t i1 = 0; i1 < ne1; i1++) { - if (ir++ < ir0) continue; - if (ir > ir1) break; - - float theta = freq_scale * (float)p; - - if (is_glm) { - theta = MIN(p, n_ctx - 2); - float block_theta = MAX(p - (n_ctx - 2), 0); - for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); - const float cos_block_theta = cosf(block_theta); - const float sin_block_theta = sinf(block_theta); - - theta *= theta_scale; - block_theta *= theta_scale; - - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - const float x2 = src[n_dims]; - const float x3 = src[n_dims/2*3]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - dst_data[n_dims] = x2*cos_block_theta - x3*sin_block_theta; - dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta; - } - } else if (!is_neox) { - for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); - // zeta scaling for xPos only: - float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f; - if (xpos_down) zeta = 1.0f / zeta; - - theta *= theta_scale; - - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = src[0]; - const float x1 = src[1]; - - dst_data[0] = x0*cos_theta*zeta - x1*sin_theta*zeta; - dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta; - } - } else { - // TODO: this might be wrong for ne0 != n_dims - need double check - // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 - for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { - for (int64_t ic = 0; ic < n_dims; ic += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); - - theta *= theta_scale; - - const int64_t i0 = ib*n_dims + ic/2; - - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - } - } - } - } - } - } -} - -static void ggml_compute_forward_rope_f16( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - float freq_base; - float freq_scale; - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - const int n_ctx = ((int32_t *) dst->op_params)[3]; - memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); - - GGML_TENSOR_UNARY_OP_LOCALS - - //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); - //printf("n_past = %d, ne2 = %d\n", n_past, ne2); - - GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(dst); - - GGML_ASSERT(n_dims <= ne0); - GGML_ASSERT(n_dims % 2 == 0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - // row index used to determine which thread to use - int ir = 0; - - const float theta_scale = powf(freq_base, -2.0f/n_dims); - - const bool is_neox = mode & 2; - const bool is_glm = mode & 4; - - const int32_t * pos = (const int32_t *) src1->data; - - for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = 0; i2 < ne2; i2++) { - const int64_t p = pos[i2]; - for (int64_t i1 = 0; i1 < ne1; i1++) { - if (ir++ < ir0) continue; - if (ir > ir1) break; - - float theta = freq_scale * (float)p; - - if (is_glm) { - theta = MIN(p, n_ctx - 2); - float block_theta = MAX(p - (n_ctx - 2), 0); - for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); - const float cos_block_theta = cosf(block_theta); - const float sin_block_theta = sinf(block_theta); - - theta *= theta_scale; - block_theta *= theta_scale; - - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = GGML_FP16_TO_FP32(src[0]); - const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]); - const float x2 = GGML_FP16_TO_FP32(src[n_dims]); - const float x3 = GGML_FP16_TO_FP32(src[n_dims/2*3]); - - dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); - dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); - dst_data[n_dims] = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta); - dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta); - } - } if (!is_neox) { - for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); - - theta *= theta_scale; - - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = GGML_FP16_TO_FP32(src[0]); - const float x1 = GGML_FP16_TO_FP32(src[1]); - - dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); - dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); - } - } else { - // TODO: this might be wrong for ne0 != n_dims - need double check - // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 - for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { - for (int64_t ic = 0; ic < n_dims; ic += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); - - theta *= theta_scale; - - const int64_t i0 = ib*n_dims + ic/2; - - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = GGML_FP16_TO_FP32(src[0]); - const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]); - - dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); - dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); - } - } - } - } - } - } -} - -static void ggml_compute_forward_rope( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_rope_f16(params, src0, src1, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_rope_f32(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_rope_back - -static void ggml_compute_forward_rope_back_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - // y = rope(x, src1) - // dx = rope_back(dy, src1) - // src0 is dy, src1 contains options - - float freq_base; - float freq_scale; - - // these two only relevant for xPos RoPE: - float xpos_base; - bool xpos_down; - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - const int n_ctx = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx); - memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool)); - - GGML_TENSOR_UNARY_OP_LOCALS - - //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); - //printf("n_past = %d, ne2 = %d\n", n_past, ne2); - - assert(nb0 == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(dst); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - // row index used to determine which thread to use - int ir = 0; - - const float theta_scale = powf(freq_base, -2.0f/n_dims); - - const bool is_neox = mode & 2; - - const int32_t * pos = (const int32_t *) src1->data; - - for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = 0; i2 < ne2; i2++) { - const int64_t p = pos[i2]; - for (int64_t i1 = 0; i1 < ne1; i1++) { - if (ir++ < ir0) continue; - if (ir > ir1) break; - - float theta = freq_scale * (float)p; - - if (!is_neox) { - for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); - // zeta scaling for xPos only: - float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f; - if (xpos_down) zeta = 1.0f / zeta; - - theta *= theta_scale; - - const float * const dy = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float dy0 = dy[0]; - const float dy1 = dy[1]; - - dx[0] = dy0*cos_theta*zeta + dy1*sin_theta*zeta; - dx[1] = - dy0*sin_theta*zeta + dy1*cos_theta*zeta; - } - } else { - for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { - for (int64_t ic = 0; ic < n_dims; ic += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); - - theta *= theta_scale; - - const int64_t i0 = ib*n_dims + ic/2; - - const float * const dy = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float dy0 = dy[0]; - const float dy1 = dy[n_dims/2]; - - dx[0] = dy0*cos_theta + dy1*sin_theta; - dx[n_dims/2] = - dy0*sin_theta + dy1*cos_theta; - } - } - } - } - } - } -} - -static void ggml_compute_forward_rope_back_f16( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - // y = rope(x, src1) - // dx = rope_back(dy, src1) - // src0 is dy, src1 contains options - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - - GGML_TENSOR_UNARY_OP_LOCALS - - //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); - //printf("n_past = %d, ne2 = %d\n", n_past, ne2); - - assert(nb0 == sizeof(ggml_fp16_t)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(dst); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - // row index used to determine which thread to use - int ir = 0; - - const float theta_scale = powf(10000.0, -2.0f/n_dims); - - const bool is_neox = mode & 2; - - const int32_t * pos = (const int32_t *) src1->data; - - for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = 0; i2 < ne2; i2++) { - const int64_t p = pos[i2]; - for (int64_t i1 = 0; i1 < ne1; i1++) { - if (ir++ < ir0) continue; - if (ir > ir1) break; - - float theta = (float)p; - - if (!is_neox) { - for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); - - theta *= theta_scale; - - const ggml_fp16_t * const dy = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - ggml_fp16_t * dx = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float dy0 = GGML_FP16_TO_FP32(dy[0]); - const float dy1 = GGML_FP16_TO_FP32(dy[1]); - - dx[0] = GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta); - dx[1] = GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta); - } - } else { - for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { - for (int64_t ic = 0; ic < n_dims; ic += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); - - theta *= theta_scale; - - const int64_t i0 = ib*n_dims + ic/2; - - const ggml_fp16_t * const dy = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - ggml_fp16_t * dx = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float dy0 = GGML_FP16_TO_FP32(dy[0]); - const float dy1 = GGML_FP16_TO_FP32(dy[n_dims/2]); - - dx[0] = GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta); - dx[n_dims/2] = GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta); - } - } - } - } - } - } -} - -static void ggml_compute_forward_rope_back( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_rope_back_f16(params, src0, src1, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_rope_back_f32(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_conv_1d - -static void ggml_compute_forward_conv_1d_f16_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const int nk = ne00; - - // size of the convolution row - the kernel size unrolled across all input channels - const int ew0 = nk*ne01; - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; - - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(float)); - - if (params->type == GGML_TASK_INIT) { - memset(params->wdata, 0, params->wsize); - - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - - for (int64_t i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - ggml_fp16_t * dst_data = wdata; - - for (int64_t i0 = 0; i0 < ne0; i0++) { - for (int64_t ik = 0; ik < nk; ik++) { - const int idx0 = i0*s0 + ik*d0 - p0; - - if(!(idx0 < 0 || idx0 >= ne10)) { - dst_data[i0*ew0 + i11*nk + ik] = GGML_FP32_TO_FP16(src[idx0]); - } - } - } - } - - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - // total rows in dst - const int nr = ne2; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - - for (int i2 = 0; i2 < ne2; i2++) { - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1); - - for (int i0 = 0; i0 < ne0; i0++) { - ggml_vec_dot_f16(ew0, dst_data + i0, - (ggml_fp16_t *) ((char *) src0->data + i1*nb02), - (ggml_fp16_t *) wdata + i2*nb2 + i0*ew0); - } - } - } -} - -static void ggml_compute_forward_conv_1d_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const int nk = ne00; - - const int ew0 = nk*ne01; - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; - - GGML_ASSERT(nb00 == sizeof(float)); - GGML_ASSERT(nb10 == sizeof(float)); - - if (params->type == GGML_TASK_INIT) { - memset(params->wdata, 0, params->wsize); - - float * const wdata = (float *) params->wdata + 0; - - for (int64_t i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - float * dst_data = wdata; - - for (int64_t i0 = 0; i0 < ne0; i0++) { - for (int64_t ik = 0; ik < nk; ik++) { - const int idx0 = i0*s0 + ik*d0 - p0; - - if(!(idx0 < 0 || idx0 >= ne10)) { - dst_data[i0*ew0 + i11*nk + ik] = src[idx0]; - } - } - } - } - - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - // total rows in dst - const int nr = ne02; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - float * const wdata = (float *) params->wdata + 0; - - for (int i2 = 0; i2 < ne2; i2++) { - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1); - - for (int i0 = 0; i0 < ne0; i0++) { - ggml_vec_dot_f32(ew0, dst_data + i0, - (float *) ((char *) src0->data + i1*nb02), - (float *) wdata + i2*nb2 + i0*ew0); - } - } - } -} - -static void gemm_f16_out_f32(int64_t m, int64_t n, int64_t k, - ggml_fp16_t * A, - ggml_fp16_t * B, - float * C, - const int ith, const int nth) { - // does not seem to make a difference - int64_t m0, m1, n0, n1; - // patches per thread - if (m > n) { - n0 = 0; - n1 = n; - - // total patches in dst - const int np = m; - - // patches per thread - const int dp = (np + nth - 1)/nth; - - // patch range for this thread - m0 = dp*ith; - m1 = MIN(m0 + dp, np); - } else { - m0 = 0; - m1 = m; - - // total patches in dst - const int np = n; - - // patches per thread - const int dp = (np + nth - 1)/nth; - - // patch range for this thread - n0 = dp*ith; - n1 = MIN(n0 + dp, np); - } - - // block-tiling attempt - int64_t blck_n = 16; - int64_t blck_m = 16; - - // int64_t CACHE_SIZE = 2 * 1024 * 1024; // 2MB - // int64_t blck_size = CACHE_SIZE / (sizeof(float) + 2 * sizeof(ggml_fp16_t) * K); - // if (blck_size > 0) { - // blck_0 = 4; - // blck_1 = blck_size / blck_0; - // if (blck_1 < 0) { - // blck_1 = 1; - // } - // // blck_0 = (int64_t)sqrt(blck_size); - // // blck_1 = blck_0; - // } - // // printf("%zd %zd %zd %zd\n", blck_size, K, blck_0, blck_1); - - for (int j = n0; j < n1; j+=blck_n) { - for (int i = m0; i < m1; i+=blck_m) { - // printf("i j k => %d %d %d\n", i, j, K); - for (int ii = i; ii < i + blck_m && ii < m1; ii++) { - for (int jj = j; jj < j + blck_n && jj < n1; jj++) { - ggml_vec_dot_f16(k, - C + ii*n + jj, - A + ii * k, - B + jj * k); - } - } - } - } -} - -// src0: kernel [OC, IC, K] -// src1: signal [N, IC, IL] -// dst: result [N, OL, IC*K] -static void ggml_compute_forward_conv_1d_stage_0_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F16); - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_BINARY_OP_LOCALS; - - const int64_t N = ne12; - const int64_t IC = ne11; - const int64_t IL = ne10; - - const int64_t K = ne00; - - const int64_t OL = ne1; - - const int ith = params->ith; - const int nth = params->nth; - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; - - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(float)); - - if (params->type == GGML_TASK_INIT) { - memset(dst->data, 0, ggml_nbytes(dst)); - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - // im2col: [N, IC, IL] => [N, OL, IC*K] - { - ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data; - - for (int64_t in = 0; in < N; in++) { - for (int64_t iol = 0; iol < OL; iol++) { - for (int64_t iic = ith; iic < IC; iic+=nth) { - - // micro kernel - ggml_fp16_t * dst_data = wdata + (in*OL + iol)*(IC*K); // [IC, K] - const float * const src_data = (float *)((char *) src1->data + in*nb12 + iic*nb11); // [IL] - - for (int64_t ik = 0; ik < K; ik++) { - const int64_t iil = iol*s0 + ik*d0 - p0; - - if (!(iil < 0 || iil >= IL)) { - dst_data[iic*K + ik] = GGML_FP32_TO_FP16(src_data[iil]); - } - } - } - } - } - } -} - -// gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K] -// src0: [OC, IC, K] -// src1: [N, OL, IC * K] -// result: [N, OC, OL] -static void ggml_compute_forward_conv_1d_stage_1_f16( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F16); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - if (params->type == GGML_TASK_INIT) { - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - GGML_TENSOR_BINARY_OP_LOCALS; - - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb0 == sizeof(float)); - - const int N = ne12; - const int OL = ne11; - - const int OC = ne02; - const int IC = ne01; - const int K = ne00; - - const int ith = params->ith; - const int nth = params->nth; - - int64_t m = OC; - int64_t n = OL; - int64_t k = IC * K; - - // [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K] - for (int i = 0; i < N; i++) { - ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k] - ggml_fp16_t * B = (ggml_fp16_t *)src1->data + i * m * k; // [n, k] - float * C = (float *)dst->data + i * m * n; // [m, n] - - gemm_f16_out_f32(m, n, k, A, B, C, ith, nth); - } -} - -static void ggml_compute_forward_conv_1d( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch(src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_conv_1d_f16_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_conv_1d_f32(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -static void ggml_compute_forward_conv_1d_stage_0( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch(src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_conv_1d_stage_0_f32(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -static void ggml_compute_forward_conv_1d_stage_1( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch(src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_conv_1d_stage_1_f16(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_conv_transpose_1d - -static void ggml_compute_forward_conv_transpose_1d_f16_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const int nk = ne00*ne01*ne02; - - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(float)); - - if (params->type == GGML_TASK_INIT) { - memset(params->wdata, 0, params->wsize); - - // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) - { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); - ggml_fp16_t * dst_data = wdata + i01*ne00*ne02; - for (int64_t i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ne02 + i02] = src[i00]; - } - } - } - } - - // permute source data (src1) from (L x Cin) to (Cin x L) - { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk; - ggml_fp16_t * dst_data = wdata; - - for (int64_t i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - for (int64_t i10 = 0; i10 < ne10; i10++) { - dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]); - } - } - } - - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - - // total rows in dst - const int nr = ne1; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - ggml_fp16_t * const wdata_src = wdata + nk; - - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i1*nb1); - ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00; - for (int i10 = 0; i10 < ne10; i10++) { - const int i1n = i10*ne11; - for (int i00 = 0; i00 < ne00; i00++) { - float v = 0; - ggml_vec_dot_f16(ne02, &v, - (ggml_fp16_t *) wdata_src + i1n, - (ggml_fp16_t *) wdata_kernel + i00*ne02); - dst_data[i10*s0 + i00] += v; - } - } - } -} - -static void ggml_compute_forward_conv_transpose_1d_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const int nk = ne00*ne01*ne02; - - GGML_ASSERT(nb00 == sizeof(float)); - GGML_ASSERT(nb10 == sizeof(float)); - - if (params->type == GGML_TASK_INIT) { - memset(params->wdata, 0, params->wsize); - - // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) - { - float * const wdata = (float *) params->wdata + 0; - - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); - float * dst_data = wdata + i01*ne00*ne02; - for (int64_t i00 = 0; i00 < ne00; i00++) { - dst_data[i01*ne00*ne02 + i00*ne02 + i02] = src[i00]; - } - } - } - } - - // prepare source data (src1) - { - float * const wdata = (float *) params->wdata + nk; - float * dst_data = wdata; - - for (int64_t i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - for (int64_t i10 = 0; i10 < ne10; i10++) { - dst_data[i10*ne11 + i11] = src[i10]; - } - } - } - - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - - // total rows in dst - const int nr = ne1; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - float * const wdata = (float *) params->wdata + 0; - float * const wdata_src = wdata + nk; - - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i1*nb1); - float * wdata_kernel = wdata + i1*ne02*ne00; - for (int i10 = 0; i10 < ne10; i10++) { - const int i1n = i10*ne11; - for (int i00 = 0; i00 < ne00; i00++) { - float v = 0; - ggml_vec_dot_f32(ne02, &v, - wdata_src + i1n, - wdata_kernel + i00*ne02); - dst_data[i10*s0 + i00] += v; - } - } - } -} - -static void ggml_compute_forward_conv_transpose_1d( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_conv_transpose_1d_f16_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_conv_transpose_1d_f32(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_conv_2d - -static void ggml_compute_forward_conv_2d_f16_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_BINARY_OP_LOCALS; - - const int ith = params->ith; - const int nth = params->nth; - - const int nk0 = ne00; - const int nk1 = ne01; - - // size of the convolution row - the kernel size unrolled across all channels - const int ew0 = nk0*nk1*ne02; - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; - - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(float)); - - if (params->type == GGML_TASK_INIT) { - memset(params->wdata, 0, params->wsize); - - // prepare source data (src1) - { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - - for (int i13 = 0; i13 < ne13; i13++) { - for (int i12 = 0; i12 < ne12; i12++) { - const float * const src = (float *)((char *) src1->data + i13*nb13 + i12*nb12); - ggml_fp16_t * dst_data = wdata + i13*(ne1*ne0*ew0); - - for (int i1 = 0; i1 < ne1; i1++) { - for (int i0 = 0; i0 < ne0; i0++) { - for (int ik1 = 0; ik1 < nk1; ik1++) { - for (int ik0 = 0; ik0 < nk0; ik0++) { - const int idx0 = i0*s0 + ik0*d0 - p0; - const int idx1 = i1*s1 + ik1*d1 - p1; - - if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) { - dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] = - GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]); - } - } - } - } - } - } - } - } - - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - // total patches in dst - const int np = ne2; - - // patches per thread - const int dp = (np + nth - 1)/nth; - - // patch range for this thread - const int ip0 = dp*ith; - const int ip1 = MIN(ip0 + dp, np); - - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - - for (int i3 = 0; i3 < ne3; i3++) { - for (int i2 = ip0; i2 < ip1; i2++) { - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2); - - for (int i1 = 0; i1 < ne1; ++i1) { - for (int i0 = 0; i0 < ne0; ++i0) { - ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0, - (ggml_fp16_t *) ((char *) src0->data + i2*nb03), - (ggml_fp16_t *) wdata + i3*nb3 + (i1*ne0 + i0)*ew0); - } - } - } - } -} - -static void ggml_compute_forward_conv_2d( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, dst); - } break; - case GGML_TYPE_F32: - { - //ggml_compute_forward_conv_2d_f32(params, src0, src1, dst); - GGML_ASSERT(false); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_conv_transpose_2d - -static void ggml_compute_forward_conv_transpose_2d( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const int nk = ne00*ne01*ne02*ne03; - - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(float)); - - if (params->type == GGML_TASK_INIT) { - memset(params->wdata, 0, params->wsize); - - // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout) - { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02); - ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03; - for (int64_t i01 = 0; i01 < ne01; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00]; - } - } - } - } - } - - // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh) - { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk; - for (int i12 = 0; i12 < ne12; i12++) { - for (int i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11); - ggml_fp16_t * dst_data = wdata + i11*ne10*ne12; - for (int i10 = 0; i10 < ne10; i10++) { - dst_data[i10*ne12 + i12] = GGML_FP32_TO_FP16(src[i10]); - } - } - } - } - - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - const int32_t stride = ggml_get_op_params_i32(dst, 0); - - // total patches in dst - const int np = ne2; - - // patches per thread - const int dp = (np + nth - 1)/nth; - - // patch range for this thread - const int ip0 = dp*ith; - const int ip1 = MIN(ip0 + dp, np); - - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - ggml_fp16_t * const wdata_src = wdata + nk; - - for (int i2 = ip0; i2 < ip1; i2++) { // Cout - float * dst_data = (float *)((char *) dst->data + i2*nb2); - ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03; - for (int i11 = 0; i11 < ne11; i11++) { - for (int i10 = 0; i10 < ne10; i10++) { - const int i1n = i11*ne10*ne12 + i10*ne12; - for (int i01 = 0; i01 < ne01; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - float v = 0; - ggml_vec_dot_f16(ne03, &v, - wdata_src + i1n, - wdata_kernel + i01*ne00*ne03 + i00*ne03); - dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v; - } - } - } - } - } -} - -// ggml_compute_forward_pool_1d_sk_p0 - -static void ggml_compute_forward_pool_1d_sk_p0( - const struct ggml_compute_params * params, - const enum ggml_op_pool op, - const struct ggml_tensor * src, - const int k, - struct ggml_tensor * dst) { - assert(src->type == GGML_TYPE_F32); - assert(params->ith == 0); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const char * cdata = (const char *)src->data; - const char * const data_end = cdata + ggml_nbytes(src); - float * drow = (float *)dst->data; - - const int64_t rs = dst->ne[0]; - - while (cdata < data_end) { - const float * const srow = (const float *)cdata; - - int j = 0; - - for (int64_t i = 0; i < rs; ++i) { - switch (op) { - case GGML_OP_POOL_AVG: drow[i] = 0; break; - case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break; - case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; - } - for (int ki = 0; ki < k; ++ki) { - switch (op) { - case GGML_OP_POOL_AVG: drow[i] += srow[j]; break; - case GGML_OP_POOL_MAX: if (srow[j] > drow[i]) drow[i] = srow[j]; break; - case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; - } - ++j; - } - switch (op) { - case GGML_OP_POOL_AVG: drow[i] /= k; break; - case GGML_OP_POOL_MAX: break; - case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; - } - } - - cdata += src->nb[1]; - drow += rs; - } -} - -// ggml_compute_forward_pool_1d - -static void ggml_compute_forward_pool_1d( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - - const int32_t * opts = (const int32_t *)dst->op_params; - enum ggml_op_pool op = opts[0]; - const int k0 = opts[1]; - const int s0 = opts[2]; - const int p0 = opts[3]; - GGML_ASSERT(p0 == 0); // padding not supported - GGML_ASSERT(k0 == s0); // only s = k supported - - ggml_compute_forward_pool_1d_sk_p0(params, op, src0, k0, dst); -} - -// ggml_compute_forward_pool_2d_sk_p0 - -static void ggml_compute_forward_pool_2d_sk_p0( - const struct ggml_compute_params * params, - const enum ggml_op_pool op, - const struct ggml_tensor * src, - const int k0, - const int k1, - struct ggml_tensor * dst) { - assert(src->type == GGML_TYPE_F32); - assert(params->ith == 0); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const char * cdata = (const char*)src->data; - const char * const data_end = cdata + ggml_nbytes(src); - - const int64_t px = dst->ne[0]; - const int64_t py = dst->ne[1]; - const int64_t pa = px * py; - - float * dplane = (float *)dst->data; - - const int ka = k0 * k1; - - while (cdata < data_end) { - for (int oy = 0; oy < py; ++oy) { - float * const drow = dplane + oy * px; - for (int ox = 0; ox < px; ++ox) { - float * const out = drow + ox; - switch (op) { - case GGML_OP_POOL_AVG: *out = 0; break; - case GGML_OP_POOL_MAX: *out = -FLT_MAX; break; - case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; - } - - const int ix = ox * k0; - const int iy = oy * k1; - - for (int ky = 0; ky < k1; ++ky) { - const float * const srow = (const float *)(cdata + src->nb[1] * (iy + ky)); - for (int kx = 0; kx < k0; ++kx) { - int j = ix + kx; - switch (op) { - case GGML_OP_POOL_AVG: *out += srow[j]; break; - case GGML_OP_POOL_MAX: if (srow[j] > *out) *out = srow[j]; break; - case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; - } - } - } - switch (op) { - case GGML_OP_POOL_AVG: *out /= ka; break; - case GGML_OP_POOL_MAX: break; - case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; - } - } - } - - cdata += src->nb[2]; - dplane += pa; - } -} - -// ggml_compute_forward_pool_2d - -static void ggml_compute_forward_pool_2d( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - - const int32_t * opts = (const int32_t *)dst->op_params; - enum ggml_op_pool op = opts[0]; - const int k0 = opts[1]; - const int k1 = opts[2]; - const int s0 = opts[3]; - const int s1 = opts[4]; - const int p0 = opts[5]; - const int p1 = opts[6]; - GGML_ASSERT(p0 == 0); - GGML_ASSERT(p1 == 0); // padding not supported - GGML_ASSERT(k0 == s0); - GGML_ASSERT(k1 == s1); // only s = k supported - - ggml_compute_forward_pool_2d_sk_p0(params, op, src0, k0, k1, dst); -} - -// ggml_compute_forward_upscale - -static void ggml_compute_forward_upscale_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - - GGML_TENSOR_UNARY_OP_LOCALS - - const int scale_factor = dst->op_params[0]; - - // TODO: optimize - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = ith; i02 < ne02; i02++) { - for (int m = 0; m < dst->ne[1]; m++) { - int i01 = m / scale_factor; - for (int n = 0; n < dst->ne[0]; n++) { - int i00 = n / scale_factor; - - const float * x = (float *)((char *) src0->data + i00 * nb00 +i01 * nb01 + i02 * nb02 + i03 * nb03); - - float * y = (float *)((char *) dst->data + n * dst->nb[0] + m * dst->nb[1] + i02 * dst->nb[2] + i03 * dst->nb[3]); - - *y = *x; - } - } - } - } -} - -static void ggml_compute_forward_upscale( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_upscale_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_flash_attn - -static void ggml_compute_forward_flash_attn_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * q, - const struct ggml_tensor * k, - const struct ggml_tensor * v, - const bool masked, - struct ggml_tensor * dst) { - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_LOCALS(int64_t, neq, q, ne) - GGML_TENSOR_LOCALS(size_t, nbq, q, nb) - GGML_TENSOR_LOCALS(int64_t, nek, k, ne) - GGML_TENSOR_LOCALS(size_t, nbk, k, nb) - GGML_TENSOR_LOCALS(int64_t, nev, v, ne) - GGML_TENSOR_LOCALS(size_t, nbv, v, nb) - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) - GGML_TENSOR_LOCALS(size_t, nb, dst, nb) - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t D = neq0; - const int64_t N = neq1; - const int64_t P = nek1 - N; - const int64_t M = P + N; - - const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); - - GGML_ASSERT(ne0 == D); - GGML_ASSERT(ne1 == N); - GGML_ASSERT(P >= 0); - - GGML_ASSERT(nbq0 == sizeof(float)); - GGML_ASSERT(nbk0 == sizeof(float)); - GGML_ASSERT(nbv0 == sizeof(float)); - - GGML_ASSERT(neq0 == D); - GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev1 == D); - - GGML_ASSERT(neq1 == N); - GGML_ASSERT(nek1 == N + P); - GGML_ASSERT(nev1 == D); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - if (params->type == GGML_TASK_INIT) { - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - // parallelize by q rows using ggml_vec_dot_f32 - - // total rows in q - const int nr = neq1*neq2*neq3; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - const float scale = 1.0f/sqrtf(D); - - //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); - - for (int ir = ir0; ir < ir1; ++ir) { - // q indices - const int iq3 = ir/(neq2*neq1); - const int iq2 = (ir - iq3*neq2*neq1)/neq1; - const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); - - float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32); - - for (int i = M; i < Mup; ++i) { - S[i] = -INFINITY; - } - - const int64_t masked_begin = masked ? (P + iq1 + 1) : M; - for (int64_t ic = 0; ic < masked_begin; ++ic) { - // k indices - const int ik3 = iq3; - const int ik2 = iq2 % nek2; - const int ik1 = ic; - - // S indices - const int i1 = ik1; - - ggml_vec_dot_f32(neq0, - S + i1, - (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - } - - // scale - ggml_vec_scale_f32(masked_begin, S, scale); - - for (int64_t i = masked_begin; i < M; i++) { - S[i] = -INFINITY; - } - - // softmax - // exclude known -INF S[..] values from max and loop - // dont forget to set their SW values to zero - { - float max = -INFINITY; - ggml_vec_max_f32(masked_begin, &max, S); - - ggml_float sum = 0.0; - { -#ifdef GGML_SOFT_MAX_ACCELERATE - max = -max; - vDSP_vsadd(S, 1, &max, S, 1, Mup); - vvexpf(S, S, &Mup); - ggml_vec_sum_f32(Mup, &sum, S); -#else - uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt); - ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; - - for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { - if (i >= masked_begin) { - break; - } - float * SS = S + i; - - for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { - if (i + j >= masked_begin) { - break; - } else if (SS[j] == -INFINITY) { - SS[j] = 0.0f; - } else { -#ifndef GGML_FLASH_ATTN_EXP_FP16 - const float val = expf(SS[j] - max); -#else - ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); - memcpy(&scvt[j], &s, sizeof(uint16_t)); - const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); -#endif - sump[j] += (ggml_float)val; - SS[j] = val; - } - } - } - - for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { - sum += sump[i]; - } -#endif - } - - assert(sum > 0.0); - - sum = 1.0/sum; - ggml_vec_scale_f32(masked_begin, S, sum); - -#ifndef NDEBUG - for (int i = 0; i < masked_begin; ++i) { - assert(!isnan(S[i])); - assert(!isinf(S[i])); - } -#endif - } - - for (int64_t ic = 0; ic < nev1; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - - // v indices - const int iv2 = iq2 % nev2; - const int iv3 = iq3; - - ggml_vec_dot_f32(masked_begin, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), - S); - } - } -} - -static void ggml_compute_forward_flash_attn_f16( - const struct ggml_compute_params * params, - const struct ggml_tensor * q, - const struct ggml_tensor * k, - const struct ggml_tensor * v, - const bool masked, - struct ggml_tensor * dst) { - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_LOCALS(int64_t, neq, q, ne) - GGML_TENSOR_LOCALS(size_t, nbq, q, nb) - GGML_TENSOR_LOCALS(int64_t, nek, k, ne) - GGML_TENSOR_LOCALS(size_t, nbk, k, nb) - GGML_TENSOR_LOCALS(int64_t, nev, v, ne) - GGML_TENSOR_LOCALS(size_t, nbv, v, nb) - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) - GGML_TENSOR_LOCALS(size_t, nb, dst, nb) - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t D = neq0; - const int64_t N = neq1; - const int64_t P = nek1 - N; - const int64_t M = P + N; - - const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); - - GGML_ASSERT(ne0 == D); - GGML_ASSERT(ne1 == N); - GGML_ASSERT(P >= 0); - - GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); - - GGML_ASSERT(neq0 == D); - GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev1 == D); - - GGML_ASSERT(neq1 == N); - GGML_ASSERT(nek1 == N + P); - GGML_ASSERT(nev1 == D); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - if (params->type == GGML_TASK_INIT) { - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - // parallelize by q rows using ggml_vec_dot_f32 - - // total rows in q - const int nr = neq1*neq2*neq3; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - const float scale = 1.0f/sqrtf(D); - - //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); - - for (int ir = ir0; ir < ir1; ++ir) { - // q indices - const int iq3 = ir/(neq2*neq1); - const int iq2 = (ir - iq3*neq2*neq1)/neq1; - const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); - - float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32); - - for (int i = M; i < Mup; ++i) { - S[i] = -INFINITY; - } - - if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { - for (int64_t ic = 0; ic < nek1; ++ic) { - // k indices - const int ik3 = iq3; - const int ik2 = iq2 % nek2; - const int ik1 = ic; - - // S indices - const int i1 = ik1; - - ggml_vec_dot_f16(neq0, - S + i1, - (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - } - } else { - for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { - // k indices - const int ik3 = iq3; - const int ik2 = iq2 % nek2; - const int ik1 = ic; - - // S indices - const int i1 = ik1; - - ggml_vec_dot_f16_unroll(neq0, nbk1, - S + i1, - ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - } - } - - // scale - ggml_vec_scale_f32(nek1, S, scale); - - if (masked) { - for (int64_t i = P; i < M; i++) { - if (i > P + iq1) { - S[i] = -INFINITY; - } - } - } - - // softmax - // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero. - // dont forget to set their S values to zero - { - float max = -INFINITY; - ggml_vec_max_f32(M, &max, S); - - ggml_float sum = 0.0; - { -#ifdef GGML_SOFT_MAX_ACCELERATE - max = -max; - vDSP_vsadd(S, 1, &max, S, 1, Mup); - vvexpf(S, S, &Mup); - ggml_vec_sum_f32(Mup, &sum, S); -#else - uint16_t scvt[GGML_SOFT_MAX_UNROLL]; - ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; - - for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { - float * SS = S + i; - - for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { - if (SS[j] == -INFINITY) { - SS[j] = 0.0f; - } else { - ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); - memcpy(&scvt[j], &s, sizeof(uint16_t)); - const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); - sump[j] += (ggml_float)val; - SS[j] = val; - } - } - } - - for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { - sum += sump[i]; - } -#endif - } - - assert(sum > 0.0); - - sum = 1.0/sum; - ggml_vec_scale_f32(M, S, sum); - -#ifndef NDEBUG - for (int i = 0; i < M; ++i) { - assert(!isnan(S[i])); - assert(!isinf(S[i])); - } -#endif - } - - ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup); - - for (int64_t i = 0; i < M; i++) { - S16[i] = GGML_FP32_TO_FP16(S[i]); - } - - // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16). - if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) { - for (int64_t ic = 0; ic < nev1; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - - // v indices - const int iv2 = iq2 % nev2; - const int iv3 = iq3; - - ggml_vec_dot_f16(nev0, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), - S16); - } - } else { - for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - - // v indices - const int iv2 = iq2 % nev2; - const int iv3 = iq3; - - ggml_vec_dot_f16_unroll(nev0, nbv1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), - S16); - } - } - } -} - -static void ggml_compute_forward_flash_attn( - const struct ggml_compute_params * params, - const struct ggml_tensor * q, - const struct ggml_tensor * k, - const struct ggml_tensor * v, - const bool masked, - struct ggml_tensor * dst) { - switch (q->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_flash_attn_f16(params, q, k, v, masked, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_flash_ff - -static void ggml_compute_forward_flash_ff_f16( - const struct ggml_compute_params * params, - const struct ggml_tensor * a, // F16 - const struct ggml_tensor * b0, // F16 fc_w - const struct ggml_tensor * b1, // F32 fc_b - const struct ggml_tensor * c0, // F16 proj_w - const struct ggml_tensor * c1, // F32 proj_b - struct ggml_tensor * dst) { - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_LOCALS(int64_t, nea, a, ne) - GGML_TENSOR_LOCALS(size_t, nba, a, nb) - GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne) - GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb) - GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne) - GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb) - GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne) - GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb) - GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne) - GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb) - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) - GGML_TENSOR_LOCALS(size_t, nb, dst, nb) - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t D = nea0; - //const int64_t N = nea1; - const int64_t M = neb01; - - GGML_ASSERT(ne0 == nea0); - GGML_ASSERT(ne1 == nea1); - GGML_ASSERT(ne2 == nea2); - - GGML_ASSERT(nba0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nbb10 == sizeof(float)); - GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nbc10 == sizeof(float)); - - GGML_ASSERT(neb00 == D); - GGML_ASSERT(neb01 == M); - GGML_ASSERT(neb10 == M); - GGML_ASSERT(neb11 == 1); - - GGML_ASSERT(nec00 == M); - GGML_ASSERT(nec01 == D); - GGML_ASSERT(nec10 == D); - GGML_ASSERT(nec11 == 1); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - if (params->type == GGML_TASK_INIT) { - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - // parallelize by a rows using ggml_vec_dot_f32 - - // total rows in a - const int nr = nea1*nea2*nea3; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // a indices - const int ia3 = ir/(nea2*nea1); - const int ia2 = (ir - ia3*nea2*nea1)/nea1; - const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1); - - float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32); - - for (int64_t ic = 0; ic < neb01; ++ic) { - // b0 indices - const int ib03 = ia3; - const int ib02 = ia2; - const int ib01 = ic; - - // S indices - const int i1 = ib01; - - ggml_vec_dot_f16(nea0, - S + i1, - (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), - (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3))); - } - - ggml_vec_add_f32(neb01, S, S, (float *) b1->data); - //ggml_vec_gelu_f32(neb01, S, S); - - ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M); - - for (int64_t i = 0; i < M; i++) { - S16[i] = GGML_FP32_TO_FP16(S[i]); - } - - ggml_vec_gelu_f16(neb01, S16, S16); - - { - // dst indices - const int i1 = ia1; - const int i2 = ia2; - const int i3 = ia3; - - for (int64_t ic = 0; ic < nec01; ++ic) { - - ggml_vec_dot_f16(neb01, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), - S16); - } - - ggml_vec_add_f32(nec01, - (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)), - (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)), - (float *) c1->data); - } - } -} - -static void ggml_compute_forward_flash_ff( - const struct ggml_compute_params * params, - const struct ggml_tensor * a, - const struct ggml_tensor * b0, - const struct ggml_tensor * b1, - const struct ggml_tensor * c0, - const struct ggml_tensor * c1, - struct ggml_tensor * dst) { - switch (b0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_flash_ff_f16(params, a, b0, b1, c0, c1, dst); - } break; - case GGML_TYPE_F32: - { - GGML_ASSERT(false); // TODO - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_flash_attn_back - -static void ggml_compute_forward_flash_attn_back_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * q, - const struct ggml_tensor * k, - const struct ggml_tensor * v, - const struct ggml_tensor * d, - const bool masked, - struct ggml_tensor * dst) { - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_LOCALS(int64_t, neq, q, ne) - GGML_TENSOR_LOCALS(size_t, nbq, q, nb) - GGML_TENSOR_LOCALS(int64_t, nek, k, ne) - GGML_TENSOR_LOCALS(size_t, nbk, k, nb) - GGML_TENSOR_LOCALS(int64_t, nev, v, ne) - GGML_TENSOR_LOCALS(size_t, nbv, v, nb) - GGML_TENSOR_LOCALS(int64_t, ned, d, ne) - GGML_TENSOR_LOCALS(size_t, nbd, d, nb) - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) - GGML_TENSOR_LOCALS(size_t, nb, dst, nb) - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t D = neq0; - const int64_t N = neq1; - const int64_t P = nek1 - N; - const int64_t M = P + N; - - const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); - const int mxDM = MAX(D, Mup); - - // GGML_ASSERT(ne0 == D); - // GGML_ASSERT(ne1 == N); - GGML_ASSERT(P >= 0); - - GGML_ASSERT(nbq0 == sizeof(float)); - GGML_ASSERT(nbk0 == sizeof(float)); - GGML_ASSERT(nbv0 == sizeof(float)); - - GGML_ASSERT(neq0 == D); - GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev1 == D); - GGML_ASSERT(ned0 == D); - - GGML_ASSERT(neq1 == N); - GGML_ASSERT(nek1 == N + P); - GGML_ASSERT(nev1 == D); - GGML_ASSERT(ned1 == N); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - if (params->type == GGML_TASK_INIT) { - if (ith == 0) { - memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3); - } - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - return; - } - - const int64_t elem_q = ggml_nelements(q); - const int64_t elem_k = ggml_nelements(k); - - enum ggml_type result_type = dst->type; - GGML_ASSERT(ggml_blck_size(result_type) == 1); - const size_t tsize = ggml_type_size(result_type); - - const size_t offs_q = 0; - const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN); - const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN); - - void * grad_q = (char *) dst->data; - void * grad_k = (char *) dst->data + offs_k; - void * grad_v = (char *) dst->data + offs_v; - - const size_t nbgq1 = nb0*neq0; - const size_t nbgq2 = nb0*neq0*neq1; - const size_t nbgq3 = nb0*neq0*neq1*neq2; - - const size_t nbgk1 = nb0*nek0; - const size_t nbgk2 = nb0*nek0*nek1; - const size_t nbgk3 = nb0*nek0*nek1*neq2; - - const size_t nbgv1 = nb0*nev0; - const size_t nbgv2 = nb0*nev0*nev1; - const size_t nbgv3 = nb0*nev0*nev1*neq2; - - // parallelize by k rows using ggml_vec_dot_f32 - - // total rows in k - const int nr = nek2*nek3; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - const float scale = 1.0f/sqrtf(D); - - //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); - - // how often k2 (and v2) is repeated in q2 - int nrep = neq2/nek2; - - for (int ir = ir0; ir < ir1; ++ir) { - // q indices - const int ik3 = ir/(nek2); - const int ik2 = ir - ik3*nek2; - - const int iq3 = ik3; - const int id3 = ik3; - const int iv3 = ik3; - const int iv2 = ik2; - - for (int irep = 0; irep < nrep; ++irep) { - const int iq2 = ik2 + irep*nek2; - const int id2 = iq2; - - // (ik2 + irep*nek2) % nek2 == ik2 - for (int iq1 = 0; iq1 < neq1; ++iq1) { - const int id1 = iq1; - - // not sure about CACHE_LINE_SIZE_F32.. - // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset? - float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32); - float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32); - - for (int i = M; i < Mup; ++i) { - S[i] = -INFINITY; - } - - const int64_t masked_begin = masked ? (P + iq1 + 1) : M; - for (int64_t ic = 0; ic < masked_begin; ++ic) { - // k indices - const int ik1 = ic; - - // S indices - const int i1 = ik1; - - ggml_vec_dot_f32(neq0, - S + i1, - (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - } - - // scale - ggml_vec_scale_f32(masked_begin, S, scale); - - for (int64_t i = masked_begin; i < M; i++) { - S[i] = -INFINITY; - } - - // softmax - // exclude known -INF S[..] values from max and loop - // dont forget to set their SM values to zero - { - float max = -INFINITY; - ggml_vec_max_f32(masked_begin, &max, S); - - ggml_float sum = 0.0; - { -#ifdef GGML_SOFT_MAX_ACCELERATE - max = -max; - vDSP_vsadd(SM, 1, &max, SM, 1, Mup); - vvexpf(SM, SM, &Mup); - ggml_vec_sum_f32(Mup, &sum, SM); -#else - uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt); - ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; - - for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { - if (i >= masked_begin) { - break; - } - float * SR = S + i; - float * SW = SM + i; - - for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { - if (i + j >= masked_begin) { - break; - } else if (SR[j] == -INFINITY) { - SW[j] = 0.0f; - } else { -#ifndef GGML_FLASH_ATTN_EXP_FP16 - const float val = expf(SR[j] - max); -#else - ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max); - memcpy(&scvt[j], &s, sizeof(uint16_t)); - const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); -#endif - sump[j] += (ggml_float)val; - SW[j] = val; - } - } - } - - for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { - sum += sump[i]; - } -#endif - } - - assert(sum > 0.0); - - sum = 1.0/sum; - ggml_vec_scale_f32(masked_begin, SM, sum); - - } - - // step-by-step explanation - { - // forward-process shape grads from backward process - // parallel_for ik2,ik3: - // for irep: - // iq2 = ik2 + irep*nek2 - // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur] - // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur] - // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur] - // for iq1: - // kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur - // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur - // vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4 - // S0 = -Inf [D,1,1,1] - // ~S1[i] = dot(kcur[:D,i], qcur) - // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale - // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P) - // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) - // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur - // ~S5[i] = dot(vcur[:,i], S4) - // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3] - // ~dst[i,iq1,iq2,iq3] = S5[i] ^ - // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3] - // dst backward-/ grad[dst] = d - // - // output gradients with their dependencies: - // - // grad[kcur] = grad[S1].T @ qcur - // grad[S1] = diag_mask_zero(grad[S3], P) * scale - // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) - // grad[S4] = grad[S5] @ vcur - // grad[S4] = d[:D,id1,id2,id3] @ vcur - // grad[qcur] = grad[S1] @ kcur - // grad[vcur] = grad[S5].T @ S4 - // grad[vcur] = d[:D,id1,id2,id3].T @ S4 - // - // in post-order: - // - // S1 = qcur @ kcur.T - // S2 = S1 * scale - // S3 = diag_mask_inf(S2, P) - // S4 = softmax(S3) - // grad[S4] = d[:D,id1,id2,id3] @ vcur - // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) - // grad[S1] = diag_mask_zero(grad[S3], P) * scale - // grad[qcur] = grad[S1] @ kcur - // grad[kcur] = grad[S1].T @ qcur - // grad[vcur] = d[:D,id1,id2,id3].T @ S4 - // - // using less variables (SM=S4): - // - // S = diag_mask_inf(qcur @ kcur.T * scale, P) - // SM = softmax(S) - // S = d[:D,iq1,iq2,iq3] @ vcur - // dot_SM_gradSM = dot(SM, S) - // S = SM * (S - dot(SM, S)) - // S = diag_mask_zero(S, P) * scale - // - // grad[q][:D,iq1,iq2,iq3] += S @ kcur - // grad[k][:D,:M,ik2,ik3] += S.T @ qcur - // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM - } - - // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3] - // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3] - // for ic: - // S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3] - // exclude known future zero S[..] values from operation - ggml_vec_set_f32(masked_begin, S, 0); - for (int64_t ic = 0; ic < D; ++ic) { - ggml_vec_mad_f32(masked_begin, - S, - (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), - *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3))); - } - - // S = SM * (S - dot(SM, S)) - float dot_SM_gradSM = 0; - ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, SM, S); - ggml_vec_acc1_f32(M, S, -dot_SM_gradSM); - ggml_vec_mul_f32 (masked_begin, S, S, SM); - - // S = diag_mask_zero(S, P) * scale - // already done by above ggml_vec_set_f32 - - // exclude known zero S[..] values from operation - ggml_vec_scale_f32(masked_begin, S, scale); - - // S shape [M,1] - // SM shape [M,1] - // kcur shape [D,M] - // qcur shape [D,1] - // vcur shape [M,D] - - // grad[q][:D,iq1,iq2,iq3] += S @ kcur - // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M] - // for ic: - // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3] - // exclude known zero S[..] values from loop - for (int64_t ic = 0; ic < masked_begin; ++ic) { - ggml_vec_mad_f32(D, - (float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)), - (float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)), - S[ic]); - } - - // grad[k][:D,:M,iq2,iq3] += S.T @ qcur - // for ic: - // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0] - // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0] - // exclude known zero S[..] values from loop - for (int64_t ic = 0; ic < masked_begin; ++ic) { - ggml_vec_mad_f32(D, - (float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)), - (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), - S[ic]); - } - - // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM - // for ic: - // grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M] - // grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M] - // exclude known zero SM[..] values from mad - for (int64_t ic = 0; ic < D; ++ic) { - ggml_vec_mad_f32(masked_begin, - (float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)), - SM, - *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3))); - } - } - } - } -} - -static void ggml_compute_forward_flash_attn_back( - const struct ggml_compute_params * params, - const struct ggml_tensor * q, - const struct ggml_tensor * k, - const struct ggml_tensor * v, - const struct ggml_tensor * d, - const bool masked, - struct ggml_tensor * dst) { - switch (q->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_flash_attn_back_f32(params, q, k, v, d, masked, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_win_part - -static void ggml_compute_forward_win_part_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) - - const int32_t nep0 = ((const int32_t *)(dst->op_params))[0]; - const int32_t nep1 = ((const int32_t *)(dst->op_params))[1]; - const int32_t w = ((const int32_t *)(dst->op_params))[2]; - - assert(ne00 == ne0); - assert(ne3 == nep0*nep1); - - // TODO: optimize / multi-thread - for (int py = 0; py < nep1; ++py) { - for (int px = 0; px < nep0; ++px) { - const int64_t i3 = py*nep0 + px; - for (int64_t i2 = 0; i2 < ne2; ++i2) { - for (int64_t i1 = 0; i1 < ne1; ++i1) { - for (int64_t i0 = 0; i0 < ne0; ++i0) { - const int64_t i02 = py*w + i2; - const int64_t i01 = px*w + i1; - const int64_t i00 = i0; - - const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0; - const int64_t j = i02*ne01*ne00 + i01*ne00 + i00; - - if (py*w + i2 >= ne02 || px*w + i1 >= ne01) { - ((float *) dst->data)[i] = 0.0f; - } else { - ((float *) dst->data)[i] = ((float *) src0->data)[j]; - } - } - } - } - } - } -} - -static void ggml_compute_forward_win_part( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_win_part_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_win_unpart - -static void ggml_compute_forward_win_unpart_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) - - const int32_t w = ((const int32_t *)(dst->op_params))[0]; - - // padding - const int px = (w - ne1%w)%w; - //const int py = (w - ne2%w)%w; - - const int npx = (px + ne1)/w; - //const int npy = (py + ne2)/w; - - assert(ne0 == ne00); - - // TODO: optimize / multi-thread - for (int64_t i2 = 0; i2 < ne2; ++i2) { - for (int64_t i1 = 0; i1 < ne1; ++i1) { - for (int64_t i0 = 0; i0 < ne0; ++i0) { - const int ip2 = i2/w; - const int ip1 = i1/w; - - const int64_t i02 = i2%w; - const int64_t i01 = i1%w; - const int64_t i00 = i0; - - const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00; - const int64_t j = i2*ne1*ne0 + i1*ne0 + i0; - - ((float *) dst->data)[j] = ((float *) src0->data)[i]; - } - } - } -} - -static void ggml_compute_forward_win_unpart( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_win_unpart_f32(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -//gmml_compute_forward_unary - -static void ggml_compute_forward_unary( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - const enum ggml_unary_op op = ggml_get_unary_op(dst); - - switch (op) { - case GGML_UNARY_OP_ABS: - { - ggml_compute_forward_abs(params, src0, dst); - } break; - case GGML_UNARY_OP_SGN: - { - ggml_compute_forward_sgn(params, src0, dst); - } break; - case GGML_UNARY_OP_NEG: - { - ggml_compute_forward_neg(params, src0, dst); - } break; - case GGML_UNARY_OP_STEP: - { - ggml_compute_forward_step(params, src0, dst); - } break; - case GGML_UNARY_OP_TANH: - { - ggml_compute_forward_tanh(params, src0, dst); - } break; - case GGML_UNARY_OP_ELU: - { - ggml_compute_forward_elu(params, src0, dst); - } break; - case GGML_UNARY_OP_RELU: - { - ggml_compute_forward_relu(params, src0, dst); - } break; - case GGML_UNARY_OP_GELU: - { - ggml_compute_forward_gelu(params, src0, dst); - } break; - case GGML_UNARY_OP_GELU_QUICK: - { - ggml_compute_forward_gelu_quick(params, src0, dst); - } break; - case GGML_UNARY_OP_SILU: - { - ggml_compute_forward_silu(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_get_rel_pos - -static void ggml_compute_forward_get_rel_pos_f16( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322 - - GGML_TENSOR_UNARY_OP_LOCALS - - const int64_t w = ne1; - - ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data; - ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data; - - for (int64_t i2 = 0; i2 < ne2; ++i2) { - for (int64_t i1 = 0; i1 < ne1; ++i1) { - const int64_t pos = (w - i1 - 1) + i2; - for (int64_t i0 = 0; i0 < ne0; ++i0) { - dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0]; - } - } - } -} - -static void ggml_compute_forward_get_rel_pos( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_get_rel_pos_f16(params, src0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_add_rel_pos - -static void ggml_compute_forward_add_rel_pos_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - const struct ggml_tensor * src2, - struct ggml_tensor * dst) { - - const bool inplace = (bool) ((int32_t *) dst->op_params)[0]; - if (!inplace && params->type == GGML_TASK_INIT) { - memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst)); - return; - } - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359 - - float * src1_data = (float *) src1->data; - float * src2_data = (float *) src2->data; - float * dst_data = (float *) dst->data; - - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - const int64_t ne12 = src1->ne[2]; - const int64_t ne13 = src1->ne[3]; - - const int ith = params->ith; - const int nth = params->nth; - - // total patches in dst - const int np = ne13; - - // patches per thread - const int dp = (np + nth - 1)/nth; - - // patch range for this thread - const int ip0 = dp*ith; - const int ip1 = MIN(ip0 + dp, np); - - - for (int64_t i13 = ip0; i13 < ip1; ++i13) { - for (int64_t i12 = 0; i12 < ne12; ++i12) { - for (int64_t i11 = 0; i11 < ne11; ++i11) { - const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10; - for (int64_t i10 = 0; i10 < ne10; ++i10) { - const int64_t jp0 = jp1 + i10; - const float src1_e = src1_data[jp0]; - const float src2_e = src2_data[jp0]; - - const int64_t jdh = jp0 * ne10; - const int64_t jdw = jdh - (ne10 - 1) * i10; - - for (int64_t j = 0; j < ne10; ++j) { - dst_data[jdh + j ] += src2_e; - dst_data[jdw + j*ne10] += src1_e; - } - } - } - } - } -} - -static void ggml_compute_forward_add_rel_pos( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - const struct ggml_tensor * src2, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_add_rel_pos_f32(params, src0, src1, src2, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_map_unary - -static void ggml_compute_forward_map_unary_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst, - const ggml_unary_op_f32_t fun) { - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - fun(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - - -static void ggml_compute_forward_map_unary( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - struct ggml_tensor * dst, - const ggml_unary_op_f32_t fun) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_map_unary_f32(params, src0, dst, fun); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_map_binary - -static void ggml_compute_forward_map_binary_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst, - const ggml_binary_op_f32_t fun) { - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - assert(src1->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - fun(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1])), - (float *) ((char *) src1->data + i*(src1->nb[1]))); - } -} - - -static void ggml_compute_forward_map_binary( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst, - const ggml_binary_op_f32_t fun) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_map_custom1 - -static void ggml_compute_forward_map_custom1_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * a, - struct ggml_tensor * dst, - const ggml_custom1_op_f32_t fun) { - assert(params->ith == 0); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - fun(dst, a); -} - -// ggml_compute_forward_map_custom2 - -static void ggml_compute_forward_map_custom2_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * a, - const struct ggml_tensor * b, - struct ggml_tensor * dst, - const ggml_custom2_op_f32_t fun) { - assert(params->ith == 0); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - fun(dst, a, b); -} - - -// ggml_compute_forward_map_custom3 - -static void ggml_compute_forward_map_custom3_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * a, - const struct ggml_tensor * b, - const struct ggml_tensor * c, - struct ggml_tensor * dst, - const ggml_custom3_op_f32_t fun) { - assert(params->ith == 0); - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - fun(dst, a, b, c); -} - -// ggml_compute_forward_map_custom1 - -static void ggml_compute_forward_map_custom1( - const struct ggml_compute_params * params, - const struct ggml_tensor * a, - struct ggml_tensor * dst) { - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - struct ggml_map_custom1_op_params * p = (struct ggml_map_custom1_op_params *) dst->op_params; - - p->fun(dst, a, params->ith, params->nth, p->userdata); -} - -// ggml_compute_forward_map_custom2 - -static void ggml_compute_forward_map_custom2( - const struct ggml_compute_params * params, - const struct ggml_tensor * a, - const struct ggml_tensor * b, - struct ggml_tensor * dst) { - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - struct ggml_map_custom2_op_params * p = (struct ggml_map_custom2_op_params *) dst->op_params; - - p->fun(dst, a, b, params->ith, params->nth, p->userdata); -} - -// ggml_compute_forward_map_custom3 - -static void ggml_compute_forward_map_custom3( - const struct ggml_compute_params * params, - const struct ggml_tensor * a, - const struct ggml_tensor * b, - const struct ggml_tensor * c, - struct ggml_tensor * dst) { - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - struct ggml_map_custom3_op_params * p = (struct ggml_map_custom3_op_params *) dst->op_params; - - p->fun(dst, a, b, c, params->ith, params->nth, p->userdata); -} - -// ggml_compute_forward_cross_entropy_loss - -static void ggml_compute_forward_cross_entropy_loss_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - GGML_ASSERT(ggml_is_scalar(dst)); - GGML_ASSERT(ggml_are_same_shape(src0, src1)); - - const int ith = params->ith; - const int nth = params->nth; - - float * sums = (float *) params->wdata; - - // TODO: handle transposed/permuted matrices - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); - - GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc)); - - if (params->type == GGML_TASK_INIT) { - if (ith == 0) { - memset(sums, 0, sizeof(float) * (nth + nth * nc)); - } - return; - } - - if (params->type == GGML_TASK_FINALIZE) { - if (ith == 0) { - float * dp = (float *) dst->data; - ggml_vec_sum_f32(nth, dp, sums); - dp[0] *= -1.0f / (float) nr; - } - return; - } - - const double eps = 1e-9; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]); - float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]); - float * st = ((float *) params->wdata) + nth + ith*nc; - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(s0[i])); - assert(!isnan(s1[i])); - } -#endif - // soft_max - ggml_float sum = 0.0; - { - float max = -INFINITY; - ggml_vec_max_f32(nc, &max, s0); - - uint16_t scvt; UNUSED(scvt); - for (int i = 0; i < nc; i++) { - if (s0[i] == -INFINITY) { - st[i] = 0.0f; - } else { -#ifndef GGML_CROSS_ENTROPY_EXP_FP16 - const float s = s0[i] - max; - const float val = expf(s); -#else - ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max); - memcpy(&scvt, &s, sizeof(scvt)); - const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]); -#endif - sum += (ggml_float)val; - st[i] = val; - } - } - - assert(sum > 0.0); - // sum = 1.0/sum; - } - // avoid log(0) by rescaling from [0..1] to [eps..1] - sum = (1.0 - eps) / sum; - ggml_vec_scale_f32(nc, st, sum); - ggml_vec_add1_f32(nc, st, st, eps); - ggml_vec_log_f32(nc, st, st); - ggml_vec_mul_f32(nc, st, st, s1); - - float st_sum = 0; - ggml_vec_sum_f32(nc, &st_sum, st); - sums[ith] += st_sum; - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - assert(!isnan(st[i])); - assert(!isinf(st[i])); - } -#endif - } - -} - -static void ggml_compute_forward_cross_entropy_loss( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_cross_entropy_loss_f32(params, src0, src1, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - -// ggml_compute_forward_cross_entropy_loss_back - -static void ggml_compute_forward_cross_entropy_loss_back_f32( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - const struct ggml_tensor * opt0, - struct ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - GGML_ASSERT(ggml_is_contiguous(opt0)); - GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - const int64_t ith = params->ith; - const int64_t nth = params->nth; - - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { - return; - } - - const double eps = 1e-9; - - // TODO: handle transposed/permuted matrices - const int64_t nc = src0->ne[0]; - const int64_t nr = ggml_nrows(src0); - - // rows per thread - const int64_t dr = (nr + nth - 1)/nth; - - // row range for this thread - const int64_t ir0 = dr*ith; - const int64_t ir1 = MIN(ir0 + dr, nr); - - float * d = (float *) opt0->data; - - for (int64_t i1 = ir0; i1 < ir1; i1++) { - float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]); - float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]); - float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]); - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(s0[i])); - assert(!isnan(s1[i])); - } -#endif - - // soft_max - ggml_float sum = 0.0; - { - float max = -INFINITY; - ggml_vec_max_f32(nc, &max, s0); - - uint16_t scvt; UNUSED(scvt); - for (int i = 0; i < nc; i++) { - if (s0[i] == -INFINITY) { - ds0[i] = 0.0f; - } else { -#ifndef GGML_CROSS_ENTROPY_EXP_FP16 - const float s = s0[i] - max; - const float val = expf(s); -#else - ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max); - memcpy(&scvt, &s, sizeof(scvt)); - const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]); -#endif - sum += (ggml_float)val; - ds0[i] = val; - } - } - - assert(sum > 0.0); - sum = (1.0 - eps)/sum; - } - - // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr - ggml_vec_scale_f32(nc, ds0, sum); - ggml_vec_add1_f32(nc, ds0, ds0, eps); - ggml_vec_sub_f32(nc, ds0, ds0, s1); - ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr); - - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - assert(!isnan(ds0[i])); - assert(!isinf(ds0[i])); - } -#endif - } -} - -static void ggml_compute_forward_cross_entropy_loss_back( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - const struct ggml_tensor * opt0, - struct ggml_tensor * dst) { - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_cross_entropy_loss_back_f32(params, src0, src1, opt0, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - - -///////////////////////////////// - -static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { - GGML_ASSERT(params); - -#ifdef GGML_USE_CUBLAS - bool skip_cpu = ggml_cuda_compute_forward(params, tensor); - if (skip_cpu) { - return; - } - GGML_ASSERT(tensor->src[0] == NULL || tensor->src[0]->backend == GGML_BACKEND_CPU); - GGML_ASSERT(tensor->src[1] == NULL || tensor->src[1]->backend == GGML_BACKEND_CPU); -#endif // GGML_USE_CUBLAS - - switch (tensor->op) { - case GGML_OP_DUP: - { - ggml_compute_forward_dup(params, tensor->src[0], tensor); - } break; - case GGML_OP_ADD: - { - ggml_compute_forward_add(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_ADD1: - { - ggml_compute_forward_add1(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_ACC: - { - ggml_compute_forward_acc(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_SUB: - { - ggml_compute_forward_sub(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_MUL: - { - ggml_compute_forward_mul(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_DIV: - { - ggml_compute_forward_div(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_SQR: - { - ggml_compute_forward_sqr(params, tensor->src[0], tensor); - } break; - case GGML_OP_SQRT: - { - ggml_compute_forward_sqrt(params, tensor->src[0], tensor); - } break; - case GGML_OP_LOG: - { - ggml_compute_forward_log(params, tensor->src[0], tensor); - } break; - case GGML_OP_SUM: - { - ggml_compute_forward_sum(params, tensor->src[0], tensor); - } break; - case GGML_OP_SUM_ROWS: - { - ggml_compute_forward_sum_rows(params, tensor->src[0], tensor); - } break; - case GGML_OP_MEAN: - { - ggml_compute_forward_mean(params, tensor->src[0], tensor); - } break; - case GGML_OP_ARGMAX: - { - ggml_compute_forward_argmax(params, tensor->src[0], tensor); - } break; - case GGML_OP_REPEAT: - { - ggml_compute_forward_repeat(params, tensor->src[0], tensor); - } break; - case GGML_OP_REPEAT_BACK: - { - ggml_compute_forward_repeat_back(params, tensor->src[0], tensor); - } break; - case GGML_OP_CONCAT: - { - ggml_compute_forward_concat(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_SILU_BACK: - { - ggml_compute_forward_silu_back(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_NORM: - { - ggml_compute_forward_norm(params, tensor->src[0], tensor); - } break; - case GGML_OP_RMS_NORM: - { - ggml_compute_forward_rms_norm(params, tensor->src[0], tensor); - } break; - case GGML_OP_RMS_NORM_BACK: - { - ggml_compute_forward_rms_norm_back(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_GROUP_NORM: - { - ggml_compute_forward_group_norm(params, tensor->src[0], tensor); - } break; - case GGML_OP_MUL_MAT: - { - ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_OUT_PROD: - { - ggml_compute_forward_out_prod(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_SCALE: - { - ggml_compute_forward_scale(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_SET: - { - ggml_compute_forward_set(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_CPY: - { - ggml_compute_forward_cpy(params, tensor->src[0], tensor); - } break; - case GGML_OP_CONT: - { - ggml_compute_forward_cont(params, tensor->src[0], tensor); - } break; - case GGML_OP_RESHAPE: - { - ggml_compute_forward_reshape(params, tensor->src[0], tensor); - } break; - case GGML_OP_VIEW: - { - ggml_compute_forward_view(params, tensor->src[0]); - } break; - case GGML_OP_PERMUTE: - { - ggml_compute_forward_permute(params, tensor->src[0]); - } break; - case GGML_OP_TRANSPOSE: - { - ggml_compute_forward_transpose(params, tensor->src[0]); - } break; - case GGML_OP_GET_ROWS: - { - ggml_compute_forward_get_rows(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_GET_ROWS_BACK: - { - ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_DIAG: - { - ggml_compute_forward_diag(params, tensor->src[0], tensor); - } break; - case GGML_OP_DIAG_MASK_INF: - { - ggml_compute_forward_diag_mask_inf(params, tensor->src[0], tensor); - } break; - case GGML_OP_DIAG_MASK_ZERO: - { - ggml_compute_forward_diag_mask_zero(params, tensor->src[0], tensor); - } break; - case GGML_OP_SOFT_MAX: - { - ggml_compute_forward_soft_max(params, tensor->src[0], tensor); - } break; - case GGML_OP_SOFT_MAX_BACK: - { - ggml_compute_forward_soft_max_back(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_ROPE: - { - ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_ROPE_BACK: - { - ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_ALIBI: - { - ggml_compute_forward_alibi(params, tensor->src[0], tensor); - } break; - case GGML_OP_CLAMP: - { - ggml_compute_forward_clamp(params, tensor->src[0], tensor); - } break; - case GGML_OP_CONV_1D: - { - ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_CONV_1D_STAGE_0: - { - ggml_compute_forward_conv_1d_stage_0(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_CONV_1D_STAGE_1: - { - ggml_compute_forward_conv_1d_stage_1(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_CONV_TRANSPOSE_1D: - { - ggml_compute_forward_conv_transpose_1d(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_CONV_2D: - { - ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_CONV_TRANSPOSE_2D: - { - ggml_compute_forward_conv_transpose_2d(params, tensor->src[0], tensor->src[1], tensor); - } break; - case GGML_OP_POOL_1D: - { - ggml_compute_forward_pool_1d(params, tensor->src[0], tensor); - } break; - case GGML_OP_POOL_2D: - { - ggml_compute_forward_pool_2d(params, tensor->src[0], tensor); - } break; - case GGML_OP_UPSCALE: - { - ggml_compute_forward_upscale(params, tensor->src[0], tensor); - } break; - case GGML_OP_FLASH_ATTN: - { - const int32_t t = ggml_get_op_params_i32(tensor, 0); - GGML_ASSERT(t == 0 || t == 1); - const bool masked = t != 0; - ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor); - } break; - case GGML_OP_FLASH_FF: - { - ggml_compute_forward_flash_ff(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor); - } break; - case GGML_OP_FLASH_ATTN_BACK: - { - int32_t t = ggml_get_op_params_i32(tensor, 0); - GGML_ASSERT(t == 0 || t == 1); - bool masked = t != 0; - ggml_compute_forward_flash_attn_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], masked, tensor); - } break; - case GGML_OP_WIN_PART: - { - ggml_compute_forward_win_part(params, tensor->src[0], tensor); - } break; - case GGML_OP_WIN_UNPART: - { - ggml_compute_forward_win_unpart(params, tensor->src[0], tensor); - } break; - case GGML_OP_UNARY: - { - ggml_compute_forward_unary(params, tensor->src[0], tensor); - } break; - case GGML_OP_GET_REL_POS: - { - ggml_compute_forward_get_rel_pos(params, tensor->src[0], tensor); - } break; - case GGML_OP_ADD_REL_POS: - { - ggml_compute_forward_add_rel_pos(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); - } break; - case GGML_OP_MAP_UNARY: - { - ggml_unary_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_unary(params, tensor->src[0], tensor, fun); - } - break; - case GGML_OP_MAP_BINARY: - { - ggml_binary_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_binary(params, tensor->src[0], tensor->src[1], tensor, fun); - } - break; - case GGML_OP_MAP_CUSTOM1_F32: - { - ggml_custom1_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_custom1_f32(params, tensor->src[0], tensor, fun); - } - break; - case GGML_OP_MAP_CUSTOM2_F32: - { - ggml_custom2_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_custom2_f32(params, tensor->src[0], tensor->src[1], tensor, fun); - } - break; - case GGML_OP_MAP_CUSTOM3_F32: - { - ggml_custom3_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_custom3_f32(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor, fun); - } - break; - case GGML_OP_MAP_CUSTOM1: - { - ggml_compute_forward_map_custom1(params, tensor->src[0], tensor); - } - break; - case GGML_OP_MAP_CUSTOM2: - { - ggml_compute_forward_map_custom2(params, tensor->src[0], tensor->src[1], tensor); - } - break; - case GGML_OP_MAP_CUSTOM3: - { - ggml_compute_forward_map_custom3(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); - } - break; - case GGML_OP_CROSS_ENTROPY_LOSS: - { - ggml_compute_forward_cross_entropy_loss(params, tensor->src[0], tensor->src[1], tensor); - } - break; - case GGML_OP_CROSS_ENTROPY_LOSS_BACK: - { - ggml_compute_forward_cross_entropy_loss_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); - } - break; - case GGML_OP_NONE: - { - // nop - } break; - case GGML_OP_COUNT: - { - GGML_ASSERT(false); - } break; - } -} - -//////////////////////////////////////////////////////////////////////////////// - -static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small"); - -static size_t hash(void * p) { - return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE; -} - -static size_t hash_find(void * hash_table[], void * p) { - size_t h = hash(p); - - // linear probing - size_t i = h; - while (hash_table[i] != NULL && hash_table[i] != p) { - i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE; - if (i == h) { - // visited all hash table entries -> not found - return GGML_GRAPH_HASHTABLE_SIZE; - } - } - return i; -} - -static bool hash_insert(void * hash_table[], void * p) { - size_t i = hash_find(hash_table, p); - - GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full - - if (hash_table[i] == p) { - return true; - } - - // insert - GGML_ASSERT(hash_table[i] == NULL); - hash_table[i] = p; - return false; -} - -static bool hash_contains(void * hash_table[], void * p) { - size_t i = hash_find(hash_table, p); - return (i < GGML_GRAPH_HASHTABLE_SIZE) && (hash_table[i] == p); -} - -struct hash_map { - void * keys[GGML_GRAPH_HASHTABLE_SIZE]; - void * vals[GGML_GRAPH_HASHTABLE_SIZE]; -}; - -static struct hash_map * new_hash_map(void) { - struct hash_map * result = malloc(sizeof(struct hash_map)); - for (int i=0; ikeys[i] = NULL; - result->vals[i] = NULL; - } - return result; -} - -static void free_hash_map(struct hash_map * map) { - free(map); -} - -// gradient checkpointing - -static struct ggml_tensor * ggml_recompute_graph_node( - struct ggml_context * ctx, - struct ggml_cgraph * graph, - struct hash_map * replacements, - struct ggml_tensor * node) { - - if (node == NULL) { - return NULL; - } - - if (node->is_param) { - return node; - } - - if (!hash_contains(graph->visited_hash_table, node)) { - return node; - } - - int count_children = 0; - for (int k = 0; k < GGML_MAX_SRC; ++k) { - if (node->src[k]) { - ++count_children; - } - } - - if (count_children == 0) { - return node; - } - - size_t i = hash_find(replacements->keys, node); - GGML_ASSERT(i < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full - if (replacements->keys[i] == node) { - return (struct ggml_tensor *) replacements->vals[i]; - } - - struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, node->n_dims, node->ne); - - // insert clone into replacements - GGML_ASSERT(replacements->keys[i] == NULL); // assert that we don't overwrite - replacements->keys[i] = node; - replacements->vals[i] = clone; - - clone->op = node->op; - clone->grad = node->grad; - clone->is_param = node->is_param; - clone->extra = node->extra; - for (int k = 0; k < GGML_MAX_DIMS; ++k) { - clone->nb[k] = node->nb[k]; - } - for (int k = 0; k < GGML_MAX_SRC; ++k) { - clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]); - } - if (node->view_src != NULL) { - clone->data = (node->view_src->data == NULL) - ? NULL // view_src not yet allocated - : (char *) node->view_src->data // view_src already allocated - + node->view_offs; - clone->view_src = node->view_src; - clone->view_offs = node->view_offs; - } - - GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t))); - GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME); - memcpy(clone->op_params, node->op_params, sizeof(node->op_params)); - ggml_format_name(clone, "%s (clone)", ggml_get_name(node)); - - return clone; -} - -void ggml_build_backward_gradient_checkpointing( - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct ggml_cgraph * gb, - struct ggml_cgraph * gb_tmp, - struct ggml_tensor * * checkpoints, - int n_checkpoints) { - *gb_tmp = *gf; - ggml_build_backward_expand(ctx, gf, gb_tmp, true); - - if (n_checkpoints <= 0) { - *gb = *gb_tmp; - return; - } - - struct hash_map * replacements = new_hash_map(); - - // insert checkpoints in replacements - for (int i = 0; i < n_checkpoints; ++i) { - size_t k = hash_find(replacements->keys, checkpoints[i]); - GGML_ASSERT(k < GGML_GRAPH_HASHTABLE_SIZE); // assert that not full - GGML_ASSERT(replacements->keys[k] == NULL); // assert that we don't overwrite - replacements->keys[k] = checkpoints[i]; - replacements->vals[k] = checkpoints[i]; - } - - *gb = *gf; - // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes], - // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]), - // by recomputing them from checkpoints - for (int i = gf->n_nodes; in_nodes; ++i) { - struct ggml_tensor * node = gb_tmp->nodes[i]; - for (int k = 0; k < GGML_MAX_SRC; ++k) { - // insert new tensors recomputing src, reusing already made replacements, - // remember replacements: remember new tensors with mapping from corresponding gf nodes - // recurse for input tensors, - // unless (i.e. terminating when) input tensors are replacments (like checkpoints) - node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]); - } - // insert rewritten backward node with replacements made into resulting backward graph gb - ggml_build_forward_expand(gb, node); - } - - free_hash_map(replacements); -} - -// functions to change gradients considering the case that input a might be initial gradient with zero value - -static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) { - if (hash_contains(zero_table, a)) { - return b; - } else { - return ggml_add_impl(ctx, a, b, false); - } -} - -static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, void * zero_table[]) { - if (hash_contains(zero_table, a)) { - struct ggml_tensor * a_zero = ggml_scale(ctx, a, ggml_new_f32(ctx, 0)); - return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false); - } else { - return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); - } -} - -static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) { - if (hash_contains(zero_table, a)) { - return ggml_repeat(ctx, b, a); - } else { - return ggml_add1_impl(ctx, a, b, false); - } -} - -static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) { - if (hash_contains(zero_table, a)) { - return ggml_neg(ctx, b); - } else { - return ggml_sub_impl(ctx, a, b, false); - } -} - -static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, void * zero_table[]) { - struct ggml_tensor * src0 = tensor->src[0]; - struct ggml_tensor * src1 = tensor->src[1]; - - switch (tensor->op) { - case GGML_OP_DUP: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); - } - } break; - case GGML_OP_ADD: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); - } - if (src1->grad) { - src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table); - } - } break; - case GGML_OP_ADD1: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); - } - if (src1->grad) { - src1->grad = ggml_add_or_set(ctx, - src1->grad, - ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean - zero_table); - } - } break; - case GGML_OP_ACC: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); - } - if (src1->grad) { - const size_t nb1 = ((int32_t *) tensor->op_params)[0]; - const size_t nb2 = ((int32_t *) tensor->op_params)[1]; - const size_t nb3 = ((int32_t *) tensor->op_params)[2]; - const size_t offset = ((int32_t *) tensor->op_params)[3]; - - struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx, - tensor->grad, - src1->grad->ne[0], - src1->grad->ne[1], - src1->grad->ne[2], - src1->grad->ne[3], - nb1, nb2, nb3, offset); - - src1->grad = - ggml_add_or_set(ctx, - src1->grad, - ggml_reshape(ctx, - ggml_cont(ctx, tensor_grad_view), - src1->grad), - zero_table); - } - } break; - case GGML_OP_SUB: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); - } - if (src1->grad) { - src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table); - } - } break; - case GGML_OP_MUL: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_mul(ctx, src1, tensor->grad), - zero_table); - } - if (src1->grad) { - src1->grad = - ggml_add_or_set(ctx, - src1->grad, - ggml_mul(ctx, src0, tensor->grad), - zero_table); - } - } break; - case GGML_OP_DIV: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_div(ctx, tensor->grad, src1), - zero_table); - } - if (src1->grad) { - src1->grad = - ggml_sub_or_set(ctx, - src1->grad, - ggml_mul(ctx, - tensor->grad, - ggml_div(ctx, tensor, src1)), - zero_table); - } - } break; - case GGML_OP_SQR: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_scale(ctx, - ggml_mul(ctx, src0, tensor->grad), - ggml_new_f32(ctx, 2.0f)), - zero_table); - } - } break; - case GGML_OP_SQRT: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_scale(ctx, - ggml_div(ctx, - tensor->grad, - tensor), - ggml_new_f32(ctx, 0.5f)), - zero_table); - } - } break; - case GGML_OP_LOG: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_div(ctx, - tensor->grad, - src0), - zero_table); - } - } break; - case GGML_OP_SUM: - { - if (src0->grad) { - src0->grad = - ggml_add1_or_set(ctx, - src0->grad, - tensor->grad, - zero_table); - } - } break; - case GGML_OP_SUM_ROWS: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_repeat(ctx, - tensor->grad, - src0->grad), - zero_table); - } - } break; - case GGML_OP_MEAN: - case GGML_OP_ARGMAX: - { - GGML_ASSERT(false); // TODO: implement - } break; - case GGML_OP_REPEAT: - { - // necessary for llama - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_repeat_back(ctx, tensor->grad, src0->grad), - zero_table); - } - } break; - case GGML_OP_REPEAT_BACK: - { - if (src0->grad) { - // TODO: test this - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_repeat(ctx, tensor->grad, src0->grad), - zero_table); - } - } break; - case GGML_OP_CONCAT: - { - GGML_ASSERT(false); // TODO: implement - } break; - case GGML_OP_SILU_BACK: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_NORM: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_RMS_NORM: - { - // necessary for llama - if (src0->grad) { - float eps; - memcpy(&eps, tensor->op_params, sizeof(float)); - - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_rms_norm_back(ctx, src0, tensor->grad, eps), - zero_table); - } - } break; - case GGML_OP_RMS_NORM_BACK: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_GROUP_NORM: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_MUL_MAT: - { - // https://cs231n.github.io/optimization-2/#staged - // # forward pass - // s0 = np.random.randn(5, 10) - // s1 = np.random.randn(10, 3) - // t = s0.dot(s1) - - // # now suppose we had the gradient on t from above in the circuit - // dt = np.random.randn(*t.shape) # same shape as t - // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix - // ds1 = t.T.dot(dt) - - // tensor.shape [m,p,qq,rr] - // src0.shape [n,m,q1,r1] - // src1.shape [n,p,qq,rr] - - // necessary for llama - if (src0->grad) { - struct ggml_tensor * s1_tg = - ggml_out_prod(ctx, // [n,m,qq,rr] - src1, // [n,p,qq,rr] - tensor->grad); // [m,p,qq,rr] - const int64_t qq = s1_tg->ne[2]; - const int64_t rr = s1_tg->ne[3]; - const int64_t q1 = src0->ne[2]; - const int64_t r1 = src0->ne[3]; - const bool ne2_broadcasted = qq > q1; - const bool ne3_broadcasted = rr > r1; - if (ne2_broadcasted || ne3_broadcasted) { - // sum broadcast repetitions of s1_tg into shape of src0 - s1_tg = ggml_repeat_back(ctx, s1_tg, src0); - } - src0->grad = - ggml_add_or_set(ctx, - src0->grad, // [n,m,q1,r1] - s1_tg, // [n,m,q1,r1] - zero_table); - } - if (src1->grad) { - src1->grad = - ggml_add_or_set(ctx, - src1->grad, // [n,p,qq,rr] - // ggml_mul_mat(ctx, // [n,p,qq,rr] - // ggml_cont(ctx, // [m,n,q1,r1] - // ggml_transpose(ctx, src0)), // [m,n,q1,r1] - // tensor->grad), // [m,p,qq,rr] - - // // when src0 is bigger than tensor->grad (this is mostly the case in llama), - // // avoid transpose of src0, rather transpose smaller tensor->grad - // // and then use ggml_out_prod - ggml_out_prod(ctx, // [n,p,qq,rr] - src0, // [n,m,q1,r1] - ggml_transpose(ctx, // [p,m,qq,rr] - tensor->grad)), // [m,p,qq,rr] - zero_table); - } - } break; - case GGML_OP_OUT_PROD: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_SCALE: - { - // necessary for llama - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_scale_impl(ctx, tensor->grad, src1, false), - zero_table); - } - if (src1->grad) { - src1->grad = - ggml_add_or_set(ctx, - src1->grad, - ggml_sum(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)), - zero_table); - } - } break; - case GGML_OP_SET: - { - const size_t nb1 = ((int32_t *) tensor->op_params)[0]; - const size_t nb2 = ((int32_t *) tensor->op_params)[1]; - const size_t nb3 = ((int32_t *) tensor->op_params)[2]; - const size_t offset = ((int32_t *) tensor->op_params)[3]; - - struct ggml_tensor * tensor_grad_view = NULL; - - if (src0->grad || src1->grad) { - GGML_ASSERT(src0->type == tensor->type); - GGML_ASSERT(tensor->grad->type == tensor->type); - GGML_ASSERT(tensor->grad->type == src1->grad->type); - - tensor_grad_view = ggml_view_4d(ctx, - tensor->grad, - src1->grad->ne[0], - src1->grad->ne[1], - src1->grad->ne[2], - src1->grad->ne[3], - nb1, nb2, nb3, offset); - } - - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_acc_impl(ctx, - tensor->grad, - ggml_neg(ctx, tensor_grad_view), - nb1, nb2, nb3, offset, false), - zero_table); - } - - if (src1->grad) { - src1->grad = - ggml_add_or_set(ctx, - src1->grad, - ggml_reshape(ctx, - ggml_cont(ctx, tensor_grad_view), - src1->grad), - zero_table); - } - } break; - case GGML_OP_CPY: - { - // necessary for llama - // cpy overwrites value of src1 by src0 and returns view(src1) - // the overwriting is mathematically equivalent to: - // tensor = src0 * 1 + src1 * 0 - if (src0->grad) { - // dsrc0 = dtensor * 1 - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); - } - if (src1->grad) { - // dsrc1 = dtensor * 0 -> noop - } - } break; - case GGML_OP_CONT: - { - // same as cpy - if (src0->grad) { - GGML_ASSERT(ggml_is_contiguous(src0->grad)); - GGML_ASSERT(ggml_is_contiguous(tensor->grad)); - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); - } - } break; - case GGML_OP_RESHAPE: - { - // necessary for llama - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, src0->grad, - ggml_reshape(ctx, - ggml_is_contiguous(tensor->grad) - ? tensor->grad - : ggml_cont(ctx, tensor->grad), - src0->grad), - zero_table); - } - } break; - case GGML_OP_VIEW: - { - // necessary for llama - if (src0->grad) { - size_t offset; - - memcpy(&offset, tensor->op_params, sizeof(offset)); - - size_t nb1 = tensor->nb[1]; - size_t nb2 = tensor->nb[2]; - size_t nb3 = tensor->nb[3]; - - if (src0->type != src0->grad->type) { - // gradient is typically F32, but src0 could be other type - size_t ng = ggml_element_size(src0->grad); - size_t n0 = ggml_element_size(src0); - GGML_ASSERT(offset % n0 == 0); - GGML_ASSERT(nb1 % n0 == 0); - GGML_ASSERT(nb2 % n0 == 0); - GGML_ASSERT(nb3 % n0 == 0); - offset = (offset / n0) * ng; - nb1 = (nb1 / n0) * ng; - nb2 = (nb2 / n0) * ng; - nb3 = (nb3 / n0) * ng; - } - - src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table); - } - } break; - case GGML_OP_PERMUTE: - { - // necessary for llama - if (src0->grad) { - int32_t * axes = (int32_t *) tensor->op_params; - int axis0 = axes[0] & 0x3; - int axis1 = axes[1] & 0x3; - int axis2 = axes[2] & 0x3; - int axis3 = axes[3] & 0x3; - int axes_backward[4] = {0,0,0,0}; - axes_backward[axis0] = 0; - axes_backward[axis1] = 1; - axes_backward[axis2] = 2; - axes_backward[axis3] = 3; - src0->grad = - ggml_add_or_set(ctx, src0->grad, - ggml_permute(ctx, - tensor->grad, - axes_backward[0], - axes_backward[1], - axes_backward[2], - axes_backward[3]), - zero_table); - } - } break; - case GGML_OP_TRANSPOSE: - { - // necessary for llama - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, src0->grad, - ggml_transpose(ctx, tensor->grad), - zero_table); - } - } break; - case GGML_OP_GET_ROWS: - { - // necessary for llama (only for tokenizer) - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, src0->grad, - // last ggml_get_rows_back argument src0->grad is only - // necessary to setup correct output shape - ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad), - zero_table); - } - if (src1->grad) { - // noop - } - } break; - case GGML_OP_GET_ROWS_BACK: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_DIAG: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_DIAG_MASK_INF: - { - // necessary for llama - if (src0->grad) { - const int n_past = ((int32_t *) tensor->op_params)[0]; - src0->grad = - ggml_add_or_set(ctx, src0->grad, - ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), - zero_table); - } - } break; - case GGML_OP_DIAG_MASK_ZERO: - { - // necessary for llama - if (src0->grad) { - const int n_past = ((int32_t *) tensor->op_params)[0]; - src0->grad = - ggml_add_or_set(ctx, src0->grad, - ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), - zero_table); - } - } break; - case GGML_OP_SOFT_MAX: - { - // necessary for llama - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, src0->grad, - ggml_soft_max_back(ctx, tensor->grad, tensor), - zero_table); - } - - } break; - case GGML_OP_SOFT_MAX_BACK: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_ROPE: - { - // necessary for llama - if (src0->grad) { - //const int n_past = ((int32_t *) tensor->op_params)[0]; - const int n_dims = ((int32_t *) tensor->op_params)[1]; - const int mode = ((int32_t *) tensor->op_params)[2]; - const int n_ctx = ((int32_t *) tensor->op_params)[3]; - float freq_base; - float freq_scale; - float xpos_base; - bool xpos_down; - memcpy(&freq_base, (int32_t *) tensor->op_params + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float)); - memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float)); - memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool)); - - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_rope_back(ctx, - tensor->grad, - src1, - n_dims, - mode, - n_ctx, - freq_base, - freq_scale, - xpos_base, - xpos_down), - zero_table); - } - } break; - case GGML_OP_ROPE_BACK: - { - if (src0->grad) { - //const int n_past = ((int32_t *) tensor->op_params)[0]; - const int n_dims = ((int32_t *) tensor->op_params)[1]; - const int mode = ((int32_t *) tensor->op_params)[2]; - const int n_ctx = ((int32_t *) tensor->op_params)[3]; - float freq_base; - float freq_scale; - float xpos_base; - bool xpos_down; - memcpy(&freq_base, (int32_t *) tensor->op_params + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float)); - memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float)); - memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool)); - - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_rope_impl(ctx, - tensor->grad, - src1, - n_dims, - mode, - n_ctx, - freq_base, - freq_scale, - xpos_base, - xpos_down, - false), - zero_table); - } - } break; - case GGML_OP_ALIBI: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_CLAMP: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_CONV_1D: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_CONV_1D_STAGE_0: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_CONV_1D_STAGE_1: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_CONV_2D: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_CONV_TRANSPOSE_1D: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_CONV_TRANSPOSE_2D: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_POOL_1D: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_POOL_2D: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_UPSCALE: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_OP_FLASH_ATTN: - { - struct ggml_tensor * flash_grad = NULL; - if (src0->grad || src1->grad || tensor->src[2]->grad) { - int32_t t = ggml_get_op_params_i32(tensor, 0); - GGML_ASSERT(t == 0 || t == 1); - bool masked = t != 0; - flash_grad = - ggml_flash_attn_back(ctx, - src0, - src1, - tensor->src[2], - tensor->grad, - masked); - } - - struct ggml_tensor * src2 = tensor->src[2]; - const int64_t elem_q = ggml_nelements(src0); - const int64_t elem_k = ggml_nelements(src1); - const int64_t elem_v = ggml_nelements(src2); - - enum ggml_type result_type = flash_grad->type; - GGML_ASSERT(ggml_blck_size(result_type) == 1); - const size_t tsize = ggml_type_size(result_type); - - const size_t offs_q = 0; - const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN); - const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN); - - if (src0->grad) { - struct ggml_tensor * view_q = ggml_view_1d(ctx, flash_grad, elem_q, offs_q); - struct ggml_tensor * grad_q = ggml_reshape(ctx, view_q, src0); - src0->grad = ggml_add_or_set(ctx, - src0->grad, - grad_q, - zero_table); - } - if (src1->grad) { - struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k); - struct ggml_tensor * grad_k = ggml_reshape(ctx, view_k, src1); - src1->grad = ggml_add_or_set(ctx, - src1->grad, - grad_k, - zero_table); - } - if (src2->grad) { - struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v); - struct ggml_tensor * grad_v = ggml_reshape(ctx, view_v, src2); - src2->grad = ggml_add_or_set(ctx, - src2->grad, - grad_v, - zero_table); - } - } break; - case GGML_OP_FLASH_FF: - { - GGML_ASSERT(false); // not supported - } break; - case GGML_OP_FLASH_ATTN_BACK: - { - GGML_ASSERT(false); // not supported - } break; - case GGML_OP_WIN_PART: - case GGML_OP_WIN_UNPART: - case GGML_OP_UNARY: - { - switch (ggml_get_unary_op(tensor)) { - case GGML_UNARY_OP_ABS: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_mul(ctx, - ggml_sgn(ctx, src0), - tensor->grad), - zero_table); - } - } break; - case GGML_UNARY_OP_SGN: - { - if (src0->grad) { - // noop - } - } break; - case GGML_UNARY_OP_NEG: - { - if (src0->grad) { - src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table); - } - } break; - case GGML_UNARY_OP_STEP: - { - if (src0->grad) { - // noop - } - } break; - case GGML_UNARY_OP_TANH: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_UNARY_OP_ELU: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_UNARY_OP_RELU: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_mul(ctx, - ggml_step(ctx, src0), - tensor->grad), - zero_table); - } - } break; - case GGML_UNARY_OP_GELU: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_UNARY_OP_GELU_QUICK: - { - GGML_ASSERT(false); // TODO: not implemented - } break; - case GGML_UNARY_OP_SILU: - { - // necessary for llama - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_silu_back(ctx, src0, tensor->grad), - zero_table); - } - } break; - default: - GGML_ASSERT(false); - } - } break; - case GGML_OP_GET_REL_POS: - case GGML_OP_ADD_REL_POS: - case GGML_OP_MAP_UNARY: - case GGML_OP_MAP_BINARY: - case GGML_OP_MAP_CUSTOM1_F32: - case GGML_OP_MAP_CUSTOM2_F32: - case GGML_OP_MAP_CUSTOM3_F32: - case GGML_OP_MAP_CUSTOM1: - case GGML_OP_MAP_CUSTOM2: - case GGML_OP_MAP_CUSTOM3: - { - GGML_ASSERT(false); // not supported - } break; - case GGML_OP_CROSS_ENTROPY_LOSS: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_cross_entropy_loss_back(ctx, - src0, - src1, - tensor->grad), - zero_table); - } - } break; - case GGML_OP_CROSS_ENTROPY_LOSS_BACK: - { - GGML_ASSERT(false); // not supported - } break; - case GGML_OP_NONE: - { - // nop - } break; - case GGML_OP_COUNT: - { - GGML_ASSERT(false); - } break; - } - - for (int i = 0; i < GGML_MAX_SRC; ++i) { - if (tensor->src[i] && tensor->src[i]->grad) { - GGML_ASSERT(ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad)); - } - } -} - -static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { - if (node->grad == NULL) { - // this usually happens when we generate intermediate nodes from constants in the backward pass - // it can also happen during forward pass, if the user performs computations with constants - if (node->op != GGML_OP_NONE) { - //GGML_PRINT_DEBUG("%s: warning: node %p has no grad, but op %d\n", __func__, (void *) node, node->op); - } - } - - // check if already visited - if (hash_insert(cgraph->visited_hash_table, node)) { - return; - } - - for (int i = 0; i < GGML_MAX_SRC; ++i) { - const int k = - (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i : - (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) : - /* unknown order, just fall back to using i*/ i; - if (node->src[k]) { - ggml_visit_parents(cgraph, node->src[k]); - } - } - - if (node->op == GGML_OP_NONE && node->grad == NULL) { - // reached a leaf node, not part of the gradient graph (e.g. a constant) - GGML_ASSERT(cgraph->n_leafs < GGML_MAX_NODES); - - if (strlen(node->name) == 0) { - ggml_format_name(node, "leaf_%d", cgraph->n_leafs); - } - - cgraph->leafs[cgraph->n_leafs] = node; - cgraph->n_leafs++; - } else { - GGML_ASSERT(cgraph->n_nodes < GGML_MAX_NODES); - - if (strlen(node->name) == 0) { - ggml_format_name(node, "node_%d", cgraph->n_nodes); - } - - cgraph->nodes[cgraph->n_nodes] = node; - cgraph->grads[cgraph->n_nodes] = node->grad; - cgraph->n_nodes++; - } -} - -static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) { - if (!expand) { - cgraph->n_nodes = 0; - cgraph->n_leafs = 0; - } - - const int n0 = cgraph->n_nodes; - UNUSED(n0); - - ggml_visit_parents(cgraph, tensor); - - const int n_new = cgraph->n_nodes - n0; - GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new); - - if (n_new > 0) { - // the last added node should always be starting point - GGML_ASSERT(cgraph->nodes[cgraph->n_nodes - 1] == tensor); - } -} - -void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) { - ggml_build_forward_impl(cgraph, tensor, true); -} - -struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) { - struct ggml_cgraph result = { - /*.n_nodes =*/ 0, - /*.n_leafs =*/ 0, - /*.nodes =*/ { NULL }, - /*.grads =*/ { NULL }, - /*.leafs =*/ { NULL }, - /*.hash_table =*/ { NULL }, - /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT, - /*.perf_runs =*/ 0, - /*.perf_cycles =*/ 0, - /*.perf_time_us =*/ 0, - }; - - ggml_build_forward_impl(&result, tensor, false); - - return result; -} - -void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep) { - GGML_ASSERT(gf->n_nodes > 0); - - // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph - if (keep) { - for (int i = 0; i < gf->n_nodes; i++) { - struct ggml_tensor * node = gf->nodes[i]; - - if (node->grad) { - node->grad = ggml_dup_tensor(ctx, node); - gf->grads[i] = node->grad; - } - } - } - - // remember original gradients which start with zero values - void ** zero_table = malloc(sizeof(void *) * GGML_GRAPH_HASHTABLE_SIZE); - memset(zero_table, 0, sizeof(void*) * GGML_GRAPH_HASHTABLE_SIZE); - for (int i = 0; i < gf->n_nodes; i++) { - if (gf->grads[i]) { - hash_insert(zero_table, gf->grads[i]); - } - } - - for (int i = gf->n_nodes - 1; i >= 0; i--) { - struct ggml_tensor * node = gf->nodes[i]; - - // inplace operations to add gradients are not created by ggml_compute_backward - // use allocator to automatically make inplace operations - if (node->grad) { - ggml_compute_backward(ctx, node, zero_table); - } - } - - for (int i = 0; i < gf->n_nodes; i++) { - struct ggml_tensor * node = gf->nodes[i]; - - if (node->is_param) { - GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); - ggml_build_forward_expand(gb, node->grad); - } - } - - free(zero_table); -} - -struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) { - struct ggml_cgraph result = *gf; - ggml_build_backward_expand(ctx, gf, &result, keep); - return result; -} - -struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) { - struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_GRAPH, GGML_GRAPH_SIZE); - struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs); - - *cgraph = (struct ggml_cgraph) { - /*.n_nodes =*/ 0, - /*.n_leafs =*/ 0, - /*.nodes =*/ { NULL }, - /*.grads =*/ { NULL }, - /*.leafs =*/ { NULL }, - /*.hash_table =*/ { NULL }, - /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT, - /*.perf_runs =*/ 0, - /*.perf_cycles =*/ 0, - /*.perf_time_us =*/ 0, - }; - - return cgraph; -} - -struct ggml_cgraph * ggml_build_forward_ctx(struct ggml_context * ctx, struct ggml_tensor * tensor) { - struct ggml_cgraph * cgraph = ggml_new_graph(ctx); - ggml_build_forward_impl(cgraph, tensor, false); - return cgraph; -} - -size_t ggml_graph_overhead(void) { - return GGML_OBJECT_SIZE + GGML_PAD(GGML_GRAPH_SIZE, GGML_MEM_ALIGN); -} - -// -// thread data -// -// synchronization is done via busy loops -// I tried using spin locks, but not sure how to use them correctly - the things I tried were slower than busy loops -// - -#ifdef __APPLE__ - -//#include -// -//typedef os_unfair_lock ggml_lock_t; -// -//#define ggml_lock_init(x) UNUSED(x) -//#define ggml_lock_destroy(x) UNUSED(x) -//#define ggml_lock_lock os_unfair_lock_lock -//#define ggml_lock_unlock os_unfair_lock_unlock -// -//#define GGML_LOCK_INITIALIZER OS_UNFAIR_LOCK_INIT - -typedef int ggml_lock_t; - -#define ggml_lock_init(x) UNUSED(x) -#define ggml_lock_destroy(x) UNUSED(x) -#define ggml_lock_lock(x) UNUSED(x) -#define ggml_lock_unlock(x) UNUSED(x) - -#define GGML_LOCK_INITIALIZER 0 - -typedef pthread_t ggml_thread_t; - -#define ggml_thread_create pthread_create -#define ggml_thread_join pthread_join - -#else - -//typedef pthread_spinlock_t ggml_lock_t; - -//#define ggml_lock_init(x) pthread_spin_init(x, PTHREAD_PROCESS_PRIVATE) -//#define ggml_lock_destroy pthread_spin_destroy -//#define ggml_lock_lock pthread_spin_lock -//#define ggml_lock_unlock pthread_spin_unlock - -typedef int ggml_lock_t; - -#define ggml_lock_init(x) UNUSED(x) -#define ggml_lock_destroy(x) UNUSED(x) -#if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64)) -#define ggml_lock_lock(x) _mm_pause() -#else -#define ggml_lock_lock(x) UNUSED(x) -#endif -#define ggml_lock_unlock(x) UNUSED(x) - -#define GGML_LOCK_INITIALIZER 0 - -typedef pthread_t ggml_thread_t; - -#define ggml_thread_create pthread_create -#define ggml_thread_join pthread_join - -#endif - -// Android's libc implementation "bionic" does not support setting affinity -#if defined(__linux__) && !defined(__BIONIC__) -static void set_numa_thread_affinity(int thread_n, int n_threads) { - if (!ggml_is_numa()) { - return; - } - - // run thread on node_num thread_n / (threads per node) - const int node_num = thread_n / ((n_threads + g_state.numa.n_nodes - 1) / g_state.numa.n_nodes); - struct ggml_numa_node * node = &g_state.numa.nodes[node_num]; - size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus); - - cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus); - CPU_ZERO_S(setsize, cpus); - for (size_t i = 0; i < node->n_cpus; ++i) { - CPU_SET_S(node->cpus[i], setsize, cpus); - } - - int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus); - if (rv) { - fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", - strerror(rv)); - } - - CPU_FREE(cpus); -} - -static void clear_numa_thread_affinity(void) { - if (!ggml_is_numa()) { - return; - } - - size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus); - - cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus); - CPU_ZERO_S(setsize, cpus); - for (unsigned i = 0; i < g_state.numa.total_cpus; ++i) { - CPU_SET_S(i, setsize, cpus); - } - - int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus); - if (rv) { - fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", - strerror(rv)); - } - - CPU_FREE(cpus); -} -#else -// TODO: Windows etc. -// (the linux implementation may also work on BSD, someone should test) -static void set_numa_thread_affinity(int thread_n, int n_threads) { UNUSED(thread_n); UNUSED(n_threads); } -static void clear_numa_thread_affinity(void) {} -#endif - -struct ggml_compute_state_shared { - const struct ggml_cgraph * cgraph; - const struct ggml_cplan * cplan; - - int64_t perf_node_start_cycles; - int64_t perf_node_start_time_us; - - const int n_threads; - - // synchronization primitives - atomic_int n_active; // num active threads - atomic_int node_n; // active graph node - - bool (*abort_callback)(void * data); // abort ggml_graph_compute when true - void * abort_callback_data; -}; - -struct ggml_compute_state { - ggml_thread_t thrd; - int ith; - struct ggml_compute_state_shared * shared; -}; - -static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const struct ggml_compute_state_shared * st) { - int64_t cycles_cur = ggml_perf_cycles() - st->perf_node_start_cycles; - int64_t time_us_cur = ggml_perf_time_us() - st->perf_node_start_time_us; - - node->perf_runs++; - node->perf_cycles += cycles_cur; - node->perf_time_us += time_us_cur; -} - -static thread_ret_t ggml_graph_compute_thread(void * data) { - struct ggml_compute_state * state = (struct ggml_compute_state *) data; - - const struct ggml_cgraph * cgraph = state->shared->cgraph; - const struct ggml_cplan * cplan = state->shared->cplan; - - const int * n_tasks_arr = cplan->n_tasks; - const int n_threads = state->shared->n_threads; - - set_numa_thread_affinity(state->ith, n_threads); - - int node_n = -1; - - while (true) { - if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { - state->shared->node_n += 1; - return (thread_ret_t) GGML_EXIT_ABORTED; - } - if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) { - // all other threads are finished and spinning - // do finalize and init here so we don't have synchronize again - struct ggml_compute_params params = { - /*.type =*/ GGML_TASK_FINALIZE, - /*.ith =*/ 0, - /*.nth =*/ 0, - /*.wsize =*/ cplan->work_size, - /*.wdata =*/ cplan->work_data, - }; - - if (node_n != -1) { - /* FINALIZE */ - struct ggml_tensor * node = state->shared->cgraph->nodes[node_n]; - if (GGML_OP_HAS_FINALIZE[node->op]) { - params.nth = n_tasks_arr[node_n]; - ggml_compute_forward(¶ms, node); - } - ggml_graph_compute_perf_stats_node(node, state->shared); - } - - // distribute new work or execute it direct if 1T - while (++node_n < cgraph->n_nodes) { - GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes); - - struct ggml_tensor * node = cgraph->nodes[node_n]; - const int n_tasks = n_tasks_arr[node_n]; - - state->shared->perf_node_start_cycles = ggml_perf_cycles(); - state->shared->perf_node_start_time_us = ggml_perf_time_us(); - - params.nth = n_tasks; - - /* INIT */ - if (GGML_OP_HAS_INIT[node->op]) { - params.type = GGML_TASK_INIT; - ggml_compute_forward(¶ms, node); - } - - if (n_tasks == 1) { - // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1, - // they do something more efficient than spinning (?) - params.type = GGML_TASK_COMPUTE; - ggml_compute_forward(¶ms, node); - - if (GGML_OP_HAS_FINALIZE[node->op]) { - params.type = GGML_TASK_FINALIZE; - ggml_compute_forward(¶ms, node); - } - - ggml_graph_compute_perf_stats_node(node, state->shared); - } else { - break; - } - - if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { - break; - } - } - - atomic_store(&state->shared->n_active, n_threads); - atomic_store(&state->shared->node_n, node_n); - } else { - // wait for other threads to finish - const int last = node_n; - while (true) { - // TODO: this sched_yield can have significant impact on the performance - either positive or negative - // depending on the workload and the operating system. - // since it is not clear what is the best approach, it should potentially become user-configurable - // ref: https://github.com/ggerganov/ggml/issues/291 -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) - sched_yield(); -#endif - - node_n = atomic_load(&state->shared->node_n); - if (node_n != last) break; - }; - } - - // check if we should stop - if (node_n >= cgraph->n_nodes) break; - - /* COMPUTE */ - struct ggml_tensor * node = cgraph->nodes[node_n]; - const int n_tasks = n_tasks_arr[node_n]; - - struct ggml_compute_params params = { - /*.type =*/ GGML_TASK_COMPUTE, - /*.ith =*/ state->ith, - /*.nth =*/ n_tasks, - /*.wsize =*/ cplan->work_size, - /*.wdata =*/ cplan->work_data, - }; - - if (state->ith < n_tasks) { - ggml_compute_forward(¶ms, node); - } - } - - return GGML_EXIT_SUCCESS; -} - -struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { - if (n_threads <= 0) { - n_threads = GGML_DEFAULT_N_THREADS; - } - - size_t work_size = 0; - - struct ggml_cplan cplan; - memset(&cplan, 0, sizeof(struct ggml_cplan)); - - // thread scheduling for the different operations + work buffer size estimation - for (int i = 0; i < cgraph->n_nodes; i++) { - int n_tasks = 1; - - struct ggml_tensor * node = cgraph->nodes[i]; - - switch (node->op) { - case GGML_OP_CPY: - case GGML_OP_DUP: - { - n_tasks = n_threads; - - size_t cur = 0; - if (ggml_is_quantized(node->type)) { - cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; - } - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_ADD: - case GGML_OP_ADD1: - { - n_tasks = n_threads; - - size_t cur = 0; - - if (ggml_is_quantized(node->src[0]->type)) { - cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; - } - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_ACC: - { - n_tasks = n_threads; - - size_t cur = 0; - - if (ggml_is_quantized(node->src[0]->type)) { - cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks; - } - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_SUB: - case GGML_OP_DIV: - case GGML_OP_SQR: - case GGML_OP_SQRT: - case GGML_OP_LOG: - case GGML_OP_SUM: - case GGML_OP_SUM_ROWS: - case GGML_OP_MEAN: - case GGML_OP_ARGMAX: - case GGML_OP_REPEAT: - case GGML_OP_REPEAT_BACK: - { - n_tasks = 1; - } break; - - case GGML_OP_UNARY: - { - switch (ggml_get_unary_op(node)) { - case GGML_UNARY_OP_ABS: - case GGML_UNARY_OP_SGN: - case GGML_UNARY_OP_NEG: - case GGML_UNARY_OP_STEP: - case GGML_UNARY_OP_TANH: - case GGML_UNARY_OP_ELU: - case GGML_UNARY_OP_RELU: - { - n_tasks = 1; - } break; - - case GGML_UNARY_OP_GELU: - case GGML_UNARY_OP_GELU_QUICK: - case GGML_UNARY_OP_SILU: - { - n_tasks = n_threads; - } break; - } - } break; - case GGML_OP_SILU_BACK: - case GGML_OP_MUL: - case GGML_OP_NORM: - case GGML_OP_RMS_NORM: - case GGML_OP_RMS_NORM_BACK: - case GGML_OP_GROUP_NORM: - { - n_tasks = n_threads; - } break; - case GGML_OP_CONCAT: - case GGML_OP_MUL_MAT: - { - n_tasks = n_threads; - - // TODO: use different scheduling for different matrix sizes - //const int nr0 = ggml_nrows(node->src[0]); - //const int nr1 = ggml_nrows(node->src[1]); - - //n_tasks = MIN(n_threads, MAX(1, nr0/128)); - //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks); - - size_t cur = 0; - const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type; - -#if defined(GGML_USE_CUBLAS) - if (ggml_cuda_can_mul_mat(node->src[0], node->src[1], node)) { - n_tasks = 1; // TODO: this actually is doing nothing - // the threads are still spinning - } else -#elif defined(GGML_USE_CLBLAST) - if (ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) { - n_tasks = 1; // TODO: this actually is doing nothing - // the threads are still spinning - cur = ggml_cl_mul_mat_get_wsize(node->src[0], node->src[1], node); - } else -#endif -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) - if (ggml_compute_forward_mul_mat_use_blas(node->src[0], node->src[1], node)) { - n_tasks = 1; // TODO: this actually is doing nothing - // the threads are still spinning - if (node->src[0]->type != GGML_TYPE_F32) { - // here we need memory just for single 2D matrix from src0 - cur = ggml_type_size(GGML_TYPE_F32)*(node->src[0]->ne[0]*node->src[0]->ne[1]); - } - } else -#endif - if (node->src[1]->type != vec_dot_type) { - cur = ggml_type_size(vec_dot_type)*ggml_nelements(node->src[1])/ggml_blck_size(vec_dot_type); - } else { - cur = 0; - } - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_OUT_PROD: - { - n_tasks = n_threads; - - size_t cur = 0; - - if (ggml_is_quantized(node->src[0]->type)) { - cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; - } - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_SCALE: - { - n_tasks = 1; - } break; - case GGML_OP_SET: - case GGML_OP_CONT: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_PERMUTE: - case GGML_OP_TRANSPOSE: - case GGML_OP_GET_ROWS: - case GGML_OP_GET_ROWS_BACK: - case GGML_OP_DIAG: - { - n_tasks = 1; - } break; - case GGML_OP_DIAG_MASK_ZERO: - case GGML_OP_DIAG_MASK_INF: - case GGML_OP_SOFT_MAX: - case GGML_OP_SOFT_MAX_BACK: - case GGML_OP_ROPE: - case GGML_OP_ROPE_BACK: - case GGML_OP_ADD_REL_POS: - { - n_tasks = n_threads; - } break; - case GGML_OP_ALIBI: - { - n_tasks = 1; //TODO - } break; - case GGML_OP_CLAMP: - { - n_tasks = 1; //TODO - } break; - case GGML_OP_CONV_1D: - { - n_tasks = n_threads; - - GGML_ASSERT(node->src[0]->ne[3] == 1); - GGML_ASSERT(node->src[1]->ne[2] == 1); - GGML_ASSERT(node->src[1]->ne[3] == 1); - - const int64_t ne00 = node->src[0]->ne[0]; - const int64_t ne01 = node->src[0]->ne[1]; - const int64_t ne02 = node->src[0]->ne[2]; - - const int64_t ne10 = node->src[1]->ne[0]; - const int64_t ne11 = node->src[1]->ne[1]; - - const int64_t ne0 = node->ne[0]; - const int64_t ne1 = node->ne[1]; - const int64_t nk = ne00; - const int64_t ew0 = nk * ne01; - - UNUSED(ne02); - UNUSED(ne10); - UNUSED(ne11); - - size_t cur = 0; - - if (node->src[0]->type == GGML_TYPE_F16 && - node->src[1]->type == GGML_TYPE_F32) { - cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0); - } else if (node->src[0]->type == GGML_TYPE_F32 && - node->src[1]->type == GGML_TYPE_F32) { - cur = sizeof(float)*(ne0*ne1*ew0); - } else { - GGML_ASSERT(false); - } - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_CONV_1D_STAGE_0: - { - n_tasks = n_threads; - } break; - case GGML_OP_CONV_1D_STAGE_1: - { - n_tasks = n_threads; - } break; - case GGML_OP_CONV_TRANSPOSE_1D: - { - n_tasks = n_threads; - - GGML_ASSERT(node->src[0]->ne[3] == 1); - GGML_ASSERT(node->src[1]->ne[2] == 1); - GGML_ASSERT(node->src[1]->ne[3] == 1); - - const int64_t ne00 = node->src[0]->ne[0]; // K - const int64_t ne01 = node->src[0]->ne[1]; // Cout - const int64_t ne02 = node->src[0]->ne[2]; // Cin - - const int64_t ne10 = node->src[1]->ne[0]; // L - const int64_t ne11 = node->src[1]->ne[1]; // Cin - - size_t cur = 0; - if (node->src[0]->type == GGML_TYPE_F16 && - node->src[1]->type == GGML_TYPE_F32) { - cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02; - cur += sizeof(ggml_fp16_t)*ne10*ne11; - } else if (node->src[0]->type == GGML_TYPE_F32 && - node->src[1]->type == GGML_TYPE_F32) { - cur += sizeof(float)*ne00*ne01*ne02; - cur += sizeof(float)*ne10*ne11; - } else { - GGML_ASSERT(false); - } - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_CONV_2D: - { - n_tasks = n_threads; - - const int64_t ne00 = node->src[0]->ne[0]; // W - const int64_t ne01 = node->src[0]->ne[1]; // H - const int64_t ne02 = node->src[0]->ne[2]; // C - const int64_t ne03 = node->src[0]->ne[3]; // N - - const int64_t ne10 = node->src[1]->ne[0]; // W - const int64_t ne11 = node->src[1]->ne[1]; // H - const int64_t ne12 = node->src[1]->ne[2]; // C - - const int64_t ne0 = node->ne[0]; - const int64_t ne1 = node->ne[1]; - const int64_t ne2 = node->ne[2]; - const int64_t nk = ne00*ne01; - const int64_t ew0 = nk * ne02; - - UNUSED(ne03); - UNUSED(ne2); - - size_t cur = 0; - - if (node->src[0]->type == GGML_TYPE_F16 && - node->src[1]->type == GGML_TYPE_F32) { - cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0); - } else if (node->src[0]->type == GGML_TYPE_F32 && - node->src[1]->type == GGML_TYPE_F32) { - cur = sizeof(float)* (ne10*ne11*ne12); - } else { - GGML_ASSERT(false); - } - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_CONV_TRANSPOSE_2D: - { - n_tasks = n_threads; - - const int64_t ne00 = node->src[0]->ne[0]; // W - const int64_t ne01 = node->src[0]->ne[1]; // H - const int64_t ne02 = node->src[0]->ne[2]; // Channels Out - const int64_t ne03 = node->src[0]->ne[3]; // Channels In - - const int64_t ne10 = node->src[1]->ne[0]; // W - const int64_t ne11 = node->src[1]->ne[1]; // H - const int64_t ne12 = node->src[1]->ne[2]; // Channels In - - size_t cur = 0; - cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03; - cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_POOL_1D: - case GGML_OP_POOL_2D: - { - n_tasks = 1; - } break; - case GGML_OP_UPSCALE: - { - n_tasks = n_threads; - } break; - case GGML_OP_FLASH_ATTN: - { - n_tasks = n_threads; - - size_t cur = 0; - - const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); - - if (node->src[1]->type == GGML_TYPE_F32) { - cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 - } - - if (node->src[1]->type == GGML_TYPE_F16) { - cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 - } - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_FLASH_FF: - { - n_tasks = n_threads; - - size_t cur = 0; - - if (node->src[1]->type == GGML_TYPE_F32) { - cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2 - } - - if (node->src[1]->type == GGML_TYPE_F16) { - cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2 - } - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_FLASH_ATTN_BACK: - { - n_tasks = n_threads; - - size_t cur = 0; - - const int64_t D = node->src[0]->ne[0]; - const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); - const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back - if (node->src[1]->type == GGML_TYPE_F32) { - cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 - } - - if (node->src[1]->type == GGML_TYPE_F16) { - cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 - } - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_WIN_PART: - case GGML_OP_WIN_UNPART: - case GGML_OP_GET_REL_POS: - case GGML_OP_MAP_UNARY: - case GGML_OP_MAP_BINARY: - case GGML_OP_MAP_CUSTOM1_F32: - case GGML_OP_MAP_CUSTOM2_F32: - case GGML_OP_MAP_CUSTOM3_F32: - { - n_tasks = 1; - } break; - case GGML_OP_MAP_CUSTOM1: - { - struct ggml_map_custom1_op_params * p = (struct ggml_map_custom1_op_params *) node->op_params; - if (p->n_tasks == GGML_N_TASKS_MAX) { - n_tasks = n_threads; - } else { - n_tasks = MIN(p->n_tasks, n_threads); - } - } break; - case GGML_OP_MAP_CUSTOM2: - { - struct ggml_map_custom2_op_params * p = (struct ggml_map_custom2_op_params *) node->op_params; - if (p->n_tasks == GGML_N_TASKS_MAX) { - n_tasks = n_threads; - } else { - n_tasks = MIN(p->n_tasks, n_threads); - } - } break; - case GGML_OP_MAP_CUSTOM3: - { - struct ggml_map_custom3_op_params * p = (struct ggml_map_custom3_op_params *) node->op_params; - if (p->n_tasks == GGML_N_TASKS_MAX) { - n_tasks = n_threads; - } else { - n_tasks = MIN(p->n_tasks, n_threads); - } - } break; - case GGML_OP_CROSS_ENTROPY_LOSS: - { - n_tasks = n_threads; - - size_t cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_CROSS_ENTROPY_LOSS_BACK: - { - n_tasks = n_threads; - } break; - case GGML_OP_NONE: - { - n_tasks = 1; - } break; - case GGML_OP_COUNT: - { - GGML_ASSERT(false); - } break; - } - - cplan.n_tasks[i] = n_tasks; - } - - if (work_size > 0) { - work_size += CACHE_LINE_SIZE*(n_threads - 1); - } - - cplan.n_threads = n_threads; - cplan.work_size = work_size; - cplan.work_data = NULL; - - return cplan; -} - -int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) { - { - GGML_ASSERT(cplan); - GGML_ASSERT(cplan->n_threads > 0); - - if (cplan->work_size > 0) { - GGML_ASSERT(cplan->work_data); - } - - for (int i = 0; i < cgraph->n_nodes; ++i) { - if (cgraph->nodes[i]->op != GGML_OP_NONE) { - GGML_ASSERT(cplan->n_tasks[i] > 0); - } - } - } - - const int n_threads = cplan->n_threads; - - struct ggml_compute_state_shared state_shared = { - /*.cgraph =*/ cgraph, - /*.cgraph_plan =*/ cplan, - /*.perf_node_start_cycles =*/ 0, - /*.perf_node_start_time_us =*/ 0, - /*.n_threads =*/ n_threads, - /*.n_active =*/ n_threads, - /*.node_n =*/ -1, - /*.abort_callback =*/ NULL, - /*.abort_callback_data =*/ NULL, - }; - struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads); - - // create thread pool - if (n_threads > 1) { - for (int j = 1; j < n_threads; ++j) { - workers[j] = (struct ggml_compute_state) { - .thrd = 0, - .ith = j, - .shared = &state_shared, - }; - - const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]); - GGML_ASSERT(rc == 0); - UNUSED(rc); - } - } - - workers[0].ith = 0; - workers[0].shared = &state_shared; - - const int64_t perf_start_cycles = ggml_perf_cycles(); - const int64_t perf_start_time_us = ggml_perf_time_us(); - - // this is a work thread too - int compute_status = (size_t) ggml_graph_compute_thread(&workers[0]); - - // don't leave affinity set on the main thread - clear_numa_thread_affinity(); - - // join or kill thread pool - if (n_threads > 1) { - for (int j = 1; j < n_threads; j++) { - const int rc = ggml_thread_join(workers[j].thrd, NULL); - GGML_ASSERT(rc == 0); - } - } - - // performance stats (graph) - { - int64_t perf_cycles_cur = ggml_perf_cycles() - perf_start_cycles; - int64_t perf_time_us_cur = ggml_perf_time_us() - perf_start_time_us; - - cgraph->perf_runs++; - cgraph->perf_cycles += perf_cycles_cur; - cgraph->perf_time_us += perf_time_us_cur; - - GGML_PRINT_DEBUG("%s: perf (%d) - cpu = %.3f / %.3f ms, wall = %.3f / %.3f ms\n", - __func__, cgraph->perf_runs, - (double) perf_cycles_cur / (double) ggml_cycles_per_ms(), - (double) cgraph->perf_cycles / (double) ggml_cycles_per_ms() / (double) cgraph->perf_runs, - (double) perf_time_us_cur / 1000.0, - (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs); - } - - return compute_status; -} - -void ggml_graph_reset(struct ggml_cgraph * cgraph) { - for (int i = 0; i < cgraph->n_nodes; i++) { - struct ggml_tensor * grad = cgraph->grads[i]; - - if (grad) { - ggml_set_zero(grad); - } - } -} - -void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) { - struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads); - - struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size); - - cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; - - ggml_graph_compute(cgraph, &cplan); -} - -struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name) { - for (int i = 0; i < cgraph->n_leafs; i++) { - struct ggml_tensor * leaf = cgraph->leafs[i]; - - if (strcmp(leaf->name, name) == 0) { - return leaf; - } - } - - for (int i = 0; i < cgraph->n_nodes; i++) { - struct ggml_tensor * node = cgraph->nodes[i]; - - if (strcmp(node->name, name) == 0) { - return node; - } - } - - return NULL; -} - -static void ggml_graph_export_leaf(const struct ggml_tensor * tensor, FILE * fout) { - const int64_t * ne = tensor->ne; - const size_t * nb = tensor->nb; - - fprintf(fout, "%-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n", - ggml_type_name(tensor->type), - ggml_op_name (tensor->op), - tensor->n_dims, - ne[0], ne[1], ne[2], ne[3], - nb[0], nb[1], nb[2], nb[3], - tensor->data, - tensor->name); -} - -static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char * arg, FILE * fout) { - const int64_t * ne = tensor->ne; - const size_t * nb = tensor->nb; - - fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n", - arg, - ggml_type_name(tensor->type), - ggml_op_name (tensor->op), - tensor->n_dims, - ne[0], ne[1], ne[2], ne[3], - nb[0], nb[1], nb[2], nb[3], - tensor->data, - tensor->name); -} - -void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { - uint64_t size_eval = 0; - - // compute size of intermediate results - // TODO: does not take into account scratch buffers !!!! - for (int i = 0; i < cgraph->n_nodes; ++i) { - size_eval += ggml_nbytes_pad(cgraph->nodes[i]); - } - - // print - { - FILE * fout = stdout; - - fprintf(fout, "\n"); - fprintf(fout, "%-16s %8x\n", "magic", GGML_FILE_MAGIC); - fprintf(fout, "%-16s %8d\n", "version", GGML_FILE_VERSION); - fprintf(fout, "%-16s %8d\n", "leafs", cgraph->n_leafs); - fprintf(fout, "%-16s %8d\n", "nodes", cgraph->n_nodes); - fprintf(fout, "%-16s %" PRIu64 "\n", "eval", size_eval); - - // header - fprintf(fout, "\n"); - fprintf(fout, "%-6s %-12s %8s %8s %8s %8s %8s %16s %16s %16s %16s %16s %16s\n", - "TYPE", "OP", "NDIMS", "NE0", "NE1", "NE2", "NE3", "NB0", "NB1", "NB2", "NB3", "DATA", "NAME"); - - for (int i = 0; i < cgraph->n_leafs; ++i) { - ggml_graph_export_leaf(cgraph->leafs[i], fout); - - GGML_ASSERT(cgraph->leafs[i]->op == GGML_OP_NONE); - GGML_ASSERT(cgraph->leafs[i]->src[0] == NULL); - GGML_ASSERT(cgraph->leafs[i]->src[1] == NULL); - } - - // header - fprintf(fout, "\n"); - fprintf(fout, "%-6s %-6s %-12s %8s %8s %8s %8s %8s %16s %16s %16s %16s %8s %16s %16s\n", - "ARG", "TYPE", "OP", "NDIMS", "NE0", "NE1", "NE2", "NE3", "NB0", "NB1", "NB2", "NB3", "NTASKS", "DATA", "NAME"); - - for (int i = 0; i < cgraph->n_nodes; ++i) { - ggml_graph_export_node(cgraph->nodes[i], "DST", fout); - - for (int j = 0; j < GGML_MAX_SRC; ++j) { - if (cgraph->nodes[i]->src[j]) { - ggml_graph_export_node(cgraph->nodes[i]->src[j], "SRC", fout); - } - } - - fprintf(fout, "\n"); - } - - fprintf(fout, "\n"); - } - - // write binary data - { - FILE * fout = fopen(fname, "wb"); - - if (!fout) { - fprintf(stderr, "%s: failed to open %s\n", __func__, fname); - return; - } - - // header - { - const uint32_t magic = GGML_FILE_MAGIC; - const uint32_t version = GGML_FILE_VERSION; - const uint32_t n_leafs = cgraph->n_leafs; - const uint32_t nodes = cgraph->n_nodes; - - fwrite(&magic, sizeof(uint32_t), 1, fout); - fwrite(&version, sizeof(uint32_t), 1, fout); - fwrite(&n_leafs, sizeof(uint32_t), 1, fout); - fwrite(&nodes, sizeof(uint32_t), 1, fout); - fwrite(&size_eval, sizeof(uint64_t), 1, fout); - } - - // leafs - { - for (int i = 0; i < cgraph->n_leafs; ++i) { - const struct ggml_tensor * tensor = cgraph->leafs[i]; - - const uint32_t type = tensor->type; - const uint32_t op = tensor->op; - const uint32_t n_dims = tensor->n_dims; - - fwrite(&type, sizeof(uint32_t), 1, fout); - fwrite(&op, sizeof(uint32_t), 1, fout); - fwrite(&n_dims, sizeof(uint32_t), 1, fout); - - for (int j = 0; j < GGML_MAX_DIMS; ++j) { - const uint64_t ne = tensor->ne[j]; - const uint64_t nb = tensor->nb[j]; - - fwrite(&ne, sizeof(uint64_t), 1, fout); - fwrite(&nb, sizeof(uint64_t), 1, fout); - } - - fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout); - fwrite(tensor->op_params, sizeof(char), GGML_MAX_OP_PARAMS, fout); - - // dump the data - // TODO: pad this to 32 byte boundary - { - const size_t size = ggml_nbytes(tensor); - - fwrite(tensor->data, sizeof(char), size, fout); - } - } - } - - // nodes - { - for (int i = 0; i < cgraph->n_nodes; ++i) { - const struct ggml_tensor * tensor = cgraph->nodes[i]; - - const uint32_t type = tensor->type; - const uint32_t op = tensor->op; - const uint32_t n_dims = tensor->n_dims; - - fwrite(&type, sizeof(uint32_t), 1, fout); - fwrite(&op, sizeof(uint32_t), 1, fout); - fwrite(&n_dims, sizeof(uint32_t), 1, fout); - - for (int j = 0; j < GGML_MAX_DIMS; ++j) { - const uint64_t ne = tensor->ne[j]; - const uint64_t nb = tensor->nb[j]; - - fwrite(&ne, sizeof(uint64_t), 1, fout); - fwrite(&nb, sizeof(uint64_t), 1, fout); - } - - fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout); - fwrite(tensor->op_params, sizeof(char), GGML_MAX_OP_PARAMS, fout); - - // output the op arguments - { - struct ggml_tensor * args[GGML_MAX_SRC] = { NULL }; - - for (int j = 0; j < GGML_MAX_SRC; ++j) { - args[j] = tensor->src[j]; - } - - for (int j = 0; j < GGML_MAX_SRC; ++j) { - if (args[j]) { - int32_t idx = -1; - - // check if leaf - { - for (int k = 0; k < cgraph->n_leafs; ++k) { - if (args[j] == cgraph->leafs[k]) { - idx = k; - break; - } - } - } - - // check if node - if (idx == -1) { - for (int k = 0; k < cgraph->n_nodes; ++k) { - if (args[j] == cgraph->nodes[k]) { - idx = GGML_MAX_NODES + k; - break; - } - } - } - - if (idx == -1) { - fprintf(stderr, "%s: failed to find tensor, arg = %d, node = %d\n", __func__, j, i); - return; - } - - fwrite(&idx, sizeof(int32_t), 1, fout); - } else { - const int32_t nul = -1; - - fwrite(&nul, sizeof(int32_t), 1, fout); - } - } - } - } - } - - fclose(fout); - } -} - -struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval) { - assert(*ctx_data == NULL); - assert(*ctx_eval == NULL); - - struct ggml_cgraph result = { 0 }; - - struct ggml_tensor * data = NULL; - - // read file into data - { - FILE * fin = fopen(fname, "rb"); - if (!fin) { - fprintf(stderr, "%s: failed to open %s\n", __func__, fname); - return result; - } - - size_t fsize = 0; - - fseek(fin, 0, SEEK_END); - fsize = ftell(fin); - fseek(fin, 0, SEEK_SET); - - // create the data context - { - const size_t overhead = 1*ggml_tensor_overhead(); - - struct ggml_init_params params = { - .mem_size = fsize + overhead, - .mem_buffer = NULL, - .no_alloc = false, - }; - - *ctx_data = ggml_init(params); - - if (!*ctx_data) { - fprintf(stderr, "%s: failed to create ggml context\n", __func__); - fclose(fin); - return result; - } - } - - data = ggml_new_tensor_1d(*ctx_data, GGML_TYPE_I8, fsize); - - { - const size_t ret = fread(data->data, sizeof(char), fsize, fin); - if (ret != fsize) { - fprintf(stderr, "%s: failed to read %s\n", __func__, fname); - fclose(fin); - return result; - } - } - - fclose(fin); - } - - // populate result - { - char * ptr = (char *) data->data; - - const uint32_t magic = *(const uint32_t *) ptr; ptr += sizeof(magic); - - if (magic != GGML_FILE_MAGIC) { - fprintf(stderr, "%s: invalid magic number, got %08x\n", __func__, magic); - return result; - } - - const uint32_t version = *(const uint32_t *) ptr; ptr += sizeof(version); - - if (version != GGML_FILE_VERSION) { - fprintf(stderr, "%s: invalid version number\n", __func__); - return result; - } - - const uint32_t n_leafs = *(const uint32_t *) ptr; ptr += sizeof(n_leafs); - const uint32_t n_nodes = *(const uint32_t *) ptr; ptr += sizeof(n_nodes); - const uint64_t size_eval = *(const uint64_t *) ptr; ptr += sizeof(size_eval); - - result.n_leafs = n_leafs; - result.n_nodes = n_nodes; - - // create the data context - { - const size_t overhead = (n_leafs + n_nodes)*ggml_tensor_overhead(); - - struct ggml_init_params params = { - .mem_size = size_eval + overhead, - .mem_buffer = NULL, - .no_alloc = true, - }; - - *ctx_eval = ggml_init(params); - - if (!*ctx_eval) { - fprintf(stderr, "%s: failed to create ggml context\n", __func__); - return result; - } - } - - // leafs - { - uint32_t type; - uint32_t op; - uint32_t n_dims; - - for (uint32_t i = 0; i < n_leafs; ++i) { - type = *(const uint32_t *) ptr; ptr += sizeof(type); - op = *(const uint32_t *) ptr; ptr += sizeof(op); - n_dims = *(const uint32_t *) ptr; ptr += sizeof(n_dims); - - int64_t ne[GGML_MAX_DIMS]; - size_t nb[GGML_MAX_DIMS]; - - for (int j = 0; j < GGML_MAX_DIMS; ++j) { - uint64_t ne_cur; - uint64_t nb_cur; - - ne_cur = *(const uint64_t *) ptr; ptr += sizeof(ne_cur); - nb_cur = *(const uint64_t *) ptr; ptr += sizeof(nb_cur); - - ne[j] = ne_cur; - nb[j] = nb_cur; - } - - struct ggml_tensor * tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne); - - tensor->op = (enum ggml_op) op; - - memcpy(tensor->name, ptr, GGML_MAX_NAME); ptr += GGML_MAX_NAME; - memcpy(tensor->op_params, ptr, GGML_MAX_OP_PARAMS); ptr += GGML_MAX_OP_PARAMS; - - tensor->data = (void *) ptr; - - for (int j = 0; j < GGML_MAX_DIMS; ++j) { - tensor->nb[j] = nb[j]; - } - - result.leafs[i] = tensor; - - ptr += ggml_nbytes(tensor); - - fprintf(stderr, "%s: loaded leaf %d: '%16s', %3d dims, %9zu bytes\n", __func__, i, tensor->name, n_dims, ggml_nbytes(tensor)); - } - } - - ggml_set_no_alloc(*ctx_eval, false); - - // nodes - { - uint32_t type; - uint32_t op; - uint32_t n_dims; - - for (uint32_t i = 0; i < n_nodes; ++i) { - type = *(const uint32_t *) ptr; ptr += sizeof(type); - op = *(const uint32_t *) ptr; ptr += sizeof(op); - n_dims = *(const uint32_t *) ptr; ptr += sizeof(n_dims); - - enum ggml_op eop = (enum ggml_op) op; - - int64_t ne[GGML_MAX_DIMS]; - size_t nb[GGML_MAX_DIMS]; - - for (int j = 0; j < GGML_MAX_DIMS; ++j) { - uint64_t ne_cur; - uint64_t nb_cur; - - ne_cur = *(const uint64_t *) ptr; ptr += sizeof(ne_cur); - nb_cur = *(const uint64_t *) ptr; ptr += sizeof(nb_cur); - - ne[j] = ne_cur; - nb[j] = nb_cur; - } - - const char * ptr_name = ptr; ptr += GGML_MAX_NAME; - const char * ptr_op_params = ptr; ptr += GGML_MAX_OP_PARAMS; - - const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += GGML_MAX_SRC*sizeof(int32_t); - - struct ggml_tensor * args[GGML_MAX_SRC] = { NULL }; - - // parse args - for (int j = 0; j < GGML_MAX_SRC; ++j) { - const int32_t arg_idx = ptr_arg_idx[j]; - - if (arg_idx == -1) { - continue; - } - - if (arg_idx < GGML_MAX_NODES) { - args[j] = result.leafs[arg_idx]; - } else { - args[j] = result.nodes[arg_idx - GGML_MAX_NODES]; - } - } - - // create the tensor - // "view" operations are handled differently - // TODO: handle inplace ops - currently a copy is always made - - struct ggml_tensor * tensor = NULL; - - switch (eop) { - // TODO: implement other view ops - case GGML_OP_RESHAPE: - { - tensor = ggml_reshape_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3]); - } break; - case GGML_OP_VIEW: - { - tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0); - - size_t offs; - memcpy(&offs, ptr_op_params, sizeof(offs)); - - tensor->data = ((char *) tensor->data) + offs; - } break; - case GGML_OP_TRANSPOSE: - { - tensor = ggml_transpose(*ctx_eval, args[0]); - } break; - case GGML_OP_PERMUTE: - { - tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0); - } break; - default: - { - tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, n_dims, ne); - - tensor->op = eop; - } break; - } - - memcpy(tensor->name, ptr_name, GGML_MAX_NAME); - memcpy(tensor->op_params, ptr_op_params, GGML_MAX_OP_PARAMS); - - for (int j = 0; j < GGML_MAX_DIMS; ++j) { - tensor->nb[j] = nb[j]; - } - - for (int j = 0; j < GGML_MAX_SRC; ++j) { - tensor->src[j] = args[j]; - } - - result.nodes[i] = tensor; - - fprintf(stderr, "%s: loaded node %d: '%16s', %3d dims, %9zu bytes\n", __func__, i, tensor->name, n_dims, ggml_nbytes(tensor)); - } - } - } - - return result; -} - -void ggml_graph_print(const struct ggml_cgraph * cgraph) { - int64_t perf_total_per_op_us[GGML_OP_COUNT] = {0}; - - GGML_PRINT("=== GRAPH ===\n"); - - GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes); - for (int i = 0; i < cgraph->n_nodes; i++) { - struct ggml_tensor * node = cgraph->nodes[i]; - - perf_total_per_op_us[node->op] += MAX(1, node->perf_time_us); - - GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n", - i, - node->ne[0], node->ne[1], node->ne[2], - ggml_op_name(node->op), node->is_param ? "x" : node->grad ? "g" : " ", node->perf_runs, - (double) node->perf_cycles / (double) ggml_cycles_per_ms(), - (double) node->perf_cycles / (double) ggml_cycles_per_ms() / (double) node->perf_runs, - (double) node->perf_time_us / 1000.0, - (double) node->perf_time_us / 1000.0 / node->perf_runs); - } - - GGML_PRINT("n_leafs = %d\n", cgraph->n_leafs); - for (int i = 0; i < cgraph->n_leafs; i++) { - struct ggml_tensor * node = cgraph->leafs[i]; - - GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n", - i, - node->ne[0], node->ne[1], - ggml_op_name(node->op), - ggml_get_name(node)); - } - - for (int i = 0; i < GGML_OP_COUNT; i++) { - if (perf_total_per_op_us[i] == 0) { - continue; - } - - GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", ggml_op_name(i), (double) perf_total_per_op_us[i] / 1000.0); - } - - GGML_PRINT("========================================\n"); -} - -// check if node is part of the graph -static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { - if (cgraph == NULL) { - return true; - } - - for (int i = 0; i < cgraph->n_nodes; i++) { - if (cgraph->nodes[i] == node) { - return true; - } - } - - return false; -} - -static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { - for (int i = 0; i < cgraph->n_nodes; i++) { - struct ggml_tensor * parent = cgraph->nodes[i]; - - if (parent->grad == node) { - return parent; - } - } - - return NULL; -} - -static void ggml_graph_dump_dot_node_edge(FILE * fp, const struct ggml_cgraph * gb, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) { - struct ggml_tensor * gparent = ggml_graph_get_parent(gb, node); - struct ggml_tensor * gparent0 = ggml_graph_get_parent(gb, parent); - fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"%s\"; ]\n", - gparent0 ? (void *) gparent0 : (void *) parent, - gparent0 ? "g" : "x", - gparent ? (void *) gparent : (void *) node, - gparent ? "g" : "x", - gparent ? "empty" : "vee", - gparent ? "dashed" : "solid", - label); -} - -static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) { - fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"%s\"; ]\n", - (void *) parent, "x", - (void *) node, "x", - label); -} - -void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) { - char color[16]; - - FILE * fp = fopen(filename, "w"); - GGML_ASSERT(fp); - - fprintf(fp, "digraph G {\n"); - fprintf(fp, " newrank = true;\n"); - fprintf(fp, " rankdir = LR;\n"); - - for (int i = 0; i < gb->n_nodes; i++) { - struct ggml_tensor * node = gb->nodes[i]; - - if (ggml_graph_get_parent(gb, node) != NULL) { - continue; - } - - if (node->is_param) { - snprintf(color, sizeof(color), "yellow"); - } else if (node->grad) { - if (ggml_graph_find(gf, node)) { - snprintf(color, sizeof(color), "green"); - } else { - snprintf(color, sizeof(color), "lightblue"); - } - } else { - snprintf(color, sizeof(color), "white"); - } - - fprintf(fp, " \"%p\" [ " - "style = filled; fillcolor = %s; shape = record; " - "label=\"", - (void *) node, color); - - if (strlen(node->name) > 0) { - fprintf(fp, "%s (%s)|", node->name, ggml_type_name(node->type)); - } else { - fprintf(fp, "(%s)|", ggml_type_name(node->type)); - } - - if (node->n_dims == 2) { - fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], ggml_op_symbol(node->op)); - } else { - fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], node->ne[2], ggml_op_symbol(node->op)); - } - - if (node->grad) { - fprintf(fp, " | %s\"; ]\n", ggml_op_symbol(node->grad->op)); - } else { - fprintf(fp, "\"; ]\n"); - } - } - - for (int i = 0; i < gb->n_leafs; i++) { - struct ggml_tensor * node = gb->leafs[i]; - - snprintf(color, sizeof(color), "pink"); - - fprintf(fp, " \"%p\" [ " - "style = filled; fillcolor = %s; shape = record; " - "label=\"", - (void *) node, color); - - if (strlen(node->name) > 0) { - fprintf(fp, "%s (%s)|", node->name, ggml_type_name(node->type)); - } else { - fprintf(fp, "(%s)|", ggml_type_name(node->type)); - } - - fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]); - if (ggml_nelements(node) < 5) { - fprintf(fp, " | ("); - for (int j = 0; j < ggml_nelements(node); j++) { - if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) { - fprintf(fp, "%d", ggml_get_i32_1d(node, j)); - } - else if (node->type == GGML_TYPE_F32 || node->type == GGML_TYPE_F16) { - fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, j)); - } - else { - fprintf(fp, "#"); - } - if (j < ggml_nelements(node) - 1) { - fprintf(fp, ", "); - } - } - fprintf(fp, ")"); - } - fprintf(fp, "\"; ]\n"); - } - - for (int i = 0; i < gb->n_nodes; i++) { - struct ggml_tensor * node = gb->nodes[i]; - - for (int j = 0; j < GGML_MAX_SRC; j++) { - if (node->src[j]) { - char label[16]; - snprintf(label, sizeof(label), "src %d", j); - ggml_graph_dump_dot_node_edge(fp, gb, node, node->src[j], label); - } - } - } - - for (int i = 0; i < gb->n_leafs; i++) { - struct ggml_tensor * node = gb->leafs[i]; - - for (int j = 0; j < GGML_MAX_SRC; j++) { - if (node->src[j]) { - char label[16]; - snprintf(label, sizeof(label), "src %d", j); - ggml_graph_dump_dot_leaf_edge(fp, node, node->src[j], label); - } - } - } - - fprintf(fp, "}\n"); - - fclose(fp); - - GGML_PRINT("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename); -} - -//////////////////////////////////////////////////////////////////////////////// - -static void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const float * x) { - int i = 0; - for (int p = 0; p < np; ++p) { - const int64_t ne = ggml_nelements(ps[p]) ; - // TODO: add function to set tensor from array - for (int64_t j = 0; j < ne; ++j) { - ggml_set_f32_1d(ps[p], j, x[i++]); - } - } -} - -static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * x) { - int i = 0; - for (int p = 0; p < np; ++p) { - const int64_t ne = ggml_nelements(ps[p]) ; - // TODO: add function to get all elements at once - for (int64_t j = 0; j < ne; ++j) { - x[i++] = ggml_get_f32_1d(ps[p], j); - } - } -} - -static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) { - int64_t i = 0; - for (int p = 0; p < np; ++p) { - const int64_t ne = ggml_nelements(ps[p]) ; - // TODO: add function to get all elements at once - for (int64_t j = 0; j < ne; ++j) { - g[i++] = ggml_get_f32_1d(ps[p]->grad, j); - } - } -} - -static void ggml_opt_acc_grad(int np, struct ggml_tensor * const ps[], float * g, float scale) { - int64_t i = 0; - for (int p = 0; p < np; ++p) { - const int64_t ne = ggml_nelements(ps[p]) ; - // TODO: add function to get all elements at once - for (int64_t j = 0; j < ne; ++j) { - g[i++] += ggml_get_f32_1d(ps[p]->grad, j) * scale; - } - } -} - -// -// ADAM -// -// ref: https://arxiv.org/pdf/1412.6980.pdf -// - -static enum ggml_opt_result ggml_opt_adam( - struct ggml_context * ctx, - struct ggml_opt_context * opt, - struct ggml_opt_params params, - struct ggml_tensor * f, - struct ggml_cgraph * gf, - struct ggml_cgraph * gb, - ggml_opt_callback callback, - void * callback_data) { - GGML_ASSERT(ggml_is_scalar(f)); - - // these will store the parameters we want to optimize - struct ggml_tensor * ps[GGML_MAX_PARAMS]; - - int np = 0; - int64_t nx = 0; - for (int i = 0; i < gf->n_nodes; ++i) { - if (gf->nodes[i]->is_param) { - GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); - - GGML_ASSERT(np < GGML_MAX_PARAMS); - - ps[np++] = gf->nodes[i]; - nx += ggml_nelements(gf->nodes[i]); - } - } - - if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past)) { - int iter = opt->iter; - ggml_opt_init(opt->ctx, opt, params, nx); - opt->iter = iter; - } - - // constants - float sched = params.adam.sched; - const float alpha = params.adam.alpha; - const float decay = params.adam.decay * alpha; - const float beta1 = params.adam.beta1; - const float beta2 = params.adam.beta2; - const float eps = params.adam.eps; - const float gclip = params.adam.gclip; - const int decay_min_ndim = params.adam.decay_min_ndim; - const int n_accum = MAX(1, params.n_gradient_accumulation); - const float accum_norm = 1.0f / (float) n_accum; - - float * g = opt->adam.g->data; // gradients - float * m = opt->adam.m->data; // first moment - float * v = opt->adam.v->data; // second moment - - float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values - - struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads); - struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size); - cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; - - bool cancel = false; - - // compute the function value - float fx = 0; - ggml_set_zero(opt->adam.g); - for (int accum_step = 0; accum_step < n_accum; ++accum_step) { - if (callback) { - callback(callback_data, accum_step, &sched, &cancel); - if (cancel) { - return GGML_OPT_CANCEL; - } - } - // ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); - ggml_graph_compute(gb, &cplan); - ggml_opt_acc_grad(np, ps, g, accum_norm); - fx += ggml_get_f32_1d(f, 0); - } - fx *= accum_norm; - - opt->adam.fx_prev = fx; - opt->adam.fx_best = opt->adam.fx_prev; - if (pf) { - pf[opt->iter % params.past] = opt->adam.fx_prev; - } - - opt->loss_before = opt->adam.fx_prev; - opt->loss_after = opt->adam.fx_prev; - - // initialize - if (opt->just_initialized) { - opt->adam.n_no_improvement = 0; - opt->just_initialized = false; - } - - float * fx_best = &opt->adam.fx_best; - float * fx_prev = &opt->adam.fx_prev; - int * n_no_improvement = &opt->adam.n_no_improvement; - - int iter0 = opt->iter; - - // run the optimizer - for (int t = 0; t < params.adam.n_iter; ++t) { - opt->iter = iter0 + t + 1; - GGML_PRINT_DEBUG ("=== iter %d ===\n", t); - - GGML_PRINT_DEBUG ("f = %10.6f\n", ggml_get_f32_1d(f, 0)); - GGML_PRINT_DEBUG_5("df/dx0 = %10.6f\n", ggml_get_f32_1d(ps[0]->grad, 0)); - GGML_PRINT_DEBUG_5("df/dx1 = %10.6f\n", ggml_get_f32_1d(ps[1]->grad, 0)); - - for (int i = 0; i < np; ++i) { - GGML_PRINT_DEBUG("param %d: %10.6f, g = %10.6f\n", i, - ggml_get_f32_1d(ps[i], 0), ggml_get_f32_1d(ps[i]->grad, 0)); - } - - const int64_t t_start_wall = ggml_time_us(); - const int64_t t_start_cpu = ggml_cycles(); - UNUSED(t_start_wall); - UNUSED(t_start_cpu); - - { - float gnorm = 1.0f; - if (gclip > 0.0f) { - // gradient clipping - ggml_float sum = 0.0; - for (int64_t i = 0; i < nx; ++i) { - sum += (ggml_float)(g[i]*g[i]); - } - ggml_float norm = sqrt(sum); - if (norm > (ggml_float) gclip) { - gnorm = (float) ((ggml_float) gclip / norm); - } - } - const float beta1h = alpha*sched/(1.0f - powf(beta1, opt->iter)); - const float beta2h = 1.0f/(1.0f - powf(beta2, opt->iter)); - int64_t i = 0; - for (int p = 0; p < np; ++p) { - const int64_t ne = ggml_nelements(ps[p]); - const float p_decay = ((ps[p]->n_dims >= decay_min_ndim) ? decay : 0.0f) * sched; - for (int64_t j = 0; j < ne; ++j) { - float x = ggml_get_f32_1d(ps[p], j); - float g_ = g[i]*gnorm; - m[i] = m[i]*beta1 + g_*(1.0f - beta1); - v[i] = v[i]*beta2 + g_*g_*(1.0f - beta2); - float mh = m[i]*beta1h; - float vh = v[i]*beta2h; - vh = sqrtf(vh) + eps; - x = x*(1.0f - p_decay) - mh/vh; - ggml_set_f32_1d(ps[p], j, x); - ++i; - } - } - } - - fx = 0; - ggml_set_zero(opt->adam.g); - for (int accum_step = 0; accum_step < n_accum; ++accum_step) { - if (callback) { - callback(callback_data, accum_step, &sched, &cancel); - if (cancel) { - return GGML_OPT_CANCEL;; - } - } - // ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); - ggml_graph_compute(gb, &cplan); - ggml_opt_acc_grad(np, ps, g, accum_norm); - fx += ggml_get_f32_1d(f, 0); - } - fx *= accum_norm; - - opt->loss_after = fx; - - - // check convergence - if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) { - GGML_PRINT_DEBUG("converged\n"); - - return GGML_OPT_OK; - } - - // delta-based convergence test - if (pf != NULL) { - // need at least params.past iterations to start checking for convergence - if (params.past <= iter0 + t) { - const float rate = (pf[(iter0 + t)%params.past] - fx)/fx; - - if (fabsf(rate) < params.delta) { - return GGML_OPT_OK; - } - } - - pf[(iter0 + t)%params.past] = fx; - } - - // check for improvement - if (params.max_no_improvement > 0) { - if (fx_best[0] > fx) { - fx_best[0] = fx; - n_no_improvement[0] = 0; - } else { - ++n_no_improvement[0]; - - if (n_no_improvement[0] >= params.max_no_improvement) { - return GGML_OPT_OK; - } - } - } - - fx_prev[0] = fx; - - { - const int64_t t_end_cpu = ggml_cycles(); - GGML_PRINT_DEBUG("time iter: %5.3f s\n", ((float)(t_end_cpu - t_start_cpu))/CLOCKS_PER_SEC); - UNUSED(t_end_cpu); - - const int64_t t_end_wall = ggml_time_us(); - GGML_PRINT_DEBUG("wall time iter: %5.3f s\n", (t_end_wall - t_start_wall)/1e6); - UNUSED(t_end_wall); - } - } - - return GGML_OPT_DID_NOT_CONVERGE; -} - -// -// L-BFGS -// -// the L-BFGS implementation below is based on the following implementation: -// -// https://github.com/chokkan/liblbfgs -// - -struct ggml_lbfgs_iteration_data { - float alpha; - float ys; - float * s; - float * y; -}; - -static enum ggml_opt_result linesearch_backtracking( - const struct ggml_opt_params * params, - int nx, - float * x, - float * fx, - float * g, - float * d, - float * step, - const float * xp, - struct ggml_tensor * f, - struct ggml_cgraph * gb, - struct ggml_cplan * cplan, - const int np, - struct ggml_tensor * ps[], - bool * cancel, - ggml_opt_callback callback, - void * callback_data) { - int count = 0; - - float width = 0.0f; - float dg = 0.0f; - float finit = 0.0f; - float dginit = 0.0f; - float dgtest = 0.0f; - - const float dec = 0.5f; - const float inc = 2.1f; - - const int n_accum = MAX(1, params->n_gradient_accumulation); - const float accum_norm = 1.0f / (float) n_accum; - - if (*step <= 0.f) { - return GGML_LINESEARCH_INVALID_PARAMETERS; - } - - // compute the initial gradient in the search direction - ggml_vec_dot_f32(nx, &dginit, g, d); - - // make sure that d points to a descent direction - if (0 < dginit) { - return GGML_LINESEARCH_FAIL; - } - - // initialize local variables - finit = *fx; - dgtest = params->lbfgs.ftol*dginit; - - while (true) { - ggml_vec_cpy_f32(nx, x, xp); - ggml_vec_mad_f32(nx, x, d, *step); - - // evaluate the function and gradient values - { - ggml_opt_set_params(np, ps, x); - - *fx = 0; - memset(g, 0, sizeof(float)*nx); - for (int accum_step = 0; accum_step < n_accum; ++accum_step) { - if (callback) { - // LBFG-S does not support learning rate -> ignore learning schedule - float sched = 0; - callback(callback_data, accum_step, &sched, cancel); - if (*cancel) { - return GGML_OPT_CANCEL; - } - } - // ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); - ggml_graph_compute(gb, cplan); - ggml_opt_acc_grad(np, ps, g, accum_norm); - *fx += ggml_get_f32_1d(f, 0); - } - *fx *= accum_norm; - - } - - ++count; - - if (*fx > finit + (*step)*dgtest) { - width = dec; - } else { - // Armijo condition is satisfied - if (params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_ARMIJO) { - return count; - } - - ggml_vec_dot_f32(nx, &dg, g, d); - - // check the Wolfe condition - if (dg < params->lbfgs.wolfe * dginit) { - width = inc; - } else { - if(params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE) { - // regular Wolfe conditions - return count; - } - - if(dg > -params->lbfgs.wolfe*dginit) { - width = dec; - } else { - // strong Wolfe condition (GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) - return count; - } - } - } - - if (*step < params->lbfgs.min_step) { - return GGML_LINESEARCH_MINIMUM_STEP; - } - if (*step > params->lbfgs.max_step) { - return GGML_LINESEARCH_MAXIMUM_STEP; - } - if (params->lbfgs.max_linesearch <= count) { - return GGML_LINESEARCH_MAXIMUM_ITERATIONS; - } - - (*step) *= width; - } - - GGML_UNREACHABLE(); -} - -static enum ggml_opt_result ggml_opt_lbfgs( - struct ggml_context * ctx, - struct ggml_opt_context * opt, - struct ggml_opt_params params, - struct ggml_tensor * f, - struct ggml_cgraph * gf, - struct ggml_cgraph * gb, - ggml_opt_callback callback, - void * callback_data) { - if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE || - params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) { - if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1.f <= params.lbfgs.wolfe) { - return GGML_OPT_INVALID_WOLFE; - } - } - - const int m = params.lbfgs.m; - - // these will store the parameters we want to optimize - struct ggml_tensor * ps[GGML_MAX_PARAMS]; - - int np = 0; - int nx = 0; - for (int i = 0; i < gf->n_nodes; ++i) { - if (gf->nodes[i]->is_param) { - GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); - - GGML_ASSERT(np < GGML_MAX_PARAMS); - - ps[np++] = gf->nodes[i]; - nx += ggml_nelements(gf->nodes[i]); - } - } - - if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past) || (opt->params.lbfgs.m != params.lbfgs.m)) { - int iter = opt->iter; - ggml_opt_init(ctx, opt, params, nx); - opt->iter = iter; - } - - struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads); - struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size); - cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; - - float * x = opt->lbfgs.x->data; // current parameters - float * xp = opt->lbfgs.xp->data; // previous parameters - float * g = opt->lbfgs.g->data; // current gradient - float * gp = opt->lbfgs.gp->data; // previous gradient - float * d = opt->lbfgs.d->data; // search direction - - float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values - - const int n_accum = MAX(1, params.n_gradient_accumulation); - const float accum_norm = 1.0f / (float) n_accum; - - float fx = 0.0f; // cost function value - float xnorm = 0.0f; // ||x|| - float gnorm = 0.0f; // ||g|| - - // initialize x from the graph nodes - ggml_opt_get_params(np, ps, x); - - // the L-BFGS memory - float * lm_alpha = opt->lbfgs.lmal->data; - float * lm_ys = opt->lbfgs.lmys->data; - float * lm_s = opt->lbfgs.lms->data; - float * lm_y = opt->lbfgs.lmy->data; - - bool cancel = false; - - // evaluate the function value and its gradient - { - ggml_opt_set_params(np, ps, x); - - fx = 0; - memset(g, 0, sizeof(float)*nx); - for (int accum_step = 0; accum_step < n_accum; ++accum_step) { - if (callback) { - // LBFG-S does not support learning rate -> ignore learning schedule - float sched = 0; - callback(callback_data, accum_step, &sched, &cancel); - if (cancel) { - return GGML_OPT_CANCEL; - } - } - // ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); - ggml_graph_compute(gb, &cplan); - ggml_opt_acc_grad(np, ps, g, accum_norm); - fx += ggml_get_f32_1d(f, 0); - } - fx *= accum_norm; - - opt->loss_before = fx; - opt->loss_after = fx; - } - - // search direction = -gradient - ggml_vec_neg_f32(nx, d, g); - - // ||x||, ||g|| - ggml_vec_norm_f32(nx, &xnorm, x); - ggml_vec_norm_f32(nx, &gnorm, g); - - if (xnorm < 1.0f) { - xnorm = 1.0f; - } - - // already optimized - if (gnorm/xnorm <= params.lbfgs.eps) { - return GGML_OPT_OK; - } - - if (opt->just_initialized) { - if (pf) { - pf[0] = fx; - } - opt->lbfgs.fx_best = fx; - - // initial step - ggml_vec_norm_inv_f32(nx, &opt->lbfgs.step, d); - opt->lbfgs.j = 0; - opt->lbfgs.k = 1; - opt->lbfgs.end = 0; - opt->lbfgs.n_no_improvement = 0; - opt->just_initialized = false; - } - - float * fx_best = &opt->lbfgs.fx_best; - float * step = &opt->lbfgs.step; - int * j = &opt->lbfgs.j; - int * k = &opt->lbfgs.k; - int * end = &opt->lbfgs.end; - int * n_no_improvement = &opt->lbfgs.n_no_improvement; - - int ls = 0; - int bound = 0; - - float ys = 0.0f; - float yy = 0.0f; - float beta = 0.0f; - - int it = 0; - - while (true) { - // store the current position and gradient vectors - ggml_vec_cpy_f32(nx, xp, x); - ggml_vec_cpy_f32(nx, gp, g); - - // TODO: instead of passing &cancel here, use the return code of the linesearch - // to determine if the optimization should be cancelled - // this is a simple change, but not doing this atm, since I don't have a nice - // way to test and don't want to break something with so many changes lined up - ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data); - if (cancel) { - return GGML_OPT_CANCEL; - } - - if (ls < 0) { - // linesearch failed - go back to the previous point and return - ggml_vec_cpy_f32(nx, x, xp); - ggml_vec_cpy_f32(nx, g, gp); - - return ls; - } - - opt->loss_after = fx; - - ggml_vec_norm_f32(nx, &xnorm, x); - ggml_vec_norm_f32(nx, &gnorm, g); - - GGML_PRINT_DEBUG("f = %10.6f\n", ggml_get_f32_1d(f, 0)); - - if (xnorm < 1.0f) { - xnorm = 1.0f; - } - if (gnorm/xnorm <= params.lbfgs.eps) { - // converged - return GGML_OPT_OK; - } - - // delta-based convergence test - if (pf != NULL) { - // need at least params.past iterations to start checking for convergence - if (params.past <= k[0]) { - const float rate = (pf[k[0]%params.past] - fx)/fx; - - if (fabsf(rate) < params.delta) { - return GGML_OPT_OK; - } - } - - pf[k[0]%params.past] = fx; - } - - // check for improvement - if (params.max_no_improvement > 0) { - if (fx < fx_best[0]) { - fx_best[0] = fx; - n_no_improvement[0] = 0; - } else { - n_no_improvement[0]++; - - if (n_no_improvement[0] >= params.max_no_improvement) { - return GGML_OPT_OK; - } - } - } - - if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < it + 1) { - // reached the maximum number of iterations - return GGML_OPT_DID_NOT_CONVERGE; - } - - // update vectors s and y: - // s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}. - // y_{k+1} = g_{k+1} - g_{k}. - // - ggml_vec_sub_f32(nx, &lm_s[end[0]*nx], x, xp); - ggml_vec_sub_f32(nx, &lm_y[end[0]*nx], g, gp); - - // compute scalars ys and yy: - // ys = y^t \cdot s -> 1 / \rho. - // yy = y^t \cdot y. - // - ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0]*nx]); - ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]); - - lm_ys[end[0]] = ys; - - // find new search direction - // ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS - - bound = (m <= k[0]) ? m : k[0]; - k[0]++; - it++; - end[0] = (end[0] + 1)%m; - - // initialize search direction with -g - ggml_vec_neg_f32(nx, d, g); - - j[0] = end[0]; - for (int i = 0; i < bound; ++i) { - j[0] = (j[0] + m - 1) % m; - // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1} - ggml_vec_dot_f32(nx, &lm_alpha[j[0]], &lm_s[j[0]*nx], d); - lm_alpha[j[0]] /= lm_ys[j[0]]; - // q_{i} = q_{i+1} - \alpha_{i} y_{i} - ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]); - } - - ggml_vec_scale_f32(nx, d, ys/yy); - - for (int i = 0; i < bound; ++i) { - // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i} - ggml_vec_dot_f32(nx, &beta, &lm_y[j[0]*nx], d); - beta /= lm_ys[j[0]]; - // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j} - ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta); - j[0] = (j[0] + 1)%m; - } - - step[0] = 1.0; - } - - GGML_UNREACHABLE(); -} - -struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) { - struct ggml_opt_params result; - - switch (type) { - case GGML_OPT_ADAM: - { - result = (struct ggml_opt_params) { - .type = GGML_OPT_ADAM, - .n_threads = 1, - .past = 0, - .delta = 1e-5f, - - .max_no_improvement = 100, - - .print_forward_graph = true, - .print_backward_graph = true, - - .n_gradient_accumulation = 1, - - .adam = { - .n_iter = 10000, - .sched = 1.000f, - .decay = 0.0f, - .decay_min_ndim = 2, - .alpha = 0.001f, - .beta1 = 0.9f, - .beta2 = 0.999f, - .eps = 1e-8f, - .eps_f = 1e-5f, - .eps_g = 1e-3f, - .gclip = 0.0f, - }, - }; - } break; - case GGML_OPT_LBFGS: - { - result = (struct ggml_opt_params) { - .type = GGML_OPT_LBFGS, - .n_threads = 1, - .past = 0, - .delta = 1e-5f, - - .max_no_improvement = 0, - - .print_forward_graph = true, - .print_backward_graph = true, - - .n_gradient_accumulation = 1, - - .lbfgs = { - .m = 6, - .n_iter = 100, - .max_linesearch = 20, - - .eps = 1e-5f, - .ftol = 1e-4f, - .wolfe = 0.9f, - .min_step = 1e-20f, - .max_step = 1e+20f, - - .linesearch = GGML_LINESEARCH_DEFAULT, - }, - }; - } break; - } - - return result; -} - -GGML_API void ggml_opt_init( - struct ggml_context * ctx, - struct ggml_opt_context * opt, - struct ggml_opt_params params, - int64_t nx) { - opt->ctx = ctx; - opt->params = params; - opt->iter = 0; - opt->nx = nx; - opt->just_initialized = true; - if (opt->ctx == NULL) { - struct ggml_init_params ctx_opt_params; - if (opt->params.type == GGML_OPT_ADAM) { - ctx_opt_params.mem_size = GGML_MEM_ALIGN*3 + ggml_tensor_overhead()*3 + ggml_type_size(GGML_TYPE_F32)*nx*3; - if (opt->params.past > 0) { - ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past; - } - } else if (opt->params.type == GGML_OPT_LBFGS) { - ctx_opt_params.mem_size = GGML_MEM_ALIGN*9 + ggml_tensor_overhead()*9 + ggml_type_size(GGML_TYPE_F32)*(nx*5 + opt->params.lbfgs.m*2 + nx*opt->params.lbfgs.m*2); - if (opt->params.past > 0) { - ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past; - } - } - ctx_opt_params.mem_buffer = NULL; - ctx_opt_params.no_alloc = false; - - opt->ctx = ggml_init(ctx_opt_params); - } - switch (opt->params.type) { - case GGML_OPT_ADAM: - { - opt->adam.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); - opt->adam.m = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); - opt->adam.v = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); - opt->adam.pf = params.past > 0 - ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past) - : NULL; - ggml_set_zero(opt->adam.m); - ggml_set_zero(opt->adam.v); - if (opt->adam.pf) { - ggml_set_zero(opt->adam.pf); - } - } break; - case GGML_OPT_LBFGS: - { - opt->lbfgs.x = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); - opt->lbfgs.xp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); - opt->lbfgs.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); - opt->lbfgs.gp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); - opt->lbfgs.d = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); - opt->lbfgs.pf = params.past > 0 - ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past) - : NULL; - opt->lbfgs.lmal = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m); - opt->lbfgs.lmys = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m); - opt->lbfgs.lms = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m); - opt->lbfgs.lmy = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m); - ggml_set_zero(opt->lbfgs.x); - ggml_set_zero(opt->lbfgs.xp); - ggml_set_zero(opt->lbfgs.g); - ggml_set_zero(opt->lbfgs.gp); - ggml_set_zero(opt->lbfgs.d); - if (opt->lbfgs.pf) { - ggml_set_zero(opt->lbfgs.pf); - } - ggml_set_zero(opt->lbfgs.lmal); - ggml_set_zero(opt->lbfgs.lmys); - ggml_set_zero(opt->lbfgs.lms); - ggml_set_zero(opt->lbfgs.lmy); - } break; - } -} - -enum ggml_opt_result ggml_opt( - struct ggml_context * ctx, - struct ggml_opt_params params, - struct ggml_tensor * f) { - bool free_ctx = false; - if (ctx == NULL) { - struct ggml_init_params params_ctx = { - .mem_size = 16*1024*1024, - .mem_buffer = NULL, - .no_alloc = false, - }; - - ctx = ggml_init(params_ctx); - if (ctx == NULL) { - return GGML_OPT_NO_CONTEXT; - } - - free_ctx = true; - } - - enum ggml_opt_result result = GGML_OPT_OK; - - struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context)); - - ggml_opt_init(ctx, opt, params, 0); - result = ggml_opt_resume(ctx, opt, f); - - if (free_ctx) { - ggml_free(ctx); - } - - return result; -} - -enum ggml_opt_result ggml_opt_resume( - struct ggml_context * ctx, - struct ggml_opt_context * opt, - struct ggml_tensor * f) { - - // build forward + backward compute graphs - struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32)+ (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0)); - struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32)+ (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0)); - - struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data; - struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data; - - *gf = ggml_build_forward (f); - *gb = ggml_build_backward(ctx, gf, true); - - return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL); -} - -enum ggml_opt_result ggml_opt_resume_g( - struct ggml_context * ctx, - struct ggml_opt_context * opt, - struct ggml_tensor * f, - struct ggml_cgraph * gf, - struct ggml_cgraph * gb, - ggml_opt_callback callback, - void * callback_data) { - - // build forward + backward compute graphs - enum ggml_opt_result result = GGML_OPT_OK; - - switch (opt->params.type) { - case GGML_OPT_ADAM: - { - result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb, callback, callback_data); - } break; - case GGML_OPT_LBFGS: - { - result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb, callback, callback_data); - } break; - } - - if (opt->params.print_forward_graph) { - ggml_graph_print (gf); - ggml_graph_dump_dot(gf, NULL, "opt-forward.dot"); - } - - if (opt->params.print_backward_graph) { - ggml_graph_print (gb); - ggml_graph_dump_dot(gb, gf, "opt-backward.dot"); - } - - return result; -} - -//////////////////////////////////////////////////////////////////////////////// - -size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) { - assert(k % QK4_0 == 0); - const int nb = k / QK4_0; - - for (int b = 0; b < n; b += k) { - block_q4_0 * restrict y = (block_q4_0 *) dst + b/QK4_0; - - quantize_row_q4_0_reference(src + b, y, k); - - for (int i = 0; i < nb; i++) { - for (int j = 0; j < QK4_0; j += 2) { - const uint8_t vi0 = y[i].qs[j/2] & 0x0F; - const uint8_t vi1 = y[i].qs[j/2] >> 4; - - hist[vi0]++; - hist[vi1]++; - } - } - } - - return (n/QK4_0*sizeof(block_q4_0)); -} - -size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) { - assert(k % QK4_1 == 0); - const int nb = k / QK4_1; - - for (int b = 0; b < n; b += k) { - block_q4_1 * restrict y = (block_q4_1 *) dst + b/QK4_1; - - quantize_row_q4_1_reference(src + b, y, k); - - for (int i = 0; i < nb; i++) { - for (int j = 0; j < QK4_1; j += 2) { - const uint8_t vi0 = y[i].qs[j/2] & 0x0F; - const uint8_t vi1 = y[i].qs[j/2] >> 4; - - hist[vi0]++; - hist[vi1]++; - } - } - } - - return (n/QK4_1*sizeof(block_q4_1)); -} - -size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) { - assert(k % QK5_0 == 0); - const int nb = k / QK5_0; - - for (int b = 0; b < n; b += k) { - block_q5_0 * restrict y = (block_q5_0 *)dst + b/QK5_0; - - quantize_row_q5_0_reference(src + b, y, k); - - for (int i = 0; i < nb; i++) { - uint32_t qh; - memcpy(&qh, &y[i].qh, sizeof(qh)); - - for (int j = 0; j < QK5_0; j += 2) { - const uint8_t vh0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; - const uint8_t vh1 = ((qh & (1u << (j + 16))) >> (j + 12)); - - // cast to 16 bins - const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2; - const uint8_t vi1 = ((y[i].qs[j/2] >> 4) | vh1) / 2; - - hist[vi0]++; - hist[vi1]++; - } - } - } - - return (n/QK5_0*sizeof(block_q5_0)); -} - -size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) { - assert(k % QK5_1 == 0); - const int nb = k / QK5_1; - - for (int b = 0; b < n; b += k) { - block_q5_1 * restrict y = (block_q5_1 *)dst + b/QK5_1; - - quantize_row_q5_1_reference(src + b, y, k); - - for (int i = 0; i < nb; i++) { - uint32_t qh; - memcpy(&qh, &y[i].qh, sizeof(qh)); - - for (int j = 0; j < QK5_1; j += 2) { - const uint8_t vh0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; - const uint8_t vh1 = ((qh & (1u << (j + 16))) >> (j + 12)); - - // cast to 16 bins - const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2; - const uint8_t vi1 = ((y[i].qs[j/2] >> 4) | vh1) / 2; - - hist[vi0]++; - hist[vi1]++; - } - } - } - - return (n/QK5_1*sizeof(block_q5_1)); -} - -size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) { - assert(k % QK8_0 == 0); - const int nb = k / QK8_0; - - for (int b = 0; b < n; b += k) { - block_q8_0 * restrict y = (block_q8_0 *)dst + b/QK8_0; - - quantize_row_q8_0_reference(src + b, y, k); - - for (int i = 0; i < nb; i++) { - for (int j = 0; j < QK8_0; ++j) { - const int8_t vi = y[i].qs[j]; - - hist[vi/16 + 8]++; - } - } - } - - return (n/QK8_0*sizeof(block_q8_0)); -} - -size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) { - size_t result = 0; - switch (type) { - case GGML_TYPE_Q4_0: - { - GGML_ASSERT(start % QK4_0 == 0); - block_q4_0 * block = (block_q4_0*)dst + start / QK4_0; - result = ggml_quantize_q4_0(src + start, block, n, n, hist); - } break; - case GGML_TYPE_Q4_1: - { - GGML_ASSERT(start % QK4_1 == 0); - block_q4_1 * block = (block_q4_1*)dst + start / QK4_1; - result = ggml_quantize_q4_1(src + start, block, n, n, hist); - } break; - case GGML_TYPE_Q5_0: - { - GGML_ASSERT(start % QK5_0 == 0); - block_q5_0 * block = (block_q5_0*)dst + start / QK5_0; - result = ggml_quantize_q5_0(src + start, block, n, n, hist); - } break; - case GGML_TYPE_Q5_1: - { - GGML_ASSERT(start % QK5_1 == 0); - block_q5_1 * block = (block_q5_1*)dst + start / QK5_1; - result = ggml_quantize_q5_1(src + start, block, n, n, hist); - } break; - case GGML_TYPE_Q8_0: - { - GGML_ASSERT(start % QK8_0 == 0); - block_q8_0 * block = (block_q8_0*)dst + start / QK8_0; - result = ggml_quantize_q8_0(src + start, block, n, n, hist); - } break; -#ifdef GGML_USE_K_QUANTS - case GGML_TYPE_Q2_K: - { - GGML_ASSERT(start % QK_K == 0); - block_q2_K * block = (block_q2_K*)dst + start / QK_K; - result = ggml_quantize_q2_K(src + start, block, n, n, hist); - } break; - case GGML_TYPE_Q3_K: - { - GGML_ASSERT(start % QK_K == 0); - block_q3_K * block = (block_q3_K*)dst + start / QK_K; - result = ggml_quantize_q3_K(src + start, block, n, n, hist); - } break; - case GGML_TYPE_Q4_K: - { - GGML_ASSERT(start % QK_K == 0); - block_q4_K * block = (block_q4_K*)dst + start / QK_K; - result = ggml_quantize_q4_K(src + start, block, n, n, hist); - } break; - case GGML_TYPE_Q5_K: - { - GGML_ASSERT(start % QK_K == 0); - block_q5_K * block = (block_q5_K*)dst + start / QK_K; - result = ggml_quantize_q5_K(src + start, block, n, n, hist); - } break; - case GGML_TYPE_Q6_K: - { - GGML_ASSERT(start % QK_K == 0); - block_q6_K * block = (block_q6_K*)dst + start / QK_K; - result = ggml_quantize_q6_K(src + start, block, n, n, hist); - } break; -#endif - case GGML_TYPE_F16: - { - int elemsize = sizeof(ggml_fp16_t); - ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n); - result = n * elemsize; - } break; - case GGML_TYPE_F32: - { - int elemsize = sizeof(float); - result = n * elemsize; - memcpy((uint8_t *)dst + start * elemsize, src + start, result); - } break; - default: - assert(false); - } - return result; -} - -//////////////////////////////////////////////////////////////////////////////// - -struct gguf_str { - uint64_t n; // GGUFv2 - char * data; -}; - -static const size_t GGUF_TYPE_SIZE[GGUF_TYPE_COUNT] = { - [GGUF_TYPE_UINT8] = sizeof(uint8_t), - [GGUF_TYPE_INT8] = sizeof(int8_t), - [GGUF_TYPE_UINT16] = sizeof(uint16_t), - [GGUF_TYPE_INT16] = sizeof(int16_t), - [GGUF_TYPE_UINT32] = sizeof(uint32_t), - [GGUF_TYPE_INT32] = sizeof(int32_t), - [GGUF_TYPE_FLOAT32] = sizeof(float), - [GGUF_TYPE_BOOL] = sizeof(bool), - [GGUF_TYPE_STRING] = sizeof(struct gguf_str), - [GGUF_TYPE_UINT64] = sizeof(uint64_t), - [GGUF_TYPE_INT64] = sizeof(int64_t), - [GGUF_TYPE_FLOAT64] = sizeof(double), - [GGUF_TYPE_ARRAY] = 0, // undefined -}; -static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13"); - -static const char * GGUF_TYPE_NAME[GGUF_TYPE_COUNT] = { - [GGUF_TYPE_UINT8] = "u8", - [GGUF_TYPE_INT8] = "i8", - [GGUF_TYPE_UINT16] = "u16", - [GGUF_TYPE_INT16] = "i16", - [GGUF_TYPE_UINT32] = "u32", - [GGUF_TYPE_INT32] = "i32", - [GGUF_TYPE_FLOAT32] = "f32", - [GGUF_TYPE_BOOL] = "bool", - [GGUF_TYPE_STRING] = "str", - [GGUF_TYPE_ARRAY] = "arr", - [GGUF_TYPE_UINT64] = "u64", - [GGUF_TYPE_INT64] = "i64", - [GGUF_TYPE_FLOAT64] = "f64", -}; -static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13"); - -union gguf_value { - uint8_t uint8; - int8_t int8; - uint16_t uint16; - int16_t int16; - uint32_t uint32; - int32_t int32; - float float32; - uint64_t uint64; - int64_t int64; - double float64; - bool bool_; - - struct gguf_str str; - - struct { - enum gguf_type type; - - uint64_t n; // GGUFv2 - void * data; - } arr; -}; - -struct gguf_kv { - struct gguf_str key; - - enum gguf_type type; - union gguf_value value; -}; - -struct gguf_header { - uint32_t magic; - uint32_t version; - uint64_t n_tensors; // GGUFv2 - uint64_t n_kv; // GGUFv2 -}; - -struct gguf_tensor_info { - struct gguf_str name; - - uint32_t n_dims; - uint64_t ne[GGML_MAX_DIMS]; - - enum ggml_type type; - - uint64_t offset; // offset from start of `data`, must be a multiple of `ALIGNMENT` - - // for writing API - const void * data; - size_t size; -}; - -struct gguf_context { - struct gguf_header header; - - struct gguf_kv * kv; - struct gguf_tensor_info * infos; - - size_t alignment; - size_t offset; // offset of `data` from beginning of file - size_t size; // size of `data` in bytes - - //uint8_t * padding; - void * data; -}; - -static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset) { - const size_t n = fread(dst, 1, size, file); - *offset += n; - return n == size; -} - -// NOTE: temporary handling of GGUFv1 >> remove after Oct 2023 -static bool gguf_fread_str_cur(FILE * file, struct gguf_str * p, size_t * offset) { - p->n = 0; - p->data = NULL; - - bool ok = true; - - ok = ok && gguf_fread_el(file, &p->n, sizeof(p->n), offset); p->data = calloc(p->n + 1, 1); - ok = ok && gguf_fread_el(file, p->data, p->n, offset); - - return ok; -} - -static bool gguf_fread_str_v1(FILE * file, struct gguf_str * p, size_t * offset) { - p->n = 0; - p->data = NULL; - - bool ok = true; - - uint32_t n = 0; - ok = ok && gguf_fread_el(file, &n, sizeof(n), offset); p->data = calloc(n + 1, 1); p->n = n; - ok = ok && gguf_fread_el(file, p->data, p->n, offset); - - return ok; -} - -struct gguf_context * gguf_init_empty(void) { - struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context)); - - ctx->header.magic = GGUF_MAGIC; - ctx->header.version = GGUF_VERSION; - ctx->header.n_tensors = 0; - ctx->header.n_kv = 0; - - ctx->kv = NULL; - ctx->infos = NULL; - - ctx->alignment = GGUF_DEFAULT_ALIGNMENT; - ctx->offset = 0; - ctx->size = 0; - - ctx->data = NULL; - - return ctx; -} - -struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) { - FILE * file = fopen(fname, "rb"); - if (!file) { - return NULL; - } - - // offset from start of file - size_t offset = 0; - - uint32_t magic = 0; - - // check the magic before making allocations - { - gguf_fread_el(file, &magic, sizeof(magic), &offset); - - if (magic != GGUF_MAGIC) { - fprintf(stderr, "%s: invalid magic number %08x\n", __func__, magic); - fclose(file); - return NULL; - } - } - - bool ok = true; - - struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context)); - - // read the header - { - ctx->header.magic = magic; - - ctx->kv = NULL; - ctx->infos = NULL; - ctx->data = NULL; - - ok = ok && gguf_fread_el(file, &ctx->header.version, sizeof(ctx->header.version), &offset); - - if (ctx->header.version == 1) { - // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023 - uint32_t n_tensors = 0; - uint32_t n_kv = 0; - - ok = ok && gguf_fread_el(file, &n_tensors, sizeof(n_tensors), &offset); - ok = ok && gguf_fread_el(file, &n_kv, sizeof(n_kv), &offset); - - ctx->header.n_tensors = n_tensors; - ctx->header.n_kv = n_kv; - } else { - ok = ok && gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset); - ok = ok && gguf_fread_el(file, &ctx->header.n_kv, sizeof(ctx->header.n_kv), &offset); - } - - if (!ok) { - fprintf(stderr, "%s: failed to read header\n", __func__); - fclose(file); - gguf_free(ctx); - return NULL; - } - } - - // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023 - bool (* gguf_fread_str)(FILE *, struct gguf_str *, size_t *) = gguf_fread_str_cur; - if (ctx->header.version == 1) { - gguf_fread_str = gguf_fread_str_v1; - } - - // read the kv pairs - { - ctx->kv = malloc(ctx->header.n_kv * sizeof(struct gguf_kv)); - - for (uint32_t i = 0; i < ctx->header.n_kv; ++i) { - struct gguf_kv * kv = &ctx->kv[i]; - - //fprintf(stderr, "%s: reading kv %d\n", __func__, i); - - ok = ok && gguf_fread_str(file, &kv->key, &offset); - ok = ok && gguf_fread_el (file, &kv->type, sizeof(kv->type), &offset); - - //fprintf(stderr, "%s: reading kv with key %s\n", __func__, kv->key.data); - - switch (kv->type) { - case GGUF_TYPE_UINT8: ok = ok && gguf_fread_el (file, &kv->value.uint8, sizeof(kv->value.uint8), &offset); break; - case GGUF_TYPE_INT8: ok = ok && gguf_fread_el (file, &kv->value.int8, sizeof(kv->value.int8), &offset); break; - case GGUF_TYPE_UINT16: ok = ok && gguf_fread_el (file, &kv->value.uint16, sizeof(kv->value.uint16), &offset); break; - case GGUF_TYPE_INT16: ok = ok && gguf_fread_el (file, &kv->value.int16, sizeof(kv->value.int16), &offset); break; - case GGUF_TYPE_UINT32: ok = ok && gguf_fread_el (file, &kv->value.uint32, sizeof(kv->value.uint32), &offset); break; - case GGUF_TYPE_INT32: ok = ok && gguf_fread_el (file, &kv->value.int32, sizeof(kv->value.int32), &offset); break; - case GGUF_TYPE_FLOAT32: ok = ok && gguf_fread_el (file, &kv->value.float32, sizeof(kv->value.float32), &offset); break; - case GGUF_TYPE_UINT64: ok = ok && gguf_fread_el (file, &kv->value.uint64, sizeof(kv->value.uint64), &offset); break; - case GGUF_TYPE_INT64: ok = ok && gguf_fread_el (file, &kv->value.int64, sizeof(kv->value.int64), &offset); break; - case GGUF_TYPE_FLOAT64: ok = ok && gguf_fread_el (file, &kv->value.float64, sizeof(kv->value.float64), &offset); break; - case GGUF_TYPE_BOOL: ok = ok && gguf_fread_el (file, &kv->value.bool_, sizeof(kv->value.bool_), &offset); break; - case GGUF_TYPE_STRING: ok = ok && gguf_fread_str(file, &kv->value.str, &offset); break; - case GGUF_TYPE_ARRAY: - { - ok = ok && gguf_fread_el(file, &kv->value.arr.type, sizeof(kv->value.arr.type), &offset); - - if (ctx->header.version == 1) { - // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023 - uint32_t n = 0; - ok = ok && gguf_fread_el(file, &n, sizeof(n), &offset); - kv->value.arr.n = n; - } else { - ok = ok && gguf_fread_el(file, &kv->value.arr.n, sizeof(kv->value.arr.n), &offset); - } - - switch (kv->value.arr.type) { - case GGUF_TYPE_UINT8: - case GGUF_TYPE_INT8: - case GGUF_TYPE_UINT16: - case GGUF_TYPE_INT16: - case GGUF_TYPE_UINT32: - case GGUF_TYPE_INT32: - case GGUF_TYPE_FLOAT32: - case GGUF_TYPE_UINT64: - case GGUF_TYPE_INT64: - case GGUF_TYPE_FLOAT64: - case GGUF_TYPE_BOOL: - { - kv->value.arr.data = malloc(kv->value.arr.n * GGUF_TYPE_SIZE[kv->value.arr.type]); - ok = ok && gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * GGUF_TYPE_SIZE[kv->value.arr.type], &offset); - } break; - case GGUF_TYPE_STRING: - { - kv->value.arr.data = malloc(kv->value.arr.n * sizeof(struct gguf_str)); - for (uint32_t j = 0; j < kv->value.arr.n; ++j) { - ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset); - } - } break; - case GGUF_TYPE_ARRAY: - case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break; - } - } break; - case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); - } - - if (!ok) { - break; - } - } - - if (!ok) { - fprintf(stderr, "%s: failed to read key-value pairs\n", __func__); - fclose(file); - gguf_free(ctx); - return NULL; - } - } - - // read the tensor infos - { - ctx->infos = malloc(ctx->header.n_tensors * sizeof(struct gguf_tensor_info)); - - for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { - struct gguf_tensor_info * info = &ctx->infos[i]; - - for (int j = 0; j < GGML_MAX_DIMS; ++j) { - info->ne[j] = 1; - } - - ok = ok && gguf_fread_str(file, &info->name, &offset); - ok = ok && gguf_fread_el (file, &info->n_dims, sizeof(info->n_dims), &offset); - for (uint32_t j = 0; j < info->n_dims; ++j) { - if (ctx->header.version == 1) { - // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023 - uint32_t t = 0; - ok = ok && gguf_fread_el(file, &t, sizeof(t), &offset); - info->ne[j] = t; - } else { - ok = ok && gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset); - } - } - ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset); - ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset); - - if (!ok) { - fprintf(stderr, "%s: failed to read tensor info\n", __func__); - fclose(file); - gguf_free(ctx); - return NULL; - } - } - } - - ctx->alignment = GGUF_DEFAULT_ALIGNMENT; - - int alignment_idx = gguf_find_key(ctx, "general.alignment"); - if (alignment_idx != -1) { - ctx->alignment = gguf_get_val_u32(ctx, alignment_idx); - } - - // we require the data section to be aligned, so take into account any padding - { - const size_t offset_pad = offset % ctx->alignment; - - if (offset_pad != 0) { - offset += ctx->alignment - offset_pad; - fseek(file, offset, SEEK_SET); - } - } - - // store the current file offset - this is where the data section starts - ctx->offset = offset; - - // compute the total size of the data section, taking into account the alignment - { - ctx->size = 0; - for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { - struct gguf_tensor_info * info = &ctx->infos[i]; - - const int64_t ne = - (int64_t) info->ne[0] * - (int64_t) info->ne[1] * - (int64_t) info->ne[2] * - (int64_t) info->ne[3]; - - if (ne % ggml_blck_size(info->type) != 0) { - fprintf(stderr, "%s: tensor '%s' number of elements (%" PRId64 ") is not a multiple of block size (%d)\n", - __func__, info->name.data, ne, ggml_blck_size(info->type)); - fclose(file); - gguf_free(ctx); - return NULL; - } - - const size_t size_cur = (ne*ggml_type_size(info->type))/ggml_blck_size(info->type); - - ctx->size += GGML_PAD(size_cur, ctx->alignment); - } - } - - // load the tensor data only if requested - if (params.ctx != NULL) { - // if the provided gguf_context is no_alloc, then we create "empty" tensors and do not read the binary blob - // otherwise, we load the binary blob into the created ggml_context as well, and point the "data" members of - // the ggml_tensor structs to the appropriate locations in the binary blob - - // compute the exact size needed for the new ggml_context - const size_t mem_size = - params.no_alloc ? - (ctx->header.n_tensors )*ggml_tensor_overhead() : - (ctx->header.n_tensors + 1)*ggml_tensor_overhead() + ctx->size; - - struct ggml_init_params pdata = { - .mem_size = mem_size, - .mem_buffer = NULL, - .no_alloc = params.no_alloc, - }; - - *params.ctx = ggml_init(pdata); - - struct ggml_context * ctx_data = *params.ctx; - - struct ggml_tensor * data = NULL; - - if (!params.no_alloc) { - data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size); - - ok = ok && data != NULL; - - // read the binary blob with the tensor data - ok = ok && gguf_fread_el(file, data->data, ctx->size, &offset); - - if (!ok) { - fprintf(stderr, "%s: failed to read tensor data\n", __func__); - fclose(file); - ggml_free(ctx_data); - gguf_free(ctx); - return NULL; - } - - ctx->data = data->data; - } - - ggml_set_no_alloc(ctx_data, true); - - // create the tensors - for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { - const int64_t ne[GGML_MAX_DIMS] = { - ctx->infos[i].ne[0], - ctx->infos[i].ne[1], - ctx->infos[i].ne[2], - ctx->infos[i].ne[3], - }; - - struct ggml_tensor * cur = ggml_new_tensor(ctx_data, ctx->infos[i].type, ctx->infos[i].n_dims, ne); - - ok = ok && cur != NULL; - - ggml_set_name(cur, ctx->infos[i].name.data); - - if (!ok) { - break; - } - - // point the data member to the appropriate location in the binary blob using the tensor infos - if (!params.no_alloc) { - //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file - cur->data = (char *) data->data + ctx->infos[i].offset; // offset from data - } - } - - if (!ok) { - fprintf(stderr, "%s: failed to read the tensor data\n", __func__); - fclose(file); - ggml_free(ctx_data); - gguf_free(ctx); - return NULL; - } - - ggml_set_no_alloc(ctx_data, params.no_alloc); - } - - fclose(file); - - return ctx; -} - -void gguf_free(struct gguf_context * ctx) { - if (ctx == NULL) { - return; - } - - if (ctx->kv) { - // free string memory - not great.. - for (uint32_t i = 0; i < ctx->header.n_kv; ++i) { - struct gguf_kv * kv = &ctx->kv[i]; - - if (kv->key.data) { - free(kv->key.data); - } - - if (kv->type == GGUF_TYPE_STRING) { - if (kv->value.str.data) { - free(kv->value.str.data); - } - } - - if (kv->type == GGUF_TYPE_ARRAY) { - if (kv->value.arr.data) { - if (kv->value.arr.type == GGUF_TYPE_STRING) { - for (uint32_t j = 0; j < kv->value.arr.n; ++j) { - struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[j]; - if (str->data) { - free(str->data); - } - } - } - free(kv->value.arr.data); - } - } - } - - free(ctx->kv); - } - - if (ctx->infos) { - for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { - struct gguf_tensor_info * info = &ctx->infos[i]; - - if (info->name.data) { - free(info->name.data); - } - } - - free(ctx->infos); - } - - GGML_ALIGNED_FREE(ctx); -} - -const char * gguf_type_name(enum gguf_type type) { - return GGUF_TYPE_NAME[type]; -} - -int gguf_get_version(const struct gguf_context * ctx) { - return ctx->header.version; -} - -size_t gguf_get_alignment(const struct gguf_context * ctx) { - return ctx->alignment; -} - -size_t gguf_get_data_offset(const struct gguf_context * ctx) { - return ctx->offset; -} - -void * gguf_get_data(const struct gguf_context * ctx) { - return ctx->data; -} - -int gguf_get_n_kv(const struct gguf_context * ctx) { - return ctx->header.n_kv; -} - -int gguf_find_key(const struct gguf_context * ctx, const char * key) { - // return -1 if key not found - int keyfound = -1; - - const int n_kv = gguf_get_n_kv(ctx); - - for (int i = 0; i < n_kv; ++i) { - if (strcmp(key, gguf_get_key(ctx, i)) == 0) { - keyfound = i; - break; - } - } - - return keyfound; -} - -const char * gguf_get_key(const struct gguf_context * ctx, int key_id) { - return ctx->kv[key_id].key.data; -} - -enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) { - return ctx->kv[key_id].type; -} - -enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); - return ctx->kv[key_id].value.arr.type; -} - -const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); - return ctx->kv[key_id].value.arr.data; -} - -const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) { - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); - struct gguf_kv * kv = &ctx->kv[key_id]; - struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i]; - return str->data; -} - -int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); - return ctx->kv[key_id].value.arr.n; -} - -uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8); - return ctx->kv[key_id].value.uint8; -} - -int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8); - return ctx->kv[key_id].value.int8; -} - -uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16); - return ctx->kv[key_id].value.uint16; -} - -int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16); - return ctx->kv[key_id].value.int16; -} - -uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32); - return ctx->kv[key_id].value.uint32; -} - -int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32); - return ctx->kv[key_id].value.int32; -} - -float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32); - return ctx->kv[key_id].value.float32; -} - -uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64); - return ctx->kv[key_id].value.uint64; -} - -int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64); - return ctx->kv[key_id].value.int64; -} - -double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64); - return ctx->kv[key_id].value.float64; -} - -bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL); - return ctx->kv[key_id].value.bool_; -} - -const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING); - return ctx->kv[key_id].value.str.data; -} - -int gguf_get_n_tensors(const struct gguf_context * ctx) { - return ctx->header.n_tensors; -} - -int gguf_find_tensor(const struct gguf_context * ctx, const char * name) { - // return -1 if tensor not found - int tensorfound = -1; - - const int n_tensors = gguf_get_n_tensors(ctx); - - for (int i = 0; i < n_tensors; ++i) { - if (strcmp(name, gguf_get_tensor_name(ctx, i)) == 0) { - tensorfound = i; - break; - } - } - - return tensorfound; -} - -size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i) { - return ctx->infos[i].offset; -} - -char * gguf_get_tensor_name(const struct gguf_context * ctx, int i) { - return ctx->infos[i].name.data; -} - -// returns the index -static int gguf_get_or_add_key(struct gguf_context * ctx, const char * key) { - const int idx = gguf_find_key(ctx, key); - if (idx >= 0) { - return idx; - } - - const int n_kv = gguf_get_n_kv(ctx); - - ctx->kv = realloc(ctx->kv, (n_kv + 1) * sizeof(struct gguf_kv)); - ctx->kv[n_kv].key.n = strlen(key); - ctx->kv[n_kv].key.data = strdup(key); - ctx->header.n_kv++; - - return n_kv; -} - -void gguf_set_val_u8(struct gguf_context * ctx, const char * key, uint8_t val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_UINT8; - ctx->kv[idx].value.uint8 = val; -} - -void gguf_set_val_i8(struct gguf_context * ctx, const char * key, int8_t val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_INT8; - ctx->kv[idx].value.int8 = val; -} - -void gguf_set_val_u16(struct gguf_context * ctx, const char * key, uint16_t val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_UINT16; - ctx->kv[idx].value.uint16 = val; -} - -void gguf_set_val_i16(struct gguf_context * ctx, const char * key, int16_t val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_INT16; - ctx->kv[idx].value.int16 = val; -} - -void gguf_set_val_u32(struct gguf_context * ctx, const char * key, uint32_t val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_UINT32; - ctx->kv[idx].value.uint32 = val; -} - -void gguf_set_val_i32(struct gguf_context * ctx, const char * key, int32_t val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_INT32; - ctx->kv[idx].value.int32 = val; -} - -void gguf_set_val_f32(struct gguf_context * ctx, const char * key, float val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_FLOAT32; - ctx->kv[idx].value.float32 = val; -} - -void gguf_set_val_u64(struct gguf_context * ctx, const char * key, uint64_t val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_UINT64; - ctx->kv[idx].value.uint64 = val; -} - -void gguf_set_val_i64(struct gguf_context * ctx, const char * key, int64_t val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_INT64; - ctx->kv[idx].value.int64 = val; -} - -void gguf_set_val_f64(struct gguf_context * ctx, const char * key, double val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_FLOAT64; - ctx->kv[idx].value.float64 = val; -} - -void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_BOOL; - ctx->kv[idx].value.bool_ = val; -} - -void gguf_set_val_str(struct gguf_context * ctx, const char * key, const char * val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_STRING; - ctx->kv[idx].value.str.n = strlen(val); - ctx->kv[idx].value.str.data = strdup(val); -} - -void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_ARRAY; - ctx->kv[idx].value.arr.type = type; - ctx->kv[idx].value.arr.n = n; - ctx->kv[idx].value.arr.data = malloc(n*GGUF_TYPE_SIZE[type]); - memcpy(ctx->kv[idx].value.arr.data, data, n*GGUF_TYPE_SIZE[type]); -} - -void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char ** data, int n) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_ARRAY; - ctx->kv[idx].value.arr.type = GGUF_TYPE_STRING; - ctx->kv[idx].value.arr.n = n; - ctx->kv[idx].value.arr.data = malloc(n*sizeof(struct gguf_str)); - for (int i = 0; i < n; i++) { - struct gguf_str * str = &((struct gguf_str *)ctx->kv[idx].value.arr.data)[i]; - str->n = strlen(data[i]); - str->data = strdup(data[i]); - } -} - -// set or add KV pairs from another context -void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) { - for (uint32_t i = 0; i < src->header.n_kv; i++) { - switch (src->kv[i].type) { - case GGUF_TYPE_UINT8: gguf_set_val_u8 (ctx, src->kv[i].key.data, src->kv[i].value.uint8); break; - case GGUF_TYPE_INT8: gguf_set_val_i8 (ctx, src->kv[i].key.data, src->kv[i].value.int8); break; - case GGUF_TYPE_UINT16: gguf_set_val_u16 (ctx, src->kv[i].key.data, src->kv[i].value.uint16); break; - case GGUF_TYPE_INT16: gguf_set_val_i16 (ctx, src->kv[i].key.data, src->kv[i].value.int16); break; - case GGUF_TYPE_UINT32: gguf_set_val_u32 (ctx, src->kv[i].key.data, src->kv[i].value.uint32); break; - case GGUF_TYPE_INT32: gguf_set_val_i32 (ctx, src->kv[i].key.data, src->kv[i].value.int32); break; - case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (ctx, src->kv[i].key.data, src->kv[i].value.float32); break; - case GGUF_TYPE_UINT64: gguf_set_val_u64 (ctx, src->kv[i].key.data, src->kv[i].value.uint64); break; - case GGUF_TYPE_INT64: gguf_set_val_i64 (ctx, src->kv[i].key.data, src->kv[i].value.int64); break; - case GGUF_TYPE_FLOAT64: gguf_set_val_f64 (ctx, src->kv[i].key.data, src->kv[i].value.float64); break; - case GGUF_TYPE_BOOL: gguf_set_val_bool(ctx, src->kv[i].key.data, src->kv[i].value.bool_); break; - case GGUF_TYPE_STRING: gguf_set_val_str (ctx, src->kv[i].key.data, src->kv[i].value.str.data); break; - case GGUF_TYPE_ARRAY: - { - if (src->kv[i].value.arr.type == GGUF_TYPE_STRING) { - const char ** data = malloc(src->kv[i].value.arr.n*sizeof(char *)); - for (uint32_t j = 0; j < src->kv[i].value.arr.n; j++) { - data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data; - } - gguf_set_arr_str(ctx, src->kv[i].key.data, data, src->kv[i].value.arr.n); - free(data); - } else if (src->kv[i].value.arr.type == GGUF_TYPE_ARRAY) { - GGML_ASSERT(false && "nested arrays not supported"); - } else { - gguf_set_arr_data(ctx, src->kv[i].key.data, src->kv[i].value.arr.type, src->kv[i].value.arr.data, src->kv[i].value.arr.n); - } - } break; - case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break; - } - } -} - -void gguf_add_tensor( - struct gguf_context * ctx, - const struct ggml_tensor * tensor) { - const int idx = ctx->header.n_tensors; - ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_tensor_info)); - - ctx->infos[idx].name.n = strlen(tensor->name); - ctx->infos[idx].name.data = strdup(tensor->name); - - for (int i = 0; i < GGML_MAX_DIMS; ++i) { - ctx->infos[idx].ne[i] = 1; - } - - ctx->infos[idx].n_dims = tensor->n_dims; - for (int i = 0; i < tensor->n_dims; i++) { - ctx->infos[idx].ne[i] = tensor->ne[i]; - } - - ctx->infos[idx].type = tensor->type; - ctx->infos[idx].offset = 0; - ctx->infos[idx].data = tensor->data; - ctx->infos[idx].size = ggml_nbytes(tensor); - - if (ctx->header.n_tensors > 0) { - ctx->infos[idx].offset = ctx->infos[idx - 1].offset + GGML_PAD(ctx->infos[idx - 1].size, ctx->alignment); - } - - ctx->header.n_tensors++; -} - -void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type) { - const int idx = gguf_find_tensor(ctx, name); - if (idx < 0) { - GGML_ASSERT(false && "tensor not found"); - } - - ctx->infos[idx].type = type; -} - -void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size) { - const int idx = gguf_find_tensor(ctx, name); - if (idx < 0) { - GGML_ASSERT(false && "tensor not found"); - } - - ctx->infos[idx].data = data; - ctx->infos[idx].size = size; - - // update offsets - for (uint32_t i = idx + 1; i < ctx->header.n_tensors; ++i) { - ctx->infos[i].offset = ctx->infos[i - 1].offset + GGML_PAD(ctx->infos[i - 1].size, ctx->alignment); - } -} - -//static void gguf_fwrite_str(FILE * file, const struct gguf_str * val) { -// fwrite(&val->n, sizeof(val->n), 1, file); -// fwrite(val->data, sizeof(char), val->n, file); -//} -// -//static void gguf_fwrite_el(FILE * file, const void * val, size_t size) { -// fwrite(val, sizeof(char), size, file); -//} - -struct gguf_buf { - void * data; - size_t size; - size_t offset; -}; - -static struct gguf_buf gguf_buf_init(size_t size) { - struct gguf_buf buf = { - /*buf.data =*/ size == 0 ? NULL : malloc(size), - /*buf.size =*/ size, - /*buf.offset =*/ 0, - }; - - return buf; -} - -static void gguf_buf_free(struct gguf_buf buf) { - if (buf.data) { - free(buf.data); - } -} - -static void gguf_buf_grow(struct gguf_buf * buf, size_t size) { - if (buf->offset + size > buf->size) { - buf->size = 1.5*(buf->offset + size); - if (buf->data) { - buf->data = realloc(buf->data, buf->size); - } - } -} - -static void gguf_bwrite_str(struct gguf_buf * buf, const struct gguf_str * val) { - gguf_buf_grow(buf, sizeof(val->n) + val->n); - - if (buf->data) { - memcpy((char *) buf->data + buf->offset, &val->n, sizeof(val->n)); - } - buf->offset += sizeof(val->n); - - if (buf->data) { - memcpy((char *) buf->data + buf->offset, val->data, val->n); - } - buf->offset += val->n; -} - -static void gguf_bwrite_el(struct gguf_buf * buf, const void * val, size_t el_size) { - gguf_buf_grow(buf, el_size); - - if (buf->data) { - memcpy((char *) buf->data + buf->offset, val, el_size); - } - buf->offset += el_size; -} - -static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) { - // write header - gguf_bwrite_el(buf, &ctx->header.magic, sizeof(ctx->header.magic)); - gguf_bwrite_el(buf, &ctx->header.version, sizeof(ctx->header.version)); - gguf_bwrite_el(buf, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors)); - gguf_bwrite_el(buf, &ctx->header.n_kv, sizeof(ctx->header.n_kv)); - - // write key-value pairs - for (uint32_t i = 0; i < ctx->header.n_kv; ++i) { - struct gguf_kv * kv = &ctx->kv[i]; - - gguf_bwrite_str(buf, &kv->key); - gguf_bwrite_el (buf, &kv->type, sizeof(kv->type)); - - switch (kv->type) { - case GGUF_TYPE_UINT8: gguf_bwrite_el( buf, &kv->value.uint8, sizeof(kv->value.uint8) ); break; - case GGUF_TYPE_INT8: gguf_bwrite_el (buf, &kv->value.int8, sizeof(kv->value.int8) ); break; - case GGUF_TYPE_UINT16: gguf_bwrite_el (buf, &kv->value.uint16, sizeof(kv->value.uint16) ); break; - case GGUF_TYPE_INT16: gguf_bwrite_el (buf, &kv->value.int16, sizeof(kv->value.int16) ); break; - case GGUF_TYPE_UINT32: gguf_bwrite_el (buf, &kv->value.uint32, sizeof(kv->value.uint32) ); break; - case GGUF_TYPE_INT32: gguf_bwrite_el (buf, &kv->value.int32, sizeof(kv->value.int32) ); break; - case GGUF_TYPE_FLOAT32: gguf_bwrite_el (buf, &kv->value.float32, sizeof(kv->value.float32)); break; - case GGUF_TYPE_UINT64: gguf_bwrite_el (buf, &kv->value.uint64, sizeof(kv->value.uint64) ); break; - case GGUF_TYPE_INT64: gguf_bwrite_el (buf, &kv->value.int64, sizeof(kv->value.int64) ); break; - case GGUF_TYPE_FLOAT64: gguf_bwrite_el (buf, &kv->value.float64, sizeof(kv->value.float64)); break; - case GGUF_TYPE_BOOL: gguf_bwrite_el (buf, &kv->value.bool_, sizeof(kv->value.bool_) ); break; - case GGUF_TYPE_STRING: gguf_bwrite_str(buf, &kv->value.str ); break; - case GGUF_TYPE_ARRAY: - { - gguf_bwrite_el(buf, &kv->value.arr.type, sizeof(kv->value.arr.type)); - gguf_bwrite_el(buf, &kv->value.arr.n, sizeof(kv->value.arr.n) ); - - switch (kv->value.arr.type) { - case GGUF_TYPE_UINT8: - case GGUF_TYPE_INT8: - case GGUF_TYPE_UINT16: - case GGUF_TYPE_INT16: - case GGUF_TYPE_UINT32: - case GGUF_TYPE_INT32: - case GGUF_TYPE_FLOAT32: - case GGUF_TYPE_UINT64: - case GGUF_TYPE_INT64: - case GGUF_TYPE_FLOAT64: - case GGUF_TYPE_BOOL: - { - gguf_bwrite_el(buf, kv->value.arr.data, kv->value.arr.n * GGUF_TYPE_SIZE[kv->value.arr.type]); - } break; - case GGUF_TYPE_STRING: - { - for (uint32_t j = 0; j < kv->value.arr.n; ++j) { - gguf_bwrite_str(buf, &((struct gguf_str *) kv->value.arr.data)[j]); - } - } break; - case GGUF_TYPE_ARRAY: - case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break; - } - } break; - case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); - } - } - - // write tensor infos - for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { - struct gguf_tensor_info * info = &ctx->infos[i]; - - gguf_bwrite_str(buf, &info->name); - gguf_bwrite_el (buf, &info->n_dims, sizeof(info->n_dims)); - for (uint32_t j = 0; j < info->n_dims; ++j) { - gguf_bwrite_el(buf, &info->ne[j], sizeof(info->ne[j])); - } - gguf_bwrite_el(buf, &info->type, sizeof(info->type)); - gguf_bwrite_el(buf, &info->offset, sizeof(info->offset)); - } - - // we require the data section to be aligned, so take into account any padding - { - const size_t offset = buf->offset; - const size_t offset_pad = GGML_PAD(offset, ctx->alignment); - - if (offset_pad != offset) { - uint8_t pad = 0; - for (size_t i = 0; i < offset_pad - offset; ++i) { - gguf_bwrite_el(buf, &pad, sizeof(pad)); - } - } - } - - if (only_meta) { - return; - } - - size_t offset = 0; - - // write tensor data - for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { - struct gguf_tensor_info * info = &ctx->infos[i]; - - const size_t size = info->size; - const size_t size_pad = GGML_PAD(size, ctx->alignment); - - gguf_bwrite_el(buf, info->data, size); - - if (size_pad != size) { - uint8_t pad = 0; - for (size_t j = 0; j < size_pad - size; ++j) { - gguf_bwrite_el(buf, &pad, sizeof(pad)); - } - } - - GGML_ASSERT(offset == info->offset); - - offset += size_pad; - } -} - -void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) { - FILE * file = fopen(fname, "wb"); - if (!file) { - GGML_ASSERT(false && "failed to open file for writing"); - } - - struct gguf_buf buf = gguf_buf_init(16*1024); - - gguf_write_to_buf(ctx, &buf, only_meta); - - fwrite(buf.data, 1, buf.offset, file); - - gguf_buf_free(buf); - - fclose(file); -} - -size_t gguf_get_meta_size(const struct gguf_context * ctx) { - // no allocs - only compute size - struct gguf_buf buf = gguf_buf_init(0); - - gguf_write_to_buf(ctx, &buf, true); - - return buf.offset; -} - -void gguf_get_meta_data(const struct gguf_context * ctx, void * data) { - struct gguf_buf buf = gguf_buf_init(16*1024); - - gguf_write_to_buf(ctx, &buf, true); - - memcpy(data, buf.data, buf.offset); - - gguf_buf_free(buf); -} - -//////////////////////////////////////////////////////////////////////////////// - -int ggml_cpu_has_avx(void) { -#if defined(__AVX__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_avx2(void) { -#if defined(__AVX2__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_avx512(void) { -#if defined(__AVX512F__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_avx512_vbmi(void) { -#if defined(__AVX512VBMI__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_avx512_vnni(void) { -#if defined(__AVX512VNNI__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_fma(void) { -#if defined(__FMA__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_neon(void) { -#if defined(__ARM_NEON) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_arm_fma(void) { -#if defined(__ARM_FEATURE_FMA) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_metal(void) { -#if defined(GGML_USE_METAL) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_f16c(void) { -#if defined(__F16C__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_fp16_va(void) { -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_wasm_simd(void) { -#if defined(__wasm_simd128__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_blas(void) { -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_cublas(void) { -#if defined(GGML_USE_CUBLAS) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_clblast(void) { -#if defined(GGML_USE_CLBLAST) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_gpublas(void) { - return ggml_cpu_has_cublas() || ggml_cpu_has_clblast(); -} - -int ggml_cpu_has_sse3(void) { -#if defined(__SSE3__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_ssse3(void) { -#if defined(__SSSE3__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_vsx(void) { -#if defined(__POWER9_VECTOR__) - return 1; -#else - return 0; -#endif -} - -//////////////////////////////////////////////////////////////////////////////// diff --git a/plugins/wasi_nn/thirdparty/ggml/ggml.h b/plugins/wasi_nn/thirdparty/ggml/ggml.h deleted file mode 100644 index 6e35888e..00000000 --- a/plugins/wasi_nn/thirdparty/ggml/ggml.h +++ /dev/null @@ -1,2114 +0,0 @@ -#pragma once - -// -// GGML Tensor Library -// -// This documentation is still a work in progress. -// If you wish some specific topics to be covered, feel free to drop a comment: -// -// https://github.com/ggerganov/whisper.cpp/issues/40 -// -// ## Overview -// -// This library implements: -// -// - a set of tensor operations -// - automatic differentiation -// - basic optimization algorithms -// -// The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes, -// but is not limited to, the following: -// -// - linear regression -// - support vector machines -// - neural networks -// -// The library allows the user to define a certain function using the available tensor operations. This function -// definition is represented internally via a computation graph. Each tensor operation in the function definition -// corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the -// function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized -// using one of the available optimization algorithms. -// -// For example, here we define the function: f(x) = a*x^2 + b -// -// { -// struct ggml_init_params params = { -// .mem_size = 16*1024*1024, -// .mem_buffer = NULL, -// }; -// -// // memory allocation happens here -// struct ggml_context * ctx = ggml_init(params); -// -// struct ggml_tensor * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); -// -// ggml_set_param(ctx, x); // x is an input variable -// -// struct ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); -// struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); -// struct ggml_tensor * x2 = ggml_mul(ctx, x, x); -// struct ggml_tensor * f = ggml_add(ctx, ggml_mul(ctx, a, x2), b); -// -// ... -// } -// -// Notice that the function definition above does not involve any actual computation. The computation is performed only -// when the user explicitly requests it. For example, to compute the function's value at x = 2.0: -// -// { -// ... -// -// struct ggml_cgraph gf = ggml_build_forward(f); -// -// // set the input variable and parameter values -// ggml_set_f32(x, 2.0f); -// ggml_set_f32(a, 3.0f); -// ggml_set_f32(b, 4.0f); -// -// ggml_graph_compute_with_ctx(ctx, &gf, n_threads); -// -// printf("f = %f\n", ggml_get_f32_1d(f, 0)); -// -// ... -// } -// -// The actual computation is performed in the ggml_graph_compute() function. -// -// The ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the -// ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know -// in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory -// and after defining the computation graph, call the ggml_used_mem() function to find out how much memory was -// actually needed. -// -// The ggml_set_param() function marks a tensor as an input variable. This is used by the automatic -// differentiation and optimization algorithms. -// -// The described approach allows to define the function graph once and then compute its forward or backward graphs -// multiple times. All computations will use the same memory buffer allocated in the ggml_init() function. This way -// the user can avoid the memory allocation overhead at runtime. -// -// The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class -// citizens, but in theory the library can be extended to support FP8 and integer data types. -// -// Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary -// and binary operations. Most of the available operations fall into one of these two categories. With time, it became -// clear that the library needs to support more complex operations. The way to support these operations is not clear -// yet, but a few examples are demonstrated in the following operations: -// -// - ggml_permute() -// - ggml_conv_1d_1s() -// - ggml_conv_1d_2s() -// -// For each tensor operator, the library implements a forward and backward computation function. The forward function -// computes the output tensor value given the input tensor values. The backward function computes the adjoint of the -// input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a -// calculus class, or watch the following video: -// -// What is Automatic Differentiation? -// https://www.youtube.com/watch?v=wG_nF1awSSY -// -// -// ## Tensor data (struct ggml_tensor) -// -// The tensors are stored in memory via the ggml_tensor struct. The structure provides information about the size of -// the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains -// pointers to the "source" tensors - i.e. the tensors that were used to compute the current tensor. For example: -// -// { -// struct ggml_tensor * c = ggml_add(ctx, a, b); -// -// assert(c->src[0] == a); -// assert(c->src[1] == b); -// } -// -// The multi-dimensional tensors are stored in row-major order. The ggml_tensor struct contains fields for the -// number of elements in each dimension ("ne") as well as the number of bytes ("nb", a.k.a. stride). This allows -// to store tensors that are not contiguous in memory, which is useful for operations such as transposition and -// permutation. All tensor operations have to take the stride into account and not assume that the tensor is -// contiguous in memory. -// -// The data of the tensor is accessed via the "data" pointer. For example: -// -// { -// const int nx = 2; -// const int ny = 3; -// -// struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, ny); -// -// for (int y = 0; y < ny; y++) { -// for (int x = 0; x < nx; x++) { -// *(float *) ((char *) a->data + y*a->nb[1] + x*a->nb[0]) = x + y; -// } -// } -// -// ... -// } -// -// Alternatively, there are helper functions, such as ggml_get_f32_1d() and ggml_set_f32_1d() that can be used. -// -// ## The matrix multiplication operator (ggml_mul_mat) -// -// TODO -// -// -// ## Multi-threading -// -// TODO -// -// -// ## Overview of ggml.c -// -// TODO -// -// -// ## SIMD optimizations -// -// TODO -// -// -// ## Debugging ggml -// -// TODO -// -// - -#ifdef GGML_SHARED -# if defined(_WIN32) && !defined(__MINGW32__) -# ifdef GGML_BUILD -# define GGML_API __declspec(dllexport) -# else -# define GGML_API __declspec(dllimport) -# endif -# else -# define GGML_API __attribute__ ((visibility ("default"))) -# endif -#else -# define GGML_API -#endif - -// TODO: support for clang -#ifdef __GNUC__ -# define GGML_DEPRECATED(func, hint) func __attribute__((deprecated(hint))) -#elif defined(_MSC_VER) -# define GGML_DEPRECATED(func, hint) __declspec(deprecated(hint)) func -#else -# define GGML_DEPRECATED(func, hint) func -#endif - -#ifndef __GNUC__ -# define GGML_ATTRIBUTE_FORMAT(...) -#elif defined(__MINGW32__) -# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) -#else -# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) -#endif - -#include -#include -#include - -#define GGML_FILE_MAGIC 0x67676d6c // "ggml" -#define GGML_FILE_VERSION 1 - -#define GGML_QNT_VERSION 2 // bump this on quantization format changes -#define GGML_QNT_VERSION_FACTOR 1000 // do not change this - -#define GGML_MAX_DIMS 4 -#define GGML_MAX_NODES 16384 -#define GGML_MAX_PARAMS 1024 -#define GGML_MAX_CONTEXTS 64 -#define GGML_MAX_SRC 6 -#define GGML_MAX_NAME 64 -#define GGML_MAX_OP_PARAMS 32 -#define GGML_DEFAULT_N_THREADS 4 - -#if UINTPTR_MAX == 0xFFFFFFFF - #define GGML_MEM_ALIGN 4 -#else - #define GGML_MEM_ALIGN 16 -#endif - -#define GGML_EXIT_SUCCESS 0 -#define GGML_EXIT_ABORTED 1 - -#define GGUF_MAGIC 0x46554747 // "GGUF" -#define GGUF_VERSION 2 - -#define GGUF_DEFAULT_ALIGNMENT 32 - -#define GGML_UNUSED(x) (void)(x) - -#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1)) - -#define GGML_ASSERT(x) \ - do { \ - if (!(x)) { \ - fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ - abort(); \ - } \ - } while (0) - -#ifndef NDEBUG -#define GGML_UNREACHABLE() GGML_ASSERT(!"statement should not be reached") -#elif defined(__GNUC__) -#define GGML_UNREACHABLE() __builtin_unreachable() -#else -#define GGML_UNREACHABLE() ((void) 0) -#endif - -// used to copy the number of elements and stride in bytes of tensors into local variables. -// main purpose is to reduce code duplication and improve readability. -// -// example: -// -// GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); -// GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); -// -#define GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \ - const type prefix##0 = (pointer)->array[0]; \ - GGML_UNUSED(prefix##0); -#define GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \ - GGML_TENSOR_LOCALS_1 (type, prefix, pointer, array) \ - const type prefix##1 = (pointer)->array[1]; \ - GGML_UNUSED(prefix##1); -#define GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \ - GGML_TENSOR_LOCALS_2 (type, prefix, pointer, array) \ - const type prefix##2 = (pointer)->array[2]; \ - GGML_UNUSED(prefix##2); -#define GGML_TENSOR_LOCALS(type, prefix, pointer, array) \ - GGML_TENSOR_LOCALS_3 (type, prefix, pointer, array) \ - const type prefix##3 = (pointer)->array[3]; \ - GGML_UNUSED(prefix##3); - -#ifdef __cplusplus -extern "C" { -#endif - -#if defined(__ARM_NEON) && defined(__CUDACC__) - typedef half ggml_fp16_t; -#elif defined(__ARM_NEON) - typedef __fp16 ggml_fp16_t; -#else - typedef uint16_t ggml_fp16_t; -#endif - - // convert FP16 <-> FP32 - GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x); - GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x); - - GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int n); - GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n); - - struct ggml_object; - struct ggml_context; - - enum ggml_type { - GGML_TYPE_F32 = 0, - GGML_TYPE_F16 = 1, - GGML_TYPE_Q4_0 = 2, - GGML_TYPE_Q4_1 = 3, - // GGML_TYPE_Q4_2 = 4, support has been removed - // GGML_TYPE_Q4_3 (5) support has been removed - GGML_TYPE_Q5_0 = 6, - GGML_TYPE_Q5_1 = 7, - GGML_TYPE_Q8_0 = 8, - GGML_TYPE_Q8_1 = 9, - // k-quantizations - GGML_TYPE_Q2_K = 10, - GGML_TYPE_Q3_K = 11, - GGML_TYPE_Q4_K = 12, - GGML_TYPE_Q5_K = 13, - GGML_TYPE_Q6_K = 14, - GGML_TYPE_Q8_K = 15, - GGML_TYPE_I8, - GGML_TYPE_I16, - GGML_TYPE_I32, - GGML_TYPE_COUNT, - }; - - enum ggml_backend_type { - GGML_BACKEND_CPU = 0, - GGML_BACKEND_GPU = 10, - GGML_BACKEND_GPU_SPLIT = 20, - }; - - // model file types - enum ggml_ftype { - GGML_FTYPE_UNKNOWN = -1, - GGML_FTYPE_ALL_F32 = 0, - GGML_FTYPE_MOSTLY_F16 = 1, // except 1d tensors - GGML_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors - GGML_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors - GGML_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 - GGML_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors - GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors - GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors - GGML_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors - GGML_FTYPE_MOSTLY_Q3_K = 11, // except 1d tensors - GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors - GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors - GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors - }; - - // available tensor operations: - enum ggml_op { - GGML_OP_NONE = 0, - - GGML_OP_DUP, - GGML_OP_ADD, - GGML_OP_ADD1, - GGML_OP_ACC, - GGML_OP_SUB, - GGML_OP_MUL, - GGML_OP_DIV, - GGML_OP_SQR, - GGML_OP_SQRT, - GGML_OP_LOG, - GGML_OP_SUM, - GGML_OP_SUM_ROWS, - GGML_OP_MEAN, - GGML_OP_ARGMAX, - GGML_OP_REPEAT, - GGML_OP_REPEAT_BACK, - GGML_OP_CONCAT, - GGML_OP_SILU_BACK, - GGML_OP_NORM, // normalize - GGML_OP_RMS_NORM, - GGML_OP_RMS_NORM_BACK, - GGML_OP_GROUP_NORM, - - GGML_OP_MUL_MAT, - GGML_OP_OUT_PROD, - - GGML_OP_SCALE, - GGML_OP_SET, - GGML_OP_CPY, - GGML_OP_CONT, - GGML_OP_RESHAPE, - GGML_OP_VIEW, - GGML_OP_PERMUTE, - GGML_OP_TRANSPOSE, - GGML_OP_GET_ROWS, - GGML_OP_GET_ROWS_BACK, - GGML_OP_DIAG, - GGML_OP_DIAG_MASK_INF, - GGML_OP_DIAG_MASK_ZERO, - GGML_OP_SOFT_MAX, - GGML_OP_SOFT_MAX_BACK, - GGML_OP_ROPE, - GGML_OP_ROPE_BACK, - GGML_OP_ALIBI, - GGML_OP_CLAMP, - GGML_OP_CONV_1D, - GGML_OP_CONV_2D, - GGML_OP_CONV_TRANSPOSE_1D, - GGML_OP_CONV_TRANSPOSE_2D, - GGML_OP_POOL_1D, - GGML_OP_POOL_2D, - - GGML_OP_CONV_1D_STAGE_0, // internal - GGML_OP_CONV_1D_STAGE_1, // internal - - GGML_OP_UPSCALE, // nearest interpolate - - GGML_OP_FLASH_ATTN, - GGML_OP_FLASH_FF, - GGML_OP_FLASH_ATTN_BACK, - GGML_OP_WIN_PART, - GGML_OP_WIN_UNPART, - GGML_OP_GET_REL_POS, - GGML_OP_ADD_REL_POS, - - GGML_OP_UNARY, - - GGML_OP_MAP_UNARY, - GGML_OP_MAP_BINARY, - - GGML_OP_MAP_CUSTOM1_F32, - GGML_OP_MAP_CUSTOM2_F32, - GGML_OP_MAP_CUSTOM3_F32, - - GGML_OP_MAP_CUSTOM1, - GGML_OP_MAP_CUSTOM2, - GGML_OP_MAP_CUSTOM3, - - GGML_OP_CROSS_ENTROPY_LOSS, - GGML_OP_CROSS_ENTROPY_LOSS_BACK, - - GGML_OP_COUNT, - }; - - enum ggml_unary_op { - GGML_UNARY_OP_ABS, - GGML_UNARY_OP_SGN, - GGML_UNARY_OP_NEG, - GGML_UNARY_OP_STEP, - GGML_UNARY_OP_TANH, - GGML_UNARY_OP_ELU, - GGML_UNARY_OP_RELU, - GGML_UNARY_OP_GELU, - GGML_UNARY_OP_GELU_QUICK, - GGML_UNARY_OP_SILU, - }; - - enum ggml_object_type { - GGML_OBJECT_TENSOR, - GGML_OBJECT_GRAPH, - GGML_OBJECT_WORK_BUFFER - }; - - enum ggml_log_level { - GGML_LOG_LEVEL_ERROR = 2, - GGML_LOG_LEVEL_WARN = 3, - GGML_LOG_LEVEL_INFO = 4 - }; - - // ggml object - struct ggml_object { - size_t offs; - size_t size; - - struct ggml_object * next; - - enum ggml_object_type type; - - char padding[4]; - }; - - static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object); - - // n-dimensional tensor - struct ggml_tensor { - enum ggml_type type; - enum ggml_backend_type backend; - - struct ggml_backend_buffer * buffer; - - int n_dims; - int64_t ne[GGML_MAX_DIMS]; // number of elements - size_t nb[GGML_MAX_DIMS]; // stride in bytes: - // nb[0] = ggml_type_size(type) - // nb[1] = nb[0] * (ne[0] / ggml_blck_size(type)) + padding - // nb[i] = nb[i-1] * ne[i-1] - - // compute data - enum ggml_op op; - - // op params - allocated as int32_t for alignment - int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; - - bool is_param; - - struct ggml_tensor * grad; - struct ggml_tensor * src[GGML_MAX_SRC]; - - // performance - int perf_runs; - int64_t perf_cycles; - int64_t perf_time_us; - - struct ggml_tensor * view_src; - size_t view_offs; - - void * data; - - char name[GGML_MAX_NAME]; - - void * extra; // extra things e.g. for ggml-cuda.cu - - char padding[12]; - }; - - static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); - - // the compute plan that needs to be prepared for ggml_graph_compute() - // since https://github.com/ggerganov/ggml/issues/287 - struct ggml_cplan { - size_t work_size; // size of work buffer, calculated by `ggml_graph_plan()` - uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()` - - int n_threads; - - // the `n_tasks` of nodes, 1:1 mapping to cgraph nodes - int n_tasks[GGML_MAX_NODES]; - - // abort ggml_graph_compute when true - bool (*abort_callback)(void * data); - void * abort_callback_data; - }; - - // next prime after GGML_MAX_NODES - // #define GGML_GRAPH_HASHTABLE_SIZE 4099 - // next prime after GGML_MAX_NODES * 2 (nodes + leafs) - // #define GGML_GRAPH_HASHTABLE_SIZE 8273 - // #define GGML_GRAPH_HASHTABLE_SIZE 16411 - #define GGML_GRAPH_HASHTABLE_SIZE 32771 - - enum ggml_cgraph_eval_order { - GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0, - GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT, - GGML_CGRAPH_EVAL_ORDER_COUNT - }; - - // computation graph - struct ggml_cgraph { - int n_nodes; - int n_leafs; - - struct ggml_tensor * nodes[GGML_MAX_NODES]; - struct ggml_tensor * grads[GGML_MAX_NODES]; - struct ggml_tensor * leafs[GGML_MAX_NODES]; - - void * visited_hash_table[GGML_GRAPH_HASHTABLE_SIZE]; - - enum ggml_cgraph_eval_order order; - - // performance - int perf_runs; - int64_t perf_cycles; - int64_t perf_time_us; - }; - - static const size_t GGML_GRAPH_SIZE = sizeof(struct ggml_cgraph); - - // scratch buffer - struct ggml_scratch { - size_t offs; - size_t size; - void * data; - }; - - struct ggml_init_params { - // memory pool - size_t mem_size; // bytes - void * mem_buffer; // if NULL, memory will be allocated internally - bool no_alloc; // don't allocate memory for the tensor data - }; - - - // compute types - - // NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled. - // This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995. - enum ggml_task_type { - GGML_TASK_INIT = 0, - GGML_TASK_COMPUTE, - GGML_TASK_FINALIZE, - }; - - struct ggml_compute_params { - enum ggml_task_type type; - - // ith = thread index, nth = number of threads - int ith, nth; - - // work buffer for all threads - size_t wsize; - void * wdata; - }; - - // misc - - GGML_API void ggml_time_init(void); // call this once at the beginning of the program - GGML_API int64_t ggml_time_ms(void); - GGML_API int64_t ggml_time_us(void); - GGML_API int64_t ggml_cycles(void); - GGML_API int64_t ggml_cycles_per_ms(void); - - GGML_API void ggml_numa_init(void); // call once for better performance on NUMA systems - GGML_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node - - GGML_API void ggml_print_object (const struct ggml_object * obj); - GGML_API void ggml_print_objects(const struct ggml_context * ctx); - - GGML_API int64_t ggml_nelements (const struct ggml_tensor * tensor); - GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor); - GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor); - GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN - GGML_API size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split); - - GGML_API int ggml_blck_size (enum ggml_type type); - GGML_API size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block - GGML_API float ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float - - GGML_API const char * ggml_type_name(enum ggml_type type); - GGML_API const char * ggml_op_name (enum ggml_op op); - GGML_API const char * ggml_op_symbol(enum ggml_op op); - - GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor); - - GGML_API bool ggml_is_quantized(enum ggml_type type); - - // TODO: temporary until model loading of ggml examples is refactored - GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype); - - GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor); - GGML_API bool ggml_is_contiguous(const struct ggml_tensor * tensor); - GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor); - - GGML_API bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1); - - // use this to compute the memory overhead of a tensor - GGML_API size_t ggml_tensor_overhead(void); - - // main - - GGML_API struct ggml_context * ggml_init(struct ggml_init_params params); - GGML_API void ggml_free(struct ggml_context * ctx); - - GGML_API size_t ggml_used_mem(const struct ggml_context * ctx); - - GGML_API size_t ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch); - GGML_API bool ggml_get_no_alloc(struct ggml_context * ctx); - GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc); - - GGML_API void * ggml_get_mem_buffer (const struct ggml_context * ctx); - GGML_API size_t ggml_get_mem_size (const struct ggml_context * ctx); - GGML_API size_t ggml_get_max_tensor_size(const struct ggml_context * ctx); - - GGML_API struct ggml_tensor * ggml_new_tensor( - struct ggml_context * ctx, - enum ggml_type type, - int n_dims, - const int64_t *ne); - - GGML_API struct ggml_tensor * ggml_new_tensor_1d( - struct ggml_context * ctx, - enum ggml_type type, - int64_t ne0); - - GGML_API struct ggml_tensor * ggml_new_tensor_2d( - struct ggml_context * ctx, - enum ggml_type type, - int64_t ne0, - int64_t ne1); - - GGML_API struct ggml_tensor * ggml_new_tensor_3d( - struct ggml_context * ctx, - enum ggml_type type, - int64_t ne0, - int64_t ne1, - int64_t ne2); - - GGML_API struct ggml_tensor * ggml_new_tensor_4d( - struct ggml_context * ctx, - enum ggml_type type, - int64_t ne0, - int64_t ne1, - int64_t ne2, - int64_t ne3); - - GGML_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value); - GGML_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value); - - GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src); - GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src); - - // Context tensor enumeration and lookup - GGML_API struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx); - GGML_API struct ggml_tensor * ggml_get_next_tensor (struct ggml_context * ctx, struct ggml_tensor * tensor); - GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name); - - GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); - GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value); - GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value); - - // Converts a flat index into coordinates - GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3); - - GGML_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i); - GGML_API void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value); - - GGML_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3); - GGML_API void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value); - - GGML_API float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i); - GGML_API void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value); - - GGML_API float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3); - GGML_API void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value); - - GGML_API void * ggml_get_data (const struct ggml_tensor * tensor); - GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor); - - GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor); - - GGML_API const char * ggml_get_name (const struct ggml_tensor * tensor); - GGML_API struct ggml_tensor * ggml_set_name ( struct ggml_tensor * tensor, const char * name); - GGML_ATTRIBUTE_FORMAT(2, 3) - GGML_API struct ggml_tensor * ggml_format_name( struct ggml_tensor * tensor, const char * fmt, ...); - - // - // operations on tensors with backpropagation - // - - GGML_API struct ggml_tensor * ggml_dup( - struct ggml_context * ctx, - struct ggml_tensor * a); - - // in-place, returns view(a) - GGML_API struct ggml_tensor * ggml_dup_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_add( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - GGML_API struct ggml_tensor * ggml_add_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - GGML_API struct ggml_tensor * ggml_add_cast( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - enum ggml_type type); - - GGML_API struct ggml_tensor * ggml_add1( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - GGML_API struct ggml_tensor * ggml_add1_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - GGML_API struct ggml_tensor * ggml_acc( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset); - - GGML_API struct ggml_tensor * ggml_acc_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset); - - GGML_API struct ggml_tensor * ggml_sub( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - GGML_API struct ggml_tensor * ggml_sub_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - GGML_API struct ggml_tensor * ggml_mul( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - GGML_API struct ggml_tensor * ggml_mul_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - GGML_API struct ggml_tensor * ggml_div( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - GGML_API struct ggml_tensor * ggml_div_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - GGML_API struct ggml_tensor * ggml_sqr( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_sqr_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_sqrt( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_sqrt_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_log( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_log_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a); - - // return scalar - GGML_API struct ggml_tensor * ggml_sum( - struct ggml_context * ctx, - struct ggml_tensor * a); - - // sums along rows, with input shape [a,b,c,d] return shape [1,b,c,d] - GGML_API struct ggml_tensor * ggml_sum_rows( - struct ggml_context * ctx, - struct ggml_tensor * a); - - // mean along rows - GGML_API struct ggml_tensor * ggml_mean( - struct ggml_context * ctx, - struct ggml_tensor * a); - - // argmax along rows - GGML_API struct ggml_tensor * ggml_argmax( - struct ggml_context * ctx, - struct ggml_tensor * a); - - // if a is the same shape as b, and a is not parameter, return a - // otherwise, return a new tensor: repeat(a) to fit in b - GGML_API struct ggml_tensor * ggml_repeat( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - // sums repetitions in a into shape of b - GGML_API struct ggml_tensor * ggml_repeat_back( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - // concat a and b on dim 2 - // used in stable-diffusion - GGML_API struct ggml_tensor * ggml_concat( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - GGML_API struct ggml_tensor * ggml_abs( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_abs_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_sgn( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_sgn_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_neg( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_neg_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_step( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_step_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_tanh( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_tanh_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_elu( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_elu_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_relu( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_relu_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a); - - // TODO: double-check this computation is correct - GGML_API struct ggml_tensor * ggml_gelu( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_gelu_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_gelu_quick( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_gelu_quick_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_silu( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_silu_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a); - - // a - x - // b - dy - GGML_API struct ggml_tensor * ggml_silu_back( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - // normalize along rows - GGML_API struct ggml_tensor * ggml_norm( - struct ggml_context * ctx, - struct ggml_tensor * a, - float eps); - - GGML_API struct ggml_tensor * ggml_norm_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - float eps); - - GGML_API struct ggml_tensor * ggml_rms_norm( - struct ggml_context * ctx, - struct ggml_tensor * a, - float eps); - - GGML_API struct ggml_tensor * ggml_rms_norm_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - float eps); - - // group normalize along ne0*ne1*n_groups - // used in stable-diffusion - // TODO: eps is hardcoded to 1e-6 for now - GGML_API struct ggml_tensor * ggml_group_norm( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_groups); - - GGML_API struct ggml_tensor * ggml_group_norm_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_groups); - - // a - x - // b - dy - GGML_API struct ggml_tensor * ggml_rms_norm_back( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - float eps); - - // A: n columns, m rows - // B: n columns, p rows (i.e. we transpose it internally) - // result is m columns, p rows - GGML_API struct ggml_tensor * ggml_mul_mat( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - // A: m columns, n rows, - // B: p columns, n rows, - // result is m columns, p rows - GGML_API struct ggml_tensor * ggml_out_prod( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - // - // operations on tensors without backpropagation - // - - GGML_API struct ggml_tensor * ggml_scale( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - // in-place, returns view(a) - GGML_API struct ggml_tensor * ggml_scale_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - // b -> view(a,offset,nb1,nb2,3), return modified a - GGML_API struct ggml_tensor * ggml_set( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset); - - // b -> view(a,offset,nb1,nb2,3), return view(a) - GGML_API struct ggml_tensor * ggml_set_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset); - - GGML_API struct ggml_tensor * ggml_set_1d( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - size_t offset); - - GGML_API struct ggml_tensor * ggml_set_1d_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - size_t offset); - - // b -> view(a,offset,nb1,nb2,3), return modified a - GGML_API struct ggml_tensor * ggml_set_2d( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - size_t nb1, - size_t offset); - - // b -> view(a,offset,nb1,nb2,3), return view(a) - GGML_API struct ggml_tensor * ggml_set_2d_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - size_t nb1, - size_t offset); - - // a -> b, return view(b) - GGML_API struct ggml_tensor * ggml_cpy( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - // a -> b, in-place, return view(b) - GGML_API struct ggml_tensor * ggml_cpy_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - // make contiguous - GGML_API struct ggml_tensor * ggml_cont( - struct ggml_context * ctx, - struct ggml_tensor * a); - - // make contiguous, in-place - GGML_API struct ggml_tensor * ggml_cont_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a); - - // make contiguous, with new shape - GGML_API struct ggml_tensor * ggml_cont_1d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int64_t ne0); - - GGML_API struct ggml_tensor * ggml_cont_2d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int64_t ne0, - int64_t ne1); - - GGML_API struct ggml_tensor * ggml_cont_3d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int64_t ne0, - int64_t ne1, - int64_t ne2); - - GGML_API struct ggml_tensor * ggml_cont_4d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int64_t ne0, - int64_t ne1, - int64_t ne2, - int64_t ne3); - - // return view(a), b specifies the new shape - // TODO: when we start computing gradient, make a copy instead of view - GGML_API struct ggml_tensor * ggml_reshape( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - // return view(a) - // TODO: when we start computing gradient, make a copy instead of view - GGML_API struct ggml_tensor * ggml_reshape_1d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int64_t ne0); - - GGML_API struct ggml_tensor * ggml_reshape_2d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int64_t ne0, - int64_t ne1); - - // return view(a) - // TODO: when we start computing gradient, make a copy instead of view - GGML_API struct ggml_tensor * ggml_reshape_3d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int64_t ne0, - int64_t ne1, - int64_t ne2); - - GGML_API struct ggml_tensor * ggml_reshape_4d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int64_t ne0, - int64_t ne1, - int64_t ne2, - int64_t ne3); - - // offset in bytes - GGML_API struct ggml_tensor * ggml_view_1d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int64_t ne0, - size_t offset); - - GGML_API struct ggml_tensor * ggml_view_2d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int64_t ne0, - int64_t ne1, - size_t nb1, // row stride in bytes - size_t offset); - - GGML_API struct ggml_tensor * ggml_view_3d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int64_t ne0, - int64_t ne1, - int64_t ne2, - size_t nb1, // row stride in bytes - size_t nb2, // slice stride in bytes - size_t offset); - - GGML_API struct ggml_tensor * ggml_view_4d( - struct ggml_context * ctx, - struct ggml_tensor * a, - int64_t ne0, - int64_t ne1, - int64_t ne2, - int64_t ne3, - size_t nb1, // row stride in bytes - size_t nb2, // slice stride in bytes - size_t nb3, - size_t offset); - - GGML_API struct ggml_tensor * ggml_permute( - struct ggml_context * ctx, - struct ggml_tensor * a, - int axis0, - int axis1, - int axis2, - int axis3); - - // alias for ggml_permute(ctx, a, 1, 0, 2, 3) - GGML_API struct ggml_tensor * ggml_transpose( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_get_rows( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - GGML_API struct ggml_tensor * ggml_get_rows_back( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c); - - GGML_API struct ggml_tensor * ggml_diag( - struct ggml_context * ctx, - struct ggml_tensor * a); - - // set elements above the diagonal to -INF - GGML_API struct ggml_tensor * ggml_diag_mask_inf( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_past); - - // in-place, returns view(a) - GGML_API struct ggml_tensor * ggml_diag_mask_inf_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_past); - - // set elements above the diagonal to 0 - GGML_API struct ggml_tensor * ggml_diag_mask_zero( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_past); - - // in-place, returns view(a) - GGML_API struct ggml_tensor * ggml_diag_mask_zero_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_past); - - GGML_API struct ggml_tensor * ggml_soft_max( - struct ggml_context * ctx, - struct ggml_tensor * a); - - // in-place, returns view(a) - GGML_API struct ggml_tensor * ggml_soft_max_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a); - - GGML_API struct ggml_tensor * ggml_soft_max_back( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - // in-place, returns view(a) - GGML_API struct ggml_tensor * ggml_soft_max_back_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - // rotary position embedding - // if mode & 1 == 1, skip n_past elements (DEPRECATED) - // if mode & 2 == 1, GPT-NeoX style - // if mode & 4 == 1, ChatGLM style - // - // b is an int32 vector with size a->ne[2], it contains the positions - GGML_API struct ggml_tensor * ggml_rope( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int n_dims, - int mode, - int n_ctx); - - // in-place, returns view(a) - GGML_API struct ggml_tensor * ggml_rope_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int n_dims, - int mode, - int n_ctx); - - // custom RoPE - GGML_API struct ggml_tensor * ggml_rope_custom( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int n_dims, - int mode, - int n_ctx, - float freq_base, - float freq_scale); - - // in-place, returns view(a) - GGML_API struct ggml_tensor * ggml_rope_custom_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int n_dims, - int mode, - int n_ctx, - float freq_base, - float freq_scale); - - // xPos RoPE, in-place, returns view(a) - GGML_API struct ggml_tensor * ggml_rope_xpos_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int n_dims, - float base, - bool down); - - // rotary position embedding backward, i.e compute dx from dy - // a - dy - GGML_API struct ggml_tensor * ggml_rope_back( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int n_dims, - int mode, - int n_ctx, - float freq_base, - float freq_scale, - float xpos_base, - bool xpos_down); - - // alibi position embedding - // in-place, returns view(a) - GGML_API struct ggml_tensor * ggml_alibi( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_past, - int n_head, - float bias_max); - - // clamp - // in-place, returns view(a) - GGML_API struct ggml_tensor * ggml_clamp( - struct ggml_context * ctx, - struct ggml_tensor * a, - float min, - float max); - - GGML_API struct ggml_tensor * ggml_conv_1d( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, // stride - int p0, // padding - int d0); // dilation - - // conv_1d with padding = half - // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d) - GGML_API struct ggml_tensor* ggml_conv_1d_ph( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s, - int d); - - GGML_API struct ggml_tensor * ggml_conv_transpose_1d( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int p0, - int d0); - - GGML_API struct ggml_tensor * ggml_conv_2d( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1); - - - // kernel size is a->ne[0] x a->ne[1] - // stride is equal to kernel size - // padding is zero - // example: - // a: 16 16 3 768 - // b: 1024 1024 3 1 - // res: 64 64 768 1 - // used in sam - GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - // kernel size is a->ne[0] x a->ne[1] - // stride is 1 - // padding is half - // example: - // a: 3 3 256 256 - // b: 64 64 256 1 - // res: 64 64 256 1 - // used in sam - GGML_API struct ggml_tensor * ggml_conv_2d_s1_ph( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int stride); - - enum ggml_op_pool { - GGML_OP_POOL_MAX, - GGML_OP_POOL_AVG, - GGML_OP_POOL_COUNT, - }; - - GGML_API struct ggml_tensor * ggml_pool_1d( - struct ggml_context * ctx, - struct ggml_tensor * a, - enum ggml_op_pool op, - int k0, // kernel size - int s0, // stride - int p0); // padding - - GGML_API struct ggml_tensor * ggml_pool_2d( - struct ggml_context * ctx, - struct ggml_tensor * a, - enum ggml_op_pool op, - int k0, - int k1, - int s0, - int s1, - int p0, - int p1); - - // nearest interpolate - // used in stable-diffusion - GGML_API struct ggml_tensor * ggml_upscale( - struct ggml_context * ctx, - struct ggml_tensor * a, - int scale_factor); - - GGML_API struct ggml_tensor * ggml_flash_attn( - struct ggml_context * ctx, - struct ggml_tensor * q, - struct ggml_tensor * k, - struct ggml_tensor * v, - bool masked); - - GGML_API struct ggml_tensor * ggml_flash_attn_back( - struct ggml_context * ctx, - struct ggml_tensor * q, - struct ggml_tensor * k, - struct ggml_tensor * v, - struct ggml_tensor * d, - bool masked); - - GGML_API struct ggml_tensor * ggml_flash_ff( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b0, - struct ggml_tensor * b1, - struct ggml_tensor * c0, - struct ggml_tensor * c1); - - // partition into non-overlapping windows with padding if needed - // example: - // a: 768 64 64 1 - // w: 14 - // res: 768 14 14 25 - // used in sam - GGML_API struct ggml_tensor * ggml_win_part( - struct ggml_context * ctx, - struct ggml_tensor * a, - int w); - - // reverse of ggml_win_part - // used in sam - GGML_API struct ggml_tensor * ggml_win_unpart( - struct ggml_context * ctx, - struct ggml_tensor * a, - int w0, - int h0, - int w); - - GGML_API struct ggml_tensor * ggml_unary( - struct ggml_context * ctx, - struct ggml_tensor * a, - enum ggml_unary_op op); - - GGML_API struct ggml_tensor * ggml_unary_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - enum ggml_unary_op op); - - // used in sam - GGML_API struct ggml_tensor * ggml_get_rel_pos( - struct ggml_context * ctx, - struct ggml_tensor * a, - int qh, - int kh); - - // used in sam - - GGML_API struct ggml_tensor * ggml_add_rel_pos( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * pw, - struct ggml_tensor * ph); - - GGML_API struct ggml_tensor * ggml_add_rel_pos_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * pw, - struct ggml_tensor * ph); - - // custom operators - - typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *); - typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *); - - typedef void (*ggml_custom1_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *); - typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *); - typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - ggml_unary_op_f32_t fun), - "use ggml_map_custom1 instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - ggml_unary_op_f32_t fun), - "use ggml_map_custom1_inplace instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - ggml_binary_op_f32_t fun), - "use ggml_map_custom2 instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - ggml_binary_op_f32_t fun), - "use ggml_map_custom2_inplace instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - ggml_custom1_op_f32_t fun), - "use ggml_map_custom1 instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - ggml_custom1_op_f32_t fun), - "use ggml_map_custom1_inplace instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - ggml_custom2_op_f32_t fun), - "use ggml_map_custom2 instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - ggml_custom2_op_f32_t fun), - "use ggml_map_custom2_inplace instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - ggml_custom3_op_f32_t fun), - "use ggml_map_custom3 instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - ggml_custom3_op_f32_t fun), - "use ggml_map_custom3_inplace instead"); - - // custom operators v2 - - typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); - typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata); - typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata); - - #define GGML_N_TASKS_MAX -1 - - GGML_API struct ggml_tensor * ggml_map_custom1( - struct ggml_context * ctx, - struct ggml_tensor * a, - ggml_custom1_op_t fun, - int n_tasks, - void * userdata); - - GGML_API struct ggml_tensor * ggml_map_custom1_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - ggml_custom1_op_t fun, - int n_tasks, - void * userdata); - - GGML_API struct ggml_tensor * ggml_map_custom2( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - ggml_custom2_op_t fun, - int n_tasks, - void * userdata); - - GGML_API struct ggml_tensor * ggml_map_custom2_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - ggml_custom2_op_t fun, - int n_tasks, - void * userdata); - - GGML_API struct ggml_tensor * ggml_map_custom3( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - ggml_custom3_op_t fun, - int n_tasks, - void * userdata); - - GGML_API struct ggml_tensor * ggml_map_custom3_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - ggml_custom3_op_t fun, - int n_tasks, - void * userdata); - - // loss function - - GGML_API struct ggml_tensor * ggml_cross_entropy_loss( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); - - GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c); - - // - // automatic differentiation - // - - GGML_API void ggml_set_param( - struct ggml_context * ctx, - struct ggml_tensor * tensor); - - - GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); - GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep); - - GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor); - GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep); - - // graph allocation in a context - GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); - GGML_API struct ggml_cgraph * ggml_build_forward_ctx(struct ggml_context * ctx, struct ggml_tensor * tensor); - GGML_API size_t ggml_graph_overhead(void); - - // ggml_graph_plan() has to be called before ggml_graph_compute() - // when plan.work_size > 0, caller must allocate memory for plan.work_data - GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/); - GGML_API int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan); - GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); - - // same as ggml_graph_compute() but the work data is allocated as a part of the context - // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data - GGML_API void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads); - - GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name); - - GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname); - GGML_API struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval); - - // print info and performance information for the graph - GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph); - - // dump the graph into a file using the dot format - GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename); - - // build gradient checkpointing backward graph gb for gf using provided checkpoints - // gb_tmp will contain original backward graph with rewritten backward process nodes, - // but without the second forward pass nodes. - GGML_API void ggml_build_backward_gradient_checkpointing( - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct ggml_cgraph * gb, - struct ggml_cgraph * gb_tmp, - struct ggml_tensor * * checkpoints, - int n_checkpoints); - // - // optimization - // - - // optimization methods - enum ggml_opt_type { - GGML_OPT_ADAM, - GGML_OPT_LBFGS, - }; - - // linesearch methods - enum ggml_linesearch { - GGML_LINESEARCH_DEFAULT = 1, - - GGML_LINESEARCH_BACKTRACKING_ARMIJO = 0, - GGML_LINESEARCH_BACKTRACKING_WOLFE = 1, - GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2, - }; - - // optimization return values - enum ggml_opt_result { - GGML_OPT_OK = 0, - GGML_OPT_DID_NOT_CONVERGE, - GGML_OPT_NO_CONTEXT, - GGML_OPT_INVALID_WOLFE, - GGML_OPT_FAIL, - GGML_OPT_CANCEL, - - GGML_LINESEARCH_FAIL = -128, - GGML_LINESEARCH_MINIMUM_STEP, - GGML_LINESEARCH_MAXIMUM_STEP, - GGML_LINESEARCH_MAXIMUM_ITERATIONS, - GGML_LINESEARCH_INVALID_PARAMETERS, - }; - - typedef void (*ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel); - typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data); - - // optimization parameters - // - // see ggml.c (ggml_opt_default_params) for default values - // - struct ggml_opt_params { - enum ggml_opt_type type; - - int n_threads; - - // delta-based convergence test - // - // if past == 0 - disabled - // if past > 0: - // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|) - // - int past; - float delta; - - // maximum number of iterations without improvement - // - // if 0 - disabled - // if > 0: - // assume convergence if no cost improvement in this number of iterations - // - int max_no_improvement; - - bool print_forward_graph; - bool print_backward_graph; - - int n_gradient_accumulation; - - // ADAM parameters - struct { - int n_iter; - - float sched; // schedule multiplier (fixed, decay or warmup) - float decay; // weight decay for AdamW, use 0.0f to disable - int decay_min_ndim; // minimum number of tensor dimension to apply weight decay - float alpha; // learning rate - float beta1; - float beta2; - float eps; // epsilon for numerical stability - float eps_f; // epsilon for convergence test - float eps_g; // epsilon for convergence test - float gclip; // gradient clipping - } adam; - - // LBFGS parameters - struct { - int m; // number of corrections to approximate the inv. Hessian - int n_iter; - int max_linesearch; - - float eps; // convergence tolerance - float ftol; // line search tolerance - float wolfe; - float min_step; - float max_step; - - enum ggml_linesearch linesearch; - } lbfgs; - }; - - struct ggml_opt_context { - struct ggml_context * ctx; - struct ggml_opt_params params; - - int iter; - int64_t nx; // number of parameter elements - - bool just_initialized; - - float loss_before; - float loss_after; - - struct { - struct ggml_tensor * g; // current gradient - struct ggml_tensor * m; // first moment - struct ggml_tensor * v; // second moment - struct ggml_tensor * pf; // past function values - float fx_best; - float fx_prev; - int n_no_improvement; - } adam; - - struct { - struct ggml_tensor * x; // current parameters - struct ggml_tensor * xp; // previous parameters - struct ggml_tensor * g; // current gradient - struct ggml_tensor * gp; // previous gradient - struct ggml_tensor * d; // search direction - struct ggml_tensor * pf; // past function values - struct ggml_tensor * lmal; // the L-BFGS memory alpha - struct ggml_tensor * lmys; // the L-BFGS memory ys - struct ggml_tensor * lms; // the L-BFGS memory s - struct ggml_tensor * lmy; // the L-BFGS memory y - float fx_best; - float step; - int j; - int k; - int end; - int n_no_improvement; - } lbfgs; - }; - - GGML_API struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type); - - // optimize the function defined by the tensor f - GGML_API enum ggml_opt_result ggml_opt( - struct ggml_context * ctx, - struct ggml_opt_params params, - struct ggml_tensor * f); - - // initialize optimizer context - GGML_API void ggml_opt_init( - struct ggml_context * ctx, - struct ggml_opt_context * opt, - struct ggml_opt_params params, - int64_t nx); - - // continue optimizing the function defined by the tensor f - GGML_API enum ggml_opt_result ggml_opt_resume( - struct ggml_context * ctx, - struct ggml_opt_context * opt, - struct ggml_tensor * f); - - // continue optimizing the function defined by the tensor f - GGML_API enum ggml_opt_result ggml_opt_resume_g( - struct ggml_context * ctx, - struct ggml_opt_context * opt, - struct ggml_tensor * f, - struct ggml_cgraph * gf, - struct ggml_cgraph * gb, - ggml_opt_callback callback, - void * callback_data); - - // - // quantization - // - - GGML_API size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist); - GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); - GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist); - GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist); - GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist); - - GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist); - - // - // gguf - // - - enum gguf_type { - GGUF_TYPE_UINT8 = 0, - GGUF_TYPE_INT8 = 1, - GGUF_TYPE_UINT16 = 2, - GGUF_TYPE_INT16 = 3, - GGUF_TYPE_UINT32 = 4, - GGUF_TYPE_INT32 = 5, - GGUF_TYPE_FLOAT32 = 6, - GGUF_TYPE_BOOL = 7, - GGUF_TYPE_STRING = 8, - GGUF_TYPE_ARRAY = 9, - GGUF_TYPE_UINT64 = 10, - GGUF_TYPE_INT64 = 11, - GGUF_TYPE_FLOAT64 = 12, - GGUF_TYPE_COUNT, // marks the end of the enum - }; - - struct gguf_context; - - struct gguf_init_params { - bool no_alloc; - - // if not NULL, create a ggml_context and allocate the tensor data in it - struct ggml_context ** ctx; - }; - - GGML_API struct gguf_context * gguf_init_empty(void); - GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params); - //GGML_API struct gguf_context * gguf_init_from_buffer(..); - - GGML_API void gguf_free(struct gguf_context * ctx); - - GGML_API const char * gguf_type_name(enum gguf_type type); - - GGML_API int gguf_get_version (const struct gguf_context * ctx); - GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx); - GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx); - GGML_API void * gguf_get_data (const struct gguf_context * ctx); - - GGML_API int gguf_get_n_kv(const struct gguf_context * ctx); - GGML_API int gguf_find_key(const struct gguf_context * ctx, const char * key); - GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int key_id); - - GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int key_id); - GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id); - - // will abort if the wrong type is used for the key - GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int key_id); - GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int key_id); - GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int key_id); - GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int key_id); - GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int key_id); - GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int key_id); - GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int key_id); - GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int key_id); - GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int key_id); - GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id); - GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id); - GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id); - GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id); - GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id); - GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i); - - GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx); - GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name); - GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i); - GGML_API char * gguf_get_tensor_name (const struct gguf_context * ctx, int i); - - // overrides existing values or adds a new one - GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val); - GGML_API void gguf_set_val_i8 (struct gguf_context * ctx, const char * key, int8_t val); - GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val); - GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t val); - GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val); - GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t val); - GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float val); - GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val); - GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t val); - GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double val); - GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val); - GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val); - GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n); - GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, int n); - - // set or add KV pairs from another context - GGML_API void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src); - - // manage tensor info - GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor); - GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type); - GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size); - - // writing gguf files can be done in 2 ways: - // - // - write the entire gguf_context to a binary file in a single pass: - // - // gguf_write_to_file(ctx, fname); - // - // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data: - // - // FILE * f = fopen(fname, "wb"); - // fseek(f, gguf_get_meta_size(ctx), SEEK_SET); - // fwrite(f, ...); - // void * data = gguf_meta_get_meta_data(ctx); - // fseek(f, 0, SEEK_SET); - // fwrite(f, data, gguf_get_meta_size(ctx)); - // free(data); - // fclose(f); - // - - // write the entire context to a binary file - GGML_API void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta); - - // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding - GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx); - GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data); - - // - // system info - // - - GGML_API int ggml_cpu_has_avx (void); - GGML_API int ggml_cpu_has_avx2 (void); - GGML_API int ggml_cpu_has_avx512 (void); - GGML_API int ggml_cpu_has_avx512_vbmi(void); - GGML_API int ggml_cpu_has_avx512_vnni(void); - GGML_API int ggml_cpu_has_fma (void); - GGML_API int ggml_cpu_has_neon (void); - GGML_API int ggml_cpu_has_arm_fma (void); - GGML_API int ggml_cpu_has_metal (void); - GGML_API int ggml_cpu_has_f16c (void); - GGML_API int ggml_cpu_has_fp16_va (void); - GGML_API int ggml_cpu_has_wasm_simd (void); - GGML_API int ggml_cpu_has_blas (void); - GGML_API int ggml_cpu_has_cublas (void); - GGML_API int ggml_cpu_has_clblast (void); - GGML_API int ggml_cpu_has_gpublas (void); - GGML_API int ggml_cpu_has_sse3 (void); - GGML_API int ggml_cpu_has_ssse3 (void); - GGML_API int ggml_cpu_has_vsx (void); - - // - // Internal types and functions exposed for tests and benchmarks - // - -#ifdef __cplusplus -// restrict not standard in C++ -#define GGML_RESTRICT -#else -#define GGML_RESTRICT restrict -#endif - typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); - typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); - typedef void (*ggml_vec_dot_t) (const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y); - - typedef struct { - const char * type_name; - int blck_size; - size_t type_size; - bool is_quantized; - ggml_to_float_t to_float; - ggml_from_float_t from_float; - ggml_from_float_t from_float_reference; - ggml_vec_dot_t vec_dot; - enum ggml_type vec_dot_type; - } ggml_type_traits_t; - - GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type); - -#ifdef __cplusplus -} -#endif diff --git a/plugins/wasi_nn/thirdparty/ggml/k_quants.c b/plugins/wasi_nn/thirdparty/ggml/k_quants.c deleted file mode 100644 index 558f5fda..00000000 --- a/plugins/wasi_nn/thirdparty/ggml/k_quants.c +++ /dev/null @@ -1,5060 +0,0 @@ -#include "k_quants.h" -#include "ggml.h" - -#include -#include -#include - -#ifdef __ARM_NEON - -// if YCM cannot find , make a symbolic link to it, for example: -// -// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ -// -#include - -#if !defined(__aarch64__) -inline static int32_t vaddvq_s16(int16x8_t v) { - return - (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) + - (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) + - (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) + - (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7); -} - -inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { - int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a)); - int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b)); - return vcombine_s16(a0, b0); -} - -inline static int32_t vaddvq_s32(int32x4_t v) { - return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); -} -#endif - -#else - -#ifdef __wasm_simd128__ -#include -#else -#ifdef __POWER9_VECTOR__ -#include -#undef bool -#define bool _Bool -#else -#if defined(_MSC_VER) || defined(__MINGW32__) -#include -#else -#if !defined(__riscv) -#include -#endif -#endif -#endif -#endif -#endif - -#ifdef __riscv_v_intrinsic -#include -#endif - -#undef MIN -#undef MAX -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define MAX(a, b) ((a) > (b) ? (a) : (b)) - -#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) - -// -// 2-6 bit quantization in super-blocks -// - -// -// ===================== Helper functions -// -static inline int nearest_int(float fval) { - assert(fval <= 4194303.f); - float val = fval + 12582912.f; - int i; memcpy(&i, &val, sizeof(int)); - return (i & 0x007fffff) - 0x00400000; -} - -static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type) { - float max = 0; - float amax = 0; - for (int i = 0; i < n; ++i) { - float ax = fabsf(x[i]); - if (ax > amax) { amax = ax; max = x[i]; } - } - if (amax < 1e-30f) { // all zero - for (int i = 0; i < n; ++i) { - L[i] = 0; - } - return 0.f; - } - float iscale = -nmax / max; - if (rmse_type == 0) { - for (int i = 0; i < n; ++i) { - int l = nearest_int(iscale * x[i]); - L[i] = nmax + MAX(-nmax, MIN(nmax-1, l)); - } - return 1/iscale; - } - bool return_early = false; - if (rmse_type < 0) { - rmse_type = -rmse_type; - return_early = true; - } - int weight_type = rmse_type%2; - float sumlx = 0; - float suml2 = 0; - for (int i = 0; i < n; ++i) { - int l = nearest_int(iscale * x[i]); - l = MAX(-nmax, MIN(nmax-1, l)); - L[i] = l + nmax; - float w = weight_type == 1 ? x[i] * x[i] : 1; - sumlx += w*x[i]*l; - suml2 += w*l*l; - } - float scale = sumlx/suml2; - if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale; - float best = scale * sumlx; - for (int is = -9; is <= 9; ++is) { - if (is == 0) { - continue; - } - iscale = -(nmax + 0.1f*is) / max; - sumlx = suml2 = 0; - for (int i = 0; i < n; ++i) { - int l = nearest_int(iscale * x[i]); - l = MAX(-nmax, MIN(nmax-1, l)); - float w = weight_type == 1 ? x[i] * x[i] : 1; - sumlx += w*x[i]*l; - suml2 += w*l*l; - } - if (suml2 > 0 && sumlx*sumlx > best*suml2) { - for (int i = 0; i < n; ++i) { - int l = nearest_int(iscale * x[i]); - L[i] = nmax + MAX(-nmax, MIN(nmax-1, l)); - } - scale = sumlx/suml2; best = scale*sumlx; - } - } - return scale; -} - -static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) { - float max = 0; - float amax = 0; - for (int i = 0; i < n; ++i) { - float ax = fabsf(x[i]); - if (ax > amax) { amax = ax; max = x[i]; } - } - if (!amax) { // all zero - for (int i = 0; i < n; ++i) { L[i] = 0; } - return 0.f; - } - float iscale = -nmax / max; - if (do_rmse) { - float sumlx = 0; - float suml2 = 0; - for (int i = 0; i < n; ++i) { - int l = nearest_int(iscale * x[i]); - l = MAX(-nmax, MIN(nmax-1, l)); - L[i] = l; - float w = x[i]*x[i]; - sumlx += w*x[i]*l; - suml2 += w*l*l; - } - for (int itry = 0; itry < 5; ++itry) { - int n_changed = 0; - for (int i = 0; i < n; ++i) { - float w = x[i]*x[i]; - float slx = sumlx - w*x[i]*L[i]; - if (slx > 0) { - float sl2 = suml2 - w*L[i]*L[i]; - int new_l = nearest_int(x[i] * sl2 / slx); - new_l = MAX(-nmax, MIN(nmax-1, new_l)); - if (new_l != L[i]) { - slx += w*x[i]*new_l; - sl2 += w*new_l*new_l; - if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) { - L[i] = new_l; sumlx = slx; suml2 = sl2; - ++n_changed; - } - } - } - } - if (!n_changed) { - break; - } - } - for (int i = 0; i < n; ++i) { - L[i] += nmax; - } - return sumlx / suml2; - } - for (int i = 0; i < n; ++i) { - int l = nearest_int(iscale * x[i]); - l = MAX(-nmax, MIN(nmax-1, l)); - L[i] = l + nmax; - } - return 1/iscale; -} - -static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min, - int ntry, float alpha) { - float min = x[0]; - float max = x[0]; - for (int i = 1; i < n; ++i) { - if (x[i] < min) min = x[i]; - if (x[i] > max) max = x[i]; - } - if (max == min) { - for (int i = 0; i < n; ++i) L[i] = 0; - *the_min = 0; - return 0.f; - } - if (min > 0) min = 0; - float iscale = nmax/(max - min); - float scale = 1/iscale; - for (int itry = 0; itry < ntry; ++itry) { - float sumlx = 0; int suml2 = 0; - bool did_change = false; - for (int i = 0; i < n; ++i) { - int l = nearest_int(iscale*(x[i] - min)); - l = MAX(0, MIN(nmax, l)); - if (l != L[i]) { - L[i] = l; - did_change = true; - } - sumlx += (x[i] - min)*l; - suml2 += l*l; - } - scale = sumlx/suml2; - float sum = 0; - for (int i = 0; i < n; ++i) { - sum += x[i] - scale*L[i]; - } - min = alpha*min + (1 - alpha)*sum/n; - if (min > 0) min = 0; - iscale = 1/scale; - if (!did_change) break; - } - *the_min = -min; - return scale; -} - -static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights, - uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux, - float rmin, float rdelta, int nstep, bool use_mad) { - float min = x[0]; - float max = x[0]; - float sum_w = weights[0]; - float sum_x = sum_w * x[0]; - for (int i = 1; i < n; ++i) { - if (x[i] < min) min = x[i]; - if (x[i] > max) max = x[i]; - float w = weights[i]; - sum_w += w; - sum_x += w * x[i]; - } - if (min > 0) min = 0; - if (max == min) { - for (int i = 0; i < n; ++i) L[i] = 0; - *the_min = -min; - return 0.f; - } - float iscale = nmax/(max - min); - float scale = 1/iscale; - float best_mad = 0; - for (int i = 0; i < n; ++i) { - int l = nearest_int(iscale*(x[i] - min)); - L[i] = MAX(0, MIN(nmax, l)); - float diff = scale * L[i] + min - x[i]; - diff = use_mad ? fabsf(diff) : diff * diff; - float w = weights[i]; - best_mad += w * diff; - } - if (nstep < 1) { - *the_min = -min; - return scale; - } - for (int is = 0; is <= nstep; ++is) { - iscale = (rmin + rdelta*is + nmax)/(max - min); - float sum_l = 0, sum_l2 = 0, sum_xl = 0; - for (int i = 0; i < n; ++i) { - int l = nearest_int(iscale*(x[i] - min)); - l = MAX(0, MIN(nmax, l)); - Laux[i] = l; - float w = weights[i]; - sum_l += w*l; - sum_l2 += w*l*l; - sum_xl += w*l*x[i]; - } - float D = sum_w * sum_l2 - sum_l * sum_l; - if (D > 0) { - float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D; - float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D; - if (this_min > 0) { - this_min = 0; - this_scale = sum_xl / sum_l2; - } - float mad = 0; - for (int i = 0; i < n; ++i) { - float diff = this_scale * Laux[i] + this_min - x[i]; - diff = use_mad ? fabsf(diff) : diff * diff; - float w = weights[i]; - mad += w * diff; - } - if (mad < best_mad) { - for (int i = 0; i < n; ++i) { - L[i] = Laux[i]; - } - best_mad = mad; - scale = this_scale; - min = this_min; - } - } - } - *the_min = -min; - return scale; -} - -#if QK_K == 256 -static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) { - if (j < 4) { - *d = q[j] & 63; *m = q[j + 4] & 63; - } else { - *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); - *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); - } -} -#endif - -//========================- 2-bit (de)-quantization - -void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - - uint8_t L[QK_K]; - uint8_t Laux[16]; - float weights[16]; - float mins[QK_K/16]; - float scales[QK_K/16]; - - const float q4scale = 15.f; - - for (int i = 0; i < nb; i++) { - float max_scale = 0; // as we are deducting the min, scales are always positive - float max_min = 0; - for (int j = 0; j < QK_K/16; ++j) { - for (int l = 0; l < 16; ++l) weights[l] = fabsf(x[16*j + l]); - scales[j] = make_qkx2_quants(16, 3, x + 16*j, weights, L + 16*j, &mins[j], Laux, -0.5f, 0.1f, 15, true); - float scale = scales[j]; - if (scale > max_scale) { - max_scale = scale; - } - float min = mins[j]; - if (min > max_min) { - max_min = min; - } - } - - if (max_scale > 0) { - float iscale = q4scale/max_scale; - for (int j = 0; j < QK_K/16; ++j) { - int l = nearest_int(iscale*scales[j]); - y[i].scales[j] = l; - } - y[i].d = ggml_fp32_to_fp16(max_scale/q4scale); - } else { - for (int j = 0; j < QK_K/16; ++j) y[i].scales[j] = 0; - y[i].d = ggml_fp32_to_fp16(0.f); - } - if (max_min > 0) { - float iscale = q4scale/max_min; - for (int j = 0; j < QK_K/16; ++j) { - int l = nearest_int(iscale*mins[j]); - y[i].scales[j] |= (l << 4); - } - y[i].dmin = ggml_fp32_to_fp16(max_min/q4scale); - } else { - y[i].dmin = ggml_fp32_to_fp16(0.f); - } - for (int j = 0; j < QK_K/16; ++j) { - const float d = ggml_fp16_to_fp32(y[i].d) * (y[i].scales[j] & 0xF); - if (!d) continue; - const float dm = ggml_fp16_to_fp32(y[i].dmin) * (y[i].scales[j] >> 4); - for (int ii = 0; ii < 16; ++ii) { - int l = nearest_int((x[16*j + ii] + dm)/d); - l = MAX(0, MIN(3, l)); - L[16*j + ii] = l; - } - } - -#if QK_K == 256 - for (int j = 0; j < QK_K; j += 128) { - for (int l = 0; l < 32; ++l) { - y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); - } - } -#else - for (int l = 0; l < 16; ++l) { - y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6); - } -#endif - - x += QK_K; - - } -} - -void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - - for (int i = 0; i < nb; i++) { - - const float d = ggml_fp16_to_fp32(x[i].d); - const float min = ggml_fp16_to_fp32(x[i].dmin); - - const uint8_t * q = x[i].qs; - -#if QK_K == 256 - int is = 0; - float dl, ml; - for (int n = 0; n < QK_K; n += 128) { - int shift = 0; - for (int j = 0; j < 4; ++j) { - - uint8_t sc = x[i].scales[is++]; - dl = d * (sc & 0xF); ml = min * (sc >> 4); - for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml; - - sc = x[i].scales[is++]; - dl = d * (sc & 0xF); ml = min * (sc >> 4); - for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml; - - shift += 2; - } - q += 32; - } -#else - float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4); - float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4); - float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4); - float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4); - for (int l = 0; l < 16; ++l) { - y[l+ 0] = dl1 * ((int8_t)((q[l] >> 0) & 3)) - ml1; - y[l+16] = dl2 * ((int8_t)((q[l] >> 2) & 3)) - ml2; - y[l+32] = dl3 * ((int8_t)((q[l] >> 4) & 3)) - ml3; - y[l+48] = dl4 * ((int8_t)((q[l] >> 6) & 3)) - ml4; - } - y += QK_K; -#endif - } -} - -void quantize_row_q2_K(const float * restrict x, void * restrict vy, int k) { - quantize_row_q2_K_reference(x, vy, k); -} - -size_t ggml_quantize_q2_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { - const int nb = k / QK_K; - - // TODO - collect histograms - although, at a second thought, I don't really care about them - (void)hist; - - for (int j = 0; j < nb; j += k) { - block_q2_K * restrict y = (block_q2_K *)dst + j/QK_K; - quantize_row_q2_K_reference(src + j, y, k); - } - return (n/QK_K*sizeof(block_q2_K)); -} - -//========================= 3-bit (de)-quantization - -void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - - int8_t L[QK_K]; - float scales[QK_K / 16]; - - for (int i = 0; i < nb; i++) { - - float max_scale = 0; - float amax = 0; - for (int j = 0; j < QK_K/16; ++j) { - scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true); - float scale = fabsf(scales[j]); - if (scale > amax) { - amax = scale; max_scale = scales[j]; - } - } - -#if QK_K == 256 - memset(y[i].scales, 0, 12); - if (max_scale) { - float iscale = -32.f/max_scale; - for (int j = 0; j < QK_K/16; ++j) { - int8_t l = nearest_int(iscale*scales[j]); - l = MAX(-32, MIN(31, l)) + 32; - if (j < 8) { - y[i].scales[j] = l & 0xF; - } else { - y[i].scales[j-8] |= ((l & 0xF) << 4); - } - l >>= 4; - y[i].scales[j%4 + 8] |= (l << (2*(j/4))); - } - y[i].d = ggml_fp32_to_fp16(1/iscale); - } else { - y[i].d = ggml_fp32_to_fp16(0.f); - } - - int8_t sc; - for (int j = 0; j < QK_K/16; ++j) { - sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4; - sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32; - float d = ggml_fp16_to_fp32(y[i].d) * sc; - if (!d) { - continue; - } - for (int ii = 0; ii < 16; ++ii) { - int l = nearest_int(x[16*j + ii]/d); - l = MAX(-4, MIN(3, l)); - L[16*j + ii] = l + 4; - } - } -#else - if (max_scale) { - float iscale = -8.f/max_scale; - for (int j = 0; j < QK_K/16; j+=2) { - int l1 = nearest_int(iscale*scales[j]); - l1 = 8 + MAX(-8, MIN(7, l1)); - int l2 = nearest_int(iscale*scales[j+1]); - l2 = 8 + MAX(-8, MIN(7, l2)); - y[i].scales[j/2] = l1 | (l2 << 4); - } - y[i].d = ggml_fp32_to_fp16(1/iscale); - } else { - for (int j = 0; j < QK_K/16; j+=2) { - y[i].scales[j/2] = 0; - } - y[i].d = ggml_fp32_to_fp16(0.f); - } - for (int j = 0; j < QK_K/16; ++j) { - int s = j%2 == 0 ? y[i].scales[j/2] & 0xF : y[i].scales[j/2] >> 4; - float d = ggml_fp16_to_fp32(y[i].d) * (s - 8); - if (!d) { - continue; - } - for (int ii = 0; ii < 16; ++ii) { - int l = nearest_int(x[16*j + ii]/d); - l = MAX(-4, MIN(3, l)); - L[16*j + ii] = l + 4; - } - } -#endif - - memset(y[i].hmask, 0, QK_K/8); - // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc. - int m = 0; - uint8_t hm = 1; - for (int j = 0; j < QK_K; ++j) { - if (L[j] > 3) { - y[i].hmask[m] |= hm; - L[j] -= 4; - } - if (++m == QK_K/8) { - m = 0; hm <<= 1; - } - } -#if QK_K == 256 - for (int j = 0; j < QK_K; j += 128) { - for (int l = 0; l < 32; ++l) { - y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); - } - } -#else - for (int l = 0; l < 16; ++l) { - y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6); - } -#endif - - x += QK_K; - } -} - -#if QK_K == 256 -void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - - const uint32_t kmask1 = 0x03030303; - const uint32_t kmask2 = 0x0f0f0f0f; - - uint32_t aux[4]; - const int8_t * scales = (const int8_t*)aux; - - for (int i = 0; i < nb; i++) { - - const float d_all = ggml_fp16_to_fp32(x[i].d); - - const uint8_t * restrict q = x[i].qs; - const uint8_t * restrict hm = x[i].hmask; - uint8_t m = 1; - - memcpy(aux, x[i].scales, 12); - uint32_t tmp = aux[2]; - aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); - aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); - aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); - aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); - - int is = 0; - float dl; - for (int n = 0; n < QK_K; n += 128) { - int shift = 0; - for (int j = 0; j < 4; ++j) { - - dl = d_all * (scales[is++] - 32); - for (int l = 0; l < 16; ++l) { - *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4)); - } - - dl = d_all * (scales[is++] - 32); - for (int l = 0; l < 16; ++l) { - *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4)); - } - - shift += 2; - m <<= 1; - } - q += 32; - } - - } -} -#else -void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) { - assert(k % QK_K == 0); - assert(QK_K == 64); - const int nb = k / QK_K; - - for (int i = 0; i < nb; i++) { - - const float d_all = ggml_fp16_to_fp32(x[i].d); - - const uint8_t * restrict q = x[i].qs; - const uint8_t * restrict hm = x[i].hmask; - - const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8); - const float d2 = d_all * ((x[i].scales[0] >> 4) - 8); - const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8); - const float d4 = d_all * ((x[i].scales[1] >> 4) - 8); - - for (int l=0; l<8; ++l) { - uint8_t h = hm[l]; - y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4)); - y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4)); - y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4)); - y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4)); - y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4)); - y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4)); - y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4)); - y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4)); - } - y += QK_K; - } -} -#endif - -void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) { - quantize_row_q3_K_reference(x, vy, k); -} - -size_t ggml_quantize_q3_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { - const int nb = k / QK_K; - - // TODO - collect histograms - although, at a second thought, I don't really care about them - (void)hist; - - for (int j = 0; j < nb; j += k) { - block_q3_K * restrict y = (block_q3_K *)dst + j/QK_K; - quantize_row_q3_K_reference(src + j, y, k); - } - return (n/QK_K*sizeof(block_q3_K)); -} - -// ====================== 4-bit (de)-quantization - -void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - - uint8_t L[QK_K]; - uint8_t Laux[32]; - float weights[32]; - float mins[QK_K/32]; - float scales[QK_K/32]; - - for (int i = 0; i < nb; i++) { - - float max_scale = 0; // as we are deducting the min, scales are always positive - float max_min = 0; - for (int j = 0; j < QK_K/32; ++j) { - //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f); - float sum_x2 = 0; - for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l]; - float av_x = sqrtf(sum_x2/32); - for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]); - scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false); - float scale = scales[j]; - if (scale > max_scale) { - max_scale = scale; - } - float min = mins[j]; - if (min > max_min) { - max_min = min; - } - } - -#if QK_K == 256 - float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f; - float inv_min = max_min > 0 ? 63.f/max_min : 0.f; - for (int j = 0; j < QK_K/32; ++j) { - uint8_t ls = nearest_int(inv_scale*scales[j]); - uint8_t lm = nearest_int(inv_min*mins[j]); - ls = MIN(63, ls); - lm = MIN(63, lm); - if (j < 4) { - y[i].scales[j] = ls; - y[i].scales[j+4] = lm; - } else { - y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4); - y[i].scales[j-4] |= ((ls >> 4) << 6); - y[i].scales[j-0] |= ((lm >> 4) << 6); - } - } - y[i].d = ggml_fp32_to_fp16(max_scale/63.f); - y[i].dmin = ggml_fp32_to_fp16(max_min/63.f); - - uint8_t sc, m; - for (int j = 0; j < QK_K/32; ++j) { - get_scale_min_k4(j, y[i].scales, &sc, &m); - const float d = ggml_fp16_to_fp32(y[i].d) * sc; - if (!d) continue; - const float dm = ggml_fp16_to_fp32(y[i].dmin) * m; - for (int ii = 0; ii < 32; ++ii) { - int l = nearest_int((x[32*j + ii] + dm)/d); - l = MAX(0, MIN(15, l)); - L[32*j + ii] = l; - } - } -#else - const float s_factor = 15.f; - float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f; - float inv_min = max_min > 0 ? s_factor/max_min : 0.f; - int d1 = nearest_int(inv_scale*scales[0]); - int m1 = nearest_int(inv_min*mins[0]); - int d2 = nearest_int(inv_scale*scales[1]); - int m2 = nearest_int(inv_min*mins[1]); - y[i].scales[0] = d1 | (m1 << 4); - y[i].scales[1] = d2 | (m2 << 4); - y[i].d[0] = ggml_fp32_to_fp16(max_scale/s_factor); - y[i].d[1] = ggml_fp32_to_fp16(max_min/s_factor); - - float sumlx = 0; - int suml2 = 0; - for (int j = 0; j < QK_K/32; ++j) { - const uint8_t sd = y[i].scales[j] & 0xF; - const uint8_t sm = y[i].scales[j] >> 4; - const float d = ggml_fp16_to_fp32(y[i].d[0]) * sd; - if (!d) continue; - const float m = ggml_fp16_to_fp32(y[i].d[1]) * sm; - for (int ii = 0; ii < 32; ++ii) { - int l = nearest_int((x[32*j + ii] + m)/d); - l = MAX(0, MIN(15, l)); - L[32*j + ii] = l; - sumlx += (x[32*j + ii] + m)*l*sd; - suml2 += l*l*sd*sd; - } - } - if (suml2) { - y[i].d[0] = ggml_fp32_to_fp16(sumlx/suml2); - } -#endif - uint8_t * q = y[i].qs; - for (int j = 0; j < QK_K; j += 64) { - for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4); - q += 32; - } - - x += QK_K; - - } -} - -void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - - for (int i = 0; i < nb; i++) { - - const uint8_t * q = x[i].qs; - -#if QK_K == 256 - - const float d = ggml_fp16_to_fp32(x[i].d); - const float min = ggml_fp16_to_fp32(x[i].dmin); - - int is = 0; - uint8_t sc, m; - for (int j = 0; j < QK_K; j += 64) { - get_scale_min_k4(is + 0, x[i].scales, &sc, &m); - const float d1 = d * sc; const float m1 = min * m; - get_scale_min_k4(is + 1, x[i].scales, &sc, &m); - const float d2 = d * sc; const float m2 = min * m; - for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1; - for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2; - q += 32; is += 2; - } -#else - const float dall = ggml_fp16_to_fp32(x[i].d[0]); - const float mall = ggml_fp16_to_fp32(x[i].d[1]); - const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4); - const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4); - for (int l = 0; l < 32; ++l) { - y[l+ 0] = d1 * (q[l] & 0xF) - m1; - y[l+32] = d2 * (q[l] >> 4) - m2; - } - y += QK_K; -#endif - - } -} - -void quantize_row_q4_K(const float * restrict x, void * restrict vy, int k) { - assert(k % QK_K == 0); - block_q4_K * restrict y = vy; - quantize_row_q4_K_reference(x, y, k); -} - -size_t ggml_quantize_q4_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - (void)hist; // TODO: collect histograms - for (int j = 0; j < nb; j += k) { - block_q4_K * restrict y = (block_q4_K *)dst + j/QK_K; - quantize_row_q4_K_reference(src + j, y, k); - } - return (n/QK_K*sizeof(block_q4_K)); -} - -// ====================== 5-bit (de)-quantization - -void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - -#if QK_K == 256 - uint8_t L[QK_K]; - float mins[QK_K/32]; - float scales[QK_K/32]; - float weights[32]; - uint8_t Laux[32]; -#else - int8_t L[QK_K]; - float scales[QK_K/16]; -#endif - - for (int i = 0; i < nb; i++) { - -#if QK_K == 256 - - float max_scale = 0; // as we are deducting the min, scales are always positive - float max_min = 0; - for (int j = 0; j < QK_K/32; ++j) { - //scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f); - float sum_x2 = 0; - for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l]; - float av_x = sqrtf(sum_x2/32); - for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]); - scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false); - float scale = scales[j]; - if (scale > max_scale) { - max_scale = scale; - } - float min = mins[j]; - if (min > max_min) { - max_min = min; - } - } - - float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f; - float inv_min = max_min > 0 ? 63.f/max_min : 0.f; - for (int j = 0; j < QK_K/32; ++j) { - uint8_t ls = nearest_int(inv_scale*scales[j]); - uint8_t lm = nearest_int(inv_min*mins[j]); - ls = MIN(63, ls); - lm = MIN(63, lm); - if (j < 4) { - y[i].scales[j] = ls; - y[i].scales[j+4] = lm; - } else { - y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4); - y[i].scales[j-4] |= ((ls >> 4) << 6); - y[i].scales[j-0] |= ((lm >> 4) << 6); - } - } - y[i].d = ggml_fp32_to_fp16(max_scale/63.f); - y[i].dmin = ggml_fp32_to_fp16(max_min/63.f); - - uint8_t sc, m; - for (int j = 0; j < QK_K/32; ++j) { - get_scale_min_k4(j, y[i].scales, &sc, &m); - const float d = ggml_fp16_to_fp32(y[i].d) * sc; - if (!d) continue; - const float dm = ggml_fp16_to_fp32(y[i].dmin) * m; - for (int ii = 0; ii < 32; ++ii) { - int l = nearest_int((x[32*j + ii] + dm)/d); - l = MAX(0, MIN(31, l)); - L[32*j + ii] = l; - } - } - - uint8_t * restrict qh = y[i].qh; - uint8_t * restrict ql = y[i].qs; - memset(qh, 0, QK_K/8); - - uint8_t m1 = 1, m2 = 2; - for (int n = 0; n < QK_K; n += 64) { - for (int j = 0; j < 32; ++j) { - int l1 = L[n + j]; - if (l1 > 15) { - l1 -= 16; qh[j] |= m1; - } - int l2 = L[n + j + 32]; - if (l2 > 15) { - l2 -= 16; qh[j] |= m2; - } - ql[j] = l1 | (l2 << 4); - } - m1 <<= 2; m2 <<= 2; - ql += 32; - } -#else - float max_scale = 0, amax = 0; - for (int j = 0; j < QK_K/16; ++j) { - scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1); - float abs_scale = fabsf(scales[j]); - if (abs_scale > amax) { - amax = abs_scale; - max_scale = scales[j]; - } - } - - float iscale = -128.f/max_scale; - for (int j = 0; j < QK_K/16; ++j) { - int l = nearest_int(iscale*scales[j]); - y[i].scales[j] = MAX(-128, MIN(127, l)); - } - y[i].d = ggml_fp32_to_fp16(1/iscale); - - for (int j = 0; j < QK_K/16; ++j) { - const float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j]; - if (!d) continue; - for (int ii = 0; ii < 16; ++ii) { - int l = nearest_int(x[16*j + ii]/d); - l = MAX(-16, MIN(15, l)); - L[16*j + ii] = l + 16; - } - } - - uint8_t * restrict qh = y[i].qh; - uint8_t * restrict ql = y[i].qs; - memset(qh, 0, QK_K/8); - - for (int j = 0; j < 32; ++j) { - int jm = j%8; - int is = j/8; - int l1 = L[j]; - if (l1 > 15) { - l1 -= 16; qh[jm] |= (1 << is); - } - int l2 = L[j + 32]; - if (l2 > 15) { - l2 -= 16; qh[jm] |= (1 << (4 + is)); - } - ql[j] = l1 | (l2 << 4); - } -#endif - - x += QK_K; - - } -} - -void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - - for (int i = 0; i < nb; i++) { - - const uint8_t * ql = x[i].qs; - const uint8_t * qh = x[i].qh; - -#if QK_K == 256 - - const float d = ggml_fp16_to_fp32(x[i].d); - const float min = ggml_fp16_to_fp32(x[i].dmin); - - int is = 0; - uint8_t sc, m; - uint8_t u1 = 1, u2 = 2; - for (int j = 0; j < QK_K; j += 64) { - get_scale_min_k4(is + 0, x[i].scales, &sc, &m); - const float d1 = d * sc; const float m1 = min * m; - get_scale_min_k4(is + 1, x[i].scales, &sc, &m); - const float d2 = d * sc; const float m2 = min * m; - for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1; - for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2; - ql += 32; is += 2; - u1 <<= 2; u2 <<= 2; - } -#else - float d = ggml_fp16_to_fp32(x[i].d); - const int8_t * restrict s = x[i].scales; - for (int l = 0; l < 8; ++l) { - y[l+ 0] = d * s[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16)); - y[l+ 8] = d * s[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16)); - y[l+16] = d * s[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16)); - y[l+24] = d * s[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16)); - y[l+32] = d * s[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16)); - y[l+40] = d * s[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16)); - y[l+48] = d * s[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16)); - y[l+56] = d * s[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16)); - } - y += QK_K; -#endif - } -} - -void quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) { - assert(k % QK_K == 0); - block_q5_K * restrict y = vy; - quantize_row_q5_K_reference(x, y, k); -} - -size_t ggml_quantize_q5_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - (void)hist; - for (int j = 0; j < nb; j += k) { - block_q5_K * restrict y = (block_q5_K *)dst + j/QK_K; - quantize_row_q5_K_reference(src + j, y, k); - } - return (n/QK_K*sizeof(block_q5_K)); -} - -// ====================== 6-bit (de)-quantization - -void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - - int8_t L[QK_K]; - float scales[QK_K/16]; - - for (int i = 0; i < nb; i++) { - - float max_scale = 0; - float max_abs_scale = 0; - - for (int ib = 0; ib < QK_K/16; ++ib) { - - const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1); - scales[ib] = scale; - - const float abs_scale = fabsf(scale); - if (abs_scale > max_abs_scale) { - max_abs_scale = abs_scale; - max_scale = scale; - } - - } - - if (!max_abs_scale) { - memset(&y[i], 0, sizeof(block_q6_K)); - y[i].d = ggml_fp32_to_fp16(0.f); - x += QK_K; - continue; - } - - float iscale = -128.f/max_scale; - y[i].d = ggml_fp32_to_fp16(1/iscale); - for (int ib = 0; ib < QK_K/16; ++ib) { - y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib])); - } - - for (int j = 0; j < QK_K/16; ++j) { - float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j]; - if (!d) { - continue; - } - for (int ii = 0; ii < 16; ++ii) { - int l = nearest_int(x[16*j + ii]/d); - l = MAX(-32, MIN(31, l)); - L[16*j + ii] = l + 32; - } - } - - uint8_t * restrict ql = y[i].ql; - uint8_t * restrict qh = y[i].qh; -#if QK_K == 256 - for (int j = 0; j < QK_K; j += 128) { - for (int l = 0; l < 32; ++l) { - const uint8_t q1 = L[j + l + 0] & 0xF; - const uint8_t q2 = L[j + l + 32] & 0xF; - const uint8_t q3 = L[j + l + 64] & 0xF; - const uint8_t q4 = L[j + l + 96] & 0xF; - ql[l+ 0] = q1 | (q3 << 4); - ql[l+32] = q2 | (q4 << 4); - qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6); - } - ql += 64; - qh += 32; - } -#else - for (int l = 0; l < 32; ++l) { - const uint8_t q1 = L[l + 0] & 0xF; - const uint8_t q2 = L[l + 32] & 0xF; - ql[l] = q1 | (q2 << 4); - } - for (int l = 0; l < 16; ++l) { - qh[l] = (L[l] >> 4) | ((L[l + 16] >> 4) << 2) | ((L[l + 32] >> 4) << 4) | ((L[l + 48] >> 4) << 6); - } -#endif - - x += QK_K; - - } -} - -void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - - for (int i = 0; i < nb; i++) { - - const float d = ggml_fp16_to_fp32(x[i].d); - - const uint8_t * restrict ql = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict sc = x[i].scales; - -#if QK_K == 256 - for (int n = 0; n < QK_K; n += 128) { - for (int l = 0; l < 32; ++l) { - int is = l/16; - const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; - const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; - const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; - const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; - y[l + 0] = d * sc[is + 0] * q1; - y[l + 32] = d * sc[is + 2] * q2; - y[l + 64] = d * sc[is + 4] * q3; - y[l + 96] = d * sc[is + 6] * q4; - } - y += 128; - ql += 64; - qh += 32; - sc += 8; - } -#else - for (int l = 0; l < 16; ++l) { - const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; - const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; - const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; - const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; - y[l+ 0] = d * sc[0] * q1; - y[l+16] = d * sc[1] * q2; - y[l+32] = d * sc[2] * q3; - y[l+48] = d * sc[3] * q4; - } - y += 64; -#endif - - } -} - -void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) { - assert(k % QK_K == 0); - block_q6_K * restrict y = vy; - quantize_row_q6_K_reference(x, y, k); -} - -size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - - (void)hist; // TODO - - for (int j = 0; j < nb; j += k) { - block_q6_K * restrict y = (block_q6_K *)dst + j/QK_K; - quantize_row_q6_K_reference(src + j, y, k); - } - return (n/QK_K*sizeof(block_q6_K)); -} - -//===================================== Q8_K ============================================== - -void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - - for (int i = 0; i < nb; i++) { - - float max = 0; - float amax = 0; - for (int j = 0; j < QK_K; ++j) { - float ax = fabsf(x[j]); - if (ax > amax) { - amax = ax; max = x[j]; - } - } - if (!amax) { - y[i].d = 0; - memset(y[i].qs, 0, QK_K); - x += QK_K; - continue; - } - const float iscale = -128.f/max; - for (int j = 0; j < QK_K; ++j) { - int v = nearest_int(iscale*x[j]); - y[i].qs[j] = MIN(127, v); - } - for (int j = 0; j < QK_K/16; ++j) { - int sum = 0; - for (int ii = 0; ii < 16; ++ii) { - sum += y[i].qs[j*16 + ii]; - } - y[i].bsums[j] = sum; - } - y[i].d = 1/iscale; - x += QK_K; - } -} - -void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - - for (int i = 0; i < nb; i++) { - for (int j = 0; j < QK_K; ++j) { - *y++ = x[i].d * x[i].qs[j]; - } - } -} - -void quantize_row_q8_K(const float * restrict x, void * restrict y, int k) { - quantize_row_q8_K_reference(x, y, k); -} - -//===================================== Dot ptoducts ================================= - -// -// Helper functions -// -#if __AVX__ || __AVX2__ || __AVX512F__ - -// horizontally add 8 floats -static inline float hsum_float_8(const __m256 x) { - __m128 res = _mm256_extractf128_ps(x, 1); - res = _mm_add_ps(res, _mm256_castps256_ps128(x)); - res = _mm_add_ps(res, _mm_movehl_ps(res, res)); - res = _mm_add_ss(res, _mm_movehdup_ps(res)); - return _mm_cvtss_f32(res); -} - -// shuffles to pick the required scales in dot products -static inline __m256i get_scale_shuffle_q3k(int i) { - static const uint8_t k_shuffle[128] = { - 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, - 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, - 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, - 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, - }; - return _mm256_loadu_si256((const __m256i*)k_shuffle + i); -} -static inline __m256i get_scale_shuffle_k4(int i) { - static const uint8_t k_shuffle[256] = { - 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, - 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, - 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, - 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, - 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, - 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, - 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, - 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15 - }; - return _mm256_loadu_si256((const __m256i*)k_shuffle + i); -} -static inline __m128i get_scale_shuffle(int i) { - static const uint8_t k_shuffle[128] = { - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, - 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, - 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, - 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, - 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, - 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11, - 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13, - 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15 - }; - return _mm_loadu_si128((const __m128i*)k_shuffle + i); -} -#endif - -#if QK_K == 256 -void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - - const block_q2_K * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#ifdef __ARM_NEON - - const uint8x16_t m3 = vdupq_n_u8(0x3); - const uint8x16_t m4 = vdupq_n_u8(0xF); -#if defined(__ARM_FEATURE_DOTPROD) - const int32x4_t vzero = vdupq_n_s32(0); -#endif - - int8x16x2_t q2bytes; - uint8_t aux[16]; - - float sum = 0; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); - - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - const uint8_t * restrict sc = x[i].scales; - - const uint8x16_t mins_and_scales = vld1q_u8(sc); - const uint8x16_t scales = vandq_u8(mins_and_scales, m4); - vst1q_u8(aux, scales); - - const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); - const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums); - const int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}; - const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])), - vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0]))); - const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])), - vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1]))); - sum += dmin * vaddvq_s32(vaddq_s32(s0, s1)); - - int isum = 0; - int is = 0; - -// We use this macro instead of a function call because for some reason -// the code runs 2-3% slower, even if the function is declared inline -#if defined(__ARM_FEATURE_DOTPROD) -#define MULTIPLY_ACCUM_WITH_SCALE(index)\ - isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\ - isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)]; -#else -#define MULTIPLY_ACCUM_WITH_SCALE(index)\ - {\ - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\ - vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\ - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\ - vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\ - isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\ - } -#endif - -#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\ - q8bytes = vld1q_s8_x2(q8); q8 += 32;\ - q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\ - q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\ - MULTIPLY_ACCUM_WITH_SCALE((index)); - - - for (int j = 0; j < QK_K/128; ++j) { - - const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32; - - int8x16x2_t q8bytes = vld1q_s8_x2(q8); q8 += 32; - q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3)); - q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3)); - MULTIPLY_ACCUM_WITH_SCALE(0); - - SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2); - - SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4); - - SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6); - - is += 8; - } - sum += d * isum; - - } - - *s = sum; - -#elif defined __AVX2__ - - const __m256i m3 = _mm256_set1_epi8(3); - const __m128i m4 = _mm_set1_epi8(0xF); - - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); - - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); - const __m128i scales8 = _mm_and_si128(mins_and_scales, m4); - const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); - const __m256i mins = _mm256_cvtepi8_epi16(mins8); - const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums)); - - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc); - - const __m256i all_scales = _mm256_cvtepi8_epi16(scales8); - const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); - const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); - const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; - - __m256i sumi = _mm256_setzero_si256(); - - for (int j = 0; j < QK_K/128; ++j) { - - const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32; - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - - const __m256i q2_0 = _mm256_and_si256(q2bits, m3); - const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3); - const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3); - const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3); - - __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); - __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); - __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2); - __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3); - - p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0); - p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1); - p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2); - p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3); - - p0 = _mm256_add_epi32(p0, p1); - p2 = _mm256_add_epi32(p2, p3); - - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2)); - } - - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); - - } - - *s = hsum_float_8(acc); - -#elif defined __AVX__ - - const __m128i m3 = _mm_set1_epi8(0x3); - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i m2 = _mm_set1_epi8(0x2); - - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d); - const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); - - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - // load mins and scales from block_q2_K.scales[QK_K/16] - const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); - const __m128i scales16 = _mm_and_si128(mins_and_scales, m4); - const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); - const __m128i mins_0 = _mm_cvtepi8_epi16(mins16); - const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16)); - - // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2 - const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0])); - const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8])); - - // sumf += -dmin * summs in 32bits*8 - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc); - - const __m128i scales_0 = _mm_cvtepi8_epi16(scales16); - const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16)); - const __m128i scales[2] = { scales_0, scales_1 }; - - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); - - for (int j = 0; j < QK_K/128; ++j) { - - // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K] - const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - - // load 2bits*16*8 from block_q2_K.qs[QK_K/4] - __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; - const __m128i q2_0 = _mm_and_si128(q2bits, m3); - const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); - const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); - const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); - q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; - const __m128i q2_1 = _mm_and_si128(q2bits, m3); - const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); - const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); - const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); - - // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8 - __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0); - __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1); - __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2); - __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3); - __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4); - __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5); - __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6); - __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7); - - // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8 - __m128i shuffle = _mm_set1_epi16(0x0100); - p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0); - shuffle = _mm_add_epi16(shuffle, m2); - p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1); - shuffle = _mm_add_epi16(shuffle, m2); - p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2); - shuffle = _mm_add_epi16(shuffle, m2); - p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3); - shuffle = _mm_add_epi16(shuffle, m2); - p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4); - shuffle = _mm_add_epi16(shuffle, m2); - p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5); - shuffle = _mm_add_epi16(shuffle, m2); - p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6); - shuffle = _mm_add_epi16(shuffle, m2); - p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7); - - p0 = _mm_add_epi32(p0, p1); - p2 = _mm_add_epi32(p2, p3); - p4 = _mm_add_epi32(p4, p5); - p6 = _mm_add_epi32(p6, p7); - - // isum in 32bits*4*2 - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6)); - } - - // sumf += dall * isum - dmin * summs in 32bits - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc); - } - - *s = hsum_float_8(acc); - -#elif defined __riscv_v_intrinsic - - float sumf = 0; - uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; - - for (int i = 0; i < nb; ++i) { - - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; - - const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d); - const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); - - size_t vl = 16; - - vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); - vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); - - vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); - - vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); - vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); - vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); - vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); - vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - - sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); - - vl = 32; - - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); - - uint8_t is=0; - int isum=0; - - for (int j = 0; j < QK_K/128; ++j) { - // load Q2 - vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); - - vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); - vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03 , vl); - vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03 , vl); - vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl); - - // duplicate scale elements for product - vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0+is, vl), vl); - vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2+is, vl), vl); - vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4+is, vl), vl); - vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl); - - vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); - vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); - vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); - vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); - - // load Q8 - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); - vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8+64, vl); - vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl); - - vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); - vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); - vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); - vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); - - vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); - - isum += __riscv_vmv_x_s_i32m1_i32(isum1); - - q2+=32; q8+=128; is=8; - - } - - sumf += dall * isum; - - } - - *s = sumf; - -#else - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; - - int summs = 0; - for (int j = 0; j < 16; ++j) { - summs += y[i].bsums[j] * (sc[j] >> 4); - } - - const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d); - const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin); - - int isum = 0; - int is = 0; - int d; - for (int k = 0; k < QK_K/128; ++k) { - int shift = 0; - for (int j = 0; j < 4; ++j) { - d = sc[is++] & 0xF; - int isuml = 0; - for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); - isum += d * isuml; - d = sc[is++] & 0xF; - isuml = 0; - for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); - isum += d * isuml; - shift += 2; - q8 += 32; - } - q2 += 32; - } - sumf += dall * isum - dmin * summs; - } - *s = sumf; -#endif -} - -#else - -void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - - const block_q2_K * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#ifdef __ARM_NEON - - const uint8x16_t m3 = vdupq_n_u8(0x3); -#if defined(__ARM_FEATURE_DOTPROD) - const int32x4_t vzero = vdupq_n_s32(0); -#endif - - int8x16x4_t q2bytes; - - uint32_t aux32[2]; - const uint8_t * scales = (const uint8_t *)aux32; - - float sum = 0; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * (float)x[i].d; - const float dmin = -y[i].d * (float)x[i].dmin; - - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - const uint32_t * restrict sc = (const uint32_t *)x[i].scales; - - aux32[0] = sc[0] & 0x0f0f0f0f; - aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f; - - sum += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]); - - int isum1 = 0, isum2 = 0; - - const uint8x16_t q2bits = vld1q_u8(q2); - - const int8x16x4_t q8bytes = vld1q_s8_x4(q8); - - q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3)); - q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3)); - q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3)); - q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3)); - -#if defined(__ARM_FEATURE_DOTPROD) - isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0]; - isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1]; - isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2]; - isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3]; -#else - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - isum1 += vaddvq_s16(p1) * scales[0]; - isum2 += vaddvq_s16(p2) * scales[1]; - - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - const int16x8_t p4 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - isum1 += vaddvq_s16(p3) * scales[2]; - isum2 += vaddvq_s16(p4) * scales[3]; -#endif - sum += d * (isum1 + isum2); - - } - - *s = sum; - -#elif defined __AVX2__ - - const __m256i m3 = _mm256_set1_epi8(3); - - __m256 acc = _mm256_setzero_ps(); - - uint32_t ud, um; - const uint8_t * restrict db = (const uint8_t *)&ud; - const uint8_t * restrict mb = (const uint8_t *)&um; - - float summs = 0; - - // TODO: optimize this - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); - - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const uint32_t * restrict sc = (const uint32_t *)x[i].scales; - ud = (sc[0] >> 0) & 0x0f0f0f0f; - um = (sc[0] >> 4) & 0x0f0f0f0f; - - int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3]; - summs += dmin * smin; - - const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); - const __m256i q2_0 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits), m3); - const __m256i q2_1 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3); - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - - const __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); - const __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); - - const __m256i p_0 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 0)); - const __m256i p_1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 1)); - const __m256i p_2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 0)); - const __m256i p_3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 1)); - - acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0), acc); - acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1), acc); - acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2), acc); - acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3), acc); - } - - *s = hsum_float_8(acc) + summs; - -#elif defined __AVX__ - - const __m128i m3 = _mm_set1_epi8(3); - - __m256 acc = _mm256_setzero_ps(); - - uint32_t ud, um; - const uint8_t * restrict db = (const uint8_t *)&ud; - const uint8_t * restrict mb = (const uint8_t *)&um; - - float summs = 0; - - // TODO: optimize this - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); - - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const uint32_t * restrict sc = (const uint32_t *)x[i].scales; - ud = (sc[0] >> 0) & 0x0f0f0f0f; - um = (sc[0] >> 4) & 0x0f0f0f0f; - - int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3]; - summs += dmin * smin; - - const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); - const __m128i q2_0 = _mm_and_si128(q2bits, m3); - const __m128i q2_1 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); - const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); - const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - - const __m128i p0 = _mm_maddubs_epi16(q2_0, _mm256_extractf128_si256(q8_0, 0)); - const __m128i p1 = _mm_maddubs_epi16(q2_1, _mm256_extractf128_si256(q8_0, 1)); - const __m128i p2 = _mm_maddubs_epi16(q2_2, _mm256_extractf128_si256(q8_1, 0)); - const __m128i p3 = _mm_maddubs_epi16(q2_3, _mm256_extractf128_si256(q8_1, 1)); - - const __m256i p_0 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p0, p0)), _mm_cvtepi16_epi32(p0)); - const __m256i p_1 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p1, p1)), _mm_cvtepi16_epi32(p1)); - const __m256i p_2 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p2, p2)), _mm_cvtepi16_epi32(p2)); - const __m256i p_3 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p3, p3)), _mm_cvtepi16_epi32(p3)); - - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0)), acc); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1)), acc); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2)), acc); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3)), acc); - } - - *s = hsum_float_8(acc) + summs; - -#elif defined __riscv_v_intrinsic - - uint32_t aux32[2]; - const uint8_t * scales = (const uint8_t *)aux32; - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * (float)x[i].d; - const float dmin = -y[i].d * (float)x[i].dmin; - - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - const uint32_t * restrict sc = (const uint32_t *)x[i].scales; - - aux32[0] = sc[0] & 0x0f0f0f0f; - aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f; - - sumf += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]); - - int isum1 = 0; - int isum2 = 0; - - size_t vl = 16; - - vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); - - // load Q2 - vuint8mf2_t q2_x = __riscv_vle8_v_u8mf2(q2, vl); - - vint8mf2_t q2_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q2_x, 0x03, vl)); - vint8mf2_t q2_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x2, vl), 0x03 , vl)); - vint8mf2_t q2_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x4, vl), 0x03 , vl)); - vint8mf2_t q2_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x6, vl), 0x03 , vl)); - - // load Q8, and take product with Q2 - vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q2_0, __riscv_vle8_v_i8mf2(q8, vl), vl); - vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q2_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); - vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q2_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); - vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q2_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); - - vint16m1_t vs_0 = __riscv_vredsum_vs_i16m1_i16m1(p0, vzero, vl); - vint16m1_t vs_1 = __riscv_vredsum_vs_i16m1_i16m1(p1, vzero, vl); - vint16m1_t vs_2 = __riscv_vredsum_vs_i16m1_i16m1(p2, vzero, vl); - vint16m1_t vs_3 = __riscv_vredsum_vs_i16m1_i16m1(p3, vzero, vl); - - isum1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[0]; - isum2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[1]; - isum1 += __riscv_vmv_x_s_i16m1_i16(vs_2) * scales[2]; - isum2 += __riscv_vmv_x_s_i16m1_i16(vs_3) * scales[3]; - - sumf += d * (isum1 + isum2); - - } - - *s = sumf; - -#else - - float sumf = 0; - - int isum[4]; - - for (int i = 0; i < nb; ++i) { - - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; - - int summs = 0; - for (int j = 0; j < QK_K/16; ++j) { - summs += y[i].bsums[j] * (sc[j] >> 4); - } - - const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d); - const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin); - - isum[0] = isum[1] = isum[2] = isum[3] = 0; - for (int l = 0; l < 16; ++l) { - isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3); - isum[1] += q8[l+16] * ((q2[l] >> 2) & 3); - isum[2] += q8[l+32] * ((q2[l] >> 4) & 3); - isum[3] += q8[l+48] * ((q2[l] >> 6) & 3); - } - for (int l = 0; l < 4; ++l) { - isum[l] *= (sc[l] & 0xF); - } - sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs; - } - *s = sumf; -#endif -} -#endif - -#if QK_K == 256 -void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - assert(n % QK_K == 0); - - const uint32_t kmask1 = 0x03030303; - const uint32_t kmask2 = 0x0f0f0f0f; - - const block_q3_K * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#ifdef __ARM_NEON - - uint32_t aux[3]; - uint32_t utmp[4]; - - const uint8x16_t m3b = vdupq_n_u8(0x3); -#ifdef __ARM_FEATURE_DOTPROD - const int32x4_t vzero = vdupq_n_s32(0); -#endif - - const uint8x16_t m0 = vdupq_n_u8(1); - const uint8x16_t m1 = vshlq_n_u8(m0, 1); - const uint8x16_t m2 = vshlq_n_u8(m0, 2); - const uint8x16_t m3 = vshlq_n_u8(m0, 3); - const int8_t m32 = 32; - - int8x16x4_t q3bytes; - - float sum = 0; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict qh = x[i].hmask; - const int8_t * restrict q8 = y[i].qs; - - uint8x16x2_t qhbits = vld1q_u8_x2(qh); - - uint8x16x4_t q3h; - - int32_t isum = 0; - - // Set up scales - memcpy(aux, x[i].scales, 12); - utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); - utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); - utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); - utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); - - int8_t * scale = (int8_t *)utmp; - for (int j = 0; j < 16; ++j) scale[j] -= m32; - - for (int j = 0; j < QK_K/128; ++j) { - - const uint8x16x2_t q3bits = vld1q_u8_x2(q3); q3 += 32; - const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64; - const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64; - - q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2); - q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2); - q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1); - q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1); - - q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0])); - q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1])); - q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2])); - q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3])); - -#if defined(__ARM_FEATURE_DOTPROD) - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3]; -#else - int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])), - vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0]))); - int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])), - vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_1.val[1]))); - int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])), - vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2]))); - int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])), - vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3]))); - isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; -#endif - scale += 4; - - q3h.val[0] = vbicq_u8(m2, qhbits.val[0]); - q3h.val[1] = vbicq_u8(m2, qhbits.val[1]); - q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1); - q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1); - - q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0])); - q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1])); - q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2])); - q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3])); - -#if defined(__ARM_FEATURE_DOTPROD) - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3]; -#else - p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])), - vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0]))); - p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])), - vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_2.val[1]))); - p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])), - vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2]))); - p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])), - vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3]))); - isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; -#endif - scale += 4; - - if (j == 0) { - qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4); - qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4); - } - - } - sum += d * isum; - - } - - *s = sum; - -#elif defined __AVX2__ - - const __m256i m3 = _mm256_set1_epi8(3); - const __m256i mone = _mm256_set1_epi8(1); - const __m128i m32 = _mm_set1_epi8(32); - - __m256 acc = _mm256_setzero_ps(); - - uint32_t aux[3]; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - - const uint8_t * restrict q3 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - // Set up scales - memcpy(aux, x[i].scales, 12); - __m128i scales128 = _mm_set_epi32( - ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), - ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), - (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), - (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); - scales128 = _mm_sub_epi8(scales128, m32); - const __m256i all_scales = _mm256_cvtepi8_epi16(scales128); - const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); - const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); - const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; - - // high bit - const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask); - - // integer accumulator - __m256i sumi = _mm256_setzero_si256(); - - int bit = 0; - int is = 0; - - for (int j = 0; j < QK_K/128; ++j) { - // load low 2 bits - const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32; - - // prepare low and high bits - const __m256i q3l_0 = _mm256_and_si256(q3bits, m3); - const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); - ++bit; - - const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3); - const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); - ++bit; - - const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3); - const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); - ++bit; - - const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3); - const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); - ++bit; - - // load Q8 quants - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, - // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, - // and 2 if the high bit was set) - __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); - __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); - __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2); - __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3); - - __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); - __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2); - __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3); - - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - p16_2 = _mm256_sub_epi16(p16_2, q8s_2); - p16_3 = _mm256_sub_epi16(p16_3, q8s_3); - - // multiply with scales - p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0); - p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1); - p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2); - p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3); - - // accumulate - p16_0 = _mm256_add_epi32(p16_0, p16_1); - p16_2 = _mm256_add_epi32(p16_2, p16_3); - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2)); - - } - - // multiply with block scale and accumulate - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); - - } - - *s = hsum_float_8(acc); - -#elif defined __AVX__ - - const __m128i m3 = _mm_set1_epi8(3); - const __m128i mone = _mm_set1_epi8(1); - const __m128i m32 = _mm_set1_epi8(32); - const __m128i m2 = _mm_set1_epi8(2); - - __m256 acc = _mm256_setzero_ps(); - - const uint32_t *aux; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - - const uint8_t * restrict q3 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - // Set up scales - aux = (const uint32_t *)x[i].scales; - __m128i scales128 = _mm_set_epi32( - ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), - ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), - (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), - (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); - scales128 = _mm_sub_epi8(scales128, m32); - const __m128i scales_0 = _mm_cvtepi8_epi16(scales128); - const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128)); - const __m128i scales[2] = { scales_0, scales_1 }; - - // high bit *128*2 from block_q3_K.hmask[QK_K/8] - const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]); - const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]); - - // integer accumulator - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); - - for (int j = 0; j < QK_K/128; ++j) { - // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4] - const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; - const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; - - // prepare low and high bits - const int bit = j << 2; - - const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3); - const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3); - const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2); - const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2); - - const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3); - const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3); - const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2); - const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2); - - const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3); - const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3); - const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2); - const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2); - - const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3); - const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3); - const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2); - const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2); - - // load Q8 quants from block_q8_K.qs[QK_K] - const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, - // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, - // and 2 if the high bit was set) - __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0); - __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1); - __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2); - __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3); - __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4); - __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5); - __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6); - __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7); - - __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0); - __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1); - __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2); - __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3); - __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4); - __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5); - __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6); - __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7); - - p16_0 = _mm_sub_epi16(p16_0, q8s_0); - p16_1 = _mm_sub_epi16(p16_1, q8s_1); - p16_2 = _mm_sub_epi16(p16_2, q8s_2); - p16_3 = _mm_sub_epi16(p16_3, q8s_3); - p16_4 = _mm_sub_epi16(p16_4, q8s_4); - p16_5 = _mm_sub_epi16(p16_5, q8s_5); - p16_6 = _mm_sub_epi16(p16_6, q8s_6); - p16_7 = _mm_sub_epi16(p16_7, q8s_7); - - // multiply with scales - __m128i shuffle = _mm_set1_epi16(0x0100); - p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0); - shuffle = _mm_add_epi16(shuffle, m2); - p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1); - shuffle = _mm_add_epi16(shuffle, m2); - p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2); - shuffle = _mm_add_epi16(shuffle, m2); - p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3); - shuffle = _mm_add_epi16(shuffle, m2); - p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4); - shuffle = _mm_add_epi16(shuffle, m2); - p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5); - shuffle = _mm_add_epi16(shuffle, m2); - p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6); - shuffle = _mm_add_epi16(shuffle, m2); - p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7); - - // accumulate - p16_0 = _mm_add_epi32(p16_0, p16_1); - p16_2 = _mm_add_epi32(p16_2, p16_3); - p16_4 = _mm_add_epi32(p16_4, p16_5); - p16_6 = _mm_add_epi32(p16_6, p16_7); - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6)); - - } - - // multiply with block scale and accumulate - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); - - } - - *s = hsum_float_8(acc); - -#elif defined __riscv_v_intrinsic - - uint32_t aux[3]; - uint32_t utmp[4]; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict qh = x[i].hmask; - const int8_t * restrict q8 = y[i].qs; - - memcpy(aux, x[i].scales, 12); - utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); - utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); - utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); - utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); - - int8_t * scale = (int8_t *)utmp; - for (int j = 0; j < 16; ++j) scale[j] -= 32; - - - size_t vl = 32; - uint8_t m = 1; - - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); - - int sum_t = 0; - - for (int j = 0; j < QK_K; j += 128) { - - vl = 32; - - // load Q3 - vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); - - vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); - vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); - vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); - vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); - - // compute mask for subtraction - vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); - vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_m(vmask_0, q3_0, 0x4, vl); - m <<= 1; - - vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); - vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_m(vmask_1, q3_1, 0x4, vl); - m <<= 1; - - vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); - vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_m(vmask_2, q3_2, 0x4, vl); - m <<= 1; - - vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); - vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_m(vmask_3, q3_3, 0x4, vl); - m <<= 1; - - // load Q8 and take product with Q3 - vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); - - vl = 16; - - // retreive lane to multiply with scale - vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); - vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); - vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); - vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); - vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); - vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); - vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); - vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); - - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); - - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); - - q3 += 32; q8 += 128; scale += 8; - - } - - const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; - - sumf += d*sum_t; - - } - - *s = sumf; - -#else - // scalar version - // This function is written like this so the compiler can manage to vectorize most of it - // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the - // manually vectorized version above. Every other version I tried would run at least 4 times slower. - // The ideal situation would be if we could just write the code once, and the compiler would - // automatically produce the best possible set of machine instructions, instead of us having to manually - // write vectorized versions for AVX, ARM_NEON, etc. - - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - memset(sums, 0, 8*sizeof(float)); - - uint32_t auxs[4]; - const int8_t * scales = (const int8_t*)auxs; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict hm = x[i].hmask; - const int8_t * restrict q8 = y[i].qs; - memset(aux32, 0, 8*sizeof(int32_t)); - int8_t * restrict a = aux8; - uint8_t m = 1; - for (int j = 0; j < QK_K; j += 128) { - for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3; - for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); - a += 32; m <<= 1; - for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3; - for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); - a += 32; m <<= 1; - for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3; - for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); - a += 32; m <<= 1; - for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3; - for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); - a += 32; m <<= 1; - q3 += 32; - } - a = aux8; - - memcpy(auxs, x[i].scales, 12); - uint32_t tmp = auxs[2]; - auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); - auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); - auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); - auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); - for (int j = 0; j < QK_K/16; ++j) { - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; - q8 += 8; a += 8; - } - const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; - -#endif - -} - -#else - -void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - assert(n % QK_K == 0); - - const block_q3_K * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#ifdef __ARM_NEON - -#ifdef __ARM_FEATURE_DOTPROD - const int32x4_t vzero = vdupq_n_s32(0); -#endif - - const uint8x16_t m3b = vdupq_n_u8(0x3); - const uint8x16_t mh = vdupq_n_u8(4); - - int8x16x4_t q3bytes; - - uint16_t aux16[2]; - int8_t * scales = (int8_t *)aux16; - - float sum = 0; - - for (int i = 0; i < nb; ++i) { - - uint8x16x4_t q3h; - - const uint8x8_t hbits = vld1_u8(x[i].hmask); - const uint8x16_t q3bits = vld1q_u8(x[i].qs); - const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs); - - const uint16_t a = *(const uint16_t *)x[i].scales; - aux16[0] = a & 0x0f0f; - aux16[1] = (a >> 4) & 0x0f0f; - - for (int j = 0; j < 4; ++j) scales[j] -= 8; - - int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]); - - const float d = y[i].d * (float)x[i].d; - - const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1)); - q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2)); - q3h.val[1] = vandq_u8(mh, htmp); - q3h.val[2] = vandq_u8(mh, vshrq_n_u8(htmp, 2)); - q3h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 4)); - - q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b), q3h.val[0])); - q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1])); - q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2])); - q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3])); - -#if defined(__ARM_FEATURE_DOTPROD) - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1]; - isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3]; -#else - const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3]; -#endif - - sum += d * isum; - - } - - *s = sum; - -#elif defined __AVX2__ - - const __m256i m3 = _mm256_set1_epi8(3); - const __m256i m1 = _mm256_set1_epi8(1); - - __m256 acc = _mm256_setzero_ps(); - - uint64_t aux64; - - uint16_t aux16[2]; - const int8_t * aux8 = (const int8_t *)aux16; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - - const uint8_t * restrict q3 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const uint16_t a = *(const uint16_t *)x[i].scales; - aux16[0] = a & 0x0f0f; - aux16[1] = (a >> 4) & 0x0f0f; - - const __m256i scale_0 = MM256_SET_M128I(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8)); - const __m256i scale_1 = MM256_SET_M128I(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8)); - - memcpy(&aux64, x[i].hmask, 8); - - const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0); - __m256i q3h_0 = MM256_SET_M128I(_mm_srli_epi16(haux, 2), haux); - __m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4); - q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2); - q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2); - - // load low 2 bits - const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3); - - // prepare low and high bits - const __m256i q3aux = MM256_SET_M128I(_mm_srli_epi16(q3bits, 2), q3bits); - const __m256i q3l_0 = _mm256_and_si256(q3aux, m3); - const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3); - - // load Q8 quants - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, - // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, - // and 2 if the high bit was set) - const __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); - const __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); - - __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); - - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - - // multiply with scales - p16_0 = _mm256_madd_epi16(scale_0, p16_0); - p16_1 = _mm256_madd_epi16(scale_1, p16_1); - - p16_0 = _mm256_add_epi32(p16_0, p16_1); - - // multiply with block scale and accumulate - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16_0), acc); - - } - - *s = hsum_float_8(acc); - -#elif defined __AVX__ - - const __m128i m3 = _mm_set1_epi8(3); - const __m128i m1 = _mm_set1_epi8(1); - - __m256 acc = _mm256_setzero_ps(); - - uint64_t aux64; - - uint16_t aux16[2]; - const int8_t * aux8 = (const int8_t *)aux16; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - - const uint8_t * restrict q3 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const uint16_t a = *(const uint16_t *)x[i].scales; - aux16[0] = a & 0x0f0f; - aux16[1] = (a >> 4) & 0x0f0f; - - const __m128i scale_0 = _mm_set1_epi16(aux8[0] - 8); - const __m128i scale_1 = _mm_set1_epi16(aux8[2] - 8); - const __m128i scale_2 = _mm_set1_epi16(aux8[1] - 8); - const __m128i scale_3 = _mm_set1_epi16(aux8[3] - 8); - - memcpy(&aux64, x[i].hmask, 8); - - __m128i q3h_0 = _mm_set_epi64x(aux64 >> 1, aux64 >> 0); - __m128i q3h_1 = _mm_srli_epi16(q3h_0, 2); - __m128i q3h_2 = _mm_srli_epi16(q3h_0, 4); - __m128i q3h_3 = _mm_srli_epi16(q3h_0, 6); - q3h_0 = _mm_slli_epi16(_mm_andnot_si128(q3h_0, m1), 2); - q3h_1 = _mm_slli_epi16(_mm_andnot_si128(q3h_1, m1), 2); - q3h_2 = _mm_slli_epi16(_mm_andnot_si128(q3h_2, m1), 2); - q3h_3 = _mm_slli_epi16(_mm_andnot_si128(q3h_3, m1), 2); - - // load low 2 bits - const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3); - - // prepare low and high bits - const __m128i q3l_0 = _mm_and_si128(q3bits, m3); - const __m128i q3l_1 = _mm_and_si128(_mm_srli_epi16(q3bits, 2), m3); - const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits, 4), m3); - const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits, 6), m3); - - // load Q8 quants - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16, - // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, - // and 2 if the high bit was set) - const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0)); - const __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, _mm256_extractf128_si256(q8_0, 1)); - const __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, _mm256_extractf128_si256(q8_1, 0)); - const __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, _mm256_extractf128_si256(q8_1, 1)); - - __m128i p16_0 = _mm_maddubs_epi16(q3l_0, _mm256_extractf128_si256(q8_0, 0)); - __m128i p16_1 = _mm_maddubs_epi16(q3l_1, _mm256_extractf128_si256(q8_0, 1)); - __m128i p16_2 = _mm_maddubs_epi16(q3l_2, _mm256_extractf128_si256(q8_1, 0)); - __m128i p16_3 = _mm_maddubs_epi16(q3l_3, _mm256_extractf128_si256(q8_1, 1)); - - p16_0 = _mm_sub_epi16(p16_0, q8s_0); - p16_1 = _mm_sub_epi16(p16_1, q8s_1); - p16_2 = _mm_sub_epi16(p16_2, q8s_2); - p16_3 = _mm_sub_epi16(p16_3, q8s_3); - - // multiply with scales - p16_0 = _mm_madd_epi16(scale_0, p16_0); - p16_1 = _mm_madd_epi16(scale_1, p16_1); - p16_2 = _mm_madd_epi16(scale_2, p16_2); - p16_3 = _mm_madd_epi16(scale_3, p16_3); - - p16_0 = _mm_add_epi32(p16_0, p16_2); - p16_1 = _mm_add_epi32(p16_1, p16_3); - __m256i p16 = MM256_SET_M128I(p16_1, p16_0); - - // multiply with block scale and accumulate - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16)), acc); - - } - - *s = hsum_float_8(acc); - -#elif defined __riscv_v_intrinsic - - uint16_t aux16[2]; - int8_t * scales = (int8_t *)aux16; - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const uint8_t * restrict q3 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const uint16_t a = *(const uint16_t *)x[i].scales; - aux16[0] = a & 0x0f0f; - aux16[1] = (a >> 4) & 0x0f0f; - - for (int j = 0; j < 4; ++j) scales[j] -= 8; - - int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]); - - const float d = y[i].d * (float)x[i].d; - - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - - // load qh - vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4(x[i].hmask, 8); - vuint8mf2_t qh_x2 = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8)); - - size_t vl = 16; - - // extend and combine both qh_x1 and qh_x2 - vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl); - - vuint8mf2_t qh_0 = __riscv_vand_vx_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl); - vuint8mf2_t qh_1 = __riscv_vand_vx_u8mf2(qh_x, 0x4, vl); - vuint8mf2_t qh_2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl); - vuint8mf2_t qh_3 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), 0x4, vl); - - // load Q3 - vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(q3, vl); - - vuint8mf2_t q3h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q3_x, 0x3, vl), qh_0, vl); - vuint8mf2_t q3h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 2, vl), 0x3, vl), qh_1, vl); - vuint8mf2_t q3h_2 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 4, vl), 0x3, vl), qh_2, vl); - vuint8mf2_t q3h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x6, vl), qh_3, vl); - - vint8mf2_t q3_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_0); - vint8mf2_t q3_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_1); - vint8mf2_t q3_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_2); - vint8mf2_t q3_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_3); - - // load Q8 and take product with Q3 - vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q3_0, __riscv_vle8_v_i8mf2(q8, vl), vl); - vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q3_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); - vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q3_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); - vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q3_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); - - vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl); - vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl); - vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl); - vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl); - - isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scales[0]; - isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scales[2]; - isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scales[1]; - isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scales[3]; - - sumf += d * isum; - - } - - *s = sumf; - -#else - - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - int32_t scales[4]; - memset(sums, 0, 8*sizeof(float)); - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict hm = x[i].hmask; - const int8_t * restrict q8 = y[i].qs; - int8_t * restrict a = aux8; - for (int l = 0; l < 8; ++l) { - a[l+ 0] = (int8_t)((q3[l+0] >> 0) & 3) - (hm[l] & 0x01 ? 0 : 4); - a[l+ 8] = (int8_t)((q3[l+8] >> 0) & 3) - (hm[l] & 0x02 ? 0 : 4); - a[l+16] = (int8_t)((q3[l+0] >> 2) & 3) - (hm[l] & 0x04 ? 0 : 4); - a[l+24] = (int8_t)((q3[l+8] >> 2) & 3) - (hm[l] & 0x08 ? 0 : 4); - a[l+32] = (int8_t)((q3[l+0] >> 4) & 3) - (hm[l] & 0x10 ? 0 : 4); - a[l+40] = (int8_t)((q3[l+8] >> 4) & 3) - (hm[l] & 0x20 ? 0 : 4); - a[l+48] = (int8_t)((q3[l+0] >> 6) & 3) - (hm[l] & 0x40 ? 0 : 4); - a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4); - } - - scales[0] = (x[i].scales[0] & 0xF) - 8; - scales[1] = (x[i].scales[0] >> 4) - 8; - scales[2] = (x[i].scales[1] & 0xF) - 8; - scales[3] = (x[i].scales[1] >> 4) - 8; - - memset(aux32, 0, 8*sizeof(int32_t)); - for (int j = 0; j < QK_K/16; ++j) { - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux32[l] += scales[j] * aux16[l]; - } - const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; - -#endif - -} -#endif - -#if QK_K == 256 -void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - assert(n % QK_K == 0); - - const block_q4_K * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; - - uint32_t utmp[4]; - -#ifdef __ARM_NEON - - const uint8x16_t m4b = vdupq_n_u8(0xf); -#ifdef __ARM_FEATURE_DOTPROD - const int32x4_t mzero = vdupq_n_s32(0); -#endif - - int8x16x2_t q4bytes; - int8x16x2_t q8bytes; - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin); - - const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); - - memcpy(utmp, x[i].scales, 12); - - uint32x2_t mins8 = { 0 }; - mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0); - mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1); - - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[0] &= kmask1; - - const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); - const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), - vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); - sumf -= dmin * vaddvq_s32(prod); - - const uint8_t * scales = (const uint8_t *)utmp; - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - int32_t sumi1 = 0; - int32_t sumi2 = 0; - - for (int j = 0; j < QK_K/64; ++j) { - - const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32; - -#ifdef __ARM_FEATURE_DOTPROD - q8bytes = vld1q_s8_x2(q8); q8 += 32; - q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); - q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); - - const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); - sumi1 += vaddvq_s32(p1) * scales[2*j+0]; - - q8bytes = vld1q_s8_x2(q8); q8 += 32; - q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); - q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); - - const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); - - sumi2 += vaddvq_s32(p2) * scales[2*j+1]; -#else - q8bytes = vld1q_s8_x2(q8); q8 += 32; - q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); - q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); - const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0]; - - q8bytes = vld1q_s8_x2(q8); q8 += 32; - q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); - q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) * scales[2*j+1]; - -#endif - } - - sumf += d * (sumi1 + sumi2); - - } - - *s = sumf; - -#elif defined __AVX2__ - - const __m256i m4 = _mm256_set1_epi8(0xF); - - __m256 acc = _mm256_setzero_ps(); - __m128 acc_m = _mm_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); - - const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); - const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); - const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); - acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m); - - const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); - const __m256i scales = MM256_SET_M128I(sc128, sc128); - - __m256i sumi = _mm256_setzero_si256(); - - for (int j = 0; j < QK_K/64; ++j) { - - const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); - const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); - - const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; - const __m256i q4l = _mm256_and_si256(q4bits, m4); - const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); - - const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); - p16l = _mm256_madd_epi16(scale_l, p16l); - - const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); - p16h = _mm256_madd_epi16(scale_h, p16h); - const __m256i sumj = _mm256_add_epi32(p16l, p16h); - - sumi = _mm256_add_epi32(sumi, sumj); - } - - __m256 vd = _mm256_set1_ps(d); - acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); - - } - - acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); - acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); - - *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); - -#elif defined __AVX__ - - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i m2 = _mm_set1_epi8(0x2); - - __m256 acc = _mm256_setzero_ps(); - __m128 acc_m = _mm_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); - const __m128i scales = _mm_cvtepu8_epi16(utmps); - const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); - - const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); - const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); - const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); - const __m128i prod = _mm_madd_epi16(mins, q8s); - acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m); - - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); - - __m128i shuffle = _mm_set1_epi16(0x0100); - for (int j = 0; j < QK_K/64; ++j) { - - const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi16(shuffle, m2); - const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi16(shuffle, m2); - - __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4l_0 = _mm_and_si128(q4bits, m4); - const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); - q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4l_1 = _mm_and_si128(q4bits, m4); - const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); - - const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0); - p16l = _mm_madd_epi16(scale_l, p16l); - sumi_0 = _mm_add_epi32(sumi_0, p16l); - const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - p16l = _mm_maddubs_epi16(q4l_1, q8l_1); - p16l = _mm_madd_epi16(scale_l, p16l); - sumi_1 = _mm_add_epi32(sumi_1, p16l); - - const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0); - p16h = _mm_madd_epi16(scale_h, p16h); - sumi_0 = _mm_add_epi32(sumi_0, p16h); - const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - p16h = _mm_maddubs_epi16(q4h_1, q8h_1); - p16h = _mm_madd_epi16(scale_h, p16h); - sumi_1 = _mm_add_epi32(sumi_1, p16h); - - } - - __m256 vd = _mm256_set1_ps(d); - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); - - } - - acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); - acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); - - *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); - -#elif defined __riscv_v_intrinsic - - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - size_t vl = 8; - - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin); - - vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); - vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); - vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); - vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); - vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); - - vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - vl = 32; - - int32_t sum_1 = 0; - int32_t sum_2 = 0; - - vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); - - for (int j = 0; j < QK_K/64; ++j) { - // load Q4 - vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); - - // load Q8 and multiply it with lower Q4 nibble - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); - vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); - vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); - - sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; - - // load Q8 and multiply it with upper Q4 nibble - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); - vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); - vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); - vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); - - sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; - - q4 += 32; q8 += 64; - - } - - sumf += d*(sum_1 + sum_2); - - } - - *s = sumf; - -#else - - - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; - - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - memset(sums, 0, 8*sizeof(float)); - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - memset(aux32, 0, 8*sizeof(int32_t)); - int8_t * restrict a = aux8; - for (int j = 0; j < QK_K/64; ++j) { - for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); - a += 32; - for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); - a += 32; q4 += 32; - } - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - int sumi = 0; - for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; - a = aux8; - int is = 0; - for (int j = 0; j < QK_K/32; ++j) { - int32_t scale = scales[is++]; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - } - const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; - const float dmin = ggml_fp16_to_fp32(x[i].dmin) * y[i].d; - sumf -= dmin * sumi; - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif -} -#else -void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - assert(n % QK_K == 0); - - const block_q4_K * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#ifdef __ARM_NEON - - const uint8x16_t m4b = vdupq_n_u8(0xf); - -#ifdef __ARM_FEATURE_DOTPROD - const int32x4_t mzero = vdupq_n_s32(0); -#endif - - float sumf = 0; - - int8x16x2_t q4bytes; - int8x16x4_t q8bytes; - - float sum_mins = 0.f; - - uint16_t aux16[2]; - const uint8_t * restrict scales = (const uint8_t *)aux16; - - for (int i = 0; i < nb; ++i) { - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const uint16_t * restrict a = (const uint16_t *)x[i].scales; - aux16[0] = a[0] & 0x0f0f; - aux16[1] = (a[0] >> 4) & 0x0f0f; - - const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]); - sum_mins += y[i].d * (float)x[i].d[1] * summi; - - const float d = y[i].d * (float)x[i].d[0]; - - const uint8x16x2_t q4bits = vld1q_u8_x2(q4); - -#ifdef __ARM_FEATURE_DOTPROD - q8bytes = vld1q_s8_x4(q8); - q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); - q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); - - const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); - const int32_t sumi1 = vaddvq_s32(p1) * scales[0]; - - q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); - q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); - - const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]); - const int32_t sumi2 = vaddvq_s32(p2) * scales[1]; - -#else - q8bytes = vld1q_s8_x4(q8); - q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); - q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); - const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0]; - - q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); - q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2]))); - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3]))); - int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1]; - -#endif - sumf += d * (sumi1 + sumi2); - - } - - *s = sumf - sum_mins; - -#elif defined __AVX2__ - - const __m256i m4 = _mm256_set1_epi8(0xF); - - __m256 acc = _mm256_setzero_ps(); - - float summs = 0; - - uint16_t aux16[2]; - const uint8_t * scales = (const uint8_t *)aux16; - - for (int i = 0; i < nb; ++i) { - - const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d; - const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d; - const __m256 vd = _mm256_set1_ps(d); - - const uint16_t * a = (const uint16_t *)x[i].scales; - aux16[0] = a[0] & 0x0f0f; - aux16[1] = (a[0] >> 4) & 0x0f0f; - - summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); - const __m256i q4l = _mm256_and_si256(q4bits, m4); - const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); - - const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32)); - - const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); - const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); - - const __m256i p32l = _mm256_madd_epi16(_mm256_set1_epi16(scales[0]), p16l); - acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32l), acc); - - const __m256i p32h = _mm256_madd_epi16(_mm256_set1_epi16(scales[1]), p16h); - acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32h), acc); - - } - - *s = hsum_float_8(acc) - summs; - -#elif defined __AVX__ - - const __m128i m4 = _mm_set1_epi8(0xF); - - __m256 acc = _mm256_setzero_ps(); - - float summs = 0; - - uint16_t aux16[2]; - const uint8_t * scales = (const uint8_t *)aux16; - - for (int i = 0; i < nb; ++i) { - - const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d; - const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d; - const __m256 vd = _mm256_set1_ps(d); - - const uint16_t * a = (const uint16_t *)x[i].scales; - aux16[0] = a[0] & 0x0f0f; - aux16[1] = (a[0] >> 4) & 0x0f0f; - - summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); - const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0); - const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1); - const __m128i q4_0 = _mm_and_si128(q4bits_0, m4); - const __m128i q4_1 = _mm_and_si128(q4bits_1, m4); - const __m128i q4_2 = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4); - const __m128i q4_3 = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4); - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - - const __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0)); - const __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1)); - const __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0)); - const __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1)); - - const __m128i p32_0 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_0); - const __m128i p32_1 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_1); - acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_1, p32_0))), acc); - - const __m128i p32_2 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_2); - const __m128i p32_3 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_3); - acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_3, p32_2))), acc); - - } - - *s = hsum_float_8(acc) - summs; - -#elif defined __riscv_v_intrinsic - - uint16_t s16[2]; - const uint8_t * restrict scales = (const uint8_t *)s16; - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const uint16_t * restrict b = (const uint16_t *)x[i].scales; - s16[0] = b[0] & 0x0f0f; - s16[1] = (b[0] >> 4) & 0x0f0f; - - sumf -= y[i].d * ggml_fp16_to_fp32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[0]); - - size_t vl = 32; - - vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); - - // load Q4 - vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); - - // load Q8 and multiply it with lower Q4 nibble - vint8m1_t q4_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); - vint16m2_t va_0 = __riscv_vwmul_vv_i16m2(q4_a, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m1_t aux1 = __riscv_vredsum_vs_i16m2_i16m1(va_0, vzero, vl); - - sumf += d*scales[0]*__riscv_vmv_x_s_i16m1_i16(aux1); - - // load Q8 and multiply it with upper Q4 nibble - vint8m1_t q4_s = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); - vint16m2_t va_1 = __riscv_vwmul_vv_i16m2(q4_s, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m1_t aux2 = __riscv_vredsum_vs_i16m2_i16m1(va_1, vzero, vl); - - sumf += d*scales[1]*__riscv_vmv_x_s_i16m1_i16(aux2); - - } - - *s = sumf; - -#else - - uint8_t aux8[QK_K]; - int16_t aux16[16]; - float sums [8]; - memset(sums, 0, 8*sizeof(float)); - - uint16_t s16[2]; - const uint8_t * restrict scales = (const uint8_t *)s16; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - uint8_t * restrict a = aux8; - for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF; - for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4; - - const uint16_t * restrict b = (const uint16_t *)x[i].scales; - s16[0] = b[0] & 0x0f0f; - s16[1] = (b[0] >> 4) & 0x0f0f; - - sumf -= y[i].d * ggml_fp16_to_fp32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); - - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[0]); - - for (int j = 0; j < QK_K/32; ++j) { - for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; - q8 += 16; a += 16; - for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l]; - q8 += 16; a += 16; - const float dl = d * scales[j]; - for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[l+8]); - } - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif -} -#endif - -#if QK_K == 256 -void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - assert(n % QK_K == 0); - - const block_q5_K * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; - - uint32_t utmp[4]; - - -#ifdef __ARM_NEON - - const uint8x16_t m4b = vdupq_n_u8(0xf); - const uint8x16_t mone = vdupq_n_u8(1); - const uint8x16_t mtwo = vdupq_n_u8(2); -#if defined(__ARM_FEATURE_DOTPROD) - const int32x4_t mzero = vdupq_n_s32(0); -#endif - - int8x16x4_t q5bytes; - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin); - - const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8); - const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8)); - const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), - vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); - int32_t sumi_mins = vaddvq_s32(prod); - - const uint8_t * scales = (const uint8_t *)utmp; - - const uint8_t * restrict q5 = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - uint8x16x2_t qhbits = vld1q_u8_x2(qh); - - uint8x16x4_t q5h; - - int32_t sumi = 0; - - for (int j = 0; j < QK_K/64; ++j) { - - const uint8x16x2_t q5bits = vld1q_u8_x2(q5); q5 += 32; - const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64; - - q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); - q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); - q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3); - q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3); - qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2); - qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2); - - q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0])); - q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1])); - q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2])); - q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3])); - -#if defined(__ARM_FEATURE_DOTPROD) - - sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++; - sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++; -#else - - const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++; - - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++; -#endif - } - - sumf += d * sumi - dmin * sumi_mins; - - } - - *s = sumf; - -#elif defined __AVX2__ - - const __m256i m4 = _mm256_set1_epi8(0xF); - const __m128i mzero = _mm_setzero_si128(); - const __m256i mone = _mm256_set1_epi8(1); - - __m256 acc = _mm256_setzero_ps(); - - float summs = 0.f; - - for (int i = 0; i < nb; ++i) { - - const uint8_t * restrict q5 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - -#if QK_K == 256 - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; -#else - // TODO - const float d = 0, dmin = 0; -#endif - - const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); - - const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); - const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); - const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); - const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); - summs += dmin * _mm_extract_epi32(hsum, 0); - - const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); - const __m256i scales = MM256_SET_M128I(sc128, sc128); - - const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh); - __m256i hmask = mone; - - __m256i sumi = _mm256_setzero_si256(); - - int bit = 0; - - for (int j = 0; j < QK_K/64; ++j) { - - const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); - const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); - - const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32; - - const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); - const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); - const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0); - hmask = _mm256_slli_epi16(hmask, 1); - - const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); - const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); - const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1); - hmask = _mm256_slli_epi16(hmask, 1); - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - - __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1); - - p16_0 = _mm256_madd_epi16(scale_0, p16_0); - p16_1 = _mm256_madd_epi16(scale_1, p16_1); - - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); - - } - - __m256 vd = _mm256_set1_ps(d); - acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); - - } - - *s = hsum_float_8(acc) + summs; - -#elif defined __AVX__ - - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i mzero = _mm_setzero_si128(); - const __m128i mone = _mm_set1_epi8(1); - const __m128i m2 = _mm_set1_epi8(2); - - __m256 acc = _mm256_setzero_ps(); - - float summs = 0.f; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); - - const uint8_t * restrict q5 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); - const __m128i scales = _mm_cvtepu8_epi16(utmps); - const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); - - const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); - const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); - const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); - const __m128i prod = _mm_madd_epi16(mins, q8s); - const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); - summs += dmin * _mm_extract_epi32(hsum, 0); - - const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]); - const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]); - __m128i hmask = mone; - - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); - - int bit = 0; - - __m128i shuffle = _mm_set1_epi16(0x0100); - for (int j = 0; j < QK_K/64; ++j) { - - const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi16(shuffle, m2); - const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi16(shuffle, m2); - - const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; - const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; - - __m128i q5l_0 = _mm_and_si128(q5bits_0, m4); - __m128i q5l_1 = _mm_and_si128(q5bits_1, m4); - __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); - __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); - __m128i q5_0 = _mm_add_epi8(q5l_0, q5h_0); - __m128i q5_1 = _mm_add_epi8(q5l_1, q5h_1); - hmask = _mm_slli_epi16(hmask, 1); - - __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0); - __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1); - p16_0 = _mm_madd_epi16(scale_0, p16_0); - p16_1 = _mm_madd_epi16(scale_0, p16_1); - - q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4); - q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4); - q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); - q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); - q5_0 = _mm_add_epi8(q5l_0, q5h_0); - q5_1 = _mm_add_epi8(q5l_1, q5h_1); - hmask = _mm_slli_epi16(hmask, 1); - - q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0); - __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1); - p16_2 = _mm_madd_epi16(scale_1, p16_2); - p16_3 = _mm_madd_epi16(scale_1, p16_3); - - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); - - } - - __m256 vd = _mm256_set1_ps(d); - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); - - } - - *s = hsum_float_8(acc) + summs; - -#elif defined __riscv_v_intrinsic - - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; - - float sumf = 0; - float sums = 0.0; - - size_t vl; - - for (int i = 0; i < nb; ++i) { - - vl = 8; - - const uint8_t * restrict q5 = x[i].qs; - const uint8_t * restrict hm = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; - const float dmin = ggml_fp16_to_fp32(x[i].dmin) * y[i].d; - - vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); - vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); - vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); - vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); - vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); - - vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); - - vl = 32; - int32_t aux32 = 0; - int is = 0; - - uint8_t m = 1; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl); - - for (int j = 0; j < QK_K/64; ++j) { - // load Q5 and Q8 - vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl); - vint8m1_t q8_y1 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl); - - // compute mask for addition - vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl)); - vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl); - vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_m(vmask_1, q5_a, 16, vl); - m <<= 1; - - vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl)); - vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl); - vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_m(vmask_2, q5_l, 16, vl); - m <<= 1; - - vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl); - vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl); - - vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl); - vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl); - - vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl); - vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl); - - aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2); - q5 += 32; q8 += 64; - - } - - vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1); - sums += __riscv_vfmv_f_s_f32m1_f32(vaux); - - } - - *s = sumf+sums; - -#else - - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; - - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - memset(sums, 0, 8*sizeof(float)); - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].qs; - const uint8_t * restrict hm = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - memset(aux32, 0, 8*sizeof(int32_t)); - int8_t * restrict a = aux8; - uint8_t m = 1; - for (int j = 0; j < QK_K/64; ++j) { - for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); - for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); - a += 32; m <<= 1; - for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); - for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); - a += 32; m <<= 1; - q4 += 32; - } - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - int sumi = 0; - for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; - a = aux8; - int is = 0; - for (int j = 0; j < QK_K/32; ++j) { - int32_t scale = scales[is++]; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - } - const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; - const float dmin = ggml_fp16_to_fp32(x[i].dmin) * y[i].d; - sumf -= dmin * sumi; - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif -} - -#else - -void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - assert(n % QK_K == 0); - - const block_q5_K * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#ifdef __ARM_NEON - - const uint8x16_t m4b = vdupq_n_u8(0xf); - const uint8x16_t mh = vdupq_n_u8(16); -#if defined(__ARM_FEATURE_DOTPROD) - const int32x4_t mzero = vdupq_n_s32(0); -#endif - - int8x16x4_t q5bytes; - uint8x16x4_t q5h; - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * (float)x[i].d; - const int8_t * sc = x[i].scales; - - const uint8_t * restrict q5 = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const uint8x8_t qhbits = vld1_u8(qh); - - const uint8x16x2_t q5bits = vld1q_u8_x2(q5); - const int8x16x4_t q8bytes = vld1q_s8_x4(q8); - - const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1)); - q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4)); - q5h.val[1] = vbicq_u8(mh, vshlq_n_u8(htmp, 2)); - q5h.val[2] = vbicq_u8(mh, htmp); - q5h.val[3] = vbicq_u8(mh, vshrq_n_u8(htmp, 2)); - - q5bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[0], m4b)), vreinterpretq_s8_u8(q5h.val[0])); - q5bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[1], m4b)), vreinterpretq_s8_u8(q5h.val[1])); - q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2])); - q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3])); - -#if defined(__ARM_FEATURE_DOTPROD) - - int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0])); - int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1])); - int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2])); - int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3])); - - sumf += d * (sumi1 + sumi2 + sumi3 + sumi4); - -#else - - const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1); - - const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3); - - sumf += d*sumi; -#endif - - } - - *s = sumf; - -#elif defined __AVX2__ - - const __m256i m4 = _mm256_set1_epi8(0xF); - const __m256i mone = _mm256_set1_epi8(1); - - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const uint8_t * restrict q5 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - - const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); - - const __m256i scale_l = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0])); - const __m256i scale_h = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2])); - - int64_t aux64; - memcpy(&aux64, x[i].qh, 8); - const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64); - const __m256i haux256 = MM256_SET_M128I(_mm_srli_epi16(haux128, 2), haux128); - - const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4); - const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4); - - const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); - const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - - const __m256i p16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5l_0, q8_0)); - const __m256i p16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5l_1, q8_1)); - const __m256i s16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5h_0, q8_0)); - const __m256i s16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5h_1, q8_1)); - - const __m256i dot = _mm256_sub_epi32(_mm256_add_epi32(p16_0, p16_1), _mm256_add_epi32(s16_0, s16_1)); - - acc = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(dot), acc); - - } - - *s = hsum_float_8(acc); - -#elif defined __AVX__ - - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i mone = _mm_set1_epi8(1); - - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const uint8_t * restrict q5 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - - const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); - - const __m128i scale_0 = _mm_set1_epi16(x[i].scales[0]); - const __m128i scale_1 = _mm_set1_epi16(x[i].scales[1]); - const __m128i scale_2 = _mm_set1_epi16(x[i].scales[2]); - const __m128i scale_3 = _mm_set1_epi16(x[i].scales[3]); - - int64_t aux64; - memcpy(&aux64, x[i].qh, 8); - const __m128i haux128_0 = _mm_set_epi64x(aux64 >> 1, aux64); - const __m128i haux128_1 = _mm_srli_epi16(haux128_0, 2); - - const __m128i q5h_0 = _mm_slli_epi16(_mm_andnot_si128(haux128_0, mone), 4); - const __m128i q5h_1 = _mm_slli_epi16(_mm_andnot_si128(haux128_1, mone), 4); - const __m128i q5h_2 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_0, 4), mone), 4); - const __m128i q5h_3 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_1, 4), mone), 4); - - const __m128i q5l_0 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 0), m4); - const __m128i q5l_1 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 1), m4); - const __m128i q5l_2 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 0), 4), m4); - const __m128i q5l_3 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 1), 4), m4); - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - - const __m128i p16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5l_0, _mm256_extractf128_si256(q8_0, 0))); - const __m128i p16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5l_1, _mm256_extractf128_si256(q8_0, 1))); - const __m128i p16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5l_2, _mm256_extractf128_si256(q8_1, 0))); - const __m128i p16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5l_3, _mm256_extractf128_si256(q8_1, 1))); - const __m128i s16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5h_0, _mm256_extractf128_si256(q8_0, 0))); - const __m128i s16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5h_1, _mm256_extractf128_si256(q8_0, 1))); - const __m128i s16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5h_2, _mm256_extractf128_si256(q8_1, 0))); - const __m128i s16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5h_3, _mm256_extractf128_si256(q8_1, 1))); - - const __m128i dot_0 = _mm_sub_epi32(_mm_add_epi32(p16_0, p16_2), _mm_add_epi32(s16_0, s16_2)); - const __m128i dot_1 = _mm_sub_epi32(_mm_add_epi32(p16_1, p16_3), _mm_add_epi32(s16_1, s16_3)); - - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(dot_1, dot_0))), acc); - - } - - *s = hsum_float_8(acc); - -#elif defined __riscv_v_intrinsic - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * (float)x[i].d; - const int8_t * sc = x[i].scales; - - const uint8_t * restrict q5 = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - - // load qh - vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4(qh, 8); - vuint8mf2_t qh_x2 = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8)); - - size_t vl = 16; - - // combine both qh_1 and qh_2 - vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl); - - vuint8mf2_t qh_h0 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl); - vuint8mf2_t qh_h1 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), vl), 16, vl); - vuint8mf2_t qh_h2 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(qh_x, vl), 16, vl); - vuint8mf2_t qh_h3 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl); - - vint8mf2_t qh_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h0); - vint8mf2_t qh_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h1); - vint8mf2_t qh_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h2); - vint8mf2_t qh_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h3); - - // load q5 - vuint8mf2_t q5_x1 = __riscv_vle8_v_u8mf2(q5, vl); - vuint8mf2_t q5_x2 = __riscv_vle8_v_u8mf2(q5+16, vl); - - vint8mf2_t q5s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x1, 0xF, vl)); - vint8mf2_t q5s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x2, 0xF, vl)); - vint8mf2_t q5s_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x1, 0x4, vl)); - vint8mf2_t q5s_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x2, 0x4, vl)); - - vint8mf2_t q5_0 = __riscv_vsub_vv_i8mf2(q5s_0, qh_0, vl); - vint8mf2_t q5_1 = __riscv_vsub_vv_i8mf2(q5s_1, qh_1, vl); - vint8mf2_t q5_2 = __riscv_vsub_vv_i8mf2(q5s_2, qh_2, vl); - vint8mf2_t q5_3 = __riscv_vsub_vv_i8mf2(q5s_3, qh_3, vl); - - // load Q8 and multiply it with Q5 - vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q5_0, __riscv_vle8_v_i8mf2(q8, vl), vl); - vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q5_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); - vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q5_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); - vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q5_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); - - vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl); - vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl); - vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl); - vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl); - - int32_t sumi1 = sc[0] * __riscv_vmv_x_s_i32m1_i32(vs_0); - int32_t sumi2 = sc[1] * __riscv_vmv_x_s_i32m1_i32(vs_1); - int32_t sumi3 = sc[2] * __riscv_vmv_x_s_i32m1_i32(vs_2); - int32_t sumi4 = sc[3] * __riscv_vmv_x_s_i32m1_i32(vs_3); - - sumf += d * (sumi1 + sumi2 + sumi3 + sumi4); - - } - - *s = sumf; - -#else - - int8_t aux8[QK_K]; - int16_t aux16[16]; - float sums [8]; - memset(sums, 0, 8*sizeof(float)); - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].qs; - const uint8_t * restrict hm = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - int8_t * restrict a = aux8; - for (int l = 0; l < 32; ++l) { - a[l+ 0] = q4[l] & 0xF; - a[l+32] = q4[l] >> 4; - } - for (int is = 0; is < 8; ++is) { - uint8_t m = 1 << is; - for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16); - } - - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - const int8_t * restrict sc = x[i].scales; - - for (int j = 0; j < QK_K/16; ++j) { - const float dl = d * sc[j]; - for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[8+l]); - q8 += 16; a += 16; - } - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif -} -#endif - - -#if QK_K == 256 -void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - assert(n % QK_K == 0); - - const block_q6_K * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#ifdef __ARM_NEON - - float sum = 0; - - const uint8x16_t m4b = vdupq_n_u8(0xF); -#if defined(__ARM_FEATURE_DOTPROD) - const int32x4_t vzero = vdupq_n_s32(0); -#endif - //const int8x16_t m32s = vdupq_n_s8(32); - - const uint8x16_t mone = vdupq_n_u8(3); - - int8x16x4_t q6bytes; - uint8x16x4_t q6h; - - for (int i = 0; i < nb; ++i) { - - const float d_all = ggml_fp16_to_fp32(x[i].d); - - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const int8_t * restrict scale = x[i].scales; - - const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums); - const int8x16_t scales = vld1q_s8(scale); - const int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}; - - const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])), - vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))), - vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])), - vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1])))); - int32_t isum_mins = vaddvq_s32(prod); - - int32_t isum = 0; - - for (int j = 0; j < QK_K/128; ++j) { - - uint8x16x2_t qhbits = vld1q_u8_x2(qh); qh += 32; - uint8x16x4_t q6bits = vld1q_u8_x4(q6); q6 += 64; - int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64; - - q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); - q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); - uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2); - q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[1], 2); - q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - - //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); - //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); - //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s); - //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s); - q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])); - q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])); - q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])); - q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])); - -#if defined(__ARM_FEATURE_DOTPROD) - - isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; - scale += 4; - -#else - - int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; - scale += 2; - - int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1]; - scale += 2; -#endif - - q8bytes = vld1q_s8_x4(q8); q8 += 64; - - shifted = vshrq_n_u8(qhbits.val[0], 4); - q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[1], 4); - q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[0], 6); - q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[1], 6); - q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - - //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s); - //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s); - //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s); - //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s); - q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])); - q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])); - q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])); - q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])); - -#if defined(__ARM_FEATURE_DOTPROD) - - isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; - scale += 4; - - //for (int l = 0; l < 4; ++l) { - // const int32x4_t p = vdotq_s32(vzero, q6bytes.val[l], q8bytes.val[l]); - // isum += vaddvq_s32(p) * *scale++; - //} -#else - p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; - scale += 2; - - p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1]; - scale += 2; -#endif - - } - //sum += isum * d_all * y[i].d; - sum += d_all * y[i].d * (isum - 32 * isum_mins); - - } - *s = sum; - -#elif defined __AVX2__ - - const __m256i m4 = _mm256_set1_epi8(0xF); - const __m256i m2 = _mm256_set1_epi8(3); - const __m256i m32s = _mm256_set1_epi8(32); - - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); - - __m256i sumi = _mm256_setzero_si256(); - - int is = 0; - - for (int j = 0; j < QK_K/128; ++j) { - - const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); - const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); - const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); - const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); - is += 4; - - const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; - const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; - const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32; - - const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4); - const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4); - const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4); - const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4); - - const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); - const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1); - const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2); - const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3); - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - - __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); - __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); - __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2); - __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3); - - __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); - __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2); - __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3); - - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - p16_2 = _mm256_sub_epi16(p16_2, q8s_2); - p16_3 = _mm256_sub_epi16(p16_3, q8s_3); - - p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); - p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); - p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2); - p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3); - - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3)); - - } - - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); - } - - *s = hsum_float_8(acc); - -#elif defined __AVX__ - - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i m3 = _mm_set1_epi8(3); - const __m128i m32s = _mm_set1_epi8(32); - const __m128i m2 = _mm_set1_epi8(2); - - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); - - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); - - __m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); - for (int j = 0; j < QK_K/128; ++j) { - - const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16; - const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16; - - const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4); - const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4); - const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4); - const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4); - const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4); - const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4); - const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4); - const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4); - - const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - - const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0); - const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1); - const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2); - const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3); - const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4); - const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5); - const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6); - const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7); - - const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - - __m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0); - __m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1); - __m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2); - __m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3); - __m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4); - __m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5); - __m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6); - __m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7); - - __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0); - __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1); - __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2); - __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3); - __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4); - __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5); - __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6); - __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7); - - p16_0 = _mm_sub_epi16(p16_0, q8s_0); - p16_1 = _mm_sub_epi16(p16_1, q8s_1); - p16_2 = _mm_sub_epi16(p16_2, q8s_2); - p16_3 = _mm_sub_epi16(p16_3, q8s_3); - p16_4 = _mm_sub_epi16(p16_4, q8s_4); - p16_5 = _mm_sub_epi16(p16_5, q8s_5); - p16_6 = _mm_sub_epi16(p16_6, q8s_6); - p16_7 = _mm_sub_epi16(p16_7, q8s_7); - - const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi8(shuffle, m2); - const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi8(shuffle, m2); - const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi8(shuffle, m2); - const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi8(shuffle, m2); - - p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); - p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1); - p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); - p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3); - p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4); - p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5); - p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6); - p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7); - - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7)); - - } - - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); - } - - *s = hsum_float_8(acc); - -#elif defined __riscv_v_intrinsic - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - - const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; - - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const int8_t * restrict scale = x[i].scales; - - size_t vl; - - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - - int sum_t = 0; - int is = 0; - - for (int j = 0; j < QK_K/128; ++j) { - - vl = 32; - - // load qh - vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); - - // load Q6 - vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); - vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); - - vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); - vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); - vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); - vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); - - vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); - vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); - vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); - vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); - - vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); - vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); - vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); - vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); - - vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); - vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); - vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); - vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); - - // load Q8 and take product - vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); - - vl = 16; - - vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); - vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); - vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); - vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); - vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); - vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); - vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); - vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); - - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); - - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); - - q6 += 64; qh += 32; q8 += 128; is=8; - - } - - sumf += d * sum_t; - - } - - *s = sumf; - -#else - - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - memset(sums, 0, 8*sizeof(float)); - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - memset(aux32, 0, 8*sizeof(int32_t)); - int8_t * restrict a = aux8; - for (int j = 0; j < QK_K; j += 128) { - for (int l = 0; l < 32; ++l) { - a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; - a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; - a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; - a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; - } - a += 128; - q4 += 64; - qh += 32; - } - a = aux8; - int is = 0; - for (int j = 0; j < QK_K/16; ++j) { - int scale = x[i].scales[is++]; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - } - const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif -} - -#else - -void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - assert(n % QK_K == 0); - - const block_q6_K * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#ifdef __ARM_NEON - - float sum = 0; - - const uint8x16_t m4b = vdupq_n_u8(0xF); - const int8x16_t m32s = vdupq_n_s8(32); -#if defined(__ARM_FEATURE_DOTPROD) - const int32x4_t vzero = vdupq_n_s32(0); -#endif - - const uint8x16_t mone = vdupq_n_u8(3); - - int8x16x4_t q6bytes; - uint8x16x4_t q6h; - - for (int i = 0; i < nb; ++i) { - - const float d_all = (float)x[i].d; - - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const int8_t * restrict scale = x[i].scales; - - int32_t isum = 0; - - uint8x16_t qhbits = vld1q_u8(qh); - uint8x16x2_t q6bits = vld1q_u8_x2(q6); - int8x16x4_t q8bytes = vld1q_s8_x4(q8); - - q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4); - uint8x16_t shifted = vshrq_n_u8(qhbits, 2); - q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits, 4); - q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits, 6); - q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - - q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); - q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); - q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s); - q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s); - -#if defined(__ARM_FEATURE_DOTPROD) - - isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + - vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; -#else - - int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), - vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); - int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), - vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); - isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; - - int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), - vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); - int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), - vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); - isum += vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; -#endif - - sum += isum * d_all * y[i].d; - - } - *s = sum; - -#elif defined __AVX2__ - - const __m256i m4 = _mm256_set1_epi8(0xF); - const __m256i m2 = _mm256_set1_epi8(3); - const __m256i m32s = _mm256_set1_epi8(32); - - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]); - const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]); - const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]); - const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]); - - __m256i sumi = _mm256_setzero_si256(); - - const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1); - const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3); - - const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); - const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh); - - const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4); - const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4); - - const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); - const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1); - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - - __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); - __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); - - __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); - - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - - p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); - p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); - - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); - - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); - } - - *s = hsum_float_8(acc); - -#elif defined __AVX__ - - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i m2 = _mm_set1_epi8(3); - const __m128i m32s = _mm_set1_epi8(32); - - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]); - const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]); - const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]); - const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]); - - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); - - const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1); - const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3); - - const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); - const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh); - - const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4); - const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4); - const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4); - const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4); - - const __m128i q4_0 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0); - const __m128i q4_1 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_1); - const __m128i q4_2 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_2); - const __m128i q4_3 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_3); - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - - __m128i q8s_0 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0)); - __m128i q8s_1 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1)); - __m128i q8s_2 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0)); - __m128i q8s_3 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1)); - - __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0)); - __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1)); - __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0)); - __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1)); - - p16_0 = _mm_sub_epi16(p16_0, q8s_0); - p16_1 = _mm_sub_epi16(p16_1, q8s_1); - p16_2 = _mm_sub_epi16(p16_2, q8s_2); - p16_3 = _mm_sub_epi16(p16_3, q8s_3); - - p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); - p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1); - p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); - p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3); - - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); - - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi_1, sumi_0))), acc); - } - - *s = hsum_float_8(acc); - -#elif defined __riscv_v_intrinsic - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const float d_all = (float)x[i].d; - - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const int8_t * restrict scale = x[i].scales; - - int32_t isum = 0; - - size_t vl = 16; - - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - - // load Q6 - vuint8mf2_t q6_0 = __riscv_vle8_v_u8mf2(q6, vl); - vuint8mf2_t q6_1 = __riscv_vle8_v_u8mf2(q6+16, vl); - - // load qh - vuint8mf2_t qh_x = __riscv_vle8_v_u8mf2(qh, vl); - - vuint8mf2_t qh0 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); - qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); - vuint8mf2_t qh1 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); - qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); - vuint8mf2_t qh2 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); - qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); - vuint8mf2_t qh3 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); - - vuint8mf2_t q6h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_0, 0xF, vl), qh0, vl); - vuint8mf2_t q6h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_1, 0xF, vl), qh1, vl); - vuint8mf2_t q6h_2 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_0, 0x4, vl), qh2, vl); - vuint8mf2_t q6h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_1, 0x4, vl), qh3, vl); - - vint8mf2_t q6v_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_0), 32, vl); - vint8mf2_t q6v_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_1), 32, vl); - vint8mf2_t q6v_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_2), 32, vl); - vint8mf2_t q6v_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_3), 32, vl); - - // load Q8 and take product - vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q6v_0, __riscv_vle8_v_i8mf2(q8, vl), vl); - vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q6v_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); - vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q6v_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); - vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q6v_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); - - vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl); - vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl); - vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl); - vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl); - - isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scale[0]; - isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scale[1]; - isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scale[2]; - isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scale[3]; - - sumf += isum * d_all * y[i].d; - - } - - *s = sumf; - -#else - - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - memset(sums, 0, 8*sizeof(float)); - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - memset(aux32, 0, 8*sizeof(int32_t)); - int8_t * restrict a = aux8; - for (int l = 0; l < 16; ++l) { - a[l+ 0] = (int8_t)((q4[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; - a[l+16] = (int8_t)((q4[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; - a[l+32] = (int8_t)((q4[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; - a[l+48] = (int8_t)((q4[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; - } - int is = 0; - for (int j = 0; j < QK_K/16; ++j) { - int scale = x[i].scales[is++]; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - } - const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif -} - -#endif diff --git a/plugins/wasi_nn/thirdparty/ggml/k_quants.h b/plugins/wasi_nn/thirdparty/ggml/k_quants.h deleted file mode 100644 index 9de089e7..00000000 --- a/plugins/wasi_nn/thirdparty/ggml/k_quants.h +++ /dev/null @@ -1,165 +0,0 @@ -#pragma once - -#include "ggml.h" - -#include -#include -#include - -// Super-block size -#ifdef GGML_QKK_64 -#define QK_K 64 -#define K_SCALE_SIZE 4 -#else -#define QK_K 256 -#define K_SCALE_SIZE 12 -#endif - -#ifndef static_assert -#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L) -#define static_assert(cond, msg) _Static_assert(cond, msg) -#else -#define static_assert(cond, msg) struct global_scope_noop_trick -#endif -#endif - -// -// Super-block quantization structures -// - -// 2-bit quantization -// weight is represented as x = a * q + b -// 16 blocks of 16 elements each -// Effectively 2.5625 bits per weight -typedef struct { - uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits - uint8_t qs[QK_K/4]; // quants - ggml_fp16_t d; // super-block scale for quantized scales - ggml_fp16_t dmin; // super-block scale for quantized mins -} block_q2_K; -static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); - -// 3-bit quantization -// weight is represented as x = a * q -// 16 blocks of 16 elements each -// Effectively 3.4375 bits per weight -#ifdef GGML_QKK_64 -typedef struct { - uint8_t hmask[QK_K/8]; // quants - high bit - uint8_t qs[QK_K/4]; // quants - low 2 bits - uint8_t scales[2]; - ggml_fp16_t d; // super-block scale -} block_q3_K; -static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding"); -#else -typedef struct { - uint8_t hmask[QK_K/8]; // quants - high bit - uint8_t qs[QK_K/4]; // quants - low 2 bits - uint8_t scales[12]; // scales, quantized with 6 bits - ggml_fp16_t d; // super-block scale -} block_q3_K; -static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding"); -#endif - -// 4-bit quantization -// 8 blocks of 32 elements each -// weight is represented as x = a * q + b -// Effectively 4.5 bits per weight -#ifdef GGML_QKK_64 -typedef struct { - ggml_fp16_t d[2]; // super-block scales/mins - uint8_t scales[2]; // 4-bit block scales/mins - uint8_t qs[QK_K/2]; // 4--bit quants -} block_q4_K; -static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding"); -#else -typedef struct { - ggml_fp16_t d; // super-block scale for quantized scales - ggml_fp16_t dmin; // super-block scale for quantized mins - uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits - uint8_t qs[QK_K/2]; // 4--bit quants -} block_q4_K; -static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding"); -#endif - -// 5-bit quantization -// 8 blocks of 32 elements each -// weight is represented as x = a * q + b -// Effectively 5.5 bits per weight -#ifdef GGML_QKK_64 -typedef struct { - ggml_fp16_t d; // super-block scale - int8_t scales[QK_K/16]; // 8-bit block scales - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_K; -static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); -#else -typedef struct { - ggml_fp16_t d; // super-block scale for quantized scales - ggml_fp16_t dmin; // super-block scale for quantized mins - uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_K; -static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); -#endif - -// 6-bit quantization -// weight is represented as x = a * q -// 16 blocks of 16 elements each -// Effectively 6.5625 bits per weight -typedef struct { - uint8_t ql[QK_K/2]; // quants, lower 4 bits - uint8_t qh[QK_K/4]; // quants, upper 2 bits - int8_t scales[QK_K/16]; // scales, quantized with 8 bits - ggml_fp16_t d; // super-block scale -} block_q6_K; -static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding"); - -// This is only used for intermediate quantization and dot products -typedef struct { - float d; // delta - int8_t qs[QK_K]; // quants - int16_t bsums[QK_K/16]; // sum of quants in groups of 16 -} block_q8_K; -static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding"); - - -// Quantization -void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k); -void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k); -void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k); -void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k); -void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k); -void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k); - -void quantize_row_q2_K(const float * restrict x, void * restrict y, int k); -void quantize_row_q3_K(const float * restrict x, void * restrict y, int k); -void quantize_row_q4_K(const float * restrict x, void * restrict y, int k); -void quantize_row_q5_K(const float * restrict x, void * restrict y, int k); -void quantize_row_q6_K(const float * restrict x, void * restrict y, int k); -void quantize_row_q8_K(const float * restrict x, void * restrict y, int k); - -// Dequantization -void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k); -void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k); -void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k); -void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k); -void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k); -void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k); - -// Dot product -void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); -void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); -void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); -void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); -void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); - -// Quantization with histogram collection -size_t ggml_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist); -size_t ggml_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist); -size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist); -size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist); -size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist); - diff --git a/plugins/wasi_nn/thirdparty/ggml/llama.cpp b/plugins/wasi_nn/thirdparty/ggml/llama.cpp deleted file mode 100644 index fa2b3d2f..00000000 --- a/plugins/wasi_nn/thirdparty/ggml/llama.cpp +++ /dev/null @@ -1,9633 +0,0 @@ -#define LLAMA_API_INTERNAL -#include "llama.h" - -#include "unicode.h" - -#include "ggml.h" - -#include "ggml-alloc.h" - -#ifdef GGML_USE_CUBLAS -# include "ggml-cuda.h" -#elif defined(GGML_USE_CLBLAST) -# include "ggml-opencl.h" -#endif - -#ifdef GGML_USE_METAL -# include "ggml-metal.h" -#endif -#ifdef GGML_USE_MPI -# include "ggml-mpi.h" -#endif -#ifdef GGML_USE_K_QUANTS -# ifndef QK_K -# ifdef GGML_QKK_64 -# define QK_K 64 -# else -# define QK_K 256 -# endif -# endif -#endif - -#ifdef __has_include - #if __has_include() - #include - #if defined(_POSIX_MAPPED_FILES) - #include - #endif - #if defined(_POSIX_MEMLOCK_RANGE) - #include - #endif - #endif -#endif - -#if defined(_WIN32) - #define WIN32_LEAN_AND_MEAN - #ifndef NOMINMAX - #define NOMINMAX - #endif - #include - #include - #include // for _fseeki64 -#endif - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if defined(_MSC_VER) -#pragma warning(disable: 4244 4267) // possible loss of data -#endif - -#ifdef __GNUC__ -#ifdef __MINGW32__ -#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) -#else -#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) -#endif -#else -#define LLAMA_ATTRIBUTE_FORMAT(...) -#endif - -// -// logging -// - -LLAMA_ATTRIBUTE_FORMAT(2, 3) -static void llama_log_internal (ggml_log_level level, const char* format, ...); -static void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data); - -#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) -#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) -#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) - -// -// helpers -// - -static size_t utf8_len(char src) { - const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; - uint8_t highbits = static_cast(src) >> 4; - return lookup[highbits]; -} - -static void replace_all(std::string & s, const std::string & search, const std::string & replace) { - std::string result; - for (size_t pos = 0; ; pos += search.length()) { - auto new_pos = s.find(search, pos); - if (new_pos == std::string::npos) { - result += s.substr(pos, s.size() - pos); - break; - } - result += s.substr(pos, new_pos - pos) + replace; - pos = new_pos; - } - s = std::move(result); -} - -static bool is_float_close(float a, float b, float abs_tol) { - // Check for non-negative tolerance - if (abs_tol < 0.0) { - throw std::invalid_argument("Tolerance must be non-negative"); - } - - // Exact equality check - if (a == b) { - return true; - } - - // Check for infinities - if (std::isinf(a) || std::isinf(b)) { - return false; - } - - // Regular comparison using the provided absolute tolerance - return std::fabs(b - a) <= abs_tol; -} - -#ifdef GGML_USE_CPU_HBM -#include -#endif - -static void zeros(std::ofstream & file, size_t n) { - char zero = 0; - for (size_t i = 0; i < n; ++i) { - file.write(&zero, 1); - } -} - -LLAMA_ATTRIBUTE_FORMAT(1, 2) -static std::string format(const char * fmt, ...) { - va_list ap; - va_list ap2; - va_start(ap, fmt); - va_copy(ap2, ap); - int size = vsnprintf(NULL, 0, fmt, ap); - GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT - std::vector buf(size + 1); - int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); - GGML_ASSERT(size2 == size); - va_end(ap2); - va_end(ap); - return std::string(buf.data(), size); -} - -// -// gguf constants (sync with gguf.py) -// - -enum llm_arch { - LLM_ARCH_LLAMA, - LLM_ARCH_FALCON, - LLM_ARCH_BAICHUAN, - LLM_ARCH_GPT2, - LLM_ARCH_GPTJ, - LLM_ARCH_GPTNEOX, - LLM_ARCH_MPT, - LLM_ARCH_STARCODER, - LLM_ARCH_PERSIMMON, - LLM_ARCH_REFACT, - LLM_ARCH_BLOOM, - LLM_ARCH_UNKNOWN, -}; - -static std::map LLM_ARCH_NAMES = { - { LLM_ARCH_LLAMA, "llama" }, - { LLM_ARCH_FALCON, "falcon" }, - { LLM_ARCH_GPT2, "gpt2" }, - { LLM_ARCH_GPTJ, "gptj" }, - { LLM_ARCH_GPTNEOX, "gptneox" }, - { LLM_ARCH_MPT, "mpt" }, - { LLM_ARCH_BAICHUAN, "baichuan" }, - { LLM_ARCH_STARCODER, "starcoder" }, - { LLM_ARCH_PERSIMMON, "persimmon" }, - { LLM_ARCH_REFACT, "refact" }, - { LLM_ARCH_BLOOM, "bloom" }, -}; - -enum llm_kv { - LLM_KV_GENERAL_ARCHITECTURE, - LLM_KV_GENERAL_QUANTIZATION_VERSION, - LLM_KV_GENERAL_ALIGNMENT, - LLM_KV_GENERAL_NAME, - LLM_KV_GENERAL_AUTHOR, - LLM_KV_GENERAL_URL, - LLM_KV_GENERAL_DESCRIPTION, - LLM_KV_GENERAL_LICENSE, - LLM_KV_GENERAL_SOURCE_URL, - LLM_KV_GENERAL_SOURCE_HF_REPO, - - LLM_KV_CONTEXT_LENGTH, - LLM_KV_EMBEDDING_LENGTH, - LLM_KV_BLOCK_COUNT, - LLM_KV_FEED_FORWARD_LENGTH, - LLM_KV_USE_PARALLEL_RESIDUAL, - LLM_KV_TENSOR_DATA_LAYOUT, - - LLM_KV_ATTENTION_HEAD_COUNT, - LLM_KV_ATTENTION_HEAD_COUNT_KV, - LLM_KV_ATTENTION_MAX_ALIBI_BIAS, - LLM_KV_ATTENTION_CLAMP_KQV, - LLM_KV_ATTENTION_LAYERNORM_EPS, - LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, - - LLM_KV_ROPE_DIMENSION_COUNT, - LLM_KV_ROPE_FREQ_BASE, - LLM_KV_ROPE_SCALE_LINEAR, - - LLM_KV_TOKENIZER_MODEL, - LLM_KV_TOKENIZER_LIST, - LLM_KV_TOKENIZER_TOKEN_TYPE, - LLM_KV_TOKENIZER_SCORES, - LLM_KV_TOKENIZER_MERGES, - LLM_KV_TOKENIZER_BOS_ID, - LLM_KV_TOKENIZER_EOS_ID, - LLM_KV_TOKENIZER_UNK_ID, - LLM_KV_TOKENIZER_SEP_ID, - LLM_KV_TOKENIZER_PAD_ID, - LLM_KV_TOKENIZER_HF_JSON, - LLM_KV_TOKENIZER_RWKV, -}; - -static std::map LLM_KV_NAMES = { - { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, - { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, - { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, - { LLM_KV_GENERAL_NAME, "general.name" }, - { LLM_KV_GENERAL_AUTHOR, "general.author" }, - { LLM_KV_GENERAL_URL, "general.url" }, - { LLM_KV_GENERAL_DESCRIPTION, "general.description" }, - { LLM_KV_GENERAL_LICENSE, "general.license" }, - { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, - { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, - - { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, - { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, - { LLM_KV_BLOCK_COUNT, "%s.block_count" }, - { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, - { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, - { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, - - { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, - { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, - { LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" }, - { LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" }, - { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, - { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, - - { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, - { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, - { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, - - { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, - { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, - { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, - { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, - { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, - { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, - { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" }, - { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, - { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, - { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, - { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, - { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, -}; - -struct LLM_KV { - LLM_KV(llm_arch arch) : arch(arch) {} - - llm_arch arch; - - std::string operator()(llm_kv kv) const { - return ::format(LLM_KV_NAMES[kv].c_str(), LLM_ARCH_NAMES[arch].c_str()); - } -}; - -enum llm_tensor { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_POS_EMBD, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_NORM_2, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, -}; - -static std::map> LLM_TENSOR_NAMES = { - { - LLM_ARCH_LLAMA, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_BAICHUAN, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_FALCON, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_GPT2, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - }, - }, - { - LLM_ARCH_GPTJ, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - }, - }, - { - LLM_ARCH_GPTNEOX, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_PERSIMMON, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd"}, - { LLM_TENSOR_OUTPUT_NORM, "output_norm"}, - { LLM_TENSOR_OUTPUT, "output"}, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm"}, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv"}, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output"}, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm"}, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm"}, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm"}, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down"}, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up"}, - { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd"}, - }, - }, - { - LLM_ARCH_MPT, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_STARCODER, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_POS_EMBD, "position_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - }, - }, - { - LLM_ARCH_REFACT, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - }, - }, - { - LLM_ARCH_BLOOM, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - }, - }, - { - LLM_ARCH_UNKNOWN, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - }, - }, -}; - -static llm_arch llm_arch_from_string(const std::string & name) { - for (const auto & kv : LLM_ARCH_NAMES) { // NOLINT - if (kv.second == name) { - return kv.first; - } - } - - return LLM_ARCH_UNKNOWN; -} - -// helper to handle gguf constants -// usage: -// -// const auto tn = LLM_TN(LLM_ARCH_LLAMA); -// -// std::string name = tn(LLM_TENSOR_OUTPUT); -> "output" -// std::string name = tn(LLM_TENSOR_TOKEN_EMBD, "bias"); -> "token_embd.bias" -// std::string name = tn(LLM_TENSOR_ATTN_NORM, "weight", 3); -> "blk.3.attn_norm.weight" -// -struct LLM_TN { - LLM_TN(llm_arch arch) : arch(arch) {} - - llm_arch arch; - - std::string operator()(llm_tensor tensor) const { - return LLM_TENSOR_NAMES[arch].at(tensor); - } - - std::string operator()(llm_tensor tensor, const std::string & suffix) const { - return LLM_TENSOR_NAMES[arch].at(tensor) + "." + suffix; - } - - std::string operator()(llm_tensor tensor, int bid) const { - return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid); - } - - std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const { - return ::format(LLM_TENSOR_NAMES[arch].at(tensor).c_str(), bid) + "." + suffix; - } -}; - -// -// gguf helpers -// - -#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \ -do { \ - const std::string skey(key); \ - const int kid = gguf_find_key(ctx, skey.c_str()); \ - if (kid >= 0) { \ - enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \ - if (ktype != (type)) { \ - throw std::runtime_error(format("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype))); \ - } \ - (dst) = func(ctx, kid); \ - } else if (req) { \ - throw std::runtime_error(format("key not found in model: %s", skey.c_str())); \ - } \ -} while (0) - -// -// ggml helpers -// - -static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * graph, int n_threads) { - struct ggml_cplan plan = ggml_graph_plan(graph, n_threads); - - if (plan.work_size > 0) { - buf.resize(plan.work_size); - plan.work_data = buf.data(); - } - - ggml_graph_compute(graph, &plan); -} - -// -// llama helpers -// - -#ifdef GGML_USE_CUBLAS -# define llama_host_malloc(n) ggml_cuda_host_malloc(n) -# define llama_host_free(data) ggml_cuda_host_free(data) -#elif GGML_USE_METAL -# define llama_host_malloc(n) ggml_metal_host_malloc(n) -# define llama_host_free(data) ggml_metal_host_free(data) -#elif GGML_USE_CPU_HBM -# define llama_host_malloc(n) hbw_malloc(n) -# define llama_host_free(data) if (data != NULL) hbw_free(data) -#else -# define llama_host_malloc(n) malloc(n) -# define llama_host_free(data) free(data) -#endif - -#if defined(_WIN32) -static std::string llama_format_win_err(DWORD err) { - LPSTR buf; - size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, - NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL); - if (!size) { - return "FormatMessageA failed"; - } - std::string ret(buf, size); - LocalFree(buf); - return ret; -} -#endif - -struct llama_buffer { - void * data = NULL; - size_t size = 0; - - // fallback to malloc / free - // useful in cases where CUDA can try to allocate PINNED memory - bool fallback = false; - - void resize(size_t n) { - llama_host_free(data); - - data = llama_host_malloc(n); - if (!data) { - fallback = true; - data = malloc(n); - } else { - fallback = false; - } - - GGML_ASSERT(data); - size = n; - } - - ~llama_buffer() { - if (data) { - if (fallback) { // NOLINT - free(data); - } else { - llama_host_free(data); - } - } - - data = NULL; - } -}; - -struct llama_file { - // use FILE * so we don't have to re-open the file to mmap - FILE * fp; - size_t size; - - llama_file(const char * fname, const char * mode) { - fp = std::fopen(fname, mode); - if (fp == NULL) { - throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); - } - seek(0, SEEK_END); - size = tell(); - seek(0, SEEK_SET); - } - - size_t tell() const { -#ifdef _WIN32 - __int64 ret = _ftelli64(fp); -#else - long ret = std::ftell(fp); -#endif - GGML_ASSERT(ret != -1); // this really shouldn't fail - return (size_t) ret; - } - - void seek(size_t offset, int whence) const { -#ifdef _WIN32 - int ret = _fseeki64(fp, (__int64) offset, whence); -#else - int ret = std::fseek(fp, (long) offset, whence); -#endif - GGML_ASSERT(ret == 0); // same - } - - void read_raw(void * ptr, size_t len) const { - if (len == 0) { - return; - } - errno = 0; - std::size_t ret = std::fread(ptr, len, 1, fp); - if (ferror(fp)) { - throw std::runtime_error(format("read error: %s", strerror(errno))); - } - if (ret != 1) { - throw std::runtime_error(std::string("unexpectedly reached end of file")); - } - } - - uint32_t read_u32() const { - uint32_t ret; - read_raw(&ret, sizeof(ret)); - return ret; - } - - void write_raw(const void * ptr, size_t len) const { - if (len == 0) { - return; - } - errno = 0; - size_t ret = std::fwrite(ptr, len, 1, fp); - if (ret != 1) { - throw std::runtime_error(format("write error: %s", strerror(errno))); - } - } - - void write_u32(std::uint32_t val) const { - write_raw(&val, sizeof(val)); - } - - ~llama_file() { - if (fp) { - std::fclose(fp); - } - } -}; - -struct llama_mmap { - void * addr; - size_t size; - - llama_mmap(const llama_mmap &) = delete; - -#ifdef _POSIX_MAPPED_FILES - static constexpr bool SUPPORTED = true; - - llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */, bool numa = false) { - size = file->size; - int fd = fileno(file->fp); - int flags = MAP_SHARED; - // prefetch/readahead impairs performance on NUMA systems - if (numa) { prefetch = 0; } -#ifdef __linux__ - if (prefetch) { flags |= MAP_POPULATE; } -#endif - addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0); - if (addr == MAP_FAILED) { - throw std::runtime_error(format("mmap failed: %s", strerror(errno))); - } - - if (prefetch > 0) { - // Advise the kernel to preload the mapped memory - if (posix_madvise(addr, std::min(file->size, prefetch), POSIX_MADV_WILLNEED)) { - fprintf(stderr, "warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n", - strerror(errno)); - } - } - if (numa) { - // advise the kernel not to use readahead - // (because the next page might not belong on the same node) - if (posix_madvise(addr, file->size, POSIX_MADV_RANDOM)) { - fprintf(stderr, "warning: posix_madvise(.., POSIX_MADV_RANDOM) failed: %s\n", - strerror(errno)); - } - } - } - - ~llama_mmap() { - munmap(addr, size); - } -#elif defined(_WIN32) - static constexpr bool SUPPORTED = true; - - llama_mmap(struct llama_file * file, bool prefetch = true, bool numa = false) { - (void) numa; - - size = file->size; - - HANDLE hFile = (HANDLE) _get_osfhandle(_fileno(file->fp)); - - HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); - DWORD error = GetLastError(); - - if (hMapping == NULL) { - throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str())); - } - - addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); - error = GetLastError(); - CloseHandle(hMapping); - - if (addr == NULL) { - throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str())); - } - - if (prefetch) { - // PrefetchVirtualMemory is only present on Windows 8 and above, so we dynamically load it - BOOL (WINAPI *pPrefetchVirtualMemory) (HANDLE, ULONG_PTR, PWIN32_MEMORY_RANGE_ENTRY, ULONG); - HMODULE hKernel32 = GetModuleHandleW(L"kernel32.dll"); - - // may fail on pre-Windows 8 systems - pPrefetchVirtualMemory = reinterpret_cast (GetProcAddress(hKernel32, "PrefetchVirtualMemory")); - - if (pPrefetchVirtualMemory) { - // advise the kernel to preload the mapped memory - WIN32_MEMORY_RANGE_ENTRY range; - range.VirtualAddress = addr; - range.NumberOfBytes = (SIZE_T)size; - if (!pPrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) { - fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n", - llama_format_win_err(GetLastError()).c_str()); - } - } - } - } - - ~llama_mmap() { - if (!UnmapViewOfFile(addr)) { - fprintf(stderr, "warning: UnmapViewOfFile failed: %s\n", - llama_format_win_err(GetLastError()).c_str()); - } - } -#else - static constexpr bool SUPPORTED = false; - - llama_mmap(struct llama_file * file, bool prefetch = true, bool numa = false) { - (void) file; - (void) prefetch; - (void) numa; - - throw std::runtime_error(std::string("mmap not supported")); - } -#endif -}; - -// Represents some region of memory being locked using mlock or VirtualLock; -// will automatically unlock on destruction. -struct llama_mlock { - void * addr = NULL; - size_t size = 0; - - bool failed_already = false; - - llama_mlock() {} - llama_mlock(const llama_mlock &) = delete; - - ~llama_mlock() { - if (size) { - raw_unlock(addr, size); - } - } - - void init(void * ptr) { - GGML_ASSERT(addr == NULL && size == 0); // NOLINT - addr = ptr; - } - - void grow_to(size_t target_size) { - GGML_ASSERT(addr); - if (failed_already) { - return; - } - size_t granularity = lock_granularity(); - target_size = (target_size + granularity - 1) & ~(granularity - 1); - if (target_size > size) { - if (raw_lock((uint8_t *) addr + size, target_size - size)) { - size = target_size; - } else { - failed_already = true; - } - } - } - -#ifdef _POSIX_MEMLOCK_RANGE - static constexpr bool SUPPORTED = true; - - static size_t lock_granularity() { - return (size_t) sysconf(_SC_PAGESIZE); - } - - #ifdef __APPLE__ - #define MLOCK_SUGGESTION \ - "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \ - "decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l).\n" - #else - #define MLOCK_SUGGESTION \ - "Try increasing RLIMIT_MLOCK ('ulimit -l' as root).\n" - #endif - - bool raw_lock(const void * addr, size_t size) const { - if (!mlock(addr, size)) { - return true; - } - - char* errmsg = std::strerror(errno); - bool suggest = (errno == ENOMEM); - - // Check if the resource limit is fine after all - struct rlimit lock_limit; - if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) { - suggest = false; - } - if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) { - suggest = false; - } - - fprintf(stderr, "warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s", - size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : ""); - return false; - } - - #undef MLOCK_SUGGESTION - - static void raw_unlock(void * addr, size_t size) { - if (munlock(addr, size)) { - fprintf(stderr, "warning: failed to munlock buffer: %s\n", std::strerror(errno)); - } - } -#elif defined(_WIN32) - static constexpr bool SUPPORTED = true; - - static size_t lock_granularity() { - SYSTEM_INFO si; - GetSystemInfo(&si); - return (size_t) si.dwPageSize; - } - - bool raw_lock(void * ptr, size_t len) const { - for (int tries = 1; ; tries++) { - if (VirtualLock(ptr, len)) { - return true; - } - if (tries == 2) { - fprintf(stderr, "warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n", - len, size, llama_format_win_err(GetLastError()).c_str()); - return false; - } - - // It failed but this was only the first try; increase the working - // set size and try again. - SIZE_T min_ws_size, max_ws_size; - if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) { - fprintf(stderr, "warning: GetProcessWorkingSetSize failed: %s\n", - llama_format_win_err(GetLastError()).c_str()); - return false; - } - // Per MSDN: "The maximum number of pages that a process can lock - // is equal to the number of pages in its minimum working set minus - // a small overhead." - // Hopefully a megabyte is enough overhead: - size_t increment = len + 1048576; - // The minimum must be <= the maximum, so we need to increase both: - min_ws_size += increment; - max_ws_size += increment; - if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) { - fprintf(stderr, "warning: SetProcessWorkingSetSize failed: %s\n", - llama_format_win_err(GetLastError()).c_str()); - return false; - } - } - } - - static void raw_unlock(void * ptr, size_t len) { - if (!VirtualUnlock(ptr, len)) { - fprintf(stderr, "warning: failed to VirtualUnlock buffer: %s\n", - llama_format_win_err(GetLastError()).c_str()); - } - } -#else - static constexpr bool SUPPORTED = false; - - static size_t lock_granularity() { - return (size_t) 65536; - } - - bool raw_lock(const void * addr, size_t len) const { - fprintf(stderr, "warning: mlock not supported on this system\n"); - return false; - } - - static void raw_unlock(const void * addr, size_t len) {} -#endif -}; - -typedef void (*offload_func_t)(struct ggml_tensor * tensor); - -static void llama_nop(struct ggml_tensor * tensor) { // don't offload by default - (void) tensor; -} - -static std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) { - std::vector result(8, 0); - const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); - if (n_tokens < 0) { - result.resize(-n_tokens); - int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); - GGML_ASSERT(check == -n_tokens); - } else { - result.resize(n_tokens); - } - - return std::string(result.data(), result.size()); -} - -// -// globals -// - -struct llama_state { - // We save the log callback globally - ggml_log_callback log_callback = llama_log_callback_default; - void * log_callback_user_data = nullptr; -}; - -static llama_state g_state; - -// available llama models -enum e_model { - MODEL_UNKNOWN, - MODEL_1B, - MODEL_3B, - MODEL_7B, - MODEL_8B, - MODEL_13B, - MODEL_15B, - MODEL_30B, - MODEL_34B, - MODEL_40B, - MODEL_65B, - MODEL_70B, -}; - -static const size_t kB = 1024; -static const size_t MB = kB*kB; -static const size_t GB = kB*kB*kB; - -struct llama_hparams { - bool vocab_only; - uint32_t n_vocab; - uint32_t n_ctx_train; // context size the model was trained on - uint32_t n_embd; - uint32_t n_head; - uint32_t n_head_kv; - uint32_t n_layer; - uint32_t n_rot; - uint32_t n_ff; - - float f_norm_eps; - float f_norm_rms_eps; - - float rope_freq_base_train; - float rope_freq_scale_train; - - float f_clamp_kqv; - float f_max_alibi_bias; - - bool operator!=(const llama_hparams & other) const { - if (this->vocab_only != other.vocab_only) return true; - if (this->n_vocab != other.n_vocab) return true; - if (this->n_ctx_train != other.n_ctx_train) return true; - if (this->n_embd != other.n_embd) return true; - if (this->n_head != other.n_head) return true; - if (this->n_head_kv != other.n_head_kv) return true; - if (this->n_layer != other.n_layer) return true; - if (this->n_rot != other.n_rot) return true; - if (this->n_ff != other.n_ff) return true; - - const float EPSILON = 1e-9; - - if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true; - if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true; - if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true; - if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true; - - return false; - } - - uint32_t n_gqa() const { - return n_head/n_head_kv; - } - - uint32_t n_embd_head() const { - return n_embd/n_head; - } - - uint32_t n_embd_gqa() const { - return n_embd/n_gqa(); - } -}; - -struct llama_cparams { - uint32_t n_ctx; // context size used during inference - uint32_t n_batch; - uint32_t n_threads; // number of threads to use for generation - uint32_t n_threads_batch; // number of threads to use for batch processing - - float rope_freq_base; - float rope_freq_scale; - - bool mul_mat_q; -}; - -struct llama_layer { - // normalization - struct ggml_tensor * attn_norm; - struct ggml_tensor * attn_norm_b; - struct ggml_tensor * attn_norm_2; - struct ggml_tensor * attn_norm_2_b; - struct ggml_tensor * attn_q_norm; - struct ggml_tensor * attn_q_norm_b; - struct ggml_tensor * attn_k_norm; - struct ggml_tensor * attn_k_norm_b; - - // attention - struct ggml_tensor * wq; - struct ggml_tensor * wk; - struct ggml_tensor * wv; - struct ggml_tensor * wo; - struct ggml_tensor * wqkv; - - // attention bias - struct ggml_tensor * bo; - struct ggml_tensor * bqkv; - - // normalization - struct ggml_tensor * ffn_norm; - struct ggml_tensor * ffn_norm_b; - - // ff - struct ggml_tensor * w1; // ffn_gate - struct ggml_tensor * w2; // ffn_down - struct ggml_tensor * w3; // ffn_up - - // ff bias - struct ggml_tensor * b2; // ffn_down - struct ggml_tensor * b3; // ffn_up -}; - -struct llama_kv_cell { - llama_pos pos = -1; - llama_pos delta = 0; - - std::set seq_id; - - bool has_seq_id(const llama_seq_id & id) const { - return seq_id.find(id) != seq_id.end(); - } -}; - -// ring-buffer of cached KV data -struct llama_kv_cache { - bool has_shift = false; - - // Note: The value of head isn't only used to optimize searching - // for a free KV slot. llama_decode_internal also uses it, so it - // cannot be freely changed after a slot has been allocated. - uint32_t head = 0; - uint32_t size = 0; - - // computed before each graph build - uint32_t n = 0; - - std::vector cells; - - struct ggml_tensor * k = NULL; - struct ggml_tensor * v = NULL; - - struct ggml_context * ctx = NULL; - - llama_buffer buf; - - ~llama_kv_cache() { - if (ctx) { - ggml_free(ctx); - } - -#ifdef GGML_USE_CUBLAS - ggml_cuda_free_data(k); - ggml_cuda_free_data(v); -#endif // GGML_USE_CUBLAS - } -}; - -struct llama_vocab { - using id = int32_t; - using token = std::string; - using ttype = llama_token_type; - - struct token_data { - token text; - float score; - ttype type; - }; - - enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; - - std::unordered_map token_to_id; - std::vector id_to_token; - - std::map, int> bpe_ranks; - - // default LLaMA special tokens - id special_bos_id = 1; - id special_eos_id = 2; - id special_unk_id = 0; - id special_sep_id = -1; - id special_pad_id = -1; - - id linefeed_id = 13; - id special_prefix_id = 32007; - id special_middle_id = 32009; - id special_suffix_id = 32008; - id special_eot_id = 32010; - - int find_bpe_rank(std::string token_left, std::string token_right) const { - replace_all(token_left, " ", "\u0120"); - replace_all(token_left, "\n", "\u010A"); - replace_all(token_right, " ", "\u0120"); - replace_all(token_right, "\n", "\u010A"); - - auto it = bpe_ranks.find(std::make_pair(token_left, token_right)); - if (it == bpe_ranks.end()) { - return -1; - } - - return it->second; - } -}; - -struct llama_model { - e_model type = MODEL_UNKNOWN; - llm_arch arch = LLM_ARCH_UNKNOWN; - llama_ftype ftype = LLAMA_FTYPE_ALL_F32; - - std::string name = "n/a"; - - llama_hparams hparams = {}; - llama_vocab vocab; - - struct ggml_tensor * tok_embeddings; - struct ggml_tensor * pos_embeddings; - struct ggml_tensor * tok_norm; - struct ggml_tensor * tok_norm_b; - - struct ggml_tensor * output_norm; - struct ggml_tensor * output_norm_b; - struct ggml_tensor * output; - - std::vector layers; - - int n_gpu_layers; - - // context - struct ggml_context * ctx = NULL; - - // the model memory buffer - llama_buffer buf; - - // model memory mapped file - std::unique_ptr mapping; - - // objects representing data potentially being locked in memory - llama_mlock mlock_buf; - llama_mlock mlock_mmap; - - // for quantize-stats only - std::vector> tensors_by_name; - - int64_t t_load_us = 0; - int64_t t_start_us = 0; - - ~llama_model() { - if (ctx) { - ggml_free(ctx); - } - -#ifdef GGML_USE_CUBLAS - for (size_t i = 0; i < tensors_by_name.size(); ++i) { - ggml_cuda_free_data(tensors_by_name[i].second); - } - ggml_cuda_free_scratch(); -#elif defined(GGML_USE_CLBLAST) - for (size_t i = 0; i < tensors_by_name.size(); ++i) { - ggml_cl_free_data(tensors_by_name[i].second); - } -#endif - } -}; - -struct llama_context { - llama_context(const llama_model & model) : model(model), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {} - ~llama_context() { -#ifdef GGML_USE_METAL - if (ctx_metal) { - ggml_metal_free(ctx_metal); - } -#endif - if (alloc) { - ggml_allocr_free(alloc); - } - } - - llama_cparams cparams; - - const llama_model & model; - - // key + value cache for the self attention - struct llama_kv_cache kv_self; - - std::mt19937 rng; - - bool has_evaluated_once = false; - - int64_t t_start_us; - int64_t t_load_us; - int64_t t_sample_us = 0; - int64_t t_p_eval_us = 0; - int64_t t_eval_us = 0; - - int32_t n_sample = 0; // number of tokens sampled - int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) - int32_t n_eval = 0; // number of eval calls - - // decode output (2-dimensional array: [n_tokens][n_vocab]) - std::vector logits; - bool logits_all = false; - - // input embedding (1-dimensional array: [n_embd]) - std::vector embedding; - - // reusable buffer for `struct ggml_graph_plan.work_data` - std::vector work_buffer; - - // memory buffers used to evaluate the model - llama_buffer buf_compute; - - llama_buffer buf_alloc; - ggml_allocr * alloc = NULL; - -#ifdef GGML_USE_METAL - ggml_metal_context * ctx_metal = NULL; -#endif - -#ifdef GGML_USE_MPI - ggml_mpi_context * ctx_mpi = NULL; -#endif -}; - -// -// kv cache helpers -// - -static bool llama_kv_cache_init( - const struct llama_hparams & hparams, - struct llama_kv_cache & cache, - ggml_type wtype, - uint32_t n_ctx, - int n_gpu_layers) { - const uint32_t n_embd = hparams.n_embd_gqa(); - const uint32_t n_layer = hparams.n_layer; - - const int64_t n_mem = n_layer*n_ctx; - const int64_t n_elements = n_embd*n_mem; - - cache.has_shift = false; - - cache.head = 0; - cache.size = n_ctx; - - cache.cells.clear(); - cache.cells.resize(n_ctx); - - // TODO: this should be: - // cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*ggml_tensor_overhead()); - // change it and test that it works - cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB); - memset(cache.buf.data, 0, cache.buf.size); - - struct ggml_init_params params; - params.mem_size = cache.buf.size; - params.mem_buffer = cache.buf.data; - params.no_alloc = false; - - cache.ctx = ggml_init(params); - - if (!cache.ctx) { - LLAMA_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__); - return false; - } - - cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); - cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); - ggml_set_name(cache.k, "cache_k"); - ggml_set_name(cache.v, "cache_v"); - - (void) n_gpu_layers; -#ifdef GGML_USE_CUBLAS - size_t vram_kv_cache = 0; - - if (n_gpu_layers > (int)n_layer + 1) { - ggml_cuda_assign_buffers_no_scratch(cache.v); - LLAMA_LOG_INFO("%s: offloading v cache to GPU\n", __func__); - vram_kv_cache += ggml_nbytes(cache.v); - } - if (n_gpu_layers > (int)n_layer + 2) { - ggml_cuda_assign_buffers_no_scratch(cache.k); - LLAMA_LOG_INFO("%s: offloading k cache to GPU\n", __func__); - vram_kv_cache += ggml_nbytes(cache.k); - } - if (vram_kv_cache > 0) { - LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MB\n", __func__, vram_kv_cache / 1024.0 / 1024.0); - } -#endif // GGML_USE_CUBLAS - - return true; -} - -// find an empty slot of size "n_tokens" in the cache -// updates the cache head -// Note: On success, it's important that cache.head points -// to the first cell of the slot. -static bool llama_kv_cache_find_slot( - struct llama_kv_cache & cache, - const struct llama_batch & batch) { - const uint32_t n_ctx = cache.size; - const uint32_t n_tokens = batch.n_tokens; - - if (n_tokens > n_ctx) { - LLAMA_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx); - return false; - } - - uint32_t n_tested = 0; - - while (true) { - if (cache.head + n_tokens > n_ctx) { - n_tested += n_ctx - cache.head; - cache.head = 0; - continue; - } - - bool found = true; - for (uint32_t i = 0; i < n_tokens; i++) { - if (cache.cells[cache.head + i].pos >= 0) { - found = false; - cache.head += i + 1; - n_tested += i + 1; - break; - } - } - - if (found) { - break; - } - - if (n_tested >= n_ctx) { - //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); - return false; - } - } - - for (uint32_t i = 0; i < n_tokens; i++) { - cache.cells[cache.head + i].pos = batch.pos[i]; - cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i]); - } - - return true; -} - -// find how many cells are currently in use -static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { - for (uint32_t i = cache.size - 1; i > 0; --i) { - if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) { - return i + 1; - } - } - - return 0; -} - -static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, int32_t c1) { - if (c0 < 0) c0 = 0; - if (c1 < 0) c1 = cache.size; - - for (int32_t i = c0; i < c1; ++i) { - cache.cells[i].pos = -1; - cache.cells[i].seq_id.clear(); - } - - // Searching for a free slot can start here since we know it will be empty. - cache.head = uint32_t(c0); -} - -static void llama_kv_cache_seq_rm( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1) { - uint32_t new_head = cache.size; - - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); - - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - cache.cells[i].seq_id.erase(seq_id); - if (cache.cells[i].seq_id.empty()) { - cache.cells[i].pos = -1; - if (new_head == cache.size) new_head = i; - } - } - } - - // If we freed up a slot, set head to it so searching can start there. - if (new_head != cache.size) cache.head = new_head; -} - -static void llama_kv_cache_seq_cp( - struct llama_kv_cache & cache, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1) { - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); - - cache.head = 0; - - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - cache.cells[i].seq_id.insert(seq_id_dst); - } - } -} - -static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) { - uint32_t new_head = cache.size; - - for (uint32_t i = 0; i < cache.size; ++i) { - if (!cache.cells[i].has_seq_id(seq_id)) { - cache.cells[i].pos = -1; - cache.cells[i].seq_id.clear(); - if (new_head == cache.size) new_head = i; - } - } - - // If we freed up a slot, set head to it so searching can start there. - if (new_head != cache.size) cache.head = new_head; -} - -static void llama_kv_cache_seq_shift( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta) { - uint32_t new_head = cache.size; - - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); - - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - cache.cells[i].pos += delta; - if (cache.cells[i].pos < 0) { - cache.cells[i].pos = -1; - cache.cells[i].seq_id.clear(); - if (new_head == cache.size) new_head = i; - } else { - cache.has_shift = true; - cache.cells[i].delta = delta; - } - } - } - - // If we freed up a slot, set head to it so searching can start there. - // Otherwise we just start the next search from the beginning. - cache.head = new_head != cache.size ? new_head : 0; -} - -// -// model loading and saving -// - -enum llama_fver { - GGUF_FILE_VERSION_V1 = 1, - GGUF_FILE_VERSION_V2 = 2, -}; - -static const char * llama_file_version_name(llama_fver version) { - switch (version) { - case GGUF_FILE_VERSION_V1: return "GGUF V1 (support until nov 2023)"; - case GGUF_FILE_VERSION_V2: return "GGUF V2 (latest)"; - } - - return "unknown"; -} - -static std::string llama_format_tensor_shape(const std::vector & ne) { - char buf[256]; - snprintf(buf, sizeof(buf), "%5" PRId64, ne.at(0)); - for (size_t i = 1; i < ne.size(); i++) { - snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, ne.at(i)); - } - return buf; -} - -static std::string llama_format_tensor_shape(const struct ggml_tensor * t) { - char buf[256]; - snprintf(buf, sizeof(buf), "%5" PRId64, t->ne[0]); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, t->ne[i]); - } - return buf; -} - -struct llama_model_loader { - int n_kv = 0; - int n_tensors = 0; - int n_created = 0; - - int64_t n_elements = 0; - size_t n_bytes = 0; - - bool use_mmap = false; - - llama_file file; - llama_ftype ftype; - llama_fver fver; - - std::unique_ptr mapping; - - struct gguf_context * ctx_gguf = NULL; - struct ggml_context * ctx_meta = NULL; - - llama_model_loader(const std::string & fname, bool use_mmap) : file(fname.c_str(), "rb") { - struct gguf_init_params params = { - /*.no_alloc = */ true, - /*.ctx = */ &ctx_meta, - }; - - ctx_gguf = gguf_init_from_file(fname.c_str(), params); - if (!ctx_gguf) { - throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str())); - } - - n_kv = gguf_get_n_kv(ctx_gguf); - n_tensors = gguf_get_n_tensors(ctx_gguf); - - fver = (enum llama_fver ) gguf_get_version(ctx_gguf); - - for (int i = 0; i < n_tensors; i++) { - const char * name = gguf_get_tensor_name(ctx_gguf, i); - struct ggml_tensor * t = ggml_get_tensor(ctx_meta, name); - n_elements += ggml_nelements(t); - n_bytes += ggml_nbytes(t); - } - - LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n", - __func__, n_kv, n_tensors, fname.c_str(), llama_file_version_name(fver)); - - // determine file type based on the number of tensors for each quantization and print meta data - // TODO: make optional - { - std::map n_type; - - uint32_t n_type_max = 0; - enum ggml_type type_max = GGML_TYPE_F32; - - for (int i = 0; i < n_tensors; i++) { - const char * name = gguf_get_tensor_name(ctx_gguf, i); - struct ggml_tensor * meta = ggml_get_tensor(ctx_meta, name); - - n_type[meta->type]++; - - if (n_type_max < n_type[meta->type]) { - n_type_max = n_type[meta->type]; - type_max = meta->type; - } - - LLAMA_LOG_INFO("%s: - tensor %4d: %32s %-8s [ %s ]\n", __func__, i, name, ggml_type_name(meta->type), llama_format_tensor_shape(meta).c_str()); - } - - switch (type_max) { - case GGML_TYPE_F32: ftype = LLAMA_FTYPE_ALL_F32; break; - case GGML_TYPE_F16: ftype = LLAMA_FTYPE_MOSTLY_F16; break; - case GGML_TYPE_Q4_0: ftype = LLAMA_FTYPE_MOSTLY_Q4_0; break; - case GGML_TYPE_Q4_1: ftype = LLAMA_FTYPE_MOSTLY_Q4_1; break; - case GGML_TYPE_Q5_0: ftype = LLAMA_FTYPE_MOSTLY_Q5_0; break; - case GGML_TYPE_Q5_1: ftype = LLAMA_FTYPE_MOSTLY_Q5_1; break; - case GGML_TYPE_Q8_0: ftype = LLAMA_FTYPE_MOSTLY_Q8_0; break; - case GGML_TYPE_Q2_K: ftype = LLAMA_FTYPE_MOSTLY_Q2_K; break; - case GGML_TYPE_Q3_K: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_M; break; - case GGML_TYPE_Q4_K: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_M; break; - case GGML_TYPE_Q5_K: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M; break; - case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break; - default: - { - LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); - ftype = LLAMA_FTYPE_ALL_F32; - } break; - } - - // this is a way to mark that we have "guessed" the file type - ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED); - - { - const int kid = gguf_find_key(ctx_gguf, "general.file_type"); - if (kid >= 0) { - ftype = (llama_ftype) gguf_get_val_u32(ctx_gguf, kid); - } - } - - for (int i = 0; i < n_kv; i++) { - const char * name = gguf_get_key(ctx_gguf, i); - const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i); - - LLAMA_LOG_INFO("%s: - kv %3d: %42s %-8s\n", __func__, i, name, gguf_type_name(type)); - } - - // print type counts - for (auto & kv : n_type) { - if (kv.second == 0) { - continue; - } - - LLAMA_LOG_INFO("%s: - type %4s: %4d tensors\n", __func__, ggml_type_name(kv.first), kv.second); - } - } - - if (!llama_mmap::SUPPORTED) { - LLAMA_LOG_WARN("%s: mmap is not supported on this platform\n", __func__); - use_mmap = false; - } - - this->use_mmap = use_mmap; - } - - ~llama_model_loader() { - if (ctx_gguf) { - gguf_free(ctx_gguf); - } - if (ctx_meta) { - ggml_free(ctx_meta); - } - } - - std::string get_arch_name() const { - const auto kv = LLM_KV(LLM_ARCH_UNKNOWN); - - std::string arch_name; - GGUF_GET_KEY(ctx_gguf, arch_name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_ARCHITECTURE)); - - return arch_name; - } - - enum llm_arch get_arch() const { - const std::string arch_name = get_arch_name(); - - return llm_arch_from_string(arch_name); - } - - const char * get_tensor_name(int i) const { - return gguf_get_tensor_name(ctx_gguf, i); - } - - struct ggml_tensor * get_tensor_meta(int i) const { - return ggml_get_tensor(ctx_meta, get_tensor_name(i)); - } - - void calc_sizes(size_t & ctx_size_p, size_t & mmapped_size_p) const { - ctx_size_p = 0; - mmapped_size_p = 0; - - for (int i = 0; i < n_tensors; i++) { - struct ggml_tensor * meta = get_tensor_meta(i); - ctx_size_p += sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE; - (use_mmap ? mmapped_size_p : ctx_size_p) += ggml_nbytes_pad(meta); - } - } - - struct ggml_tensor * create_tensor_for(struct ggml_context * ctx, struct ggml_tensor * meta, ggml_backend_type backend) { - if (backend != GGML_BACKEND_CPU) { - ggml_set_no_alloc(ctx, true); - } - - struct ggml_tensor * tensor = ggml_dup_tensor(ctx, meta); - tensor->backend = backend; // TODO: ggml_set_backend - ggml_set_name(tensor, ggml_get_name(meta)); - - if (backend != GGML_BACKEND_CPU) { - ggml_set_no_alloc(ctx, use_mmap); - } - - n_created++; - - return tensor; - } - - struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector & ne, ggml_backend_type backend) { - struct ggml_tensor * cur = ggml_get_tensor(ctx_meta, name.c_str()); - - if (cur == NULL) { - throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name.c_str())); - } - - { - bool is_ok = true; - for (size_t i = 0; i < ne.size(); ++i) { - if (ne[i] != cur->ne[i]) { - is_ok = false; - break; - } - } - if (!is_ok) { - throw std::runtime_error( - format("%s: tensor '%s' has wrong shape; expected %s, got %s", - __func__, name.c_str(), - llama_format_tensor_shape(ne).c_str(), - llama_format_tensor_shape(cur).c_str())); - } - } - - return create_tensor_for(ctx, cur, backend); - } - - void done_getting_tensors() const { - if (n_created != n_tensors) { - throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); - } - } - - size_t file_offset(const char * name) const { - const int idx = gguf_find_tensor(ctx_gguf, name); - - if (idx < 0) { - throw std::runtime_error(format("%s: tensor '%s' not found in the file", __func__, name)); - } - - return gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, idx); - } - - void load_data_for(struct ggml_tensor * cur) const { - const size_t offs = file_offset(ggml_get_name(cur)); - - if (use_mmap) { - cur->data = (uint8_t *) mapping->addr + offs; - } else { - file.seek(offs, SEEK_SET); - file.read_raw(cur->data, ggml_nbytes(cur)); - } - } - - void load_all_data(struct ggml_context * ctx, llama_progress_callback progress_callback, void * progress_callback_user_data, llama_mlock * lmlock) { - size_t size_data = 0; - size_t size_lock = 0; - size_t size_pref = 0; // prefetch - - for (int i = 0; i < gguf_get_n_tensors(ctx_gguf); i++) { - struct ggml_tensor * cur = ggml_get_tensor(ctx, gguf_get_tensor_name(ctx_gguf, i)); - size_data += ggml_nbytes(cur); - if (cur->backend == GGML_BACKEND_CPU) { - size_pref += ggml_nbytes(cur); - } - } - - if (use_mmap) { - mapping.reset(new llama_mmap(&file, size_pref, ggml_is_numa())); - if (lmlock) { - lmlock->init(mapping->addr); - } - } - - size_t done_size = 0; - for (int i = 0; i < gguf_get_n_tensors(ctx_gguf); i++) { - struct ggml_tensor * cur = ggml_get_tensor(ctx, gguf_get_tensor_name(ctx_gguf, i)); - GGML_ASSERT(cur); // unused tensors should have been caught by load_data already - - if (progress_callback) { - progress_callback((float) done_size / size_data, progress_callback_user_data); - } - - // allocate temp buffer if not using mmap - if (!use_mmap && cur->data == NULL) { - GGML_ASSERT(cur->backend != GGML_BACKEND_CPU); - #ifdef GGML_USE_CPU_HBM - cur->data = (uint8_t*)hbw_malloc(ggml_nbytes(cur)); - #else - cur->data = (uint8_t*)malloc(ggml_nbytes(cur)); - #endif - } - - load_data_for(cur); - - switch (cur->backend) { - case GGML_BACKEND_CPU: - if (use_mmap && lmlock) { - size_lock += ggml_nbytes(cur); - lmlock->grow_to(size_lock); - } - break; -#ifdef GGML_USE_CUBLAS - case GGML_BACKEND_GPU: - case GGML_BACKEND_GPU_SPLIT: - // old code: - //ggml_cuda_transform_tensor(lt.data, lt.ggml_tensor); - - // TODO: test if this works !! - ggml_cuda_transform_tensor(cur->data, cur); - if (!use_mmap) { - free(cur->data); - } - break; -#elif defined(GGML_USE_CLBLAST) - case GGML_BACKEND_GPU: - ggml_cl_transform_tensor(cur->data, cur); - if (!use_mmap) { - free(cur->data); - } - break; -#endif - default: - continue; - } - - done_size += ggml_nbytes(cur); - } - } -}; - -// -// load LLaMA models -// - -static std::string llama_model_arch_name(llm_arch arch) { - auto it = LLM_ARCH_NAMES.find(arch); - if (it == LLM_ARCH_NAMES.end()) { - return "unknown"; - } - return it->second; -} - -static std::string llama_model_ftype_name(llama_ftype ftype) { - if (ftype & LLAMA_FTYPE_GUESSED) { - return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)"; - } - - switch (ftype) { - case LLAMA_FTYPE_ALL_F32: return "all F32"; - case LLAMA_FTYPE_MOSTLY_F16: return "mostly F16"; - case LLAMA_FTYPE_MOSTLY_Q4_0: return "mostly Q4_0"; - case LLAMA_FTYPE_MOSTLY_Q4_1: return "mostly Q4_1"; - case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: - return "mostly Q4_1, some F16"; - case LLAMA_FTYPE_MOSTLY_Q5_0: return "mostly Q5_0"; - case LLAMA_FTYPE_MOSTLY_Q5_1: return "mostly Q5_1"; - case LLAMA_FTYPE_MOSTLY_Q8_0: return "mostly Q8_0"; - - // K-quants - case LLAMA_FTYPE_MOSTLY_Q2_K: return "mostly Q2_K"; - case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "mostly Q3_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "mostly Q3_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "mostly Q3_K - Large"; - case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "mostly Q4_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "mostly Q4_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "mostly Q5_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "mostly Q5_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q6_K: return "mostly Q6_K"; - - default: return "unknown, may not work"; - } -} - -static const char * llama_model_type_name(e_model type) { - switch (type) { - case MODEL_1B: return "1B"; - case MODEL_3B: return "3B"; - case MODEL_7B: return "7B"; - case MODEL_8B: return "8B"; - case MODEL_13B: return "13B"; - case MODEL_15B: return "15B"; - case MODEL_30B: return "30B"; - case MODEL_34B: return "34B"; - case MODEL_40B: return "40B"; - case MODEL_65B: return "65B"; - case MODEL_70B: return "70B"; - default: return "?B"; - } -} - -static void llm_load_arch(llama_model_loader & ml, llama_model & model) { - model.arch = ml.get_arch(); - if (model.arch == LLM_ARCH_UNKNOWN) { - throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'"); - } -} - -static void llm_load_hparams( - llama_model_loader & ml, - llama_model & model) { - struct gguf_context * ctx = ml.ctx_gguf; - - const auto kv = LLM_KV(model.arch); - - auto & hparams = model.hparams; - - // get general kv - GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME)); - - // get hparams kv - GGUF_GET_KEY(ctx, hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, kv(LLM_KV_TOKENIZER_LIST)); - GGUF_GET_KEY(ctx, hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_CONTEXT_LENGTH)); - GGUF_GET_KEY(ctx, hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH)); - GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH)); - GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT)); - GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT)); - - // n_head_kv is optional, default to n_head - hparams.n_head_kv = hparams.n_head; - GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV)); - - // rope_freq_base (optional) - hparams.rope_freq_base_train = 10000.0f; - GGUF_GET_KEY(ctx, hparams.rope_freq_base_train, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE)); - - // rope_freq_scale (inverse of the kv) is optional - float ropescale = 1.0f; - GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR)); - hparams.rope_freq_scale_train = 1.0f/ropescale; - - // sanity check for n_rot (optional) - { - hparams.n_rot = hparams.n_embd / hparams.n_head; - - GGUF_GET_KEY(ctx, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT)); - - if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) { - if (hparams.n_rot != hparams.n_embd / hparams.n_head) { - throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head)); - } - } - // gpt-neox n_rot = rotary_pct * (n_embd / n_head) - // gpt-j n_rot = rotary_dim - } - - // arch-specific KVs - switch (model.arch) { - case LLM_ARCH_LLAMA: - { - GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS)); - - switch (hparams.n_layer) { - case 26: model.type = e_model::MODEL_3B; break; - case 32: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_13B; break; - case 48: model.type = e_model::MODEL_34B; break; - case 60: model.type = e_model::MODEL_30B; break; - case 80: model.type = hparams.n_head == hparams.n_head_kv ? e_model::MODEL_65B : e_model::MODEL_70B; break; - default: model.type = e_model::MODEL_UNKNOWN; - } - } break; - case LLM_ARCH_FALCON: - { - GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS)); - - switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 60: model.type = e_model::MODEL_40B; break; - default: model.type = e_model::MODEL_UNKNOWN; - } - } break; - case LLM_ARCH_BAICHUAN: - { - GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS)); - switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_13B; break; - default: model.type = e_model::MODEL_UNKNOWN; - } - } break; - case LLM_ARCH_STARCODER: - { - GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS)); - switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_1B; break; - case 36: model.type = e_model::MODEL_3B; break; - case 42: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_15B; break; - default: model.type = e_model::MODEL_UNKNOWN; - } - } break; - case LLM_ARCH_PERSIMMON: - { - GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS)); - switch (hparams.n_layer) { - case 36: model.type = e_model::MODEL_8B; break; - default: model.type = e_model::MODEL_UNKNOWN; - } - } break; - case LLM_ARCH_REFACT: - { - GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS)); - switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_1B; break; - default: model.type = e_model::MODEL_UNKNOWN; - } - } break; - case LLM_ARCH_BLOOM: - { - GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS)); - - switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_1B; break; - case 30: - switch (hparams.n_embd) { - case 2560: model.type = e_model::MODEL_3B; break; - case 4096: model.type = e_model::MODEL_7B; break; - } break; - } - } break; - case LLM_ARCH_MPT: - { - hparams.f_clamp_kqv = 0.0f; - - GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS)); - GGUF_GET_KEY(ctx, hparams.f_clamp_kqv, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_CLAMP_KQV)); - GGUF_GET_KEY(ctx, hparams.f_max_alibi_bias, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS)); - - switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 48: model.type = e_model::MODEL_30B; break; - default: model.type = e_model::MODEL_UNKNOWN; - } - } break; - default: (void)0; - } - - model.ftype = ml.ftype; -} - -// TODO: This should probably be in llama.h -static std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos); -static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch); - -static void llm_load_vocab( - llama_model_loader & ml, - llama_model & model) { - auto & vocab = model.vocab; - - struct gguf_context * ctx = ml.ctx_gguf; - - const auto kv = LLM_KV(model.arch); - - const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str()); - if (token_idx == -1) { - throw std::runtime_error("cannot find tokenizer vocab in model file\n"); - } - - const float * scores = nullptr; - const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str()); - if (score_idx != -1) { - scores = (const float * ) gguf_get_arr_data(ctx, score_idx); - } - - const int * toktypes = nullptr; - const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str()); - if (toktype_idx != -1) { - toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); - } - - // determine vocab type - { - std::string tokenizer_name; - - GGUF_GET_KEY(ctx, tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, kv(LLM_KV_TOKENIZER_MODEL)); - - if (tokenizer_name == "llama") { - vocab.type = LLAMA_VOCAB_TYPE_SPM; - - // default special tokens - vocab.special_bos_id = 1; - vocab.special_eos_id = 2; - vocab.special_unk_id = 0; - vocab.special_sep_id = -1; - vocab.special_pad_id = -1; - } else if (tokenizer_name == "gpt2") { - vocab.type = LLAMA_VOCAB_TYPE_BPE; - - // read bpe merges and populate bpe ranks - const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); - if (merges_keyidx == -1) { - throw std::runtime_error("cannot find tokenizer merges in model file\n"); - } - - const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); - - for (int i = 0; i < n_merges; i++) { - const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); - GGML_ASSERT(codepoints_from_utf8(word).size() > 0); - - std::string first; - std::string second; - - const size_t pos = word.find(' ', 1); - - if (pos != std::string::npos) { - first = word.substr(0, pos); - second = word.substr(pos + 1); - } - - vocab.bpe_ranks.emplace(std::make_pair(first, second), i); - } - - // default special tokens - vocab.special_bos_id = 11; - vocab.special_eos_id = 11; - vocab.special_unk_id = -1; - vocab.special_sep_id = -1; - vocab.special_pad_id = -1; - } else { - LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str()); - LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); - - vocab.type = LLAMA_VOCAB_TYPE_SPM; - } - } - - const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx); - - vocab.id_to_token.resize(n_vocab); - - for (uint32_t i = 0; i < n_vocab; i++) { - std::string word = gguf_get_arr_str(ctx, token_idx, i); - GGML_ASSERT(codepoints_from_utf8(word).size() > 0); - - vocab.token_to_id[word] = i; - - auto & token_data = vocab.id_to_token[i]; - token_data.text = std::move(word); - token_data.score = scores ? scores[i] : 0.0f; - token_data.type = toktypes ? (llama_token_type) toktypes[i] : LLAMA_TOKEN_TYPE_NORMAL; - } - GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size()); - - // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' - if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { - vocab.linefeed_id = llama_byte_to_token(vocab, '\n'); - } else { - vocab.linefeed_id = llama_tokenize_internal(vocab, "\u010A", false)[0]; - } - - // special tokens - GGUF_GET_KEY(ctx, vocab.special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_BOS_ID)); - GGUF_GET_KEY(ctx, vocab.special_eos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_EOS_ID)); - GGUF_GET_KEY(ctx, vocab.special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_UNK_ID)); - GGUF_GET_KEY(ctx, vocab.special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_SEP_ID)); - GGUF_GET_KEY(ctx, vocab.special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_PAD_ID)); -} - -static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { - const auto & hparams = model.hparams; - const auto & vocab = model.vocab; - - // hparams - LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver)); - LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str()); - LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix - LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab); - LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) vocab.bpe_ranks.size()); - LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); - LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); - LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head); - LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); - LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); - LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim - LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa()); - LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); - LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); - LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); - LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); - LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); - LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); - LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); - LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); - LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str()); - LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9); - if (ml.n_bytes < GB) { - LLAMA_LOG_INFO("%s: model size = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements); - } else { - LLAMA_LOG_INFO("%s: model size = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements); - } - - // general kv - LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str()); - - // special tokens - if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); } - if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); } - if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); } - if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); } - if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); } - if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); } -} - -static void llm_load_tensors( - llama_model_loader & ml, - llama_model & model, - int n_gpu_layers, - int main_gpu, - const float * tensor_split, - bool use_mlock, - llama_progress_callback progress_callback, - void * progress_callback_user_data) { - model.t_start_us = ggml_time_us(); - - auto & ctx = model.ctx; - auto & hparams = model.hparams; - - model.n_gpu_layers = n_gpu_layers; - - size_t ctx_size; - size_t mmapped_size; - - ml.calc_sizes(ctx_size, mmapped_size); - - LLAMA_LOG_INFO("%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/1024.0/1024.0); - - // create the ggml context - { - model.buf.resize(ctx_size); - if (use_mlock) { - model.mlock_buf.init (model.buf.data); - model.mlock_buf.grow_to(model.buf.size); - } - - struct ggml_init_params params = { - /*.mem_size =*/ model.buf.size, - /*.mem_buffer =*/ model.buf.data, - /*.no_alloc =*/ ml.use_mmap, - }; - - model.ctx = ggml_init(params); - if (!model.ctx) { - throw std::runtime_error(format("ggml_init() failed")); - } - } - - (void) main_gpu; -#ifdef GGML_USE_CUBLAS - LLAMA_LOG_INFO("%s: using " GGML_CUDA_NAME " for GPU acceleration\n", __func__); - ggml_cuda_set_main_device(main_gpu); -#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU -#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU_SPLIT -#elif defined(GGML_USE_CLBLAST) - LLAMA_LOG_INFO("%s: using OpenCL for GPU acceleration\n", __func__); -#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU -#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU -#else -#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_CPU -#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_CPU -#endif - - // prepare memory for the weights - size_t vram_weights = 0; - { - const int64_t n_embd = hparams.n_embd; - const int64_t n_embd_gqa = hparams.n_embd_gqa(); - const int64_t n_layer = hparams.n_layer; - const int64_t n_vocab = hparams.n_vocab; - - const auto tn = LLM_TN(model.arch); - switch (model.arch) { - case LLM_ARCH_LLAMA: - case LLM_ARCH_REFACT: - { - model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); - - // output - { - ggml_backend_type backend_norm; - ggml_backend_type backend_output; - - if (n_gpu_layers > int(n_layer)) { - // norm is not performance relevant on its own but keeping it in VRAM reduces data copying - // on Windows however this is detrimental unless everything is on the GPU -#ifndef _WIN32 - backend_norm = LLAMA_BACKEND_OFFLOAD; -#else - backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; -#endif // _WIN32 - - backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; - } else { - backend_norm = GGML_BACKEND_CPU; - backend_output = GGML_BACKEND_CPU; - } - - model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); - model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); - - if (backend_norm == GGML_BACKEND_GPU) { - vram_weights += ggml_nbytes(model.output_norm); - } - if (backend_output == GGML_BACKEND_GPU_SPLIT) { - vram_weights += ggml_nbytes(model.output); - } - } - - const uint32_t n_ff = hparams.n_ff; - - const int i_gpu_start = n_layer - n_gpu_layers; - - model.layers.resize(n_layer); - - for (uint32_t i = 0; i < n_layer; ++i) { - const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT - const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT - - auto & layer = model.layers[i]; - - layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); - - layer.wq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, backend_split); - layer.wk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, backend_split); - layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split); - layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); - - layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); - - layer.w1 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); - layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); - layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); - - if (backend == GGML_BACKEND_GPU) { - vram_weights += - ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + - ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + - ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3); - } - } - } break; - case LLM_ARCH_BAICHUAN: - { - model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); - { - ggml_backend_type backend_norm; - ggml_backend_type backend_output; - - if (n_gpu_layers > int(n_layer)) { - // norm is not performance relevant on its own but keeping it in VRAM reduces data copying - // on Windows however this is detrimental unless everything is on the GPU -#ifndef _WIN32 - backend_norm = LLAMA_BACKEND_OFFLOAD; -#else - backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; -#endif // _WIN32 - - backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; - } else { - backend_norm = GGML_BACKEND_CPU; - backend_output = GGML_BACKEND_CPU; - } - - model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); - model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); - - if (backend_norm == GGML_BACKEND_GPU) { - vram_weights += ggml_nbytes(model.output_norm); - } - if (backend_output == GGML_BACKEND_GPU_SPLIT) { - vram_weights += ggml_nbytes(model.output); - } - } - - const uint32_t n_ff = hparams.n_ff; - - const int i_gpu_start = n_layer - n_gpu_layers; - - model.layers.resize(n_layer); - - for (uint32_t i = 0; i < n_layer; ++i) { - const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT - const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT - - auto & layer = model.layers[i]; - - layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); - - layer.wq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, backend_split); - layer.wk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, backend_split); - layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split); - layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); - - layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); - - layer.w1 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); - layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); - layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); - - if (backend == GGML_BACKEND_GPU) { - vram_weights += - ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + - ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + - ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3); - } - } - } break; - case LLM_ARCH_FALCON: - { - // TODO: CPU-only for now - - model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); - - // output - { - ggml_backend_type backend_norm; - ggml_backend_type backend_output; - - if (n_gpu_layers > int(n_layer)) { - // norm is not performance relevant on its own but keeping it in VRAM reduces data copying - // on Windows however this is detrimental unless everything is on the GPU -#ifndef _WIN32 - backend_norm = LLAMA_BACKEND_OFFLOAD; -#else - backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; -#endif // _WIN32 - - backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; - } else { - backend_norm = GGML_BACKEND_CPU; - backend_output = GGML_BACKEND_CPU; - } - - model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); - model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); - model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); - - if (backend_norm == GGML_BACKEND_GPU) { - vram_weights += ggml_nbytes(model.output_norm); - vram_weights += ggml_nbytes(model.output_norm_b); - } - if (backend_output == GGML_BACKEND_GPU_SPLIT) { - vram_weights += ggml_nbytes(model.output); - } - } - - const uint32_t n_ff = hparams.n_ff; - - const int i_gpu_start = n_layer - n_gpu_layers; - - model.layers.resize(n_layer); - - for (uint32_t i = 0; i < n_layer; ++i) { - const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT - const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT - - auto & layer = model.layers[i]; - - layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); - layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); - - if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i).c_str()) >= 0) { - layer.attn_norm_2 = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, backend); - layer.attn_norm_2_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, backend); - - if (backend == GGML_BACKEND_GPU) { - vram_weights += ggml_nbytes(layer.attn_norm_2); - vram_weights += ggml_nbytes(layer.attn_norm_2_b); - } - } - - layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); - layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); - - layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); - layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); - - if (backend == GGML_BACKEND_GPU) { - vram_weights += - ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) + - ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.wo) + - ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3); - } - } - } break; - case LLM_ARCH_STARCODER: - { - model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); - model.pos_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}, GGML_BACKEND_CPU); - - // output - { - ggml_backend_type backend_norm; - ggml_backend_type backend_output; - - if (n_gpu_layers > int(n_layer)) { - // norm is not performance relevant on its own but keeping it in VRAM reduces data copying - // on Windows however this is detrimental unless everything is on the GPU -#ifndef _WIN32 - backend_norm = LLAMA_BACKEND_OFFLOAD; -#else - backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; -#endif // _WIN32 - - backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; - } else { - backend_norm = GGML_BACKEND_CPU; - backend_output = GGML_BACKEND_CPU; - } - - model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); - model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); - model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); - - if (backend_norm == GGML_BACKEND_GPU) { - vram_weights += ggml_nbytes(model.output_norm); - vram_weights += ggml_nbytes(model.output_norm_b); - } - if (backend_output == GGML_BACKEND_GPU_SPLIT) { - vram_weights += ggml_nbytes(model.output); - } - } - - const uint32_t n_ff = hparams.n_ff; - - const int i_gpu_start = n_layer - n_gpu_layers; - - model.layers.resize(n_layer); - - for (uint32_t i = 0; i < n_layer; ++i) { - const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT - const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT - - auto & layer = model.layers[i]; - - layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); - layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); - - layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); - layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend_split); - - layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); - layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend_split); - - layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); - layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend); - - layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split); - layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend_split); - - layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); - layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend_split); - - if (backend == GGML_BACKEND_GPU) { - vram_weights += - ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) + - ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.bqkv) + - ggml_nbytes(layer.wo) + ggml_nbytes(layer.bo) + - ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_norm_b) + - ggml_nbytes(layer.w2) + ggml_nbytes(layer.b2) + - ggml_nbytes(layer.w3) + ggml_nbytes(layer.b3); - } - } - } break; - case LLM_ARCH_PERSIMMON: - { - model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); - - { - ggml_backend_type backend_norm; - ggml_backend_type backend_output; - - if (n_gpu_layers > int(n_layer)) { - // norm is not performance relevant on its own but keeping it in VRAM reduces data copying - // on Windows however this is detrimental unless everything is on the GPU -#ifndef _WIN32 - backend_norm = LLAMA_BACKEND_OFFLOAD; -#else - backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; -#endif // _WIN32 - - backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; - } else { - backend_norm = GGML_BACKEND_CPU; - backend_output = GGML_BACKEND_CPU; - } - - model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); - model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); - model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); - - if (backend_norm == GGML_BACKEND_GPU) { - vram_weights += ggml_nbytes(model.output_norm); - vram_weights += ggml_nbytes(model.output_norm_b); - } - if (backend_output == GGML_BACKEND_GPU_SPLIT) { - vram_weights += ggml_nbytes(model.output); - } - } - - const uint32_t n_ff = hparams.n_ff; - const int i_gpu_start = n_layer - n_gpu_layers; - model.layers.resize(n_layer); - for (uint32_t i = 0; i < n_layer; ++i) { - const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; - const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; - auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); - layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); - layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); - layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend_split); - layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); - layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend_split); - layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split); - layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend_split); - layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); - layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend_split); - layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); - layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend); - layer.attn_q_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {64}, backend); - layer.attn_q_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {64}, backend); - layer.attn_k_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {64}, backend); - layer.attn_k_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {64}, backend); - } - } break; - case LLM_ARCH_BLOOM: - { - // TODO: CPU-only for now - - model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); - model.tok_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, GGML_BACKEND_CPU); - model.tok_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, GGML_BACKEND_CPU); - - // output - { - ggml_backend_type backend_norm; - ggml_backend_type backend_output; - - if (n_gpu_layers > int(n_layer)) { - // norm is not performance relevant on its own but keeping it in VRAM reduces data copying - // on Windows however this is detrimental unless everything is on the GPU -#ifndef _WIN32 - backend_norm = LLAMA_BACKEND_OFFLOAD; -#else - backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; -#endif // _WIN32 - - backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; - } else { - backend_norm = GGML_BACKEND_CPU; - backend_output = GGML_BACKEND_CPU; - } - - model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); - model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); - model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); - - if (backend_norm == GGML_BACKEND_GPU) { - vram_weights += ggml_nbytes(model.output_norm); - vram_weights += ggml_nbytes(model.output_norm_b); - } - if (backend_output == GGML_BACKEND_GPU_SPLIT) { - vram_weights += ggml_nbytes(model.output); - } - } - - const uint32_t n_ff = hparams.n_ff; - - const int i_gpu_start = n_layer - n_gpu_layers; - - model.layers.resize(n_layer); - - for (uint32_t i = 0; i < n_layer; ++i) { - const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT - const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT - - auto & layer = model.layers[i]; - - layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); - layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); - - layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); - layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend_split); - - layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); - layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend_split); - - layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); - layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend); - - layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split); - layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend_split); - - layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); - layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend_split); - - if (backend == GGML_BACKEND_GPU) { - vram_weights += - ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) + - ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.bqkv) + - ggml_nbytes(layer.wo) + ggml_nbytes(layer.bo) + - ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_norm_b) + - ggml_nbytes(layer.w3) + ggml_nbytes(layer.b3) + - ggml_nbytes(layer.w2) + ggml_nbytes(layer.b2); - } - } - } break; - case LLM_ARCH_MPT: - { - model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); - - // output - { - ggml_backend_type backend_norm; - ggml_backend_type backend_output; - - if (n_gpu_layers > int(n_layer)) { - // norm is not performance relevant on its own but keeping it in VRAM reduces data copying - // on Windows however this is detrimental unless everything is on the GPU -#ifndef _WIN32 - backend_norm = LLAMA_BACKEND_OFFLOAD; -#else - backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; -#endif // _WIN32 - - backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; - } else { - backend_norm = GGML_BACKEND_CPU; - backend_output = GGML_BACKEND_CPU; - } - - model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); - model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); - - if (backend_norm == GGML_BACKEND_GPU) { - vram_weights += ggml_nbytes(model.output_norm); - } - if (backend_output == GGML_BACKEND_GPU_SPLIT) { - vram_weights += ggml_nbytes(model.output); - } - } - - const uint32_t n_ff = hparams.n_ff; - - const int i_gpu_start = n_layer - n_gpu_layers; - - model.layers.resize(n_layer); - - for (uint32_t i = 0; i < n_layer; ++i) { - const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT - const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT - - auto & layer = model.layers[i]; - - layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); - layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); - layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); - - layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); - - layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); - layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); - - if (backend == GGML_BACKEND_GPU) { - vram_weights += - ggml_nbytes(layer.attn_norm) + - ggml_nbytes(layer.wqkv) + - ggml_nbytes(layer.wo) + - ggml_nbytes(layer.ffn_norm) + - ggml_nbytes(layer.w2) + - ggml_nbytes(layer.w3); - } - } - } break; - default: - throw std::runtime_error("unknown architecture"); - } - } - - ml.done_getting_tensors(); - - // print memory requirements - { - // this is the total memory required to run the inference - size_t mem_required = - ctx_size + - mmapped_size - vram_weights; // weights in VRAM not in memory - - LLAMA_LOG_INFO("%s: mem required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0); - -#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) - const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); - - LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_gpu); - if (n_gpu_layers > (int) hparams.n_layer) { - LLAMA_LOG_INFO("%s: offloading non-repeating layers to GPU\n", __func__); - } - -#ifdef GGML_USE_CUBLAS - const int max_backend_supported_layers = hparams.n_layer + 3; - const int max_offloadable_layers = hparams.n_layer + 3; -#elif defined(GGML_USE_CLBLAST) - const int max_backend_supported_layers = hparams.n_layer + 1; - const int max_offloadable_layers = hparams.n_layer + 1; -#endif // GGML_USE_CUBLAS - - LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); - LLAMA_LOG_INFO("%s: VRAM used: %.2f MB\n", __func__, vram_weights / 1024.0 / 1024.0); -#else - (void) n_gpu_layers; -#endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) - } - - // populate `tensors_by_name` - for (int i = 0; i < ml.n_tensors; ++i) { - struct ggml_tensor * cur = ggml_get_tensor(ctx, ml.get_tensor_name(i)); - model.tensors_by_name.emplace_back(ggml_get_name(cur), cur); - } - - (void) tensor_split; -#ifdef GGML_USE_CUBLAS - { - ggml_cuda_set_tensor_split(tensor_split); - } -#endif - - ml.load_all_data(ctx, progress_callback, progress_callback_user_data, use_mlock ? &model.mlock_mmap : NULL); - - if (progress_callback) { - progress_callback(1.0f, progress_callback_user_data); - } - - model.mapping = std::move(ml.mapping); - - // loading time will be recalculate after the first eval, so - // we take page faults deferred by mmap() into consideration - model.t_load_us = ggml_time_us() - model.t_start_us; -} - -static bool llama_model_load( - const std::string & fname, - llama_model & model, - int n_gpu_layers, - int main_gpu, - const float * tensor_split, - bool use_mmap, - bool use_mlock, - bool vocab_only, - llama_progress_callback progress_callback, - void *progress_callback_user_data) { - try { - llama_model_loader ml(fname, use_mmap); - - model.hparams.vocab_only = vocab_only; - - llm_load_arch (ml, model); - llm_load_hparams(ml, model); - llm_load_vocab (ml, model); - - llm_load_print_meta(ml, model); - - if (model.hparams.n_vocab != model.vocab.id_to_token.size()) { - throw std::runtime_error("vocab size mismatch"); - } - - if (vocab_only) { - LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__); - return true; - } - - llm_load_tensors( - ml, model, n_gpu_layers, - main_gpu, tensor_split, - use_mlock, progress_callback, progress_callback_user_data); - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("error loading model: %s\n", err.what()); - return false; - } - - return true; -} - -static struct ggml_cgraph * llm_build_llama( - llama_context & lctx, - const llama_batch & batch) { - const auto & model = lctx.model; - const auto & hparams = model.hparams; - const auto & cparams = lctx.cparams; - - const auto & kv_self = lctx.kv_self; - - GGML_ASSERT(!!kv_self.ctx); - - const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = cparams.n_ctx; - const int64_t n_head = hparams.n_head; - const int64_t n_head_kv = hparams.n_head_kv; - const int64_t n_embd_head = hparams.n_embd_head(); - const int64_t n_embd_gqa = hparams.n_embd_gqa(); - - GGML_ASSERT(n_embd_head == hparams.n_rot); - - const float freq_base = cparams.rope_freq_base; - const float freq_scale = cparams.rope_freq_scale; - const float norm_rms_eps = hparams.f_norm_rms_eps; - - const int n_gpu_layers = model.n_gpu_layers; - - const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; - const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; - - const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; - - //printf("n_kv = %d\n", n_kv); - - auto & buf_compute = lctx.buf_compute; - - struct ggml_init_params params = { - /*.mem_size =*/ buf_compute.size, - /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ true, - }; - - struct ggml_context * ctx0 = ggml_init(params); - - ggml_cgraph * gf = ggml_new_graph(ctx0); - - struct ggml_tensor * cur; - struct ggml_tensor * inpL; - - if (batch.token) { - struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - - ggml_allocr_alloc(lctx.alloc, inp_tokens); - if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); - } - ggml_set_name(inp_tokens, "inp_tokens"); - - inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); - } else { -#ifdef GGML_USE_MPI - GGML_ASSERT(false && "not implemented"); -#endif - - inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); - - ggml_allocr_alloc(lctx.alloc, inpL); - if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); - } - } - - const int i_gpu_start = n_layer - n_gpu_layers; - (void) i_gpu_start; - - // offload functions set the tensor output backend to GPU - // tensors are GPU-accelerated if any input or the output has been offloaded - offload_func_t offload_func_nr = llama_nop; // nr = non-repeating - offload_func_t offload_func_kq = llama_nop; - offload_func_t offload_func_v = llama_nop; - -#ifdef GGML_USE_CUBLAS - if (n_gpu_layers > n_layer) { - offload_func_nr = ggml_cuda_assign_buffers_no_alloc; - } - if (n_gpu_layers > n_layer + 1) { - offload_func_v = ggml_cuda_assign_buffers_no_alloc; - } - if (n_gpu_layers > n_layer + 2) { - offload_func_kq = ggml_cuda_assign_buffers_no_alloc; - } -#endif // GGML_USE_CUBLAS - - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); - ggml_allocr_alloc(lctx.alloc, KQ_scale); - if (!ggml_allocr_is_measure(lctx.alloc)) { - ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd_head))); - } - - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); - offload_func_kq(KQ_mask); - ggml_set_name(KQ_mask, "KQ_mask"); - ggml_allocr_alloc(lctx.alloc, KQ_mask); - if (!ggml_allocr_is_measure(lctx.alloc)) { - float * data = (float *) KQ_mask->data; - memset(data, 0, ggml_nbytes(KQ_mask)); - - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; - - for (int i = 0; i < n_kv; ++i) { - if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; - } - } - } - } - } - - // KQ_pos - contains the positions - struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - offload_func_kq(KQ_pos); - ggml_set_name(KQ_pos, "KQ_pos"); - ggml_allocr_alloc(lctx.alloc, KQ_pos); - if (!ggml_allocr_is_measure(lctx.alloc)) { - int * data = (int *) KQ_pos->data; - for (int i = 0; i < n_tokens; ++i) { - data[i] = batch.pos[i]; - } - } - - // shift the entire K-cache if needed - if (do_rope_shift) { - struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); - offload_func_kq(K_shift); - ggml_set_name(K_shift, "K_shift"); - ggml_allocr_alloc(lctx.alloc, K_shift); - if (!ggml_allocr_is_measure(lctx.alloc)) { - int * data = (int *) K_shift->data; - for (int i = 0; i < n_ctx; ++i) { - data[i] = kv_self.cells[i].delta; - } - } - - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * tmp = - ggml_rope_custom_inplace(ctx0, - ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_head_kv, n_ctx, - ggml_element_size(kv_self.k)*n_embd_head, - ggml_element_size(kv_self.k)*n_embd_gqa, - ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il), - K_shift, n_embd_head, 0, 0, freq_base, freq_scale); - offload_func_kq(tmp); - ggml_build_forward_expand(gf, tmp); - } - } - - for (int il = 0; il < n_layer; ++il) { - ggml_format_name(inpL, "layer_inp_%d", il); - - offload_func_t offload_func = llama_nop; - -#ifdef GGML_USE_CUBLAS - if (il >= i_gpu_start) { - offload_func = ggml_cuda_assign_buffers_no_alloc; - } -#endif // GGML_USE_CUBLAS - - struct ggml_tensor * inpSA = inpL; - - // norm - { - cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps); - offload_func(cur); - ggml_set_name(cur, "rms_norm_0"); - - // cur = cur*attn_norm(broadcasted) - cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm); - offload_func(cur); - ggml_set_name(cur, "attention_norm_0"); - } - - // self-attention - { - // compute Q and K and RoPE them - struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur); - offload_func_kq(tmpk); - ggml_set_name(tmpk, "tmpk"); - - struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); - offload_func_kq(tmpq); - ggml_set_name(tmpq, "tmpq"); - - struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); - offload_func_kq(Kcur); - ggml_set_name(Kcur, "Kcur"); - - struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); - offload_func_kq(Qcur); - ggml_set_name(Qcur, "Qcur"); - - // store key and value to memory - { - // compute the transposed [n_tokens, n_embd] V matrix - - struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); - offload_func_v(tmpv); - ggml_set_name(tmpv, "tmpv"); - - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens)); - offload_func_v(Vcur); - ggml_set_name(Vcur, "Vcur"); - - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); - offload_func_kq(k); - ggml_set_name(k, "k"); - - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, - ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); - offload_func_v(v); - ggml_set_name(v, "v"); - - // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); - } - - struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); - offload_func_kq(Q); - ggml_set_name(Q, "Q"); - - struct ggml_tensor * K = - ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_kv, n_head_kv, - ggml_element_size(kv_self.k)*n_embd_gqa, - ggml_element_size(kv_self.k)*n_embd_head, - ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); - offload_func_kq(K); - ggml_set_name(K, "K"); - - // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - offload_func_kq(KQ); - ggml_set_name(KQ, "KQ"); - - // KQ_scaled = KQ / sqrt(n_embd_head) - // KQ_scaled shape [n_kv, n_tokens, n_head, 1] - struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); - offload_func_kq(KQ_scaled); - ggml_set_name(KQ_scaled, "KQ_scaled"); - - // KQ_masked = mask_past(KQ_scaled) - struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); - offload_func_kq(KQ_masked); - ggml_set_name(KQ_masked, "KQ_masked"); - - // KQ = soft_max(KQ_masked) - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); - offload_func_v(KQ_soft_max); - ggml_set_name(KQ_soft_max, "KQ_soft_max"); - - // split cached V into n_head heads - struct ggml_tensor * V = - ggml_view_3d(ctx0, kv_self.v, - n_kv, n_embd_head, n_head_kv, - ggml_element_size(kv_self.v)*n_ctx, - ggml_element_size(kv_self.v)*n_ctx*n_embd_head, - ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); - offload_func_v(V); - ggml_set_name(V, "V"); - -#if 1 - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - offload_func_v(KQV); - ggml_set_name(KQV, "KQV"); -#else - // make V contiguous in memory to speed up the matmul, however we waste time on the copy - // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation - // is there a better way? - struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_ctx, n_embd_head, n_head)); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max); -#endif - - // KQV_merged = KQV.permute(0, 2, 1, 3) - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - offload_func_v(KQV_merged); - ggml_set_name(KQV_merged, "KQV_merged"); - - // cur = KQV_merged.contiguous().view(n_embd, n_tokens) - cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); - offload_func_v(cur); - ggml_set_name(cur, "KQV_merged_contiguous"); - - // projection (no bias) - cur = ggml_mul_mat(ctx0, - model.layers[il].wo, - cur); - offload_func(cur); - ggml_set_name(cur, "result_wo"); - } - - struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); - offload_func(inpFF); - ggml_set_name(inpFF, "inpFF"); - - // feed-forward network - { - // norm - { - cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps); - offload_func(cur); - ggml_set_name(cur, "rms_norm_1"); - - // cur = cur*ffn_norm(broadcasted) - cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm); - offload_func(cur); - ggml_set_name(cur, "ffn_norm"); - } - - struct ggml_tensor * tmp = ggml_mul_mat(ctx0, - model.layers[il].w3, - cur); - offload_func(tmp); - ggml_set_name(tmp, "result_w3"); - - cur = ggml_mul_mat(ctx0, - model.layers[il].w1, - cur); - offload_func(cur); - ggml_set_name(cur, "result_w1"); - - // SILU activation - cur = ggml_silu(ctx0, cur); - offload_func(cur); - ggml_set_name(cur, "silu"); - - cur = ggml_mul(ctx0, cur, tmp); - offload_func(cur); - ggml_set_name(cur, "silu_x_result_w3"); - - cur = ggml_mul_mat(ctx0, - model.layers[il].w2, - cur); - offload_func(cur); - ggml_set_name(cur, "result_w2"); - } - - cur = ggml_add(ctx0, cur, inpFF); - offload_func(cur); - ggml_set_name(cur, "inpFF_+_result_w2"); - - // input for next layer - inpL = cur; - } - - cur = inpL; - - // norm - { - cur = ggml_rms_norm(ctx0, cur, norm_rms_eps); - offload_func_nr(cur); - ggml_set_name(cur, "rms_norm_2"); - - // cur = cur*norm(broadcasted) - cur = ggml_mul(ctx0, cur, model.output_norm); - // offload_func_nr(cur); // TODO CPU + GPU mirrored backend - ggml_set_name(cur, "result_norm"); - } - - // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); - ggml_set_name(cur, "result_output"); - - ggml_build_forward_expand(gf, cur); - - ggml_free(ctx0); - - return gf; -} - -static struct ggml_cgraph * llm_build_baichaun( - llama_context & lctx, - const llama_batch & batch) { - const auto & model = lctx.model; - const auto & hparams = model.hparams; - const auto & cparams = lctx.cparams; - - const auto & kv_self = lctx.kv_self; - - GGML_ASSERT(!!kv_self.ctx); - - const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = cparams.n_ctx; - const int64_t n_head = hparams.n_head; - const int64_t n_head_kv = hparams.n_head_kv; - const int64_t n_embd_head = hparams.n_embd_head(); - const int64_t n_embd_gqa = hparams.n_embd_gqa(); - - GGML_ASSERT(n_embd_head == hparams.n_rot); - - const float freq_base = cparams.rope_freq_base; - const float freq_scale = cparams.rope_freq_scale; - const float norm_rms_eps = hparams.f_norm_rms_eps; - - const int n_gpu_layers = model.n_gpu_layers; - - const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; - const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; - - const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; - - auto & buf_compute = lctx.buf_compute; - - struct ggml_init_params params = { - /*.mem_size =*/ buf_compute.size, - /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ true, - }; - - struct ggml_context * ctx0 = ggml_init(params); - - ggml_cgraph * gf = ggml_new_graph(ctx0); - - struct ggml_tensor * cur; - struct ggml_tensor * inpL; - - if (batch.token) { - struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - - ggml_allocr_alloc(lctx.alloc, inp_tokens); - if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); - } - ggml_set_name(inp_tokens, "inp_tokens"); - - inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); - } else { -#ifdef GGML_USE_MPI - GGML_ASSERT(false && "not implemented"); -#endif - - inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); - - ggml_allocr_alloc(lctx.alloc, inpL); - if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); - } - } - - const int i_gpu_start = n_layer - n_gpu_layers; - (void) i_gpu_start; - - // offload functions set the tensor output backend to GPU - // tensors are GPU-accelerated if any input or the output has been offloaded - offload_func_t offload_func_nr = llama_nop; // nr = non-repeating - offload_func_t offload_func_kq = llama_nop; - offload_func_t offload_func_v = llama_nop; - -#ifdef GGML_USE_CUBLAS - if (n_gpu_layers > n_layer) { - offload_func_nr = ggml_cuda_assign_buffers_no_alloc; - } - if (n_gpu_layers > n_layer + 1) { - offload_func_v = ggml_cuda_assign_buffers_no_alloc; - } - if (n_gpu_layers > n_layer + 2) { - offload_func_kq = ggml_cuda_assign_buffers_no_alloc; - } -#endif // GGML_USE_CUBLAS - - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); - ggml_allocr_alloc(lctx.alloc, KQ_scale); - if (!ggml_allocr_is_measure(lctx.alloc)) { - ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); - } - - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); - offload_func_kq(KQ_mask); - ggml_set_name(KQ_mask, "KQ_mask"); - ggml_allocr_alloc(lctx.alloc, KQ_mask); - if (!ggml_allocr_is_measure(lctx.alloc)) { - float * data = (float *) KQ_mask->data; - memset(data, 0, ggml_nbytes(KQ_mask)); - - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; - - for (int i = 0; i < n_kv; ++i) { - if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; - } - } - } - } - } - - // KQ_pos - contains the positions - struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - offload_func_kq(KQ_pos); - ggml_set_name(KQ_pos, "KQ_pos"); - ggml_allocr_alloc(lctx.alloc, KQ_pos); - if (!ggml_allocr_is_measure(lctx.alloc)) { - int * data = (int *) KQ_pos->data; - for (int i = 0; i < n_tokens; ++i) { - data[i] = batch.pos[i]; - } - } - - // shift the entire K-cache if needed - if (do_rope_shift) { - struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); - offload_func_kq(K_shift); - ggml_set_name(K_shift, "K_shift"); - ggml_allocr_alloc(lctx.alloc, K_shift); - if (!ggml_allocr_is_measure(lctx.alloc)) { - int * data = (int *) K_shift->data; - for (int i = 0; i < n_ctx; ++i) { - data[i] = kv_self.cells[i].delta; - } - } - - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * tmp = - ggml_rope_custom_inplace(ctx0, - ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_head_kv, n_ctx, - ggml_element_size(kv_self.k)*n_embd_head, - ggml_element_size(kv_self.k)*n_embd_gqa, - ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il), - K_shift, n_embd_head, 0, 0, freq_base, freq_scale); - offload_func_kq(tmp); - ggml_build_forward_expand(gf, tmp); - } - } - - for (int il = 0; il < n_layer; ++il) { - ggml_format_name(inpL, "layer_inp_%d", il); - - offload_func_t offload_func = llama_nop; - -#ifdef GGML_USE_CUBLAS - if (il >= i_gpu_start) { - offload_func = ggml_cuda_assign_buffers_no_alloc; - } -#endif // GGML_USE_CUBLAS - - struct ggml_tensor * inpSA = inpL; - - // norm - { - cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps); - offload_func(cur); - ggml_set_name(cur, "rms_norm_0"); - - // cur = cur*attn_norm(broadcasted) - cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm); - offload_func(cur); - ggml_set_name(cur, "attention_norm_0"); - } - - // self-attention - { - // compute Q and K and RoPE them - struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur); - offload_func_kq(tmpk); - ggml_set_name(tmpk, "tmpk"); - - struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); - offload_func_kq(tmpq); - ggml_set_name(tmpq, "tmpq"); - - struct ggml_tensor * Kcur; - struct ggml_tensor * Qcur; - switch (model.type) { - case MODEL_7B: - Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); - Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); - break; - case MODEL_13B: - Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, n_tokens); - Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, n_tokens); - break; - default: - GGML_ASSERT(false); - } - - offload_func_kq(Kcur); - ggml_set_name(Kcur, "Kcur"); - - offload_func_kq(Qcur); - ggml_set_name(Qcur, "Qcur"); - - // store key and value to memory - { - // compute the transposed [n_tokens, n_embd] V matrix - - struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); - offload_func_v(tmpv); - ggml_set_name(tmpv, "tmpv"); - - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens)); - offload_func_v(Vcur); - ggml_set_name(Vcur, "Vcur"); - - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); - offload_func_kq(k); - ggml_set_name(k, "k"); - - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, - ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); - offload_func_v(v); - ggml_set_name(v, "v"); - - // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); - } - - struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); - offload_func_kq(Q); - ggml_set_name(Q, "Q"); - - struct ggml_tensor * K = - ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_kv, n_head_kv, - ggml_element_size(kv_self.k)*n_embd_gqa, - ggml_element_size(kv_self.k)*n_embd_head, - ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); - offload_func_kq(K); - ggml_set_name(K, "K"); - - // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - offload_func_kq(KQ); - ggml_set_name(KQ, "KQ"); - - // KQ_scaled = KQ / sqrt(n_embd_head) - // KQ_scaled shape [n_past + n_tokens, n_tokens, n_head, 1] - struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); - offload_func_kq(KQ_scaled); - ggml_set_name(KQ_scaled, "KQ_scaled"); - - struct ggml_tensor * KQ_masked; - struct ggml_tensor * KQ_scaled_alibi; - - switch (model.type) { - case MODEL_7B: - KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); - break; - case MODEL_13B: - // TODO: replace with ggml_add() - KQ_scaled_alibi = ggml_alibi(ctx0, KQ_scaled, /*n_past*/ 0, n_head, 8); - ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi"); - KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask); - break; - default: - GGML_ASSERT(false); - } - - // KQ = soft_max(KQ_masked) - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); - offload_func_v(KQ_soft_max); - ggml_set_name(KQ_soft_max, "KQ_soft_max"); - - // split cached V into n_head heads - struct ggml_tensor * V = - ggml_view_3d(ctx0, kv_self.v, - n_kv, n_embd_head, n_head_kv, - ggml_element_size(kv_self.v)*n_ctx, - ggml_element_size(kv_self.v)*n_ctx*n_embd_head, - ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); - offload_func_v(V); - ggml_set_name(V, "V"); - - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - offload_func_v(KQV); - ggml_set_name(KQV, "KQV"); - - // KQV_merged = KQV.permute(0, 2, 1, 3) - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - offload_func_v(KQV_merged); - ggml_set_name(KQV_merged, "KQV_merged"); - - // cur = KQV_merged.contiguous().view(n_embd, n_tokens) - cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); - offload_func_v(cur); - ggml_set_name(cur, "KQV_merged_contiguous"); - - // projection (no bias) - cur = ggml_mul_mat(ctx0, - model.layers[il].wo, - cur); - offload_func(cur); - ggml_set_name(cur, "result_wo"); - } - - struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); - offload_func(inpFF); - ggml_set_name(inpFF, "inpFF"); - - // feed-forward network - { - // norm - { - cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps); - offload_func(cur); - ggml_set_name(cur, "rms_norm_1"); - - // cur = cur*ffn_norm(broadcasted) - cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm); - offload_func(cur); - ggml_set_name(cur, "ffn_norm"); - } - - struct ggml_tensor * tmp = ggml_mul_mat(ctx0, - model.layers[il].w3, - cur); - offload_func(tmp); - ggml_set_name(tmp, "result_w3"); - - cur = ggml_mul_mat(ctx0, - model.layers[il].w1, - cur); - offload_func(cur); - ggml_set_name(cur, "result_w1"); - - // SILU activation - cur = ggml_silu(ctx0, cur); - offload_func(cur); - ggml_set_name(cur, "silu"); - - cur = ggml_mul(ctx0, cur, tmp); - offload_func(cur); - ggml_set_name(cur, "silu_x_result_w3"); - - cur = ggml_mul_mat(ctx0, - model.layers[il].w2, - cur); - offload_func(cur); - ggml_set_name(cur, "result_w2"); - } - - cur = ggml_add(ctx0, cur, inpFF); - offload_func(cur); - ggml_set_name(cur, "inpFF_+_result_w2"); - - // input for next layer - inpL = cur; - } - - cur = inpL; - - // norm - { - cur = ggml_rms_norm(ctx0, cur, norm_rms_eps); - offload_func_nr(cur); - ggml_set_name(cur, "rms_norm_2"); - - // cur = cur*norm(broadcasted) - cur = ggml_mul(ctx0, cur, model.output_norm); - // offload_func_nr(cur); // TODO CPU + GPU mirrored backend - ggml_set_name(cur, "result_norm"); - } - - // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); - ggml_set_name(cur, "result_output"); - - ggml_build_forward_expand(gf, cur); - - ggml_free(ctx0); - - return gf; -} - -static struct ggml_cgraph * llm_build_refact( - llama_context & lctx, - const llama_batch & batch) { - const auto & model = lctx.model; - const auto & hparams = model.hparams; - const auto & cparams = lctx.cparams; - - const auto & kv_self = lctx.kv_self; - - GGML_ASSERT(!!kv_self.ctx); - - const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = cparams.n_ctx; - const int64_t n_head = hparams.n_head; - const int64_t n_head_kv = hparams.n_head_kv; - const int64_t n_embd_head = hparams.n_embd_head(); - const int64_t n_embd_gqa = hparams.n_embd_gqa(); - - const float norm_rms_eps = hparams.f_norm_rms_eps; - - const int n_gpu_layers = model.n_gpu_layers; - - const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; - const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; - - // printf("n_kv = %d\n", n_kv); - - auto & buf_compute = lctx.buf_compute; - - struct ggml_init_params params = { - /*.mem_size =*/ buf_compute.size, - /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ true, - }; - - struct ggml_context * ctx0 = ggml_init(params); - - ggml_cgraph * gf = ggml_new_graph(ctx0); - - struct ggml_tensor * cur; - struct ggml_tensor * inpL; - - if (batch.token) { - struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - - ggml_allocr_alloc(lctx.alloc, inp_tokens); - if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); - } - ggml_set_name(inp_tokens, "inp_tokens"); - - inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); - } else { -#ifdef GGML_USE_MPI - GGML_ASSERT(false && "not implemented"); -#endif - - inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); - - ggml_allocr_alloc(lctx.alloc, inpL); - if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); - } - } - - const int i_gpu_start = n_layer - n_gpu_layers; - (void) i_gpu_start; - - // offload functions set the tensor output backend to GPU - // tensors are GPU-accelerated if any input or the output has been offloaded - offload_func_t offload_func_nr = llama_nop; // nr = non-repeating - offload_func_t offload_func_kq = llama_nop; - offload_func_t offload_func_v = llama_nop; - -#ifdef GGML_USE_CUBLAS - if (n_gpu_layers > n_layer) { - offload_func_nr = ggml_cuda_assign_buffers_no_alloc; - } - if (n_gpu_layers > n_layer + 1) { - offload_func_v = ggml_cuda_assign_buffers_no_alloc; - } - if (n_gpu_layers > n_layer + 2) { - offload_func_kq = ggml_cuda_assign_buffers_no_alloc; - } -#endif // GGML_USE_CUBLAS - - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); - ggml_allocr_alloc(lctx.alloc, KQ_scale); - if (!ggml_allocr_is_measure(lctx.alloc)) { - ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd_head))); - } - - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); - offload_func_kq(KQ_mask); - ggml_set_name(KQ_mask, "KQ_mask"); - ggml_allocr_alloc(lctx.alloc, KQ_mask); - if (!ggml_allocr_is_measure(lctx.alloc)) { - float * data = (float *) KQ_mask->data; - memset(data, 0, ggml_nbytes(KQ_mask)); - - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; - - for (int i = 0; i < n_kv; ++i) { - if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; - } - } - } - } - } - - for (int il = 0; il < n_layer; ++il) { - ggml_format_name(inpL, "layer_inp_%d", il); - - offload_func_t offload_func = llama_nop; - -#ifdef GGML_USE_CUBLAS - if (il >= i_gpu_start) { - offload_func = ggml_cuda_assign_buffers_no_alloc; - } -#endif // GGML_USE_CUBLAS - - struct ggml_tensor * inpSA = inpL; - - // norm - { - cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps); - offload_func(cur); - ggml_set_name(cur, "rms_norm_0"); - - // cur = cur*attn_norm(broadcasted) - cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm); - offload_func(cur); - ggml_set_name(cur, "attention_norm_0"); - } - - // self-attention - { - // compute Q and K - struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur); - offload_func_kq(tmpk); - ggml_set_name(tmpk, "tmpk"); - - struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); - offload_func_kq(tmpq); - ggml_set_name(tmpq, "tmpq"); - - struct ggml_tensor * Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens); - offload_func_kq(Kcur); - ggml_set_name(Kcur, "Kcur"); - - struct ggml_tensor * Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens); - offload_func_kq(Qcur); - ggml_set_name(Qcur, "Qcur"); - - // store key and value to memory - { - // compute the transposed [n_tokens, n_embd] V matrix - - struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); - offload_func_v(tmpv); - ggml_set_name(tmpv, "tmpv"); - - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens)); - offload_func_v(Vcur); - ggml_set_name(Vcur, "Vcur"); - - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); - offload_func_kq(k); - ggml_set_name(k, "k"); - - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, - ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); - offload_func_v(v); - ggml_set_name(v, "v"); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); - } - - struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); - offload_func_kq(Q); - ggml_set_name(Q, "Q"); - - struct ggml_tensor * K = - ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_kv, n_head_kv, - ggml_element_size(kv_self.k)*n_embd_gqa, - ggml_element_size(kv_self.k)*n_embd_head, - ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); - offload_func_kq(K); - ggml_set_name(K, "K"); - - // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - offload_func_kq(KQ); - ggml_set_name(KQ, "KQ"); - - // KQ_scaled = KQ / sqrt(n_embd_head) - // KQ_scaled shape [n_kv, n_tokens, n_head, 1] - struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); - offload_func_kq(KQ_scaled); - ggml_set_name(KQ_scaled, "KQ_scaled"); - - // KQ_masked = mask_past(KQ_scaled) - struct ggml_tensor * KQ_scaled_alibi = ggml_alibi(ctx0, KQ_scaled, /*n_past*/ 0, n_head, 8); - ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi"); - - struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask); - offload_func_kq(KQ_masked); - ggml_set_name(KQ_masked, "KQ_masked"); - - // KQ = soft_max(KQ_masked) - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); - offload_func_v(KQ_soft_max); - ggml_set_name(KQ_soft_max, "KQ_soft_max"); - - // split cached V into n_head heads - struct ggml_tensor * V = - ggml_view_3d(ctx0, kv_self.v, - n_kv, n_embd_head, n_head_kv, - ggml_element_size(kv_self.v)*n_ctx, - ggml_element_size(kv_self.v)*n_ctx*n_embd_head, - ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); - offload_func_v(V); - ggml_set_name(V, "V"); - -#if 1 - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - offload_func_v(KQV); - ggml_set_name(KQV, "KQV"); -#else - // make V contiguous in memory to speed up the matmul, however we waste time on the copy - // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation - // is there a better way? - struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_ctx, n_embd_head, n_head)); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max); -#endif - - // KQV_merged = KQV.permute(0, 2, 1, 3) - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - offload_func_v(KQV_merged); - ggml_set_name(KQV_merged, "KQV_merged"); - - // cur = KQV_merged.contiguous().view(n_embd, n_tokens) - cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); - offload_func_v(cur); - ggml_set_name(cur, "KQV_merged_contiguous"); - - // projection (no bias) - cur = ggml_mul_mat(ctx0, - model.layers[il].wo, - cur); - offload_func(cur); - ggml_set_name(cur, "result_wo"); - } - - struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); - offload_func(inpFF); - ggml_set_name(inpFF, "inpFF"); - - // feed-forward network - { - // norm - { - cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps); - offload_func(cur); - ggml_set_name(cur, "rms_norm_1"); - - // cur = cur*ffn_norm(broadcasted) - cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm); - offload_func(cur); - ggml_set_name(cur, "ffn_norm"); - } - - struct ggml_tensor * tmp = ggml_mul_mat(ctx0, - model.layers[il].w3, - cur); - offload_func(tmp); - ggml_set_name(tmp, "result_w3"); - - cur = ggml_mul_mat(ctx0, - model.layers[il].w1, - cur); - offload_func(cur); - ggml_set_name(cur, "result_w1"); - - // SILU activation - cur = ggml_silu(ctx0, cur); - offload_func(cur); - ggml_set_name(cur, "silu"); - - cur = ggml_mul(ctx0, cur, tmp); - offload_func(cur); - ggml_set_name(cur, "silu_x_result_w3"); - - cur = ggml_mul_mat(ctx0, - model.layers[il].w2, - cur); - offload_func(cur); - ggml_set_name(cur, "result_w2"); - } - - cur = ggml_add(ctx0, cur, inpFF); - offload_func(cur); - ggml_set_name(cur, "inpFF_+_result_w2"); - - // input for next layer - inpL = cur; - } - - cur = inpL; - - // norm - { - cur = ggml_rms_norm(ctx0, cur, norm_rms_eps); - offload_func_nr(cur); - ggml_set_name(cur, "rms_norm_2"); - - // cur = cur*norm(broadcasted) - cur = ggml_mul(ctx0, cur, model.output_norm); - // offload_func_nr(cur); // TODO CPU + GPU mirrored backend - ggml_set_name(cur, "result_norm"); - } - - // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); - ggml_set_name(cur, "result_output"); - - ggml_build_forward_expand(gf, cur); - - ggml_free(ctx0); - - return gf; -} - -static struct ggml_cgraph * llm_build_falcon( - llama_context & lctx, - const llama_batch & batch) { - const auto & model = lctx.model; - const auto & hparams = model.hparams; - const auto & cparams = lctx.cparams; - - const auto & kv_self = lctx.kv_self; - - GGML_ASSERT(!!kv_self.ctx); - - const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = cparams.n_ctx; - const int64_t n_head = hparams.n_head; - const int64_t n_head_kv = hparams.n_head_kv; - const int64_t n_embd_head = hparams.n_embd_head(); - const int64_t n_embd_gqa = hparams.n_embd_gqa(); - - GGML_ASSERT(n_embd_head == hparams.n_rot); - - const float freq_base = cparams.rope_freq_base; - const float freq_scale = cparams.rope_freq_scale; - const float norm_eps = hparams.f_norm_eps; - - const int n_gpu_layers = model.n_gpu_layers; - - const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; - const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; - - const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; - - //printf("kv_head = %d, n_kv = %d, n_tokens = %d, n_ctx = %d, is_measure = %d, has_shift = %d\n", - // kv_head, n_kv, n_tokens, n_ctx, ggml_allocr_is_measure(lctx.alloc), kv_self.has_shift); - - auto & buf_compute = lctx.buf_compute; - - struct ggml_init_params params = { - /*.mem_size =*/ buf_compute.size, - /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ true, - }; - - struct ggml_context * ctx0 = ggml_init(params); - - ggml_cgraph * gf = ggml_new_graph(ctx0); - - struct ggml_tensor * cur; - struct ggml_tensor * inpL; - - if (batch.token) { - struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - - ggml_allocr_alloc(lctx.alloc, inp_tokens); - if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); - } - ggml_set_name(inp_tokens, "inp_tokens"); - - inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); - } else { -#ifdef GGML_USE_MPI - GGML_ASSERT(false && "not implemented"); -#endif - - inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); - - ggml_allocr_alloc(lctx.alloc, inpL); - if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); - } - } - - const int i_gpu_start = n_layer - n_gpu_layers; - (void) i_gpu_start; - - // offload functions set the tensor output backend to GPU - // tensors are GPU-accelerated if any input or the output has been offloaded - offload_func_t offload_func_nr = llama_nop; // nr = non-repeating - offload_func_t offload_func_kq = llama_nop; - offload_func_t offload_func_v = llama_nop; - -#ifdef GGML_USE_CUBLAS - if (n_gpu_layers > n_layer) { - offload_func_nr = ggml_cuda_assign_buffers_no_alloc; - } - if (n_gpu_layers > n_layer + 1) { - offload_func_v = ggml_cuda_assign_buffers_no_alloc; - } - if (n_gpu_layers > n_layer + 2) { - offload_func_kq = ggml_cuda_assign_buffers_no_alloc; - } -#endif // GGML_USE_CUBLAS - - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); - ggml_allocr_alloc(lctx.alloc, KQ_scale); - if (!ggml_allocr_is_measure(lctx.alloc)) { - ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); - } - - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); - offload_func_kq(KQ_mask); - ggml_set_name(KQ_mask, "KQ_mask"); - ggml_allocr_alloc(lctx.alloc, KQ_mask); - if (!ggml_allocr_is_measure(lctx.alloc)) { - float * data = (float *) KQ_mask->data; - memset(data, 0, ggml_nbytes(KQ_mask)); - - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; - - for (int i = 0; i < n_kv; ++i) { - if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; - } - } - } - } - } - - // KQ_pos - contains the positions - struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - offload_func_kq(KQ_pos); - ggml_set_name(KQ_pos, "KQ_pos"); - ggml_allocr_alloc(lctx.alloc, KQ_pos); - if (!ggml_allocr_is_measure(lctx.alloc)) { - int * data = (int *) KQ_pos->data; - for (int i = 0; i < n_tokens; ++i) { - data[i] = batch.pos[i]; - } - } - - // shift the entire K-cache if needed - if (do_rope_shift) { - struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); - offload_func_kq(K_shift); - ggml_set_name(K_shift, "K_shift"); - ggml_allocr_alloc(lctx.alloc, K_shift); - if (!ggml_allocr_is_measure(lctx.alloc)) { - int * data = (int *) K_shift->data; - for (int i = 0; i < n_ctx; ++i) { - data[i] = kv_self.cells[i].delta; - } - } - - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * tmp = - ggml_rope_custom_inplace(ctx0, - ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_head_kv, n_ctx, - ggml_element_size(kv_self.k)*n_embd_head, - ggml_element_size(kv_self.k)*n_embd_gqa, - ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il), - K_shift, n_embd_head, 2, 0, freq_base, freq_scale); - offload_func_kq(tmp); - ggml_build_forward_expand(gf, tmp); - } - } - - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * attn_norm; - - offload_func_t offload_func = llama_nop; - -#ifdef GGML_USE_CUBLAS - if (il >= i_gpu_start) { - offload_func = ggml_cuda_assign_buffers_no_alloc; - } -#endif // GGML_USE_CUBLAS - - // self-attention - // TODO: refactor into common function (shared with LLaMA) - { - attn_norm = ggml_norm(ctx0, inpL, norm_eps); - offload_func(attn_norm); - - attn_norm = ggml_add(ctx0, - ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm), - model.layers[il].attn_norm_b); - offload_func(attn_norm->src[0]); - offload_func(attn_norm); - - if (model.layers[il].attn_norm_2) { // Falcon-40B - cur = ggml_norm(ctx0, inpL, norm_eps); - offload_func(cur); - - cur = ggml_add(ctx0, - ggml_mul(ctx0, cur, model.layers[il].attn_norm_2), - model.layers[il].attn_norm_2_b); - offload_func(cur->src[0]); - offload_func(cur); - } else { // Falcon 7B - cur = attn_norm; - } - - // compute QKV - - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); - offload_func_kq(cur); - - // Note that the strides for Kcur, Vcur are set up so that the - // resulting views are misaligned with the tensor's storage - // (by applying the K/V offset we shift the tensor's original - // view to stick out behind the viewed QKV tensor's allocated - // memory, so to say). This is ok because no actual accesses - // happen to that out-of-range memory, but it can require some - // trickery when trying to accurately dump these views for - // debugging. - - const size_t wsize = ggml_type_size(cur->type); - - // TODO: these 2 ggml_conts are technically not needed, but we add them until CUDA support for - // non-contiguous views is added for the rope operator - struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_3d( - ctx0, cur, n_embd_head, n_head, n_tokens, - wsize * n_embd_head, - wsize * n_embd_head * (n_head + 2 * n_head_kv), - 0)); - offload_func_kq(tmpq); - - struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_3d( - ctx0, cur, n_embd_head, n_head_kv, n_tokens, - wsize * n_embd_head, - wsize * n_embd_head * (n_head + 2 * n_head_kv), - wsize * n_embd_head * n_head)); - offload_func_kq(tmpk); - - struct ggml_tensor * tmpv = ggml_view_3d( - ctx0, cur, n_embd_head, n_head_kv, n_tokens, - wsize * n_embd_head, - wsize * n_embd_head * (n_head + 2 * n_head_kv), - wsize * n_embd_head * (n_head + n_head_kv)); - offload_func_v(tmpv); - - // using mode = 2 for neox mode - struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, tmpq, KQ_pos, n_embd_head, 2, 0, freq_base, freq_scale); - offload_func_kq(Qcur); - struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, tmpk, KQ_pos, n_embd_head, 2, 0, freq_base, freq_scale); - offload_func_kq(Kcur); - - { - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens)); - offload_func_v(Vcur); - offload_func_v(Vcur->src[0]->src[0]); - ggml_set_name(Vcur, "Vcur"); - - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); - offload_func_kq(k); - ggml_set_name(k, "k"); - - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, - ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); - offload_func_v(v); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); - } - - struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); - offload_func_kq(Q); - ggml_set_name(Q, "Q"); - - struct ggml_tensor * K = - ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_kv, n_head_kv, - ggml_element_size(kv_self.k)*n_embd_gqa, - ggml_element_size(kv_self.k)*n_embd_head, - ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); - offload_func_kq(K); - ggml_set_name(K, "K"); - - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - offload_func_kq(KQ); - ggml_set_name(KQ, "KQ"); - - struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); - offload_func_kq(KQ_scaled); - ggml_set_name(KQ_scaled, "KQ_scaled"); - - struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); - offload_func_kq(KQ_masked); - ggml_set_name(KQ_masked, "KQ_masked"); - - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); - offload_func_v(KQ_soft_max); - ggml_set_name(KQ_soft_max, "KQ_soft_max"); - - struct ggml_tensor * V = - ggml_view_3d(ctx0, kv_self.v, - n_kv, n_embd_head, n_head_kv, - ggml_element_size(kv_self.v)*n_ctx, - ggml_element_size(kv_self.v)*n_ctx*n_embd_head, - ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); - offload_func_v(V); - ggml_set_name(V, "V"); - - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - offload_func_v(KQV); - ggml_set_name(KQV, "KQV"); - - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - offload_func_v(KQV_merged); - ggml_set_name(KQV_merged, "KQV_merged"); - - cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); - offload_func_v(cur); - ggml_set_name(cur, "KQV_merged_contiguous"); - - cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); - offload_func(cur); - ggml_set_name(cur, "result_wo"); - } - - struct ggml_tensor * attn_out = cur; - - // feed forward - { - struct ggml_tensor * inpFF = attn_norm; - - cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF); - offload_func(cur); - - cur = ggml_gelu(ctx0, cur); - offload_func(cur); - cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); - offload_func(cur); - } - - cur = ggml_add(ctx0, cur, attn_out); - offload_func(cur); - cur = ggml_add(ctx0, cur, inpL); - offload_func(cur); - - // input for next layer - inpL = cur; - } - - cur = inpL; - - // norm - { - cur = ggml_norm(ctx0, cur, norm_eps); - offload_func_nr(cur); - - cur = ggml_add(ctx0, - ggml_mul(ctx0, cur, model.output_norm), - model.output_norm_b); - ggml_set_name(cur, "result_norm"); - } - - cur = ggml_mul_mat(ctx0, model.output, cur); - ggml_set_name(cur, "result_output"); - - ggml_build_forward_expand(gf, cur); - - ggml_free(ctx0); - - return gf; -} - -static struct ggml_cgraph * llm_build_starcoder( - llama_context & lctx, - const llama_batch & batch) { - const auto & model = lctx.model; - const auto & hparams = model.hparams; - const auto & cparams = lctx.cparams; - - const auto & kv_self = lctx.kv_self; - - GGML_ASSERT(!!kv_self.ctx); - - const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = cparams.n_ctx; - const int64_t n_head = hparams.n_head; - const int64_t n_head_kv = hparams.n_head_kv; - const int64_t n_embd_head = hparams.n_embd_head(); - const int64_t n_embd_gqa = hparams.n_embd_gqa(); - - GGML_ASSERT(n_embd_head == hparams.n_rot); - - const float norm_eps = hparams.f_norm_eps; - - const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; - const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; - - auto & buf_compute = lctx.buf_compute; - - struct ggml_init_params params = { - /*.mem_size =*/ buf_compute.size, - /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ true, - }; - - struct ggml_context * ctx0 = ggml_init(params); - - ggml_cgraph * gf = ggml_new_graph(ctx0); - - struct ggml_tensor * cur; - struct ggml_tensor * token; - struct ggml_tensor * position; - struct ggml_tensor * inpL; - - if (batch.token) { - struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - - ggml_allocr_alloc(lctx.alloc, inp_tokens); - if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); - } - ggml_set_name(inp_tokens, "inp_tokens"); - - token = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); - } else { -#ifdef GGML_USE_MPI - GGML_ASSERT(false && "not implemented"); -#endif - - token = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); - - ggml_allocr_alloc(lctx.alloc, token); - if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(token->data, batch.embd, n_tokens * n_embd * ggml_element_size(token)); - } - } - - { - // Compute position embeddings. - struct ggml_tensor * inp_positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - ggml_allocr_alloc(lctx.alloc, inp_positions); - if (!ggml_allocr_is_measure(lctx.alloc)) { - for (int i = 0; i < n_tokens; ++i) { - ((int32_t *) inp_positions->data)[i] = batch.pos[i]; - } - } - ggml_set_name(inp_positions, "inp_positions"); - - position = ggml_get_rows(ctx0, model.pos_embeddings, inp_positions); - } - - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); - ggml_allocr_alloc(lctx.alloc, KQ_scale); - if (!ggml_allocr_is_measure(lctx.alloc)) { - ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); - } - - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); - ggml_set_name(KQ_mask, "KQ_mask"); - ggml_allocr_alloc(lctx.alloc, KQ_mask); - if (!ggml_allocr_is_measure(lctx.alloc)) { - float * data = (float *) KQ_mask->data; - memset(data, 0, ggml_nbytes(KQ_mask)); - - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; - - for (int i = 0; i < n_kv; ++i) { - if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; - } - } - } - } - } - - inpL = ggml_add(ctx0, token, position); - ggml_set_name(inpL, "inpL"); - - for (int il = 0; il < n_layer; ++il) { - { - // Norm - cur = ggml_norm(ctx0, inpL, norm_eps); - cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].attn_norm), model.layers[il].attn_norm_b); - } - - { - // Self Attention - cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wqkv, cur), model.layers[il].bqkv); - - struct ggml_tensor * tmpq = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*n_embd); - struct ggml_tensor * tmpk = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*n_embd); - struct ggml_tensor * tmpv = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*(n_embd + n_embd_gqa)); - - struct ggml_tensor * Qcur = tmpq; - struct ggml_tensor * Kcur = tmpk; - - { - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens)); - ggml_set_name(Vcur, "Vcur"); - - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); - ggml_set_name(k, "k"); - - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, - ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); - } - - struct ggml_tensor * Q = - ggml_permute(ctx0, - ggml_cpy(ctx0, - Qcur, - ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd_head, n_head, n_tokens)), - 0, 2, 1, 3); - ggml_set_name(Q, "Q"); - - struct ggml_tensor * K = - ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_kv, n_head_kv, - ggml_element_size(kv_self.k)*n_embd_gqa, - ggml_element_size(kv_self.k)*n_embd_head, - ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); - ggml_set_name(K, "K"); - - // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - ggml_set_name(KQ, "KQ"); - - // KQ_scaled = KQ / sqrt(n_embd_head) - // KQ_scaled shape [n_past + n_tokens, n_tokens, n_head, 1] - struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); - ggml_set_name(KQ_scaled, "KQ_scaled"); - - // KQ_masked = mask_past(KQ_scaled) - struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); - ggml_set_name(KQ_masked, "KQ_masked"); - - // KQ = soft_max(KQ_masked) - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); - ggml_set_name(KQ_soft_max, "KQ_soft_max"); - - // split cached V into n_head heads - struct ggml_tensor * V = - ggml_view_3d(ctx0, kv_self.v, - n_kv, n_embd_head, n_head_kv, - ggml_element_size(kv_self.v)*n_ctx, - ggml_element_size(kv_self.v)*n_ctx*n_embd_head, - ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); - ggml_set_name(V, "V"); - - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - ggml_set_name(KQV, "KQV"); - - // KQV_merged = KQV.permute(0, 2, 1, 3) - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - ggml_set_name(KQV_merged, "KQV_merged"); - - // cur = KQV_merged.contiguous().view(n_embd, n_tokens) - cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); - ggml_set_name(cur, "KQV_merged_contiguous"); - } - - // Projection - cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wo, cur), model.layers[il].bo); - - // Add the input - cur = ggml_add(ctx0, cur, inpL); - - struct ggml_tensor * inpFF = cur; - - // FF - { - // Norm - { - cur = ggml_norm(ctx0, inpFF, norm_eps); - cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ffn_norm), model.layers[il].ffn_norm_b); - } - - cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w3, cur), model.layers[il].b3); - - // GELU activation - cur = ggml_gelu(ctx0, cur); - - // Projection - cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w2, cur), model.layers[il].b2); - } - - inpL = ggml_add(ctx0, cur, inpFF); - } - - // Output Norm - { - cur = ggml_norm(ctx0, inpL, norm_eps); - cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.output_norm), model.output_norm_b); - } - ggml_set_name(cur, "result_norm"); - - cur = ggml_mul_mat(ctx0, model.output, cur); - ggml_set_name(cur, "result_output"); - - ggml_build_forward_expand(gf, cur); - ggml_free(ctx0); - - return gf; -} - -static struct ggml_cgraph * llm_build_persimmon( - llama_context & lctx, - const llama_batch & batch) { - const auto & model = lctx.model; - const auto & hparams = model.hparams; - - const auto & kv_self = lctx.kv_self; - - GGML_ASSERT(!!kv_self.ctx); - - const auto & cparams = lctx.cparams; - const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = cparams.n_ctx; - const int64_t n_head_kv = hparams.n_head_kv; - const int64_t n_head = hparams.n_head; - const int64_t n_embd_head = hparams.n_embd_head(); - const int64_t n_embd_gqa = hparams.n_embd_gqa(); - const size_t n_rot = n_embd_head / 2; - - const float freq_base = cparams.rope_freq_base; - const float freq_scale = cparams.rope_freq_scale; - const float norm_eps = hparams.f_norm_eps; - - const int n_gpu_layers = model.n_gpu_layers; - - - const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; - const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; - - const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; - - auto & buf_compute = lctx.buf_compute; - struct ggml_init_params params = { - /*.mem_size =*/ buf_compute.size, - /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ true, - }; - - struct ggml_context * ctx0 = ggml_init(params); - - ggml_cgraph * gf = ggml_new_graph(ctx0); - - struct ggml_tensor * cur; - struct ggml_tensor * inpL; - - if (batch.token) { - struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - - ggml_allocr_alloc(lctx.alloc, inp_tokens); - if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); - } - ggml_set_name(inp_tokens, "inp_tokens"); - inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); - } else { - inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); - ggml_allocr_alloc(lctx.alloc, inpL); - if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); - } - } - const int i_gpu_start = n_layer - n_gpu_layers; - (void) i_gpu_start; - offload_func_t offload_func_nr = llama_nop; // nr = non-repeating - offload_func_t offload_func_kq = llama_nop; - offload_func_t offload_func_v = llama_nop; - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - ggml_allocr_alloc(lctx.alloc, KQ_scale); - if (!ggml_allocr_is_measure(lctx.alloc)) { - ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd_head))); - } - ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); - offload_func_kq(KQ_mask); - ggml_set_name(KQ_mask, "KQ_mask"); - ggml_allocr_alloc(lctx.alloc, KQ_mask); - - if (!ggml_allocr_is_measure(lctx.alloc)) { - float * data = (float *) KQ_mask->data; - memset(data, 0, ggml_nbytes(KQ_mask)); - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; - for (int i = 0; i < n_kv; ++i) { - if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; - } - } - } - } - } - - struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - offload_func_kq(KQ_pos); - ggml_set_name(KQ_pos, "KQ_pos"); - ggml_allocr_alloc(lctx.alloc, KQ_pos); - if (!ggml_allocr_is_measure(lctx.alloc)) { - int * data = (int *) KQ_pos->data; - for (int i = 0; i < n_tokens; ++i) { - data[i] = batch.pos[i]; - } - } - if (do_rope_shift) { - struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); - offload_func_kq(K_shift); - ggml_set_name(K_shift, "K_shift"); - ggml_allocr_alloc(lctx.alloc, K_shift); - if (!ggml_allocr_is_measure(lctx.alloc)) { - int * data = (int *) K_shift->data; - for (int i = 0; i < n_ctx; ++i) { - data[i] = kv_self.cells[i].delta; - } - } - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * tmp = - // we rotate only the first n_rot dimensions. - ggml_rope_custom_inplace(ctx0, - ggml_view_3d(ctx0, kv_self.k, - n_rot, n_head, n_ctx, - ggml_element_size(kv_self.k)*n_embd_gqa, - ggml_element_size(kv_self.k)*n_embd_head, - ggml_element_size(kv_self.k)*(n_embd_head*n_ctx*il) - ), - K_shift, n_rot, 2, 0, freq_base, freq_scale); - offload_func_kq(tmp); - ggml_build_forward_expand(gf, tmp); - } - } - for (int il=0; il < n_layer; ++il) { - struct ggml_tensor * residual = inpL; - offload_func_t offload_func = llama_nop; - { - cur = ggml_norm(ctx0, inpL, norm_eps); - offload_func(cur); - cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm); - offload_func(cur); - cur = ggml_add(ctx0, cur, model.layers[il].attn_norm_b); - offload_func(cur); - ggml_format_name(cur, "input_layernorm_%d", il); - } - // self attention - { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); - offload_func_kq(cur); - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - offload_func_kq(cur); - - // split qkv - GGML_ASSERT(n_head_kv == n_head); - ggml_set_name(cur, format("qkv_%d", il).c_str()); - struct ggml_tensor * tmpqkv = ggml_reshape_4d(ctx0, cur, n_embd_head, 3, n_head, n_tokens); - offload_func_kq(tmpqkv); - struct ggml_tensor * tmpqkv_perm = ggml_cont(ctx0, ggml_permute(ctx0, tmpqkv, 0, 3, 1, 2)); - offload_func_kq(tmpqkv_perm); - ggml_format_name(tmpqkv_perm, "tmpqkv_perm_%d", il); - struct ggml_tensor * tmpq = ggml_view_3d( - ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, - ggml_element_size(tmpqkv_perm) * n_embd_head, - ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, - 0 - ); - offload_func_kq(tmpq); - struct ggml_tensor * tmpk = ggml_view_3d( - ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, - ggml_element_size(tmpqkv_perm) * n_embd_head, - ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, - ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens - ); - offload_func_kq(tmpk); - // Q/K Layernorm - tmpq = ggml_norm(ctx0, tmpq, norm_eps); - offload_func_kq(tmpq); - tmpq = ggml_mul(ctx0, tmpq, model.layers[il].attn_q_norm); - offload_func_kq(tmpq); - tmpq = ggml_add(ctx0, tmpq, model.layers[il].attn_q_norm_b); - offload_func_kq(tmpq); - - tmpk = ggml_norm(ctx0, tmpk, norm_eps); - offload_func_v(tmpk); - tmpk = ggml_mul(ctx0, tmpk, model.layers[il].attn_k_norm); - offload_func_v(tmpk); - tmpk = ggml_add(ctx0, tmpk, model.layers[il].attn_k_norm_b); - offload_func_v(tmpk); - - // RoPE the first n_rot of q/k, pass the other half, and concat. - struct ggml_tensor * qrot = ggml_view_3d( - ctx0, tmpq, n_rot, n_head, n_tokens, - ggml_element_size(tmpq) * n_embd_head, - ggml_element_size(tmpq) * n_embd_head * n_head, - 0 - ); - offload_func_kq(qrot); - ggml_format_name(qrot, "qrot_%d", il); - struct ggml_tensor * krot = ggml_view_3d( - ctx0, tmpk, n_rot, n_head, n_tokens, - ggml_element_size(tmpk) * n_embd_head, - ggml_element_size(tmpk) * n_embd_head * n_head, - 0 - ); - offload_func_kq(krot); - ggml_format_name(krot, "krot_%d", il); - - // get the second half of tmpq, e.g tmpq[n_rot:, :, :] - struct ggml_tensor * qpass = ggml_view_3d( - ctx0, tmpq, n_rot, n_head, n_tokens, - ggml_element_size(tmpq) * n_embd_head, - ggml_element_size(tmpq) * n_embd_head * n_head, - ggml_element_size(tmpq) * n_rot - ); - offload_func_kq(qpass); - ggml_format_name(qpass, "qpass_%d", il); - struct ggml_tensor * kpass = ggml_view_3d( - ctx0, tmpk, n_rot, n_head, n_tokens, - ggml_element_size(tmpk) * n_embd_head, - ggml_element_size(tmpk) * n_embd_head * n_head, - ggml_element_size(tmpk) * n_rot - ); - offload_func_kq(kpass); - ggml_format_name(kpass, "kpass_%d", il); - - struct ggml_tensor * qrotated = ggml_rope_custom( - ctx0, qrot, KQ_pos, n_rot, 2, 0, freq_base, freq_scale - ); - offload_func_kq(qrotated); - struct ggml_tensor * krotated = ggml_rope_custom( - ctx0, krot, KQ_pos, n_rot, 2, 0, freq_base, freq_scale - ); - offload_func_kq(krotated); - // ggml currently only supports concatenation on dim=2 - // so we need to permute qrot, qpass, concat, then permute back. - qrotated = ggml_cont(ctx0, ggml_permute(ctx0, qrotated, 2, 1, 0, 3)); - offload_func_kq(qrotated); - krotated = ggml_cont(ctx0, ggml_permute(ctx0, krotated, 2, 1, 0, 3)); - offload_func_kq(krotated); - - qpass = ggml_cont(ctx0, ggml_permute(ctx0, qpass, 2, 1, 0, 3)); - offload_func_kq(qpass); - kpass = ggml_cont(ctx0, ggml_permute(ctx0, kpass, 2, 1, 0, 3)); - offload_func_kq(kpass); - - struct ggml_tensor * Qcur = ggml_concat(ctx0, qrotated, qpass); - offload_func_kq(Qcur); - struct ggml_tensor * Kcur = ggml_concat(ctx0, krotated, kpass); - offload_func_kq(Kcur); - - struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 1, 2, 0, 3)); - offload_func_kq(Q); - - Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 2, 1, 0, 3)); - offload_func_kq(Kcur); - { - struct ggml_tensor * tmpv = ggml_view_3d( - ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, - ggml_element_size(tmpqkv_perm) * n_embd_head, - ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, - ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens * 2 - ); - offload_func_v(tmpv); - // store K, V in cache - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens)); - offload_func_v(Vcur); - ggml_set_name(Vcur, "Vcur"); - - struct ggml_tensor * k = ggml_view_1d( - ctx0, kv_self.k, n_tokens*n_embd_gqa, - (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head) - ); - offload_func_kq(k); - ggml_set_name(k, "k"); - - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, - ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); - offload_func_v(v); - ggml_set_name(v, "v"); - - // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); - } - struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_kv, n_head_kv, - ggml_element_size(kv_self.k)*n_embd_gqa, - ggml_element_size(kv_self.k)*n_embd_head, - ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); - - offload_func_kq(K); - ggml_format_name(K, "K_%d", il); - - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - offload_func_kq(KQ); - ggml_set_name(KQ, "KQ"); - - struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); - offload_func_kq(KQ_scaled); - ggml_set_name(KQ_scaled, "KQ_scaled"); - - struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); - offload_func_kq(KQ_masked); - ggml_set_name(KQ_masked, "KQ_masked"); - - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); - offload_func_kq(KQ_soft_max); - ggml_set_name(KQ_soft_max, "KQ_soft_max"); - - struct ggml_tensor * V = - ggml_view_3d(ctx0, kv_self.v, - n_kv, n_embd_head, n_head_kv, - ggml_element_size(kv_self.v)*n_ctx, - ggml_element_size(kv_self.v)*n_ctx*n_embd_head, - ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); - offload_func_v(V); - ggml_set_name(V, "V"); - - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - offload_func_v(KQV); - ggml_set_name(KQV, "KQV"); - - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - offload_func_v(KQV_merged); - ggml_set_name(KQV_merged, "KQV_merged"); - - cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); - offload_func_v(cur); - ggml_set_name(cur, "KQV_merged_contiguous"); - - cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); - offload_func(cur); - cur = ggml_add(ctx0, cur, model.layers[il].bo); - offload_func(cur); - ggml_set_name(cur, "result_wo"); - } - - struct ggml_tensor * inpFF = ggml_add(ctx0, residual, cur); - offload_func(inpFF); - ggml_set_name(inpFF, "inpFF"); - { - // MLP - { - // Norm - cur = ggml_norm(ctx0, inpFF, norm_eps); - offload_func(cur); - cur = ggml_add(ctx0, - ggml_mul(ctx0, cur, model.layers[il].ffn_norm), - model.layers[il].ffn_norm_b - ); - ggml_set_name(cur, "ffn_norm"); - offload_func(cur); - } - cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur); - offload_func(cur); - - cur = ggml_add(ctx0, cur, model.layers[il].b3); - offload_func(cur); - ggml_set_name(cur, "result_ffn_up"); - - cur = ggml_sqr(ctx0, ggml_relu(ctx0, cur)); - ggml_set_name(cur, "result_ffn_act"); - offload_func(cur); - offload_func(cur->src[0]); - - cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); - offload_func(cur); - cur = ggml_add(ctx0, - cur, - model.layers[il].b2); - offload_func(cur); - ggml_set_name(cur, "outFF"); - } - cur = ggml_add(ctx0, cur, inpFF); - offload_func(cur); - ggml_set_name(cur, "inpFF_+_outFF"); - inpL = cur; - } - cur = inpL; - { - cur = ggml_norm(ctx0, cur, norm_eps); - offload_func_nr(cur); - cur = ggml_mul(ctx0, cur, model.output_norm); - offload_func_nr(cur); - - cur = ggml_add(ctx0, cur, model.output_norm_b); - // offload_func_nr(cur); - - ggml_set_name(cur, "result_norm"); - } - cur = ggml_mul_mat(ctx0, model.output, cur); - ggml_set_name(cur, "result_output"); - ggml_build_forward_expand(gf, cur); - ggml_free(ctx0); - return gf; -} - -static struct ggml_cgraph * llm_build_bloom( - llama_context & lctx, - const llama_batch & batch) { - const auto & model = lctx.model; - const auto & hparams = model.hparams; - const auto & cparams = lctx.cparams; - - const auto & kv_self = lctx.kv_self; - - GGML_ASSERT(!!kv_self.ctx); - - const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = cparams.n_ctx; - const int64_t n_head = hparams.n_head; - const int64_t n_head_kv = hparams.n_head_kv; - const int64_t n_embd_head = hparams.n_embd_head(); - const int64_t n_embd_gqa = hparams.n_embd_gqa(); - - GGML_ASSERT(n_embd_head == hparams.n_rot); - - const float norm_eps = hparams.f_norm_eps; - - const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; - const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; - - auto & buf_compute = lctx.buf_compute; - - struct ggml_init_params params = { - /*.mem_size =*/ buf_compute.size, - /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ false, - }; - - params.no_alloc = true; - - struct ggml_context * ctx0 = ggml_init(params); - - ggml_cgraph * gf = ggml_new_graph(ctx0); - - struct ggml_tensor * cur; - struct ggml_tensor * token; - struct ggml_tensor * inpL; - - if (batch.token) { - struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - - ggml_allocr_alloc(lctx.alloc, inp_tokens); - if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); - } - ggml_set_name(inp_tokens, "inp_tokens"); - - token = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); - } else { -#ifdef GGML_USE_MPI - GGML_ASSERT(false && "not implemented"); -#endif - - token = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); - - ggml_allocr_alloc(lctx.alloc, token); - if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(token->data, batch.embd, n_tokens * n_embd * ggml_element_size(token)); - } - } - - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); - ggml_allocr_alloc(lctx.alloc, KQ_scale); - if (!ggml_allocr_is_measure(lctx.alloc)) { - ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); - } - - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); - ggml_set_name(KQ_mask, "KQ_mask"); - ggml_allocr_alloc(lctx.alloc, KQ_mask); - if (!ggml_allocr_is_measure(lctx.alloc)) { - float * data = (float *) KQ_mask->data; - memset(data, 0, ggml_nbytes(KQ_mask)); - - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; - - for (int i = 0; i < n_kv; ++i) { - if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; - } - } - } - } - } - - // norm - { - inpL = ggml_norm(ctx0, token, norm_eps); - inpL = ggml_add(ctx0, ggml_mul(ctx0, inpL, model.tok_norm), model.tok_norm_b); - } - - ggml_set_name(inpL, "inpL"); - - for (int il = 0; il < n_layer; ++il) { - { - // Norm - cur = ggml_norm(ctx0, inpL, norm_eps); - cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].attn_norm), model.layers[il].attn_norm_b); - } - - { - // Self Attention - cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wqkv, cur), model.layers[il].bqkv); - - struct ggml_tensor * tmpq = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*n_embd); - struct ggml_tensor * tmpk = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*n_embd); - struct ggml_tensor * tmpv = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*(n_embd + n_embd_gqa)); - - struct ggml_tensor * Qcur = tmpq; - struct ggml_tensor * Kcur = tmpk; - - // store key and value to memory - { - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens)); - ggml_set_name(Vcur, "Vcur"); - - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); - ggml_set_name(k, "k"); - - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, - ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); - } - - struct ggml_tensor * Q = - ggml_permute(ctx0, - ggml_cpy(ctx0, - Qcur, - ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd_head, n_head, n_tokens)), - 0, 2, 1, 3); - ggml_set_name(Q, "Q"); - - struct ggml_tensor * K = - ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_kv, n_head_kv, - ggml_element_size(kv_self.k)*n_embd_gqa, - ggml_element_size(kv_self.k)*n_embd_head, - ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); - ggml_set_name(K, "K"); - - // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - ggml_set_name(KQ, "KQ"); - - // KQ_scaled = KQ / sqrt(n_embd_head) - // KQ_scaled shape [n_past + n_tokens, n_tokens, n_head, 1] - struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); - ggml_set_name(KQ_scaled, "KQ_scaled"); - - struct ggml_tensor * KQ_scaled_alibi = ggml_alibi(ctx0, KQ_scaled, /*n_past*/ kv_head, n_head, 8); - ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi"); - - // KQ_masked = mask_past(KQ_scaled) - struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask); - ggml_set_name(KQ_masked, "KQ_masked"); - - // KQ = soft_max(KQ_masked) - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); - ggml_set_name(KQ_soft_max, "KQ_soft_max"); - - // split cached V into n_head heads - struct ggml_tensor * V = - ggml_view_3d(ctx0, kv_self.v, - n_kv, n_embd_head, n_head_kv, - ggml_element_size(kv_self.v)*n_ctx, - ggml_element_size(kv_self.v)*n_ctx*n_embd_head, - ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); - ggml_set_name(V, "V"); - - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - ggml_set_name(KQV, "KQV"); - - // KQV_merged = KQV.permute(0, 2, 1, 3) - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - ggml_set_name(KQV_merged, "KQV_merged"); - - // cur = KQV_merged.contiguous().view(n_embd, n_tokens) - cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); - ggml_set_name(cur, "KQV_merged_contiguous"); - } - - // Projection - cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wo, cur), model.layers[il].bo); - - // Add the input - cur = ggml_add(ctx0, cur, inpL); - - struct ggml_tensor * inpFF = cur; - - // FF - { - // Norm - { - cur = ggml_norm(ctx0, inpFF, norm_eps); - cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ffn_norm), model.layers[il].ffn_norm_b); - } - - cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w3, cur), model.layers[il].b3); - - // GELU activation - cur = ggml_gelu(ctx0, cur); - - // Projection - cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w2, cur), model.layers[il].b2); - } - - inpL = ggml_add(ctx0, cur, inpFF); - } - - // Output Norm - { - cur = ggml_norm(ctx0, inpL, norm_eps); - cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.output_norm), model.output_norm_b); - } - ggml_set_name(cur, "result_norm"); - - cur = ggml_mul_mat(ctx0, model.output, cur); - ggml_set_name(cur, "result_output"); - - ggml_build_forward_expand(gf, cur); - - ggml_free(ctx0); - - return gf; -} - -static struct ggml_cgraph * llm_build_mpt( - llama_context & lctx, - const llama_batch & batch) { - const auto & model = lctx.model; - const auto & hparams = model.hparams; - const auto & cparams = lctx.cparams; - - const auto & kv_self = lctx.kv_self; - - GGML_ASSERT(!!kv_self.ctx); - - const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = cparams.n_ctx; - const int64_t n_head = hparams.n_head; - const int64_t n_head_kv = hparams.n_head_kv; - const int64_t n_embd_head = hparams.n_embd_head(); - const int64_t n_embd_gqa = hparams.n_embd_gqa(); - - const float norm_eps = hparams.f_norm_eps; - const float clamp_kqv = hparams.f_clamp_kqv; - const float max_alibi_bias = hparams.f_max_alibi_bias; - - const int n_gpu_layers = model.n_gpu_layers; - - const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; - const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; - - auto & buf_compute = lctx.buf_compute; - - struct ggml_init_params params = { - /*.mem_size =*/ buf_compute.size, - /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ false, - }; - - params.no_alloc = true; - - struct ggml_context * ctx0 = ggml_init(params); - - ggml_cgraph * gf = ggml_new_graph(ctx0); - - struct ggml_tensor * cur; - struct ggml_tensor * inpL; - - //int warmup = 0; - if (batch.token) { - struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - - ggml_allocr_alloc(lctx.alloc, inp_tokens); - if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); - //warmup = ((uint32_t*) inp_tokens->data)[0] == 0; - } - - ggml_set_name(inp_tokens, "inp_tokens"); - - inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); - } else { -#ifdef GGML_USE_MPI - GGML_ASSERT(false && "not implemented"); -#endif - - inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); - - ggml_allocr_alloc(lctx.alloc, inpL); - if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); - } - } - - const int i_gpu_start = n_layer - n_gpu_layers; - (void) i_gpu_start; - - // offload functions set the tensor output backend to GPU - // tensors are GPU-accelerated if any input or the output has been offloaded - offload_func_t offload_func_nr = llama_nop; // nr = non-repeating - offload_func_t offload_func_kq = llama_nop; - offload_func_t offload_func_v = llama_nop; - -#ifdef GGML_USE_CUBLAS - if (n_gpu_layers > n_layer) { - offload_func_nr = ggml_cuda_assign_buffers_no_alloc; - } - if (n_gpu_layers > n_layer + 1) { - offload_func_v = ggml_cuda_assign_buffers_no_alloc; - } - if (n_gpu_layers > n_layer + 2) { - offload_func_kq = ggml_cuda_assign_buffers_no_alloc; - } -#endif // GGML_USE_CUBLAS - - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); - ggml_allocr_alloc(lctx.alloc, KQ_scale); - if (!ggml_allocr_is_measure(lctx.alloc)) { - ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); - } - - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); - offload_func_kq(KQ_mask); - ggml_set_name(KQ_mask, "KQ_mask"); - ggml_allocr_alloc(lctx.alloc, KQ_mask); - if (!ggml_allocr_is_measure(lctx.alloc)) { - float * data = (float *) KQ_mask->data; - memset(data, 0, ggml_nbytes(KQ_mask)); - - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j]; - - for (int i = 0; i < n_kv; ++i) { - if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; - } - } - } - } - } - - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * attn_norm; - - offload_func_t offload_func = llama_nop; - -#ifdef GGML_USE_CUBLAS - if (il >= i_gpu_start) { - offload_func = ggml_cuda_assign_buffers_no_alloc; - } -#endif // GGML_USE_CUBLAS - - // self-attention - // TODO: refactor into common function (shared with LLaMA) - { - attn_norm = ggml_norm(ctx0, inpL, norm_eps); - offload_func(attn_norm); - - attn_norm = ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm); - offload_func(attn_norm); - - if (1) { - cur = attn_norm; - } - - // compute QKV - - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); - offload_func_kq(cur); - - if (clamp_kqv > 0.0f) { - cur = ggml_clamp(ctx0, cur, -clamp_kqv, clamp_kqv); - offload_func_kq(cur); - } - - const size_t wsize = ggml_type_size(cur->type); - - struct ggml_tensor * Qcur = ggml_view_3d( - ctx0, cur, n_embd_head, n_head, n_tokens, - wsize * n_embd_head, - wsize * n_embd_head * (n_head + 2 * n_head_kv), - 0); - offload_func_kq(Qcur); - - struct ggml_tensor * Kcur = ggml_view_3d( - ctx0, cur, n_embd_head, n_head_kv, n_tokens, - wsize * n_embd_head, - wsize * n_embd_head * (n_head + 2 * n_head_kv), - wsize * n_embd_head * n_head); - offload_func_kq(Kcur); - - struct ggml_tensor * tmpv = ggml_view_3d( - ctx0, cur, n_embd_head, n_head_kv, n_tokens, - wsize * n_embd_head, - wsize * n_embd_head * (n_head + 2 * n_head_kv), - wsize * n_embd_head * (n_head + n_head_kv)); - offload_func_kq(Kcur); - - ggml_set_name(Qcur, "Qcur"); - ggml_set_name(Kcur, "Kcur"); - - { - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens)); - offload_func_v(Vcur); - offload_func_v(Vcur->src[0]->src[0]); - ggml_set_name(Vcur, "Vcur"); - - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); - offload_func_kq(k); - ggml_set_name(k, "k"); - - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, - ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); - offload_func_v(v); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); - } - - struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); - offload_func_kq(Q); - ggml_set_name(Q, "Q"); - - struct ggml_tensor * K = - ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_kv, n_head_kv, - ggml_element_size(kv_self.k)*n_embd_gqa, - ggml_element_size(kv_self.k)*n_embd_head, - ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); - offload_func_kq(K); - ggml_set_name(K, "K"); - - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - offload_func_kq(KQ); - ggml_set_name(KQ, "KQ"); - - struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); - offload_func_kq(KQ_scaled); - ggml_set_name(KQ_scaled, "KQ_scaled"); - - // TODO: replace with ggml_add() - struct ggml_tensor * KQ_scaled_alibi = - ggml_alibi(ctx0, KQ_scaled, 0, n_head, max_alibi_bias); - offload_func_kq(KQ_scaled_alibi); - ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi"); - - struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask); - offload_func_kq(KQ_masked); - ggml_set_name(KQ_masked, "KQ_masked"); - - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); - offload_func_v(KQ_soft_max); - ggml_set_name(KQ_soft_max, "KQ_soft_max"); - - struct ggml_tensor * V = - ggml_view_3d(ctx0, kv_self.v, - n_kv, n_embd_head, n_head_kv, - ggml_element_size(kv_self.v)*n_ctx, - ggml_element_size(kv_self.v)*n_ctx*n_embd_head, - ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); - offload_func_v(V); - ggml_set_name(V, "V"); - - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - offload_func_v(KQV); - ggml_set_name(KQV, "KQV"); - - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - offload_func_v(KQV_merged); - ggml_set_name(KQV_merged, "KQV_merged"); - - cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); - offload_func_v(cur); - ggml_set_name(cur, "KQV_merged_contiguous"); - - cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); - offload_func(cur); - ggml_set_name(cur, "result_wo"); - } - - // Add the input - cur = ggml_add(ctx0, cur, inpL); - offload_func(cur); - - struct ggml_tensor * attn_out = cur; - - // feed forward - { - // Norm - { - cur = ggml_norm(ctx0, attn_out, norm_eps); - offload_func(cur); - - cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm); - offload_func(cur); - } - - cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur); - offload_func(cur); - - cur = ggml_gelu(ctx0, cur); - offload_func(cur); - cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); - offload_func(cur); - } - - cur = ggml_add(ctx0, cur, attn_out); - offload_func(cur); - // input for next layer - inpL = cur; - } - - cur = inpL; - - // norm - { - cur = ggml_norm(ctx0, cur, norm_eps); - offload_func_nr(cur); - - cur = ggml_mul(ctx0, cur, model.output_norm); - ggml_set_name(cur, "result_norm"); - } - - cur = ggml_mul_mat(ctx0, model.output, cur); - ggml_set_name(cur, "result_output"); - - ggml_build_forward_expand(gf, cur); - - ggml_free(ctx0); - - return gf; -} - -static struct ggml_cgraph * llama_build_graph( - llama_context & lctx, - const llama_batch & batch) { - const auto & model = lctx.model; - - struct ggml_cgraph * result = NULL; - - switch (model.arch) { - case LLM_ARCH_LLAMA: - { - result = llm_build_llama(lctx, batch); - } break; - case LLM_ARCH_BAICHUAN: - { - result = llm_build_baichaun(lctx, batch); - } break; - case LLM_ARCH_FALCON: - { - result = llm_build_falcon(lctx, batch); - } break; - case LLM_ARCH_STARCODER: - { - result = llm_build_starcoder(lctx, batch); - } break; - case LLM_ARCH_PERSIMMON: - { - result = llm_build_persimmon(lctx, batch); - } break; - case LLM_ARCH_REFACT: - { - result = llm_build_refact(lctx, batch); - } break; - case LLM_ARCH_BLOOM: - { - result = llm_build_bloom(lctx, batch); - } break; - case LLM_ARCH_MPT: - { - result = llm_build_mpt(lctx, batch); - } break; - default: - GGML_ASSERT(false); - } - - return result; -} - -// decode a batch of tokens by evaluating the transformer -// -// - lctx: llama context -// - batch: batch to evaluate -// -// return 0 on success -// return positive int on warning -// return negative int on error -// -static int llama_decode_internal( - llama_context & lctx, - llama_batch batch) { - const uint32_t n_tokens = batch.n_tokens; - - if (n_tokens == 0) { - LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__); - return -1; - } - - const auto & model = lctx.model; - const auto & hparams = model.hparams; - const auto & cparams = lctx.cparams; - - const auto n_batch = cparams.n_batch; - - GGML_ASSERT(n_tokens <= n_batch); - - int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; - GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT - - const int64_t t_start_us = ggml_time_us(); - -#ifdef GGML_USE_MPI - // TODO: needs fix after #3228 - GGML_ASSERT(false && "not implemented"); - //ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); -#endif - - GGML_ASSERT(n_threads > 0); - - auto & kv_self = lctx.kv_self; - - GGML_ASSERT(!!kv_self.ctx); - - const int64_t n_embd = hparams.n_embd; - const int64_t n_vocab = hparams.n_vocab; - - // helpers for smoother batch API transistion - // after deprecating the llama_eval calls, these will be removed - std::vector pos; - std::vector seq_id; - - if (batch.pos == nullptr) { - pos.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - pos[i] = batch.all_pos_0 + i*batch.all_pos_1; - } - - batch.pos = pos.data(); - } - - if (batch.seq_id == nullptr) { - seq_id.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - seq_id[i] = batch.all_seq_id; - } - - batch.seq_id = seq_id.data(); - } - - if (!llama_kv_cache_find_slot(kv_self, batch)) { - return 1; - } - - // a heuristic, to avoid attending the full cache if it is not yet utilized - // after enough generations, the benefit from this heuristic disappears - // if we start defragmenting the cache, the benefit from this will be more important - //kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)); // TODO: this might be better for CUDA? - kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self))); - - //printf("kv_self.n = %d\n", kv_self.n); - - ggml_allocr_reset(lctx.alloc); - - ggml_cgraph * gf = llama_build_graph(lctx, batch); - - ggml_allocr_alloc_graph(lctx.alloc, gf); - -#ifdef GGML_USE_CUBLAS - for (int i = 0; i < gf->n_leafs; i++) { - ggml_tensor * node = gf->leafs[i]; - if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) { - ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data); - ggml_cuda_copy_to_device(node); - } - } - - for (int i = 0; i < gf->n_nodes; i++) { - ggml_tensor * node = gf->nodes[i]; - if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) { - ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data); - } - } - - ggml_cuda_set_mul_mat_q(cparams.mul_mat_q); -#endif - - // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); - - // for big prompts, if BLAS is enabled, it is better to use only one thread - // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance - // TODO: this is mostly important for Apple Silicon where CBLAS is still performing very well - // we still need some threads to process all non-mul_mat ops, but not too much to avoid interfering - // with the BLAS calls. need a better solution - if (n_tokens >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) { - n_threads = std::min(4, n_threads); - } - - // If all tensors can be run on the GPU then using more than 1 thread is detrimental. - const bool full_offload_supported = model.arch == LLM_ARCH_LLAMA || - model.arch == LLM_ARCH_BAICHUAN || - model.arch == LLM_ARCH_FALCON || - model.arch == LLM_ARCH_REFACT || - model.arch == LLM_ARCH_MPT; - const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 3; - if (ggml_cpu_has_cublas() && full_offload_supported && fully_offloaded) { - n_threads = 1; - } - - struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; - struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; - - GGML_ASSERT(strcmp(res->name, "result_output") == 0); - GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); - -#if GGML_USE_MPI - const int64_t n_layer = hparams.n_layer; - ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer); -#endif - -#ifdef GGML_USE_METAL - if (lctx.ctx_metal) { - ggml_metal_set_n_cb (lctx.ctx_metal, n_threads); - ggml_metal_graph_compute(lctx.ctx_metal, gf); - } else { - ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); - } -#else - ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); -#endif - -#if GGML_USE_MPI - ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer); -#endif - - // update the kv ring buffer - lctx.kv_self.has_shift = false; - lctx.kv_self.head += n_tokens; - // Ensure kv cache head points to a valid index. - if (lctx.kv_self.head >= lctx.kv_self.size) { - lctx.kv_self.head = 0; - } - -#ifdef GGML_PERF - // print timing information per ggml operation (for debugging purposes) - // requires GGML_PERF to be defined - ggml_graph_print(gf); -#endif - - // plot the computation graph in dot format (for debugging purposes) - //if (n_past%100 == 0) { - // ggml_graph_dump_dot(gf, NULL, "llama.dot"); - //} - - // extract logits - { - auto & logits_out = lctx.logits; - - if (batch.logits) { - logits_out.resize(n_vocab * n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - if (batch.logits[i] == 0) { - continue; - } - memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab); - } - } else if (lctx.logits_all) { - logits_out.resize(n_vocab * n_tokens); - memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens); - } else { - logits_out.resize(n_vocab); - memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab); - } - } - - // extract embeddings - if (!lctx.embedding.empty()) { - auto & embedding_out = lctx.embedding; - - embedding_out.resize(n_embd); - memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(n_tokens - 1)), sizeof(float)*n_embd); - } - - // measure the performance only for the single-token evals - if (n_tokens == 1) { - lctx.t_eval_us += ggml_time_us() - t_start_us; - lctx.n_eval++; - } - else if (n_tokens > 1) { - lctx.t_p_eval_us += ggml_time_us() - t_start_us; - lctx.n_p_eval += n_tokens; - } - - // get a more accurate load time, upon first eval - // TODO: fix this - if (!lctx.has_evaluated_once) { - lctx.t_load_us = ggml_time_us() - lctx.t_start_us; - lctx.has_evaluated_once = true; - } - - return 0; -} - -// -// tokenizer -// - -static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) { - return vocab.type; -} - -static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) { - return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL; -} - -static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) { - return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_UNKNOWN; -} - -static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) { - return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_CONTROL; -} - -static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) { - return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_BYTE; -} - -static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id) { - return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_USER_DEFINED; -} - -static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { - GGML_ASSERT(llama_is_byte_token(vocab, id)); - const auto& token_data = vocab.id_to_token.at(id); - switch (llama_vocab_get_type(vocab)) { - case LLAMA_VOCAB_TYPE_SPM: { - auto buf = token_data.text.substr(3, 2); - return strtol(buf.c_str(), NULL, 16); - } - case LLAMA_VOCAB_TYPE_BPE: { - GGML_ASSERT(false); - return unicode_to_bytes_bpe(token_data.text); - } - default: - GGML_ASSERT(false); - } -} - -static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) { - switch (llama_vocab_get_type(vocab)) { - case LLAMA_VOCAB_TYPE_SPM: { - char buf[7]; - int result = snprintf(buf, sizeof(buf), "<0x%02X>", ch); - GGML_ASSERT(0 <= result && result < 7); - return vocab.token_to_id.at(buf); - } - case LLAMA_VOCAB_TYPE_BPE: { - return vocab.token_to_id.at(bytes_to_unicode_bpe(ch)); - } - default: - GGML_ASSERT(false); - } -} - -static void llama_escape_whitespace(std::string & text) { - replace_all(text, " ", "\xe2\x96\x81"); -} - -static void llama_unescape_whitespace(std::string & word) { - replace_all(word, "\xe2\x96\x81", " "); -} - -struct llm_symbol { - using index = int; - index prev; - index next; - const char * text; - size_t n; -}; - -static_assert(std::is_trivially_copyable::value, "llm_symbol is not trivially copyable"); - -// SPM tokenizer -// original implementation: -// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4 - -struct llm_bigram_spm { - struct comparator { - bool operator()(llm_bigram_spm & l, llm_bigram_spm & r) { - return (l.score < r.score) || (l.score == r.score && l.left > r.left); - } - }; - using queue_storage = std::vector; - using queue = std::priority_queue; - llm_symbol::index left; - llm_symbol::index right; - float score; - size_t size; -}; - -struct llm_tokenizer_spm { - llm_tokenizer_spm(const llama_vocab & vocab): vocab(vocab) {} - - void tokenize(const std::string & text, std::vector & output) { - // split string into utf8 chars - int index = 0; - size_t offs = 0; - while (offs < text.size()) { - llm_symbol sym; - size_t len = utf8_len(text[offs]); - sym.text = text.c_str() + offs; - sym.n = std::min(len, text.size() - offs); - offs += sym.n; - sym.prev = index - 1; - sym.next = offs == text.size() ? -1 : index + 1; - index++; - symbols.emplace_back(sym); - } - - // seed the work queue with all possible 2-character tokens. - for (size_t i = 1; i < symbols.size(); ++i) { - try_add_bigram(i - 1, i); - } - - // keep substituting the highest frequency pairs for as long as we can. - while (!work_queue.empty()) { - auto bigram = work_queue.top(); - work_queue.pop(); - - auto & left_sym = symbols[bigram.left]; - auto & right_sym = symbols[bigram.right]; - - // if one of the symbols already got merged, skip it. - if (left_sym.n == 0 || right_sym.n == 0 || - left_sym.n + right_sym.n != bigram.size) { - continue; - } - - // merge the right sym into the left one - left_sym.n += right_sym.n; - right_sym.n = 0; - - //LLAMA_LOG_INFO("left = '%*s' size = %zu\n", (int) left_sym.n, left_sym.text, bigram.size); - - // remove the right sym from the chain - left_sym.next = right_sym.next; - if (right_sym.next >= 0) { - symbols[right_sym.next].prev = bigram.left; - } - - // find more substitutions - try_add_bigram(left_sym.prev, bigram.left); - try_add_bigram(bigram.left, left_sym.next); - } - - for (int i = 0; i != -1; i = symbols[i].next) { - auto & symbol = symbols[i]; - resegment(symbol, output); - } - } - -private: - void resegment(llm_symbol & symbol, std::vector & output) { - auto text = std::string(symbol.text, symbol.n); - auto token = vocab.token_to_id.find(text); - - // Do we need to support is_unused? - if (token != vocab.token_to_id.end()) { - output.push_back((*token).second); - return; - } - - const auto p = rev_merge.find(text); - - if (p == rev_merge.end()) { - // output any symbols that did not form tokens as bytes. - for (int j = 0; j < (int)symbol.n; ++j) { - llama_vocab::id token_id = llama_byte_to_token(vocab, symbol.text[j]); - output.push_back(token_id); - } - return; - } - - resegment(symbols[p->second.first], output); - resegment(symbols[p->second.second], output); - } - - void try_add_bigram(int left, int right) { - if (left == -1 || right == -1) { - return; - } - - const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n); - auto token = vocab.token_to_id.find(text); - - if (token == vocab.token_to_id.end()) { - return; - } - - if (static_cast((*token).second) >= vocab.id_to_token.size()) { - return; - } - - const auto & tok_data = vocab.id_to_token[(*token).second]; - - llm_bigram_spm bigram; - bigram.left = left; - bigram.right = right; - bigram.score = tok_data.score; - bigram.size = text.size(); - - work_queue.push(bigram); - - // Do we need to support is_unused? - rev_merge[text] = std::make_pair(left, right); - } - - const llama_vocab & vocab; - - std::vector symbols; - llm_bigram_spm::queue work_queue; - - std::map> rev_merge; -}; - -// BPE tokenizer -// adapted from https://github.com/cmp-nct/ggllm.cpp [MIT License] -// tried to simplify unicode stuff, so most likely does not work 100% correctly! - -// TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused - -struct llm_bigram_bpe { - struct comparator { - bool operator()(const llm_bigram_bpe & l, const llm_bigram_bpe & r) const { - return l.rank > r.rank || (l.rank == r.rank && l.left > r.left); - } - }; - - using queue_storage = std::vector; - using queue = std::priority_queue; - llm_symbol::index left; - llm_symbol::index right; - std::string text; - int rank; - size_t size; -}; - -struct llm_tokenizer_bpe { - llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {} - - void tokenize(const std::string & text, std::vector & output) { - int final_prev_index = -1; - auto word_collection = bpe_gpt2_preprocess(text); - - symbols_final.clear(); - - for (auto & word : word_collection) { - work_queue = llm_bigram_bpe::queue(); - symbols.clear(); - - int index = 0; - size_t offset = 0; - - while (offset < word.size()) { - llm_symbol sym; - size_t char_len = std::min(word.size() - offset, (size_t) ::utf8_len(word[offset])); - sym.text = word.c_str() + offset; - sym.n = 1; - sym.n = char_len; - offset += sym.n; - sym.prev = index - 1; - sym.next = offset == word.size() ? -1 : index + 1; - index++; - symbols.emplace_back(sym); - } - for (size_t i = 1; i < symbols.size(); ++i) { - add_new_bigram(i - 1, i); - } - - // build token(s) - while (!work_queue.empty()) { - auto bigram = work_queue.top(); - work_queue.pop(); - - auto & left_symbol = symbols[bigram.left]; - auto & right_symbol = symbols[bigram.right]; - - if (left_symbol.n == 0 || right_symbol.n == 0) { - continue; - } - std::string left_token = std::string(left_symbol.text, left_symbol.n); - std::string right_token = std::string(right_symbol.text, right_symbol.n); - if (left_token + right_token != bigram.text) { - continue; // Skip this bigram if it's outdated - } - - // merge the right sym into the left one - left_symbol.n += right_symbol.n; - right_symbol.n = 0; - - // remove the right sym from the chain - left_symbol.next = right_symbol.next; - if (right_symbol.next >= 0) { - symbols[right_symbol.next].prev = bigram.left; - } - - add_new_bigram(left_symbol.prev, bigram.left); // left side of current symbol - add_new_bigram(bigram.left, left_symbol.next); // right side of current symbol - } - - // add the fnished tokens to the final list keeping correct order for next and prev - for (auto & sym : symbols) { - if (sym.n > 0) { - sym.prev = final_prev_index; - sym.next = -1; - if (final_prev_index != -1) { - symbols_final[final_prev_index].next = symbols_final.size(); - } - symbols_final.emplace_back(sym); - final_prev_index = symbols_final.size() - 1; - } - } - } - - symbols = symbols_final; - - if (!symbols.empty()) { - for (int i = 0; i != -1; i = symbols[i].next) { - auto & symbol = symbols[i]; - if (symbol.n == 0) { - continue; - } - - const std::string str = std::string(symbol.text, symbol.n); - const auto token = vocab.token_to_id.find(str); - - if (token == vocab.token_to_id.end()) { - for (auto j = str.begin(); j != str.end(); ++j) { - std::string byte_str(1, *j); - auto token_multibyte = vocab.token_to_id.find(byte_str); - if (token_multibyte == vocab.token_to_id.end()) { - throw std::runtime_error("ERROR: byte not found in vocab"); - } - output.push_back((*token_multibyte).second); - } - } else { - output.push_back((*token).second); - } - } - } - } - -private: - void add_new_bigram(int left, int right) { - if (left == -1 || right == -1) { - return; - } - - std::string left_token = std::string(symbols[left].text, symbols[left].n); - std::string right_token = std::string(symbols[right].text, symbols[right].n); - - int rank_found = -1; - - rank_found = vocab.find_bpe_rank(left_token, right_token); - - if (rank_found < 0) { - return; - } - - llm_bigram_bpe bigram; - - bigram.left = left; - bigram.right = right; - bigram.text = left_token + right_token; - bigram.size = left_token.size() + right_token.size(); - bigram.rank = rank_found; - - work_queue.push(bigram); - } - - std::vector bpe_gpt2_preprocess(const std::string & text) { - std::vector bpe_words; - std::vector bpe_encoded_words; - - std::string token = ""; - // GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+ - bool collecting_numeric = false; - bool collecting_letter = false; - bool collecting_special = false; - bool collecting_whitespace_lookahead = false; - bool collecting = false; - - std::vector text_utf; - text_utf.reserve(text.size()); - bpe_words.reserve(text.size()); - bpe_encoded_words.reserve(text.size()); - - auto cps = codepoints_from_utf8(text); - for (size_t i = 0; i < cps.size(); ++i) - text_utf.emplace_back(codepoint_to_utf8(cps[i])); - - for (int i = 0; i < (int)text_utf.size(); i++) { - const std::string & utf_char = text_utf[i]; - bool split_condition = false; - int bytes_remain = text_utf.size() - i; - // forward backward lookups - const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : ""; - const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : ""; - - // handling contractions - if (!split_condition && bytes_remain >= 2) { - // 's|'t|'m|'d - if (utf_char == "\'" && (utf_char_next == "s" || utf_char_next == "t" || utf_char_next == "m" || utf_char_next == "d")) { - split_condition = true; - } - if (split_condition) { - if (token.size()) { - bpe_words.emplace_back(token); // push previous content as token - } - token = utf_char + utf_char_next; - bpe_words.emplace_back(token); - token = ""; - i++; - continue; - } - } - if (!split_condition && bytes_remain >= 3) { - // 're|'ve|'ll - if (utf_char == "\'" && ( - (utf_char_next == "r" && utf_char_next_next == "e") || - (utf_char_next == "v" && utf_char_next_next == "e") || - (utf_char_next == "l" && utf_char_next_next == "l")) - ) { - split_condition = true; - } - if (split_condition) { - // current token + next token can be defined - if (token.size()) { - bpe_words.emplace_back(token); // push previous content as token - } - token = utf_char + utf_char_next + utf_char_next_next; - bpe_words.emplace_back(token); // the contraction - token = ""; - i += 2; - continue; - } - } - - if (!split_condition && !collecting) { - if (codepoint_type(utf_char) == CODEPOINT_TYPE_LETTER || (!token.size() && utf_char == " " && codepoint_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) { - collecting_letter = true; - collecting = true; - } - else if (codepoint_type(utf_char) == CODEPOINT_TYPE_DIGIT || (!token.size() && utf_char == " " && codepoint_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) { - collecting_numeric = true; - collecting = true; - } - else if ( - ((codepoint_type(utf_char) != CODEPOINT_TYPE_LETTER && codepoint_type(utf_char) != CODEPOINT_TYPE_DIGIT) && (codepoint_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) || - (!token.size() && utf_char == " " && codepoint_type(utf_char_next) != CODEPOINT_TYPE_LETTER && codepoint_type(utf_char_next) != CODEPOINT_TYPE_DIGIT && codepoint_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE) - ) { - collecting_special = true; - collecting = true; - } - else if (codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && codepoint_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) { - collecting_whitespace_lookahead = true; - collecting = true; - } - else if (codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) { - split_condition = true; - } - } - else if (!split_condition && collecting) { - if (collecting_letter && codepoint_type(utf_char) != CODEPOINT_TYPE_LETTER) { - split_condition = true; - } - else if (collecting_numeric && codepoint_type(utf_char) != CODEPOINT_TYPE_DIGIT) { - split_condition = true; - } - else if (collecting_special && (codepoint_type(utf_char) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char) == CODEPOINT_TYPE_DIGIT || codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) { - split_condition = true; - } - else if (collecting_whitespace_lookahead && (codepoint_type(utf_char_next) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) { - split_condition = true; - } - } - - if (utf_char_next == "") { - split_condition = true; // final - token += utf_char; - } - - if (split_condition) { - if (token.size()) { - bpe_words.emplace_back(token); - } - token = utf_char; - collecting = false; - collecting_letter = false; - collecting_numeric = false; - collecting_special = false; - collecting_whitespace_lookahead = false; - } - else { - token += utf_char; - } - } - - for (std::string & word : bpe_words) { - std::string encoded_token = ""; - for (char & c : word) { - encoded_token += bytes_to_unicode_bpe(c); - } - bpe_encoded_words.emplace_back(encoded_token); - } - - return bpe_encoded_words; - } - - const llama_vocab & vocab; - - std::vector symbols; - std::vector symbols_final; - - llm_bigram_bpe::queue work_queue; -}; - -static std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos) { - std::vector output; - - // OG tokenizer behavior: - // - // tokenizer.encode('', add_bos=True) returns [1] - // tokenizer.encode('', add_bos=False) returns [] - - if (bos && vocab.special_bos_id != -1) { - output.push_back(vocab.special_bos_id); - } - - if (raw_text.empty()) { - return output; - } - - switch (vocab.type) { - case LLAMA_VOCAB_TYPE_SPM: - { - // without adding this leading whitespace, we do not get the same results as the original tokenizer - raw_text = " " + raw_text; - - llm_tokenizer_spm tokenizer(vocab); - llama_escape_whitespace(raw_text); - tokenizer.tokenize(raw_text, output); - } break; - case LLAMA_VOCAB_TYPE_BPE: - { - llm_tokenizer_bpe tokenizer(vocab); - tokenizer.tokenize(raw_text, output); - } break; - } - - return output; -} - -// -// grammar - internal -// - -struct llama_partial_utf8 { - uint32_t value; // bit value so far (unshifted) - int n_remain; // num bytes remaining; -1 indicates invalid sequence -}; - -struct llama_grammar { - const std::vector> rules; - std::vector> stacks; - - // buffer for partially generated UTF-8 sequence from accepted tokens - llama_partial_utf8 partial_utf8; -}; - -struct llama_grammar_candidate { - size_t index; - const uint32_t * code_points; - llama_partial_utf8 partial_utf8; -}; - -// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as -// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`. -static std::pair, llama_partial_utf8> decode_utf8( - const char * src, - llama_partial_utf8 partial_start) { - static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; - const char * pos = src; - std::vector code_points; - uint32_t value = partial_start.value; - int n_remain = partial_start.n_remain; - - // continue previous decode, if applicable - while (*pos != 0 && n_remain > 0) { - uint8_t next_byte = static_cast(*pos); - if ((next_byte >> 6) != 2) { - // invalid sequence, abort - code_points.push_back(0); - return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 }); - } - value = (value << 6) + (next_byte & 0x3F); - ++pos; - --n_remain; - } - - if (partial_start.n_remain > 0 && n_remain == 0) { - code_points.push_back(value); - } - - // decode any subsequent utf-8 sequences, which may end in an incomplete one - while (*pos != 0) { - uint8_t first_byte = static_cast(*pos); - uint8_t highbits = first_byte >> 4; - n_remain = lookup[highbits] - 1; - - if (n_remain < 0) { - // invalid sequence, abort - code_points.clear(); - code_points.push_back(0); - return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain }); - } - - uint8_t mask = (1 << (7 - n_remain)) - 1; - value = first_byte & mask; - ++pos; - while (*pos != 0 && n_remain > 0) { - value = (value << 6) + (static_cast(*pos) & 0x3F); - ++pos; - --n_remain; - } - if (n_remain == 0) { - code_points.push_back(value); - } - } - code_points.push_back(0); - - return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain }); -} - -// returns true iff pos points to the end of one of the definitions of a rule -static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) { - switch (pos->type) { - case LLAMA_GRETYPE_END: return true; // NOLINT - case LLAMA_GRETYPE_ALT: return true; // NOLINT - default: return false; - } -} - -// returns true iff chr satisfies the char range at pos (regular or inverse range) -// asserts that pos is pointing to a char range element -static std::pair llama_grammar_match_char( - const llama_grammar_element * pos, - const uint32_t chr) { - - bool found = false; - bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR; - - GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT - - do { - if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { - // inclusive range, e.g. [a-z] - found = found || (pos->value <= chr && chr <= pos[1].value); - pos += 2; - } else { - // exact char match, e.g. [a] or "a" - found = found || pos->value == chr; - pos += 1; - } - } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); - - return std::make_pair(found == is_positive_char, pos); -} - -// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char -// range at pos (regular or inverse range) -// asserts that pos is pointing to a char range element -static bool llama_grammar_match_partial_char( - const llama_grammar_element * pos, - const llama_partial_utf8 partial_utf8) { - - bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR; - GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); - - uint32_t partial_value = partial_utf8.value; - int n_remain = partial_utf8.n_remain; - - // invalid sequence or 7-bit char split across 2 bytes (overlong) - if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) { - return false; - } - - // range of possible code points this partial UTF-8 sequence could complete to - uint32_t low = partial_value << (n_remain * 6); - uint32_t high = low | ((1 << (n_remain * 6)) - 1); - - if (low == 0) { - if (n_remain == 2) { - low = 1 << 11; - } else if (n_remain == 3) { - low = 1 << 16; - } - } - - do { - if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { - // inclusive range, e.g. [a-z] - if (pos->value <= high && low <= pos[1].value) { - return is_positive_char; - } - pos += 2; - } else { - // exact char match, e.g. [a] or "a" - if (low <= pos->value && pos->value <= high) { - return is_positive_char; - } - pos += 1; - } - } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); - - return !is_positive_char; -} - - -// transforms a grammar pushdown stack into N possible stacks, all ending -// at a character range (terminal element) -static void llama_grammar_advance_stack( - const std::vector> & rules, - const std::vector & stack, - std::vector> & new_stacks) { - - if (stack.empty()) { - new_stacks.emplace_back(stack); - return; - } - - const llama_grammar_element * pos = stack.back(); - - switch (pos->type) { - case LLAMA_GRETYPE_RULE_REF: { - const size_t rule_id = static_cast(pos->value); - const llama_grammar_element * subpos = rules[rule_id].data(); - do { - // init new stack without the top (pos) - std::vector new_stack(stack.begin(), stack.end() - 1); - if (!llama_grammar_is_end_of_sequence(pos + 1)) { - // if this rule ref is followed by another element, add that to stack - new_stack.push_back(pos + 1); - } - if (!llama_grammar_is_end_of_sequence(subpos)) { - // if alternate is nonempty, add to stack - new_stack.push_back(subpos); - } - llama_grammar_advance_stack(rules, new_stack, new_stacks); - while (!llama_grammar_is_end_of_sequence(subpos)) { - // scan to end of alternate def - subpos++; - } - if (subpos->type == LLAMA_GRETYPE_ALT) { - // there's another alternate def of this rule to process - subpos++; - } else { - break; - } - } while (true); - break; - } - case LLAMA_GRETYPE_CHAR: - case LLAMA_GRETYPE_CHAR_NOT: - new_stacks.emplace_back(stack); - break; - default: - // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range - // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on - // those - GGML_ASSERT(false); - } -} - -// takes a set of possible pushdown stacks on a grammar, which are required to -// be positioned at a character range (see `llama_grammar_advance_stack`), and -// produces the N possible stacks if the given char is accepted at those -// positions -static std::vector> llama_grammar_accept( - const std::vector> & rules, - const std::vector> & stacks, - const uint32_t chr) { - - std::vector> new_stacks; - - for (const auto & stack : stacks) { - if (stack.empty()) { - continue; - } - - auto match = llama_grammar_match_char(stack.back(), chr); - if (match.first) { - const llama_grammar_element * pos = match.second; - - // update top of stack to next element, if any - std::vector new_stack(stack.begin(), stack.end() - 1); - if (!llama_grammar_is_end_of_sequence(pos)) { - new_stack.push_back(pos); - } - llama_grammar_advance_stack(rules, new_stack, new_stacks); - } - } - - return new_stacks; -} - -static std::vector llama_grammar_reject_candidates( - const std::vector> & rules, - const std::vector> & stacks, - const std::vector & candidates); - -static std::vector llama_grammar_reject_candidates_for_stack( - const std::vector> & rules, - const std::vector & stack, - const std::vector & candidates) { - - std::vector rejects; - - if (stack.empty()) { - for (auto tok : candidates) { - if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) { - rejects.push_back(tok); - } - } - return rejects; - } - - const llama_grammar_element * stack_pos = stack.back(); - - std::vector next_candidates; - for (auto tok : candidates) { - if (*tok.code_points == 0) { - // reached end of full codepoints in token, reject iff it ended in a partial sequence - // that cannot satisfy this position in grammar - if (tok.partial_utf8.n_remain != 0 && - !llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) { - rejects.push_back(tok); - } - } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) { - next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 }); - } else { - rejects.push_back(tok); - } - } - - const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second; - - // update top of stack to next element, if any - std::vector stack_after(stack.begin(), stack.end() - 1); - if (!llama_grammar_is_end_of_sequence(stack_pos_after)) { - stack_after.push_back(stack_pos_after); - } - std::vector> next_stacks; - llama_grammar_advance_stack(rules, stack_after, next_stacks); - - auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); - for (auto tok : next_rejects) { - rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 }); - } - - return rejects; -} - -static std::vector llama_grammar_reject_candidates( - const std::vector> & rules, - const std::vector> & stacks, - const std::vector & candidates) { - GGML_ASSERT(!stacks.empty()); // REVIEW - - if (candidates.empty()) { - return std::vector(); - } - - auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates); - - for (size_t i = 1, size = stacks.size(); i < size; ++i) { - rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects); - } - return rejects; -} - -// -// grammar - external -// - -struct llama_grammar * llama_grammar_init( - const llama_grammar_element ** rules, - size_t n_rules, - size_t start_rule_index) { - const llama_grammar_element * pos; - - // copy rule definitions into vectors - std::vector> vec_rules(n_rules); - for (size_t i = 0; i < n_rules; i++) { - for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) { - vec_rules[i].push_back(*pos); - } - vec_rules[i].push_back({LLAMA_GRETYPE_END, 0}); - } - - // loop over alternates of start rule to build initial stacks - std::vector> stacks; - pos = rules[start_rule_index]; - do { - std::vector stack; - if (!llama_grammar_is_end_of_sequence(pos)) { - // if alternate is nonempty, add to stack - stack.push_back(pos); - } - llama_grammar_advance_stack(vec_rules, stack, stacks); - while (!llama_grammar_is_end_of_sequence(pos)) { - // scan to end of alternate def - pos++; - } - if (pos->type == LLAMA_GRETYPE_ALT) { - // there's another alternate def of this rule to process - pos++; - } else { - break; - } - } while (true); - - return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} }; -} - -void llama_grammar_free(struct llama_grammar * grammar) { - delete grammar; -} - -struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) { - llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 }; - - // redirect elements in stacks to point to new rules - for (size_t is = 0; is < result->stacks.size(); is++) { - for (size_t ie = 0; ie < result->stacks[is].size(); ie++) { - for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) { - for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) { - if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) { - result->stacks[is][ie] = &result->rules[ir0][ir1]; - } - } - } - } - } - - return result; -} - -// -// sampling -// - -void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { - if (seed == LLAMA_DEFAULT_SEED) { - seed = time(NULL); - } - ctx->rng.seed(seed); -} - -void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { - GGML_ASSERT(candidates->size > 0); - - const int64_t t_start_sample_us = ggml_time_us(); - - // Sort the logits in descending order - if (!candidates->sorted) { - std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); - candidates->sorted = true; - } - - float max_l = candidates->data[0].logit; - float cum_sum = 0.0f; - for (size_t i = 0; i < candidates->size; ++i) { - float p = expf(candidates->data[i].logit - max_l); - candidates->data[i].p = p; - cum_sum += p; - } - for (size_t i = 0; i < candidates->size; ++i) { - candidates->data[i].p /= cum_sum; - } - - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } -} - -void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep) { - const int64_t t_start_sample_us = ggml_time_us(); - - k = std::max(k, (int) min_keep); - k = std::min(k, (int) candidates->size); - - // Sort scores in descending order - if (!candidates->sorted) { - auto comp = [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }; - if (k == (int) candidates->size) { - std::sort(candidates->data, candidates->data + candidates->size, comp); - } else { - std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); - } - candidates->sorted = true; - } - candidates->size = k; - - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } -} - -void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { - if (p >= 1.0f) { - return; - } - - llama_sample_softmax(ctx, candidates); - - const int64_t t_start_sample_us = ggml_time_us(); - - // Compute the cumulative probabilities - float cum_sum = 0.0f; - size_t last_idx = candidates->size; - - for (size_t i = 0; i < candidates->size; ++i) { - cum_sum += candidates->data[i].p; - - // Check if the running sum is at least p or if we have kept at least min_keep tokens - // we set the last index to i+1 to indicate that the current iterate should be included in the set - if (cum_sum >= p && i + 1 >= min_keep) { - last_idx = i + 1; - break; - } - } - - // Resize the output vector to keep only the top-p tokens - candidates->size = last_idx; - - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } -} - -void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { - if (z >= 1.0f || candidates->size <= 2) { - return; - } - - llama_sample_softmax(nullptr, candidates); - const int64_t t_start_sample_us = ggml_time_us(); - - // Compute the first and second derivatives - std::vector first_derivatives(candidates->size - 1); - std::vector second_derivatives(candidates->size - 2); - - for (size_t i = 0; i < first_derivatives.size(); ++i) { - first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p; - } - for (size_t i = 0; i < second_derivatives.size(); ++i) { - second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1]; - } - - // Calculate absolute value of second derivatives - for (size_t i = 0; i < second_derivatives.size(); ++i) { - second_derivatives[i] = std::abs(second_derivatives[i]); - } - - // Normalize the second derivatives - { - const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f); - - if (second_derivatives_sum > 1e-6f) { - for (float & value : second_derivatives) { - value /= second_derivatives_sum; - } - } else { - for (float & value : second_derivatives) { - value = 1.0f / second_derivatives.size(); - } - } - } - - float cum_sum = 0.0f; - size_t last_idx = candidates->size; - for (size_t i = 0; i < second_derivatives.size(); ++i) { - cum_sum += second_derivatives[i]; - - // Check if the running sum is greater than z or if we have kept at least min_keep tokens - if (cum_sum > z && i >= min_keep) { - last_idx = i; - break; - } - } - - // Resize the output vector to keep only the tokens above the tail location - candidates->size = last_idx; - - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } -} - -void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { - // Reference implementation: - // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr - if (p >= 1.0f) { - return; - } - - // Compute the softmax of logits and calculate entropy - llama_sample_softmax(nullptr, candidates); - - const int64_t t_start_sample_us = ggml_time_us(); - - float entropy = 0.0f; - for (size_t i = 0; i < candidates->size; ++i) { - entropy += -candidates->data[i].p * logf(candidates->data[i].p); - } - - // Compute the absolute difference between negative log probability and entropy for each candidate - std::vector shifted_scores; - for (size_t i = 0; i < candidates->size; ++i) { - float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy); - shifted_scores.push_back(shifted_score); - } - - // Sort tokens based on the shifted_scores and their corresponding indices - std::vector indices(candidates->size); - std::iota(indices.begin(), indices.end(), 0); - - std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) { - return shifted_scores[a] < shifted_scores[b]; - }); - - // Compute the cumulative probabilities - float cum_sum = 0.0f; - size_t last_idx = indices.size(); - - for (size_t i = 0; i < indices.size(); ++i) { - size_t idx = indices[i]; - cum_sum += candidates->data[idx].p; - - // Check if the running sum is greater than typical or if we have kept at least min_keep tokens - if (cum_sum > p && i >= min_keep - 1) { - last_idx = i + 1; - break; - } - } - - // Resize the output vector to keep only the locally typical tokens - std::vector new_candidates; - for (size_t i = 0; i < last_idx; ++i) { - size_t idx = indices[i]; - new_candidates.push_back(candidates->data[idx]); - } - - // Replace the data in candidates with the new_candidates data - std::copy(new_candidates.begin(), new_candidates.end(), candidates->data); - candidates->size = new_candidates.size(); - - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } -} - -void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { - const int64_t t_start_sample_us = ggml_time_us(); - - for (size_t i = 0; i < candidates_p->size; ++i) { - candidates_p->data[i].logit /= temp; - } - - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } -} - -void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { - llama_sample_temp(ctx, candidates_p, temp); -} - -void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty) { - if (last_tokens_size == 0 || penalty == 1.0f) { - return; - } - - const int64_t t_start_sample_us = ggml_time_us(); - - for (size_t i = 0; i < candidates->size; ++i) { - const auto * token_iter = std::find(last_tokens, last_tokens + last_tokens_size, candidates->data[i].id); - if (token_iter == last_tokens + last_tokens_size) { - continue; - } - - // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. - // This is common fix for this problem, which is to multiply by the penalty instead of dividing. - if (candidates->data[i].logit <= 0) { - candidates->data[i].logit *= penalty; - } else { - candidates->data[i].logit /= penalty; - } - } - - candidates->sorted = false; - - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } -} - -void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) { - if (last_tokens_size == 0 || (alpha_frequency == 0.0f && alpha_presence == 0.0f)) { - return; - } - - const int64_t t_start_sample_us = ggml_time_us(); - - // Create a frequency map to count occurrences of each token in last_tokens - std::unordered_map token_count; - for (size_t i = 0; i < last_tokens_size; ++i) { - token_count[last_tokens_p[i]]++; - } - - // Apply frequency and presence penalties to the candidates - for (size_t i = 0; i < candidates->size; ++i) { - auto token_iter = token_count.find(candidates->data[i].id); - if (token_iter == token_count.end()) { - continue; - } - - int count = token_iter->second; - candidates->data[i].logit -= float(count) * alpha_frequency + float(count > 0) * alpha_presence; - } - - candidates->sorted = false; - - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } -} - -void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) { - GGML_ASSERT(ctx); - const int64_t t_start_sample_us = ggml_time_us(); - - bool allow_eos = false; - for (const auto & stack : grammar->stacks) { - if (stack.empty()) { - allow_eos = true; - break; - } - } - - const llama_token eos = llama_token_eos(ctx); - - std::vector, llama_partial_utf8>> candidates_decoded; - std::vector candidates_grammar; - - for (size_t i = 0; i < candidates->size; ++i) { - const llama_token id = candidates->data[i].id; - const std::string piece = llama_token_to_str(ctx, id); - if (id == eos) { - if (!allow_eos) { - candidates->data[i].logit = -INFINITY; - } - } else if (piece.empty() || piece[0] == 0) { - candidates->data[i].logit = -INFINITY; - } else { - candidates_decoded.push_back(decode_utf8(piece.c_str(), grammar->partial_utf8)); - candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); - } - } - - const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar); - for (const auto & reject : rejects) { - candidates->data[reject.index].logit = -INFINITY; - } - - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; -} - -static void llama_log_softmax(float * array, size_t size) { - float max_l = *std::max_element(array, array + size); - float sum = 0.f; - for (size_t i = 0; i < size; ++i) { - float p = expf(array[i] - max_l); - sum += p; - array[i] = p; - } - - for (size_t i = 0; i < size; ++i) { - array[i] = logf(array[i] / sum); - } -} - -void llama_sample_classifier_free_guidance( - struct llama_context * ctx, - llama_token_data_array * candidates, - struct llama_context * guidance_ctx, - float scale) { - int64_t t_start_sample_us = ggml_time_us(); - - GGML_ASSERT(ctx); - - auto n_vocab = llama_n_vocab(llama_get_model(ctx)); - - GGML_ASSERT(n_vocab == (int)candidates->size); - GGML_ASSERT(!candidates->sorted); - - std::vector logits_base; - logits_base.reserve(candidates->size); - for (size_t i = 0; i < candidates->size; ++i) { - logits_base.push_back(candidates->data[i].logit); - } - llama_log_softmax(logits_base.data(), candidates->size); - - float* logits_guidance = llama_get_logits(guidance_ctx); - llama_log_softmax(logits_guidance, n_vocab); - - for (int i = 0; i < n_vocab; ++i) { - float logit_guidance = logits_guidance[i]; - float logit_base = logits_base[i]; - candidates->data[i].logit = scale * (logit_base - logit_guidance) + logit_guidance; - } - - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } -} - -llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) { - GGML_ASSERT(ctx); - - auto N = float(llama_n_vocab(llama_get_model(ctx))); - int64_t t_start_sample_us; - t_start_sample_us = ggml_time_us(); - - llama_sample_softmax(nullptr, candidates); - - // Estimate s_hat using the most probable m tokens - float s_hat = 0.0; - float sum_ti_bi = 0.0; - float sum_ti_sq = 0.0; - for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) { - float t_i = logf(float(i + 2) / float(i + 1)); - float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p); - sum_ti_bi += t_i * b_i; - sum_ti_sq += t_i * t_i; - } - s_hat = sum_ti_bi / sum_ti_sq; - - // Compute k from the estimated s_hat and target surprise value - float epsilon_hat = s_hat - 1; - float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat); - - // Sample the next word X using top-k sampling - llama_sample_top_k(nullptr, candidates, int(k), 1); - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } - llama_token X = llama_sample_token(ctx, candidates); - t_start_sample_us = ggml_time_us(); - - // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { - return candidate.id == X; - })); - float observed_surprise = -log2f(candidates->data[X_idx].p); - float e = observed_surprise - tau; - - // Update mu using the learning rate and error - *mu = *mu - eta * e; - - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } - return X; -} - -llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) { - int64_t t_start_sample_us; - t_start_sample_us = ggml_time_us(); - - llama_sample_softmax(ctx, candidates); - - // Truncate the words with surprise values greater than mu - candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { - return -log2f(candidate.p) > *mu; - })); - - if (candidates->size == 0) { - candidates->size = 1; - } - - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } - - // Normalize the probabilities of the remaining words - llama_sample_softmax(ctx, candidates); - - // Sample the next word X from the remaining words - llama_token X = llama_sample_token(ctx, candidates); - t_start_sample_us = ggml_time_us(); - - // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { - return candidate.id == X; - })); - float observed_surprise = -log2f(candidates->data[X_idx].p); - float e = observed_surprise - tau; - - // Update mu using the learning rate and error - *mu = *mu - eta * e; - - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } - return X; -} - -llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) { - const int64_t t_start_sample_us = ggml_time_us(); - - // Find max element - auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { - return a.logit < b.logit; - }); - - llama_token result = max_iter->id; - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - ctx->n_sample++; - } - return result; -} - -llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { - GGML_ASSERT(ctx); - - const int64_t t_start_sample_us = ggml_time_us(); - llama_sample_softmax(nullptr, candidates); - - std::vector probs; - probs.reserve(candidates->size); - for (size_t i = 0; i < candidates->size; ++i) { - probs.push_back(candidates->data[i].p); - } - - std::discrete_distribution<> dist(probs.begin(), probs.end()); - auto & rng = ctx->rng; - int idx = dist(rng); - - llama_token result = candidates->data[idx].id; - - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - ctx->n_sample++; - return result; -} - -void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { - const int64_t t_start_sample_us = ggml_time_us(); - - if (token == llama_token_eos(ctx)) { - for (const auto & stack : grammar->stacks) { - if (stack.empty()) { - return; - } - } - GGML_ASSERT(false); - } - - const std::string piece = llama_token_to_str(ctx, token); - - // Note terminating 0 in decoded string - const auto decoded = decode_utf8(piece.c_str(), grammar->partial_utf8); - const auto & code_points = decoded.first; - for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); - } - grammar->partial_utf8 = decoded.second; - GGML_ASSERT(!grammar->stacks.empty()); - - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; -} - -// -// Beam search -// - -struct llama_beam { - std::vector tokens; - float p; // Cumulative beam probability (renormalized relative to all beams) - bool eob; // Initialize end-of-beam to false. Callback sets this to true. - // Sort beams by probability. In case of ties, prefer beams at eob. - bool operator<(const llama_beam & rhs) const { - return std::make_pair(p, eob) < std::make_pair(rhs.p, rhs.eob); - } - // Shift off first n tokens and discard them. - void shift_tokens(const size_t n) { - if (n) { - std::copy(tokens.begin() + n, tokens.end(), tokens.begin()); - tokens.resize(tokens.size() - n); - } - } - llama_beam_view view() const { return {tokens.data(), tokens.size(), p, eob}; } -}; - -// A struct for calculating logit-related info. -struct llama_logit_info { - const float * const logits; - const int n_vocab; - const float max_l; - const float normalizer; - struct sum_exp { - float max_l; - float operator()(float sum, float l) const { return sum + std::exp(l - max_l); } - }; - llama_logit_info(llama_context * ctx) - : logits(llama_get_logits(ctx)) - , n_vocab(llama_n_vocab(llama_get_model(ctx))) - , max_l(*std::max_element(logits, logits + n_vocab)) - , normalizer(1.0f / std::accumulate(logits, logits + n_vocab, 0.0f, sum_exp{max_l})) - { } - llama_token_data get_token_data(const llama_token token_id) const { - constexpr auto p = std::numeric_limits::quiet_NaN(); // never used - return {token_id, logits[token_id], p}; - } - // Return top k token_data by logit. - std::vector top_k(size_t k) { - std::vector min_heap; // min-heap by logit - const llama_token k_min = std::min(static_cast(k), n_vocab); - min_heap.reserve(k_min); - for (llama_token token_id = 0 ; token_id < k_min ; ++token_id) { - min_heap.push_back(get_token_data(token_id)); - } - auto comp = [](const llama_token_data & a, const llama_token_data & b) { return a.logit > b.logit; }; - std::make_heap(min_heap.begin(), min_heap.end(), comp); - for (llama_token token_id = k_min ; token_id < n_vocab ; ++token_id) { - if (min_heap.front().logit < logits[token_id]) { - std::pop_heap(min_heap.begin(), min_heap.end(), comp); - min_heap.back().id = token_id; - min_heap.back().logit = logits[token_id]; - std::push_heap(min_heap.begin(), min_heap.end(), comp); - } - } - return min_heap; - } - float probability_from_logit(float logit) const { - return normalizer * std::exp(logit - max_l); - } -}; - -struct llama_beam_search_data { - llama_context * ctx; - size_t n_beams; - int n_past; - int n_predict; - std::vector beams; - std::vector next_beams; - - // Re-calculated on each loop iteration - size_t common_prefix_length; - - // Used to communicate to/from callback on beams state. - std::vector beam_views; - - llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict) - : ctx(ctx) - , n_beams(n_beams) - , n_past(n_past) - , n_predict(n_predict) - , beam_views(n_beams) { - beams.reserve(n_beams); - next_beams.reserve(n_beams); - } - - // Collapse beams to a single beam given by index. - void collapse_beams(const size_t beam_idx) { - if (0u < beam_idx) { - std::swap(beams[0], beams[beam_idx]); - } - beams.resize(1); - } - - // Min-heaps are used to efficiently collect the top-k elements (k=n_beams). - // The repetative patterns below reflect the 2 stages of heaps: - // * Gather elements until the vector is full, then call std::make_heap() on it. - // * If the heap is full and a new element is found that should be included, pop the - // least element to the back(), replace it with the new, then push it into the heap. - void fill_next_beams_by_top_probabilities(llama_beam & beam) { - // Min-heaps use a greater-than comparator. - const auto comp = [](const llama_beam & a, const llama_beam & b) { return a.p > b.p; }; - if (beam.eob) { - // beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough. - if (next_beams.size() < n_beams) { - next_beams.push_back(std::move(beam)); - if (next_beams.size() == n_beams) { - std::make_heap(next_beams.begin(), next_beams.end(), comp); - } - } else if (next_beams.front().p < beam.p) { - std::pop_heap(next_beams.begin(), next_beams.end(), comp); - next_beams.back() = std::move(beam); - std::push_heap(next_beams.begin(), next_beams.end(), comp); - } - } else { - // beam is not at end-of-sentence, so branch with next top_k tokens. - if (!beam.tokens.empty()) { - llama_decode(ctx, llama_batch_get_one(beam.tokens.data(), beam.tokens.size(), n_past, 0)); - } - llama_logit_info logit_info(ctx); - std::vector next_tokens = logit_info.top_k(n_beams); - size_t i=0; - if (next_beams.size() < n_beams) { - for (; next_beams.size() < n_beams ; ++i) { - llama_beam next_beam = beam; - next_beam.tokens.push_back(next_tokens[i].id); - next_beam.p *= logit_info.probability_from_logit(next_tokens[i].logit); - next_beams.push_back(std::move(next_beam)); - } - std::make_heap(next_beams.begin(), next_beams.end(), comp); - } else { - for (; next_beams.front().p == 0.0f ; ++i) { - std::pop_heap(next_beams.begin(), next_beams.end(), comp); - next_beams.back() = beam; - next_beams.back().tokens.push_back(next_tokens[i].id); - next_beams.back().p *= logit_info.probability_from_logit(next_tokens[i].logit); - std::push_heap(next_beams.begin(), next_beams.end(), comp); - } - } - for (; i < n_beams ; ++i) { - const float next_p = beam.p * logit_info.probability_from_logit(next_tokens[i].logit); - if (next_beams.front().p < next_p) { - std::pop_heap(next_beams.begin(), next_beams.end(), comp); - next_beams.back() = beam; - next_beams.back().tokens.push_back(next_tokens[i].id); - next_beams.back().p = next_p; - std::push_heap(next_beams.begin(), next_beams.end(), comp); - } - } - } - } - - // Find common_prefix_length based on beams. - // Requires beams is not empty. - size_t find_common_prefix_length() { - size_t common_prefix_length = beams[0].tokens.size(); - for (size_t i = 1 ; i < beams.size() ; ++i) { - common_prefix_length = std::min(common_prefix_length, beams[i].tokens.size()); - for (size_t j = 0 ; j < common_prefix_length ; ++j) { - if (beams[0].tokens[j] != beams[i].tokens[j]) { - common_prefix_length = j; - break; - } - } - } - return common_prefix_length; - } - - // Construct beams_state to send back to caller via the callback function. - // Side effect: set common_prefix_length = find_common_prefix_length(); - llama_beams_state get_beams_state(const bool last_call) { - for (size_t i = 0 ; i < beams.size() ; ++i) { - beam_views[i] = beams[i].view(); - } - common_prefix_length = find_common_prefix_length(); - return {beam_views.data(), beams.size(), common_prefix_length, last_call}; - } - - // Loop: - // * while i < n_predict, AND - // * any of the beams have not yet reached end-of-beam (eob), AND - // * the highest probability beam(s) (plural in case of ties) are not at end-of-sentence - // (since all other beam probabilities can only decrease) - void loop(const llama_beam_search_callback_fn_t callback, void * const callback_data) { - beams.push_back({{}, 1.0f, false}); // Start with one empty beam w/ probability = 1.0 and !eob. - const auto not_eob = [](const llama_beam & beam) { return !beam.eob; }; - for (int i = 0 ; i < n_predict && std::any_of(beams.begin(),beams.end(),not_eob) && - !beams[top_beam_index()].eob ; ++i) { - callback(callback_data, get_beams_state(false)); // Sets common_prefix_length - update_beams_from_beam_views(); // Update values (p,eob) that callback may have changed. - if (common_prefix_length) { - llama_decode(ctx, llama_batch_get_one(beams[0].tokens.data(), common_prefix_length, n_past, 0)); - n_past += common_prefix_length; - } - // Zero-out next_beam probabilities to place them last in following min-heap. - std::for_each(next_beams.begin(), next_beams.end(), [](llama_beam & beam) { beam.p = 0.0f; }); - for (llama_beam & beam : beams) { - beam.shift_tokens(common_prefix_length); - fill_next_beams_by_top_probabilities(beam); - } - // next_beams become the beams of next/final iteration. Swap them to re-use memory. - beams.swap(next_beams); - renormalize_beam_probabilities(beams); - } - collapse_beams(top_beam_index()); - callback(callback_data, get_beams_state(true)); - } - - // As beams grow, the cumulative probabilities decrease. - // Renormalize them to avoid floating point underflow. - static void renormalize_beam_probabilities(std::vector & beams) { - const auto sum_p = [](float sum, llama_beam & beam) { return sum + beam.p; }; - const float inv_sum = 1.0f / std::accumulate(beams.begin(), beams.end(), 0.0f, sum_p); - std::for_each(beams.begin(), beams.end(), [=](llama_beam & beam) { beam.p *= inv_sum; }); - } - - // Assumes beams is non-empty. Uses llama_beam::operator<() for ordering. - size_t top_beam_index() { - return std::max_element(beams.begin(), beams.end()) - beams.begin(); - } - - // Copy (p,eob) for each beam which may have been changed by the callback. - void update_beams_from_beam_views() { - for (size_t i = 0 ; i < beams.size() ; ++i) { - beams[i].p = beam_views[i].p; - beams[i].eob = beam_views[i].eob; - } - } -}; - -void llama_beam_search(llama_context * ctx, - llama_beam_search_callback_fn_t callback, void * callback_data, - size_t n_beams, int n_past, int n_predict) { - assert(ctx); - const int64_t t_start_sample_us = ggml_time_us(); - - llama_beam_search_data beam_search_data(ctx, n_beams, n_past, n_predict); - - beam_search_data.loop(callback, callback_data); - - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - ctx->n_sample++; -} - -// -// quantization -// - -template -struct no_init { - T value; - no_init() { /* do nothing */ } -}; - -static void llama_convert_tensor_internal( - struct ggml_tensor * tensor, std::vector> & output, std::vector & workers, - const size_t nelements, const int nthread -) { - if (output.size() < nelements) { - output.resize(nelements); - } - float * f32_output = (float *) output.data(); - - ggml_type_traits_t qtype; - if (ggml_is_quantized(tensor->type)) { - qtype = ggml_internal_get_type_traits(tensor->type); - if (qtype.to_float == NULL) { - throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type))); - } - } else if (tensor->type != GGML_TYPE_F16) { - throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type))); - } - - if (nthread < 2) { - if (tensor->type == GGML_TYPE_F16) { - ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements); - } else if (ggml_is_quantized(tensor->type)) { - qtype.to_float(tensor->data, f32_output, nelements); - } else { - GGML_ASSERT(false); // unreachable - } - return; - } - - auto block_size = tensor->type == GGML_TYPE_F16 ? 1 : (size_t)ggml_blck_size(tensor->type); - auto block_size_bytes = ggml_type_size(tensor->type); - - GGML_ASSERT(nelements % block_size == 0); - auto nblocks = nelements / block_size; - auto blocks_per_thread = nblocks / nthread; - auto spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count - - for (auto tnum = 0, in_buff_offs = 0, out_buff_offs = 0; tnum < nthread; tnum++) { - auto thr_blocks = blocks_per_thread + (tnum == nthread - 1 ? spare_blocks : 0); // num blocks for this thread - auto thr_elems = thr_blocks * block_size; // number of elements for this thread - auto thr_block_bytes = thr_blocks * block_size_bytes; // number of input bytes for this thread - - auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) { - if (typ == GGML_TYPE_F16) { - ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels); - } else { - qtype.to_float(inbuf, outbuf, nels); - } - }; - workers.emplace_back(compute, tensor->type, (uint8_t *) tensor->data + in_buff_offs, f32_output + out_buff_offs, thr_elems); - in_buff_offs += thr_block_bytes; - out_buff_offs += thr_elems; - } - for (auto & w : workers) { w.join(); } - workers.clear(); -} - -#ifdef GGML_USE_K_QUANTS -static ggml_type get_k_quant_type( - ggml_type new_type, const ggml_tensor * tensor, const llama_model & model, llama_ftype ftype, int * i_attention_wv, - int n_attention_wv, int * i_feed_forward_w2, int n_feed_forward_w2 -) { - const std::string name = ggml_get_name(tensor); - // TODO: avoid hardcoded tensor names - use the TN_* constants - const auto tn = LLM_TN(model.arch); - - auto use_more_bits = [](int i_layer, int num_layers) -> bool { - return i_layer < num_layers/8 || i_layer >= 7*num_layers/8 || (i_layer - num_layers/8)%3 == 2; - }; - - if (name == tn(LLM_TENSOR_OUTPUT, "weight")) { - int nx = tensor->ne[0]; - if (model.arch == LLM_ARCH_FALCON || nx % QK_K != 0) { - new_type = GGML_TYPE_Q8_0; - } - else if (new_type != GGML_TYPE_Q8_0) { - new_type = GGML_TYPE_Q6_K; - } - } else if (name.find("attn_v.weight") != std::string::npos) { - if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; - else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { - new_type = *i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; - } - else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; - else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && - use_more_bits(*i_attention_wv, n_attention_wv)) new_type = GGML_TYPE_Q6_K; - else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && *i_attention_wv < 4) new_type = GGML_TYPE_Q5_K; - else if (QK_K == 64 && (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) && - (*i_attention_wv < n_attention_wv/8 || *i_attention_wv >= 7*n_attention_wv/8)) new_type = GGML_TYPE_Q6_K; - if (model.type == MODEL_70B) { - // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is - // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with - // nearly negligible increase in model size by quantizing this tensor with more bits: - if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K; - } - ++*i_attention_wv; - } else if (name.find("ffn_down.weight") != std::string::npos) { - if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; - else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { - new_type = *i_feed_forward_w2 < 2 ? GGML_TYPE_Q5_K - : model.arch != LLM_ARCH_FALCON || use_more_bits(*i_feed_forward_w2, n_feed_forward_w2) ? GGML_TYPE_Q4_K - : GGML_TYPE_Q3_K; - } - else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) { - new_type = model.arch == LLM_ARCH_FALCON ? GGML_TYPE_Q4_K : GGML_TYPE_Q5_K; - } - else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) { - if (model.arch == LLM_ARCH_FALCON) { - new_type = *i_feed_forward_w2 < 2 ? GGML_TYPE_Q6_K : - use_more_bits(*i_feed_forward_w2, n_feed_forward_w2) ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; - } else { - if (use_more_bits(*i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K; - } - } - else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(*i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K; - else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && model.arch != LLM_ARCH_FALCON && *i_feed_forward_w2 < 4) { - new_type = GGML_TYPE_Q5_K; - } - ++*i_feed_forward_w2; - } else if (name.find("attn_output.weight") != std::string::npos) { - if (model.arch != LLM_ARCH_FALCON) { - if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K ) new_type = GGML_TYPE_Q3_K; - else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) new_type = GGML_TYPE_Q4_K; - else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; - } else { - if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K; - } - } - else if (name.find("attn_qkv.weight") != std::string::npos) { - if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K; - else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K; - else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K; - } - else if (name.find("ffn_gate.weight") != std::string::npos || name.find("ffn_up.weight") != std::string::npos) { - if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; - } - // This can be used to reduce the size of the Q5_K_S model. - // The associated PPL increase is fully in line with the size reduction - //else { - // if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K; - //} - bool convert_incompatible_tensor = false; - if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K || - new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K) { - int nx = tensor->ne[0]; - int ny = tensor->ne[1]; - if (nx % QK_K != 0) { - LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for k-quants\n", __func__, nx, ny, QK_K); - convert_incompatible_tensor = true; - } - } - if (convert_incompatible_tensor) { - if (name == tn(LLM_TENSOR_OUTPUT, "weight")) { - new_type = GGML_TYPE_F16; //fall back to F16 instead of just failing. - LLAMA_LOG_WARN("F16 will be used for this tensor instead.\n"); - } else if (name == tn(LLM_TENSOR_TOKEN_EMBD, "weight")) { - new_type = GGML_TYPE_Q4_0; //fall back to Q4_0 instead of just failing. - LLAMA_LOG_WARN("Q4_0 will be used for this tensor instead.\n"); - } else { - throw std::runtime_error("Unsupported tensor size encountered\n"); - } - } - - return new_type; -} -#endif - -static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { - ggml_type quantized_type; - llama_ftype ftype = params->ftype; - - switch (params->ftype) { - case LLAMA_FTYPE_MOSTLY_Q4_0: quantized_type = GGML_TYPE_Q4_0; break; - case LLAMA_FTYPE_MOSTLY_Q4_1: quantized_type = GGML_TYPE_Q4_1; break; - case LLAMA_FTYPE_MOSTLY_Q5_0: quantized_type = GGML_TYPE_Q5_0; break; - case LLAMA_FTYPE_MOSTLY_Q5_1: quantized_type = GGML_TYPE_Q5_1; break; - case LLAMA_FTYPE_MOSTLY_Q8_0: quantized_type = GGML_TYPE_Q8_0; break; - case LLAMA_FTYPE_MOSTLY_F16: quantized_type = GGML_TYPE_F16; break; - case LLAMA_FTYPE_ALL_F32: quantized_type = GGML_TYPE_F32; break; - -#ifdef GGML_USE_K_QUANTS - // K-quants - case LLAMA_FTYPE_MOSTLY_Q2_K: quantized_type = GGML_TYPE_Q2_K; break; - case LLAMA_FTYPE_MOSTLY_Q3_K_S: - case LLAMA_FTYPE_MOSTLY_Q3_K_M: - case LLAMA_FTYPE_MOSTLY_Q3_K_L: quantized_type = GGML_TYPE_Q3_K; break; - case LLAMA_FTYPE_MOSTLY_Q4_K_S: - case LLAMA_FTYPE_MOSTLY_Q4_K_M: quantized_type = GGML_TYPE_Q4_K; break; - case LLAMA_FTYPE_MOSTLY_Q5_K_S: - case LLAMA_FTYPE_MOSTLY_Q5_K_M: quantized_type = GGML_TYPE_Q5_K; break; - case LLAMA_FTYPE_MOSTLY_Q6_K: quantized_type = GGML_TYPE_Q6_K; break; -#endif - default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); - } - - int nthread = params->nthread; - - if (nthread <= 0) { - nthread = std::thread::hardware_concurrency(); - } - - // mmap consistently increases speed Linux, and also increases speed on Windows with - // hot cache. It may cause a slowdown on macOS, possibly related to free memory. -#if defined(__linux__) || defined(_WIN32) - constexpr bool use_mmap = true; -#else - constexpr bool use_mmap = false; -#endif - - llama_model_loader ml(fname_inp, use_mmap); - if (ml.use_mmap) { - ml.mapping.reset(new llama_mmap(&ml.file, /* prefetch */ 0, ggml_is_numa())); - } - - llama_model model; - llm_load_arch(ml, model); - llm_load_hparams(ml, model); - - if (params->only_copy) { - ftype = model.ftype; - } - - const size_t align = GGUF_DEFAULT_ALIGNMENT; - struct gguf_context * ctx_out = gguf_init_empty(); - - // copy the KV pairs from the input file - gguf_set_kv (ctx_out, ml.ctx_gguf); - gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION); - gguf_set_val_u32(ctx_out, "general.file_type", ftype); - -#ifdef GGML_USE_K_QUANTS - int n_attention_wv = 0; - int n_feed_forward_w2 = 0; - - for (int i = 0; i < ml.n_tensors; ++i) { - struct ggml_tensor * meta = ml.get_tensor_meta(i); - - const std::string name = ggml_get_name(meta); - - // TODO: avoid hardcoded tensor names - use the TN_* constants - if (name.find("attn_v.weight") != std::string::npos || name.find("attn_qkv.weight") != std::string::npos) { - ++n_attention_wv; - } - else if (name.find("ffn_down.weight") != std::string::npos) { - ++n_feed_forward_w2; - } - } - if (n_attention_wv != n_feed_forward_w2 || (uint32_t)n_attention_wv != model.hparams.n_layer) { - LLAMA_LOG_WARN("%s ============ Strange model: n_attention_wv = %d, n_feed_forward_w2 = %d, hparams.n_layer = %d\n", - __func__, n_attention_wv, n_feed_forward_w2, model.hparams.n_layer); - } - - int i_attention_wv = 0; - int i_feed_forward_w2 = 0; -#endif - - size_t total_size_org = 0; - size_t total_size_new = 0; - std::vector hist_all(1 << 4, 0); - - std::vector workers; - workers.reserve(nthread); - std::mutex mutex; - - int idx = 0; - - std::vector> read_data; - std::vector> work; - std::vector> f32_conv_buf; - - // populate the original tensors so we get an initial meta data - for (int i = 0; i < ml.n_tensors; ++i) { - struct ggml_tensor * meta = ml.get_tensor_meta(i); - gguf_add_tensor(ctx_out, meta); - } - - std::ofstream fout(fname_out, std::ios::binary); - fout.exceptions(std::ofstream::failbit); // fail fast on write errors - - const size_t meta_size = gguf_get_meta_size(ctx_out); - - LLAMA_LOG_INFO("%s: meta size = %zu bytes\n", __func__, meta_size); - - // placeholder for the meta data - ::zeros(fout, meta_size); - - for (int i = 0; i < ml.n_tensors; ++i) { - struct ggml_tensor * tensor = ml.get_tensor_meta(i); - - const std::string name = ggml_get_name(tensor); - - if (!ml.use_mmap) { - if (read_data.size() < ggml_nbytes(tensor)) { - read_data.resize(ggml_nbytes(tensor)); - } - tensor->data = read_data.data(); - } - ml.load_data_for(tensor); - - LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ", - ++idx, ml.n_tensors, - ggml_get_name(tensor), - llama_format_tensor_shape(tensor).c_str(), - ggml_type_name(tensor->type)); - - // This used to be a regex, but has an extreme cost to compile times. - bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? - - // quantize only 2D tensors - quantize &= (tensor->n_dims == 2); - quantize &= params->quantize_output_tensor || name != "output.weight"; - quantize &= !params->only_copy; - - enum ggml_type new_type; - void * new_data; - size_t new_size; - - if (quantize) { - new_type = quantized_type; -#ifdef GGML_USE_K_QUANTS - new_type = get_k_quant_type( - new_type, tensor, model, ftype, &i_attention_wv, n_attention_wv, &i_feed_forward_w2, n_feed_forward_w2 - ); -#endif - // If we've decided to quantize to the same type the tensor is already - // in then there's nothing to do. - quantize = tensor->type != new_type; - } - if (!quantize) { - new_type = tensor->type; - new_data = tensor->data; - new_size = ggml_nbytes(tensor); - LLAMA_LOG_INFO("size = %8.3f MB\n", ggml_nbytes(tensor)/1024.0/1024.0); - } else { - const size_t nelements = ggml_nelements(tensor); - - float * f32_data; - - if (tensor->type == GGML_TYPE_F32) { - f32_data = (float *) tensor->data; - } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) { - throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type))); - } else { - llama_convert_tensor_internal(tensor, f32_conv_buf, workers, nelements, nthread); - f32_data = (float *) f32_conv_buf.data(); - } - - LLAMA_LOG_INFO("quantizing to %s .. ", ggml_type_name(new_type)); - fflush(stdout); - - if (work.size() < nelements * 4) { - work.resize(nelements * 4); // upper bound on size - } - new_data = work.data(); - std::array hist_cur = {}; - - static const int chunk_size = 32 * 512; - const int nchunk = (nelements + chunk_size - 1)/chunk_size; - const int nthread_use = nthread > 1 ? std::max(1, std::min(nthread, nchunk)) : 1; - if (nthread_use < 2) { - new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nelements, hist_cur.data()); - } else { - size_t counter = 0; - new_size = 0; - auto compute = [&mutex, &counter, &hist_cur, &new_size, new_type, f32_data, new_data, nelements]() { - std::array local_hist = {}; - size_t local_size = 0; - while (true) { - std::unique_lock lock(mutex); - size_t first = counter; counter += chunk_size; - if (first >= nelements) { - if (local_size > 0) { - for (int j=0; j %8.2f MB | hist: ", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0); - int64_t tot_count = 0; - for (size_t i = 0; i < hist_cur.size(); i++) { - hist_all[i] += hist_cur[i]; - tot_count += hist_cur[i]; - } - - if (tot_count > 0) { - for (size_t i = 0; i < hist_cur.size(); i++) { - LLAMA_LOG_INFO("%5.3f ", hist_cur[i] / float(nelements)); - } - } - LLAMA_LOG_INFO("\n"); - } - total_size_org += ggml_nbytes(tensor); - total_size_new += new_size; - - // update the gguf meta data as we go - gguf_set_tensor_type(ctx_out, name.c_str(), new_type); - gguf_set_tensor_data(ctx_out, name.c_str(), new_data, new_size); - - // write tensor data + padding - fout.write((const char *) new_data, new_size); - zeros(fout, GGML_PAD(new_size, align) - new_size); - } - - // go back to beginning of file and write the updated meta data - { - fout.seekp(0); - std::vector data(gguf_get_meta_size(ctx_out)); - gguf_get_meta_data(ctx_out, data.data()); - fout.write((const char *) data.data(), data.size()); - } - - fout.close(); - - gguf_free(ctx_out); - - LLAMA_LOG_INFO("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0); - LLAMA_LOG_INFO("%s: quant size = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0); - - // print histogram for all tensors - { - int64_t sum_all = 0; - for (size_t i = 0; i < hist_all.size(); i++) { - sum_all += hist_all[i]; - } - - if (sum_all > 0) { - LLAMA_LOG_INFO("%s: hist: ", __func__); - for (size_t i = 0; i < hist_all.size(); i++) { - LLAMA_LOG_INFO("%5.3f ", hist_all[i] / float(sum_all)); - } - LLAMA_LOG_INFO("\n"); - } - } -} - -static int llama_apply_lora_from_file_internal( - const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads -) { - LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora); - - const int64_t t_start_lora_us = ggml_time_us(); - - auto fin = std::ifstream(path_lora, std::ios::binary); - if (!fin) { - LLAMA_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_lora); - return 1; - } - - // verify magic and version - { - uint32_t magic; - fin.read((char *) &magic, sizeof(magic)); - uint32_t format_version; - fin.read((char *) &format_version, sizeof(format_version)); - - if (format_version != 1) { - LLAMA_LOG_ERROR("%s: unsupported file version\n", __func__ ); - return 1; - } - } - - int32_t lora_r; - int32_t lora_alpha; - fin.read((char *) &lora_r, sizeof(lora_r)); - fin.read((char *) &lora_alpha, sizeof(lora_alpha)); - float scaling = scale * (float)lora_alpha / (float)lora_r; - - LLAMA_LOG_INFO("%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling); - - // create a temporary ggml context to store the lora tensors - // todo: calculate size from biggest possible tensor - std::vector lora_buf(1024ull * 1024ull * 1024ull); - struct ggml_init_params params; - params.mem_size = lora_buf.size(); - params.mem_buffer = lora_buf.data(); - params.no_alloc = false; - - ggml_context * lora_ctx = ggml_init(params); - std::unordered_map lora_tensors; - - // create a name -> tensor map of the model to accelerate lookups - std::unordered_map model_tensors; - for (const auto & kv : model.tensors_by_name) { - model_tensors.insert(kv); - } - - // load base model - std::unique_ptr ml; - ggml_context * base_ctx = NULL; - std::vector base_buf; - if (path_base_model) { - LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model); - ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true)); - - size_t ctx_size; - size_t mmapped_size; - ml->calc_sizes(ctx_size, mmapped_size); - base_buf.resize(ctx_size); - - ggml_init_params base_params; - base_params.mem_size = base_buf.size(); - base_params.mem_buffer = base_buf.data(); - base_params.no_alloc = ml->use_mmap; - - base_ctx = ggml_init(base_params); - - // maybe this should in llama_model_loader - if (ml->use_mmap) { - ml->mapping.reset(new llama_mmap(&ml->file, /* prefetch */ 0, ggml_is_numa())); - } - } - - // read tensors and apply - bool warned = false; - int n_tensors = 0; - - std::vector work_buffer; - - while (true) { - int32_t n_dims; - int32_t length; - int32_t ftype; - - fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); - fin.read(reinterpret_cast(&length), sizeof(length)); - fin.read(reinterpret_cast(&ftype), sizeof(ftype)); - if (fin.eof()) { - break; - } - - int32_t ne[2] = { 1, 1 }; - for (int i = 0; i < n_dims; ++i) { - fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); - } - - std::string name; - { - char buf[1024]; - fin.read(buf, length); - name = std::string(buf, length); - } - - // check for lora suffix and get the type of tensor - const std::string lora_suffix = ".lora"; - size_t pos = name.rfind(lora_suffix); - if (pos == std::string::npos) { - LLAMA_LOG_ERROR("%s: error: '%s' is not a lora tensor\n", __func__, name.c_str()); - return 1; - } - - std::string lora_type = name.substr(pos + lora_suffix.length()); - std::string base_name = name; - base_name.erase(pos); - // LLAMA_LOG_INFO("%s: %s => %s (lora type %s) \n", __func__, name.c_str(),base_name.c_str(), lora_type.c_str()); - - if (model_tensors.find(base_name) == model_tensors.end()) { - LLAMA_LOG_ERROR("%s: unknown tensor '%s' in lora adapter\n", __func__, name.data()); - return 1; - } - - // create ggml tensor - ggml_type wtype; - switch (ftype) { - case 0: wtype = GGML_TYPE_F32; break; - case 1: wtype = GGML_TYPE_F16; break; - default: - { - LLAMA_LOG_ERROR("%s: invalid tensor data type '%d'\n", - __func__, ftype); - return false; - } - } - ggml_tensor * lora_tensor; - if (n_dims == 2) { - lora_tensor = ggml_new_tensor_2d(lora_ctx, wtype, ne[0], ne[1]); - } - else { - LLAMA_LOG_ERROR("%s: unsupported tensor dimension %d\n", __func__, n_dims); - return 1; - } - ggml_set_name(lora_tensor, "lora_tensor"); - - // load tensor data - size_t offset = fin.tellg(); - size_t tensor_data_size = ggml_nbytes(lora_tensor); - offset = (offset + 31) & -32; - fin.seekg(offset); - fin.read((char*)lora_tensor->data, tensor_data_size); - - lora_tensors[name] = lora_tensor; - - // check if we have both A and B tensors and apply - if (lora_tensors.find(base_name + ".loraA") != lora_tensors.end() && - lora_tensors.find(base_name + ".loraB") != lora_tensors.end()) { - - ggml_tensor * dest_t = model_tensors[base_name]; - - offload_func_t offload_func = llama_nop; - offload_func_t offload_func_force_inplace = llama_nop; - -#ifdef GGML_USE_CUBLAS - if (dest_t->backend == GGML_BACKEND_GPU || dest_t->backend == GGML_BACKEND_GPU_SPLIT) { - if (dest_t->type != GGML_TYPE_F16) { - throw std::runtime_error(format( - "%s: error: the simultaneous use of LoRAs and GPU acceleration is only supported for f16 models", __func__)); - } - offload_func = ggml_cuda_assign_buffers; - offload_func_force_inplace = ggml_cuda_assign_buffers_force_inplace; - } -#endif // GGML_USE_CUBLAS - - ggml_tensor * base_t; - if (ml) { - struct gguf_context * ctx_gguf = ml->ctx_gguf; - - // load from base model - if (gguf_find_tensor(ctx_gguf, base_name.c_str()) < 0) { - // TODO: throw - LLAMA_LOG_ERROR("%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str()); - return 1; - } - - // TODO: not tested!! maybe not working! - base_t = ml->create_tensor(base_ctx, base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }, GGML_BACKEND_CPU); - ml->load_data_for(base_t); - } else { - base_t = dest_t; - } - - if (ggml_is_quantized(base_t->type)) { - if (!warned) { - LLAMA_LOG_WARN("%s: warning: using a lora adapter with a quantized model may result in poor quality, " - "use a f16 or f32 base model with --lora-base\n", __func__); - warned = true; - } - } - - ggml_tensor * loraA = lora_tensors[base_name + ".loraA"]; - GGML_ASSERT(loraA->type == GGML_TYPE_F32); - ggml_set_name(loraA, "loraA"); - - ggml_tensor * loraB = lora_tensors[base_name + ".loraB"]; - GGML_ASSERT(loraB->type == GGML_TYPE_F32); - ggml_set_name(loraB, "loraB"); - - if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) { - LLAMA_LOG_ERROR("%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");" - " are you sure that this adapter is for this model?\n", __func__, base_t->ne[0], loraA->ne[1]); - return 1; - } - - // w = w + BA*s - ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB); - offload_func(BA); - ggml_set_name(BA, "BA"); - - if (scaling != 1.0f) { - ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, scaling); - ggml_set_name(scale_tensor, "scale_tensor"); - - BA = ggml_scale_inplace(lora_ctx, BA, scale_tensor); - offload_func(BA); - ggml_set_name(BA, "BA_scaled"); - } - - ggml_tensor * r; - if (base_t == dest_t) { - r = ggml_add_inplace(lora_ctx, dest_t, BA); - offload_func_force_inplace(r); - ggml_set_name(r, "r_add_inplace"); - } - else { - r = ggml_add(lora_ctx, base_t, BA); - offload_func(r); - ggml_set_name(r, "r_add"); - - r = ggml_cpy(lora_ctx, r, dest_t); - offload_func(r); - ggml_set_name(r, "r_cpy"); - } - - struct ggml_cgraph * gf = ggml_new_graph(lora_ctx); - ggml_build_forward_expand(gf, r); - - ggml_graph_compute_helper(work_buffer, gf, n_threads); - - // we won't need these tensors again, reset the context to save memory - ggml_free(lora_ctx); - lora_ctx = ggml_init(params); - lora_tensors.clear(); - - n_tensors++; - if (n_tensors % 4 == 0) { - LLAMA_LOG_INFO("."); - } - } - } - - // TODO: this should be in a destructor, it will leak on failure - ggml_free(lora_ctx); - if (base_ctx) { - ggml_free(base_ctx); - } - - const int64_t t_lora_us = ggml_time_us() - t_start_lora_us; - LLAMA_LOG_INFO(" done (%.2f ms)\n", t_lora_us / 1000.0); - - return 0; -} - -// -// interface implementation -// -struct llama_model_params llama_model_default_params() { - struct llama_model_params result = { - /*.n_gpu_layers =*/ 0, - /*.main_gpu =*/ 0, - /*.tensor_split =*/ nullptr, - /*.progress_callback =*/ nullptr, - /*.progress_callback_user_data =*/ nullptr, - /*.vocab_only =*/ false, - /*.use_mmap =*/ true, - /*.use_mlock =*/ false, - }; - -#ifdef GGML_USE_METAL - result.n_gpu_layers = 1; -#endif - - return result; -} - -struct llama_context_params llama_context_default_params() { - struct llama_context_params result = { - /*.seed =*/ LLAMA_DEFAULT_SEED, - /*.n_ctx =*/ 512, - /*.n_batch =*/ 512, - /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default - /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, - /*.rope_freq_base =*/ 0.0f, - /*.rope_freq_scale =*/ 0.0f, - /*.mul_mat_q =*/ true, - /*.f16_kv =*/ true, - /*.logits_all =*/ false, - /*.embedding =*/ false, - }; - - return result; -} - -struct llama_model_quantize_params llama_model_quantize_default_params() { - struct llama_model_quantize_params result = { - /*.nthread =*/ 0, - /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1, - /*.allow_requantize =*/ false, - /*.quantize_output_tensor =*/ true, - /*.only_copy =*/ false, - }; - - return result; -} - -int llama_max_devices(void) { - return LLAMA_MAX_DEVICES; -} - -bool llama_mmap_supported(void) { - return llama_mmap::SUPPORTED; -} - -bool llama_mlock_supported(void) { - return llama_mlock::SUPPORTED; -} - -void llama_backend_init(bool numa) { - ggml_time_init(); - - // needed to initialize f16 tables - { - struct ggml_init_params params = { 0, NULL, false }; - struct ggml_context * ctx = ggml_init(params); - ggml_free(ctx); - } - - if (numa) { - ggml_numa_init(); - } - -#ifdef GGML_USE_MPI - ggml_mpi_backend_init(); -#endif -} - -void llama_backend_free(void) { -#ifdef GGML_USE_MPI - ggml_mpi_backend_free(); -#endif -} - -int64_t llama_time_us(void) { - return ggml_time_us(); -} - -struct llama_model * llama_load_model_from_file( - const char * path_model, - struct llama_model_params params) { - ggml_time_init(); - - llama_model * model = new llama_model; - - unsigned cur_percentage = 0; - if (params.progress_callback == NULL) { - params.progress_callback_user_data = &cur_percentage; - params.progress_callback = [](float progress, void * ctx) { - unsigned * cur_percentage_p = (unsigned *) ctx; - unsigned percentage = (unsigned) (100 * progress); - while (percentage > *cur_percentage_p) { - *cur_percentage_p = percentage; - LLAMA_LOG_INFO("."); - if (percentage >= 100) { - LLAMA_LOG_INFO("\n"); - } - } - }; - } - - if (!llama_model_load(path_model, *model, params.n_gpu_layers, - params.main_gpu, params.tensor_split, - params.use_mmap, params.use_mlock, params.vocab_only, - params.progress_callback, params.progress_callback_user_data)) { - LLAMA_LOG_ERROR("%s: failed to load model\n", __func__); - delete model; - return nullptr; - } - - return model; -} - -void llama_free_model(struct llama_model * model) { - delete model; -} - -struct llama_context * llama_new_context_with_model( - struct llama_model * model, - struct llama_context_params params) { - - if (!model) { - return nullptr; - } - - llama_context * ctx = new llama_context(*model); - - const auto & hparams = model->hparams; - auto & cparams = ctx->cparams; - - cparams.n_batch = params.n_batch; - cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; - cparams.rope_freq_base = params.rope_freq_base == 0 ? hparams.rope_freq_base_train : params.rope_freq_base; - cparams.rope_freq_scale = params.rope_freq_scale == 0 ? hparams.rope_freq_scale_train : params.rope_freq_scale; - cparams.n_threads = params.n_threads; - cparams.n_threads_batch = params.n_threads_batch; - cparams.mul_mat_q = params.mul_mat_q; - - if (params.seed == LLAMA_DEFAULT_SEED) { - params.seed = time(NULL); - } - - LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); - LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); - LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); - - ctx->rng = std::mt19937(params.seed); - ctx->logits_all = params.logits_all; - - ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - - // reserve memory for context buffers - if (!hparams.vocab_only) { - if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, cparams.n_ctx, model->n_gpu_layers)) { - LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); - llama_free(ctx); - return nullptr; - } - - { - const size_t memory_size = ggml_nbytes(ctx->kv_self.k) + ggml_nbytes(ctx->kv_self.v); - LLAMA_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); - } - - // resized during inference - if (params.logits_all) { - ctx->logits.reserve(cparams.n_ctx*hparams.n_vocab); - } else { - ctx->logits.reserve(hparams.n_vocab); - } - - if (params.embedding){ - ctx->embedding.resize(hparams.n_embd); - } - - { - static const size_t tensor_alignment = 32; - // the compute buffer is used to store the tensor and graph structs, while the allocator buffer is used for the tensor data - ctx->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); - - // create measure allocator - ctx->alloc = ggml_allocr_new_measure(tensor_alignment); - - // build worst-case graph - int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch); - int n_past = cparams.n_ctx - n_tokens; - llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0)); - -#ifdef GGML_USE_METAL - if (model->n_gpu_layers > 0) { - ggml_metal_log_set_callback(llama_log_callback_default, NULL); - - ctx->ctx_metal = ggml_metal_init(1); - if (!ctx->ctx_metal) { - LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__); - llama_free(ctx); - return NULL; - } - //ggml_metal_graph_find_concurrency(ctx->ctx_metal, gf, false); - //ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); - } -#endif - // measure memory requirements for the graph - size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf) + tensor_alignment; - - LLAMA_LOG_INFO("%s: compute buffer total size = %.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0); - - // recreate allocator with exact memory requirements - ggml_allocr_free(ctx->alloc); - - ctx->buf_alloc.resize(alloc_size); - ctx->alloc = ggml_allocr_new(ctx->buf_alloc.data, ctx->buf_alloc.size, tensor_alignment); -#ifdef GGML_USE_METAL - if (ctx->ctx_metal) { - //ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); - } -#endif -#ifdef GGML_USE_CUBLAS - ggml_cuda_set_scratch_size(alloc_size); - LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MB\n", __func__, alloc_size / 1024.0 / 1024.0); - - // calculate total VRAM usage - auto add_tensor = [](const ggml_tensor * t, size_t & size) { - if (t->backend == GGML_BACKEND_GPU || t->backend == GGML_BACKEND_GPU_SPLIT) { - size += ggml_nbytes(t); - } - }; - size_t model_vram_size = 0; - for (const auto & kv : model->tensors_by_name) { - add_tensor(kv.second, model_vram_size); - } - - size_t kv_vram_size = 0; - add_tensor(ctx->kv_self.k, kv_vram_size); - add_tensor(ctx->kv_self.v, kv_vram_size); - - size_t ctx_vram_size = alloc_size + kv_vram_size; - size_t total_vram_size = model_vram_size + ctx_vram_size; - - LLAMA_LOG_INFO("%s: total VRAM used: %.2f MB (model: %.2f MB, context: %.2f MB)\n", __func__, - total_vram_size / 1024.0 / 1024.0, - model_vram_size / 1024.0 / 1024.0, - ctx_vram_size / 1024.0 / 1024.0); -#endif - } - -#ifdef GGML_USE_METAL - if (model->n_gpu_layers > 0) { - // this allocates all Metal resources and memory buffers - - void * data_ptr = NULL; - size_t data_size = 0; - - if (ctx->model.mapping) { - data_ptr = ctx->model.mapping->addr; - data_size = ctx->model.mapping->size; - } else { - data_ptr = ggml_get_mem_buffer(ctx->model.ctx); - data_size = ggml_get_mem_size (ctx->model.ctx); - } - - const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx); - - LLAMA_LOG_INFO("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0); - -#define LLAMA_METAL_CHECK_BUF(result) \ - if (!(result)) { \ - LLAMA_LOG_ERROR("%s: failed to add buffer\n", __func__); \ - llama_free(ctx); \ - return NULL; \ - } - - LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "data", data_ptr, data_size, max_size)); - LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->kv_self.buf.data, ctx->kv_self.buf.size, 0)); - LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "alloc", ctx->buf_alloc.data, ctx->buf_alloc.size, 0)); -#undef LLAMA_METAL_CHECK_BUF - } -#endif - } - -#ifdef GGML_USE_MPI - ctx->ctx_mpi = ggml_mpi_init(); - - if (ggml_mpi_rank(ctx->ctx_mpi) > 0) { - // Enter a blocking eval loop with dummy input, letting rank=0 drive the process - // TODO: needs fix after #3228 - GGML_ASSERT(false && "not implemented"); - //const std::vector tmp(ctx->model.hparams.n_ctx, llama_token_bos(ctx)); - //while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {}; - llama_backend_free(); - exit(1); - } -#endif - - return ctx; -} - -void llama_free(struct llama_context * ctx) { - delete ctx; -} - -const llama_model * llama_get_model(const struct llama_context * ctx) { - return &ctx->model; -} - -int llama_n_ctx(const struct llama_context * ctx) { - return ctx->cparams.n_ctx; -} - -enum llama_vocab_type llama_vocab_type(const struct llama_model * model) { - return model->vocab.type; -} - -int llama_n_vocab(const struct llama_model * model) { - return model->vocab.id_to_token.size(); -} - -int llama_n_ctx_train(const struct llama_model * model) { - return model->hparams.n_ctx_train; -} - -int llama_n_embd(const struct llama_model * model) { - return model->hparams.n_embd; -} - -float llama_rope_freq_scale_train(const struct llama_model * model) { - return model->hparams.rope_freq_scale_train; -} - -int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) { - return snprintf(buf, buf_size, "%s %s %s", - llama_model_arch_name(model->arch).c_str(), - llama_model_type_name(model->type), - llama_model_ftype_name(model->ftype).c_str()); -} - -uint64_t llama_model_size(const struct llama_model * model) { - uint64_t size = 0; - for (const auto & it : model->tensors_by_name) { - size += ggml_nbytes(it.second); - } - return size; -} - -uint64_t llama_model_n_params(const struct llama_model * model) { - uint64_t nparams = 0; - for (const auto & it : model->tensors_by_name) { - nparams += ggml_nelements(it.second); - } - return nparams; -} - -struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name) { - return ggml_get_tensor(model->ctx, name); -} - -int llama_model_quantize( - const char * fname_inp, - const char * fname_out, - const llama_model_quantize_params * params) { - try { - llama_model_quantize_internal(fname_inp, fname_out, params); - return 0; - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: failed to quantize: %s\n", __func__, err.what()); - return 1; - } -} - -int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, float scale, const char * path_base_model, int n_threads) { - try { - return llama_apply_lora_from_file_internal(ctx->model, path_lora, scale, path_base_model, n_threads); - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what()); - return 1; - } -} - -int llama_model_apply_lora_from_file(const struct llama_model * model, const char * path_lora, float scale, const char * path_base_model, int n_threads) { - try { - return llama_apply_lora_from_file_internal(*model, path_lora, scale, path_base_model, n_threads); - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what()); - return 1; - } -} - -int llama_get_kv_cache_token_count(const struct llama_context * ctx) { - return ctx->kv_self.head; -} - -void llama_kv_cache_tokens_rm(struct llama_context * ctx, int32_t c0, int32_t c1) { - llama_kv_cache_tokens_rm(ctx->kv_self, c0, c1); -} - -void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1); -} - -void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1); -} - -void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { - llama_kv_cache_seq_keep(ctx->kv_self, seq_id); -} - -void llama_kv_cache_seq_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - llama_kv_cache_seq_shift(ctx->kv_self, seq_id, p0, p1, delta); -} - -// Returns the *maximum* size of the state -size_t llama_get_state_size(const struct llama_context * ctx) { - // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. - // for reference, std::mt19937(1337) serializes to 6701 bytes. - const size_t s_rng_size = sizeof(size_t); - const size_t s_rng = LLAMA_MAX_RNG_STATE; - const size_t s_logits_capacity = sizeof(size_t); - const size_t s_logits_size = sizeof(size_t); - const size_t s_logits = ctx->logits.capacity() * sizeof(float); - const size_t s_embedding_size = sizeof(size_t); - const size_t s_embedding = ctx->embedding.size() * sizeof(float); - const size_t s_kv_size = sizeof(size_t); - const size_t s_kv_ntok = sizeof(int); - const size_t s_kv = ctx->kv_self.buf.size; - - const size_t s_total = ( - + s_rng_size - + s_rng - + s_logits_capacity - + s_logits_size - + s_logits - + s_embedding_size - + s_embedding - + s_kv_size - + s_kv_ntok - + s_kv - ); - - return s_total; -} - -// llama_context_data -struct llama_data_context { - virtual void write(const void * src, size_t size) = 0; - virtual size_t get_size_written() = 0; - virtual ~llama_data_context() = default; -}; - -struct llama_data_buffer_context : llama_data_context { - uint8_t * ptr; - size_t size_written = 0; - - llama_data_buffer_context(uint8_t * p) : ptr(p) {} - - void write(const void * src, size_t size) override { - memcpy(ptr, src, size); - ptr += size; - size_written += size; - } - - size_t get_size_written() override { - return size_written; - } -}; - -struct llama_data_file_context : llama_data_context { - llama_file * file; - size_t size_written = 0; - - llama_data_file_context(llama_file * f) : file(f) {} - - void write(const void * src, size_t size) override { - file->write_raw(src, size); - size_written += size; - } - - size_t get_size_written() override { - return size_written; - } -}; - -/** copy state data into either a buffer or file depending on the passed in context - * - * file context: - * llama_file file("/path", "wb"); - * llama_data_file_context data_ctx(&file); - * llama_copy_state_data(ctx, &data_ctx); - * - * buffer context: - * std::vector buf(max_size, 0); - * llama_data_buffer_context data_ctx(&buf.data()); - * llama_copy_state_data(ctx, &data_ctx); - * -*/ -static void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { - // copy rng - { - std::stringstream rng_ss; - rng_ss << ctx->rng; - - const size_t rng_size = rng_ss.str().size(); - char rng_buf[LLAMA_MAX_RNG_STATE]; - - memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE); - memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); - - data_ctx->write(&rng_size, sizeof(rng_size)); - data_ctx->write(&rng_buf[0], LLAMA_MAX_RNG_STATE); - } - - // copy logits - { - const size_t logits_cap = ctx->logits.capacity(); - const size_t logits_size = ctx->logits.size(); - - data_ctx->write(&logits_cap, sizeof(logits_cap)); - data_ctx->write(&logits_size, sizeof(logits_size)); - - if (logits_size) { - data_ctx->write(ctx->logits.data(), logits_size * sizeof(float)); - } - - // If there is a gap between the size and the capacity, write padding - size_t padding_size = (logits_cap - logits_size) * sizeof(float); - if (padding_size > 0) { - std::vector padding(padding_size, 0); // Create a buffer filled with zeros - data_ctx->write(padding.data(), padding_size); - } - } - - // copy embeddings - { - const size_t embedding_size = ctx->embedding.size(); - - data_ctx->write(&embedding_size, sizeof(embedding_size)); - - if (embedding_size) { - data_ctx->write(ctx->embedding.data(), embedding_size * sizeof(float)); - } - } - - // copy kv cache - { - const auto & kv_self = ctx->kv_self; - const auto & hparams = ctx->model.hparams; - const auto & cparams = ctx->cparams; - - const auto n_layer = hparams.n_layer; - const auto n_embd = hparams.n_embd_gqa(); - const auto n_ctx = cparams.n_ctx; - - const size_t kv_buf_size = kv_self.buf.size; - const uint32_t kv_head = kv_self.head; - const uint32_t kv_size = kv_self.size; - - data_ctx->write(&kv_buf_size, sizeof(kv_buf_size)); - data_ctx->write(&kv_head, sizeof(kv_head)); - data_ctx->write(&kv_size, sizeof(kv_size)); - - if (kv_buf_size) { - const size_t elt_size = ggml_element_size(kv_self.k); - - ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); - ggml_cgraph gf{}; - - ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer); - std::vector kout3d_data(ggml_nbytes(kout3d), 0); - kout3d->data = kout3d_data.data(); - - ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_head, n_embd, n_layer); - std::vector vout3d_data(ggml_nbytes(vout3d), 0); - vout3d->data = vout3d_data.data(); - - ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k, - n_embd, kv_head, n_layer, - elt_size*n_embd, elt_size*n_embd*n_ctx, 0); - - ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v, - kv_head, n_embd, n_layer, - elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); - - ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d)); - ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d)); - ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1); - - ggml_free(cpy_ctx); - - // our data is now in the kout3d_data and vout3d_data buffers - // write them to file - data_ctx->write(kout3d_data.data(), kout3d_data.size()); - data_ctx->write(vout3d_data.data(), vout3d_data.size()); - } - - for (uint32_t i = 0; i < kv_size; ++i) { - const auto & cell = kv_self.cells[i]; - - const llama_pos pos = cell.pos; - const size_t seq_id_size = cell.seq_id.size(); - - data_ctx->write(&pos, sizeof(pos)); - data_ctx->write(&seq_id_size, sizeof(seq_id_size)); - - for (auto seq_id : cell.seq_id) { - data_ctx->write(&seq_id, sizeof(seq_id)); - } - } - } -} - -size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { - llama_data_buffer_context data_ctx(dst); - llama_copy_state_data_internal(ctx, &data_ctx); - - return data_ctx.get_size_written(); -} - -// Sets the state reading from the specified source address -size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { - uint8_t * inp = src; - - // set rng - { - size_t rng_size; - char rng_buf[LLAMA_MAX_RNG_STATE]; - - memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size); - memcpy(&rng_buf[0], inp, LLAMA_MAX_RNG_STATE); inp += LLAMA_MAX_RNG_STATE; - - std::stringstream rng_ss; - rng_ss.str(std::string(&rng_buf[0], rng_size)); - rng_ss >> ctx->rng; - - GGML_ASSERT(!rng_ss.fail()); - } - - // set logits - { - size_t logits_cap; - size_t logits_size; - - memcpy(&logits_cap, inp, sizeof(logits_cap)); inp += sizeof(logits_cap); - memcpy(&logits_size, inp, sizeof(logits_size)); inp += sizeof(logits_size); - - GGML_ASSERT(ctx->logits.capacity() == logits_cap); - - if (logits_size) { - ctx->logits.resize(logits_size); - memcpy(ctx->logits.data(), inp, logits_size * sizeof(float)); - } - - inp += logits_cap * sizeof(float); - } - - // set embeddings - { - size_t embedding_size; - - memcpy(&embedding_size, inp, sizeof(embedding_size)); inp += sizeof(embedding_size); - - GGML_ASSERT(ctx->embedding.capacity() == embedding_size); - - if (embedding_size) { - memcpy(ctx->embedding.data(), inp, embedding_size * sizeof(float)); - inp += embedding_size * sizeof(float); - } - } - - // set kv cache - { - const auto & kv_self = ctx->kv_self; - const auto & hparams = ctx->model.hparams; - const auto & cparams = ctx->cparams; - - const int n_layer = hparams.n_layer; - const int n_embd = hparams.n_embd_gqa(); - const int n_ctx = cparams.n_ctx; - - size_t kv_buf_size; - uint32_t kv_head; - uint32_t kv_size; - - memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size); - memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head); - memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size); - - if (kv_buf_size) { - GGML_ASSERT(kv_self.buf.size == kv_buf_size); - - const size_t elt_size = ggml_element_size(kv_self.k); - - ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); - ggml_cgraph gf{}; - - ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer); - kin3d->data = (void *) inp; - inp += ggml_nbytes(kin3d); - - ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_head, n_embd, n_layer); - vin3d->data = (void *) inp; - inp += ggml_nbytes(vin3d); - - ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k, - n_embd, kv_head, n_layer, - elt_size*n_embd, elt_size*n_embd*n_ctx, 0); - - ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v, - kv_head, n_embd, n_layer, - elt_size*n_ctx, elt_size*n_ctx*n_embd, 0); - - ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d)); - ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d)); - ggml_graph_compute_helper(ctx->work_buffer, &gf, /*n_threads*/ 1); - - ggml_free(cpy_ctx); - } - - ctx->kv_self.head = kv_head; - ctx->kv_self.size = kv_size; - - ctx->kv_self.cells.resize(kv_size); - - for (uint32_t i = 0; i < kv_size; ++i) { - llama_pos pos; - size_t seq_id_size; - - memcpy(&pos, inp, sizeof(pos)); inp += sizeof(pos); - memcpy(&seq_id_size, inp, sizeof(seq_id_size)); inp += sizeof(seq_id_size); - - ctx->kv_self.cells[i].pos = pos; - - llama_seq_id seq_id; - - for (size_t j = 0; j < seq_id_size; ++j) { - memcpy(&seq_id, inp, sizeof(seq_id)); inp += sizeof(seq_id); - ctx->kv_self.cells[i].seq_id.insert(seq_id); - } - } - } - - const size_t nread = inp - src; - const size_t max_size = llama_get_state_size(ctx); - - GGML_ASSERT(nread <= max_size); - - return nread; -} - -static bool llama_load_session_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { - llama_file file(path_session, "rb"); - - // sanity checks - { - const uint32_t magic = file.read_u32(); - const uint32_t version = file.read_u32(); - - if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) { - LLAMA_LOG_ERROR("%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version); - return false; - } - - llama_hparams session_hparams; - file.read_raw(&session_hparams, sizeof(llama_hparams)); - - if (session_hparams != ctx->model.hparams) { - LLAMA_LOG_INFO("%s : model hparams didn't match from session file!\n", __func__); - return false; - } - } - - // load the prompt - { - const uint32_t n_token_count = file.read_u32(); - - if (n_token_count > n_token_capacity) { - LLAMA_LOG_ERROR("%s : token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); - return false; - } - - file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); - *n_token_count_out = n_token_count; - } - - // restore the context state - { - const size_t n_state_size_cur = file.size - file.tell(); - const size_t n_state_size_max = llama_get_state_size(ctx); - - if (n_state_size_cur > n_state_size_max) { - LLAMA_LOG_ERROR("%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur); - return false; - } - - std::vector state_data(n_state_size_max); - file.read_raw(state_data.data(), n_state_size_cur); - - llama_set_state_data(ctx, state_data.data()); - } - - return true; -} - -bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { - try { - return llama_load_session_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("error loading session file: %s\n", err.what()); - return false; - } -} - -bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { - llama_file file(path_session, "wb"); - - file.write_u32(LLAMA_SESSION_MAGIC); - file.write_u32(LLAMA_SESSION_VERSION); - - file.write_raw(&ctx->model.hparams, sizeof(llama_hparams)); - - // save the prompt - file.write_u32((uint32_t) n_token_count); - file.write_raw(tokens, sizeof(llama_token) * n_token_count); - - // save the context state using stream saving - llama_data_file_context data_ctx(&file); - llama_copy_state_data_internal(ctx, &data_ctx); - - return true; -} - -int llama_eval( - struct llama_context * ctx, - llama_token * tokens, - int32_t n_tokens, - int n_past) { - llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); - - const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0)); - if (ret < 0) { - LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); - } - - return ret; -} - -int llama_eval_embd( - struct llama_context * ctx, - float * embd, - int32_t n_tokens, - int n_past) { - llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); - - llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, }; - - const int ret = llama_decode_internal(*ctx, batch); - if (ret < 0) { - LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); - } - - return ret; -} - -void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) { - ctx->cparams.n_threads = n_threads; - ctx->cparams.n_threads_batch = n_threads_batch; -} - -struct llama_batch llama_batch_get_one( - llama_token * tokens, - int32_t n_tokens, - llama_pos pos_0, - llama_seq_id seq_id) { - return { - /*n_tokens =*/ n_tokens, - /*tokens =*/ tokens, - /*embd =*/ nullptr, - /*pos =*/ nullptr, - /*seq_id =*/ nullptr, - /*logits =*/ nullptr, - /*all_pos_0 =*/ pos_0, - /*all_pos_1 =*/ 1, - /*all_seq_id =*/ seq_id, - }; -} - -struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd) { - llama_batch batch = { -1, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, }; - - if (embd) { - batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd); - } else { - batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); - } - - batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); - batch.seq_id = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_tokens); - batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); - - return batch; -} - -void llama_batch_free(struct llama_batch batch) { - if (batch.token) free(batch.token); - if (batch.embd) free(batch.embd); - if (batch.pos) free(batch.pos); - if (batch.seq_id) free(batch.seq_id); - if (batch.logits) free(batch.logits); -} - -int llama_decode( - struct llama_context * ctx, - struct llama_batch batch) { - const int ret = llama_decode_internal(*ctx, batch); - if (ret < 0) { - LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); - } - - return ret; -} - -float * llama_get_logits(struct llama_context * ctx) { - return ctx->logits.data(); -} - -float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { - return ctx->logits.data() + i*ctx->model.hparams.n_vocab; -} - -float * llama_get_embeddings(struct llama_context * ctx) { - return ctx->embedding.data(); -} - -const char * llama_token_get_text(const struct llama_context * ctx, llama_token token) { - return ctx->model.vocab.id_to_token[token].text.c_str(); -} - -float llama_token_get_score(const struct llama_context * ctx, llama_token token) { - return ctx->model.vocab.id_to_token[token].score; -} - -llama_token_type llama_token_get_type(const struct llama_context * ctx, llama_token token) { - return ctx->model.vocab.id_to_token[token].type; -} - -llama_token llama_token_bos(const struct llama_context * ctx) { - return ctx->model.vocab.special_bos_id; -} - -llama_token llama_token_eos(const struct llama_context * ctx) { - return ctx->model.vocab.special_eos_id; -} - -llama_token llama_token_nl(const struct llama_context * ctx) { - return ctx->model.vocab.linefeed_id; -} -llama_token llama_token_prefix(const struct llama_context * ctx) { - return ctx->model.vocab.special_prefix_id; -} - -llama_token llama_token_middle(const struct llama_context * ctx) { - return ctx->model.vocab.special_middle_id; -} - -llama_token llama_token_suffix(const struct llama_context * ctx) { - return ctx->model.vocab.special_suffix_id; -} - -llama_token llama_token_eot(const struct llama_context * ctx) { - return ctx->model.vocab.special_eot_id; -} - - -int llama_tokenize( - const struct llama_model * model, - const char * text, - int text_len, - llama_token * tokens, - int n_max_tokens, - bool add_bos) { - auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos); - - if (n_max_tokens < (int) res.size()) { - // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); - return -((int) res.size()); - } - - for (size_t i = 0; i < res.size(); i++) { - tokens[i] = res[i]; - } - - return res.size(); -} - -static std::string llama_decode_text(const std::string & text) { - std::string decoded_text; - auto unicode_sequences = codepoints_from_utf8(text); - for (auto& unicode_sequence : unicode_sequences) { - decoded_text += unicode_to_bytes_bpe(codepoint_to_utf8(unicode_sequence)); - } - - return decoded_text; -} - -// does not write null-terminator to buf -int llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int length) { - if (0 <= token && token < llama_n_vocab(model)) { - switch (llama_vocab_get_type(model->vocab)) { - case LLAMA_VOCAB_TYPE_SPM: { - if (llama_is_normal_token(model->vocab, token)) { - std::string result = model->vocab.id_to_token[token].text; - llama_unescape_whitespace(result); - if (length < (int) result.length()) { - return -result.length(); - } - memcpy(buf, result.c_str(), result.length()); - return result.length(); - } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT - if (length < 3) { - return -3; - } - memcpy(buf, "\xe2\x96\x85", 3); - return 3; - } else if (llama_is_control_token(model->vocab, token)) { - ; - } else if (llama_is_byte_token(model->vocab, token)) { - if (length < 1) { - return -1; - } - buf[0] = llama_token_to_byte(model->vocab, token); - return 1; - } else { - // TODO: for now we accept all unsupported token types, - // suppressing them like CONTROL tokens. - // GGML_ASSERT(false); - } - break; - } - case LLAMA_VOCAB_TYPE_BPE: { - if (llama_is_normal_token(model->vocab, token)) { - std::string result = model->vocab.id_to_token[token].text; - result = llama_decode_text(result); - if (length < (int) result.length()) { - return -result.length(); - } - memcpy(buf, result.c_str(), result.length()); - return result.length(); - } else if (llama_is_control_token(model->vocab, token)) { - ; - } else { - // TODO: for now we accept all unsupported token types, - // suppressing them like CONTROL tokens. - // GGML_ASSERT(false); - } - break; - } - default: - GGML_ASSERT(false); - } - } - return 0; -} - -struct llama_timings llama_get_timings(struct llama_context * ctx) { - struct llama_timings result = { - /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, - /*.t_end_ms =*/ 1.00 * ggml_time_ms(), - /*.t_load_ms =*/ 1e-3 * ctx->t_load_us, - /*.t_sample_ms =*/ 1e-3 * ctx->t_sample_us, - /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us, - /*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us, - - /*.n_sample =*/ std::max(1, ctx->n_sample), - /*.n_p_eval =*/ std::max(1, ctx->n_p_eval), - /*.n_eval =*/ std::max(1, ctx->n_eval), - }; - - return result; -} - -void llama_print_timings(struct llama_context * ctx) { - const llama_timings timings = llama_get_timings(ctx); - - LLAMA_LOG_INFO("\n"); - LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, timings.t_load_ms); - LLAMA_LOG_INFO("%s: sample time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", - __func__, timings.t_sample_ms, timings.n_sample, timings.t_sample_ms / timings.n_sample, 1e3 / timings.t_sample_ms * timings.n_sample); - LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", - __func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval); - LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", - __func__, timings.t_eval_ms, timings.n_eval, timings.t_eval_ms / timings.n_eval, 1e3 / timings.t_eval_ms * timings.n_eval); - LLAMA_LOG_INFO("%s: total time = %10.2f ms\n", __func__, (timings.t_end_ms - timings.t_start_ms)); -} - -void llama_reset_timings(struct llama_context * ctx) { - ctx->t_start_us = ggml_time_us(); - ctx->t_sample_us = ctx->n_sample = 0; - ctx->t_eval_us = ctx->n_eval = 0; - ctx->t_p_eval_us = ctx->n_p_eval = 0; -} - -const char * llama_print_system_info(void) { - static std::string s; - - s = ""; - s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; - s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; - s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; - s += "AVX512_VBMI = " + std::to_string(ggml_cpu_has_avx512_vbmi()) + " | "; - s += "AVX512_VNNI = " + std::to_string(ggml_cpu_has_avx512_vnni()) + " | "; - s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; - s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; - s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; - s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | "; - s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; - s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; - s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; - s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | "; - s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | "; - s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; - - return s.c_str(); -} - -void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) { - fprintf(stream, "\n"); - fprintf(stream, "###########\n"); - fprintf(stream, "# Timings #\n"); - fprintf(stream, "###########\n"); - fprintf(stream, "\n"); - - fprintf(stream, "mst_eval: %.2f # ms / token during generation\n", - 1.0e-3 * ctx->t_eval_us / ctx->n_eval); - fprintf(stream, "mst_p_eval: %.2f # ms / token during prompt processing\n", - 1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval); - fprintf(stream, "mst_sample: %.2f # ms / token during sampling\n", - 1.0e-3 * ctx->t_sample_us / ctx->n_sample); - fprintf(stream, "n_eval: %d # number of tokens generated (excluding the first one)\n", ctx->n_eval); - fprintf(stream, "n_p_eval: %d # number of tokens processed in batches at the beginning\n", ctx->n_p_eval); - fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->n_sample); - fprintf(stream, "t_eval_us: %" PRId64 " # total microseconds spent generating tokens\n", ctx->t_eval_us); - fprintf(stream, "t_load_us: %" PRId64 " # total microseconds spent loading the model\n", ctx->t_load_us); - fprintf(stream, "t_p_eval_us: %" PRId64 " # total microseconds spent prompt processing\n", ctx->t_p_eval_us); - fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->t_sample_us); - fprintf(stream, "ts_eval: %.2f # tokens / second during generation\n", - 1.0e6 * ctx->n_eval / ctx->t_eval_us); - fprintf(stream, "ts_p_eval: %.2f # tokens / second during prompt processing\n", - 1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us); - fprintf(stream, "ts_sample: %.2f # tokens / second during sampling\n", - 1.0e6 * ctx->n_sample / ctx->t_sample_us); -} - -// For internal test use -const std::vector> & llama_internal_get_tensor_map( - struct llama_context * ctx -) { - return ctx->model.tensors_by_name; -} - -void llama_log_set(ggml_log_callback log_callback, void * user_data) { - g_state.log_callback = log_callback ? log_callback : llama_log_callback_default; - g_state.log_callback_user_data = user_data; -} - -static void llama_log_internal_v(ggml_log_level level, const char * format, va_list args) { - va_list args_copy; - va_copy(args_copy, args); - char buffer[128]; - int len = vsnprintf(buffer, 128, format, args); - if (len < 128) { - g_state.log_callback(level, buffer, g_state.log_callback_user_data); - } else { - char* buffer2 = new char[len+1]; - vsnprintf(buffer2, len+1, format, args_copy); - buffer2[len] = 0; - g_state.log_callback(level, buffer2, g_state.log_callback_user_data); - delete[] buffer2; - } - va_end(args_copy); -} - -static void llama_log_internal(ggml_log_level level, const char * format, ...) { - va_list args; - va_start(args, format); - llama_log_internal_v(level, format, args); - va_end(args); -} - -static void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data) { - (void) level; - bool * enable_log = static_cast(user_data); - if (enable_log && *enable_log) { - fputs(text, stderr); - fflush(stderr); - } -} diff --git a/plugins/wasi_nn/thirdparty/ggml/llama.h b/plugins/wasi_nn/thirdparty/ggml/llama.h deleted file mode 100644 index a78015ad..00000000 --- a/plugins/wasi_nn/thirdparty/ggml/llama.h +++ /dev/null @@ -1,752 +0,0 @@ -#ifndef LLAMA_H -#define LLAMA_H - -#include "ggml.h" -#ifdef GGML_USE_CUBLAS -#include "ggml-cuda.h" -#define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES -#else -#define LLAMA_MAX_DEVICES 1 -#endif // GGML_USE_CUBLAS -#include -#include -#include -#include - -#ifdef LLAMA_SHARED -# if defined(_WIN32) && !defined(__MINGW32__) -# ifdef LLAMA_BUILD -# define LLAMA_API __declspec(dllexport) -# else -# define LLAMA_API __declspec(dllimport) -# endif -# else -# define LLAMA_API __attribute__ ((visibility ("default"))) -# endif -#else -# define LLAMA_API -#endif - -#ifdef __GNUC__ -# define DEPRECATED(func, hint) func __attribute__((deprecated(hint))) -#elif defined(_MSC_VER) -# define DEPRECATED(func, hint) __declspec(deprecated(hint)) func -#else -# define DEPRECATED(func, hint) func -#endif - -#define LLAMA_DEFAULT_SEED 0xFFFFFFFF - -#define LLAMA_MAX_RNG_STATE (64*1024) - -#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' - -#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN -#define LLAMA_SESSION_VERSION 2 - -#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) -// Defined when llama.cpp is compiled with support for offloading model layers to GPU. -#define LLAMA_SUPPORTS_GPU_OFFLOAD -#endif - -#ifdef __cplusplus -extern "C" { -#endif - - // - // C interface - // - // TODO: show sample usage - // - - struct llama_model; - struct llama_context; - - typedef int32_t llama_pos; - typedef int32_t llama_token; - typedef int32_t llama_seq_id; - - enum llama_vocab_type { - LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece - LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding - }; - - enum llama_token_type { - LLAMA_TOKEN_TYPE_UNDEFINED = 0, - LLAMA_TOKEN_TYPE_NORMAL = 1, - LLAMA_TOKEN_TYPE_UNKNOWN = 2, - LLAMA_TOKEN_TYPE_CONTROL = 3, - LLAMA_TOKEN_TYPE_USER_DEFINED = 4, - LLAMA_TOKEN_TYPE_UNUSED = 5, - LLAMA_TOKEN_TYPE_BYTE = 6, - }; - - // model file types - enum llama_ftype { - LLAMA_FTYPE_ALL_F32 = 0, - LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 - // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed - // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed - LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q3_K_S = 11, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q3_K_M = 12, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q3_K_L = 13, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_K_S = 14, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_K_M = 15, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_K_S = 16, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_K_M = 17, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q6_K = 18, // except 1d tensors - - LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file - }; - - typedef struct llama_token_data { - llama_token id; // token id - float logit; // log-odds of the token - float p; // probability of the token - } llama_token_data; - - typedef struct llama_token_data_array { - llama_token_data * data; - size_t size; - bool sorted; - } llama_token_data_array; - - typedef void (*llama_progress_callback)(float progress, void *ctx); - - // Input data for llama_decode - // A llama_batch object can contain input about one or many sequences - // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens - // - // - token : the token ids of the input (used when embd is NULL) - // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) - // - pos : the positions of the respective token in the sequence - // - seq_id : the sequence to which the respective token belongs - // - logits : if zero, the logits for the respective token will not be output - // - typedef struct llama_batch { - int32_t n_tokens; - - llama_token * token; - float * embd; - llama_pos * pos; - llama_seq_id * seq_id; - int8_t * logits; - - // NOTE: helpers for smooth API transition - can be deprecated in the future - // for future-proof code, use the above fields instead and ignore everything below - // - // pos[i] = all_pos_0 + i*all_pos_1 - // - llama_pos all_pos_0; // used if pos == NULL - llama_pos all_pos_1; // used if pos == NULL - llama_seq_id all_seq_id; // used if seq_id == NULL - } llama_batch; - - struct llama_model_params { - int32_t n_gpu_layers; // number of layers to store in VRAM - int32_t main_gpu; // the GPU that is used for scratch and small tensors - const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES) - - // called with a progress value between 0 and 1, pass NULL to disable - llama_progress_callback progress_callback; - // context pointer passed to the progress callback - void * progress_callback_user_data; - - // Keep the booleans together to avoid misalignment during copy-by-value. - bool vocab_only; // only load the vocabulary, no weights - bool use_mmap; // use mmap if possible - bool use_mlock; // force system to keep model in RAM - }; - - struct llama_context_params { - uint32_t seed; // RNG seed, -1 for random - uint32_t n_ctx; // text context, 0 = from model - uint32_t n_batch; // prompt processing maximum batch size - uint32_t n_threads; // number of threads to use for generation - uint32_t n_threads_batch; // number of threads to use for batch processing - - // ref: https://github.com/ggerganov/llama.cpp/pull/2054 - float rope_freq_base; // RoPE base frequency, 0 = from model - float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model - - // Keep the booleans together to avoid misalignment during copy-by-value. - bool mul_mat_q; // if true, use experimental mul_mat_q kernels - bool f16_kv; // use fp16 for KV cache, fp32 otherwise - bool logits_all; // the llama_eval() call computes all logits, not just the last one - bool embedding; // embedding mode only - }; - - // model quantization parameters - typedef struct llama_model_quantize_params { - int nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() - enum llama_ftype ftype; // quantize to this llama_ftype - bool allow_requantize; // allow quantizing non-f32/f16 tensors - bool quantize_output_tensor; // quantize output.weight - bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored - } llama_model_quantize_params; - - // grammar types - struct llama_grammar; - - // grammar element type - enum llama_gretype { - // end of rule definition - LLAMA_GRETYPE_END = 0, - - // start of alternate definition for rule - LLAMA_GRETYPE_ALT = 1, - - // non-terminal element: reference to rule - LLAMA_GRETYPE_RULE_REF = 2, - - // terminal element: character (code point) - LLAMA_GRETYPE_CHAR = 3, - - // inverse char(s) ([^a], [^a-b] [^abc]) - LLAMA_GRETYPE_CHAR_NOT = 4, - - // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to - // be an inclusive range ([a-z]) - LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, - - // modifies a preceding LLAMA_GRETYPE_CHAR or - // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) - LLAMA_GRETYPE_CHAR_ALT = 6, - }; - - typedef struct llama_grammar_element { - enum llama_gretype type; - uint32_t value; // Unicode code point or rule ID - } llama_grammar_element; - - // performance timing information - struct llama_timings { - double t_start_ms; - double t_end_ms; - double t_load_ms; - double t_sample_ms; - double t_p_eval_ms; - double t_eval_ms; - - int32_t n_sample; - int32_t n_p_eval; - int32_t n_eval; - }; - - // Helpers for getting default parameters - LLAMA_API struct llama_model_params llama_model_default_params(void); - LLAMA_API struct llama_context_params llama_context_default_params(void); - LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); - - // Initialize the llama + ggml backend - // If numa is true, use NUMA optimizations - // Call once at the start of the program - LLAMA_API void llama_backend_init(bool numa); - - // Call once at the end of the program - currently only used for MPI - LLAMA_API void llama_backend_free(void); - - LLAMA_API struct llama_model * llama_load_model_from_file( - const char * path_model, - struct llama_model_params params); - - LLAMA_API void llama_free_model(struct llama_model * model); - - LLAMA_API struct llama_context * llama_new_context_with_model( - struct llama_model * model, - struct llama_context_params params); - - // Frees all allocated memory - LLAMA_API void llama_free(struct llama_context * ctx); - - LLAMA_API int64_t llama_time_us(void); - - LLAMA_API int llama_max_devices (void); - LLAMA_API bool llama_mmap_supported (void); - LLAMA_API bool llama_mlock_supported(void); - - LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); - - LLAMA_API int llama_n_ctx (const struct llama_context * ctx); - - LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model); - - LLAMA_API int llama_n_vocab (const struct llama_model * model); - LLAMA_API int llama_n_ctx_train(const struct llama_model * model); - LLAMA_API int llama_n_embd (const struct llama_model * model); - - // Get the model's RoPE frequency scaling factor - LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model); - - // Get a string describing the model type - LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size); - - // Returns the total size of all the tensors in the model in bytes - LLAMA_API uint64_t llama_model_size(const struct llama_model * model); - - // Returns the total number of parameters in the model - LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); - - // Get a llama model tensor - LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name); - - // Returns 0 on success - LLAMA_API int llama_model_quantize( - const char * fname_inp, - const char * fname_out, - const llama_model_quantize_params * params); - - // Apply a LoRA adapter to a loaded model - // path_base_model is the path to a higher quality model to use as a base for - // the layers modified by the adapter. Can be NULL to use the current loaded model. - // The model needs to be reloaded before applying a new adapter, otherwise the adapter - // will be applied on top of the previous one - // Returns 0 on success - LLAMA_API DEPRECATED(int llama_apply_lora_from_file( - struct llama_context * ctx, - const char * path_lora, - float scale, - const char * path_base_model, - int n_threads), - "use llama_model_apply_lora_from_file instead"); - - LLAMA_API int llama_model_apply_lora_from_file( - const struct llama_model * model, - const char * path_lora, - float scale, - const char * path_base_model, - int n_threads); - - // - // KV cache - // - - // Returns the number of tokens in the KV cache - LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx), - "avoid using this, it will be removed in the future, instead - count the tokens in user code"); - - // Remove all tokens data of cells in [c0, c1) - // c0 < 0 : [0, c1] - // c1 < 0 : [c0, inf) - LLAMA_API void llama_kv_cache_tokens_rm( - struct llama_context * ctx, - int32_t c0, - int32_t c1); - - // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) - // p0 < 0 : [0, p1] - // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_cache_seq_rm( - struct llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1); - - // Copy all tokens that belong to the specified sequence to another sequence - // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence - // p0 < 0 : [0, p1] - // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_cache_seq_cp( - struct llama_context * ctx, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1); - - // Removes all tokens that do not belong to the specified sequence - LLAMA_API void llama_kv_cache_seq_keep( - struct llama_context * ctx, - llama_seq_id seq_id); - - // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) - // If the KV cache is RoPEd, the KV data is updated accordingly - // p0 < 0 : [0, p1] - // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_cache_seq_shift( - struct llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta); - - // - // State / sessions - // - - // Returns the maximum size in bytes of the state (rng, logits, embedding - // and kv_cache) - will often be smaller after compacting tokens - LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx); - - // Copies the state to the specified destination address. - // Destination needs to have allocated enough memory. - // Returns the number of bytes copied - LLAMA_API size_t llama_copy_state_data( - struct llama_context * ctx, - uint8_t * dst); - - // Set the state reading from the specified address - // Returns the number of bytes read - LLAMA_API size_t llama_set_state_data( - struct llama_context * ctx, - uint8_t * src); - - // Save/load session file - LLAMA_API bool llama_load_session_file( - struct llama_context * ctx, - const char * path_session, - llama_token * tokens_out, - size_t n_token_capacity, - size_t * n_token_count_out); - - LLAMA_API bool llama_save_session_file( - struct llama_context * ctx, - const char * path_session, - const llama_token * tokens, - size_t n_token_count); - - // - // Decoding - // - - // Run the llama inference to obtain the logits and probabilities for the next token(s). - // tokens + n_tokens is the provided batch of new tokens to process - // n_past is the number of tokens to use from previous eval calls - // Returns 0 on success - // DEPRECATED: use llama_decode() instead - LLAMA_API DEPRECATED(int llama_eval( - struct llama_context * ctx, - llama_token * tokens, - int32_t n_tokens, - int n_past), - "use llama_decode() instead"); - - // Same as llama_eval, but use float matrix input directly. - // DEPRECATED: use llama_decode() instead - LLAMA_API DEPRECATED(int llama_eval_embd( - struct llama_context * ctx, - float * embd, - int32_t n_tokens, - int n_past), - "use llama_decode() instead"); - - // Return batch for single sequence of tokens starting at pos_0 - // - // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it - // - LLAMA_API struct llama_batch llama_batch_get_one( - llama_token * tokens, - int32_t n_tokens, - llama_pos pos_0, - llama_seq_id seq_id); - - // Allocates a batch of tokens on the heap - // The batch has to be freed with llama_batch_free() - // If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float) - // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token - // The rest of the llama_batch members are allocated with size n_tokens - // All members are left uninitialized - LLAMA_API struct llama_batch llama_batch_init( - int32_t n_tokens, - int32_t embd); - - // Frees a batch of tokens allocated with llama_batch_init() - LLAMA_API void llama_batch_free(struct llama_batch batch); - - // Positive return values does not mean a fatal error, but rather a warning. - // 0 - success - // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) - // < 0 - error - LLAMA_API int llama_decode( - struct llama_context * ctx, - struct llama_batch batch); - - // Set the number of threads used for decoding - // n_threads is the number of threads used for generation (single token) - // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) - LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch); - - // Token logits obtained from the last call to llama_eval() - // The logits for the last token are stored in the last row - // Logits for which llama_batch.logits[i] == 0 are undefined - // Rows: n_tokens provided with llama_batch - // Cols: n_vocab - LLAMA_API float * llama_get_logits(struct llama_context * ctx); - - // Logits for the ith token. Equivalent to: - // llama_get_logits(ctx) + i*n_vocab - LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); - - // Get the embeddings for the input - // shape: [n_embd] (1-dimensional) - LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); - - // - // Vocab - // - - LLAMA_API const char * llama_token_get_text(const struct llama_context * ctx, llama_token token); - - LLAMA_API float llama_token_get_score(const struct llama_context * ctx, llama_token token); - - LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_context * ctx, llama_token token); - - // Special tokens - LLAMA_API llama_token llama_token_bos(const struct llama_context * ctx); // beginning-of-sentence - LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx); // end-of-sentence - LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx); // next-line - // codellama infill tokens - LLAMA_API llama_token llama_token_prefix(const struct llama_context * ctx); // Beginning of infill prefix - LLAMA_API llama_token llama_token_middle(const struct llama_context * ctx); // Beginning of infill middle - LLAMA_API llama_token llama_token_suffix(const struct llama_context * ctx); // Beginning of infill suffix - LLAMA_API llama_token llama_token_eot (const struct llama_context * ctx); // End of infill middle - - // - // Tokenization - // - - // Convert the provided text into tokens. - // The tokens pointer must be large enough to hold the resulting tokens. - // Returns the number of tokens on success, no more than n_max_tokens - // Returns a negative number on failure - the number of tokens that would have been returned - LLAMA_API int llama_tokenize( - const struct llama_model * model, - const char * text, - int text_len, - llama_token * tokens, - int n_max_tokens, - bool add_bos); - - // Token Id -> Piece. - // Uses the vocabulary in the provided context. - // Does not write null terminator to the buffer. - // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens. - LLAMA_API int llama_token_to_piece( - const struct llama_model * model, - llama_token token, - char * buf, - int length); - - // - // Grammar - // - - LLAMA_API struct llama_grammar * llama_grammar_init( - const llama_grammar_element ** rules, - size_t n_rules, - size_t start_rule_index); - - LLAMA_API void llama_grammar_free(struct llama_grammar * grammar); - - LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar); - - // - // Sampling functions - // - - // Sets the current rng seed. - LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); - - /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. - LLAMA_API void llama_sample_repetition_penalty( - struct llama_context * ctx, - llama_token_data_array * candidates, - const llama_token * last_tokens, - size_t last_tokens_size, - float penalty); - - /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. - LLAMA_API void llama_sample_frequency_and_presence_penalties( - struct llama_context * ctx, - llama_token_data_array * candidates, - const llama_token * last_tokens, - size_t last_tokens_size, - float alpha_frequency, - float alpha_presence); - - /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 - /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted. - /// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. - /// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. - LLAMA_API void llama_sample_classifier_free_guidance( - struct llama_context * ctx, - llama_token_data_array * candidates, - struct llama_context * guidance_ctx, - float scale); - - /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. - LLAMA_API void llama_sample_softmax( - struct llama_context * ctx, - llama_token_data_array * candidates); - - /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API void llama_sample_top_k( - struct llama_context * ctx, - llama_token_data_array * candidates, - int k, - size_t min_keep); - - /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API void llama_sample_top_p( - struct llama_context * ctx, - llama_token_data_array * candidates, - float p, - size_t min_keep); - - /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. - LLAMA_API void llama_sample_tail_free( - struct llama_context * ctx, - llama_token_data_array * candidates, - float z, - size_t min_keep); - - /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. - LLAMA_API void llama_sample_typical( - struct llama_context * ctx, - llama_token_data_array * candidates, - float p, - size_t min_keep); - - LLAMA_API void llama_sample_temp( - struct llama_context * ctx, - llama_token_data_array * candidates, - float temp); - - LLAMA_API DEPRECATED(void llama_sample_temperature( - struct llama_context * ctx, - llama_token_data_array * candidates, - float temp), - "use llama_sample_temp instead"); - - /// @details Apply constraints from grammar - LLAMA_API void llama_sample_grammar( - struct llama_context * ctx, - llama_token_data_array * candidates, - const struct llama_grammar * grammar); - - /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. - /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. - /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. - /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. - /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - LLAMA_API llama_token llama_sample_token_mirostat( - struct llama_context * ctx, - llama_token_data_array * candidates, - float tau, - float eta, - int m, - float * mu); - - /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. - /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. - /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. - /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - LLAMA_API llama_token llama_sample_token_mirostat_v2( - struct llama_context * ctx, - llama_token_data_array * candidates, - float tau, - float eta, - float * mu); - - /// @details Selects the token with the highest probability. - LLAMA_API llama_token llama_sample_token_greedy( - struct llama_context * ctx, - llama_token_data_array * candidates); - - /// @details Randomly selects a token from the candidates based on their probabilities. - LLAMA_API llama_token llama_sample_token( - struct llama_context * ctx, - llama_token_data_array * candidates); - - /// @details Accepts the sampled token into the grammar - LLAMA_API void llama_grammar_accept_token( - struct llama_context * ctx, - struct llama_grammar * grammar, - llama_token token); - - // - // Beam search - // - - struct llama_beam_view { - const llama_token * tokens; - - size_t n_tokens; - float p; // Cumulative beam probability (renormalized relative to all beams) - bool eob; // Callback should set this to true when a beam is at end-of-beam. - }; - - // Passed to beam_search_callback function. - // Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams - // (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks. - // These pointers are valid only during the synchronous callback, so should not be saved. - struct llama_beams_state { - struct llama_beam_view * beam_views; - - size_t n_beams; // Number of elements in beam_views[]. - size_t common_prefix_length; // Current max length of prefix tokens shared by all beams. - bool last_call; // True iff this is the last callback invocation. - }; - - // Type of pointer to the beam_search_callback function. - // void* callback_data is any custom data passed to llama_beam_search, that is subsequently - // passed back to beam_search_callback. This avoids having to use global variables in the callback. - typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state); - - /// @details Deterministically returns entire sentence constructed by a beam search. - /// @param ctx Pointer to the llama_context. - /// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state. - /// @param callback_data A pointer that is simply passed back to callback. - /// @param n_beams Number of beams to use. - /// @param n_past Number of tokens already evaluated. - /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier. - LLAMA_API void llama_beam_search( - struct llama_context * ctx, - llama_beam_search_callback_fn_t callback, - void * callback_data, - size_t n_beams, - int n_past, - int n_predict); - - // Performance information - LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); - - LLAMA_API void llama_print_timings(struct llama_context * ctx); - LLAMA_API void llama_reset_timings(struct llama_context * ctx); - - // Print system information - LLAMA_API const char * llama_print_system_info(void); - - // Set callback for all future logging events. - // If this is not called, or NULL is supplied, everything is output on stderr. - LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data); - - LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx); - -#ifdef __cplusplus -} -#endif - -// Internal API to be implemented by llama.cpp and used by tests/benchmarks only -#ifdef LLAMA_API_INTERNAL - -#include -#include - -struct ggml_tensor; - -const std::vector> & llama_internal_get_tensor_map( - struct llama_context * ctx -); - -#endif // LLAMA_API_INTERNAL - -#endif // LLAMA_H diff --git a/plugins/wasi_nn/thirdparty/ggml/log.h b/plugins/wasi_nn/thirdparty/ggml/log.h deleted file mode 100644 index b8953fdc..00000000 --- a/plugins/wasi_nn/thirdparty/ggml/log.h +++ /dev/null @@ -1,643 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -// -------------------------------- -// -// Basic usage: -// -// -------- -// -// The LOG() and LOG_TEE() macros are ready to go by default -// they do not require any initialization. -// -// LOGLN() and LOG_TEELN() are variants which automatically -// include \n character at the end of the log string. -// -// LOG() behaves exactly like printf, by default writing to a logfile. -// LOG_TEE() additionally, prints to the screen too ( mimics Unix tee command ). -// -// Default logfile is named -// "llama..log" -// Default LOG_TEE() secondary output target is -// stderr -// -// Logs can be dynamically disabled or enabled using functions: -// log_disable() -// and -// log_enable() -// -// A log target can be changed with: -// log_set_target( string ) -// creating and opening, or re-opening a file by string filename -// or -// log_set_target( FILE* ) -// allowing to point at stderr, stdout, or any valid FILE* file handler. -// -// -------- -// -// End of Basic usage. -// -// -------------------------------- - -// Specifies a log target. -// default uses log_handler() with "llama.log" log file -// this can be changed, by defining LOG_TARGET -// like so: -// -// #define LOG_TARGET (a valid FILE*) -// #include "log.h" -// -// or it can be simply redirected to stdout or stderr -// like so: -// -// #define LOG_TARGET stderr -// #include "log.h" -// -// The log target can also be redirected to a diffrent function -// like so: -// -// #define LOG_TARGET log_handler_diffrent() -// #include "log.h" -// -// FILE* log_handler_diffrent() -// { -// return stderr; -// } -// -// or: -// -// #define LOG_TARGET log_handler_another_one("somelog.log") -// #include "log.h" -// -// FILE* log_handler_another_one(char*filename) -// { -// static FILE* logfile = nullptr; -// (...) -// if( !logfile ) -// { -// fopen(...) -// } -// (...) -// return logfile -// } -// -#ifndef LOG_TARGET - #define LOG_TARGET log_handler() -#endif - -#ifndef LOG_TEE_TARGET - #define LOG_TEE_TARGET stderr -#endif - -// Utility to obtain "pid" like unique process id and use it when creating log files. -inline std::string log_get_pid() -{ - static std::string pid; - if (pid.empty()) - { - // std::this_thread::get_id() is the most portable way of obtaining a "process id" - // it's not the same as "pid" but is unique enough to solve multiple instances - // trying to write to the same log. - std::stringstream ss; - ss << std::this_thread::get_id(); - pid = ss.str(); - } - - return pid; -} - -// Utility function for generating log file names with unique id based on thread id. -// invocation with log_filename_generator( "llama", "log" ) creates a string "llama..log" -// where the number is a runtime id of the current thread. - -#define log_filename_generator(log_file_basename, log_file_extension) log_filename_generator_impl(log_file_basename, log_file_extension) - -// INTERNAL, DO NOT USE -inline std::string log_filename_generator_impl(const std::string & log_file_basename, const std::string & log_file_extension) -{ - std::stringstream buf; - - buf << log_file_basename; - buf << "."; - buf << log_get_pid(); - buf << "."; - buf << log_file_extension; - - return buf.str(); -} - -#ifndef LOG_DEFAULT_FILE_NAME - #define LOG_DEFAULT_FILE_NAME log_filename_generator("llama", "log") -#endif - -// Utility for turning #define values into string literals -// so we can have a define for stderr and -// we can print "stderr" instead of literal stderr, etc. -#define LOG_STRINGIZE1(s) #s -#define LOG_STRINGIZE(s) LOG_STRINGIZE1(s) - -#define LOG_TEE_TARGET_STRING LOG_STRINGIZE(LOG_TEE_TARGET) - -// Allows disabling timestamps. -// in order to disable, define LOG_NO_TIMESTAMPS -// like so: -// -// #define LOG_NO_TIMESTAMPS -// #include "log.h" -// -#ifndef LOG_NO_TIMESTAMPS - #ifndef _MSC_VER - #define LOG_TIMESTAMP_FMT "[%" PRIu64 "] " - #define LOG_TIMESTAMP_VAL , (std::chrono::duration_cast>(std::chrono::system_clock::now().time_since_epoch())).count() - #else - #define LOG_TIMESTAMP_FMT "[%" PRIu64 "] " - #define LOG_TIMESTAMP_VAL , (std::chrono::duration_cast>(std::chrono::system_clock::now().time_since_epoch())).count() - #endif -#else - #define LOG_TIMESTAMP_FMT "%s" - #define LOG_TIMESTAMP_VAL ,"" -#endif - -#ifdef LOG_TEE_TIMESTAMPS - #ifndef _MSC_VER - #define LOG_TEE_TIMESTAMP_FMT "[%" PRIu64 "] " - #define LOG_TEE_TIMESTAMP_VAL , (std::chrono::duration_cast>(std::chrono::system_clock::now().time_since_epoch())).count() - #else - #define LOG_TEE_TIMESTAMP_FMT "[%" PRIu64 "] " - #define LOG_TEE_TIMESTAMP_VAL , (std::chrono::duration_cast>(std::chrono::system_clock::now().time_since_epoch())).count() - #endif -#else - #define LOG_TEE_TIMESTAMP_FMT "%s" - #define LOG_TEE_TIMESTAMP_VAL ,"" -#endif - -// Allows disabling file/line/function prefix -// in order to disable, define LOG_NO_FILE_LINE_FUNCTION -// like so: -// -// #define LOG_NO_FILE_LINE_FUNCTION -// #include "log.h" -// -#ifndef LOG_NO_FILE_LINE_FUNCTION - #ifndef _MSC_VER - #define LOG_FLF_FMT "[%24s:%5d][%24s] " - #define LOG_FLF_VAL , __FILE__, __LINE__, __FUNCTION__ - #else - #define LOG_FLF_FMT "[%24s:%5ld][%24s] " - #define LOG_FLF_VAL , __FILE__, __LINE__, __FUNCTION__ - #endif -#else - #define LOG_FLF_FMT "%s" - #define LOG_FLF_VAL ,"" -#endif - -#ifdef LOG_TEE_FILE_LINE_FUNCTION - #ifndef _MSC_VER - #define LOG_TEE_FLF_FMT "[%24s:%5d][%24s] " - #define LOG_TEE_FLF_VAL , __FILE__, __LINE__, __FUNCTION__ - #else - #define LOG_TEE_FLF_FMT "[%24s:%5ld][%24s] " - #define LOG_TEE_FLF_VAL , __FILE__, __LINE__, __FUNCTION__ - #endif -#else - #define LOG_TEE_FLF_FMT "%s" - #define LOG_TEE_FLF_VAL ,"" -#endif - -// Utility for synchronizing log configuration state -// since std::optional was introduced only in c++17 -enum LogTriState -{ - LogTriStateSame, - LogTriStateFalse, - LogTriStateTrue -}; - -// INTERNAL, DO NOT USE -// USE LOG() INSTEAD -// -#ifndef _MSC_VER - #define LOG_IMPL(str, ...) \ - do { \ - if (LOG_TARGET != nullptr) \ - { \ - fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL, __VA_ARGS__); \ - fflush(LOG_TARGET); \ - } \ - } while (0) -#else - #define LOG_IMPL(str, ...) \ - do { \ - if (LOG_TARGET != nullptr) \ - { \ - fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL "", ##__VA_ARGS__); \ - fflush(LOG_TARGET); \ - } \ - } while (0) -#endif - -// INTERNAL, DO NOT USE -// USE LOG_TEE() INSTEAD -// -#ifndef _MSC_VER - #define LOG_TEE_IMPL(str, ...) \ - do { \ - if (LOG_TARGET != nullptr) \ - { \ - fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL, __VA_ARGS__); \ - fflush(LOG_TARGET); \ - } \ - if (LOG_TARGET != nullptr && LOG_TARGET != stdout && LOG_TARGET != stderr && LOG_TEE_TARGET != nullptr) \ - { \ - fprintf(LOG_TEE_TARGET, LOG_TEE_TIMESTAMP_FMT LOG_TEE_FLF_FMT str "%s" LOG_TEE_TIMESTAMP_VAL LOG_TEE_FLF_VAL, __VA_ARGS__); \ - fflush(LOG_TEE_TARGET); \ - } \ - } while (0) -#else - #define LOG_TEE_IMPL(str, ...) \ - do { \ - if (LOG_TARGET != nullptr) \ - { \ - fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL "", ##__VA_ARGS__); \ - fflush(LOG_TARGET); \ - } \ - if (LOG_TARGET != nullptr && LOG_TARGET != stdout && LOG_TARGET != stderr && LOG_TEE_TARGET != nullptr) \ - { \ - fprintf(LOG_TEE_TARGET, LOG_TEE_TIMESTAMP_FMT LOG_TEE_FLF_FMT str "%s" LOG_TEE_TIMESTAMP_VAL LOG_TEE_FLF_VAL "", ##__VA_ARGS__); \ - fflush(LOG_TEE_TARGET); \ - } \ - } while (0) -#endif - -// The '\0' as a last argument, is a trick to bypass the silly -// "warning: ISO C++11 requires at least one argument for the "..." in a variadic macro" -// so we can have a single macro which can be called just like printf. - -// Main LOG macro. -// behaves like printf, and supports arguments the exact same way. -// -#ifndef _MSC_VER - #define LOG(...) LOG_IMPL(__VA_ARGS__, "") -#else - #define LOG(str, ...) LOG_IMPL("%s" str, "", __VA_ARGS__, "") -#endif - -// Main TEE macro. -// does the same as LOG -// and -// simultaneously writes stderr. -// -// Secondary target can be changed just like LOG_TARGET -// by defining LOG_TEE_TARGET -// -#ifndef _MSC_VER - #define LOG_TEE(...) LOG_TEE_IMPL(__VA_ARGS__, "") -#else - #define LOG_TEE(str, ...) LOG_TEE_IMPL("%s" str, "", __VA_ARGS__, "") -#endif - -// LOG macro variants with auto endline. -#ifndef _MSC_VER - #define LOGLN(...) LOG_IMPL(__VA_ARGS__, "\n") - #define LOG_TEELN(...) LOG_TEE_IMPL(__VA_ARGS__, "\n") -#else - #define LOGLN(str, ...) LOG_IMPL("%s" str, "", __VA_ARGS__, "\n") - #define LOG_TEELN(str, ...) LOG_TEE_IMPL("%s" str, "", __VA_ARGS__, "\n") -#endif - -// INTERNAL, DO NOT USE -inline FILE *log_handler1_impl(bool change = false, LogTriState disable = LogTriStateSame, const std::string & filename = LOG_DEFAULT_FILE_NAME, FILE *target = nullptr) -{ - static bool _initialized{false}; - static bool _disabled{(filename.empty() && target == nullptr)}; - static std::string log_current_filename{filename}; - static FILE *log_current_target{target}; - static FILE *logfile = nullptr; - - if (change) - { - if (disable == LogTriStateTrue) - { - // Disable primary target - _disabled = true; - } - // If previously disabled, only enable, and keep previous target - else if (disable == LogTriStateFalse) - { - _disabled = false; - } - // Otherwise, process the arguments - else if (log_current_filename != filename || log_current_target != target) - { - _initialized = false; - } - } - - if (_disabled) - { - // Log is disabled - return nullptr; - } - - if (_initialized) - { - // with fallback in case something went wrong - return logfile ? logfile : stderr; - } - - // do the (re)initialization - if (target != nullptr) - { - if (logfile != nullptr && logfile != stdout && logfile != stderr) - { - fclose(logfile); - } - - log_current_filename = LOG_DEFAULT_FILE_NAME; - log_current_target = target; - - logfile = target; - } - else - { - if (log_current_filename != filename) - { - if (logfile != nullptr && logfile != stdout && logfile != stderr) - { - fclose(logfile); - } - } - - logfile = fopen(filename.c_str(), "w"); - } - - if (!logfile) - { - // Verify whether the file was opened, otherwise fallback to stderr - logfile = stderr; - - fprintf(stderr, "Failed to open logfile '%s' with error '%s'\n", filename.c_str(), std::strerror(errno)); - fflush(stderr); - - // At this point we let the init flag be to true below, and let the target fallback to stderr - // otherwise we would repeatedly fopen() which was already unsuccessful - } - - _initialized = true; - - return logfile ? logfile : stderr; -} - -// INTERNAL, DO NOT USE -inline FILE *log_handler2_impl(bool change = false, LogTriState disable = LogTriStateSame, FILE *target = nullptr, const std::string & filename = LOG_DEFAULT_FILE_NAME) -{ - return log_handler1_impl(change, disable, filename, target); -} - -// Disables logs entirely at runtime. -// Makes LOG() and LOG_TEE() produce no output, -// untill enabled back. -#define log_disable() log_disable_impl() - -// INTERNAL, DO NOT USE -inline FILE *log_disable_impl() -{ - return log_handler1_impl(true, LogTriStateTrue); -} - -// Enables logs at runtime. -#define log_enable() log_enable_impl() - -// INTERNAL, DO NOT USE -inline FILE *log_enable_impl() -{ - return log_handler1_impl(true, LogTriStateFalse); -} - -// Sets target fir logs, either by a file name or FILE* pointer (stdout, stderr, or any valid FILE*) -#define log_set_target(target) log_set_target_impl(target) - -// INTERNAL, DO NOT USE -inline FILE *log_set_target_impl(const std::string & filename) { return log_handler1_impl(true, LogTriStateSame, filename); } -inline FILE *log_set_target_impl(FILE *target) { return log_handler2_impl(true, LogTriStateSame, target); } - -// INTERNAL, DO NOT USE -inline FILE *log_handler() { return log_handler1_impl(); } - -inline void log_test() -{ - log_disable(); - LOG("01 Hello World to nobody, because logs are disabled!\n"); - log_enable(); - LOG("02 Hello World to default output, which is \"%s\" ( Yaaay, arguments! )!\n", LOG_STRINGIZE(LOG_TARGET)); - LOG_TEE("03 Hello World to **both** default output and " LOG_TEE_TARGET_STRING "!\n"); - log_set_target(stderr); - LOG("04 Hello World to stderr!\n"); - LOG_TEE("05 Hello World TEE with double printing to stderr prevented!\n"); - log_set_target(LOG_DEFAULT_FILE_NAME); - LOG("06 Hello World to default log file!\n"); - log_set_target(stdout); - LOG("07 Hello World to stdout!\n"); - log_set_target(LOG_DEFAULT_FILE_NAME); - LOG("08 Hello World to default log file again!\n"); - log_disable(); - LOG("09 Hello World _1_ into the void!\n"); - log_enable(); - LOG("10 Hello World back from the void ( you should not see _1_ in the log or the output )!\n"); - log_disable(); - log_set_target("llama.anotherlog.log"); - LOG("11 Hello World _2_ to nobody, new target was selected but logs are still disabled!\n"); - log_enable(); - LOG("12 Hello World this time in a new file ( you should not see _2_ in the log or the output )?\n"); - log_set_target("llama.yetanotherlog.log"); - LOG("13 Hello World this time in yet new file?\n"); - log_set_target(log_filename_generator("llama_autonamed", "log")); - LOG("14 Hello World in log with generated filename!\n"); -#ifdef _MSC_VER - LOG_TEE("15 Hello msvc TEE without arguments\n"); - LOG_TEE("16 Hello msvc TEE with (%d)(%s) arguments\n", 1, "test"); - LOG_TEELN("17 Hello msvc TEELN without arguments\n"); - LOG_TEELN("18 Hello msvc TEELN with (%d)(%s) arguments\n", 1, "test"); - LOG("19 Hello msvc LOG without arguments\n"); - LOG("20 Hello msvc LOG with (%d)(%s) arguments\n", 1, "test"); - LOGLN("21 Hello msvc LOGLN without arguments\n"); - LOGLN("22 Hello msvc LOGLN with (%d)(%s) arguments\n", 1, "test"); -#endif -} - -inline bool log_param_single_parse(const std::string & param) -{ - if ( param == "--log-test") - { - log_test(); - return true; - } - - if ( param == "--log-disable") - { - log_disable(); - return true; - } - - if ( param == "--log-enable") - { - log_enable(); - return true; - } - - return false; -} - -inline bool log_param_pair_parse(bool check_but_dont_parse, const std::string & param, const std::string & next = std::string()) -{ - if ( param == "--log-file") - { - if (!check_but_dont_parse) - { - log_set_target(log_filename_generator(next.empty() ? "unnamed" : next, "log")); - } - - return true; - } - - return false; -} - -inline void log_print_usage() -{ - printf("log options:\n"); - /* format - printf(" -h, --help show this help message and exit\n");*/ - /* spacing - printf("__-param----------------Description\n");*/ - printf(" --log-test Run simple logging test\n"); - printf(" --log-disable Disable trace logs\n"); - printf(" --log-enable Enable trace logs\n"); - printf(" --log-file Specify a log filename (without extension)\n"); - printf(" Log file will be tagged with unique ID and written as \"..log\"\n"); /* */ -} - -#define log_dump_cmdline(argc, argv) log_dump_cmdline_impl(argc, argv) - -// INTERNAL, DO NOT USE -inline void log_dump_cmdline_impl(int argc, char **argv) -{ - std::stringstream buf; - for (int i = 0; i < argc; ++i) - { - if (std::string(argv[i]).find(' ') != std::string::npos) - { - buf << " \"" << argv[i] <<"\""; - } - else - { - buf << " " << argv[i]; - } - } - LOGLN("Cmd:%s", buf.str().c_str()); -} - -#define log_tostr(var) log_var_to_string_impl(var).c_str() - -inline std::string log_var_to_string_impl(bool var) -{ - return var ? "true" : "false"; -} - -inline std::string log_var_to_string_impl(std::string var) -{ - return var; -} - -inline std::string log_var_to_string_impl(const std::vector & var) -{ - std::stringstream buf; - buf << "[ "; - bool first = true; - for (auto e : var) - { - if (first) - { - first = false; - } - else - { - buf << ", "; - } - buf << std::to_string(e); - } - buf << " ]"; - - return buf.str(); -} - -#define LOG_TOKENS_TOSTR_PRETTY(ctx, tokens) \ - [&tokens, &ctx]() \ - { \ - std::stringstream buf; \ - buf << "[ "; \ - \ - bool first = true; \ - for (const auto &token : tokens) \ - { \ - if (!first) \ - buf << ", "; \ - else \ - first = false; \ - \ - auto detokenized = llama_token_to_piece(ctx, token); \ - \ - detokenized.erase( \ - std::remove_if( \ - detokenized.begin(), \ - detokenized.end(), \ - [](const unsigned char c) { return !std::isprint(c); }), \ - detokenized.end()); \ - \ - buf \ - << "'" << detokenized << "'" \ - << ":" << std::to_string(token); \ - } \ - buf << " ]"; \ - \ - return buf.str(); \ - }() \ - .c_str() - -#ifdef LOG_DISABLE_LOGS - -#undef LOG -#define LOG(...) // dummy stub -#undef LOGLN -#define LOGLN(...) // dummy stub - -#undef LOG_TEE -#define LOG_TEE(...) fprintf(stderr, __VA_ARGS__) // convert to normal fprintf - -#undef LOG_TEELN -#define LOG_TEELN(...) fprintf(stderr, __VA_ARGS__) // convert to normal fprintf - -#undef LOG_DISABLE -#define LOG_DISABLE() // dummy stub - -#undef LOG_ENABLE -#define LOG_ENABLE() // dummy stub - -#undef LOG_ENABLE -#define LOG_ENABLE() // dummy stub - -#undef LOG_SET_TARGET -#define LOG_SET_TARGET(...) // dummy stub - -#undef LOG_DUMP_CMDLINE -#define LOG_DUMP_CMDLINE(...) // dummy stub - -#endif // LOG_DISABLE_LOGS diff --git a/plugins/wasi_nn/thirdparty/ggml/sampling.cpp b/plugins/wasi_nn/thirdparty/ggml/sampling.cpp deleted file mode 100644 index 8ce41945..00000000 --- a/plugins/wasi_nn/thirdparty/ggml/sampling.cpp +++ /dev/null @@ -1,166 +0,0 @@ -#include "sampling.h" - -llama_sampling_context::~llama_sampling_context() { - for (auto & it : sequence_contexts) { - if (it.second.grammar != NULL) { - llama_grammar_free(it.second.grammar); - it.second.grammar = NULL; - } - } -} - -llama_sampling_context llama_sampling_context_init( - const struct gpt_params & params, - llama_grammar * grammar) { - llama_sampling_context result; - - result.params = params.sampling_params; - result.grammar = grammar; - return result; -} - -// Note: Creates the context if it doesn't exist, so this always return something. -llama_sampler_sequence_context & llama_sampling_get_sequence_context( - llama_sampling_context & ctx_sampling, - const llama_seq_id seq) { - const auto it = ctx_sampling.sequence_contexts.find(seq); - if (it != ctx_sampling.sequence_contexts.end()) { - return it->second; - } - llama_sampler_sequence_context new_ctx = { - 2.0f * ctx_sampling.params.mirostat_tau, - ctx_sampling.grammar != NULL ? llama_grammar_copy(ctx_sampling.grammar) : NULL, - }; - return ctx_sampling.sequence_contexts.insert({seq, new_ctx}).first->second; -} - -bool llama_sampling_context_reset( - llama_sampling_context & ctx_sampling, - const llama_seq_id seq) { - const auto it = ctx_sampling.sequence_contexts.find(seq); - if (it == ctx_sampling.sequence_contexts.end()) return false; - if (it->second.grammar != NULL) { - llama_grammar_free(it->second.grammar); - it->second.grammar = NULL; - } - ctx_sampling.sequence_contexts.erase(it); - return true; -} - -llama_token llama_sampling_sample( - struct llama_context * ctx, - struct llama_context * ctx_guidance, - struct llama_sampling_context & ctx_sampling, - const std::vector & last_tokens, - std::vector & candidates, - const int idx, - llama_seq_id seq) { - const int n_ctx = llama_n_ctx(ctx); - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); - - const llama_sampling_params & params = ctx_sampling.params; - const float temp = params.temp; - const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; - const float top_p = params.top_p; - const float tfs_z = params.tfs_z; - const float typical_p = params.typical_p; - const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; - const float repeat_penalty = params.repeat_penalty; - const float alpha_presence = params.presence_penalty; - const float alpha_frequency = params.frequency_penalty; - const int mirostat = params.mirostat; - const float mirostat_tau = params.mirostat_tau; - const float mirostat_eta = params.mirostat_eta; - const bool penalize_nl = params.penalize_nl; - - llama_token id = 0; - - float * logits = llama_get_logits_ith(ctx, idx); - - // Apply params.logit_bias map - for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { - logits[it->first] += it->second; - } - - candidates.clear(); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); - } - - llama_token_data_array cur_p = { candidates.data(), candidates.size(), false }; - - if (ctx_guidance) { - llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale); - } - - // apply penalties - if (!last_tokens.empty()) { - const float nl_logit = logits[llama_token_nl(ctx)]; - const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx); - - llama_sample_repetition_penalty(ctx, &cur_p, - last_tokens.data() + last_tokens.size() - last_n_repeat, - last_n_repeat, repeat_penalty); - llama_sample_frequency_and_presence_penalties(ctx, &cur_p, - last_tokens.data() + last_tokens.size() - last_n_repeat, - last_n_repeat, alpha_frequency, alpha_presence); - - if (!penalize_nl) { - for (size_t idx = 0; idx < cur_p.size; idx++) { - if (cur_p.data[idx].id == llama_token_nl(ctx)) { - cur_p.data[idx].logit = nl_logit; - break; - } - } - } - } - - llama_sampler_sequence_context & ctx_seq = llama_sampling_get_sequence_context(ctx_sampling, seq); - - if (ctx_seq.grammar != NULL) { - llama_sample_grammar(ctx, &cur_p, ctx_seq.grammar); - } - - if (temp <= 0) { - // Greedy sampling - id = llama_sample_token_greedy(ctx, &cur_p); - } else { - if (mirostat == 1) { - const int mirostat_m = 100; - llama_sample_temp(ctx, &cur_p, temp); - id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_seq.mirostat_mu); - } else if (mirostat == 2) { - llama_sample_temp(ctx, &cur_p, temp); - id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_seq.mirostat_mu); - } else { - // Temperature sampling - size_t min_keep = std::max(1, params.n_probs); - llama_sample_top_k (ctx, &cur_p, top_k, min_keep); - llama_sample_tail_free (ctx, &cur_p, tfs_z, min_keep); - llama_sample_typical (ctx, &cur_p, typical_p, min_keep); - llama_sample_top_p (ctx, &cur_p, top_p, min_keep); - llama_sample_temp(ctx, &cur_p, temp); - - { - const int n_top = 10; - LOG("top %d candidates:\n", n_top); - - for (int i = 0; i < n_top; i++) { - const llama_token id = cur_p.data[i].id; - (void)id; // To avoid a warning that id is unused when logging is disabled. - LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p); - } - } - - id = llama_sample_token(ctx, &cur_p); - - LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str()); - } - } - - if (ctx_seq.grammar != NULL) { - llama_grammar_accept_token(ctx, ctx_seq.grammar, id); - } - - return id; -} diff --git a/plugins/wasi_nn/thirdparty/ggml/sampling.h b/plugins/wasi_nn/thirdparty/ggml/sampling.h deleted file mode 100644 index 0aab5d03..00000000 --- a/plugins/wasi_nn/thirdparty/ggml/sampling.h +++ /dev/null @@ -1,108 +0,0 @@ -#pragma once - -#include "llama.h" - -#include -#include -#include - -// sampling parameters -typedef struct llama_sampling_params { - int32_t top_k = 40; // <= 0 to use vocab size - float top_p = 0.95f; // 1.0 = disabled - float tfs_z = 1.00f; // 1.0 = disabled - float typical_p = 1.00f; // 1.0 = disabled - float temp = 0.80f; // 1.0 = disabled - float repeat_penalty = 1.10f; // 1.0 = disabled - int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float frequency_penalty = 0.00f; // 0.0 = disabled - float presence_penalty = 0.00f; // 0.0 = disabled - int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float mirostat_tau = 5.00f; // target entropy - float mirostat_eta = 0.10f; // learning rate - - bool penalize_nl = true; // consider newlines as a repeatable token - - int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. - - // Classifier-Free Guidance - // https://arxiv.org/abs/2306.17806 - std::string cfg_negative_prompt; // string to help guidance - float cfg_scale = 1.f; // How strong is guidance - - std::unordered_map logit_bias; // logit bias for specific tokens - -} llama_sampling_params; - -// per-sequence sampler context -typedef struct llama_sampler_sequence_context { - float mirostat_mu; // mirostat sampler state - llama_grammar * grammar; -} llama_sampler_sequence_context; - -// general sampler context -typedef struct llama_sampling_context { - ~llama_sampling_context(); - - // parameters that will be used for sampling and when creating - // new llama_sampler_sequence_context instances - llama_sampling_params params; - - // map of sequence ids to sampler contexts - std::unordered_map sequence_contexts; - - // when non-NULL, new instances of llama_sampler_sequence_context - // will get a copy of the grammar here - // note: only the pointer is stored here, it is not a copy of - // the grammar and shouldn't be freed - llama_grammar * grammar; -} llama_sampling_context; - -#include "common.h" - -// Create a new sampling context instance. -llama_sampling_context llama_sampling_context_init( - const struct gpt_params & params, - llama_grammar * grammar = NULL); - -// Fetches the sampler context for the specified sequence id (defaults to 0). -// If the context for that sequence id doesn't already exist, it will be created with -// default values based on the parameters in the ctx_sampling argument. -llama_sampler_sequence_context & llama_sampling_get_sequence_context( - llama_sampling_context & ctx_sampling, - const llama_seq_id seq = 0); - -// Reset the sampler context for the supplied sequence id (defaults to 0). -// This is necessary to reuse a sequence id or free memory used by sequences -// that are no longer required. -bool llama_sampling_context_reset( - llama_sampling_context & ctx_sampling, - const llama_seq_id seq = 0); - -// this is a common sampling function used across the examples for convenience -// it can serve as a starting point for implementing your own sampling function -// Note: When using multiple sequences, it is the caller's responsibility to call -// llama_sampling_context_reset when a sequence ends -// -// required: -// - ctx: context to use for sampling -// - ctx_sampling: sampling-specific context -// -// optional: -// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL -// - last_tokens: needed for repetition penalty, ignore if empty -// - idx: sample from llama_get_logits_ith(ctx, idx) -// - seq: sequence id to associate sampler state with -// -// returns: -// - token: sampled token -// - candidates: vector of candidate tokens -// -llama_token llama_sampling_sample( - struct llama_context * ctx, - struct llama_context * ctx_guidance, - struct llama_sampling_context & ctx_sampling, - const std::vector & last_tokens, - std::vector & candidates, - const int idx = 0, - llama_seq_id seq = 0); diff --git a/plugins/wasi_nn/thirdparty/ggml/unicode.h b/plugins/wasi_nn/thirdparty/ggml/unicode.h deleted file mode 100644 index aeca879e..00000000 --- a/plugins/wasi_nn/thirdparty/ggml/unicode.h +++ /dev/null @@ -1,462 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -static const std::vector> digit_ranges = { -{0x30, 0x39}, {0xB2, 0xB3}, {0xB9, 0xB9}, {0x660, 0x669}, {0x6F0, 0x6F9}, {0x7C0, 0x7C9}, {0x966, 0x96F}, {0x9E6, 0x9EF}, {0xA66, 0xA6F}, {0xAE6, 0xAEF}, {0xB66, 0xB6F}, {0xBE6, 0xBEF}, {0xC66, 0xC6F}, -{0xCE6, 0xCEF}, {0xD66, 0xD6F}, {0xDE6, 0xDEF}, {0xE50, 0xE59}, {0xED0, 0xED9}, {0xF20, 0xF29}, {0x1040, 0x1049}, {0x1090, 0x1099}, {0x1369, 0x1371}, {0x17E0, 0x17E9}, {0x1810, 0x1819}, {0x1946, 0x194F}, -{0x19D0, 0x19DA}, {0x1A80, 0x1A89}, {0x1A90, 0x1A99}, {0x1B50, 0x1B59}, {0x1BB0, 0x1BB9}, {0x1C40, 0x1C49}, {0x1C50, 0x1C59}, {0x2070, 0x2070}, {0x2074, 0x2079}, {0x2080, 0x2089}, {0x2460, 0x2468}, -{0x2474, 0x247C}, {0x2488, 0x2490}, {0x24EA, 0x24EA}, {0x24F5, 0x24FD}, {0x24FF, 0x24FF}, {0x2776, 0x277E}, {0x2780, 0x2788}, {0x278A, 0x2792}, {0xA620, 0xA629}, {0xA8D0, 0xA8D9}, {0xA900, 0xA909}, -{0xA9D0, 0xA9D9}, {0xA9F0, 0xA9F9}, {0xAA50, 0xAA59}, {0xABF0, 0xABF9}, {0xFF10, 0xFF19}, {0x104A0, 0x104A9}, {0x10A40, 0x10A43}, {0x10D30, 0x10D39}, {0x10E60, 0x10E68}, {0x11052, 0x1105A}, -{0x11066, 0x1106F}, {0x110F0, 0x110F9}, {0x11136, 0x1113F}, {0x111D0, 0x111D9}, {0x112F0, 0x112F9}, {0x11450, 0x11459}, {0x114D0, 0x114D9}, {0x11650, 0x11659}, {0x116C0, 0x116C9}, {0x11730, 0x11739}, -{0x118E0, 0x118E9}, {0x11950, 0x11959}, {0x11C50, 0x11C59}, {0x11D50, 0x11D59}, {0x11DA0, 0x11DA9}, {0x16A60, 0x16A69}, {0x16B50, 0x16B59}, {0x1D7CE, 0x1D7FF}, {0x1E140, 0x1E149}, {0x1E2F0, 0x1E2F9}, -{0x1E950, 0x1E959}, {0x1F100, 0x1F10A}, {0x1FBF0, 0x1FBF9}, -}; - -static const std::vector> letter_ranges = { -{0x41, 0x5A}, {0x61, 0x7A}, {0xAA, 0xAA}, {0xB5, 0xB5}, {0xBA, 0xBA}, {0xC0, 0xD6}, {0xD8, 0xF6}, {0xF8, 0x2C1}, {0x2C6, 0x2D1}, {0x2E0, 0x2E4}, {0x2EC, 0x2EC}, {0x2EE, 0x2EE}, {0x370, 0x374}, -{0x376, 0x377}, {0x37A, 0x37D}, {0x37F, 0x37F}, {0x386, 0x386}, {0x388, 0x38A}, {0x38C, 0x38C}, {0x38E, 0x3A1}, {0x3A3, 0x3F5}, {0x3F7, 0x481}, {0x48A, 0x52F}, {0x531, 0x556}, {0x559, 0x559}, -{0x560, 0x588}, {0x5D0, 0x5EA}, {0x5EF, 0x5F2}, {0x620, 0x64A}, {0x66E, 0x66F}, {0x671, 0x6D3}, {0x6D5, 0x6D5}, {0x6E5, 0x6E6}, {0x6EE, 0x6EF}, {0x6FA, 0x6FC}, {0x6FF, 0x6FF}, {0x710, 0x710}, -{0x712, 0x72F}, {0x74D, 0x7A5}, {0x7B1, 0x7B1}, {0x7CA, 0x7EA}, {0x7F4, 0x7F5}, {0x7FA, 0x7FA}, {0x800, 0x815}, {0x81A, 0x81A}, {0x824, 0x824}, {0x828, 0x828}, {0x840, 0x858}, {0x860, 0x86A}, -{0x8A0, 0x8B4}, {0x8B6, 0x8C7}, {0x904, 0x939}, {0x93D, 0x93D}, {0x950, 0x950}, {0x958, 0x961}, {0x971, 0x980}, {0x985, 0x98C}, {0x98F, 0x990}, {0x993, 0x9A8}, {0x9AA, 0x9B0}, {0x9B2, 0x9B2}, -{0x9B6, 0x9B9}, {0x9BD, 0x9BD}, {0x9CE, 0x9CE}, {0x9DC, 0x9DD}, {0x9DF, 0x9E1}, {0x9F0, 0x9F1}, {0x9FC, 0x9FC}, {0xA05, 0xA0A}, {0xA0F, 0xA10}, {0xA13, 0xA28}, {0xA2A, 0xA30}, {0xA32, 0xA33}, -{0xA35, 0xA36}, {0xA38, 0xA39}, {0xA59, 0xA5C}, {0xA5E, 0xA5E}, {0xA72, 0xA74}, {0xA85, 0xA8D}, {0xA8F, 0xA91}, {0xA93, 0xAA8}, {0xAAA, 0xAB0}, {0xAB2, 0xAB3}, {0xAB5, 0xAB9}, {0xABD, 0xABD}, -{0xAD0, 0xAD0}, {0xAE0, 0xAE1}, {0xAF9, 0xAF9}, {0xB05, 0xB0C}, {0xB0F, 0xB10}, {0xB13, 0xB28}, {0xB2A, 0xB30}, {0xB32, 0xB33}, {0xB35, 0xB39}, {0xB3D, 0xB3D}, {0xB5C, 0xB5D}, {0xB5F, 0xB61}, -{0xB71, 0xB71}, {0xB83, 0xB83}, {0xB85, 0xB8A}, {0xB8E, 0xB90}, {0xB92, 0xB95}, {0xB99, 0xB9A}, {0xB9C, 0xB9C}, {0xB9E, 0xB9F}, {0xBA3, 0xBA4}, {0xBA8, 0xBAA}, {0xBAE, 0xBB9}, {0xBD0, 0xBD0}, -{0xC05, 0xC0C}, {0xC0E, 0xC10}, {0xC12, 0xC28}, {0xC2A, 0xC39}, {0xC3D, 0xC3D}, {0xC58, 0xC5A}, {0xC60, 0xC61}, {0xC80, 0xC80}, {0xC85, 0xC8C}, {0xC8E, 0xC90}, {0xC92, 0xCA8}, {0xCAA, 0xCB3}, -{0xCB5, 0xCB9}, {0xCBD, 0xCBD}, {0xCDE, 0xCDE}, {0xCE0, 0xCE1}, {0xCF1, 0xCF2}, {0xD04, 0xD0C}, {0xD0E, 0xD10}, {0xD12, 0xD3A}, {0xD3D, 0xD3D}, {0xD4E, 0xD4E}, {0xD54, 0xD56}, {0xD5F, 0xD61}, -{0xD7A, 0xD7F}, {0xD85, 0xD96}, {0xD9A, 0xDB1}, {0xDB3, 0xDBB}, {0xDBD, 0xDBD}, {0xDC0, 0xDC6}, {0xE01, 0xE30}, {0xE32, 0xE33}, {0xE40, 0xE46}, {0xE81, 0xE82}, {0xE84, 0xE84}, {0xE86, 0xE8A}, -{0xE8C, 0xEA3}, {0xEA5, 0xEA5}, {0xEA7, 0xEB0}, {0xEB2, 0xEB3}, {0xEBD, 0xEBD}, {0xEC0, 0xEC4}, {0xEC6, 0xEC6}, {0xEDC, 0xEDF}, {0xF00, 0xF00}, {0xF40, 0xF47}, {0xF49, 0xF6C}, {0xF88, 0xF8C}, -{0x1000, 0x102A}, {0x103F, 0x103F}, {0x1050, 0x1055}, {0x105A, 0x105D}, {0x1061, 0x1061}, {0x1065, 0x1066}, {0x106E, 0x1070}, {0x1075, 0x1081}, {0x108E, 0x108E}, {0x10A0, 0x10C5}, {0x10C7, 0x10C7}, -{0x10CD, 0x10CD}, {0x10D0, 0x10FA}, {0x10FC, 0x1248}, {0x124A, 0x124D}, {0x1250, 0x1256}, {0x1258, 0x1258}, {0x125A, 0x125D}, {0x1260, 0x1288}, {0x128A, 0x128D}, {0x1290, 0x12B0}, {0x12B2, 0x12B5}, -{0x12B8, 0x12BE}, {0x12C0, 0x12C0}, {0x12C2, 0x12C5}, {0x12C8, 0x12D6}, {0x12D8, 0x1310}, {0x1312, 0x1315}, {0x1318, 0x135A}, {0x1380, 0x138F}, {0x13A0, 0x13F5}, {0x13F8, 0x13FD}, {0x1401, 0x166C}, -{0x166F, 0x167F}, {0x1681, 0x169A}, {0x16A0, 0x16EA}, {0x16F1, 0x16F8}, {0x1700, 0x170C}, {0x170E, 0x1711}, {0x1720, 0x1731}, {0x1740, 0x1751}, {0x1760, 0x176C}, {0x176E, 0x1770}, {0x1780, 0x17B3}, -{0x17D7, 0x17D7}, {0x17DC, 0x17DC}, {0x1820, 0x1878}, {0x1880, 0x1884}, {0x1887, 0x18A8}, {0x18AA, 0x18AA}, {0x18B0, 0x18F5}, {0x1900, 0x191E}, {0x1950, 0x196D}, {0x1970, 0x1974}, {0x1980, 0x19AB}, -{0x19B0, 0x19C9}, {0x1A00, 0x1A16}, {0x1A20, 0x1A54}, {0x1AA7, 0x1AA7}, {0x1B05, 0x1B33}, {0x1B45, 0x1B4B}, {0x1B83, 0x1BA0}, {0x1BAE, 0x1BAF}, {0x1BBA, 0x1BE5}, {0x1C00, 0x1C23}, {0x1C4D, 0x1C4F}, -{0x1C5A, 0x1C7D}, {0x1C80, 0x1C88}, {0x1C90, 0x1CBA}, {0x1CBD, 0x1CBF}, {0x1CE9, 0x1CEC}, {0x1CEE, 0x1CF3}, {0x1CF5, 0x1CF6}, {0x1CFA, 0x1CFA}, {0x1D00, 0x1DBF}, {0x1E00, 0x1F15}, {0x1F18, 0x1F1D}, -{0x1F20, 0x1F45}, {0x1F48, 0x1F4D}, {0x1F50, 0x1F57}, {0x1F59, 0x1F59}, {0x1F5B, 0x1F5B}, {0x1F5D, 0x1F5D}, {0x1F5F, 0x1F7D}, {0x1F80, 0x1FB4}, {0x1FB6, 0x1FBC}, {0x1FBE, 0x1FBE}, {0x1FC2, 0x1FC4}, -{0x1FC6, 0x1FCC}, {0x1FD0, 0x1FD3}, {0x1FD6, 0x1FDB}, {0x1FE0, 0x1FEC}, {0x1FF2, 0x1FF4}, {0x1FF6, 0x1FFC}, {0x2071, 0x2071}, {0x207F, 0x207F}, {0x2090, 0x209C}, {0x2102, 0x2102}, {0x2107, 0x2107}, -{0x210A, 0x2113}, {0x2115, 0x2115}, {0x2119, 0x211D}, {0x2124, 0x2124}, {0x2126, 0x2126}, {0x2128, 0x2128}, {0x212A, 0x212D}, {0x212F, 0x2139}, {0x213C, 0x213F}, {0x2145, 0x2149}, {0x214E, 0x214E}, -{0x2183, 0x2184}, {0x2C00, 0x2C2E}, {0x2C30, 0x2C5E}, {0x2C60, 0x2CE4}, {0x2CEB, 0x2CEE}, {0x2CF2, 0x2CF3}, {0x2D00, 0x2D25}, {0x2D27, 0x2D27}, {0x2D2D, 0x2D2D}, {0x2D30, 0x2D67}, {0x2D6F, 0x2D6F}, -{0x2D80, 0x2D96}, {0x2DA0, 0x2DA6}, {0x2DA8, 0x2DAE}, {0x2DB0, 0x2DB6}, {0x2DB8, 0x2DBE}, {0x2DC0, 0x2DC6}, {0x2DC8, 0x2DCE}, {0x2DD0, 0x2DD6}, {0x2DD8, 0x2DDE}, {0x2E2F, 0x2E2F}, {0x3005, 0x3006}, -{0x3031, 0x3035}, {0x303B, 0x303C}, {0x3041, 0x3096}, {0x309D, 0x309F}, {0x30A1, 0x30FA}, {0x30FC, 0x30FF}, {0x3105, 0x312F}, {0x3131, 0x318E}, {0x31A0, 0x31BF}, {0x31F0, 0x31FF}, {0x3400, 0x4DBF}, -{0x4E00, 0x9FFC}, {0xA000, 0xA48C}, {0xA4D0, 0xA4FD}, {0xA500, 0xA60C}, {0xA610, 0xA61F}, {0xA62A, 0xA62B}, {0xA640, 0xA66E}, {0xA67F, 0xA69D}, {0xA6A0, 0xA6E5}, {0xA717, 0xA71F}, {0xA722, 0xA788}, -{0xA78B, 0xA7BF}, {0xA7C2, 0xA7CA}, {0xA7F5, 0xA801}, {0xA803, 0xA805}, {0xA807, 0xA80A}, {0xA80C, 0xA822}, {0xA840, 0xA873}, {0xA882, 0xA8B3}, {0xA8F2, 0xA8F7}, {0xA8FB, 0xA8FB}, {0xA8FD, 0xA8FE}, -{0xA90A, 0xA925}, {0xA930, 0xA946}, {0xA960, 0xA97C}, {0xA984, 0xA9B2}, {0xA9CF, 0xA9CF}, {0xA9E0, 0xA9E4}, {0xA9E6, 0xA9EF}, {0xA9FA, 0xA9FE}, {0xAA00, 0xAA28}, {0xAA40, 0xAA42}, {0xAA44, 0xAA4B}, -{0xAA60, 0xAA76}, {0xAA7A, 0xAA7A}, {0xAA7E, 0xAAAF}, {0xAAB1, 0xAAB1}, {0xAAB5, 0xAAB6}, {0xAAB9, 0xAABD}, {0xAAC0, 0xAAC0}, {0xAAC2, 0xAAC2}, {0xAADB, 0xAADD}, {0xAAE0, 0xAAEA}, {0xAAF2, 0xAAF4}, -{0xAB01, 0xAB06}, {0xAB09, 0xAB0E}, {0xAB11, 0xAB16}, {0xAB20, 0xAB26}, {0xAB28, 0xAB2E}, {0xAB30, 0xAB5A}, {0xAB5C, 0xAB69}, {0xAB70, 0xABE2}, {0xAC00, 0xD7A3}, {0xD7B0, 0xD7C6}, {0xD7CB, 0xD7FB}, -{0xF900, 0xFA6D}, {0xFA70, 0xFAD9}, {0xFB00, 0xFB06}, {0xFB13, 0xFB17}, {0xFB1D, 0xFB1D}, {0xFB1F, 0xFB28}, {0xFB2A, 0xFB36}, {0xFB38, 0xFB3C}, {0xFB3E, 0xFB3E}, {0xFB40, 0xFB41}, {0xFB43, 0xFB44}, -{0xFB46, 0xFBB1}, {0xFBD3, 0xFD3D}, {0xFD50, 0xFD8F}, {0xFD92, 0xFDC7}, {0xFDF0, 0xFDFB}, {0xFE70, 0xFE74}, {0xFE76, 0xFEFC}, {0xFF21, 0xFF3A}, {0xFF41, 0xFF5A}, {0xFF66, 0xFFBE}, {0xFFC2, 0xFFC7}, -{0xFFCA, 0xFFCF}, {0xFFD2, 0xFFD7}, {0xFFDA, 0xFFDC}, {0x10000, 0x1000B}, {0x1000D, 0x10026}, {0x10028, 0x1003A}, {0x1003C, 0x1003D}, {0x1003F, 0x1004D}, {0x10050, 0x1005D}, {0x10080, 0x100FA}, -{0x10280, 0x1029C}, {0x102A0, 0x102D0}, {0x10300, 0x1031F}, {0x1032D, 0x10340}, {0x10342, 0x10349}, {0x10350, 0x10375}, {0x10380, 0x1039D}, {0x103A0, 0x103C3}, {0x103C8, 0x103CF}, {0x10400, 0x1049D}, -{0x104B0, 0x104D3}, {0x104D8, 0x104FB}, {0x10500, 0x10527}, {0x10530, 0x10563}, {0x10600, 0x10736}, {0x10740, 0x10755}, {0x10760, 0x10767}, {0x10800, 0x10805}, {0x10808, 0x10808}, {0x1080A, 0x10835}, -{0x10837, 0x10838}, {0x1083C, 0x1083C}, {0x1083F, 0x10855}, {0x10860, 0x10876}, {0x10880, 0x1089E}, {0x108E0, 0x108F2}, {0x108F4, 0x108F5}, {0x10900, 0x10915}, {0x10920, 0x10939}, {0x10980, 0x109B7}, -{0x109BE, 0x109BF}, {0x10A00, 0x10A00}, {0x10A10, 0x10A13}, {0x10A15, 0x10A17}, {0x10A19, 0x10A35}, {0x10A60, 0x10A7C}, {0x10A80, 0x10A9C}, {0x10AC0, 0x10AC7}, {0x10AC9, 0x10AE4}, {0x10B00, 0x10B35}, -{0x10B40, 0x10B55}, {0x10B60, 0x10B72}, {0x10B80, 0x10B91}, {0x10C00, 0x10C48}, {0x10C80, 0x10CB2}, {0x10CC0, 0x10CF2}, {0x10D00, 0x10D23}, {0x10E80, 0x10EA9}, {0x10EB0, 0x10EB1}, {0x10F00, 0x10F1C}, -{0x10F27, 0x10F27}, {0x10F30, 0x10F45}, {0x10FB0, 0x10FC4}, {0x10FE0, 0x10FF6}, {0x11003, 0x11037}, {0x11083, 0x110AF}, {0x110D0, 0x110E8}, {0x11103, 0x11126}, {0x11144, 0x11144}, {0x11147, 0x11147}, -{0x11150, 0x11172}, {0x11176, 0x11176}, {0x11183, 0x111B2}, {0x111C1, 0x111C4}, {0x111DA, 0x111DA}, {0x111DC, 0x111DC}, {0x11200, 0x11211}, {0x11213, 0x1122B}, {0x11280, 0x11286}, {0x11288, 0x11288}, -{0x1128A, 0x1128D}, {0x1128F, 0x1129D}, {0x1129F, 0x112A8}, {0x112B0, 0x112DE}, {0x11305, 0x1130C}, {0x1130F, 0x11310}, {0x11313, 0x11328}, {0x1132A, 0x11330}, {0x11332, 0x11333}, {0x11335, 0x11339}, -{0x1133D, 0x1133D}, {0x11350, 0x11350}, {0x1135D, 0x11361}, {0x11400, 0x11434}, {0x11447, 0x1144A}, {0x1145F, 0x11461}, {0x11480, 0x114AF}, {0x114C4, 0x114C5}, {0x114C7, 0x114C7}, {0x11580, 0x115AE}, -{0x115D8, 0x115DB}, {0x11600, 0x1162F}, {0x11644, 0x11644}, {0x11680, 0x116AA}, {0x116B8, 0x116B8}, {0x11700, 0x1171A}, {0x11800, 0x1182B}, {0x118A0, 0x118DF}, {0x118FF, 0x11906}, {0x11909, 0x11909}, -{0x1190C, 0x11913}, {0x11915, 0x11916}, {0x11918, 0x1192F}, {0x1193F, 0x1193F}, {0x11941, 0x11941}, {0x119A0, 0x119A7}, {0x119AA, 0x119D0}, {0x119E1, 0x119E1}, {0x119E3, 0x119E3}, {0x11A00, 0x11A00}, -{0x11A0B, 0x11A32}, {0x11A3A, 0x11A3A}, {0x11A50, 0x11A50}, {0x11A5C, 0x11A89}, {0x11A9D, 0x11A9D}, {0x11AC0, 0x11AF8}, {0x11C00, 0x11C08}, {0x11C0A, 0x11C2E}, {0x11C40, 0x11C40}, {0x11C72, 0x11C8F}, -{0x11D00, 0x11D06}, {0x11D08, 0x11D09}, {0x11D0B, 0x11D30}, {0x11D46, 0x11D46}, {0x11D60, 0x11D65}, {0x11D67, 0x11D68}, {0x11D6A, 0x11D89}, {0x11D98, 0x11D98}, {0x11EE0, 0x11EF2}, {0x11FB0, 0x11FB0}, -{0x12000, 0x12399}, {0x12480, 0x12543}, {0x13000, 0x1342E}, {0x14400, 0x14646}, {0x16800, 0x16A38}, {0x16A40, 0x16A5E}, {0x16AD0, 0x16AED}, {0x16B00, 0x16B2F}, {0x16B40, 0x16B43}, {0x16B63, 0x16B77}, -{0x16B7D, 0x16B8F}, {0x16E40, 0x16E7F}, {0x16F00, 0x16F4A}, {0x16F50, 0x16F50}, {0x16F93, 0x16F9F}, {0x16FE0, 0x16FE1}, {0x16FE3, 0x16FE3}, {0x17000, 0x187F7}, {0x18800, 0x18CD5}, {0x18D00, 0x18D08}, -{0x1B000, 0x1B11E}, {0x1B150, 0x1B152}, {0x1B164, 0x1B167}, {0x1B170, 0x1B2FB}, {0x1BC00, 0x1BC6A}, {0x1BC70, 0x1BC7C}, {0x1BC80, 0x1BC88}, {0x1BC90, 0x1BC99}, {0x1D400, 0x1D454}, {0x1D456, 0x1D49C}, -{0x1D49E, 0x1D49F}, {0x1D4A2, 0x1D4A2}, {0x1D4A5, 0x1D4A6}, {0x1D4A9, 0x1D4AC}, {0x1D4AE, 0x1D4B9}, {0x1D4BB, 0x1D4BB}, {0x1D4BD, 0x1D4C3}, {0x1D4C5, 0x1D505}, {0x1D507, 0x1D50A}, {0x1D50D, 0x1D514}, -{0x1D516, 0x1D51C}, {0x1D51E, 0x1D539}, {0x1D53B, 0x1D53E}, {0x1D540, 0x1D544}, {0x1D546, 0x1D546}, {0x1D54A, 0x1D550}, {0x1D552, 0x1D6A5}, {0x1D6A8, 0x1D6C0}, {0x1D6C2, 0x1D6DA}, {0x1D6DC, 0x1D6FA}, -{0x1D6FC, 0x1D714}, {0x1D716, 0x1D734}, {0x1D736, 0x1D74E}, {0x1D750, 0x1D76E}, {0x1D770, 0x1D788}, {0x1D78A, 0x1D7A8}, {0x1D7AA, 0x1D7C2}, {0x1D7C4, 0x1D7CB}, {0x1E100, 0x1E12C}, {0x1E137, 0x1E13D}, -{0x1E14E, 0x1E14E}, {0x1E2C0, 0x1E2EB}, {0x1E800, 0x1E8C4}, {0x1E900, 0x1E943}, {0x1E94B, 0x1E94B}, {0x1EE00, 0x1EE03}, {0x1EE05, 0x1EE1F}, {0x1EE21, 0x1EE22}, {0x1EE24, 0x1EE24}, {0x1EE27, 0x1EE27}, -{0x1EE29, 0x1EE32}, {0x1EE34, 0x1EE37}, {0x1EE39, 0x1EE39}, {0x1EE3B, 0x1EE3B}, {0x1EE42, 0x1EE42}, {0x1EE47, 0x1EE47}, {0x1EE49, 0x1EE49}, {0x1EE4B, 0x1EE4B}, {0x1EE4D, 0x1EE4F}, {0x1EE51, 0x1EE52}, -{0x1EE54, 0x1EE54}, {0x1EE57, 0x1EE57}, {0x1EE59, 0x1EE59}, {0x1EE5B, 0x1EE5B}, {0x1EE5D, 0x1EE5D}, {0x1EE5F, 0x1EE5F}, {0x1EE61, 0x1EE62}, {0x1EE64, 0x1EE64}, {0x1EE67, 0x1EE6A}, {0x1EE6C, 0x1EE72}, -{0x1EE74, 0x1EE77}, {0x1EE79, 0x1EE7C}, {0x1EE7E, 0x1EE7E}, {0x1EE80, 0x1EE89}, {0x1EE8B, 0x1EE9B}, {0x1EEA1, 0x1EEA3}, {0x1EEA5, 0x1EEA9}, {0x1EEAB, 0x1EEBB}, {0x20000, 0x2A6DD}, {0x2A700, 0x2B734}, -{0x2B740, 0x2B81D}, {0x2B820, 0x2CEA1}, {0x2CEB0, 0x2EBE0}, {0x2F800, 0x2FA1D}, {0x30000, 0x3134A}, -}; - -static const std::vector> whitespace_ranges = { -{0x9, 0xD}, {0x1C, 0x20}, {0x85, 0x85}, {0xA0, 0xA0}, {0x1680, 0x1680}, {0x2000, 0x200A}, {0x2028, 0x2029}, {0x202F, 0x202F}, {0x205F, 0x205F}, {0x3000, 0x3000}, -}; - -static const std::vector> accent_mark_ranges = { -{0x300, 0x36F}, {0x483, 0x489}, {0x591, 0x5BD}, {0x5BF, 0x5BF}, {0x5C1, 0x5C2}, {0x5C4, 0x5C5}, {0x5C7, 0x5C7}, {0x610, 0x61A}, {0x64B, 0x65F}, {0x670, 0x670}, {0x6D6, 0x6DC}, {0x6DF, 0x6E4}, -{0x6E7, 0x6E8}, {0x6EA, 0x6ED}, {0x711, 0x711}, {0x730, 0x74A}, {0x7A6, 0x7B0}, {0x7EB, 0x7F3}, {0x7FD, 0x7FD}, {0x816, 0x819}, {0x81B, 0x823}, {0x825, 0x827}, {0x829, 0x82D}, {0x859, 0x85B}, -{0x8D3, 0x8E1}, {0x8E3, 0x903}, {0x93A, 0x93C}, {0x93E, 0x94F}, {0x951, 0x957}, {0x962, 0x963}, {0x981, 0x983}, {0x9BC, 0x9BC}, {0x9BE, 0x9C4}, {0x9C7, 0x9C8}, {0x9CB, 0x9CD}, {0x9D7, 0x9D7}, -{0x9E2, 0x9E3}, {0x9FE, 0x9FE}, {0xA01, 0xA03}, {0xA3C, 0xA3C}, {0xA3E, 0xA42}, {0xA47, 0xA48}, {0xA4B, 0xA4D}, {0xA51, 0xA51}, {0xA70, 0xA71}, {0xA75, 0xA75}, {0xA81, 0xA83}, {0xABC, 0xABC}, -{0xABE, 0xAC5}, {0xAC7, 0xAC9}, {0xACB, 0xACD}, {0xAE2, 0xAE3}, {0xAFA, 0xAFF}, {0xB01, 0xB03}, {0xB3C, 0xB3C}, {0xB3E, 0xB44}, {0xB47, 0xB48}, {0xB4B, 0xB4D}, {0xB55, 0xB57}, {0xB62, 0xB63}, -{0xB82, 0xB82}, {0xBBE, 0xBC2}, {0xBC6, 0xBC8}, {0xBCA, 0xBCD}, {0xBD7, 0xBD7}, {0xC00, 0xC04}, {0xC3E, 0xC44}, {0xC46, 0xC48}, {0xC4A, 0xC4D}, {0xC55, 0xC56}, {0xC62, 0xC63}, {0xC81, 0xC83}, -{0xCBC, 0xCBC}, {0xCBE, 0xCC4}, {0xCC6, 0xCC8}, {0xCCA, 0xCCD}, {0xCD5, 0xCD6}, {0xCE2, 0xCE3}, {0xD00, 0xD03}, {0xD3B, 0xD3C}, {0xD3E, 0xD44}, {0xD46, 0xD48}, {0xD4A, 0xD4D}, {0xD57, 0xD57}, -{0xD62, 0xD63}, {0xD81, 0xD83}, {0xDCA, 0xDCA}, {0xDCF, 0xDD4}, {0xDD6, 0xDD6}, {0xDD8, 0xDDF}, {0xDF2, 0xDF3}, {0xE31, 0xE31}, {0xE34, 0xE3A}, {0xE47, 0xE4E}, {0xEB1, 0xEB1}, {0xEB4, 0xEBC}, -{0xEC8, 0xECD}, {0xF18, 0xF19}, {0xF35, 0xF35}, {0xF37, 0xF37}, {0xF39, 0xF39}, {0xF3E, 0xF3F}, {0xF71, 0xF84}, {0xF86, 0xF87}, {0xF8D, 0xF97}, {0xF99, 0xFBC}, {0xFC6, 0xFC6}, {0x102B, 0x103E}, -{0x1056, 0x1059}, {0x105E, 0x1060}, {0x1062, 0x1064}, {0x1067, 0x106D}, {0x1071, 0x1074}, {0x1082, 0x108D}, {0x108F, 0x108F}, {0x109A, 0x109D}, {0x135D, 0x135F}, {0x1712, 0x1714}, {0x1732, 0x1734}, -{0x1752, 0x1753}, {0x1772, 0x1773}, {0x17B4, 0x17D3}, {0x17DD, 0x17DD}, {0x180B, 0x180D}, {0x1885, 0x1886}, {0x18A9, 0x18A9}, {0x1920, 0x192B}, {0x1930, 0x193B}, {0x1A17, 0x1A1B}, {0x1A55, 0x1A5E}, -{0x1A60, 0x1A7C}, {0x1A7F, 0x1A7F}, {0x1AB0, 0x1AC0}, {0x1B00, 0x1B04}, {0x1B34, 0x1B44}, {0x1B6B, 0x1B73}, {0x1B80, 0x1B82}, {0x1BA1, 0x1BAD}, {0x1BE6, 0x1BF3}, {0x1C24, 0x1C37}, {0x1CD0, 0x1CD2}, -{0x1CD4, 0x1CE8}, {0x1CED, 0x1CED}, {0x1CF4, 0x1CF4}, {0x1CF7, 0x1CF9}, {0x1DC0, 0x1DF9}, {0x1DFB, 0x1DFF}, {0x20D0, 0x20F0}, {0x2CEF, 0x2CF1}, {0x2D7F, 0x2D7F}, {0x2DE0, 0x2DFF}, {0x302A, 0x302F}, -{0x3099, 0x309A}, {0xA66F, 0xA672}, {0xA674, 0xA67D}, {0xA69E, 0xA69F}, {0xA6F0, 0xA6F1}, {0xA802, 0xA802}, {0xA806, 0xA806}, {0xA80B, 0xA80B}, {0xA823, 0xA827}, {0xA82C, 0xA82C}, {0xA880, 0xA881}, -{0xA8B4, 0xA8C5}, {0xA8E0, 0xA8F1}, {0xA8FF, 0xA8FF}, {0xA926, 0xA92D}, {0xA947, 0xA953}, {0xA980, 0xA983}, {0xA9B3, 0xA9C0}, {0xA9E5, 0xA9E5}, {0xAA29, 0xAA36}, {0xAA43, 0xAA43}, {0xAA4C, 0xAA4D}, -{0xAA7B, 0xAA7D}, {0xAAB0, 0xAAB0}, {0xAAB2, 0xAAB4}, {0xAAB7, 0xAAB8}, {0xAABE, 0xAABF}, {0xAAC1, 0xAAC1}, {0xAAEB, 0xAAEF}, {0xAAF5, 0xAAF6}, {0xABE3, 0xABEA}, {0xABEC, 0xABED}, {0xFB1E, 0xFB1E}, -{0xFE00, 0xFE0F}, {0xFE20, 0xFE2F}, {0x101FD, 0x101FD}, {0x102E0, 0x102E0}, {0x10376, 0x1037A}, {0x10A01, 0x10A03}, {0x10A05, 0x10A06}, {0x10A0C, 0x10A0F}, {0x10A38, 0x10A3A}, {0x10A3F, 0x10A3F}, -{0x10AE5, 0x10AE6}, {0x10D24, 0x10D27}, {0x10EAB, 0x10EAC}, {0x10F46, 0x10F50}, {0x11000, 0x11002}, {0x11038, 0x11046}, {0x1107F, 0x11082}, {0x110B0, 0x110BA}, {0x11100, 0x11102}, {0x11127, 0x11134}, -{0x11145, 0x11146}, {0x11173, 0x11173}, {0x11180, 0x11182}, {0x111B3, 0x111C0}, {0x111C9, 0x111CC}, {0x111CE, 0x111CF}, {0x1122C, 0x11237}, {0x1123E, 0x1123E}, {0x112DF, 0x112EA}, {0x11300, 0x11303}, -{0x1133B, 0x1133C}, {0x1133E, 0x11344}, {0x11347, 0x11348}, {0x1134B, 0x1134D}, {0x11357, 0x11357}, {0x11362, 0x11363}, {0x11366, 0x1136C}, {0x11370, 0x11374}, {0x11435, 0x11446}, {0x1145E, 0x1145E}, -{0x114B0, 0x114C3}, {0x115AF, 0x115B5}, {0x115B8, 0x115C0}, {0x115DC, 0x115DD}, {0x11630, 0x11640}, {0x116AB, 0x116B7}, {0x1171D, 0x1172B}, {0x1182C, 0x1183A}, {0x11930, 0x11935}, {0x11937, 0x11938}, -{0x1193B, 0x1193E}, {0x11940, 0x11940}, {0x11942, 0x11943}, {0x119D1, 0x119D7}, {0x119DA, 0x119E0}, {0x119E4, 0x119E4}, {0x11A01, 0x11A0A}, {0x11A33, 0x11A39}, {0x11A3B, 0x11A3E}, {0x11A47, 0x11A47}, -{0x11A51, 0x11A5B}, {0x11A8A, 0x11A99}, {0x11C2F, 0x11C36}, {0x11C38, 0x11C3F}, {0x11C92, 0x11CA7}, {0x11CA9, 0x11CB6}, {0x11D31, 0x11D36}, {0x11D3A, 0x11D3A}, {0x11D3C, 0x11D3D}, {0x11D3F, 0x11D45}, -{0x11D47, 0x11D47}, {0x11D8A, 0x11D8E}, {0x11D90, 0x11D91}, {0x11D93, 0x11D97}, {0x11EF3, 0x11EF6}, {0x16AF0, 0x16AF4}, {0x16B30, 0x16B36}, {0x16F4F, 0x16F4F}, {0x16F51, 0x16F87}, {0x16F8F, 0x16F92}, -{0x16FE4, 0x16FE4}, {0x16FF0, 0x16FF1}, {0x1BC9D, 0x1BC9E}, {0x1D165, 0x1D169}, {0x1D16D, 0x1D172}, {0x1D17B, 0x1D182}, {0x1D185, 0x1D18B}, {0x1D1AA, 0x1D1AD}, {0x1D242, 0x1D244}, {0x1DA00, 0x1DA36}, -{0x1DA3B, 0x1DA6C}, {0x1DA75, 0x1DA75}, {0x1DA84, 0x1DA84}, {0x1DA9B, 0x1DA9F}, {0x1DAA1, 0x1DAAF}, {0x1E000, 0x1E006}, {0x1E008, 0x1E018}, {0x1E01B, 0x1E021}, {0x1E023, 0x1E024}, {0x1E026, 0x1E02A}, -{0x1E130, 0x1E136}, {0x1E2EC, 0x1E2EF}, {0x1E8D0, 0x1E8D6}, {0x1E944, 0x1E94A}, {0xE0100, 0xE01EF}, -}; - -static const std::vector> punctuation_ranges = { -{0x21, 0x23}, {0x25, 0x2A}, {0x2C, 0x2F}, {0x3A, 0x3B}, {0x3F, 0x40}, {0x5B, 0x5D}, {0x5F, 0x5F}, {0x7B, 0x7B}, {0x7D, 0x7D}, {0xA1, 0xA1}, {0xA7, 0xA7}, {0xAB, 0xAB}, {0xB6, 0xB7}, {0xBB, 0xBB}, -{0xBF, 0xBF}, {0x37E, 0x37E}, {0x387, 0x387}, {0x55A, 0x55F}, {0x589, 0x58A}, {0x5BE, 0x5BE}, {0x5C0, 0x5C0}, {0x5C3, 0x5C3}, {0x5C6, 0x5C6}, {0x5F3, 0x5F4}, {0x609, 0x60A}, {0x60C, 0x60D}, -{0x61B, 0x61B}, {0x61E, 0x61F}, {0x66A, 0x66D}, {0x6D4, 0x6D4}, {0x700, 0x70D}, {0x7F7, 0x7F9}, {0x830, 0x83E}, {0x85E, 0x85E}, {0x964, 0x965}, {0x970, 0x970}, {0x9FD, 0x9FD}, {0xA76, 0xA76}, -{0xAF0, 0xAF0}, {0xC77, 0xC77}, {0xC84, 0xC84}, {0xDF4, 0xDF4}, {0xE4F, 0xE4F}, {0xE5A, 0xE5B}, {0xF04, 0xF12}, {0xF14, 0xF14}, {0xF3A, 0xF3D}, {0xF85, 0xF85}, {0xFD0, 0xFD4}, {0xFD9, 0xFDA}, -{0x104A, 0x104F}, {0x10FB, 0x10FB}, {0x1360, 0x1368}, {0x1400, 0x1400}, {0x166E, 0x166E}, {0x169B, 0x169C}, {0x16EB, 0x16ED}, {0x1735, 0x1736}, {0x17D4, 0x17D6}, {0x17D8, 0x17DA}, {0x1800, 0x180A}, -{0x1944, 0x1945}, {0x1A1E, 0x1A1F}, {0x1AA0, 0x1AA6}, {0x1AA8, 0x1AAD}, {0x1B5A, 0x1B60}, {0x1BFC, 0x1BFF}, {0x1C3B, 0x1C3F}, {0x1C7E, 0x1C7F}, {0x1CC0, 0x1CC7}, {0x1CD3, 0x1CD3}, {0x2010, 0x2027}, -{0x2030, 0x2043}, {0x2045, 0x2051}, {0x2053, 0x205E}, {0x207D, 0x207E}, {0x208D, 0x208E}, {0x2308, 0x230B}, {0x2329, 0x232A}, {0x2768, 0x2775}, {0x27C5, 0x27C6}, {0x27E6, 0x27EF}, {0x2983, 0x2998}, -{0x29D8, 0x29DB}, {0x29FC, 0x29FD}, {0x2CF9, 0x2CFC}, {0x2CFE, 0x2CFF}, {0x2D70, 0x2D70}, {0x2E00, 0x2E2E}, {0x2E30, 0x2E4F}, {0x2E52, 0x2E52}, {0x3001, 0x3003}, {0x3008, 0x3011}, {0x3014, 0x301F}, -{0x3030, 0x3030}, {0x303D, 0x303D}, {0x30A0, 0x30A0}, {0x30FB, 0x30FB}, {0xA4FE, 0xA4FF}, {0xA60D, 0xA60F}, {0xA673, 0xA673}, {0xA67E, 0xA67E}, {0xA6F2, 0xA6F7}, {0xA874, 0xA877}, {0xA8CE, 0xA8CF}, -{0xA8F8, 0xA8FA}, {0xA8FC, 0xA8FC}, {0xA92E, 0xA92F}, {0xA95F, 0xA95F}, {0xA9C1, 0xA9CD}, {0xA9DE, 0xA9DF}, {0xAA5C, 0xAA5F}, {0xAADE, 0xAADF}, {0xAAF0, 0xAAF1}, {0xABEB, 0xABEB}, {0xFD3E, 0xFD3F}, -{0xFE10, 0xFE19}, {0xFE30, 0xFE52}, {0xFE54, 0xFE61}, {0xFE63, 0xFE63}, {0xFE68, 0xFE68}, {0xFE6A, 0xFE6B}, {0xFF01, 0xFF03}, {0xFF05, 0xFF0A}, {0xFF0C, 0xFF0F}, {0xFF1A, 0xFF1B}, {0xFF1F, 0xFF20}, -{0xFF3B, 0xFF3D}, {0xFF3F, 0xFF3F}, {0xFF5B, 0xFF5B}, {0xFF5D, 0xFF5D}, {0xFF5F, 0xFF65}, {0x10100, 0x10102}, {0x1039F, 0x1039F}, {0x103D0, 0x103D0}, {0x1056F, 0x1056F}, {0x10857, 0x10857}, -{0x1091F, 0x1091F}, {0x1093F, 0x1093F}, {0x10A50, 0x10A58}, {0x10A7F, 0x10A7F}, {0x10AF0, 0x10AF6}, {0x10B39, 0x10B3F}, {0x10B99, 0x10B9C}, {0x10EAD, 0x10EAD}, {0x10F55, 0x10F59}, {0x11047, 0x1104D}, -{0x110BB, 0x110BC}, {0x110BE, 0x110C1}, {0x11140, 0x11143}, {0x11174, 0x11175}, {0x111C5, 0x111C8}, {0x111CD, 0x111CD}, {0x111DB, 0x111DB}, {0x111DD, 0x111DF}, {0x11238, 0x1123D}, {0x112A9, 0x112A9}, -{0x1144B, 0x1144F}, {0x1145A, 0x1145B}, {0x1145D, 0x1145D}, {0x114C6, 0x114C6}, {0x115C1, 0x115D7}, {0x11641, 0x11643}, {0x11660, 0x1166C}, {0x1173C, 0x1173E}, {0x1183B, 0x1183B}, {0x11944, 0x11946}, -{0x119E2, 0x119E2}, {0x11A3F, 0x11A46}, {0x11A9A, 0x11A9C}, {0x11A9E, 0x11AA2}, {0x11C41, 0x11C45}, {0x11C70, 0x11C71}, {0x11EF7, 0x11EF8}, {0x11FFF, 0x11FFF}, {0x12470, 0x12474}, {0x16A6E, 0x16A6F}, -{0x16AF5, 0x16AF5}, {0x16B37, 0x16B3B}, {0x16B44, 0x16B44}, {0x16E97, 0x16E9A}, {0x16FE2, 0x16FE2}, {0x1BC9F, 0x1BC9F}, {0x1DA87, 0x1DA8B}, {0x1E95E, 0x1E95F}, -}; - -static const std::vector> symbol_ranges = { -{0x24, 0x24}, {0x2B, 0x2B}, {0x3C, 0x3E}, {0x5E, 0x5E}, {0x60, 0x60}, {0x7C, 0x7C}, {0x7E, 0x7E}, {0xA2, 0xA6}, {0xA8, 0xA9}, {0xAC, 0xAC}, {0xAE, 0xB1}, {0xB4, 0xB4}, {0xB8, 0xB8}, {0xD7, 0xD7}, -{0xF7, 0xF7}, {0x2C2, 0x2C5}, {0x2D2, 0x2DF}, {0x2E5, 0x2EB}, {0x2ED, 0x2ED}, {0x2EF, 0x2FF}, {0x375, 0x375}, {0x384, 0x385}, {0x3F6, 0x3F6}, {0x482, 0x482}, {0x58D, 0x58F}, {0x606, 0x608}, -{0x60B, 0x60B}, {0x60E, 0x60F}, {0x6DE, 0x6DE}, {0x6E9, 0x6E9}, {0x6FD, 0x6FE}, {0x7F6, 0x7F6}, {0x7FE, 0x7FF}, {0x9F2, 0x9F3}, {0x9FA, 0x9FB}, {0xAF1, 0xAF1}, {0xB70, 0xB70}, {0xBF3, 0xBFA}, -{0xC7F, 0xC7F}, {0xD4F, 0xD4F}, {0xD79, 0xD79}, {0xE3F, 0xE3F}, {0xF01, 0xF03}, {0xF13, 0xF13}, {0xF15, 0xF17}, {0xF1A, 0xF1F}, {0xF34, 0xF34}, {0xF36, 0xF36}, {0xF38, 0xF38}, {0xFBE, 0xFC5}, -{0xFC7, 0xFCC}, {0xFCE, 0xFCF}, {0xFD5, 0xFD8}, {0x109E, 0x109F}, {0x1390, 0x1399}, {0x166D, 0x166D}, {0x17DB, 0x17DB}, {0x1940, 0x1940}, {0x19DE, 0x19FF}, {0x1B61, 0x1B6A}, {0x1B74, 0x1B7C}, -{0x1FBD, 0x1FBD}, {0x1FBF, 0x1FC1}, {0x1FCD, 0x1FCF}, {0x1FDD, 0x1FDF}, {0x1FED, 0x1FEF}, {0x1FFD, 0x1FFE}, {0x2044, 0x2044}, {0x2052, 0x2052}, {0x207A, 0x207C}, {0x208A, 0x208C}, {0x20A0, 0x20BF}, -{0x2100, 0x2101}, {0x2103, 0x2106}, {0x2108, 0x2109}, {0x2114, 0x2114}, {0x2116, 0x2118}, {0x211E, 0x2123}, {0x2125, 0x2125}, {0x2127, 0x2127}, {0x2129, 0x2129}, {0x212E, 0x212E}, {0x213A, 0x213B}, -{0x2140, 0x2144}, {0x214A, 0x214D}, {0x214F, 0x214F}, {0x218A, 0x218B}, {0x2190, 0x2307}, {0x230C, 0x2328}, {0x232B, 0x2426}, {0x2440, 0x244A}, {0x249C, 0x24E9}, {0x2500, 0x2767}, {0x2794, 0x27C4}, -{0x27C7, 0x27E5}, {0x27F0, 0x2982}, {0x2999, 0x29D7}, {0x29DC, 0x29FB}, {0x29FE, 0x2B73}, {0x2B76, 0x2B95}, {0x2B97, 0x2BFF}, {0x2CE5, 0x2CEA}, {0x2E50, 0x2E51}, {0x2E80, 0x2E99}, {0x2E9B, 0x2EF3}, -{0x2F00, 0x2FD5}, {0x2FF0, 0x2FFB}, {0x3004, 0x3004}, {0x3012, 0x3013}, {0x3020, 0x3020}, {0x3036, 0x3037}, {0x303E, 0x303F}, {0x309B, 0x309C}, {0x3190, 0x3191}, {0x3196, 0x319F}, {0x31C0, 0x31E3}, -{0x3200, 0x321E}, {0x322A, 0x3247}, {0x3250, 0x3250}, {0x3260, 0x327F}, {0x328A, 0x32B0}, {0x32C0, 0x33FF}, {0x4DC0, 0x4DFF}, {0xA490, 0xA4C6}, {0xA700, 0xA716}, {0xA720, 0xA721}, {0xA789, 0xA78A}, -{0xA828, 0xA82B}, {0xA836, 0xA839}, {0xAA77, 0xAA79}, {0xAB5B, 0xAB5B}, {0xAB6A, 0xAB6B}, {0xFB29, 0xFB29}, {0xFBB2, 0xFBC1}, {0xFDFC, 0xFDFD}, {0xFE62, 0xFE62}, {0xFE64, 0xFE66}, {0xFE69, 0xFE69}, -{0xFF04, 0xFF04}, {0xFF0B, 0xFF0B}, {0xFF1C, 0xFF1E}, {0xFF3E, 0xFF3E}, {0xFF40, 0xFF40}, {0xFF5C, 0xFF5C}, {0xFF5E, 0xFF5E}, {0xFFE0, 0xFFE6}, {0xFFE8, 0xFFEE}, {0xFFFC, 0xFFFD}, {0x10137, 0x1013F}, -{0x10179, 0x10189}, {0x1018C, 0x1018E}, {0x10190, 0x1019C}, {0x101A0, 0x101A0}, {0x101D0, 0x101FC}, {0x10877, 0x10878}, {0x10AC8, 0x10AC8}, {0x1173F, 0x1173F}, {0x11FD5, 0x11FF1}, {0x16B3C, 0x16B3F}, -{0x16B45, 0x16B45}, {0x1BC9C, 0x1BC9C}, {0x1D000, 0x1D0F5}, {0x1D100, 0x1D126}, {0x1D129, 0x1D164}, {0x1D16A, 0x1D16C}, {0x1D183, 0x1D184}, {0x1D18C, 0x1D1A9}, {0x1D1AE, 0x1D1E8}, {0x1D200, 0x1D241}, -{0x1D245, 0x1D245}, {0x1D300, 0x1D356}, {0x1D6C1, 0x1D6C1}, {0x1D6DB, 0x1D6DB}, {0x1D6FB, 0x1D6FB}, {0x1D715, 0x1D715}, {0x1D735, 0x1D735}, {0x1D74F, 0x1D74F}, {0x1D76F, 0x1D76F}, {0x1D789, 0x1D789}, -{0x1D7A9, 0x1D7A9}, {0x1D7C3, 0x1D7C3}, {0x1D800, 0x1D9FF}, {0x1DA37, 0x1DA3A}, {0x1DA6D, 0x1DA74}, {0x1DA76, 0x1DA83}, {0x1DA85, 0x1DA86}, {0x1E14F, 0x1E14F}, {0x1E2FF, 0x1E2FF}, {0x1ECAC, 0x1ECAC}, -{0x1ECB0, 0x1ECB0}, {0x1ED2E, 0x1ED2E}, {0x1EEF0, 0x1EEF1}, {0x1F000, 0x1F02B}, {0x1F030, 0x1F093}, {0x1F0A0, 0x1F0AE}, {0x1F0B1, 0x1F0BF}, {0x1F0C1, 0x1F0CF}, {0x1F0D1, 0x1F0F5}, {0x1F10D, 0x1F1AD}, -{0x1F1E6, 0x1F202}, {0x1F210, 0x1F23B}, {0x1F240, 0x1F248}, {0x1F250, 0x1F251}, {0x1F260, 0x1F265}, {0x1F300, 0x1F6D7}, {0x1F6E0, 0x1F6EC}, {0x1F6F0, 0x1F6FC}, {0x1F700, 0x1F773}, {0x1F780, 0x1F7D8}, -{0x1F7E0, 0x1F7EB}, {0x1F800, 0x1F80B}, {0x1F810, 0x1F847}, {0x1F850, 0x1F859}, {0x1F860, 0x1F887}, {0x1F890, 0x1F8AD}, {0x1F8B0, 0x1F8B1}, {0x1F900, 0x1F978}, {0x1F97A, 0x1F9CB}, {0x1F9CD, 0x1FA53}, -{0x1FA60, 0x1FA6D}, {0x1FA70, 0x1FA74}, {0x1FA78, 0x1FA7A}, {0x1FA80, 0x1FA86}, {0x1FA90, 0x1FAA8}, {0x1FAB0, 0x1FAB6}, {0x1FAC0, 0x1FAC2}, {0x1FAD0, 0x1FAD6}, {0x1FB00, 0x1FB92}, {0x1FB94, 0x1FBCA}, -}; - -static const std::vector> control_ranges = { -{0x0, 0x8}, {0xE, 0x1B}, {0x7F, 0x84}, {0x86, 0x9F}, {0xAD, 0xAD}, {0x378, 0x379}, {0x380, 0x383}, {0x38B, 0x38B}, {0x38D, 0x38D}, {0x3A2, 0x3A2}, {0x530, 0x530}, {0x557, 0x558}, {0x58B, 0x58C}, -{0x590, 0x590}, {0x5C8, 0x5CF}, {0x5EB, 0x5EE}, {0x5F5, 0x605}, {0x61C, 0x61D}, {0x6DD, 0x6DD}, {0x70E, 0x70F}, {0x74B, 0x74C}, {0x7B2, 0x7BF}, {0x7FB, 0x7FC}, {0x82E, 0x82F}, {0x83F, 0x83F}, -{0x85C, 0x85D}, {0x85F, 0x85F}, {0x86B, 0x89F}, {0x8B5, 0x8B5}, {0x8C8, 0x8D2}, {0x8E2, 0x8E2}, {0x984, 0x984}, {0x98D, 0x98E}, {0x991, 0x992}, {0x9A9, 0x9A9}, {0x9B1, 0x9B1}, {0x9B3, 0x9B5}, -{0x9BA, 0x9BB}, {0x9C5, 0x9C6}, {0x9C9, 0x9CA}, {0x9CF, 0x9D6}, {0x9D8, 0x9DB}, {0x9DE, 0x9DE}, {0x9E4, 0x9E5}, {0x9FF, 0xA00}, {0xA04, 0xA04}, {0xA0B, 0xA0E}, {0xA11, 0xA12}, {0xA29, 0xA29}, -{0xA31, 0xA31}, {0xA34, 0xA34}, {0xA37, 0xA37}, {0xA3A, 0xA3B}, {0xA3D, 0xA3D}, {0xA43, 0xA46}, {0xA49, 0xA4A}, {0xA4E, 0xA50}, {0xA52, 0xA58}, {0xA5D, 0xA5D}, {0xA5F, 0xA65}, {0xA77, 0xA80}, -{0xA84, 0xA84}, {0xA8E, 0xA8E}, {0xA92, 0xA92}, {0xAA9, 0xAA9}, {0xAB1, 0xAB1}, {0xAB4, 0xAB4}, {0xABA, 0xABB}, {0xAC6, 0xAC6}, {0xACA, 0xACA}, {0xACE, 0xACF}, {0xAD1, 0xADF}, {0xAE4, 0xAE5}, -{0xAF2, 0xAF8}, {0xB00, 0xB00}, {0xB04, 0xB04}, {0xB0D, 0xB0E}, {0xB11, 0xB12}, {0xB29, 0xB29}, {0xB31, 0xB31}, {0xB34, 0xB34}, {0xB3A, 0xB3B}, {0xB45, 0xB46}, {0xB49, 0xB4A}, {0xB4E, 0xB54}, -{0xB58, 0xB5B}, {0xB5E, 0xB5E}, {0xB64, 0xB65}, {0xB78, 0xB81}, {0xB84, 0xB84}, {0xB8B, 0xB8D}, {0xB91, 0xB91}, {0xB96, 0xB98}, {0xB9B, 0xB9B}, {0xB9D, 0xB9D}, {0xBA0, 0xBA2}, {0xBA5, 0xBA7}, -{0xBAB, 0xBAD}, {0xBBA, 0xBBD}, {0xBC3, 0xBC5}, {0xBC9, 0xBC9}, {0xBCE, 0xBCF}, {0xBD1, 0xBD6}, {0xBD8, 0xBE5}, {0xBFB, 0xBFF}, {0xC0D, 0xC0D}, {0xC11, 0xC11}, {0xC29, 0xC29}, {0xC3A, 0xC3C}, -{0xC45, 0xC45}, {0xC49, 0xC49}, {0xC4E, 0xC54}, {0xC57, 0xC57}, {0xC5B, 0xC5F}, {0xC64, 0xC65}, {0xC70, 0xC76}, {0xC8D, 0xC8D}, {0xC91, 0xC91}, {0xCA9, 0xCA9}, {0xCB4, 0xCB4}, {0xCBA, 0xCBB}, -{0xCC5, 0xCC5}, {0xCC9, 0xCC9}, {0xCCE, 0xCD4}, {0xCD7, 0xCDD}, {0xCDF, 0xCDF}, {0xCE4, 0xCE5}, {0xCF0, 0xCF0}, {0xCF3, 0xCFF}, {0xD0D, 0xD0D}, {0xD11, 0xD11}, {0xD45, 0xD45}, {0xD49, 0xD49}, -{0xD50, 0xD53}, {0xD64, 0xD65}, {0xD80, 0xD80}, {0xD84, 0xD84}, {0xD97, 0xD99}, {0xDB2, 0xDB2}, {0xDBC, 0xDBC}, {0xDBE, 0xDBF}, {0xDC7, 0xDC9}, {0xDCB, 0xDCE}, {0xDD5, 0xDD5}, {0xDD7, 0xDD7}, -{0xDE0, 0xDE5}, {0xDF0, 0xDF1}, {0xDF5, 0xE00}, {0xE3B, 0xE3E}, {0xE5C, 0xE80}, {0xE83, 0xE83}, {0xE85, 0xE85}, {0xE8B, 0xE8B}, {0xEA4, 0xEA4}, {0xEA6, 0xEA6}, {0xEBE, 0xEBF}, {0xEC5, 0xEC5}, -{0xEC7, 0xEC7}, {0xECE, 0xECF}, {0xEDA, 0xEDB}, {0xEE0, 0xEFF}, {0xF48, 0xF48}, {0xF6D, 0xF70}, {0xF98, 0xF98}, {0xFBD, 0xFBD}, {0xFCD, 0xFCD}, {0xFDB, 0xFFF}, {0x10C6, 0x10C6}, {0x10C8, 0x10CC}, -{0x10CE, 0x10CF}, {0x1249, 0x1249}, {0x124E, 0x124F}, {0x1257, 0x1257}, {0x1259, 0x1259}, {0x125E, 0x125F}, {0x1289, 0x1289}, {0x128E, 0x128F}, {0x12B1, 0x12B1}, {0x12B6, 0x12B7}, {0x12BF, 0x12BF}, -{0x12C1, 0x12C1}, {0x12C6, 0x12C7}, {0x12D7, 0x12D7}, {0x1311, 0x1311}, {0x1316, 0x1317}, {0x135B, 0x135C}, {0x137D, 0x137F}, {0x139A, 0x139F}, {0x13F6, 0x13F7}, {0x13FE, 0x13FF}, {0x169D, 0x169F}, -{0x16F9, 0x16FF}, {0x170D, 0x170D}, {0x1715, 0x171F}, {0x1737, 0x173F}, {0x1754, 0x175F}, {0x176D, 0x176D}, {0x1771, 0x1771}, {0x1774, 0x177F}, {0x17DE, 0x17DF}, {0x17EA, 0x17EF}, {0x17FA, 0x17FF}, -{0x180E, 0x180F}, {0x181A, 0x181F}, {0x1879, 0x187F}, {0x18AB, 0x18AF}, {0x18F6, 0x18FF}, {0x191F, 0x191F}, {0x192C, 0x192F}, {0x193C, 0x193F}, {0x1941, 0x1943}, {0x196E, 0x196F}, {0x1975, 0x197F}, -{0x19AC, 0x19AF}, {0x19CA, 0x19CF}, {0x19DB, 0x19DD}, {0x1A1C, 0x1A1D}, {0x1A5F, 0x1A5F}, {0x1A7D, 0x1A7E}, {0x1A8A, 0x1A8F}, {0x1A9A, 0x1A9F}, {0x1AAE, 0x1AAF}, {0x1AC1, 0x1AFF}, {0x1B4C, 0x1B4F}, -{0x1B7D, 0x1B7F}, {0x1BF4, 0x1BFB}, {0x1C38, 0x1C3A}, {0x1C4A, 0x1C4C}, {0x1C89, 0x1C8F}, {0x1CBB, 0x1CBC}, {0x1CC8, 0x1CCF}, {0x1CFB, 0x1CFF}, {0x1DFA, 0x1DFA}, {0x1F16, 0x1F17}, {0x1F1E, 0x1F1F}, -{0x1F46, 0x1F47}, {0x1F4E, 0x1F4F}, {0x1F58, 0x1F58}, {0x1F5A, 0x1F5A}, {0x1F5C, 0x1F5C}, {0x1F5E, 0x1F5E}, {0x1F7E, 0x1F7F}, {0x1FB5, 0x1FB5}, {0x1FC5, 0x1FC5}, {0x1FD4, 0x1FD5}, {0x1FDC, 0x1FDC}, -{0x1FF0, 0x1FF1}, {0x1FF5, 0x1FF5}, {0x1FFF, 0x1FFF}, {0x200B, 0x200F}, {0x202A, 0x202E}, {0x2060, 0x206F}, {0x2072, 0x2073}, {0x208F, 0x208F}, {0x209D, 0x209F}, {0x20C0, 0x20CF}, {0x20F1, 0x20FF}, -{0x218C, 0x218F}, {0x2427, 0x243F}, {0x244B, 0x245F}, {0x2B74, 0x2B75}, {0x2B96, 0x2B96}, {0x2C2F, 0x2C2F}, {0x2C5F, 0x2C5F}, {0x2CF4, 0x2CF8}, {0x2D26, 0x2D26}, {0x2D28, 0x2D2C}, {0x2D2E, 0x2D2F}, -{0x2D68, 0x2D6E}, {0x2D71, 0x2D7E}, {0x2D97, 0x2D9F}, {0x2DA7, 0x2DA7}, {0x2DAF, 0x2DAF}, {0x2DB7, 0x2DB7}, {0x2DBF, 0x2DBF}, {0x2DC7, 0x2DC7}, {0x2DCF, 0x2DCF}, {0x2DD7, 0x2DD7}, {0x2DDF, 0x2DDF}, -{0x2E53, 0x2E7F}, {0x2E9A, 0x2E9A}, {0x2EF4, 0x2EFF}, {0x2FD6, 0x2FEF}, {0x2FFC, 0x2FFF}, {0x3040, 0x3040}, {0x3097, 0x3098}, {0x3100, 0x3104}, {0x3130, 0x3130}, {0x318F, 0x318F}, {0x31E4, 0x31EF}, -{0x321F, 0x321F}, {0x9FFD, 0x9FFF}, {0xA48D, 0xA48F}, {0xA4C7, 0xA4CF}, {0xA62C, 0xA63F}, {0xA6F8, 0xA6FF}, {0xA7C0, 0xA7C1}, {0xA7CB, 0xA7F4}, {0xA82D, 0xA82F}, {0xA83A, 0xA83F}, {0xA878, 0xA87F}, -{0xA8C6, 0xA8CD}, {0xA8DA, 0xA8DF}, {0xA954, 0xA95E}, {0xA97D, 0xA97F}, {0xA9CE, 0xA9CE}, {0xA9DA, 0xA9DD}, {0xA9FF, 0xA9FF}, {0xAA37, 0xAA3F}, {0xAA4E, 0xAA4F}, {0xAA5A, 0xAA5B}, {0xAAC3, 0xAADA}, -{0xAAF7, 0xAB00}, {0xAB07, 0xAB08}, {0xAB0F, 0xAB10}, {0xAB17, 0xAB1F}, {0xAB27, 0xAB27}, {0xAB2F, 0xAB2F}, {0xAB6C, 0xAB6F}, {0xABEE, 0xABEF}, {0xABFA, 0xABFF}, {0xD7A4, 0xD7AF}, {0xD7C7, 0xD7CA}, -{0xD7FC, 0xF8FF}, {0xFA6E, 0xFA6F}, {0xFADA, 0xFAFF}, {0xFB07, 0xFB12}, {0xFB18, 0xFB1C}, {0xFB37, 0xFB37}, {0xFB3D, 0xFB3D}, {0xFB3F, 0xFB3F}, {0xFB42, 0xFB42}, {0xFB45, 0xFB45}, {0xFBC2, 0xFBD2}, -{0xFD40, 0xFD4F}, {0xFD90, 0xFD91}, {0xFDC8, 0xFDEF}, {0xFDFE, 0xFDFF}, {0xFE1A, 0xFE1F}, {0xFE53, 0xFE53}, {0xFE67, 0xFE67}, {0xFE6C, 0xFE6F}, {0xFE75, 0xFE75}, {0xFEFD, 0xFF00}, {0xFFBF, 0xFFC1}, -{0xFFC8, 0xFFC9}, {0xFFD0, 0xFFD1}, {0xFFD8, 0xFFD9}, {0xFFDD, 0xFFDF}, {0xFFE7, 0xFFE7}, {0xFFEF, 0xFFFB}, {0xFFFE, 0xFFFF}, {0x1000C, 0x1000C}, {0x10027, 0x10027}, {0x1003B, 0x1003B}, -{0x1003E, 0x1003E}, {0x1004E, 0x1004F}, {0x1005E, 0x1007F}, {0x100FB, 0x100FF}, {0x10103, 0x10106}, {0x10134, 0x10136}, {0x1018F, 0x1018F}, {0x1019D, 0x1019F}, {0x101A1, 0x101CF}, {0x101FE, 0x1027F}, -{0x1029D, 0x1029F}, {0x102D1, 0x102DF}, {0x102FC, 0x102FF}, {0x10324, 0x1032C}, {0x1034B, 0x1034F}, {0x1037B, 0x1037F}, {0x1039E, 0x1039E}, {0x103C4, 0x103C7}, {0x103D6, 0x103FF}, {0x1049E, 0x1049F}, -{0x104AA, 0x104AF}, {0x104D4, 0x104D7}, {0x104FC, 0x104FF}, {0x10528, 0x1052F}, {0x10564, 0x1056E}, {0x10570, 0x105FF}, {0x10737, 0x1073F}, {0x10756, 0x1075F}, {0x10768, 0x107FF}, {0x10806, 0x10807}, -{0x10809, 0x10809}, {0x10836, 0x10836}, {0x10839, 0x1083B}, {0x1083D, 0x1083E}, {0x10856, 0x10856}, {0x1089F, 0x108A6}, {0x108B0, 0x108DF}, {0x108F3, 0x108F3}, {0x108F6, 0x108FA}, {0x1091C, 0x1091E}, -{0x1093A, 0x1093E}, {0x10940, 0x1097F}, {0x109B8, 0x109BB}, {0x109D0, 0x109D1}, {0x10A04, 0x10A04}, {0x10A07, 0x10A0B}, {0x10A14, 0x10A14}, {0x10A18, 0x10A18}, {0x10A36, 0x10A37}, {0x10A3B, 0x10A3E}, -{0x10A49, 0x10A4F}, {0x10A59, 0x10A5F}, {0x10AA0, 0x10ABF}, {0x10AE7, 0x10AEA}, {0x10AF7, 0x10AFF}, {0x10B36, 0x10B38}, {0x10B56, 0x10B57}, {0x10B73, 0x10B77}, {0x10B92, 0x10B98}, {0x10B9D, 0x10BA8}, -{0x10BB0, 0x10BFF}, {0x10C49, 0x10C7F}, {0x10CB3, 0x10CBF}, {0x10CF3, 0x10CF9}, {0x10D28, 0x10D2F}, {0x10D3A, 0x10E5F}, {0x10E7F, 0x10E7F}, {0x10EAA, 0x10EAA}, {0x10EAE, 0x10EAF}, {0x10EB2, 0x10EFF}, -{0x10F28, 0x10F2F}, {0x10F5A, 0x10FAF}, {0x10FCC, 0x10FDF}, {0x10FF7, 0x10FFF}, {0x1104E, 0x11051}, {0x11070, 0x1107E}, {0x110BD, 0x110BD}, {0x110C2, 0x110CF}, {0x110E9, 0x110EF}, {0x110FA, 0x110FF}, -{0x11135, 0x11135}, {0x11148, 0x1114F}, {0x11177, 0x1117F}, {0x111E0, 0x111E0}, {0x111F5, 0x111FF}, {0x11212, 0x11212}, {0x1123F, 0x1127F}, {0x11287, 0x11287}, {0x11289, 0x11289}, {0x1128E, 0x1128E}, -{0x1129E, 0x1129E}, {0x112AA, 0x112AF}, {0x112EB, 0x112EF}, {0x112FA, 0x112FF}, {0x11304, 0x11304}, {0x1130D, 0x1130E}, {0x11311, 0x11312}, {0x11329, 0x11329}, {0x11331, 0x11331}, {0x11334, 0x11334}, -{0x1133A, 0x1133A}, {0x11345, 0x11346}, {0x11349, 0x1134A}, {0x1134E, 0x1134F}, {0x11351, 0x11356}, {0x11358, 0x1135C}, {0x11364, 0x11365}, {0x1136D, 0x1136F}, {0x11375, 0x113FF}, {0x1145C, 0x1145C}, -{0x11462, 0x1147F}, {0x114C8, 0x114CF}, {0x114DA, 0x1157F}, {0x115B6, 0x115B7}, {0x115DE, 0x115FF}, {0x11645, 0x1164F}, {0x1165A, 0x1165F}, {0x1166D, 0x1167F}, {0x116B9, 0x116BF}, {0x116CA, 0x116FF}, -{0x1171B, 0x1171C}, {0x1172C, 0x1172F}, {0x11740, 0x117FF}, {0x1183C, 0x1189F}, {0x118F3, 0x118FE}, {0x11907, 0x11908}, {0x1190A, 0x1190B}, {0x11914, 0x11914}, {0x11917, 0x11917}, {0x11936, 0x11936}, -{0x11939, 0x1193A}, {0x11947, 0x1194F}, {0x1195A, 0x1199F}, {0x119A8, 0x119A9}, {0x119D8, 0x119D9}, {0x119E5, 0x119FF}, {0x11A48, 0x11A4F}, {0x11AA3, 0x11ABF}, {0x11AF9, 0x11BFF}, {0x11C09, 0x11C09}, -{0x11C37, 0x11C37}, {0x11C46, 0x11C4F}, {0x11C6D, 0x11C6F}, {0x11C90, 0x11C91}, {0x11CA8, 0x11CA8}, {0x11CB7, 0x11CFF}, {0x11D07, 0x11D07}, {0x11D0A, 0x11D0A}, {0x11D37, 0x11D39}, {0x11D3B, 0x11D3B}, -{0x11D3E, 0x11D3E}, {0x11D48, 0x11D4F}, {0x11D5A, 0x11D5F}, {0x11D66, 0x11D66}, {0x11D69, 0x11D69}, {0x11D8F, 0x11D8F}, {0x11D92, 0x11D92}, {0x11D99, 0x11D9F}, {0x11DAA, 0x11EDF}, {0x11EF9, 0x11FAF}, -{0x11FB1, 0x11FBF}, {0x11FF2, 0x11FFE}, {0x1239A, 0x123FF}, {0x1246F, 0x1246F}, {0x12475, 0x1247F}, {0x12544, 0x12FFF}, {0x1342F, 0x143FF}, {0x14647, 0x167FF}, {0x16A39, 0x16A3F}, {0x16A5F, 0x16A5F}, -{0x16A6A, 0x16A6D}, {0x16A70, 0x16ACF}, {0x16AEE, 0x16AEF}, {0x16AF6, 0x16AFF}, {0x16B46, 0x16B4F}, {0x16B5A, 0x16B5A}, {0x16B62, 0x16B62}, {0x16B78, 0x16B7C}, {0x16B90, 0x16E3F}, {0x16E9B, 0x16EFF}, -{0x16F4B, 0x16F4E}, {0x16F88, 0x16F8E}, {0x16FA0, 0x16FDF}, {0x16FE5, 0x16FEF}, {0x16FF2, 0x16FFF}, {0x187F8, 0x187FF}, {0x18CD6, 0x18CFF}, {0x18D09, 0x1AFFF}, {0x1B11F, 0x1B14F}, {0x1B153, 0x1B163}, -{0x1B168, 0x1B16F}, {0x1B2FC, 0x1BBFF}, {0x1BC6B, 0x1BC6F}, {0x1BC7D, 0x1BC7F}, {0x1BC89, 0x1BC8F}, {0x1BC9A, 0x1BC9B}, {0x1BCA0, 0x1CFFF}, {0x1D0F6, 0x1D0FF}, {0x1D127, 0x1D128}, {0x1D173, 0x1D17A}, -{0x1D1E9, 0x1D1FF}, {0x1D246, 0x1D2DF}, {0x1D2F4, 0x1D2FF}, {0x1D357, 0x1D35F}, {0x1D379, 0x1D3FF}, {0x1D455, 0x1D455}, {0x1D49D, 0x1D49D}, {0x1D4A0, 0x1D4A1}, {0x1D4A3, 0x1D4A4}, {0x1D4A7, 0x1D4A8}, -{0x1D4AD, 0x1D4AD}, {0x1D4BA, 0x1D4BA}, {0x1D4BC, 0x1D4BC}, {0x1D4C4, 0x1D4C4}, {0x1D506, 0x1D506}, {0x1D50B, 0x1D50C}, {0x1D515, 0x1D515}, {0x1D51D, 0x1D51D}, {0x1D53A, 0x1D53A}, {0x1D53F, 0x1D53F}, -{0x1D545, 0x1D545}, {0x1D547, 0x1D549}, {0x1D551, 0x1D551}, {0x1D6A6, 0x1D6A7}, {0x1D7CC, 0x1D7CD}, {0x1DA8C, 0x1DA9A}, {0x1DAA0, 0x1DAA0}, {0x1DAB0, 0x1DFFF}, {0x1E007, 0x1E007}, {0x1E019, 0x1E01A}, -{0x1E022, 0x1E022}, {0x1E025, 0x1E025}, {0x1E02B, 0x1E0FF}, {0x1E12D, 0x1E12F}, {0x1E13E, 0x1E13F}, {0x1E14A, 0x1E14D}, {0x1E150, 0x1E2BF}, {0x1E2FA, 0x1E2FE}, {0x1E300, 0x1E7FF}, {0x1E8C5, 0x1E8C6}, -{0x1E8D7, 0x1E8FF}, {0x1E94C, 0x1E94F}, {0x1E95A, 0x1E95D}, {0x1E960, 0x1EC70}, {0x1ECB5, 0x1ED00}, {0x1ED3E, 0x1EDFF}, {0x1EE04, 0x1EE04}, {0x1EE20, 0x1EE20}, {0x1EE23, 0x1EE23}, {0x1EE25, 0x1EE26}, -{0x1EE28, 0x1EE28}, {0x1EE33, 0x1EE33}, {0x1EE38, 0x1EE38}, {0x1EE3A, 0x1EE3A}, {0x1EE3C, 0x1EE41}, {0x1EE43, 0x1EE46}, {0x1EE48, 0x1EE48}, {0x1EE4A, 0x1EE4A}, {0x1EE4C, 0x1EE4C}, {0x1EE50, 0x1EE50}, -{0x1EE53, 0x1EE53}, {0x1EE55, 0x1EE56}, {0x1EE58, 0x1EE58}, {0x1EE5A, 0x1EE5A}, {0x1EE5C, 0x1EE5C}, {0x1EE5E, 0x1EE5E}, {0x1EE60, 0x1EE60}, {0x1EE63, 0x1EE63}, {0x1EE65, 0x1EE66}, {0x1EE6B, 0x1EE6B}, -{0x1EE73, 0x1EE73}, {0x1EE78, 0x1EE78}, {0x1EE7D, 0x1EE7D}, {0x1EE7F, 0x1EE7F}, {0x1EE8A, 0x1EE8A}, {0x1EE9C, 0x1EEA0}, {0x1EEA4, 0x1EEA4}, {0x1EEAA, 0x1EEAA}, {0x1EEBC, 0x1EEEF}, {0x1EEF2, 0x1EFFF}, -{0x1F02C, 0x1F02F}, {0x1F094, 0x1F09F}, {0x1F0AF, 0x1F0B0}, {0x1F0C0, 0x1F0C0}, {0x1F0D0, 0x1F0D0}, {0x1F0F6, 0x1F0FF}, {0x1F1AE, 0x1F1E5}, {0x1F203, 0x1F20F}, {0x1F23C, 0x1F23F}, {0x1F249, 0x1F24F}, -{0x1F252, 0x1F25F}, {0x1F266, 0x1F2FF}, {0x1F6D8, 0x1F6DF}, {0x1F6ED, 0x1F6EF}, {0x1F6FD, 0x1F6FF}, {0x1F774, 0x1F77F}, {0x1F7D9, 0x1F7DF}, {0x1F7EC, 0x1F7FF}, {0x1F80C, 0x1F80F}, {0x1F848, 0x1F84F}, -{0x1F85A, 0x1F85F}, {0x1F888, 0x1F88F}, {0x1F8AE, 0x1F8AF}, {0x1F8B2, 0x1F8FF}, {0x1F979, 0x1F979}, {0x1F9CC, 0x1F9CC}, {0x1FA54, 0x1FA5F}, {0x1FA6E, 0x1FA6F}, {0x1FA75, 0x1FA77}, {0x1FA7B, 0x1FA7F}, -{0x1FA87, 0x1FA8F}, {0x1FAA9, 0x1FAAF}, {0x1FAB7, 0x1FABF}, {0x1FAC3, 0x1FACF}, {0x1FAD7, 0x1FAFF}, {0x1FB93, 0x1FB93}, {0x1FBCB, 0x1FBEF}, {0x1FBFA, 0x1FFFF}, {0x2A6DE, 0x2A6FF}, {0x2B735, 0x2B73F}, -{0x2B81E, 0x2B81F}, {0x2CEA2, 0x2CEAF}, {0x2EBE1, 0x2F7FF}, {0x2FA1E, 0x2FFFF}, {0x3134B, 0xE00FF}, {0xE01F0, 0x10FFFF}, -}; - -static std::string codepoint_to_utf8(uint32_t cp) { - std::string result; - if (/* 0x00 <= cp && */ cp <= 0x7f) { - result.push_back(cp); - } - else if (0x80 <= cp && cp <= 0x7ff) { - result.push_back(0xc0 | ((cp >> 6) & 0x1f)); - result.push_back(0x80 | (cp & 0x3f)); - } - else if (0x800 <= cp && cp <= 0xffff) { - result.push_back(0xe0 | ((cp >> 12) & 0x0f)); - result.push_back(0x80 | ((cp >> 6) & 0x3f)); - result.push_back(0x80 | (cp & 0x3f)); - } - else if (0x10000 <= cp && cp <= 0x10ffff) { - result.push_back(0xf0 | ((cp >> 18) & 0x07)); - result.push_back(0x80 | ((cp >> 12) & 0x3f)); - result.push_back(0x80 | ((cp >> 6) & 0x3f)); - result.push_back(0x80 | (cp & 0x3f)); - } - else { - throw std::invalid_argument("invalid codepoint"); - } - return result; -} - -static std::string codepoints_to_utf8(const std::vector & cps) { - std::string result; - for (size_t i = 0; i < cps.size(); ++i) { - result.append(codepoint_to_utf8(cps[i])); - } - return result; -} - -static uint32_t codepoint_from_utf8(const std::string & utf8, size_t & offset) { - assert(offset < utf8.size()); - if (!(utf8[offset + 0] & 0x80)) { - auto result = utf8[offset + 0]; - offset += 1; - return result; - } - else if (!(utf8[offset + 0] & 0x40)) { - throw std::invalid_argument("invalid character"); - } - else if (!(utf8[offset + 0] & 0x20)) { - if (offset + 1 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80)) - throw std::invalid_argument("invalid character"); - auto result = ((utf8[offset + 0] & 0x1f) << 6) | (utf8[offset + 1] & 0x3f); - offset += 2; - return result; - } - else if (!(utf8[offset + 0] & 0x10)) { - if (offset + 2 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80)) - throw std::invalid_argument("invalid character"); - auto result = ((utf8[offset + 0] & 0x0f) << 12) | ((utf8[offset + 1] & 0x3f) << 6) | (utf8[offset + 2] & 0x3f); - offset += 3; - return result; - } - else if (!(utf8[offset + 0] & 0x08)) { - if (offset + 3 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80)) - throw std::invalid_argument("invalid character"); - auto result = ((utf8[offset + 0] & 0x07) << 18) | ((utf8[offset + 1] & 0x3f) << 12) | ((utf8[offset + 2] & 0x3f) << 6) | (utf8[offset + 3] & 0x3f); - offset += 4; - return result; - } - throw std::invalid_argument("invalid string"); -} - -static std::vector codepoints_from_utf8(const std::string & utf8) { - std::vector result; - size_t offset = 0; - while (offset < utf8.size()) { - result.push_back(codepoint_from_utf8(utf8, offset)); - } - return result; -} - -static std::vector codepoint_to_utf16(uint32_t cp) { - std::vector result; - if (/* 0x0000 <= cp && */ cp <= 0xffff) { - result.emplace_back(cp); - } - else if (0x10000 <= cp && cp <= 0x10ffff) { - result.emplace_back(0xd800 | ((cp - 0x10000) >> 10)); - result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff)); - } - else { - throw std::invalid_argument("invalid codepoint"); - } - return result; -} - -static std::vector codepoints_to_utf16(const std::vector & cps) { - std::vector result; - for (size_t i = 0; i < cps.size(); ++i) { - auto temp = codepoint_to_utf16(cps[i]); - result.insert(result.end(), temp.begin(), temp.end()); - } - return result; -} - -static uint32_t codepoint_from_utf16(const std::vector & utf16, size_t & offset) { - assert(offset < utf16.size()); - if (((utf16[0] >> 10) << 10) != 0xd800) { - auto result = utf16[offset + 0]; - offset += 1; - return result; - } - else { - if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) - throw std::invalid_argument("invalid character"); - auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff)); - offset += 2; - return result; - } - throw std::invalid_argument("invalid string"); -} - -static std::vector codepoints_from_utf16(const std::vector & utf16) { - std::vector result; - size_t offset = 0; - while (offset < utf16.size()) - result.push_back(codepoint_from_utf16(utf16, offset)); - return result; -} - -#define CODEPOINT_TYPE_UNIDENTIFIED 0 -#define CODEPOINT_TYPE_DIGIT 1 -#define CODEPOINT_TYPE_LETTER 2 -#define CODEPOINT_TYPE_WHITESPACE 3 -#define CODEPOINT_TYPE_ACCENT_MARK 4 -#define CODEPOINT_TYPE_PUNCTUATION 5 -#define CODEPOINT_TYPE_SYMBOL 6 -#define CODEPOINT_TYPE_CONTROL 7 - -static std::unordered_map codepoint_type_map() { - std::unordered_map codepoint_types; - for (auto p : digit_ranges) { - for(auto i = p.first; i <= p.second; ++ i) - codepoint_types[i] = CODEPOINT_TYPE_DIGIT; - } - for(auto p : letter_ranges) { - for(auto i = p.first; i <= p.second; ++ i) - codepoint_types[i] = CODEPOINT_TYPE_LETTER; - } - for(auto p : whitespace_ranges) { - for(auto i = p.first; i <= p.second; ++ i) - codepoint_types[i] = CODEPOINT_TYPE_WHITESPACE; - } - for(auto p : accent_mark_ranges) { - for(auto i = p.first; i <= p.second; ++ i) - codepoint_types[i] = CODEPOINT_TYPE_ACCENT_MARK; - } - for(auto p : punctuation_ranges) { - for(auto i = p.first; i <= p.second; ++ i) - codepoint_types[i] = CODEPOINT_TYPE_PUNCTUATION; - } - for (auto p : symbol_ranges) { - for (auto i = p.first; i <= p.second; ++i) - codepoint_types[i] = CODEPOINT_TYPE_SYMBOL; - } - for(auto p : control_ranges) { - for(auto i = p.first; i <= p.second; ++ i) - codepoint_types[i] = CODEPOINT_TYPE_CONTROL; - } - return codepoint_types; -} - -static int codepoint_type(uint32_t cp) { - static std::unordered_map codepoint_types = codepoint_type_map(); - return codepoint_types[cp]; -} - -static int codepoint_type(const std::string & utf8) { - if (utf8.length() == 0) - return CODEPOINT_TYPE_UNIDENTIFIED; - size_t offset = 0; - return codepoint_type(codepoint_from_utf8(utf8, offset)); -} - -static std::unordered_map bytes_to_unicode_map_bpe() { - std::unordered_map map; - for (int ch = u'!'; ch <= u'~'; ++ch) { - assert(0 <= ch && ch < 256); - map[ch] = codepoint_to_utf8(ch); - } - for (int ch = u'¡'; ch <= u'¬'; ++ch) { - assert(0 <= ch && ch < 256); - map[ch] = codepoint_to_utf8(ch); - } - for (int ch = u'®'; ch <= u'ÿ'; ++ch) { - assert(0 <= ch && ch < 256); - map[ch] = codepoint_to_utf8(ch); - } - auto n = 0; - for (int ch = 0; ch < 256; ++ch) { - if (map.find(ch) == map.end()) { - map[ch] = codepoint_to_utf8(256 + n); - ++n; - } - } - return map; -} - -static std::string bytes_to_unicode_bpe(uint8_t byte) { - static std::unordered_map map = bytes_to_unicode_map_bpe(); - return map.at(byte); -} - -static std::unordered_map unicode_to_bytes_map_bpe() { - std::unordered_map map; - for (int ch = u'!'; ch <= u'~'; ++ch) { - assert(0 <= ch && ch < 256); - map[codepoint_to_utf8(ch)] = ch; - } - for (int ch = u'¡'; ch <= u'¬'; ++ch) { - assert(0 <= ch && ch < 256); - map[codepoint_to_utf8(ch)] = ch; - } - for (int ch = u'®'; ch <= u'ÿ'; ++ch) { - assert(0 <= ch && ch < 256); - map[codepoint_to_utf8(ch)] = ch; - } - auto n = 0; - for (int ch = 0; ch < 256; ++ch) { - if (map.find(codepoint_to_utf8(ch)) == map.end()) { - map[codepoint_to_utf8(256 + n)] = ch; - ++n; - } - } - return map; -} - -static uint8_t unicode_to_bytes_bpe(const std::string & utf8) { - static std::unordered_map map = unicode_to_bytes_map_bpe(); - return map.at(utf8); -} - From 0d64b54e016fd59d7a8c126d81a6341e5aa01b98 Mon Sep 17 00:00:00 2001 From: dm4 Date: Fri, 24 Nov 2023 02:14:33 +0800 Subject: [PATCH 184/623] [WASI-NN] ggml: update llama.cpp to b1550 Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- plugins/wasi_nn/ggml.cpp | 185 +++++++++++++++++++-------------- 2 files changed, 107 insertions(+), 80 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 390714f9..c6ccb81a 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -47,7 +47,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b1383 + GIT_TAG b1550 PATCH_COMMAND test -f ggml.patched || git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch && ${CMAKE_COMMAND} -E touch ggml.patched GIT_SHALLOW TRUE ) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index bd19cc1a..d2d46f2c 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -324,9 +324,11 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } // Set the input. + const bool AddBos = llama_should_add_bos_token(GraphRef.LlamaModel); std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), Tensor.Tensor.size()); - CxtRef.LlamaInputs = llama_tokenize(CxtRef.LlamaContext, Prompt, true); + CxtRef.LlamaInputs = + llama_tokenize(CxtRef.LlamaContext, Prompt, AddBos, true); const uint32_t MaxContextSize = llama_n_ctx(CxtRef.LlamaContext); // Minus 4 for the special tokens. const uint32_t MaxTokensListSize = MaxContextSize - 4; @@ -364,93 +366,118 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { } // Clear the outputs. - CxtRef.LlamaOutputs = ""sv; + CxtRef.LlamaOutputs.clear(); // Main predict loop. - // TODO: recompute a compressed context based on previous tokens once the - // cache is full. - const int MaxContextSize = llama_n_ctx(CxtRef.LlamaContext); - // NPredict is the number of tokens to predict. Same as -n, --n-predict in - // llama.cpp. - int NPredict = GraphRef.NPredict; - - // Evaluate the initial prompt. - llama_batch LlamaBatch = llama_batch_init(GraphRef.BatchSize, 0); - LlamaBatch.n_tokens = CxtRef.LlamaInputs.size(); - for (int32_t I = 0; I < LlamaBatch.n_tokens; I++) { - LlamaBatch.token[I] = CxtRef.LlamaInputs[I]; - LlamaBatch.pos[I] = I; - LlamaBatch.seq_id[I] = 0; - LlamaBatch.logits[I] = false; - } - - // llama_decode will output logits only for the last token of the prompt - LlamaBatch.logits[LlamaBatch.n_tokens - 1] = true; - if (llama_decode(CxtRef.LlamaContext, LlamaBatch) != 0) { - spdlog::info("[WASI-NN] GGML backend: llama_decode() failed"sv); - return ErrNo::RuntimeError; - } - - int NCur = LlamaBatch.n_tokens; - while (NCur < MaxContextSize && NCur < NPredict) { - // Sample the next token - auto NVocab = llama_n_vocab(GraphRef.LlamaModel); - auto *Logits = - llama_get_logits_ith(CxtRef.LlamaContext, LlamaBatch.n_tokens - 1); - - std::vector Candidates; - Candidates.reserve(NVocab); - for (llama_token TokenId = 0; TokenId < NVocab; TokenId++) { - Candidates.emplace_back(llama_token_data{TokenId, Logits[TokenId], 0.0f}); - } - llama_token_data_array CandidatesP = {Candidates.data(), Candidates.size(), - false}; - - // Sample the most likely token - const llama_token NewTokenId = - llama_sample_token_greedy(CxtRef.LlamaContext, &CandidatesP); - - // Is it an end of stream? - if (NewTokenId == llama_token_eos(CxtRef.LlamaContext) || - NCur == MaxContextSize || NCur == NPredict) { - break; - } + gpt_params GPTParams; + struct llama_sampling_context *CtxSampling = + llama_sampling_init(GPTParams.sparams); + std::vector Embd; + int NPast = 0; + int NConsumed = 0; + int NRemain = GraphRef.NPredict; + int NKeep = GPTParams.n_keep; + int NCtx = llama_n_ctx(CxtRef.LlamaContext); + // Minus 4 for the special tokens. + const int MaxTokensListSize = NCtx - 4; + while (NRemain != 0) { + // Preidct + if (!Embd.empty()) { + // Truncate if necessary. + if (static_cast(Embd.size()) > MaxTokensListSize) { + auto NSkipped = Embd.size() - MaxTokensListSize; + Embd.resize(MaxTokensListSize); + if (GraphRef.EnableLog) { + spdlog::info("[WASI-NN] GGML backend: Truncated {} tokens"sv, + NSkipped); + } + } - std::string NextToken = - llama_token_to_piece(CxtRef.LlamaContext, NewTokenId); + // Infinite text generation via context swapping. + if (NPast + static_cast(Embd.size()) > NCtx) { + const int NLeft = NPast + NKeep - 1; + const int NDiscard = NLeft / 2; + if (GraphRef.EnableLog) { + spdlog::info( + "[WASI-NN] GGML backend: Context full, swapping: NPast = {}, NLeft = {}, NKeep = {}, NDiscard = {}"sv, + NPast, NLeft, NKeep, NDiscard); + } + llama_kv_cache_seq_rm(CxtRef.LlamaContext, 0, NKeep + 1, + NKeep + NDiscard + 1); + llama_kv_cache_seq_shift(CxtRef.LlamaContext, 0, NKeep + 1 + NDiscard, + NPast, -NDiscard); + NPast -= NDiscard; + if (GraphRef.EnableLog) { + spdlog::info( + "[WASI-NN] GGML backend: After context swapping: NPast = {}"sv, + NPast); + } + } - // When setting StreamStdout, we print the output to stdout. - if (GraphRef.StreamStdout) { - std::cout << NextToken << std::flush; + // Evaluate tokens in batches. + for (int I = 0; I < static_cast(Embd.size()); + I += GraphRef.BatchSize) { + int NEval = static_cast(Embd.size()) - I; + if (NEval > static_cast(GraphRef.BatchSize)) { + NEval = GraphRef.BatchSize; + } + if (llama_decode(CxtRef.LlamaContext, + llama_batch_get_one(&Embd[I], NEval, NPast, 0))) { + spdlog::error("[WASI-NN] GGML backend: failed to llama_decode"sv); + return ErrNo::RuntimeError; + } + + NPast += NEval; + } } - // Append the new token. - CxtRef.LlamaOutputs += NextToken; - - // Prepare the next batch - LlamaBatch.n_tokens = 0; - - // Push this new token for next evaluation - LlamaBatch.token[LlamaBatch.n_tokens] = NewTokenId; - LlamaBatch.pos[LlamaBatch.n_tokens] = NCur; - LlamaBatch.seq_id[LlamaBatch.n_tokens] = 0; - LlamaBatch.logits[LlamaBatch.n_tokens] = true; - LlamaBatch.n_tokens += 1; - NCur += 1; - - // Evaluate the current batch with the transformer model - if (llama_decode(CxtRef.LlamaContext, LlamaBatch)) { - spdlog::error("[WASI-NN] GGML backend: failed to llama_decode"sv); - return ErrNo::RuntimeError; + Embd.clear(); + + if (static_cast(CxtRef.LlamaInputs.size()) <= NConsumed) { + const llama_token Id = + llama_sampling_sample(CtxSampling, CxtRef.LlamaContext, nullptr); + llama_sampling_accept(CtxSampling, CxtRef.LlamaContext, Id, true); + Embd.emplace_back(Id); + --NRemain; + // Save the output token. + CxtRef.LlamaOutputs += llama_token_to_piece(CxtRef.LlamaContext, Id); + // When setting StreamStdout, we print the output to stdout. + if (GraphRef.StreamStdout) { + std::cout << llama_token_to_piece(CxtRef.LlamaContext, Id) + << std::flush; + } + // Break if reverse prompt is found. + if (!GraphRef.ReversePrompt.empty() && + CxtRef.LlamaOutputs.find(GraphRef.ReversePrompt) != + std::string::npos) { + if (GraphRef.EnableLog) { + spdlog::info("[WASI-NN] GGML backend: reverse prompt found"sv); + } + break; + } + } else { + while (static_cast(CxtRef.LlamaInputs.size()) > NConsumed) { + Embd.push_back(CxtRef.LlamaInputs[NConsumed]); + // Push the prompt in the sampling context. + llama_sampling_accept(CtxSampling, CxtRef.LlamaContext, + CxtRef.LlamaInputs[NConsumed], false); + ++NConsumed; + if (Embd.size() >= GraphRef.BatchSize) { + break; + } + } } - // Break if reverse prompt is found. - if (!GraphRef.ReversePrompt.empty() && - CxtRef.LlamaOutputs.find(GraphRef.ReversePrompt) != std::string::npos) { - if (GraphRef.EnableLog) { - spdlog::info("[WASI-NN] GGML backend: reverse prompt found"sv); + // If not currently processing queued inputs. + if (static_cast(CxtRef.LlamaInputs.size()) <= NConsumed) { + // Deal with end of text token. + if (llama_sampling_last(CtxSampling) == + llama_token_eos(GraphRef.LlamaModel)) { + if (GraphRef.EnableLog) { + spdlog::info("[WASI-NN] GGML backend: EOS token found"sv); + } + break; } - break; } } From ede4f89026d30eeb0a51df877d921855ed59340a Mon Sep 17 00:00:00 2001 From: dm4 Date: Fri, 24 Nov 2023 11:33:43 +0800 Subject: [PATCH 185/623] [WASI-NN] ggml: update test model Signed-off-by: dm4 --- test/plugins/wasi_nn/CMakeLists.txt | 6 +++--- test/plugins/wasi_nn/wasi_nn.cpp | 6 +++--- utils/wasi-nn/download-ggml-fixtures.sh | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index 9546c667..e7eb03c5 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -60,9 +60,9 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) COMMAND bash ${CMAKE_SOURCE_DIR}/utils/wasi-nn/download-ggml-fixtures.sh ${CMAKE_CURRENT_BINARY_DIR}/wasinn_ggml_fixtures RESULT_VARIABLE DOWNLOAD_ERROR OUTPUT_STRIP_TRAILING_WHITESPACE) - file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_ggml_fixtures/orca-mini-3b.q4_0.gguf CHECKSUM_MODEL) - if(NOT CHECKSUM_MODEL STREQUAL "aae346fe095e60139ca39b3fda4ac7ae") - message(FATAL_ERROR "orca-mini-3b.q4_0.gguf downloaded with wrong md5") + file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_ggml_fixtures/orca_mini.gguf CHECKSUM_MODEL) + if(NOT CHECKSUM_MODEL STREQUAL "f895f00678bfbf89f70d6d25f20a7b5f") + message(FATAL_ERROR "orca_mini.gguf downloaded with wrong md5") endif() else() # Add the other backend test files fetching here. diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index b4b1ce9d..c66849a3 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -1218,7 +1218,7 @@ TEST(WasiNNTest, GGMLBackend) { WasmEdge::Runtime::Instance::ModuleInstance Mod(""); Mod.addHostMemory( "memory", std::make_unique( - WasmEdge::AST::MemoryType(40000))); + WasmEdge::AST::MemoryType(60000))); auto *MemInstPtr = Mod.findMemoryExports("memory"); ASSERT_TRUE(MemInstPtr != nullptr); auto &MemInst = *MemInstPtr; @@ -1228,13 +1228,13 @@ TEST(WasiNNTest, GGMLBackend) { std::string Prompt = "Once upon a time, "; std::vector TensorData(Prompt.begin(), Prompt.end()); std::vector WeightRead = - readEntireFile("./wasinn_ggml_fixtures/orca-mini-3b.q4_0.gguf"); + readEntireFile("./wasinn_ggml_fixtures/orca_mini.gguf"); std::vector TensorDim{1}; uint32_t BuilderPtr = UINT32_C(0); uint32_t LoadEntryPtr = UINT32_C(0); uint32_t SetInputEntryPtr = UINT32_C(0); - uint32_t OutBoundPtr = UINT32_C(41000 * 65536); + uint32_t OutBoundPtr = UINT32_C(61000 * 65536); uint32_t StorePtr = UINT32_C(65536); // Return value. diff --git a/utils/wasi-nn/download-ggml-fixtures.sh b/utils/wasi-nn/download-ggml-fixtures.sh index 53f8ea15..bb635925 100755 --- a/utils/wasi-nn/download-ggml-fixtures.sh +++ b/utils/wasi-nn/download-ggml-fixtures.sh @@ -6,8 +6,8 @@ TODIR=$1 if [[ $# -eq 0 ]]; then TODIR=. fi -MODEL=orca-mini-3b.q4_0.gguf -FIXTURE=https://huggingface.co/juanjgit/orca_mini_3B-GGUF/resolve/main/$MODEL +MODEL=orca_mini.gguf +FIXTURE=https://huggingface.co/TheBloke/orca_mini_v3_7B-GGUF/resolve/main/orca_mini_v3_7b.Q2_K.gguf if [ ! -d $TODIR ]; then mkdir $TODIR fi From f8b3517233deae65d5211e01fcef95c26fa3c8a9 Mon Sep 17 00:00:00 2001 From: dm4 Date: Fri, 24 Nov 2023 14:34:13 +0800 Subject: [PATCH 186/623] [WASI-NN] ggml: free context at the end of compute Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index d2d46f2c..9faa70f9 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -262,16 +262,6 @@ Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, uint32_t &ContextId) noexcept { Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); ContextId = Env.NNContext.size() - 1; - auto &CxtRef = Env.NNContext[ContextId].get(); - auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - - // Initialize the llama context. - llama_context_params ContextParams = llama_context_default_params(); - ContextParams.n_ctx = GraphRef.CtxSize; - ContextParams.n_batch = GraphRef.BatchSize; - CxtRef.LlamaContext = - llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); - return ErrNo::Success; } @@ -377,6 +367,12 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { int NConsumed = 0; int NRemain = GraphRef.NPredict; int NKeep = GPTParams.n_keep; + // Initialize the llama context. + llama_context_params ContextParams = llama_context_default_params(); + ContextParams.n_ctx = GraphRef.CtxSize; + ContextParams.n_batch = GraphRef.BatchSize; + CxtRef.LlamaContext = + llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); int NCtx = llama_n_ctx(CxtRef.LlamaContext); // Minus 4 for the special tokens. const int MaxTokensListSize = NCtx - 4; @@ -485,6 +481,9 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { llama_print_timings(CxtRef.LlamaContext); } + llama_sampling_free(CtxSampling); + llama_free(CxtRef.LlamaContext); + return ErrNo::Success; } #else From 423c8f0567abeba37def8b81d8f44b5499394bb3 Mon Sep 17 00:00:00 2001 From: dm4 Date: Fri, 24 Nov 2023 17:40:13 +0800 Subject: [PATCH 187/623] [WASI-NN] ggml: add more comments Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 49 ++++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 9faa70f9..7133b1ea 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -374,14 +374,17 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { CxtRef.LlamaContext = llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); int NCtx = llama_n_ctx(CxtRef.LlamaContext); - // Minus 4 for the special tokens. + // Minus 4 for the special tokens. (Such as , , ... tokens.) const int MaxTokensListSize = NCtx - 4; - while (NRemain != 0) { + // Use the const sequence id here. + const int SequenceId = 0; + while (NRemain >= 0) { // Preidct if (!Embd.empty()) { // Truncate if necessary. if (static_cast(Embd.size()) > MaxTokensListSize) { auto NSkipped = Embd.size() - MaxTokensListSize; + // We follow llama.cpp/example/main to truncate the last few tokens. Embd.resize(MaxTokensListSize); if (GraphRef.EnableLog) { spdlog::info("[WASI-NN] GGML backend: Truncated {} tokens"sv, @@ -398,10 +401,16 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { "[WASI-NN] GGML backend: Context full, swapping: NPast = {}, NLeft = {}, NKeep = {}, NDiscard = {}"sv, NPast, NLeft, NKeep, NDiscard); } - llama_kv_cache_seq_rm(CxtRef.LlamaContext, 0, NKeep + 1, + // llama_kv_cache_seq_rm(context, sequence_id, start_pos, end_pos) + // This will remove the tokens [start_pos, end_pos). + llama_kv_cache_seq_rm(CxtRef.LlamaContext, SequenceId, NKeep + 1, NKeep + NDiscard + 1); - llama_kv_cache_seq_shift(CxtRef.LlamaContext, 0, NKeep + 1 + NDiscard, - NPast, -NDiscard); + // llama_kv_cache_seq_shift(context, sequence_id, start_pos, end_pos, + // delta) + // This will shift the tokens at [start_pos, end_pos) with delta + // distance. + llama_kv_cache_seq_shift(CxtRef.LlamaContext, SequenceId, + NKeep + 1 + NDiscard, NPast, -NDiscard); NPast -= NDiscard; if (GraphRef.EnableLog) { spdlog::info( @@ -417,8 +426,12 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { if (NEval > static_cast(GraphRef.BatchSize)) { NEval = GraphRef.BatchSize; } - if (llama_decode(CxtRef.LlamaContext, - llama_batch_get_one(&Embd[I], NEval, NPast, 0))) { + // llama_batch_get_one(*token, n_tokens, position, sequence_id) + // This will return batch for single sequence of tokens starting at + // position. + if (llama_decode( + CxtRef.LlamaContext, + llama_batch_get_one(&Embd[I], NEval, NPast, SequenceId))) { spdlog::error("[WASI-NN] GGML backend: failed to llama_decode"sv); return ErrNo::RuntimeError; } @@ -451,6 +464,14 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { } break; } + // Deal with end of text token. + if (llama_sampling_last(CtxSampling) == + llama_token_eos(GraphRef.LlamaModel)) { + if (GraphRef.EnableLog) { + spdlog::info("[WASI-NN] GGML backend: EOS token found"sv); + } + break; + } } else { while (static_cast(CxtRef.LlamaInputs.size()) > NConsumed) { Embd.push_back(CxtRef.LlamaInputs[NConsumed]); @@ -463,24 +484,14 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { } } } - - // If not currently processing queued inputs. - if (static_cast(CxtRef.LlamaInputs.size()) <= NConsumed) { - // Deal with end of text token. - if (llama_sampling_last(CtxSampling) == - llama_token_eos(GraphRef.LlamaModel)) { - if (GraphRef.EnableLog) { - spdlog::info("[WASI-NN] GGML backend: EOS token found"sv); - } - break; - } - } } if (GraphRef.EnableLog) { llama_print_timings(CxtRef.LlamaContext); } + // We free the contexts here to keep the ggml plugin stateless. + // Users could fully controll the contexts by themselves via their prompt. llama_sampling_free(CtxSampling); llama_free(CxtRef.LlamaContext); From b3e75fe8905e0d1561223220bdb8f14e9808417c Mon Sep 17 00:00:00 2001 From: dm4 Date: Fri, 24 Nov 2023 18:24:34 +0800 Subject: [PATCH 188/623] [WASI-NN] ggml: support metadata of outputs from index 1 Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 21 +++++++++++++++++++-- plugins/wasi_nn/ggml.h | 1 + 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 7133b1ea..963df546 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -332,10 +332,25 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, - [[maybe_unused]] uint32_t Index, - Span OutBuffer, + uint32_t Index, Span OutBuffer, uint32_t &BytesWritten) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); + // Index 1 is for the metadata of the outputs. + if (Index == 1) { + std::string MetadataTemplate = + R"({"input_tokens": %d, "output_tokens": %d})"; + // The 20 bytes are reserved to accommodate two %d placeholders in the + // MetadataTemplate. This allows for a decimal integer value up to a + // 12-digit number of input/output tokens. + char Buffer[MetadataTemplate.size() + 20]; + snprintf(Buffer, sizeof(Buffer), MetadataTemplate.c_str(), + CxtRef.LlamaInputs.size(), CxtRef.LlamaOutputTokens.size()); + std::string Metadata(Buffer); + std::copy_n(Metadata.data(), Metadata.length(), OutBuffer.data()); + BytesWritten = Metadata.length(); + return ErrNo::Success; + } + std::copy_n(CxtRef.LlamaOutputs.data(), CxtRef.LlamaOutputs.length(), OutBuffer.data()); BytesWritten = CxtRef.LlamaOutputs.length(); @@ -357,6 +372,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // Clear the outputs. CxtRef.LlamaOutputs.clear(); + CxtRef.LlamaOutputTokens.clear(); // Main predict loop. gpt_params GPTParams; @@ -449,6 +465,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { Embd.emplace_back(Id); --NRemain; // Save the output token. + CxtRef.LlamaOutputTokens.emplace_back(Id); CxtRef.LlamaOutputs += llama_token_to_piece(CxtRef.LlamaContext, Id); // When setting StreamStdout, we print the output to stdout. if (GraphRef.StreamStdout) { diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index e24ae954..dfcd285b 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -39,6 +39,7 @@ struct Context { llama_context *LlamaContext = nullptr; std::vector LlamaInputs; std::string LlamaOutputs; + std::vector LlamaOutputTokens; }; #else struct Graph {}; From b6d705438604cd8ccdc6eb7686a1fcf0051a198e Mon Sep 17 00:00:00 2001 From: dm4 Date: Fri, 24 Nov 2023 21:41:34 +0800 Subject: [PATCH 189/623] [WASI-NN] ggml: fix typo Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 963df546..a3687d65 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -508,7 +508,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { } // We free the contexts here to keep the ggml plugin stateless. - // Users could fully controll the contexts by themselves via their prompt. + // Users could fully control the contexts by themselves via their prompt. llama_sampling_free(CtxSampling); llama_free(CxtRef.LlamaContext); From 51637f474f244c25eb3638e5579f1f5a4b503295 Mon Sep 17 00:00:00 2001 From: dm4 Date: Mon, 27 Nov 2023 14:25:20 +0800 Subject: [PATCH 190/623] [WASI-NN] ggml: remove context swapping Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 41 ++++++++++------------------------------ 1 file changed, 10 insertions(+), 31 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index a3687d65..87fa44c3 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -324,7 +324,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, const uint32_t MaxTokensListSize = MaxContextSize - 4; if (CxtRef.LlamaInputs.size() > MaxTokensListSize) { spdlog::error( - "[WASI-NN] GGML backend: Error: prompt too long ({} tokens, max {})"sv, + "[WASI-NN] GGML backend: Error: The prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, CxtRef.LlamaInputs.size(), MaxTokensListSize); return ErrNo::InvalidArgument; } @@ -382,7 +382,6 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { int NPast = 0; int NConsumed = 0; int NRemain = GraphRef.NPredict; - int NKeep = GPTParams.n_keep; // Initialize the llama context. llama_context_params ContextParams = llama_context_default_params(); ContextParams.n_ctx = GraphRef.CtxSize; @@ -397,42 +396,22 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { while (NRemain >= 0) { // Preidct if (!Embd.empty()) { - // Truncate if necessary. + // Input too long. if (static_cast(Embd.size()) > MaxTokensListSize) { - auto NSkipped = Embd.size() - MaxTokensListSize; - // We follow llama.cpp/example/main to truncate the last few tokens. - Embd.resize(MaxTokensListSize); - if (GraphRef.EnableLog) { - spdlog::info("[WASI-NN] GGML backend: Truncated {} tokens"sv, - NSkipped); - } + spdlog::error( + "[WASI-NN] GGML backend: Error: The prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, + Embd.size(), MaxTokensListSize); + return ErrNo::RuntimeError; } - // Infinite text generation via context swapping. + // We do not swap context here. End the inference if the context is full. if (NPast + static_cast(Embd.size()) > NCtx) { - const int NLeft = NPast + NKeep - 1; - const int NDiscard = NLeft / 2; - if (GraphRef.EnableLog) { - spdlog::info( - "[WASI-NN] GGML backend: Context full, swapping: NPast = {}, NLeft = {}, NKeep = {}, NDiscard = {}"sv, - NPast, NLeft, NKeep, NDiscard); - } - // llama_kv_cache_seq_rm(context, sequence_id, start_pos, end_pos) - // This will remove the tokens [start_pos, end_pos). - llama_kv_cache_seq_rm(CxtRef.LlamaContext, SequenceId, NKeep + 1, - NKeep + NDiscard + 1); - // llama_kv_cache_seq_shift(context, sequence_id, start_pos, end_pos, - // delta) - // This will shift the tokens at [start_pos, end_pos) with delta - // distance. - llama_kv_cache_seq_shift(CxtRef.LlamaContext, SequenceId, - NKeep + 1 + NDiscard, NPast, -NDiscard); - NPast -= NDiscard; if (GraphRef.EnableLog) { spdlog::info( - "[WASI-NN] GGML backend: After context swapping: NPast = {}"sv, - NPast); + "[WASI-NN] GGML backend: the context if full ({} / {} tokens)"sv, + NPast + static_cast(Embd.size()), NCtx); } + break; } // Evaluate tokens in batches. From d704b5329c2fe7c87827cfd16396ca42616e9a85 Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 30 Nov 2023 11:25:45 +0800 Subject: [PATCH 191/623] [WASI-NN] ggml: update ggml to b1575 Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index c6ccb81a..e7a03946 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -47,7 +47,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b1550 + GIT_TAG b1575 PATCH_COMMAND test -f ggml.patched || git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch && ${CMAKE_COMMAND} -E touch ggml.patched GIT_SHALLOW TRUE ) @@ -141,6 +141,13 @@ target_include_directories(wasmedgePluginWasiNN if(BACKEND STREQUAL "ggml") target_include_directories(wasmedgePluginWasiNN PUBLIC ${CMAKE_BINARY_DIR}/_deps/llama-src) target_link_libraries(wasmedgePluginWasiNN PRIVATE common simdjson) + if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL) + add_custom_command( + TARGET wasmedgePluginWasiNN + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_BINARY_DIR}/_deps/llama-src/ggml-metal.metal ggml-metal.metal + ) + endif() endif() if(WASMEDGE_LINK_PLUGINS_STATIC) From 33d763b40ae60cd824633b8d0238014deb7ee8e5 Mon Sep 17 00:00:00 2001 From: dm4 Date: Mon, 4 Dec 2023 22:45:19 +0800 Subject: [PATCH 192/623] [WASI-NN] ggml: add repeat-penalty and temp Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 20 ++++++++++++++++++++ plugins/wasi_nn/ggml.h | 3 +++ 2 files changed, 23 insertions(+) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 87fa44c3..8e9f0895 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -104,6 +104,24 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, return ErrNo::InvalidArgument; } } + // The sampling parameters. + if (Doc.at_key("temp").error() == simdjson::SUCCESS) { + auto Err = Doc["temp"].get().get(GraphRef.Temp); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the temp option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Temp = std::max(0.0, GraphRef.Temp); + } + if (Doc.at_key("repeat-penalty").error() == simdjson::SUCCESS) { + auto Err = Doc["repeat-penalty"].get().get(GraphRef.RepeatPenalty); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the repeat-penalty option."sv); + return ErrNo::InvalidArgument; + } + } // Check if the model is updated. if (IsModelUpdated && ModelParams.n_gpu_layers != GraphRef.NGPULayers) { @@ -376,6 +394,8 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // Main predict loop. gpt_params GPTParams; + GPTParams.sparams.temp = GraphRef.Temp; + GPTParams.sparams.penalty_repeat = GraphRef.RepeatPenalty; struct llama_sampling_context *CtxSampling = llama_sampling_init(GPTParams.sparams); std::vector Embd; diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index dfcd285b..23139e97 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -30,6 +30,9 @@ struct Graph { // Context parameters: uint64_t CtxSize; uint64_t BatchSize; + // Sampleing parameters: + double Temp; + double RepeatPenalty; }; struct Context { From 7080c2ac94249b69d94a12b7a4661a6ec89ca700 Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 6 Dec 2023 16:05:02 +0800 Subject: [PATCH 193/623] [WASI-NN] ggml: init repeat-penalty and temp Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 8e9f0895..b6d7a8b6 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -207,6 +207,10 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Initialize the context parameters. GraphRef.CtxSize = ContextDefault.n_ctx; GraphRef.BatchSize = ContextDefault.n_batch; + // Initialize the sampling parameters. + llama_sampling_params SamplingDefault; + GraphRef.Temp = SamplingDefault.temp; + GraphRef.RepeatPenalty = SamplingDefault.penalty_repeat; // If the graph builder length > 1, the data of builder[1] is the metadata. if (Builders.size() > 1) { From 7e738e66fd51230476a19947690abd7ffe93d397 Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 6 Dec 2023 16:05:56 +0800 Subject: [PATCH 194/623] [WASI-NN] ggml: do not save llama_context, reset llama_context every time Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 102 ++++++++++++++++++--------------------- plugins/wasi_nn/ggml.h | 1 - 2 files changed, 48 insertions(+), 55 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index b6d7a8b6..b32fe9d3 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -16,7 +16,6 @@ namespace WasmEdge::Host::WASINN::GGML { namespace details { Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, - bool *IsCxtUpdated = nullptr, bool *IsModelUpdated = nullptr) noexcept { simdjson::dom::parser Parser; simdjson::dom::element Doc; @@ -36,9 +35,6 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, // Get the current llama parameters. llama_model_params ModelParams = llama_model_default_params(); ModelParams.n_gpu_layers = GraphRef.NGPULayers; - llama_context_params CxtParams = llama_context_default_params(); - CxtParams.n_ctx = GraphRef.CtxSize; - CxtParams.n_batch = GraphRef.BatchSize; // The plugin parameters. if (Doc.at_key("enable-log").error() == simdjson::SUCCESS) { @@ -128,12 +124,6 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, *IsModelUpdated = true; } - // Check if the context is updated. - if (IsCxtUpdated && (CxtParams.n_ctx != GraphRef.CtxSize || - CxtParams.n_batch != GraphRef.BatchSize)) { - *IsCxtUpdated = true; - } - return ErrNo::Success; } @@ -292,56 +282,57 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - bool IsCxtParamsUpdated = false; bool IsModelParamsUpdated = false; // Use index 1 for metadata. if (Index == 1) { const std::string Metadata(reinterpret_cast(Tensor.Tensor.data()), Tensor.Tensor.size()); - return details::parseMetadata(GraphRef, Metadata, &IsCxtParamsUpdated, - &IsModelParamsUpdated); - } + auto Res = + details::parseMetadata(GraphRef, Metadata, &IsModelParamsUpdated); - // XXX: Due to the limitation of WASI-NN proposal, - // we have no way to pass the metadata before the setInput phase - // when we want to do some configurations in the load phase. - // That's why we have this hack. + if (Res != ErrNo::Success) { + spdlog::error("[WASI-NN] GGML backend: Failed to parse metadata."sv); + return Res; + } + + // XXX: Due to the limitation of WASI-NN proposal, + // we have no way to pass the metadata before the setInput phase + // when we want to do some configurations in the load phase. + // That's why we have this hack. #ifndef __APPLE__ - { - if (IsModelParamsUpdated) { - llama_model_params ModelParams = llama_model_default_params(); - ModelParams.n_gpu_layers = GraphRef.NGPULayers; - llama_free_model(GraphRef.LlamaModel); - GraphRef.LlamaModel = llama_load_model_from_file( - GraphRef.ModelFilePath.c_str(), ModelParams); - if (GraphRef.LlamaModel == nullptr) { - spdlog::error("[WASI-NN] GGML backend: Error: unable to init model."sv); - Env.NNGraph.pop_back(); - return ErrNo::InvalidArgument; + { + if (IsModelParamsUpdated) { + llama_model_params ModelParams = llama_model_default_params(); + ModelParams.n_gpu_layers = GraphRef.NGPULayers; + llama_free_model(GraphRef.LlamaModel); + GraphRef.LlamaModel = llama_load_model_from_file( + GraphRef.ModelFilePath.c_str(), ModelParams); + if (GraphRef.LlamaModel == nullptr) { + spdlog::error( + "[WASI-NN] GGML backend: Error: unable to init model."sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } } } - } #endif - // Initialize the llama context. - if (CxtRef.LlamaContext == nullptr || IsCxtParamsUpdated) { - llama_context_params ContextParams = llama_context_default_params(); - ContextParams.n_ctx = GraphRef.CtxSize; - ContextParams.n_batch = GraphRef.BatchSize; - if (CxtRef.LlamaContext != nullptr) { - llama_free(CxtRef.LlamaContext); - } - CxtRef.LlamaContext = - llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); + return ErrNo::Success; } + // Initialize the llama context. + llama_context_params ContextParams = llama_context_default_params(); + ContextParams.n_ctx = GraphRef.CtxSize; + ContextParams.n_batch = GraphRef.BatchSize; + auto LlamaContext = + llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); + // Set the input. const bool AddBos = llama_should_add_bos_token(GraphRef.LlamaModel); std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), Tensor.Tensor.size()); - CxtRef.LlamaInputs = - llama_tokenize(CxtRef.LlamaContext, Prompt, AddBos, true); - const uint32_t MaxContextSize = llama_n_ctx(CxtRef.LlamaContext); + CxtRef.LlamaInputs = llama_tokenize(LlamaContext, Prompt, AddBos, true); + const uint32_t MaxContextSize = llama_n_ctx(LlamaContext); // Minus 4 for the special tokens. const uint32_t MaxTokensListSize = MaxContextSize - 4; if (CxtRef.LlamaInputs.size() > MaxTokensListSize) { @@ -350,6 +341,10 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, CxtRef.LlamaInputs.size(), MaxTokensListSize); return ErrNo::InvalidArgument; } + + // Delete the llama context. + llama_free(LlamaContext); + return ErrNo::Success; } @@ -410,9 +405,9 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { llama_context_params ContextParams = llama_context_default_params(); ContextParams.n_ctx = GraphRef.CtxSize; ContextParams.n_batch = GraphRef.BatchSize; - CxtRef.LlamaContext = + auto LlamaContext = llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); - int NCtx = llama_n_ctx(CxtRef.LlamaContext); + int NCtx = llama_n_ctx(LlamaContext); // Minus 4 for the special tokens. (Such as , , ... tokens.) const int MaxTokensListSize = NCtx - 4; // Use the const sequence id here. @@ -449,7 +444,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // This will return batch for single sequence of tokens starting at // position. if (llama_decode( - CxtRef.LlamaContext, + LlamaContext, llama_batch_get_one(&Embd[I], NEval, NPast, SequenceId))) { spdlog::error("[WASI-NN] GGML backend: failed to llama_decode"sv); return ErrNo::RuntimeError; @@ -463,17 +458,16 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { if (static_cast(CxtRef.LlamaInputs.size()) <= NConsumed) { const llama_token Id = - llama_sampling_sample(CtxSampling, CxtRef.LlamaContext, nullptr); - llama_sampling_accept(CtxSampling, CxtRef.LlamaContext, Id, true); + llama_sampling_sample(CtxSampling, LlamaContext, nullptr); + llama_sampling_accept(CtxSampling, LlamaContext, Id, true); Embd.emplace_back(Id); --NRemain; // Save the output token. CxtRef.LlamaOutputTokens.emplace_back(Id); - CxtRef.LlamaOutputs += llama_token_to_piece(CxtRef.LlamaContext, Id); + CxtRef.LlamaOutputs += llama_token_to_piece(LlamaContext, Id); // When setting StreamStdout, we print the output to stdout. if (GraphRef.StreamStdout) { - std::cout << llama_token_to_piece(CxtRef.LlamaContext, Id) - << std::flush; + std::cout << llama_token_to_piece(LlamaContext, Id) << std::flush; } // Break if reverse prompt is found. if (!GraphRef.ReversePrompt.empty() && @@ -496,7 +490,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { while (static_cast(CxtRef.LlamaInputs.size()) > NConsumed) { Embd.push_back(CxtRef.LlamaInputs[NConsumed]); // Push the prompt in the sampling context. - llama_sampling_accept(CtxSampling, CxtRef.LlamaContext, + llama_sampling_accept(CtxSampling, LlamaContext, CxtRef.LlamaInputs[NConsumed], false); ++NConsumed; if (Embd.size() >= GraphRef.BatchSize) { @@ -507,13 +501,13 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { } if (GraphRef.EnableLog) { - llama_print_timings(CxtRef.LlamaContext); + llama_print_timings(LlamaContext); } // We free the contexts here to keep the ggml plugin stateless. // Users could fully control the contexts by themselves via their prompt. llama_sampling_free(CtxSampling); - llama_free(CxtRef.LlamaContext); + llama_free(LlamaContext); return ErrNo::Success; } diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index 23139e97..9b55937f 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -39,7 +39,6 @@ struct Context { public: Context(size_t GId, Graph &) noexcept : GraphId(GId) {} size_t GraphId; - llama_context *LlamaContext = nullptr; std::vector LlamaInputs; std::string LlamaOutputs; std::vector LlamaOutputTokens; From 61b3ad42920548dba0a58f63f570d3bdf3f4fbd2 Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 6 Dec 2023 16:06:21 +0800 Subject: [PATCH 195/623] [WASI-NN] ggml: update ggml to b1616 Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index e7a03946..6db7ff75 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -47,7 +47,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b1575 + GIT_TAG b1616 PATCH_COMMAND test -f ggml.patched || git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch && ${CMAKE_COMMAND} -E touch ggml.patched GIT_SHALLOW TRUE ) From 30890ec486e5a0ed8aab025efb58250148f57e7b Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 12 Dec 2023 16:07:34 +0800 Subject: [PATCH 196/623] [WASI-NN] ggml backend: add more debug log (#3089) Signed-off-by: hydai --- plugins/wasi_nn/ggml.cpp | 87 ++++++++++++++++++++++++++++++++++++++++ plugins/wasi_nn/ggml.h | 1 + 2 files changed, 88 insertions(+) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index b32fe9d3..4eb233b7 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -46,6 +46,14 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, } llama_log_set(nullptr, &GraphRef.EnableLog); } + if (Doc.at_key("enable-debug-log").error() == simdjson::SUCCESS) { + auto Err = Doc["enable-debug-log"].get().get(GraphRef.EnableDebugLog); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the enable-debug-log option."sv); + return ErrNo::InvalidArgument; + } + } if (Doc.at_key("stream-stdout").error() == simdjson::SUCCESS) { auto Err = Doc["stream-stdout"].get().get(GraphRef.StreamStdout); if (Err) { @@ -215,6 +223,9 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, } } + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: Handling model path."sv); + } // Handle the model path. auto Weight = Builders[0]; std::string BinModel(reinterpret_cast(Weight.data()), Weight.size()); @@ -231,6 +242,10 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, return Res; } } else { + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: Model path not found in nn-preload, write model into a tmpfile."sv); + } // TODO: pass the model directly to ggml // Write ggml model to file. std::istringstream BinRead(BinModel); @@ -247,8 +262,20 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, } TempFile << BinModel; TempFile.close(); + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: Write model into a tmpfile...Done"sv); + } + } + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: Finished handling model path."sv); } + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: Initialize ggml model with given parameters"sv); + } // Initialize ggml model with model parameters. GraphRef.ModelFilePath = ModelFilePath; llama_model_params ModelParams = llama_model_default_params(); @@ -260,6 +287,10 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, Env.NNGraph.pop_back(); return ErrNo::InvalidArgument; } + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: Initialize ggml model with given parameters...Done"sv); + } // Store the loaded graph. GraphId = Env.NNGraph.size() - 1; @@ -281,10 +312,17 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, uint32_t Index, const TensorData &Tensor) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: setInput"sv); + } bool IsModelParamsUpdated = false; // Use index 1 for metadata. if (Index == 1) { + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: found Metadata, processing"sv); + } const std::string Metadata(reinterpret_cast(Tensor.Tensor.data()), Tensor.Tensor.size()); auto Res = @@ -317,17 +355,30 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } #endif + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: found Metadata, processing...Done"sv); + } return ErrNo::Success; } // Initialize the llama context. + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: init llama context"sv); + } llama_context_params ContextParams = llama_context_default_params(); ContextParams.n_ctx = GraphRef.CtxSize; ContextParams.n_batch = GraphRef.BatchSize; auto LlamaContext = llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: init llama context...Done"sv); + } // Set the input. + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: set the input"sv); + } const bool AddBos = llama_should_add_bos_token(GraphRef.LlamaModel); std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), Tensor.Tensor.size()); @@ -341,10 +392,24 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, CxtRef.LlamaInputs.size(), MaxTokensListSize); return ErrNo::InvalidArgument; } + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: set the input...Done"sv); + } // Delete the llama context. + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: delete llama context to make it stateless"sv); + } llama_free(LlamaContext); + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: delete llama context to make it stateless...Done"sv); + } + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: setInput...Done"sv); + } return ErrNo::Success; } @@ -377,6 +442,9 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: compute"sv); + } if (CxtRef.LlamaInputs.size() == 0) { spdlog::error("[WASI-NN] GGML backend: Llama input is not set!"sv); return ErrNo::InvalidArgument; @@ -388,10 +456,21 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { } // Clear the outputs. + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: clear the previous output and tokens"sv); + } CxtRef.LlamaOutputs.clear(); CxtRef.LlamaOutputTokens.clear(); + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: clear the previous output and tokens...Done"sv); + } // Main predict loop. + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: enter main predict loop"sv); + } gpt_params GPTParams; GPTParams.sparams.temp = GraphRef.Temp; GPTParams.sparams.penalty_repeat = GraphRef.RepeatPenalty; @@ -499,6 +578,10 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { } } } + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: enter main predict loop...Done"sv); + } if (GraphRef.EnableLog) { llama_print_timings(LlamaContext); @@ -509,6 +592,10 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { llama_sampling_free(CtxSampling); llama_free(LlamaContext); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: compute...Done"sv); + } + return ErrNo::Success; } #else diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index 9b55937f..e5c7981c 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -22,6 +22,7 @@ struct Graph { std::string ModelFilePath; // Plugin parameters: bool EnableLog; + bool EnableDebugLog; bool StreamStdout; uint64_t NPredict; std::string ReversePrompt; From 73cb0fc81d4915033b1959da287457a8f41af464 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Fri, 10 Mar 2023 03:37:55 +0800 Subject: [PATCH 197/623] [Common] Refactor the `WasmEdge::ValType` and `WasmEdge::RefType` class. Concepts: Refine the `ValType`, `RefType`, and `BlockType` into 8-bytes length. 1. Implement the `HeapType` built-in to the `ValType` and `RefType`. 2. Rename the enum class into `WasmEdge::ValTypeCode` and `WasmEdge::RefTypeCode`. 3. Rename the enum into `WasmEdge_ValTypeCode` and `WasmEdge_RefTypeCode`. 4. The `RefType` is subset of the `ValType`. 5. Refactor the `WasmEdge_ValType` in C API, which is a struct contains the opaque data in 8-bytes. 6. The data in the `WasmEdge_ValType` is copied from the internal `WasmEdge::ValType`, and vise versa. Signed-off-by: YiYing He --- test/plugins/unittest/testplugin.c | 8 +- test/plugins/wasm_bpf/simple_map_test.cpp | 521 +++++++++--------- test/plugins/wasm_bpf/simple_ringbuf_test.cpp | 427 +++++++------- test/plugins/wasm_bpf/wasm_bpf.cpp | 5 +- 4 files changed, 480 insertions(+), 481 deletions(-) diff --git a/test/plugins/unittest/testplugin.c b/test/plugins/unittest/testplugin.c index 8239f92b..a0536101 100644 --- a/test/plugins/unittest/testplugin.c +++ b/test/plugins/unittest/testplugin.c @@ -67,10 +67,10 @@ CreateTestModule(const struct WasmEdge_ModuleDescriptor *Desc) { WasmEdge_String FuncName; WasmEdge_FunctionTypeContext *FType; WasmEdge_FunctionInstanceContext *FuncCxt; - enum WasmEdge_ValType ParamTypes[2], ReturnTypes[1]; - ParamTypes[0] = WasmEdge_ValType_I32; - ParamTypes[1] = WasmEdge_ValType_I32; - ReturnTypes[0] = WasmEdge_ValType_I32; + WasmEdge_ValType ParamTypes[2], ReturnTypes[1]; + ParamTypes[0] = WasmEdge_ValTypeGenI32(); + ParamTypes[1] = WasmEdge_ValTypeGenI32(); + ReturnTypes[0] = WasmEdge_ValTypeGenI32(); /* Create the "add" function and add into the module instance. */ FType = WasmEdge_FunctionTypeCreate(ParamTypes, 2, ReturnTypes, 1); diff --git a/test/plugins/wasm_bpf/simple_map_test.cpp b/test/plugins/wasm_bpf/simple_map_test.cpp index e902bbc2..9a86892c 100644 --- a/test/plugins/wasm_bpf/simple_map_test.cpp +++ b/test/plugins/wasm_bpf/simple_map_test.cpp @@ -1,13 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2022 Second State INC -#include -#include -#include -#include -#include -#include -#include #include "common/defines.h" #include "executor/executor.h" #include "func-attach-bpf-program.h" @@ -18,32 +11,40 @@ #include "plugin/plugin.h" #include "runtime/instance/module.h" #include "wasm-bpf-module.h" + +#include +#include +#include +#include +#include +#include +#include + namespace { -WasmEdge::Runtime::Instance::ModuleInstance* createModule() { - using namespace std::literals::string_view_literals; - WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( - "../../../plugins/wasm_bpf/" - "libwasmedgePluginWasmBpf" WASMEDGE_LIB_EXTENSION)); - if (const auto* Plugin = WasmEdge::Plugin::Plugin::find("wasm_bpf"sv)) { - if (const auto* Module = Plugin->findModule("wasm_bpf"sv)) { - return Module->create().release(); - } +WasmEdge::Runtime::Instance::ModuleInstance *createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasm_bpf/" + "libwasmedgePluginWasmBpf" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasm_bpf"sv)) { + if (const auto *Module = Plugin->findModule("wasm_bpf"sv)) { + return Module->create().release(); } - return nullptr; + } + return nullptr; } std::filesystem::path getAssertsPath() { - std::filesystem::path thisFile(__FILE__); - return thisFile.parent_path() / "assets"; + std::filesystem::path thisFile(__FILE__); + return thisFile.parent_path() / "assets"; } -void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance& memInst, - uint32_t offset, - const std::vector& data) noexcept { - char* buf = memInst.getPointer(offset); - std::copy(data.begin(), data.end(), buf); +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &memInst, + uint32_t offset, const std::vector &data) noexcept { + char *buf = memInst.getPointer(offset); + std::copy(data.begin(), data.end(), buf); } -} // namespace +} // namespace static const uint32_t INDICATING_KEY = 0xABCD; static const uint32_t ADD_VALUE_1_KEY = 0xCDEF; @@ -51,241 +52,241 @@ static const uint32_t ADD_VALUE_2_KEY = 0x1234; static const uint32_t RESULT_VALUE_KEY = 0x7890; TEST(WasmBpfTest, SimpleMapTest) { - using namespace std::string_view_literals; - // Test loading and attaching a bpf program, and some operations of maps - auto module = dynamic_cast(createModule()); - ASSERT_NE(module, nullptr); - - // Create the calling frame with memory instance. - WasmEdge::Runtime::Instance::ModuleInstance moduleInst(""); - // moduleInst.addHostFunc() - moduleInst.addHostMemory( - "memory", std::make_unique( - WasmEdge::AST::MemoryType(1))); - auto* memoryInst = moduleInst.findMemoryExports("memory"); - ASSERT_NE(memoryInst, nullptr); - auto& memoryInstRef = *memoryInst; - WasmEdge::Executor::Executor executor((WasmEdge::Configure())); - WasmEdge::Runtime::CallingFrame CallFrame(&executor, &moduleInst); - - namespace fs = std::filesystem; - auto bpfObject = getAssertsPath() / "simple_map.bpf.o"; - - // Ensure the bpf object we need exists - ASSERT_TRUE(fs::exists(bpfObject)); - - // Read the bpf object into wasm memory - std::ifstream bpfObjStream(bpfObject); - ASSERT_TRUE(bpfObjStream.is_open()); - ASSERT_TRUE(bpfObjStream.good()); - std::vector bpfObjectBytes( - (std::istreambuf_iterator(bpfObjStream)), - std::istreambuf_iterator()); - ASSERT_FALSE(bpfObjectBytes.empty()); - // Offset to put things into memory - uint32_t nextOffset = 1; - - // Put the bpf object into memory - const uint32_t bpfObjectMemoryOffset = nextOffset; - fillMemContent(memoryInstRef, bpfObjectMemoryOffset, bpfObjectBytes); - nextOffset += static_cast(bpfObjectBytes.size()); - - // Fill strings that will be used into memory - std::array strings = { - "test_map", // Map name - "sched_wakeup", // Program names - "" // An empty string - }; - std::array stringOffsets; - - for (size_t i = 0; i < strings.size(); i++) { - std::string currString(strings[i]); - std::vector bytes(currString.begin(), currString.end()); - // Ensure that strings are zero-terminated - bytes.push_back('\0'); - fillMemContent(memoryInstRef, nextOffset, bytes); - stringOffsets[i] = nextOffset; - nextOffset += static_cast(bytes.size()); - } - - // Get function "wasm_load_bpf_object" - auto* loadFunc = module->findFuncExports("wasm_load_bpf_object"); - ASSERT_NE(loadFunc, nullptr); - ASSERT_TRUE(loadFunc->isHostFunction()); - auto& loadFuncHost = - dynamic_cast(loadFunc->getHostFunc()); - - // call "wasm_load_bpf_object" to Load `bootstrap.bpf.o`, and check the - // result - std::array loadResult; - ASSERT_TRUE(loadFuncHost.run( + using namespace std::string_view_literals; + // Test loading and attaching a bpf program, and some operations of maps + auto module = dynamic_cast(createModule()); + ASSERT_NE(module, nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance moduleInst(""); + // moduleInst.addHostFunc() + moduleInst.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *memoryInst = moduleInst.findMemoryExports("memory"); + ASSERT_NE(memoryInst, nullptr); + auto &memoryInstRef = *memoryInst; + WasmEdge::Executor::Executor executor((WasmEdge::Configure())); + WasmEdge::Runtime::CallingFrame CallFrame(&executor, &moduleInst); + + namespace fs = std::filesystem; + auto bpfObject = getAssertsPath() / "simple_map.bpf.o"; + + // Ensure the bpf object we need exists + ASSERT_TRUE(fs::exists(bpfObject)); + + // Read the bpf object into wasm memory + std::ifstream bpfObjStream(bpfObject); + ASSERT_TRUE(bpfObjStream.is_open()); + ASSERT_TRUE(bpfObjStream.good()); + std::vector bpfObjectBytes( + (std::istreambuf_iterator(bpfObjStream)), + std::istreambuf_iterator()); + ASSERT_FALSE(bpfObjectBytes.empty()); + // Offset to put things into memory + uint32_t nextOffset = 1; + + // Put the bpf object into memory + const uint32_t bpfObjectMemoryOffset = nextOffset; + fillMemContent(memoryInstRef, bpfObjectMemoryOffset, bpfObjectBytes); + nextOffset += static_cast(bpfObjectBytes.size()); + + // Fill strings that will be used into memory + std::array strings = { + "test_map", // Map name + "sched_wakeup", // Program names + "" // An empty string + }; + std::array stringOffsets; + + for (size_t i = 0; i < strings.size(); i++) { + std::string currString(strings[i]); + std::vector bytes(currString.begin(), currString.end()); + // Ensure that strings are zero-terminated + bytes.push_back('\0'); + fillMemContent(memoryInstRef, nextOffset, bytes); + stringOffsets[i] = nextOffset; + nextOffset += static_cast(bytes.size()); + } + + // Get function "wasm_load_bpf_object" + auto *loadFunc = module->findFuncExports("wasm_load_bpf_object"); + ASSERT_NE(loadFunc, nullptr); + ASSERT_TRUE(loadFunc->isHostFunction()); + auto &loadFuncHost = + dynamic_cast(loadFunc->getHostFunc()); + + // call "wasm_load_bpf_object" to Load `bootstrap.bpf.o`, and check the + // result + std::array loadResult; + ASSERT_TRUE(loadFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(bpfObjectMemoryOffset), + WasmEdge::ValVariant(static_cast(bpfObjectBytes.size()))}, + loadResult)); + auto handle = loadResult[0].get(); + ASSERT_NE(handle, 0); + + // Get function `wasm_attach_bpf_program` + auto *attachFunc = module->findFuncExports("wasm_attach_bpf_program"); + ASSERT_NE(attachFunc, nullptr); + ASSERT_TRUE(attachFunc->isHostFunction()); + auto &attachFuncHost = dynamic_cast( + attachFunc->getHostFunc()); + + // Call "wasm_attach_bpf_program" to attach, and check the result + std::array attachResult; + ASSERT_TRUE(attachFuncHost.run(CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + WasmEdge::ValVariant(stringOffsets[1]), + // There should be '\0' + WasmEdge::ValVariant(stringOffsets[2]), + }, + attachResult)); + ASSERT_GE(attachResult[0].get(), 0); + + // Get function `wasm_bpf_map_fd_by_name` + auto *mapFdFunc = module->findFuncExports("wasm_bpf_map_fd_by_name"); + ASSERT_NE(mapFdFunc, nullptr); + ASSERT_TRUE(mapFdFunc->isHostFunction()); + auto &mapFdFuncHost = + dynamic_cast(mapFdFunc->getHostFunc()); + + // Call "wasm_bpf_map_fd_by_name" to get the map fd, and check the result + std::array mapFdResult; + ASSERT_TRUE(mapFdFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), WasmEdge::ValVariant(stringOffsets[0])}, + mapFdResult)); + auto mapFd = mapFdResult[0].get(); + ASSERT_GE(mapFd, 0); + + // Get function `wasm_bpf_map_fd_by_name` + auto *mapOptFunc = module->findFuncExports("wasm_bpf_map_operate"); + EXPECT_NE(mapOptFunc, nullptr); + EXPECT_TRUE(mapOptFunc->isHostFunction()); + auto &mapOptFuncHost = + dynamic_cast(mapOptFunc->getHostFunc()); + + // A wrapper to call wasm_bpf_map_operate + auto callMapOperate = [&](int32_t fd, int32_t cmd, uint32_t key, + uint32_t value, uint32_t nextKey, + uint64_t flags) -> int32_t { + std::array callResult; + EXPECT_TRUE(mapOptFuncHost.run( CallFrame, std::initializer_list{ - WasmEdge::ValVariant(bpfObjectMemoryOffset), - WasmEdge::ValVariant(static_cast(bpfObjectBytes.size()))}, - loadResult)); - auto handle = loadResult[0].get(); - ASSERT_NE(handle, 0); - - // Get function `wasm_attach_bpf_program` - auto* attachFunc = module->findFuncExports("wasm_attach_bpf_program"); - ASSERT_NE(attachFunc, nullptr); - ASSERT_TRUE(attachFunc->isHostFunction()); - auto& attachFuncHost = dynamic_cast( - attachFunc->getHostFunc()); - - // Call "wasm_attach_bpf_program" to attach, and check the result - std::array attachResult; - ASSERT_TRUE(attachFuncHost.run(CallFrame, - std::initializer_list{ - WasmEdge::ValVariant(handle), - WasmEdge::ValVariant(stringOffsets[1]), - // There should be '\0' - WasmEdge::ValVariant(stringOffsets[2]), - }, - attachResult)); - ASSERT_GE(attachResult[0].get(), 0); - - // Get function `wasm_bpf_map_fd_by_name` - auto* mapFdFunc = module->findFuncExports("wasm_bpf_map_fd_by_name"); - ASSERT_NE(mapFdFunc, nullptr); - ASSERT_TRUE(mapFdFunc->isHostFunction()); - auto& mapFdFuncHost = - dynamic_cast(mapFdFunc->getHostFunc()); - - // Call "wasm_bpf_map_fd_by_name" to get the map fd, and check the result - std::array mapFdResult; - ASSERT_TRUE(mapFdFuncHost.run(CallFrame, - std::initializer_list{ - WasmEdge::ValVariant(handle), - WasmEdge::ValVariant(stringOffsets[0])}, - mapFdResult)); - auto mapFd = mapFdResult[0].get(); - ASSERT_GE(mapFd, 0); - - // Get function `wasm_bpf_map_fd_by_name` - auto* mapOptFunc = module->findFuncExports("wasm_bpf_map_operate"); - EXPECT_NE(mapOptFunc, nullptr); - EXPECT_TRUE(mapOptFunc->isHostFunction()); - auto& mapOptFuncHost = - dynamic_cast(mapOptFunc->getHostFunc()); - - // A wrapper to call wasm_bpf_map_operate - auto callMapOperate = [&](int32_t fd, int32_t cmd, uint32_t key, - uint32_t value, uint32_t nextKey, - uint64_t flags) -> int32_t { - std::array callResult; - EXPECT_TRUE(mapOptFuncHost.run( - CallFrame, - std::initializer_list{ - WasmEdge::ValVariant(fd), WasmEdge::ValVariant(cmd), - WasmEdge::ValVariant(key), WasmEdge::ValVariant(value), - WasmEdge::ValVariant(nextKey), WasmEdge::ValVariant(flags)}, - callResult)); - return callResult[0].get(); - }; - - auto mapLookupElem = [&](int32_t fd, uint32_t key, - uint32_t valueOut) -> int32_t { - // key found -> returns 0 - // key not found -> returns -1 - return callMapOperate(fd, - 1, // BPF_MAP_LOOKUP_ELEM - key, valueOut, 0, 0); - }; - - auto mapUpdateElem = [&](int32_t fd, uint32_t key, - uint32_t value) -> int32_t { - return callMapOperate(fd, - 2, // BPF_MAP_UPDATE_ELEM - key, value, 0, 0); - }; - - // Helper functions to make read & write more convenient - auto readU64 = [&](uint32_t offset) -> uint64_t { - const auto* ptr = memoryInstRef.getPointer(offset); - EXPECT_NE(ptr, nullptr); - return *ptr; - }; - auto writeU64 = [&](uint32_t offset, uint64_t val) { - auto* ptr = memoryInstRef.getPointer(offset); - EXPECT_NE(ptr, nullptr); - *ptr = val; - }; - - auto writeU32 = [&](uint32_t offset, uint32_t val) { - auto* ptr = memoryInstRef.getPointer(offset); - EXPECT_NE(ptr, nullptr); - *ptr = val; - }; - - // Generate two numbers, which will be stored in the map and calculated the - // summation by the ebpf program - std::mt19937 randGen; - randGen.seed(std::random_device()()); - std::uniform_int_distribution intDist( - 0, static_cast(1e16)); - uint64_t num1 = intDist(randGen); - uint64_t num2 = intDist(randGen); - - // Prepare for wasm memory which is used to store numbers - const uint32_t numOffset1 = nextOffset; - nextOffset += 8; - const uint32_t numOffset2 = nextOffset; - nextOffset += 8; - const uint32_t resultOffset = nextOffset; - nextOffset += 8; - const uint32_t indicatingKeyOffset = nextOffset; - nextOffset += 4; - const uint32_t num1KeyOffset = nextOffset; - nextOffset += 4; - const uint32_t num2KeyOffset = nextOffset; - nextOffset += 4; - const uint32_t resultKeyOffset = nextOffset; - nextOffset += 4; - - writeU32(num1KeyOffset, ADD_VALUE_1_KEY); - writeU32(num2KeyOffset, ADD_VALUE_2_KEY); - writeU32(resultKeyOffset, RESULT_VALUE_KEY); - writeU32(indicatingKeyOffset, INDICATING_KEY); - - writeU32(INDICATING_KEY, indicatingKeyOffset); - - writeU64(numOffset1, num1); - writeU64(numOffset2, num2); - - writeU64(resultOffset, 0); - - // Write the add values into the map - ASSERT_EQ(mapUpdateElem(mapFd, num1KeyOffset, numOffset1), 0); - ASSERT_EQ(mapUpdateElem(mapFd, num2KeyOffset, numOffset2), 0); - - // Write the indicating key - // Arbitrary values are correct. We only care the existence of the - // indicating key - ASSERT_EQ(mapUpdateElem(mapFd, indicatingKeyOffset, numOffset1), 0); - - // Sleep for 1s and wait for the ebpf program to process.. - std::this_thread::sleep_for(std::chrono::seconds(2)); - - // Read the result and check it - ASSERT_EQ(mapLookupElem(mapFd, resultKeyOffset, resultOffset), 0); - uint64_t addResult = readU64(resultOffset); - ASSERT_EQ(addResult, num1 + num2); - - // Get function `wasm_close_bpf_object` - auto* closeFunc = module->findFuncExports("wasm_close_bpf_object"); - ASSERT_NE(closeFunc, nullptr); - ASSERT_TRUE(closeFunc->isHostFunction()); - auto& closeFuncHost = - dynamic_cast(closeFunc->getHostFunc()); - - // Call "wasm_close_bpf_object" to attach, and check the result - std::array closeResult; - ASSERT_TRUE(closeFuncHost.run(CallFrame, - std::initializer_list{ - WasmEdge::ValVariant(handle), - }, - closeResult)); - ASSERT_EQ(closeResult[0].get(), 0); + WasmEdge::ValVariant(fd), WasmEdge::ValVariant(cmd), + WasmEdge::ValVariant(key), WasmEdge::ValVariant(value), + WasmEdge::ValVariant(nextKey), WasmEdge::ValVariant(flags)}, + callResult)); + return callResult[0].get(); + }; + + auto mapLookupElem = [&](int32_t fd, uint32_t key, + uint32_t valueOut) -> int32_t { + // key found -> returns 0 + // key not found -> returns -1 + return callMapOperate(fd, + 1, // BPF_MAP_LOOKUP_ELEM + key, valueOut, 0, 0); + }; + + auto mapUpdateElem = [&](int32_t fd, uint32_t key, + uint32_t value) -> int32_t { + return callMapOperate(fd, + 2, // BPF_MAP_UPDATE_ELEM + key, value, 0, 0); + }; + + // Helper functions to make read & write more convenient + auto readU64 = [&](uint32_t offset) -> uint64_t { + const auto *ptr = memoryInstRef.getPointer(offset); + EXPECT_NE(ptr, nullptr); + return *ptr; + }; + auto writeU64 = [&](uint32_t offset, uint64_t val) { + auto *ptr = memoryInstRef.getPointer(offset); + EXPECT_NE(ptr, nullptr); + *ptr = val; + }; + + auto writeU32 = [&](uint32_t offset, uint32_t val) { + auto *ptr = memoryInstRef.getPointer(offset); + EXPECT_NE(ptr, nullptr); + *ptr = val; + }; + + // Generate two numbers, which will be stored in the map and calculated the + // summation by the ebpf program + std::mt19937 randGen; + randGen.seed(std::random_device()()); + std::uniform_int_distribution intDist(0, + static_cast(1e16)); + uint64_t num1 = intDist(randGen); + uint64_t num2 = intDist(randGen); + + // Prepare for wasm memory which is used to store numbers + const uint32_t numOffset1 = nextOffset; + nextOffset += 8; + const uint32_t numOffset2 = nextOffset; + nextOffset += 8; + const uint32_t resultOffset = nextOffset; + nextOffset += 8; + const uint32_t indicatingKeyOffset = nextOffset; + nextOffset += 4; + const uint32_t num1KeyOffset = nextOffset; + nextOffset += 4; + const uint32_t num2KeyOffset = nextOffset; + nextOffset += 4; + const uint32_t resultKeyOffset = nextOffset; + nextOffset += 4; + + writeU32(num1KeyOffset, ADD_VALUE_1_KEY); + writeU32(num2KeyOffset, ADD_VALUE_2_KEY); + writeU32(resultKeyOffset, RESULT_VALUE_KEY); + writeU32(indicatingKeyOffset, INDICATING_KEY); + + writeU32(INDICATING_KEY, indicatingKeyOffset); + + writeU64(numOffset1, num1); + writeU64(numOffset2, num2); + + writeU64(resultOffset, 0); + + // Write the add values into the map + ASSERT_EQ(mapUpdateElem(mapFd, num1KeyOffset, numOffset1), 0); + ASSERT_EQ(mapUpdateElem(mapFd, num2KeyOffset, numOffset2), 0); + + // Write the indicating key + // Arbitrary values are correct. We only care the existence of the + // indicating key + ASSERT_EQ(mapUpdateElem(mapFd, indicatingKeyOffset, numOffset1), 0); + + // Sleep for 1s and wait for the ebpf program to process.. + std::this_thread::sleep_for(std::chrono::seconds(2)); + + // Read the result and check it + ASSERT_EQ(mapLookupElem(mapFd, resultKeyOffset, resultOffset), 0); + uint64_t addResult = readU64(resultOffset); + ASSERT_EQ(addResult, num1 + num2); + + // Get function `wasm_close_bpf_object` + auto *closeFunc = module->findFuncExports("wasm_close_bpf_object"); + ASSERT_NE(closeFunc, nullptr); + ASSERT_TRUE(closeFunc->isHostFunction()); + auto &closeFuncHost = + dynamic_cast(closeFunc->getHostFunc()); + + // Call "wasm_close_bpf_object" to attach, and check the result + std::array closeResult; + ASSERT_TRUE(closeFuncHost.run(CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + }, + closeResult)); + ASSERT_EQ(closeResult[0].get(), 0); } diff --git a/test/plugins/wasm_bpf/simple_ringbuf_test.cpp b/test/plugins/wasm_bpf/simple_ringbuf_test.cpp index 9202d8f8..a46d2e73 100644 --- a/test/plugins/wasm_bpf/simple_ringbuf_test.cpp +++ b/test/plugins/wasm_bpf/simple_ringbuf_test.cpp @@ -1,9 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2022 Second State INC -#include -#include -#include #include "common/defines.h" #include "executor/executor.h" #include "func-attach-bpf-program.h" @@ -14,232 +11,234 @@ #include "plugin/plugin.h" #include "runtime/instance/module.h" #include "wasm-bpf-module.h" + +#include +#include +#include + namespace { -WasmEdge::Runtime::Instance::ModuleInstance* createModule() { - using namespace std::literals::string_view_literals; - WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( - "../../../plugins/wasm_bpf/" - "libwasmedgePluginWasmBpf" WASMEDGE_LIB_EXTENSION)); - if (const auto* Plugin = WasmEdge::Plugin::Plugin::find("wasm_bpf"sv)) { - if (const auto* Module = Plugin->findModule("wasm_bpf"sv)) { - return Module->create().release(); - } +WasmEdge::Runtime::Instance::ModuleInstance *createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasm_bpf/" + "libwasmedgePluginWasmBpf" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasm_bpf"sv)) { + if (const auto *Module = Plugin->findModule("wasm_bpf"sv)) { + return Module->create().release(); } - return nullptr; + } + return nullptr; } std::filesystem::path getAssertsPath() { - std::filesystem::path thisFile(__FILE__); - return thisFile.parent_path() / "assets"; + std::filesystem::path thisFile(__FILE__); + return thisFile.parent_path() / "assets"; } -void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance& memInst, - uint32_t offset, - const std::vector& data) noexcept { - char* buf = memInst.getPointer(offset); - std::copy(data.begin(), data.end(), buf); +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &memInst, + uint32_t offset, const std::vector &data) noexcept { + char *buf = memInst.getPointer(offset); + std::copy(data.begin(), data.end(), buf); } class PollCallbackFunction : public WasmEdge::Runtime::HostFunction { - public: - PollCallbackFunction() {} - WasmEdge::Expect body(const WasmEdge::Runtime::CallingFrame& Frame, - uint32_t __attribute__((unused)) ctx, - uint32_t data, - uint32_t data_sz) { - using namespace std; - using WasmEdge::unlikely; - auto* memory = Frame.getMemoryByIndex(0); - if (unlikely(!memory)) { - return WasmEdge::Unexpect(WasmEdge::ErrCode::Value::HostFuncError); - } - if (data_sz < static_cast(sizeof(uint32_t))) { - return WasmEdge::Unexpect(WasmEdge::ErrCode::Value::HostFuncError); - } - const uint32_t* dataPtr = memory->getSpan(data, 1).data(); - if (unlikely(!dataPtr)) { - return WasmEdge::Unexpect(WasmEdge::ErrCode::Value::HostFuncError); - } - EXPECT_EQ(*dataPtr, UINT32_C(0xABCD1234)); - return 0; +public: + PollCallbackFunction() {} + WasmEdge::Expect body(const WasmEdge::Runtime::CallingFrame &Frame, + uint32_t __attribute__((unused)) ctx, + uint32_t data, uint32_t data_sz) { + using namespace std; + using WasmEdge::unlikely; + auto *memory = Frame.getMemoryByIndex(0); + if (unlikely(!memory)) { + return WasmEdge::Unexpect(WasmEdge::ErrCode::Value::HostFuncError); + } + if (data_sz < static_cast(sizeof(uint32_t))) { + return WasmEdge::Unexpect(WasmEdge::ErrCode::Value::HostFuncError); } + const uint32_t *dataPtr = memory->getSpan(data, 1).data(); + if (unlikely(!dataPtr)) { + return WasmEdge::Unexpect(WasmEdge::ErrCode::Value::HostFuncError); + } + EXPECT_EQ(*dataPtr, UINT32_C(0xABCD1234)); + return 0; + } }; -} // namespace +} // namespace TEST(WasmBpfTest, SimpleRingbuf) { - using namespace std::string_view_literals; - // Test loading and attaching a bpf program, and polling buffer - auto module = dynamic_cast(createModule()); - ASSERT_NE(module, nullptr); - - // Create the calling frame with memory instance. - WasmEdge::Runtime::Instance::ModuleInstance moduleInst(""); - // moduleInst.addHostFunc() - moduleInst.addHostMemory( - "memory", std::make_unique( - WasmEdge::AST::MemoryType(1))); - auto* memoryInst = moduleInst.findMemoryExports("memory"); - ASSERT_NE(memoryInst, nullptr); - auto& memoryInstRef = *memoryInst; - WasmEdge::Executor::Executor executor((WasmEdge::Configure())); - WasmEdge::Runtime::CallingFrame CallFrame(&executor, &moduleInst); - - namespace fs = std::filesystem; - auto bpfObject = getAssertsPath() / "simple_ringbuf.bpf.o"; - - // Ensure the bpf object we need exists - ASSERT_TRUE(fs::exists(bpfObject)); - - // Read the bpf object into wasm memory - std::ifstream bpfObjStream(bpfObject); - ASSERT_TRUE(bpfObjStream.is_open()); - ASSERT_TRUE(bpfObjStream.good()); - std::vector bpfObjectBytes( - (std::istreambuf_iterator(bpfObjStream)), - std::istreambuf_iterator()); - ASSERT_FALSE(bpfObjectBytes.empty()); - // Offset to put things into memory - uint32_t nextOffset = 1; - - // Put the bpf object into memory - const uint32_t bpfObjectMemoryOffset = nextOffset; - fillMemContent(memoryInstRef, bpfObjectMemoryOffset, bpfObjectBytes); - nextOffset += static_cast(bpfObjectBytes.size()); - - // Fill strings that will be used into memory - std::array strings = { - "rb", // Map name - "handle_exec", // Program names - "" // An empty string - }; - std::array stringOffsets; - - for (size_t i = 0; i < strings.size(); i++) { - std::string currString(strings[i]); - std::vector bytes(currString.begin(), currString.end()); - // Ensure that strings are zero-terminated - bytes.push_back('\0'); - fillMemContent(memoryInstRef, nextOffset, bytes); - stringOffsets[i] = nextOffset; - nextOffset += static_cast(bytes.size()); - } - - const uint32_t bufferPollMemoryOffset = nextOffset; - const uint32_t bufferPollSize = 256; - nextOffset += bufferPollSize; - - // Get function "wasm_load_bpf_object" - auto* loadFunc = module->findFuncExports("wasm_load_bpf_object"); - ASSERT_NE(loadFunc, nullptr); - ASSERT_TRUE(loadFunc->isHostFunction()); - auto& loadFuncHost = - dynamic_cast(loadFunc->getHostFunc()); - - // call "wasm_load_bpf_object" to Load `bootstrap.bpf.o`, and check the - // result - std::array loadResult; - ASSERT_TRUE(loadFuncHost.run( + using namespace std::string_view_literals; + // Test loading and attaching a bpf program, and polling buffer + auto module = dynamic_cast(createModule()); + ASSERT_NE(module, nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance moduleInst(""); + // moduleInst.addHostFunc() + moduleInst.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *memoryInst = moduleInst.findMemoryExports("memory"); + ASSERT_NE(memoryInst, nullptr); + auto &memoryInstRef = *memoryInst; + WasmEdge::Executor::Executor executor((WasmEdge::Configure())); + WasmEdge::Runtime::CallingFrame CallFrame(&executor, &moduleInst); + + namespace fs = std::filesystem; + auto bpfObject = getAssertsPath() / "simple_ringbuf.bpf.o"; + + // Ensure the bpf object we need exists + ASSERT_TRUE(fs::exists(bpfObject)); + + // Read the bpf object into wasm memory + std::ifstream bpfObjStream(bpfObject); + ASSERT_TRUE(bpfObjStream.is_open()); + ASSERT_TRUE(bpfObjStream.good()); + std::vector bpfObjectBytes( + (std::istreambuf_iterator(bpfObjStream)), + std::istreambuf_iterator()); + ASSERT_FALSE(bpfObjectBytes.empty()); + // Offset to put things into memory + uint32_t nextOffset = 1; + + // Put the bpf object into memory + const uint32_t bpfObjectMemoryOffset = nextOffset; + fillMemContent(memoryInstRef, bpfObjectMemoryOffset, bpfObjectBytes); + nextOffset += static_cast(bpfObjectBytes.size()); + + // Fill strings that will be used into memory + std::array strings = { + "rb", // Map name + "handle_exec", // Program names + "" // An empty string + }; + std::array stringOffsets; + + for (size_t i = 0; i < strings.size(); i++) { + std::string currString(strings[i]); + std::vector bytes(currString.begin(), currString.end()); + // Ensure that strings are zero-terminated + bytes.push_back('\0'); + fillMemContent(memoryInstRef, nextOffset, bytes); + stringOffsets[i] = nextOffset; + nextOffset += static_cast(bytes.size()); + } + + const uint32_t bufferPollMemoryOffset = nextOffset; + const uint32_t bufferPollSize = 256; + nextOffset += bufferPollSize; + + // Get function "wasm_load_bpf_object" + auto *loadFunc = module->findFuncExports("wasm_load_bpf_object"); + ASSERT_NE(loadFunc, nullptr); + ASSERT_TRUE(loadFunc->isHostFunction()); + auto &loadFuncHost = + dynamic_cast(loadFunc->getHostFunc()); + + // call "wasm_load_bpf_object" to Load `bootstrap.bpf.o`, and check the + // result + std::array loadResult; + ASSERT_TRUE(loadFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(bpfObjectMemoryOffset), + WasmEdge::ValVariant(static_cast(bpfObjectBytes.size()))}, + loadResult)); + auto handle = loadResult[0].get(); + ASSERT_NE(handle, 0); + + // Get function `wasm_attach_bpf_program` + auto *attachFunc = module->findFuncExports("wasm_attach_bpf_program"); + ASSERT_NE(attachFunc, nullptr); + ASSERT_TRUE(attachFunc->isHostFunction()); + auto &attachFuncHost = dynamic_cast( + attachFunc->getHostFunc()); + + // Call "wasm_attach_bpf_program" to attach, and check the result + std::array attachResult; + ASSERT_TRUE(attachFuncHost.run(CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + WasmEdge::ValVariant(stringOffsets[1]), + // There should be '\0' + WasmEdge::ValVariant(stringOffsets[2]), + }, + attachResult)); + ASSERT_GE(attachResult[0].get(), 0); + + // Get function `wasm_bpf_map_fd_by_name` + auto *mapFdFunc = module->findFuncExports("wasm_bpf_map_fd_by_name"); + ASSERT_NE(mapFdFunc, nullptr); + ASSERT_TRUE(mapFdFunc->isHostFunction()); + auto &mapFdFuncHost = + dynamic_cast(mapFdFunc->getHostFunc()); + + // Call "wasm_bpf_map_fd_by_name" to get the map fd, and check the result + std::array mapFdResult; + ASSERT_TRUE(mapFdFuncHost.run( + CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), WasmEdge::ValVariant(stringOffsets[0])}, + mapFdResult)); + auto mapFd = mapFdResult[0].get(); + ASSERT_GE(mapFd, 0); + + // In the following several steps we will prepare for polling + // Create an instance of the polling callback function + auto callbackFuncInst = + std::make_unique( + &moduleInst, std::make_unique()); + // Create a function table, and fill the callback function into it + auto funcTableInst = + std::make_unique( + WasmEdge::AST::TableType(WasmEdge::RefTypeCode::FuncRef, 1)); + ASSERT_TRUE(funcTableInst->setRefs( + std::initializer_list{callbackFuncInst.get()}, + 0, 0, 1)); + // Add the table to the main module + moduleInst.addHostTable("__indirect_function_table"sv, + std::move(funcTableInst)); + + // Get the "wasm_bpf_buffer_poll" function + auto *bufferPollFunc = module->findFuncExports("wasm_bpf_buffer_poll"); + ASSERT_NE(bufferPollFunc, nullptr); + ASSERT_TRUE(bufferPollFunc->isHostFunction()); + auto &bufferPollFuncHost = dynamic_cast( + bufferPollFunc->getHostFunc()); + + // Call the polling function + std::array pollResult; + for (size_t i = 1; i <= 50; i++) { + using namespace std; + ASSERT_TRUE(bufferPollFuncHost.run( CallFrame, std::initializer_list{ - WasmEdge::ValVariant(bpfObjectMemoryOffset), - WasmEdge::ValVariant(static_cast(bpfObjectBytes.size()))}, - loadResult)); - auto handle = loadResult[0].get(); - ASSERT_NE(handle, 0); - - // Get function `wasm_attach_bpf_program` - auto* attachFunc = module->findFuncExports("wasm_attach_bpf_program"); - ASSERT_NE(attachFunc, nullptr); - ASSERT_TRUE(attachFunc->isHostFunction()); - auto& attachFuncHost = dynamic_cast( - attachFunc->getHostFunc()); - - // Call "wasm_attach_bpf_program" to attach, and check the result - std::array attachResult; - ASSERT_TRUE(attachFuncHost.run(CallFrame, - std::initializer_list{ - WasmEdge::ValVariant(handle), - WasmEdge::ValVariant(stringOffsets[1]), - // There should be '\0' - WasmEdge::ValVariant(stringOffsets[2]), - }, - attachResult)); - ASSERT_GE(attachResult[0].get(), 0); - - // Get function `wasm_bpf_map_fd_by_name` - auto* mapFdFunc = module->findFuncExports("wasm_bpf_map_fd_by_name"); - ASSERT_NE(mapFdFunc, nullptr); - ASSERT_TRUE(mapFdFunc->isHostFunction()); - auto& mapFdFuncHost = - dynamic_cast(mapFdFunc->getHostFunc()); - - // Call "wasm_bpf_map_fd_by_name" to get the map fd, and check the result - std::array mapFdResult; - ASSERT_TRUE(mapFdFuncHost.run(CallFrame, - std::initializer_list{ - WasmEdge::ValVariant(handle), - WasmEdge::ValVariant(stringOffsets[0])}, - mapFdResult)); - auto mapFd = mapFdResult[0].get(); - ASSERT_GE(mapFd, 0); - - // In the following several steps we will prepare for polling - // Create an instance of the polling callback function - auto callbackFuncInst = - std::make_unique( - &moduleInst, std::make_unique()); - // Create a function table, and fill the callback function into it - auto funcTableInst = - std::make_unique( - WasmEdge::AST::TableType(WasmEdge::RefType::FuncRef, 1)); - ASSERT_TRUE(funcTableInst->setRefs( - std::initializer_list{ - WasmEdge::FuncRef(callbackFuncInst.get())}, - 0, 0, 1)); - // Add the table to the main module - moduleInst.addHostTable("__indirect_function_table"sv, - std::move(funcTableInst)); - - // Get the "wasm_bpf_buffer_poll" function - auto* bufferPollFunc = module->findFuncExports("wasm_bpf_buffer_poll"); - ASSERT_NE(bufferPollFunc, nullptr); - ASSERT_TRUE(bufferPollFunc->isHostFunction()); - auto& bufferPollFuncHost = dynamic_cast( - bufferPollFunc->getHostFunc()); - - // Call the polling function - std::array pollResult; - for (size_t i = 1; i <= 50; i++) { - using namespace std; - ASSERT_TRUE(bufferPollFuncHost.run( - CallFrame, - std::initializer_list{ - WasmEdge::ValVariant(handle), // object handle - WasmEdge::ValVariant(mapFd), // map fd - UINT32_C(0), // callback function index - UINT32_C(0), // Custom context pointer - WasmEdge::ValVariant(bufferPollMemoryOffset), // buffer offset - WasmEdge::ValVariant(bufferPollSize), // buffer size - UINT32_C(100) // timeout (ms) - }, - pollResult)); - ASSERT_GE(pollResult[0].get(), 0); - } - - // Get function `wasm_close_bpf_object` - auto* closeFunc = module->findFuncExports("wasm_close_bpf_object"); - ASSERT_NE(closeFunc, nullptr); - ASSERT_TRUE(closeFunc->isHostFunction()); - auto& closeFuncHost = - dynamic_cast(closeFunc->getHostFunc()); - - // Call "wasm_close_bpf_object" to attach, and check the result - std::array closeResult; - ASSERT_TRUE(closeFuncHost.run(CallFrame, - std::initializer_list{ - WasmEdge::ValVariant(handle), - }, - closeResult)); - ASSERT_EQ(closeResult[0].get(), 0); + WasmEdge::ValVariant(handle), // object handle + WasmEdge::ValVariant(mapFd), // map fd + UINT32_C(0), // callback function index + UINT32_C(0), // Custom context pointer + WasmEdge::ValVariant(bufferPollMemoryOffset), // buffer offset + WasmEdge::ValVariant(bufferPollSize), // buffer size + UINT32_C(100) // timeout (ms) + }, + pollResult)); + ASSERT_GE(pollResult[0].get(), 0); + } + + // Get function `wasm_close_bpf_object` + auto *closeFunc = module->findFuncExports("wasm_close_bpf_object"); + ASSERT_NE(closeFunc, nullptr); + ASSERT_TRUE(closeFunc->isHostFunction()); + auto &closeFuncHost = + dynamic_cast(closeFunc->getHostFunc()); + + // Call "wasm_close_bpf_object" to attach, and check the result + std::array closeResult; + ASSERT_TRUE(closeFuncHost.run(CallFrame, + std::initializer_list{ + WasmEdge::ValVariant(handle), + }, + closeResult)); + ASSERT_EQ(closeResult[0].get(), 0); } diff --git a/test/plugins/wasm_bpf/wasm_bpf.cpp b/test/plugins/wasm_bpf/wasm_bpf.cpp index 11adc89a..c073b72a 100644 --- a/test/plugins/wasm_bpf/wasm_bpf.cpp +++ b/test/plugins/wasm_bpf/wasm_bpf.cpp @@ -285,10 +285,9 @@ TEST(WasmBpfTest, RunBpfProgramWithPolling) { // Create a function table, and fill the callback function into it auto funcTableInst = std::make_unique( - WasmEdge::AST::TableType(WasmEdge::RefType::FuncRef, 1)); + WasmEdge::AST::TableType(WasmEdge::RefTypeCode::FuncRef, 1)); EXPECT_TRUE(funcTableInst->setRefs( - std::initializer_list{ - WasmEdge::FuncRef(callbackFuncInst.get())}, + std::initializer_list{callbackFuncInst.get()}, 0, 0, 1)); // Add the table to the main module moduleInst.addHostTable("__indirect_function_table"sv, From d9db1aed14f53ef91dc288a7db877d2852bd3b4a Mon Sep 17 00:00:00 2001 From: YiYing He Date: Fri, 15 Sep 2023 01:53:15 +0800 Subject: [PATCH 198/623] [Misc] Refactor the value types and enums. 1. Select all value types into `WasmEdge::ValType` class. 2. Select all types related enums into `WasmEdge_TypeCode` enum. Signed-off-by: YiYing He --- test/plugins/wasm_bpf/simple_ringbuf_test.cpp | 2 +- test/plugins/wasm_bpf/wasm_bpf.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/plugins/wasm_bpf/simple_ringbuf_test.cpp b/test/plugins/wasm_bpf/simple_ringbuf_test.cpp index a46d2e73..b615524e 100644 --- a/test/plugins/wasm_bpf/simple_ringbuf_test.cpp +++ b/test/plugins/wasm_bpf/simple_ringbuf_test.cpp @@ -192,7 +192,7 @@ TEST(WasmBpfTest, SimpleRingbuf) { // Create a function table, and fill the callback function into it auto funcTableInst = std::make_unique( - WasmEdge::AST::TableType(WasmEdge::RefTypeCode::FuncRef, 1)); + WasmEdge::AST::TableType(WasmEdge::TypeCode::FuncRef, 1)); ASSERT_TRUE(funcTableInst->setRefs( std::initializer_list{callbackFuncInst.get()}, 0, 0, 1)); diff --git a/test/plugins/wasm_bpf/wasm_bpf.cpp b/test/plugins/wasm_bpf/wasm_bpf.cpp index c073b72a..5f80042d 100644 --- a/test/plugins/wasm_bpf/wasm_bpf.cpp +++ b/test/plugins/wasm_bpf/wasm_bpf.cpp @@ -285,7 +285,7 @@ TEST(WasmBpfTest, RunBpfProgramWithPolling) { // Create a function table, and fill the callback function into it auto funcTableInst = std::make_unique( - WasmEdge::AST::TableType(WasmEdge::RefTypeCode::FuncRef, 1)); + WasmEdge::AST::TableType(WasmEdge::TypeCode::FuncRef, 1)); EXPECT_TRUE(funcTableInst->setRefs( std::initializer_list{callbackFuncInst.get()}, 0, 0, 1)); From 22788cc4db972bc6f6dc1ba04e16c54735b4a5fb Mon Sep 17 00:00:00 2001 From: hydai Date: Mon, 18 Dec 2023 13:11:48 +0900 Subject: [PATCH 199/623] [WASI-NN] ggml backend: bump to llama.cpp b1656 Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 6db7ff75..6feb7e7d 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -47,7 +47,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b1616 + GIT_TAG b1656 PATCH_COMMAND test -f ggml.patched || git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch && ${CMAKE_COMMAND} -E touch ggml.patched GIT_SHALLOW TRUE ) From 8c51e49d8c32db96f2fc221464ed7737baf4f60f Mon Sep 17 00:00:00 2001 From: hydai Date: Mon, 18 Dec 2023 13:29:44 +0900 Subject: [PATCH 200/623] [WASI-NN] ggml backend: force enable metal on macOS Signed-off-by: hydai --- plugins/wasi_nn/ggml.cpp | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 4eb233b7..f8717c01 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -167,8 +167,12 @@ Expect parseModelConfig(Graph &GraphRef, Token = Config.substr(0, Pos); try { if (Token == "n_gpu_layers" || Token == "ngl") { +#ifndef __APPLE__ GraphRef.NGPULayers = std::stoi(Config.substr(Pos + Delimiter.length())); +#else + GraphRef.NGPULayers = 1; // Force enabled Metal on macOS +#endif } } catch (const std::invalid_argument &e) { spdlog::error( @@ -333,11 +337,12 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, return Res; } - // XXX: Due to the limitation of WASI-NN proposal, - // we have no way to pass the metadata before the setInput phase - // when we want to do some configurations in the load phase. - // That's why we have this hack. #ifndef __APPLE__ + // XXX: Due to the limitation of WASI-NN proposal, + // this is a workaround for non-macOS devices. + // However, if the model params is updated in Config stage, + // then, we doesn't encourage to use this to avoid the model + // reloading. { if (IsModelParamsUpdated) { llama_model_params ModelParams = llama_model_default_params(); From 69ec36f8b362ac066b39efadcdc261fabffbd4ae Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 20 Dec 2023 19:30:32 +0800 Subject: [PATCH 201/623] [WASI-NN] ggml: force set ngl=1 on macOS to enable Metal (#3102) Signed-off-by: hydai --- plugins/wasi_nn/ggml.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index f8717c01..7c3544b9 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -90,6 +90,11 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, return ErrNo::InvalidArgument; } } +#ifdef __APPLE__ + // Whatever the `n-gpu-layers` is given, we will always set the ngl to 1 on + // macOS to forcely enabled Metal. + GraphRef.NGPULayers = 1; // Force enabled Metal on macOS +#endif // The context parameters. if (Doc.at_key("ctx-size").error() == simdjson::SUCCESS) { From a456af89bb3e16f1c1cabd5c1a2606c88bf01b2f Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 14 Dec 2023 15:17:31 +0800 Subject: [PATCH 202/623] [WASI-NN] ggml: support single token inference' Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 199 ++++++++++++++++++++++++++++--- plugins/wasi_nn/ggml.h | 13 ++ plugins/wasi_nn/types.h | 1 + plugins/wasi_nn/wasinnfunc.cpp | 63 ++++++++++ plugins/wasi_nn/wasinnfunc.h | 31 +++++ plugins/wasi_nn/wasinnmodule.cpp | 3 + 6 files changed, 293 insertions(+), 17 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 7c3544b9..2a2063a0 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -28,9 +28,6 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, // Get metadata from the json. // Need to update Model: // * n_gpu_layers - // Need to update Context: - // * ctx-size - // * batch-size // Get the current llama parameters. llama_model_params ModelParams = llama_model_default_params(); @@ -195,6 +192,21 @@ Expect parseModelConfig(Graph &GraphRef, return ErrNo::Success; } +Expect buildOutputMetadata(Context &CxtRef, + std::string &Metadata) noexcept { + std::string MetadataTemplate = R"({"input_tokens": %d, "output_tokens": %d})"; + + // The 20 bytes are reserved to accommodate two %d placeholders in the + // MetadataTemplate. This allows for a decimal integer value up to a + // 12-digit number of input/output tokens. + char Buffer[MetadataTemplate.size() + 20]; + snprintf(Buffer, sizeof(Buffer), MetadataTemplate.c_str(), + CxtRef.LlamaInputs.size(), CxtRef.LlamaOutputTokens.size()); + Metadata = std::string(Buffer); + + return ErrNo::Success; +} + } // namespace details Expect load(WasiNNEnvironment &Env, Span> Builders, @@ -314,6 +326,11 @@ Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, uint32_t &ContextId) noexcept { Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); ContextId = Env.NNContext.size() - 1; + auto &GraphRef = Env.NNGraph[GraphId].get(); + if (GraphRef.EnableLog) { + spdlog::info("[WASI-NN] GGML backend: llama_system_info: {}"sv, + llama_print_system_info()); + } return ErrNo::Success; } @@ -429,15 +446,13 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, auto &CxtRef = Env.NNContext[ContextId].get(); // Index 1 is for the metadata of the outputs. if (Index == 1) { - std::string MetadataTemplate = - R"({"input_tokens": %d, "output_tokens": %d})"; - // The 20 bytes are reserved to accommodate two %d placeholders in the - // MetadataTemplate. This allows for a decimal integer value up to a - // 12-digit number of input/output tokens. - char Buffer[MetadataTemplate.size() + 20]; - snprintf(Buffer, sizeof(Buffer), MetadataTemplate.c_str(), - CxtRef.LlamaInputs.size(), CxtRef.LlamaOutputTokens.size()); - std::string Metadata(Buffer); + std::string Metadata; + auto Res = details::buildOutputMetadata(CxtRef, Metadata); + if (Res != ErrNo::Success) { + spdlog::error( + "[WASI-NN] GGML backend: Failed to build output metadata."sv); + return Res; + } std::copy_n(Metadata.data(), Metadata.length(), OutBuffer.data()); BytesWritten = Metadata.length(); return ErrNo::Success; @@ -460,11 +475,6 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { return ErrNo::InvalidArgument; } - if (GraphRef.EnableLog) { - spdlog::info("[WASI-NN] GGML backend: llama_system_info: {}"sv, - llama_print_system_info()); - } - // Clear the outputs. if (GraphRef.EnableDebugLog) { spdlog::info( @@ -608,6 +618,161 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { return ErrNo::Success; } + +Expect getOutputSingle(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index, Span OutBuffer, + uint32_t &BytesWritten) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + // Index 1 is for the metadata of the outputs. + if (Index == 1) { + std::string Metadata; + auto Res = details::buildOutputMetadata(CxtRef, Metadata); + if (Res != ErrNo::Success) { + spdlog::error( + "[WASI-NN] GGML backend: Failed to build output metadata."sv); + return Res; + } + std::copy_n(Metadata.data(), Metadata.length(), OutBuffer.data()); + BytesWritten = Metadata.length(); + return ErrNo::Success; + } + std::string LastToken = llama_token_to_piece(CxtRef.LlamaContext, + CxtRef.LlamaOutputTokens.back()); + std::copy_n(LastToken.data(), LastToken.length(), OutBuffer.data()); + BytesWritten = LastToken.length(); + return ErrNo::Success; +} + +Expect computeSingle(WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + + // Logging. + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: computeSingleToken"sv); + } + if (CxtRef.LlamaInputs.size() == 0) { + spdlog::error("[WASI-NN] GGML backend: Llama input is not set!"sv); + return ErrNo::InvalidArgument; + } + + // New compute single token context. + if (CxtRef.LlamaContext == nullptr) { + // Clear the outputs. + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: clear the previous output and tokens"sv); + } + CxtRef.LlamaOutputs.clear(); + CxtRef.LlamaOutputTokens.clear(); + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: clear the previous output and tokens...Done"sv); + } + + // Initialize the llama context. + gpt_params GPTParams; + GPTParams.sparams.temp = GraphRef.Temp; + GPTParams.sparams.penalty_repeat = GraphRef.RepeatPenalty; + CxtRef.LlamaSampling = llama_sampling_init(GPTParams.sparams); + llama_context_params ContextParams = llama_context_default_params(); + ContextParams.n_ctx = GraphRef.CtxSize; + ContextParams.n_batch = GraphRef.BatchSize; + CxtRef.LlamaContext = + llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); + CxtRef.LlamaEmbd.clear(); + CxtRef.LlamaNPast = 0; + CxtRef.LlamaNConsumed = 0; + } + + const int NCtx = llama_n_ctx(CxtRef.LlamaContext); + // Minus 4 for the special tokens. (Such as , , ... tokens.) + const int MaxTokensListSize = NCtx - 4; + // Use the const sequence id here. + const int SequenceId = 0; + + while (true) { + if (!CxtRef.LlamaEmbd.empty()) { + // Input too long. + if (static_cast(CxtRef.LlamaEmbd.size()) > MaxTokensListSize) { + spdlog::error( + "[WASI-NN] GGML backend: Error: The prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, + CxtRef.LlamaEmbd.size(), MaxTokensListSize); + return ErrNo::RuntimeError; + } + + // We do not swap context here. End the inference if the context is full. + if (CxtRef.LlamaNPast + static_cast(CxtRef.LlamaEmbd.size()) > + NCtx) { + if (GraphRef.EnableLog) { + spdlog::info( + "[WASI-NN] GGML backend: the context if full ({} / {} tokens)"sv, + CxtRef.LlamaNPast + static_cast(CxtRef.LlamaEmbd.size()), + NCtx); + } + return ErrNo::RuntimeError; + } + + // Evaluate tokens in batches. + for (int I = 0; I < static_cast(CxtRef.LlamaEmbd.size()); + I += GraphRef.BatchSize) { + int NEval = static_cast(CxtRef.LlamaEmbd.size()) - I; + if (NEval > static_cast(GraphRef.BatchSize)) { + NEval = GraphRef.BatchSize; + } + // llama_batch_get_one(*token, n_tokens, position, sequence_id) + // This will return batch for single sequence of tokens starting at + // position. + if (llama_decode(CxtRef.LlamaContext, + llama_batch_get_one(&CxtRef.LlamaEmbd[I], NEval, + CxtRef.LlamaNPast, SequenceId))) { + spdlog::error("[WASI-NN] GGML backend: failed to llama_decode"sv); + return ErrNo::RuntimeError; + } + + CxtRef.LlamaNPast += NEval; + } + } + + CxtRef.LlamaEmbd.clear(); + + if (static_cast(CxtRef.LlamaInputs.size()) <= CxtRef.LlamaNConsumed) { + const llama_token Id = llama_sampling_sample( + CxtRef.LlamaSampling, CxtRef.LlamaContext, nullptr); + llama_sampling_accept(CxtRef.LlamaSampling, CxtRef.LlamaContext, Id, + true); + CxtRef.LlamaEmbd.emplace_back(Id); + // Save the output token. + CxtRef.LlamaOutputTokens.emplace_back(Id); + CxtRef.LlamaOutputs += llama_token_to_piece(CxtRef.LlamaContext, Id); + // Deal with end of text token. + if (llama_sampling_last(CxtRef.LlamaSampling) == + llama_token_eos(GraphRef.LlamaModel)) { + if (GraphRef.EnableLog) { + spdlog::info("[WASI-NN] GGML backend: EOS token found"sv); + } + return ErrNo::EndOfSequence; + } + return ErrNo::Success; + } else { + while (static_cast(CxtRef.LlamaInputs.size()) > + CxtRef.LlamaNConsumed) { + CxtRef.LlamaEmbd.push_back(CxtRef.LlamaInputs[CxtRef.LlamaNConsumed]); + // Push the prompt in the sampling context. + llama_sampling_accept(CxtRef.LlamaSampling, CxtRef.LlamaContext, + CxtRef.LlamaInputs[CxtRef.LlamaNConsumed], false); + ++CxtRef.LlamaNConsumed; + if (CxtRef.LlamaEmbd.size() >= GraphRef.BatchSize) { + break; + } + } + } + } + + return ErrNo::Success; +} + #else namespace { Expect reportBackendNotSupported() noexcept { diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index e5c7981c..8a014173 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -7,6 +7,7 @@ #include "types.h" #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +#include #include #endif @@ -43,6 +44,12 @@ struct Context { std::vector LlamaInputs; std::string LlamaOutputs; std::vector LlamaOutputTokens; + // Preserve for computing single token + llama_context *LlamaContext = nullptr; + struct llama_sampling_context *LlamaSampling = nullptr; + std::vector LlamaEmbd; + int LlamaNPast; + int LlamaNConsumed; }; #else struct Graph {}; @@ -66,6 +73,12 @@ Expect getOutput(WASINN::WasiNNEnvironment &Env, uint32_t ContextId, uint32_t Index, Span OutBuffer, uint32_t &BytesWritten) noexcept; +Expect getOutputSingle(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; Expect compute(WASINN::WasiNNEnvironment &Env, uint32_t ContextId) noexcept; +Expect computeSingle(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; } // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/types.h b/plugins/wasi_nn/types.h index 98eca041..e797f0d4 100644 --- a/plugins/wasi_nn/types.h +++ b/plugins/wasi_nn/types.h @@ -18,6 +18,7 @@ enum class ErrNo : uint32_t { UnsupportedOperation = 6, // Unsupported Operation. TooLarge = 7, // Too Large. NotFound = 8, // Not Found. + EndOfSequence = 9, // End of Sequence Found. }; enum class TensorType : uint8_t { F16 = 0, F32 = 1, U8 = 2, I32 = 3 }; diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index 39b51009..a257718e 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -303,6 +303,45 @@ WasiNNGetOutput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context, } } +Expect WasiNNGetOutputSingle::bodyImpl( + const Runtime::CallingFrame &Frame, uint32_t Context, uint32_t Index, + uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, + uint32_t BytesWrittenPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + if (Env.NNContext.size() <= Context) { + spdlog::error( + "[WASI-NN] get_output_single: Execution Context does not exist"sv); + return WASINN::ErrNo::InvalidArgument; + } + + const auto OutBuffer = + MemInst->getSpan(OutBufferPtr, OutBufferMaxSize); + if (unlikely(OutBuffer.data() == nullptr)) { + spdlog::error( + "[WASI-NN] Failed when accessing the Output Buffer memory."sv); + return WASINN::ErrNo::InvalidArgument; + } + uint32_t *BytesWritten = MemInst->getPointer(BytesWrittenPtr); + if (unlikely(BytesWritten == nullptr)) { + spdlog::error("[WASI-NN] Failed when accessing the BytesWritten memory."sv); + return WASINN::ErrNo::InvalidArgument; + } + + switch (Env.NNContext[Context].getBackend()) { + case WASINN::Backend::GGML: + return WASINN::GGML::getOutputSingle(Env, Context, Index, OutBuffer, + *BytesWritten); + default: + spdlog::error( + "[WASI-NN] get_output_single: Only GGML backend supports get_output_single."sv); + return WASINN::ErrNo::InvalidArgument; + } +} + Expect WasiNNCompute::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context) { auto *MemInst = Frame.getMemoryByIndex(0); @@ -327,5 +366,29 @@ WasiNNCompute::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context) { } } +Expect +WasiNNComputeSingle::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t Context) { + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + if (Env.NNContext.size() <= Context) { + spdlog::error( + "[WASI-NN] compute_single: Execution Context does not exist."sv); + return WASINN::ErrNo::InvalidArgument; + } + + switch (Env.NNContext[Context].getBackend()) { + case WASINN::Backend::GGML: + return WASINN::GGML::computeSingle(Env, Context); + default: + spdlog::error( + "[WASI-NN] compute_single: Only GGML backend supports compute_single."sv); + return WASINN::ErrNo::InvalidArgument; + } +} + } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasi_nn/wasinnfunc.h b/plugins/wasi_nn/wasinnfunc.h index 5db87f7a..6fd5cd18 100644 --- a/plugins/wasi_nn/wasinnfunc.h +++ b/plugins/wasi_nn/wasinnfunc.h @@ -106,6 +106,25 @@ class WasiNNGetOutput : public WasiNN { uint32_t BytesWrittenPtr); }; +class WasiNNGetOutputSingle : public WasiNN { +public: + WasiNNGetOutputSingle(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t Context, + uint32_t Index, uint32_t OutBufferPtr, + uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr) { + return bodyImpl(Frame, Context, Index, OutBufferPtr, OutBufferMaxSize, + BytesWrittenPtr) + .map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t Context, uint32_t Index, + uint32_t OutBufferPtr, + uint32_t OutBufferMaxSize, + uint32_t BytesWrittenPtr); +}; + class WasiNNCompute : public WasiNN { public: WasiNNCompute(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} @@ -118,5 +137,17 @@ class WasiNNCompute : public WasiNN { uint32_t Context); }; +class WasiNNComputeSingle : public WasiNN { +public: + WasiNNComputeSingle(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t Context) { + return bodyImpl(Frame, Context).map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t Context); +}; + } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasi_nn/wasinnmodule.cpp b/plugins/wasi_nn/wasinnmodule.cpp index cb3eea06..2cf3fbbe 100644 --- a/plugins/wasi_nn/wasinnmodule.cpp +++ b/plugins/wasi_nn/wasinnmodule.cpp @@ -16,7 +16,10 @@ WasiNNModule::WasiNNModule() : ModuleInstance("wasi_ephemeral_nn") { std::make_unique(Env)); addHostFunc("set_input", std::make_unique(Env)); addHostFunc("get_output", std::make_unique(Env)); + addHostFunc("get_output_single", + std::make_unique(Env)); addHostFunc("compute", std::make_unique(Env)); + addHostFunc("compute_single", std::make_unique(Env)); } } // namespace Host From c55dc8bc2ecd2d26451e7753c75cb45f3c7daac4 Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 20 Dec 2023 20:57:16 +0800 Subject: [PATCH 203/623] [WASI-NN] ggml: remove configuration from the nn-prelaod cli parameter Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 70 +++------------------------------------- 1 file changed, 5 insertions(+), 65 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 2a2063a0..ff738f8e 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -137,61 +137,6 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, return ErrNo::Success; } -Expect parseModelConfig(Graph &GraphRef, - std::string ModelFilePathWithConfig, - std::string &ModelFilePath) noexcept { - std::vector Configs; - std::string Delimiter = ","; - if (ModelFilePathWithConfig.find(Delimiter) == std::string::npos) { - ModelFilePath = ModelFilePathWithConfig; - } else { - // Handle model path with config. - size_t Pos = 0; - std::string Token; - Pos = ModelFilePathWithConfig.find(Delimiter); - ModelFilePath = ModelFilePathWithConfig.substr(0, Pos); - ModelFilePathWithConfig.erase(0, Pos + Delimiter.length()); - while ((Pos = ModelFilePathWithConfig.find(Delimiter)) != - std::string::npos) { - Token = ModelFilePathWithConfig.substr(0, Pos); - Configs.emplace_back(Token); - ModelFilePathWithConfig.erase(0, Pos + Delimiter.length()); - } - Configs.emplace_back(ModelFilePathWithConfig); - } - - // Parse the configs. - for (const auto &Config : Configs) { - std::string Delimiter = "="; - size_t Pos = 0; - std::string Token; - Pos = Config.find(Delimiter); - Token = Config.substr(0, Pos); - try { - if (Token == "n_gpu_layers" || Token == "ngl") { -#ifndef __APPLE__ - GraphRef.NGPULayers = - std::stoi(Config.substr(Pos + Delimiter.length())); -#else - GraphRef.NGPULayers = 1; // Force enabled Metal on macOS -#endif - } - } catch (const std::invalid_argument &e) { - spdlog::error( - "[WASI-NN] GGML backend: parse model parameter failed: invalid_argument {}"sv, - e.what()); - return ErrNo::InvalidArgument; - } catch (const std::out_of_range &e) { - spdlog::error( - "[WASI-NN] GGML backend: parse parameter failed: out_of_range {}"sv, - e.what()); - return ErrNo::InvalidArgument; - } - } - - return ErrNo::Success; -} - Expect buildOutputMetadata(Context &CxtRef, std::string &Metadata) noexcept { std::string MetadataTemplate = R"({"input_tokens": %d, "output_tokens": %d})"; @@ -223,6 +168,10 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, GraphRef.NPredict = ContextDefault.n_ctx; // Initialize the model parameters. GraphRef.NGPULayers = 0; +#ifdef __APPLE__ + // We will always set the ngl to 1 on macOS to enable Metal. + GraphRef.NGPULayers = 1; +#endif // Initialize the context parameters. GraphRef.CtxSize = ContextDefault.n_ctx; GraphRef.BatchSize = ContextDefault.n_batch; @@ -252,16 +201,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, std::string BinModel(reinterpret_cast(Weight.data()), Weight.size()); std::string ModelFilePath; if (BinModel.substr(0, 8) == "preload:") { - // If BinModel starts with 'preload:', it means that the model name passed - // in as the --nn-preload parameter may have a config separated by ',' at - // the end. For example, "preload:./model.bin,n_gpu_layers=99" - auto Res = - details::parseModelConfig(GraphRef, BinModel.substr(8), ModelFilePath); - if (Res != ErrNo::Success) { - spdlog::error("[WASI-NN] GGML backend: Failed to parse model config."sv); - Env.NNGraph.pop_back(); - return Res; - } + ModelFilePath = BinModel.substr(8); } else { if (GraphRef.EnableDebugLog) { spdlog::info( From 27a69073e828e7b7238534274b2d2557ea3e27a8 Mon Sep 17 00:00:00 2001 From: dm4 Date: Fri, 22 Dec 2023 18:04:41 +0800 Subject: [PATCH 204/623] [WASI-NN] ggml: show llama commit and build number in log Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- plugins/wasi_nn/ggml.cpp | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 6feb7e7d..8807129b 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -49,7 +49,7 @@ if(BACKEND STREQUAL "ggml") GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git GIT_TAG b1656 PATCH_COMMAND test -f ggml.patched || git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch && ${CMAKE_COMMAND} -E touch ggml.patched - GIT_SHALLOW TRUE + GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(llama) set_property(TARGET ggml PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index ff738f8e..0a1c6efa 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -192,6 +192,11 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, return Res; } } + if (GraphRef.EnableLog) { + spdlog::info("[WASI-NN] GGML backend: LLAMA_COMMIT {}"sv, LLAMA_COMMIT); + spdlog::info("[WASI-NN] GGML backend: LLAMA_BUILD_NUMBER {}"sv, + LLAMA_BUILD_NUMBER); + } if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] GGML backend: Handling model path."sv); From 7f1751ea6465d2ed52386cd979b833c1b74eaab2 Mon Sep 17 00:00:00 2001 From: dm4 Date: Mon, 25 Dec 2023 02:05:36 +0800 Subject: [PATCH 205/623] [WASI-NN] ggml: update ggml to b1698 Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 8807129b..54a8b722 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -47,7 +47,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b1656 + GIT_TAG b1698 PATCH_COMMAND test -f ggml.patched || git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch && ${CMAKE_COMMAND} -E touch ggml.patched GIT_SHALLOW FALSE ) From ba9018b7c5b8f3d08f6c662053801fda693d646b Mon Sep 17 00:00:00 2001 From: dm4 Date: Mon, 25 Dec 2023 11:03:59 +0800 Subject: [PATCH 206/623] [WASI-NN] ggml: fix tests Signed-off-by: dm4 --- test/plugins/wasi_nn/wasi_nn.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index c66849a3..93441808 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -1458,9 +1458,9 @@ TEST(WasiNNTest, GGMLBackend) { UINT32_C(0), UINT32_C(0), StorePtr, 65532, BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); - // Should output more than 100 bytes. + // Should output more than 50 bytes. auto BytesWritten = *MemInst.getPointer(BuilderPtr); - EXPECT_GE(BytesWritten, 100); + EXPECT_GE(BytesWritten, 50); } } #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML From d2942e074728419f7fa68f7bd6a9ac49f000c740 Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 27 Dec 2023 14:50:04 +0800 Subject: [PATCH 207/623] [WASI-NN] ggml: add fini_single and ContextFull error Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 45 ++++++++++++++++++++++++++++++-- plugins/wasi_nn/ggml.h | 2 ++ plugins/wasi_nn/types.h | 1 + plugins/wasi_nn/wasinnfunc.cpp | 23 ++++++++++++++++ plugins/wasi_nn/wasinnfunc.h | 12 +++++++++ plugins/wasi_nn/wasinnmodule.cpp | 1 + 6 files changed, 82 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 0a1c6efa..33fff693 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -456,6 +456,8 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { const int MaxTokensListSize = NCtx - 4; // Use the const sequence id here. const int SequenceId = 0; + // Return value. + auto ReturnCode = ErrNo::Success; while (NRemain >= 0) { // Preidct if (!Embd.empty()) { @@ -474,6 +476,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { "[WASI-NN] GGML backend: the context if full ({} / {} tokens)"sv, NPast + static_cast(Embd.size()), NCtx); } + ReturnCode = ErrNo::ContextFull; break; } @@ -561,7 +564,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { spdlog::info("[WASI-NN][Debug] GGML backend: compute...Done"sv); } - return ErrNo::Success; + return ReturnCode; } Expect getOutputSingle(WasiNNEnvironment &Env, uint32_t ContextId, @@ -656,7 +659,7 @@ Expect computeSingle(WasiNNEnvironment &Env, CxtRef.LlamaNPast + static_cast(CxtRef.LlamaEmbd.size()), NCtx); } - return ErrNo::RuntimeError; + return ErrNo::ContextFull; } // Evaluate tokens in batches. @@ -718,6 +721,44 @@ Expect computeSingle(WasiNNEnvironment &Env, return ErrNo::Success; } +Expect finiSingle(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + + // Clear the outputs. + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: finiSingle: clear the previous output and tokens"sv); + } + CxtRef.LlamaOutputs.clear(); + CxtRef.LlamaOutputTokens.clear(); + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: finiSingle: clear the previous output and tokens...Done"sv); + } + + // Delete the llama context. + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: finiSingle: free the llama context"sv); + } + llama_sampling_free(CxtRef.LlamaSampling); + llama_free(CxtRef.LlamaContext); + CxtRef.LlamaSampling = nullptr; + CxtRef.LlamaContext = nullptr; + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: finiSingle: free the llama context...Done"sv); + } + + // Reset the context variables. + CxtRef.LlamaEmbd.clear(); + CxtRef.LlamaNPast = 0; + CxtRef.LlamaNConsumed = 0; + + return ErrNo::Success; +} + #else namespace { Expect reportBackendNotSupported() noexcept { diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index 8a014173..c94df07e 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -81,4 +81,6 @@ Expect compute(WASINN::WasiNNEnvironment &Env, uint32_t ContextId) noexcept; Expect computeSingle(WASINN::WasiNNEnvironment &Env, uint32_t ContextId) noexcept; +Expect finiSingle(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; } // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/types.h b/plugins/wasi_nn/types.h index e797f0d4..66acfee5 100644 --- a/plugins/wasi_nn/types.h +++ b/plugins/wasi_nn/types.h @@ -19,6 +19,7 @@ enum class ErrNo : uint32_t { TooLarge = 7, // Too Large. NotFound = 8, // Not Found. EndOfSequence = 9, // End of Sequence Found. + ContextFull = 10, // Context Full. }; enum class TensorType : uint8_t { F16 = 0, F32 = 1, U8 = 2, I32 = 3 }; diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index a257718e..3be2bca8 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -390,5 +390,28 @@ WasiNNComputeSingle::bodyImpl(const Runtime::CallingFrame &Frame, } } +Expect +WasiNNFiniSingle::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t Context) { + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + if (Env.NNContext.size() <= Context) { + spdlog::error("[WASI-NN] fini_single: Execution Context does not exist."sv); + return WASINN::ErrNo::InvalidArgument; + } + + switch (Env.NNContext[Context].getBackend()) { + case WASINN::Backend::GGML: + return WASINN::GGML::finiSingle(Env, Context); + default: + spdlog::error( + "[WASI-NN] fini_single: Only GGML backend supports compute_single."sv); + return WASINN::ErrNo::InvalidArgument; + } +} + } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasi_nn/wasinnfunc.h b/plugins/wasi_nn/wasinnfunc.h index 6fd5cd18..7dcfbd2a 100644 --- a/plugins/wasi_nn/wasinnfunc.h +++ b/plugins/wasi_nn/wasinnfunc.h @@ -149,5 +149,17 @@ class WasiNNComputeSingle : public WasiNN { uint32_t Context); }; +class WasiNNFiniSingle : public WasiNN { +public: + WasiNNFiniSingle(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t Context) { + return bodyImpl(Frame, Context).map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t Context); +}; + } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasi_nn/wasinnmodule.cpp b/plugins/wasi_nn/wasinnmodule.cpp index 2cf3fbbe..c2246405 100644 --- a/plugins/wasi_nn/wasinnmodule.cpp +++ b/plugins/wasi_nn/wasinnmodule.cpp @@ -20,6 +20,7 @@ WasiNNModule::WasiNNModule() : ModuleInstance("wasi_ephemeral_nn") { std::make_unique(Env)); addHostFunc("compute", std::make_unique(Env)); addHostFunc("compute_single", std::make_unique(Env)); + addHostFunc("fini_single", std::make_unique(Env)); } } // namespace Host From 4923fcbc6a67031e7fa72b4fa39022c7d70dae34 Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 27 Dec 2023 16:23:34 +0800 Subject: [PATCH 208/623] [WASI-NN] ggml: fix tests for the new ContextFull error Signed-off-by: dm4 --- test/plugins/wasi_nn/wasi_nn.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 93441808..58d6ea7c 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -1419,12 +1419,14 @@ TEST(WasiNNTest, GGMLBackend) { static_cast(ErrNo::InvalidArgument)); } - // Test: compute -- compute successfully. + // Test: compute -- compute until finish or context full. { EXPECT_TRUE(HostFuncCompute.run( CallFrame, std::initializer_list{UINT32_C(0)}, Errno)); - EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_TRUE( + Errno[0].get() == static_cast(ErrNo::Success) || + Errno[0].get() == static_cast(ErrNo::ContextFull)); } // GGML WASI-NN get_output tests. From 36df7aa0433d0a49eaa315cec9294c877fedd1c1 Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 27 Dec 2023 16:54:01 +0800 Subject: [PATCH 209/623] [WASI-NN] ggml: update ggml to b1703 Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 54a8b722..3e31e06d 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -47,7 +47,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b1698 + GIT_TAG b1703 PATCH_COMMAND test -f ggml.patched || git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch && ${CMAKE_COMMAND} -E touch ggml.patched GIT_SHALLOW FALSE ) From 4f2c9b12ff7e36ea6d55e21a0724b5f5b5552c49 Mon Sep 17 00:00:00 2001 From: hydai Date: Thu, 28 Dec 2023 08:41:56 +0800 Subject: [PATCH 210/623] [WASI-NN] ggml: Enhance the error msg when llama_decode failed (#3118) Signed-off-by: hydai --- plugins/wasi_nn/ggml.cpp | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 33fff693..5b057c84 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -490,10 +490,16 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // llama_batch_get_one(*token, n_tokens, position, sequence_id) // This will return batch for single sequence of tokens starting at // position. - if (llama_decode( - LlamaContext, - llama_batch_get_one(&Embd[I], NEval, NPast, SequenceId))) { - spdlog::error("[WASI-NN] GGML backend: failed to llama_decode"sv); + auto Status = + llama_decode(LlamaContext, llama_batch_get_one(&Embd[I], NEval, + NPast, SequenceId)); + if (Status == 1) { + spdlog::error( + "[WASI-NN] GGML backend: failed to llama_decode: try reducing the size of the batch or increasing the size of context"sv); + return ErrNo::RuntimeError; + } else if (Status < 0) { + spdlog::error( + "[WASI-NN] GGML backend: failed to llama_decode: internal fatal error. Please open an issue on GitHub"sv); return ErrNo::RuntimeError; } @@ -672,10 +678,17 @@ Expect computeSingle(WasiNNEnvironment &Env, // llama_batch_get_one(*token, n_tokens, position, sequence_id) // This will return batch for single sequence of tokens starting at // position. - if (llama_decode(CxtRef.LlamaContext, + auto Status = + llama_decode(CxtRef.LlamaContext, llama_batch_get_one(&CxtRef.LlamaEmbd[I], NEval, - CxtRef.LlamaNPast, SequenceId))) { - spdlog::error("[WASI-NN] GGML backend: failed to llama_decode"sv); + CxtRef.LlamaNPast, SequenceId)); + if (Status == 1) { + spdlog::error( + "[WASI-NN] GGML backend: failed to llama_decode: try reducing the size of the batch or increasing the size of context"sv); + return ErrNo::RuntimeError; + } else if (Status < 0) { + spdlog::error( + "[WASI-NN] GGML backend: failed to llama_decode: internal fatal error. Please open an issue on GitHub"sv); return ErrNo::RuntimeError; } From 8be1d2c34e62a51c468747f8898e15a8e8af5193 Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 2 Jan 2024 17:10:13 +0800 Subject: [PATCH 211/623] [WASI-NN] ggml: handle prompt too long error Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 35 ++++++++++++++++++++++++----------- plugins/wasi_nn/types.h | 5 +++-- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 5b057c84..32647c06 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -355,15 +355,6 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), Tensor.Tensor.size()); CxtRef.LlamaInputs = llama_tokenize(LlamaContext, Prompt, AddBos, true); - const uint32_t MaxContextSize = llama_n_ctx(LlamaContext); - // Minus 4 for the special tokens. - const uint32_t MaxTokensListSize = MaxContextSize - 4; - if (CxtRef.LlamaInputs.size() > MaxTokensListSize) { - spdlog::error( - "[WASI-NN] GGML backend: Error: The prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, - CxtRef.LlamaInputs.size(), MaxTokensListSize); - return ErrNo::InvalidArgument; - } if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] GGML backend: set the input...Done"sv); } @@ -451,6 +442,8 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { ContextParams.n_batch = GraphRef.BatchSize; auto LlamaContext = llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); + + // Get the context size. int NCtx = llama_n_ctx(LlamaContext); // Minus 4 for the special tokens. (Such as , , ... tokens.) const int MaxTokensListSize = NCtx - 4; @@ -458,6 +451,16 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { const int SequenceId = 0; // Return value. auto ReturnCode = ErrNo::Success; + + // Check if the input is too long. + if (static_cast(CxtRef.LlamaInputs.size()) > MaxTokensListSize) { + spdlog::error( + "[WASI-NN] GGML backend: Error: The prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, + CxtRef.LlamaInputs.size(), MaxTokensListSize); + return ErrNo::PromptTooLong; + } + + // Main predict loop. while (NRemain >= 0) { // Preidct if (!Embd.empty()) { @@ -466,7 +469,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { spdlog::error( "[WASI-NN] GGML backend: Error: The prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, Embd.size(), MaxTokensListSize); - return ErrNo::RuntimeError; + return ErrNo::PromptTooLong; } // We do not swap context here. End the inference if the context is full. @@ -640,12 +643,22 @@ Expect computeSingle(WasiNNEnvironment &Env, CxtRef.LlamaNConsumed = 0; } + // Get the context size. const int NCtx = llama_n_ctx(CxtRef.LlamaContext); // Minus 4 for the special tokens. (Such as , , ... tokens.) const int MaxTokensListSize = NCtx - 4; // Use the const sequence id here. const int SequenceId = 0; + // Check if the input is too long. + if (static_cast(CxtRef.LlamaInputs.size()) > MaxTokensListSize) { + spdlog::error( + "[WASI-NN] GGML backend: Error: The prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, + CxtRef.LlamaInputs.size(), MaxTokensListSize); + return ErrNo::PromptTooLong; + } + + // Main predict loop. while (true) { if (!CxtRef.LlamaEmbd.empty()) { // Input too long. @@ -653,7 +666,7 @@ Expect computeSingle(WasiNNEnvironment &Env, spdlog::error( "[WASI-NN] GGML backend: Error: The prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, CxtRef.LlamaEmbd.size(), MaxTokensListSize); - return ErrNo::RuntimeError; + return ErrNo::PromptTooLong; } // We do not swap context here. End the inference if the context is full. diff --git a/plugins/wasi_nn/types.h b/plugins/wasi_nn/types.h index 66acfee5..453b320b 100644 --- a/plugins/wasi_nn/types.h +++ b/plugins/wasi_nn/types.h @@ -18,8 +18,9 @@ enum class ErrNo : uint32_t { UnsupportedOperation = 6, // Unsupported Operation. TooLarge = 7, // Too Large. NotFound = 8, // Not Found. - EndOfSequence = 9, // End of Sequence Found. - ContextFull = 10, // Context Full. + EndOfSequence = 100, // End of Sequence Found. + ContextFull = 101, // Context Full. + PromptTooLong = 102, // Prompt Too Long. }; enum class TensorType : uint8_t { F16 = 0, F32 = 1, U8 = 2, I32 = 3 }; From a1057ef79683ffb3962b48a84eeaa156e880ad2b Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 2 Jan 2024 17:26:38 +0800 Subject: [PATCH 212/623] [WASI-NN] ggml: update ggml to b1743 Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 3e31e06d..a09c7bb7 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -47,7 +47,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b1703 + GIT_TAG b1743 PATCH_COMMAND test -f ggml.patched || git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch && ${CMAKE_COMMAND} -E touch ggml.patched GIT_SHALLOW FALSE ) From f957d31b87c837e89cacb261aa36f54fd132208f Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 2 Jan 2024 19:37:18 +0800 Subject: [PATCH 213/623] [WASI-NN] ggml: use uint64_t for llama token calculation Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 52 +++++++++++++++++++++------------------- plugins/wasi_nn/ggml.h | 4 ++-- 2 files changed, 29 insertions(+), 27 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 32647c06..d9c6562f 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -433,9 +433,9 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { struct llama_sampling_context *CtxSampling = llama_sampling_init(GPTParams.sparams); std::vector Embd; - int NPast = 0; - int NConsumed = 0; - int NRemain = GraphRef.NPredict; + uint64_t NPast = 0; + uint64_t NConsumed = 0; + int32_t NRemain = GraphRef.NPredict; // Initialize the llama context. llama_context_params ContextParams = llama_context_default_params(); ContextParams.n_ctx = GraphRef.CtxSize; @@ -444,16 +444,16 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); // Get the context size. - int NCtx = llama_n_ctx(LlamaContext); + const uint64_t NCtx = llama_n_ctx(LlamaContext); // Minus 4 for the special tokens. (Such as , , ... tokens.) - const int MaxTokensListSize = NCtx - 4; + const uint64_t MaxTokensListSize = NCtx - 4; // Use the const sequence id here. - const int SequenceId = 0; + const llama_seq_id SequenceId = 0; // Return value. auto ReturnCode = ErrNo::Success; // Check if the input is too long. - if (static_cast(CxtRef.LlamaInputs.size()) > MaxTokensListSize) { + if (static_cast(CxtRef.LlamaInputs.size()) > MaxTokensListSize) { spdlog::error( "[WASI-NN] GGML backend: Error: The prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, CxtRef.LlamaInputs.size(), MaxTokensListSize); @@ -465,7 +465,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // Preidct if (!Embd.empty()) { // Input too long. - if (static_cast(Embd.size()) > MaxTokensListSize) { + if (static_cast(Embd.size()) > MaxTokensListSize) { spdlog::error( "[WASI-NN] GGML backend: Error: The prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, Embd.size(), MaxTokensListSize); @@ -473,7 +473,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { } // We do not swap context here. End the inference if the context is full. - if (NPast + static_cast(Embd.size()) > NCtx) { + if (NPast + static_cast(Embd.size()) > NCtx) { if (GraphRef.EnableLog) { spdlog::info( "[WASI-NN] GGML backend: the context if full ({} / {} tokens)"sv, @@ -486,8 +486,8 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // Evaluate tokens in batches. for (int I = 0; I < static_cast(Embd.size()); I += GraphRef.BatchSize) { - int NEval = static_cast(Embd.size()) - I; - if (NEval > static_cast(GraphRef.BatchSize)) { + uint64_t NEval = static_cast(Embd.size()) - I; + if (NEval > static_cast(GraphRef.BatchSize)) { NEval = GraphRef.BatchSize; } // llama_batch_get_one(*token, n_tokens, position, sequence_id) @@ -512,7 +512,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { Embd.clear(); - if (static_cast(CxtRef.LlamaInputs.size()) <= NConsumed) { + if (static_cast(CxtRef.LlamaInputs.size()) <= NConsumed) { const llama_token Id = llama_sampling_sample(CtxSampling, LlamaContext, nullptr); llama_sampling_accept(CtxSampling, LlamaContext, Id, true); @@ -543,7 +543,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { break; } } else { - while (static_cast(CxtRef.LlamaInputs.size()) > NConsumed) { + while (static_cast(CxtRef.LlamaInputs.size()) > NConsumed) { Embd.push_back(CxtRef.LlamaInputs[NConsumed]); // Push the prompt in the sampling context. llama_sampling_accept(CtxSampling, LlamaContext, @@ -644,14 +644,14 @@ Expect computeSingle(WasiNNEnvironment &Env, } // Get the context size. - const int NCtx = llama_n_ctx(CxtRef.LlamaContext); + const uint64_t NCtx = llama_n_ctx(CxtRef.LlamaContext); // Minus 4 for the special tokens. (Such as , , ... tokens.) - const int MaxTokensListSize = NCtx - 4; + const uint64_t MaxTokensListSize = NCtx - 4; // Use the const sequence id here. - const int SequenceId = 0; + const llama_seq_id SequenceId = 0; // Check if the input is too long. - if (static_cast(CxtRef.LlamaInputs.size()) > MaxTokensListSize) { + if (static_cast(CxtRef.LlamaInputs.size()) > MaxTokensListSize) { spdlog::error( "[WASI-NN] GGML backend: Error: The prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, CxtRef.LlamaInputs.size(), MaxTokensListSize); @@ -662,7 +662,7 @@ Expect computeSingle(WasiNNEnvironment &Env, while (true) { if (!CxtRef.LlamaEmbd.empty()) { // Input too long. - if (static_cast(CxtRef.LlamaEmbd.size()) > MaxTokensListSize) { + if (static_cast(CxtRef.LlamaEmbd.size()) > MaxTokensListSize) { spdlog::error( "[WASI-NN] GGML backend: Error: The prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, CxtRef.LlamaEmbd.size(), MaxTokensListSize); @@ -670,22 +670,23 @@ Expect computeSingle(WasiNNEnvironment &Env, } // We do not swap context here. End the inference if the context is full. - if (CxtRef.LlamaNPast + static_cast(CxtRef.LlamaEmbd.size()) > + if (CxtRef.LlamaNPast + static_cast(CxtRef.LlamaEmbd.size()) > NCtx) { if (GraphRef.EnableLog) { spdlog::info( "[WASI-NN] GGML backend: the context if full ({} / {} tokens)"sv, - CxtRef.LlamaNPast + static_cast(CxtRef.LlamaEmbd.size()), + CxtRef.LlamaNPast + + static_cast(CxtRef.LlamaEmbd.size()), NCtx); } return ErrNo::ContextFull; } // Evaluate tokens in batches. - for (int I = 0; I < static_cast(CxtRef.LlamaEmbd.size()); + for (uint64_t I = 0; I < static_cast(CxtRef.LlamaEmbd.size()); I += GraphRef.BatchSize) { - int NEval = static_cast(CxtRef.LlamaEmbd.size()) - I; - if (NEval > static_cast(GraphRef.BatchSize)) { + uint64_t NEval = static_cast(CxtRef.LlamaEmbd.size()) - I; + if (NEval > static_cast(GraphRef.BatchSize)) { NEval = GraphRef.BatchSize; } // llama_batch_get_one(*token, n_tokens, position, sequence_id) @@ -711,7 +712,8 @@ Expect computeSingle(WasiNNEnvironment &Env, CxtRef.LlamaEmbd.clear(); - if (static_cast(CxtRef.LlamaInputs.size()) <= CxtRef.LlamaNConsumed) { + if (static_cast(CxtRef.LlamaInputs.size()) <= + CxtRef.LlamaNConsumed) { const llama_token Id = llama_sampling_sample( CxtRef.LlamaSampling, CxtRef.LlamaContext, nullptr); llama_sampling_accept(CxtRef.LlamaSampling, CxtRef.LlamaContext, Id, @@ -730,7 +732,7 @@ Expect computeSingle(WasiNNEnvironment &Env, } return ErrNo::Success; } else { - while (static_cast(CxtRef.LlamaInputs.size()) > + while (static_cast(CxtRef.LlamaInputs.size()) > CxtRef.LlamaNConsumed) { CxtRef.LlamaEmbd.push_back(CxtRef.LlamaInputs[CxtRef.LlamaNConsumed]); // Push the prompt in the sampling context. diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index c94df07e..2b6b3154 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -48,8 +48,8 @@ struct Context { llama_context *LlamaContext = nullptr; struct llama_sampling_context *LlamaSampling = nullptr; std::vector LlamaEmbd; - int LlamaNPast; - int LlamaNConsumed; + uint64_t LlamaNPast; + uint64_t LlamaNConsumed; }; #else struct Graph {}; From b30a68ba7bcf95cf37a979abf243868d5b5cea58 Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 4 Jan 2024 14:00:25 +0800 Subject: [PATCH 214/623] [WASI-NN] ggml: improved logging mechanism for ContextFull and PromptTooLong Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 39 ++++++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index d9c6562f..5a6bd12a 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -454,9 +454,11 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // Check if the input is too long. if (static_cast(CxtRef.LlamaInputs.size()) > MaxTokensListSize) { - spdlog::error( - "[WASI-NN] GGML backend: Error: The prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, - CxtRef.LlamaInputs.size(), MaxTokensListSize); + if (GraphRef.EnableLog) { + spdlog::info( + "[WASI-NN] GGML backend: the prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, + CxtRef.LlamaInputs.size(), MaxTokensListSize); + } return ErrNo::PromptTooLong; } @@ -466,17 +468,20 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { if (!Embd.empty()) { // Input too long. if (static_cast(Embd.size()) > MaxTokensListSize) { - spdlog::error( - "[WASI-NN] GGML backend: Error: The prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, - Embd.size(), MaxTokensListSize); - return ErrNo::PromptTooLong; + if (GraphRef.EnableLog) { + spdlog::info( + "[WASI-NN] GGML backend: the prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, + Embd.size(), MaxTokensListSize); + } + ReturnCode = ErrNo::PromptTooLong; + break; } // We do not swap context here. End the inference if the context is full. if (NPast + static_cast(Embd.size()) > NCtx) { if (GraphRef.EnableLog) { spdlog::info( - "[WASI-NN] GGML backend: the context if full ({} / {} tokens)"sv, + "[WASI-NN] GGML backend: the context if full ({} / {} tokens). Please increase your ctx-size."sv, NPast + static_cast(Embd.size()), NCtx); } ReturnCode = ErrNo::ContextFull; @@ -652,9 +657,11 @@ Expect computeSingle(WasiNNEnvironment &Env, // Check if the input is too long. if (static_cast(CxtRef.LlamaInputs.size()) > MaxTokensListSize) { - spdlog::error( - "[WASI-NN] GGML backend: Error: The prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, - CxtRef.LlamaInputs.size(), MaxTokensListSize); + if (GraphRef.EnableLog) { + spdlog::info( + "[WASI-NN] GGML backend: the prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, + CxtRef.LlamaInputs.size(), MaxTokensListSize); + } return ErrNo::PromptTooLong; } @@ -663,9 +670,11 @@ Expect computeSingle(WasiNNEnvironment &Env, if (!CxtRef.LlamaEmbd.empty()) { // Input too long. if (static_cast(CxtRef.LlamaEmbd.size()) > MaxTokensListSize) { - spdlog::error( - "[WASI-NN] GGML backend: Error: The prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, - CxtRef.LlamaEmbd.size(), MaxTokensListSize); + if (GraphRef.EnableLog) { + spdlog::info( + "[WASI-NN] GGML backend: the prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, + CxtRef.LlamaEmbd.size(), MaxTokensListSize); + } return ErrNo::PromptTooLong; } @@ -674,7 +683,7 @@ Expect computeSingle(WasiNNEnvironment &Env, NCtx) { if (GraphRef.EnableLog) { spdlog::info( - "[WASI-NN] GGML backend: the context if full ({} / {} tokens)"sv, + "[WASI-NN] GGML backend: the context if full ({} / {} tokens). Please increase your ctx-size."sv, CxtRef.LlamaNPast + static_cast(CxtRef.LlamaEmbd.size()), NCtx); From 89404cc8265b3c45c5443e582a0876957ad19b00 Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 4 Jan 2024 14:05:15 +0800 Subject: [PATCH 215/623] [WASI-NN] ggml: print llama timing in fini_single Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 5a6bd12a..1ace5c75 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -762,6 +762,11 @@ Expect finiSingle(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + // Logging for the llama timings. + if (GraphRef.EnableLog) { + llama_print_timings(CxtRef.LlamaContext); + } + // Clear the outputs. if (GraphRef.EnableDebugLog) { spdlog::info( From 436c887ae0b24483614159806f851b303798da3d Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 4 Jan 2024 15:13:42 +0800 Subject: [PATCH 216/623] [WASI-NN] ggml: support setting the llama threads number Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 13 +++++++++++++ plugins/wasi_nn/ggml.h | 1 + 2 files changed, 14 insertions(+) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 1ace5c75..1cf8ecf6 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -110,6 +110,14 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, return ErrNo::InvalidArgument; } } + if (Doc.at_key("threads").error() == simdjson::SUCCESS) { + auto Err = Doc["threads"].get().get(GraphRef.Threads); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the threads option."sv); + return ErrNo::InvalidArgument; + } + } // The sampling parameters. if (Doc.at_key("temp").error() == simdjson::SUCCESS) { auto Err = Doc["temp"].get().get(GraphRef.Temp); @@ -175,6 +183,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Initialize the context parameters. GraphRef.CtxSize = ContextDefault.n_ctx; GraphRef.BatchSize = ContextDefault.n_batch; + GraphRef.Threads = ContextDefault.n_threads; // Initialize the sampling parameters. llama_sampling_params SamplingDefault; GraphRef.Temp = SamplingDefault.temp; @@ -341,6 +350,8 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, llama_context_params ContextParams = llama_context_default_params(); ContextParams.n_ctx = GraphRef.CtxSize; ContextParams.n_batch = GraphRef.BatchSize; + ContextParams.n_threads = GraphRef.Threads; + ContextParams.n_threads_batch = GraphRef.Threads; auto LlamaContext = llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); if (GraphRef.EnableDebugLog) { @@ -641,6 +652,8 @@ Expect computeSingle(WasiNNEnvironment &Env, llama_context_params ContextParams = llama_context_default_params(); ContextParams.n_ctx = GraphRef.CtxSize; ContextParams.n_batch = GraphRef.BatchSize; + ContextParams.n_threads = GraphRef.Threads; + ContextParams.n_threads_batch = GraphRef.Threads; CxtRef.LlamaContext = llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); CxtRef.LlamaEmbd.clear(); diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index 2b6b3154..bd94df5f 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -32,6 +32,7 @@ struct Graph { // Context parameters: uint64_t CtxSize; uint64_t BatchSize; + uint64_t Threads; // Sampleing parameters: double Temp; double RepeatPenalty; From c87b27f19b70dd5865ceb91b57d0395daf523ca7 Mon Sep 17 00:00:00 2001 From: dm4 Date: Mon, 8 Jan 2024 14:44:17 +0800 Subject: [PATCH 217/623] [Plugin] image: update boost tarball URL Signed-off-by: dm4 --- plugins/wasmedge_image/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasmedge_image/CMakeLists.txt b/plugins/wasmedge_image/CMakeLists.txt index 2c2d8c5d..01762dbb 100644 --- a/plugins/wasmedge_image/CMakeLists.txt +++ b/plugins/wasmedge_image/CMakeLists.txt @@ -108,7 +108,7 @@ else() include(FetchContent) FetchContent_Declare( Boost - URL https://boostorg.jfrog.io/artifactory/main/release/1.82.0/source/boost_1_82_0.tar.bz2 + URL http://sources.buildroot.net/boost/boost_1_82_0.tar.bz2 URL_HASH SHA256=a6e1ab9b0860e6a2881dd7b21fe9f737a095e5f33a3a874afc6a345228597ee6 ) set(BOOST_ENABLE_CMAKE ON) From 5730ec1ebbf12b8f21a39fed97851bc3671f3eff Mon Sep 17 00:00:00 2001 From: Akihiro Suda Date: Tue, 9 Jan 2024 17:25:04 +0900 Subject: [PATCH 218/623] [WASI-NN] Support RPC mode RPC mode allows using another Wasi-NN instance that is running on a remote WasmEdge instance, via `ssh -R remote.sock:local.sock`. An example usecase is to allow Linux VM (e.g., Lima) guest to use the host GPU. The gRPC proto can be repurposed for non-WASM applications as well. - - - Build ===== Set `WASMEDGE_BUILD_WASI_NN_RPC` to `ON`. Enabled by default when gRPC (libgrpc++-dev) is installed. Usage ===== Host 1 (rpc server / ssh client, e.g., Lima host with physical GPU) ----- ``` wasi_nn_rpcserver \ --nn-rpc-uri unix:///$HOME/nn_server.sock \ --nn-preload default:GGML:AUTO:llama-2-7b-chat.Q5_K_M.gguf ssh \ -R /tmp/nn_client.sock:$HOME/nn_server.sock \ host2 ``` Host 2 (rpc client / ssh server, e.g., Lima guest) ----- ``` wasmedge \ --nn-rpc-uri unix:///tmp/nn_client.sock \ wasmedge-ggml-llama-interactive.wasm \ default "1 + 1 = ?" ``` See for how to obtain `llama-2-7b-chat.Q5_K_M.gguf` and `wasmedge-ggml-llama-interactive.wasm`. Signed-off-by: Akihiro Suda --- plugins/wasi_nn/CMakeLists.txt | 13 ++ plugins/wasi_nn/wasinnenv.cpp | 35 +++++ plugins/wasi_nn/wasinnenv.h | 10 +- plugins/wasi_nn/wasinnfunc.cpp | 181 +++++++++++++++++++++--- test/plugins/wasi_nn/CMakeLists.txt | 8 ++ test/plugins/wasi_nn/wasi_nn.cpp | 212 +++++++++++++++++++++++++++- 6 files changed, 441 insertions(+), 18 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index a09c7bb7..9411b092 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -131,6 +131,9 @@ target_compile_options(wasmedgePluginWasiNN PUBLIC -DWASMEDGE_PLUGIN ) +if(WASMEDGE_BUILD_WASI_NN_RPC) + add_definitions(-DWASMEDGE_BUILD_WASI_NN_RPC) +endif() target_include_directories(wasmedgePluginWasiNN PUBLIC @@ -162,6 +165,16 @@ else() ) endif() +if(WASMEDGE_BUILD_WASI_NN_RPC) + target_include_directories(wasmedgePluginWasiNN + SYSTEM BEFORE PUBLIC ${Protobuf_INCLUDE_DIR} + ) + target_link_libraries(wasmedgePluginWasiNN + PRIVATE + wasiNNRPC + ) +endif() + include(WASINNDeps) wasmedge_setup_wasinn_target(wasmedgePluginWasiNN) diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index 6c90ae54..438c7f02 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -6,6 +6,10 @@ #include +#ifdef WASMEDGE_BUILD_WASI_NN_RPC +#include +#endif + namespace WasmEdge { namespace Host { @@ -46,6 +50,25 @@ bool load(const std::filesystem::path &Path, std::vector &Data) { } WasiNNEnvironment::WasiNNEnvironment() noexcept { +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (getenv("_WASI_NN_RPCSERVER") == nullptr) { + // RPC client mode + auto URI = NNRPCURI.value(); + if (!URI.empty()) { + std::string_view UnixPrefix = "unix://"; + if (URI.substr(0, UnixPrefix.length()) != UnixPrefix) { + spdlog::warn("[WASI-NN] Expected \"unix://...\", got \"{}\""sv, URI); + } + auto Cred = grpc::InsecureChannelCredentials(); // safe for unix://... + NNRPCChannel = grpc::CreateChannel(URI, Cred); + if (NNModels.value().size() > 0) { + spdlog::warn( + "[WASI-NN] nn-preload has to be specified on the RPC server side, not on the client side"sv); + } + return; + } + } +#endif // Preload NN Models for (const auto &M : NNModels.value()) { std::istringstream ISS(M); @@ -100,9 +123,21 @@ PO::List WasiNNEnvironment::NNModels( "Allow preload models from wasinn plugin. Each NN model can be specified as --nn-preload `COMMAND`."sv), PO::MetaVar("COMMANDS"sv)); +#ifdef WASMEDGE_BUILD_WASI_NN_RPC +PO::Option WasiNNEnvironment::NNRPCURI( + PO::Description("Specify NN RPC URI to connect (\"unix://...\")"sv), + PO::MetaVar("URI"sv), PO::DefaultValue(std::string(""))); +#endif + void addOptions(const Plugin::Plugin::PluginDescriptor *, PO::ArgumentParser &Parser) noexcept { Parser.add_option("nn-preload"sv, WasiNNEnvironment::NNModels); +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (getenv("_WASI_NN_RPCSERVER") == nullptr) { + // RPC client mode + Parser.add_option("nn-rpc-uri"sv, WasiNNEnvironment::NNRPCURI); + } +#endif } Plugin::Plugin::PluginDescriptor Descriptor{ diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index 7d919b5c..00c8df0f 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -17,6 +17,11 @@ #include "torch.h" #include "types.h" +#ifdef WASMEDGE_BUILD_WASI_NN_RPC +#include +#include +#endif + namespace WasmEdge { namespace Host { namespace WASINN { @@ -197,7 +202,10 @@ struct WasiNNEnvironment : std::vector NNGraph; std::vector NNContext; static PO::List NNModels; - +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + static PO::Option NNRPCURI; // For RPC client mode + std::shared_ptr NNRPCChannel; +#endif static Plugin::PluginRegister Register; }; diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index 3be2bca8..6ccf2b45 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -8,6 +8,12 @@ #include #include +#ifdef WASMEDGE_BUILD_WASI_NN_RPC +#include "wasi_ephemeral_nn.grpc.pb.h" + +#include +#endif // #ifdef WASMEDGE_BUILD_WASI_NN_RPC + namespace WasmEdge { namespace Host { @@ -36,6 +42,13 @@ Expect WasiNNLoad::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t BuilderPtr, uint32_t BuilderLen, uint32_t RawEncoding, uint32_t Target, uint32_t GraphIdPtr) { +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + // TODO: implement RPC for Load + spdlog::error("[WASI-NN] RPC client is not implemented for Load"sv); + return WASINN::ErrNo::UnsupportedOperation; + } +#endif // Check memory instance from module. auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { @@ -113,6 +126,25 @@ WasiNNLoadByName::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t NamePtr, return WASINN::ErrNo::InvalidArgument; } +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + auto Stub = wasi_ephemeral_nn::Graph::NewStub(Env.NNRPCChannel); + grpc::ClientContext ClientContext; + wasi_ephemeral_nn::LoadByNameRequest Req; + auto NameStrView = MemInst->getStringView(NamePtr, NameLen); + Req.set_name(NameStrView.data(), NameStrView.size()); + wasi_ephemeral_nn::LoadByNameResult Res; + auto Status = Stub->LoadByName(&ClientContext, Req, &Res); + if (!Status.ok()) { + spdlog::error("[WASI-NN] Failed when calling remote LoadByName: {}"sv, + Status.error_message()); + return WASINN::ErrNo::RuntimeError; + } + *GraphId = Res.graph_handle(); + return WASINN::ErrNo::Success; + } +#endif // ifdef WASMEDGE_BUILD_WASI_NN_RPC + // Get the model std::string ModelName(reinterpret_cast(Name), NameLen); if (Env.mdGet(ModelName, *GraphId)) { @@ -125,6 +157,14 @@ WasiNNLoadByName::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t NamePtr, Expect WasiNNLoadByNameWithConfig::bodyImpl( const Runtime::CallingFrame &Frame, uint32_t NamePtr, uint32_t NameLen, uint32_t ConfigPtr, uint32_t ConfigLen, uint32_t GraphIdPtr) { +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + // TODO: implement RPC for LoadByNameWithConfig + spdlog::error( + "[WASI-NN] RPC client is not implemented for LoadByNameWithConfig"sv); + return WASINN::ErrNo::UnsupportedOperation; + } +#endif auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); @@ -172,12 +212,6 @@ WasiNNInitExecCtx::bodyImpl(const Runtime::CallingFrame &Frame, return Unexpect(ErrCode::Value::HostFuncError); } - if (Env.NNGraph.size() <= GraphId) { - spdlog::error( - "[WASI-NN] init_execution_context: Graph Id does not exist."sv); - return WASINN::ErrNo::InvalidArgument; - } - // Check the return value: Context should be valid. uint32_t *Context = MemInst->getPointer(ContextPtr); if (unlikely(Context == nullptr)) { @@ -185,6 +219,31 @@ WasiNNInitExecCtx::bodyImpl(const Runtime::CallingFrame &Frame, return WASINN::ErrNo::InvalidArgument; } +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + auto Stub = wasi_ephemeral_nn::GraphResource::NewStub(Env.NNRPCChannel); + grpc::ClientContext ClientContext; + wasi_ephemeral_nn::InitExecutionContextRequest Req; + Req.set_resource_handle(GraphId); + wasi_ephemeral_nn::InitExecutionContextResult Res; + auto Status = Stub->InitExecutionContext(&ClientContext, Req, &Res); + if (!Status.ok()) { + spdlog::error( + "[WASI-NN] Failed when calling remote InitExecutionContext: {}"sv, + Status.error_message()); + return WASINN::ErrNo::RuntimeError; + } + *Context = Res.ctx_handle(); + return WASINN::ErrNo::Success; + } +#endif // ifdef WASMEDGE_BUILD_WASI_NN_RPC + + if (Env.NNGraph.size() <= GraphId) { + spdlog::error( + "[WASI-NN] init_execution_context: Graph Id does not exist."sv); + return WASINN::ErrNo::InvalidArgument; + } + switch (const auto Backend = Env.NNGraph[GraphId].getBackend()) { #define EACH(B) \ case WASINN::Backend::B: \ @@ -205,11 +264,6 @@ WasiNNSetInput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context, return Unexpect(ErrCode::Value::HostFuncError); } - if (Env.NNContext.size() <= Context) { - spdlog::error("[WASI-NN] set_input: Execution Context does not exist."sv); - return WASINN::ErrNo::InvalidArgument; - } - // Tensor's Layout: // | dim buf | dim buf len | rtype | data buf | data buf len | struct WasiTensorData { @@ -225,6 +279,7 @@ WasiNNSetInput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context, spdlog::error("[WASI-NN] Failed when accessing the Tensor memory."sv); return WASINN::ErrNo::InvalidArgument; } + WASINN::TensorData Tensor; Tensor.Dimension = MemInst->getSpan(WasiTensor->DimensionPtr, WasiTensor->DimensionLen); @@ -252,6 +307,37 @@ WasiNNSetInput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context, return WASINN::ErrNo::InvalidArgument; } +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + auto Stub = wasi_ephemeral_nn::GraphExecutionContextResource::NewStub( + Env.NNRPCChannel); + grpc::ClientContext ClientContext; + wasi_ephemeral_nn::SetInputRequest Req; + Req.set_resource_handle(Context); + Req.set_index(Index); + wasi_ephemeral_nn::Tensor RPCTensor; + RPCTensor.mutable_dimensions()->Add(Tensor.Dimension.begin(), + Tensor.Dimension.end()); + RPCTensor.set_ty(wasi_ephemeral_nn::TensorType(Tensor.RType)); + RPCTensor.set_data(MemInst->getPointer(WasiTensor->TensorPtr), + WasiTensor->TensorLen); + *Req.mutable_tensor() = RPCTensor; + google::protobuf::Empty Res; + auto Status = Stub->SetInput(&ClientContext, Req, &Res); + if (!Status.ok()) { + spdlog::error("[WASI-NN] Failed when calling remote SetInput: {}"sv, + Status.error_message()); + return WASINN::ErrNo::RuntimeError; + } + return WASINN::ErrNo::Success; + } +#endif // ifdef WASMEDGE_BUILD_WASI_NN_RPC + + if (Env.NNContext.size() <= Context) { + spdlog::error("[WASI-NN] set_input: Execution Context does not exist."sv); + return WASINN::ErrNo::InvalidArgument; + } + switch (const auto Backend = Env.NNContext[Context].getBackend()) { #define EACH(B) \ case WASINN::Backend::B: \ @@ -273,11 +359,6 @@ WasiNNGetOutput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context, return Unexpect(ErrCode::Value::HostFuncError); } - if (Env.NNContext.size() <= Context) { - spdlog::error("[WASI-NN] get_output: Execution Context does not exist"sv); - return WASINN::ErrNo::InvalidArgument; - } - const auto OutBuffer = MemInst->getSpan(OutBufferPtr, OutBufferMaxSize); if (unlikely(OutBuffer.data() == nullptr)) { @@ -291,6 +372,34 @@ WasiNNGetOutput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context, return WASINN::ErrNo::InvalidArgument; } +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + auto Stub = wasi_ephemeral_nn::GraphExecutionContextResource::NewStub( + Env.NNRPCChannel); + grpc::ClientContext ClientContext; + wasi_ephemeral_nn::GetOutputRequest Req; + Req.set_resource_handle(Context); + Req.set_index(Index); + wasi_ephemeral_nn::GetOutputResult Res; + auto Status = Stub->GetOutput(&ClientContext, Req, &Res); + if (!Status.ok()) { + spdlog::error("[WASI-NN] Failed when calling remote GetOutput: {}"sv, + Status.error_message()); + return WASINN::ErrNo::RuntimeError; + } + uint32_t BytesWrittenVal = + std::min(static_cast(Res.data().size()), OutBufferMaxSize); + std::copy_n(Res.data().begin(), BytesWrittenVal, OutBuffer.begin()); + *BytesWritten = BytesWrittenVal; + return WASINN::ErrNo::Success; + } +#endif // ifdef WASMEDGE_BUILD_WASI_NN_RPC + + if (Env.NNContext.size() <= Context) { + spdlog::error("[WASI-NN] get_output: Execution Context does not exist"sv); + return WASINN::ErrNo::InvalidArgument; + } + switch (const auto Backend = Env.NNContext[Context].getBackend()) { #define EACH(B) \ case WASINN::Backend::B: \ @@ -307,6 +416,14 @@ Expect WasiNNGetOutputSingle::bodyImpl( const Runtime::CallingFrame &Frame, uint32_t Context, uint32_t Index, uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr) { +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + // TODO: implement RPC for GetOutputSingle + spdlog::error( + "[WASI-NN] RPC client is not implemented for GetOutputSingle"sv); + return WASINN::ErrNo::UnsupportedOperation; + } +#endif auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); @@ -344,6 +461,23 @@ Expect WasiNNGetOutputSingle::bodyImpl( Expect WasiNNCompute::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context) { +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + auto Stub = wasi_ephemeral_nn::GraphExecutionContextResource::NewStub( + Env.NNRPCChannel); + grpc::ClientContext ClientContext; + wasi_ephemeral_nn::ComputeRequest Req; + Req.set_resource_handle(Context); + google::protobuf::Empty Res; + auto Status = Stub->Compute(&ClientContext, Req, &Res); + if (!Status.ok()) { + spdlog::error("[WASI-NN] Failed when calling remote Compute: {}"sv, + Status.error_message()); + return WASINN::ErrNo::RuntimeError; + } + return WASINN::ErrNo::Success; + } +#endif // ifdef WASMEDGE_BUILD_WASI_NN_RPC auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); @@ -369,6 +503,14 @@ WasiNNCompute::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context) { Expect WasiNNComputeSingle::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context) { +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + // TODO: implement RPC for ComputeSingle + spdlog::error( + "[WASI-NN] RPC client is not implemented for ComputeSingle"sv); + return WASINN::ErrNo::UnsupportedOperation; + } +#endif auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); @@ -393,6 +535,13 @@ WasiNNComputeSingle::bodyImpl(const Runtime::CallingFrame &Frame, Expect WasiNNFiniSingle::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context) { +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + // TODO: implement RPC for FiniSingle + spdlog::error("[WASI-NN] RPC client is not implemented for FiniSingle"sv); + return WASINN::ErrNo::UnsupportedOperation; + } +#endif auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index e7eb03c5..fe68a8fa 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -100,3 +100,11 @@ else() endif() add_test(wasiNNTests wasiNNTests) + +if(WASMEDGE_BUILD_WASI_NN_RPC) + add_definitions(-DWASMEDGE_BUILD_WASI_NN_RPC) + target_link_libraries(wasiNNTests + PRIVATE + wasiNNRPC + ) +endif() \ No newline at end of file diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 58d6ea7c..02857651 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -21,12 +21,18 @@ using WasmEdge::Host::WASINN::ErrNo; defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML) namespace { -WasmEdge::Runtime::Instance::ModuleInstance *createModule() { +WasmEdge::Runtime::Instance::ModuleInstance * +createModule(std::string_view NNRPCURI = "") { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( "../../../plugins/wasi_nn/" "libwasmedgePluginWasiNN" WASMEDGE_LIB_EXTENSION)); if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasi_nn"sv)) { + WasmEdge::PO::ArgumentParser Parser; + Plugin->registerOptions(Parser); + if (NNRPCURI != "") { + Parser.set_raw_value("nn-rpc-uri"sv, std::string(NNRPCURI)); + } if (const auto *Module = Plugin->findModule("wasi_nn"sv)) { return Module->create().release(); } @@ -1465,4 +1471,208 @@ TEST(WasiNNTest, GGMLBackend) { EXPECT_GE(BytesWritten, 50); } } +#ifdef WASMEDGE_BUILD_WASI_NN_RPC +TEST(WasiNNTest, GGMLBackendWithRPC) { + // wasi_nn_rpcserver has to be started outside this test, + // and the URI has to be set to $WASI_NN_RPC_TEST_URI. + // nn-preload has to be specified for "default". + /* + DIR=/tmp/build + export WASI_NN_RPC_TEST_URI=unix://${DIR}/wasi_nn_rpc.sock + export WASMEDGE_PLUGIN_PATH=${DIR}/plugins/wasi_nn + ${DIR}/tools/wasmedge/wasi_nn_rpcserver \ + --nn-rpc-uri=$WASI_NN_RPC_TEST_URI \ + --nn-preload=default:GGML:AUTO:${DIR}/test/plugins/wasi_nn/wasinn_ggml_fixtures/orca_mini.gguf + */ + const auto NNRPCURI = ::getenv("WASI_NN_RPC_TEST_URI"); + if (NNRPCURI == nullptr) { + GTEST_SKIP() << "WASI_NN_RPC_TEST_URI is unset"; + } + + // Create the wasmedge_process module instance. + auto *NNMod = + dynamic_cast(createModule(NNRPCURI)); + EXPECT_FALSE(NNMod == nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(60000))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + std::string Prompt = "Once upon a time, "; + std::vector TensorData(Prompt.begin(), Prompt.end()); + + std::vector TensorDim{1}; + uint32_t BuilderPtr = UINT32_C(0); + uint32_t LoadEntryPtr = UINT32_C(0); + uint32_t SetInputEntryPtr = UINT32_C(0); + uint32_t OutBoundPtr = UINT32_C(61000) * UINT32_C(65536); + uint32_t StorePtr = UINT32_C(65536); + + // Return value. + std::array Errno = {UINT32_C(0)}; + + // Get the function "load_by_name". + auto FuncInst = NNMod->findFuncExports("load_by_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncLoadByName = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "init_execution_context". + FuncInst = NNMod->findFuncExports("init_execution_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInit = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "set_input". + FuncInst = NNMod->findFuncExports("set_input"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSetInput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "get_output". + FuncInst = NNMod->findFuncExports("get_output"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetOutput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "compute". + FuncInst = NNMod->findFuncExports("compute"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncCompute = + dynamic_cast(FuncInst->getHostFunc()); + + // Test: load_by_name -- load successfully. + { + std::string Name = "default"; + std::vector NameVec(Name.begin(), Name.end()); + writeBinaries(MemInst, NameVec, LoadEntryPtr); + EXPECT_TRUE(HostFuncLoadByName.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, static_cast(NameVec.size()), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // GGML WASI-NN init_execution_context tests. + // Test: init_execution_context -- graph id invalid. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(2), BuilderPtr}, + Errno)); + EXPECT_NE(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Test: init_execution_context -- init context successfully. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // GGML WASI-NN set_input tests. + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); + + // Test: set_input -- context id exceeds. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(3), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_NE(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Test: set_input -- set input successfully. + BuilderPtr = SetInputEntryPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + BuilderPtr); + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += (TensorDim.size() * 4 + TensorData.size()); + + // GGML WASI-NN compute tests. + // Test: compute -- context id exceeds. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(3)}, + Errno)); + EXPECT_NE(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Test: compute -- compute until finish or context full. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(0)}, + Errno)); + // FIXME: ErrNo propagation is not supported yet + // EXPECT_TRUE( + // Errno[0].get() == static_cast(ErrNo::Success) + // || Errno[0].get() == + // static_cast(ErrNo::ContextFull)); + } + + // GGML WASI-NN get_output tests. + // Test: get_output -- output bytes ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, OutBoundPtr}, + Errno)); + EXPECT_NE(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Test: get_output -- output buffer ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), OutBoundPtr, 65532, BuilderPtr}, + Errno)); + EXPECT_NE(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Test: get_output -- get output successfully. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + // Should output more than 50 bytes. + auto BytesWritten = *MemInst.getPointer(BuilderPtr); + EXPECT_GE(BytesWritten, 50); + } +} +#endif // WASMEDGE_BUILD_WASI_NN_RPC #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML From a61467215ab346148710208e5549f057d67b6af5 Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 11 Jan 2024 14:14:39 +0800 Subject: [PATCH 219/623] [WASI-NN] ggml: update ggml to b1808 Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 9411b092..5d8ebf20 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -47,7 +47,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b1743 + GIT_TAG b1808 PATCH_COMMAND test -f ggml.patched || git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch && ${CMAKE_COMMAND} -E touch ggml.patched GIT_SHALLOW FALSE ) From 97bbade06591c43b93ddefe5574ddfc2c7c30dbc Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 16 Jan 2024 15:26:49 +0800 Subject: [PATCH 220/623] [WASI-NN] ggml backend: bump to llama.cpp b1879 (#3158) Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 5d8ebf20..575e6b20 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -47,7 +47,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b1808 + GIT_TAG b1879 PATCH_COMMAND test -f ggml.patched || git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch && ${CMAKE_COMMAND} -E touch ggml.patched GIT_SHALLOW FALSE ) From a5ff6ffc358cc23af2e2c16b9c115e094f4bfd4f Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 17 Jan 2024 15:12:45 +0800 Subject: [PATCH 221/623] [WASI-NN] ggml: return version in metadata Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 1cf8ecf6..581f253f 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -147,14 +147,21 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, Expect buildOutputMetadata(Context &CxtRef, std::string &Metadata) noexcept { - std::string MetadataTemplate = R"({"input_tokens": %d, "output_tokens": %d})"; + std::string MetadataTemplate = + R"({"input_tokens": %d, "output_tokens": %d, "llama_build_number": %d, "llama_commit": "%s"})"; // The 20 bytes are reserved to accommodate two %d placeholders in the // MetadataTemplate. This allows for a decimal integer value up to a // 12-digit number of input/output tokens. - char Buffer[MetadataTemplate.size() + 20]; + // The 3 bytes are reserved to accommodate the %d placeholder for the build + // number. Allows for a decimal integer value up to a 5-digit number. + // The 5 bytes are reserved to accommodate the %s placeholder for the commit + // hash. The commit hash is 7 bytes long by default using `git rev-parse + // --short HEAD`. + char Buffer[MetadataTemplate.size() + 20 + 3 + 5]; snprintf(Buffer, sizeof(Buffer), MetadataTemplate.c_str(), - CxtRef.LlamaInputs.size(), CxtRef.LlamaOutputTokens.size()); + CxtRef.LlamaInputs.size(), CxtRef.LlamaOutputTokens.size(), + LLAMA_BUILD_NUMBER, LLAMA_COMMIT); Metadata = std::string(Buffer); return ErrNo::Success; From f68ebe67849c2e792060e724c08ac9a563aaaf79 Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 18 Jan 2024 11:27:22 +0800 Subject: [PATCH 222/623] [WASI-NN] ggml: count n-predict correctly Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 581f253f..7c99c3d4 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -481,7 +481,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { } // Main predict loop. - while (NRemain >= 0) { + while (NRemain > 0) { // Preidct if (!Embd.empty()) { // Input too long. From b34aefd0fcea8608b297b5e24e010c5034e38b96 Mon Sep 17 00:00:00 2001 From: dm4 Date: Mon, 22 Jan 2024 20:54:20 +0800 Subject: [PATCH 223/623] [WASI-NN] ggml: remove fixed ngl when using Metal on macOS (#3165) Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 7c99c3d4..54ba2fb6 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -87,11 +87,6 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, return ErrNo::InvalidArgument; } } -#ifdef __APPLE__ - // Whatever the `n-gpu-layers` is given, we will always set the ngl to 1 on - // macOS to forcely enabled Metal. - GraphRef.NGPULayers = 1; // Force enabled Metal on macOS -#endif // The context parameters. if (Doc.at_key("ctx-size").error() == simdjson::SUCCESS) { From da71661c6d13307609204377dc8d7a0541724dab Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 23 Jan 2024 02:26:22 +0800 Subject: [PATCH 224/623] [WASI-NN] ggml backend: bump llama.cpp b1953 Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 575e6b20..ee1e9ad1 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -47,7 +47,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b1879 + GIT_TAG b1953 PATCH_COMMAND test -f ggml.patched || git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch && ${CMAKE_COMMAND} -E touch ggml.patched GIT_SHALLOW FALSE ) From f858e9c89d9b4e1f8130bf0a790cda4227336eab Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 24 Jan 2024 19:11:27 +0800 Subject: [PATCH 225/623] [WASI-NN] ggml: support top_p and presence_penalty (#3174) Signed-off-by: hydai --- plugins/wasi_nn/ggml.cpp | 23 +++++++++++++++++++++++ plugins/wasi_nn/ggml.h | 4 +++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 54ba2fb6..dee8207c 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -123,6 +123,14 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, } GraphRef.Temp = std::max(0.0, GraphRef.Temp); } + if (Doc.at_key("top-p").error() == simdjson::SUCCESS) { + auto Err = Doc["top-p"].get().get(GraphRef.TopP); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the top-p option."sv); + return ErrNo::InvalidArgument; + } + } if (Doc.at_key("repeat-penalty").error() == simdjson::SUCCESS) { auto Err = Doc["repeat-penalty"].get().get(GraphRef.RepeatPenalty); if (Err) { @@ -131,6 +139,15 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, return ErrNo::InvalidArgument; } } + if (Doc.at_key("presence-penalty").error() == simdjson::SUCCESS) { + auto Err = + Doc["presence-penalty"].get().get(GraphRef.PresencePenalty); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the presence-penalty option."sv); + return ErrNo::InvalidArgument; + } + } // Check if the model is updated. if (IsModelUpdated && ModelParams.n_gpu_layers != GraphRef.NGPULayers) { @@ -189,7 +206,9 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Initialize the sampling parameters. llama_sampling_params SamplingDefault; GraphRef.Temp = SamplingDefault.temp; + GraphRef.TopP = SamplingDefault.top_p; GraphRef.RepeatPenalty = SamplingDefault.penalty_repeat; + GraphRef.PresencePenalty = SamplingDefault.penalty_present; // If the graph builder length > 1, the data of builder[1] is the metadata. if (Builders.size() > 1) { @@ -442,7 +461,9 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { } gpt_params GPTParams; GPTParams.sparams.temp = GraphRef.Temp; + GPTParams.sparams.top_p = GraphRef.TopP; GPTParams.sparams.penalty_repeat = GraphRef.RepeatPenalty; + GPTParams.sparams.penalty_present = GraphRef.PresencePenalty; struct llama_sampling_context *CtxSampling = llama_sampling_init(GPTParams.sparams); std::vector Embd; @@ -649,7 +670,9 @@ Expect computeSingle(WasiNNEnvironment &Env, // Initialize the llama context. gpt_params GPTParams; GPTParams.sparams.temp = GraphRef.Temp; + GPTParams.sparams.top_p = GraphRef.TopP; GPTParams.sparams.penalty_repeat = GraphRef.RepeatPenalty; + GPTParams.sparams.penalty_present = GraphRef.PresencePenalty; CxtRef.LlamaSampling = llama_sampling_init(GPTParams.sparams); llama_context_params ContextParams = llama_context_default_params(); ContextParams.n_ctx = GraphRef.CtxSize; diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index bd94df5f..d00836c0 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -33,9 +33,11 @@ struct Graph { uint64_t CtxSize; uint64_t BatchSize; uint64_t Threads; - // Sampleing parameters: + // Sampling parameters: double Temp; + double TopP; double RepeatPenalty; + double PresencePenalty; }; struct Context { From 4319bbc8de176096546d4873d776b2d7670381ed Mon Sep 17 00:00:00 2001 From: vincent Date: Sun, 21 Jan 2024 19:59:24 +0800 Subject: [PATCH 226/623] [WASI-NN] openvino: upgrade to 2.0 (#3140) Signed-off-by: vincent --- plugins/wasi_nn/openvino.cpp | 357 ++++++------------------------ plugins/wasi_nn/openvino.h | 122 +--------- test/plugins/wasi_nn/wasi_nn.cpp | 6 +- utils/wasi-nn/install-openvino.sh | 2 +- 4 files changed, 82 insertions(+), 405 deletions(-) diff --git a/plugins/wasi_nn/openvino.cpp b/plugins/wasi_nn/openvino.cpp index dcdf8eda..9c0625ac 100644 --- a/plugins/wasi_nn/openvino.cpp +++ b/plugins/wasi_nn/openvino.cpp @@ -10,12 +10,6 @@ namespace WasmEdge::Host::WASINN::OpenVINO { Expect load(WASINN::WasiNNEnvironment &Env, Span> Builders, WASINN::Device Device, uint32_t &GraphId) noexcept { - // The OpenVINO core must be initialized in constructor. - if (unlikely(Env.OpenVINOCore == nullptr)) { - spdlog::error("[WASI-NN] OpenVINO core not initialized."); - return WASINN::ErrNo::MissingMemory; - } - // The graph builder length must be 2. if (Builders.size() != 2) { spdlog::error("[WASI-NN] Wrong GraphBuilder Length {:d}, expect 2", @@ -33,134 +27,25 @@ Expect load(WASINN::WasiNNEnvironment &Env, Env.NNGraph.emplace_back(Backend::OpenVINO); auto &GraphRef = Env.NNGraph.back().get(); - // Create the weights blob memory. - tensor_desc_t WeightsDesc{ - layout_e::ANY, {1, {Weight.size()}}, precision_e::U8}; - IEStatusCode Status = - ie_blob_make_memory(&WeightsDesc, &(GraphRef.OpenVINOWeightBlob)); - if (Status != IEStatusCode::OK) { - spdlog::error( - "[WASI-NN] Unable to create the model's weight blob, error code: {}", - Status); - Env.NNGraph.pop_back(); - return WASINN::ErrNo::Busy; - } - - // Copy the weights buffer to the blob. - ie_blob_buffer_t BlobBuffer; - Status = ie_blob_get_buffer(GraphRef.OpenVINOWeightBlob, &BlobBuffer); - if (unlikely(Status != IEStatusCode::OK)) { - spdlog::error( - "[WASI-NN] Unable to find the weight blob's buffer, error code: {}", - Status); - Env.NNGraph.pop_back(); - return WASINN::ErrNo::MissingMemory; - } - std::copy_n(Weight.data(), Weight.size(), - static_cast(BlobBuffer.buffer)); - - // Read network from memory. - Status = ie_core_read_network_from_memory( - Env.OpenVINOCore, XML.data(), XML.size(), GraphRef.OpenVINOWeightBlob, - &(GraphRef.OpenVINONetwork)); - if (Status != IEStatusCode::OK) { - spdlog::error("[WASI-NN] Unable to read network from the XML and " - "Weights, error code: {}", - Status); - Env.NNGraph.pop_back(); - return WASINN::ErrNo::Busy; - } - - // Get the network input and output size. - size_t NetworkInputSize = 0; - Status = - ie_network_get_inputs_number(GraphRef.OpenVINONetwork, &NetworkInputSize); - if (unlikely(Status != IEStatusCode::OK)) { - spdlog::error("[WASI-NN] Unable to get the inputs number from the " - "network, error code: {}", - Status); - Env.NNGraph.pop_back(); - return WASINN::ErrNo::MissingMemory; - } - spdlog::debug("[WASI-NN] Got input size: {}", NetworkInputSize); - size_t NetworkOutputSize = 0; - Status = ie_network_get_outputs_number(GraphRef.OpenVINONetwork, - &NetworkOutputSize); - if (unlikely(Status != IEStatusCode::OK)) { - spdlog::error("[WASI-NN] Unable to get the outputs number from the " - "network, error code: {}", - Status); - Env.NNGraph.pop_back(); - return WASINN::ErrNo::MissingMemory; - } - spdlog::debug("[WASI-NN] Got output size: {}", NetworkOutputSize); - - // Get and store the input and output names. - GraphRef.OpenVINOInputNames.resize(NetworkInputSize, nullptr); - for (size_t I = 0; I < NetworkInputSize; I++) { - Status = ie_network_get_input_name(GraphRef.OpenVINONetwork, I, - &(GraphRef.OpenVINOInputNames[I])); - if (Status != IEStatusCode::OK) { - spdlog::error("[WASI-NN] Unable to find input name correctly with " - "Index {}, error code: {}", - I, Status); - Env.NNGraph.pop_back(); - return WASINN::ErrNo::MissingMemory; - } - spdlog::debug("[WASI-NN] Got input name: {}", - GraphRef.OpenVINOInputNames[I]); - } - GraphRef.OpenVINOOutputNames.resize(NetworkOutputSize, nullptr); - for (size_t I = 0; I < NetworkOutputSize; I++) { - Status = ie_network_get_output_name(GraphRef.OpenVINONetwork, I, - &(GraphRef.OpenVINOOutputNames[I])); - if (Status != IEStatusCode::OK) { - spdlog::error("[WASI-NN] Unable to find output name correctly with " - "Index {}, error code: {}", - I, Status); - Env.NNGraph.pop_back(); - return WASINN::ErrNo::MissingMemory; - } - spdlog::debug("[WASI-NN] Got output name: {}", - GraphRef.OpenVINOOutputNames[I]); - } - - // Set the input layout. - // FIXME: this is a temporary workaround. We need a more eligant way to - // specify the layout in the long run. However, without this newer versions - // of OpenVINO will fail due to parameter mismatch. - for (size_t I = 0; I < NetworkInputSize; I++) { - // More layouts should be supported. - Status = ie_network_set_input_layout(GraphRef.OpenVINONetwork, - GraphRef.OpenVINOInputNames[I], - layout_e::NHWC); - spdlog::debug("[WASI-NN] Setting [{}] to NHWC", - GraphRef.OpenVINOInputNames[I]); - if (Status != IEStatusCode::OK) { - spdlog::error("[WASI-NN] Unable to set input layout with the input " - "name {}, error code: {}", - GraphRef.OpenVINOInputNames[I], Status); - Env.NNGraph.pop_back(); - return WASINN::ErrNo::MissingMemory; - } - } - - // Load network. - ie_config_t Config = {nullptr, nullptr, nullptr}; - Status = ie_core_load_network(Env.OpenVINOCore, GraphRef.OpenVINONetwork, - fmt::format("{}"sv, Device).c_str(), &Config, - &GraphRef.OpenVINOExecNetwork); - if (Status != IEStatusCode::OK) { - spdlog::error( - "[WASI-NN] Unable to create executable Network, error code: {}", - Status); + // Store device information + GraphRef.TargetDevice = Device; + + try { + std::string ModelString(reinterpret_cast(XML.data()), + XML.size()); + GraphRef.OpenVINOIWeightTensor = + ov::Tensor(ov::element::Type_t::u8, {Weight.size()}); + std::copy_n(Weight.data(), Weight.size(), + static_cast(GraphRef.OpenVINOIWeightTensor.data())); + GraphRef.OpenVINOModel = Env.OpenVINOCore.read_model( + ModelString, GraphRef.OpenVINOIWeightTensor); + } catch (const std::exception &EX) { + spdlog::error("[WASI-NN] Model Load Exception: {}", EX.what()); Env.NNGraph.pop_back(); - return WASINN::ErrNo::Busy; + return WASINN::ErrNo::RuntimeError; } - // Store the loaded graph. GraphId = Env.NNGraph.size() - 1; - return WASINN::ErrNo::Success; } @@ -169,21 +54,12 @@ Expect initExecCtx(WASINN::WasiNNEnvironment &Env, uint32_t &ContextId) noexcept { // Check the network and the execution network with the graph ID. auto &GraphRef = Env.NNGraph[GraphId].get(); - if (GraphRef.OpenVINONetwork == nullptr || - GraphRef.OpenVINOExecNetwork == nullptr) { + if (GraphRef.OpenVINOModel == nullptr) { spdlog::error("[WASI-NN] Model for Graph:{} is empty!", GraphId); return WASINN::ErrNo::MissingMemory; } - // Create context. Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); - auto &CtxRef = Env.NNContext.back().get(); - if (CtxRef.OpenVINOInferRequest == nullptr) { - spdlog::error("[WASI-NN] Unable to create openvino context"); - Env.NNContext.pop_back(); - return WASINN::ErrNo::Busy; - } - ContextId = Env.NNContext.size() - 1; return WASINN::ErrNo::Success; } @@ -193,26 +69,17 @@ Expect setInput(WASINN::WasiNNEnvironment &Env, const TensorData &Tensor) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - // Check the infer request and the network. - auto *Network = GraphRef.OpenVINONetwork; - if (Network == nullptr || CxtRef.OpenVINOInferRequest == nullptr) { + + if (GraphRef.OpenVINOModel == nullptr) { spdlog::error("[WASI-NN] The founded openvino session is empty"); return WASINN::ErrNo::MissingMemory; } - // Check the input index. - if (GraphRef.OpenVINOInputNames.size() <= Index) { - spdlog::error("[WASI-NN] The input index {} exceeds the inputs number {}.", - Index, GraphRef.OpenVINOInputNames.size()); - return WASINN::ErrNo::InvalidArgument; - } - char *InputName = GraphRef.OpenVINOInputNames[Index]; - if (Tensor.Dimension.size() > 8) { - spdlog::error( - "[WASI-NN] Tensor dimension is out of range, expect it under 8-dim, " - "but got {}-dim.", - Tensor.Dimension.size()); + spdlog::error("[WASI-NN] Tensor dimension is out of range, expect " + "it under 8-dim, " + "but got {}-dim.", + Tensor.Dimension.size()); return WASINN::ErrNo::InvalidArgument; } if (Tensor.RType != WASINN::TensorType::F32) { @@ -221,96 +88,38 @@ Expect setInput(WASINN::WasiNNEnvironment &Env, return WASINN::ErrNo::InvalidArgument; } - // Set the input resize algorithm. - // Mark the input as resizable by setting a resize algorithm. - // In this case we will be able to set an input blob of any shape to an - // infer request. Resizing and layout conversions are executed automatically - // when inferring. - IEStatusCode Status = ie_network_set_input_resize_algorithm( - Network, InputName, RESIZE_BILINEAR); - if (Status != IEStatusCode::OK) { - spdlog::error( - "[WASI-NN] Unable to set input resize correctly, error code: {}", - Status); - return WASINN::ErrNo::InvalidArgument; - } - - // Set the input layout. - // More layouts should be supported. - Status = ie_network_set_input_layout(Network, InputName, layout_e::NHWC); - if (Status != IEStatusCode::OK) { - spdlog::error( - "[WASI-NN] Unable to set input layout correctly, error code: {}", - Status); - return WASINN::ErrNo::InvalidArgument; - } - - // Set the input precision. - // More types should be supported. - Status = - ie_network_set_input_precision(Network, InputName, precision_e::FP32); - if (Status != IEStatusCode::OK) { - spdlog::error( - "[WASI-NN] Unable to set input precision correctly, error code: {}", - Status); + // Check the input index. + if (GraphRef.OpenVINOModel->inputs().size() <= Index) { + spdlog::error("[WASI-NN] The input index {} exceeds the inputs number {}.", + Index, GraphRef.OpenVINOModel->inputs().size()); return WASINN::ErrNo::InvalidArgument; } - // Set the dimensions and the tensor description. - dimensions_t Dimens; - Dimens.ranks = Tensor.Dimension.size(); - for (size_t I = 0; I < Dimens.ranks; I++) { - Dimens.dims[I] = static_cast(Tensor.Dimension[I]); - } - tensor_desc_t TensorDesc = {layout_e::NHWC, Dimens, precision_e::FP32}; - - // Create the input blob memory. - ie_blob_t *InputBlob = nullptr; - Status = ie_blob_make_memory(&TensorDesc, &InputBlob); - if (Status != IEStatusCode::OK) { - spdlog::error("[WASI-NN] Unable to allocated input tensor correctly, " - "error code: {}", - Status); - return WASINN::ErrNo::Busy; + try { + ov::element::Type InputType = ov::element::f32; + ov::Shape InputShape = {1, 224, 224, 3}; + ov::Tensor InputTensor = + ov::Tensor(InputType, InputShape, Tensor.Tensor.data()); + const ov::Layout InputLayout{"NHWC"}; + ov::preprocess::PrePostProcessor PPP(GraphRef.OpenVINOModel); + PPP.input() + .tensor() + .set_shape(InputShape) + .set_element_type(InputType) + .set_layout(InputLayout); + PPP.input().preprocess().resize( + ov::preprocess::ResizeAlgorithm::RESIZE_LINEAR); + PPP.input().model().set_layout("NCHW"); + PPP.output().tensor().set_element_type(ov::element::f32); + auto model = PPP.build(); + ov::CompiledModel CompiledModel = + Env.OpenVINOCore.compile_model(model, "CPU"); + CxtRef.OpenVINOInferRequest = CompiledModel.create_infer_request(); + CxtRef.OpenVINOInferRequest.set_input_tensor(Index, InputTensor); + } catch (const std::exception &EX) { + spdlog::error("[WASI-NN] Set Input Exception: {}", EX.what()); + return WASINN::ErrNo::RuntimeError; } - - // Get the blob buffer size and compare with the tensor size. - int BlobSize; - Status = ie_blob_size(InputBlob, &BlobSize); - if (unlikely(Status != IEStatusCode::OK)) { - spdlog::error("[WASI-NN] Unable to get the input blob size, error code: {}", - Status); - return WASINN::ErrNo::Busy; - } - if (unlikely(static_cast(BlobSize * 4) != Tensor.Tensor.size())) { - spdlog::error("[WASI-NN] Blob size {} and the Tensor size {} not matched.", - BlobSize * 4, Tensor.Tensor.size()); - } - - // Copy the data into the input blob buffer. - ie_blob_buffer_t BlobBuffer; - Status = ie_blob_get_buffer(InputBlob, &BlobBuffer); - if (unlikely(Status != IEStatusCode::OK)) { - spdlog::error("[WASI-NN] Unable to find input tensor buffer"); - ie_blob_free(&InputBlob); - return WASINN::ErrNo::MissingMemory; - } - std::copy_n(Tensor.Tensor.data(), Tensor.Tensor.size(), - static_cast(BlobBuffer.buffer)); - - // Set input blob. - Status = ie_infer_request_set_blob(CxtRef.OpenVINOInferRequest, InputName, - InputBlob); - if (Status != IEStatusCode::OK) { - spdlog::error("[WASI-NN] Unable to set input tensor to model correctly, " - "error code: {}", - Status); - ie_blob_free(&InputBlob); - return WASINN::ErrNo::Busy; - } - - ie_blob_free(&InputBlob); - return WASINN::ErrNo::Success; } @@ -320,70 +129,36 @@ Expect getOutput(WASINN::WasiNNEnvironment &Env, uint32_t &BytesWritten) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - auto *Network = GraphRef.OpenVINONetwork; // Check the output index. - if (GraphRef.OpenVINOOutputNames.size() <= Index) { + if (GraphRef.OpenVINOModel->outputs().size() <= Index) { spdlog::error( "[WASI-NN] The output index {} exceeds the outputs number {}.", Index, - GraphRef.OpenVINOOutputNames.size()); - return WASINN::ErrNo::InvalidArgument; - } - char *OutputName = GraphRef.OpenVINOOutputNames[Index]; - - // Set output precision. - IEStatusCode Status = - ie_network_set_output_precision(Network, OutputName, precision_e::FP32); - if (Status != IEStatusCode::OK) { - spdlog::error( - "[WASI-NN] Unable to set output precision correctly with Index:{}", - Index); + GraphRef.OpenVINOModel->outputs().size()); return WASINN::ErrNo::InvalidArgument; } - // Get output blob buffer. - ie_blob_t *OutputBlob = nullptr; - Status = ie_infer_request_get_blob(CxtRef.OpenVINOInferRequest, OutputName, - &OutputBlob); - if (Status != IEStatusCode::OK) { - spdlog::error("[WASI-NN] Unable to retrieve output tensor correctly", - Index); - return WASINN::ErrNo::InvalidArgument; - } - - // Get the blob size and copy the output buffer. - int BlobSize; - Status = ie_blob_size(OutputBlob, &BlobSize); - ie_blob_buffer_t BlobCBuffer; - Status = ie_blob_get_cbuffer(OutputBlob, &BlobCBuffer); - if (Status != IEStatusCode::OK) { - spdlog::error("[WASI-NN] Unable to retrieve output tensor correctly", - Index); - ie_blob_free(&OutputBlob); - return WASINN::ErrNo::MissingMemory; + try { + const ov::Tensor &OutputTensor = + CxtRef.OpenVINOInferRequest.get_output_tensor(Index); + BytesWritten = OutputTensor.get_byte_size(); + std::copy_n(static_cast(OutputTensor.data()), BytesWritten, + OutBuffer.data()); + } catch (const std::exception &EX) { + spdlog::error("[WASI-NN] Get Output Exception: {}", EX.what()); + return WASINN::ErrNo::RuntimeError; } - uint32_t BytesToWrite = - std::min(static_cast(BlobSize * 4), OutBuffer.size()); - std::copy_n(static_cast(BlobCBuffer.cbuffer), BytesToWrite, - OutBuffer.data()); - - // Write the bytes written result. - BytesWritten = BytesToWrite; - - ie_blob_free(&OutputBlob); - return WASINN::ErrNo::Success; } Expect compute(WASINN::WasiNNEnvironment &Env, uint32_t ContextId) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); - IEStatusCode Status = ie_infer_request_infer(CxtRef.OpenVINOInferRequest); - if (Status != IEStatusCode::OK) { - spdlog::error( - "[WASI-NN] Unable to perform computation correctly, error code: {}", - Status); - return WASINN::ErrNo::Busy; + try { + CxtRef.OpenVINOInferRequest.infer(); + } catch (const std::exception &EX) { + spdlog::error("[WASI-NN] Infer Request Exception: {}", EX.what()); + return WASINN::ErrNo::RuntimeError; } return WASINN::ErrNo::Success; } diff --git a/plugins/wasi_nn/openvino.h b/plugins/wasi_nn/openvino.h index 9a1417a4..0ae3a2a7 100644 --- a/plugins/wasi_nn/openvino.h +++ b/plugins/wasi_nn/openvino.h @@ -7,65 +7,7 @@ #include "types.h" #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO -#include "common/log.h" -#include -#include - -template <> -struct fmt::formatter : fmt::formatter { - fmt::format_context::iterator format(IEStatusCode Code, - fmt::format_context &Ctx) const { - using namespace std::literals; - std::string_view Name; - switch (Code) { - case OK: - Name = "OK"sv; - break; - case GENERAL_ERROR: - Name = "GENERAL_ERROR"sv; - break; - case NOT_IMPLEMENTED: - Name = "NOT_IMPLEMENTED"sv; - break; - case NETWORK_NOT_LOADED: - Name = "NETWORK_NOT_LOADED"sv; - break; - case PARAMETER_MISMATCH: - Name = "PARAMETER_MISMATCH"sv; - break; - case NOT_FOUND: - Name = "NOT_FOUND"sv; - break; - case OUT_OF_BOUNDS: - Name = "OUT_OF_BOUNDS"sv; - break; - case UNEXPECTED: - Name = "UNEXPECTED"sv; - break; - case REQUEST_BUSY: - Name = "REQUEST_BUSY"sv; - break; - case RESULT_NOT_READY: - Name = "RESULT_NOT_READY"sv; - break; - case NOT_ALLOCATED: - Name = "NOT_ALLOCATED"sv; - break; - case INFER_NOT_STARTED: - Name = "INFER_NOT_STARTED"sv; - break; - case NETWORK_NOT_READ: - Name = "NETWORK_NOT_READ"sv; - break; - case INFER_CANCELLED: - Name = "INFER_CANCELLED"sv; - break; - default: - Name = "Unknown"sv; - } - return fmt::formatter::format(Name, Ctx); - } -}; +#include "openvino/openvino.hpp" #endif namespace WasmEdge::Host::WASINN { @@ -75,65 +17,23 @@ struct WasiNNEnvironment; namespace WasmEdge::Host::WASINN::OpenVINO { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO struct Graph { - ~Graph() noexcept { - if (OpenVINONetwork) { - ie_network_free(&OpenVINONetwork); - } - if (OpenVINOExecNetwork) { - ie_exec_network_free(&OpenVINOExecNetwork); - } - if (OpenVINOWeightBlob) { - ie_blob_free(&OpenVINOWeightBlob); - } - for (auto &I : OpenVINOInputNames) { - if (I) { - ie_network_name_free(&I); - } - } - for (auto &I : OpenVINOOutputNames) { - if (I) { - ie_network_name_free(&I); - } - } - } - ie_network_t *OpenVINONetwork = nullptr; - ie_executable_network_t *OpenVINOExecNetwork = nullptr; - ie_blob_t *OpenVINOWeightBlob = nullptr; - std::vector OpenVINOInputNames; - std::vector OpenVINOOutputNames; + ~Graph() noexcept {} + ov::Tensor OpenVINOIWeightTensor; + std::shared_ptr OpenVINOModel; + Device TargetDevice = Device::AUTO; }; struct Context { - Context(size_t GId, Graph &G) noexcept : GraphId(GId) { - IEStatusCode Status = ie_exec_network_create_infer_request( - G.OpenVINOExecNetwork, &OpenVINOInferRequest); - if (Status != IEStatusCode::OK) { - OpenVINOInferRequest = nullptr; - spdlog::error("[WASI-NN] Unable to create infer request for OpenVINO"); - } - } - ~Context() noexcept { - if (OpenVINOInferRequest) { - ie_infer_request_free(&OpenVINOInferRequest); - } - } + Context(size_t GId, Graph &) noexcept : GraphId(GId) {} + ~Context() noexcept {} size_t GraphId; - ie_infer_request_t *OpenVINOInferRequest = nullptr; + ov::InferRequest OpenVINOInferRequest; }; struct Environ { - Environ() noexcept { - if (ie_core_create("", &OpenVINOCore) != IEStatusCode::OK) { - spdlog::error( - "[WASI-NN] Error happened when initializing OpenVINO core."); - } - } - ~Environ() noexcept { - if (OpenVINOCore) { - ie_core_free(&OpenVINOCore); - } - } - ie_core_t *OpenVINOCore = nullptr; + Environ() noexcept {} + ~Environ() noexcept {} + ov::Core OpenVINOCore; }; #else struct Graph {}; diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 02857651..38195202 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -166,7 +166,8 @@ TEST(WasiNNTest, OpenVINOBackend) { std::initializer_list{ LoadEntryPtr, UINT32_C(2), UINT32_C(0), UINT32_C(0), BuilderPtr}, Errno)); - EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Busy)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::RuntimeError)); } // Test: load -- graph id ptr out of bounds. @@ -422,7 +423,8 @@ TEST(WasiNNTest, OpenVINOBackend) { EXPECT_TRUE(HostFuncCompute.run( CallFrame, std::initializer_list{UINT32_C(0)}, Errno)); - EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Busy)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::RuntimeError)); } // Swap back. NNGraphTmp.swap(NNMod->getEnv().NNGraph); diff --git a/utils/wasi-nn/install-openvino.sh b/utils/wasi-nn/install-openvino.sh index 15a8fcc8..761b6a21 100755 --- a/utils/wasi-nn/install-openvino.sh +++ b/utils/wasi-nn/install-openvino.sh @@ -8,5 +8,5 @@ wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PU apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB echo "deb https://apt.repos.intel.com/openvino/2023 ubuntu20 main" | tee /etc/apt/sources.list.d/intel-openvino-2023.list apt update -apt-get -y install openvino-2023.0.2 +apt-get -y install openvino-2023.2.0 ldconfig From d4834a545fa390c254c9072caea824a4485bdeee Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 31 Jan 2024 15:55:02 +0800 Subject: [PATCH 227/623] [WASI-NN] ggml: bump llama.cpp b2029 Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index ee1e9ad1..6fceddf0 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -47,7 +47,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b1953 + GIT_TAG b2029 PATCH_COMMAND test -f ggml.patched || git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch && ${CMAKE_COMMAND} -E touch ggml.patched GIT_SHALLOW FALSE ) From 78dfc46b8f233b16b47f79cfdd631c4e433d4cdd Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 31 Jan 2024 16:01:57 +0800 Subject: [PATCH 228/623] [WASI-NN] ggml: sync default values with llama.cpp Signed-off-by: hydai --- plugins/wasi_nn/ggml.h | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index d00836c0..a1ae94fe 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -22,22 +22,23 @@ struct Graph { llama_model *LlamaModel = nullptr; std::string ModelFilePath; // Plugin parameters: - bool EnableLog; - bool EnableDebugLog; - bool StreamStdout; + bool EnableLog = false; + bool EnableDebugLog = false; + bool StreamStdout = false; uint64_t NPredict; std::string ReversePrompt; // Model parameters: - int64_t NGPULayers; + int64_t NGPULayers = 0; // Context parameters: uint64_t CtxSize; uint64_t BatchSize; uint64_t Threads; // Sampling parameters: - double Temp; - double TopP; - double RepeatPenalty; - double PresencePenalty; + double Temp = 0.80; + double TopP = 0.95; + double RepeatPenalty = 1.10; + double PresencePenalty = 0.00; + double FrequencyPenalty = 0.00; }; struct Context { From cc6f15e7838957b96805206227f3aab5e6e99084 Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 31 Jan 2024 16:03:54 +0800 Subject: [PATCH 229/623] [WASI-NN] ggml: support frequency penalty (#3180) Signed-off-by: hydai --- plugins/wasi_nn/ggml.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index dee8207c..81f2e824 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -148,6 +148,15 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, return ErrNo::InvalidArgument; } } + if (Doc.at_key("frequency-penalty").error() == simdjson::SUCCESS) { + auto Err = + Doc["frequency-penalty"].get().get(GraphRef.FrequencyPenalty); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the frequency-penalty option."sv); + return ErrNo::InvalidArgument; + } + } // Check if the model is updated. if (IsModelUpdated && ModelParams.n_gpu_layers != GraphRef.NGPULayers) { @@ -209,6 +218,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, GraphRef.TopP = SamplingDefault.top_p; GraphRef.RepeatPenalty = SamplingDefault.penalty_repeat; GraphRef.PresencePenalty = SamplingDefault.penalty_present; + GraphRef.FrequencyPenalty = SamplingDefault.penalty_freq; // If the graph builder length > 1, the data of builder[1] is the metadata. if (Builders.size() > 1) { @@ -464,6 +474,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { GPTParams.sparams.top_p = GraphRef.TopP; GPTParams.sparams.penalty_repeat = GraphRef.RepeatPenalty; GPTParams.sparams.penalty_present = GraphRef.PresencePenalty; + GPTParams.sparams.penalty_freq = GraphRef.FrequencyPenalty; struct llama_sampling_context *CtxSampling = llama_sampling_init(GPTParams.sparams); std::vector Embd; @@ -673,6 +684,7 @@ Expect computeSingle(WasiNNEnvironment &Env, GPTParams.sparams.top_p = GraphRef.TopP; GPTParams.sparams.penalty_repeat = GraphRef.RepeatPenalty; GPTParams.sparams.penalty_present = GraphRef.PresencePenalty; + GPTParams.sparams.penalty_freq = GraphRef.FrequencyPenalty; CxtRef.LlamaSampling = llama_sampling_init(GPTParams.sparams); llama_context_params ContextParams = llama_context_default_params(); ContextParams.n_ctx = GraphRef.CtxSize; From 80ee9cf4e030eebae064e7f3942a2e7af0553cfa Mon Sep 17 00:00:00 2001 From: hydai Date: Thu, 1 Feb 2024 06:06:42 +0800 Subject: [PATCH 230/623] [WASI-NN] ggml: add experimental embedding support Signed-off-by: hydai --- plugins/wasi_nn/ggml.cpp | 187 +++++++++++++++++++++++++++++++++------ plugins/wasi_nn/ggml.h | 1 + 2 files changed, 161 insertions(+), 27 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 81f2e824..2d83eafb 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #endif namespace WasmEdge::Host::WASINN::GGML { @@ -59,6 +60,14 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, return ErrNo::InvalidArgument; } } + if (Doc.at_key("embedding").error() == simdjson::SUCCESS) { + auto Err = Doc["embedding"].get().get(GraphRef.Embedding); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the embedding option."sv); + return ErrNo::InvalidArgument; + } + } if (Doc.at_key("n-predict").error() == simdjson::SUCCESS) { auto Err = Doc["n-predict"].get().get(GraphRef.NPredict); if (Err) { @@ -168,22 +177,137 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, Expect buildOutputMetadata(Context &CxtRef, std::string &Metadata) noexcept { - std::string MetadataTemplate = - R"({"input_tokens": %d, "output_tokens": %d, "llama_build_number": %d, "llama_commit": "%s"})"; - - // The 20 bytes are reserved to accommodate two %d placeholders in the - // MetadataTemplate. This allows for a decimal integer value up to a - // 12-digit number of input/output tokens. - // The 3 bytes are reserved to accommodate the %d placeholder for the build - // number. Allows for a decimal integer value up to a 5-digit number. - // The 5 bytes are reserved to accommodate the %s placeholder for the commit - // hash. The commit hash is 7 bytes long by default using `git rev-parse - // --short HEAD`. - char Buffer[MetadataTemplate.size() + 20 + 3 + 5]; - snprintf(Buffer, sizeof(Buffer), MetadataTemplate.c_str(), - CxtRef.LlamaInputs.size(), CxtRef.LlamaOutputTokens.size(), - LLAMA_BUILD_NUMBER, LLAMA_COMMIT); - Metadata = std::string(Buffer); + std::ostringstream OS; + OS << R"({"input_tokens": )" << CxtRef.LlamaInputs.size() + << R"(, "output_tokens": )" << CxtRef.LlamaOutputTokens.size() + << R"(, "llama_build_number": )" << LLAMA_BUILD_NUMBER + << R"(, "llama_commit": ")" << LLAMA_COMMIT << R"("})"; + Metadata = OS.str(); + + return ErrNo::Success; +} + +void buildOutputEmbedding(std::string &Embedding, int32_t NEmbd, + const float *Embeddings) noexcept { + // Embedding vector format + // | Content | + // | ----------------------------------- | + // | '{"number_embedding": ' | + // | n_embedding | + // | ', "embedding": ' | + // | '[' | + // | n_embedding*(embedding value %.10f) | + // | (n_embedding-1)*(',') | + // | ']' | + // | '}' | + std::ostringstream OS; + OS.precision(10); + OS << R"({"n_embedding": )" << NEmbd << R"(, "embedding": [)"; + for (int32_t Idx = 0; Idx < NEmbd - 1; Idx++) { + OS << Embeddings[Idx] << ","; + } + OS << Embeddings[NEmbd - 1] << "]}"; + Embedding = OS.str(); +} + +Expect getEmbedding(WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: getEmbedding"sv); + } + + if (CxtRef.LlamaInputs.size() == 0) { + spdlog::error("[WASI-NN] GGML backend: Llama input is not set!"sv); + return ErrNo::InvalidArgument; + } + + // Clear the outputs. + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: clear the previous output and tokens"sv); + } + CxtRef.LlamaOutputs.clear(); + CxtRef.LlamaOutputTokens.clear(); + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: clear the previous output and tokens...Done"sv); + } + + // Main predict loop. + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: handle embedding"sv); + } + // Initialize the llama context. + llama_context_params ContextParams = llama_context_default_params(); + ContextParams.n_ctx = GraphRef.CtxSize; + ContextParams.n_batch = GraphRef.BatchSize; + ContextParams.embedding = GraphRef.Embedding; + auto *LlamaContext = + llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); + + // Get the context size. + const uint64_t NCtx = llama_n_ctx(LlamaContext); + // Minus 4 for the special tokens. (Such as , , ... tokens.) + const uint64_t MaxTokensListSize = NCtx - 4; + // Use the const sequence id here. + const llama_seq_id SequenceId = 0; + + // Check if the input is too long. + if (static_cast(CxtRef.LlamaInputs.size()) > MaxTokensListSize) { + if (GraphRef.EnableLog) { + spdlog::info( + "[WASI-NN] GGML backend: the prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, + CxtRef.LlamaInputs.size(), MaxTokensListSize); + } + return ErrNo::PromptTooLong; + } + + int NPast = 0; + while (!CxtRef.LlamaInputs.empty()) { + const uint64_t NTokens = (ContextParams.n_batch > CxtRef.LlamaInputs.size()) + ? CxtRef.LlamaInputs.size() + : ContextParams.n_batch; + auto Status = llama_decode(LlamaContext, + llama_batch_get_one(CxtRef.LlamaInputs.data(), + NTokens, NPast, SequenceId)); + if (Status == 1) { + spdlog::error( + "[WASI-NN] GGML backend: failed to llama_decode: try reducing the size of the batch or increasing the size of context"sv); + return ErrNo::RuntimeError; + } + if (Status < 0) { + spdlog::error( + "[WASI-NN] GGML backend: failed to llama_decode: internal fatal error. Please open an issue on GitHub"sv); + return ErrNo::RuntimeError; + } + + NPast += NTokens; + CxtRef.LlamaInputs.erase(CxtRef.LlamaInputs.begin(), + CxtRef.LlamaInputs.begin() + NTokens); + } + const int32_t NEmbd = llama_n_embd(GraphRef.LlamaModel); + const auto *Embeddings = llama_get_embeddings(LlamaContext); + + details::buildOutputEmbedding(CxtRef.LlamaOutputs, NEmbd, Embeddings); + + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: enter embedding loop...Done"sv); + } + + if (GraphRef.EnableLog) { + llama_print_timings(LlamaContext); + } + + // We free the contexts here to keep the ggml plugin stateless. + // Users could fully control the contexts by themselves via their prompt. + llama_free(LlamaContext); + + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: compute...Done"sv); + } return ErrNo::Success; } @@ -213,7 +337,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, GraphRef.BatchSize = ContextDefault.n_batch; GraphRef.Threads = ContextDefault.n_threads; // Initialize the sampling parameters. - llama_sampling_params SamplingDefault; + const llama_sampling_params SamplingDefault; GraphRef.Temp = SamplingDefault.temp; GraphRef.TopP = SamplingDefault.top_p; GraphRef.RepeatPenalty = SamplingDefault.penalty_repeat; @@ -222,8 +346,8 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // If the graph builder length > 1, the data of builder[1] is the metadata. if (Builders.size() > 1) { - std::string Metadata(reinterpret_cast(Builders[1].data()), - Builders[1].size()); + const std::string Metadata(reinterpret_cast(Builders[1].data()), + Builders[1].size()); // Ignore context or model updates when initializing the graph. auto Res = details::parseMetadata(GraphRef, Metadata); if (Res != ErrNo::Success) { @@ -243,7 +367,8 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, } // Handle the model path. auto Weight = Builders[0]; - std::string BinModel(reinterpret_cast(Weight.data()), Weight.size()); + const std::string BinModel(reinterpret_cast(Weight.data()), + Weight.size()); std::string ModelFilePath; if (BinModel.substr(0, 8) == "preload:") { ModelFilePath = BinModel.substr(8); @@ -254,7 +379,6 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, } // TODO: pass the model directly to ggml // Write ggml model to file. - std::istringstream BinRead(BinModel); ModelFilePath = "ggml-model.bin"sv; std::ofstream TempFile(ModelFilePath); if (!TempFile) { @@ -383,7 +507,9 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, ContextParams.n_batch = GraphRef.BatchSize; ContextParams.n_threads = GraphRef.Threads; ContextParams.n_threads_batch = GraphRef.Threads; - auto LlamaContext = + ContextParams.embedding = GraphRef.Embedding; + + auto *LlamaContext = llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] GGML backend: init llama context...Done"sv); @@ -394,8 +520,8 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, spdlog::info("[WASI-NN][Debug] GGML backend: set the input"sv); } const bool AddBos = llama_should_add_bos_token(GraphRef.LlamaModel); - std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), - Tensor.Tensor.size()); + const std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); CxtRef.LlamaInputs = llama_tokenize(LlamaContext, Prompt, AddBos, true); if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] GGML backend: set the input...Done"sv); @@ -448,6 +574,11 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] GGML backend: compute"sv); } + + if (GraphRef.Embedding) { + return details::getEmbedding(Env, ContextId); + } + if (CxtRef.LlamaInputs.size() == 0) { spdlog::error("[WASI-NN] GGML backend: Llama input is not set!"sv); return ErrNo::InvalidArgument; @@ -485,7 +616,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { llama_context_params ContextParams = llama_context_default_params(); ContextParams.n_ctx = GraphRef.CtxSize; ContextParams.n_batch = GraphRef.BatchSize; - auto LlamaContext = + auto *LlamaContext = llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); // Get the context size. @@ -550,7 +681,8 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { spdlog::error( "[WASI-NN] GGML backend: failed to llama_decode: try reducing the size of the batch or increasing the size of context"sv); return ErrNo::RuntimeError; - } else if (Status < 0) { + } + if (Status < 0) { spdlog::error( "[WASI-NN] GGML backend: failed to llama_decode: internal fatal error. Please open an issue on GitHub"sv); return ErrNo::RuntimeError; @@ -759,7 +891,8 @@ Expect computeSingle(WasiNNEnvironment &Env, spdlog::error( "[WASI-NN] GGML backend: failed to llama_decode: try reducing the size of the batch or increasing the size of context"sv); return ErrNo::RuntimeError; - } else if (Status < 0) { + } + if (Status < 0) { spdlog::error( "[WASI-NN] GGML backend: failed to llama_decode: internal fatal error. Please open an issue on GitHub"sv); return ErrNo::RuntimeError; diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index a1ae94fe..ad3f28d9 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -25,6 +25,7 @@ struct Graph { bool EnableLog = false; bool EnableDebugLog = false; bool StreamStdout = false; + bool Embedding = false; uint64_t NPredict; std::string ReversePrompt; // Model parameters: From ea8c9384866dbe07cce257ea98c98e93ec3021be Mon Sep 17 00:00:00 2001 From: hydai Date: Thu, 1 Feb 2024 10:54:54 +0800 Subject: [PATCH 231/623] [WASI-NN] ggml: bump llama.cpp b2037 Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 6fceddf0..5066af48 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -47,7 +47,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b2029 + GIT_TAG b2037 PATCH_COMMAND test -f ggml.patched || git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch && ${CMAKE_COMMAND} -E touch ggml.patched GIT_SHALLOW FALSE ) From e521a8f149af7c010e466e5cff6f9f4468b2ed0b Mon Sep 17 00:00:00 2001 From: vincent Date: Sun, 4 Feb 2024 01:02:03 +0800 Subject: [PATCH 232/623] [Misc] Update the printed messages to match the actual installed version of OpenVINO. (#3193) Signed-off-by: vincent --- utils/wasi-nn/install-openvino.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/wasi-nn/install-openvino.sh b/utils/wasi-nn/install-openvino.sh index 761b6a21..57afdee8 100755 --- a/utils/wasi-nn/install-openvino.sh +++ b/utils/wasi-nn/install-openvino.sh @@ -3,7 +3,7 @@ # SPDX-FileCopyrightText: 2019-2022 Second State INC set -e -echo "Installing OpenVINO with version 2023.0.2" +echo "Installing OpenVINO with version 2023.2.0" wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB echo "deb https://apt.repos.intel.com/openvino/2023 ubuntu20 main" | tee /etc/apt/sources.list.d/intel-openvino-2023.list From db2b6d50120ffcccdecf18abb16a65a46ca3c273 Mon Sep 17 00:00:00 2001 From: "LO, CHIN-HAO" <49036880+hankluo6@users.noreply.github.com> Date: Tue, 13 Feb 2024 00:43:50 -0600 Subject: [PATCH 233/623] [Misc] Fix the linking error with simdjson (#3206) Signed-off-by: hankluo6 --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 5066af48..f3963fb3 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -143,7 +143,7 @@ target_include_directories(wasmedgePluginWasiNN if(BACKEND STREQUAL "ggml") target_include_directories(wasmedgePluginWasiNN PUBLIC ${CMAKE_BINARY_DIR}/_deps/llama-src) - target_link_libraries(wasmedgePluginWasiNN PRIVATE common simdjson) + target_link_libraries(wasmedgePluginWasiNN PRIVATE common simdjson::simdjson) if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL) add_custom_command( TARGET wasmedgePluginWasiNN From 8de3d2acc70c853f4e9b4a819cf5c1551877ca5d Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 13 Feb 2024 16:04:48 +0800 Subject: [PATCH 234/623] [Zlib] Use the zlib package from system Signed-off-by: hydai --- plugins/wasmedge_zlib/CMakeLists.txt | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/plugins/wasmedge_zlib/CMakeLists.txt b/plugins/wasmedge_zlib/CMakeLists.txt index 30fa7d73..f12889bf 100644 --- a/plugins/wasmedge_zlib/CMakeLists.txt +++ b/plugins/wasmedge_zlib/CMakeLists.txt @@ -1,19 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: 2019-2022 Second State INC -# Don't reply on System zlib -# find_package(ZLIB REQUIRED) +find_package(ZLIB REQUIRED) set(ZLIB_COMPAT ON) -set(ZLIBNG_ENABLE_TESTS OFF) - -FetchContent_Declare( - zlib - GIT_REPOSITORY "https://github.com/zlib-ng/zlib-ng.git" - GIT_TAG 2.0.7 - GIT_PROGRESS TRUE -) -FetchContent_MakeAvailable(zlib) wasmedge_add_library(wasmedgePluginWasmEdgeZlib SHARED @@ -37,13 +27,13 @@ if(WASMEDGE_LINK_PLUGINS_STATIC) target_link_libraries(wasmedgePluginWasmEdgeZlib PRIVATE wasmedgeCAPI - zlib + z ) else() target_link_libraries(wasmedgePluginWasmEdgeZlib PRIVATE wasmedge_shared - zlib + z ) endif() From 8c457797008684e4b97fa12ba38247a711cd603e Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Tue, 5 Dec 2023 17:02:45 +0800 Subject: [PATCH 235/623] [AOT] Move LLVM-related code into separate directory * Rename cmake config WASMEDGE_BUILD_AOT_RUNTIME into WASMEDGE_USE_LLVM * Add a warning if old config was supplied * Separate out `CodeGen` function from `Compiler` class * Add `LLVM::Data` class for transfer `LLVM::Context` and `LLVM::Module` * Use `fpexcept.strict` for correct NAN-related behaviors Signed-off-by: Shen-Ta Hsieh --- utils/docker/Dockerfile.alpine-static | 2 +- utils/docker/Dockerfile.debian-static | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/docker/Dockerfile.alpine-static b/utils/docker/Dockerfile.alpine-static index d92dd8da..36e936af 100644 --- a/utils/docker/Dockerfile.alpine-static +++ b/utils/docker/Dockerfile.alpine-static @@ -62,7 +62,7 @@ RUN --mount=type=bind,target=/src,source=. \ # For cross compiling -DCMAKE_TOOLCHAIN_FILE="$(xx-toolchain)" \ -DWASMEDGE_BUILD_PACKAGE="TGZ" \ - -DWASMEDGE_BUILD_AOT_RUNTIME=ON \ + -DWASMEDGE_USE_LLVM=ON \ # Build just what we need -DWASMEDGE_BUILD_STATIC_LIB=ON \ -DWASMEDGE_BUILD_TESTS=OFF \ diff --git a/utils/docker/Dockerfile.debian-static b/utils/docker/Dockerfile.debian-static index 7c43fc78..cea1563b 100644 --- a/utils/docker/Dockerfile.debian-static +++ b/utils/docker/Dockerfile.debian-static @@ -75,14 +75,14 @@ RUN cmake -S /src -B /build -G Ninja \ -DCMAKE_INSTALL_PREFIX=/install \ -DWASMEDGE_BUILD_PACKAGE="TGZ" \ -DWASMEDGE_BUILD_TESTS=OFF \ - -DWASMEDGE_BUILD_AOT_RUNTIME=ON \ -DWASMEDGE_BUILD_SHARED_LIB=OFF \ -DWASMEDGE_BUILD_STATIC_LIB=ON \ -DWASMEDGE_BUILD_TOOLS=OFF \ -DWASMEDGE_BUILD_PLUGINS=OFF \ -DWASMEDGE_BUILD_EXAMPLE=OFF \ -DWASMEDGE_LINK_LLVM_STATIC=ON \ - -DWASMEDGE_LINK_TOOLS_STATIC=ON + -DWASMEDGE_LINK_TOOLS_STATIC=ON \ + -DWASMEDGE_USE_LLVM=ON RUN cmake --build /build -- install RUN cmake --build /build -- package From 0d1557b2e850cd599e442b97249528840c95ecda Mon Sep 17 00:00:00 2001 From: Sarrah Bastawala <84874044+sarrah-basta@users.noreply.github.com> Date: Mon, 19 Feb 2024 00:55:13 +0530 Subject: [PATCH 236/623] [WASI-OCR] WASI-OCR New Plugin for using Tesseract OCR within WasmEdge (#2962) * Final commit for Wasi-OCR plugin integrating OCR capabilities within WasmEdge Introducing **Wasi-OCR**: A WasmEdge Plugin for Optical Character Recognition (OCR) powered by the Tesseract API. This plugin offers seamless integration with Tesseract, a leading open-source OCR engine, enabling the extraction of text from images. Harness the power of OCR in your Document AI applications within WasmEdge effortlessly. --------- Signed-off-by: Sarrah Bastawala <84874044+sarrah-basta@users.noreply.github.com> --- plugins/CMakeLists.txt | 7 ++- plugins/wasi_ocr/CMakeLists.txt | 51 ++++++++++++++++++++++ plugins/wasi_ocr/wasiocrbase.h | 23 ++++++++++ plugins/wasi_ocr/wasiocrenv.cpp | 42 ++++++++++++++++++ plugins/wasi_ocr/wasiocrenv.h | 46 ++++++++++++++++++++ plugins/wasi_ocr/wasiocrfunc.cpp | 69 ++++++++++++++++++++++++++++++ plugins/wasi_ocr/wasiocrfunc.h | 30 +++++++++++++ plugins/wasi_ocr/wasiocrmodule.cpp | 17 ++++++++ plugins/wasi_ocr/wasiocrmodule.h | 23 ++++++++++ 9 files changed, 307 insertions(+), 1 deletion(-) create mode 100644 plugins/wasi_ocr/CMakeLists.txt create mode 100644 plugins/wasi_ocr/wasiocrbase.h create mode 100644 plugins/wasi_ocr/wasiocrenv.cpp create mode 100644 plugins/wasi_ocr/wasiocrenv.h create mode 100644 plugins/wasi_ocr/wasiocrfunc.cpp create mode 100644 plugins/wasi_ocr/wasiocrfunc.h create mode 100644 plugins/wasi_ocr/wasiocrmodule.cpp create mode 100644 plugins/wasi_ocr/wasiocrmodule.h diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index abe3e8e6..91b64909 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -54,6 +54,10 @@ if(WASMEDGE_PLUGIN_WASM_BPF) endif() endif() +if(WASMEDGE_PLUGIN_WASI_OCR) + add_subdirectory(wasi_ocr) +endif() + if(WASMEDGE_PLUGIN_OPENCVMINI) # Only Linux and MacOS support wasmedge_opencvmini now. if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") @@ -73,4 +77,5 @@ endif() if(WASMEDGE_PLUGIN_ZLIB) add_subdirectory(wasmedge_zlib) -endif() + +endif() \ No newline at end of file diff --git a/plugins/wasi_ocr/CMakeLists.txt b/plugins/wasi_ocr/CMakeLists.txt new file mode 100644 index 00000000..b2eef6d3 --- /dev/null +++ b/plugins/wasi_ocr/CMakeLists.txt @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +add_library(wasmedgePluginWasiOCR + SHARED + wasiocrenv.cpp + wasiocrfunc.cpp + wasiocrmodule.cpp +) + +target_compile_options(wasmedgePluginWasiOCR + PUBLIC + -DWASMEDGE_PLUGIN +) + +target_include_directories(wasmedgePluginWasiOCR + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} +) + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasiOCR + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginWasiOCR + PRIVATE + wasmedge_shared + ) +endif() + +install(TARGETS wasmedgePluginWasiOCR DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) + +message(STATUS "WASI-OCR: Build Tesseract backend for WASI-OCR") +find_package(PkgConfig REQUIRED) +pkg_search_module(TESSERACT REQUIRED tesseract) +pkg_search_module(LEPTONICA REQUIRED lept) + +target_include_directories(wasmedgePluginWasiOCR + PUBLIC + ${TESSERACT_INCLUDE_DIRS} + ${LEPTONICA_INCLUDE_DIRS} +) + +target_link_libraries(wasmedgePluginWasiOCR + PUBLIC + ${TESSERACT_LIBRARIES} + ${LEPTONICA_LIBRARIES} +) diff --git a/plugins/wasi_ocr/wasiocrbase.h b/plugins/wasi_ocr/wasiocrbase.h new file mode 100644 index 00000000..38fd9d72 --- /dev/null +++ b/plugins/wasi_ocr/wasiocrbase.h @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2023 Second State INC + +#pragma once + +#include "common/errcode.h" +#include "runtime/hostfunc.h" +#include "wasiocrenv.h" + +namespace WasmEdge { +namespace Host { + +template class WasiOCR : public Runtime::HostFunction { +public: + WasiOCR(WASIOCR::WasiOCREnvironment &HostEnv) + : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + WASIOCR::WasiOCREnvironment &Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_ocr/wasiocrenv.cpp b/plugins/wasi_ocr/wasiocrenv.cpp new file mode 100644 index 00000000..3df9d69c --- /dev/null +++ b/plugins/wasi_ocr/wasiocrenv.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2023 Second State INC + +#include "wasiocrenv.h" +#include "wasiocrmodule.h" + +namespace WasmEdge { +namespace Host { + +namespace { + +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasiOCRModule; +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasi_ocr", + .Description = "A WasmEdge Plugin for Optical Character Recognition (OCR) " + "powered by the Tesseract API.", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 10, 1, 0}, + .ModuleCount = 1, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "wasi_ocr", + .Description = + "A WasmEdge Plugin for Optical Character Recognition (OCR) " + "powered by the Tesseract API.", + .Create = create, + }, + }, + .AddOptions = nullptr, +}; + +} // namespace + +Plugin::PluginRegister WASIOCR::WasiOCREnvironment::Register(&Descriptor); + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_ocr/wasiocrenv.h b/plugins/wasi_ocr/wasiocrenv.h new file mode 100644 index 00000000..9e90a5ef --- /dev/null +++ b/plugins/wasi_ocr/wasiocrenv.h @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2023 Second State INC + +#pragma once + +#include "common/log.h" +#include "plugin/plugin.h" +#include + +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WASIOCR { + +enum class ErrNo : uint32_t { + Success = 0, // No error occurred. + InvalidArgument = 1, // Caller module passed an invalid argument. + MissingMemory = 2, // Caller module is missing a memory export. + Busy = 3 // Device or resource busy. +}; + +class WasiOCREnvironment { +public: + WasiOCREnvironment() noexcept { + // check Tesseract API by initializing tesseract-ocr with English, without + // specifying tessdata path + if (TesseractApi->Init(NULL, "eng")) { + spdlog::error("[WASI-OCR] Error occurred when initializing tesseract."); + } + } + ~WasiOCREnvironment() noexcept { + if (TesseractApi) { + TesseractApi->End(); + ; + } + } + tesseract::TessBaseAPI *TesseractApi = new tesseract::TessBaseAPI(); + + static Plugin::PluginRegister Register; +}; + +} // namespace WASIOCR +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_ocr/wasiocrfunc.cpp b/plugins/wasi_ocr/wasiocrfunc.cpp new file mode 100644 index 00000000..a1d17d81 --- /dev/null +++ b/plugins/wasi_ocr/wasiocrfunc.cpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2023 Second State INC + +#include "wasiocrfunc.h" +#include "common/log.h" + +#include +#include + +namespace WasmEdge { +namespace Host { + +Expect +WasiOCRNumOfExtractions::body(const Runtime::CallingFrame &Frame, + uint32_t ImagePathPtr, uint32_t ImagePathLen) { + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + auto ImagePtr = MemInst->getSpan(ImagePathPtr, ImagePathLen); + if (unlikely(ImagePtr.size() != ImagePathLen)) { + return Unexpect(ErrCode::Value::HostFuncError); + } + Pix *image = pixRead(ImagePtr.data()); + + Env.TesseractApi->SetImage(image); + Env.TesseractApi->Recognize(0); + + tesseract::PageIteratorLevel level = tesseract::RIL_WORD; + const char *outText = Env.TesseractApi->GetTSVText(level); + + uint32_t length = strlen(outText); + pixDestroy(&image); + return static_cast(length); +} + +Expect WasiOCRGetOutput::body(const Runtime::CallingFrame &Frame, + uint32_t OutBufferPtr [[maybe_unused]], + uint32_t OutBufferMaxSize + [[maybe_unused]]) { + // Check memory instance from module. + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + // Check the return value: OutBufferPtr should be valid. + auto Buf = MemInst->getSpan(OutBufferPtr, OutBufferMaxSize); + if (unlikely(Buf.empty())) { + spdlog::error( + "[WASI-OCR] Failed when accessing the return OutBufferPtr memory."); + return static_cast(WASIOCR::ErrNo::InvalidArgument); + } + + tesseract::PageIteratorLevel level = tesseract::RIL_WORD; + const char *outText = Env.TesseractApi->GetTSVText(level); + std::strcpy(Buf.data(), outText); + + // remaining free and deltee memory stuff + Env.TesseractApi->End(); + delete[] outText; // USE WHEN USING TESS API + + return static_cast(WASIOCR::ErrNo::Success); + // return outText; +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_ocr/wasiocrfunc.h b/plugins/wasi_ocr/wasiocrfunc.h new file mode 100644 index 00000000..ad98639b --- /dev/null +++ b/plugins/wasi_ocr/wasiocrfunc.h @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2023 Second State INC + +#pragma once + +#include "runtime/callingframe.h" +#include "wasiocrbase.h" + +#include + +namespace WasmEdge { +namespace Host { + +class WasiOCRNumOfExtractions : public WasiOCR { +public: + WasiOCRNumOfExtractions(WASIOCR::WasiOCREnvironment &HostEnv) + : WasiOCR(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t ImagePathPtr, + uint32_t ImagePathLen); +}; + +class WasiOCRGetOutput : public WasiOCR { +public: + WasiOCRGetOutput(WASIOCR::WasiOCREnvironment &HostEnv) : WasiOCR(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t OutBufferPtr, + uint32_t OutBufferMaxSize); +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_ocr/wasiocrmodule.cpp b/plugins/wasi_ocr/wasiocrmodule.cpp new file mode 100644 index 00000000..ea16f215 --- /dev/null +++ b/plugins/wasi_ocr/wasiocrmodule.cpp @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2023 Second State INC + +#include "wasiocrmodule.h" +#include "wasiocrfunc.h" + +namespace WasmEdge { +namespace Host { + +WasiOCRModule::WasiOCRModule() : ModuleInstance("wasi_ephemeral_ocr") { + addHostFunc("num_of_extractions", + std::make_unique(Env)); + addHostFunc("get_output", std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_ocr/wasiocrmodule.h b/plugins/wasi_ocr/wasiocrmodule.h new file mode 100644 index 00000000..cbab6d5c --- /dev/null +++ b/plugins/wasi_ocr/wasiocrmodule.h @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2023 Second State INC + +#pragma once + +#include "runtime/instance/module.h" +#include "wasiocrenv.h" + +namespace WasmEdge { +namespace Host { + +class WasiOCRModule : public Runtime::Instance::ModuleInstance { +public: + WasiOCRModule(); + + WASIOCR::WasiOCREnvironment &getEnv() { return Env; } + +private: + WASIOCR::WasiOCREnvironment Env; +}; + +} // namespace Host +} // namespace WasmEdge From d7377ad75f2cf4afe86f7e11d923dcae3062d3a1 Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 21 Feb 2024 15:39:35 +0800 Subject: [PATCH 237/623] [WASI-NN] ggml: fix incorrect inputs tokens number after calling embedding (#3228) Signed-off-by: hydai --- plugins/wasi_nn/ggml.cpp | 3 ++- plugins/wasi_nn/ggml.h | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 2d83eafb..fd50b702 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -178,7 +178,7 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, Expect buildOutputMetadata(Context &CxtRef, std::string &Metadata) noexcept { std::ostringstream OS; - OS << R"({"input_tokens": )" << CxtRef.LlamaInputs.size() + OS << R"({"input_tokens": )" << CxtRef.LlamaNInputs << R"(, "output_tokens": )" << CxtRef.LlamaOutputTokens.size() << R"(, "llama_build_number": )" << LLAMA_BUILD_NUMBER << R"(, "llama_commit": ")" << LLAMA_COMMIT << R"("})"; @@ -523,6 +523,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, const std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), Tensor.Tensor.size()); CxtRef.LlamaInputs = llama_tokenize(LlamaContext, Prompt, AddBos, true); + CxtRef.LlamaNInputs = CxtRef.LlamaInputs.size(); if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] GGML backend: set the input...Done"sv); } diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index ad3f28d9..4df0b76d 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -53,6 +53,7 @@ struct Context { llama_context *LlamaContext = nullptr; struct llama_sampling_context *LlamaSampling = nullptr; std::vector LlamaEmbd; + uint64_t LlamaNInputs; uint64_t LlamaNPast; uint64_t LlamaNConsumed; }; From e7e6d53200804ecddc723b6556c68f9a8605e3aa Mon Sep 17 00:00:00 2001 From: dm4 Date: Fri, 19 Jan 2024 14:38:44 +0800 Subject: [PATCH 238/623] [WASI-NN] ggml: add llava support Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 29 +- plugins/wasi_nn/ggml.cpp | 590 ++++++++++++++++++++------------- plugins/wasi_nn/ggml.h | 12 +- 3 files changed, 386 insertions(+), 245 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index f3963fb3..3275f86c 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -12,6 +12,8 @@ if(BACKEND STREQUAL "ggml") if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_CUBLAS) message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_CUBLAS") set(LLAMA_CUBLAS ON) + # We need to set GGML_USE_CUBLAS for clip from llava. + add_compile_definitions(GGML_USE_CUBLAS) # If CUBLAS is ON, then OpenBLAS should be OFF. set(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_BLAS OFF) else() @@ -142,8 +144,31 @@ target_include_directories(wasmedgePluginWasiNN ) if(BACKEND STREQUAL "ggml") - target_include_directories(wasmedgePluginWasiNN PUBLIC ${CMAKE_BINARY_DIR}/_deps/llama-src) - target_link_libraries(wasmedgePluginWasiNN PRIVATE common simdjson::simdjson) + # Setup llava from llama.cpp + wasmedge_add_library(llava OBJECT + ${CMAKE_BINARY_DIR}/_deps/llama-src/examples/llava/clip.cpp + ${CMAKE_BINARY_DIR}/_deps/llama-src/examples/llava/llava.cpp + ) + target_link_libraries(llava PRIVATE ggml llama) + target_include_directories(llava PUBLIC + ${CMAKE_BINARY_DIR}/_deps/llama-src + ${CMAKE_BINARY_DIR}/_deps/llama-src/common + ${CMAKE_BINARY_DIR}/_deps/llama-src/examples/llava + ) + target_compile_options(llava PRIVATE + -Wno-error=unused-function + -Wno-error=unused-variable + -Wno-unused-function + -Wno-unused-variable + ) + wasmedge_setup_target(llava) + + # Setup include and link from llama.cpp + target_include_directories(wasmedgePluginWasiNN PUBLIC + ${CMAKE_BINARY_DIR}/_deps/llama-src + ${CMAKE_BINARY_DIR}/_deps/llama-src/examples/llava + ) + target_link_libraries(wasmedgePluginWasiNN PRIVATE common simdjson::simdjson llava) if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL) add_custom_command( TARGET wasmedgePluginWasiNN diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index fd50b702..f7b6a68f 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -6,9 +6,11 @@ #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML #include "simdjson.h" +#include #include #include #include +#include #include #endif @@ -27,8 +29,29 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, } // Get metadata from the json. - // Need to update Model: - // * n_gpu_layers + + // Currently supported metadata: + // Plugin parameters (used by this plugin): + // enable-log: bool + // enable-debug-log: bool + // stream-stdout: bool + // embedding: bool + // n-predict: uint64_t + // reverse-prompt: string + // mmproj: string + // image: string + // Model parameters (need to reload the model if updated): + // n-gpu-layers: int64_t + // Context parameters (used by the llama context): + // ctx-size: uint64_t + // batch-size: uint64_t + // threads: uint64_t + // Sampling parameters (used by the llama sampling context). + // temp: double + // top-p: double + // repeat-penalty: double + // presence-penalty: double + // frequency-penalty: double // Get the current llama parameters. llama_model_params ModelParams = llama_model_default_params(); @@ -86,6 +109,26 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, } GraphRef.ReversePrompt = ReversePrompt; } + if (Doc.at_key("mmproj").error() == simdjson::SUCCESS) { + std::string_view MMProjModelPath; + auto Err = Doc["mmproj"].get().get(MMProjModelPath); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the mmproj option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.MMProjModelPath = MMProjModelPath; + } + if (Doc.at_key("image").error() == simdjson::SUCCESS) { + std::string_view ImagePath; + auto Err = Doc["image"].get().get(ImagePath); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the image option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.ImagePath = ImagePath; + } // The model parameters. if (Doc.at_key("n-gpu-layers").error() == simdjson::SUCCESS) { @@ -122,6 +165,7 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, return ErrNo::InvalidArgument; } } + // The sampling parameters. if (Doc.at_key("temp").error() == simdjson::SUCCESS) { auto Err = Doc["temp"].get().get(GraphRef.Temp); @@ -175,6 +219,23 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, return ErrNo::Success; } +Expect setupGPTParam(Graph &GraphRef, gpt_params &GPTParams) { + GPTParams.sparams.temp = GraphRef.Temp; + GPTParams.sparams.top_p = GraphRef.TopP; + GPTParams.sparams.penalty_repeat = GraphRef.RepeatPenalty; + GPTParams.sparams.penalty_present = GraphRef.PresencePenalty; + return ErrNo::Success; +} + +Expect setupContextParam(Graph &GraphRef, + llama_context_params &ContextParams) { + ContextParams.n_ctx = GraphRef.CtxSize; + ContextParams.n_batch = GraphRef.BatchSize; + ContextParams.n_threads = GraphRef.Threads; + ContextParams.n_threads_batch = GraphRef.Threads; + return ErrNo::Success; +} + Expect buildOutputMetadata(Context &CxtRef, std::string &Metadata) noexcept { std::ostringstream OS; @@ -312,6 +373,48 @@ Expect getEmbedding(WasiNNEnvironment &Env, return ErrNo::Success; } +ErrNo EvaluateTokens(Graph &GraphRef, struct llama_context *LlamaContext, + std::vector Tokens, int &NPast) noexcept { + uint32_t NCtx = llama_n_ctx(LlamaContext); + + // End the inference if the context is full. + if (NPast + static_cast(Tokens.size()) > NCtx) { + if (GraphRef.EnableLog) { + spdlog::info( + "[WASI-NN] GGML backend: the context if full ({} / {} tokens). Please increase your context size."sv, + NPast + static_cast(Tokens.size()), NCtx); + } + return ErrNo::ContextFull; + } + + for (int I = 0; I < static_cast(Tokens.size()); + I += GraphRef.BatchSize) { + int NEval = static_cast(Tokens.size()) - I; + if (NEval > static_cast(GraphRef.BatchSize)) { + NEval = GraphRef.BatchSize; + } + // llama_batch_get_one(*token, n_tokens, position, sequence_id) + // This will return batch for single sequence of tokens starting at + // position. + const llama_seq_id SequenceId = 0; + auto Status = + llama_decode(LlamaContext, + llama_batch_get_one(&Tokens[I], NEval, NPast, SequenceId)); + if (Status == 1) { + spdlog::error( + "[WASI-NN] GGML backend: failed to llama_decode: try reducing the size of the batch or increasing the size of context"sv); + return ErrNo::RuntimeError; + } else if (Status < 0) { + spdlog::error( + "[WASI-NN] GGML backend: failed to llama_decode: internal fatal error. Please open an issue on GitHub"sv); + return ErrNo::RuntimeError; + } + NPast += NEval; + } + + return ErrNo::Success; +} + } // namespace details Expect load(WasiNNEnvironment &Env, Span> Builders, @@ -323,9 +426,12 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Initialize the plugin parameters. auto ContextDefault = llama_context_default_params(); GraphRef.EnableLog = false; + GraphRef.EnableDebugLog = false; GraphRef.StreamStdout = false; - GraphRef.ReversePrompt = ""sv; GraphRef.NPredict = ContextDefault.n_ctx; + GraphRef.ReversePrompt = ""sv; + GraphRef.MMProjModelPath = ""sv; + GraphRef.ImagePath = ""sv; // Initialize the model parameters. GraphRef.NGPULayers = 0; #ifdef __APPLE__ @@ -503,13 +609,8 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, spdlog::info("[WASI-NN][Debug] GGML backend: init llama context"sv); } llama_context_params ContextParams = llama_context_default_params(); - ContextParams.n_ctx = GraphRef.CtxSize; - ContextParams.n_batch = GraphRef.BatchSize; - ContextParams.n_threads = GraphRef.Threads; - ContextParams.n_threads_batch = GraphRef.Threads; - ContextParams.embedding = GraphRef.Embedding; - - auto *LlamaContext = + details::setupContextParam(GraphRef, ContextParams); + auto LlamaContext = llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] GGML backend: init llama context...Done"sv); @@ -522,8 +623,64 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, const bool AddBos = llama_should_add_bos_token(GraphRef.LlamaModel); const std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), Tensor.Tensor.size()); - CxtRef.LlamaInputs = llama_tokenize(LlamaContext, Prompt, AddBos, true); - CxtRef.LlamaNInputs = CxtRef.LlamaInputs.size(); + if (GraphRef.MMProjModelPath == ""sv) { + // Text only prompt. + CxtRef.LlamaInputs = llama_tokenize(LlamaContext, Prompt, AddBos, true); + CxtRef.LlamaNInputs = CxtRef.LlamaInputs.size(); + } else { + // Handle llava format prompt. + + // Show some warnings. + if (GraphRef.EnableLog) { + if (GraphRef.ImagePath == ""sv) { + spdlog::info( + "[WASI-NN] GGML backend: Image path is not set, will process as text-only prompt"sv); + } + if (GraphRef.CtxSize < 4096) { + spdlog::info( + "[WASI-NN] GGML backend: Context size is {}, " + "we recommand context size >= 2048 when using llava-v1.5 " + "and context size >= 4096 when using llava-v1.6 for better results."sv, + GraphRef.CtxSize); + } + } + + // We split prompt by as placeholder and save the position. + const std::string_view PromptPlaceholder = ""; + auto PlaceholderPosition = Prompt.find(PromptPlaceholder); + if (PlaceholderPosition == std::string::npos) { + spdlog::error( + "[WASI-NN] GGML backend: Error: unable to find the placeholder in the llava prompt."sv); + return ErrNo::InvalidArgument; + } + std::string PromptBeforeImage = Prompt.substr(0, PlaceholderPosition); + std::string PromptAfterImage = + Prompt.substr(PlaceholderPosition + PromptPlaceholder.length()); + std::vector EmbdInputBeforeImage = + llama_tokenize(LlamaContext, PromptBeforeImage, AddBos, true); + std::vector EmbdInputAfterImage = + llama_tokenize(LlamaContext, PromptAfterImage, false, true); + CxtRef.LlavaImagePosition = EmbdInputBeforeImage.size(); + CxtRef.LlamaInputs.reserve(EmbdInputBeforeImage.size() + + EmbdInputAfterImage.size()); + CxtRef.LlamaInputs.insert(CxtRef.LlamaInputs.end(), + EmbdInputBeforeImage.begin(), + EmbdInputBeforeImage.end()); + CxtRef.LlamaInputs.insert(CxtRef.LlamaInputs.end(), + EmbdInputAfterImage.begin(), + EmbdInputAfterImage.end()); + + // Load image for llava. + int LlavaVerbosity = 0; + if (GraphRef.EnableLog) { + LlavaVerbosity = 1; + } + auto ClipContext = + clip_model_load(GraphRef.MMProjModelPath.c_str(), LlavaVerbosity); + CxtRef.LlavaImageEmbd = llava_image_embed_make_with_filename( + ClipContext, GraphRef.Threads, GraphRef.ImagePath.c_str()); + clip_free(ClipContext); + } if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] GGML backend: set the input...Done"sv); } @@ -597,35 +754,22 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { "[WASI-NN][Debug] GGML backend: clear the previous output and tokens...Done"sv); } - // Main predict loop. - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: enter main predict loop"sv); - } + // Initialize the llama context. gpt_params GPTParams; - GPTParams.sparams.temp = GraphRef.Temp; - GPTParams.sparams.top_p = GraphRef.TopP; - GPTParams.sparams.penalty_repeat = GraphRef.RepeatPenalty; - GPTParams.sparams.penalty_present = GraphRef.PresencePenalty; - GPTParams.sparams.penalty_freq = GraphRef.FrequencyPenalty; + llama_context_params ContextParams = llama_context_default_params(); + details::setupGPTParam(GraphRef, GPTParams); + details::setupContextParam(GraphRef, ContextParams); + auto LlamaContext = + llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); struct llama_sampling_context *CtxSampling = llama_sampling_init(GPTParams.sparams); - std::vector Embd; - uint64_t NPast = 0; - uint64_t NConsumed = 0; + // Prepare variables; + int32_t NPast = 0; int32_t NRemain = GraphRef.NPredict; - // Initialize the llama context. - llama_context_params ContextParams = llama_context_default_params(); - ContextParams.n_ctx = GraphRef.CtxSize; - ContextParams.n_batch = GraphRef.BatchSize; - auto *LlamaContext = - llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); - // Get the context size. const uint64_t NCtx = llama_n_ctx(LlamaContext); // Minus 4 for the special tokens. (Such as , , ... tokens.) const uint64_t MaxTokensListSize = NCtx - 4; - // Use the const sequence id here. - const llama_seq_id SequenceId = 0; // Return value. auto ReturnCode = ErrNo::Success; @@ -639,109 +783,91 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { return ErrNo::PromptTooLong; } + // Evaluate input tokens. + if (CxtRef.LlavaImageEmbd == nullptr) { + // Text only prompt. + ReturnCode = details::EvaluateTokens(GraphRef, LlamaContext, + CxtRef.LlamaInputs, NPast); + if (ReturnCode != ErrNo::Success) { + spdlog::error( + "[WASI-NN] GGML backend: failed to evaluate input tokens."sv); + return ReturnCode; + } + } else { + // Llava format prompt with image data. + std::vector EmbdInputBeforeImage( + CxtRef.LlamaInputs.begin(), + CxtRef.LlamaInputs.begin() + CxtRef.LlavaImagePosition); + std::vector EmbdInputAfterImage(CxtRef.LlamaInputs.begin() + + CxtRef.LlavaImagePosition, + CxtRef.LlamaInputs.end()); + ReturnCode = details::EvaluateTokens(GraphRef, LlamaContext, + EmbdInputBeforeImage, NPast); + if (ReturnCode != ErrNo::Success) { + spdlog::error( + "[WASI-NN] GGML backend: failed to evaluate input tokens before image."sv); + return ReturnCode; + } + bool EvalImageStatus = llava_eval_image_embed( + LlamaContext, CxtRef.LlavaImageEmbd, GraphRef.BatchSize, &NPast); + if (!EvalImageStatus) { + spdlog::error( + "[WASI-NN] GGML backend: failed to evaluate embed image tokens."sv); + return ErrNo::RuntimeError; + } + ReturnCode = details::EvaluateTokens(GraphRef, LlamaContext, + EmbdInputAfterImage, NPast); + if (ReturnCode != ErrNo::Success) { + spdlog::error( + "[WASI-NN] GGML backend: failed to evaluate input tokens after image."sv); + return ReturnCode; + } + } + // Main predict loop. + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: enter main predict loop"sv); + } while (NRemain > 0) { - // Preidct - if (!Embd.empty()) { - // Input too long. - if (static_cast(Embd.size()) > MaxTokensListSize) { - if (GraphRef.EnableLog) { - spdlog::info( - "[WASI-NN] GGML backend: the prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, - Embd.size(), MaxTokensListSize); - } - ReturnCode = ErrNo::PromptTooLong; - break; - } - - // We do not swap context here. End the inference if the context is full. - if (NPast + static_cast(Embd.size()) > NCtx) { - if (GraphRef.EnableLog) { - spdlog::info( - "[WASI-NN] GGML backend: the context if full ({} / {} tokens). Please increase your ctx-size."sv, - NPast + static_cast(Embd.size()), NCtx); - } - ReturnCode = ErrNo::ContextFull; - break; - } - - // Evaluate tokens in batches. - for (int I = 0; I < static_cast(Embd.size()); - I += GraphRef.BatchSize) { - uint64_t NEval = static_cast(Embd.size()) - I; - if (NEval > static_cast(GraphRef.BatchSize)) { - NEval = GraphRef.BatchSize; - } - // llama_batch_get_one(*token, n_tokens, position, sequence_id) - // This will return batch for single sequence of tokens starting at - // position. - auto Status = - llama_decode(LlamaContext, llama_batch_get_one(&Embd[I], NEval, - NPast, SequenceId)); - if (Status == 1) { - spdlog::error( - "[WASI-NN] GGML backend: failed to llama_decode: try reducing the size of the batch or increasing the size of context"sv); - return ErrNo::RuntimeError; - } - if (Status < 0) { - spdlog::error( - "[WASI-NN] GGML backend: failed to llama_decode: internal fatal error. Please open an issue on GitHub"sv); - return ErrNo::RuntimeError; - } - - NPast += NEval; - } + const llama_token Id = + llama_sampling_sample(CtxSampling, LlamaContext, nullptr); + llama_sampling_accept(CtxSampling, LlamaContext, Id, true); + --NRemain; + + // Save the output token. + CxtRef.LlamaOutputTokens.emplace_back(Id); + CxtRef.LlamaOutputs += llama_token_to_piece(LlamaContext, Id); + // When setting StreamStdout, we print the output to stdout. + if (GraphRef.StreamStdout) { + std::cout << llama_token_to_piece(LlamaContext, Id) << std::flush; } - - Embd.clear(); - - if (static_cast(CxtRef.LlamaInputs.size()) <= NConsumed) { - const llama_token Id = - llama_sampling_sample(CtxSampling, LlamaContext, nullptr); - llama_sampling_accept(CtxSampling, LlamaContext, Id, true); - Embd.emplace_back(Id); - --NRemain; - // Save the output token. - CxtRef.LlamaOutputTokens.emplace_back(Id); - CxtRef.LlamaOutputs += llama_token_to_piece(LlamaContext, Id); - // When setting StreamStdout, we print the output to stdout. - if (GraphRef.StreamStdout) { - std::cout << llama_token_to_piece(LlamaContext, Id) << std::flush; + // Break if reverse prompt is found. + if (!GraphRef.ReversePrompt.empty() && + CxtRef.LlamaOutputs.find(GraphRef.ReversePrompt) != std::string::npos) { + if (GraphRef.EnableLog) { + spdlog::info("[WASI-NN] GGML backend: reverse prompt found"sv); } - // Break if reverse prompt is found. - if (!GraphRef.ReversePrompt.empty() && - CxtRef.LlamaOutputs.find(GraphRef.ReversePrompt) != - std::string::npos) { - if (GraphRef.EnableLog) { - spdlog::info("[WASI-NN] GGML backend: reverse prompt found"sv); - } - break; - } - // Deal with end of text token. - if (llama_sampling_last(CtxSampling) == - llama_token_eos(GraphRef.LlamaModel)) { - if (GraphRef.EnableLog) { - spdlog::info("[WASI-NN] GGML backend: EOS token found"sv); - } - break; - } - } else { - while (static_cast(CxtRef.LlamaInputs.size()) > NConsumed) { - Embd.push_back(CxtRef.LlamaInputs[NConsumed]); - // Push the prompt in the sampling context. - llama_sampling_accept(CtxSampling, LlamaContext, - CxtRef.LlamaInputs[NConsumed], false); - ++NConsumed; - if (Embd.size() >= GraphRef.BatchSize) { - break; - } + break; + } + // Deal with end of text token. + if (llama_sampling_last(CtxSampling) == + llama_token_eos(GraphRef.LlamaModel)) { + if (GraphRef.EnableLog) { + spdlog::info("[WASI-NN] GGML backend: EOS token found"sv); } + break; + } + // Evaluate the output token. + ReturnCode = details::EvaluateTokens(GraphRef, LlamaContext, {Id}, NPast); + if (ReturnCode != ErrNo::Success) { + break; } } if (GraphRef.EnableDebugLog) { spdlog::info( "[WASI-NN][Debug] GGML backend: enter main predict loop...Done"sv); } + // End of main predict loop. if (GraphRef.EnableLog) { llama_print_timings(LlamaContext); @@ -751,6 +877,10 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // Users could fully control the contexts by themselves via their prompt. llama_sampling_free(CtxSampling); llama_free(LlamaContext); + if (CxtRef.LlavaImageEmbd != nullptr) { + llava_image_embed_free(CxtRef.LlavaImageEmbd); + CxtRef.LlavaImageEmbd = nullptr; + } if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] GGML backend: compute...Done"sv); @@ -813,133 +943,113 @@ Expect computeSingle(WasiNNEnvironment &Env, // Initialize the llama context. gpt_params GPTParams; - GPTParams.sparams.temp = GraphRef.Temp; - GPTParams.sparams.top_p = GraphRef.TopP; - GPTParams.sparams.penalty_repeat = GraphRef.RepeatPenalty; - GPTParams.sparams.penalty_present = GraphRef.PresencePenalty; - GPTParams.sparams.penalty_freq = GraphRef.FrequencyPenalty; - CxtRef.LlamaSampling = llama_sampling_init(GPTParams.sparams); llama_context_params ContextParams = llama_context_default_params(); - ContextParams.n_ctx = GraphRef.CtxSize; - ContextParams.n_batch = GraphRef.BatchSize; - ContextParams.n_threads = GraphRef.Threads; - ContextParams.n_threads_batch = GraphRef.Threads; + details::setupGPTParam(GraphRef, GPTParams); + details::setupContextParam(GraphRef, ContextParams); CxtRef.LlamaContext = llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); - CxtRef.LlamaEmbd.clear(); + CxtRef.LlamaSampling = llama_sampling_init(GPTParams.sparams); CxtRef.LlamaNPast = 0; - CxtRef.LlamaNConsumed = 0; - } - - // Get the context size. - const uint64_t NCtx = llama_n_ctx(CxtRef.LlamaContext); - // Minus 4 for the special tokens. (Such as , , ... tokens.) - const uint64_t MaxTokensListSize = NCtx - 4; - // Use the const sequence id here. - const llama_seq_id SequenceId = 0; - // Check if the input is too long. - if (static_cast(CxtRef.LlamaInputs.size()) > MaxTokensListSize) { - if (GraphRef.EnableLog) { - spdlog::info( - "[WASI-NN] GGML backend: the prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, - CxtRef.LlamaInputs.size(), MaxTokensListSize); + // Get the context size. + const uint64_t NCtx = llama_n_ctx(CxtRef.LlamaContext); + // Minus 4 for the special tokens. (Such as , , ... tokens.) + const uint64_t MaxTokensListSize = NCtx - 4; + // Return value. + auto ReturnCode = ErrNo::Success; + + // Check if the input is too long. + if (static_cast(CxtRef.LlamaInputs.size()) > MaxTokensListSize) { + if (GraphRef.EnableLog) { + spdlog::info( + "[WASI-NN] GGML backend: the prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, + CxtRef.LlamaInputs.size(), MaxTokensListSize); + } + return ErrNo::PromptTooLong; } - return ErrNo::PromptTooLong; - } - // Main predict loop. - while (true) { - if (!CxtRef.LlamaEmbd.empty()) { - // Input too long. - if (static_cast(CxtRef.LlamaEmbd.size()) > MaxTokensListSize) { - if (GraphRef.EnableLog) { - spdlog::info( - "[WASI-NN] GGML backend: the prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, - CxtRef.LlamaEmbd.size(), MaxTokensListSize); - } - return ErrNo::PromptTooLong; + // Evaluate input tokens. + if (CxtRef.LlavaImageEmbd == nullptr) { + // Text only prompt. + ReturnCode = details::EvaluateTokens( + GraphRef, CxtRef.LlamaContext, CxtRef.LlamaInputs, CxtRef.LlamaNPast); + if (ReturnCode != ErrNo::Success) { + spdlog::error( + "[WASI-NN] GGML backend: failed to evaluate input tokens."sv); + return ReturnCode; } - - // We do not swap context here. End the inference if the context is full. - if (CxtRef.LlamaNPast + static_cast(CxtRef.LlamaEmbd.size()) > - NCtx) { - if (GraphRef.EnableLog) { - spdlog::info( - "[WASI-NN] GGML backend: the context if full ({} / {} tokens). Please increase your ctx-size."sv, - CxtRef.LlamaNPast + - static_cast(CxtRef.LlamaEmbd.size()), - NCtx); - } - return ErrNo::ContextFull; + } else { + // Llava format prompt with image data. + std::vector EmbdInputBeforeImage( + CxtRef.LlamaInputs.begin(), + CxtRef.LlamaInputs.begin() + CxtRef.LlavaImagePosition); + std::vector EmbdInputAfterImage( + CxtRef.LlamaInputs.begin() + CxtRef.LlavaImagePosition, + CxtRef.LlamaInputs.end()); + ReturnCode = + details::EvaluateTokens(GraphRef, CxtRef.LlamaContext, + EmbdInputBeforeImage, CxtRef.LlamaNPast); + if (ReturnCode != ErrNo::Success) { + spdlog::error( + "[WASI-NN] GGML backend: failed to evaluate input tokens before image."sv); + return ReturnCode; } - - // Evaluate tokens in batches. - for (uint64_t I = 0; I < static_cast(CxtRef.LlamaEmbd.size()); - I += GraphRef.BatchSize) { - uint64_t NEval = static_cast(CxtRef.LlamaEmbd.size()) - I; - if (NEval > static_cast(GraphRef.BatchSize)) { - NEval = GraphRef.BatchSize; - } - // llama_batch_get_one(*token, n_tokens, position, sequence_id) - // This will return batch for single sequence of tokens starting at - // position. - auto Status = - llama_decode(CxtRef.LlamaContext, - llama_batch_get_one(&CxtRef.LlamaEmbd[I], NEval, - CxtRef.LlamaNPast, SequenceId)); - if (Status == 1) { - spdlog::error( - "[WASI-NN] GGML backend: failed to llama_decode: try reducing the size of the batch or increasing the size of context"sv); - return ErrNo::RuntimeError; - } - if (Status < 0) { - spdlog::error( - "[WASI-NN] GGML backend: failed to llama_decode: internal fatal error. Please open an issue on GitHub"sv); - return ErrNo::RuntimeError; - } - - CxtRef.LlamaNPast += NEval; + bool EvalImageStatus = + llava_eval_image_embed(CxtRef.LlamaContext, CxtRef.LlavaImageEmbd, + GraphRef.BatchSize, &CxtRef.LlamaNPast); + if (!EvalImageStatus) { + spdlog::error( + "[WASI-NN] GGML backend: failed to evaluate embed image tokens ."sv); + return ErrNo::RuntimeError; + } + ReturnCode = + details::EvaluateTokens(GraphRef, CxtRef.LlamaContext, + EmbdInputAfterImage, CxtRef.LlamaNPast); + if (ReturnCode != ErrNo::Success) { + spdlog::error( + "[WASI-NN] GGML backend: failed to evaluate input tokens after image."sv); + return ReturnCode; } } + } - CxtRef.LlamaEmbd.clear(); - - if (static_cast(CxtRef.LlamaInputs.size()) <= - CxtRef.LlamaNConsumed) { - const llama_token Id = llama_sampling_sample( - CxtRef.LlamaSampling, CxtRef.LlamaContext, nullptr); - llama_sampling_accept(CxtRef.LlamaSampling, CxtRef.LlamaContext, Id, - true); - CxtRef.LlamaEmbd.emplace_back(Id); - // Save the output token. - CxtRef.LlamaOutputTokens.emplace_back(Id); - CxtRef.LlamaOutputs += llama_token_to_piece(CxtRef.LlamaContext, Id); - // Deal with end of text token. - if (llama_sampling_last(CxtRef.LlamaSampling) == - llama_token_eos(GraphRef.LlamaModel)) { - if (GraphRef.EnableLog) { - spdlog::info("[WASI-NN] GGML backend: EOS token found"sv); - } - return ErrNo::EndOfSequence; - } - return ErrNo::Success; - } else { - while (static_cast(CxtRef.LlamaInputs.size()) > - CxtRef.LlamaNConsumed) { - CxtRef.LlamaEmbd.push_back(CxtRef.LlamaInputs[CxtRef.LlamaNConsumed]); - // Push the prompt in the sampling context. - llama_sampling_accept(CxtRef.LlamaSampling, CxtRef.LlamaContext, - CxtRef.LlamaInputs[CxtRef.LlamaNConsumed], false); - ++CxtRef.LlamaNConsumed; - if (CxtRef.LlamaEmbd.size() >= GraphRef.BatchSize) { - break; - } - } + // Main predict process. + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: enter main predict process"sv); + } + auto ReturnCode = ErrNo::Success; + const llama_token Id = + llama_sampling_sample(CxtRef.LlamaSampling, CxtRef.LlamaContext, nullptr); + llama_sampling_accept(CxtRef.LlamaSampling, CxtRef.LlamaContext, Id, true); + + // Save the output token. + // In single token mode, we do not handle StreamStdout and ReversePrompt. + CxtRef.LlamaOutputTokens.emplace_back(Id); + CxtRef.LlamaOutputs += llama_token_to_piece(CxtRef.LlamaContext, Id); + // Deal with end of text token. + if (llama_sampling_last(CxtRef.LlamaSampling) == + llama_token_eos(GraphRef.LlamaModel)) { + ReturnCode = ErrNo::EndOfSequence; + if (GraphRef.EnableLog) { + spdlog::info("[WASI-NN] GGML backend: EOS token found"sv); } } + // Evaluate the output token if not EOS. + if (ReturnCode != ErrNo::EndOfSequence) { + ReturnCode = details::EvaluateTokens(GraphRef, CxtRef.LlamaContext, {Id}, + CxtRef.LlamaNPast); + } + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: enter main predict process...Done"sv); + } + // End of main predict process. - return ErrNo::Success; + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: computeSingleToken...Done"sv); + } + + return ReturnCode; } Expect finiSingle(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { @@ -972,15 +1082,17 @@ Expect finiSingle(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { llama_free(CxtRef.LlamaContext); CxtRef.LlamaSampling = nullptr; CxtRef.LlamaContext = nullptr; + if (CxtRef.LlavaImageEmbd != nullptr) { + llava_image_embed_free(CxtRef.LlavaImageEmbd); + CxtRef.LlavaImageEmbd = nullptr; + } if (GraphRef.EnableDebugLog) { spdlog::info( "[WASI-NN][Debug] GGML backend: finiSingle: free the llama context...Done"sv); } // Reset the context variables. - CxtRef.LlamaEmbd.clear(); CxtRef.LlamaNPast = 0; - CxtRef.LlamaNConsumed = 0; return ErrNo::Success; } diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index 4df0b76d..7f26f2a7 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -9,6 +9,7 @@ #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML #include #include +#include #endif namespace WasmEdge::Host::WASINN { @@ -28,6 +29,8 @@ struct Graph { bool Embedding = false; uint64_t NPredict; std::string ReversePrompt; + std::string MMProjModelPath; + std::string ImagePath; // Model parameters: int64_t NGPULayers = 0; // Context parameters: @@ -47,15 +50,16 @@ struct Context { Context(size_t GId, Graph &) noexcept : GraphId(GId) {} size_t GraphId; std::vector LlamaInputs; + uint64_t LlamaNInputs = 0; std::string LlamaOutputs; std::vector LlamaOutputTokens; // Preserve for computing single token llama_context *LlamaContext = nullptr; struct llama_sampling_context *LlamaSampling = nullptr; - std::vector LlamaEmbd; - uint64_t LlamaNInputs; - uint64_t LlamaNPast; - uint64_t LlamaNConsumed; + int32_t LlamaNPast = 0; + // Preserve for llava + struct llava_image_embed *LlavaImageEmbd = nullptr; + size_t LlavaImagePosition = 0; }; #else struct Graph {}; From 11743fe448626ae32a0c441df9cdaf61051b6772 Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 21 Feb 2024 16:41:05 +0800 Subject: [PATCH 239/623] [WASI-NN] ggml: bump llama.cpp b2220 Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 3275f86c..c0277410 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -49,7 +49,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b2037 + GIT_TAG b2220 PATCH_COMMAND test -f ggml.patched || git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch && ${CMAKE_COMMAND} -E touch ggml.patched GIT_SHALLOW FALSE ) From 1091f56a333a1452419e8da9509b8daefe7a0524 Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 21 Feb 2024 17:04:53 +0800 Subject: [PATCH 240/623] [WASI-NN] ggml: fix typo Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index f7b6a68f..c034344a 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -639,7 +639,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, if (GraphRef.CtxSize < 4096) { spdlog::info( "[WASI-NN] GGML backend: Context size is {}, " - "we recommand context size >= 2048 when using llava-v1.5 " + "we recommend context size >= 2048 when using llava-v1.5 " "and context size >= 4096 when using llava-v1.6 for better results."sv, GraphRef.CtxSize); } @@ -999,7 +999,7 @@ Expect computeSingle(WasiNNEnvironment &Env, GraphRef.BatchSize, &CxtRef.LlamaNPast); if (!EvalImageStatus) { spdlog::error( - "[WASI-NN] GGML backend: failed to evaluate embed image tokens ."sv); + "[WASI-NN] GGML backend: failed to evaluate embed image tokens."sv); return ErrNo::RuntimeError; } ReturnCode = From 757b2b00b665d80810e30e5ffd86b649f764e661 Mon Sep 17 00:00:00 2001 From: hydai Date: Thu, 22 Feb 2024 00:26:37 +0800 Subject: [PATCH 241/623] [WASI-NN] ggml: support main-gpu and tensor-split (#3229) Signed-off-by: hydai --- plugins/wasi_nn/ggml.cpp | 67 ++++++++++++++++++++++++++++++++++------ plugins/wasi_nn/ggml.h | 2 ++ 2 files changed, 59 insertions(+), 10 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index c034344a..54209eec 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -6,6 +6,7 @@ #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML #include "simdjson.h" +#include #include #include #include @@ -42,6 +43,8 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, // image: string // Model parameters (need to reload the model if updated): // n-gpu-layers: int64_t + // main-gpu: int64_t + // tensor-split: string, comma-separated floating number list // Context parameters (used by the llama context): // ctx-size: uint64_t // batch-size: uint64_t @@ -56,6 +59,8 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, // Get the current llama parameters. llama_model_params ModelParams = llama_model_default_params(); ModelParams.n_gpu_layers = GraphRef.NGPULayers; + ModelParams.main_gpu = GraphRef.MainGPU; + ModelParams.tensor_split = GraphRef.TensorSplit.data(); // The plugin parameters. if (Doc.at_key("enable-log").error() == simdjson::SUCCESS) { @@ -139,6 +144,44 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, return ErrNo::InvalidArgument; } } + if (Doc.at_key("main-gpu").error() == simdjson::SUCCESS) { + auto Err = Doc["main-gpu"].get().get(GraphRef.MainGPU); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the main-gpu option."sv); + return ErrNo::InvalidArgument; + } + } + if (Doc.at_key("tensor-split").error() == simdjson::SUCCESS) { + // The TensorSplit is a comma-separated list of non-negative values. + // E.g., "3,2" presents 60% of the data to GPU 0 and 40% to GPU 1. + std::string_view TSV; + auto Err = Doc["tensor-split"].get().get(TSV); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the tensor-split option."sv); + return ErrNo::InvalidArgument; + } + std::string TS(TSV); + std::replace(TS.begin(), TS.end(), ',', ' '); + std::stringstream SS(TS); + GraphRef.TensorSplit.clear(); + while (SS.good()) { + float TmpTensor; + SS >> TmpTensor; + GraphRef.TensorSplit.push_back(TmpTensor); + } + uint32_t NDevices = llama_max_devices(); + if (GraphRef.TensorSplit.size() > NDevices) { + spdlog::error( + "[WASI-NN] GGML backend: Number of Tensor-Split is larger " + "than MaxDevices, please reduce the size of tensor-split."sv); + return ErrNo::InvalidArgument; + } + for (uint32_t Idx = GraphRef.TensorSplit.size(); Idx < NDevices; Idx++) { + GraphRef.TensorSplit.push_back(0.0f); + } + } // The context parameters. if (Doc.at_key("ctx-size").error() == simdjson::SUCCESS) { @@ -318,9 +361,9 @@ Expect getEmbedding(WasiNNEnvironment &Env, // Check if the input is too long. if (static_cast(CxtRef.LlamaInputs.size()) > MaxTokensListSize) { if (GraphRef.EnableLog) { - spdlog::info( - "[WASI-NN] GGML backend: the prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, - CxtRef.LlamaInputs.size(), MaxTokensListSize); + spdlog::info("[WASI-NN] GGML backend: the prompt is too long. Your input " + "has {} tokens. Please reduce it to {} tokens."sv, + CxtRef.LlamaInputs.size(), MaxTokensListSize); } return ErrNo::PromptTooLong; } @@ -335,12 +378,13 @@ Expect getEmbedding(WasiNNEnvironment &Env, NTokens, NPast, SequenceId)); if (Status == 1) { spdlog::error( - "[WASI-NN] GGML backend: failed to llama_decode: try reducing the size of the batch or increasing the size of context"sv); + "[WASI-NN] GGML backend: failed to llama_decode: try " + "reducing the size of the batch or increasing the size of context"sv); return ErrNo::RuntimeError; } if (Status < 0) { - spdlog::error( - "[WASI-NN] GGML backend: failed to llama_decode: internal fatal error. Please open an issue on GitHub"sv); + spdlog::error("[WASI-NN] GGML backend: failed to llama_decode: internal " + "fatal error. Please open an issue on GitHub"sv); return ErrNo::RuntimeError; } @@ -481,7 +525,8 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, } else { if (GraphRef.EnableDebugLog) { spdlog::info( - "[WASI-NN][Debug] GGML backend: Model path not found in nn-preload, write model into a tmpfile."sv); + "[WASI-NN][Debug] GGML backend: Model path not found in nn-preload, " + "write model into a tmpfile."sv); } // TODO: pass the model directly to ggml // Write ggml model to file. @@ -516,6 +561,8 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, GraphRef.ModelFilePath = ModelFilePath; llama_model_params ModelParams = llama_model_default_params(); ModelParams.n_gpu_layers = GraphRef.NGPULayers; + ModelParams.main_gpu = GraphRef.MainGPU; + ModelParams.tensor_split = GraphRef.TensorSplit.data(); GraphRef.LlamaModel = llama_load_model_from_file(GraphRef.ModelFilePath.c_str(), ModelParams); if (GraphRef.LlamaModel == nullptr) { @@ -776,9 +823,9 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // Check if the input is too long. if (static_cast(CxtRef.LlamaInputs.size()) > MaxTokensListSize) { if (GraphRef.EnableLog) { - spdlog::info( - "[WASI-NN] GGML backend: the prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, - CxtRef.LlamaInputs.size(), MaxTokensListSize); + spdlog::info("[WASI-NN] GGML backend: the prompt is too long. Your input " + "has {} tokens. Please reduce it to {} tokens."sv, + CxtRef.LlamaInputs.size(), MaxTokensListSize); } return ErrNo::PromptTooLong; } diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index 7f26f2a7..67c70fcc 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -32,7 +32,9 @@ struct Graph { std::string MMProjModelPath; std::string ImagePath; // Model parameters: + int64_t MainGPU = 0; // Use GPU 0 by default int64_t NGPULayers = 0; + std::vector TensorSplit; // Context parameters: uint64_t CtxSize; uint64_t BatchSize; From 11afca81bc5379cf1a269aadb77c50e453f88173 Mon Sep 17 00:00:00 2001 From: hydai Date: Thu, 22 Feb 2024 03:41:26 +0800 Subject: [PATCH 242/623] [WASI-NN] ggml: bump to b2230 for supporting the gemma model Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index c0277410..d65d2e06 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -49,7 +49,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b2220 + GIT_TAG b2230 PATCH_COMMAND test -f ggml.patched || git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch && ${CMAKE_COMMAND} -E touch ggml.patched GIT_SHALLOW FALSE ) From 557a8cd17943d104cba05a3a492aecdbd7157c70 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Mon, 5 Feb 2024 16:09:15 +0800 Subject: [PATCH 243/623] [Test] Support GC spec test suite. Not to turn on the GC proposal testing for interpreter now. Will turn on it once passing. Signed-off-by: YiYing He --- test/plugins/wasm_bpf/simple_ringbuf_test.cpp | 11 ++++++----- test/plugins/wasm_bpf/wasm_bpf.cpp | 11 ++++++----- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/test/plugins/wasm_bpf/simple_ringbuf_test.cpp b/test/plugins/wasm_bpf/simple_ringbuf_test.cpp index b615524e..f7052105 100644 --- a/test/plugins/wasm_bpf/simple_ringbuf_test.cpp +++ b/test/plugins/wasm_bpf/simple_ringbuf_test.cpp @@ -186,16 +186,17 @@ TEST(WasmBpfTest, SimpleRingbuf) { // In the following several steps we will prepare for polling // Create an instance of the polling callback function - auto callbackFuncInst = - std::make_unique( - &moduleInst, std::make_unique()); + moduleInst.addHostFunc("__polling_callback_hostfunc"sv, + std::make_unique()); + auto *callbackFuncInst = + moduleInst.findFuncExports("__polling_callback_hostfunc"); // Create a function table, and fill the callback function into it auto funcTableInst = std::make_unique( WasmEdge::AST::TableType(WasmEdge::TypeCode::FuncRef, 1)); ASSERT_TRUE(funcTableInst->setRefs( - std::initializer_list{callbackFuncInst.get()}, - 0, 0, 1)); + std::initializer_list{callbackFuncInst}, 0, 0, + 1)); // Add the table to the main module moduleInst.addHostTable("__indirect_function_table"sv, std::move(funcTableInst)); diff --git a/test/plugins/wasm_bpf/wasm_bpf.cpp b/test/plugins/wasm_bpf/wasm_bpf.cpp index 5f80042d..9bf84f87 100644 --- a/test/plugins/wasm_bpf/wasm_bpf.cpp +++ b/test/plugins/wasm_bpf/wasm_bpf.cpp @@ -279,16 +279,17 @@ TEST(WasmBpfTest, RunBpfProgramWithPolling) { // In the following several steps we will prepare for polling // Create an instance of the polling callback function - auto callbackFuncInst = - std::make_unique( - &moduleInst, std::make_unique()); + moduleInst.addHostFunc("__polling_callback_hostfunc"sv, + std::make_unique()); + auto *callbackFuncInst = + moduleInst.findFuncExports("__polling_callback_hostfunc"); // Create a function table, and fill the callback function into it auto funcTableInst = std::make_unique( WasmEdge::AST::TableType(WasmEdge::TypeCode::FuncRef, 1)); EXPECT_TRUE(funcTableInst->setRefs( - std::initializer_list{callbackFuncInst.get()}, - 0, 0, 1)); + std::initializer_list{callbackFuncInst}, 0, 0, + 1)); // Add the table to the main module moduleInst.addHostTable("__indirect_function_table"sv, std::move(funcTableInst)); From 30cc420309876dda48b5f12edb05e0ea03237d50 Mon Sep 17 00:00:00 2001 From: vincent Date: Sat, 23 Sep 2023 18:02:36 +0800 Subject: [PATCH 244/623] [Plugin] Decouple plugin registration from lib loading Handling special case where the main program statically links Wasmedge but still requires dynamic loading of plugins. Signed-off-by: vincent --- plugins/wasi_crypto/ctx.cpp | 3 ++- plugins/wasi_crypto/ctx.h | 1 - plugins/wasi_logging/env.cpp | 5 ++--- plugins/wasi_logging/wasi_logging/env.h | 1 - plugins/wasi_nn/wasinnenv.cpp | 4 ++-- plugins/wasi_nn/wasinnenv.h | 1 - plugins/wasm_bpf/wasm-bpf-module.cpp | 2 +- plugins/wasmedge_image/image_env.cpp | 5 ++--- plugins/wasmedge_image/image_env.h | 4 +--- plugins/wasmedge_opencvmini/opencvmini_env.cpp | 5 ++--- plugins/wasmedge_opencvmini/opencvmini_env.h | 2 -- plugins/wasmedge_process/processenv.cpp | 5 ++--- plugins/wasmedge_process/processenv.h | 1 - plugins/wasmedge_tensorflow/tensorflow_env.cpp | 5 ++--- plugins/wasmedge_tensorflow/tensorflow_env.h | 2 -- plugins/wasmedge_tensorflowlite/tensorflowlite_env.cpp | 5 ++--- plugins/wasmedge_tensorflowlite/tensorflowlite_env.h | 2 -- plugins/wasmedge_zlib/zlibenv.cpp | 5 ++--- plugins/wasmedge_zlib/zlibenv.h | 3 --- test/plugins/unittest/testplugin.cpp | 5 ++--- test/plugins/unittest/testplugin.h | 1 - 21 files changed, 22 insertions(+), 45 deletions(-) diff --git a/plugins/wasi_crypto/ctx.cpp b/plugins/wasi_crypto/ctx.cpp index 7ff586f6..db2eefe1 100644 --- a/plugins/wasi_crypto/ctx.cpp +++ b/plugins/wasi_crypto/ctx.cpp @@ -72,9 +72,10 @@ Plugin::Plugin::PluginDescriptor Descriptor{ .AddOptions = nullptr, }; +EXPORT_GET_DESCRIPTOR(Descriptor) + } // namespace -Plugin::PluginRegister WasiCrypto::Context::Register(&Descriptor); std::shared_mutex WasiCrypto::Context::Mutex; std::weak_ptr WasiCrypto::Context::Instance; diff --git a/plugins/wasi_crypto/ctx.h b/plugins/wasi_crypto/ctx.h index f5eda8a7..46a93d84 100644 --- a/plugins/wasi_crypto/ctx.h +++ b/plugins/wasi_crypto/ctx.h @@ -358,7 +358,6 @@ class Context { static std::shared_mutex Mutex; static std::weak_ptr Instance; - static Plugin::PluginRegister Register; }; } // namespace WasiCrypto diff --git a/plugins/wasi_logging/env.cpp b/plugins/wasi_logging/env.cpp index 0f55170f..0e5160a1 100644 --- a/plugins/wasi_logging/env.cpp +++ b/plugins/wasi_logging/env.cpp @@ -28,9 +28,8 @@ Plugin::Plugin::PluginDescriptor Descriptor{ .AddOptions = nullptr, }; -} // namespace - -Plugin::PluginRegister WasiLoggingEnvironment::Register(&Descriptor); +EXPORT_GET_DESCRIPTOR(Descriptor) +} // namespace } // namespace Host } // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasi_logging/wasi_logging/env.h b/plugins/wasi_logging/wasi_logging/env.h index 8a2864d2..2c199f96 100644 --- a/plugins/wasi_logging/wasi_logging/env.h +++ b/plugins/wasi_logging/wasi_logging/env.h @@ -16,7 +16,6 @@ class WasiLoggingEnvironment { spdlog::stdout_color_mt("wasi_logging_stdout"); inline const static std::shared_ptr StderrLogger = spdlog::stderr_color_mt("wasi_logging_stderr"); - static Plugin::PluginRegister Register; }; } // namespace Host diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index 438c7f02..6a4038a1 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -157,9 +157,9 @@ Plugin::Plugin::PluginDescriptor Descriptor{ .AddOptions = addOptions, }; -} // namespace WASINN +EXPORT_GET_DESCRIPTOR(Descriptor) -Plugin::PluginRegister WASINN::WasiNNEnvironment::Register(&Descriptor); +} // namespace WASINN } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index 00c8df0f..bc83b739 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -206,7 +206,6 @@ struct WasiNNEnvironment : static PO::Option NNRPCURI; // For RPC client mode std::shared_ptr NNRPCChannel; #endif - static Plugin::PluginRegister Register; }; } // namespace WASINN diff --git a/plugins/wasm_bpf/wasm-bpf-module.cpp b/plugins/wasm_bpf/wasm-bpf-module.cpp index 9ae739d4..6118f6ff 100644 --- a/plugins/wasm_bpf/wasm-bpf-module.cpp +++ b/plugins/wasm_bpf/wasm-bpf-module.cpp @@ -53,7 +53,7 @@ Plugin::Plugin::PluginDescriptor Descriptor{ }, .AddOptions = nullptr}; -Plugin::PluginRegister Register(&Descriptor); +EXPORT_GET_DESCRIPTOR(Descriptor) } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasmedge_image/image_env.cpp b/plugins/wasmedge_image/image_env.cpp index 022787ca..a1f104ae 100644 --- a/plugins/wasmedge_image/image_env.cpp +++ b/plugins/wasmedge_image/image_env.cpp @@ -32,9 +32,8 @@ Plugin::Plugin::PluginDescriptor Descriptor{ .AddOptions = nullptr, }; -} // namespace - -Plugin::PluginRegister WasmEdgeImage::ImgEnv::Register(&Descriptor); +EXPORT_GET_DESCRIPTOR(Descriptor) +} // namespace } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasmedge_image/image_env.h b/plugins/wasmedge_image/image_env.h index 0eea7ba5..b2b59c2c 100644 --- a/plugins/wasmedge_image/image_env.h +++ b/plugins/wasmedge_image/image_env.h @@ -23,9 +23,7 @@ enum class DataType : uint32_t { BGR32F = 3, }; -struct ImgEnv { - static Plugin::PluginRegister Register; -}; +struct ImgEnv {}; } // namespace WasmEdgeImage } // namespace Host diff --git a/plugins/wasmedge_opencvmini/opencvmini_env.cpp b/plugins/wasmedge_opencvmini/opencvmini_env.cpp index c205d7f4..d5211492 100644 --- a/plugins/wasmedge_opencvmini/opencvmini_env.cpp +++ b/plugins/wasmedge_opencvmini/opencvmini_env.cpp @@ -33,9 +33,8 @@ Plugin::Plugin::PluginDescriptor Descriptor{ .AddOptions = nullptr, }; -} // namespace - -Plugin::PluginRegister WasmEdgeOpenCVMiniEnvironment::Register(&Descriptor); +EXPORT_GET_DESCRIPTOR(Descriptor) +} // namespace } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasmedge_opencvmini/opencvmini_env.h b/plugins/wasmedge_opencvmini/opencvmini_env.h index 9ab004ea..f2e606cb 100644 --- a/plugins/wasmedge_opencvmini/opencvmini_env.h +++ b/plugins/wasmedge_opencvmini/opencvmini_env.h @@ -16,8 +16,6 @@ class WasmEdgeOpenCVMiniEnvironment { public: WasmEdgeOpenCVMiniEnvironment() noexcept; - static Plugin::PluginRegister Register; - std::map MatPool; Expect getMat(uint32_t MatKey) { diff --git a/plugins/wasmedge_process/processenv.cpp b/plugins/wasmedge_process/processenv.cpp index 8c769a33..32a318f7 100644 --- a/plugins/wasmedge_process/processenv.cpp +++ b/plugins/wasmedge_process/processenv.cpp @@ -56,9 +56,8 @@ Plugin::Plugin::PluginDescriptor Descriptor{ .AddOptions = addOptions, }; -} // namespace - -Plugin::PluginRegister WasmEdgeProcessEnvironment::Register(&Descriptor); +EXPORT_GET_DESCRIPTOR(Descriptor) +} // namespace } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasmedge_process/processenv.h b/plugins/wasmedge_process/processenv.h index f1b4ddbc..a7dc1161 100644 --- a/plugins/wasmedge_process/processenv.h +++ b/plugins/wasmedge_process/processenv.h @@ -46,7 +46,6 @@ class WasmEdgeProcessEnvironment { static PO::List AllowCmd; static PO::Option AllowCmdAll; - static Plugin::PluginRegister Register; }; } // namespace Host diff --git a/plugins/wasmedge_tensorflow/tensorflow_env.cpp b/plugins/wasmedge_tensorflow/tensorflow_env.cpp index 9a672863..04fb5507 100644 --- a/plugins/wasmedge_tensorflow/tensorflow_env.cpp +++ b/plugins/wasmedge_tensorflow/tensorflow_env.cpp @@ -32,9 +32,8 @@ Plugin::Plugin::PluginDescriptor Descriptor{ .AddOptions = nullptr, }; -} // namespace - -Plugin::PluginRegister WasmEdgeTensorflow::TFEnv::Register(&Descriptor); +EXPORT_GET_DESCRIPTOR(Descriptor) +} // namespace } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflow/tensorflow_env.h b/plugins/wasmedge_tensorflow/tensorflow_env.h index 8cfcd645..46160dcb 100644 --- a/plugins/wasmedge_tensorflow/tensorflow_env.h +++ b/plugins/wasmedge_tensorflow/tensorflow_env.h @@ -117,8 +117,6 @@ struct TFEnv { } } - static Plugin::PluginRegister Register; - private: std::unordered_set RecycledIdx; std::vector TFContext; diff --git a/plugins/wasmedge_tensorflowlite/tensorflowlite_env.cpp b/plugins/wasmedge_tensorflowlite/tensorflowlite_env.cpp index 47a3e037..a6d194ea 100644 --- a/plugins/wasmedge_tensorflowlite/tensorflowlite_env.cpp +++ b/plugins/wasmedge_tensorflowlite/tensorflowlite_env.cpp @@ -32,9 +32,8 @@ Plugin::Plugin::PluginDescriptor Descriptor{ .AddOptions = nullptr, }; -} // namespace - -Plugin::PluginRegister WasmEdgeTensorflowLite::TFLiteEnv::Register(&Descriptor); +EXPORT_GET_DESCRIPTOR(Descriptor) +} // namespace } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasmedge_tensorflowlite/tensorflowlite_env.h b/plugins/wasmedge_tensorflowlite/tensorflowlite_env.h index f4f606f4..188ab9c2 100644 --- a/plugins/wasmedge_tensorflowlite/tensorflowlite_env.h +++ b/plugins/wasmedge_tensorflowlite/tensorflowlite_env.h @@ -63,8 +63,6 @@ struct TFLiteEnv { } } - static Plugin::PluginRegister Register; - private: std::unordered_set RecycledIdx; std::vector TFLiteContext; diff --git a/plugins/wasmedge_zlib/zlibenv.cpp b/plugins/wasmedge_zlib/zlibenv.cpp index 45654a28..f00a9ef7 100644 --- a/plugins/wasmedge_zlib/zlibenv.cpp +++ b/plugins/wasmedge_zlib/zlibenv.cpp @@ -31,9 +31,8 @@ Plugin::Plugin::PluginDescriptor Descriptor{ .AddOptions = nullptr, }; -} // namespace - -Plugin::PluginRegister WasmEdgeZlibEnvironment::Register(&Descriptor); +EXPORT_GET_DESCRIPTOR(Descriptor) +} // namespace } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasmedge_zlib/zlibenv.h b/plugins/wasmedge_zlib/zlibenv.h index 97781a57..190025cd 100644 --- a/plugins/wasmedge_zlib/zlibenv.h +++ b/plugins/wasmedge_zlib/zlibenv.h @@ -91,9 +91,6 @@ class WasmEdgeZlibEnvironment { std::unordered_map> ZStreamMap; std::map, std::greater> GZFileMap; std::unordered_map GZHeaderMap; - - /// Initial Configurations - static Plugin::PluginRegister Register; }; } // namespace Host diff --git a/test/plugins/unittest/testplugin.cpp b/test/plugins/unittest/testplugin.cpp index 21431b14..565b77c8 100644 --- a/test/plugins/unittest/testplugin.cpp +++ b/test/plugins/unittest/testplugin.cpp @@ -49,9 +49,8 @@ Plugin::Plugin::PluginDescriptor Descriptor{ .AddOptions = addOptions, }; -} // namespace - -Plugin::PluginRegister WasmEdgePluginTestEnv::Register(&Descriptor); +EXPORT_GET_DESCRIPTOR(Descriptor) +} // namespace } // namespace Host } // namespace WasmEdge diff --git a/test/plugins/unittest/testplugin.h b/test/plugins/unittest/testplugin.h index 214ab049..d7373c9c 100644 --- a/test/plugins/unittest/testplugin.h +++ b/test/plugins/unittest/testplugin.h @@ -20,7 +20,6 @@ class WasmEdgePluginTestEnv { static PO::List CmdArgs; static PO::Option CmdName; - static Plugin::PluginRegister Register; }; template From 8865ac30f3836f30e58766e828018ec8d6f3b0c6 Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Wed, 21 Feb 2024 18:04:58 +0800 Subject: [PATCH 245/623] [Docker] Rename SHA256SUM file for manylinux2014 Signed-off-by: Yi Huang --- utils/docker/Dockerfile.manylinux2014_aarch64 | 4 ++-- utils/docker/Dockerfile.manylinux2014_x86_64 | 4 ++-- utils/docker/{SHA256SUM => SHA256SUM.manylinux2014} | 0 3 files changed, 4 insertions(+), 4 deletions(-) rename utils/docker/{SHA256SUM => SHA256SUM.manylinux2014} (100%) diff --git a/utils/docker/Dockerfile.manylinux2014_aarch64 b/utils/docker/Dockerfile.manylinux2014_aarch64 index 59ca239e..72c5a1a4 100644 --- a/utils/docker/Dockerfile.manylinux2014_aarch64 +++ b/utils/docker/Dockerfile.manylinux2014_aarch64 @@ -5,7 +5,7 @@ FROM quay.io/pypa/manylinux2014_aarch64 MAINTAINER hydai hydai@secondstate.io -ADD SHA256SUM /root/ +ADD SHA256SUM.manylinux2014 /root/ ENV PATH /opt/rh/devtoolset-10/root/usr/bin${PATH:+:${PATH}} ENV MANPATH /opt/rh/devtoolset-10/root/usr/share/man${MANPATH:+:${MANPATH}} @@ -27,7 +27,7 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/cmake-16.0.5.src.tar.xz \ https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/third-party-16.0.5.src.tar.xz \ https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/clang-16.0.5.src.tar.xz && \ - sha256sum -c SHA256SUM && \ + sha256sum -c SHA256SUM.manylinux2014 && \ gzip -dc zstd-1.5.5.tar.gz | tar -xf - && \ gzip -dc cmake-3.26.4.tar.gz | tar -xf - && \ gzip -dc v1.11.1.tar.gz | tar -xf - && \ diff --git a/utils/docker/Dockerfile.manylinux2014_x86_64 b/utils/docker/Dockerfile.manylinux2014_x86_64 index cc6a942a..9555dbbf 100644 --- a/utils/docker/Dockerfile.manylinux2014_x86_64 +++ b/utils/docker/Dockerfile.manylinux2014_x86_64 @@ -5,7 +5,7 @@ FROM quay.io/pypa/manylinux2014_x86_64 MAINTAINER hydai hydai@secondstate.io -ADD SHA256SUM /root/ +ADD SHA256SUM.manylinux2014 /root/ ENV PATH /opt/rh/devtoolset-11/root/usr/bin${PATH:+:${PATH}} ENV MANPATH /opt/rh/devtoolset-11/root/usr/share/man${MANPATH:+:${MANPATH}} @@ -27,7 +27,7 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/cmake-16.0.5.src.tar.xz \ https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/third-party-16.0.5.src.tar.xz \ https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/clang-16.0.5.src.tar.xz && \ - sha256sum -c SHA256SUM && \ + sha256sum -c SHA256SUM.manylinux2014 && \ gzip -dc zstd-1.5.5.tar.gz | tar -xf - && \ gzip -dc cmake-3.26.4.tar.gz | tar -xf - && \ gzip -dc v1.11.1.tar.gz | tar -xf - && \ diff --git a/utils/docker/SHA256SUM b/utils/docker/SHA256SUM.manylinux2014 similarity index 100% rename from utils/docker/SHA256SUM rename to utils/docker/SHA256SUM.manylinux2014 From 26a107b9d90d0cda5cc746d2bcdd9729a82f9968 Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Fri, 16 Feb 2024 18:48:54 +0800 Subject: [PATCH 246/623] [Docker] manylinux_2_28 (#3188) Signed-off-by: Yi Huang --- ...ckerfile.manylinux_2_28-build-plugins-deps | 15 +++++ .../docker/Dockerfile.manylinux_2_28_aarch64 | 56 +++++++++++++++++++ utils/docker/Dockerfile.manylinux_2_28_x86_64 | 56 +++++++++++++++++++ utils/docker/SHA256SUM.manylinux_2_28 | 7 +++ 4 files changed, 134 insertions(+) create mode 100644 utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps create mode 100644 utils/docker/Dockerfile.manylinux_2_28_aarch64 create mode 100644 utils/docker/Dockerfile.manylinux_2_28_x86_64 create mode 100644 utils/docker/SHA256SUM.manylinux_2_28 diff --git a/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps b/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps new file mode 100644 index 00000000..60b8a044 --- /dev/null +++ b/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps @@ -0,0 +1,15 @@ +ARG BASE=wasmedge/wasmedge:manylinux_2_28_x86_64 +FROM ${BASE} + +ENV PATH /opt/rh/gcc-toolset-13/root/usr/bin${PATH:+:${PATH}} +ENV MANPATH /opt/rh/gcc-toolset-13/root/usr/share/man${MANPATH:+:${MANPATH}} +ENV INFOPATH /opt/rh/gcc-toolset-13/root/usr/share/info${INFOPATH:+:${INFOPATH}} +ENV PKG_CONFIG_PATH /opt/rh/gcc-toolset-13/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} + +RUN cd && (yum check-update || true) && yum install -y wget unzip + +COPY install-opencvmini.sh . +ENV OPENCV_VERSION=4.8.0 +RUN [ "/bin/bash", "install-opencvmini.sh" ] + +RUN yum clean all diff --git a/utils/docker/Dockerfile.manylinux_2_28_aarch64 b/utils/docker/Dockerfile.manylinux_2_28_aarch64 new file mode 100644 index 00000000..98dc2e36 --- /dev/null +++ b/utils/docker/Dockerfile.manylinux_2_28_aarch64 @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +FROM quay.io/pypa/manylinux_2_28_aarch64 + +MAINTAINER hydai hydai@secondstate.io + +ADD SHA256SUM.manylinux_2_28 /root/ + +ENV PATH /opt/rh/gcc-toolset-13/root/usr/bin${PATH:+:${PATH}} +ENV MANPATH /opt/rh/gcc-toolset-13/root/usr/share/man${MANPATH:+:${MANPATH}} +ENV INFOPATH /opt/rh/gcc-toolset-13/root/usr/share/info${INFOPATH:+:${INFOPATH}} +ENV PKG_CONFIG_PATH /opt/rh/gcc-toolset-13/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} + +RUN cd && (yum check-update || true) && yum install -y openssl-devel rpm-build cmake && \ + yum install -y gcc-toolset-13 && \ + export CPU=$(/opt/python/cp311-cp311/bin/python3 -c \ + 'import multiprocessing; print(multiprocessing.cpu_count())') && \ + export CFGFLAGS="--prefix=/opt/rh/gcc-toolset-13/root/usr --disable-shared --libdir=/opt/rh/gcc-toolset-13/root/usr/lib64" && \ + curl -s -L -O --remote-name-all \ + https://github.com/ninja-build/ninja/archive/refs/tags/v1.11.1.tar.gz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/llvm-17.0.6.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/lld-17.0.6.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/libunwind-17.0.6.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/cmake-17.0.6.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/third-party-17.0.6.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/clang-17.0.6.src.tar.xz && \ + sha256sum -c SHA256SUM.manylinux_2_28 && \ + gzip -dc v1.11.1.tar.gz | tar -xf - && \ + xz -dc llvm-17.0.6.src.tar.xz | tar -xf - && \ + xz -dc lld-17.0.6.src.tar.xz | tar -xf - && \ + xz -dc libunwind-17.0.6.src.tar.xz | tar -xf - && \ + xz -dc cmake-17.0.6.src.tar.xz | tar -xf - && \ + xz -dc third-party-17.0.6.src.tar.xz | tar -xf - && \ + xz -dc clang-17.0.6.src.tar.xz | tar -xf - && \ + export ZSTDFLAGS=(PREFIX=/opt/rh/gcc-toolset-13/root/usr LIBDIR=/opt/rh/gcc-toolset-13/root/usr/lib64 SED_ERE_OPT=--regexp-extended MOREFLAGS="-std=c17 -O3 -fPIC -fPIE -fvisibility=hidden") && \ + mkdir build && cd build && /opt/python/cp311-cp311/bin/python \ + ../ninja-1.11.1/configure.py --bootstrap \ + --with-python=/opt/python/cp311-cp311/bin/python && \ + cp -v ninja /opt/rh/gcc-toolset-13/root/usr/bin/ninja && cd - && rm -rf build && \ + mv -v llvm-17.0.6.src llvm && \ + mv -v lld-17.0.6.src lld && \ + mv -v libunwind-17.0.6.src libunwind && \ + mv -v cmake-17.0.6.src cmake && \ + mv -v third-party-17.0.6.src third-party && \ + mv -v clang-17.0.6.src clang && \ + cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=/opt/rh/gcc-toolset-13/root/usr \ + -DPython3_ROOT_DIR=/opt/python/cp311-cp311 -DLLVM_LIBDIR_SUFFIX=64 \ + -DLLVM_TARGETS_TO_BUILD="AArch64;BPF" -DLLVM_ENABLE_PROJECTS="lld;clang" \ + -DLLVM_DEFAULT_TARGET_TRIPLE="aarch64-redhat-linux-gnu" \ + -DBUILD_SHARED_LIBS=OFF llvm && \ + cmake --build build --target install && \ + rm -rf build && rm -rf * + +RUN yum clean all diff --git a/utils/docker/Dockerfile.manylinux_2_28_x86_64 b/utils/docker/Dockerfile.manylinux_2_28_x86_64 new file mode 100644 index 00000000..ba1fe2e9 --- /dev/null +++ b/utils/docker/Dockerfile.manylinux_2_28_x86_64 @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +FROM quay.io/pypa/manylinux_2_28_x86_64 + +MAINTAINER hydai hydai@secondstate.io + +ADD SHA256SUM.manylinux_2_28 /root/ + +ENV PATH /opt/rh/gcc-toolset-13/root/usr/bin${PATH:+:${PATH}} +ENV MANPATH /opt/rh/gcc-toolset-13/root/usr/share/man${MANPATH:+:${MANPATH}} +ENV INFOPATH /opt/rh/gcc-toolset-13/root/usr/share/info${INFOPATH:+:${INFOPATH}} +ENV PKG_CONFIG_PATH /opt/rh/gcc-toolset-13/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} + +RUN cd && (yum check-update || true) && yum install -y openssl-devel rpm-build cmake && \ + yum install -y gcc-toolset-13 && \ + export CPU=$(/opt/python/cp311-cp311/bin/python3 -c \ + 'import multiprocessing; print(multiprocessing.cpu_count())') && \ + export CFGFLAGS="--prefix=/opt/rh/gcc-toolset-13/root/usr --disable-shared --libdir=/opt/rh/gcc-toolset-13/root/usr/lib64" && \ + curl -s -L -O --remote-name-all \ + https://github.com/ninja-build/ninja/archive/refs/tags/v1.11.1.tar.gz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/llvm-17.0.6.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/lld-17.0.6.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/libunwind-17.0.6.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/cmake-17.0.6.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/third-party-17.0.6.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/clang-17.0.6.src.tar.xz && \ + sha256sum -c SHA256SUM.manylinux_2_28 && \ + gzip -dc v1.11.1.tar.gz | tar -xf - && \ + xz -dc llvm-17.0.6.src.tar.xz | tar -xf - && \ + xz -dc lld-17.0.6.src.tar.xz | tar -xf - && \ + xz -dc libunwind-17.0.6.src.tar.xz | tar -xf - && \ + xz -dc cmake-17.0.6.src.tar.xz | tar -xf - && \ + xz -dc third-party-17.0.6.src.tar.xz | tar -xf - && \ + xz -dc clang-17.0.6.src.tar.xz | tar -xf - && \ + export ZSTDFLAGS=(PREFIX=/opt/rh/gcc-toolset-13/root/usr LIBDIR=/opt/rh/gcc-toolset-13/root/usr/lib64 SED_ERE_OPT=--regexp-extended MOREFLAGS="-std=c17 -O3 -fPIC -fPIE -fvisibility=hidden") && \ + mkdir build && cd build && /opt/python/cp311-cp311/bin/python \ + ../ninja-1.11.1/configure.py --bootstrap \ + --with-python=/opt/python/cp311-cp311/bin/python && \ + cp -v ninja /opt/rh/gcc-toolset-13/root/usr/bin/ninja && cd - && rm -rf build && \ + mv -v llvm-17.0.6.src llvm && \ + mv -v lld-17.0.6.src lld && \ + mv -v libunwind-17.0.6.src libunwind && \ + mv -v cmake-17.0.6.src cmake && \ + mv -v third-party-17.0.6.src third-party && \ + mv -v clang-17.0.6.src clang && \ + cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=/opt/rh/gcc-toolset-13/root/usr \ + -DPython3_ROOT_DIR=/opt/python/cp311-cp311 -DLLVM_LIBDIR_SUFFIX=64 \ + -DLLVM_TARGETS_TO_BUILD="X86;BPF" -DLLVM_ENABLE_PROJECTS="lld;clang" \ + -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" \ + -DBUILD_SHARED_LIBS=OFF llvm && \ + cmake --build build --target install && \ + rm -rf build && rm -rf * + +RUN yum clean all diff --git a/utils/docker/SHA256SUM.manylinux_2_28 b/utils/docker/SHA256SUM.manylinux_2_28 new file mode 100644 index 00000000..0b732009 --- /dev/null +++ b/utils/docker/SHA256SUM.manylinux_2_28 @@ -0,0 +1,7 @@ +31747ae633213f1eda3842686f83c2aa1412e0f5691d1c14dbbcc67fe7400cea v1.11.1.tar.gz +a78f668a726ae1d3d9a7179996d97b12b90fb76ab9442a43110b972ff7ad9029 clang-17.0.6.src.tar.xz +807f069c54dc20cb47b21c1f6acafdd9c649f3ae015609040d6182cab01140f4 cmake-17.0.6.src.tar.xz +9e7535a353aa862730b4ba38df42e06f6856b40c4cc51b57f27b5046dc21d70d libunwind-17.0.6.src.tar.xz +4ac13125616dc44905b85820aa403d27ec1226329b7f674daeb5f5584c6f0b22 lld-17.0.6.src.tar.xz +b638167da139126ca11917b6880207cc6e8f9d1cbb1a48d87d017f697ef78188 llvm-17.0.6.src.tar.xz +3054d0a9c9375dab1a4539cc2cc45ab340341c5d71475f9599ba7752e222947b third-party-17.0.6.src.tar.xz From 676ea7de50633e370101b05eb148510b4a13b979 Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 29 Feb 2024 16:12:21 +0800 Subject: [PATCH 247/623] [WASI-NN] ggml: add inline base64 prompt support for llava Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 174 ++++++++++++++++++++++++++++++++------- 1 file changed, 144 insertions(+), 30 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 54209eec..89e8b7af 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -7,6 +7,7 @@ #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML #include "simdjson.h" #include +#include #include #include #include @@ -417,7 +418,7 @@ Expect getEmbedding(WasiNNEnvironment &Env, return ErrNo::Success; } -ErrNo EvaluateTokens(Graph &GraphRef, struct llama_context *LlamaContext, +ErrNo evaluateTokens(Graph &GraphRef, struct llama_context *LlamaContext, std::vector Tokens, int &NPast) noexcept { uint32_t NCtx = llama_n_ctx(LlamaContext); @@ -459,6 +460,93 @@ ErrNo EvaluateTokens(Graph &GraphRef, struct llama_context *LlamaContext, return ErrNo::Success; } +const std::string_view Base64ImageTagPrefix = ""; +const std::string_view PromptImagePlaceholder = ""; + +bool checkBase64Image(const std::string Prompt) noexcept { + // Check if the prompt contains a base64 image. + // Follow this link for the supported image formats: + // https://github.com/ggerganov/llama.cpp/blob/master/common/stb_image.h + + auto Base64ImageTagBeginPos = Prompt.find(Base64ImageTagPrefix); + if (Base64ImageTagBeginPos == std::string::npos) { + return false; + } + auto Base64ImageTagEndPos = + Prompt.find(Base64ImageTagSuffix, Base64ImageTagBeginPos); + if (Base64ImageTagEndPos == std::string::npos) { + return false; + } + return true; +} + +struct llava_image_embed * +loadBase64ImageFromPrompt(Graph &GraphRef, clip_ctx *ClipContext, + std::string Prompt) noexcept { + // Load the base64 image from the prompt. + // Follow this link for the supported image formats: + // https://github.com/ggerganov/llama.cpp/blob/master/common/stb_image.h + + // Find `` + auto Base64ImageTagEndPos = + Prompt.find(Base64ImageTagSuffix, Base64ImageBytesBeginPos); + if (Base64ImageTagEndPos == std::string::npos) { + return nullptr; + } + + auto Base64Str = + Prompt.substr(Base64ImageBytesBeginPos + Base64ImageBytesPrefix.size(), + Base64ImageTagEndPos - Base64ImageBytesBeginPos - + Base64ImageBytesPrefix.size()); + + // Decode the base64 image. + auto RequiredBytes = base64::required_encode_size(Base64Str.size()); + auto ImageBytes = std::vector(RequiredBytes); + base64::decode(Base64Str.begin(), Base64Str.end(), ImageBytes.begin()); + + return llava_image_embed_make_with_bytes( + ClipContext, GraphRef.Threads, ImageBytes.data(), ImageBytes.size()); +} + +ErrNo replaceBase64ImagePlaceholderInPrompt(std::string &Prompt) noexcept { + // Replace the base64 image in the prompt with a placeholder. + + // Find `` + auto Base64ImageTagEndPos = + Prompt.find(Base64ImageTagSuffix, Base64ImageTagBeginPos); + if (Base64ImageTagEndPos == std::string::npos) { + return ErrNo::InvalidArgument; + } + + auto Base64ImageTagLength = Base64ImageTagEndPos - Base64ImageTagBeginPos + + Base64ImageTagSuffix.size(); + Prompt.replace(Base64ImageTagBeginPos, Base64ImageTagLength, + PromptImagePlaceholder); + + return ErrNo::Success; +} + } // namespace details Expect load(WasiNNEnvironment &Env, Span> Builders, @@ -668,8 +756,8 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, spdlog::info("[WASI-NN][Debug] GGML backend: set the input"sv); } const bool AddBos = llama_should_add_bos_token(GraphRef.LlamaModel); - const std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), - Tensor.Tensor.size()); + std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); if (GraphRef.MMProjModelPath == ""sv) { // Text only prompt. CxtRef.LlamaInputs = llama_tokenize(LlamaContext, Prompt, AddBos, true); @@ -677,12 +765,18 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } else { // Handle llava format prompt. + // Check if the prompt contains a base64 image. + bool ContainBase64Image = details::checkBase64Image(Prompt); + if (GraphRef.ImagePath == ""sv && ContainBase64Image == false) { + spdlog::error( + "[WASI-NN] GGML backend: Error: when using llava model, " + "you need to specify the image path or have the base64 encoded " + "image in the prompt."sv); + return ErrNo::InvalidArgument; + } + // Show some warnings. if (GraphRef.EnableLog) { - if (GraphRef.ImagePath == ""sv) { - spdlog::info( - "[WASI-NN] GGML backend: Image path is not set, will process as text-only prompt"sv); - } if (GraphRef.CtxSize < 4096) { spdlog::info( "[WASI-NN] GGML backend: Context size is {}, " @@ -692,17 +786,48 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } } + // Load image for llava. + int LlavaVerbosity = 0; + if (GraphRef.EnableLog) { + LlavaVerbosity = 1; + } + auto ClipContext = + clip_model_load(GraphRef.MMProjModelPath.c_str(), LlavaVerbosity); + if (ContainBase64Image) { + // Load the base64 image from the prompt. + CxtRef.LlavaImageEmbd = + details::loadBase64ImageFromPrompt(GraphRef, ClipContext, Prompt); + // Replace the base64 image in the prompt with a placeholder. + auto Res = details::replaceBase64ImagePlaceholderInPrompt(Prompt); + if (Res != ErrNo::Success) { + spdlog::error( + "[WASI-NN] GGML backend: Error: unable to replace the base64 image in the prompt."sv); + clip_free(ClipContext); + return Res; + } + } else { + // Load the image from the file. + CxtRef.LlavaImageEmbd = llava_image_embed_make_with_filename( + ClipContext, GraphRef.Threads, GraphRef.ImagePath.c_str()); + } + clip_free(ClipContext); + if (CxtRef.LlavaImageEmbd == nullptr) { + spdlog::error( + "[WASI-NN] GGML backend: Error: unable to load the image {}."sv, + GraphRef.ImagePath); + return ErrNo::InvalidArgument; + } + // We split prompt by as placeholder and save the position. - const std::string_view PromptPlaceholder = ""; - auto PlaceholderPosition = Prompt.find(PromptPlaceholder); + auto PlaceholderPosition = Prompt.find(details::PromptImagePlaceholder); if (PlaceholderPosition == std::string::npos) { spdlog::error( "[WASI-NN] GGML backend: Error: unable to find the placeholder in the llava prompt."sv); return ErrNo::InvalidArgument; } std::string PromptBeforeImage = Prompt.substr(0, PlaceholderPosition); - std::string PromptAfterImage = - Prompt.substr(PlaceholderPosition + PromptPlaceholder.length()); + std::string PromptAfterImage = Prompt.substr( + PlaceholderPosition + details::PromptImagePlaceholder.length()); std::vector EmbdInputBeforeImage = llama_tokenize(LlamaContext, PromptBeforeImage, AddBos, true); std::vector EmbdInputAfterImage = @@ -716,17 +841,6 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, CxtRef.LlamaInputs.insert(CxtRef.LlamaInputs.end(), EmbdInputAfterImage.begin(), EmbdInputAfterImage.end()); - - // Load image for llava. - int LlavaVerbosity = 0; - if (GraphRef.EnableLog) { - LlavaVerbosity = 1; - } - auto ClipContext = - clip_model_load(GraphRef.MMProjModelPath.c_str(), LlavaVerbosity); - CxtRef.LlavaImageEmbd = llava_image_embed_make_with_filename( - ClipContext, GraphRef.Threads, GraphRef.ImagePath.c_str()); - clip_free(ClipContext); } if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] GGML backend: set the input...Done"sv); @@ -833,7 +947,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // Evaluate input tokens. if (CxtRef.LlavaImageEmbd == nullptr) { // Text only prompt. - ReturnCode = details::EvaluateTokens(GraphRef, LlamaContext, + ReturnCode = details::evaluateTokens(GraphRef, LlamaContext, CxtRef.LlamaInputs, NPast); if (ReturnCode != ErrNo::Success) { spdlog::error( @@ -848,7 +962,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { std::vector EmbdInputAfterImage(CxtRef.LlamaInputs.begin() + CxtRef.LlavaImagePosition, CxtRef.LlamaInputs.end()); - ReturnCode = details::EvaluateTokens(GraphRef, LlamaContext, + ReturnCode = details::evaluateTokens(GraphRef, LlamaContext, EmbdInputBeforeImage, NPast); if (ReturnCode != ErrNo::Success) { spdlog::error( @@ -862,7 +976,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { "[WASI-NN] GGML backend: failed to evaluate embed image tokens."sv); return ErrNo::RuntimeError; } - ReturnCode = details::EvaluateTokens(GraphRef, LlamaContext, + ReturnCode = details::evaluateTokens(GraphRef, LlamaContext, EmbdInputAfterImage, NPast); if (ReturnCode != ErrNo::Success) { spdlog::error( @@ -905,7 +1019,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { break; } // Evaluate the output token. - ReturnCode = details::EvaluateTokens(GraphRef, LlamaContext, {Id}, NPast); + ReturnCode = details::evaluateTokens(GraphRef, LlamaContext, {Id}, NPast); if (ReturnCode != ErrNo::Success) { break; } @@ -1018,7 +1132,7 @@ Expect computeSingle(WasiNNEnvironment &Env, // Evaluate input tokens. if (CxtRef.LlavaImageEmbd == nullptr) { // Text only prompt. - ReturnCode = details::EvaluateTokens( + ReturnCode = details::evaluateTokens( GraphRef, CxtRef.LlamaContext, CxtRef.LlamaInputs, CxtRef.LlamaNPast); if (ReturnCode != ErrNo::Success) { spdlog::error( @@ -1034,7 +1148,7 @@ Expect computeSingle(WasiNNEnvironment &Env, CxtRef.LlamaInputs.begin() + CxtRef.LlavaImagePosition, CxtRef.LlamaInputs.end()); ReturnCode = - details::EvaluateTokens(GraphRef, CxtRef.LlamaContext, + details::evaluateTokens(GraphRef, CxtRef.LlamaContext, EmbdInputBeforeImage, CxtRef.LlamaNPast); if (ReturnCode != ErrNo::Success) { spdlog::error( @@ -1050,7 +1164,7 @@ Expect computeSingle(WasiNNEnvironment &Env, return ErrNo::RuntimeError; } ReturnCode = - details::EvaluateTokens(GraphRef, CxtRef.LlamaContext, + details::evaluateTokens(GraphRef, CxtRef.LlamaContext, EmbdInputAfterImage, CxtRef.LlamaNPast); if (ReturnCode != ErrNo::Success) { spdlog::error( @@ -1083,7 +1197,7 @@ Expect computeSingle(WasiNNEnvironment &Env, } // Evaluate the output token if not EOS. if (ReturnCode != ErrNo::EndOfSequence) { - ReturnCode = details::EvaluateTokens(GraphRef, CxtRef.LlamaContext, {Id}, + ReturnCode = details::evaluateTokens(GraphRef, CxtRef.LlamaContext, {Id}, CxtRef.LlamaNPast); } if (GraphRef.EnableDebugLog) { From 798c7bcb8c63139178e3c4e328481e13f305b567 Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 29 Feb 2024 16:30:35 +0800 Subject: [PATCH 248/623] [WASI-NN] ggml: bump llama.cpp b2294 Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index d65d2e06..4441271f 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -49,7 +49,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b2230 + GIT_TAG b2294 PATCH_COMMAND test -f ggml.patched || git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch && ${CMAKE_COMMAND} -E touch ggml.patched GIT_SHALLOW FALSE ) From 94a2d5e6c72e312be1d83662d89eed3c3959fed9 Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 29 Feb 2024 16:31:46 +0800 Subject: [PATCH 249/623] [WASI-NN] ggml: remove special ngl default value for macOS Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 89e8b7af..adc5d5cd 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -566,10 +566,6 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, GraphRef.ImagePath = ""sv; // Initialize the model parameters. GraphRef.NGPULayers = 0; -#ifdef __APPLE__ - // We will always set the ngl to 1 on macOS to enable Metal. - GraphRef.NGPULayers = 1; -#endif // Initialize the context parameters. GraphRef.CtxSize = ContextDefault.n_ctx; GraphRef.BatchSize = ContextDefault.n_batch; From 940f7e9a8efa2658d5fe61b9c654881432ddc094 Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 29 Feb 2024 21:14:57 +0800 Subject: [PATCH 250/623] [WASI-NN] ggml: minor fix, add more debug log Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index adc5d5cd..f1a6787c 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -460,23 +460,31 @@ ErrNo evaluateTokens(Graph &GraphRef, struct llama_context *LlamaContext, return ErrNo::Success; } -const std::string_view Base64ImageTagPrefix = ""; -const std::string_view PromptImagePlaceholder = ""; +const std::string_view Base64ImageTagPrefix = ""sv; +const std::string_view PromptImagePlaceholder = ""sv; -bool checkBase64Image(const std::string Prompt) noexcept { +bool containsBase64Image(Graph &GraphRef, std::string Prompt) noexcept { // Check if the prompt contains a base64 image. // Follow this link for the supported image formats: // https://github.com/ggerganov/llama.cpp/blob/master/common/stb_image.h auto Base64ImageTagBeginPos = Prompt.find(Base64ImageTagPrefix); if (Base64ImageTagBeginPos == std::string::npos) { + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: No base64 image tag found in the prompt."sv); + } return false; } auto Base64ImageTagEndPos = Prompt.find(Base64ImageTagSuffix, Base64ImageTagBeginPos); if (Base64ImageTagEndPos == std::string::npos) { + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: Found an unclosed base64 image tag."sv); + } return false; } return true; @@ -762,8 +770,8 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, // Handle llava format prompt. // Check if the prompt contains a base64 image. - bool ContainBase64Image = details::checkBase64Image(Prompt); - if (GraphRef.ImagePath == ""sv && ContainBase64Image == false) { + bool ContainsBase64Image = details::containsBase64Image(GraphRef, Prompt); + if (GraphRef.ImagePath == ""sv && ContainsBase64Image == false) { spdlog::error( "[WASI-NN] GGML backend: Error: when using llava model, " "you need to specify the image path or have the base64 encoded " @@ -789,7 +797,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } auto ClipContext = clip_model_load(GraphRef.MMProjModelPath.c_str(), LlavaVerbosity); - if (ContainBase64Image) { + if (ContainsBase64Image) { // Load the base64 image from the prompt. CxtRef.LlavaImageEmbd = details::loadBase64ImageFromPrompt(GraphRef, ClipContext, Prompt); From cfb38a928052543e3517aaba307c6ad4157783c2 Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 29 Feb 2024 21:23:02 +0800 Subject: [PATCH 251/623] [WASI-NN] ggml: handle base64 decode error Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index f1a6787c..eee07797 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -525,7 +525,13 @@ loadBase64ImageFromPrompt(Graph &GraphRef, clip_ctx *ClipContext, // Decode the base64 image. auto RequiredBytes = base64::required_encode_size(Base64Str.size()); auto ImageBytes = std::vector(RequiredBytes); - base64::decode(Base64Str.begin(), Base64Str.end(), ImageBytes.begin()); + try { + base64::decode(Base64Str.begin(), Base64Str.end(), ImageBytes.begin()); + } catch (const base64_error &E) { + spdlog::error("[WASI-NN] GGML backend: Error when base64::decode: {}"sv, + E.what()); + return nullptr; + } return llava_image_embed_make_with_bytes( ClipContext, GraphRef.Threads, ImageBytes.data(), ImageBytes.size()); @@ -817,8 +823,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, clip_free(ClipContext); if (CxtRef.LlavaImageEmbd == nullptr) { spdlog::error( - "[WASI-NN] GGML backend: Error: unable to load the image {}."sv, - GraphRef.ImagePath); + "[WASI-NN] GGML backend: Error: unable to load the image."sv); return ErrNo::InvalidArgument; } From 406b49dbffc1030a9aa6a7527e9b370209e04749 Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 5 Mar 2024 12:12:31 +0800 Subject: [PATCH 252/623] [WASI-NN] ggml: bump to b2334 for supporting the starcoder2 model (#3254) Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 4441271f..86a48cef 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -49,7 +49,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b2294 + GIT_TAG b2334 PATCH_COMMAND test -f ggml.patched || git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch && ${CMAKE_COMMAND} -E touch ggml.patched GIT_SHALLOW FALSE ) From 047e018835b614feeb91cd6e3f9800e7876de6cf Mon Sep 17 00:00:00 2001 From: hydai Date: Thu, 7 Mar 2024 11:21:04 +0800 Subject: [PATCH 253/623] [Installer] Support WasmEdge rustls plugin installation on the manylinux2014 aarch64 (#3262) Signed-off-by: hydai --- ...ckerfile.manylinux2014_plugins_deps_x86_64 | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 utils/docker/Dockerfile.manylinux2014_plugins_deps_x86_64 diff --git a/utils/docker/Dockerfile.manylinux2014_plugins_deps_x86_64 b/utils/docker/Dockerfile.manylinux2014_plugins_deps_x86_64 new file mode 100644 index 00000000..f3a89549 --- /dev/null +++ b/utils/docker/Dockerfile.manylinux2014_plugins_deps_x86_64 @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +ARG BASE=wasmedge/wasmedge:manylinux2014_x86_64 +ARG BUILDPLATFORM=x86_64 +FROM --platform=$BUILDPLATFORM ${BASE} + +MAINTAINER hydai hydai@secondstate.io + +ADD install-opencvmini.sh /root/ + +ENV PATH=/opt/rh/devtoolset-11/root/usr/bin${PATH:+:${PATH}} +ENV MANPATH=/opt/rh/devtoolset-11/root/usr/share/man${MANPATH:+:${MANPATH}} +ENV INFOPATH=/opt/rh/devtoolset-11/root/usr/share/info${INFOPATH:+:${INFOPATH}} +ENV PKG_CONFIG_PATH=/opt/rh/devtoolset-11/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} +ENV OPENCV_VERSION=4.8.0 + +WORKDIR /root/ + +RUN yum update -y \ + && yum install -y zlib-devel zlib-static cmake wget unzip \ + && bash /root/install-opencvmini.sh + +RUN yum clean all From e7dafa61cf01b42551309e9b3f2af4f26abd16c3 Mon Sep 17 00:00:00 2001 From: Hrushikesh Date: Tue, 12 Mar 2024 10:00:32 +0530 Subject: [PATCH 254/623] [Plugin] Initial support for FFmpeg (#2885) * added ffmpeg to cmake Signed-off-by: Hrushi20 * error.h and avformat.h file functions Signed-off-by: Hrushi20 * fix compile error Signed-off-by: Hrushi20 * avFormatContext, avInputFormat struct fields Signed-off-by: Hrushi20 * av_best_stream, stream, mediaType Signed-off-by: Hrushi20 * avCodec, avCodecParameters, AVCodecID Signed-off-by: Hrushi20 * avRational Signed-off-by: Hrushi20 * AVPixelFormat,AVFrame Signed-off-by: Hrushi20 * clean code, refactor Signed-off-by: Hrushi20 refactor Signed-off-by: Hrushi20 * avfilter, swscale Signed-off-by: Hrushi20 * bug fix Signed-off-by: Hrushi20 * init test Signed-off-by: Hrushi20 * naming plugin Signed-off-by: Hrushi20 * AVSample Signed-off-by: Hrushi20 * AVFrame, ChannelLayout Signed-off-by: Hrushi20 * added index while fetching data AVFrame Signed-off-by: Hrushi20 * AVRational Tests Signed-off-by: Hrushi20 * swscale tests init Signed-off-by: Hrushi20 * SwsFilter, SwsVector Signed-off-by: Hrushi20 swsvector Signed-off-by: Hrushi20 * swscale tests Signed-off-by: Hrushi20 * bug fix Signed-off-by: Hrushi20 * swresample init, swresample funcs Signed-off-by: Hrushi20 added swresample funcs Signed-off-by: Hrushi20 * swresample test Signed-off-by: Hrushikesh Rao Signed-off-by: Hrushi20 * refactor based on code style, add ffmpeg to github workflow Signed-off-by: Hrushi20 added plugin to workflow Signed-off-by: Hrushi20 added option Signed-off-by: Hrushi20 * swresample util functions Signed-off-by: Hrushi20 * added avChapters, avfilter version Signed-off-by: Hrushi20 * init avDictionary Signed-off-by: Hrushi20 * build ffmpeg, liniting, add ffmpeg dependencies Signed-off-by: Hrushi20 build ffmpeg Signed-off-by: Hrushi20 added ffmpeg dependency Signed-off-by: Hrushi20 added ffmpeg dependency Signed-off-by: Hrushi20 * AvDictionary functions Signed-off-by: Hrushi20 * download sample_video_file for test Signed-off-by: Hrushi20 * AVDictionary tests Signed-off-by: Hrushi20 * init avdevice Signed-off-by: Hrushi20 * AVInputFormat, AVOutputFormat funcs Signed-off-by: Hrushi20 * avformat tests init Signed-off-by: Hrushi20 * AvChapter tests Signed-off-by: Hrushi20 * AVChapterMut functions Signed-off-by: Hrushi20 * fixed AVChapter Metatdata bug Signed-off-by: Hrushi20 * AVStream, AVStreamMut functions Signed-off-by: Hrushi20 * init/deinit network func Signed-off-by: Hrushi20 * AVStream and AVStreamMut functions Signed-off-by: Hrushi20 * AVFormatCtx Input/Output functions Signed-off-by: Hrushi20 * AvFormatCtxStruct Test funcs Signed-off-by: Hrushi20 * AVFormat Functions test Signed-off-by: Hrushi20 * Bindings for AVPictureType, AVOptionType, AVRounding, AVColor, AVChromaLocation Signed-off-by: Hrushi20 * AVTime funcs Signed-off-by: Hrushi20 * AVFrame funcs Signed-off-by: Hrushi20 * AVRounding funcs Signed-off-by: Hrushi20 * Util to initialize AVFrame in test, AVFrame tests Init Signed-off-by: Hrushi20 * AVUtil Tests added Signed-off-by: Hrushi20 * install ffmpeg dependency, nasm Signed-off-by: Hrushi20 nasm dependency Signed-off-by: Hrushi20 * update ffmpeg cmake file Signed-off-by: Hrushi20 * AVcodecCtx Video Encoder funcs Signed-off-by: Hrushi20 * AVCodecCtx Video funcs, AVSendFrame Funcs, AVReceivePkt funcs Signed-off-by: Hrushi20 * AVCodec funs Signed-off-by: Hrushi20 * AVPacket funcs, Remux funcs Signed-off-by: Hrushi20 * null checks when setting AVDict Signed-off-by: Hrushi20 * AVCodecCtx funs for Encoder Signed-off-by: Hrushi20 * AVCodecCtx decoder funcs Signed-off-by: Hrushi20 * version, config, license, fetching strings(color, pixfmt,samplefmt) Signed-off-by: Hrushi20 * AVPacket Data, AVColorPrimaries bindings Signed-off-by: Hrushi20 * AVDict Null checks, AVUtil tests Signed-off-by: Hrushi20 * Swscale tests refactor Signed-off-by: Hrushi20 * AVCodec,AVCodecParameters Test Signed-off-by: Hrushi20 * AVPacket tests Signed-off-by: Hrushi20 * AVCodec funcs test init Signed-off-by: Hrushi20 * AVCodecCtx tests Signed-off-by: Hrushi20 AVCodecCtx tests Signed-off-by: Hrushi20 * AVcodec tests Signed-off-by: Hrushi20 * AVFilter funcs Signed-off-by: Hrushi20 * AVCodecId bindings updated Signed-off-by: Hrushi20 * Enhancements/bug fixes Signed-off-by: Hrushi20 * refactor swresample tests Signed-off-by: Hrushi20 * AVFilter tests Signed-off-by: Hrushi20 Filter module tests added Signed-off-by: Hrushi20 * uncomment cmake files Signed-off-by: Hrushi20 * ffmpeg tests refactor Signed-off-by: Hrushi20 * Update install-ffmpeg script, remove relative imports, update workflow Signed-off-by: Hrushi20 import header files utils.cpp Signed-off-by: Hrushi20 brew to install ffmpeg, indentation Signed-off-by: Hrushi20 update install-ffmpeg script Signed-off-by: Hrushi20 attempt to fix PKG_CONFIG_PATH Signed-off-by: Hrushi20 attempt to fix install-ffmpeg.sh script Signed-off-by: Hrushi20 attempt to fix install-ffmpeg.sh script Signed-off-by: Hrushi20 avformat_func tests, install-ffmpeg bash script Signed-off-by: Hrushi20 update workflow file, cmake styles Signed-off-by: Hrushi20 remove relative path in tests Signed-off-by: Hrushi20 * AVFrame Audio Funcs, AVCodecFuncs Signed-off-by: Hrushi20 * Added mutex lock Signed-off-by: Hrushi20 * comment avcodec_func tests, added AVInputFormat Tests Signed-off-by: Hrushi20 * fix swscale test Signed-off-by: Hrushi20 * ChLayoutMask, SampleFmtMask, PixelMask, AVCodecFunc tests uncommented Signed-off-by: Hrushi20 * Fix CI build Signed-off-by: Hrushi20 comment avcodec_open2 test Signed-off-by: Hrushi20 comment few avcodec_func tests Signed-off-by: Hrushi20 size of malloc function, replace avdict malloc to calloc Signed-off-by: Hrushi20 attempt to fix linux build Signed-off-by: Hrushi20 fix ffmpeg lib path error Signed-off-by: Hrushi20 uncomment send packet/receive frame test, added dummy file video Signed-off-by: Hrushi20 disable gmock, fix spell-check Signed-off-by: Hrushi20 fix linter bug Signed-off-by: Hrushi20 attempt fix linux build error Signed-off-by: Hrushi20 attempt fix linux build error Signed-off-by: Hrushi20 * swscale, sampleFmt, avCodec tests Signed-off-by: Hrushi20 * use nullptr, added sv to end of string, remove unwanted spaces, update code style Signed-off-by: Hrushi20 * initialize length with 0, type casting using const_cast, uncomment avframe_data test Signed-off-by: Hrushi20 * replace memmove() with std::copy_n() Signed-off-by: Hrushi20 * replace malloc -> av_malloc, calloc -> av_mallocz Signed-off-by: Hrushi20 * fix code format Signed-off-by: Hrushi20 * added TODO to missing tests Signed-off-by: Hrushi20 * update github workflow file Signed-off-by: Hrushi20 * fix ffmpeg CI build in release mode Signed-off-by: Hrushi20 * added yasm dependency Signed-off-by: Hrushi20 * remove yasm dependency arch Signed-off-by: Hrushi20 * enable gmock Signed-off-by: Hrushi20 * comment avcodec tests Signed-off-by: Hrushi20 * uncomment avfilter tests Signed-off-by: Hrushi20 * uncomment swresample/swscale Signed-off-by: Hrushi20 * uncomment avformat tests Signed-off-by: Hrushi20 * uncomment avPacket/avCodecCtx/avCodecParameters Signed-off-by: Hrushi20 * uncomment avcodec_func Signed-off-by: Hrushi20 * attempt to fix avcodec_func tests Signed-off-by: Hrushi20 * attempt to fix avCodec tests Signed-off-by: Hrushi20 * uncomment avcodec_func tests Signed-off-by: Hrushi20 * fix coding style, decouple plugin registration Signed-off-by: Hrushi20 * build fix manylinux Signed-off-by: Hrushi20 * reasoning for av_guess_codec Signed-off-by: Hrushi20 * fix variable format, initialize variable value, append sv to string literals Signed-off-by: Hrushi20 * Trigger Build Signed-off-by: Hrushi20 * fix code style, remove redundant imports Signed-off-by: Hrushi20 * fix av_read_frame test, logging in av_format test Signed-off-by: Hrushi20 * credit for sample_video in test Signed-off-by: Hrushi20 --------- Signed-off-by: Hrushi20 Signed-off-by: Hrushikesh Rao --- plugins/CMakeLists.txt | 9 +- plugins/wasmedge_ffmpeg/CMakeLists.txt | 88 + plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp | 241 + plugins/wasmedge_ffmpeg/avcodec/avCodec.h | 163 + .../avcodec/avCodecContext.cpp | 859 ++++ .../wasmedge_ffmpeg/avcodec/avCodecContext.h | 816 +++ .../avcodec/avCodecParameters.cpp | 38 + .../avcodec/avCodecParameters.h | 39 + plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp | 189 + plugins/wasmedge_ffmpeg/avcodec/avPacket.h | 173 + .../wasmedge_ffmpeg/avcodec/avcodec_base.h | 24 + .../wasmedge_ffmpeg/avcodec/avcodec_func.cpp | 338 ++ .../wasmedge_ffmpeg/avcodec/avcodec_func.h | 245 + plugins/wasmedge_ffmpeg/avcodec/module.cpp | 361 ++ plugins/wasmedge_ffmpeg/avcodec/module.h | 19 + .../wasmedge_ffmpeg/avdevice/avDevice_base.h | 25 + .../avdevice/avDevice_func.cpp | 124 + .../wasmedge_ffmpeg/avdevice/avDevice_func.h | 127 + plugins/wasmedge_ffmpeg/avdevice/module.cpp | 38 + plugins/wasmedge_ffmpeg/avdevice/module.h | 19 + plugins/wasmedge_ffmpeg/avfilter/avFilter.cpp | 149 + plugins/wasmedge_ffmpeg/avfilter/avFilter.h | 120 + .../wasmedge_ffmpeg/avfilter/avfilter_base.h | 25 + .../avfilter/avfilter_func.cpp | 300 ++ .../wasmedge_ffmpeg/avfilter/avfilter_func.h | 207 + .../avfilter/buffer_source_sink.cpp | 68 + .../avfilter/buffer_source_sink.h | 67 + plugins/wasmedge_ffmpeg/avfilter/module.cpp | 108 + plugins/wasmedge_ffmpeg/avfilter/module.h | 19 + .../wasmedge_ffmpeg/avformat/avChapter.cpp | 196 + plugins/wasmedge_ffmpeg/avformat/avChapter.h | 103 + .../avformat/avInputOutputFormat.cpp | 215 + .../avformat/avInputOutputFormat.h | 144 + plugins/wasmedge_ffmpeg/avformat/avStream.cpp | 285 ++ plugins/wasmedge_ffmpeg/avformat/avStream.h | 152 + .../avformat/avformatContext.cpp | 122 + .../avformat/avformatContext.h | 99 + .../wasmedge_ffmpeg/avformat/avformat_base.h | 25 + .../avformat/avformat_func.cpp | 383 ++ .../wasmedge_ffmpeg/avformat/avformat_func.h | 273 + plugins/wasmedge_ffmpeg/avformat/module.cpp | 193 + plugins/wasmedge_ffmpeg/avformat/module.h | 19 + .../wasmedge_ffmpeg/avutil/avDictionary.cpp | 159 + plugins/wasmedge_ffmpeg/avutil/avDictionary.h | 59 + plugins/wasmedge_ffmpeg/avutil/avFrame.cpp | 464 ++ plugins/wasmedge_ffmpeg/avutil/avFrame.h | 404 ++ plugins/wasmedge_ffmpeg/avutil/avRational.cpp | 169 + plugins/wasmedge_ffmpeg/avutil/avRational.h | 107 + plugins/wasmedge_ffmpeg/avutil/avTime.cpp | 32 + plugins/wasmedge_ffmpeg/avutil/avTime.h | 42 + plugins/wasmedge_ffmpeg/avutil/avutil_base.h | 25 + .../wasmedge_ffmpeg/avutil/avutil_func.cpp | 142 + plugins/wasmedge_ffmpeg/avutil/avutil_func.h | 208 + plugins/wasmedge_ffmpeg/avutil/error.cpp | 39 + plugins/wasmedge_ffmpeg/avutil/error.h | 36 + plugins/wasmedge_ffmpeg/avutil/module.cpp | 263 + plugins/wasmedge_ffmpeg/avutil/module.h | 19 + plugins/wasmedge_ffmpeg/avutil/pixfmt.cpp | 174 + plugins/wasmedge_ffmpeg/avutil/pixfmt.h | 132 + plugins/wasmedge_ffmpeg/avutil/samplefmt.cpp | 134 + plugins/wasmedge_ffmpeg/avutil/samplefmt.h | 105 + plugins/wasmedge_ffmpeg/bindings.h | 4425 +++++++++++++++++ plugins/wasmedge_ffmpeg/ffmpeg_env.cpp | 112 + plugins/wasmedge_ffmpeg/ffmpeg_env.h | 110 + plugins/wasmedge_ffmpeg/swresample/module.cpp | 40 + plugins/wasmedge_ffmpeg/swresample/module.h | 20 + .../swresample/swresample_base.h | 25 + .../swresample/swresample_func.cpp | 126 + .../swresample/swresample_func.h | 107 + plugins/wasmedge_ffmpeg/swscale/module.cpp | 74 + plugins/wasmedge_ffmpeg/swscale/module.h | 19 + .../wasmedge_ffmpeg/swscale/swscale_base.h | 25 + .../wasmedge_ffmpeg/swscale/swscale_func.cpp | 298 ++ .../wasmedge_ffmpeg/swscale/swscale_func.h | 226 + test/plugins/CMakeLists.txt | 8 +- test/plugins/wasmedge_ffmpeg/CMakeLists.txt | 74 + .../wasmedge_ffmpeg/avcodec/avCodec.cpp | 365 ++ .../wasmedge_ffmpeg/avcodec/avCodecCtx.cpp | 1657 ++++++ .../avcodec/avCodecParameters.cpp | 75 + .../wasmedge_ffmpeg/avcodec/avPacket.cpp | 368 ++ .../wasmedge_ffmpeg/avcodec/avcodec_func.cpp | 574 +++ .../wasmedge_ffmpeg/avfilter/avfilter.cpp | 287 ++ .../avfilter/avfilter_func.cpp | 682 +++ .../wasmedge_ffmpeg/avformat/avChapter.cpp | 220 + .../avformat/avInputOutputContext.cpp | 207 + .../wasmedge_ffmpeg/avformat/avStream.cpp | 303 ++ .../avformat/avformatContext.cpp | 184 + .../avformat/avformat_func.cpp | 588 +++ .../wasmedge_ffmpeg/avutil/avDictionary.cpp | 151 + .../wasmedge_ffmpeg/avutil/avError.cpp | 65 + .../wasmedge_ffmpeg/avutil/avFrame.cpp | 781 +++ .../wasmedge_ffmpeg/avutil/avPixfmt.cpp | 244 + .../wasmedge_ffmpeg/avutil/avRational.cpp | 313 ++ .../wasmedge_ffmpeg/avutil/avSampleFmt.cpp | 201 + .../wasmedge_ffmpeg/avutil/avutil_func.cpp | 261 + test/plugins/wasmedge_ffmpeg/main.cpp | 6 + .../swresample/swresample_func.cpp | 255 + .../wasmedge_ffmpeg/swscale/swscale_func.cpp | 540 ++ test/plugins/wasmedge_ffmpeg/utils.cpp | 259 + test/plugins/wasmedge_ffmpeg/utils.h | 164 + utils/ffmpeg/download-ffmpeg-sample-video.sh | 22 + utils/ffmpeg/install-ffmpeg-v6.0.sh | 13 + 102 files changed, 24389 insertions(+), 5 deletions(-) create mode 100644 plugins/wasmedge_ffmpeg/CMakeLists.txt create mode 100644 plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp create mode 100644 plugins/wasmedge_ffmpeg/avcodec/avCodec.h create mode 100644 plugins/wasmedge_ffmpeg/avcodec/avCodecContext.cpp create mode 100644 plugins/wasmedge_ffmpeg/avcodec/avCodecContext.h create mode 100644 plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp create mode 100644 plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.h create mode 100644 plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp create mode 100644 plugins/wasmedge_ffmpeg/avcodec/avPacket.h create mode 100644 plugins/wasmedge_ffmpeg/avcodec/avcodec_base.h create mode 100644 plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp create mode 100644 plugins/wasmedge_ffmpeg/avcodec/avcodec_func.h create mode 100644 plugins/wasmedge_ffmpeg/avcodec/module.cpp create mode 100644 plugins/wasmedge_ffmpeg/avcodec/module.h create mode 100644 plugins/wasmedge_ffmpeg/avdevice/avDevice_base.h create mode 100644 plugins/wasmedge_ffmpeg/avdevice/avDevice_func.cpp create mode 100644 plugins/wasmedge_ffmpeg/avdevice/avDevice_func.h create mode 100644 plugins/wasmedge_ffmpeg/avdevice/module.cpp create mode 100644 plugins/wasmedge_ffmpeg/avdevice/module.h create mode 100644 plugins/wasmedge_ffmpeg/avfilter/avFilter.cpp create mode 100644 plugins/wasmedge_ffmpeg/avfilter/avFilter.h create mode 100644 plugins/wasmedge_ffmpeg/avfilter/avfilter_base.h create mode 100644 plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp create mode 100644 plugins/wasmedge_ffmpeg/avfilter/avfilter_func.h create mode 100644 plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.cpp create mode 100644 plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.h create mode 100644 plugins/wasmedge_ffmpeg/avfilter/module.cpp create mode 100644 plugins/wasmedge_ffmpeg/avfilter/module.h create mode 100644 plugins/wasmedge_ffmpeg/avformat/avChapter.cpp create mode 100644 plugins/wasmedge_ffmpeg/avformat/avChapter.h create mode 100644 plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.cpp create mode 100644 plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.h create mode 100644 plugins/wasmedge_ffmpeg/avformat/avStream.cpp create mode 100644 plugins/wasmedge_ffmpeg/avformat/avStream.h create mode 100644 plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp create mode 100644 plugins/wasmedge_ffmpeg/avformat/avformatContext.h create mode 100644 plugins/wasmedge_ffmpeg/avformat/avformat_base.h create mode 100644 plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp create mode 100644 plugins/wasmedge_ffmpeg/avformat/avformat_func.h create mode 100644 plugins/wasmedge_ffmpeg/avformat/module.cpp create mode 100644 plugins/wasmedge_ffmpeg/avformat/module.h create mode 100644 plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp create mode 100644 plugins/wasmedge_ffmpeg/avutil/avDictionary.h create mode 100644 plugins/wasmedge_ffmpeg/avutil/avFrame.cpp create mode 100644 plugins/wasmedge_ffmpeg/avutil/avFrame.h create mode 100644 plugins/wasmedge_ffmpeg/avutil/avRational.cpp create mode 100644 plugins/wasmedge_ffmpeg/avutil/avRational.h create mode 100644 plugins/wasmedge_ffmpeg/avutil/avTime.cpp create mode 100644 plugins/wasmedge_ffmpeg/avutil/avTime.h create mode 100644 plugins/wasmedge_ffmpeg/avutil/avutil_base.h create mode 100644 plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp create mode 100644 plugins/wasmedge_ffmpeg/avutil/avutil_func.h create mode 100644 plugins/wasmedge_ffmpeg/avutil/error.cpp create mode 100644 plugins/wasmedge_ffmpeg/avutil/error.h create mode 100644 plugins/wasmedge_ffmpeg/avutil/module.cpp create mode 100644 plugins/wasmedge_ffmpeg/avutil/module.h create mode 100644 plugins/wasmedge_ffmpeg/avutil/pixfmt.cpp create mode 100644 plugins/wasmedge_ffmpeg/avutil/pixfmt.h create mode 100644 plugins/wasmedge_ffmpeg/avutil/samplefmt.cpp create mode 100644 plugins/wasmedge_ffmpeg/avutil/samplefmt.h create mode 100644 plugins/wasmedge_ffmpeg/bindings.h create mode 100644 plugins/wasmedge_ffmpeg/ffmpeg_env.cpp create mode 100644 plugins/wasmedge_ffmpeg/ffmpeg_env.h create mode 100644 plugins/wasmedge_ffmpeg/swresample/module.cpp create mode 100644 plugins/wasmedge_ffmpeg/swresample/module.h create mode 100644 plugins/wasmedge_ffmpeg/swresample/swresample_base.h create mode 100644 plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp create mode 100644 plugins/wasmedge_ffmpeg/swresample/swresample_func.h create mode 100644 plugins/wasmedge_ffmpeg/swscale/module.cpp create mode 100644 plugins/wasmedge_ffmpeg/swscale/module.h create mode 100644 plugins/wasmedge_ffmpeg/swscale/swscale_base.h create mode 100644 plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp create mode 100644 plugins/wasmedge_ffmpeg/swscale/swscale_func.h create mode 100644 test/plugins/wasmedge_ffmpeg/CMakeLists.txt create mode 100644 test/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp create mode 100644 test/plugins/wasmedge_ffmpeg/avcodec/avCodecCtx.cpp create mode 100644 test/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp create mode 100644 test/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp create mode 100644 test/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp create mode 100644 test/plugins/wasmedge_ffmpeg/avfilter/avfilter.cpp create mode 100644 test/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp create mode 100644 test/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp create mode 100644 test/plugins/wasmedge_ffmpeg/avformat/avInputOutputContext.cpp create mode 100644 test/plugins/wasmedge_ffmpeg/avformat/avStream.cpp create mode 100644 test/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp create mode 100644 test/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp create mode 100644 test/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp create mode 100644 test/plugins/wasmedge_ffmpeg/avutil/avError.cpp create mode 100644 test/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp create mode 100644 test/plugins/wasmedge_ffmpeg/avutil/avPixfmt.cpp create mode 100644 test/plugins/wasmedge_ffmpeg/avutil/avRational.cpp create mode 100644 test/plugins/wasmedge_ffmpeg/avutil/avSampleFmt.cpp create mode 100644 test/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp create mode 100644 test/plugins/wasmedge_ffmpeg/main.cpp create mode 100644 test/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp create mode 100644 test/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp create mode 100644 test/plugins/wasmedge_ffmpeg/utils.cpp create mode 100644 test/plugins/wasmedge_ffmpeg/utils.h create mode 100644 utils/ffmpeg/download-ffmpeg-sample-video.sh create mode 100755 utils/ffmpeg/install-ffmpeg-v6.0.sh diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index 91b64909..8de19e70 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -57,7 +57,7 @@ endif() if(WASMEDGE_PLUGIN_WASI_OCR) add_subdirectory(wasi_ocr) endif() - + if(WASMEDGE_PLUGIN_OPENCVMINI) # Only Linux and MacOS support wasmedge_opencvmini now. if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") @@ -77,5 +77,8 @@ endif() if(WASMEDGE_PLUGIN_ZLIB) add_subdirectory(wasmedge_zlib) - -endif() \ No newline at end of file +endif() + +if(WASMEDGE_PLUGIN_FFMPEG) + add_subdirectory(wasmedge_ffmpeg) +endif() diff --git a/plugins/wasmedge_ffmpeg/CMakeLists.txt b/plugins/wasmedge_ffmpeg/CMakeLists.txt new file mode 100644 index 00000000..e78dcaa8 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/CMakeLists.txt @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +find_package(PkgConfig REQUIRED) +pkg_check_modules(LIBAV REQUIRED IMPORTED_TARGET + libavdevice + libavfilter + libavformat + libavcodec + libswresample + libswscale + libavutil +) + +wasmedge_add_library(wasmedgePluginWasmEdgeFFmpeg + SHARED + ffmpeg_env.cpp + + avcodec/module.cpp + avcodec/avcodec_func.cpp + avcodec/avCodecContext.cpp + avcodec/avCodec.cpp + avcodec/avCodecParameters.cpp + avcodec/avPacket.cpp + + avdevice/module.cpp + avdevice/avDevice_func.cpp + + avfilter/module.cpp + avfilter/avfilter_func.cpp + avfilter/buffer_source_sink.cpp + avfilter/avFilter.cpp + + avformat/module.cpp + avformat/avformat_func.cpp + avformat/avformatContext.cpp + avformat/avInputOutputFormat.cpp + avformat/avStream.cpp + avformat/avChapter.cpp + + avutil/module.cpp + avutil/avutil_func.cpp + avutil/error.cpp + avutil/avRational.cpp + avutil/avFrame.cpp + avutil/pixfmt.cpp + avutil/samplefmt.cpp + avutil/avDictionary.cpp + avutil/avTime.cpp + + swresample/module.cpp + swresample/swresample_func.cpp + + swscale/module.cpp + swscale/swscale_func.cpp + +) + +target_compile_options(wasmedgePluginWasmEdgeFFmpeg + PUBLIC + -DWASMEDGE_PLUGIN + -Wno-deprecated-declarations +) + +target_include_directories(wasmedgePluginWasmEdgeFFmpeg + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_link_libraries(wasmedgePluginWasmEdgeFFmpeg + PUBLIC + PkgConfig::LIBAV +) + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasmEdgeFFmpeg + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginWasmEdgeFFmpeg + PRIVATE + wasmedge_shared + ) +endif() + +install(TARGETS wasmedgePluginWasmEdgeFFmpeg DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) diff --git a/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp b/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp new file mode 100644 index 00000000..242ccab0 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp @@ -0,0 +1,241 @@ +#include "avCodec.h" + +extern "C" { +#include "libavcodec/avcodec.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVcodec { + +Expect AVCodecID::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + return FFmpegUtils::CodecID::fromAVCodecID(AvCodec->id); +} + +Expect AVCodecType::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + return FFmpegUtils::MediaType::fromMediaType(AvCodec->type); +} + +Expect AVCodecMaxLowres::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + return AvCodec->max_lowres; +} + +Expect AVCodecCapabilities::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + return AvCodec->capabilities; +} + +Expect AVCodecGetNameLen::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + return strlen(AvCodec->name); +} + +Expect AVCodecGetName::body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecId, uint32_t NamePtr, + uint32_t NameLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(NameBuf, MemInst, char, NamePtr, NameLen, ""); + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + + const char *Name = AvCodec->name; + std::copy_n(Name, NameLen, NameBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVCodecGetLongNameLen::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + return strlen(AvCodec->long_name); +} + +Expect AVCodecGetLongName::body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecId, + uint32_t LongNamePtr, + uint32_t LongNameLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(LongNameBuf, MemInst, char, LongNamePtr, LongNameLen, ""); + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + + const char *LongName = AvCodec->long_name; + std::copy_n(LongName, LongNameLen, LongNameBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVCodecProfiles::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + if (AvCodec->profiles) + return 1; + return 0; +} + +Expect AVCodecPixFmtsIsNull::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + if (AvCodec->pix_fmts == nullptr) + return 1; + return 0; +} + +Expect AVCodecPixFmtsIter::body(const Runtime::CallingFrame &, + uint32_t AvCodecId, uint32_t Idx) { + + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + AVPixelFormat const *PixelFormat = AvCodec->pix_fmts; + if (PixelFormat == nullptr) + return 0; + + uint32_t Curr = 0; + while (Curr < Idx) { + PixelFormat++; + Curr++; + } + + return FFmpegUtils::PixFmt::fromAVPixFmt(*PixelFormat); +} + +Expect +AVCodecSupportedFrameratesIsNull::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + if (AvCodec->supported_framerates == nullptr) + return 1; + return 0; +} + +Expect +AVCodecSupportedFrameratesIter::body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecId, uint32_t Idx, + uint32_t NumPtr, uint32_t DenPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(NumId, MemInst, int32_t, NumPtr, + "Failed when accessing the return NumPtr Memory"sv); + MEM_PTR_CHECK(DenId, MemInst, int32_t, DenPtr, + "Failed when accessing the return DenPtr Memory"sv); + + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + AVRational const *Rational = AvCodec->supported_framerates; + + if (Rational == nullptr) { + *NumId = 0; + *DenId = 0; + return static_cast(ErrNo::Success); + } + + uint32_t Curr = 0; + while (Curr < Idx) { + Rational++; + Curr++; + } + + *NumId = Rational->num; + *DenId = Rational->den; + return static_cast(ErrNo::Success); +} + +Expect +AVCodecSupportedSampleRatesIsNull::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + if (AvCodec->supported_samplerates == nullptr) + return 1; + return 0; +} + +Expect +AVCodecSupportedSampleRatesIter::body(const Runtime::CallingFrame &, + uint32_t AvCodecId, uint32_t Idx) { + + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + const int32_t *SampleRates = AvCodec->supported_samplerates; + if (SampleRates == nullptr) + return 0; + + uint32_t Curr = 0; + while (Curr < Idx) { + SampleRates++; + Curr++; + } + + return *SampleRates; +} + +Expect AVCodecChannelLayoutIsNull::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + if (AvCodec->channel_layouts == nullptr) + return 1; + return 0; +} + +Expect AVCodecChannelLayoutIter::body(const Runtime::CallingFrame &, + uint32_t AvCodecId, + uint32_t Idx) { + + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + const uint64_t *ChannelLayout = AvCodec->channel_layouts; + if (ChannelLayout == nullptr) + return 0; + + uint32_t Curr = 0; + while (Curr < Idx) { + ChannelLayout++; + Curr++; + } + + return FFmpegUtils::ChannelLayout::intoChannelLayoutID(*ChannelLayout); +} + +Expect AVCodecSampleFmtsIsNull::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + if (AvCodec->sample_fmts == nullptr) + return 1; + return 0; +} + +Expect AVCodecSampleFmtsIter::body(const Runtime::CallingFrame &, + uint32_t AvCodecId, uint32_t Idx) { + + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + AVSampleFormat const *SampleFormat = AvCodec->sample_fmts; + if (SampleFormat == nullptr) + return 0; + + uint32_t Curr = 0; + while (Curr < Idx) { + SampleFormat++; + Curr++; + } + + return FFmpegUtils::SampleFmt::toSampleID(*SampleFormat); +} + +} // namespace AVcodec +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avcodec/avCodec.h b/plugins/wasmedge_ffmpeg/avcodec/avCodec.h new file mode 100644 index 00000000..da1b5fab --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avcodec/avCodec.h @@ -0,0 +1,163 @@ +#pragma once +#include "avcodec_base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVcodec { + +class AVCodecID : public WasmEdgeFFmpegAVCodec { +public: + AVCodecID(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecType : public WasmEdgeFFmpegAVCodec { +public: + AVCodecType(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecMaxLowres : public WasmEdgeFFmpegAVCodec { +public: + AVCodecMaxLowres(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecCapabilities : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCapabilities(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecGetNameLen : public WasmEdgeFFmpegAVCodec { +public: + AVCodecGetNameLen(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecGetName : public WasmEdgeFFmpegAVCodec { +public: + AVCodecGetName(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, + uint32_t NamePtr, uint32_t NameLen); +}; + +class AVCodecGetLongNameLen + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecGetLongNameLen(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecGetLongName : public WasmEdgeFFmpegAVCodec { +public: + AVCodecGetLongName(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, + uint32_t LongNamePtr, uint32_t LongNameLen); +}; + +class AVCodecProfiles : public WasmEdgeFFmpegAVCodec { +public: + AVCodecProfiles(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecPixFmtsIsNull + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecPixFmtsIsNull(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecPixFmtsIter : public WasmEdgeFFmpegAVCodec { +public: + AVCodecPixFmtsIter(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, + uint32_t Idx); +}; + +class AVCodecSupportedFrameratesIsNull + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecSupportedFrameratesIsNull(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecSupportedFrameratesIter + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecSupportedFrameratesIter(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, + uint32_t Idx, uint32_t NumPtr, uint32_t DenPtr); +}; + +class AVCodecSupportedSampleRatesIsNull + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecSupportedSampleRatesIsNull(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecSupportedSampleRatesIter + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecSupportedSampleRatesIter(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, + uint32_t Idx); +}; + +class AVCodecChannelLayoutIsNull + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecChannelLayoutIsNull(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecChannelLayoutIter + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecChannelLayoutIter(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, + uint32_t Idx); +}; + +class AVCodecSampleFmtsIsNull + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecSampleFmtsIsNull(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecSampleFmtsIter + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecSampleFmtsIter(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, + uint32_t Idx); +}; + +} // namespace AVcodec +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.cpp b/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.cpp new file mode 100644 index 00000000..b766ae78 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.cpp @@ -0,0 +1,859 @@ +#include "avCodecContext.h" + +extern "C" { +#include "libavcodec/avcodec.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVcodec { + +Expect AVCodecCtxCodecID::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVCodecID const AvCodecId = AvCodecCtx->codec_id; + return FFmpegUtils::CodecID::fromAVCodecID(AvCodecId); +} + +Expect AVCodecCtxCodecType::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVMediaType const AvMediaType = AvCodecCtx->codec_type; + return FFmpegUtils::MediaType::fromMediaType(AvMediaType); +} + +Expect AVCodecCtxSetCodecType::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t CodecTypeId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVMediaType const AvMediaType = + FFmpegUtils::MediaType::intoMediaType(CodecTypeId); + + AvCodecCtx->codec_type = AvMediaType; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetTimebase::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, int32_t Num, + int32_t Den) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVRational const Rational = av_make_q(Num, Den); + AvCodecCtx->time_base = Rational; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxTimeBase::body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint32_t NumPtr, + uint32_t DenPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(Num, MemInst, int32_t, NumPtr, + "Failed to access Numerator Ptr for AVRational"sv); + MEM_PTR_CHECK(Den, MemInst, int32_t, DenPtr, + "Failed to access Denominator Ptr for AVRational"sv); + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVRational const AvRational = AvCodecCtx->time_base; + *Num = AvRational.num; + *Den = AvRational.den; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxWidth::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->width; +} + +Expect AVCodecCtxSetWidth::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, int32_t Width) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->width = Width; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxHeight::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->height; +} + +Expect AVCodecCtxSetHeight::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t Height) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->height = Height; + return static_cast(ErrNo::Success); +} + +Expect +AVCodecCtxSampleAspectRatio::body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint32_t NumPtr, + uint32_t DenPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(Num, MemInst, int32_t, NumPtr, + "Failed to access Numerator Ptr for AVRational"sv); + MEM_PTR_CHECK(Den, MemInst, int32_t, DenPtr, + "Failed to access Denominator Ptr for AVRational"sv); + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + + const AVRational AvRational = AvCodecCtx->sample_aspect_ratio; + *Num = AvRational.num; + *Den = AvRational.den; + return static_cast(ErrNo::Success); +} + +Expect +AVCodecCtxSetSampleAspectRatio::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, int32_t Num, + int32_t Den) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + const AVRational AspectRatio = av_make_q(Num, Den); + AvCodecCtx->sample_aspect_ratio = AspectRatio; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxChannelLayout::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + // Deprecated method + uint64_t const AvChannel = AvCodecCtx->channel_layout; + return FFmpegUtils::ChannelLayout::intoChannelLayoutID(AvChannel); +} + +Expect AVCodecCtxSetChannelLayout::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + uint64_t ChannelLayoutId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + uint64_t const AvChannel = + FFmpegUtils::ChannelLayout::fromChannelLayoutID(ChannelLayoutId); + AvCodecCtx->channel_layout = AvChannel; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxPixFormat::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVPixelFormat const PixFmt = AvCodecCtx->pix_fmt; + return FFmpegUtils::PixFmt::fromAVPixFmt(PixFmt); +} + +Expect AVCodecCtxSetPixFormat::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + uint32_t PixFmtId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVPixelFormat const PixFmt = FFmpegUtils::PixFmt::intoAVPixFmt(PixFmtId); + AvCodecCtx->pix_fmt = PixFmt; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSampleFormat::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVSampleFormat const AvSampleFormat = AvCodecCtx->sample_fmt; + return FFmpegUtils::SampleFmt::toSampleID(AvSampleFormat); +} + +Expect AVCodecCtxSetSampleFormat::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + uint32_t SampleFmtId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVSampleFormat const SampleFormat = + FFmpegUtils::SampleFmt::fromSampleID(SampleFmtId); + AvCodecCtx->sample_fmt = SampleFormat; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSampleRate::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->sample_rate; +} + +Expect AVCodecCtxSetSampleRate::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t SampleRate) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->sample_rate = SampleRate; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetGopSize::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t GopSize) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->gop_size = GopSize; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetMaxBFrames::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t MaxBFrames) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->max_b_frames = MaxBFrames; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetBQuantFactor::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + float BQuantFactor) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->b_quant_factor = BQuantFactor; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetBQuantOffset::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + float BQuantOffset) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->b_quant_offset = BQuantOffset; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetIQuantFactor::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + float IQuantFactor) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->i_quant_factor = IQuantFactor; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetIQuantOffset::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + float IQuantOffset) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->i_quant_offset = IQuantOffset; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetLumiMasking::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + float LumiMasking) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->lumi_masking = LumiMasking; + return static_cast(ErrNo::Success); +} + +Expect +AVCodecCtxSetTemporalCplxMasking::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + float TemporalCplxMasking) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->temporal_cplx_masking = TemporalCplxMasking; + return static_cast(ErrNo::Success); +} + +Expect +AVCodecCtxSetSpatialCplxMasking::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + float SpatialCplxMasking) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->spatial_cplx_masking = SpatialCplxMasking; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetPMasking::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + float PMasking) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->p_masking = PMasking; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetDarkMasking::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + float DarkMasking) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->dark_masking = DarkMasking; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetMeCmp::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, int32_t MeCmp) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->me_cmp = MeCmp; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetMeSubCmp::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t MeSubCmp) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->me_sub_cmp = MeSubCmp; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetMbCmp::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, int32_t MbCmp) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->mb_cmp = MbCmp; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetIldctCmp::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t IldctCmp) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->ildct_cmp = IldctCmp; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetDiaSize::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t DiaSize) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->dia_size = DiaSize; + return static_cast(ErrNo::Success); +} + +Expect +AVCodecCtxSetLastPredictorsCount::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t LastPredictorCount) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->last_predictor_count = LastPredictorCount; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetMePreCmp::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t MePreCmp) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->me_pre_cmp = MePreCmp; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetPreDiaSize::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t PreDiaSize) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->pre_dia_size = PreDiaSize; + return static_cast(ErrNo::Success); +} + +Expect +AVCodecCtxSetMeSubpelQuality::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t MeSubpelQuality) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->me_subpel_quality = MeSubpelQuality; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetMeRange::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t MeRange) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->me_range = MeRange; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetMbDecision::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t MbDecision) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->mb_decision = MbDecision; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetMbLMin::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t MbLMin) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->mb_lmin = MbLMin; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetMbLMax::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t MbLMax) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->mb_lmax = MbLMax; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxIntraDcPrecision::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->intra_dc_precision; +} + +Expect +AVCodecCtxSetIntraDcPrecision::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t IntraDcPrecision) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->intra_dc_precision = IntraDcPrecision; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetQMin::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, int32_t QMin) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->qmin = QMin; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetQMax::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, int32_t QMax) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->qmax = QMax; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetGlobalQuality::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t GlobalQuality) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->global_quality = GlobalQuality; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetColorspace::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t ColorspaceId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVColorSpace const ColorSpace = + FFmpegUtils::ColorSpace::intoAVColorSpace(ColorspaceId); + AvCodecCtx->colorspace = ColorSpace; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxColorspace::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVColorSpace const Colorspace = AvCodecCtx->colorspace; + return FFmpegUtils::ColorSpace::fromAVColorSpace(Colorspace); +} + +Expect AVCodecCtxSetColorRange::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t ColorRangeId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->color_range = static_cast(ColorRangeId); + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxColorRange::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVColorRange const ColorRange = AvCodecCtx->color_range; + return static_cast(ColorRange); +} + +Expect AVCodecCtxFrameSize::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->frame_size; +} + +Expect AVCodecCtxBitRate::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->bit_rate; +} + +Expect AVCodecCtxSetBitRate::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int64_t BitRate) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->bit_rate = BitRate; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxRcMaxRate::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->rc_max_rate; +} + +Expect AVCodecCtxSetRcMaxRate::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int64_t RcMaxRate) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->rc_max_rate = RcMaxRate; + return static_cast(ErrNo::Success); +} + +Expect +AVCodecCtxSetBitRateTolerance::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t BitRateTolerance) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->bit_rate_tolerance = BitRateTolerance; + return static_cast(ErrNo::Success); +} + +Expect +AVCodecCtxSetCompressionLevel::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t CompressionLevel) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->compression_level = CompressionLevel; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxFrameRate::body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, + uint32_t NumPtr, uint32_t DenPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(Num, MemInst, int32_t, NumPtr, + "Failed to access Numerator Ptr for AVRational"sv); + MEM_PTR_CHECK(Den, MemInst, int32_t, DenPtr, + "Failed to access Denominator Ptr for AVRational"sv); + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + + AVRational const FrameRate = AvCodecCtx->framerate; + *Num = FrameRate.num; + *Den = FrameRate.den; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetFrameRate::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, int32_t Num, + int32_t Den) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVRational const Rational = av_make_q(Num, Den); + AvCodecCtx->framerate = Rational; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetFlags::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, int32_t Flags) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->flags = Flags; + return static_cast(ErrNo::Success); +} + +Expect +AVCodecCtxSetStrictStdCompliance::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t ComplianceId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->strict_std_compliance = ComplianceId; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetDebug::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, int32_t Debug) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->debug = Debug; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxCodec::body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, + uint32_t AvCodecPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(AVCodecId, MemInst, uint32_t, AvCodecPtr, + "Failed to access Ptr for AvCodecPtr"sv); + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + FFMPEG_PTR_FETCH(AvCodec, *AVCodecId, const AVCodec); + + AvCodec = AvCodecCtx->codec; + if (AvCodec == nullptr) + return -1; + + FFMPEG_PTR_STORE(const_cast(AvCodec), AVCodecId); + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxChannels::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->channels; +} + +Expect AVCodecCtxSetChannels::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t Channels) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->channels = Channels; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetSkipLoopFilter::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t AVDiscardId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->skip_loop_filter = static_cast(AVDiscardId); + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetSkipFrame::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t AVDiscardId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->skip_frame = static_cast(AVDiscardId); + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetSkipIdct::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t AVDiscardId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->skip_idct = static_cast(AVDiscardId); + return static_cast(ErrNo::Success); +} + +Expect +AVCodecCtxSetErrorConcealment::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t ErrorConcealment) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->error_concealment = ErrorConcealment; + return static_cast(ErrNo::Success); +} + +Expect +AVCodecCtxSetErrorRecognition::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t ErrRecognition) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->err_recognition = ErrRecognition; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxDelay::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->delay; +} + +Expect AVCodecCtxSetSkipTop::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t Value) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->skip_top = Value; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetSkipBottom::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t Value) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->skip_bottom = Value; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxRefs::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->refs; +} + +Expect AVCodecCtxSetSliceFlags::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t Value) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->slice_flags = Value; + return static_cast(ErrNo::Success); +} +Expect AVCodecCtxSetSliceCount::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t Value) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->slice_count = Value; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxSetFieldOrder::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t Value) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->field_order = static_cast(Value); + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxColorTrc::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return static_cast(AvCodecCtx->color_trc); +} + +Expect +AVCodecCtxChromaSampleLocation::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVChromaLocation const Chroma = AvCodecCtx->chroma_sample_location; + return FFmpegUtils::ChromaLocation::fromAVChromaLocation(Chroma); +} + +Expect AVCodecCtxFrameNumber::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->frame_number; +} + +Expect AVCodecCtxBlockAlign::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->block_align; +} + +Expect +AVCodecCtxSetRequestSampleFmt::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + uint32_t SampleFmtId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVSampleFormat const SampleFmt = + FFmpegUtils::SampleFmt::fromSampleID(SampleFmtId); + AvCodecCtx->request_sample_fmt = SampleFmt; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxAudioServiceType::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVAudioServiceType const AudioServiceType = AvCodecCtx->audio_service_type; + return static_cast(AudioServiceType); +} + +Expect AVCodecCtxHasBFrames::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->has_b_frames; +} + +Expect +AVCodecCtxSetRequestChannelLayout::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + uint64_t ChannelLayoutId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->request_channel_layout = + FFmpegUtils::ChannelLayout::fromChannelLayoutID(ChannelLayoutId); + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxActiveThreadType::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->active_thread_type; +} + +Expect AVCodecCtxSetThreadType::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t ThreadType) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->thread_type = ThreadType; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxThreadCount::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return AvCodecCtx->thread_count; +} + +Expect AVCodecCtxSetThreadCount::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + int32_t ThreadCount) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AvCodecCtx->thread_count = ThreadCount; + return static_cast(ErrNo::Success); +} + +Expect AVCodecCtxColorPrimaries::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + AVColorPrimaries const ColorPrimaries = AvCodecCtx->color_primaries; + return FFmpegUtils::ColorPrimaries::fromAVColorPrimaries(ColorPrimaries); +} + +} // namespace AVcodec +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.h b/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.h new file mode 100644 index 00000000..ee39aa0b --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.h @@ -0,0 +1,816 @@ +#pragma once +#include "avcodec_base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVcodec { + +class AVCodecCtxCodecID : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxCodecID(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxCodecType : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxCodecType(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetCodecType + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetCodecType(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t CodecTypeId); +}; + +class AVCodecCtxSetTimebase + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetTimebase(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Num, int32_t Den); +}; + +class AVCodecCtxTimeBase : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxTimeBase(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint32_t NumPtr, uint32_t DenPtr); +}; + +class AVCodecCtxWidth : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxWidth(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetWidth : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetWidth(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Width); +}; + +class AVCodecCtxHeight : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxHeight(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetHeight : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetHeight(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Height); +}; + +class AVCodecCtxSampleAspectRatio + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSampleAspectRatio(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint32_t NumPtr, uint32_t DenPtr); +}; + +class AVCodecCtxSetSampleAspectRatio + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetSampleAspectRatio(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Num, int32_t Den); +}; + +class AVCodecCtxChannelLayout + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxChannelLayout(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetChannelLayout + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetChannelLayout(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint64_t ChannelLayoutId); +}; + +class AVCodecCtxPixFormat : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxPixFormat(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetPixFormat + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetPixFormat(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint32_t PixFmtId); +}; + +class AVCodecCtxSampleFormat + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSampleFormat(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetSampleFormat + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetSampleFormat(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint32_t SampleFmtId); +}; + +class AVCodecCtxSampleRate + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSampleRate(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetSampleRate + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetSampleRate(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t SampleRate); +}; + +class AVCodecCtxSetGopSize + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetGopSize(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t GopSize); +}; + +class AVCodecCtxSetMaxBFrames + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetMaxBFrames(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t MaxBFrames); +}; + +class AVCodecCtxSetBQuantFactor + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetBQuantFactor(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, float BQuantFactor); +}; + +class AVCodecCtxSetBQuantOffset + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetBQuantOffset(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, float BQuantOffset); +}; + +class AVCodecCtxSetIQuantFactor + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetIQuantFactor(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, float IQuantFactor); +}; + +class AVCodecCtxSetIQuantOffset + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetIQuantOffset(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, float IQuantOffset); +}; + +class AVCodecCtxSetLumiMasking + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetLumiMasking(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, float LumiMasking); +}; + +class AVCodecCtxSetTemporalCplxMasking + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetTemporalCplxMasking(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, float TemporalCplxMasking); +}; + +class AVCodecCtxSetSpatialCplxMasking + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetSpatialCplxMasking(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, float SpatialCplxMasking); +}; + +class AVCodecCtxSetPMasking + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetPMasking(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, float PMasking); +}; + +class AVCodecCtxSetDarkMasking + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetDarkMasking(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, float DarkMasking); +}; + +class AVCodecCtxSetMeCmp : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetMeCmp(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t MeCmp); +}; + +class AVCodecCtxSetMeSubCmp + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetMeSubCmp(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t MeSubCmp); +}; + +class AVCodecCtxSetMbCmp : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetMbCmp(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t MbCmp); +}; + +class AVCodecCtxSetIldctCmp + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetIldctCmp(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t IldctCmp); +}; + +class AVCodecCtxSetDiaSize + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetDiaSize(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t DiaSize); +}; + +class AVCodecCtxSetLastPredictorsCount + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetLastPredictorsCount(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t LastPredictorCount); +}; + +class AVCodecCtxSetMePreCmp + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetMePreCmp(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t MePreCmp); +}; + +class AVCodecCtxSetPreDiaSize + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetPreDiaSize(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t PreDiaSize); +}; + +class AVCodecCtxSetMeSubpelQuality + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetMeSubpelQuality(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t MeSubpelQuality); +}; + +class AVCodecCtxSetMeRange + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetMeRange(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t MeRange); +}; + +class AVCodecCtxSetMbDecision + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetMbDecision(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t MbDecision); +}; + +class AVCodecCtxSetMbLMin : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetMbLMin(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t MbLMin); +}; + +class AVCodecCtxSetMbLMax : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetMbLMax(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t MbLMax); +}; + +class AVCodecCtxIntraDcPrecision + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxIntraDcPrecision(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetIntraDcPrecision + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetIntraDcPrecision(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t IntraDcPrecision); +}; + +class AVCodecCtxSetQMin : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetQMin(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t QMin); +}; + +class AVCodecCtxSetQMax : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetQMax(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t QMax); +}; + +class AVCodecCtxSetGlobalQuality + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetGlobalQuality(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t GlobalQuality); +}; + +class AVCodecCtxSetColorspace + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetColorspace(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t ColorspaceId); +}; + +class AVCodecCtxColorspace + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxColorspace(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetColorRange + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetColorRange(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t ColorRange); +}; + +class AVCodecCtxColorRange + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxColorRange(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxFrameSize : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxFrameSize(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxBitRate : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxBitRate(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetBitRate + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetBitRate(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int64_t BitRate); +}; + +class AVCodecCtxRcMaxRate : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxRcMaxRate(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetRcMaxRate + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetRcMaxRate(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int64_t RcMaxRate); +}; + +class AVCodecCtxSetBitRateTolerance + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetBitRateTolerance(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t BitRateTolerance); +}; + +class AVCodecCtxSetCompressionLevel + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetCompressionLevel(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t CompressionLevel); +}; + +class AVCodecCtxFrameRate : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxFrameRate(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint32_t NumPtr, uint32_t DenPtr); +}; + +class AVCodecCtxSetFrameRate + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetFrameRate(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Num, int32_t Den); +}; + +class AVCodecCtxSetFlags : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetFlags(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Flags); +}; + +class AVCodecCtxSetStrictStdCompliance + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetStrictStdCompliance(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t ComplianceId); +}; + +class AVCodecCtxSetDebug : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetDebug(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Debug); +}; + +class AVCodecCtxCodec : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxCodec(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint32_t AvCodecPtr); +}; + +class AVCodecCtxChannels : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxChannels(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetChannels + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetChannels(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Channels); +}; + +class AVCodecCtxSetSkipLoopFilter + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetSkipLoopFilter(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t AVDicardId); +}; + +class AVCodecCtxSetSkipFrame + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetSkipFrame(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t AVDiscardId); +}; + +class AVCodecCtxSetSkipIdct + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetSkipIdct(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t AVDicardId); +}; + +class AVCodecCtxSetErrorConcealment + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetErrorConcealment(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t ErrorConcealment); +}; + +class AVCodecCtxSetErrorRecognition + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetErrorRecognition(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t ErrorRecognition); +}; + +class AVCodecCtxDelay : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxDelay(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetSkipTop + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetSkipTop(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Value); +}; + +class AVCodecCtxSetSkipBottom + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetSkipBottom(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Value); +}; + +class AVCodecCtxRefs : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxRefs(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetSliceFlags + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetSliceFlags(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Flags); +}; + +class AVCodecCtxSetSliceCount + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetSliceCount(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Value); +}; + +class AVCodecCtxSetFieldOrder + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetFieldOrder(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t Value); +}; + +class AVCodecCtxColorTrc : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxColorTrc(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxChromaSampleLocation + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxChromaSampleLocation(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxFrameNumber + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxFrameNumber(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxBlockAlign + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxBlockAlign(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetRequestSampleFmt + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetRequestSampleFmt(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint32_t SampleFmtId); +}; + +class AVCodecCtxAudioServiceType + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxAudioServiceType(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxHasBFrames + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxHasBFrames(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetRequestChannelLayout + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetRequestChannelLayout(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint64_t ChannelLayoutId); +}; + +class AVCodecCtxActiveThreadType + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxActiveThreadType(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetThreadType + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetThreadType(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t ThreadType); +}; + +class AVCodecCtxThreadCount + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxThreadCount(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecCtxSetThreadCount + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxSetThreadCount(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, int32_t ThreadCount); +}; + +class AVCodecCtxColorPrimaries + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecCtxColorPrimaries(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +} // namespace AVcodec +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp b/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp new file mode 100644 index 00000000..9ffbd830 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp @@ -0,0 +1,38 @@ +#include "avCodecParameters.h" + +extern "C" { +#include "libavcodec/avcodec.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVcodec { + +Expect AVCodecParamCodecId::body(const Runtime::CallingFrame &, + uint32_t AvCodecParamId) { + + FFMPEG_PTR_FETCH(AvCodecParams, AvCodecParamId, AVCodecParameters); + return FFmpegUtils::CodecID::fromAVCodecID(AvCodecParams->codec_id); +} + +Expect AVCodecParamCodecType::body(const Runtime::CallingFrame &, + uint32_t AvCodecParamId) { + + FFMPEG_PTR_FETCH(AvCodecParams, AvCodecParamId, AVCodecParameters); + return FFmpegUtils::MediaType::fromMediaType(AvCodecParams->codec_type); +} + +Expect AVCodecParamSetCodecTag::body(const Runtime::CallingFrame &, + uint32_t AvCodecParamId, + uint32_t CodecTag) { + + FFMPEG_PTR_FETCH(AvCodecParams, AvCodecParamId, AVCodecParameters); + AvCodecParams->codec_tag = CodecTag; + return static_cast(ErrNo::Success); +} + +} // namespace AVcodec +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.h b/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.h new file mode 100644 index 00000000..9e37b7fa --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.h @@ -0,0 +1,39 @@ +#pragma once +#include "avcodec_base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVcodec { + +class AVCodecParamCodecId : public WasmEdgeFFmpegAVCodec { +public: + AVCodecParamCodecId(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecParamId); +}; + +class AVCodecParamCodecType + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecParamCodecType(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecParamId); +}; + +class AVCodecParamSetCodecTag + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecParamSetCodecTag(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecParamId, uint32_t CodecTag); +}; + +} // namespace AVcodec +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp b/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp new file mode 100644 index 00000000..892a0395 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp @@ -0,0 +1,189 @@ +#include "avPacket.h" + +extern "C" { +#include "libavcodec/packet.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVcodec { + +Expect AVPacketAlloc::body(const Runtime::CallingFrame &Frame, + uint32_t AvPacketPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(AvPacketId, MemInst, uint32_t, AvPacketPtr, + "Failed when accessing the return AVCodecContext Memory"sv); + + FFMPEG_PTR_FETCH(AvPacket, *AvPacketId, AVPacket); // Initialize the packet. + AvPacket = av_packet_alloc(); + FFMPEG_PTR_STORE(AvPacket, AvPacketId); + return static_cast(ErrNo::Success); +} + +Expect AVNewPacket::body(const Runtime::CallingFrame &, + uint32_t AvPacketId, int32_t Size) { + + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + return av_new_packet(AvPacket, Size); +} + +Expect AVPacketRef::body(const Runtime::CallingFrame &, + uint32_t DestPacketId, uint32_t SrcPacketId) { + + FFMPEG_PTR_FETCH(DestAvPacket, DestPacketId, AVPacket); + FFMPEG_PTR_FETCH(SrcAvPacket, SrcPacketId, AVPacket); + + return av_packet_ref(DestAvPacket, SrcAvPacket); +} + +Expect AVPacketUnref::body(const Runtime::CallingFrame &, + uint32_t AvPacketId) { + + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); // Free packet. + av_packet_unref(AvPacket); + FFMPEG_PTR_DELETE(AvPacketId); + return static_cast(ErrNo::Success); +} + +Expect AVGrowPacket::body(const Runtime::CallingFrame &, + uint32_t AvPacketId, int32_t Size) { + + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + return av_grow_packet(AvPacket, Size); +} + +Expect AVShrinkPacket::body(const Runtime::CallingFrame &, + uint32_t AvPacketId, int32_t Size) { + + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + av_shrink_packet(AvPacket, Size); + return static_cast(ErrNo::Success); +} + +Expect AVPacketStreamIndex::body(const Runtime::CallingFrame &, + uint32_t AvPacketId) { + + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + return AvPacket->stream_index; +} + +Expect AVPacketSetStreamIndex::body(const Runtime::CallingFrame &, + uint32_t AvPacketId, + int32_t StreamIdx) { + + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + AvPacket->stream_index = StreamIdx; + return static_cast(ErrNo::Success); +} + +Expect AVPacketSize::body(const Runtime::CallingFrame &, + uint32_t AvPacketId) { + + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + return AvPacket->size; +} + +Expect AVPacketFlags::body(const Runtime::CallingFrame &, + uint32_t AvPacketId) { + + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + return AvPacket->flags; +} + +Expect AVPacketSetFlags::body(const Runtime::CallingFrame &, + uint32_t AvPacketId, int32_t Flags) { + + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + AvPacket->flags = Flags; + return static_cast(ErrNo::Success); +} + +Expect AVPacketPos::body(const Runtime::CallingFrame &, + uint32_t AvPacketId) { + + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + return AvPacket->pos; +} + +Expect AVPacketSetPos::body(const Runtime::CallingFrame &, + uint32_t AvPacketId, int64_t Pos) { + + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + AvPacket->pos = Pos; + return static_cast(ErrNo::Success); +} + +Expect AVPacketDuration::body(const Runtime::CallingFrame &, + uint32_t AvPacketId) { + + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + return AvPacket->duration; +} + +Expect AVPacketSetDuration::body(const Runtime::CallingFrame &, + uint32_t AvPacketId, + int64_t Duration) { + + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + AvPacket->duration = Duration; + return static_cast(ErrNo::Success); +} + +Expect AVPacketDts::body(const Runtime::CallingFrame &, + uint32_t AvPacketId) { + + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + return AvPacket->dts; +} + +Expect AVPacketSetDts::body(const Runtime::CallingFrame &, + uint32_t AvPacketId, int64_t Dts) { + + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + AvPacket->dts = Dts; + return static_cast(ErrNo::Success); +} + +Expect AVPacketPts::body(const Runtime::CallingFrame &, + uint32_t AvPacketId) { + + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + return AvPacket->pts; +} + +Expect AVPacketSetPts::body(const Runtime::CallingFrame &, + uint32_t AvPacketId, int64_t Pts) { + + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + AvPacket->pts = Pts; + return static_cast(ErrNo::Success); +} + +Expect AVPacketIsDataNull::body(const Runtime::CallingFrame &, + uint32_t AvPacketId) { + + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + if (AvPacket->data == nullptr) + return 1; + return 0; +} + +Expect AVPacketData::body(const Runtime::CallingFrame &Frame, + uint32_t AvPacketId, uint32_t DataPtr, + uint32_t DataLen) { + + MEMINST_CHECK(MemInst, Frame, 0) + MEM_SPAN_CHECK(Buffer, MemInst, uint8_t, DataPtr, DataLen, ""); + + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + uint8_t *Data = AvPacket->data; + std::copy_n(Data, DataLen, Buffer.data()); + return static_cast(ErrNo::Success); +} + +} // namespace AVcodec +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avcodec/avPacket.h b/plugins/wasmedge_ffmpeg/avcodec/avPacket.h new file mode 100644 index 00000000..5ed513a8 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avcodec/avPacket.h @@ -0,0 +1,173 @@ +#pragma once +#include "avcodec_base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVcodec { + +class AVPacketAlloc : public WasmEdgeFFmpegAVCodec { +public: + AVPacketAlloc(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvPacketPtr); +}; + +class AVNewPacket : public WasmEdgeFFmpegAVCodec { +public: + AVNewPacket(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, + int32_t Size); +}; + +class AVPacketRef : public WasmEdgeFFmpegAVCodec { +public: + AVPacketRef(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t DestPacketId, uint32_t SrcPacketId); +}; + +class AVPacketUnref : public WasmEdgeFFmpegAVCodec { +public: + AVPacketUnref(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); +}; + +class AVGrowPacket : public WasmEdgeFFmpegAVCodec { +public: + AVGrowPacket(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, + int32_t Size); +}; + +class AVShrinkPacket : public WasmEdgeFFmpegAVCodec { +public: + AVShrinkPacket(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, + int32_t Size); +}; + +class AVPacketStreamIndex : public WasmEdgeFFmpegAVCodec { +public: + AVPacketStreamIndex(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); +}; + +class AVPacketSetStreamIndex + : public WasmEdgeFFmpegAVCodec { +public: + AVPacketSetStreamIndex(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, + int32_t StreamIdx); +}; + +class AVPacketSize : public WasmEdgeFFmpegAVCodec { +public: + AVPacketSize(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); +}; + +class AVPacketFlags : public WasmEdgeFFmpegAVCodec { +public: + AVPacketFlags(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); +}; + +class AVPacketSetFlags : public WasmEdgeFFmpegAVCodec { +public: + AVPacketSetFlags(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, + int32_t Flags); +}; + +class AVPacketPos : public WasmEdgeFFmpegAVCodec { +public: + AVPacketPos(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); +}; + +class AVPacketSetPos : public WasmEdgeFFmpegAVCodec { +public: + AVPacketSetPos(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, + int64_t Pos); +}; + +class AVPacketDuration : public WasmEdgeFFmpegAVCodec { +public: + AVPacketDuration(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); +}; + +class AVPacketSetDuration : public WasmEdgeFFmpegAVCodec { +public: + AVPacketSetDuration(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, + int64_t Duration); +}; + +class AVPacketDts : public WasmEdgeFFmpegAVCodec { +public: + AVPacketDts(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); +}; + +class AVPacketSetDts : public WasmEdgeFFmpegAVCodec { +public: + AVPacketSetDts(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, + int64_t Dts); +}; + +class AVPacketPts : public WasmEdgeFFmpegAVCodec { +public: + AVPacketPts(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); +}; + +class AVPacketSetPts : public WasmEdgeFFmpegAVCodec { +public: + AVPacketSetPts(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, + int64_t Pts); +}; + +class AVPacketIsDataNull : public WasmEdgeFFmpegAVCodec { +public: + AVPacketIsDataNull(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); +}; + +class AVPacketData : public WasmEdgeFFmpegAVCodec { +public: + AVPacketData(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, + uint32_t DataPtr, uint32_t DataLen); +}; + +} // namespace AVcodec +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avcodec/avcodec_base.h b/plugins/wasmedge_ffmpeg/avcodec/avcodec_base.h new file mode 100644 index 00000000..e3365858 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avcodec/avcodec_base.h @@ -0,0 +1,24 @@ +#pragma once + +#include "ffmpeg_env.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVcodec { + +template +class WasmEdgeFFmpegAVCodec : public Runtime::HostFunction { +public: + WasmEdgeFFmpegAVCodec(std::shared_ptr HostEnv) + : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + std::shared_ptr Env; +}; + +} // namespace AVcodec +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp b/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp new file mode 100644 index 00000000..91c0b9de --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp @@ -0,0 +1,338 @@ +#include "avcodec_func.h" + +extern "C" { +#include "libavcodec/avcodec.h" +#include "libavformat/avformat.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVcodec { + +Expect AVCodecAllocContext3::body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecId, + uint32_t AvCodecCtxPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(AvCodecCtxId, MemInst, uint32_t, AvCodecCtxPtr, + "Failed when accessing the return AVCodecContext Memory"sv); + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, AVCodec); + + AVCodecContext *AvCodecCtx = avcodec_alloc_context3(AvCodec); + FFMPEG_PTR_STORE(AvCodecCtx, AvCodecCtxId); + return static_cast(ErrNo::Success); +} + +Expect +AVCodecParametersFromContext::body(const Runtime::CallingFrame &, + uint32_t AvCodecParamId, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecParam, AvCodecParamId, AVCodecParameters); + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + return avcodec_parameters_from_context(AvCodecParam, AvCodecCtx); +} + +Expect AVCodecParametersFree::body(const Runtime::CallingFrame &, + uint32_t AvCodecParamId) { + + FFMPEG_PTR_FETCH(AvCodecParam, AvCodecParamId, AVCodecParameters); + + avcodec_parameters_free(&AvCodecParam); + FFMPEG_PTR_DELETE(AvCodecParamId); + return static_cast(ErrNo::Success); +} + +Expect AVCodecFreeContext::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + + avcodec_free_context(&AvCodecCtx); + FFMPEG_PTR_DELETE(AvCodecCtxId); + return static_cast(ErrNo::Success); +} + +Expect AVCodecParametersAlloc::body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecParamPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(AvCodecParamId, MemInst, uint32_t, AvCodecParamPtr, + "Failed when accessing the return AVCodecParameters Memory"sv); + + FFMPEG_PTR_FETCH(AvCodecParam, *AvCodecParamId, AVCodecParameters); + AvCodecParam = avcodec_parameters_alloc(); + FFMPEG_PTR_STORE(AvCodecParam, AvCodecParamId); + return static_cast(ErrNo::Success); +} + +Expect AVCodecGetType::body(const Runtime::CallingFrame &, + uint32_t AvCodecIdIndex) { + + AVCodecID const AvCodecId = + FFmpegUtils::CodecID::intoAVCodecID(AvCodecIdIndex); + AVMediaType const MediaType = avcodec_get_type(AvCodecId); + return FFmpegUtils::MediaType::fromMediaType(MediaType); +} + +Expect AVCodecOpen2::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, uint32_t AvCodecId, + uint32_t AvDictionaryId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + FFMPEG_PTR_FETCH(AvDictionary, AvDictionaryId, AVDictionary *); + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, AVCodec); + return avcodec_open2(AvCodecCtx, AvCodec, AvDictionary); +} + +Expect AVCodecFindDecoder::body(const Runtime::CallingFrame &Frame, + uint32_t ID, uint32_t AvCodecPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(AVCodecId, MemInst, uint32_t, AvCodecPtr, + "Failed when accessing the return AVCodec Memory"sv); + + AVCodecID const Id = FFmpegUtils::CodecID::intoAVCodecID(ID); + + const AVCodec *AvCodec = avcodec_find_decoder(Id); + + // Setting AvCodec value as NULL. + if (AvCodec == nullptr) { + *AVCodecId = 0; + return static_cast(ErrNo::Success); + } + + FFMPEG_PTR_STORE(const_cast(AvCodec), AVCodecId); + return static_cast(ErrNo::Success); +} + +Expect AVCodecIsEncoder::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + return av_codec_is_encoder(AvCodec); +} + +Expect AVCodecIsDecoder::body(const Runtime::CallingFrame &, + uint32_t AvCodecId) { + + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + return av_codec_is_decoder(AvCodec); +} + +Expect AVCodecClose::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + int Res = avcodec_close(AvCodecCtx); + FFMPEG_PTR_DELETE(AvCodecCtxId); + return Res; +} + +Expect AVCodecParametersToContext::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + uint32_t AvCodecParamId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + FFMPEG_PTR_FETCH(AvCodecParam, AvCodecParamId, AVCodecParameters); + + return avcodec_parameters_to_context(AvCodecCtx, AvCodecParam); +} + +Expect AVCodecReceiveFrame::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return avcodec_receive_frame(AvCodecCtx, AvFrame); +} + +Expect AVCodecSendPacket::body(const Runtime::CallingFrame &, + uint32_t AvCodecCtxId, + uint32_t PacketId) { + + FFMPEG_PTR_FETCH(AVCodecCtx, AvCodecCtxId, AVCodecContext); + FFMPEG_PTR_FETCH(AvPacket, PacketId, + AVPacket); // Can send Null AVPacket, to close the stream. + return avcodec_send_packet(AVCodecCtx, AvPacket); +} + +Expect AVCodecFindEncoder::body(const Runtime::CallingFrame &Frame, + uint32_t ID, uint32_t AVCodecPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(AVCodecId, MemInst, uint32_t, AVCodecPtr, + "Failed when accessing the return AVCodec Memory"sv); + + AVCodecID const Id = FFmpegUtils::CodecID::intoAVCodecID(ID); + + const AVCodec *AvCodec = avcodec_find_encoder(Id); + + // Setting AvCodec value as NULL. + if (AvCodec == nullptr) { + *AVCodecId = 0; + return static_cast(ErrNo::Success); + } + + FFMPEG_PTR_STORE(const_cast(AvCodec), AVCodecId); + return static_cast(ErrNo::Success); +} + +Expect AVCodecReceivePacket::body(const Runtime::CallingFrame &, + uint32_t AVCodecCtxId, + uint32_t PacketId) { + + FFMPEG_PTR_FETCH(AVCodecCtx, AVCodecCtxId, AVCodecContext); + FFMPEG_PTR_FETCH(AvPacket, PacketId, AVPacket); + return avcodec_receive_packet(AVCodecCtx, AvPacket); +} + +Expect AVCodecSendFrame::body(const Runtime::CallingFrame &, + uint32_t AVCodecCtxId, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AVCodecCtx, AVCodecCtxId, AVCodecContext); + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return avcodec_send_frame(AVCodecCtx, AvFrame); +} + +Expect +AVCodecFindDecoderByName::body(const Runtime::CallingFrame &Frame, + uint32_t AVCodecPtr, uint32_t NamePtr, + uint32_t NameLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(AVCodecId, MemInst, uint32_t, AVCodecPtr, + "Failed when accessing the return AVCodec Memory"sv); + MEM_PTR_CHECK(NameId, MemInst, char, NamePtr, + "Failed when accessing the return URL memory"sv); + + std::string Name; + std::copy_n(NameId, NameLen, std::back_inserter(Name)); + + AVCodec const *AvCodec = avcodec_find_decoder_by_name(Name.c_str()); + + if (AvCodec == nullptr) { + *AVCodecId = 0; + return static_cast(ErrNo::Success); + } + + FFMPEG_PTR_STORE(const_cast(AvCodec), AVCodecId); + return static_cast(ErrNo::Success); +} + +Expect +AVCodecFindEncoderByName::body(const Runtime::CallingFrame &Frame, + uint32_t AVCodecPtr, uint32_t NamePtr, + uint32_t NameLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(AVCodecId, MemInst, uint32_t, AVCodecPtr, + "Failed when accessing the return AVCodec Memory"sv); + MEM_PTR_CHECK(NameId, MemInst, char, NamePtr, + "Failed when accessing the return URL memory"sv); + + std::string Name; + std::copy_n(NameId, NameLen, std::back_inserter(Name)); + + AVCodec const *AvCodec = avcodec_find_encoder_by_name(Name.c_str()); + + if (AvCodec == nullptr) { + *AVCodecId = 0; + return static_cast(ErrNo::Success); + } + + FFMPEG_PTR_STORE(const_cast(AvCodec), AVCodecId); + return static_cast(ErrNo::Success); +} + +Expect AVPacketRescaleTs::body(const Runtime::CallingFrame &, + uint32_t AvPacketId, int32_t SrcNum, + int32_t SrcDen, int32_t DestNum, + int32_t DestDen) { + + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + AVRational const Src = av_make_q(SrcNum, SrcDen); + AVRational const Dest = av_make_q(DestNum, DestDen); + + av_packet_rescale_ts(AvPacket, Src, Dest); + return static_cast(ErrNo::Success); +} + +Expect AVPacketMakeWritable::body(const Runtime::CallingFrame &, + uint32_t AVPacketId) { + + FFMPEG_PTR_FETCH(AvPacket, AVPacketId, AVPacket); + return av_packet_make_writable(AvPacket); +} + +Expect AVCodecParametersCopy::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t AVCodecParamId, + uint32_t StreamIdx) { + + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + FFMPEG_PTR_FETCH(AvCodecParam, AVCodecParamId, AVCodecParameters); + + AVStream **AvStream = AvFormatCtx->streams; + + // No check here (Check) + // Raw Pointer Iteration. + for (unsigned int I = 1; I <= StreamIdx; I++) + AvStream++; + + return avcodec_parameters_copy((*AvStream)->codecpar, AvCodecParam); +} + +Expect AVCodecVersion::body(const Runtime::CallingFrame &) { + return avcodec_version(); +} + +Expect AVCodecFlushBuffers::body(const Runtime::CallingFrame &, + uint32_t AVCodecCtxId) { + + FFMPEG_PTR_FETCH(AvCodecCtx, AVCodecCtxId, AVCodecContext); + avcodec_flush_buffers(AvCodecCtx); + return static_cast(ErrNo::Success); +} + +Expect +AVCodecConfigurationLength::body(const Runtime::CallingFrame &) { + const char *Config = avcodec_configuration(); + return strlen(Config); +} + +Expect AVCodecConfiguration::body(const Runtime::CallingFrame &Frame, + uint32_t ConfigPtr, + uint32_t ConfigLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); + + const char *Config = avcodec_configuration(); + std::copy_n(Config, ConfigLen, ConfigBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVCodecLicenseLength::body(const Runtime::CallingFrame &) { + + const char *License = avcodec_license(); + return strlen(License); +} + +Expect AVCodecLicense::body(const Runtime::CallingFrame &Frame, + uint32_t LicensePtr, uint32_t LicenseLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); + + const char *License = avcodec_license(); + std::copy_n(License, LicenseLen, LicenseBuf.data()); + return static_cast(ErrNo::Success); +} + +} // namespace AVcodec +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.h b/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.h new file mode 100644 index 00000000..3aabf3e1 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.h @@ -0,0 +1,245 @@ +#pragma once +#include "avcodec_base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVcodec { + +class AVCodecAllocContext3 + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecAllocContext3(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, + uint32_t AvCodecCtxPtr); +}; + +class AVCodecParametersFromContext + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecParametersFromContext(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecParamId, uint32_t AvCodecCtxId); +}; + +class AVCodecParametersFree + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecParametersFree(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecParamId); +}; + +class AVCodecFreeContext : public WasmEdgeFFmpegAVCodec { +public: + AVCodecFreeContext(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId); +}; + +class AVCodecParametersAlloc + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecParametersAlloc(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecParamPtr); +}; + +class AVCodecGetType : public WasmEdgeFFmpegAVCodec { +public: + AVCodecGetType(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecOpen2 : public WasmEdgeFFmpegAVCodec { +public: + AVCodecOpen2(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint32_t AvCodecId, + uint32_t AvDictionaryId); +}; + +class AVCodecFindDecoder : public WasmEdgeFFmpegAVCodec { +public: + AVCodecFindDecoder(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ID, + uint32_t AvCodecId); +}; + +class AVCodecIsEncoder : public WasmEdgeFFmpegAVCodec { +public: + AVCodecIsEncoder(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecIsDecoder : public WasmEdgeFFmpegAVCodec { +public: + AVCodecIsDecoder(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecClose : public WasmEdgeFFmpegAVCodec { +public: + AVCodecClose(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); +}; + +class AVCodecParametersToContext + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecParametersToContext(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, + uint32_t AvCodecParamId); +}; + +class AVCodecReceiveFrame : public WasmEdgeFFmpegAVCodec { +public: + AVCodecReceiveFrame(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint32_t FrameId); +}; + +class AVCodecSendPacket : public WasmEdgeFFmpegAVCodec { +public: + AVCodecSendPacket(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvCodecCtxId, uint32_t PacketId); +}; + +class AVCodecFindEncoder : public WasmEdgeFFmpegAVCodec { +public: + AVCodecFindEncoder(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ID, + uint32_t AVCodecPtr); +}; + +class AVCodecReceivePacket + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecReceivePacket(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVCodecCtxId, uint32_t PacketId); +}; + +class AVCodecSendFrame : public WasmEdgeFFmpegAVCodec { +public: + AVCodecSendFrame(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVCodecCtxId, uint32_t FrameId); +}; + +class AVCodecFindDecoderByName + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecFindDecoderByName(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AVCodecPtr, + uint32_t NamePtr, uint32_t NameLen); +}; + +class AVCodecFindEncoderByName + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecFindEncoderByName(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AVCodecPtr, + uint32_t NamePtr, uint32_t NameLen); +}; + +class AVPacketRescaleTs : public WasmEdgeFFmpegAVCodec { +public: + AVPacketRescaleTs(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AVPacketId, + int32_t SrcNum, int32_t SrcDen, int32_t DestNum, + int32_t DestDen); +}; + +class AVPacketMakeWritable + : public WasmEdgeFFmpegAVCodec { +public: + AVPacketMakeWritable(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t AVPacketId); +}; + +class AVCodecParametersCopy + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecParametersCopy(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVFormatCtxId, uint32_t AVCodecParamId, + uint32_t StreamIdx); +}; + +class AVCodecVersion : public WasmEdgeFFmpegAVCodec { +public: + AVCodecVersion(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVCodecFlushBuffers : public WasmEdgeFFmpegAVCodec { +public: + AVCodecFlushBuffers(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVCodecCtxId); +}; + +class AVCodecConfigurationLength + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecConfigurationLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVCodecConfiguration + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecConfiguration(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, + uint32_t ConfigLen); +}; + +class AVCodecLicenseLength + : public WasmEdgeFFmpegAVCodec { +public: + AVCodecLicenseLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVCodecLicense : public WasmEdgeFFmpegAVCodec { +public: + AVCodecLicense(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVCodec(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, + uint32_t LicenseLen); +}; + +} // namespace AVcodec +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avcodec/module.cpp b/plugins/wasmedge_ffmpeg/avcodec/module.cpp new file mode 100644 index 00000000..1b904755 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avcodec/module.cpp @@ -0,0 +1,361 @@ +#include "module.h" +#include "avCodec.h" +#include "avCodecContext.h" +#include "avCodecParameters.h" +#include "avPacket.h" +#include "avcodec_func.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVcodec { + +WasmEdgeFFmpegAVCodecModule::WasmEdgeFFmpegAVCodecModule( + std::shared_ptr Env) + : ModuleInstance("wasmedge_ffmpeg_avcodec") { + + // avcodec_func.h + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_alloc_context3", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_parameters_from_context", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_parameters_free", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_free_context", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_parameters_alloc", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_get_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_open2", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_find_decoder", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_codec_is_encoder", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_codec_is_decoder", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_close", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_parameters_to_context", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_receive_frame", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_send_packet", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_find_encoder", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_receive_packet", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_send_frame", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_find_decoder_by_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_find_encoder_by_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_rescale_ts", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_make_writable", + std::make_unique(Env)); + addHostFunc( + "wasmedge_ffmpeg_avcodec_avcodec_parameters_copy", + std::make_unique(Env)); // TODO: Write Test. + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_version", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_flush_buffers", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_configuration_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_configuration", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_license_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_license", + std::make_unique(Env)); + + // avCodecContext Struct fields access + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_codec_id", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_codec_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_codec_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_time_base", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_time_base", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_width", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_width", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_height", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_height", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_sample_aspect_ratio", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_sample_aspect_ratio", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_channel_layout", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_channel_layout", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_pix_fmt", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_pix_fmt", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_sample_format", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_sample_format", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_sample_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_sample_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_gop_size", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_max_b_frames", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_b_quant_factor", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_b_quant_offset", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_i_quant_factor", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_i_quant_offset", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_lumi_masking", + std::make_unique(Env)); + addHostFunc( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_temporal_cplx_masking", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_spatial_cplx_masking", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_p_masking", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_dark_masking", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_me_cmp", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_me_sub_cmp", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_mb_cmp", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_ildct_cmp", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_dia_size", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_last_predictor_count", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_me_pre_cmp", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_pre_dia_size", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_me_subpel_quality", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_me_range", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_mb_decision", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_mb_lmin", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_mb_lmax", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_intra_dc_precision", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_intra_dc_precision", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_qmin", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_qmax", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_global_quality", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_colorspace", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_colorspace", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_color_range", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_color_range", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_frame_size", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_bit_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_bit_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_rc_max_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_rc_max_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_bit_rate_tolerance", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_compression_level", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_framerate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_framerate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_flags", + std::make_unique(Env)); + addHostFunc( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_strict_std_compliance", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_debug", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_codec", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_channels", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_channels", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_skip_loop_filter", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_skip_frame", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_skip_idct", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_error_concealment", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_err_recognition", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_delay", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_skip_top", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_skip_bottom", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_refs", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_slice_flags", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_slice_count", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_field_order", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_color_trc", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_chroma_sample_location", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_frame_number", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_block_align", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_request_sample_fmt", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_audio_service_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_has_b_frames", + std::make_unique(Env)); + addHostFunc( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_request_channel_layout", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_active_thread_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_thread_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_thread_count", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_thread_count", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_color_primaries", + std::make_unique(Env)); + + // avCodec Struct fields access + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_id", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_max_lowres", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_capabilities", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_get_name_len", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_get_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_get_long_name_len", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_get_long_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_profiles", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_pix_fmts_is_null", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_pix_fmts_iter", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_supported_framerate_is_null", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_supported_framerate_iter", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_supported_samplerates_is_null", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_supported_samplerates_iter", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_channel_layouts_is_null", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_channel_layouts_iter", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_sample_fmts_is_null", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_sample_fmts_iter", + std::make_unique(Env)); + + // AVCodecParam Struct fields access. + addHostFunc("wasmedge_ffmpeg_avcodec_avcodecparam_codec_id", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodecparam_codec_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_avcodecparam_set_codec_tag", + std::make_unique(Env)); + + // AVPacket functions. + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_alloc", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_new_packet", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_ref", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_unref", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_grow_packet", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_shrink_packet", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_stream_index", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_set_stream_index", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_size", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_flags", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_set_flags", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_pos", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_set_pos", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_duration", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_set_duration", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_dts", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_set_dts", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_pts", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_set_pts", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_is_data_null", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avcodec_av_packet_data", + std::make_unique(Env)); +} + +} // namespace AVcodec +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avcodec/module.h b/plugins/wasmedge_ffmpeg/avcodec/module.h new file mode 100644 index 00000000..7d3776d6 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avcodec/module.h @@ -0,0 +1,19 @@ +#pragma once + +#include "ffmpeg_env.h" +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVcodec { + +class WasmEdgeFFmpegAVCodecModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeFFmpegAVCodecModule(std::shared_ptr Env); +}; + +} // namespace AVcodec +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avdevice/avDevice_base.h b/plugins/wasmedge_ffmpeg/avdevice/avDevice_base.h new file mode 100644 index 00000000..06096dd4 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avdevice/avDevice_base.h @@ -0,0 +1,25 @@ +#pragma once + +#include "ffmpeg_env.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVDevice { + +template +class WasmEdgeFFmpegAVDevice : public Runtime::HostFunction { +public: + WasmEdgeFFmpegAVDevice( + std::shared_ptr HostEnv) + : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + std::shared_ptr Env; +}; + +} // namespace AVDevice +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.cpp b/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.cpp new file mode 100644 index 00000000..b5276302 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.cpp @@ -0,0 +1,124 @@ +#include "avDevice_func.h" + +extern "C" { +#include "libavdevice/avdevice.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVDevice { + +Expect AVDeviceRegisterAll::body(const Runtime::CallingFrame &) { + avdevice_register_all(); + return {}; +} + +Expect AVDeviceVersion::body(const Runtime::CallingFrame &) { + return avdevice_version(); +} + +Expect AVDeviceListDevices::body(const Runtime::CallingFrame &Frame, + uint32_t AVFormatCtxId, + uint32_t AVDeviceInfoListPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(AVDeviceInfoListId, MemInst, uint32_t, AVDeviceInfoListPtr, "") + + FFMPEG_PTR_FETCH(AvFormatCtx, AVFormatCtxId, AVFormatContext); + + AVDeviceInfoList **AvDeviceInfoList = + static_cast(av_malloc(sizeof(AVDeviceInfoList *))); + + int Res = avdevice_list_devices(AvFormatCtx, AvDeviceInfoList); + FFMPEG_PTR_STORE(AvDeviceInfoList, AVDeviceInfoListId); + return Res; +} + +Expect AVInputAudioDeviceNext::body(const Runtime::CallingFrame &) { + spdlog::error("[WasmEdge-FFmpeg] AVInputAudioDeviceNext unimplemented"sv); + // av_input_audio_device_next(); + return static_cast(ErrNo::UnImplemented); +} + +Expect AVInputVideoDeviceNext::body(const Runtime::CallingFrame &) { + spdlog::error("[WasmEdge-FFmpeg] AVInputVideoDeviceNext unimplemented"sv); + // av_input_video_device_next(); + return static_cast(ErrNo::UnImplemented); +} + +Expect AVOutputAudioDeviceNext::body(const Runtime::CallingFrame &) { + spdlog::error("[WasmEdge-FFmpeg] AVOutputAudioDeviceNext unimplemented"sv); + // av_output_audio_device_next(); + return static_cast(ErrNo::UnImplemented); +} + +Expect AVOutputVideoDeviceNext::body(const Runtime::CallingFrame &) { + spdlog::error("[WasmEdge-FFmpeg] AVOutputVideoDeviceNext unimplemented"sv); + // av_output_video_device_next(); + return static_cast(ErrNo::UnImplemented); +} + +Expect AVDeviceFreeListDevices::body(const Runtime::CallingFrame &, + uint32_t AVDeviceInfoListId) { + + FFMPEG_PTR_FETCH(AvDeviceInfoList, AVDeviceInfoListId, AVDeviceInfoList *); + + avdevice_free_list_devices(AvDeviceInfoList); + FFMPEG_PTR_DELETE(AVDeviceInfoListId); + return static_cast(ErrNo::Success); +} + +Expect AVDeviceNbDevices::body(const Runtime::CallingFrame &, + uint32_t AVDeviceInfoListId) { + + FFMPEG_PTR_FETCH(AvDeviceInfoList, AVDeviceInfoListId, AVDeviceInfoList *); + return (*AvDeviceInfoList)->nb_devices; +} + +Expect AVDeviceDefaultDevice::body(const Runtime::CallingFrame &, + uint32_t AVDeviceInfoListId) { + + FFMPEG_PTR_FETCH(AvDeviceInfoList, AVDeviceInfoListId, AVDeviceInfoList *); + return (*AvDeviceInfoList)->default_device; +} + +Expect +AVDeviceConfigurationLength::body(const Runtime::CallingFrame &) { + const char *Config = avdevice_configuration(); + return strlen(Config); +} + +Expect AVDeviceConfiguration::body(const Runtime::CallingFrame &Frame, + uint32_t ConfigPtr, + uint32_t ConfigLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); + + const char *Config = avdevice_configuration(); + std::copy_n(Config, ConfigLen, ConfigBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVDeviceLicenseLength::body(const Runtime::CallingFrame &) { + + const char *License = avdevice_license(); + return strlen(License); +} + +Expect AVDeviceLicense::body(const Runtime::CallingFrame &Frame, + uint32_t LicensePtr, + uint32_t LicenseLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); + + const char *License = avdevice_license(); + std::copy_n(License, LicenseLen, LicenseBuf.data()); + return static_cast(ErrNo::Success); +} + +} // namespace AVDevice +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.h b/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.h new file mode 100644 index 00000000..92a1309b --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.h @@ -0,0 +1,127 @@ +#pragma once + +#include "avDevice_base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVDevice { + +class AVDeviceRegisterAll : public WasmEdgeFFmpegAVDevice { +public: + AVDeviceRegisterAll(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVDevice(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVDeviceVersion : public WasmEdgeFFmpegAVDevice { +public: + AVDeviceVersion(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVDevice(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVDeviceListDevices : public WasmEdgeFFmpegAVDevice { +public: + AVDeviceListDevices(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVDevice(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVFormatCtxId, uint32_t AVDeviceInfoListPtr); +}; + +class AVInputAudioDeviceNext + : public WasmEdgeFFmpegAVDevice { +public: + AVInputAudioDeviceNext(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVDevice(HostEnv) {} + Expect body(const Runtime::CallingFrame &); +}; + +class AVInputVideoDeviceNext + : public WasmEdgeFFmpegAVDevice { +public: + AVInputVideoDeviceNext(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVDevice(HostEnv) {} + Expect body(const Runtime::CallingFrame &); +}; + +class AVOutputAudioDeviceNext + : public WasmEdgeFFmpegAVDevice { +public: + AVOutputAudioDeviceNext(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVDevice(HostEnv) {} + Expect body(const Runtime::CallingFrame &); +}; + +class AVOutputVideoDeviceNext + : public WasmEdgeFFmpegAVDevice { +public: + AVOutputVideoDeviceNext(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVDevice(HostEnv) {} + Expect body(const Runtime::CallingFrame &); +}; + +class AVDeviceFreeListDevices + : public WasmEdgeFFmpegAVDevice { +public: + AVDeviceFreeListDevices(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVDevice(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVDeviceInfoListId); +}; + +class AVDeviceNbDevices : public WasmEdgeFFmpegAVDevice { +public: + AVDeviceNbDevices(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVDevice(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVDeviceInfoListId); +}; + +class AVDeviceDefaultDevice + : public WasmEdgeFFmpegAVDevice { +public: + AVDeviceDefaultDevice(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVDevice(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVDeviceInfoListId); +}; + +class AVDeviceConfigurationLength + : public WasmEdgeFFmpegAVDevice { +public: + AVDeviceConfigurationLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVDevice(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVDeviceConfiguration + : public WasmEdgeFFmpegAVDevice { +public: + AVDeviceConfiguration(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVDevice(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, + uint32_t ConfigLen); +}; + +class AVDeviceLicenseLength + : public WasmEdgeFFmpegAVDevice { +public: + AVDeviceLicenseLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVDevice(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVDeviceLicense : public WasmEdgeFFmpegAVDevice { +public: + AVDeviceLicense(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVDevice(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, + uint32_t LicenseLen); +}; + +} // namespace AVDevice +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avdevice/module.cpp b/plugins/wasmedge_ffmpeg/avdevice/module.cpp new file mode 100644 index 00000000..a47312d5 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avdevice/module.cpp @@ -0,0 +1,38 @@ +#include "module.h" +#include "avDevice_func.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVDevice { + +WasmEdgeFFmpegAVDeviceModule::WasmEdgeFFmpegAVDeviceModule( + std::shared_ptr Env) + : ModuleInstance("wasmedge_ffmpeg_avdevice") { + + addHostFunc("wasmedge_ffmpeg_avdevice_avdevice_register_all", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avdevice_avdevice_version", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avdevice_avdevice_list_devices", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avdevice_avdevice_free_list_devices", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avdevice_avdevice_nb_devices", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avdevice_avdevice_default_device", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avdevice_avdevice_configuration_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avdevice_avdevice_configuration", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avdevice_avdevice_license_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avdevice_avdevice_license", + std::make_unique(Env)); +} + +} // namespace AVDevice +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avdevice/module.h b/plugins/wasmedge_ffmpeg/avdevice/module.h new file mode 100644 index 00000000..57e291eb --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avdevice/module.h @@ -0,0 +1,19 @@ +#pragma once + +#include "ffmpeg_env.h" +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVDevice { + +class WasmEdgeFFmpegAVDeviceModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeFFmpegAVDeviceModule(std::shared_ptr Env); +}; + +} // namespace AVDevice +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avfilter/avFilter.cpp b/plugins/wasmedge_ffmpeg/avfilter/avFilter.cpp new file mode 100644 index 00000000..2d7fcf64 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avfilter/avFilter.cpp @@ -0,0 +1,149 @@ +#include "avFilter.h" + +extern "C" { +#include "libavfilter/avfilter.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFilter { + +Expect AVFilterNameLength::body(const Runtime::CallingFrame &, + uint32_t FilterId) { + FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); + return strlen(Filter->name); +} + +Expect AVFilterName::body(const Runtime::CallingFrame &Frame, + uint32_t FilterId, uint32_t NamePtr, + uint32_t NameLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(NameBuf, MemInst, char, NamePtr, NameLen, ""); + + FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); + const char *Name = Filter->name; + std::copy_n(Name, NameLen, NameBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVFilterDescriptionLength::body(const Runtime::CallingFrame &, + uint32_t FilterId) { + FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); + return strlen(Filter->description); +} + +Expect AVFilterDescription::body(const Runtime::CallingFrame &Frame, + uint32_t FilterId, uint32_t DescPtr, + uint32_t DescLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(DescBuf, MemInst, char, DescPtr, DescLen, ""); + + FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); + const char *Desc = Filter->description; + std::copy_n(Desc, DescLen, DescBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVFilterNbInputs::body(const Runtime::CallingFrame &, + uint32_t FilterId) { + FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); + return Filter->nb_inputs; +} + +Expect AVFilterNbOutputs::body(const Runtime::CallingFrame &, + uint32_t FilterId) { + FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); + return Filter->nb_outputs; +} + +Expect AVFilterFlags::body(const Runtime::CallingFrame &, + uint32_t FilterId) { + + FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); + return Filter->flags; +} + +Expect AVFilterInOutSetName::body(const Runtime::CallingFrame &Frame, + uint32_t InOutId, uint32_t NamePtr, + uint32_t NameLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(NameBuf, MemInst, char, NamePtr, NameLen, ""); + + FFMPEG_PTR_FETCH(InOut, InOutId, AVFilterInOut); + + std::string Name; + std::copy_n(NameBuf.data(), NameLen, std::back_inserter(Name)); + char *CName = av_strdup(Name.c_str()); + if (CName == nullptr) + return static_cast(ErrNo::Success); + InOut->name = CName; + return static_cast(ErrNo::Success); +} + +Expect AVFilterInOutSetFilterCtx::body(const Runtime::CallingFrame &, + uint32_t InOutId, + uint32_t FilterCtxId) { + + FFMPEG_PTR_FETCH(InOut, InOutId, AVFilterInOut); + FFMPEG_PTR_FETCH(FilterCtx, FilterCtxId, AVFilterContext); + + InOut->filter_ctx = FilterCtx; + return static_cast(ErrNo::Success); +} + +Expect AVFilterInOutSetPadIdx::body(const Runtime::CallingFrame &, + uint32_t InOutId, int32_t PadIdx) { + + FFMPEG_PTR_FETCH(InOut, InOutId, AVFilterInOut); + InOut->pad_idx = PadIdx; + return static_cast(ErrNo::Success); +} + +Expect AVFilterInOutSetNext::body(const Runtime::CallingFrame &, + uint32_t InOutId, + uint32_t NextInOutId) { + + FFMPEG_PTR_FETCH(InOut, InOutId, AVFilterInOut); + FFMPEG_PTR_FETCH(NextInOut, NextInOutId, AVFilterInOut); + InOut->next = NextInOut; + return static_cast(ErrNo::Success); +} + +Expect +AVFilterGetInputsFilterPad::body(const Runtime::CallingFrame &Frame, + uint32_t FilterId, uint32_t FilterPadPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(FilterPadId, MemInst, uint32_t, FilterPadPtr, "") + + FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); + const AVFilterPad *FilterPad = Filter->inputs; + if (FilterPad == nullptr) + return static_cast(ErrNo::Success); + FFMPEG_PTR_STORE(const_cast(FilterPad), FilterPadId); + return static_cast(ErrNo::Success); +} + +Expect +AVFilterGetOutputsFilterPad::body(const Runtime::CallingFrame &Frame, + uint32_t FilterId, uint32_t FilterPadPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(FilterPadId, MemInst, uint32_t, FilterPadPtr, "") + + FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); + const AVFilterPad *FilterPad = Filter->outputs; + if (FilterPad == nullptr) + return static_cast(ErrNo::Success); + FFMPEG_PTR_STORE(const_cast(FilterPad), FilterPadId); + return static_cast(ErrNo::Success); +} + +} // namespace AVFilter +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avfilter/avFilter.h b/plugins/wasmedge_ffmpeg/avfilter/avFilter.h new file mode 100644 index 00000000..1dbf650f --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avfilter/avFilter.h @@ -0,0 +1,120 @@ +#pragma once + +#include "avfilter_base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFilter { + +class AVFilterNameLength : public WasmEdgeFFmpegAVFilter { +public: + AVFilterNameLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId); +}; + +class AVFilterName : public WasmEdgeFFmpegAVFilter { +public: + AVFilterName(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId, + uint32_t NamePtr, uint32_t NameLen); +}; + +class AVFilterDescriptionLength + : public WasmEdgeFFmpegAVFilter { +public: + AVFilterDescriptionLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId); +}; + +class AVFilterDescription : public WasmEdgeFFmpegAVFilter { +public: + AVFilterDescription(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId, + uint32_t DescPtr, uint32_t DescLen); +}; + +class AVFilterNbInputs : public WasmEdgeFFmpegAVFilter { +public: + AVFilterNbInputs(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId); +}; + +class AVFilterNbOutputs : public WasmEdgeFFmpegAVFilter { +public: + AVFilterNbOutputs(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId); +}; + +class AVFilterFlags : public WasmEdgeFFmpegAVFilter { +public: + AVFilterFlags(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId); +}; + +class AVFilterInOutSetName + : public WasmEdgeFFmpegAVFilter { +public: + AVFilterInOutSetName(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t InOutId, + uint32_t NamePtr, uint32_t NameLen); +}; + +class AVFilterInOutSetFilterCtx + : public WasmEdgeFFmpegAVFilter { +public: + AVFilterInOutSetFilterCtx(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t InOutId, + uint32_t FilterCtxId); +}; + +class AVFilterInOutSetPadIdx + : public WasmEdgeFFmpegAVFilter { +public: + AVFilterInOutSetPadIdx(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t InOutId, + int32_t PadIdx); +}; + +class AVFilterInOutSetNext + : public WasmEdgeFFmpegAVFilter { +public: + AVFilterInOutSetNext(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t InOutId, + uint32_t NextInOutId); +}; + +class AVFilterGetInputsFilterPad + : public WasmEdgeFFmpegAVFilter { +public: + AVFilterGetInputsFilterPad(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId, + uint32_t FilterPadPtr); +}; + +class AVFilterGetOutputsFilterPad + : public WasmEdgeFFmpegAVFilter { +public: + AVFilterGetOutputsFilterPad(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId, + uint32_t FilterPadPtr); +}; + +} // namespace AVFilter +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avfilter/avfilter_base.h b/plugins/wasmedge_ffmpeg/avfilter/avfilter_base.h new file mode 100644 index 00000000..7ed333d9 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avfilter/avfilter_base.h @@ -0,0 +1,25 @@ +#pragma once + +#include "ffmpeg_env.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFilter { + +template +class WasmEdgeFFmpegAVFilter : public Runtime::HostFunction { +public: + WasmEdgeFFmpegAVFilter( + std::shared_ptr HostEnv) + : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + std::shared_ptr Env; +}; + +} // namespace AVFilter +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp b/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp new file mode 100644 index 00000000..974a1b37 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp @@ -0,0 +1,300 @@ +#include "avfilter_func.h" + +extern "C" { +#include "libavfilter/avfilter.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFilter { + +Expect AVFilterGraphAlloc::body(const Runtime::CallingFrame &Frame, + uint32_t FilterGraphPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(FilterGraphId, MemInst, uint32_t, FilterGraphPtr, "") + + FFMPEG_PTR_FETCH(FilterGraph, *FilterGraphId, AVFilterGraph); + + FilterGraph = avfilter_graph_alloc(); + if (FilterGraph == nullptr) + return static_cast(ErrNo::Success); + FFMPEG_PTR_STORE(FilterGraph, FilterGraphId); + return static_cast(ErrNo::Success); +} + +Expect AVFilterGraphConfig::body(const Runtime::CallingFrame &, + uint32_t FilterGraphId) { + + FFMPEG_PTR_FETCH(FilterGraph, FilterGraphId, AVFilterGraph); + return avfilter_graph_config(FilterGraph, + nullptr); // log_ctx always NULL on Rust SDK. +} + +Expect AVFilterGraphFree::body(const Runtime::CallingFrame &, + uint32_t FilterGraphId) { + + FFMPEG_PTR_FETCH(FilterGraph, FilterGraphId, AVFilterGraph); + avfilter_graph_free(&FilterGraph); + FFMPEG_PTR_DELETE(FilterGraphId); + return static_cast(ErrNo::Success); +} + +Expect AVFilterGraphGetFilter::body(const Runtime::CallingFrame &Frame, + uint32_t FilterCtxPtr, + uint32_t FilterGraphId, + uint32_t NamePtr, + uint32_t NameSize) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(NameId, MemInst, char, NamePtr, + "Failed when accessing the return Name memory"sv); + MEM_PTR_CHECK(FilterCtxId, MemInst, uint32_t, FilterCtxPtr, ""); + + FFMPEG_PTR_FETCH(FilterGraph, FilterGraphId, AVFilterGraph); + FFMPEG_PTR_FETCH(FilterCtx, *FilterCtxId, AVFilterContext); + + std::string Name; + std::copy_n(NameId, NameSize, std::back_inserter(Name)); + + FilterCtx = avfilter_graph_get_filter(FilterGraph, Name.c_str()); + if (FilterCtx == nullptr) + return static_cast(ErrNo::Success); + FFMPEG_PTR_STORE(FilterCtx, FilterCtxId); + return static_cast(ErrNo::Success); +} + +Expect AVFilterGraphParsePtr::body(const Runtime::CallingFrame &Frame, + uint32_t FilterGraphId, + uint32_t FiltersString, + uint32_t FiltersSize, + uint32_t InputsId, + uint32_t OutputsId) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(FiltersId, MemInst, char, FiltersString, ""); + + FFMPEG_PTR_FETCH(Inputs, InputsId, AVFilterInOut); + FFMPEG_PTR_FETCH(Outputs, OutputsId, AVFilterInOut); + FFMPEG_PTR_FETCH(FiltersGraph, FilterGraphId, AVFilterGraph); + + std::string Filters; + std::copy_n(FiltersId, FiltersSize, std::back_inserter(Filters)); + return avfilter_graph_parse_ptr(FiltersGraph, Filters.c_str(), &Inputs, + &Outputs, nullptr); +} + +Expect AVFilterInOutFree::body(const Runtime::CallingFrame &, + uint32_t InOutId) { + + FFMPEG_PTR_FETCH(InOut, InOutId, AVFilterInOut); + avfilter_inout_free(&InOut); + FFMPEG_PTR_DELETE(InOutId); + return static_cast(ErrNo::Success); +} + +Expect AVFilterVersion::body(const Runtime::CallingFrame &) { + return avfilter_version(); +} + +Expect AVFilterGetByName::body(const Runtime::CallingFrame &Frame, + uint32_t FilterPtr, uint32_t StrPtr, + uint32_t StrLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(StrId, MemInst, char, StrPtr, + "Failed when accessing the return Str memory"sv); + MEM_PTR_CHECK(FilterId, MemInst, uint32_t, FilterPtr, + "Failed when accessing the return Filter memory"sv); + + FFMPEG_PTR_FETCH(Filter, *FilterId, const struct AVFilter); + std::string Name; + std::copy_n(StrId, StrLen, std::back_inserter(Name)); + + Filter = avfilter_get_by_name(Name.c_str()); + if (Filter == nullptr) + return static_cast(ErrNo::Success); + + FFMPEG_PTR_STORE(const_cast(Filter), FilterId); + return static_cast(ErrNo::Success); +} + +Expect +AVFilterConfigurationLength::body(const Runtime::CallingFrame &) { + + const char *Config = avfilter_configuration(); + return strlen(Config); +} + +Expect AVFilterConfiguration::body(const Runtime::CallingFrame &Frame, + uint32_t ConfigPtr, + uint32_t ConfigLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); + + const char *Config = avfilter_configuration(); + std::copy_n(Config, ConfigLen, ConfigBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVFilterLicenseLength::body(const Runtime::CallingFrame &) { + + const char *License = avfilter_license(); + return strlen(License); +} + +Expect AVFilterLicense::body(const Runtime::CallingFrame &Frame, + uint32_t LicensePtr, + uint32_t LicenseLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); + + const char *License = avfilter_license(); + std::copy_n(License, LicenseLen, LicenseBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVFilterGraphCreateFilter::body( + const Runtime::CallingFrame &Frame, uint32_t FilterCtxPtr, + uint32_t FilterId, uint32_t NamePtr, uint32_t NameLen, uint32_t ArgsPtr, + uint32_t ArgsLen, uint32_t FilterGraphId) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(NameBuf, MemInst, char, NamePtr, NameLen, ""); + MEM_SPAN_CHECK(ArgsBuf, MemInst, char, ArgsPtr, ArgsLen, ""); + MEM_PTR_CHECK(FilterCtxId, MemInst, uint32_t, FilterCtxPtr, "") + + FFMPEG_PTR_FETCH(FilterCtx, *FilterCtxId, AVFilterContext); + FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); + FFMPEG_PTR_FETCH(FilterGraph, FilterGraphId, AVFilterGraph); + + std::string Name; + std::string Args; + std::copy_n(NameBuf.data(), NameLen, std::back_inserter(Name)); + std::copy_n(ArgsBuf.data(), ArgsLen, std::back_inserter(Args)); + + int Res = avfilter_graph_create_filter(&FilterCtx, Filter, Name.c_str(), + Args.c_str(), nullptr, FilterGraph); + if (Res < 0) + return Res; + + FFMPEG_PTR_STORE(FilterCtx, FilterCtxId); + return Res; +} + +Expect AVFilterInOutAlloc::body(const Runtime::CallingFrame &Frame, + uint32_t InOutPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(InOutId, MemInst, uint32_t, InOutPtr, "") + + FFMPEG_PTR_FETCH(InOut, *InOutId, AVFilterInOut); + InOut = avfilter_inout_alloc(); + if (InOut == nullptr) + return static_cast(ErrNo::Success); + FFMPEG_PTR_STORE(InOut, InOutId); + return static_cast(ErrNo::Success); +} + +Expect AVFilterPadGetNameLength::body(const Runtime::CallingFrame &, + uint32_t FilterPadId, + int32_t Idx) { + + FFMPEG_PTR_FETCH(FilterPad, FilterPadId, AVFilterPad); + + const char *Name = avfilter_pad_get_name(FilterPad, Idx); + return strlen(Name); +} + +Expect AVFilterPadGetName::body(const Runtime::CallingFrame &Frame, + uint32_t FilterPadId, int32_t Idx, + uint32_t NamePtr, uint32_t NameLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(NameBuf, MemInst, char, NamePtr, NameLen, ""); + + FFMPEG_PTR_FETCH(FilterPad, FilterPadId, AVFilterPad); + + const char *Name = avfilter_pad_get_name(FilterPad, Idx); + std::copy_n(Name, NameLen, NameBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVFilterPadGetType::body(const Runtime::CallingFrame &, + uint32_t FilterPadId, int32_t Idx) { + + FFMPEG_PTR_FETCH(FilterPad, FilterPadId, AVFilterPad); + AVMediaType const MediaType = avfilter_pad_get_type(FilterPad, Idx); + return FFmpegUtils::MediaType::fromMediaType(MediaType); +} + +Expect AVFilterGraphDumpLength::body(const Runtime::CallingFrame &, + uint32_t FilterGraphId) { + + FFMPEG_PTR_FETCH(FilterGraph, FilterGraphId, AVFilterGraph); + char *Graph = avfilter_graph_dump(FilterGraph, nullptr); + return strlen(Graph); +} + +Expect AVFilterGraphDump::body(const Runtime::CallingFrame &Frame, + uint32_t FilterGraphId, + uint32_t GraphStrPtr, + uint32_t GraphStrLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(GraphStr, MemInst, char, GraphStrPtr, GraphStrLen, ""); + + FFMPEG_PTR_FETCH(FilterGraph, FilterGraphId, AVFilterGraph); + + char *Graph = avfilter_graph_dump(FilterGraph, nullptr); + std::copy_n(Graph, GraphStrLen, GraphStr.data()); + return static_cast(ErrNo::Success); +} + +Expect AVFilterFreeGraphStr::body(const Runtime::CallingFrame &, + uint32_t FilterGraphId) { + + FFMPEG_PTR_FETCH(FilterGraph, FilterGraphId, AVFilterGraph); + + char *Graph = avfilter_graph_dump(FilterGraph, nullptr); + av_free(Graph); + return static_cast(ErrNo::Success); +} + +Expect AVFilterDrop::body(const Runtime::CallingFrame &, + uint32_t FilterId) { + + FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); + if (Filter == nullptr) + return static_cast(ErrNo::Success); + FFMPEG_PTR_DELETE(FilterId); + return static_cast(ErrNo::Success); +} + +Expect AVFilterPadDrop::body(const Runtime::CallingFrame &, + uint32_t FilterPadId) { + + FFMPEG_PTR_FETCH(FilterPad, FilterPadId, AVFilterPad); + if (FilterPad == nullptr) + return static_cast(ErrNo::Success); + FFMPEG_PTR_DELETE(FilterPadId); + return static_cast(ErrNo::Success); +} + +Expect AVFilterContextDrop::body(const Runtime::CallingFrame &, + uint32_t FilterCtxId) { + + FFMPEG_PTR_FETCH(FilterCtx, FilterCtxId, AVFilterContext); + if (FilterCtx == nullptr) + return static_cast(ErrNo::Success); + FFMPEG_PTR_DELETE(FilterCtxId); + return static_cast(ErrNo::Success); +} + +} // namespace AVFilter +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.h b/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.h new file mode 100644 index 00000000..b96914e2 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.h @@ -0,0 +1,207 @@ +#pragma once + +#include "avfilter_base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFilter { + +class AVFilterGraphAlloc : public WasmEdgeFFmpegAVFilter { +public: + AVFilterGraphAlloc(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterGraphPtr); +}; + +class AVFilterGraphConfig : public WasmEdgeFFmpegAVFilter { +public: + AVFilterGraphConfig(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterGraphId); +}; + +class AVFilterGraphFree : public WasmEdgeFFmpegAVFilter { +public: + AVFilterGraphFree(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterGraphId); +}; + +class AVFilterGraphGetFilter + : public WasmEdgeFFmpegAVFilter { +public: + AVFilterGraphGetFilter(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterCtxPtr, uint32_t FilterGraphId, + uint32_t NamePtr, uint32_t NameSize); +}; + +class AVFilterGraphParsePtr + : public WasmEdgeFFmpegAVFilter { +public: + AVFilterGraphParsePtr(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterGraphId, uint32_t FiltersString, + uint32_t FiltersSize, uint32_t InputsId, + uint32_t OutputsId); +}; + +class AVFilterInOutFree : public WasmEdgeFFmpegAVFilter { +public: + AVFilterInOutFree(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t InOutId); +}; + +class AVFilterVersion : public WasmEdgeFFmpegAVFilter { +public: + AVFilterVersion(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVFilterGetByName : public WasmEdgeFFmpegAVFilter { +public: + AVFilterGetByName(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterPtr, + uint32_t StrPtr, uint32_t StrLen); +}; + +class AVFilterConfigurationLength + : public WasmEdgeFFmpegAVFilter { +public: + AVFilterConfigurationLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVFilterConfiguration + : public WasmEdgeFFmpegAVFilter { +public: + AVFilterConfiguration(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, + uint32_t ConfigLen); +}; + +class AVFilterLicenseLength + : public WasmEdgeFFmpegAVFilter { +public: + AVFilterLicenseLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVFilterLicense : public WasmEdgeFFmpegAVFilter { +public: + AVFilterLicense(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, + uint32_t LicenseLen); +}; + +class AVFilterGraphCreateFilter + : public WasmEdgeFFmpegAVFilter { +public: + AVFilterGraphCreateFilter(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterCtxPtr, uint32_t FilterId, + uint32_t NamePtr, uint32_t NameLen, uint32_t ArgsPtr, + uint32_t ArgsLen, uint32_t FilterGraphId); +}; + +class AVFilterInOutAlloc : public WasmEdgeFFmpegAVFilter { +public: + AVFilterInOutAlloc(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t InOutPtr); +}; + +class AVFilterPadGetNameLength + : public WasmEdgeFFmpegAVFilter { +public: + AVFilterPadGetNameLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterPadId, + int32_t Idx); +}; + +class AVFilterPadGetName : public WasmEdgeFFmpegAVFilter { +public: + AVFilterPadGetName(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterPadId, + int32_t Idx, uint32_t NamePtr, uint32_t NameLen); +}; + +class AVFilterPadGetType : public WasmEdgeFFmpegAVFilter { +public: + AVFilterPadGetType(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterPadId, + int32_t Idx); +}; + +class AVFilterGraphDumpLength + : public WasmEdgeFFmpegAVFilter { +public: + AVFilterGraphDumpLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterGraphId); +}; + +class AVFilterGraphDump : public WasmEdgeFFmpegAVFilter { +public: + AVFilterGraphDump(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterGraphId, uint32_t GraphStrPtr, + uint32_t GraphStrLen); +}; + +class AVFilterFreeGraphStr + : public WasmEdgeFFmpegAVFilter { +public: + AVFilterFreeGraphStr(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterGraphId); +}; + +class AVFilterDrop : public WasmEdgeFFmpegAVFilter { +public: + AVFilterDrop(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId); +}; + +class AVFilterPadDrop : public WasmEdgeFFmpegAVFilter { +public: + AVFilterPadDrop(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterPadId); +}; + +class AVFilterContextDrop : public WasmEdgeFFmpegAVFilter { +public: + AVFilterContextDrop(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterCtxId); +}; + +} // namespace AVFilter +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.cpp b/plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.cpp new file mode 100644 index 00000000..0aed3697 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.cpp @@ -0,0 +1,68 @@ +#include "buffer_source_sink.h" +extern "C" { +#include "libavfilter/buffersink.h" +#include "libavfilter/buffersrc.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFilter { + +Expect AVBufferSinkGetFrame::body(const Runtime::CallingFrame &, + uint32_t FilterContextId, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(FilterCtx, FilterContextId, AVFilterContext); + FFMPEG_PTR_FETCH(Frame, FrameId, AVFrame); + return av_buffersink_get_frame(FilterCtx, Frame); +} + +Expect AVBufferSinkGetSamples::body(const Runtime::CallingFrame &, + uint32_t FilterContextId, + uint32_t FrameId, + int32_t Samples) { + + FFMPEG_PTR_FETCH(FilterCtx, FilterContextId, AVFilterContext); + FFMPEG_PTR_FETCH(Frame, FrameId, AVFrame); + return av_buffersink_get_samples(FilterCtx, Frame, Samples); +} + +Expect AvBufferSinkSetFrameSize::body(const Runtime::CallingFrame &, + uint32_t FilterContextId, + int32_t Value) { + + FFMPEG_PTR_FETCH(FilterCtx, FilterContextId, AVFilterContext); + av_buffersink_set_frame_size(FilterCtx, Value); + return static_cast(ErrNo::Success); +} + +Expect +AVBufferSrcGetNbFailedRequests::body(const Runtime::CallingFrame &, + uint32_t FilterContextId) { + + FFMPEG_PTR_FETCH(FilterCtx, FilterContextId, AVFilterContext); + return av_buffersrc_get_nb_failed_requests(FilterCtx); +} + +Expect AVBufferSrcAddFrame::body(const Runtime::CallingFrame &, + uint32_t FilterContextId, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(FilterCtx, FilterContextId, AVFilterContext); + FFMPEG_PTR_FETCH(Frame, FrameId, AVFrame); + return av_buffersrc_add_frame(FilterCtx, Frame); +} + +Expect AVBufferSrcClose::body(const Runtime::CallingFrame &, + uint32_t FilterContextId, int64_t Pts, + uint32_t Flags) { + + FFMPEG_PTR_FETCH(FilterCtx, FilterContextId, AVFilterContext); + return av_buffersrc_close(FilterCtx, Pts, Flags); +} + +} // namespace AVFilter +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.h b/plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.h new file mode 100644 index 00000000..7119d5fa --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.h @@ -0,0 +1,67 @@ +#pragma once + +#include "avfilter_base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFilter { + +class AVBufferSinkGetFrame + : public WasmEdgeFFmpegAVFilter { +public: + AVBufferSinkGetFrame(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterContextId, uint32_t FrameId); +}; + +class AVBufferSinkGetSamples + : public WasmEdgeFFmpegAVFilter { +public: + AVBufferSinkGetSamples(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterContextId, uint32_t FrameId, + int32_t Samples); +}; + +class AvBufferSinkSetFrameSize + : public WasmEdgeFFmpegAVFilter { +public: + AvBufferSinkSetFrameSize(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterContextId, int32_t Value); +}; + +class AVBufferSrcGetNbFailedRequests + : public WasmEdgeFFmpegAVFilter { +public: + AVBufferSrcGetNbFailedRequests(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterContextId); +}; + +class AVBufferSrcAddFrame : public WasmEdgeFFmpegAVFilter { +public: + AVBufferSrcAddFrame(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterContextId, uint32_t FrameId); +}; + +class AVBufferSrcClose : public WasmEdgeFFmpegAVFilter { +public: + AVBufferSrcClose(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFilter(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t FilterContextId, int64_t Pts, uint32_t Flags); +}; + +} // namespace AVFilter +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avfilter/module.cpp b/plugins/wasmedge_ffmpeg/avfilter/module.cpp new file mode 100644 index 00000000..2c31e44f --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avfilter/module.cpp @@ -0,0 +1,108 @@ +#include "module.h" +#include "avFilter.h" +#include "avfilter_func.h" +#include "buffer_source_sink.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFilter { + +WasmEdgeFFmpegAVFilterModule::WasmEdgeFFmpegAVFilterModule( + std::shared_ptr Env) + : ModuleInstance("wasmedge_ffmpeg_avfilter") { + + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_graph_alloc", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_graph_config", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_graph_free", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_graph_get_filter", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_graph_parse_ptr", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_inout_free", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_version", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_get_by_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_configuration_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_configuration", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_license_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_license", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_graph_create_filter", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_inout_alloc", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_pad_get_name_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_pad_get_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_pad_get_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_graph_dump_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_graph_dump", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_free_graph_str", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_drop", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_pad_drop", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_context_drop", + std::make_unique(Env)); + + // buffersrc.h && buffersink.h + addHostFunc("wasmedge_ffmpeg_avfilter_av_buffersink_get_frame", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_av_buffersink_get_samples", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_av_buffersink_set_frame_size", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_av_buffersrc_get_nb_failed_requests", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_av_buffersrc_add_frame", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_av_buffersrc_close", + std::make_unique(Env)); + + // avfilter.h + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_name_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_description_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_description", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_nb_inputs", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_nb_outputs", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_flags", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_inout_set_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_inout_set_filter_ctx", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_inout_set_pad_idx", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_inout_set_next", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_get_inputs_filter_pad", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_get_outputs_filter_pad", + std::make_unique(Env)); +} + +} // namespace AVFilter +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avfilter/module.h b/plugins/wasmedge_ffmpeg/avfilter/module.h new file mode 100644 index 00000000..2515e6ae --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avfilter/module.h @@ -0,0 +1,19 @@ +#pragma once + +#include "ffmpeg_env.h" +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFilter { + +class WasmEdgeFFmpegAVFilterModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeFFmpegAVFilterModule(std::shared_ptr Env); +}; + +} // namespace AVFilter +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp b/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp new file mode 100644 index 00000000..1d5f5c41 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp @@ -0,0 +1,196 @@ +#include "avChapter.h" + +extern "C" { +#include "libavformat/avformat.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFormat { + +Expect AVChapterId::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, uint32_t ChapterIdx) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVChapter **AvChapter = AvFormatContext->chapters; + + // No check here (Check) + // Raw Pointer Iteration. + for (unsigned int I = 1; I <= ChapterIdx; I++) + AvChapter++; + + return static_cast(*AvChapter)->id; +} + +Expect AVChapterSetId::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t ChapterIdx, int64_t ChapterId) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVChapter **AvChapter = AvFormatContext->chapters; + + // No check here (Check) + // Raw Pointer Iteration. + for (unsigned int I = 1; I <= ChapterIdx; I++) + AvChapter++; + + (*AvChapter)->id = ChapterId; + return static_cast(ErrNo::Success); +} + +Expect AVChapterTimebase::body(const Runtime::CallingFrame &Frame, + uint32_t NumPtr, uint32_t DenPtr, + uint32_t AvFormatCtxId, + uint32_t ChapterIdx) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(Num, MemInst, int32_t, NumPtr, ""); + MEM_PTR_CHECK(Den, MemInst, int32_t, DenPtr, ""); + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVChapter **AvChapter = AvFormatContext->chapters; + + // No check here (Check) + // Raw Pointer Iteration. + for (unsigned int I = 1; I <= ChapterIdx; I++) + AvChapter++; + + AVRational const AvRational = static_cast(*AvChapter)->time_base; + *Num = AvRational.num; + *Den = AvRational.den; + return static_cast(ErrNo::Success); +} + +Expect AVChapterSetTimebase::body(const Runtime::CallingFrame &, + int32_t Num, int32_t Den, + uint32_t AvFormatCtxId, + uint32_t ChapterIdx) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVRational const Timebase = av_make_q(Num, Den); + + AVChapter **AvChapter = AvFormatContext->chapters; + + // No check here (Check) + // Raw Pointer Iteration. + for (unsigned int I = 1; I <= ChapterIdx; I++) + AvChapter++; + + (*AvChapter)->time_base = Timebase; + return static_cast(ErrNo::Success); +} + +Expect AVChapterStart::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t ChapterIdx) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVChapter **AvChapter = AvFormatContext->chapters; + + // No check here (Check) + // Raw Pointer Iteration. + for (unsigned int I = 1; I <= ChapterIdx; I++) + AvChapter++; + + return static_cast(*AvChapter)->start; +} + +Expect AVChapterSetStart::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t ChapterIdx, + int64_t StartValue) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVChapter **AvChapter = AvFormatContext->chapters; + + // No check here (Check) + // Raw Pointer Iteration. + for (unsigned int I = 1; I <= ChapterIdx; I++) + AvChapter++; + + (*AvChapter)->start = StartValue; + return static_cast(ErrNo::Success); +} + +Expect AVChapterEnd::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t ChapterIdx) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVChapter **AvChapter = AvFormatContext->chapters; + + // No check here (Check) + // Raw Pointer Iteration. + for (unsigned int I = 1; I <= ChapterIdx; I++) + AvChapter++; + + return static_cast(*AvChapter)->end; +} + +Expect AVChapterSetEnd::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t ChapterIdx, int64_t EndValue) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVChapter **AvChapter = AvFormatContext->chapters; + + // No check here (Check) + // Raw Pointer Iteration. + for (unsigned int I = 1; I <= ChapterIdx; I++) + AvChapter++; + + (*AvChapter)->end = EndValue; + return static_cast(ErrNo::Success); +} + +Expect AVChapterMetadata::body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, + uint32_t ChapterIdx, uint32_t DictPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(DictId, MemInst, uint32_t, DictPtr, + "Failed when accessing the return AVDictionary memory"sv); + + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + + AVDictionary **AvDictionary = + static_cast(av_malloc(sizeof(AVDictionary *))); + AVChapter **AvChapter = AvFormatCtx->chapters; + + // No check here (Check) + // Raw Pointer Iteration. + for (unsigned int I = 1; I <= ChapterIdx; I++) + AvChapter++; + + *AvDictionary = (*AvChapter)->metadata; + FFMPEG_PTR_STORE(AvDictionary, DictId); + return static_cast(ErrNo::Success); +} + +Expect AVChapterSetMetadata::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t ChapterIdx, + uint32_t DictId) { + + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + FFMPEG_PTR_FETCH(AvDictionary, DictId, AVDictionary *); + + AVChapter **AvChapter = AvFormatCtx->chapters; + + // No check here (Check) + // Raw Pointer Iteration. + for (unsigned int I = 1; I <= ChapterIdx; I++) + AvChapter++; + + if (AvDictionary == nullptr) + (*AvChapter)->metadata = nullptr; + else + (*AvChapter)->metadata = *AvDictionary; + return static_cast(ErrNo::Success); +} + +} // namespace AVFormat +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/avChapter.h b/plugins/wasmedge_ffmpeg/avformat/avChapter.h new file mode 100644 index 00000000..bb40c088 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avformat/avChapter.h @@ -0,0 +1,103 @@ +#pragma once + +#include "avformat_base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFormat { + +class AVChapterId : public WasmEdgeFFmpegAVFormat { +public: + AVChapterId(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t ChapterIdx); +}; + +class AVChapterSetId : public WasmEdgeFFmpegAVFormat { +public: + AVChapterSetId(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t ChapterIdx, + int64_t ChapterId); +}; + +class AVChapterTimebase : public WasmEdgeFFmpegAVFormat { +public: + AVChapterTimebase(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t NumPtr, + uint32_t DenPtr, uint32_t AvFormatCtxId, + uint32_t ChapterIdx); +}; + +class AVChapterSetTimebase + : public WasmEdgeFFmpegAVFormat { +public: + AVChapterSetTimebase(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t Num, + int32_t Den, uint32_t AvFormatCtxId, + uint32_t ChapterIdx); +}; + +class AVChapterStart : public WasmEdgeFFmpegAVFormat { +public: + AVChapterStart(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t ChapterIdx); +}; + +class AVChapterSetStart : public WasmEdgeFFmpegAVFormat { +public: + AVChapterSetStart(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t ChapterIdx, + int64_t StartValue); +}; + +class AVChapterEnd : public WasmEdgeFFmpegAVFormat { +public: + AVChapterEnd(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t ChapterIdx); +}; + +class AVChapterSetEnd : public WasmEdgeFFmpegAVFormat { +public: + AVChapterSetEnd(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t ChapterIdx, + int64_t EndValue); +}; + +class AVChapterMetadata : public WasmEdgeFFmpegAVFormat { +public: + AVChapterMetadata(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t ChapterIdx, + uint32_t DictPtr); +}; + +class AVChapterSetMetadata + : public WasmEdgeFFmpegAVFormat { +public: + AVChapterSetMetadata(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t ChapterIdx, + uint32_t DictId); +}; + +} // namespace AVFormat +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.cpp b/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.cpp new file mode 100644 index 00000000..5c0628f2 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.cpp @@ -0,0 +1,215 @@ +#include "avInputOutputFormat.h" + +extern "C" { +#include "libavformat/avformat.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFormat { + +Expect AVIOFormatNameLength::body(const Runtime::CallingFrame &, + uint32_t AVIOFormatId, + uint32_t FormatType) { + + const char *Name; + + if (FormatType == 0) { + FFMPEG_PTR_FETCH(AvInputFormat, AVIOFormatId, AVInputFormat); + Name = AvInputFormat->name; + } else { + FFMPEG_PTR_FETCH(AvOutputFormat, AVIOFormatId, AVOutputFormat); + Name = AvOutputFormat->name; + } + + if (Name == nullptr) + return 0; + return strlen(Name); +} + +Expect AVInputFormatName::body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId, + uint32_t NamePtr, uint32_t NameLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(NameBuf, MemInst, char, NamePtr, NameLen, ""); + FFMPEG_PTR_FETCH(AvInputFormat, AVInputFormatId, AVInputFormat); + + const char *Name = AvInputFormat->name; + std::copy_n(Name, NameLen, NameBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVOutputFormatName::body(const Runtime::CallingFrame &Frame, + uint32_t AVOutputFormatId, + uint32_t NamePtr, uint32_t NameLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(NameBuf, MemInst, char, NamePtr, NameLen, ""); + FFMPEG_PTR_FETCH(AvOutputFormat, AVOutputFormatId, AVOutputFormat); + + const char *Name = AvOutputFormat->name; + std::copy_n(Name, NameLen, NameBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVIOFormatLongNameLength::body(const Runtime::CallingFrame &, + uint32_t AVIOFormatId, + uint32_t FormatType) { + + const char *LongName; + + if (FormatType == 0) { + FFMPEG_PTR_FETCH(AvInputFormat, AVIOFormatId, AVInputFormat); + LongName = AvInputFormat->long_name; + } else { + FFMPEG_PTR_FETCH(AvOutputFormat, AVIOFormatId, AVOutputFormat); + LongName = AvOutputFormat->long_name; + } + + if (LongName == nullptr) + return 0; + return strlen(LongName); +} + +Expect AVInputFormatLongName::body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId, + uint32_t LongNamePtr, + uint32_t LongNameLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(LongNameBuf, MemInst, char, LongNamePtr, LongNameLen, ""); + FFMPEG_PTR_FETCH(AvInputFormat, AVInputFormatId, AVInputFormat); + + const char *LongName = AvInputFormat->long_name; + std::copy_n(LongName, LongNameLen, LongNameBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVOutputFormatLongName::body(const Runtime::CallingFrame &Frame, + uint32_t AVOutputFormatId, + uint32_t LongNamePtr, + uint32_t LongNameLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(LongNameBuf, MemInst, char, LongNamePtr, LongNameLen, ""); + FFMPEG_PTR_FETCH(AvOutputFormat, AVOutputFormatId, AVOutputFormat); + + const char *LongName = AvOutputFormat->long_name; + std::copy_n(LongName, LongNameLen, LongNameBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVIOFormatExtensionsLength::body(const Runtime::CallingFrame &, + uint32_t AVIOFormatId, + uint32_t FormatType) { + + const char *Extensions; + + if (FormatType == 0) { + FFMPEG_PTR_FETCH(AvInputFormat, AVIOFormatId, AVInputFormat); + Extensions = AvInputFormat->extensions; + } else { + FFMPEG_PTR_FETCH(AvOutputFormat, AVIOFormatId, AVOutputFormat); + Extensions = AvOutputFormat->extensions; + } + + if (Extensions == nullptr) + return 0; + return strlen(Extensions); +} + +Expect +AVInputFormatExtensions::body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId, uint32_t ExtensionsPtr, + uint32_t ExtensionsLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(ExtensionsBuf, MemInst, char, ExtensionsPtr, ExtensionsLen, + ""); + FFMPEG_PTR_FETCH(AvInputFormat, AVInputFormatId, AVInputFormat); + + const char *Extensions = AvInputFormat->extensions; + std::copy_n(Extensions, ExtensionsLen, ExtensionsBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect +AVOutputFormatExtensions::body(const Runtime::CallingFrame &Frame, + uint32_t AVOutputFormatId, + uint32_t ExtensionsPtr, uint32_t ExtensionsLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(ExtensionsBuf, MemInst, char, ExtensionsPtr, ExtensionsLen, + ""); + FFMPEG_PTR_FETCH(AvOutputFormat, AVOutputFormatId, AVOutputFormat); + + const char *Extensions = AvOutputFormat->extensions; + std::copy_n(Extensions, ExtensionsLen, ExtensionsBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVIOFormatMimeTypeLength::body(const Runtime::CallingFrame &, + uint32_t AVIOFormatId, + uint32_t FormatType) { + const char *MimeType; + + if (FormatType == 0) { + FFMPEG_PTR_FETCH(AvInputFormat, AVIOFormatId, AVInputFormat); + MimeType = AvInputFormat->mime_type; + } else { + FFMPEG_PTR_FETCH(AvOutputFormat, AVIOFormatId, AVOutputFormat); + MimeType = AvOutputFormat->mime_type; + } + + if (MimeType == nullptr) + return 0; + return strlen(MimeType); +} + +Expect AVInputFormatMimeType::body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId, + uint32_t MimeTypePtr, + uint32_t MimeTypeLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(MimeTypeBuf, MemInst, char, MimeTypePtr, MimeTypeLen, ""); + FFMPEG_PTR_FETCH(AvInputFormat, AVInputFormatId, AVInputFormat); + + const char *MimeType = AvInputFormat->mime_type; + std::copy_n(MimeType, MimeTypeLen, MimeTypeBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVOutputFormatMimeType::body(const Runtime::CallingFrame &Frame, + uint32_t AVOutputFormatId, + uint32_t MimeTypePtr, + uint32_t MimeTypeLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(MimeTypeBuf, MemInst, char, MimeTypePtr, MimeTypeLen, ""); + FFMPEG_PTR_FETCH(AvOutputFormat, AVOutputFormatId, AVOutputFormat); + + const char *MimeType = AvOutputFormat->mime_type; + std::copy_n(MimeType, MimeTypeLen, MimeTypeBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVOutputFormatFlags::body(const Runtime::CallingFrame &, + uint32_t AVOutputFormatId) { + + FFMPEG_PTR_FETCH(AvOutputFormat, AVOutputFormatId, AVOutputFormat); + return AvOutputFormat->flags; +} + +Expect AVInputOutputFormatFree::body(const Runtime::CallingFrame &, + uint32_t AVInputOutputId) { + FFMPEG_PTR_DELETE(AVInputOutputId); + return static_cast(ErrNo::Success); +} + +} // namespace AVFormat +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.h b/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.h new file mode 100644 index 00000000..9c815ff8 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.h @@ -0,0 +1,144 @@ +#pragma once +#include "avformat_base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFormat { + +class AVIOFormatNameLength + : public WasmEdgeFFmpegAVFormat { +public: + AVIOFormatNameLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t AVIOFormatId, + uint32_t FormatType); +}; + +class AVInputFormatName : public WasmEdgeFFmpegAVFormat { +public: + AVInputFormatName(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId, uint32_t NamePtr, + uint32_t NameLen); +}; + +class AVOutputFormatName : public WasmEdgeFFmpegAVFormat { +public: + AVOutputFormatName(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId, uint32_t NamePtr, + uint32_t NameLen); +}; + +class AVIOFormatLongNameLength + : public WasmEdgeFFmpegAVFormat { +public: + AVIOFormatLongNameLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t AVIOFormatId, + uint32_t FormatType); +}; + +class AVInputFormatLongName + : public WasmEdgeFFmpegAVFormat { +public: + AVInputFormatLongName(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId, uint32_t LongNamePtr, + uint32_t LongNameLen); +}; + +class AVOutputFormatLongName + : public WasmEdgeFFmpegAVFormat { +public: + AVOutputFormatLongName(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId, uint32_t LongNamePtr, + uint32_t LongNameLen); +}; + +class AVIOFormatExtensionsLength + : public WasmEdgeFFmpegAVFormat { +public: + AVIOFormatExtensionsLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t AVIOFormatId, + uint32_t FormatType); +}; + +class AVInputFormatExtensions + : public WasmEdgeFFmpegAVFormat { +public: + AVInputFormatExtensions(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId, uint32_t Extensions, + uint32_t ExtensionsLen); +}; + +class AVOutputFormatExtensions + : public WasmEdgeFFmpegAVFormat { +public: + AVOutputFormatExtensions(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId, uint32_t Extensions, + uint32_t ExtensionsLen); +}; + +class AVIOFormatMimeTypeLength + : public WasmEdgeFFmpegAVFormat { +public: + AVIOFormatMimeTypeLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t AVIOFormatId, + uint32_t FormatType); +}; + +class AVInputFormatMimeType + : public WasmEdgeFFmpegAVFormat { +public: + AVInputFormatMimeType(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId, uint32_t MimeTypePtr, + uint32_t MimeTypeLen); +}; + +class AVOutputFormatMimeType + : public WasmEdgeFFmpegAVFormat { +public: + AVOutputFormatMimeType(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId, uint32_t MimeTypePtr, + uint32_t MimeTypeLen); +}; + +class AVOutputFormatFlags : public WasmEdgeFFmpegAVFormat { +public: + AVOutputFormatFlags(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVInputFormatId); +}; + +class AVInputOutputFormatFree + : public WasmEdgeFFmpegAVFormat { +public: + AVInputOutputFormatFree(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVInputOutputId); +}; + +} // namespace AVFormat +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/avStream.cpp b/plugins/wasmedge_ffmpeg/avformat/avStream.cpp new file mode 100644 index 00000000..1eaebef6 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avformat/avStream.cpp @@ -0,0 +1,285 @@ +#include "avStream.h" + +extern "C" { +#include "libavformat/avformat.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFormat { + +Expect AVStreamId::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, uint32_t StreamIdx) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVStream **AvStream = AvFormatContext->streams; + + // No check here (Check) + // Raw Pointer Iteration. + for (unsigned int I = 1; I <= StreamIdx; I++) + AvStream++; + + return static_cast(*AvStream)->id; +} + +Expect AVStreamIndex::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t StreamIdx) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) + AvStream++; + + return static_cast(*AvStream)->index; +} + +Expect AVStreamCodecPar::body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, + uint32_t StreamIdx, + uint32_t CodecParameterPtr) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(CodecParamId, MemInst, uint32_t, CodecParameterPtr, + "Failed when accessing the return CodecParameter Memory"sv); + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) + AvStream++; + + AVCodecParameters *CodecParam = + (static_cast(*AvStream))->codecpar; + FFMPEG_PTR_STORE(CodecParam, CodecParamId); + return static_cast(ErrNo::Success); +} + +Expect AVStreamTimebase::body(const Runtime::CallingFrame &Frame, + uint32_t NumPtr, uint32_t DenPtr, + uint32_t AvFormatCtxId, + uint32_t StreamIdx) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(Num, MemInst, int32_t, NumPtr, ""); + MEM_PTR_CHECK(Den, MemInst, int32_t, DenPtr, ""); + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) + AvStream++; + + AVRational const AvRational = static_cast(*AvStream)->time_base; + *Num = AvRational.num; + *Den = AvRational.den; + return static_cast(ErrNo::Success); +} + +Expect AVStreamSetTimebase::body(const Runtime::CallingFrame &, + uint32_t Num, uint32_t Den, + uint32_t AvFormatCtxId, + uint32_t StreamIdx) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + + AVStream **AvStream = AvFormatContext->streams; + for (unsigned int I = 1; I <= StreamIdx; I++) + AvStream++; + + AVRational const Timebase = av_make_q(Num, Den); + (*AvStream)->time_base = Timebase; + return static_cast(ErrNo::Success); +} + +Expect AVStreamDuration::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t StreamIdx) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) + AvStream++; + + return static_cast(*AvStream)->duration; +} + +Expect AVStreamStartTime::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t StreamIdx) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) + AvStream++; + + return static_cast(*AvStream)->start_time; +} + +Expect AVStreamNbFrames::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t StreamIdx) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) + AvStream++; + + return static_cast(*AvStream)->nb_frames; +} + +Expect AVStreamDisposition::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t StreamIdx) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) + AvStream++; + + return static_cast(*AvStream)->disposition; +} + +Expect AVStreamRFrameRate::body(const Runtime::CallingFrame &Frame, + uint32_t NumPtr, uint32_t DenPtr, + uint32_t AvFormatCtxId, + uint32_t StreamIdx) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(Num, MemInst, int32_t, NumPtr, ""); + MEM_PTR_CHECK(Den, MemInst, int32_t, DenPtr, ""); + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) + AvStream++; + + AVRational const AvRational = + static_cast(*AvStream)->r_frame_rate; + *Num = AvRational.num; + *Den = AvRational.den; + return static_cast(ErrNo::Success); +} + +Expect AVStreamSetRFrameRate::body(const Runtime::CallingFrame &, + int32_t Num, int32_t Den, + uint32_t AvFormatCtxId, + uint32_t StreamIdx) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + + AVStream **AvStream = AvFormatContext->streams; + for (unsigned int I = 1; I <= StreamIdx; I++) + AvStream++; + + AVRational const RFrameRate = av_make_q(Num, Den); + (*AvStream)->r_frame_rate = RFrameRate; + return static_cast(ErrNo::Success); +} + +Expect AVStreamAvgFrameRate::body(const Runtime::CallingFrame &Frame, + uint32_t NumPtr, uint32_t DenPtr, + uint32_t AvFormatCtxId, + uint32_t StreamIdx) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(Num, MemInst, int32_t, NumPtr, ""); + MEM_PTR_CHECK(Den, MemInst, int32_t, DenPtr, ""); + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) + AvStream++; + + AVRational const AvRational = + static_cast(*AvStream)->avg_frame_rate; + *Num = AvRational.num; + *Den = AvRational.den; + return static_cast(ErrNo::Success); +} + +Expect AVStreamSetAvgFrameRate::body(const Runtime::CallingFrame &, + int32_t Num, int32_t Den, + uint32_t AvFormatCtxId, + uint32_t StreamIdx) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) + AvStream++; + + AVRational const AvgFrameRate = av_make_q(Num, Den); + (*AvStream)->avg_frame_rate = AvgFrameRate; + return static_cast(ErrNo::Success); +} + +Expect AVStreamMetadata::body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, + uint32_t StreamIdx, uint32_t DictPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(DictId, MemInst, uint32_t, DictPtr, + "Failed when accessing the return AVDictPtr Memory"sv); + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) + AvStream++; + + AVDictionary **AvDictionary = + static_cast(av_malloc(sizeof(AVDictionary *))); + + *AvDictionary = (*AvStream)->metadata; + FFMPEG_PTR_STORE(AvDictionary, DictId); + return static_cast(ErrNo::Success); +} + +Expect AVStreamSetMetadata::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t StreamIdx, uint32_t DictId) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + FFMPEG_PTR_FETCH(AvDictionary, DictId, AVDictionary *); + + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) + AvStream++; + + if (AvDictionary == nullptr) + (*AvStream)->metadata = nullptr; + else + (*AvStream)->metadata = *AvDictionary; + + return static_cast(ErrNo::Success); +} + +Expect AVStreamDiscard::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t StreamIdx) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AVStream **AvStream = AvFormatContext->streams; + + for (unsigned int I = 1; I <= StreamIdx; I++) + AvStream++; + + return static_cast((*AvStream)->discard); +} + +} // namespace AVFormat +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/avStream.h b/plugins/wasmedge_ffmpeg/avformat/avStream.h new file mode 100644 index 00000000..4a8956ce --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avformat/avStream.h @@ -0,0 +1,152 @@ +#pragma once + +#include "avformat_base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFormat { + +class AVStreamId : public WasmEdgeFFmpegAVFormat { +public: + AVStreamId(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t StreamIdx); +}; + +class AVStreamIndex : public WasmEdgeFFmpegAVFormat { +public: + AVStreamIndex(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t StreamIdx); +}; + +class AVStreamCodecPar : public WasmEdgeFFmpegAVFormat { +public: + AVStreamCodecPar(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t StreamIdx, + uint32_t CodecParameterPtr); +}; + +class AVStreamTimebase : public WasmEdgeFFmpegAVFormat { +public: + AVStreamTimebase(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t NumPtr, + uint32_t DenPtr, uint32_t AvFormatCtxId, + uint32_t StreamIdx); +}; + +class AVStreamSetTimebase : public WasmEdgeFFmpegAVFormat { +public: + AVStreamSetTimebase(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t Num, + uint32_t Den, uint32_t AvFormatCtxId, + uint32_t StreamIdx); +}; + +class AVStreamDuration : public WasmEdgeFFmpegAVFormat { +public: + AVStreamDuration(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t StreamIdx); +}; + +class AVStreamStartTime : public WasmEdgeFFmpegAVFormat { +public: + AVStreamStartTime(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t StreamIdx); +}; + +class AVStreamNbFrames : public WasmEdgeFFmpegAVFormat { +public: + AVStreamNbFrames(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t StreamIdx); +}; + +class AVStreamDisposition : public WasmEdgeFFmpegAVFormat { +public: + AVStreamDisposition(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t StreamIdx); +}; + +class AVStreamRFrameRate : public WasmEdgeFFmpegAVFormat { +public: + AVStreamRFrameRate(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t NumPtr, + uint32_t DenPtr, uint32_t AvFormatCtxId, + uint32_t StreamIdx); +}; + +class AVStreamSetRFrameRate + : public WasmEdgeFFmpegAVFormat { +public: + AVStreamSetRFrameRate(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t Num, + int32_t Den, uint32_t AvFormatCtxId, uint32_t StreamIdx); +}; + +class AVStreamAvgFrameRate + : public WasmEdgeFFmpegAVFormat { +public: + AVStreamAvgFrameRate(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t NumPtr, + uint32_t DenPtr, uint32_t AvFormatCtxId, + uint32_t StreamIdx); +}; + +class AVStreamSetAvgFrameRate + : public WasmEdgeFFmpegAVFormat { +public: + AVStreamSetAvgFrameRate(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t Num, + int32_t Den, uint32_t AvFormatCtxId, uint32_t StreamIdx); +}; + +class AVStreamMetadata : public WasmEdgeFFmpegAVFormat { +public: + AVStreamMetadata(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t StreamIdx, + uint32_t DictPtr); +}; + +class AVStreamSetMetadata : public WasmEdgeFFmpegAVFormat { +public: + AVStreamSetMetadata(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t StreamIdx, + uint32_t DictId); +}; + +class AVStreamDiscard : public WasmEdgeFFmpegAVFormat { +public: + AVStreamDiscard(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t StreamIdx); +}; + +} // namespace AVFormat +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp b/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp new file mode 100644 index 00000000..5c9515f8 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp @@ -0,0 +1,122 @@ +#include "avformatContext.h" + +extern "C" { +#include "libavformat/avformat.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFormat { + +Expect AVFormatCtxIFormat::body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, + uint32_t AvInputFormatPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(AvInputFormatId, MemInst, uint32_t, AvInputFormatPtr, + "Failed when accessing the return AVInputFormat Memory"sv); + + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + + AVInputFormat const *AvInputFormat = AvFormatCtx->iformat; + FFMPEG_PTR_STORE(const_cast(AvInputFormat), AvInputFormatId); + return static_cast(ErrNo::Success); +} + +Expect AVFormatCtxOFormat::body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, + uint32_t AvOutputFormatPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(AvOutputFormatId, MemInst, uint32_t, AvOutputFormatPtr, + "Failed when accessing the return AVOutputFormat Memory"sv); + + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + + AVOutputFormat const *AvOutputFormat = AvFormatCtx->oformat; + FFMPEG_PTR_STORE(const_cast(AvOutputFormat), + AvOutputFormatId); + return static_cast(ErrNo::Success); +} + +Expect AVFormatCtxProbeScore::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + return AvFormatContext->probe_score; +} + +Expect AVFormatCtxNbStreams::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + return AvFormatContext->nb_streams; +}; + +Expect AVFormatCtxBitRate::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + return AvFormatContext->bit_rate; +} + +Expect AVFormatCtxDuration::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + return AvFormatContext->duration; +} + +Expect AVFormatCtxNbChapters::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + return AvFormatContext->nb_chapters; +} + +Expect AVFormatCtxSetNbChapters::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t NbChapters) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + AvFormatContext->nb_chapters = NbChapters; + return static_cast(ErrNo::Success); +} + +Expect AVFormatCtxMetadata::body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, + uint32_t DictPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(DictId, MemInst, uint32_t, DictPtr, + "Failed when accessing the return AVDictionary memory"sv); + + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + + AVDictionary **AvDictionary = + static_cast(av_malloc(sizeof(AVDictionary *))); + + *AvDictionary = AvFormatCtx->metadata; + FFMPEG_PTR_STORE(AvDictionary, DictId); + return static_cast(ErrNo::Success); +} + +Expect AVFormatCtxSetMetadata::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t DictId) { + + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + FFMPEG_PTR_FETCH(AvDictionary, DictId, AVDictionary *); + + if (AvDictionary == nullptr) + AvFormatCtx->metadata = nullptr; + else + AvFormatCtx->metadata = *AvDictionary; + return static_cast(ErrNo::Success); +} + +} // namespace AVFormat +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/avformatContext.h b/plugins/wasmedge_ffmpeg/avformat/avformatContext.h new file mode 100644 index 00000000..90cd679e --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avformat/avformatContext.h @@ -0,0 +1,99 @@ +#pragma once + +#include "avformat_base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFormat { + +class AVFormatCtxIFormat : public WasmEdgeFFmpegAVFormat { +public: + AVFormatCtxIFormat(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t AvInputFormatPtr); +}; + +class AVFormatCtxOFormat : public WasmEdgeFFmpegAVFormat { +public: + AVFormatCtxOFormat(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t AvOutputFormatPtr); +}; + +class AVFormatCtxProbeScore + : public WasmEdgeFFmpegAVFormat { +public: + AVFormatCtxProbeScore(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId); +}; + +class AVFormatCtxNbStreams + : public WasmEdgeFFmpegAVFormat { +public: + AVFormatCtxNbStreams(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId); +}; + +class AVFormatCtxBitRate : public WasmEdgeFFmpegAVFormat { +public: + AVFormatCtxBitRate(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId); +}; + +class AVFormatCtxDuration : public WasmEdgeFFmpegAVFormat { +public: + AVFormatCtxDuration(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId); +}; + +class AVFormatCtxNbChapters + : public WasmEdgeFFmpegAVFormat { +public: + AVFormatCtxNbChapters(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId); +}; + +class AVFormatCtxSetNbChapters + : public WasmEdgeFFmpegAVFormat { +public: + AVFormatCtxSetNbChapters(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t NbChapters); +}; + +class AVFormatCtxMetadata : public WasmEdgeFFmpegAVFormat { +public: + AVFormatCtxMetadata(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t DictPtr); +}; + +class AVFormatCtxSetMetadata + : public WasmEdgeFFmpegAVFormat { +public: + AVFormatCtxSetMetadata(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t DictId); +}; + +} // namespace AVFormat +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/avformat_base.h b/plugins/wasmedge_ffmpeg/avformat/avformat_base.h new file mode 100644 index 00000000..28318eb2 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avformat/avformat_base.h @@ -0,0 +1,25 @@ +#pragma once + +#include "ffmpeg_env.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFormat { + +template +class WasmEdgeFFmpegAVFormat : public Runtime::HostFunction { +public: + WasmEdgeFFmpegAVFormat( + std::shared_ptr HostEnv) + : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + std::shared_ptr Env; +}; + +} // namespace AVFormat +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp b/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp new file mode 100644 index 00000000..545894db --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp @@ -0,0 +1,383 @@ +#include "avformat_func.h" + +extern "C" { +#include "libavcodec/packet.h" +#include "libavformat/avformat.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFormat { + +Expect AVFormatOpenInput::body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxPtr, + uint32_t UrlPtr, uint32_t UrlSize, + uint32_t AvInputFormatId, + uint32_t AvDictionaryId) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(urlId, MemInst, char, UrlPtr, + "Failed when accessing the return URL memory"sv); + MEM_PTR_CHECK(AvFormatCtxId, MemInst, uint32_t, AvFormatCtxPtr, + "Failed when accessing the return AVFormatContext Memory"sv); + + std::string TargetUrl; + std::copy_n(urlId, UrlSize, std::back_inserter(TargetUrl)); + + AVFormatContext *AvFormatContext = nullptr; + FFMPEG_PTR_FETCH(AvDictionary, AvDictionaryId, AVDictionary *); + FFMPEG_PTR_FETCH(AvInputFormat, AvInputFormatId, AVInputFormat); + + int const Res = avformat_open_input(&AvFormatContext, TargetUrl.c_str(), + AvInputFormat, AvDictionary); + FFMPEG_PTR_STORE(AvFormatContext, AvFormatCtxId); + return Res; +} + +Expect AVFormatFindStreamInfo::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t AvDictionaryId) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + FFMPEG_PTR_FETCH(AvDictionary, AvDictionaryId, AVDictionary *); + return avformat_find_stream_info(AvFormatContext, AvDictionary); +} + +Expect AVFormatCloseInput::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId) { + + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + avformat_close_input(&AvFormatCtx); + FFMPEG_PTR_DELETE(AvFormatCtxId); + return static_cast(ErrNo::Success); +} + +Expect AVReadPause::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + return av_read_pause(AvFormatContext); +} + +Expect AVReadPlay::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + return av_read_play(AvFormatContext); +} + +Expect AVFormatSeekFile::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t StreamIdx, int64_t MinTs, + int64_t Ts, int64_t MaxTs, + int32_t Flags) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + return avformat_seek_file(AvFormatContext, StreamIdx, MinTs, Ts, MaxTs, + Flags); +} + +Expect AVDumpFormat::body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, int32_t Idx, + uint32_t UrlPtr, uint32_t UrlSize, + int32_t IsOutput) { + std::string TargetUrl; + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(UrlBuf, MemInst, char, UrlPtr, ""); + + std::copy_n(UrlBuf, UrlSize, std::back_inserter(TargetUrl)); + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + + av_dump_format(AvFormatCtx, Idx, TargetUrl.c_str(), IsOutput); + return static_cast(ErrNo::Success); +} + +Expect AVFormatFreeContext::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId) { + + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + avformat_free_context(AvFormatCtx); + FFMPEG_PTR_DELETE(AvFormatCtxId); + return static_cast(ErrNo::Success); +} + +Expect AVFindBestStream::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + int32_t MediaTypeId, + int32_t WantedStream, + int32_t RelatedStream, + uint32_t DecoderRetId, int32_t Flags) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + FFMPEG_PTR_FETCH(DecoderRet, DecoderRetId, const AVCodec *); + + AVMediaType const AvMediaType = + FFmpegUtils::MediaType::intoMediaType(MediaTypeId); + return av_find_best_stream(AvFormatContext, AvMediaType, WantedStream, + RelatedStream, DecoderRet, Flags); +} + +Expect AVReadFrame::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, uint32_t PacketId) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + FFMPEG_PTR_FETCH(AvPacket, PacketId, AVPacket); + + return av_read_frame(AvFormatContext, AvPacket); +} + +Expect AVIOClose::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId) { + + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + avio_close(AvFormatCtx->pb); + return static_cast(ErrNo::Success); +} + +Expect AVFormatNetworkInit::body(const Runtime::CallingFrame &) { + return avformat_network_init(); +} + +Expect AVFormatNetworkDeInit::body(const Runtime::CallingFrame &) { + return avformat_network_deinit(); +} + +Expect AVFormatWriteHeader::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t DictId) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + FFMPEG_PTR_FETCH(AvDict, DictId, AVDictionary *); + return avformat_write_header(AvFormatContext, AvDict); +} + +Expect AVFormatWriteTrailer::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId) { + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + return av_write_trailer(AvFormatContext); +} + +Expect AVFormatAllocOutputContext2::body( + const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxPtr, + uint32_t AVOutputFormatId, uint32_t FormatNamePtr, uint32_t FormatLen, + uint32_t FileNamePtr, uint32_t FileNameLen) { + + std::string Format; + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(FileId, MemInst, char, FileNamePtr, + "Failed when accessing the return FileName memory"sv); + if (FormatLen > 0) { + MEM_PTR_CHECK(FormatId, MemInst, char, FormatNamePtr, + "Failed when accessing the return FormatName memory"sv); + + std::copy_n(FormatId, FormatLen, std::back_inserter(Format)); + } + MEM_PTR_CHECK(AvFormatCtxId, MemInst, uint32_t, AvFormatCtxPtr, + "Failed when accessing the return AVFormatContext Memory"sv); + + std::string File; + std::copy_n(FileId, FileNameLen, std::back_inserter(File)); + + AVFormatContext *AvFormatContext = nullptr; + FFMPEG_PTR_FETCH(AvOutputFormat, AVOutputFormatId, AVOutputFormat); + + int Res = 0; + if (FormatLen == 0) { + Res = avformat_alloc_output_context2(&AvFormatContext, AvOutputFormat, + nullptr, File.c_str()); + } else { + Res = avformat_alloc_output_context2(&AvFormatContext, AvOutputFormat, + Format.c_str(), File.c_str()); + } + FFMPEG_PTR_STORE(AvFormatContext, AvFormatCtxId); + return Res; +} + +Expect AVIOOpen::body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t FileNamePtr, + uint32_t FileNameLen, int32_t Flags) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(FileId, MemInst, char, FileNamePtr, + "Failed when accessing the return FileName memory"sv); + + std::string File; + std::copy_n(FileId, FileNameLen, std::back_inserter(File)); + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + + return avio_open(&(AvFormatContext->pb), File.c_str(), Flags); +} + +Expect AVIOOpen2::body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxtId, uint32_t UrlPtr, + uint32_t UrlLen, int32_t Flags, + uint32_t AVIOInterruptCBId, + uint32_t AVDictionaryId) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(UrlId, MemInst, char, UrlPtr, + "Failed when accessing the return Url memory"sv); + + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxtId, AVFormatContext); + FFMPEG_PTR_FETCH(AvDictionary, AVDictionaryId, AVDictionary *); + FFMPEG_PTR_FETCH(AvIOInterruptCB, AVIOInterruptCBId, AVIOInterruptCB); + + std::string TargetUrl; + std::copy_n(UrlId, UrlLen, std::back_inserter(TargetUrl)); + + return avio_open2(&(AvFormatCtx->pb), TargetUrl.c_str(), Flags, + AvIOInterruptCB, AvDictionary); +} + +Expect AVFormatVersion::body(const Runtime::CallingFrame &) { + return avformat_version(); +} + +Expect AVChapterMallocz::body(const Runtime::CallingFrame &Frame, + uint32_t AVChapterPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(AvChapterId, MemInst, uint32_t, AVChapterPtr, + "Failed to access Memory for AVChapterPtr"sv) + + AVChapter *AvChapter = + static_cast(av_mallocz(sizeof(AVChapter))); + FFMPEG_PTR_STORE(AvChapter, AvChapterId); + return static_cast(ErrNo::Success); +} + +Expect AVChapterDynarrayAdd::body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, + int32_t NbChaptersPtr, + uint32_t AvChapterId) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(NbChapters, MemInst, int32_t, NbChaptersPtr, + "Failed to access Memory for NbChaptersPtr"sv) + + FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); + FFMPEG_PTR_FETCH(AvChapter, AvChapterId, AVChapter); + + av_dynarray_add(&(AvFormatContext->chapters), NbChapters, AvChapter); + if (*(AvFormatContext->chapters) == nullptr && *(NbChapters) == 0) + return static_cast(ErrNo::InternalError); + return static_cast(ErrNo::Success); +} + +Expect AVFreeP::body(const Runtime::CallingFrame &, + uint32_t AvChapterId) { + + FFMPEG_PTR_FETCH(AvChapter, AvChapterId, AVChapter); + av_freep(AvChapter); + FFMPEG_PTR_DELETE(AvChapterId); + return static_cast(ErrNo::Success); +} + +Expect AVInterleavedWriteFrame::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t AvPacketId) { + + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + + return av_interleaved_write_frame(AvFormatCtx, AvPacket); +} + +Expect AVWriteFrame::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t AvPacketId) { + + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); + + return av_write_frame(AvFormatCtx, AvPacket); +} + +Expect AVFormatNewStream::body(const Runtime::CallingFrame &, + uint32_t AvFormatCtxId, + uint32_t AvCodecId) { + + FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); + FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); + AVStream *Stream = avformat_new_stream(AvFormatCtx, AvCodec); + if (Stream == nullptr) + return 0; + return 1; +} + +Expect AVGuessCodec::body(const Runtime::CallingFrame &Frame, + uint32_t AVIOFormatId, + uint32_t ShortNamePtr, + uint32_t ShortNameLen, uint32_t FileNamePtr, + uint32_t FileNameLen, uint32_t MimeTypePtr, + uint32_t MimeTypeLen, int32_t MediaTypeId) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(ShortNameBuf, MemInst, char, ShortNamePtr, + "Failed when accessing the return ShortName memory"sv); + MEM_PTR_CHECK(FileNameBuf, MemInst, char, FileNamePtr, + "Failed when accessing the return FileName memory"sv); + MEM_PTR_CHECK(MimeTypeBuf, MemInst, char, MimeTypePtr, + "Failed when accessing the return MimeType memory"sv); + FFMPEG_PTR_FETCH(AvOutputFormat, AVIOFormatId, AVOutputFormat); + + std::string ShortName; + std::string FileName; + std::string MimeType; + std::copy_n(ShortNameBuf, ShortNameLen, std::back_inserter(ShortName)); + std::copy_n(FileNameBuf, FileNameLen, std::back_inserter(FileName)); + std::copy_n(MimeTypeBuf, MimeTypeLen, std::back_inserter(MimeType)); + + AVMediaType const MediaType = + FFmpegUtils::MediaType::intoMediaType(MediaTypeId); + AVCodecID const Id = + av_guess_codec(AvOutputFormat, ShortName.c_str(), FileName.c_str(), + MimeType.c_str(), MediaType); + + return FFmpegUtils::CodecID::fromAVCodecID(Id); +} + +Expect +AVFormatConfigurationLength::body(const Runtime::CallingFrame &) { + + const char *Config = avformat_configuration(); + return strlen(Config); +} + +Expect AVFormatConfiguration::body(const Runtime::CallingFrame &Frame, + uint32_t ConfigPtr, + uint32_t ConfigLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); + + const char *Config = avformat_configuration(); + std::copy_n(Config, ConfigLen, ConfigBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVFormatLicenseLength::body(const Runtime::CallingFrame &) { + + const char *License = avformat_license(); + return strlen(License); +} + +Expect AVFormatLicense::body(const Runtime::CallingFrame &Frame, + uint32_t LicensePtr, + uint32_t LicenseLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); + + const char *License = avformat_license(); + std::copy_n(License, LicenseLen, LicenseBuf.data()); + return static_cast(ErrNo::Success); +} + +} // namespace AVFormat +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/avformat_func.h b/plugins/wasmedge_ffmpeg/avformat/avformat_func.h new file mode 100644 index 00000000..3cc44a33 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avformat/avformat_func.h @@ -0,0 +1,273 @@ +#pragma once + +#include "avformat_base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFormat { + +class AVFormatOpenInput : public WasmEdgeFFmpegAVFormat { +public: + AVFormatOpenInput(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxPtr, uint32_t UrlPtr, + uint32_t UrlSize, uint32_t AvInputFormatId, + uint32_t AvDictionaryId); +}; + +class AVFormatFindStreamInfo + : public WasmEdgeFFmpegAVFormat { +public: + AVFormatFindStreamInfo(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, + uint32_t AvDictionaryId); +}; + +class AVFormatCloseInput : public WasmEdgeFFmpegAVFormat { +public: + AVFormatCloseInput(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId); +}; + +class AVReadPause : public WasmEdgeFFmpegAVFormat { +public: + AVReadPause(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId); +}; + +class AVReadPlay : public WasmEdgeFFmpegAVFormat { +public: + AVReadPlay(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId); +}; + +class AVFormatSeekFile : public WasmEdgeFFmpegAVFormat { +public: + AVFormatSeekFile(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, + uint32_t StreamIdx, int64_t MinTs, int64_t Ts, + int64_t MaxTs, int32_t Flags); +}; + +class AVDumpFormat : public WasmEdgeFFmpegAVFormat { +public: + AVDumpFormat(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, int32_t Idx, uint32_t UrlPtr, + uint32_t UrlSize, int32_t IsOutput); +}; + +class AVFormatFreeContext : public WasmEdgeFFmpegAVFormat { +public: + AVFormatFreeContext(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxPtr); +}; + +class AVFindBestStream : public WasmEdgeFFmpegAVFormat { +public: + AVFindBestStream(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, + int32_t MediaTypeId, int32_t WantedStream, + int32_t RelatedStream, uint32_t DecoderRetId, + int32_t Flags); +}; + +class AVReadFrame : public WasmEdgeFFmpegAVFormat { +public: + AVReadFrame(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, + uint32_t PacketId); +}; + +class AVIOClose : public WasmEdgeFFmpegAVFormat { +public: + AVIOClose(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId); +}; + +class AVFormatNetworkInit : public WasmEdgeFFmpegAVFormat { +public: + AVFormatNetworkInit(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVFormatNetworkDeInit + : public WasmEdgeFFmpegAVFormat { +public: + AVFormatNetworkDeInit(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVFormatWriteHeader : public WasmEdgeFFmpegAVFormat { +public: + AVFormatWriteHeader(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t DictId); +}; + +class AVFormatWriteTrailer + : public WasmEdgeFFmpegAVFormat { +public: + AVFormatWriteTrailer(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId); +}; + +class AVFormatAllocOutputContext2 + : public WasmEdgeFFmpegAVFormat { +public: + AVFormatAllocOutputContext2(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxPtr, uint32_t AVOutputFormatId, + uint32_t FormatNamePtr, uint32_t FormatLen, + uint32_t FileNamePtr, uint32_t FileNameLen); +}; + +class AVIOOpen : public WasmEdgeFFmpegAVFormat { +public: + AVIOOpen(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxId, uint32_t FileNamePtr, + uint32_t FileNameLen, int32_t Flags); +}; + +class AVIOOpen2 : public WasmEdgeFFmpegAVFormat { +public: + AVIOOpen2(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvFormatCtxtId, uint32_t UrlPtr, + uint32_t UrlLen, int32_t Flags, + uint32_t AVIOInterruptCBId, uint32_t AVDictionaryId); +}; + +class AVFormatVersion : public WasmEdgeFFmpegAVFormat { +public: + AVFormatVersion(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVChapterMallocz : public WasmEdgeFFmpegAVFormat { +public: + AVChapterMallocz(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVChapterPtr); +}; + +class AVChapterDynarrayAdd + : public WasmEdgeFFmpegAVFormat { +public: + AVChapterDynarrayAdd(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + + Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, + int32_t NbChaptersPtr, uint32_t AvChapterId); +}; + +class AVFreeP : public WasmEdgeFFmpegAVFormat { +public: + AVFreeP(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + + Expect body(const Runtime::CallingFrame &, uint32_t AvChapterId); +}; + +class AVInterleavedWriteFrame + : public WasmEdgeFFmpegAVFormat { +public: + AVInterleavedWriteFrame(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + + Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, + uint32_t AvPacketId); +}; + +class AVWriteFrame : public WasmEdgeFFmpegAVFormat { +public: + AVWriteFrame(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + + Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, + uint32_t AvPacketId); +}; + +class AVFormatNewStream : public WasmEdgeFFmpegAVFormat { +public: + AVFormatNewStream(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVFormatCtxId, uint32_t AVCodecId); +}; + +class AVGuessCodec : public WasmEdgeFFmpegAVFormat { +public: + AVGuessCodec(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AVInputOutputId, uint32_t ShortNamePtr, + uint32_t ShortNameLen, uint32_t FileNamePtr, + uint32_t FileNameLen, uint32_t MimeTypePtr, + uint32_t MimeTypeLen, int32_t MediaTypeId); +}; + +class AVFormatConfigurationLength + : public WasmEdgeFFmpegAVFormat { +public: + AVFormatConfigurationLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVFormatConfiguration + : public WasmEdgeFFmpegAVFormat { +public: + AVFormatConfiguration(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, + uint32_t ConfigLen); +}; + +class AVFormatLicenseLength + : public WasmEdgeFFmpegAVFormat { +public: + AVFormatLicenseLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVFormatLicense : public WasmEdgeFFmpegAVFormat { +public: + AVFormatLicense(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVFormat(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, + uint32_t LicenseLen); +}; + +} // namespace AVFormat +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/module.cpp b/plugins/wasmedge_ffmpeg/avformat/module.cpp new file mode 100644 index 00000000..0d1ec108 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avformat/module.cpp @@ -0,0 +1,193 @@ +#include "module.h" +#include "avChapter.h" +#include "avInputOutputFormat.h" +#include "avStream.h" +#include "avformatContext.h" +#include "avformat_func.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFormat { + +WasmEdgeFFmpegAVFormatModule::WasmEdgeFFmpegAVFormatModule( + std::shared_ptr Env) + : ModuleInstance("wasmedge_ffmpeg_avformat") { + + // avformat_func.h + addHostFunc("wasmedge_ffmpeg_avformat_avformat_open_input", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_find_stream_info", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_close_input", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_av_read_play", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_av_read_pause", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_av_dump_format", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_seek_file", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_free_context", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_av_find_best_stream", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_av_read_frame", // TODO: Write Test + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avio_close", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_network_init", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_network_deinit", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_write_header", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_write_trailer", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_alloc_output_context2", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avio_open", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avio_open2", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avchapter_mallocz", + std::make_unique(Env)); + addHostFunc( + "wasmedge_ffmpeg_avformat_avchapter_dynarray_add", // TODO: Write Test + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_avfreep", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_version", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_av_write_frame", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_av_interleaved_write_frame", + std::make_unique(Env)); + addHostFunc( + "wasmedge_ffmpeg_avformat_avformat_new_stream", // TODO: Write Test + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_av_guess_codec", // TODO: Write Test + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_configuration_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_configuration", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_license_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformat_license", + std::make_unique(Env)); + + // avformatContext Struct functions. + addHostFunc("wasmedge_ffmpeg_avformat_avformatContext_iformat", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformatContext_oformat", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformatContext_probescope", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformatContext_nb_streams", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformatContext_duration", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformatContext_bit_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformatContext_nb_chapters", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformatContext_set_nb_chapters", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformatContext_metadata", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avformatContext_set_metadata", + std::make_unique(Env)); + + // avInputFormat Struct functions. + addHostFunc("wasmedge_ffmpeg_avformat_avIOFormat_name_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avInputFormat_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avOutputFormat_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avIOFormat_long_name_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avInputFormat_long_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avOutputFormat_long_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avIOFormat_extensions_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avInputFormat_extensions", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avOutputFormat_extensions", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avIOFormat_mime_type_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avInputFormat_mime_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avOutputFormat_mime_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avOutputFormat_flags", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avInputOutputFormat_free", + std::make_unique(Env)); + + // avStream Struct Functions. + addHostFunc("wasmedge_ffmpeg_avformat_avStream_id", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_index", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_codecpar", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_timebase", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_set_timebase", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_duration", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_start_time", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_nb_frames", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_disposition", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_r_frame_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_set_r_frame_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_avg_frame_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_set_avg_frame_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_metadata", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_set_metadata", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avStream_discard", + std::make_unique(Env)); + + // avChapter Struct Functions. + addHostFunc("wasmedge_ffmpeg_avformat_avChapter_id", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avChapter_set_id", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avChapter_timebase", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avChapter_set_timebase", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avChapter_start", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avChapter_set_start", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avChapter_end", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avChapter_set_end", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avChapter_metadata", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avformat_avChapter_set_metadata", + std::make_unique(Env)); +} + +} // namespace AVFormat +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/module.h b/plugins/wasmedge_ffmpeg/avformat/module.h new file mode 100644 index 00000000..4ab491ed --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avformat/module.h @@ -0,0 +1,19 @@ +#pragma once + +#include "ffmpeg_env.h" +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVFormat { + +class WasmEdgeFFmpegAVFormatModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeFFmpegAVFormatModule(std::shared_ptr Env); +}; + +} // namespace AVFormat +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp b/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp new file mode 100644 index 00000000..047f309c --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp @@ -0,0 +1,159 @@ +#include "avDictionary.h" + +extern "C" { +#include "libavutil/dict.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +Expect AVDictSet::body(const Runtime::CallingFrame &Frame, + uint32_t DictPtr, uint32_t KeyPtr, + uint32_t KeyLen, uint32_t ValuePtr, + uint32_t ValueLen, int32_t Flags) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(KeyBuf, MemInst, char, KeyPtr, + "Failed when accessing the return Key memory"sv); + MEM_PTR_CHECK(ValueBuf, MemInst, char, ValuePtr, + "Failed when accessing the return Value memory"sv); + MEM_PTR_CHECK(DictId, MemInst, uint32_t, DictPtr, + "Failed to access Memory for AVDict"sv) + + std::string Key; + std::string Value; + std::copy_n(KeyBuf, KeyLen, std::back_inserter(Key)); + std::copy_n(ValueBuf, ValueLen, std::back_inserter(Value)); + + int Res = 0; + + // Using Maybe::uninit(); in Rust. If Uninitialized, zero is + // passed. Else the Ptr contains a Number. + if (*DictId) { + FFMPEG_PTR_FETCH(AvDict, *DictId, AVDictionary *); + Res = av_dict_set(AvDict, Key.c_str(), Value.c_str(), Flags); + } else { + AVDictionary **AvDict = + static_cast(av_mallocz(sizeof(AVDictionary *))); + Res = av_dict_set(AvDict, Key.c_str(), Value.c_str(), Flags); + FFMPEG_PTR_STORE(AvDict, DictId); + } + + return Res; +} + +Expect AVDictCopy::body(const Runtime::CallingFrame &Frame, + uint32_t DestDictPtr, uint32_t SrcDictId, + uint32_t Flags) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(DestDictId, MemInst, uint32_t, DestDictPtr, + "Failed to access Memory for AVDict"sv) + + FFMPEG_PTR_FETCH(SrcAvDict, SrcDictId, AVDictionary *); + + int Res = 0; + + if (SrcAvDict == nullptr) + return static_cast(ErrNo::InternalError); + + if (*DestDictId) { + FFMPEG_PTR_FETCH(DestAvDict, *DestDictId, AVDictionary *); + Res = av_dict_copy(DestAvDict, *SrcAvDict, Flags); + } else { + AVDictionary **DestAvDict = + static_cast(av_mallocz(sizeof(AVDictionary *))); + av_dict_copy(DestAvDict, *SrcAvDict, Flags); + FFMPEG_PTR_STORE(DestAvDict, DestDictId); + } + + return Res; +} + +Expect AVDictGet::body(const Runtime::CallingFrame &Frame, + uint32_t DictId, uint32_t KeyPtr, + uint32_t KeyLen, uint32_t PrevDictEntryIdx, + uint32_t Flags, uint32_t KeyLenPtr, + uint32_t ValueLenPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(KeyStr, MemInst, char, KeyPtr, + "Failed when accessing the return Key memory"sv); + MEM_PTR_CHECK(KeyLenId, MemInst, uint32_t, KeyLenPtr, + "Failed when accessing the return KeyLen memory"sv); + MEM_PTR_CHECK(ValueLenId, MemInst, uint32_t, ValueLenPtr, + "Failed when accessing the return ValueLen memory"sv); + + FFMPEG_PTR_FETCH(AvDict, DictId, AVDictionary *); + + // If Dict Not created return (i.e. 0 is passed as AVDictId) + if (AvDict == nullptr) + return static_cast(ErrNo::InternalError); + std::string Key; + std::copy_n(KeyStr, KeyLen, std::back_inserter(Key)); + + AVDictionaryEntry *DictEntry = nullptr; + uint32_t Curr = 0; + while (Curr <= PrevDictEntryIdx) { + DictEntry = av_dict_get(*AvDict, Key.c_str(), DictEntry, Flags); + Curr++; + } + + if (DictEntry == nullptr) + return static_cast(ErrNo::InternalError); + + *KeyLenId = strlen(DictEntry->key); + *ValueLenId = strlen(DictEntry->value); + return Curr; +} + +Expect AVDictGetKeyValue::body( + const Runtime::CallingFrame &Frame, uint32_t DictId, uint32_t KeyPtr, + uint32_t KeyLen, uint32_t ValBufPtr, uint32_t ValBufLen, uint32_t KeyBufPtr, + uint32_t KeyBufLen, uint32_t PrevDictEntryIdx, uint32_t Flags) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(KeyStr, MemInst, char, KeyPtr, + "Failed when accessing the return Key memory"sv); + MEM_SPAN_CHECK(KeyBuf, MemInst, char, KeyBufPtr, KeyBufLen, ""); + MEM_SPAN_CHECK(ValBuf, MemInst, char, ValBufPtr, ValBufLen, ""); + + FFMPEG_PTR_FETCH(AvDict, DictId, AVDictionary *); + + // If Dict Not created return (i.e. 0 is passed as AVDictId) + if (AvDict == nullptr) + return static_cast(ErrNo::InternalError); + + std::string Key; + std::copy_n(KeyStr, KeyLen, std::back_inserter(Key)); + + AVDictionaryEntry *DictEntry = nullptr; + uint32_t Curr = 0; + while (Curr <= PrevDictEntryIdx) { + DictEntry = av_dict_get(*AvDict, Key.c_str(), DictEntry, Flags); + Curr++; + } + if (DictEntry == nullptr) + return static_cast(ErrNo::InternalError); + std::copy_n(DictEntry->value, strlen(DictEntry->value), ValBuf.data()); + std::copy_n(DictEntry->key, strlen(DictEntry->key), KeyBuf.data()); + return Curr; +} + +Expect AVDictFree::body(const Runtime::CallingFrame &, + uint32_t DictId) { + + if (DictId == 0) + return static_cast(ErrNo::Success); + FFMPEG_PTR_FETCH(AvDict, DictId, AVDictionary *); + av_dict_free(AvDict); + FFMPEG_PTR_DELETE(DictId); + return static_cast(ErrNo::Success); +} + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/avDictionary.h b/plugins/wasmedge_ffmpeg/avutil/avDictionary.h new file mode 100644 index 00000000..8d5a5ff1 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/avDictionary.h @@ -0,0 +1,59 @@ +#pragma once + +#include "avutil_base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +class AVDictSet : public WasmEdgeFFmpegAVUtil { +public: + AVDictSet(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t DictId, + uint32_t KeyPtr, uint32_t KeyLen, uint32_t ValuePtr, + uint32_t ValueLen, int32_t Flags); +}; + +class AVDictGet : public WasmEdgeFFmpegAVUtil { +public: + AVDictGet(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t DictId, + uint32_t KeyPtr, uint32_t KeyLen, + uint32_t PrevDictEntryIdx, uint32_t Flags, + uint32_t KeyLenPtr, uint32_t ValueLenPtr); +}; + +class AVDictGetKeyValue : public WasmEdgeFFmpegAVUtil { +public: + AVDictGetKeyValue(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t DictId, + uint32_t KeyPtr, uint32_t KeyLen, uint32_t ValBufPtr, + uint32_t ValBufLen, uint32_t KeyBufPtr, + uint32_t KeyBufLen, uint32_t PrevDictEntryIdx, + uint32_t Flags); +}; + +class AVDictCopy : public WasmEdgeFFmpegAVUtil { +public: + AVDictCopy(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t DestDictId, + uint32_t SrcDictId, uint32_t Flags); +}; + +class AVDictFree : public WasmEdgeFFmpegAVUtil { +public: + AVDictFree(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t DictId); +}; + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp b/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp new file mode 100644 index 00000000..414031d9 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp @@ -0,0 +1,464 @@ +#include "avFrame.h" + +extern "C" { +#include "libavutil/frame.h" +#include "libavutil/pixfmt.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +Expect AVFrameAlloc::body(const Runtime::CallingFrame &Frame, + uint32_t FramePtr) { + MEMINST_CHECK(MemInst, Frame, 0) + MEM_PTR_CHECK(FrameId, MemInst, uint32_t, FramePtr, + "Failed to access Memory for AVFrame"sv) + + AVFrame *AvFrame = av_frame_alloc(); + FFMPEG_PTR_STORE(AvFrame, FrameId); + return static_cast(ErrNo::Success); +} + +Expect AVFrameFree::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + av_frame_free(&AvFrame); + FFMPEG_PTR_DELETE(FrameId); + return static_cast(ErrNo::Success); +} + +Expect AVFrameWidth::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->width; +} + +Expect AVFrameHeight::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->height; +} + +Expect AVFrameSetHeight::body(const Runtime::CallingFrame &, + uint32_t FrameId, uint32_t Height) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AvFrame->height = Height; + return static_cast(ErrNo::Success); +} + +Expect AVFrameSetWidth::body(const Runtime::CallingFrame &, + uint32_t FrameId, uint32_t Width) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AvFrame->width = Width; + return static_cast(ErrNo::Success); +} + +Expect AVFrameVideoFormat::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + int const Format = AvFrame->format; + if (Format == -1) + return -1; + AVPixelFormat const PixelFormat = static_cast(Format); + return FFmpegUtils::PixFmt::fromAVPixFmt(PixelFormat); +} + +Expect AVFrameSetVideoFormat::body(const Runtime::CallingFrame &, + uint32_t FrameId, + uint32_t AvPixFormatId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AVPixelFormat const PixelFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(AvPixFormatId); + AvFrame->format = PixelFormat; + return static_cast(ErrNo::Success); +} + +Expect AVFrameIsNull::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->data[0] == nullptr; +} + +Expect AVFrameLinesize::body(const Runtime::CallingFrame &, + uint32_t FrameId, uint32_t Idx) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->linesize[Idx]; +} + +Expect AVFrameData::body(const Runtime::CallingFrame &Frame, + uint32_t FrameId, uint32_t FrameBufPtr, + uint32_t FrameBufLen, uint32_t Index) { + + MEMINST_CHECK(MemInst, Frame, 0) + MEM_SPAN_CHECK(Buffer, MemInst, uint8_t, FrameBufPtr, FrameBufLen, ""); + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + + uint8_t *Data = AvFrame->data[Index]; + std::copy_n(Data, FrameBufLen, Buffer.data()); + return static_cast(ErrNo::Success); +} + +Expect AVFrameGetBuffer::body(const Runtime::CallingFrame &, + uint32_t FrameId, int32_t Align) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return av_frame_get_buffer(AvFrame, Align); +} + +Expect AVFrameAudioFormat::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + int const Format = AvFrame->format; + if (Format == -1) + return -1; + + AVSampleFormat const SampleFormat = static_cast(Format); + return FFmpegUtils::SampleFmt::toSampleID(SampleFormat); +} + +Expect AVFrameSetAudioFormat::body(const Runtime::CallingFrame &, + uint32_t FrameId, + uint32_t SampleFormatId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AVSampleFormat const SampleFormat = + FFmpegUtils::SampleFmt::fromSampleID(SampleFormatId); + AvFrame->format = SampleFormat; + return static_cast(ErrNo::Success); +} + +Expect AVFrameSetChannelLayout::body(const Runtime::CallingFrame &, + uint32_t FrameId, + uint64_t ChannelLayoutID) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + uint64_t const ChannelLayout = + FFmpegUtils::ChannelLayout::fromChannelLayoutID(ChannelLayoutID); + AvFrame->channel_layout = ChannelLayout; + return static_cast(ErrNo::Success); +} + +Expect AVFrameSetNbSamples::body(const Runtime::CallingFrame &, + uint32_t FrameId, int32_t Samples) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AvFrame->nb_samples = Samples; + return static_cast(ErrNo::Success); +} + +Expect AVFrameNbSamples::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->nb_samples; +} + +Expect AVFrameSampleRate::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->sample_rate; +} + +Expect AVFrameSetSampleRate::body(const Runtime::CallingFrame &, + uint32_t FrameId, + int32_t SampleRate) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AvFrame->sample_rate = SampleRate; + return static_cast(ErrNo::Success); +} + +Expect AVFrameChannels::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->channels; +} + +Expect AVFrameSetChannels::body(const Runtime::CallingFrame &, + uint32_t FrameId, int32_t Channels) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AvFrame->channels = Channels; + return static_cast(ErrNo::Success); +} + +Expect AVFrameChannelLayout::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + uint64_t const ChannelLayout = AvFrame->channel_layout; + return FFmpegUtils::ChannelLayout::intoChannelLayoutID(ChannelLayout); +} + +Expect AVFrameBestEffortTimestamp::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->best_effort_timestamp; +} + +Expect AVFramePictType::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AVPictureType const AvPictureType = AvFrame->pict_type; + return FFmpegUtils::PictureType::fromAVPictureType(AvPictureType); +} + +Expect AVFrameSetPictType::body(const Runtime::CallingFrame &, + uint32_t FrameId, int32_t PictureId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AVPictureType const AvPictureType = + FFmpegUtils::PictureType::intoAVPictureType(PictureId); + + AvFrame->pict_type = AvPictureType; + return static_cast(ErrNo::Success); +} + +Expect AVFrameInterlacedFrame::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->interlaced_frame; +} + +Expect AVFrameTopFieldFirst::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->top_field_first; +} + +Expect AVFramePaletteHasChanged::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->palette_has_changed; +} + +Expect AVFrameColorSpace::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AVColorSpace const AvColorSpace = AvFrame->colorspace; + return FFmpegUtils::ColorSpace::fromAVColorSpace(AvColorSpace); +} + +Expect AVFrameSetColorSpace::body(const Runtime::CallingFrame &, + uint32_t FrameId, + int32_t ColorSpaceId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AvFrame->colorspace = FFmpegUtils::ColorSpace::intoAVColorSpace(ColorSpaceId); + return static_cast(ErrNo::Success); +} + +Expect AVFrameColorRange::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AVColorRange const AvColorRange = AvFrame->color_range; + + return static_cast(AvColorRange); +} + +Expect AVFrameSetColorRange::body(const Runtime::CallingFrame &, + uint32_t FrameId, + int32_t ColorRangeId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AvFrame->color_range = static_cast(ColorRangeId); + return static_cast(ErrNo::Success); +} + +Expect +AVFrameColorTransferCharacteristic::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AVColorTransferCharacteristic const Characteristic = AvFrame->color_trc; + + // Can use the binding as well. Currently, Commented the binding. + return static_cast(Characteristic); +} + +Expect AVFrameSetColorTransferCharacteristic::body( + const Runtime::CallingFrame &, uint32_t FrameId, + int32_t ColorTransferCharacteristicId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AvFrame->color_trc = + static_cast(ColorTransferCharacteristicId); + return static_cast(ErrNo::Success); +} + +Expect AVFrameChromaLocation::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AVChromaLocation const AvChromaLocation = AvFrame->chroma_location; + return FFmpegUtils::ChromaLocation::fromAVChromaLocation(AvChromaLocation); +} + +Expect AVFrameCodedPictureNumber::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->coded_picture_number; +} + +Expect AVFrameDisplayPictureNumber::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->display_picture_number; +} + +Expect AVFrameRepeatPict::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->repeat_pict; +} + +Expect AVFrameFlags::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->flags; +} + +Expect AVFrameQuality::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->quality; +} + +Expect AVFrameMetadata::body(const Runtime::CallingFrame &Frame, + uint32_t FrameId, uint32_t DictPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(DictId, MemInst, uint32_t, DictPtr, + "Failed when accessing the return AVDictionary memory"sv); + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + + AVDictionary **AvDictionary = + static_cast(av_malloc(sizeof(AVDictionary *))); + + *AvDictionary = AvFrame->metadata; + FFMPEG_PTR_STORE(AvDictionary, DictId); + return static_cast(ErrNo::Success); +} + +Expect AVFrameSetMetadata::body(const Runtime::CallingFrame &, + uint32_t FrameId, uint32_t DictId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + FFMPEG_PTR_FETCH(AvDict, DictId, AVDictionary *); + + if (AvDict == nullptr) + AvFrame->metadata = nullptr; + else + AvFrame->metadata = *AvDict; + return static_cast(ErrNo::Success); +} + +Expect AVFrameKeyFrame::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->key_frame; +} + +Expect AVFramePts::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + return AvFrame->pts; +} + +Expect AVFrameSetPts::body(const Runtime::CallingFrame &, + uint32_t FrameId, int64_t Pts) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AvFrame->pts = Pts; + return static_cast(ErrNo::Success); +} + +Expect AVFrameCopy::body(const Runtime::CallingFrame &, + uint32_t DestFrameId, uint32_t SrcFrameId) { + + FFMPEG_PTR_FETCH(DestAvFrame, DestFrameId, AVFrame); + FFMPEG_PTR_FETCH(SrcAvFrame, SrcFrameId, AVFrame); + + av_frame_copy(DestAvFrame, SrcAvFrame); + return static_cast(ErrNo::Success); +} + +Expect AVFrameCopyProps::body(const Runtime::CallingFrame &, + uint32_t DestFrameId, + uint32_t SrcFrameId) { + + FFMPEG_PTR_FETCH(DestAvFrame, DestFrameId, AVFrame); + FFMPEG_PTR_FETCH(SrcAvFrame, SrcFrameId, AVFrame); + + av_frame_copy_props(DestAvFrame, SrcAvFrame); + return static_cast(ErrNo::Success); +} + +Expect +AVFrameSampleAspectRatio::body(const Runtime::CallingFrame &Frame, + uint32_t FrameId, uint32_t NumPtr, + uint32_t DenPtr) { + + MEMINST_CHECK(MemInst, Frame, 0) + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + + MEM_PTR_CHECK(Num, MemInst, int32_t, NumPtr, + "Failed to access Numerator Ptr for AVRational"sv); + MEM_PTR_CHECK(Den, MemInst, int32_t, DenPtr, + "Failed to access Denominator Ptr for AVRational"sv); + + AVRational const Rational = AvFrame->sample_aspect_ratio; + *Num = Rational.num; + *Den = Rational.den; + return static_cast(ErrNo::Success); +} + +Expect AVFrameColorPrimaries::body(const Runtime::CallingFrame &, + uint32_t FrameId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AVColorPrimaries const ColorPrimaries = AvFrame->color_primaries; + return FFmpegUtils::ColorPrimaries::fromAVColorPrimaries(ColorPrimaries); +} + +Expect AVFrameSetColorPrimaries::body(const Runtime::CallingFrame &, + uint32_t FrameId, + int32_t ColorPrimariesId) { + + FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); + AVColorPrimaries const ColorPrimaries = + FFmpegUtils::ColorPrimaries::intoAVColorPrimaries(ColorPrimariesId); + AvFrame->color_primaries = ColorPrimaries; + return static_cast(ErrNo::Success); +} + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/avFrame.h b/plugins/wasmedge_ffmpeg/avutil/avFrame.h new file mode 100644 index 00000000..3bceaa3e --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/avFrame.h @@ -0,0 +1,404 @@ +#pragma once + +#include "avutil_base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +class AVFrameAlloc : public WasmEdgeFFmpegAVUtil { +public: + AVFrameAlloc(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FramePtr); +}; + +class AVFrameFree : public WasmEdgeFFmpegAVUtil { +public: + AVFrameFree(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameWidth : public WasmEdgeFFmpegAVUtil { +public: + AVFrameWidth(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameHeight : public WasmEdgeFFmpegAVUtil { +public: + AVFrameHeight(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameSetWidth : public WasmEdgeFFmpegAVUtil { +public: + AVFrameSetWidth(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + uint32_t Width); +}; + +class AVFrameSetHeight : public WasmEdgeFFmpegAVUtil { +public: + AVFrameSetHeight(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + uint32_t Height); +}; + +class AVFrameVideoFormat : public WasmEdgeFFmpegAVUtil { +public: + AVFrameVideoFormat(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameSetVideoFormat + : public WasmEdgeFFmpegAVUtil { +public: + AVFrameSetVideoFormat(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + uint32_t AvPixFormatId); +}; + +class AVFrameIsNull : public WasmEdgeFFmpegAVUtil { +public: + AVFrameIsNull(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameLinesize : public WasmEdgeFFmpegAVUtil { +public: + AVFrameLinesize(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + uint32_t Idx); +}; + +class AVFrameData : public WasmEdgeFFmpegAVUtil { +public: + AVFrameData(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + uint32_t FrameBufPtr, uint32_t FrameBufLen, + uint32_t Index); +}; + +class AVFrameGetBuffer : public WasmEdgeFFmpegAVUtil { +public: + AVFrameGetBuffer(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + int32_t Align); +}; + +class AVFrameAudioFormat : public WasmEdgeFFmpegAVUtil { +public: + AVFrameAudioFormat(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameSetAudioFormat + : public WasmEdgeFFmpegAVUtil { +public: + AVFrameSetAudioFormat(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + uint32_t SampleFormatId); +}; + +class AVFrameSetChannelLayout + : public WasmEdgeFFmpegAVUtil { +public: + AVFrameSetChannelLayout(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + uint64_t ChannelLayoutID); +}; + +class AVFrameSetNbSamples : public WasmEdgeFFmpegAVUtil { +public: + AVFrameSetNbSamples(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + int32_t Samples); +}; + +class AVFrameNbSamples : public WasmEdgeFFmpegAVUtil { +public: + AVFrameNbSamples(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameSampleRate : public WasmEdgeFFmpegAVUtil { +public: + AVFrameSampleRate(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameSetSampleRate : public WasmEdgeFFmpegAVUtil { +public: + AVFrameSetSampleRate(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + int32_t SampleRate); +}; + +class AVFrameChannels : public WasmEdgeFFmpegAVUtil { +public: + AVFrameChannels(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameSetChannels : public WasmEdgeFFmpegAVUtil { +public: + AVFrameSetChannels(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + int32_t Channels); +}; + +class AVFrameChannelLayout : public WasmEdgeFFmpegAVUtil { +public: + AVFrameChannelLayout(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameBestEffortTimestamp + : public WasmEdgeFFmpegAVUtil { +public: + AVFrameBestEffortTimestamp(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFramePictType : public WasmEdgeFFmpegAVUtil { +public: + AVFramePictType(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameSetPictType : public WasmEdgeFFmpegAVUtil { +public: + AVFrameSetPictType(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + int32_t PictureId); +}; + +class AVFrameInterlacedFrame + : public WasmEdgeFFmpegAVUtil { +public: + AVFrameInterlacedFrame(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameTopFieldFirst : public WasmEdgeFFmpegAVUtil { +public: + AVFrameTopFieldFirst(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFramePaletteHasChanged + : public WasmEdgeFFmpegAVUtil { +public: + AVFramePaletteHasChanged(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameColorSpace : public WasmEdgeFFmpegAVUtil { +public: + AVFrameColorSpace(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameSetColorSpace : public WasmEdgeFFmpegAVUtil { +public: + AVFrameSetColorSpace(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + int32_t ColorSpaceId); +}; + +class AVFrameColorRange : public WasmEdgeFFmpegAVUtil { +public: + AVFrameColorRange(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameSetColorRange : public WasmEdgeFFmpegAVUtil { +public: + AVFrameSetColorRange(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + int32_t ColorRangeId); +}; + +// color_transfer_characteristic + +class AVFrameColorTransferCharacteristic + : public WasmEdgeFFmpegAVUtil { +public: + AVFrameColorTransferCharacteristic(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameSetColorTransferCharacteristic + : public WasmEdgeFFmpegAVUtil { +public: + AVFrameSetColorTransferCharacteristic( + std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + int32_t ColorTransferCharacteristicId); +}; + +class AVFrameChromaLocation + : public WasmEdgeFFmpegAVUtil { +public: + AVFrameChromaLocation(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameCodedPictureNumber + : public WasmEdgeFFmpegAVUtil { +public: + AVFrameCodedPictureNumber(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameDisplayPictureNumber + : public WasmEdgeFFmpegAVUtil { +public: + AVFrameDisplayPictureNumber(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameRepeatPict : public WasmEdgeFFmpegAVUtil { +public: + AVFrameRepeatPict(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameFlags : public WasmEdgeFFmpegAVUtil { +public: + AVFrameFlags(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameQuality : public WasmEdgeFFmpegAVUtil { +public: + AVFrameQuality(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameMetadata : public WasmEdgeFFmpegAVUtil { +public: + AVFrameMetadata(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + uint32_t DictPtr); +}; + +class AVFrameSetMetadata : public WasmEdgeFFmpegAVUtil { +public: + AVFrameSetMetadata(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + uint32_t DictId); +}; + +class AVFrameKeyFrame : public WasmEdgeFFmpegAVUtil { +public: + AVFrameKeyFrame(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFramePts : public WasmEdgeFFmpegAVUtil { +public: + AVFramePts(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameSetPts : public WasmEdgeFFmpegAVUtil { +public: + AVFrameSetPts(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + int64_t Pts); +}; + +class AVFrameCopy : public WasmEdgeFFmpegAVUtil { +public: + AVFrameCopy(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t DestFrameId, + uint32_t SrcFrameId); +}; + +class AVFrameCopyProps : public WasmEdgeFFmpegAVUtil { +public: + AVFrameCopyProps(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t DestFrameId, + uint32_t SrcFrameId); +}; + +class AVFrameSampleAspectRatio + : public WasmEdgeFFmpegAVUtil { +public: + AVFrameSampleAspectRatio(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + uint32_t NumPtr, uint32_t DenPtr); +}; + +class AVFrameColorPrimaries + : public WasmEdgeFFmpegAVUtil { +public: + AVFrameColorPrimaries(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); +}; + +class AVFrameSetColorPrimaries + : public WasmEdgeFFmpegAVUtil { +public: + AVFrameSetColorPrimaries(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, + int32_t ColorPrimariesId); +}; + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/avRational.cpp b/plugins/wasmedge_ffmpeg/avutil/avRational.cpp new file mode 100644 index 00000000..8fbe81fd --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/avRational.cpp @@ -0,0 +1,169 @@ +#include "avRational.h" + +extern "C" { +#include "libavutil/rational.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +Expect AVAddQ::body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen, int32_t BNum, int32_t BDen, + uint32_t CNumPtr, uint32_t CDenPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(CNum, MemInst, int32_t, CNumPtr, + "Failed to access Numerator Ptr for AVRational"sv); + MEM_PTR_CHECK(CDen, MemInst, int32_t, CDenPtr, + "Failed to access Denominator Ptr for AVRational"sv); + + AVRational const A = av_make_q(ANum, ADen); + AVRational const B = av_make_q(BNum, BDen); + + AVRational const C = av_add_q(A, B); + *CNum = C.num; + *CDen = C.den; + + return static_cast(ErrNo::Success); +} + +Expect AVSubQ::body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen, int32_t BNum, int32_t BDen, + uint32_t CNumPtr, uint32_t CDenPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(CNum, MemInst, int32_t, CNumPtr, + "Failed to access Numerator Ptr for AVRational"sv); + MEM_PTR_CHECK(CDen, MemInst, int32_t, CDenPtr, + "Failed to access Denominator Ptr for AVRational"sv); + + AVRational const A = av_make_q(ANum, ADen); + AVRational const B = av_make_q(BNum, BDen); + + AVRational const C = av_sub_q(A, B); + *CNum = C.num; + *CDen = C.den; + return static_cast(ErrNo::Success); +} + +Expect AVMulQ::body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen, int32_t BNum, int32_t BDen, + uint32_t CNumPtr, uint32_t CDenPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(CNum, MemInst, int32_t, CNumPtr, + "Failed to access Numerator Ptr for AVRational"sv); + MEM_PTR_CHECK(CDen, MemInst, int32_t, CDenPtr, + "Failed to access Denominator Ptr for AVRational"sv); + + AVRational const A = av_make_q(ANum, ADen); + AVRational const B = av_make_q(BNum, BDen); + + AVRational const C = av_mul_q(A, B); + *CNum = C.num; + *CDen = C.den; + return static_cast(ErrNo::Success); +} + +Expect AVDivQ::body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen, int32_t BNum, int32_t BDen, + uint32_t CNumPtr, uint32_t CDenPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(CNum, MemInst, int32_t, CNumPtr, + "Failed to access Numerator Ptr for AVRational"sv); + MEM_PTR_CHECK(CDen, MemInst, int32_t, CDenPtr, + "Failed to access Denominator Ptr for AVRational"sv); + + AVRational const A = av_make_q(ANum, ADen); + AVRational const B = av_make_q(BNum, BDen); + + AVRational const C = av_div_q(A, B); + *CNum = C.num; + *CDen = C.den; + return static_cast(ErrNo::Success); +} + +Expect AVCmpQ::body(const Runtime::CallingFrame &, int32_t ANum, + int32_t ADen, int32_t BNum, int32_t BDen) { + + AVRational const A = av_make_q(ANum, ADen); + AVRational const B = av_make_q(BNum, BDen); + return av_cmp_q(A, B); +} + +Expect AVNearerQ::body(const Runtime::CallingFrame &, int32_t ANum, + int32_t ADen, int32_t BNum, int32_t BDen, + int32_t CNum, int32_t CDen) { + + AVRational const A = av_make_q(ANum, ADen); + AVRational const B = av_make_q(BNum, BDen); + AVRational const C = av_make_q(CNum, CDen); + + return av_nearer_q(A, B, C); +} + +Expect AVQ2d::body(const Runtime::CallingFrame &, int32_t ANum, + int32_t ADen) { + + AVRational const A = av_make_q(ANum, ADen); + return av_q2d(A); +} + +Expect AVD2Q::body(const Runtime::CallingFrame &Frame, double_t D, + int32_t Max, uint32_t ANumPtr, uint32_t ADenPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(ANum, MemInst, int32_t, ANumPtr, + "Failed to access Numerator Ptr for AVRational"sv); + MEM_PTR_CHECK(ADen, MemInst, int32_t, ADenPtr, + "Failed to access Denominator Ptr for AVRational"sv); + + AVRational const A = av_d2q(D, Max); + *ANum = A.num; + *ADen = A.den; + return static_cast(ErrNo::Success); +} + +Expect AVQ2IntFloat::body(const Runtime::CallingFrame &, int32_t ANum, + int32_t ADen) { + + AVRational const A = av_make_q(ANum, ADen); + return av_q2intfloat(A); +} + +Expect AVInvQ::body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen, uint32_t BNumPtr, uint32_t BDenPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(BNum, MemInst, int32_t, BNumPtr, + "Failed to access Numerator Ptr for AVRational"sv); + MEM_PTR_CHECK(BDen, MemInst, int32_t, BDenPtr, + "Failed to access Denominator Ptr for AVRational"sv); + + AVRational const A = av_make_q(ANum, ADen); + AVRational const B = av_inv_q(A); + + *BNum = B.num; + *BDen = B.den; + return static_cast(ErrNo::Success); +} + +Expect AVReduce::body(const Runtime::CallingFrame &Frame, + uint32_t ANumPtr, uint32_t ADenPtr, int64_t BNum, + int64_t BDen, int64_t Max) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(ANum, MemInst, int32_t, ANumPtr, + "Failed to access Numerator Ptr for AVRational"sv); + MEM_PTR_CHECK(ADen, MemInst, int32_t, ADenPtr, + "Failed to access Denominator Ptr for AVRational"sv); + return av_reduce(ANum, ADen, BNum, BDen, Max); +} + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/avRational.h b/plugins/wasmedge_ffmpeg/avutil/avRational.h new file mode 100644 index 00000000..b158e663 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/avRational.h @@ -0,0 +1,107 @@ +#pragma once +#include "avutil_base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +class AVAddQ : public WasmEdgeFFmpegAVUtil { +public: + AVAddQ(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen, int32_t BNum, int32_t BDen, + uint32_t CNumPtr, uint32_t CDenPtr); +}; + +class AVSubQ : public WasmEdgeFFmpegAVUtil { +public: + AVSubQ(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen, int32_t BNum, int32_t BDen, + uint32_t CNumPtr, uint32_t CDenPtr); +}; + +class AVMulQ : public WasmEdgeFFmpegAVUtil { +public: + AVMulQ(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen, int32_t BNum, int32_t BDen, + uint32_t CNumPtr, uint32_t CDenPtr); +}; + +class AVDivQ : public WasmEdgeFFmpegAVUtil { +public: + AVDivQ(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen, int32_t BNum, int32_t BDen, + uint32_t CNumPtr, uint32_t CDenPtr); +}; + +class AVCmpQ : public WasmEdgeFFmpegAVUtil { +public: + AVCmpQ(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen, int32_t BNum, int32_t BDen); +}; + +class AVNearerQ : public WasmEdgeFFmpegAVUtil { +public: + AVNearerQ(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen, int32_t BNum, int32_t BDen, int32_t CNum, + int32_t CDen); +}; + +class AVQ2d : public WasmEdgeFFmpegAVUtil { +public: + AVQ2d(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen); +}; + +class AVD2Q : public WasmEdgeFFmpegAVUtil { +public: + AVD2Q(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, double_t D, + int32_t Max, uint32_t ANumPtr, uint32_t ADenPtr); +}; + +class AVQ2IntFloat : public WasmEdgeFFmpegAVUtil { +public: + AVQ2IntFloat(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen); +}; + +class AVInvQ : public WasmEdgeFFmpegAVUtil { +public: + AVInvQ(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, + int32_t ADen, uint32_t BNumPtr, uint32_t BDenPtr); +}; + +class AVReduce : public WasmEdgeFFmpegAVUtil { +public: + AVReduce(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ANumPtr, + uint32_t ADenPtr, int64_t BNum, int64_t BDen, + int64_t Max); +}; + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/avTime.cpp b/plugins/wasmedge_ffmpeg/avutil/avTime.cpp new file mode 100644 index 00000000..1ebbb03a --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/avTime.cpp @@ -0,0 +1,32 @@ +#include "avTime.h" + +extern "C" { +#include "libavutil/time.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +Expect AVGetTime::body(const Runtime::CallingFrame &) { + return av_gettime(); +} + +Expect AVGetTimeRelative::body(const Runtime::CallingFrame &) { + return av_gettime_relative(); +} + +Expect +AVGetTimeRelativeIsMonotonic::body(const Runtime::CallingFrame &) { + return av_gettime_relative_is_monotonic(); +} + +Expect AVUSleep::body(const Runtime::CallingFrame &, uint32_t USec) { + return av_usleep(USec); +} + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/avTime.h b/plugins/wasmedge_ffmpeg/avutil/avTime.h new file mode 100644 index 00000000..803e404a --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/avTime.h @@ -0,0 +1,42 @@ +#pragma once +#include "avutil_base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +class AVGetTime : public WasmEdgeFFmpegAVUtil { +public: + AVGetTime(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVGetTimeRelative : public WasmEdgeFFmpegAVUtil { +public: + AVGetTimeRelative(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVGetTimeRelativeIsMonotonic + : public WasmEdgeFFmpegAVUtil { +public: + AVGetTimeRelativeIsMonotonic(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVUSleep : public WasmEdgeFFmpegAVUtil { +public: + AVUSleep(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t USec); +}; + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/avutil_base.h b/plugins/wasmedge_ffmpeg/avutil/avutil_base.h new file mode 100644 index 00000000..dcf35283 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/avutil_base.h @@ -0,0 +1,25 @@ +#pragma once + +#include "ffmpeg_env.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +template +class WasmEdgeFFmpegAVUtil : public Runtime::HostFunction { +public: + WasmEdgeFFmpegAVUtil( + std::shared_ptr HostEnv) + : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + std::shared_ptr Env; +}; + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp b/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp new file mode 100644 index 00000000..a1f20e4f --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp @@ -0,0 +1,142 @@ +#include "avutil_func.h" + +extern "C" { +#include "libavutil/avutil.h" +#include "libavutil/time.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +Expect AVLogSetLevel::body(const Runtime::CallingFrame &, + int32_t LogLevelId) { + av_log_set_level(LogLevelId); + return {}; +} + +Expect AVLogGetLevel::body(const Runtime::CallingFrame &) { + return av_log_get_level(); +} + +Expect AVLogGetFlags::body(const Runtime::CallingFrame &) { + return av_log_get_flags(); +} + +Expect AVLogSetFlags::body(const Runtime::CallingFrame &, + int32_t FlagId) { + av_log_set_flags(FlagId); + return {}; +} + +Expect AVRescaleQ::body(const Runtime::CallingFrame &, int64_t A, + int32_t BNum, int32_t BDen, int32_t CNum, + int32_t CDen) { + + AVRational const B = av_make_q(BNum, BDen); + AVRational const C = av_make_q(CNum, CDen); + return av_rescale_q(A, B, C); +} + +Expect AVRescaleQRnd::body(const Runtime::CallingFrame &, int64_t A, + int32_t BNum, int32_t BDen, int32_t CNum, + int32_t CDen, int32_t RoundingId) { + + AVRational const B = av_make_q(BNum, BDen); + AVRational const C = av_make_q(CNum, CDen); + AVRounding const Rounding = FFmpegUtils::Rounding::intoAVRounding(RoundingId); + return av_rescale_q_rnd(A, B, C, Rounding); +} + +Expect AVUtilVersion::body(const Runtime::CallingFrame &) { + return avutil_version(); +} + +Expect +AVGetChannelLayoutNbChannels::body(const Runtime::CallingFrame &, + uint64_t ChannelLayoutId) { + uint64_t const ChannelLayout = + FFmpegUtils::ChannelLayout::fromChannelLayoutID(ChannelLayoutId); + return av_get_channel_layout_nb_channels(ChannelLayout); +} + +Expect AVGetChannelLayoutNameLen::body(const Runtime::CallingFrame &, + uint64_t ChannelLayoutId) { + + uint64_t const ChannelLayout = + FFmpegUtils::ChannelLayout::fromChannelLayoutID(ChannelLayoutId); + const char *ChName = av_get_channel_name(ChannelLayout); + if (ChName == nullptr) + return 0; + return strlen(ChName); +} + +Expect AVGetChannelLayoutName::body(const Runtime::CallingFrame &Frame, + uint64_t ChannelLayoutId, + uint32_t NamePtr, + uint32_t NameLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(NameBuf, MemInst, char, NamePtr, NameLen, ""); + + uint64_t const ChannelLayout = + FFmpegUtils::ChannelLayout::fromChannelLayoutID(ChannelLayoutId); + const char *ChName = av_get_channel_name(ChannelLayout); + + std::copy_n(ChName, NameLen, NameBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVGetChannelLayoutMask::body(const Runtime::CallingFrame &, + uint64_t ChannelLayoutId) { + + uint64_t const ChannelLayout = + FFmpegUtils::ChannelLayout::fromChannelLayoutID(ChannelLayoutId); + return ChannelLayout; +} + +Expect AVGetDefaultChannelLayout::body(const Runtime::CallingFrame &, + int32_t Number) { + uint64_t const ChannelLayout = av_get_default_channel_layout(Number); + return FFmpegUtils::ChannelLayout::intoChannelLayoutID(ChannelLayout); +} + +Expect AVUtilConfigurationLength::body(const Runtime::CallingFrame &) { + const char *Config = avutil_configuration(); + return strlen(Config); +} + +Expect AVUtilConfiguration::body(const Runtime::CallingFrame &Frame, + uint32_t ConfigPtr, + uint32_t ConfigLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); + + const char *Config = avutil_configuration(); + std::copy_n(Config, ConfigLen, ConfigBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVUtilLicenseLength::body(const Runtime::CallingFrame &) { + + const char *License = avutil_license(); + return strlen(License); +} + +Expect AVUtilLicense::body(const Runtime::CallingFrame &Frame, + uint32_t LicensePtr, uint32_t LicenseLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); + + const char *License = avutil_license(); + std::copy_n(License, LicenseLen, LicenseBuf.data()); + return static_cast(ErrNo::Success); +} + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/avutil_func.h b/plugins/wasmedge_ffmpeg/avutil/avutil_func.h new file mode 100644 index 00000000..1529e94f --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/avutil_func.h @@ -0,0 +1,208 @@ +#pragma once +#include "avutil_base.h" + +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +class AVLogSetLevel : public WasmEdgeFFmpegAVUtil { +public: + AVLogSetLevel(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t LogLevelId); +}; + +class AVLogGetLevel : public WasmEdgeFFmpegAVUtil { +public: + AVLogGetLevel(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVLogSetFlags : public WasmEdgeFFmpegAVUtil { +public: + AVLogSetFlags(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t FlagsId); +}; + +class AVLogGetFlags : public WasmEdgeFFmpegAVUtil { +public: + AVLogGetFlags(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +// Option funcs. +class AVOptSetBin : public WasmEdgeFFmpegAVUtil { +public: + AVOptSetBin(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVOptSet : public WasmEdgeFFmpegAVUtil { +public: + AVOptSet(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVOptSetInt : public WasmEdgeFFmpegAVUtil { +public: + AVOptSetInt(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVOptSetDouble : public WasmEdgeFFmpegAVUtil { +public: + AVOptSetDouble(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVOptSetQ : public WasmEdgeFFmpegAVUtil { +public: + AVOptSetQ(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVOptSetImageSize : public WasmEdgeFFmpegAVUtil { +public: + AVOptSetImageSize(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVOptSetPixelFmt : public WasmEdgeFFmpegAVUtil { +public: + AVOptSetPixelFmt(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVOptSetSampleFmt : public WasmEdgeFFmpegAVUtil { +public: + AVOptSetSampleFmt(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVOptSetChannelLayout + : public WasmEdgeFFmpegAVUtil { +public: + AVOptSetChannelLayout(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVRescaleQ : public WasmEdgeFFmpegAVUtil { +public: + AVRescaleQ(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int64_t A, + int32_t BNum, int32_t BDen, int32_t CNum, int32_t CDen); +}; + +class AVRescaleQRnd : public WasmEdgeFFmpegAVUtil { +public: + AVRescaleQRnd(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &, int64_t A, int32_t BNum, + int32_t BDen, int32_t CNum, int32_t CDen, + int32_t RoundingId); +}; + +class AVUtilVersion : public WasmEdgeFFmpegAVUtil { +public: + AVUtilVersion(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &); +}; + +class AVGetChannelLayoutNbChannels + : public WasmEdgeFFmpegAVUtil { +public: + AVGetChannelLayoutNbChannels(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint64_t ChannelLayoutId); +}; + +class AVGetChannelLayoutNameLen + : public WasmEdgeFFmpegAVUtil { +public: + AVGetChannelLayoutNameLen(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint64_t ChannelLayoutId); +}; + +class AVGetChannelLayoutName + : public WasmEdgeFFmpegAVUtil { +public: + AVGetChannelLayoutName(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint64_t ChannelLayoutId, uint32_t NamePtr, + uint32_t NameLen); +}; + +class AVGetChannelLayoutMask + : public WasmEdgeFFmpegAVUtil { +public: + AVGetChannelLayoutMask(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint64_t ChannelLayoutId); +}; + +class AVGetDefaultChannelLayout + : public WasmEdgeFFmpegAVUtil { +public: + AVGetDefaultChannelLayout(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + int32_t ChannelLayoutId); +}; + +class AVUtilConfigurationLength + : public WasmEdgeFFmpegAVUtil { +public: + AVUtilConfigurationLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVUtilConfiguration : public WasmEdgeFFmpegAVUtil { +public: + AVUtilConfiguration(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, + uint32_t ConfigLen); +}; + +class AVUtilLicenseLength : public WasmEdgeFFmpegAVUtil { +public: + AVUtilLicenseLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class AVUtilLicense : public WasmEdgeFFmpegAVUtil { +public: + AVUtilLicense(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, + uint32_t LicenseLen); +}; + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/error.cpp b/plugins/wasmedge_ffmpeg/avutil/error.cpp new file mode 100644 index 00000000..d918ece5 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/error.cpp @@ -0,0 +1,39 @@ +#include "error.h" + +extern "C" { +#include "libavutil/error.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +Expect AVUtilAVStrError::body(const Runtime::CallingFrame &Frame, + int32_t ErrNum, uint32_t ErrBuf, + uint32_t BufLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + + MEM_PTR_CHECK(ErrId, MemInst, char, ErrBuf, + "Failed when accessing the return URL memory"sv); + + std::string Error; + std::copy_n(ErrId, BufLen, std::back_inserter(Error)); + return av_strerror(ErrNum, const_cast(Error.c_str()), BufLen); +} + +Expect AVUtilAVError::body(const Runtime::CallingFrame &, + int32_t ErrNum) { + return AVERROR(ErrNum); +} + +Expect AVUtilAVUNError::body(const Runtime::CallingFrame &, + int32_t ErrNum) { + return AVUNERROR(ErrNum); +} + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/error.h b/plugins/wasmedge_ffmpeg/avutil/error.h new file mode 100644 index 00000000..a8137151 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/error.h @@ -0,0 +1,36 @@ +#pragma once + +#include "avutil_base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +class AVUtilAVStrError : public WasmEdgeFFmpegAVUtil { +public: + AVUtilAVStrError(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t ErrNum, + uint32_t ErrBuf, uint32_t BufLen); +}; + +class AVUtilAVError : public WasmEdgeFFmpegAVUtil { +public: + AVUtilAVError(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t ErrNum); +}; + +class AVUtilAVUNError : public WasmEdgeFFmpegAVUtil { +public: + AVUtilAVUNError(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t ErrNum); +}; + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/module.cpp b/plugins/wasmedge_ffmpeg/avutil/module.cpp new file mode 100644 index 00000000..12588050 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/module.cpp @@ -0,0 +1,263 @@ +#include "module.h" +#include "avDictionary.h" +#include "avFrame.h" +#include "avRational.h" +#include "avTime.h" +#include "avutil_func.h" +#include "error.h" +#include "pixfmt.h" +#include "samplefmt.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +WasmEdgeFFmpegAVUtilModule::WasmEdgeFFmpegAVUtilModule( + std::shared_ptr Env) + : ModuleInstance("wasmedge_ffmpeg_avutil") { + + // error.h + addHostFunc("wasmedge_ffmpeg_avutil_av_strerror", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_AVERROR", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_AVUNERROR", + std::make_unique(Env)); + + // rational.h + addHostFunc("wasmedge_ffmpeg_avutil_av_add_q", std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_sub_q", std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_mul_q", std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_div_q", std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_d2q", std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_q2d", std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_inv_q", std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_q2intfloat", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_nearer_q", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_cmp_q", std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_reduce", + std::make_unique(Env)); + + // frame.h + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_alloc", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_free", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_width", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_height", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_width", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_height", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_video_format", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_video_format", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_isnull", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_linesize", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_data", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_get_buffer", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_audio_format", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_audio_format", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_nb_samples", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_channel_layout", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_nb_samples", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_sample_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_sample_rate", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_channels", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_channels", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_channel_layout", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_best_effort_timestamp", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_pict_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_pict_type", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_interlaced_frame", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_top_field_first", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_palette_has_changed", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_colorspace", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_colorspace", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_color_range", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_color_range", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_color_trc", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_color_trc", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_chroma_location", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_coded_picture_number", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_display_picture_number", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_repeat_pict", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_flags", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_quality", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_metadata", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_metadata", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_key_frame", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_pts", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_pts", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_copy", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_copy_props", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_sample_aspect_ratio", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_set_color_primaries", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_frame_color_primaries", + std::make_unique(Env)); + + // pixfmt.h (Even AvPixFmtDesc is in this file) + addHostFunc("wasmedge_ffmpeg_avutil_avpixfmtdescriptor_nb_components", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_avpixfmtdescriptor_log2_chromaw", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_avpixfmtdescriptor_log2_chromah", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_color_transfer_name_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_color_transfer_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_color_range_name_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_color_range_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_color_space_name_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_color_space_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_color_primaries_name_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_color_primaries_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_pix_format_name_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_pix_format_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_pix_format_mask", + std::make_unique(Env)); + + // samplefmt.h + addHostFunc("wasmedge_ffmpeg_avutil_av_get_packed_sample_fmt", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_get_planar_sample_fmt", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_sample_fmt_is_planar", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_get_bytes_per_sample", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_get_sample_fmt", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_samples_get_buffer_size", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_samples_alloc_array_and_samples", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_get_sample_fmt_name_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_get_sample_fmt_name", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_get_sample_fmt_mask", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_freep", + std::make_unique(Env)); + + // dict.h + addHostFunc("wasmedge_ffmpeg_avutil_av_dict_set", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_dict_get", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_dict_get_key_value", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_dict_copy", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_dict_free", + std::make_unique(Env)); + + // avutil_func.h + addHostFunc("wasmedge_ffmpeg_avutil_av_log_set_level", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_log_get_level", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_log_set_flags", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_log_get_flags", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_rescale_q", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_rescale_q_rnd", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_get_channel_layout_nb_channels", + std::make_unique(Env)); + addHostFunc( + "wasmedge_ffmpeg_avutil_av_get_channel_layout_name_len", // TODO: Write + std::make_unique(Env)); + addHostFunc( + "wasmedge_ffmpeg_avutil_av_get_channel_layout_name", // TODO: Write Test + std::make_unique(Env)); + addHostFunc( + "wasmedge_ffmpeg_avutil_av_get_channel_layout_mask", // TODO: Write Test + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_get_default_channel_layout", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_avutil_version", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_avutil_configuration_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_avutil_configuration", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_avutil_license_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_avutil_license", + std::make_unique(Env)); + + // time.h + addHostFunc("wasmedge_ffmpeg_avutil_av_gettime", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_gettime_relative", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_gettime_relative_is_monotonic", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_avutil_av_usleep", + std::make_unique(Env)); +} + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/module.h b/plugins/wasmedge_ffmpeg/avutil/module.h new file mode 100644 index 00000000..ebd35dba --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/module.h @@ -0,0 +1,19 @@ +#pragma once + +#include "ffmpeg_env.h" +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +class WasmEdgeFFmpegAVUtilModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeFFmpegAVUtilModule(std::shared_ptr Env); +}; + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/pixfmt.cpp b/plugins/wasmedge_ffmpeg/avutil/pixfmt.cpp new file mode 100644 index 00000000..9d5b55c2 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/pixfmt.cpp @@ -0,0 +1,174 @@ +#include "pixfmt.h" +extern "C" { +#include "libavutil/pixdesc.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +Expect +AvPixFmtDescriptorNbComponents::body(const Runtime::CallingFrame &, + uint32_t PixFormatId) { + AVPixelFormat const PixelFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(PixFormatId); + const AVPixFmtDescriptor *AvPixFmtDescriptor = + av_pix_fmt_desc_get(PixelFormat); + return AvPixFmtDescriptor->nb_components; +} + +Expect +AvPixFmtDescriptorLog2ChromaW::body(const Runtime::CallingFrame &, + uint32_t PixFormatId) { + + AVPixelFormat const PixelFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(PixFormatId); + const AVPixFmtDescriptor *AvPixFmtDescriptor = + av_pix_fmt_desc_get(PixelFormat); + return AvPixFmtDescriptor->log2_chroma_w; +} + +Expect +AvPixFmtDescriptorLog2ChromaH::body(const Runtime::CallingFrame &, + uint32_t PixFormatId) { + + AVPixelFormat const PixelFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(PixFormatId); + const AVPixFmtDescriptor *AvPixFmtDescriptor = + av_pix_fmt_desc_get(PixelFormat); + return AvPixFmtDescriptor->log2_chroma_h; +} + +Expect AVColorRangeNameLength::body(const Runtime::CallingFrame &, + int32_t RangeId) { + + AVColorRange const ColorRange = static_cast(RangeId); + const char *Name = av_color_range_name(ColorRange); + return strlen(Name); +} + +Expect AVColorRangeName::body(const Runtime::CallingFrame &Frame, + int32_t RangeId, uint32_t RangeNamePtr, + uint32_t RangeLength) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(RangeNameBuf, MemInst, char, RangeNamePtr, RangeLength, ""); + + AVColorRange const ColorRange = static_cast(RangeId); + const char *RangeName = av_color_range_name(ColorRange); + std::copy_n(RangeName, RangeLength, RangeNameBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVColorTransferNameLength::body(const Runtime::CallingFrame &, + int32_t TransferId) { + + AVColorTransferCharacteristic const Characteristic = + static_cast(TransferId); + const char *Name = av_color_transfer_name(Characteristic); + return strlen(Name); +} + +Expect AVColorTransferName::body(const Runtime::CallingFrame &Frame, + int32_t TransferId, + uint32_t TransferNamePtr, + uint32_t TransferLength) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(TransferNameBuf, MemInst, char, TransferNamePtr, + TransferLength, ""); + + AVColorTransferCharacteristic const Characteristic = + static_cast(TransferId); + const char *TransferName = av_color_transfer_name(Characteristic); + std::copy_n(TransferName, TransferLength, TransferNameBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVColorSpaceNameLength::body(const Runtime::CallingFrame &, + int32_t ColorSpaceId) { + + AVColorSpace const ColorSpace = static_cast(ColorSpaceId); + const char *Name = av_color_space_name(ColorSpace); + return strlen(Name); +} + +Expect AVColorSpaceName::body(const Runtime::CallingFrame &Frame, + int32_t ColorSpaceId, + uint32_t ColorSpaceNamePtr, + uint32_t ColorSpaceLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(ColorSpaceBuf, MemInst, char, ColorSpaceNamePtr, ColorSpaceLen, + ""); + + AVColorSpace const ColorSpace = static_cast(ColorSpaceId); + const char *ColorSpaceName = av_color_space_name(ColorSpace); + std::copy_n(ColorSpaceName, ColorSpaceLen, ColorSpaceBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVColorPrimariesNameLength::body(const Runtime::CallingFrame &, + int32_t ColorPrimariesId) { + + AVColorPrimaries const ColorPrimaries = + FFmpegUtils::ColorPrimaries::intoAVColorPrimaries(ColorPrimariesId); + const char *Name = av_color_primaries_name(ColorPrimaries); + return strlen(Name); +} + +Expect AVColorPrimariesName::body(const Runtime::CallingFrame &Frame, + int32_t ColorPrimariesId, + uint32_t ColorPrimariesNamePtr, + uint32_t ColorPrimariesLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(ColorPrimariesBuf, MemInst, char, ColorPrimariesNamePtr, + ColorPrimariesLen, ""); + + AVColorPrimaries const ColorPrimaries = + FFmpegUtils::ColorPrimaries::intoAVColorPrimaries(ColorPrimariesId); + const char *PrimariesName = av_color_primaries_name(ColorPrimaries); + std::copy_n(PrimariesName, ColorPrimariesLen, ColorPrimariesBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVPixelFormatNameLength::body(const Runtime::CallingFrame &, + uint32_t AvPixFormatId) { + + AVPixelFormat const PixFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(AvPixFormatId); + const AVPixFmtDescriptor *PixFmtDescriptor = av_pix_fmt_desc_get(PixFormat); + + return strlen(PixFmtDescriptor->name); +} + +Expect AVPixelFormatName::body(const Runtime::CallingFrame &Frame, + uint32_t PixFormatId, + uint32_t PixFormatNamePtr, + uint32_t PixFormatNameLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(PixFormatBuf, MemInst, char, PixFormatNamePtr, + PixFormatNameLen, ""); + + AVPixelFormat const PixFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(PixFormatId); + const AVPixFmtDescriptor *PixFmtDescriptor = av_pix_fmt_desc_get(PixFormat); + const char *PixFormatName = PixFmtDescriptor->name; + std::copy_n(PixFormatName, PixFormatNameLen, PixFormatBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVPixelFormatMask::body(const Runtime::CallingFrame &, + uint32_t PixFormatId) { + AVPixelFormat const PixelFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(PixFormatId); + return static_cast(PixelFormat); +} + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/pixfmt.h b/plugins/wasmedge_ffmpeg/avutil/pixfmt.h new file mode 100644 index 00000000..51126aa6 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/pixfmt.h @@ -0,0 +1,132 @@ +#pragma once +#include "avutil_base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +class AvPixFmtDescriptorNbComponents + : public WasmEdgeFFmpegAVUtil { +public: + AvPixFmtDescriptorNbComponents(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t PixFormatId); +}; + +class AvPixFmtDescriptorLog2ChromaW + : public WasmEdgeFFmpegAVUtil { +public: + AvPixFmtDescriptorLog2ChromaW(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t PixFormatId); +}; + +class AvPixFmtDescriptorLog2ChromaH + : public WasmEdgeFFmpegAVUtil { +public: + AvPixFmtDescriptorLog2ChromaH(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t PixFormatId); +}; + +class AVColorRangeNameLength + : public WasmEdgeFFmpegAVUtil { +public: + AVColorRangeNameLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t RangeId); +}; + +class AVColorRangeName : public WasmEdgeFFmpegAVUtil { +public: + AVColorRangeName(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t RangeId, + uint32_t RangeName, uint32_t RangeLength); +}; + +class AVColorTransferNameLength + : public WasmEdgeFFmpegAVUtil { +public: + AVColorTransferNameLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t TransferId); +}; + +class AVColorTransferName : public WasmEdgeFFmpegAVUtil { +public: + AVColorTransferName(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t TransferId, + uint32_t TransferNamePtr, uint32_t TransferLength); +}; + +class AVColorSpaceNameLength + : public WasmEdgeFFmpegAVUtil { +public: + AVColorSpaceNameLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + int32_t ColorSpaceId); +}; + +class AVColorSpaceName : public WasmEdgeFFmpegAVUtil { +public: + AVColorSpaceName(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t ColorSpaceId, + uint32_t ColorSpaceNamePtr, uint32_t ColorSpaceLen); +}; + +class AVColorPrimariesNameLength + : public WasmEdgeFFmpegAVUtil { +public: + AVColorPrimariesNameLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + int32_t ColorPrimariesId); +}; + +class AVColorPrimariesName : public WasmEdgeFFmpegAVUtil { +public: + AVColorPrimariesName(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + int32_t ColorPrimariesId, uint32_t ColorPrimariesNamePtr, + uint32_t ColorPrimariesLen); +}; + +class AVPixelFormatNameLength + : public WasmEdgeFFmpegAVUtil { +public: + AVPixelFormatNameLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t AvPixFormatId); +}; + +class AVPixelFormatName : public WasmEdgeFFmpegAVUtil { +public: + AVPixelFormatName(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t PixFormatId, + uint32_t PixFormatNamePtr, uint32_t PixFormatNameLen); +}; + +class AVPixelFormatMask : public WasmEdgeFFmpegAVUtil { +public: + AVPixelFormatMask(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t PixFormatId); +}; + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/samplefmt.cpp b/plugins/wasmedge_ffmpeg/avutil/samplefmt.cpp new file mode 100644 index 00000000..f38914e6 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/samplefmt.cpp @@ -0,0 +1,134 @@ +#include "samplefmt.h" +extern "C" { +#include "libavutil/samplefmt.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +Expect AVGetPlanarSampleFmt::body(const Runtime::CallingFrame &, + uint32_t SampleFormatId) { + AVSampleFormat const AvSampleFormat = + FFmpegUtils::SampleFmt::fromSampleID(SampleFormatId); + AVSampleFormat const PlanarSampleFmt = + av_get_planar_sample_fmt(AvSampleFormat); + return FFmpegUtils::SampleFmt::toSampleID(PlanarSampleFmt); +} + +Expect AVGetPackedSampleFmt::body(const Runtime::CallingFrame &, + uint32_t SampleFormatId) { + AVSampleFormat const AvSampleFormat = + FFmpegUtils::SampleFmt::fromSampleID(SampleFormatId); + AVSampleFormat const PackedSampleFmt = + av_get_packed_sample_fmt(AvSampleFormat); + return FFmpegUtils::SampleFmt::toSampleID(PackedSampleFmt); +} + +Expect AVSampleFmtIsPlanar::body(const Runtime::CallingFrame &, + uint32_t SampleFormatId) { + AVSampleFormat const AvSampleFormat = + FFmpegUtils::SampleFmt::fromSampleID(SampleFormatId); + return av_sample_fmt_is_planar(AvSampleFormat); +} + +Expect AVGetBytesPerSample::body(const Runtime::CallingFrame &, + uint32_t SampleFormatId) { + AVSampleFormat const AvSampleFormat = + FFmpegUtils::SampleFmt::fromSampleID(SampleFormatId); + return av_get_bytes_per_sample(AvSampleFormat); +} + +Expect AVGetSampleFmt::body(const Runtime::CallingFrame &Frame, + uint32_t Str, uint32_t StrLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(StrId, MemInst, char, Str, ""); + + std::string TargetUrl; + std::copy_n(StrId, StrLen, std::back_inserter(TargetUrl)); + + AVSampleFormat const AvSampleFormat = av_get_sample_fmt(TargetUrl.c_str()); + return FFmpegUtils::SampleFmt::toSampleID(AvSampleFormat); +} + +Expect AVSamplesGetBufferSize::body(const Runtime::CallingFrame &, + int32_t NbChannels, + int32_t NbSamples, + uint32_t SampleFormatId, + int32_t Align) { + AVSampleFormat const AvSampleFormat = + FFmpegUtils::SampleFmt::fromSampleID(SampleFormatId); + return av_samples_get_buffer_size(nullptr, NbChannels, NbSamples, + AvSampleFormat, + Align); // linesize is NULL in RustSDK. +} + +Expect +AVSamplesAllocArrayAndSamples::body(const Runtime::CallingFrame &Frame, + uint32_t BufferPtr, uint32_t LinesizePtr, + int32_t NbChannels, int32_t NbSamples, + uint32_t SampleFmtId, int32_t Align) { + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(BufId, MemInst, uint32_t, BufferPtr, ""); + MEM_PTR_CHECK(LineSize, MemInst, int32_t, LinesizePtr, ""); + + FFMPEG_PTR_FETCH(Buf, *BufId, uint8_t *); + int LineSizeValue = 0; + AVSampleFormat const AvSampleFormat = + FFmpegUtils::SampleFmt::fromSampleID(SampleFmtId); + int Res = av_samples_alloc_array_and_samples( + &Buf, &LineSizeValue, NbChannels, NbSamples, AvSampleFormat, Align); + + *LineSize = LineSizeValue; + FFMPEG_PTR_STORE(Buf, BufId); + return Res; +} + +Expect AVGetSampleFmtNameLength::body(const Runtime::CallingFrame &, + uint32_t SampleFmtId) { + + AVSampleFormat const SampleFmt = + FFmpegUtils::SampleFmt::fromSampleID(SampleFmtId); + + const char *Name = av_get_sample_fmt_name(SampleFmt); + return strlen(Name); +} + +Expect AVGetSampleFmtName::body(const Runtime::CallingFrame &Frame, + uint32_t SampleFmtId, + uint32_t SampleFmtNamePtr, + uint32_t SampleFmtNameLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(SampleFmtBuf, MemInst, char, SampleFmtNamePtr, + SampleFmtNameLen, ""); + + AVSampleFormat const SampleFmt = + FFmpegUtils::SampleFmt::fromSampleID(SampleFmtId); + const char *Name = av_get_sample_fmt_name(SampleFmt); + std::copy_n(Name, SampleFmtNameLen, SampleFmtBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect AVGetSampleFmtMask::body(const Runtime::CallingFrame &, + uint32_t SampleFmtId) { + + AVSampleFormat const SampleFmt = + FFmpegUtils::SampleFmt::fromSampleID(SampleFmtId); + return static_cast(SampleFmt); +} + +Expect AVFreep::body(const Runtime::CallingFrame &, + uint32_t BufferId) { + FFMPEG_PTR_FETCH(Buffer, BufferId, uint8_t *); + av_freep(Buffer); + FFMPEG_PTR_DELETE(BufferId); + return static_cast(ErrNo::Success); +} + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/samplefmt.h b/plugins/wasmedge_ffmpeg/avutil/samplefmt.h new file mode 100644 index 00000000..a6190779 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/avutil/samplefmt.h @@ -0,0 +1,105 @@ +#pragma once +#include "avutil_base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace AVUtil { + +class AVGetPlanarSampleFmt : public WasmEdgeFFmpegAVUtil { +public: + AVGetPlanarSampleFmt(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SampleFormatId); +}; + +class AVGetPackedSampleFmt : public WasmEdgeFFmpegAVUtil { +public: + AVGetPackedSampleFmt(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SampleFormatId); +}; + +class AVSampleFmtIsPlanar : public WasmEdgeFFmpegAVUtil { +public: + AVSampleFmtIsPlanar(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SampleFormatId); +}; + +class AVGetBytesPerSample : public WasmEdgeFFmpegAVUtil { +public: + AVGetBytesPerSample(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SampleFormatId); +}; + +class AVGetSampleFmt : public WasmEdgeFFmpegAVUtil { +public: + AVGetSampleFmt(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t Str, + uint32_t StrLen); +}; + +class AVSamplesGetBufferSize + : public WasmEdgeFFmpegAVUtil { +public: + AVSamplesGetBufferSize(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, int32_t NbChannels, + int32_t NbSamples, uint32_t SampleFormatId, + int32_t Align); +}; + +class AVSamplesAllocArrayAndSamples + : public WasmEdgeFFmpegAVUtil { +public: + AVSamplesAllocArrayAndSamples(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t BufferPtr, + uint32_t LinesizePtr, int32_t NbChannels, + int32_t NbSamples, uint32_t SampleFmtId, int32_t Align); +}; + +class AVGetSampleFmtNameLength + : public WasmEdgeFFmpegAVUtil { +public: + AVGetSampleFmtNameLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SampleFmtId); +}; + +class AVGetSampleFmtName : public WasmEdgeFFmpegAVUtil { +public: + AVGetSampleFmtName(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SampleFmtId, + uint32_t SampleFmtNamePtr, uint32_t SampleFmtNameLen); +}; + +class AVGetSampleFmtMask : public WasmEdgeFFmpegAVUtil { +public: + AVGetSampleFmtMask(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SampleFmtId); +}; + +class AVFreep : public WasmEdgeFFmpegAVUtil { +public: + AVFreep(std::shared_ptr HostEnv) + : WasmEdgeFFmpegAVUtil(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t BufferId); +}; + +} // namespace AVUtil +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/bindings.h b/plugins/wasmedge_ffmpeg/bindings.h new file mode 100644 index 00000000..748f5987 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/bindings.h @@ -0,0 +1,4425 @@ +#pragma once + +extern "C" { +#include "libavcodec/avcodec.h" +#include "libavutil/avutil.h" +#include "libavutil/opt.h" +#include "libswresample/swresample.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace FFmpegUtils { +class MediaType { +public: + static AVMediaType intoMediaType(int32_t MediaTypeId) { + switch (MediaTypeId) { + case 0: + return AVMEDIA_TYPE_VIDEO; + case 1: + return AVMEDIA_TYPE_AUDIO; + case 2: + return AVMEDIA_TYPE_DATA; + case 3: + return AVMEDIA_TYPE_SUBTITLE; + case 4: + return AVMEDIA_TYPE_ATTACHMENT; + case 5: + return AVMEDIA_TYPE_NB; + default: + return AVMEDIA_TYPE_UNKNOWN; + } + } + + static int32_t fromMediaType(AVMediaType MediaType) { + switch (MediaType) { + case AVMEDIA_TYPE_VIDEO: + return 0; + case AVMEDIA_TYPE_AUDIO: + return 1; + case AVMEDIA_TYPE_DATA: + return 2; + case AVMEDIA_TYPE_SUBTITLE: + return 3; + case 4: + return AVMEDIA_TYPE_ATTACHMENT; + case 5: + return AVMEDIA_TYPE_NB; + default: + return AVMEDIA_TYPE_UNKNOWN; + } + } +}; + +class CodecID { +public: + static AVCodecID intoAVCodecID(uint32_t AvCodecIndex) { + switch (AvCodecIndex) { + case 0: + return AV_CODEC_ID_NONE; + case 1: + return AV_CODEC_ID_MPEG1VIDEO; + case 2: + return AV_CODEC_ID_MPEG2VIDEO; + case 3: + return AV_CODEC_ID_H261; + case 4: + return AV_CODEC_ID_H263; + case 5: + return AV_CODEC_ID_RV10; + case 6: + return AV_CODEC_ID_RV20; + case 7: + return AV_CODEC_ID_MJPEG; + case 8: + return AV_CODEC_ID_MJPEGB; + case 9: + return AV_CODEC_ID_LJPEG; + case 10: + return AV_CODEC_ID_SP5X; + case 11: + return AV_CODEC_ID_JPEGLS; + case 12: + return AV_CODEC_ID_MPEG4; + case 13: + return AV_CODEC_ID_RAWVIDEO; + case 14: + return AV_CODEC_ID_MSMPEG4V1; + case 15: + return AV_CODEC_ID_MSMPEG4V2; + case 16: + return AV_CODEC_ID_MSMPEG4V3; + case 17: + return AV_CODEC_ID_WMV1; + case 18: + return AV_CODEC_ID_WMV2; + case 19: + return AV_CODEC_ID_H263P; + case 20: + return AV_CODEC_ID_H263I; + case 21: + return AV_CODEC_ID_FLV1; + case 22: + return AV_CODEC_ID_SVQ1; + case 23: + return AV_CODEC_ID_SVQ3; + case 24: + return AV_CODEC_ID_DVVIDEO; + case 25: + return AV_CODEC_ID_HUFFYUV; + case 26: + return AV_CODEC_ID_CYUV; + case 27: + return AV_CODEC_ID_H264; + case 28: + return AV_CODEC_ID_INDEO3; + case 29: + return AV_CODEC_ID_VP3; + case 30: + return AV_CODEC_ID_THEORA; + case 31: + return AV_CODEC_ID_ASV1; + case 32: + return AV_CODEC_ID_ASV2; + case 33: + return AV_CODEC_ID_FFV1; + case 34: + return AV_CODEC_ID_4XM; + case 35: + return AV_CODEC_ID_VCR1; + case 36: + return AV_CODEC_ID_CLJR; + case 37: + return AV_CODEC_ID_MDEC; + case 38: + return AV_CODEC_ID_ROQ; + case 39: + return AV_CODEC_ID_INTERPLAY_VIDEO; + case 40: + return AV_CODEC_ID_XAN_WC3; + case 41: + return AV_CODEC_ID_XAN_WC4; + case 42: + return AV_CODEC_ID_RPZA; + case 43: + return AV_CODEC_ID_CINEPAK; + case 44: + return AV_CODEC_ID_WS_VQA; + case 45: + return AV_CODEC_ID_MSRLE; + case 46: + return AV_CODEC_ID_MSVIDEO1; + case 47: + return AV_CODEC_ID_IDCIN; + case 48: + return AV_CODEC_ID_8BPS; + case 49: + return AV_CODEC_ID_SMC; + case 50: + return AV_CODEC_ID_FLIC; + case 51: + return AV_CODEC_ID_TRUEMOTION1; + case 52: + return AV_CODEC_ID_VMDVIDEO; + case 53: + return AV_CODEC_ID_MSZH; + case 54: + return AV_CODEC_ID_ZLIB; + case 55: + return AV_CODEC_ID_QTRLE; + case 56: + return AV_CODEC_ID_TSCC; + case 57: + return AV_CODEC_ID_ULTI; + case 58: + return AV_CODEC_ID_QDRAW; + case 59: + return AV_CODEC_ID_VIXL; + case 60: + return AV_CODEC_ID_QPEG; + case 61: + return AV_CODEC_ID_PNG; + case 62: + return AV_CODEC_ID_PPM; + case 63: + return AV_CODEC_ID_PBM; + case 64: + return AV_CODEC_ID_PGM; + case 65: + return AV_CODEC_ID_PGMYUV; + case 66: + return AV_CODEC_ID_PAM; + case 67: + return AV_CODEC_ID_FFVHUFF; + case 68: + return AV_CODEC_ID_RV30; + case 69: + return AV_CODEC_ID_RV40; + case 70: + return AV_CODEC_ID_VC1; + case 71: + return AV_CODEC_ID_WMV3; + case 72: + return AV_CODEC_ID_LOCO; + case 73: + return AV_CODEC_ID_WNV1; + case 74: + return AV_CODEC_ID_AASC; + case 75: + return AV_CODEC_ID_INDEO2; + case 76: + return AV_CODEC_ID_FRAPS; + case 77: + return AV_CODEC_ID_TRUEMOTION2; + case 78: + return AV_CODEC_ID_BMP; + case 79: + return AV_CODEC_ID_CSCD; + case 80: + return AV_CODEC_ID_MMVIDEO; + case 81: + return AV_CODEC_ID_ZMBV; + case 82: + return AV_CODEC_ID_AVS; + case 83: + return AV_CODEC_ID_SMACKVIDEO; + case 84: + return AV_CODEC_ID_NUV; + case 85: + return AV_CODEC_ID_KMVC; + case 86: + return AV_CODEC_ID_FLASHSV; + case 87: + return AV_CODEC_ID_CAVS; + case 88: + return AV_CODEC_ID_JPEG2000; + case 89: + return AV_CODEC_ID_VMNC; + case 90: + return AV_CODEC_ID_VP5; + case 91: + return AV_CODEC_ID_VP6; + case 92: + return AV_CODEC_ID_VP6F; + case 93: + return AV_CODEC_ID_TARGA; + case 94: + return AV_CODEC_ID_DSICINVIDEO; + case 95: + return AV_CODEC_ID_TIERTEXSEQVIDEO; + case 96: + return AV_CODEC_ID_TIFF; + case 97: + return AV_CODEC_ID_GIF; + case 98: + return AV_CODEC_ID_DXA; + case 99: + return AV_CODEC_ID_DNXHD; + case 100: + return AV_CODEC_ID_THP; + case 101: + return AV_CODEC_ID_SGI; + case 102: + return AV_CODEC_ID_C93; + case 103: + return AV_CODEC_ID_BETHSOFTVID; + case 104: + return AV_CODEC_ID_PTX; + case 105: + return AV_CODEC_ID_TXD; + case 106: + return AV_CODEC_ID_VP6A; + case 107: + return AV_CODEC_ID_AMV; + case 108: + return AV_CODEC_ID_VB; + case 109: + return AV_CODEC_ID_PCX; + case 110: + return AV_CODEC_ID_SUNRAST; + case 111: + return AV_CODEC_ID_INDEO4; + case 112: + return AV_CODEC_ID_INDEO5; + case 113: + return AV_CODEC_ID_MIMIC; + case 114: + return AV_CODEC_ID_RL2; + case 115: + return AV_CODEC_ID_ESCAPE124; + case 116: + return AV_CODEC_ID_DIRAC; + case 117: + return AV_CODEC_ID_BFI; + case 118: + return AV_CODEC_ID_CMV; + case 119: + return AV_CODEC_ID_MOTIONPIXELS; + case 120: + return AV_CODEC_ID_TGV; + case 121: + return AV_CODEC_ID_TGQ; + case 122: + return AV_CODEC_ID_TQI; + case 123: + return AV_CODEC_ID_AURA; + case 124: + return AV_CODEC_ID_AURA2; + case 125: + return AV_CODEC_ID_V210X; + case 126: + return AV_CODEC_ID_TMV; + case 127: + return AV_CODEC_ID_V210; + case 128: + return AV_CODEC_ID_DPX; + case 129: + return AV_CODEC_ID_MAD; + case 130: + return AV_CODEC_ID_FRWU; + case 131: + return AV_CODEC_ID_FLASHSV2; + case 132: + return AV_CODEC_ID_CDGRAPHICS; + case 133: + return AV_CODEC_ID_R210; + case 134: + return AV_CODEC_ID_ANM; + case 135: + return AV_CODEC_ID_BINKVIDEO; + case 136: + return AV_CODEC_ID_IFF_ILBM; + case 137: + return AV_CODEC_ID_IFF_ILBM; + case 138: + return AV_CODEC_ID_KGV1; + case 139: + return AV_CODEC_ID_YOP; + case 140: + return AV_CODEC_ID_VP8; + case 141: + return AV_CODEC_ID_PICTOR; + case 142: + return AV_CODEC_ID_ANSI; + case 143: + return AV_CODEC_ID_A64_MULTI; + case 144: + return AV_CODEC_ID_A64_MULTI5; + case 145: + return AV_CODEC_ID_R10K; + case 146: + return AV_CODEC_ID_MXPEG; + case 147: + return AV_CODEC_ID_LAGARITH; + case 148: + return AV_CODEC_ID_PRORES; + case 149: + return AV_CODEC_ID_JV; + case 150: + return AV_CODEC_ID_DFA; + case 151: + return AV_CODEC_ID_WMV3IMAGE; + case 152: + return AV_CODEC_ID_VC1IMAGE; + case 153: + return AV_CODEC_ID_UTVIDEO; + case 154: + return AV_CODEC_ID_BMV_VIDEO; + case 155: + return AV_CODEC_ID_VBLE; + case 156: + return AV_CODEC_ID_DXTORY; + case 157: + return AV_CODEC_ID_V410; + case 158: + return AV_CODEC_ID_XWD; + case 159: + return AV_CODEC_ID_CDXL; + case 160: + return AV_CODEC_ID_XBM; + case 161: + return AV_CODEC_ID_ZEROCODEC; + case 162: + return AV_CODEC_ID_MSS1; + case 163: + return AV_CODEC_ID_MSA1; + case 164: + return AV_CODEC_ID_TSCC2; + case 165: + return AV_CODEC_ID_MTS2; + case 166: + return AV_CODEC_ID_CLLC; + case 167: + return AV_CODEC_ID_MSS2; + case 168: + return AV_CODEC_ID_VP9; + case 169: + return AV_CODEC_ID_AIC; + case 170: + return AV_CODEC_ID_ESCAPE130; + case 171: + return AV_CODEC_ID_G2M; + case 172: + return AV_CODEC_ID_WEBP; + case 173: + return AV_CODEC_ID_HNM4_VIDEO; + case 174: + return AV_CODEC_ID_HEVC; + case 175: + return AV_CODEC_ID_HEVC; + case 176: + return AV_CODEC_ID_FIC; + case 177: + return AV_CODEC_ID_ALIAS_PIX; + case 178: + return AV_CODEC_ID_BRENDER_PIX; + case 179: + return AV_CODEC_ID_PAF_VIDEO; + case 180: + return AV_CODEC_ID_EXR; + case 181: + return AV_CODEC_ID_VP7; + case 182: + return AV_CODEC_ID_SANM; + case 183: + return AV_CODEC_ID_SGIRLE; + case 184: + return AV_CODEC_ID_MVC1; + case 185: + return AV_CODEC_ID_MVC2; + case 186: + return AV_CODEC_ID_HQX; + case 187: + return AV_CODEC_ID_TDSC; + case 188: + return AV_CODEC_ID_HQ_HQA; + case 189: + return AV_CODEC_ID_HAP; + case 190: + return AV_CODEC_ID_DDS; + case 191: + return AV_CODEC_ID_DXV; + case 192: + return AV_CODEC_ID_SCREENPRESSO; + case 193: + return AV_CODEC_ID_RSCC; + /////////////////////////////// + // case 194: + // return AV_CODEC_ID_Y41P; + // case 194: + // return AV_CODEC_ID_AVS2; + case 194: + return AV_CODEC_ID_Y41P; + case 195: + return AV_CODEC_ID_AVRP; + case 196: + return AV_CODEC_ID_012V; + case 197: + return AV_CODEC_ID_AVUI; + case 198: + return AV_CODEC_ID_AYUV; + case 199: + return AV_CODEC_ID_TARGA_Y216; + case 200: + return AV_CODEC_ID_V308; + case 201: + return AV_CODEC_ID_V408; + case 202: + return AV_CODEC_ID_YUV4; + case 203: + return AV_CODEC_ID_AVRN; + case 204: + return AV_CODEC_ID_CPIA; + case 205: + return AV_CODEC_ID_XFACE; + case 206: + return AV_CODEC_ID_SNOW; + case 207: + return AV_CODEC_ID_SMVJPEG; + case 208: + return AV_CODEC_ID_APNG; + case 209: + return AV_CODEC_ID_DAALA; + case 210: + return AV_CODEC_ID_CFHD; + case 211: + return AV_CODEC_ID_TRUEMOTION2RT; + case 212: + return AV_CODEC_ID_M101; + case 213: + return AV_CODEC_ID_MAGICYUV; + case 214: + return AV_CODEC_ID_SHEERVIDEO; + case 215: + return AV_CODEC_ID_YLC; + case 216: + return AV_CODEC_ID_PCM_S16LE; + case 217: + return AV_CODEC_ID_PCM_S16BE; + case 218: + return AV_CODEC_ID_PCM_U16LE; + case 219: + return AV_CODEC_ID_PCM_U16BE; + case 220: + return AV_CODEC_ID_PCM_S8; + case 221: + return AV_CODEC_ID_PCM_U8; + case 222: + return AV_CODEC_ID_PCM_MULAW; + case 223: + return AV_CODEC_ID_PCM_ALAW; + case 224: + return AV_CODEC_ID_PCM_S32LE; + case 225: + return AV_CODEC_ID_PCM_S32BE; + case 226: + return AV_CODEC_ID_PCM_U32LE; + case 227: + return AV_CODEC_ID_PCM_U32BE; + case 228: + return AV_CODEC_ID_PCM_S24LE; + case 229: + return AV_CODEC_ID_PCM_S24BE; + case 230: + return AV_CODEC_ID_PCM_U24LE; + case 231: + return AV_CODEC_ID_PCM_U24BE; + case 232: + return AV_CODEC_ID_PCM_S24DAUD; + case 233: + return AV_CODEC_ID_PCM_ZORK; + case 234: + return AV_CODEC_ID_PCM_S16LE_PLANAR; + case 235: + return AV_CODEC_ID_PCM_DVD; + case 236: + return AV_CODEC_ID_PCM_F32BE; + case 237: + return AV_CODEC_ID_PCM_F32LE; + case 238: + return AV_CODEC_ID_PCM_F64BE; + case 239: + return AV_CODEC_ID_PCM_F64LE; + case 240: + return AV_CODEC_ID_PCM_BLURAY; + case 241: + return AV_CODEC_ID_PCM_LXF; + case 242: + return AV_CODEC_ID_S302M; + case 243: + return AV_CODEC_ID_PCM_S8_PLANAR; + case 244: + return AV_CODEC_ID_PCM_S24LE_PLANAR; + case 245: + return AV_CODEC_ID_PCM_S32LE_PLANAR; + case 246: + return AV_CODEC_ID_PCM_S16BE_PLANAR; + case 247: + return AV_CODEC_ID_PCM_S64LE; + case 248: + return AV_CODEC_ID_PCM_S64BE; + case 249: + return AV_CODEC_ID_ADPCM_IMA_QT; + case 250: + return AV_CODEC_ID_ADPCM_IMA_WAV; + case 251: + return AV_CODEC_ID_ADPCM_IMA_DK3; + case 252: + return AV_CODEC_ID_ADPCM_IMA_DK4; + case 253: + return AV_CODEC_ID_ADPCM_IMA_WS; + case 254: + return AV_CODEC_ID_ADPCM_IMA_SMJPEG; + case 255: + return AV_CODEC_ID_ADPCM_MS; + case 256: + return AV_CODEC_ID_ADPCM_4XM; + case 257: + return AV_CODEC_ID_ADPCM_XA; + case 258: + return AV_CODEC_ID_ADPCM_ADX; + case 259: + return AV_CODEC_ID_ADPCM_EA; + case 260: + return AV_CODEC_ID_ADPCM_G726; + case 261: + return AV_CODEC_ID_ADPCM_CT; + case 262: + return AV_CODEC_ID_ADPCM_SWF; + case 263: + return AV_CODEC_ID_ADPCM_YAMAHA; + case 264: + return AV_CODEC_ID_ADPCM_SBPRO_4; + case 265: + return AV_CODEC_ID_ADPCM_SBPRO_3; + case 266: + return AV_CODEC_ID_ADPCM_SBPRO_2; + case 267: + return AV_CODEC_ID_ADPCM_THP; + case 268: + return AV_CODEC_ID_ADPCM_IMA_AMV; + case 269: + return AV_CODEC_ID_ADPCM_EA_R1; + case 270: + return AV_CODEC_ID_ADPCM_EA_R3; + case 271: + return AV_CODEC_ID_ADPCM_EA_R2; + case 272: + return AV_CODEC_ID_ADPCM_IMA_EA_SEAD; + case 273: + return AV_CODEC_ID_ADPCM_IMA_EA_EACS; + case 274: + return AV_CODEC_ID_ADPCM_EA_XAS; + case 275: + return AV_CODEC_ID_ADPCM_EA_MAXIS_XA; + case 276: + return AV_CODEC_ID_ADPCM_IMA_ISS; + case 277: + return AV_CODEC_ID_ADPCM_G722; + case 278: + return AV_CODEC_ID_ADPCM_IMA_APC; + case 279: + return AV_CODEC_ID_ADPCM_VIMA; + case 280: + return AV_CODEC_ID_ADPCM_AFC; + case 281: + return AV_CODEC_ID_ADPCM_IMA_OKI; + case 282: + return AV_CODEC_ID_ADPCM_DTK; + case 283: + return AV_CODEC_ID_ADPCM_IMA_RAD; + case 284: + return AV_CODEC_ID_ADPCM_G726LE; + case 285: + return AV_CODEC_ID_ADPCM_THP_LE; + case 286: + return AV_CODEC_ID_ADPCM_PSX; + case 287: + return AV_CODEC_ID_ADPCM_AICA; + case 288: + return AV_CODEC_ID_ADPCM_IMA_DAT4; + case 289: + return AV_CODEC_ID_ADPCM_MTAF; + case 290: + return AV_CODEC_ID_AMR_NB; + case 291: + return AV_CODEC_ID_AMR_WB; + case 292: + return AV_CODEC_ID_RA_144; + case 293: + return AV_CODEC_ID_RA_288; + case 294: + return AV_CODEC_ID_ROQ_DPCM; + case 295: + return AV_CODEC_ID_INTERPLAY_DPCM; + case 296: + return AV_CODEC_ID_XAN_DPCM; + case 297: + return AV_CODEC_ID_SOL_DPCM; + case 298: + return AV_CODEC_ID_SDX2_DPCM; + case 299: + return AV_CODEC_ID_MP2; + case 300: + return AV_CODEC_ID_MP3; + case 301: + return AV_CODEC_ID_AAC; + case 302: + return AV_CODEC_ID_AC3; + case 303: + return AV_CODEC_ID_DTS; + case 304: + return AV_CODEC_ID_VORBIS; + case 305: + return AV_CODEC_ID_DVAUDIO; + case 306: + return AV_CODEC_ID_WMAV1; + case 307: + return AV_CODEC_ID_WMAV2; + case 308: + return AV_CODEC_ID_MACE3; + case 309: + return AV_CODEC_ID_MACE6; + case 310: + return AV_CODEC_ID_VMDAUDIO; + case 311: + return AV_CODEC_ID_FLAC; + case 312: + return AV_CODEC_ID_MP3ADU; + case 313: + return AV_CODEC_ID_MP3ON4; + case 314: + return AV_CODEC_ID_SHORTEN; + case 315: + return AV_CODEC_ID_ALAC; + case 316: + return AV_CODEC_ID_WESTWOOD_SND1; + case 317: + return AV_CODEC_ID_GSM; + case 318: + return AV_CODEC_ID_QDM2; + case 319: + return AV_CODEC_ID_COOK; + case 320: + return AV_CODEC_ID_TRUESPEECH; + case 321: + return AV_CODEC_ID_TTA; + case 322: + return AV_CODEC_ID_SMACKAUDIO; + case 323: + return AV_CODEC_ID_QCELP; + case 324: + return AV_CODEC_ID_WAVPACK; + case 325: + return AV_CODEC_ID_DSICINAUDIO; + case 326: + return AV_CODEC_ID_IMC; + case 327: + return AV_CODEC_ID_MUSEPACK7; + case 328: + return AV_CODEC_ID_MLP; + case 329: + return AV_CODEC_ID_GSM_MS; + case 330: + return AV_CODEC_ID_ATRAC3; + // #[cfg(feature = "ff_api_voxware")] + // case 331: + // return AV_CODEC_ID_VOXWARE; + case 332: + return AV_CODEC_ID_APE; + case 333: + return AV_CODEC_ID_NELLYMOSER; + case 334: + return AV_CODEC_ID_MUSEPACK8; + case 335: + return AV_CODEC_ID_SPEEX; + case 336: + return AV_CODEC_ID_WMAVOICE; + case 337: + return AV_CODEC_ID_WMAPRO; + case 338: + return AV_CODEC_ID_WMALOSSLESS; + case 339: + return AV_CODEC_ID_ATRAC3P; + case 340: + return AV_CODEC_ID_EAC3; + case 341: + return AV_CODEC_ID_SIPR; + case 342: + return AV_CODEC_ID_MP1; + case 343: + return AV_CODEC_ID_TWINVQ; + case 344: + return AV_CODEC_ID_TRUEHD; + case 345: + return AV_CODEC_ID_MP4ALS; + case 346: + return AV_CODEC_ID_ATRAC1; + case 347: + return AV_CODEC_ID_BINKAUDIO_RDFT; + case 348: + return AV_CODEC_ID_BINKAUDIO_DCT; + case 349: + return AV_CODEC_ID_AAC_LATM; + case 350: + return AV_CODEC_ID_QDMC; + case 351: + return AV_CODEC_ID_CELT; + case 352: + return AV_CODEC_ID_G723_1; + case 353: + return AV_CODEC_ID_G729; + case 354: + return AV_CODEC_ID_8SVX_EXP; + case 355: + return AV_CODEC_ID_8SVX_FIB; + case 356: + return AV_CODEC_ID_BMV_AUDIO; + case 357: + return AV_CODEC_ID_RALF; + case 358: + return AV_CODEC_ID_IAC; + case 359: + return AV_CODEC_ID_ILBC; + case 360: + return AV_CODEC_ID_OPUS; + case 361: + return AV_CODEC_ID_COMFORT_NOISE; + case 362: + return AV_CODEC_ID_TAK; + case 363: + return AV_CODEC_ID_METASOUND; + case 364: + return AV_CODEC_ID_PAF_AUDIO; + case 365: + return AV_CODEC_ID_ON2AVC; + case 366: + return AV_CODEC_ID_DSS_SP; + case 367: + return AV_CODEC_ID_CODEC2; + case 368: + return AV_CODEC_ID_FFWAVESYNTH; + case 369: + return AV_CODEC_ID_SONIC; + case 370: + return AV_CODEC_ID_SONIC_LS; + case 371: + return AV_CODEC_ID_EVRC; + case 372: + return AV_CODEC_ID_SMV; + case 373: + return AV_CODEC_ID_DSD_LSBF; + case 374: + return AV_CODEC_ID_DSD_MSBF; + case 375: + return AV_CODEC_ID_DSD_LSBF_PLANAR; + case 376: + return AV_CODEC_ID_DSD_MSBF_PLANAR; + case 377: + return AV_CODEC_ID_4GV; + case 378: + return AV_CODEC_ID_INTERPLAY_ACM; + case 379: + return AV_CODEC_ID_XMA1; + case 380: + return AV_CODEC_ID_XMA2; + case 381: + return AV_CODEC_ID_DST; + ///////////// + ///////////// + ///////////// + case 382: + return AV_CODEC_ID_DVD_SUBTITLE; + case 383: + return AV_CODEC_ID_DVB_SUBTITLE; + case 384: + return AV_CODEC_ID_TEXT; + case 385: + return AV_CODEC_ID_XSUB; + case 386: + return AV_CODEC_ID_SSA; + case 387: + return AV_CODEC_ID_MOV_TEXT; + case 388: + return AV_CODEC_ID_HDMV_PGS_SUBTITLE; + case 389: + return AV_CODEC_ID_DVB_TELETEXT; + case 390: + return AV_CODEC_ID_SRT; + case 391: + return AV_CODEC_ID_MICRODVD; + case 392: + return AV_CODEC_ID_EIA_608; + case 393: + return AV_CODEC_ID_JACOSUB; + case 394: + return AV_CODEC_ID_SAMI; + case 395: + return AV_CODEC_ID_REALTEXT; + case 396: + return AV_CODEC_ID_STL; + case 397: + return AV_CODEC_ID_SUBVIEWER1; + case 398: + return AV_CODEC_ID_SUBVIEWER; + case 399: + return AV_CODEC_ID_SUBRIP; + case 400: + return AV_CODEC_ID_WEBVTT; + case 401: + return AV_CODEC_ID_MPL2; + case 402: + return AV_CODEC_ID_VPLAYER; + case 403: + return AV_CODEC_ID_PJS; + case 404: + return AV_CODEC_ID_ASS; + case 405: + return AV_CODEC_ID_HDMV_TEXT_SUBTITLE; + case 406: + return AV_CODEC_ID_TTF; + case 407: + return AV_CODEC_ID_SCTE_35; + case 408: + return AV_CODEC_ID_BINTEXT; + case 409: + return AV_CODEC_ID_XBIN; + case 410: + return AV_CODEC_ID_IDF; + case 411: + return AV_CODEC_ID_OTF; + case 412: + return AV_CODEC_ID_SMPTE_KLV; + case 413: + return AV_CODEC_ID_DVD_NAV; + case 414: + return AV_CODEC_ID_TIMED_ID3; + case 415: + return AV_CODEC_ID_BIN_DATA; + case 416: + return AV_CODEC_ID_PROBE; + case 417: + return AV_CODEC_ID_MPEG2TS; + case 418: + return AV_CODEC_ID_MPEG4SYSTEMS; + case 419: + return AV_CODEC_ID_FFMETADATA; + case 420: + return AV_CODEC_ID_WRAPPED_AVFRAME; + case 421: + return AV_CODEC_ID_PSD; + case 422: + return AV_CODEC_ID_PIXLET; + case 423: + return AV_CODEC_ID_SPEEDHQ; + case 424: + return AV_CODEC_ID_CLEARVIDEO; + case 425: + return AV_CODEC_ID_FMVC; + case 426: + return AV_CODEC_ID_SCPR; + case 427: + return AV_CODEC_ID_XPM; + case 428: + return AV_CODEC_ID_AV1; + case 429: + return AV_CODEC_ID_PCM_F16LE; + case 430: + return AV_CODEC_ID_PCM_F24LE; + //////////// + case 431: + return AV_CODEC_ID_ATRAC3AL; + case 432: + return AV_CODEC_ID_ATRAC3PAL; + case 433: + return AV_CODEC_ID_BITPACKED; + case 434: + return AV_CODEC_ID_MSCC; + case 435: + return AV_CODEC_ID_SRGC; + case 436: + return AV_CODEC_ID_SVG; + case 437: + return AV_CODEC_ID_GDV; + case 438: + return AV_CODEC_ID_FITS; + case 439: + return AV_CODEC_ID_GREMLIN_DPCM; + case 440: + return AV_CODEC_ID_DOLBY_E; + case 441: + return AV_CODEC_ID_APTX; + case 442: + return AV_CODEC_ID_APTX_HD; + case 443: + return AV_CODEC_ID_SBC; + case 444: + return AV_CODEC_ID_AVS2; + case 445: + return AV_CODEC_ID_IMM4; + case 446: + return AV_CODEC_ID_PROSUMER; + case 447: + return AV_CODEC_ID_MWSC; + case 448: + return AV_CODEC_ID_WCMV; + case 449: + return AV_CODEC_ID_RASC; + case 450: + return AV_CODEC_ID_PCM_VIDC; + case 451: + return AV_CODEC_ID_ATRAC9; + case 452: + return AV_CODEC_ID_TTML; + case 453: + return AV_CODEC_ID_HYMT; + case 454: + return AV_CODEC_ID_ARBC; + case 455: + return AV_CODEC_ID_AGM; + case 456: + return AV_CODEC_ID_LSCR; + case 457: + return AV_CODEC_ID_VP4; + case 458: + return AV_CODEC_ID_ADPCM_AGM; + case 459: + return AV_CODEC_ID_HCOM; + case 460: + return AV_CODEC_ID_ARIB_CAPTION; + case 461: + return AV_CODEC_ID_IMM5; + case 462: + return AV_CODEC_ID_MVDV; + case 463: + return AV_CODEC_ID_MVHA; + case 464: + return AV_CODEC_ID_CDTOONS; + case 465: + return AV_CODEC_ID_MV30; + case 466: + return AV_CODEC_ID_NOTCHLC; + case 467: + return AV_CODEC_ID_PFM; + case 468: + return AV_CODEC_ID_ARGO; + case 469: + return AV_CODEC_ID_ADPCM_IMA_SSI; + case 470: + return AV_CODEC_ID_ADPCM_ZORK; + case 471: + return AV_CODEC_ID_ADPCM_IMA_APM; + case 472: + return AV_CODEC_ID_ADPCM_IMA_ALP; + case 473: + return AV_CODEC_ID_ADPCM_IMA_MTF; + case 474: + return AV_CODEC_ID_ADPCM_IMA_CUNNING; + case 475: + return AV_CODEC_ID_DERF_DPCM; + case 476: + return AV_CODEC_ID_ACELP_KELVIN; + case 477: + return AV_CODEC_ID_MPEGH_3D_AUDIO; + case 478: + return AV_CODEC_ID_SIREN; + case 479: + return AV_CODEC_ID_HCA; + case 480: + return AV_CODEC_ID_EPG; + case 481: + return AV_CODEC_ID_AVS3; + case 482: + return AV_CODEC_ID_PGX; + case 483: + return AV_CODEC_ID_MSP2; + case 484: + return AV_CODEC_ID_VVC; + case 485: + return AV_CODEC_ID_MOBICLIP; + case 486: + return AV_CODEC_ID_PHOTOCD; + case 487: + return AV_CODEC_ID_ADPCM_ARGO; + case 488: + return AV_CODEC_ID_CRI; + case 489: + return AV_CODEC_ID_IPU; + case 490: + return AV_CODEC_ID_SIMBIOSIS_IMX; + case 491: + return AV_CODEC_ID_SGA_VIDEO; + case 492: + return AV_CODEC_ID_PCM_SGA; + case 493: + return AV_CODEC_ID_ADPCM_IMA_MOFLEX; + case 494: + return AV_CODEC_ID_FASTAUDIO; + case 495: + return AV_CODEC_ID_GEM; + case 496: + return AV_CODEC_ID_ADPCM_IMA_ACORN; + case 497: + return AV_CODEC_ID_MSNSIREN; + case 498: + return AV_CODEC_ID_VBN; + case 499: + return AV_CODEC_ID_JPEGXL; + case 500: + return AV_CODEC_ID_QOI; + case 501: + return AV_CODEC_ID_PHM; + case 502: + return AV_CODEC_ID_DFPWM; + case 503: + return AV_CODEC_ID_RADIANCE_HDR; + case 504: + return AV_CODEC_ID_WBMP; + case 505: + return AV_CODEC_ID_MEDIA100; + case 506: + return AV_CODEC_ID_VQC; + case 507: + return AV_CODEC_ID_ADPCM_XMD; + case 508: + return AV_CODEC_ID_WADY_DPCM; + case 509: + return AV_CODEC_ID_CBD2_DPCM; + case 510: + return AV_CODEC_ID_BONK; + case 511: + return AV_CODEC_ID_MISC4; + case 512: + return AV_CODEC_ID_APAC; + case 513: + return AV_CODEC_ID_FTR; + case 514: + return AV_CODEC_ID_WAVARC; + case 515: + return AV_CODEC_ID_RKA; + case 516: + return AV_CODEC_ID_VNULL; + case 517: + return AV_CODEC_ID_ANULL; + // case 518: + // return AV_CODEC_ID_MPEG2VIDEO_XVMC; + default: + return AV_CODEC_ID_NONE; + }; + } + + // Convert AVCodecID to uint32_t for rust SDK. + static uint32_t fromAVCodecID(AVCodecID AvCodecId) { + switch (AvCodecId) { + case AV_CODEC_ID_NONE: + return 0; + case AV_CODEC_ID_MPEG1VIDEO: + return 1; + case AV_CODEC_ID_MPEG2VIDEO: + return 2; + case AV_CODEC_ID_H261: + return 3; + case AV_CODEC_ID_H263: + return 4; + case AV_CODEC_ID_RV10: + return 5; + case AV_CODEC_ID_RV20: + return 6; + case AV_CODEC_ID_MJPEG: + return 7; + case AV_CODEC_ID_MJPEGB: + return 8; + case AV_CODEC_ID_LJPEG: + return 9; + case AV_CODEC_ID_SP5X: + return 10; + case AV_CODEC_ID_JPEGLS: + return 11; + case AV_CODEC_ID_MPEG4: + return 12; + case AV_CODEC_ID_RAWVIDEO: + return 13; + case AV_CODEC_ID_MSMPEG4V1: + return 14; + case AV_CODEC_ID_MSMPEG4V2: + return 15; + case AV_CODEC_ID_MSMPEG4V3: + return 16; + case AV_CODEC_ID_WMV1: + return 17; + case AV_CODEC_ID_WMV2: + return 18; + case AV_CODEC_ID_H263P: + return 19; + case AV_CODEC_ID_H263I: + return 20; + case AV_CODEC_ID_FLV1: + return 21; + case AV_CODEC_ID_SVQ1: + return 22; + case AV_CODEC_ID_SVQ3: + return 23; + case AV_CODEC_ID_DVVIDEO: + return 24; + case AV_CODEC_ID_HUFFYUV: + return 25; + case AV_CODEC_ID_CYUV: + return 26; + case AV_CODEC_ID_H264: + return 27; + case AV_CODEC_ID_INDEO3: + return 28; + case AV_CODEC_ID_VP3: + return 29; + case AV_CODEC_ID_THEORA: + return 30; + case AV_CODEC_ID_ASV1: + return 31; + case AV_CODEC_ID_ASV2: + return 32; + case AV_CODEC_ID_FFV1: + return 33; + case AV_CODEC_ID_4XM: + return 34; + case AV_CODEC_ID_VCR1: + return 35; + case AV_CODEC_ID_CLJR: + return 36; + case AV_CODEC_ID_MDEC: + return 37; + case AV_CODEC_ID_ROQ: + return 38; + case AV_CODEC_ID_INTERPLAY_VIDEO: + return 39; + case AV_CODEC_ID_XAN_WC3: + return 40; + case AV_CODEC_ID_XAN_WC4: + return 41; + case AV_CODEC_ID_RPZA: + return 42; + case AV_CODEC_ID_CINEPAK: + return 43; + case AV_CODEC_ID_WS_VQA: + return 44; + case AV_CODEC_ID_MSRLE: + return 45; + case AV_CODEC_ID_MSVIDEO1: + return 46; + case AV_CODEC_ID_IDCIN: + return 47; + case AV_CODEC_ID_8BPS: + return 48; + case AV_CODEC_ID_SMC: + return 49; + case AV_CODEC_ID_FLIC: + return 50; + case AV_CODEC_ID_TRUEMOTION1: + return 51; + case AV_CODEC_ID_VMDVIDEO: + return 52; + case AV_CODEC_ID_MSZH: + return 53; + case AV_CODEC_ID_ZLIB: + return 54; + case AV_CODEC_ID_QTRLE: + return 55; + case AV_CODEC_ID_TSCC: + return 56; + case AV_CODEC_ID_ULTI: + return 57; + case AV_CODEC_ID_QDRAW: + return 58; + case AV_CODEC_ID_VIXL: + return 59; + case AV_CODEC_ID_QPEG: + return 60; + case AV_CODEC_ID_PNG: + return 61; + case AV_CODEC_ID_PPM: + return 62; + case AV_CODEC_ID_PBM: + return 63; + case AV_CODEC_ID_PGM: + return 64; + case AV_CODEC_ID_PGMYUV: + return 65; + case AV_CODEC_ID_PAM: + return 66; + case AV_CODEC_ID_FFVHUFF: + return 67; + case AV_CODEC_ID_RV30: + return 68; + case AV_CODEC_ID_RV40: + return 69; + case AV_CODEC_ID_VC1: + return 70; + case AV_CODEC_ID_WMV3: + return 71; + case AV_CODEC_ID_LOCO: + return 72; + case AV_CODEC_ID_WNV1: + return 73; + case AV_CODEC_ID_AASC: + return 74; + case AV_CODEC_ID_INDEO2: + return 75; + case AV_CODEC_ID_FRAPS: + return 76; + case AV_CODEC_ID_TRUEMOTION2: + return 77; + case AV_CODEC_ID_BMP: + return 78; + case AV_CODEC_ID_CSCD: + return 79; + case AV_CODEC_ID_MMVIDEO: + return 80; + case AV_CODEC_ID_ZMBV: + return 81; + case AV_CODEC_ID_AVS: + return 82; + case AV_CODEC_ID_SMACKVIDEO: + return 83; + case AV_CODEC_ID_NUV: + return 84; + case AV_CODEC_ID_KMVC: + return 85; + case AV_CODEC_ID_FLASHSV: + return 86; + case AV_CODEC_ID_CAVS: + return 87; + case AV_CODEC_ID_JPEG2000: + return 88; + case AV_CODEC_ID_VMNC: + return 89; + case AV_CODEC_ID_VP5: + return 90; + case AV_CODEC_ID_VP6: + return 91; + case AV_CODEC_ID_VP6F: + return 92; + case AV_CODEC_ID_TARGA: + return 93; + case AV_CODEC_ID_DSICINVIDEO: + return 94; + case AV_CODEC_ID_TIERTEXSEQVIDEO: + return 95; + case AV_CODEC_ID_TIFF: + return 96; + case AV_CODEC_ID_GIF: + return 97; + case AV_CODEC_ID_DXA: + return 98; + case AV_CODEC_ID_DNXHD: + return 99; + case AV_CODEC_ID_THP: + return 100; + case AV_CODEC_ID_SGI: + return 101; + case AV_CODEC_ID_C93: + return 102; + case AV_CODEC_ID_BETHSOFTVID: + return 103; + case AV_CODEC_ID_PTX: + return 104; + case AV_CODEC_ID_TXD: + return 105; + case AV_CODEC_ID_VP6A: + return 106; + case AV_CODEC_ID_AMV: + return 107; + case AV_CODEC_ID_VB: + return 108; + case AV_CODEC_ID_PCX: + return 109; + case AV_CODEC_ID_SUNRAST: + return 110; + case AV_CODEC_ID_INDEO4: + return 111; + case AV_CODEC_ID_INDEO5: + return 112; + case AV_CODEC_ID_MIMIC: + return 113; + case AV_CODEC_ID_RL2: + return 114; + case AV_CODEC_ID_ESCAPE124: + return 115; + case AV_CODEC_ID_DIRAC: + return 116; + case AV_CODEC_ID_BFI: + return 117; + case AV_CODEC_ID_CMV: + return 118; + case AV_CODEC_ID_MOTIONPIXELS: + return 119; + case AV_CODEC_ID_TGV: + return 120; + case AV_CODEC_ID_TGQ: + return 121; + case AV_CODEC_ID_TQI: + return 122; + case AV_CODEC_ID_AURA: + return 123; + case AV_CODEC_ID_AURA2: + return 124; + case AV_CODEC_ID_V210X: + return 125; + case AV_CODEC_ID_TMV: + return 126; + case AV_CODEC_ID_V210: + return 127; + case AV_CODEC_ID_DPX: + return 128; + case AV_CODEC_ID_MAD: + return 129; + case AV_CODEC_ID_FRWU: + return 130; + case AV_CODEC_ID_FLASHSV2: + return 131; + case AV_CODEC_ID_CDGRAPHICS: + return 132; + case AV_CODEC_ID_R210: + return 133; + case AV_CODEC_ID_ANM: + return 134; + case AV_CODEC_ID_BINKVIDEO: + return 135; + case AV_CODEC_ID_IFF_ILBM: + return 136; + // case AV_CODEC_ID_IFF_ILBM: + // return 137; + case AV_CODEC_ID_KGV1: + return 138; + case AV_CODEC_ID_YOP: + return 139; + case AV_CODEC_ID_VP8: + return 140; + case AV_CODEC_ID_PICTOR: + return 141; + case AV_CODEC_ID_ANSI: + return 142; + case AV_CODEC_ID_A64_MULTI: + return 143; + case AV_CODEC_ID_A64_MULTI5: + return 144; + case AV_CODEC_ID_R10K: + return 145; + case AV_CODEC_ID_MXPEG: + return 146; + case AV_CODEC_ID_LAGARITH: + return 147; + case AV_CODEC_ID_PRORES: + return 148; + case AV_CODEC_ID_JV: + return 149; + case AV_CODEC_ID_DFA: + return 150; + case AV_CODEC_ID_WMV3IMAGE: + return 151; + case AV_CODEC_ID_VC1IMAGE: + return 152; + case AV_CODEC_ID_UTVIDEO: + return 153; + case AV_CODEC_ID_BMV_VIDEO: + return 154; + case AV_CODEC_ID_VBLE: + return 155; + case AV_CODEC_ID_DXTORY: + return 156; + case AV_CODEC_ID_V410: + return 157; + case AV_CODEC_ID_XWD: + return 158; + case AV_CODEC_ID_CDXL: + return 159; + case AV_CODEC_ID_XBM: + return 160; + case AV_CODEC_ID_ZEROCODEC: + return 161; + case AV_CODEC_ID_MSS1: + return 162; + case AV_CODEC_ID_MSA1: + return 163; + case AV_CODEC_ID_TSCC2: + return 164; + case AV_CODEC_ID_MTS2: + return 165; + case AV_CODEC_ID_CLLC: + return 166; + case AV_CODEC_ID_MSS2: + return 167; + case AV_CODEC_ID_VP9: + return 168; + case AV_CODEC_ID_AIC: + return 169; + case AV_CODEC_ID_ESCAPE130: + return 170; + case AV_CODEC_ID_G2M: + return 171; + case AV_CODEC_ID_WEBP: + return 172; + case AV_CODEC_ID_HNM4_VIDEO: + return 173; + case AV_CODEC_ID_HEVC: + return 174; + // case AV_CODEC_ID_HEVC: + // return 175; + case AV_CODEC_ID_FIC: + return 176; + case AV_CODEC_ID_ALIAS_PIX: + return 177; + case AV_CODEC_ID_BRENDER_PIX: + return 178; + case AV_CODEC_ID_PAF_VIDEO: + return 179; + case AV_CODEC_ID_EXR: + return 180; + case AV_CODEC_ID_VP7: + return 181; + case AV_CODEC_ID_SANM: + return 182; + case AV_CODEC_ID_SGIRLE: + return 183; + case AV_CODEC_ID_MVC1: + return 184; + case AV_CODEC_ID_MVC2: + return 185; + case AV_CODEC_ID_HQX: + return 186; + case AV_CODEC_ID_TDSC: + return 187; + case AV_CODEC_ID_HQ_HQA: + return 188; + case AV_CODEC_ID_HAP: + return 189; + case AV_CODEC_ID_DDS: + return 190; + case AV_CODEC_ID_DXV: + return 191; + case AV_CODEC_ID_SCREENPRESSO: + return 192; + case AV_CODEC_ID_RSCC: + return 193; + /////////////////////////////// + // return ; + // case AV_CODEC_ID_Y41P: + // return ; + // case AV_CODEC_ID_AVS2: + case AV_CODEC_ID_Y41P: + return 194; + case AV_CODEC_ID_AVRP: + return 195; + case AV_CODEC_ID_012V: + return 196; + case AV_CODEC_ID_AVUI: + return 197; + case AV_CODEC_ID_AYUV: + return 198; + case AV_CODEC_ID_TARGA_Y216: + return 199; + case AV_CODEC_ID_V308: + return 200; + case AV_CODEC_ID_V408: + return 201; + case AV_CODEC_ID_YUV4: + return 202; + case AV_CODEC_ID_AVRN: + return 203; + case AV_CODEC_ID_CPIA: + return 204; + case AV_CODEC_ID_XFACE: + return 205; + case AV_CODEC_ID_SNOW: + return 206; + case AV_CODEC_ID_SMVJPEG: + return 207; + case AV_CODEC_ID_APNG: + return 208; + case AV_CODEC_ID_DAALA: + return 209; + case AV_CODEC_ID_CFHD: + return 210; + case AV_CODEC_ID_TRUEMOTION2RT: + return 211; + case AV_CODEC_ID_M101: + return 212; + case AV_CODEC_ID_MAGICYUV: + return 213; + case AV_CODEC_ID_SHEERVIDEO: + return 214; + case AV_CODEC_ID_YLC: + return 215; + // ================================= + // ================================= + // ================================= + case AV_CODEC_ID_PCM_S16LE: + return 216; + case AV_CODEC_ID_PCM_S16BE: + return 217; + case AV_CODEC_ID_PCM_U16LE: + return 218; + case AV_CODEC_ID_PCM_U16BE: + return 219; + case AV_CODEC_ID_PCM_S8: + return 220; + case AV_CODEC_ID_PCM_U8: + return 221; + case AV_CODEC_ID_PCM_MULAW: + return 222; + case AV_CODEC_ID_PCM_ALAW: + return 223; + case AV_CODEC_ID_PCM_S32LE: + return 224; + case AV_CODEC_ID_PCM_S32BE: + return 225; + case AV_CODEC_ID_PCM_U32LE: + return 226; + case AV_CODEC_ID_PCM_U32BE: + return 227; + case AV_CODEC_ID_PCM_S24LE: + return 228; + case AV_CODEC_ID_PCM_S24BE: + return 229; + case AV_CODEC_ID_PCM_U24LE: + return 230; + case AV_CODEC_ID_PCM_U24BE: + return 231; + case AV_CODEC_ID_PCM_S24DAUD: + return 232; + case AV_CODEC_ID_PCM_ZORK: + return 233; + case AV_CODEC_ID_PCM_S16LE_PLANAR: + return 234; + case AV_CODEC_ID_PCM_DVD: + return 235; + case AV_CODEC_ID_PCM_F32BE: + return 236; + case AV_CODEC_ID_PCM_F32LE: + return 237; + case AV_CODEC_ID_PCM_F64BE: + return 238; + case AV_CODEC_ID_PCM_F64LE: + return 239; + case AV_CODEC_ID_PCM_BLURAY: + return 240; + case AV_CODEC_ID_PCM_LXF: + return 241; + case AV_CODEC_ID_S302M: + return 242; + case AV_CODEC_ID_PCM_S8_PLANAR: + return 243; + case AV_CODEC_ID_PCM_S24LE_PLANAR: + return 244; + case AV_CODEC_ID_PCM_S32LE_PLANAR: + return 245; + case AV_CODEC_ID_PCM_S16BE_PLANAR: + return 246; + case AV_CODEC_ID_PCM_S64LE: + return 247; + case AV_CODEC_ID_PCM_S64BE: + return 248; + case AV_CODEC_ID_ADPCM_IMA_QT: + return 249; + case AV_CODEC_ID_ADPCM_IMA_WAV: + return 250; + case AV_CODEC_ID_ADPCM_IMA_DK3: + return 251; + case AV_CODEC_ID_ADPCM_IMA_DK4: + return 252; + case AV_CODEC_ID_ADPCM_IMA_WS: + return 253; + case AV_CODEC_ID_ADPCM_IMA_SMJPEG: + return 254; + case AV_CODEC_ID_ADPCM_MS: + return 255; + case AV_CODEC_ID_ADPCM_4XM: + return 256; + case AV_CODEC_ID_ADPCM_XA: + return 257; + case AV_CODEC_ID_ADPCM_ADX: + return 258; + case AV_CODEC_ID_ADPCM_EA: + return 259; + case AV_CODEC_ID_ADPCM_G726: + return 260; + case AV_CODEC_ID_ADPCM_CT: + return 261; + case AV_CODEC_ID_ADPCM_SWF: + return 262; + case AV_CODEC_ID_ADPCM_YAMAHA: + return 263; + case AV_CODEC_ID_ADPCM_SBPRO_4: + return 264; + case AV_CODEC_ID_ADPCM_SBPRO_3: + return 265; + case AV_CODEC_ID_ADPCM_SBPRO_2: + return 266; + case AV_CODEC_ID_ADPCM_THP: + return 267; + case AV_CODEC_ID_ADPCM_IMA_AMV: + return 268; + case AV_CODEC_ID_ADPCM_EA_R1: + return 269; + case AV_CODEC_ID_ADPCM_EA_R3: + return 270; + case AV_CODEC_ID_ADPCM_EA_R2: + return 271; + case AV_CODEC_ID_ADPCM_IMA_EA_SEAD: + return 272; + case AV_CODEC_ID_ADPCM_IMA_EA_EACS: + return 273; + case AV_CODEC_ID_ADPCM_EA_XAS: + return 274; + case AV_CODEC_ID_ADPCM_EA_MAXIS_XA: + return 275; + case AV_CODEC_ID_ADPCM_IMA_ISS: + return 276; + case AV_CODEC_ID_ADPCM_G722: + return 277; + case AV_CODEC_ID_ADPCM_IMA_APC: + return 278; + case AV_CODEC_ID_ADPCM_VIMA: + return 279; + case AV_CODEC_ID_ADPCM_AFC: + return 280; + case AV_CODEC_ID_ADPCM_IMA_OKI: + return 281; + case AV_CODEC_ID_ADPCM_DTK: + return 282; + case AV_CODEC_ID_ADPCM_IMA_RAD: + return 283; + case AV_CODEC_ID_ADPCM_G726LE: + return 284; + case AV_CODEC_ID_ADPCM_THP_LE: + return 285; + case AV_CODEC_ID_ADPCM_PSX: + return 286; + case AV_CODEC_ID_ADPCM_AICA: + return 287; + case AV_CODEC_ID_ADPCM_IMA_DAT4: + return 288; + case AV_CODEC_ID_ADPCM_MTAF: + return 289; + case AV_CODEC_ID_AMR_NB: + return 290; + case AV_CODEC_ID_AMR_WB: + return 291; + case AV_CODEC_ID_RA_144: + return 292; + case AV_CODEC_ID_RA_288: + return 293; + case AV_CODEC_ID_ROQ_DPCM: + return 294; + case AV_CODEC_ID_INTERPLAY_DPCM: + return 295; + case AV_CODEC_ID_XAN_DPCM: + return 296; + case AV_CODEC_ID_SOL_DPCM: + return 297; + case AV_CODEC_ID_SDX2_DPCM: + return 298; + case AV_CODEC_ID_MP2: + return 299; + case AV_CODEC_ID_MP3: + return 300; + case AV_CODEC_ID_AAC: + return 301; + case AV_CODEC_ID_AC3: + return 302; + case AV_CODEC_ID_DTS: + return 303; + case AV_CODEC_ID_VORBIS: + return 304; + case AV_CODEC_ID_DVAUDIO: + return 305; + case AV_CODEC_ID_WMAV1: + return 306; + case AV_CODEC_ID_WMAV2: + return 307; + case AV_CODEC_ID_MACE3: + return 308; + case AV_CODEC_ID_MACE6: + return 309; + case AV_CODEC_ID_VMDAUDIO: + return 310; + case AV_CODEC_ID_FLAC: + return 311; + case AV_CODEC_ID_MP3ADU: + return 312; + case AV_CODEC_ID_MP3ON4: + return 313; + case AV_CODEC_ID_SHORTEN: + return 314; + case AV_CODEC_ID_ALAC: + return 315; + case AV_CODEC_ID_WESTWOOD_SND1: + return 316; + case AV_CODEC_ID_GSM: + return 317; + case AV_CODEC_ID_QDM2: + return 318; + case AV_CODEC_ID_COOK: + return 319; + case AV_CODEC_ID_TRUESPEECH: + return 320; + case AV_CODEC_ID_TTA: + return 321; + case AV_CODEC_ID_SMACKAUDIO: + return 322; + case AV_CODEC_ID_QCELP: + return 323; + case AV_CODEC_ID_WAVPACK: + return 324; + case AV_CODEC_ID_DSICINAUDIO: + return 325; + case AV_CODEC_ID_IMC: + return 326; + case AV_CODEC_ID_MUSEPACK7: + return 327; + case AV_CODEC_ID_MLP: + return 328; + case AV_CODEC_ID_GSM_MS: + return 329; + case AV_CODEC_ID_ATRAC3: + return 330; + // #[cfg(feature = "ff_api_voxware")] + // case AV_CODEC_ID_VOXWARE: + // return 331; + case AV_CODEC_ID_APE: + return 332; + case AV_CODEC_ID_NELLYMOSER: + return 333; + case AV_CODEC_ID_MUSEPACK8: + return 334; + case AV_CODEC_ID_SPEEX: + return 335; + case AV_CODEC_ID_WMAVOICE: + return 336; + case AV_CODEC_ID_WMAPRO: + return 337; + case AV_CODEC_ID_WMALOSSLESS: + return 338; + case AV_CODEC_ID_ATRAC3P: + return 339; + case AV_CODEC_ID_EAC3: + return 340; + case AV_CODEC_ID_SIPR: + return 341; + case AV_CODEC_ID_MP1: + return 342; + case AV_CODEC_ID_TWINVQ: + return 343; + case AV_CODEC_ID_TRUEHD: + return 344; + case AV_CODEC_ID_MP4ALS: + return 345; + case AV_CODEC_ID_ATRAC1: + return 346; + case AV_CODEC_ID_BINKAUDIO_RDFT: + return 347; + case AV_CODEC_ID_BINKAUDIO_DCT: + return 348; + case AV_CODEC_ID_AAC_LATM: + return 349; + case AV_CODEC_ID_QDMC: + return 350; + case AV_CODEC_ID_CELT: + return 351; + case AV_CODEC_ID_G723_1: + return 352; + case AV_CODEC_ID_G729: + return 353; + case AV_CODEC_ID_8SVX_EXP: + return 354; + case AV_CODEC_ID_8SVX_FIB: + return 355; + case AV_CODEC_ID_BMV_AUDIO: + return 356; + case AV_CODEC_ID_RALF: + return 357; + case AV_CODEC_ID_IAC: + return 358; + case AV_CODEC_ID_ILBC: + return 359; + case AV_CODEC_ID_OPUS: + return 360; + case AV_CODEC_ID_COMFORT_NOISE: + return 361; + case AV_CODEC_ID_TAK: + return 362; + case AV_CODEC_ID_METASOUND: + return 363; + case AV_CODEC_ID_PAF_AUDIO: + return 364; + case AV_CODEC_ID_ON2AVC: + return 365; + case AV_CODEC_ID_DSS_SP: + return 366; + case AV_CODEC_ID_CODEC2: + return 367; + case AV_CODEC_ID_FFWAVESYNTH: + return 368; + case AV_CODEC_ID_SONIC: + return 369; + case AV_CODEC_ID_SONIC_LS: + return 370; + case AV_CODEC_ID_EVRC: + return 371; + case AV_CODEC_ID_SMV: + return 372; + case AV_CODEC_ID_DSD_LSBF: + return 373; + case AV_CODEC_ID_DSD_MSBF: + return 374; + case AV_CODEC_ID_DSD_LSBF_PLANAR: + return 375; + case AV_CODEC_ID_DSD_MSBF_PLANAR: + return 376; + case AV_CODEC_ID_4GV: + return 377; + case AV_CODEC_ID_INTERPLAY_ACM: + return 378; + case AV_CODEC_ID_XMA1: + return 379; + case AV_CODEC_ID_XMA2: + return 380; + case AV_CODEC_ID_DST: + return 381; + case AV_CODEC_ID_DVD_SUBTITLE: + return 382; + case AV_CODEC_ID_DVB_SUBTITLE: + return 383; + case AV_CODEC_ID_TEXT: + return 384; + case AV_CODEC_ID_XSUB: + return 385; + case AV_CODEC_ID_SSA: + return 386; + case AV_CODEC_ID_MOV_TEXT: + return 387; + case AV_CODEC_ID_HDMV_PGS_SUBTITLE: + return 388; + case AV_CODEC_ID_DVB_TELETEXT: + return 389; + case AV_CODEC_ID_SRT: + return 390; + case AV_CODEC_ID_MICRODVD: + return 391; + case AV_CODEC_ID_EIA_608: + return 392; + case AV_CODEC_ID_JACOSUB: + return 393; + case AV_CODEC_ID_SAMI: + return 394; + case AV_CODEC_ID_REALTEXT: + return 395; + case AV_CODEC_ID_STL: + return 396; + case AV_CODEC_ID_SUBVIEWER1: + return 397; + case AV_CODEC_ID_SUBVIEWER: + return 398; + case AV_CODEC_ID_SUBRIP: + return 399; + case AV_CODEC_ID_WEBVTT: + return 400; + case AV_CODEC_ID_MPL2: + return 401; + case AV_CODEC_ID_VPLAYER: + return 402; + case AV_CODEC_ID_PJS: + return 403; + case AV_CODEC_ID_ASS: + return 404; + case AV_CODEC_ID_HDMV_TEXT_SUBTITLE: + return 405; + case AV_CODEC_ID_TTF: + return 406; + case AV_CODEC_ID_SCTE_35: + return 407; + case AV_CODEC_ID_BINTEXT: + return 408; + case AV_CODEC_ID_XBIN: + return 409; + case AV_CODEC_ID_IDF: + return 410; + case AV_CODEC_ID_OTF: + return 411; + case AV_CODEC_ID_SMPTE_KLV: + return 412; + case AV_CODEC_ID_DVD_NAV: + return 413; + case AV_CODEC_ID_TIMED_ID3: + return 414; + case AV_CODEC_ID_BIN_DATA: + return 415; + case AV_CODEC_ID_PROBE: + return 416; + case AV_CODEC_ID_MPEG2TS: + return 417; + case AV_CODEC_ID_MPEG4SYSTEMS: + return 418; + case AV_CODEC_ID_FFMETADATA: + return 419; + case AV_CODEC_ID_WRAPPED_AVFRAME: + return 420; + case AV_CODEC_ID_PSD: + return 421; + case AV_CODEC_ID_PIXLET: + return 422; + case AV_CODEC_ID_SPEEDHQ: + return 423; + case AV_CODEC_ID_CLEARVIDEO: + return 424; + case AV_CODEC_ID_FMVC: + return 425; + case AV_CODEC_ID_SCPR: + return 426; + case AV_CODEC_ID_XPM: + return 427; + case AV_CODEC_ID_AV1: + return 428; + case AV_CODEC_ID_PCM_F16LE: + return 429; + case AV_CODEC_ID_PCM_F24LE: + return 430; + //////////// + case AV_CODEC_ID_ATRAC3AL: + return 431; + case AV_CODEC_ID_ATRAC3PAL: + return 432; + case AV_CODEC_ID_BITPACKED: + return 433; + case AV_CODEC_ID_MSCC: + return 434; + case AV_CODEC_ID_SRGC: + return 435; + case AV_CODEC_ID_SVG: + return 436; + case AV_CODEC_ID_GDV: + return 437; + case AV_CODEC_ID_FITS: + return 438; + case AV_CODEC_ID_GREMLIN_DPCM: + return 439; + case AV_CODEC_ID_DOLBY_E: + return 440; + case AV_CODEC_ID_APTX: + return 441; + case AV_CODEC_ID_APTX_HD: + return 442; + case AV_CODEC_ID_SBC: + return 443; + case AV_CODEC_ID_AVS2: + return 444; + case AV_CODEC_ID_IMM4: + return 445; + case AV_CODEC_ID_PROSUMER: + return 446; + case AV_CODEC_ID_MWSC: + return 447; + case AV_CODEC_ID_WCMV: + return 448; + case AV_CODEC_ID_RASC: + return 449; + case AV_CODEC_ID_PCM_VIDC: + return 450; + case AV_CODEC_ID_ATRAC9: + return 451; + case AV_CODEC_ID_TTML: + return 452; + case AV_CODEC_ID_HYMT: + return 453; + case AV_CODEC_ID_ARBC: + return 454; + case AV_CODEC_ID_AGM: + return 455; + case AV_CODEC_ID_LSCR: + return 456; + case AV_CODEC_ID_VP4: + return 457; + case AV_CODEC_ID_ADPCM_AGM: + return 458; + case AV_CODEC_ID_HCOM: + return 459; + case AV_CODEC_ID_ARIB_CAPTION: + return 460; + case AV_CODEC_ID_IMM5: + return 461; + case AV_CODEC_ID_MVDV: + return 462; + case AV_CODEC_ID_MVHA: + return 463; + case AV_CODEC_ID_CDTOONS: + return 464; + case AV_CODEC_ID_MV30: + return 465; + case AV_CODEC_ID_NOTCHLC: + return 466; + case AV_CODEC_ID_PFM: + return 467; + case AV_CODEC_ID_ARGO: + return 468; + case AV_CODEC_ID_ADPCM_IMA_SSI: + return 469; + case AV_CODEC_ID_ADPCM_ZORK: + return 470; + case AV_CODEC_ID_ADPCM_IMA_APM: + return 471; + case AV_CODEC_ID_ADPCM_IMA_ALP: + return 472; + case AV_CODEC_ID_ADPCM_IMA_MTF: + return 473; + case AV_CODEC_ID_ADPCM_IMA_CUNNING: + return 474; + case AV_CODEC_ID_DERF_DPCM: + return 475; + case AV_CODEC_ID_ACELP_KELVIN: + return 476; + case AV_CODEC_ID_MPEGH_3D_AUDIO: + return 477; + case AV_CODEC_ID_SIREN: + return 478; + case AV_CODEC_ID_HCA: + return 479; + case AV_CODEC_ID_EPG: + return 480; + case AV_CODEC_ID_AVS3: + return 481; + case AV_CODEC_ID_PGX: + return 482; + case AV_CODEC_ID_MSP2: + return 483; + case AV_CODEC_ID_VVC: + return 484; + case AV_CODEC_ID_MOBICLIP: + return 485; + case AV_CODEC_ID_PHOTOCD: + return 486; + case AV_CODEC_ID_ADPCM_ARGO: + return 487; + case AV_CODEC_ID_CRI: + return 488; + case AV_CODEC_ID_IPU: + return 489; + case AV_CODEC_ID_SIMBIOSIS_IMX: + return 490; + case AV_CODEC_ID_SGA_VIDEO: + return 491; + case AV_CODEC_ID_PCM_SGA: + return 492; + case AV_CODEC_ID_ADPCM_IMA_MOFLEX: + return 493; + case AV_CODEC_ID_FASTAUDIO: + return 494; + case AV_CODEC_ID_GEM: + return 495; + case AV_CODEC_ID_ADPCM_IMA_ACORN: + return 496; + case AV_CODEC_ID_MSNSIREN: + return 497; + case AV_CODEC_ID_VBN: + return 498; + case AV_CODEC_ID_JPEGXL: + return 499; + case AV_CODEC_ID_QOI: + return 500; + case AV_CODEC_ID_PHM: + return 501; + case AV_CODEC_ID_DFPWM: + return 502; + case AV_CODEC_ID_RADIANCE_HDR: + return 503; + case AV_CODEC_ID_WBMP: + return 504; + case AV_CODEC_ID_MEDIA100: + return 505; + case AV_CODEC_ID_VQC: + return 506; + case AV_CODEC_ID_ADPCM_XMD: + return 507; + case AV_CODEC_ID_WADY_DPCM: + return 508; + case AV_CODEC_ID_CBD2_DPCM: + return 509; + case AV_CODEC_ID_BONK: + return 510; + case AV_CODEC_ID_MISC4: + return 511; + case AV_CODEC_ID_APAC: + return 512; + case AV_CODEC_ID_FTR: + return 513; + case AV_CODEC_ID_WAVARC: + return 514; + case AV_CODEC_ID_RKA: + return 515; + case AV_CODEC_ID_VNULL: + return 516; + case AV_CODEC_ID_ANULL: + return 517; + // case AV_CODEC_ID_MPEG2VIDEO_XVMC: + // return 518; + default: + return 0; + } + } +}; + +class PixFmt { + +public: + static uint32_t fromAVPixFmt(AVPixelFormat AvPixelFormat) { + switch (AvPixelFormat) { + case AV_PIX_FMT_NONE: + return 0; + case AV_PIX_FMT_YUV420P: + return 1; + case AV_PIX_FMT_YUYV422: + return 2; + case AV_PIX_FMT_RGB24: + return 3; + case AV_PIX_FMT_BGR24: + return 4; + case AV_PIX_FMT_YUV422P: + return 5; + case AV_PIX_FMT_YUV444P: + return 7; + case AV_PIX_FMT_YUV410P: + return 8; + case AV_PIX_FMT_YUV411P: + return 9; + case AV_PIX_FMT_GRAY8: + return 10; + case AV_PIX_FMT_MONOWHITE: + return 11; + case AV_PIX_FMT_MONOBLACK: + return 12; + case AV_PIX_FMT_PAL8: + return 13; + case AV_PIX_FMT_YUVJ420P: + return 14; + case AV_PIX_FMT_YUVJ422P: + return 15; + case AV_PIX_FMT_YUVJ444P: + return 16; + // case AV_PIX_FMT_XVMC_MPEG2_MC : // Lower FFmpeg Version + // return 17; + // case AV_PIX_FMT_XVMC_MPEG2_IDCT : + // return 18; + case AV_PIX_FMT_UYVY422: + return 19; + case AV_PIX_FMT_UYYVYY411: + return 20; + case AV_PIX_FMT_BGR8: + return 21; + case AV_PIX_FMT_BGR4: + return 22; + case AV_PIX_FMT_BGR4_BYTE: + return 23; + case AV_PIX_FMT_RGB8: + return 24; + case AV_PIX_FMT_RGB4: + return 25; + case AV_PIX_FMT_RGB4_BYTE: + return 26; + case AV_PIX_FMT_NV12: + return 27; + case AV_PIX_FMT_NV21: + return 28; + case AV_PIX_FMT_ARGB: // Big Endian + return 29; + case AV_PIX_FMT_RGBA: // Big + return 30; + case AV_PIX_FMT_ABGR: // Big + return 31; + case AV_PIX_FMT_BGRA: // little + return 32; + case AV_PIX_FMT_GRAY16BE: // big + return 33; + case AV_PIX_FMT_GRAY16LE: + return 34; + case AV_PIX_FMT_YUV440P: + return 35; + case AV_PIX_FMT_YUVJ440P: + return 36; + case AV_PIX_FMT_YUVA420P: + return 37; + // case AV_PIX_FMT_VDPAU_H264 : + // return 38; + // case AV_PIX_FMT_VDPAU_MPEG1 : + // return 39; + // case AV_PIX_FMT_VDPAU_MPEG2 : + // return 40; + // case AV_PIX_FMT_VDPAU_WMV3 : // Conditional compile. + // return 41; + // case AV_PIX_FMT_VDPAU_VC1 : // ff_api_vdpau is present + // return 42; + case AV_PIX_FMT_RGB48BE: + return 43; + case AV_PIX_FMT_RGB48LE: + return 44; + case AV_PIX_FMT_RGB565BE: + return 45; + case AV_PIX_FMT_RGB565LE: + return 46; + case AV_PIX_FMT_RGB555BE: + return 47; + case AV_PIX_FMT_RGB555LE: + return 48; + case AV_PIX_FMT_BGR565BE: + return 49; + case AV_PIX_FMT_BGR565LE: + return 50; + case AV_PIX_FMT_BGR555BE: + return 51; + case AV_PIX_FMT_BGR555LE: + return 52; + // case AV_PIX_FMT_VAAPI_MOCO : + // return 53; + // case AV_PIX_FMT_VAAPI_IDCT : + // return 54; + // case AV_PIX_FMT_VAAPI_VLD : + // return 55; + // case AV_PIX_FMT_VAAPI : // ff_api_vdpau is present + // return 56; + case AV_PIX_FMT_YUV420P16LE: + return 57; + case AV_PIX_FMT_YUV420P16BE: + return 58; + case AV_PIX_FMT_YUV422P16LE: + return 59; + case AV_PIX_FMT_YUV422P16BE: + return 60; + case AV_PIX_FMT_YUV444P16LE: + return 61; + case AV_PIX_FMT_YUV444P16BE: + return 62; + // case AV_PIX_FMT_VDPAU_MPEG4 : // ff_api_vdpau is present + // return 63; + case AV_PIX_FMT_DXVA2_VLD: + return 64; + case AV_PIX_FMT_RGB444LE: + return 65; + case AV_PIX_FMT_RGB444BE: + return 66; + case AV_PIX_FMT_BGR444LE: + return 67; + case AV_PIX_FMT_BGR444BE: + return 68; + case AV_PIX_FMT_YA8: + return 69; + case AV_PIX_FMT_BGR48BE: + return 70; + case AV_PIX_FMT_BGR48LE: + return 71; + case AV_PIX_FMT_YUV420P9BE: + return 72; + case AV_PIX_FMT_YUV420P9LE: + return 73; + case AV_PIX_FMT_YUV420P10BE: + return 74; + case AV_PIX_FMT_YUV420P10LE: + return 75; + case AV_PIX_FMT_YUV422P10BE: + return 76; + case AV_PIX_FMT_YUV422P10LE: + return 77; + case AV_PIX_FMT_YUV444P9BE: + return 78; + case AV_PIX_FMT_YUV444P9LE: + return 79; + case AV_PIX_FMT_YUV444P10BE: + return 80; + case AV_PIX_FMT_YUV444P10LE: + return 81; + case AV_PIX_FMT_YUV422P9BE: + return 82; + case AV_PIX_FMT_YUV422P9LE: + return 83; + // case AV_PIX_FMT_VDA_VLD : // Lower than ffmpeg version 4 + // return 84; + case AV_PIX_FMT_GBRP: + return 85; + case AV_PIX_FMT_GBRP9BE: + return 86; + case AV_PIX_FMT_GBRP9LE: + return 87; + case AV_PIX_FMT_GBRP10BE: + return 88; + case AV_PIX_FMT_GBRP10LE: + return 89; + case AV_PIX_FMT_GBRP16BE: + return 90; + case AV_PIX_FMT_GBRP16LE: + return 91; + case AV_PIX_FMT_YUVA420P9BE: + return 92; + case AV_PIX_FMT_YUVA420P9LE: + return 93; + case AV_PIX_FMT_YUVA422P9BE: + return 94; + case AV_PIX_FMT_YUVA422P9LE: + return 95; + case AV_PIX_FMT_YUVA444P9BE: + return 96; + case AV_PIX_FMT_YUVA444P9LE: + return 97; + case AV_PIX_FMT_YUVA420P10BE: + return 98; + case AV_PIX_FMT_YUVA420P10LE: + return 99; + case AV_PIX_FMT_YUVA422P10BE: + return 100; + case AV_PIX_FMT_YUVA422P10LE: + return 101; + case AV_PIX_FMT_YUVA444P10BE: + return 102; + case AV_PIX_FMT_YUVA444P10LE: + return 103; + case AV_PIX_FMT_YUVA420P16BE: + return 104; + case AV_PIX_FMT_YUVA420P16LE: + return 105; + case AV_PIX_FMT_YUVA422P16BE: + return 106; + case AV_PIX_FMT_YUVA422P16LE: + return 107; + case AV_PIX_FMT_YUVA444P16BE: + return 108; + case AV_PIX_FMT_YUVA444P16LE: + return 109; + case AV_PIX_FMT_VDPAU: + return 110; + case AV_PIX_FMT_XYZ12LE: + return 111; + case AV_PIX_FMT_XYZ12BE: + return 112; + case AV_PIX_FMT_NV16: + return 113; + case AV_PIX_FMT_NV20LE: + return 114; + case AV_PIX_FMT_NV20BE: + return 115; + case AV_PIX_FMT_RGBA64BE: + return 116; + case AV_PIX_FMT_RGBA64LE: + return 117; + case AV_PIX_FMT_BGRA64BE: + return 118; + case AV_PIX_FMT_BGRA64LE: + return 119; + case AV_PIX_FMT_YVYU422: + return 120; + // case AV_PIX_FMT_VDA : // Lower than ffmpeg version 4. + // return 121; + case AV_PIX_FMT_YA16BE: // big + return 122; + case AV_PIX_FMT_YA16LE: + return 123; + case AV_PIX_FMT_QSV: + return 124; + case AV_PIX_FMT_MMAL: + return 125; + case AV_PIX_FMT_D3D11VA_VLD: + return 126; + case AV_PIX_FMT_CUDA: + return 127; + case AV_PIX_FMT_0RGB: // big + return 128; + case AV_PIX_FMT_RGB0: + return 129; + case AV_PIX_FMT_0BGR: // big + return 130; + case AV_PIX_FMT_BGR0: + return 131; + case AV_PIX_FMT_YUVA444P: + return 132; + case AV_PIX_FMT_YUVA422P: + return 133; + case AV_PIX_FMT_YUV420P12BE: + return 134; + case AV_PIX_FMT_YUV420P12LE: + return 135; + case AV_PIX_FMT_YUV420P14BE: + return 136; + case AV_PIX_FMT_YUV420P14LE: + return 137; + case AV_PIX_FMT_YUV422P12BE: + return 138; + case AV_PIX_FMT_YUV422P12LE: + return 139; + case AV_PIX_FMT_YUV422P14BE: + return 140; + case AV_PIX_FMT_YUV422P14LE: + return 141; + case AV_PIX_FMT_YUV444P12BE: + return 142; + case AV_PIX_FMT_YUV444P12LE: + return 143; + case AV_PIX_FMT_YUV444P14BE: + return 144; + case AV_PIX_FMT_YUV444P14LE: + return 146; + case AV_PIX_FMT_GBRP12BE: + return 147; + case AV_PIX_FMT_GBRP12LE: + return 148; + case AV_PIX_FMT_GBRP14BE: + return 149; + case AV_PIX_FMT_GBRP14LE: + return 150; + case AV_PIX_FMT_GBRAP: + return 151; + case AV_PIX_FMT_GBRAP16BE: + return 152; + case AV_PIX_FMT_GBRAP16LE: + return 153; + case AV_PIX_FMT_YUVJ411P: + return 154; + case AV_PIX_FMT_BAYER_BGGR8: + return 155; + case AV_PIX_FMT_BAYER_RGGB8: + return 156; + case AV_PIX_FMT_BAYER_GBRG8: + return 157; + case AV_PIX_FMT_BAYER_GRBG8: + return 158; + case AV_PIX_FMT_BAYER_BGGR16LE: + return 159; + case AV_PIX_FMT_BAYER_BGGR16BE: + return 160; + case AV_PIX_FMT_BAYER_RGGB16LE: + return 161; + case AV_PIX_FMT_BAYER_RGGB16BE: + return 162; + case AV_PIX_FMT_BAYER_GBRG16LE: + return 163; + case AV_PIX_FMT_BAYER_GBRG16BE: + return 164; + case AV_PIX_FMT_BAYER_GRBG16LE: + return 165; + case AV_PIX_FMT_BAYER_GRBG16BE: + return 166; + case AV_PIX_FMT_YUV440P10LE: + return 167; + case AV_PIX_FMT_YUV440P10BE: + return 168; + case AV_PIX_FMT_YUV440P12LE: + return 169; + case AV_PIX_FMT_YUV440P12BE: + return 170; + case AV_PIX_FMT_AYUV64LE: + return 171; + case AV_PIX_FMT_AYUV64BE: + return 172; + case AV_PIX_FMT_VIDEOTOOLBOX: + return 173; + case AV_PIX_FMT_XVMC: + return 174; + // case AV_PIX_FMT_RGB32: // IF format is this type, based on + // endianness, it resolves to big endian or small endian. + // return 175; // The Switch case contains both big and + // small endian, so No need to add these in switch case. + // case AV_PIX_FMT_RGB32_1: // Will Automatically resolve. + // return 176; + // case AV_PIX_FMT_BGR32: + // return 177; + // case AV_PIX_FMT_BGR32_1: + // return 178; + // case AV_PIX_FMT_0RGB32: + // return 179; + // case AV_PIX_FMT_0BGR32: + // return 180; + // case AV_PIX_FMT_GRAY16: + // return 181; + // case AV_PIX_FMT_YA16: + // return 182; + // case AV_PIX_FMT_RGB48: + // return 183; + // case AV_PIX_FMT_RGB565: + // return 184; + // case AV_PIX_FMT_RGB444: + // return 185; + // case AV_PIX_FMT_BGR48: + // return 186; + // case AV_PIX_FMT_BGR565: + // return 187; + // case AV_PIX_FMT_BGR555: + // return 188; + // case AV_PIX_FMT_BGR444: + // return 189; + // case AV_PIX_FMT_YUV420P9: + // return 190; + // case AV_PIX_FMT_YUV422P9: + // return 191; + // case AV_PIX_FMT_YUV444P9: + // return 192; + // case AV_PIX_FMT_YUV420P10: + // return 193; + // case AV_PIX_FMT_YUV422P10: + // return 194; + // case AV_PIX_FMT_YUV440P10: + // return 195; + // case AV_PIX_FMT_YUV444P10: + // return 196; + // case AV_PIX_FMT_YUV420P12: + // return 197; + // case AV_PIX_FMT_YUV422P12: + // return 198; + // case AV_PIX_FMT_YUV440P12: + // return 199; + // case AV_PIX_FMT_YUV444P12: + // return 200; + // case AV_PIX_FMT_YUV420P14: + // return 201; + // case AV_PIX_FMT_YUV422P14: + // return 202; + // case AV_PIX_FMT_YUV444P14: + // return 203; + // case AV_PIX_FMT_YUV420P16: + // return 204; + // case AV_PIX_FMT_YUV422P16: + // return 205; + // case AV_PIX_FMT_YUV444P16: + // return 206; + // case AV_PIX_FMT_GBRP9: + // return 207; + // case AV_PIX_FMT_GBRP10: + // return 208; + // case AV_PIX_FMT_GBRP12: + // return 209; + // case AV_PIX_FMT_GBRP14: + // return 210; + // case AV_PIX_FMT_GBRP16: + // return 211; + // case AV_PIX_FMT_GBRAP16: + // return 212; + // case AV_PIX_FMT_BAYER_BGGR16: + // return 213; + // case AV_PIX_FMT_BAYER_RGGB16: + // return 214; + // case AV_PIX_FMT_BAYER_GBRG16: + // return 215; + // case AV_PIX_FMT_BAYER_GRBG16: + // return 216; + // case AV_PIX_FMT_YUVA420P9: + // return 217; + // case AV_PIX_FMT_YUVA422P9: + // return 218; + // case AV_PIX_FMT_YUVA444P9: + // return 219; + // case AV_PIX_FMT_YUVA420P10: + // return 220; + // case AV_PIX_FMT_YUVA422P10: + // return 221; + // case AV_PIX_FMT_YUVA444P10: + // return 222; + // case AV_PIX_FMT_YUVA420P16: + // return 223; + // case AV_PIX_FMT_YUVA422P16: + // return 224; + // case AV_PIX_FMT_YUVA444P16: + // return 225; + // case AV_PIX_FMT_XYZ12: + // return 226; + // case AV_PIX_FMT_NV20: + // return 227; + // case AV_PIX_FMT_AYUV64: + // return 228; + case AV_PIX_FMT_P010LE: + return 229; + case AV_PIX_FMT_P010BE: + return 230; + case AV_PIX_FMT_GBRAP12BE: + return 231; + case AV_PIX_FMT_GBRAP12LE: + return 232; + case AV_PIX_FMT_GBRAP10LE: + return 233; + case AV_PIX_FMT_GBRAP10BE: + return 234; + case AV_PIX_FMT_MEDIACODEC: + return 235; + case AV_PIX_FMT_GRAY12BE: + return 236; + case AV_PIX_FMT_GRAY12LE: + return 237; + case AV_PIX_FMT_GRAY10BE: + return 238; + case AV_PIX_FMT_GRAY10LE: + return 239; + case AV_PIX_FMT_P016LE: + return 240; + case AV_PIX_FMT_P016BE: + return 241; + case AV_PIX_FMT_D3D11: + return 242; + case AV_PIX_FMT_GRAY9BE: + return 243; + case AV_PIX_FMT_GRAY9LE: + return 244; + case AV_PIX_FMT_GBRPF32BE: + return 245; + case AV_PIX_FMT_GBRPF32LE: + return 246; + case AV_PIX_FMT_GBRAPF32BE: + return 247; + case AV_PIX_FMT_GBRAPF32LE: + return 248; + case AV_PIX_FMT_DRM_PRIME: + return 249; + // Above ffmpeg 4.0 Need to add versions. + case AV_PIX_FMT_OPENCL: + return 250; + case AV_PIX_FMT_GRAY14BE: + return 251; + case AV_PIX_FMT_GRAY14LE: + return 252; + case AV_PIX_FMT_GRAYF32BE: + return 253; + case AV_PIX_FMT_GRAYF32LE: + return 254; + case AV_PIX_FMT_YUVA422P12BE: + return 255; + case AV_PIX_FMT_YUVA422P12LE: + return 256; + case AV_PIX_FMT_YUVA444P12BE: + return 257; + case AV_PIX_FMT_YUVA444P12LE: + return 258; + case AV_PIX_FMT_NV24: + return 259; + case AV_PIX_FMT_NV42: + return 260; + case AV_PIX_FMT_VULKAN: + return 261; + case AV_PIX_FMT_Y210BE: + return 262; + case AV_PIX_FMT_Y210LE: + return 263; + case AV_PIX_FMT_X2RGB10LE: + return 264; + case AV_PIX_FMT_X2RGB10BE: + return 265; + case AV_PIX_FMT_X2BGR10LE: + return 266; + case AV_PIX_FMT_X2BGR10BE: + return 267; + case AV_PIX_FMT_P210BE: + return 268; + case AV_PIX_FMT_P210LE: + return 269; + case AV_PIX_FMT_P410BE: + return 270; + case AV_PIX_FMT_P410LE: + return 271; + case AV_PIX_FMT_P216BE: + return 272; + case AV_PIX_FMT_P216LE: + return 273; + case AV_PIX_FMT_P416BE: + return 274; + case AV_PIX_FMT_P416LE: + return 275; + case AV_PIX_FMT_VUYA: + return 276; + case AV_PIX_FMT_RGBAF16BE: + return 277; + case AV_PIX_FMT_RGBAF16LE: + return 278; + case AV_PIX_FMT_VUYX: + return 279; + case AV_PIX_FMT_P012LE: + return 280; + case AV_PIX_FMT_P012BE: + return 281; + case AV_PIX_FMT_Y212BE: + return 282; + case AV_PIX_FMT_Y212LE: + return 283; + case AV_PIX_FMT_XV30BE: + return 284; + case AV_PIX_FMT_XV30LE: + return 285; + case AV_PIX_FMT_XV36BE: + return 286; + case AV_PIX_FMT_XV36LE: + return 287; + case AV_PIX_FMT_RGBF32BE: + return 288; + case AV_PIX_FMT_RGBF32LE: + return 289; + case AV_PIX_FMT_RGBAF32BE: + return 290; + case AV_PIX_FMT_RGBAF32LE: + return 291; + // case AV_PIX_FMT_RPI : + // return 292; + // case AV_PIX_FMT_SAND128 : + // return 293; + // case AV_PIX_FMT_SAND64_10 : + // return 294; + // case AV_PIX_FMT_SAND64_16 : + // return 295; + // case AV_PIX_FMT_RPI4_8 : // rpi turn on then only + // return 296; + // case AV_PIX_FMT_RPI4_10 : + // return 297; + // case AV_PIX_FMT_RGB555: // Little Endian, Big Endian WIll + // Resolve on it's own. + // return 298; + default: + return 0; + } + } + + static AVPixelFormat intoAVPixFmt(uint32_t AvPixFmtId) { + switch (AvPixFmtId) { + case 0: + return AV_PIX_FMT_NONE; + case 1: + return AV_PIX_FMT_YUV420P; + case 2: + return AV_PIX_FMT_YUYV422; + case 3: + return AV_PIX_FMT_RGB24; + case 4: + return AV_PIX_FMT_BGR24; + case 5: + return AV_PIX_FMT_YUV422P; + case 7: + return AV_PIX_FMT_YUV444P; + case 8: + return AV_PIX_FMT_YUV410P; + case 9: + return AV_PIX_FMT_YUV411P; + case 10: + return AV_PIX_FMT_GRAY8; + case 11: + return AV_PIX_FMT_MONOWHITE; + case 12: + return AV_PIX_FMT_MONOBLACK; + case 13: + return AV_PIX_FMT_PAL8; + case 14: + return AV_PIX_FMT_YUVJ420P; + case 15: + return AV_PIX_FMT_YUVJ422P; + case 16: + return AV_PIX_FMT_YUVJ444P; + // case 17: + // return AV_PIX_FMT_XVMC_MPEG2_MC ; // Lower FFmpeg Version + // case 18: + // return AV_PIX_FMT_XVMC_MPEG2_IDCT ; + case 19: + return AV_PIX_FMT_UYVY422; + case 20: + return AV_PIX_FMT_UYYVYY411; + case 21: + return AV_PIX_FMT_BGR8; + case 22: + return AV_PIX_FMT_BGR4; + case 23: + return AV_PIX_FMT_BGR4_BYTE; + case 24: + return AV_PIX_FMT_RGB8; + case 25: + return AV_PIX_FMT_RGB4; + case 26: + return AV_PIX_FMT_RGB4_BYTE; + case 27: + return AV_PIX_FMT_NV12; + case 28: + return AV_PIX_FMT_NV21; + case 29: + return AV_PIX_FMT_ARGB; + case 30: + return AV_PIX_FMT_RGBA; + case 31: + return AV_PIX_FMT_ABGR; + case 32: + return AV_PIX_FMT_BGRA; // Little + case 33: + return AV_PIX_FMT_GRAY16BE; + case 34: + return AV_PIX_FMT_GRAY16LE; + case 35: + return AV_PIX_FMT_YUV440P; + case 36: + return AV_PIX_FMT_YUVJ440P; + case 37: + return AV_PIX_FMT_YUVA420P; + // case 38: + // return AV_PIX_FMT_VDPAU_H264 ; + // case 39: + // return AV_PIX_FMT_VDPAU_MPEG1 ; + // case 40: + // return AV_PIX_FMT_VDPAU_MPEG2 ; + // case 41: + // return AV_PIX_FMT_VDPAU_WMV3 ; // Conditional compile. + // case 42: + // return AV_PIX_FMT_VDPAU_VC1 ; // ff_api_vdpau is present + case 43: + return AV_PIX_FMT_RGB48BE; + case 44: + return AV_PIX_FMT_RGB48LE; + case 45: + return AV_PIX_FMT_RGB565BE; + case 46: + return AV_PIX_FMT_RGB565LE; + case 47: + return AV_PIX_FMT_RGB555BE; + case 48: + return AV_PIX_FMT_RGB555LE; + case 49: + return AV_PIX_FMT_BGR565BE; + case 50: + return AV_PIX_FMT_BGR565LE; + case 51: + return AV_PIX_FMT_BGR555BE; + case 52: + return AV_PIX_FMT_BGR555LE; + // case 53: + // return AV_PIX_FMT_VAAPI_MOCO ; + // case 54: + // return AV_PIX_FMT_VAAPI_IDCT ; + // case 55: + // return AV_PIX_FMT_VAAPI_VLD ; + // case 56: + // return AV_PIX_FMT_VAAPI ; // ff_api_vdpau is present + case 57: + return AV_PIX_FMT_YUV420P16LE; + case 58: + return AV_PIX_FMT_YUV420P16BE; + case 59: + return AV_PIX_FMT_YUV422P16LE; + case 60: + return AV_PIX_FMT_YUV422P16BE; + case 61: + return AV_PIX_FMT_YUV444P16LE; + case 62: + return AV_PIX_FMT_YUV444P16BE; + // case 63: + // return AV_PIX_FMT_VDPAU_MPEG4 ; // ff_api_vdpau is present + case 64: + return AV_PIX_FMT_DXVA2_VLD; + case 65: + return AV_PIX_FMT_RGB444LE; + case 66: + return AV_PIX_FMT_RGB444BE; + case 67: + return AV_PIX_FMT_BGR444LE; + case 68: + return AV_PIX_FMT_BGR444BE; + case 69: + return AV_PIX_FMT_YA8; + case 70: + return AV_PIX_FMT_BGR48BE; + case 71: + return AV_PIX_FMT_BGR48LE; + case 72: + return AV_PIX_FMT_YUV420P9BE; + case 73: + return AV_PIX_FMT_YUV420P9LE; + case 74: + return AV_PIX_FMT_YUV420P10BE; + case 75: + return AV_PIX_FMT_YUV420P10LE; + case 76: + return AV_PIX_FMT_YUV422P10BE; + case 77: + return AV_PIX_FMT_YUV422P10LE; + case 78: + return AV_PIX_FMT_YUV444P9BE; + case 79: + return AV_PIX_FMT_YUV444P9LE; + case 80: + return AV_PIX_FMT_YUV444P10BE; + case 81: + return AV_PIX_FMT_YUV444P10LE; + case 82: + return AV_PIX_FMT_YUV422P9BE; + case 83: + return AV_PIX_FMT_YUV422P9LE; + // case 84: + // return AV_PIX_FMT_VDA_VLD ; // Lower than ffmpeg version 4 + case 85: + return AV_PIX_FMT_GBRP; + case 86: + return AV_PIX_FMT_GBRP9BE; + case 87: + return AV_PIX_FMT_GBRP9LE; + case 88: + return AV_PIX_FMT_GBRP10BE; + case 89: + return AV_PIX_FMT_GBRP10LE; + case 90: + return AV_PIX_FMT_GBRP16BE; + case 91: + return AV_PIX_FMT_GBRP16LE; + case 92: + return AV_PIX_FMT_YUVA420P9BE; + case 93: + return AV_PIX_FMT_YUVA420P9LE; + case 94: + return AV_PIX_FMT_YUVA422P9BE; + case 95: + return AV_PIX_FMT_YUVA422P9LE; + case 96: + return AV_PIX_FMT_YUVA444P9BE; + case 97: + return AV_PIX_FMT_YUVA444P9LE; + case 98: + return AV_PIX_FMT_YUVA420P10BE; + case 99: + return AV_PIX_FMT_YUVA420P10LE; + case 100: + return AV_PIX_FMT_YUVA422P10BE; + case 101: + return AV_PIX_FMT_YUVA422P10LE; + case 102: + return AV_PIX_FMT_YUVA444P10BE; + case 103: + return AV_PIX_FMT_YUVA444P10LE; + case 104: + return AV_PIX_FMT_YUVA420P16BE; + case 105: + return AV_PIX_FMT_YUVA420P16LE; + case 106: + return AV_PIX_FMT_YUVA422P16BE; + case 107: + return AV_PIX_FMT_YUVA422P16LE; + case 108: + return AV_PIX_FMT_YUVA444P16BE; + case 109: + return AV_PIX_FMT_YUVA444P16LE; + case 110: + return AV_PIX_FMT_VDPAU; + case 111: + return AV_PIX_FMT_XYZ12LE; + case 112: + return AV_PIX_FMT_XYZ12BE; + case 113: + return AV_PIX_FMT_NV16; + case 114: + return AV_PIX_FMT_NV20LE; + case 115: + return AV_PIX_FMT_NV20BE; + case 116: + return AV_PIX_FMT_RGBA64BE; + case 117: + return AV_PIX_FMT_RGBA64LE; + case 118: + return AV_PIX_FMT_BGRA64BE; + case 119: + return AV_PIX_FMT_BGRA64LE; + case 120: + return AV_PIX_FMT_YVYU422; + // case 121: + // return AV_PIX_FMT_VDA ; // Lower than ffmpeg version 4. + case 122: + return AV_PIX_FMT_YA16BE; + case 123: + return AV_PIX_FMT_YA16LE; + case 124: + return AV_PIX_FMT_QSV; + case 125: + return AV_PIX_FMT_MMAL; + case 126: + return AV_PIX_FMT_D3D11VA_VLD; + case 127: + return AV_PIX_FMT_CUDA; + case 128: + return AV_PIX_FMT_0RGB; + case 129: + return AV_PIX_FMT_RGB0; + case 130: + return AV_PIX_FMT_0BGR; + case 131: + return AV_PIX_FMT_BGR0; + case 132: + return AV_PIX_FMT_YUVA444P; + case 133: + return AV_PIX_FMT_YUVA422P; + case 134: + return AV_PIX_FMT_YUV420P12BE; + case 135: + return AV_PIX_FMT_YUV420P12LE; + case 136: + return AV_PIX_FMT_YUV420P14BE; + case 137: + return AV_PIX_FMT_YUV420P14LE; + case 138: + return AV_PIX_FMT_YUV422P12BE; + case 139: + return AV_PIX_FMT_YUV422P12LE; + case 140: + return AV_PIX_FMT_YUV422P14BE; + case 141: + return AV_PIX_FMT_YUV422P14LE; + case 142: + return AV_PIX_FMT_YUV444P12BE; + case 143: + return AV_PIX_FMT_YUV444P12LE; + case 144: + return AV_PIX_FMT_YUV444P14BE; + case 146: + return AV_PIX_FMT_YUV444P14LE; + case 147: + return AV_PIX_FMT_GBRP12BE; + case 148: + return AV_PIX_FMT_GBRP12LE; + case 149: + return AV_PIX_FMT_GBRP14BE; + case 150: + return AV_PIX_FMT_GBRP14LE; + case 151: + return AV_PIX_FMT_GBRAP; + case 152: + return AV_PIX_FMT_GBRAP16BE; + case 153: + return AV_PIX_FMT_GBRAP16LE; + case 154: + return AV_PIX_FMT_YUVJ411P; + case 155: + return AV_PIX_FMT_BAYER_BGGR8; + case 156: + return AV_PIX_FMT_BAYER_RGGB8; + case 157: + return AV_PIX_FMT_BAYER_GBRG8; + case 158: + return AV_PIX_FMT_BAYER_GRBG8; + case 159: + return AV_PIX_FMT_BAYER_BGGR16LE; + case 160: + return AV_PIX_FMT_BAYER_BGGR16BE; + case 161: + return AV_PIX_FMT_BAYER_RGGB16LE; + case 162: + return AV_PIX_FMT_BAYER_RGGB16BE; + case 163: + return AV_PIX_FMT_BAYER_GBRG16LE; + case 164: + return AV_PIX_FMT_BAYER_GBRG16BE; + case 165: + return AV_PIX_FMT_BAYER_GRBG16LE; + case 166: + return AV_PIX_FMT_BAYER_GRBG16BE; + case 167: + return AV_PIX_FMT_YUV440P10LE; + case 168: + return AV_PIX_FMT_YUV440P10BE; + case 169: + return AV_PIX_FMT_YUV440P12LE; + case 170: + return AV_PIX_FMT_YUV440P12BE; + case 171: + return AV_PIX_FMT_AYUV64LE; + case 172: + return AV_PIX_FMT_AYUV64BE; + case 173: + return AV_PIX_FMT_VIDEOTOOLBOX; + case 174: + return AV_PIX_FMT_XVMC; + case 175: + return AV_PIX_FMT_RGB32; + case 176: + return AV_PIX_FMT_RGB32_1; + case 177: + return AV_PIX_FMT_BGR32; + case 178: + return AV_PIX_FMT_BGR32_1; + case 179: + return AV_PIX_FMT_0RGB32; + case 180: + return AV_PIX_FMT_0BGR32; + case 181: + return AV_PIX_FMT_GRAY16; + case 182: + return AV_PIX_FMT_YA16; + case 183: + return AV_PIX_FMT_RGB48; + case 184: + return AV_PIX_FMT_RGB565; + case 185: + return AV_PIX_FMT_RGB444; + case 186: + return AV_PIX_FMT_BGR48; + case 187: + return AV_PIX_FMT_BGR565; + case 188: + return AV_PIX_FMT_BGR555; + case 189: + return AV_PIX_FMT_BGR444; + case 190: + return AV_PIX_FMT_YUV420P9; + case 191: + return AV_PIX_FMT_YUV422P9; + case 192: + return AV_PIX_FMT_YUV444P9; + case 193: + return AV_PIX_FMT_YUV420P10; + case 194: + return AV_PIX_FMT_YUV422P10; + case 195: + return AV_PIX_FMT_YUV440P10; + case 196: + return AV_PIX_FMT_YUV444P10; + case 197: + return AV_PIX_FMT_YUV420P12; + case 198: + return AV_PIX_FMT_YUV422P12; + case 199: + return AV_PIX_FMT_YUV440P12; + case 200: + return AV_PIX_FMT_YUV444P12; + case 201: + return AV_PIX_FMT_YUV420P14; + case 202: + return AV_PIX_FMT_YUV422P14; + case 203: + return AV_PIX_FMT_YUV444P14; + case 204: + return AV_PIX_FMT_YUV420P16; + case 205: + return AV_PIX_FMT_YUV422P16; + case 206: + return AV_PIX_FMT_YUV444P16; + case 207: + return AV_PIX_FMT_GBRP9; + case 208: + return AV_PIX_FMT_GBRP10; + case 209: + return AV_PIX_FMT_GBRP12; + case 210: + return AV_PIX_FMT_GBRP14; + case 211: + return AV_PIX_FMT_GBRP16; + case 212: + return AV_PIX_FMT_GBRAP16; + case 213: + return AV_PIX_FMT_BAYER_BGGR16; + case 214: + return AV_PIX_FMT_BAYER_RGGB16; + case 215: + return AV_PIX_FMT_BAYER_GBRG16; + case 216: + return AV_PIX_FMT_BAYER_GRBG16; + case 217: + return AV_PIX_FMT_YUVA420P9; + case 218: + return AV_PIX_FMT_YUVA422P9; + case 219: + return AV_PIX_FMT_YUVA444P9; + case 220: + return AV_PIX_FMT_YUVA420P10; + case 221: + return AV_PIX_FMT_YUVA422P10; + case 222: + return AV_PIX_FMT_YUVA444P10; + case 223: + return AV_PIX_FMT_YUVA420P16; + case 224: + return AV_PIX_FMT_YUVA422P16; + case 225: + return AV_PIX_FMT_YUVA444P16; + case 226: + return AV_PIX_FMT_XYZ12; + case 227: + return AV_PIX_FMT_NV20; + case 228: + return AV_PIX_FMT_AYUV64; + case 229: + return AV_PIX_FMT_P010LE; + case 230: + return AV_PIX_FMT_P010BE; + case 231: + return AV_PIX_FMT_GBRAP12BE; + case 232: + return AV_PIX_FMT_GBRAP12LE; + case 233: + return AV_PIX_FMT_GBRAP10LE; + case 234: + return AV_PIX_FMT_GBRAP10BE; + case 235: + return AV_PIX_FMT_MEDIACODEC; + case 236: + return AV_PIX_FMT_GRAY12BE; + case 237: + return AV_PIX_FMT_GRAY12LE; + case 238: + return AV_PIX_FMT_GRAY10BE; + case 239: + return AV_PIX_FMT_GRAY10LE; + case 240: + return AV_PIX_FMT_P016LE; + case 241: + return AV_PIX_FMT_P016BE; + case 242: + return AV_PIX_FMT_D3D11; + case 243: + return AV_PIX_FMT_GRAY9BE; + case 244: + return AV_PIX_FMT_GRAY9LE; + case 245: + return AV_PIX_FMT_GBRPF32BE; + case 246: + return AV_PIX_FMT_GBRPF32LE; + case 247: + return AV_PIX_FMT_GBRAPF32BE; + case 248: + return AV_PIX_FMT_GBRAPF32LE; + case 249: + return AV_PIX_FMT_DRM_PRIME; + + // Above ffmpeg 4.0 Need to add versions. + case 250: + return AV_PIX_FMT_OPENCL; + case 251: + return AV_PIX_FMT_GRAY14BE; + case 252: + return AV_PIX_FMT_GRAY14LE; + case 253: + return AV_PIX_FMT_GRAYF32BE; + case 254: + return AV_PIX_FMT_GRAYF32LE; + case 255: + return AV_PIX_FMT_YUVA422P12BE; + case 256: + return AV_PIX_FMT_YUVA422P12LE; + case 257: + return AV_PIX_FMT_YUVA444P12BE; + case 258: + return AV_PIX_FMT_YUVA444P12LE; + case 259: + return AV_PIX_FMT_NV24; + case 260: + return AV_PIX_FMT_NV42; + case 261: + return AV_PIX_FMT_VULKAN; + case 262: + return AV_PIX_FMT_Y210BE; + case 263: + return AV_PIX_FMT_Y210LE; + case 264: + return AV_PIX_FMT_X2RGB10LE; + case 265: + return AV_PIX_FMT_X2RGB10BE; + case 266: + return AV_PIX_FMT_X2BGR10LE; + case 267: + return AV_PIX_FMT_X2BGR10BE; + case 268: + return AV_PIX_FMT_P210BE; + case 269: + return AV_PIX_FMT_P210LE; + case 270: + return AV_PIX_FMT_P410BE; + case 271: + return AV_PIX_FMT_P410LE; + case 272: + return AV_PIX_FMT_P216BE; + case 273: + return AV_PIX_FMT_P216LE; + case 274: + return AV_PIX_FMT_P416BE; + case 275: + return AV_PIX_FMT_P416LE; + case 276: + return AV_PIX_FMT_VUYA; + case 277: + return AV_PIX_FMT_RGBAF16BE; + case 278: + return AV_PIX_FMT_RGBAF16LE; + case 279: + return AV_PIX_FMT_VUYX; + case 280: + return AV_PIX_FMT_P012LE; + case 281: + return AV_PIX_FMT_P012BE; + case 282: + return AV_PIX_FMT_Y212BE; + case 283: + return AV_PIX_FMT_Y212LE; + case 284: + return AV_PIX_FMT_XV30BE; + case 285: + return AV_PIX_FMT_XV30LE; + case 286: + return AV_PIX_FMT_XV36BE; + case 287: + return AV_PIX_FMT_XV36LE; + case 288: + return AV_PIX_FMT_RGBF32BE; + case 289: + return AV_PIX_FMT_RGBF32LE; + case 290: + return AV_PIX_FMT_RGBAF32BE; + case 291: + return AV_PIX_FMT_RGBAF32LE; + // case 292: + // return AV_PIX_FMT_RPI ; + // case 293: + // return AV_PIX_FMT_SAND128 ; + // case 294: + // return AV_PIX_FMT_SAND64_10 ; + // case 295: + // return AV_PIX_FMT_SAND64_16 ; + // case 296: + // return AV_PIX_FMT_RPI4_8 ; // rpi turn on then only + // case 297: + // return AV_PIX_FMT_RPI4_10 ; + case 298: + return AV_PIX_FMT_RGB555; + default: + return AV_PIX_FMT_NONE; + } + } +}; + +class SampleFmt { +public: + static AVSampleFormat fromSampleID(uint32_t SampleID) { + switch (SampleID) { + case 0: + return AV_SAMPLE_FMT_NONE; + case 1: + return AV_SAMPLE_FMT_U8; + case 2: + return AV_SAMPLE_FMT_S16; + case 3: + return AV_SAMPLE_FMT_S32; + case 4: + return AV_SAMPLE_FMT_FLT; + case 5: + return AV_SAMPLE_FMT_DBL; + case 6: + return AV_SAMPLE_FMT_U8P; + case 7: + return AV_SAMPLE_FMT_S16P; + case 8: + return AV_SAMPLE_FMT_S32P; + case 9: + return AV_SAMPLE_FMT_FLTP; + case 10: + return AV_SAMPLE_FMT_DBLP; + case 11: + return AV_SAMPLE_FMT_S64; + case 12: + return AV_SAMPLE_FMT_S64P; + case 13: + return AV_SAMPLE_FMT_NB; + default: + return AV_SAMPLE_FMT_NONE; + } + } + + static uint32_t toSampleID(AVSampleFormat AvSampleFormat) { + switch (AvSampleFormat) { + case AV_SAMPLE_FMT_NONE: + return 0; + case AV_SAMPLE_FMT_U8: + return 1; + case AV_SAMPLE_FMT_S16: + return 2; + case AV_SAMPLE_FMT_S32: + return 3; + case AV_SAMPLE_FMT_FLT: + return 4; + case AV_SAMPLE_FMT_DBL: + return 5; + case AV_SAMPLE_FMT_U8P: + return 6; + case AV_SAMPLE_FMT_S16P: + return 7; + case AV_SAMPLE_FMT_S32P: + return 8; + case AV_SAMPLE_FMT_FLTP: + return 9; + case AV_SAMPLE_FMT_DBLP: + return 10; + case AV_SAMPLE_FMT_S64: + return 11; + case AV_SAMPLE_FMT_S64P: + return 12; + case AV_SAMPLE_FMT_NB: + return 13; + default: + return 0; + } + } +}; + +// Could have avoided, but Did this to support older version of FFMPEG +// (V5,V4,V3) Version 6 FFmpeg uses AVChannelLayout Struct; +class ChannelLayout { +private: + const static uint64_t FRONT_LEFT = 1; + const static uint64_t FRONT_RIGHT = 1ULL << 1; + const static uint64_t FRONT_CENTER = 1ULL << 2; + const static uint64_t LOW_FREQUENCY = 1ULL << 3; + const static uint64_t BACK_LEFT = 1ULL << 4; + const static uint64_t BACK_RIGHT = 1ULL << 5; + const static uint64_t FRONT_LEFT_OF_CENTER = 1ULL << 6; + const static uint64_t FRONT_RIGHT_OF_CENTER = 1ULL << 7; + const static uint64_t BACK_CENTER = 1ULL << 8; + const static uint64_t SIDE_LEFT = 1ULL << 9; + const static uint64_t SIDE_RIGHT = 1ULL << 10; + const static uint64_t TOP_CENTER = 1ULL << 11; + const static uint64_t TOP_FRONT_LEFT = 1ULL << 12; + const static uint64_t TOP_FRONT_CENTER = 1ULL << 13; + const static uint64_t TOP_FRONT_RIGHT = 1ULL << 14; + const static uint64_t TOP_BACK_LEFT = 1ULL << 15; + const static uint64_t TOP_BACK_CENTER = 1ULL << 16; + const static uint64_t TOP_BACK_RIGHT = 1ULL << 17; + const static uint64_t STEREO_LEFT = 1ULL << 18; + const static uint64_t STEREO_RIGHT = 1ULL << 19; + const static uint64_t WIDE_LEFT = 1ULL << 20; + const static uint64_t WIDE_RIGHT = 1ULL << 21; + const static uint64_t SURROUND_DIRECT_LEFT = 1ULL << 22; + const static uint64_t SURROUND_DIRECT_RIGHT = 1ULL << 23; + const static uint64_t LOW_FREQUENCY_2 = 1ULL << 24; + const static uint64_t NATIVE = 1ULL << 25; + + const static uint64_t MONO = 1ULL << 26; + const static uint64_t STEREO = 1ULL << 27; + const static uint64_t _2POINT1 = 1ULL << 28; + const static uint64_t _2_1 = 1ULL << 29; + const static uint64_t SURROUND = 1ULL << 30; + const static uint64_t _3POINT1 = 1ULL << 31; + const static uint64_t _4POINT0 = 1ULL << 32; + const static uint64_t _4POINT1 = 1ULL << 33; + const static uint64_t _2_2 = 1ULL << 34; + const static uint64_t QUAD = 1ULL << 35; + const static uint64_t _5POINT0 = 1ULL << 36; + const static uint64_t _5POINT1 = 1ULL << 37; + const static uint64_t _5POINT0_BACK = 1ULL << 38; + const static uint64_t _5POINT1_BACK = 1ULL << 39; + const static uint64_t _6POINT0 = 1ULL << 40; + const static uint64_t _6POINT0_FRONT = 1ULL << 41; + const static uint64_t HEXAGONAL = 1ULL << 42; + const static uint64_t _6POINT1 = 1ULL << 43; + const static uint64_t _6POINT1_BACK = 1ULL << 44; + const static uint64_t _6POINT1_FRONT = 1ULL << 45; + const static uint64_t _7POINT0 = 1ULL << 46; + const static uint64_t _7POINT0_FRONT = 1ULL << 47; + const static uint64_t _7POINT1 = 1ULL << 48; + const static uint64_t _7POINT1_WIDE = 1ULL << 49; + const static uint64_t _7POINT1_WIDE_BACK = 1ULL << 50; + const static uint64_t OCTAGONAL = 1ULL << 51; + const static uint64_t HEXADECAGONAL = 1ULL << 52; + const static uint64_t STEREO_DOWNMIX = 1ULL << 53; + +public: + // Check This function. (Looks good, test it) + static uint64_t fromChannelLayoutID(uint64_t ChannelLayout) { + uint64_t Channel = 0UL; + if (ChannelLayout & FRONT_LEFT) + Channel |= AV_CH_FRONT_LEFT; + if (ChannelLayout & FRONT_RIGHT) + Channel |= AV_CH_FRONT_RIGHT; + if (ChannelLayout & FRONT_CENTER) + Channel |= AV_CH_FRONT_CENTER; + if (ChannelLayout & LOW_FREQUENCY) + Channel |= AV_CH_LOW_FREQUENCY; + if (ChannelLayout & BACK_LEFT) + Channel |= AV_CH_BACK_LEFT; + if (ChannelLayout & BACK_RIGHT) + Channel |= AV_CH_BACK_RIGHT; + if (ChannelLayout & FRONT_LEFT_OF_CENTER) + Channel |= AV_CH_FRONT_LEFT_OF_CENTER; + if (ChannelLayout & FRONT_RIGHT_OF_CENTER) + Channel |= AV_CH_FRONT_RIGHT_OF_CENTER; + if (ChannelLayout & BACK_CENTER) + Channel |= AV_CH_BACK_CENTER; + if (ChannelLayout & SIDE_LEFT) + Channel |= AV_CH_SIDE_LEFT; + if (ChannelLayout & SIDE_RIGHT) + Channel |= AV_CH_SIDE_RIGHT; + if (ChannelLayout & TOP_CENTER) + Channel |= AV_CH_TOP_CENTER; + if (ChannelLayout & TOP_FRONT_LEFT) + Channel |= AV_CH_TOP_FRONT_LEFT; + if (ChannelLayout & TOP_FRONT_CENTER) + Channel |= AV_CH_TOP_FRONT_CENTER; + if (ChannelLayout & TOP_FRONT_RIGHT) + Channel |= AV_CH_TOP_FRONT_RIGHT; + if (ChannelLayout & TOP_BACK_LEFT) + Channel |= AV_CH_TOP_BACK_LEFT; + if (ChannelLayout & TOP_BACK_CENTER) + Channel |= AV_CH_TOP_BACK_CENTER; + if (ChannelLayout & TOP_BACK_RIGHT) + Channel |= AV_CH_TOP_BACK_RIGHT; + if (ChannelLayout & STEREO_LEFT) + Channel |= AV_CH_STEREO_LEFT; + if (ChannelLayout & STEREO_RIGHT) + Channel |= AV_CH_STEREO_RIGHT; + if (ChannelLayout & WIDE_LEFT) + Channel |= AV_CH_WIDE_LEFT; + if (ChannelLayout & WIDE_RIGHT) + Channel |= AV_CH_WIDE_RIGHT; + if (ChannelLayout & SURROUND_DIRECT_LEFT) + Channel |= AV_CH_SURROUND_DIRECT_LEFT; + if (ChannelLayout & SURROUND_DIRECT_RIGHT) + Channel |= AV_CH_SURROUND_DIRECT_RIGHT; + if (ChannelLayout & LOW_FREQUENCY_2) + Channel |= AV_CH_LOW_FREQUENCY_2; + if (ChannelLayout & NATIVE) + Channel |= AV_CH_LAYOUT_NATIVE; + if (ChannelLayout & MONO) + Channel |= AV_CH_LAYOUT_MONO; + if (ChannelLayout & STEREO) + Channel |= AV_CH_LAYOUT_STEREO; + if (ChannelLayout & _2POINT1) + Channel |= AV_CH_LAYOUT_2POINT1; + if (ChannelLayout & _2_1) + Channel |= AV_CH_LAYOUT_2_1; + if (ChannelLayout & SURROUND) + Channel |= AV_CH_LAYOUT_SURROUND; + if (ChannelLayout & _3POINT1) + Channel |= AV_CH_LAYOUT_3POINT1; + if (ChannelLayout & _4POINT0) + Channel |= AV_CH_LAYOUT_4POINT0; + if (ChannelLayout & _4POINT1) + Channel |= AV_CH_LAYOUT_4POINT1; + if (ChannelLayout & _2_2) + Channel |= AV_CH_LAYOUT_2_2; + if (ChannelLayout & QUAD) + Channel |= AV_CH_LAYOUT_QUAD; + if (ChannelLayout & _5POINT0) + Channel |= AV_CH_LAYOUT_5POINT0; + if (ChannelLayout & _5POINT1) + Channel |= AV_CH_LAYOUT_5POINT1; + if (ChannelLayout & _5POINT0_BACK) + Channel |= AV_CH_LAYOUT_5POINT0_BACK; + if (ChannelLayout & _5POINT1_BACK) + Channel |= AV_CH_LAYOUT_5POINT1_BACK; + if (ChannelLayout & _6POINT0) + Channel |= AV_CH_LAYOUT_6POINT0; + if (ChannelLayout & _6POINT0_FRONT) + Channel |= AV_CH_LAYOUT_6POINT0_FRONT; + if (ChannelLayout & HEXAGONAL) + Channel |= AV_CH_LAYOUT_HEXAGONAL; + if (ChannelLayout & _6POINT1) + Channel |= AV_CH_LAYOUT_6POINT1; + if (ChannelLayout & _6POINT1_BACK) + Channel |= AV_CH_LAYOUT_6POINT1_BACK; + if (ChannelLayout & _6POINT1_FRONT) + Channel |= AV_CH_LAYOUT_6POINT1_FRONT; + if (ChannelLayout & _7POINT0) + Channel |= AV_CH_LAYOUT_7POINT0; + if (ChannelLayout & _7POINT0_FRONT) + Channel |= AV_CH_LAYOUT_7POINT0_FRONT; + if (ChannelLayout & _7POINT1) + Channel |= AV_CH_LAYOUT_7POINT1; + if (ChannelLayout & _7POINT1_WIDE) + Channel |= AV_CH_LAYOUT_7POINT1_WIDE; + if (ChannelLayout & _7POINT1_WIDE_BACK) + Channel |= AV_CH_LAYOUT_7POINT1_WIDE_BACK; + if (ChannelLayout & OCTAGONAL) + Channel |= AV_CH_LAYOUT_OCTAGONAL; + if (ChannelLayout & HEXADECAGONAL) + Channel |= AV_CH_LAYOUT_HEXADECAGONAL; + if (ChannelLayout & STEREO_DOWNMIX) + Channel |= AV_CH_LAYOUT_STEREO_DOWNMIX; + return Channel; + } + + // Perfect Logic :) + static uint64_t intoChannelLayoutID(uint64_t ChannelLayout) { + uint64_t Channel = 0; + if ((ChannelLayout & AV_CH_FRONT_LEFT) == AV_CH_FRONT_LEFT) + Channel |= FRONT_LEFT; + if ((ChannelLayout & AV_CH_FRONT_RIGHT) == AV_CH_FRONT_RIGHT) + Channel |= FRONT_RIGHT; + if ((ChannelLayout & AV_CH_FRONT_CENTER) == AV_CH_FRONT_CENTER) + Channel |= FRONT_CENTER; + if ((ChannelLayout & AV_CH_LOW_FREQUENCY) == AV_CH_LOW_FREQUENCY) + Channel |= LOW_FREQUENCY; + if ((ChannelLayout & AV_CH_BACK_LEFT) == AV_CH_BACK_LEFT) + Channel |= BACK_LEFT; + if ((ChannelLayout & AV_CH_BACK_RIGHT) == AV_CH_BACK_RIGHT) + Channel |= BACK_RIGHT; + if ((ChannelLayout & AV_CH_FRONT_LEFT_OF_CENTER) == + AV_CH_FRONT_LEFT_OF_CENTER) + Channel |= FRONT_LEFT_OF_CENTER; + if ((ChannelLayout & AV_CH_FRONT_RIGHT_OF_CENTER) == + AV_CH_FRONT_RIGHT_OF_CENTER) + Channel |= FRONT_RIGHT_OF_CENTER; + if ((ChannelLayout & AV_CH_BACK_CENTER) == AV_CH_BACK_CENTER) + Channel |= BACK_CENTER; + if ((ChannelLayout & AV_CH_SIDE_LEFT) == AV_CH_SIDE_LEFT) + Channel |= SIDE_LEFT; + if ((ChannelLayout & AV_CH_SIDE_RIGHT) == AV_CH_SIDE_RIGHT) + Channel |= SIDE_RIGHT; + if ((ChannelLayout & AV_CH_TOP_CENTER) == AV_CH_TOP_CENTER) + Channel |= TOP_CENTER; + if ((ChannelLayout & AV_CH_TOP_FRONT_LEFT) == AV_CH_TOP_FRONT_LEFT) + Channel |= TOP_FRONT_LEFT; + if ((ChannelLayout & AV_CH_TOP_FRONT_CENTER) == AV_CH_TOP_FRONT_CENTER) + Channel |= TOP_FRONT_CENTER; + if ((ChannelLayout & AV_CH_TOP_FRONT_RIGHT) == AV_CH_TOP_FRONT_RIGHT) + Channel |= TOP_FRONT_RIGHT; + if ((ChannelLayout & AV_CH_TOP_BACK_LEFT) == AV_CH_TOP_BACK_LEFT) + Channel |= TOP_BACK_LEFT; + if ((ChannelLayout & AV_CH_TOP_BACK_CENTER) == AV_CH_TOP_BACK_CENTER) + Channel |= TOP_BACK_CENTER; + if ((ChannelLayout & AV_CH_TOP_BACK_RIGHT) == AV_CH_TOP_BACK_RIGHT) + Channel |= TOP_BACK_RIGHT; + if ((ChannelLayout & AV_CH_STEREO_LEFT) == AV_CH_STEREO_LEFT) + Channel |= STEREO_LEFT; + if ((ChannelLayout & AV_CH_STEREO_RIGHT) == AV_CH_STEREO_RIGHT) + Channel |= STEREO_RIGHT; + if ((ChannelLayout & AV_CH_WIDE_LEFT) == AV_CH_WIDE_LEFT) + Channel |= WIDE_LEFT; + if ((ChannelLayout & AV_CH_WIDE_RIGHT) == AV_CH_WIDE_RIGHT) + Channel |= WIDE_RIGHT; + if ((ChannelLayout & AV_CH_SURROUND_DIRECT_LEFT) == + AV_CH_SURROUND_DIRECT_LEFT) + Channel |= SURROUND_DIRECT_LEFT; + if ((ChannelLayout & AV_CH_SURROUND_DIRECT_RIGHT) == + AV_CH_SURROUND_DIRECT_RIGHT) + Channel |= SURROUND_DIRECT_RIGHT; + if ((ChannelLayout & AV_CH_LOW_FREQUENCY_2) == AV_CH_LOW_FREQUENCY_2) + Channel |= LOW_FREQUENCY_2; + + // Channel Mask C; + if ((ChannelLayout & AV_CH_LAYOUT_NATIVE) == AV_CH_LAYOUT_NATIVE) + Channel |= NATIVE; + if ((ChannelLayout & AV_CH_LAYOUT_MONO) == AV_CH_LAYOUT_MONO) + Channel |= MONO; + if ((ChannelLayout & AV_CH_LAYOUT_STEREO) == AV_CH_LAYOUT_STEREO) + Channel |= STEREO; + if ((ChannelLayout & AV_CH_LAYOUT_2POINT1) == AV_CH_LAYOUT_2POINT1) + Channel |= _2POINT1; + if ((ChannelLayout & AV_CH_LAYOUT_2_1) == AV_CH_LAYOUT_2_1) + Channel |= _2_1; + if ((ChannelLayout & AV_CH_LAYOUT_SURROUND) == AV_CH_LAYOUT_SURROUND) + Channel |= SURROUND; + if ((ChannelLayout & AV_CH_LAYOUT_3POINT1) == AV_CH_LAYOUT_3POINT1) + Channel |= _3POINT1; + if ((ChannelLayout & AV_CH_LAYOUT_4POINT0) == AV_CH_LAYOUT_4POINT0) + Channel |= _4POINT0; + if ((ChannelLayout & AV_CH_LAYOUT_4POINT1) == AV_CH_LAYOUT_4POINT1) + Channel |= _4POINT1; + if ((ChannelLayout & AV_CH_LAYOUT_2_2) == AV_CH_LAYOUT_2_2) + Channel |= _2_2; + if ((ChannelLayout & AV_CH_LAYOUT_QUAD) == AV_CH_LAYOUT_QUAD) + Channel |= QUAD; + if ((ChannelLayout & AV_CH_LAYOUT_5POINT0) == AV_CH_LAYOUT_5POINT0) + Channel |= _5POINT0; + if ((ChannelLayout & AV_CH_LAYOUT_5POINT1) == AV_CH_LAYOUT_5POINT1) + Channel |= _5POINT1; + if ((ChannelLayout & AV_CH_LAYOUT_5POINT0_BACK) == + AV_CH_LAYOUT_5POINT0_BACK) + Channel |= _5POINT0_BACK; + if ((ChannelLayout & AV_CH_LAYOUT_5POINT1_BACK) == + AV_CH_LAYOUT_5POINT1_BACK) + Channel |= _5POINT1_BACK; + if ((ChannelLayout & AV_CH_LAYOUT_6POINT0) == AV_CH_LAYOUT_6POINT0) + Channel |= _6POINT0; + if ((ChannelLayout & AV_CH_LAYOUT_6POINT0_FRONT) == + AV_CH_LAYOUT_6POINT0_FRONT) + Channel |= _6POINT0_FRONT; + if ((ChannelLayout & AV_CH_LAYOUT_HEXAGONAL) == AV_CH_LAYOUT_HEXAGONAL) + Channel |= HEXAGONAL; + if ((ChannelLayout & AV_CH_LAYOUT_6POINT1) == AV_CH_LAYOUT_6POINT1) + Channel |= _6POINT1; + if ((ChannelLayout & AV_CH_LAYOUT_6POINT1_BACK) == + AV_CH_LAYOUT_6POINT1_BACK) + Channel |= _6POINT1_BACK; + if ((ChannelLayout & AV_CH_LAYOUT_6POINT1_FRONT) == + AV_CH_LAYOUT_6POINT1_FRONT) + Channel |= _6POINT1_FRONT; + if ((ChannelLayout & AV_CH_LAYOUT_7POINT0) == AV_CH_LAYOUT_7POINT0) + Channel |= _7POINT0; + if ((ChannelLayout & AV_CH_LAYOUT_7POINT0_FRONT) == + AV_CH_LAYOUT_7POINT0_FRONT) + Channel |= _7POINT0_FRONT; + if ((ChannelLayout & AV_CH_LAYOUT_7POINT1) == AV_CH_LAYOUT_7POINT1) + Channel |= _7POINT1; + if ((ChannelLayout & AV_CH_LAYOUT_7POINT1_WIDE) == + AV_CH_LAYOUT_7POINT1_WIDE) + Channel |= _7POINT1_WIDE; + if ((ChannelLayout & AV_CH_LAYOUT_7POINT1_WIDE_BACK) == + AV_CH_LAYOUT_7POINT1_WIDE_BACK) + Channel |= _7POINT1_WIDE_BACK; + if ((ChannelLayout & AV_CH_LAYOUT_OCTAGONAL) == AV_CH_LAYOUT_OCTAGONAL) + Channel |= OCTAGONAL; + if ((ChannelLayout & AV_CH_LAYOUT_HEXADECAGONAL) == + AV_CH_LAYOUT_HEXADECAGONAL) + Channel |= HEXADECAGONAL; + if ((ChannelLayout & AV_CH_LAYOUT_STEREO_DOWNMIX) == + AV_CH_LAYOUT_STEREO_DOWNMIX) + Channel |= STEREO_DOWNMIX; + return Channel; + } +}; + +class SWRFilterType { +public: + uint32_t fromSwrFilterType(SwrFilterType FilterType) { + switch (FilterType) { + case SWR_FILTER_TYPE_CUBIC: + return 1; + case SWR_FILTER_TYPE_BLACKMAN_NUTTALL: + return 2; + case SWR_FILTER_TYPE_KAISER: + return 3; + default: + return 1; + } + } + + SwrFilterType intoSwrFilterType(uint32_t FilterID) { + switch (FilterID) { + case 1: + return SWR_FILTER_TYPE_CUBIC; + case 2: + return SWR_FILTER_TYPE_BLACKMAN_NUTTALL; + case 3: + return SWR_FILTER_TYPE_KAISER; + default: + return SWR_FILTER_TYPE_CUBIC; + } + } +}; + +class SWREngine { +public: + SwrEngine intoSwrEngine(uint32_t EngineId) { + switch (EngineId) { + case 1: + return SWR_ENGINE_SWR; + case 2: + return SWR_ENGINE_SOXR; + default: + return SWR_ENGINE_SWR; + } + } + + uint32_t fromSwrEngine(SwrEngine Engine) { + switch (Engine) { + case SWR_ENGINE_SWR: + return 1; + case SWR_ENGINE_SOXR: + return 2; + case SWR_ENGINE_NB: + return 3; + default: + return SWR_ENGINE_SWR; + } + } +}; + +class SWRDitherType { +public: + SwrDitherType intoSwrDitherType(uint32_t SwrDitherId) { + switch (SwrDitherId) { + case 0: + return SWR_DITHER_NONE; + case 1: + return SWR_DITHER_RECTANGULAR; + case 2: + return SWR_DITHER_TRIANGULAR; + case 3: + return SWR_DITHER_TRIANGULAR_HIGHPASS; + case 64: + return SWR_DITHER_NS; + case 4: + return SWR_DITHER_NS_LIPSHITZ; + case 5: + return SWR_DITHER_NS_F_WEIGHTED; + case 6: + return SWR_DITHER_NS_MODIFIED_E_WEIGHTED; + case 7: + return SWR_DITHER_NS_IMPROVED_E_WEIGHTED; + case 8: + return SWR_DITHER_NS_SHIBATA; + case 9: + return SWR_DITHER_NS_LOW_SHIBATA; + case 10: + return SWR_DITHER_NS_HIGH_SHIBATA; + case 11: + return SWR_DITHER_NB; + default: + return SWR_DITHER_NONE; + } + } + + uint32_t fromSwrDitherType(SwrDitherType SwrDitherType) { + switch (SwrDitherType) { + case SWR_DITHER_NONE: + return 0; + case SWR_DITHER_RECTANGULAR: + return 1; + case SWR_DITHER_TRIANGULAR: + return 2; + case SWR_DITHER_TRIANGULAR_HIGHPASS: + return 3; + case SWR_DITHER_NS: + return 64; + case SWR_DITHER_NS_LIPSHITZ: + return 4; + case SWR_DITHER_NS_F_WEIGHTED: + return 5; + case SWR_DITHER_NS_MODIFIED_E_WEIGHTED: + return 6; + case SWR_DITHER_NS_IMPROVED_E_WEIGHTED: + return 7; + case SWR_DITHER_NS_SHIBATA: + return 8; + case SWR_DITHER_NS_LOW_SHIBATA: + return 9; + case SWR_DITHER_NS_HIGH_SHIBATA: + return 10; + case SWR_DITHER_NB: + return 11; + default: + return 0; + } + } +}; + +class ChromaLocation { +public: + static AVChromaLocation intoAVChromaLocation(int32_t ChromaLocationId) { + switch (ChromaLocationId) { + case 0: + return AVCHROMA_LOC_UNSPECIFIED; + case 1: + return AVCHROMA_LOC_LEFT; + case 2: + return AVCHROMA_LOC_CENTER; + case 3: + return AVCHROMA_LOC_TOPLEFT; + case 4: + return AVCHROMA_LOC_TOP; + case 5: + return AVCHROMA_LOC_BOTTOMLEFT; + case 6: + return AVCHROMA_LOC_BOTTOM; + default: + return AVCHROMA_LOC_UNSPECIFIED; + } + } + + static int32_t fromAVChromaLocation(AVChromaLocation ChromaLocation) { + switch (ChromaLocation) { + case AVCHROMA_LOC_UNSPECIFIED: + return 0; + case AVCHROMA_LOC_LEFT: + return 1; + case AVCHROMA_LOC_CENTER: + return 2; + case AVCHROMA_LOC_TOPLEFT: + return 3; + case AVCHROMA_LOC_TOP: + return 4; + case AVCHROMA_LOC_BOTTOMLEFT: + return 5; + case AVCHROMA_LOC_BOTTOM: + return 6; + default: + return 0; + } + } +}; + +class Rounding { + +public: + static AVRounding intoAVRounding(int32_t RoundingId) { + switch (RoundingId) { + case 0: + return AV_ROUND_ZERO; + case 1: + return AV_ROUND_INF; + case 2: + return AV_ROUND_DOWN; + case 3: + return AV_ROUND_UP; + case 4: + return AV_ROUND_NEAR_INF; + case 5: + return AV_ROUND_PASS_MINMAX; + default: + return AV_ROUND_ZERO; + } + } + + static int32_t fromAVRounding(AVRounding Rounding) { + switch (Rounding) { + case AV_ROUND_ZERO: + return 0; + case AV_ROUND_INF: + return 1; + case AV_ROUND_DOWN: + return 2; + case AV_ROUND_UP: + return 3; + case AV_ROUND_NEAR_INF: + return 4; + case AV_ROUND_PASS_MINMAX: + return 5; + default: + return 0; + } + } +}; + +class OptionType { + +public: + static AVOptionType intoAVOptionType(int32_t RoundingId) { + switch (RoundingId) { + case 0: + return AV_OPT_TYPE_FLAGS; + case 1: + return AV_OPT_TYPE_INT; + case 2: + return AV_OPT_TYPE_INT64; + case 3: + return AV_OPT_TYPE_DOUBLE; + case 4: + return AV_OPT_TYPE_FLOAT; + case 5: + return AV_OPT_TYPE_STRING; + case 6: + return AV_OPT_TYPE_RATIONAL; + case 7: + return AV_OPT_TYPE_BINARY; + case 8: + return AV_OPT_TYPE_DICT; + case 9: + return AV_OPT_TYPE_CONST; + case 10: + return AV_OPT_TYPE_IMAGE_SIZE; + case 11: + return AV_OPT_TYPE_PIXEL_FMT; + case 12: + return AV_OPT_TYPE_SAMPLE_FMT; + case 13: + return AV_OPT_TYPE_VIDEO_RATE; + case 14: + return AV_OPT_TYPE_DURATION; + case 15: + return AV_OPT_TYPE_COLOR; + case 16: + return AV_OPT_TYPE_CHANNEL_LAYOUT; + case 17: + return AV_OPT_TYPE_UINT64; + case 18: + return AV_OPT_TYPE_BOOL; + case 19: + return AV_OPT_TYPE_CHLAYOUT; + default: + return AV_OPT_TYPE_FLAGS; + } + } + + static int32_t fromAVOptionType(AVOptionType OptionType) { + switch (OptionType) { + case AV_OPT_TYPE_FLAGS: + return 0; + case AV_OPT_TYPE_INT: + return 1; + case AV_OPT_TYPE_INT64: + return 2; + case AV_OPT_TYPE_DOUBLE: + return 3; + case AV_OPT_TYPE_FLOAT: + return 4; + case AV_OPT_TYPE_STRING: + return 5; + case AV_OPT_TYPE_RATIONAL: + return 6; + case AV_OPT_TYPE_BINARY: + return 7; + case AV_OPT_TYPE_DICT: + return 8; + case AV_OPT_TYPE_CONST: + return 9; + case AV_OPT_TYPE_IMAGE_SIZE: + return 10; + case AV_OPT_TYPE_PIXEL_FMT: + return 11; + case AV_OPT_TYPE_SAMPLE_FMT: + return 12; + case AV_OPT_TYPE_VIDEO_RATE: + return 13; + case AV_OPT_TYPE_DURATION: + return 14; + case AV_OPT_TYPE_COLOR: + return 15; + case AV_OPT_TYPE_CHANNEL_LAYOUT: + return 16; + case AV_OPT_TYPE_UINT64: + return 17; + case AV_OPT_TYPE_BOOL: + return 18; + case AV_OPT_TYPE_CHLAYOUT: + return 19; + default: + return 0; + } + } +}; + +class PictureType { +public: + static AVPictureType intoAVPictureType(int32_t PictureId) { + switch (PictureId) { + case 0: + return AV_PICTURE_TYPE_NONE; + case 1: + return AV_PICTURE_TYPE_I; + case 2: + return AV_PICTURE_TYPE_P; + case 3: + return AV_PICTURE_TYPE_B; + case 4: + return AV_PICTURE_TYPE_S; + case 5: + return AV_PICTURE_TYPE_SI; + case 6: + return AV_PICTURE_TYPE_SP; + case 7: + return AV_PICTURE_TYPE_BI; + default: + return AV_PICTURE_TYPE_NONE; + } + }; + + static int32_t fromAVPictureType(AVPictureType PictureType) { + switch (PictureType) { + case AV_PICTURE_TYPE_NONE: + return 0; + case AV_PICTURE_TYPE_I: + return 1; + case AV_PICTURE_TYPE_P: + return 2; + case AV_PICTURE_TYPE_B: + return 3; + case AV_PICTURE_TYPE_S: + return 4; + case AV_PICTURE_TYPE_SI: + return 5; + case AV_PICTURE_TYPE_SP: + return 6; + case AV_PICTURE_TYPE_BI: + return 7; + default: + return 0; + } + } +}; + +// Direct mapping in rust. Not required. Can be used for decoupling (Clean +// Code). +// +// class ColorTransferCharacteristic { +// +// static AVColorTransferCharacteristic +// intoColorTransferCharacteristic(uint32_t ColorTransferCharacteristicId) { +// switch (ColorTransferCharacteristicId) { +// case 0: +// return AVCOL_TRC_RESERVED0; +// case 1: +// return AVCOL_TRC_BT709; +// case 2: +// return AVCOL_TRC_UNSPECIFIED; +// case 3: +// return AVCOL_TRC_RESERVED; +// case 4: +// return AVCOL_TRC_GAMMA22; +// case 5: +// return AVCOL_TRC_GAMMA28; +// case 6: +// return AVCOL_TRC_SMPTE170M; +// case 7: +// return AVCOL_TRC_SMPTE240M; +// case 8: +// return AVCOL_TRC_LINEAR; +// case 9: +// return AVCOL_TRC_LOG; +// case 10: +// return AVCOL_TRC_LOG_SQRT; +// case 11: +// return AVCOL_TRC_IEC61966_2_4; +// case 12: +// return AVCOL_TRC_BT1361_ECG; +// case 13: +// return AVCOL_TRC_IEC61966_2_1; +// case 14: +// return AVCOL_TRC_BT2020_10; +// case 15: +// return AVCOL_TRC_BT2020_12; +// case 16: +// return AVCOL_TRC_SMPTE2084; +// case 17: +// return AVCOL_TRC_SMPTE428; +// case 18: +// return AVCOL_TRC_ARIB_STD_B67; +// case 19: +// return AVCOL_TRC_NB; +// default: +// return AVCOL_TRC_RESERVED0; +// } +// }; +// +// static uint32_t +// fromColorTransferCharacteristic(uint32_t ColorTransferCharacteristic) { +// switch (ColorTransferCharacteristic) { +// case AVCOL_TRC_RESERVED0: +// return 0; +// case AVCOL_TRC_BT709: +// return 1; +// case AVCOL_TRC_UNSPECIFIED: +// return 2; +// case AVCOL_TRC_RESERVED: +// return 3; +// case AVCOL_TRC_GAMMA22: +// return 4; +// case AVCOL_TRC_GAMMA28: +// return 5; +// case AVCOL_TRC_SMPTE170M: +// return 6; +// case AVCOL_TRC_SMPTE240M: +// return 7; +// case AVCOL_TRC_LINEAR: +// return 8; +// case AVCOL_TRC_LOG: +// return 9; +// case AVCOL_TRC_LOG_SQRT: +// return 10; +// case AVCOL_TRC_IEC61966_2_4: +// return 11; +// case AVCOL_TRC_BT1361_ECG: +// return 12; +// case AVCOL_TRC_IEC61966_2_1: +// return 13; +// case AVCOL_TRC_BT2020_10: +// return 14; +// case AVCOL_TRC_BT2020_12: +// return 15; +// case AVCOL_TRC_SMPTE2084: +// return 16; +// case AVCOL_TRC_SMPTE428: +// return 17; +// case AVCOL_TRC_ARIB_STD_B67: +// return 18; +// case AVCOL_TRC_NB: +// return 19; +// default: +// return 0; +// } +// }; +//}; + +// We can keep or remove the binding. +class ColorSpace { + +public: + static AVColorSpace intoAVColorSpace(int32_t ColorSpaceId) { + + switch (ColorSpaceId) { + case 0: + return AVCOL_SPC_RGB; + case 1: + return AVCOL_SPC_BT709; + case 2: + return AVCOL_SPC_UNSPECIFIED; + case 3: + return AVCOL_SPC_RESERVED; + case 4: + return AVCOL_SPC_FCC; + case 5: + return AVCOL_SPC_BT470BG; + case 6: + return AVCOL_SPC_SMPTE170M; + case 7: + return AVCOL_SPC_SMPTE240M; + case 8: + return AVCOL_SPC_YCGCO; + case 9: + return AVCOL_SPC_BT2020_NCL; + case 10: + return AVCOL_SPC_BT2020_CL; + case 11: + return AVCOL_SPC_SMPTE2085; + case 12: + return AVCOL_SPC_CHROMA_DERIVED_NCL; + case 13: + return AVCOL_SPC_CHROMA_DERIVED_CL; + case 14: + return AVCOL_SPC_ICTCP; + default: + return AVCOL_SPC_RGB; + } + }; + + static int32_t fromAVColorSpace(AVColorSpace ColorSpace) { + + switch (ColorSpace) { + case AVCOL_SPC_RGB: + return 0; + case AVCOL_SPC_BT709: + return 1; + case AVCOL_SPC_UNSPECIFIED: + return 2; + case AVCOL_SPC_RESERVED: + return 3; + case AVCOL_SPC_FCC: + return 4; + case AVCOL_SPC_BT470BG: + return 5; + case AVCOL_SPC_SMPTE170M: + return 6; + case AVCOL_SPC_SMPTE240M: + return 7; + case AVCOL_SPC_YCGCO: + return 8; + case AVCOL_SPC_BT2020_NCL: + return 9; + case AVCOL_SPC_BT2020_CL: + return 10; + case AVCOL_SPC_SMPTE2085: + return 11; + case AVCOL_SPC_CHROMA_DERIVED_NCL: + return 12; + case AVCOL_SPC_CHROMA_DERIVED_CL: + return 13; + case AVCOL_SPC_ICTCP: + return 14; + default: + return 0; + } + }; +}; + +class FieldOrder { +public: + static AVFieldOrder intoAVFieldOrder(int32_t FieldOrderId) { + switch (FieldOrderId) { + case 0: + return AV_FIELD_UNKNOWN; + case 1: + return AV_FIELD_PROGRESSIVE; + case 2: + return AV_FIELD_TT; + case 3: + return AV_FIELD_BB; + case 4: + return AV_FIELD_TB; + case 5: + return AV_FIELD_BT; + default: + return AV_FIELD_UNKNOWN; + } + } + + static int32_t fromAVFieldOrder(AVFieldOrder FieldOrder) { + switch (FieldOrder) { + case AV_FIELD_UNKNOWN: + return 0; + case AV_FIELD_PROGRESSIVE: + return 1; + case AV_FIELD_TT: + return 2; + case AV_FIELD_BB: + return 3; + case AV_FIELD_TB: + return 4; + case AV_FIELD_BT: + return 5; + default: + return 0; + } + } +}; + +class ColorPrimaries { + +public: + static AVColorPrimaries intoAVColorPrimaries(int32_t ColorPrimariesId) { + switch (ColorPrimariesId) { + case 0: + return AVCOL_PRI_RESERVED0; + case 1: + return AVCOL_PRI_BT709; + case 2: + return AVCOL_PRI_UNSPECIFIED; + case 3: + return AVCOL_PRI_RESERVED; + case 4: + return AVCOL_PRI_BT470M; + case 5: + return AVCOL_PRI_BT470BG; + case 6: + return AVCOL_PRI_SMPTE170M; + case 7: + return AVCOL_PRI_SMPTE240M; + case 8: + return AVCOL_PRI_FILM; + case 9: + return AVCOL_PRI_BT2020; + case 10: + return AVCOL_PRI_SMPTE428; + case 11: + return AVCOL_PRI_SMPTE431; + case 12: + return AVCOL_PRI_SMPTE432; + case 13: + return AVCOL_PRI_JEDEC_P22; + case 14: + return AVCOL_PRI_EBU3213; + default: + return AVCOL_PRI_RESERVED0; + } + }; + + static int32_t fromAVColorPrimaries(AVColorPrimaries ColorPrimaries) { + switch (ColorPrimaries) { + case AVCOL_PRI_RESERVED0: + return 0; + case AVCOL_PRI_BT709: + return 1; + case AVCOL_PRI_UNSPECIFIED: + return 2; + case AVCOL_PRI_RESERVED: + return 3; + case AVCOL_PRI_BT470M: + return 4; + case AVCOL_PRI_BT470BG: + return 5; + case AVCOL_PRI_SMPTE170M: + return 6; + case AVCOL_PRI_SMPTE240M: + return 7; + case AVCOL_PRI_FILM: + return 8; + case AVCOL_PRI_BT2020: + return 9; + case AVCOL_PRI_SMPTE428: + return 10; + case AVCOL_PRI_SMPTE431: + return 11; + case AVCOL_PRI_SMPTE432: + return 12; + // #[cfg(not(feature = "ffmpeg_4_3"))] + // case AVCOL_PRI_JEDEC_P22: + // return 13; + case AVCOL_PRI_EBU3213: + return 14; + default: + return 0; + } + }; +}; + +} // namespace FFmpegUtils +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasmedge_ffmpeg/ffmpeg_env.cpp b/plugins/wasmedge_ffmpeg/ffmpeg_env.cpp new file mode 100644 index 00000000..f2705c4a --- /dev/null +++ b/plugins/wasmedge_ffmpeg/ffmpeg_env.cpp @@ -0,0 +1,112 @@ +#include "ffmpeg_env.h" +#include "avcodec/module.h" +#include "avdevice/module.h" +#include "avfilter/module.h" +#include "avformat/module.h" +#include "avutil/module.h" +#include "swresample/module.h" +#include "swscale/module.h" + +namespace WasmEdge { +namespace Host { +namespace { + +Runtime::Instance::ModuleInstance * +createAVCodec(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeFFmpeg::AVcodec::WasmEdgeFFmpegAVCodecModule( + WasmEdgeFFmpeg::WasmEdgeFFmpegEnv::getInstance()); +} + +Runtime::Instance::ModuleInstance * +createAVDevice(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeFFmpeg::AVDevice::WasmEdgeFFmpegAVDeviceModule( + WasmEdgeFFmpeg::WasmEdgeFFmpegEnv::getInstance()); +} + +Runtime::Instance::ModuleInstance * +createAVFilter(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeFFmpeg::AVFilter::WasmEdgeFFmpegAVFilterModule( + WasmEdgeFFmpeg::WasmEdgeFFmpegEnv::getInstance()); +} + +Runtime::Instance::ModuleInstance * +createAVFormat(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeFFmpeg::AVFormat::WasmEdgeFFmpegAVFormatModule( + WasmEdgeFFmpeg::WasmEdgeFFmpegEnv::getInstance()); +} + +Runtime::Instance::ModuleInstance * +createAVUtil(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeFFmpeg::AVUtil::WasmEdgeFFmpegAVUtilModule( + WasmEdgeFFmpeg::WasmEdgeFFmpegEnv::getInstance()); +} + +Runtime::Instance::ModuleInstance * +createSWScale(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeFFmpeg::SWScale::WasmEdgeFFmpegSWScaleModule( + WasmEdgeFFmpeg::WasmEdgeFFmpegEnv::getInstance()); +} + +Runtime::Instance::ModuleInstance * +createSWResample(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeFFmpeg::SWResample::WasmEdgeFFmpegSWResampleModule( + WasmEdgeFFmpeg::WasmEdgeFFmpegEnv::getInstance()); +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasmedge_ffmpeg", + .Description = "", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 0, 0, 1}, + .ModuleCount = 7, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "wasmedge_ffmpeg_avcodec", + .Description = "encoding/decoding library", + .Create = createAVCodec, + }, + { + .Name = "wasmedge_ffmpeg_avdevice", + .Description = "special devices muxing/demuxing library ", + .Create = createAVDevice, + }, + { + .Name = "wasmedge_ffmpeg_avfilter", + .Description = "graph-based frame editing library", + .Create = createAVFilter, + }, + { + .Name = "wasmedge_ffmpeg_avformat", + .Description = "I/O and muxing/demuxing library", + .Create = createAVFormat, + }, + { + .Name = "wasmedge_ffmpeg_avutil", + .Description = "utils utility library", + .Create = createAVUtil, + }, + { + .Name = "wasmedge_ffmpeg_swresample", + .Description = "audio resampling, format conversion and mixing", + .Create = createSWResample, + }, + { + .Name = "wasmedge_ffmpeg_swscale", + .Description = "color conversion and scaling library", + .Create = createSWScale, + }}, + .AddOptions = nullptr, +}; + +EXPORT_GET_DESCRIPTOR(Descriptor) + +} // namespace + +std::weak_ptr + WasmEdgeFFmpeg::WasmEdgeFFmpegEnv::Instance = + std::make_shared(); + +std::shared_mutex WasmEdgeFFmpeg::WasmEdgeFFmpegEnv::Mutex; +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasmedge_ffmpeg/ffmpeg_env.h b/plugins/wasmedge_ffmpeg/ffmpeg_env.h new file mode 100644 index 00000000..e584fc5d --- /dev/null +++ b/plugins/wasmedge_ffmpeg/ffmpeg_env.h @@ -0,0 +1,110 @@ +#pragma once + +#include "bindings.h" +#include "plugin/plugin.h" + +#include "vector" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +class WasmEdgeFFmpegEnv { +public: + // Singleton + static std::shared_ptr getInstance() noexcept { + std::unique_lock Lock(Mutex); + std::shared_ptr EnvPtr = Instance.lock(); + if (!EnvPtr) { + EnvPtr.reset(new WasmEdgeFFmpegEnv()); + Instance = EnvPtr; + } + return EnvPtr; + } + + // Avoid copy constructor and overloading functions. + WasmEdgeFFmpegEnv(const WasmEdgeFFmpegEnv &) = delete; + void operator=(const WasmEdgeFFmpegEnv &) = delete; + + void alloc(void *Data, uint32_t *DataPtr) { + FfmpegPtrMap[FfmpegPtrAllocateKey++] = Data; + *DataPtr = FfmpegPtrAllocateKey - 1; + } + + void *fetchData(const size_t Index) { + if (Index >= FfmpegPtrAllocateKey) { + return nullptr; + } + // Check this condition. + if (FfmpegPtrMap[Index] == nullptr) { + return nullptr; + } + + return FfmpegPtrMap[Index]; + } + + void dealloc(size_t Index) { + + if (Index >= FfmpegPtrAllocateKey) { + return; + } + + FfmpegPtrMap.erase(Index); + } + + WasmEdgeFFmpegEnv() noexcept {} + +private: + // Using zero as NULL Value. + uint32_t FfmpegPtrAllocateKey = 1; + // Can update this to uint64_t to get more memory. + std::map FfmpegPtrMap; + static std::weak_ptr Instance; + static std::shared_mutex Mutex; +}; + +// Utils functions. +#define MEMINST_CHECK(Out, CallFrame, Index) \ + auto *Out = CallFrame.getMemoryByIndex(Index); \ + if (unlikely(Out == nullptr)) { \ + spdlog::error("[WasmEdge-FFmpeg] Memory instance not found."sv); \ + return static_cast(ErrNo::MissingMemory); \ + } + +#define FFMPEG_PTR_FETCH(StructPtr, FFmpegStructId, Type) \ + Type *StructPtr = nullptr; \ + if (FFmpegStructId != 0) \ + StructPtr = static_cast(Env.get()->fetchData(FFmpegStructId)); + +#define MEM_SPAN_CHECK(OutSpan, MemInst, Type, BufPtr, BufLen, Message) \ + auto OutSpan = MemInst->getSpan(BufPtr, BufLen); \ + if (unlikely(OutSpan.size() != BufLen)) { \ + spdlog::error("[WasmEdge-FFmpeg] "sv Message); \ + return static_cast(ErrNo::MissingMemory); \ + } + +#define FFMPEG_PTR_STORE(StructPtr, FFmpegStructId) \ + Env.get()->alloc(StructPtr, FFmpegStructId); + +#define FFMPEG_PTR_DELETE(FFmpegStructId) Env.get()->dealloc(FFmpegStructId); + +#define MEM_PTR_CHECK(OutPtr, MemInst, Type, Offset, Message) \ + Type *OutPtr = MemInst->getPointerOrNull(Offset); \ + if (unlikely(OutPtr == nullptr)) { \ + spdlog::error("[WasmEdge-FFmpeg] "sv Message); \ + return static_cast(ErrNo::MissingMemory); \ + } + +// Starting from 200 because, posix codes take values till 131. +// Hence using 200. +enum class ErrNo : int32_t { + Success = 0, // No error occurred. + MissingMemory = -201, // Caller module is missing a memory export. + NullStructId = -202, // Rust Sdk Passes null id. + InternalError = -203, + UnImplemented = -204 // Unimplemented funcs. +}; + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasmedge_ffmpeg/swresample/module.cpp b/plugins/wasmedge_ffmpeg/swresample/module.cpp new file mode 100644 index 00000000..00d617db --- /dev/null +++ b/plugins/wasmedge_ffmpeg/swresample/module.cpp @@ -0,0 +1,40 @@ +#include "module.h" +#include "swresample_func.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace SWResample { + +WasmEdgeFFmpegSWResampleModule::WasmEdgeFFmpegSWResampleModule( + std::shared_ptr Env) + : ModuleInstance("wasmedge_ffmpeg_swresample") { + + addHostFunc("wasmedge_ffmpeg_swresample_swresample_version", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swresample_swr_get_delay", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swresample_swr_init", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swresample_swr_alloc_set_opts", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swresample_av_opt_set_dict", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swresample_swr_convert_frame", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swresample_swr_free", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swresample_swresample_configuration_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swresample_swresample_configuration", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swresample_swresample_license_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swresample_swresample_license", + std::make_unique(Env)); +} + +} // namespace SWResample +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/swresample/module.h b/plugins/wasmedge_ffmpeg/swresample/module.h new file mode 100644 index 00000000..a47d966b --- /dev/null +++ b/plugins/wasmedge_ffmpeg/swresample/module.h @@ -0,0 +1,20 @@ +#pragma once + +#include "ffmpeg_env.h" +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace SWResample { + +class WasmEdgeFFmpegSWResampleModule + : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeFFmpegSWResampleModule(std::shared_ptr Env); +}; + +} // namespace SWResample +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/swresample/swresample_base.h b/plugins/wasmedge_ffmpeg/swresample/swresample_base.h new file mode 100644 index 00000000..574dcd20 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/swresample/swresample_base.h @@ -0,0 +1,25 @@ +#pragma once + +#include "ffmpeg_env.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace SWResample { + +template +class WasmEdgeFFmpegSWResample : public Runtime::HostFunction { +public: + WasmEdgeFFmpegSWResample( + std::shared_ptr HostEnv) + : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + std::shared_ptr Env; +}; + +} // namespace SWResample +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp b/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp new file mode 100644 index 00000000..a28e7520 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp @@ -0,0 +1,126 @@ +#include "swresample_func.h" + +extern "C" { +#include "libavutil/avutil.h" +#include "libavutil/opt.h" +#include "libswresample/swresample.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace SWResample { + +Expect SWResampleVersion::body(const Runtime::CallingFrame &) { + return swresample_version(); +} + +Expect SWRGetDelay::body(const Runtime::CallingFrame &, + uint32_t SWRContextId, int64_t Base) { + + FFMPEG_PTR_FETCH(SWRContext, SWRContextId, SwrContext); + return swr_get_delay(SWRContext, Base); +} + +Expect SWRInit::body(const Runtime::CallingFrame &, + uint32_t SWRContextId) { + + FFMPEG_PTR_FETCH(SWRContext, SWRContextId, SwrContext); + return swr_init(SWRContext); +} + +Expect +SWRAllocSetOpts::body(const Runtime::CallingFrame &Frame, uint32_t SwrCtxPtr, + uint32_t SWRContextId, uint64_t OutChLayoutId, + uint32_t OutSampleFmtId, int32_t OutSampleRate, + uint64_t InChLayoutId, uint32_t InSampleFmtId, + int32_t InSampleRate, int32_t LogOffset) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(SwrCtxId, MemInst, uint32_t, SwrCtxPtr, "") + FFMPEG_PTR_FETCH(CurrSwrCtx, *SwrCtxId, SwrContext); + FFMPEG_PTR_FETCH(ExistSWRContext, SWRContextId, SwrContext); + + uint64_t const OutChLayout = + FFmpegUtils::ChannelLayout::fromChannelLayoutID(OutChLayoutId); + AVSampleFormat const OutSampleFmt = + FFmpegUtils::SampleFmt::fromSampleID(OutSampleFmtId); + uint64_t const InChLayout = + FFmpegUtils::ChannelLayout::fromChannelLayoutID(InChLayoutId); + AVSampleFormat const InSampleFmt = + FFmpegUtils::SampleFmt::fromSampleID(InSampleFmtId); + CurrSwrCtx = swr_alloc_set_opts( + ExistSWRContext, OutChLayout, OutSampleFmt, OutSampleRate, InChLayout, + InSampleFmt, InSampleRate, LogOffset, + nullptr); // Always being used as null in rust sdk. + FFMPEG_PTR_STORE(CurrSwrCtx, SwrCtxId); + return static_cast(ErrNo::Success); +} + +Expect AVOptSetDict::body(const Runtime::CallingFrame &, + uint32_t SWRContextId, uint32_t DictId) { + + FFMPEG_PTR_FETCH(SWRContext, SWRContextId, SwrContext); + FFMPEG_PTR_FETCH(AvDictionary, DictId, AVDictionary *); + return av_opt_set_dict(SWRContext, AvDictionary); +} + +Expect SWRConvertFrame::body(const Runtime::CallingFrame &, + uint32_t SWRContextId, + uint32_t FrameOutputId, + uint32_t FrameInputId) { + FFMPEG_PTR_FETCH(SWRContext, SWRContextId, SwrContext); + FFMPEG_PTR_FETCH(OuputFrame, FrameOutputId, AVFrame); + FFMPEG_PTR_FETCH(InputFrame, FrameInputId, AVFrame); + + return swr_convert_frame(SWRContext, OuputFrame, InputFrame); +} + +Expect SWRFree::body(const Runtime::CallingFrame &, + uint32_t SWRContextId) { + FFMPEG_PTR_FETCH(SWRContext, SWRContextId, SwrContext); + swr_close(SWRContext); + FFMPEG_PTR_DELETE(SWRContextId); + return static_cast(ErrNo::Success); +} + +Expect +SWResampleConfigurationLength::body(const Runtime::CallingFrame &) { + const char *Config = swresample_configuration(); + return strlen(Config); +} + +Expect +SWResampleConfiguration::body(const Runtime::CallingFrame &Frame, + uint32_t ConfigPtr, uint32_t ConfigLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); + + const char *Config = swresample_configuration(); + std::copy_n(Config, ConfigLen, ConfigBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect SWResampleLicenseLength::body(const Runtime::CallingFrame &) { + + const char *License = swresample_license(); + return strlen(License); +} + +Expect SWResampleLicense::body(const Runtime::CallingFrame &Frame, + uint32_t LicensePtr, + uint32_t LicenseLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); + + const char *License = swresample_license(); + std::copy_n(License, LicenseLen, LicenseBuf.data()); + return static_cast(ErrNo::Success); +} + +} // namespace SWResample +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/swresample/swresample_func.h b/plugins/wasmedge_ffmpeg/swresample/swresample_func.h new file mode 100644 index 00000000..b8dd8d7f --- /dev/null +++ b/plugins/wasmedge_ffmpeg/swresample/swresample_func.h @@ -0,0 +1,107 @@ +#pragma once + +#include "ffmpeg_env.h" +#include "runtime/callingframe.h" +#include "swresample_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace SWResample { + +class SWResampleVersion : public WasmEdgeFFmpegSWResample { +public: + SWResampleVersion(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWResample(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class SWRGetDelay : public WasmEdgeFFmpegSWResample { +public: + SWRGetDelay(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWResample(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SWRContextId, int64_t Base); +}; + +class SWRInit : public WasmEdgeFFmpegSWResample { +public: + SWRInit(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWResample(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SWRContextId); +}; + +class SWRAllocSetOpts : public WasmEdgeFFmpegSWResample { +public: + SWRAllocSetOpts(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWResample(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SwrCtxPtr, + uint32_t SWRContextId, uint64_t OutChLayout, + uint32_t OutSampleFmtId, int32_t OutSampleRate, + uint64_t InChLayout, uint32_t InSampleFmtId, + int32_t InSampleRate, int32_t LogOffset); +}; + +class AVOptSetDict : public WasmEdgeFFmpegSWResample { +public: + AVOptSetDict(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWResample(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SWRContextId, uint32_t DictId); +}; + +class SWRConvertFrame : public WasmEdgeFFmpegSWResample { +public: + SWRConvertFrame(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWResample(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SWRContextId, uint32_t FrameOutputId, + uint32_t FrameInputId); +}; + +class SWRFree : public WasmEdgeFFmpegSWResample { +public: + SWRFree(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWResample(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SWRContextId); +}; + +class SWResampleConfigurationLength + : public WasmEdgeFFmpegSWResample { +public: + SWResampleConfigurationLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWResample(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class SWResampleConfiguration + : public WasmEdgeFFmpegSWResample { +public: + SWResampleConfiguration(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWResample(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, + uint32_t ConfigLen); +}; + +class SWResampleLicenseLength + : public WasmEdgeFFmpegSWResample { +public: + SWResampleLicenseLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWResample(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class SWResampleLicense : public WasmEdgeFFmpegSWResample { +public: + SWResampleLicense(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWResample(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, + uint32_t LicenseLen); +}; + +} // namespace SWResample +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasmedge_ffmpeg/swscale/module.cpp b/plugins/wasmedge_ffmpeg/swscale/module.cpp new file mode 100644 index 00000000..f33cadd4 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/swscale/module.cpp @@ -0,0 +1,74 @@ +#include "module.h" +#include "swscale_func.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace SWScale { + +WasmEdgeFFmpegSWScaleModule::WasmEdgeFFmpegSWScaleModule( + std::shared_ptr Env) + : ModuleInstance("wasmedge_ffmpeg_swscale") { + + addHostFunc("wasmedge_ffmpeg_swscale_swscale_version", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_swscale_configuration_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_swscale_configuration", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_swscale_license_length", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_swscale_license", + std::make_unique(Env)); + + // SwsContext + addHostFunc("wasmedge_ffmpeg_swscale_sws_getContext", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_freeContext", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_scale", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_getCachedContext", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_isSupportedInput", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_isSupportedOutput", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_isSupportedEndiannessConversion", + std::make_unique(Env)); + + // SwsFilter + addHostFunc("wasmedge_ffmpeg_swscale_sws_getDefaultFilter", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_getLumaH", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_getLumaV", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_getChromaH", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_getChromaV", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_freeFilter", + std::make_unique(Env)); + + // SwsVector + addHostFunc("wasmedge_ffmpeg_swscale_sws_allocVec", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_getGaussianVec", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_scaleVec", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_normalizeVec", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_getCoeffVecLength", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_getCoeff", + std::make_unique(Env)); + addHostFunc("wasmedge_ffmpeg_swscale_sws_freeVec", + std::make_unique(Env)); +} + +} // namespace SWScale +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/swscale/module.h b/plugins/wasmedge_ffmpeg/swscale/module.h new file mode 100644 index 00000000..bc53ee2f --- /dev/null +++ b/plugins/wasmedge_ffmpeg/swscale/module.h @@ -0,0 +1,19 @@ +#pragma once + +#include "ffmpeg_env.h" +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace SWScale { + +class WasmEdgeFFmpegSWScaleModule : public Runtime::Instance::ModuleInstance { +public: + WasmEdgeFFmpegSWScaleModule(std::shared_ptr Env); +}; + +} // namespace SWScale +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/swscale/swscale_base.h b/plugins/wasmedge_ffmpeg/swscale/swscale_base.h new file mode 100644 index 00000000..32dc9cf1 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/swscale/swscale_base.h @@ -0,0 +1,25 @@ +#pragma once + +#include "ffmpeg_env.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace SWScale { + +template +class WasmEdgeFFmpegSWScale : public Runtime::HostFunction { +public: + WasmEdgeFFmpegSWScale( + std::shared_ptr HostEnv) + : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + std::shared_ptr Env; +}; + +} // namespace SWScale +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp b/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp new file mode 100644 index 00000000..7b126607 --- /dev/null +++ b/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp @@ -0,0 +1,298 @@ +#include "swscale_func.h" + +extern "C" { +#include "libavutil/frame.h" +#include "libswscale/swscale.h" +} + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace SWScale { + +Expect +SwsGetContext::body(const Runtime::CallingFrame &Frame, uint32_t SwsCtxPtr, + uint32_t SrcW, uint32_t SrcH, uint32_t SrcPixFormatId, + uint32_t DesW, uint32_t DesH, uint32_t DesPixFormatId, + int32_t Flags, uint32_t SrcFilterId, uint32_t DesFilterId) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(SwsCtxId, MemInst, uint32_t, SwsCtxPtr, + "Failed when accessing the return SWSContext Memory"sv) + + FFMPEG_PTR_FETCH(SwsCtx, *SwsCtxId, SwsContext) + FFMPEG_PTR_FETCH(SrcSwsFilter, SrcFilterId, SwsFilter) + FFMPEG_PTR_FETCH(DesSwsFilter, DesFilterId, SwsFilter) + + AVPixelFormat const SrcPixelFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(SrcPixFormatId); + AVPixelFormat const DestPixelFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(DesPixFormatId); + SwsCtx = sws_getContext(SrcW, SrcH, SrcPixelFormat, DesW, DesH, + DestPixelFormat, Flags, SrcSwsFilter, DesSwsFilter, + nullptr); // Not using param anywhere in Rust SDK. + if (SwsCtx == nullptr) + return static_cast(ErrNo::InternalError); + FFMPEG_PTR_STORE(SwsCtx, SwsCtxId); + return static_cast(ErrNo::Success); +} + +Expect SwsFreeContext::body(const Runtime::CallingFrame &, + uint32_t SwsCtxId) { + + FFMPEG_PTR_FETCH(SwsCtx, SwsCtxId, SwsContext) + sws_freeContext(SwsCtx); + FFMPEG_PTR_DELETE(SwsCtxId); + return static_cast(ErrNo::Success); +} + +Expect SwsScale::body(const Runtime::CallingFrame &, uint32_t SwsCtxId, + uint32_t InputFrameId, int32_t SrcSliceY, + int32_t SrcSliceH, uint32_t OutputFrameId) { + + FFMPEG_PTR_FETCH(SwsCtx, SwsCtxId, SwsContext); + FFMPEG_PTR_FETCH(InputFrame, InputFrameId, AVFrame); + FFMPEG_PTR_FETCH(OutputFrame, OutputFrameId, AVFrame); + return sws_scale(SwsCtx, InputFrame->data, InputFrame->linesize, SrcSliceY, + SrcSliceH, OutputFrame->data, OutputFrame->linesize); +} + +Expect SwsGetCachedContext::body( + const Runtime::CallingFrame &Frame, uint32_t SwsCachedCtxPtr, + uint32_t SwsCtxId, uint32_t SrcW, uint32_t SrcH, uint32_t SrcPixFormatId, + uint32_t DesW, uint32_t DesH, uint32_t DesPixFormatId, int32_t Flags, + uint32_t SrcFilterId, uint32_t DesFilterId) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(SwsCachedCtxId, MemInst, uint32_t, SwsCachedCtxPtr, "") + + FFMPEG_PTR_FETCH(SwsCachedCtx, *SwsCachedCtxId, SwsContext); + FFMPEG_PTR_FETCH(SwsCtx, SwsCtxId, SwsContext); + FFMPEG_PTR_FETCH(SrcSwsFilter, SrcFilterId, SwsFilter) + FFMPEG_PTR_FETCH(DesSwsFilter, DesFilterId, SwsFilter) + + AVPixelFormat const SrcPixelFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(SrcPixFormatId); + AVPixelFormat const DestPixelFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(DesPixFormatId); + SwsCachedCtx = sws_getCachedContext(SwsCtx, SrcW, SrcH, SrcPixelFormat, DesW, + DesH, DestPixelFormat, Flags, + SrcSwsFilter, DesSwsFilter, nullptr); + if (SwsCachedCtx == nullptr) + return static_cast(ErrNo::InternalError); + + FFMPEG_PTR_STORE(SwsCachedCtx, SwsCachedCtxId); + return static_cast(ErrNo::Success); +} + +Expect SwsIsSupportedInput::body(const Runtime::CallingFrame &, + uint32_t PixFormatId) { + AVPixelFormat const PixelFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(PixFormatId); + return sws_isSupportedInput(PixelFormat); +} + +Expect SwsIsSupportedOutput::body(const Runtime::CallingFrame &, + uint32_t PixFormatId) { + AVPixelFormat const PixelFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(PixFormatId); + return sws_isSupportedOutput(PixelFormat); +} + +Expect +SwsIsSupportedEndiannessConversion::body(const Runtime::CallingFrame &, + uint32_t PixFormatId) { + AVPixelFormat const PixelFormat = + FFmpegUtils::PixFmt::intoAVPixFmt(PixFormatId); + return sws_isSupportedEndiannessConversion(PixelFormat); +} + +Expect SwsGetDefaultFilter::body( + const Runtime::CallingFrame &Frame, uint32_t SwsFilterPtr, float LumaGBlur, + float ChromaGBlur, float LumaSharpen, float ChromaSharpen, + float ChromaHShift, float ChromaVShift, int32_t Verbose) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(SwsFilterId, MemInst, uint32_t, SwsFilterPtr, "") + + SwsFilter *Filter = + sws_getDefaultFilter(LumaGBlur, ChromaGBlur, LumaSharpen, ChromaSharpen, + ChromaHShift, ChromaVShift, Verbose); + if (Filter == nullptr) + return static_cast(ErrNo::InternalError); + FFMPEG_PTR_STORE(Filter, SwsFilterId); + return static_cast(ErrNo::Success); +} + +Expect SwsGetLumaH::body(const Runtime::CallingFrame &Frame, + uint32_t SwsFilterId, uint32_t SwsVectorPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(SwsVectorId, MemInst, uint32_t, SwsVectorPtr, "") + FFMPEG_PTR_FETCH(Filter, SwsFilterId, SwsFilter); + + SwsVector *Vector = Filter->lumH; + FFMPEG_PTR_STORE(Vector, SwsVectorId); + return static_cast(ErrNo::Success); +} + +Expect SwsGetLumaV::body(const Runtime::CallingFrame &Frame, + uint32_t SwsFilterId, uint32_t SwsVectorPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(SwsVectorId, MemInst, uint32_t, SwsVectorPtr, "") + FFMPEG_PTR_FETCH(Filter, SwsFilterId, SwsFilter); + + SwsVector *Vector = Filter->lumV; + FFMPEG_PTR_STORE(Vector, SwsVectorId); + return static_cast(ErrNo::Success); +} + +Expect SwsGetChromaH::body(const Runtime::CallingFrame &Frame, + uint32_t SwsFilterId, + uint32_t SwsVectorPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(SwsVectorId, MemInst, uint32_t, SwsVectorPtr, "") + FFMPEG_PTR_FETCH(Filter, SwsFilterId, SwsFilter); + + SwsVector *Vector = Filter->chrH; + FFMPEG_PTR_STORE(Vector, SwsVectorId); + return static_cast(ErrNo::Success); +} + +Expect SwsGetChromaV::body(const Runtime::CallingFrame &Frame, + uint32_t SwsFilterId, + uint32_t SwsVectorPtr) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(SwsVectorId, MemInst, uint32_t, SwsVectorPtr, "") + FFMPEG_PTR_FETCH(Filter, SwsFilterId, SwsFilter); + + SwsVector *Vector = Filter->chrV; + FFMPEG_PTR_STORE(Vector, SwsVectorId); + return static_cast(ErrNo::Success); +} + +Expect SwsFreeFilter::body(const Runtime::CallingFrame &, + uint32_t SwsFilterId) { + + FFMPEG_PTR_FETCH(Filter, SwsFilterId, SwsFilter); + sws_freeFilter(Filter); + FFMPEG_PTR_DELETE(SwsFilterId); + return static_cast(ErrNo::Success); +} + +Expect SwsAllocVec::body(const Runtime::CallingFrame &Frame, + uint32_t SwsVectorPtr, int32_t Length) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(SwsVectorId, MemInst, uint32_t, SwsVectorPtr, "") + + SwsVector *Vector = sws_allocVec(Length); + FFMPEG_PTR_STORE(Vector, SwsVectorId); + return static_cast(ErrNo::Success); +} + +Expect SwsGetGaussianVec::body(const Runtime::CallingFrame &Frame, + uint32_t SwsVectorPtr, double Variance, + double Quality) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_PTR_CHECK(SwsVectorId, MemInst, uint32_t, SwsVectorPtr, "") + + SwsVector *Vector = sws_getGaussianVec(Variance, Quality); + FFMPEG_PTR_STORE(Vector, SwsVectorId); + return static_cast(ErrNo::Success); +} + +Expect SwsScaleVec::body(const Runtime::CallingFrame &, + uint32_t SwsVectorId, double Scalar) { + + FFMPEG_PTR_FETCH(Vector, SwsVectorId, SwsVector); + sws_scaleVec(Vector, Scalar); + return static_cast(ErrNo::Success); +} + +Expect SwsNormalizeVec::body(const Runtime::CallingFrame &, + uint32_t SwsVectorId, double Height) { + + FFMPEG_PTR_FETCH(Vector, SwsVectorId, SwsVector); + sws_normalizeVec(Vector, Height); + return static_cast(ErrNo::Success); +} + +Expect SwsGetCoeffVecLength::body(const Runtime::CallingFrame &, + uint32_t SwsVectorId) { + + FFMPEG_PTR_FETCH(Vector, SwsVectorId, SwsVector); + return Vector->length * + sizeof(double); // Getting the size in uint_8* (Cuz Passing uint8_t* + // array from Rust SDK). +} + +Expect SwsGetCoeff::body(const Runtime::CallingFrame &Frame, + uint32_t SwsVectorId, uint32_t CoeffBufPtr, + uint32_t Len) { + + MEMINST_CHECK(MemInst, Frame, 0) + MEM_SPAN_CHECK(Buffer, MemInst, uint8_t, CoeffBufPtr, Len, ""); + FFMPEG_PTR_FETCH(Vector, SwsVectorId, SwsVector); + + double *Coeff = Vector->coeff; + std::copy_n(Coeff, Len, Buffer.data()); + return static_cast(ErrNo::Success); +} + +Expect SwsFreeVec::body(const Runtime::CallingFrame &, + uint32_t SwsVectorId) { + + FFMPEG_PTR_FETCH(Vector, SwsVectorId, SwsVector); + sws_freeVec(Vector); + FFMPEG_PTR_DELETE(SwsVectorId); + return static_cast(ErrNo::Success); +} + +Expect SwscaleVersion::body(const Runtime::CallingFrame &) { + return swscale_version(); +} + +Expect +SwscaleConfigurationLength::body(const Runtime::CallingFrame &) { + const char *Config = swscale_configuration(); + return strlen(Config); +} + +Expect SwscaleConfiguration::body(const Runtime::CallingFrame &Frame, + uint32_t ConfigPtr, + uint32_t ConfigLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); + + const char *Config = swscale_configuration(); + std::copy_n(Config, ConfigLen, ConfigBuf.data()); + return static_cast(ErrNo::Success); +} + +Expect SwscaleLicenseLength::body(const Runtime::CallingFrame &) { + + const char *License = swscale_license(); + return strlen(License); +} + +Expect SwscaleLicense::body(const Runtime::CallingFrame &Frame, + uint32_t LicensePtr, uint32_t LicenseLen) { + + MEMINST_CHECK(MemInst, Frame, 0); + MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); + + const char *License = swscale_license(); + std::copy_n(License, LicenseLen, LicenseBuf.data()); + return static_cast(ErrNo::Success); +} + +} // namespace SWScale +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/swscale/swscale_func.h b/plugins/wasmedge_ffmpeg/swscale/swscale_func.h new file mode 100644 index 00000000..dfb7ffbf --- /dev/null +++ b/plugins/wasmedge_ffmpeg/swscale/swscale_func.h @@ -0,0 +1,226 @@ +#pragma once + +#include "runtime/callingframe.h" +#include "swscale_base.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +namespace SWScale { + +class SwsGetContext : public WasmEdgeFFmpegSWScale { +public: + SwsGetContext(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWScale(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsCtxPtr, + uint32_t SrcW, uint32_t SrcH, uint32_t SrcPixFormatId, + uint32_t DesW, uint32_t DesH, uint32_t DesPixFormatId, + int32_t Flags, uint32_t SrcFilterId, + uint32_t DesFilterId); +}; + +class SwsFreeContext : public WasmEdgeFFmpegSWScale { +public: + SwsFreeContext(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWScale(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsCtxId); +}; + +class SwsScale : public WasmEdgeFFmpegSWScale { +public: + SwsScale(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWScale(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsCtxId, + uint32_t InputFrameId, int32_t SrcSliceY, + int32_t SrcSliceH, uint32_t OutputFrameId); +}; + +class SwsGetCachedContext : public WasmEdgeFFmpegSWScale { +public: + SwsGetCachedContext(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWScale(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SwsCachedCtxPtr, uint32_t SwsCtxPtr, + uint32_t SrcW, uint32_t SrcH, uint32_t SrcPixFormatId, + uint32_t DesW, uint32_t DesH, uint32_t DesPixFormatId, + int32_t Flags, uint32_t SrcFilterId, + uint32_t DesFilterId); +}; + +class SwsIsSupportedInput : public WasmEdgeFFmpegSWScale { +public: + SwsIsSupportedInput(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWScale(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t PixFormatId); +}; + +class SwsIsSupportedOutput + : public WasmEdgeFFmpegSWScale { +public: + SwsIsSupportedOutput(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWScale(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t PixFormatId); +}; + +class SwsIsSupportedEndiannessConversion + : public WasmEdgeFFmpegSWScale { +public: + SwsIsSupportedEndiannessConversion(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWScale(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t PixFormatId); +}; + +class SwsGetDefaultFilter : public WasmEdgeFFmpegSWScale { +public: + SwsGetDefaultFilter(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWScale(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SwsFilterPtr, float LumaGBlur, + float ChromaGBlur, float LumaSharpen, + float ChromaSharpen, float ChromaHShift, + float ChromaVShift, int32_t Verbose); +}; + +class SwsGetLumaH : public WasmEdgeFFmpegSWScale { +public: + SwsGetLumaH(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWScale(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsFilterId, + uint32_t SwsVectorPtr); +}; + +class SwsGetLumaV : public WasmEdgeFFmpegSWScale { +public: + SwsGetLumaV(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWScale(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsFilterId, + uint32_t SwsVectorPtr); +}; + +class SwsGetChromaH : public WasmEdgeFFmpegSWScale { +public: + SwsGetChromaH(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWScale(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsFilterId, + uint32_t SwsVectorPtr); +}; + +class SwsGetChromaV : public WasmEdgeFFmpegSWScale { +public: + SwsGetChromaV(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWScale(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsFilterId, + uint32_t SwsVectorPtr); +}; + +class SwsFreeFilter : public WasmEdgeFFmpegSWScale { +public: + SwsFreeFilter(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWScale(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SwsFilterId); +}; + +class SwsAllocVec : public WasmEdgeFFmpegSWScale { +public: + SwsAllocVec(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWScale(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SwsVectorPtr, int32_t Length); +}; + +class SwsGetGaussianVec : public WasmEdgeFFmpegSWScale { +public: + SwsGetGaussianVec(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWScale(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SwsVectorPtr, double Variance, double Quality); +}; + +class SwsScaleVec : public WasmEdgeFFmpegSWScale { +public: + SwsScaleVec(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWScale(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsVectorId, + double Scalar); +}; + +class SwsNormalizeVec : public WasmEdgeFFmpegSWScale { +public: + SwsNormalizeVec(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWScale(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsVectorId, + double Height); +}; + +class SwsGetCoeffVecLength + : public WasmEdgeFFmpegSWScale { +public: + SwsGetCoeffVecLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWScale(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t SwsVectorId); +}; + +class SwsGetCoeff : public WasmEdgeFFmpegSWScale { +public: + SwsGetCoeff(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWScale(HostEnv) {} + Expect body(const Runtime::CallingFrame &, uint32_t SwsVectorId, + uint32_t CoeffBuf, uint32_t Len); +}; + +class SwsFreeVec : public WasmEdgeFFmpegSWScale { +public: + SwsFreeVec(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWScale(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t SwsVectorId); +}; + +class SwscaleVersion : public WasmEdgeFFmpegSWScale { +public: + SwscaleVersion(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWScale(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class SwscaleConfigurationLength + : public WasmEdgeFFmpegSWScale { +public: + SwscaleConfigurationLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWScale(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class SwscaleConfiguration + : public WasmEdgeFFmpegSWScale { +public: + SwscaleConfiguration(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWScale(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, + uint32_t ConfigLen); +}; + +class SwscaleLicenseLength + : public WasmEdgeFFmpegSWScale { +public: + SwscaleLicenseLength(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWScale(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame); +}; + +class SwscaleLicense : public WasmEdgeFFmpegSWScale { +public: + SwscaleLicense(std::shared_ptr HostEnv) + : WasmEdgeFFmpegSWScale(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, + uint32_t LicenseLen); +}; + +} // namespace SWScale +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt index 808507fd..211c0bbd 100644 --- a/test/plugins/CMakeLists.txt +++ b/test/plugins/CMakeLists.txt @@ -1,14 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: 2019-2022 Second State INC +if(WASMEDGE_PLUGIN_FFMPEG) + add_subdirectory(wasmedge_ffmpeg) +endif() + if(WASMEDGE_PLUGIN_PROCESS) - if (CMAKE_SYSTEM_NAME MATCHES "Linux") + if(CMAKE_SYSTEM_NAME MATCHES "Linux") add_subdirectory(wasmedge_process) endif() endif() if(WASMEDGE_PLUGIN_ZLIB) - add_subdirectory(wasmedge_zlib) + add_subdirectory(wasmedge_zlib) endif() if(WASMEDGE_PLUGIN_WASI_NN_BACKEND) diff --git a/test/plugins/wasmedge_ffmpeg/CMakeLists.txt b/test/plugins/wasmedge_ffmpeg/CMakeLists.txt new file mode 100644 index 00000000..c35e0bc4 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/CMakeLists.txt @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +wasmedge_add_executable(wasmedgeFFmpegTests + main.cpp + + avcodec/avcodec_func.cpp + avcodec/avCodec.cpp + avcodec/avCodecParameters.cpp + avcodec/avPacket.cpp + avcodec/avCodecCtx.cpp + + avfilter/avfilter_func.cpp + avfilter/avfilter.cpp + + avformat/avformat_func.cpp + avformat/avformatContext.cpp + avformat/avInputOutputContext.cpp + avformat/avStream.cpp + avformat/avChapter.cpp + + avutil/avRational.cpp + avutil/avDictionary.cpp + avutil/avFrame.cpp + avutil/avutil_func.cpp + avutil/avError.cpp + avutil/avSampleFmt.cpp + avutil/avPixfmt.cpp + + swscale/swscale_func.cpp + + swresample/swresample_func.cpp + + utils.cpp +) + +# Downloading a sample file +execute_process( + COMMAND bash ${CMAKE_SOURCE_DIR}/utils/ffmpeg/download-ffmpeg-sample-video.sh ${CMAKE_CURRENT_BINARY_DIR}/ffmpeg-assets + RESULT_VARIABLE DOWNLOAD_ERROR + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +add_dependencies(wasmedgeFFmpegTests + wasmedgePluginWasmEdgeFFmpeg +) + +target_include_directories(wasmedgeFFmpegTests + PUBLIC + $ + $ + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_link_libraries(wasmedgeFFmpegTests + PRIVATE + wasmedgePluginWasmEdgeFFmpeg + ${GTEST_BOTH_LIBRARIES} +) + +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgeFFmpegTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgeFFmpegTests + PRIVATE + wasmedge_shared + ) +endif() + +add_test(wasmedgeFFmpegTests wasmedgeFFmpegTests) diff --git a/test/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp b/test/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp new file mode 100644 index 00000000..6eb3a3d4 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp @@ -0,0 +1,365 @@ +#include "avcodec/avCodec.h" +#include "avcodec/module.h" +#include "utils.h" + +#include + +// Testing all AVCodecstruct + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVCodec) { + + ASSERT_TRUE(AVCodecMod != nullptr); + + uint32_t AVCodecPtr = UINT32_C(20); + uint32_t StringPtr = UINT32_C(68); + uint32_t NumeratorPtr = UINT32_C(72); + uint32_t DenominatorPtr = UINT32_C(76); + std::string FileName = "ffmpeg-assets/sample_video.mp4"; // 32 chars + spdlog::info("Init FFmpeg Structs"sv); + initFFmpegStructs(AVCodecPtr, UINT32_C(24), UINT32_C(28), FileName, + UINT32_C(60), UINT32_C(64), UINT32_C(68), UINT32_C(72)); + + uint32_t AVCodecId = readUInt32(MemInst, AVCodecPtr); + auto *FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_avcodec_id"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecID = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecId"sv); + { + EXPECT_TRUE(HostFuncAVCodecID.run( + CallFrame, std::initializer_list{AVCodecId}, + Result)); + EXPECT_EQ(Result[0].get(), 27); // H264 + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_avcodec_type"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecType = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecType"sv); + { + EXPECT_TRUE(HostFuncAVCodecType.run( + CallFrame, std::initializer_list{AVCodecId}, + Result)); + EXPECT_EQ(Result[0].get(), + 0); // MediaType is Video + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_avcodec_max_lowres"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecMaxLowres = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecMaxLowres"sv); + { + EXPECT_TRUE(HostFuncAVCodecMaxLowres.run( + CallFrame, std::initializer_list{AVCodecId}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_capabilities"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCapabilities = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCapabilities &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecCapabilities"sv); + { + EXPECT_TRUE(HostFuncAVCodecCapabilities.run( + CallFrame, std::initializer_list{AVCodecId}, + Result)); + EXPECT_TRUE(Result[0].get() > 0); + } + + int32_t Length = 0; + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_get_name_len"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecGetNameLen = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecGetNameLen &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecGetNameLen"sv); + { + EXPECT_TRUE(HostFuncAVCodecGetNameLen.run( + CallFrame, std::initializer_list{AVCodecId}, + Result)); + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_avcodec_get_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecGetName = + dynamic_cast( + FuncInst->getHostFunc()); + + // Fill the Memory with 0. + fillMemContent(MemInst, StringPtr, Length); + spdlog::info("Testing AVCodecGetName"sv); + { + EXPECT_TRUE( + HostFuncAVCodecGetName.run(CallFrame, + std::initializer_list{ + AVCodecId, StringPtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_get_long_name_len"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecGetLongNameLen = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecGetLongNameLen &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecGetLongNameLen"sv); + { + EXPECT_TRUE(HostFuncAVCodecGetLongNameLen.run( + CallFrame, std::initializer_list{AVCodecId}, + Result)); + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_get_long_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecGetLongName = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecGetLongName &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecGetLongName"sv); + { + EXPECT_TRUE(HostFuncAVCodecGetLongName.run( + CallFrame, + std::initializer_list{AVCodecId, StringPtr, + Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_avcodec_profiles"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecProfiles = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecProfiles"sv); + { + EXPECT_TRUE(HostFuncAVCodecProfiles.run( + CallFrame, std::initializer_list{AVCodecId}, + Result)); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_pix_fmts_is_null"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecPixFmtIsNull = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecPixFmtsIsNull &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecPixFmtsIsNull"sv); + { + EXPECT_TRUE(HostFuncAVCodecPixFmtIsNull.run( + CallFrame, std::initializer_list{AVCodecId}, + Result)); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_pix_fmts_iter"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecPixFmtIter = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecPixFmtsIter &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecPixFmtsIter"sv); + { + uint32_t Idx = 0; + EXPECT_TRUE(HostFuncAVCodecPixFmtIter.run( + CallFrame, std::initializer_list{AVCodecId, Idx}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_supported_framerate_is_null"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecSupportedFrameratesIsNull = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecSupportedFrameratesIsNull + &>(FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecSupportedFramratesIsNull"sv); + { + EXPECT_TRUE(HostFuncAVCodecSupportedFrameratesIsNull.run( + CallFrame, std::initializer_list{AVCodecId}, + Result)); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_supported_framerate_iter"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecSupportedFrameratesIter = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecSupportedFrameratesIter + &>(FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecSupportedFrameratesIter"sv); + { + EXPECT_TRUE(HostFuncAVCodecSupportedFrameratesIter.run( + CallFrame, + std::initializer_list{AVCodecId, 1, NumeratorPtr, + DenominatorPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_supported_samplerates_is_null"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecSupportedSampleRatesIsNull = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecSupportedSampleRatesIsNull + &>(FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecSupportedSampleRatesIsNull"sv); + { + EXPECT_TRUE(HostFuncAVCodecSupportedSampleRatesIsNull.run( + CallFrame, std::initializer_list{AVCodecId}, + Result)); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_supported_samplerates_iter"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecSupportedSampleRatesIter = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecSupportedSampleRatesIter + &>(FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecSupportedSampleRatesIter"sv); + { + EXPECT_TRUE(HostFuncAVCodecSupportedSampleRatesIter.run( + CallFrame, std::initializer_list{AVCodecId, 0}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_channel_layouts_is_null"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecChannelLayoutIsNull = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecChannelLayoutIsNull &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecChannelLayoutIsNull"sv); + { + EXPECT_TRUE(HostFuncAVCodecChannelLayoutIsNull.run( + CallFrame, std::initializer_list{AVCodecId}, + Result)); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_channel_layouts_iter"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecChannelLayoutIter = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecChannelLayoutIter &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecChannelLayoutIter"sv); + { + EXPECT_TRUE(HostFuncAVCodecChannelLayoutIter.run( + CallFrame, std::initializer_list{AVCodecId, 0}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_sample_fmts_is_null"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecSampleFmtsIsNull = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecSampleFmtsIsNull &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecSampleFmtsIsNull"sv); + { + EXPECT_TRUE(HostFuncAVCodecSampleFmtsIsNull.run( + CallFrame, std::initializer_list{AVCodecId}, + Result)); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_sample_fmts_iter"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecSampleFmtsIter = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecSampleFmtsIter &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecSampleFmtsIter"sv); + { + EXPECT_TRUE(HostFuncAVCodecSampleFmtsIter.run( + CallFrame, std::initializer_list{AVCodecId, 0}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } +} +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/test/plugins/wasmedge_ffmpeg/avcodec/avCodecCtx.cpp b/test/plugins/wasmedge_ffmpeg/avcodec/avCodecCtx.cpp new file mode 100644 index 00000000..b89128f8 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avcodec/avCodecCtx.cpp @@ -0,0 +1,1657 @@ +#include "avcodec/avCodecContext.h" +#include "avcodec/module.h" +#include "utils.h" + +#include + +// Testing all AVCodecCtxstruct +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVCodecCtx) { + ASSERT_TRUE(AVCodecMod != nullptr); + + uint32_t AVCodecCtxPtr = UINT32_C(64); + + std::string FileName = "ffmpeg-assets/sample_video.mp4"; // 32 chars + initFFmpegStructs(UINT32_C(20), UINT32_C(24), UINT32_C(28), FileName, + UINT32_C(60), AVCodecCtxPtr, UINT32_C(68), UINT32_C(72)); + uint32_t NumPtr = UINT32_C(76); + uint32_t DenPtr = UINT32_C(80); + uint32_t AVCodecPtr = UINT32_C(84); + + uint32_t AVCodecCtxId = readUInt32(MemInst, AVCodecCtxPtr); + + auto *FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_codec_id"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxCodecID = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxCodecID &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxCodecID.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), 27); // H264 + } + + int32_t CodecType = 0; // MediaType Video + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_codec_type"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetCodecType = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetCodecType &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetCodecType.run( + CallFrame, + std::initializer_list{AVCodecCtxId, CodecType}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_codec_type"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxCodecType = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxCodecType &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxCodecType.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), CodecType); // MediaType Video + } + + int32_t Num = 5; + int32_t Den = 10; + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_time_base"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetTimebase = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetTimebase &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetTimebase.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Num, Den}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_time_base"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxTimeBase = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxTimeBase &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxTimeBase.run( + CallFrame, + std::initializer_list{AVCodecCtxId, NumPtr, + DenPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + int32_t Numerator = readSInt32(MemInst, NumPtr); + int32_t Denominator = readSInt32(MemInst, DenPtr); + EXPECT_EQ(Numerator, Num); + EXPECT_EQ(Denominator, Den); + } + + int32_t Dimension = 200; + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_width"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetWidth = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetWidth &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetWidth.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Dimension}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_width"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxWidth = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxWidth.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), Dimension); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_height"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetHeight = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetHeight &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetHeight.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Dimension}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_height"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxHeight = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxHeight.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), Dimension); + } + + Num = 10; + Den = 20; + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_sample_aspect_ratio"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetSampleAspectRatio = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetSampleAspectRatio + &>(FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetSampleAspectRatio.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Num, Den}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_sample_aspect_ratio"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSampleAspectRatio = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSampleAspectRatio &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSampleAspectRatio.run( + CallFrame, + std::initializer_list{AVCodecCtxId, NumPtr, + DenPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + int32_t Numerator = readSInt32(MemInst, NumPtr); + int32_t Denominator = readSInt32(MemInst, DenPtr); + EXPECT_EQ(Numerator, Num); + EXPECT_EQ(Denominator, Den); + } + + uint64_t ChannelLayoutId = 1; // FRONT_LEFT; + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_channel_layout"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetChannelLayout = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetChannelLayout &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetChannelLayout.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + ChannelLayoutId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_channel_layout"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxChannelLayout = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxChannelLayout &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxChannelLayout.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), ChannelLayoutId); + } + + uint32_t PixFormatId = 1; // YUV420P + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_pix_fmt"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetPixFormat = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetPixFormat &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetPixFormat.run( + CallFrame, + std::initializer_list{AVCodecCtxId, PixFormatId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_pix_fmt"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxPixFormat = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxPixFormat &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxPixFormat.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), PixFormatId); + } + + uint32_t SampleFmtId = 1; // SAMPLE_FMT_U8 + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_sample_format"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetSampleFormat = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetSampleFormat &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetSampleFormat.run( + CallFrame, + std::initializer_list{AVCodecCtxId, SampleFmtId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_sample_format"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSampleFormat = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSampleFormat &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSampleFormat.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), SampleFmtId); + } + + int32_t SampleRate = 500; + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_sample_rate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetSampleRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetSampleRate &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetSampleRate.run( + CallFrame, + std::initializer_list{AVCodecCtxId, SampleRate}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_sample_rate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSampleRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSampleRate &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSampleRate.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), SampleRate); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_gop_size"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetGopSize = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetGopSize &>( + FuncInst->getHostFunc()); + + { + int32_t GopSize = 20; + EXPECT_TRUE(HostFuncAVCodecCtxSetGopSize.run( + CallFrame, + std::initializer_list{AVCodecCtxId, GopSize}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_max_b_frames"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetMaxBFrames = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetMaxBFrames &>( + FuncInst->getHostFunc()); + + { + int32_t MaxBFrames = 30; + EXPECT_TRUE(HostFuncAVCodecCtxSetMaxBFrames.run( + CallFrame, + std::initializer_list{AVCodecCtxId, MaxBFrames}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_b_quant_factor"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetBQuantFactor = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetBQuantFactor &>( + FuncInst->getHostFunc()); + + { + float BQuantFactor = 12.32; + EXPECT_TRUE(HostFuncAVCodecCtxSetBQuantFactor.run( + CallFrame, + std::initializer_list{AVCodecCtxId, BQuantFactor}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_b_quant_offset"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetBQuantOffset = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetBQuantOffset &>( + FuncInst->getHostFunc()); + + { + float BQuantOffset = 3.53; + EXPECT_TRUE(HostFuncAVCodecCtxSetBQuantOffset.run( + CallFrame, + std::initializer_list{AVCodecCtxId, BQuantOffset}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_i_quant_factor"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetIQuantFactor = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetIQuantFactor &>( + FuncInst->getHostFunc()); + + { + float IQuantFactor = 3.435; + EXPECT_TRUE(HostFuncAVCodecCtxSetIQuantFactor.run( + CallFrame, + std::initializer_list{AVCodecCtxId, IQuantFactor}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_i_quant_offset"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetIQuantOffset = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetIQuantOffset &>( + FuncInst->getHostFunc()); + + { + float IQuantOffset = 6.322; + EXPECT_TRUE(HostFuncAVCodecCtxSetIQuantOffset.run( + CallFrame, + std::initializer_list{AVCodecCtxId, IQuantOffset}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_lumi_masking"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetLumiMasking = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetLumiMasking &>( + FuncInst->getHostFunc()); + + { + float LumiMasking = 54.32432; + EXPECT_TRUE(HostFuncAVCodecCtxSetLumiMasking.run( + CallFrame, + std::initializer_list{AVCodecCtxId, LumiMasking}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_temporal_cplx_masking"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetTemporalCplxMasking = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetTemporalCplxMasking + &>(FuncInst->getHostFunc()); + + { + float TemporialCplxMasking = 642.32; + EXPECT_TRUE(HostFuncAVCodecCtxSetTemporalCplxMasking.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + TemporialCplxMasking}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_spatial_cplx_masking"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetSpatialCplxMasking = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetSpatialCplxMasking + &>(FuncInst->getHostFunc()); + + { + float SpatialCplxMasking = 324.32; + EXPECT_TRUE(HostFuncAVCodecCtxSetSpatialCplxMasking.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + SpatialCplxMasking}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_p_masking"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetPMasking = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetPMasking &>( + FuncInst->getHostFunc()); + + { + float PMasking = 65.3245; + EXPECT_TRUE(HostFuncAVCodecCtxSetPMasking.run( + CallFrame, + std::initializer_list{AVCodecCtxId, PMasking}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_dark_masking"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetDarkMasking = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetDarkMasking &>( + FuncInst->getHostFunc()); + + { + float DarkMasking = 83.32; + EXPECT_TRUE(HostFuncAVCodecCtxSetDarkMasking.run( + CallFrame, + std::initializer_list{AVCodecCtxId, DarkMasking}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_me_cmp"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetMeCmp = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetMeCmp &>( + FuncInst->getHostFunc()); + + { + int32_t MeCmp = 532; + EXPECT_TRUE(HostFuncAVCodecCtxSetMeCmp.run( + CallFrame, + std::initializer_list{AVCodecCtxId, MeCmp}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_me_sub_cmp"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetMeSubCmp = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetMeSubCmp &>( + FuncInst->getHostFunc()); + + { + int32_t MeSubCmp = 321; + EXPECT_TRUE(HostFuncAVCodecCtxSetMeSubCmp.run( + CallFrame, + std::initializer_list{AVCodecCtxId, MeSubCmp}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_mb_cmp"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetMbCmp = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetMbCmp &>( + FuncInst->getHostFunc()); + + { + int32_t MbCmp = 243; + EXPECT_TRUE(HostFuncAVCodecCtxSetMbCmp.run( + CallFrame, + std::initializer_list{AVCodecCtxId, MbCmp}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_ildct_cmp"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetIldctCmp = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetIldctCmp &>( + FuncInst->getHostFunc()); + + { + int32_t IldctCmp = 3; + EXPECT_TRUE(HostFuncAVCodecCtxSetIldctCmp.run( + CallFrame, + std::initializer_list{AVCodecCtxId, IldctCmp}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_dia_size"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetDiaSize = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetDiaSize &>( + FuncInst->getHostFunc()); + + { + int32_t DiaSize = 9; + EXPECT_TRUE(HostFuncAVCodecCtxSetDiaSize.run( + CallFrame, + std::initializer_list{AVCodecCtxId, DiaSize}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_last_predictor_count"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetLastPredictorsCount = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetLastPredictorsCount + &>(FuncInst->getHostFunc()); + + { + int32_t LastPredictorCount = 21; + EXPECT_TRUE(HostFuncAVCodecCtxSetLastPredictorsCount.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + LastPredictorCount}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_me_pre_cmp"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetMePreCmp = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetMePreCmp &>( + FuncInst->getHostFunc()); + + { + int32_t MePreCmp = 53; + EXPECT_TRUE(HostFuncAVCodecCtxSetMePreCmp.run( + CallFrame, + std::initializer_list{AVCodecCtxId, MePreCmp}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_pre_dia_size"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetPreDiaSize = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetPreDiaSize &>( + FuncInst->getHostFunc()); + + { + int32_t PreDiaSize = 74; + EXPECT_TRUE(HostFuncAVCodecCtxSetPreDiaSize.run( + CallFrame, + std::initializer_list{AVCodecCtxId, PreDiaSize}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_me_subpel_quality"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetMeSubpelQuality = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetMeSubpelQuality &>( + FuncInst->getHostFunc()); + + { + int32_t MeSubpelQuality = 85; + EXPECT_TRUE(HostFuncAVCodecCtxSetMeSubpelQuality.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + MeSubpelQuality}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_me_range"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetMeRange = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetMeRange &>( + FuncInst->getHostFunc()); + + { + int32_t SetMeRange = 31; + EXPECT_TRUE(HostFuncAVCodecCtxSetMeRange.run( + CallFrame, + std::initializer_list{AVCodecCtxId, SetMeRange}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_mb_decision"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetMbDecision = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetMbDecision &>( + FuncInst->getHostFunc()); + + { + int32_t MbDecision = 78; + EXPECT_TRUE(HostFuncAVCodecCtxSetMbDecision.run( + CallFrame, + std::initializer_list{AVCodecCtxId, MbDecision}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_mb_lmin"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetMbLMin = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetMbLMin &>( + FuncInst->getHostFunc()); + + { + int32_t MbLMin = 11; + EXPECT_TRUE(HostFuncAVCodecCtxSetMbLMin.run( + CallFrame, + std::initializer_list{AVCodecCtxId, MbLMin}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_mb_lmax"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetMbLMax = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetMbLMax &>( + FuncInst->getHostFunc()); + + { + int32_t MbLMax = 18; + EXPECT_TRUE(HostFuncAVCodecCtxSetMbLMax.run( + CallFrame, + std::initializer_list{AVCodecCtxId, MbLMax}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_intra_dc_precision"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + int32_t IntraDcPrecision = 323; + auto &HostFuncAVCodecCtxSetIntraDcPrecision = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetIntraDcPrecision &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetIntraDcPrecision.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + IntraDcPrecision}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_intra_dc_precision"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxIntraDcPrecision = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxIntraDcPrecision &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxIntraDcPrecision.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), IntraDcPrecision); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_qmin"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetQMin = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetQMin &>( + FuncInst->getHostFunc()); + + { + int32_t QMin = 10; + EXPECT_TRUE(HostFuncAVCodecCtxSetQMin.run( + CallFrame, + std::initializer_list{AVCodecCtxId, QMin}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_qmax"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetQMax = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetQMax &>( + FuncInst->getHostFunc()); + + { + int32_t QMax = 20; + EXPECT_TRUE(HostFuncAVCodecCtxSetQMax.run( + CallFrame, + std::initializer_list{AVCodecCtxId, QMax}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_global_quality"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetGlobalQuality = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetGlobalQuality &>( + FuncInst->getHostFunc()); + + { + int32_t GlobalQuality = 93; + EXPECT_TRUE(HostFuncAVCodecCtxSetGlobalQuality.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + GlobalQuality}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_colorspace"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetColorspace = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetColorspace &>( + FuncInst->getHostFunc()); + + int32_t ColorspaceId = 1; // AVCOL_SPC_BT709 + { + EXPECT_TRUE(HostFuncAVCodecCtxSetColorspace.run( + CallFrame, + std::initializer_list{AVCodecCtxId, ColorspaceId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_colorspace"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxColorspace = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxColorspace &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxColorspace.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), ColorspaceId); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_color_range"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetColorRange = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetColorRange &>( + FuncInst->getHostFunc()); + + int32_t ColorRangeId = 1; // MPEG + { + EXPECT_TRUE(HostFuncAVCodecCtxSetColorRange.run( + CallFrame, + std::initializer_list{AVCodecCtxId, ColorRangeId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_color_range"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxColorRange = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxColorRange &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxColorRange.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), ColorRangeId); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_frame_size"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxFrameSize = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxFrameSize &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxFrameSize.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_bit_rate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetBitRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetBitRate &>( + FuncInst->getHostFunc()); + + int64_t BitRate = 9932; + { + EXPECT_TRUE(HostFuncAVCodecCtxSetBitRate.run( + CallFrame, + std::initializer_list{AVCodecCtxId, BitRate}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_bit_rate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxBitRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxBitRate &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxBitRate.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), BitRate); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_rc_max_rate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + int64_t RcMaxRate = 3245; + auto &HostFuncAVCodecCtxSetRcMaxRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetRcMaxRate &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetRcMaxRate.run( + CallFrame, + std::initializer_list{AVCodecCtxId, RcMaxRate}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_rc_max_rate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxRcMaxRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxRcMaxRate &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxRcMaxRate.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), RcMaxRate); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_bit_rate_tolerance"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetBitRateTolerance = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetBitRateTolerance &>( + FuncInst->getHostFunc()); + + { + int32_t BitRateTolerance = 9543; + EXPECT_TRUE(HostFuncAVCodecCtxSetBitRateTolerance.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + BitRateTolerance}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_compression_level"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetCompressionLevel = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetCompressionLevel &>( + FuncInst->getHostFunc()); + + { + int32_t CompressionLevel = 934; + EXPECT_TRUE(HostFuncAVCodecCtxSetCompressionLevel.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + CompressionLevel}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_framerate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + Num = 20; + Den = 30; + auto &HostFuncAVCodecCtxSetFrameRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetFrameRate &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetFrameRate.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Num, Den}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_framerate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxFrameRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxFrameRate &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxFrameRate.run( + CallFrame, + std::initializer_list{AVCodecCtxId, NumPtr, + DenPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + int32_t Numerator = readSInt32(MemInst, NumPtr); + int32_t Denominator = readSInt32(MemInst, DenPtr); + EXPECT_EQ(Numerator, Num); + EXPECT_EQ(Denominator, Den); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_flags"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetFlags = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetFlags &>( + FuncInst->getHostFunc()); + + { + int32_t Flags = 3; + EXPECT_TRUE(HostFuncAVCodecCtxSetFlags.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Flags}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_strict_std_compliance"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetStrictStdCompliance = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetStrictStdCompliance + &>(FuncInst->getHostFunc()); + + { + int32_t ComplianceId = 3; + EXPECT_TRUE(HostFuncAVCodecCtxSetStrictStdCompliance.run( + CallFrame, + std::initializer_list{AVCodecCtxId, ComplianceId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_debug"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetDebug = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetDebug &>( + FuncInst->getHostFunc()); + + { + int32_t Debug = 50; + EXPECT_TRUE(HostFuncAVCodecCtxSetDebug.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Debug}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_codec"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxCodec = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxCodec.run( + CallFrame, + std::initializer_list{AVCodecCtxId, AVCodecPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + ASSERT_TRUE(readUInt32(MemInst, AVCodecPtr) > 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_channels"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetChannels = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetChannels &>( + FuncInst->getHostFunc()); + + int32_t Channels = 10; + { + EXPECT_TRUE(HostFuncAVCodecCtxSetChannels.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Channels}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_channels"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxChannels = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxChannels &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxChannels.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), Channels); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_skip_loop_filter"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetSkipLoopFilter = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetSkipLoopFilter &>( + FuncInst->getHostFunc()); + + int32_t DiscardId = 16; // Bidirectional + { + EXPECT_TRUE(HostFuncAVCodecCtxSetSkipLoopFilter.run( + CallFrame, + std::initializer_list{AVCodecCtxId, DiscardId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_skip_frame"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetSkipFrame = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetSkipFrame &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetSkipFrame.run( + CallFrame, + std::initializer_list{AVCodecCtxId, DiscardId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_skip_idct"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetSkipIdct = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetSkipIdct &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetSkipIdct.run( + CallFrame, + std::initializer_list{AVCodecCtxId, DiscardId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_error_concealment"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetErrorConcealment = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetErrorConcealment &>( + FuncInst->getHostFunc()); + + { + int32_t ErrorConcealment = 99; + EXPECT_TRUE(HostFuncAVCodecCtxSetErrorConcealment.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + ErrorConcealment}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_err_recognition"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetErrorRecognition = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetErrorRecognition &>( + FuncInst->getHostFunc()); + + { + int32_t ErrorRecognition = 88; + EXPECT_TRUE(HostFuncAVCodecCtxSetErrorRecognition.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + ErrorRecognition}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_delay"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxDelay = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxDelay.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_skip_top"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetSkipTop = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetSkipTop &>( + FuncInst->getHostFunc()); + + { + int32_t Value = 50; + EXPECT_TRUE(HostFuncAVCodecCtxSetSkipTop.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Value}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_skip_bottom"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetSkipBottom = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetSkipBottom &>( + FuncInst->getHostFunc()); + + { + int32_t Value = 60; + EXPECT_TRUE(HostFuncAVCodecCtxSetSkipBottom.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Value}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_refs"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxRefs = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxRefs.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), 4); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_slice_flags"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetSliceFlags = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetSliceFlags &>( + FuncInst->getHostFunc()); + + { + int32_t Value = 70; + EXPECT_TRUE(HostFuncAVCodecCtxSetSliceFlags.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Value}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_slice_count"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetSliceCount = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetSliceCount &>( + FuncInst->getHostFunc()); + + { + int32_t Value = 100; + EXPECT_TRUE(HostFuncAVCodecCtxSetSliceCount.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Value}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_field_order"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetFieldOrder = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetFieldOrder &>( + FuncInst->getHostFunc()); + + { + int32_t Value = 200; + EXPECT_TRUE(HostFuncAVCodecCtxSetFieldOrder.run( + CallFrame, + std::initializer_list{AVCodecCtxId, Value}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_color_trc"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxColorTrc = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxColorTrc &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxColorTrc.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + ASSERT_TRUE(Result[0].get() > 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_chroma_sample_location"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxChromaSampleLocation = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxChromaSampleLocation + &>(FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxChromaSampleLocation.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + ASSERT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_frame_number"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxFrameNumber = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxFrameNumber &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxFrameNumber.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_block_align"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxBlockAlign = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxBlockAlign &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxBlockAlign.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_request_sample_fmt"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetRequestSampleFmt = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetRequestSampleFmt &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetRequestSampleFmt.run( + CallFrame, + std::initializer_list{AVCodecCtxId, SampleFmtId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_audio_service_type"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxAudioServiceType = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxAudioServiceType &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxAudioServiceType.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + ASSERT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_has_b_frames"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxHasBFrames = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxHasBFrames &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxHasBFrames.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + ASSERT_TRUE(Result[0].get() > 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_request_channel_layout"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetRequestChannelLayout = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetRequestChannelLayout + &>(FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxSetRequestChannelLayout.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + ChannelLayoutId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_active_thread_type"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxActiveThreadType = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxActiveThreadType &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxActiveThreadType.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + ASSERT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_thread_type"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetThreadType = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetThreadType &>( + FuncInst->getHostFunc()); + + { + int32_t ThreadType = 1; // Frame + EXPECT_TRUE(HostFuncAVCodecCtxSetThreadType.run( + CallFrame, + std::initializer_list{AVCodecCtxId, ThreadType}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_set_thread_count"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxSetThreadCount = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetThreadCount &>( + FuncInst->getHostFunc()); + + int32_t ThreadCount = 50; + { + EXPECT_TRUE(HostFuncAVCodecCtxSetThreadCount.run( + CallFrame, + std::initializer_list{AVCodecCtxId, ThreadCount}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_thread_count"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxThreadCount = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxThreadCount &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxThreadCount.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), ThreadCount); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_color_primaries"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecCtxColorPrimaries = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxColorPrimaries &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecCtxColorPrimaries.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + ASSERT_TRUE(Result[0].get() >= 0); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp b/test/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp new file mode 100644 index 00000000..1217914e --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp @@ -0,0 +1,75 @@ +#include "avcodec/avCodecParameters.h" +#include "avcodec/module.h" +#include "utils.h" + +#include + +// Testing all AVCodecstruct + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVCodecParameters) { + ASSERT_TRUE(AVCodecMod != nullptr); + + uint32_t AVCodecParamPtr = UINT32_C(60); + + std::string FileName = "ffmpeg-assets/sample_video.mp4"; // 32 chars + initFFmpegStructs(UINT32_C(20), UINT32_C(24), UINT32_C(28), FileName, + AVCodecParamPtr, UINT32_C(64), UINT32_C(68), UINT32_C(72)); + + uint32_t AVCodecParamId = readUInt32(MemInst, AVCodecParamPtr); + + auto *FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodecparam_codec_id"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecParamCodecId = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecParamCodecId &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecParamCodecId.run( + CallFrame, std::initializer_list{AVCodecParamId}, + Result)); + EXPECT_EQ(Result[0].get(), 27); // H264 + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodecparam_codec_type"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecParamCodecType = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecParamCodecType &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecParamCodecType.run( + CallFrame, std::initializer_list{AVCodecParamId}, + Result)); + EXPECT_EQ(Result[0].get(), 0); // MediaType Video + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodecparam_set_codec_tag"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecParamSetCodecTag = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecParamSetCodecTag &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVCodecParamSetCodecTag.run( + CallFrame, + std::initializer_list{AVCodecParamId, 20}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/test/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp b/test/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp new file mode 100644 index 00000000..4d7348a3 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp @@ -0,0 +1,368 @@ +#include "avcodec/avPacket.h" +#include "avcodec/module.h" +#include "utils.h" + +#include + +// Testing all AVPacket + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVPacketTest) { + + ASSERT_TRUE(AVCodecMod != nullptr); + + uint32_t PacketPtr = UINT32_C(4); + uint32_t PacketPtr2 = UINT32_C(8); + uint32_t DataPtr = UINT32_C(12); + + auto *FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_alloc"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVPacketAlloc = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketAlloc.run( + CallFrame, std::initializer_list{PacketPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_TRUE(HostFuncAVPacketAlloc.run( + CallFrame, std::initializer_list{PacketPtr2}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + uint32_t PacketId = readUInt32(MemInst, PacketPtr); + uint32_t PacketId2 = readUInt32(MemInst, PacketPtr2); + ASSERT_TRUE(PacketId > 0); + ASSERT_TRUE(PacketId2 > 0); + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_new_packet"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVNewPacket = + dynamic_cast( + FuncInst->getHostFunc()); + + { + uint32_t Size = 40; + EXPECT_TRUE(HostFuncAVNewPacket.run( + CallFrame, std::initializer_list{PacketId, Size}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_grow_packet"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVGrowPacket = + dynamic_cast( + FuncInst->getHostFunc()); + + { + uint32_t Size = 40; + EXPECT_TRUE(HostFuncAVGrowPacket.run( + CallFrame, std::initializer_list{PacketId, Size}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_shrink_packet"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVShrinkPacket = + dynamic_cast( + FuncInst->getHostFunc()); + + { + uint32_t Size = 40; + EXPECT_TRUE(HostFuncAVShrinkPacket.run( + CallFrame, std::initializer_list{PacketId, Size}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + uint32_t StreamIdx = 3; + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_av_packet_set_stream_index"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketSetStreamIndex = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVPacketSetStreamIndex &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketSetStreamIndex.run( + CallFrame, + std::initializer_list{PacketId, StreamIdx}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_av_packet_stream_index"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketStreamIndex = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVPacketStreamIndex &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketStreamIndex.run( + CallFrame, std::initializer_list{PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), StreamIdx); + } + + uint32_t Size = 0; + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_size"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketSize = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketSize.run( + CallFrame, std::initializer_list{PacketId}, + Result)); + Size = Result[0].get(); + EXPECT_TRUE(Size > 0); + } + + uint32_t Flags = 5; + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_av_packet_set_flags"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketSetFlags = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketSetFlags.run( + CallFrame, std::initializer_list{PacketId, Flags}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_flags"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketFlags = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketFlags.run( + CallFrame, std::initializer_list{PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), Flags); + } + + int64_t Pos = 500; + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_set_pos"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketSetPos = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketSetPos.run( + CallFrame, std::initializer_list{PacketId, Pos}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_pos"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketPos = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketPos.run( + CallFrame, std::initializer_list{PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), Pos); + } + + int64_t Duration = 100; + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_av_packet_set_duration"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketSetDuration = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVPacketSetDuration &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketSetDuration.run( + CallFrame, + std::initializer_list{PacketId, Duration}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_duration"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketDuration = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketDuration.run( + CallFrame, std::initializer_list{PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), Duration); + } + + int64_t Dts = 1000; + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_set_dts"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketSetDts = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketSetDts.run( + CallFrame, std::initializer_list{PacketId, Dts}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_dts"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketDts = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketDts.run( + CallFrame, std::initializer_list{PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), Dts); + } + + int64_t Pts = 5000; + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_set_pts"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketSetPts = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketSetPts.run( + CallFrame, std::initializer_list{PacketId, Pts}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_pts"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketPts = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketPts.run( + CallFrame, std::initializer_list{PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), Pts); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_av_packet_is_data_null"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketIsDataNull = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVPacketIsDataNull &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketIsDataNull.run( + CallFrame, std::initializer_list{PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_data"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVPacketData = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketData.run( + CallFrame, + std::initializer_list{PacketId, DataPtr, Size}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_ref"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVPacketRef = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketRef.run( + CallFrame, + std::initializer_list{PacketId2, PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_unref"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVPacketUnref = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVPacketUnref.run( + CallFrame, std::initializer_list{PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/test/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp b/test/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp new file mode 100644 index 00000000..d5fcc6b0 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp @@ -0,0 +1,574 @@ +#include "avcodec/avcodec_func.h" +#include "avcodec/module.h" + +#include "utils.h" +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +// TODO: Commented functions need to be tested. + +TEST_F(FFmpegTest, AVCodecFunc) { + ASSERT_TRUE(AVCodecMod != nullptr); + + uint32_t CodecCtxPtr = UINT32_C(4); + uint32_t CodecParamPtr = UINT32_C(8); + uint32_t CodecParamPtr2 = UINT32_C(20); + uint32_t CodecDecoderPtr = UINT32_C(12); + uint32_t CodecEncoderPtr = UINT32_C(16); + uint32_t StrPtr = UINT32_C(32); + + uint32_t CodecNamePtr = UINT32_C(150); + std::string CodecName = "mpeg1video"; + spdlog::info("Filling memory CodecName into CodecNamePtr"sv); + fillMemContent(MemInst, CodecNamePtr, CodecName); + + uint32_t ID = 1; + + auto *FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_alloc_context3"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecAllocContext3 = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecAllocContext3 &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AvCodecAllocContext3"sv); + { + EXPECT_TRUE(HostFuncAVCodecAllocContext3.run( + CallFrame, std::initializer_list{0, CodecCtxPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + uint32_t AVCodecCtxId = readUInt32(MemInst, CodecCtxPtr); + ASSERT_TRUE(AVCodecCtxId > 0); + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_parameters_alloc"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecParametersAlloc = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecParametersAlloc &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecParametersAlloc"sv); + { + EXPECT_TRUE(HostFuncAVCodecParametersAlloc.run( + CallFrame, std::initializer_list{CodecParamPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_TRUE(HostFuncAVCodecParametersAlloc.run( + CallFrame, std::initializer_list{CodecParamPtr2}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + uint32_t AVCodecParamId = readUInt32(MemInst, CodecParamPtr); + ASSERT_TRUE(AVCodecParamId > 0); + + uint32_t AVCodecParamId2 = readUInt32(MemInst, CodecParamPtr2); + ASSERT_TRUE(AVCodecParamId2 > 0); + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_parameters_from_context"sv); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecParametersFromContext = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecParametersFromContext &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecParametersFromContext"sv); + { + EXPECT_TRUE(HostFuncAVCodecParametersFromContext.run( + CallFrame, + std::initializer_list{AVCodecParamId, + AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_avcodec_get_type"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecGetType = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecGetType"sv); + { + EXPECT_TRUE(HostFuncAVCodecGetType.run( + CallFrame, std::initializer_list{ID}, Result)); + EXPECT_EQ(Result[0].get(), 0); // Video Type + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_find_decoder"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecFindDecoder = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecFindDecoder &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecFindDecoder"sv); + { + EXPECT_TRUE(HostFuncAVCodecFindDecoder.run( + CallFrame, + std::initializer_list{ID, CodecDecoderPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + uint32_t AVCodecDecoderId = readUInt32(MemInst, CodecDecoderPtr); + ASSERT_TRUE(AVCodecDecoderId > 0); + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_find_encoder"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecFindEncoder = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecFindEncoder &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecFindEncoder"sv); + { + EXPECT_TRUE(HostFuncAVCodecFindEncoder.run( + CallFrame, + std::initializer_list{ID, CodecEncoderPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + uint32_t AVCodecEncoderId = readUInt32(MemInst, CodecEncoderPtr); + ASSERT_TRUE(AVCodecEncoderId > 0); + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_avcodec_open2"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecOpen2 = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecOpen2"sv); + // Invalid argument passed. Return -22 Error code. Means functionality + // working. + { + EXPECT_TRUE( + HostFuncAVCodecOpen2.run(CallFrame, + std::initializer_list{ + AVCodecCtxId, AVCodecEncoderId, 0}, + Result)); + EXPECT_EQ(Result[0].get(), -22); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_av_codec_is_encoder"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecIsEncoder = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecIsEncoder"sv); + { + EXPECT_TRUE(HostFuncAVCodecIsEncoder.run( + CallFrame, + std::initializer_list{AVCodecEncoderId}, Result)); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_av_codec_is_decoder"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecIsDecoder = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecIsDecoder"sv); + { + EXPECT_TRUE(HostFuncAVCodecIsDecoder.run( + CallFrame, + std::initializer_list{AVCodecDecoderId}, Result)); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_find_decoder_by_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecFindDecoderByName = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecFindDecoderByName &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecFindDecoderByName"sv); + { + uint32_t Length = CodecName.length(); + EXPECT_TRUE(HostFuncAVCodecFindDecoderByName.run( + CallFrame, + std::initializer_list{CodecDecoderPtr, + CodecNamePtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_find_encoder_by_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecFindEncoderByName = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecFindEncoderByName &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecFindEncoderByName"sv); + { + uint32_t Length = CodecName.length(); + EXPECT_TRUE(HostFuncAVCodecFindEncoderByName.run( + CallFrame, + std::initializer_list{CodecEncoderPtr, + CodecNamePtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_parameters_to_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecParametersToContext = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecParametersToContext &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecParametersToContext"sv); + { + EXPECT_TRUE(HostFuncAVCodecParametersToContext.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + AVCodecParamId}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + // TODO: Need FormatCtxId To test this func. + // FuncInst = AVCodecMod->findFuncExports( + // "wasmedge_ffmpeg_avcodec_avcodec_parameters_copy"); + // EXPECT_NE(FuncInst, nullptr); + // EXPECT_TRUE(FuncInst->isHostFunction()); + // + // auto &HostFuncAVCodecParametersCopy = dynamic_cast< + // WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecParametersCopy &>( + // FuncInst->getHostFunc()); + // + // { + // EXPECT_TRUE(HostFuncAVCodecParametersCopy.run( + // CallFrame, std::initializer_list{}, Result)); + // EXPECT_EQ(Result[0].get(), + // static_cast(ErrNo::Success)); + // } + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_avcodec_version"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecVersion = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecVersion"sv); + { + EXPECT_TRUE(HostFuncAVCodecVersion.run( + CallFrame, std::initializer_list{}, Result)); + EXPECT_TRUE(Result[0].get() > 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_configuration_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecConfigurationLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecConfigurationLength &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecConfigurationLength"sv); + int32_t Length = 0; + { + EXPECT_TRUE(HostFuncAVCodecConfigurationLength.run( + CallFrame, std::initializer_list{}, Result)); + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_configuration"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecConfiguration = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecConfiguration &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecConfiguration"sv); + { + EXPECT_TRUE(HostFuncAVCodecConfiguration.run( + CallFrame, std::initializer_list{StrPtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_license_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecLicenseLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecLicenseLength &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecLicenseLength"sv); + { + EXPECT_TRUE(HostFuncAVCodecLicenseLength.run( + CallFrame, std::initializer_list{}, Result)); + + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_avcodec_license"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecLicense = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecLicense"sv); + { + EXPECT_TRUE(HostFuncAVCodecLicense.run( + CallFrame, std::initializer_list{StrPtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_free_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecFreeContext = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecFreeContext &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecFreeContext"sv); + { + EXPECT_TRUE(HostFuncAVCodecFreeContext.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_parameters_free"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecParametersFree = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecParametersFree &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecParametersFree"sv); + { + EXPECT_TRUE(HostFuncAVCodecParametersFree.run( + CallFrame, std::initializer_list{AVCodecParamId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} + +TEST_F(FFmpegTest, SendPacketReceiveFrame) { + + std::string FileName = "ffmpeg-assets/dummy.mp4"; // 32 chars + uint32_t CodecCtxPtr = UINT32_C(64); + uint32_t FramePtr = UINT32_C(72); + uint32_t PacketPtr = UINT32_C(68); + initFFmpegStructs(UINT32_C(20), UINT32_C(24), UINT32_C(28), FileName, + UINT32_C(60), CodecCtxPtr, PacketPtr, FramePtr); + + uint32_t FrameId = readUInt32(MemInst, FramePtr); + uint32_t PacketId = readUInt32(MemInst, PacketPtr); + uint32_t CodecCtxId = readUInt32(MemInst, CodecCtxPtr); + + auto *FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_avcodec_send_frame"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecSendFrame = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecSendFrame"sv); + // Invalid Argument Error. Should Use Encoder, I'm using decoder + // Aim is to test the functionality. + { + EXPECT_TRUE(HostFuncAVCodecSendFrame.run( + CallFrame, + std::initializer_list{CodecCtxId, FrameId}, + Result)); + EXPECT_EQ(Result[0].get(), -22); + } + + // Invalid Argument Error. Should Use Encoder, I'm using decoder + // Aim is to test the functionality. + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_receive_packet"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecReceivePacket = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecReceivePacket &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecReceivePacket"sv); + { + EXPECT_TRUE(HostFuncAVCodecReceivePacket.run( + CallFrame, + std::initializer_list{CodecCtxId, PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), -22); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_send_packet"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecSendPacket = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecSendPacket &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecSendPacket"sv); + // Send packet to Decoder. + { + EXPECT_TRUE(HostFuncAVCodecSendPacket.run( + CallFrame, + std::initializer_list{CodecCtxId, PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_receive_frame"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + // Decoder Receives the Packet as Frame. + auto &HostFuncAVCodecReceiveFrame = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecReceiveFrame &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecReceiveFrame"sv); + { + EXPECT_TRUE(HostFuncAVCodecReceiveFrame.run( + CallFrame, + std::initializer_list{CodecCtxId, FrameId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_av_packet_rescale_ts"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVPacketRescaleTs = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVPacketRescaleTs &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVPacketRescaleTs"sv); + { + int32_t SrcNum = 2; + int32_t SrcDen = 3; + int32_t DestNum = 5; + int32_t DestDen = 9; + EXPECT_TRUE(HostFuncAVPacketRescaleTs.run( + CallFrame, + std::initializer_list{PacketId, SrcNum, SrcDen, + DestNum, DestDen}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_av_packet_make_writable"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVPacketMakeWritable = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVPacketMakeWritable &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVPacketMakeWritable"sv); + { + EXPECT_TRUE(HostFuncAVPacketMakeWritable.run( + CallFrame, std::initializer_list{PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_flush_buffers"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecFlushBuffers = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecFlushBuffers &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecFlushBuffers"sv); + { + EXPECT_TRUE(HostFuncAVCodecFlushBuffers.run( + CallFrame, std::initializer_list{CodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_avcodec_close"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVCodecClose = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVCodecClose"sv); + { + EXPECT_TRUE(HostFuncAVCodecClose.run( + CallFrame, std::initializer_list{CodecCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avfilter/avfilter.cpp b/test/plugins/wasmedge_ffmpeg/avfilter/avfilter.cpp new file mode 100644 index 00000000..58930164 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avfilter/avfilter.cpp @@ -0,0 +1,287 @@ +#include "avfilter/avFilter.h" +#include "avfilter//avfilter_func.h" +#include "avfilter/module.h" +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVFilterStructs) { + ASSERT_TRUE(AVFilterMod != nullptr); + + uint32_t FilterPtr = UINT32_C(8); + uint32_t InputFilterPadPtr = UINT32_C(12); + uint32_t OutputFilterPadPtr = UINT32_C(16); + uint32_t InputNamePtr = UINT32_C(100); + uint32_t StrPtr = UINT32_C(150); + + std::string InputName = std::string("abuffer"); + fillMemContent(MemInst, InputNamePtr, InputName); + + // ================================================================== + // Start Initialize AVFilter + // ================================================================== + + auto *FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_get_by_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterGetByName = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterGetByName &>( + FuncInst->getHostFunc()); + + { + int32_t Length = InputName.length(); + EXPECT_TRUE(HostFuncAVFilterGetByName.run( + CallFrame, + std::initializer_list{FilterPtr, InputNamePtr, + Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + uint32_t FilterId = readUInt32(MemInst, FilterPtr); + ASSERT_TRUE(FilterId > 0); + // ================================================================== + // End Initialize AVFilter + // ================================================================== + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_name_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterNameLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterNameLength &>( + FuncInst->getHostFunc()); + + int32_t Length = 0; + { + EXPECT_TRUE(HostFuncAVFilterNameLength.run( + CallFrame, std::initializer_list{FilterId}, + Result)); + Length = Result[0].get(); + ASSERT_TRUE(Length > 0); + } + + FuncInst = + AVFilterMod->findFuncExports("wasmedge_ffmpeg_avfilter_avfilter_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterName = + dynamic_cast( + FuncInst->getHostFunc()); + + fillMemContent(MemInst, StrPtr, Length); + { + EXPECT_TRUE(HostFuncAVFilterName.run( + CallFrame, + std::initializer_list{FilterId, StrPtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_description_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterDescriptionLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterDescriptionLength &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterDescriptionLength.run( + CallFrame, std::initializer_list{FilterId}, + Result)); + + Length = Result[0].get(); + ASSERT_TRUE(Length > 0); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_description"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterDescription = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterDescription &>( + FuncInst->getHostFunc()); + + fillMemContent(MemInst, StrPtr, Length); + { + EXPECT_TRUE(HostFuncAVFilterDescription.run( + CallFrame, + std::initializer_list{FilterId, StrPtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_nb_inputs"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterNbInputs = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterNbInputs &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterNbInputs.run( + CallFrame, std::initializer_list{FilterId}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_nb_outputs"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterNbOutputs = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterNbOutputs &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterNbOutputs.run( + CallFrame, std::initializer_list{FilterId}, + Result)); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = + AVFilterMod->findFuncExports("wasmedge_ffmpeg_avfilter_avfilter_flags"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterFlags = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterFlags.run( + CallFrame, std::initializer_list{FilterId}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_get_inputs_filter_pad"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterGetInputsFilterPad = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterGetInputsFilterPad &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterGetInputsFilterPad.run( + CallFrame, + std::initializer_list{FilterId, + InputFilterPadPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_get_outputs_filter_pad"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterGetOutputsFilterPad = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterGetOutputsFilterPad &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterGetOutputsFilterPad.run( + CallFrame, + std::initializer_list{FilterId, + OutputFilterPadPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + uint32_t OutputFilterPadId = readUInt32(MemInst, OutputFilterPadPtr); + ASSERT_TRUE(OutputFilterPadId > 0); + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_pad_get_name_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterPadGetNameLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterPadGetNameLength &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterPadGetNameLength.run( + CallFrame, + std::initializer_list{OutputFilterPadId, 0}, + Result)); + Length = Result[0].get(); + ASSERT_TRUE(Length > 0); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_pad_get_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterPadGetName = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterPadGetName &>( + FuncInst->getHostFunc()); + + { + int32_t Idx = 0; + EXPECT_TRUE(HostFuncAVFilterPadGetName.run( + CallFrame, + std::initializer_list{OutputFilterPadId, Idx, + StrPtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_pad_get_type"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterPadGetType = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterPadGetType &>( + FuncInst->getHostFunc()); + + { + int32_t Idx = 0; + EXPECT_TRUE(HostFuncAVFilterPadGetType.run( + CallFrame, + std::initializer_list{OutputFilterPadId, Idx}, + Result)); + EXPECT_EQ(Result[0].get(), 1); // Audio + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_pad_drop"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterPadDrop = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterPadDrop.run( + CallFrame, + std::initializer_list{OutputFilterPadId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp b/test/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp new file mode 100644 index 00000000..7a6675fe --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp @@ -0,0 +1,682 @@ +#include "avfilter//avfilter_func.h" +#include "avfilter/avFilter.h" +#include "avfilter/buffer_source_sink.h" +#include "avfilter/module.h" +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVFilterFunc) { + + ASSERT_TRUE(AVFilterMod != nullptr); + + // Structs Ptr + uint32_t FilterGraphPtr = UINT32_C(4); + uint32_t FilterPtr = UINT32_C(8); + uint32_t Filter2Ptr = UINT32_C(12); + uint32_t InputFilterCtxPtr = UINT32_C(28); // AVFilterContext + uint32_t OutputFilterCtxPtr = UINT32_C(24); // AVFilterContext + uint32_t InputInOutPtr = UINT32_C(32); + uint32_t OutputInOutPtr = UINT32_C(36); + uint32_t FramePtr = UINT32_C(40); + + // Strings. + uint32_t InputNamePtr = UINT32_C(100); + uint32_t OutputNamePtr = UINT32_C(150); + uint32_t InputFilterNamePtr = UINT32_C(200); + uint32_t OutputFilterNamePtr = UINT32_C(250); + uint32_t ArgsPtr = UINT32_C(300); + uint32_t SpecPtr = UINT32_C(450); + uint32_t StrPtr = UINT32_C(500); + + std::string InputName = std::string("abuffer"); + fillMemContent(MemInst, InputNamePtr, InputName); + + std::string OutputName = std::string("abuffersink"); + fillMemContent(MemInst, OutputNamePtr, OutputName); + + std::string InputFilterName = std::string("in"); + fillMemContent(MemInst, InputFilterNamePtr, InputFilterName); + + std::string OutputFilterName = std::string("out"); + fillMemContent(MemInst, OutputFilterNamePtr, OutputFilterName); + + std::string Args = std::string( + "time_base=1/44100:sample_rate=44100:sample_fmt=fltp:channel_layout=0x3"); + fillMemContent(MemInst, ArgsPtr, Args); + + std::string SpecStr = std::string("anull"); + fillMemContent(MemInst, SpecPtr, SpecStr); + + initEmptyFrame(FramePtr); + uint32_t FrameId = readUInt32(MemInst, FramePtr); + + auto *FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_graph_alloc"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterGraphAlloc = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterGraphAlloc &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterGraphAlloc.run( + CallFrame, std::initializer_list{FilterGraphPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + uint32_t FilterGraphId = readUInt32(MemInst, FilterGraphPtr); + ASSERT_TRUE(FilterGraphId > 0); + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_get_by_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterGetByName = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterGetByName &>( + FuncInst->getHostFunc()); + + { + int32_t Length = InputName.length(); + EXPECT_TRUE(HostFuncAVFilterGetByName.run( + CallFrame, + std::initializer_list{FilterPtr, InputNamePtr, + Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + Length = OutputName.length(); + EXPECT_TRUE(HostFuncAVFilterGetByName.run( + CallFrame, + std::initializer_list{Filter2Ptr, OutputNamePtr, + Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + uint32_t FilterId = readUInt32(MemInst, FilterPtr); + uint32_t Filter2Id = readUInt32(MemInst, Filter2Ptr); + ASSERT_TRUE(FilterId > 0); + ASSERT_TRUE(Filter2Id > 0); + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_graph_create_filter"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterGraphCreateFilter = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterGraphCreateFilter &>( + FuncInst->getHostFunc()); + + { + int32_t NameLen = InputFilterName.length(); + int32_t ArgsLen = Args.length(); + EXPECT_TRUE(HostFuncAVFilterGraphCreateFilter.run( + CallFrame, + std::initializer_list{ + InputFilterCtxPtr, FilterId, InputFilterNamePtr, NameLen, ArgsPtr, + ArgsLen, FilterGraphId}, + Result)); + ASSERT_TRUE(Result[0].get() >= 0); + writeUInt32(MemInst, 0, InputFilterCtxPtr); // Setting InputFilterCtx to 0 + + NameLen = OutputFilterName.length(); + EXPECT_TRUE(HostFuncAVFilterGraphCreateFilter.run( + CallFrame, + std::initializer_list{ + OutputFilterCtxPtr, Filter2Id, OutputFilterNamePtr, NameLen, 0, 0, + FilterGraphId}, + Result)); + ASSERT_TRUE(Result[0].get() >= 0); + writeUInt32(MemInst, 0, OutputFilterCtxPtr); // Setting OutputFilterCtx to 0 + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_inout_alloc"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterInOutAlloc = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterInOutAlloc &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterInOutAlloc.run( + CallFrame, std::initializer_list{InputInOutPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_TRUE(HostFuncAVFilterInOutAlloc.run( + CallFrame, std::initializer_list{OutputInOutPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + uint32_t InputInOutId = readUInt32(MemInst, InputInOutPtr); + ASSERT_TRUE(InputInOutId > 0); + + uint32_t OutputInOutId = readUInt32(MemInst, OutputInOutPtr); + ASSERT_TRUE(OutputInOutId > 0); + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_graph_get_filter"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterGraphGetFilter = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterGraphGetFilter &>( + FuncInst->getHostFunc()); + + { + int32_t Length = OutputFilterName.length(); + EXPECT_TRUE(HostFuncAVFilterGraphGetFilter.run( + CallFrame, + std::initializer_list{ + OutputFilterCtxPtr, FilterGraphId, OutputFilterNamePtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + Length = InputFilterName.length(); + EXPECT_TRUE(HostFuncAVFilterGraphGetFilter.run( + CallFrame, + std::initializer_list{ + InputFilterCtxPtr, FilterGraphId, InputFilterNamePtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + uint32_t OutputFilterCtxId = readUInt32(MemInst, OutputFilterCtxPtr); + ASSERT_TRUE(OutputFilterCtxId > 0); + + uint32_t InputFilterCtxId = readUInt32(MemInst, InputFilterCtxPtr); + ASSERT_TRUE(InputFilterCtxId > 0); + + // ================================================================== + // Setting InOutId Values for Filtering + // ================================================================== + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_inout_set_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterInOutSetName = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterInOutSetName &>( + FuncInst->getHostFunc()); + + { + int32_t Length = InputFilterName.length(); + EXPECT_TRUE(HostFuncAVFilterInOutSetName.run( + CallFrame, + std::initializer_list{OutputInOutId, + InputFilterNamePtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + Length = OutputFilterName.length(); + EXPECT_TRUE(HostFuncAVFilterInOutSetName.run( + CallFrame, + std::initializer_list{ + InputInOutId, OutputFilterNamePtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_inout_set_filter_ctx"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterInOutSetFilterCtx = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterInOutSetFilterCtx &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterInOutSetFilterCtx.run( + CallFrame, + std::initializer_list{OutputInOutId, + InputFilterCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_TRUE(HostFuncAVFilterInOutSetFilterCtx.run( + CallFrame, + std::initializer_list{InputInOutId, + OutputFilterCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_inout_set_pad_idx"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterInOutSetPadIdx = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterInOutSetPadIdx &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterInOutSetPadIdx.run( + CallFrame, + std::initializer_list{OutputInOutId, 0}, Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_TRUE(HostFuncAVFilterInOutSetPadIdx.run( + CallFrame, std::initializer_list{InputInOutId, 0}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_inout_set_next"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterInOutSetNext = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterInOutSetNext &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterInOutSetNext.run( + CallFrame, + std::initializer_list{OutputInOutId, 0}, Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_TRUE(HostFuncAVFilterInOutSetNext.run( + CallFrame, std::initializer_list{InputInOutId, 0}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + // ================================================================== + // End Setting InOutId Values for Filtering + // ================================================================== + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_graph_parse_ptr"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterGraphParsePtr = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterGraphParsePtr &>( + FuncInst->getHostFunc()); + + { + int32_t Length = SpecStr.length(); + EXPECT_TRUE(HostFuncAVFilterGraphParsePtr.run( + CallFrame, + std::initializer_list{ + FilterGraphId, SpecPtr, Length, InputInOutId, OutputInOutId}, + Result)); + ASSERT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_graph_config"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterGraphConfig = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterGraphConfig &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterGraphConfig.run( + CallFrame, std::initializer_list{FilterGraphId}, + Result)); + ASSERT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_graph_dump_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterGraphDumpLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterGraphDumpLength &>( + FuncInst->getHostFunc()); + + int32_t GraphStrLen = 0; + { + EXPECT_TRUE(HostFuncAVFilterGraphDumpLength.run( + CallFrame, std::initializer_list{FilterGraphId}, + Result)); + GraphStrLen = Result[0].get(); + ASSERT_TRUE(GraphStrLen > 0); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_graph_dump"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterGraphDump = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterGraphDump &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterGraphDump.run( + CallFrame, + std::initializer_list{FilterGraphId, StrPtr, + GraphStrLen}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + // Crashing the program. Checked even from Rust side. + + // FuncInst = AVFilterMod->findFuncExports( + // "wasmedge_ffmpeg_avfilter_avfilter_inout_free"); + // EXPECT_NE(FuncInst, nullptr); + // EXPECT_TRUE(FuncInst->isHostFunction()); + // + // auto &HostFuncAVFilterInOutFree = dynamic_cast< + // WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterInOutFree &>( + // FuncInst->getHostFunc()); + // + // { + // EXPECT_TRUE(HostFuncAVFilterInOutFree.run( + // CallFrame, + // std::initializer_list{InputInOutId}, Result)); + // EXPECT_EQ(Result[0].get(), + // static_cast(ErrNo::Success)); + // + // EXPECT_TRUE(HostFuncAVFilterInOutFree.run( + // CallFrame, + // std::initializer_list{OutputInOutId}, + // Result)); + // EXPECT_EQ(Result[0].get(), + // static_cast(ErrNo::Success)); + // } + + FuncInst = + AVFilterMod->findFuncExports("wasmedge_ffmpeg_avfilter_avfilter_version"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterVersion = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterVersion.run( + CallFrame, std::initializer_list{}, Result)); + ASSERT_TRUE(Result[0].get() > 0); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_configuration_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterConfigurationLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterConfigurationLength &>( + FuncInst->getHostFunc()); + + int32_t Length = 0; + { + EXPECT_TRUE(HostFuncAVFilterConfigurationLength.run( + CallFrame, std::initializer_list{}, Result)); + Length = Result[0].get(); + ASSERT_TRUE(Length > 0); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_configuration"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterConfiguration = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterConfiguration &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterConfiguration.run( + CallFrame, std::initializer_list{StrPtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_license_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterLicenseLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterLicenseLength &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterLicenseLength.run( + CallFrame, std::initializer_list{}, Result)); + Length = Result[0].get(); + ASSERT_TRUE(Length > 0); + } + + FuncInst = + AVFilterMod->findFuncExports("wasmedge_ffmpeg_avfilter_avfilter_license"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterLicense = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterLicense.run( + CallFrame, std::initializer_list{StrPtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + // ================================================================== + // Start Test AVBufferSource, AVBufferSink Funcs + // ================================================================== + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_av_buffersrc_get_nb_failed_requests"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVBufferSrcGetNbFailedRequests = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVBufferSrcGetNbFailedRequests + &>(FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVBufferSrcGetNbFailedRequests.run( + CallFrame, + std::initializer_list{InputFilterCtxId}, Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_av_buffersrc_add_frame"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVBufferSrcAddFrame = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVBufferSrcAddFrame &>( + FuncInst->getHostFunc()); + + // Returning Error Code -22 (Invalid Argument), Due to Passing Empty Frame. + { + EXPECT_TRUE(HostFuncAVBufferSrcAddFrame.run( + CallFrame, + std::initializer_list{InputFilterCtxId, FrameId}, + Result)); + ASSERT_TRUE(Result[0].get()); + } + + // Need to send the last frame. Then only this test will pass. Else Null + // pointer exception. + // FuncInst = AVFilterMod->findFuncExports( + // "wasmedge_ffmpeg_avfilter_av_buffersrc_close"); + // EXPECT_NE(FuncInst, nullptr); + // EXPECT_TRUE(FuncInst->isHostFunction()); + // + // auto &HostFuncAVBufferSrcClose = dynamic_cast< + // WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVBufferSrcClose &>( + // FuncInst->getHostFunc()); + // + // { + // int64_t Pts = 20; + // uint32_t Flags = 30; + // EXPECT_TRUE(HostFuncAVBufferSrcClose.run( + // CallFrame, + // std::initializer_list{InputFilterCtxPtr, Pts, + // Flags}, + // Result)); + // EXPECT_EQ(Result[0].get(), + // static_cast(ErrNo::Success)); + // } + + // Passing Empty frames. Return AVERROR due to no frames presen Return AVERROR + // due to no frames present. + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_av_buffersink_get_frame"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVBufferSinkGetFrame = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVBufferSinkGetFrame &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVBufferSinkGetFrame.run( + CallFrame, + std::initializer_list{OutputFilterCtxId, FrameId}, + Result)); + ASSERT_TRUE(Result[0].get()); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_av_buffersink_get_samples"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVBufferSinkGetSamples = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVBufferSinkGetSamples &>( + FuncInst->getHostFunc()); + + // Passing Empty frames. Return AVERROR due to no frames presen Return AVERROR + // due to no frames present. + { + EXPECT_TRUE(HostFuncAVBufferSinkGetSamples.run( + CallFrame, + std::initializer_list{OutputFilterCtxId, FrameId, + 20}, + Result)); + ASSERT_TRUE(Result[0].get()); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_av_buffersink_set_frame_size"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAvBufferSinkSetFrameSize = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AvBufferSinkSetFrameSize &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAvBufferSinkSetFrameSize.run( + CallFrame, + std::initializer_list{OutputFilterCtxId, 30}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + // ================================================================== + // End Test AVBufferSource, AVBufferSink Funcs + // ================================================================== + + // ================================================================== + // Clean Memory + // ================================================================== + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_free_graph_str"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterFreeGraphStr = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterFreeGraphStr &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterFreeGraphStr.run( + CallFrame, std::initializer_list{FilterGraphId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVFilterMod->findFuncExports("wasmedge_ffmpeg_avfilter_avfilter_drop"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterDrop = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterDrop.run( + CallFrame, std::initializer_list{FilterId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_context_drop"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterContextDrop = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterContextDrop &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterContextDrop.run( + CallFrame, + std::initializer_list{InputFilterCtxId}, Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_TRUE(HostFuncAVFilterContextDrop.run( + CallFrame, + std::initializer_list{OutputFilterCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFilterMod->findFuncExports( + "wasmedge_ffmpeg_avfilter_avfilter_graph_free"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + + auto &HostFuncAVFilterGraphFree = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::AVFilterGraphFree &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFilterGraphFree.run( + CallFrame, std::initializer_list{FilterGraphId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + // ================================================================== + // End Clean Memory + // ================================================================== +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/test/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp b/test/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp new file mode 100644 index 00000000..94dd8c63 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp @@ -0,0 +1,220 @@ +#include "avformat/avChapter.h" +#include "avformat/module.h" +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +// Sample Video under test has only Single Chapter. +TEST_F(FFmpegTest, AVChapter) { + + ASSERT_TRUE(AVFormatMod != nullptr); + + uint32_t ChapterIdx = 0; + + uint32_t FormatCtxPtr = UINT32_C(4); + uint32_t NumPtr = UINT32_C(12); + uint32_t DenPtr = UINT32_C(16); + uint32_t DictionaryPtr = UINT32_C(20); + uint32_t FilePtr = UINT32_C(100); + + std::string FileName = "ffmpeg-assets/sample_video.mp4"; // 32 chars + initFormatCtx(FormatCtxPtr, FilePtr, FileName); + uint32_t FormatCtxId = readUInt32(MemInst, FormatCtxPtr); + + auto *FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_avChapter_id"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVChapterId = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVChapterId.run( + CallFrame, + std::initializer_list{FormatCtxId, ChapterIdx}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avChapter_timebase"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVChapterTimebase = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVChapterTimebase &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVChapterTimebase.run( + CallFrame, + std::initializer_list{NumPtr, DenPtr, FormatCtxId, + ChapterIdx}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + EXPECT_EQ(readSInt32(MemInst, NumPtr), 1); + EXPECT_TRUE(readSInt32(MemInst, DenPtr) >= 0); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_avChapter_start"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVChapterStart = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVChapterStart.run( + CallFrame, + std::initializer_list{FormatCtxId, ChapterIdx}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_avChapter_end"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVChapterEnd = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVChapterEnd.run( + CallFrame, + std::initializer_list{FormatCtxId, ChapterIdx}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avChapter_metadata"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVChapterMetadata = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVChapterMetadata &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVChapterMetadata.run( + CallFrame, + std::initializer_list{FormatCtxId, ChapterIdx, + DictionaryPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + EXPECT_TRUE(readUInt32(MemInst, DictionaryPtr) > 0); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_avChapter_set_id"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVChapterSetId = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int64_t ChapterId = 10000; + EXPECT_TRUE( + HostFuncAVChapterSetId.run(CallFrame, + std::initializer_list{ + FormatCtxId, ChapterIdx, ChapterId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + // Verify Set Data + EXPECT_TRUE(HostFuncAVChapterId.run( + CallFrame, + std::initializer_list{FormatCtxId, ChapterIdx}, + Result)); + EXPECT_EQ(Result[0].get(), ChapterId); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avChapter_set_timebase"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVChapterSetTimebase = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVChapterSetTimebase &>( + FuncInst->getHostFunc()); + + { + int32_t Num = 3; + int32_t Den = 4; + EXPECT_TRUE(HostFuncAVChapterSetTimebase.run( + CallFrame, + std::initializer_list{Num, Den, FormatCtxId, + ChapterIdx}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + // Verify Set Data + EXPECT_TRUE(HostFuncAVChapterTimebase.run( + CallFrame, + std::initializer_list{NumPtr, DenPtr, FormatCtxId, + ChapterIdx}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(readSInt32(MemInst, NumPtr), Num); + EXPECT_EQ(readSInt32(MemInst, DenPtr), Den); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avChapter_set_start"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVChapterSetStart = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVChapterSetStart &>( + FuncInst->getHostFunc()); + + { + int64_t StartValue = 1000; + EXPECT_TRUE(HostFuncAVChapterSetStart.run( + CallFrame, + std::initializer_list{FormatCtxId, ChapterIdx, + StartValue}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + // Verify Set Data + EXPECT_TRUE(HostFuncAVChapterStart.run( + CallFrame, + std::initializer_list{FormatCtxId, ChapterIdx}, + Result)); + EXPECT_EQ(Result[0].get(), StartValue); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avChapter_set_end"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVChapterSetEnd = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int64_t EndValue = 99999; + EXPECT_TRUE( + HostFuncAVChapterSetEnd.run(CallFrame, + std::initializer_list{ + FormatCtxId, ChapterIdx, EndValue}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + // Verify Set Data + EXPECT_TRUE(HostFuncAVChapterEnd.run( + CallFrame, + std::initializer_list{FormatCtxId, ChapterIdx}, + Result)); + EXPECT_EQ(Result[0].get(), EndValue); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/test/plugins/wasmedge_ffmpeg/avformat/avInputOutputContext.cpp b/test/plugins/wasmedge_ffmpeg/avformat/avInputOutputContext.cpp new file mode 100644 index 00000000..3a575551 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avformat/avInputOutputContext.cpp @@ -0,0 +1,207 @@ +#include "avformat/avInputOutputFormat.h" +#include "avformat/avformatContext.h" +#include "avformat/module.h" +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVInputFormat) { + + std::string FileName = "ffmpeg-assets/sample_video.mp4"; // 32 chars + uint32_t FormatCtxPtr = UINT32_C(24); + uint32_t InputFormatPtr = UINT32_C(28); + + uint32_t StrBuf = UINT32_C(100); + initFFmpegStructs(UINT32_C(20), FormatCtxPtr, UINT32_C(28), FileName, + UINT32_C(60), UINT32_C(64), UINT32_C(68), UINT32_C(72)); + + uint32_t FormatCtxId = readUInt32(MemInst, FormatCtxPtr); + + // ==================================================================== + // Initialize AVInputFormat + // ==================================================================== + + auto *FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformatContext_iformat"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatCtxIFormat = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatCtxIFormat &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFormatCtxIFormat.run( + CallFrame, + std::initializer_list{FormatCtxId, + InputFormatPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + EXPECT_TRUE(readUInt32(MemInst, InputFormatPtr) > 0); + } + uint32_t InputFormatId = readUInt32(MemInst, InputFormatPtr); + + // ==================================================================== + // End Initialize AVInputFormat + // ==================================================================== + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avIOFormat_name_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVIOFormatNameLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVIOFormatNameLength &>( + FuncInst->getHostFunc()); + + int32_t Length = 0; + { + EXPECT_TRUE(HostFuncAVIOFormatNameLength.run( + CallFrame, + std::initializer_list{InputFormatId, 0}, Result)); + Length = Result[0].get(); + ASSERT_TRUE(Length > 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avInputFormat_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVInputFormatName = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVInputFormatName &>( + FuncInst->getHostFunc()); + + fillMemContent(MemInst, StrBuf, Length); + { + EXPECT_TRUE(HostFuncAVInputFormatName.run( + CallFrame, + std::initializer_list{InputFormatId, StrBuf, + Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avIOFormat_long_name_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVIOFormatLongNameLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVIOFormatLongNameLength &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVIOFormatLongNameLength.run( + CallFrame, + std::initializer_list{InputFormatId, 0}, Result)); + + Length = Result[0].get(); + ASSERT_TRUE(Length > 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avInputFormat_long_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVInputFormatLongName = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVInputFormatLongName &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVInputFormatLongName.run( + CallFrame, + std::initializer_list{InputFormatId, StrBuf, + Length}, + Result)); + + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avIOFormat_extensions_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVIOFormatExtensionsLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVIOFormatExtensionsLength &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVIOFormatExtensionsLength.run( + CallFrame, + std::initializer_list{InputFormatId, 0}, Result)); + + Length = Result[0].get(); + ASSERT_TRUE(Length > 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avInputFormat_extensions"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVInputFormatExtensions = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVInputFormatExtensions &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVInputFormatExtensions.run( + CallFrame, + std::initializer_list{InputFormatId, StrBuf, + Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avIOFormat_mime_type_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVIOFormatMimeTypeLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVIOFormatMimeTypeLength &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVIOFormatMimeTypeLength.run( + CallFrame, + std::initializer_list{InputFormatId, 0}, Result)); + + Length = Result[0].get(); + ASSERT_TRUE(Length >= 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avInputFormat_mime_type"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVInputFormatMimeType = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVInputFormatMimeType &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVInputFormatMimeType.run( + CallFrame, + std::initializer_list{InputFormatId, StrBuf, + Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avInputOutputFormat_free"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVInputOutputFormatFree = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVInputOutputFormatFree &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVInputOutputFormatFree.run( + CallFrame, std::initializer_list{InputFormatId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/test/plugins/wasmedge_ffmpeg/avformat/avStream.cpp b/test/plugins/wasmedge_ffmpeg/avformat/avStream.cpp new file mode 100644 index 00000000..1fc2dadc --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avformat/avStream.cpp @@ -0,0 +1,303 @@ +#include "avformat/avStream.h" +#include "avformat/module.h" +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +// Testing all AVFormat_funcs. +TEST_F(FFmpegTest, AVStreamStruct) { + + ASSERT_TRUE(AVFormatMod != nullptr); + + uint32_t StreamIdx = 0; + + uint32_t FormatCtxPtr = UINT32_C(4); + uint32_t CodecParameterPtr = UINT32_C(8); + uint32_t NumPtr = UINT32_C(12); + uint32_t DenPtr = UINT32_C(16); + uint32_t DictPtr = UINT32_C(20); + uint32_t FilePtr = UINT32_C(100); + + std::string FileName = "ffmpeg-assets/sample_video.mp4"; // 32 chars + initFormatCtx(FormatCtxPtr, FilePtr, FileName); + uint32_t FormatCtxId = readUInt32(MemInst, FormatCtxPtr); + + auto *FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_avStream_id"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamId = + dynamic_cast( + FuncInst->getHostFunc()); + + uint32_t AvFormatCtxId = readUInt32(MemInst, FormatCtxPtr); + { + EXPECT_TRUE(HostFuncAVStreamId.run( + CallFrame, + std::initializer_list{AvFormatCtxId, StreamIdx}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_avStream_index"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamIndex = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVStreamIndex.run( + CallFrame, + std::initializer_list{FormatCtxId, StreamIdx}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_codecpar"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamCodecPar = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamCodecPar &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVStreamCodecPar.run( + CallFrame, + std::initializer_list{FormatCtxId, StreamIdx, + CodecParameterPtr}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + ASSERT_TRUE(readUInt32(MemInst, CodecParameterPtr) > 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_timebase"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamTimebase = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamTimebase &>( + FuncInst->getHostFunc()); + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_set_timebase"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamSetTimebase = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamSetTimebase &>( + FuncInst->getHostFunc()); + + { + int32_t Num = 3; + int32_t Den = 4; + EXPECT_TRUE(HostFuncAVStreamSetTimebase.run( + CallFrame, + std::initializer_list{Num, Den, FormatCtxId, + StreamIdx}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_TRUE(HostFuncAVStreamTimebase.run( + CallFrame, + std::initializer_list{NumPtr, DenPtr, FormatCtxId, + StreamIdx}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(readUInt32(MemInst, NumPtr), Num); + EXPECT_EQ(readUInt32(MemInst, DenPtr), Den); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_duration"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamDuration = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamDuration &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVStreamDuration.run( + CallFrame, + std::initializer_list{FormatCtxId, StreamIdx}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_start_time"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamStartTime = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamStartTime &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVStreamStartTime.run( + CallFrame, + std::initializer_list{FormatCtxId, StreamIdx}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_nb_frames"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamNbFrames = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamNbFrames &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVStreamNbFrames.run( + CallFrame, + std::initializer_list{FormatCtxId, StreamIdx}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_disposition"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamDisposition = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamDisposition &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVStreamDisposition.run( + CallFrame, + std::initializer_list{FormatCtxId, StreamIdx}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_set_r_frame_rate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamSetRFrameRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamSetRFrameRate &>( + FuncInst->getHostFunc()); + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_r_frame_rate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamRFrameRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamRFrameRate &>( + FuncInst->getHostFunc()); + + { + int32_t Num = 3; + int32_t Den = 4; + EXPECT_TRUE(HostFuncAVStreamSetRFrameRate.run( + CallFrame, + std::initializer_list{Num, Den, FormatCtxId, + StreamIdx}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_TRUE(HostFuncAVStreamRFrameRate.run( + CallFrame, + std::initializer_list{NumPtr, DenPtr, FormatCtxId, + StreamIdx}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(readUInt32(MemInst, NumPtr), Num); + EXPECT_EQ(readUInt32(MemInst, DenPtr), Den); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_set_avg_frame_rate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamSetAvgFrameRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamSetAvgFrameRate &>( + FuncInst->getHostFunc()); + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_avg_frame_rate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamAvgFrameRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamAvgFrameRate &>( + FuncInst->getHostFunc()); + + { + int32_t Num = 3; + int32_t Den = 4; + + EXPECT_TRUE(HostFuncAVStreamSetAvgFrameRate.run( + CallFrame, + std::initializer_list{Num, Den, FormatCtxId, + StreamIdx}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_TRUE(HostFuncAVStreamAvgFrameRate.run( + CallFrame, + std::initializer_list{NumPtr, DenPtr, FormatCtxId, + StreamIdx}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(readUInt32(MemInst, NumPtr), Num); + EXPECT_EQ(readUInt32(MemInst, DenPtr), Den); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_metadata"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamMetadata = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamMetadata &>( + FuncInst->getHostFunc()); + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_set_metadata"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamSetMetadata = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamSetMetadata &>( + FuncInst->getHostFunc()); + { + EXPECT_TRUE(HostFuncAVStreamMetadata.run( + CallFrame, + std::initializer_list{FormatCtxId, StreamIdx, + DictPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + uint32_t DictId = readUInt32(MemInst, DictPtr); + EXPECT_TRUE(HostFuncAVStreamSetMetadata.run( + CallFrame, + std::initializer_list{FormatCtxId, StreamIdx, + DictId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_avStream_discard"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVStreamDiscard = + dynamic_cast( + FuncInst->getHostFunc()); + { + EXPECT_TRUE(HostFuncAVStreamDiscard.run( + CallFrame, + std::initializer_list{FormatCtxId, StreamIdx}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp b/test/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp new file mode 100644 index 00000000..ede24ec5 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp @@ -0,0 +1,184 @@ +#include "avformat/avformatContext.h" +#include "avformat/module.h" +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +// Testing all AVFormat_funcs. +TEST_F(FFmpegTest, AVFormatContextStruct) { + + uint32_t FormatCtxPtr = UINT32_C(4); + uint32_t InputFormatPtr = UINT32_C(8); + uint32_t OutputFormatPtr = UINT32_C(12); + uint32_t DicPtr = uint32_t(16); + uint32_t FilePtr = UINT32_C(100); + + std::string FileName = "ffmpeg-assets/sample_video.mp4"; // 32 chars + initFormatCtx(FormatCtxPtr, FilePtr, FileName); + uint32_t FormatCtxId = readUInt32(MemInst, FormatCtxPtr); + + auto *FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformatContext_iformat"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatCtxIFormat = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatCtxIFormat &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFormatCtxIFormat.run( + CallFrame, + std::initializer_list{FormatCtxId, + InputFormatPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + EXPECT_TRUE(readUInt32(MemInst, InputFormatPtr) > 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformatContext_oformat"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatCtxOFormat = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatCtxOFormat &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFormatCtxOFormat.run( + CallFrame, + std::initializer_list{FormatCtxId, + OutputFormatPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + EXPECT_TRUE(readUInt32(MemInst, InputFormatPtr) > 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformatContext_probescope"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatCtxProbeScore = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatCtxProbeScore &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFormatCtxProbeScore.run( + CallFrame, std::initializer_list{FormatCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), 100); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformatContext_nb_streams"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatCtxNbStreams = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatCtxNbStreams &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFormatCtxNbStreams.run( + CallFrame, std::initializer_list{FormatCtxId}, + Result)); + EXPECT_TRUE(Result[0].get() > 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformatContext_duration"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatCtxDuration = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatCtxDuration &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFormatCtxDuration.run( + CallFrame, std::initializer_list{FormatCtxId}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0 || Result[0].get() < 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformatContext_bit_rate"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatCtxBitRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatCtxBitRate &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFormatCtxBitRate.run( + CallFrame, std::initializer_list{FormatCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformatContext_set_nb_chapters"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatCtxSetNbChapters = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatCtxSetNbChapters &>( + FuncInst->getHostFunc()); + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformatContext_nb_chapters"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatCtxNbChapters = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatCtxNbChapters &>( + FuncInst->getHostFunc()); + { + uint32_t NbChapters = 200; + EXPECT_TRUE(HostFuncAVFormatCtxSetNbChapters.run( + CallFrame, + std::initializer_list{FormatCtxId, NbChapters}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_TRUE(HostFuncAVFormatCtxNbChapters.run( + CallFrame, std::initializer_list{FormatCtxId}, + Result)); + EXPECT_EQ(Result[0].get(), NbChapters); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformatContext_metadata"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatCtxMetadata = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatCtxMetadata &>( + FuncInst->getHostFunc()); + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformatContext_set_metadata"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatCtxSetMetadata = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatCtxSetMetadata &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncAVFormatCtxMetadata.run( + CallFrame, + std::initializer_list{FormatCtxId, DicPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + EXPECT_TRUE(readUInt32(MemInst, DicPtr) > 0); + + uint32_t DictId = readUInt32(MemInst, DicPtr); + EXPECT_TRUE(HostFuncAVFormatCtxSetMetadata.run( + CallFrame, + std::initializer_list{FormatCtxId, DictId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/test/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp b/test/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp new file mode 100644 index 00000000..c2fc4fee --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp @@ -0,0 +1,588 @@ +#include "avformat/avformat_func.h" +#include "avformat/module.h" +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +// Testing all AVFormat_funcs. +TEST_F(FFmpegTest, AVInputFormatFunc) { + + uint32_t FormatCtxPtr = UINT32_C(4); + uint32_t DictPtr = UINT32_C(16); + uint32_t KeyPtr = UINT32_C(100); + uint32_t ValuePtr = UINT32_C(200); + uint32_t StrPtr = UINT32_C(400); + + initDict(DictPtr, KeyPtr, std::string("Key"), ValuePtr, std::string("Value")); + uint32_t DictId = readUInt32(MemInst, DictPtr); + + uint32_t UrlStart = UINT32_C(300); + uint32_t UrlSize = 30; + fillMemContent(MemInst, UrlStart, UrlSize); + fillMemContent(MemInst, UrlStart, + std::string("ffmpeg-assets/sample_video.mp4")); + + auto *FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_open_input"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatOpenInput = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatOpenInput &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatOpenInput"sv); + { + // AVDict only + EXPECT_TRUE(HostFuncAVFormatOpenInput.run( + CallFrame, + std::initializer_list{ + FormatCtxPtr, UrlStart, UrlSize, UINT32_C(0), DictId}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + EXPECT_TRUE(readUInt32(MemInst, FormatCtxPtr) > 0); + + // No AVDict, No AVInputFormat + EXPECT_TRUE(HostFuncAVFormatOpenInput.run( + CallFrame, + std::initializer_list{ + FormatCtxPtr, UrlStart, UrlSize, UINT32_C(0), UINT32_C(0)}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + EXPECT_TRUE(readUInt32(MemInst, FormatCtxPtr) > 0); + } + + uint32_t FormatCtxId = readUInt32(MemInst, FormatCtxPtr); + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_find_stream_info"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatFindStreamInfo = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatFindStreamInfo &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatFindStreamInfo"sv); + { + EXPECT_TRUE(HostFuncAVFormatFindStreamInfo.run( + CallFrame, + std::initializer_list{FormatCtxId, UINT32_C(0)}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_av_dump_format"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatAVDumpFormat = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVDumpFormat"sv); + { + EXPECT_TRUE(HostFuncAVFormatAVDumpFormat.run( + CallFrame, + std::initializer_list{ + FormatCtxId, 0, UINT32_C(100), UINT32_C(30), 0}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_av_find_best_stream"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFindBestStream = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFindBestStream &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFindBestStream"sv); + { + EXPECT_TRUE(HostFuncAVFindBestStream.run( + CallFrame, + std::initializer_list{ + FormatCtxId, UINT32_C(0), INT32_C(-1), INT32_C(-1), 0, 0}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_av_read_frame"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVReadFrame = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVReadFrame"sv); + { + uint32_t PacketPtr = UINT32_C(520); + allocPacket(PacketPtr); + uint32_t PacketId = readUInt32(MemInst, PacketPtr); + EXPECT_TRUE(HostFuncAVReadFrame.run( + CallFrame, + std::initializer_list{FormatCtxId, PacketId}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_network_init"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatNetworkInit = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatNetworkInit &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatNetworkInit"sv); + { + EXPECT_TRUE(HostFuncAVFormatNetworkInit.run( + CallFrame, std::initializer_list{}, Result)); + + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_seek_file"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatSeekFile = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatSeekFile &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatSeekFile"sv); + { + uint32_t StreamIdx = -1; + int64_t MinTs = -10; + int64_t Ts = 0; + int64_t MaxTs = 10; + int32_t Flags = 0; + + // Try a network Fetch. + EXPECT_TRUE(HostFuncAVFormatSeekFile.run( + CallFrame, + std::initializer_list{FormatCtxId, StreamIdx, + MinTs, Ts, MaxTs, Flags}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_av_read_play"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatAVReadPlay = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVReadPlay"sv); + { + // Try a network Fetch. + EXPECT_TRUE(HostFuncAVFormatAVReadPlay.run( + CallFrame, std::initializer_list{FormatCtxId}, + Result)); + EXPECT_TRUE(Result[0].get() < 0); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_av_read_pause"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatAVReadPause = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVReadPause"sv); + { + // Try a network Fetch. + EXPECT_TRUE(HostFuncAVFormatAVReadPause.run( + CallFrame, std::initializer_list{FormatCtxId}, + Result)); + EXPECT_TRUE(Result[0].get() < 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_network_deinit"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatNetworkDeInit = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatNetworkDeInit &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatNetworkDeInit"sv); + { + EXPECT_TRUE(HostFuncAVFormatNetworkDeInit.run( + CallFrame, std::initializer_list{}, Result)); + + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_close_input"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatCloseInput = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatCloseInput &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatCloseInput"sv); + { + EXPECT_TRUE(HostFuncAVFormatCloseInput.run( + CallFrame, std::initializer_list{FormatCtxId}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_free_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFreeContext = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatFreeContext &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatFreeContext"sv); + { + EXPECT_TRUE(HostFuncAVFreeContext.run( + CallFrame, std::initializer_list{FormatCtxId}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_avformat_version"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatVersion = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatVersion"sv); + { + EXPECT_TRUE(HostFuncAVFormatVersion.run( + CallFrame, std::initializer_list{}, Result)); + + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_configuration_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatConfigurationLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatConfigurationLength &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatConfigurationLength"sv); + int32_t Length = 0; + { + EXPECT_TRUE(HostFuncAVFormatConfigurationLength.run( + CallFrame, std::initializer_list{}, Result)); + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_configuration"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatConfiguration = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatConfiguration &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatConfiguration"sv); + { + EXPECT_TRUE(HostFuncAVFormatConfiguration.run( + CallFrame, std::initializer_list{StrPtr, Length}, + Result)); + + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_license_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatLicenseLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatLicenseLength &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatLicenseLength"sv); + { + EXPECT_TRUE(HostFuncAVFormatLicenseLength.run( + CallFrame, std::initializer_list{}, Result)); + + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_avformat_license"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatLicense = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatLicense"sv); + { + EXPECT_TRUE(HostFuncAVFormatLicense.run( + CallFrame, std::initializer_list{StrPtr, Length}, + Result)); + + EXPECT_TRUE(Result[0].get() >= 0); + } +} + +TEST_F(FFmpegTest, AVOutputFormatFunc) { + + uint32_t FormatCtxPtr = UINT32_C(4); + uint32_t DictPtr = UINT32_C(16); + uint32_t ChapterPtr = UINT32_C(20); + uint32_t FramePtr = UINT32_C(24); + uint32_t KeyPtr = UINT32_C(100); + uint32_t ValuePtr = UINT32_C(200); + + initDict(DictPtr, KeyPtr, std::string("Key"), ValuePtr, std::string("Value")); + initEmptyFrame(FramePtr); + uint32_t DictId = readUInt32(MemInst, DictPtr); + uint32_t FrameId = readUInt32(MemInst, FramePtr); + + uint32_t FormatStart = 300; + uint32_t FormatLen = 3; + uint32_t FileStart = 350; + uint32_t FileLen = 8; + fillMemContent(MemInst, FormatStart, FormatLen + FileLen); + + fillMemContent(MemInst, FormatStart, std::string("mp4")); + fillMemContent(MemInst, FileStart, std::string("test.mp4")); + + auto *FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_alloc_output_context2"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatAllocOutputContext2 = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatAllocOutputContext2 &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatAllocOutputContext2"sv); + { + EXPECT_TRUE(HostFuncAVFormatAllocOutputContext2.run( + CallFrame, + std::initializer_list{ + FormatCtxPtr, 0, FormatStart, FormatLen, FileStart, FileLen}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + + EXPECT_TRUE(HostFuncAVFormatAllocOutputContext2.run( + CallFrame, + std::initializer_list{ + FormatCtxPtr, readUInt32(MemInst, FormatCtxPtr), FormatStart, + FormatLen, FileStart, FileLen}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + + EXPECT_TRUE(HostFuncAVFormatAllocOutputContext2.run( + CallFrame, + std::initializer_list{FormatCtxPtr, 0, 0, 0, + FileStart, FileLen}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_avio_open"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVIOOpen = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVIOOpen"sv); + { + uint32_t AvFormatCtxId = readUInt32(MemInst, FormatCtxPtr); + EXPECT_TRUE( + HostFuncAVIOOpen.run(CallFrame, + std::initializer_list{ + AvFormatCtxId, FileStart, FileLen, 2}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_avio_open2"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVIOOpen2 = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVIOOpen2"sv); + { + uint32_t FormatCtxId = readUInt32(MemInst, FormatCtxPtr); + EXPECT_TRUE(HostFuncAVIOOpen2.run( + CallFrame, + std::initializer_list{FormatCtxId, FileStart, + FileLen, 2, 0, DictId}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + } + + // TODO: This test modifies the input file. Unable to test. + // Added test on rust side. + // spdlog::info("Testing AVGuessCodec"sv); + // uint32_t EmptyStrPtr = UINT32_C(520); + // writeUInt32(MemInst, 0, EmptyStrPtr); + // { + // uint32_t FormatCtxId = readUInt32(MemInst, FormatCtxPtr); + // int32_t MediaTypeId = 0; // Video + // EXPECT_TRUE(HostFuncAVGuessCodec.run( + // CallFrame, + // std::initializer_list{FormatCtxId, + // EmptyStrPtr, 0, + // FilePtr, 32, + // EmptyStrPtr, 0, + // MediaTypeId}, + // Result)); + // EXPECT_EQ(Result[0].get(), 1); // AV_CODEC_ID_MPEG1VIDEO: + // } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_write_header"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatWriteHeader = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatWriteHeader &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFormatWriteHeader"sv); + { + // Did not set AVParameters, etc. Hence Giving Invalid Argument Error. + uint32_t FormatCtxId = readUInt32(MemInst, FormatCtxPtr); + EXPECT_TRUE(HostFuncAVFormatWriteHeader.run( + CallFrame, + std::initializer_list{FormatCtxId, DictId}, + Result)); + EXPECT_EQ(Result[0].get(), -22); + } + + // Write Header above return invalid argument due to which below test won't + // work. The OutputFormatContext should Be configured using the input format + // context. Test on the Rust side. This is working as expected. + + // FuncInst = AVFormatMod->findFuncExports( + // "wasmedge_ffmpeg_avformat_avformat_write_trailer"); + // EXPECT_NE(FuncInst, nullptr); + // EXPECT_TRUE(FuncInst->isHostFunction()); + // auto &HostFuncAVFormatTrailer = dynamic_cast< + // WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatWriteTrailer &>( + // FuncInst->getHostFunc()); + // { + // // Did not set AVParameters, etc. Hence Giving Invalid Argument Error. + // uint32_t FormatCtxId = readUInt32(MemInst, FormatCtxPtr); + // EXPECT_TRUE(HostFuncAVFormatTrailer.run( + // CallFrame, std::initializer_list{FormatCtxId}, + // Result)); + // EXPECT_EQ(Result[0].get(), -22); + // } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avchapter_mallocz"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVIOClose = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVChapterMallocz &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVChapterMallocz"sv); + { + EXPECT_TRUE(HostFuncAVIOClose.run( + CallFrame, std::initializer_list{ChapterPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + // How to pass IntPtr + // + // FuncInst = AVFormatMod->findFuncExports( + // "wasmedge_ffmpeg_avformat_avchapter_dynarray_add"); + // EXPECT_NE(FuncInst, nullptr); + // EXPECT_TRUE(FuncInst->isHostFunction()); + // auto &HostFuncAVChapterDynarrayAdd = dynamic_cast< + // WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVChapterDynarrayAdd &>( + // FuncInst->getHostFunc()); + // // For the give input file, nb_chapter is 0; + // { + // uint32_t AvChapterId = readUInt32(MemInst, AvFormatCtxPtr); + // uint32_t AvFormatCtxId = readUInt32(MemInst, AvFormatCtxPtr); + // EXPECT_TRUE(HostFuncAVChapterDynarrayAdd.run( + // CallFrame, + // std::initializer_list{AvFormatCtxId, + // UINT32_C(0), + // AvChapterId}, + // Result)); + // EXPECT_EQ(Result[0].get(), + // static_cast(ErrNo::Success)); + // } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_avformat_avfreep"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVFormatAVFreep = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVFreeP"sv); + { + uint32_t ChapterId = readUInt32(MemInst, ChapterPtr); + EXPECT_TRUE(HostFuncAVFormatAVFreep.run( + CallFrame, std::initializer_list{ChapterId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_av_write_frame"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVWriteFrame = + dynamic_cast( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVWriteFrame"sv); + // Passing Empty Frame, Hence giving Invalid Argument Error. + { + uint32_t FormatCtxId = readUInt32(MemInst, FormatCtxPtr); + EXPECT_TRUE(HostFuncAVWriteFrame.run( + CallFrame, + std::initializer_list{FormatCtxId, FrameId}, + Result)); + EXPECT_EQ(Result[0].get(), -22); + } + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_av_interleaved_write_frame"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVInterleavedWriteFrame = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVInterleavedWriteFrame &>( + FuncInst->getHostFunc()); + + spdlog::info("Testing AVInterleavedWriteFrame"sv); + // Passing Empty Frame, Hence giving Invalid Argument Error. + { + uint32_t FormatCtxId = readUInt32(MemInst, FormatCtxPtr); + EXPECT_TRUE(HostFuncAVInterleavedWriteFrame.run( + CallFrame, + std::initializer_list{FormatCtxId, FrameId}, + Result)); + EXPECT_EQ(Result[0].get(), -22); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp new file mode 100644 index 00000000..cb4ad790 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp @@ -0,0 +1,151 @@ +#include "avutil/avDictionary.h" +#include "avutil/module.h" +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVDictionary) { + + uint32_t KeyStart = UINT32_C(1); + uint32_t KeyLen = 3; + uint32_t ValueStart = UINT32_C(4); + uint32_t ValueLen = 5; + uint32_t PrevDictEntryIdx = 0; // The Fetch the next Key value Node using an + // index. Passing Index from Rust side. + int32_t Flags = 0; + uint32_t NullDictId = UINT32_C(0); + + uint32_t DictPtr = UINT32_C(80); + + auto *FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_dict_set"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVDictSet = + dynamic_cast( + FuncInst->getHostFunc()); + + // Fill 0 in WasmMemory. + fillMemContent(MemInst, KeyStart, KeyLen + ValueLen); + fillMemContent(MemInst, KeyStart, std::string("KEY")); + fillMemContent(MemInst, ValueStart, std::string("VALUE")); + + // Storing the above Key and Value in dict and using these in below tests + // (dict_get) to fetch Key,values. + { + EXPECT_TRUE(HostFuncAVDictSet.run( + CallFrame, + std::initializer_list{ + DictPtr, KeyStart, KeyLen, ValueStart, ValueLen, Flags}, + Result)); + EXPECT_TRUE(Result[0].get() >= 0); + ASSERT_TRUE(readUInt32(MemInst, DictPtr) > 0); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_dict_copy"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVDictCopy = + dynamic_cast( + FuncInst->getHostFunc()); + + { + uint32_t DestDictPtr = UINT32_C(80); + uint32_t SrcDictId = readUInt32(MemInst, DictPtr); + EXPECT_TRUE( + HostFuncAVDictCopy.run(CallFrame, + std::initializer_list{ + DestDictPtr, SrcDictId, Flags}, + Result)); + ASSERT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_dict_get"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVDictGet = + dynamic_cast( + FuncInst->getHostFunc()); + + { + // Store the string length of Key and value in below Pointers. + uint32_t KeyLenPtr = UINT32_C(56); + uint32_t ValueLenPtr = UINT32_C(60); + uint32_t DictId = readUInt32(MemInst, DictPtr); + EXPECT_TRUE(HostFuncAVDictGet.run( + CallFrame, + std::initializer_list{DictId, KeyStart, KeyLen, + PrevDictEntryIdx, Flags, + KeyLenPtr, ValueLenPtr}, + Result)); + EXPECT_TRUE(Result[0].get() == 1); + EXPECT_EQ(readUInt32(MemInst, KeyLenPtr), KeyLen); + EXPECT_EQ(readUInt32(MemInst, ValueLenPtr), ValueLen); + + // Pass a Null Dict and testing. + EXPECT_TRUE(HostFuncAVDictGet.run( + CallFrame, + std::initializer_list{ + NullDictId, KeyStart, KeyLen, PrevDictEntryIdx, Flags, KeyLenPtr, + ValueLenPtr}, + Result)); + EXPECT_EQ(Result[0].get(), + static_cast(ErrNo::InternalError)); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_dict_get_key_value"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVDictGetKeyValue = + dynamic_cast( + FuncInst->getHostFunc()); + + { + // Store the string of Key and value in below Buffer Pointers. + uint32_t KeyBufPtr = UINT32_C(36); + uint32_t ValueBufPtr = UINT32_C(40); + uint32_t DictId = readUInt32(MemInst, DictPtr); + EXPECT_TRUE(HostFuncAVDictGetKeyValue.run( + CallFrame, + std::initializer_list{ + DictId, KeyStart, KeyLen, ValueBufPtr, ValueLen, KeyBufPtr, + UINT32_C(3), PrevDictEntryIdx, Flags}, + Result)); + EXPECT_EQ(Result[0].get(), 1); + // Verify String. Read String from MemInst + + // Pass a Null Dict and testing. + EXPECT_TRUE(HostFuncAVDictGetKeyValue.run( + CallFrame, + std::initializer_list{ + NullDictId, KeyStart, KeyLen, ValueBufPtr, ValueLen, KeyBufPtr, + UINT32_C(3), PrevDictEntryIdx, Flags}, + Result)); + EXPECT_EQ(Result[0].get(), + static_cast(ErrNo::InternalError)); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_dict_free"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVDictFree = + dynamic_cast( + FuncInst->getHostFunc()); + + { + uint32_t DictId = readUInt32(MemInst, DictPtr); + EXPECT_TRUE(HostFuncAVDictFree.run( + CallFrame, std::initializer_list{DictId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avError.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avError.cpp new file mode 100644 index 00000000..454d38a2 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avutil/avError.cpp @@ -0,0 +1,65 @@ +#include "avutil/error.h" +#include "avutil/module.h" +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVError) { + + ASSERT_TRUE(AVUtilMod != nullptr); + + int32_t ErrNum = 35; + + uint32_t ErrStartPtr = UINT32_C(100); + uint32_t ErrSize = 10; + fillMemContent(MemInst, ErrStartPtr, ErrSize); + + fillMemContent(MemInst, ErrStartPtr, std::string("Test Error")); + + auto *FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_strerror"); + auto &HostFuncAVUtilAVStrError = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVUtilAVStrError.run( + CallFrame, std::initializer_list{}, Result); + + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_AVERROR"); + auto &HostFuncAVUtilAVError = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVUtilAVError.run( + CallFrame, std::initializer_list{ErrNum}, Result); + + EXPECT_EQ(Result[0].get(), + ErrNum * -1); // Returns Negative, convert to Positive + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_AVUNERROR"); + auto &HostFuncAVUtilAVUNError = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVUtilAVUNError.run( + CallFrame, std::initializer_list{ErrNum}, Result); + + EXPECT_EQ(Result[0].get(), + ErrNum * -1); // Returns Negative, convert to Positive + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp new file mode 100644 index 00000000..3f4c6269 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp @@ -0,0 +1,781 @@ +#include "avutil/avFrame.h" +#include "avutil/module.h" +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVFrame) { + + uint32_t AVFramePtr = UINT32_C(72); + uint32_t AVFrame2Ptr = UINT32_C(40); + uint32_t DictPtr = UINT32_C(36); + uint32_t NumPtr = UINT32_C(80); + uint32_t DenPtr = UINT32_C(84); + uint32_t BufPtr = UINT32_C(200); // TO store Frame Data; + + std::string FileName = "ffmpeg-assets/sample_video.mp4"; // 32 chars + initFFmpegStructs(UINT32_C(12), UINT32_C(24), UINT32_C(28), FileName, + UINT32_C(60), UINT32_C(64), UINT32_C(68), AVFramePtr); + + initFFmpegStructs(UINT32_C(100), UINT32_C(104), UINT32_C(108), FileName, + UINT32_C(112), UINT32_C(116), UINT32_C(120), AVFrame2Ptr); + + uint32_t AVFrameId = readUInt32(MemInst, AVFramePtr); + uint32_t AVFrame2Id = readUInt32(MemInst, AVFrame2Ptr); + + auto *FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_alloc"); + auto &HostFuncAVFrameAlloc = + dynamic_cast( + FuncInst->getHostFunc()); + + uint32_t EmptyFramePtr = UINT32_C(64); + + { + HostFuncAVFrameAlloc.run( + CallFrame, std::initializer_list{EmptyFramePtr}, + Result); + + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + EXPECT_TRUE(readUInt32(MemInst, EmptyFramePtr) > 0); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_free"); + auto &HostFuncAVFrameFree = + dynamic_cast( + FuncInst->getHostFunc()); + + { + uint32_t EmptyFrameId = readUInt32(MemInst, EmptyFramePtr); + HostFuncAVFrameFree.run( + CallFrame, std::initializer_list{EmptyFrameId}, + Result); + + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_width"); + auto &HostFuncAVFrameWidth = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameWidth.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + + EXPECT_EQ(Result[0].get(), 1920); // Width + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_height"); + auto &HostFuncAVFrameHeight = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameHeight.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + + EXPECT_EQ(Result[0].get(), 1080); // Height + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_video_format"); + auto &HostFuncAVFrameVideoFormat = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameVideoFormat &>( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameVideoFormat.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + + EXPECT_EQ(Result[0].get(), 1); // Video Format + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_isnull"); + auto &HostFuncAVFrameIsNull = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameIsNull.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_linesize"); + auto &HostFuncAVFrameLinesize = + dynamic_cast( + FuncInst->getHostFunc()); + + int32_t Stride = 0; + uint32_t Idx = 0; + { + HostFuncAVFrameLinesize.run( + CallFrame, std::initializer_list{AVFrameId, Idx}, + Result); + + Stride = Result[0].get(); + EXPECT_EQ(Stride, 1920); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_get_buffer"); + auto &HostFuncAVFrameGetBuffer = + dynamic_cast( + FuncInst->getHostFunc()); + { + // For video, it is 32. + int32_t Align = 32; + HostFuncAVFrameGetBuffer.run( + CallFrame, + std::initializer_list{AVFrameId, Align}, Result); + + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_best_effort_timestamp"); + auto &HostFuncAVFrameBestEffortTimestamp = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameBestEffortTimestamp &>( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameBestEffortTimestamp.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_pict_type"); + auto &HostFuncAVFramePictType = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVFramePictType.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_interlaced_frame"); + auto &HostFuncAVFrameInterlacedFrame = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameInterlacedFrame &>( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameInterlacedFrame.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_TRUE(Result[0].get() == 0); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_top_field_first"); + auto &HostFuncAVFrameTopFieldFirst = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameTopFieldFirst &>( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameTopFieldFirst.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_TRUE(Result[0].get() == 0); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_palette_has_changed"); + auto &HostFuncAVFramePaletteHasChanged = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFramePaletteHasChanged &>( + FuncInst->getHostFunc()); + + { + HostFuncAVFramePaletteHasChanged.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_TRUE(Result[0].get() == 0); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_colorspace"); + auto &HostFuncAVFrameColorspace = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameColorspace.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), 2); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_color_range"); + auto &HostFuncAVFrameColorRange = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameColorRange.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_color_trc"); + auto &HostAVFrameColorTransferCharacteristic = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameColorTransferCharacteristic + &>(FuncInst->getHostFunc()); + + { + HostAVFrameColorTransferCharacteristic.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), 2); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_chroma_location"); + auto &HostAVFrameChromaLocation = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameChromaLocation &>( + FuncInst->getHostFunc()); + + { + HostAVFrameChromaLocation.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_coded_picture_number"); + auto &HostAVFrameCodedPictureNumber = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameCodedPictureNumber &>( + FuncInst->getHostFunc()); + + { + HostAVFrameCodedPictureNumber.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_display_picture_number"); + auto &HostAVFrameDisplayPictureNumber = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameDisplayPictureNumber &>( + FuncInst->getHostFunc()); + + { + HostAVFrameDisplayPictureNumber.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_repeat_pict"); + auto &HostAVFrameRepeatPict = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostAVFrameRepeatPict.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_flags"); + auto &HostAVFrameFlags = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostAVFrameFlags.run(CallFrame, + std::initializer_list{AVFrameId}, + Result); + EXPECT_TRUE(Result[0].get() != 1 << 0); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_quality"); + auto &HostAVFrameQuality = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostAVFrameQuality.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_metadata"); + auto &HostAVFrameMetadata = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostAVFrameMetadata.run( + CallFrame, + std::initializer_list{AVFrameId, DictPtr}, + Result); + EXPECT_EQ(Result[0].get(), 0); + } + + uint32_t DictId = readUInt32(MemInst, DictPtr); + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_set_metadata"); + auto &HostAVFrameSetMetadata = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameSetMetadata &>( + FuncInst->getHostFunc()); + + { + HostAVFrameSetMetadata.run( + CallFrame, + std::initializer_list{AVFrameId, DictId}, Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_key_frame"); + auto &HostAVFrameKeyFrame = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostAVFrameKeyFrame.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_pts"); + auto &HostAVFramePts = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostAVFramePts.run(CallFrame, + std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_copy"); + auto &HostAVFrameCopy = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostAVFrameCopy.run( + CallFrame, + std::initializer_list{AVFrame2Id, AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_copy_props"); + auto &HostAVFrameCopyProps = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostAVFrameCopyProps.run( + CallFrame, + std::initializer_list{AVFrame2Id, AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_set_width"); + auto &HostFuncAVFrameSetWidth = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int32_t Width = 100; + HostFuncAVFrameSetWidth.run( + CallFrame, + std::initializer_list{AVFrameId, Width}, Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + HostFuncAVFrameWidth.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), Width); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_set_height"); + auto &HostFuncAVFrameSetHeight = + dynamic_cast( + FuncInst->getHostFunc()); + + int32_t Height = 100; + { + HostFuncAVFrameSetHeight.run( + CallFrame, + std::initializer_list{AVFrameId, Height}, Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + HostFuncAVFrameHeight.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), Height); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_data"); + auto &HostFuncAVFrameData = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int32_t Size = 1; // Just reading One byte data for test. + fillMemContent(MemInst, BufPtr, Size); + HostFuncAVFrameData.run(CallFrame, + std::initializer_list{ + AVFrameId, BufPtr, Size, Idx}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_set_video_format"); + auto &HostFuncAVFrameSetVideoFormat = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameSetVideoFormat &>( + FuncInst->getHostFunc()); + + { + uint32_t PixFormatId = 10; // GRAY8 + HostFuncAVFrameSetVideoFormat.run( + CallFrame, + std::initializer_list{AVFrameId, PixFormatId}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + HostFuncAVFrameVideoFormat.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), PixFormatId); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_set_pict_type"); + auto &HostFuncAVFrameSetPictType = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameSetPictType &>( + FuncInst->getHostFunc()); + + { + int32_t PictureId = 4; // AV_PICTURE_TYPE_S + HostFuncAVFrameSetPictType.run( + CallFrame, + std::initializer_list{AVFrameId, PictureId}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + HostFuncAVFramePictType.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), PictureId); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_set_colorspace"); + auto &HostFuncAVFrameSetColorSpace = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameSetColorSpace &>( + FuncInst->getHostFunc()); + + { + int32_t ColorSpaceId = 4; // FCC + HostFuncAVFrameSetColorSpace.run( + CallFrame, + std::initializer_list{AVFrameId, ColorSpaceId}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + HostFuncAVFrameColorspace.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), ColorSpaceId); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_set_color_range"); + auto &HostFuncAVFrameSetColorRange = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameSetColorRange &>( + FuncInst->getHostFunc()); + + { + int32_t ColorRangeId = 1; // MPEG + HostFuncAVFrameSetColorRange.run( + CallFrame, + std::initializer_list{AVFrameId, ColorRangeId}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + HostFuncAVFrameColorRange.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), ColorRangeId); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_set_color_trc"); + auto &HostFuncAVFrameSetColorTransferCharacteristic = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int32_t ColorTrcId = 5; // GAMMA28 + HostFuncAVFrameSetColorTransferCharacteristic.run( + CallFrame, + std::initializer_list{AVFrameId, ColorTrcId}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + HostAVFrameColorTransferCharacteristic.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), ColorTrcId); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_set_pts"); + auto &HostFuncAVFrameSetPts = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int64_t Pts = 10; + HostFuncAVFrameSetPts.run( + CallFrame, std::initializer_list{AVFrameId, Pts}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + HostAVFramePts.run(CallFrame, + std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), Pts); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_sample_aspect_ratio"); + auto &HostFuncAVFrameSampleAspectRatio = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameSampleAspectRatio &>( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameSampleAspectRatio.run( + CallFrame, + std::initializer_list{AVFrameId, NumPtr, DenPtr}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + int32_t ColorPrimariesId = 1; // BT709 + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_set_color_primaries"); + auto &HostFuncAVFrameSetColorPrimaries = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameSetColorPrimaries &>( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameSetColorPrimaries.run( + CallFrame, + std::initializer_list{AVFrameId, + ColorPrimariesId}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_color_primaries"); + auto &HostFuncAVFrameColorPrimaries = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameColorPrimaries &>( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameColorPrimaries.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), ColorPrimariesId); + } + + // ========================================================================== + // AVFrame Audio Funcs. + // ========================================================================== + + // Setting the fields to Video Frame itself. + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_set_audio_format"); + auto &HostFuncAVFrameSetAudioFormat = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameSetAudioFormat &>( + FuncInst->getHostFunc()); + + uint32_t SampleFormatId = 4; + { + HostFuncAVFrameSetAudioFormat.run( + CallFrame, + std::initializer_list{AVFrameId, SampleFormatId}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_audio_format"); + auto &HostFuncAVFrameAudioFormat = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameAudioFormat &>( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameAudioFormat.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), SampleFormatId); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_set_nb_samples"); + auto &HostFuncAVFrameSetNbSamples = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameSetNbSamples &>( + FuncInst->getHostFunc()); + + int32_t NbSamples = 32; + { + HostFuncAVFrameSetNbSamples.run( + CallFrame, + std::initializer_list{AVFrameId, NbSamples}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_nb_samples"); + auto &HostFuncAVFrameNbSamples = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameNbSamples.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), NbSamples); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_set_sample_rate"); + auto &HostFuncAVFrameSetSampleRate = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameSetSampleRate &>( + FuncInst->getHostFunc()); + + int32_t SampleRate = 10; + { + HostFuncAVFrameSetSampleRate.run( + CallFrame, + std::initializer_list{AVFrameId, SampleRate}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_sample_rate"); + auto &HostFuncAVFrameSampleRate = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameSampleRate.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), SampleRate); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_set_channels"); + auto &HostFuncAVFrameSetChannels = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameSetChannels &>( + FuncInst->getHostFunc()); + + int32_t Channels = 3; + { + HostFuncAVFrameSetChannels.run( + CallFrame, + std::initializer_list{AVFrameId, Channels}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_channels"); + auto &HostFuncAVFrameChannels = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameChannels.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), Channels); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_set_channel_layout"); + auto &HostFuncAVFrameSetChannelLayout = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameSetChannelLayout &>( + FuncInst->getHostFunc()); + + uint64_t ChannelLayout = 1UL << 10; + { + HostFuncAVFrameSetChannelLayout.run( + CallFrame, + std::initializer_list{AVFrameId, ChannelLayout}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_frame_channel_layout"); + auto &HostFuncAVFrameChannelLayout = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameChannelLayout &>( + FuncInst->getHostFunc()); + + { + HostFuncAVFrameChannelLayout.run( + CallFrame, std::initializer_list{AVFrameId}, + Result); + EXPECT_EQ(Result[0].get(), ChannelLayout); + } + + // ========================================================================== + // AVFrame Audio Funcs. + // ========================================================================== +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avPixfmt.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avPixfmt.cpp new file mode 100644 index 00000000..afef37c1 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avutil/avPixfmt.cpp @@ -0,0 +1,244 @@ +#include "avutil/module.h" +#include "avutil/pixfmt.h" +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVPixFmt) { + + uint32_t NamePtr = UINT32_C(4); + + auto *FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_avpixfmtdescriptor_nb_components"); + auto &HostFuncAVPixFmtDescriptorNbComponents = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AvPixFmtDescriptorNbComponents &>( + FuncInst->getHostFunc()); + + uint32_t PixFmtId = 3; // RGB24 + + { + HostFuncAVPixFmtDescriptorNbComponents.run( + CallFrame, std::initializer_list{PixFmtId}, + Result); + + EXPECT_EQ(Result[0].get(), PixFmtId); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_avpixfmtdescriptor_log2_chromaw"); + auto &HostFuncAvPixFmtDescriptorLog2ChromaW = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AvPixFmtDescriptorLog2ChromaW &>( + FuncInst->getHostFunc()); + + { + HostFuncAvPixFmtDescriptorLog2ChromaW.run( + CallFrame, std::initializer_list{1}, Result); + + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_avpixfmtdescriptor_log2_chromah"); + auto &HostFuncAvPixFmtDescriptorLog2ChromaH = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AvPixFmtDescriptorLog2ChromaH &>( + FuncInst->getHostFunc()); + + { + HostFuncAvPixFmtDescriptorLog2ChromaH.run( + CallFrame, std::initializer_list{PixFmtId}, + Result); + + EXPECT_TRUE(Result[0].get() >= 0); + } + + int32_t Length = 0; + int32_t TransferCharacteristicId = 6; // (SMPTE170M) + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_color_transfer_name_length"); + auto &HostFuncAVColorTransferNameLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVColorTransferNameLength &>( + FuncInst->getHostFunc()); + + { + HostFuncAVColorTransferNameLength.run( + CallFrame, + std::initializer_list{TransferCharacteristicId}, + Result); + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + // Fill memory with zero. + fillMemContent(MemInst, NamePtr, Length); + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_color_transfer_name"); + auto &HostFuncAVColorTransferName = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVColorTransferName &>( + FuncInst->getHostFunc()); + + { + HostFuncAVColorTransferName.run( + CallFrame, + std::initializer_list{TransferCharacteristicId, + NamePtr, Length}, + Result); + + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + int32_t ColorRangeId = 2; //; JPEG + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_color_range_name_length"); + auto &HostFuncAVColorRangeNameLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVColorRangeNameLength &>( + FuncInst->getHostFunc()); + + { + HostFuncAVColorRangeNameLength.run( + CallFrame, std::initializer_list{ColorRangeId}, + Result); + + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + // Fill memory with zero. + fillMemContent(MemInst, NamePtr, Length); + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_color_range_name"); + auto &HostFuncAVColorRangeName = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVColorRangeName.run(CallFrame, + std::initializer_list{ + ColorRangeId, NamePtr, Length}, + Result); + + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + int32_t ColorSpaceId = 1; // BT709 + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_color_space_name_length"); + auto &HostFuncAVColorSpaceNameLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVColorSpaceNameLength &>( + FuncInst->getHostFunc()); + + { + HostFuncAVColorSpaceNameLength.run( + CallFrame, std::initializer_list{ColorSpaceId}, + Result); + + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + // Fill memory with zero. + fillMemContent(MemInst, NamePtr, Length); + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_color_space_name"); + auto &HostFuncAVColorSpaceName = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVColorSpaceName.run(CallFrame, + std::initializer_list{ + ColorSpaceId, NamePtr, Length}, + Result); + + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + int32_t ColorPrimariesId = 1; // BT709 + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_color_primaries_name_length"); + auto &HostFuncAVColorPrimariesNameLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVColorPrimariesNameLength &>( + FuncInst->getHostFunc()); + + { + HostFuncAVColorPrimariesNameLength.run( + CallFrame, + std::initializer_list{ColorPrimariesId}, Result); + + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + // Fill memory with zero. + fillMemContent(MemInst, NamePtr, Length); + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_color_primaries_name"); + auto &HostFuncAVColorPrimariesName = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVColorPrimariesName &>( + FuncInst->getHostFunc()); + + { + HostFuncAVColorPrimariesName.run( + CallFrame, + std::initializer_list{ColorPrimariesId, NamePtr, + Length}, + Result); + + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + PixFmtId = 1; // YUV420P + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_pix_format_name_length"); + auto &HostFuncAVPixFormatNameLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVPixelFormatNameLength &>( + FuncInst->getHostFunc()); + + { + HostFuncAVPixFormatNameLength.run( + CallFrame, std::initializer_list{PixFmtId}, + Result); + + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + // Fill memory with zero. + fillMemContent(MemInst, NamePtr, Length); + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_pix_format_name"); + auto &HostFuncAVPixFormatName = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVPixFormatName.run( + CallFrame, + std::initializer_list{PixFmtId, NamePtr, Length}, + Result); + + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_pix_format_mask"); + auto &HostFuncAVPixFormatMask = + dynamic_cast( + FuncInst->getHostFunc()); + + { + uint32_t PixId = 3; // AV_PIX_FMT_RGB24: + HostFuncAVPixFormatMask.run( + CallFrame, std::initializer_list{PixId}, Result); + + EXPECT_EQ(Result[0].get(), + 2); // Verify Mask. Position of Pix in Enum. + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avRational.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avRational.cpp new file mode 100644 index 00000000..250c0fec --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avutil/avRational.cpp @@ -0,0 +1,313 @@ +#include "avutil/avRational.h" +#include "avutil/module.h" +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVRational) { + + ASSERT_TRUE(AVUtilMod != nullptr); + + uint32_t NumPtr = UINT32_C(4); + uint32_t DenPtr = UINT32_C(8); + + // Addition Function + auto *FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_add_q"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVAddQ = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int32_t ANum = 3; + int32_t ADen = 4; + int32_t BNum = -6; + int32_t BDen = 7; + EXPECT_TRUE(HostFuncAVAddQ.run(CallFrame, + std::initializer_list{ + ANum, ADen, BNum, BDen, NumPtr, DenPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_EQ(readSInt32(MemInst, NumPtr), -3); + EXPECT_EQ(readSInt32(MemInst, DenPtr), 28); + } + + // Subtraction Function + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_sub_q"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVSubQ = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int32_t ANum = -843; + int32_t ADen = 11; + int32_t BNum = 38; + int32_t BDen = 12; + + writeSInt32(MemInst, 0, NumPtr); // Setting value of pointer to 0. + writeSInt32(MemInst, 0, DenPtr); // Setting value of pointer to 0. + EXPECT_TRUE(HostFuncAVSubQ.run(CallFrame, + std::initializer_list{ + ANum, ADen, BNum, BDen, NumPtr, DenPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_EQ(readSInt32(MemInst, NumPtr), -5267); + EXPECT_EQ(readSInt32(MemInst, DenPtr), 66); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_mul_q"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVMulQ = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int32_t ANum = -6; + int32_t ADen = 7; + int32_t BNum = 3; + int32_t BDen = 4; + + writeSInt32(MemInst, 0, NumPtr); // Setting value of pointer to 0. + writeSInt32(MemInst, 0, DenPtr); // Setting value of pointer to 0. + EXPECT_TRUE(HostFuncAVMulQ.run(CallFrame, + std::initializer_list{ + ANum, ADen, BNum, BDen, NumPtr, DenPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_EQ(readSInt32(MemInst, NumPtr), -9); + EXPECT_EQ(readSInt32(MemInst, DenPtr), 14); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_div_q"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVDivQ = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int32_t ANum = -6; + int32_t ADen = 7; + int32_t BNum = 3; + int32_t BDen = 4; + + writeSInt32(MemInst, 0, NumPtr); // Setting value of pointer to 0. + writeSInt32(MemInst, 0, DenPtr); // Setting value of pointer to 0. + EXPECT_TRUE(HostFuncAVDivQ.run(CallFrame, + std::initializer_list{ + ANum, ADen, BNum, BDen, NumPtr, DenPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_EQ(readSInt32(MemInst, NumPtr), -8); + EXPECT_EQ(readSInt32(MemInst, DenPtr), 7); + } + + // How to Pass a Double functions. + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_d2q"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVD2Q = + dynamic_cast( + FuncInst->getHostFunc()); + + { + double D = 5; + int32_t Max = 10; + + writeSInt32(MemInst, 0, NumPtr); // Setting value of pointer to 0. + writeSInt32(MemInst, 0, DenPtr); // Setting value of pointer to 0. + + EXPECT_TRUE(HostFuncAVD2Q.run( + CallFrame, + std::initializer_list{D, Max, NumPtr, DenPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_EQ(readSInt32(MemInst, NumPtr), 5); + EXPECT_EQ(readSInt32(MemInst, DenPtr), 1); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_q2d"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVQ2d = + dynamic_cast( + FuncInst->getHostFunc()); + + { + // Convert Rational Number to Double. + int32_t ANum = 1; + int32_t ADen = 2; + + writeSInt32(MemInst, 0, NumPtr); // Setting value of pointer to 0. + writeSInt32(MemInst, 0, DenPtr); // Setting value of pointer to 0. + EXPECT_TRUE(HostFuncAVQ2d.run( + CallFrame, std::initializer_list{ANum, ADen}, + Result)); + EXPECT_EQ(Result[0].get(), 0.5); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_inv_q"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInvQ = + dynamic_cast( + FuncInst->getHostFunc()); + + { + // Inverse a Rational Number. + int32_t ANum = -3; + int32_t ADen = 4; + + writeSInt32(MemInst, 0, NumPtr); // Setting value of pointer to 0. + writeSInt32(MemInst, 0, DenPtr); // Setting value of pointer to 0. + EXPECT_TRUE(HostFuncInvQ.run( + CallFrame, + std::initializer_list{ANum, ADen, NumPtr, DenPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + + EXPECT_EQ(readSInt32(MemInst, NumPtr), 4); + EXPECT_EQ(readSInt32(MemInst, DenPtr), -3); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_q2intfloat"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVQ2IntFloat = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int32_t ANum = 1; + int32_t ADen = 5; + + EXPECT_TRUE(HostFuncAVQ2IntFloat.run( + CallFrame, std::initializer_list{ANum, ADen}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(1045220557)); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_nearer_q"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVNearerQ = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int32_t ANum = 1; + int32_t ADen = 3; + int32_t BNum = 1; + int32_t BDen = 2; + int32_t CNum = -1; + int32_t CDen = 2; + + // B nearer to A + EXPECT_TRUE( + HostFuncAVNearerQ.run(CallFrame, + std::initializer_list{ + ANum, ADen, BNum, BDen, CNum, CDen}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(1)); + + ANum = -1; + + // C nearer to A + EXPECT_TRUE( + HostFuncAVNearerQ.run(CallFrame, + std::initializer_list{ + ANum, ADen, BNum, BDen, CNum, CDen}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(-1)); + + ANum = 0; + ADen = 0; + + // Both are at same distance + EXPECT_TRUE( + HostFuncAVNearerQ.run(CallFrame, + std::initializer_list{ + ANum, ADen, BNum, BDen, CNum, CDen}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(0)); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_cmp_q"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVCmpQ = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int32_t ANum = 1; + int32_t ADen = 2; + int32_t BNum = 2; + int32_t BDen = 1; + // A < B + EXPECT_TRUE(HostFuncAVCmpQ.run( + CallFrame, + std::initializer_list{ANum, ADen, BNum, BDen}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(-1)); + + ANum = 2; + ADen = 1; + BNum = 1; + BDen = 2; + // A > B + EXPECT_TRUE(HostFuncAVCmpQ.run( + CallFrame, + std::initializer_list{ANum, ADen, BNum, BDen}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(1)); + + ANum = 2; + ADen = 1; + BNum = 2; + BDen = 1; + // A == B + EXPECT_TRUE(HostFuncAVCmpQ.run( + CallFrame, + std::initializer_list{ANum, ADen, BNum, BDen}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(0)); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_reduce"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVReduce = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int64_t ANum = 1; + int64_t ADen = 2; + int64_t Max = 3; + EXPECT_TRUE( + HostFuncAVReduce.run(CallFrame, + std::initializer_list{ + NumPtr, DenPtr, ANum, ADen, Max}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(1)); + } +} +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avSampleFmt.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avSampleFmt.cpp new file mode 100644 index 00000000..c4f13e04 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avutil/avSampleFmt.cpp @@ -0,0 +1,201 @@ +#include "avutil/module.h" +#include "avutil/samplefmt.h" +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVSampleFmt) { + + ASSERT_TRUE(AVUtilMod != nullptr); + + uint32_t BufferPtr = UINT32_C(160); + uint32_t NamePtr = UINT32_C(80); + uint32_t LinesizePtr = UINT32_C(20); + + uint32_t SampleFmtId = 1; // AV_SAMPLE_FMT_S32 + auto *FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_get_packed_sample_fmt"); + auto &HostFuncAVGetPackedSampleFmt = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVGetPackedSampleFmt &>( + FuncInst->getHostFunc()); + + { + HostFuncAVGetPackedSampleFmt.run( + CallFrame, std::initializer_list{SampleFmtId}, + Result); + + EXPECT_EQ(Result[0].get(), SampleFmtId); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_get_planar_sample_fmt"); + auto &HostFuncAVGetPlanarSampleFmt = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVGetPlanarSampleFmt &>( + FuncInst->getHostFunc()); + + { + HostFuncAVGetPlanarSampleFmt.run( + CallFrame, std::initializer_list{SampleFmtId}, + Result); + + EXPECT_EQ(Result[0].get(), 6); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_sample_fmt_is_planar"); + auto &HostFuncAVSampleFmtIsPlanar = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVSampleFmtIsPlanar &>( + FuncInst->getHostFunc()); + + { + HostFuncAVSampleFmtIsPlanar.run( + CallFrame, std::initializer_list{SampleFmtId}, + Result); + + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_get_bytes_per_sample"); + auto &HostFuncAVGetBytesPerSample = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVGetBytesPerSample &>( + FuncInst->getHostFunc()); + + { + HostFuncAVGetBytesPerSample.run( + CallFrame, std::initializer_list{SampleFmtId}, + Result); + + EXPECT_TRUE(Result[0].get() >= 0); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_get_sample_fmt"); + auto &HostFuncAVGetSampleFmt = + dynamic_cast( + FuncInst->getHostFunc()); + + uint32_t SampleFmtStart = 100; + uint32_t SampleFmtSize = 2; + fillMemContent(MemInst, SampleFmtSize, SampleFmtSize); + + fillMemContent(MemInst, SampleFmtStart, std::string("u8")); + { + HostFuncAVGetSampleFmt.run(CallFrame, + std::initializer_list{ + SampleFmtStart, SampleFmtSize}, + Result); + + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_samples_get_buffer_size"); + auto &HostFuncAVSamplesGetBufferSize = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVSamplesGetBufferSize &>( + FuncInst->getHostFunc()); + + int32_t NbChannels = 1; + int32_t NbSamples = 5; + int32_t Align = 1; + int32_t BufSize = 0; + { + HostFuncAVSamplesGetBufferSize.run( + CallFrame, + std::initializer_list{NbChannels, NbSamples, + SampleFmtId, Align}, + Result); + + BufSize = Result[0].get(); + EXPECT_TRUE(BufSize); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_samples_alloc_array_and_samples"); + auto &HostFuncAVSamplesAllocArrayAndSamples = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVSamplesAllocArrayAndSamples &>( + FuncInst->getHostFunc()); + + { + HostFuncAVSamplesAllocArrayAndSamples.run( + CallFrame, + std::initializer_list{ + BufferPtr, LinesizePtr, NbChannels, NbSamples, SampleFmtId, Align}, + Result); + + EXPECT_TRUE(Result[0].get() >= 0); + } + + uint32_t BufId = readUInt32(MemInst, BufferPtr); + ASSERT_TRUE(BufId > 0); + + int32_t Length = 0; + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_get_sample_fmt_name_length"); + auto &HostFuncAVGetSampleFmtNameLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVGetSampleFmtNameLength &>( + FuncInst->getHostFunc()); + + { + HostFuncAVGetSampleFmtNameLength.run( + CallFrame, std::initializer_list{SampleFmtId}, + Result); + + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_get_sample_fmt_name"); + auto &HostFuncAVGetSampleFmtName = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVGetSampleFmtName &>( + FuncInst->getHostFunc()); + + // Fill Memory with 0. + fillMemContent(MemInst, NamePtr, Length); + { + HostFuncAVGetSampleFmtName.run(CallFrame, + std::initializer_list{ + SampleFmtId, NamePtr, Length}, + Result); + + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_get_sample_fmt_mask"); + auto &HostFuncAVGetSampleFmtMask = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVGetSampleFmtMask &>( + FuncInst->getHostFunc()); + + { + uint32_t SampleId = 2; // AV_SAMPLE_FMT_S16; + HostFuncAVGetSampleFmtMask.run( + CallFrame, std::initializer_list{SampleId}, + Result); + + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_freep"); + auto &HostFuncAVFreep = + dynamic_cast( + FuncInst->getHostFunc()); + + { + uint32_t BufferId = readUInt32(MemInst, BufferPtr); + HostFuncAVFreep.run(CallFrame, + std::initializer_list{BufferId}, + Result); + + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp new file mode 100644 index 00000000..c6431ece --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp @@ -0,0 +1,261 @@ +#include "avutil/avutil_func.h" +#include "avutil/avTime.h" +#include "avutil/module.h" +#include "utils.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, AVUtilFunc) { + + ASSERT_TRUE(AVUtilMod != nullptr); + + uint32_t NamePtr = UINT32_C(4); + + auto *FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_log_set_level"); + auto &HostFuncAVLogSetLevel = + dynamic_cast( + FuncInst->getHostFunc()); + + int32_t LogLvlId = 32; + { + HostFuncAVLogSetLevel.run( + CallFrame, std::initializer_list{LogLvlId}, + Result); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_log_get_level"); + auto &HostFuncAVLogGetLevel = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVLogGetLevel.run( + CallFrame, std::initializer_list{}, Result); + EXPECT_EQ(Result[0].get(), LogLvlId); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_log_set_flags"); + auto &HostFuncAVLogSetFlags = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVLogSetFlags.run( + CallFrame, std::initializer_list{1}, Result); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_log_get_flags"); + auto &HostFuncAVLogGetFlags = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVLogGetFlags.run( + CallFrame, std::initializer_list{1}, Result); + + EXPECT_EQ(Result[0].get(), 32); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_rescale_q"); + auto &HostFuncAVRescaleQ = + dynamic_cast( + FuncInst->getHostFunc()); + + int64_t A = 20; + int32_t BNum = 5; + int32_t BDen = 10; + int32_t CNum = 5; + int32_t CDen = 20; + + { + HostFuncAVRescaleQ.run( + CallFrame, + std::initializer_list{A, BNum, BDen, CNum, CDen}, + Result); + + EXPECT_TRUE(Result[0].get() > 0); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_rescale_q_rnd"); + auto &HostFuncAVRescaleQRnd = + dynamic_cast( + FuncInst->getHostFunc()); + + { + int32_t RoundingId = 2; + HostFuncAVRescaleQRnd.run(CallFrame, + std::initializer_list{ + A, BNum, BDen, CNum, CDen, RoundingId}, + Result); + + EXPECT_TRUE(Result[0].get() > 0); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_avutil_version"); + auto &HostFuncAVUtilVersion = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVUtilVersion.run( + CallFrame, std::initializer_list{}, Result); + + EXPECT_TRUE(Result[0].get() > 0); + } + + uint64_t ChannelId = 1; // FRONT_LEFT + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_get_channel_layout_nb_channels"); + auto &HostFuncAVGetChannelLayoutNbChannels = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVGetChannelLayoutNbChannels &>( + FuncInst->getHostFunc()); + + { + HostFuncAVGetChannelLayoutNbChannels.run( + CallFrame, std::initializer_list{ChannelId}, + Result); + EXPECT_TRUE(Result[0].get() > 0); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_get_default_channel_layout"); + auto &HostFuncAVGetDefaultChannelLayout = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVGetDefaultChannelLayout &>( + FuncInst->getHostFunc()); + + { + HostFuncAVGetDefaultChannelLayout.run( + CallFrame, std::initializer_list{ChannelId}, + Result); + EXPECT_TRUE(Result[0].get() > 0); + } + + uint32_t Length = 0; + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_avutil_configuration_length"); + auto &HostFuncAVUtilConfigurationLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVUtilConfigurationLength &>( + FuncInst->getHostFunc()); + + { + HostFuncAVUtilConfigurationLength.run( + CallFrame, std::initializer_list{}, Result); + + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + // Fill NamePtr with 0. + fillMemContent(MemInst, NamePtr, Length); + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_avutil_configuration"); + auto &HostFuncAVUtilConfiguration = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVUtilConfiguration &>( + FuncInst->getHostFunc()); + + { + HostFuncAVUtilConfiguration.run( + CallFrame, std::initializer_list{NamePtr, Length}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_avutil_license_length"); + auto &HostFuncAVUtilLicenseLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVUtilLicenseLength &>( + FuncInst->getHostFunc()); + + { + HostFuncAVUtilLicenseLength.run( + CallFrame, std::initializer_list{}, Result); + + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + // Fill NamePtr with 0. + fillMemContent(MemInst, NamePtr, Length); + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_avutil_license"); + auto &HostFuncAVUtilLicense = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVUtilLicense.run( + CallFrame, std::initializer_list{NamePtr, Length}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} + +TEST_F(FFmpegTest, AVTime) { + + ASSERT_TRUE(AVUtilMod != nullptr); + + auto *FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_gettime"); + auto &HostFuncAVGetTime = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVGetTime.run( + CallFrame, std::initializer_list{}, Result); + + EXPECT_TRUE(Result[0].get() > 0); + } + + FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_gettime_relative"); + auto &HostFuncAVGetTimeRelative = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVGetTimeRelative.run( + CallFrame, std::initializer_list{}, Result); + + EXPECT_TRUE(Result[0].get() > 0); + } + + FuncInst = AVUtilMod->findFuncExports( + "wasmedge_ffmpeg_avutil_av_gettime_relative_is_monotonic"); + auto &HostFuncAVGetTimeRelativeIsMonotonic = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVGetTimeRelativeIsMonotonic &>( + FuncInst->getHostFunc()); + + { + HostFuncAVGetTimeRelativeIsMonotonic.run( + CallFrame, std::initializer_list{}, Result); + + EXPECT_EQ(Result[0].get(), 1); + } + + FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_usleep"); + auto &HostFuncAVUSleep = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncAVUSleep.run( + CallFrame, std::initializer_list{1000}, Result); + + EXPECT_EQ(Result[0].get(), 0); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/test/plugins/wasmedge_ffmpeg/main.cpp b/test/plugins/wasmedge_ffmpeg/main.cpp new file mode 100644 index 00000000..852694a0 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/main.cpp @@ -0,0 +1,6 @@ +#include + +GTEST_API_ int main(int Argc, char **Argv) { + testing::InitGoogleTest(&Argc, Argv); + return RUN_ALL_TESTS(); +} diff --git a/test/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp b/test/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp new file mode 100644 index 00000000..c5879355 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp @@ -0,0 +1,255 @@ +#include "swresample/swresample_func.h" +#include "swresample/module.h" + +#include "utils.h" +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +TEST_F(FFmpegTest, SWResampleFunc) { + + ASSERT_TRUE(SWResampleMod != nullptr); + + uint32_t DictPtr = UINT32_C(4); + uint32_t SWResamplePtr = UINT32_C(8); + uint32_t FramePtr = UINT32_C(72); + uint32_t Frame2Ptr = UINT32_C(16); + uint32_t KeyPtr = UINT32_C(100); + uint32_t ValuePtr = UINT32_C(200); + + initDict(DictPtr, KeyPtr, std::string("Key"), ValuePtr, std::string("Value")); + + std::string FileName = "ffmpeg-assets/sample_video.mp4"; // 32 chars + initFFmpegStructs(UINT32_C(20), UINT32_C(24), UINT32_C(28), FileName, + UINT32_C(60), UINT32_C(64), UINT32_C(68), FramePtr); + + uint32_t StrPtr = UINT32_C(76); + initEmptyFrame(Frame2Ptr); + + uint32_t DictId = readUInt32(MemInst, DictPtr); + uint32_t FrameId = readUInt32(MemInst, FramePtr); + uint32_t Frame2Id = readUInt32(MemInst, Frame2Ptr); + + auto *FuncInst = SWResampleMod->findFuncExports( + "wasmedge_ffmpeg_swresample_swresample_version"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSWResampleVersion = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWResample::SWResampleVersion &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncSWResampleVersion.run(CallFrame, {}, Result)); + ASSERT_TRUE(Result[0].get() > 0); + } + + FuncInst = SWResampleMod->findFuncExports( + "wasmedge_ffmpeg_swresample_swr_alloc_set_opts"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwrAllocSetOpts = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWResample::SWRAllocSetOpts &>( + FuncInst->getHostFunc()); + + // Testing with Null Old SwrCtx. Hence 2nd argument is 0. + { + uint32_t SWRCtxId = 0; + uint64_t OutChLayoutId = 1 << 1; // Front Right + uint32_t OutSampleFmtId = 2; // AV_SAMPLE_FMT_S16 + int32_t OutSampleRate = 30; + uint64_t InChLayoutId = 1 << 2; // FRONT_CENTER + uint32_t InSampleFmtId = 3; // AV_SAMPLE_FMT_S32 + int32_t SampleRate = 40; + int32_t LogOffset = 1; + + EXPECT_TRUE(HostFuncSwrAllocSetOpts.run( + CallFrame, + std::initializer_list{ + SWResamplePtr, SWRCtxId, OutChLayoutId, OutSampleFmtId, + OutSampleRate, InChLayoutId, InSampleFmtId, SampleRate, LogOffset}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + ASSERT_TRUE(readUInt32(MemInst, SWResamplePtr) > 0); + } + + // Test with Existing SwrCtx. + uint32_t SwrId = readUInt32(MemInst, SWResamplePtr); + { + uint32_t SWRCtxId = SwrId; + uint64_t OutChLayoutId = 1 << 1; // Front Right + uint32_t OutSampleFmtId = 2; // AV_SAMPLE_FMT_S16 + int32_t OutSampleRate = 30; + uint64_t InChLayoutId = 1 << 2; // FRONT_CENTER + uint32_t InSampleFmtId = 3; // AV_SAMPLE_FMT_S32 + int32_t SampleRate = 40; + int32_t LogOffset = 1; + EXPECT_TRUE(HostFuncSwrAllocSetOpts.run( + CallFrame, + std::initializer_list{ + SWResamplePtr, SWRCtxId, OutChLayoutId, OutSampleFmtId, + OutSampleRate, InChLayoutId, InSampleFmtId, SampleRate, LogOffset}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + ASSERT_TRUE(readUInt32(MemInst, SWResamplePtr) > 0); + } + + FuncInst = + SWResampleMod->findFuncExports("wasmedge_ffmpeg_swresample_swr_free"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwrFree = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncSwrFree.run( + CallFrame, std::initializer_list{SwrId}, Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + SWResampleMod->findFuncExports("wasmedge_ffmpeg_swresample_swr_init"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwrInit = + dynamic_cast( + FuncInst->getHostFunc()); + + { + SwrId = readUInt32(MemInst, SWResamplePtr); + EXPECT_TRUE(HostFuncSwrInit.run( + CallFrame, std::initializer_list{SwrId}, Result)); + ASSERT_TRUE(Result[0].get() >= 0); + } + + FuncInst = SWResampleMod->findFuncExports( + "wasmedge_ffmpeg_swresample_av_opt_set_dict"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncAVOptSetDict = + dynamic_cast( + FuncInst->getHostFunc()); + + { + uint32_t EmptyDictId = 0; + SwrId = readUInt32(MemInst, SWResamplePtr); + EXPECT_TRUE(HostFuncAVOptSetDict.run( + CallFrame, + std::initializer_list{SwrId, EmptyDictId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + { + SwrId = readUInt32(MemInst, SWResamplePtr); + EXPECT_TRUE(HostFuncAVOptSetDict.run( + CallFrame, std::initializer_list{SwrId, DictId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = SWResampleMod->findFuncExports( + "wasmedge_ffmpeg_swresample_swr_convert_frame"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwrConvertFrame = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWResample::SWRConvertFrame &>( + FuncInst->getHostFunc()); + + { + SwrId = readUInt32(MemInst, SWResamplePtr); + EXPECT_TRUE(HostFuncSwrConvertFrame.run( + CallFrame, + std::initializer_list{SwrId, Frame2Id, FrameId}, + Result)); + ASSERT_TRUE(Result[0].get()); + } + + FuncInst = SWResampleMod->findFuncExports( + "wasmedge_ffmpeg_swresample_swr_get_delay"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwrGetDelay = + dynamic_cast( + FuncInst->getHostFunc()); + + { + SwrId = readUInt32(MemInst, SWResamplePtr); + EXPECT_TRUE(HostFuncSwrGetDelay.run( + CallFrame, std::initializer_list{SwrId, 1}, + Result)); + EXPECT_EQ(Result[0].get(), 0); + } + + FuncInst = SWResampleMod->findFuncExports( + "wasmedge_ffmpeg_swresample_swresample_configuration_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwrConfigLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWResample::SWResampleConfigurationLength + &>(FuncInst->getHostFunc()); + + int32_t Length = 0; + { + SwrId = readUInt32(MemInst, SWResamplePtr); + EXPECT_TRUE(HostFuncSwrConfigLength.run( + CallFrame, std::initializer_list{}, Result)); + Length = Result[0].get(); + ASSERT_TRUE(Length > 0); + } + + FuncInst = SWResampleMod->findFuncExports( + "wasmedge_ffmpeg_swresample_swresample_configuration"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwrConfig = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWResample::SWResampleConfiguration &>( + FuncInst->getHostFunc()); + + { + SwrId = readUInt32(MemInst, SWResamplePtr); + EXPECT_TRUE(HostFuncSwrConfig.run( + CallFrame, std::initializer_list{StrPtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = SWResampleMod->findFuncExports( + "wasmedge_ffmpeg_swresample_swresample_license_length"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwrLicenseLen = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWResample::SWResampleLicenseLength &>( + FuncInst->getHostFunc()); + + { + SwrId = readUInt32(MemInst, SWResamplePtr); + EXPECT_TRUE(HostFuncSwrLicenseLen.run( + CallFrame, std::initializer_list{}, Result)); + + Length = Result[0].get(); + ASSERT_TRUE(Length > 0); + } + + FuncInst = SWResampleMod->findFuncExports( + "wasmedge_ffmpeg_swresample_swresample_license"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwrLicense = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWResample::SWResampleLicense &>( + FuncInst->getHostFunc()); + + { + SwrId = readUInt32(MemInst, SWResamplePtr); + EXPECT_TRUE(HostFuncSwrLicense.run( + CallFrame, std::initializer_list{StrPtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/test/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp b/test/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp new file mode 100644 index 00000000..cfe4b38c --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp @@ -0,0 +1,540 @@ +#include "swscale/swscale_func.h" +#include "swscale/module.h" + +#include "utils.h" +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +// ============================================================================ +// This test deals with funcs related to SwsContext +// ============================================================================ + +TEST_F(FFmpegTest, SwsContext) { + + ASSERT_TRUE(SWScaleMod != nullptr); + + auto *FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_getContext"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsGetContext = + dynamic_cast( + FuncInst->getHostFunc()); + + uint32_t SWScalePtr = UINT32_C(4); + uint32_t SWCachedScalePtr = UINT32_C(8); + uint32_t FramePtr = UINT32_C(72); + uint32_t Frame2Ptr = UINT32_C(124); + + std::string FileName = "ffmpeg-assets/sample_video.mp4"; // 32 chars + initFFmpegStructs(UINT32_C(12), UINT32_C(24), UINT32_C(28), FileName, + UINT32_C(60), UINT32_C(64), UINT32_C(68), FramePtr); + + initEmptyFrame(Frame2Ptr); + + uint32_t FrameId = readUInt32(MemInst, FramePtr); + uint32_t Frame2Id = readUInt32(MemInst, Frame2Ptr); + + uint32_t YUV420PId = 1; // YUV420P AVPixFormatId (From Bindings.h) + uint32_t RGB24Id = 3; // RGB24 AVPixFormatId (From Bindings.h) + uint32_t XVMCId = 174; // XVMC AVPixFormatId (From Bindings.h) + + uint32_t SrcWidth = 100; + uint32_t SrcHeight = 100; + uint32_t DestWidth = 200; + uint32_t DestHeight = 200; + int32_t Flags = 8; + uint32_t SrcFilterId = 0; + uint32_t DestFilterId = 0; + + // Allocating SWScale... + // Filter ID for source and destination is Null. + { + EXPECT_TRUE(HostFuncSwsGetContext.run( + CallFrame, + std::initializer_list{ + SWScalePtr, SrcWidth, SrcHeight, YUV420PId, DestWidth, DestHeight, + RGB24Id, Flags, SrcFilterId, DestFilterId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + ASSERT_TRUE(readUInt32(MemInst, SWScalePtr) > 0); + } + + uint32_t SWSScaleId = readUInt32(MemInst, SWScalePtr); + ASSERT_TRUE(SWSScaleId > 0); + + // Checking correctness of function. Returns Invalid Argument Error. + FuncInst = SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_scale"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsScale = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE( + HostFuncSwsScale.run(CallFrame, + std::initializer_list{ + SWSScaleId, FrameId, 20, 40, Frame2Id}, + Result)); + EXPECT_EQ(Result[0].get(), -22); + } + + FuncInst = SWScaleMod->findFuncExports( + "wasmedge_ffmpeg_swscale_sws_getCachedContext"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsGetCachedContext = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWScale::SwsGetCachedContext &>( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncSwsGetCachedContext.run( + CallFrame, + std::initializer_list{ + SWCachedScalePtr, SWSScaleId, SrcWidth, SrcHeight, YUV420PId, + DestWidth, DestHeight, RGB24Id, Flags, SrcFilterId, DestFilterId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + ASSERT_TRUE(readUInt32(MemInst, SWCachedScalePtr) > 0); + } + + FuncInst = SWScaleMod->findFuncExports( + "wasmedge_ffmpeg_swscale_sws_isSupportedInput"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsIsSupportedInput = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWScale::SwsIsSupportedInput &>( + FuncInst->getHostFunc()); + + { + // AV_PIX_FMT_RGB24 is supported Pixel Format + EXPECT_TRUE(HostFuncSwsIsSupportedInput.run( + CallFrame, std::initializer_list{RGB24Id}, + Result)); + ASSERT_TRUE(Result[0].get() > 0); + + // AV_PIX_FMT_XVMC is not supported Pixel Format + EXPECT_TRUE(HostFuncSwsIsSupportedInput.run( + CallFrame, std::initializer_list{XVMCId}, + Result)); + ASSERT_TRUE(Result[0].get() == 0); + } + + FuncInst = SWScaleMod->findFuncExports( + "wasmedge_ffmpeg_swscale_sws_isSupportedOutput"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsIsSupportedOutput = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWScale::SwsIsSupportedOutput &>( + FuncInst->getHostFunc()); + + { + // AV_PIX_FMT_RGB24 is supported Pixel Format + EXPECT_TRUE(HostFuncSwsIsSupportedOutput.run( + CallFrame, std::initializer_list{RGB24Id}, + Result)); + ASSERT_TRUE(Result[0].get() > 0); + + // AV_PIX_FMT_XVMC is not supported Pixel Format + EXPECT_TRUE(HostFuncSwsIsSupportedOutput.run( + CallFrame, std::initializer_list{XVMCId}, + Result)); + ASSERT_TRUE(Result[0].get() == 0); + } + + FuncInst = SWScaleMod->findFuncExports( + "wasmedge_ffmpeg_swscale_sws_isSupportedEndiannessConversion"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsIsSupportedEndiannessConversion = + dynamic_cast( + FuncInst->getHostFunc()); + + { + // AV_PIX_FMT_XVMC is not supported Pixel Format for + EXPECT_TRUE(HostFuncSwsIsSupportedEndiannessConversion.run( + CallFrame, std::initializer_list{XVMCId}, + Result)); + ASSERT_TRUE(Result[0].get() == 0); + } + + FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_freeContext"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsFreeContext = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncSwsFreeContext.run( + CallFrame, std::initializer_list{SWSScaleId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + { + uint32_t InvalidDestWidth = -200; + uint32_t InvalidDestHeight = -200; + uint32_t SWScalePtrInvalid = UINT32_C(80); + EXPECT_TRUE(HostFuncSwsGetContext.run( + CallFrame, + std::initializer_list{ + SWScalePtrInvalid, SrcWidth, SrcHeight, YUV420PId, InvalidDestWidth, + InvalidDestHeight, RGB24Id, Flags, SrcFilterId, DestFilterId}, + Result)); + EXPECT_EQ(Result[0].get(), + static_cast(ErrNo::InternalError)); + ASSERT_TRUE(readUInt32(MemInst, SWScalePtrInvalid) == 0); + } +} + +// ============================================================================ +// This test deals with funcs related to SwsFilter. +// ============================================================================ + +TEST_F(FFmpegTest, SwsFilter) { + + ASSERT_TRUE(SWScaleMod != nullptr); + auto *FuncInst = SWScaleMod->findFuncExports( + "wasmedge_ffmpeg_swscale_sws_getDefaultFilter"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsGetDefaultFilter = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWScale::SwsGetDefaultFilter &>( + FuncInst->getHostFunc()); + + uint32_t SwsFilterPtr = UINT32_C(40); + { + float LumaGBlur = 10.5; + float ChromaGBlur = 10.5; + float LumaSharpen = 10.5; + float ChromaSharpen = 10.5; + float ChromaHShift = 10.5; + float ChromaVShift = 10.5; + int32_t Verbose = 1; + + EXPECT_TRUE(HostFuncSwsGetDefaultFilter.run( + CallFrame, + std::initializer_list{ + SwsFilterPtr, LumaGBlur, ChromaGBlur, LumaSharpen, ChromaSharpen, + ChromaHShift, ChromaVShift, Verbose}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + ASSERT_TRUE(readUInt32(MemInst, SwsFilterPtr) > 0); + } + + uint32_t FilterId = readUInt32(MemInst, SwsFilterPtr); + FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_getLumaH"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsGetLumaH = + dynamic_cast( + FuncInst->getHostFunc()); + + uint32_t SwsVectorPtr = UINT32_C(20); + { + EXPECT_TRUE(HostFuncSwsGetLumaH.run( + CallFrame, + std::initializer_list{FilterId, SwsVectorPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + ASSERT_TRUE(readUInt32(MemInst, SwsVectorPtr) > 0); + } + + FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_getLumaV"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsGetLumaV = + dynamic_cast( + FuncInst->getHostFunc()); + + { + writeUInt32(MemInst, UINT32_C(0), SwsVectorPtr); + EXPECT_TRUE(HostFuncSwsGetLumaV.run( + CallFrame, + std::initializer_list{FilterId, SwsVectorPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + ASSERT_TRUE(readUInt32(MemInst, SwsVectorPtr) > 0); + } + + FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_getChromaH"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsGetChromaH = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncSwsGetChromaH.run( + CallFrame, + std::initializer_list{FilterId, SwsVectorPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + ASSERT_TRUE(readUInt32(MemInst, SwsVectorPtr) > 0); + } + + FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_getChromaV"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsGetChromaV = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncSwsGetChromaV.run( + CallFrame, + std::initializer_list{FilterId, SwsVectorPtr}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + ASSERT_TRUE(readUInt32(MemInst, SwsVectorPtr) > 0); + } + + FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_freeFilter"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsFreeFilter = + dynamic_cast( + FuncInst->getHostFunc()); + + { + EXPECT_TRUE(HostFuncSwsFreeFilter.run( + CallFrame, std::initializer_list{FilterId}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} + +// ============================================================================ +// This test deals with funcs related to SwsVector. +// ============================================================================ + +TEST_F(FFmpegTest, SwsVector) { + + ASSERT_TRUE(SWScaleMod != nullptr); + uint32_t SwsVectorPtr = UINT32_C(40); + uint32_t CoeffPtr = UINT32_C(100); + + auto *FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_allocVec"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsAllocVec = + dynamic_cast( + FuncInst->getHostFunc()); + + { + writeUInt32(MemInst, UINT32_C(0), SwsVectorPtr); + int32_t Length = 20; + EXPECT_TRUE(HostFuncSwsAllocVec.run( + CallFrame, + std::initializer_list{SwsVectorPtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + ASSERT_TRUE(readUInt32(MemInst, SwsVectorPtr) > 0); + } + + FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_getGaussianVec"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsGetGaussianVec = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWScale::SwsGetGaussianVec &>( + FuncInst->getHostFunc()); + + { + writeUInt32(MemInst, UINT32_C(0), SwsVectorPtr); + double Variance = 20.5; + double Quality = 4.3; + EXPECT_TRUE(HostFuncSwsGetGaussianVec.run( + CallFrame, + std::initializer_list{SwsVectorPtr, Variance, + Quality}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + ASSERT_TRUE(readUInt32(MemInst, SwsVectorPtr) > 0); + } + + FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_scaleVec"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsScaleVec = + dynamic_cast( + FuncInst->getHostFunc()); + + { + uint32_t SwsVecId = readUInt32(MemInst, SwsVectorPtr); + double Scalar = 20.35; + EXPECT_TRUE(HostFuncSwsScaleVec.run( + CallFrame, + std::initializer_list{SwsVecId, Scalar}, Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_normalizeVec"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsNormalizeVec = + dynamic_cast( + FuncInst->getHostFunc()); + + { + uint32_t SwsVecId = readUInt32(MemInst, SwsVectorPtr); + double Height = 4.3; + EXPECT_TRUE(HostFuncSwsNormalizeVec.run( + CallFrame, + std::initializer_list{SwsVecId, Height}, Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = SWScaleMod->findFuncExports( + "wasmedge_ffmpeg_swscale_sws_getCoeffVecLength"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsGetCoeffVecLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWScale::SwsGetCoeffVecLength &>( + FuncInst->getHostFunc()); + + int Length = 0; + { + uint32_t SwsVecId = readUInt32(MemInst, SwsVectorPtr); + EXPECT_TRUE(HostFuncSwsGetCoeffVecLength.run( + CallFrame, std::initializer_list{SwsVecId}, + Result)); + Length = Result[0].get(); + ASSERT_TRUE(Length > 0); + } + + FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_getCoeff"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsGetCoeff = + dynamic_cast( + FuncInst->getHostFunc()); + + fillMemContent(MemInst, CoeffPtr, Length); + { + uint32_t SwsVecId = readUInt32(MemInst, SwsVectorPtr); + EXPECT_TRUE(HostFuncSwsGetCoeff.run( + CallFrame, + std::initializer_list{SwsVecId, CoeffPtr, Length}, + Result)); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_sws_freeVec"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSwsFreeVec = + dynamic_cast( + FuncInst->getHostFunc()); + + { + uint32_t SwsVecId = readUInt32(MemInst, SwsVectorPtr); + EXPECT_TRUE(HostFuncSwsFreeVec.run( + CallFrame, std::initializer_list{SwsVecId}, + Result)); + } +} + +// ============================================================================ +// This test deals with funcs related to Version, Configuration and License +// ============================================================================ + +TEST_F(FFmpegTest, SWScaleVersion) { + + ASSERT_TRUE(SWScaleMod != nullptr); + + uint32_t Length = 0; + uint32_t NamePtr = UINT32_C(8); + + auto *FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_swscale_version"); + auto &HostFuncSwscaleVersion = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncSwscaleVersion.run( + CallFrame, std::initializer_list{}, Result); + + EXPECT_TRUE(Result[0].get() > 0); + } + + FuncInst = SWScaleMod->findFuncExports( + "wasmedge_ffmpeg_swscale_swscale_configuration_length"); + auto &HostFuncSwscaleConfigurationLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWScale::SwscaleConfigurationLength &>( + FuncInst->getHostFunc()); + + { + HostFuncSwscaleConfigurationLength.run( + CallFrame, std::initializer_list{}, Result); + + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + // Testing Version, Configuration, License + // Fill NamePtr with 0. + fillMemContent(MemInst, NamePtr, Length); + FuncInst = SWScaleMod->findFuncExports( + "wasmedge_ffmpeg_swscale_swscale_configuration"); + auto &HostFuncSwscaleConfiguration = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWScale::SwscaleConfiguration &>( + FuncInst->getHostFunc()); + + { + HostFuncSwscaleConfiguration.run( + CallFrame, std::initializer_list{NamePtr, Length}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } + + FuncInst = SWScaleMod->findFuncExports( + "wasmedge_ffmpeg_swscale_swscale_license_length"); + auto &HostFuncSwscaleLicenseLength = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWScale::SwscaleLicenseLength &>( + FuncInst->getHostFunc()); + + { + HostFuncSwscaleLicenseLength.run( + CallFrame, std::initializer_list{}, Result); + + Length = Result[0].get(); + EXPECT_TRUE(Length > 0); + } + + // Fill NamePtr with 0. + fillMemContent(MemInst, NamePtr, Length); + FuncInst = + SWScaleMod->findFuncExports("wasmedge_ffmpeg_swscale_swscale_license"); + auto &HostFuncSwscaleLicense = + dynamic_cast( + FuncInst->getHostFunc()); + + { + HostFuncSwscaleLicense.run( + CallFrame, std::initializer_list{NamePtr, Length}, + Result); + EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); + } +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/utils.cpp b/test/plugins/wasmedge_ffmpeg/utils.cpp new file mode 100644 index 00000000..cef6ba40 --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/utils.cpp @@ -0,0 +1,259 @@ +#include "utils.h" +#include "avcodec/avCodecContext.h" +#include "avcodec/avPacket.h" +#include "avcodec/avcodec_func.h" +#include "avformat/avStream.h" +#include "avformat/avformat_func.h" +#include "avutil/avDictionary.h" +#include "avutil/avFrame.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { + +void FFmpegTest::initEmptyFrame(uint32_t FramePtr) { + + auto *FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_alloc"); + auto &HostFuncAVFrameAlloc = + dynamic_cast( + FuncInst->getHostFunc()); + HostFuncAVFrameAlloc.run( + CallFrame, std::initializer_list{FramePtr}, Result); +} + +void FFmpegTest::initFFmpegStructs(uint32_t AVCodecPtr, uint32_t AVFormatCtxPtr, + uint32_t FilePtr, std::string FileName, + uint32_t CodecParameterPtr, + uint32_t AVCodecCtxPtr, uint32_t PacketPtr, + uint32_t FramePtr) { + initFormatCtx(AVFormatCtxPtr, FilePtr, FileName); + + uint32_t AvFormatCtxId = readUInt32(MemInst, AVFormatCtxPtr); + + auto *FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_av_find_best_stream"); + auto &HostFuncAVFindBestStream = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFindBestStream &>( + FuncInst->getHostFunc()); + uint32_t MediaTypeId = 0; // Video + uint32_t WantedStream = -1; + uint32_t RelatedStream = -1; + uint32_t DecoderRetId = 0; + uint32_t Flags = 0; + HostFuncAVFindBestStream.run(CallFrame, + std::initializer_list{ + AvFormatCtxId, MediaTypeId, WantedStream, + RelatedStream, DecoderRetId, Flags}, + Result); + + uint32_t StreamIdx = Result[0].get(); + + FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avStream_codecpar"); + + auto &HostFuncAVStreamCodecpar = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVStreamCodecPar &>( + FuncInst->getHostFunc()); + + HostFuncAVStreamCodecpar.run(CallFrame, + std::initializer_list{ + AvFormatCtxId, StreamIdx, CodecParameterPtr}, + Result); + + uint32_t CodecParametersId = readUInt32(MemInst, CodecParameterPtr); + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_alloc_context3"); + auto &HostFuncAVCodecAllocContext3 = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecAllocContext3 &>( + FuncInst->getHostFunc()); + + HostFuncAVCodecAllocContext3.run( + CallFrame, std::initializer_list{0, AVCodecCtxPtr}, + Result); + + uint32_t AVCodecCtxId = readUInt32(MemInst, AVCodecCtxPtr); + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_parameters_to_context"); + auto &HostFuncAVCodecParametersToContext = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecParametersToContext &>( + FuncInst->getHostFunc()); + + HostFuncAVCodecParametersToContext.run( + CallFrame, + std::initializer_list{AVCodecCtxId, + CodecParametersId}, + Result); + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodeccontext_codec_id"); + auto &HostFuncAVCodecContextCodecId = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxCodecID &>( + FuncInst->getHostFunc()); + + HostFuncAVCodecContextCodecId.run( + CallFrame, std::initializer_list{AVCodecCtxId}, + Result); + + uint32_t CodecId = Result[0].get(); + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_find_decoder"); + auto &HostFuncAVCodecFindDecoder = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecFindDecoder &>( + FuncInst->getHostFunc()); + + HostFuncAVCodecFindDecoder.run( + CallFrame, + std::initializer_list{CodecId, AVCodecPtr}, Result); + + uint32_t AVCodecId = readUInt32(MemInst, AVCodecPtr); + + FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_avcodec_open2"); + auto &HostFuncAVCodecOpen2 = + dynamic_cast( + FuncInst->getHostFunc()); + + HostFuncAVCodecOpen2.run( + CallFrame, + std::initializer_list{AVCodecCtxId, AVCodecId, 0}, + Result); + + initEmptyFrame(FramePtr); + uint32_t FrameId = readUInt32(MemInst, FramePtr); + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_receive_frame"); + auto &HostFuncAVCodecReceiveFrame = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecReceiveFrame &>( + FuncInst->getHostFunc()); + + FuncInst = + AVFormatMod->findFuncExports("wasmedge_ffmpeg_avformat_av_read_frame"); + auto &HostFuncAVReadFrame = + dynamic_cast( + FuncInst->getHostFunc()); + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_avcodec_send_packet"); + auto &HostFuncAVCodecSendPacket = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecSendPacket &>( + FuncInst->getHostFunc()); + + FuncInst = AVCodecMod->findFuncExports( + "wasmedge_ffmpeg_avcodec_av_packet_stream_index"); + auto &HostFuncAVPacketStreamIndex = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVPacketStreamIndex &>( + FuncInst->getHostFunc()); + + while (true) { + + HostFuncAVCodecReceiveFrame.run( + CallFrame, + std::initializer_list{AVCodecCtxId, FrameId}, + Result); + + // Error returned by FFmpeg are negative. + int32_t Error = Result[0].get() * -1; + + if (Error == EAGAIN) { + while (true) { + + allocPacket(PacketPtr); + + uint32_t PackedId = readUInt32(MemInst, PacketPtr); + + while (true) { + HostFuncAVReadFrame.run(CallFrame, + std::initializer_list{ + AvFormatCtxId, PackedId}, + Result); + + int32_t Res = Result[0].get(); + if (Res == 0 || Res == AVERROR_EOF) { + break; + } + } + + HostFuncAVPacketStreamIndex.run( + CallFrame, + std::initializer_list{AVCodecCtxId, FrameId}, + Result); + + uint32_t PacketStreamIdx = Result[0].get(); + + if (PacketStreamIdx != StreamIdx) { + continue; + } + + HostFuncAVCodecSendPacket.run( + CallFrame, + std::initializer_list{AVCodecCtxId, PackedId}, + Result); + break; + } + } else { + break; + } + } +} + +void FFmpegTest::initFormatCtx(uint32_t AVFormatCtxPtr, uint32_t FilePtr, + std::string FileName) { + + int32_t Length = FileName.length(); + fillMemContent(MemInst, FilePtr, Length); + fillMemContent(MemInst, FilePtr, FileName); + + auto *FuncInst = AVFormatMod->findFuncExports( + "wasmedge_ffmpeg_avformat_avformat_open_input"); + auto &HostFuncAVFormatOpenInput = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVFormatOpenInput &>( + FuncInst->getHostFunc()); + HostFuncAVFormatOpenInput.run( + CallFrame, + std::initializer_list{ + AVFormatCtxPtr, FilePtr, Length, UINT32_C(0), UINT32_C(0)}, + Result); +} + +void FFmpegTest::initDict(uint32_t DictPtr, uint32_t KeyPtr, std::string Key, + uint32_t ValuePtr, std::string Value) { + + uint32_t KeyLen = Key.length(); + uint32_t ValueLen = Value.length(); + fillMemContent(MemInst, KeyPtr, KeyLen + ValueLen); + fillMemContent(MemInst, KeyPtr, Key); + fillMemContent(MemInst, ValuePtr, Value); + + auto *FuncInst = + AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_dict_set"); + auto &HostFuncAVDictSet = + dynamic_cast( + FuncInst->getHostFunc()); + + HostFuncAVDictSet.run(CallFrame, + std::initializer_list{ + DictPtr, KeyPtr, KeyLen, ValuePtr, ValueLen, 0}, + Result); +} + +void FFmpegTest::allocPacket(uint32_t PacketPtr) { + + auto *FuncInst = + AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_alloc"); + auto &HostFuncAVPacketAlloc = + dynamic_cast( + FuncInst->getHostFunc()); + + HostFuncAVPacketAlloc.run( + CallFrame, std::initializer_list{PacketPtr}, + Result); +} + +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/utils.h b/test/plugins/wasmedge_ffmpeg/utils.h new file mode 100644 index 00000000..9128bacd --- /dev/null +++ b/test/plugins/wasmedge_ffmpeg/utils.h @@ -0,0 +1,164 @@ +#pragma once +#include "avcodec/module.h" +#include "avfilter/module.h" +#include "avformat/module.h" +#include "avutil/module.h" +#include "common/types.h" +#include "runtime/callingframe.h" +#include "runtime/instance/module.h" +#include "swresample/module.h" +#include "swscale/module.h" +#include "gtest/gtest.h" + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeFFmpeg { +inline void writeUInt32(WasmEdge::Runtime::Instance::MemoryInstance *MemInst, + uint32_t Value, uint32_t &Ptr) { + uint32_t *BufPtr = MemInst->getPointer(Ptr); + *BufPtr = Value; +} + +inline void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance *MemInst, + uint32_t Offset, uint32_t Cnt, + uint8_t C = 0) noexcept { + std::fill_n(MemInst->getPointer(Offset), Cnt, C); +} + +inline void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance *MemInst, + uint32_t Offset, const std::string &Str) noexcept { + char *Buf = MemInst->getPointer(Offset); + std::copy_n(Str.c_str(), Str.length(), Buf); +} + +inline void writeSInt32(WasmEdge::Runtime::Instance::MemoryInstance *MemInst, + int32_t Value, uint32_t &Ptr) { + int32_t *BufPtr = MemInst->getPointer(Ptr); + *BufPtr = Value; +} + +inline int32_t readSInt32(WasmEdge::Runtime::Instance::MemoryInstance *MemInst, + uint32_t &Ptr) { + int32_t *BufPtr = MemInst->getPointer(Ptr); + return *BufPtr; +} + +inline uint32_t readUInt32(WasmEdge::Runtime::Instance::MemoryInstance *MemInst, + uint32_t &Ptr) { + uint32_t *BufPtr = MemInst->getPointer(Ptr); + return *BufPtr; +} + +class FFmpegTest : public ::testing::Test { +public: + FFmpegTest() : Mod(""), CallFrame(nullptr, &Mod) { + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + MemInst = Mod.findMemoryExports("memory"); + + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasmedge_ffmpeg/" + "libwasmedgePluginWasmEdgeFFmpeg" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = + WasmEdge::Plugin::Plugin::find("wasmedge_ffmpeg"sv)) { + if (const auto *Module = + Plugin->findModule("wasmedge_ffmpeg_avformat"sv)) { + AVFormatMod = dynamic_cast( + Module->create().release()); + } + if (const auto *Module = Plugin->findModule("wasmedge_ffmpeg_avutil"sv)) { + AVUtilMod = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::WasmEdgeFFmpegAVUtilModule + *>(Module->create().release()); + } + if (const auto *Module = + Plugin->findModule("wasmedge_ffmpeg_swscale"sv)) { + SWScaleMod = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::SWScale::WasmEdgeFFmpegSWScaleModule + *>(Module->create().release()); + } + if (const auto *Module = + Plugin->findModule("wasmedge_ffmpeg_avcodec"sv)) { + AVCodecMod = dynamic_cast< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::WasmEdgeFFmpegAVCodecModule + *>(Module->create().release()); + } + if (const auto *Module = + Plugin->findModule("wasmedge_ffmpeg_swresample"sv)) { + SWResampleMod = + dynamic_cast( + Module->create().release()); + } + if (const auto *Module = + Plugin->findModule("wasmedge_ffmpeg_avfilter"sv)) { + AVFilterMod = dynamic_cast( + Module->create().release()); + } + } + } + + ~FFmpegTest() override { + if (AVUtilMod) { + delete AVUtilMod; + } + if (AVCodecMod) { + delete AVCodecMod; + } + if (SWScaleMod) { + delete SWScaleMod; + } + if (SWResampleMod) { + delete SWResampleMod; + } + if (AVFormatMod) { + delete AVFormatMod; + } + if (AVFilterMod) { + delete AVFilterMod; + } + } + +protected: + void initEmptyFrame(uint32_t FramePtr); + + void initDict(uint32_t DictPtr, uint32_t KeyPtr, std::string Key, + uint32_t ValuePtr, std::string Value); + void initFFmpegStructs(uint32_t AVCodecPtr, uint32_t AVFormatCtxPtr, + uint32_t FilePtr, std::string FileName, + uint32_t CodecParameterPtr, uint32_t AVCodecCtxPtr, + uint32_t PacketPtr, uint32_t FramePtr); + + void initFormatCtx(uint32_t AVFormatCtxPtr, uint32_t FilePtr, + std::string FileName); + void allocPacket(uint32_t PacketPtr); + + // Result of Funcs to be stored here. + std::array Result = {UINT32_C(0)}; + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod; + WasmEdge::Runtime::Instance::MemoryInstance *MemInst; + WasmEdge::Runtime::CallingFrame CallFrame; + + // Wasm Modules. + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::WasmEdgeFFmpegAVFormatModule + *AVFormatMod = nullptr; + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::WasmEdgeFFmpegAVUtilModule + *AVUtilMod = nullptr; + WasmEdge::Host::WasmEdgeFFmpeg::SWResample::WasmEdgeFFmpegSWResampleModule + *SWResampleMod = nullptr; + WasmEdge::Host::WasmEdgeFFmpeg::SWScale::WasmEdgeFFmpegSWScaleModule + *SWScaleMod = nullptr; + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::WasmEdgeFFmpegAVCodecModule + *AVCodecMod = nullptr; + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::WasmEdgeFFmpegAVFilterModule + *AVFilterMod = nullptr; +}; +} // namespace WasmEdgeFFmpeg +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/utils/ffmpeg/download-ffmpeg-sample-video.sh b/utils/ffmpeg/download-ffmpeg-sample-video.sh new file mode 100644 index 00000000..fabbf84a --- /dev/null +++ b/utils/ffmpeg/download-ffmpeg-sample-video.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +# The below video used is sourced from an ffmpeg-libav-tutorial repository. +# Source: https://github.com/leandromoreira/ffmpeg-libav-tutorial/blob/master/LICENSE. +TODIR=$1 +SAMPLE_VIDEO=https://raw.githubusercontent.com/Hrushi20/rust-ffmpeg/master/assets/bunny.mp4 +if [[ $# -eq 0 ]]; then + TODIR=. +fi +if [ ! -d $TODIR ]; then + mkdir $TODIR +fi +if [ ! -d $TODIR ]; then + mkdir $TODIR +fi + +if [ ! -f $TODIR/sample_video.mp4 ]; then + curl -sL $SAMPLE_VIDEO -o $TODIR/sample_video.mp4 + cp $TODIR/sample_video.mp4 $TODIR/dummy.mp4 # Dummy file to manipulate and run tests on file. +fi diff --git a/utils/ffmpeg/install-ffmpeg-v6.0.sh b/utils/ffmpeg/install-ffmpeg-v6.0.sh new file mode 100755 index 00000000..72ac1458 --- /dev/null +++ b/utils/ffmpeg/install-ffmpeg-v6.0.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash +rm -rf FFmpeg-n6.0 ffmpeg.zip +echo $(pwd) + +curl -sL https://github.com/FFmpeg/FFmpeg/archive/refs/tags/n6.0.zip -o ffmpeg.zip + +unzip ffmpeg.zip + +mkdir -p FFmpeg-n6.0/output +cd FFmpeg-n6.0 +./configure --prefix=$(pwd)/output --enable-gpl --enable-nonfree --enable-shared --disable-static +make && make install +cd .. \ No newline at end of file From c3bd75e214ec350d82c9181714ccd1776db13496 Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Wed, 13 Mar 2024 01:01:48 +0800 Subject: [PATCH 255/623] [Docker] Add zlib dependencies Signed-off-by: Yi Huang --- utils/docker/Dockerfile.manylinux2014-build-plugins-deps | 3 ++- utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/utils/docker/Dockerfile.manylinux2014-build-plugins-deps b/utils/docker/Dockerfile.manylinux2014-build-plugins-deps index 45bcd857..3c7fc7aa 100644 --- a/utils/docker/Dockerfile.manylinux2014-build-plugins-deps +++ b/utils/docker/Dockerfile.manylinux2014-build-plugins-deps @@ -6,7 +6,8 @@ ENV MANPATH /opt/rh/devtoolset-11/root/usr/share/man${MANPATH:+:${MANPATH}} ENV INFOPATH /opt/rh/devtoolset-11/root/usr/share/info${INFOPATH:+:${INFOPATH}} ENV PKG_CONFIG_PATH /opt/rh/devtoolset-11/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} -RUN cd && (yum check-update || true) && yum install -y cmake wget unzip +RUN cd && (yum check-update || true) && \ + yum install -y cmake wget unzip zlib-devel zlib-static COPY install-opencvmini.sh . ENV OPENCV_VERSION=4.8.0 diff --git a/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps b/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps index 60b8a044..634fb8a0 100644 --- a/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps +++ b/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps @@ -6,7 +6,8 @@ ENV MANPATH /opt/rh/gcc-toolset-13/root/usr/share/man${MANPATH:+:${MANPATH}} ENV INFOPATH /opt/rh/gcc-toolset-13/root/usr/share/info${INFOPATH:+:${INFOPATH}} ENV PKG_CONFIG_PATH /opt/rh/gcc-toolset-13/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} -RUN cd && (yum check-update || true) && yum install -y wget unzip +RUN cd && (yum check-update || true) && \ + yum install -y wget unzip zlib-devel zlib-static COPY install-opencvmini.sh . ENV OPENCV_VERSION=4.8.0 From 8aa62da9e4249869f607a3eb9168ceebdfdf5b1b Mon Sep 17 00:00:00 2001 From: YiYing He Date: Tue, 12 Mar 2024 13:37:39 +0800 Subject: [PATCH 256/623] [Plugin] Fix option toggle for wasmedge_process plugin. Signed-off-by: YiYing He --- test/plugins/unittest/testplugin.cpp | 6 +++++- test/plugins/unittest/testplugin.h | 12 ++++++++++++ test/plugins/unittest/unittest_c.cpp | 8 +++++--- test/plugins/unittest/unittest_cpp.cpp | 16 +++++++++++++++- 4 files changed, 37 insertions(+), 5 deletions(-) diff --git a/test/plugins/unittest/testplugin.cpp b/test/plugins/unittest/testplugin.cpp index 565b77c8..7cc55013 100644 --- a/test/plugins/unittest/testplugin.cpp +++ b/test/plugins/unittest/testplugin.cpp @@ -19,12 +19,16 @@ PO::Option WasmEdgePluginTestEnv::CmdName(PO::Description("Test for input name."sv), PO::DefaultValue(std::string(""))); +PO::Option + WasmEdgePluginTestEnv::CmdOpt(PO::Description("Test for option."sv)); + namespace { void addOptions(const Plugin::Plugin::PluginDescriptor *, PO::ArgumentParser &Parser) noexcept { Parser.add_option("arg"sv, WasmEdgePluginTestEnv::CmdArgs) - .add_option("name"sv, WasmEdgePluginTestEnv::CmdName); + .add_option("name"sv, WasmEdgePluginTestEnv::CmdName) + .add_option("opt"sv, WasmEdgePluginTestEnv::CmdOpt); } Runtime::Instance::ModuleInstance * diff --git a/test/plugins/unittest/testplugin.h b/test/plugins/unittest/testplugin.h index d7373c9c..b6bdbd34 100644 --- a/test/plugins/unittest/testplugin.h +++ b/test/plugins/unittest/testplugin.h @@ -20,6 +20,7 @@ class WasmEdgePluginTestEnv { static PO::List CmdArgs; static PO::Option CmdName; + static PO::Option CmdOpt; }; template @@ -62,6 +63,16 @@ class WasmEdgePluginTestFuncArgLen } }; +class WasmEdgePluginTestFuncOpt + : public WasmEdgePluginTestFunc { +public: + WasmEdgePluginTestFuncOpt(WasmEdgePluginTestEnv &HostEnv) + : WasmEdgePluginTestFunc(HostEnv) {} + Expect body(const Runtime::CallingFrame &) { + return static_cast(Env.CmdOpt.value()); + } +}; + class WasmEdgePluginTestFuncNameSize : public WasmEdgePluginTestFunc { public: @@ -79,6 +90,7 @@ class WasmEdgePluginTestModule : public Runtime::Instance::ModuleInstance { addHostFunc("add", std::make_unique(Env)); addHostFunc("sub", std::make_unique(Env)); addHostFunc("arg_len", std::make_unique(Env)); + addHostFunc("opt", std::make_unique(Env)); addHostFunc("name_size", std::make_unique(Env)); } diff --git a/test/plugins/unittest/unittest_c.cpp b/test/plugins/unittest/unittest_c.cpp index 19f22993..a1ccb94d 100644 --- a/test/plugins/unittest/unittest_c.cpp +++ b/test/plugins/unittest/unittest_c.cpp @@ -152,9 +152,9 @@ TEST(wasmedgePluginTests, C_Module) { // Create the wasmedge_plugintest_cpp_module module instance. auto *ModInstCPP = createModuleCPP(); ASSERT_FALSE(ModInstCPP == nullptr); - EXPECT_EQ(WasmEdge_ModuleInstanceListFunctionLength(ModInstCPP), 4U); + EXPECT_EQ(WasmEdge_ModuleInstanceListFunctionLength(ModInstCPP), 5U); std::memset(NameBuf, 0, sizeof(WasmEdge_String) * 16); - EXPECT_EQ(WasmEdge_ModuleInstanceListFunction(ModInstCPP, NameBuf, 16), 4U); + EXPECT_EQ(WasmEdge_ModuleInstanceListFunction(ModInstCPP, NameBuf, 16), 5U); EXPECT_TRUE( WasmEdge_StringIsEqual(NameBuf[0], WasmEdge_StringWrap("add", 3U))); EXPECT_TRUE( @@ -162,7 +162,9 @@ TEST(wasmedgePluginTests, C_Module) { EXPECT_TRUE( WasmEdge_StringIsEqual(NameBuf[2], WasmEdge_StringWrap("name_size", 9U))); EXPECT_TRUE( - WasmEdge_StringIsEqual(NameBuf[3], WasmEdge_StringWrap("sub", 3U))); + WasmEdge_StringIsEqual(NameBuf[3], WasmEdge_StringWrap("opt", 3U))); + EXPECT_TRUE( + WasmEdge_StringIsEqual(NameBuf[4], WasmEdge_StringWrap("sub", 3U))); WasmEdge_ModuleInstanceDelete(ModInstCPP); } diff --git a/test/plugins/unittest/unittest_cpp.cpp b/test/plugins/unittest/unittest_cpp.cpp index 00efc45e..345bdbfd 100644 --- a/test/plugins/unittest/unittest_cpp.cpp +++ b/test/plugins/unittest/unittest_cpp.cpp @@ -40,6 +40,7 @@ WasmEdge::Runtime::Instance::ModuleInstance *createModuleCPP() { Parser.set_raw_value("name"sv, std::string("test_name")); Parser.set_raw_value>( "arg"sv, std::vector({"arg0", "arg1", "arg2", "arg3"})); + Parser.set_raw_value("opt"sv); if (const auto *Module = Plugin->findModule("wasmedge_plugintest_cpp_module"sv)) { return Module->create().release(); @@ -83,6 +84,18 @@ TEST(wasmedgePluginTests, CPP_Run) { EXPECT_TRUE(HostFuncInst2.run(CallFrame, {}, RetVal)); EXPECT_EQ(RetVal[0].get(), 9); + // Get the function "opt". + auto *FuncInst3 = TestModCPP->findFuncExports("opt"); + EXPECT_NE(FuncInst3, nullptr); + EXPECT_TRUE(FuncInst3->isHostFunction()); + auto &HostFuncInst3 = + dynamic_cast( + FuncInst3->getHostFunc()); + + // Test: Run function successfully. + EXPECT_TRUE(HostFuncInst3.run(CallFrame, {}, RetVal)); + EXPECT_EQ(RetVal[0].get(), 1); + delete TestModCPP; // Create the wasmedge_plugintest_c_module module instance. @@ -98,10 +111,11 @@ TEST(wasmedgePluginTests, CPP_Module) { auto *TestModCPP = dynamic_cast( createModuleCPP()); ASSERT_FALSE(TestModCPP == nullptr); - EXPECT_EQ(TestModCPP->getFuncExportNum(), 4U); + EXPECT_EQ(TestModCPP->getFuncExportNum(), 5U); EXPECT_NE(TestModCPP->findFuncExports("add"), nullptr); EXPECT_NE(TestModCPP->findFuncExports("sub"), nullptr); EXPECT_NE(TestModCPP->findFuncExports("arg_len"), nullptr); + EXPECT_NE(TestModCPP->findFuncExports("opt"), nullptr); EXPECT_NE(TestModCPP->findFuncExports("name_size"), nullptr); delete TestModCPP; From 1a73c120d638b529d5c0cee6988e930dfb272304 Mon Sep 17 00:00:00 2001 From: dm4 Date: Sat, 16 Mar 2024 00:28:04 +0800 Subject: [PATCH 257/623] [WASI-NN] ggml: set LlamaNInputs correctly when llava inference (#3286) Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index eee07797..f46ee6b6 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -771,7 +771,6 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, if (GraphRef.MMProjModelPath == ""sv) { // Text only prompt. CxtRef.LlamaInputs = llama_tokenize(LlamaContext, Prompt, AddBos, true); - CxtRef.LlamaNInputs = CxtRef.LlamaInputs.size(); } else { // Handle llava format prompt. @@ -851,6 +850,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, EmbdInputAfterImage.begin(), EmbdInputAfterImage.end()); } + CxtRef.LlamaNInputs = CxtRef.LlamaInputs.size(); if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] GGML backend: set the input...Done"sv); } From 7ea453442089afb03d6aded253ab710fd877ad14 Mon Sep 17 00:00:00 2001 From: hydai Date: Mon, 18 Mar 2024 13:42:43 +0800 Subject: [PATCH 258/623] [WASI-NN] ggml backend: bump llama.cpp b2450 Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 2 +- plugins/wasi_nn/ggml.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 86a48cef..c0cab2e7 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -49,7 +49,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b2334 + GIT_TAG b2450 PATCH_COMMAND test -f ggml.patched || git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch && ${CMAKE_COMMAND} -E touch ggml.patched GIT_SHALLOW FALSE ) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index f46ee6b6..bf300f35 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -348,7 +348,7 @@ Expect getEmbedding(WasiNNEnvironment &Env, llama_context_params ContextParams = llama_context_default_params(); ContextParams.n_ctx = GraphRef.CtxSize; ContextParams.n_batch = GraphRef.BatchSize; - ContextParams.embedding = GraphRef.Embedding; + ContextParams.embeddings = GraphRef.Embedding; auto *LlamaContext = llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); From b19af402e305769be76eb1eae69c2690b63bf8ae Mon Sep 17 00:00:00 2001 From: hydai Date: Mon, 18 Mar 2024 16:48:12 +0800 Subject: [PATCH 259/623] [WASI-NN] ggml: downgrade to b2370 to fix the metal segfault issue (#3290) Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index c0cab2e7..2cf32885 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -49,7 +49,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b2450 + GIT_TAG b2370 PATCH_COMMAND test -f ggml.patched || git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch && ${CMAKE_COMMAND} -E touch ggml.patched GIT_SHALLOW FALSE ) From 313ac7264ca31aa3315b695c96e337b5df5ae6c2 Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Fri, 15 Mar 2024 16:33:07 +0800 Subject: [PATCH 260/623] [Docker] Remove an accident Signed-off-by: Yi Huang --- ...ckerfile.manylinux2014_plugins_deps_x86_64 | 24 ------------------- 1 file changed, 24 deletions(-) delete mode 100644 utils/docker/Dockerfile.manylinux2014_plugins_deps_x86_64 diff --git a/utils/docker/Dockerfile.manylinux2014_plugins_deps_x86_64 b/utils/docker/Dockerfile.manylinux2014_plugins_deps_x86_64 deleted file mode 100644 index f3a89549..00000000 --- a/utils/docker/Dockerfile.manylinux2014_plugins_deps_x86_64 +++ /dev/null @@ -1,24 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# SPDX-FileCopyrightText: 2019-2022 Second State INC - -ARG BASE=wasmedge/wasmedge:manylinux2014_x86_64 -ARG BUILDPLATFORM=x86_64 -FROM --platform=$BUILDPLATFORM ${BASE} - -MAINTAINER hydai hydai@secondstate.io - -ADD install-opencvmini.sh /root/ - -ENV PATH=/opt/rh/devtoolset-11/root/usr/bin${PATH:+:${PATH}} -ENV MANPATH=/opt/rh/devtoolset-11/root/usr/share/man${MANPATH:+:${MANPATH}} -ENV INFOPATH=/opt/rh/devtoolset-11/root/usr/share/info${INFOPATH:+:${INFOPATH}} -ENV PKG_CONFIG_PATH=/opt/rh/devtoolset-11/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} -ENV OPENCV_VERSION=4.8.0 - -WORKDIR /root/ - -RUN yum update -y \ - && yum install -y zlib-devel zlib-static cmake wget unzip \ - && bash /root/install-opencvmini.sh - -RUN yum clean all From 9e0c9dfb46afedbc0cb9849f2d38856c6f19bab3 Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Fri, 15 Mar 2024 19:50:30 +0800 Subject: [PATCH 261/623] [CI] Move dependencies to docker image Signed-off-by: Yi Huang --- ...ockerfile.manylinux2014-build-plugins-deps | 26 ++++++++++++++++-- ...ckerfile.manylinux_2_28-build-plugins-deps | 27 +++++++++++++++++-- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/utils/docker/Dockerfile.manylinux2014-build-plugins-deps b/utils/docker/Dockerfile.manylinux2014-build-plugins-deps index 3c7fc7aa..4762b087 100644 --- a/utils/docker/Dockerfile.manylinux2014-build-plugins-deps +++ b/utils/docker/Dockerfile.manylinux2014-build-plugins-deps @@ -8,9 +8,31 @@ ENV PKG_CONFIG_PATH /opt/rh/devtoolset-11/root/usr/lib64/pkgconfig${PKG_CONFIG_P RUN cd && (yum check-update || true) && \ yum install -y cmake wget unzip zlib-devel zlib-static +RUN yum-config-manager --add-repo https://cli.github.com/packages/rpm/gh-cli.repo && \ + yum install -y gh -COPY install-opencvmini.sh . -ENV OPENCV_VERSION=4.8.0 +WORKDIR /root + +COPY docker/install-opencvmini.sh . +ENV OPENCV_VERSION "4.8.0" RUN [ "/bin/bash", "install-opencvmini.sh" ] +COPY wasi-nn/install-pytorch.sh . +ENV PYTORCH_VERSION "1.8.2" +ENV PYTORCH_INSTALL_TO "/root" +ENV Torch_DIR "/root/libtorch" +RUN [ "/bin/bash", "install-pytorch.sh", "--disable-cxx11-abi" ] + +COPY wasi-crypto/build-openssl.sh . +ENV OPENSSL_ROOT_DIR "/root/openssl-1.1.1n/openssl" +RUN [ "/bin/bash", "build-openssl.sh" ] + +COPY ffmpeg/install-ffmpeg-v6.0.sh . +RUN [ "/bin/bash", "install-ffmpeg-v6.0.sh" ] +ENV PKG_CONFIG_PATH /root/FFmpeg-n6.0/output/lib/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} +ENV LD_LIBRARY_PATH /root/FFmpeg-n6.0/output/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} + +ENV OPENVINO_VERSION "2023.0.2" +ENV OPENVINO_YEAR "2023" + RUN yum clean all diff --git a/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps b/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps index 634fb8a0..c8ed5dbf 100644 --- a/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps +++ b/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps @@ -8,9 +8,32 @@ ENV PKG_CONFIG_PATH /opt/rh/gcc-toolset-13/root/usr/lib64/pkgconfig${PKG_CONFIG_ RUN cd && (yum check-update || true) && \ yum install -y wget unzip zlib-devel zlib-static +RUN yum install -y yum-utils && \ + yum-config-manager --add-repo https://cli.github.com/packages/rpm/gh-cli.repo && \ + yum install -y gh -COPY install-opencvmini.sh . -ENV OPENCV_VERSION=4.8.0 +WORKDIR /root + +COPY docker/install-opencvmini.sh . +ENV OPENCV_VERSION "4.8.0" RUN [ "/bin/bash", "install-opencvmini.sh" ] +COPY wasi-nn/install-pytorch.sh . +ENV PYTORCH_VERSION "1.8.2" +ENV PYTORCH_INSTALL_TO "/root" +ENV Torch_DIR "/root/libtorch" +RUN [ "/bin/bash", "install-pytorch.sh", "--disable-cxx11-abi" ] + +COPY wasi-crypto/build-openssl.sh . +ENV OpenSSL_DIR "/root/openssl-1.1.1n/openssl" +RUN [ "/bin/bash", "build-openssl.sh" ] + +COPY ffmpeg/install-ffmpeg-v6.0.sh . +RUN [ "/bin/bash", "install-ffmpeg-v6.0.sh" ] +ENV PKG_CONFIG_PATH /root/FFmpeg-n6.0/output/lib/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} +ENV LD_LIBRARY_PATH /root/FFmpeg-n6.0/output/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} + +ENV OPENVINO_VERSION "2023.0.2" +ENV OPENVINO_YEAR "2023" + RUN yum clean all From bba0f358ba999848c151a55f827a1c4893dabb9b Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 19 Mar 2024 08:30:44 +0800 Subject: [PATCH 262/623] [WASI-NN] ggml: downgrade to b2334 to fix the embedding issue Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 2 +- plugins/wasi_nn/ggml.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 2cf32885..86a48cef 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -49,7 +49,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b2370 + GIT_TAG b2334 PATCH_COMMAND test -f ggml.patched || git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch && ${CMAKE_COMMAND} -E touch ggml.patched GIT_SHALLOW FALSE ) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index bf300f35..f46ee6b6 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -348,7 +348,7 @@ Expect getEmbedding(WasiNNEnvironment &Env, llama_context_params ContextParams = llama_context_default_params(); ContextParams.n_ctx = GraphRef.CtxSize; ContextParams.n_batch = GraphRef.BatchSize; - ContextParams.embeddings = GraphRef.Embedding; + ContextParams.embedding = GraphRef.Embedding; auto *LlamaContext = llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); From 7d5d2f6c146e87567267be0975490c0716891f69 Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 19 Mar 2024 18:22:47 +0800 Subject: [PATCH 263/623] [WASI-NN] ggml: use new metal file Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 86a48cef..8696fd76 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -174,6 +174,7 @@ if(BACKEND STREQUAL "ggml") TARGET wasmedgePluginWasiNN POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_BINARY_DIR}/_deps/llama-src/ggml-metal.metal ggml-metal.metal + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_BINARY_DIR}/bin/default.metallib default.metallib ) endif() endif() From 8ab8b2599ea7eefbcb52600bc1b9c0c53260c6af Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 21 Mar 2024 10:44:14 +0800 Subject: [PATCH 264/623] [WASI-NN] rpc: implement `load_by_name_with_config` Signed-off-by: dm4 --- plugins/wasi_nn/wasinnfunc.cpp | 30 ++++++++++++++++++++++-------- test/plugins/wasi_nn/wasi_nn.cpp | 27 +++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index 6ccf2b45..defaf6ae 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -157,14 +157,6 @@ WasiNNLoadByName::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t NamePtr, Expect WasiNNLoadByNameWithConfig::bodyImpl( const Runtime::CallingFrame &Frame, uint32_t NamePtr, uint32_t NameLen, uint32_t ConfigPtr, uint32_t ConfigLen, uint32_t GraphIdPtr) { -#ifdef WASMEDGE_BUILD_WASI_NN_RPC - if (Env.NNRPCChannel != nullptr) { - // TODO: implement RPC for LoadByNameWithConfig - spdlog::error( - "[WASI-NN] RPC client is not implemented for LoadByNameWithConfig"sv); - return WASINN::ErrNo::UnsupportedOperation; - } -#endif auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); @@ -192,6 +184,28 @@ Expect WasiNNLoadByNameWithConfig::bodyImpl( return WASINN::ErrNo::InvalidArgument; } +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + auto Stub = wasi_ephemeral_nn::Graph::NewStub(Env.NNRPCChannel); + grpc::ClientContext ClientContext; + wasi_ephemeral_nn::LoadByNameWithConfigRequest Req; + auto NameStrView = MemInst->getStringView(NamePtr, NameLen); + auto ConfigStrView = MemInst->getStringView(ConfigPtr, ConfigLen); + Req.set_name(NameStrView.data(), NameStrView.size()); + Req.set_config(ConfigStrView.data(), ConfigStrView.size()); + wasi_ephemeral_nn::LoadByNameWithConfigResult Res; + auto Status = Stub->LoadByNameWithConfig(&ClientContext, Req, &Res); + if (!Status.ok()) { + spdlog::error( + "[WASI-NN] Failed when calling remote LoadByNameWithConfig: {}"sv, + Status.error_message()); + return WASINN::ErrNo::RuntimeError; + } + *GraphId = Res.graph_handle(); + return WASINN::ErrNo::Success; + } +#endif // ifdef WASMEDGE_BUILD_WASI_NN_RPC + // Get the model std::string ModelName(reinterpret_cast(Name), NameLen); std::vector ModelConfig(reinterpret_cast(Config), diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 38195202..f5d5572d 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -1525,6 +1525,13 @@ TEST(WasiNNTest, GGMLBackendWithRPC) { EXPECT_TRUE(FuncInst->isHostFunction()); auto &HostFuncLoadByName = dynamic_cast(FuncInst->getHostFunc()); + // Get the function "load_by_name_with_config". + FuncInst = NNMod->findFuncExports("load_by_name_with_config"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncLoadByNameWithConfig = + dynamic_cast( + FuncInst->getHostFunc()); // Get the function "init_execution_context". FuncInst = NNMod->findFuncExports("init_execution_context"); EXPECT_NE(FuncInst, nullptr); @@ -1565,6 +1572,26 @@ TEST(WasiNNTest, GGMLBackendWithRPC) { BuilderPtr += 4; } + // Test: load_by_name_with_config -- load successfully. + { + std::string Name = "default"; + std::string Config = "{}"; + std::vector NameVec(Name.begin(), Name.end()); + std::vector ConfigVec(Config.begin(), Config.end()); + uint32_t ConfigPtr = LoadEntryPtr + NameVec.size(); + writeBinaries(MemInst, NameVec, LoadEntryPtr); + writeBinaries(MemInst, ConfigVec, ConfigPtr); + EXPECT_TRUE(HostFuncLoadByNameWithConfig.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, static_cast(NameVec.size()), ConfigPtr, + static_cast(ConfigVec.size()), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + // GGML WASI-NN init_execution_context tests. // Test: init_execution_context -- graph id invalid. { From dd9cbde9bdf3d4f991f5d2b4292832add42a8dd4 Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 21 Mar 2024 15:48:53 +0800 Subject: [PATCH 265/623] [WASI-NN] ggml: bump llama.cpp b2479 and fix embeddings Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- plugins/wasi_nn/ggml.cpp | 136 ++++++++++++++++----------------- 2 files changed, 68 insertions(+), 70 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 8696fd76..dd80a29b 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -49,7 +49,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b2334 + GIT_TAG b2479 PATCH_COMMAND test -f ggml.patched || git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch && ${CMAKE_COMMAND} -E touch ggml.patched GIT_SHALLOW FALSE ) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index f46ee6b6..5309f9aa 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -277,6 +277,7 @@ Expect setupContextParam(Graph &GraphRef, ContextParams.n_batch = GraphRef.BatchSize; ContextParams.n_threads = GraphRef.Threads; ContextParams.n_threads_batch = GraphRef.Threads; + ContextParams.embeddings = GraphRef.Embedding; return ErrNo::Success; } @@ -315,6 +316,48 @@ void buildOutputEmbedding(std::string &Embedding, int32_t NEmbd, Embedding = OS.str(); } +ErrNo evaluateTokens(Graph &GraphRef, struct llama_context *LlamaContext, + std::vector Tokens, int &NPast) noexcept { + uint32_t NCtx = llama_n_ctx(LlamaContext); + + // End the inference if the context is full. + if (NPast + static_cast(Tokens.size()) > NCtx) { + if (GraphRef.EnableLog) { + spdlog::info( + "[WASI-NN] GGML backend: the context if full ({} / {} tokens). Please increase your context size."sv, + NPast + static_cast(Tokens.size()), NCtx); + } + return ErrNo::ContextFull; + } + + for (int I = 0; I < static_cast(Tokens.size()); + I += GraphRef.BatchSize) { + int NEval = static_cast(Tokens.size()) - I; + if (NEval > static_cast(GraphRef.BatchSize)) { + NEval = GraphRef.BatchSize; + } + // llama_batch_get_one(*token, n_tokens, position, sequence_id) + // This will return batch for single sequence of tokens starting at + // position. + const llama_seq_id SequenceId = 0; + auto Status = + llama_decode(LlamaContext, + llama_batch_get_one(&Tokens[I], NEval, NPast, SequenceId)); + if (Status == 1) { + spdlog::error( + "[WASI-NN] GGML backend: failed to llama_decode: try reducing the size of the batch or increasing the size of context"sv); + return ErrNo::RuntimeError; + } else if (Status < 0) { + spdlog::error( + "[WASI-NN] GGML backend: failed to llama_decode: internal fatal error. Please open an issue on GitHub"sv); + return ErrNo::RuntimeError; + } + NPast += NEval; + } + + return ErrNo::Success; +} + Expect getEmbedding(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); @@ -346,18 +389,30 @@ Expect getEmbedding(WasiNNEnvironment &Env, } // Initialize the llama context. llama_context_params ContextParams = llama_context_default_params(); - ContextParams.n_ctx = GraphRef.CtxSize; - ContextParams.n_batch = GraphRef.BatchSize; - ContextParams.embedding = GraphRef.Embedding; + setupContextParam(GraphRef, ContextParams); auto *LlamaContext = llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); + // Prepare variables; + int32_t NPast = 0; // Get the context size. const uint64_t NCtx = llama_n_ctx(LlamaContext); // Minus 4 for the special tokens. (Such as , , ... tokens.) const uint64_t MaxTokensListSize = NCtx - 4; // Use the const sequence id here. const llama_seq_id SequenceId = 0; + // Return value. + auto ReturnCode = ErrNo::Success; + + // Add BOS if not present. + if (CxtRef.LlamaInputs.front() != llama_token_bos(GraphRef.LlamaModel)) { + CxtRef.LlamaInputs.insert(CxtRef.LlamaInputs.begin(), + llama_token_bos(GraphRef.LlamaModel)); + } + // Add EOS if not present. + if (CxtRef.LlamaInputs.back() != llama_token_eos(GraphRef.LlamaModel)) { + CxtRef.LlamaInputs.push_back(llama_token_eos(GraphRef.LlamaModel)); + } // Check if the input is too long. if (static_cast(CxtRef.LlamaInputs.size()) > MaxTokensListSize) { @@ -369,32 +424,17 @@ Expect getEmbedding(WasiNNEnvironment &Env, return ErrNo::PromptTooLong; } - int NPast = 0; - while (!CxtRef.LlamaInputs.empty()) { - const uint64_t NTokens = (ContextParams.n_batch > CxtRef.LlamaInputs.size()) - ? CxtRef.LlamaInputs.size() - : ContextParams.n_batch; - auto Status = llama_decode(LlamaContext, - llama_batch_get_one(CxtRef.LlamaInputs.data(), - NTokens, NPast, SequenceId)); - if (Status == 1) { - spdlog::error( - "[WASI-NN] GGML backend: failed to llama_decode: try " - "reducing the size of the batch or increasing the size of context"sv); - return ErrNo::RuntimeError; - } - if (Status < 0) { - spdlog::error("[WASI-NN] GGML backend: failed to llama_decode: internal " - "fatal error. Please open an issue on GitHub"sv); - return ErrNo::RuntimeError; - } - - NPast += NTokens; - CxtRef.LlamaInputs.erase(CxtRef.LlamaInputs.begin(), - CxtRef.LlamaInputs.begin() + NTokens); + // Evaluate input tokens. + ReturnCode = details::evaluateTokens(GraphRef, LlamaContext, + CxtRef.LlamaInputs, NPast); + if (ReturnCode != ErrNo::Success) { + spdlog::error("[WASI-NN] GGML backend: failed to evaluate input tokens."sv); + return ReturnCode; } + const int32_t NEmbd = llama_n_embd(GraphRef.LlamaModel); - const auto *Embeddings = llama_get_embeddings(LlamaContext); + auto *Embeddings = llama_get_embeddings_seq(LlamaContext, SequenceId); + llama_embd_normalize(Embeddings, Embeddings, NEmbd); details::buildOutputEmbedding(CxtRef.LlamaOutputs, NEmbd, Embeddings); @@ -418,48 +458,6 @@ Expect getEmbedding(WasiNNEnvironment &Env, return ErrNo::Success; } -ErrNo evaluateTokens(Graph &GraphRef, struct llama_context *LlamaContext, - std::vector Tokens, int &NPast) noexcept { - uint32_t NCtx = llama_n_ctx(LlamaContext); - - // End the inference if the context is full. - if (NPast + static_cast(Tokens.size()) > NCtx) { - if (GraphRef.EnableLog) { - spdlog::info( - "[WASI-NN] GGML backend: the context if full ({} / {} tokens). Please increase your context size."sv, - NPast + static_cast(Tokens.size()), NCtx); - } - return ErrNo::ContextFull; - } - - for (int I = 0; I < static_cast(Tokens.size()); - I += GraphRef.BatchSize) { - int NEval = static_cast(Tokens.size()) - I; - if (NEval > static_cast(GraphRef.BatchSize)) { - NEval = GraphRef.BatchSize; - } - // llama_batch_get_one(*token, n_tokens, position, sequence_id) - // This will return batch for single sequence of tokens starting at - // position. - const llama_seq_id SequenceId = 0; - auto Status = - llama_decode(LlamaContext, - llama_batch_get_one(&Tokens[I], NEval, NPast, SequenceId)); - if (Status == 1) { - spdlog::error( - "[WASI-NN] GGML backend: failed to llama_decode: try reducing the size of the batch or increasing the size of context"sv); - return ErrNo::RuntimeError; - } else if (Status < 0) { - spdlog::error( - "[WASI-NN] GGML backend: failed to llama_decode: internal fatal error. Please open an issue on GitHub"sv); - return ErrNo::RuntimeError; - } - NPast += NEval; - } - - return ErrNo::Success; -} - const std::string_view Base64ImageTagPrefix = ""sv; From 8b07badbde3047aa44ba66e8cae453f40b45db64 Mon Sep 17 00:00:00 2001 From: dm4 Date: Fri, 22 Mar 2024 10:09:27 +0800 Subject: [PATCH 266/623] [WASI-NN] ggml: clear inputs before setting it Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 5309f9aa..f796cdb2 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -766,6 +766,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, const bool AddBos = llama_should_add_bos_token(GraphRef.LlamaModel); std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), Tensor.Tensor.size()); + CxtRef.LlamaInputs.clear(); if (GraphRef.MMProjModelPath == ""sv) { // Text only prompt. CxtRef.LlamaInputs = llama_tokenize(LlamaContext, Prompt, AddBos, true); From b029796c623a1e9dca11dab5bf39becd9d006b74 Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 26 Mar 2024 16:17:05 +0800 Subject: [PATCH 267/623] [WASI-NN] ggml: set LLAMA_METAL_EMBED_LIBRARY=ON Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index dd80a29b..278c9153 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -38,6 +38,7 @@ if(BACKEND STREQUAL "ggml") if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL) message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_METAL") set(LLAMA_METAL ON) + set(LLAMA_METAL_EMBED_LIBRARY ON) else() message(STATUS "WASI-NN GGML LLAMA backend: Disable LLAMA_METAL") set(LLAMA_METAL OFF) @@ -174,7 +175,7 @@ if(BACKEND STREQUAL "ggml") TARGET wasmedgePluginWasiNN POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_BINARY_DIR}/_deps/llama-src/ggml-metal.metal ggml-metal.metal - COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_BINARY_DIR}/bin/default.metallib default.metallib + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_BINARY_DIR}/_deps/llama-src/ggml-common.h ggml-common.h ) endif() endif() From ebadb0a99c3951ed62eeb1c70976ceabe2cf33e1 Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 26 Mar 2024 16:36:38 +0800 Subject: [PATCH 268/623] [WASI-NN] ggml: bump llama.cpp b2534 Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 278c9153..16b6e856 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -10,15 +10,15 @@ if(BACKEND STREQUAL "ggml") set(LLAMA_ACCELERATE OFF) if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_CUBLAS) - message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_CUBLAS") - set(LLAMA_CUBLAS ON) - # We need to set GGML_USE_CUBLAS for clip from llava. - add_compile_definitions(GGML_USE_CUBLAS) - # If CUBLAS is ON, then OpenBLAS should be OFF. + message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_CUDA") + set(LLAMA_CUDA ON) + # We need to set GGML_USE_CUDA for clip from llava. + add_compile_definitions(GGML_USE_CUDA) + # If CUDA is ON, then OpenBLAS should be OFF. set(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_BLAS OFF) else() - message(STATUS "WASI-NN GGML LLAMA backend: Disable LLAMA_CUBLAS") - set(LLAMA_CUBLAS OFF) + message(STATUS "WASI-NN GGML LLAMA backend: Disable LLAMA_CUDA") + set(LLAMA_CUDA OFF) endif() if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_BLAS) @@ -50,7 +50,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b2479 + GIT_TAG b2534 PATCH_COMMAND test -f ggml.patched || git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch && ${CMAKE_COMMAND} -E touch ggml.patched GIT_SHALLOW FALSE ) From 78ec5b73c1794fac2b77d45388657ea75552dd08 Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Wed, 27 Mar 2024 12:43:22 +0800 Subject: [PATCH 269/623] [WASI-NN] Support windows build * Rewrite shellscripts to cmake * Add compile flags for llama * Fix compile warnings Signed-off-by: Shen-Ta Hsieh --- plugins/wasi_nn/CMakeLists.txt | 15 ++- plugins/wasi_nn/ggml.cpp | 87 ++++++------ plugins/wasi_nn/wasinnenv.cpp | 4 + test/plugins/wasi_nn/CMakeLists.txt | 107 +++++++-------- test/plugins/wasi_nn/wasi_nn.cpp | 142 +++++++++++++------- utils/wasi-nn/download-ggml-fixtures.sh | 17 --- utils/wasi-nn/download-openvino-fixtures.sh | 25 ---- utils/wasi-nn/download-pytorch-fixtures.sh | 19 --- utils/wasi-nn/download-tflite-fixtures.sh | 19 --- 9 files changed, 207 insertions(+), 228 deletions(-) delete mode 100755 utils/wasi-nn/download-ggml-fixtures.sh delete mode 100755 utils/wasi-nn/download-openvino-fixtures.sh delete mode 100755 utils/wasi-nn/download-pytorch-fixtures.sh delete mode 100755 utils/wasi-nn/download-tflite-fixtures.sh diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 16b6e856..77b97789 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -46,12 +46,25 @@ if(BACKEND STREQUAL "ggml") # setup llama.cpp message(STATUS "Downloading llama.cpp source") + if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") + add_compile_options( + -Wno-cast-align + -Wno-cast-qual + -Wno-disabled-macro-expansion + -Wno-exceptions + -Wno-float-conversion + -Wno-implicit-fallthrough + -Wno-implicit-float-conversion + -Wno-unused-macros + ) + endif() include(FetchContent) FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git GIT_TAG b2534 - PATCH_COMMAND test -f ggml.patched || git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch && ${CMAKE_COMMAND} -E touch ggml.patched + PATCH_COMMAND git checkout . + COMMAND git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(llama) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index f796cdb2..43afec1d 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -19,7 +19,7 @@ namespace WasmEdge::Host::WASINN::GGML { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML -namespace details { +namespace { Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, bool *IsModelUpdated = nullptr) noexcept { simdjson::dom::parser Parser; @@ -264,10 +264,11 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, } Expect setupGPTParam(Graph &GraphRef, gpt_params &GPTParams) { - GPTParams.sparams.temp = GraphRef.Temp; - GPTParams.sparams.top_p = GraphRef.TopP; - GPTParams.sparams.penalty_repeat = GraphRef.RepeatPenalty; - GPTParams.sparams.penalty_present = GraphRef.PresencePenalty; + GPTParams.sparams.temp = static_cast(GraphRef.Temp); + GPTParams.sparams.top_p = static_cast(GraphRef.TopP); + GPTParams.sparams.penalty_repeat = static_cast(GraphRef.RepeatPenalty); + GPTParams.sparams.penalty_present = + static_cast(GraphRef.PresencePenalty); return ErrNo::Success; } @@ -425,8 +426,8 @@ Expect getEmbedding(WasiNNEnvironment &Env, } // Evaluate input tokens. - ReturnCode = details::evaluateTokens(GraphRef, LlamaContext, - CxtRef.LlamaInputs, NPast); + ReturnCode = evaluateTokens(GraphRef, LlamaContext, + std::move(CxtRef.LlamaInputs), NPast); if (ReturnCode != ErrNo::Success) { spdlog::error("[WASI-NN] GGML backend: failed to evaluate input tokens."sv); return ReturnCode; @@ -436,7 +437,7 @@ Expect getEmbedding(WasiNNEnvironment &Env, auto *Embeddings = llama_get_embeddings_seq(LlamaContext, SequenceId); llama_embd_normalize(Embeddings, Embeddings, NEmbd); - details::buildOutputEmbedding(CxtRef.LlamaOutputs, NEmbd, Embeddings); + buildOutputEmbedding(CxtRef.LlamaOutputs, NEmbd, Embeddings); if (GraphRef.EnableDebugLog) { spdlog::info( @@ -463,7 +464,7 @@ const std::string_view Base64ImageBytesPrefix = ";base64,"sv; const std::string_view Base64ImageTagSuffix = "\">"sv; const std::string_view PromptImagePlaceholder = ""sv; -bool containsBase64Image(Graph &GraphRef, std::string Prompt) noexcept { +bool containsBase64Image(Graph &GraphRef, std::string_view Prompt) noexcept { // Check if the prompt contains a base64 image. // Follow this link for the supported image formats: // https://github.com/ggerganov/llama.cpp/blob/master/common/stb_image.h @@ -490,7 +491,7 @@ bool containsBase64Image(Graph &GraphRef, std::string Prompt) noexcept { struct llava_image_embed * loadBase64ImageFromPrompt(Graph &GraphRef, clip_ctx *ClipContext, - std::string Prompt) noexcept { + std::string_view Prompt) noexcept { // Load the base64 image from the prompt. // Follow this link for the supported image formats: // https://github.com/ggerganov/llama.cpp/blob/master/common/stb_image.h @@ -559,7 +560,7 @@ ErrNo replaceBase64ImagePlaceholderInPrompt(std::string &Prompt) noexcept { return ErrNo::Success; } -} // namespace details +} // namespace Expect load(WasiNNEnvironment &Env, Span> Builders, [[maybe_unused]] Device Device, uint32_t &GraphId) noexcept { @@ -595,7 +596,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, const std::string Metadata(reinterpret_cast(Builders[1].data()), Builders[1].size()); // Ignore context or model updates when initializing the graph. - auto Res = details::parseMetadata(GraphRef, Metadata); + auto Res = parseMetadata(GraphRef, Metadata); if (Res != ErrNo::Success) { spdlog::error("[WASI-NN] GGML backend: Failed to parse metadata."sv); Env.NNGraph.pop_back(); @@ -709,8 +710,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } const std::string Metadata(reinterpret_cast(Tensor.Tensor.data()), Tensor.Tensor.size()); - auto Res = - details::parseMetadata(GraphRef, Metadata, &IsModelParamsUpdated); + auto Res = parseMetadata(GraphRef, Metadata, &IsModelParamsUpdated); if (Res != ErrNo::Success) { spdlog::error("[WASI-NN] GGML backend: Failed to parse metadata."sv); @@ -752,7 +752,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, spdlog::info("[WASI-NN][Debug] GGML backend: init llama context"sv); } llama_context_params ContextParams = llama_context_default_params(); - details::setupContextParam(GraphRef, ContextParams); + setupContextParam(GraphRef, ContextParams); auto LlamaContext = llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); if (GraphRef.EnableDebugLog) { @@ -774,7 +774,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, // Handle llava format prompt. // Check if the prompt contains a base64 image. - bool ContainsBase64Image = details::containsBase64Image(GraphRef, Prompt); + bool ContainsBase64Image = containsBase64Image(GraphRef, Prompt); if (GraphRef.ImagePath == ""sv && ContainsBase64Image == false) { spdlog::error( "[WASI-NN] GGML backend: Error: when using llava model, " @@ -804,9 +804,9 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, if (ContainsBase64Image) { // Load the base64 image from the prompt. CxtRef.LlavaImageEmbd = - details::loadBase64ImageFromPrompt(GraphRef, ClipContext, Prompt); + loadBase64ImageFromPrompt(GraphRef, ClipContext, Prompt); // Replace the base64 image in the prompt with a placeholder. - auto Res = details::replaceBase64ImagePlaceholderInPrompt(Prompt); + auto Res = replaceBase64ImagePlaceholderInPrompt(Prompt); if (Res != ErrNo::Success) { spdlog::error( "[WASI-NN] GGML backend: Error: unable to replace the base64 image in the prompt."sv); @@ -826,15 +826,15 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } // We split prompt by as placeholder and save the position. - auto PlaceholderPosition = Prompt.find(details::PromptImagePlaceholder); + auto PlaceholderPosition = Prompt.find(PromptImagePlaceholder); if (PlaceholderPosition == std::string::npos) { spdlog::error( "[WASI-NN] GGML backend: Error: unable to find the placeholder in the llava prompt."sv); return ErrNo::InvalidArgument; } std::string PromptBeforeImage = Prompt.substr(0, PlaceholderPosition); - std::string PromptAfterImage = Prompt.substr( - PlaceholderPosition + details::PromptImagePlaceholder.length()); + std::string PromptAfterImage = + Prompt.substr(PlaceholderPosition + PromptImagePlaceholder.length()); std::vector EmbdInputBeforeImage = llama_tokenize(LlamaContext, PromptBeforeImage, AddBos, true); std::vector EmbdInputAfterImage = @@ -878,7 +878,7 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, // Index 1 is for the metadata of the outputs. if (Index == 1) { std::string Metadata; - auto Res = details::buildOutputMetadata(CxtRef, Metadata); + auto Res = buildOutputMetadata(CxtRef, Metadata); if (Res != ErrNo::Success) { spdlog::error( "[WASI-NN] GGML backend: Failed to build output metadata."sv); @@ -903,7 +903,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { } if (GraphRef.Embedding) { - return details::getEmbedding(Env, ContextId); + return getEmbedding(Env, ContextId); } if (CxtRef.LlamaInputs.size() == 0) { @@ -926,8 +926,8 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // Initialize the llama context. gpt_params GPTParams; llama_context_params ContextParams = llama_context_default_params(); - details::setupGPTParam(GraphRef, GPTParams); - details::setupContextParam(GraphRef, ContextParams); + setupGPTParam(GraphRef, GPTParams); + setupContextParam(GraphRef, ContextParams); auto LlamaContext = llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); struct llama_sampling_context *CtxSampling = @@ -955,8 +955,8 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // Evaluate input tokens. if (CxtRef.LlavaImageEmbd == nullptr) { // Text only prompt. - ReturnCode = details::evaluateTokens(GraphRef, LlamaContext, - CxtRef.LlamaInputs, NPast); + ReturnCode = evaluateTokens(GraphRef, LlamaContext, + std::move(CxtRef.LlamaInputs), NPast); if (ReturnCode != ErrNo::Success) { spdlog::error( "[WASI-NN] GGML backend: failed to evaluate input tokens."sv); @@ -970,8 +970,8 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { std::vector EmbdInputAfterImage(CxtRef.LlamaInputs.begin() + CxtRef.LlavaImagePosition, CxtRef.LlamaInputs.end()); - ReturnCode = details::evaluateTokens(GraphRef, LlamaContext, - EmbdInputBeforeImage, NPast); + ReturnCode = evaluateTokens(GraphRef, LlamaContext, + std::move(EmbdInputBeforeImage), NPast); if (ReturnCode != ErrNo::Success) { spdlog::error( "[WASI-NN] GGML backend: failed to evaluate input tokens before image."sv); @@ -984,8 +984,8 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { "[WASI-NN] GGML backend: failed to evaluate embed image tokens."sv); return ErrNo::RuntimeError; } - ReturnCode = details::evaluateTokens(GraphRef, LlamaContext, - EmbdInputAfterImage, NPast); + ReturnCode = evaluateTokens(GraphRef, LlamaContext, + std::move(EmbdInputAfterImage), NPast); if (ReturnCode != ErrNo::Success) { spdlog::error( "[WASI-NN] GGML backend: failed to evaluate input tokens after image."sv); @@ -1027,7 +1027,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { break; } // Evaluate the output token. - ReturnCode = details::evaluateTokens(GraphRef, LlamaContext, {Id}, NPast); + ReturnCode = evaluateTokens(GraphRef, LlamaContext, {Id}, NPast); if (ReturnCode != ErrNo::Success) { break; } @@ -1065,7 +1065,7 @@ Expect getOutputSingle(WasiNNEnvironment &Env, uint32_t ContextId, // Index 1 is for the metadata of the outputs. if (Index == 1) { std::string Metadata; - auto Res = details::buildOutputMetadata(CxtRef, Metadata); + auto Res = buildOutputMetadata(CxtRef, Metadata); if (Res != ErrNo::Success) { spdlog::error( "[WASI-NN] GGML backend: Failed to build output metadata."sv); @@ -1113,8 +1113,8 @@ Expect computeSingle(WasiNNEnvironment &Env, // Initialize the llama context. gpt_params GPTParams; llama_context_params ContextParams = llama_context_default_params(); - details::setupGPTParam(GraphRef, GPTParams); - details::setupContextParam(GraphRef, ContextParams); + setupGPTParam(GraphRef, GPTParams); + setupContextParam(GraphRef, ContextParams); CxtRef.LlamaContext = llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); CxtRef.LlamaSampling = llama_sampling_init(GPTParams.sparams); @@ -1140,8 +1140,9 @@ Expect computeSingle(WasiNNEnvironment &Env, // Evaluate input tokens. if (CxtRef.LlavaImageEmbd == nullptr) { // Text only prompt. - ReturnCode = details::evaluateTokens( - GraphRef, CxtRef.LlamaContext, CxtRef.LlamaInputs, CxtRef.LlamaNPast); + ReturnCode = + evaluateTokens(GraphRef, CxtRef.LlamaContext, + std::move(CxtRef.LlamaInputs), CxtRef.LlamaNPast); if (ReturnCode != ErrNo::Success) { spdlog::error( "[WASI-NN] GGML backend: failed to evaluate input tokens."sv); @@ -1156,8 +1157,8 @@ Expect computeSingle(WasiNNEnvironment &Env, CxtRef.LlamaInputs.begin() + CxtRef.LlavaImagePosition, CxtRef.LlamaInputs.end()); ReturnCode = - details::evaluateTokens(GraphRef, CxtRef.LlamaContext, - EmbdInputBeforeImage, CxtRef.LlamaNPast); + evaluateTokens(GraphRef, CxtRef.LlamaContext, + std::move(EmbdInputBeforeImage), CxtRef.LlamaNPast); if (ReturnCode != ErrNo::Success) { spdlog::error( "[WASI-NN] GGML backend: failed to evaluate input tokens before image."sv); @@ -1172,8 +1173,8 @@ Expect computeSingle(WasiNNEnvironment &Env, return ErrNo::RuntimeError; } ReturnCode = - details::evaluateTokens(GraphRef, CxtRef.LlamaContext, - EmbdInputAfterImage, CxtRef.LlamaNPast); + evaluateTokens(GraphRef, CxtRef.LlamaContext, + std::move(EmbdInputAfterImage), CxtRef.LlamaNPast); if (ReturnCode != ErrNo::Success) { spdlog::error( "[WASI-NN] GGML backend: failed to evaluate input tokens after image."sv); @@ -1205,8 +1206,8 @@ Expect computeSingle(WasiNNEnvironment &Env, } // Evaluate the output token if not EOS. if (ReturnCode != ErrNo::EndOfSequence) { - ReturnCode = details::evaluateTokens(GraphRef, CxtRef.LlamaContext, {Id}, - CxtRef.LlamaNPast); + ReturnCode = + evaluateTokens(GraphRef, CxtRef.LlamaContext, {Id}, CxtRef.LlamaNPast); } if (GraphRef.EnableDebugLog) { spdlog::info( diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index 6a4038a1..83090a63 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -15,6 +15,7 @@ namespace Host { namespace WASINN { +namespace { Runtime::Instance::ModuleInstance * create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { return new WasiNNModule; @@ -48,6 +49,7 @@ bool load(const std::filesystem::path &Path, std::vector &Data) { File.close(); return true; } +} // namespace WasiNNEnvironment::WasiNNEnvironment() noexcept { #ifdef WASMEDGE_BUILD_WASI_NN_RPC @@ -129,6 +131,7 @@ PO::Option WasiNNEnvironment::NNRPCURI( PO::MetaVar("URI"sv), PO::DefaultValue(std::string(""))); #endif +namespace { void addOptions(const Plugin::Plugin::PluginDescriptor *, PO::ArgumentParser &Parser) noexcept { Parser.add_option("nn-preload"sv, WasiNNEnvironment::NNModels); @@ -156,6 +159,7 @@ Plugin::Plugin::PluginDescriptor Descriptor{ }, .AddOptions = addOptions, }; +} // namespace EXPORT_GET_DESCRIPTOR(Descriptor) diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index fe68a8fa..3abca194 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -5,65 +5,66 @@ wasmedge_add_executable(wasiNNTests wasi_nn.cpp ) +function(download URL OUTPUT HASH) + file(DOWNLOAD + ${URL} + ${OUTPUT} + SHOW_PROGRESS + EXPECTED_HASH ${HASH} + ) +endfunction() + # Prepare the testing data for each backends. foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) string(TOLOWER ${BACKEND} BACKEND) if(BACKEND MATCHES "openvino") - message( STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_openvino_fixtures") - execute_process( - COMMAND bash ${CMAKE_SOURCE_DIR}/utils/wasi-nn/download-openvino-fixtures.sh ${CMAKE_CURRENT_BINARY_DIR}/wasinn_openvino_fixtures - RESULT_VARIABLE DOWNLOAD_ERROR - OUTPUT_STRIP_TRAILING_WHITESPACE) - file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_openvino_fixtures/mobilenet.bin CHECKSUM_WEIGHT) - file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_openvino_fixtures/mobilenet.xml CHECKSUM_DESCRIP) - file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_openvino_fixtures/tensor-1x224x224x3-f32.bgr CHECKSUM_TENSOR) - if(NOT CHECKSUM_WEIGHT STREQUAL "ae096b1f735f1e8e54bac8b2a42303bd") - message(FATAL_ERROR "mobilenet.bin downloaded with wrong md5") - endif() - if(NOT CHECKSUM_DESCRIP STREQUAL "4ea3a14273587ce5c1662018878f9f90") - message(FATAL_ERROR "mobilenet.xml downloaded with wrong md5") - endif() - if(NOT CHECKSUM_TENSOR STREQUAL "bfca546f4a3b5e6da49b7bd728e2799a") - message(FATAL_ERROR "tensor-1x224x224x3-f32.bgr downloaded with wrong md5") - endif() + message(STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_openvino_fixtures") + download( + https://github.com/intel/openvino-rs/raw/v0.3.3/crates/openvino/tests/fixtures/mobilenet/mobilenet.bin + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_openvino_fixtures/mobilenet.bin + MD5=ae096b1f735f1e8e54bac8b2a42303bd + ) + download( + https://github.com/intel/openvino-rs/raw/v0.3.3/crates/openvino/tests/fixtures/mobilenet/mobilenet.xml + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_openvino_fixtures/mobilenet.xml + MD5=4ea3a14273587ce5c1662018878f9f90 + ) + download( + https://github.com/intel/openvino-rs/raw/v0.3.3/crates/openvino/tests/fixtures/mobilenet/tensor-1x224x224x3-f32.bgr + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_openvino_fixtures/tensor-1x224x224x3-f32.bgr + MD5=bfca546f4a3b5e6da49b7bd728e2799a + ) elseif(BACKEND MATCHES "pytorch") - message( STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_pytorch_fixtures") - execute_process( - COMMAND bash ${CMAKE_SOURCE_DIR}/utils/wasi-nn/download-pytorch-fixtures.sh ${CMAKE_CURRENT_BINARY_DIR}/wasinn_pytorch_fixtures - RESULT_VARIABLE DOWNLOAD_ERROR - OUTPUT_STRIP_TRAILING_WHITESPACE) - file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_pytorch_fixtures/mobilenet.pt CHECKSUM_WEIGHT) - file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_pytorch_fixtures/image-1x3x224x224.rgb CHECKSUM_IMAGE) - if(NOT CHECKSUM_WEIGHT STREQUAL "234f446d2446e0f6fd8ed700c0b4b63b") - message(FATAL_ERROR "mobilenet.pt downloaded with wrong md5") - endif() - if(NOT CHECKSUM_IMAGE STREQUAL "551caa6f3b66c1d953655228462570a1") - message(FATAL_ERROR "image-1x3x224x224.rgb downloaded with wrong md5") - endif() + message(STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_pytorch_fixtures") + download( + https://github.com/second-state/WasmEdge-WASINN-examples/raw/master/pytorch-mobilenet-image/mobilenet.pt + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_pytorch_fixtures/mobilenet.pt + MD5=234f446d2446e0f6fd8ed700c0b4b63b + ) + download( + https://github.com/second-state/WasmEdge-WASINN-examples/raw/master/pytorch-mobilenet-image/image-1x3x224x224.rgb + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_pytorch_fixtures/image-1x3x224x224.rgb + MD5=551caa6f3b66c1d953655228462570a1 + ) elseif(BACKEND STREQUAL "tensorflowlite") - message( STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_tflite_fixtures") - execute_process( - COMMAND bash ${CMAKE_SOURCE_DIR}/utils/wasi-nn/download-tflite-fixtures.sh ${CMAKE_CURRENT_BINARY_DIR}/wasinn_tflite_fixtures - RESULT_VARIABLE DOWNLOAD_ERROR - OUTPUT_STRIP_TRAILING_WHITESPACE) - file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_tflite_fixtures/lite-model_aiy_vision_classifier_birds_V1_3.tflite CHECKSUM_WEIGHT) - file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_tflite_fixtures/birdx224x224x3.rgb CHECKSUM_IMAGE) - if(NOT CHECKSUM_WEIGHT STREQUAL "3e59cc3a99afeeb819c2c38b319a7938") - message(FATAL_ERROR "downloaded tflite model with wrong md5") - endif() - if(NOT CHECKSUM_IMAGE STREQUAL "ad51c39cfe35d2ef35c4052b78cb3c55") - message(FATAL_ERROR "downloaded bird.jpg fixture with wrong md5") - endif() + message(STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_tflite_fixtures") + download( + https://raw.githubusercontent.com/gusye1234/WasmEdge-WASINN-examples/demo-tflite-image/tflite-birds_v1-image/lite-model_aiy_vision_classifier_birds_V1_3.tflite + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_tflite_fixtures/lite-model_aiy_vision_classifier_birds_V1_3.tflite + MD5=3e59cc3a99afeeb819c2c38b319a7938 + ) + download( + https://raw.githubusercontent.com/gusye1234/WasmEdge-WASINN-examples/demo-tflite-image/tflite-birds_v1-image/birdx224x224x3.rgb + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_tflite_fixtures/birdx224x224x3.rgb + MD5=ad51c39cfe35d2ef35c4052b78cb3c55 + ) elseif(BACKEND STREQUAL "ggml") - message( STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_ggml_fixtures") - execute_process( - COMMAND bash ${CMAKE_SOURCE_DIR}/utils/wasi-nn/download-ggml-fixtures.sh ${CMAKE_CURRENT_BINARY_DIR}/wasinn_ggml_fixtures - RESULT_VARIABLE DOWNLOAD_ERROR - OUTPUT_STRIP_TRAILING_WHITESPACE) - file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_ggml_fixtures/orca_mini.gguf CHECKSUM_MODEL) - if(NOT CHECKSUM_MODEL STREQUAL "f895f00678bfbf89f70d6d25f20a7b5f") - message(FATAL_ERROR "orca_mini.gguf downloaded with wrong md5") - endif() + message(STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_ggml_fixtures") + download( + https://huggingface.co/TheBloke/orca_mini_v3_7B-GGUF/resolve/main/orca_mini_v3_7b.Q2_K.gguf + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_ggml_fixtures/orca_mini.gguf + MD5=f895f00678bfbf89f70d6d25f20a7b5f + ) else() # Add the other backend test files fetching here. endif() @@ -107,4 +108,4 @@ if(WASMEDGE_BUILD_WASI_NN_RPC) PRIVATE wasiNNRPC ) -endif() \ No newline at end of file +endif() diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index f5d5572d..20287914 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -46,10 +46,10 @@ inline std::vector readEntireFile(const std::string &Path) { return {}; } Fin.seekg(0, std::ios::end); - std::vector Buf(static_cast(Fin.tellg())); + std::vector Buf(static_cast(Fin.tellg())); Fin.seekg(0, std::ios::beg); if (!Fin.read(reinterpret_cast(Buf.data()), - static_cast(Buf.size()))) { + static_cast(Buf.size()))) { return {}; } Fin.close(); @@ -75,6 +75,9 @@ void writeFatPointer(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, writeUInt32(MemInst, PtrSize, Ptr); } +#if defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO) || \ + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH) || \ + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE) template std::vector classSort(WasmEdge::Span Array) { std::vector Indices(Array.size()); @@ -86,6 +89,7 @@ std::vector classSort(WasmEdge::Span Array) { }); return Indices; } +#endif } // namespace #endif @@ -194,9 +198,10 @@ TEST(WasiNNTest, OpenVINOBackend) { // Test: load -- OpenVINO model xml ptr out of bounds. BuilderPtr = LoadEntryPtr; - writeFatPointer(MemInst, OutBoundPtr, XmlRead.size(), BuilderPtr); - writeFatPointer(MemInst, StorePtr + XmlRead.size(), WeightRead.size(), + writeFatPointer(MemInst, OutBoundPtr, static_cast(XmlRead.size()), BuilderPtr); + writeFatPointer(MemInst, StorePtr + static_cast(XmlRead.size()), + static_cast(WeightRead.size()), BuilderPtr); { EXPECT_TRUE(HostFuncLoad.run( CallFrame, @@ -209,8 +214,10 @@ TEST(WasiNNTest, OpenVINOBackend) { // Test: load -- OpenVINO model bin ptr out of bounds. BuilderPtr = LoadEntryPtr; - writeFatPointer(MemInst, StorePtr, XmlRead.size(), BuilderPtr); - writeFatPointer(MemInst, OutBoundPtr, WeightRead.size(), BuilderPtr); + writeFatPointer(MemInst, StorePtr, static_cast(XmlRead.size()), + BuilderPtr); + writeFatPointer(MemInst, OutBoundPtr, + static_cast(WeightRead.size()), BuilderPtr); { EXPECT_TRUE(HostFuncLoad.run( CallFrame, @@ -223,9 +230,10 @@ TEST(WasiNNTest, OpenVINOBackend) { // Test: load -- wrong builders' length. BuilderPtr = LoadEntryPtr; - writeFatPointer(MemInst, StorePtr, XmlRead.size(), BuilderPtr); - writeFatPointer(MemInst, StorePtr + XmlRead.size(), WeightRead.size(), + writeFatPointer(MemInst, StorePtr, static_cast(XmlRead.size()), BuilderPtr); + writeFatPointer(MemInst, StorePtr + static_cast(XmlRead.size()), + static_cast(WeightRead.size()), BuilderPtr); writeBinaries(MemInst, XmlRead, StorePtr); writeBinaries(MemInst, WeightRead, StorePtr + XmlRead.size()); StorePtr += (XmlRead.size() + WeightRead.size()); @@ -326,10 +334,12 @@ TEST(WasiNNTest, OpenVINOBackend) { // OpenVINO WASI-NN set_input tests. SetInputEntryPtr = BuilderPtr; - writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); - writeUInt32(MemInst, UINT32_C(1), BuilderPtr); - writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); writeBinaries(MemInst, TensorDim, StorePtr); writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); @@ -375,10 +385,12 @@ TEST(WasiNNTest, OpenVINOBackend) { // Test: set_input -- tensor type not FP32. BuilderPtr = SetInputEntryPtr; - writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); - writeUInt32(MemInst, UINT32_C(2), BuilderPtr); - writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), BuilderPtr); + writeUInt32(MemInst, UINT32_C(2), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); { EXPECT_TRUE( HostFuncSetInput.run(CallFrame, @@ -391,10 +403,12 @@ TEST(WasiNNTest, OpenVINOBackend) { // Test: set_input -- set input successfully. BuilderPtr = SetInputEntryPtr; - writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); - writeUInt32(MemInst, UINT32_C(1), BuilderPtr); - writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); { EXPECT_TRUE( HostFuncSetInput.run(CallFrame, @@ -598,7 +612,8 @@ TEST(WasiNNTest, PyTorchBackend) { } // Test: load -- Torch model bin ptr out of bounds. BuilderPtr = LoadEntryPtr; - writeFatPointer(MemInst, OutBoundPtr, WeightRead.size(), BuilderPtr); + writeFatPointer(MemInst, OutBoundPtr, + static_cast(WeightRead.size()), BuilderPtr); { EXPECT_TRUE(HostFuncLoad.run(CallFrame, std::initializer_list{ @@ -612,7 +627,8 @@ TEST(WasiNNTest, PyTorchBackend) { // Test: load -- wrong builders' length. BuilderPtr = LoadEntryPtr; - writeFatPointer(MemInst, StorePtr, WeightRead.size(), BuilderPtr); + writeFatPointer(MemInst, StorePtr, static_cast(WeightRead.size()), + BuilderPtr); writeBinaries(MemInst, WeightRead, StorePtr); StorePtr += WeightRead.size(); { @@ -717,10 +733,12 @@ TEST(WasiNNTest, PyTorchBackend) { // Torch WASI-NN set_input tests. SetInputEntryPtr = BuilderPtr; - writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); - writeUInt32(MemInst, UINT32_C(1), BuilderPtr); - writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); writeBinaries(MemInst, TensorDim, StorePtr); writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); @@ -739,10 +757,12 @@ TEST(WasiNNTest, PyTorchBackend) { // Test: set_input -- tensor type not FP32. BuilderPtr = SetInputEntryPtr; - writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); - writeUInt32(MemInst, UINT32_C(2), BuilderPtr); - writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), BuilderPtr); + writeUInt32(MemInst, UINT32_C(2), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); { EXPECT_TRUE( HostFuncSetInput.run(CallFrame, @@ -755,10 +775,12 @@ TEST(WasiNNTest, PyTorchBackend) { // Test: set_input -- set input successfully. BuilderPtr = SetInputEntryPtr; - writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); - writeUInt32(MemInst, UINT32_C(1), BuilderPtr); - writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); { EXPECT_TRUE( HostFuncSetInput.run(CallFrame, @@ -966,7 +988,8 @@ TEST(WasiNNTest, TFLiteBackend) { } // Test: load -- model bin ptr out of bounds. BuilderPtr = LoadEntryPtr; - writeFatPointer(MemInst, OutBoundPtr, WeightRead.size(), BuilderPtr); + writeFatPointer(MemInst, OutBoundPtr, + static_cast(WeightRead.size()), BuilderPtr); { EXPECT_TRUE( HostFuncLoad.run(CallFrame, @@ -981,7 +1004,8 @@ TEST(WasiNNTest, TFLiteBackend) { // Test: load -- wrong builders' length. BuilderPtr = LoadEntryPtr; - writeFatPointer(MemInst, StorePtr, WeightRead.size(), BuilderPtr); + writeFatPointer(MemInst, StorePtr, static_cast(WeightRead.size()), + BuilderPtr); writeBinaries(MemInst, WeightRead, StorePtr); StorePtr += WeightRead.size(); { @@ -1089,10 +1113,12 @@ TEST(WasiNNTest, TFLiteBackend) { // Torch WASI-NN set_input tests. SetInputEntryPtr = BuilderPtr; - writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); - writeUInt32(MemInst, UINT32_C(1), BuilderPtr); - writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); writeBinaries(MemInst, TensorDim, StorePtr); writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); @@ -1111,11 +1137,13 @@ TEST(WasiNNTest, TFLiteBackend) { // Test: set_input -- set input successfully. BuilderPtr = SetInputEntryPtr; - writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), + BuilderPtr); // Tensor type U8 writeUInt32(MemInst, UINT32_C(2), BuilderPtr); - writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), - BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); { EXPECT_TRUE( HostFuncSetInput.run(CallFrame, @@ -1318,7 +1346,8 @@ TEST(WasiNNTest, GGMLBackend) { // Test: load -- GGML model bin ptr out of bounds. BuilderPtr = LoadEntryPtr; - writeFatPointer(MemInst, OutBoundPtr, WeightRead.size(), BuilderPtr); + writeFatPointer(MemInst, OutBoundPtr, + static_cast(WeightRead.size()), BuilderPtr); { EXPECT_TRUE(HostFuncLoad.run(CallFrame, std::initializer_list{ @@ -1332,7 +1361,8 @@ TEST(WasiNNTest, GGMLBackend) { // Test: load -- wrong metadata encoding when builders length > 1. BuilderPtr = LoadEntryPtr; - writeFatPointer(MemInst, StorePtr, WeightRead.size(), BuilderPtr); + writeFatPointer(MemInst, StorePtr, static_cast(WeightRead.size()), + BuilderPtr); writeBinaries(MemInst, WeightRead, StorePtr); StorePtr += WeightRead.size(); { @@ -1383,12 +1413,16 @@ TEST(WasiNNTest, GGMLBackend) { // GGML WASI-NN set_input tests. SetInputEntryPtr = BuilderPtr; - writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); - writeUInt32(MemInst, UINT32_C(1), BuilderPtr); - writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); writeBinaries(MemInst, TensorDim, StorePtr); - writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); + writeBinaries(MemInst, TensorData, + StorePtr + + static_cast(TensorDim.size()) * 4); // Test: set_input -- context id exceeds. { @@ -1403,10 +1437,12 @@ TEST(WasiNNTest, GGMLBackend) { // Test: set_input -- set input successfully. BuilderPtr = SetInputEntryPtr; - writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); - writeUInt32(MemInst, UINT32_C(1), BuilderPtr); - writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); { EXPECT_TRUE( HostFuncSetInput.run(CallFrame, @@ -1615,10 +1651,12 @@ TEST(WasiNNTest, GGMLBackendWithRPC) { // GGML WASI-NN set_input tests. SetInputEntryPtr = BuilderPtr; - writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); - writeUInt32(MemInst, UINT32_C(1), BuilderPtr); - writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); writeBinaries(MemInst, TensorDim, StorePtr); writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); @@ -1634,10 +1672,12 @@ TEST(WasiNNTest, GGMLBackendWithRPC) { // Test: set_input -- set input successfully. BuilderPtr = SetInputEntryPtr; - writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); - writeUInt32(MemInst, UINT32_C(1), BuilderPtr); - writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); { EXPECT_TRUE( HostFuncSetInput.run(CallFrame, diff --git a/utils/wasi-nn/download-ggml-fixtures.sh b/utils/wasi-nn/download-ggml-fixtures.sh deleted file mode 100755 index bb635925..00000000 --- a/utils/wasi-nn/download-ggml-fixtures.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/usr/bin/env bash -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# SPDX-FileCopyrightText: 2019-2023 Second State INC - -TODIR=$1 -if [[ $# -eq 0 ]]; then - TODIR=. -fi -MODEL=orca_mini.gguf -FIXTURE=https://huggingface.co/TheBloke/orca_mini_v3_7B-GGUF/resolve/main/orca_mini_v3_7b.Q2_K.gguf -if [ ! -d $TODIR ]; then - mkdir $TODIR -fi - -if [ ! -f $TODIR/$MODEL ]; then - curl -sL $FIXTURE -o $TODIR/$MODEL -fi diff --git a/utils/wasi-nn/download-openvino-fixtures.sh b/utils/wasi-nn/download-openvino-fixtures.sh deleted file mode 100755 index 02a243da..00000000 --- a/utils/wasi-nn/download-openvino-fixtures.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/usr/bin/env bash -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# SPDX-FileCopyrightText: 2019-2022 Second State INC - -TODIR=$1 -if [[ $# -eq 0 ]]; then - TODIR=. -fi -FIXTURE=https://github.com/intel/openvino-rs/raw/v0.3.3/crates/openvino/tests/fixtures/mobilenet/ -if [ ! -d $TODIR ]; then - mkdir $TODIR -fi -if [ ! -d $TODIR ]; then - mkdir $TODIR -fi - -if [ ! -f $TODIR/mobilenet.bin ]; then - curl -sL $FIXTURE/mobilenet.bin -o $TODIR/mobilenet.bin -fi -if [ ! -f $TODIR/mobilenet.xml ]; then - curl -sL $FIXTURE/mobilenet.xml -o $TODIR/mobilenet.xml -fi -if [ ! -f $TODIR/tensor-1x224x224x3-f32.bgr ]; then - curl -sL $FIXTURE/tensor-1x224x224x3-f32.bgr -o $TODIR/tensor-1x224x224x3-f32.bgr -fi diff --git a/utils/wasi-nn/download-pytorch-fixtures.sh b/utils/wasi-nn/download-pytorch-fixtures.sh deleted file mode 100755 index 6a6aab91..00000000 --- a/utils/wasi-nn/download-pytorch-fixtures.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/usr/bin/env bash -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# SPDX-FileCopyrightText: 2019-2022 Second State INC - -TODIR=$1 -if [[ $# -eq 0 ]]; then - TODIR=. -fi -FIXTURE=https://github.com/second-state/WasmEdge-WASINN-examples/raw/master/pytorch-mobilenet-image/ -if [ ! -d $TODIR ]; then - mkdir $TODIR -fi - -if [ ! -f $TODIR/mobilenet.pt ]; then - curl -sL $FIXTURE/mobilenet.pt -o $TODIR/mobilenet.pt -fi -if [ ! -f $TODIR/image-1x3x224x224.rgb ]; then - curl -sL $FIXTURE/image-1x3x224x224.rgb -o $TODIR/image-1x3x224x224.rgb -fi diff --git a/utils/wasi-nn/download-tflite-fixtures.sh b/utils/wasi-nn/download-tflite-fixtures.sh deleted file mode 100755 index 959d7fee..00000000 --- a/utils/wasi-nn/download-tflite-fixtures.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/usr/bin/env bash -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# SPDX-FileCopyrightText: 2019-2022 Second State INC - -TODIR=$1 -if [[ $# -eq 0 ]]; then - TODIR=. -fi -FIXTURE=https://raw.githubusercontent.com/gusye1234/WasmEdge-WASINN-examples/demo-tflite-image/tflite-birds_v1-image -if [ ! -d $TODIR ]; then - mkdir $TODIR -fi - -if [ ! -f $TODIR/lite-model_aiy_vision_classifier_birds_V1_3.tflite ]; then - curl -sL $FIXTURE/lite-model_aiy_vision_classifier_birds_V1_3.tflite -o $TODIR/lite-model_aiy_vision_classifier_birds_V1_3.tflite -fi -if [ ! -f $TODIR/birdx224x224x3.rgb ]; then - curl -sL $FIXTURE/birdx224x224x3.rgb -o $TODIR/birdx224x224x3.rgb -fi From 84b1f5cd26ea83d7ac04f679351dc1bbac5d132e Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Wed, 27 Mar 2024 14:35:45 +0800 Subject: [PATCH 270/623] [Plugin] Add WASMEDGE_LIB_PREFIX for windows * Change std::string arguments to std::string_view Signed-off-by: Shen-Ta Hsieh --- test/plugins/CMakeLists.txt | 4 +-- test/plugins/unittest/CMakeLists.txt | 9 +++++- test/plugins/unittest/testplugin.c | 20 ++++++------- test/plugins/unittest/testplugin.cpp | 29 ++++++++++--------- test/plugins/unittest/unittest_c.cpp | 6 ++-- test/plugins/unittest/unittest_cpp.cpp | 6 ++-- test/plugins/wasi_crypto/helper.h | 4 +-- test/plugins/wasi_logging/wasi_logging.cpp | 4 +-- test/plugins/wasi_nn/wasi_nn.cpp | 6 ++-- test/plugins/wasm_bpf/simple_map_test.cpp | 6 ++-- test/plugins/wasm_bpf/simple_ringbuf_test.cpp | 6 ++-- test/plugins/wasm_bpf/wasm_bpf.cpp | 6 ++-- test/plugins/wasmedge_ffmpeg/utils.h | 6 ++-- .../plugins/wasmedge_image/wasmedge_image.cpp | 4 +-- .../wasmedge_opencvmini.cpp | 4 +-- .../wasmedge_process/wasmedge_process.cpp | 4 +-- .../wasmedge_rustls/wasmedge_rustls.cpp | 6 ++-- .../wasmedge_tensorflow.cpp | 4 +-- .../wasmedge_tensorflowlite.cpp | 4 +-- test/plugins/wasmedge_zlib/wasmedge_zlib.cpp | 4 +-- 20 files changed, 76 insertions(+), 66 deletions(-) diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt index 211c0bbd..eeb1950b 100644 --- a/test/plugins/CMakeLists.txt +++ b/test/plugins/CMakeLists.txt @@ -61,6 +61,4 @@ if(WASMEDGE_PLUGIN_RUSTLS) add_subdirectory(wasmedge_rustls) endif() -if(CMAKE_SYSTEM_NAME MATCHES "Linux" OR CMAKE_SYSTEM_NAME MATCHES "Darwin") - add_subdirectory(unittest) -endif() +add_subdirectory(unittest) diff --git a/test/plugins/unittest/CMakeLists.txt b/test/plugins/unittest/CMakeLists.txt index 4631c0b4..d1614c06 100644 --- a/test/plugins/unittest/CMakeLists.txt +++ b/test/plugins/unittest/CMakeLists.txt @@ -2,15 +2,22 @@ # SPDX-FileCopyrightText: 2019-2022 Second State INC # The test plugin module in C API +enable_language(C) + wasmedge_add_library(wasmedgePluginTestModuleC SHARED testplugin.c -) + ) set_target_properties(wasmedgePluginTestModuleC PROPERTIES C_STANDARD 11 ) +# remove cxx_standard for msvc +set_property(TARGET wasmedgePluginTestModuleC PROPERTY + CXX_STANDARD +) + target_compile_options(wasmedgePluginTestModuleC PUBLIC -DWASMEDGE_PLUGIN diff --git a/test/plugins/unittest/testplugin.c b/test/plugins/unittest/testplugin.c index a0536101..70eb27bd 100644 --- a/test/plugins/unittest/testplugin.c +++ b/test/plugins/unittest/testplugin.c @@ -12,15 +12,15 @@ static WasmEdge_String NameString; static const char NameCString[] = "name"; static const WasmEdge_String NameStringDefaultValue = {.Buf = NameCString, .Length = 4}; -void Finalizer(void *Data) { +static void Finalizer(void *Data) { printf("Deallocate host data\n"); free((int32_t *)Data); } -WasmEdge_Result HostFuncAdd(void *Data, - const WasmEdge_CallingFrameContext *CallFrameCxt - __attribute__((unused)), - const WasmEdge_Value *In, WasmEdge_Value *Out) { +static WasmEdge_Result +HostFuncAdd(void *Data, const WasmEdge_CallingFrameContext *CallFrameCxt, + const WasmEdge_Value *In, WasmEdge_Value *Out) { + (void)CallFrameCxt; /* * Host function to calculate A + B, * and accumulate (A + B) to the host data. @@ -34,10 +34,10 @@ WasmEdge_Result HostFuncAdd(void *Data, return WasmEdge_Result_Success; } -WasmEdge_Result HostFuncSub(void *Data, - const WasmEdge_CallingFrameContext *CallFrameCxt - __attribute__((unused)), - const WasmEdge_Value *In, WasmEdge_Value *Out) { +static WasmEdge_Result +HostFuncSub(void *Data, const WasmEdge_CallingFrameContext *CallFrameCxt, + const WasmEdge_Value *In, WasmEdge_Value *Out) { + (void)CallFrameCxt; /* * Host function to calculate A - B, * and accumulate (A - B) to the host data. @@ -51,7 +51,7 @@ WasmEdge_Result HostFuncSub(void *Data, return WasmEdge_Result_Success; } -WasmEdge_ModuleInstanceContext * +static WasmEdge_ModuleInstanceContext * CreateTestModule(const struct WasmEdge_ModuleDescriptor *Desc) { /* Allocate and initialize a host data. */ printf("Allocate host data\n"); diff --git a/test/plugins/unittest/testplugin.cpp b/test/plugins/unittest/testplugin.cpp index 7cc55013..581dfb5f 100644 --- a/test/plugins/unittest/testplugin.cpp +++ b/test/plugins/unittest/testplugin.cpp @@ -36,21 +36,22 @@ create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { return new WasmEdgePluginTestModule; } +static Plugin::PluginModule::ModuleDescriptor MD[]{ + { + /* Name */ "wasmedge_plugintest_cpp_module", + /* Description */ "This is for the plugin tests in WasmEdge.", + /* Create */ create, + }, +}; + Plugin::Plugin::PluginDescriptor Descriptor{ - .Name = "wasmedge_plugintest_cpp", - .Description = "", - .APIVersion = Plugin::Plugin::CurrentAPIVersion, - .Version = {0, 10, 0, 0}, - .ModuleCount = 1, - .ModuleDescriptions = - (Plugin::PluginModule::ModuleDescriptor[]){ - { - .Name = "wasmedge_plugintest_cpp_module", - .Description = "This is for the plugin tests in WasmEdge.", - .Create = create, - }, - }, - .AddOptions = addOptions, + /* Name */ "wasmedge_plugintest_cpp", + /* Description */ "", + /* APIVersion */ Plugin::Plugin::CurrentAPIVersion, + /* Version */ {0, 10, 0, 0}, + /* ModuleCount */ 1, + /* ModuleDescriptions */ MD, + /* AddOptions */ addOptions, }; EXPORT_GET_DESCRIPTOR(Descriptor) diff --git a/test/plugins/unittest/unittest_c.cpp b/test/plugins/unittest/unittest_c.cpp index a1ccb94d..5397c9ca 100644 --- a/test/plugins/unittest/unittest_c.cpp +++ b/test/plugins/unittest/unittest_c.cpp @@ -11,7 +11,8 @@ namespace { WasmEdge_ModuleInstanceContext *createModuleC() { WasmEdge_PluginLoadFromPath( - "./libwasmedgePluginTestModuleC" WASMEDGE_LIB_EXTENSION); + "./" WASMEDGE_LIB_PREFIX + "wasmedgePluginTestModuleC" WASMEDGE_LIB_EXTENSION); WasmEdge_String Str = WasmEdge_StringCreateByCString("wasmedge_plugintest_c"); const WasmEdge_PluginContext *PluginCxt = WasmEdge_PluginFind(Str); WasmEdge_StringDelete(Str); @@ -28,7 +29,8 @@ WasmEdge_ModuleInstanceContext *createModuleC() { WasmEdge_ModuleInstanceContext *createModuleCPP() { WasmEdge_PluginLoadFromPath( - "./libwasmedgePluginTestModuleCPP" WASMEDGE_LIB_EXTENSION); + "./" WASMEDGE_LIB_PREFIX + "wasmedgePluginTestModuleCPP" WASMEDGE_LIB_EXTENSION); WasmEdge_String Str = WasmEdge_StringCreateByCString("wasmedge_plugintest_cpp"); const WasmEdge_PluginContext *PluginCxt = WasmEdge_PluginFind(Str); diff --git a/test/plugins/unittest/unittest_cpp.cpp b/test/plugins/unittest/unittest_cpp.cpp index 345bdbfd..81af0d5e 100644 --- a/test/plugins/unittest/unittest_cpp.cpp +++ b/test/plugins/unittest/unittest_cpp.cpp @@ -18,7 +18,8 @@ namespace { WasmEdge::Runtime::Instance::ModuleInstance *createModuleC() { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( - "./libwasmedgePluginTestModuleC" WASMEDGE_LIB_EXTENSION)); + "./" WASMEDGE_LIB_PREFIX + "wasmedgePluginTestModuleC" WASMEDGE_LIB_EXTENSION)); if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasmedge_plugintest_c"sv)) { if (const auto *Module = @@ -32,7 +33,8 @@ WasmEdge::Runtime::Instance::ModuleInstance *createModuleC() { WasmEdge::Runtime::Instance::ModuleInstance *createModuleCPP() { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( - "./libwasmedgePluginTestModuleCPP" WASMEDGE_LIB_EXTENSION)); + "./" WASMEDGE_LIB_PREFIX + "wasmedgePluginTestModuleCPP" WASMEDGE_LIB_EXTENSION)); if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasmedge_plugintest_cpp"sv)) { WasmEdge::PO::ArgumentParser Parser; diff --git a/test/plugins/wasi_crypto/helper.h b/test/plugins/wasi_crypto/helper.h index beb1a339..9feefc44 100644 --- a/test/plugins/wasi_crypto/helper.h +++ b/test/plugins/wasi_crypto/helper.h @@ -63,8 +63,8 @@ class WasiCryptoTest : public ::testing::Test { using namespace std::literals::string_view_literals; Plugin::Plugin::load(std::filesystem::u8path( - "../../../plugins/wasi_crypto/" - "libwasmedgePluginWasiCrypto" WASMEDGE_LIB_EXTENSION)); + "../../../plugins/wasi_crypto/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasiCrypto" WASMEDGE_LIB_EXTENSION)); if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasi_crypto"sv)) { if (const auto *Module = Plugin->findModule("wasi_crypto_asymmetric_common"sv)) { diff --git a/test/plugins/wasi_logging/wasi_logging.cpp b/test/plugins/wasi_logging/wasi_logging.cpp index d1c36c02..a3720f2c 100644 --- a/test/plugins/wasi_logging/wasi_logging.cpp +++ b/test/plugins/wasi_logging/wasi_logging.cpp @@ -14,8 +14,8 @@ namespace { WasmEdge::Runtime::Instance::ModuleInstance *createModule() { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( - "../../../plugins/wasi_logging/" - "libwasmedgePluginWasiLogging" WASMEDGE_LIB_EXTENSION)); + "../../../plugins/wasi_logging/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasiLogging" WASMEDGE_LIB_EXTENSION)); if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasi_logging"sv)) { if (const auto *Module = Plugin->findModule("wasi:logging/logging"sv)) { return Module->create().release(); diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 20287914..806c33a3 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -24,9 +24,9 @@ namespace { WasmEdge::Runtime::Instance::ModuleInstance * createModule(std::string_view NNRPCURI = "") { using namespace std::literals::string_view_literals; - WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( - "../../../plugins/wasi_nn/" - "libwasmedgePluginWasiNN" WASMEDGE_LIB_EXTENSION)); + WasmEdge::Plugin::Plugin::load( + std::filesystem::u8path("../../../plugins/wasi_nn/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasiNN" WASMEDGE_LIB_EXTENSION)); if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasi_nn"sv)) { WasmEdge::PO::ArgumentParser Parser; Plugin->registerOptions(Parser); diff --git a/test/plugins/wasm_bpf/simple_map_test.cpp b/test/plugins/wasm_bpf/simple_map_test.cpp index 9a86892c..51ef2f5a 100644 --- a/test/plugins/wasm_bpf/simple_map_test.cpp +++ b/test/plugins/wasm_bpf/simple_map_test.cpp @@ -23,9 +23,9 @@ namespace { WasmEdge::Runtime::Instance::ModuleInstance *createModule() { using namespace std::literals::string_view_literals; - WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( - "../../../plugins/wasm_bpf/" - "libwasmedgePluginWasmBpf" WASMEDGE_LIB_EXTENSION)); + WasmEdge::Plugin::Plugin::load( + std::filesystem::u8path("../../../plugins/wasm_bpf/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasmBpf" WASMEDGE_LIB_EXTENSION)); if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasm_bpf"sv)) { if (const auto *Module = Plugin->findModule("wasm_bpf"sv)) { return Module->create().release(); diff --git a/test/plugins/wasm_bpf/simple_ringbuf_test.cpp b/test/plugins/wasm_bpf/simple_ringbuf_test.cpp index f7052105..ac389934 100644 --- a/test/plugins/wasm_bpf/simple_ringbuf_test.cpp +++ b/test/plugins/wasm_bpf/simple_ringbuf_test.cpp @@ -19,9 +19,9 @@ namespace { WasmEdge::Runtime::Instance::ModuleInstance *createModule() { using namespace std::literals::string_view_literals; - WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( - "../../../plugins/wasm_bpf/" - "libwasmedgePluginWasmBpf" WASMEDGE_LIB_EXTENSION)); + WasmEdge::Plugin::Plugin::load( + std::filesystem::u8path("../../../plugins/wasm_bpf/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasmBpf" WASMEDGE_LIB_EXTENSION)); if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasm_bpf"sv)) { if (const auto *Module = Plugin->findModule("wasm_bpf"sv)) { return Module->create().release(); diff --git a/test/plugins/wasm_bpf/wasm_bpf.cpp b/test/plugins/wasm_bpf/wasm_bpf.cpp index 9bf84f87..3e5dbd0e 100644 --- a/test/plugins/wasm_bpf/wasm_bpf.cpp +++ b/test/plugins/wasm_bpf/wasm_bpf.cpp @@ -31,9 +31,9 @@ namespace { WasmEdge::Runtime::Instance::ModuleInstance *createModule() { using namespace std::literals::string_view_literals; - WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( - "../../../plugins/wasm_bpf/" - "libwasmedgePluginWasmBpf" WASMEDGE_LIB_EXTENSION)); + WasmEdge::Plugin::Plugin::load( + std::filesystem::u8path("../../../plugins/wasm_bpf/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasmBpf" WASMEDGE_LIB_EXTENSION)); if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasm_bpf"sv)) { if (const auto *Module = Plugin->findModule("wasm_bpf"sv)) { return Module->create().release(); diff --git a/test/plugins/wasmedge_ffmpeg/utils.h b/test/plugins/wasmedge_ffmpeg/utils.h index 9128bacd..6c95138e 100644 --- a/test/plugins/wasmedge_ffmpeg/utils.h +++ b/test/plugins/wasmedge_ffmpeg/utils.h @@ -59,8 +59,8 @@ class FFmpegTest : public ::testing::Test { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( - "../../../plugins/wasmedge_ffmpeg/" - "libwasmedgePluginWasmEdgeFFmpeg" WASMEDGE_LIB_EXTENSION)); + "../../../plugins/wasmedge_ffmpeg/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasmEdgeFFmpeg" WASMEDGE_LIB_EXTENSION)); if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasmedge_ffmpeg"sv)) { if (const auto *Module = @@ -161,4 +161,4 @@ class FFmpegTest : public ::testing::Test { }; } // namespace WasmEdgeFFmpeg } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_image/wasmedge_image.cpp b/test/plugins/wasmedge_image/wasmedge_image.cpp index 0583c1c7..a5ea3356 100644 --- a/test/plugins/wasmedge_image/wasmedge_image.cpp +++ b/test/plugins/wasmedge_image/wasmedge_image.cpp @@ -17,8 +17,8 @@ namespace { WasmEdge::Runtime::Instance::ModuleInstance *createModule() { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( - "../../../plugins/wasmedge_image/" - "libwasmedgePluginWasmEdgeImage" WASMEDGE_LIB_EXTENSION)); + "../../../plugins/wasmedge_image/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasmEdgeImage" WASMEDGE_LIB_EXTENSION)); if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasmedge_image"sv)) { if (const auto *Module = Plugin->findModule("wasmedge_image"sv)) { return Module->create().release(); diff --git a/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp b/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp index 926e1330..d3b570c9 100644 --- a/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp +++ b/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp @@ -17,8 +17,8 @@ namespace { WasmEdge::Runtime::Instance::ModuleInstance *createModule() { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( - "../../../plugins/wasmedge_opencvmini/" - "libwasmedgePluginWasmEdgeOpenCVMini" WASMEDGE_LIB_EXTENSION)); + "../../../plugins/wasmedge_opencvmini/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasmEdgeOpenCVMini" WASMEDGE_LIB_EXTENSION)); if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasmedge_opencvmini"sv)) { if (const auto *Module = Plugin->findModule("wasmedge_opencvmini"sv)) { diff --git a/test/plugins/wasmedge_process/wasmedge_process.cpp b/test/plugins/wasmedge_process/wasmedge_process.cpp index 9c25f412..96e66b12 100644 --- a/test/plugins/wasmedge_process/wasmedge_process.cpp +++ b/test/plugins/wasmedge_process/wasmedge_process.cpp @@ -19,8 +19,8 @@ WasmEdge::Runtime::CallingFrame DummyCallFrame(nullptr, nullptr); WasmEdge::Runtime::Instance::ModuleInstance *createModule() { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( - "../../../plugins/wasmedge_process/" - "libwasmedgePluginWasmEdgeProcess" WASMEDGE_LIB_EXTENSION)); + "../../../plugins/wasmedge_process/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasmEdgeProcess" WASMEDGE_LIB_EXTENSION)); if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasmedge_process"sv)) { if (const auto *Module = Plugin->findModule("wasmedge_process"sv)) { diff --git a/test/plugins/wasmedge_rustls/wasmedge_rustls.cpp b/test/plugins/wasmedge_rustls/wasmedge_rustls.cpp index 1e758cf0..f34d3a78 100644 --- a/test/plugins/wasmedge_rustls/wasmedge_rustls.cpp +++ b/test/plugins/wasmedge_rustls/wasmedge_rustls.cpp @@ -15,9 +15,9 @@ namespace { WasmEdge::Runtime::Instance::ModuleInstance *createModule() { using namespace std::literals::string_view_literals; - WasmEdge::Plugin::Plugin::load( - std::filesystem::u8path("../../../plugins/wasmedge_rustls/" - "libwasmedge_rustls" WASMEDGE_LIB_EXTENSION)); + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasmedge_rustls/" WASMEDGE_LIB_PREFIX + "wasmedge_rustls" WASMEDGE_LIB_EXTENSION)); if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("rustls"sv)) { if (const auto *Module = Plugin->findModule("rustls_client"sv)) { return Module->create().release(); diff --git a/test/plugins/wasmedge_tensorflow/wasmedge_tensorflow.cpp b/test/plugins/wasmedge_tensorflow/wasmedge_tensorflow.cpp index 253ab828..0b0617d3 100644 --- a/test/plugins/wasmedge_tensorflow/wasmedge_tensorflow.cpp +++ b/test/plugins/wasmedge_tensorflow/wasmedge_tensorflow.cpp @@ -17,8 +17,8 @@ namespace { WasmEdge::Runtime::Instance::ModuleInstance *createModule() { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( - "../../../plugins/wasmedge_tensorflow/" - "libwasmedgePluginWasmEdgeTensorflow" WASMEDGE_LIB_EXTENSION)); + "../../../plugins/wasmedge_tensorflow/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasmEdgeTensorflow" WASMEDGE_LIB_EXTENSION)); if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasmedge_tensorflow"sv)) { if (const auto *Module = Plugin->findModule("wasmedge_tensorflow"sv)) { diff --git a/test/plugins/wasmedge_tensorflowlite/wasmedge_tensorflowlite.cpp b/test/plugins/wasmedge_tensorflowlite/wasmedge_tensorflowlite.cpp index b0cfa5c3..ebddf267 100644 --- a/test/plugins/wasmedge_tensorflowlite/wasmedge_tensorflowlite.cpp +++ b/test/plugins/wasmedge_tensorflowlite/wasmedge_tensorflowlite.cpp @@ -17,8 +17,8 @@ namespace { WasmEdge::Runtime::Instance::ModuleInstance *createModule() { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( - "../../../plugins/wasmedge_tensorflowlite/" - "libwasmedgePluginWasmEdgeTensorflowLite" WASMEDGE_LIB_EXTENSION)); + "../../../plugins/wasmedge_tensorflowlite/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasmEdgeTensorflowLite" WASMEDGE_LIB_EXTENSION)); if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasmedge_tensorflowlite"sv)) { if (const auto *Module = Plugin->findModule("wasmedge_tensorflowlite"sv)) { diff --git a/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp b/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp index d526178a..1bf5bdc3 100644 --- a/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp +++ b/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp @@ -20,8 +20,8 @@ WasmEdge::Runtime::CallingFrame DummyCallFrame(nullptr, nullptr); WasmEdge::Runtime::Instance::ModuleInstance *createModule() { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( - "../../../plugins/wasmedge_zlib/" - "libwasmedgePluginWasmEdgeZlib" WASMEDGE_LIB_EXTENSION)); + "../../../plugins/wasmedge_zlib/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasmEdgeZlib" WASMEDGE_LIB_EXTENSION)); if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasmedge_zlib"sv)) { if (const auto *Module = Plugin->findModule("wasmedge_zlib"sv)) { return Module->create().release(); From 42c340167eb624e3f606718c81c22640286904b3 Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 2 Apr 2024 22:44:51 +0800 Subject: [PATCH 271/623] [WASI-NN] ggml: support grammar Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 12 ++++++++++++ plugins/wasi_nn/ggml.h | 1 + 2 files changed, 13 insertions(+) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 43afec1d..79d823c9 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -254,6 +254,16 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, return ErrNo::InvalidArgument; } } + if (Doc.at_key("grammar").error() == simdjson::SUCCESS) { + std::string_view Grammar; + auto Err = Doc["grammar"].get().get(Grammar); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the grammar option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Grammar = Grammar; + } // Check if the model is updated. if (IsModelUpdated && ModelParams.n_gpu_layers != GraphRef.NGPULayers) { @@ -269,6 +279,7 @@ Expect setupGPTParam(Graph &GraphRef, gpt_params &GPTParams) { GPTParams.sparams.penalty_repeat = static_cast(GraphRef.RepeatPenalty); GPTParams.sparams.penalty_present = static_cast(GraphRef.PresencePenalty); + GPTParams.sparams.grammar = GraphRef.Grammar; return ErrNo::Success; } @@ -590,6 +601,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, GraphRef.RepeatPenalty = SamplingDefault.penalty_repeat; GraphRef.PresencePenalty = SamplingDefault.penalty_present; GraphRef.FrequencyPenalty = SamplingDefault.penalty_freq; + GraphRef.Grammar = SamplingDefault.grammar; // If the graph builder length > 1, the data of builder[1] is the metadata. if (Builders.size() > 1) { diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index 67c70fcc..f8978d29 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -45,6 +45,7 @@ struct Graph { double RepeatPenalty = 1.10; double PresencePenalty = 0.00; double FrequencyPenalty = 0.00; + std::string Grammar; }; struct Context { From 04475616434eb9412e94ed6d51012df83235cdd0 Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 10 Apr 2024 15:11:02 +0800 Subject: [PATCH 272/623] [WASI-NN] ggml: update ggml to b2636 Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 77b97789..17e99555 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -62,7 +62,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b2534 + GIT_TAG b2636 PATCH_COMMAND git checkout . COMMAND git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch GIT_SHALLOW FALSE From 9c5da5b1d389bccc7f6493c026ddfbdfb4ce08ed Mon Sep 17 00:00:00 2001 From: dm4 Date: Fri, 12 Apr 2024 12:44:59 +0800 Subject: [PATCH 273/623] [WASI-NN] ggml: fix embedding with batch decode Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 68 ++++++++++++++++++++++++++++++++++------ 1 file changed, 58 insertions(+), 10 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 79d823c9..c6fd6a41 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -370,6 +370,57 @@ ErrNo evaluateTokens(Graph &GraphRef, struct llama_context *LlamaContext, return ErrNo::Success; } +void batchAddSeq(llama_batch &Batch, const std::vector &Tokens, + llama_seq_id SequenceId) noexcept { + for (int I = 0; I < static_cast(Tokens.size()); I++) { + // llama_batch_add_seq(llama_batch, llama_token, llama_pos, + // std::vector, logits); + llama_batch_add(Batch, Tokens[I], I, {SequenceId}, + I == static_cast(Tokens.size()) - 1); + } +} + +ErrNo batchDecode(llama_context *LlamaContext, llama_batch &Batch, + float *Output, int NEmbd) noexcept { + // Clear previous kv_cache values (irrelevant for embeddings) + llama_kv_cache_clear(LlamaContext); + + // Decode the batch. + auto Status = llama_decode(LlamaContext, Batch); + if (Status == 1) { + spdlog::error( + "[WASI-NN] GGML backend: failed to llama_decode: try reducing the size of the batch or increasing the size of context"sv); + return ErrNo::RuntimeError; + } else if (Status < 0) { + spdlog::error( + "[WASI-NN] GGML backend: failed to llama_decode: internal fatal error. Please open an issue on GitHub"sv); + return ErrNo::RuntimeError; + } + + for (int I = 0; I < Batch.n_tokens; I++) { + if (!Batch.logits[I]) { + continue; + } + + // Try to get sequence embeddings. + auto *Embd = llama_get_embeddings_seq(LlamaContext, Batch.seq_id[I][0]); + if (Embd == nullptr) { + Embd = llama_get_embeddings_ith(LlamaContext, I); + if (Embd == nullptr) { + spdlog::error( + "[WASI-NN] GGML backend: failed to get embeddings for token {}"sv, + I); + continue; + } + } + + // Normalize the embeddings. + llama_embd_normalize(Embd, Output, NEmbd); + } + + return ErrNo::Success; +} + Expect getEmbedding(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); @@ -405,8 +456,6 @@ Expect getEmbedding(WasiNNEnvironment &Env, auto *LlamaContext = llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); - // Prepare variables; - int32_t NPast = 0; // Get the context size. const uint64_t NCtx = llama_n_ctx(LlamaContext); // Minus 4 for the special tokens. (Such as , , ... tokens.) @@ -436,19 +485,18 @@ Expect getEmbedding(WasiNNEnvironment &Env, return ErrNo::PromptTooLong; } - // Evaluate input tokens. - ReturnCode = evaluateTokens(GraphRef, LlamaContext, - std::move(CxtRef.LlamaInputs), NPast); + const int32_t NEmbd = llama_n_embd(GraphRef.LlamaModel); + struct llama_batch Batch = + llama_batch_init(GraphRef.BatchSize, /* embd */ 0, /* n_seq_max */ 1); + std::vector Embeddings(NEmbd); + batchAddSeq(Batch, CxtRef.LlamaInputs, SequenceId); + ReturnCode = batchDecode(LlamaContext, Batch, Embeddings.data(), NEmbd); if (ReturnCode != ErrNo::Success) { spdlog::error("[WASI-NN] GGML backend: failed to evaluate input tokens."sv); return ReturnCode; } - const int32_t NEmbd = llama_n_embd(GraphRef.LlamaModel); - auto *Embeddings = llama_get_embeddings_seq(LlamaContext, SequenceId); - llama_embd_normalize(Embeddings, Embeddings, NEmbd); - - buildOutputEmbedding(CxtRef.LlamaOutputs, NEmbd, Embeddings); + buildOutputEmbedding(CxtRef.LlamaOutputs, NEmbd, Embeddings.data()); if (GraphRef.EnableDebugLog) { spdlog::info( From af836e8291e350de034700177a4c958f41252baa Mon Sep 17 00:00:00 2001 From: dm4 Date: Mon, 15 Apr 2024 21:48:05 +0800 Subject: [PATCH 274/623] [WASI-NN] ggml: fix input size checking of computeSingle (#3336) --- plugins/wasi_nn/ggml.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index c6fd6a41..f85b2b68 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -1151,13 +1151,15 @@ Expect computeSingle(WasiNNEnvironment &Env, if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] GGML backend: computeSingleToken"sv); } - if (CxtRef.LlamaInputs.size() == 0) { - spdlog::error("[WASI-NN] GGML backend: Llama input is not set!"sv); - return ErrNo::InvalidArgument; - } // New compute single token context. if (CxtRef.LlamaContext == nullptr) { + // Check if the input is set before setting up the context. + if (CxtRef.LlamaInputs.size() == 0) { + spdlog::error("[WASI-NN] GGML backend: Llama input is not set!"sv); + return ErrNo::InvalidArgument; + } + // Clear the outputs. if (GraphRef.EnableDebugLog) { spdlog::info( From 530cfa7b7413e415abbd0ea96c3ffb28922363c5 Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 16 Apr 2024 14:37:48 +0800 Subject: [PATCH 275/623] [WASI-NN] ggml: add ErrNo::ModelNotFound Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 7 +++++++ plugins/wasi_nn/types.h | 1 + 2 files changed, 8 insertions(+) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index f85b2b68..25da5800 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -709,6 +710,12 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, spdlog::info( "[WASI-NN][Debug] GGML backend: Finished handling model path."sv); } + // Check if the model exists. + if (!std::filesystem::exists(std::filesystem::u8path(ModelFilePath))) { + spdlog::error("[WASI-NN] GGML backend: Model file not found."sv); + Env.NNGraph.pop_back(); + return ErrNo::ModelNotFound; + } if (GraphRef.EnableDebugLog) { spdlog::info( diff --git a/plugins/wasi_nn/types.h b/plugins/wasi_nn/types.h index 453b320b..f14350b4 100644 --- a/plugins/wasi_nn/types.h +++ b/plugins/wasi_nn/types.h @@ -21,6 +21,7 @@ enum class ErrNo : uint32_t { EndOfSequence = 100, // End of Sequence Found. ContextFull = 101, // Context Full. PromptTooLong = 102, // Prompt Too Long. + ModelNotFound = 103, // Model Not Found. }; enum class TensorType : uint8_t { F16 = 0, F32 = 1, U8 = 2, I32 = 3 }; From 2977207070ad363f66c7bdb7a8b387c53725424c Mon Sep 17 00:00:00 2001 From: hydai Date: Fri, 19 Apr 2024 12:25:48 +0800 Subject: [PATCH 276/623] [WASI-NN] ggml: bump tp b2694 to fix the MoE segfault (#3345) Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 17e99555..5c153f07 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -62,7 +62,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b2636 + GIT_TAG b2694 PATCH_COMMAND git checkout . COMMAND git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch GIT_SHALLOW FALSE From 6d2978535427a7545c3cdab09b6f3e48cfa5207e Mon Sep 17 00:00:00 2001 From: dm4 Date: Mon, 22 Apr 2024 17:17:43 +0800 Subject: [PATCH 277/623] [WASI-NN] ggml: add unload function Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 11 +++++++++++ plugins/wasi_nn/ggml.h | 2 ++ plugins/wasi_nn/wasinnenv.h | 11 +++++++++++ plugins/wasi_nn/wasinnfunc.cpp | 28 ++++++++++++++++++++++++++++ plugins/wasi_nn/wasinnfunc.h | 12 ++++++++++++ plugins/wasi_nn/wasinnmodule.cpp | 1 + 6 files changed, 65 insertions(+) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 25da5800..8af214bd 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -1336,6 +1336,17 @@ Expect finiSingle(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { return ErrNo::Success; } +Expect unload(WasiNNEnvironment &Env, uint32_t GraphId) noexcept { + auto &GraphRef = Env.NNGraph[GraphId].get(); + if (GraphRef.LlamaModel != nullptr) { + llama_free_model(GraphRef.LlamaModel); + GraphRef.LlamaModel = nullptr; + } + Env.NNGraph.erase(Env.NNGraph.begin() + GraphId); + Env.mdRemoveById(GraphId); + return ErrNo::Success; +} + #else namespace { Expect reportBackendNotSupported() noexcept { diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index f8978d29..e334ca5d 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -96,4 +96,6 @@ Expect computeSingle(WASINN::WasiNNEnvironment &Env, uint32_t ContextId) noexcept; Expect finiSingle(WASINN::WasiNNEnvironment &Env, uint32_t ContextId) noexcept; +Expect unload(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId) noexcept; } // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index bc83b739..e1e45b2f 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -168,6 +168,17 @@ struct WasiNNEnvironment : return false; } + void mdRemoveById(uint32_t GraphId) noexcept { + std::unique_lock Lock(MdMutex); + for (auto It = MdMap.begin(); It != MdMap.end();) { + if (It->second == static_cast(GraphId)) { + It = MdMap.erase(It); + } else { + ++It; + } + } + } + Expect mdBuild(std::string Name, uint32_t &GraphId, Callback Load, std::vector Config = std::vector()) noexcept { diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index defaf6ae..6e609d4f 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -576,5 +576,33 @@ WasiNNFiniSingle::bodyImpl(const Runtime::CallingFrame &Frame, } } +Expect WasiNNUnload::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t GraphId) { +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + // TODO: implement RPC for unload + spdlog::error("[WASI-NN] RPC client is not implemented for unload"sv); + return WASINN::ErrNo::UnsupportedOperation; + } +#endif + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + if (Env.NNGraph.size() <= GraphId) { + spdlog::error("[WASI-NN] unload: GraphId {} does not exist."sv, GraphId); + return WASINN::ErrNo::InvalidArgument; + } + + switch (Env.NNGraph[GraphId].getBackend()) { + case WASINN::Backend::GGML: + return WASINN::GGML::unload(Env, GraphId); + default: + spdlog::error("[WASI-NN] unlaod: Only GGML backend supports unload."sv); + return WASINN::ErrNo::InvalidArgument; + } +} + } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasi_nn/wasinnfunc.h b/plugins/wasi_nn/wasinnfunc.h index 7dcfbd2a..dcda910c 100644 --- a/plugins/wasi_nn/wasinnfunc.h +++ b/plugins/wasi_nn/wasinnfunc.h @@ -161,5 +161,17 @@ class WasiNNFiniSingle : public WasiNN { uint32_t Context); }; +class WasiNNUnload : public WasiNN { +public: + WasiNNUnload(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t GraphId) { + return bodyImpl(Frame, GraphId).map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t GraphId); +}; + } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasi_nn/wasinnmodule.cpp b/plugins/wasi_nn/wasinnmodule.cpp index c2246405..b2b2c6d5 100644 --- a/plugins/wasi_nn/wasinnmodule.cpp +++ b/plugins/wasi_nn/wasinnmodule.cpp @@ -21,6 +21,7 @@ WasiNNModule::WasiNNModule() : ModuleInstance("wasi_ephemeral_nn") { addHostFunc("compute", std::make_unique(Env)); addHostFunc("compute_single", std::make_unique(Env)); addHostFunc("fini_single", std::make_unique(Env)); + addHostFunc("unload", std::make_unique(Env)); } } // namespace Host From 0dd6b3f5554bca2c91ede7e850847184ad070585 Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 23 Apr 2024 11:52:06 +0800 Subject: [PATCH 278/623] [WASI-NN] ggml: bump llama.cpp b2715 Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- plugins/wasi_nn/ggml.cpp | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 5c153f07..601f735e 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -62,7 +62,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b2694 + GIT_TAG b2715 PATCH_COMMAND git checkout . COMMAND git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch GIT_SHALLOW FALSE diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 8af214bd..401abbf0 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -472,7 +472,7 @@ Expect getEmbedding(WasiNNEnvironment &Env, llama_token_bos(GraphRef.LlamaModel)); } // Add EOS if not present. - if (CxtRef.LlamaInputs.back() != llama_token_eos(GraphRef.LlamaModel)) { + if (!llama_token_is_eog(GraphRef.LlamaModel, CxtRef.LlamaInputs.back())) { CxtRef.LlamaInputs.push_back(llama_token_eos(GraphRef.LlamaModel)); } @@ -1086,8 +1086,8 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { break; } // Deal with end of text token. - if (llama_sampling_last(CtxSampling) == - llama_token_eos(GraphRef.LlamaModel)) { + if (llama_token_is_eog(GraphRef.LlamaModel, + llama_sampling_last(CtxSampling))) { if (GraphRef.EnableLog) { spdlog::info("[WASI-NN] GGML backend: EOS token found"sv); } @@ -1266,8 +1266,8 @@ Expect computeSingle(WasiNNEnvironment &Env, CxtRef.LlamaOutputTokens.emplace_back(Id); CxtRef.LlamaOutputs += llama_token_to_piece(CxtRef.LlamaContext, Id); // Deal with end of text token. - if (llama_sampling_last(CxtRef.LlamaSampling) == - llama_token_eos(GraphRef.LlamaModel)) { + if (llama_token_is_eog(GraphRef.LlamaModel, + llama_sampling_last(CxtRef.LlamaSampling))) { ReturnCode = ErrNo::EndOfSequence; if (GraphRef.EnableLog) { spdlog::info("[WASI-NN] GGML backend: EOS token found"sv); From e080dff00ede84f275aba7137299463cf24fe6ce Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 23 Apr 2024 15:17:36 +0800 Subject: [PATCH 279/623] [WASI-NN] propagate LLAMA_NATIVE for setting the AVX/AVX2/FMA features Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 601f735e..c414068e 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -9,6 +9,13 @@ if(BACKEND STREQUAL "ggml") set(LLAMA_METAL_NDEBUG ON) set(LLAMA_ACCELERATE OFF) + if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_NATIVE) + message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_NATIVE(AVX/AVX2/FMA)") + set(LLAMA_NATIVE ON) + else() + set(LLAMA_NATIVE OFF) + endif() + if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_CUBLAS) message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_CUDA") set(LLAMA_CUDA ON) From ffe12cd37663b7350ddc234f7503a7a8b9d35000 Mon Sep 17 00:00:00 2001 From: Yi-Ying He Date: Tue, 23 Apr 2024 20:04:35 +0800 Subject: [PATCH 280/623] [CMake] Bump simdjson to 3.9.1 and fix compiler flag. (#3357) Signed-off-by: YiYing He --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index c414068e..58dfe39c 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -89,7 +89,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( simdjson GIT_REPOSITORY https://github.com/simdjson/simdjson.git - GIT_TAG tags/v3.2.1 + GIT_TAG tags/v3.9.1 GIT_SHALLOW TRUE) if(MSVC) From 74fe8368ce8aff703de0b96484fed2632201f90e Mon Sep 17 00:00:00 2001 From: hydai Date: Fri, 26 Apr 2024 11:06:03 +0800 Subject: [PATCH 281/623] [WASI-NN] ggml: bump tp b2734 to support phi-3-mini (#3364) Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 58dfe39c..51001ddb 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -69,7 +69,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b2715 + GIT_TAG b2734 PATCH_COMMAND git checkout . COMMAND git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch GIT_SHALLOW FALSE From 94bb19d3bfedcd56cc2432c71f3f400333b20ca0 Mon Sep 17 00:00:00 2001 From: dm4 Date: Fri, 26 Apr 2024 13:47:25 +0800 Subject: [PATCH 282/623] [WASI-NN] ggml: add llama log callback (#3365) Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 -- plugins/wasi_nn/ggml.cpp | 29 ++++++++++++++++++++++++++++- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 51001ddb..14768036 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -70,8 +70,6 @@ if(BACKEND STREQUAL "ggml") llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git GIT_TAG b2734 - PATCH_COMMAND git checkout . - COMMAND git apply ${CMAKE_SOURCE_DIR}/thirdparty/ggml/ggml.patch GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(llama) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 401abbf0..0ccbfad3 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -21,6 +21,31 @@ namespace WasmEdge::Host::WASINN::GGML { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML namespace { + +void LlamaLogCallback(ggml_log_level LogLevel, const char *LogText, + void *UserData) { + Graph GraphRef = *static_cast(UserData); + if (!GraphRef.EnableLog) { + return; + } + std::string Text(LogText); + // Remove the trailing newlines. + Text = Text.erase(Text.find_last_not_of("\n") + 1); + // Skip for "." + if (Text == ".") { + return; + } + if (LogLevel == GGML_LOG_LEVEL_ERROR) { + spdlog::error("[WASI-NN] llama.cpp: {}"sv, Text); + } else if (LogLevel == GGML_LOG_LEVEL_WARN) { + spdlog::warn("[WASI-NN] llama.cpp: {}"sv, Text); + } else if (LogLevel == GGML_LOG_LEVEL_INFO) { + spdlog::info("[WASI-NN] llama.cpp: {}"sv, Text); + } else if (LogLevel == GGML_LOG_LEVEL_DEBUG) { + spdlog::debug("[WASI-NN] llama.cpp: {}"sv, Text); + } +} + Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, bool *IsModelUpdated = nullptr) noexcept { simdjson::dom::parser Parser; @@ -72,7 +97,6 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, "[WASI-NN] GGML backend: Unable to retrieve the enable-log option."sv); return ErrNo::InvalidArgument; } - llama_log_set(nullptr, &GraphRef.EnableLog); } if (Doc.at_key("enable-debug-log").error() == simdjson::SUCCESS) { auto Err = Doc["enable-debug-log"].get().get(GraphRef.EnableDebugLog); @@ -652,6 +676,9 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, GraphRef.FrequencyPenalty = SamplingDefault.penalty_freq; GraphRef.Grammar = SamplingDefault.grammar; + // Set llama log callback. + llama_log_set(LlamaLogCallback, &GraphRef); + // If the graph builder length > 1, the data of builder[1] is the metadata. if (Builders.size() > 1) { const std::string Metadata(reinterpret_cast(Builders[1].data()), From 7eb07c6507b399cc872fb04af366882c13dbfd52 Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Tue, 23 Apr 2024 12:01:38 +0800 Subject: [PATCH 283/623] [Plugin] Use cmake commands for cross-platform support Signed-off-by: Shen-Ta Hsieh --- plugins/wasmedge_rustls/CMakeLists.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/plugins/wasmedge_rustls/CMakeLists.txt b/plugins/wasmedge_rustls/CMakeLists.txt index 56069740..739aea92 100644 --- a/plugins/wasmedge_rustls/CMakeLists.txt +++ b/plugins/wasmedge_rustls/CMakeLists.txt @@ -6,14 +6,14 @@ else() set(TARGET_DIR "release") endif() -set(RS_SO ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_DIR}/libwasmedge_rustls${CMAKE_SHARED_LIBRARY_SUFFIX}) +set(RS_SO ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_DIR}/${CMAKE_SHARED_LIBRARY_PREFIX}wasmedge_rustls${CMAKE_SHARED_LIBRARY_SUFFIX}) set(WASMEDGE_LIB_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../lib/api) add_custom_target(wasmedge_rustls ALL - COMMAND WASMEDGE_LIB_DIR=${WASMEDGE_LIB_DIR} LD_LIBARAY_PATH=${WASMEDGE_LIB_DIR} CARGO_TARGET_DIR=${CMAKE_CURRENT_BINARY_DIR} ${CARGO_CMD} - COMMAND cp ${RS_SO} ${CMAKE_CURRENT_BINARY_DIR} - COMMAND rm -rf ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_DIR} + COMMAND ${CMAKE_COMMAND} -E env WASMEDGE_LIB_DIR=${WASMEDGE_LIB_DIR} CARGO_TARGET_DIR=${CMAKE_CURRENT_BINARY_DIR} -- ${CARGO_CMD} + COMMAND ${CMAKE_COMMAND} -E copy ${RS_SO} ${CMAKE_CURRENT_BINARY_DIR} + COMMAND ${CMAKE_COMMAND} -E rm -rf ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_DIR} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS wasmedge_shared ) From a9ef5872ecf146873edf573885cd61ecd4b4e615 Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Tue, 23 Apr 2024 19:25:40 +0800 Subject: [PATCH 284/623] [LOG] Rename common/log.h to common/spdlog.h to to prevent name conflict with llama.cpp Signed-off-by: Shen-Ta Hsieh --- plugins/wasi_crypto/utils/evp_wrapper.h | 2 +- plugins/wasi_nn/types.h | 2 +- plugins/wasi_nn/wasinnenv.h | 2 +- plugins/wasi_nn/wasinnfunc.cpp | 2 +- plugins/wasi_ocr/wasiocrenv.h | 2 +- plugins/wasi_ocr/wasiocrfunc.cpp | 2 +- plugins/wasmedge_image/image_func.cpp | 2 +- plugins/wasmedge_tensorflow/tensorflow_func.cpp | 2 +- plugins/wasmedge_tensorflowlite/tensorflowlite_func.cpp | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/plugins/wasi_crypto/utils/evp_wrapper.h b/plugins/wasi_crypto/utils/evp_wrapper.h index 67b4ecfb..d6edc89b 100644 --- a/plugins/wasi_crypto/utils/evp_wrapper.h +++ b/plugins/wasi_crypto/utils/evp_wrapper.h @@ -17,8 +17,8 @@ #include "utils/error.h" #include "utils/secret_vec.h" -#include "common/log.h" #include "common/span.h" +#include "common/spdlog.h" #include #include diff --git a/plugins/wasi_nn/types.h b/plugins/wasi_nn/types.h index f14350b4..81a06eee 100644 --- a/plugins/wasi_nn/types.h +++ b/plugins/wasi_nn/types.h @@ -2,8 +2,8 @@ // SPDX-FileCopyrightText: 2019-2022 Second State INC #pragma once -#include "common/log.h" #include "common/span.h" +#include "common/spdlog.h" #include namespace WasmEdge::Host::WASINN { diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index e1e45b2f..260a1659 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -3,7 +3,7 @@ #pragma once -#include "common/log.h" +#include "common/spdlog.h" #include "plugin/plugin.h" #include #include diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index 6e609d4f..8466eaa4 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -2,7 +2,7 @@ // SPDX-FileCopyrightText: 2019-2022 Second State INC #include "wasinnfunc.h" -#include "common/log.h" +#include "common/spdlog.h" #include "wasinnenv.h" #include diff --git a/plugins/wasi_ocr/wasiocrenv.h b/plugins/wasi_ocr/wasiocrenv.h index 9e90a5ef..14ec7562 100644 --- a/plugins/wasi_ocr/wasiocrenv.h +++ b/plugins/wasi_ocr/wasiocrenv.h @@ -3,7 +3,7 @@ #pragma once -#include "common/log.h" +#include "common/spdlog.h" #include "plugin/plugin.h" #include diff --git a/plugins/wasi_ocr/wasiocrfunc.cpp b/plugins/wasi_ocr/wasiocrfunc.cpp index a1d17d81..f63bd021 100644 --- a/plugins/wasi_ocr/wasiocrfunc.cpp +++ b/plugins/wasi_ocr/wasiocrfunc.cpp @@ -2,7 +2,7 @@ // SPDX-FileCopyrightText: 2023 Second State INC #include "wasiocrfunc.h" -#include "common/log.h" +#include "common/spdlog.h" #include #include diff --git a/plugins/wasmedge_image/image_func.cpp b/plugins/wasmedge_image/image_func.cpp index 58d62649..0088ebff 100644 --- a/plugins/wasmedge_image/image_func.cpp +++ b/plugins/wasmedge_image/image_func.cpp @@ -3,8 +3,8 @@ #include "image_func.h" -#include "common/log.h" #include "common/span.h" +#include "common/spdlog.h" #include #include diff --git a/plugins/wasmedge_tensorflow/tensorflow_func.cpp b/plugins/wasmedge_tensorflow/tensorflow_func.cpp index eed2b4fc..6b4ba440 100644 --- a/plugins/wasmedge_tensorflow/tensorflow_func.cpp +++ b/plugins/wasmedge_tensorflow/tensorflow_func.cpp @@ -3,8 +3,8 @@ #include "tensorflow_func.h" -#include "common/log.h" #include "common/span.h" +#include "common/spdlog.h" #include "tensorflow/c/c_api.h" diff --git a/plugins/wasmedge_tensorflowlite/tensorflowlite_func.cpp b/plugins/wasmedge_tensorflowlite/tensorflowlite_func.cpp index 18af8a1b..0976dd13 100644 --- a/plugins/wasmedge_tensorflowlite/tensorflowlite_func.cpp +++ b/plugins/wasmedge_tensorflowlite/tensorflowlite_func.cpp @@ -3,8 +3,8 @@ #include "tensorflowlite_func.h" -#include "common/log.h" #include "common/span.h" +#include "common/spdlog.h" #include "tensorflow/lite/c/c_api.h" From 108520017a4eb4a4a7ffddcfc387978e242e79ea Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Tue, 23 Apr 2024 11:57:56 +0800 Subject: [PATCH 285/623] [Misc] Use `add_compile_options` for shorter code in setting subproject options * Add utf-8 option for MSVC Signed-off-by: Shen-Ta Hsieh --- plugins/wasi_nn/CMakeLists.txt | 92 +++++++++++++++++------------ test/plugins/wasi_nn/CMakeLists.txt | 7 ++- 2 files changed, 61 insertions(+), 38 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 14768036..424af4ab 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -53,16 +53,42 @@ if(BACKEND STREQUAL "ggml") # setup llama.cpp message(STATUS "Downloading llama.cpp source") - if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") + if(MSVC) add_compile_options( + /utf-8 + /wd4067 # unexpected tokens following preprocessor directive - expected a newline + /wd4101 # 'identifier' : unreferenced local variable + /wd4189 # 'identifier' : local variable is initialized but not referenced + /wd4244 # 'argument' : conversion from 'type1' to 'type2', possible loss of data + /wd4267 # 'var' : conversion from 'size_t' to 'type', possible loss of data + /wd4297 # 'function' : function assumed not to throw an exception but does + /wd4456 # declaration of 'identifier' hides previous local declaration + /wd4505 # 'function' : unreferenced local function has been removed + ) + endif() + if(CMAKE_CXX_COMPILER_ID MATCHES "GNU") + add_compile_options( + $<$:-Wno-exceptions> + -Wno-cast-align + -Wno-cast-qual + -Wno-float-conversion + -Wno-implicit-fallthrough + -Wno-unused-macros + -Wno-unused-function + -Wno-unused-variable + ) + elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang") + add_compile_options( + $<$:-Wno-exceptions> -Wno-cast-align -Wno-cast-qual -Wno-disabled-macro-expansion - -Wno-exceptions -Wno-float-conversion -Wno-implicit-fallthrough -Wno-implicit-float-conversion -Wno-unused-macros + -Wno-unused-function + -Wno-unused-variable ) endif() include(FetchContent) @@ -92,15 +118,7 @@ if(BACKEND STREQUAL "ggml") if(MSVC) if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") - get_property( - compile_options - DIRECTORY - PROPERTY COMPILE_OPTIONS - ) - set_property( - DIRECTORY - APPEND - PROPERTY COMPILE_OPTIONS + add_compile_options( -Wno-undef -Wno-suggest-override -Wno-documentation @@ -116,15 +134,11 @@ if(BACKEND STREQUAL "ggml") -Wno-format-nonliteral -Wno-unused-exception-parameter -Wno-unused-member-function - ) - unset(compile_options) + ) elseif(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") - set_property( - DIRECTORY - APPEND - PROPERTY COMPILE_OPTIONS + add_compile_options( /wd4100 # unreferenced formal parameter - ) + ) endif() endif() @@ -165,35 +179,39 @@ target_include_directories(wasmedgePluginWasiNN if(BACKEND STREQUAL "ggml") # Setup llava from llama.cpp wasmedge_add_library(llava OBJECT - ${CMAKE_BINARY_DIR}/_deps/llama-src/examples/llava/clip.cpp - ${CMAKE_BINARY_DIR}/_deps/llama-src/examples/llava/llava.cpp + ${llama_SOURCE_DIR}/examples/llava/clip.cpp + ${llama_SOURCE_DIR}/examples/llava/llava.cpp ) + if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") + target_compile_options(llava PRIVATE -Wno-error=unused-variable -Wno-error=unused-function) + endif() target_link_libraries(llava PRIVATE ggml llama) target_include_directories(llava PUBLIC - ${CMAKE_BINARY_DIR}/_deps/llama-src - ${CMAKE_BINARY_DIR}/_deps/llama-src/common - ${CMAKE_BINARY_DIR}/_deps/llama-src/examples/llava - ) - target_compile_options(llava PRIVATE - -Wno-error=unused-function - -Wno-error=unused-variable - -Wno-unused-function - -Wno-unused-variable + ${llama_SOURCE_DIR} + ${llama_SOURCE_DIR}/common + ${llama_SOURCE_DIR}/examples/llava ) - wasmedge_setup_target(llava) - # Setup include and link from llama.cpp - target_include_directories(wasmedgePluginWasiNN PUBLIC - ${CMAKE_BINARY_DIR}/_deps/llama-src - ${CMAKE_BINARY_DIR}/_deps/llama-src/examples/llava + target_include_directories(wasmedgePluginWasiNN PRIVATE + ${llama_SOURCE_DIR} + ${llama_SOURCE_DIR}examples/llava + ) + target_link_libraries(wasmedgePluginWasiNN PRIVATE + common + simdjson::simdjson + llava ) - target_link_libraries(wasmedgePluginWasiNN PRIVATE common simdjson::simdjson llava) + if(MSVC) + target_compile_options(wasmedgePluginWasiNN PUBLIC + /wd4067 # unexpected tokens following preprocessor directive - expected a newline + ) + endif() if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL) add_custom_command( TARGET wasmedgePluginWasiNN POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_BINARY_DIR}/_deps/llama-src/ggml-metal.metal ggml-metal.metal - COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_BINARY_DIR}/_deps/llama-src/ggml-common.h ggml-common.h + COMMAND ${CMAKE_COMMAND} -E copy ${llama_SOURCE_DIR}/ggml-metal.metal ggml-metal.metal + COMMAND ${CMAKE_COMMAND} -E copy ${llama_SOURCE_DIR}/ggml-common.h ggml-common.h ) endif() endif() diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index 3abca194..8674aaf1 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -59,12 +59,17 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) MD5=ad51c39cfe35d2ef35c4052b78cb3c55 ) elseif(BACKEND STREQUAL "ggml") - message(STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_ggml_fixtures") + message(STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_ggml_fixtures") download( https://huggingface.co/TheBloke/orca_mini_v3_7B-GGUF/resolve/main/orca_mini_v3_7b.Q2_K.gguf ${CMAKE_CURRENT_BINARY_DIR}/wasinn_ggml_fixtures/orca_mini.gguf MD5=f895f00678bfbf89f70d6d25f20a7b5f ) + if(MSVC) + target_compile_options(wasiNNTests PUBLIC + /wd4067 # unexpected tokens following preprocessor directive - expected a newline + ) + endif() else() # Add the other backend test files fetching here. endif() From 56183cb19b2554f60e0c5ca4d6e714afceeaf3df Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Tue, 23 Apr 2024 12:01:00 +0800 Subject: [PATCH 286/623] [WASI-NN] Use standard c++ syntax for MSVC compiler Signed-off-by: Shen-Ta Hsieh --- plugins/wasi_nn/wasinnenv.cpp | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index 83090a63..b2228ff4 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -143,21 +143,22 @@ void addOptions(const Plugin::Plugin::PluginDescriptor *, #endif } +static Plugin::PluginModule::ModuleDescriptor MD[] = { + { + /* Name */ "wasi_nn", + /* Description */ "", + /* Create */ create, + }, +}; + Plugin::Plugin::PluginDescriptor Descriptor{ - .Name = "wasi_nn", - .Description = "", - .APIVersion = Plugin::Plugin::CurrentAPIVersion, - .Version = {0, 10, 1, 0}, - .ModuleCount = 1, - .ModuleDescriptions = - (Plugin::PluginModule::ModuleDescriptor[]){ - { - .Name = "wasi_nn", - .Description = "", - .Create = create, - }, - }, - .AddOptions = addOptions, + /* Name */ "wasi_nn", + /* Description */ "", + /* APIVersion */ Plugin::Plugin::CurrentAPIVersion, + /* Version */ {0, 10, 1, 0}, + /* ModuleCount */ 1, + /* ModuleDescriptions */ MD, + /* AddOptions */ addOptions, }; } // namespace From c64de2509b53d40a825fd3290b5dae87b02dbf5a Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Tue, 23 Apr 2024 12:34:43 +0800 Subject: [PATCH 287/623] [WASI-NN] Add windows build * Fix compiler warning for variable size Signed-off-by: Shen-Ta Hsieh --- plugins/wasi_nn/ggml.cpp | 10 +++++----- test/plugins/wasi_nn/wasi_nn.cpp | 9 ++++----- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 0ccbfad3..18057a7e 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -702,10 +702,10 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, } // Handle the model path. auto Weight = Builders[0]; - const std::string BinModel(reinterpret_cast(Weight.data()), - Weight.size()); + const std::string_view BinModel(reinterpret_cast(Weight.data()), + Weight.size()); std::string ModelFilePath; - if (BinModel.substr(0, 8) == "preload:") { + if (BinModel.substr(0, 8) == "preload:"sv) { ModelFilePath = BinModel.substr(8); } else { if (GraphRef.EnableDebugLog) { @@ -716,7 +716,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // TODO: pass the model directly to ggml // Write ggml model to file. ModelFilePath = "ggml-model.bin"sv; - std::ofstream TempFile(ModelFilePath); + std::ofstream TempFile(ModelFilePath, std::ios::out | std::ios::binary); if (!TempFile) { spdlog::error( "[WASI-NN] GGML backend: Failed to create the temporary file. " @@ -726,7 +726,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, Env.NNGraph.pop_back(); return ErrNo::InvalidArgument; } - TempFile << BinModel; + TempFile.write(BinModel.data(), BinModel.size()); TempFile.close(); if (GraphRef.EnableDebugLog) { spdlog::info( diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 806c33a3..e8ccdeaf 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -41,11 +41,10 @@ createModule(std::string_view NNRPCURI = "") { } inline std::vector readEntireFile(const std::string &Path) { - std::ifstream Fin(Path, std::ios::binary | std::ios::ate); + std::ifstream Fin(Path, std::ios::in | std::ios::binary | std::ios::ate); if (!Fin) { return {}; } - Fin.seekg(0, std::ios::end); std::vector Buf(static_cast(Fin.tellg())); Fin.seekg(0, std::ios::beg); if (!Fin.read(reinterpret_cast(Buf.data()), @@ -58,7 +57,7 @@ inline std::vector readEntireFile(const std::string &Path) { template void writeBinaries(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, - std::vector Binaries, uint32_t Ptr) noexcept { + WasmEdge::Span Binaries, uint32_t Ptr) noexcept { std::copy(Binaries.begin(), Binaries.end(), MemInst.getPointer(Ptr)); } @@ -1364,7 +1363,7 @@ TEST(WasiNNTest, GGMLBackend) { writeFatPointer(MemInst, StorePtr, static_cast(WeightRead.size()), BuilderPtr); writeBinaries(MemInst, WeightRead, StorePtr); - StorePtr += WeightRead.size(); + StorePtr += static_cast(WeightRead.size()); { EXPECT_TRUE(HostFuncLoad.run(CallFrame, std::initializer_list{ @@ -1451,7 +1450,7 @@ TEST(WasiNNTest, GGMLBackend) { Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); } - StorePtr += (TensorDim.size() * 4 + TensorData.size()); + StorePtr += static_cast(TensorDim.size() * 4 + TensorData.size()); // GGML WASI-NN compute tests. // Test: compute -- context id exceeds. From 8293cff66c36e555b9cc5841968b68465962faa5 Mon Sep 17 00:00:00 2001 From: hugo-syn <61210734+hugo-syn@users.noreply.github.com> Date: Tue, 30 Apr 2024 11:38:10 +0200 Subject: [PATCH 288/623] [Misc] chore: fix typos (#3378) Signed-off-by: hugo-syn --- thirdparty/wasi_crypto/api.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thirdparty/wasi_crypto/api.hpp b/thirdparty/wasi_crypto/api.hpp index 8d6f6081..306526b5 100644 --- a/thirdparty/wasi_crypto/api.hpp +++ b/thirdparty/wasi_crypto/api.hpp @@ -400,7 +400,7 @@ static_assert(alignof(__wasi_algorithm_type_e_t) == 2, "witx calculated align"); /** * Version of a managed key. * - * A version can be an arbitrary `u64` integer, with the expection of some reserved values. + * A version can be an arbitrary `u64` integer, with the exception of some reserved values. */ using __wasi_version_t = uint64_t; @@ -544,7 +544,7 @@ static_assert(alignof(__wasi_symmetric_key_t) == 4, "witx calculated align"); * * This object type can't be directly created from raw bytes. They are only returned by functions computing MACs. * - * The host is reponsible for securely wiping them from memory on close. + * The host is responsible for securely wiping them from memory on close. */ using __wasi_symmetric_tag_t = int32_t; From 708555d9ed91c63b674aa958930252787ecf75df Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 1 May 2024 17:15:18 +0800 Subject: [PATCH 289/623] [WASI-NN] ggml: fix missing implemented body (#3379) Signed-off-by: hydai --- plugins/wasi_nn/ggml.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 18057a7e..4430fa66 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -1401,6 +1401,19 @@ Expect getOutput(WasiNNEnvironment &, uint32_t, uint32_t, Span, Expect compute(WasiNNEnvironment &, uint32_t) noexcept { return reportBackendNotSupported(); } +Expect getOutputSingle(WasiNNEnvironment &, uint32_t, uint32_t, + Span, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect computeSingle(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +Expect finiSingle(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +Expect unload(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} #endif } // namespace WasmEdge::Host::WASINN::GGML From dd45624ed60d16cd0945e4208cce77676256ab31 Mon Sep 17 00:00:00 2001 From: dm4 Date: Fri, 3 May 2024 16:08:42 +0800 Subject: [PATCH 290/623] [WASI-NN] ggml: support ubatch-size (#3383) Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 53 ++++++++++++++++++++++++---------------- plugins/wasi_nn/ggml.h | 1 + 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 4430fa66..467bde22 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -75,6 +75,7 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, // Context parameters (used by the llama context): // ctx-size: uint64_t // batch-size: uint64_t + // ubatch-size: uint64_t // threads: uint64_t // Sampling parameters (used by the llama sampling context). // temp: double @@ -82,6 +83,7 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, // repeat-penalty: double // presence-penalty: double // frequency-penalty: double + // grammar: string // Get the current llama parameters. llama_model_params ModelParams = llama_model_default_params(); @@ -226,6 +228,14 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, return ErrNo::InvalidArgument; } } + if (Doc.at_key("ubatch-size").error() == simdjson::SUCCESS) { + auto Err = Doc["ubatch-size"].get().get(GraphRef.UBatchSize); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the ubatch-size option."sv); + return ErrNo::InvalidArgument; + } + } if (Doc.at_key("threads").error() == simdjson::SUCCESS) { auto Err = Doc["threads"].get().get(GraphRef.Threads); if (Err) { @@ -312,6 +322,7 @@ Expect setupContextParam(Graph &GraphRef, llama_context_params &ContextParams) { ContextParams.n_ctx = GraphRef.CtxSize; ContextParams.n_batch = GraphRef.BatchSize; + ContextParams.n_ubatch = GraphRef.UBatchSize; ContextParams.n_threads = GraphRef.Threads; ContextParams.n_threads_batch = GraphRef.Threads; ContextParams.embeddings = GraphRef.Embedding; @@ -478,34 +489,30 @@ Expect getEmbedding(WasiNNEnvironment &Env, // Initialize the llama context. llama_context_params ContextParams = llama_context_default_params(); setupContextParam(GraphRef, ContextParams); + // For non-causal models, batch size must be equal to ubatch size + ContextParams.n_ubatch = ContextParams.n_batch; auto *LlamaContext = llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); - // Get the context size. - const uint64_t NCtx = llama_n_ctx(LlamaContext); - // Minus 4 for the special tokens. (Such as , , ... tokens.) - const uint64_t MaxTokensListSize = NCtx - 4; // Use the const sequence id here. const llama_seq_id SequenceId = 0; // Return value. auto ReturnCode = ErrNo::Success; - // Add BOS if not present. - if (CxtRef.LlamaInputs.front() != llama_token_bos(GraphRef.LlamaModel)) { - CxtRef.LlamaInputs.insert(CxtRef.LlamaInputs.begin(), - llama_token_bos(GraphRef.LlamaModel)); - } - // Add EOS if not present. - if (!llama_token_is_eog(GraphRef.LlamaModel, CxtRef.LlamaInputs.back())) { - CxtRef.LlamaInputs.push_back(llama_token_eos(GraphRef.LlamaModel)); + // Add SEP if not present. + if (CxtRef.LlamaInputs.back() != llama_token_sep(GraphRef.LlamaModel)) { + CxtRef.LlamaInputs.push_back(llama_token_sep(GraphRef.LlamaModel)); } // Check if the input is too long. - if (static_cast(CxtRef.LlamaInputs.size()) > MaxTokensListSize) { + if (static_cast(CxtRef.LlamaInputs.size()) > + ContextParams.n_batch) { if (GraphRef.EnableLog) { - spdlog::info("[WASI-NN] GGML backend: the prompt is too long. Your input " - "has {} tokens. Please reduce it to {} tokens."sv, - CxtRef.LlamaInputs.size(), MaxTokensListSize); + spdlog::info( + "[WASI-NN] GGML backend: the prompt is too long. " + "Your input has {} tokens exceeds batch size {}. " + "Please reduce the input size or increase your batch-size."sv, + CxtRef.LlamaInputs.size(), ContextParams.n_batch); } return ErrNo::PromptTooLong; } @@ -857,13 +864,15 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] GGML backend: set the input"sv); } - const bool AddBos = llama_should_add_bos_token(GraphRef.LlamaModel); + const bool AddSpecial = true; + const bool ParseSpecial = true; std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), Tensor.Tensor.size()); CxtRef.LlamaInputs.clear(); if (GraphRef.MMProjModelPath == ""sv) { // Text only prompt. - CxtRef.LlamaInputs = llama_tokenize(LlamaContext, Prompt, AddBos, true); + CxtRef.LlamaInputs = + llama_tokenize(LlamaContext, Prompt, AddSpecial, ParseSpecial); } else { // Handle llava format prompt. @@ -929,10 +938,12 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, std::string PromptBeforeImage = Prompt.substr(0, PlaceholderPosition); std::string PromptAfterImage = Prompt.substr(PlaceholderPosition + PromptImagePlaceholder.length()); - std::vector EmbdInputBeforeImage = - llama_tokenize(LlamaContext, PromptBeforeImage, AddBos, true); + std::vector EmbdInputBeforeImage = llama_tokenize( + LlamaContext, PromptBeforeImage, AddSpecial, ParseSpecial); + // Do not add special token (such as , , ... tokens.) to the + // tokens after the image. std::vector EmbdInputAfterImage = - llama_tokenize(LlamaContext, PromptAfterImage, false, true); + llama_tokenize(LlamaContext, PromptAfterImage, false, ParseSpecial); CxtRef.LlavaImagePosition = EmbdInputBeforeImage.size(); CxtRef.LlamaInputs.reserve(EmbdInputBeforeImage.size() + EmbdInputAfterImage.size()); diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index e334ca5d..25c02caa 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -38,6 +38,7 @@ struct Graph { // Context parameters: uint64_t CtxSize; uint64_t BatchSize; + uint64_t UBatchSize; uint64_t Threads; // Sampling parameters: double Temp = 0.80; From 415496769ec95727627106e6e4e83c9e1365643c Mon Sep 17 00:00:00 2001 From: dm4 Date: Fri, 3 May 2024 16:11:28 +0800 Subject: [PATCH 291/623] [WASI-NN] ggml: bump llama.cpp b2781 Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 424af4ab..0899b891 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -95,7 +95,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b2734 + GIT_TAG b2781 GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(llama) From 1a03679a6cc92d87f30acf770bef453d77f5c08e Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 15 May 2024 12:03:53 +0800 Subject: [PATCH 292/623] [WASI-NN] ggml: add more debug log Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 95 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 92 insertions(+), 3 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 467bde22..70a0f702 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -544,7 +544,7 @@ Expect getEmbedding(WasiNNEnvironment &Env, llama_free(LlamaContext); if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: compute...Done"sv); + spdlog::info("[WASI-NN][Debug] GGML backend: getEmbedding...Done"sv); } return ErrNo::Success; @@ -587,6 +587,10 @@ loadBase64ImageFromPrompt(Graph &GraphRef, clip_ctx *ClipContext, // Follow this link for the supported image formats: // https://github.com/ggerganov/llama.cpp/blob/master/common/stb_image.h + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: loadBase64ImageFromPrompt"sv); + } + // Find ` load(WasiNNEnvironment &Env, Span> Builders, Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, uint32_t &ContextId) noexcept { + auto &GraphRef = Env.NNGraph[GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: initExecCtx"sv); + } Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); ContextId = Env.NNContext.size() - 1; - auto &GraphRef = Env.NNGraph[GraphId].get(); if (GraphRef.EnableLog) { spdlog::info("[WASI-NN] GGML backend: llama_system_info: {}"sv, llama_print_system_info()); } + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: initExecCtx...Done"sv); + } return ErrNo::Success; } @@ -871,11 +886,21 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, CxtRef.LlamaInputs.clear(); if (GraphRef.MMProjModelPath == ""sv) { // Text only prompt. + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: tokenize text prompt"sv); + } CxtRef.LlamaInputs = llama_tokenize(LlamaContext, Prompt, AddSpecial, ParseSpecial); + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: tokenize text prompt...Done"sv); + } } else { // Handle llava format prompt. - + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: handle llava format prompt"sv); + } // Check if the prompt contains a base64 image. bool ContainsBase64Image = containsBase64Image(GraphRef, Prompt); if (GraphRef.ImagePath == ""sv && ContainsBase64Image == false) { @@ -953,6 +978,10 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, CxtRef.LlamaInputs.insert(CxtRef.LlamaInputs.end(), EmbdInputAfterImage.begin(), EmbdInputAfterImage.end()); + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: handle llava format prompt...Done"sv); + } } CxtRef.LlamaNInputs = CxtRef.LlamaInputs.size(); if (GraphRef.EnableDebugLog) { @@ -980,6 +1009,11 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, uint32_t Index, Span OutBuffer, uint32_t &BytesWritten) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: getOutput with Index {}"sv, + Index); + } // Index 1 is for the metadata of the outputs. if (Index == 1) { std::string Metadata; @@ -991,12 +1025,22 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, } std::copy_n(Metadata.data(), Metadata.length(), OutBuffer.data()); BytesWritten = Metadata.length(); + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: getOutput with Index {}...Done"sv, + Index); + } return ErrNo::Success; } std::copy_n(CxtRef.LlamaOutputs.data(), CxtRef.LlamaOutputs.length(), OutBuffer.data()); BytesWritten = CxtRef.LlamaOutputs.length(); + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: getOutput with Index {}...Done"sv, + Index); + } return ErrNo::Success; } @@ -1149,12 +1193,20 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // We free the contexts here to keep the ggml plugin stateless. // Users could fully control the contexts by themselves via their prompt. + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: delete llama context to make it stateless"sv); + } llama_sampling_free(CtxSampling); llama_free(LlamaContext); if (CxtRef.LlavaImageEmbd != nullptr) { llava_image_embed_free(CxtRef.LlavaImageEmbd); CxtRef.LlavaImageEmbd = nullptr; } + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: delete llama context to make it stateless...Done"sv); + } if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] GGML backend: compute...Done"sv); @@ -1167,6 +1219,12 @@ Expect getOutputSingle(WasiNNEnvironment &Env, uint32_t ContextId, uint32_t Index, Span OutBuffer, uint32_t &BytesWritten) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: getOutputSingle with Index {}"sv, + Index); + } // Index 1 is for the metadata of the outputs. if (Index == 1) { std::string Metadata; @@ -1178,12 +1236,22 @@ Expect getOutputSingle(WasiNNEnvironment &Env, uint32_t ContextId, } std::copy_n(Metadata.data(), Metadata.length(), OutBuffer.data()); BytesWritten = Metadata.length(); + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: getOutputSingle with Index {}...Done"sv, + Index); + } return ErrNo::Success; } std::string LastToken = llama_token_to_piece(CxtRef.LlamaContext, CxtRef.LlamaOutputTokens.back()); std::copy_n(LastToken.data(), LastToken.length(), OutBuffer.data()); BytesWritten = LastToken.length(); + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: getOutputSingle with Index {}...Done"sv, + Index); + } return ErrNo::Success; } @@ -1333,6 +1401,10 @@ Expect finiSingle(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: finiSingle"sv); + } + // Logging for the llama timings. if (GraphRef.EnableLog) { llama_print_timings(CxtRef.LlamaContext); @@ -1371,17 +1443,34 @@ Expect finiSingle(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // Reset the context variables. CxtRef.LlamaNPast = 0; + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: finiSingle...Done"sv); + } + return ErrNo::Success; } Expect unload(WasiNNEnvironment &Env, uint32_t GraphId) noexcept { auto &GraphRef = Env.NNGraph[GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: unload"sv); + } if (GraphRef.LlamaModel != nullptr) { + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: unload: free llama model"sv); + } llama_free_model(GraphRef.LlamaModel); GraphRef.LlamaModel = nullptr; + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] GGML backend: unload: free llama model...Done"sv); + } } Env.NNGraph.erase(Env.NNGraph.begin() + GraphId); Env.mdRemoveById(GraphId); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] GGML backend: unload...Done"sv); + } return ErrNo::Success; } From 50cc1b81ad1287d61b2732916f9e2b3d283cd8bd Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 15 May 2024 15:22:08 +0800 Subject: [PATCH 293/623] [WASI-NN] ggml: free the llama batch Signed-off-by: dm4 --- plugins/wasi_nn/ggml.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 70a0f702..e84c0a29 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -542,6 +542,7 @@ Expect getEmbedding(WasiNNEnvironment &Env, // We free the contexts here to keep the ggml plugin stateless. // Users could fully control the contexts by themselves via their prompt. llama_free(LlamaContext); + llama_batch_free(Batch); if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] GGML backend: getEmbedding...Done"sv); From 2c4fed02f22cf2459ec675841f64ec5cef78f6ab Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 15 May 2024 15:27:40 +0800 Subject: [PATCH 294/623] [WASI-NN] ggml: bump llama.cpp b2879 Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 0899b891..28b5e01b 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -95,7 +95,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b2781 + GIT_TAG b2879 GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(llama) From 3c72135f855825b3bb02259f37f71a8b5c12f44b Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Thu, 16 May 2024 16:35:39 +0800 Subject: [PATCH 295/623] [CI] Add aarch64 for `ci-image-base` Signed-off-by: Yi Huang --- utils/docker/Dockerfile.ci-image-base | 2 +- utils/docker/docker-bake.ci-image-base | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 utils/docker/docker-bake.ci-image-base diff --git a/utils/docker/Dockerfile.ci-image-base b/utils/docker/Dockerfile.ci-image-base index 4c7fa368..92f7bd67 100644 --- a/utils/docker/Dockerfile.ci-image-base +++ b/utils/docker/Dockerfile.ci-image-base @@ -13,7 +13,7 @@ RUN apt update && apt upgrade -y \ RUN curl -fsSL https://download.docker.com/linux/ubuntu/gpg | apt-key add - RUN add-apt-repository \ - "deb [arch=amd64] https://download.docker.com/linux/ubuntu \ + "deb [arch=$(dpkg --print-architecture)] https://download.docker.com/linux/ubuntu \ $(lsb_release -cs) \ stable" diff --git a/utils/docker/docker-bake.ci-image-base b/utils/docker/docker-bake.ci-image-base new file mode 100644 index 00000000..0054032c --- /dev/null +++ b/utils/docker/docker-bake.ci-image-base @@ -0,0 +1,26 @@ +group "default" { + targets = [ + "x86_64", + "aarch64" + ] +} + +target "base" { + dockerfile = "./utils/docker/Dockerfile.ci-image-base" + context = "." +} + +target "x86_64" { + inherits = ["base"] + platforms = ["linux/amd64"] + tags = [ + "wasmedge/wasmedge:ci-image-base", + "wasmedge/wasmedge:ci-image-base_x86_64" + ] +} + +target "aarch64" { + inherits = ["base"] + platforms = ["linux/arm64"] + tags = ["wasmedge/wasmedge:ci-image-base_aarch64"] +} From 6e7996a1c47c836b866ff684a2c5e5a6eb2724eb Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Mon, 20 May 2024 16:19:55 +0800 Subject: [PATCH 296/623] [WASI-NN] Hide unsupported flags from nvcc Signed-off-by: Shen-Ta Hsieh --- plugins/wasi_nn/CMakeLists.txt | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 28b5e01b..a35418aa 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -55,15 +55,16 @@ if(BACKEND STREQUAL "ggml") message(STATUS "Downloading llama.cpp source") if(MSVC) add_compile_options( - /utf-8 - /wd4067 # unexpected tokens following preprocessor directive - expected a newline - /wd4101 # 'identifier' : unreferenced local variable - /wd4189 # 'identifier' : local variable is initialized but not referenced - /wd4244 # 'argument' : conversion from 'type1' to 'type2', possible loss of data - /wd4267 # 'var' : conversion from 'size_t' to 'type', possible loss of data - /wd4297 # 'function' : function assumed not to throw an exception but does - /wd4456 # declaration of 'identifier' hides previous local declaration - /wd4505 # 'function' : unreferenced local function has been removed + $<$:/utf-8> + $<$:-Xcompiler=/utf-8> + $<$:/wd4067> # unexpected tokens following preprocessor directive - expected a newline + $<$:/wd4101> # 'identifier' : unreferenced local variable + $<$:/wd4189> # 'identifier' : local variable is initialized but not referenced + $<$:/wd4244> # 'argument' : conversion from 'type1' to 'type2', possible loss of data + $<$:/wd4267> # 'var' : conversion from 'size_t' to 'type', possible loss of data + $<$:/wd4297> # 'function' : function assumed not to throw an exception but does + $<$:/wd4456> # declaration of 'identifier' hides previous local declaration + $<$:/wd4505> # 'function' : unreferenced local function has been removed ) endif() if(CMAKE_CXX_COMPILER_ID MATCHES "GNU") From 9c396e7d3c87fd92d65112b255b7ed6e62b98902 Mon Sep 17 00:00:00 2001 From: dm4 Date: Mon, 20 May 2024 16:18:50 +0800 Subject: [PATCH 297/623] [WASI-NN] ggml: bump llama.cpp b2943 Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index a35418aa..285a727f 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -96,7 +96,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b2879 + GIT_TAG b2943 GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(llama) From c45ae3b7a931b4ac84ec58b7628d706408ac346b Mon Sep 17 00:00:00 2001 From: YiYing He Date: Mon, 20 May 2024 17:48:15 +0800 Subject: [PATCH 298/623] [Plugin] Refactor the wasi-logging plugin. 1. Rename the files. 2. Add the file sink. 3. Refine the architecture. Signed-off-by: YiYing He --- plugins/wasi_logging/base.h | 28 ++++++ plugins/wasi_logging/env.cpp | 9 +- plugins/wasi_logging/env.h | 36 +++++++ plugins/wasi_logging/func.cpp | 96 +++++++++++-------- .../wasi_logging/{wasi_logging => }/func.h | 11 ++- plugins/wasi_logging/module.cpp | 12 ++- .../wasi_logging/{wasi_logging => }/module.h | 10 +- plugins/wasi_logging/wasi_logging/base.h | 22 ----- plugins/wasi_logging/wasi_logging/enum.h | 13 --- plugins/wasi_logging/wasi_logging/env.h | 22 ----- test/plugins/wasi_logging/wasi_logging.cpp | 18 ++-- 11 files changed, 158 insertions(+), 119 deletions(-) create mode 100644 plugins/wasi_logging/base.h create mode 100644 plugins/wasi_logging/env.h rename plugins/wasi_logging/{wasi_logging => }/func.h (55%) rename plugins/wasi_logging/{wasi_logging => }/module.h (57%) delete mode 100644 plugins/wasi_logging/wasi_logging/base.h delete mode 100644 plugins/wasi_logging/wasi_logging/enum.h delete mode 100644 plugins/wasi_logging/wasi_logging/env.h diff --git a/plugins/wasi_logging/base.h b/plugins/wasi_logging/base.h new file mode 100644 index 00000000..dc1c51ca --- /dev/null +++ b/plugins/wasi_logging/base.h @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "env.h" + +#include "common/errcode.h" +#include "runtime/callingframe.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { +namespace WASILogging { + +enum class LogLevel : uint32_t { Trace, Debug, Info, Warn, Error, Critical }; + +template class Func : public Runtime::HostFunction { +public: + Func(LogEnv &HostEnv) : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + LogEnv &Env; +}; + +} // namespace WASILogging +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_logging/env.cpp b/plugins/wasi_logging/env.cpp index 0e5160a1..97f2b141 100644 --- a/plugins/wasi_logging/env.cpp +++ b/plugins/wasi_logging/env.cpp @@ -1,5 +1,8 @@ -#include "wasi_logging/env.h" -#include "wasi_logging/module.h" +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "env.h" +#include "module.h" namespace WasmEdge { namespace Host { @@ -32,4 +35,4 @@ EXPORT_GET_DESCRIPTOR(Descriptor) } // namespace } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/plugins/wasi_logging/env.h b/plugins/wasi_logging/env.h new file mode 100644 index 00000000..838c7850 --- /dev/null +++ b/plugins/wasi_logging/env.h @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "plugin/plugin.h" + +#include +#include + +namespace WasmEdge { +namespace Host { +namespace WASILogging { + +class LogEnv { +public: + LogEnv() noexcept { + // TODO: Use the config in WasmEdge to set the logging level. + StdoutLogger->set_level(spdlog::level::trace); + StderrLogger->set_level(spdlog::level::trace); + StdoutLogger->set_pattern(DefFormat); + StderrLogger->set_pattern(DefFormat); + } + + const std::shared_ptr StdoutLogger = + spdlog::stdout_color_mt("wasi_logging_stdout"); + const std::shared_ptr StderrLogger = + spdlog::stderr_color_mt("wasi_logging_stderr"); + const std::string DefFormat = "[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] %v"; + std::shared_ptr FileLogger; + std::string LogFileName; +}; + +} // namespace WASILogging +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_logging/func.cpp b/plugins/wasi_logging/func.cpp index 5f0a2c01..9d44d4c3 100644 --- a/plugins/wasi_logging/func.cpp +++ b/plugins/wasi_logging/func.cpp @@ -1,85 +1,99 @@ -#include "wasi_logging/func.h" -#include "wasi_logging/enum.h" +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "func.h" + #include namespace WasmEdge { namespace Host { +namespace WASILogging { using namespace std::literals; -Expect WasiLoggingLog::body(const Runtime::CallingFrame &Frame, - uint32_t Level, uint32_t CxtPtr, - uint32_t CxtLen, uint32_t MsgPtr, - uint32_t MsgLen) { +Expect Log::body(const Runtime::CallingFrame &Frame, uint32_t Level, + uint32_t CxtPtr, uint32_t CxtLen, uint32_t MsgPtr, + uint32_t MsgLen) { // Check memory instance from module. auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); } - // Get Buffer Pointer + // Get Buffer Pointer. char *CxtBuf = MemInst->getPointer(CxtPtr); char *MsgBuf = MemInst->getPointer(MsgPtr); if (CxtBuf == nullptr || MsgBuf == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); } - // Copy Context String and Message String - std::string CxtStr, MsgStr; - std::copy_n(CxtBuf, CxtLen, std::back_inserter(CxtStr)); - std::copy_n(MsgBuf, MsgLen, std::back_inserter(MsgStr)); + // Get Context and Message string_view + std::string_view CxtSV(CxtBuf, CxtLen); + std::string_view MsgSV(MsgBuf, MsgLen); // Setup Logger for Stdout or Stderr - CxtStr == "stderr"sv ? Env.isCxtStrStderr = true : Env.isCxtStrStderr = false; - auto logger = Env.isCxtStrStderr ? Env.StderrLogger : Env.StdoutLogger; - - // Construct Spdlog Message - std::string SpdlogMsg; - if (!CxtStr.empty()) { - SpdlogMsg = CxtStr + ": " + MsgStr; + std::shared_ptr Logger; + if (CxtSV == "stdout"sv || CxtSV == ""sv) { + Logger = Env.StdoutLogger; + } else if (CxtSV == "stderr"sv) { + Logger = Env.StderrLogger; } else { - SpdlogMsg = MsgStr; + if (CxtSV != Env.LogFileName) { + try { + spdlog::drop("wasi_logging_file"); + Env.FileLogger = + spdlog::basic_logger_mt("wasi_logging_file", std::string(CxtSV)); + Env.FileLogger->set_pattern(Env.DefFormat); + Env.LogFileName = CxtSV; + // TODO: Use the config in WasmEdge to set the logging level. + Env.FileLogger->set_level(spdlog::level::trace); + } catch (const spdlog::spdlog_ex &Ex) { + spdlog::error("[WasiLogging] Cannot log into file: {}"sv, Ex.what()); + return Unexpect(ErrCode::Value::HostFuncError); + } + } + Logger = Env.FileLogger; } // Print Message by Logging Level - switch (Level) { - case WASILOGGING::WasiLoggingLevel::Trace: - logger->trace(SpdlogMsg); + switch (static_cast(Level)) { + case LogLevel::Trace: + Logger->trace(MsgSV); break; - case WASILOGGING::WasiLoggingLevel::Debug: - logger->debug(SpdlogMsg); + case LogLevel::Debug: + Logger->debug(MsgSV); break; - case WASILOGGING::WasiLoggingLevel::Info: - logger->info(SpdlogMsg); + case LogLevel::Info: + Logger->info(MsgSV); break; - case WASILOGGING::WasiLoggingLevel::Warn: - logger->warn(SpdlogMsg); + case LogLevel::Warn: + Logger->warn(MsgSV); break; - case WASILOGGING::WasiLoggingLevel::Error: - logger->error(SpdlogMsg); + case LogLevel::Error: + Logger->error(MsgSV); break; - case WASILOGGING::WasiLoggingLevel::Critical: - logger->critical(SpdlogMsg); + case LogLevel::Critical: + Logger->critical(MsgSV); break; default: spdlog::error("[WasiLogging] Unrecognized Logging Level: {}"sv, Level); spdlog::error("[WasiLogging] Trace Level = {}"sv, - static_cast(WASILOGGING::WasiLoggingLevel::Trace)); + static_cast(LogLevel::Trace)); spdlog::error("[WasiLogging] Debug Level = {}"sv, - static_cast(WASILOGGING::WasiLoggingLevel::Debug)); + static_cast(LogLevel::Debug)); spdlog::error("[WasiLogging] Info Level = {}"sv, - static_cast(WASILOGGING::WasiLoggingLevel::Info)); + static_cast(LogLevel::Info)); spdlog::error("[WasiLogging] Warn Level = {}"sv, - static_cast(WASILOGGING::WasiLoggingLevel::Warn)); + static_cast(LogLevel::Warn)); spdlog::error("[WasiLogging] Error Level = {}"sv, - static_cast(WASILOGGING::WasiLoggingLevel::Error)); - spdlog::error( - "[WasiLogging] Critical Level = {}"sv, - static_cast(WASILOGGING::WasiLoggingLevel::Critical)); + static_cast(LogLevel::Error)); + spdlog::error("[WasiLogging] Critical Level = {}"sv, + static_cast(LogLevel::Critical)); return Unexpect(ErrCode::Value::HostFuncError); } return {}; } +} // namespace WASILogging } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/plugins/wasi_logging/wasi_logging/func.h b/plugins/wasi_logging/func.h similarity index 55% rename from plugins/wasi_logging/wasi_logging/func.h rename to plugins/wasi_logging/func.h index f7631e98..60e81c9d 100644 --- a/plugins/wasi_logging/wasi_logging/func.h +++ b/plugins/wasi_logging/func.h @@ -1,17 +1,22 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once -#include "wasi_logging/base.h" +#include "base.h" namespace WasmEdge { namespace Host { +namespace WASILogging { -class WasiLoggingLog : public WasiLogging { +class Log : public Func { public: - WasiLoggingLog(WasiLoggingEnvironment &HostEnv) : WasiLogging(HostEnv) {} + Log(LogEnv &HostEnv) : Func(HostEnv) {} Expect body(const Runtime::CallingFrame &Frame, uint32_t Level, uint32_t CxtPtr, uint32_t CxtLen, uint32_t MsgPtr, uint32_t MsgLen); }; +} // namespace WASILogging } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasi_logging/module.cpp b/plugins/wasi_logging/module.cpp index 6ebd7879..938ecc20 100644 --- a/plugins/wasi_logging/module.cpp +++ b/plugins/wasi_logging/module.cpp @@ -1,5 +1,9 @@ -#include "wasi_logging/module.h" -#include "wasi_logging/func.h" +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "module.h" +#include "func.h" + #include namespace WasmEdge { @@ -9,8 +13,8 @@ using namespace std::literals; WasiLoggingModule::WasiLoggingModule() : ModuleInstance("wasi:logging/logging"sv) { - addHostFunc("log"sv, std::make_unique(Env)); + addHostFunc("log"sv, std::make_unique(Env)); } } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/plugins/wasi_logging/wasi_logging/module.h b/plugins/wasi_logging/module.h similarity index 57% rename from plugins/wasi_logging/wasi_logging/module.h rename to plugins/wasi_logging/module.h index 12504522..9cf1260e 100644 --- a/plugins/wasi_logging/wasi_logging/module.h +++ b/plugins/wasi_logging/module.h @@ -1,7 +1,11 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once +#include "env.h" + #include "runtime/instance/module.h" -#include "wasi_logging/env.h" namespace WasmEdge { namespace Host { @@ -10,10 +14,10 @@ class WasiLoggingModule : public Runtime::Instance::ModuleInstance { public: WasiLoggingModule(); - WasiLoggingEnvironment &getEnv() { return Env; } + WASILogging::LogEnv &getEnv() { return Env; } private: - WasiLoggingEnvironment Env; + WASILogging::LogEnv Env; }; } // namespace Host diff --git a/plugins/wasi_logging/wasi_logging/base.h b/plugins/wasi_logging/wasi_logging/base.h deleted file mode 100644 index 11b5504a..00000000 --- a/plugins/wasi_logging/wasi_logging/base.h +++ /dev/null @@ -1,22 +0,0 @@ -#pragma once - -#include "wasi_logging/env.h" - -#include "common/errcode.h" -#include "runtime/callingframe.h" -#include "runtime/hostfunc.h" - -namespace WasmEdge { -namespace Host { - -template class WasiLogging : public Runtime::HostFunction { -public: - WasiLogging(WasiLoggingEnvironment &HostEnv) - : Runtime::HostFunction(0), Env(HostEnv) {} - -protected: - WasiLoggingEnvironment &Env; -}; - -} // namespace Host -} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasi_logging/wasi_logging/enum.h b/plugins/wasi_logging/wasi_logging/enum.h deleted file mode 100644 index 5c4e2e25..00000000 --- a/plugins/wasi_logging/wasi_logging/enum.h +++ /dev/null @@ -1,13 +0,0 @@ -#pragma once - -#include - -namespace WasmEdge { -namespace Host { -namespace WASILOGGING { - -enum WasiLoggingLevel : uint32_t { Trace, Debug, Info, Warn, Error, Critical }; - -} // namespace WASILOGGING -} // namespace Host -} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasi_logging/wasi_logging/env.h b/plugins/wasi_logging/wasi_logging/env.h deleted file mode 100644 index 2c199f96..00000000 --- a/plugins/wasi_logging/wasi_logging/env.h +++ /dev/null @@ -1,22 +0,0 @@ -#pragma once - -#include "plugin/plugin.h" -#include -namespace WasmEdge { -namespace Host { - -class WasiLoggingEnvironment { -public: - WasiLoggingEnvironment() noexcept { - StdoutLogger->set_level(spdlog::level::trace); - StderrLogger->set_level(spdlog::level::trace); - } - bool isCxtStrStderr = false; - inline const static std::shared_ptr StdoutLogger = - spdlog::stdout_color_mt("wasi_logging_stdout"); - inline const static std::shared_ptr StderrLogger = - spdlog::stderr_color_mt("wasi_logging_stderr"); -}; - -} // namespace Host -} // namespace WasmEdge diff --git a/test/plugins/wasi_logging/wasi_logging.cpp b/test/plugins/wasi_logging/wasi_logging.cpp index a3720f2c..471b7297 100644 --- a/test/plugins/wasi_logging/wasi_logging.cpp +++ b/test/plugins/wasi_logging/wasi_logging.cpp @@ -1,12 +1,17 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "func.h" +#include "module.h" + #include "common/defines.h" #include "runtime/instance/module.h" -#include "wasi_logging/func.h" -#include "wasi_logging/module.h" #include -#include #include #include + +#include #include namespace { @@ -56,7 +61,7 @@ TEST(WasiLoggingTests, func_log) { // Clear the memory[0, 32]. fillMemContent(MemInst, 0, 32); // Set strings in memory - fillMemContent(MemInst, 0, std::string("CxtStr")); + fillMemContent(MemInst, 0, std::string("stdout")); fillMemContent(MemInst, 8, std::string("stderr")); fillMemContent(MemInst, 16, std::string("MsgStr")); @@ -65,7 +70,7 @@ TEST(WasiLoggingTests, func_log) { EXPECT_NE(FuncInst, nullptr); EXPECT_TRUE(FuncInst->isHostFunction()); auto &HostFuncInst = - dynamic_cast(FuncInst->getHostFunc()); + dynamic_cast(FuncInst->getHostFunc()); // Show All Level EXPECT_TRUE(HostFuncInst.run( @@ -99,7 +104,6 @@ TEST(WasiLoggingTests, func_log) { std::initializer_list{ UINT32_C(5), UINT32_C(0), UINT32_C(6), UINT32_C(16), UINT32_C(6)}, {})); - EXPECT_FALSE(WasiLoggingMod->getEnv().isCxtStrStderr); // Stderr Context EXPECT_TRUE(HostFuncInst.run( @@ -107,7 +111,6 @@ TEST(WasiLoggingTests, func_log) { std::initializer_list{ UINT32_C(0), UINT32_C(8), UINT32_C(6), UINT32_C(16), UINT32_C(6)}, {})); - EXPECT_TRUE(WasiLoggingMod->getEnv().isCxtStrStderr); // UnKnown Level EXPECT_FALSE(HostFuncInst.run( @@ -115,7 +118,6 @@ TEST(WasiLoggingTests, func_log) { std::initializer_list{ UINT32_C(6), UINT32_C(0), UINT32_C(6), UINT32_C(16), UINT32_C(6)}, {})); - EXPECT_FALSE(WasiLoggingMod->getEnv().isCxtStrStderr); delete WasiLoggingMod; } From 9a2c58f867a94afdb7c99fe8dd30b9ad76c01829 Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 22 May 2024 15:09:11 +0800 Subject: [PATCH 299/623] [WASI-NN] ggml: bump llama.cpp b2961 (#3418) Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 285a727f..e66acd56 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -96,7 +96,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b2943 + GIT_TAG b2961 GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(llama) From 009f20b25a1a9e9543482f242c1dca3952e824fe Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 22 May 2024 17:57:39 +0800 Subject: [PATCH 300/623] [WASI-NN] ggml: bump llama.cpp b2963 (#3419) Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index e66acd56..45236ef3 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -96,7 +96,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b2961 + GIT_TAG b2963 GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(llama) From 6e27ef61b501dd9b040311eecde3aa72bc8bf070 Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Wed, 22 May 2024 02:08:41 +0800 Subject: [PATCH 301/623] [WASI-Crypto] Fix compile error on gcc 13 Signed-off-by: Yi Huang --- plugins/wasi_crypto/utils/handles_manager.h | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/plugins/wasi_crypto/utils/handles_manager.h b/plugins/wasi_crypto/utils/handles_manager.h index ba01f157..2282a283 100644 --- a/plugins/wasi_crypto/utils/handles_manager.h +++ b/plugins/wasi_crypto/utils/handles_manager.h @@ -48,7 +48,8 @@ template class BaseHandlesManager { BaseHandlesManager &operator=(BaseHandlesManager &&) noexcept = delete; /// @param TypeID A unique number - BaseHandlesManager(uint8_t TypeID) noexcept : LastHandle{TypeID, 0} {} + explicit BaseHandlesManager(uint8_t TypeID) noexcept + : LastHandle{TypeID, 0} {} WasiCryptoExpect close(HandleType Handle) noexcept { std::unique_lock Lock{Mutex}; @@ -140,6 +141,8 @@ class RcHandlesManager ManagerType>::HandleWrapper; public: + using detail::BaseHandlesManager::BaseHandlesManager; + /// Get the return copy. WasiCryptoExpect get(HandleType Handle) noexcept { std::shared_lock Lock{this->Mutex}; @@ -184,6 +187,8 @@ class RefHandlesManager ManagerType>::HandleWrapper; public: + using detail::BaseHandlesManager::BaseHandlesManager; + /// Get the return reference. WasiCryptoExpect> get(HandleType Handle) noexcept { From b2ef349407c27f87758857c21c412213a0e4c6f5 Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Wed, 22 May 2024 18:07:55 +0800 Subject: [PATCH 302/623] [Misc] Remove dup files Signed-off-by: Yi Huang --- utils/docker/Dockerfile.build-plugins-deps | 2 +- .../Dockerfile.manylinux2014-build-plugins-deps | 2 +- .../Dockerfile.manylinux_2_28-build-plugins-deps | 2 +- utils/docker/build.sh | 2 +- utils/docker/install-opencvmini.sh | 16 ---------------- utils/opencvmini/install-opencvmini.sh | 5 +++-- 6 files changed, 7 insertions(+), 22 deletions(-) delete mode 100644 utils/docker/install-opencvmini.sh diff --git a/utils/docker/Dockerfile.build-plugins-deps b/utils/docker/Dockerfile.build-plugins-deps index 21c73d97..ae797d66 100644 --- a/utils/docker/Dockerfile.build-plugins-deps +++ b/utils/docker/Dockerfile.build-plugins-deps @@ -7,6 +7,6 @@ RUN apt update && apt install -y \ RUN rm -rf /var/lib/apt/lists/* -COPY install-opencvmini.sh . +COPY opencvmini/install-opencvmini.sh . ENV OPENCV_VERSION=4.8.0 RUN [ "/bin/bash", "install-opencvmini.sh" ] diff --git a/utils/docker/Dockerfile.manylinux2014-build-plugins-deps b/utils/docker/Dockerfile.manylinux2014-build-plugins-deps index 4762b087..359f4451 100644 --- a/utils/docker/Dockerfile.manylinux2014-build-plugins-deps +++ b/utils/docker/Dockerfile.manylinux2014-build-plugins-deps @@ -13,7 +13,7 @@ RUN yum-config-manager --add-repo https://cli.github.com/packages/rpm/gh-cli.rep WORKDIR /root -COPY docker/install-opencvmini.sh . +COPY opencvmini/install-opencvmini.sh . ENV OPENCV_VERSION "4.8.0" RUN [ "/bin/bash", "install-opencvmini.sh" ] diff --git a/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps b/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps index c8ed5dbf..4be4bf97 100644 --- a/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps +++ b/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps @@ -14,7 +14,7 @@ RUN yum install -y yum-utils && \ WORKDIR /root -COPY docker/install-opencvmini.sh . +COPY opencvmini/install-opencvmini.sh . ENV OPENCV_VERSION "4.8.0" RUN [ "/bin/bash", "install-opencvmini.sh" ] diff --git a/utils/docker/build.sh b/utils/docker/build.sh index 3d99a86f..bdcea629 100755 --- a/utils/docker/build.sh +++ b/utils/docker/build.sh @@ -15,7 +15,7 @@ function docker_build local NAME_TAG=${NAME}:${TAG} echo "Building docker image \"${NAME_TAG}\" from file \"${FILENAME}\"." - ( set -x; docker build "$@" -f "${FILENAME}" -t "${NAME_TAG}" . ) + ( set -x; docker build "$@" -f "docker/${FILENAME}" -t "${NAME_TAG}" . ) if [[ "${TAG}" == im-* ]]; then INTERMEDIATES+=( "${NAME_TAG}" ) diff --git a/utils/docker/install-opencvmini.sh b/utils/docker/install-opencvmini.sh deleted file mode 100644 index f66ef586..00000000 --- a/utils/docker/install-opencvmini.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env bash -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# SPDX-FileCopyrightText: 2019-2023 Second State INC - -wget -O opencv.zip https://github.com/opencv/opencv/archive/refs/tags/${OPENCV_VERSION}.zip - -unzip opencv.zip -mv opencv-${OPENCV_VERSION} opencv - -mkdir -p opencv/build && cd opencv/build -# Configure -cmake -GNinja .. -# Build -cmake --build . -# Install to system -cmake --install . diff --git a/utils/opencvmini/install-opencvmini.sh b/utils/opencvmini/install-opencvmini.sh index ac5128ef..bd95a902 100644 --- a/utils/opencvmini/install-opencvmini.sh +++ b/utils/opencvmini/install-opencvmini.sh @@ -1,11 +1,12 @@ #!/usr/bin/env bash # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-FileCopyrightText: 2019-2023 Second State INC +OPENCV_VERSION=${OPENCV_VERSION:-4.8.0} -wget -O opencv.zip https://github.com/opencv/opencv/archive/refs/tags/4.8.0.zip +wget -O opencv.zip https://github.com/opencv/opencv/archive/refs/tags/${OPENCV_VERSION}.zip unzip opencv.zip -mv opencv-4.8.0 opencv +mv opencv-${OPENCV_VERSION} opencv mkdir -p opencv/build && cd opencv/build # Configure From ffc516145f549db7fda9990cf702a34e12a4d001 Mon Sep 17 00:00:00 2001 From: Yi-Ying He Date: Mon, 27 May 2024 18:23:22 +0800 Subject: [PATCH 303/623] [CMake] Merge and move the simdjson fetching. (#3426) Signed-off-by: YiYing He --- plugins/wasi_nn/CMakeLists.txt | 45 +--------------------------------- 1 file changed, 1 insertion(+), 44 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 45236ef3..f39b65b3 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -104,50 +104,7 @@ if(BACKEND STREQUAL "ggml") set_property(TARGET common PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET llama PROPERTY POSITION_INDEPENDENT_CODE ON) - # setup simdjson - find_package(simdjson QUIET) - if(simdjson_FOUND) - message(STATUS "SIMDJSON found") - else() - message(STATUS "Downloading SIMDJSON source") - include(FetchContent) - FetchContent_Declare( - simdjson - GIT_REPOSITORY https://github.com/simdjson/simdjson.git - GIT_TAG tags/v3.9.1 - GIT_SHALLOW TRUE) - - if(MSVC) - if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") - add_compile_options( - -Wno-undef - -Wno-suggest-override - -Wno-documentation - -Wno-sign-conversion - -Wno-extra-semi-stmt - -Wno-old-style-cast - -Wno-error=unused-parameter - -Wno-error=unused-template - -Wno-conditional-uninitialized - -Wno-implicit-int-conversion - -Wno-shorten-64-to-32 - -Wno-range-loop-bind-reference - -Wno-format-nonliteral - -Wno-unused-exception-parameter - -Wno-unused-member-function - ) - elseif(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") - add_compile_options( - /wd4100 # unreferenced formal parameter - ) - endif() - endif() - - FetchContent_MakeAvailable(simdjson) - set_property(TARGET simdjson PROPERTY POSITION_INDEPENDENT_CODE ON) - - message(STATUS "Downloading SIMDJSON source -- done") - endif() + wasmedge_setup_simdjson() endif() wasmedge_add_library(wasmedgePluginWasiNN From 3333beeb569e5129e41732161703119198370957 Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 28 May 2024 12:17:16 +0800 Subject: [PATCH 304/623] [WASI-NN] ggml: bump llama.cpp b3014 (#3434) Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index f39b65b3..092fb5a4 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -96,7 +96,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b2963 + GIT_TAG b3014 GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(llama) From f3bcc790b059838118085bb136632da8699bdf78 Mon Sep 17 00:00:00 2001 From: "Shen-Ta Hsieh(BestSteve)" Date: Thu, 30 May 2024 01:09:32 +0800 Subject: [PATCH 305/623] [Misc] Upgrade LLVM to 17.0.6, cmake to 3.29.3 (#3278) Signed-off-by: Shen-Ta Hsieh --- utils/docker/Dockerfile.manylinux2014_aarch64 | 42 +++++++++---------- utils/docker/Dockerfile.manylinux2014_x86_64 | 42 +++++++++---------- utils/docker/Dockerfile.ubuntu2104_armv7l | 8 ++-- utils/docker/SHA256SUM.manylinux2014 | 13 +++--- 4 files changed, 53 insertions(+), 52 deletions(-) diff --git a/utils/docker/Dockerfile.manylinux2014_aarch64 b/utils/docker/Dockerfile.manylinux2014_aarch64 index 72c5a1a4..f2ef726e 100644 --- a/utils/docker/Dockerfile.manylinux2014_aarch64 +++ b/utils/docker/Dockerfile.manylinux2014_aarch64 @@ -19,38 +19,38 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil export CFGFLAGS="--prefix=/opt/rh/devtoolset-10/root/usr --disable-shared --libdir=/opt/rh/devtoolset-10/root/usr/lib64" && \ curl -s -L -O --remote-name-all \ https://github.com/facebook/zstd/releases/download/v1.5.5/zstd-1.5.5.tar.gz \ - https://github.com/Kitware/CMake/releases/download/v3.26.4/cmake-3.26.4.tar.gz \ + https://github.com/Kitware/CMake/releases/download/v3.29.3/cmake-3.29.3.tar.gz \ https://github.com/ninja-build/ninja/archive/refs/tags/v1.11.1.tar.gz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/llvm-16.0.5.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/lld-16.0.5.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/libunwind-16.0.5.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/cmake-16.0.5.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/third-party-16.0.5.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/clang-16.0.5.src.tar.xz && \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/llvm-17.0.6.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/lld-17.0.6.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/libunwind-17.0.6.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/cmake-17.0.6.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/third-party-17.0.6.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/clang-17.0.6.src.tar.xz && \ sha256sum -c SHA256SUM.manylinux2014 && \ gzip -dc zstd-1.5.5.tar.gz | tar -xf - && \ - gzip -dc cmake-3.26.4.tar.gz | tar -xf - && \ + gzip -dc cmake-3.29.3.tar.gz | tar -xf - && \ gzip -dc v1.11.1.tar.gz | tar -xf - && \ - xz -dc llvm-16.0.5.src.tar.xz | tar -xf - && \ - xz -dc lld-16.0.5.src.tar.xz | tar -xf - && \ - xz -dc libunwind-16.0.5.src.tar.xz | tar -xf - && \ - xz -dc cmake-16.0.5.src.tar.xz | tar -xf - && \ - xz -dc third-party-16.0.5.src.tar.xz | tar -xf - && \ - xz -dc clang-16.0.5.src.tar.xz | tar -xf - && \ + xz -dc llvm-17.0.6.src.tar.xz | tar -xf - && \ + xz -dc lld-17.0.6.src.tar.xz | tar -xf - && \ + xz -dc libunwind-17.0.6.src.tar.xz | tar -xf - && \ + xz -dc cmake-17.0.6.src.tar.xz | tar -xf - && \ + xz -dc third-party-17.0.6.src.tar.xz | tar -xf - && \ + xz -dc clang-17.0.6.src.tar.xz | tar -xf - && \ export ZSTDFLAGS=(PREFIX=/opt/rh/devtoolset-10/root/usr LIBDIR=/opt/rh/devtoolset-10/root/usr/lib64 SED_ERE_OPT=--regexp-extended MOREFLAGS="-std=c17 -O3 -fPIC -fPIE -fvisibility=hidden") && \ cd zstd-1.5.5 && make -s "${ZSTDFLAGS[@]}" -j $CPU && make -s "${ZSTDFLAGS[@]}" install && rm -vf /opt/rh/devtoolset-10/root/usr/lib64/libzstd.so* && cd - && \ mkdir build && cd build && /opt/python/cp311-cp311/bin/python \ ../ninja-1.11.1/configure.py --bootstrap \ --with-python=/opt/python/cp311-cp311/bin/python && \ cp -v ninja /opt/rh/devtoolset-10/root/usr/bin/ninja && cd - && rm -rf build && \ - mkdir build && cd build && ../cmake-3.26.4/configure --prefix=/opt/rh/devtoolset-10/root/usr \ + mkdir build && cd build && ../cmake-3.29.3/configure --prefix=/opt/rh/devtoolset-10/root/usr \ --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mv -v llvm-16.0.5.src llvm && \ - mv -v lld-16.0.5.src lld && \ - mv -v libunwind-16.0.5.src libunwind && \ - mv -v cmake-16.0.5.src cmake && \ - mv -v third-party-16.0.5.src third-party && \ - mv -v clang-16.0.5.src clang && \ + mv -v llvm-17.0.6.src llvm && \ + mv -v lld-17.0.6.src lld && \ + mv -v libunwind-17.0.6.src libunwind && \ + mv -v cmake-17.0.6.src cmake && \ + mv -v third-party-17.0.6.src third-party && \ + mv -v clang-17.0.6.src clang && \ cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_INSTALL_PREFIX=/opt/rh/devtoolset-10/root/usr \ -DPython3_ROOT_DIR=/opt/python/cp311-cp311 -DLLVM_LIBDIR_SUFFIX=64 \ diff --git a/utils/docker/Dockerfile.manylinux2014_x86_64 b/utils/docker/Dockerfile.manylinux2014_x86_64 index 9555dbbf..924ff57c 100644 --- a/utils/docker/Dockerfile.manylinux2014_x86_64 +++ b/utils/docker/Dockerfile.manylinux2014_x86_64 @@ -19,38 +19,38 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil export CFGFLAGS="--prefix=/opt/rh/devtoolset-11/root/usr --disable-shared --libdir=/opt/rh/devtoolset-11/root/usr/lib64" && \ curl -s -L -O --remote-name-all \ https://github.com/facebook/zstd/releases/download/v1.5.5/zstd-1.5.5.tar.gz \ - https://github.com/Kitware/CMake/releases/download/v3.26.4/cmake-3.26.4.tar.gz \ + https://github.com/Kitware/CMake/releases/download/v3.29.3/cmake-3.29.3.tar.gz \ https://github.com/ninja-build/ninja/archive/refs/tags/v1.11.1.tar.gz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/llvm-16.0.5.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/lld-16.0.5.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/libunwind-16.0.5.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/cmake-16.0.5.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/third-party-16.0.5.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-16.0.5/clang-16.0.5.src.tar.xz && \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/llvm-17.0.6.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/lld-17.0.6.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/libunwind-17.0.6.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/cmake-17.0.6.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/third-party-17.0.6.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/clang-17.0.6.src.tar.xz && \ sha256sum -c SHA256SUM.manylinux2014 && \ gzip -dc zstd-1.5.5.tar.gz | tar -xf - && \ - gzip -dc cmake-3.26.4.tar.gz | tar -xf - && \ + gzip -dc cmake-3.29.3.tar.gz | tar -xf - && \ gzip -dc v1.11.1.tar.gz | tar -xf - && \ - xz -dc llvm-16.0.5.src.tar.xz | tar -xf - && \ - xz -dc lld-16.0.5.src.tar.xz | tar -xf - && \ - xz -dc libunwind-16.0.5.src.tar.xz | tar -xf - && \ - xz -dc cmake-16.0.5.src.tar.xz | tar -xf - && \ - xz -dc third-party-16.0.5.src.tar.xz | tar -xf - && \ - xz -dc clang-16.0.5.src.tar.xz | tar -xf - && \ + xz -dc llvm-17.0.6.src.tar.xz | tar -xf - && \ + xz -dc lld-17.0.6.src.tar.xz | tar -xf - && \ + xz -dc libunwind-17.0.6.src.tar.xz | tar -xf - && \ + xz -dc cmake-17.0.6.src.tar.xz | tar -xf - && \ + xz -dc third-party-17.0.6.src.tar.xz | tar -xf - && \ + xz -dc clang-17.0.6.src.tar.xz | tar -xf - && \ export ZSTDFLAGS=(PREFIX=/opt/rh/devtoolset-11/root/usr LIBDIR=/opt/rh/devtoolset-11/root/usr/lib64 SED_ERE_OPT=--regexp-extended MOREFLAGS="-std=c17 -O3 -fPIC -fPIE -fvisibility=hidden") && \ cd zstd-1.5.5 && make -s "${ZSTDFLAGS[@]}" -j $CPU && make -s "${ZSTDFLAGS[@]}" install && rm -vf /opt/rh/devtoolset-11/root/usr/lib64/libzstd.so* && cd - && \ mkdir build && cd build && /opt/python/cp311-cp311/bin/python \ ../ninja-1.11.1/configure.py --bootstrap \ --with-python=/opt/python/cp311-cp311/bin/python && \ cp -v ninja /opt/rh/devtoolset-11/root/usr/bin/ninja && cd - && rm -rf build && \ - mkdir build && cd build && ../cmake-3.26.4/configure --prefix=/opt/rh/devtoolset-11/root/usr \ + mkdir build && cd build && ../cmake-3.29.3/configure --prefix=/opt/rh/devtoolset-11/root/usr \ --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mv -v llvm-16.0.5.src llvm && \ - mv -v lld-16.0.5.src lld && \ - mv -v libunwind-16.0.5.src libunwind && \ - mv -v cmake-16.0.5.src cmake && \ - mv -v third-party-16.0.5.src third-party && \ - mv -v clang-16.0.5.src clang && \ + mv -v llvm-17.0.6.src llvm && \ + mv -v lld-17.0.6.src lld && \ + mv -v libunwind-17.0.6.src libunwind && \ + mv -v cmake-17.0.6.src cmake && \ + mv -v third-party-17.0.6.src third-party && \ + mv -v clang-17.0.6.src clang && \ cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_INSTALL_PREFIX=/opt/rh/devtoolset-11/root/usr \ -DPython3_ROOT_DIR=/opt/python/cp311-cp311 -DLLVM_LIBDIR_SUFFIX=64 \ diff --git a/utils/docker/Dockerfile.ubuntu2104_armv7l b/utils/docker/Dockerfile.ubuntu2104_armv7l index 28e7c94c..d4984a11 100644 --- a/utils/docker/Dockerfile.ubuntu2104_armv7l +++ b/utils/docker/Dockerfile.ubuntu2104_armv7l @@ -30,12 +30,12 @@ RUN apt update && apt upgrade -y \ # CMake build from source to avoid compiler_id_detection fails when using QEMU user-mode emulation # See: https://gitlab.kitware.com/cmake/cmake/-/issues/20568 -#RUN wget https://github.com/Kitware/CMake/releases/download/v3.21.1/cmake-3.21.1.tar.gz --no-check-certificate && \ -# tar zxvf cmake-3.21.1.tar.gz && \ -# cd cmake-3.21.1 && \ +#RUN wget https://github.com/Kitware/CMake/releases/download/v3.29.3/cmake-3.29.3.tar.gz --no-check-certificate && \ +# tar zxvf cmake-3.29.3.tar.gz && \ +# cd cmake-3.29.3 && \ # ./configure && \ # make install -j$(nproc) && \ -# cd .. && rm -rf cmake-3.21.1 +# cd .. && rm -rf cmake-3.29.3 RUN rm -rf /var/lib/apt/lists/* diff --git a/utils/docker/SHA256SUM.manylinux2014 b/utils/docker/SHA256SUM.manylinux2014 index 903268be..2e48695e 100644 --- a/utils/docker/SHA256SUM.manylinux2014 +++ b/utils/docker/SHA256SUM.manylinux2014 @@ -1,8 +1,9 @@ -9400d49acd53a4b8f310de60554a891436db5a19f6f227f99f0de13e4afaaaff cmake-16.0.5.src.tar.xz -e7f65970298a60e9608a9fc55ea9af5e9c8e1bc0dc0067f3e9f10eb3fe3e8986 libunwind-16.0.5.src.tar.xz -0c593d1c23f626dc33caa8bf112868f77126e018b58dd1641f5ae6aa1c2a0ce3 lld-16.0.5.src.tar.xz -701b764a182d8ea8fb017b6b5f7f5f1272a29f17c339b838f48de894ffdd4f91 llvm-16.0.5.src.tar.xz -0a4bbb8505e95570e529d6b3d5176e93beb3260f061de9001e320d57b59aed59 third-party-16.0.5.src.tar.xz +a78f668a726ae1d3d9a7179996d97b12b90fb76ab9442a43110b972ff7ad9029 clang-17.0.6.src.tar.xz +807f069c54dc20cb47b21c1f6acafdd9c649f3ae015609040d6182cab01140f4 cmake-17.0.6.src.tar.xz +252aee1448d49caa04954fd5e27d189dd51570557313e7b281636716a238bccb cmake-3.29.3.tar.gz +9e7535a353aa862730b4ba38df42e06f6856b40c4cc51b57f27b5046dc21d70d libunwind-17.0.6.src.tar.xz +4ac13125616dc44905b85820aa403d27ec1226329b7f674daeb5f5584c6f0b22 lld-17.0.6.src.tar.xz +b638167da139126ca11917b6880207cc6e8f9d1cbb1a48d87d017f697ef78188 llvm-17.0.6.src.tar.xz +3054d0a9c9375dab1a4539cc2cc45ab340341c5d71475f9599ba7752e222947b third-party-17.0.6.src.tar.xz 31747ae633213f1eda3842686f83c2aa1412e0f5691d1c14dbbcc67fe7400cea v1.11.1.tar.gz 9c4396cc829cfae319a6e2615202e82aad41372073482fce286fac78646d3ee4 zstd-1.5.5.tar.gz -f4bb3456c415f01e929d96983b851c49d02b595bf4f99edbbfc55626437775a7 clang-16.0.5.src.tar.xz From 754585afc8ae9eb1393ec3a9fe47b5f5421cc7af Mon Sep 17 00:00:00 2001 From: Fusaaaann <59491356+Fusaaaann@users.noreply.github.com> Date: Thu, 30 May 2024 22:43:43 +0800 Subject: [PATCH 306/623] [WASI-NN] ggml: add mmap feature to ggml.cpp, ggml.h (#3436) Signed-off-by: Fusaaaann --- plugins/wasi_nn/ggml.cpp | 12 ++++++++++++ plugins/wasi_nn/ggml.h | 1 + 2 files changed, 13 insertions(+) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index e84c0a29..65e57512 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -68,10 +68,12 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, // reverse-prompt: string // mmproj: string // image: string + // use-mmap: bool // Model parameters (need to reload the model if updated): // n-gpu-layers: int64_t // main-gpu: int64_t // tensor-split: string, comma-separated floating number list + // use-mmap: use mmap // Context parameters (used by the llama context): // ctx-size: uint64_t // batch-size: uint64_t @@ -90,6 +92,7 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, ModelParams.n_gpu_layers = GraphRef.NGPULayers; ModelParams.main_gpu = GraphRef.MainGPU; ModelParams.tensor_split = GraphRef.TensorSplit.data(); + ModelParams.use_mmap = GraphRef.UseMMap; // The plugin parameters. if (Doc.at_key("enable-log").error() == simdjson::SUCCESS) { @@ -210,6 +213,14 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, GraphRef.TensorSplit.push_back(0.0f); } } + if (Doc.at_key("use-mmap").error() == simdjson::SUCCESS) { + auto Err = Doc["use-mmap"].get().get(GraphRef.UseMMap); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the use-mmap option."sv); + return ErrNo::InvalidArgument; + } + } // The context parameters. if (Doc.at_key("ctx-size").error() == simdjson::SUCCESS) { @@ -771,6 +782,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, ModelParams.n_gpu_layers = GraphRef.NGPULayers; ModelParams.main_gpu = GraphRef.MainGPU; ModelParams.tensor_split = GraphRef.TensorSplit.data(); + ModelParams.use_mmap = GraphRef.UseMMap; GraphRef.LlamaModel = llama_load_model_from_file(GraphRef.ModelFilePath.c_str(), ModelParams); if (GraphRef.LlamaModel == nullptr) { diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index 25c02caa..76890216 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -35,6 +35,7 @@ struct Graph { int64_t MainGPU = 0; // Use GPU 0 by default int64_t NGPULayers = 0; std::vector TensorSplit; + bool UseMMap = true; // Context parameters: uint64_t CtxSize; uint64_t BatchSize; From a061741a93b30c586427b61343f6cffd8251998e Mon Sep 17 00:00:00 2001 From: hydai Date: Fri, 31 May 2024 15:16:31 +0800 Subject: [PATCH 307/623] [Rustls] Deprecated the rustls plugin because the crate can be built with the Wasm target (#3442) Signed-off-by: hydai --- plugins/CMakeLists.txt | 4 - plugins/wasmedge_rustls/.gitignore | 1 - plugins/wasmedge_rustls/CMakeLists.txt | 19 - plugins/wasmedge_rustls/Cargo.toml | 18 - plugins/wasmedge_rustls/src/lib.rs | 703 ------------------ test/plugins/CMakeLists.txt | 4 - test/plugins/wasmedge_rustls/CMakeLists.txt | 35 - .../wasmedge_rustls/wasmedge_rustls.cpp | 54 -- 8 files changed, 838 deletions(-) delete mode 100644 plugins/wasmedge_rustls/.gitignore delete mode 100644 plugins/wasmedge_rustls/CMakeLists.txt delete mode 100644 plugins/wasmedge_rustls/Cargo.toml delete mode 100644 plugins/wasmedge_rustls/src/lib.rs delete mode 100644 test/plugins/wasmedge_rustls/CMakeLists.txt delete mode 100644 test/plugins/wasmedge_rustls/wasmedge_rustls.cpp diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index 8de19e70..348c98a4 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -71,10 +71,6 @@ if(WASMEDGE_PLUGIN_WASI_LOGGING) add_subdirectory(wasi_logging) endif() -if(WASMEDGE_PLUGIN_RUSTLS) - add_subdirectory(wasmedge_rustls) -endif() - if(WASMEDGE_PLUGIN_ZLIB) add_subdirectory(wasmedge_zlib) endif() diff --git a/plugins/wasmedge_rustls/.gitignore b/plugins/wasmedge_rustls/.gitignore deleted file mode 100644 index eb5a316c..00000000 --- a/plugins/wasmedge_rustls/.gitignore +++ /dev/null @@ -1 +0,0 @@ -target diff --git a/plugins/wasmedge_rustls/CMakeLists.txt b/plugins/wasmedge_rustls/CMakeLists.txt deleted file mode 100644 index 739aea92..00000000 --- a/plugins/wasmedge_rustls/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -if(CMAKE_BUILD_TYPE STREQUAL "Debug") - set(CARGO_CMD cargo build) - set(TARGET_DIR "debug") -else() - set(CARGO_CMD cargo build --release) - set(TARGET_DIR "release") -endif() - -set(RS_SO ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_DIR}/${CMAKE_SHARED_LIBRARY_PREFIX}wasmedge_rustls${CMAKE_SHARED_LIBRARY_SUFFIX}) - -set(WASMEDGE_LIB_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../lib/api) - -add_custom_target(wasmedge_rustls ALL - COMMAND ${CMAKE_COMMAND} -E env WASMEDGE_LIB_DIR=${WASMEDGE_LIB_DIR} CARGO_TARGET_DIR=${CMAKE_CURRENT_BINARY_DIR} -- ${CARGO_CMD} - COMMAND ${CMAKE_COMMAND} -E copy ${RS_SO} ${CMAKE_CURRENT_BINARY_DIR} - COMMAND ${CMAKE_COMMAND} -E rm -rf ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_DIR} - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} - DEPENDS wasmedge_shared -) diff --git a/plugins/wasmedge_rustls/Cargo.toml b/plugins/wasmedge_rustls/Cargo.toml deleted file mode 100644 index 1e97c526..00000000 --- a/plugins/wasmedge_rustls/Cargo.toml +++ /dev/null @@ -1,18 +0,0 @@ -[package] -name = "wasmedge_rustls_plugin" -version = "0.2.0" -edition = "2021" - -[lib] -name = "wasmedge_rustls" -path = "src/lib.rs" -crate-type = ["cdylib"] - -[dependencies] -libc = "0.2" -rustls = "0.20" -bytes = "1" -webpki-roots = "0.22" -wasmedge_plugin_sdk = "0.2.0" -log = "0.4" -thiserror = "1" diff --git a/plugins/wasmedge_rustls/src/lib.rs b/plugins/wasmedge_rustls/src/lib.rs deleted file mode 100644 index d32e84c2..00000000 --- a/plugins/wasmedge_rustls/src/lib.rs +++ /dev/null @@ -1,703 +0,0 @@ -use thiserror::Error; - -#[derive(Error, Debug)] -pub enum TlsError { - #[error("{0}")] - Tls(#[from] rustls::Error), - #[error("{0}")] - IO(#[from] std::io::Error), - #[error("ParamError")] - ParamError, -} - -impl TlsError { - pub fn error_code(&self) -> i32 { - match self { - TlsError::ParamError => -1, - TlsError::Tls(tls_err) => match tls_err { - rustls::Error::InappropriateMessage { .. } => -2, - rustls::Error::InappropriateHandshakeMessage { .. } => -3, - rustls::Error::CorruptMessage => -4, - rustls::Error::CorruptMessagePayload(_) => -5, - rustls::Error::NoCertificatesPresented => -6, - rustls::Error::UnsupportedNameType => -7, - rustls::Error::DecryptError => -8, - rustls::Error::EncryptError => -9, - rustls::Error::PeerIncompatibleError(_) => -10, - rustls::Error::PeerMisbehavedError(_) => -11, - rustls::Error::AlertReceived(_) => -12, - rustls::Error::InvalidCertificateEncoding => -13, - rustls::Error::InvalidCertificateSignatureType => -14, - rustls::Error::InvalidCertificateSignature => -15, - rustls::Error::InvalidCertificateData(_) => -16, - rustls::Error::InvalidSct(_) => -17, - rustls::Error::General(_) => -18, - rustls::Error::FailedToGetCurrentTime => -19, - rustls::Error::FailedToGetRandomBytes => -20, - rustls::Error::HandshakeNotComplete => -21, - rustls::Error::PeerSentOversizedRecord => -22, - rustls::Error::NoApplicationProtocol => -23, - rustls::Error::BadMaxFragmentSize => -24, - }, - TlsError::IO(io_err) if io_err.kind() == std::io::ErrorKind::WouldBlock => -25, - TlsError::IO(_) => -26, - } - } -} - -#[repr(C)] -pub struct TlsIoState { - tls_bytes_to_write: u32, - plaintext_bytes_to_read: u32, - peer_has_closed: bool, -} - -impl From for TlsIoState { - fn from(value: rustls::IoState) -> Self { - TlsIoState { - tls_bytes_to_write: value.tls_bytes_to_write() as u32, - plaintext_bytes_to_read: value.plaintext_bytes_to_read() as u32, - peer_has_closed: value.peer_has_closed(), - } - } -} - -mod tls_client { - use std::{ - io::{Read, Write}, - sync::Arc, - }; - - use bytes::{Buf, BufMut}; - use rustls::{OwnedTrustAnchor, RootCertStore}; - - use crate::TlsError; - use crate::TlsIoState; - - pub struct Ctx { - pub client_configs: Vec>>, - pub client_codec: Vec>, - } - - impl Ctx { - pub fn new() -> Ctx { - let mut root_store = RootCertStore::empty(); - root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map( - |ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - }, - )); - let config = rustls::ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_store) - .with_no_client_auth(); - - Ctx { - client_configs: vec![Some(Arc::new(config))], - client_codec: Vec::with_capacity(1024), - } - } - - pub fn default_client_config(&mut self) -> usize { - 0 - } - - pub fn new_codec( - &mut self, - server_name: &str, - config_id: usize, - ) -> Result { - let config = self - .client_configs - .get(config_id) - .ok_or(TlsError::ParamError)? - .clone() - .ok_or(TlsError::ParamError)?; - - let name = server_name.try_into().map_err(|_| TlsError::ParamError)?; - let new_codec = rustls::ClientConnection::new(config, name)?; - let new_codec = ClientCodec(new_codec); - - if let Some((id, item)) = self - .client_codec - .iter_mut() - .enumerate() - .find(|(_, item)| item.is_none()) - { - debug_assert!(item.is_none()); - let _ = item.insert(new_codec); - Ok(id) - } else { - let id = self.client_codec.len(); - self.client_codec.push(Some(new_codec)); - Ok(id) - } - } - - pub fn delete_codec(&mut self, codec_id: usize) { - if let Some(codec) = self.client_codec.get_mut(codec_id) { - let _ = codec.take(); - } - } - } - - #[derive(Debug)] - pub struct ClientCodec(pub rustls::ClientConnection); - - impl ClientCodec { - pub fn is_handshaking(&self) -> bool { - self.0.is_handshaking() - } - - pub fn process_new_packets(&mut self) -> Result { - Ok(self.0.process_new_packets()?.into()) - } - - pub fn send_close_notify(&mut self) { - self.0.send_close_notify(); - } - - pub fn write_raw(&mut self, raw_buf: &[u8]) -> Result { - let conn = &mut self.0; - Ok(conn.writer().write(raw_buf)?) - } - - pub fn write_tls(&mut self, tls_buf: &mut [u8]) -> Result { - let conn = &mut self.0; - Ok(conn.write_tls(&mut tls_buf.writer())?) - } - - pub fn read_raw(&mut self, raw_buf: &mut [u8]) -> Result { - let conn = &mut self.0; - Ok(conn.reader().read(raw_buf)?) - } - - pub fn read_tls(&mut self, tls_buf: &[u8]) -> Result { - let conn = &mut self.0; - Ok(conn.read_tls(&mut tls_buf.reader())?) - } - } - - #[cfg(test)] - mod tls_client_test { - use super::*; - #[test] - fn test_ctx() { - let mut ctx = Ctx::new(); - let config_id = ctx.default_client_config(); - assert_eq!(config_id, 0); - - let codec_id_0 = ctx.new_codec("httpbin.org", config_id).unwrap(); - assert_eq!(codec_id_0, 0); - let codec_id_1 = ctx.new_codec("httpbin.org", config_id).unwrap(); - assert_eq!(codec_id_1, 1); - ctx.delete_codec(codec_id_0); - println!("{:?}", ctx.client_codec); - let codec_id_0 = ctx.new_codec("httpbin.org", config_id).unwrap(); - assert_eq!(codec_id_0, 0); - } - } -} - -mod wasmedge_client_plugin { - - use wasmedge_plugin_sdk::{ - error::CoreError, - memory::Memory, - module::{PluginModule, SyncInstanceRef}, - types::{ValType, WasmVal}, - }; - - use crate::{tls_client::*, TlsError}; - - macro_rules! match_value { - ($expression:expr, $t:path, $error:expr) => { - match $expression { - $t(v) => v, - _ => return Err($error), - } - }; - } - - fn default_config( - _inst: &mut SyncInstanceRef, - _memory: &mut Memory, - ctx: &mut Ctx, - _args: Vec, - ) -> Result, CoreError> { - let config_id = ctx.default_client_config(); - Ok(vec![WasmVal::I32(config_id as i32)]) - } - - fn new_client_codec( - _inst: &mut SyncInstanceRef, - memory: &mut Memory, - ctx: &mut Ctx, - args: Vec, - ) -> Result, CoreError> { - #[inline] - fn new_client_codec_inner( - memory: &mut Memory, - ctx: &mut Ctx, - args: Vec, - ) -> Result { - let config_id = args[0].clone(); - let server_ptr = args[1].clone(); - let server_len = args[2].clone(); - - if let (WasmVal::I32(config_id), WasmVal::I32(server_ptr), WasmVal::I32(server_len)) = - (config_id, server_ptr, server_len) - { - let server_name = memory.data_pointer(server_ptr as usize, server_len as usize); - let server_name = server_name - .and_then(|bs| std::str::from_utf8(bs).ok()) - .ok_or(TlsError::ParamError)?; - let r = ctx.new_codec(server_name, config_id as usize)?; - Ok(WasmVal::I32(r as i32)) - } else { - Err(TlsError::ParamError) - } - } - match new_client_codec_inner(memory, ctx, args) { - Ok(ok) => Ok(vec![ok]), - Err(e) => Ok(vec![WasmVal::I32(e.error_code())]), - } - } - - fn is_handshaking( - _inst: &mut SyncInstanceRef, - memory: &mut Memory, - ctx: &mut Ctx, - args: Vec, - ) -> Result, CoreError> { - #[inline] - fn is_handshaking_inner( - _memory: &mut Memory, - ctx: &mut Ctx, - args: Vec, - ) -> Result { - let codec_id = match_value!(args[0].clone(), WasmVal::I32, TlsError::ParamError); - let codec = ctx - .client_codec - .get(codec_id as usize) - .ok_or(TlsError::ParamError)? - .as_ref() - .ok_or(TlsError::ParamError)?; - - if codec.is_handshaking() { - Ok(WasmVal::I32(1)) - } else { - Ok(WasmVal::I32(0)) - } - } - - match is_handshaking_inner(memory, ctx, args) { - Ok(ok) => Ok(vec![ok]), - Err(e) => Ok(vec![WasmVal::I32(e.error_code())]), - } - } - - fn wants( - _inst: &mut SyncInstanceRef, - memory: &mut Memory, - ctx: &mut Ctx, - args: Vec, - ) -> Result, CoreError> { - #[inline] - fn wants_inner( - _memory: &mut Memory, - ctx: &mut Ctx, - args: Vec, - ) -> Result { - let codec_id = match_value!(args[0].clone(), WasmVal::I32, TlsError::ParamError); - let codec = ctx - .client_codec - .get(codec_id as usize) - .ok_or(TlsError::ParamError)? - .as_ref() - .ok_or(TlsError::ParamError)?; - match (codec.0.wants_write(), codec.0.wants_read()) { - (true, true) => Ok(WasmVal::I32(0b11)), - (true, false) => Ok(WasmVal::I32(0b10)), - (false, true) => Ok(WasmVal::I32(0b01)), - (false, false) => Ok(WasmVal::I32(0)), - } - } - - match wants_inner(memory, ctx, args) { - Ok(ok) => Ok(vec![ok]), - Err(e) => Ok(vec![WasmVal::I32(e.error_code())]), - } - } - - fn delete_codec( - _inst: &mut SyncInstanceRef, - memory: &mut Memory, - ctx: &mut Ctx, - args: Vec, - ) -> Result, CoreError> { - #[inline] - fn delete_codec_inner( - _memory: &mut Memory, - ctx: &mut Ctx, - args: Vec, - ) -> Result { - let codec_id = match_value!(args[0].clone(), WasmVal::I32, TlsError::ParamError); - ctx.delete_codec(codec_id as usize); - Ok(WasmVal::I32(0)) - } - - match delete_codec_inner(memory, ctx, args) { - Ok(ok) => Ok(vec![ok]), - Err(e) => Ok(vec![WasmVal::I32(e.error_code())]), - } - } - - fn process_new_packets( - _inst: &mut SyncInstanceRef, - memory: &mut Memory, - ctx: &mut Ctx, - args: Vec, - ) -> Result, CoreError> { - #[inline] - fn process_new_packets_inner( - memory: &mut Memory, - ctx: &mut Ctx, - args: Vec, - ) -> Result { - let codec_id = match_value!(args[0].clone(), WasmVal::I32, TlsError::ParamError); - let result_ptr = match_value!(args[1].clone(), WasmVal::I32, TlsError::ParamError); - - let codec = ctx - .client_codec - .get_mut(codec_id as usize) - .ok_or(TlsError::ParamError)? - .as_mut() - .ok_or(TlsError::ParamError)?; - let io_state = codec.process_new_packets()?; - - memory - .write_data((result_ptr as usize).into(), io_state) - .ok_or(TlsError::ParamError)?; - - Ok(WasmVal::I32(0 as i32)) - } - match process_new_packets_inner(memory, ctx, args) { - Ok(ok) => Ok(vec![ok]), - Err(e) => Ok(vec![WasmVal::I32(e.error_code())]), - } - } - - fn send_close_notify( - _inst: &mut SyncInstanceRef, - memory: &mut Memory, - ctx: &mut Ctx, - args: Vec, - ) -> Result, CoreError> { - #[inline] - fn send_close_notify_inner( - _memory: &mut Memory, - ctx: &mut Ctx, - args: Vec, - ) -> Result { - let codec_id = match_value!(args[0].clone(), WasmVal::I32, TlsError::ParamError); - let codec = ctx - .client_codec - .get_mut(codec_id as usize) - .ok_or(TlsError::ParamError)? - .as_mut() - .ok_or(TlsError::ParamError)?; - codec.send_close_notify(); - Ok(WasmVal::I32(0)) - } - - match send_close_notify_inner(memory, ctx, args) { - Ok(ok) => Ok(vec![ok]), - Err(e) => Ok(vec![WasmVal::I32(e.error_code())]), - } - } - - fn write_raw( - _inst: &mut SyncInstanceRef, - memory: &mut Memory, - ctx: &mut Ctx, - args: Vec, - ) -> Result, CoreError> { - #[inline] - fn write_raw_inner( - memory: &mut Memory, - ctx: &mut Ctx, - args: Vec, - ) -> Result { - let codec_id = match_value!(args[0].clone(), WasmVal::I32, TlsError::ParamError); - let raw_buf_ptr = match_value!(args[1].clone(), WasmVal::I32, TlsError::ParamError); - let raw_len = match_value!(args[2].clone(), WasmVal::I32, TlsError::ParamError); - - let codec = ctx - .client_codec - .get_mut(codec_id as usize) - .ok_or(TlsError::ParamError)? - .as_mut() - .ok_or(TlsError::ParamError)?; - - let raw_buf = memory - .data_pointer(raw_buf_ptr as usize, raw_len as usize) - .ok_or(TlsError::ParamError)?; - - let n = codec.write_raw(raw_buf)?; - Ok(WasmVal::I32(n as i32)) - } - match write_raw_inner(memory, ctx, args) { - Ok(ok) => Ok(vec![ok]), - Err(e) => Ok(vec![WasmVal::I32(e.error_code())]), - } - } - - fn write_tls( - _inst: &mut SyncInstanceRef, - memory: &mut Memory, - ctx: &mut Ctx, - args: Vec, - ) -> Result, CoreError> { - #[inline] - fn write_tls_inner( - memory: &mut Memory, - ctx: &mut Ctx, - args: Vec, - ) -> Result { - let codec_id = match_value!(args[0].clone(), WasmVal::I32, TlsError::ParamError); - let tls_buf_ptr = match_value!(args[1].clone(), WasmVal::I32, TlsError::ParamError); - let tls_len = match_value!(args[2].clone(), WasmVal::I32, TlsError::ParamError); - - let codec = ctx - .client_codec - .get_mut(codec_id as usize) - .ok_or(TlsError::ParamError)? - .as_mut() - .ok_or(TlsError::ParamError)?; - - let raw_buf = memory - .data_pointer_mut(tls_buf_ptr as usize, tls_len as usize) - .ok_or(TlsError::ParamError)?; - - let n = codec.write_tls(raw_buf)?; - Ok(WasmVal::I32(n as i32)) - } - match write_tls_inner(memory, ctx, args) { - Ok(ok) => Ok(vec![ok]), - Err(e) => Ok(vec![WasmVal::I32(e.error_code())]), - } - } - - fn read_raw( - _inst: &mut SyncInstanceRef, - memory: &mut Memory, - ctx: &mut Ctx, - args: Vec, - ) -> Result, CoreError> { - #[inline] - fn read_raw_inner( - memory: &mut Memory, - ctx: &mut Ctx, - args: Vec, - ) -> Result { - let codec_id = match_value!(args[0].clone(), WasmVal::I32, TlsError::ParamError); - let raw_buf_ptr = match_value!(args[1].clone(), WasmVal::I32, TlsError::ParamError); - let raw_len = match_value!(args[2].clone(), WasmVal::I32, TlsError::ParamError); - - let codec = ctx - .client_codec - .get_mut(codec_id as usize) - .ok_or(TlsError::ParamError)? - .as_mut() - .ok_or(TlsError::ParamError)?; - - let raw_buf = memory - .data_pointer_mut(raw_buf_ptr as usize, raw_len as usize) - .ok_or(TlsError::ParamError)?; - - let n = codec.read_raw(raw_buf); - let n = n?; - Ok(WasmVal::I32(n as i32)) - } - match read_raw_inner(memory, ctx, args) { - Ok(ok) => Ok(vec![ok]), - Err(e) => Ok(vec![WasmVal::I32(e.error_code())]), - } - } - - fn read_tls( - _inst: &mut SyncInstanceRef, - memory: &mut Memory, - ctx: &mut Ctx, - args: Vec, - ) -> Result, CoreError> { - #[inline] - fn read_tls_inner( - memory: &mut Memory, - ctx: &mut Ctx, - args: Vec, - ) -> Result { - let codec_id = match_value!(args[0].clone(), WasmVal::I32, TlsError::ParamError); - let tls_buf_ptr = match_value!(args[1].clone(), WasmVal::I32, TlsError::ParamError); - let tls_len = match_value!(args[2].clone(), WasmVal::I32, TlsError::ParamError); - - let codec = ctx - .client_codec - .get_mut(codec_id as usize) - .ok_or(TlsError::ParamError)? - .as_mut() - .ok_or(TlsError::ParamError)?; - - let raw_buf = memory - .data_pointer(tls_buf_ptr as usize, tls_len as usize) - .ok_or(TlsError::ParamError)?; - - let n = codec.read_tls(raw_buf)?; - Ok(WasmVal::I32(n as i32)) - } - match read_tls_inner(memory, ctx, args) { - Ok(ok) => Ok(vec![ok]), - Err(e) => Ok(vec![WasmVal::I32(e.error_code())]), - } - } - - pub fn create_module() -> PluginModule { - let mut module = PluginModule::create("rustls_client", Ctx::new()).unwrap(); - module - .add_func( - "default_config", - (vec![], vec![ValType::I32]), - default_config, - ) - .unwrap(); - - module - .add_func( - "new_codec", - ( - vec![ValType::I32, ValType::I32, ValType::I32], - vec![ValType::I32], - ), - new_client_codec, - ) - .unwrap(); - - module - .add_func( - "codec_is_handshaking", - (vec![ValType::I32], vec![ValType::I32]), - is_handshaking, - ) - .unwrap(); - - module - .add_func( - "codec_wants", - (vec![ValType::I32], vec![ValType::I32]), - wants, - ) - .unwrap(); - - module - .add_func( - "delete_codec", - (vec![ValType::I32], vec![ValType::I32]), - delete_codec, - ) - .unwrap(); - - module - .add_func( - "send_close_notify", - (vec![ValType::I32], vec![ValType::I32]), - send_close_notify, - ) - .unwrap(); - - module - .add_func( - "process_new_packets", - (vec![ValType::I32, ValType::I32], vec![ValType::I32]), - process_new_packets, - ) - .unwrap(); - - module - .add_func( - "write_raw", - ( - vec![ - ValType::I32, //codec_id - ValType::I32, // buf - ValType::I32, // buf_len - ], - vec![ValType::I32], - ), - write_raw, - ) - .unwrap(); - - module - .add_func( - "write_tls", - ( - vec![ - ValType::I32, //codec_id - ValType::I32, // buf - ValType::I32, // buf_len - ], - vec![ValType::I32], - ), - write_tls, - ) - .unwrap(); - - module - .add_func( - "read_raw", - ( - vec![ - ValType::I32, //codec_id - ValType::I32, // buf - ValType::I32, // buf_len - ], - vec![ValType::I32], - ), - read_raw, - ) - .unwrap(); - - module - .add_func( - "read_tls", - ( - vec![ - ValType::I32, //codec_id - ValType::I32, // buf - ValType::I32, // buf_len - ], - vec![ValType::I32], - ), - read_tls, - ) - .unwrap(); - - module - } -} - -use wasmedge_client_plugin::create_module; - -wasmedge_plugin_sdk::plugin::register_plugin!( - plugin_name="rustls", - plugin_description="rustls plugin", - version=(0,0,1,0), - modules=[ - {"rustls_client","rustls client module",create_module} - ] -); diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt index eeb1950b..cdda1155 100644 --- a/test/plugins/CMakeLists.txt +++ b/test/plugins/CMakeLists.txt @@ -57,8 +57,4 @@ if(WASMEDGE_PLUGIN_WASI_LOGGING) add_subdirectory(wasi_logging) endif() -if(WASMEDGE_PLUGIN_RUSTLS) - add_subdirectory(wasmedge_rustls) -endif() - add_subdirectory(unittest) diff --git a/test/plugins/wasmedge_rustls/CMakeLists.txt b/test/plugins/wasmedge_rustls/CMakeLists.txt deleted file mode 100644 index 3cfc874b..00000000 --- a/test/plugins/wasmedge_rustls/CMakeLists.txt +++ /dev/null @@ -1,35 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC - -wasmedge_add_executable(wasmEdgeRUSTLSTests - wasmedge_rustls.cpp -) - -add_dependencies(wasmEdgeRUSTLSTests - wasmedge_rustls -) - -target_include_directories(wasmEdgeRUSTLSTests - PUBLIC - $ - $ -) - -target_link_libraries(wasmEdgeRUSTLSTests - PRIVATE - ${GTEST_BOTH_LIBRARIES} -) -# Link to the WasmEdge library -if(WASMEDGE_LINK_PLUGINS_STATIC) - target_link_libraries(wasmEdgeRUSTLSTests - PRIVATE - wasmedgeCAPI - ) -else() - target_link_libraries(wasmEdgeRUSTLSTests - PRIVATE - wasmedge_shared - ) -endif() - -add_test(wasmEdgeRUSTLSTests wasmEdgeRUSTLSTests) diff --git a/test/plugins/wasmedge_rustls/wasmedge_rustls.cpp b/test/plugins/wasmedge_rustls/wasmedge_rustls.cpp deleted file mode 100644 index f34d3a78..00000000 --- a/test/plugins/wasmedge_rustls/wasmedge_rustls.cpp +++ /dev/null @@ -1,54 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC - -#include "common/defines.h" -#include "plugin/plugin.h" -#include "runtime/instance/module.h" - -#include -#include -#include -#include -#include -#include - -namespace { -WasmEdge::Runtime::Instance::ModuleInstance *createModule() { - using namespace std::literals::string_view_literals; - WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( - "../../../plugins/wasmedge_rustls/" WASMEDGE_LIB_PREFIX - "wasmedge_rustls" WASMEDGE_LIB_EXTENSION)); - if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("rustls"sv)) { - if (const auto *Module = Plugin->findModule("rustls_client"sv)) { - return Module->create().release(); - } - } - return nullptr; -} -} // namespace - -// TODO: unit tests for every functions. - -TEST(WasmEdgeRUSTLSTest, Module) { - // Create the wasmedge_rustls module instance. - auto *TLSMod = createModule(); - EXPECT_FALSE(TLSMod == nullptr); - EXPECT_EQ(TLSMod->getFuncExportNum(), 11U); - EXPECT_NE(TLSMod->findFuncExports("default_config"), nullptr); - EXPECT_NE(TLSMod->findFuncExports("new_codec"), nullptr); - EXPECT_NE(TLSMod->findFuncExports("codec_is_handshaking"), nullptr); - EXPECT_NE(TLSMod->findFuncExports("codec_wants"), nullptr); - EXPECT_NE(TLSMod->findFuncExports("delete_codec"), nullptr); - EXPECT_NE(TLSMod->findFuncExports("send_close_notify"), nullptr); - EXPECT_NE(TLSMod->findFuncExports("process_new_packets"), nullptr); - EXPECT_NE(TLSMod->findFuncExports("write_raw"), nullptr); - EXPECT_NE(TLSMod->findFuncExports("write_tls"), nullptr); - EXPECT_NE(TLSMod->findFuncExports("read_raw"), nullptr); - EXPECT_NE(TLSMod->findFuncExports("read_tls"), nullptr); - delete TLSMod; -} - -GTEST_API_ int main(int argc, char **argv) { - testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} From a7db8034847da4daad175ee77ea0f6175970079e Mon Sep 17 00:00:00 2001 From: hydai Date: Mon, 3 Jun 2024 17:29:00 +0800 Subject: [PATCH 308/623] [WASI-NN] ggml: bump to b3067 (#3445) Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 092fb5a4..b39cb2c4 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -96,7 +96,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b3014 + GIT_TAG b3067 GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(llama) From d204816fa9db58ec52dc09748d76844e05b7cad8 Mon Sep 17 00:00:00 2001 From: dm4 Date: Mon, 3 Jun 2024 22:59:12 +0800 Subject: [PATCH 309/623] [WASI-NN] ggml: bump llama.cpp b3075 (#3446) Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index b39cb2c4..838dfcbd 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -96,7 +96,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b3067 + GIT_TAG b3075 GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(llama) From 313f1650ba803065da525458cd5655accc56f1fc Mon Sep 17 00:00:00 2001 From: grorge Date: Mon, 25 Mar 2024 19:54:13 +0800 Subject: [PATCH 310/623] [WASI-NN] neural speed: add backend struct Signed-off-by: grorge --- plugins/wasi_nn/CMakeLists.txt | 1 + plugins/wasi_nn/neuralspeed.cpp | 55 +++++++++++++++++++++++++++++++++ plugins/wasi_nn/neuralspeed.h | 34 ++++++++++++++++++++ plugins/wasi_nn/types.h | 4 ++- plugins/wasi_nn/wasinnenv.cpp | 3 +- plugins/wasi_nn/wasinnenv.h | 1 + 6 files changed, 96 insertions(+), 2 deletions(-) create mode 100644 plugins/wasi_nn/neuralspeed.cpp create mode 100644 plugins/wasi_nn/neuralspeed.h diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 838dfcbd..cfd02f10 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -118,6 +118,7 @@ wasmedge_add_library(wasmedgePluginWasiNN torch.cpp tfl.cpp ggml.cpp + neuralspeed.cpp ) target_compile_options(wasmedgePluginWasiNN diff --git a/plugins/wasi_nn/neuralspeed.cpp b/plugins/wasi_nn/neuralspeed.cpp new file mode 100644 index 00000000..3f857f45 --- /dev/null +++ b/plugins/wasi_nn/neuralspeed.cpp @@ -0,0 +1,55 @@ +#include "neuralspeed.h" +#include "wasinnenv.h" + +namespace WasmEdge::Host::WASINN::NeuralSpeed { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED +Expect load(WASINN::WasiNNEnvironment &, + Span>, WASINN::Device, + uint32_t &) noexcept { + return WASINN::ErrNo::Success; +} +Expect initExecCtx(WASINN::WasiNNEnvironment &, uint32_t, + uint32_t &) noexcept { + return WASINN::ErrNo::Success; +} +Expect setInput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return WASINN::ErrNo::Success; +} +Expect getOutput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + Span, uint32_t &) noexcept { + return WASINN::ErrNo::Success; +} +Expect compute(WASINN::WasiNNEnvironment &, uint32_t) noexcept { + return WASINN::ErrNo::Success; +} +#else +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error("[WASI-NN] Neural speed backend is not supported."); + return WASINN::ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WASINN::WasiNNEnvironment &, + Span>, WASINN::Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WASINN::WasiNNEnvironment &, uint32_t, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + Span, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WASINN::WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +#endif +} // namespace WasmEdge::Host::WASINN::NeuralSpeed diff --git a/plugins/wasi_nn/neuralspeed.h b/plugins/wasi_nn/neuralspeed.h new file mode 100644 index 00000000..eddcd2c9 --- /dev/null +++ b/plugins/wasi_nn/neuralspeed.h @@ -0,0 +1,34 @@ +#pragma once + +#include "plugin/plugin.h" +#include "types.h" + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} + +namespace WasmEdge::Host::WASINN::NeuralSpeed { +struct Graph {}; +struct Context { + Context(size_t, Graph &) noexcept {} +}; + +struct Environ {}; + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; + +} // namespace WasmEdge::Host::WASINN::NeuralSpeed \ No newline at end of file diff --git a/plugins/wasi_nn/types.h b/plugins/wasi_nn/types.h index 81a06eee..d1dd7fe7 100644 --- a/plugins/wasi_nn/types.h +++ b/plugins/wasi_nn/types.h @@ -36,10 +36,12 @@ enum class Backend : uint8_t { TensorflowLite = 4, Autodetect = 5, GGML = 6, + NeuralSpeed = 7, }; #define FOR_EACH_BACKEND(F) \ - F(OpenVINO) F(ONNX) F(Tensorflow) F(PyTorch) F(TensorflowLite) F(GGML) + F(OpenVINO) \ + F(ONNX) F(Tensorflow) F(PyTorch) F(TensorflowLite) F(GGML) F(NeuralSpeed) struct TensorData { Span Dimension; diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index b2228ff4..fdadbb96 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -28,7 +28,8 @@ std::map BackendMap = { {"pytorch"sv, Backend::PyTorch}, {"tensorflowlite"sv, Backend::TensorflowLite}, {"autodetect"sv, Backend::Autodetect}, - {"ggml"sv, Backend::GGML}}; + {"ggml"sv, Backend::GGML}, + {"neuralspeed"sv, Backend::NeuralSpeed}}; std::map DeviceMap = {{"cpu"sv, Device::CPU}, {"gpu"sv, Device::GPU}, diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index 260a1659..f2b1adb8 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -10,6 +10,7 @@ #include #include "ggml.h" +#include "neuralspeed.h" #include "onnx.h" #include "openvino.h" #include "tf.h" From 26df1447c7204affc9e1700a3eae44c6d5d063f6 Mon Sep 17 00:00:00 2001 From: grorge Date: Wed, 3 Apr 2024 22:31:42 +0800 Subject: [PATCH 311/623] [WASI-NN] neural speed: add successful test Signed-off-by: grorge --- test/plugins/wasi_nn/wasi_nn.cpp | 140 +++++++++++++++++- .../wasi-nn/download-neuralspeed-fixtures.sh | 17 +++ 2 files changed, 156 insertions(+), 1 deletion(-) create mode 100755 utils/wasi-nn/download-neuralspeed-fixtures.sh diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index e8ccdeaf..7380d63e 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -19,7 +19,8 @@ using WasmEdge::Host::WASINN::ErrNo; #if defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE) || \ - defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML) + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML) || \ + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED) namespace { WasmEdge::Runtime::Instance::ModuleInstance * createModule(std::string_view NNRPCURI = "") { @@ -1744,3 +1745,140 @@ TEST(WasiNNTest, GGMLBackendWithRPC) { } #endif // WASMEDGE_BUILD_WASI_NN_RPC #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED +TEST(WasiNNTest, NeuralSpeedBackend) { + // Create the wasmedge_process module instance. + auto *NNMod = dynamic_cast(createModule()); + ASSERT_TRUE(NNMod != nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(400))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // Load the files. + std::string Prompt = "Once upon a time, "; + std::vector TensorData(Prompt.begin(), Prompt.end()); + std::vector WeightRead = + readEntireFile("./wasinn_neuralspeed_fixtures/llama-2-7b-chat.Q4_0.gguf"); + + std::vector TensorDim{1}; + uint32_t BuilderPtr = UINT32_C(0); + uint32_t LoadEntryPtr = UINT32_C(0); + uint32_t SetInputEntryPtr = UINT32_C(0); +// uint32_t OutBoundPtr = UINT32_C(61000 * 65536); + uint32_t StorePtr = UINT32_C(65536); + + // Return value. + std::array Errno = {UINT32_C(0)}; + + // Get the function "load". + auto *FuncInst = NNMod->findFuncExports("load"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncLoad = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "init_execution_context". + FuncInst = NNMod->findFuncExports("init_execution_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInit = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "set_input". + FuncInst = NNMod->findFuncExports("set_input"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSetInput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "get_output". + FuncInst = NNMod->findFuncExports("get_output"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetOutput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "compute". + FuncInst = NNMod->findFuncExports("compute"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncCompute = + dynamic_cast(FuncInst->getHostFunc()); + + // Neural Speed WASI-NN load tests. + // Test: load -- load successfully. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, StorePtr, WeightRead.size(), BuilderPtr); + writeBinaries(MemInst, WeightRead, StorePtr); + StorePtr += WeightRead.size(); + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::NeuralSpeed), UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Neural Speed WASI-NN init_execution_context tests. + // Test: init_execution_context -- init second context. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(1), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 1); + BuilderPtr += 4; + } + + // Neural Speed WASI-NN set_input tests. + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); + + // Test: set_input -- set input successfully. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(1), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += (TensorDim.size() * 4 + TensorData.size()); + + // Neural Speed WASI-NN compute tests. + // Test: compute -- compute successfully. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(0)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Neural Speed WASI-NN get_output tests. + // Test: get_output -- get output successfully. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + // Should output more than 50 bytes. + auto BytesWritten = *MemInst.getPointer(BuilderPtr); + EXPECT_GE(BytesWritten, 50); + } +} +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED \ No newline at end of file diff --git a/utils/wasi-nn/download-neuralspeed-fixtures.sh b/utils/wasi-nn/download-neuralspeed-fixtures.sh new file mode 100755 index 00000000..e050aa3c --- /dev/null +++ b/utils/wasi-nn/download-neuralspeed-fixtures.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2023 Second State INC + +TODIR=$1 +if [[ $# -eq 0 ]]; then + TODIR=. +fi +MODEL=orca_mini.gguf +FIXTURE=https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/blob/main/llama-2-7b-chat.Q4_0.gguf +if [ ! -d $TODIR ]; then + mkdir $TODIR +fi + +if [ ! -f $TODIR/$MODEL ]; then + curl -sL $FIXTURE -o $TODIR/$MODEL +fi From c344a3b9cf95dad54b0875ec06f51fdebf2fddc9 Mon Sep 17 00:00:00 2001 From: grorge Date: Sun, 7 Apr 2024 13:29:17 +0800 Subject: [PATCH 312/623] [WASI-NN] neural speed: implement neural network backend Signed-off-by: grorge --- plugins/wasi_nn/CMakeLists.txt | 30 +++++ plugins/wasi_nn/neuralspeed.cpp | 190 ++++++++++++++++++++++++++++++-- plugins/wasi_nn/neuralspeed.h | 42 ++++++- 3 files changed, 251 insertions(+), 11 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index cfd02f10..95f1157d 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -197,6 +197,36 @@ if(WASMEDGE_BUILD_WASI_NN_RPC) ) endif() +if(BACKEND STREQUAL "neuralspeed") + find_package(Python3 REQUIRED COMPONENTS Interpreter Development) + execute_process(COMMAND python3-config --cflags + OUTPUT_VARIABLE PYTHON_CFLAGS + OUTPUT_STRIP_TRAILING_WHITESPACE) + execute_process(COMMAND python3-config --ldflags --embed + OUTPUT_VARIABLE PYTHON_LDFLAGS + OUTPUT_STRIP_TRAILING_WHITESPACE) + separate_arguments(PYTHON_INCLUDES UNIX_COMMAND "${PYTHON_CFLAGS}") + foreach(flag ${PYTHON_INCLUDES}) + if(flag MATCHES "^-I") + string(SUBSTRING "${flag}" 2 -1 path) + # message(STATUS "Include: ${path}") + target_include_directories(wasmedgePluginWasiNN PUBLIC ${path}) + endif() + endforeach() + separate_arguments(PYTHON_LIBS UNIX_COMMAND "${PYTHON_LDFLAGS}") + foreach(flag ${PYTHON_LIBS}) + if(flag MATCHES "^-lpython") + string(SUBSTRING "${flag}" 2 -1 path) + # message(STATUS "LINK: ${path}") + target_link_libraries(wasmedgePluginWasiNN PUBLIC ${path}) + elseif(flag MATCHES "^-L") + string(SUBSTRING "${flag}" 2 -1 path) + # message(STATUS "LINK: ${path}") + target_link_directories(wasmedgePluginWasiNN PUBLIC ${path}) + endif() + endforeach() +endif() + include(WASINNDeps) wasmedge_setup_wasinn_target(wasmedgePluginWasiNN) diff --git a/plugins/wasi_nn/neuralspeed.cpp b/plugins/wasi_nn/neuralspeed.cpp index 3f857f45..4c393b43 100644 --- a/plugins/wasi_nn/neuralspeed.cpp +++ b/plugins/wasi_nn/neuralspeed.cpp @@ -3,24 +3,194 @@ namespace WasmEdge::Host::WASINN::NeuralSpeed { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED -Expect load(WASINN::WasiNNEnvironment &, - Span>, WASINN::Device, +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, WASINN::Device, uint32_t &) noexcept { + // Add a new graph. + Env.NNGraph.emplace_back(Backend::GGML); + auto &GraphRef = Env.NNGraph.back().get(); + + // Initialize the plugin parameters. + GraphRef.EnableDebugLog = true; + // Handle the model path. + auto Weight = Builders[0]; + const std::string BinModel(reinterpret_cast(Weight.data()), + Weight.size()); + std::string ModelFilePath; + if (BinModel.substr(0, 8) == "preload:") { + ModelFilePath = BinModel.substr(8); + } else { + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] Neural speed: Model path not found in nn-preload, " + "write model into a tmpfile."sv); + } + // TODO: pass the model directly to ggml + // Write neural speed model to file. + ModelFilePath = "neural-speed-model.bin"sv; + std::ofstream TempFile(ModelFilePath); + if (!TempFile) { + spdlog::error( + "[WASI-NN] Neural speed: Failed to create the temporary file. " + "Currently, our workaround involves creating a temporary model " + "file named \"ggml-model.bin\" and passing this filename as a " + "parameter to the ggml llama library."sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } + TempFile << BinModel; + TempFile.close(); + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] Neural speed: Write model into a tmpfile...Done"sv); + } + } + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] Neural speed: Finished handling model path."sv); + } + + // Create Model class + PyObject *moduleName = PyUnicode_FromString("neural_speed"); + GraphRef.NeuralSpeedModule = PyImport_Import(moduleName); + Py_DECREF(moduleName); + if (GraphRef.NeuralSpeedModule == nullptr) { + spdlog::error( + "[WASI-NN] neural speed backend: Can not find neural speed library."sv); + return WASINN::ErrNo::RuntimeError; + } + GraphRef.ModelClass = + PyObject_GetAttrString(GraphRef.NeuralSpeedModule, "Model"); + if (GraphRef.ModelClass == nullptr || + !PyCallable_Check(GraphRef.ModelClass)) { + spdlog::error( + "[WASI-NN] neural speed backend: Can not find Model class in neural speed."sv); + return WASINN::ErrNo::RuntimeError; + } + GraphRef.Model = PyObject_CallObject(GraphRef.ModelClass, NULL); + PyObject *LoadResult = PyObject_CallMethod(GraphRef.Model, "init_from_bin", + "llama", ModelFilePath); + if (LoadResult == nullptr) { + spdlog::error("[WASI-NN] neural speed backend: Load model error."sv); + return WASINN::ErrNo::RuntimeError; + } + Py_XDECREF(LoadResult); + return WASINN::ErrNo::Success; } -Expect initExecCtx(WASINN::WasiNNEnvironment &, uint32_t, - uint32_t &) noexcept { - return WASINN::ErrNo::Success; + +Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, + uint32_t &ContextId) noexcept { + Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); + ContextId = Env.NNContext.size() - 1; + return ErrNo::Success; } -Expect setInput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, - const TensorData &) noexcept { + +Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t, const TensorData &Tensor) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] Neural speed backend: setInput"sv); + } + + // Set the input. + if (Tensor.Tensor.size() % sizeof(long long int) != 0) { + spdlog::error("[WASI-NN] neural speed backend: Input tensor size is not a " + "multiple of " + "4 bytes."sv); + return WASINN::ErrNo::InvalidArgument; + } + std::vector Prompt{ + reinterpret_cast(Tensor.Tensor.data()), + reinterpret_cast(Tensor.Tensor.data() + + Tensor.Tensor.size())}; + CxtRef.Inputs.clear(); + CxtRef.Inputs = Prompt; + return WASINN::ErrNo::Success; } -Expect getOutput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, - Span, uint32_t &) noexcept { +Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t, Span OutBuffer, + uint32_t &BytesWritten) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] Neural speed backend: getOutput"sv); + } + std::string stmp(reinterpret_cast(CxtRef.Outputs.data()), + CxtRef.Outputs.size() * sizeof(long long int)); + std::copy_n(stmp.data(), stmp.length(), OutBuffer.data()); + BytesWritten = stmp.length(); return WASINN::ErrNo::Success; } -Expect compute(WASINN::WasiNNEnvironment &, uint32_t) noexcept { +Expect compute(WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] Neural speed backend: compute"sv); + } + if (CxtRef.Inputs.size() == 0) { + spdlog::error("[WASI-NN] Neural speed backend: Llama input is not set!"sv); + return ErrNo::InvalidArgument; + } + CxtRef.Outputs.clear(); + PyObject *TensorList = PyList_New(0); + for (size_t i = 0; i < CxtRef.Inputs.size(); ++i) { + PyObject *num = PyLong_FromLong(CxtRef.Inputs[i]); + PyList_Append(TensorList, num); + Py_DECREF(num); + } + PyObject *TmpArg = PyList_New(0); + PyList_Append(TmpArg, TensorList); + Py_DECREF(TensorList); + PyObject *torchModule = PyImport_ImportModule("torch"); + if (torchModule == nullptr) { + spdlog::error( + "[WASI-NN] neural speed backend: Can not find torch library."sv); + return WASINN::ErrNo::RuntimeError; + } + PyObject *LongTensorFunc = PyObject_GetAttrString(torchModule, "LongTensor"); + PyObject *LongTensorArgs = PyTuple_Pack(1, TmpArg); + PyObject *LongTensor = PyObject_CallObject(LongTensorFunc, LongTensorArgs); + Py_DECREF(torchModule); + Py_DECREF(LongTensorFunc); + Py_DECREF(TmpArg); + Py_DECREF(LongTensorArgs); + if (LongTensor == nullptr) { + spdlog::error( + "[WASI-NN] neural speed backend: Input transfer tensor failed."sv); + return WASINN::ErrNo::InvalidArgument; + } + PyObject *GenerateArgs = PyTuple_Pack(1, LongTensor); + PyObject *result = PyObject_CallMethodObjArgs( + GraphRef.Model, PyUnicode_FromString("generate"), GenerateArgs, NULL); + if (result == nullptr) { + spdlog::error( + "[WASI-NN] neural speed backend: Neural Speed runtime error."sv); + return WASINN::ErrNo::RuntimeError; + } + if (PyList_Check(result)) { + Py_ssize_t outerSize = PyList_Size(result); + for (Py_ssize_t i = 0; i < outerSize; ++i) { + PyObject *innerList = PyList_GetItem(result, i); + if (PyList_Check(innerList)) { + std::vector innerVec; + Py_ssize_t innerSize = PyList_Size(innerList); + for (Py_ssize_t j = 0; j < innerSize; ++j) { + PyObject *num = PyList_GetItem(innerList, j); + if (PyLong_Check(num)) { + innerVec.push_back(PyLong_AsLong(num)); + } + } + CxtRef.Outputs = innerVec; + } + } + } + Py_DECREF(result); + Py_DECREF(GenerateArgs); + Py_DECREF(LongTensor); return WASINN::ErrNo::Success; } #else diff --git a/plugins/wasi_nn/neuralspeed.h b/plugins/wasi_nn/neuralspeed.h index eddcd2c9..3e49e90d 100644 --- a/plugins/wasi_nn/neuralspeed.h +++ b/plugins/wasi_nn/neuralspeed.h @@ -2,16 +2,56 @@ #include "plugin/plugin.h" #include "types.h" - +#include namespace WasmEdge::Host::WASINN { struct WasiNNEnvironment; } namespace WasmEdge::Host::WASINN::NeuralSpeed { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED +#include +struct Graph { + bool EnableDebugLog = true; + static std::mutex py_mutex; + inline static int GraphNumber = 0; + Graph() noexcept { + py_mutex.lock(); + if (GraphNumber == 0) { + Py_Initialize(); + } + GraphNumber++; + py_mutex.unlock(); + } + ~Graph() noexcept { + Py_XDECREF(Model); + Py_XDECREF(ModelClass); + Py_XDECREF(NeuralSpeedModule); + py_mutex.lock(); + if (GraphNumber == 1) { + Py_Finalize(); + } + GraphNumber--; + py_mutex.unlock(); + } + + PyObject *Model; + PyObject *NeuralSpeedModule; + PyObject *ModelClass; +}; +struct Context { + Context(size_t Gid, Graph &) noexcept : GraphId(Gid) {} + size_t GraphId; + std::vector Inputs; + std::vector Outputs; +}; +#else struct Graph {}; struct Context { Context(size_t, Graph &) noexcept {} }; +#endif + + struct Environ {}; From d55c3d07c091c187427cb75e82ca221b9259bbb3 Mon Sep 17 00:00:00 2001 From: grorge Date: Mon, 8 Apr 2024 14:35:34 +0800 Subject: [PATCH 313/623] [WASI-NN] neural speed: fix neural backend and basic test Signed-off-by: grorge --- plugins/wasi_nn/CMakeLists.txt | 30 +------ plugins/wasi_nn/neuralspeed.cpp | 89 ++++++++++--------- plugins/wasi_nn/neuralspeed.h | 17 ++-- test/plugins/wasi_nn/wasi_nn.cpp | 33 ++++--- .../wasi-nn/download-neuralspeed-fixtures.sh | 4 +- 5 files changed, 82 insertions(+), 91 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 95f1157d..aaa60b55 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -198,33 +198,9 @@ if(WASMEDGE_BUILD_WASI_NN_RPC) endif() if(BACKEND STREQUAL "neuralspeed") - find_package(Python3 REQUIRED COMPONENTS Interpreter Development) - execute_process(COMMAND python3-config --cflags - OUTPUT_VARIABLE PYTHON_CFLAGS - OUTPUT_STRIP_TRAILING_WHITESPACE) - execute_process(COMMAND python3-config --ldflags --embed - OUTPUT_VARIABLE PYTHON_LDFLAGS - OUTPUT_STRIP_TRAILING_WHITESPACE) - separate_arguments(PYTHON_INCLUDES UNIX_COMMAND "${PYTHON_CFLAGS}") - foreach(flag ${PYTHON_INCLUDES}) - if(flag MATCHES "^-I") - string(SUBSTRING "${flag}" 2 -1 path) - # message(STATUS "Include: ${path}") - target_include_directories(wasmedgePluginWasiNN PUBLIC ${path}) - endif() - endforeach() - separate_arguments(PYTHON_LIBS UNIX_COMMAND "${PYTHON_LDFLAGS}") - foreach(flag ${PYTHON_LIBS}) - if(flag MATCHES "^-lpython") - string(SUBSTRING "${flag}" 2 -1 path) - # message(STATUS "LINK: ${path}") - target_link_libraries(wasmedgePluginWasiNN PUBLIC ${path}) - elseif(flag MATCHES "^-L") - string(SUBSTRING "${flag}" 2 -1 path) - # message(STATUS "LINK: ${path}") - target_link_directories(wasmedgePluginWasiNN PUBLIC ${path}) - endif() - endforeach() + find_package(PythonLibs REQUIRED) + include_directories(${PYTHON_INCLUDE_DIRS}) + target_link_libraries(wasmedgePluginWasiNN PUBLIC ${PYTHON_LIBRARIES}) endif() include(WASINNDeps) diff --git a/plugins/wasi_nn/neuralspeed.cpp b/plugins/wasi_nn/neuralspeed.cpp index 4c393b43..c5b264e0 100644 --- a/plugins/wasi_nn/neuralspeed.cpp +++ b/plugins/wasi_nn/neuralspeed.cpp @@ -1,21 +1,27 @@ #include "neuralspeed.h" #include "wasinnenv.h" - +#include namespace WasmEdge::Host::WASINN::NeuralSpeed { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED Expect load(WASINN::WasiNNEnvironment &Env, Span> Builders, WASINN::Device, uint32_t &) noexcept { + spdlog::info("[WASI-NN][Debug] Neural speed: test"sv); // Add a new graph. - Env.NNGraph.emplace_back(Backend::GGML); + Env.NNGraph.emplace_back(Backend::NeuralSpeed); auto &GraphRef = Env.NNGraph.back().get(); // Initialize the plugin parameters. GraphRef.EnableDebugLog = true; + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] Neural speed backend: Load."sv); + } // Handle the model path. auto Weight = Builders[0]; const std::string BinModel(reinterpret_cast(Weight.data()), Weight.size()); + spdlog::info("[WASI-NN][Debug] Neural speed: BinModel: {}"sv, + BinModel.size()); std::string ModelFilePath; if (BinModel.substr(0, 8) == "preload:") { ModelFilePath = BinModel.substr(8); @@ -28,7 +34,9 @@ Expect load(WASINN::WasiNNEnvironment &Env, // TODO: pass the model directly to ggml // Write neural speed model to file. ModelFilePath = "neural-speed-model.bin"sv; - std::ofstream TempFile(ModelFilePath); + std::ofstream TempFile(ModelFilePath, std::ios::binary); + TempFile.imbue( + std::locale(TempFile.getloc(), new std::codecvt_utf8)); if (!TempFile) { spdlog::error( "[WASI-NN] Neural speed: Failed to create the temporary file. " @@ -51,9 +59,8 @@ Expect load(WASINN::WasiNNEnvironment &Env, } // Create Model class - PyObject *moduleName = PyUnicode_FromString("neural_speed"); - GraphRef.NeuralSpeedModule = PyImport_Import(moduleName); - Py_DECREF(moduleName); + GraphRef.NeuralSpeedModule = + PyImport_Import(PyUnicode_FromString("neural_speed")); if (GraphRef.NeuralSpeedModule == nullptr) { spdlog::error( "[WASI-NN] neural speed backend: Can not find neural speed library."sv); @@ -68,9 +75,10 @@ Expect load(WASINN::WasiNNEnvironment &Env, return WASINN::ErrNo::RuntimeError; } GraphRef.Model = PyObject_CallObject(GraphRef.ModelClass, NULL); - PyObject *LoadResult = PyObject_CallMethod(GraphRef.Model, "init_from_bin", - "llama", ModelFilePath); + PyObject *LoadResult = PyObject_CallMethod( + GraphRef.Model, "init_from_bin", "(ss)", "llama", ModelFilePath.c_str()); if (LoadResult == nullptr) { + PyErr_Print(); spdlog::error("[WASI-NN] neural speed backend: Load model error."sv); return WASINN::ErrNo::RuntimeError; } @@ -101,7 +109,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, "4 bytes."sv); return WASINN::ErrNo::InvalidArgument; } - std::vector Prompt{ + const std::vector Prompt{ reinterpret_cast(Tensor.Tensor.data()), reinterpret_cast(Tensor.Tensor.data() + Tensor.Tensor.size())}; @@ -118,10 +126,10 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] Neural speed backend: getOutput"sv); } - std::string stmp(reinterpret_cast(CxtRef.Outputs.data()), - CxtRef.Outputs.size() * sizeof(long long int)); - std::copy_n(stmp.data(), stmp.length(), OutBuffer.data()); - BytesWritten = stmp.length(); + std::string StringTmp(reinterpret_cast(CxtRef.Outputs.data()), + CxtRef.Outputs.size() * sizeof(long long int)); + std::copy_n(StringTmp.data(), StringTmp.length(), OutBuffer.data()); + BytesWritten = StringTmp.length(); return WASINN::ErrNo::Success; } Expect compute(WasiNNEnvironment &Env, @@ -137,24 +145,24 @@ Expect compute(WasiNNEnvironment &Env, } CxtRef.Outputs.clear(); PyObject *TensorList = PyList_New(0); - for (size_t i = 0; i < CxtRef.Inputs.size(); ++i) { - PyObject *num = PyLong_FromLong(CxtRef.Inputs[i]); - PyList_Append(TensorList, num); - Py_DECREF(num); + for (size_t Cnt = 0; Cnt < CxtRef.Inputs.size(); ++Cnt) { + PyObject *Num = PyLong_FromLong(CxtRef.Inputs[Cnt]); + PyList_Append(TensorList, Num); + Py_DECREF(Num); } PyObject *TmpArg = PyList_New(0); PyList_Append(TmpArg, TensorList); - Py_DECREF(TensorList); - PyObject *torchModule = PyImport_ImportModule("torch"); - if (torchModule == nullptr) { + PyObject *TorchModule = PyImport_ImportModule("torch"); + if (TorchModule == nullptr) { spdlog::error( "[WASI-NN] neural speed backend: Can not find torch library."sv); return WASINN::ErrNo::RuntimeError; } - PyObject *LongTensorFunc = PyObject_GetAttrString(torchModule, "LongTensor"); + PyObject *LongTensorFunc = PyObject_GetAttrString(TorchModule, "LongTensor"); PyObject *LongTensorArgs = PyTuple_Pack(1, TmpArg); PyObject *LongTensor = PyObject_CallObject(LongTensorFunc, LongTensorArgs); - Py_DECREF(torchModule); + Py_DECREF(TensorList); + Py_DECREF(TorchModule); Py_DECREF(LongTensorFunc); Py_DECREF(TmpArg); Py_DECREF(LongTensorArgs); @@ -163,33 +171,34 @@ Expect compute(WasiNNEnvironment &Env, "[WASI-NN] neural speed backend: Input transfer tensor failed."sv); return WASINN::ErrNo::InvalidArgument; } - PyObject *GenerateArgs = PyTuple_Pack(1, LongTensor); - PyObject *result = PyObject_CallMethodObjArgs( - GraphRef.Model, PyUnicode_FromString("generate"), GenerateArgs, NULL); - if (result == nullptr) { + // PyObject *GenerateArgs = PyTuple_Pack(1, LongTensor); + PyObject *Result = PyObject_CallMethodObjArgs( + GraphRef.Model, PyUnicode_FromString("generate"), LongTensor, NULL); + if (Result == nullptr) { + PyErr_Print(); spdlog::error( "[WASI-NN] neural speed backend: Neural Speed runtime error."sv); return WASINN::ErrNo::RuntimeError; } - if (PyList_Check(result)) { - Py_ssize_t outerSize = PyList_Size(result); - for (Py_ssize_t i = 0; i < outerSize; ++i) { - PyObject *innerList = PyList_GetItem(result, i); - if (PyList_Check(innerList)) { - std::vector innerVec; - Py_ssize_t innerSize = PyList_Size(innerList); - for (Py_ssize_t j = 0; j < innerSize; ++j) { - PyObject *num = PyList_GetItem(innerList, j); - if (PyLong_Check(num)) { - innerVec.push_back(PyLong_AsLong(num)); + if (PyList_Check(Result)) { + const Py_ssize_t OuterSize = PyList_Size(Result); + for (Py_ssize_t OutterCnt = 0; OutterCnt < OuterSize; ++OutterCnt) { + PyObject *InnerList = PyList_GetItem(Result, OutterCnt); + if (PyList_Check(InnerList)) { + std::vector InnerVec; + const Py_ssize_t InnerSize = PyList_Size(InnerList); + for (Py_ssize_t InnerCnt = 0; InnerCnt < InnerSize; ++InnerCnt) { + PyObject *Num = PyList_GetItem(InnerList, InnerCnt); + if (PyLong_Check(Num)) { + InnerVec.push_back(PyLong_AsLong(Num)); } } - CxtRef.Outputs = innerVec; + CxtRef.Outputs = InnerVec; } } } - Py_DECREF(result); - Py_DECREF(GenerateArgs); + Py_DECREF(Result); + // Py_DECREF(GenerateArgs); Py_DECREF(LongTensor); return WASINN::ErrNo::Success; } diff --git a/plugins/wasi_nn/neuralspeed.h b/plugins/wasi_nn/neuralspeed.h index 3e49e90d..cf6c34dd 100644 --- a/plugins/wasi_nn/neuralspeed.h +++ b/plugins/wasi_nn/neuralspeed.h @@ -3,35 +3,38 @@ #include "plugin/plugin.h" #include "types.h" #include +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED +#include +#endif namespace WasmEdge::Host::WASINN { struct WasiNNEnvironment; } namespace WasmEdge::Host::WASINN::NeuralSpeed { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED -#include struct Graph { bool EnableDebugLog = true; - static std::mutex py_mutex; + // TODO add mutex + // static std::mutex py_mutex; inline static int GraphNumber = 0; Graph() noexcept { - py_mutex.lock(); + // py_mutex.lock(); if (GraphNumber == 0) { Py_Initialize(); } GraphNumber++; - py_mutex.unlock(); + // py_mutex.unlock(); } ~Graph() noexcept { Py_XDECREF(Model); Py_XDECREF(ModelClass); Py_XDECREF(NeuralSpeedModule); - py_mutex.lock(); + // py_mutex.lock(); if (GraphNumber == 1) { Py_Finalize(); } GraphNumber--; - py_mutex.unlock(); + // py_mutex.unlock(); } PyObject *Model; @@ -51,8 +54,6 @@ struct Context { }; #endif - - struct Environ {}; Expect load(WASINN::WasiNNEnvironment &Env, diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 7380d63e..736139de 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -1756,23 +1756,26 @@ TEST(WasiNNTest, NeuralSpeedBackend) { WasmEdge::Runtime::Instance::ModuleInstance Mod(""); Mod.addHostMemory( "memory", std::make_unique( - WasmEdge::AST::MemoryType(400))); + WasmEdge::AST::MemoryType(60000))); auto *MemInstPtr = Mod.findMemoryExports("memory"); ASSERT_TRUE(MemInstPtr != nullptr); auto &MemInst = *MemInstPtr; WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); // Load the files. - std::string Prompt = "Once upon a time, "; - std::vector TensorData(Prompt.begin(), Prompt.end()); - std::vector WeightRead = - readEntireFile("./wasinn_neuralspeed_fixtures/llama-2-7b-chat.Q4_0.gguf"); + std::vector Prompt = {1, 9038, 2501, 263, 931, 29892, + 727, 22856, 263, 2217, 7826, 29892}; + std::string tmp(reinterpret_cast(Prompt.data()), + Prompt.size() * sizeof(long long int)); + std::vector TensorData(tmp.begin(), tmp.end()); + std::vector WeightRead = readEntireFile( + "./wasinn_neural_speed_fixtures/llama-2-7b-chat.Q4_0.gguf"); std::vector TensorDim{1}; uint32_t BuilderPtr = UINT32_C(0); uint32_t LoadEntryPtr = UINT32_C(0); uint32_t SetInputEntryPtr = UINT32_C(0); -// uint32_t OutBoundPtr = UINT32_C(61000 * 65536); + // uint32_t OutBoundPtr = UINT32_C(61000 * 65536); uint32_t StorePtr = UINT32_C(65536); // Return value. @@ -1816,11 +1819,13 @@ TEST(WasiNNTest, NeuralSpeedBackend) { writeBinaries(MemInst, WeightRead, StorePtr); StorePtr += WeightRead.size(); { - EXPECT_TRUE(HostFuncLoad.run( - CallFrame, - std::initializer_list{ - LoadEntryPtr, UINT32_C(1), static_cast(Backend::NeuralSpeed), UINT32_C(0), BuilderPtr}, - Errno)); + EXPECT_TRUE( + HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::NeuralSpeed), + UINT32_C(0), BuilderPtr}, + Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); BuilderPtr += 4; @@ -1831,10 +1836,10 @@ TEST(WasiNNTest, NeuralSpeedBackend) { { EXPECT_TRUE(HostFuncInit.run( CallFrame, - std::initializer_list{UINT32_C(1), BuilderPtr}, + std::initializer_list{UINT32_C(0), BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); - EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 1); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); BuilderPtr += 4; } @@ -1852,7 +1857,7 @@ TEST(WasiNNTest, NeuralSpeedBackend) { EXPECT_TRUE( HostFuncSetInput.run(CallFrame, std::initializer_list{ - UINT32_C(1), UINT32_C(0), SetInputEntryPtr}, + UINT32_C(0), UINT32_C(0), SetInputEntryPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); } diff --git a/utils/wasi-nn/download-neuralspeed-fixtures.sh b/utils/wasi-nn/download-neuralspeed-fixtures.sh index e050aa3c..222d814e 100755 --- a/utils/wasi-nn/download-neuralspeed-fixtures.sh +++ b/utils/wasi-nn/download-neuralspeed-fixtures.sh @@ -6,8 +6,8 @@ TODIR=$1 if [[ $# -eq 0 ]]; then TODIR=. fi -MODEL=orca_mini.gguf -FIXTURE=https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/blob/main/llama-2-7b-chat.Q4_0.gguf +MODEL=llama-2-7b-chat.Q4_0.gguf +FIXTURE=https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_0.gguf if [ ! -d $TODIR ]; then mkdir $TODIR fi From 64fcd836047909885a33487a7303b677b7617c8c Mon Sep 17 00:00:00 2001 From: grorge Date: Sat, 13 Apr 2024 01:07:02 +0800 Subject: [PATCH 314/623] [WASI-NN] neural speed: fix update pythonlib and share library Signed-off-by: grorge --- plugins/wasi_nn/CMakeLists.txt | 15 ++++++++++++--- plugins/wasi_nn/neuralspeed.cpp | 4 +++- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index aaa60b55..d6a4f951 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -198,9 +198,18 @@ if(WASMEDGE_BUILD_WASI_NN_RPC) endif() if(BACKEND STREQUAL "neuralspeed") - find_package(PythonLibs REQUIRED) - include_directories(${PYTHON_INCLUDE_DIRS}) - target_link_libraries(wasmedgePluginWasiNN PUBLIC ${PYTHON_LIBRARIES}) +find_package (Python3 COMPONENTS Interpreter Development) +if(Python3_FOUND) + target_compile_options(wasmedgePluginWasiNN PUBLIC -Xlinker -export-dynamic) + target_compile_definitions(wasmedgePluginWasiNN + PUBLIC PYTHON_LIB_PATH="${Python3_LIBRARIES}" + ) + include_directories(${Python3_INCLUDE_DIRS}) + target_link_libraries(wasmedgePluginWasiNN PUBLIC ${Python3_LIBRARIES}) + target_link_directories(wasmedgePluginWasiNN PUBLIC ${Python3_RUNTIME_LIBRARY_DIRS}) + elseif() + message(FATAL_ERROR "Can not find python3.") + endif() endif() include(WASINNDeps) diff --git a/plugins/wasi_nn/neuralspeed.cpp b/plugins/wasi_nn/neuralspeed.cpp index c5b264e0..c72340ad 100644 --- a/plugins/wasi_nn/neuralspeed.cpp +++ b/plugins/wasi_nn/neuralspeed.cpp @@ -1,8 +1,10 @@ #include "neuralspeed.h" #include "wasinnenv.h" +#include #include namespace WasmEdge::Host::WASINN::NeuralSpeed { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED +void *SharedLib = dlopen(PYTHON_LIB_PATH, RTLD_GLOBAL | RTLD_NOW); Expect load(WASINN::WasiNNEnvironment &Env, Span> Builders, WASINN::Device, uint32_t &) noexcept { @@ -62,6 +64,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, GraphRef.NeuralSpeedModule = PyImport_Import(PyUnicode_FromString("neural_speed")); if (GraphRef.NeuralSpeedModule == nullptr) { + PyErr_Print(); spdlog::error( "[WASI-NN] neural speed backend: Can not find neural speed library."sv); return WASINN::ErrNo::RuntimeError; @@ -78,7 +81,6 @@ Expect load(WASINN::WasiNNEnvironment &Env, PyObject *LoadResult = PyObject_CallMethod( GraphRef.Model, "init_from_bin", "(ss)", "llama", ModelFilePath.c_str()); if (LoadResult == nullptr) { - PyErr_Print(); spdlog::error("[WASI-NN] neural speed backend: Load model error."sv); return WASINN::ErrNo::RuntimeError; } From 92c973470fb86483bb67a9260dd03884975cdfaf Mon Sep 17 00:00:00 2001 From: grorge Date: Thu, 18 Apr 2024 23:38:11 +0800 Subject: [PATCH 315/623] [WASI-NN] neural speed: add setting model type Signed-off-by: grorge --- plugins/wasi_nn/CMakeLists.txt | 60 ++++++++++++++++++++++++++++++++- plugins/wasi_nn/neuralspeed.cpp | 33 +++++++++++++++--- 2 files changed, 87 insertions(+), 6 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index d6a4f951..25cfeaf8 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -107,6 +107,64 @@ if(BACKEND STREQUAL "ggml") wasmedge_setup_simdjson() endif() +if(BACKEND STREQUAL "neuralspeed") + find_package(simdjson QUIET) + if(simdjson_FOUND) + message(STATUS "SIMDJSON found") + else() + message(STATUS "Downloading SIMDJSON source") + include(FetchContent) + FetchContent_Declare( + simdjson + GIT_REPOSITORY https://github.com/simdjson/simdjson.git + GIT_TAG tags/v3.2.1 + GIT_SHALLOW TRUE) + + if(MSVC) + if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") + get_property( + compile_options + DIRECTORY + PROPERTY COMPILE_OPTIONS + ) + set_property( + DIRECTORY + APPEND + PROPERTY COMPILE_OPTIONS + -Wno-undef + -Wno-suggest-override + -Wno-documentation + -Wno-sign-conversion + -Wno-extra-semi-stmt + -Wno-old-style-cast + -Wno-error=unused-parameter + -Wno-error=unused-template + -Wno-conditional-uninitialized + -Wno-implicit-int-conversion + -Wno-shorten-64-to-32 + -Wno-range-loop-bind-reference + -Wno-format-nonliteral + -Wno-unused-exception-parameter + -Wno-unused-member-function + ) + unset(compile_options) + elseif(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") + set_property( + DIRECTORY + APPEND + PROPERTY COMPILE_OPTIONS + /wd4100 # unreferenced formal parameter + ) + endif() + endif() + + FetchContent_MakeAvailable(simdjson) + set_property(TARGET simdjson PROPERTY POSITION_INDEPENDENT_CODE ON) + + message(STATUS "Downloading SIMDJSON source -- done") + endif() +endif() + wasmedge_add_library(wasmedgePluginWasiNN SHARED wasinnenv.cpp @@ -200,7 +258,6 @@ endif() if(BACKEND STREQUAL "neuralspeed") find_package (Python3 COMPONENTS Interpreter Development) if(Python3_FOUND) - target_compile_options(wasmedgePluginWasiNN PUBLIC -Xlinker -export-dynamic) target_compile_definitions(wasmedgePluginWasiNN PUBLIC PYTHON_LIB_PATH="${Python3_LIBRARIES}" ) @@ -210,6 +267,7 @@ if(Python3_FOUND) elseif() message(FATAL_ERROR "Can not find python3.") endif() + target_link_libraries(wasmedgePluginWasiNN PRIVATE simdjson::simdjson) endif() include(WASINNDeps) diff --git a/plugins/wasi_nn/neuralspeed.cpp b/plugins/wasi_nn/neuralspeed.cpp index c72340ad..300c8b6d 100644 --- a/plugins/wasi_nn/neuralspeed.cpp +++ b/plugins/wasi_nn/neuralspeed.cpp @@ -1,4 +1,5 @@ #include "neuralspeed.h" +#include "simdjson.h" #include "wasinnenv.h" #include #include @@ -18,6 +19,29 @@ Expect load(WASINN::WasiNNEnvironment &Env, if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] Neural speed backend: Load."sv); } + + if (Builders.size() > 1) { + std::string Metadata = std::string( + reinterpret_cast(Builders[1].data()), Builders[1].size()); + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + auto ParseError = Parser.parse(Metadata).get(Doc); + if (ParseError) { + spdlog::error("[WASI-NN] neural speed backend: Parse metadata error"sv); + return ErrNo::InvalidEncoding; + } + if (Doc.at_key("model_type").error() == simdjson::SUCCESS) { + std::string_view model_type; + auto Err = Doc["model_type-log"].get().get(model_type); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the enable-log option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.model_type = model_type; + } + } + // Handle the model path. auto Weight = Builders[0]; const std::string BinModel(reinterpret_cast(Weight.data()), @@ -42,9 +66,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, if (!TempFile) { spdlog::error( "[WASI-NN] Neural speed: Failed to create the temporary file. " - "Currently, our workaround involves creating a temporary model " - "file named \"ggml-model.bin\" and passing this filename as a " - "parameter to the ggml llama library."sv); + "Currently, our workaround involves creating a temporary model."sv); Env.NNGraph.pop_back(); return ErrNo::InvalidArgument; } @@ -78,8 +100,9 @@ Expect load(WASINN::WasiNNEnvironment &Env, return WASINN::ErrNo::RuntimeError; } GraphRef.Model = PyObject_CallObject(GraphRef.ModelClass, NULL); - PyObject *LoadResult = PyObject_CallMethod( - GraphRef.Model, "init_from_bin", "(ss)", "llama", ModelFilePath.c_str()); + PyObject *LoadResult = + PyObject_CallMethod(GraphRef.Model, "init_from_bin", "(ss)", + GraphRef.model_type.c_str(), ModelFilePath.c_str()); if (LoadResult == nullptr) { spdlog::error("[WASI-NN] neural speed backend: Load model error."sv); return WASINN::ErrNo::RuntimeError; From 9dc3541f83d5e592af198efc534566741b554b46 Mon Sep 17 00:00:00 2001 From: grorge Date: Thu, 18 Apr 2024 23:39:01 +0800 Subject: [PATCH 316/623] [WASI-NN] neural speed: fix python finalize deadlock Signed-off-by: grorge --- plugins/wasi_nn/neuralspeed.cpp | 12 ++++++++++++ plugins/wasi_nn/neuralspeed.h | 25 ++++--------------------- plugins/wasi_nn/wasinnfunc.cpp | 2 ++ 3 files changed, 18 insertions(+), 21 deletions(-) diff --git a/plugins/wasi_nn/neuralspeed.cpp b/plugins/wasi_nn/neuralspeed.cpp index 300c8b6d..8aa08cd9 100644 --- a/plugins/wasi_nn/neuralspeed.cpp +++ b/plugins/wasi_nn/neuralspeed.cpp @@ -217,9 +217,11 @@ Expect compute(WasiNNEnvironment &Env, if (PyLong_Check(Num)) { InnerVec.push_back(PyLong_AsLong(Num)); } + Py_DECREF(Num); } CxtRef.Outputs = InnerVec; } + Py_DECREF(InnerList); } } Py_DECREF(Result); @@ -227,6 +229,12 @@ Expect compute(WasiNNEnvironment &Env, Py_DECREF(LongTensor); return WASINN::ErrNo::Success; } + +Expect finiSingle(WASINN::WasiNNEnvironment &, + uint32_t) noexcept { + Py_Finalize(); + return WASINN::ErrNo::Success; +} #else namespace { Expect reportBackendNotSupported() noexcept { @@ -255,5 +263,9 @@ Expect getOutput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, Expect compute(WASINN::WasiNNEnvironment &, uint32_t) noexcept { return reportBackendNotSupported(); } +Expect finiSingle(WASINN::WasiNNEnvironment &, + uint32_t) noexcept { + return reportBackendNotSupported(); +} #endif } // namespace WasmEdge::Host::WASINN::NeuralSpeed diff --git a/plugins/wasi_nn/neuralspeed.h b/plugins/wasi_nn/neuralspeed.h index cf6c34dd..60d34646 100644 --- a/plugins/wasi_nn/neuralspeed.h +++ b/plugins/wasi_nn/neuralspeed.h @@ -14,28 +14,9 @@ namespace WasmEdge::Host::WASINN::NeuralSpeed { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED struct Graph { bool EnableDebugLog = true; - // TODO add mutex - // static std::mutex py_mutex; + std::string model_type = "llama"; inline static int GraphNumber = 0; - Graph() noexcept { - // py_mutex.lock(); - if (GraphNumber == 0) { - Py_Initialize(); - } - GraphNumber++; - // py_mutex.unlock(); - } - ~Graph() noexcept { - Py_XDECREF(Model); - Py_XDECREF(ModelClass); - Py_XDECREF(NeuralSpeedModule); - // py_mutex.lock(); - if (GraphNumber == 1) { - Py_Finalize(); - } - GraphNumber--; - // py_mutex.unlock(); - } + Graph() noexcept { Py_Initialize(); } PyObject *Model; PyObject *NeuralSpeedModule; @@ -71,5 +52,7 @@ Expect getOutput(WASINN::WasiNNEnvironment &Env, uint32_t &BytesWritten) noexcept; Expect compute(WASINN::WasiNNEnvironment &Env, uint32_t ContextId) noexcept; +Expect finiSingle(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; } // namespace WasmEdge::Host::WASINN::NeuralSpeed \ No newline at end of file diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index 8466eaa4..6372947d 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -569,6 +569,8 @@ WasiNNFiniSingle::bodyImpl(const Runtime::CallingFrame &Frame, switch (Env.NNContext[Context].getBackend()) { case WASINN::Backend::GGML: return WASINN::GGML::finiSingle(Env, Context); + case WASINN::Backend::NeuralSpeed: + return WASINN::NeuralSpeed::finiSingle(Env, Context); default: spdlog::error( "[WASI-NN] fini_single: Only GGML backend supports compute_single."sv); From e27a78b5ce08d846c22d452f3c8572419d719772 Mon Sep 17 00:00:00 2001 From: grorge Date: Wed, 3 Apr 2024 22:31:42 +0800 Subject: [PATCH 317/623] [WASI-NN] neural speed: add successful test Signed-off-by: grorge --- test/plugins/wasi_nn/CMakeLists.txt | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index 8674aaf1..a747e82d 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -70,6 +70,16 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) /wd4067 # unexpected tokens following preprocessor directive - expected a newline ) endif() + elseif(BACKEND STREQUAL "neuralspeed") + message( STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_neural_speed_fixtures") + execute_process( + COMMAND bash ${CMAKE_SOURCE_DIR}/utils/wasi-nn/download-neuralspeed-fixtures.sh ${CMAKE_CURRENT_BINARY_DIR}/wasinn_neural_speed_fixtures + RESULT_VARIABLE DOWNLOAD_ERROR + OUTPUT_STRIP_TRAILING_WHITESPACE) + file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_neural_speed_fixtures/llama-2-7b-chat.Q4_0.gguf CHECKSUM_MODEL) + if(NOT CHECKSUM_MODEL STREQUAL "09ff9df730d4b5a8b60fcd89e97ee6f4") + message(FATAL_ERROR "llama-2-7b-chat.Q4_0.gguf downloaded with wrong md5") + endif() else() # Add the other backend test files fetching here. endif() From 8dbb350716b24576c5227e62942d644792464890 Mon Sep 17 00:00:00 2001 From: grorge Date: Wed, 24 Apr 2024 14:52:52 +0800 Subject: [PATCH 318/623] [WASI-NN] neural speed: update build neural speed cmake Signed-off-by: grorge --- plugins/wasi_nn/CMakeLists.txt | 6 +++--- test/plugins/wasi_nn/CMakeLists.txt | 13 +++++-------- utils/wasi-nn/download-neuralspeed-fixtures.sh | 17 ----------------- 3 files changed, 8 insertions(+), 28 deletions(-) delete mode 100755 utils/wasi-nn/download-neuralspeed-fixtures.sh diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 25cfeaf8..68756db5 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -259,11 +259,11 @@ if(BACKEND STREQUAL "neuralspeed") find_package (Python3 COMPONENTS Interpreter Development) if(Python3_FOUND) target_compile_definitions(wasmedgePluginWasiNN - PUBLIC PYTHON_LIB_PATH="${Python3_LIBRARIES}" + PRIVATE PYTHON_LIB_PATH="${Python3_LIBRARIES}" ) include_directories(${Python3_INCLUDE_DIRS}) - target_link_libraries(wasmedgePluginWasiNN PUBLIC ${Python3_LIBRARIES}) - target_link_directories(wasmedgePluginWasiNN PUBLIC ${Python3_RUNTIME_LIBRARY_DIRS}) + target_link_libraries(wasmedgePluginWasiNN PRIVATE ${Python3_LIBRARIES}) + target_link_directories(wasmedgePluginWasiNN PRIVATE ${Python3_RUNTIME_LIBRARY_DIRS}) elseif() message(FATAL_ERROR "Can not find python3.") endif() diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index a747e82d..ad6f85b8 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -72,14 +72,11 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) endif() elseif(BACKEND STREQUAL "neuralspeed") message( STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_neural_speed_fixtures") - execute_process( - COMMAND bash ${CMAKE_SOURCE_DIR}/utils/wasi-nn/download-neuralspeed-fixtures.sh ${CMAKE_CURRENT_BINARY_DIR}/wasinn_neural_speed_fixtures - RESULT_VARIABLE DOWNLOAD_ERROR - OUTPUT_STRIP_TRAILING_WHITESPACE) - file(MD5 ${CMAKE_CURRENT_BINARY_DIR}/wasinn_neural_speed_fixtures/llama-2-7b-chat.Q4_0.gguf CHECKSUM_MODEL) - if(NOT CHECKSUM_MODEL STREQUAL "09ff9df730d4b5a8b60fcd89e97ee6f4") - message(FATAL_ERROR "llama-2-7b-chat.Q4_0.gguf downloaded with wrong md5") - endif() + download( + https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_0.gguf + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_neural_speed_fixtures/llama-2-7b-chat.Q4_0.gguf + MD5=09ff9df730d4b5a8b60fcd89e97ee6f4 + ) else() # Add the other backend test files fetching here. endif() diff --git a/utils/wasi-nn/download-neuralspeed-fixtures.sh b/utils/wasi-nn/download-neuralspeed-fixtures.sh deleted file mode 100755 index 222d814e..00000000 --- a/utils/wasi-nn/download-neuralspeed-fixtures.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/usr/bin/env bash -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# SPDX-FileCopyrightText: 2019-2023 Second State INC - -TODIR=$1 -if [[ $# -eq 0 ]]; then - TODIR=. -fi -MODEL=llama-2-7b-chat.Q4_0.gguf -FIXTURE=https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_0.gguf -if [ ! -d $TODIR ]; then - mkdir $TODIR -fi - -if [ ! -f $TODIR/$MODEL ]; then - curl -sL $FIXTURE -o $TODIR/$MODEL -fi From 04f578c0b2a89b344fc21bc1c07e1d08db95cce3 Mon Sep 17 00:00:00 2001 From: grorge Date: Wed, 24 Apr 2024 16:10:25 +0800 Subject: [PATCH 319/623] [WASI-NN] neural speed: add neural speed backend ad unload function Signed-off-by: grorge --- plugins/wasi_nn/neuralspeed.cpp | 49 ++++++++++++++++++++++++++++---- plugins/wasi_nn/neuralspeed.h | 13 +++++++-- plugins/wasi_nn/wasinnfunc.cpp | 7 +++-- test/plugins/wasi_nn/wasi_nn.cpp | 15 ++++++++++ 4 files changed, 73 insertions(+), 11 deletions(-) diff --git a/plugins/wasi_nn/neuralspeed.cpp b/plugins/wasi_nn/neuralspeed.cpp index 8aa08cd9..5b5e1966 100644 --- a/plugins/wasi_nn/neuralspeed.cpp +++ b/plugins/wasi_nn/neuralspeed.cpp @@ -27,15 +27,15 @@ Expect load(WASINN::WasiNNEnvironment &Env, simdjson::dom::element Doc; auto ParseError = Parser.parse(Metadata).get(Doc); if (ParseError) { - spdlog::error("[WASI-NN] neural speed backend: Parse metadata error"sv); + spdlog::error("[WASI-NN] Neural speed backend: Parse metadata error"sv); return ErrNo::InvalidEncoding; } if (Doc.at_key("model_type").error() == simdjson::SUCCESS) { std::string_view model_type; - auto Err = Doc["model_type-log"].get().get(model_type); + auto Err = Doc["model_type"].get().get(model_type); if (Err) { spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the enable-log option."sv); + "[WASI-NN] Neural speed backend: Unable to retrieve the enable-log option."sv); return ErrNo::InvalidArgument; } GraphRef.model_type = model_type; @@ -83,6 +83,9 @@ Expect load(WASINN::WasiNNEnvironment &Env, } // Create Model class + if (!Py_IsInitialized()) { + Py_Initialize(); + } GraphRef.NeuralSpeedModule = PyImport_Import(PyUnicode_FromString("neural_speed")); if (GraphRef.NeuralSpeedModule == nullptr) { @@ -114,6 +117,11 @@ Expect load(WASINN::WasiNNEnvironment &Env, Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, uint32_t &ContextId) noexcept { + if (!Py_IsInitialized()) { + spdlog::info( + "[WASI-NN][Error] Neural speed backend: Model has been realse, please reload it."sv); + return WASINN::ErrNo::RuntimeError; + } Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); ContextId = Env.NNContext.size() - 1; return ErrNo::Success; @@ -123,6 +131,11 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, uint32_t, const TensorData &Tensor) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (!Py_IsInitialized()) { + spdlog::info( + "[WASI-NN][Error] Neural speed backend: Model has been realse, please reload it."sv); + return WASINN::ErrNo::RuntimeError; + } if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] Neural speed backend: setInput"sv); } @@ -159,6 +172,11 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, } Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { + if (!Py_IsInitialized()) { + spdlog::info( + "[WASI-NN][Error] Neural speed backend: Model has been realse, please reload it."sv); + return WASINN::ErrNo::RuntimeError; + } auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); if (GraphRef.EnableDebugLog) { @@ -230,9 +248,25 @@ Expect compute(WasiNNEnvironment &Env, return WASINN::ErrNo::Success; } -Expect finiSingle(WASINN::WasiNNEnvironment &, - uint32_t) noexcept { - Py_Finalize(); +Expect finiSingle(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + spdlog::info("[WASI-NN] neural speed backend: start finiSingle."sv); + if (NeuralSpeed::unload(Env, ContextId) == WASINN::ErrNo::Success) { + return WASINN::ErrNo::Success; + } + return WASINN::ErrNo::RuntimeError; +} +Expect unload(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + spdlog::info("[WASI-NN] neural speed backend: start unload."sv); + if (Py_IsInitialized()) { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + Py_XDECREF(GraphRef.Model); + Py_XDECREF(GraphRef.ModelClass); + Py_XDECREF(GraphRef.NeuralSpeedModule); + Py_Finalize(); + } return WASINN::ErrNo::Success; } #else @@ -267,5 +301,8 @@ Expect finiSingle(WASINN::WasiNNEnvironment &, uint32_t) noexcept { return reportBackendNotSupported(); } +Expect unload(WASINN::WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} #endif } // namespace WasmEdge::Host::WASINN::NeuralSpeed diff --git a/plugins/wasi_nn/neuralspeed.h b/plugins/wasi_nn/neuralspeed.h index 60d34646..db302ed6 100644 --- a/plugins/wasi_nn/neuralspeed.h +++ b/plugins/wasi_nn/neuralspeed.h @@ -17,7 +17,13 @@ struct Graph { std::string model_type = "llama"; inline static int GraphNumber = 0; Graph() noexcept { Py_Initialize(); } - + ~Graph() noexcept { + if (Py_IsInitialized()) { + Py_XDECREF(Model); + Py_XDECREF(ModelClass); + Py_XDECREF(NeuralSpeedModule); + } + } PyObject *Model; PyObject *NeuralSpeedModule; PyObject *ModelClass; @@ -53,6 +59,7 @@ Expect getOutput(WASINN::WasiNNEnvironment &Env, Expect compute(WASINN::WasiNNEnvironment &Env, uint32_t ContextId) noexcept; Expect finiSingle(WASINN::WasiNNEnvironment &Env, - uint32_t ContextId) noexcept; - + uint32_t GraphId) noexcept; +Expect unload(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId) noexcept; } // namespace WasmEdge::Host::WASINN::NeuralSpeed \ No newline at end of file diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index 6372947d..5efbf525 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -573,7 +573,7 @@ WasiNNFiniSingle::bodyImpl(const Runtime::CallingFrame &Frame, return WASINN::NeuralSpeed::finiSingle(Env, Context); default: spdlog::error( - "[WASI-NN] fini_single: Only GGML backend supports compute_single."sv); + "[WASI-NN] fini_single: Only GGML and NeuralSpeed backend supports compute_single."sv); return WASINN::ErrNo::InvalidArgument; } } @@ -600,8 +600,11 @@ Expect WasiNNUnload::bodyImpl(const Runtime::CallingFrame &Frame, switch (Env.NNGraph[GraphId].getBackend()) { case WASINN::Backend::GGML: return WASINN::GGML::unload(Env, GraphId); + case WASINN::Backend::NeuralSpeed: + return WASINN::NeuralSpeed::unload(Env, GraphId); default: - spdlog::error("[WASI-NN] unlaod: Only GGML backend supports unload."sv); + spdlog::error( + "[WASI-NN] unlaod: Only GGML and Neural speed backend supports unload."sv); return WASINN::ErrNo::InvalidArgument; } } diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 736139de..98c47fa7 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -1811,6 +1811,12 @@ TEST(WasiNNTest, NeuralSpeedBackend) { EXPECT_TRUE(FuncInst->isHostFunction()); auto &HostFuncCompute = dynamic_cast(FuncInst->getHostFunc()); + // Get the function "unload". + FuncInst = NNMod->findFuncExports("unload"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncUnload = + dynamic_cast(FuncInst->getHostFunc()); // Neural Speed WASI-NN load tests. // Test: load -- load successfully. @@ -1885,5 +1891,14 @@ TEST(WasiNNTest, NeuralSpeedBackend) { auto BytesWritten = *MemInst.getPointer(BuilderPtr); EXPECT_GE(BytesWritten, 50); } + + // Neural Speed WASI-NN unload tests. + // Test: unload -- unload successfully. + { + EXPECT_TRUE(HostFuncUnload.run( + CallFrame, std::initializer_list{UINT32_C(0)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } } #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED \ No newline at end of file From fe323d30433b9ec4b642f0fc1222fa9ed7ccdfed Mon Sep 17 00:00:00 2001 From: grorge Date: Thu, 25 Apr 2024 12:18:42 +0800 Subject: [PATCH 320/623] [WASI-NN] neural speed: fix simdjson not fund Signed-off-by: grorge --- plugins/wasi_nn/neuralspeed.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/plugins/wasi_nn/neuralspeed.cpp b/plugins/wasi_nn/neuralspeed.cpp index 5b5e1966..4cc59ebb 100644 --- a/plugins/wasi_nn/neuralspeed.cpp +++ b/plugins/wasi_nn/neuralspeed.cpp @@ -1,5 +1,7 @@ #include "neuralspeed.h" +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED #include "simdjson.h" +#endif #include "wasinnenv.h" #include #include From feefa543db7866b0cd20405edaff65edefe2b4b7 Mon Sep 17 00:00:00 2001 From: grorge Date: Fri, 26 Apr 2024 14:14:56 +0800 Subject: [PATCH 321/623] [WASI-NN] neural speed: fix remove finiSingle and set simdjson tags/v3.9.1 Signed-off-by: grorge --- plugins/wasi_nn/CMakeLists.txt | 2 +- plugins/wasi_nn/neuralspeed.cpp | 24 ++++++------------------ plugins/wasi_nn/neuralspeed.h | 2 -- plugins/wasi_nn/wasinnfunc.cpp | 4 +--- 4 files changed, 8 insertions(+), 24 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 68756db5..99798fd5 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -117,7 +117,7 @@ if(BACKEND STREQUAL "neuralspeed") FetchContent_Declare( simdjson GIT_REPOSITORY https://github.com/simdjson/simdjson.git - GIT_TAG tags/v3.2.1 + GIT_TAG tags/v3.9.1 GIT_SHALLOW TRUE) if(MSVC) diff --git a/plugins/wasi_nn/neuralspeed.cpp b/plugins/wasi_nn/neuralspeed.cpp index 4cc59ebb..4f932968 100644 --- a/plugins/wasi_nn/neuralspeed.cpp +++ b/plugins/wasi_nn/neuralspeed.cpp @@ -11,7 +11,6 @@ void *SharedLib = dlopen(PYTHON_LIB_PATH, RTLD_GLOBAL | RTLD_NOW); Expect load(WASINN::WasiNNEnvironment &Env, Span> Builders, WASINN::Device, uint32_t &) noexcept { - spdlog::info("[WASI-NN][Debug] Neural speed: test"sv); // Add a new graph. Env.NNGraph.emplace_back(Backend::NeuralSpeed); auto &GraphRef = Env.NNGraph.back().get(); @@ -37,7 +36,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, auto Err = Doc["model_type"].get().get(model_type); if (Err) { spdlog::error( - "[WASI-NN] Neural speed backend: Unable to retrieve the enable-log option."sv); + "[WASI-NN] Neural speed backend: Unable to retrieve the model_type option."sv); return ErrNo::InvalidArgument; } GraphRef.model_type = model_type; @@ -249,21 +248,14 @@ Expect compute(WasiNNEnvironment &Env, Py_DECREF(LongTensor); return WASINN::ErrNo::Success; } - -Expect finiSingle(WASINN::WasiNNEnvironment &Env, - uint32_t ContextId) noexcept { - spdlog::info("[WASI-NN] neural speed backend: start finiSingle."sv); - if (NeuralSpeed::unload(Env, ContextId) == WASINN::ErrNo::Success) { - return WASINN::ErrNo::Success; - } - return WASINN::ErrNo::RuntimeError; -} Expect unload(WASINN::WasiNNEnvironment &Env, uint32_t ContextId) noexcept { - spdlog::info("[WASI-NN] neural speed backend: start unload."sv); + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] Neural speed backend: start unload."sv); + } if (Py_IsInitialized()) { - auto &CxtRef = Env.NNContext[ContextId].get(); - auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); Py_XDECREF(GraphRef.Model); Py_XDECREF(GraphRef.ModelClass); Py_XDECREF(GraphRef.NeuralSpeedModule); @@ -299,10 +291,6 @@ Expect getOutput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, Expect compute(WASINN::WasiNNEnvironment &, uint32_t) noexcept { return reportBackendNotSupported(); } -Expect finiSingle(WASINN::WasiNNEnvironment &, - uint32_t) noexcept { - return reportBackendNotSupported(); -} Expect unload(WASINN::WasiNNEnvironment &, uint32_t) noexcept { return reportBackendNotSupported(); } diff --git a/plugins/wasi_nn/neuralspeed.h b/plugins/wasi_nn/neuralspeed.h index db302ed6..fd4ed2ec 100644 --- a/plugins/wasi_nn/neuralspeed.h +++ b/plugins/wasi_nn/neuralspeed.h @@ -58,8 +58,6 @@ Expect getOutput(WASINN::WasiNNEnvironment &Env, uint32_t &BytesWritten) noexcept; Expect compute(WASINN::WasiNNEnvironment &Env, uint32_t ContextId) noexcept; -Expect finiSingle(WASINN::WasiNNEnvironment &Env, - uint32_t GraphId) noexcept; Expect unload(WASINN::WasiNNEnvironment &Env, uint32_t GraphId) noexcept; } // namespace WasmEdge::Host::WASINN::NeuralSpeed \ No newline at end of file diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index 5efbf525..aa16fbc3 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -569,11 +569,9 @@ WasiNNFiniSingle::bodyImpl(const Runtime::CallingFrame &Frame, switch (Env.NNContext[Context].getBackend()) { case WASINN::Backend::GGML: return WASINN::GGML::finiSingle(Env, Context); - case WASINN::Backend::NeuralSpeed: - return WASINN::NeuralSpeed::finiSingle(Env, Context); default: spdlog::error( - "[WASI-NN] fini_single: Only GGML and NeuralSpeed backend supports compute_single."sv); + "[WASI-NN] fini_single: Only GGML backend supports compute_single."sv); return WASINN::ErrNo::InvalidArgument; } } From f249443de993e5e5530339c6ccc776d225bafcf3 Mon Sep 17 00:00:00 2001 From: grorge Date: Wed, 1 May 2024 21:18:39 +0800 Subject: [PATCH 322/623] [WASI-NN] neural speed: add print run time Signed-off-by: grorge --- plugins/wasi_nn/neuralspeed.cpp | 27 +++++++++++++++++++++++++-- plugins/wasi_nn/neuralspeed.h | 2 ++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn/neuralspeed.cpp b/plugins/wasi_nn/neuralspeed.cpp index 4f932968..7af04995 100644 --- a/plugins/wasi_nn/neuralspeed.cpp +++ b/plugins/wasi_nn/neuralspeed.cpp @@ -1,6 +1,7 @@ #include "neuralspeed.h" #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED #include "simdjson.h" +#include #endif #include "wasinnenv.h" #include @@ -8,13 +9,30 @@ namespace WasmEdge::Host::WASINN::NeuralSpeed { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED void *SharedLib = dlopen(PYTHON_LIB_PATH, RTLD_GLOBAL | RTLD_NOW); +int64_t nowTime_ms() { + struct timespec Time; + clock_gettime(CLOCK_REALTIME, &Time); + return (int64_t)Time.tv_sec * 1000 + (int64_t)Time.tv_nsec / 1000000; +} +void printImformation(Graph &GraphRef, Context &CxtRef) { + spdlog::info( + "[WASI-NN][Info] Neural speed backend: Number of input tokens: {}"sv, + CxtRef.Inputs.size()); + spdlog::info( + "[WASI-NN][Info] Neural speed backend: Number of Output tokens: {}"sv, + CxtRef.Outputs.size()); + spdlog::info("[WASI-NN][Info] Neural speed backend: Load time: {}ms"sv, + GraphRef.LoadTime); + spdlog::info("[WASI-NN][Info] Neural speed backend: Compute time: {}ms "sv, + GraphRef.ComputeTime); +} Expect load(WASINN::WasiNNEnvironment &Env, Span> Builders, WASINN::Device, uint32_t &) noexcept { // Add a new graph. Env.NNGraph.emplace_back(Backend::NeuralSpeed); auto &GraphRef = Env.NNGraph.back().get(); - + GraphRef.LoadTime = nowTime_ms(); // Initialize the plugin parameters. GraphRef.EnableDebugLog = true; if (GraphRef.EnableDebugLog) { @@ -112,7 +130,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, return WASINN::ErrNo::RuntimeError; } Py_XDECREF(LoadResult); - + GraphRef.LoadTime = nowTime_ms() - GraphRef.LoadTime; return WASINN::ErrNo::Success; } @@ -180,6 +198,7 @@ Expect compute(WasiNNEnvironment &Env, } auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + GraphRef.ComputeTime = nowTime_ms(); if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] Neural speed backend: compute"sv); } @@ -246,6 +265,10 @@ Expect compute(WasiNNEnvironment &Env, Py_DECREF(Result); // Py_DECREF(GenerateArgs); Py_DECREF(LongTensor); + GraphRef.ComputeTime = nowTime_ms() - GraphRef.ComputeTime; + if (GraphRef.EnableDebugLog) { + printImformation(GraphRef, CxtRef); + } return WASINN::ErrNo::Success; } Expect unload(WASINN::WasiNNEnvironment &Env, diff --git a/plugins/wasi_nn/neuralspeed.h b/plugins/wasi_nn/neuralspeed.h index fd4ed2ec..bd06e2a4 100644 --- a/plugins/wasi_nn/neuralspeed.h +++ b/plugins/wasi_nn/neuralspeed.h @@ -27,6 +27,8 @@ struct Graph { PyObject *Model; PyObject *NeuralSpeedModule; PyObject *ModelClass; + int64_t LoadTime; + int64_t ComputeTime; }; struct Context { Context(size_t Gid, Graph &) noexcept : GraphId(Gid) {} From 0a732a3aa8aac56e76b8a10be9cb4fbd280d8fae Mon Sep 17 00:00:00 2001 From: grorge Date: Thu, 2 May 2024 14:44:20 +0800 Subject: [PATCH 323/623] [WASI-NN] neural speed: fix change time to chrono Signed-off-by: grorge --- plugins/wasi_nn/neuralspeed.cpp | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/plugins/wasi_nn/neuralspeed.cpp b/plugins/wasi_nn/neuralspeed.cpp index 7af04995..f96e1fee 100644 --- a/plugins/wasi_nn/neuralspeed.cpp +++ b/plugins/wasi_nn/neuralspeed.cpp @@ -4,16 +4,12 @@ #include #endif #include "wasinnenv.h" +#include #include -#include + namespace WasmEdge::Host::WASINN::NeuralSpeed { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED void *SharedLib = dlopen(PYTHON_LIB_PATH, RTLD_GLOBAL | RTLD_NOW); -int64_t nowTime_ms() { - struct timespec Time; - clock_gettime(CLOCK_REALTIME, &Time); - return (int64_t)Time.tv_sec * 1000 + (int64_t)Time.tv_nsec / 1000000; -} void printImformation(Graph &GraphRef, Context &CxtRef) { spdlog::info( "[WASI-NN][Info] Neural speed backend: Number of input tokens: {}"sv, @@ -21,9 +17,9 @@ void printImformation(Graph &GraphRef, Context &CxtRef) { spdlog::info( "[WASI-NN][Info] Neural speed backend: Number of Output tokens: {}"sv, CxtRef.Outputs.size()); - spdlog::info("[WASI-NN][Info] Neural speed backend: Load time: {}ms"sv, + spdlog::info("[WASI-NN][Info] Neural speed backend: Load time: {} ms"sv, GraphRef.LoadTime); - spdlog::info("[WASI-NN][Info] Neural speed backend: Compute time: {}ms "sv, + spdlog::info("[WASI-NN][Info] Neural speed backend: Compute time: {} ms "sv, GraphRef.ComputeTime); } Expect load(WASINN::WasiNNEnvironment &Env, @@ -32,7 +28,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, // Add a new graph. Env.NNGraph.emplace_back(Backend::NeuralSpeed); auto &GraphRef = Env.NNGraph.back().get(); - GraphRef.LoadTime = nowTime_ms(); + const auto StartTime = std::chrono::steady_clock::now(); // Initialize the plugin parameters. GraphRef.EnableDebugLog = true; if (GraphRef.EnableDebugLog) { @@ -130,7 +126,9 @@ Expect load(WASINN::WasiNNEnvironment &Env, return WASINN::ErrNo::RuntimeError; } Py_XDECREF(LoadResult); - GraphRef.LoadTime = nowTime_ms() - GraphRef.LoadTime; + GraphRef.LoadTime = std::chrono::duration_cast( + std::chrono::steady_clock::now() - StartTime) + .count(); return WASINN::ErrNo::Success; } @@ -198,7 +196,7 @@ Expect compute(WasiNNEnvironment &Env, } auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - GraphRef.ComputeTime = nowTime_ms(); + const auto StartTime = std::chrono::steady_clock::now(); if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] Neural speed backend: compute"sv); } @@ -265,7 +263,9 @@ Expect compute(WasiNNEnvironment &Env, Py_DECREF(Result); // Py_DECREF(GenerateArgs); Py_DECREF(LongTensor); - GraphRef.ComputeTime = nowTime_ms() - GraphRef.ComputeTime; + GraphRef.ComputeTime = std::chrono::duration_cast( + std::chrono::steady_clock::now() - StartTime) + .count(); if (GraphRef.EnableDebugLog) { printImformation(GraphRef, CxtRef); } From 3457f9905de7ebc7aea4bfcb16c08bbffcce9725 Mon Sep 17 00:00:00 2001 From: grorge Date: Wed, 8 May 2024 16:51:20 +0800 Subject: [PATCH 324/623] [WASI-NN] neural speed: fix windows dlopen Signed-off-by: grorge --- plugins/wasi_nn/neuralspeed.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/plugins/wasi_nn/neuralspeed.cpp b/plugins/wasi_nn/neuralspeed.cpp index f96e1fee..50f1a3a3 100644 --- a/plugins/wasi_nn/neuralspeed.cpp +++ b/plugins/wasi_nn/neuralspeed.cpp @@ -5,11 +5,18 @@ #endif #include "wasinnenv.h" #include +#if !defined(_WIN32) && !defined(_WIN64) && !defined(__WIN32__) && \ + !defined(__TOS_WIN__) && !defined(__WINDOWS__) #include - +#endif namespace WasmEdge::Host::WASINN::NeuralSpeed { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED +#if defined(_WIN32) || defined(_WIN64) || defined(__WIN32__) || \ + defined(__TOS_WIN__) || defined(__WINDOWS__) +HINSTANCE SharedLib = LoadLibrary(PYTHON_LIB_PATH); +#else void *SharedLib = dlopen(PYTHON_LIB_PATH, RTLD_GLOBAL | RTLD_NOW); +#endif void printImformation(Graph &GraphRef, Context &CxtRef) { spdlog::info( "[WASI-NN][Info] Neural speed backend: Number of input tokens: {}"sv, From bf9aad1d4301dde9a1c84fae997324faa2cf1a08 Mon Sep 17 00:00:00 2001 From: grorge Date: Mon, 13 May 2024 14:21:48 +0800 Subject: [PATCH 325/623] [WASI-NN] neural speed: add more test Signed-off-by: grorge --- plugins/wasi_nn/neuralspeed.cpp | 12 +++- test/plugins/wasi_nn/wasi_nn.cpp | 104 ++++++++++++++++++++++++++++++- 2 files changed, 114 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn/neuralspeed.cpp b/plugins/wasi_nn/neuralspeed.cpp index 50f1a3a3..ecfdbf8e 100644 --- a/plugins/wasi_nn/neuralspeed.cpp +++ b/plugins/wasi_nn/neuralspeed.cpp @@ -122,15 +122,25 @@ Expect load(WASINN::WasiNNEnvironment &Env, !PyCallable_Check(GraphRef.ModelClass)) { spdlog::error( "[WASI-NN] neural speed backend: Can not find Model class in neural speed."sv); + Py_XDECREF(GraphRef.Model); return WASINN::ErrNo::RuntimeError; } GraphRef.Model = PyObject_CallObject(GraphRef.ModelClass, NULL); + if (GraphRef.Model == nullptr) { + spdlog::error("[WASI-NN] neural speed backend: Load model error."sv); + Py_XDECREF(GraphRef.ModelClass); + Py_XDECREF(GraphRef.NeuralSpeedModule); + return WASINN::ErrNo::InvalidArgument; + } PyObject *LoadResult = PyObject_CallMethod(GraphRef.Model, "init_from_bin", "(ss)", GraphRef.model_type.c_str(), ModelFilePath.c_str()); if (LoadResult == nullptr) { spdlog::error("[WASI-NN] neural speed backend: Load model error."sv); - return WASINN::ErrNo::RuntimeError; + Py_XDECREF(GraphRef.Model); + Py_XDECREF(GraphRef.ModelClass); + Py_XDECREF(GraphRef.NeuralSpeedModule); + return WASINN::ErrNo::InvalidArgument; } Py_XDECREF(LoadResult); GraphRef.LoadTime = std::chrono::duration_cast( diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 98c47fa7..2df44ed1 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -1775,7 +1775,7 @@ TEST(WasiNNTest, NeuralSpeedBackend) { uint32_t BuilderPtr = UINT32_C(0); uint32_t LoadEntryPtr = UINT32_C(0); uint32_t SetInputEntryPtr = UINT32_C(0); - // uint32_t OutBoundPtr = UINT32_C(61000 * 65536); + uint32_t OutBoundPtr = UINT32_C(61000 * 65536); uint32_t StorePtr = UINT32_C(65536); // Return value. @@ -1819,6 +1819,59 @@ TEST(WasiNNTest, NeuralSpeedBackend) { dynamic_cast(FuncInst->getHostFunc()); // Neural Speed WASI-NN load tests. + // Test: load -- meaningless binaries. + { + EXPECT_TRUE( + HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::NeuralSpeed), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: load -- graph id ptr out of bounds. + { + EXPECT_TRUE( + HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::NeuralSpeed), + UINT32_C(0), OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- graph builder ptr out of bounds. + { + EXPECT_TRUE( + HostFuncLoad.run(CallFrame, + std::initializer_list{ + OutBoundPtr, UINT32_C(1), + static_cast(Backend::NeuralSpeed), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- Neural Speed model bin ptr out of bounds. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, OutBoundPtr, + static_cast(WeightRead.size()), BuilderPtr); + { + EXPECT_TRUE( + HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::NeuralSpeed), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } // Test: load -- load successfully. BuilderPtr = LoadEntryPtr; writeFatPointer(MemInst, StorePtr, WeightRead.size(), BuilderPtr); @@ -1838,6 +1891,16 @@ TEST(WasiNNTest, NeuralSpeedBackend) { } // Neural Speed WASI-NN init_execution_context tests. + // Test: init_execution_context -- graph id invalid. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(2), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: init_execution_context -- init second context. { EXPECT_TRUE(HostFuncInit.run( @@ -1858,6 +1921,16 @@ TEST(WasiNNTest, NeuralSpeedBackend) { writeBinaries(MemInst, TensorDim, StorePtr); writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); + // Test: set_input -- context id exceeds. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(3), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } // Test: set_input -- set input successfully. { EXPECT_TRUE( @@ -1870,6 +1943,14 @@ TEST(WasiNNTest, NeuralSpeedBackend) { StorePtr += (TensorDim.size() * 4 + TensorData.size()); // Neural Speed WASI-NN compute tests. + // Test: compute -- context id exceeds. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(3)}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } // Test: compute -- compute successfully. { EXPECT_TRUE(HostFuncCompute.run( @@ -1879,6 +1960,27 @@ TEST(WasiNNTest, NeuralSpeedBackend) { } // Neural Speed WASI-NN get_output tests. + // Test: get_output -- output bytes ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- output buffer ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), OutBoundPtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } // Test: get_output -- get output successfully. { EXPECT_TRUE(HostFuncGetOutput.run( From 56d6debaf4d3bcb1794b597b05017e606fae4bbf Mon Sep 17 00:00:00 2001 From: grorge Date: Sat, 1 Jun 2024 17:08:43 +0800 Subject: [PATCH 326/623] [WASI-NN] neural speed: change simdjson and small model Signed-off-by: grorge --- plugins/wasi_nn/CMakeLists.txt | 56 +------------------------ plugins/wasi_nn/neuralspeed.cpp | 63 +++++++++++++++-------------- test/plugins/wasi_nn/CMakeLists.txt | 6 +-- test/plugins/wasi_nn/wasi_nn.cpp | 17 +++++--- 4 files changed, 48 insertions(+), 94 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 99798fd5..c725bbee 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -108,61 +108,7 @@ if(BACKEND STREQUAL "ggml") endif() if(BACKEND STREQUAL "neuralspeed") - find_package(simdjson QUIET) - if(simdjson_FOUND) - message(STATUS "SIMDJSON found") - else() - message(STATUS "Downloading SIMDJSON source") - include(FetchContent) - FetchContent_Declare( - simdjson - GIT_REPOSITORY https://github.com/simdjson/simdjson.git - GIT_TAG tags/v3.9.1 - GIT_SHALLOW TRUE) - - if(MSVC) - if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") - get_property( - compile_options - DIRECTORY - PROPERTY COMPILE_OPTIONS - ) - set_property( - DIRECTORY - APPEND - PROPERTY COMPILE_OPTIONS - -Wno-undef - -Wno-suggest-override - -Wno-documentation - -Wno-sign-conversion - -Wno-extra-semi-stmt - -Wno-old-style-cast - -Wno-error=unused-parameter - -Wno-error=unused-template - -Wno-conditional-uninitialized - -Wno-implicit-int-conversion - -Wno-shorten-64-to-32 - -Wno-range-loop-bind-reference - -Wno-format-nonliteral - -Wno-unused-exception-parameter - -Wno-unused-member-function - ) - unset(compile_options) - elseif(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") - set_property( - DIRECTORY - APPEND - PROPERTY COMPILE_OPTIONS - /wd4100 # unreferenced formal parameter - ) - endif() - endif() - - FetchContent_MakeAvailable(simdjson) - set_property(TARGET simdjson PROPERTY POSITION_INDEPENDENT_CODE ON) - - message(STATUS "Downloading SIMDJSON source -- done") - endif() + wasmedge_setup_simdjson() endif() wasmedge_add_library(wasmedgePluginWasiNN diff --git a/plugins/wasi_nn/neuralspeed.cpp b/plugins/wasi_nn/neuralspeed.cpp index ecfdbf8e..30604a43 100644 --- a/plugins/wasi_nn/neuralspeed.cpp +++ b/plugins/wasi_nn/neuralspeed.cpp @@ -18,28 +18,26 @@ HINSTANCE SharedLib = LoadLibrary(PYTHON_LIB_PATH); void *SharedLib = dlopen(PYTHON_LIB_PATH, RTLD_GLOBAL | RTLD_NOW); #endif void printImformation(Graph &GraphRef, Context &CxtRef) { - spdlog::info( - "[WASI-NN][Info] Neural speed backend: Number of input tokens: {}"sv, - CxtRef.Inputs.size()); - spdlog::info( - "[WASI-NN][Info] Neural speed backend: Number of Output tokens: {}"sv, - CxtRef.Outputs.size()); - spdlog::info("[WASI-NN][Info] Neural speed backend: Load time: {} ms"sv, + spdlog::info("[WASI-NN] Neural speed backend: Number of input tokens: {}"sv, + CxtRef.Inputs.size()); + spdlog::info("[WASI-NN] Neural speed backend: Number of Output tokens: {}"sv, + CxtRef.Outputs.size()); + spdlog::info("[WASI-NN] Neural speed backend: Load time: {} ms"sv, GraphRef.LoadTime); - spdlog::info("[WASI-NN][Info] Neural speed backend: Compute time: {} ms "sv, + spdlog::info("[WASI-NN] Neural speed backend: Compute time: {} ms "sv, GraphRef.ComputeTime); } Expect load(WASINN::WasiNNEnvironment &Env, Span> Builders, WASINN::Device, - uint32_t &) noexcept { + uint32_t &GraphId) noexcept { // Add a new graph. Env.NNGraph.emplace_back(Backend::NeuralSpeed); auto &GraphRef = Env.NNGraph.back().get(); const auto StartTime = std::chrono::steady_clock::now(); // Initialize the plugin parameters. - GraphRef.EnableDebugLog = true; + GraphRef.EnableDebugLog = false; if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] Neural speed backend: Load."sv); + spdlog::info("[WASI-NN] Neural speed backend: Load."sv); } if (Builders.size() > 1) { @@ -50,6 +48,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, auto ParseError = Parser.parse(Metadata).get(Doc); if (ParseError) { spdlog::error("[WASI-NN] Neural speed backend: Parse metadata error"sv); + Env.NNGraph.pop_back(); return ErrNo::InvalidEncoding; } if (Doc.at_key("model_type").error() == simdjson::SUCCESS) { @@ -58,6 +57,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, if (Err) { spdlog::error( "[WASI-NN] Neural speed backend: Unable to retrieve the model_type option."sv); + Env.NNGraph.pop_back(); return ErrNo::InvalidArgument; } GraphRef.model_type = model_type; @@ -68,18 +68,16 @@ Expect load(WASINN::WasiNNEnvironment &Env, auto Weight = Builders[0]; const std::string BinModel(reinterpret_cast(Weight.data()), Weight.size()); - spdlog::info("[WASI-NN][Debug] Neural speed: BinModel: {}"sv, - BinModel.size()); + spdlog::info("[WASI-NN] Neural speed: BinModel: {}"sv, BinModel.size()); std::string ModelFilePath; if (BinModel.substr(0, 8) == "preload:") { ModelFilePath = BinModel.substr(8); } else { if (GraphRef.EnableDebugLog) { spdlog::info( - "[WASI-NN][Debug] Neural speed: Model path not found in nn-preload, " + "[WASI-NN] Neural speed: Model path not found in nn-preload, " "write model into a tmpfile."sv); } - // TODO: pass the model directly to ggml // Write neural speed model to file. ModelFilePath = "neural-speed-model.bin"sv; std::ofstream TempFile(ModelFilePath, std::ios::binary); @@ -96,12 +94,11 @@ Expect load(WASINN::WasiNNEnvironment &Env, TempFile.close(); if (GraphRef.EnableDebugLog) { spdlog::info( - "[WASI-NN][Debug] Neural speed: Write model into a tmpfile...Done"sv); + "[WASI-NN] Neural speed: Write model into a tmpfile...Done"sv); } } if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] Neural speed: Finished handling model path."sv); + spdlog::info("[WASI-NN] Neural speed: Finished handling model path."sv); } // Create Model class @@ -114,6 +111,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, PyErr_Print(); spdlog::error( "[WASI-NN] neural speed backend: Can not find neural speed library."sv); + Env.NNGraph.pop_back(); return WASINN::ErrNo::RuntimeError; } GraphRef.ModelClass = @@ -123,6 +121,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, spdlog::error( "[WASI-NN] neural speed backend: Can not find Model class in neural speed."sv); Py_XDECREF(GraphRef.Model); + Env.NNGraph.pop_back(); return WASINN::ErrNo::RuntimeError; } GraphRef.Model = PyObject_CallObject(GraphRef.ModelClass, NULL); @@ -130,6 +129,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, spdlog::error("[WASI-NN] neural speed backend: Load model error."sv); Py_XDECREF(GraphRef.ModelClass); Py_XDECREF(GraphRef.NeuralSpeedModule); + Env.NNGraph.pop_back(); return WASINN::ErrNo::InvalidArgument; } PyObject *LoadResult = @@ -140,20 +140,25 @@ Expect load(WASINN::WasiNNEnvironment &Env, Py_XDECREF(GraphRef.Model); Py_XDECREF(GraphRef.ModelClass); Py_XDECREF(GraphRef.NeuralSpeedModule); + Env.NNGraph.pop_back(); return WASINN::ErrNo::InvalidArgument; } Py_XDECREF(LoadResult); GraphRef.LoadTime = std::chrono::duration_cast( std::chrono::steady_clock::now() - StartTime) .count(); + + // Store the loaded graph. + GraphId = Env.NNGraph.size() - 1; + return WASINN::ErrNo::Success; } Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, uint32_t &ContextId) noexcept { if (!Py_IsInitialized()) { - spdlog::info( - "[WASI-NN][Error] Neural speed backend: Model has been realse, please reload it."sv); + spdlog::error( + "[WASI-NN] Neural speed backend: Model has been realse, please reload it."sv); return WASINN::ErrNo::RuntimeError; } Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); @@ -166,12 +171,12 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); if (!Py_IsInitialized()) { - spdlog::info( - "[WASI-NN][Error] Neural speed backend: Model has been realse, please reload it."sv); + spdlog::error( + "[WASI-NN] Neural speed backend: Model has been realse, please reload it."sv); return WASINN::ErrNo::RuntimeError; } if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] Neural speed backend: setInput"sv); + spdlog::info("[WASI-NN] Neural speed backend: setInput"sv); } // Set the input. @@ -196,7 +201,7 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] Neural speed backend: getOutput"sv); + spdlog::info("[WASI-NN] Neural speed backend: getOutput"sv); } std::string StringTmp(reinterpret_cast(CxtRef.Outputs.data()), CxtRef.Outputs.size() * sizeof(long long int)); @@ -207,15 +212,15 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { if (!Py_IsInitialized()) { - spdlog::info( - "[WASI-NN][Error] Neural speed backend: Model has been realse, please reload it."sv); + spdlog::error( + "[WASI-NN]Neural speed backend: Model has been realse, please reload it."sv); return WASINN::ErrNo::RuntimeError; } auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); const auto StartTime = std::chrono::steady_clock::now(); if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] Neural speed backend: compute"sv); + spdlog::info("[WASI-NN] Neural speed backend: compute"sv); } if (CxtRef.Inputs.size() == 0) { spdlog::error("[WASI-NN] Neural speed backend: Llama input is not set!"sv); @@ -249,7 +254,6 @@ Expect compute(WasiNNEnvironment &Env, "[WASI-NN] neural speed backend: Input transfer tensor failed."sv); return WASINN::ErrNo::InvalidArgument; } - // PyObject *GenerateArgs = PyTuple_Pack(1, LongTensor); PyObject *Result = PyObject_CallMethodObjArgs( GraphRef.Model, PyUnicode_FromString("generate"), LongTensor, NULL); if (Result == nullptr) { @@ -278,7 +282,6 @@ Expect compute(WasiNNEnvironment &Env, } } Py_DECREF(Result); - // Py_DECREF(GenerateArgs); Py_DECREF(LongTensor); GraphRef.ComputeTime = std::chrono::duration_cast( std::chrono::steady_clock::now() - StartTime) @@ -293,7 +296,7 @@ Expect unload(WASINN::WasiNNEnvironment &Env, auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] Neural speed backend: start unload."sv); + spdlog::info("[WASI-NN] Neural speed backend: start unload."sv); } if (Py_IsInitialized()) { Py_XDECREF(GraphRef.Model); diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index ad6f85b8..37cd0166 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -73,9 +73,9 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) elseif(BACKEND STREQUAL "neuralspeed") message( STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_neural_speed_fixtures") download( - https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_0.gguf - ${CMAKE_CURRENT_BINARY_DIR}/wasinn_neural_speed_fixtures/llama-2-7b-chat.Q4_0.gguf - MD5=09ff9df730d4b5a8b60fcd89e97ee6f4 + https://huggingface.co/grorge123/phi-2-GPTQ/resolve/main/ne_phi_q_nf4_bestla_cfp32_g32.bin + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_neural_speed_fixtures/ne_phi_q_nf4_bestla_cfp32_g32.bin + MD5=5e055b41f8cc1a42f26ff8742719ef1e ) else() # Add the other backend test files fetching here. diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 2df44ed1..0df306d3 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -1763,19 +1763,19 @@ TEST(WasiNNTest, NeuralSpeedBackend) { WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); // Load the files. - std::vector Prompt = {1, 9038, 2501, 263, 931, 29892, - 727, 22856, 263, 2217, 7826, 29892}; + std::vector Prompt = {7454, 2402, 257, 640, 11, 612, + 11196, 257, 1310, 2576, 11}; std::string tmp(reinterpret_cast(Prompt.data()), Prompt.size() * sizeof(long long int)); std::vector TensorData(tmp.begin(), tmp.end()); std::vector WeightRead = readEntireFile( - "./wasinn_neural_speed_fixtures/llama-2-7b-chat.Q4_0.gguf"); + "./wasinn_neural_speed_fixtures/ne_phi_q_nf4_bestla_cfp32_g32.bin"); std::vector TensorDim{1}; uint32_t BuilderPtr = UINT32_C(0); uint32_t LoadEntryPtr = UINT32_C(0); uint32_t SetInputEntryPtr = UINT32_C(0); - uint32_t OutBoundPtr = UINT32_C(61000 * 65536); + uint32_t OutBoundPtr = UINT32_C(61000) * UINT32_C(65536); uint32_t StorePtr = UINT32_C(65536); // Return value. @@ -1873,15 +1873,20 @@ TEST(WasiNNTest, NeuralSpeedBackend) { static_cast(ErrNo::InvalidArgument)); } // Test: load -- load successfully. + std::string Config = "{\"model_type\":\"phi\"}"; + std::vector ConfigData(Config.begin(), Config.end()); BuilderPtr = LoadEntryPtr; writeFatPointer(MemInst, StorePtr, WeightRead.size(), BuilderPtr); + writeFatPointer(MemInst, StorePtr + WeightRead.size(), ConfigData.size(), + BuilderPtr); writeBinaries(MemInst, WeightRead, StorePtr); - StorePtr += WeightRead.size(); + writeBinaries(MemInst, ConfigData, StorePtr + WeightRead.size()); + StorePtr += WeightRead.size() + ConfigData.size(); { EXPECT_TRUE( HostFuncLoad.run(CallFrame, std::initializer_list{ - LoadEntryPtr, UINT32_C(1), + LoadEntryPtr, UINT32_C(2), static_cast(Backend::NeuralSpeed), UINT32_C(0), BuilderPtr}, Errno)); From b811ed3c457f499fa29f69d149e99b4cb7b8808f Mon Sep 17 00:00:00 2001 From: grorge Date: Mon, 3 Jun 2024 13:02:52 +0800 Subject: [PATCH 327/623] [WASI-NN] neural speed: Add build workflow Signed-off-by: grorge --- plugins/wasi_nn/neuralspeed.cpp | 6 +++--- utils/wasi-nn/install-neuralspeed.sh | 11 +++++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) create mode 100644 utils/wasi-nn/install-neuralspeed.sh diff --git a/plugins/wasi_nn/neuralspeed.cpp b/plugins/wasi_nn/neuralspeed.cpp index 30604a43..c932fe97 100644 --- a/plugins/wasi_nn/neuralspeed.cpp +++ b/plugins/wasi_nn/neuralspeed.cpp @@ -52,15 +52,15 @@ Expect load(WASINN::WasiNNEnvironment &Env, return ErrNo::InvalidEncoding; } if (Doc.at_key("model_type").error() == simdjson::SUCCESS) { - std::string_view model_type; - auto Err = Doc["model_type"].get().get(model_type); + std::string_view ModelType; + auto Err = Doc["model_type"].get().get(ModelType); if (Err) { spdlog::error( "[WASI-NN] Neural speed backend: Unable to retrieve the model_type option."sv); Env.NNGraph.pop_back(); return ErrNo::InvalidArgument; } - GraphRef.model_type = model_type; + GraphRef.model_type = ModelType; } } diff --git a/utils/wasi-nn/install-neuralspeed.sh b/utils/wasi-nn/install-neuralspeed.sh new file mode 100644 index 00000000..13d3f043 --- /dev/null +++ b/utils/wasi-nn/install-neuralspeed.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash + +set -e +echo "Installing Python library!" +apt update +apt install -y python3-dev python3-pip + +echo "Installing Neural Speed!" +wget https://raw.githubusercontent.com/intel/neural-speed/main/requirements.txt +pip install -r requirements.txt +pip install neural-speed==${NEURALSPEED_VERSION} \ No newline at end of file From e7cdf37a1cc0cd8d6c9917a871e87067a19aeefa Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 12 Jun 2024 19:16:10 +0800 Subject: [PATCH 328/623] [WASI-NN] ggml backend: bump to llama.cpp b3135 (#3469) Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index c725bbee..a39963f2 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -96,7 +96,7 @@ if(BACKEND STREQUAL "ggml") FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b3075 + GIT_TAG b3135 GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(llama) From 9513255180022287bf9b39dc0b25eae3b381fb73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=AEm=20Ts=C3=BA-thu=C3=A0n?= Date: Tue, 18 Jun 2024 14:43:01 +0800 Subject: [PATCH 329/623] [component model] instantiate (#3218) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [component model] instantiate process drafts * section order will affect index space * skip validation process for now * nested core module instantiation * record all environments had been used * add type and core:type into index space (in component instance) * (FIXME) import section: the current implementation will treat a plugin as module instance instead a component instance, in the end we should have both * instantiate component sort part * export function instance out, so we can run it from CLI * canoncial lift without options * canoncial lower without options * unify component and module instance in VM * drafting inline exports concept Signed-off-by: Lîm Tsú-thuàn * [wasi] plugin wasi_http for testing * wasi-http implementation * drafting example for test * fix a bug in store manager, soft modules do it wrong * fix instance problem misc * update to follow review feedback * mark TODO * add prefix * reinterpret_cast * format stuff Signed-off-by: Lîm Tsú-thuàn * [plugin] prepare `PluginComponent` - canonical changes - extracting out the instantiate of each section to its own file - `alias` is correct, the bug is coming from the demo code Signed-off-by: Lîm Tsú-thuàn * [component model] drafting canonical ABI wrapping function 1. [component] create `Executor::lowering` function that gets a higher function instance, returns a lower function instance 2. add enum `TypeCode::String` 3. add `StrVariant` so host function can operate string 4. [plugin] update wasi-http implementation to use new stuffs Signed-off-by: Lîm Tsú-thuàn * [component model] canonical lift draft * pushing canonical lifting a bit, now call flow has no error came out * add header * cannot allocate ValVariant * `vector` of component instance use push_back * fix test build * remove const descriptor * avoid occurs casting which will fail on some platforms * remove `const` * proper initialize * anonymous namespace * lift lower function that occurs string correctly misc * mark TODO * use hard code version to allocate string in memory * abort execution if run into unimplemented option * add formatter for AST type * use formatter Signed-off-by: Lîm Tsú-thuàn * [component model] canonical: introduce comp instance for type lookup misc * complete defined value type formatter * format the Component function type Signed-off-by: Lîm Tsú-thuàn * [tool] let argument can direct pass string NOTE: some internal bugs still there, so the demo haven't fully work yet. Signed-off-by: Lîm Tsú-thuàn * [wasi-http] cleanup environment misc: remove unused code Signed-off-by: Lîm Tsú-thuàn * [component model] complete lifted function returns misc * missing `break` in cases * adjustment Signed-off-by: Lîm Tsú-thuàn * [wasi-poll] initial the plugin Signed-off-by: Lîm Tsú-thuàn * [internal] reduce the need of string instance, directly pass the string content misc * missing `break` * reduce interface of `StrVariant` Signed-off-by: Lîm Tsú-thuàn * [component model] complete core inline exports misc: format Signed-off-by: Lîm Tsú-thuàn * [misc] fix plugin Signed-off-by: Lîm Tsú-thuàn * [Misc] use visitor to avoid missing case Signed-off-by: Lîm Tsú-thuàn * [Misc] fix found copy happening Signed-off-by: Lîm Tsú-thuàn * [Misc] cleanup StrVariant Signed-off-by: Lîm Tsú-thuàn * [Plugin] ensure default option is set Signed-off-by: Lîm Tsú-thuàn * [Misc] keep unique_ptr mode Signed-off-by: Lîm Tsú-thuàn * [Fix] trivialize std::variant Signed-off-by: Lîm Tsú-thuàn * [Misc] adjust if style Signed-off-by: Lîm Tsú-thuàn * [Misc] Update for reviews Signed-off-by: Lîm Tsú-thuàn * [Misc] missing author Signed-off-by: Lîm Tsú-thuàn --------- Signed-off-by: Lîm Tsú-thuàn --- plugins/CMakeLists.txt | 8 +++++ plugins/wasi_http/CMakeLists.txt | 45 ++++++++++++++++++++++++++++ plugins/wasi_http/README.md | 3 ++ plugins/wasi_http/base.h | 24 +++++++++++++++ plugins/wasi_http/env.cpp | 42 ++++++++++++++++++++++++++ plugins/wasi_http/env.h | 19 ++++++++++++ plugins/wasi_http/func.cpp | 34 +++++++++++++++++++++ plugins/wasi_http/func.h | 25 ++++++++++++++++ plugins/wasi_http/module.cpp | 18 +++++++++++ plugins/wasi_http/module.h | 24 +++++++++++++++ plugins/wasi_nn/wasinnenv.cpp | 2 ++ plugins/wasi_poll/CMakeLists.txt | 39 ++++++++++++++++++++++++ plugins/wasi_poll/README.md | 3 ++ plugins/wasi_poll/base.h | 24 +++++++++++++++ plugins/wasi_poll/env.cpp | 42 ++++++++++++++++++++++++++ plugins/wasi_poll/env.h | 19 ++++++++++++ plugins/wasi_poll/func.cpp | 14 +++++++++ plugins/wasi_poll/func.h | 23 ++++++++++++++ plugins/wasi_poll/module.cpp | 17 +++++++++++ plugins/wasi_poll/module.h | 24 +++++++++++++++ test/plugins/unittest/testplugin.cpp | 2 ++ 21 files changed, 451 insertions(+) create mode 100644 plugins/wasi_http/CMakeLists.txt create mode 100644 plugins/wasi_http/README.md create mode 100644 plugins/wasi_http/base.h create mode 100644 plugins/wasi_http/env.cpp create mode 100644 plugins/wasi_http/env.h create mode 100644 plugins/wasi_http/func.cpp create mode 100644 plugins/wasi_http/func.h create mode 100644 plugins/wasi_http/module.cpp create mode 100644 plugins/wasi_http/module.h create mode 100644 plugins/wasi_poll/CMakeLists.txt create mode 100644 plugins/wasi_poll/README.md create mode 100644 plugins/wasi_poll/base.h create mode 100644 plugins/wasi_poll/env.cpp create mode 100644 plugins/wasi_poll/env.h create mode 100644 plugins/wasi_poll/func.cpp create mode 100644 plugins/wasi_poll/func.h create mode 100644 plugins/wasi_poll/module.cpp create mode 100644 plugins/wasi_poll/module.h diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index 348c98a4..21173df2 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -1,6 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: 2019-2022 Second State INC +if(WASMEDGE_PLUGIN_WASI_HTTP) + add_subdirectory(wasi_http) +endif() + +if(WASMEDGE_PLUGIN_WASI_POLL) + add_subdirectory(wasi_poll) +endif() + if(WASMEDGE_PLUGIN_WASI_NN_BACKEND) add_subdirectory(wasi_nn) endif() diff --git a/plugins/wasi_http/CMakeLists.txt b/plugins/wasi_http/CMakeLists.txt new file mode 100644 index 00000000..d85d4e7f --- /dev/null +++ b/plugins/wasi_http/CMakeLists.txt @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +include(FetchContent) +FetchContent_Declare(cpr GIT_REPOSITORY https://github.com/libcpr/cpr.git + GIT_TAG 3b15fa82ea74739b574d705fea44959b58142eb8) +FetchContent_MakeAvailable(cpr) + +wasmedge_add_library(wasmedgePluginWasiHttp + SHARED + env.cpp + func.cpp + module.cpp +) + +target_compile_options(wasmedgePluginWasiHttp + PUBLIC + -DWASMEDGE_PLUGIN +) + +target_include_directories(wasmedgePluginWasiHttp + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} + ${PROJECT_SOURCE_DIR}/thirdparty +) + +target_link_libraries(wasmedgePluginWasiHttp + PUBLIC + cpr::cpr +) + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasiHttp + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginWasiHttp + PRIVATE + wasmedge_shared + ) +endif() + +install(TARGETS wasmedgePluginWasiHttp DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) diff --git a/plugins/wasi_http/README.md b/plugins/wasi_http/README.md new file mode 100644 index 00000000..0074d4b8 --- /dev/null +++ b/plugins/wasi_http/README.md @@ -0,0 +1,3 @@ +# wasi_http + +This is corresponding to [wasi-http preview2](https://github.com/WebAssembly/wasi-http), but a very beginning implementation, for now it's created to test component model. diff --git a/plugins/wasi_http/base.h b/plugins/wasi_http/base.h new file mode 100644 index 00000000..0c5fdbd2 --- /dev/null +++ b/plugins/wasi_http/base.h @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "env.h" + +#include "common/errcode.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { + +template class WasiHttp : public Runtime::HostFunction { +public: + WasiHttp(WasiHttpEnvironment &HostEnv) + : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + WasiHttpEnvironment &Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_http/env.cpp b/plugins/wasi_http/env.cpp new file mode 100644 index 00000000..b0743d14 --- /dev/null +++ b/plugins/wasi_http/env.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "env.h" +#include "module.h" + +namespace WasmEdge { +namespace Host { + +WasiHttpEnvironment::WasiHttpEnvironment() noexcept {} + +namespace { + +Runtime::Instance::ComponentInstance * +create(const Plugin::PluginComponent::ComponentDescriptor *) noexcept { + return new WasiHttpModule(); +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasi_http", + .Description = "", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 1, 0, 0}, + .ModuleCount = 0, + .ModuleDescriptions = {}, + .ComponentCount = 1, + .ComponentDescriptions = + (Plugin::PluginComponent::ComponentDescriptor[]){ + { + .Name = "wasi:http/test", + .Description = "", + .Create = create, + }, + }, + .AddOptions = nullptr, +}; + +EXPORT_GET_DESCRIPTOR(Descriptor) + +} // namespace +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_http/env.h b/plugins/wasi_http/env.h new file mode 100644 index 00000000..2df43a2b --- /dev/null +++ b/plugins/wasi_http/env.h @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC +#pragma once + +#include "plugin/plugin.h" + +#include +#include + +namespace WasmEdge { +namespace Host { + +class WasiHttpEnvironment { +public: + WasiHttpEnvironment() noexcept; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_http/func.cpp b/plugins/wasi_http/func.cpp new file mode 100644 index 00000000..f00b2121 --- /dev/null +++ b/plugins/wasi_http/func.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "func.h" +#include "common/defines.h" +#include "common/errcode.h" + +#include +#include +#include +#include + +namespace WasmEdge { +namespace Host { + +Expect WasiHttpPrint::body(const Runtime::CallingFrame &, + StrVariant Str) { + spdlog::info("[WASI-HTTP] print: {}", Str.getString()); + return {}; +} + +Expect WasiHttpGet::body(const Runtime::CallingFrame &, + StrVariant URI) { + const auto &S = URI.getString(); + spdlog::info("[WASI-HTTP] URI: {}", S); + cpr::Response Res = cpr::Get( + cpr::Url{S}, cpr::Authentication{"user", "pass", cpr::AuthMode::BASIC}); + spdlog::info("[WASI-HTTP] status: {}", Res.status_code); + + return StrVariant(std::move(Res.text)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_http/func.h b/plugins/wasi_http/func.h new file mode 100644 index 00000000..79ce1c83 --- /dev/null +++ b/plugins/wasi_http/func.h @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { + +class WasiHttpPrint : public WasiHttp { +public: + WasiHttpPrint(WasiHttpEnvironment &HostEnv) : WasiHttp(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, StrVariant Str); +}; + +class WasiHttpGet : public WasiHttp { +public: + WasiHttpGet(WasiHttpEnvironment &HostEnv) : WasiHttp(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, StrVariant URI); +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_http/module.cpp b/plugins/wasi_http/module.cpp new file mode 100644 index 00000000..49fc4fb6 --- /dev/null +++ b/plugins/wasi_http/module.cpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "module.h" +#include "func.h" + +#include + +namespace WasmEdge { +namespace Host { + +WasiHttpModule::WasiHttpModule() : ComponentInstance("wasi:http/test") { + addHostFunc("http-get", std::make_unique(Env)); + addHostFunc("print", std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_http/module.h b/plugins/wasi_http/module.h new file mode 100644 index 00000000..93bb8776 --- /dev/null +++ b/plugins/wasi_http/module.h @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "env.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasiHttpModule : public Runtime::Instance::ComponentInstance { +public: + WasiHttpModule(); + + WasiHttpEnvironment &getEnv() { return Env; } + +private: + WasiHttpEnvironment Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index fdadbb96..32163bed 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -159,6 +159,8 @@ Plugin::Plugin::PluginDescriptor Descriptor{ /* Version */ {0, 10, 1, 0}, /* ModuleCount */ 1, /* ModuleDescriptions */ MD, + /* ComponentCount */ 0, + /* ComponentDescriptions */ nullptr, /* AddOptions */ addOptions, }; } // namespace diff --git a/plugins/wasi_poll/CMakeLists.txt b/plugins/wasi_poll/CMakeLists.txt new file mode 100644 index 00000000..cfa00821 --- /dev/null +++ b/plugins/wasi_poll/CMakeLists.txt @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +wasmedge_add_library(wasmedgePluginWasiPoll + SHARED + env.cpp + func.cpp + module.cpp +) + +target_compile_options(wasmedgePluginWasiPoll + PUBLIC + -DWASMEDGE_PLUGIN +) + +target_include_directories(wasmedgePluginWasiPoll + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} + ${PROJECT_SOURCE_DIR}/thirdparty +) + +target_link_libraries(wasmedgePluginWasiPoll + PUBLIC +) + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasiPoll + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginWasiPoll + PRIVATE + wasmedge_shared + ) +endif() + +install(TARGETS wasmedgePluginWasiPoll DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) diff --git a/plugins/wasi_poll/README.md b/plugins/wasi_poll/README.md new file mode 100644 index 00000000..4eae44f4 --- /dev/null +++ b/plugins/wasi_poll/README.md @@ -0,0 +1,3 @@ +# wasi_poll + +This is corresponding to [wasi-poll preview2](https://github.com/WebAssembly/wasi-poll). diff --git a/plugins/wasi_poll/base.h b/plugins/wasi_poll/base.h new file mode 100644 index 00000000..e1c50bb7 --- /dev/null +++ b/plugins/wasi_poll/base.h @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "env.h" + +#include "common/errcode.h" +#include "runtime/hostfunc.h" + +namespace WasmEdge { +namespace Host { + +template class WasiPoll : public Runtime::HostFunction { +public: + WasiPoll(WasiPollEnvironment &HostEnv) + : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + WasiPollEnvironment &Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_poll/env.cpp b/plugins/wasi_poll/env.cpp new file mode 100644 index 00000000..90b5e632 --- /dev/null +++ b/plugins/wasi_poll/env.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "env.h" +#include "module.h" + +namespace WasmEdge { +namespace Host { + +WasiPollEnvironment::WasiPollEnvironment() noexcept {} + +namespace { + +Runtime::Instance::ComponentInstance * +create(const Plugin::PluginComponent::ComponentDescriptor *) noexcept { + return new WasiPollModule(); +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasi_poll", + .Description = "", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 1, 0, 0}, + .ModuleCount = 0, + .ModuleDescriptions = {}, + .ComponentCount = 1, + .ComponentDescriptions = + (Plugin::PluginComponent::ComponentDescriptor[]){ + { + .Name = "wasi:poll/poll", + .Description = "", + .Create = create, + }, + }, + .AddOptions = nullptr, +}; + +EXPORT_GET_DESCRIPTOR(Descriptor) + +} // namespace +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_poll/env.h b/plugins/wasi_poll/env.h new file mode 100644 index 00000000..e7d8ab1d --- /dev/null +++ b/plugins/wasi_poll/env.h @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC +#pragma once + +#include "plugin/plugin.h" + +#include +#include + +namespace WasmEdge { +namespace Host { + +class WasiPollEnvironment { +public: + WasiPollEnvironment() noexcept; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_poll/func.cpp b/plugins/wasi_poll/func.cpp new file mode 100644 index 00000000..732208dd --- /dev/null +++ b/plugins/wasi_poll/func.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "func.h" +#include "common/defines.h" +#include "common/errcode.h" + +namespace WasmEdge { +namespace Host { + +Expect Drop::body(const Runtime::CallingFrame &, Pollable) { return {}; } + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_poll/func.h b/plugins/wasi_poll/func.h new file mode 100644 index 00000000..a0c1c7d8 --- /dev/null +++ b/plugins/wasi_poll/func.h @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2023 Second State INC + +#pragma once + +#include "base.h" +#include "runtime/callingframe.h" + +namespace WasmEdge { +namespace Host { + +using Pollable = uint32_t; + +class Drop : public WasiPoll { +public: + Drop(WasiPollEnvironment &HostEnv) : WasiPoll(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, Pollable This); +}; + +// poll-oneoff: func(in: list) -> list + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_poll/module.cpp b/plugins/wasi_poll/module.cpp new file mode 100644 index 00000000..b523f867 --- /dev/null +++ b/plugins/wasi_poll/module.cpp @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "module.h" +#include "func.h" + +#include + +namespace WasmEdge { +namespace Host { + +WasiPollModule::WasiPollModule() : ComponentInstance("wasi:poll/poll") { + addHostFunc("drop-pollable", std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_poll/module.h b/plugins/wasi_poll/module.h new file mode 100644 index 00000000..5c26930f --- /dev/null +++ b/plugins/wasi_poll/module.h @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "env.h" + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasiPollModule : public Runtime::Instance::ComponentInstance { +public: + WasiPollModule(); + + WasiPollEnvironment &getEnv() { return Env; } + +private: + WasiPollEnvironment Env; +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/test/plugins/unittest/testplugin.cpp b/test/plugins/unittest/testplugin.cpp index 581dfb5f..2c5747b6 100644 --- a/test/plugins/unittest/testplugin.cpp +++ b/test/plugins/unittest/testplugin.cpp @@ -51,6 +51,8 @@ Plugin::Plugin::PluginDescriptor Descriptor{ /* Version */ {0, 10, 0, 0}, /* ModuleCount */ 1, /* ModuleDescriptions */ MD, + /* ComponentCount */ 0, + /* ComponentDescriptions */ nullptr, /* AddOptions */ addOptions, }; From 88fac9c43c26439749b937ebdb83de22b642a977 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Wed, 19 Jun 2024 13:25:36 +0800 Subject: [PATCH 330/623] [CMake] Refactor WASI-NN ggml for supporting multiple backends. Signed-off-by: YiYing He --- plugins/wasi_crypto/symmetric/options.cpp | 2 +- plugins/wasi_nn/CMakeLists.txt | 335 +++++++++++----------- plugins/wasi_nn/ggml.cpp | 65 +++-- plugins/wasi_nn/wasinnenv.cpp | 14 +- test/plugins/wasi_nn/CMakeLists.txt | 2 +- 5 files changed, 207 insertions(+), 211 deletions(-) diff --git a/plugins/wasi_crypto/symmetric/options.cpp b/plugins/wasi_crypto/symmetric/options.cpp index 22309e57..923d4cc8 100644 --- a/plugins/wasi_crypto/symmetric/options.cpp +++ b/plugins/wasi_crypto/symmetric/options.cpp @@ -18,7 +18,7 @@ constexpr std::array ValidNames{"context"sv, "salt"sv, std::string toLower(std::string_view Name) noexcept { std::string Ret{Name}; std::transform(Ret.begin(), Ret.end(), Ret.begin(), - [](char C) { return std::tolower(C); }); + [](char C) { return static_cast(std::tolower(C)); }); return Ret; } diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index a39963f2..35186b6e 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -1,116 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: 2019-2022 Second State INC -string(TOLOWER ${WASMEDGE_PLUGIN_WASI_NN_BACKEND} BACKEND) -if(BACKEND STREQUAL "ggml") - # llama.cpp options - # Disable warnings and debug messages - set(LLAMA_ALL_WARNINGS OFF) - set(LLAMA_METAL_NDEBUG ON) - set(LLAMA_ACCELERATE OFF) - - if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_NATIVE) - message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_NATIVE(AVX/AVX2/FMA)") - set(LLAMA_NATIVE ON) - else() - set(LLAMA_NATIVE OFF) - endif() - - if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_CUBLAS) - message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_CUDA") - set(LLAMA_CUDA ON) - # We need to set GGML_USE_CUDA for clip from llava. - add_compile_definitions(GGML_USE_CUDA) - # If CUDA is ON, then OpenBLAS should be OFF. - set(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_BLAS OFF) - else() - message(STATUS "WASI-NN GGML LLAMA backend: Disable LLAMA_CUDA") - set(LLAMA_CUDA OFF) - endif() - - if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_BLAS) - message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_BLAS") - # Default use OpenBLAS - set(LLAMA_BLAS ON) - set(LLAMA_BLAS_VENDOR "OpenBLAS") - else() - message(STATUS "WASI-NN GGML LLAMA backend: Disable LLAMA_BLAS") - set(LLAMA_BLAS OFF) - endif() - - if(NOT APPLE) - set(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL OFF) - endif() - - if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL) - message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_METAL") - set(LLAMA_METAL ON) - set(LLAMA_METAL_EMBED_LIBRARY ON) - else() - message(STATUS "WASI-NN GGML LLAMA backend: Disable LLAMA_METAL") - set(LLAMA_METAL OFF) - endif() - - # setup llama.cpp - message(STATUS "Downloading llama.cpp source") - if(MSVC) - add_compile_options( - $<$:/utf-8> - $<$:-Xcompiler=/utf-8> - $<$:/wd4067> # unexpected tokens following preprocessor directive - expected a newline - $<$:/wd4101> # 'identifier' : unreferenced local variable - $<$:/wd4189> # 'identifier' : local variable is initialized but not referenced - $<$:/wd4244> # 'argument' : conversion from 'type1' to 'type2', possible loss of data - $<$:/wd4267> # 'var' : conversion from 'size_t' to 'type', possible loss of data - $<$:/wd4297> # 'function' : function assumed not to throw an exception but does - $<$:/wd4456> # declaration of 'identifier' hides previous local declaration - $<$:/wd4505> # 'function' : unreferenced local function has been removed - ) - endif() - if(CMAKE_CXX_COMPILER_ID MATCHES "GNU") - add_compile_options( - $<$:-Wno-exceptions> - -Wno-cast-align - -Wno-cast-qual - -Wno-float-conversion - -Wno-implicit-fallthrough - -Wno-unused-macros - -Wno-unused-function - -Wno-unused-variable - ) - elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang") - add_compile_options( - $<$:-Wno-exceptions> - -Wno-cast-align - -Wno-cast-qual - -Wno-disabled-macro-expansion - -Wno-float-conversion - -Wno-implicit-fallthrough - -Wno-implicit-float-conversion - -Wno-unused-macros - -Wno-unused-function - -Wno-unused-variable - ) - endif() - include(FetchContent) - FetchContent_Declare( - llama - GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b3135 - GIT_SHALLOW FALSE - ) - FetchContent_MakeAvailable(llama) - set_property(TARGET ggml PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET common PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET llama PROPERTY POSITION_INDEPENDENT_CODE ON) - - wasmedge_setup_simdjson() -endif() - -if(BACKEND STREQUAL "neuralspeed") - wasmedge_setup_simdjson() -endif() - wasmedge_add_library(wasmedgePluginWasiNN SHARED wasinnenv.cpp @@ -125,13 +15,164 @@ wasmedge_add_library(wasmedgePluginWasiNN neuralspeed.cpp ) +foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) + string(TOLOWER ${BACKEND} BACKEND) + if(BACKEND STREQUAL "ggml") + wasmedge_setup_simdjson() + # llama.cpp options + # Disable warnings and debug messages + set(LLAMA_ALL_WARNINGS OFF) + set(LLAMA_METAL_NDEBUG ON) + set(LLAMA_ACCELERATE OFF) + + if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_NATIVE) + message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_NATIVE(AVX/AVX2/FMA)") + set(LLAMA_NATIVE ON) + else() + set(LLAMA_NATIVE OFF) + endif() + + if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_CUBLAS) + message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_CUDA") + set(LLAMA_CUDA ON) + # We need to set GGML_USE_CUDA for clip from llava. + add_compile_definitions(GGML_USE_CUDA) + # If CUDA is ON, then OpenBLAS should be OFF. + set(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_BLAS OFF) + else() + message(STATUS "WASI-NN GGML LLAMA backend: Disable LLAMA_CUDA") + set(LLAMA_CUDA OFF) + endif() + + if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_BLAS) + message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_BLAS") + # Default use OpenBLAS + set(LLAMA_BLAS ON) + set(LLAMA_BLAS_VENDOR "OpenBLAS") + else() + message(STATUS "WASI-NN GGML LLAMA backend: Disable LLAMA_BLAS") + set(LLAMA_BLAS OFF) + endif() + + if(NOT APPLE) + set(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL OFF) + endif() + + if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL) + message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_METAL") + set(LLAMA_METAL ON) + set(LLAMA_METAL_EMBED_LIBRARY ON) + else() + message(STATUS "WASI-NN GGML LLAMA backend: Disable LLAMA_METAL") + set(LLAMA_METAL OFF) + endif() + + # setup llama.cpp + message(STATUS "Downloading llama.cpp source") + include(FetchContent) + FetchContent_Declare( + llama + GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git + GIT_TAG b3135 + GIT_SHALLOW FALSE + ) + FetchContent_MakeAvailable(llama) + set_property(TARGET ggml PROPERTY POSITION_INDEPENDENT_CODE ON) + set_property(TARGET common PROPERTY POSITION_INDEPENDENT_CODE ON) + set_property(TARGET llama PROPERTY POSITION_INDEPENDENT_CODE ON) + + # Setup llava from llama.cpp + wasmedge_add_library(llava OBJECT + ${llama_SOURCE_DIR}/examples/llava/clip.cpp + ${llama_SOURCE_DIR}/examples/llava/llava.cpp + ) + if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") + target_compile_options(llava + PRIVATE + $<$:/utf-8> + $<$:-Xcompiler=/utf-8> + $<$:/wd4067> # unexpected tokens following preprocessor directive - expected a newline + $<$:/wd4101> # 'identifier' : unreferenced local variable + $<$:/wd4189> # 'identifier' : local variable is initialized but not referenced + $<$:/wd4244> # 'argument' : conversion from 'type1' to 'type2', possible loss of data + $<$:/wd4267> # 'var' : conversion from 'size_t' to 'type', possible loss of data + $<$:/wd4297> # 'function' : function assumed not to throw an exception but does + $<$:/wd4456> # declaration of 'identifier' hides previous local declaration + $<$:/wd4505> # 'function' : unreferenced local function has been removed + ) + elseif(CMAKE_CXX_COMPILER_ID MATCHES "GNU") + target_compile_options(llava + PRIVATE + $<$:-Wno-exceptions> + -Wno-cast-align + -Wno-cast-qual + -Wno-float-conversion + -Wno-implicit-fallthrough + -Wno-unused-macros + -Wno-unused-function + -Wno-unused-variable + ) + elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang") + target_compile_options(llava + PRIVATE + $<$:-Wno-exceptions> + -Wno-cast-align + -Wno-cast-qual + -Wno-disabled-macro-expansion + -Wno-float-conversion + -Wno-implicit-fallthrough + -Wno-implicit-float-conversion + -Wno-unused-macros + -Wno-unused-function + -Wno-unused-variable + ) + endif() + target_link_libraries(llava PRIVATE ggml llama) + target_include_directories(llava PUBLIC + ${llama_SOURCE_DIR} + ${llama_SOURCE_DIR}/common + ${llama_SOURCE_DIR}/examples/llava + ) + # Setup include and link from llama.cpp + target_include_directories(wasmedgePluginWasiNN PRIVATE + ${llama_SOURCE_DIR} + ${llama_SOURCE_DIR}examples/llava + ) + target_link_libraries(wasmedgePluginWasiNN PRIVATE + common + simdjson::simdjson + llava + ) + if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL) + add_custom_command( + TARGET wasmedgePluginWasiNN + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${llama_SOURCE_DIR}/ggml-metal.metal ggml-metal.metal + COMMAND ${CMAKE_COMMAND} -E copy ${llama_SOURCE_DIR}/ggml-common.h ggml-common.h + ) + endif() + elseif(BACKEND STREQUAL "neuralspeed") + wasmedge_setup_simdjson() + + find_package(Python3 COMPONENTS Interpreter Development) + if(Python3_FOUND) + target_compile_definitions(wasmedgePluginWasiNN + PRIVATE PYTHON_LIB_PATH="${Python3_LIBRARIES}" + ) + include_directories(${Python3_INCLUDE_DIRS}) + target_link_libraries(wasmedgePluginWasiNN PRIVATE ${Python3_LIBRARIES}) + target_link_directories(wasmedgePluginWasiNN PRIVATE ${Python3_RUNTIME_LIBRARY_DIRS}) + else() + message(FATAL_ERROR "Can not find python3.") + endif() + target_link_libraries(wasmedgePluginWasiNN PRIVATE simdjson::simdjson) + endif() +endforeach() + target_compile_options(wasmedgePluginWasiNN PUBLIC -DWASMEDGE_PLUGIN ) -if(WASMEDGE_BUILD_WASI_NN_RPC) - add_definitions(-DWASMEDGE_BUILD_WASI_NN_RPC) -endif() target_include_directories(wasmedgePluginWasiNN PUBLIC @@ -139,44 +180,15 @@ target_include_directories(wasmedgePluginWasiNN ${CMAKE_CURRENT_SOURCE_DIR} ) -if(BACKEND STREQUAL "ggml") - # Setup llava from llama.cpp - wasmedge_add_library(llava OBJECT - ${llama_SOURCE_DIR}/examples/llava/clip.cpp - ${llama_SOURCE_DIR}/examples/llava/llava.cpp - ) - if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") - target_compile_options(llava PRIVATE -Wno-error=unused-variable -Wno-error=unused-function) - endif() - target_link_libraries(llava PRIVATE ggml llama) - target_include_directories(llava PUBLIC - ${llama_SOURCE_DIR} - ${llama_SOURCE_DIR}/common - ${llama_SOURCE_DIR}/examples/llava - ) - # Setup include and link from llama.cpp - target_include_directories(wasmedgePluginWasiNN PRIVATE - ${llama_SOURCE_DIR} - ${llama_SOURCE_DIR}examples/llava +if(WASMEDGE_BUILD_WASI_NN_RPC) + add_definitions(-DWASMEDGE_BUILD_WASI_NN_RPC) + target_include_directories(wasmedgePluginWasiNN + SYSTEM BEFORE PUBLIC ${Protobuf_INCLUDE_DIR} ) - target_link_libraries(wasmedgePluginWasiNN PRIVATE - common - simdjson::simdjson - llava + target_link_libraries(wasmedgePluginWasiNN + PRIVATE + wasiNNRPC ) - if(MSVC) - target_compile_options(wasmedgePluginWasiNN PUBLIC - /wd4067 # unexpected tokens following preprocessor directive - expected a newline - ) - endif() - if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL) - add_custom_command( - TARGET wasmedgePluginWasiNN - POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${llama_SOURCE_DIR}/ggml-metal.metal ggml-metal.metal - COMMAND ${CMAKE_COMMAND} -E copy ${llama_SOURCE_DIR}/ggml-common.h ggml-common.h - ) - endif() endif() if(WASMEDGE_LINK_PLUGINS_STATIC) @@ -191,31 +203,6 @@ else() ) endif() -if(WASMEDGE_BUILD_WASI_NN_RPC) - target_include_directories(wasmedgePluginWasiNN - SYSTEM BEFORE PUBLIC ${Protobuf_INCLUDE_DIR} - ) - target_link_libraries(wasmedgePluginWasiNN - PRIVATE - wasiNNRPC - ) -endif() - -if(BACKEND STREQUAL "neuralspeed") -find_package (Python3 COMPONENTS Interpreter Development) -if(Python3_FOUND) - target_compile_definitions(wasmedgePluginWasiNN - PRIVATE PYTHON_LIB_PATH="${Python3_LIBRARIES}" - ) - include_directories(${Python3_INCLUDE_DIRS}) - target_link_libraries(wasmedgePluginWasiNN PRIVATE ${Python3_LIBRARIES}) - target_link_directories(wasmedgePluginWasiNN PRIVATE ${Python3_RUNTIME_LIBRARY_DIRS}) - elseif() - message(FATAL_ERROR "Can not find python3.") - endif() - target_link_libraries(wasmedgePluginWasiNN PRIVATE simdjson::simdjson) -endif() - include(WASINNDeps) wasmedge_setup_wasinn_target(wasmedgePluginWasiNN) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 65e57512..c25b3b9a 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -89,8 +89,8 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, // Get the current llama parameters. llama_model_params ModelParams = llama_model_default_params(); - ModelParams.n_gpu_layers = GraphRef.NGPULayers; - ModelParams.main_gpu = GraphRef.MainGPU; + ModelParams.n_gpu_layers = static_cast(GraphRef.NGPULayers); + ModelParams.main_gpu = static_cast(GraphRef.MainGPU); ModelParams.tensor_split = GraphRef.TensorSplit.data(); ModelParams.use_mmap = GraphRef.UseMMap; @@ -202,14 +202,14 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, SS >> TmpTensor; GraphRef.TensorSplit.push_back(TmpTensor); } - uint32_t NDevices = llama_max_devices(); + size_t NDevices = llama_max_devices(); if (GraphRef.TensorSplit.size() > NDevices) { spdlog::error( "[WASI-NN] GGML backend: Number of Tensor-Split is larger " "than MaxDevices, please reduce the size of tensor-split."sv); return ErrNo::InvalidArgument; } - for (uint32_t Idx = GraphRef.TensorSplit.size(); Idx < NDevices; Idx++) { + for (size_t Idx = GraphRef.TensorSplit.size(); Idx < NDevices; Idx++) { GraphRef.TensorSplit.push_back(0.0f); } } @@ -331,11 +331,11 @@ Expect setupGPTParam(Graph &GraphRef, gpt_params &GPTParams) { Expect setupContextParam(Graph &GraphRef, llama_context_params &ContextParams) { - ContextParams.n_ctx = GraphRef.CtxSize; - ContextParams.n_batch = GraphRef.BatchSize; - ContextParams.n_ubatch = GraphRef.UBatchSize; - ContextParams.n_threads = GraphRef.Threads; - ContextParams.n_threads_batch = GraphRef.Threads; + ContextParams.n_ctx = static_cast(GraphRef.CtxSize); + ContextParams.n_batch = static_cast(GraphRef.BatchSize); + ContextParams.n_ubatch = static_cast(GraphRef.UBatchSize); + ContextParams.n_threads = static_cast(GraphRef.Threads); + ContextParams.n_threads_batch = static_cast(GraphRef.Threads); ContextParams.embeddings = GraphRef.Embedding; return ErrNo::Success; } @@ -390,10 +390,10 @@ ErrNo evaluateTokens(Graph &GraphRef, struct llama_context *LlamaContext, } for (int I = 0; I < static_cast(Tokens.size()); - I += GraphRef.BatchSize) { + I += static_cast(GraphRef.BatchSize)) { int NEval = static_cast(Tokens.size()) - I; if (NEval > static_cast(GraphRef.BatchSize)) { - NEval = GraphRef.BatchSize; + NEval = static_cast(GraphRef.BatchSize); } // llama_batch_get_one(*token, n_tokens, position, sequence_id) // This will return batch for single sequence of tokens starting at @@ -529,8 +529,10 @@ Expect getEmbedding(WasiNNEnvironment &Env, } const int32_t NEmbd = llama_n_embd(GraphRef.LlamaModel); - struct llama_batch Batch = - llama_batch_init(GraphRef.BatchSize, /* embd */ 0, /* n_seq_max */ 1); + struct llama_batch Batch = llama_batch_init( + /* n_tokens_alloc */ static_cast(GraphRef.BatchSize), + /* embd */ 0, + /* n_seq_max */ 1); std::vector Embeddings(NEmbd); batchAddSeq(Batch, CxtRef.LlamaInputs, SequenceId); ReturnCode = batchDecode(LlamaContext, Batch, Embeddings.data(), NEmbd); @@ -645,7 +647,8 @@ loadBase64ImageFromPrompt(Graph &GraphRef, clip_ctx *ClipContext, } return llava_image_embed_make_with_bytes( - ClipContext, GraphRef.Threads, ImageBytes.data(), ImageBytes.size()); + ClipContext, static_cast(GraphRef.Threads), ImageBytes.data(), + static_cast(ImageBytes.size())); } ErrNo replaceBase64ImagePlaceholderInPrompt(std::string &Prompt) noexcept { @@ -779,8 +782,8 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Initialize ggml model with model parameters. GraphRef.ModelFilePath = ModelFilePath; llama_model_params ModelParams = llama_model_default_params(); - ModelParams.n_gpu_layers = GraphRef.NGPULayers; - ModelParams.main_gpu = GraphRef.MainGPU; + ModelParams.n_gpu_layers = static_cast(GraphRef.NGPULayers); + ModelParams.main_gpu = static_cast(GraphRef.MainGPU); ModelParams.tensor_split = GraphRef.TensorSplit.data(); ModelParams.use_mmap = GraphRef.UseMMap; GraphRef.LlamaModel = @@ -796,7 +799,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, } // Store the loaded graph. - GraphId = Env.NNGraph.size() - 1; + GraphId = static_cast(Env.NNGraph.size() - 1); // Disable llama log by default. log_disable(); @@ -811,7 +814,7 @@ Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, spdlog::info("[WASI-NN][Debug] GGML backend: initExecCtx"sv); } Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); - ContextId = Env.NNContext.size() - 1; + ContextId = static_cast(Env.NNContext.size() - 1); if (GraphRef.EnableLog) { spdlog::info("[WASI-NN] GGML backend: llama_system_info: {}"sv, llama_print_system_info()); @@ -855,7 +858,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, { if (IsModelParamsUpdated) { llama_model_params ModelParams = llama_model_default_params(); - ModelParams.n_gpu_layers = GraphRef.NGPULayers; + ModelParams.n_gpu_layers = static_cast(GraphRef.NGPULayers); llama_free_model(GraphRef.LlamaModel); GraphRef.LlamaModel = llama_load_model_from_file( GraphRef.ModelFilePath.c_str(), ModelParams); @@ -957,7 +960,8 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } else { // Load the image from the file. CxtRef.LlavaImageEmbd = llava_image_embed_make_with_filename( - ClipContext, GraphRef.Threads, GraphRef.ImagePath.c_str()); + ClipContext, static_cast(GraphRef.Threads), + GraphRef.ImagePath.c_str()); } clip_free(ClipContext); if (CxtRef.LlavaImageEmbd == nullptr) { @@ -1037,7 +1041,7 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, return Res; } std::copy_n(Metadata.data(), Metadata.length(), OutBuffer.data()); - BytesWritten = Metadata.length(); + BytesWritten = static_cast(Metadata.length()); if (GraphRef.EnableDebugLog) { spdlog::info( "[WASI-NN][Debug] GGML backend: getOutput with Index {}...Done"sv, @@ -1048,7 +1052,7 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, std::copy_n(CxtRef.LlamaOutputs.data(), CxtRef.LlamaOutputs.length(), OutBuffer.data()); - BytesWritten = CxtRef.LlamaOutputs.length(); + BytesWritten = static_cast(CxtRef.LlamaOutputs.length()); if (GraphRef.EnableDebugLog) { spdlog::info( "[WASI-NN][Debug] GGML backend: getOutput with Index {}...Done"sv, @@ -1096,7 +1100,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { llama_sampling_init(GPTParams.sparams); // Prepare variables; int32_t NPast = 0; - int32_t NRemain = GraphRef.NPredict; + uint64_t NRemain = GraphRef.NPredict; // Get the context size. const uint64_t NCtx = llama_n_ctx(LlamaContext); // Minus 4 for the special tokens. (Such as , , ... tokens.) @@ -1139,8 +1143,9 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { "[WASI-NN] GGML backend: failed to evaluate input tokens before image."sv); return ReturnCode; } - bool EvalImageStatus = llava_eval_image_embed( - LlamaContext, CxtRef.LlavaImageEmbd, GraphRef.BatchSize, &NPast); + bool EvalImageStatus = + llava_eval_image_embed(LlamaContext, CxtRef.LlavaImageEmbd, + static_cast(GraphRef.BatchSize), &NPast); if (!EvalImageStatus) { spdlog::error( "[WASI-NN] GGML backend: failed to evaluate embed image tokens."sv); @@ -1248,7 +1253,7 @@ Expect getOutputSingle(WasiNNEnvironment &Env, uint32_t ContextId, return Res; } std::copy_n(Metadata.data(), Metadata.length(), OutBuffer.data()); - BytesWritten = Metadata.length(); + BytesWritten = static_cast(Metadata.length()); if (GraphRef.EnableDebugLog) { spdlog::info( "[WASI-NN][Debug] GGML backend: getOutputSingle with Index {}...Done"sv, @@ -1259,7 +1264,7 @@ Expect getOutputSingle(WasiNNEnvironment &Env, uint32_t ContextId, std::string LastToken = llama_token_to_piece(CxtRef.LlamaContext, CxtRef.LlamaOutputTokens.back()); std::copy_n(LastToken.data(), LastToken.length(), OutBuffer.data()); - BytesWritten = LastToken.length(); + BytesWritten = static_cast(LastToken.length()); if (GraphRef.EnableDebugLog) { spdlog::info( "[WASI-NN][Debug] GGML backend: getOutputSingle with Index {}...Done"sv, @@ -1352,9 +1357,9 @@ Expect computeSingle(WasiNNEnvironment &Env, "[WASI-NN] GGML backend: failed to evaluate input tokens before image."sv); return ReturnCode; } - bool EvalImageStatus = - llava_eval_image_embed(CxtRef.LlamaContext, CxtRef.LlavaImageEmbd, - GraphRef.BatchSize, &CxtRef.LlamaNPast); + bool EvalImageStatus = llava_eval_image_embed( + CxtRef.LlamaContext, CxtRef.LlavaImageEmbd, + static_cast(GraphRef.BatchSize), &CxtRef.LlamaNPast); if (!EvalImageStatus) { spdlog::error( "[WASI-NN] GGML backend: failed to evaluate embed image tokens."sv); diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index 32163bed..1d861320 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -90,22 +90,26 @@ WasiNNEnvironment::WasiNNEnvironment() noexcept { std::vector> Models; Models.reserve(Paths.size()); std::transform(Encode.begin(), Encode.end(), Encode.begin(), - [](unsigned char C) { return std::tolower(C); }); + [](unsigned char C) { + return static_cast(std::tolower(C)); + }); std::transform(Target.begin(), Target.end(), Target.begin(), - [](unsigned char C) { return std::tolower(C); }); + [](unsigned char C) { + return static_cast(std::tolower(C)); + }); auto Backend = BackendMap.find(Encode); auto Device = DeviceMap.find(Target); if (Backend != BackendMap.end() && Device != DeviceMap.end()) { - for (const std::string &Path : Paths) { + for (const std::string &P : Paths) { if (Backend->second == Backend::GGML) { // We write model path to model data to avoid file IO in llama.cpp. - std::string ModelPath = "preload:" + Path; + std::string ModelPath = "preload:" + P; std::vector ModelPathData(ModelPath.begin(), ModelPath.end()); Models.push_back(std::move(ModelPathData)); } else { std::vector Model; - if (load(std::filesystem::u8path(Path), Model)) { + if (load(std::filesystem::u8path(P), Model)) { Models.push_back(std::move(Model)); } } diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index 37cd0166..10d45ba0 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -65,7 +65,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) ${CMAKE_CURRENT_BINARY_DIR}/wasinn_ggml_fixtures/orca_mini.gguf MD5=f895f00678bfbf89f70d6d25f20a7b5f ) - if(MSVC) + if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") target_compile_options(wasiNNTests PUBLIC /wd4067 # unexpected tokens following preprocessor directive - expected a newline ) From 826706ab7b8191d5cb0b3039d90cfdcd9ee70b77 Mon Sep 17 00:00:00 2001 From: hydai Date: Thu, 20 Jun 2024 11:14:46 +0800 Subject: [PATCH 331/623] [WASI-NN] ggml: bump to b3186 Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 35186b6e..dffb142e 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -73,7 +73,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b3135 + GIT_TAG b3186 GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(llama) From 1e346176b724969af74d7d01079b59eb458f792c Mon Sep 17 00:00:00 2001 From: YiYing He Date: Thu, 20 Jun 2024 16:36:25 +0800 Subject: [PATCH 332/623] [WASI-NN] ggml: prevent copy in log callback. Signed-off-by: YiYing He --- plugins/wasi_nn/ggml.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index c25b3b9a..4bb467a3 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -24,7 +24,7 @@ namespace { void LlamaLogCallback(ggml_log_level LogLevel, const char *LogText, void *UserData) { - Graph GraphRef = *static_cast(UserData); + Graph &GraphRef = *reinterpret_cast(UserData); if (!GraphRef.EnableLog) { return; } From fb5a7b74b824ada4588917acf7042f81835b5d83 Mon Sep 17 00:00:00 2001 From: Elmira <95964498+PrabhuUdurg@users.noreply.github.com> Date: Mon, 24 Jun 2024 08:05:11 +0000 Subject: [PATCH 333/623] [CI] Upgraded the WASI-NN OpenVino backend to 2024.2.0 (#3494) * Upgraded the WASI-NN OpenVino backend to 2024.2.0 Signed-off-by: Yehor Pishyi * Fixes Signed-off-by: Pishyi Yehor * Fixes Signed-off-by: Pishyi Yehor --------- Signed-off-by: Yehor Pishyi Signed-off-by: Pishyi Yehor Co-authored-by: Pishyi Yehor --- utils/docker/Dockerfile.manylinux2014-build-plugins-deps | 4 ++-- utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps | 4 ++-- utils/wasi-nn/install-openvino.sh | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/utils/docker/Dockerfile.manylinux2014-build-plugins-deps b/utils/docker/Dockerfile.manylinux2014-build-plugins-deps index 359f4451..e2aab3ea 100644 --- a/utils/docker/Dockerfile.manylinux2014-build-plugins-deps +++ b/utils/docker/Dockerfile.manylinux2014-build-plugins-deps @@ -32,7 +32,7 @@ RUN [ "/bin/bash", "install-ffmpeg-v6.0.sh" ] ENV PKG_CONFIG_PATH /root/FFmpeg-n6.0/output/lib/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} ENV LD_LIBRARY_PATH /root/FFmpeg-n6.0/output/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} -ENV OPENVINO_VERSION "2023.0.2" -ENV OPENVINO_YEAR "2023" +ENV OPENVINO_VERSION "2024.2.0" +ENV OPENVINO_YEAR "2024" RUN yum clean all diff --git a/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps b/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps index 4be4bf97..eaa1cd42 100644 --- a/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps +++ b/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps @@ -33,7 +33,7 @@ RUN [ "/bin/bash", "install-ffmpeg-v6.0.sh" ] ENV PKG_CONFIG_PATH /root/FFmpeg-n6.0/output/lib/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} ENV LD_LIBRARY_PATH /root/FFmpeg-n6.0/output/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} -ENV OPENVINO_VERSION "2023.0.2" -ENV OPENVINO_YEAR "2023" +ENV OPENVINO_VERSION "2024.2.0" +ENV OPENVINO_YEAR "2024" RUN yum clean all diff --git a/utils/wasi-nn/install-openvino.sh b/utils/wasi-nn/install-openvino.sh index 57afdee8..fcc9e27b 100755 --- a/utils/wasi-nn/install-openvino.sh +++ b/utils/wasi-nn/install-openvino.sh @@ -3,10 +3,10 @@ # SPDX-FileCopyrightText: 2019-2022 Second State INC set -e -echo "Installing OpenVINO with version 2023.2.0" +echo "Installing OpenVINO with version 2024.2.0" wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB -echo "deb https://apt.repos.intel.com/openvino/2023 ubuntu20 main" | tee /etc/apt/sources.list.d/intel-openvino-2023.list +echo "deb https://apt.repos.intel.com/openvino/2024 ubuntu20 main" | tee /etc/apt/sources.list.d/intel-openvino-2024.list apt update -apt-get -y install openvino-2023.2.0 +apt-get -y install openvino-2024.2.0 ldconfig From 28dd42018ec4dfcaced333ada2f439b9479c7b7f Mon Sep 17 00:00:00 2001 From: grorge Date: Sun, 19 May 2024 23:01:21 +0800 Subject: [PATCH 334/623] [Plugin] stable diffusion: Fetch stable diffusion Signed-off-by: grorge --- plugins/CMakeLists.txt | 8 +++++ .../wasmedge_stablediffusion/CMakeLists.txt | 35 +++++++++++++++++++ .../stablediffusion.cpp | 22 ++++++++++++ .../stablediffusion.h | 0 4 files changed, 65 insertions(+) create mode 100644 plugins/wasmedge_stablediffusion/CMakeLists.txt create mode 100644 plugins/wasmedge_stablediffusion/stablediffusion.cpp create mode 100644 plugins/wasmedge_stablediffusion/stablediffusion.h diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index 21173df2..47248f1e 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -66,6 +66,14 @@ if(WASMEDGE_PLUGIN_WASI_OCR) add_subdirectory(wasi_ocr) endif() +if(WASMEDGE_PLUGIN_STABLEDIFFUSION) + if(CMAKE_SYSTEM_NAME MATCHES "Linux") + add_subdirectory(wasmedge_stablediffusion) + else() + message(WARNING "Only Linux platforms support wasm_bpf plug-in now.") + endif() +endif() + if(WASMEDGE_PLUGIN_OPENCVMINI) # Only Linux and MacOS support wasmedge_opencvmini now. if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt new file mode 100644 index 00000000..d52f6526 --- /dev/null +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -0,0 +1,35 @@ + # setup stable diffusion +message(STATUS "Downloading stable diffusion source") +if (MSVC) + add_compile_options( + /wd4996 + /wd4456 + /wd4459 + /wd4100 + /wd4127 + /wd4701 + ) +else() + add_compile_options( + -Wno-unused-function + -Wno-unused-variable + -Wno-unused-parameter + -Wno-missing-field-initializers + ) +endif() +FetchContent_Declare( + stable-diffusion + GIT_REPOSITORY https://github.com/leejet/stable-diffusion.cpp.git + GIT_TAG master-1d2af5c + GIT_SHALLOW FALSE +) +FetchContent_MakeAvailable(stable-diffusion) +set_property(TARGET stable-diffusion PROPERTY POSITION_INDEPENDENT_CODE ON) + +wasmedge_add_library(wasmedgePluginStableDiffusion + SHARED + stablediffusion.cpp +) + +install(TARGETS wasmedgePluginStableDiffusion RUNTIME) +target_link_libraries(wasmedgePluginStableDiffusion PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) \ No newline at end of file diff --git a/plugins/wasmedge_stablediffusion/stablediffusion.cpp b/plugins/wasmedge_stablediffusion/stablediffusion.cpp new file mode 100644 index 00000000..80c1def8 --- /dev/null +++ b/plugins/wasmedge_stablediffusion/stablediffusion.cpp @@ -0,0 +1,22 @@ +#include +#include +#include +#include +#include +#include +#include + +// #include "preprocessing.hpp" +#include "stable-diffusion.h" + +#define STB_IMAGE_IMPLEMENTATION +#define STB_IMAGE_STATIC +#include "stb_image.h" + +#define STB_IMAGE_WRITE_IMPLEMENTATION +#define STB_IMAGE_WRITE_STATIC +#include "stb_image_write.h" + +#define STB_IMAGE_RESIZE_IMPLEMENTATION +#define STB_IMAGE_RESIZE_STATIC +#include "stb_image_resize.h" \ No newline at end of file diff --git a/plugins/wasmedge_stablediffusion/stablediffusion.h b/plugins/wasmedge_stablediffusion/stablediffusion.h new file mode 100644 index 00000000..e69de29b From 14b9e47fa0e0d21106c7e788cd3921bfe78fc8cb Mon Sep 17 00:00:00 2001 From: grorge Date: Mon, 20 May 2024 15:41:26 +0800 Subject: [PATCH 335/623] [Plugin] stable diffusion: Add plugin structure Signed-off-by: grorge --- .../wasmedge_stablediffusion/CMakeLists.txt | 23 +++++++++++++- plugins/wasmedge_stablediffusion/sb_base.h | 24 +++++++++++++++ plugins/wasmedge_stablediffusion/sb_env.cpp | 0 plugins/wasmedge_stablediffusion/sb_env.h | 30 +++++++++++++++++++ plugins/wasmedge_stablediffusion/sb_func.cpp | 14 +++++++++ plugins/wasmedge_stablediffusion/sb_func.h | 19 ++++++++++++ .../wasmedge_stablediffusion/sb_module.cpp | 12 ++++++++ plugins/wasmedge_stablediffusion/sb_module.h | 19 ++++++++++++ 8 files changed, 140 insertions(+), 1 deletion(-) create mode 100644 plugins/wasmedge_stablediffusion/sb_base.h create mode 100644 plugins/wasmedge_stablediffusion/sb_env.cpp create mode 100644 plugins/wasmedge_stablediffusion/sb_env.h create mode 100644 plugins/wasmedge_stablediffusion/sb_func.cpp create mode 100644 plugins/wasmedge_stablediffusion/sb_func.h create mode 100644 plugins/wasmedge_stablediffusion/sb_module.cpp create mode 100644 plugins/wasmedge_stablediffusion/sb_module.h diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index d52f6526..d71a7547 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -28,8 +28,29 @@ set_property(TARGET stable-diffusion PROPERTY POSITION_INDEPENDENT_CODE ON) wasmedge_add_library(wasmedgePluginStableDiffusion SHARED + sb_env.cpp + sb_func.cpp + sb_module.cpp stablediffusion.cpp ) install(TARGETS wasmedgePluginStableDiffusion RUNTIME) -target_link_libraries(wasmedgePluginStableDiffusion PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) \ No newline at end of file +target_link_libraries(wasmedgePluginStableDiffusion PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) + + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginStableDiffusion + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginStableDiffusion + PRIVATE + wasmedge_shared + ) +endif() +target_include_directories(wasmedgePluginStableDiffusion + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} +) diff --git a/plugins/wasmedge_stablediffusion/sb_base.h b/plugins/wasmedge_stablediffusion/sb_base.h new file mode 100644 index 00000000..16fee573 --- /dev/null +++ b/plugins/wasmedge_stablediffusion/sb_base.h @@ -0,0 +1,24 @@ +#pragma once + +#include "common/errcode.h" +#include "runtime/hostfunc.h" +#include "sb_env.h" + +namespace WasmEdge { +namespace Host { +namespace StableDiffusion { + +template class Func : public Runtime::HostFunction { +public: + Func(SBEnviornment &HostEnv) : Runtime::HostFunction(0), Env(HostEnv) {} + +protected: + static constexpr uint32_t castErrNo(StableDiffusion::ErrNo E) noexcept { + return static_cast(E); + } + SBEnviornment &Env; +}; + +} // namespace StableDiffusion +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasmedge_stablediffusion/sb_env.cpp b/plugins/wasmedge_stablediffusion/sb_env.cpp new file mode 100644 index 00000000..e69de29b diff --git a/plugins/wasmedge_stablediffusion/sb_env.h b/plugins/wasmedge_stablediffusion/sb_env.h new file mode 100644 index 00000000..59b91867 --- /dev/null +++ b/plugins/wasmedge_stablediffusion/sb_env.h @@ -0,0 +1,30 @@ +#pragma once + +#include "stable-diffusion.h" + +#include "plugin/plugin.h" +#include "stdint.h" + +namespace WasmEdge { +namespace Host { +namespace StableDiffusion { + +enum class ErrNo : uint32_t { + Success = 0, // No error occurred. + InvalidArgument = 1, // Caller module passed an invalid argument. + InvalidEncoding = 2, // Invalid encoding. + MissingMemory = 3, // Caller module is missing a memory export. + Busy = 4, // Device or resource busy. + RuntimeError = 5, // Runtime Error. +}; + +struct Context { + Context() noexcept {} + ~Context() noexcept {} +}; + +class SBEnviornment {}; + +} // namespace StableDiffusion +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasmedge_stablediffusion/sb_func.cpp b/plugins/wasmedge_stablediffusion/sb_func.cpp new file mode 100644 index 00000000..a174a2e2 --- /dev/null +++ b/plugins/wasmedge_stablediffusion/sb_func.cpp @@ -0,0 +1,14 @@ +#include "sb_func.h" +#include "common/spdlog.h" +#include "sb_env.h" + +namespace WasmEdge { +namespace Host { +namespace {} +Expect +SBCreateContext::bodyImpl(const Runtime::CallingFrame &Frame) { + return StableDiffusion::ErrNo::RuntimeError; +} + +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasmedge_stablediffusion/sb_func.h b/plugins/wasmedge_stablediffusion/sb_func.h new file mode 100644 index 00000000..47cb2d4e --- /dev/null +++ b/plugins/wasmedge_stablediffusion/sb_func.h @@ -0,0 +1,19 @@ +#pragma once + +#include "runtime/callingframe.h" +#include "sb_base.h" + +namespace WasmEdge { +namespace Host { +class SBCreateContext : public StableDiffusion::Func { +public: + SBCreateContext(StableDiffusion::SBEnviornment &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame) { + return bodyImpl(Frame).map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame); +}; +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasmedge_stablediffusion/sb_module.cpp b/plugins/wasmedge_stablediffusion/sb_module.cpp new file mode 100644 index 00000000..cc598e64 --- /dev/null +++ b/plugins/wasmedge_stablediffusion/sb_module.cpp @@ -0,0 +1,12 @@ +#include "sb_module.h" +#include "sb_func.h" + +namespace WasmEdge { +namespace Host { + +SBModule::SBModule() : ModuleInstance("stable_diffusion") { + addHostFunc("create_context", std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasmedge_stablediffusion/sb_module.h b/plugins/wasmedge_stablediffusion/sb_module.h new file mode 100644 index 00000000..f9cb607f --- /dev/null +++ b/plugins/wasmedge_stablediffusion/sb_module.h @@ -0,0 +1,19 @@ +#pragma once + +#include "runtime/instance/module.h" +#include "sb_env.h" + +namespace WasmEdge { +namespace Host { + +class SBModule : public Runtime::Instance::ModuleInstance { +public: + SBModule(); + StableDiffusion::SBEnviornment &getEnv() { return Env; } + +private: + StableDiffusion::SBEnviornment Env; +}; + +} // namespace Host +} // namespace WasmEdge \ No newline at end of file From 5b39980ebaadbfdc60dad0bbff62ac24f4dcc73d Mon Sep 17 00:00:00 2001 From: grorge Date: Tue, 21 May 2024 17:04:03 +0800 Subject: [PATCH 336/623] [Plugin] stable diffusion: rename stable diffusion Signed-off-by: grorge --- .../wasmedge_stablediffusion/CMakeLists.txt | 6 +++--- plugins/wasmedge_stablediffusion/sb_func.cpp | 14 ------------- .../wasmedge_stablediffusion/sb_module.cpp | 12 ----------- plugins/wasmedge_stablediffusion/sb_module.h | 19 ------------------ .../{sb_base.h => sd_base.h} | 6 +++--- .../{sb_env.cpp => sd_env.cpp} | 0 .../{sb_env.h => sd_env.h} | 3 +-- plugins/wasmedge_stablediffusion/sd_func.cpp | 16 +++++++++++++++ .../{sb_func.h => sd_func.h} | 8 +++++--- .../wasmedge_stablediffusion/sd_module.cpp | 14 +++++++++++++ plugins/wasmedge_stablediffusion/sd_module.h | 20 +++++++++++++++++++ 11 files changed, 62 insertions(+), 56 deletions(-) delete mode 100644 plugins/wasmedge_stablediffusion/sb_func.cpp delete mode 100644 plugins/wasmedge_stablediffusion/sb_module.cpp delete mode 100644 plugins/wasmedge_stablediffusion/sb_module.h rename plugins/wasmedge_stablediffusion/{sb_base.h => sd_base.h} (80%) rename plugins/wasmedge_stablediffusion/{sb_env.cpp => sd_env.cpp} (100%) rename plugins/wasmedge_stablediffusion/{sb_env.h => sd_env.h} (93%) create mode 100644 plugins/wasmedge_stablediffusion/sd_func.cpp rename plugins/wasmedge_stablediffusion/{sb_func.h => sd_func.h} (63%) create mode 100644 plugins/wasmedge_stablediffusion/sd_module.cpp create mode 100644 plugins/wasmedge_stablediffusion/sd_module.h diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index d71a7547..219d1656 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -28,9 +28,9 @@ set_property(TARGET stable-diffusion PROPERTY POSITION_INDEPENDENT_CODE ON) wasmedge_add_library(wasmedgePluginStableDiffusion SHARED - sb_env.cpp - sb_func.cpp - sb_module.cpp + sd_env.cpp + sd_func.cpp + sd_module.cpp stablediffusion.cpp ) diff --git a/plugins/wasmedge_stablediffusion/sb_func.cpp b/plugins/wasmedge_stablediffusion/sb_func.cpp deleted file mode 100644 index a174a2e2..00000000 --- a/plugins/wasmedge_stablediffusion/sb_func.cpp +++ /dev/null @@ -1,14 +0,0 @@ -#include "sb_func.h" -#include "common/spdlog.h" -#include "sb_env.h" - -namespace WasmEdge { -namespace Host { -namespace {} -Expect -SBCreateContext::bodyImpl(const Runtime::CallingFrame &Frame) { - return StableDiffusion::ErrNo::RuntimeError; -} - -} // namespace Host -} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasmedge_stablediffusion/sb_module.cpp b/plugins/wasmedge_stablediffusion/sb_module.cpp deleted file mode 100644 index cc598e64..00000000 --- a/plugins/wasmedge_stablediffusion/sb_module.cpp +++ /dev/null @@ -1,12 +0,0 @@ -#include "sb_module.h" -#include "sb_func.h" - -namespace WasmEdge { -namespace Host { - -SBModule::SBModule() : ModuleInstance("stable_diffusion") { - addHostFunc("create_context", std::make_unique(Env)); -} - -} // namespace Host -} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasmedge_stablediffusion/sb_module.h b/plugins/wasmedge_stablediffusion/sb_module.h deleted file mode 100644 index f9cb607f..00000000 --- a/plugins/wasmedge_stablediffusion/sb_module.h +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -#include "runtime/instance/module.h" -#include "sb_env.h" - -namespace WasmEdge { -namespace Host { - -class SBModule : public Runtime::Instance::ModuleInstance { -public: - SBModule(); - StableDiffusion::SBEnviornment &getEnv() { return Env; } - -private: - StableDiffusion::SBEnviornment Env; -}; - -} // namespace Host -} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasmedge_stablediffusion/sb_base.h b/plugins/wasmedge_stablediffusion/sd_base.h similarity index 80% rename from plugins/wasmedge_stablediffusion/sb_base.h rename to plugins/wasmedge_stablediffusion/sd_base.h index 16fee573..c89e6b2b 100644 --- a/plugins/wasmedge_stablediffusion/sb_base.h +++ b/plugins/wasmedge_stablediffusion/sd_base.h @@ -2,7 +2,7 @@ #include "common/errcode.h" #include "runtime/hostfunc.h" -#include "sb_env.h" +#include "sd_env.h" namespace WasmEdge { namespace Host { @@ -10,13 +10,13 @@ namespace StableDiffusion { template class Func : public Runtime::HostFunction { public: - Func(SBEnviornment &HostEnv) : Runtime::HostFunction(0), Env(HostEnv) {} + Func(SDEnviornment &HostEnv) : Runtime::HostFunction(0), Env(HostEnv) {} protected: static constexpr uint32_t castErrNo(StableDiffusion::ErrNo E) noexcept { return static_cast(E); } - SBEnviornment &Env; + SDEnviornment &Env; }; } // namespace StableDiffusion diff --git a/plugins/wasmedge_stablediffusion/sb_env.cpp b/plugins/wasmedge_stablediffusion/sd_env.cpp similarity index 100% rename from plugins/wasmedge_stablediffusion/sb_env.cpp rename to plugins/wasmedge_stablediffusion/sd_env.cpp diff --git a/plugins/wasmedge_stablediffusion/sb_env.h b/plugins/wasmedge_stablediffusion/sd_env.h similarity index 93% rename from plugins/wasmedge_stablediffusion/sb_env.h rename to plugins/wasmedge_stablediffusion/sd_env.h index 59b91867..58ceeac2 100644 --- a/plugins/wasmedge_stablediffusion/sb_env.h +++ b/plugins/wasmedge_stablediffusion/sd_env.h @@ -3,7 +3,6 @@ #include "stable-diffusion.h" #include "plugin/plugin.h" -#include "stdint.h" namespace WasmEdge { namespace Host { @@ -23,7 +22,7 @@ struct Context { ~Context() noexcept {} }; -class SBEnviornment {}; +class SDEnviornment {}; } // namespace StableDiffusion } // namespace Host diff --git a/plugins/wasmedge_stablediffusion/sd_func.cpp b/plugins/wasmedge_stablediffusion/sd_func.cpp new file mode 100644 index 00000000..ca9a15e6 --- /dev/null +++ b/plugins/wasmedge_stablediffusion/sd_func.cpp @@ -0,0 +1,16 @@ +#include "sd_func.h" +#include "common/spdlog.h" +#include "sd_env.h" + +namespace WasmEdge { +namespace Host { +namespace StableDiffusion { +namespace {} +Expect +SDCreateContext::bodyImpl(const Runtime::CallingFrame &) { + return ErrNo::RuntimeError; +} + +} // namespace StableDiffusion +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasmedge_stablediffusion/sb_func.h b/plugins/wasmedge_stablediffusion/sd_func.h similarity index 63% rename from plugins/wasmedge_stablediffusion/sb_func.h rename to plugins/wasmedge_stablediffusion/sd_func.h index 47cb2d4e..03664383 100644 --- a/plugins/wasmedge_stablediffusion/sb_func.h +++ b/plugins/wasmedge_stablediffusion/sd_func.h @@ -1,13 +1,14 @@ #pragma once #include "runtime/callingframe.h" -#include "sb_base.h" +#include "sd_base.h" namespace WasmEdge { namespace Host { -class SBCreateContext : public StableDiffusion::Func { +namespace StableDiffusion { +class SDCreateContext : public StableDiffusion::Func { public: - SBCreateContext(StableDiffusion::SBEnviornment &HostEnv) : Func(HostEnv) {} + SDCreateContext(StableDiffusion::SDEnviornment &HostEnv) : Func(HostEnv) {} Expect body(const Runtime::CallingFrame &Frame) { return bodyImpl(Frame).map(castErrNo); } @@ -15,5 +16,6 @@ class SBCreateContext : public StableDiffusion::Func { private: Expect bodyImpl(const Runtime::CallingFrame &Frame); }; +} // namespace StableDiffusion } // namespace Host } // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasmedge_stablediffusion/sd_module.cpp b/plugins/wasmedge_stablediffusion/sd_module.cpp new file mode 100644 index 00000000..6e6cc836 --- /dev/null +++ b/plugins/wasmedge_stablediffusion/sd_module.cpp @@ -0,0 +1,14 @@ +#include "sd_module.h" +#include "sd_func.h" + +namespace WasmEdge { +namespace Host { +namespace StableDiffusion { + +SDModule::SDModule() : ModuleInstance("stable_diffusion") { + addHostFunc("create_context", std::make_unique(Env)); +} + +} // namespace StableDiffusion +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasmedge_stablediffusion/sd_module.h b/plugins/wasmedge_stablediffusion/sd_module.h new file mode 100644 index 00000000..219581e1 --- /dev/null +++ b/plugins/wasmedge_stablediffusion/sd_module.h @@ -0,0 +1,20 @@ +#pragma once + +#include "runtime/instance/module.h" +#include "sd_env.h" + +namespace WasmEdge { +namespace Host { +namespace StableDiffusion { +class SDModule : public Runtime::Instance::ModuleInstance { +public: + SDModule(); + StableDiffusion::SDEnviornment &getEnv() { return Env; } + +private: + StableDiffusion::SDEnviornment Env; +}; + +} // namespace StableDiffusion +} // namespace Host +} // namespace WasmEdge \ No newline at end of file From 29bb04d22536e399f512d57941a36b19abdf2871 Mon Sep 17 00:00:00 2001 From: grorge Date: Sun, 2 Jun 2024 16:24:32 +0800 Subject: [PATCH 337/623] [Plugin] stable diffusion: add create_context and text_to_image function Signed-off-by: grorge --- .../wasmedge_stablediffusion/CMakeLists.txt | 8 + plugins/wasmedge_stablediffusion/sd_env.cpp | 11 ++ plugins/wasmedge_stablediffusion/sd_env.h | 13 +- plugins/wasmedge_stablediffusion/sd_func.cpp | 186 +++++++++++++++++- plugins/wasmedge_stablediffusion/sd_func.h | 32 ++- .../wasmedge_stablediffusion/sd_module.cpp | 1 + .../stablediffusion.cpp | 22 --- .../stablediffusion.h | 0 8 files changed, 236 insertions(+), 37 deletions(-) delete mode 100644 plugins/wasmedge_stablediffusion/stablediffusion.cpp delete mode 100644 plugins/wasmedge_stablediffusion/stablediffusion.h diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index 219d1656..4e0f06e8 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -25,6 +25,14 @@ FetchContent_Declare( ) FetchContent_MakeAvailable(stable-diffusion) set_property(TARGET stable-diffusion PROPERTY POSITION_INDEPENDENT_CODE ON) +get_target_property(SD_DEPS stable-diffusion LINK_LIBRARIES) +foreach(dep ${SD_DEPS}) +if(TARGET ${dep}) + set_target_properties(${dep} PROPERTIES + POSITION_INDEPENDENT_CODE ON + ) + endif() +endforeach() wasmedge_add_library(wasmedgePluginStableDiffusion SHARED diff --git a/plugins/wasmedge_stablediffusion/sd_env.cpp b/plugins/wasmedge_stablediffusion/sd_env.cpp index e69de29b..24d200b4 100644 --- a/plugins/wasmedge_stablediffusion/sd_env.cpp +++ b/plugins/wasmedge_stablediffusion/sd_env.cpp @@ -0,0 +1,11 @@ + +#include "sd_env.h" +namespace WasmEdge::Host::StableDiffusion { +uint32_t SDEnviornment::addContext(sd_ctx_t *sd_ctx) noexcept { + Contexts.push_back(sd_ctx); + return Contexts.size() - 1; +} +sd_ctx_t *SDEnviornment::getContext(const uint32_t Id) noexcept { + return Contexts[Id]; +} +} // namespace WasmEdge::Host::StableDiffusion \ No newline at end of file diff --git a/plugins/wasmedge_stablediffusion/sd_env.h b/plugins/wasmedge_stablediffusion/sd_env.h index 58ceeac2..b9908873 100644 --- a/plugins/wasmedge_stablediffusion/sd_env.h +++ b/plugins/wasmedge_stablediffusion/sd_env.h @@ -3,6 +3,7 @@ #include "stable-diffusion.h" #include "plugin/plugin.h" +#include namespace WasmEdge { namespace Host { @@ -17,12 +18,14 @@ enum class ErrNo : uint32_t { RuntimeError = 5, // Runtime Error. }; -struct Context { - Context() noexcept {} - ~Context() noexcept {} -}; +class SDEnviornment { +public: + uint32_t addContext(sd_ctx_t *sd_ctx) noexcept; + sd_ctx_t *getContext(const uint32_t Id) noexcept; -class SDEnviornment {}; +private: + std::vector Contexts; +}; } // namespace StableDiffusion } // namespace Host diff --git a/plugins/wasmedge_stablediffusion/sd_func.cpp b/plugins/wasmedge_stablediffusion/sd_func.cpp index ca9a15e6..001b1b43 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.cpp +++ b/plugins/wasmedge_stablediffusion/sd_func.cpp @@ -1,16 +1,194 @@ #include "sd_func.h" #include "common/spdlog.h" #include "sd_env.h" +#include "stable-diffusion.h" +#define STB_IMAGE_IMPLEMENTATION +#define STB_IMAGE_STATIC +#include "stb_image.h" + +#define STB_IMAGE_WRITE_IMPLEMENTATION +#define STB_IMAGE_WRITE_STATIC +#include "stb_image_write.h" namespace WasmEdge { namespace Host { namespace StableDiffusion { -namespace {} -Expect -SDCreateContext::bodyImpl(const Runtime::CallingFrame &) { - return ErrNo::RuntimeError; + +#define MEMINST_CHECK(Out, CallFrame, Index) \ + auto *Out = CallFrame.getMemoryByIndex(Index); \ + if (unlikely(Out == nullptr)) { \ + spdlog::error("[WasmEdge-StableDiffusion] Memory instance not found."sv); \ + return static_cast(ErrNo::MissingMemory); \ + } + +#define SESSION_CHECK(Out, SessionID, Message, ErrNo) \ + auto *Out = Env.getContext(SessionID); \ + if (unlikely(Out == nullptr)) { \ + spdlog::error("[WasmEdge-StableDiffusion] "sv Message); \ + return static_cast(ErrNo); \ + } + +#define MEM_SPAN_CHECK(OutSpan, MemInst, Type, BufPtr, BufLen, Message) \ + auto OutSpan = MemInst->getSpan(BufPtr, BufLen); \ + if (unlikely(OutSpan.size() != BufLen)) { \ + spdlog::error("[WasmEdge-StableDiffusion] "sv Message); \ + return static_cast(ErrNo::MissingMemory); \ + } + +#define MEM_SV_CHECK(OutSV, MemInst, BufPtr, BufLen, Message) \ + auto OutSV = MemInst->getStringView(BufPtr, BufLen); \ + if (unlikely(OutSV.size() != BufLen)) { \ + spdlog::error("[WasmEdge-StableDiffusion] "sv Message); \ + return static_cast(ErrNo::MissingMemory); \ + } + +#define MEM_PTR_CHECK(OutPtr, MemInst, Type, Offset, Message) \ + Type *OutPtr = MemInst->getPointer(Offset); \ + if (unlikely(OutPtr == nullptr)) { \ + spdlog::error("[WasmEdge-StableDiffusion] "sv Message); \ + return static_cast(ErrNo::MissingMemory); \ + } + +sd_image_t *ReadControlImage(Span ControlImage, + uint8_t *control_image_buffer, int Width, + int Height, bool canny_preprocess) { + sd_image_t *control_image = NULL; + int Channel = 0; + control_image_buffer = stbi_load_from_memory( + ControlImage.data(), ControlImage.size(), &Width, &Height, &Channel, 3); + if (control_image_buffer == NULL) { + spdlog::error("[StableDiffusion] Load image from control image failed."sv); + return nullptr; + } + control_image = new sd_image_t{(uint32_t)Width, (uint32_t)Height, 3, + control_image_buffer}; + if (canny_preprocess) { // apply preprocessor + control_image->data = preprocess_canny( + control_image->data, control_image->width, control_image->height, 0.08f, + 0.08f, 0.8f, 1.0f, false); + } + return control_image; +} + +Expect SDCreateContext::body( + const Runtime::CallingFrame &Frame, uint32_t ModelPathPtr, + uint32_t ModelPathLen, uint32_t VaePathPtr, uint32_t VaePathLen, + uint32_t TaesdPathPtr, uint32_t TaesdPathLen, uint32_t ControlNetPathPtr, + uint32_t ControlNetPathLen, uint32_t LoraModelDirPtr, + uint32_t LoraModelDirLen, uint32_t EmbedDirPtr, uint32_t EmbedDirLen, + uint32_t IdEmbedDirPtr, uint32_t IdEmbedDirLen, uint32_t VaeDecodeOnly, + uint32_t VaeTiling, int32_t NThreads, uint32_t Wtype, uint32_t RngType, + uint32_t Schedule, uint32_t ClipOnCpu, uint32_t ControlNetCpu, + uint32_t VaeOnCpu, uint32_t SessiontIdPtr) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Check the input model buffer. + MEM_SPAN_CHECK(ModelPathSpan, MemInst, char, ModelPathPtr, ModelPathLen, + "Failed when accessing the input model path memory."sv) + MEM_SPAN_CHECK(VaePathSpan, MemInst, char, VaePathPtr, VaePathLen, + "Failed when accessing the input vae path memory."sv) + MEM_SPAN_CHECK(ControlNetPathSpan, MemInst, char, ControlNetPathPtr, + ControlNetPathLen, + "Failed when accessing the input control net path memory."sv) + MEM_SPAN_CHECK(LoraModelDirSpan, MemInst, char, LoraModelDirPtr, + LoraModelDirLen, + "Failed when accessing the input lora model path memory."sv) + MEM_SPAN_CHECK(TaesdPathSpan, MemInst, char, TaesdPathPtr, TaesdPathLen, + "Failed when accessing the input taesd path memory."sv) + MEM_SPAN_CHECK(EmbedDirSpan, MemInst, char, EmbedDirPtr, EmbedDirLen, + "Failed when accessing the input embedded directory memory."sv) + MEM_SPAN_CHECK( + IdEmbedDirSpan, MemInst, char, IdEmbedDirPtr, IdEmbedDirLen, + "Failed when accessing the input id dembed directory memory."sv) + MEM_PTR_CHECK(SessionId, MemInst, uint32_t, SessiontIdPtr, + "Failed when accessing the return SessionID memory."sv) + + // Create context and import graph. + sd_ctx_t *sd_ctx = + new_sd_ctx(ModelPathSpan.data(), VaePathSpan.data(), TaesdPathSpan.data(), + ControlNetPathSpan.data(), LoraModelDirSpan.data(), + EmbedDirSpan.data(), IdEmbedDirSpan.data(), + static_cast(VaeDecodeOnly), static_cast(VaeTiling), + true, NThreads, sd_type_t(Wtype), rng_type_t(RngType), + schedule_t(Schedule), ClipOnCpu, ControlNetCpu, VaeOnCpu); + if (sd_ctx == NULL) { + spdlog::error("[WasmEdge-StableDiffusion] Failed to create context."); + return static_cast(ErrNo::InvalidArgument); + } + *SessionId = Env.addContext(sd_ctx); + + return static_cast(ErrNo::Success); } +Expect SDTextToImage::body( + const Runtime::CallingFrame &Frame, uint32_t ImagePtr, uint32_t ImageLen, + uint32_t SessionId, uint32_t Width, uint32_t Height, + uint32_t ControlImagePtr, uint32_t ControlImageLen, uint32_t PromptPtr, + uint32_t PromptLen, uint32_t NegativePromptPtr, uint32_t NegativePromptLen, + uint32_t ClipSkip, uint32_t CfgScale, uint32_t SampleMethod, + uint32_t SampleSteps, uint32_t Strength, uint32_t Seed, uint32_t BatchCount, + uint32_t ControlStrength, uint32_t StyleRatio, uint32_t NormalizeInput, + uint32_t InputIdImagesPathPtr, uint32_t InputIdImagesPathLen, + uint32_t canny_preprocess, uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, + uint32_t BytesWrittenPtr) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + // Check the input model buffer. + MEM_SPAN_CHECK(ImageSpan, MemInst, uint8_t, ImagePtr, ImageLen, + "Failed when accessing the input image memory."sv) + + MEM_SPAN_CHECK(PromptSpan, MemInst, char, PromptPtr, PromptLen, + "Failed when accessing the promp memory."sv) + MEM_SPAN_CHECK(NegativePromptSpan, MemInst, char, NegativePromptPtr, + NegativePromptLen, + "Failed when accessing the input negative prompt memory."sv) + MEM_SPAN_CHECK( + InputIdImagesPathSpan, MemInst, char, InputIdImagesPathPtr, + InputIdImagesPathLen, + "Failed when accessing the input input id images path memory."sv) + MEM_SPAN_CHECK(OutputBufferSpan, MemInst, uint8_t, OutBufferPtr, + OutBufferMaxSize, + "Failed when accessing the Output Buffer memory."sv) + MEM_PTR_CHECK(BytesWritten, MemInst, uint32_t, BytesWrittenPtr, + "Failed when accessing the return bytes written memory."sv) + sd_ctx_t *SDCtx = Env.getContext(SessionId); + + uint8_t *InputImageBuffer = nullptr; + uint8_t *control_image_buffer = nullptr; + int Channel = 0; + int ImageWidth = 0; + int ImageHeight = 0; + InputImageBuffer = + stbi_load_from_memory(ImageSpan.data(), ImageSpan.size(), &ImageWidth, + &ImageHeight, &Channel, 3); + sd_image_t InputImage = {Width, Height, 3, InputImageBuffer}; + sd_image_t *ControlImage; + if (ControlImageLen != 0) { + MEM_SPAN_CHECK(ControlImageSpan, MemInst, uint8_t, ControlImagePtr, + ControlImageLen, + "Failed when accessing the control image memory."sv) + ControlImage = ReadControlImage(ControlImageSpan, control_image_buffer, + Width, Height, canny_preprocess); + } + + sd_image_t *results; + results = img2img(SDCtx, InputImage, PromptSpan.data(), + NegativePromptSpan.data(), ClipSkip, CfgScale, Width, + Height, sample_method_t(SampleMethod), SampleSteps, + Strength, Seed, BatchCount, ControlImage, ControlStrength, + StyleRatio, NormalizeInput, InputIdImagesPathSpan.data()); + int len; + unsigned char *png = + stbi_write_png_to_mem((const unsigned char *)results, 0, results->width, + results->height, results->channel, &len, NULL); + *BytesWritten = len; + std::copy_n(png, *BytesWritten, OutputBufferSpan.data()); + free(results); + free(InputImageBuffer); + free(control_image_buffer); + return static_cast(ErrNo::RuntimeError); +} } // namespace StableDiffusion } // namespace Host } // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasmedge_stablediffusion/sd_func.h b/plugins/wasmedge_stablediffusion/sd_func.h index 03664383..62971202 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.h +++ b/plugins/wasmedge_stablediffusion/sd_func.h @@ -9,12 +9,32 @@ namespace StableDiffusion { class SDCreateContext : public StableDiffusion::Func { public: SDCreateContext(StableDiffusion::SDEnviornment &HostEnv) : Func(HostEnv) {} - Expect body(const Runtime::CallingFrame &Frame) { - return bodyImpl(Frame).map(castErrNo); - } - -private: - Expect bodyImpl(const Runtime::CallingFrame &Frame); + Expect + body(const Runtime::CallingFrame &Frame, uint32_t ModelPathPtr, + uint32_t ModelPathLen, uint32_t VaePathPtr, uint32_t VaePathLen, + uint32_t TaesdPathPtr, uint32_t TaesdPathLen, uint32_t ControlNetPathPtr, + uint32_t ControlNetPathLen, uint32_t LoraModelDirPtr, + uint32_t LoraModelDirLen, uint32_t EmbedDirPtr, uint32_t EmbedDirLen, + uint32_t IdEmbedDirPtr, uint32_t IdEmbedDirLen, uint32_t VaeDecodeOnly, + uint32_t VaeTiling, int32_t NThreads, uint32_t Wtype, uint32_t RngType, + uint32_t Schedule, uint32_t ClipOnCpu, uint32_t ControlNetCpu, + uint32_t VaeOnCpu, uint32_t SessiontIdPtr); +}; +class SDTextToImage : public StableDiffusion::Func { +public: + SDTextToImage(StableDiffusion::SDEnviornment &HostEnv) : Func(HostEnv) {} + Expect + body(const Runtime::CallingFrame &Frame, uint32_t ImagePtr, uint32_t ImageLen, + uint32_t SessionId, uint32_t Width, uint32_t Height, + uint32_t ControlImagePtr, uint32_t ControlImageLen, uint32_t PromptPtr, + uint32_t PromptLen, uint32_t NegativePromptPtr, + uint32_t NegativePromptLen, uint32_t ClipSkip, uint32_t CfgScale, + uint32_t SampleMethod, uint32_t SampleSteps, uint32_t Strength, + uint32_t Seed, uint32_t BatchCount, uint32_t ControlStrength, + uint32_t StyleRatio, uint32_t NormalizeInput, + uint32_t InputIdImagesPathPtr, uint32_t InputIdImagesPathLen, + uint32_t canny_preprocess, uint32_t OutBufferPtr, + uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr); }; } // namespace StableDiffusion } // namespace Host diff --git a/plugins/wasmedge_stablediffusion/sd_module.cpp b/plugins/wasmedge_stablediffusion/sd_module.cpp index 6e6cc836..84ebab4a 100644 --- a/plugins/wasmedge_stablediffusion/sd_module.cpp +++ b/plugins/wasmedge_stablediffusion/sd_module.cpp @@ -7,6 +7,7 @@ namespace StableDiffusion { SDModule::SDModule() : ModuleInstance("stable_diffusion") { addHostFunc("create_context", std::make_unique(Env)); + addHostFunc("text_to_image", std::make_unique(Env)); } } // namespace StableDiffusion diff --git a/plugins/wasmedge_stablediffusion/stablediffusion.cpp b/plugins/wasmedge_stablediffusion/stablediffusion.cpp deleted file mode 100644 index 80c1def8..00000000 --- a/plugins/wasmedge_stablediffusion/stablediffusion.cpp +++ /dev/null @@ -1,22 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -// #include "preprocessing.hpp" -#include "stable-diffusion.h" - -#define STB_IMAGE_IMPLEMENTATION -#define STB_IMAGE_STATIC -#include "stb_image.h" - -#define STB_IMAGE_WRITE_IMPLEMENTATION -#define STB_IMAGE_WRITE_STATIC -#include "stb_image_write.h" - -#define STB_IMAGE_RESIZE_IMPLEMENTATION -#define STB_IMAGE_RESIZE_STATIC -#include "stb_image_resize.h" \ No newline at end of file diff --git a/plugins/wasmedge_stablediffusion/stablediffusion.h b/plugins/wasmedge_stablediffusion/stablediffusion.h deleted file mode 100644 index e69de29b..00000000 From 01cf813abdf07611e91e654ae7066e4e73bff674 Mon Sep 17 00:00:00 2001 From: grorge Date: Fri, 7 Jun 2024 20:02:15 +0800 Subject: [PATCH 338/623] [Plugin] stable diffusion: add basic gtest Signed-off-by: grorge --- .../wasmedge_stablediffusion/CMakeLists.txt | 26 ++++++----- plugins/wasmedge_stablediffusion/sd_env.cpp | 38 +++++++++++++++- .../wasmedge_stablediffusion/sd_module.cpp | 8 ++-- plugins/wasmedge_stablediffusion/sd_module.h | 2 - test/plugins/CMakeLists.txt | 6 +++ .../wasmedge_stablediffusion/CMakeLists.txt | 33 ++++++++++++++ .../wasmedge_stablediffusion.cpp | 44 +++++++++++++++++++ 7 files changed, 139 insertions(+), 18 deletions(-) create mode 100644 test/plugins/wasmedge_stablediffusion/CMakeLists.txt create mode 100644 test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index 4e0f06e8..2544452b 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -39,26 +39,32 @@ wasmedge_add_library(wasmedgePluginStableDiffusion sd_env.cpp sd_func.cpp sd_module.cpp - stablediffusion.cpp ) -install(TARGETS wasmedgePluginStableDiffusion RUNTIME) target_link_libraries(wasmedgePluginStableDiffusion PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_options(wasmedgePluginStableDiffusion + PUBLIC + -DWASMEDGE_PLUGIN +) if(WASMEDGE_LINK_PLUGINS_STATIC) - target_link_libraries(wasmedgePluginStableDiffusion - PRIVATE - wasmedgeCAPI - ) +target_link_libraries(wasmedgePluginStableDiffusion + PRIVATE + wasmedgeCAPI +) + else() - target_link_libraries(wasmedgePluginStableDiffusion - PRIVATE - wasmedge_shared - ) +target_link_libraries(wasmedgePluginStableDiffusion + PRIVATE + wasmedge_shared +) endif() + target_include_directories(wasmedgePluginStableDiffusion PUBLIC $ ${CMAKE_CURRENT_SOURCE_DIR} ) + +install(TARGETS wasmedgePluginStableDiffusion DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) \ No newline at end of file diff --git a/plugins/wasmedge_stablediffusion/sd_env.cpp b/plugins/wasmedge_stablediffusion/sd_env.cpp index 24d200b4..9f3d2222 100644 --- a/plugins/wasmedge_stablediffusion/sd_env.cpp +++ b/plugins/wasmedge_stablediffusion/sd_env.cpp @@ -1,6 +1,38 @@ #include "sd_env.h" -namespace WasmEdge::Host::StableDiffusion { +#include "sd_module.h" +namespace WasmEdge { +namespace Host { +namespace { + +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new SDModule; +} + +Plugin::Plugin::PluginDescriptor Descriptor{ + .Name = "wasmedge_stablediffusion", + .Description = "Stable Diffusion plug-in for WasmEdge.", + .APIVersion = Plugin::Plugin::CurrentAPIVersion, + .Version = {0, 1, 0, 0}, + .ModuleCount = 1, + .ModuleDescriptions = + (Plugin::PluginModule::ModuleDescriptor[]){ + { + .Name = "wasmedge_stablediffusion", + .Description = + "This module contains Stable Diffusion host functions.", + .Create = create, + }, + }, + .AddOptions = nullptr, +}; + +EXPORT_GET_DESCRIPTOR(Descriptor) + +} // namespace + +namespace StableDiffusion { uint32_t SDEnviornment::addContext(sd_ctx_t *sd_ctx) noexcept { Contexts.push_back(sd_ctx); return Contexts.size() - 1; @@ -8,4 +40,6 @@ uint32_t SDEnviornment::addContext(sd_ctx_t *sd_ctx) noexcept { sd_ctx_t *SDEnviornment::getContext(const uint32_t Id) noexcept { return Contexts[Id]; } -} // namespace WasmEdge::Host::StableDiffusion \ No newline at end of file +} // namespace StableDiffusion +} // namespace Host +} // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasmedge_stablediffusion/sd_module.cpp b/plugins/wasmedge_stablediffusion/sd_module.cpp index 84ebab4a..14a8ab5b 100644 --- a/plugins/wasmedge_stablediffusion/sd_module.cpp +++ b/plugins/wasmedge_stablediffusion/sd_module.cpp @@ -3,13 +3,13 @@ namespace WasmEdge { namespace Host { -namespace StableDiffusion { SDModule::SDModule() : ModuleInstance("stable_diffusion") { - addHostFunc("create_context", std::make_unique(Env)); - addHostFunc("text_to_image", std::make_unique(Env)); + addHostFunc("create_context", + std::make_unique(Env)); + addHostFunc("text_to_image", + std::make_unique(Env)); } -} // namespace StableDiffusion } // namespace Host } // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasmedge_stablediffusion/sd_module.h b/plugins/wasmedge_stablediffusion/sd_module.h index 219581e1..8088355b 100644 --- a/plugins/wasmedge_stablediffusion/sd_module.h +++ b/plugins/wasmedge_stablediffusion/sd_module.h @@ -5,7 +5,6 @@ namespace WasmEdge { namespace Host { -namespace StableDiffusion { class SDModule : public Runtime::Instance::ModuleInstance { public: SDModule(); @@ -15,6 +14,5 @@ class SDModule : public Runtime::Instance::ModuleInstance { StableDiffusion::SDEnviornment Env; }; -} // namespace StableDiffusion } // namespace Host } // namespace WasmEdge \ No newline at end of file diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt index cdda1155..30da9585 100644 --- a/test/plugins/CMakeLists.txt +++ b/test/plugins/CMakeLists.txt @@ -53,6 +53,12 @@ if(WASMEDGE_PLUGIN_WASM_BPF) endif() endif() +if(WASMEDGE_PLUGIN_STABLEDIFFUSION) + if(CMAKE_SYSTEM_NAME MATCHES "Linux") + add_subdirectory(wasmedge_stablediffusion) + endif() +endif() + if(WASMEDGE_PLUGIN_WASI_LOGGING) add_subdirectory(wasi_logging) endif() diff --git a/test/plugins/wasmedge_stablediffusion/CMakeLists.txt b/test/plugins/wasmedge_stablediffusion/CMakeLists.txt new file mode 100644 index 00000000..13ad7a74 --- /dev/null +++ b/test/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -0,0 +1,33 @@ +wasmedge_add_executable(wasmedgeStableDiffusionTests + wasmedge_stablediffusion.cpp +) + +add_dependencies(wasmedgeStableDiffusionTests + wasmedgePluginStableDiffusion +) + +target_include_directories(wasmedgeStableDiffusionTests + PUBLIC + $ + $ +) + +target_link_libraries(wasmedgeStableDiffusionTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} +) + +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgeStableDiffusionTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgeStableDiffusionTests + PRIVATE + wasmedge_shared + ) +endif() + +add_test(wasmedgeStableDiffusionTests wasmedgeStableDiffusionTests) \ No newline at end of file diff --git a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp new file mode 100644 index 00000000..12f64b53 --- /dev/null +++ b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp @@ -0,0 +1,44 @@ +#include "common/defines.h" +#include "runtime/instance/module.h" +#include "sd_func.h" +#include "sd_module.h" + +#include +#include +#include +#include +#include +#include + +namespace { +WasmEdge::Runtime::Instance::ModuleInstance *createModule() { + using namespace std::literals::string_view_literals; + bool LoadState = WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasmedge_stablediffusion/" WASMEDGE_LIB_PREFIX + "wasmedgePluginStableDiffusion" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = + WasmEdge::Plugin::Plugin::find("wasmedge_stablediffusion"sv)) { + if (const auto *Module = Plugin->findModule("wasmedge_stablediffusion"sv)) { + return Module->create().release(); + } + } + return nullptr; +} +} // namespace + +// TODO: unit tests for every functions. + +TEST(WasmEdgeImageTest, Module) { + // Create the wasmedge_image module instance. + auto *SBMod = dynamic_cast(createModule()); + EXPECT_FALSE(SBMod == nullptr); + EXPECT_EQ(SBMod->getFuncExportNum(), 2U); + EXPECT_NE(SBMod->findFuncExports("create_context"), nullptr); + EXPECT_NE(SBMod->findFuncExports("text_to_image"), nullptr); + delete SBMod; +} + +GTEST_API_ int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} From 6de4d9f4b68d5f2f50c15e611c86959464ade098 Mon Sep 17 00:00:00 2001 From: grorge Date: Tue, 11 Jun 2024 13:02:48 +0800 Subject: [PATCH 339/623] [Plugin] stable diffusion: refactor variable name Signed-off-by: grorge --- plugins/wasmedge_stablediffusion/sd_env.cpp | 4 +- plugins/wasmedge_stablediffusion/sd_env.h | 2 +- plugins/wasmedge_stablediffusion/sd_func.cpp | 71 ++++++++++---------- 3 files changed, 39 insertions(+), 38 deletions(-) diff --git a/plugins/wasmedge_stablediffusion/sd_env.cpp b/plugins/wasmedge_stablediffusion/sd_env.cpp index 9f3d2222..8b0ab85e 100644 --- a/plugins/wasmedge_stablediffusion/sd_env.cpp +++ b/plugins/wasmedge_stablediffusion/sd_env.cpp @@ -33,8 +33,8 @@ EXPORT_GET_DESCRIPTOR(Descriptor) } // namespace namespace StableDiffusion { -uint32_t SDEnviornment::addContext(sd_ctx_t *sd_ctx) noexcept { - Contexts.push_back(sd_ctx); +uint32_t SDEnviornment::addContext(sd_ctx_t *Ctx) noexcept { + Contexts.push_back(Ctx); return Contexts.size() - 1; } sd_ctx_t *SDEnviornment::getContext(const uint32_t Id) noexcept { diff --git a/plugins/wasmedge_stablediffusion/sd_env.h b/plugins/wasmedge_stablediffusion/sd_env.h index b9908873..f7cc23b5 100644 --- a/plugins/wasmedge_stablediffusion/sd_env.h +++ b/plugins/wasmedge_stablediffusion/sd_env.h @@ -20,7 +20,7 @@ enum class ErrNo : uint32_t { class SDEnviornment { public: - uint32_t addContext(sd_ctx_t *sd_ctx) noexcept; + uint32_t addContext(sd_ctx_t *Ctx) noexcept; sd_ctx_t *getContext(const uint32_t Id) noexcept; private: diff --git a/plugins/wasmedge_stablediffusion/sd_func.cpp b/plugins/wasmedge_stablediffusion/sd_func.cpp index 001b1b43..c3a5e6bd 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.cpp +++ b/plugins/wasmedge_stablediffusion/sd_func.cpp @@ -50,24 +50,25 @@ namespace StableDiffusion { } sd_image_t *ReadControlImage(Span ControlImage, - uint8_t *control_image_buffer, int Width, - int Height, bool canny_preprocess) { - sd_image_t *control_image = NULL; + uint8_t *ControlImageBuf, int Width, int Height, + bool CannyPreprocess) { + sd_image_t *ControlImg = NULL; int Channel = 0; - control_image_buffer = stbi_load_from_memory( + ControlImageBuf = stbi_load_from_memory( ControlImage.data(), ControlImage.size(), &Width, &Height, &Channel, 3); - if (control_image_buffer == NULL) { + if (ControlImageBuf == NULL) { spdlog::error("[StableDiffusion] Load image from control image failed."sv); return nullptr; } - control_image = new sd_image_t{(uint32_t)Width, (uint32_t)Height, 3, - control_image_buffer}; - if (canny_preprocess) { // apply preprocessor - control_image->data = preprocess_canny( - control_image->data, control_image->width, control_image->height, 0.08f, - 0.08f, 0.8f, 1.0f, false); + ControlImg = + new sd_image_t{static_cast(Width), + static_cast(Height), 3, ControlImageBuf}; + if (CannyPreprocess) { // apply preprocessor + ControlImg->data = + preprocess_canny(ControlImg->data, ControlImg->width, + ControlImg->height, 0.08f, 0.08f, 0.8f, 1.0f, false); } - return control_image; + return ControlImg; } Expect SDCreateContext::body( @@ -105,18 +106,18 @@ Expect SDCreateContext::body( "Failed when accessing the return SessionID memory."sv) // Create context and import graph. - sd_ctx_t *sd_ctx = - new_sd_ctx(ModelPathSpan.data(), VaePathSpan.data(), TaesdPathSpan.data(), - ControlNetPathSpan.data(), LoraModelDirSpan.data(), - EmbedDirSpan.data(), IdEmbedDirSpan.data(), - static_cast(VaeDecodeOnly), static_cast(VaeTiling), - true, NThreads, sd_type_t(Wtype), rng_type_t(RngType), - schedule_t(Schedule), ClipOnCpu, ControlNetCpu, VaeOnCpu); - if (sd_ctx == NULL) { + sd_ctx_t *Ctx = new_sd_ctx( + ModelPathSpan.data(), VaePathSpan.data(), TaesdPathSpan.data(), + ControlNetPathSpan.data(), LoraModelDirSpan.data(), EmbedDirSpan.data(), + IdEmbedDirSpan.data(), static_cast(VaeDecodeOnly), + static_cast(VaeTiling), true, NThreads, + static_cast(Wtype), static_cast(RngType), + static_cast(Schedule), ClipOnCpu, ControlNetCpu, VaeOnCpu); + if (Ctx == NULL) { spdlog::error("[WasmEdge-StableDiffusion] Failed to create context."); return static_cast(ErrNo::InvalidArgument); } - *SessionId = Env.addContext(sd_ctx); + *SessionId = Env.addContext(Ctx); return static_cast(ErrNo::Success); } @@ -129,7 +130,7 @@ Expect SDTextToImage::body( uint32_t SampleSteps, uint32_t Strength, uint32_t Seed, uint32_t BatchCount, uint32_t ControlStrength, uint32_t StyleRatio, uint32_t NormalizeInput, uint32_t InputIdImagesPathPtr, uint32_t InputIdImagesPathLen, - uint32_t canny_preprocess, uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, + uint32_t CannyPreprocess, uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr) { // Check memory instance from module. MEMINST_CHECK(MemInst, Frame, 0) @@ -155,7 +156,7 @@ Expect SDTextToImage::body( sd_ctx_t *SDCtx = Env.getContext(SessionId); uint8_t *InputImageBuffer = nullptr; - uint8_t *control_image_buffer = nullptr; + uint8_t *ControlImageBuffer = nullptr; int Channel = 0; int ImageWidth = 0; int ImageHeight = 0; @@ -168,25 +169,25 @@ Expect SDTextToImage::body( MEM_SPAN_CHECK(ControlImageSpan, MemInst, uint8_t, ControlImagePtr, ControlImageLen, "Failed when accessing the control image memory."sv) - ControlImage = ReadControlImage(ControlImageSpan, control_image_buffer, - Width, Height, canny_preprocess); + ControlImage = ReadControlImage(ControlImageSpan, ControlImageBuffer, Width, + Height, CannyPreprocess); } - sd_image_t *results; - results = img2img(SDCtx, InputImage, PromptSpan.data(), + sd_image_t *Results; + Results = img2img(SDCtx, InputImage, PromptSpan.data(), NegativePromptSpan.data(), ClipSkip, CfgScale, Width, Height, sample_method_t(SampleMethod), SampleSteps, Strength, Seed, BatchCount, ControlImage, ControlStrength, StyleRatio, NormalizeInput, InputIdImagesPathSpan.data()); - int len; - unsigned char *png = - stbi_write_png_to_mem((const unsigned char *)results, 0, results->width, - results->height, results->channel, &len, NULL); - *BytesWritten = len; - std::copy_n(png, *BytesWritten, OutputBufferSpan.data()); - free(results); + int Len; + unsigned char *Png = stbi_write_png_to_mem( + reinterpret_cast(Results), 0, Results->width, + Results->height, Results->channel, &Len, NULL); + *BytesWritten = Len; + std::copy_n(Png, *BytesWritten, OutputBufferSpan.data()); + free(Results); free(InputImageBuffer); - free(control_image_buffer); + free(ControlImageBuffer); return static_cast(ErrNo::RuntimeError); } } // namespace StableDiffusion From a5e59bd3efd5a94231cd9f422b2604b0a3491281 Mon Sep 17 00:00:00 2001 From: grorge Date: Tue, 11 Jun 2024 18:07:55 +0800 Subject: [PATCH 340/623] [Plugin] stable diffusion: add convert function Signed-off-by: grorge --- plugins/wasmedge_stablediffusion/sd_func.cpp | 56 +++++++++++-- plugins/wasmedge_stablediffusion/sd_func.h | 9 +++ .../wasmedge_stablediffusion/sd_module.cpp | 1 + .../wasmedge_stablediffusion/CMakeLists.txt | 14 ++++ .../wasmedge_stablediffusion.cpp | 80 ++++++++++++++++++- 5 files changed, 150 insertions(+), 10 deletions(-) diff --git a/plugins/wasmedge_stablediffusion/sd_func.cpp b/plugins/wasmedge_stablediffusion/sd_func.cpp index c3a5e6bd..68ec1744 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.cpp +++ b/plugins/wasmedge_stablediffusion/sd_func.cpp @@ -49,14 +49,14 @@ namespace StableDiffusion { return static_cast(ErrNo::MissingMemory); \ } -sd_image_t *ReadControlImage(Span ControlImage, +sd_image_t *readControlImage(Span ControlImage, uint8_t *ControlImageBuf, int Width, int Height, bool CannyPreprocess) { - sd_image_t *ControlImg = NULL; + sd_image_t *ControlImg = nullptr; int Channel = 0; ControlImageBuf = stbi_load_from_memory( ControlImage.data(), ControlImage.size(), &Width, &Height, &Channel, 3); - if (ControlImageBuf == NULL) { + if (ControlImageBuf == nullptr) { spdlog::error("[StableDiffusion] Load image from control image failed."sv); return nullptr; } @@ -71,6 +71,50 @@ sd_image_t *ReadControlImage(Span ControlImage, return ControlImg; } +Expect SDConvert::body(const Runtime::CallingFrame &Frame, + uint32_t ModelPathPtr, uint32_t ModelPathLen, + uint32_t VaeModelPathPtr, + uint32_t VaeModelPathLen, + uint32_t OutputPathPtr, uint32_t OutputPathLen, + uint32_t WType) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Check the input model buffer. + MEM_SPAN_CHECK(ModelPathSpan, MemInst, char, ModelPathPtr, ModelPathLen, + "Failed when accessing the input model path memory."sv) + MEM_SPAN_CHECK(VaeModelPathSpan, MemInst, char, VaeModelPathPtr, + VaeModelPathLen, + "Failed when accessing the input vae model path memory."sv) + MEM_SPAN_CHECK(OutputPathSpan, MemInst, char, OutputPathPtr, OutputPathLen, + "Failed when accessing the output path memory."sv) + std::string ModelPath = std::string( + ModelPathSpan.begin(), ModelPathSpan.begin() + ModelPathSpan.size()); + std::string VaeModelPath = + std::string(VaeModelPathSpan.begin(), + VaeModelPathSpan.begin() + VaeModelPathSpan.size()); + std::string OutputPath = std::string( + OutputPathSpan.begin(), OutputPathSpan.begin() + OutputPathSpan.size()); + + spdlog::info("[WasmEdge-StableDiffusion] Convert model: {} to {}."sv, + ModelPath.data(), OutputPath.data()); + std::ifstream Fin(ModelPath.data(), + std::ios::in | std::ios::binary | std::ios::ate); + if (!Fin) { + spdlog::error("[WasmEdge-StableDiffusion] Model not found."); + return static_cast(ErrNo::InvalidArgument); + } + // Convert model. + bool Ret = convert(ModelPath.data(), VaeModelPath.data(), OutputPath.data(), + static_cast(WType)); + if (!Ret) { + spdlog::error("[WasmEdge-StableDiffusion] Failed to convert model."); + return static_cast(ErrNo::InvalidArgument); + } + + return static_cast(ErrNo::Success); +} + Expect SDCreateContext::body( const Runtime::CallingFrame &Frame, uint32_t ModelPathPtr, uint32_t ModelPathLen, uint32_t VaePathPtr, uint32_t VaePathLen, @@ -113,7 +157,7 @@ Expect SDCreateContext::body( static_cast(VaeTiling), true, NThreads, static_cast(Wtype), static_cast(RngType), static_cast(Schedule), ClipOnCpu, ControlNetCpu, VaeOnCpu); - if (Ctx == NULL) { + if (Ctx == nullptr) { spdlog::error("[WasmEdge-StableDiffusion] Failed to create context."); return static_cast(ErrNo::InvalidArgument); } @@ -169,7 +213,7 @@ Expect SDTextToImage::body( MEM_SPAN_CHECK(ControlImageSpan, MemInst, uint8_t, ControlImagePtr, ControlImageLen, "Failed when accessing the control image memory."sv) - ControlImage = ReadControlImage(ControlImageSpan, ControlImageBuffer, Width, + ControlImage = readControlImage(ControlImageSpan, ControlImageBuffer, Width, Height, CannyPreprocess); } @@ -182,7 +226,7 @@ Expect SDTextToImage::body( int Len; unsigned char *Png = stbi_write_png_to_mem( reinterpret_cast(Results), 0, Results->width, - Results->height, Results->channel, &Len, NULL); + Results->height, Results->channel, &Len, nullptr); *BytesWritten = Len; std::copy_n(Png, *BytesWritten, OutputBufferSpan.data()); free(Results); diff --git a/plugins/wasmedge_stablediffusion/sd_func.h b/plugins/wasmedge_stablediffusion/sd_func.h index 62971202..1ee8b594 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.h +++ b/plugins/wasmedge_stablediffusion/sd_func.h @@ -36,6 +36,15 @@ class SDTextToImage : public StableDiffusion::Func { uint32_t canny_preprocess, uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr); }; +class SDConvert : public StableDiffusion::Func { +public: + SDConvert(StableDiffusion::SDEnviornment &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t ModelPathPtr, uint32_t ModelPathLen, + uint32_t VaeModelPathPtr, uint32_t VaeModelPathLen, + uint32_t OutputPathPtr, uint32_t OutputPathLen, + uint32_t WType); +}; } // namespace StableDiffusion } // namespace Host } // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasmedge_stablediffusion/sd_module.cpp b/plugins/wasmedge_stablediffusion/sd_module.cpp index 14a8ab5b..2cf0c122 100644 --- a/plugins/wasmedge_stablediffusion/sd_module.cpp +++ b/plugins/wasmedge_stablediffusion/sd_module.cpp @@ -9,6 +9,7 @@ SDModule::SDModule() : ModuleInstance("stable_diffusion") { std::make_unique(Env)); addHostFunc("text_to_image", std::make_unique(Env)); + addHostFunc("convert", std::make_unique(Env)); } } // namespace Host diff --git a/test/plugins/wasmedge_stablediffusion/CMakeLists.txt b/test/plugins/wasmedge_stablediffusion/CMakeLists.txt index 13ad7a74..fa321903 100644 --- a/test/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/test/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -29,5 +29,19 @@ else() wasmedge_shared ) endif() +function(download URL OUTPUT HASH) + file(DOWNLOAD + ${URL} + ${OUTPUT} + SHOW_PROGRESS + EXPECTED_HASH ${HASH} + ) +endfunction() +message(STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/sd-v1-4.ckpt") +download( + https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt + ${CMAKE_CURRENT_BINARY_DIR}/stableDiffusion/sd-v1-4.ckpt + MD5=c01059060130b8242849d86e97212c84 +) add_test(wasmedgeStableDiffusionTests wasmedgeStableDiffusionTests) \ No newline at end of file diff --git a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp index 12f64b53..a7938834 100644 --- a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp +++ b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp @@ -10,10 +10,12 @@ #include #include +using WasmEdge::Host::StableDiffusion::ErrNo; + namespace { WasmEdge::Runtime::Instance::ModuleInstance *createModule() { using namespace std::literals::string_view_literals; - bool LoadState = WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( "../../../plugins/wasmedge_stablediffusion/" WASMEDGE_LIB_PREFIX "wasmedgePluginStableDiffusion" WASMEDGE_LIB_EXTENSION)); if (const auto *Plugin = @@ -26,15 +28,85 @@ WasmEdge::Runtime::Instance::ModuleInstance *createModule() { } } // namespace +template +void writeBinaries(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + WasmEdge::Span Binaries, uint32_t Ptr) noexcept { + std::copy(Binaries.begin(), Binaries.end(), MemInst.getPointer(Ptr)); +} + +void writeUInt32(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t Value, uint32_t &Ptr) { + uint32_t *BufPtr = MemInst.getPointer(Ptr); + *BufPtr = Value; + Ptr += 4; +} + +void writeFatPointer(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t PtrVal, uint32_t PtrSize, uint32_t &Ptr) { + writeUInt32(MemInst, PtrVal, Ptr); + writeUInt32(MemInst, PtrSize, Ptr); +} + // TODO: unit tests for every functions. -TEST(WasmEdgeImageTest, Module) { - // Create the wasmedge_image module instance. +TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { + // Create the stable diffusion module instance. auto *SBMod = dynamic_cast(createModule()); EXPECT_FALSE(SBMod == nullptr); - EXPECT_EQ(SBMod->getFuncExportNum(), 2U); + EXPECT_EQ(SBMod->getFuncExportNum(), 3U); EXPECT_NE(SBMod->findFuncExports("create_context"), nullptr); EXPECT_NE(SBMod->findFuncExports("text_to_image"), nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(60000))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // Return value. + std::array Errno = {UINT32_C(0)}; + + // uint32_t OutBoundPtr = UINT32_C(61000) * UINT32_C(65536); + + // Get the function "convert". + auto *FuncInst = SBMod->findFuncExports("convert"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncConvert = + dynamic_cast( + FuncInst->getHostFunc()); + std::string Prompt = "a lovely cat"; + std::vector TensorData(Prompt.begin(), Prompt.end()); + std::string ModelPathString = "./stableDiffusion/sd-v1-4.ckpt"; + std::vector ModelPath(ModelPathString.begin(), ModelPathString.end()); + std::string QuantModelPathString = "./stableDiffusion/sd-v1-4-Q4_K.gguf"; + std::vector QuantModelPath(QuantModelPathString.begin(), + QuantModelPathString.end()); + + // Test: convert -- convert successfully. + { + uint32_t ModelPathPtr = UINT32_C(0); + writeBinaries(MemInst, ModelPath, ModelPathPtr); + uint32_t QuantModelPathPtr = ModelPathPtr + ModelPath.size(); + writeBinaries(MemInst, QuantModelPath, QuantModelPathPtr); + EXPECT_TRUE(HostFuncConvert.run( + CallFrame, + std::initializer_list{ + ModelPathPtr, static_cast(ModelPath.size()), 0, 0, + QuantModelPathPtr, static_cast(QuantModelPath.size()), + 12}, // SD_TYPE_Q4_K = 12 + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + std::ifstream Fin(QuantModelPath.data(), + std::ios::in | std::ios::binary | std::ios::ate); + EXPECT_FALSE(Fin.fail()); + Fin.close(); + } + delete SBMod; } From 3a6fbfec3b28be31eafef6533c160cceafc26d57 Mon Sep 17 00:00:00 2001 From: grorge Date: Mon, 17 Jun 2024 13:28:36 +0800 Subject: [PATCH 341/623] [Plugin] stable diffusion: add image to image Signed-off-by: grorge --- plugins/CMakeLists.txt | 4 - .../wasmedge_stablediffusion/CMakeLists.txt | 4 +- plugins/wasmedge_stablediffusion/sd_env.cpp | 25 ++ plugins/wasmedge_stablediffusion/sd_env.h | 8 +- plugins/wasmedge_stablediffusion/sd_func.cpp | 187 ++++++++++++--- plugins/wasmedge_stablediffusion/sd_func.h | 32 ++- .../wasmedge_stablediffusion/sd_module.cpp | 2 + .../wasmedge_stablediffusion.cpp | 220 +++++++++++++++++- 8 files changed, 424 insertions(+), 58 deletions(-) diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index 47248f1e..f951985f 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -67,11 +67,7 @@ if(WASMEDGE_PLUGIN_WASI_OCR) endif() if(WASMEDGE_PLUGIN_STABLEDIFFUSION) - if(CMAKE_SYSTEM_NAME MATCHES "Linux") add_subdirectory(wasmedge_stablediffusion) - else() - message(WARNING "Only Linux platforms support wasm_bpf plug-in now.") - endif() endif() if(WASMEDGE_PLUGIN_OPENCVMINI) diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index 2544452b..4a6ac7b6 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -20,8 +20,8 @@ endif() FetchContent_Declare( stable-diffusion GIT_REPOSITORY https://github.com/leejet/stable-diffusion.cpp.git - GIT_TAG master-1d2af5c - GIT_SHALLOW FALSE + GIT_TAG master-9c51d87 + GIT_SHALLOW TRUE ) FetchContent_MakeAvailable(stable-diffusion) set_property(TARGET stable-diffusion PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/plugins/wasmedge_stablediffusion/sd_env.cpp b/plugins/wasmedge_stablediffusion/sd_env.cpp index 8b0ab85e..6c80d5e6 100644 --- a/plugins/wasmedge_stablediffusion/sd_env.cpp +++ b/plugins/wasmedge_stablediffusion/sd_env.cpp @@ -40,6 +40,31 @@ uint32_t SDEnviornment::addContext(sd_ctx_t *Ctx) noexcept { sd_ctx_t *SDEnviornment::getContext(const uint32_t Id) noexcept { return Contexts[Id]; } +void SBLog(enum sd_log_level_t level, const char *log, void *) { + if (!log) { + return; + } + std::string levelStr; + switch (level) { + case SD_LOG_DEBUG: + levelStr = "DEBUG"; + break; + case SD_LOG_INFO: + levelStr = "INFO"; + break; + case SD_LOG_WARN: + levelStr = "WARN"; + break; + case SD_LOG_ERROR: + levelStr = "ERROR"; + break; + default: + levelStr = "?????"; + break; + } + + spdlog::info("[WasmEdge-StableDiffusion] SD-log: [{}] {}", levelStr, log); +} } // namespace StableDiffusion } // namespace Host } // namespace WasmEdge \ No newline at end of file diff --git a/plugins/wasmedge_stablediffusion/sd_env.h b/plugins/wasmedge_stablediffusion/sd_env.h index f7cc23b5..3685aa5f 100644 --- a/plugins/wasmedge_stablediffusion/sd_env.h +++ b/plugins/wasmedge_stablediffusion/sd_env.h @@ -8,7 +8,7 @@ namespace WasmEdge { namespace Host { namespace StableDiffusion { - +void SBLog(enum sd_log_level_t level, const char *log, void *); enum class ErrNo : uint32_t { Success = 0, // No error occurred. InvalidArgument = 1, // Caller module passed an invalid argument. @@ -20,11 +20,17 @@ enum class ErrNo : uint32_t { class SDEnviornment { public: + SDEnviornment() noexcept { + if (EnableSDLog) { + sd_set_log_callback(SBLog, nullptr); + } + }; uint32_t addContext(sd_ctx_t *Ctx) noexcept; sd_ctx_t *getContext(const uint32_t Id) noexcept; private: std::vector Contexts; + bool EnableSDLog = false; }; } // namespace StableDiffusion diff --git a/plugins/wasmedge_stablediffusion/sd_func.cpp b/plugins/wasmedge_stablediffusion/sd_func.cpp index 68ec1744..b246065c 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.cpp +++ b/plugins/wasmedge_stablediffusion/sd_func.cpp @@ -54,10 +54,20 @@ sd_image_t *readControlImage(Span ControlImage, bool CannyPreprocess) { sd_image_t *ControlImg = nullptr; int Channel = 0; - ControlImageBuf = stbi_load_from_memory( - ControlImage.data(), ControlImage.size(), &Width, &Height, &Channel, 3); + std::string ControlImagePath(ControlImage.begin(), ControlImage.end()); + + if (ControlImagePath.substr(0, 5) == "path:"sv) { + ControlImageBuf = stbi_load(ControlImagePath.substr(5).data(), &Width, + &Height, &Channel, 3); + } else { + + ControlImageBuf = stbi_load_from_memory( + ControlImage.data(), ControlImage.size(), &Width, &Height, &Channel, 3); + } + if (ControlImageBuf == nullptr) { - spdlog::error("[StableDiffusion] Load image from control image failed."sv); + spdlog::error( + "[WasmEdge-StableDiffusion] Load image from control image failed."sv); return nullptr; } ControlImg = @@ -80,7 +90,7 @@ Expect SDConvert::body(const Runtime::CallingFrame &Frame, // Check memory instance from module. MEMINST_CHECK(MemInst, Frame, 0) - // Check the input model buffer. + // Check the input parameter value. MEM_SPAN_CHECK(ModelPathSpan, MemInst, char, ModelPathPtr, ModelPathLen, "Failed when accessing the input model path memory."sv) MEM_SPAN_CHECK(VaeModelPathSpan, MemInst, char, VaeModelPathPtr, @@ -98,12 +108,13 @@ Expect SDConvert::body(const Runtime::CallingFrame &Frame, spdlog::info("[WasmEdge-StableDiffusion] Convert model: {} to {}."sv, ModelPath.data(), OutputPath.data()); - std::ifstream Fin(ModelPath.data(), - std::ios::in | std::ios::binary | std::ios::ate); + std::ifstream Fin(ModelPath.data(), std::ios::in | std::ios::binary); if (!Fin) { + Fin.close(); spdlog::error("[WasmEdge-StableDiffusion] Model not found."); return static_cast(ErrNo::InvalidArgument); } + Fin.close(); // Convert model. bool Ret = convert(ModelPath.data(), VaeModelPath.data(), OutputPath.data(), static_cast(WType)); @@ -149,13 +160,30 @@ Expect SDCreateContext::body( MEM_PTR_CHECK(SessionId, MemInst, uint32_t, SessiontIdPtr, "Failed when accessing the return SessionID memory."sv) + std::string ModelPath = + std::string(ModelPathSpan.begin(), ModelPathSpan.end()); + std::string VaePath = std::string(VaePathSpan.begin(), VaePathSpan.end()); + std::string TaesdPath = + std::string(TaesdPathSpan.begin(), TaesdPathSpan.end()); + std::string ControlNetPath = + std::string(ControlNetPathSpan.begin(), ControlNetPathSpan.end()); + std::string LoraModelDir = + std::string(LoraModelDirSpan.begin(), LoraModelDirSpan.end()); + std::string EmbedDir = std::string(EmbedDirSpan.begin(), EmbedDirSpan.end()); + std::string IdEmbedDir = + std::string(IdEmbedDirSpan.begin(), IdEmbedDirSpan.end()); + if (NThreads == -1) { + NThreads = get_num_physical_cores(); + } + + spdlog::info("[WasmEdge-StableDiffusion] Create context."sv); // Create context and import graph. + sd_ctx_t *Ctx = new_sd_ctx( - ModelPathSpan.data(), VaePathSpan.data(), TaesdPathSpan.data(), - ControlNetPathSpan.data(), LoraModelDirSpan.data(), EmbedDirSpan.data(), - IdEmbedDirSpan.data(), static_cast(VaeDecodeOnly), - static_cast(VaeTiling), true, NThreads, - static_cast(Wtype), static_cast(RngType), + ModelPath.data(), VaePath.data(), TaesdPath.data(), ControlNetPath.data(), + LoraModelDir.data(), EmbedDir.data(), IdEmbedDir.data(), + static_cast(VaeDecodeOnly), static_cast(VaeTiling), true, + NThreads, static_cast(Wtype), static_cast(RngType), static_cast(Schedule), ClipOnCpu, ControlNetCpu, VaeOnCpu); if (Ctx == nullptr) { spdlog::error("[WasmEdge-StableDiffusion] Failed to create context."); @@ -166,20 +194,88 @@ Expect SDCreateContext::body( return static_cast(ErrNo::Success); } Expect SDTextToImage::body( + const Runtime::CallingFrame &Frame, uint32_t PromptPtr, uint32_t PromptLen, + uint32_t SessionId, uint32_t ControlImagePtr, uint32_t ControlImageLen, + uint32_t NegativePromptPtr, uint32_t NegativePromptLen, uint32_t Width, + uint32_t Height, int32_t ClipSkip, float CfgScale, uint32_t SampleMethod, + uint32_t SampleSteps, uint32_t Seed, uint32_t BatchCount, + float ControlStrength, float StyleRatio, uint32_t NormalizeInput, + uint32_t InputIdImagesPathPtr, uint32_t InputIdImagesPathLen, + uint32_t CannyPreprocess, uint32_t OutputPathPtr, uint32_t OutputPathLen, + uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, + uint32_t BytesWrittenPtr) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + // Check the input model buffer. + MEM_SPAN_CHECK(PromptSpan, MemInst, char, PromptPtr, PromptLen, + "Failed when accessing the promp memory."sv) + MEM_SPAN_CHECK(NegativePromptSpan, MemInst, char, NegativePromptPtr, + NegativePromptLen, + "Failed when accessing the input negative prompt memory."sv) + MEM_SPAN_CHECK(InputIdImagesPathSpan, MemInst, char, InputIdImagesPathPtr, + InputIdImagesPathLen, + "Failed when accessing the input id images path memory."sv) + MEM_SPAN_CHECK(OutputBufferSpan, MemInst, uint8_t, OutBufferPtr, + OutBufferMaxSize, + "Failed when accessing the Output Buffer memory."sv) + MEM_PTR_CHECK(BytesWritten, MemInst, uint32_t, BytesWrittenPtr, + "Failed when accessing the return bytes written memory."sv) + MEM_SPAN_CHECK(OutputPathSpan, MemInst, char, OutputPathPtr, OutputPathLen, + "Failed when accessing the output path memory."sv) + std::string Prompt(PromptSpan.begin(), PromptSpan.end()); + std::string NegativePrompt(NegativePromptSpan.begin(), + NegativePromptSpan.end()); + std::string InputIdImagesPath(InputIdImagesPathSpan.begin(), + InputIdImagesPathSpan.end()); + std::string OutputPath(OutputPathSpan.begin(), OutputPathSpan.end()); + sd_ctx_t *SDCtx = Env.getContext(SessionId); + sd_image_t *Results = nullptr; + sd_image_t *ControlImage = nullptr; + uint8_t *ControlImageBuffer = nullptr; + if (ControlImageLen != 0) { + MEM_SPAN_CHECK(ControlImageSpan, MemInst, uint8_t, ControlImagePtr, + ControlImageLen, + "Failed when accessing the control image memory."sv) + ControlImage = readControlImage(ControlImageSpan, ControlImageBuffer, Width, + Height, CannyPreprocess); + } + spdlog::info("[WasmEdge-StableDiffusion] Start to generate image."sv); + Results = + txt2img(SDCtx, Prompt.data(), NegativePrompt.data(), ClipSkip, CfgScale, + Width, Height, sample_method_t(SampleMethod), SampleSteps, Seed, + BatchCount, ControlImage, ControlStrength, StyleRatio, + NormalizeInput, InputIdImagesPath.data()); + int Len; + unsigned char *Png = stbi_write_png_to_mem( + reinterpret_cast(Results), 0, Results->width, + Results->height, Results->channel, &Len, nullptr); + if (OutputPathLen != 0) { + stbi_write_png(OutputPath.data(), Results->width, Results->height, + Results->channel, Results->data, 0, nullptr); + } + *BytesWritten = Len; + std::copy_n(Png, *BytesWritten, OutputBufferSpan.data()); + free(Png); + free(Results); + free(ControlImageBuffer); + return static_cast(ErrNo::Success); +} +Expect SDImageToImage::body( const Runtime::CallingFrame &Frame, uint32_t ImagePtr, uint32_t ImageLen, uint32_t SessionId, uint32_t Width, uint32_t Height, uint32_t ControlImagePtr, uint32_t ControlImageLen, uint32_t PromptPtr, uint32_t PromptLen, uint32_t NegativePromptPtr, uint32_t NegativePromptLen, - uint32_t ClipSkip, uint32_t CfgScale, uint32_t SampleMethod, - uint32_t SampleSteps, uint32_t Strength, uint32_t Seed, uint32_t BatchCount, - uint32_t ControlStrength, uint32_t StyleRatio, uint32_t NormalizeInput, + int32_t ClipSkip, float CfgScale, uint32_t SampleMethod, + uint32_t SampleSteps, float Strength, uint32_t Seed, uint32_t BatchCount, + float ControlStrength, float StyleRatio, uint32_t NormalizeInput, uint32_t InputIdImagesPathPtr, uint32_t InputIdImagesPathLen, - uint32_t CannyPreprocess, uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, + uint32_t CannyPreprocess, uint32_t OutputPathPtr, uint32_t OutputPathLen, + uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr) { // Check memory instance from module. MEMINST_CHECK(MemInst, Frame, 0) - // Check the input model buffer. + // Check the input parameter valud. MEM_SPAN_CHECK(ImageSpan, MemInst, uint8_t, ImagePtr, ImageLen, "Failed when accessing the input image memory."sv) @@ -188,27 +284,47 @@ Expect SDTextToImage::body( MEM_SPAN_CHECK(NegativePromptSpan, MemInst, char, NegativePromptPtr, NegativePromptLen, "Failed when accessing the input negative prompt memory."sv) - MEM_SPAN_CHECK( - InputIdImagesPathSpan, MemInst, char, InputIdImagesPathPtr, - InputIdImagesPathLen, - "Failed when accessing the input input id images path memory."sv) + MEM_SPAN_CHECK(InputIdImagesPathSpan, MemInst, char, InputIdImagesPathPtr, + InputIdImagesPathLen, + "Failed when accessing the input id images path memory."sv) MEM_SPAN_CHECK(OutputBufferSpan, MemInst, uint8_t, OutBufferPtr, OutBufferMaxSize, "Failed when accessing the Output Buffer memory."sv) MEM_PTR_CHECK(BytesWritten, MemInst, uint32_t, BytesWrittenPtr, "Failed when accessing the return bytes written memory."sv) + MEM_SPAN_CHECK(OutputPathSpan, MemInst, char, OutputPathPtr, OutputPathLen, + "Failed when accessing the output path memory."sv) sd_ctx_t *SDCtx = Env.getContext(SessionId); - + std::string Prompt(PromptSpan.begin(), PromptSpan.end()); + std::string NegativePrompt(NegativePromptSpan.begin(), + NegativePromptSpan.end()); + std::string InputIdImagesPath(InputIdImagesPathSpan.begin(), + InputIdImagesPathSpan.end()); + std::string OutputPath(OutputPathSpan.begin(), OutputPathSpan.end()); uint8_t *InputImageBuffer = nullptr; uint8_t *ControlImageBuffer = nullptr; int Channel = 0; int ImageWidth = 0; int ImageHeight = 0; - InputImageBuffer = - stbi_load_from_memory(ImageSpan.data(), ImageSpan.size(), &ImageWidth, - &ImageHeight, &Channel, 3); + std::string ImagePath(ImageSpan.begin(), ImageSpan.end()); + if (ImagePath.substr(0, 5) == "path:"sv) { + InputImageBuffer = stbi_load(ImagePath.substr(5).data(), &ImageWidth, + &ImageHeight, &Channel, 3); + if (InputImageBuffer == nullptr) { + spdlog::error( + "[WasmEdge-StableDiffusion] Load image from input image failed."sv); + return static_cast(ErrNo::InvalidArgument); + } + } else { + + InputImageBuffer = + stbi_load_from_memory(ImageSpan.data(), ImageSpan.size(), &ImageWidth, + &ImageHeight, &Channel, 3); + } + + // TODO: Resize image when image size not matches weight and height sd_image_t InputImage = {Width, Height, 3, InputImageBuffer}; - sd_image_t *ControlImage; + sd_image_t *ControlImage = nullptr; if (ControlImageLen != 0) { MEM_SPAN_CHECK(ControlImageSpan, MemInst, uint8_t, ControlImagePtr, ControlImageLen, @@ -216,23 +332,28 @@ Expect SDTextToImage::body( ControlImage = readControlImage(ControlImageSpan, ControlImageBuffer, Width, Height, CannyPreprocess); } - - sd_image_t *Results; - Results = img2img(SDCtx, InputImage, PromptSpan.data(), - NegativePromptSpan.data(), ClipSkip, CfgScale, Width, - Height, sample_method_t(SampleMethod), SampleSteps, - Strength, Seed, BatchCount, ControlImage, ControlStrength, - StyleRatio, NormalizeInput, InputIdImagesPathSpan.data()); + sd_image_t *Results = nullptr; + spdlog::info("[WasmEdge-StableDiffusion] Start to generate image."sv); + Results = img2img(SDCtx, InputImage, Prompt.data(), NegativePrompt.data(), + ClipSkip, CfgScale, Width, Height, + sample_method_t(SampleMethod), SampleSteps, Strength, Seed, + BatchCount, ControlImage, ControlStrength, StyleRatio, + NormalizeInput, InputIdImagesPath.data()); int Len; unsigned char *Png = stbi_write_png_to_mem( reinterpret_cast(Results), 0, Results->width, Results->height, Results->channel, &Len, nullptr); + if (OutputPathLen != 0) { + stbi_write_png(OutputPath.data(), Results->width, Results->height, + Results->channel, Results->data, 0, nullptr); + } *BytesWritten = Len; std::copy_n(Png, *BytesWritten, OutputBufferSpan.data()); + free(Png); free(Results); free(InputImageBuffer); free(ControlImageBuffer); - return static_cast(ErrNo::RuntimeError); + return static_cast(ErrNo::Success); } } // namespace StableDiffusion } // namespace Host diff --git a/plugins/wasmedge_stablediffusion/sd_func.h b/plugins/wasmedge_stablediffusion/sd_func.h index 1ee8b594..f735e106 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.h +++ b/plugins/wasmedge_stablediffusion/sd_func.h @@ -20,22 +20,38 @@ class SDCreateContext : public StableDiffusion::Func { uint32_t Schedule, uint32_t ClipOnCpu, uint32_t ControlNetCpu, uint32_t VaeOnCpu, uint32_t SessiontIdPtr); }; -class SDTextToImage : public StableDiffusion::Func { +class SDImageToImage : public StableDiffusion::Func { public: - SDTextToImage(StableDiffusion::SDEnviornment &HostEnv) : Func(HostEnv) {} + SDImageToImage(StableDiffusion::SDEnviornment &HostEnv) : Func(HostEnv) {} Expect body(const Runtime::CallingFrame &Frame, uint32_t ImagePtr, uint32_t ImageLen, uint32_t SessionId, uint32_t Width, uint32_t Height, uint32_t ControlImagePtr, uint32_t ControlImageLen, uint32_t PromptPtr, uint32_t PromptLen, uint32_t NegativePromptPtr, - uint32_t NegativePromptLen, uint32_t ClipSkip, uint32_t CfgScale, - uint32_t SampleMethod, uint32_t SampleSteps, uint32_t Strength, - uint32_t Seed, uint32_t BatchCount, uint32_t ControlStrength, - uint32_t StyleRatio, uint32_t NormalizeInput, - uint32_t InputIdImagesPathPtr, uint32_t InputIdImagesPathLen, - uint32_t canny_preprocess, uint32_t OutBufferPtr, + uint32_t NegativePromptLen, int32_t ClipSkip, float CfgScale, + uint32_t SampleMethod, uint32_t SampleSteps, float Strength, + uint32_t Seed, uint32_t BatchCount, float ControlStrength, + float StyleRatio, uint32_t NormalizeInput, uint32_t InputIdImagesPathPtr, + uint32_t InputIdImagesPathLen, uint32_t canny_preprocess, + uint32_t OutputPathPtr, uint32_t OutputPathLen, uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr); }; +class SDTextToImage : public StableDiffusion::Func { +public: + SDTextToImage(StableDiffusion::SDEnviornment &HostEnv) : Func(HostEnv) {} + Expect + body(const Runtime::CallingFrame &Frame, uint32_t PromptPtr, + uint32_t PromptLen, uint32_t SessionId, uint32_t ControlImagePtr, + uint32_t ControlImageLen, uint32_t NegativePromptPtr, + uint32_t NegativePromptLen, uint32_t Width, uint32_t Height, + int32_t ClipSkip, float CfgScale, uint32_t SampleMethod, + uint32_t SampleSteps, uint32_t Seed, uint32_t BatchCount, + float ControlStrength, float StyleRatio, uint32_t NormalizeInput, + uint32_t InputIdImagesPathPtr, uint32_t InputIdImagesPathLen, + uint32_t CannyPreprocess, uint32_t OutputPathPtr, uint32_t OutputPathLen, + uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, + uint32_t BytesWrittenPtr); +}; class SDConvert : public StableDiffusion::Func { public: SDConvert(StableDiffusion::SDEnviornment &HostEnv) : Func(HostEnv) {} diff --git a/plugins/wasmedge_stablediffusion/sd_module.cpp b/plugins/wasmedge_stablediffusion/sd_module.cpp index 2cf0c122..29057bfa 100644 --- a/plugins/wasmedge_stablediffusion/sd_module.cpp +++ b/plugins/wasmedge_stablediffusion/sd_module.cpp @@ -7,6 +7,8 @@ namespace Host { SDModule::SDModule() : ModuleInstance("stable_diffusion") { addHostFunc("create_context", std::make_unique(Env)); + addHostFunc("image_to_image", + std::make_unique(Env)); addHostFunc("text_to_image", std::make_unique(Env)); addHostFunc("convert", std::make_unique(Env)); diff --git a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp index a7938834..f7005bc8 100644 --- a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp +++ b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp @@ -53,9 +53,7 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { // Create the stable diffusion module instance. auto *SBMod = dynamic_cast(createModule()); EXPECT_FALSE(SBMod == nullptr); - EXPECT_EQ(SBMod->getFuncExportNum(), 3U); - EXPECT_NE(SBMod->findFuncExports("create_context"), nullptr); - EXPECT_NE(SBMod->findFuncExports("text_to_image"), nullptr); + EXPECT_EQ(SBMod->getFuncExportNum(), 4U); // Create the calling frame with memory instance. WasmEdge::Runtime::Instance::ModuleInstance Mod(""); @@ -70,6 +68,9 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { // Return value. std::array Errno = {UINT32_C(0)}; + uint32_t SessionPtr = UINT32_C(0); + uint32_t SessionId = UINT32_C(0); + uint32_t OutputPtr = UINT32_C(0); // uint32_t OutBoundPtr = UINT32_C(61000) * UINT32_C(65536); // Get the function "convert". @@ -79,26 +80,58 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { auto &HostFuncConvert = dynamic_cast( FuncInst->getHostFunc()); + // Get the function "create_context". + FuncInst = SBMod->findFuncExports("create_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncCreateContext = + dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "text_to_image". + FuncInst = SBMod->findFuncExports("text_to_image"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncTextToImage = + dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "image_to_image". + FuncInst = SBMod->findFuncExports("image_to_image"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncImageToImage = + dynamic_cast( + FuncInst->getHostFunc()); + std::string Prompt = "a lovely cat"; - std::vector TensorData(Prompt.begin(), Prompt.end()); + std::string Prompt2 = "with blue eyes"; + std::string OutputPathString = "./stableDiffusion/output.png"; + std::vector OutputPath(OutputPathString.begin(), + OutputPathString.end()); + std::string InputPathString = "path:" + OutputPathString; + std::vector InputPath(InputPathString.begin(), InputPathString.end()); + std::string OutputPathString2 = "./stableDiffusion/output2.png"; + std::vector OutputPath2(OutputPathString2.begin(), + OutputPathString2.end()); + std::vector PromptData(Prompt.begin(), Prompt.end()); + std::vector PromptData2(Prompt2.begin(), Prompt2.end()); std::string ModelPathString = "./stableDiffusion/sd-v1-4.ckpt"; std::vector ModelPath(ModelPathString.begin(), ModelPathString.end()); - std::string QuantModelPathString = "./stableDiffusion/sd-v1-4-Q4_K.gguf"; + std::string QuantModelPathString = "./stableDiffusion/sd-v1-4-Q8_0.gguf"; std::vector QuantModelPath(QuantModelPathString.begin(), QuantModelPathString.end()); + uint32_t ModelPathPtr = UINT32_C(0); + uint32_t QuantModelPathPtr = ModelPathPtr + ModelPath.size(); + writeBinaries(MemInst, ModelPath, ModelPathPtr); + writeBinaries(MemInst, QuantModelPath, QuantModelPathPtr); // Test: convert -- convert successfully. { - uint32_t ModelPathPtr = UINT32_C(0); - writeBinaries(MemInst, ModelPath, ModelPathPtr); - uint32_t QuantModelPathPtr = ModelPathPtr + ModelPath.size(); - writeBinaries(MemInst, QuantModelPath, QuantModelPathPtr); EXPECT_TRUE(HostFuncConvert.run( CallFrame, std::initializer_list{ ModelPathPtr, static_cast(ModelPath.size()), 0, 0, QuantModelPathPtr, static_cast(QuantModelPath.size()), - 12}, // SD_TYPE_Q4_K = 12 + 8}, // SD_TYPE_Q8_0 = 8 Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); std::ifstream Fin(QuantModelPath.data(), @@ -106,7 +139,174 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { EXPECT_FALSE(Fin.fail()); Fin.close(); } + // Test: create_context -- create context for text to image. + { + EXPECT_TRUE(HostFuncCreateContext.run( + CallFrame, + std::initializer_list{ + QuantModelPathPtr, + static_cast(QuantModelPath.size()), + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + -1, + 31, + 1, + 0, + 0, + 0, + 0, + SessionPtr}, // vaeDecodeOnly=true, NThreads=-1, + // wtype=31(SD_TYPE_COUNT), RngType=CUDA_RNG, + // Schedule=DEFAULT, Other is false + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + SessionId = *MemInst.getPointer(SessionPtr); + EXPECT_EQ(SessionId, 0); + } + // Test: text_to_image -- generate image from text. + { + uint32_t PromptPtr = UINT32_C(0); + uint32_t OutputPathPtr = PromptPtr + PromptData.size(); + uint32_t BytesWrittenPtr = OutputPathPtr + OutputPath.size(); + OutputPtr = BytesWrittenPtr + 4; + writeBinaries(MemInst, PromptData, PromptPtr); + writeBinaries(MemInst, OutputPath, OutputPathPtr); + EXPECT_TRUE(HostFuncTextToImage.run( + CallFrame, + std::initializer_list{PromptPtr, + PromptData.size(), + SessionId, + 0, + 0, + 0, + 0, + 512, + 512, + -1, + 7.0f, + 0, + 20, + 42, + 1, + 0.90f, + 20.0f, + 0, + 0, + 0, + 0, + OutputPathPtr, + OutputPath.size(), + OutputPtr, + 65532, + BytesWrittenPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + auto BytesWritten = *MemInst.getPointer(BytesWrittenPtr); + EXPECT_GE(BytesWritten, 50); + std::ifstream Fin(OutputPath.data(), std::ios::in | std::ios::binary); + EXPECT_FALSE(Fin.fail()); + Fin.close(); + } + writeBinaries(MemInst, ModelPath, ModelPathPtr); + writeBinaries(MemInst, QuantModelPath, QuantModelPathPtr); + // Test: create_context -- create context for image to image. + { + EXPECT_TRUE(HostFuncCreateContext.run( + CallFrame, + std::initializer_list{ + QuantModelPathPtr, + static_cast(QuantModelPath.size()), + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -1, + 31, + 1, + 0, + 0, + 0, + 0, + SessionPtr}, // NThreads=-1, + // wtype=31(SD_TYPE_COUNT), RngType=CUDA_RNG, + // Schedule=DEFAULT, Other is false + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + SessionId = *MemInst.getPointer(SessionPtr); + EXPECT_EQ(SessionId, 1); + } + // Test: image_to_image -- generate image from image. + { + uint32_t PromptPtr = UINT32_C(0); + uint32_t InputPathPtr = PromptPtr + PromptData2.size(); + uint32_t OutputPathPtr = InputPathPtr + InputPath.size(); + uint32_t BytesWrittenPtr = OutputPathPtr + OutputPath2.size(); + OutputPtr = BytesWrittenPtr + 4; + writeBinaries(MemInst, PromptData2, PromptPtr); + writeBinaries(MemInst, InputPath, InputPathPtr); + writeBinaries(MemInst, OutputPath2, OutputPathPtr); + EXPECT_TRUE(HostFuncImageToImage.run( + CallFrame, + std::initializer_list{InputPathPtr, + InputPath.size(), + SessionId, + 512, + 512, + 0, + 0, + PromptPtr, + PromptData2.size(), + 0, + 0, + -1, + 7.0f, + 0, + 20, + 0.75f, + 42, + 1, + 0.9f, + 20.0f, + 0, + 0, + 0, + 0, + OutputPathPtr, + OutputPath2.size(), + OutputPtr, + 65532, + BytesWrittenPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + auto BytesWritten = *MemInst.getPointer(BytesWrittenPtr); + EXPECT_GE(BytesWritten, 50); + std::ifstream Fin(OutputPath2.data(), std::ios::in | std::ios::binary); + EXPECT_FALSE(Fin.fail()); + Fin.close(); + } delete SBMod; } From 3dfff57db5d517d889cb570b88bb686e2f2a281e Mon Sep 17 00:00:00 2001 From: grorge Date: Tue, 18 Jun 2024 17:00:49 +0800 Subject: [PATCH 342/623] [Plugin] stable diffusion: handle exceed maximum Signed-off-by: grorge --- .../wasmedge_stablediffusion/CMakeLists.txt | 37 ++++++++++--------- plugins/wasmedge_stablediffusion/sd_func.cpp | 19 +++++++++- 2 files changed, 36 insertions(+), 20 deletions(-) diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index 4a6ac7b6..fca1bf56 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -1,22 +1,5 @@ # setup stable diffusion message(STATUS "Downloading stable diffusion source") -if (MSVC) - add_compile_options( - /wd4996 - /wd4456 - /wd4459 - /wd4100 - /wd4127 - /wd4701 - ) -else() - add_compile_options( - -Wno-unused-function - -Wno-unused-variable - -Wno-unused-parameter - -Wno-missing-field-initializers - ) -endif() FetchContent_Declare( stable-diffusion GIT_REPOSITORY https://github.com/leejet/stable-diffusion.cpp.git @@ -66,5 +49,23 @@ target_include_directories(wasmedgePluginStableDiffusion $ ${CMAKE_CURRENT_SOURCE_DIR} ) - +if (MSVC) +target_compile_options( + wasmedgePluginStableDiffusion + PRIVATE + /wd4459 + /wd4100 + /wd4127 + /wd4701 + ) +else() + target_compile_options( + wasmedgePluginStableDiffusion + PRIVATE + -Wno-unused-function + -Wno-unused-variable + -Wno-unused-parameter + -Wno-missing-field-initializers + ) +endif() install(TARGETS wasmedgePluginStableDiffusion DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) \ No newline at end of file diff --git a/plugins/wasmedge_stablediffusion/sd_func.cpp b/plugins/wasmedge_stablediffusion/sd_func.cpp index b246065c..dd33b224 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.cpp +++ b/plugins/wasmedge_stablediffusion/sd_func.cpp @@ -254,6 +254,13 @@ Expect SDTextToImage::body( Results->channel, Results->data, 0, nullptr); } *BytesWritten = Len; + if (OutBufferMaxSize < *BytesWritten) { + spdlog::error("[WasmEdge-StableDiffusion] Output buffer is not enough."sv); + free(Png); + free(Results); + free(ControlImageBuffer); + return static_cast(ErrNo::RuntimeError); + } std::copy_n(Png, *BytesWritten, OutputBufferSpan.data()); free(Png); free(Results); @@ -275,7 +282,7 @@ Expect SDImageToImage::body( // Check memory instance from module. MEMINST_CHECK(MemInst, Frame, 0) - // Check the input parameter valud. + // Check the input parameter valid. MEM_SPAN_CHECK(ImageSpan, MemInst, uint8_t, ImagePtr, ImageLen, "Failed when accessing the input image memory."sv) @@ -322,7 +329,7 @@ Expect SDImageToImage::body( &ImageHeight, &Channel, 3); } - // TODO: Resize image when image size not matches weight and height + // TODO: Resize image when image size not matches width and height sd_image_t InputImage = {Width, Height, 3, InputImageBuffer}; sd_image_t *ControlImage = nullptr; if (ControlImageLen != 0) { @@ -348,6 +355,14 @@ Expect SDImageToImage::body( Results->channel, Results->data, 0, nullptr); } *BytesWritten = Len; + if (OutBufferMaxSize < *BytesWritten) { + spdlog::error("[WasmEdge-StableDiffusion] Output buffer is not enough."sv); + free(Png); + free(Results); + free(InputImageBuffer); + free(ControlImageBuffer); + return static_cast(ErrNo::RuntimeError); + } std::copy_n(Png, *BytesWritten, OutputBufferSpan.data()); free(Png); free(Results); From 53801bb260525badcb57275d89008c7144e61a1a Mon Sep 17 00:00:00 2001 From: grorge Date: Wed, 19 Jun 2024 17:04:03 +0800 Subject: [PATCH 343/623] [Plugin] stable diffusion: fix VM registered module number Signed-off-by: grorge --- plugins/wasmedge_stablediffusion/sd_env.h | 1 + plugins/wasmedge_stablediffusion/sd_func.cpp | 41 +++++++++++++++---- plugins/wasmedge_stablediffusion/sd_func.h | 15 ++++--- .../wasmedge_stablediffusion.cpp | 14 +++++-- 4 files changed, 54 insertions(+), 17 deletions(-) diff --git a/plugins/wasmedge_stablediffusion/sd_env.h b/plugins/wasmedge_stablediffusion/sd_env.h index 3685aa5f..dfd6a2db 100644 --- a/plugins/wasmedge_stablediffusion/sd_env.h +++ b/plugins/wasmedge_stablediffusion/sd_env.h @@ -27,6 +27,7 @@ class SDEnviornment { }; uint32_t addContext(sd_ctx_t *Ctx) noexcept; sd_ctx_t *getContext(const uint32_t Id) noexcept; + size_t getContextSize() noexcept { return Contexts.size(); } private: std::vector Contexts; diff --git a/plugins/wasmedge_stablediffusion/sd_func.cpp b/plugins/wasmedge_stablediffusion/sd_func.cpp index dd33b224..a0caabc0 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.cpp +++ b/plugins/wasmedge_stablediffusion/sd_func.cpp @@ -48,7 +48,24 @@ namespace StableDiffusion { spdlog::error("[WasmEdge-StableDiffusion] "sv Message); \ return static_cast(ErrNo::MissingMemory); \ } - +bool parameterCheck(SDEnviornment &Env, uint32_t width, uint32_t height, + uint32_t SessionId) { + if (SessionId < 0 || SessionId >= Env.getContextSize()) { + spdlog::error("[WasmEdge-StableDiffusion] Session ID is invalid."); + return false; + } + if (width <= 0 || width % 64 != 0) { + spdlog::error("[WasmEdge-StableDiffusion] Width must be a multiple of 64 " + "and greater than 0"); + return false; + } + if (height <= 0 || height % 64 != 0) { + spdlog::error("[WasmEdge-StableDiffusion] Height must be a multiple of 64 " + "and greater than 0"); + return false; + } + return true; +} sd_image_t *readControlImage(Span ControlImage, uint8_t *ControlImageBuf, int Width, int Height, bool CannyPreprocess) { @@ -201,9 +218,10 @@ Expect SDTextToImage::body( uint32_t SampleSteps, uint32_t Seed, uint32_t BatchCount, float ControlStrength, float StyleRatio, uint32_t NormalizeInput, uint32_t InputIdImagesPathPtr, uint32_t InputIdImagesPathLen, - uint32_t CannyPreprocess, uint32_t OutputPathPtr, uint32_t OutputPathLen, - uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, - uint32_t BytesWrittenPtr) { + uint32_t CannyPreprocess, uint32_t UpscaleModelPathPtr, + uint32_t UpscaleModelPathLen, uint32_t UpscaleRepeats, + uint32_t OutputPathPtr, uint32_t OutputPathLen, uint32_t OutBufferPtr, + uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr) { // Check memory instance from module. MEMINST_CHECK(MemInst, Frame, 0) // Check the input model buffer. @@ -228,6 +246,9 @@ Expect SDTextToImage::body( std::string InputIdImagesPath(InputIdImagesPathSpan.begin(), InputIdImagesPathSpan.end()); std::string OutputPath(OutputPathSpan.begin(), OutputPathSpan.end()); + if (!parameterCheck(Env, Width, Height, SessionId)) { + return static_cast(ErrNo::InvalidArgument); + } sd_ctx_t *SDCtx = Env.getContext(SessionId); sd_image_t *Results = nullptr; sd_image_t *ControlImage = nullptr; @@ -245,6 +266,7 @@ Expect SDTextToImage::body( Width, Height, sample_method_t(SampleMethod), SampleSteps, Seed, BatchCount, ControlImage, ControlStrength, StyleRatio, NormalizeInput, InputIdImagesPath.data()); + // TODO upscale image int Len; unsigned char *Png = stbi_write_png_to_mem( reinterpret_cast(Results), 0, Results->width, @@ -276,9 +298,10 @@ Expect SDImageToImage::body( uint32_t SampleSteps, float Strength, uint32_t Seed, uint32_t BatchCount, float ControlStrength, float StyleRatio, uint32_t NormalizeInput, uint32_t InputIdImagesPathPtr, uint32_t InputIdImagesPathLen, - uint32_t CannyPreprocess, uint32_t OutputPathPtr, uint32_t OutputPathLen, - uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, - uint32_t BytesWrittenPtr) { + uint32_t CannyPreprocess, uint32_t UpscaleModelPathPtr, + uint32_t UpscaleModelPathLen, uint32_t UpscaleRepeats, + uint32_t OutputPathPtr, uint32_t OutputPathLen, uint32_t OutBufferPtr, + uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr) { // Check memory instance from module. MEMINST_CHECK(MemInst, Frame, 0) @@ -301,6 +324,9 @@ Expect SDImageToImage::body( "Failed when accessing the return bytes written memory."sv) MEM_SPAN_CHECK(OutputPathSpan, MemInst, char, OutputPathPtr, OutputPathLen, "Failed when accessing the output path memory."sv) + if (!parameterCheck(Env, Width, Height, SessionId)) { + return static_cast(ErrNo::InvalidArgument); + } sd_ctx_t *SDCtx = Env.getContext(SessionId); std::string Prompt(PromptSpan.begin(), PromptSpan.end()); std::string NegativePrompt(NegativePromptSpan.begin(), @@ -346,6 +372,7 @@ Expect SDImageToImage::body( sample_method_t(SampleMethod), SampleSteps, Strength, Seed, BatchCount, ControlImage, ControlStrength, StyleRatio, NormalizeInput, InputIdImagesPath.data()); + // TODO upscale image int Len; unsigned char *Png = stbi_write_png_to_mem( reinterpret_cast(Results), 0, Results->width, diff --git a/plugins/wasmedge_stablediffusion/sd_func.h b/plugins/wasmedge_stablediffusion/sd_func.h index f735e106..57731c48 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.h +++ b/plugins/wasmedge_stablediffusion/sd_func.h @@ -32,9 +32,11 @@ class SDImageToImage : public StableDiffusion::Func { uint32_t SampleMethod, uint32_t SampleSteps, float Strength, uint32_t Seed, uint32_t BatchCount, float ControlStrength, float StyleRatio, uint32_t NormalizeInput, uint32_t InputIdImagesPathPtr, - uint32_t InputIdImagesPathLen, uint32_t canny_preprocess, - uint32_t OutputPathPtr, uint32_t OutputPathLen, uint32_t OutBufferPtr, - uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr); + uint32_t InputIdImagesPathLen, uint32_t CannyPreprocess, + uint32_t UpscaleModelPathPtr, uint32_t UpscaleModelPathLen, + uint32_t UpscaleRepeats, uint32_t OutputPathPtr, uint32_t OutputPathLen, + uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, + uint32_t BytesWrittenPtr); }; class SDTextToImage : public StableDiffusion::Func { public: @@ -48,9 +50,10 @@ class SDTextToImage : public StableDiffusion::Func { uint32_t SampleSteps, uint32_t Seed, uint32_t BatchCount, float ControlStrength, float StyleRatio, uint32_t NormalizeInput, uint32_t InputIdImagesPathPtr, uint32_t InputIdImagesPathLen, - uint32_t CannyPreprocess, uint32_t OutputPathPtr, uint32_t OutputPathLen, - uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, - uint32_t BytesWrittenPtr); + uint32_t CannyPreprocess, uint32_t UpscaleModelPathPtr, + uint32_t UpscaleModelPathLen, uint32_t UpscaleRepeats, + uint32_t OutputPathPtr, uint32_t OutputPathLen, uint32_t OutBufferPtr, + uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr); }; class SDConvert : public StableDiffusion::Func { public: diff --git a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp index f7005bc8..6b1c71e0 100644 --- a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp +++ b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp @@ -193,8 +193,8 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { 0, 0, 0, - 512, - 512, + 64, + 64, -1, 7.0f, 0, @@ -207,6 +207,9 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { 0, 0, 0, + 0, + 0, + 0, OutputPathPtr, OutputPath.size(), OutputPtr, @@ -273,8 +276,8 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { std::initializer_list{InputPathPtr, InputPath.size(), SessionId, - 512, - 512, + 64, + 64, 0, 0, PromptPtr, @@ -294,6 +297,9 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { 0, 0, 0, + 0, + 0, + 0, OutputPathPtr, OutputPath2.size(), OutputPtr, From ee386568eb94d9b551dce020027611e0daed5838 Mon Sep 17 00:00:00 2001 From: grorge Date: Sat, 22 Jun 2024 08:48:17 +0800 Subject: [PATCH 344/623] [Plugin] stable diffusion: allow other platform sto run test Signed-off-by: grorge --- plugins/wasmedge_stablediffusion/sd_func.cpp | 30 +-- plugins/wasmedge_stablediffusion/sd_func.h | 6 +- test/plugins/CMakeLists.txt | 4 +- .../wasmedge_stablediffusion.cpp | 236 +++++++++--------- 4 files changed, 136 insertions(+), 140 deletions(-) diff --git a/plugins/wasmedge_stablediffusion/sd_func.cpp b/plugins/wasmedge_stablediffusion/sd_func.cpp index a0caabc0..a3063b1c 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.cpp +++ b/plugins/wasmedge_stablediffusion/sd_func.cpp @@ -50,16 +50,16 @@ namespace StableDiffusion { } bool parameterCheck(SDEnviornment &Env, uint32_t width, uint32_t height, uint32_t SessionId) { - if (SessionId < 0 || SessionId >= Env.getContextSize()) { + if (SessionId >= Env.getContextSize()) { spdlog::error("[WasmEdge-StableDiffusion] Session ID is invalid."); return false; } - if (width <= 0 || width % 64 != 0) { + if (width % 64 != 0) { spdlog::error("[WasmEdge-StableDiffusion] Width must be a multiple of 64 " "and greater than 0"); return false; } - if (height <= 0 || height % 64 != 0) { + if (height % 64 != 0) { spdlog::error("[WasmEdge-StableDiffusion] Height must be a multiple of 64 " "and greater than 0"); return false; @@ -217,7 +217,7 @@ Expect SDTextToImage::body( uint32_t Height, int32_t ClipSkip, float CfgScale, uint32_t SampleMethod, uint32_t SampleSteps, uint32_t Seed, uint32_t BatchCount, float ControlStrength, float StyleRatio, uint32_t NormalizeInput, - uint32_t InputIdImagesPathPtr, uint32_t InputIdImagesPathLen, + uint32_t InputIdImagesDirPtr, uint32_t InputIdImagesDirLen, uint32_t CannyPreprocess, uint32_t UpscaleModelPathPtr, uint32_t UpscaleModelPathLen, uint32_t UpscaleRepeats, uint32_t OutputPathPtr, uint32_t OutputPathLen, uint32_t OutBufferPtr, @@ -230,8 +230,8 @@ Expect SDTextToImage::body( MEM_SPAN_CHECK(NegativePromptSpan, MemInst, char, NegativePromptPtr, NegativePromptLen, "Failed when accessing the input negative prompt memory."sv) - MEM_SPAN_CHECK(InputIdImagesPathSpan, MemInst, char, InputIdImagesPathPtr, - InputIdImagesPathLen, + MEM_SPAN_CHECK(InputIdImagesDirSpan, MemInst, char, InputIdImagesDirPtr, + InputIdImagesDirLen, "Failed when accessing the input id images path memory."sv) MEM_SPAN_CHECK(OutputBufferSpan, MemInst, uint8_t, OutBufferPtr, OutBufferMaxSize, @@ -243,8 +243,8 @@ Expect SDTextToImage::body( std::string Prompt(PromptSpan.begin(), PromptSpan.end()); std::string NegativePrompt(NegativePromptSpan.begin(), NegativePromptSpan.end()); - std::string InputIdImagesPath(InputIdImagesPathSpan.begin(), - InputIdImagesPathSpan.end()); + std::string InputIdImagesDir(InputIdImagesDirSpan.begin(), + InputIdImagesDirSpan.end()); std::string OutputPath(OutputPathSpan.begin(), OutputPathSpan.end()); if (!parameterCheck(Env, Width, Height, SessionId)) { return static_cast(ErrNo::InvalidArgument); @@ -265,7 +265,7 @@ Expect SDTextToImage::body( txt2img(SDCtx, Prompt.data(), NegativePrompt.data(), ClipSkip, CfgScale, Width, Height, sample_method_t(SampleMethod), SampleSteps, Seed, BatchCount, ControlImage, ControlStrength, StyleRatio, - NormalizeInput, InputIdImagesPath.data()); + NormalizeInput, InputIdImagesDir.data()); // TODO upscale image int Len; unsigned char *Png = stbi_write_png_to_mem( @@ -297,7 +297,7 @@ Expect SDImageToImage::body( int32_t ClipSkip, float CfgScale, uint32_t SampleMethod, uint32_t SampleSteps, float Strength, uint32_t Seed, uint32_t BatchCount, float ControlStrength, float StyleRatio, uint32_t NormalizeInput, - uint32_t InputIdImagesPathPtr, uint32_t InputIdImagesPathLen, + uint32_t InputIdImagesDirPtr, uint32_t InputIdImagesDirLen, uint32_t CannyPreprocess, uint32_t UpscaleModelPathPtr, uint32_t UpscaleModelPathLen, uint32_t UpscaleRepeats, uint32_t OutputPathPtr, uint32_t OutputPathLen, uint32_t OutBufferPtr, @@ -314,8 +314,8 @@ Expect SDImageToImage::body( MEM_SPAN_CHECK(NegativePromptSpan, MemInst, char, NegativePromptPtr, NegativePromptLen, "Failed when accessing the input negative prompt memory."sv) - MEM_SPAN_CHECK(InputIdImagesPathSpan, MemInst, char, InputIdImagesPathPtr, - InputIdImagesPathLen, + MEM_SPAN_CHECK(InputIdImagesDirSpan, MemInst, char, InputIdImagesDirPtr, + InputIdImagesDirLen, "Failed when accessing the input id images path memory."sv) MEM_SPAN_CHECK(OutputBufferSpan, MemInst, uint8_t, OutBufferPtr, OutBufferMaxSize, @@ -331,8 +331,8 @@ Expect SDImageToImage::body( std::string Prompt(PromptSpan.begin(), PromptSpan.end()); std::string NegativePrompt(NegativePromptSpan.begin(), NegativePromptSpan.end()); - std::string InputIdImagesPath(InputIdImagesPathSpan.begin(), - InputIdImagesPathSpan.end()); + std::string InputIdImagesDir(InputIdImagesDirSpan.begin(), + InputIdImagesDirSpan.end()); std::string OutputPath(OutputPathSpan.begin(), OutputPathSpan.end()); uint8_t *InputImageBuffer = nullptr; uint8_t *ControlImageBuffer = nullptr; @@ -371,7 +371,7 @@ Expect SDImageToImage::body( ClipSkip, CfgScale, Width, Height, sample_method_t(SampleMethod), SampleSteps, Strength, Seed, BatchCount, ControlImage, ControlStrength, StyleRatio, - NormalizeInput, InputIdImagesPath.data()); + NormalizeInput, InputIdImagesDir.data()); // TODO upscale image int Len; unsigned char *Png = stbi_write_png_to_mem( diff --git a/plugins/wasmedge_stablediffusion/sd_func.h b/plugins/wasmedge_stablediffusion/sd_func.h index 57731c48..0dbdf0a2 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.h +++ b/plugins/wasmedge_stablediffusion/sd_func.h @@ -31,8 +31,8 @@ class SDImageToImage : public StableDiffusion::Func { uint32_t NegativePromptLen, int32_t ClipSkip, float CfgScale, uint32_t SampleMethod, uint32_t SampleSteps, float Strength, uint32_t Seed, uint32_t BatchCount, float ControlStrength, - float StyleRatio, uint32_t NormalizeInput, uint32_t InputIdImagesPathPtr, - uint32_t InputIdImagesPathLen, uint32_t CannyPreprocess, + float StyleRatio, uint32_t NormalizeInput, uint32_t InputIdImagesDirPtr, + uint32_t InputIdImagesDirLen, uint32_t CannyPreprocess, uint32_t UpscaleModelPathPtr, uint32_t UpscaleModelPathLen, uint32_t UpscaleRepeats, uint32_t OutputPathPtr, uint32_t OutputPathLen, uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, @@ -49,7 +49,7 @@ class SDTextToImage : public StableDiffusion::Func { int32_t ClipSkip, float CfgScale, uint32_t SampleMethod, uint32_t SampleSteps, uint32_t Seed, uint32_t BatchCount, float ControlStrength, float StyleRatio, uint32_t NormalizeInput, - uint32_t InputIdImagesPathPtr, uint32_t InputIdImagesPathLen, + uint32_t InputIdImagesDirPtr, uint32_t InputIdImagesDirLen, uint32_t CannyPreprocess, uint32_t UpscaleModelPathPtr, uint32_t UpscaleModelPathLen, uint32_t UpscaleRepeats, uint32_t OutputPathPtr, uint32_t OutputPathLen, uint32_t OutBufferPtr, diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt index 30da9585..e59b215e 100644 --- a/test/plugins/CMakeLists.txt +++ b/test/plugins/CMakeLists.txt @@ -54,9 +54,7 @@ if(WASMEDGE_PLUGIN_WASM_BPF) endif() if(WASMEDGE_PLUGIN_STABLEDIFFUSION) - if(CMAKE_SYSTEM_NAME MATCHES "Linux") - add_subdirectory(wasmedge_stablediffusion) - endif() + add_subdirectory(wasmedge_stablediffusion) endif() if(WASMEDGE_PLUGIN_WASI_LOGGING) diff --git a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp index 6b1c71e0..b2533c44 100644 --- a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp +++ b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp @@ -144,32 +144,30 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { EXPECT_TRUE(HostFuncCreateContext.run( CallFrame, std::initializer_list{ - QuantModelPathPtr, - static_cast(QuantModelPath.size()), - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1, - 0, - -1, - 31, - 1, - 0, - 0, - 0, - 0, - SessionPtr}, // vaeDecodeOnly=true, NThreads=-1, - // wtype=31(SD_TYPE_COUNT), RngType=CUDA_RNG, - // Schedule=DEFAULT, Other is false + QuantModelPathPtr, // ModelPathPtr + static_cast(QuantModelPath.size()), // ModelPathLen + 0, // VaePathPtr + 0, // VaePathLen + 0, // TaesdPathPtr + 0, // TaesdPathLen + 0, // ControlNetPathPtr + 0, // ControlNetPathLen + 0, // LoraModelDirPtr + 0, // LoraModelDirLen + 0, // EmbedDirPtr + 0, // EmbedDirLen + 0, // IdEmbedDirPtr + 0, // IdEmbedDirLen + 1, // VaeDecodeOnly + 0, // VaeTiling + -1, // NThreads + 31, // Wtype + 1, // RngType + 0, // Schedule + 0, // ClipOnCpu + 0, // ControlNetCpu + 0, // VaeOnCpu + SessionPtr}, // SessiontIdPtr Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); SessionId = *MemInst.getPointer(SessionPtr); @@ -184,38 +182,39 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { OutputPtr = BytesWrittenPtr + 4; writeBinaries(MemInst, PromptData, PromptPtr); writeBinaries(MemInst, OutputPath, OutputPathPtr); - EXPECT_TRUE(HostFuncTextToImage.run( - CallFrame, - std::initializer_list{PromptPtr, - PromptData.size(), - SessionId, - 0, - 0, - 0, - 0, - 64, - 64, - -1, - 7.0f, - 0, - 20, - 42, - 1, - 0.90f, - 20.0f, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - OutputPathPtr, - OutputPath.size(), - OutputPtr, - 65532, - BytesWrittenPtr}, - Errno)); + EXPECT_TRUE( + HostFuncTextToImage.run(CallFrame, + std::initializer_list{ + PromptPtr, // PromptPtr + PromptData.size(), // PromptLen + SessionId, // SessionId + 0, // ControlImagePtr + 0, // ControlImageLen + 0, // NegativePromptPtr + 0, // NegativePromptLen + 64, // Width + 64, // Height + -1, // ClipSkip + 7.0f, // CfgScale + 0, // SampleMethod + 20, // SampleSteps + 42, // Seed + 1, // BatchCount + 0.90f, // ControlStrength + 20.0f, // StyleRatio + 0, // NormalizeInput + 0, // InputIdImagesDirPtr + 0, // InputIdImagesDirLen + 0, // CannyPreprocess + 0, // UpscaleModelPathPtr + 0, // UpscaleModelPathLen + 1, // UpscaleRepeats + OutputPathPtr, // OutputPathPtr + OutputPath.size(), // OutputPathLen + OutputPtr, // OutBufferPtr + 65532, // OutBufferMaxSize + BytesWrittenPtr}, // BytesWrittenPtr + Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); auto BytesWritten = *MemInst.getPointer(BytesWrittenPtr); EXPECT_GE(BytesWritten, 50); @@ -230,32 +229,30 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { EXPECT_TRUE(HostFuncCreateContext.run( CallFrame, std::initializer_list{ - QuantModelPathPtr, - static_cast(QuantModelPath.size()), - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - -1, - 31, - 1, - 0, - 0, - 0, - 0, - SessionPtr}, // NThreads=-1, - // wtype=31(SD_TYPE_COUNT), RngType=CUDA_RNG, - // Schedule=DEFAULT, Other is false + QuantModelPathPtr, // ModelPathPtr + static_cast(QuantModelPath.size()), // ModelPathLen + 0, // VaePathPtr + 0, // VaePathLen + 0, // TaesdPathPtr + 0, // TaesdPathLen + 0, // ControlNetPathPtr + 0, // ControlNetPathLen + 0, // LoraModelDirPtr + 0, // LoraModelDirLen + 0, // EmbedDirPtr + 0, // EmbedDirLen + 0, // IdEmbedDirPtr + 0, // IdEmbedDirLen + 0, // VaeDecodeOnly + 0, // VaeTiling + -1, // NThreads + 31, // Wtype + 1, // RngType + 0, // Schedule + 0, // ClipOnCpu + 0, // ControlNetCpu + 0, // VaeOnCpu + SessionPtr}, // SessiontIdPtr Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); SessionId = *MemInst.getPointer(SessionPtr); @@ -271,41 +268,42 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { writeBinaries(MemInst, PromptData2, PromptPtr); writeBinaries(MemInst, InputPath, InputPathPtr); writeBinaries(MemInst, OutputPath2, OutputPathPtr); - EXPECT_TRUE(HostFuncImageToImage.run( - CallFrame, - std::initializer_list{InputPathPtr, - InputPath.size(), - SessionId, - 64, - 64, - 0, - 0, - PromptPtr, - PromptData2.size(), - 0, - 0, - -1, - 7.0f, - 0, - 20, - 0.75f, - 42, - 1, - 0.9f, - 20.0f, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - OutputPathPtr, - OutputPath2.size(), - OutputPtr, - 65532, - BytesWrittenPtr}, - Errno)); + EXPECT_TRUE( + HostFuncImageToImage.run(CallFrame, + std::initializer_list{ + InputPathPtr, // ImagePtr + InputPath.size(), // ImageLen + SessionId, // SessionId + 64, // Width + 64, // Height + 0, // ControlImagePtr + 0, // ControlImageLen + PromptPtr, // PromptPtr + PromptData2.size(), // PromptLen + 0, // NegativePromptPtr + 0, // NegativePromptLen + -1, // ClipSkip + 7.0f, // CfgScale + 0, // SampleMethod + 20, // SampleSteps + 0.75f, // Strength + 42, // Seed + 1, // BatchCount + 0.9f, // ControlStrength + 20.0f, // StyleRatio + 0, // NormalizeInput + 0, // InputIdImagesDirPtr + 0, // InputIdImagesDirLen + 0, // CannyPreprocess + 0, // UpscaleModelPathPtr + 0, // UpscaleModelPathLen + 1, // UpscaleRepeats + OutputPathPtr, // OutputPathPtr + OutputPath2.size(), // OutputPathLen + OutputPtr, // OutBufferPtr + 65532, // OutBufferMaxSize + BytesWrittenPtr}, // BytesWrittenPtr + Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); auto BytesWritten = *MemInst.getPointer(BytesWrittenPtr); EXPECT_GE(BytesWritten, 50); From 0d5134802184aacb9f70d0524a516bd63523791c Mon Sep 17 00:00:00 2001 From: YiYing He Date: Wed, 26 Jun 2024 20:18:57 +0800 Subject: [PATCH 345/623] [WASI-NN] ggml: refine the unused target include in CMake. Signed-off-by: YiYing He --- plugins/wasi_nn/CMakeLists.txt | 5 ----- 1 file changed, 5 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index dffb142e..f8f5e580 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -133,11 +133,6 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) ${llama_SOURCE_DIR}/common ${llama_SOURCE_DIR}/examples/llava ) - # Setup include and link from llama.cpp - target_include_directories(wasmedgePluginWasiNN PRIVATE - ${llama_SOURCE_DIR} - ${llama_SOURCE_DIR}examples/llava - ) target_link_libraries(wasmedgePluginWasiNN PRIVATE common simdjson::simdjson From d9ef9392e840ebcc71b7c5bf51df3f158e5716a9 Mon Sep 17 00:00:00 2001 From: dm4 Date: Fri, 28 Jun 2024 13:34:05 +0800 Subject: [PATCH 346/623] [WASI-NN] ggml: bump llama.cpp b3259 Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index f8f5e580..7c9c2916 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -73,7 +73,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b3186 + GIT_TAG b3259 GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(llama) From 416656372d32ef785c632dc2e8e7a0fda881cc83 Mon Sep 17 00:00:00 2001 From: hydai Date: Sat, 29 Jun 2024 17:43:49 +0800 Subject: [PATCH 347/623] [WASI-NN] ggml: fix metal related files because the folder structure changed Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 58 ++++++++++++++++++---------------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 7c9c2916..7535e43a 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -23,35 +23,24 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) # Disable warnings and debug messages set(LLAMA_ALL_WARNINGS OFF) set(LLAMA_METAL_NDEBUG ON) - set(LLAMA_ACCELERATE OFF) + set(GGML_ACCELERATE OFF) + set(GGML_BLAS OFF) if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_NATIVE) - message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_NATIVE(AVX/AVX2/FMA)") - set(LLAMA_NATIVE ON) + message(STATUS "WASI-NN GGML LLAMA backend: Enable GGML_NATIVE(AVX/AVX2/FMA)") + set(GGML_NATIVE ON) else() - set(LLAMA_NATIVE OFF) + set(GGML_NATIVE OFF) endif() if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_CUBLAS) - message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_CUDA") - set(LLAMA_CUDA ON) + message(STATUS "WASI-NN GGML LLAMA backend: Enable GGML_CUDA") + set(GGML_CUDA ON) # We need to set GGML_USE_CUDA for clip from llava. add_compile_definitions(GGML_USE_CUDA) - # If CUDA is ON, then OpenBLAS should be OFF. - set(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_BLAS OFF) else() - message(STATUS "WASI-NN GGML LLAMA backend: Disable LLAMA_CUDA") - set(LLAMA_CUDA OFF) - endif() - - if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_BLAS) - message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_BLAS") - # Default use OpenBLAS - set(LLAMA_BLAS ON) - set(LLAMA_BLAS_VENDOR "OpenBLAS") - else() - message(STATUS "WASI-NN GGML LLAMA backend: Disable LLAMA_BLAS") - set(LLAMA_BLAS OFF) + message(STATUS "WASI-NN GGML LLAMA backend: Disable GGML_CUDA") + set(GGML_CUDA OFF) endif() if(NOT APPLE) @@ -59,12 +48,12 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) endif() if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL) - message(STATUS "WASI-NN GGML LLAMA backend: Enable LLAMA_METAL") - set(LLAMA_METAL ON) - set(LLAMA_METAL_EMBED_LIBRARY ON) + message(STATUS "WASI-NN GGML LLAMA backend: Enable GGML_METAL") + set(GGML_METAL ON) + set(GGML_METAL_EMBED_LIBRARY ON) else() - message(STATUS "WASI-NN GGML LLAMA backend: Disable LLAMA_METAL") - set(LLAMA_METAL OFF) + message(STATUS "WASI-NN GGML LLAMA backend: Disable GGML_METAL") + set(GGML_METAL OFF) endif() # setup llama.cpp @@ -142,8 +131,23 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) add_custom_command( TARGET wasmedgePluginWasiNN POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${llama_SOURCE_DIR}/ggml-metal.metal ggml-metal.metal - COMMAND ${CMAKE_COMMAND} -E copy ${llama_SOURCE_DIR}/ggml-common.h ggml-common.h + COMMAND ${CMAKE_COMMAND} -E copy ${llama_SOURCE_DIR}/ggml/src/ggml-metal.metal ggml-metal.metal + COMMAND ${CMAKE_COMMAND} -E copy ${llama_SOURCE_DIR}/ggml/src/ggml-common.h ggml-common.h + ) + endif() + if(CMAKE_SYSTEM_NAME STREQUAL "Darwin") + add_custom_command( + TARGET wasmedgePluginWasiNN + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${llama_SOURCE_DIR}/../llama-build/src/libllama.dylib libllama.dylib + COMMAND ${CMAKE_COMMAND} -E copy ${llama_SOURCE_DIR}/../llama-build/ggml/src/libggml.dylib libggml.dylib + ) + else() + add_custom_command( + TARGET wasmedgePluginWasiNN + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${llama_SOURCE_DIR}/../llama-build/src/libllama.so libllama.so + COMMAND ${CMAKE_COMMAND} -E copy ${llama_SOURCE_DIR}/../llama-build/ggml/src/libggml.so libggml.so ) endif() elseif(BACKEND STREQUAL "neuralspeed") From 4bf10f074cc5506700760ff5201ae676c64c0790 Mon Sep 17 00:00:00 2001 From: dm4 Date: Sat, 29 Jun 2024 21:49:38 +0800 Subject: [PATCH 348/623] [WASI-NN] ggml: static build llama.cpp Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 7535e43a..92c031cb 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -25,6 +25,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) set(LLAMA_METAL_NDEBUG ON) set(GGML_ACCELERATE OFF) set(GGML_BLAS OFF) + set(BUILD_SHARED_LIBS OFF) if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_NATIVE) message(STATUS "WASI-NN GGML LLAMA backend: Enable GGML_NATIVE(AVX/AVX2/FMA)") From d8f9b4dbc7a0d6f7aabf560c86879b9953108099 Mon Sep 17 00:00:00 2001 From: hydai Date: Sat, 29 Jun 2024 22:04:00 +0800 Subject: [PATCH 349/623] [WASI-NN] ggml: remove libggml and libllama. use static build instead Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 92c031cb..d7d242ba 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -136,21 +136,6 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) COMMAND ${CMAKE_COMMAND} -E copy ${llama_SOURCE_DIR}/ggml/src/ggml-common.h ggml-common.h ) endif() - if(CMAKE_SYSTEM_NAME STREQUAL "Darwin") - add_custom_command( - TARGET wasmedgePluginWasiNN - POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${llama_SOURCE_DIR}/../llama-build/src/libllama.dylib libllama.dylib - COMMAND ${CMAKE_COMMAND} -E copy ${llama_SOURCE_DIR}/../llama-build/ggml/src/libggml.dylib libggml.dylib - ) - else() - add_custom_command( - TARGET wasmedgePluginWasiNN - POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${llama_SOURCE_DIR}/../llama-build/src/libllama.so libllama.so - COMMAND ${CMAKE_COMMAND} -E copy ${llama_SOURCE_DIR}/../llama-build/ggml/src/libggml.so libggml.so - ) - endif() elseif(BACKEND STREQUAL "neuralspeed") wasmedge_setup_simdjson() From 9af970cd2ba27c7086571b87ae9ce12d1812f6f5 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Mon, 1 Jul 2024 15:17:01 +0800 Subject: [PATCH 350/623] [WASI-NN] Neural speed: Refine code and add the release CI. Signed-off-by: YiYing He --- plugins/wasi_nn/neuralspeed.cpp | 25 ++++++++++++++++++++----- plugins/wasi_nn/neuralspeed.h | 9 +++++++-- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/plugins/wasi_nn/neuralspeed.cpp b/plugins/wasi_nn/neuralspeed.cpp index c932fe97..6a04da17 100644 --- a/plugins/wasi_nn/neuralspeed.cpp +++ b/plugins/wasi_nn/neuralspeed.cpp @@ -1,22 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + #include "neuralspeed.h" +#include "wasinnenv.h" + #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED #include "simdjson.h" -#include -#endif -#include "wasinnenv.h" -#include + #if !defined(_WIN32) && !defined(_WIN64) && !defined(__WIN32__) && \ !defined(__TOS_WIN__) && !defined(__WINDOWS__) #include #endif +#include +#include +#endif + namespace WasmEdge::Host::WASINN::NeuralSpeed { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED + #if defined(_WIN32) || defined(_WIN64) || defined(__WIN32__) || \ defined(__TOS_WIN__) || defined(__WINDOWS__) HINSTANCE SharedLib = LoadLibrary(PYTHON_LIB_PATH); #else void *SharedLib = dlopen(PYTHON_LIB_PATH, RTLD_GLOBAL | RTLD_NOW); #endif + void printImformation(Graph &GraphRef, Context &CxtRef) { spdlog::info("[WASI-NN] Neural speed backend: Number of input tokens: {}"sv, CxtRef.Inputs.size()); @@ -27,6 +35,7 @@ void printImformation(Graph &GraphRef, Context &CxtRef) { spdlog::info("[WASI-NN] Neural speed backend: Compute time: {} ms "sv, GraphRef.ComputeTime); } + Expect load(WASINN::WasiNNEnvironment &Env, Span> Builders, WASINN::Device, uint32_t &GraphId) noexcept { @@ -195,6 +204,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, return WASINN::ErrNo::Success; } + Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, uint32_t, Span OutBuffer, uint32_t &BytesWritten) noexcept { @@ -209,6 +219,7 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, BytesWritten = StringTmp.length(); return WASINN::ErrNo::Success; } + Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { if (!Py_IsInitialized()) { @@ -291,6 +302,7 @@ Expect compute(WasiNNEnvironment &Env, } return WASINN::ErrNo::Success; } + Expect unload(WASINN::WasiNNEnvironment &Env, uint32_t ContextId) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); @@ -306,10 +318,13 @@ Expect unload(WASINN::WasiNNEnvironment &Env, } return WASINN::ErrNo::Success; } + #else namespace { Expect reportBackendNotSupported() noexcept { - spdlog::error("[WASI-NN] Neural speed backend is not supported."); + spdlog::error( + "[WASI-NN] Neural speed backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"NeuralSpeed\" to build it."sv); return WASINN::ErrNo::InvalidArgument; } } // namespace diff --git a/plugins/wasi_nn/neuralspeed.h b/plugins/wasi_nn/neuralspeed.h index bd06e2a4..f087f151 100644 --- a/plugins/wasi_nn/neuralspeed.h +++ b/plugins/wasi_nn/neuralspeed.h @@ -1,11 +1,15 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + #pragma once #include "plugin/plugin.h" #include "types.h" -#include + #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED #include #endif + namespace WasmEdge::Host::WASINN { struct WasiNNEnvironment; } @@ -62,4 +66,5 @@ Expect compute(WASINN::WasiNNEnvironment &Env, uint32_t ContextId) noexcept; Expect unload(WASINN::WasiNNEnvironment &Env, uint32_t GraphId) noexcept; -} // namespace WasmEdge::Host::WASINN::NeuralSpeed \ No newline at end of file + +} // namespace WasmEdge::Host::WASINN::NeuralSpeed From f408689f7e7e7f45e737b5f5c544f09a7ee98e1e Mon Sep 17 00:00:00 2001 From: YiYing He Date: Tue, 2 Jul 2024 16:51:21 +0800 Subject: [PATCH 351/623] [Plugin] Correct the module name and refine code of wasmedge_stablediffusion plugin. Signed-off-by: YiYing He --- .../wasmedge_stablediffusion/CMakeLists.txt | 30 +++++++++++-------- plugins/wasmedge_stablediffusion/sd_base.h | 6 +++- plugins/wasmedge_stablediffusion/sd_env.cpp | 29 +++++++++++------- plugins/wasmedge_stablediffusion/sd_env.h | 9 ++++-- plugins/wasmedge_stablediffusion/sd_func.cpp | 25 ++++++++++------ plugins/wasmedge_stablediffusion/sd_func.h | 13 ++++++-- .../wasmedge_stablediffusion/sd_module.cpp | 7 +++-- plugins/wasmedge_stablediffusion/sd_module.h | 9 ++++-- 8 files changed, 86 insertions(+), 42 deletions(-) diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index fca1bf56..5ac3abcf 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -1,4 +1,7 @@ - # setup stable diffusion +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2022 Second State INC + +# setup stable diffusion message(STATUS "Downloading stable diffusion source") FetchContent_Declare( stable-diffusion @@ -10,7 +13,7 @@ FetchContent_MakeAvailable(stable-diffusion) set_property(TARGET stable-diffusion PROPERTY POSITION_INDEPENDENT_CODE ON) get_target_property(SD_DEPS stable-diffusion LINK_LIBRARIES) foreach(dep ${SD_DEPS}) -if(TARGET ${dep}) + if(TARGET ${dep}) set_target_properties(${dep} PROPERTIES POSITION_INDEPENDENT_CODE ON ) @@ -32,16 +35,15 @@ target_compile_options(wasmedgePluginStableDiffusion ) if(WASMEDGE_LINK_PLUGINS_STATIC) -target_link_libraries(wasmedgePluginStableDiffusion - PRIVATE - wasmedgeCAPI -) - + target_link_libraries(wasmedgePluginStableDiffusion + PRIVATE + wasmedgeCAPI + ) else() -target_link_libraries(wasmedgePluginStableDiffusion - PRIVATE - wasmedge_shared -) + target_link_libraries(wasmedgePluginStableDiffusion + PRIVATE + wasmedge_shared + ) endif() target_include_directories(wasmedgePluginStableDiffusion @@ -49,8 +51,9 @@ target_include_directories(wasmedgePluginStableDiffusion $ ${CMAKE_CURRENT_SOURCE_DIR} ) + if (MSVC) -target_compile_options( + target_compile_options( wasmedgePluginStableDiffusion PRIVATE /wd4459 @@ -68,4 +71,5 @@ else() -Wno-missing-field-initializers ) endif() -install(TARGETS wasmedgePluginStableDiffusion DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) \ No newline at end of file + +install(TARGETS wasmedgePluginStableDiffusion DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) diff --git a/plugins/wasmedge_stablediffusion/sd_base.h b/plugins/wasmedge_stablediffusion/sd_base.h index c89e6b2b..c29b77aa 100644 --- a/plugins/wasmedge_stablediffusion/sd_base.h +++ b/plugins/wasmedge_stablediffusion/sd_base.h @@ -1,8 +1,12 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + #pragma once +#include "sd_env.h" + #include "common/errcode.h" #include "runtime/hostfunc.h" -#include "sd_env.h" namespace WasmEdge { namespace Host { diff --git a/plugins/wasmedge_stablediffusion/sd_env.cpp b/plugins/wasmedge_stablediffusion/sd_env.cpp index 6c80d5e6..04753eec 100644 --- a/plugins/wasmedge_stablediffusion/sd_env.cpp +++ b/plugins/wasmedge_stablediffusion/sd_env.cpp @@ -1,6 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC #include "sd_env.h" #include "sd_module.h" + namespace WasmEdge { namespace Host { namespace { @@ -33,38 +36,42 @@ EXPORT_GET_DESCRIPTOR(Descriptor) } // namespace namespace StableDiffusion { + uint32_t SDEnviornment::addContext(sd_ctx_t *Ctx) noexcept { Contexts.push_back(Ctx); return Contexts.size() - 1; } + sd_ctx_t *SDEnviornment::getContext(const uint32_t Id) noexcept { return Contexts[Id]; } -void SBLog(enum sd_log_level_t level, const char *log, void *) { - if (!log) { + +void SBLog(enum sd_log_level_t Level, const char *Log, void *) { + if (!Log) { return; } - std::string levelStr; - switch (level) { + std::string LevelStr; + switch (Level) { case SD_LOG_DEBUG: - levelStr = "DEBUG"; + LevelStr = "DEBUG"; break; case SD_LOG_INFO: - levelStr = "INFO"; + LevelStr = "INFO"; break; case SD_LOG_WARN: - levelStr = "WARN"; + LevelStr = "WARN"; break; case SD_LOG_ERROR: - levelStr = "ERROR"; + LevelStr = "ERROR"; break; default: - levelStr = "?????"; + LevelStr = "?????"; break; } - spdlog::info("[WasmEdge-StableDiffusion] SD-log: [{}] {}", levelStr, log); + spdlog::info("[WasmEdge-StableDiffusion] SD-log: [{}] {}", LevelStr, Log); } + } // namespace StableDiffusion } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/plugins/wasmedge_stablediffusion/sd_env.h b/plugins/wasmedge_stablediffusion/sd_env.h index dfd6a2db..6b64b28a 100644 --- a/plugins/wasmedge_stablediffusion/sd_env.h +++ b/plugins/wasmedge_stablediffusion/sd_env.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + #pragma once #include "stable-diffusion.h" @@ -8,7 +11,9 @@ namespace WasmEdge { namespace Host { namespace StableDiffusion { -void SBLog(enum sd_log_level_t level, const char *log, void *); + +void SBLog(enum sd_log_level_t Level, const char *Log, void *); + enum class ErrNo : uint32_t { Success = 0, // No error occurred. InvalidArgument = 1, // Caller module passed an invalid argument. @@ -36,4 +41,4 @@ class SDEnviornment { } // namespace StableDiffusion } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/plugins/wasmedge_stablediffusion/sd_func.cpp b/plugins/wasmedge_stablediffusion/sd_func.cpp index a3063b1c..564e8e51 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.cpp +++ b/plugins/wasmedge_stablediffusion/sd_func.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + #include "sd_func.h" #include "common/spdlog.h" #include "sd_env.h" @@ -10,6 +13,7 @@ #define STB_IMAGE_WRITE_IMPLEMENTATION #define STB_IMAGE_WRITE_STATIC #include "stb_image_write.h" + namespace WasmEdge { namespace Host { namespace StableDiffusion { @@ -48,24 +52,26 @@ namespace StableDiffusion { spdlog::error("[WasmEdge-StableDiffusion] "sv Message); \ return static_cast(ErrNo::MissingMemory); \ } -bool parameterCheck(SDEnviornment &Env, uint32_t width, uint32_t height, + +bool parameterCheck(SDEnviornment &Env, uint32_t Width, uint32_t Height, uint32_t SessionId) { if (SessionId >= Env.getContextSize()) { spdlog::error("[WasmEdge-StableDiffusion] Session ID is invalid."); return false; } - if (width % 64 != 0) { + if (Width % 64 != 0) { spdlog::error("[WasmEdge-StableDiffusion] Width must be a multiple of 64 " - "and greater than 0"); + "and greater than 0."); return false; } - if (height % 64 != 0) { + if (Height % 64 != 0) { spdlog::error("[WasmEdge-StableDiffusion] Height must be a multiple of 64 " - "and greater than 0"); + "and greater than 0."); return false; } return true; } + sd_image_t *readControlImage(Span ControlImage, uint8_t *ControlImageBuf, int Width, int Height, bool CannyPreprocess) { @@ -77,7 +83,6 @@ sd_image_t *readControlImage(Span ControlImage, ControlImageBuf = stbi_load(ControlImagePath.substr(5).data(), &Width, &Height, &Channel, 3); } else { - ControlImageBuf = stbi_load_from_memory( ControlImage.data(), ControlImage.size(), &Width, &Height, &Channel, 3); } @@ -210,6 +215,7 @@ Expect SDCreateContext::body( return static_cast(ErrNo::Success); } + Expect SDTextToImage::body( const Runtime::CallingFrame &Frame, uint32_t PromptPtr, uint32_t PromptLen, uint32_t SessionId, uint32_t ControlImagePtr, uint32_t ControlImageLen, @@ -289,6 +295,7 @@ Expect SDTextToImage::body( free(ControlImageBuffer); return static_cast(ErrNo::Success); } + Expect SDImageToImage::body( const Runtime::CallingFrame &Frame, uint32_t ImagePtr, uint32_t ImageLen, uint32_t SessionId, uint32_t Width, uint32_t Height, @@ -349,7 +356,6 @@ Expect SDImageToImage::body( return static_cast(ErrNo::InvalidArgument); } } else { - InputImageBuffer = stbi_load_from_memory(ImageSpan.data(), ImageSpan.size(), &ImageWidth, &ImageHeight, &Channel, 3); @@ -372,7 +378,7 @@ Expect SDImageToImage::body( sample_method_t(SampleMethod), SampleSteps, Strength, Seed, BatchCount, ControlImage, ControlStrength, StyleRatio, NormalizeInput, InputIdImagesDir.data()); - // TODO upscale image + // TODO: upscale image int Len; unsigned char *Png = stbi_write_png_to_mem( reinterpret_cast(Results), 0, Results->width, @@ -397,6 +403,7 @@ Expect SDImageToImage::body( free(ControlImageBuffer); return static_cast(ErrNo::Success); } + } // namespace StableDiffusion } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/plugins/wasmedge_stablediffusion/sd_func.h b/plugins/wasmedge_stablediffusion/sd_func.h index 0dbdf0a2..0e3135cc 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.h +++ b/plugins/wasmedge_stablediffusion/sd_func.h @@ -1,11 +1,16 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + #pragma once -#include "runtime/callingframe.h" #include "sd_base.h" +#include "runtime/callingframe.h" + namespace WasmEdge { namespace Host { namespace StableDiffusion { + class SDCreateContext : public StableDiffusion::Func { public: SDCreateContext(StableDiffusion::SDEnviornment &HostEnv) : Func(HostEnv) {} @@ -20,6 +25,7 @@ class SDCreateContext : public StableDiffusion::Func { uint32_t Schedule, uint32_t ClipOnCpu, uint32_t ControlNetCpu, uint32_t VaeOnCpu, uint32_t SessiontIdPtr); }; + class SDImageToImage : public StableDiffusion::Func { public: SDImageToImage(StableDiffusion::SDEnviornment &HostEnv) : Func(HostEnv) {} @@ -38,6 +44,7 @@ class SDImageToImage : public StableDiffusion::Func { uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr); }; + class SDTextToImage : public StableDiffusion::Func { public: SDTextToImage(StableDiffusion::SDEnviornment &HostEnv) : Func(HostEnv) {} @@ -55,6 +62,7 @@ class SDTextToImage : public StableDiffusion::Func { uint32_t OutputPathPtr, uint32_t OutputPathLen, uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr); }; + class SDConvert : public StableDiffusion::Func { public: SDConvert(StableDiffusion::SDEnviornment &HostEnv) : Func(HostEnv) {} @@ -64,6 +72,7 @@ class SDConvert : public StableDiffusion::Func { uint32_t OutputPathPtr, uint32_t OutputPathLen, uint32_t WType); }; + } // namespace StableDiffusion } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/plugins/wasmedge_stablediffusion/sd_module.cpp b/plugins/wasmedge_stablediffusion/sd_module.cpp index 29057bfa..5938db3a 100644 --- a/plugins/wasmedge_stablediffusion/sd_module.cpp +++ b/plugins/wasmedge_stablediffusion/sd_module.cpp @@ -1,10 +1,13 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + #include "sd_module.h" #include "sd_func.h" namespace WasmEdge { namespace Host { -SDModule::SDModule() : ModuleInstance("stable_diffusion") { +SDModule::SDModule() : ModuleInstance("wasmedge_stablediffusion") { addHostFunc("create_context", std::make_unique(Env)); addHostFunc("image_to_image", @@ -15,4 +18,4 @@ SDModule::SDModule() : ModuleInstance("stable_diffusion") { } } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/plugins/wasmedge_stablediffusion/sd_module.h b/plugins/wasmedge_stablediffusion/sd_module.h index 8088355b..e681ba06 100644 --- a/plugins/wasmedge_stablediffusion/sd_module.h +++ b/plugins/wasmedge_stablediffusion/sd_module.h @@ -1,10 +1,15 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + #pragma once -#include "runtime/instance/module.h" #include "sd_env.h" +#include "runtime/instance/module.h" + namespace WasmEdge { namespace Host { + class SDModule : public Runtime::Instance::ModuleInstance { public: SDModule(); @@ -15,4 +20,4 @@ class SDModule : public Runtime::Instance::ModuleInstance { }; } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge From 050c0159f1db0c685cbf308cd48056271a755f56 Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Tue, 2 Jul 2024 17:45:48 +0800 Subject: [PATCH 352/623] [Docker] Install `elfutils` for bpf Signed-off-by: Yi Huang --- utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps b/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps index eaa1cd42..41488891 100644 --- a/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps +++ b/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps @@ -7,7 +7,7 @@ ENV INFOPATH /opt/rh/gcc-toolset-13/root/usr/share/info${INFOPATH:+:${INFOPATH}} ENV PKG_CONFIG_PATH /opt/rh/gcc-toolset-13/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} RUN cd && (yum check-update || true) && \ - yum install -y wget unzip zlib-devel zlib-static + yum install -y wget unzip zlib-devel zlib-static elfutils RUN yum install -y yum-utils && \ yum-config-manager --add-repo https://cli.github.com/packages/rpm/gh-cli.repo && \ yum install -y gh From 1f3a7cb45711bfdf74a14ce787e2b8fc21c85d4c Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Wed, 3 Jul 2024 17:21:29 +0800 Subject: [PATCH 353/623] [Docker] Install gh in workflows Signed-off-by: Yi Huang --- utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps | 3 --- utils/docker/Dockerfile.manylinux_2_28_aarch64 | 3 ++- utils/docker/Dockerfile.manylinux_2_28_x86_64 | 3 ++- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps b/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps index 41488891..1499f3d4 100644 --- a/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps +++ b/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps @@ -8,9 +8,6 @@ ENV PKG_CONFIG_PATH /opt/rh/gcc-toolset-13/root/usr/lib64/pkgconfig${PKG_CONFIG_ RUN cd && (yum check-update || true) && \ yum install -y wget unzip zlib-devel zlib-static elfutils -RUN yum install -y yum-utils && \ - yum-config-manager --add-repo https://cli.github.com/packages/rpm/gh-cli.repo && \ - yum install -y gh WORKDIR /root diff --git a/utils/docker/Dockerfile.manylinux_2_28_aarch64 b/utils/docker/Dockerfile.manylinux_2_28_aarch64 index 98dc2e36..f1f11b74 100644 --- a/utils/docker/Dockerfile.manylinux_2_28_aarch64 +++ b/utils/docker/Dockerfile.manylinux_2_28_aarch64 @@ -12,7 +12,8 @@ ENV MANPATH /opt/rh/gcc-toolset-13/root/usr/share/man${MANPATH:+:${MANPATH}} ENV INFOPATH /opt/rh/gcc-toolset-13/root/usr/share/info${INFOPATH:+:${INFOPATH}} ENV PKG_CONFIG_PATH /opt/rh/gcc-toolset-13/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} -RUN cd && (yum check-update || true) && yum install -y openssl-devel rpm-build cmake && \ +RUN cd && (yum check-update || true) && yum install -y openssl-devel rpm-build cmake yum-utils && \ + yum-config-manager --add-repo https://cli.github.com/packages/rpm/gh-cli.repo && \ yum install -y gcc-toolset-13 && \ export CPU=$(/opt/python/cp311-cp311/bin/python3 -c \ 'import multiprocessing; print(multiprocessing.cpu_count())') && \ diff --git a/utils/docker/Dockerfile.manylinux_2_28_x86_64 b/utils/docker/Dockerfile.manylinux_2_28_x86_64 index ba1fe2e9..850a7959 100644 --- a/utils/docker/Dockerfile.manylinux_2_28_x86_64 +++ b/utils/docker/Dockerfile.manylinux_2_28_x86_64 @@ -12,7 +12,8 @@ ENV MANPATH /opt/rh/gcc-toolset-13/root/usr/share/man${MANPATH:+:${MANPATH}} ENV INFOPATH /opt/rh/gcc-toolset-13/root/usr/share/info${INFOPATH:+:${INFOPATH}} ENV PKG_CONFIG_PATH /opt/rh/gcc-toolset-13/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} -RUN cd && (yum check-update || true) && yum install -y openssl-devel rpm-build cmake && \ +RUN cd && (yum check-update || true) && yum install -y openssl-devel rpm-build cmake yum-utils && \ + yum-config-manager --add-repo https://cli.github.com/packages/rpm/gh-cli.repo && \ yum install -y gcc-toolset-13 && \ export CPU=$(/opt/python/cp311-cp311/bin/python3 -c \ 'import multiprocessing; print(multiprocessing.cpu_count())') && \ From 87db23af179837b8aff01c293ff17ec2aedec81d Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Thu, 4 Jul 2024 19:23:22 +0800 Subject: [PATCH 354/623] [Docker] Fix option for pytorch install Signed-off-by: Yi Huang --- utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps b/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps index 1499f3d4..72723eec 100644 --- a/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps +++ b/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps @@ -19,7 +19,7 @@ COPY wasi-nn/install-pytorch.sh . ENV PYTORCH_VERSION "1.8.2" ENV PYTORCH_INSTALL_TO "/root" ENV Torch_DIR "/root/libtorch" -RUN [ "/bin/bash", "install-pytorch.sh", "--disable-cxx11-abi" ] +RUN [ "/bin/bash", "install-pytorch.sh" ] COPY wasi-crypto/build-openssl.sh . ENV OpenSSL_DIR "/root/openssl-1.1.1n/openssl" From bad54c403c6a4189271e3aebfe9575d2b9358e60 Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Fri, 5 Jul 2024 00:54:29 +0800 Subject: [PATCH 355/623] [Docker] Add LD_LIBRARY_PATH for gcc 13 Signed-off-by: Yi Huang --- utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps | 5 ----- utils/docker/Dockerfile.manylinux_2_28_aarch64 | 2 ++ utils/docker/Dockerfile.manylinux_2_28_x86_64 | 2 ++ 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps b/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps index 72723eec..d0aa34e5 100644 --- a/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps +++ b/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps @@ -1,11 +1,6 @@ ARG BASE=wasmedge/wasmedge:manylinux_2_28_x86_64 FROM ${BASE} -ENV PATH /opt/rh/gcc-toolset-13/root/usr/bin${PATH:+:${PATH}} -ENV MANPATH /opt/rh/gcc-toolset-13/root/usr/share/man${MANPATH:+:${MANPATH}} -ENV INFOPATH /opt/rh/gcc-toolset-13/root/usr/share/info${INFOPATH:+:${INFOPATH}} -ENV PKG_CONFIG_PATH /opt/rh/gcc-toolset-13/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} - RUN cd && (yum check-update || true) && \ yum install -y wget unzip zlib-devel zlib-static elfutils diff --git a/utils/docker/Dockerfile.manylinux_2_28_aarch64 b/utils/docker/Dockerfile.manylinux_2_28_aarch64 index f1f11b74..789c5afc 100644 --- a/utils/docker/Dockerfile.manylinux_2_28_aarch64 +++ b/utils/docker/Dockerfile.manylinux_2_28_aarch64 @@ -7,9 +7,11 @@ MAINTAINER hydai hydai@secondstate.io ADD SHA256SUM.manylinux_2_28 /root/ +# See /opt/rh/gcc-toolset-13/enable ENV PATH /opt/rh/gcc-toolset-13/root/usr/bin${PATH:+:${PATH}} ENV MANPATH /opt/rh/gcc-toolset-13/root/usr/share/man${MANPATH:+:${MANPATH}} ENV INFOPATH /opt/rh/gcc-toolset-13/root/usr/share/info${INFOPATH:+:${INFOPATH}} +ENV LD_LIBRARY_PATH /opt/rh/gcc-toolset-13/root/usr/lib64:/opt/rh/gcc-toolset-13/root/usr/lib:${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} ENV PKG_CONFIG_PATH /opt/rh/gcc-toolset-13/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} RUN cd && (yum check-update || true) && yum install -y openssl-devel rpm-build cmake yum-utils && \ diff --git a/utils/docker/Dockerfile.manylinux_2_28_x86_64 b/utils/docker/Dockerfile.manylinux_2_28_x86_64 index 850a7959..832a6c3a 100644 --- a/utils/docker/Dockerfile.manylinux_2_28_x86_64 +++ b/utils/docker/Dockerfile.manylinux_2_28_x86_64 @@ -7,9 +7,11 @@ MAINTAINER hydai hydai@secondstate.io ADD SHA256SUM.manylinux_2_28 /root/ +# See /opt/rh/gcc-toolset-13/enable ENV PATH /opt/rh/gcc-toolset-13/root/usr/bin${PATH:+:${PATH}} ENV MANPATH /opt/rh/gcc-toolset-13/root/usr/share/man${MANPATH:+:${MANPATH}} ENV INFOPATH /opt/rh/gcc-toolset-13/root/usr/share/info${INFOPATH:+:${INFOPATH}} +ENV LD_LIBRARY_PATH /opt/rh/gcc-toolset-13/root/usr/lib64:/opt/rh/gcc-toolset-13/root/usr/lib:${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} ENV PKG_CONFIG_PATH /opt/rh/gcc-toolset-13/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} RUN cd && (yum check-update || true) && yum install -y openssl-devel rpm-build cmake yum-utils && \ From 36c31e59f5437441ac2df4302de10e969007f8de Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Fri, 5 Jul 2024 15:38:50 +0800 Subject: [PATCH 356/623] [Docker] Fix dependency: change to dev lib Signed-off-by: Yi Huang --- utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps b/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps index d0aa34e5..e178ab91 100644 --- a/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps +++ b/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps @@ -2,7 +2,7 @@ ARG BASE=wasmedge/wasmedge:manylinux_2_28_x86_64 FROM ${BASE} RUN cd && (yum check-update || true) && \ - yum install -y wget unzip zlib-devel zlib-static elfutils + yum install -y wget unzip zlib-devel zlib-static elfutils-libelf-devel WORKDIR /root From bdc507f428af5cd2e70a5806ab903c06c69bdc97 Mon Sep 17 00:00:00 2001 From: dm4 Date: Mon, 8 Jul 2024 17:37:12 +0800 Subject: [PATCH 357/623] [WASI-NN] ggml: disable OpenMP by default Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index d7d242ba..343c955f 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -25,6 +25,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) set(LLAMA_METAL_NDEBUG ON) set(GGML_ACCELERATE OFF) set(GGML_BLAS OFF) + set(GGML_OPENMP OFF) set(BUILD_SHARED_LIBS OFF) if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_NATIVE) From 302118bb03a699e0ffb9932daec5ba0f4561b91c Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 10 Jul 2024 14:26:56 +0800 Subject: [PATCH 358/623] [WASI-NN] ggml: bump to b3358 Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 343c955f..5a1bd6f1 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -64,7 +64,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b3259 + GIT_TAG b3358 GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(llama) From 306e8bf7d20f9dd5adb0af73af07c4001a090fc8 Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 10 Jul 2024 16:00:56 +0800 Subject: [PATCH 359/623] [WASI-NN] ggml: fix model path in nn-preload string on Windows Signed-off-by: dm4 --- plugins/wasi_nn/wasinnenv.cpp | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index 1d861320..47febedf 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -100,14 +100,23 @@ WasiNNEnvironment::WasiNNEnvironment() noexcept { auto Backend = BackendMap.find(Encode); auto Device = DeviceMap.find(Target); if (Backend != BackendMap.end() && Device != DeviceMap.end()) { - for (const std::string &P : Paths) { - if (Backend->second == Backend::GGML) { - // We write model path to model data to avoid file IO in llama.cpp. - std::string ModelPath = "preload:" + P; - std::vector ModelPathData(ModelPath.begin(), - ModelPath.end()); - Models.push_back(std::move(ModelPathData)); - } else { + if (Backend->second == Backend::GGML) { + // In GGML, we only support loading one model from nn-preload config. + // To handle paths on Windows that contains `:` in the path, we combine + // the Paths into a single string separated by `:`. + std::string P; + for (const std::string &PathSegment : Paths) { + P += PathSegment; + if (PathSegment != Paths.back()) { + P += ":"; + } + } + // We write model path to model data to avoid file IO in llama.cpp. + std::string ModelPath = "preload:" + P; + std::vector ModelPathData(ModelPath.begin(), ModelPath.end()); + Models.push_back(std::move(ModelPathData)); + } else { + for (const std::string &P : Paths) { std::vector Model; if (load(std::filesystem::u8path(P), Model)) { Models.push_back(std::move(Model)); From 662a9d89efca288d3267bf216c4db8d8d0f50ffa Mon Sep 17 00:00:00 2001 From: Jun Zhang Date: Wed, 10 Jul 2024 19:54:46 +0800 Subject: [PATCH 360/623] [WASI-LLM] Implement wasi_llm module (#3536) Signed-off-by: Jun Zhang --- plugins/CMakeLists.txt | 4 ++ plugins/wasi_llm/CMakeLists.txt | 35 ++++++++++ plugins/wasi_llm/types.h | 15 +++++ plugins/wasi_llm/wasillmbase.h | 24 +++++++ plugins/wasi_llm/wasillmfunc.cpp | 81 ++++++++++++++++++++++ plugins/wasi_llm/wasillmfunc.h | 105 +++++++++++++++++++++++++++++ plugins/wasi_llm/wasillmmodule.cpp | 21 ++++++ plugins/wasi_llm/wasillmmodule.h | 17 +++++ 8 files changed, 302 insertions(+) create mode 100644 plugins/wasi_llm/CMakeLists.txt create mode 100644 plugins/wasi_llm/types.h create mode 100644 plugins/wasi_llm/wasillmbase.h create mode 100644 plugins/wasi_llm/wasillmfunc.cpp create mode 100644 plugins/wasi_llm/wasillmfunc.h create mode 100644 plugins/wasi_llm/wasillmmodule.cpp create mode 100644 plugins/wasi_llm/wasillmmodule.h diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index f951985f..a479647d 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -90,3 +90,7 @@ endif() if(WASMEDGE_PLUGIN_FFMPEG) add_subdirectory(wasmedge_ffmpeg) endif() + +if (WASMEDGE_PLUGIN_LLM) + add_subdirectory(wasi_llm) +endif() diff --git a/plugins/wasi_llm/CMakeLists.txt b/plugins/wasi_llm/CMakeLists.txt new file mode 100644 index 00000000..6934fa93 --- /dev/null +++ b/plugins/wasi_llm/CMakeLists.txt @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +# TODO: Fetch llm.c source. + +wasmedge_add_library(wasmedgePluginWasiLLM + SHARED + wasillmfunc.cpp + wasillmmodule.cpp +) + +target_compile_options(wasmedgePluginWasiLLM + PUBLIC + -DWASMEDGE_PLUGIN +) + +target_include_directories(wasmedgePluginWasiLLM + PUBLIC + $ + ${CMAKE_CURRENT_SOURCE_DIR} +) + +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasmedgePluginWasiLLM + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasmedgePluginWasiLLM + PRIVATE + wasmedge_shared + ) +endif() + +install(TARGETS wasmedgePluginWasiLLM DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) diff --git a/plugins/wasi_llm/types.h b/plugins/wasi_llm/types.h new file mode 100644 index 00000000..7d858f1b --- /dev/null +++ b/plugins/wasi_llm/types.h @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include + +namespace WasmEdge::Host::WASILLM { + +enum class ErrNo : uint32_t { + Success = 0, + InvalidArgument = 1, +}; + +} // namespace WasmEdge::Host::WASILLM diff --git a/plugins/wasi_llm/wasillmbase.h b/plugins/wasi_llm/wasillmbase.h new file mode 100644 index 00000000..aace0bd4 --- /dev/null +++ b/plugins/wasi_llm/wasillmbase.h @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "common/errcode.h" +#include "runtime/hostfunc.h" +#include "types.h" + +namespace WasmEdge { +namespace Host { + +template class WasiLLM : public Runtime::HostFunction { +public: + WasiLLM() : Runtime::HostFunction(0) {} + +protected: + static constexpr uint32_t castErrNo(WASILLM::ErrNo E) noexcept { + return static_cast(E); + } +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_llm/wasillmfunc.cpp b/plugins/wasi_llm/wasillmfunc.cpp new file mode 100644 index 00000000..c40090d1 --- /dev/null +++ b/plugins/wasi_llm/wasillmfunc.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "wasillmfunc.h" +#include "common/spdlog.h" + +#include +#include + +namespace WasmEdge { +namespace Host { + +Expect +WasiLLMModelCreate::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t CheckPointPath, + uint32_t CheckPointPathLen) { + (void)Frame; + (void)CheckPointPath; + (void)CheckPointPathLen; + return WASILLM::ErrNo::InvalidArgument; +} + +Expect +WasiLLMModelFree::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t ModelPtr) { + (void)Frame; + (void)ModelPtr; + return WASILLM::ErrNo::InvalidArgument; +} + +Expect +WasiLLMDataLoaderCreate::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t DataPath, uint32_t DataPathLen) { + (void)Frame; + (void)DataPath; + (void)DataPathLen; + return WASILLM::ErrNo::InvalidArgument; +} + +Expect +WasiLLMDataLoaderFree::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t DataLoaderPtr) { + (void)Frame; + (void)DataLoaderPtr; + return WASILLM::ErrNo::InvalidArgument; +} + +Expect +WasiLLMTokenizerCreate::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t FilePath, uint32_t FilePathLen) { + (void)Frame; + (void)FilePath; + (void)FilePathLen; + return WASILLM::ErrNo::InvalidArgument; +} + +Expect +WasiLLMTokenizerFree::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t TokenizerPtr) { + (void)Frame; + (void)TokenizerPtr; + return WASILLM::ErrNo::InvalidArgument; +} + +Expect +WasiLLMModelTrain::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t ModelPtr, uint32_t TrainDataLoaderPtr, + uint32_t ValDataLoaderPtr, uint32_t TokenizerPtr, + uint32_t Lr, uint32_t Epoch) { + (void)Frame; + (void)ModelPtr; + (void)TrainDataLoaderPtr; + (void)ValDataLoaderPtr; + (void)TokenizerPtr; + (void)Lr; + (void)Epoch; + return WASILLM::ErrNo::InvalidArgument; +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_llm/wasillmfunc.h b/plugins/wasi_llm/wasillmfunc.h new file mode 100644 index 00000000..de5a6d21 --- /dev/null +++ b/plugins/wasi_llm/wasillmfunc.h @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "runtime/callingframe.h" +#include "types.h" +#include "wasillmbase.h" + +#include + +namespace WasmEdge { +namespace Host { + +class WasiLLMModelCreate : public WasiLLM { +public: + Expect body(const Runtime::CallingFrame &Frame, + uint32_t CheckPointPath, uint32_t CheckPointPathLen) { + return bodyImpl(Frame, CheckPointPath, CheckPointPathLen).map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t CheckPointPath, + uint32_t CheckPointPathLen); +}; + +class WasiLLMModelFree : public WasiLLM { +public: + Expect body(const Runtime::CallingFrame &Frame, uint32_t ModelPtr) { + return bodyImpl(Frame, ModelPtr).map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t ModelPtr); +}; + +class WasiLLMDataLoaderCreate : public WasiLLM { +public: + Expect body(const Runtime::CallingFrame &Frame, uint32_t DataPath, + uint32_t DataPathLen) { + return bodyImpl(Frame, DataPath, DataPathLen).map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t DataPath, uint32_t DataPathLen); +}; + +class WasiLLMDataLoaderFree : public WasiLLM { +public: + Expect body(const Runtime::CallingFrame &Frame, + uint32_t DataLoaderPtr) { + return bodyImpl(Frame, DataLoaderPtr).map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t DataLoaderPtr); +}; + +class WasiLLMTokenizerCreate : public WasiLLM { +public: + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilePath, + uint32_t FilePathLen) { + return bodyImpl(Frame, FilePath, FilePathLen).map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t FilePath, uint32_t FilePathLen); +}; + +class WasiLLMTokenizerFree : public WasiLLM { +public: + Expect body(const Runtime::CallingFrame &Frame, + uint32_t TokenizerPtr) { + return bodyImpl(Frame, TokenizerPtr).map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t TokenizerPtr); +}; + +class WasiLLMModelTrain : public WasiLLM { +public: + Expect body(const Runtime::CallingFrame &Frame, uint32_t ModelPtr, + uint32_t TrainDataLoaderPtr, uint32_t ValDataLoaderPtr, + uint32_t TokenizerPtr, uint32_t Lr, uint32_t Epoch) { + return bodyImpl(Frame, ModelPtr, TrainDataLoaderPtr, ValDataLoaderPtr, + TokenizerPtr, Lr, Epoch) + .map(castErrNo); + } + +private: + Expect + bodyImpl(const Runtime::CallingFrame &Frame, uint32_t ModelPtr, + uint32_t TrainDataLoaderPtr, uint32_t ValDataLoaderPtr, + uint32_t TokenizerPtr, uint32_t Lr, uint32_t Epoch); +}; + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_llm/wasillmmodule.cpp b/plugins/wasi_llm/wasillmmodule.cpp new file mode 100644 index 00000000..76aca85d --- /dev/null +++ b/plugins/wasi_llm/wasillmmodule.cpp @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "wasillmmodule.h" +#include "wasillmfunc.h" + +namespace WasmEdge { +namespace Host { + +WasiLLMModule::WasiLLMModule() : ModuleInstance("wasi_llm") { + addHostFunc("model_create", std::make_unique()); + addHostFunc("model_free", std::make_unique()); + addHostFunc("dataloader_create", std::make_unique()); + addHostFunc("dataloader_free", std::make_unique()); + addHostFunc("tokenizer_create", std::make_unique()); + addHostFunc("tokenizer_free", std::make_unique()); + addHostFunc("model_train", std::make_unique()); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_llm/wasillmmodule.h b/plugins/wasi_llm/wasillmmodule.h new file mode 100644 index 00000000..9033fe1c --- /dev/null +++ b/plugins/wasi_llm/wasillmmodule.h @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "runtime/instance/module.h" + +namespace WasmEdge { +namespace Host { + +class WasiLLMModule : public Runtime::Instance::ModuleInstance { +public: + WasiLLMModule(); +}; + +} // namespace Host +} // namespace WasmEdge From fc3b00e5099dc32edad189d67cfc40d5f3a3ac47 Mon Sep 17 00:00:00 2001 From: PeterD1524 <53310459+PeterD1524@users.noreply.github.com> Date: Thu, 11 Jul 2024 23:18:09 +0800 Subject: [PATCH 361/623] [WASI-NN] piper: integrate piper, a local neural text to speech system, as a new backend (#3499) * cmake add piper and patch Signed-off-by: PeterD1524 * piper cpp Signed-off-by: PeterD1524 * change PIPER to Piper Signed-off-by: PeterD1524 * fix typo and show more info in error message Signed-off-by: PeterD1524 * use commit hash rather than tag name for piper source Signed-off-by: PeterD1524 * [WASI-NN] piper: clean up patch command Signed-off-by: PeterD1524 * [WASI-NN] piper: use blank struct if piper backend not defined Signed-off-by: PeterD1524 * [WASI-NN] piper: add tests create new cmake function wasmedge_setup_piper and add it to WASINNDeps (to make tests build) check index and dimension for set input and index for get output update piper patch to not create test_piper Signed-off-by: PeterD1524 * [WASI-NN] piper: add build workflow Signed-off-by: PeterD1524 * [WASI-NN] piper: remove redundant empty line and use string view literals for error messages Signed-off-by: PeterD1524 * [WASI-NN] piper: suppress -Werror=unused-parameter in dependency: piper Signed-off-by: PeterD1524 * [WASI-NN] piper: update Piper patch to fix build failure with Clang and Ninja piper and piper-phonemize use `-Wl,-rpath,'$ORIGIN'` CXXFLAGS. `$ORIGIN` expands to `RIGIN` in Unix Makefiles and expands to empty string in Ninja. Clang does not pass empty ld arguments so `-rpath` will incorrectly absorb the next argument. similar issue: https://github.com/mesonbuild/meson/issues/2814 This patch uses CMAKE_BUILD_RPATH and CMAKE_INSTALL_RPATH to fix the problem. Signed-off-by: PeterD1524 * [WASI-NN] piper: quote variables for PATCH_COMMAND Signed-off-by: PeterD1524 * [WASI-NN] piper: suppress -Wunused-function for readEntireFile Signed-off-by: PeterD1524 * [WASI-NN] piper: make the parsing of run config more robust Signed-off-by: PeterD1524 * [WASI-NN] piper: check whether the espeak-ng data directory and libtashkeel ort model exist Signed-off-by: PeterD1524 * [WASI-NN] piper: clean some code: remove unused .get() Signed-off-by: PeterD1524 * [WASI-NN] piper: make the json parsing for setInput more robust Signed-off-by: PeterD1524 * [WASI-NN] piper: refine code Signed-off-by: PeterD1524 * [WASI-NN] piper: fix run config comments Signed-off-by: PeterD1524 * [WASI-NN] piper: clean includes Signed-off-by: PeterD1524 * [WASI-NN] piper: use std::filesystem for paths Signed-off-by: PeterD1524 --------- Signed-off-by: PeterD1524 Signed-off-by: PeterD1524 <53310459+PeterD1524@users.noreply.github.com> --- plugins/wasi_nn/CMakeLists.txt | 7 + plugins/wasi_nn/piper.cpp | 522 ++++++++++++++++++++++++++++ plugins/wasi_nn/piper.h | 105 ++++++ plugins/wasi_nn/piper.patch | 91 +++++ plugins/wasi_nn/types.h | 4 +- plugins/wasi_nn/wasinnenv.cpp | 4 +- plugins/wasi_nn/wasinnenv.h | 1 + test/plugins/wasi_nn/CMakeLists.txt | 22 ++ test/plugins/wasi_nn/wasi_nn.cpp | 254 +++++++++++++- 9 files changed, 1006 insertions(+), 4 deletions(-) create mode 100644 plugins/wasi_nn/piper.cpp create mode 100644 plugins/wasi_nn/piper.h create mode 100644 plugins/wasi_nn/piper.patch diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 5a1bd6f1..e5e350db 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -13,6 +13,7 @@ wasmedge_add_library(wasmedgePluginWasiNN tfl.cpp ggml.cpp neuralspeed.cpp + piper.cpp ) foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) @@ -152,6 +153,12 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) message(FATAL_ERROR "Can not find python3.") endif() target_link_libraries(wasmedgePluginWasiNN PRIVATE simdjson::simdjson) + elseif(BACKEND STREQUAL "piper") + wasmedge_setup_simdjson() + include(WASINNDeps) + wasmedge_setup_piper() + target_include_directories(wasmedgePluginWasiNN PRIVATE ${piper_SOURCE_DIR}/src/cpp) + target_link_libraries(wasmedgePluginWasiNN PRIVATE simdjson::simdjson) endif() endforeach() diff --git a/plugins/wasi_nn/piper.cpp b/plugins/wasi_nn/piper.cpp new file mode 100644 index 00000000..fc0dce77 --- /dev/null +++ b/plugins/wasi_nn/piper.cpp @@ -0,0 +1,522 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "piper.h" +#include "wasinnenv.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER +#include "simdjson.h" +#include "types.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +namespace WasmEdge::Host::WASINN::Piper { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER + +template +std::tuple getOption(simdjson::dom::object &Object, + std::string_view Key, T &Result) { + if (auto Error = Object[Key].get(Result)) { + if (Error == simdjson::error_code::NO_SUCH_FIELD) { + return {WASINN::ErrNo::Success, false}; + } + spdlog::error( + "[WASI-NN] Piper backend: Unable to retrieve the \"{}\" option: {}"sv, + Key, simdjson::error_message(Error)); + return {WASINN::ErrNo::InvalidArgument, false}; + } + return {WASINN::ErrNo::Success, true}; +} + +template +WASINN::ErrNo getOptionalOption(simdjson::dom::object &Object, + std::string_view Key, + std::optional &Result) { + auto Value = U{}; + auto [Err, HasValue] = getOption(Object, Key, Value); + if (HasValue) { + Result = Value; + } + return Err; +} + +Expect parseRunConfig(RunConfig &RunConfig, + const std::string &String) noexcept { + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + if (auto Error = Parser.parse(String).get(Doc)) { + spdlog::error("[WASI-NN] Piper backend: Parse run config error: {}"sv, + simdjson::error_message(Error)); + return WASINN::ErrNo::InvalidEncoding; + } + simdjson::dom::object Object; + if (auto Error = Doc.get(Object)) { + spdlog::error( + "[WASI-NN] Piper backend: The run config is not an object: {}"sv, + simdjson::error_message(Error)); + return WASINN::ErrNo::InvalidArgument; + } + + auto ModelPath = std::optional{}; + if (auto Err = getOptionalOption(Object, "model", ModelPath); + Err != WASINN::ErrNo::Success) { + return Err; + } + // Verify model file exists + if (ModelPath) { + auto Path = std::filesystem::u8path(ModelPath.value()); + if (!std::filesystem::exists(Path)) { + spdlog::error("[WASI-NN] Piper backend: Model file doesn't exist"sv); + return WASINN::ErrNo::InvalidArgument; + } + RunConfig.ModelPath = Path; + } else { + spdlog::error( + "[WASI-NN] Piper backend: The model option is required but not provided"sv); + return WASINN::ErrNo::InvalidArgument; + } + + auto ModelConfigPath = std::optional{}; + if (auto Err = getOptionalOption(Object, "config", ModelConfigPath); + Err != WASINN::ErrNo::Success) { + return Err; + } + if (ModelConfigPath) { + RunConfig.ModelConfigPath = + std::filesystem::u8path(ModelConfigPath.value()); + } else { + RunConfig.ModelConfigPath = RunConfig.ModelPath; + RunConfig.ModelConfigPath += ".json"; + } + // Verify model config exists + if (!std::filesystem::exists(RunConfig.ModelConfigPath)) { + spdlog::error("[WASI-NN] Piper backend: Model config doesn't exist"sv); + return WASINN::ErrNo::InvalidArgument; + } + + { + auto Value = std::optional{}; + if (auto Err = getOptionalOption(Object, "output_type", Value); + Err != WASINN::ErrNo::Success) { + return Err; + } + if (Value) { + if (Value.value() == "wav") { + RunConfig.OutputType = RunConfigOutputType::OUTPUT_WAV; + } else if (Value.value() == "raw") { + RunConfig.OutputType = RunConfigOutputType::OUTPUT_RAW; + } else { + spdlog::error( + "[WASI-NN] Piper backend: The output_type option has an unknown value {}."sv, + Value.value()); + return WASINN::ErrNo::InvalidArgument; + } + } + } + if (auto Err = getOptionalOption(Object, "speaker", RunConfig.SpeakerId); + Err != WASINN::ErrNo::Success) { + return Err; + } + if (auto Err = getOptionalOption(Object, "noise_scale", + RunConfig.NoiseScale); + Err != WASINN::ErrNo::Success) { + return Err; + } + if (auto Err = getOptionalOption(Object, "length_scale", + RunConfig.LengthScale); + Err != WASINN::ErrNo::Success) { + return Err; + } + if (auto Err = + getOptionalOption(Object, "noise_w", RunConfig.NoiseW); + Err != WASINN::ErrNo::Success) { + return Err; + } + if (auto Err = getOptionalOption( + Object, "sentence_silence", RunConfig.SentenceSilenceSeconds); + Err != WASINN::ErrNo::Success) { + return Err; + } + { + auto PhonemeSilence = std::optional{}; + if (auto Err = getOptionalOption(Object, "phoneme_silence", PhonemeSilence); + Err != WASINN::ErrNo::Success) { + return Err; + } + if (PhonemeSilence) { + for (auto [Key, Value] : PhonemeSilence.value()) { + auto PhonemeStr = std::string{Key}; + if (!piper::isSingleCodepoint(PhonemeStr)) { + spdlog::error( + "[WASI-NN] Piper backend: Phoneme '{}' is not a single codepoint (phoneme_silence)."sv, + PhonemeStr); + return WASINN::ErrNo::InvalidArgument; + } + auto Seconds = Value.get_double(); + if (auto Error = Seconds.error()) { + spdlog::error( + "[WASI-NN] Piper backend: Failed to get silence seconds for phoneme '{}' as a double: {}"sv, + PhonemeStr, simdjson::error_message(Error)); + return WASINN::ErrNo::InvalidArgument; + } + if (!RunConfig.PhonemeSilenceSeconds) { + RunConfig.PhonemeSilenceSeconds.emplace(); + } + auto Phoneme = piper::getCodepoint(PhonemeStr); + RunConfig.PhonemeSilenceSeconds.value()[Phoneme] = Seconds.value(); + } + } + } + { + auto Path = std::optional{}; + if (auto Err = getOptionalOption(Object, "espeak_data", Path); + Err != WASINN::ErrNo::Success) { + return Err; + } + if (Path) { + RunConfig.ESpeakDataPath = std::filesystem::u8path(Path.value()); + } + } + { + auto Path = std::optional{}; + if (auto Err = getOptionalOption(Object, "tashkeel_model", Path); + Err != WASINN::ErrNo::Success) { + return Err; + } + if (Path) { + RunConfig.TashkeelModelPath = std::filesystem::u8path(Path.value()); + } + } + if (auto Err = + std::get<0>(getOption(Object, "json_input", RunConfig.JsonInput)); + Err != WASINN::ErrNo::Success) { + return Err; + } + return WASINN::ErrNo::Success; +} + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, WASINN::Device, + uint32_t &GraphId) noexcept { + // The graph builder length must be 1. + if (Builders.size() != 1) { + spdlog::error( + "[WASI-NN] Piper backend: Wrong GraphBuilder Length {:d}, expect 1"sv, + Builders.size()); + return WASINN::ErrNo::InvalidArgument; + } + + // Add a new graph. + auto &GraphRef = Env.NNGraph.emplace_back(Backend::Piper).get(); + GraphRef.Config = std::make_unique(); + auto String = std::string{Builders[0].begin(), Builders[0].end()}; + if (auto Res = parseRunConfig(*GraphRef.Config, String); + Res != WASINN::ErrNo::Success) { + Env.NNGraph.pop_back(); + spdlog::error("[WASI-NN] Piper backend: Failed to parse run config."sv); + return Res; + } + + GraphRef.PiperConfig = std::make_unique(); + GraphRef.Voice = std::make_unique(); + piper::loadVoice(*GraphRef.PiperConfig, GraphRef.Config->ModelPath.string(), + GraphRef.Config->ModelConfigPath.string(), *GraphRef.Voice, + GraphRef.Config->SpeakerId); + GraphRef.SpeakerId = GraphRef.Config->SpeakerId; + + if (GraphRef.Voice->phonemizeConfig.phonemeType == + piper::PhonemeType::eSpeakPhonemes) { + if (!GraphRef.Config->ESpeakDataPath) { + spdlog::error( + "[WASI-NN] Piper backend: espeak-ng data directory is required for eSpeakPhonemes"sv); + Env.NNGraph.pop_back(); + return WASINN::ErrNo::InvalidArgument; + } + if (!std::filesystem::exists(GraphRef.Config->ESpeakDataPath.value())) { + spdlog::error( + "[WASI-NN] Piper backend: espeak-ng data directory doesn't exist"sv); + Env.NNGraph.pop_back(); + return WASINN::ErrNo::InvalidArgument; + } + // User provided path + GraphRef.PiperConfig->eSpeakDataPath = + GraphRef.Config->ESpeakDataPath->string(); + } else { + // Not using eSpeak + GraphRef.PiperConfig->useESpeak = false; + } + + // Enable libtashkeel for Arabic + if (GraphRef.Voice->phonemizeConfig.eSpeak.voice == "ar") { + if (!GraphRef.Config->TashkeelModelPath) { + spdlog::error( + "[WASI-NN] Piper backend: libtashkeel ort model is required for Arabic"sv); + Env.NNGraph.pop_back(); + return WASINN::ErrNo::InvalidArgument; + } + if (!std::filesystem::exists(GraphRef.Config->TashkeelModelPath.value())) { + spdlog::error( + "[WASI-NN] Piper backend: libtashkeel ort model doesn't exist"sv); + Env.NNGraph.pop_back(); + return WASINN::ErrNo::InvalidArgument; + } + GraphRef.PiperConfig->useTashkeel = true; + // User provided path + GraphRef.PiperConfig->tashkeelModelPath = + GraphRef.Config->TashkeelModelPath->string(); + } + + piper::initialize(*GraphRef.PiperConfig); + + // Scales + if (GraphRef.Config->NoiseScale) { + GraphRef.Voice->synthesisConfig.noiseScale = + GraphRef.Config->NoiseScale.value(); + } + + if (GraphRef.Config->LengthScale) { + GraphRef.Voice->synthesisConfig.lengthScale = + GraphRef.Config->LengthScale.value(); + } + + if (GraphRef.Config->NoiseW) { + GraphRef.Voice->synthesisConfig.noiseW = GraphRef.Config->NoiseW.value(); + } + + if (GraphRef.Config->SentenceSilenceSeconds) { + GraphRef.Voice->synthesisConfig.sentenceSilenceSeconds = + GraphRef.Config->SentenceSilenceSeconds.value(); + } + + if (GraphRef.Config->PhonemeSilenceSeconds) { + if (!GraphRef.Voice->synthesisConfig.phonemeSilenceSeconds) { + // Overwrite + GraphRef.Voice->synthesisConfig.phonemeSilenceSeconds = + GraphRef.Config->PhonemeSilenceSeconds; + } else { + // Merge + for (const auto &[Phoneme, SilenceSeconds] : + *GraphRef.Config->PhonemeSilenceSeconds) { + GraphRef.Voice->synthesisConfig.phonemeSilenceSeconds->try_emplace( + Phoneme, SilenceSeconds); + } + } + } // if phonemeSilenceSeconds + + // Store the loaded graph. + GraphId = Env.NNGraph.size() - 1; + return WASINN::ErrNo::Success; +} + +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept { + // Create context. + Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); + ContextId = Env.NNContext.size() - 1; + return WASINN::ErrNo::Success; +} + +template +WASINN::ErrNo getOptionalInputOption(simdjson::dom::object &Object, + std::string_view Key, + std::optional &Result) { + auto Value = T{}; + if (auto Error = Object[Key].get(Value)) { + if (Error == simdjson::error_code::NO_SUCH_FIELD) { + return WASINN::ErrNo::Success; + } + spdlog::error( + "[WASI-NN] Piper backend: Unable to retrieve \"{}\" from json input: {}"sv, + Key, simdjson::error_message(Error)); + return WASINN::ErrNo::InvalidArgument; + } + Result = Value; + return WASINN::ErrNo::Success; +} + +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept { + if (Index != 0) { + spdlog::error("[WASI-NN] Piper backend: Input index must be 0."sv); + return WASINN::ErrNo::InvalidArgument; + } + if (!(Tensor.Dimension.size() == 1 && Tensor.Dimension[0] == 1)) { + spdlog::error( + "[WASI-NN] Piper backend: Input tensor dimension must be [1]."sv); + return WASINN::ErrNo::InvalidArgument; + } + + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + + auto Line = std::string{Tensor.Tensor.begin(), Tensor.Tensor.end()}; + + if (GraphRef.Config->JsonInput) { + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + if (auto Error = Parser.parse(Line).get(Doc)) { + spdlog::error("[WASI-NN] Piper backend: Parse json input error: {}"sv, + simdjson::error_message(Error)); + return WASINN::ErrNo::InvalidEncoding; + } + simdjson::dom::object Object; + if (auto Error = Doc.get(Object)) { + spdlog::error( + "[WASI-NN] Piper backend: The json input is not an object: {}"sv, + simdjson::error_message(Error)); + return WASINN::ErrNo::InvalidArgument; + } + + // Text is required + auto Text = std::string_view{}; + if (auto Error = Object["text"].get(Text)) { + spdlog::error( + "[WASI-NN] Piper backend: Unable to retrieve required \"text\" from json input: {}"sv, + simdjson::error_message(Error)); + return WASINN::ErrNo::InvalidArgument; + } + Line = Text; + + // Override speaker id + auto SpeakerId = std::optional{}; + if (auto Err = getOptionalInputOption(Object, "speaker_id", SpeakerId); + Err != WASINN::ErrNo::Success) { + return Err; + } + if (SpeakerId) { + GraphRef.Voice->synthesisConfig.speakerId = SpeakerId; + } else { + auto SpeakerName = std::optional{}; + if (auto Err = getOptionalInputOption(Object, "speaker", SpeakerName); + Err != WASINN::ErrNo::Success) { + return Err; + } + if (SpeakerName) { + // Resolve to id using speaker id map + auto Name = std::string{SpeakerName.value()}; + if (GraphRef.Voice->modelConfig.speakerIdMap && + GraphRef.Voice->modelConfig.speakerIdMap->count(Name) > 0) { + GraphRef.Voice->synthesisConfig.speakerId = + GraphRef.Voice->modelConfig.speakerIdMap.value()[Name]; + } else { + spdlog::warn("[WASI-NN] Piper backend: No speaker named: {}"sv, Name); + } + } + } + } + CxtRef.Line = Line; + return WASINN::ErrNo::Success; +} + +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept { + if (Index != 0) { + spdlog::error("[WASI-NN] Piper backend: Output index must be 0."sv); + return WASINN::ErrNo::InvalidArgument; + } + + auto &CxtRef = Env.NNContext[ContextId].get(); + + if (!CxtRef.Output) { + spdlog::error("[WASI-NN] Piper backend: No output available."sv); + return WASINN::ErrNo::InvalidArgument; + } + + if (CxtRef.Output->size() >= std::numeric_limits::max()) { + spdlog::error( + "[WASI-NN] Piper backend: Output size {} is greater than std::numeric_limits::max() {}."sv, + CxtRef.Output->size(), std::numeric_limits::max()); + return WASINN::ErrNo::InvalidArgument; + } + + if (CxtRef.Output->size() > OutBuffer.size_bytes()) { + spdlog::error( + "[WASI-NN] Piper backend: Output size {} is greater than buffer size {}."sv, + CxtRef.Output->size(), OutBuffer.size_bytes()); + return WASINN::ErrNo::InvalidArgument; + } + + std::memcpy(OutBuffer.data(), CxtRef.Output->data(), CxtRef.Output->size()); + BytesWritten = CxtRef.Output->size(); + return WASINN::ErrNo::Success; +} + +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + + if (!CxtRef.Line) { + spdlog::error("[WASI-NN] Piper backend: Input is not set."sv); + return WASINN::ErrNo::InvalidArgument; + } + + auto Result = piper::SynthesisResult{}; + if (GraphRef.Config->OutputType == RunConfigOutputType::OUTPUT_WAV) { + auto AudioFile = + std::stringstream{std::ios::binary | std::ios::in | std::ios::out}; + piper::textToWavFile(*GraphRef.PiperConfig, *GraphRef.Voice, + CxtRef.Line.value(), AudioFile, Result); + auto String = AudioFile.str(); + CxtRef.Output = std::vector{String.begin(), String.end()}; + } else if (GraphRef.Config->OutputType == RunConfigOutputType::OUTPUT_RAW) { + auto AudioBuffer = std::vector{}; + piper::textToAudio(*GraphRef.PiperConfig, *GraphRef.Voice, + CxtRef.Line.value(), AudioBuffer, Result, nullptr); + CxtRef.Output = std::vector( + sizeof(decltype(AudioBuffer)::value_type) * AudioBuffer.size()); + std::memcpy(CxtRef.Output->data(), AudioBuffer.data(), + CxtRef.Output->size()); + } + + // Restore config (json_input) + GraphRef.Voice->synthesisConfig.speakerId = GraphRef.SpeakerId; + return WASINN::ErrNo::Success; +} +#else +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error("[WASI-NN] Piper backend is not supported."sv); + return WASINN::ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WASINN::WasiNNEnvironment &, + Span>, WASINN::Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WASINN::WasiNNEnvironment &, uint32_t, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + Span, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WASINN::WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +#endif +} // namespace WasmEdge::Host::WASINN::Piper diff --git a/plugins/wasi_nn/piper.h b/plugins/wasi_nn/piper.h new file mode 100644 index 00000000..70b4b19b --- /dev/null +++ b/plugins/wasi_nn/piper.h @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "plugin/plugin.h" +#include "types.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER +#include +#include +#include +#include +#include +#include +#endif + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} + +namespace WasmEdge::Host::WASINN::Piper { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER +enum class RunConfigOutputType { OUTPUT_WAV, OUTPUT_RAW }; +struct RunConfig { + // Path to .onnx voice file + std::filesystem::path ModelPath; + + // Path to JSON voice config file + std::filesystem::path ModelConfigPath; + + // Type of output to produce. + // Default is a WAV file. + RunConfigOutputType OutputType = RunConfigOutputType::OUTPUT_WAV; + + // Numerical id of the default speaker (multi-speaker voices) + std::optional SpeakerId; + + // Amount of noise to add during audio generation + std::optional NoiseScale; + + // Speed of speaking (1 = normal, < 1 is faster, > 1 is slower) + std::optional LengthScale; + + // Variation in phoneme lengths + std::optional NoiseW; + + // Seconds of silence to add after each sentence + std::optional SentenceSilenceSeconds; + + // Path to espeak-ng data directory + std::optional ESpeakDataPath; + + // Path to libtashkeel ort model + // https://github.com/mush42/libtashkeel/ + std::optional TashkeelModelPath; + + // input is JSON instead of text with format: + // { + // "text": str, (required) + // "speaker_id": int, (optional) + // "speaker": str, (optional) + // } + bool JsonInput = false; + + // Seconds of extra silence to insert after a single phoneme + std::optional> PhonemeSilenceSeconds; +}; +struct Graph { + std::unique_ptr Config; + std::unique_ptr PiperConfig; + std::unique_ptr Voice; + std::optional SpeakerId; +}; +struct Context { + Context(size_t GId, Graph &) noexcept : GraphId(GId) {} + size_t GraphId; + std::optional Line; + std::optional> Output; +}; +#else +struct Graph {}; +struct Context { + Context(size_t, Graph &) noexcept {} +}; +#endif + +struct Environ {}; + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +} // namespace WasmEdge::Host::WASINN::Piper diff --git a/plugins/wasi_nn/piper.patch b/plugins/wasi_nn/piper.patch new file mode 100644 index 00000000..56a42028 --- /dev/null +++ b/plugins/wasi_nn/piper.patch @@ -0,0 +1,91 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +index f96ec44..a759c35 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -13,11 +13,13 @@ if(MSVC) + add_compile_options("$<$:/utf-8>") + elseif(NOT APPLE) + # Linux flags +- string(APPEND CMAKE_CXX_FLAGS " -Wall -Wextra -Wl,-rpath,'$ORIGIN'") ++ string(APPEND CMAKE_CXX_FLAGS " -Wall -Wextra") ++ list(APPEND CMAKE_BUILD_RPATH "$ORIGIN") ++ list(APPEND CMAKE_INSTALL_RPATH "$ORIGIN") + string(APPEND CMAKE_C_FLAGS " -Wall -Wextra") + endif() + +-add_executable(piper src/cpp/main.cpp src/cpp/piper.cpp) ++add_library(piper src/cpp/piper.cpp) + add_executable(test_piper src/cpp/test.cpp src/cpp/piper.cpp) + + # NOTE: external project prefix are shortened because of path length restrictions on Windows +@@ -60,10 +62,14 @@ endif() + + if(NOT DEFINED PIPER_PHONEMIZE_DIR) + set(PIPER_PHONEMIZE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pi") ++ find_program(GIT_CMD git REQUIRED) + ExternalProject_Add( + piper_phonemize_external + PREFIX "${CMAKE_CURRENT_BINARY_DIR}/p" +- URL "https://github.com/rhasspy/piper-phonemize/archive/refs/heads/master.zip" ++ GIT_REPOSITORY "https://github.com/rhasspy/piper-phonemize.git" ++ GIT_TAG "bfc2e7549957829b0227c66a305d11cc88167bda" # master ++ UPDATE_DISCONNECTED TRUE ++ PATCH_COMMAND "${GIT_CMD}" "apply" "${CMAKE_CURRENT_SOURCE_DIR}/piper-phonemize.patch" + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${PIPER_PHONEMIZE_DIR} + ) + add_dependencies(piper piper_phonemize_external) +@@ -74,7 +80,9 @@ endif() + + if((NOT MSVC) AND (NOT APPLE)) + # Linux flags +- string(APPEND CMAKE_CXX_FLAGS " -Wall -Wextra -Wl,-rpath,'$ORIGIN'") ++ string(APPEND CMAKE_CXX_FLAGS " -Wall -Wextra") ++ list(APPEND CMAKE_BUILD_RPATH "$ORIGIN") ++ list(APPEND CMAKE_INSTALL_RPATH "$ORIGIN") + string(APPEND CMAKE_C_FLAGS " -Wall -Wextra") + target_link_libraries(piper -static-libgcc -static-libstdc++) + +@@ -104,14 +112,6 @@ target_include_directories(piper PUBLIC + + target_compile_definitions(piper PUBLIC _PIPER_VERSION=${piper_version}) + +-# ---- Declare test ---- +-include(CTest) +-enable_testing() +-add_test( +- NAME test_piper +- COMMAND test_piper "${CMAKE_SOURCE_DIR}/etc/test_voice.onnx" "${PIPER_PHONEMIZE_DIR}/share/espeak-ng-data" "${CMAKE_CURRENT_BINARY_DIR}/test.wav" +-) +- + target_compile_features(test_piper PUBLIC cxx_std_17) + + target_include_directories( +diff --git a/VERSION b/VERSION +index 26aaba0..867e524 100644 +--- a/VERSION ++++ b/VERSION +@@ -1 +1 @@ +-1.2.0 ++1.2.0 +\ No newline at end of file +diff --git a/piper-phonemize.patch b/piper-phonemize.patch +new file mode 100644 +index 0000000..f8ca06f +--- /dev/null ++++ b/piper-phonemize.patch +@@ -0,0 +1,15 @@ ++diff --git a/CMakeLists.txt b/CMakeLists.txt ++index ec7b501..34cf7b1 100644 ++--- a/CMakeLists.txt +++++ b/CMakeLists.txt ++@@ -17,7 +17,9 @@ if(MSVC) ++ ++ elseif(NOT APPLE) ++ # Linux flags ++- string(APPEND CMAKE_CXX_FLAGS " -Wall -Wextra -Wl,-rpath,'$ORIGIN'") +++ string(APPEND CMAKE_CXX_FLAGS " -Wall -Wextra") +++ list(APPEND CMAKE_BUILD_RPATH "$ORIGIN") +++ list(APPEND CMAKE_INSTALL_RPATH "$ORIGIN") ++ string(APPEND CMAKE_C_FLAGS " -Wall -Wextra") ++ endif() ++ diff --git a/plugins/wasi_nn/types.h b/plugins/wasi_nn/types.h index d1dd7fe7..69c60f26 100644 --- a/plugins/wasi_nn/types.h +++ b/plugins/wasi_nn/types.h @@ -37,11 +37,13 @@ enum class Backend : uint8_t { Autodetect = 5, GGML = 6, NeuralSpeed = 7, + Piper = 11, }; #define FOR_EACH_BACKEND(F) \ F(OpenVINO) \ - F(ONNX) F(Tensorflow) F(PyTorch) F(TensorflowLite) F(GGML) F(NeuralSpeed) + F(ONNX) \ + F(Tensorflow) F(PyTorch) F(TensorflowLite) F(GGML) F(NeuralSpeed) F(Piper) struct TensorData { Span Dimension; diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index 47febedf..4db5185e 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -2,6 +2,7 @@ // SPDX-FileCopyrightText: 2019-2022 Second State INC #include "wasinnenv.h" +#include "types.h" #include "wasinnmodule.h" #include @@ -29,7 +30,8 @@ std::map BackendMap = { {"tensorflowlite"sv, Backend::TensorflowLite}, {"autodetect"sv, Backend::Autodetect}, {"ggml"sv, Backend::GGML}, - {"neuralspeed"sv, Backend::NeuralSpeed}}; + {"neuralspeed"sv, Backend::NeuralSpeed}, + {"piper"sv, Backend::Piper}}; std::map DeviceMap = {{"cpu"sv, Device::CPU}, {"gpu"sv, Device::GPU}, diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index f2b1adb8..1367b353 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -13,6 +13,7 @@ #include "neuralspeed.h" #include "onnx.h" #include "openvino.h" +#include "piper.h" #include "tf.h" #include "tfl.h" #include "torch.h" diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index 10d45ba0..040fae7b 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -77,6 +77,28 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) ${CMAKE_CURRENT_BINARY_DIR}/wasinn_neural_speed_fixtures/ne_phi_q_nf4_bestla_cfp32_g32.bin MD5=5e055b41f8cc1a42f26ff8742719ef1e ) + elseif(BACKEND STREQUAL "piper") + message(STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_piper_fixtures") + download( + https://github.com/rhasspy/piper/raw/master/etc/test_voice.onnx + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_piper_fixtures/test_voice.onnx + SHA256=937682595755bbb3ee9f131b8a4b2b1ba2fac9b26431fcd7aa48cff0f7382838 + ) + download( + https://github.com/rhasspy/piper/raw/master/etc/test_voice.onnx.json + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_piper_fixtures/test_voice.onnx.json + SHA256=f3e0b906861cc2fb8a50e12ceca263afe226ff9688f60e9d4ef943d4f047a513 + ) + download( + https://github.com/rhasspy/piper/releases/download/2023.11.14-2/piper_linux_x86_64.tar.gz + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_piper_fixtures/piper_linux_x86_64.tar.gz + SHA256=a50cb45f355b7af1f6d758c1b360717877ba0a398cc8cbe6d2a7a3a26e225992 + ) + file(ARCHIVE_EXTRACT + INPUT ${CMAKE_CURRENT_BINARY_DIR}/wasinn_piper_fixtures/piper_linux_x86_64.tar.gz + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/wasinn_piper_fixtures + PATTERNS piper/espeak-ng-data + ) else() # Add the other backend test files fetching here. endif() diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 0df306d3..97de3d54 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -20,7 +20,8 @@ using WasmEdge::Host::WASINN::ErrNo; defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML) || \ - defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED) + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED) || \ + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER) namespace { WasmEdge::Runtime::Instance::ModuleInstance * createModule(std::string_view NNRPCURI = "") { @@ -2008,4 +2009,253 @@ TEST(WasiNNTest, NeuralSpeedBackend) { EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); } } -#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED \ No newline at end of file +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER +TEST(WasiNNTest, PiperBackend) { + // Create the wasmedge_process module instance. + auto *NNMod = dynamic_cast(createModule()); + ASSERT_TRUE(NNMod != nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(400))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // Load the files. + (void)readEntireFile; + std::string Text = "This is a test."; + std::vector TensorData(Text.begin(), Text.end()); + + std::vector TensorDim{1}; + uint32_t BuilderPtr = UINT32_C(0); + uint32_t LoadEntryPtr = UINT32_C(0); + uint32_t SetInputEntryPtr = UINT32_C(0); + uint32_t OutBoundPtr = UINT32_C(410 * 65536); + uint32_t StorePtr = UINT32_C(65536); + + // Return value. + std::array Errno = {UINT32_C(0)}; + + // Get the function "load". + auto *FuncInst = NNMod->findFuncExports("load"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncLoad = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "init_execution_context". + FuncInst = NNMod->findFuncExports("init_execution_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInit = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "set_input". + FuncInst = NNMod->findFuncExports("set_input"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSetInput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "get_output". + FuncInst = NNMod->findFuncExports("get_output"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetOutput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "compute". + FuncInst = NNMod->findFuncExports("compute"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncCompute = + dynamic_cast(FuncInst->getHostFunc()); + + // Piper WASI-NN load tests. + // Test: load -- graph id ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::Piper), + UINT32_C(0), OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- graph builder ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + OutBoundPtr, UINT32_C(1), + static_cast(Backend::Piper), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- Piper config ptr out of bounds. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, OutBoundPtr, 1, BuilderPtr); + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::Piper), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- wrong config encoding. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, StorePtr, 0, BuilderPtr); + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::Piper), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidEncoding)); + } + + // Test: load -- load successfully. + std::string Config = + "{\"model\": \"./wasinn_piper_fixtures/test_voice.onnx\", " + "\"espeak_data\": \"./wasinn_piper_fixtures/piper/espeak-ng-data\"}"; + std::vector ConfigData(Config.begin(), Config.end()); + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, StorePtr, ConfigData.size(), BuilderPtr); + writeBinaries(MemInst, ConfigData, StorePtr); + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::Piper), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Piper WASI-NN init_execution_context tests. + // Test: init_execution_context -- graph id invalid. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(2), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: init_execution_context -- init context successfully. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Piper WASI-NN set_input tests. + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, 2, BuilderPtr); + writeFatPointer(MemInst, + StorePtr + TensorDim.size() * + sizeof(decltype(TensorDim)::value_type), + TensorData.size(), BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries( + MemInst, TensorData, + StorePtr + TensorDim.size() * sizeof(decltype(TensorDim)::value_type)); + + // Test: set_input -- context id exceeds. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(3), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: set_input -- set input successfully. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += TensorDim.size() * sizeof(decltype(TensorDim)::value_type) + + TensorData.size(); + + // Piper WASI-NN compute tests. + // Test: compute -- context id exceeds. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(3)}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: compute -- compute successfully. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(0)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Piper WASI-NN get_output tests. + // Test: get_output -- output bytes ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- output buffer ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), OutBoundPtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- get output successfully. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + // Should output more than 10000 bytes. + auto BytesWritten = *MemInst.getPointer(BuilderPtr); + EXPECT_GE(BytesWritten, 10000); + } +} +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER \ No newline at end of file From 6d0472d8cec0d1a216d23c2bb55119ec26750e08 Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Tue, 9 Jul 2024 17:38:33 +0800 Subject: [PATCH 362/623] [Misc] Fix missing .hcl file extensions Signed-off-by: Yi Huang --- ...ocker-bake.ci-image-base => docker-bake.ci-image-base.hcl} | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename utils/docker/{docker-bake.ci-image-base => docker-bake.ci-image-base.hcl} (83%) diff --git a/utils/docker/docker-bake.ci-image-base b/utils/docker/docker-bake.ci-image-base.hcl similarity index 83% rename from utils/docker/docker-bake.ci-image-base rename to utils/docker/docker-bake.ci-image-base.hcl index 0054032c..b98e75b6 100644 --- a/utils/docker/docker-bake.ci-image-base +++ b/utils/docker/docker-bake.ci-image-base.hcl @@ -6,8 +6,8 @@ group "default" { } target "base" { - dockerfile = "./utils/docker/Dockerfile.ci-image-base" - context = "." + dockerfile = "Dockerfile.ci-image-base" + context = "./utils/docker" } target "x86_64" { From 7ed9c893d9fff4f6c4305f0e28a9ea07b84fef2a Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Thu, 11 Jul 2024 16:55:02 +0800 Subject: [PATCH 363/623] [Docker] Resolve warnings on ENV definition Signed-off-by: Yi Huang --- ...ockerfile.manylinux_2_28-build-plugins-deps | 18 +++++++++--------- utils/docker/Dockerfile.manylinux_2_28_aarch64 | 10 +++++----- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps b/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps index e178ab91..46685449 100644 --- a/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps +++ b/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps @@ -7,25 +7,25 @@ RUN cd && (yum check-update || true) && \ WORKDIR /root COPY opencvmini/install-opencvmini.sh . -ENV OPENCV_VERSION "4.8.0" +ENV OPENCV_VERSION="4.8.0" RUN [ "/bin/bash", "install-opencvmini.sh" ] COPY wasi-nn/install-pytorch.sh . -ENV PYTORCH_VERSION "1.8.2" -ENV PYTORCH_INSTALL_TO "/root" -ENV Torch_DIR "/root/libtorch" +ENV PYTORCH_VERSION="1.8.2" +ENV PYTORCH_INSTALL_TO="/root" +ENV Torch_DIR="/root/libtorch" RUN [ "/bin/bash", "install-pytorch.sh" ] COPY wasi-crypto/build-openssl.sh . -ENV OpenSSL_DIR "/root/openssl-1.1.1n/openssl" +ENV OpenSSL_DIR="/root/openssl-1.1.1n/openssl" RUN [ "/bin/bash", "build-openssl.sh" ] COPY ffmpeg/install-ffmpeg-v6.0.sh . RUN [ "/bin/bash", "install-ffmpeg-v6.0.sh" ] -ENV PKG_CONFIG_PATH /root/FFmpeg-n6.0/output/lib/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} -ENV LD_LIBRARY_PATH /root/FFmpeg-n6.0/output/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} +ENV PKG_CONFIG_PATH=/root/FFmpeg-n6.0/output/lib/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} +ENV LD_LIBRARY_PATH=/root/FFmpeg-n6.0/output/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} -ENV OPENVINO_VERSION "2024.2.0" -ENV OPENVINO_YEAR "2024" +ENV OPENVINO_VERSION="2024.2.0" +ENV OPENVINO_YEAR="2024" RUN yum clean all diff --git a/utils/docker/Dockerfile.manylinux_2_28_aarch64 b/utils/docker/Dockerfile.manylinux_2_28_aarch64 index 789c5afc..f0621e63 100644 --- a/utils/docker/Dockerfile.manylinux_2_28_aarch64 +++ b/utils/docker/Dockerfile.manylinux_2_28_aarch64 @@ -8,11 +8,11 @@ MAINTAINER hydai hydai@secondstate.io ADD SHA256SUM.manylinux_2_28 /root/ # See /opt/rh/gcc-toolset-13/enable -ENV PATH /opt/rh/gcc-toolset-13/root/usr/bin${PATH:+:${PATH}} -ENV MANPATH /opt/rh/gcc-toolset-13/root/usr/share/man${MANPATH:+:${MANPATH}} -ENV INFOPATH /opt/rh/gcc-toolset-13/root/usr/share/info${INFOPATH:+:${INFOPATH}} -ENV LD_LIBRARY_PATH /opt/rh/gcc-toolset-13/root/usr/lib64:/opt/rh/gcc-toolset-13/root/usr/lib:${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} -ENV PKG_CONFIG_PATH /opt/rh/gcc-toolset-13/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} +ENV PATH=/opt/rh/gcc-toolset-13/root/usr/bin${PATH:+:${PATH}} +ENV MANPATH=/opt/rh/gcc-toolset-13/root/usr/share/man${MANPATH:+:${MANPATH}} +ENV INFOPATH=/opt/rh/gcc-toolset-13/root/usr/share/info${INFOPATH:+:${INFOPATH}} +ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-13/root/usr/lib64:/opt/rh/gcc-toolset-13/root/usr/lib:${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} +ENV PKG_CONFIG_PATH=/opt/rh/gcc-toolset-13/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} RUN cd && (yum check-update || true) && yum install -y openssl-devel rpm-build cmake yum-utils && \ yum-config-manager --add-repo https://cli.github.com/packages/rpm/gh-cli.repo && \ From 321214308472a239df530cd78b12791591c6a061 Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Tue, 9 Jul 2024 17:58:55 +0800 Subject: [PATCH 364/623] [Docker] Refactor manylinux_2_28 with docker bake Signed-off-by: Yi Huang --- ...aarch64 => Dockerfile.manylinux_2_28-base} | 9 ++- ...=> Dockerfile.manylinux_2_28-plugins-deps} | 26 +++++--- utils/docker/Dockerfile.manylinux_2_28_x86_64 | 59 ------------------- utils/docker/docker-bake.manylinux.hcl | 55 +++++++++++++++++ 4 files changed, 79 insertions(+), 70 deletions(-) rename utils/docker/{Dockerfile.manylinux_2_28_aarch64 => Dockerfile.manylinux_2_28-base} (94%) rename utils/docker/{Dockerfile.manylinux_2_28-build-plugins-deps => Dockerfile.manylinux_2_28-plugins-deps} (75%) delete mode 100644 utils/docker/Dockerfile.manylinux_2_28_x86_64 create mode 100644 utils/docker/docker-bake.manylinux.hcl diff --git a/utils/docker/Dockerfile.manylinux_2_28_aarch64 b/utils/docker/Dockerfile.manylinux_2_28-base similarity index 94% rename from utils/docker/Dockerfile.manylinux_2_28_aarch64 rename to utils/docker/Dockerfile.manylinux_2_28-base index f0621e63..ae2ea372 100644 --- a/utils/docker/Dockerfile.manylinux_2_28_aarch64 +++ b/utils/docker/Dockerfile.manylinux_2_28-base @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-FileCopyrightText: 2019-2022 Second State INC -FROM quay.io/pypa/manylinux_2_28_aarch64 +ARG BASE_IMAGE +FROM ${BASE_IMAGE} MAINTAINER hydai hydai@secondstate.io @@ -14,6 +15,8 @@ ENV INFOPATH=/opt/rh/gcc-toolset-13/root/usr/share/info${INFOPATH:+:${INFOPATH}} ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-13/root/usr/lib64:/opt/rh/gcc-toolset-13/root/usr/lib:${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} ENV PKG_CONFIG_PATH=/opt/rh/gcc-toolset-13/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} +ARG LLVM_TARGETS LLVM_TRIPLE + RUN cd && (yum check-update || true) && yum install -y openssl-devel rpm-build cmake yum-utils && \ yum-config-manager --add-repo https://cli.github.com/packages/rpm/gh-cli.repo && \ yum install -y gcc-toolset-13 && \ @@ -50,8 +53,8 @@ RUN cd && (yum check-update || true) && yum install -y openssl-devel rpm-build c cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_INSTALL_PREFIX=/opt/rh/gcc-toolset-13/root/usr \ -DPython3_ROOT_DIR=/opt/python/cp311-cp311 -DLLVM_LIBDIR_SUFFIX=64 \ - -DLLVM_TARGETS_TO_BUILD="AArch64;BPF" -DLLVM_ENABLE_PROJECTS="lld;clang" \ - -DLLVM_DEFAULT_TARGET_TRIPLE="aarch64-redhat-linux-gnu" \ + -DLLVM_TARGETS_TO_BUILD="${LLVM_TARGETS}" -DLLVM_ENABLE_PROJECTS="lld;clang" \ + -DLLVM_DEFAULT_TARGET_TRIPLE="${LLVM_TRIPLE}" \ -DBUILD_SHARED_LIBS=OFF llvm && \ cmake --build build --target install && \ rm -rf build && rm -rf * diff --git a/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps b/utils/docker/Dockerfile.manylinux_2_28-plugins-deps similarity index 75% rename from utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps rename to utils/docker/Dockerfile.manylinux_2_28-plugins-deps index 46685449..066d2d95 100644 --- a/utils/docker/Dockerfile.manylinux_2_28-build-plugins-deps +++ b/utils/docker/Dockerfile.manylinux_2_28-plugins-deps @@ -1,14 +1,12 @@ -ARG BASE=wasmedge/wasmedge:manylinux_2_28_x86_64 -FROM ${BASE} - -RUN cd && (yum check-update || true) && \ - yum install -y wget unzip zlib-devel zlib-static elfutils-libelf-devel +ARG BASE_IMAGE +FROM ${BASE_IMAGE} as base WORKDIR /root -COPY opencvmini/install-opencvmini.sh . -ENV OPENCV_VERSION="4.8.0" -RUN [ "/bin/bash", "install-opencvmini.sh" ] +### deps for x86_64 ### +FROM base AS deps-amd64 +RUN cd && (yum check-update || true) && \ + yum install -y wget unzip zlib-devel zlib-static elfutils-libelf-devel COPY wasi-nn/install-pytorch.sh . ENV PYTORCH_VERSION="1.8.2" @@ -16,6 +14,18 @@ ENV PYTORCH_INSTALL_TO="/root" ENV Torch_DIR="/root/libtorch" RUN [ "/bin/bash", "install-pytorch.sh" ] +### deps for aarch64 ### +FROM base AS deps-arm64 +RUN cd && (yum check-update || true) && \ + yum install -y wget unzip zlib-devel zlib-static + +### deps for all ### +FROM deps-${TARGETARCH} as final + +COPY opencvmini/install-opencvmini.sh . +ENV OPENCV_VERSION="4.8.0" +RUN [ "/bin/bash", "install-opencvmini.sh" ] + COPY wasi-crypto/build-openssl.sh . ENV OpenSSL_DIR="/root/openssl-1.1.1n/openssl" RUN [ "/bin/bash", "build-openssl.sh" ] diff --git a/utils/docker/Dockerfile.manylinux_2_28_x86_64 b/utils/docker/Dockerfile.manylinux_2_28_x86_64 deleted file mode 100644 index 832a6c3a..00000000 --- a/utils/docker/Dockerfile.manylinux_2_28_x86_64 +++ /dev/null @@ -1,59 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# SPDX-FileCopyrightText: 2019-2022 Second State INC - -FROM quay.io/pypa/manylinux_2_28_x86_64 - -MAINTAINER hydai hydai@secondstate.io - -ADD SHA256SUM.manylinux_2_28 /root/ - -# See /opt/rh/gcc-toolset-13/enable -ENV PATH /opt/rh/gcc-toolset-13/root/usr/bin${PATH:+:${PATH}} -ENV MANPATH /opt/rh/gcc-toolset-13/root/usr/share/man${MANPATH:+:${MANPATH}} -ENV INFOPATH /opt/rh/gcc-toolset-13/root/usr/share/info${INFOPATH:+:${INFOPATH}} -ENV LD_LIBRARY_PATH /opt/rh/gcc-toolset-13/root/usr/lib64:/opt/rh/gcc-toolset-13/root/usr/lib:${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} -ENV PKG_CONFIG_PATH /opt/rh/gcc-toolset-13/root/usr/lib64/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} - -RUN cd && (yum check-update || true) && yum install -y openssl-devel rpm-build cmake yum-utils && \ - yum-config-manager --add-repo https://cli.github.com/packages/rpm/gh-cli.repo && \ - yum install -y gcc-toolset-13 && \ - export CPU=$(/opt/python/cp311-cp311/bin/python3 -c \ - 'import multiprocessing; print(multiprocessing.cpu_count())') && \ - export CFGFLAGS="--prefix=/opt/rh/gcc-toolset-13/root/usr --disable-shared --libdir=/opt/rh/gcc-toolset-13/root/usr/lib64" && \ - curl -s -L -O --remote-name-all \ - https://github.com/ninja-build/ninja/archive/refs/tags/v1.11.1.tar.gz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/llvm-17.0.6.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/lld-17.0.6.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/libunwind-17.0.6.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/cmake-17.0.6.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/third-party-17.0.6.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/clang-17.0.6.src.tar.xz && \ - sha256sum -c SHA256SUM.manylinux_2_28 && \ - gzip -dc v1.11.1.tar.gz | tar -xf - && \ - xz -dc llvm-17.0.6.src.tar.xz | tar -xf - && \ - xz -dc lld-17.0.6.src.tar.xz | tar -xf - && \ - xz -dc libunwind-17.0.6.src.tar.xz | tar -xf - && \ - xz -dc cmake-17.0.6.src.tar.xz | tar -xf - && \ - xz -dc third-party-17.0.6.src.tar.xz | tar -xf - && \ - xz -dc clang-17.0.6.src.tar.xz | tar -xf - && \ - export ZSTDFLAGS=(PREFIX=/opt/rh/gcc-toolset-13/root/usr LIBDIR=/opt/rh/gcc-toolset-13/root/usr/lib64 SED_ERE_OPT=--regexp-extended MOREFLAGS="-std=c17 -O3 -fPIC -fPIE -fvisibility=hidden") && \ - mkdir build && cd build && /opt/python/cp311-cp311/bin/python \ - ../ninja-1.11.1/configure.py --bootstrap \ - --with-python=/opt/python/cp311-cp311/bin/python && \ - cp -v ninja /opt/rh/gcc-toolset-13/root/usr/bin/ninja && cd - && rm -rf build && \ - mv -v llvm-17.0.6.src llvm && \ - mv -v lld-17.0.6.src lld && \ - mv -v libunwind-17.0.6.src libunwind && \ - mv -v cmake-17.0.6.src cmake && \ - mv -v third-party-17.0.6.src third-party && \ - mv -v clang-17.0.6.src clang && \ - cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_INSTALL_PREFIX=/opt/rh/gcc-toolset-13/root/usr \ - -DPython3_ROOT_DIR=/opt/python/cp311-cp311 -DLLVM_LIBDIR_SUFFIX=64 \ - -DLLVM_TARGETS_TO_BUILD="X86;BPF" -DLLVM_ENABLE_PROJECTS="lld;clang" \ - -DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-pc-linux-gnu" \ - -DBUILD_SHARED_LIBS=OFF llvm && \ - cmake --build build --target install && \ - rm -rf build && rm -rf * - -RUN yum clean all diff --git a/utils/docker/docker-bake.manylinux.hcl b/utils/docker/docker-bake.manylinux.hcl new file mode 100644 index 00000000..306081cb --- /dev/null +++ b/utils/docker/docker-bake.manylinux.hcl @@ -0,0 +1,55 @@ +target "base" { + dockerfile = "Dockerfile.manylinux_2_28-base" + context = "./utils/docker" +} + +target "plugins-base" { + dockerfile = "./docker/Dockerfile.manylinux_2_28-plugins-deps" + context = "./utils" +} + +target "x86_64" { + inherits = ["base"] + platforms = ["linux/amd64"] + tags = ["wasmedge/wasmedge:manylinux_2_28_x86_64"] + args = { + BASE_IMAGE = "quay.io/pypa/manylinux_2_28_x86_64", + LLVM_TARGETS = "X86;BPF", + LLVM_TRIPLE = "x86_64-pc-linux-gnu" + } +} + +target "x86_64-plugins" { + inherits = ["plugins-base"] + platforms = ["linux/amd64"] + tags = ["wasmedge/wasmedge:manylinux_2_28_x86_64-plugins-deps"] + contexts = { + "wasmedge/wasmedge:manylinux_2_28_x86_64"= "target:x86_64" + } + args = { + BASE_IMAGE = "wasmedge/wasmedge:manylinux_2_28_x86_64" + } +} + +target "aarch64" { + inherits = ["base"] + platforms = ["linux/arm64"] + tags = ["wasmedge/wasmedge:manylinux_2_28_aarch64"] + args = { + BASE_IMAGE = "quay.io/pypa/manylinux_2_28_aarch64", + LLVM_TARGETS = "AArch64;BPF", + LLVM_TRIPLE = "aarch64-redhat-linux-gnu" + } +} + +target "aarch64-plugins" { + inherits = ["plugins-base"] + platforms = ["linux/arm64"] + tags = ["wasmedge/wasmedge:manylinux_2_28_aarch64-plugins-deps"] + contexts = { + "wasmedge/wasmedge:manylinux_2_28_aarch64" = "target:aarch64" + } + args = { + BASE_IMAGE = "wasmedge/wasmedge:manylinux_2_28_aarch64" + } +} From b573f51df42b2a4bbe782100b5432e7297202664 Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 17 Jul 2024 14:55:52 +0800 Subject: [PATCH 365/623] [WASI-NN] ggml: bump llama.cpp b3405 Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index e5e350db..6a047558 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -65,7 +65,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b3358 + GIT_TAG b3405 GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(llama) From 92e543d2a3f59b2b5ec94817f3860091efd386f6 Mon Sep 17 00:00:00 2001 From: grorge Date: Wed, 17 Jul 2024 15:43:44 +0800 Subject: [PATCH 366/623] [Plugin] Stable Diffusion: fix output path Signed-off-by: grorge --- .../wasmedge_stablediffusion/wasmedge_stablediffusion.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp index b2533c44..57dc372c 100644 --- a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp +++ b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp @@ -218,7 +218,7 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); auto BytesWritten = *MemInst.getPointer(BytesWrittenPtr); EXPECT_GE(BytesWritten, 50); - std::ifstream Fin(OutputPath.data(), std::ios::in | std::ios::binary); + std::ifstream Fin(OutputPathString, std::ios::in | std::ios::binary); EXPECT_FALSE(Fin.fail()); Fin.close(); } @@ -307,7 +307,7 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); auto BytesWritten = *MemInst.getPointer(BytesWrittenPtr); EXPECT_GE(BytesWritten, 50); - std::ifstream Fin(OutputPath2.data(), std::ios::in | std::ios::binary); + std::ifstream Fin(OutputPathString2, std::ios::in | std::ios::binary); EXPECT_FALSE(Fin.fail()); Fin.close(); } From b06a184f1607ec4baebcd8da159105eb0289cbc9 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Mon, 17 Jun 2024 15:54:29 +0800 Subject: [PATCH 367/623] [WASI-NN] Basic support for whisper backend. Signed-off-by: YiYing He --- plugins/wasi_nn/CMakeLists.txt | 16 ++ plugins/wasi_nn/types.h | 9 +- plugins/wasi_nn/wasinnenv.cpp | 1 + plugins/wasi_nn/wasinnenv.h | 1 + plugins/wasi_nn/whispercpp.cpp | 353 +++++++++++++++++++++++++++++++++ plugins/wasi_nn/whispercpp.h | 74 +++++++ 6 files changed, 453 insertions(+), 1 deletion(-) create mode 100644 plugins/wasi_nn/whispercpp.cpp create mode 100644 plugins/wasi_nn/whispercpp.h diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 6a047558..eb14e3cc 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -14,6 +14,7 @@ wasmedge_add_library(wasmedgePluginWasiNN ggml.cpp neuralspeed.cpp piper.cpp + whispercpp.cpp ) foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) @@ -159,6 +160,21 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) wasmedge_setup_piper() target_include_directories(wasmedgePluginWasiNN PRIVATE ${piper_SOURCE_DIR}/src/cpp) target_link_libraries(wasmedgePluginWasiNN PRIVATE simdjson::simdjson) + elseif(BACKEND STREQUAL "whisper") + set(WHISPER_NO_ACCELERATE ON CACHE INTERNAL "Whisper turn off accelerate") + set(BUILD_SHARED_LIBS OFF CACHE INTERNAL "Whisper not build shared") + include(FetchContent) + FetchContent_Declare( + whisper + GIT_REPOSITORY https://github.com/ggerganov/whisper.cpp.git + GIT_TAG v1.6.2 + GIT_SHALLOW FALSE + ) + FetchContent_MakeAvailable(whisper) + set_property(TARGET whisper PROPERTY POSITION_INDEPENDENT_CODE ON) + target_link_libraries(wasmedgePluginWasiNN PRIVATE + whisper + ) endif() endforeach() diff --git a/plugins/wasi_nn/types.h b/plugins/wasi_nn/types.h index 69c60f26..271af483 100644 --- a/plugins/wasi_nn/types.h +++ b/plugins/wasi_nn/types.h @@ -37,13 +37,20 @@ enum class Backend : uint8_t { Autodetect = 5, GGML = 6, NeuralSpeed = 7, + Whisper = 9, Piper = 11, }; #define FOR_EACH_BACKEND(F) \ F(OpenVINO) \ F(ONNX) \ - F(Tensorflow) F(PyTorch) F(TensorflowLite) F(GGML) F(NeuralSpeed) F(Piper) + F(Tensorflow) \ + F(PyTorch) \ + F(TensorflowLite) \ + F(GGML) \ + F(NeuralSpeed) \ + F(Whisper) \ + F(Piper) struct TensorData { Span Dimension; diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index 4db5185e..89994c5f 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -31,6 +31,7 @@ std::map BackendMap = { {"autodetect"sv, Backend::Autodetect}, {"ggml"sv, Backend::GGML}, {"neuralspeed"sv, Backend::NeuralSpeed}, + {"whisper"sv, Backend::Whisper}, {"piper"sv, Backend::Piper}}; std::map DeviceMap = {{"cpu"sv, Device::CPU}, diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index 1367b353..7e38cb80 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -18,6 +18,7 @@ #include "tfl.h" #include "torch.h" #include "types.h" +#include "whispercpp.h" #ifdef WASMEDGE_BUILD_WASI_NN_RPC #include diff --git a/plugins/wasi_nn/whispercpp.cpp b/plugins/wasi_nn/whispercpp.cpp new file mode 100644 index 00000000..21632e30 --- /dev/null +++ b/plugins/wasi_nn/whispercpp.cpp @@ -0,0 +1,353 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "whispercpp.h" +#include "wasinnenv.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER +#define DR_WAV_IMPLEMENTATION +#include + +#include +#endif + +namespace WasmEdge::Host::WASINN::Whisper { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER + +namespace { + +bool checkAudioRIFF(const std::string_view Buf, const std::string_view Format) { + if (Buf.size() < 12 || Buf.substr(0, 4) != "RIFF"sv) { + return false; + } + if (Buf.substr(8, 4) != Format) { + return false; + } + uint32_t ChunkSize = *reinterpret_cast(Buf.data() + 4); + if (ChunkSize + 8 != Buf.size()) { + return false; + } + return true; +} + +bool loadWAV(Span Buf, std::vector &PCMF32) { + // Not to use the helper function in examples of whisper.cpp to prevent from + // copy. + drwav WAV; + const uint32_t ConstSampleRate = 16000; + + if (!drwav_init_memory(&WAV, Buf.data(), Buf.size(), nullptr)) { + spdlog::error("[WASI-NN] Whisper backend: load WAV failed."sv); + return false; + } + + if (WAV.channels != 1 && WAV.channels != 2) { + spdlog::error("[WASI-NN] Whisper backend: WAV must be mono or stereo."sv); + drwav_uninit(&WAV); + return false; + } + + if (WAV.sampleRate != ConstSampleRate) { + spdlog::error("[WASI-NN] Whisper backend: WAV must be {} kHz."sv, + ConstSampleRate / 1000); + drwav_uninit(&WAV); + return false; + } + + if (WAV.bitsPerSample != 16) { + spdlog::error("[WASI-NN] Whisper backend: WAV must be 16-bit."sv); + drwav_uninit(&WAV); + return false; + } + + const uint32_t N = WAV.totalPCMFrameCount; + std::vector PCM16(N * WAV.channels); + drwav_read_pcm_frames_s16(&WAV, N, PCM16.data()); + drwav_uninit(&WAV); + + PCMF32.resize(N); + if (WAV.channels == 1) { + for (uint64_t I = 0; I < N; I++) { + PCMF32[I] = static_cast(PCM16[I]) / 32768.0f; + } + } else { + for (uint64_t I = 0; I < N; I++) { + PCMF32[I] = + static_cast(PCM16[2 * I] + PCM16[2 * I + 1]) / 65536.0f; + } + } + return true; +} + +void WhisperLogCallback(ggml_log_level LogLevel, const char *LogText, + void *UserData) { + const Graph &GraphRef = *reinterpret_cast(UserData); + if (!GraphRef.EnableLog) { + return; + } + std::string Text(LogText); + // Remove the trailing newlines. + Text = Text.erase(Text.find_last_not_of("\n") + 1); + // Skip for "." + if (Text == ".") { + return; + } + if (LogLevel == GGML_LOG_LEVEL_ERROR) { + spdlog::error("[WASI-NN] whisper.cpp: {}"sv, Text); + } else if (LogLevel == GGML_LOG_LEVEL_WARN) { + spdlog::warn("[WASI-NN] whisper.cpp: {}"sv, Text); + } else if (LogLevel == GGML_LOG_LEVEL_INFO) { + spdlog::info("[WASI-NN] whisper.cpp: {}"sv, Text); + } else if (LogLevel == GGML_LOG_LEVEL_DEBUG) { + spdlog::debug("[WASI-NN] whisper.cpp: {}"sv, Text); + } +} + +void WhisperOutputSegmentCallback(struct whisper_context *WhisperCtx, + struct whisper_state * /* state */, int NewN, + void *UserData) { + auto &CxtRef = *reinterpret_cast(UserData); + const int SegN = whisper_full_n_segments(WhisperCtx); + + auto ToTimeStr = [](int64_t T) -> std::string { + T *= 10; + uint32_t HR = static_cast(T / (1000 * 60 * 60)); + T %= 1000 * 60 * 60; + uint32_t M = static_cast(T / (1000 * 60)); + T %= 1000 * 60; + uint32_t S = static_cast(T / 1000); + uint32_t MS = static_cast(T % 1000); + char Buf[32]; + snprintf(Buf, sizeof(Buf), "%02d:%02d:%02d.%03d", HR, M, S, MS); + return std::string(Buf); + }; + + // Output the last new N segments. + for (int I = SegN - NewN; I < SegN; I++) { + int64_t T0 = whisper_full_get_segment_t0(WhisperCtx, I); + int64_t T1 = whisper_full_get_segment_t1(WhisperCtx, I); + // TODO: Add the print timestamp config. + CxtRef.Outputs += "["; + CxtRef.Outputs += ToTimeStr(T0); + CxtRef.Outputs += " --> "; + CxtRef.Outputs += ToTimeStr(T1); + CxtRef.Outputs += "] "; + CxtRef.Outputs += whisper_full_get_segment_text(WhisperCtx, I); + CxtRef.Outputs += "\n"; + } +} + +} // namespace + +Expect load(WasiNNEnvironment &Env, Span> Builders, + [[maybe_unused]] Device Device, uint32_t &GraphId) noexcept { + // Add a new graph. + Env.NNGraph.emplace_back(Backend::Whisper); + auto &GraphRef = Env.NNGraph.back().get(); + + // Initialize the parameters. + auto CParam = whisper_context_default_params(); + GraphRef.EnableLog = false; + GraphRef.EnableDebugLog = false; + GraphRef.UseGPU = CParam.use_gpu; + GraphRef.MainGPU = CParam.gpu_device; + GraphRef.ModelFilePath = ""sv; + GraphRef.ModelLanguage = "en"sv; + + // Set whisper log callback. + whisper_log_set(WhisperLogCallback, &GraphRef); + + // TODO: Use the metadata to pass data. + + // Handle the model path. + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] Whisper backend: Handling model path."sv); + } + auto Weight = Builders[0]; + const std::string_view BinModel(reinterpret_cast(Weight.data()), + Weight.size()); + if (BinModel.substr(0, 8) == "preload:"sv) { + GraphRef.ModelFilePath = BinModel.substr(8); + } + + // Initialize whisper context from model file with parameters. + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] Whisper backend: Initialize whisper context with " + "given parameters"sv); + } + if (GraphRef.ModelFilePath == ""sv) { + GraphRef.WhisperCtx = whisper_init_from_buffer_with_params( + Weight.data(), Weight.size(), CParam); + } else { + GraphRef.WhisperCtx = whisper_init_from_file_with_params( + GraphRef.ModelFilePath.c_str(), CParam); + } + if (GraphRef.WhisperCtx == nullptr) { + spdlog::error( + "[WASI-NN] Whisper backend: Error: unable to init whisper context from " + "model."sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] Whisper backend: Initialize whisper context with " + "given parameters...Done"sv); + } + + // Store the loaded graph. + GraphId = Env.NNGraph.size() - 1; + + return ErrNo::Success; +} + +Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, + uint32_t &ContextId) noexcept { + auto &GraphRef = Env.NNGraph[GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] Whisper backend: initExecCtx"sv); + } + Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); + ContextId = Env.NNContext.size() - 1; + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &WParam = CxtRef.WhisperParams; + WParam.print_progress = false; + WParam.thold_pt = GraphRef.WordThreshold; + WParam.language = GraphRef.ModelLanguage.c_str(); + WParam.temperature_inc = GraphRef.TemperatureInc; + WParam.temperature = GraphRef.Temperature; + WParam.entropy_thold = GraphRef.EntropyThreshold; + WParam.logprob_thold = GraphRef.LogprobThreshold; + WParam.grammar_penalty = GraphRef.GrammarPenalty; + WParam.new_segment_callback = WhisperOutputSegmentCallback; + WParam.new_segment_callback_user_data = &CxtRef; + if (GraphRef.EnableLog) { + spdlog::info("[WASI-NN] Whisper backend: whisper_system_info: {}"sv, + whisper_print_system_info()); + } + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] Whisper backend: initExecCtx...Done"sv); + } + return ErrNo::Success; +} + +Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index [[maybe_unused]], + const TensorData &Tensor) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] Whisper backend: setInput"sv); + } + + if (Tensor.Dimension.size() != 2) { + spdlog::error("[WASI-NN] Tensor dimension is out of range, expect 2-dim, " + "but got {}-dim.", + Tensor.Dimension.size()); + return WASINN::ErrNo::InvalidArgument; + } + if (Tensor.Dimension[0] != 1) { + spdlog::error("[WASI-NN] Only 1 channel supported for now."); + return WASINN::ErrNo::InvalidArgument; + } + + // Tensor type not used here. Not to check this. + + // Check the input audio file format and load. Currently WAV supported. + if (!checkAudioRIFF( + std::string_view(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()), + "WAVE"sv)) { + spdlog::error("[WASI-NN] Only WAV format supported now."sv); + return WASINN::ErrNo::InvalidArgument; + } + if (!loadWAV(Tensor.Tensor, CxtRef.InputPCM)) { + return WASINN::ErrNo::InvalidArgument; + } + + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] Whisper backend: setInput...Done"sv); + } + return ErrNo::Success; +} + +Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index, Span OutBuffer, + uint32_t &BytesWritten) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] Whisper backend: getOutput with Index {}"sv, + Index); + } + + // Check out buffer max size. + if (OutBuffer.size() < CxtRef.Outputs.length()) { + spdlog::error("[WASI-NN] Expect out buffer max size {}, but got {}", + CxtRef.Outputs.length(), OutBuffer.size()); + return WASINN::ErrNo::InvalidArgument; + } + + std::copy_n(CxtRef.Outputs.data(), CxtRef.Outputs.length(), OutBuffer.data()); + BytesWritten = CxtRef.Outputs.length(); + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] Whisper backend: getOutput with Index {}...Done"sv, + Index); + } + return ErrNo::Success; +} + +Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] Whisper backend: compute"sv); + } + + CxtRef.Outputs.clear(); + if (whisper_full(GraphRef.WhisperCtx, CxtRef.WhisperParams, + CxtRef.InputPCM.data(), CxtRef.InputPCM.size()) != 0) { + spdlog::error( + "[WASI-NN] Whisper backend: Error: failed to process audio."sv); + return ErrNo::RuntimeError; + } + + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] Whisper backend: compute...Done"sv); + } + return ErrNo::Success; +} + +#else + +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error("[WASI-NN] Whisper backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"whisper\" to build it."sv); + return ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WasiNNEnvironment &, Span>, Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WasiNNEnvironment &, uint32_t, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WasiNNEnvironment &, uint32_t, uint32_t, Span, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} + +#endif +} // namespace WasmEdge::Host::WASINN::Whisper diff --git a/plugins/wasi_nn/whispercpp.h b/plugins/wasi_nn/whispercpp.h new file mode 100644 index 00000000..69b8c038 --- /dev/null +++ b/plugins/wasi_nn/whispercpp.h @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#pragma once + +#include "plugin/plugin.h" +#include "types.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER +#include + +#include +#include +#endif + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} + +namespace WasmEdge::Host::WASINN::Whisper { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER +struct Graph { + whisper_context *WhisperCtx = nullptr; + std::string ModelFilePath; + std::string ModelLanguage; + // Whisper parameters: + bool EnableLog = false; + bool EnableDebugLog = false; + // Context parameters: + bool UseGPU = true; + int64_t MainGPU = 0; // Use GPU 0 by default + // Sampling parameters: + float WordThreshold = 0.01f; + float EntropyThreshold = 2.40f; + float LogprobThreshold = -1.00f; + float Temperature = 0.0f; + float TemperatureInc = 0.2f; + float GrammarPenalty = 100.0f; +}; + +struct Context { +public: + Context(size_t GId, Graph &) noexcept : GraphId(GId) {} + size_t GraphId; + std::vector InputPCM; // mono-channel F32 PCM input. + whisper_full_params WhisperParams = whisper_full_default_params( + whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH); + std::string Outputs; +}; +#else +struct Graph {}; +struct Context { + Context(size_t, Graph &) noexcept {} +}; +#endif + +struct Environ {}; + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +} // namespace WasmEdge::Host::WASINN::Whisper From 14b6e4357ff6c3eacc636286bd61d462cd741f83 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Thu, 4 Jul 2024 06:25:13 +0800 Subject: [PATCH 368/623] [WASI-NN] Tests for whisper backend. Signed-off-by: YiYing He --- test/plugins/wasi_nn/CMakeLists.txt | 12 + test/plugins/wasi_nn/wasi_nn.cpp | 490 +++++++++++++++++++++------- 2 files changed, 385 insertions(+), 117 deletions(-) diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index 040fae7b..e0f5f9ed 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -99,6 +99,18 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/wasinn_piper_fixtures PATTERNS piper/espeak-ng-data ) + elseif(BACKEND STREQUAL "whisper") + message( STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_whisper_fixtures") + download( + https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_whisper_fixtures/ggml-base.bin + MD5=4279db3d7b18d9f6e4d5817a16af4f09 + ) + download( + https://github.com/second-state/WasmEdge-WASINN-examples/raw/master/whisper-basic/test.wav + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_whisper_fixtures/test.wav + MD5=6cf3f7af1ebbd6b29c373e526b548dba + ) else() # Add the other backend test files fetching here. endif() diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 97de3d54..9cb61300 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -1,19 +1,22 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2022 Second State INC -#include "common/types.h" -#include "runtime/instance/module.h" #include "wasinnfunc.h" #include "wasinnmodule.h" +#include "common/types.h" +#include "runtime/instance/module.h" + +#include + #include #include #include -#include #include #include using WasmEdge::Host::WASINN::Backend; +using WasmEdge::Host::WASINN::Device; using WasmEdge::Host::WASINN::ErrNo; #if defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO) || \ @@ -21,7 +24,8 @@ using WasmEdge::Host::WASINN::ErrNo; defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED) || \ - defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER) + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER) || \ + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER) namespace { WasmEdge::Runtime::Instance::ModuleInstance * createModule(std::string_view NNRPCURI = "") { @@ -96,7 +100,7 @@ std::vector classSort(WasmEdge::Span Array) { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO TEST(WasiNNTest, OpenVINOBackend) { - // Create the wasmedge_process module instance. + // Create the wasi_nn module instance. auto *NNMod = dynamic_cast(createModule()); ASSERT_TRUE(NNMod != nullptr); @@ -169,7 +173,8 @@ TEST(WasiNNTest, OpenVINOBackend) { EXPECT_TRUE(HostFuncLoad.run( CallFrame, std::initializer_list{ - LoadEntryPtr, UINT32_C(2), UINT32_C(0), UINT32_C(0), BuilderPtr}, + LoadEntryPtr, UINT32_C(2), static_cast(Backend::OpenVINO), + static_cast(Device::CPU), BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::RuntimeError)); @@ -180,7 +185,8 @@ TEST(WasiNNTest, OpenVINOBackend) { EXPECT_TRUE(HostFuncLoad.run( CallFrame, std::initializer_list{ - LoadEntryPtr, UINT32_C(2), UINT32_C(0), UINT32_C(0), OutBoundPtr}, + LoadEntryPtr, UINT32_C(2), static_cast(Backend::OpenVINO), + static_cast(Device::CPU), OutBoundPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); @@ -191,7 +197,8 @@ TEST(WasiNNTest, OpenVINOBackend) { EXPECT_TRUE(HostFuncLoad.run( CallFrame, std::initializer_list{ - OutBoundPtr, UINT32_C(2), UINT32_C(0), UINT32_C(0), BuilderPtr}, + OutBoundPtr, UINT32_C(2), static_cast(Backend::OpenVINO), + static_cast(Device::CPU), BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); @@ -207,7 +214,8 @@ TEST(WasiNNTest, OpenVINOBackend) { EXPECT_TRUE(HostFuncLoad.run( CallFrame, std::initializer_list{ - LoadEntryPtr, UINT32_C(2), UINT32_C(0), UINT32_C(0), BuilderPtr}, + LoadEntryPtr, UINT32_C(2), static_cast(Backend::OpenVINO), + static_cast(Device::CPU), BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); @@ -223,7 +231,8 @@ TEST(WasiNNTest, OpenVINOBackend) { EXPECT_TRUE(HostFuncLoad.run( CallFrame, std::initializer_list{ - LoadEntryPtr, UINT32_C(2), UINT32_C(0), UINT32_C(0), BuilderPtr}, + LoadEntryPtr, UINT32_C(2), static_cast(Backend::OpenVINO), + static_cast(Device::CPU), BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); @@ -242,7 +251,8 @@ TEST(WasiNNTest, OpenVINOBackend) { EXPECT_TRUE(HostFuncLoad.run( CallFrame, std::initializer_list{ - LoadEntryPtr, UINT32_C(4), UINT32_C(0), UINT32_C(0), BuilderPtr}, + LoadEntryPtr, UINT32_C(4), static_cast(Backend::OpenVINO), + static_cast(Device::CPU), BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); @@ -253,7 +263,8 @@ TEST(WasiNNTest, OpenVINOBackend) { EXPECT_TRUE(HostFuncLoad.run( CallFrame, std::initializer_list{ - LoadEntryPtr, UINT32_C(2), UINT32_C(0), UINT32_C(3), BuilderPtr}, + LoadEntryPtr, UINT32_C(2), static_cast(Backend::OpenVINO), + static_cast(Device::AUTO), BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); @@ -264,7 +275,8 @@ TEST(WasiNNTest, OpenVINOBackend) { EXPECT_TRUE(HostFuncLoad.run( CallFrame, std::initializer_list{ - LoadEntryPtr, UINT32_C(2), UINT32_C(0), UINT32_C(0), BuilderPtr}, + LoadEntryPtr, UINT32_C(2), static_cast(Backend::OpenVINO), + static_cast(Device::CPU), BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); @@ -276,7 +288,8 @@ TEST(WasiNNTest, OpenVINOBackend) { EXPECT_TRUE(HostFuncLoad.run( CallFrame, std::initializer_list{ - LoadEntryPtr, UINT32_C(2), UINT32_C(0), UINT32_C(0), BuilderPtr}, + LoadEntryPtr, UINT32_C(2), static_cast(Backend::OpenVINO), + static_cast(Device::CPU), BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 1); @@ -505,12 +518,14 @@ TEST(WasiNNTest, OpenVINOBackend) { EXPECT_EQ(SortedIndex[I], CorrectClasses[I]); } } + + delete NNMod; } #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH TEST(WasiNNTest, PyTorchBackend) { - // Create the wasmedge_process module instance. + // Create the wasi_nn module instance. auto *NNMod = dynamic_cast(createModule()); EXPECT_FALSE(NNMod == nullptr); @@ -578,36 +593,36 @@ TEST(WasiNNTest, PyTorchBackend) { // Torch WASI-NN load tests. // Test: load -- meaningless binaries. { - EXPECT_TRUE(HostFuncLoad.run(CallFrame, - std::initializer_list{ - LoadEntryPtr, UINT32_C(1), - static_cast(Backend::PyTorch), - UINT32_C(0), BuilderPtr}, - Errno)); + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::PyTorch), + static_cast(Device::CPU), BuilderPtr}, + Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); } // Test: load -- graph id ptr out of bounds. { - EXPECT_TRUE(HostFuncLoad.run(CallFrame, - std::initializer_list{ - LoadEntryPtr, UINT32_C(1), - static_cast(Backend::PyTorch), - UINT32_C(0), OutBoundPtr}, - Errno)); + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::PyTorch), + static_cast(Device::CPU), OutBoundPtr}, + Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); } // Test: load -- graph builder ptr out of bounds. { - EXPECT_TRUE(HostFuncLoad.run(CallFrame, - std::initializer_list{ - OutBoundPtr, UINT32_C(1), - static_cast(Backend::PyTorch), - UINT32_C(0), BuilderPtr}, - Errno)); + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + OutBoundPtr, UINT32_C(1), static_cast(Backend::PyTorch), + static_cast(Device::CPU), BuilderPtr}, + Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); } @@ -616,12 +631,12 @@ TEST(WasiNNTest, PyTorchBackend) { writeFatPointer(MemInst, OutBoundPtr, static_cast(WeightRead.size()), BuilderPtr); { - EXPECT_TRUE(HostFuncLoad.run(CallFrame, - std::initializer_list{ - LoadEntryPtr, UINT32_C(1), - static_cast(Backend::PyTorch), - UINT32_C(0), BuilderPtr}, - Errno)); + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::PyTorch), + static_cast(Device::CPU), BuilderPtr}, + Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); } @@ -633,36 +648,36 @@ TEST(WasiNNTest, PyTorchBackend) { writeBinaries(MemInst, WeightRead, StorePtr); StorePtr += WeightRead.size(); { - EXPECT_TRUE(HostFuncLoad.run(CallFrame, - std::initializer_list{ - LoadEntryPtr, UINT32_C(2), - static_cast(Backend::PyTorch), - UINT32_C(0), BuilderPtr}, - Errno)); + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(2), static_cast(Backend::PyTorch), + static_cast(Device::CPU), BuilderPtr}, + Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); } // Test: load -- unsupported device. CPU 0, GPU 1, TPU 2 { - EXPECT_TRUE(HostFuncLoad.run(CallFrame, - std::initializer_list{ - LoadEntryPtr, UINT32_C(1), - static_cast(Backend::PyTorch), - UINT32_C(3), BuilderPtr}, - Errno)); + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::PyTorch), + static_cast(Device::AUTO), BuilderPtr}, + Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); } // Test: load -- load successfully. { - EXPECT_TRUE(HostFuncLoad.run(CallFrame, - std::initializer_list{ - LoadEntryPtr, UINT32_C(1), - static_cast(Backend::PyTorch), - UINT32_C(0), BuilderPtr}, - Errno)); + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::PyTorch), + static_cast(Device::CPU), BuilderPtr}, + Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); BuilderPtr += 4; @@ -670,12 +685,12 @@ TEST(WasiNNTest, PyTorchBackend) { // Test: load -- load second graph. { - EXPECT_TRUE(HostFuncLoad.run(CallFrame, - std::initializer_list{ - LoadEntryPtr, UINT32_C(1), - static_cast(Backend::PyTorch), - UINT32_C(0), BuilderPtr}, - Errno)); + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::PyTorch), + static_cast(Device::CPU), BuilderPtr}, + Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 1); BuilderPtr += 4; @@ -877,12 +892,14 @@ TEST(WasiNNTest, PyTorchBackend) { EXPECT_EQ(SortedIndex[I], CorrectClasses[I]); } } + + delete NNMod; } #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE TEST(WasiNNTest, TFLiteBackend) { - // Create the wasmedge_process module instance. + // Create the wasi_nn module instance. auto *NNMod = dynamic_cast(createModule()); EXPECT_FALSE(NNMod == nullptr); @@ -956,7 +973,7 @@ TEST(WasiNNTest, TFLiteBackend) { std::initializer_list{ LoadEntryPtr, UINT32_C(1), static_cast(Backend::TensorflowLite), - UINT32_C(0), BuilderPtr}, + static_cast(Device::CPU), BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); @@ -969,7 +986,7 @@ TEST(WasiNNTest, TFLiteBackend) { std::initializer_list{ LoadEntryPtr, UINT32_C(1), static_cast(Backend::TensorflowLite), - UINT32_C(0), OutBoundPtr}, + static_cast(Device::CPU), OutBoundPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); @@ -982,7 +999,7 @@ TEST(WasiNNTest, TFLiteBackend) { std::initializer_list{ OutBoundPtr, UINT32_C(1), static_cast(Backend::TensorflowLite), - UINT32_C(0), BuilderPtr}, + static_cast(Device::CPU), BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); @@ -997,7 +1014,7 @@ TEST(WasiNNTest, TFLiteBackend) { std::initializer_list{ LoadEntryPtr, UINT32_C(1), static_cast(Backend::TensorflowLite), - UINT32_C(0), BuilderPtr}, + static_cast(Device::CPU), BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); @@ -1015,7 +1032,7 @@ TEST(WasiNNTest, TFLiteBackend) { std::initializer_list{ LoadEntryPtr, UINT32_C(2), static_cast(Backend::TensorflowLite), - UINT32_C(0), BuilderPtr}, + static_cast(Device::CPU), BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); @@ -1028,7 +1045,7 @@ TEST(WasiNNTest, TFLiteBackend) { std::initializer_list{ LoadEntryPtr, UINT32_C(1), static_cast(Backend::TensorflowLite), - UINT32_C(3), BuilderPtr}, + static_cast(Device::AUTO), BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); @@ -1041,7 +1058,7 @@ TEST(WasiNNTest, TFLiteBackend) { std::initializer_list{ LoadEntryPtr, UINT32_C(1), static_cast(Backend::TensorflowLite), - UINT32_C(0), BuilderPtr}, + static_cast(Device::CPU), BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); @@ -1055,7 +1072,7 @@ TEST(WasiNNTest, TFLiteBackend) { std::initializer_list{ LoadEntryPtr, UINT32_C(1), static_cast(Backend::TensorflowLite), - UINT32_C(0), BuilderPtr}, + static_cast(Device::CPU), BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 1); @@ -1242,12 +1259,14 @@ TEST(WasiNNTest, TFLiteBackend) { OutputClassification[CorrectClasses[I]]); } } + + delete NNMod; } #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML TEST(WasiNNTest, GGMLBackend) { - // Create the wasmedge_process module instance. + // Create the wasi_nn module instance. auto *NNMod = dynamic_cast(createModule()); EXPECT_FALSE(NNMod == nullptr); @@ -1311,36 +1330,36 @@ TEST(WasiNNTest, GGMLBackend) { // GGML WASI-NN load tests. // Test: load -- meaningless binaries. { - EXPECT_TRUE(HostFuncLoad.run(CallFrame, - std::initializer_list{ - LoadEntryPtr, UINT32_C(1), - static_cast(Backend::GGML), - UINT32_C(0), BuilderPtr}, - Errno)); + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::GGML), + static_cast(Device::CPU), BuilderPtr}, + Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); } // Test: load -- graph id ptr out of bounds. { - EXPECT_TRUE(HostFuncLoad.run(CallFrame, - std::initializer_list{ - LoadEntryPtr, UINT32_C(1), - static_cast(Backend::GGML), - UINT32_C(0), OutBoundPtr}, - Errno)); + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::GGML), + static_cast(Device::CPU), OutBoundPtr}, + Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); } // Test: load -- graph builder ptr out of bounds. { - EXPECT_TRUE(HostFuncLoad.run(CallFrame, - std::initializer_list{ - OutBoundPtr, UINT32_C(1), - static_cast(Backend::GGML), - UINT32_C(0), BuilderPtr}, - Errno)); + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + OutBoundPtr, UINT32_C(1), static_cast(Backend::GGML), + static_cast(Device::CPU), BuilderPtr}, + Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); } @@ -1350,12 +1369,12 @@ TEST(WasiNNTest, GGMLBackend) { writeFatPointer(MemInst, OutBoundPtr, static_cast(WeightRead.size()), BuilderPtr); { - EXPECT_TRUE(HostFuncLoad.run(CallFrame, - std::initializer_list{ - LoadEntryPtr, UINT32_C(1), - static_cast(Backend::GGML), - UINT32_C(0), BuilderPtr}, - Errno)); + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::GGML), + static_cast(Device::CPU), BuilderPtr}, + Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); } @@ -1367,24 +1386,24 @@ TEST(WasiNNTest, GGMLBackend) { writeBinaries(MemInst, WeightRead, StorePtr); StorePtr += static_cast(WeightRead.size()); { - EXPECT_TRUE(HostFuncLoad.run(CallFrame, - std::initializer_list{ - LoadEntryPtr, UINT32_C(2), - static_cast(Backend::GGML), - UINT32_C(0), BuilderPtr}, - Errno)); + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(2), static_cast(Backend::GGML), + static_cast(Device::CPU), BuilderPtr}, + Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidEncoding)); } // Test: load -- load successfully. { - EXPECT_TRUE(HostFuncLoad.run(CallFrame, - std::initializer_list{ - LoadEntryPtr, UINT32_C(1), - static_cast(Backend::GGML), - UINT32_C(0), BuilderPtr}, - Errno)); + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::GGML), + static_cast(Device::CPU), BuilderPtr}, + Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); BuilderPtr += 4; @@ -1528,7 +1547,7 @@ TEST(WasiNNTest, GGMLBackendWithRPC) { GTEST_SKIP() << "WASI_NN_RPC_TEST_URI is unset"; } - // Create the wasmedge_process module instance. + // Create the wasi_nn module instance. auto *NNMod = dynamic_cast(createModule(NNRPCURI)); EXPECT_FALSE(NNMod == nullptr); @@ -1743,13 +1762,15 @@ TEST(WasiNNTest, GGMLBackendWithRPC) { auto BytesWritten = *MemInst.getPointer(BuilderPtr); EXPECT_GE(BytesWritten, 50); } + + delete NNMod; } #endif // WASMEDGE_BUILD_WASI_NN_RPC #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED TEST(WasiNNTest, NeuralSpeedBackend) { - // Create the wasmedge_process module instance. + // Create the wasi_nn module instance. auto *NNMod = dynamic_cast(createModule()); ASSERT_TRUE(NNMod != nullptr); @@ -1827,7 +1848,7 @@ TEST(WasiNNTest, NeuralSpeedBackend) { std::initializer_list{ LoadEntryPtr, UINT32_C(1), static_cast(Backend::NeuralSpeed), - UINT32_C(0), BuilderPtr}, + static_cast(Device::CPU), BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); @@ -1839,7 +1860,7 @@ TEST(WasiNNTest, NeuralSpeedBackend) { std::initializer_list{ LoadEntryPtr, UINT32_C(1), static_cast(Backend::NeuralSpeed), - UINT32_C(0), OutBoundPtr}, + static_cast(Device::CPU), OutBoundPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); @@ -1852,7 +1873,7 @@ TEST(WasiNNTest, NeuralSpeedBackend) { std::initializer_list{ OutBoundPtr, UINT32_C(1), static_cast(Backend::NeuralSpeed), - UINT32_C(0), BuilderPtr}, + static_cast(Device::CPU), BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); @@ -1868,7 +1889,7 @@ TEST(WasiNNTest, NeuralSpeedBackend) { std::initializer_list{ LoadEntryPtr, UINT32_C(1), static_cast(Backend::NeuralSpeed), - UINT32_C(0), BuilderPtr}, + static_cast(Device::CPU), BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); @@ -1889,7 +1910,7 @@ TEST(WasiNNTest, NeuralSpeedBackend) { std::initializer_list{ LoadEntryPtr, UINT32_C(2), static_cast(Backend::NeuralSpeed), - UINT32_C(0), BuilderPtr}, + static_cast(Device::CPU), BuilderPtr}, Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); @@ -2008,9 +2029,244 @@ TEST(WasiNNTest, NeuralSpeedBackend) { Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); } + + delete NNMod; } #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER +TEST(WasiNNTest, WhisperBackend) { + // Create the wasi_nn module instance. + auto *NNMod = dynamic_cast(createModule()); + ASSERT_TRUE(NNMod != nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(60000))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + std::vector TensorData = + readEntireFile("./wasinn_whisper_fixtures/test.wav"); + std::vector WeightRead = + readEntireFile("./wasinn_whisper_fixtures/ggml-base.bin"); + std::vector TensorDim{1, static_cast(TensorData.size())}; + uint32_t BuilderPtr = UINT32_C(0); + uint32_t LoadEntryPtr = UINT32_C(0); + uint32_t SetInputEntryPtr = UINT32_C(0); + uint32_t OutBoundPtr = UINT32_C(61000) * UINT32_C(65536); + uint32_t StorePtr = UINT32_C(65536); + + // Return value. + std::array Errno = {UINT32_C(0)}; + + // Get the function "load". + auto *FuncInst = NNMod->findFuncExports("load"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncLoad = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "init_execution_context". + FuncInst = NNMod->findFuncExports("init_execution_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInit = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "set_input". + FuncInst = NNMod->findFuncExports("set_input"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSetInput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "get_output". + FuncInst = NNMod->findFuncExports("get_output"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetOutput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "compute". + FuncInst = NNMod->findFuncExports("compute"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncCompute = + dynamic_cast(FuncInst->getHostFunc()); + + // Whisper WASI-NN load tests. + // Test: load -- meaningless binaries. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::Whisper), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: load -- graph id ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::Whisper), + static_cast(Device::CPU), OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: load -- graph builder ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + OutBoundPtr, UINT32_C(1), static_cast(Backend::Whisper), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: load -- Whisper model bin ptr out of bounds. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, OutBoundPtr, + static_cast(WeightRead.size()), BuilderPtr); + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::Whisper), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: load -- load successfully. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, StorePtr, WeightRead.size(), BuilderPtr); + writeBinaries(MemInst, WeightRead, StorePtr); + StorePtr += WeightRead.size(); + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::Whisper), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Whisper WASI-NN init_execution_context tests. + // Test: init_execution_context -- graph id invalid. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(2), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: init_execution_context -- init second context. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Whisper WASI-NN set_input tests. + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); + + // Test: set_input -- context id exceeds. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(3), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: set_input -- set input successfully. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += (TensorDim.size() * 4 + TensorData.size()); + + // Whisper WASI-NN compute tests. + // Test: compute -- context id exceeds. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(3)}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: compute -- compute successfully. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(0)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Whisper WASI-NN get_output tests. + // Test: get_output -- output bytes ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: get_output -- output buffer ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), OutBoundPtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: get_output -- get output successfully. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + // Should output more than 50 bytes. + auto BytesWritten = *MemInst.getPointer(BuilderPtr); + EXPECT_GE(BytesWritten, 50); + } + + delete NNMod; +} +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER + #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER TEST(WasiNNTest, PiperBackend) { // Create the wasmedge_process module instance. @@ -2258,4 +2514,4 @@ TEST(WasiNNTest, PiperBackend) { EXPECT_GE(BytesWritten, 10000); } } -#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER \ No newline at end of file +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER From 3506a6287d75c60c0b3c832be56ce60051b935c2 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Thu, 11 Jul 2024 18:11:21 +0800 Subject: [PATCH 369/623] [CI] Add the build and release CI of stable diffusion plugin. Signed-off-by: YiYing He --- .../wasmedge_stablediffusion/CMakeLists.txt | 18 +++++++++--------- .../wasmedge_stablediffusion/CMakeLists.txt | 4 ++-- .../wasmedge_stablediffusion.cpp | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index 5ac3abcf..53409616 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -20,33 +20,33 @@ foreach(dep ${SD_DEPS}) endif() endforeach() -wasmedge_add_library(wasmedgePluginStableDiffusion +wasmedge_add_library(wasmedgePluginWasmEdgeStableDiffusion SHARED sd_env.cpp sd_func.cpp sd_module.cpp ) -target_link_libraries(wasmedgePluginStableDiffusion PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(wasmedgePluginWasmEdgeStableDiffusion PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) -target_compile_options(wasmedgePluginStableDiffusion +target_compile_options(wasmedgePluginWasmEdgeStableDiffusion PUBLIC -DWASMEDGE_PLUGIN ) if(WASMEDGE_LINK_PLUGINS_STATIC) - target_link_libraries(wasmedgePluginStableDiffusion + target_link_libraries(wasmedgePluginWasmEdgeStableDiffusion PRIVATE wasmedgeCAPI ) else() - target_link_libraries(wasmedgePluginStableDiffusion + target_link_libraries(wasmedgePluginWasmEdgeStableDiffusion PRIVATE wasmedge_shared ) endif() -target_include_directories(wasmedgePluginStableDiffusion +target_include_directories(wasmedgePluginWasmEdgeStableDiffusion PUBLIC $ ${CMAKE_CURRENT_SOURCE_DIR} @@ -54,7 +54,7 @@ target_include_directories(wasmedgePluginStableDiffusion if (MSVC) target_compile_options( - wasmedgePluginStableDiffusion + wasmedgePluginWasmEdgeStableDiffusion PRIVATE /wd4459 /wd4100 @@ -63,7 +63,7 @@ if (MSVC) ) else() target_compile_options( - wasmedgePluginStableDiffusion + wasmedgePluginWasmEdgeStableDiffusion PRIVATE -Wno-unused-function -Wno-unused-variable @@ -72,4 +72,4 @@ else() ) endif() -install(TARGETS wasmedgePluginStableDiffusion DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) +install(TARGETS wasmedgePluginWasmEdgeStableDiffusion DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) diff --git a/test/plugins/wasmedge_stablediffusion/CMakeLists.txt b/test/plugins/wasmedge_stablediffusion/CMakeLists.txt index fa321903..e2aaf279 100644 --- a/test/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/test/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -3,13 +3,13 @@ wasmedge_add_executable(wasmedgeStableDiffusionTests ) add_dependencies(wasmedgeStableDiffusionTests - wasmedgePluginStableDiffusion + wasmedgePluginWasmEdgeStableDiffusion ) target_include_directories(wasmedgeStableDiffusionTests PUBLIC $ - $ + $ ) target_link_libraries(wasmedgeStableDiffusionTests diff --git a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp index 57dc372c..b623ce47 100644 --- a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp +++ b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp @@ -17,7 +17,7 @@ WasmEdge::Runtime::Instance::ModuleInstance *createModule() { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( "../../../plugins/wasmedge_stablediffusion/" WASMEDGE_LIB_PREFIX - "wasmedgePluginStableDiffusion" WASMEDGE_LIB_EXTENSION)); + "wasmedgePluginWasmEdgeStableDiffusion" WASMEDGE_LIB_EXTENSION)); if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasmedge_stablediffusion"sv)) { if (const auto *Module = Plugin->findModule("wasmedge_stablediffusion"sv)) { From ac16b0d51aa9180ab25ffd278b5ab2c908de2e3e Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 23 Jul 2024 21:56:43 +0800 Subject: [PATCH 370/623] [WASI-NN] ggml: add json-schema suuport and bump llama.cpp b3445 Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- plugins/wasi_nn/ggml.cpp | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index eb14e3cc..89174086 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -66,7 +66,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b3405 + GIT_TAG b3445 GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(llama) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 4bb467a3..c91ea2b7 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -12,6 +12,8 @@ #include #include #include +#include +#include #include #include #include @@ -310,6 +312,17 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, } GraphRef.Grammar = Grammar; } + if (Doc.at_key("json-schema").error() == simdjson::SUCCESS) { + std::string_view JsonSchema; + auto Err = Doc["json-schema"].get().get(JsonSchema); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the json-schema option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Grammar = + json_schema_to_grammar(nlohmann::ordered_json::parse(JsonSchema)); + } // Check if the model is updated. if (IsModelUpdated && ModelParams.n_gpu_layers != GraphRef.NGPULayers) { From 463ffc63d8a5d1da3400975fe96d36ed08f103c8 Mon Sep 17 00:00:00 2001 From: hydai Date: Fri, 26 Jul 2024 03:56:41 +0800 Subject: [PATCH 371/623] [WASI-NN] ggml: bump to b3463 and allow forcely disable the native isa extensions Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 89174086..34a16032 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -31,10 +31,15 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) set(BUILD_SHARED_LIBS OFF) if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_NATIVE) - message(STATUS "WASI-NN GGML LLAMA backend: Enable GGML_NATIVE(AVX/AVX2/FMA)") + message(STATUS "WASI-NN GGML LLAMA backend: Enable GGML_NATIVE(AVX/AVX2/FMA/F16C)") set(GGML_NATIVE ON) else() + message(STATUS "WASI-NN GGML LLAMA backend: Disable GGML_NATIVE(AVX/AVX2/FMA/F16C)") set(GGML_NATIVE OFF) + set(GGML_AVX OFF) + set(GGML_AVX2 OFF) + set(GGML_FMA OFF) + set(GGML_F16C OFF) endif() if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_CUBLAS) @@ -66,7 +71,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b3445 + GIT_TAG b3463 GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(llama) From f8f1ec0b6d44f76a968f1919301f4a85a6618ee9 Mon Sep 17 00:00:00 2001 From: grorge Date: Fri, 26 Jul 2024 21:17:25 +0800 Subject: [PATCH 372/623] [WASI-NN] neural speed: fix weak pointer of import module Signed-off-by: grorge --- plugins/wasi_nn/neuralspeed.cpp | 13 +++++++------ plugins/wasi_nn/neuralspeed.h | 12 ++++++------ 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/plugins/wasi_nn/neuralspeed.cpp b/plugins/wasi_nn/neuralspeed.cpp index 6a04da17..f3fc2c9f 100644 --- a/plugins/wasi_nn/neuralspeed.cpp +++ b/plugins/wasi_nn/neuralspeed.cpp @@ -69,7 +69,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, Env.NNGraph.pop_back(); return ErrNo::InvalidArgument; } - GraphRef.model_type = ModelType; + GraphRef.ModelType = ModelType; } } @@ -114,8 +114,10 @@ Expect load(WASINN::WasiNNEnvironment &Env, if (!Py_IsInitialized()) { Py_Initialize(); } - GraphRef.NeuralSpeedModule = - PyImport_Import(PyUnicode_FromString("neural_speed")); + if (GraphRef.NeuralSpeedModule == nullptr) { + GraphRef.NeuralSpeedModule = + PyImport_Import(PyUnicode_FromString("neural_speed")); + } if (GraphRef.NeuralSpeedModule == nullptr) { PyErr_Print(); spdlog::error( @@ -137,18 +139,16 @@ Expect load(WASINN::WasiNNEnvironment &Env, if (GraphRef.Model == nullptr) { spdlog::error("[WASI-NN] neural speed backend: Load model error."sv); Py_XDECREF(GraphRef.ModelClass); - Py_XDECREF(GraphRef.NeuralSpeedModule); Env.NNGraph.pop_back(); return WASINN::ErrNo::InvalidArgument; } PyObject *LoadResult = PyObject_CallMethod(GraphRef.Model, "init_from_bin", "(ss)", - GraphRef.model_type.c_str(), ModelFilePath.c_str()); + GraphRef.ModelType.c_str(), ModelFilePath.c_str()); if (LoadResult == nullptr) { spdlog::error("[WASI-NN] neural speed backend: Load model error."sv); Py_XDECREF(GraphRef.Model); Py_XDECREF(GraphRef.ModelClass); - Py_XDECREF(GraphRef.NeuralSpeedModule); Env.NNGraph.pop_back(); return WASINN::ErrNo::InvalidArgument; } @@ -314,6 +314,7 @@ Expect unload(WASINN::WasiNNEnvironment &Env, Py_XDECREF(GraphRef.Model); Py_XDECREF(GraphRef.ModelClass); Py_XDECREF(GraphRef.NeuralSpeedModule); + GraphRef.NeuralSpeedModule = nullptr; Py_Finalize(); } return WASINN::ErrNo::Success; diff --git a/plugins/wasi_nn/neuralspeed.h b/plugins/wasi_nn/neuralspeed.h index f087f151..fcdf6a5a 100644 --- a/plugins/wasi_nn/neuralspeed.h +++ b/plugins/wasi_nn/neuralspeed.h @@ -18,7 +18,7 @@ namespace WasmEdge::Host::WASINN::NeuralSpeed { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED struct Graph { bool EnableDebugLog = true; - std::string model_type = "llama"; + std::string ModelType = "llama"; inline static int GraphNumber = 0; Graph() noexcept { Py_Initialize(); } ~Graph() noexcept { @@ -28,11 +28,11 @@ struct Graph { Py_XDECREF(NeuralSpeedModule); } } - PyObject *Model; - PyObject *NeuralSpeedModule; - PyObject *ModelClass; - int64_t LoadTime; - int64_t ComputeTime; + PyObject *Model = nullptr; + PyObject *NeuralSpeedModule = nullptr; + PyObject *ModelClass = nullptr; + int64_t LoadTime = 0; + int64_t ComputeTime = 0; }; struct Context { Context(size_t Gid, Graph &) noexcept : GraphId(Gid) {} From ad70577460b1f3ab9f8316ab3f800b035604950b Mon Sep 17 00:00:00 2001 From: grorge Date: Fri, 26 Jul 2024 22:42:40 +0800 Subject: [PATCH 373/623] [WASI-NN] neural speed: fix release reference Signed-off-by: grorge --- plugins/wasi_nn/neuralspeed.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/plugins/wasi_nn/neuralspeed.cpp b/plugins/wasi_nn/neuralspeed.cpp index f3fc2c9f..cdc86f23 100644 --- a/plugins/wasi_nn/neuralspeed.cpp +++ b/plugins/wasi_nn/neuralspeed.cpp @@ -125,20 +125,20 @@ Expect load(WASINN::WasiNNEnvironment &Env, Env.NNGraph.pop_back(); return WASINN::ErrNo::RuntimeError; } - GraphRef.ModelClass = - PyObject_GetAttrString(GraphRef.NeuralSpeedModule, "Model"); + if (GraphRef.ModelClass == nullptr) { + GraphRef.ModelClass = + PyObject_GetAttrString(GraphRef.NeuralSpeedModule, "Model"); + } if (GraphRef.ModelClass == nullptr || !PyCallable_Check(GraphRef.ModelClass)) { spdlog::error( "[WASI-NN] neural speed backend: Can not find Model class in neural speed."sv); - Py_XDECREF(GraphRef.Model); Env.NNGraph.pop_back(); return WASINN::ErrNo::RuntimeError; } GraphRef.Model = PyObject_CallObject(GraphRef.ModelClass, NULL); if (GraphRef.Model == nullptr) { spdlog::error("[WASI-NN] neural speed backend: Load model error."sv); - Py_XDECREF(GraphRef.ModelClass); Env.NNGraph.pop_back(); return WASINN::ErrNo::InvalidArgument; } @@ -148,7 +148,6 @@ Expect load(WASINN::WasiNNEnvironment &Env, if (LoadResult == nullptr) { spdlog::error("[WASI-NN] neural speed backend: Load model error."sv); Py_XDECREF(GraphRef.Model); - Py_XDECREF(GraphRef.ModelClass); Env.NNGraph.pop_back(); return WASINN::ErrNo::InvalidArgument; } @@ -285,11 +284,9 @@ Expect compute(WasiNNEnvironment &Env, if (PyLong_Check(Num)) { InnerVec.push_back(PyLong_AsLong(Num)); } - Py_DECREF(Num); } CxtRef.Outputs = InnerVec; } - Py_DECREF(InnerList); } } Py_DECREF(Result); @@ -314,6 +311,10 @@ Expect unload(WASINN::WasiNNEnvironment &Env, Py_XDECREF(GraphRef.Model); Py_XDECREF(GraphRef.ModelClass); Py_XDECREF(GraphRef.NeuralSpeedModule); + spdlog::info("[WASI-NN] Neural speed backend: Finish unload. {} {}"sv, + GraphRef.ModelClass == nullptr, + GraphRef.NeuralSpeedModule == nullptr); + GraphRef.ModelClass = nullptr; GraphRef.NeuralSpeedModule = nullptr; Py_Finalize(); } From 0269fbc11fd6ad18c5720d79751901650882abbb Mon Sep 17 00:00:00 2001 From: grorge Date: Sat, 27 Jul 2024 13:16:12 +0800 Subject: [PATCH 374/623] [WASI-NN] neural speed: release string reference Signed-off-by: grorge --- plugins/wasi_nn/neuralspeed.cpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/plugins/wasi_nn/neuralspeed.cpp b/plugins/wasi_nn/neuralspeed.cpp index cdc86f23..e670b342 100644 --- a/plugins/wasi_nn/neuralspeed.cpp +++ b/plugins/wasi_nn/neuralspeed.cpp @@ -115,8 +115,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, Py_Initialize(); } if (GraphRef.NeuralSpeedModule == nullptr) { - GraphRef.NeuralSpeedModule = - PyImport_Import(PyUnicode_FromString("neural_speed")); + GraphRef.NeuralSpeedModule = PyImport_ImportModule("neural_speed"); } if (GraphRef.NeuralSpeedModule == nullptr) { PyErr_Print(); @@ -264,8 +263,10 @@ Expect compute(WasiNNEnvironment &Env, "[WASI-NN] neural speed backend: Input transfer tensor failed."sv); return WASINN::ErrNo::InvalidArgument; } - PyObject *Result = PyObject_CallMethodObjArgs( - GraphRef.Model, PyUnicode_FromString("generate"), LongTensor, NULL); + PyObject *GenerateString = PyUnicode_FromString("generate"); + PyObject *Result = PyObject_CallMethodObjArgs(GraphRef.Model, GenerateString, + LongTensor, NULL); + Py_DECREF(GenerateString); if (Result == nullptr) { PyErr_Print(); spdlog::error( @@ -311,9 +312,6 @@ Expect unload(WASINN::WasiNNEnvironment &Env, Py_XDECREF(GraphRef.Model); Py_XDECREF(GraphRef.ModelClass); Py_XDECREF(GraphRef.NeuralSpeedModule); - spdlog::info("[WASI-NN] Neural speed backend: Finish unload. {} {}"sv, - GraphRef.ModelClass == nullptr, - GraphRef.NeuralSpeedModule == nullptr); GraphRef.ModelClass = nullptr; GraphRef.NeuralSpeedModule = nullptr; Py_Finalize(); From d23c34033606e847f54e0b71b089bc3d1f92f345 Mon Sep 17 00:00:00 2001 From: hydai Date: Sat, 27 Jul 2024 06:14:27 +0800 Subject: [PATCH 375/623] [Misc] Append EOF via linelint Signed-off-by: hydai --- plugins/wasm_bpf/README.md | 2 +- plugins/wasmedge_ffmpeg/avcodec/avCodecContext.h | 2 +- plugins/wasmedge_ffmpeg/avdevice/avDevice_func.cpp | 2 +- plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.cpp | 2 +- plugins/wasmedge_ffmpeg/avformat/module.h | 2 +- plugins/wasmedge_ffmpeg/bindings.h | 2 +- plugins/wasmedge_ffmpeg/ffmpeg_env.cpp | 2 +- plugins/wasmedge_ffmpeg/ffmpeg_env.h | 2 +- plugins/wasmedge_ffmpeg/swresample/swresample_func.h | 2 +- plugins/wasmedge_stablediffusion/sd_base.h | 2 +- test/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp | 2 +- test/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp | 2 +- test/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp | 2 +- test/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp | 2 +- test/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp | 2 +- test/plugins/wasmedge_ffmpeg/avformat/avInputOutputContext.cpp | 2 +- test/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp | 2 +- test/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp | 2 +- test/plugins/wasmedge_ffmpeg/avutil/avError.cpp | 2 +- test/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp | 2 +- test/plugins/wasmedge_ffmpeg/avutil/avPixfmt.cpp | 2 +- test/plugins/wasmedge_ffmpeg/avutil/avSampleFmt.cpp | 2 +- test/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp | 2 +- test/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp | 2 +- test/plugins/wasmedge_stablediffusion/CMakeLists.txt | 2 +- thirdparty/wasi_crypto/api.hpp | 1 - utils/ffmpeg/install-ffmpeg-v6.0.sh | 2 +- utils/wasi-nn/install-neuralspeed.sh | 2 +- .../0001-PATCH-Disable-other-tests-except-wasmedge.patch | 1 - 29 files changed, 27 insertions(+), 29 deletions(-) diff --git a/plugins/wasm_bpf/README.md b/plugins/wasm_bpf/README.md index 05a2755f..824bd062 100644 --- a/plugins/wasm_bpf/README.md +++ b/plugins/wasm_bpf/README.md @@ -106,4 +106,4 @@ Set `WASMEDGE_PLUGIN_PATH=./build/plugins/wasm_bpf/` and run wasmedge: [289159] cpuUsage.sh -> cat /proc/289148/stat [289160] cpuUsage.sh -> sleep 1 ^C -``` \ No newline at end of file +``` diff --git a/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.h b/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.h index ee39aa0b..356924d8 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.h +++ b/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.h @@ -813,4 +813,4 @@ class AVCodecCtxColorPrimaries } // namespace AVcodec } // namespace WasmEdgeFFmpeg } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.cpp b/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.cpp index b5276302..9352680c 100644 --- a/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.cpp +++ b/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.cpp @@ -121,4 +121,4 @@ Expect AVDeviceLicense::body(const Runtime::CallingFrame &Frame, } // namespace AVDevice } // namespace WasmEdgeFFmpeg } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.cpp b/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.cpp index 5c0628f2..072104c5 100644 --- a/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.cpp +++ b/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.cpp @@ -212,4 +212,4 @@ Expect AVInputOutputFormatFree::body(const Runtime::CallingFrame &, } // namespace AVFormat } // namespace WasmEdgeFFmpeg } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/module.h b/plugins/wasmedge_ffmpeg/avformat/module.h index 4ab491ed..8e5d9740 100644 --- a/plugins/wasmedge_ffmpeg/avformat/module.h +++ b/plugins/wasmedge_ffmpeg/avformat/module.h @@ -16,4 +16,4 @@ class WasmEdgeFFmpegAVFormatModule : public Runtime::Instance::ModuleInstance { } // namespace AVFormat } // namespace WasmEdgeFFmpeg } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/bindings.h b/plugins/wasmedge_ffmpeg/bindings.h index 748f5987..05858cd4 100644 --- a/plugins/wasmedge_ffmpeg/bindings.h +++ b/plugins/wasmedge_ffmpeg/bindings.h @@ -4422,4 +4422,4 @@ class ColorPrimaries { } // namespace FFmpegUtils } // namespace WasmEdgeFFmpeg } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/ffmpeg_env.cpp b/plugins/wasmedge_ffmpeg/ffmpeg_env.cpp index f2705c4a..8ac02fc1 100644 --- a/plugins/wasmedge_ffmpeg/ffmpeg_env.cpp +++ b/plugins/wasmedge_ffmpeg/ffmpeg_env.cpp @@ -109,4 +109,4 @@ std::weak_ptr std::shared_mutex WasmEdgeFFmpeg::WasmEdgeFFmpegEnv::Mutex; } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/ffmpeg_env.h b/plugins/wasmedge_ffmpeg/ffmpeg_env.h index e584fc5d..ebb8ba98 100644 --- a/plugins/wasmedge_ffmpeg/ffmpeg_env.h +++ b/plugins/wasmedge_ffmpeg/ffmpeg_env.h @@ -107,4 +107,4 @@ enum class ErrNo : int32_t { } // namespace WasmEdgeFFmpeg } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/swresample/swresample_func.h b/plugins/wasmedge_ffmpeg/swresample/swresample_func.h index b8dd8d7f..a79ae568 100644 --- a/plugins/wasmedge_ffmpeg/swresample/swresample_func.h +++ b/plugins/wasmedge_ffmpeg/swresample/swresample_func.h @@ -104,4 +104,4 @@ class SWResampleLicense : public WasmEdgeFFmpegSWResample { } // namespace SWResample } // namespace WasmEdgeFFmpeg } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/plugins/wasmedge_stablediffusion/sd_base.h b/plugins/wasmedge_stablediffusion/sd_base.h index c29b77aa..c8ae12f0 100644 --- a/plugins/wasmedge_stablediffusion/sd_base.h +++ b/plugins/wasmedge_stablediffusion/sd_base.h @@ -25,4 +25,4 @@ template class Func : public Runtime::HostFunction { } // namespace StableDiffusion } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp b/test/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp index 6eb3a3d4..eb2e8a95 100644 --- a/test/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp +++ b/test/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp @@ -362,4 +362,4 @@ TEST_F(FFmpegTest, AVCodec) { } } // namespace WasmEdgeFFmpeg } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp b/test/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp index 1217914e..9fc5029c 100644 --- a/test/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp +++ b/test/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp @@ -72,4 +72,4 @@ TEST_F(FFmpegTest, AVCodecParameters) { } } // namespace WasmEdgeFFmpeg } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp b/test/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp index 4d7348a3..88b5a226 100644 --- a/test/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp +++ b/test/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp @@ -365,4 +365,4 @@ TEST_F(FFmpegTest, AVPacketTest) { } } // namespace WasmEdgeFFmpeg } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp b/test/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp index 7a6675fe..2a936fbd 100644 --- a/test/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp +++ b/test/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp @@ -679,4 +679,4 @@ TEST_F(FFmpegTest, AVFilterFunc) { } // namespace WasmEdgeFFmpeg } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp b/test/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp index 94dd8c63..3dbef770 100644 --- a/test/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp +++ b/test/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp @@ -217,4 +217,4 @@ TEST_F(FFmpegTest, AVChapter) { } // namespace WasmEdgeFFmpeg } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avformat/avInputOutputContext.cpp b/test/plugins/wasmedge_ffmpeg/avformat/avInputOutputContext.cpp index 3a575551..c78db2f6 100644 --- a/test/plugins/wasmedge_ffmpeg/avformat/avInputOutputContext.cpp +++ b/test/plugins/wasmedge_ffmpeg/avformat/avInputOutputContext.cpp @@ -204,4 +204,4 @@ TEST_F(FFmpegTest, AVInputFormat) { } // namespace WasmEdgeFFmpeg } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp b/test/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp index ede24ec5..27c7f511 100644 --- a/test/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp +++ b/test/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp @@ -181,4 +181,4 @@ TEST_F(FFmpegTest, AVFormatContextStruct) { } // namespace WasmEdgeFFmpeg } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp b/test/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp index c2fc4fee..0b7c2cb8 100644 --- a/test/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp +++ b/test/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp @@ -585,4 +585,4 @@ TEST_F(FFmpegTest, AVOutputFormatFunc) { } // namespace WasmEdgeFFmpeg } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avError.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avError.cpp index 454d38a2..bc5e6dab 100644 --- a/test/plugins/wasmedge_ffmpeg/avutil/avError.cpp +++ b/test/plugins/wasmedge_ffmpeg/avutil/avError.cpp @@ -62,4 +62,4 @@ TEST_F(FFmpegTest, AVError) { } // namespace WasmEdgeFFmpeg } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp index 3f4c6269..5ffa1d9d 100644 --- a/test/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp +++ b/test/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp @@ -778,4 +778,4 @@ TEST_F(FFmpegTest, AVFrame) { } // namespace WasmEdgeFFmpeg } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avPixfmt.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avPixfmt.cpp index afef37c1..91b3399a 100644 --- a/test/plugins/wasmedge_ffmpeg/avutil/avPixfmt.cpp +++ b/test/plugins/wasmedge_ffmpeg/avutil/avPixfmt.cpp @@ -241,4 +241,4 @@ TEST_F(FFmpegTest, AVPixFmt) { } // namespace WasmEdgeFFmpeg } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avSampleFmt.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avSampleFmt.cpp index c4f13e04..93593079 100644 --- a/test/plugins/wasmedge_ffmpeg/avutil/avSampleFmt.cpp +++ b/test/plugins/wasmedge_ffmpeg/avutil/avSampleFmt.cpp @@ -198,4 +198,4 @@ TEST_F(FFmpegTest, AVSampleFmt) { } // namespace WasmEdgeFFmpeg } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp index c6431ece..43c75cee 100644 --- a/test/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp +++ b/test/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp @@ -258,4 +258,4 @@ TEST_F(FFmpegTest, AVTime) { } // namespace WasmEdgeFFmpeg } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp b/test/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp index c5879355..1fc4f905 100644 --- a/test/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp +++ b/test/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp @@ -252,4 +252,4 @@ TEST_F(FFmpegTest, SWResampleFunc) { } // namespace WasmEdgeFFmpeg } // namespace Host -} // namespace WasmEdge \ No newline at end of file +} // namespace WasmEdge diff --git a/test/plugins/wasmedge_stablediffusion/CMakeLists.txt b/test/plugins/wasmedge_stablediffusion/CMakeLists.txt index e2aaf279..dac8abd7 100644 --- a/test/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/test/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -44,4 +44,4 @@ download( MD5=c01059060130b8242849d86e97212c84 ) -add_test(wasmedgeStableDiffusionTests wasmedgeStableDiffusionTests) \ No newline at end of file +add_test(wasmedgeStableDiffusionTests wasmedgeStableDiffusionTests) diff --git a/thirdparty/wasi_crypto/api.hpp b/thirdparty/wasi_crypto/api.hpp index 306526b5..7ee8b90a 100644 --- a/thirdparty/wasi_crypto/api.hpp +++ b/thirdparty/wasi_crypto/api.hpp @@ -673,4 +673,3 @@ using __wasi_signature_secretkey_t = __wasi_secretkey_t; static_assert(sizeof(__wasi_signature_secretkey_t) == 4, "witx calculated size"); static_assert(alignof(__wasi_signature_secretkey_t) == 4, "witx calculated align"); - diff --git a/utils/ffmpeg/install-ffmpeg-v6.0.sh b/utils/ffmpeg/install-ffmpeg-v6.0.sh index 72ac1458..479908ce 100755 --- a/utils/ffmpeg/install-ffmpeg-v6.0.sh +++ b/utils/ffmpeg/install-ffmpeg-v6.0.sh @@ -10,4 +10,4 @@ mkdir -p FFmpeg-n6.0/output cd FFmpeg-n6.0 ./configure --prefix=$(pwd)/output --enable-gpl --enable-nonfree --enable-shared --disable-static make && make install -cd .. \ No newline at end of file +cd .. diff --git a/utils/wasi-nn/install-neuralspeed.sh b/utils/wasi-nn/install-neuralspeed.sh index 13d3f043..23838c02 100644 --- a/utils/wasi-nn/install-neuralspeed.sh +++ b/utils/wasi-nn/install-neuralspeed.sh @@ -8,4 +8,4 @@ apt install -y python3-dev python3-pip echo "Installing Neural Speed!" wget https://raw.githubusercontent.com/intel/neural-speed/main/requirements.txt pip install -r requirements.txt -pip install neural-speed==${NEURALSPEED_VERSION} \ No newline at end of file +pip install neural-speed==${NEURALSPEED_VERSION} diff --git a/utils/wasi-test/0001-PATCH-Disable-other-tests-except-wasmedge.patch b/utils/wasi-test/0001-PATCH-Disable-other-tests-except-wasmedge.patch index e6abfe00..937bbf0f 100644 --- a/utils/wasi-test/0001-PATCH-Disable-other-tests-except-wasmedge.patch +++ b/utils/wasi-test/0001-PATCH-Disable-other-tests-except-wasmedge.patch @@ -40,4 +40,3 @@ index 9341520..da54714 100755 -- 2.31.1 - From be6e6e3f547cc39d4c1f0c4d485b4d7e5ce5f2e7 Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Tue, 16 Jul 2024 14:30:00 +0800 Subject: [PATCH 376/623] [WASINN] Turn off warning for llama.cpp Signed-off-by: Shen-Ta Hsieh --- plugins/wasi_nn/CMakeLists.txt | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 34a16032..3e8ba6c6 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -123,6 +123,14 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) -Wno-unused-macros -Wno-unused-function -Wno-unused-variable + -Wno-sign-conversion + -Wno-shorten-64-to-32 + -Wno-implicit-int-conversion + -Wno-old-style-cast + -Wno-extra-semi-stmt + -Wno-format-nonliteral + -Wno-documentation + -Wno-unused-template ) endif() target_link_libraries(llava PRIVATE ggml llama) From 6ce5b9edd73a7edeff0b0581862354d16fd7f15b Mon Sep 17 00:00:00 2001 From: YiYing He Date: Tue, 23 Jul 2024 12:51:12 +0800 Subject: [PATCH 377/623] [Plugin] Fix the Darwin build of WasmEdge-StableDiffusion plug-in. Signed-off-by: YiYing He --- plugins/CMakeLists.txt | 5 + .../wasmedge_stablediffusion/CMakeLists.txt | 1 + plugins/wasmedge_stablediffusion/sd_func.cpp | 4 +- .../wasmedge_stablediffusion.cpp | 151 +++++++++--------- 4 files changed, 80 insertions(+), 81 deletions(-) diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index a479647d..2061e166 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -67,7 +67,12 @@ if(WASMEDGE_PLUGIN_WASI_OCR) endif() if(WASMEDGE_PLUGIN_STABLEDIFFUSION) + # Only Linux and MacOS support wasmedge_stablediffusion now. + if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") add_subdirectory(wasmedge_stablediffusion) + else() + message(WARNING "Only Linux and Darwin platforms support WasmEdge_StableDiffusion plug-in now.") + endif() endif() if(WASMEDGE_PLUGIN_OPENCVMINI) diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index 53409616..7dd00371 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -69,6 +69,7 @@ else() -Wno-unused-variable -Wno-unused-parameter -Wno-missing-field-initializers + -Wno-deprecated-declarations ) endif() diff --git a/plugins/wasmedge_stablediffusion/sd_func.cpp b/plugins/wasmedge_stablediffusion/sd_func.cpp index 564e8e51..8a74cb2f 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.cpp +++ b/plugins/wasmedge_stablediffusion/sd_func.cpp @@ -138,8 +138,8 @@ Expect SDConvert::body(const Runtime::CallingFrame &Frame, } Fin.close(); // Convert model. - bool Ret = convert(ModelPath.data(), VaeModelPath.data(), OutputPath.data(), - static_cast(WType)); + bool Ret = ::convert(ModelPath.data(), VaeModelPath.data(), OutputPath.data(), + static_cast(WType)); if (!Ret) { spdlog::error("[WasmEdge-StableDiffusion] Failed to convert model."); return static_cast(ErrNo::InvalidArgument); diff --git a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp index b623ce47..cbed4455 100644 --- a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp +++ b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp @@ -134,10 +134,7 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { 8}, // SD_TYPE_Q8_0 = 8 Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); - std::ifstream Fin(QuantModelPath.data(), - std::ios::in | std::ios::binary | std::ios::ate); - EXPECT_FALSE(Fin.fail()); - Fin.close(); + EXPECT_TRUE(std::filesystem::exists(QuantModelPathString)); } // Test: create_context -- create context for text to image. { @@ -182,45 +179,43 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { OutputPtr = BytesWrittenPtr + 4; writeBinaries(MemInst, PromptData, PromptPtr); writeBinaries(MemInst, OutputPath, OutputPathPtr); - EXPECT_TRUE( - HostFuncTextToImage.run(CallFrame, - std::initializer_list{ - PromptPtr, // PromptPtr - PromptData.size(), // PromptLen - SessionId, // SessionId - 0, // ControlImagePtr - 0, // ControlImageLen - 0, // NegativePromptPtr - 0, // NegativePromptLen - 64, // Width - 64, // Height - -1, // ClipSkip - 7.0f, // CfgScale - 0, // SampleMethod - 20, // SampleSteps - 42, // Seed - 1, // BatchCount - 0.90f, // ControlStrength - 20.0f, // StyleRatio - 0, // NormalizeInput - 0, // InputIdImagesDirPtr - 0, // InputIdImagesDirLen - 0, // CannyPreprocess - 0, // UpscaleModelPathPtr - 0, // UpscaleModelPathLen - 1, // UpscaleRepeats - OutputPathPtr, // OutputPathPtr - OutputPath.size(), // OutputPathLen - OutputPtr, // OutBufferPtr - 65532, // OutBufferMaxSize - BytesWrittenPtr}, // BytesWrittenPtr - Errno)); + EXPECT_TRUE(HostFuncTextToImage.run( + CallFrame, + std::initializer_list{ + PromptPtr, // PromptPtr + static_cast(PromptData.size()), // PromptLen + SessionId, // SessionId + 0, // ControlImagePtr + 0, // ControlImageLen + 0, // NegativePromptPtr + 0, // NegativePromptLen + 64, // Width + 64, // Height + -1, // ClipSkip + 7.0f, // CfgScale + 0, // SampleMethod + 20, // SampleSteps + 42, // Seed + 1, // BatchCount + 0.90f, // ControlStrength + 20.0f, // StyleRatio + 0, // NormalizeInput + 0, // InputIdImagesDirPtr + 0, // InputIdImagesDirLen + 0, // CannyPreprocess + 0, // UpscaleModelPathPtr + 0, // UpscaleModelPathLen + 1, // UpscaleRepeats + OutputPathPtr, // OutputPathPtr + static_cast(OutputPath.size()), // OutputPathLen + OutputPtr, // OutBufferPtr + 65532, // OutBufferMaxSize + BytesWrittenPtr}, // BytesWrittenPtr + Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); auto BytesWritten = *MemInst.getPointer(BytesWrittenPtr); EXPECT_GE(BytesWritten, 50); - std::ifstream Fin(OutputPathString, std::ios::in | std::ios::binary); - EXPECT_FALSE(Fin.fail()); - Fin.close(); + EXPECT_TRUE(std::filesystem::exists(OutputPathString)); } writeBinaries(MemInst, ModelPath, ModelPathPtr); writeBinaries(MemInst, QuantModelPath, QuantModelPathPtr); @@ -268,48 +263,46 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { writeBinaries(MemInst, PromptData2, PromptPtr); writeBinaries(MemInst, InputPath, InputPathPtr); writeBinaries(MemInst, OutputPath2, OutputPathPtr); - EXPECT_TRUE( - HostFuncImageToImage.run(CallFrame, - std::initializer_list{ - InputPathPtr, // ImagePtr - InputPath.size(), // ImageLen - SessionId, // SessionId - 64, // Width - 64, // Height - 0, // ControlImagePtr - 0, // ControlImageLen - PromptPtr, // PromptPtr - PromptData2.size(), // PromptLen - 0, // NegativePromptPtr - 0, // NegativePromptLen - -1, // ClipSkip - 7.0f, // CfgScale - 0, // SampleMethod - 20, // SampleSteps - 0.75f, // Strength - 42, // Seed - 1, // BatchCount - 0.9f, // ControlStrength - 20.0f, // StyleRatio - 0, // NormalizeInput - 0, // InputIdImagesDirPtr - 0, // InputIdImagesDirLen - 0, // CannyPreprocess - 0, // UpscaleModelPathPtr - 0, // UpscaleModelPathLen - 1, // UpscaleRepeats - OutputPathPtr, // OutputPathPtr - OutputPath2.size(), // OutputPathLen - OutputPtr, // OutBufferPtr - 65532, // OutBufferMaxSize - BytesWrittenPtr}, // BytesWrittenPtr - Errno)); + EXPECT_TRUE(HostFuncImageToImage.run( + CallFrame, + std::initializer_list{ + InputPathPtr, // ImagePtr + static_cast(InputPath.size()), // ImageLen + SessionId, // SessionId + 64, // Width + 64, // Height + 0, // ControlImagePtr + 0, // ControlImageLen + PromptPtr, // PromptPtr + static_cast(PromptData2.size()), // PromptLen + 0, // NegativePromptPtr + 0, // NegativePromptLen + -1, // ClipSkip + 7.0f, // CfgScale + 0, // SampleMethod + 20, // SampleSteps + 0.75f, // Strength + 42, // Seed + 1, // BatchCount + 0.9f, // ControlStrength + 20.0f, // StyleRatio + 0, // NormalizeInput + 0, // InputIdImagesDirPtr + 0, // InputIdImagesDirLen + 0, // CannyPreprocess + 0, // UpscaleModelPathPtr + 0, // UpscaleModelPathLen + 1, // UpscaleRepeats + OutputPathPtr, // OutputPathPtr + static_cast(OutputPath2.size()), // OutputPathLen + OutputPtr, // OutBufferPtr + 65532, // OutBufferMaxSize + BytesWrittenPtr}, // BytesWrittenPtr + Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); auto BytesWritten = *MemInst.getPointer(BytesWrittenPtr); EXPECT_GE(BytesWritten, 50); - std::ifstream Fin(OutputPathString2, std::ios::in | std::ios::binary); - EXPECT_FALSE(Fin.fail()); - Fin.close(); + EXPECT_TRUE(std::filesystem::exists(OutputPathString2)); } delete SBMod; } From ebfe1774af8038d050b8c973c06dea14afebc86d Mon Sep 17 00:00:00 2001 From: YiYing He Date: Thu, 25 Jul 2024 21:27:36 +0800 Subject: [PATCH 378/623] [Plugin] Turn off accelerate of GGML in MacOS 13.x or lower. Signed-off-by: YiYing He --- plugins/wasi_nn/CMakeLists.txt | 5 ++++- plugins/wasmedge_stablediffusion/CMakeLists.txt | 4 ++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 3e8ba6c6..e5799f20 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -174,7 +174,10 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) target_include_directories(wasmedgePluginWasiNN PRIVATE ${piper_SOURCE_DIR}/src/cpp) target_link_libraries(wasmedgePluginWasiNN PRIVATE simdjson::simdjson) elseif(BACKEND STREQUAL "whisper") - set(WHISPER_NO_ACCELERATE ON CACHE INTERNAL "Whisper turn off accelerate") + if(APPLE AND CMAKE_SYSTEM_VERSION VERSION_LESS 23) + # `cblas_sgemm()` introduced in macOS 13.3. + set(WHISPER_NO_ACCELERATE ON CACHE INTERNAL "Stable diffusion turn off accelerate") + endif() set(BUILD_SHARED_LIBS OFF CACHE INTERNAL "Whisper not build shared") include(FetchContent) FetchContent_Declare( diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index 7dd00371..cc94f40a 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -11,6 +11,10 @@ FetchContent_Declare( ) FetchContent_MakeAvailable(stable-diffusion) set_property(TARGET stable-diffusion PROPERTY POSITION_INDEPENDENT_CODE ON) +if(APPLE AND CMAKE_SYSTEM_VERSION VERSION_LESS 23) + # `cblas_sgemm()` introduced in macOS 13.3. + set(GGML_NO_ACCELERATE ON CACHE INTERNAL "Stable diffusion turn off accelerate") +endif() get_target_property(SD_DEPS stable-diffusion LINK_LIBRARIES) foreach(dep ${SD_DEPS}) if(TARGET ${dep}) From f33889ac1f389bee0cae80afdadf5846d8c3ccd7 Mon Sep 17 00:00:00 2001 From: PeterD1524 <53310459+PeterD1524@users.noreply.github.com> Date: Tue, 30 Jul 2024 15:55:21 +0800 Subject: [PATCH 379/623] [WASI-NN] piper: fix piper dependencies (#3583) * update piper patch to support static linking piper-phonemize and espeak-ng and support using system onnxruntime Signed-off-by: PeterD1524 * force BUILD_SHARED_LIBS OFF and CMAKE_POSITION_INDEPENDENT_CODE ON for piper Signed-off-by: PeterD1524 * avoid unwanted targets from piper Signed-off-by: PeterD1524 * add onnxruntime install script Signed-off-by: PeterD1524 * update piper patch to use PRIVATE scope for target-specific commands to avoid include and link pollution Signed-off-by: PeterD1524 * add clean include directory for piper disable tests for piper_phonemize use normal variable instead of CACHE for piper FetchContent Signed-off-by: PeterD1524 * remove unnecessary target_include_directories Signed-off-by: PeterD1524 * install onnxruntime in workflow Signed-off-by: PeterD1524 * remove the `-ubuntu` suffix because this should also be used in the manylinux(CentOS) distributions Signed-off-by: PeterD1524 * remove redundant lines will be done later by `wasmedge_setup_wasinn_target(wasmedgePluginWasiNN)` Signed-off-by: PeterD1524 * find onnxruntime early and fail if not found Signed-off-by: PeterD1524 --------- Signed-off-by: PeterD1524 --- plugins/wasi_nn/CMakeLists.txt | 3 - plugins/wasi_nn/piper.patch | 350 ++++++++++++++++++++++++--- utils/wasi-nn/install-onnxruntime.sh | 14 ++ 3 files changed, 337 insertions(+), 30 deletions(-) create mode 100644 utils/wasi-nn/install-onnxruntime.sh diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index e5799f20..d89583e7 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -169,9 +169,6 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) target_link_libraries(wasmedgePluginWasiNN PRIVATE simdjson::simdjson) elseif(BACKEND STREQUAL "piper") wasmedge_setup_simdjson() - include(WASINNDeps) - wasmedge_setup_piper() - target_include_directories(wasmedgePluginWasiNN PRIVATE ${piper_SOURCE_DIR}/src/cpp) target_link_libraries(wasmedgePluginWasiNN PRIVATE simdjson::simdjson) elseif(BACKEND STREQUAL "whisper") if(APPLE AND CMAKE_SYSTEM_VERSION VERSION_LESS 23) diff --git a/plugins/wasi_nn/piper.patch b/plugins/wasi_nn/piper.patch index 56a42028..b5acad34 100644 --- a/plugins/wasi_nn/piper.patch +++ b/plugins/wasi_nn/piper.patch @@ -1,8 +1,17 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index f96ec44..a759c35 100644 +index f96ec44..ef67ff5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -13,11 +13,13 @@ if(MSVC) +@@ -2,6 +2,8 @@ cmake_minimum_required(VERSION 3.13) + + project(piper C CXX) + ++option(BUILD_SHARED_LIBS "Build using shared libraries" OFF) ++ + file(READ "${CMAKE_CURRENT_LIST_DIR}/VERSION" piper_version) + + set(CMAKE_CXX_STANDARD 17) +@@ -13,11 +15,13 @@ if(MSVC) add_compile_options("$<$:/utf-8>") elseif(NOT APPLE) # Linux flags @@ -18,23 +27,33 @@ index f96ec44..a759c35 100644 add_executable(test_piper src/cpp/test.cpp src/cpp/piper.cpp) # NOTE: external project prefix are shortened because of path length restrictions on Windows -@@ -60,10 +62,14 @@ endif() - - if(NOT DEFINED PIPER_PHONEMIZE_DIR) - set(PIPER_PHONEMIZE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pi") -+ find_program(GIT_CMD git REQUIRED) - ExternalProject_Add( - piper_phonemize_external - PREFIX "${CMAKE_CURRENT_BINARY_DIR}/p" +@@ -58,59 +62,54 @@ endif() + + # ---- piper-phonemize --- + +-if(NOT DEFINED PIPER_PHONEMIZE_DIR) +- set(PIPER_PHONEMIZE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pi") +- ExternalProject_Add( +- piper_phonemize_external +- PREFIX "${CMAKE_CURRENT_BINARY_DIR}/p" - URL "https://github.com/rhasspy/piper-phonemize/archive/refs/heads/master.zip" -+ GIT_REPOSITORY "https://github.com/rhasspy/piper-phonemize.git" -+ GIT_TAG "bfc2e7549957829b0227c66a305d11cc88167bda" # master -+ UPDATE_DISCONNECTED TRUE -+ PATCH_COMMAND "${GIT_CMD}" "apply" "${CMAKE_CURRENT_SOURCE_DIR}/piper-phonemize.patch" - CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${PIPER_PHONEMIZE_DIR} - ) - add_dependencies(piper piper_phonemize_external) -@@ -74,7 +80,9 @@ endif() +- CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${PIPER_PHONEMIZE_DIR} +- ) +- add_dependencies(piper piper_phonemize_external) +- add_dependencies(test_piper piper_phonemize_external) +-endif() ++include(FetchContent) ++find_program(GIT_CMD git REQUIRED) ++FetchContent_Declare( ++ piper_phonemize ++ GIT_REPOSITORY "https://github.com/rhasspy/piper-phonemize.git" ++ GIT_TAG "bfc2e7549957829b0227c66a305d11cc88167bda" # master ++ UPDATE_DISCONNECTED TRUE ++ PATCH_COMMAND "${GIT_CMD}" "apply" "${CMAKE_CURRENT_SOURCE_DIR}/piper-phonemize.patch" ++) ++FetchContent_MakeAvailable(piper_phonemize) + + # ---- Declare executable ---- if((NOT MSVC) AND (NOT APPLE)) # Linux flags @@ -43,12 +62,43 @@ index f96ec44..a759c35 100644 + list(APPEND CMAKE_BUILD_RPATH "$ORIGIN") + list(APPEND CMAKE_INSTALL_RPATH "$ORIGIN") string(APPEND CMAKE_C_FLAGS " -Wall -Wextra") - target_link_libraries(piper -static-libgcc -static-libstdc++) +- target_link_libraries(piper -static-libgcc -static-libstdc++) ++ target_link_libraries(piper PRIVATE -static-libgcc -static-libstdc++) -@@ -104,14 +112,6 @@ target_include_directories(piper PUBLIC + set(PIPER_EXTRA_LIBRARIES "pthread") + endif() + +-target_link_libraries(piper ++target_link_libraries(piper PRIVATE + fmt + spdlog + espeak-ng +- piper_phonemize + onnxruntime + ${PIPER_EXTRA_LIBRARIES} ++ PUBLIC piper_phonemize + ) + +-target_link_directories(piper PUBLIC ++target_link_directories(piper PRIVATE + ${FMT_DIR}/lib + ${SPDLOG_DIR}/lib +- ${PIPER_PHONEMIZE_DIR}/lib + ) - target_compile_definitions(piper PUBLIC _PIPER_VERSION=${piper_version}) +-target_include_directories(piper PUBLIC ++set(PIPER_INTERFACE_INCLUDE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/include") ++file(COPY src/cpp/piper.hpp src/cpp/json.hpp DESTINATION "${PIPER_INTERFACE_INCLUDE_DIRECTORY}") ++ ++target_include_directories(piper PRIVATE + ${FMT_DIR}/include + ${SPDLOG_DIR}/include +- ${PIPER_PHONEMIZE_DIR}/include ++ INTERFACE "${PIPER_INTERFACE_INCLUDE_DIRECTORY}" + ) +-target_compile_definitions(piper PUBLIC _PIPER_VERSION=${piper_version}) +- -# ---- Declare test ---- -include(CTest) -enable_testing() @@ -56,10 +106,58 @@ index f96ec44..a759c35 100644 - NAME test_piper - COMMAND test_piper "${CMAKE_SOURCE_DIR}/etc/test_voice.onnx" "${PIPER_PHONEMIZE_DIR}/share/espeak-ng-data" "${CMAKE_CURRENT_BINARY_DIR}/test.wav" -) -- ++target_compile_definitions(piper PRIVATE _PIPER_VERSION=${piper_version}) + target_compile_features(test_piper PUBLIC cxx_std_17) - target_include_directories( +@@ -118,14 +117,12 @@ target_include_directories( + test_piper PUBLIC + ${FMT_DIR}/include + ${SPDLOG_DIR}/include +- ${PIPER_PHONEMIZE_DIR}/include + ) + + target_link_directories( + test_piper PUBLIC + ${FMT_DIR}/lib + ${SPDLOG_DIR}/lib +- ${PIPER_PHONEMIZE_DIR}/lib + ) + + target_link_libraries(test_piper PUBLIC +@@ -141,32 +138,3 @@ target_link_libraries(test_piper PUBLIC + install( + TARGETS piper + DESTINATION ${CMAKE_INSTALL_PREFIX}) +- +-# Dependencies +-install( +- DIRECTORY ${PIPER_PHONEMIZE_DIR}/bin/ +- DESTINATION ${CMAKE_INSTALL_PREFIX} +- USE_SOURCE_PERMISSIONS # keep +x +- FILES_MATCHING +- PATTERN "piper_phonemize" +- PATTERN "espeak-ng" +- PATTERN "*.dll" +-) +- +-install( +- DIRECTORY ${PIPER_PHONEMIZE_DIR}/lib/ +- DESTINATION ${CMAKE_INSTALL_PREFIX} +- FILES_MATCHING +- PATTERN "*.dll" +- PATTERN "*.so*" +-) +- +-install( +- DIRECTORY ${PIPER_PHONEMIZE_DIR}/share/espeak-ng-data +- DESTINATION ${CMAKE_INSTALL_PREFIX} +-) +- +-install( +- FILES ${PIPER_PHONEMIZE_DIR}/share/libtashkeel_model.ort +- DESTINATION ${CMAKE_INSTALL_PREFIX} +-) diff --git a/VERSION b/VERSION index 26aaba0..867e524 100644 --- a/VERSION @@ -70,15 +168,24 @@ index 26aaba0..867e524 100644 \ No newline at end of file diff --git a/piper-phonemize.patch b/piper-phonemize.patch new file mode 100644 -index 0000000..f8ca06f +index 0000000..0d91cde --- /dev/null +++ b/piper-phonemize.patch -@@ -0,0 +1,15 @@ +@@ -0,0 +1,213 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt -+index ec7b501..34cf7b1 100644 ++index ec7b501..335ef46 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt -+@@ -17,7 +17,9 @@ if(MSVC) ++@@ -10,6 +10,8 @@ project( ++ LANGUAGES CXX ++ ) ++ +++option(BUILD_SHARED_LIBS "Build using shared libraries" ON) +++ ++ if(MSVC) ++ # Force compiler to use UTF-8 for IPA constants ++ add_compile_options("$<$:/utf-8>") ++@@ -17,12 +19,14 @@ if(MSVC) + + elseif(NOT APPLE) + # Linux flags @@ -89,3 +196,192 @@ index 0000000..f8ca06f + string(APPEND CMAKE_C_FLAGS " -Wall -Wextra") + endif() + ++ add_library( ++- piper_phonemize SHARED +++ piper_phonemize ++ src/phonemize.cpp ++ src/phoneme_ids.cpp ++ src/tashkeel.cpp ++@@ -36,12 +40,33 @@ set_target_properties(piper_phonemize PROPERTIES ++ ++ # ---- onnxruntime --- ++ ++-# Look for onnxruntime files in /lib ++-if(NOT DEFINED ONNXRUNTIME_DIR) ++- if(NOT DEFINED ONNXRUNTIME_VERSION) ++- set(ONNXRUNTIME_VERSION "1.14.1") +++set(onnxruntime_FOUND FALSE) +++ +++if(NOT DEFINED ONNXRUNTIME_VERSION) +++ set(ONNXRUNTIME_VERSION "1.14.1") +++endif() +++ +++if(NOT onnxruntime_FOUND AND NOT DEFINED ONNXRUNTIME_DIR) +++ find_package(onnxruntime "${ONNXRUNTIME_VERSION}") +++ if(onnxruntime_FOUND) +++ list(APPEND ONNXRUNTIME_LINK_LIBRARIES "onnxruntime::onnxruntime") ++ endif() +++endif() +++ +++if(NOT onnxruntime_FOUND AND NOT DEFINED ONNXRUNTIME_DIR) +++ find_library(ONNXRUNTIME_LIBRARY onnxruntime) +++ if(NOT "${ONNXRUNTIME_LIBRARY}" STREQUAL "ONNXRUNTIME_LIBRARY-NOTFOUND") +++ find_path(ONNXRUNTIME_PATH "onnxruntime_cxx_api.h" PATH_SUFFIXES "onnxruntime") +++ if(NOT "${ONNXRUNTIME_PATH}" STREQUAL "ONNXRUNTIME_PATH-NOTFOUND") +++ list(APPEND ONNXRUNTIME_INCLUDE_DIRECTORIES "${ONNXRUNTIME_PATH}") +++ list(APPEND ONNXRUNTIME_LINK_LIBRARIES "${ONNXRUNTIME_LIBRARY}") +++ set(onnxruntime_FOUND TRUE) +++ endif() +++ endif() +++endif() ++ +++# Look for onnxruntime files in /lib +++if(NOT onnxruntime_FOUND AND NOT DEFINED ONNXRUNTIME_DIR) ++ if(WIN32) ++ # Windows x86-64 ++ set(ONNXRUNTIME_PREFIX "onnxruntime-win-x64-${ONNXRUNTIME_VERSION}") ++@@ -95,19 +120,31 @@ if(NOT DEFINED ONNXRUNTIME_DIR) ++ endif() ++ endif() ++ +++if(NOT onnxruntime_FOUND AND DEFINED ONNXRUNTIME_DIR) +++ list(APPEND ONNXRUNTIME_INCLUDE_DIRECTORIES "${ONNXRUNTIME_DIR}/include") +++ list(APPEND ONNXRUNTIME_LINK_DIRECTORIES "${ONNXRUNTIME_DIR}/lib") +++ list(APPEND ONNXRUNTIME_LINK_LIBRARIES "onnxruntime") +++ set(onnxruntime_FOUND TRUE) +++endif() +++ ++ # ---- espeak-ng --- ++ ++ if(NOT DEFINED ESPEAK_NG_DIR) ++ set(ESPEAK_NG_DIR "${CMAKE_CURRENT_BINARY_DIR}/ei") ++ +++ find_program(GIT_PROGRAM "git" REQUIRED) ++ include(ExternalProject) ++ ExternalProject_Add( ++ espeak_ng_external ++ PREFIX "${CMAKE_CURRENT_BINARY_DIR}/e" ++- URL "https://github.com/rhasspy/espeak-ng/archive/0f65aa301e0d6bae5e172cc74197d32a6182200f.zip" +++ GIT_REPOSITORY "https://github.com/rhasspy/espeak-ng" +++ GIT_TAG "0f65aa301e0d6bae5e172cc74197d32a6182200f" +++ GIT_PROGRESS TRUE +++ UPDATE_DISCONNECTED TRUE +++ PATCH_COMMAND "${GIT_PROGRAM}" "apply" "${CMAKE_CURRENT_SOURCE_DIR}/espeak-ng.patch" ++ CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${ESPEAK_NG_DIR} ++ CMAKE_ARGS -DUSE_ASYNC:BOOL=OFF ++- CMAKE_ARGS -DBUILD_SHARED_LIBS:BOOL=ON +++ CMAKE_ARGS "-DBUILD_SHARED_LIBS:BOOL=${BUILD_SHARED_LIBS}" ++ CMAKE_ARGS -DUSE_MBROLA:BOOL=OFF ++ CMAKE_ARGS -DUSE_LIBSONIC:BOOL=OFF ++ CMAKE_ARGS -DUSE_LIBPCAUDIO:BOOL=OFF ++@@ -116,6 +153,7 @@ if(NOT DEFINED ESPEAK_NG_DIR) ++ CMAKE_ARGS -DEXTRA_cmn:BOOL=ON ++ CMAKE_ARGS -DEXTRA_ru:BOOL=ON ++ CMAKE_ARGS -DCMAKE_C_FLAGS="-D_FILE_OFFSET_BITS=64" +++ USES_TERMINAL_DOWNLOAD TRUE ++ ) ++ add_dependencies(piper_phonemize espeak_ng_external) ++ endif() ++@@ -123,23 +161,27 @@ endif() ++ ++ # ---- Declare library ---- ++ +++set(PIPER_PHONEMIZE_INTERFACE_INCLUDE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/include") +++file(COPY "src/" DESTINATION "${PIPER_PHONEMIZE_INTERFACE_INCLUDE_DIRECTORY}/piper-phonemize" FILES_MATCHING PATTERN "*.h" PATTERN "*.hpp") +++ ++ target_include_directories( ++ piper_phonemize PUBLIC ++ "$" ++ ${ESPEAK_NG_DIR}/include ++- ${ONNXRUNTIME_DIR}/include +++ ${ONNXRUNTIME_INCLUDE_DIRECTORIES} +++ INTERFACE "${PIPER_PHONEMIZE_INTERFACE_INCLUDE_DIRECTORY}" ++ ) ++ ++ target_link_directories( ++ piper_phonemize PUBLIC ++ ${ESPEAK_NG_DIR}/lib ++- ${ONNXRUNTIME_DIR}/lib +++ ${ONNXRUNTIME_LINK_DIRECTORIES} ++ ) ++ ++ target_link_libraries( ++ piper_phonemize ++ espeak-ng ++- onnxruntime +++ ${ONNXRUNTIME_LINK_LIBRARIES} ++ ) ++ ++ target_compile_features(piper_phonemize PUBLIC cxx_std_17) ++@@ -173,12 +215,13 @@ target_link_libraries(piper_phonemize_exe PUBLIC ++ # ---- Declare test ---- ++ ++ include(CTest) ++-enable_testing() ++ add_executable(test_piper_phonemize src/test.cpp src/phoneme_ids.cpp) ++-add_test( ++- NAME test_piper_phonemize ++- COMMAND test_piper_phonemize "${ESPEAK_NG_DIR}/share/espeak-ng-data" "${CMAKE_SOURCE_DIR}/etc/libtashkeel_model.ort" ++-) +++if(BUILD_TESTING) +++ add_test( +++ NAME test_piper_phonemize +++ COMMAND test_piper_phonemize "${ESPEAK_NG_DIR}/share/espeak-ng-data" "${CMAKE_SOURCE_DIR}/etc/libtashkeel_model.ort" +++ ) +++endif() ++ ++ target_compile_features(test_piper_phonemize PUBLIC cxx_std_17) ++ ++@@ -207,7 +250,7 @@ install( ++ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) ++ ++ install( ++- DIRECTORY ${CMAKE_SOURCE_DIR}/src/ +++ DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/src/" ++ DESTINATION include/piper-phonemize ++ FILES_MATCHING ++ PATTERN "*.h" ++@@ -218,7 +261,7 @@ install( ++ ARCHIVE DESTINATION ${CMAKE_INSTALL_BINDIR}) ++ ++ install( ++- FILES ${CMAKE_SOURCE_DIR}/etc/libtashkeel_model.ort +++ FILES "${CMAKE_CURRENT_SOURCE_DIR}/etc/libtashkeel_model.ort" ++ TYPE DATA) ++ ++ # Dependencies ++@@ -226,10 +269,12 @@ install( ++ DIRECTORY ${ESPEAK_NG_DIR}/ ++ DESTINATION ${CMAKE_INSTALL_PREFIX}) ++ ++-install( ++- DIRECTORY ${ONNXRUNTIME_DIR}/include/ ++- DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) +++if(DEFINED ONNXRUNTIME_DIR) +++ install( +++ DIRECTORY ${ONNXRUNTIME_DIR}/include/ +++ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) ++ ++-install( ++- DIRECTORY ${ONNXRUNTIME_DIR}/lib/ ++- DESTINATION ${CMAKE_INSTALL_LIBDIR}) +++ install( +++ DIRECTORY ${ONNXRUNTIME_DIR}/lib/ +++ DESTINATION ${CMAKE_INSTALL_LIBDIR}) +++endif() ++diff --git a/espeak-ng.patch b/espeak-ng.patch ++new file mode 100644 ++index 0000000..a51d146 ++--- /dev/null +++++ b/espeak-ng.patch ++@@ -0,0 +1,10 @@ +++diff --git a/src/ucd-tools/CMakeLists.txt b/src/ucd-tools/CMakeLists.txt +++index 2050c114..4bd7d17e 100644 +++--- a/src/ucd-tools/CMakeLists.txt ++++++ b/src/ucd-tools/CMakeLists.txt +++@@ -1,4 +1,4 @@ +++-add_library(ucd STATIC ++++add_library(ucd OBJECT +++ src/case.c +++ src/categories.c +++ src/ctype.c diff --git a/utils/wasi-nn/install-onnxruntime.sh b/utils/wasi-nn/install-onnxruntime.sh new file mode 100644 index 00000000..61c2acb7 --- /dev/null +++ b/utils/wasi-nn/install-onnxruntime.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash + +set -e + +: ${ONNXRUNTIME_VERSION:=1.14.1} + +ONNXRUNTIME_NAME="onnxruntime-linux-x64-${ONNXRUNTIME_VERSION}" +ONNXRUNTIME_TGZ="${ONNXRUNTIME_NAME}.tgz" + +curl -LO "https://github.com/microsoft/onnxruntime/releases/download/v${ONNXRUNTIME_VERSION}/${ONNXRUNTIME_TGZ}" +tar zxf "${ONNXRUNTIME_TGZ}" +mv "${ONNXRUNTIME_NAME}/include/"* /usr/local/include/ +mv "${ONNXRUNTIME_NAME}/lib/"* /usr/local/lib/ +rm -rf "${ONNXRUNTIME_TGZ}" "${ONNXRUNTIME_NAME}" From 7cb630dedc0ecdbb4ccf97d0def9059313921428 Mon Sep 17 00:00:00 2001 From: grorge Date: Sat, 6 Jul 2024 15:26:13 +0800 Subject: [PATCH 380/623] [WASI-NN] ChatTTS: Ceate backend Signed-off-by: grorge --- plugins/wasi_nn/CMakeLists.txt | 3 +- plugins/wasi_nn/chattts.cpp | 0 plugins/wasi_nn/chattts.h | 59 ++++++++++++++++++++++++++++++++++ plugins/wasi_nn/types.h | 4 ++- plugins/wasi_nn/wasinnenv.cpp | 3 +- plugins/wasi_nn/wasinnenv.h | 1 + plugins/wasi_nn/wasinnfunc.cpp | 4 ++- 7 files changed, 70 insertions(+), 4 deletions(-) create mode 100644 plugins/wasi_nn/chattts.cpp create mode 100644 plugins/wasi_nn/chattts.h diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index d89583e7..cdd1b332 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -15,6 +15,7 @@ wasmedge_add_library(wasmedgePluginWasiNN neuralspeed.cpp piper.cpp whispercpp.cpp + chattts.cpp ) foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) @@ -152,7 +153,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) COMMAND ${CMAKE_COMMAND} -E copy ${llama_SOURCE_DIR}/ggml/src/ggml-common.h ggml-common.h ) endif() - elseif(BACKEND STREQUAL "neuralspeed") + elseif(BACKEND STREQUAL "neuralspeed" OR BACKEND STREQUAL "chattts") wasmedge_setup_simdjson() find_package(Python3 COMPONENTS Interpreter Development) diff --git a/plugins/wasi_nn/chattts.cpp b/plugins/wasi_nn/chattts.cpp new file mode 100644 index 00000000..e69de29b diff --git a/plugins/wasi_nn/chattts.h b/plugins/wasi_nn/chattts.h new file mode 100644 index 00000000..ef7f387f --- /dev/null +++ b/plugins/wasi_nn/chattts.h @@ -0,0 +1,59 @@ +#pragma once + +#include "plugin/plugin.h" +#include "types.h" +#include +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS +#include +#endif +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} + +namespace WasmEdge::Host::WASINN::ChatTTS { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS +struct Graph { + bool EnableDebugLog = true; + Graph() noexcept { Py_Initialize(); } + ~Graph() noexcept { + if (Py_IsInitialized()) { + Py_XDECREF(Model); + Py_XDECREF(ChatTTSModule); + } + } + PyObject *Model; + PyObject *ChatTTSModule; +}; +struct Context { + Context(size_t Gid, Graph &) noexcept : GraphId(Gid) {} + size_t GraphId; + std::vector Inputs; + std::vector Outputs; +}; +#else +struct Graph {}; +struct Context { + Context(size_t, Graph &) noexcept {} +}; +#endif + +struct Environ {}; + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +Expect unload(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId) noexcept; +} // namespace WasmEdge::Host::WASINN::ChatTTS \ No newline at end of file diff --git a/plugins/wasi_nn/types.h b/plugins/wasi_nn/types.h index 271af483..1c2dad70 100644 --- a/plugins/wasi_nn/types.h +++ b/plugins/wasi_nn/types.h @@ -39,6 +39,7 @@ enum class Backend : uint8_t { NeuralSpeed = 7, Whisper = 9, Piper = 11, + ChatTTS = 8, }; #define FOR_EACH_BACKEND(F) \ @@ -50,7 +51,8 @@ enum class Backend : uint8_t { F(GGML) \ F(NeuralSpeed) \ F(Whisper) \ - F(Piper) + F(Piper) \ + F(ChatTTS) struct TensorData { Span Dimension; diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index 89994c5f..b33f8255 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -32,7 +32,8 @@ std::map BackendMap = { {"ggml"sv, Backend::GGML}, {"neuralspeed"sv, Backend::NeuralSpeed}, {"whisper"sv, Backend::Whisper}, - {"piper"sv, Backend::Piper}}; + {"piper"sv, Backend::Piper}, + {"chattts"sv, Backend::ChatTTS}}; std::map DeviceMap = {{"cpu"sv, Device::CPU}, {"gpu"sv, Device::GPU}, diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index 7e38cb80..7f67b819 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -11,6 +11,7 @@ #include "ggml.h" #include "neuralspeed.h" +#include "chattts.h" #include "onnx.h" #include "openvino.h" #include "piper.h" diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index aa16fbc3..692028da 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -600,9 +600,11 @@ Expect WasiNNUnload::bodyImpl(const Runtime::CallingFrame &Frame, return WASINN::GGML::unload(Env, GraphId); case WASINN::Backend::NeuralSpeed: return WASINN::NeuralSpeed::unload(Env, GraphId); + case WASINN::Backend::ChatTTS: + return WASINN::ChatTTS::unload(Env, GraphId); default: spdlog::error( - "[WASI-NN] unlaod: Only GGML and Neural speed backend supports unload."sv); + "[WASI-NN] unlaod: Only GGML, Neural speed, and ChatTTS backend supports unload."sv); return WASINN::ErrNo::InvalidArgument; } } From d9b72ab9cb033680dcaff16c6666bb574e88a7b8 Mon Sep 17 00:00:00 2001 From: grorge Date: Sat, 6 Jul 2024 17:05:31 +0800 Subject: [PATCH 381/623] [WASI-NN] ChatTTS: Implement function Signed-off-by: grorge --- plugins/wasi_nn/chattts.cpp | 248 ++++++++++++++++++++++++++++++++++++ plugins/wasi_nn/chattts.h | 8 +- 2 files changed, 252 insertions(+), 4 deletions(-) diff --git a/plugins/wasi_nn/chattts.cpp b/plugins/wasi_nn/chattts.cpp index e69de29b..754dfa2c 100644 --- a/plugins/wasi_nn/chattts.cpp +++ b/plugins/wasi_nn/chattts.cpp @@ -0,0 +1,248 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2022 Second State INC + +#include "chattts.h" +#include "wasinnenv.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS +#include "simdjson.h" + +#if !defined(_WIN32) && !defined(_WIN64) && !defined(__WIN32__) && \ + !defined(__TOS_WIN__) && !defined(__WINDOWS__) +#include +#endif +#include +#include +#endif + +namespace WasmEdge::Host::WASINN::ChatTTS { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS +#if defined(_WIN32) || defined(_WIN64) || defined(__WIN32__) || \ + defined(__TOS_WIN__) || defined(__WINDOWS__) +HINSTANCE SharedLib = LoadLibrary(PYTHON_LIB_PATH); +#else +void *SharedLib = dlopen(PYTHON_LIB_PATH, RTLD_GLOBAL | RTLD_NOW); +#endif +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, WASINN::Device, + uint32_t &GraphId) noexcept { + // Add a new graph. + Env.NNGraph.emplace_back(Backend::NeuralSpeed); + auto &GraphRef = Env.NNGraph.back().get(); + // Initialize the plugin parameters. + GraphRef.EnableDebugLog = false; + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN] ChatTTS backend: Load."sv); + } + + if (Builders.size() > 1) { + std::string Metadata = std::string( + reinterpret_cast(Builders[1].data()), Builders[1].size()); + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + auto ParseError = Parser.parse(Metadata).get(Doc); + if (ParseError) { + spdlog::error("[WASI-NN] ChatTTS backend: Parse metadata error"sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidEncoding; + } + // TODO handle metadata + } + + // Create Model class + if (!Py_IsInitialized()) { + Py_Initialize(); + } + GraphRef.ChatTTSModule = PyImport_Import(PyUnicode_FromString("ChatTTS")); + if (GraphRef.ChatTTSModule == nullptr) { + PyErr_Print(); + spdlog::error("[WASI-NN] ChatTTS backend: Can not find ChatTTS library."sv); + Env.NNGraph.pop_back(); + return WASINN::ErrNo::RuntimeError; + } + PyObject *ChatFunction = + PyObject_GetAttrString(GraphRef.ChatTTSModule, "Chat"); + if (ChatFunction && PyCallable_Check(ChatFunction)) { + spdlog::error( + "[WASI-NN] ChatTTS backend: Can not find Chat class in ChatTTS."sv); + Py_XDECREF(GraphRef.ChatTTSModule); + Env.NNGraph.pop_back(); + return WASINN::ErrNo::RuntimeError; + } + GraphRef.Chat = PyObject_CallObject(ChatFunction, nullptr); + Py_XDECREF(ChatFunction); + if (GraphRef.Chat == nullptr) { + spdlog::error("[WASI-NN] ChatTTS backend: Can not create chat."sv); + Py_XDECREF(GraphRef.ChatTTSModule); + Env.NNGraph.pop_back(); + return WASINN::ErrNo::RuntimeError; + } + PyObject *LoadMethod = PyObject_GetAttrString(GraphRef.Chat, "load"); + if (LoadMethod && PyCallable_Check(LoadMethod)) { + PyObject *Args = PyTuple_Pack(1, Py_False); + PyObject *Value = PyObject_CallObject(LoadMethod, Args); + Py_XDECREF(Value); + Py_DECREF(Args); + Py_DECREF(LoadMethod); + } else { + spdlog::error("[WASI-NN] ChatTTS backend: Can not load chat."sv); + PyErr_Print(); + Env.NNGraph.pop_back(); + return WASINN::ErrNo::RuntimeError; + } + // Store the loaded graph. + GraphId = Env.NNGraph.size() - 1; + + return WASINN::ErrNo::Success; +} + +Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, + uint32_t &ContextId) noexcept { + if (!Py_IsInitialized()) { + spdlog::error( + "[WASI-NN] ChatTTS backend: Model has been realse, please reload it."sv); + return WASINN::ErrNo::RuntimeError; + } + Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); + ContextId = Env.NNContext.size() - 1; + return ErrNo::Success; +} + +Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t, const TensorData &Tensor) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (!Py_IsInitialized()) { + spdlog::error( + "[WASI-NN] ChatTTS backend: Model has been realse, please reload it."sv); + return WASINN::ErrNo::RuntimeError; + } + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN] ChatTTS backend: setInput"sv); + } + + // Set the input. + std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); + CxtRef.Inputs.clear(); + CxtRef.Inputs = Prompt; + + return WASINN::ErrNo::Success; +} + +Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t, Span OutBuffer, + uint32_t &BytesWritten) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN] ChatTTS backend: getOutput"sv); + } + std::string StringTmp(reinterpret_cast(CxtRef.Outputs.data()), + CxtRef.Outputs.size() * sizeof(long long int)); + std::copy_n(StringTmp.data(), StringTmp.length(), OutBuffer.data()); + BytesWritten = StringTmp.length(); + return WASINN::ErrNo::Success; +} +Expect compute(WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + if (!Py_IsInitialized()) { + spdlog::error( + "[WASI-NN] ChatTTS backend: Model has been realse, please reload it."sv); + return WASINN::ErrNo::RuntimeError; + } + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN] ChatTTS backend: compute"sv); + } + if (CxtRef.Inputs.size() == 0) { + spdlog::error("[WASI-NN] ChatTTS backend: Input is not set!"sv); + return ErrNo::InvalidArgument; + } + PyObject *InputStr = PyUnicode_FromString(CxtRef.Inputs.c_str()); + PyObject *InferMethod = PyObject_GetAttrString(GraphRef.Chat, "infer"); + PyObject *Result = nullptr; + if (InferMethod && PyCallable_Check(InferMethod)) { + PyObject *Args = PyTuple_Pack(1, InputStr); + Result = PyObject_CallObject(InferMethod, Args); + Py_XDECREF(Args); + } else { + spdlog::error( + "[WASI-NN] ChatTTS backend: Can not find infer method in Chat."sv); + PyErr_Print(); + Py_XDECREF(InferMethod); + return WASINN::ErrNo::RuntimeError; + } + if (Result != nullptr) { + PyObject *Wav0 = PyList_GetItem(Result, 0); + PyObject *BytesObj = PyObject_CallMethod(Wav0, "tobytes", nullptr); + char *Bytes = PyBytes_AsString(BytesObj); + Py_ssize_t size = PyBytes_Size(BytesObj); + CxtRef.Outputs = std::vector(Bytes, Bytes + size); + Py_DECREF(BytesObj); + Py_DECREF(Wav0); + } else { + spdlog::error( + "[WASI-NN] ChatTTS backend: Can not get output from infer method."sv); + PyErr_Print(); + Py_XDECREF(Result); + return WASINN::ErrNo::RuntimeError; + } + Py_XDECREF(Result); + Py_XDECREF(InputStr); + Py_XDECREF(InferMethod); + return WASINN::ErrNo::Success; +} + +Expect unload(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN] Neural speed backend: start unload."sv); + } + if (Py_IsInitialized()) { + Py_XDECREF(GraphRef.Chat); + Py_XDECREF(GraphRef.ChatTTSModule); + Py_Finalize(); + } + return WASINN::ErrNo::Success; +} + +#else +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error( + "[WASI-NN] Neural speed backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"NeuralSpeed\" to build it."sv); + return WASINN::ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WASINN::WasiNNEnvironment &, + Span>, WASINN::Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WASINN::WasiNNEnvironment &, uint32_t, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + Span, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WASINN::WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +Expect unload(WASINN::WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +#endif + +} // namespace WasmEdge::Host::WASINN::ChatTTS \ No newline at end of file diff --git a/plugins/wasi_nn/chattts.h b/plugins/wasi_nn/chattts.h index ef7f387f..acf11752 100644 --- a/plugins/wasi_nn/chattts.h +++ b/plugins/wasi_nn/chattts.h @@ -17,18 +17,18 @@ struct Graph { Graph() noexcept { Py_Initialize(); } ~Graph() noexcept { if (Py_IsInitialized()) { - Py_XDECREF(Model); + Py_XDECREF(Chat); Py_XDECREF(ChatTTSModule); } } - PyObject *Model; + PyObject *Chat; PyObject *ChatTTSModule; }; struct Context { Context(size_t Gid, Graph &) noexcept : GraphId(Gid) {} size_t GraphId; - std::vector Inputs; - std::vector Outputs; + std::string Inputs; + std::vector Outputs; }; #else struct Graph {}; From 42d2603654ee2bdd0f557e40a68e91ab587fa8cf Mon Sep 17 00:00:00 2001 From: grorge Date: Mon, 8 Jul 2024 17:57:17 +0800 Subject: [PATCH 382/623] [WASI-NN] ChatTTS: add basic test Signed-off-by: grorge --- plugins/wasi_nn/chattts.cpp | 63 +++--- plugins/wasi_nn/wasinnenv.h | 2 +- test/plugins/wasi_nn/wasi_nn.cpp | 345 +++++++++++++++++++++++++------ 3 files changed, 314 insertions(+), 96 deletions(-) diff --git a/plugins/wasi_nn/chattts.cpp b/plugins/wasi_nn/chattts.cpp index 754dfa2c..0ae1b1ff 100644 --- a/plugins/wasi_nn/chattts.cpp +++ b/plugins/wasi_nn/chattts.cpp @@ -24,30 +24,30 @@ HINSTANCE SharedLib = LoadLibrary(PYTHON_LIB_PATH); void *SharedLib = dlopen(PYTHON_LIB_PATH, RTLD_GLOBAL | RTLD_NOW); #endif Expect load(WASINN::WasiNNEnvironment &Env, - Span> Builders, WASINN::Device, + Span>, WASINN::Device, uint32_t &GraphId) noexcept { // Add a new graph. - Env.NNGraph.emplace_back(Backend::NeuralSpeed); + Env.NNGraph.emplace_back(Backend::ChatTTS); auto &GraphRef = Env.NNGraph.back().get(); // Initialize the plugin parameters. - GraphRef.EnableDebugLog = false; + GraphRef.EnableDebugLog = true; if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN] ChatTTS backend: Load."sv); } - if (Builders.size() > 1) { - std::string Metadata = std::string( - reinterpret_cast(Builders[1].data()), Builders[1].size()); - simdjson::dom::parser Parser; - simdjson::dom::element Doc; - auto ParseError = Parser.parse(Metadata).get(Doc); - if (ParseError) { - spdlog::error("[WASI-NN] ChatTTS backend: Parse metadata error"sv); - Env.NNGraph.pop_back(); - return ErrNo::InvalidEncoding; - } - // TODO handle metadata - } + // if (Builders.size() > 1) { + // std::string Metadata = std::string( + // reinterpret_cast(Builders[1].data()), Builders[1].size()); + // simdjson::dom::parser Parser; + // simdjson::dom::element Doc; + // auto ParseError = Parser.parse(Metadata).get(Doc); + // if (ParseError) { + // spdlog::error("[WASI-NN] ChatTTS backend: Parse metadata error"sv); + // Env.NNGraph.pop_back(); + // return ErrNo::InvalidEncoding; + // } + // // TODO handle metadata + // } // Create Model class if (!Py_IsInitialized()) { @@ -62,7 +62,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, } PyObject *ChatFunction = PyObject_GetAttrString(GraphRef.ChatTTSModule, "Chat"); - if (ChatFunction && PyCallable_Check(ChatFunction)) { + if (ChatFunction == nullptr || !PyCallable_Check(ChatFunction)) { spdlog::error( "[WASI-NN] ChatTTS backend: Can not find Chat class in ChatTTS."sv); Py_XDECREF(GraphRef.ChatTTSModule); @@ -78,18 +78,15 @@ Expect load(WASINN::WasiNNEnvironment &Env, return WASINN::ErrNo::RuntimeError; } PyObject *LoadMethod = PyObject_GetAttrString(GraphRef.Chat, "load"); - if (LoadMethod && PyCallable_Check(LoadMethod)) { - PyObject *Args = PyTuple_Pack(1, Py_False); - PyObject *Value = PyObject_CallObject(LoadMethod, Args); - Py_XDECREF(Value); - Py_DECREF(Args); - Py_DECREF(LoadMethod); - } else { + if (LoadMethod == nullptr || !PyCallable_Check(LoadMethod)) { spdlog::error("[WASI-NN] ChatTTS backend: Can not load chat."sv); PyErr_Print(); Env.NNGraph.pop_back(); return WASINN::ErrNo::RuntimeError; } + PyObject *Value = PyObject_CallObject(LoadMethod, nullptr); + Py_XDECREF(Value); + Py_DECREF(LoadMethod); // Store the loaded graph. GraphId = Env.NNGraph.size() - 1; @@ -163,17 +160,16 @@ Expect compute(WasiNNEnvironment &Env, PyObject *InputStr = PyUnicode_FromString(CxtRef.Inputs.c_str()); PyObject *InferMethod = PyObject_GetAttrString(GraphRef.Chat, "infer"); PyObject *Result = nullptr; - if (InferMethod && PyCallable_Check(InferMethod)) { - PyObject *Args = PyTuple_Pack(1, InputStr); - Result = PyObject_CallObject(InferMethod, Args); - Py_XDECREF(Args); - } else { + if (InferMethod == nullptr || !PyCallable_Check(InferMethod)) { spdlog::error( "[WASI-NN] ChatTTS backend: Can not find infer method in Chat."sv); PyErr_Print(); Py_XDECREF(InferMethod); return WASINN::ErrNo::RuntimeError; } + PyObject *Args = PyTuple_Pack(1, InputStr); + Result = PyObject_CallObject(InferMethod, Args); + Py_XDECREF(Args); if (Result != nullptr) { PyObject *Wav0 = PyList_GetItem(Result, 0); PyObject *BytesObj = PyObject_CallMethod(Wav0, "tobytes", nullptr); @@ -203,8 +199,8 @@ Expect unload(WASINN::WasiNNEnvironment &Env, spdlog::info("[WASI-NN] Neural speed backend: start unload."sv); } if (Py_IsInitialized()) { - Py_XDECREF(GraphRef.Chat); - Py_XDECREF(GraphRef.ChatTTSModule); + Py_XDECREF(GraphRef.Chat); + Py_XDECREF(GraphRef.ChatTTSModule); Py_Finalize(); } return WASINN::ErrNo::Success; @@ -213,9 +209,8 @@ Expect unload(WASINN::WasiNNEnvironment &Env, #else namespace { Expect reportBackendNotSupported() noexcept { - spdlog::error( - "[WASI-NN] Neural speed backend is not built. use " - "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"NeuralSpeed\" to build it."sv); + spdlog::error("[WASI-NN] ChatTTS backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"ChatTTS\" to build it."sv); return WASINN::ErrNo::InvalidArgument; } } // namespace diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index 7f67b819..da952fe2 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -9,9 +9,9 @@ #include #include +#include "chattts.h" #include "ggml.h" #include "neuralspeed.h" -#include "chattts.h" #include "onnx.h" #include "openvino.h" #include "piper.h" diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 9cb61300..ebed49b1 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -25,75 +25,81 @@ using WasmEdge::Host::WASINN::ErrNo; defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER) || \ - defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER) -namespace { -WasmEdge::Runtime::Instance::ModuleInstance * -createModule(std::string_view NNRPCURI = "") { - using namespace std::literals::string_view_literals; - WasmEdge::Plugin::Plugin::load( - std::filesystem::u8path("../../../plugins/wasi_nn/" WASMEDGE_LIB_PREFIX - "wasmedgePluginWasiNN" WASMEDGE_LIB_EXTENSION)); - if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasi_nn"sv)) { - WasmEdge::PO::ArgumentParser Parser; - Plugin->registerOptions(Parser); - if (NNRPCURI != "") { - Parser.set_raw_value("nn-rpc-uri"sv, std::string(NNRPCURI)); +<<<<<<< HEAD +defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER) +======= + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS) +>>>>>>> 6769e4a7 ([WASI-NN] ChatTTS: add basic test) + namespace { + WasmEdge::Runtime::Instance::ModuleInstance *createModule( + std::string_view NNRPCURI = "") { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load( + std::filesystem::u8path("../../../plugins/wasi_nn/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasiNN" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasi_nn"sv)) { + WasmEdge::PO::ArgumentParser Parser; + Plugin->registerOptions(Parser); + if (NNRPCURI != "") { + Parser.set_raw_value("nn-rpc-uri"sv, + std::string(NNRPCURI)); + } + if (const auto *Module = Plugin->findModule("wasi_nn"sv)) { + return Module->create().release(); + } } - if (const auto *Module = Plugin->findModule("wasi_nn"sv)) { - return Module->create().release(); + return nullptr; + } +#if !defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS) + inline std::vector readEntireFile(const std::string &Path) { + std::ifstream Fin(Path, std::ios::in | std::ios::binary | std::ios::ate); + if (!Fin) { + return {}; + } + std::vector Buf(static_cast(Fin.tellg())); + Fin.seekg(0, std::ios::beg); + if (!Fin.read(reinterpret_cast(Buf.data()), + static_cast(Buf.size()))) { + return {}; } + Fin.close(); + return Buf; } - return nullptr; -} +#endif -inline std::vector readEntireFile(const std::string &Path) { - std::ifstream Fin(Path, std::ios::in | std::ios::binary | std::ios::ate); - if (!Fin) { - return {}; + template + void writeBinaries(WasmEdge::Runtime::Instance::MemoryInstance & MemInst, + WasmEdge::Span Binaries, uint32_t Ptr) noexcept { + std::copy(Binaries.begin(), Binaries.end(), MemInst.getPointer(Ptr)); } - std::vector Buf(static_cast(Fin.tellg())); - Fin.seekg(0, std::ios::beg); - if (!Fin.read(reinterpret_cast(Buf.data()), - static_cast(Buf.size()))) { - return {}; - } - Fin.close(); - return Buf; -} - -template -void writeBinaries(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, - WasmEdge::Span Binaries, uint32_t Ptr) noexcept { - std::copy(Binaries.begin(), Binaries.end(), MemInst.getPointer(Ptr)); -} -void writeUInt32(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, - uint32_t Value, uint32_t &Ptr) { - uint32_t *BufPtr = MemInst.getPointer(Ptr); - *BufPtr = Value; - Ptr += 4; -} + void writeUInt32(WasmEdge::Runtime::Instance::MemoryInstance & MemInst, + uint32_t Value, uint32_t & Ptr) { + uint32_t *BufPtr = MemInst.getPointer(Ptr); + *BufPtr = Value; + Ptr += 4; + } -void writeFatPointer(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, - uint32_t PtrVal, uint32_t PtrSize, uint32_t &Ptr) { - writeUInt32(MemInst, PtrVal, Ptr); - writeUInt32(MemInst, PtrSize, Ptr); -} + void writeFatPointer(WasmEdge::Runtime::Instance::MemoryInstance & MemInst, + uint32_t PtrVal, uint32_t PtrSize, uint32_t & Ptr) { + writeUInt32(MemInst, PtrVal, Ptr); + writeUInt32(MemInst, PtrSize, Ptr); + } #if defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE) -template -std::vector classSort(WasmEdge::Span Array) { - std::vector Indices(Array.size()); - std::iota(Indices.begin(), Indices.end(), 0); - std::sort(Indices.begin(), Indices.end(), - [&Array](size_t Left, size_t Right) -> bool { - // Sort indices according to corresponding array element. - return Array[Left] > Array[Right]; - }); - return Indices; -} + template + std::vector classSort(WasmEdge::Span Array) { + std::vector Indices(Array.size()); + std::iota(Indices.begin(), Indices.end(), 0); + std::sort(Indices.begin(), Indices.end(), + [&Array](size_t Left, size_t Right) -> bool { + // Sort indices according to corresponding array element. + return Array[Left] > Array[Right]; + }); + return Indices; + } #endif } // namespace #endif @@ -2447,7 +2453,6 @@ TEST(WasiNNTest, PiperBackend) { EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); } - // Test: set_input -- set input successfully. { EXPECT_TRUE( @@ -2469,7 +2474,6 @@ TEST(WasiNNTest, PiperBackend) { EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); } - // Test: compute -- compute successfully. { EXPECT_TRUE(HostFuncCompute.run( @@ -2500,7 +2504,6 @@ TEST(WasiNNTest, PiperBackend) { EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); } - // Test: get_output -- get output successfully. { EXPECT_TRUE(HostFuncGetOutput.run( @@ -2515,3 +2518,223 @@ TEST(WasiNNTest, PiperBackend) { } } #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS +TEST(WasiNNTest, ChatTTSBackend) { + // Create the wasmedge_process module instance. + auto *NNMod = dynamic_cast(createModule()); + ASSERT_TRUE(NNMod != nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(60000))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // Load the files. + std::string Prompt = "This is test prompt."; + std::vector TensorData(Prompt.begin(), Prompt.end()); + + std::vector TensorDim{1}; + uint32_t BuilderPtr = UINT32_C(0); + uint32_t LoadEntryPtr = UINT32_C(0); + uint32_t SetInputEntryPtr = UINT32_C(0); + uint32_t OutBoundPtr = UINT32_C(61000) * UINT32_C(65536); + uint32_t StorePtr = UINT32_C(65536); + + // Return value. + std::array Errno = {UINT32_C(0)}; + + // Get the function "load". + auto *FuncInst = NNMod->findFuncExports("load"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncLoad = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "init_execution_context". + FuncInst = NNMod->findFuncExports("init_execution_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInit = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "set_input". + FuncInst = NNMod->findFuncExports("set_input"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSetInput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "get_output". + FuncInst = NNMod->findFuncExports("get_output"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetOutput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "compute". + FuncInst = NNMod->findFuncExports("compute"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncCompute = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "unload". + FuncInst = NNMod->findFuncExports("unload"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncUnload = + dynamic_cast(FuncInst->getHostFunc()); + + // ChatTTS WASI-NN load tests. + // Test: load -- graph id ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::ChatTTS), + UINT32_C(0), OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- graph builder ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + OutBoundPtr, UINT32_C(1), + static_cast(Backend::ChatTTS), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- load successfully. + BuilderPtr = LoadEntryPtr; + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::ChatTTS), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + // ChatTTS WASI-NN init_execution_context tests. + // Test: init_execution_context -- graph id invalid. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(2), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: init_execution_context -- init second context. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // ChatTTS WASI-NN set_input tests. + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); + + // Test: set_input -- context id exceeds. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(3), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: set_input -- set input successfully. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += (TensorDim.size() * 4 + TensorData.size()); + + // ChatTTS WASI-NN compute tests. + // Test: compute -- context id exceeds. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(3)}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: compute -- compute successfully. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(0)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // ChatTTS WASI-NN get_output tests. + // Test: get_output -- output bytes ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- output buffer ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), OutBoundPtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: get_output -- get output successfully. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + // Should output more than 50 bytes. + auto BytesWritten = *MemInst.getPointer(BuilderPtr); + EXPECT_GE(BytesWritten, 50); + } + + // ChatTTS WASI-NN unload tests. + // Test: unload -- unload successfully. + { + EXPECT_TRUE(HostFuncUnload.run( + CallFrame, std::initializer_list{UINT32_C(0)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } +} +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS From 1e4df3c72da849cf76e68d32ad85ee8b06191d46 Mon Sep 17 00:00:00 2001 From: grorge Date: Tue, 9 Jul 2024 15:49:19 +0800 Subject: [PATCH 383/623] [WASI-NN] ChatTTS: add output metadata Signed-off-by: grorge --- plugins/wasi_nn/chattts.cpp | 18 ++++++++++++------ plugins/wasi_nn/types.h | 2 +- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/plugins/wasi_nn/chattts.cpp b/plugins/wasi_nn/chattts.cpp index 0ae1b1ff..9014e9b6 100644 --- a/plugins/wasi_nn/chattts.cpp +++ b/plugins/wasi_nn/chattts.cpp @@ -128,18 +128,24 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, - uint32_t, Span OutBuffer, + uint32_t Index, Span OutBuffer, uint32_t &BytesWritten) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN] ChatTTS backend: getOutput"sv); } - std::string StringTmp(reinterpret_cast(CxtRef.Outputs.data()), - CxtRef.Outputs.size() * sizeof(long long int)); - std::copy_n(StringTmp.data(), StringTmp.length(), OutBuffer.data()); - BytesWritten = StringTmp.length(); - return WASINN::ErrNo::Success; + if (Index == 0) { + std::copy_n(CxtRef.Outputs.data(), CxtRef.Outputs.size(), OutBuffer.data()); + BytesWritten = CxtRef.Outputs.size(); + return WASINN::ErrNo::Success; + } else if (Index == 1) { + uint32_t Size = CxtRef.Outputs.size(); + std::memcpy(OutBuffer.data(), &Size, sizeof(uint32_t)); + BytesWritten = sizeof(uint32_t); + return WASINN::ErrNo::Success; + } + return WASINN::ErrNo::InvalidArgument; } Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { diff --git a/plugins/wasi_nn/types.h b/plugins/wasi_nn/types.h index 1c2dad70..c8d74cca 100644 --- a/plugins/wasi_nn/types.h +++ b/plugins/wasi_nn/types.h @@ -39,7 +39,7 @@ enum class Backend : uint8_t { NeuralSpeed = 7, Whisper = 9, Piper = 11, - ChatTTS = 8, + ChatTTS = 12, }; #define FOR_EACH_BACKEND(F) \ From c7273fb912581302d135aae105a03e24538b261d Mon Sep 17 00:00:00 2001 From: grorge Date: Sat, 13 Jul 2024 15:13:55 +0800 Subject: [PATCH 384/623] [WASI-NN] ChatTTS: add input metadata Signed-off-by: grorge --- plugins/wasi_nn/chattts.cpp | 162 ++++++++++++++++++++++++++----- plugins/wasi_nn/chattts.h | 2 + test/plugins/wasi_nn/wasi_nn.cpp | 2 +- 3 files changed, 139 insertions(+), 27 deletions(-) diff --git a/plugins/wasi_nn/chattts.cpp b/plugins/wasi_nn/chattts.cpp index 9014e9b6..681c3e4b 100644 --- a/plugins/wasi_nn/chattts.cpp +++ b/plugins/wasi_nn/chattts.cpp @@ -35,20 +35,6 @@ Expect load(WASINN::WasiNNEnvironment &Env, spdlog::info("[WASI-NN] ChatTTS backend: Load."sv); } - // if (Builders.size() > 1) { - // std::string Metadata = std::string( - // reinterpret_cast(Builders[1].data()), Builders[1].size()); - // simdjson::dom::parser Parser; - // simdjson::dom::element Doc; - // auto ParseError = Parser.parse(Metadata).get(Doc); - // if (ParseError) { - // spdlog::error("[WASI-NN] ChatTTS backend: Parse metadata error"sv); - // Env.NNGraph.pop_back(); - // return ErrNo::InvalidEncoding; - // } - // // TODO handle metadata - // } - // Create Model class if (!Py_IsInitialized()) { Py_Initialize(); @@ -106,7 +92,8 @@ Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, } Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, - uint32_t, const TensorData &Tensor) noexcept { + uint32_t Index, + const TensorData &Tensor) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); if (!Py_IsInitialized()) { @@ -117,14 +104,117 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN] ChatTTS backend: setInput"sv); } - - // Set the input. - std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), - Tensor.Tensor.size()); - CxtRef.Inputs.clear(); - CxtRef.Inputs = Prompt; - - return WASINN::ErrNo::Success; + if (Index == 0) { + // Set the input. + std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); + CxtRef.Inputs.clear(); + CxtRef.Inputs = Prompt; + return WASINN::ErrNo::Success; + } else if (Index == 1) { + // Set metadata. + std::string Metadata = std::string( + reinterpret_cast(Tensor.Tensor.data()), Tensor.Tensor.size()); + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + auto ParseError = Parser.parse(Metadata).get(Doc); + if (ParseError) { + spdlog::error("[WASI-NN] ChatTTS backend: Parse metadata error"sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidEncoding; + } + // Handle Refine Text Params + PyObject *PromptObj = nullptr; + if (Doc.at_key("prompt").error() == simdjson::SUCCESS) { + std::string_view PromptView; + auto Err = Doc["prompt"].get().get(PromptView); + if (Err) { + spdlog::error( + "[WASI-NN] ChatTTS backend: Unable to retrieve the prompt option."sv); + return ErrNo::InvalidArgument; + } + PromptObj = PyUnicode_FromString(std::string(PromptView).c_str()); + } + if (PromptObj != nullptr) { + PyObject *Args = PyTuple_New(0); + PyObject *Kwargs = PyDict_New(); + PyDict_SetItemString(Kwargs, "prompt", PromptObj); + PyObject *RefineTextParamsFun = + PyObject_GetAttrString(GraphRef.Chat, "RefineTextParams"); + GraphRef.ParamsRefineText = + PyObject_Call(RefineTextParamsFun, Args, Kwargs); + Py_XDECREF(PromptObj); + Py_DECREF(Args); + Py_DECREF(Kwargs); + Py_XDECREF(RefineTextParamsFun); + } + // Handle Infer Code Params + PyObject *InferKwargs = PyDict_New(); + if (Doc.at_key("temperature").error() == simdjson::SUCCESS) { + double temperature; + auto Err = Doc["temperature"].get().get(temperature); + if (Err) { + spdlog::error( + "[WASI-NN] ChatTTS backend: Unable to retrieve the temperature option."sv); + return ErrNo::InvalidArgument; + } + PyDict_SetItemString(InferKwargs, "temperature", + PyFloat_FromDouble(temperature)); + } + if (Doc.at_key("top_K").error() == simdjson::SUCCESS) { + double TopK; + auto Err = Doc["top_K"].get().get(TopK); + if (Err) { + spdlog::error( + "[WASI-NN] ChatTTS backend: Unable to retrieve the topK option."sv); + return ErrNo::InvalidArgument; + } + PyDict_SetItemString(InferKwargs, "top_K", PyFloat_FromDouble(TopK)); + } + if (Doc.at_key("top_P").error() == simdjson::SUCCESS) { + double TopP; + auto Err = Doc["top_P"].get().get(TopP); + if (Err) { + spdlog::error( + "[WASI-NN] ChatTTS backend: Unable to retrieve the temperature option."sv); + return ErrNo::InvalidArgument; + } + PyDict_SetItemString(InferKwargs, "top_P", PyFloat_FromDouble(TopP)); + } + if (Doc.at_key("spk_emb").error() == simdjson::SUCCESS) { + std::string_view SpkEmb; + auto Err = Doc["spk_emb"].get().get(SpkEmb); + if (Err) { + spdlog::error( + "[WASI-NN] ChatTTS backend: Unable to retrieve the spk_emb option."sv); + return ErrNo::InvalidArgument; + } + if (SpkEmb == "random") { + PyObject *SampleRandomSpeaker = + PyObject_GetAttrString(GraphRef.Chat, "sample_random_speaker"); + PyObject *Spk = PyObject_CallNoArgs(SampleRandomSpeaker); + PyDict_SetItemString(InferKwargs, "spk_emb", Spk); + Py_XDECREF(SampleRandomSpeaker); + Py_XDECREF(Spk); + } else { + PyObject *Spk = PyUnicode_FromString(std::string(SpkEmb).c_str()); + PyDict_SetItemString(InferKwargs, "spk_emb", Spk); + Py_XDECREF(Spk); + } + } + if (PyDict_Size(InferKwargs) != 0) { + PyObject *Args = PyTuple_New(0); + PyObject *InferCodeParams = + PyObject_GetAttrString(GraphRef.Chat, "InferCodeParams"); + GraphRef.ParamsInferCode = + PyObject_Call(InferCodeParams, Args, InferKwargs); + Py_XDECREF(Args); + Py_XDECREF(InferCodeParams); + } + Py_XDECREF(InferKwargs); + return WASINN::ErrNo::Success; + } + return WASINN::ErrNo::InvalidArgument; } Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, @@ -173,9 +263,27 @@ Expect compute(WasiNNEnvironment &Env, Py_XDECREF(InferMethod); return WASINN::ErrNo::RuntimeError; } - PyObject *Args = PyTuple_Pack(1, InputStr); - Result = PyObject_CallObject(InferMethod, Args); - Py_XDECREF(Args); + if (GraphRef.ParamsRefineText == nullptr && + GraphRef.ParamsInferCode == nullptr) { + PyObject *Args = PyTuple_Pack(1, InputStr); + Result = PyObject_CallObject(InferMethod, Args); + Py_XDECREF(Args); + } else { + PyObject *Args = PyTuple_New(0); + PyObject *Kwargs = PyDict_New(); + PyDict_SetItemString(Kwargs, "text", InputStr); + if (GraphRef.ParamsRefineText != nullptr) { + PyDict_SetItemString(Kwargs, "params_refine_text", + GraphRef.ParamsRefineText); + } + if (GraphRef.ParamsInferCode != nullptr) { + PyDict_SetItemString(Kwargs, "params_infer_code", + GraphRef.ParamsInferCode); + } + Result = PyObject_Call(InferMethod, Args, Kwargs); + Py_DECREF(Args); + Py_DECREF(Kwargs); + } if (Result != nullptr) { PyObject *Wav0 = PyList_GetItem(Result, 0); PyObject *BytesObj = PyObject_CallMethod(Wav0, "tobytes", nullptr); @@ -205,6 +313,8 @@ Expect unload(WASINN::WasiNNEnvironment &Env, spdlog::info("[WASI-NN] Neural speed backend: start unload."sv); } if (Py_IsInitialized()) { + Py_XDECREF(GraphRef.ParamsRefineText); + Py_XDECREF(GraphRef.ParamsInferCode); Py_XDECREF(GraphRef.Chat); Py_XDECREF(GraphRef.ChatTTSModule); Py_Finalize(); diff --git a/plugins/wasi_nn/chattts.h b/plugins/wasi_nn/chattts.h index acf11752..1d50184c 100644 --- a/plugins/wasi_nn/chattts.h +++ b/plugins/wasi_nn/chattts.h @@ -23,6 +23,8 @@ struct Graph { } PyObject *Chat; PyObject *ChatTTSModule; + PyObject *ParamsRefineText = nullptr; + PyObject *ParamsInferCode = nullptr; }; struct Context { Context(size_t Gid, Graph &) noexcept : GraphId(Gid) {} diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index ebed49b1..f4086b48 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -2649,7 +2649,7 @@ TEST(WasiNNTest, ChatTTSBackend) { // ChatTTS WASI-NN set_input tests. SetInputEntryPtr = BuilderPtr; writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); - writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeUInt32(MemInst, UINT32_C(2), BuilderPtr); writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), BuilderPtr); writeBinaries(MemInst, TensorDim, StorePtr); From a56cd1cb4721c35da98181f236613e0011675cf7 Mon Sep 17 00:00:00 2001 From: grorge Date: Mon, 29 Jul 2024 09:51:40 +0800 Subject: [PATCH 385/623] [WASI-NN] ChatTTS: fix memory leak and add metadata test Signed-off-by: grorge --- plugins/wasi_nn/chattts.cpp | 100 ++++++++++---------- plugins/wasi_nn/chattts.h | 4 +- test/plugins/wasi_nn/wasi_nn.cpp | 153 ++++++++++++++++++------------- 3 files changed, 143 insertions(+), 114 deletions(-) diff --git a/plugins/wasi_nn/chattts.cpp b/plugins/wasi_nn/chattts.cpp index 681c3e4b..f524e898 100644 --- a/plugins/wasi_nn/chattts.cpp +++ b/plugins/wasi_nn/chattts.cpp @@ -39,40 +39,41 @@ Expect load(WASINN::WasiNNEnvironment &Env, if (!Py_IsInitialized()) { Py_Initialize(); } - GraphRef.ChatTTSModule = PyImport_Import(PyUnicode_FromString("ChatTTS")); if (GraphRef.ChatTTSModule == nullptr) { - PyErr_Print(); - spdlog::error("[WASI-NN] ChatTTS backend: Can not find ChatTTS library."sv); - Env.NNGraph.pop_back(); - return WASINN::ErrNo::RuntimeError; - } - PyObject *ChatFunction = - PyObject_GetAttrString(GraphRef.ChatTTSModule, "Chat"); - if (ChatFunction == nullptr || !PyCallable_Check(ChatFunction)) { - spdlog::error( - "[WASI-NN] ChatTTS backend: Can not find Chat class in ChatTTS."sv); - Py_XDECREF(GraphRef.ChatTTSModule); - Env.NNGraph.pop_back(); - return WASINN::ErrNo::RuntimeError; + GraphRef.ChatTTSModule = PyImport_ImportModule("ChatTTS"); + if (GraphRef.ChatTTSModule == nullptr) { + spdlog::error( + "[WASI-NN] ChatTTS backend: Can not find ChatTTS library."sv); + Env.NNGraph.pop_back(); + return WASINN::ErrNo::RuntimeError; + } } - GraphRef.Chat = PyObject_CallObject(ChatFunction, nullptr); - Py_XDECREF(ChatFunction); if (GraphRef.Chat == nullptr) { - spdlog::error("[WASI-NN] ChatTTS backend: Can not create chat."sv); - Py_XDECREF(GraphRef.ChatTTSModule); - Env.NNGraph.pop_back(); - return WASINN::ErrNo::RuntimeError; - } - PyObject *LoadMethod = PyObject_GetAttrString(GraphRef.Chat, "load"); - if (LoadMethod == nullptr || !PyCallable_Check(LoadMethod)) { - spdlog::error("[WASI-NN] ChatTTS backend: Can not load chat."sv); - PyErr_Print(); - Env.NNGraph.pop_back(); - return WASINN::ErrNo::RuntimeError; + PyObject *ChatFunction = + PyObject_GetAttrString(GraphRef.ChatTTSModule, "Chat"); + if (ChatFunction == nullptr || !PyCallable_Check(ChatFunction)) { + spdlog::error( + "[WASI-NN] ChatTTS backend: Can not find Chat class in ChatTTS."sv); + Env.NNGraph.pop_back(); + return WASINN::ErrNo::RuntimeError; + } + GraphRef.Chat = PyObject_CallObject(ChatFunction, nullptr); + Py_XDECREF(ChatFunction); + if (GraphRef.Chat == nullptr) { + spdlog::error("[WASI-NN] ChatTTS backend: Can not create chat."sv); + Env.NNGraph.pop_back(); + return WASINN::ErrNo::RuntimeError; + } + PyObject *LoadMethod = PyObject_GetAttrString(GraphRef.Chat, "load"); + if (LoadMethod == nullptr || !PyCallable_Check(LoadMethod)) { + spdlog::error("[WASI-NN] ChatTTS backend: Can not load chat."sv); + Env.NNGraph.pop_back(); + return WASINN::ErrNo::RuntimeError; + } + PyObject *Value = PyObject_CallObject(LoadMethod, nullptr); + Py_XDECREF(Value); + Py_XDECREF(LoadMethod); } - PyObject *Value = PyObject_CallObject(LoadMethod, nullptr); - Py_XDECREF(Value); - Py_DECREF(LoadMethod); // Store the loaded graph. GraphId = Env.NNGraph.size() - 1; @@ -144,22 +145,23 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, GraphRef.ParamsRefineText = PyObject_Call(RefineTextParamsFun, Args, Kwargs); Py_XDECREF(PromptObj); - Py_DECREF(Args); - Py_DECREF(Kwargs); + Py_XDECREF(Args); + Py_XDECREF(Kwargs); Py_XDECREF(RefineTextParamsFun); } // Handle Infer Code Params PyObject *InferKwargs = PyDict_New(); if (Doc.at_key("temperature").error() == simdjson::SUCCESS) { - double temperature; - auto Err = Doc["temperature"].get().get(temperature); + double Temperature; + auto Err = Doc["temperature"].get().get(Temperature); if (Err) { spdlog::error( "[WASI-NN] ChatTTS backend: Unable to retrieve the temperature option."sv); return ErrNo::InvalidArgument; } - PyDict_SetItemString(InferKwargs, "temperature", - PyFloat_FromDouble(temperature)); + PyObject *TemperatureObject = PyFloat_FromDouble(Temperature); + PyDict_SetItemString(InferKwargs, "temperature", TemperatureObject); + Py_XDECREF(TemperatureObject); } if (Doc.at_key("top_K").error() == simdjson::SUCCESS) { double TopK; @@ -169,7 +171,9 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, "[WASI-NN] ChatTTS backend: Unable to retrieve the topK option."sv); return ErrNo::InvalidArgument; } - PyDict_SetItemString(InferKwargs, "top_K", PyFloat_FromDouble(TopK)); + PyObject *TopKObject = PyFloat_FromDouble(TopK); + PyDict_SetItemString(InferKwargs, "top_K", TopKObject); + Py_XDECREF(TopKObject); } if (Doc.at_key("top_P").error() == simdjson::SUCCESS) { double TopP; @@ -179,7 +183,9 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, "[WASI-NN] ChatTTS backend: Unable to retrieve the temperature option."sv); return ErrNo::InvalidArgument; } - PyDict_SetItemString(InferKwargs, "top_P", PyFloat_FromDouble(TopP)); + PyObject *TopPObject = PyFloat_FromDouble(TopP); + PyDict_SetItemString(InferKwargs, "top_P", TopPObject); + Py_XDECREF(TopPObject); } if (Doc.at_key("spk_emb").error() == simdjson::SUCCESS) { std::string_view SpkEmb; @@ -229,11 +235,6 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, std::copy_n(CxtRef.Outputs.data(), CxtRef.Outputs.size(), OutBuffer.data()); BytesWritten = CxtRef.Outputs.size(); return WASINN::ErrNo::Success; - } else if (Index == 1) { - uint32_t Size = CxtRef.Outputs.size(); - std::memcpy(OutBuffer.data(), &Size, sizeof(uint32_t)); - BytesWritten = sizeof(uint32_t); - return WASINN::ErrNo::Success; } return WASINN::ErrNo::InvalidArgument; } @@ -281,8 +282,8 @@ Expect compute(WasiNNEnvironment &Env, GraphRef.ParamsInferCode); } Result = PyObject_Call(InferMethod, Args, Kwargs); - Py_DECREF(Args); - Py_DECREF(Kwargs); + Py_XDECREF(Args); + Py_XDECREF(Kwargs); } if (Result != nullptr) { PyObject *Wav0 = PyList_GetItem(Result, 0); @@ -290,13 +291,12 @@ Expect compute(WasiNNEnvironment &Env, char *Bytes = PyBytes_AsString(BytesObj); Py_ssize_t size = PyBytes_Size(BytesObj); CxtRef.Outputs = std::vector(Bytes, Bytes + size); - Py_DECREF(BytesObj); - Py_DECREF(Wav0); + Py_XDECREF(BytesObj); } else { spdlog::error( "[WASI-NN] ChatTTS backend: Can not get output from infer method."sv); - PyErr_Print(); - Py_XDECREF(Result); + Py_XDECREF(InputStr); + Py_XDECREF(InferMethod); return WASINN::ErrNo::RuntimeError; } Py_XDECREF(Result); @@ -317,6 +317,8 @@ Expect unload(WASINN::WasiNNEnvironment &Env, Py_XDECREF(GraphRef.ParamsInferCode); Py_XDECREF(GraphRef.Chat); Py_XDECREF(GraphRef.ChatTTSModule); + GraphRef.Chat = nullptr; + GraphRef.ChatTTSModule = nullptr; Py_Finalize(); } return WASINN::ErrNo::Success; diff --git a/plugins/wasi_nn/chattts.h b/plugins/wasi_nn/chattts.h index 1d50184c..15f70a6f 100644 --- a/plugins/wasi_nn/chattts.h +++ b/plugins/wasi_nn/chattts.h @@ -21,8 +21,8 @@ struct Graph { Py_XDECREF(ChatTTSModule); } } - PyObject *Chat; - PyObject *ChatTTSModule; + PyObject *Chat = nullptr; + PyObject *ChatTTSModule = nullptr; PyObject *ParamsRefineText = nullptr; PyObject *ParamsInferCode = nullptr; }; diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index f4086b48..e0feef93 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -25,81 +25,78 @@ using WasmEdge::Host::WASINN::ErrNo; defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER) || \ -<<<<<<< HEAD -defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER) -======= + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS) ->>>>>>> 6769e4a7 ([WASI-NN] ChatTTS: add basic test) - namespace { - WasmEdge::Runtime::Instance::ModuleInstance *createModule( - std::string_view NNRPCURI = "") { - using namespace std::literals::string_view_literals; - WasmEdge::Plugin::Plugin::load( - std::filesystem::u8path("../../../plugins/wasi_nn/" WASMEDGE_LIB_PREFIX - "wasmedgePluginWasiNN" WASMEDGE_LIB_EXTENSION)); - if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasi_nn"sv)) { - WasmEdge::PO::ArgumentParser Parser; - Plugin->registerOptions(Parser); - if (NNRPCURI != "") { - Parser.set_raw_value("nn-rpc-uri"sv, - std::string(NNRPCURI)); - } - if (const auto *Module = Plugin->findModule("wasi_nn"sv)) { - return Module->create().release(); - } +namespace { +WasmEdge::Runtime::Instance::ModuleInstance * +createModule(std::string_view NNRPCURI = "") { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load( + std::filesystem::u8path("../../../plugins/wasi_nn/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasiNN" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasi_nn"sv)) { + WasmEdge::PO::ArgumentParser Parser; + Plugin->registerOptions(Parser); + if (NNRPCURI != "") { + Parser.set_raw_value("nn-rpc-uri"sv, std::string(NNRPCURI)); } - return nullptr; - } -#if !defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS) - inline std::vector readEntireFile(const std::string &Path) { - std::ifstream Fin(Path, std::ios::in | std::ios::binary | std::ios::ate); - if (!Fin) { - return {}; - } - std::vector Buf(static_cast(Fin.tellg())); - Fin.seekg(0, std::ios::beg); - if (!Fin.read(reinterpret_cast(Buf.data()), - static_cast(Buf.size()))) { - return {}; + if (const auto *Module = Plugin->findModule("wasi_nn"sv)) { + return Module->create().release(); } - Fin.close(); - return Buf; } + return nullptr; +} + +#if !defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS) +inline std::vector readEntireFile(const std::string &Path) { + std::ifstream Fin(Path, std::ios::in | std::ios::binary | std::ios::ate); + if (!Fin) { + return {}; + } + std::vector Buf(static_cast(Fin.tellg())); + Fin.seekg(0, std::ios::beg); + if (!Fin.read(reinterpret_cast(Buf.data()), + static_cast(Buf.size()))) { + return {}; + } + Fin.close(); + return Buf; +} #endif - template - void writeBinaries(WasmEdge::Runtime::Instance::MemoryInstance & MemInst, - WasmEdge::Span Binaries, uint32_t Ptr) noexcept { - std::copy(Binaries.begin(), Binaries.end(), MemInst.getPointer(Ptr)); - } +template +void writeBinaries(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + WasmEdge::Span Binaries, uint32_t Ptr) noexcept { + std::copy(Binaries.begin(), Binaries.end(), MemInst.getPointer(Ptr)); +} - void writeUInt32(WasmEdge::Runtime::Instance::MemoryInstance & MemInst, - uint32_t Value, uint32_t & Ptr) { - uint32_t *BufPtr = MemInst.getPointer(Ptr); - *BufPtr = Value; - Ptr += 4; - } +void writeUInt32(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t Value, uint32_t &Ptr) { + uint32_t *BufPtr = MemInst.getPointer(Ptr); + *BufPtr = Value; + Ptr += 4; +} - void writeFatPointer(WasmEdge::Runtime::Instance::MemoryInstance & MemInst, - uint32_t PtrVal, uint32_t PtrSize, uint32_t & Ptr) { - writeUInt32(MemInst, PtrVal, Ptr); - writeUInt32(MemInst, PtrSize, Ptr); - } +void writeFatPointer(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t PtrVal, uint32_t PtrSize, uint32_t &Ptr) { + writeUInt32(MemInst, PtrVal, Ptr); + writeUInt32(MemInst, PtrSize, Ptr); +} #if defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE) - template - std::vector classSort(WasmEdge::Span Array) { - std::vector Indices(Array.size()); - std::iota(Indices.begin(), Indices.end(), 0); - std::sort(Indices.begin(), Indices.end(), - [&Array](size_t Left, size_t Right) -> bool { - // Sort indices according to corresponding array element. - return Array[Left] > Array[Right]; - }); - return Indices; - } +template +std::vector classSort(WasmEdge::Span Array) { + std::vector Indices(Array.size()); + std::iota(Indices.begin(), Indices.end(), 0); + std::sort(Indices.begin(), Indices.end(), + [&Array](size_t Left, size_t Right) -> bool { + // Sort indices according to corresponding array element. + return Array[Left] > Array[Right]; + }); + return Indices; +} #endif } // namespace #endif @@ -2538,6 +2535,10 @@ TEST(WasiNNTest, ChatTTSBackend) { // Load the files. std::string Prompt = "This is test prompt."; std::vector TensorData(Prompt.begin(), Prompt.end()); + std::string config = + "{\"prompt\":\"[oral_2][laugh_0][break_6]\",\"spk_emb\":\"random\"," + "\"temperature\":0.5,\"top_k\":0,\"top_p\":0.9}"; + std::vector ConfigData(config.begin(), config.end()); std::vector TensorDim{1}; uint32_t BuilderPtr = UINT32_C(0); @@ -2693,6 +2694,32 @@ TEST(WasiNNTest, ChatTTSBackend) { EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); } + // Test: setInput -- set metadata successfully. + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, UINT32_C(2), BuilderPtr); + writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, ConfigData.size(), + BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries(MemInst, ConfigData, StorePtr + TensorDim.size() * 4); + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(1), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += (TensorDim.size() * 4 + ConfigData.size()); + + // Test: compute -- compute successfully. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(0)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + // ChatTTS WASI-NN get_output tests. // Test: get_output -- output bytes ptr out of bounds. { From 3119ecc9c5493d50e75fc4f93d3f9851dd8fd133 Mon Sep 17 00:00:00 2001 From: grorge Date: Wed, 31 Jul 2024 11:40:33 +0800 Subject: [PATCH 386/623] [WASI-NN] ChatTTS: fix typo Signed-off-by: grorge --- plugins/wasi_nn/chattts.cpp | 7 +++---- plugins/wasi_nn/chattts.h | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/plugins/wasi_nn/chattts.cpp b/plugins/wasi_nn/chattts.cpp index f524e898..8b8a2a04 100644 --- a/plugins/wasi_nn/chattts.cpp +++ b/plugins/wasi_nn/chattts.cpp @@ -30,7 +30,6 @@ Expect load(WASINN::WasiNNEnvironment &Env, Env.NNGraph.emplace_back(Backend::ChatTTS); auto &GraphRef = Env.NNGraph.back().get(); // Initialize the plugin parameters. - GraphRef.EnableDebugLog = true; if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN] ChatTTS backend: Load."sv); } @@ -84,7 +83,7 @@ Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, uint32_t &ContextId) noexcept { if (!Py_IsInitialized()) { spdlog::error( - "[WASI-NN] ChatTTS backend: Model has been realse, please reload it."sv); + "[WASI-NN] ChatTTS backend: Model has been released, please reload it."sv); return WASINN::ErrNo::RuntimeError; } Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); @@ -99,7 +98,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); if (!Py_IsInitialized()) { spdlog::error( - "[WASI-NN] ChatTTS backend: Model has been realse, please reload it."sv); + "[WASI-NN] ChatTTS backend: Model has been released, please reload it."sv); return WASINN::ErrNo::RuntimeError; } if (GraphRef.EnableDebugLog) { @@ -242,7 +241,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { if (!Py_IsInitialized()) { spdlog::error( - "[WASI-NN] ChatTTS backend: Model has been realse, please reload it."sv); + "[WASI-NN] ChatTTS backend: Model has been released, please reload it."sv); return WASINN::ErrNo::RuntimeError; } auto &CxtRef = Env.NNContext[ContextId].get(); diff --git a/plugins/wasi_nn/chattts.h b/plugins/wasi_nn/chattts.h index 15f70a6f..84676e25 100644 --- a/plugins/wasi_nn/chattts.h +++ b/plugins/wasi_nn/chattts.h @@ -13,7 +13,7 @@ struct WasiNNEnvironment; namespace WasmEdge::Host::WASINN::ChatTTS { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS struct Graph { - bool EnableDebugLog = true; + bool EnableDebugLog = false; Graph() noexcept { Py_Initialize(); } ~Graph() noexcept { if (Py_IsInitialized()) { From d1ce1728234430f8c51f1da37d3fbc2d5920511b Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 26 Jun 2024 16:52:02 +0800 Subject: [PATCH 387/623] [WASI-NN] ggml: support compute single in RPC mode Signed-off-by: dm4 --- plugins/wasi_nn/wasinnfunc.cpp | 119 +++++++++++------- test/plugins/wasi_nn/wasi_nn.cpp | 208 +++++++++++++++++++++++++++++++ 2 files changed, 284 insertions(+), 43 deletions(-) diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index 692028da..9af5fc96 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -36,6 +36,16 @@ Expect load(WASINN::WasiNNEnvironment &Env, return WASINN::ErrNo::InvalidEncoding; } } +#ifdef WASMEDGE_BUILD_WASI_NN_RPC +WASINN::ErrNo metadataToErrNo( + const std::multimap &Metadata) { + if (Metadata.find("errno") != Metadata.end()) { + auto ErrNo = std::stoi(Metadata.find("errno")->second.data()); + return static_cast(ErrNo); + } + return WASINN::ErrNo::Success; +} +#endif // #ifdef WASMEDGE_BUILD_WASI_NN_RPC } // namespace Expect @@ -136,9 +146,8 @@ WasiNNLoadByName::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t NamePtr, wasi_ephemeral_nn::LoadByNameResult Res; auto Status = Stub->LoadByName(&ClientContext, Req, &Res); if (!Status.ok()) { - spdlog::error("[WASI-NN] Failed when calling remote LoadByName: {}"sv, - Status.error_message()); - return WASINN::ErrNo::RuntimeError; + auto Metadata = ClientContext.GetServerTrailingMetadata(); + return metadataToErrNo(Metadata); } *GraphId = Res.graph_handle(); return WASINN::ErrNo::Success; @@ -196,10 +205,8 @@ Expect WasiNNLoadByNameWithConfig::bodyImpl( wasi_ephemeral_nn::LoadByNameWithConfigResult Res; auto Status = Stub->LoadByNameWithConfig(&ClientContext, Req, &Res); if (!Status.ok()) { - spdlog::error( - "[WASI-NN] Failed when calling remote LoadByNameWithConfig: {}"sv, - Status.error_message()); - return WASINN::ErrNo::RuntimeError; + auto Metadata = ClientContext.GetServerTrailingMetadata(); + return metadataToErrNo(Metadata); } *GraphId = Res.graph_handle(); return WASINN::ErrNo::Success; @@ -242,10 +249,8 @@ WasiNNInitExecCtx::bodyImpl(const Runtime::CallingFrame &Frame, wasi_ephemeral_nn::InitExecutionContextResult Res; auto Status = Stub->InitExecutionContext(&ClientContext, Req, &Res); if (!Status.ok()) { - spdlog::error( - "[WASI-NN] Failed when calling remote InitExecutionContext: {}"sv, - Status.error_message()); - return WASINN::ErrNo::RuntimeError; + auto Metadata = ClientContext.GetServerTrailingMetadata(); + return metadataToErrNo(Metadata); } *Context = Res.ctx_handle(); return WASINN::ErrNo::Success; @@ -339,9 +344,8 @@ WasiNNSetInput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context, google::protobuf::Empty Res; auto Status = Stub->SetInput(&ClientContext, Req, &Res); if (!Status.ok()) { - spdlog::error("[WASI-NN] Failed when calling remote SetInput: {}"sv, - Status.error_message()); - return WASINN::ErrNo::RuntimeError; + auto Metadata = ClientContext.GetServerTrailingMetadata(); + return metadataToErrNo(Metadata); } return WASINN::ErrNo::Success; } @@ -397,9 +401,8 @@ WasiNNGetOutput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context, wasi_ephemeral_nn::GetOutputResult Res; auto Status = Stub->GetOutput(&ClientContext, Req, &Res); if (!Status.ok()) { - spdlog::error("[WASI-NN] Failed when calling remote GetOutput: {}"sv, - Status.error_message()); - return WASINN::ErrNo::RuntimeError; + auto Metadata = ClientContext.GetServerTrailingMetadata(); + return metadataToErrNo(Metadata); } uint32_t BytesWrittenVal = std::min(static_cast(Res.data().size()), OutBufferMaxSize); @@ -430,25 +433,11 @@ Expect WasiNNGetOutputSingle::bodyImpl( const Runtime::CallingFrame &Frame, uint32_t Context, uint32_t Index, uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr) { -#ifdef WASMEDGE_BUILD_WASI_NN_RPC - if (Env.NNRPCChannel != nullptr) { - // TODO: implement RPC for GetOutputSingle - spdlog::error( - "[WASI-NN] RPC client is not implemented for GetOutputSingle"sv); - return WASINN::ErrNo::UnsupportedOperation; - } -#endif auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); } - if (Env.NNContext.size() <= Context) { - spdlog::error( - "[WASI-NN] get_output_single: Execution Context does not exist"sv); - return WASINN::ErrNo::InvalidArgument; - } - const auto OutBuffer = MemInst->getSpan(OutBufferPtr, OutBufferMaxSize); if (unlikely(OutBuffer.data() == nullptr)) { @@ -462,6 +451,34 @@ Expect WasiNNGetOutputSingle::bodyImpl( return WASINN::ErrNo::InvalidArgument; } +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + auto Stub = wasi_ephemeral_nn::GraphExecutionContextResource::NewStub( + Env.NNRPCChannel); + grpc::ClientContext ClientContext; + wasi_ephemeral_nn::GetOutputRequest Req; + Req.set_resource_handle(Context); + Req.set_index(Index); + wasi_ephemeral_nn::GetOutputResult Res; + auto Status = Stub->GetOutputSingle(&ClientContext, Req, &Res); + if (!Status.ok()) { + auto Metadata = ClientContext.GetServerTrailingMetadata(); + return metadataToErrNo(Metadata); + } + uint32_t BytesWrittenVal = + std::min(static_cast(Res.data().size()), OutBufferMaxSize); + std::copy_n(Res.data().begin(), BytesWrittenVal, OutBuffer.begin()); + *BytesWritten = BytesWrittenVal; + return WASINN::ErrNo::Success; + } +#endif // ifdef WASMEDGE_BUILD_WASI_NN_RPC + + if (Env.NNContext.size() <= Context) { + spdlog::error( + "[WASI-NN] get_output_single: Execution Context does not exist"sv); + return WASINN::ErrNo::InvalidArgument; + } + switch (Env.NNContext[Context].getBackend()) { case WASINN::Backend::GGML: return WASINN::GGML::getOutputSingle(Env, Context, Index, OutBuffer, @@ -485,9 +502,8 @@ WasiNNCompute::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context) { google::protobuf::Empty Res; auto Status = Stub->Compute(&ClientContext, Req, &Res); if (!Status.ok()) { - spdlog::error("[WASI-NN] Failed when calling remote Compute: {}"sv, - Status.error_message()); - return WASINN::ErrNo::RuntimeError; + auto Metadata = ClientContext.GetServerTrailingMetadata(); + return metadataToErrNo(Metadata); } return WASINN::ErrNo::Success; } @@ -519,12 +535,20 @@ WasiNNComputeSingle::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context) { #ifdef WASMEDGE_BUILD_WASI_NN_RPC if (Env.NNRPCChannel != nullptr) { - // TODO: implement RPC for ComputeSingle - spdlog::error( - "[WASI-NN] RPC client is not implemented for ComputeSingle"sv); - return WASINN::ErrNo::UnsupportedOperation; + auto Stub = wasi_ephemeral_nn::GraphExecutionContextResource::NewStub( + Env.NNRPCChannel); + grpc::ClientContext ClientContext; + wasi_ephemeral_nn::ComputeRequest Req; + Req.set_resource_handle(Context); + google::protobuf::Empty Res; + auto Status = Stub->ComputeSingle(&ClientContext, Req, &Res); + if (!Status.ok()) { + auto Metadata = ClientContext.GetServerTrailingMetadata(); + return metadataToErrNo(Metadata); + } + return WASINN::ErrNo::Success; } -#endif +#endif // ifdef WASMEDGE_BUILD_WASI_NN_RPC auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); @@ -551,11 +575,20 @@ WasiNNFiniSingle::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context) { #ifdef WASMEDGE_BUILD_WASI_NN_RPC if (Env.NNRPCChannel != nullptr) { - // TODO: implement RPC for FiniSingle - spdlog::error("[WASI-NN] RPC client is not implemented for FiniSingle"sv); - return WASINN::ErrNo::UnsupportedOperation; + auto Stub = wasi_ephemeral_nn::GraphExecutionContextResource::NewStub( + Env.NNRPCChannel); + grpc::ClientContext ClientContext; + wasi_ephemeral_nn::FiniSingleRequest Req; + Req.set_resource_handle(Context); + google::protobuf::Empty Res; + auto Status = Stub->FiniSingle(&ClientContext, Req, &Res); + if (!Status.ok()) { + auto Metadata = ClientContext.GetServerTrailingMetadata(); + return metadataToErrNo(Metadata); + } + return WASINN::ErrNo::Success; } -#endif +#endif // ifdef WASMEDGE_BUILD_WASI_NN_RPC auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index e0feef93..93340b40 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -1768,6 +1768,214 @@ TEST(WasiNNTest, GGMLBackendWithRPC) { delete NNMod; } + +TEST(WasiNNTest, GGMLBackendComputeSingleWithRPC) { + // wasi_nn_rpcserver has to be started outside this test, + // and the URI has to be set to $WASI_NN_RPC_TEST_URI. + // nn-preload has to be specified for "default". + /* + DIR=/tmp/build + export WASI_NN_RPC_TEST_URI=unix://${DIR}/wasi_nn_rpc.sock + export WASMEDGE_PLUGIN_PATH=${DIR}/plugins/wasi_nn + ${DIR}/tools/wasmedge/wasi_nn_rpcserver \ + --nn-rpc-uri=$WASI_NN_RPC_TEST_URI \ + --nn-preload=default:GGML:AUTO:${DIR}/test/plugins/wasi_nn/wasinn_ggml_fixtures/orca_mini.gguf + */ + const auto NNRPCURI = ::getenv("WASI_NN_RPC_TEST_URI"); + if (NNRPCURI == nullptr) { + GTEST_SKIP() << "WASI_NN_RPC_TEST_URI is unset"; + } + + // Create the wasmedge_process module instance. + auto *NNMod = + dynamic_cast(createModule(NNRPCURI)); + EXPECT_FALSE(NNMod == nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(60000))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + std::string Prompt = "Once upon a time, "; + std::vector TensorData(Prompt.begin(), Prompt.end()); + + std::vector TensorDim{1}; + uint32_t BuilderPtr = UINT32_C(0); + uint32_t LoadEntryPtr = UINT32_C(0); + uint32_t SetInputEntryPtr = UINT32_C(0); + uint32_t OutBoundPtr = UINT32_C(61000) * UINT32_C(65536); + uint32_t StorePtr = UINT32_C(65536); + + // Return value. + std::array Errno = {UINT32_C(0)}; + + // Get the function "load_by_name". + auto FuncInst = NNMod->findFuncExports("load_by_name"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncLoadByName = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "load_by_name_with_config". + FuncInst = NNMod->findFuncExports("load_by_name_with_config"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncLoadByNameWithConfig = + dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "init_execution_context". + FuncInst = NNMod->findFuncExports("init_execution_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInit = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "set_input". + FuncInst = NNMod->findFuncExports("set_input"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSetInput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "get_output". + FuncInst = NNMod->findFuncExports("get_output_single"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetOutputSingle = + dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "compute". + FuncInst = NNMod->findFuncExports("compute_single"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncComputeSingle = + dynamic_cast( + FuncInst->getHostFunc()); + + // Test: load_by_name -- load successfully. + { + std::string Name = "default"; + std::vector NameVec(Name.begin(), Name.end()); + writeBinaries(MemInst, NameVec, LoadEntryPtr); + EXPECT_TRUE(HostFuncLoadByName.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, static_cast(NameVec.size()), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Test: load_by_name_with_config -- load successfully. + { + std::string Name = "default"; + std::string Config = "{}"; + std::vector NameVec(Name.begin(), Name.end()); + std::vector ConfigVec(Config.begin(), Config.end()); + uint32_t ConfigPtr = LoadEntryPtr + NameVec.size(); + writeBinaries(MemInst, NameVec, LoadEntryPtr); + writeBinaries(MemInst, ConfigVec, ConfigPtr); + EXPECT_TRUE(HostFuncLoadByNameWithConfig.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, static_cast(NameVec.size()), ConfigPtr, + static_cast(ConfigVec.size()), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // Test: init_execution_context -- init context successfully. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // GGML WASI-NN set_input tests. + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), + BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); + + // Test: set_input -- set input successfully. + BuilderPtr = SetInputEntryPtr; + writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), + BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, + StorePtr + static_cast(TensorDim.size()) * 4, + static_cast(TensorData.size()), BuilderPtr); + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += (TensorDim.size() * 4 + TensorData.size()); + + // GGML WASI-NN compute_single tests. + // Test: compute_single -- context id exceeds. + { + EXPECT_TRUE(HostFuncComputeSingle.run( + CallFrame, std::initializer_list{UINT32_C(3)}, + Errno)); + EXPECT_NE(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Test: compute_single -- call compute_single once follow by a + // get_output_single. + { + // compute_single + EXPECT_TRUE(HostFuncComputeSingle.run( + CallFrame, std::initializer_list{UINT32_C(0)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + // get_output_single + EXPECT_TRUE(HostFuncGetOutputSingle.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // GGML WASI-NN get_output_single tests. + // Test: get_output_single -- output bytes ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutputSingle.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, OutBoundPtr}, + Errno)); + EXPECT_NE(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Test: get_output -- output buffer ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutputSingle.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), OutBoundPtr, 65532, BuilderPtr}, + Errno)); + EXPECT_NE(Errno[0].get(), static_cast(ErrNo::Success)); + } +} #endif // WASMEDGE_BUILD_WASI_NN_RPC #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML From fb6bfa22c7d8a27c5a96f564a913ac4c384c3460 Mon Sep 17 00:00:00 2001 From: hydai Date: Thu, 1 Aug 2024 17:59:36 +0800 Subject: [PATCH 388/623] [WASI-NN] ggml: bump to b3499 Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index cdd1b332..9c6020b3 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -72,7 +72,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b3463 + GIT_TAG b3499 GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(llama) From 23dc641a84c7402af3ecc1a8ccfb32a2fb055f5c Mon Sep 17 00:00:00 2001 From: PeterD1524 <53310459+PeterD1524@users.noreply.github.com> Date: Mon, 5 Aug 2024 14:39:18 +0800 Subject: [PATCH 389/623] [Docker] install onnxruntime in manylinux_2_28-plugins-deps for WASI-NN Piper CI (#3622) * install onnxruntime in manylinux_2_28-plugins-deps for WASI-NN Piper CI Signed-off-by: PeterD1524 * update utils/wasi-nn/install-onnxruntime.sh to support both x86_64 and aarch64 Signed-off-by: PeterD1524 --------- Signed-off-by: PeterD1524 --- utils/docker/Dockerfile.manylinux_2_28-plugins-deps | 3 +++ utils/wasi-nn/install-onnxruntime.sh | 13 ++++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/utils/docker/Dockerfile.manylinux_2_28-plugins-deps b/utils/docker/Dockerfile.manylinux_2_28-plugins-deps index 066d2d95..1f907cb0 100644 --- a/utils/docker/Dockerfile.manylinux_2_28-plugins-deps +++ b/utils/docker/Dockerfile.manylinux_2_28-plugins-deps @@ -38,4 +38,7 @@ ENV LD_LIBRARY_PATH=/root/FFmpeg-n6.0/output/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY ENV OPENVINO_VERSION="2024.2.0" ENV OPENVINO_YEAR="2024" +COPY wasi-nn/install-onnxruntime.sh . +RUN [ "/bin/bash", "install-onnxruntime.sh" ] + RUN yum clean all diff --git a/utils/wasi-nn/install-onnxruntime.sh b/utils/wasi-nn/install-onnxruntime.sh index 61c2acb7..3ff011e3 100644 --- a/utils/wasi-nn/install-onnxruntime.sh +++ b/utils/wasi-nn/install-onnxruntime.sh @@ -2,9 +2,18 @@ set -e +case "$(uname -m)" in + 'x86_64') ARCH='x64' ;; + 'aarch64') ARCH='aarch64' ;; + *) + echo 'Cannot determine architecture for onnxruntime' >&2 + exit 1 + ;; +esac + : ${ONNXRUNTIME_VERSION:=1.14.1} -ONNXRUNTIME_NAME="onnxruntime-linux-x64-${ONNXRUNTIME_VERSION}" +ONNXRUNTIME_NAME="onnxruntime-linux-${ARCH}-${ONNXRUNTIME_VERSION}" ONNXRUNTIME_TGZ="${ONNXRUNTIME_NAME}.tgz" curl -LO "https://github.com/microsoft/onnxruntime/releases/download/v${ONNXRUNTIME_VERSION}/${ONNXRUNTIME_TGZ}" @@ -12,3 +21,5 @@ tar zxf "${ONNXRUNTIME_TGZ}" mv "${ONNXRUNTIME_NAME}/include/"* /usr/local/include/ mv "${ONNXRUNTIME_NAME}/lib/"* /usr/local/lib/ rm -rf "${ONNXRUNTIME_TGZ}" "${ONNXRUNTIME_NAME}" + +ldconfig From 008fdfca7125f11aef150ba32d3444362a1f0181 Mon Sep 17 00:00:00 2001 From: PeterD1524 <53310459+PeterD1524@users.noreply.github.com> Date: Tue, 6 Aug 2024 14:14:11 +0800 Subject: [PATCH 390/623] [CI] manylinux: WASI-NN Piper (#3614) * enable wasi_nn-piper build for manylinux_2_28_x86_64 Signed-off-by: PeterD1524 * update piper patch to allow passing CMAKE_POSITION_INDEPENDENT_CODE to espeak-ng It worked on ubuntu before because gcc on ubuntu (wasmedge/wasmedge:ubuntu-build-gcc) is configured with --enable-default-pie. This will turn R_X86_64_32S into R_X86_64_PC32. gcc on wasmedge/wasmedge:manylinux_2_28_x86_64-plugins-deps is not configured with --enable-default-pie, so the error `libespeak-ng.a(case.c.o): relocation R_X86_64_32S against `.rodata' can not be used when making a shared object; recompile with -fPIC` will occur. The configuration of gcc can be obtained by running `gcc -v`. As a result, passing CMAKE_POSITION_INDEPENDENT_CODE=ON to espeak-ng to compile with -fPIC is required. Some references: https://stackoverflow.com/a/46493456 https://stackoverflow.com/a/6093910 https://www.ucw.cz/~hubicka/papers/abi/node19.html Signed-off-by: PeterD1524 * enable wasi_nn-piper build for manylinux_2_28_aarch64 Signed-off-by: PeterD1524 --------- Signed-off-by: PeterD1524 --- plugins/wasi_nn/piper.patch | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/plugins/wasi_nn/piper.patch b/plugins/wasi_nn/piper.patch index b5acad34..3ac654d6 100644 --- a/plugins/wasi_nn/piper.patch +++ b/plugins/wasi_nn/piper.patch @@ -168,12 +168,12 @@ index 26aaba0..867e524 100644 \ No newline at end of file diff --git a/piper-phonemize.patch b/piper-phonemize.patch new file mode 100644 -index 0000000..0d91cde +index 0000000..c4676cb --- /dev/null +++ b/piper-phonemize.patch -@@ -0,0 +1,213 @@ +@@ -0,0 +1,214 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt -+index ec7b501..335ef46 100644 ++index ec7b501..39275a6 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -10,6 +10,8 @@ project( @@ -274,15 +274,16 @@ index 0000000..0d91cde + CMAKE_ARGS -DUSE_MBROLA:BOOL=OFF + CMAKE_ARGS -DUSE_LIBSONIC:BOOL=OFF + CMAKE_ARGS -DUSE_LIBPCAUDIO:BOOL=OFF -+@@ -116,6 +153,7 @@ if(NOT DEFINED ESPEAK_NG_DIR) ++@@ -116,6 +153,8 @@ if(NOT DEFINED ESPEAK_NG_DIR) + CMAKE_ARGS -DEXTRA_cmn:BOOL=ON + CMAKE_ARGS -DEXTRA_ru:BOOL=ON + CMAKE_ARGS -DCMAKE_C_FLAGS="-D_FILE_OFFSET_BITS=64" +++ CMAKE_ARGS "-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${CMAKE_POSITION_INDEPENDENT_CODE}" ++ USES_TERMINAL_DOWNLOAD TRUE + ) + add_dependencies(piper_phonemize espeak_ng_external) + endif() -+@@ -123,23 +161,27 @@ endif() ++@@ -123,23 +162,27 @@ endif() + + # ---- Declare library ---- + @@ -313,7 +314,7 @@ index 0000000..0d91cde + ) + + target_compile_features(piper_phonemize PUBLIC cxx_std_17) -+@@ -173,12 +215,13 @@ target_link_libraries(piper_phonemize_exe PUBLIC ++@@ -173,12 +216,13 @@ target_link_libraries(piper_phonemize_exe PUBLIC + # ---- Declare test ---- + + include(CTest) @@ -332,7 +333,7 @@ index 0000000..0d91cde + + target_compile_features(test_piper_phonemize PUBLIC cxx_std_17) + -+@@ -207,7 +250,7 @@ install( ++@@ -207,7 +251,7 @@ install( + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) + + install( @@ -341,7 +342,7 @@ index 0000000..0d91cde + DESTINATION include/piper-phonemize + FILES_MATCHING + PATTERN "*.h" -+@@ -218,7 +261,7 @@ install( ++@@ -218,7 +262,7 @@ install( + ARCHIVE DESTINATION ${CMAKE_INSTALL_BINDIR}) + + install( @@ -350,7 +351,7 @@ index 0000000..0d91cde + TYPE DATA) + + # Dependencies -+@@ -226,10 +269,12 @@ install( ++@@ -226,10 +270,12 @@ install( + DIRECTORY ${ESPEAK_NG_DIR}/ + DESTINATION ${CMAKE_INSTALL_PREFIX}) + From c737bb05c1b1dd487204c7b4b3fbb4011d570c51 Mon Sep 17 00:00:00 2001 From: Han-Wen Tsao Date: Wed, 7 Aug 2024 00:03:53 +0800 Subject: [PATCH 391/623] [Plugin] Stable Diffusion: fix segment fault (#3628) Signed-off-by: grorge --- plugins/wasmedge_stablediffusion/sd_func.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/wasmedge_stablediffusion/sd_func.cpp b/plugins/wasmedge_stablediffusion/sd_func.cpp index 8a74cb2f..b9ccf33f 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.cpp +++ b/plugins/wasmedge_stablediffusion/sd_func.cpp @@ -275,7 +275,7 @@ Expect SDTextToImage::body( // TODO upscale image int Len; unsigned char *Png = stbi_write_png_to_mem( - reinterpret_cast(Results), 0, Results->width, + reinterpret_cast(Results->data), 0, Results->width, Results->height, Results->channel, &Len, nullptr); if (OutputPathLen != 0) { stbi_write_png(OutputPath.data(), Results->width, Results->height, @@ -381,7 +381,7 @@ Expect SDImageToImage::body( // TODO: upscale image int Len; unsigned char *Png = stbi_write_png_to_mem( - reinterpret_cast(Results), 0, Results->width, + reinterpret_cast(Results->data), 0, Results->width, Results->height, Results->channel, &Len, nullptr); if (OutputPathLen != 0) { stbi_write_png(OutputPath.data(), Results->width, Results->height, From 65ef1ef57f61b7a8ec8767055e17b77cad6b941c Mon Sep 17 00:00:00 2001 From: YiYing He Date: Mon, 5 Aug 2024 16:55:05 +0800 Subject: [PATCH 392/623] [Plugin] Move the wasi-logging plugin to the built-in plugin. Signed-off-by: YiYing He --- plugins/CMakeLists.txt | 6 +- plugins/wasi_logging/CMakeLists.txt | 31 ---- plugins/wasi_logging/base.h | 28 ---- plugins/wasi_logging/env.cpp | 38 ----- plugins/wasi_logging/env.h | 36 ----- plugins/wasi_logging/func.cpp | 99 ------------- plugins/wasi_logging/func.h | 22 --- plugins/wasi_logging/module.cpp | 20 --- plugins/wasi_logging/module.h | 24 --- test/plugins/CMakeLists.txt | 5 +- test/plugins/wasi_logging/CMakeLists.txt | 1 + test/plugins/wasi_logging/wasi_logging.cpp | 140 ++++++++++++------ .../avformat/avformat_func.cpp | 4 +- .../wasmedge_ffmpeg/avutil/avDictionary.cpp | 5 +- .../wasmedge_ffmpeg/avutil/avError.cpp | 6 +- .../wasmedge_ffmpeg/avutil/avSampleFmt.cpp | 4 +- test/plugins/wasmedge_ffmpeg/utils.h | 4 +- .../wasmedge_process/wasmedge_process.cpp | 31 ++-- 18 files changed, 129 insertions(+), 375 deletions(-) delete mode 100644 plugins/wasi_logging/CMakeLists.txt delete mode 100644 plugins/wasi_logging/base.h delete mode 100644 plugins/wasi_logging/env.cpp delete mode 100644 plugins/wasi_logging/env.h delete mode 100644 plugins/wasi_logging/func.cpp delete mode 100644 plugins/wasi_logging/func.h delete mode 100644 plugins/wasi_logging/module.cpp delete mode 100644 plugins/wasi_logging/module.h diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index 2061e166..f9ac516c 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -84,10 +84,6 @@ if(WASMEDGE_PLUGIN_OPENCVMINI) endif() endif() -if(WASMEDGE_PLUGIN_WASI_LOGGING) - add_subdirectory(wasi_logging) -endif() - if(WASMEDGE_PLUGIN_ZLIB) add_subdirectory(wasmedge_zlib) endif() @@ -96,6 +92,6 @@ if(WASMEDGE_PLUGIN_FFMPEG) add_subdirectory(wasmedge_ffmpeg) endif() -if (WASMEDGE_PLUGIN_LLM) +if(WASMEDGE_PLUGIN_LLM) add_subdirectory(wasi_llm) endif() diff --git a/plugins/wasi_logging/CMakeLists.txt b/plugins/wasi_logging/CMakeLists.txt deleted file mode 100644 index 0d0b6469..00000000 --- a/plugins/wasi_logging/CMakeLists.txt +++ /dev/null @@ -1,31 +0,0 @@ -wasmedge_add_library(wasmedgePluginWasiLogging - SHARED - env.cpp - func.cpp - module.cpp -) - -target_compile_options(wasmedgePluginWasiLogging - PUBLIC - -DWASMEDGE_PLUGIN -) - -target_include_directories(wasmedgePluginWasiLogging - PUBLIC - $ - ${CMAKE_CURRENT_SOURCE_DIR} -) - -if(WASMEDGE_LINK_PLUGINS_STATIC) - target_link_libraries(wasmedgePluginWasiLogging - PRIVATE - wasmedgeCAPI - ) -else() - target_link_libraries(wasmedgePluginWasiLogging - PRIVATE - wasmedge_shared - ) -endif() - -install(TARGETS wasmedgePluginWasiLogging DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) diff --git a/plugins/wasi_logging/base.h b/plugins/wasi_logging/base.h deleted file mode 100644 index dc1c51ca..00000000 --- a/plugins/wasi_logging/base.h +++ /dev/null @@ -1,28 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2024 Second State INC - -#pragma once - -#include "env.h" - -#include "common/errcode.h" -#include "runtime/callingframe.h" -#include "runtime/hostfunc.h" - -namespace WasmEdge { -namespace Host { -namespace WASILogging { - -enum class LogLevel : uint32_t { Trace, Debug, Info, Warn, Error, Critical }; - -template class Func : public Runtime::HostFunction { -public: - Func(LogEnv &HostEnv) : Runtime::HostFunction(0), Env(HostEnv) {} - -protected: - LogEnv &Env; -}; - -} // namespace WASILogging -} // namespace Host -} // namespace WasmEdge diff --git a/plugins/wasi_logging/env.cpp b/plugins/wasi_logging/env.cpp deleted file mode 100644 index 97f2b141..00000000 --- a/plugins/wasi_logging/env.cpp +++ /dev/null @@ -1,38 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2024 Second State INC - -#include "env.h" -#include "module.h" - -namespace WasmEdge { -namespace Host { - -namespace { - -Runtime::Instance::ModuleInstance * -create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { - return new WasiLoggingModule; -} - -Plugin::Plugin::PluginDescriptor Descriptor{ - .Name = "wasi_logging", - .Description = "", - .APIVersion = Plugin::Plugin::CurrentAPIVersion, - .Version = {0, 1, 0, 0}, - .ModuleCount = 1, - .ModuleDescriptions = - (Plugin::PluginModule::ModuleDescriptor[]){ - { - .Name = "wasi:logging/logging", - .Description = "", - .Create = create, - }, - }, - .AddOptions = nullptr, -}; - -EXPORT_GET_DESCRIPTOR(Descriptor) - -} // namespace -} // namespace Host -} // namespace WasmEdge diff --git a/plugins/wasi_logging/env.h b/plugins/wasi_logging/env.h deleted file mode 100644 index 838c7850..00000000 --- a/plugins/wasi_logging/env.h +++ /dev/null @@ -1,36 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2024 Second State INC - -#pragma once - -#include "plugin/plugin.h" - -#include -#include - -namespace WasmEdge { -namespace Host { -namespace WASILogging { - -class LogEnv { -public: - LogEnv() noexcept { - // TODO: Use the config in WasmEdge to set the logging level. - StdoutLogger->set_level(spdlog::level::trace); - StderrLogger->set_level(spdlog::level::trace); - StdoutLogger->set_pattern(DefFormat); - StderrLogger->set_pattern(DefFormat); - } - - const std::shared_ptr StdoutLogger = - spdlog::stdout_color_mt("wasi_logging_stdout"); - const std::shared_ptr StderrLogger = - spdlog::stderr_color_mt("wasi_logging_stderr"); - const std::string DefFormat = "[%Y-%m-%d %H:%M:%S.%e] [%^%l%$] %v"; - std::shared_ptr FileLogger; - std::string LogFileName; -}; - -} // namespace WASILogging -} // namespace Host -} // namespace WasmEdge diff --git a/plugins/wasi_logging/func.cpp b/plugins/wasi_logging/func.cpp deleted file mode 100644 index 9d44d4c3..00000000 --- a/plugins/wasi_logging/func.cpp +++ /dev/null @@ -1,99 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2024 Second State INC - -#include "func.h" - -#include - -namespace WasmEdge { -namespace Host { -namespace WASILogging { - -using namespace std::literals; - -Expect Log::body(const Runtime::CallingFrame &Frame, uint32_t Level, - uint32_t CxtPtr, uint32_t CxtLen, uint32_t MsgPtr, - uint32_t MsgLen) { - // Check memory instance from module. - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - return Unexpect(ErrCode::Value::HostFuncError); - } - - // Get Buffer Pointer. - char *CxtBuf = MemInst->getPointer(CxtPtr); - char *MsgBuf = MemInst->getPointer(MsgPtr); - if (CxtBuf == nullptr || MsgBuf == nullptr) { - return Unexpect(ErrCode::Value::HostFuncError); - } - - // Get Context and Message string_view - std::string_view CxtSV(CxtBuf, CxtLen); - std::string_view MsgSV(MsgBuf, MsgLen); - - // Setup Logger for Stdout or Stderr - std::shared_ptr Logger; - if (CxtSV == "stdout"sv || CxtSV == ""sv) { - Logger = Env.StdoutLogger; - } else if (CxtSV == "stderr"sv) { - Logger = Env.StderrLogger; - } else { - if (CxtSV != Env.LogFileName) { - try { - spdlog::drop("wasi_logging_file"); - Env.FileLogger = - spdlog::basic_logger_mt("wasi_logging_file", std::string(CxtSV)); - Env.FileLogger->set_pattern(Env.DefFormat); - Env.LogFileName = CxtSV; - // TODO: Use the config in WasmEdge to set the logging level. - Env.FileLogger->set_level(spdlog::level::trace); - } catch (const spdlog::spdlog_ex &Ex) { - spdlog::error("[WasiLogging] Cannot log into file: {}"sv, Ex.what()); - return Unexpect(ErrCode::Value::HostFuncError); - } - } - Logger = Env.FileLogger; - } - - // Print Message by Logging Level - switch (static_cast(Level)) { - case LogLevel::Trace: - Logger->trace(MsgSV); - break; - case LogLevel::Debug: - Logger->debug(MsgSV); - break; - case LogLevel::Info: - Logger->info(MsgSV); - break; - case LogLevel::Warn: - Logger->warn(MsgSV); - break; - case LogLevel::Error: - Logger->error(MsgSV); - break; - case LogLevel::Critical: - Logger->critical(MsgSV); - break; - default: - spdlog::error("[WasiLogging] Unrecognized Logging Level: {}"sv, Level); - spdlog::error("[WasiLogging] Trace Level = {}"sv, - static_cast(LogLevel::Trace)); - spdlog::error("[WasiLogging] Debug Level = {}"sv, - static_cast(LogLevel::Debug)); - spdlog::error("[WasiLogging] Info Level = {}"sv, - static_cast(LogLevel::Info)); - spdlog::error("[WasiLogging] Warn Level = {}"sv, - static_cast(LogLevel::Warn)); - spdlog::error("[WasiLogging] Error Level = {}"sv, - static_cast(LogLevel::Error)); - spdlog::error("[WasiLogging] Critical Level = {}"sv, - static_cast(LogLevel::Critical)); - return Unexpect(ErrCode::Value::HostFuncError); - } - return {}; -} - -} // namespace WASILogging -} // namespace Host -} // namespace WasmEdge diff --git a/plugins/wasi_logging/func.h b/plugins/wasi_logging/func.h deleted file mode 100644 index 60e81c9d..00000000 --- a/plugins/wasi_logging/func.h +++ /dev/null @@ -1,22 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2024 Second State INC - -#pragma once - -#include "base.h" - -namespace WasmEdge { -namespace Host { -namespace WASILogging { - -class Log : public Func { -public: - Log(LogEnv &HostEnv) : Func(HostEnv) {} - Expect body(const Runtime::CallingFrame &Frame, uint32_t Level, - uint32_t CxtPtr, uint32_t CxtLen, uint32_t MsgPtr, - uint32_t MsgLen); -}; - -} // namespace WASILogging -} // namespace Host -} // namespace WasmEdge diff --git a/plugins/wasi_logging/module.cpp b/plugins/wasi_logging/module.cpp deleted file mode 100644 index 938ecc20..00000000 --- a/plugins/wasi_logging/module.cpp +++ /dev/null @@ -1,20 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2024 Second State INC - -#include "module.h" -#include "func.h" - -#include - -namespace WasmEdge { -namespace Host { - -using namespace std::literals; - -WasiLoggingModule::WasiLoggingModule() - : ModuleInstance("wasi:logging/logging"sv) { - addHostFunc("log"sv, std::make_unique(Env)); -} - -} // namespace Host -} // namespace WasmEdge diff --git a/plugins/wasi_logging/module.h b/plugins/wasi_logging/module.h deleted file mode 100644 index 9cf1260e..00000000 --- a/plugins/wasi_logging/module.h +++ /dev/null @@ -1,24 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2024 Second State INC - -#pragma once - -#include "env.h" - -#include "runtime/instance/module.h" - -namespace WasmEdge { -namespace Host { - -class WasiLoggingModule : public Runtime::Instance::ModuleInstance { -public: - WasiLoggingModule(); - - WASILogging::LogEnv &getEnv() { return Env; } - -private: - WASILogging::LogEnv Env; -}; - -} // namespace Host -} // namespace WasmEdge diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt index e59b215e..571a578d 100644 --- a/test/plugins/CMakeLists.txt +++ b/test/plugins/CMakeLists.txt @@ -57,8 +57,5 @@ if(WASMEDGE_PLUGIN_STABLEDIFFUSION) add_subdirectory(wasmedge_stablediffusion) endif() -if(WASMEDGE_PLUGIN_WASI_LOGGING) - add_subdirectory(wasi_logging) -endif() - +add_subdirectory(wasi_logging) add_subdirectory(unittest) diff --git a/test/plugins/wasi_logging/CMakeLists.txt b/test/plugins/wasi_logging/CMakeLists.txt index d3373e80..88985294 100644 --- a/test/plugins/wasi_logging/CMakeLists.txt +++ b/test/plugins/wasi_logging/CMakeLists.txt @@ -19,6 +19,7 @@ target_link_libraries(wasiLoggingTests PRIVATE ${GTEST_BOTH_LIBRARIES} ) + # Link to the WasmEdge library if(WASMEDGE_LINK_PLUGINS_STATIC) target_link_libraries(wasiLoggingTests diff --git a/test/plugins/wasi_logging/wasi_logging.cpp b/test/plugins/wasi_logging/wasi_logging.cpp index 471b7297..90cad91e 100644 --- a/test/plugins/wasi_logging/wasi_logging.cpp +++ b/test/plugins/wasi_logging/wasi_logging.cpp @@ -1,15 +1,13 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "func.h" -#include "module.h" +#include "plugin/wasi_logging/func.h" +#include "plugin/wasi_logging/module.h" #include "common/defines.h" #include "runtime/instance/module.h" #include -#include -#include #include #include @@ -18,9 +16,8 @@ namespace { WasmEdge::Runtime::Instance::ModuleInstance *createModule() { using namespace std::literals::string_view_literals; - WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( - "../../../plugins/wasi_logging/" WASMEDGE_LIB_PREFIX - "wasmedgePluginWasiLogging" WASMEDGE_LIB_EXTENSION)); + // The built-in plugins are loaded when loading from default paths. + WasmEdge::Plugin::Plugin::loadFromDefaultPaths(); if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasi_logging"sv)) { if (const auto *Module = Plugin->findModule("wasi:logging/logging"sv)) { return Module->create().release(); @@ -35,18 +32,23 @@ void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, } void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, - uint32_t Offset, const std::string &Str) noexcept { + uint32_t Offset, std::string_view Str) noexcept { char *Buf = MemInst.getPointer(Offset); - std::copy_n(Str.c_str(), Str.length(), Buf); + std::copy_n(Str.data(), Str.length(), Buf); } } // namespace TEST(WasiLoggingTests, func_log) { + using namespace std::literals::string_view_literals; // Create the wasi-logging module instance. - auto WasiLoggingMod = + // Here create 2 wasi-logging modules for testing in multiple modules. + auto WasiLoggingMod1 = + dynamic_cast(createModule()); + EXPECT_NE(WasiLoggingMod1, nullptr); + auto WasiLoggingMod2 = dynamic_cast(createModule()); - EXPECT_NE(WasiLoggingMod, nullptr); + EXPECT_NE(WasiLoggingMod2, nullptr); // Create the calling frame with memory instance. WasmEdge::Runtime::Instance::ModuleInstance Mod(""); @@ -58,68 +60,122 @@ TEST(WasiLoggingTests, func_log) { auto &MemInst = *MemInstPtr; WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); - // Clear the memory[0, 32]. - fillMemContent(MemInst, 0, 32); - // Set strings in memory - fillMemContent(MemInst, 0, std::string("stdout")); - fillMemContent(MemInst, 8, std::string("stderr")); - fillMemContent(MemInst, 16, std::string("MsgStr")); - - // Get the function "log" - auto *FuncInst = WasiLoggingMod->findFuncExports("log"); - EXPECT_NE(FuncInst, nullptr); - EXPECT_TRUE(FuncInst->isHostFunction()); - auto &HostFuncInst = - dynamic_cast(FuncInst->getHostFunc()); + // Clear the memory[0, 256]. + fillMemContent(MemInst, 0, 256); + // Set strings in memory. + fillMemContent(MemInst, 0, "stdout"sv); + fillMemContent(MemInst, 8, "stderr"sv); + fillMemContent(MemInst, 16, "out.log"sv); + fillMemContent(MemInst, 24, "out2.log"sv); + fillMemContent(MemInst, 128, "This is log message"sv); + fillMemContent(MemInst, 160, "Message 1 to file"sv); + fillMemContent(MemInst, 192, "Message 2 to file"sv); + fillMemContent(MemInst, 224, "Message 3 to file"sv); + + // Get the function "log". + auto *FuncInst1 = WasiLoggingMod1->findFuncExports("log"); + auto *FuncInst2 = WasiLoggingMod2->findFuncExports("log"); + EXPECT_NE(FuncInst1, nullptr); + EXPECT_NE(FuncInst2, nullptr); + EXPECT_TRUE(FuncInst1->isHostFunction()); + EXPECT_TRUE(FuncInst2->isHostFunction()); + auto &HostFuncInst1 = dynamic_cast( + FuncInst1->getHostFunc()); + auto &HostFuncInst2 = dynamic_cast( + FuncInst2->getHostFunc()); // Show All Level - EXPECT_TRUE(HostFuncInst.run( + EXPECT_TRUE(HostFuncInst1.run( CallFrame, std::initializer_list{ - UINT32_C(0), UINT32_C(0), UINT32_C(6), UINT32_C(16), UINT32_C(6)}, + UINT32_C(0), UINT32_C(0), UINT32_C(6), UINT32_C(128), UINT32_C(19)}, {})); - EXPECT_TRUE(HostFuncInst.run( + EXPECT_TRUE(HostFuncInst1.run( CallFrame, std::initializer_list{ - UINT32_C(1), UINT32_C(0), UINT32_C(6), UINT32_C(16), UINT32_C(6)}, + UINT32_C(1), UINT32_C(0), UINT32_C(6), UINT32_C(128), UINT32_C(19)}, {})); - EXPECT_TRUE(HostFuncInst.run( + EXPECT_TRUE(HostFuncInst1.run( CallFrame, std::initializer_list{ - UINT32_C(2), UINT32_C(0), UINT32_C(6), UINT32_C(16), UINT32_C(6)}, + UINT32_C(2), UINT32_C(0), UINT32_C(6), UINT32_C(128), UINT32_C(19)}, {})); - - EXPECT_TRUE(HostFuncInst.run( + EXPECT_TRUE(HostFuncInst1.run( CallFrame, std::initializer_list{ - UINT32_C(3), UINT32_C(0), UINT32_C(6), UINT32_C(16), UINT32_C(6)}, + UINT32_C(3), UINT32_C(0), UINT32_C(6), UINT32_C(128), UINT32_C(19)}, {})); - EXPECT_TRUE(HostFuncInst.run( + EXPECT_TRUE(HostFuncInst1.run( CallFrame, std::initializer_list{ - UINT32_C(4), UINT32_C(0), UINT32_C(6), UINT32_C(16), UINT32_C(6)}, + UINT32_C(4), UINT32_C(0), UINT32_C(6), UINT32_C(128), UINT32_C(19)}, {})); - EXPECT_TRUE(HostFuncInst.run( + EXPECT_TRUE(HostFuncInst1.run( CallFrame, std::initializer_list{ - UINT32_C(5), UINT32_C(0), UINT32_C(6), UINT32_C(16), UINT32_C(6)}, + UINT32_C(5), UINT32_C(0), UINT32_C(6), UINT32_C(128), UINT32_C(19)}, {})); // Stderr Context - EXPECT_TRUE(HostFuncInst.run( + EXPECT_TRUE(HostFuncInst1.run( CallFrame, std::initializer_list{ - UINT32_C(0), UINT32_C(8), UINT32_C(6), UINT32_C(16), UINT32_C(6)}, + UINT32_C(0), UINT32_C(8), UINT32_C(6), UINT32_C(128), UINT32_C(19)}, + {})); + EXPECT_TRUE(HostFuncInst2.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(8), UINT32_C(6), UINT32_C(128), UINT32_C(19)}, + {})); + + // Log to out.txt: message 1 + EXPECT_TRUE(HostFuncInst1.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(16), UINT32_C(7), UINT32_C(160), UINT32_C(17)}, + {})); + EXPECT_TRUE(HostFuncInst2.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(16), UINT32_C(7), UINT32_C(160), UINT32_C(17)}, + {})); + // Log to out2.txt: message 2 + EXPECT_TRUE(HostFuncInst1.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(24), UINT32_C(8), UINT32_C(192), UINT32_C(17)}, + {})); + EXPECT_TRUE(HostFuncInst2.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(24), UINT32_C(8), UINT32_C(192), UINT32_C(17)}, + {})); + // Log to out.txt: message 3 + EXPECT_TRUE(HostFuncInst1.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(16), UINT32_C(7), UINT32_C(224), UINT32_C(17)}, + {})); + EXPECT_TRUE(HostFuncInst2.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(16), UINT32_C(7), UINT32_C(224), UINT32_C(17)}, {})); // UnKnown Level - EXPECT_FALSE(HostFuncInst.run( + EXPECT_FALSE(HostFuncInst1.run( + CallFrame, + std::initializer_list{ + UINT32_C(6), UINT32_C(0), UINT32_C(6), UINT32_C(128), UINT32_C(19)}, + {})); + EXPECT_FALSE(HostFuncInst2.run( CallFrame, std::initializer_list{ - UINT32_C(6), UINT32_C(0), UINT32_C(6), UINT32_C(16), UINT32_C(6)}, + UINT32_C(6), UINT32_C(0), UINT32_C(6), UINT32_C(128), UINT32_C(19)}, {})); - delete WasiLoggingMod; + delete WasiLoggingMod1; + delete WasiLoggingMod2; } GTEST_API_ int main(int argc, char **argv) { diff --git a/test/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp b/test/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp index 0b7c2cb8..6ca7f158 100644 --- a/test/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp +++ b/test/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp @@ -358,8 +358,8 @@ TEST_F(FFmpegTest, AVOutputFormatFunc) { uint32_t FileLen = 8; fillMemContent(MemInst, FormatStart, FormatLen + FileLen); - fillMemContent(MemInst, FormatStart, std::string("mp4")); - fillMemContent(MemInst, FileStart, std::string("test.mp4")); + fillMemContent(MemInst, FormatStart, "mp4"sv); + fillMemContent(MemInst, FileStart, "test.mp4"sv); auto *FuncInst = AVFormatMod->findFuncExports( "wasmedge_ffmpeg_avformat_avformat_alloc_output_context2"); diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp index cb4ad790..2e3542c4 100644 --- a/test/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp +++ b/test/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp @@ -9,6 +9,7 @@ namespace Host { namespace WasmEdgeFFmpeg { TEST_F(FFmpegTest, AVDictionary) { + using namespace std::literals::string_view_literals; uint32_t KeyStart = UINT32_C(1); uint32_t KeyLen = 3; @@ -31,8 +32,8 @@ TEST_F(FFmpegTest, AVDictionary) { // Fill 0 in WasmMemory. fillMemContent(MemInst, KeyStart, KeyLen + ValueLen); - fillMemContent(MemInst, KeyStart, std::string("KEY")); - fillMemContent(MemInst, ValueStart, std::string("VALUE")); + fillMemContent(MemInst, KeyStart, "KEY"sv); + fillMemContent(MemInst, ValueStart, "VALUE"sv); // Storing the above Key and Value in dict and using these in below tests // (dict_get) to fetch Key,values. diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avError.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avError.cpp index bc5e6dab..bcba324c 100644 --- a/test/plugins/wasmedge_ffmpeg/avutil/avError.cpp +++ b/test/plugins/wasmedge_ffmpeg/avutil/avError.cpp @@ -9,16 +9,14 @@ namespace Host { namespace WasmEdgeFFmpeg { TEST_F(FFmpegTest, AVError) { - + using namespace std::literals::string_view_literals; ASSERT_TRUE(AVUtilMod != nullptr); int32_t ErrNum = 35; - uint32_t ErrStartPtr = UINT32_C(100); uint32_t ErrSize = 10; fillMemContent(MemInst, ErrStartPtr, ErrSize); - - fillMemContent(MemInst, ErrStartPtr, std::string("Test Error")); + fillMemContent(MemInst, ErrStartPtr, "Test Error"sv); auto *FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_strerror"); diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avSampleFmt.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avSampleFmt.cpp index 93593079..14ea8bb5 100644 --- a/test/plugins/wasmedge_ffmpeg/avutil/avSampleFmt.cpp +++ b/test/plugins/wasmedge_ffmpeg/avutil/avSampleFmt.cpp @@ -9,7 +9,7 @@ namespace Host { namespace WasmEdgeFFmpeg { TEST_F(FFmpegTest, AVSampleFmt) { - + using namespace std::literals::string_view_literals; ASSERT_TRUE(AVUtilMod != nullptr); uint32_t BufferPtr = UINT32_C(160); @@ -83,7 +83,7 @@ TEST_F(FFmpegTest, AVSampleFmt) { uint32_t SampleFmtSize = 2; fillMemContent(MemInst, SampleFmtSize, SampleFmtSize); - fillMemContent(MemInst, SampleFmtStart, std::string("u8")); + fillMemContent(MemInst, SampleFmtStart, "u8"sv); { HostFuncAVGetSampleFmt.run(CallFrame, std::initializer_list{ diff --git a/test/plugins/wasmedge_ffmpeg/utils.h b/test/plugins/wasmedge_ffmpeg/utils.h index 6c95138e..5b5175fa 100644 --- a/test/plugins/wasmedge_ffmpeg/utils.h +++ b/test/plugins/wasmedge_ffmpeg/utils.h @@ -26,9 +26,9 @@ inline void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance *MemInst, } inline void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance *MemInst, - uint32_t Offset, const std::string &Str) noexcept { + uint32_t Offset, std::string_view Str) noexcept { char *Buf = MemInst->getPointer(Offset); - std::copy_n(Str.c_str(), Str.length(), Buf); + std::copy_n(Str.data(), Str.length(), Buf); } inline void writeSInt32(WasmEdge::Runtime::Instance::MemoryInstance *MemInst, diff --git a/test/plugins/wasmedge_process/wasmedge_process.cpp b/test/plugins/wasmedge_process/wasmedge_process.cpp index 96e66b12..5f63fd90 100644 --- a/test/plugins/wasmedge_process/wasmedge_process.cpp +++ b/test/plugins/wasmedge_process/wasmedge_process.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include namespace { @@ -36,12 +37,14 @@ void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, } void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, - uint32_t Offset, const std::string &Str) noexcept { + uint32_t Offset, std::string_view Str) noexcept { char *Buf = MemInst.getPointer(Offset); - std::copy_n(Str.c_str(), Str.length(), Buf); + std::copy_n(Str.data(), Str.length(), Buf); } } // namespace +using namespace std::literals::string_view_literals; + TEST(WasmEdgeProcessTest, SetProgName) { // Create the wasmedge_process module instance. auto *ProcMod = @@ -61,7 +64,7 @@ TEST(WasmEdgeProcessTest, SetProgName) { // Clear the memory[0, 64]. fillMemContent(MemInst, 0, 64); // Set the memory[0, 4] as string "echo". - fillMemContent(MemInst, 0, std::string("echo")); + fillMemContent(MemInst, 0, "echo"sv); // Get the function "wasmedge_process_set_prog_name". auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_set_prog_name"); @@ -106,11 +109,11 @@ TEST(WasmEdgeProcessTest, AddArg) { // Clear the memory[0, 64]. fillMemContent(MemInst, 0, 64); // Set the memory[0, 4] as string "echo". - fillMemContent(MemInst, 0, std::string("arg1")); + fillMemContent(MemInst, 0, "arg1"sv); // Set the memory[4, 8] as string "arg2". - fillMemContent(MemInst, 4, std::string("arg2")); + fillMemContent(MemInst, 4, "arg2"sv); // Set the memory[30, 41] as string "--final-arg". - fillMemContent(MemInst, 30, std::string("--final-arg")); + fillMemContent(MemInst, 30, "--final-arg"sv); // Get the function "wasmedge_process_add_arg". auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_add_arg"); @@ -171,13 +174,13 @@ TEST(WasmEdgeProcessTest, AddEnv) { // Clear the memory[0, 256]. fillMemContent(MemInst, 0, 256); // Set the memory[0, 4] as string "ENV1". - fillMemContent(MemInst, 0, std::string("ENV1")); + fillMemContent(MemInst, 0, "ENV1"sv); // Set the memory[4, 10] as string "VALUE1". - fillMemContent(MemInst, 4, std::string("VALUE1")); + fillMemContent(MemInst, 4, "VALUE1"sv); // Set the memory[30, 45] as string "LD_LIBRARY_PATH". - fillMemContent(MemInst, 30, std::string("LD_LIBRARY_PATH")); + fillMemContent(MemInst, 30, "LD_LIBRARY_PATH"sv); // Set the memory[50, 64] as string "/usr/local/lib". - fillMemContent(MemInst, 50, std::string("/usr/local/lib")); + fillMemContent(MemInst, 50, "/usr/local/lib"sv); // Get the function "wasmedge_process_add_env". auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_add_env"); @@ -233,9 +236,9 @@ TEST(WasmEdgeProcessTest, AddStdIn) { // Clear the memory[0, 64]. fillMemContent(MemInst, 0, 64); // Set the memory[0, 4] as string "\01\02\03\04". - fillMemContent(MemInst, 0, std::string("\01\02\03\04")); + fillMemContent(MemInst, 0, "\01\02\03\04"sv); // Set the memory[30, 46] as string "hello, wasmedge\n". - fillMemContent(MemInst, 30, std::string("hello, wasmedge\n")); + fillMemContent(MemInst, 30, "hello, wasmedge\n"sv); // Get the function "wasmedge_process_add_stdin". auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_add_stdin"); @@ -315,9 +318,9 @@ TEST(WasmEdgeProcessTest, Run) { // Clear the memory[0, 64]. fillMemContent(MemInst, 0, 64); // Set the memory[0, 4] as string "\01\02\03\04". - fillMemContent(MemInst, 0, std::string("\01\02\03\04")); + fillMemContent(MemInst, 0, "\01\02\03\04"sv); // Set the memory[30, 46] as string "hello, wasmedge\n". - fillMemContent(MemInst, 30, std::string("hello, wasmedge\n")); + fillMemContent(MemInst, 30, "hello, wasmedge\n"sv); // Get the function "wasmedge_process_run". auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_run"); From ad21a81a78ff7fcac3891eaff41d169f8882e1fa Mon Sep 17 00:00:00 2001 From: Yi-Ying He Date: Wed, 7 Aug 2024 22:53:22 +0800 Subject: [PATCH 393/623] [Misc] Update the file copyright text and lint. (#3629) Signed-off-by: YiYing He --- plugins/CMakeLists.txt | 2 +- plugins/wasi_crypto/CMakeLists.txt | 2 +- plugins/wasi_crypto/asymmetric_common/ctx.cpp | 2 +- plugins/wasi_crypto/asymmetric_common/ecdsa.h | 2 +- .../wasi_crypto/asymmetric_common/func.cpp | 2 +- plugins/wasi_crypto/asymmetric_common/func.h | 2 +- .../wasi_crypto/asymmetric_common/keypair.cpp | 2 +- .../wasi_crypto/asymmetric_common/keypair.h | 2 +- .../wasi_crypto/asymmetric_common/module.cpp | 2 +- .../wasi_crypto/asymmetric_common/module.h | 2 +- .../asymmetric_common/publickey.cpp | 2 +- .../wasi_crypto/asymmetric_common/publickey.h | 2 +- .../asymmetric_common/registered.h | 2 +- .../asymmetric_common/secretkey.cpp | 2 +- .../wasi_crypto/asymmetric_common/secretkey.h | 2 +- plugins/wasi_crypto/common/array_output.cpp | 2 +- plugins/wasi_crypto/common/array_output.h | 2 +- plugins/wasi_crypto/common/ctx.cpp | 2 +- plugins/wasi_crypto/common/func.cpp | 2 +- plugins/wasi_crypto/common/func.h | 2 +- plugins/wasi_crypto/common/module.cpp | 2 +- plugins/wasi_crypto/common/module.h | 2 +- plugins/wasi_crypto/common/options.cpp | 2 +- plugins/wasi_crypto/common/options.h | 2 +- plugins/wasi_crypto/ctx.cpp | 2 +- plugins/wasi_crypto/ctx.h | 2 +- plugins/wasi_crypto/kx/ctx.cpp | 2 +- plugins/wasi_crypto/kx/dh/ecdsa.cpp | 2 +- plugins/wasi_crypto/kx/dh/ecdsa.h | 2 +- plugins/wasi_crypto/kx/dh/x25519.cpp | 2 +- plugins/wasi_crypto/kx/dh/x25519.h | 2 +- plugins/wasi_crypto/kx/func.cpp | 2 +- plugins/wasi_crypto/kx/func.h | 2 +- plugins/wasi_crypto/kx/kx.cpp | 2 +- plugins/wasi_crypto/kx/kx.h | 2 +- plugins/wasi_crypto/kx/module.cpp | 2 +- plugins/wasi_crypto/kx/module.h | 2 +- plugins/wasi_crypto/kx/options.cpp | 2 +- plugins/wasi_crypto/kx/options.h | 2 +- plugins/wasi_crypto/kx/registered.h | 2 +- plugins/wasi_crypto/signatures/ctx.cpp | 2 +- plugins/wasi_crypto/signatures/ecdsa.cpp | 2 +- plugins/wasi_crypto/signatures/ecdsa.h | 2 +- plugins/wasi_crypto/signatures/eddsa.cpp | 2 +- plugins/wasi_crypto/signatures/eddsa.h | 2 +- plugins/wasi_crypto/signatures/func.cpp | 2 +- plugins/wasi_crypto/signatures/func.h | 2 +- plugins/wasi_crypto/signatures/module.cpp | 2 +- plugins/wasi_crypto/signatures/module.h | 2 +- plugins/wasi_crypto/signatures/options.cpp | 2 +- plugins/wasi_crypto/signatures/options.h | 2 +- plugins/wasi_crypto/signatures/registered.h | 2 +- plugins/wasi_crypto/signatures/rsa.cpp | 2 +- plugins/wasi_crypto/signatures/rsa.h | 2 +- plugins/wasi_crypto/signatures/signatures.cpp | 2 +- plugins/wasi_crypto/signatures/signatures.h | 2 +- plugins/wasi_crypto/signatures/signstate.cpp | 2 +- plugins/wasi_crypto/signatures/signstate.h | 2 +- .../signatures/verificationstate.cpp | 2 +- .../signatures/verificationstate.h | 2 +- plugins/wasi_crypto/symmetric/aeads.cpp | 2 +- plugins/wasi_crypto/symmetric/aeads.h | 2 +- plugins/wasi_crypto/symmetric/ctx.cpp | 2 +- plugins/wasi_crypto/symmetric/func.cpp | 2 +- plugins/wasi_crypto/symmetric/func.h | 2 +- plugins/wasi_crypto/symmetric/hash.cpp | 2 +- plugins/wasi_crypto/symmetric/hash.h | 2 +- plugins/wasi_crypto/symmetric/kdf.h | 2 +- plugins/wasi_crypto/symmetric/key.cpp | 2 +- plugins/wasi_crypto/symmetric/key.h | 2 +- plugins/wasi_crypto/symmetric/mac.cpp | 2 +- plugins/wasi_crypto/symmetric/mac.h | 2 +- plugins/wasi_crypto/symmetric/module.cpp | 2 +- plugins/wasi_crypto/symmetric/module.h | 2 +- plugins/wasi_crypto/symmetric/options.cpp | 2 +- plugins/wasi_crypto/symmetric/options.h | 2 +- plugins/wasi_crypto/symmetric/registered.h | 2 +- plugins/wasi_crypto/symmetric/state.cpp | 2 +- plugins/wasi_crypto/symmetric/state.h | 2 +- plugins/wasi_crypto/symmetric/tag.cpp | 2 +- plugins/wasi_crypto/symmetric/tag.h | 2 +- plugins/wasi_crypto/utils/error.h | 2 +- plugins/wasi_crypto/utils/evp_wrapper.cpp | 2 +- plugins/wasi_crypto/utils/evp_wrapper.h | 2 +- plugins/wasi_crypto/utils/handles_manager.h | 2 +- plugins/wasi_crypto/utils/hostfunction.cpp | 2 +- plugins/wasi_crypto/utils/hostfunction.h | 2 +- plugins/wasi_crypto/utils/optional.h | 2 +- plugins/wasi_crypto/utils/secret_vec.h | 2 +- plugins/wasi_nn/CMakeLists.txt | 2 +- plugins/wasi_nn/chattts.cpp | 4 +- plugins/wasi_nn/chattts.h | 7 +- plugins/wasi_nn/ggml.cpp | 2 +- plugins/wasi_nn/ggml.h | 2 +- plugins/wasi_nn/neuralspeed.cpp | 2 +- plugins/wasi_nn/neuralspeed.h | 2 +- plugins/wasi_nn/onnx.cpp | 2 +- plugins/wasi_nn/onnx.h | 2 +- plugins/wasi_nn/openvino.cpp | 2 +- plugins/wasi_nn/openvino.h | 2 +- plugins/wasi_nn/piper.cpp | 2 +- plugins/wasi_nn/piper.h | 2 +- plugins/wasi_nn/tf.cpp | 2 +- plugins/wasi_nn/tf.h | 2 +- plugins/wasi_nn/tfl.cpp | 2 +- plugins/wasi_nn/tfl.h | 2 +- plugins/wasi_nn/torch.cpp | 2 +- plugins/wasi_nn/torch.h | 2 +- plugins/wasi_nn/types.h | 4 +- plugins/wasi_nn/wasinnbase.h | 2 +- plugins/wasi_nn/wasinnenv.cpp | 2 +- plugins/wasi_nn/wasinnenv.h | 2 +- plugins/wasi_nn/wasinnfunc.cpp | 2 +- plugins/wasi_nn/wasinnfunc.h | 2 +- plugins/wasi_nn/wasinnmodule.cpp | 2 +- plugins/wasi_nn/wasinnmodule.h | 2 +- plugins/wasi_nn/whispercpp.cpp | 2 +- plugins/wasi_nn/whispercpp.h | 2 +- plugins/wasi_ocr/CMakeLists.txt | 2 +- plugins/wasi_poll/func.h | 2 +- plugins/wasm_bpf/CMakeLists.txt | 2 +- plugins/wasm_bpf/bpf-api.h | 2 +- plugins/wasm_bpf/func-attach-bpf-program.cpp | 2 +- plugins/wasm_bpf/func-attach-bpf-program.h | 2 +- plugins/wasm_bpf/func-bpf-buffer-poll.cpp | 3 +- plugins/wasm_bpf/func-bpf-buffer-poll.h | 2 +- plugins/wasm_bpf/func-bpf-map-fd-by-name.cpp | 2 +- plugins/wasm_bpf/func-bpf-map-fd-by-name.h | 2 +- plugins/wasm_bpf/func-bpf-map-operate.cpp | 3 +- plugins/wasm_bpf/func-bpf-map-operate.h | 2 +- plugins/wasm_bpf/func-close-bpf-object.cpp | 2 +- plugins/wasm_bpf/func-close-bpf-object.h | 2 +- plugins/wasm_bpf/func-load-bpf-object.cpp | 2 +- plugins/wasm_bpf/func-load-bpf-object.h | 2 +- plugins/wasm_bpf/state.h | 2 +- plugins/wasm_bpf/util.cpp | 2 +- plugins/wasm_bpf/util.h | 2 +- plugins/wasm_bpf/wasm-bpf-module.cpp | 2 +- plugins/wasm_bpf/wasm-bpf-module.h | 2 +- plugins/wasm_bpf/wasm-bpf.cpp | 2 +- plugins/wasmedge_ffmpeg/CMakeLists.txt | 2 +- plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp | 50 ++- plugins/wasmedge_ffmpeg/avcodec/avCodec.h | 4 + .../avcodec/avCodecContext.cpp | 94 +---- .../wasmedge_ffmpeg/avcodec/avCodecContext.h | 4 + .../avcodec/avCodecParameters.cpp | 6 +- .../avcodec/avCodecParameters.h | 4 + plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp | 24 +- plugins/wasmedge_ffmpeg/avcodec/avPacket.h | 4 + .../wasmedge_ffmpeg/avcodec/avcodec_base.h | 3 + .../wasmedge_ffmpeg/avcodec/avcodec_func.cpp | 28 +- .../wasmedge_ffmpeg/avcodec/avcodec_func.h | 4 + plugins/wasmedge_ffmpeg/avcodec/module.cpp | 4 +- plugins/wasmedge_ffmpeg/avcodec/module.h | 3 + .../wasmedge_ffmpeg/avdevice/avDevice_base.h | 3 + .../avdevice/avDevice_func.cpp | 10 +- .../wasmedge_ffmpeg/avdevice/avDevice_func.h | 3 + plugins/wasmedge_ffmpeg/avdevice/module.cpp | 4 +- plugins/wasmedge_ffmpeg/avdevice/module.h | 3 + plugins/wasmedge_ffmpeg/avfilter/avFilter.cpp | 21 +- plugins/wasmedge_ffmpeg/avfilter/avFilter.h | 3 + .../wasmedge_ffmpeg/avfilter/avfilter_base.h | 3 + .../avfilter/avfilter_func.cpp | 49 +-- .../wasmedge_ffmpeg/avfilter/avfilter_func.h | 3 + .../avfilter/buffer_source_sink.cpp | 10 +- .../avfilter/buffer_source_sink.h | 3 + plugins/wasmedge_ffmpeg/avfilter/module.cpp | 4 +- plugins/wasmedge_ffmpeg/avfilter/module.h | 3 + .../wasmedge_ffmpeg/avformat/avChapter.cpp | 48 +-- plugins/wasmedge_ffmpeg/avformat/avChapter.h | 3 + .../avformat/avInputOutputFormat.cpp | 27 +- .../avformat/avInputOutputFormat.h | 4 + plugins/wasmedge_ffmpeg/avformat/avStream.cpp | 71 ++-- plugins/wasmedge_ffmpeg/avformat/avStream.h | 3 + .../avformat/avformatContext.cpp | 18 +- .../avformat/avformatContext.h | 3 + .../wasmedge_ffmpeg/avformat/avformat_base.h | 3 + .../avformat/avformat_func.cpp | 34 +- .../wasmedge_ffmpeg/avformat/avformat_func.h | 3 + plugins/wasmedge_ffmpeg/avformat/module.cpp | 4 +- plugins/wasmedge_ffmpeg/avformat/module.h | 3 + .../wasmedge_ffmpeg/avutil/avDictionary.cpp | 26 +- plugins/wasmedge_ffmpeg/avutil/avDictionary.h | 3 + plugins/wasmedge_ffmpeg/avutil/avFrame.cpp | 60 +--- plugins/wasmedge_ffmpeg/avutil/avFrame.h | 3 + plugins/wasmedge_ffmpeg/avutil/avRational.cpp | 14 +- plugins/wasmedge_ffmpeg/avutil/avRational.h | 4 + plugins/wasmedge_ffmpeg/avutil/avTime.cpp | 3 + plugins/wasmedge_ffmpeg/avutil/avTime.h | 4 + plugins/wasmedge_ffmpeg/avutil/avutil_base.h | 3 + .../wasmedge_ffmpeg/avutil/avutil_func.cpp | 14 +- plugins/wasmedge_ffmpeg/avutil/avutil_func.h | 4 + plugins/wasmedge_ffmpeg/avutil/error.cpp | 4 +- plugins/wasmedge_ffmpeg/avutil/error.h | 3 + plugins/wasmedge_ffmpeg/avutil/module.cpp | 4 +- plugins/wasmedge_ffmpeg/avutil/module.h | 3 + plugins/wasmedge_ffmpeg/avutil/pixfmt.cpp | 16 +- plugins/wasmedge_ffmpeg/avutil/pixfmt.h | 4 + plugins/wasmedge_ffmpeg/avutil/samplefmt.cpp | 8 +- plugins/wasmedge_ffmpeg/avutil/samplefmt.h | 4 + plugins/wasmedge_ffmpeg/bindings.h | 327 ++++++++++++------ plugins/wasmedge_ffmpeg/ffmpeg_env.cpp | 4 + plugins/wasmedge_ffmpeg/ffmpeg_env.h | 3 + plugins/wasmedge_ffmpeg/swresample/module.cpp | 4 +- plugins/wasmedge_ffmpeg/swresample/module.h | 3 + .../swresample/swresample_base.h | 3 + .../swresample/swresample_func.cpp | 10 +- .../swresample/swresample_func.h | 3 + plugins/wasmedge_ffmpeg/swscale/module.cpp | 4 +- plugins/wasmedge_ffmpeg/swscale/module.h | 3 + .../wasmedge_ffmpeg/swscale/swscale_base.h | 3 + .../wasmedge_ffmpeg/swscale/swscale_func.cpp | 32 +- .../wasmedge_ffmpeg/swscale/swscale_func.h | 3 + plugins/wasmedge_image/CMakeLists.txt | 2 +- plugins/wasmedge_image/image_base.h | 2 +- plugins/wasmedge_image/image_env.cpp | 2 +- plugins/wasmedge_image/image_env.h | 2 +- plugins/wasmedge_image/image_func.cpp | 2 +- plugins/wasmedge_image/image_func.h | 2 +- plugins/wasmedge_image/image_module.cpp | 2 +- plugins/wasmedge_image/image_module.h | 2 +- plugins/wasmedge_opencvmini/CMakeLists.txt | 2 +- plugins/wasmedge_opencvmini/opencvmini_base.h | 2 +- .../wasmedge_opencvmini/opencvmini_env.cpp | 2 +- plugins/wasmedge_opencvmini/opencvmini_env.h | 2 +- .../wasmedge_opencvmini/opencvmini_func.cpp | 2 +- plugins/wasmedge_opencvmini/opencvmini_func.h | 2 +- .../wasmedge_opencvmini/opencvmini_module.cpp | 2 +- .../wasmedge_opencvmini/opencvmini_module.h | 2 +- plugins/wasmedge_process/CMakeLists.txt | 2 +- plugins/wasmedge_process/processbase.h | 2 +- plugins/wasmedge_process/processenv.cpp | 2 +- plugins/wasmedge_process/processenv.h | 2 +- plugins/wasmedge_process/processfunc.cpp | 2 +- plugins/wasmedge_process/processfunc.h | 2 +- plugins/wasmedge_process/processmodule.cpp | 2 +- plugins/wasmedge_process/processmodule.h | 2 +- .../wasmedge_stablediffusion/CMakeLists.txt | 2 +- plugins/wasmedge_stablediffusion/sd_base.h | 2 +- plugins/wasmedge_stablediffusion/sd_env.cpp | 2 +- plugins/wasmedge_stablediffusion/sd_env.h | 2 +- plugins/wasmedge_stablediffusion/sd_func.cpp | 2 +- plugins/wasmedge_stablediffusion/sd_func.h | 2 +- .../wasmedge_stablediffusion/sd_module.cpp | 2 +- plugins/wasmedge_stablediffusion/sd_module.h | 2 +- plugins/wasmedge_tensorflow/CMakeLists.txt | 2 +- plugins/wasmedge_tensorflow/tensorflow_base.h | 2 +- .../wasmedge_tensorflow/tensorflow_env.cpp | 2 +- plugins/wasmedge_tensorflow/tensorflow_env.h | 2 +- .../wasmedge_tensorflow/tensorflow_func.cpp | 2 +- plugins/wasmedge_tensorflow/tensorflow_func.h | 2 +- .../wasmedge_tensorflow/tensorflow_module.cpp | 2 +- .../wasmedge_tensorflow/tensorflow_module.h | 2 +- .../wasmedge_tensorflowlite/CMakeLists.txt | 2 +- .../tensorflowlite_base.h | 2 +- .../tensorflowlite_env.cpp | 2 +- .../tensorflowlite_env.h | 2 +- .../tensorflowlite_func.cpp | 2 +- .../tensorflowlite_func.h | 2 +- .../tensorflowlite_module.cpp | 2 +- .../tensorflowlite_module.h | 2 +- plugins/wasmedge_zlib/CMakeLists.txt | 2 +- plugins/wasmedge_zlib/zlibbase.h | 2 +- plugins/wasmedge_zlib/zlibenv.cpp | 2 +- plugins/wasmedge_zlib/zlibenv.h | 13 +- plugins/wasmedge_zlib/zlibfunc.cpp | 2 +- plugins/wasmedge_zlib/zlibfunc.h | 2 +- plugins/wasmedge_zlib/zlibmodule.cpp | 2 +- plugins/wasmedge_zlib/zlibmodule.h | 2 +- test/plugins/CMakeLists.txt | 2 +- test/plugins/unittest/CMakeLists.txt | 2 +- test/plugins/unittest/testplugin.c | 2 +- test/plugins/unittest/testplugin.cpp | 2 +- test/plugins/unittest/testplugin.h | 2 +- test/plugins/unittest/unittest_c.cpp | 2 +- test/plugins/unittest/unittest_cpp.cpp | 2 +- test/plugins/wasi_crypto/CMakeLists.txt | 2 +- test/plugins/wasi_crypto/aeads.cpp | 2 +- test/plugins/wasi_crypto/asymmetric.cpp | 2 +- test/plugins/wasi_crypto/common.cpp | 2 +- test/plugins/wasi_crypto/hash.cpp | 2 +- test/plugins/wasi_crypto/helper.cpp | 2 +- test/plugins/wasi_crypto/helper.h | 2 +- test/plugins/wasi_crypto/kdf.cpp | 2 +- test/plugins/wasi_crypto/kx.cpp | 2 +- test/plugins/wasi_crypto/mac.cpp | 2 +- test/plugins/wasi_crypto/notimplement.cpp | 2 +- test/plugins/wasi_crypto/signatures.cpp | 2 +- test/plugins/wasi_logging/CMakeLists.txt | 2 +- test/plugins/wasi_nn/CMakeLists.txt | 2 +- test/plugins/wasi_nn/wasi_nn.cpp | 2 +- test/plugins/wasm_bpf/CMakeLists.txt | 2 +- .../assets/bpf-sources/simple_map.bpf.c | 2 +- .../assets/bpf-sources/simple_ringbuf.bpf.c | 2 +- test/plugins/wasm_bpf/simple_map_test.cpp | 2 +- test/plugins/wasm_bpf/simple_ringbuf_test.cpp | 2 +- test/plugins/wasm_bpf/wasm_bpf.cpp | 2 +- test/plugins/wasmedge_ffmpeg/CMakeLists.txt | 2 +- test/plugins/wasmedge_image/CMakeLists.txt | 2 +- .../plugins/wasmedge_image/wasmedge_image.cpp | 2 +- .../wasmedge_opencvmini/CMakeLists.txt | 2 +- .../wasmedge_opencvmini.cpp | 2 +- test/plugins/wasmedge_process/CMakeLists.txt | 2 +- .../wasmedge_process/wasmedge_process.cpp | 2 +- .../wasmedge_tensorflow/CMakeLists.txt | 2 +- .../wasmedge_tensorflow.cpp | 2 +- .../wasmedge_tensorflowlite/CMakeLists.txt | 2 +- .../wasmedge_tensorflowlite.cpp | 2 +- test/plugins/wasmedge_zlib/CMakeLists.txt | 2 +- test/plugins/wasmedge_zlib/wasmedge_zlib.cpp | 2 +- thirdparty/wasi_crypto/api.hpp | 2 +- utils/docker/Dockerfile.manylinux2014_aarch64 | 2 +- utils/docker/Dockerfile.manylinux2014_x86_64 | 2 +- utils/docker/Dockerfile.manylinux_2_28-base | 2 +- utils/docker/Dockerfile.ubuntu2104_armv7l | 2 +- utils/docker/build-manylinux.sh | 2 +- utils/docker/build.sh | 2 +- utils/ffmpeg/download-ffmpeg-sample-video.sh | 2 +- utils/opencvmini/install-opencvmini.sh | 2 +- utils/wasi-crypto/build-openssl.sh | 2 +- utils/wasi-nn/build-wasinn-ubuntu-openvino.sh | 2 +- utils/wasi-nn/install-neuralspeed.sh | 2 + utils/wasi-nn/install-onnxruntime.sh | 2 + utils/wasi-nn/install-openvino.sh | 2 +- utils/wasi-nn/install-pytorch.sh | 2 +- utils/wasi-nn/test-wasinn-ubuntu-openvino.sh | 4 + utils/wasi-test/run-wasi-test.sh | 2 +- 327 files changed, 864 insertions(+), 834 deletions(-) diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index f9ac516c..b7aab370 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC if(WASMEDGE_PLUGIN_WASI_HTTP) add_subdirectory(wasi_http) diff --git a/plugins/wasi_crypto/CMakeLists.txt b/plugins/wasi_crypto/CMakeLists.txt index 70ae211e..f8e2b71f 100644 --- a/plugins/wasi_crypto/CMakeLists.txt +++ b/plugins/wasi_crypto/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC set(OPENSSL_USE_STATIC_LIBS ON) find_package(OpenSSL REQUIRED) diff --git a/plugins/wasi_crypto/asymmetric_common/ctx.cpp b/plugins/wasi_crypto/asymmetric_common/ctx.cpp index 551bd5cc..4e94a3fd 100644 --- a/plugins/wasi_crypto/asymmetric_common/ctx.cpp +++ b/plugins/wasi_crypto/asymmetric_common/ctx.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "ctx.h" diff --git a/plugins/wasi_crypto/asymmetric_common/ecdsa.h b/plugins/wasi_crypto/asymmetric_common/ecdsa.h index 5e170cf0..ee351b8a 100644 --- a/plugins/wasi_crypto/asymmetric_common/ecdsa.h +++ b/plugins/wasi_crypto/asymmetric_common/ecdsa.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/asymmetric_common/ecdsa.h - Ecdsa alg-===// // diff --git a/plugins/wasi_crypto/asymmetric_common/func.cpp b/plugins/wasi_crypto/asymmetric_common/func.cpp index 6f70f521..698f72a6 100644 --- a/plugins/wasi_crypto/asymmetric_common/func.cpp +++ b/plugins/wasi_crypto/asymmetric_common/func.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "asymmetric_common/func.h" diff --git a/plugins/wasi_crypto/asymmetric_common/func.h b/plugins/wasi_crypto/asymmetric_common/func.h index 1157751b..3c5411a5 100644 --- a/plugins/wasi_crypto/asymmetric_common/func.h +++ b/plugins/wasi_crypto/asymmetric_common/func.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/asymmetric_common/func.h -------------===// // diff --git a/plugins/wasi_crypto/asymmetric_common/keypair.cpp b/plugins/wasi_crypto/asymmetric_common/keypair.cpp index ec543c10..e3881524 100644 --- a/plugins/wasi_crypto/asymmetric_common/keypair.cpp +++ b/plugins/wasi_crypto/asymmetric_common/keypair.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "asymmetric_common/keypair.h" diff --git a/plugins/wasi_crypto/asymmetric_common/keypair.h b/plugins/wasi_crypto/asymmetric_common/keypair.h index 65a32172..0016b612 100644 --- a/plugins/wasi_crypto/asymmetric_common/keypair.h +++ b/plugins/wasi_crypto/asymmetric_common/keypair.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/asymmetric_common/keypair.h ----------===// // diff --git a/plugins/wasi_crypto/asymmetric_common/module.cpp b/plugins/wasi_crypto/asymmetric_common/module.cpp index 816f2232..ed3b3448 100644 --- a/plugins/wasi_crypto/asymmetric_common/module.cpp +++ b/plugins/wasi_crypto/asymmetric_common/module.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "asymmetric_common/module.h" #include "asymmetric_common/func.h" diff --git a/plugins/wasi_crypto/asymmetric_common/module.h b/plugins/wasi_crypto/asymmetric_common/module.h index 4e5281af..9aa39cf4 100644 --- a/plugins/wasi_crypto/asymmetric_common/module.h +++ b/plugins/wasi_crypto/asymmetric_common/module.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/asymmetric_common/module.h - Asym ----===// // diff --git a/plugins/wasi_crypto/asymmetric_common/publickey.cpp b/plugins/wasi_crypto/asymmetric_common/publickey.cpp index 5e9f1bae..2ee7d84a 100644 --- a/plugins/wasi_crypto/asymmetric_common/publickey.cpp +++ b/plugins/wasi_crypto/asymmetric_common/publickey.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "asymmetric_common/publickey.h" diff --git a/plugins/wasi_crypto/asymmetric_common/publickey.h b/plugins/wasi_crypto/asymmetric_common/publickey.h index 1e4915cb..5e50efdd 100644 --- a/plugins/wasi_crypto/asymmetric_common/publickey.h +++ b/plugins/wasi_crypto/asymmetric_common/publickey.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/asymmetric_common/publickey.h --------===// // diff --git a/plugins/wasi_crypto/asymmetric_common/registered.h b/plugins/wasi_crypto/asymmetric_common/registered.h index 989130ac..f74cb0a0 100644 --- a/plugins/wasi_crypto/asymmetric_common/registered.h +++ b/plugins/wasi_crypto/asymmetric_common/registered.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/asymmetric/registered.h - Registered -===// // diff --git a/plugins/wasi_crypto/asymmetric_common/secretkey.cpp b/plugins/wasi_crypto/asymmetric_common/secretkey.cpp index f76b2180..55f518e7 100644 --- a/plugins/wasi_crypto/asymmetric_common/secretkey.cpp +++ b/plugins/wasi_crypto/asymmetric_common/secretkey.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "asymmetric_common/secretkey.h" diff --git a/plugins/wasi_crypto/asymmetric_common/secretkey.h b/plugins/wasi_crypto/asymmetric_common/secretkey.h index 25115a8b..b9f3ec59 100644 --- a/plugins/wasi_crypto/asymmetric_common/secretkey.h +++ b/plugins/wasi_crypto/asymmetric_common/secretkey.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/asymmetric_common/secretkey.h --------===// // diff --git a/plugins/wasi_crypto/common/array_output.cpp b/plugins/wasi_crypto/common/array_output.cpp index 79ec82e7..8f25327e 100644 --- a/plugins/wasi_crypto/common/array_output.cpp +++ b/plugins/wasi_crypto/common/array_output.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "common/array_output.h" diff --git a/plugins/wasi_crypto/common/array_output.h b/plugins/wasi_crypto/common/array_output.h index 51238281..f5fef5c8 100644 --- a/plugins/wasi_crypto/common/array_output.h +++ b/plugins/wasi_crypto/common/array_output.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/common/array_output.h - ArrayOutput --===// // diff --git a/plugins/wasi_crypto/common/ctx.cpp b/plugins/wasi_crypto/common/ctx.cpp index 03996922..7c8e76b1 100644 --- a/plugins/wasi_crypto/common/ctx.cpp +++ b/plugins/wasi_crypto/common/ctx.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "ctx.h" #include "common/array_output.h" diff --git a/plugins/wasi_crypto/common/func.cpp b/plugins/wasi_crypto/common/func.cpp index e8dcc981..37bfb29f 100644 --- a/plugins/wasi_crypto/common/func.cpp +++ b/plugins/wasi_crypto/common/func.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "common/func.h" diff --git a/plugins/wasi_crypto/common/func.h b/plugins/wasi_crypto/common/func.h index d29a9f3e..e546672d 100644 --- a/plugins/wasi_crypto/common/func.h +++ b/plugins/wasi_crypto/common/func.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/common/func.h - Common func ----------===// // diff --git a/plugins/wasi_crypto/common/module.cpp b/plugins/wasi_crypto/common/module.cpp index bbb0107b..9cdfd46d 100644 --- a/plugins/wasi_crypto/common/module.cpp +++ b/plugins/wasi_crypto/common/module.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "common/module.h" #include "common/func.h" diff --git a/plugins/wasi_crypto/common/module.h b/plugins/wasi_crypto/common/module.h index cfa05726..75a1bc11 100644 --- a/plugins/wasi_crypto/common/module.h +++ b/plugins/wasi_crypto/common/module.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/common/module.h - Common Module ------===// // diff --git a/plugins/wasi_crypto/common/options.cpp b/plugins/wasi_crypto/common/options.cpp index f93f2fe0..e95aefa2 100644 --- a/plugins/wasi_crypto/common/options.cpp +++ b/plugins/wasi_crypto/common/options.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "common/options.h" diff --git a/plugins/wasi_crypto/common/options.h b/plugins/wasi_crypto/common/options.h index de22b64d..e07578c1 100644 --- a/plugins/wasi_crypto/common/options.h +++ b/plugins/wasi_crypto/common/options.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/common/options.h - Options definition ===// // diff --git a/plugins/wasi_crypto/ctx.cpp b/plugins/wasi_crypto/ctx.cpp index db2eefe1..7c7d6811 100644 --- a/plugins/wasi_crypto/ctx.cpp +++ b/plugins/wasi_crypto/ctx.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "ctx.h" #include "asymmetric_common/module.h" diff --git a/plugins/wasi_crypto/ctx.h b/plugins/wasi_crypto/ctx.h index 46a93d84..caa5bdf4 100644 --- a/plugins/wasi_crypto/ctx.h +++ b/plugins/wasi_crypto/ctx.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/ctx.h - Context class definition -----===// // diff --git a/plugins/wasi_crypto/kx/ctx.cpp b/plugins/wasi_crypto/kx/ctx.cpp index f5bb349b..86f7646a 100644 --- a/plugins/wasi_crypto/kx/ctx.cpp +++ b/plugins/wasi_crypto/kx/ctx.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "ctx.h" #include "kx/kx.h" diff --git a/plugins/wasi_crypto/kx/dh/ecdsa.cpp b/plugins/wasi_crypto/kx/dh/ecdsa.cpp index ef4536a6..a228cfe8 100644 --- a/plugins/wasi_crypto/kx/dh/ecdsa.cpp +++ b/plugins/wasi_crypto/kx/dh/ecdsa.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "kx/dh/ecdsa.h" diff --git a/plugins/wasi_crypto/kx/dh/ecdsa.h b/plugins/wasi_crypto/kx/dh/ecdsa.h index d15cbb3c..f41cd63b 100644 --- a/plugins/wasi_crypto/kx/dh/ecdsa.h +++ b/plugins/wasi_crypto/kx/dh/ecdsa.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/kx/dh/ecdsa.h - Ecdsa alg implement --===// // diff --git a/plugins/wasi_crypto/kx/dh/x25519.cpp b/plugins/wasi_crypto/kx/dh/x25519.cpp index a2460217..53d87a72 100644 --- a/plugins/wasi_crypto/kx/dh/x25519.cpp +++ b/plugins/wasi_crypto/kx/dh/x25519.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "kx/dh/x25519.h" diff --git a/plugins/wasi_crypto/kx/dh/x25519.h b/plugins/wasi_crypto/kx/dh/x25519.h index f598b026..806d6dd7 100644 --- a/plugins/wasi_crypto/kx/dh/x25519.h +++ b/plugins/wasi_crypto/kx/dh/x25519.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/kx/dh/x25519.h - X25519 alg implement ===// // diff --git a/plugins/wasi_crypto/kx/func.cpp b/plugins/wasi_crypto/kx/func.cpp index 0cdd305d..4649c278 100644 --- a/plugins/wasi_crypto/kx/func.cpp +++ b/plugins/wasi_crypto/kx/func.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "kx/func.h" diff --git a/plugins/wasi_crypto/kx/func.h b/plugins/wasi_crypto/kx/func.h index c0f83789..3f1040c0 100644 --- a/plugins/wasi_crypto/kx/func.h +++ b/plugins/wasi_crypto/kx/func.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/kx/func.h - Key Exchange funcs -------===// // diff --git a/plugins/wasi_crypto/kx/kx.cpp b/plugins/wasi_crypto/kx/kx.cpp index cd785fc5..79d98e9a 100644 --- a/plugins/wasi_crypto/kx/kx.cpp +++ b/plugins/wasi_crypto/kx/kx.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "kx/kx.h" diff --git a/plugins/wasi_crypto/kx/kx.h b/plugins/wasi_crypto/kx/kx.h index eba5a4c3..5004ebe9 100644 --- a/plugins/wasi_crypto/kx/kx.h +++ b/plugins/wasi_crypto/kx/kx.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/kx/kx.h - Key Exchange related -------===// // diff --git a/plugins/wasi_crypto/kx/module.cpp b/plugins/wasi_crypto/kx/module.cpp index 38f35cf6..12dff877 100644 --- a/plugins/wasi_crypto/kx/module.cpp +++ b/plugins/wasi_crypto/kx/module.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "kx/module.h" #include "kx/func.h" diff --git a/plugins/wasi_crypto/kx/module.h b/plugins/wasi_crypto/kx/module.h index 85992f0a..aa3ea512 100644 --- a/plugins/wasi_crypto/kx/module.h +++ b/plugins/wasi_crypto/kx/module.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/kx/module.h - Kx Module --------------===// // diff --git a/plugins/wasi_crypto/kx/options.cpp b/plugins/wasi_crypto/kx/options.cpp index da383199..c2612e24 100644 --- a/plugins/wasi_crypto/kx/options.cpp +++ b/plugins/wasi_crypto/kx/options.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "kx/options.h" diff --git a/plugins/wasi_crypto/kx/options.h b/plugins/wasi_crypto/kx/options.h index 4074579c..83fb85d3 100644 --- a/plugins/wasi_crypto/kx/options.h +++ b/plugins/wasi_crypto/kx/options.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/kx/options.h - Key exchange Options --===// // diff --git a/plugins/wasi_crypto/kx/registered.h b/plugins/wasi_crypto/kx/registered.h index 655fc392..6020f3b9 100644 --- a/plugins/wasi_crypto/kx/registered.h +++ b/plugins/wasi_crypto/kx/registered.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/kx/registered.h - Registered ---------===// // diff --git a/plugins/wasi_crypto/signatures/ctx.cpp b/plugins/wasi_crypto/signatures/ctx.cpp index 17070d44..ea85c965 100644 --- a/plugins/wasi_crypto/signatures/ctx.cpp +++ b/plugins/wasi_crypto/signatures/ctx.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "ctx.h" #include "signatures/signatures.h" diff --git a/plugins/wasi_crypto/signatures/ecdsa.cpp b/plugins/wasi_crypto/signatures/ecdsa.cpp index f295910c..e7b2162e 100644 --- a/plugins/wasi_crypto/signatures/ecdsa.cpp +++ b/plugins/wasi_crypto/signatures/ecdsa.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "signatures/ecdsa.h" diff --git a/plugins/wasi_crypto/signatures/ecdsa.h b/plugins/wasi_crypto/signatures/ecdsa.h index 4c8e373e..41c0d27c 100644 --- a/plugins/wasi_crypto/signatures/ecdsa.h +++ b/plugins/wasi_crypto/signatures/ecdsa.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/signatures/ecdsa.h - Ecdsa alg -------===// // diff --git a/plugins/wasi_crypto/signatures/eddsa.cpp b/plugins/wasi_crypto/signatures/eddsa.cpp index 820204b8..d1a44f24 100644 --- a/plugins/wasi_crypto/signatures/eddsa.cpp +++ b/plugins/wasi_crypto/signatures/eddsa.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "signatures/eddsa.h" diff --git a/plugins/wasi_crypto/signatures/eddsa.h b/plugins/wasi_crypto/signatures/eddsa.h index 70a17538..99674408 100644 --- a/plugins/wasi_crypto/signatures/eddsa.h +++ b/plugins/wasi_crypto/signatures/eddsa.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/signatures/eddsa.h - Eddsa alg -------===// // diff --git a/plugins/wasi_crypto/signatures/func.cpp b/plugins/wasi_crypto/signatures/func.cpp index 8a5b6510..18e6905c 100644 --- a/plugins/wasi_crypto/signatures/func.cpp +++ b/plugins/wasi_crypto/signatures/func.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "signatures/func.h" diff --git a/plugins/wasi_crypto/signatures/func.h b/plugins/wasi_crypto/signatures/func.h index 56683080..f1b8b8a7 100644 --- a/plugins/wasi_crypto/signatures/func.h +++ b/plugins/wasi_crypto/signatures/func.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/signatures/func.h - Signatures func --===// // diff --git a/plugins/wasi_crypto/signatures/module.cpp b/plugins/wasi_crypto/signatures/module.cpp index 44a3fef7..36d08b2a 100644 --- a/plugins/wasi_crypto/signatures/module.cpp +++ b/plugins/wasi_crypto/signatures/module.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "module.h" #include "asymmetric_common/func.h" diff --git a/plugins/wasi_crypto/signatures/module.h b/plugins/wasi_crypto/signatures/module.h index 71e8540f..1296fff4 100644 --- a/plugins/wasi_crypto/signatures/module.h +++ b/plugins/wasi_crypto/signatures/module.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/signatures/module.h - Module ---------===// // diff --git a/plugins/wasi_crypto/signatures/options.cpp b/plugins/wasi_crypto/signatures/options.cpp index 761ca439..d4de6cf3 100644 --- a/plugins/wasi_crypto/signatures/options.cpp +++ b/plugins/wasi_crypto/signatures/options.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "signatures/options.h" diff --git a/plugins/wasi_crypto/signatures/options.h b/plugins/wasi_crypto/signatures/options.h index a229dc9b..895c5d9f 100644 --- a/plugins/wasi_crypto/signatures/options.h +++ b/plugins/wasi_crypto/signatures/options.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/signatures/options.h - Options -------===// // diff --git a/plugins/wasi_crypto/signatures/registered.h b/plugins/wasi_crypto/signatures/registered.h index 5d436ea3..06dcd0d0 100644 --- a/plugins/wasi_crypto/signatures/registered.h +++ b/plugins/wasi_crypto/signatures/registered.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/signatures/registered.h - Registered //-----===// diff --git a/plugins/wasi_crypto/signatures/rsa.cpp b/plugins/wasi_crypto/signatures/rsa.cpp index 6e7fc8f0..245fe9ed 100644 --- a/plugins/wasi_crypto/signatures/rsa.cpp +++ b/plugins/wasi_crypto/signatures/rsa.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "signatures/rsa.h" diff --git a/plugins/wasi_crypto/signatures/rsa.h b/plugins/wasi_crypto/signatures/rsa.h index 85c9b312..878324ae 100644 --- a/plugins/wasi_crypto/signatures/rsa.h +++ b/plugins/wasi_crypto/signatures/rsa.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/signatures/rsa.h - Rsa alg implement -===// // diff --git a/plugins/wasi_crypto/signatures/signatures.cpp b/plugins/wasi_crypto/signatures/signatures.cpp index 29fae93a..ae3b003e 100644 --- a/plugins/wasi_crypto/signatures/signatures.cpp +++ b/plugins/wasi_crypto/signatures/signatures.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "signatures/signatures.h" diff --git a/plugins/wasi_crypto/signatures/signatures.h b/plugins/wasi_crypto/signatures/signatures.h index 6e548c0f..d2bd2bbb 100644 --- a/plugins/wasi_crypto/signatures/signatures.h +++ b/plugins/wasi_crypto/signatures/signatures.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/signatures/signatures.h - Signatures -===// // diff --git a/plugins/wasi_crypto/signatures/signstate.cpp b/plugins/wasi_crypto/signatures/signstate.cpp index bf12ddbc..a8c45707 100644 --- a/plugins/wasi_crypto/signatures/signstate.cpp +++ b/plugins/wasi_crypto/signatures/signstate.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "signatures/signstate.h" diff --git a/plugins/wasi_crypto/signatures/signstate.h b/plugins/wasi_crypto/signatures/signstate.h index 02fcd7a5..5ed5f61d 100644 --- a/plugins/wasi_crypto/signatures/signstate.h +++ b/plugins/wasi_crypto/signatures/signstate.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/signatures/signstate.h - SignState ---===// // diff --git a/plugins/wasi_crypto/signatures/verificationstate.cpp b/plugins/wasi_crypto/signatures/verificationstate.cpp index 547bbdb8..3f4593ef 100644 --- a/plugins/wasi_crypto/signatures/verificationstate.cpp +++ b/plugins/wasi_crypto/signatures/verificationstate.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "signatures/verificationstate.h" diff --git a/plugins/wasi_crypto/signatures/verificationstate.h b/plugins/wasi_crypto/signatures/verificationstate.h index 3e326daa..81e612f1 100644 --- a/plugins/wasi_crypto/signatures/verificationstate.h +++ b/plugins/wasi_crypto/signatures/verificationstate.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/signatures/verificationstate.h -------===// // diff --git a/plugins/wasi_crypto/symmetric/aeads.cpp b/plugins/wasi_crypto/symmetric/aeads.cpp index 08ba7658..15f7d127 100644 --- a/plugins/wasi_crypto/symmetric/aeads.cpp +++ b/plugins/wasi_crypto/symmetric/aeads.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "symmetric/aeads.h" #include "utils/error.h" diff --git a/plugins/wasi_crypto/symmetric/aeads.h b/plugins/wasi_crypto/symmetric/aeads.h index 96d6b245..4e580485 100644 --- a/plugins/wasi_crypto/symmetric/aeads.h +++ b/plugins/wasi_crypto/symmetric/aeads.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/symmetric/aeads.h - Aeads related ----===// // diff --git a/plugins/wasi_crypto/symmetric/ctx.cpp b/plugins/wasi_crypto/symmetric/ctx.cpp index 4afa2351..54263417 100644 --- a/plugins/wasi_crypto/symmetric/ctx.cpp +++ b/plugins/wasi_crypto/symmetric/ctx.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "ctx.h" #include "symmetric/key.h" diff --git a/plugins/wasi_crypto/symmetric/func.cpp b/plugins/wasi_crypto/symmetric/func.cpp index 47abba79..24ceb43a 100644 --- a/plugins/wasi_crypto/symmetric/func.cpp +++ b/plugins/wasi_crypto/symmetric/func.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "symmetric/func.h" diff --git a/plugins/wasi_crypto/symmetric/func.h b/plugins/wasi_crypto/symmetric/func.h index a122784f..eded49a3 100644 --- a/plugins/wasi_crypto/symmetric/func.h +++ b/plugins/wasi_crypto/symmetric/func.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/symmetric/func.h - Symmetric funcs ---===// // diff --git a/plugins/wasi_crypto/symmetric/hash.cpp b/plugins/wasi_crypto/symmetric/hash.cpp index ea2129b5..2e80f856 100644 --- a/plugins/wasi_crypto/symmetric/hash.cpp +++ b/plugins/wasi_crypto/symmetric/hash.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "symmetric/hash.h" #include "utils/evp_wrapper.h" diff --git a/plugins/wasi_crypto/symmetric/hash.h b/plugins/wasi_crypto/symmetric/hash.h index 66934da5..59257a42 100644 --- a/plugins/wasi_crypto/symmetric/hash.h +++ b/plugins/wasi_crypto/symmetric/hash.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/symmetric/hash.h - Hash related ------===// // diff --git a/plugins/wasi_crypto/symmetric/kdf.h b/plugins/wasi_crypto/symmetric/kdf.h index 73c8897c..4834f385 100644 --- a/plugins/wasi_crypto/symmetric/kdf.h +++ b/plugins/wasi_crypto/symmetric/kdf.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/symmetric/kdf.h - Kdf related --------===// // diff --git a/plugins/wasi_crypto/symmetric/key.cpp b/plugins/wasi_crypto/symmetric/key.cpp index 33797042..970d1ace 100644 --- a/plugins/wasi_crypto/symmetric/key.cpp +++ b/plugins/wasi_crypto/symmetric/key.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "symmetric/key.h" #include "utils/error.h" diff --git a/plugins/wasi_crypto/symmetric/key.h b/plugins/wasi_crypto/symmetric/key.h index bf2d8745..d7e4aa6d 100644 --- a/plugins/wasi_crypto/symmetric/key.h +++ b/plugins/wasi_crypto/symmetric/key.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/symmetric/key.h - Symmetric Key class ===// // diff --git a/plugins/wasi_crypto/symmetric/mac.cpp b/plugins/wasi_crypto/symmetric/mac.cpp index 53c7e2ee..2b0bab77 100644 --- a/plugins/wasi_crypto/symmetric/mac.cpp +++ b/plugins/wasi_crypto/symmetric/mac.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "symmetric/mac.h" #include "utils/secret_vec.h" diff --git a/plugins/wasi_crypto/symmetric/mac.h b/plugins/wasi_crypto/symmetric/mac.h index f1247b95..57fa5d43 100644 --- a/plugins/wasi_crypto/symmetric/mac.h +++ b/plugins/wasi_crypto/symmetric/mac.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/symmetric/mac.h - Mac related --------===// // diff --git a/plugins/wasi_crypto/symmetric/module.cpp b/plugins/wasi_crypto/symmetric/module.cpp index 402c4b19..01540cf8 100644 --- a/plugins/wasi_crypto/symmetric/module.cpp +++ b/plugins/wasi_crypto/symmetric/module.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "symmetric/module.h" #include "symmetric/func.h" diff --git a/plugins/wasi_crypto/symmetric/module.h b/plugins/wasi_crypto/symmetric/module.h index af4f23d4..14052359 100644 --- a/plugins/wasi_crypto/symmetric/module.h +++ b/plugins/wasi_crypto/symmetric/module.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/symmetric/module.h - Module ----------===// // diff --git a/plugins/wasi_crypto/symmetric/options.cpp b/plugins/wasi_crypto/symmetric/options.cpp index 923d4cc8..ae2b3c5f 100644 --- a/plugins/wasi_crypto/symmetric/options.cpp +++ b/plugins/wasi_crypto/symmetric/options.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "symmetric/options.h" diff --git a/plugins/wasi_crypto/symmetric/options.h b/plugins/wasi_crypto/symmetric/options.h index 9d0a6ec3..7e7d1bd1 100644 --- a/plugins/wasi_crypto/symmetric/options.h +++ b/plugins/wasi_crypto/symmetric/options.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/symmetric/options.h - Options --------===// // diff --git a/plugins/wasi_crypto/symmetric/registered.h b/plugins/wasi_crypto/symmetric/registered.h index 0785fc6c..2bf5a622 100644 --- a/plugins/wasi_crypto/symmetric/registered.h +++ b/plugins/wasi_crypto/symmetric/registered.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/symmetric/registered.h - Registered --===// // diff --git a/plugins/wasi_crypto/symmetric/state.cpp b/plugins/wasi_crypto/symmetric/state.cpp index 77852f58..c759d219 100644 --- a/plugins/wasi_crypto/symmetric/state.cpp +++ b/plugins/wasi_crypto/symmetric/state.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "symmetric/state.h" #include "utils/error.h" diff --git a/plugins/wasi_crypto/symmetric/state.h b/plugins/wasi_crypto/symmetric/state.h index d567cf2e..9759753e 100644 --- a/plugins/wasi_crypto/symmetric/state.h +++ b/plugins/wasi_crypto/symmetric/state.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/symmetric/state.h - Symmetric State --===// // diff --git a/plugins/wasi_crypto/symmetric/tag.cpp b/plugins/wasi_crypto/symmetric/tag.cpp index 2fee4fb9..3998aee2 100644 --- a/plugins/wasi_crypto/symmetric/tag.cpp +++ b/plugins/wasi_crypto/symmetric/tag.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "symmetric/tag.h" diff --git a/plugins/wasi_crypto/symmetric/tag.h b/plugins/wasi_crypto/symmetric/tag.h index 5d2f3fad..9daa458b 100644 --- a/plugins/wasi_crypto/symmetric/tag.h +++ b/plugins/wasi_crypto/symmetric/tag.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/symmetric/tag.h - Symmetric Tag class ===// // diff --git a/plugins/wasi_crypto/utils/error.h b/plugins/wasi_crypto/utils/error.h index ba1bc1fc..24018755 100644 --- a/plugins/wasi_crypto/utils/error.h +++ b/plugins/wasi_crypto/utils/error.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/utils/error.h - Error definition -----===// // diff --git a/plugins/wasi_crypto/utils/evp_wrapper.cpp b/plugins/wasi_crypto/utils/evp_wrapper.cpp index e5badb14..a28f9665 100644 --- a/plugins/wasi_crypto/utils/evp_wrapper.cpp +++ b/plugins/wasi_crypto/utils/evp_wrapper.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "utils/evp_wrapper.h" #include "utils/error.h" diff --git a/plugins/wasi_crypto/utils/evp_wrapper.h b/plugins/wasi_crypto/utils/evp_wrapper.h index d6edc89b..893fb0e5 100644 --- a/plugins/wasi_crypto/utils/evp_wrapper.h +++ b/plugins/wasi_crypto/utils/evp_wrapper.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/utils/evp_wrapper.h - Evp Wrapper ----===// // diff --git a/plugins/wasi_crypto/utils/handles_manager.h b/plugins/wasi_crypto/utils/handles_manager.h index 2282a283..8eb45603 100644 --- a/plugins/wasi_crypto/utils/handles_manager.h +++ b/plugins/wasi_crypto/utils/handles_manager.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/utils/handles_manager.h --------------===// // diff --git a/plugins/wasi_crypto/utils/hostfunction.cpp b/plugins/wasi_crypto/utils/hostfunction.cpp index b07c2be2..81c078a9 100644 --- a/plugins/wasi_crypto/utils/hostfunction.cpp +++ b/plugins/wasi_crypto/utils/hostfunction.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "utils/hostfunction.h" diff --git a/plugins/wasi_crypto/utils/hostfunction.h b/plugins/wasi_crypto/utils/hostfunction.h index bbe58859..460130eb 100644 --- a/plugins/wasi_crypto/utils/hostfunction.h +++ b/plugins/wasi_crypto/utils/hostfunction.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/hostfunc.h - HostFunction class ------===// // diff --git a/plugins/wasi_crypto/utils/optional.h b/plugins/wasi_crypto/utils/optional.h index ea8e93f6..85d06ad6 100644 --- a/plugins/wasi_crypto/utils/optional.h +++ b/plugins/wasi_crypto/utils/optional.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/utils/handles_manager.h - OptionalRef ===// // diff --git a/plugins/wasi_crypto/utils/secret_vec.h b/plugins/wasi_crypto/utils/secret_vec.h index a66bed9f..9e2d2f1b 100644 --- a/plugins/wasi_crypto/utils/secret_vec.h +++ b/plugins/wasi_crypto/utils/secret_vec.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC //===-- wasmedge/plugins/wasi_crypto/utils/secret_vec.h - Secret Vec def --===// // diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 9c6020b3..4878108c 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC wasmedge_add_library(wasmedgePluginWasiNN SHARED diff --git a/plugins/wasi_nn/chattts.cpp b/plugins/wasi_nn/chattts.cpp index 8b8a2a04..ecd4f99e 100644 --- a/plugins/wasi_nn/chattts.cpp +++ b/plugins/wasi_nn/chattts.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "chattts.h" #include "wasinnenv.h" @@ -357,4 +357,4 @@ Expect unload(WASINN::WasiNNEnvironment &, uint32_t) noexcept { } #endif -} // namespace WasmEdge::Host::WASINN::ChatTTS \ No newline at end of file +} // namespace WasmEdge::Host::WASINN::ChatTTS diff --git a/plugins/wasi_nn/chattts.h b/plugins/wasi_nn/chattts.h index 84676e25..46cab0a3 100644 --- a/plugins/wasi_nn/chattts.h +++ b/plugins/wasi_nn/chattts.h @@ -1,11 +1,16 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "plugin/plugin.h" #include "types.h" + #include #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS #include #endif + namespace WasmEdge::Host::WASINN { struct WasiNNEnvironment; } @@ -58,4 +63,4 @@ Expect compute(WASINN::WasiNNEnvironment &Env, uint32_t ContextId) noexcept; Expect unload(WASINN::WasiNNEnvironment &Env, uint32_t GraphId) noexcept; -} // namespace WasmEdge::Host::WASINN::ChatTTS \ No newline at end of file +} // namespace WasmEdge::Host::WASINN::ChatTTS diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index c91ea2b7..d196f944 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "ggml.h" #include "wasinnenv.h" diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/ggml.h index 76890216..b29dcd0b 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/ggml.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasi_nn/neuralspeed.cpp b/plugins/wasi_nn/neuralspeed.cpp index e670b342..ef4f11eb 100644 --- a/plugins/wasi_nn/neuralspeed.cpp +++ b/plugins/wasi_nn/neuralspeed.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "neuralspeed.h" #include "wasinnenv.h" diff --git a/plugins/wasi_nn/neuralspeed.h b/plugins/wasi_nn/neuralspeed.h index fcdf6a5a..7893aa74 100644 --- a/plugins/wasi_nn/neuralspeed.h +++ b/plugins/wasi_nn/neuralspeed.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasi_nn/onnx.cpp b/plugins/wasi_nn/onnx.cpp index 3fa4b326..7b7cb6f9 100644 --- a/plugins/wasi_nn/onnx.cpp +++ b/plugins/wasi_nn/onnx.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "onnx.h" #include "wasinnenv.h" diff --git a/plugins/wasi_nn/onnx.h b/plugins/wasi_nn/onnx.h index e5aaff1a..6d46e02b 100644 --- a/plugins/wasi_nn/onnx.h +++ b/plugins/wasi_nn/onnx.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasi_nn/openvino.cpp b/plugins/wasi_nn/openvino.cpp index 9c0625ac..72fcb70c 100644 --- a/plugins/wasi_nn/openvino.cpp +++ b/plugins/wasi_nn/openvino.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "openvino.h" #include "wasinnenv.h" diff --git a/plugins/wasi_nn/openvino.h b/plugins/wasi_nn/openvino.h index 0ae3a2a7..b3616666 100644 --- a/plugins/wasi_nn/openvino.h +++ b/plugins/wasi_nn/openvino.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasi_nn/piper.cpp b/plugins/wasi_nn/piper.cpp index fc0dce77..c764c882 100644 --- a/plugins/wasi_nn/piper.cpp +++ b/plugins/wasi_nn/piper.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "piper.h" #include "wasinnenv.h" diff --git a/plugins/wasi_nn/piper.h b/plugins/wasi_nn/piper.h index 70b4b19b..5ff6c43c 100644 --- a/plugins/wasi_nn/piper.h +++ b/plugins/wasi_nn/piper.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasi_nn/tf.cpp b/plugins/wasi_nn/tf.cpp index 353d1cf7..6dd02ef3 100644 --- a/plugins/wasi_nn/tf.cpp +++ b/plugins/wasi_nn/tf.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "tf.h" #include "wasinnenv.h" diff --git a/plugins/wasi_nn/tf.h b/plugins/wasi_nn/tf.h index 20e38b59..6f822af4 100644 --- a/plugins/wasi_nn/tf.h +++ b/plugins/wasi_nn/tf.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasi_nn/tfl.cpp b/plugins/wasi_nn/tfl.cpp index 32d3c5f3..28e1a6a9 100644 --- a/plugins/wasi_nn/tfl.cpp +++ b/plugins/wasi_nn/tfl.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "tfl.h" #include "wasinnenv.h" diff --git a/plugins/wasi_nn/tfl.h b/plugins/wasi_nn/tfl.h index 84d8227a..451fcf6d 100644 --- a/plugins/wasi_nn/tfl.h +++ b/plugins/wasi_nn/tfl.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasi_nn/torch.cpp b/plugins/wasi_nn/torch.cpp index 153ed32f..7470ce34 100644 --- a/plugins/wasi_nn/torch.cpp +++ b/plugins/wasi_nn/torch.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "torch.h" #include "wasinnenv.h" diff --git a/plugins/wasi_nn/torch.h b/plugins/wasi_nn/torch.h index 8961efc8..fa480cfb 100644 --- a/plugins/wasi_nn/torch.h +++ b/plugins/wasi_nn/torch.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasi_nn/types.h b/plugins/wasi_nn/types.h index c8d74cca..12eede09 100644 --- a/plugins/wasi_nn/types.h +++ b/plugins/wasi_nn/types.h @@ -1,9 +1,11 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once + #include "common/span.h" #include "common/spdlog.h" + #include namespace WasmEdge::Host::WASINN { diff --git a/plugins/wasi_nn/wasinnbase.h b/plugins/wasi_nn/wasinnbase.h index 4568dc76..00898003 100644 --- a/plugins/wasi_nn/wasinnbase.h +++ b/plugins/wasi_nn/wasinnbase.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index b33f8255..331e4963 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "wasinnenv.h" #include "types.h" diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index da952fe2..6ec0077f 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index 9af5fc96..d9fb5e69 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "wasinnfunc.h" #include "common/spdlog.h" diff --git a/plugins/wasi_nn/wasinnfunc.h b/plugins/wasi_nn/wasinnfunc.h index dcda910c..b955a300 100644 --- a/plugins/wasi_nn/wasinnfunc.h +++ b/plugins/wasi_nn/wasinnfunc.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasi_nn/wasinnmodule.cpp b/plugins/wasi_nn/wasinnmodule.cpp index b2b2c6d5..6226f89d 100644 --- a/plugins/wasi_nn/wasinnmodule.cpp +++ b/plugins/wasi_nn/wasinnmodule.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "wasinnmodule.h" #include "wasinnfunc.h" diff --git a/plugins/wasi_nn/wasinnmodule.h b/plugins/wasi_nn/wasinnmodule.h index 0c18bd16..87b0fd4e 100644 --- a/plugins/wasi_nn/wasinnmodule.h +++ b/plugins/wasi_nn/wasinnmodule.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasi_nn/whispercpp.cpp b/plugins/wasi_nn/whispercpp.cpp index 21632e30..bee0efff 100644 --- a/plugins/wasi_nn/whispercpp.cpp +++ b/plugins/wasi_nn/whispercpp.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "whispercpp.h" #include "wasinnenv.h" diff --git a/plugins/wasi_nn/whispercpp.h b/plugins/wasi_nn/whispercpp.h index 69b8c038..c5d3631e 100644 --- a/plugins/wasi_nn/whispercpp.h +++ b/plugins/wasi_nn/whispercpp.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasi_ocr/CMakeLists.txt b/plugins/wasi_ocr/CMakeLists.txt index b2eef6d3..de9402e9 100644 --- a/plugins/wasi_ocr/CMakeLists.txt +++ b/plugins/wasi_ocr/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC add_library(wasmedgePluginWasiOCR SHARED diff --git a/plugins/wasi_poll/func.h b/plugins/wasi_poll/func.h index a0c1c7d8..4daa13bc 100644 --- a/plugins/wasi_poll/func.h +++ b/plugins/wasi_poll/func.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2023 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasm_bpf/CMakeLists.txt b/plugins/wasm_bpf/CMakeLists.txt index 0e4e8120..924b898a 100644 --- a/plugins/wasm_bpf/CMakeLists.txt +++ b/plugins/wasm_bpf/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC # Try to get libbpf use the following order # - PkgConfig diff --git a/plugins/wasm_bpf/bpf-api.h b/plugins/wasm_bpf/bpf-api.h index 3b2018fc..1970fe69 100644 --- a/plugins/wasm_bpf/bpf-api.h +++ b/plugins/wasm_bpf/bpf-api.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasm_bpf/func-attach-bpf-program.cpp b/plugins/wasm_bpf/func-attach-bpf-program.cpp index be4d6ade..27e2e0d8 100644 --- a/plugins/wasm_bpf/func-attach-bpf-program.cpp +++ b/plugins/wasm_bpf/func-attach-bpf-program.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "func-attach-bpf-program.h" #include "util.h" diff --git a/plugins/wasm_bpf/func-attach-bpf-program.h b/plugins/wasm_bpf/func-attach-bpf-program.h index df5abedd..5eda9088 100644 --- a/plugins/wasm_bpf/func-attach-bpf-program.h +++ b/plugins/wasm_bpf/func-attach-bpf-program.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasm_bpf/func-bpf-buffer-poll.cpp b/plugins/wasm_bpf/func-bpf-buffer-poll.cpp index 45b6e5ca..7fb8ce03 100644 --- a/plugins/wasm_bpf/func-bpf-buffer-poll.cpp +++ b/plugins/wasm_bpf/func-bpf-buffer-poll.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "func-bpf-buffer-poll.h" #include "wasmedge/wasmedge.h" @@ -12,6 +12,7 @@ namespace Host { inline const auto *toCallFrameCxt(const Runtime::CallingFrame *Cxt) noexcept { return reinterpret_cast(Cxt); } + Expect BpfBufferPoll::body(const Runtime::CallingFrame &Frame, handle_t program, int32_t fd, int32_t sample_func, uint32_t ctx, diff --git a/plugins/wasm_bpf/func-bpf-buffer-poll.h b/plugins/wasm_bpf/func-bpf-buffer-poll.h index 2ece1588..ff33f82e 100644 --- a/plugins/wasm_bpf/func-bpf-buffer-poll.h +++ b/plugins/wasm_bpf/func-bpf-buffer-poll.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasm_bpf/func-bpf-map-fd-by-name.cpp b/plugins/wasm_bpf/func-bpf-map-fd-by-name.cpp index c7a98eef..6f6e7172 100644 --- a/plugins/wasm_bpf/func-bpf-map-fd-by-name.cpp +++ b/plugins/wasm_bpf/func-bpf-map-fd-by-name.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "func-bpf-map-fd-by-name.h" #include "util.h" diff --git a/plugins/wasm_bpf/func-bpf-map-fd-by-name.h b/plugins/wasm_bpf/func-bpf-map-fd-by-name.h index d15f67db..f4d93e1e 100644 --- a/plugins/wasm_bpf/func-bpf-map-fd-by-name.h +++ b/plugins/wasm_bpf/func-bpf-map-fd-by-name.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasm_bpf/func-bpf-map-operate.cpp b/plugins/wasm_bpf/func-bpf-map-operate.cpp index e1dbc242..73b1cd6f 100644 --- a/plugins/wasm_bpf/func-bpf-map-operate.cpp +++ b/plugins/wasm_bpf/func-bpf-map-operate.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "func-bpf-map-operate.h" #include "bpf-api.h" @@ -16,6 +16,7 @@ namespace Host { if (var##_span.size() != expected_size) \ return Unexpect(ErrCode::Value::HostFuncError); \ const auto var = var##_span.data(); + Expect BpfMapOperate::body(const WasmEdge::Runtime::CallingFrame &Frame, int32_t fd, int32_t cmd, uint32_t key, uint32_t value, diff --git a/plugins/wasm_bpf/func-bpf-map-operate.h b/plugins/wasm_bpf/func-bpf-map-operate.h index 259f28de..486f7d9b 100644 --- a/plugins/wasm_bpf/func-bpf-map-operate.h +++ b/plugins/wasm_bpf/func-bpf-map-operate.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "bpf-api.h" #include "plugin/plugin.h" diff --git a/plugins/wasm_bpf/func-close-bpf-object.cpp b/plugins/wasm_bpf/func-close-bpf-object.cpp index 909b4f0f..52b428fb 100644 --- a/plugins/wasm_bpf/func-close-bpf-object.cpp +++ b/plugins/wasm_bpf/func-close-bpf-object.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "func-close-bpf-object.h" #include diff --git a/plugins/wasm_bpf/func-close-bpf-object.h b/plugins/wasm_bpf/func-close-bpf-object.h index 3c441238..1d8d8c12 100644 --- a/plugins/wasm_bpf/func-close-bpf-object.h +++ b/plugins/wasm_bpf/func-close-bpf-object.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasm_bpf/func-load-bpf-object.cpp b/plugins/wasm_bpf/func-load-bpf-object.cpp index 99e5f72e..c9c7b682 100644 --- a/plugins/wasm_bpf/func-load-bpf-object.cpp +++ b/plugins/wasm_bpf/func-load-bpf-object.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "func-load-bpf-object.h" diff --git a/plugins/wasm_bpf/func-load-bpf-object.h b/plugins/wasm_bpf/func-load-bpf-object.h index 5a6226ad..39516908 100644 --- a/plugins/wasm_bpf/func-load-bpf-object.h +++ b/plugins/wasm_bpf/func-load-bpf-object.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasm_bpf/state.h b/plugins/wasm_bpf/state.h index 7ea995b8..cd9d4720 100644 --- a/plugins/wasm_bpf/state.h +++ b/plugins/wasm_bpf/state.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasm_bpf/util.cpp b/plugins/wasm_bpf/util.cpp index 9851c2ac..b3c45b7b 100644 --- a/plugins/wasm_bpf/util.cpp +++ b/plugins/wasm_bpf/util.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "util.h" diff --git a/plugins/wasm_bpf/util.h b/plugins/wasm_bpf/util.h index 958290ff..4ac79e44 100644 --- a/plugins/wasm_bpf/util.h +++ b/plugins/wasm_bpf/util.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasm_bpf/wasm-bpf-module.cpp b/plugins/wasm_bpf/wasm-bpf-module.cpp index 6118f6ff..8e564a6d 100644 --- a/plugins/wasm_bpf/wasm-bpf-module.cpp +++ b/plugins/wasm_bpf/wasm-bpf-module.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "wasm-bpf-module.h" #include "func-attach-bpf-program.h" diff --git a/plugins/wasm_bpf/wasm-bpf-module.h b/plugins/wasm_bpf/wasm-bpf-module.h index 8cbdd9dd..65eba255 100644 --- a/plugins/wasm_bpf/wasm-bpf-module.h +++ b/plugins/wasm_bpf/wasm-bpf-module.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasm_bpf/wasm-bpf.cpp b/plugins/wasm_bpf/wasm-bpf.cpp index 9adb55eb..3c60665a 100644 --- a/plugins/wasm_bpf/wasm-bpf.cpp +++ b/plugins/wasm_bpf/wasm-bpf.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include #include diff --git a/plugins/wasmedge_ffmpeg/CMakeLists.txt b/plugins/wasmedge_ffmpeg/CMakeLists.txt index e78dcaa8..4895100d 100644 --- a/plugins/wasmedge_ffmpeg/CMakeLists.txt +++ b/plugins/wasmedge_ffmpeg/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC find_package(PkgConfig REQUIRED) pkg_check_modules(LIBAV REQUIRED IMPORTED_TARGET diff --git a/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp b/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp index 242ccab0..3ba24d47 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp +++ b/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avCodec.h" extern "C" { @@ -11,35 +14,30 @@ namespace AVcodec { Expect AVCodecID::body(const Runtime::CallingFrame &, uint32_t AvCodecId) { - FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); return FFmpegUtils::CodecID::fromAVCodecID(AvCodec->id); } Expect AVCodecType::body(const Runtime::CallingFrame &, uint32_t AvCodecId) { - FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); return FFmpegUtils::MediaType::fromMediaType(AvCodec->type); } Expect AVCodecMaxLowres::body(const Runtime::CallingFrame &, uint32_t AvCodecId) { - FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); return AvCodec->max_lowres; } Expect AVCodecCapabilities::body(const Runtime::CallingFrame &, uint32_t AvCodecId) { - FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); return AvCodec->capabilities; } Expect AVCodecGetNameLen::body(const Runtime::CallingFrame &, uint32_t AvCodecId) { - FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); return strlen(AvCodec->name); } @@ -47,7 +45,6 @@ Expect AVCodecGetNameLen::body(const Runtime::CallingFrame &, Expect AVCodecGetName::body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, uint32_t NamePtr, uint32_t NameLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(NameBuf, MemInst, char, NamePtr, NameLen, ""); FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); @@ -59,7 +56,6 @@ Expect AVCodecGetName::body(const Runtime::CallingFrame &Frame, Expect AVCodecGetLongNameLen::body(const Runtime::CallingFrame &, uint32_t AvCodecId) { - FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); return strlen(AvCodec->long_name); } @@ -68,7 +64,6 @@ Expect AVCodecGetLongName::body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, uint32_t LongNamePtr, uint32_t LongNameLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(LongNameBuf, MemInst, char, LongNamePtr, LongNameLen, ""); FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); @@ -80,29 +75,29 @@ Expect AVCodecGetLongName::body(const Runtime::CallingFrame &Frame, Expect AVCodecProfiles::body(const Runtime::CallingFrame &, uint32_t AvCodecId) { - FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); - if (AvCodec->profiles) + if (AvCodec->profiles) { return 1; + } return 0; } Expect AVCodecPixFmtsIsNull::body(const Runtime::CallingFrame &, uint32_t AvCodecId) { - FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); - if (AvCodec->pix_fmts == nullptr) + if (AvCodec->pix_fmts == nullptr) { return 1; + } return 0; } Expect AVCodecPixFmtsIter::body(const Runtime::CallingFrame &, uint32_t AvCodecId, uint32_t Idx) { - FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); AVPixelFormat const *PixelFormat = AvCodec->pix_fmts; - if (PixelFormat == nullptr) + if (PixelFormat == nullptr) { return 0; + } uint32_t Curr = 0; while (Curr < Idx) { @@ -116,10 +111,10 @@ Expect AVCodecPixFmtsIter::body(const Runtime::CallingFrame &, Expect AVCodecSupportedFrameratesIsNull::body(const Runtime::CallingFrame &, uint32_t AvCodecId) { - FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); - if (AvCodec->supported_framerates == nullptr) + if (AvCodec->supported_framerates == nullptr) { return 1; + } return 0; } @@ -127,7 +122,6 @@ Expect AVCodecSupportedFrameratesIter::body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, uint32_t Idx, uint32_t NumPtr, uint32_t DenPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(NumId, MemInst, int32_t, NumPtr, "Failed when accessing the return NumPtr Memory"sv); @@ -157,21 +151,21 @@ AVCodecSupportedFrameratesIter::body(const Runtime::CallingFrame &Frame, Expect AVCodecSupportedSampleRatesIsNull::body(const Runtime::CallingFrame &, uint32_t AvCodecId) { - FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); - if (AvCodec->supported_samplerates == nullptr) + if (AvCodec->supported_samplerates == nullptr) { return 1; + } return 0; } Expect AVCodecSupportedSampleRatesIter::body(const Runtime::CallingFrame &, uint32_t AvCodecId, uint32_t Idx) { - FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); const int32_t *SampleRates = AvCodec->supported_samplerates; - if (SampleRates == nullptr) + if (SampleRates == nullptr) { return 0; + } uint32_t Curr = 0; while (Curr < Idx) { @@ -184,10 +178,10 @@ AVCodecSupportedSampleRatesIter::body(const Runtime::CallingFrame &, Expect AVCodecChannelLayoutIsNull::body(const Runtime::CallingFrame &, uint32_t AvCodecId) { - FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); - if (AvCodec->channel_layouts == nullptr) + if (AvCodec->channel_layouts == nullptr) { return 1; + } return 0; } @@ -197,8 +191,9 @@ Expect AVCodecChannelLayoutIter::body(const Runtime::CallingFrame &, FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); const uint64_t *ChannelLayout = AvCodec->channel_layouts; - if (ChannelLayout == nullptr) + if (ChannelLayout == nullptr) { return 0; + } uint32_t Curr = 0; while (Curr < Idx) { @@ -211,10 +206,10 @@ Expect AVCodecChannelLayoutIter::body(const Runtime::CallingFrame &, Expect AVCodecSampleFmtsIsNull::body(const Runtime::CallingFrame &, uint32_t AvCodecId) { - FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); - if (AvCodec->sample_fmts == nullptr) + if (AvCodec->sample_fmts == nullptr) { return 1; + } return 0; } @@ -223,8 +218,9 @@ Expect AVCodecSampleFmtsIter::body(const Runtime::CallingFrame &, FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); AVSampleFormat const *SampleFormat = AvCodec->sample_fmts; - if (SampleFormat == nullptr) + if (SampleFormat == nullptr) { return 0; + } uint32_t Curr = 0; while (Curr < Idx) { diff --git a/plugins/wasmedge_ffmpeg/avcodec/avCodec.h b/plugins/wasmedge_ffmpeg/avcodec/avCodec.h index da1b5fab..8954bd4e 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/avCodec.h +++ b/plugins/wasmedge_ffmpeg/avcodec/avCodec.h @@ -1,4 +1,8 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once + #include "avcodec_base.h" #include "runtime/callingframe.h" diff --git a/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.cpp b/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.cpp index b766ae78..ad6de103 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.cpp +++ b/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avCodecContext.h" extern "C" { @@ -11,7 +14,6 @@ namespace AVcodec { Expect AVCodecCtxCodecID::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AVCodecID const AvCodecId = AvCodecCtx->codec_id; return FFmpegUtils::CodecID::fromAVCodecID(AvCodecId); @@ -19,7 +21,6 @@ Expect AVCodecCtxCodecID::body(const Runtime::CallingFrame &, Expect AVCodecCtxCodecType::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AVMediaType const AvMediaType = AvCodecCtx->codec_type; return FFmpegUtils::MediaType::fromMediaType(AvMediaType); @@ -28,7 +29,6 @@ Expect AVCodecCtxCodecType::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetCodecType::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t CodecTypeId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AVMediaType const AvMediaType = FFmpegUtils::MediaType::intoMediaType(CodecTypeId); @@ -40,7 +40,6 @@ Expect AVCodecCtxSetCodecType::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetTimebase::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t Num, int32_t Den) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AVRational const Rational = av_make_q(Num, Den); AvCodecCtx->time_base = Rational; @@ -50,7 +49,6 @@ Expect AVCodecCtxSetTimebase::body(const Runtime::CallingFrame &, Expect AVCodecCtxTimeBase::body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, uint32_t NumPtr, uint32_t DenPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(Num, MemInst, int32_t, NumPtr, "Failed to access Numerator Ptr for AVRational"sv); @@ -66,14 +64,12 @@ Expect AVCodecCtxTimeBase::body(const Runtime::CallingFrame &Frame, Expect AVCodecCtxWidth::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); return AvCodecCtx->width; } Expect AVCodecCtxSetWidth::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t Width) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->width = Width; return static_cast(ErrNo::Success); @@ -81,7 +77,6 @@ Expect AVCodecCtxSetWidth::body(const Runtime::CallingFrame &, Expect AVCodecCtxHeight::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); return AvCodecCtx->height; } @@ -89,7 +84,6 @@ Expect AVCodecCtxHeight::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetHeight::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t Height) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->height = Height; return static_cast(ErrNo::Success); @@ -99,7 +93,6 @@ Expect AVCodecCtxSampleAspectRatio::body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, uint32_t NumPtr, uint32_t DenPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(Num, MemInst, int32_t, NumPtr, "Failed to access Numerator Ptr for AVRational"sv); @@ -117,7 +110,6 @@ Expect AVCodecCtxSetSampleAspectRatio::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t Num, int32_t Den) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); const AVRational AspectRatio = av_make_q(Num, Den); AvCodecCtx->sample_aspect_ratio = AspectRatio; @@ -126,7 +118,6 @@ AVCodecCtxSetSampleAspectRatio::body(const Runtime::CallingFrame &, Expect AVCodecCtxChannelLayout::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); // Deprecated method uint64_t const AvChannel = AvCodecCtx->channel_layout; @@ -136,7 +127,6 @@ Expect AVCodecCtxChannelLayout::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetChannelLayout::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, uint64_t ChannelLayoutId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); uint64_t const AvChannel = FFmpegUtils::ChannelLayout::fromChannelLayoutID(ChannelLayoutId); @@ -154,7 +144,6 @@ Expect AVCodecCtxPixFormat::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetPixFormat::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, uint32_t PixFmtId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AVPixelFormat const PixFmt = FFmpegUtils::PixFmt::intoAVPixFmt(PixFmtId); AvCodecCtx->pix_fmt = PixFmt; @@ -163,7 +152,6 @@ Expect AVCodecCtxSetPixFormat::body(const Runtime::CallingFrame &, Expect AVCodecCtxSampleFormat::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AVSampleFormat const AvSampleFormat = AvCodecCtx->sample_fmt; return FFmpegUtils::SampleFmt::toSampleID(AvSampleFormat); @@ -172,7 +160,6 @@ Expect AVCodecCtxSampleFormat::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetSampleFormat::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, uint32_t SampleFmtId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AVSampleFormat const SampleFormat = FFmpegUtils::SampleFmt::fromSampleID(SampleFmtId); @@ -182,7 +169,6 @@ Expect AVCodecCtxSetSampleFormat::body(const Runtime::CallingFrame &, Expect AVCodecCtxSampleRate::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); return AvCodecCtx->sample_rate; } @@ -190,7 +176,6 @@ Expect AVCodecCtxSampleRate::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetSampleRate::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t SampleRate) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->sample_rate = SampleRate; return static_cast(ErrNo::Success); @@ -199,7 +184,6 @@ Expect AVCodecCtxSetSampleRate::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetGopSize::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t GopSize) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->gop_size = GopSize; return static_cast(ErrNo::Success); @@ -208,7 +192,6 @@ Expect AVCodecCtxSetGopSize::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetMaxBFrames::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t MaxBFrames) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->max_b_frames = MaxBFrames; return static_cast(ErrNo::Success); @@ -217,7 +200,6 @@ Expect AVCodecCtxSetMaxBFrames::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetBQuantFactor::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, float BQuantFactor) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->b_quant_factor = BQuantFactor; return static_cast(ErrNo::Success); @@ -226,7 +208,6 @@ Expect AVCodecCtxSetBQuantFactor::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetBQuantOffset::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, float BQuantOffset) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->b_quant_offset = BQuantOffset; return static_cast(ErrNo::Success); @@ -235,7 +216,6 @@ Expect AVCodecCtxSetBQuantOffset::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetIQuantFactor::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, float IQuantFactor) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->i_quant_factor = IQuantFactor; return static_cast(ErrNo::Success); @@ -244,7 +224,6 @@ Expect AVCodecCtxSetIQuantFactor::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetIQuantOffset::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, float IQuantOffset) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->i_quant_offset = IQuantOffset; return static_cast(ErrNo::Success); @@ -253,7 +232,6 @@ Expect AVCodecCtxSetIQuantOffset::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetLumiMasking::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, float LumiMasking) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->lumi_masking = LumiMasking; return static_cast(ErrNo::Success); @@ -263,7 +241,6 @@ Expect AVCodecCtxSetTemporalCplxMasking::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, float TemporalCplxMasking) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->temporal_cplx_masking = TemporalCplxMasking; return static_cast(ErrNo::Success); @@ -273,7 +250,6 @@ Expect AVCodecCtxSetSpatialCplxMasking::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, float SpatialCplxMasking) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->spatial_cplx_masking = SpatialCplxMasking; return static_cast(ErrNo::Success); @@ -282,7 +258,6 @@ AVCodecCtxSetSpatialCplxMasking::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetPMasking::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, float PMasking) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->p_masking = PMasking; return static_cast(ErrNo::Success); @@ -291,7 +266,6 @@ Expect AVCodecCtxSetPMasking::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetDarkMasking::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, float DarkMasking) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->dark_masking = DarkMasking; return static_cast(ErrNo::Success); @@ -299,7 +273,6 @@ Expect AVCodecCtxSetDarkMasking::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetMeCmp::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t MeCmp) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->me_cmp = MeCmp; return static_cast(ErrNo::Success); @@ -308,7 +281,6 @@ Expect AVCodecCtxSetMeCmp::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetMeSubCmp::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t MeSubCmp) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->me_sub_cmp = MeSubCmp; return static_cast(ErrNo::Success); @@ -316,7 +288,6 @@ Expect AVCodecCtxSetMeSubCmp::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetMbCmp::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t MbCmp) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->mb_cmp = MbCmp; return static_cast(ErrNo::Success); @@ -325,7 +296,6 @@ Expect AVCodecCtxSetMbCmp::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetIldctCmp::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t IldctCmp) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->ildct_cmp = IldctCmp; return static_cast(ErrNo::Success); @@ -334,7 +304,6 @@ Expect AVCodecCtxSetIldctCmp::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetDiaSize::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t DiaSize) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->dia_size = DiaSize; return static_cast(ErrNo::Success); @@ -344,7 +313,6 @@ Expect AVCodecCtxSetLastPredictorsCount::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t LastPredictorCount) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->last_predictor_count = LastPredictorCount; return static_cast(ErrNo::Success); @@ -353,7 +321,6 @@ AVCodecCtxSetLastPredictorsCount::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetMePreCmp::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t MePreCmp) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->me_pre_cmp = MePreCmp; return static_cast(ErrNo::Success); @@ -362,7 +329,6 @@ Expect AVCodecCtxSetMePreCmp::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetPreDiaSize::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t PreDiaSize) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->pre_dia_size = PreDiaSize; return static_cast(ErrNo::Success); @@ -372,7 +338,6 @@ Expect AVCodecCtxSetMeSubpelQuality::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t MeSubpelQuality) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->me_subpel_quality = MeSubpelQuality; return static_cast(ErrNo::Success); @@ -381,7 +346,6 @@ AVCodecCtxSetMeSubpelQuality::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetMeRange::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t MeRange) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->me_range = MeRange; return static_cast(ErrNo::Success); @@ -390,7 +354,6 @@ Expect AVCodecCtxSetMeRange::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetMbDecision::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t MbDecision) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->mb_decision = MbDecision; return static_cast(ErrNo::Success); @@ -399,7 +362,6 @@ Expect AVCodecCtxSetMbDecision::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetMbLMin::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t MbLMin) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->mb_lmin = MbLMin; return static_cast(ErrNo::Success); @@ -408,7 +370,6 @@ Expect AVCodecCtxSetMbLMin::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetMbLMax::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t MbLMax) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->mb_lmax = MbLMax; return static_cast(ErrNo::Success); @@ -416,7 +377,6 @@ Expect AVCodecCtxSetMbLMax::body(const Runtime::CallingFrame &, Expect AVCodecCtxIntraDcPrecision::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); return AvCodecCtx->intra_dc_precision; } @@ -425,7 +385,6 @@ Expect AVCodecCtxSetIntraDcPrecision::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t IntraDcPrecision) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->intra_dc_precision = IntraDcPrecision; return static_cast(ErrNo::Success); @@ -433,7 +392,6 @@ AVCodecCtxSetIntraDcPrecision::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetQMin::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t QMin) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->qmin = QMin; return static_cast(ErrNo::Success); @@ -441,7 +399,6 @@ Expect AVCodecCtxSetQMin::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetQMax::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t QMax) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->qmax = QMax; return static_cast(ErrNo::Success); @@ -450,7 +407,6 @@ Expect AVCodecCtxSetQMax::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetGlobalQuality::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t GlobalQuality) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->global_quality = GlobalQuality; return static_cast(ErrNo::Success); @@ -459,7 +415,6 @@ Expect AVCodecCtxSetGlobalQuality::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetColorspace::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t ColorspaceId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AVColorSpace const ColorSpace = FFmpegUtils::ColorSpace::intoAVColorSpace(ColorspaceId); @@ -469,7 +424,6 @@ Expect AVCodecCtxSetColorspace::body(const Runtime::CallingFrame &, Expect AVCodecCtxColorspace::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AVColorSpace const Colorspace = AvCodecCtx->colorspace; return FFmpegUtils::ColorSpace::fromAVColorSpace(Colorspace); @@ -478,7 +432,6 @@ Expect AVCodecCtxColorspace::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetColorRange::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t ColorRangeId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->color_range = static_cast(ColorRangeId); return static_cast(ErrNo::Success); @@ -486,7 +439,6 @@ Expect AVCodecCtxSetColorRange::body(const Runtime::CallingFrame &, Expect AVCodecCtxColorRange::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AVColorRange const ColorRange = AvCodecCtx->color_range; return static_cast(ColorRange); @@ -494,14 +446,12 @@ Expect AVCodecCtxColorRange::body(const Runtime::CallingFrame &, Expect AVCodecCtxFrameSize::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); return AvCodecCtx->frame_size; } Expect AVCodecCtxBitRate::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); return AvCodecCtx->bit_rate; } @@ -509,7 +459,6 @@ Expect AVCodecCtxBitRate::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetBitRate::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int64_t BitRate) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->bit_rate = BitRate; return static_cast(ErrNo::Success); @@ -517,7 +466,6 @@ Expect AVCodecCtxSetBitRate::body(const Runtime::CallingFrame &, Expect AVCodecCtxRcMaxRate::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); return AvCodecCtx->rc_max_rate; } @@ -525,7 +473,6 @@ Expect AVCodecCtxRcMaxRate::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetRcMaxRate::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int64_t RcMaxRate) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->rc_max_rate = RcMaxRate; return static_cast(ErrNo::Success); @@ -535,7 +482,6 @@ Expect AVCodecCtxSetBitRateTolerance::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t BitRateTolerance) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->bit_rate_tolerance = BitRateTolerance; return static_cast(ErrNo::Success); @@ -545,7 +491,6 @@ Expect AVCodecCtxSetCompressionLevel::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t CompressionLevel) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->compression_level = CompressionLevel; return static_cast(ErrNo::Success); @@ -554,7 +499,6 @@ AVCodecCtxSetCompressionLevel::body(const Runtime::CallingFrame &, Expect AVCodecCtxFrameRate::body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, uint32_t NumPtr, uint32_t DenPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(Num, MemInst, int32_t, NumPtr, "Failed to access Numerator Ptr for AVRational"sv); @@ -572,7 +516,6 @@ Expect AVCodecCtxFrameRate::body(const Runtime::CallingFrame &Frame, Expect AVCodecCtxSetFrameRate::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t Num, int32_t Den) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AVRational const Rational = av_make_q(Num, Den); AvCodecCtx->framerate = Rational; @@ -581,7 +524,6 @@ Expect AVCodecCtxSetFrameRate::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetFlags::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t Flags) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->flags = Flags; return static_cast(ErrNo::Success); @@ -591,7 +533,6 @@ Expect AVCodecCtxSetStrictStdCompliance::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t ComplianceId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->strict_std_compliance = ComplianceId; return static_cast(ErrNo::Success); @@ -599,7 +540,6 @@ AVCodecCtxSetStrictStdCompliance::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetDebug::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t Debug) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->debug = Debug; return static_cast(ErrNo::Success); @@ -608,7 +548,6 @@ Expect AVCodecCtxSetDebug::body(const Runtime::CallingFrame &, Expect AVCodecCtxCodec::body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, uint32_t AvCodecPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(AVCodecId, MemInst, uint32_t, AvCodecPtr, "Failed to access Ptr for AvCodecPtr"sv); @@ -626,7 +565,6 @@ Expect AVCodecCtxCodec::body(const Runtime::CallingFrame &Frame, Expect AVCodecCtxChannels::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); return AvCodecCtx->channels; } @@ -634,7 +572,6 @@ Expect AVCodecCtxChannels::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetChannels::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t Channels) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->channels = Channels; return static_cast(ErrNo::Success); @@ -643,7 +580,6 @@ Expect AVCodecCtxSetChannels::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetSkipLoopFilter::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t AVDiscardId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->skip_loop_filter = static_cast(AVDiscardId); return static_cast(ErrNo::Success); @@ -652,7 +588,6 @@ Expect AVCodecCtxSetSkipLoopFilter::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetSkipFrame::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t AVDiscardId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->skip_frame = static_cast(AVDiscardId); return static_cast(ErrNo::Success); @@ -661,7 +596,6 @@ Expect AVCodecCtxSetSkipFrame::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetSkipIdct::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t AVDiscardId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->skip_idct = static_cast(AVDiscardId); return static_cast(ErrNo::Success); @@ -671,7 +605,6 @@ Expect AVCodecCtxSetErrorConcealment::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t ErrorConcealment) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->error_concealment = ErrorConcealment; return static_cast(ErrNo::Success); @@ -681,7 +614,6 @@ Expect AVCodecCtxSetErrorRecognition::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t ErrRecognition) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->err_recognition = ErrRecognition; return static_cast(ErrNo::Success); @@ -689,7 +621,6 @@ AVCodecCtxSetErrorRecognition::body(const Runtime::CallingFrame &, Expect AVCodecCtxDelay::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); return AvCodecCtx->delay; } @@ -697,7 +628,6 @@ Expect AVCodecCtxDelay::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetSkipTop::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t Value) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->skip_top = Value; return static_cast(ErrNo::Success); @@ -706,7 +636,6 @@ Expect AVCodecCtxSetSkipTop::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetSkipBottom::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t Value) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->skip_bottom = Value; return static_cast(ErrNo::Success); @@ -714,7 +643,6 @@ Expect AVCodecCtxSetSkipBottom::body(const Runtime::CallingFrame &, Expect AVCodecCtxRefs::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); return AvCodecCtx->refs; } @@ -722,7 +650,6 @@ Expect AVCodecCtxRefs::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetSliceFlags::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t Value) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->slice_flags = Value; return static_cast(ErrNo::Success); @@ -730,7 +657,6 @@ Expect AVCodecCtxSetSliceFlags::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetSliceCount::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t Value) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->slice_count = Value; return static_cast(ErrNo::Success); @@ -739,7 +665,6 @@ Expect AVCodecCtxSetSliceCount::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetFieldOrder::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t Value) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->field_order = static_cast(Value); return static_cast(ErrNo::Success); @@ -747,7 +672,6 @@ Expect AVCodecCtxSetFieldOrder::body(const Runtime::CallingFrame &, Expect AVCodecCtxColorTrc::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); return static_cast(AvCodecCtx->color_trc); } @@ -755,7 +679,6 @@ Expect AVCodecCtxColorTrc::body(const Runtime::CallingFrame &, Expect AVCodecCtxChromaSampleLocation::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AVChromaLocation const Chroma = AvCodecCtx->chroma_sample_location; return FFmpegUtils::ChromaLocation::fromAVChromaLocation(Chroma); @@ -763,14 +686,12 @@ AVCodecCtxChromaSampleLocation::body(const Runtime::CallingFrame &, Expect AVCodecCtxFrameNumber::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); return AvCodecCtx->frame_number; } Expect AVCodecCtxBlockAlign::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); return AvCodecCtx->block_align; } @@ -779,7 +700,6 @@ Expect AVCodecCtxSetRequestSampleFmt::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, uint32_t SampleFmtId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AVSampleFormat const SampleFmt = FFmpegUtils::SampleFmt::fromSampleID(SampleFmtId); @@ -789,7 +709,6 @@ AVCodecCtxSetRequestSampleFmt::body(const Runtime::CallingFrame &, Expect AVCodecCtxAudioServiceType::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AVAudioServiceType const AudioServiceType = AvCodecCtx->audio_service_type; return static_cast(AudioServiceType); @@ -797,7 +716,6 @@ Expect AVCodecCtxAudioServiceType::body(const Runtime::CallingFrame &, Expect AVCodecCtxHasBFrames::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); return AvCodecCtx->has_b_frames; } @@ -806,7 +724,6 @@ Expect AVCodecCtxSetRequestChannelLayout::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, uint64_t ChannelLayoutId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->request_channel_layout = FFmpegUtils::ChannelLayout::fromChannelLayoutID(ChannelLayoutId); @@ -815,7 +732,6 @@ AVCodecCtxSetRequestChannelLayout::body(const Runtime::CallingFrame &, Expect AVCodecCtxActiveThreadType::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); return AvCodecCtx->active_thread_type; } @@ -823,7 +739,6 @@ Expect AVCodecCtxActiveThreadType::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetThreadType::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t ThreadType) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->thread_type = ThreadType; return static_cast(ErrNo::Success); @@ -831,7 +746,6 @@ Expect AVCodecCtxSetThreadType::body(const Runtime::CallingFrame &, Expect AVCodecCtxThreadCount::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); return AvCodecCtx->thread_count; } @@ -839,7 +753,6 @@ Expect AVCodecCtxThreadCount::body(const Runtime::CallingFrame &, Expect AVCodecCtxSetThreadCount::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t ThreadCount) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AvCodecCtx->thread_count = ThreadCount; return static_cast(ErrNo::Success); @@ -847,7 +760,6 @@ Expect AVCodecCtxSetThreadCount::body(const Runtime::CallingFrame &, Expect AVCodecCtxColorPrimaries::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); AVColorPrimaries const ColorPrimaries = AvCodecCtx->color_primaries; return FFmpegUtils::ColorPrimaries::fromAVColorPrimaries(ColorPrimaries); diff --git a/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.h b/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.h index 356924d8..ce443aec 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.h +++ b/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.h @@ -1,4 +1,8 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once + #include "avcodec_base.h" #include "runtime/callingframe.h" diff --git a/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp b/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp index 9ffbd830..3724aa6e 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp +++ b/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avCodecParameters.h" extern "C" { @@ -11,14 +14,12 @@ namespace AVcodec { Expect AVCodecParamCodecId::body(const Runtime::CallingFrame &, uint32_t AvCodecParamId) { - FFMPEG_PTR_FETCH(AvCodecParams, AvCodecParamId, AVCodecParameters); return FFmpegUtils::CodecID::fromAVCodecID(AvCodecParams->codec_id); } Expect AVCodecParamCodecType::body(const Runtime::CallingFrame &, uint32_t AvCodecParamId) { - FFMPEG_PTR_FETCH(AvCodecParams, AvCodecParamId, AVCodecParameters); return FFmpegUtils::MediaType::fromMediaType(AvCodecParams->codec_type); } @@ -26,7 +27,6 @@ Expect AVCodecParamCodecType::body(const Runtime::CallingFrame &, Expect AVCodecParamSetCodecTag::body(const Runtime::CallingFrame &, uint32_t AvCodecParamId, uint32_t CodecTag) { - FFMPEG_PTR_FETCH(AvCodecParams, AvCodecParamId, AVCodecParameters); AvCodecParams->codec_tag = CodecTag; return static_cast(ErrNo::Success); diff --git a/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.h b/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.h index 9e37b7fa..5b3043d8 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.h +++ b/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.h @@ -1,4 +1,8 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once + #include "avcodec_base.h" #include "runtime/callingframe.h" diff --git a/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp b/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp index 892a0395..2510f978 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp +++ b/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avPacket.h" extern "C" { @@ -11,7 +14,6 @@ namespace AVcodec { Expect AVPacketAlloc::body(const Runtime::CallingFrame &Frame, uint32_t AvPacketPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(AvPacketId, MemInst, uint32_t, AvPacketPtr, "Failed when accessing the return AVCodecContext Memory"sv); @@ -24,14 +26,12 @@ Expect AVPacketAlloc::body(const Runtime::CallingFrame &Frame, Expect AVNewPacket::body(const Runtime::CallingFrame &, uint32_t AvPacketId, int32_t Size) { - FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); return av_new_packet(AvPacket, Size); } Expect AVPacketRef::body(const Runtime::CallingFrame &, uint32_t DestPacketId, uint32_t SrcPacketId) { - FFMPEG_PTR_FETCH(DestAvPacket, DestPacketId, AVPacket); FFMPEG_PTR_FETCH(SrcAvPacket, SrcPacketId, AVPacket); @@ -40,7 +40,6 @@ Expect AVPacketRef::body(const Runtime::CallingFrame &, Expect AVPacketUnref::body(const Runtime::CallingFrame &, uint32_t AvPacketId) { - FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); // Free packet. av_packet_unref(AvPacket); FFMPEG_PTR_DELETE(AvPacketId); @@ -49,14 +48,12 @@ Expect AVPacketUnref::body(const Runtime::CallingFrame &, Expect AVGrowPacket::body(const Runtime::CallingFrame &, uint32_t AvPacketId, int32_t Size) { - FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); return av_grow_packet(AvPacket, Size); } Expect AVShrinkPacket::body(const Runtime::CallingFrame &, uint32_t AvPacketId, int32_t Size) { - FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); av_shrink_packet(AvPacket, Size); return static_cast(ErrNo::Success); @@ -64,7 +61,6 @@ Expect AVShrinkPacket::body(const Runtime::CallingFrame &, Expect AVPacketStreamIndex::body(const Runtime::CallingFrame &, uint32_t AvPacketId) { - FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); return AvPacket->stream_index; } @@ -72,7 +68,6 @@ Expect AVPacketStreamIndex::body(const Runtime::CallingFrame &, Expect AVPacketSetStreamIndex::body(const Runtime::CallingFrame &, uint32_t AvPacketId, int32_t StreamIdx) { - FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); AvPacket->stream_index = StreamIdx; return static_cast(ErrNo::Success); @@ -80,21 +75,18 @@ Expect AVPacketSetStreamIndex::body(const Runtime::CallingFrame &, Expect AVPacketSize::body(const Runtime::CallingFrame &, uint32_t AvPacketId) { - FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); return AvPacket->size; } Expect AVPacketFlags::body(const Runtime::CallingFrame &, uint32_t AvPacketId) { - FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); return AvPacket->flags; } Expect AVPacketSetFlags::body(const Runtime::CallingFrame &, uint32_t AvPacketId, int32_t Flags) { - FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); AvPacket->flags = Flags; return static_cast(ErrNo::Success); @@ -102,14 +94,12 @@ Expect AVPacketSetFlags::body(const Runtime::CallingFrame &, Expect AVPacketPos::body(const Runtime::CallingFrame &, uint32_t AvPacketId) { - FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); return AvPacket->pos; } Expect AVPacketSetPos::body(const Runtime::CallingFrame &, uint32_t AvPacketId, int64_t Pos) { - FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); AvPacket->pos = Pos; return static_cast(ErrNo::Success); @@ -117,7 +107,6 @@ Expect AVPacketSetPos::body(const Runtime::CallingFrame &, Expect AVPacketDuration::body(const Runtime::CallingFrame &, uint32_t AvPacketId) { - FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); return AvPacket->duration; } @@ -125,7 +114,6 @@ Expect AVPacketDuration::body(const Runtime::CallingFrame &, Expect AVPacketSetDuration::body(const Runtime::CallingFrame &, uint32_t AvPacketId, int64_t Duration) { - FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); AvPacket->duration = Duration; return static_cast(ErrNo::Success); @@ -133,14 +121,12 @@ Expect AVPacketSetDuration::body(const Runtime::CallingFrame &, Expect AVPacketDts::body(const Runtime::CallingFrame &, uint32_t AvPacketId) { - FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); return AvPacket->dts; } Expect AVPacketSetDts::body(const Runtime::CallingFrame &, uint32_t AvPacketId, int64_t Dts) { - FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); AvPacket->dts = Dts; return static_cast(ErrNo::Success); @@ -148,14 +134,12 @@ Expect AVPacketSetDts::body(const Runtime::CallingFrame &, Expect AVPacketPts::body(const Runtime::CallingFrame &, uint32_t AvPacketId) { - FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); return AvPacket->pts; } Expect AVPacketSetPts::body(const Runtime::CallingFrame &, uint32_t AvPacketId, int64_t Pts) { - FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); AvPacket->pts = Pts; return static_cast(ErrNo::Success); @@ -163,7 +147,6 @@ Expect AVPacketSetPts::body(const Runtime::CallingFrame &, Expect AVPacketIsDataNull::body(const Runtime::CallingFrame &, uint32_t AvPacketId) { - FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); if (AvPacket->data == nullptr) return 1; @@ -173,7 +156,6 @@ Expect AVPacketIsDataNull::body(const Runtime::CallingFrame &, Expect AVPacketData::body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, uint32_t DataPtr, uint32_t DataLen) { - MEMINST_CHECK(MemInst, Frame, 0) MEM_SPAN_CHECK(Buffer, MemInst, uint8_t, DataPtr, DataLen, ""); diff --git a/plugins/wasmedge_ffmpeg/avcodec/avPacket.h b/plugins/wasmedge_ffmpeg/avcodec/avPacket.h index 5ed513a8..c54eee8c 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/avPacket.h +++ b/plugins/wasmedge_ffmpeg/avcodec/avPacket.h @@ -1,4 +1,8 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once + #include "avcodec_base.h" #include "runtime/callingframe.h" diff --git a/plugins/wasmedge_ffmpeg/avcodec/avcodec_base.h b/plugins/wasmedge_ffmpeg/avcodec/avcodec_base.h index e3365858..126f43b0 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/avcodec_base.h +++ b/plugins/wasmedge_ffmpeg/avcodec/avcodec_base.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "ffmpeg_env.h" diff --git a/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp b/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp index 91c0b9de..5158f4b3 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp +++ b/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avcodec_func.h" extern "C" { @@ -13,7 +16,6 @@ namespace AVcodec { Expect AVCodecAllocContext3::body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, uint32_t AvCodecCtxPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(AvCodecCtxId, MemInst, uint32_t, AvCodecCtxPtr, "Failed when accessing the return AVCodecContext Memory"sv); @@ -28,7 +30,6 @@ Expect AVCodecParametersFromContext::body(const Runtime::CallingFrame &, uint32_t AvCodecParamId, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecParam, AvCodecParamId, AVCodecParameters); FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); return avcodec_parameters_from_context(AvCodecParam, AvCodecCtx); @@ -36,7 +37,6 @@ AVCodecParametersFromContext::body(const Runtime::CallingFrame &, Expect AVCodecParametersFree::body(const Runtime::CallingFrame &, uint32_t AvCodecParamId) { - FFMPEG_PTR_FETCH(AvCodecParam, AvCodecParamId, AVCodecParameters); avcodec_parameters_free(&AvCodecParam); @@ -46,7 +46,6 @@ Expect AVCodecParametersFree::body(const Runtime::CallingFrame &, Expect AVCodecFreeContext::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); avcodec_free_context(&AvCodecCtx); @@ -68,7 +67,6 @@ Expect AVCodecParametersAlloc::body(const Runtime::CallingFrame &Frame, Expect AVCodecGetType::body(const Runtime::CallingFrame &, uint32_t AvCodecIdIndex) { - AVCodecID const AvCodecId = FFmpegUtils::CodecID::intoAVCodecID(AvCodecIdIndex); AVMediaType const MediaType = avcodec_get_type(AvCodecId); @@ -78,7 +76,6 @@ Expect AVCodecGetType::body(const Runtime::CallingFrame &, Expect AVCodecOpen2::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, uint32_t AvCodecId, uint32_t AvDictionaryId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); FFMPEG_PTR_FETCH(AvDictionary, AvDictionaryId, AVDictionary *); FFMPEG_PTR_FETCH(AvCodec, AvCodecId, AVCodec); @@ -87,7 +84,6 @@ Expect AVCodecOpen2::body(const Runtime::CallingFrame &, Expect AVCodecFindDecoder::body(const Runtime::CallingFrame &Frame, uint32_t ID, uint32_t AvCodecPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(AVCodecId, MemInst, uint32_t, AvCodecPtr, "Failed when accessing the return AVCodec Memory"sv); @@ -108,21 +104,18 @@ Expect AVCodecFindDecoder::body(const Runtime::CallingFrame &Frame, Expect AVCodecIsEncoder::body(const Runtime::CallingFrame &, uint32_t AvCodecId) { - FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); return av_codec_is_encoder(AvCodec); } Expect AVCodecIsDecoder::body(const Runtime::CallingFrame &, uint32_t AvCodecId) { - FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); return av_codec_is_decoder(AvCodec); } Expect AVCodecClose::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); int Res = avcodec_close(AvCodecCtx); FFMPEG_PTR_DELETE(AvCodecCtxId); @@ -132,7 +125,6 @@ Expect AVCodecClose::body(const Runtime::CallingFrame &, Expect AVCodecParametersToContext::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, uint32_t AvCodecParamId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); FFMPEG_PTR_FETCH(AvCodecParam, AvCodecParamId, AVCodecParameters); @@ -142,7 +134,6 @@ Expect AVCodecParametersToContext::body(const Runtime::CallingFrame &, Expect AVCodecReceiveFrame::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); return avcodec_receive_frame(AvCodecCtx, AvFrame); @@ -151,7 +142,6 @@ Expect AVCodecReceiveFrame::body(const Runtime::CallingFrame &, Expect AVCodecSendPacket::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, uint32_t PacketId) { - FFMPEG_PTR_FETCH(AVCodecCtx, AvCodecCtxId, AVCodecContext); FFMPEG_PTR_FETCH(AvPacket, PacketId, AVPacket); // Can send Null AVPacket, to close the stream. @@ -160,7 +150,6 @@ Expect AVCodecSendPacket::body(const Runtime::CallingFrame &, Expect AVCodecFindEncoder::body(const Runtime::CallingFrame &Frame, uint32_t ID, uint32_t AVCodecPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(AVCodecId, MemInst, uint32_t, AVCodecPtr, "Failed when accessing the return AVCodec Memory"sv); @@ -182,7 +171,6 @@ Expect AVCodecFindEncoder::body(const Runtime::CallingFrame &Frame, Expect AVCodecReceivePacket::body(const Runtime::CallingFrame &, uint32_t AVCodecCtxId, uint32_t PacketId) { - FFMPEG_PTR_FETCH(AVCodecCtx, AVCodecCtxId, AVCodecContext); FFMPEG_PTR_FETCH(AvPacket, PacketId, AVPacket); return avcodec_receive_packet(AVCodecCtx, AvPacket); @@ -191,7 +179,6 @@ Expect AVCodecReceivePacket::body(const Runtime::CallingFrame &, Expect AVCodecSendFrame::body(const Runtime::CallingFrame &, uint32_t AVCodecCtxId, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AVCodecCtx, AVCodecCtxId, AVCodecContext); FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); return avcodec_send_frame(AVCodecCtx, AvFrame); @@ -201,7 +188,6 @@ Expect AVCodecFindDecoderByName::body(const Runtime::CallingFrame &Frame, uint32_t AVCodecPtr, uint32_t NamePtr, uint32_t NameLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(AVCodecId, MemInst, uint32_t, AVCodecPtr, "Failed when accessing the return AVCodec Memory"sv); @@ -226,7 +212,6 @@ Expect AVCodecFindEncoderByName::body(const Runtime::CallingFrame &Frame, uint32_t AVCodecPtr, uint32_t NamePtr, uint32_t NameLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(AVCodecId, MemInst, uint32_t, AVCodecPtr, "Failed when accessing the return AVCodec Memory"sv); @@ -251,7 +236,6 @@ Expect AVPacketRescaleTs::body(const Runtime::CallingFrame &, uint32_t AvPacketId, int32_t SrcNum, int32_t SrcDen, int32_t DestNum, int32_t DestDen) { - FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); AVRational const Src = av_make_q(SrcNum, SrcDen); AVRational const Dest = av_make_q(DestNum, DestDen); @@ -262,7 +246,6 @@ Expect AVPacketRescaleTs::body(const Runtime::CallingFrame &, Expect AVPacketMakeWritable::body(const Runtime::CallingFrame &, uint32_t AVPacketId) { - FFMPEG_PTR_FETCH(AvPacket, AVPacketId, AVPacket); return av_packet_make_writable(AvPacket); } @@ -271,7 +254,6 @@ Expect AVCodecParametersCopy::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t AVCodecParamId, uint32_t StreamIdx) { - FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); FFMPEG_PTR_FETCH(AvCodecParam, AVCodecParamId, AVCodecParameters); @@ -291,7 +273,6 @@ Expect AVCodecVersion::body(const Runtime::CallingFrame &) { Expect AVCodecFlushBuffers::body(const Runtime::CallingFrame &, uint32_t AVCodecCtxId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AVCodecCtxId, AVCodecContext); avcodec_flush_buffers(AvCodecCtx); return static_cast(ErrNo::Success); @@ -306,7 +287,6 @@ AVCodecConfigurationLength::body(const Runtime::CallingFrame &) { Expect AVCodecConfiguration::body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, uint32_t ConfigLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); @@ -316,14 +296,12 @@ Expect AVCodecConfiguration::body(const Runtime::CallingFrame &Frame, } Expect AVCodecLicenseLength::body(const Runtime::CallingFrame &) { - const char *License = avcodec_license(); return strlen(License); } Expect AVCodecLicense::body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, uint32_t LicenseLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); diff --git a/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.h b/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.h index 3aabf3e1..f7ea4d24 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.h +++ b/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.h @@ -1,4 +1,8 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once + #include "avcodec_base.h" #include "runtime/callingframe.h" diff --git a/plugins/wasmedge_ffmpeg/avcodec/module.cpp b/plugins/wasmedge_ffmpeg/avcodec/module.cpp index 1b904755..5f082ba8 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/module.cpp +++ b/plugins/wasmedge_ffmpeg/avcodec/module.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "module.h" #include "avCodec.h" #include "avCodecContext.h" @@ -13,7 +16,6 @@ namespace AVcodec { WasmEdgeFFmpegAVCodecModule::WasmEdgeFFmpegAVCodecModule( std::shared_ptr Env) : ModuleInstance("wasmedge_ffmpeg_avcodec") { - // avcodec_func.h addHostFunc("wasmedge_ffmpeg_avcodec_avcodec_alloc_context3", std::make_unique(Env)); diff --git a/plugins/wasmedge_ffmpeg/avcodec/module.h b/plugins/wasmedge_ffmpeg/avcodec/module.h index 7d3776d6..7bb8ff67 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/module.h +++ b/plugins/wasmedge_ffmpeg/avcodec/module.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "ffmpeg_env.h" diff --git a/plugins/wasmedge_ffmpeg/avdevice/avDevice_base.h b/plugins/wasmedge_ffmpeg/avdevice/avDevice_base.h index 06096dd4..ae7c37a1 100644 --- a/plugins/wasmedge_ffmpeg/avdevice/avDevice_base.h +++ b/plugins/wasmedge_ffmpeg/avdevice/avDevice_base.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "ffmpeg_env.h" diff --git a/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.cpp b/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.cpp index 9352680c..0e19d8aa 100644 --- a/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.cpp +++ b/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avDevice_func.h" extern "C" { @@ -60,9 +63,7 @@ Expect AVOutputVideoDeviceNext::body(const Runtime::CallingFrame &) { Expect AVDeviceFreeListDevices::body(const Runtime::CallingFrame &, uint32_t AVDeviceInfoListId) { - FFMPEG_PTR_FETCH(AvDeviceInfoList, AVDeviceInfoListId, AVDeviceInfoList *); - avdevice_free_list_devices(AvDeviceInfoList); FFMPEG_PTR_DELETE(AVDeviceInfoListId); return static_cast(ErrNo::Success); @@ -70,14 +71,12 @@ Expect AVDeviceFreeListDevices::body(const Runtime::CallingFrame &, Expect AVDeviceNbDevices::body(const Runtime::CallingFrame &, uint32_t AVDeviceInfoListId) { - FFMPEG_PTR_FETCH(AvDeviceInfoList, AVDeviceInfoListId, AVDeviceInfoList *); return (*AvDeviceInfoList)->nb_devices; } Expect AVDeviceDefaultDevice::body(const Runtime::CallingFrame &, uint32_t AVDeviceInfoListId) { - FFMPEG_PTR_FETCH(AvDeviceInfoList, AVDeviceInfoListId, AVDeviceInfoList *); return (*AvDeviceInfoList)->default_device; } @@ -91,7 +90,6 @@ AVDeviceConfigurationLength::body(const Runtime::CallingFrame &) { Expect AVDeviceConfiguration::body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, uint32_t ConfigLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); @@ -101,7 +99,6 @@ Expect AVDeviceConfiguration::body(const Runtime::CallingFrame &Frame, } Expect AVDeviceLicenseLength::body(const Runtime::CallingFrame &) { - const char *License = avdevice_license(); return strlen(License); } @@ -109,7 +106,6 @@ Expect AVDeviceLicenseLength::body(const Runtime::CallingFrame &) { Expect AVDeviceLicense::body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, uint32_t LicenseLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); diff --git a/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.h b/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.h index 92a1309b..cccfa303 100644 --- a/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.h +++ b/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "avDevice_base.h" diff --git a/plugins/wasmedge_ffmpeg/avdevice/module.cpp b/plugins/wasmedge_ffmpeg/avdevice/module.cpp index a47312d5..0e58d788 100644 --- a/plugins/wasmedge_ffmpeg/avdevice/module.cpp +++ b/plugins/wasmedge_ffmpeg/avdevice/module.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "module.h" #include "avDevice_func.h" @@ -9,7 +12,6 @@ namespace AVDevice { WasmEdgeFFmpegAVDeviceModule::WasmEdgeFFmpegAVDeviceModule( std::shared_ptr Env) : ModuleInstance("wasmedge_ffmpeg_avdevice") { - addHostFunc("wasmedge_ffmpeg_avdevice_avdevice_register_all", std::make_unique(Env)); addHostFunc("wasmedge_ffmpeg_avdevice_avdevice_version", diff --git a/plugins/wasmedge_ffmpeg/avdevice/module.h b/plugins/wasmedge_ffmpeg/avdevice/module.h index 57e291eb..748d4cae 100644 --- a/plugins/wasmedge_ffmpeg/avdevice/module.h +++ b/plugins/wasmedge_ffmpeg/avdevice/module.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "ffmpeg_env.h" diff --git a/plugins/wasmedge_ffmpeg/avfilter/avFilter.cpp b/plugins/wasmedge_ffmpeg/avfilter/avFilter.cpp index 2d7fcf64..9020b759 100644 --- a/plugins/wasmedge_ffmpeg/avfilter/avFilter.cpp +++ b/plugins/wasmedge_ffmpeg/avfilter/avFilter.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avFilter.h" extern "C" { @@ -18,7 +21,6 @@ Expect AVFilterNameLength::body(const Runtime::CallingFrame &, Expect AVFilterName::body(const Runtime::CallingFrame &Frame, uint32_t FilterId, uint32_t NamePtr, uint32_t NameLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(NameBuf, MemInst, char, NamePtr, NameLen, ""); @@ -37,7 +39,6 @@ Expect AVFilterDescriptionLength::body(const Runtime::CallingFrame &, Expect AVFilterDescription::body(const Runtime::CallingFrame &Frame, uint32_t FilterId, uint32_t DescPtr, uint32_t DescLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(DescBuf, MemInst, char, DescPtr, DescLen, ""); @@ -61,7 +62,6 @@ Expect AVFilterNbOutputs::body(const Runtime::CallingFrame &, Expect AVFilterFlags::body(const Runtime::CallingFrame &, uint32_t FilterId) { - FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); return Filter->flags; } @@ -69,7 +69,6 @@ Expect AVFilterFlags::body(const Runtime::CallingFrame &, Expect AVFilterInOutSetName::body(const Runtime::CallingFrame &Frame, uint32_t InOutId, uint32_t NamePtr, uint32_t NameLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(NameBuf, MemInst, char, NamePtr, NameLen, ""); @@ -78,8 +77,9 @@ Expect AVFilterInOutSetName::body(const Runtime::CallingFrame &Frame, std::string Name; std::copy_n(NameBuf.data(), NameLen, std::back_inserter(Name)); char *CName = av_strdup(Name.c_str()); - if (CName == nullptr) + if (CName == nullptr) { return static_cast(ErrNo::Success); + } InOut->name = CName; return static_cast(ErrNo::Success); } @@ -87,7 +87,6 @@ Expect AVFilterInOutSetName::body(const Runtime::CallingFrame &Frame, Expect AVFilterInOutSetFilterCtx::body(const Runtime::CallingFrame &, uint32_t InOutId, uint32_t FilterCtxId) { - FFMPEG_PTR_FETCH(InOut, InOutId, AVFilterInOut); FFMPEG_PTR_FETCH(FilterCtx, FilterCtxId, AVFilterContext); @@ -97,7 +96,6 @@ Expect AVFilterInOutSetFilterCtx::body(const Runtime::CallingFrame &, Expect AVFilterInOutSetPadIdx::body(const Runtime::CallingFrame &, uint32_t InOutId, int32_t PadIdx) { - FFMPEG_PTR_FETCH(InOut, InOutId, AVFilterInOut); InOut->pad_idx = PadIdx; return static_cast(ErrNo::Success); @@ -106,7 +104,6 @@ Expect AVFilterInOutSetPadIdx::body(const Runtime::CallingFrame &, Expect AVFilterInOutSetNext::body(const Runtime::CallingFrame &, uint32_t InOutId, uint32_t NextInOutId) { - FFMPEG_PTR_FETCH(InOut, InOutId, AVFilterInOut); FFMPEG_PTR_FETCH(NextInOut, NextInOutId, AVFilterInOut); InOut->next = NextInOut; @@ -116,14 +113,14 @@ Expect AVFilterInOutSetNext::body(const Runtime::CallingFrame &, Expect AVFilterGetInputsFilterPad::body(const Runtime::CallingFrame &Frame, uint32_t FilterId, uint32_t FilterPadPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(FilterPadId, MemInst, uint32_t, FilterPadPtr, "") FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); const AVFilterPad *FilterPad = Filter->inputs; - if (FilterPad == nullptr) + if (FilterPad == nullptr) { return static_cast(ErrNo::Success); + } FFMPEG_PTR_STORE(const_cast(FilterPad), FilterPadId); return static_cast(ErrNo::Success); } @@ -131,14 +128,14 @@ AVFilterGetInputsFilterPad::body(const Runtime::CallingFrame &Frame, Expect AVFilterGetOutputsFilterPad::body(const Runtime::CallingFrame &Frame, uint32_t FilterId, uint32_t FilterPadPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(FilterPadId, MemInst, uint32_t, FilterPadPtr, "") FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); const AVFilterPad *FilterPad = Filter->outputs; - if (FilterPad == nullptr) + if (FilterPad == nullptr) { return static_cast(ErrNo::Success); + } FFMPEG_PTR_STORE(const_cast(FilterPad), FilterPadId); return static_cast(ErrNo::Success); } diff --git a/plugins/wasmedge_ffmpeg/avfilter/avFilter.h b/plugins/wasmedge_ffmpeg/avfilter/avFilter.h index 1dbf650f..58129a15 100644 --- a/plugins/wasmedge_ffmpeg/avfilter/avFilter.h +++ b/plugins/wasmedge_ffmpeg/avfilter/avFilter.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "avfilter_base.h" diff --git a/plugins/wasmedge_ffmpeg/avfilter/avfilter_base.h b/plugins/wasmedge_ffmpeg/avfilter/avfilter_base.h index 7ed333d9..64209977 100644 --- a/plugins/wasmedge_ffmpeg/avfilter/avfilter_base.h +++ b/plugins/wasmedge_ffmpeg/avfilter/avfilter_base.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "ffmpeg_env.h" diff --git a/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp b/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp index 974a1b37..73ad99ca 100644 --- a/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp +++ b/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avfilter_func.h" extern "C" { @@ -11,22 +14,21 @@ namespace AVFilter { Expect AVFilterGraphAlloc::body(const Runtime::CallingFrame &Frame, uint32_t FilterGraphPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(FilterGraphId, MemInst, uint32_t, FilterGraphPtr, "") FFMPEG_PTR_FETCH(FilterGraph, *FilterGraphId, AVFilterGraph); FilterGraph = avfilter_graph_alloc(); - if (FilterGraph == nullptr) + if (FilterGraph == nullptr) { return static_cast(ErrNo::Success); + } FFMPEG_PTR_STORE(FilterGraph, FilterGraphId); return static_cast(ErrNo::Success); } Expect AVFilterGraphConfig::body(const Runtime::CallingFrame &, uint32_t FilterGraphId) { - FFMPEG_PTR_FETCH(FilterGraph, FilterGraphId, AVFilterGraph); return avfilter_graph_config(FilterGraph, nullptr); // log_ctx always NULL on Rust SDK. @@ -34,7 +36,6 @@ Expect AVFilterGraphConfig::body(const Runtime::CallingFrame &, Expect AVFilterGraphFree::body(const Runtime::CallingFrame &, uint32_t FilterGraphId) { - FFMPEG_PTR_FETCH(FilterGraph, FilterGraphId, AVFilterGraph); avfilter_graph_free(&FilterGraph); FFMPEG_PTR_DELETE(FilterGraphId); @@ -46,7 +47,6 @@ Expect AVFilterGraphGetFilter::body(const Runtime::CallingFrame &Frame, uint32_t FilterGraphId, uint32_t NamePtr, uint32_t NameSize) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(NameId, MemInst, char, NamePtr, "Failed when accessing the return Name memory"sv); @@ -59,8 +59,9 @@ Expect AVFilterGraphGetFilter::body(const Runtime::CallingFrame &Frame, std::copy_n(NameId, NameSize, std::back_inserter(Name)); FilterCtx = avfilter_graph_get_filter(FilterGraph, Name.c_str()); - if (FilterCtx == nullptr) + if (FilterCtx == nullptr) { return static_cast(ErrNo::Success); + } FFMPEG_PTR_STORE(FilterCtx, FilterCtxId); return static_cast(ErrNo::Success); } @@ -71,7 +72,6 @@ Expect AVFilterGraphParsePtr::body(const Runtime::CallingFrame &Frame, uint32_t FiltersSize, uint32_t InputsId, uint32_t OutputsId) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(FiltersId, MemInst, char, FiltersString, ""); @@ -87,7 +87,6 @@ Expect AVFilterGraphParsePtr::body(const Runtime::CallingFrame &Frame, Expect AVFilterInOutFree::body(const Runtime::CallingFrame &, uint32_t InOutId) { - FFMPEG_PTR_FETCH(InOut, InOutId, AVFilterInOut); avfilter_inout_free(&InOut); FFMPEG_PTR_DELETE(InOutId); @@ -101,7 +100,6 @@ Expect AVFilterVersion::body(const Runtime::CallingFrame &) { Expect AVFilterGetByName::body(const Runtime::CallingFrame &Frame, uint32_t FilterPtr, uint32_t StrPtr, uint32_t StrLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(StrId, MemInst, char, StrPtr, "Failed when accessing the return Str memory"sv); @@ -113,8 +111,9 @@ Expect AVFilterGetByName::body(const Runtime::CallingFrame &Frame, std::copy_n(StrId, StrLen, std::back_inserter(Name)); Filter = avfilter_get_by_name(Name.c_str()); - if (Filter == nullptr) + if (Filter == nullptr) { return static_cast(ErrNo::Success); + } FFMPEG_PTR_STORE(const_cast(Filter), FilterId); return static_cast(ErrNo::Success); @@ -122,7 +121,6 @@ Expect AVFilterGetByName::body(const Runtime::CallingFrame &Frame, Expect AVFilterConfigurationLength::body(const Runtime::CallingFrame &) { - const char *Config = avfilter_configuration(); return strlen(Config); } @@ -130,7 +128,6 @@ AVFilterConfigurationLength::body(const Runtime::CallingFrame &) { Expect AVFilterConfiguration::body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, uint32_t ConfigLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); @@ -140,7 +137,6 @@ Expect AVFilterConfiguration::body(const Runtime::CallingFrame &Frame, } Expect AVFilterLicenseLength::body(const Runtime::CallingFrame &) { - const char *License = avfilter_license(); return strlen(License); } @@ -148,7 +144,6 @@ Expect AVFilterLicenseLength::body(const Runtime::CallingFrame &) { Expect AVFilterLicense::body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, uint32_t LicenseLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); @@ -161,7 +156,6 @@ Expect AVFilterGraphCreateFilter::body( const Runtime::CallingFrame &Frame, uint32_t FilterCtxPtr, uint32_t FilterId, uint32_t NamePtr, uint32_t NameLen, uint32_t ArgsPtr, uint32_t ArgsLen, uint32_t FilterGraphId) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(NameBuf, MemInst, char, NamePtr, NameLen, ""); MEM_SPAN_CHECK(ArgsBuf, MemInst, char, ArgsPtr, ArgsLen, ""); @@ -178,8 +172,9 @@ Expect AVFilterGraphCreateFilter::body( int Res = avfilter_graph_create_filter(&FilterCtx, Filter, Name.c_str(), Args.c_str(), nullptr, FilterGraph); - if (Res < 0) + if (Res < 0) { return Res; + } FFMPEG_PTR_STORE(FilterCtx, FilterCtxId); return Res; @@ -187,14 +182,14 @@ Expect AVFilterGraphCreateFilter::body( Expect AVFilterInOutAlloc::body(const Runtime::CallingFrame &Frame, uint32_t InOutPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(InOutId, MemInst, uint32_t, InOutPtr, "") FFMPEG_PTR_FETCH(InOut, *InOutId, AVFilterInOut); InOut = avfilter_inout_alloc(); - if (InOut == nullptr) + if (InOut == nullptr) { return static_cast(ErrNo::Success); + } FFMPEG_PTR_STORE(InOut, InOutId); return static_cast(ErrNo::Success); } @@ -202,7 +197,6 @@ Expect AVFilterInOutAlloc::body(const Runtime::CallingFrame &Frame, Expect AVFilterPadGetNameLength::body(const Runtime::CallingFrame &, uint32_t FilterPadId, int32_t Idx) { - FFMPEG_PTR_FETCH(FilterPad, FilterPadId, AVFilterPad); const char *Name = avfilter_pad_get_name(FilterPad, Idx); @@ -212,7 +206,6 @@ Expect AVFilterPadGetNameLength::body(const Runtime::CallingFrame &, Expect AVFilterPadGetName::body(const Runtime::CallingFrame &Frame, uint32_t FilterPadId, int32_t Idx, uint32_t NamePtr, uint32_t NameLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(NameBuf, MemInst, char, NamePtr, NameLen, ""); @@ -225,7 +218,6 @@ Expect AVFilterPadGetName::body(const Runtime::CallingFrame &Frame, Expect AVFilterPadGetType::body(const Runtime::CallingFrame &, uint32_t FilterPadId, int32_t Idx) { - FFMPEG_PTR_FETCH(FilterPad, FilterPadId, AVFilterPad); AVMediaType const MediaType = avfilter_pad_get_type(FilterPad, Idx); return FFmpegUtils::MediaType::fromMediaType(MediaType); @@ -233,7 +225,6 @@ Expect AVFilterPadGetType::body(const Runtime::CallingFrame &, Expect AVFilterGraphDumpLength::body(const Runtime::CallingFrame &, uint32_t FilterGraphId) { - FFMPEG_PTR_FETCH(FilterGraph, FilterGraphId, AVFilterGraph); char *Graph = avfilter_graph_dump(FilterGraph, nullptr); return strlen(Graph); @@ -243,7 +234,6 @@ Expect AVFilterGraphDump::body(const Runtime::CallingFrame &Frame, uint32_t FilterGraphId, uint32_t GraphStrPtr, uint32_t GraphStrLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(GraphStr, MemInst, char, GraphStrPtr, GraphStrLen, ""); @@ -256,7 +246,6 @@ Expect AVFilterGraphDump::body(const Runtime::CallingFrame &Frame, Expect AVFilterFreeGraphStr::body(const Runtime::CallingFrame &, uint32_t FilterGraphId) { - FFMPEG_PTR_FETCH(FilterGraph, FilterGraphId, AVFilterGraph); char *Graph = avfilter_graph_dump(FilterGraph, nullptr); @@ -266,30 +255,30 @@ Expect AVFilterFreeGraphStr::body(const Runtime::CallingFrame &, Expect AVFilterDrop::body(const Runtime::CallingFrame &, uint32_t FilterId) { - FFMPEG_PTR_FETCH(Filter, FilterId, struct AVFilter); - if (Filter == nullptr) + if (Filter == nullptr) { return static_cast(ErrNo::Success); + } FFMPEG_PTR_DELETE(FilterId); return static_cast(ErrNo::Success); } Expect AVFilterPadDrop::body(const Runtime::CallingFrame &, uint32_t FilterPadId) { - FFMPEG_PTR_FETCH(FilterPad, FilterPadId, AVFilterPad); - if (FilterPad == nullptr) + if (FilterPad == nullptr) { return static_cast(ErrNo::Success); + } FFMPEG_PTR_DELETE(FilterPadId); return static_cast(ErrNo::Success); } Expect AVFilterContextDrop::body(const Runtime::CallingFrame &, uint32_t FilterCtxId) { - FFMPEG_PTR_FETCH(FilterCtx, FilterCtxId, AVFilterContext); - if (FilterCtx == nullptr) + if (FilterCtx == nullptr) { return static_cast(ErrNo::Success); + } FFMPEG_PTR_DELETE(FilterCtxId); return static_cast(ErrNo::Success); } diff --git a/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.h b/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.h index b96914e2..18f037dc 100644 --- a/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.h +++ b/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "avfilter_base.h" diff --git a/plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.cpp b/plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.cpp index 0aed3697..5753477d 100644 --- a/plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.cpp +++ b/plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.cpp @@ -1,4 +1,8 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "buffer_source_sink.h" + extern "C" { #include "libavfilter/buffersink.h" #include "libavfilter/buffersrc.h" @@ -12,7 +16,6 @@ namespace AVFilter { Expect AVBufferSinkGetFrame::body(const Runtime::CallingFrame &, uint32_t FilterContextId, uint32_t FrameId) { - FFMPEG_PTR_FETCH(FilterCtx, FilterContextId, AVFilterContext); FFMPEG_PTR_FETCH(Frame, FrameId, AVFrame); return av_buffersink_get_frame(FilterCtx, Frame); @@ -22,7 +25,6 @@ Expect AVBufferSinkGetSamples::body(const Runtime::CallingFrame &, uint32_t FilterContextId, uint32_t FrameId, int32_t Samples) { - FFMPEG_PTR_FETCH(FilterCtx, FilterContextId, AVFilterContext); FFMPEG_PTR_FETCH(Frame, FrameId, AVFrame); return av_buffersink_get_samples(FilterCtx, Frame, Samples); @@ -31,7 +33,6 @@ Expect AVBufferSinkGetSamples::body(const Runtime::CallingFrame &, Expect AvBufferSinkSetFrameSize::body(const Runtime::CallingFrame &, uint32_t FilterContextId, int32_t Value) { - FFMPEG_PTR_FETCH(FilterCtx, FilterContextId, AVFilterContext); av_buffersink_set_frame_size(FilterCtx, Value); return static_cast(ErrNo::Success); @@ -40,7 +41,6 @@ Expect AvBufferSinkSetFrameSize::body(const Runtime::CallingFrame &, Expect AVBufferSrcGetNbFailedRequests::body(const Runtime::CallingFrame &, uint32_t FilterContextId) { - FFMPEG_PTR_FETCH(FilterCtx, FilterContextId, AVFilterContext); return av_buffersrc_get_nb_failed_requests(FilterCtx); } @@ -48,7 +48,6 @@ AVBufferSrcGetNbFailedRequests::body(const Runtime::CallingFrame &, Expect AVBufferSrcAddFrame::body(const Runtime::CallingFrame &, uint32_t FilterContextId, uint32_t FrameId) { - FFMPEG_PTR_FETCH(FilterCtx, FilterContextId, AVFilterContext); FFMPEG_PTR_FETCH(Frame, FrameId, AVFrame); return av_buffersrc_add_frame(FilterCtx, Frame); @@ -57,7 +56,6 @@ Expect AVBufferSrcAddFrame::body(const Runtime::CallingFrame &, Expect AVBufferSrcClose::body(const Runtime::CallingFrame &, uint32_t FilterContextId, int64_t Pts, uint32_t Flags) { - FFMPEG_PTR_FETCH(FilterCtx, FilterContextId, AVFilterContext); return av_buffersrc_close(FilterCtx, Pts, Flags); } diff --git a/plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.h b/plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.h index 7119d5fa..41144b9f 100644 --- a/plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.h +++ b/plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "avfilter_base.h" diff --git a/plugins/wasmedge_ffmpeg/avfilter/module.cpp b/plugins/wasmedge_ffmpeg/avfilter/module.cpp index 2c31e44f..761d5fba 100644 --- a/plugins/wasmedge_ffmpeg/avfilter/module.cpp +++ b/plugins/wasmedge_ffmpeg/avfilter/module.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "module.h" #include "avFilter.h" #include "avfilter_func.h" @@ -11,7 +14,6 @@ namespace AVFilter { WasmEdgeFFmpegAVFilterModule::WasmEdgeFFmpegAVFilterModule( std::shared_ptr Env) : ModuleInstance("wasmedge_ffmpeg_avfilter") { - addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_graph_alloc", std::make_unique(Env)); addHostFunc("wasmedge_ffmpeg_avfilter_avfilter_graph_config", diff --git a/plugins/wasmedge_ffmpeg/avfilter/module.h b/plugins/wasmedge_ffmpeg/avfilter/module.h index 2515e6ae..7892a925 100644 --- a/plugins/wasmedge_ffmpeg/avfilter/module.h +++ b/plugins/wasmedge_ffmpeg/avfilter/module.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "ffmpeg_env.h" diff --git a/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp b/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp index 1d5f5c41..f0657a52 100644 --- a/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp +++ b/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avChapter.h" extern "C" { @@ -11,14 +14,14 @@ namespace AVFormat { Expect AVChapterId::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t ChapterIdx) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); AVChapter **AvChapter = AvFormatContext->chapters; // No check here (Check) // Raw Pointer Iteration. - for (unsigned int I = 1; I <= ChapterIdx; I++) + for (unsigned int I = 1; I <= ChapterIdx; I++) { AvChapter++; + } return static_cast(*AvChapter)->id; } @@ -26,14 +29,14 @@ Expect AVChapterId::body(const Runtime::CallingFrame &, Expect AVChapterSetId::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t ChapterIdx, int64_t ChapterId) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); AVChapter **AvChapter = AvFormatContext->chapters; // No check here (Check) // Raw Pointer Iteration. - for (unsigned int I = 1; I <= ChapterIdx; I++) + for (unsigned int I = 1; I <= ChapterIdx; I++) { AvChapter++; + } (*AvChapter)->id = ChapterId; return static_cast(ErrNo::Success); @@ -43,7 +46,6 @@ Expect AVChapterTimebase::body(const Runtime::CallingFrame &Frame, uint32_t NumPtr, uint32_t DenPtr, uint32_t AvFormatCtxId, uint32_t ChapterIdx) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(Num, MemInst, int32_t, NumPtr, ""); MEM_PTR_CHECK(Den, MemInst, int32_t, DenPtr, ""); @@ -53,8 +55,9 @@ Expect AVChapterTimebase::body(const Runtime::CallingFrame &Frame, // No check here (Check) // Raw Pointer Iteration. - for (unsigned int I = 1; I <= ChapterIdx; I++) + for (unsigned int I = 1; I <= ChapterIdx; I++) { AvChapter++; + } AVRational const AvRational = static_cast(*AvChapter)->time_base; *Num = AvRational.num; @@ -66,7 +69,6 @@ Expect AVChapterSetTimebase::body(const Runtime::CallingFrame &, int32_t Num, int32_t Den, uint32_t AvFormatCtxId, uint32_t ChapterIdx) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); AVRational const Timebase = av_make_q(Num, Den); @@ -74,8 +76,9 @@ Expect AVChapterSetTimebase::body(const Runtime::CallingFrame &, // No check here (Check) // Raw Pointer Iteration. - for (unsigned int I = 1; I <= ChapterIdx; I++) + for (unsigned int I = 1; I <= ChapterIdx; I++) { AvChapter++; + } (*AvChapter)->time_base = Timebase; return static_cast(ErrNo::Success); @@ -84,14 +87,14 @@ Expect AVChapterSetTimebase::body(const Runtime::CallingFrame &, Expect AVChapterStart::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t ChapterIdx) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); AVChapter **AvChapter = AvFormatContext->chapters; // No check here (Check) // Raw Pointer Iteration. - for (unsigned int I = 1; I <= ChapterIdx; I++) + for (unsigned int I = 1; I <= ChapterIdx; I++) { AvChapter++; + } return static_cast(*AvChapter)->start; } @@ -100,14 +103,14 @@ Expect AVChapterSetStart::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t ChapterIdx, int64_t StartValue) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); AVChapter **AvChapter = AvFormatContext->chapters; // No check here (Check) // Raw Pointer Iteration. - for (unsigned int I = 1; I <= ChapterIdx; I++) + for (unsigned int I = 1; I <= ChapterIdx; I++) { AvChapter++; + } (*AvChapter)->start = StartValue; return static_cast(ErrNo::Success); @@ -116,14 +119,14 @@ Expect AVChapterSetStart::body(const Runtime::CallingFrame &, Expect AVChapterEnd::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t ChapterIdx) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); AVChapter **AvChapter = AvFormatContext->chapters; // No check here (Check) // Raw Pointer Iteration. - for (unsigned int I = 1; I <= ChapterIdx; I++) + for (unsigned int I = 1; I <= ChapterIdx; I++) { AvChapter++; + } return static_cast(*AvChapter)->end; } @@ -131,14 +134,14 @@ Expect AVChapterEnd::body(const Runtime::CallingFrame &, Expect AVChapterSetEnd::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t ChapterIdx, int64_t EndValue) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); AVChapter **AvChapter = AvFormatContext->chapters; // No check here (Check) // Raw Pointer Iteration. - for (unsigned int I = 1; I <= ChapterIdx; I++) + for (unsigned int I = 1; I <= ChapterIdx; I++) { AvChapter++; + } (*AvChapter)->end = EndValue; return static_cast(ErrNo::Success); @@ -147,7 +150,6 @@ Expect AVChapterSetEnd::body(const Runtime::CallingFrame &, Expect AVChapterMetadata::body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t ChapterIdx, uint32_t DictPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(DictId, MemInst, uint32_t, DictPtr, "Failed when accessing the return AVDictionary memory"sv); @@ -160,8 +162,9 @@ Expect AVChapterMetadata::body(const Runtime::CallingFrame &Frame, // No check here (Check) // Raw Pointer Iteration. - for (unsigned int I = 1; I <= ChapterIdx; I++) + for (unsigned int I = 1; I <= ChapterIdx; I++) { AvChapter++; + } *AvDictionary = (*AvChapter)->metadata; FFMPEG_PTR_STORE(AvDictionary, DictId); @@ -172,7 +175,6 @@ Expect AVChapterSetMetadata::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t ChapterIdx, uint32_t DictId) { - FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); FFMPEG_PTR_FETCH(AvDictionary, DictId, AVDictionary *); @@ -180,13 +182,15 @@ Expect AVChapterSetMetadata::body(const Runtime::CallingFrame &, // No check here (Check) // Raw Pointer Iteration. - for (unsigned int I = 1; I <= ChapterIdx; I++) + for (unsigned int I = 1; I <= ChapterIdx; I++) { AvChapter++; + } - if (AvDictionary == nullptr) + if (AvDictionary == nullptr) { (*AvChapter)->metadata = nullptr; - else + } else { (*AvChapter)->metadata = *AvDictionary; + } return static_cast(ErrNo::Success); } diff --git a/plugins/wasmedge_ffmpeg/avformat/avChapter.h b/plugins/wasmedge_ffmpeg/avformat/avChapter.h index bb40c088..58ed4f54 100644 --- a/plugins/wasmedge_ffmpeg/avformat/avChapter.h +++ b/plugins/wasmedge_ffmpeg/avformat/avChapter.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "avformat_base.h" diff --git a/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.cpp b/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.cpp index 072104c5..826da201 100644 --- a/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.cpp +++ b/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avInputOutputFormat.h" extern "C" { @@ -12,7 +15,6 @@ namespace AVFormat { Expect AVIOFormatNameLength::body(const Runtime::CallingFrame &, uint32_t AVIOFormatId, uint32_t FormatType) { - const char *Name; if (FormatType == 0) { @@ -23,15 +25,15 @@ Expect AVIOFormatNameLength::body(const Runtime::CallingFrame &, Name = AvOutputFormat->name; } - if (Name == nullptr) + if (Name == nullptr) { return 0; + } return strlen(Name); } Expect AVInputFormatName::body(const Runtime::CallingFrame &Frame, uint32_t AVInputFormatId, uint32_t NamePtr, uint32_t NameLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(NameBuf, MemInst, char, NamePtr, NameLen, ""); FFMPEG_PTR_FETCH(AvInputFormat, AVInputFormatId, AVInputFormat); @@ -44,7 +46,6 @@ Expect AVInputFormatName::body(const Runtime::CallingFrame &Frame, Expect AVOutputFormatName::body(const Runtime::CallingFrame &Frame, uint32_t AVOutputFormatId, uint32_t NamePtr, uint32_t NameLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(NameBuf, MemInst, char, NamePtr, NameLen, ""); FFMPEG_PTR_FETCH(AvOutputFormat, AVOutputFormatId, AVOutputFormat); @@ -57,7 +58,6 @@ Expect AVOutputFormatName::body(const Runtime::CallingFrame &Frame, Expect AVIOFormatLongNameLength::body(const Runtime::CallingFrame &, uint32_t AVIOFormatId, uint32_t FormatType) { - const char *LongName; if (FormatType == 0) { @@ -68,8 +68,9 @@ Expect AVIOFormatLongNameLength::body(const Runtime::CallingFrame &, LongName = AvOutputFormat->long_name; } - if (LongName == nullptr) + if (LongName == nullptr) { return 0; + } return strlen(LongName); } @@ -77,7 +78,6 @@ Expect AVInputFormatLongName::body(const Runtime::CallingFrame &Frame, uint32_t AVInputFormatId, uint32_t LongNamePtr, uint32_t LongNameLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(LongNameBuf, MemInst, char, LongNamePtr, LongNameLen, ""); FFMPEG_PTR_FETCH(AvInputFormat, AVInputFormatId, AVInputFormat); @@ -91,7 +91,6 @@ Expect AVOutputFormatLongName::body(const Runtime::CallingFrame &Frame, uint32_t AVOutputFormatId, uint32_t LongNamePtr, uint32_t LongNameLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(LongNameBuf, MemInst, char, LongNamePtr, LongNameLen, ""); FFMPEG_PTR_FETCH(AvOutputFormat, AVOutputFormatId, AVOutputFormat); @@ -104,7 +103,6 @@ Expect AVOutputFormatLongName::body(const Runtime::CallingFrame &Frame, Expect AVIOFormatExtensionsLength::body(const Runtime::CallingFrame &, uint32_t AVIOFormatId, uint32_t FormatType) { - const char *Extensions; if (FormatType == 0) { @@ -115,8 +113,9 @@ Expect AVIOFormatExtensionsLength::body(const Runtime::CallingFrame &, Extensions = AvOutputFormat->extensions; } - if (Extensions == nullptr) + if (Extensions == nullptr) { return 0; + } return strlen(Extensions); } @@ -124,7 +123,6 @@ Expect AVInputFormatExtensions::body(const Runtime::CallingFrame &Frame, uint32_t AVInputFormatId, uint32_t ExtensionsPtr, uint32_t ExtensionsLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(ExtensionsBuf, MemInst, char, ExtensionsPtr, ExtensionsLen, ""); @@ -139,7 +137,6 @@ Expect AVOutputFormatExtensions::body(const Runtime::CallingFrame &Frame, uint32_t AVOutputFormatId, uint32_t ExtensionsPtr, uint32_t ExtensionsLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(ExtensionsBuf, MemInst, char, ExtensionsPtr, ExtensionsLen, ""); @@ -163,8 +160,9 @@ Expect AVIOFormatMimeTypeLength::body(const Runtime::CallingFrame &, MimeType = AvOutputFormat->mime_type; } - if (MimeType == nullptr) + if (MimeType == nullptr) { return 0; + } return strlen(MimeType); } @@ -172,7 +170,6 @@ Expect AVInputFormatMimeType::body(const Runtime::CallingFrame &Frame, uint32_t AVInputFormatId, uint32_t MimeTypePtr, uint32_t MimeTypeLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(MimeTypeBuf, MemInst, char, MimeTypePtr, MimeTypeLen, ""); FFMPEG_PTR_FETCH(AvInputFormat, AVInputFormatId, AVInputFormat); @@ -186,7 +183,6 @@ Expect AVOutputFormatMimeType::body(const Runtime::CallingFrame &Frame, uint32_t AVOutputFormatId, uint32_t MimeTypePtr, uint32_t MimeTypeLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(MimeTypeBuf, MemInst, char, MimeTypePtr, MimeTypeLen, ""); FFMPEG_PTR_FETCH(AvOutputFormat, AVOutputFormatId, AVOutputFormat); @@ -198,7 +194,6 @@ Expect AVOutputFormatMimeType::body(const Runtime::CallingFrame &Frame, Expect AVOutputFormatFlags::body(const Runtime::CallingFrame &, uint32_t AVOutputFormatId) { - FFMPEG_PTR_FETCH(AvOutputFormat, AVOutputFormatId, AVOutputFormat); return AvOutputFormat->flags; } diff --git a/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.h b/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.h index 9c815ff8..b75efd7d 100644 --- a/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.h +++ b/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.h @@ -1,4 +1,8 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once + #include "avformat_base.h" #include "runtime/callingframe.h" diff --git a/plugins/wasmedge_ffmpeg/avformat/avStream.cpp b/plugins/wasmedge_ffmpeg/avformat/avStream.cpp index 1eaebef6..6f1c60ce 100644 --- a/plugins/wasmedge_ffmpeg/avformat/avStream.cpp +++ b/plugins/wasmedge_ffmpeg/avformat/avStream.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avStream.h" extern "C" { @@ -11,14 +14,14 @@ namespace AVFormat { Expect AVStreamId::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t StreamIdx) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); AVStream **AvStream = AvFormatContext->streams; // No check here (Check) // Raw Pointer Iteration. - for (unsigned int I = 1; I <= StreamIdx; I++) + for (unsigned int I = 1; I <= StreamIdx; I++) { AvStream++; + } return static_cast(*AvStream)->id; } @@ -26,12 +29,12 @@ Expect AVStreamId::body(const Runtime::CallingFrame &, Expect AVStreamIndex::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t StreamIdx) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); AVStream **AvStream = AvFormatContext->streams; - for (unsigned int I = 1; I <= StreamIdx; I++) + for (unsigned int I = 1; I <= StreamIdx; I++) { AvStream++; + } return static_cast(*AvStream)->index; } @@ -47,8 +50,9 @@ Expect AVStreamCodecPar::body(const Runtime::CallingFrame &Frame, FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); AVStream **AvStream = AvFormatContext->streams; - for (unsigned int I = 1; I <= StreamIdx; I++) + for (unsigned int I = 1; I <= StreamIdx; I++) { AvStream++; + } AVCodecParameters *CodecParam = (static_cast(*AvStream))->codecpar; @@ -60,7 +64,6 @@ Expect AVStreamTimebase::body(const Runtime::CallingFrame &Frame, uint32_t NumPtr, uint32_t DenPtr, uint32_t AvFormatCtxId, uint32_t StreamIdx) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(Num, MemInst, int32_t, NumPtr, ""); MEM_PTR_CHECK(Den, MemInst, int32_t, DenPtr, ""); @@ -68,8 +71,9 @@ Expect AVStreamTimebase::body(const Runtime::CallingFrame &Frame, FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); AVStream **AvStream = AvFormatContext->streams; - for (unsigned int I = 1; I <= StreamIdx; I++) + for (unsigned int I = 1; I <= StreamIdx; I++) { AvStream++; + } AVRational const AvRational = static_cast(*AvStream)->time_base; *Num = AvRational.num; @@ -81,12 +85,12 @@ Expect AVStreamSetTimebase::body(const Runtime::CallingFrame &, uint32_t Num, uint32_t Den, uint32_t AvFormatCtxId, uint32_t StreamIdx) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); AVStream **AvStream = AvFormatContext->streams; - for (unsigned int I = 1; I <= StreamIdx; I++) + for (unsigned int I = 1; I <= StreamIdx; I++) { AvStream++; + } AVRational const Timebase = av_make_q(Num, Den); (*AvStream)->time_base = Timebase; @@ -96,12 +100,12 @@ Expect AVStreamSetTimebase::body(const Runtime::CallingFrame &, Expect AVStreamDuration::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t StreamIdx) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); AVStream **AvStream = AvFormatContext->streams; - for (unsigned int I = 1; I <= StreamIdx; I++) + for (unsigned int I = 1; I <= StreamIdx; I++) { AvStream++; + } return static_cast(*AvStream)->duration; } @@ -109,12 +113,12 @@ Expect AVStreamDuration::body(const Runtime::CallingFrame &, Expect AVStreamStartTime::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t StreamIdx) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); AVStream **AvStream = AvFormatContext->streams; - for (unsigned int I = 1; I <= StreamIdx; I++) + for (unsigned int I = 1; I <= StreamIdx; I++) { AvStream++; + } return static_cast(*AvStream)->start_time; } @@ -122,12 +126,12 @@ Expect AVStreamStartTime::body(const Runtime::CallingFrame &, Expect AVStreamNbFrames::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t StreamIdx) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); AVStream **AvStream = AvFormatContext->streams; - for (unsigned int I = 1; I <= StreamIdx; I++) + for (unsigned int I = 1; I <= StreamIdx; I++) { AvStream++; + } return static_cast(*AvStream)->nb_frames; } @@ -135,13 +139,13 @@ Expect AVStreamNbFrames::body(const Runtime::CallingFrame &, Expect AVStreamDisposition::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t StreamIdx) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); AVStream **AvStream = AvFormatContext->streams; - for (unsigned int I = 1; I <= StreamIdx; I++) + for (unsigned int I = 1; I <= StreamIdx; I++) { AvStream++; + } return static_cast(*AvStream)->disposition; } @@ -150,7 +154,6 @@ Expect AVStreamRFrameRate::body(const Runtime::CallingFrame &Frame, uint32_t NumPtr, uint32_t DenPtr, uint32_t AvFormatCtxId, uint32_t StreamIdx) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(Num, MemInst, int32_t, NumPtr, ""); MEM_PTR_CHECK(Den, MemInst, int32_t, DenPtr, ""); @@ -158,8 +161,9 @@ Expect AVStreamRFrameRate::body(const Runtime::CallingFrame &Frame, FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); AVStream **AvStream = AvFormatContext->streams; - for (unsigned int I = 1; I <= StreamIdx; I++) + for (unsigned int I = 1; I <= StreamIdx; I++) { AvStream++; + } AVRational const AvRational = static_cast(*AvStream)->r_frame_rate; @@ -172,12 +176,12 @@ Expect AVStreamSetRFrameRate::body(const Runtime::CallingFrame &, int32_t Num, int32_t Den, uint32_t AvFormatCtxId, uint32_t StreamIdx) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); AVStream **AvStream = AvFormatContext->streams; - for (unsigned int I = 1; I <= StreamIdx; I++) + for (unsigned int I = 1; I <= StreamIdx; I++) { AvStream++; + } AVRational const RFrameRate = av_make_q(Num, Den); (*AvStream)->r_frame_rate = RFrameRate; @@ -188,7 +192,6 @@ Expect AVStreamAvgFrameRate::body(const Runtime::CallingFrame &Frame, uint32_t NumPtr, uint32_t DenPtr, uint32_t AvFormatCtxId, uint32_t StreamIdx) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(Num, MemInst, int32_t, NumPtr, ""); MEM_PTR_CHECK(Den, MemInst, int32_t, DenPtr, ""); @@ -196,8 +199,9 @@ Expect AVStreamAvgFrameRate::body(const Runtime::CallingFrame &Frame, FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); AVStream **AvStream = AvFormatContext->streams; - for (unsigned int I = 1; I <= StreamIdx; I++) + for (unsigned int I = 1; I <= StreamIdx; I++) { AvStream++; + } AVRational const AvRational = static_cast(*AvStream)->avg_frame_rate; @@ -210,13 +214,13 @@ Expect AVStreamSetAvgFrameRate::body(const Runtime::CallingFrame &, int32_t Num, int32_t Den, uint32_t AvFormatCtxId, uint32_t StreamIdx) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); AVStream **AvStream = AvFormatContext->streams; - for (unsigned int I = 1; I <= StreamIdx; I++) + for (unsigned int I = 1; I <= StreamIdx; I++) { AvStream++; + } AVRational const AvgFrameRate = av_make_q(Num, Den); (*AvStream)->avg_frame_rate = AvgFrameRate; @@ -226,7 +230,6 @@ Expect AVStreamSetAvgFrameRate::body(const Runtime::CallingFrame &, Expect AVStreamMetadata::body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t StreamIdx, uint32_t DictPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(DictId, MemInst, uint32_t, DictPtr, "Failed when accessing the return AVDictPtr Memory"sv); @@ -235,8 +238,9 @@ Expect AVStreamMetadata::body(const Runtime::CallingFrame &Frame, AVStream **AvStream = AvFormatContext->streams; - for (unsigned int I = 1; I <= StreamIdx; I++) + for (unsigned int I = 1; I <= StreamIdx; I++) { AvStream++; + } AVDictionary **AvDictionary = static_cast(av_malloc(sizeof(AVDictionary *))); @@ -249,19 +253,20 @@ Expect AVStreamMetadata::body(const Runtime::CallingFrame &Frame, Expect AVStreamSetMetadata::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t StreamIdx, uint32_t DictId) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); FFMPEG_PTR_FETCH(AvDictionary, DictId, AVDictionary *); AVStream **AvStream = AvFormatContext->streams; - for (unsigned int I = 1; I <= StreamIdx; I++) + for (unsigned int I = 1; I <= StreamIdx; I++) { AvStream++; + } - if (AvDictionary == nullptr) + if (AvDictionary == nullptr) { (*AvStream)->metadata = nullptr; - else + } else { (*AvStream)->metadata = *AvDictionary; + } return static_cast(ErrNo::Success); } @@ -269,12 +274,12 @@ Expect AVStreamSetMetadata::body(const Runtime::CallingFrame &, Expect AVStreamDiscard::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t StreamIdx) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); AVStream **AvStream = AvFormatContext->streams; - for (unsigned int I = 1; I <= StreamIdx; I++) + for (unsigned int I = 1; I <= StreamIdx; I++) { AvStream++; + } return static_cast((*AvStream)->discard); } diff --git a/plugins/wasmedge_ffmpeg/avformat/avStream.h b/plugins/wasmedge_ffmpeg/avformat/avStream.h index 4a8956ce..552bb206 100644 --- a/plugins/wasmedge_ffmpeg/avformat/avStream.h +++ b/plugins/wasmedge_ffmpeg/avformat/avStream.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "avformat_base.h" diff --git a/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp b/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp index 5c9515f8..6ab0e259 100644 --- a/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp +++ b/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avformatContext.h" extern "C" { @@ -12,7 +15,6 @@ namespace AVFormat { Expect AVFormatCtxIFormat::body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t AvInputFormatPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(AvInputFormatId, MemInst, uint32_t, AvInputFormatPtr, "Failed when accessing the return AVInputFormat Memory"sv); @@ -27,7 +29,6 @@ Expect AVFormatCtxIFormat::body(const Runtime::CallingFrame &Frame, Expect AVFormatCtxOFormat::body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t AvOutputFormatPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(AvOutputFormatId, MemInst, uint32_t, AvOutputFormatPtr, "Failed when accessing the return AVOutputFormat Memory"sv); @@ -42,35 +43,30 @@ Expect AVFormatCtxOFormat::body(const Runtime::CallingFrame &Frame, Expect AVFormatCtxProbeScore::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); return AvFormatContext->probe_score; } Expect AVFormatCtxNbStreams::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); return AvFormatContext->nb_streams; }; Expect AVFormatCtxBitRate::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); return AvFormatContext->bit_rate; } Expect AVFormatCtxDuration::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); return AvFormatContext->duration; } Expect AVFormatCtxNbChapters::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); return AvFormatContext->nb_chapters; } @@ -78,7 +74,6 @@ Expect AVFormatCtxNbChapters::body(const Runtime::CallingFrame &, Expect AVFormatCtxSetNbChapters::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t NbChapters) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); AvFormatContext->nb_chapters = NbChapters; return static_cast(ErrNo::Success); @@ -87,7 +82,6 @@ Expect AVFormatCtxSetNbChapters::body(const Runtime::CallingFrame &, Expect AVFormatCtxMetadata::body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t DictPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(DictId, MemInst, uint32_t, DictPtr, "Failed when accessing the return AVDictionary memory"sv); @@ -105,14 +99,14 @@ Expect AVFormatCtxMetadata::body(const Runtime::CallingFrame &Frame, Expect AVFormatCtxSetMetadata::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t DictId) { - FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); FFMPEG_PTR_FETCH(AvDictionary, DictId, AVDictionary *); - if (AvDictionary == nullptr) + if (AvDictionary == nullptr) { AvFormatCtx->metadata = nullptr; - else + } else { AvFormatCtx->metadata = *AvDictionary; + } return static_cast(ErrNo::Success); } diff --git a/plugins/wasmedge_ffmpeg/avformat/avformatContext.h b/plugins/wasmedge_ffmpeg/avformat/avformatContext.h index 90cd679e..3f131944 100644 --- a/plugins/wasmedge_ffmpeg/avformat/avformatContext.h +++ b/plugins/wasmedge_ffmpeg/avformat/avformatContext.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "avformat_base.h" diff --git a/plugins/wasmedge_ffmpeg/avformat/avformat_base.h b/plugins/wasmedge_ffmpeg/avformat/avformat_base.h index 28318eb2..a53a92df 100644 --- a/plugins/wasmedge_ffmpeg/avformat/avformat_base.h +++ b/plugins/wasmedge_ffmpeg/avformat/avformat_base.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "ffmpeg_env.h" diff --git a/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp b/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp index 545894db..35637087 100644 --- a/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp +++ b/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avformat_func.h" extern "C" { @@ -15,7 +18,6 @@ Expect AVFormatOpenInput::body(const Runtime::CallingFrame &Frame, uint32_t UrlPtr, uint32_t UrlSize, uint32_t AvInputFormatId, uint32_t AvDictionaryId) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(urlId, MemInst, char, UrlPtr, "Failed when accessing the return URL memory"sv); @@ -38,7 +40,6 @@ Expect AVFormatOpenInput::body(const Runtime::CallingFrame &Frame, Expect AVFormatFindStreamInfo::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t AvDictionaryId) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); FFMPEG_PTR_FETCH(AvDictionary, AvDictionaryId, AVDictionary *); return avformat_find_stream_info(AvFormatContext, AvDictionary); @@ -46,7 +47,6 @@ Expect AVFormatFindStreamInfo::body(const Runtime::CallingFrame &, Expect AVFormatCloseInput::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId) { - FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); avformat_close_input(&AvFormatCtx); FFMPEG_PTR_DELETE(AvFormatCtxId); @@ -55,14 +55,12 @@ Expect AVFormatCloseInput::body(const Runtime::CallingFrame &, Expect AVReadPause::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); return av_read_pause(AvFormatContext); } Expect AVReadPlay::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); return av_read_play(AvFormatContext); } @@ -72,7 +70,6 @@ Expect AVFormatSeekFile::body(const Runtime::CallingFrame &, uint32_t StreamIdx, int64_t MinTs, int64_t Ts, int64_t MaxTs, int32_t Flags) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); return avformat_seek_file(AvFormatContext, StreamIdx, MinTs, Ts, MaxTs, Flags); @@ -96,7 +93,6 @@ Expect AVDumpFormat::body(const Runtime::CallingFrame &Frame, Expect AVFormatFreeContext::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId) { - FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); avformat_free_context(AvFormatCtx); FFMPEG_PTR_DELETE(AvFormatCtxId); @@ -109,7 +105,6 @@ Expect AVFindBestStream::body(const Runtime::CallingFrame &, int32_t WantedStream, int32_t RelatedStream, uint32_t DecoderRetId, int32_t Flags) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); FFMPEG_PTR_FETCH(DecoderRet, DecoderRetId, const AVCodec *); @@ -121,7 +116,6 @@ Expect AVFindBestStream::body(const Runtime::CallingFrame &, Expect AVReadFrame::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t PacketId) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); FFMPEG_PTR_FETCH(AvPacket, PacketId, AVPacket); @@ -130,7 +124,6 @@ Expect AVReadFrame::body(const Runtime::CallingFrame &, Expect AVIOClose::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId) { - FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); avio_close(AvFormatCtx->pb); return static_cast(ErrNo::Success); @@ -147,7 +140,6 @@ Expect AVFormatNetworkDeInit::body(const Runtime::CallingFrame &) { Expect AVFormatWriteHeader::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t DictId) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); FFMPEG_PTR_FETCH(AvDict, DictId, AVDictionary *); return avformat_write_header(AvFormatContext, AvDict); @@ -155,7 +147,6 @@ Expect AVFormatWriteHeader::body(const Runtime::CallingFrame &, Expect AVFormatWriteTrailer::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId) { - FFMPEG_PTR_FETCH(AvFormatContext, AvFormatCtxId, AVFormatContext); return av_write_trailer(AvFormatContext); } @@ -164,7 +155,6 @@ Expect AVFormatAllocOutputContext2::body( const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxPtr, uint32_t AVOutputFormatId, uint32_t FormatNamePtr, uint32_t FormatLen, uint32_t FileNamePtr, uint32_t FileNameLen) { - std::string Format; MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(FileId, MemInst, char, FileNamePtr, @@ -199,7 +189,6 @@ Expect AVFormatAllocOutputContext2::body( Expect AVIOOpen::body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t FileNamePtr, uint32_t FileNameLen, int32_t Flags) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(FileId, MemInst, char, FileNamePtr, "Failed when accessing the return FileName memory"sv); @@ -217,7 +206,6 @@ Expect AVIOOpen2::body(const Runtime::CallingFrame &Frame, uint32_t UrlLen, int32_t Flags, uint32_t AVIOInterruptCBId, uint32_t AVDictionaryId) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(UrlId, MemInst, char, UrlPtr, "Failed when accessing the return Url memory"sv); @@ -239,7 +227,6 @@ Expect AVFormatVersion::body(const Runtime::CallingFrame &) { Expect AVChapterMallocz::body(const Runtime::CallingFrame &Frame, uint32_t AVChapterPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(AvChapterId, MemInst, uint32_t, AVChapterPtr, "Failed to access Memory for AVChapterPtr"sv) @@ -262,14 +249,14 @@ Expect AVChapterDynarrayAdd::body(const Runtime::CallingFrame &Frame, FFMPEG_PTR_FETCH(AvChapter, AvChapterId, AVChapter); av_dynarray_add(&(AvFormatContext->chapters), NbChapters, AvChapter); - if (*(AvFormatContext->chapters) == nullptr && *(NbChapters) == 0) + if (*(AvFormatContext->chapters) == nullptr && *(NbChapters) == 0) { return static_cast(ErrNo::InternalError); + } return static_cast(ErrNo::Success); } Expect AVFreeP::body(const Runtime::CallingFrame &, uint32_t AvChapterId) { - FFMPEG_PTR_FETCH(AvChapter, AvChapterId, AVChapter); av_freep(AvChapter); FFMPEG_PTR_DELETE(AvChapterId); @@ -279,7 +266,6 @@ Expect AVFreeP::body(const Runtime::CallingFrame &, Expect AVInterleavedWriteFrame::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t AvPacketId) { - FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); @@ -289,7 +275,6 @@ Expect AVInterleavedWriteFrame::body(const Runtime::CallingFrame &, Expect AVWriteFrame::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t AvPacketId) { - FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); FFMPEG_PTR_FETCH(AvPacket, AvPacketId, AVPacket); @@ -299,12 +284,12 @@ Expect AVWriteFrame::body(const Runtime::CallingFrame &, Expect AVFormatNewStream::body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t AvCodecId) { - FFMPEG_PTR_FETCH(AvFormatCtx, AvFormatCtxId, AVFormatContext); FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); AVStream *Stream = avformat_new_stream(AvFormatCtx, AvCodec); - if (Stream == nullptr) + if (Stream == nullptr) { return 0; + } return 1; } @@ -314,7 +299,6 @@ Expect AVGuessCodec::body(const Runtime::CallingFrame &Frame, uint32_t ShortNameLen, uint32_t FileNamePtr, uint32_t FileNameLen, uint32_t MimeTypePtr, uint32_t MimeTypeLen, int32_t MediaTypeId) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(ShortNameBuf, MemInst, char, ShortNamePtr, "Failed when accessing the return ShortName memory"sv); @@ -342,7 +326,6 @@ Expect AVGuessCodec::body(const Runtime::CallingFrame &Frame, Expect AVFormatConfigurationLength::body(const Runtime::CallingFrame &) { - const char *Config = avformat_configuration(); return strlen(Config); } @@ -350,7 +333,6 @@ AVFormatConfigurationLength::body(const Runtime::CallingFrame &) { Expect AVFormatConfiguration::body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, uint32_t ConfigLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); @@ -360,7 +342,6 @@ Expect AVFormatConfiguration::body(const Runtime::CallingFrame &Frame, } Expect AVFormatLicenseLength::body(const Runtime::CallingFrame &) { - const char *License = avformat_license(); return strlen(License); } @@ -368,7 +349,6 @@ Expect AVFormatLicenseLength::body(const Runtime::CallingFrame &) { Expect AVFormatLicense::body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, uint32_t LicenseLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); diff --git a/plugins/wasmedge_ffmpeg/avformat/avformat_func.h b/plugins/wasmedge_ffmpeg/avformat/avformat_func.h index 3cc44a33..30d8a52b 100644 --- a/plugins/wasmedge_ffmpeg/avformat/avformat_func.h +++ b/plugins/wasmedge_ffmpeg/avformat/avformat_func.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "avformat_base.h" diff --git a/plugins/wasmedge_ffmpeg/avformat/module.cpp b/plugins/wasmedge_ffmpeg/avformat/module.cpp index 0d1ec108..1246beff 100644 --- a/plugins/wasmedge_ffmpeg/avformat/module.cpp +++ b/plugins/wasmedge_ffmpeg/avformat/module.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "module.h" #include "avChapter.h" #include "avInputOutputFormat.h" @@ -13,7 +16,6 @@ namespace AVFormat { WasmEdgeFFmpegAVFormatModule::WasmEdgeFFmpegAVFormatModule( std::shared_ptr Env) : ModuleInstance("wasmedge_ffmpeg_avformat") { - // avformat_func.h addHostFunc("wasmedge_ffmpeg_avformat_avformat_open_input", std::make_unique(Env)); diff --git a/plugins/wasmedge_ffmpeg/avformat/module.h b/plugins/wasmedge_ffmpeg/avformat/module.h index 8e5d9740..eae89e3a 100644 --- a/plugins/wasmedge_ffmpeg/avformat/module.h +++ b/plugins/wasmedge_ffmpeg/avformat/module.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "ffmpeg_env.h" diff --git a/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp b/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp index 047f309c..908c47c6 100644 --- a/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp +++ b/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avDictionary.h" extern "C" { @@ -13,7 +16,6 @@ Expect AVDictSet::body(const Runtime::CallingFrame &Frame, uint32_t DictPtr, uint32_t KeyPtr, uint32_t KeyLen, uint32_t ValuePtr, uint32_t ValueLen, int32_t Flags) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(KeyBuf, MemInst, char, KeyPtr, "Failed when accessing the return Key memory"sv); @@ -47,7 +49,6 @@ Expect AVDictSet::body(const Runtime::CallingFrame &Frame, Expect AVDictCopy::body(const Runtime::CallingFrame &Frame, uint32_t DestDictPtr, uint32_t SrcDictId, uint32_t Flags) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(DestDictId, MemInst, uint32_t, DestDictPtr, "Failed to access Memory for AVDict"sv) @@ -56,8 +57,9 @@ Expect AVDictCopy::body(const Runtime::CallingFrame &Frame, int Res = 0; - if (SrcAvDict == nullptr) + if (SrcAvDict == nullptr) { return static_cast(ErrNo::InternalError); + } if (*DestDictId) { FFMPEG_PTR_FETCH(DestAvDict, *DestDictId, AVDictionary *); @@ -77,7 +79,6 @@ Expect AVDictGet::body(const Runtime::CallingFrame &Frame, uint32_t KeyLen, uint32_t PrevDictEntryIdx, uint32_t Flags, uint32_t KeyLenPtr, uint32_t ValueLenPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(KeyStr, MemInst, char, KeyPtr, "Failed when accessing the return Key memory"sv); @@ -89,8 +90,9 @@ Expect AVDictGet::body(const Runtime::CallingFrame &Frame, FFMPEG_PTR_FETCH(AvDict, DictId, AVDictionary *); // If Dict Not created return (i.e. 0 is passed as AVDictId) - if (AvDict == nullptr) + if (AvDict == nullptr) { return static_cast(ErrNo::InternalError); + } std::string Key; std::copy_n(KeyStr, KeyLen, std::back_inserter(Key)); @@ -101,8 +103,9 @@ Expect AVDictGet::body(const Runtime::CallingFrame &Frame, Curr++; } - if (DictEntry == nullptr) + if (DictEntry == nullptr) { return static_cast(ErrNo::InternalError); + } *KeyLenId = strlen(DictEntry->key); *ValueLenId = strlen(DictEntry->value); @@ -113,7 +116,6 @@ Expect AVDictGetKeyValue::body( const Runtime::CallingFrame &Frame, uint32_t DictId, uint32_t KeyPtr, uint32_t KeyLen, uint32_t ValBufPtr, uint32_t ValBufLen, uint32_t KeyBufPtr, uint32_t KeyBufLen, uint32_t PrevDictEntryIdx, uint32_t Flags) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(KeyStr, MemInst, char, KeyPtr, "Failed when accessing the return Key memory"sv); @@ -123,8 +125,9 @@ Expect AVDictGetKeyValue::body( FFMPEG_PTR_FETCH(AvDict, DictId, AVDictionary *); // If Dict Not created return (i.e. 0 is passed as AVDictId) - if (AvDict == nullptr) + if (AvDict == nullptr) { return static_cast(ErrNo::InternalError); + } std::string Key; std::copy_n(KeyStr, KeyLen, std::back_inserter(Key)); @@ -135,8 +138,9 @@ Expect AVDictGetKeyValue::body( DictEntry = av_dict_get(*AvDict, Key.c_str(), DictEntry, Flags); Curr++; } - if (DictEntry == nullptr) + if (DictEntry == nullptr) { return static_cast(ErrNo::InternalError); + } std::copy_n(DictEntry->value, strlen(DictEntry->value), ValBuf.data()); std::copy_n(DictEntry->key, strlen(DictEntry->key), KeyBuf.data()); return Curr; @@ -144,9 +148,9 @@ Expect AVDictGetKeyValue::body( Expect AVDictFree::body(const Runtime::CallingFrame &, uint32_t DictId) { - - if (DictId == 0) + if (DictId == 0) { return static_cast(ErrNo::Success); + } FFMPEG_PTR_FETCH(AvDict, DictId, AVDictionary *); av_dict_free(AvDict); FFMPEG_PTR_DELETE(DictId); diff --git a/plugins/wasmedge_ffmpeg/avutil/avDictionary.h b/plugins/wasmedge_ffmpeg/avutil/avDictionary.h index 8d5a5ff1..b8732c76 100644 --- a/plugins/wasmedge_ffmpeg/avutil/avDictionary.h +++ b/plugins/wasmedge_ffmpeg/avutil/avDictionary.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "avutil_base.h" diff --git a/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp b/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp index 414031d9..2dc1449b 100644 --- a/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp +++ b/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avFrame.h" extern "C" { @@ -23,7 +26,6 @@ Expect AVFrameAlloc::body(const Runtime::CallingFrame &Frame, Expect AVFrameFree::body(const Runtime::CallingFrame &, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); av_frame_free(&AvFrame); FFMPEG_PTR_DELETE(FrameId); @@ -32,21 +34,18 @@ Expect AVFrameFree::body(const Runtime::CallingFrame &, Expect AVFrameWidth::body(const Runtime::CallingFrame &, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); return AvFrame->width; } Expect AVFrameHeight::body(const Runtime::CallingFrame &, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); return AvFrame->height; } Expect AVFrameSetHeight::body(const Runtime::CallingFrame &, uint32_t FrameId, uint32_t Height) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); AvFrame->height = Height; return static_cast(ErrNo::Success); @@ -54,7 +53,6 @@ Expect AVFrameSetHeight::body(const Runtime::CallingFrame &, Expect AVFrameSetWidth::body(const Runtime::CallingFrame &, uint32_t FrameId, uint32_t Width) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); AvFrame->width = Width; return static_cast(ErrNo::Success); @@ -62,11 +60,11 @@ Expect AVFrameSetWidth::body(const Runtime::CallingFrame &, Expect AVFrameVideoFormat::body(const Runtime::CallingFrame &, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); int const Format = AvFrame->format; - if (Format == -1) + if (Format == -1) { return -1; + } AVPixelFormat const PixelFormat = static_cast(Format); return FFmpegUtils::PixFmt::fromAVPixFmt(PixelFormat); } @@ -74,7 +72,6 @@ Expect AVFrameVideoFormat::body(const Runtime::CallingFrame &, Expect AVFrameSetVideoFormat::body(const Runtime::CallingFrame &, uint32_t FrameId, uint32_t AvPixFormatId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); AVPixelFormat const PixelFormat = FFmpegUtils::PixFmt::intoAVPixFmt(AvPixFormatId); @@ -84,14 +81,12 @@ Expect AVFrameSetVideoFormat::body(const Runtime::CallingFrame &, Expect AVFrameIsNull::body(const Runtime::CallingFrame &, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); return AvFrame->data[0] == nullptr; } Expect AVFrameLinesize::body(const Runtime::CallingFrame &, uint32_t FrameId, uint32_t Idx) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); return AvFrame->linesize[Idx]; } @@ -99,7 +94,6 @@ Expect AVFrameLinesize::body(const Runtime::CallingFrame &, Expect AVFrameData::body(const Runtime::CallingFrame &Frame, uint32_t FrameId, uint32_t FrameBufPtr, uint32_t FrameBufLen, uint32_t Index) { - MEMINST_CHECK(MemInst, Frame, 0) MEM_SPAN_CHECK(Buffer, MemInst, uint8_t, FrameBufPtr, FrameBufLen, ""); FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); @@ -111,18 +105,17 @@ Expect AVFrameData::body(const Runtime::CallingFrame &Frame, Expect AVFrameGetBuffer::body(const Runtime::CallingFrame &, uint32_t FrameId, int32_t Align) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); return av_frame_get_buffer(AvFrame, Align); } Expect AVFrameAudioFormat::body(const Runtime::CallingFrame &, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); int const Format = AvFrame->format; - if (Format == -1) + if (Format == -1) { return -1; + } AVSampleFormat const SampleFormat = static_cast(Format); return FFmpegUtils::SampleFmt::toSampleID(SampleFormat); @@ -141,7 +134,6 @@ Expect AVFrameSetAudioFormat::body(const Runtime::CallingFrame &, Expect AVFrameSetChannelLayout::body(const Runtime::CallingFrame &, uint32_t FrameId, uint64_t ChannelLayoutID) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); uint64_t const ChannelLayout = FFmpegUtils::ChannelLayout::fromChannelLayoutID(ChannelLayoutID); @@ -151,7 +143,6 @@ Expect AVFrameSetChannelLayout::body(const Runtime::CallingFrame &, Expect AVFrameSetNbSamples::body(const Runtime::CallingFrame &, uint32_t FrameId, int32_t Samples) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); AvFrame->nb_samples = Samples; return static_cast(ErrNo::Success); @@ -159,14 +150,12 @@ Expect AVFrameSetNbSamples::body(const Runtime::CallingFrame &, Expect AVFrameNbSamples::body(const Runtime::CallingFrame &, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); return AvFrame->nb_samples; } Expect AVFrameSampleRate::body(const Runtime::CallingFrame &, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); return AvFrame->sample_rate; } @@ -174,7 +163,6 @@ Expect AVFrameSampleRate::body(const Runtime::CallingFrame &, Expect AVFrameSetSampleRate::body(const Runtime::CallingFrame &, uint32_t FrameId, int32_t SampleRate) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); AvFrame->sample_rate = SampleRate; return static_cast(ErrNo::Success); @@ -182,14 +170,12 @@ Expect AVFrameSetSampleRate::body(const Runtime::CallingFrame &, Expect AVFrameChannels::body(const Runtime::CallingFrame &, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); return AvFrame->channels; } Expect AVFrameSetChannels::body(const Runtime::CallingFrame &, uint32_t FrameId, int32_t Channels) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); AvFrame->channels = Channels; return static_cast(ErrNo::Success); @@ -204,14 +190,12 @@ Expect AVFrameChannelLayout::body(const Runtime::CallingFrame &, Expect AVFrameBestEffortTimestamp::body(const Runtime::CallingFrame &, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); return AvFrame->best_effort_timestamp; } Expect AVFramePictType::body(const Runtime::CallingFrame &, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); AVPictureType const AvPictureType = AvFrame->pict_type; return FFmpegUtils::PictureType::fromAVPictureType(AvPictureType); @@ -219,7 +203,6 @@ Expect AVFramePictType::body(const Runtime::CallingFrame &, Expect AVFrameSetPictType::body(const Runtime::CallingFrame &, uint32_t FrameId, int32_t PictureId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); AVPictureType const AvPictureType = FFmpegUtils::PictureType::intoAVPictureType(PictureId); @@ -230,28 +213,24 @@ Expect AVFrameSetPictType::body(const Runtime::CallingFrame &, Expect AVFrameInterlacedFrame::body(const Runtime::CallingFrame &, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); return AvFrame->interlaced_frame; } Expect AVFrameTopFieldFirst::body(const Runtime::CallingFrame &, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); return AvFrame->top_field_first; } Expect AVFramePaletteHasChanged::body(const Runtime::CallingFrame &, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); return AvFrame->palette_has_changed; } Expect AVFrameColorSpace::body(const Runtime::CallingFrame &, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); AVColorSpace const AvColorSpace = AvFrame->colorspace; return FFmpegUtils::ColorSpace::fromAVColorSpace(AvColorSpace); @@ -260,7 +239,6 @@ Expect AVFrameColorSpace::body(const Runtime::CallingFrame &, Expect AVFrameSetColorSpace::body(const Runtime::CallingFrame &, uint32_t FrameId, int32_t ColorSpaceId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); AvFrame->colorspace = FFmpegUtils::ColorSpace::intoAVColorSpace(ColorSpaceId); return static_cast(ErrNo::Success); @@ -268,7 +246,6 @@ Expect AVFrameSetColorSpace::body(const Runtime::CallingFrame &, Expect AVFrameColorRange::body(const Runtime::CallingFrame &, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); AVColorRange const AvColorRange = AvFrame->color_range; @@ -278,7 +255,6 @@ Expect AVFrameColorRange::body(const Runtime::CallingFrame &, Expect AVFrameSetColorRange::body(const Runtime::CallingFrame &, uint32_t FrameId, int32_t ColorRangeId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); AvFrame->color_range = static_cast(ColorRangeId); return static_cast(ErrNo::Success); @@ -297,7 +273,6 @@ AVFrameColorTransferCharacteristic::body(const Runtime::CallingFrame &, Expect AVFrameSetColorTransferCharacteristic::body( const Runtime::CallingFrame &, uint32_t FrameId, int32_t ColorTransferCharacteristicId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); AvFrame->color_trc = static_cast(ColorTransferCharacteristicId); @@ -306,7 +281,6 @@ Expect AVFrameSetColorTransferCharacteristic::body( Expect AVFrameChromaLocation::body(const Runtime::CallingFrame &, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); AVChromaLocation const AvChromaLocation = AvFrame->chroma_location; return FFmpegUtils::ChromaLocation::fromAVChromaLocation(AvChromaLocation); @@ -314,42 +288,36 @@ Expect AVFrameChromaLocation::body(const Runtime::CallingFrame &, Expect AVFrameCodedPictureNumber::body(const Runtime::CallingFrame &, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); return AvFrame->coded_picture_number; } Expect AVFrameDisplayPictureNumber::body(const Runtime::CallingFrame &, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); return AvFrame->display_picture_number; } Expect AVFrameRepeatPict::body(const Runtime::CallingFrame &, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); return AvFrame->repeat_pict; } Expect AVFrameFlags::body(const Runtime::CallingFrame &, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); return AvFrame->flags; } Expect AVFrameQuality::body(const Runtime::CallingFrame &, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); return AvFrame->quality; } Expect AVFrameMetadata::body(const Runtime::CallingFrame &Frame, uint32_t FrameId, uint32_t DictPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(DictId, MemInst, uint32_t, DictPtr, "Failed when accessing the return AVDictionary memory"sv); @@ -366,34 +334,31 @@ Expect AVFrameMetadata::body(const Runtime::CallingFrame &Frame, Expect AVFrameSetMetadata::body(const Runtime::CallingFrame &, uint32_t FrameId, uint32_t DictId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); FFMPEG_PTR_FETCH(AvDict, DictId, AVDictionary *); - if (AvDict == nullptr) + if (AvDict == nullptr) { AvFrame->metadata = nullptr; - else + } else { AvFrame->metadata = *AvDict; + } return static_cast(ErrNo::Success); } Expect AVFrameKeyFrame::body(const Runtime::CallingFrame &, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); return AvFrame->key_frame; } Expect AVFramePts::body(const Runtime::CallingFrame &, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); return AvFrame->pts; } Expect AVFrameSetPts::body(const Runtime::CallingFrame &, uint32_t FrameId, int64_t Pts) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); AvFrame->pts = Pts; return static_cast(ErrNo::Success); @@ -401,7 +366,6 @@ Expect AVFrameSetPts::body(const Runtime::CallingFrame &, Expect AVFrameCopy::body(const Runtime::CallingFrame &, uint32_t DestFrameId, uint32_t SrcFrameId) { - FFMPEG_PTR_FETCH(DestAvFrame, DestFrameId, AVFrame); FFMPEG_PTR_FETCH(SrcAvFrame, SrcFrameId, AVFrame); @@ -412,7 +376,6 @@ Expect AVFrameCopy::body(const Runtime::CallingFrame &, Expect AVFrameCopyProps::body(const Runtime::CallingFrame &, uint32_t DestFrameId, uint32_t SrcFrameId) { - FFMPEG_PTR_FETCH(DestAvFrame, DestFrameId, AVFrame); FFMPEG_PTR_FETCH(SrcAvFrame, SrcFrameId, AVFrame); @@ -424,7 +387,6 @@ Expect AVFrameSampleAspectRatio::body(const Runtime::CallingFrame &Frame, uint32_t FrameId, uint32_t NumPtr, uint32_t DenPtr) { - MEMINST_CHECK(MemInst, Frame, 0) FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); @@ -441,7 +403,6 @@ AVFrameSampleAspectRatio::body(const Runtime::CallingFrame &Frame, Expect AVFrameColorPrimaries::body(const Runtime::CallingFrame &, uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); AVColorPrimaries const ColorPrimaries = AvFrame->color_primaries; return FFmpegUtils::ColorPrimaries::fromAVColorPrimaries(ColorPrimaries); @@ -450,7 +411,6 @@ Expect AVFrameColorPrimaries::body(const Runtime::CallingFrame &, Expect AVFrameSetColorPrimaries::body(const Runtime::CallingFrame &, uint32_t FrameId, int32_t ColorPrimariesId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); AVColorPrimaries const ColorPrimaries = FFmpegUtils::ColorPrimaries::intoAVColorPrimaries(ColorPrimariesId); diff --git a/plugins/wasmedge_ffmpeg/avutil/avFrame.h b/plugins/wasmedge_ffmpeg/avutil/avFrame.h index 3bceaa3e..39cb732e 100644 --- a/plugins/wasmedge_ffmpeg/avutil/avFrame.h +++ b/plugins/wasmedge_ffmpeg/avutil/avFrame.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "avutil_base.h" diff --git a/plugins/wasmedge_ffmpeg/avutil/avRational.cpp b/plugins/wasmedge_ffmpeg/avutil/avRational.cpp index 8fbe81fd..deedd7e4 100644 --- a/plugins/wasmedge_ffmpeg/avutil/avRational.cpp +++ b/plugins/wasmedge_ffmpeg/avutil/avRational.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avRational.h" extern "C" { @@ -12,7 +15,6 @@ namespace AVUtil { Expect AVAddQ::body(const Runtime::CallingFrame &Frame, int32_t ANum, int32_t ADen, int32_t BNum, int32_t BDen, uint32_t CNumPtr, uint32_t CDenPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(CNum, MemInst, int32_t, CNumPtr, "Failed to access Numerator Ptr for AVRational"sv); @@ -32,7 +34,6 @@ Expect AVAddQ::body(const Runtime::CallingFrame &Frame, int32_t ANum, Expect AVSubQ::body(const Runtime::CallingFrame &Frame, int32_t ANum, int32_t ADen, int32_t BNum, int32_t BDen, uint32_t CNumPtr, uint32_t CDenPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(CNum, MemInst, int32_t, CNumPtr, "Failed to access Numerator Ptr for AVRational"sv); @@ -51,7 +52,6 @@ Expect AVSubQ::body(const Runtime::CallingFrame &Frame, int32_t ANum, Expect AVMulQ::body(const Runtime::CallingFrame &Frame, int32_t ANum, int32_t ADen, int32_t BNum, int32_t BDen, uint32_t CNumPtr, uint32_t CDenPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(CNum, MemInst, int32_t, CNumPtr, "Failed to access Numerator Ptr for AVRational"sv); @@ -70,7 +70,6 @@ Expect AVMulQ::body(const Runtime::CallingFrame &Frame, int32_t ANum, Expect AVDivQ::body(const Runtime::CallingFrame &Frame, int32_t ANum, int32_t ADen, int32_t BNum, int32_t BDen, uint32_t CNumPtr, uint32_t CDenPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(CNum, MemInst, int32_t, CNumPtr, "Failed to access Numerator Ptr for AVRational"sv); @@ -88,7 +87,6 @@ Expect AVDivQ::body(const Runtime::CallingFrame &Frame, int32_t ANum, Expect AVCmpQ::body(const Runtime::CallingFrame &, int32_t ANum, int32_t ADen, int32_t BNum, int32_t BDen) { - AVRational const A = av_make_q(ANum, ADen); AVRational const B = av_make_q(BNum, BDen); return av_cmp_q(A, B); @@ -97,7 +95,6 @@ Expect AVCmpQ::body(const Runtime::CallingFrame &, int32_t ANum, Expect AVNearerQ::body(const Runtime::CallingFrame &, int32_t ANum, int32_t ADen, int32_t BNum, int32_t BDen, int32_t CNum, int32_t CDen) { - AVRational const A = av_make_q(ANum, ADen); AVRational const B = av_make_q(BNum, BDen); AVRational const C = av_make_q(CNum, CDen); @@ -107,14 +104,12 @@ Expect AVNearerQ::body(const Runtime::CallingFrame &, int32_t ANum, Expect AVQ2d::body(const Runtime::CallingFrame &, int32_t ANum, int32_t ADen) { - AVRational const A = av_make_q(ANum, ADen); return av_q2d(A); } Expect AVD2Q::body(const Runtime::CallingFrame &Frame, double_t D, int32_t Max, uint32_t ANumPtr, uint32_t ADenPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(ANum, MemInst, int32_t, ANumPtr, "Failed to access Numerator Ptr for AVRational"sv); @@ -129,14 +124,12 @@ Expect AVD2Q::body(const Runtime::CallingFrame &Frame, double_t D, Expect AVQ2IntFloat::body(const Runtime::CallingFrame &, int32_t ANum, int32_t ADen) { - AVRational const A = av_make_q(ANum, ADen); return av_q2intfloat(A); } Expect AVInvQ::body(const Runtime::CallingFrame &Frame, int32_t ANum, int32_t ADen, uint32_t BNumPtr, uint32_t BDenPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(BNum, MemInst, int32_t, BNumPtr, "Failed to access Numerator Ptr for AVRational"sv); @@ -154,7 +147,6 @@ Expect AVInvQ::body(const Runtime::CallingFrame &Frame, int32_t ANum, Expect AVReduce::body(const Runtime::CallingFrame &Frame, uint32_t ANumPtr, uint32_t ADenPtr, int64_t BNum, int64_t BDen, int64_t Max) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(ANum, MemInst, int32_t, ANumPtr, "Failed to access Numerator Ptr for AVRational"sv); diff --git a/plugins/wasmedge_ffmpeg/avutil/avRational.h b/plugins/wasmedge_ffmpeg/avutil/avRational.h index b158e663..f37195da 100644 --- a/plugins/wasmedge_ffmpeg/avutil/avRational.h +++ b/plugins/wasmedge_ffmpeg/avutil/avRational.h @@ -1,4 +1,8 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once + #include "avutil_base.h" #include "runtime/callingframe.h" diff --git a/plugins/wasmedge_ffmpeg/avutil/avTime.cpp b/plugins/wasmedge_ffmpeg/avutil/avTime.cpp index 1ebbb03a..40cdaba1 100644 --- a/plugins/wasmedge_ffmpeg/avutil/avTime.cpp +++ b/plugins/wasmedge_ffmpeg/avutil/avTime.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avTime.h" extern "C" { diff --git a/plugins/wasmedge_ffmpeg/avutil/avTime.h b/plugins/wasmedge_ffmpeg/avutil/avTime.h index 803e404a..6a841844 100644 --- a/plugins/wasmedge_ffmpeg/avutil/avTime.h +++ b/plugins/wasmedge_ffmpeg/avutil/avTime.h @@ -1,4 +1,8 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once + #include "avutil_base.h" #include "runtime/callingframe.h" diff --git a/plugins/wasmedge_ffmpeg/avutil/avutil_base.h b/plugins/wasmedge_ffmpeg/avutil/avutil_base.h index dcf35283..851f692c 100644 --- a/plugins/wasmedge_ffmpeg/avutil/avutil_base.h +++ b/plugins/wasmedge_ffmpeg/avutil/avutil_base.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "ffmpeg_env.h" diff --git a/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp b/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp index a1f20e4f..6dfeaf01 100644 --- a/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp +++ b/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avutil_func.h" extern "C" { @@ -33,7 +36,6 @@ Expect AVLogSetFlags::body(const Runtime::CallingFrame &, Expect AVRescaleQ::body(const Runtime::CallingFrame &, int64_t A, int32_t BNum, int32_t BDen, int32_t CNum, int32_t CDen) { - AVRational const B = av_make_q(BNum, BDen); AVRational const C = av_make_q(CNum, CDen); return av_rescale_q(A, B, C); @@ -42,7 +44,6 @@ Expect AVRescaleQ::body(const Runtime::CallingFrame &, int64_t A, Expect AVRescaleQRnd::body(const Runtime::CallingFrame &, int64_t A, int32_t BNum, int32_t BDen, int32_t CNum, int32_t CDen, int32_t RoundingId) { - AVRational const B = av_make_q(BNum, BDen); AVRational const C = av_make_q(CNum, CDen); AVRounding const Rounding = FFmpegUtils::Rounding::intoAVRounding(RoundingId); @@ -63,12 +64,12 @@ AVGetChannelLayoutNbChannels::body(const Runtime::CallingFrame &, Expect AVGetChannelLayoutNameLen::body(const Runtime::CallingFrame &, uint64_t ChannelLayoutId) { - uint64_t const ChannelLayout = FFmpegUtils::ChannelLayout::fromChannelLayoutID(ChannelLayoutId); const char *ChName = av_get_channel_name(ChannelLayout); - if (ChName == nullptr) + if (ChName == nullptr) { return 0; + } return strlen(ChName); } @@ -76,7 +77,6 @@ Expect AVGetChannelLayoutName::body(const Runtime::CallingFrame &Frame, uint64_t ChannelLayoutId, uint32_t NamePtr, uint32_t NameLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(NameBuf, MemInst, char, NamePtr, NameLen, ""); @@ -90,7 +90,6 @@ Expect AVGetChannelLayoutName::body(const Runtime::CallingFrame &Frame, Expect AVGetChannelLayoutMask::body(const Runtime::CallingFrame &, uint64_t ChannelLayoutId) { - uint64_t const ChannelLayout = FFmpegUtils::ChannelLayout::fromChannelLayoutID(ChannelLayoutId); return ChannelLayout; @@ -110,7 +109,6 @@ Expect AVUtilConfigurationLength::body(const Runtime::CallingFrame &) { Expect AVUtilConfiguration::body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, uint32_t ConfigLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); @@ -120,14 +118,12 @@ Expect AVUtilConfiguration::body(const Runtime::CallingFrame &Frame, } Expect AVUtilLicenseLength::body(const Runtime::CallingFrame &) { - const char *License = avutil_license(); return strlen(License); } Expect AVUtilLicense::body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, uint32_t LicenseLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); diff --git a/plugins/wasmedge_ffmpeg/avutil/avutil_func.h b/plugins/wasmedge_ffmpeg/avutil/avutil_func.h index 1529e94f..5532eb59 100644 --- a/plugins/wasmedge_ffmpeg/avutil/avutil_func.h +++ b/plugins/wasmedge_ffmpeg/avutil/avutil_func.h @@ -1,4 +1,8 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once + #include "avutil_base.h" #include "runtime/callingframe.h" diff --git a/plugins/wasmedge_ffmpeg/avutil/error.cpp b/plugins/wasmedge_ffmpeg/avutil/error.cpp index d918ece5..90df0fbf 100644 --- a/plugins/wasmedge_ffmpeg/avutil/error.cpp +++ b/plugins/wasmedge_ffmpeg/avutil/error.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "error.h" extern "C" { @@ -12,7 +15,6 @@ namespace AVUtil { Expect AVUtilAVStrError::body(const Runtime::CallingFrame &Frame, int32_t ErrNum, uint32_t ErrBuf, uint32_t BufLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(ErrId, MemInst, char, ErrBuf, diff --git a/plugins/wasmedge_ffmpeg/avutil/error.h b/plugins/wasmedge_ffmpeg/avutil/error.h index a8137151..3e2ec1e3 100644 --- a/plugins/wasmedge_ffmpeg/avutil/error.h +++ b/plugins/wasmedge_ffmpeg/avutil/error.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "avutil_base.h" diff --git a/plugins/wasmedge_ffmpeg/avutil/module.cpp b/plugins/wasmedge_ffmpeg/avutil/module.cpp index 12588050..2dd207c3 100644 --- a/plugins/wasmedge_ffmpeg/avutil/module.cpp +++ b/plugins/wasmedge_ffmpeg/avutil/module.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "module.h" #include "avDictionary.h" #include "avFrame.h" @@ -16,7 +19,6 @@ namespace AVUtil { WasmEdgeFFmpegAVUtilModule::WasmEdgeFFmpegAVUtilModule( std::shared_ptr Env) : ModuleInstance("wasmedge_ffmpeg_avutil") { - // error.h addHostFunc("wasmedge_ffmpeg_avutil_av_strerror", std::make_unique(Env)); diff --git a/plugins/wasmedge_ffmpeg/avutil/module.h b/plugins/wasmedge_ffmpeg/avutil/module.h index ebd35dba..0ef5e265 100644 --- a/plugins/wasmedge_ffmpeg/avutil/module.h +++ b/plugins/wasmedge_ffmpeg/avutil/module.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "ffmpeg_env.h" diff --git a/plugins/wasmedge_ffmpeg/avutil/pixfmt.cpp b/plugins/wasmedge_ffmpeg/avutil/pixfmt.cpp index 9d5b55c2..639a9241 100644 --- a/plugins/wasmedge_ffmpeg/avutil/pixfmt.cpp +++ b/plugins/wasmedge_ffmpeg/avutil/pixfmt.cpp @@ -1,4 +1,8 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "pixfmt.h" + extern "C" { #include "libavutil/pixdesc.h" } @@ -21,7 +25,6 @@ AvPixFmtDescriptorNbComponents::body(const Runtime::CallingFrame &, Expect AvPixFmtDescriptorLog2ChromaW::body(const Runtime::CallingFrame &, uint32_t PixFormatId) { - AVPixelFormat const PixelFormat = FFmpegUtils::PixFmt::intoAVPixFmt(PixFormatId); const AVPixFmtDescriptor *AvPixFmtDescriptor = @@ -32,7 +35,6 @@ AvPixFmtDescriptorLog2ChromaW::body(const Runtime::CallingFrame &, Expect AvPixFmtDescriptorLog2ChromaH::body(const Runtime::CallingFrame &, uint32_t PixFormatId) { - AVPixelFormat const PixelFormat = FFmpegUtils::PixFmt::intoAVPixFmt(PixFormatId); const AVPixFmtDescriptor *AvPixFmtDescriptor = @@ -42,7 +44,6 @@ AvPixFmtDescriptorLog2ChromaH::body(const Runtime::CallingFrame &, Expect AVColorRangeNameLength::body(const Runtime::CallingFrame &, int32_t RangeId) { - AVColorRange const ColorRange = static_cast(RangeId); const char *Name = av_color_range_name(ColorRange); return strlen(Name); @@ -51,7 +52,6 @@ Expect AVColorRangeNameLength::body(const Runtime::CallingFrame &, Expect AVColorRangeName::body(const Runtime::CallingFrame &Frame, int32_t RangeId, uint32_t RangeNamePtr, uint32_t RangeLength) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(RangeNameBuf, MemInst, char, RangeNamePtr, RangeLength, ""); @@ -63,7 +63,6 @@ Expect AVColorRangeName::body(const Runtime::CallingFrame &Frame, Expect AVColorTransferNameLength::body(const Runtime::CallingFrame &, int32_t TransferId) { - AVColorTransferCharacteristic const Characteristic = static_cast(TransferId); const char *Name = av_color_transfer_name(Characteristic); @@ -74,7 +73,6 @@ Expect AVColorTransferName::body(const Runtime::CallingFrame &Frame, int32_t TransferId, uint32_t TransferNamePtr, uint32_t TransferLength) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(TransferNameBuf, MemInst, char, TransferNamePtr, TransferLength, ""); @@ -88,7 +86,6 @@ Expect AVColorTransferName::body(const Runtime::CallingFrame &Frame, Expect AVColorSpaceNameLength::body(const Runtime::CallingFrame &, int32_t ColorSpaceId) { - AVColorSpace const ColorSpace = static_cast(ColorSpaceId); const char *Name = av_color_space_name(ColorSpace); return strlen(Name); @@ -98,7 +95,6 @@ Expect AVColorSpaceName::body(const Runtime::CallingFrame &Frame, int32_t ColorSpaceId, uint32_t ColorSpaceNamePtr, uint32_t ColorSpaceLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(ColorSpaceBuf, MemInst, char, ColorSpaceNamePtr, ColorSpaceLen, ""); @@ -111,7 +107,6 @@ Expect AVColorSpaceName::body(const Runtime::CallingFrame &Frame, Expect AVColorPrimariesNameLength::body(const Runtime::CallingFrame &, int32_t ColorPrimariesId) { - AVColorPrimaries const ColorPrimaries = FFmpegUtils::ColorPrimaries::intoAVColorPrimaries(ColorPrimariesId); const char *Name = av_color_primaries_name(ColorPrimaries); @@ -122,7 +117,6 @@ Expect AVColorPrimariesName::body(const Runtime::CallingFrame &Frame, int32_t ColorPrimariesId, uint32_t ColorPrimariesNamePtr, uint32_t ColorPrimariesLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(ColorPrimariesBuf, MemInst, char, ColorPrimariesNamePtr, ColorPrimariesLen, ""); @@ -136,7 +130,6 @@ Expect AVColorPrimariesName::body(const Runtime::CallingFrame &Frame, Expect AVPixelFormatNameLength::body(const Runtime::CallingFrame &, uint32_t AvPixFormatId) { - AVPixelFormat const PixFormat = FFmpegUtils::PixFmt::intoAVPixFmt(AvPixFormatId); const AVPixFmtDescriptor *PixFmtDescriptor = av_pix_fmt_desc_get(PixFormat); @@ -148,7 +141,6 @@ Expect AVPixelFormatName::body(const Runtime::CallingFrame &Frame, uint32_t PixFormatId, uint32_t PixFormatNamePtr, uint32_t PixFormatNameLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(PixFormatBuf, MemInst, char, PixFormatNamePtr, PixFormatNameLen, ""); diff --git a/plugins/wasmedge_ffmpeg/avutil/pixfmt.h b/plugins/wasmedge_ffmpeg/avutil/pixfmt.h index 51126aa6..dedbcd30 100644 --- a/plugins/wasmedge_ffmpeg/avutil/pixfmt.h +++ b/plugins/wasmedge_ffmpeg/avutil/pixfmt.h @@ -1,4 +1,8 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once + #include "avutil_base.h" #include "runtime/callingframe.h" diff --git a/plugins/wasmedge_ffmpeg/avutil/samplefmt.cpp b/plugins/wasmedge_ffmpeg/avutil/samplefmt.cpp index f38914e6..40150a68 100644 --- a/plugins/wasmedge_ffmpeg/avutil/samplefmt.cpp +++ b/plugins/wasmedge_ffmpeg/avutil/samplefmt.cpp @@ -1,4 +1,8 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "samplefmt.h" + extern "C" { #include "libavutil/samplefmt.h" } @@ -42,7 +46,6 @@ Expect AVGetBytesPerSample::body(const Runtime::CallingFrame &, Expect AVGetSampleFmt::body(const Runtime::CallingFrame &Frame, uint32_t Str, uint32_t StrLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(StrId, MemInst, char, Str, ""); @@ -88,7 +91,6 @@ AVSamplesAllocArrayAndSamples::body(const Runtime::CallingFrame &Frame, Expect AVGetSampleFmtNameLength::body(const Runtime::CallingFrame &, uint32_t SampleFmtId) { - AVSampleFormat const SampleFmt = FFmpegUtils::SampleFmt::fromSampleID(SampleFmtId); @@ -104,7 +106,6 @@ Expect AVGetSampleFmtName::body(const Runtime::CallingFrame &Frame, MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(SampleFmtBuf, MemInst, char, SampleFmtNamePtr, SampleFmtNameLen, ""); - AVSampleFormat const SampleFmt = FFmpegUtils::SampleFmt::fromSampleID(SampleFmtId); const char *Name = av_get_sample_fmt_name(SampleFmt); @@ -114,7 +115,6 @@ Expect AVGetSampleFmtName::body(const Runtime::CallingFrame &Frame, Expect AVGetSampleFmtMask::body(const Runtime::CallingFrame &, uint32_t SampleFmtId) { - AVSampleFormat const SampleFmt = FFmpegUtils::SampleFmt::fromSampleID(SampleFmtId); return static_cast(SampleFmt); diff --git a/plugins/wasmedge_ffmpeg/avutil/samplefmt.h b/plugins/wasmedge_ffmpeg/avutil/samplefmt.h index a6190779..02b0d0c4 100644 --- a/plugins/wasmedge_ffmpeg/avutil/samplefmt.h +++ b/plugins/wasmedge_ffmpeg/avutil/samplefmt.h @@ -1,4 +1,8 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once + #include "avutil_base.h" #include "runtime/callingframe.h" diff --git a/plugins/wasmedge_ffmpeg/bindings.h b/plugins/wasmedge_ffmpeg/bindings.h index 05858cd4..ac8fb6fb 100644 --- a/plugins/wasmedge_ffmpeg/bindings.h +++ b/plugins/wasmedge_ffmpeg/bindings.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once extern "C" { @@ -3515,244 +3518,352 @@ class ChannelLayout { // Check This function. (Looks good, test it) static uint64_t fromChannelLayoutID(uint64_t ChannelLayout) { uint64_t Channel = 0UL; - if (ChannelLayout & FRONT_LEFT) + if (ChannelLayout & FRONT_LEFT) { Channel |= AV_CH_FRONT_LEFT; - if (ChannelLayout & FRONT_RIGHT) + } + if (ChannelLayout & FRONT_RIGHT) { Channel |= AV_CH_FRONT_RIGHT; - if (ChannelLayout & FRONT_CENTER) + } + if (ChannelLayout & FRONT_CENTER) { Channel |= AV_CH_FRONT_CENTER; - if (ChannelLayout & LOW_FREQUENCY) + } + if (ChannelLayout & LOW_FREQUENCY) { Channel |= AV_CH_LOW_FREQUENCY; - if (ChannelLayout & BACK_LEFT) + } + if (ChannelLayout & BACK_LEFT) { Channel |= AV_CH_BACK_LEFT; - if (ChannelLayout & BACK_RIGHT) + } + if (ChannelLayout & BACK_RIGHT) { Channel |= AV_CH_BACK_RIGHT; - if (ChannelLayout & FRONT_LEFT_OF_CENTER) + } + if (ChannelLayout & FRONT_LEFT_OF_CENTER) { Channel |= AV_CH_FRONT_LEFT_OF_CENTER; - if (ChannelLayout & FRONT_RIGHT_OF_CENTER) + } + if (ChannelLayout & FRONT_RIGHT_OF_CENTER) { Channel |= AV_CH_FRONT_RIGHT_OF_CENTER; - if (ChannelLayout & BACK_CENTER) + } + if (ChannelLayout & BACK_CENTER) { Channel |= AV_CH_BACK_CENTER; - if (ChannelLayout & SIDE_LEFT) + } + if (ChannelLayout & SIDE_LEFT) { Channel |= AV_CH_SIDE_LEFT; - if (ChannelLayout & SIDE_RIGHT) + } + if (ChannelLayout & SIDE_RIGHT) { Channel |= AV_CH_SIDE_RIGHT; - if (ChannelLayout & TOP_CENTER) + } + if (ChannelLayout & TOP_CENTER) { Channel |= AV_CH_TOP_CENTER; - if (ChannelLayout & TOP_FRONT_LEFT) + } + if (ChannelLayout & TOP_FRONT_LEFT) { Channel |= AV_CH_TOP_FRONT_LEFT; - if (ChannelLayout & TOP_FRONT_CENTER) + } + if (ChannelLayout & TOP_FRONT_CENTER) { Channel |= AV_CH_TOP_FRONT_CENTER; - if (ChannelLayout & TOP_FRONT_RIGHT) + } + if (ChannelLayout & TOP_FRONT_RIGHT) { Channel |= AV_CH_TOP_FRONT_RIGHT; - if (ChannelLayout & TOP_BACK_LEFT) + } + if (ChannelLayout & TOP_BACK_LEFT) { Channel |= AV_CH_TOP_BACK_LEFT; - if (ChannelLayout & TOP_BACK_CENTER) + } + if (ChannelLayout & TOP_BACK_CENTER) { Channel |= AV_CH_TOP_BACK_CENTER; - if (ChannelLayout & TOP_BACK_RIGHT) + } + if (ChannelLayout & TOP_BACK_RIGHT) { Channel |= AV_CH_TOP_BACK_RIGHT; - if (ChannelLayout & STEREO_LEFT) + } + if (ChannelLayout & STEREO_LEFT) { Channel |= AV_CH_STEREO_LEFT; - if (ChannelLayout & STEREO_RIGHT) + } + if (ChannelLayout & STEREO_RIGHT) { Channel |= AV_CH_STEREO_RIGHT; - if (ChannelLayout & WIDE_LEFT) + } + if (ChannelLayout & WIDE_LEFT) { Channel |= AV_CH_WIDE_LEFT; - if (ChannelLayout & WIDE_RIGHT) + } + if (ChannelLayout & WIDE_RIGHT) { Channel |= AV_CH_WIDE_RIGHT; - if (ChannelLayout & SURROUND_DIRECT_LEFT) + } + if (ChannelLayout & SURROUND_DIRECT_LEFT) { Channel |= AV_CH_SURROUND_DIRECT_LEFT; - if (ChannelLayout & SURROUND_DIRECT_RIGHT) + } + if (ChannelLayout & SURROUND_DIRECT_RIGHT) { Channel |= AV_CH_SURROUND_DIRECT_RIGHT; - if (ChannelLayout & LOW_FREQUENCY_2) + } + if (ChannelLayout & LOW_FREQUENCY_2) { Channel |= AV_CH_LOW_FREQUENCY_2; - if (ChannelLayout & NATIVE) + } + if (ChannelLayout & NATIVE) { Channel |= AV_CH_LAYOUT_NATIVE; - if (ChannelLayout & MONO) + } + if (ChannelLayout & MONO) { Channel |= AV_CH_LAYOUT_MONO; - if (ChannelLayout & STEREO) + } + if (ChannelLayout & STEREO) { Channel |= AV_CH_LAYOUT_STEREO; - if (ChannelLayout & _2POINT1) + } + if (ChannelLayout & _2POINT1) { Channel |= AV_CH_LAYOUT_2POINT1; - if (ChannelLayout & _2_1) + } + if (ChannelLayout & _2_1) { Channel |= AV_CH_LAYOUT_2_1; - if (ChannelLayout & SURROUND) + } + if (ChannelLayout & SURROUND) { Channel |= AV_CH_LAYOUT_SURROUND; - if (ChannelLayout & _3POINT1) + } + if (ChannelLayout & _3POINT1) { Channel |= AV_CH_LAYOUT_3POINT1; - if (ChannelLayout & _4POINT0) + } + if (ChannelLayout & _4POINT0) { Channel |= AV_CH_LAYOUT_4POINT0; - if (ChannelLayout & _4POINT1) + } + if (ChannelLayout & _4POINT1) { Channel |= AV_CH_LAYOUT_4POINT1; - if (ChannelLayout & _2_2) + } + if (ChannelLayout & _2_2) { Channel |= AV_CH_LAYOUT_2_2; - if (ChannelLayout & QUAD) + } + if (ChannelLayout & QUAD) { Channel |= AV_CH_LAYOUT_QUAD; - if (ChannelLayout & _5POINT0) + } + if (ChannelLayout & _5POINT0) { Channel |= AV_CH_LAYOUT_5POINT0; - if (ChannelLayout & _5POINT1) + } + if (ChannelLayout & _5POINT1) { Channel |= AV_CH_LAYOUT_5POINT1; - if (ChannelLayout & _5POINT0_BACK) + } + if (ChannelLayout & _5POINT0_BACK) { Channel |= AV_CH_LAYOUT_5POINT0_BACK; - if (ChannelLayout & _5POINT1_BACK) + } + if (ChannelLayout & _5POINT1_BACK) { Channel |= AV_CH_LAYOUT_5POINT1_BACK; - if (ChannelLayout & _6POINT0) + } + if (ChannelLayout & _6POINT0) { Channel |= AV_CH_LAYOUT_6POINT0; - if (ChannelLayout & _6POINT0_FRONT) + } + if (ChannelLayout & _6POINT0_FRONT) { Channel |= AV_CH_LAYOUT_6POINT0_FRONT; - if (ChannelLayout & HEXAGONAL) + } + if (ChannelLayout & HEXAGONAL) { Channel |= AV_CH_LAYOUT_HEXAGONAL; - if (ChannelLayout & _6POINT1) + } + if (ChannelLayout & _6POINT1) { Channel |= AV_CH_LAYOUT_6POINT1; - if (ChannelLayout & _6POINT1_BACK) + } + if (ChannelLayout & _6POINT1_BACK) { Channel |= AV_CH_LAYOUT_6POINT1_BACK; - if (ChannelLayout & _6POINT1_FRONT) + } + if (ChannelLayout & _6POINT1_FRONT) { Channel |= AV_CH_LAYOUT_6POINT1_FRONT; - if (ChannelLayout & _7POINT0) + } + if (ChannelLayout & _7POINT0) { Channel |= AV_CH_LAYOUT_7POINT0; - if (ChannelLayout & _7POINT0_FRONT) + } + if (ChannelLayout & _7POINT0_FRONT) { Channel |= AV_CH_LAYOUT_7POINT0_FRONT; - if (ChannelLayout & _7POINT1) + } + if (ChannelLayout & _7POINT1) { Channel |= AV_CH_LAYOUT_7POINT1; - if (ChannelLayout & _7POINT1_WIDE) + } + if (ChannelLayout & _7POINT1_WIDE) { Channel |= AV_CH_LAYOUT_7POINT1_WIDE; - if (ChannelLayout & _7POINT1_WIDE_BACK) + } + if (ChannelLayout & _7POINT1_WIDE_BACK) { Channel |= AV_CH_LAYOUT_7POINT1_WIDE_BACK; - if (ChannelLayout & OCTAGONAL) + } + if (ChannelLayout & OCTAGONAL) { Channel |= AV_CH_LAYOUT_OCTAGONAL; - if (ChannelLayout & HEXADECAGONAL) + } + if (ChannelLayout & HEXADECAGONAL) { Channel |= AV_CH_LAYOUT_HEXADECAGONAL; - if (ChannelLayout & STEREO_DOWNMIX) + } + if (ChannelLayout & STEREO_DOWNMIX) { Channel |= AV_CH_LAYOUT_STEREO_DOWNMIX; + } return Channel; } // Perfect Logic :) static uint64_t intoChannelLayoutID(uint64_t ChannelLayout) { uint64_t Channel = 0; - if ((ChannelLayout & AV_CH_FRONT_LEFT) == AV_CH_FRONT_LEFT) + if ((ChannelLayout & AV_CH_FRONT_LEFT) == AV_CH_FRONT_LEFT) { Channel |= FRONT_LEFT; - if ((ChannelLayout & AV_CH_FRONT_RIGHT) == AV_CH_FRONT_RIGHT) + } + if ((ChannelLayout & AV_CH_FRONT_RIGHT) == AV_CH_FRONT_RIGHT) { Channel |= FRONT_RIGHT; - if ((ChannelLayout & AV_CH_FRONT_CENTER) == AV_CH_FRONT_CENTER) + } + if ((ChannelLayout & AV_CH_FRONT_CENTER) == AV_CH_FRONT_CENTER) { Channel |= FRONT_CENTER; - if ((ChannelLayout & AV_CH_LOW_FREQUENCY) == AV_CH_LOW_FREQUENCY) + } + if ((ChannelLayout & AV_CH_LOW_FREQUENCY) == AV_CH_LOW_FREQUENCY) { Channel |= LOW_FREQUENCY; - if ((ChannelLayout & AV_CH_BACK_LEFT) == AV_CH_BACK_LEFT) + } + if ((ChannelLayout & AV_CH_BACK_LEFT) == AV_CH_BACK_LEFT) { Channel |= BACK_LEFT; - if ((ChannelLayout & AV_CH_BACK_RIGHT) == AV_CH_BACK_RIGHT) + } + if ((ChannelLayout & AV_CH_BACK_RIGHT) == AV_CH_BACK_RIGHT) { Channel |= BACK_RIGHT; + } if ((ChannelLayout & AV_CH_FRONT_LEFT_OF_CENTER) == - AV_CH_FRONT_LEFT_OF_CENTER) + AV_CH_FRONT_LEFT_OF_CENTER) { Channel |= FRONT_LEFT_OF_CENTER; + } if ((ChannelLayout & AV_CH_FRONT_RIGHT_OF_CENTER) == - AV_CH_FRONT_RIGHT_OF_CENTER) + AV_CH_FRONT_RIGHT_OF_CENTER) { Channel |= FRONT_RIGHT_OF_CENTER; - if ((ChannelLayout & AV_CH_BACK_CENTER) == AV_CH_BACK_CENTER) + } + if ((ChannelLayout & AV_CH_BACK_CENTER) == AV_CH_BACK_CENTER) { Channel |= BACK_CENTER; - if ((ChannelLayout & AV_CH_SIDE_LEFT) == AV_CH_SIDE_LEFT) + } + if ((ChannelLayout & AV_CH_SIDE_LEFT) == AV_CH_SIDE_LEFT) { Channel |= SIDE_LEFT; - if ((ChannelLayout & AV_CH_SIDE_RIGHT) == AV_CH_SIDE_RIGHT) + } + if ((ChannelLayout & AV_CH_SIDE_RIGHT) == AV_CH_SIDE_RIGHT) { Channel |= SIDE_RIGHT; - if ((ChannelLayout & AV_CH_TOP_CENTER) == AV_CH_TOP_CENTER) + } + if ((ChannelLayout & AV_CH_TOP_CENTER) == AV_CH_TOP_CENTER) { Channel |= TOP_CENTER; - if ((ChannelLayout & AV_CH_TOP_FRONT_LEFT) == AV_CH_TOP_FRONT_LEFT) + } + if ((ChannelLayout & AV_CH_TOP_FRONT_LEFT) == AV_CH_TOP_FRONT_LEFT) { Channel |= TOP_FRONT_LEFT; - if ((ChannelLayout & AV_CH_TOP_FRONT_CENTER) == AV_CH_TOP_FRONT_CENTER) + } + if ((ChannelLayout & AV_CH_TOP_FRONT_CENTER) == AV_CH_TOP_FRONT_CENTER) { Channel |= TOP_FRONT_CENTER; - if ((ChannelLayout & AV_CH_TOP_FRONT_RIGHT) == AV_CH_TOP_FRONT_RIGHT) + } + if ((ChannelLayout & AV_CH_TOP_FRONT_RIGHT) == AV_CH_TOP_FRONT_RIGHT) { Channel |= TOP_FRONT_RIGHT; - if ((ChannelLayout & AV_CH_TOP_BACK_LEFT) == AV_CH_TOP_BACK_LEFT) + } + if ((ChannelLayout & AV_CH_TOP_BACK_LEFT) == AV_CH_TOP_BACK_LEFT) { Channel |= TOP_BACK_LEFT; - if ((ChannelLayout & AV_CH_TOP_BACK_CENTER) == AV_CH_TOP_BACK_CENTER) + } + if ((ChannelLayout & AV_CH_TOP_BACK_CENTER) == AV_CH_TOP_BACK_CENTER) { Channel |= TOP_BACK_CENTER; - if ((ChannelLayout & AV_CH_TOP_BACK_RIGHT) == AV_CH_TOP_BACK_RIGHT) + } + if ((ChannelLayout & AV_CH_TOP_BACK_RIGHT) == AV_CH_TOP_BACK_RIGHT) { Channel |= TOP_BACK_RIGHT; - if ((ChannelLayout & AV_CH_STEREO_LEFT) == AV_CH_STEREO_LEFT) + } + if ((ChannelLayout & AV_CH_STEREO_LEFT) == AV_CH_STEREO_LEFT) { Channel |= STEREO_LEFT; - if ((ChannelLayout & AV_CH_STEREO_RIGHT) == AV_CH_STEREO_RIGHT) + } + if ((ChannelLayout & AV_CH_STEREO_RIGHT) == AV_CH_STEREO_RIGHT) { Channel |= STEREO_RIGHT; - if ((ChannelLayout & AV_CH_WIDE_LEFT) == AV_CH_WIDE_LEFT) + } + if ((ChannelLayout & AV_CH_WIDE_LEFT) == AV_CH_WIDE_LEFT) { Channel |= WIDE_LEFT; - if ((ChannelLayout & AV_CH_WIDE_RIGHT) == AV_CH_WIDE_RIGHT) + } + if ((ChannelLayout & AV_CH_WIDE_RIGHT) == AV_CH_WIDE_RIGHT) { Channel |= WIDE_RIGHT; + } if ((ChannelLayout & AV_CH_SURROUND_DIRECT_LEFT) == - AV_CH_SURROUND_DIRECT_LEFT) + AV_CH_SURROUND_DIRECT_LEFT) { Channel |= SURROUND_DIRECT_LEFT; + } if ((ChannelLayout & AV_CH_SURROUND_DIRECT_RIGHT) == - AV_CH_SURROUND_DIRECT_RIGHT) + AV_CH_SURROUND_DIRECT_RIGHT) { Channel |= SURROUND_DIRECT_RIGHT; - if ((ChannelLayout & AV_CH_LOW_FREQUENCY_2) == AV_CH_LOW_FREQUENCY_2) + } + if ((ChannelLayout & AV_CH_LOW_FREQUENCY_2) == AV_CH_LOW_FREQUENCY_2) { Channel |= LOW_FREQUENCY_2; + } // Channel Mask C; - if ((ChannelLayout & AV_CH_LAYOUT_NATIVE) == AV_CH_LAYOUT_NATIVE) + if ((ChannelLayout & AV_CH_LAYOUT_NATIVE) == AV_CH_LAYOUT_NATIVE) { Channel |= NATIVE; - if ((ChannelLayout & AV_CH_LAYOUT_MONO) == AV_CH_LAYOUT_MONO) + } + if ((ChannelLayout & AV_CH_LAYOUT_MONO) == AV_CH_LAYOUT_MONO) { Channel |= MONO; - if ((ChannelLayout & AV_CH_LAYOUT_STEREO) == AV_CH_LAYOUT_STEREO) + } + if ((ChannelLayout & AV_CH_LAYOUT_STEREO) == AV_CH_LAYOUT_STEREO) { Channel |= STEREO; - if ((ChannelLayout & AV_CH_LAYOUT_2POINT1) == AV_CH_LAYOUT_2POINT1) + } + if ((ChannelLayout & AV_CH_LAYOUT_2POINT1) == AV_CH_LAYOUT_2POINT1) { Channel |= _2POINT1; - if ((ChannelLayout & AV_CH_LAYOUT_2_1) == AV_CH_LAYOUT_2_1) + } + if ((ChannelLayout & AV_CH_LAYOUT_2_1) == AV_CH_LAYOUT_2_1) { Channel |= _2_1; - if ((ChannelLayout & AV_CH_LAYOUT_SURROUND) == AV_CH_LAYOUT_SURROUND) + } + if ((ChannelLayout & AV_CH_LAYOUT_SURROUND) == AV_CH_LAYOUT_SURROUND) { Channel |= SURROUND; - if ((ChannelLayout & AV_CH_LAYOUT_3POINT1) == AV_CH_LAYOUT_3POINT1) + } + if ((ChannelLayout & AV_CH_LAYOUT_3POINT1) == AV_CH_LAYOUT_3POINT1) { Channel |= _3POINT1; - if ((ChannelLayout & AV_CH_LAYOUT_4POINT0) == AV_CH_LAYOUT_4POINT0) + } + if ((ChannelLayout & AV_CH_LAYOUT_4POINT0) == AV_CH_LAYOUT_4POINT0) { Channel |= _4POINT0; - if ((ChannelLayout & AV_CH_LAYOUT_4POINT1) == AV_CH_LAYOUT_4POINT1) + } + if ((ChannelLayout & AV_CH_LAYOUT_4POINT1) == AV_CH_LAYOUT_4POINT1) { Channel |= _4POINT1; - if ((ChannelLayout & AV_CH_LAYOUT_2_2) == AV_CH_LAYOUT_2_2) + } + if ((ChannelLayout & AV_CH_LAYOUT_2_2) == AV_CH_LAYOUT_2_2) { Channel |= _2_2; - if ((ChannelLayout & AV_CH_LAYOUT_QUAD) == AV_CH_LAYOUT_QUAD) + } + if ((ChannelLayout & AV_CH_LAYOUT_QUAD) == AV_CH_LAYOUT_QUAD) { Channel |= QUAD; - if ((ChannelLayout & AV_CH_LAYOUT_5POINT0) == AV_CH_LAYOUT_5POINT0) + } + if ((ChannelLayout & AV_CH_LAYOUT_5POINT0) == AV_CH_LAYOUT_5POINT0) { Channel |= _5POINT0; - if ((ChannelLayout & AV_CH_LAYOUT_5POINT1) == AV_CH_LAYOUT_5POINT1) + } + if ((ChannelLayout & AV_CH_LAYOUT_5POINT1) == AV_CH_LAYOUT_5POINT1) { Channel |= _5POINT1; + } if ((ChannelLayout & AV_CH_LAYOUT_5POINT0_BACK) == - AV_CH_LAYOUT_5POINT0_BACK) + AV_CH_LAYOUT_5POINT0_BACK) { Channel |= _5POINT0_BACK; + } if ((ChannelLayout & AV_CH_LAYOUT_5POINT1_BACK) == - AV_CH_LAYOUT_5POINT1_BACK) + AV_CH_LAYOUT_5POINT1_BACK) { Channel |= _5POINT1_BACK; - if ((ChannelLayout & AV_CH_LAYOUT_6POINT0) == AV_CH_LAYOUT_6POINT0) + } + if ((ChannelLayout & AV_CH_LAYOUT_6POINT0) == AV_CH_LAYOUT_6POINT0) { Channel |= _6POINT0; + } if ((ChannelLayout & AV_CH_LAYOUT_6POINT0_FRONT) == - AV_CH_LAYOUT_6POINT0_FRONT) + AV_CH_LAYOUT_6POINT0_FRONT) { Channel |= _6POINT0_FRONT; - if ((ChannelLayout & AV_CH_LAYOUT_HEXAGONAL) == AV_CH_LAYOUT_HEXAGONAL) + } + if ((ChannelLayout & AV_CH_LAYOUT_HEXAGONAL) == AV_CH_LAYOUT_HEXAGONAL) { Channel |= HEXAGONAL; - if ((ChannelLayout & AV_CH_LAYOUT_6POINT1) == AV_CH_LAYOUT_6POINT1) + } + if ((ChannelLayout & AV_CH_LAYOUT_6POINT1) == AV_CH_LAYOUT_6POINT1) { Channel |= _6POINT1; + } if ((ChannelLayout & AV_CH_LAYOUT_6POINT1_BACK) == - AV_CH_LAYOUT_6POINT1_BACK) + AV_CH_LAYOUT_6POINT1_BACK) { Channel |= _6POINT1_BACK; + } if ((ChannelLayout & AV_CH_LAYOUT_6POINT1_FRONT) == - AV_CH_LAYOUT_6POINT1_FRONT) + AV_CH_LAYOUT_6POINT1_FRONT) { Channel |= _6POINT1_FRONT; - if ((ChannelLayout & AV_CH_LAYOUT_7POINT0) == AV_CH_LAYOUT_7POINT0) + } + if ((ChannelLayout & AV_CH_LAYOUT_7POINT0) == AV_CH_LAYOUT_7POINT0) { Channel |= _7POINT0; + } if ((ChannelLayout & AV_CH_LAYOUT_7POINT0_FRONT) == - AV_CH_LAYOUT_7POINT0_FRONT) + AV_CH_LAYOUT_7POINT0_FRONT) { Channel |= _7POINT0_FRONT; - if ((ChannelLayout & AV_CH_LAYOUT_7POINT1) == AV_CH_LAYOUT_7POINT1) + } + if ((ChannelLayout & AV_CH_LAYOUT_7POINT1) == AV_CH_LAYOUT_7POINT1) { Channel |= _7POINT1; + } if ((ChannelLayout & AV_CH_LAYOUT_7POINT1_WIDE) == - AV_CH_LAYOUT_7POINT1_WIDE) + AV_CH_LAYOUT_7POINT1_WIDE) { Channel |= _7POINT1_WIDE; + } if ((ChannelLayout & AV_CH_LAYOUT_7POINT1_WIDE_BACK) == - AV_CH_LAYOUT_7POINT1_WIDE_BACK) + AV_CH_LAYOUT_7POINT1_WIDE_BACK) { Channel |= _7POINT1_WIDE_BACK; - if ((ChannelLayout & AV_CH_LAYOUT_OCTAGONAL) == AV_CH_LAYOUT_OCTAGONAL) + } + if ((ChannelLayout & AV_CH_LAYOUT_OCTAGONAL) == AV_CH_LAYOUT_OCTAGONAL) { Channel |= OCTAGONAL; + } if ((ChannelLayout & AV_CH_LAYOUT_HEXADECAGONAL) == - AV_CH_LAYOUT_HEXADECAGONAL) + AV_CH_LAYOUT_HEXADECAGONAL) { Channel |= HEXADECAGONAL; + } if ((ChannelLayout & AV_CH_LAYOUT_STEREO_DOWNMIX) == - AV_CH_LAYOUT_STEREO_DOWNMIX) + AV_CH_LAYOUT_STEREO_DOWNMIX) { Channel |= STEREO_DOWNMIX; + } return Channel; } }; diff --git a/plugins/wasmedge_ffmpeg/ffmpeg_env.cpp b/plugins/wasmedge_ffmpeg/ffmpeg_env.cpp index 8ac02fc1..fe70b0b4 100644 --- a/plugins/wasmedge_ffmpeg/ffmpeg_env.cpp +++ b/plugins/wasmedge_ffmpeg/ffmpeg_env.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "ffmpeg_env.h" #include "avcodec/module.h" #include "avdevice/module.h" @@ -108,5 +111,6 @@ std::weak_ptr std::make_shared(); std::shared_mutex WasmEdgeFFmpeg::WasmEdgeFFmpegEnv::Mutex; + } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/ffmpeg_env.h b/plugins/wasmedge_ffmpeg/ffmpeg_env.h index ebb8ba98..300de141 100644 --- a/plugins/wasmedge_ffmpeg/ffmpeg_env.h +++ b/plugins/wasmedge_ffmpeg/ffmpeg_env.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "bindings.h" diff --git a/plugins/wasmedge_ffmpeg/swresample/module.cpp b/plugins/wasmedge_ffmpeg/swresample/module.cpp index 00d617db..5b5f9867 100644 --- a/plugins/wasmedge_ffmpeg/swresample/module.cpp +++ b/plugins/wasmedge_ffmpeg/swresample/module.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "module.h" #include "swresample_func.h" @@ -9,7 +12,6 @@ namespace SWResample { WasmEdgeFFmpegSWResampleModule::WasmEdgeFFmpegSWResampleModule( std::shared_ptr Env) : ModuleInstance("wasmedge_ffmpeg_swresample") { - addHostFunc("wasmedge_ffmpeg_swresample_swresample_version", std::make_unique(Env)); addHostFunc("wasmedge_ffmpeg_swresample_swr_get_delay", diff --git a/plugins/wasmedge_ffmpeg/swresample/module.h b/plugins/wasmedge_ffmpeg/swresample/module.h index a47d966b..0d1aa42f 100644 --- a/plugins/wasmedge_ffmpeg/swresample/module.h +++ b/plugins/wasmedge_ffmpeg/swresample/module.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "ffmpeg_env.h" diff --git a/plugins/wasmedge_ffmpeg/swresample/swresample_base.h b/plugins/wasmedge_ffmpeg/swresample/swresample_base.h index 574dcd20..b3bd6078 100644 --- a/plugins/wasmedge_ffmpeg/swresample/swresample_base.h +++ b/plugins/wasmedge_ffmpeg/swresample/swresample_base.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "ffmpeg_env.h" diff --git a/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp b/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp index a28e7520..3c81578c 100644 --- a/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp +++ b/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "swresample_func.h" extern "C" { @@ -17,14 +20,12 @@ Expect SWResampleVersion::body(const Runtime::CallingFrame &) { Expect SWRGetDelay::body(const Runtime::CallingFrame &, uint32_t SWRContextId, int64_t Base) { - FFMPEG_PTR_FETCH(SWRContext, SWRContextId, SwrContext); return swr_get_delay(SWRContext, Base); } Expect SWRInit::body(const Runtime::CallingFrame &, uint32_t SWRContextId) { - FFMPEG_PTR_FETCH(SWRContext, SWRContextId, SwrContext); return swr_init(SWRContext); } @@ -35,7 +36,6 @@ SWRAllocSetOpts::body(const Runtime::CallingFrame &Frame, uint32_t SwrCtxPtr, uint32_t OutSampleFmtId, int32_t OutSampleRate, uint64_t InChLayoutId, uint32_t InSampleFmtId, int32_t InSampleRate, int32_t LogOffset) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(SwrCtxId, MemInst, uint32_t, SwrCtxPtr, "") FFMPEG_PTR_FETCH(CurrSwrCtx, *SwrCtxId, SwrContext); @@ -59,7 +59,6 @@ SWRAllocSetOpts::body(const Runtime::CallingFrame &Frame, uint32_t SwrCtxPtr, Expect AVOptSetDict::body(const Runtime::CallingFrame &, uint32_t SWRContextId, uint32_t DictId) { - FFMPEG_PTR_FETCH(SWRContext, SWRContextId, SwrContext); FFMPEG_PTR_FETCH(AvDictionary, DictId, AVDictionary *); return av_opt_set_dict(SWRContext, AvDictionary); @@ -93,7 +92,6 @@ SWResampleConfigurationLength::body(const Runtime::CallingFrame &) { Expect SWResampleConfiguration::body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, uint32_t ConfigLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); @@ -103,7 +101,6 @@ SWResampleConfiguration::body(const Runtime::CallingFrame &Frame, } Expect SWResampleLicenseLength::body(const Runtime::CallingFrame &) { - const char *License = swresample_license(); return strlen(License); } @@ -111,7 +108,6 @@ Expect SWResampleLicenseLength::body(const Runtime::CallingFrame &) { Expect SWResampleLicense::body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, uint32_t LicenseLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); diff --git a/plugins/wasmedge_ffmpeg/swresample/swresample_func.h b/plugins/wasmedge_ffmpeg/swresample/swresample_func.h index a79ae568..5cfa38c2 100644 --- a/plugins/wasmedge_ffmpeg/swresample/swresample_func.h +++ b/plugins/wasmedge_ffmpeg/swresample/swresample_func.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "ffmpeg_env.h" diff --git a/plugins/wasmedge_ffmpeg/swscale/module.cpp b/plugins/wasmedge_ffmpeg/swscale/module.cpp index f33cadd4..da84ee06 100644 --- a/plugins/wasmedge_ffmpeg/swscale/module.cpp +++ b/plugins/wasmedge_ffmpeg/swscale/module.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "module.h" #include "swscale_func.h" @@ -9,7 +12,6 @@ namespace SWScale { WasmEdgeFFmpegSWScaleModule::WasmEdgeFFmpegSWScaleModule( std::shared_ptr Env) : ModuleInstance("wasmedge_ffmpeg_swscale") { - addHostFunc("wasmedge_ffmpeg_swscale_swscale_version", std::make_unique(Env)); addHostFunc("wasmedge_ffmpeg_swscale_swscale_configuration_length", diff --git a/plugins/wasmedge_ffmpeg/swscale/module.h b/plugins/wasmedge_ffmpeg/swscale/module.h index bc53ee2f..69fde0d7 100644 --- a/plugins/wasmedge_ffmpeg/swscale/module.h +++ b/plugins/wasmedge_ffmpeg/swscale/module.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "ffmpeg_env.h" diff --git a/plugins/wasmedge_ffmpeg/swscale/swscale_base.h b/plugins/wasmedge_ffmpeg/swscale/swscale_base.h index 32dc9cf1..3ae1b62d 100644 --- a/plugins/wasmedge_ffmpeg/swscale/swscale_base.h +++ b/plugins/wasmedge_ffmpeg/swscale/swscale_base.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "ffmpeg_env.h" diff --git a/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp b/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp index 7b126607..5da058cc 100644 --- a/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp +++ b/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "swscale_func.h" extern "C" { @@ -15,7 +18,6 @@ SwsGetContext::body(const Runtime::CallingFrame &Frame, uint32_t SwsCtxPtr, uint32_t SrcW, uint32_t SrcH, uint32_t SrcPixFormatId, uint32_t DesW, uint32_t DesH, uint32_t DesPixFormatId, int32_t Flags, uint32_t SrcFilterId, uint32_t DesFilterId) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(SwsCtxId, MemInst, uint32_t, SwsCtxPtr, "Failed when accessing the return SWSContext Memory"sv) @@ -31,15 +33,15 @@ SwsGetContext::body(const Runtime::CallingFrame &Frame, uint32_t SwsCtxPtr, SwsCtx = sws_getContext(SrcW, SrcH, SrcPixelFormat, DesW, DesH, DestPixelFormat, Flags, SrcSwsFilter, DesSwsFilter, nullptr); // Not using param anywhere in Rust SDK. - if (SwsCtx == nullptr) + if (SwsCtx == nullptr) { return static_cast(ErrNo::InternalError); + } FFMPEG_PTR_STORE(SwsCtx, SwsCtxId); return static_cast(ErrNo::Success); } Expect SwsFreeContext::body(const Runtime::CallingFrame &, uint32_t SwsCtxId) { - FFMPEG_PTR_FETCH(SwsCtx, SwsCtxId, SwsContext) sws_freeContext(SwsCtx); FFMPEG_PTR_DELETE(SwsCtxId); @@ -49,7 +51,6 @@ Expect SwsFreeContext::body(const Runtime::CallingFrame &, Expect SwsScale::body(const Runtime::CallingFrame &, uint32_t SwsCtxId, uint32_t InputFrameId, int32_t SrcSliceY, int32_t SrcSliceH, uint32_t OutputFrameId) { - FFMPEG_PTR_FETCH(SwsCtx, SwsCtxId, SwsContext); FFMPEG_PTR_FETCH(InputFrame, InputFrameId, AVFrame); FFMPEG_PTR_FETCH(OutputFrame, OutputFrameId, AVFrame); @@ -62,7 +63,6 @@ Expect SwsGetCachedContext::body( uint32_t SwsCtxId, uint32_t SrcW, uint32_t SrcH, uint32_t SrcPixFormatId, uint32_t DesW, uint32_t DesH, uint32_t DesPixFormatId, int32_t Flags, uint32_t SrcFilterId, uint32_t DesFilterId) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(SwsCachedCtxId, MemInst, uint32_t, SwsCachedCtxPtr, "") @@ -78,8 +78,9 @@ Expect SwsGetCachedContext::body( SwsCachedCtx = sws_getCachedContext(SwsCtx, SrcW, SrcH, SrcPixelFormat, DesW, DesH, DestPixelFormat, Flags, SrcSwsFilter, DesSwsFilter, nullptr); - if (SwsCachedCtx == nullptr) + if (SwsCachedCtx == nullptr) { return static_cast(ErrNo::InternalError); + } FFMPEG_PTR_STORE(SwsCachedCtx, SwsCachedCtxId); return static_cast(ErrNo::Success); @@ -111,22 +112,21 @@ Expect SwsGetDefaultFilter::body( const Runtime::CallingFrame &Frame, uint32_t SwsFilterPtr, float LumaGBlur, float ChromaGBlur, float LumaSharpen, float ChromaSharpen, float ChromaHShift, float ChromaVShift, int32_t Verbose) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(SwsFilterId, MemInst, uint32_t, SwsFilterPtr, "") SwsFilter *Filter = sws_getDefaultFilter(LumaGBlur, ChromaGBlur, LumaSharpen, ChromaSharpen, ChromaHShift, ChromaVShift, Verbose); - if (Filter == nullptr) + if (Filter == nullptr) { return static_cast(ErrNo::InternalError); + } FFMPEG_PTR_STORE(Filter, SwsFilterId); return static_cast(ErrNo::Success); } Expect SwsGetLumaH::body(const Runtime::CallingFrame &Frame, uint32_t SwsFilterId, uint32_t SwsVectorPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(SwsVectorId, MemInst, uint32_t, SwsVectorPtr, "") FFMPEG_PTR_FETCH(Filter, SwsFilterId, SwsFilter); @@ -138,7 +138,6 @@ Expect SwsGetLumaH::body(const Runtime::CallingFrame &Frame, Expect SwsGetLumaV::body(const Runtime::CallingFrame &Frame, uint32_t SwsFilterId, uint32_t SwsVectorPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(SwsVectorId, MemInst, uint32_t, SwsVectorPtr, "") FFMPEG_PTR_FETCH(Filter, SwsFilterId, SwsFilter); @@ -151,7 +150,6 @@ Expect SwsGetLumaV::body(const Runtime::CallingFrame &Frame, Expect SwsGetChromaH::body(const Runtime::CallingFrame &Frame, uint32_t SwsFilterId, uint32_t SwsVectorPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(SwsVectorId, MemInst, uint32_t, SwsVectorPtr, "") FFMPEG_PTR_FETCH(Filter, SwsFilterId, SwsFilter); @@ -164,7 +162,6 @@ Expect SwsGetChromaH::body(const Runtime::CallingFrame &Frame, Expect SwsGetChromaV::body(const Runtime::CallingFrame &Frame, uint32_t SwsFilterId, uint32_t SwsVectorPtr) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(SwsVectorId, MemInst, uint32_t, SwsVectorPtr, "") FFMPEG_PTR_FETCH(Filter, SwsFilterId, SwsFilter); @@ -176,7 +173,6 @@ Expect SwsGetChromaV::body(const Runtime::CallingFrame &Frame, Expect SwsFreeFilter::body(const Runtime::CallingFrame &, uint32_t SwsFilterId) { - FFMPEG_PTR_FETCH(Filter, SwsFilterId, SwsFilter); sws_freeFilter(Filter); FFMPEG_PTR_DELETE(SwsFilterId); @@ -185,7 +181,6 @@ Expect SwsFreeFilter::body(const Runtime::CallingFrame &, Expect SwsAllocVec::body(const Runtime::CallingFrame &Frame, uint32_t SwsVectorPtr, int32_t Length) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(SwsVectorId, MemInst, uint32_t, SwsVectorPtr, "") @@ -197,7 +192,6 @@ Expect SwsAllocVec::body(const Runtime::CallingFrame &Frame, Expect SwsGetGaussianVec::body(const Runtime::CallingFrame &Frame, uint32_t SwsVectorPtr, double Variance, double Quality) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_PTR_CHECK(SwsVectorId, MemInst, uint32_t, SwsVectorPtr, "") @@ -208,7 +202,6 @@ Expect SwsGetGaussianVec::body(const Runtime::CallingFrame &Frame, Expect SwsScaleVec::body(const Runtime::CallingFrame &, uint32_t SwsVectorId, double Scalar) { - FFMPEG_PTR_FETCH(Vector, SwsVectorId, SwsVector); sws_scaleVec(Vector, Scalar); return static_cast(ErrNo::Success); @@ -216,7 +209,6 @@ Expect SwsScaleVec::body(const Runtime::CallingFrame &, Expect SwsNormalizeVec::body(const Runtime::CallingFrame &, uint32_t SwsVectorId, double Height) { - FFMPEG_PTR_FETCH(Vector, SwsVectorId, SwsVector); sws_normalizeVec(Vector, Height); return static_cast(ErrNo::Success); @@ -224,7 +216,6 @@ Expect SwsNormalizeVec::body(const Runtime::CallingFrame &, Expect SwsGetCoeffVecLength::body(const Runtime::CallingFrame &, uint32_t SwsVectorId) { - FFMPEG_PTR_FETCH(Vector, SwsVectorId, SwsVector); return Vector->length * sizeof(double); // Getting the size in uint_8* (Cuz Passing uint8_t* @@ -234,7 +225,6 @@ Expect SwsGetCoeffVecLength::body(const Runtime::CallingFrame &, Expect SwsGetCoeff::body(const Runtime::CallingFrame &Frame, uint32_t SwsVectorId, uint32_t CoeffBufPtr, uint32_t Len) { - MEMINST_CHECK(MemInst, Frame, 0) MEM_SPAN_CHECK(Buffer, MemInst, uint8_t, CoeffBufPtr, Len, ""); FFMPEG_PTR_FETCH(Vector, SwsVectorId, SwsVector); @@ -246,7 +236,6 @@ Expect SwsGetCoeff::body(const Runtime::CallingFrame &Frame, Expect SwsFreeVec::body(const Runtime::CallingFrame &, uint32_t SwsVectorId) { - FFMPEG_PTR_FETCH(Vector, SwsVectorId, SwsVector); sws_freeVec(Vector); FFMPEG_PTR_DELETE(SwsVectorId); @@ -266,7 +255,6 @@ SwscaleConfigurationLength::body(const Runtime::CallingFrame &) { Expect SwscaleConfiguration::body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, uint32_t ConfigLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); @@ -276,14 +264,12 @@ Expect SwscaleConfiguration::body(const Runtime::CallingFrame &Frame, } Expect SwscaleLicenseLength::body(const Runtime::CallingFrame &) { - const char *License = swscale_license(); return strlen(License); } Expect SwscaleLicense::body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, uint32_t LicenseLen) { - MEMINST_CHECK(MemInst, Frame, 0); MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); diff --git a/plugins/wasmedge_ffmpeg/swscale/swscale_func.h b/plugins/wasmedge_ffmpeg/swscale/swscale_func.h index dfb7ffbf..30ce313d 100644 --- a/plugins/wasmedge_ffmpeg/swscale/swscale_func.h +++ b/plugins/wasmedge_ffmpeg/swscale/swscale_func.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "runtime/callingframe.h" diff --git a/plugins/wasmedge_image/CMakeLists.txt b/plugins/wasmedge_image/CMakeLists.txt index 01762dbb..bb4b3d1a 100644 --- a/plugins/wasmedge_image/CMakeLists.txt +++ b/plugins/wasmedge_image/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC wasmedge_add_library(wasmedgePluginWasmEdgeImage SHARED diff --git a/plugins/wasmedge_image/image_base.h b/plugins/wasmedge_image/image_base.h index 733cdf18..47f501b2 100644 --- a/plugins/wasmedge_image/image_base.h +++ b/plugins/wasmedge_image/image_base.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_image/image_env.cpp b/plugins/wasmedge_image/image_env.cpp index a1f104ae..12ab8886 100644 --- a/plugins/wasmedge_image/image_env.cpp +++ b/plugins/wasmedge_image/image_env.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "image_env.h" #include "image_module.h" diff --git a/plugins/wasmedge_image/image_env.h b/plugins/wasmedge_image/image_env.h index b2b59c2c..837b7086 100644 --- a/plugins/wasmedge_image/image_env.h +++ b/plugins/wasmedge_image/image_env.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_image/image_func.cpp b/plugins/wasmedge_image/image_func.cpp index 0088ebff..785a8163 100644 --- a/plugins/wasmedge_image/image_func.cpp +++ b/plugins/wasmedge_image/image_func.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "image_func.h" diff --git a/plugins/wasmedge_image/image_func.h b/plugins/wasmedge_image/image_func.h index 099aa1a4..9b18c3d2 100644 --- a/plugins/wasmedge_image/image_func.h +++ b/plugins/wasmedge_image/image_func.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_image/image_module.cpp b/plugins/wasmedge_image/image_module.cpp index 89ad74c5..745cbb56 100644 --- a/plugins/wasmedge_image/image_module.cpp +++ b/plugins/wasmedge_image/image_module.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "image_module.h" #include "image_func.h" diff --git a/plugins/wasmedge_image/image_module.h b/plugins/wasmedge_image/image_module.h index a23b05f0..8d4f42b4 100644 --- a/plugins/wasmedge_image/image_module.h +++ b/plugins/wasmedge_image/image_module.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_opencvmini/CMakeLists.txt b/plugins/wasmedge_opencvmini/CMakeLists.txt index 6949132d..577ed424 100644 --- a/plugins/wasmedge_opencvmini/CMakeLists.txt +++ b/plugins/wasmedge_opencvmini/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2023 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC find_package(OpenCV 4 REQUIRED) diff --git a/plugins/wasmedge_opencvmini/opencvmini_base.h b/plugins/wasmedge_opencvmini/opencvmini_base.h index 1a4dba74..c9bc9d64 100644 --- a/plugins/wasmedge_opencvmini/opencvmini_base.h +++ b/plugins/wasmedge_opencvmini/opencvmini_base.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2023 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_opencvmini/opencvmini_env.cpp b/plugins/wasmedge_opencvmini/opencvmini_env.cpp index d5211492..499c69e0 100644 --- a/plugins/wasmedge_opencvmini/opencvmini_env.cpp +++ b/plugins/wasmedge_opencvmini/opencvmini_env.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2023 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "opencvmini_env.h" #include "opencvmini_module.h" diff --git a/plugins/wasmedge_opencvmini/opencvmini_env.h b/plugins/wasmedge_opencvmini/opencvmini_env.h index f2e606cb..98c8b785 100644 --- a/plugins/wasmedge_opencvmini/opencvmini_env.h +++ b/plugins/wasmedge_opencvmini/opencvmini_env.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2023 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_opencvmini/opencvmini_func.cpp b/plugins/wasmedge_opencvmini/opencvmini_func.cpp index 7173a0ed..112d2a08 100644 --- a/plugins/wasmedge_opencvmini/opencvmini_func.cpp +++ b/plugins/wasmedge_opencvmini/opencvmini_func.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2023 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "opencvmini_func.h" #include "common/defines.h" diff --git a/plugins/wasmedge_opencvmini/opencvmini_func.h b/plugins/wasmedge_opencvmini/opencvmini_func.h index 80953b7f..e2c80d76 100644 --- a/plugins/wasmedge_opencvmini/opencvmini_func.h +++ b/plugins/wasmedge_opencvmini/opencvmini_func.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2023 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_opencvmini/opencvmini_module.cpp b/plugins/wasmedge_opencvmini/opencvmini_module.cpp index 5aac89db..b4c60ed5 100644 --- a/plugins/wasmedge_opencvmini/opencvmini_module.cpp +++ b/plugins/wasmedge_opencvmini/opencvmini_module.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2023 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "opencvmini_module.h" #include "opencvmini_func.h" diff --git a/plugins/wasmedge_opencvmini/opencvmini_module.h b/plugins/wasmedge_opencvmini/opencvmini_module.h index 0175110d..3d9296a2 100644 --- a/plugins/wasmedge_opencvmini/opencvmini_module.h +++ b/plugins/wasmedge_opencvmini/opencvmini_module.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2023 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_process/CMakeLists.txt b/plugins/wasmedge_process/CMakeLists.txt index 0355b905..5ae1dc1a 100644 --- a/plugins/wasmedge_process/CMakeLists.txt +++ b/plugins/wasmedge_process/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC wasmedge_add_library(wasmedgePluginWasmEdgeProcess SHARED diff --git a/plugins/wasmedge_process/processbase.h b/plugins/wasmedge_process/processbase.h index 12c39357..f7d9fe6e 100644 --- a/plugins/wasmedge_process/processbase.h +++ b/plugins/wasmedge_process/processbase.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_process/processenv.cpp b/plugins/wasmedge_process/processenv.cpp index 32a318f7..774989d6 100644 --- a/plugins/wasmedge_process/processenv.cpp +++ b/plugins/wasmedge_process/processenv.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "processenv.h" #include "processmodule.h" diff --git a/plugins/wasmedge_process/processenv.h b/plugins/wasmedge_process/processenv.h index a7dc1161..9b3626dd 100644 --- a/plugins/wasmedge_process/processenv.h +++ b/plugins/wasmedge_process/processenv.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_process/processfunc.cpp b/plugins/wasmedge_process/processfunc.cpp index fe4e85b7..ded6e971 100644 --- a/plugins/wasmedge_process/processfunc.cpp +++ b/plugins/wasmedge_process/processfunc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "processfunc.h" diff --git a/plugins/wasmedge_process/processfunc.h b/plugins/wasmedge_process/processfunc.h index f23a4c41..9746d433 100644 --- a/plugins/wasmedge_process/processfunc.h +++ b/plugins/wasmedge_process/processfunc.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_process/processmodule.cpp b/plugins/wasmedge_process/processmodule.cpp index 163a1cf2..613be81d 100644 --- a/plugins/wasmedge_process/processmodule.cpp +++ b/plugins/wasmedge_process/processmodule.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "processmodule.h" #include "processfunc.h" diff --git a/plugins/wasmedge_process/processmodule.h b/plugins/wasmedge_process/processmodule.h index 0a8e5bac..6482ee68 100644 --- a/plugins/wasmedge_process/processmodule.h +++ b/plugins/wasmedge_process/processmodule.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index cc94f40a..11097974 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC # setup stable diffusion message(STATUS "Downloading stable diffusion source") diff --git a/plugins/wasmedge_stablediffusion/sd_base.h b/plugins/wasmedge_stablediffusion/sd_base.h index c8ae12f0..5ba7441c 100644 --- a/plugins/wasmedge_stablediffusion/sd_base.h +++ b/plugins/wasmedge_stablediffusion/sd_base.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_stablediffusion/sd_env.cpp b/plugins/wasmedge_stablediffusion/sd_env.cpp index 04753eec..73c618fa 100644 --- a/plugins/wasmedge_stablediffusion/sd_env.cpp +++ b/plugins/wasmedge_stablediffusion/sd_env.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "sd_env.h" #include "sd_module.h" diff --git a/plugins/wasmedge_stablediffusion/sd_env.h b/plugins/wasmedge_stablediffusion/sd_env.h index 6b64b28a..4c3a278c 100644 --- a/plugins/wasmedge_stablediffusion/sd_env.h +++ b/plugins/wasmedge_stablediffusion/sd_env.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_stablediffusion/sd_func.cpp b/plugins/wasmedge_stablediffusion/sd_func.cpp index b9ccf33f..7aca1854 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.cpp +++ b/plugins/wasmedge_stablediffusion/sd_func.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "sd_func.h" #include "common/spdlog.h" diff --git a/plugins/wasmedge_stablediffusion/sd_func.h b/plugins/wasmedge_stablediffusion/sd_func.h index 0e3135cc..88166c79 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.h +++ b/plugins/wasmedge_stablediffusion/sd_func.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_stablediffusion/sd_module.cpp b/plugins/wasmedge_stablediffusion/sd_module.cpp index 5938db3a..a568c4ab 100644 --- a/plugins/wasmedge_stablediffusion/sd_module.cpp +++ b/plugins/wasmedge_stablediffusion/sd_module.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "sd_module.h" #include "sd_func.h" diff --git a/plugins/wasmedge_stablediffusion/sd_module.h b/plugins/wasmedge_stablediffusion/sd_module.h index e681ba06..bfc7ba72 100644 --- a/plugins/wasmedge_stablediffusion/sd_module.h +++ b/plugins/wasmedge_stablediffusion/sd_module.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_tensorflow/CMakeLists.txt b/plugins/wasmedge_tensorflow/CMakeLists.txt index ba93b141..7189c2bb 100644 --- a/plugins/wasmedge_tensorflow/CMakeLists.txt +++ b/plugins/wasmedge_tensorflow/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC wasmedge_add_library(wasmedgePluginWasmEdgeTensorflow SHARED diff --git a/plugins/wasmedge_tensorflow/tensorflow_base.h b/plugins/wasmedge_tensorflow/tensorflow_base.h index 56686110..fb17fec5 100644 --- a/plugins/wasmedge_tensorflow/tensorflow_base.h +++ b/plugins/wasmedge_tensorflow/tensorflow_base.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_tensorflow/tensorflow_env.cpp b/plugins/wasmedge_tensorflow/tensorflow_env.cpp index 04fb5507..98312b14 100644 --- a/plugins/wasmedge_tensorflow/tensorflow_env.cpp +++ b/plugins/wasmedge_tensorflow/tensorflow_env.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "tensorflow_env.h" #include "tensorflow_module.h" diff --git a/plugins/wasmedge_tensorflow/tensorflow_env.h b/plugins/wasmedge_tensorflow/tensorflow_env.h index 46160dcb..5fd4ef3c 100644 --- a/plugins/wasmedge_tensorflow/tensorflow_env.h +++ b/plugins/wasmedge_tensorflow/tensorflow_env.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_tensorflow/tensorflow_func.cpp b/plugins/wasmedge_tensorflow/tensorflow_func.cpp index 6b4ba440..d86143f4 100644 --- a/plugins/wasmedge_tensorflow/tensorflow_func.cpp +++ b/plugins/wasmedge_tensorflow/tensorflow_func.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "tensorflow_func.h" diff --git a/plugins/wasmedge_tensorflow/tensorflow_func.h b/plugins/wasmedge_tensorflow/tensorflow_func.h index 6a2786dc..54b5e76a 100644 --- a/plugins/wasmedge_tensorflow/tensorflow_func.h +++ b/plugins/wasmedge_tensorflow/tensorflow_func.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_tensorflow/tensorflow_module.cpp b/plugins/wasmedge_tensorflow/tensorflow_module.cpp index 73b079c4..f0703e45 100644 --- a/plugins/wasmedge_tensorflow/tensorflow_module.cpp +++ b/plugins/wasmedge_tensorflow/tensorflow_module.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "tensorflow_module.h" #include "tensorflow_func.h" diff --git a/plugins/wasmedge_tensorflow/tensorflow_module.h b/plugins/wasmedge_tensorflow/tensorflow_module.h index ae60330e..dfb96f9d 100644 --- a/plugins/wasmedge_tensorflow/tensorflow_module.h +++ b/plugins/wasmedge_tensorflow/tensorflow_module.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_tensorflowlite/CMakeLists.txt b/plugins/wasmedge_tensorflowlite/CMakeLists.txt index 62e4707e..56cade99 100644 --- a/plugins/wasmedge_tensorflowlite/CMakeLists.txt +++ b/plugins/wasmedge_tensorflowlite/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC wasmedge_add_library(wasmedgePluginWasmEdgeTensorflowLite SHARED diff --git a/plugins/wasmedge_tensorflowlite/tensorflowlite_base.h b/plugins/wasmedge_tensorflowlite/tensorflowlite_base.h index 6b0765ae..075a46f7 100644 --- a/plugins/wasmedge_tensorflowlite/tensorflowlite_base.h +++ b/plugins/wasmedge_tensorflowlite/tensorflowlite_base.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_tensorflowlite/tensorflowlite_env.cpp b/plugins/wasmedge_tensorflowlite/tensorflowlite_env.cpp index a6d194ea..12161a6d 100644 --- a/plugins/wasmedge_tensorflowlite/tensorflowlite_env.cpp +++ b/plugins/wasmedge_tensorflowlite/tensorflowlite_env.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "tensorflowlite_env.h" #include "tensorflowlite_module.h" diff --git a/plugins/wasmedge_tensorflowlite/tensorflowlite_env.h b/plugins/wasmedge_tensorflowlite/tensorflowlite_env.h index 188ab9c2..02da4069 100644 --- a/plugins/wasmedge_tensorflowlite/tensorflowlite_env.h +++ b/plugins/wasmedge_tensorflowlite/tensorflowlite_env.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_tensorflowlite/tensorflowlite_func.cpp b/plugins/wasmedge_tensorflowlite/tensorflowlite_func.cpp index 0976dd13..48782592 100644 --- a/plugins/wasmedge_tensorflowlite/tensorflowlite_func.cpp +++ b/plugins/wasmedge_tensorflowlite/tensorflowlite_func.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "tensorflowlite_func.h" diff --git a/plugins/wasmedge_tensorflowlite/tensorflowlite_func.h b/plugins/wasmedge_tensorflowlite/tensorflowlite_func.h index d8e9d5b5..90e29f0b 100644 --- a/plugins/wasmedge_tensorflowlite/tensorflowlite_func.h +++ b/plugins/wasmedge_tensorflowlite/tensorflowlite_func.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_tensorflowlite/tensorflowlite_module.cpp b/plugins/wasmedge_tensorflowlite/tensorflowlite_module.cpp index 5f880229..0681849a 100644 --- a/plugins/wasmedge_tensorflowlite/tensorflowlite_module.cpp +++ b/plugins/wasmedge_tensorflowlite/tensorflowlite_module.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "tensorflowlite_module.h" #include "tensorflowlite_func.h" diff --git a/plugins/wasmedge_tensorflowlite/tensorflowlite_module.h b/plugins/wasmedge_tensorflowlite/tensorflowlite_module.h index a93545bb..1f5161b7 100644 --- a/plugins/wasmedge_tensorflowlite/tensorflowlite_module.h +++ b/plugins/wasmedge_tensorflowlite/tensorflowlite_module.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_zlib/CMakeLists.txt b/plugins/wasmedge_zlib/CMakeLists.txt index f12889bf..40f7ae89 100644 --- a/plugins/wasmedge_zlib/CMakeLists.txt +++ b/plugins/wasmedge_zlib/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC find_package(ZLIB REQUIRED) diff --git a/plugins/wasmedge_zlib/zlibbase.h b/plugins/wasmedge_zlib/zlibbase.h index 11640a72..63b9a16e 100644 --- a/plugins/wasmedge_zlib/zlibbase.h +++ b/plugins/wasmedge_zlib/zlibbase.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_zlib/zlibenv.cpp b/plugins/wasmedge_zlib/zlibenv.cpp index f00a9ef7..f3e8eaa4 100644 --- a/plugins/wasmedge_zlib/zlibenv.cpp +++ b/plugins/wasmedge_zlib/zlibenv.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "zlibenv.h" #include "zlibmodule.h" diff --git a/plugins/wasmedge_zlib/zlibenv.h b/plugins/wasmedge_zlib/zlibenv.h index 190025cd..a677d98a 100644 --- a/plugins/wasmedge_zlib/zlibenv.h +++ b/plugins/wasmedge_zlib/zlibenv.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once @@ -42,9 +42,8 @@ struct WasmZStream { /* [Wasm Offset] private data object passed to zalloc and zfree */ uint32_t Opaque; - /* best guess about the data type: binary or text - for deflate, or the decoding state - for inflate */ + /* best guess about the data type: binary or text for deflate, or the decoding + state for inflate */ int32_t DataType; /* Adler-32 or CRC-32 value of the uncompressed data */ @@ -55,8 +54,8 @@ struct WasmZStream { static_assert(sizeof(WasmZStream) == 56, "WasmZStream should be 56 bytes"); /* - gzip header information passed to and from zlib routines. See RFC 1952 - for more details on the meanings of these fields. + gzip header information passed to and from zlib routines. See RFC 1952 for + more details on the meanings of these fields. */ struct WasmGZHeader { int32_t Text; /* true if compressed data believed to be text */ @@ -72,7 +71,7 @@ struct WasmGZHeader { uint32_t CommMax; /* space at comment (only when reading header) */ int32_t HCRC; /* true if there was or will be a header crc */ int32_t Done; /* true when done reading gzip header (not used - when writing a gzip file) */ + when writing a gzip file) */ }; static_assert(sizeof(WasmGZHeader) == 52, "WasmGZHeader should be 52 bytes"); diff --git a/plugins/wasmedge_zlib/zlibfunc.cpp b/plugins/wasmedge_zlib/zlibfunc.cpp index 04aeac05..5ee8d343 100644 --- a/plugins/wasmedge_zlib/zlibfunc.cpp +++ b/plugins/wasmedge_zlib/zlibfunc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "zlibfunc.h" diff --git a/plugins/wasmedge_zlib/zlibfunc.h b/plugins/wasmedge_zlib/zlibfunc.h index d276462d..b7ca1f05 100644 --- a/plugins/wasmedge_zlib/zlibfunc.h +++ b/plugins/wasmedge_zlib/zlibfunc.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/plugins/wasmedge_zlib/zlibmodule.cpp b/plugins/wasmedge_zlib/zlibmodule.cpp index ecd2fd6e..4e39eaa6 100644 --- a/plugins/wasmedge_zlib/zlibmodule.cpp +++ b/plugins/wasmedge_zlib/zlibmodule.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "zlibmodule.h" #include "zlibfunc.h" diff --git a/plugins/wasmedge_zlib/zlibmodule.h b/plugins/wasmedge_zlib/zlibmodule.h index e502993c..dd595124 100644 --- a/plugins/wasmedge_zlib/zlibmodule.h +++ b/plugins/wasmedge_zlib/zlibmodule.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt index 571a578d..03c014bc 100644 --- a/test/plugins/CMakeLists.txt +++ b/test/plugins/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC if(WASMEDGE_PLUGIN_FFMPEG) add_subdirectory(wasmedge_ffmpeg) diff --git a/test/plugins/unittest/CMakeLists.txt b/test/plugins/unittest/CMakeLists.txt index d1614c06..8a0e2ad5 100644 --- a/test/plugins/unittest/CMakeLists.txt +++ b/test/plugins/unittest/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC # The test plugin module in C API enable_language(C) diff --git a/test/plugins/unittest/testplugin.c b/test/plugins/unittest/testplugin.c index 70eb27bd..a2be6472 100644 --- a/test/plugins/unittest/testplugin.c +++ b/test/plugins/unittest/testplugin.c @@ -1,6 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "wasmedge/wasmedge.h" diff --git a/test/plugins/unittest/testplugin.cpp b/test/plugins/unittest/testplugin.cpp index 2c5747b6..589df415 100644 --- a/test/plugins/unittest/testplugin.cpp +++ b/test/plugins/unittest/testplugin.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "testplugin.h" #include "po/helper.h" diff --git a/test/plugins/unittest/testplugin.h b/test/plugins/unittest/testplugin.h index b6bdbd34..82376b3e 100644 --- a/test/plugins/unittest/testplugin.h +++ b/test/plugins/unittest/testplugin.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/test/plugins/unittest/unittest_c.cpp b/test/plugins/unittest/unittest_c.cpp index 5397c9ca..1943648b 100644 --- a/test/plugins/unittest/unittest_c.cpp +++ b/test/plugins/unittest/unittest_c.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "common/defines.h" #include "wasmedge/wasmedge.h" diff --git a/test/plugins/unittest/unittest_cpp.cpp b/test/plugins/unittest/unittest_cpp.cpp index 81af0d5e..806e855a 100644 --- a/test/plugins/unittest/unittest_cpp.cpp +++ b/test/plugins/unittest/unittest_cpp.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "common/defines.h" #include "runtime/callingframe.h" diff --git a/test/plugins/wasi_crypto/CMakeLists.txt b/test/plugins/wasi_crypto/CMakeLists.txt index 5e874abd..8935d066 100644 --- a/test/plugins/wasi_crypto/CMakeLists.txt +++ b/test/plugins/wasi_crypto/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC wasmedge_add_executable(wasiCryptoTests aeads.cpp diff --git a/test/plugins/wasi_crypto/aeads.cpp b/test/plugins/wasi_crypto/aeads.cpp index eefc21c7..bae54a26 100644 --- a/test/plugins/wasi_crypto/aeads.cpp +++ b/test/plugins/wasi_crypto/aeads.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "helper.h" diff --git a/test/plugins/wasi_crypto/asymmetric.cpp b/test/plugins/wasi_crypto/asymmetric.cpp index c218216a..9e9f9c9f 100644 --- a/test/plugins/wasi_crypto/asymmetric.cpp +++ b/test/plugins/wasi_crypto/asymmetric.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "helper.h" diff --git a/test/plugins/wasi_crypto/common.cpp b/test/plugins/wasi_crypto/common.cpp index d50a527c..086701cf 100644 --- a/test/plugins/wasi_crypto/common.cpp +++ b/test/plugins/wasi_crypto/common.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "common/func.h" #include "helper.h" diff --git a/test/plugins/wasi_crypto/hash.cpp b/test/plugins/wasi_crypto/hash.cpp index 825e0432..ac615978 100644 --- a/test/plugins/wasi_crypto/hash.cpp +++ b/test/plugins/wasi_crypto/hash.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "helper.h" diff --git a/test/plugins/wasi_crypto/helper.cpp b/test/plugins/wasi_crypto/helper.cpp index 51b84955..c0a0ce04 100644 --- a/test/plugins/wasi_crypto/helper.cpp +++ b/test/plugins/wasi_crypto/helper.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "helper.h" #include "asymmetric_common/func.h" diff --git a/test/plugins/wasi_crypto/helper.h b/test/plugins/wasi_crypto/helper.h index 9feefc44..b6512a5f 100644 --- a/test/plugins/wasi_crypto/helper.h +++ b/test/plugins/wasi_crypto/helper.h @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #pragma once diff --git a/test/plugins/wasi_crypto/kdf.cpp b/test/plugins/wasi_crypto/kdf.cpp index 6e3fa5ed..8d64fa9f 100644 --- a/test/plugins/wasi_crypto/kdf.cpp +++ b/test/plugins/wasi_crypto/kdf.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "helper.h" diff --git a/test/plugins/wasi_crypto/kx.cpp b/test/plugins/wasi_crypto/kx.cpp index de107fcb..a47b728c 100644 --- a/test/plugins/wasi_crypto/kx.cpp +++ b/test/plugins/wasi_crypto/kx.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "helper.h" diff --git a/test/plugins/wasi_crypto/mac.cpp b/test/plugins/wasi_crypto/mac.cpp index 2e8521fc..c2616e31 100644 --- a/test/plugins/wasi_crypto/mac.cpp +++ b/test/plugins/wasi_crypto/mac.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "helper.h" diff --git a/test/plugins/wasi_crypto/notimplement.cpp b/test/plugins/wasi_crypto/notimplement.cpp index b87c0b06..0b255180 100644 --- a/test/plugins/wasi_crypto/notimplement.cpp +++ b/test/plugins/wasi_crypto/notimplement.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "helper.h" diff --git a/test/plugins/wasi_crypto/signatures.cpp b/test/plugins/wasi_crypto/signatures.cpp index 8cd2f083..07deb8c4 100644 --- a/test/plugins/wasi_crypto/signatures.cpp +++ b/test/plugins/wasi_crypto/signatures.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "helper.h" diff --git a/test/plugins/wasi_logging/CMakeLists.txt b/test/plugins/wasi_logging/CMakeLists.txt index 88985294..923363a2 100644 --- a/test/plugins/wasi_logging/CMakeLists.txt +++ b/test/plugins/wasi_logging/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC wasmedge_add_executable(wasiLoggingTests wasi_logging.cpp diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index e0f5f9ed..03669f6f 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC wasmedge_add_executable(wasiNNTests wasi_nn.cpp diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 93340b40..76411a25 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "wasinnfunc.h" #include "wasinnmodule.h" diff --git a/test/plugins/wasm_bpf/CMakeLists.txt b/test/plugins/wasm_bpf/CMakeLists.txt index c91608c5..625bfafe 100644 --- a/test/plugins/wasm_bpf/CMakeLists.txt +++ b/test/plugins/wasm_bpf/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC wasmedge_add_executable(wasmBpfTests simple_map_test.cpp diff --git a/test/plugins/wasm_bpf/assets/bpf-sources/simple_map.bpf.c b/test/plugins/wasm_bpf/assets/bpf-sources/simple_map.bpf.c index a787ceaa..1049aa5f 100644 --- a/test/plugins/wasm_bpf/assets/bpf-sources/simple_map.bpf.c +++ b/test/plugins/wasm_bpf/assets/bpf-sources/simple_map.bpf.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #define SEC(name) __attribute__((section(name), used)) #define __uint(name, val) int(*name)[val] diff --git a/test/plugins/wasm_bpf/assets/bpf-sources/simple_ringbuf.bpf.c b/test/plugins/wasm_bpf/assets/bpf-sources/simple_ringbuf.bpf.c index 888c6e1a..0a0b0e8d 100644 --- a/test/plugins/wasm_bpf/assets/bpf-sources/simple_ringbuf.bpf.c +++ b/test/plugins/wasm_bpf/assets/bpf-sources/simple_ringbuf.bpf.c @@ -1,5 +1,5 @@ // SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #define SEC(name) __attribute__((section(name), used)) #define __uint(name, val) int(*name)[val] diff --git a/test/plugins/wasm_bpf/simple_map_test.cpp b/test/plugins/wasm_bpf/simple_map_test.cpp index 51ef2f5a..b6c4c00c 100644 --- a/test/plugins/wasm_bpf/simple_map_test.cpp +++ b/test/plugins/wasm_bpf/simple_map_test.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "common/defines.h" #include "executor/executor.h" diff --git a/test/plugins/wasm_bpf/simple_ringbuf_test.cpp b/test/plugins/wasm_bpf/simple_ringbuf_test.cpp index ac389934..81fc935b 100644 --- a/test/plugins/wasm_bpf/simple_ringbuf_test.cpp +++ b/test/plugins/wasm_bpf/simple_ringbuf_test.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "common/defines.h" #include "executor/executor.h" diff --git a/test/plugins/wasm_bpf/wasm_bpf.cpp b/test/plugins/wasm_bpf/wasm_bpf.cpp index 3e5dbd0e..705015c3 100644 --- a/test/plugins/wasm_bpf/wasm_bpf.cpp +++ b/test/plugins/wasm_bpf/wasm_bpf.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "ast/type.h" #include "common/defines.h" diff --git a/test/plugins/wasmedge_ffmpeg/CMakeLists.txt b/test/plugins/wasmedge_ffmpeg/CMakeLists.txt index c35e0bc4..5b66bf2c 100644 --- a/test/plugins/wasmedge_ffmpeg/CMakeLists.txt +++ b/test/plugins/wasmedge_ffmpeg/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC wasmedge_add_executable(wasmedgeFFmpegTests main.cpp diff --git a/test/plugins/wasmedge_image/CMakeLists.txt b/test/plugins/wasmedge_image/CMakeLists.txt index 83a80949..62e1bd25 100644 --- a/test/plugins/wasmedge_image/CMakeLists.txt +++ b/test/plugins/wasmedge_image/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC wasmedge_add_executable(wasmedgeImageTests wasmedge_image.cpp diff --git a/test/plugins/wasmedge_image/wasmedge_image.cpp b/test/plugins/wasmedge_image/wasmedge_image.cpp index a5ea3356..0c16db43 100644 --- a/test/plugins/wasmedge_image/wasmedge_image.cpp +++ b/test/plugins/wasmedge_image/wasmedge_image.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "common/defines.h" #include "image_func.h" diff --git a/test/plugins/wasmedge_opencvmini/CMakeLists.txt b/test/plugins/wasmedge_opencvmini/CMakeLists.txt index 33fc1425..9f0946f2 100644 --- a/test/plugins/wasmedge_opencvmini/CMakeLists.txt +++ b/test/plugins/wasmedge_opencvmini/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2023 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC wasmedge_add_executable(wasmedgeOpencvminiTests wasmedge_opencvmini.cpp diff --git a/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp b/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp index d3b570c9..60e78db0 100644 --- a/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp +++ b/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "common/defines.h" #include "opencvmini_func.h" diff --git a/test/plugins/wasmedge_process/CMakeLists.txt b/test/plugins/wasmedge_process/CMakeLists.txt index ee34f5c1..fc389115 100644 --- a/test/plugins/wasmedge_process/CMakeLists.txt +++ b/test/plugins/wasmedge_process/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC wasmedge_add_executable(wasmedgeProcessTests wasmedge_process.cpp diff --git a/test/plugins/wasmedge_process/wasmedge_process.cpp b/test/plugins/wasmedge_process/wasmedge_process.cpp index 5f63fd90..fd322f5f 100644 --- a/test/plugins/wasmedge_process/wasmedge_process.cpp +++ b/test/plugins/wasmedge_process/wasmedge_process.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "common/defines.h" #include "processfunc.h" diff --git a/test/plugins/wasmedge_tensorflow/CMakeLists.txt b/test/plugins/wasmedge_tensorflow/CMakeLists.txt index aba4ee6e..9c0f6823 100644 --- a/test/plugins/wasmedge_tensorflow/CMakeLists.txt +++ b/test/plugins/wasmedge_tensorflow/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC wasmedge_add_executable(wasmedgeTensorflowTests wasmedge_tensorflow.cpp diff --git a/test/plugins/wasmedge_tensorflow/wasmedge_tensorflow.cpp b/test/plugins/wasmedge_tensorflow/wasmedge_tensorflow.cpp index 0b0617d3..8b53e4b3 100644 --- a/test/plugins/wasmedge_tensorflow/wasmedge_tensorflow.cpp +++ b/test/plugins/wasmedge_tensorflow/wasmedge_tensorflow.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "common/defines.h" #include "runtime/instance/module.h" diff --git a/test/plugins/wasmedge_tensorflowlite/CMakeLists.txt b/test/plugins/wasmedge_tensorflowlite/CMakeLists.txt index 04eee39a..bbf2ff60 100644 --- a/test/plugins/wasmedge_tensorflowlite/CMakeLists.txt +++ b/test/plugins/wasmedge_tensorflowlite/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC wasmedge_add_executable(wasmedgeTensorflowLiteTests wasmedge_tensorflowlite.cpp diff --git a/test/plugins/wasmedge_tensorflowlite/wasmedge_tensorflowlite.cpp b/test/plugins/wasmedge_tensorflowlite/wasmedge_tensorflowlite.cpp index ebddf267..61c49219 100644 --- a/test/plugins/wasmedge_tensorflowlite/wasmedge_tensorflowlite.cpp +++ b/test/plugins/wasmedge_tensorflowlite/wasmedge_tensorflowlite.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "common/defines.h" #include "runtime/instance/module.h" diff --git a/test/plugins/wasmedge_zlib/CMakeLists.txt b/test/plugins/wasmedge_zlib/CMakeLists.txt index 7159ab83..5b9ad3ec 100644 --- a/test/plugins/wasmedge_zlib/CMakeLists.txt +++ b/test/plugins/wasmedge_zlib/CMakeLists.txt @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC wasmedge_add_executable(wasmedgeZlibTests wasmedge_zlib.cpp diff --git a/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp b/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp index 1bf5bdc3..e22f725d 100644 --- a/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp +++ b/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC #include "common/defines.h" #include "runtime/instance/module.h" diff --git a/thirdparty/wasi_crypto/api.hpp b/thirdparty/wasi_crypto/api.hpp index 7ee8b90a..71bed3d2 100644 --- a/thirdparty/wasi_crypto/api.hpp +++ b/thirdparty/wasi_crypto/api.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2022 Second State INC +// SPDX-FileCopyrightText: 2019-2024 Second State INC /** * THIS FILE IS AUTO-GENERATED from the following files: diff --git a/utils/docker/Dockerfile.manylinux2014_aarch64 b/utils/docker/Dockerfile.manylinux2014_aarch64 index f2ef726e..b6665388 100644 --- a/utils/docker/Dockerfile.manylinux2014_aarch64 +++ b/utils/docker/Dockerfile.manylinux2014_aarch64 @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC FROM quay.io/pypa/manylinux2014_aarch64 diff --git a/utils/docker/Dockerfile.manylinux2014_x86_64 b/utils/docker/Dockerfile.manylinux2014_x86_64 index 924ff57c..da67eae9 100644 --- a/utils/docker/Dockerfile.manylinux2014_x86_64 +++ b/utils/docker/Dockerfile.manylinux2014_x86_64 @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC FROM quay.io/pypa/manylinux2014_x86_64 diff --git a/utils/docker/Dockerfile.manylinux_2_28-base b/utils/docker/Dockerfile.manylinux_2_28-base index ae2ea372..7166f1fa 100644 --- a/utils/docker/Dockerfile.manylinux_2_28-base +++ b/utils/docker/Dockerfile.manylinux_2_28-base @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC ARG BASE_IMAGE FROM ${BASE_IMAGE} diff --git a/utils/docker/Dockerfile.ubuntu2104_armv7l b/utils/docker/Dockerfile.ubuntu2104_armv7l index d4984a11..3329e6f2 100644 --- a/utils/docker/Dockerfile.ubuntu2104_armv7l +++ b/utils/docker/Dockerfile.ubuntu2104_armv7l @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC FROM arm32v7/ubuntu:hirsute diff --git a/utils/docker/build-manylinux.sh b/utils/docker/build-manylinux.sh index 2c4ead7b..5e4affab 100755 --- a/utils/docker/build-manylinux.sh +++ b/utils/docker/build-manylinux.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC git config --global --add safe.directory $(pwd) diff --git a/utils/docker/build.sh b/utils/docker/build.sh index bdcea629..71a485bf 100755 --- a/utils/docker/build.sh +++ b/utils/docker/build.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC NAME=${1:+$1/}wasmedge INTERMEDIATES=() diff --git a/utils/ffmpeg/download-ffmpeg-sample-video.sh b/utils/ffmpeg/download-ffmpeg-sample-video.sh index fabbf84a..5b05b080 100644 --- a/utils/ffmpeg/download-ffmpeg-sample-video.sh +++ b/utils/ffmpeg/download-ffmpeg-sample-video.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC # The below video used is sourced from an ffmpeg-libav-tutorial repository. # Source: https://github.com/leandromoreira/ffmpeg-libav-tutorial/blob/master/LICENSE. diff --git a/utils/opencvmini/install-opencvmini.sh b/utils/opencvmini/install-opencvmini.sh index bd95a902..799f7382 100644 --- a/utils/opencvmini/install-opencvmini.sh +++ b/utils/opencvmini/install-opencvmini.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# SPDX-FileCopyrightText: 2019-2023 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC OPENCV_VERSION=${OPENCV_VERSION:-4.8.0} wget -O opencv.zip https://github.com/opencv/opencv/archive/refs/tags/${OPENCV_VERSION}.zip diff --git a/utils/wasi-crypto/build-openssl.sh b/utils/wasi-crypto/build-openssl.sh index 547170a3..45d628bf 100755 --- a/utils/wasi-crypto/build-openssl.sh +++ b/utils/wasi-crypto/build-openssl.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC echo "Building OpenSSL for wasi-crypto..." # Get OpenSSL source diff --git a/utils/wasi-nn/build-wasinn-ubuntu-openvino.sh b/utils/wasi-nn/build-wasinn-ubuntu-openvino.sh index 5e3be658..0256ab83 100755 --- a/utils/wasi-nn/build-wasinn-ubuntu-openvino.sh +++ b/utils/wasi-nn/build-wasinn-ubuntu-openvino.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC if [[ ! -v "${CMAKE_BUILD_TYPE}" ]]; then CMAKE_BUILD_TYPE=Release diff --git a/utils/wasi-nn/install-neuralspeed.sh b/utils/wasi-nn/install-neuralspeed.sh index 23838c02..8b8d53df 100644 --- a/utils/wasi-nn/install-neuralspeed.sh +++ b/utils/wasi-nn/install-neuralspeed.sh @@ -1,4 +1,6 @@ #!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC set -e echo "Installing Python library!" diff --git a/utils/wasi-nn/install-onnxruntime.sh b/utils/wasi-nn/install-onnxruntime.sh index 3ff011e3..dad1fb56 100644 --- a/utils/wasi-nn/install-onnxruntime.sh +++ b/utils/wasi-nn/install-onnxruntime.sh @@ -1,4 +1,6 @@ #!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC set -e diff --git a/utils/wasi-nn/install-openvino.sh b/utils/wasi-nn/install-openvino.sh index fcc9e27b..10a3d089 100755 --- a/utils/wasi-nn/install-openvino.sh +++ b/utils/wasi-nn/install-openvino.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC set -e echo "Installing OpenVINO with version 2024.2.0" diff --git a/utils/wasi-nn/install-pytorch.sh b/utils/wasi-nn/install-pytorch.sh index faec86f2..f3fd070e 100755 --- a/utils/wasi-nn/install-pytorch.sh +++ b/utils/wasi-nn/install-pytorch.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC if [[ ! -n ${PYTORCH_VERSION} ]]; then PYTORCH_VERSION="1.8.2" diff --git a/utils/wasi-nn/test-wasinn-ubuntu-openvino.sh b/utils/wasi-nn/test-wasinn-ubuntu-openvino.sh index 49198538..39ff9d05 100755 --- a/utils/wasi-nn/test-wasinn-ubuntu-openvino.sh +++ b/utils/wasi-nn/test-wasinn-ubuntu-openvino.sh @@ -1,3 +1,7 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + ldconfig export LD_LIBRARY_PATH="$(pwd)/build/lib/api:$LD_LIBRARY_PATH" diff --git a/utils/wasi-test/run-wasi-test.sh b/utils/wasi-test/run-wasi-test.sh index 0acc1c59..478255e4 100755 --- a/utils/wasi-test/run-wasi-test.sh +++ b/utils/wasi-test/run-wasi-test.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2022 Second State INC +# SPDX-FileCopyrightText: 2019-2024 Second State INC # Test WasmEdge WASI layer. # The testcase is from https://github.com/khronosproject/wasi-test From 984d106a4b410e59adcc2ec1eaee94c544f7dd54 Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 13 Aug 2024 01:52:35 +0800 Subject: [PATCH 394/623] [WASI-NN] ggml: bump llama.cpp b3567 (#3645) Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 4878108c..8e9edde7 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -72,7 +72,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b3499 + GIT_TAG b3567 GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(llama) @@ -165,7 +165,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) target_link_libraries(wasmedgePluginWasiNN PRIVATE ${Python3_LIBRARIES}) target_link_directories(wasmedgePluginWasiNN PRIVATE ${Python3_RUNTIME_LIBRARY_DIRS}) else() - message(FATAL_ERROR "Can not find python3.") + message(FATAL_ERROR "Can not find python3.") endif() target_link_libraries(wasmedgePluginWasiNN PRIVATE simdjson::simdjson) elseif(BACKEND STREQUAL "piper") From 5011387527b73851baadc04d4a53e3b5314d1ef7 Mon Sep 17 00:00:00 2001 From: PeterD1524 Date: Tue, 13 Aug 2024 09:31:04 +0800 Subject: [PATCH 395/623] [WASI-NN] piper: update piper patch to support find_package for fmt and spdlog Signed-off-by: PeterD1524 --- plugins/wasi_nn/piper.patch | 134 +++++++++++++++++++++++++++++------- 1 file changed, 109 insertions(+), 25 deletions(-) diff --git a/plugins/wasi_nn/piper.patch b/plugins/wasi_nn/piper.patch index 3ac654d6..c4ba79e2 100644 --- a/plugins/wasi_nn/piper.patch +++ b/plugins/wasi_nn/piper.patch @@ -1,5 +1,5 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index f96ec44..ef67ff5 100644 +index f96ec44..1e84722 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,6 +2,8 @@ cmake_minimum_required(VERSION 3.13) @@ -27,8 +27,75 @@ index f96ec44..ef67ff5 100644 add_executable(test_piper src/cpp/test.cpp src/cpp/piper.cpp) # NOTE: external project prefix are shortened because of path length restrictions on Windows -@@ -58,59 +62,54 @@ endif() +@@ -25,7 +29,21 @@ add_executable(test_piper src/cpp/test.cpp src/cpp/piper.cpp) + # ---- fmt --- + +-if(NOT DEFINED FMT_DIR) ++set(fmt_FOUND FALSE) ++ ++if(NOT fmt_FOUND AND TARGET "fmt::fmt") ++ list(APPEND FMT_LINK_LIBRARIES "fmt::fmt") ++ set(fmt_FOUND TRUE) ++endif() ++ ++if(NOT fmt_FOUND AND NOT DEFINED FMT_DIR) ++ find_package(fmt) ++ if(fmt_FOUND) ++ list(APPEND FMT_LINK_LIBRARIES "fmt::fmt") ++ endif() ++endif() ++ ++if(NOT fmt_FOUND AND NOT DEFINED FMT_DIR) + set(FMT_VERSION "10.0.0") + set(FMT_DIR "${CMAKE_CURRENT_BINARY_DIR}/fi") + +@@ -41,11 +59,33 @@ if(NOT DEFINED FMT_DIR) + add_dependencies(test_piper fmt_external) + endif() + ++if(NOT fmt_FOUND AND DEFINED FMT_DIR) ++ list(APPEND FMT_LINK_LIBRARIES "fmt") ++ list(APPEND FMT_LINK_DIRECTORIES "${FMT_DIR}/lib") ++ list(APPEND FMT_INCLUDE_DIRECTORIES "${FMT_DIR}/include") ++ set(fmt_FOUND TRUE) ++endif() ++ + # ---- spdlog --- + +-if(NOT DEFINED SPDLOG_DIR) ++set(spdlog_FOUND FALSE) ++ ++if(NOT spdlog_FOUND AND TARGET "spdlog::spdlog") ++ list(APPEND SPDLOG_LINK_LIBRARIES "spdlog::spdlog") ++ set(spdlog_FOUND TRUE) ++endif() ++ ++if(NOT spdlog_FOUND AND NOT DEFINED SPDLOG_DIR) ++ find_package(spdlog) ++ if(spdlog_FOUND) ++ list(APPEND SPDLOG_LINK_LIBRARIES "spdlog::spdlog") ++ endif() ++endif() ++ ++if(NOT spdlog_FOUND AND NOT DEFINED SPDLOG_DIR) + set(SPDLOG_DIR "${CMAKE_CURRENT_BINARY_DIR}/si") + set(SPDLOG_VERSION "1.12.0") ++ include(ExternalProject) + ExternalProject_Add( + spdlog_external + PREFIX "${CMAKE_CURRENT_BINARY_DIR}/s" +@@ -56,81 +96,81 @@ if(NOT DEFINED SPDLOG_DIR) + add_dependencies(test_piper spdlog_external) + endif() + ++if(NOT spdlog_FOUND AND DEFINED SPDLOG_DIR) ++ list(APPEND SPDLOG_LINK_LIBRARIES "spdlog") ++ list(APPEND SPDLOG_LINK_DIRECTORIES "${SPDLOG_DIR}/lib") ++ list(APPEND SPDLOG_INCLUDE_DIRECTORIES "${SPDLOG_DIR}/include") ++ set(spdlog_FOUND TRUE) ++endif() ++ # ---- piper-phonemize --- -if(NOT DEFINED PIPER_PHONEMIZE_DIR) @@ -69,9 +136,11 @@ index f96ec44..ef67ff5 100644 endif() -target_link_libraries(piper +- fmt +- spdlog +target_link_libraries(piper PRIVATE - fmt - spdlog ++ "${FMT_LINK_LIBRARIES}" ++ "${SPDLOG_LINK_LIBRARIES}" espeak-ng - piper_phonemize onnxruntime @@ -80,52 +149,67 @@ index f96ec44..ef67ff5 100644 ) -target_link_directories(piper PUBLIC -+target_link_directories(piper PRIVATE - ${FMT_DIR}/lib - ${SPDLOG_DIR}/lib +- ${FMT_DIR}/lib +- ${SPDLOG_DIR}/lib - ${PIPER_PHONEMIZE_DIR}/lib - ) - +-) +- -target_include_directories(piper PUBLIC -+set(PIPER_INTERFACE_INCLUDE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/include") -+file(COPY src/cpp/piper.hpp src/cpp/json.hpp DESTINATION "${PIPER_INTERFACE_INCLUDE_DIRECTORY}") -+ -+target_include_directories(piper PRIVATE - ${FMT_DIR}/include - ${SPDLOG_DIR}/include +- ${FMT_DIR}/include +- ${SPDLOG_DIR}/include - ${PIPER_PHONEMIZE_DIR}/include -+ INTERFACE "${PIPER_INTERFACE_INCLUDE_DIRECTORY}" ++target_link_directories(piper PRIVATE ++ "${FMT_LINK_DIRECTORIES}" ++ "${SPDLOG_LINK_DIRECTORIES}" ) -target_compile_definitions(piper PUBLIC _PIPER_VERSION=${piper_version}) -- ++set(PIPER_INTERFACE_INCLUDE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/include") ++file(COPY src/cpp/piper.hpp src/cpp/json.hpp DESTINATION "${PIPER_INTERFACE_INCLUDE_DIRECTORY}") + -# ---- Declare test ---- -include(CTest) -enable_testing() -add_test( - NAME test_piper - COMMAND test_piper "${CMAKE_SOURCE_DIR}/etc/test_voice.onnx" "${PIPER_PHONEMIZE_DIR}/share/espeak-ng-data" "${CMAKE_CURRENT_BINARY_DIR}/test.wav" --) -+target_compile_definitions(piper PRIVATE _PIPER_VERSION=${piper_version}) ++target_include_directories(piper PRIVATE ++ "${FMT_INCLUDE_DIRECTORIES}" ++ "${SPDLOG_INCLUDE_DIRECTORIES}" ++ INTERFACE "${PIPER_INTERFACE_INCLUDE_DIRECTORY}" + ) ++target_compile_definitions(piper PRIVATE _PIPER_VERSION=${piper_version}) ++ target_compile_features(test_piper PUBLIC cxx_std_17) -@@ -118,14 +117,12 @@ target_include_directories( + target_include_directories( test_piper PUBLIC - ${FMT_DIR}/include - ${SPDLOG_DIR}/include +- ${FMT_DIR}/include +- ${SPDLOG_DIR}/include - ${PIPER_PHONEMIZE_DIR}/include ++ "${FMT_INCLUDE_DIRECTORIES}" ++ "${SPDLOG_INCLUDE_DIRECTORIES}" ) target_link_directories( test_piper PUBLIC - ${FMT_DIR}/lib - ${SPDLOG_DIR}/lib +- ${FMT_DIR}/lib +- ${SPDLOG_DIR}/lib - ${PIPER_PHONEMIZE_DIR}/lib ++ "${FMT_LINK_DIRECTORIES}" ++ "${SPDLOG_LINK_DIRECTORIES}" ) target_link_libraries(test_piper PUBLIC -@@ -141,32 +138,3 @@ target_link_libraries(test_piper PUBLIC +- fmt +- spdlog ++ "${FMT_LINK_LIBRARIES}" ++ "${SPDLOG_LINK_LIBRARIES}" + espeak-ng + piper_phonemize + onnxruntime +@@ -141,32 +181,3 @@ target_link_libraries(test_piper PUBLIC install( TARGETS piper DESTINATION ${CMAKE_INSTALL_PREFIX}) From c9103c53d61b05bb9b0fdd84b4edde8d53ded910 Mon Sep 17 00:00:00 2001 From: Jun Zhang Date: Wed, 14 Aug 2024 18:02:19 +0800 Subject: [PATCH 396/623] [WASI-LLM] Add initial CPU backend for wasi_llm (#3624) Signed-off-by: Jun Zhang --- plugins/wasi_llm/CMakeLists.txt | 13 ++ plugins/wasi_llm/llmc_fwd.h | 29 ++++ plugins/wasi_llm/types.h | 1 + plugins/wasi_llm/wasillmbase.h | 4 +- plugins/wasi_llm/wasillmenv.cpp | 88 ++++++++++++ plugins/wasi_llm/wasillmenv.h | 49 +++++++ plugins/wasi_llm/wasillmfunc.cpp | 144 ++++++++++++------- plugins/wasi_llm/wasillmfunc.h | 94 ++++++------- plugins/wasi_llm/wasillmmodule.cpp | 13 +- plugins/wasi_llm/wasillmmodule.h | 3 + test/plugins/CMakeLists.txt | 5 + test/plugins/wasi_llm/CMakeLists.txt | 67 +++++++++ test/plugins/wasi_llm/wasi_llm.cpp | 199 +++++++++++++++++++++++++++ 13 files changed, 598 insertions(+), 111 deletions(-) create mode 100644 plugins/wasi_llm/llmc_fwd.h create mode 100644 plugins/wasi_llm/wasillmenv.cpp create mode 100644 plugins/wasi_llm/wasillmenv.h create mode 100644 test/plugins/wasi_llm/CMakeLists.txt create mode 100644 test/plugins/wasi_llm/wasi_llm.cpp diff --git a/plugins/wasi_llm/CMakeLists.txt b/plugins/wasi_llm/CMakeLists.txt index 6934fa93..ea8c71ea 100644 --- a/plugins/wasi_llm/CMakeLists.txt +++ b/plugins/wasi_llm/CMakeLists.txt @@ -7,6 +7,19 @@ wasmedge_add_library(wasmedgePluginWasiLLM SHARED wasillmfunc.cpp wasillmmodule.cpp + wasillmenv.cpp +) + +message(STATUS "Start fetching llm.c source") +include(FetchContent) +FetchContent_Declare( + llmc + GIT_REPOSITORY https://github.com/WasmEdge/llm.c +) +FetchContent_MakeAvailable(llmc) + +target_link_libraries(wasmedgePluginWasiLLM PRIVATE + train_gpt2_cpu ) target_compile_options(wasmedgePluginWasiLLM diff --git a/plugins/wasi_llm/llmc_fwd.h b/plugins/wasi_llm/llmc_fwd.h new file mode 100644 index 00000000..4486be35 --- /dev/null +++ b/plugins/wasi_llm/llmc_fwd.h @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "wasillmenv.h" + +extern "C" { + +struct GPT2; +struct Tokenizer; +struct DataLoader; + +GPT2 *gpt2_create(const char *checkpoint_path); + +void gpt2_free(GPT2 *model); + +DataLoader *dataloader_create(const char *filename_pattern, size_t B, size_t T, + int process_rank, int num_processes, + int should_shuffle); +void dataloader_destroy(DataLoader *loader); + +Tokenizer *tokenizer_create(const char *filename); + +void tokenizer_destroy(Tokenizer *tokenizer); + +void gpt2_train(GPT2 *model, DataLoader *train_loader, DataLoader *val_loader, + Tokenizer *tokenizer, int B, int T, float lr, int epoch); +} diff --git a/plugins/wasi_llm/types.h b/plugins/wasi_llm/types.h index 7d858f1b..71e73641 100644 --- a/plugins/wasi_llm/types.h +++ b/plugins/wasi_llm/types.h @@ -10,6 +10,7 @@ namespace WasmEdge::Host::WASILLM { enum class ErrNo : uint32_t { Success = 0, InvalidArgument = 1, + MissingMemory = 2, }; } // namespace WasmEdge::Host::WASILLM diff --git a/plugins/wasi_llm/wasillmbase.h b/plugins/wasi_llm/wasillmbase.h index aace0bd4..fcc941ce 100644 --- a/plugins/wasi_llm/wasillmbase.h +++ b/plugins/wasi_llm/wasillmbase.h @@ -6,18 +6,20 @@ #include "common/errcode.h" #include "runtime/hostfunc.h" #include "types.h" +#include "wasillmenv.h" namespace WasmEdge { namespace Host { template class WasiLLM : public Runtime::HostFunction { public: - WasiLLM() : Runtime::HostFunction(0) {} + WasiLLM(WASILLM::WASILLMEnv &E) : Runtime::HostFunction(0), Env(E) {} protected: static constexpr uint32_t castErrNo(WASILLM::ErrNo E) noexcept { return static_cast(E); } + WASILLM::WASILLMEnv &Env; }; } // namespace Host diff --git a/plugins/wasi_llm/wasillmenv.cpp b/plugins/wasi_llm/wasillmenv.cpp new file mode 100644 index 00000000..336ece58 --- /dev/null +++ b/plugins/wasi_llm/wasillmenv.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "wasillmenv.h" +#include "llmc_fwd.h" +#include "wasillmmodule.h" + +namespace WasmEdge { +namespace Host { + +namespace WASILLM { + +uint32_t WASILLMEnv::addModel(GPT2 *M) noexcept { + Models.push_back(M); + return Models.size() - 1; +} + +GPT2 *WASILLMEnv::getModel(uint32_t Id) noexcept { + assert(Id < Models.size() && "Out of bounds"); + return Models[Id]; +} + +uint32_t WASILLMEnv::addTokenizer(Tokenizer *T) noexcept { + Tokenizers.push_back(T); + return Tokenizers.size() - 1; +} + +Tokenizer *WASILLMEnv::getTokenizer(uint32_t Id) noexcept { + assert(Id < Tokenizers.size() && "Out of bounds"); + return Tokenizers[Id]; +} + +uint32_t WASILLMEnv::addDataLoader(DataLoader *D) noexcept { + DataLoaders.push_back(D); + return DataLoaders.size() - 1; +} + +DataLoader *WASILLMEnv::getDataLoader(uint32_t Id) noexcept { + assert(Id < DataLoaders.size() && "Out of bounds"); + return DataLoaders[Id]; +} + +WASILLMEnv::~WASILLMEnv() { + for (GPT2 *M : Models) { + gpt2_free(M); + } + for (DataLoader *DL : DataLoaders) { + dataloader_destroy(DL); + } + for (Tokenizer *T : Tokenizers) { + tokenizer_destroy(T); + } +} + +namespace { + +Runtime::Instance::ModuleInstance * +create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasiLLMModule; +} + +static Plugin::PluginModule::ModuleDescriptor MD[] = { + { + /* Name */ "wasi_llm", + /* Description */ "", + /* Create */ create, + }, +}; + +Plugin::Plugin::PluginDescriptor Descriptor{ + /* Name */ "wasi_llm", + /* Description */ "", + /* APIVersion */ Plugin::Plugin::CurrentAPIVersion, + /* Version */ {0, 1, 0, 0}, + /* ModuleCount */ 1, + /* ModuleDescriptions */ MD, + /* ComponentCount */ 0, + /* ComponentDescriptions */ nullptr, + /*AddOptions*/ nullptr, +}; +} // namespace + +EXPORT_GET_DESCRIPTOR(Descriptor) + +} // namespace WASILLM + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_llm/wasillmenv.h b/plugins/wasi_llm/wasillmenv.h new file mode 100644 index 00000000..af409023 --- /dev/null +++ b/plugins/wasi_llm/wasillmenv.h @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "plugin/plugin.h" +#include +#include +#include + +extern "C" { +struct GPT2; +struct Tokenizer; +struct DataLoader; +} + +namespace WasmEdge { +namespace Host { +namespace WASILLM { + +class WASILLMEnv { + std::vector Models; + std::vector Tokenizers; + std::vector DataLoaders; + +public: + uint32_t addModel(GPT2 *M) noexcept; + + GPT2 *getModel(uint32_t Id) noexcept; + + size_t getModelSize() const noexcept { return Models.size(); } + + uint32_t addTokenizer(Tokenizer *T) noexcept; + + Tokenizer *getTokenizer(uint32_t Id) noexcept; + + size_t getTokenizerSize() const noexcept { return Tokenizers.size(); } + + uint32_t addDataLoader(DataLoader *D) noexcept; + + DataLoader *getDataLoader(uint32_t Id) noexcept; + + size_t getDataLoaderSize() const noexcept { return DataLoaders.size(); } + + ~WASILLMEnv(); +}; +} // namespace WASILLM +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_llm/wasillmfunc.cpp b/plugins/wasi_llm/wasillmfunc.cpp index c40090d1..ec794909 100644 --- a/plugins/wasi_llm/wasillmfunc.cpp +++ b/plugins/wasi_llm/wasillmfunc.cpp @@ -2,8 +2,10 @@ // SPDX-FileCopyrightText: 2019-2024 Second State INC #include "wasillmfunc.h" +#include "common/errcode.h" #include "common/spdlog.h" - +#include "llmc_fwd.h" +#include "types.h" #include #include @@ -13,68 +15,110 @@ namespace Host { Expect WasiLLMModelCreate::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t CheckPointPath, - uint32_t CheckPointPathLen) { - (void)Frame; - (void)CheckPointPath; - (void)CheckPointPathLen; - return WASILLM::ErrNo::InvalidArgument; -} + uint32_t CheckPointPathLen, uint32_t ModelIdPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-LLM] Memory instance not found."sv); + return WASILLM::ErrNo::MissingMemory; + } + auto CheckPointPathSpan = + MemInst->getSpan(CheckPointPath, CheckPointPathLen); + if (unlikely(CheckPointPathSpan.size() != CheckPointPathLen)) { + spdlog::error( + "[WasmEdge-LLM] Failed when accessing the input checkpoint path memory."sv); + return WASILLM::ErrNo::MissingMemory; + } -Expect -WasiLLMModelFree::bodyImpl(const Runtime::CallingFrame &Frame, - uint32_t ModelPtr) { - (void)Frame; - (void)ModelPtr; - return WASILLM::ErrNo::InvalidArgument; + auto *ModelId = MemInst->getPointer(ModelIdPtr); + if (unlikely(ModelId == nullptr)) { + spdlog::error( + "[WasmEdge-LLM] Failed when accessing the return model memory."sv); + return WASILLM::ErrNo::InvalidArgument; + } + std::string CheckPointPathStr = + std::string(CheckPointPathSpan.begin(), + CheckPointPathSpan.begin() + CheckPointPathSpan.size()); + GPT2 *Model = gpt2_create(CheckPointPathStr.data()); + *ModelId = Env.addModel(Model); + return WASILLM::ErrNo::Success; } -Expect -WasiLLMDataLoaderCreate::bodyImpl(const Runtime::CallingFrame &Frame, - uint32_t DataPath, uint32_t DataPathLen) { - (void)Frame; - (void)DataPath; - (void)DataPathLen; - return WASILLM::ErrNo::InvalidArgument; -} +Expect WasiLLMDataLoaderCreate::bodyImpl( + const Runtime::CallingFrame &Frame, uint32_t DataPath, uint32_t DataPathLen, + uint32_t B, uint32_t T, uint32_t ProcessRank, uint32_t NumProcesses, + int32_t ShouldShuffle, uint32_t DataLoaderIdPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-LLM] Memory instance not found."sv); + return WASILLM::ErrNo::MissingMemory; + } + auto DataPathSpan = MemInst->getSpan(DataPath, DataPathLen); + if (unlikely(DataPathSpan.size() != DataPathLen)) { + spdlog::error( + "[WasmEdge-LLM] Failed when accessing the input dataloader path memory."sv); + return WASILLM::ErrNo::MissingMemory; + } -Expect -WasiLLMDataLoaderFree::bodyImpl(const Runtime::CallingFrame &Frame, - uint32_t DataLoaderPtr) { - (void)Frame; - (void)DataLoaderPtr; - return WASILLM::ErrNo::InvalidArgument; + auto *DataLoaderId = MemInst->getPointer(DataLoaderIdPtr); + if (unlikely(DataLoaderId == nullptr)) { + spdlog::error( + "[WasmEdge-LLM] Failed when accessing the return dataloader memory."sv); + return WASILLM::ErrNo::InvalidArgument; + } + + std::string DataPathStr = std::string( + DataPathSpan.begin(), DataPathSpan.begin() + DataPathSpan.size()); + DataLoader *D = dataloader_create(DataPathStr.data(), B, T, ProcessRank, + NumProcesses, ShouldShuffle); + *DataLoaderId = Env.addDataLoader(D); + return WASILLM::ErrNo::Success; } Expect WasiLLMTokenizerCreate::bodyImpl(const Runtime::CallingFrame &Frame, - uint32_t FilePath, uint32_t FilePathLen) { - (void)Frame; - (void)FilePath; - (void)FilePathLen; - return WASILLM::ErrNo::InvalidArgument; -} + uint32_t FilePath, uint32_t FilePathLen, + uint32_t TokenizerIdPtr) { + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-LLM] Memory instance not found."sv); + return WASILLM::ErrNo::MissingMemory; + } + auto FilePathSpan = MemInst->getSpan(FilePath, FilePathLen); + if (unlikely(FilePathSpan.size() != FilePathLen)) { + spdlog::error( + "[WasmEdge-LLM] Failed when accessing the input tokenizer path memory."sv); + return WASILLM::ErrNo::MissingMemory; + } -Expect -WasiLLMTokenizerFree::bodyImpl(const Runtime::CallingFrame &Frame, - uint32_t TokenizerPtr) { - (void)Frame; - (void)TokenizerPtr; - return WASILLM::ErrNo::InvalidArgument; + auto *TokenizerId = MemInst->getPointer(TokenizerIdPtr); + if (unlikely(TokenizerId == nullptr)) { + spdlog::error( + "[WasmEdge-LLM] Failed when accessing the return tokenizer memory."sv); + return WASILLM::ErrNo::InvalidArgument; + } + std::string FilePathStr = std::string( + FilePathSpan.begin(), FilePathSpan.begin() + FilePathSpan.size()); + Tokenizer *T = tokenizer_create(FilePathStr.data()); + *TokenizerId = Env.addTokenizer(T); + return WASILLM::ErrNo::Success; } Expect WasiLLMModelTrain::bodyImpl(const Runtime::CallingFrame &Frame, - uint32_t ModelPtr, uint32_t TrainDataLoaderPtr, - uint32_t ValDataLoaderPtr, uint32_t TokenizerPtr, - uint32_t Lr, uint32_t Epoch) { - (void)Frame; - (void)ModelPtr; - (void)TrainDataLoaderPtr; - (void)ValDataLoaderPtr; - (void)TokenizerPtr; - (void)Lr; - (void)Epoch; - return WASILLM::ErrNo::InvalidArgument; + uint32_t ModelId, uint32_t TrainDataLoaderId, + uint32_t ValDataLoaderId, uint32_t TokenizerId, + uint32_t B, uint32_t T, float Lr, uint32_t Epoch) { + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + spdlog::error("[WasmEdge-LLM] Memory instance not found."sv); + return WASILLM::ErrNo::MissingMemory; + } + GPT2 *Model = Env.getModel(ModelId); + DataLoader *TrainDataLoader = Env.getDataLoader(TrainDataLoaderId); + DataLoader *ValDataLoader = Env.getDataLoader(ValDataLoaderId); + Tokenizer *Tokenizer = Env.getTokenizer(TokenizerId); + gpt2_train(Model, TrainDataLoader, ValDataLoader, Tokenizer, B, T, Lr, Epoch); + return WASILLM::ErrNo::Success; } } // namespace Host diff --git a/plugins/wasi_llm/wasillmfunc.h b/plugins/wasi_llm/wasillmfunc.h index de5a6d21..4c588b19 100644 --- a/plugins/wasi_llm/wasillmfunc.h +++ b/plugins/wasi_llm/wasillmfunc.h @@ -6,6 +6,7 @@ #include "runtime/callingframe.h" #include "types.h" #include "wasillmbase.h" +#include "wasillmenv.h" #include @@ -14,91 +15,78 @@ namespace Host { class WasiLLMModelCreate : public WasiLLM { public: + explicit WasiLLMModelCreate(WASILLM::WASILLMEnv &Env) : WasiLLM(Env) {} + Expect body(const Runtime::CallingFrame &Frame, - uint32_t CheckPointPath, uint32_t CheckPointPathLen) { - return bodyImpl(Frame, CheckPointPath, CheckPointPathLen).map(castErrNo); + uint32_t CheckPointPath, uint32_t CheckPointPathLen, + uint32_t ModelIdPtr) { + return bodyImpl(Frame, CheckPointPath, CheckPointPathLen, ModelIdPtr) + .map(castErrNo); } private: Expect bodyImpl(const Runtime::CallingFrame &Frame, uint32_t CheckPointPath, - uint32_t CheckPointPathLen); -}; - -class WasiLLMModelFree : public WasiLLM { -public: - Expect body(const Runtime::CallingFrame &Frame, uint32_t ModelPtr) { - return bodyImpl(Frame, ModelPtr).map(castErrNo); - } - -private: - Expect bodyImpl(const Runtime::CallingFrame &Frame, - uint32_t ModelPtr); + uint32_t CheckPointPathLen, + uint32_t ModelIdPtr); }; class WasiLLMDataLoaderCreate : public WasiLLM { public: - Expect body(const Runtime::CallingFrame &Frame, uint32_t DataPath, - uint32_t DataPathLen) { - return bodyImpl(Frame, DataPath, DataPathLen).map(castErrNo); - } - -private: - Expect bodyImpl(const Runtime::CallingFrame &Frame, - uint32_t DataPath, uint32_t DataPathLen); -}; + explicit WasiLLMDataLoaderCreate(WASILLM::WASILLMEnv &Env) : WasiLLM(Env) {} -class WasiLLMDataLoaderFree : public WasiLLM { -public: - Expect body(const Runtime::CallingFrame &Frame, - uint32_t DataLoaderPtr) { - return bodyImpl(Frame, DataLoaderPtr).map(castErrNo); + Expect body(const Runtime::CallingFrame &Frame, uint32_t DataPath, + uint32_t DataPathLen, uint32_t B, uint32_t T, + uint32_t ProcessRank, uint32_t NumProcesses, + int32_t ShouldShuffle, uint32_t DataLoaderIdPtr) { + return bodyImpl(Frame, DataPath, DataPathLen, B, T, ProcessRank, + NumProcesses, ShouldShuffle, DataLoaderIdPtr) + .map(castErrNo); } private: Expect bodyImpl(const Runtime::CallingFrame &Frame, - uint32_t DataLoaderPtr); + uint32_t DataPath, uint32_t DataPathLen, + uint32_t B, uint32_t T, uint32_t ProcessRank, + uint32_t NumProcesses, int32_t ShouldShuffle, + uint32_t DataLoaderIdPtr); }; class WasiLLMTokenizerCreate : public WasiLLM { public: - Expect body(const Runtime::CallingFrame &Frame, uint32_t FilePath, - uint32_t FilePathLen) { - return bodyImpl(Frame, FilePath, FilePathLen).map(castErrNo); - } + explicit WasiLLMTokenizerCreate(WASILLM::WASILLMEnv &Env) : WasiLLM(Env) {} -private: - Expect bodyImpl(const Runtime::CallingFrame &Frame, - uint32_t FilePath, uint32_t FilePathLen); -}; - -class WasiLLMTokenizerFree : public WasiLLM { -public: - Expect body(const Runtime::CallingFrame &Frame, - uint32_t TokenizerPtr) { - return bodyImpl(Frame, TokenizerPtr).map(castErrNo); + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilePath, + uint32_t FilePathLen, uint32_t TokenizerIdPtr) { + return bodyImpl(Frame, FilePath, FilePathLen, TokenizerIdPtr) + .map(castErrNo); } private: Expect bodyImpl(const Runtime::CallingFrame &Frame, - uint32_t TokenizerPtr); + uint32_t FilePath, uint32_t FilePathLen, + uint32_t TokenizerIdPtr); }; class WasiLLMModelTrain : public WasiLLM { public: - Expect body(const Runtime::CallingFrame &Frame, uint32_t ModelPtr, - uint32_t TrainDataLoaderPtr, uint32_t ValDataLoaderPtr, - uint32_t TokenizerPtr, uint32_t Lr, uint32_t Epoch) { - return bodyImpl(Frame, ModelPtr, TrainDataLoaderPtr, ValDataLoaderPtr, - TokenizerPtr, Lr, Epoch) + explicit WasiLLMModelTrain(WASILLM::WASILLMEnv &Env) : WasiLLM(Env) {} + + Expect body(const Runtime::CallingFrame &Frame, uint32_t ModelId, + uint32_t TrainDataLoaderId, uint32_t ValDataLoaderId, + uint32_t TokenizerId, uint32_t B, uint32_t T, float Lr, + uint32_t Epoch) { + return bodyImpl(Frame, ModelId, TrainDataLoaderId, ValDataLoaderId, + TokenizerId, B, T, Lr, Epoch) .map(castErrNo); } private: - Expect - bodyImpl(const Runtime::CallingFrame &Frame, uint32_t ModelPtr, - uint32_t TrainDataLoaderPtr, uint32_t ValDataLoaderPtr, - uint32_t TokenizerPtr, uint32_t Lr, uint32_t Epoch); + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t ModelId, uint32_t TrainDataLoaderId, + uint32_t ValDataLoaderId, + uint32_t TokenizerId, uint32_t B, uint32_t T, + float Lr, uint32_t Epoch); }; } // namespace Host diff --git a/plugins/wasi_llm/wasillmmodule.cpp b/plugins/wasi_llm/wasillmmodule.cpp index 76aca85d..36742ac3 100644 --- a/plugins/wasi_llm/wasillmmodule.cpp +++ b/plugins/wasi_llm/wasillmmodule.cpp @@ -8,13 +8,12 @@ namespace WasmEdge { namespace Host { WasiLLMModule::WasiLLMModule() : ModuleInstance("wasi_llm") { - addHostFunc("model_create", std::make_unique()); - addHostFunc("model_free", std::make_unique()); - addHostFunc("dataloader_create", std::make_unique()); - addHostFunc("dataloader_free", std::make_unique()); - addHostFunc("tokenizer_create", std::make_unique()); - addHostFunc("tokenizer_free", std::make_unique()); - addHostFunc("model_train", std::make_unique()); + addHostFunc("model_create", std::make_unique(Env)); + addHostFunc("dataloader_create", + std::make_unique(Env)); + addHostFunc("tokenizer_create", + std::make_unique(Env)); + addHostFunc("model_train", std::make_unique(Env)); } } // namespace Host diff --git a/plugins/wasi_llm/wasillmmodule.h b/plugins/wasi_llm/wasillmmodule.h index 9033fe1c..6a6574c2 100644 --- a/plugins/wasi_llm/wasillmmodule.h +++ b/plugins/wasi_llm/wasillmmodule.h @@ -4,11 +4,14 @@ #pragma once #include "runtime/instance/module.h" +#include "wasillmenv.h" namespace WasmEdge { namespace Host { class WasiLLMModule : public Runtime::Instance::ModuleInstance { + WASILLM::WASILLMEnv Env; + public: WasiLLMModule(); }; diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt index 03c014bc..e8397a55 100644 --- a/test/plugins/CMakeLists.txt +++ b/test/plugins/CMakeLists.txt @@ -58,4 +58,9 @@ if(WASMEDGE_PLUGIN_STABLEDIFFUSION) endif() add_subdirectory(wasi_logging) + +if(WASMEDGE_PLUGIN_LLM) + add_subdirectory(wasi_llm) +endif() + add_subdirectory(unittest) diff --git a/test/plugins/wasi_llm/CMakeLists.txt b/test/plugins/wasi_llm/CMakeLists.txt new file mode 100644 index 00000000..a29cd29e --- /dev/null +++ b/test/plugins/wasi_llm/CMakeLists.txt @@ -0,0 +1,67 @@ +wasmedge_add_executable(wasiLLMTests + wasi_llm.cpp +) + +add_dependencies(wasiLLMTests + wasmedgePluginWasiLLM +) + +target_include_directories(wasiLLMTests + PUBLIC + $ + $ +) + +target_link_libraries(wasiLLMTests + PRIVATE + ${GTEST_BOTH_LIBRARIES} +) + +# Link to the WasmEdge library +if(WASMEDGE_LINK_PLUGINS_STATIC) + target_link_libraries(wasiLLMTests + PRIVATE + wasmedgeCAPI + ) +else() + target_link_libraries(wasiLLMTests + PRIVATE + wasmedge_shared + ) +endif() + +function(download URL OUTPUT HASH) + file(DOWNLOAD + ${URL} + ${OUTPUT} + SHOW_PROGRESS + EXPECTED_HASH ${HASH} + ) +endfunction() + +message(STATUS "Downloading GPT2 model check point to ${CMAKE_CURRENT_BINARY_DIR}/gpt2_124M.bin") +download( + https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/gpt2_124M.bin + ${CMAKE_CURRENT_BINARY_DIR}/wasi_llm/gpt2_124M.bin + SHA256=3da8b207584030bcdcd207cf7a99952e3421dce92da218b351071857511bf162 +) +message(STATUS "Downloading training dataset to ${CMAKE_CURRENT_BINARY_DIR}/tiny_shakespeare_train.bin") +download( + https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/tiny_shakespeare_train.bin + ${CMAKE_CURRENT_BINARY_DIR}/wasi_llm/tiny_shakespeare_train.bin + SHA256=8a70606be574040c26d225694f5f9759973b419852d22f7fe5c118e1b359dcc8 +) +message(STATUS "Downloading validation dataset to ${CMAKE_CURRENT_BINARY_DIR}/tiny_shakespeare_val.bin") +download( + https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/tiny_shakespeare_val.bin + ${CMAKE_CURRENT_BINARY_DIR}/wasi_llm/tiny_shakespeare_val.bin + SHA256=fe99db720dc7c83e694806d4e047a952909411da1daccde4ccc2e55f40882a62 +) +message(STATUS "Downloading tokenizer data to ${CMAKE_CURRENT_BINARY_DIR}/gpt2_tokenizer.bin") +download( + https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/gpt2_tokenizer.bin + ${CMAKE_CURRENT_BINARY_DIR}/wasi_llm/gpt2_tokenizer.bin + SHA256=6f3abc21e444e4e8300e225f4e03da48ea121cf17e30f67009b8dad7a66c2f13 +) + +add_test(wasiLLMTests wasiLLMTests) diff --git a/test/plugins/wasi_llm/wasi_llm.cpp b/test/plugins/wasi_llm/wasi_llm.cpp new file mode 100644 index 00000000..67790397 --- /dev/null +++ b/test/plugins/wasi_llm/wasi_llm.cpp @@ -0,0 +1,199 @@ +#include "common/defines.h" +#include "common/types.h" +#include "plugin/plugin.h" +#include "runtime/callingframe.h" +#include "runtime/instance/module.h" +#include "types.h" +#include "wasillmfunc.h" +#include "wasillmmodule.h" + +#include +#include +#include +#include +#include +#include +#include + +using WasmEdge::Host::WASILLM::ErrNo; + +namespace { +WasmEdge::Runtime::Instance::ModuleInstance *createModule() { + using namespace std::literals::string_view_literals; + WasmEdge::Plugin::Plugin::load( + std::filesystem::u8path("../../../plugins/wasi_llm/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasiLLM" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasi_llm"sv)) { + if (const auto *Module = Plugin->findModule("wasi_llm"sv)) { + return Module->create().release(); + } + } + return nullptr; +} +} // namespace + +template +void writeBinaries(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + WasmEdge::Span Binaries, uint32_t Ptr) noexcept { + std::copy(Binaries.begin(), Binaries.end(), MemInst.getPointer(Ptr)); +} + +void writeUInt32(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t Value, uint32_t &Ptr) { + uint32_t *BufPtr = MemInst.getPointer(Ptr); + *BufPtr = Value; + Ptr += 4; +} + +TEST(WasiLLMTest, TrainGPT2) { + // Create wasi_llm module instance. + auto *LLMMod = dynamic_cast(createModule()); + EXPECT_NE(LLMMod, nullptr); + EXPECT_EQ(LLMMod->getFuncExportNum(), 4U); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(60000))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + EXPECT_NE(MemInstPtr, nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + auto *ModelCreate = LLMMod->findFuncExports("model_create"); + EXPECT_NE(ModelCreate, nullptr); + EXPECT_TRUE(ModelCreate->isHostFunction()); + auto &HostFuncModelCreate = + dynamic_cast( + ModelCreate->getHostFunc()); + + auto *DataLoaderCreate = LLMMod->findFuncExports("dataloader_create"); + EXPECT_NE(DataLoaderCreate, nullptr); + EXPECT_TRUE(DataLoaderCreate->isHostFunction()); + auto &HostFuncDataLoadereCreate = + dynamic_cast( + DataLoaderCreate->getHostFunc()); + + auto *TokenizerCreate = LLMMod->findFuncExports("tokenizer_create"); + EXPECT_NE(TokenizerCreate, nullptr); + EXPECT_TRUE(TokenizerCreate->isHostFunction()); + auto &HostFuncTokenizerCreate = + dynamic_cast( + TokenizerCreate->getHostFunc()); + + auto *ModelTrain = LLMMod->findFuncExports("model_train"); + EXPECT_NE(ModelTrain, nullptr); + EXPECT_TRUE(ModelTrain->isHostFunction()); + auto &HostFuncModelTrain = dynamic_cast( + ModelTrain->getHostFunc()); + + std::array Errno = {UINT32_C(0)}; + + std::string CheckPointString = "./wasi_llm/gpt2_124M.bin"; + std::vector CheckPointPath(CheckPointString.begin(), + CheckPointString.end()); + uint32_t CheckPointPathPtr = UINT32_C(0); + writeBinaries(MemInst, CheckPointPath, CheckPointPathPtr); + + std::string TrainDataString = "./wasi_llm/tiny_shakespeare_train.bin"; + std::vector TrainDataPath(TrainDataString.begin(), + TrainDataString.end()); + uint32_t TrainDataPathPtr = CheckPointPathPtr + CheckPointPath.size(); + writeBinaries(MemInst, TrainDataPath, TrainDataPathPtr); + + std::string ValDataString = "./wasi_llm/tiny_shakespeare_val.bin"; + std::vector ValDataPath(ValDataString.begin(), ValDataString.end()); + uint32_t ValDataPathPtr = TrainDataPathPtr + TrainDataPath.size(); + writeBinaries(MemInst, ValDataPath, ValDataPathPtr); + + std::string TokenizerBin = "./wasi_llm/gpt2_tokenizer.bin"; + std::vector TokenizerBinPath(TokenizerBin.begin(), TokenizerBin.end()); + uint32_t TokenizerBinPtr = ValDataPathPtr + ValDataPath.size(); + writeBinaries(MemInst, TokenizerBinPath, TokenizerBinPtr); + + uint32_t ModelIdPtr = UINT32_C(0); + uint32_t ModelId = UINT32_C(0); + uint32_t TrainDataLoaderIdPtr = UINT32_C(0); + uint32_t TrainDataLoaderId = UINT32_C(0); + uint32_t ValDataLoaderIdPtr = UINT32_C(0); + uint32_t ValDataLoaderId = UINT32_C(0); + uint32_t TokenizerIdPtr = UINT32_C(0); + uint32_t TokenizerId = UINT32_C(0); + + { + EXPECT_TRUE(HostFuncModelCreate.run( + CallFrame, + std::initializer_list{ + CheckPointPathPtr, static_cast(CheckPointPath.size()), + ModelIdPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + ModelId = *MemInst.getPointer(ModelIdPtr); + EXPECT_EQ(ModelId, 0); + } + + { + EXPECT_TRUE(HostFuncDataLoadereCreate.run( + CallFrame, + std::initializer_list{ + TrainDataPathPtr, static_cast(TrainDataPath.size()), + /*B*/ 4, + /*T*/ 64, + /*ProcessRank*/ 0, + /*NumProcesses*/ 1, + /*ShouldShuffle*/ 1, TrainDataLoaderIdPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + TrainDataLoaderId = *MemInst.getPointer(TrainDataLoaderIdPtr); + EXPECT_EQ(TrainDataLoaderId, 0); + } + + { + EXPECT_TRUE(HostFuncDataLoadereCreate.run( + CallFrame, + std::initializer_list{ + ValDataPathPtr, static_cast(ValDataPath.size()), + /*B*/ 4, + /*T*/ 64, + /*ProcessRank*/ 0, + /*NumProcesses*/ 1, + /*ShouldShuffle*/ 0, ValDataLoaderIdPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + ValDataLoaderId = *MemInst.getPointer(ValDataLoaderIdPtr); + EXPECT_EQ(ValDataLoaderId, 1); + } + + { + EXPECT_TRUE(HostFuncTokenizerCreate.run( + CallFrame, + std::initializer_list{ + TokenizerBinPtr, static_cast(TokenizerBinPath.size()), + TokenizerIdPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + TokenizerId = *MemInst.getPointer(TokenizerIdPtr); + EXPECT_EQ(TokenizerId, 0); + } + + { + + EXPECT_TRUE(HostFuncModelTrain.run( + CallFrame, + std::initializer_list{ + ModelId, TrainDataLoaderId, ValDataLoaderId, TokenizerId, + /*B*/ 4, + /*T*/ 64, + /*Lr*/ 1e-4f, + /*Epoch*/ 20}, + Errno)); + } + + delete LLMMod; +} + +GTEST_API_ int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} From 51ec0b10171bd2dbad10034dbae387df952e3550 Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 14 Aug 2024 19:46:01 +0800 Subject: [PATCH 397/623] [Plugin] Stable Diffusion: enable CUBLAS (#3652) Signed-off-by: dm4 --- plugins/wasmedge_stablediffusion/CMakeLists.txt | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index 11097974..ff926ba0 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -1,6 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: 2019-2024 Second State INC + +if(WASMEDGE_PLUGIN_STABLEDIFFUSION_CUBLAS) + message(STATUS "Stable diffusion plugin: Enable SD_CUBLAS") + set(SD_CUBLAS ON) +else() + message(STATUS "Stable diffusion plugin: Disable SD_CUBLAS") + set(SD_CUBLAS OFF) +endif() + # setup stable diffusion message(STATUS "Downloading stable diffusion source") FetchContent_Declare( @@ -13,7 +22,7 @@ FetchContent_MakeAvailable(stable-diffusion) set_property(TARGET stable-diffusion PROPERTY POSITION_INDEPENDENT_CODE ON) if(APPLE AND CMAKE_SYSTEM_VERSION VERSION_LESS 23) # `cblas_sgemm()` introduced in macOS 13.3. - set(GGML_NO_ACCELERATE ON CACHE INTERNAL "Stable diffusion turn off accelerate") + set(GGML_NO_ACCELERATE ON CACHE INTERNAL "Stable diffusion plugin: Turn off accelerate") endif() get_target_property(SD_DEPS stable-diffusion LINK_LIBRARIES) foreach(dep ${SD_DEPS}) From d91a711ddfc7fd8ba272e8d58545f26e45ad1542 Mon Sep 17 00:00:00 2001 From: Yi-Ying He Date: Fri, 16 Aug 2024 08:41:20 +0800 Subject: [PATCH 398/623] [CMake] Add the WasmEdge component of cpack. (#3662) Signed-off-by: YiYing He --- plugins/wasi_crypto/CMakeLists.txt | 6 +++++- plugins/wasi_http/CMakeLists.txt | 6 +++++- plugins/wasi_llm/CMakeLists.txt | 6 +++++- plugins/wasi_nn/CMakeLists.txt | 6 +++++- plugins/wasi_ocr/CMakeLists.txt | 6 +++++- plugins/wasi_poll/CMakeLists.txt | 6 +++++- plugins/wasmedge_ffmpeg/CMakeLists.txt | 6 +++++- plugins/wasmedge_image/CMakeLists.txt | 6 +++++- plugins/wasmedge_opencvmini/CMakeLists.txt | 6 +++++- plugins/wasmedge_process/CMakeLists.txt | 6 +++++- plugins/wasmedge_stablediffusion/CMakeLists.txt | 6 +++++- plugins/wasmedge_tensorflow/CMakeLists.txt | 6 +++++- plugins/wasmedge_tensorflowlite/CMakeLists.txt | 6 +++++- plugins/wasmedge_zlib/CMakeLists.txt | 6 +++++- 14 files changed, 70 insertions(+), 14 deletions(-) diff --git a/plugins/wasi_crypto/CMakeLists.txt b/plugins/wasi_crypto/CMakeLists.txt index f8e2b71f..48ae00c9 100644 --- a/plugins/wasi_crypto/CMakeLists.txt +++ b/plugins/wasi_crypto/CMakeLists.txt @@ -79,4 +79,8 @@ else() ) endif() -install(TARGETS wasmedgePluginWasiCrypto DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) +install( + TARGETS wasmedgePluginWasiCrypto + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) diff --git a/plugins/wasi_http/CMakeLists.txt b/plugins/wasi_http/CMakeLists.txt index d85d4e7f..45bc0030 100644 --- a/plugins/wasi_http/CMakeLists.txt +++ b/plugins/wasi_http/CMakeLists.txt @@ -42,4 +42,8 @@ else() ) endif() -install(TARGETS wasmedgePluginWasiHttp DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) +install( + TARGETS wasmedgePluginWasiHttp + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) diff --git a/plugins/wasi_llm/CMakeLists.txt b/plugins/wasi_llm/CMakeLists.txt index ea8c71ea..440b5f51 100644 --- a/plugins/wasi_llm/CMakeLists.txt +++ b/plugins/wasi_llm/CMakeLists.txt @@ -45,4 +45,8 @@ else() ) endif() -install(TARGETS wasmedgePluginWasiLLM DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) +install( + TARGETS wasmedgePluginWasiLLM + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 8e9edde7..a4bfb187 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -229,4 +229,8 @@ endif() include(WASINNDeps) wasmedge_setup_wasinn_target(wasmedgePluginWasiNN) -install(TARGETS wasmedgePluginWasiNN DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) +install( + TARGETS wasmedgePluginWasiNN + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) diff --git a/plugins/wasi_ocr/CMakeLists.txt b/plugins/wasi_ocr/CMakeLists.txt index de9402e9..4690ebae 100644 --- a/plugins/wasi_ocr/CMakeLists.txt +++ b/plugins/wasi_ocr/CMakeLists.txt @@ -31,7 +31,11 @@ else() ) endif() -install(TARGETS wasmedgePluginWasiOCR DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) +install( + TARGETS wasmedgePluginWasiOCR + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) message(STATUS "WASI-OCR: Build Tesseract backend for WASI-OCR") find_package(PkgConfig REQUIRED) diff --git a/plugins/wasi_poll/CMakeLists.txt b/plugins/wasi_poll/CMakeLists.txt index cfa00821..9c641135 100644 --- a/plugins/wasi_poll/CMakeLists.txt +++ b/plugins/wasi_poll/CMakeLists.txt @@ -36,4 +36,8 @@ else() ) endif() -install(TARGETS wasmedgePluginWasiPoll DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) +install( + TARGETS wasmedgePluginWasiPoll + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) diff --git a/plugins/wasmedge_ffmpeg/CMakeLists.txt b/plugins/wasmedge_ffmpeg/CMakeLists.txt index 4895100d..f5fe627f 100644 --- a/plugins/wasmedge_ffmpeg/CMakeLists.txt +++ b/plugins/wasmedge_ffmpeg/CMakeLists.txt @@ -85,4 +85,8 @@ else() ) endif() -install(TARGETS wasmedgePluginWasmEdgeFFmpeg DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) +install( + TARGETS wasmedgePluginWasmEdgeFFmpeg + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) diff --git a/plugins/wasmedge_image/CMakeLists.txt b/plugins/wasmedge_image/CMakeLists.txt index bb4b3d1a..800c3a94 100644 --- a/plugins/wasmedge_image/CMakeLists.txt +++ b/plugins/wasmedge_image/CMakeLists.txt @@ -140,4 +140,8 @@ else() ) endif() -install(TARGETS wasmedgePluginWasmEdgeImage DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) +install( + TARGETS wasmedgePluginWasmEdgeImage + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) diff --git a/plugins/wasmedge_opencvmini/CMakeLists.txt b/plugins/wasmedge_opencvmini/CMakeLists.txt index 577ed424..ed10e816 100644 --- a/plugins/wasmedge_opencvmini/CMakeLists.txt +++ b/plugins/wasmedge_opencvmini/CMakeLists.txt @@ -36,4 +36,8 @@ else() ) endif() -install(TARGETS wasmedgePluginWasmEdgeOpenCVMini DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) +install( + TARGETS wasmedgePluginWasmEdgeOpenCVMini + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) diff --git a/plugins/wasmedge_process/CMakeLists.txt b/plugins/wasmedge_process/CMakeLists.txt index 5ae1dc1a..28a4bcce 100644 --- a/plugins/wasmedge_process/CMakeLists.txt +++ b/plugins/wasmedge_process/CMakeLists.txt @@ -31,4 +31,8 @@ else() ) endif() -install(TARGETS wasmedgePluginWasmEdgeProcess DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) +install( + TARGETS wasmedgePluginWasmEdgeProcess + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index ff926ba0..d8cb20d2 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -86,4 +86,8 @@ else() ) endif() -install(TARGETS wasmedgePluginWasmEdgeStableDiffusion DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) +install( + TARGETS wasmedgePluginWasmEdgeStableDiffusion + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) diff --git a/plugins/wasmedge_tensorflow/CMakeLists.txt b/plugins/wasmedge_tensorflow/CMakeLists.txt index 7189c2bb..ccfe25ed 100644 --- a/plugins/wasmedge_tensorflow/CMakeLists.txt +++ b/plugins/wasmedge_tensorflow/CMakeLists.txt @@ -34,4 +34,8 @@ endif() include(WASINNDeps) wasmedge_setup_tf_target(wasmedgePluginWasmEdgeTensorflow) -install(TARGETS wasmedgePluginWasmEdgeTensorflow DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) +install( + TARGETS wasmedgePluginWasmEdgeTensorflow + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) diff --git a/plugins/wasmedge_tensorflowlite/CMakeLists.txt b/plugins/wasmedge_tensorflowlite/CMakeLists.txt index 56cade99..f8ee177d 100644 --- a/plugins/wasmedge_tensorflowlite/CMakeLists.txt +++ b/plugins/wasmedge_tensorflowlite/CMakeLists.txt @@ -34,4 +34,8 @@ endif() include(WASINNDeps) wasmedge_setup_tflite_target(wasmedgePluginWasmEdgeTensorflowLite) -install(TARGETS wasmedgePluginWasmEdgeTensorflowLite DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) +install( + TARGETS wasmedgePluginWasmEdgeTensorflowLite + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) diff --git a/plugins/wasmedge_zlib/CMakeLists.txt b/plugins/wasmedge_zlib/CMakeLists.txt index 40f7ae89..56745021 100644 --- a/plugins/wasmedge_zlib/CMakeLists.txt +++ b/plugins/wasmedge_zlib/CMakeLists.txt @@ -37,4 +37,8 @@ else() ) endif() -install(TARGETS wasmedgePluginWasmEdgeZlib DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge) +install( + TARGETS wasmedgePluginWasmEdgeZlib + DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge +) From 4ff39745033bea95648e9b55aa6180be85b1b32e Mon Sep 17 00:00:00 2001 From: YiYing He Date: Fri, 16 Aug 2024 16:43:25 +0800 Subject: [PATCH 399/623] [Plugin] Fix the wasmedge-ffmpeg linking issue. Signed-off-by: YiYing He --- plugins/wasmedge_ffmpeg/CMakeLists.txt | 28 +- plugins/wasmedge_ffmpeg/avcodec/avCodec.h | 103 ++-- .../wasmedge_ffmpeg/avcodec/avCodecContext.h | 504 +++++++----------- .../avcodec/avCodecParameters.h | 20 +- plugins/wasmedge_ffmpeg/avcodec/avPacket.h | 109 ++-- .../wasmedge_ffmpeg/avcodec/avcodec_func.h | 153 ++---- plugins/wasmedge_ffmpeg/avcodec/module.h | 1 + .../wasmedge_ffmpeg/avdevice/avDevice_base.h | 28 - .../wasmedge_ffmpeg/avdevice/avDevice_func.h | 81 +-- plugins/wasmedge_ffmpeg/avdevice/module.h | 1 + plugins/wasmedge_ffmpeg/avfilter/avFilter.h | 71 +-- .../wasmedge_ffmpeg/avfilter/avfilter_base.h | 28 - .../wasmedge_ffmpeg/avfilter/avfilter_func.h | 125 ++--- .../avfilter/buffer_source_sink.h | 36 +- plugins/wasmedge_ffmpeg/avfilter/module.h | 1 + plugins/wasmedge_ffmpeg/avformat/avChapter.h | 55 +- .../avformat/avInputOutputFormat.h | 83 +-- plugins/wasmedge_ffmpeg/avformat/avStream.h | 86 ++- .../avformat/avformatContext.h | 58 +- .../wasmedge_ffmpeg/avformat/avformat_base.h | 28 - .../wasmedge_ffmpeg/avformat/avformat_func.h | 164 +++--- plugins/wasmedge_ffmpeg/avformat/module.h | 1 + plugins/wasmedge_ffmpeg/avutil/avDictionary.h | 28 +- plugins/wasmedge_ffmpeg/avutil/avFrame.h | 263 ++++----- plugins/wasmedge_ffmpeg/avutil/avRational.h | 58 +- plugins/wasmedge_ffmpeg/avutil/avTime.h | 23 +- plugins/wasmedge_ffmpeg/avutil/avutil_base.h | 28 - plugins/wasmedge_ffmpeg/avutil/avutil_func.h | 132 ++--- plugins/wasmedge_ffmpeg/avutil/error.h | 18 +- plugins/wasmedge_ffmpeg/avutil/module.h | 1 + plugins/wasmedge_ffmpeg/avutil/pixfmt.h | 76 +-- plugins/wasmedge_ffmpeg/avutil/samplefmt.h | 60 +-- .../{avcodec/avcodec_base.h => ffmpeg_base.h} | 9 +- plugins/wasmedge_ffmpeg/ffmpeg_env.cpp | 11 +- plugins/wasmedge_ffmpeg/ffmpeg_env.h | 3 +- plugins/wasmedge_ffmpeg/swresample/module.h | 1 + .../swresample/swresample_base.h | 28 - .../swresample/swresample_func.h | 62 +-- plugins/wasmedge_ffmpeg/swscale/module.h | 1 + .../wasmedge_ffmpeg/swscale/swscale_base.h | 28 - .../wasmedge_ffmpeg/swscale/swscale_func.h | 132 ++--- test/plugins/wasmedge_ffmpeg/CMakeLists.txt | 5 +- .../wasmedge_ffmpeg/avcodec/avCodec.cpp | 4 +- .../wasmedge_ffmpeg/avcodec/avCodecCtx.cpp | 4 + .../avcodec/avCodecParameters.cpp | 4 + .../wasmedge_ffmpeg/avcodec/avPacket.cpp | 5 +- .../wasmedge_ffmpeg/avcodec/avcodec_func.cpp | 5 +- .../wasmedge_ffmpeg/avfilter/avfilter.cpp | 4 + .../avfilter/avfilter_func.cpp | 5 +- .../wasmedge_ffmpeg/avformat/avChapter.cpp | 5 +- .../avformat/avInputOutputContext.cpp | 5 +- .../wasmedge_ffmpeg/avformat/avStream.cpp | 5 +- .../avformat/avformatContext.cpp | 5 +- .../avformat/avformat_func.cpp | 5 +- .../wasmedge_ffmpeg/avutil/avDictionary.cpp | 4 + .../wasmedge_ffmpeg/avutil/avError.cpp | 4 + .../wasmedge_ffmpeg/avutil/avFrame.cpp | 5 +- .../wasmedge_ffmpeg/avutil/avPixfmt.cpp | 5 +- .../wasmedge_ffmpeg/avutil/avRational.cpp | 5 +- .../wasmedge_ffmpeg/avutil/avSampleFmt.cpp | 4 + .../wasmedge_ffmpeg/avutil/avutil_func.cpp | 5 +- test/plugins/wasmedge_ffmpeg/main.cpp | 3 + .../swresample/swresample_func.cpp | 5 +- .../wasmedge_ffmpeg/swscale/swscale_func.cpp | 8 +- test/plugins/wasmedge_ffmpeg/utils.cpp | 12 +- test/plugins/wasmedge_ffmpeg/utils.h | 12 +- 66 files changed, 1077 insertions(+), 1782 deletions(-) delete mode 100644 plugins/wasmedge_ffmpeg/avdevice/avDevice_base.h delete mode 100644 plugins/wasmedge_ffmpeg/avfilter/avfilter_base.h delete mode 100644 plugins/wasmedge_ffmpeg/avformat/avformat_base.h delete mode 100644 plugins/wasmedge_ffmpeg/avutil/avutil_base.h rename plugins/wasmedge_ffmpeg/{avcodec/avcodec_base.h => ffmpeg_base.h} (68%) delete mode 100644 plugins/wasmedge_ffmpeg/swresample/swresample_base.h delete mode 100644 plugins/wasmedge_ffmpeg/swscale/swscale_base.h diff --git a/plugins/wasmedge_ffmpeg/CMakeLists.txt b/plugins/wasmedge_ffmpeg/CMakeLists.txt index f5fe627f..0a1ff4a8 100644 --- a/plugins/wasmedge_ffmpeg/CMakeLists.txt +++ b/plugins/wasmedge_ffmpeg/CMakeLists.txt @@ -14,32 +14,29 @@ pkg_check_modules(LIBAV REQUIRED IMPORTED_TARGET wasmedge_add_library(wasmedgePluginWasmEdgeFFmpeg SHARED - ffmpeg_env.cpp - - avcodec/module.cpp - avcodec/avcodec_func.cpp + avcodec/avCodecContext.cpp avcodec/avCodec.cpp avcodec/avCodecParameters.cpp avcodec/avPacket.cpp + avcodec/avcodec_func.cpp + avcodec/module.cpp - avdevice/module.cpp avdevice/avDevice_func.cpp + avdevice/module.cpp - avfilter/module.cpp - avfilter/avfilter_func.cpp avfilter/buffer_source_sink.cpp avfilter/avFilter.cpp + avfilter/avfilter_func.cpp + avfilter/module.cpp - avformat/module.cpp - avformat/avformat_func.cpp avformat/avformatContext.cpp avformat/avInputOutputFormat.cpp avformat/avStream.cpp avformat/avChapter.cpp + avformat/avformat_func.cpp + avformat/module.cpp - avutil/module.cpp - avutil/avutil_func.cpp avutil/error.cpp avutil/avRational.cpp avutil/avFrame.cpp @@ -47,13 +44,16 @@ wasmedge_add_library(wasmedgePluginWasmEdgeFFmpeg avutil/samplefmt.cpp avutil/avDictionary.cpp avutil/avTime.cpp + avutil/avutil_func.cpp + avutil/module.cpp - swresample/module.cpp swresample/swresample_func.cpp + swresample/module.cpp - swscale/module.cpp swscale/swscale_func.cpp - + swscale/module.cpp + + ffmpeg_env.cpp ) target_compile_options(wasmedgePluginWasmEdgeFFmpeg diff --git a/plugins/wasmedge_ffmpeg/avcodec/avCodec.h b/plugins/wasmedge_ffmpeg/avcodec/avCodec.h index 8954bd4e..70676c2f 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/avCodec.h +++ b/plugins/wasmedge_ffmpeg/avcodec/avCodec.h @@ -3,160 +3,135 @@ #pragma once -#include "avcodec_base.h" -#include "runtime/callingframe.h" +#include "ffmpeg_base.h" namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { namespace AVcodec { -class AVCodecID : public WasmEdgeFFmpegAVCodec { +class AVCodecID : public HostFunction { public: - AVCodecID(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); }; -class AVCodecType : public WasmEdgeFFmpegAVCodec { +class AVCodecType : public HostFunction { public: - AVCodecType(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); }; -class AVCodecMaxLowres : public WasmEdgeFFmpegAVCodec { +class AVCodecMaxLowres : public HostFunction { public: - AVCodecMaxLowres(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); }; -class AVCodecCapabilities : public WasmEdgeFFmpegAVCodec { +class AVCodecCapabilities : public HostFunction { public: - AVCodecCapabilities(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); }; -class AVCodecGetNameLen : public WasmEdgeFFmpegAVCodec { +class AVCodecGetNameLen : public HostFunction { public: - AVCodecGetNameLen(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); }; -class AVCodecGetName : public WasmEdgeFFmpegAVCodec { +class AVCodecGetName : public HostFunction { public: - AVCodecGetName(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, uint32_t NamePtr, uint32_t NameLen); }; -class AVCodecGetLongNameLen - : public WasmEdgeFFmpegAVCodec { +class AVCodecGetLongNameLen : public HostFunction { public: - AVCodecGetLongNameLen(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); }; -class AVCodecGetLongName : public WasmEdgeFFmpegAVCodec { +class AVCodecGetLongName : public HostFunction { public: - AVCodecGetLongName(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, uint32_t LongNamePtr, uint32_t LongNameLen); }; -class AVCodecProfiles : public WasmEdgeFFmpegAVCodec { +class AVCodecProfiles : public HostFunction { public: - AVCodecProfiles(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); }; -class AVCodecPixFmtsIsNull - : public WasmEdgeFFmpegAVCodec { +class AVCodecPixFmtsIsNull : public HostFunction { public: - AVCodecPixFmtsIsNull(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); }; -class AVCodecPixFmtsIter : public WasmEdgeFFmpegAVCodec { +class AVCodecPixFmtsIter : public HostFunction { public: - AVCodecPixFmtsIter(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, uint32_t Idx); }; class AVCodecSupportedFrameratesIsNull - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecSupportedFrameratesIsNull(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); }; class AVCodecSupportedFrameratesIter - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecSupportedFrameratesIter(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, uint32_t Idx, uint32_t NumPtr, uint32_t DenPtr); }; class AVCodecSupportedSampleRatesIsNull - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecSupportedSampleRatesIsNull(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); }; class AVCodecSupportedSampleRatesIter - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecSupportedSampleRatesIter(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, uint32_t Idx); }; class AVCodecChannelLayoutIsNull - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecChannelLayoutIsNull(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); }; -class AVCodecChannelLayoutIter - : public WasmEdgeFFmpegAVCodec { +class AVCodecChannelLayoutIter : public HostFunction { public: - AVCodecChannelLayoutIter(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, uint32_t Idx); }; -class AVCodecSampleFmtsIsNull - : public WasmEdgeFFmpegAVCodec { +class AVCodecSampleFmtsIsNull : public HostFunction { public: - AVCodecSampleFmtsIsNull(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); }; -class AVCodecSampleFmtsIter - : public WasmEdgeFFmpegAVCodec { +class AVCodecSampleFmtsIter : public HostFunction { public: - AVCodecSampleFmtsIter(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, uint32_t Idx); }; diff --git a/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.h b/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.h index ce443aec..88d86558 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.h +++ b/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.h @@ -3,813 +3,679 @@ #pragma once -#include "avcodec_base.h" -#include "runtime/callingframe.h" +#include "ffmpeg_base.h" namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { namespace AVcodec { -class AVCodecCtxCodecID : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxCodecID : public HostFunction { public: - AVCodecCtxCodecID(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; -class AVCodecCtxCodecType : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxCodecType : public HostFunction { public: - AVCodecCtxCodecType(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; -class AVCodecCtxSetCodecType - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetCodecType : public HostFunction { public: - AVCodecCtxSetCodecType(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t CodecTypeId); }; -class AVCodecCtxSetTimebase - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetTimebase : public HostFunction { public: - AVCodecCtxSetTimebase(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t Num, int32_t Den); }; -class AVCodecCtxTimeBase : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxTimeBase : public HostFunction { public: - AVCodecCtxTimeBase(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, uint32_t NumPtr, uint32_t DenPtr); }; -class AVCodecCtxWidth : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxWidth : public HostFunction { public: - AVCodecCtxWidth(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; -class AVCodecCtxSetWidth : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetWidth : public HostFunction { public: - AVCodecCtxSetWidth(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t Width); }; -class AVCodecCtxHeight : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxHeight : public HostFunction { public: - AVCodecCtxHeight(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; -class AVCodecCtxSetHeight : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetHeight : public HostFunction { public: - AVCodecCtxSetHeight(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t Height); }; class AVCodecCtxSampleAspectRatio - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxSampleAspectRatio(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, uint32_t NumPtr, uint32_t DenPtr); }; class AVCodecCtxSetSampleAspectRatio - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxSetSampleAspectRatio(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t Num, int32_t Den); }; -class AVCodecCtxChannelLayout - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxChannelLayout : public HostFunction { public: - AVCodecCtxChannelLayout(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; class AVCodecCtxSetChannelLayout - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxSetChannelLayout(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, uint64_t ChannelLayoutId); }; -class AVCodecCtxPixFormat : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxPixFormat : public HostFunction { public: - AVCodecCtxPixFormat(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; -class AVCodecCtxSetPixFormat - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetPixFormat : public HostFunction { public: - AVCodecCtxSetPixFormat(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, uint32_t PixFmtId); }; -class AVCodecCtxSampleFormat - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSampleFormat : public HostFunction { public: - AVCodecCtxSampleFormat(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; class AVCodecCtxSetSampleFormat - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxSetSampleFormat(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, uint32_t SampleFmtId); }; -class AVCodecCtxSampleRate - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSampleRate : public HostFunction { public: - AVCodecCtxSampleRate(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; -class AVCodecCtxSetSampleRate - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetSampleRate : public HostFunction { public: - AVCodecCtxSetSampleRate(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t SampleRate); }; -class AVCodecCtxSetGopSize - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetGopSize : public HostFunction { public: - AVCodecCtxSetGopSize(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t GopSize); }; -class AVCodecCtxSetMaxBFrames - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetMaxBFrames : public HostFunction { public: - AVCodecCtxSetMaxBFrames(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t MaxBFrames); }; class AVCodecCtxSetBQuantFactor - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxSetBQuantFactor(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, float BQuantFactor); }; class AVCodecCtxSetBQuantOffset - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxSetBQuantOffset(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, float BQuantOffset); }; class AVCodecCtxSetIQuantFactor - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxSetIQuantFactor(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, float IQuantFactor); }; class AVCodecCtxSetIQuantOffset - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxSetIQuantOffset(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, float IQuantOffset); }; -class AVCodecCtxSetLumiMasking - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetLumiMasking : public HostFunction { public: - AVCodecCtxSetLumiMasking(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, float LumiMasking); }; class AVCodecCtxSetTemporalCplxMasking - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxSetTemporalCplxMasking(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, float TemporalCplxMasking); }; class AVCodecCtxSetSpatialCplxMasking - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxSetSpatialCplxMasking(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, float SpatialCplxMasking); }; -class AVCodecCtxSetPMasking - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetPMasking : public HostFunction { public: - AVCodecCtxSetPMasking(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, float PMasking); }; -class AVCodecCtxSetDarkMasking - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetDarkMasking : public HostFunction { public: - AVCodecCtxSetDarkMasking(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, float DarkMasking); }; -class AVCodecCtxSetMeCmp : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetMeCmp : public HostFunction { public: - AVCodecCtxSetMeCmp(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t MeCmp); }; -class AVCodecCtxSetMeSubCmp - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetMeSubCmp : public HostFunction { public: - AVCodecCtxSetMeSubCmp(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t MeSubCmp); }; -class AVCodecCtxSetMbCmp : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetMbCmp : public HostFunction { public: - AVCodecCtxSetMbCmp(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t MbCmp); }; -class AVCodecCtxSetIldctCmp - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetIldctCmp : public HostFunction { public: - AVCodecCtxSetIldctCmp(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t IldctCmp); }; -class AVCodecCtxSetDiaSize - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetDiaSize : public HostFunction { public: - AVCodecCtxSetDiaSize(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t DiaSize); }; class AVCodecCtxSetLastPredictorsCount - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxSetLastPredictorsCount(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t LastPredictorCount); }; -class AVCodecCtxSetMePreCmp - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetMePreCmp : public HostFunction { public: - AVCodecCtxSetMePreCmp(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t MePreCmp); }; -class AVCodecCtxSetPreDiaSize - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetPreDiaSize : public HostFunction { public: - AVCodecCtxSetPreDiaSize(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t PreDiaSize); }; class AVCodecCtxSetMeSubpelQuality - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxSetMeSubpelQuality(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t MeSubpelQuality); }; -class AVCodecCtxSetMeRange - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetMeRange : public HostFunction { public: - AVCodecCtxSetMeRange(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t MeRange); }; -class AVCodecCtxSetMbDecision - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetMbDecision : public HostFunction { public: - AVCodecCtxSetMbDecision(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t MbDecision); }; -class AVCodecCtxSetMbLMin : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetMbLMin : public HostFunction { public: - AVCodecCtxSetMbLMin(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t MbLMin); }; -class AVCodecCtxSetMbLMax : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetMbLMax : public HostFunction { public: - AVCodecCtxSetMbLMax(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t MbLMax); }; class AVCodecCtxIntraDcPrecision - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxIntraDcPrecision(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; class AVCodecCtxSetIntraDcPrecision - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxSetIntraDcPrecision(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t IntraDcPrecision); }; -class AVCodecCtxSetQMin : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetQMin : public HostFunction { public: - AVCodecCtxSetQMin(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t QMin); }; -class AVCodecCtxSetQMax : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetQMax : public HostFunction { public: - AVCodecCtxSetQMax(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t QMax); }; class AVCodecCtxSetGlobalQuality - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxSetGlobalQuality(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t GlobalQuality); }; -class AVCodecCtxSetColorspace - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetColorspace : public HostFunction { public: - AVCodecCtxSetColorspace(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t ColorspaceId); }; -class AVCodecCtxColorspace - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxColorspace : public HostFunction { public: - AVCodecCtxColorspace(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; -class AVCodecCtxSetColorRange - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetColorRange : public HostFunction { public: - AVCodecCtxSetColorRange(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t ColorRange); }; -class AVCodecCtxColorRange - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxColorRange : public HostFunction { public: - AVCodecCtxColorRange(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; -class AVCodecCtxFrameSize : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxFrameSize : public HostFunction { public: - AVCodecCtxFrameSize(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; -class AVCodecCtxBitRate : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxBitRate : public HostFunction { public: - AVCodecCtxBitRate(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; -class AVCodecCtxSetBitRate - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetBitRate : public HostFunction { public: - AVCodecCtxSetBitRate(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int64_t BitRate); }; -class AVCodecCtxRcMaxRate : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxRcMaxRate : public HostFunction { public: - AVCodecCtxRcMaxRate(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; -class AVCodecCtxSetRcMaxRate - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetRcMaxRate : public HostFunction { public: - AVCodecCtxSetRcMaxRate(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int64_t RcMaxRate); }; class AVCodecCtxSetBitRateTolerance - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxSetBitRateTolerance(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t BitRateTolerance); }; class AVCodecCtxSetCompressionLevel - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxSetCompressionLevel(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t CompressionLevel); }; -class AVCodecCtxFrameRate : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxFrameRate : public HostFunction { public: - AVCodecCtxFrameRate(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, uint32_t NumPtr, uint32_t DenPtr); }; -class AVCodecCtxSetFrameRate - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetFrameRate : public HostFunction { public: - AVCodecCtxSetFrameRate(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t Num, int32_t Den); }; -class AVCodecCtxSetFlags : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetFlags : public HostFunction { public: - AVCodecCtxSetFlags(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t Flags); }; class AVCodecCtxSetStrictStdCompliance - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxSetStrictStdCompliance(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t ComplianceId); }; -class AVCodecCtxSetDebug : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetDebug : public HostFunction { public: - AVCodecCtxSetDebug(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t Debug); }; -class AVCodecCtxCodec : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxCodec : public HostFunction { public: - AVCodecCtxCodec(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, uint32_t AvCodecPtr); }; -class AVCodecCtxChannels : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxChannels : public HostFunction { public: - AVCodecCtxChannels(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; -class AVCodecCtxSetChannels - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetChannels : public HostFunction { public: - AVCodecCtxSetChannels(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t Channels); }; class AVCodecCtxSetSkipLoopFilter - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxSetSkipLoopFilter(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t AVDicardId); }; -class AVCodecCtxSetSkipFrame - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetSkipFrame : public HostFunction { public: - AVCodecCtxSetSkipFrame(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t AVDiscardId); }; -class AVCodecCtxSetSkipIdct - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetSkipIdct : public HostFunction { public: - AVCodecCtxSetSkipIdct(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t AVDicardId); }; class AVCodecCtxSetErrorConcealment - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxSetErrorConcealment(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t ErrorConcealment); }; class AVCodecCtxSetErrorRecognition - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxSetErrorRecognition(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t ErrorRecognition); }; -class AVCodecCtxDelay : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxDelay : public HostFunction { public: - AVCodecCtxDelay(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; -class AVCodecCtxSetSkipTop - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetSkipTop : public HostFunction { public: - AVCodecCtxSetSkipTop(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t Value); }; -class AVCodecCtxSetSkipBottom - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetSkipBottom : public HostFunction { public: - AVCodecCtxSetSkipBottom(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t Value); }; -class AVCodecCtxRefs : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxRefs : public HostFunction { public: - AVCodecCtxRefs(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; -class AVCodecCtxSetSliceFlags - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetSliceFlags : public HostFunction { public: - AVCodecCtxSetSliceFlags(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t Flags); }; -class AVCodecCtxSetSliceCount - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetSliceCount : public HostFunction { public: - AVCodecCtxSetSliceCount(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t Value); }; -class AVCodecCtxSetFieldOrder - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetFieldOrder : public HostFunction { public: - AVCodecCtxSetFieldOrder(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t Value); }; -class AVCodecCtxColorTrc : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxColorTrc : public HostFunction { public: - AVCodecCtxColorTrc(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; class AVCodecCtxChromaSampleLocation - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxChromaSampleLocation(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; -class AVCodecCtxFrameNumber - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxFrameNumber : public HostFunction { public: - AVCodecCtxFrameNumber(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; -class AVCodecCtxBlockAlign - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxBlockAlign : public HostFunction { public: - AVCodecCtxBlockAlign(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; class AVCodecCtxSetRequestSampleFmt - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxSetRequestSampleFmt(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, uint32_t SampleFmtId); }; class AVCodecCtxAudioServiceType - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxAudioServiceType(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; -class AVCodecCtxHasBFrames - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxHasBFrames : public HostFunction { public: - AVCodecCtxHasBFrames(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; class AVCodecCtxSetRequestChannelLayout - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxSetRequestChannelLayout(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, uint64_t ChannelLayoutId); }; class AVCodecCtxActiveThreadType - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecCtxActiveThreadType(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; -class AVCodecCtxSetThreadType - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetThreadType : public HostFunction { public: - AVCodecCtxSetThreadType(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t ThreadType); }; -class AVCodecCtxThreadCount - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxThreadCount : public HostFunction { public: - AVCodecCtxThreadCount(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; -class AVCodecCtxSetThreadCount - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxSetThreadCount : public HostFunction { public: - AVCodecCtxSetThreadCount(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, int32_t ThreadCount); }; -class AVCodecCtxColorPrimaries - : public WasmEdgeFFmpegAVCodec { +class AVCodecCtxColorPrimaries : public HostFunction { public: - AVCodecCtxColorPrimaries(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; diff --git a/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.h b/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.h index 5b3043d8..fc6557d3 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.h +++ b/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.h @@ -3,36 +3,30 @@ #pragma once -#include "avcodec_base.h" -#include "runtime/callingframe.h" +#include "ffmpeg_base.h" namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { namespace AVcodec { -class AVCodecParamCodecId : public WasmEdgeFFmpegAVCodec { +class AVCodecParamCodecId : public HostFunction { public: - AVCodecParamCodecId(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecParamId); }; -class AVCodecParamCodecType - : public WasmEdgeFFmpegAVCodec { +class AVCodecParamCodecType : public HostFunction { public: - AVCodecParamCodecType(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecParamId); }; -class AVCodecParamSetCodecTag - : public WasmEdgeFFmpegAVCodec { +class AVCodecParamSetCodecTag : public HostFunction { public: - AVCodecParamSetCodecTag(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecParamId, uint32_t CodecTag); }; diff --git a/plugins/wasmedge_ffmpeg/avcodec/avPacket.h b/plugins/wasmedge_ffmpeg/avcodec/avPacket.h index c54eee8c..403d55a2 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/avPacket.h +++ b/plugins/wasmedge_ffmpeg/avcodec/avPacket.h @@ -3,170 +3,147 @@ #pragma once -#include "avcodec_base.h" -#include "runtime/callingframe.h" +#include "ffmpeg_base.h" namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { namespace AVcodec { -class AVPacketAlloc : public WasmEdgeFFmpegAVCodec { +class AVPacketAlloc : public HostFunction { public: - AVPacketAlloc(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketPtr); }; -class AVNewPacket : public WasmEdgeFFmpegAVCodec { +class AVNewPacket : public HostFunction { public: - AVNewPacket(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, int32_t Size); }; -class AVPacketRef : public WasmEdgeFFmpegAVCodec { +class AVPacketRef : public HostFunction { public: - AVPacketRef(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t DestPacketId, uint32_t SrcPacketId); }; -class AVPacketUnref : public WasmEdgeFFmpegAVCodec { +class AVPacketUnref : public HostFunction { public: - AVPacketUnref(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); }; -class AVGrowPacket : public WasmEdgeFFmpegAVCodec { +class AVGrowPacket : public HostFunction { public: - AVGrowPacket(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, int32_t Size); }; -class AVShrinkPacket : public WasmEdgeFFmpegAVCodec { +class AVShrinkPacket : public HostFunction { public: - AVShrinkPacket(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, int32_t Size); }; -class AVPacketStreamIndex : public WasmEdgeFFmpegAVCodec { +class AVPacketStreamIndex : public HostFunction { public: - AVPacketStreamIndex(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); }; -class AVPacketSetStreamIndex - : public WasmEdgeFFmpegAVCodec { +class AVPacketSetStreamIndex : public HostFunction { public: - AVPacketSetStreamIndex(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, int32_t StreamIdx); }; -class AVPacketSize : public WasmEdgeFFmpegAVCodec { +class AVPacketSize : public HostFunction { public: - AVPacketSize(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); }; -class AVPacketFlags : public WasmEdgeFFmpegAVCodec { +class AVPacketFlags : public HostFunction { public: - AVPacketFlags(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); }; -class AVPacketSetFlags : public WasmEdgeFFmpegAVCodec { +class AVPacketSetFlags : public HostFunction { public: - AVPacketSetFlags(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, int32_t Flags); }; -class AVPacketPos : public WasmEdgeFFmpegAVCodec { +class AVPacketPos : public HostFunction { public: - AVPacketPos(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); }; -class AVPacketSetPos : public WasmEdgeFFmpegAVCodec { +class AVPacketSetPos : public HostFunction { public: - AVPacketSetPos(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, int64_t Pos); }; -class AVPacketDuration : public WasmEdgeFFmpegAVCodec { +class AVPacketDuration : public HostFunction { public: - AVPacketDuration(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); }; -class AVPacketSetDuration : public WasmEdgeFFmpegAVCodec { +class AVPacketSetDuration : public HostFunction { public: - AVPacketSetDuration(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, int64_t Duration); }; -class AVPacketDts : public WasmEdgeFFmpegAVCodec { +class AVPacketDts : public HostFunction { public: - AVPacketDts(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); }; -class AVPacketSetDts : public WasmEdgeFFmpegAVCodec { +class AVPacketSetDts : public HostFunction { public: - AVPacketSetDts(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, int64_t Dts); }; -class AVPacketPts : public WasmEdgeFFmpegAVCodec { +class AVPacketPts : public HostFunction { public: - AVPacketPts(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); }; -class AVPacketSetPts : public WasmEdgeFFmpegAVCodec { +class AVPacketSetPts : public HostFunction { public: - AVPacketSetPts(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, int64_t Pts); }; -class AVPacketIsDataNull : public WasmEdgeFFmpegAVCodec { +class AVPacketIsDataNull : public HostFunction { public: - AVPacketIsDataNull(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId); }; -class AVPacketData : public WasmEdgeFFmpegAVCodec { +class AVPacketData : public HostFunction { public: - AVPacketData(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPacketId, uint32_t DataPtr, uint32_t DataLen); }; diff --git a/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.h b/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.h index f7ea4d24..c560e065 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.h +++ b/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.h @@ -3,242 +3,203 @@ #pragma once -#include "avcodec_base.h" -#include "runtime/callingframe.h" +#include "ffmpeg_base.h" namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { namespace AVcodec { -class AVCodecAllocContext3 - : public WasmEdgeFFmpegAVCodec { +class AVCodecAllocContext3 : public HostFunction { public: - AVCodecAllocContext3(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, uint32_t AvCodecCtxPtr); }; class AVCodecParametersFromContext - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecParametersFromContext(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecParamId, uint32_t AvCodecCtxId); }; -class AVCodecParametersFree - : public WasmEdgeFFmpegAVCodec { +class AVCodecParametersFree : public HostFunction { public: - AVCodecParametersFree(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecParamId); }; -class AVCodecFreeContext : public WasmEdgeFFmpegAVCodec { +class AVCodecFreeContext : public HostFunction { public: - AVCodecFreeContext(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId); }; -class AVCodecParametersAlloc - : public WasmEdgeFFmpegAVCodec { +class AVCodecParametersAlloc : public HostFunction { public: - AVCodecParametersAlloc(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecParamPtr); }; -class AVCodecGetType : public WasmEdgeFFmpegAVCodec { +class AVCodecGetType : public HostFunction { public: - AVCodecGetType(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); }; -class AVCodecOpen2 : public WasmEdgeFFmpegAVCodec { +class AVCodecOpen2 : public HostFunction { public: - AVCodecOpen2(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, uint32_t AvCodecId, uint32_t AvDictionaryId); }; -class AVCodecFindDecoder : public WasmEdgeFFmpegAVCodec { +class AVCodecFindDecoder : public HostFunction { public: - AVCodecFindDecoder(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t ID, uint32_t AvCodecId); }; -class AVCodecIsEncoder : public WasmEdgeFFmpegAVCodec { +class AVCodecIsEncoder : public HostFunction { public: - AVCodecIsEncoder(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); }; -class AVCodecIsDecoder : public WasmEdgeFFmpegAVCodec { +class AVCodecIsDecoder : public HostFunction { public: - AVCodecIsDecoder(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); }; -class AVCodecClose : public WasmEdgeFFmpegAVCodec { +class AVCodecClose : public HostFunction { public: - AVCodecClose(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId); }; class AVCodecParametersToContext - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecParametersToContext(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecId, uint32_t AvCodecParamId); }; -class AVCodecReceiveFrame : public WasmEdgeFFmpegAVCodec { +class AVCodecReceiveFrame : public HostFunction { public: - AVCodecReceiveFrame(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, uint32_t FrameId); }; -class AVCodecSendPacket : public WasmEdgeFFmpegAVCodec { +class AVCodecSendPacket : public HostFunction { public: - AVCodecSendPacket(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvCodecCtxId, uint32_t PacketId); }; -class AVCodecFindEncoder : public WasmEdgeFFmpegAVCodec { +class AVCodecFindEncoder : public HostFunction { public: - AVCodecFindEncoder(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t ID, uint32_t AVCodecPtr); }; -class AVCodecReceivePacket - : public WasmEdgeFFmpegAVCodec { +class AVCodecReceivePacket : public HostFunction { public: - AVCodecReceivePacket(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AVCodecCtxId, uint32_t PacketId); }; -class AVCodecSendFrame : public WasmEdgeFFmpegAVCodec { +class AVCodecSendFrame : public HostFunction { public: - AVCodecSendFrame(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AVCodecCtxId, uint32_t FrameId); }; -class AVCodecFindDecoderByName - : public WasmEdgeFFmpegAVCodec { +class AVCodecFindDecoderByName : public HostFunction { public: - AVCodecFindDecoderByName(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AVCodecPtr, uint32_t NamePtr, uint32_t NameLen); }; -class AVCodecFindEncoderByName - : public WasmEdgeFFmpegAVCodec { +class AVCodecFindEncoderByName : public HostFunction { public: - AVCodecFindEncoderByName(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AVCodecPtr, uint32_t NamePtr, uint32_t NameLen); }; -class AVPacketRescaleTs : public WasmEdgeFFmpegAVCodec { +class AVPacketRescaleTs : public HostFunction { public: - AVPacketRescaleTs(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AVPacketId, int32_t SrcNum, int32_t SrcDen, int32_t DestNum, int32_t DestDen); }; -class AVPacketMakeWritable - : public WasmEdgeFFmpegAVCodec { +class AVPacketMakeWritable : public HostFunction { public: - AVPacketMakeWritable(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AVPacketId); }; -class AVCodecParametersCopy - : public WasmEdgeFFmpegAVCodec { +class AVCodecParametersCopy : public HostFunction { public: - AVCodecParametersCopy(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AVFormatCtxId, uint32_t AVCodecParamId, uint32_t StreamIdx); }; -class AVCodecVersion : public WasmEdgeFFmpegAVCodec { +class AVCodecVersion : public HostFunction { public: - AVCodecVersion(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVCodecFlushBuffers : public WasmEdgeFFmpegAVCodec { +class AVCodecFlushBuffers : public HostFunction { public: - AVCodecFlushBuffers(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AVCodecCtxId); }; class AVCodecConfigurationLength - : public WasmEdgeFFmpegAVCodec { + : public HostFunction { public: - AVCodecConfigurationLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVCodecConfiguration - : public WasmEdgeFFmpegAVCodec { +class AVCodecConfiguration : public HostFunction { public: - AVCodecConfiguration(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, uint32_t ConfigLen); }; -class AVCodecLicenseLength - : public WasmEdgeFFmpegAVCodec { +class AVCodecLicenseLength : public HostFunction { public: - AVCodecLicenseLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVCodecLicense : public WasmEdgeFFmpegAVCodec { +class AVCodecLicense : public HostFunction { public: - AVCodecLicense(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVCodec(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, uint32_t LicenseLen); }; diff --git a/plugins/wasmedge_ffmpeg/avcodec/module.h b/plugins/wasmedge_ffmpeg/avcodec/module.h index 7bb8ff67..b7e63947 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/module.h +++ b/plugins/wasmedge_ffmpeg/avcodec/module.h @@ -4,6 +4,7 @@ #pragma once #include "ffmpeg_env.h" + #include "runtime/instance/module.h" namespace WasmEdge { diff --git a/plugins/wasmedge_ffmpeg/avdevice/avDevice_base.h b/plugins/wasmedge_ffmpeg/avdevice/avDevice_base.h deleted file mode 100644 index ae7c37a1..00000000 --- a/plugins/wasmedge_ffmpeg/avdevice/avDevice_base.h +++ /dev/null @@ -1,28 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2024 Second State INC - -#pragma once - -#include "ffmpeg_env.h" -#include "runtime/hostfunc.h" - -namespace WasmEdge { -namespace Host { -namespace WasmEdgeFFmpeg { -namespace AVDevice { - -template -class WasmEdgeFFmpegAVDevice : public Runtime::HostFunction { -public: - WasmEdgeFFmpegAVDevice( - std::shared_ptr HostEnv) - : Runtime::HostFunction(0), Env(HostEnv) {} - -protected: - std::shared_ptr Env; -}; - -} // namespace AVDevice -} // namespace WasmEdgeFFmpeg -} // namespace Host -} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.h b/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.h index cccfa303..6b5832f5 100644 --- a/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.h +++ b/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.h @@ -3,123 +3,100 @@ #pragma once -#include "avDevice_base.h" -#include "runtime/callingframe.h" +#include "ffmpeg_base.h" namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { namespace AVDevice { -class AVDeviceRegisterAll : public WasmEdgeFFmpegAVDevice { +class AVDeviceRegisterAll : public HostFunction { public: - AVDeviceRegisterAll(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVDevice(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVDeviceVersion : public WasmEdgeFFmpegAVDevice { +class AVDeviceVersion : public HostFunction { public: - AVDeviceVersion(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVDevice(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVDeviceListDevices : public WasmEdgeFFmpegAVDevice { +class AVDeviceListDevices : public HostFunction { public: - AVDeviceListDevices(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVDevice(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AVFormatCtxId, uint32_t AVDeviceInfoListPtr); }; -class AVInputAudioDeviceNext - : public WasmEdgeFFmpegAVDevice { +class AVInputAudioDeviceNext : public HostFunction { public: - AVInputAudioDeviceNext(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVDevice(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &); }; -class AVInputVideoDeviceNext - : public WasmEdgeFFmpegAVDevice { +class AVInputVideoDeviceNext : public HostFunction { public: - AVInputVideoDeviceNext(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVDevice(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &); }; -class AVOutputAudioDeviceNext - : public WasmEdgeFFmpegAVDevice { +class AVOutputAudioDeviceNext : public HostFunction { public: - AVOutputAudioDeviceNext(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVDevice(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &); }; -class AVOutputVideoDeviceNext - : public WasmEdgeFFmpegAVDevice { +class AVOutputVideoDeviceNext : public HostFunction { public: - AVOutputVideoDeviceNext(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVDevice(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &); }; -class AVDeviceFreeListDevices - : public WasmEdgeFFmpegAVDevice { +class AVDeviceFreeListDevices : public HostFunction { public: - AVDeviceFreeListDevices(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVDevice(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AVDeviceInfoListId); }; -class AVDeviceNbDevices : public WasmEdgeFFmpegAVDevice { +class AVDeviceNbDevices : public HostFunction { public: - AVDeviceNbDevices(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVDevice(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AVDeviceInfoListId); }; -class AVDeviceDefaultDevice - : public WasmEdgeFFmpegAVDevice { +class AVDeviceDefaultDevice : public HostFunction { public: - AVDeviceDefaultDevice(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVDevice(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AVDeviceInfoListId); }; class AVDeviceConfigurationLength - : public WasmEdgeFFmpegAVDevice { + : public HostFunction { public: - AVDeviceConfigurationLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVDevice(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVDeviceConfiguration - : public WasmEdgeFFmpegAVDevice { +class AVDeviceConfiguration : public HostFunction { public: - AVDeviceConfiguration(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVDevice(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, uint32_t ConfigLen); }; -class AVDeviceLicenseLength - : public WasmEdgeFFmpegAVDevice { +class AVDeviceLicenseLength : public HostFunction { public: - AVDeviceLicenseLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVDevice(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVDeviceLicense : public WasmEdgeFFmpegAVDevice { +class AVDeviceLicense : public HostFunction { public: - AVDeviceLicense(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVDevice(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, uint32_t LicenseLen); }; diff --git a/plugins/wasmedge_ffmpeg/avdevice/module.h b/plugins/wasmedge_ffmpeg/avdevice/module.h index 748d4cae..26ed72df 100644 --- a/plugins/wasmedge_ffmpeg/avdevice/module.h +++ b/plugins/wasmedge_ffmpeg/avdevice/module.h @@ -4,6 +4,7 @@ #pragma once #include "ffmpeg_env.h" + #include "runtime/instance/module.h" namespace WasmEdge { diff --git a/plugins/wasmedge_ffmpeg/avfilter/avFilter.h b/plugins/wasmedge_ffmpeg/avfilter/avFilter.h index 58129a15..8141d3d7 100644 --- a/plugins/wasmedge_ffmpeg/avfilter/avFilter.h +++ b/plugins/wasmedge_ffmpeg/avfilter/avFilter.h @@ -3,116 +3,99 @@ #pragma once -#include "avfilter_base.h" -#include "runtime/callingframe.h" +#include "ffmpeg_base.h" namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { namespace AVFilter { -class AVFilterNameLength : public WasmEdgeFFmpegAVFilter { +class AVFilterNameLength : public HostFunction { public: - AVFilterNameLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId); }; -class AVFilterName : public WasmEdgeFFmpegAVFilter { +class AVFilterName : public HostFunction { public: - AVFilterName(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId, uint32_t NamePtr, uint32_t NameLen); }; class AVFilterDescriptionLength - : public WasmEdgeFFmpegAVFilter { + : public HostFunction { public: - AVFilterDescriptionLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId); }; -class AVFilterDescription : public WasmEdgeFFmpegAVFilter { +class AVFilterDescription : public HostFunction { public: - AVFilterDescription(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId, uint32_t DescPtr, uint32_t DescLen); }; -class AVFilterNbInputs : public WasmEdgeFFmpegAVFilter { +class AVFilterNbInputs : public HostFunction { public: - AVFilterNbInputs(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId); }; -class AVFilterNbOutputs : public WasmEdgeFFmpegAVFilter { +class AVFilterNbOutputs : public HostFunction { public: - AVFilterNbOutputs(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId); }; -class AVFilterFlags : public WasmEdgeFFmpegAVFilter { +class AVFilterFlags : public HostFunction { public: - AVFilterFlags(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId); }; -class AVFilterInOutSetName - : public WasmEdgeFFmpegAVFilter { +class AVFilterInOutSetName : public HostFunction { public: - AVFilterInOutSetName(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t InOutId, uint32_t NamePtr, uint32_t NameLen); }; class AVFilterInOutSetFilterCtx - : public WasmEdgeFFmpegAVFilter { + : public HostFunction { public: - AVFilterInOutSetFilterCtx(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t InOutId, uint32_t FilterCtxId); }; -class AVFilterInOutSetPadIdx - : public WasmEdgeFFmpegAVFilter { +class AVFilterInOutSetPadIdx : public HostFunction { public: - AVFilterInOutSetPadIdx(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t InOutId, int32_t PadIdx); }; -class AVFilterInOutSetNext - : public WasmEdgeFFmpegAVFilter { +class AVFilterInOutSetNext : public HostFunction { public: - AVFilterInOutSetNext(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t InOutId, uint32_t NextInOutId); }; class AVFilterGetInputsFilterPad - : public WasmEdgeFFmpegAVFilter { + : public HostFunction { public: - AVFilterGetInputsFilterPad(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId, uint32_t FilterPadPtr); }; class AVFilterGetOutputsFilterPad - : public WasmEdgeFFmpegAVFilter { + : public HostFunction { public: - AVFilterGetOutputsFilterPad(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId, uint32_t FilterPadPtr); }; diff --git a/plugins/wasmedge_ffmpeg/avfilter/avfilter_base.h b/plugins/wasmedge_ffmpeg/avfilter/avfilter_base.h deleted file mode 100644 index 64209977..00000000 --- a/plugins/wasmedge_ffmpeg/avfilter/avfilter_base.h +++ /dev/null @@ -1,28 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2024 Second State INC - -#pragma once - -#include "ffmpeg_env.h" -#include "runtime/hostfunc.h" - -namespace WasmEdge { -namespace Host { -namespace WasmEdgeFFmpeg { -namespace AVFilter { - -template -class WasmEdgeFFmpegAVFilter : public Runtime::HostFunction { -public: - WasmEdgeFFmpegAVFilter( - std::shared_ptr HostEnv) - : Runtime::HostFunction(0), Env(HostEnv) {} - -protected: - std::shared_ptr Env; -}; - -} // namespace AVFilter -} // namespace WasmEdgeFFmpeg -} // namespace Host -} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.h b/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.h index 18f037dc..6f8b486d 100644 --- a/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.h +++ b/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.h @@ -3,203 +3,172 @@ #pragma once -#include "avfilter_base.h" -#include "runtime/callingframe.h" +#include "ffmpeg_base.h" namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { namespace AVFilter { -class AVFilterGraphAlloc : public WasmEdgeFFmpegAVFilter { +class AVFilterGraphAlloc : public HostFunction { public: - AVFilterGraphAlloc(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterGraphPtr); }; -class AVFilterGraphConfig : public WasmEdgeFFmpegAVFilter { +class AVFilterGraphConfig : public HostFunction { public: - AVFilterGraphConfig(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterGraphId); }; -class AVFilterGraphFree : public WasmEdgeFFmpegAVFilter { +class AVFilterGraphFree : public HostFunction { public: - AVFilterGraphFree(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterGraphId); }; -class AVFilterGraphGetFilter - : public WasmEdgeFFmpegAVFilter { +class AVFilterGraphGetFilter : public HostFunction { public: - AVFilterGraphGetFilter(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterCtxPtr, uint32_t FilterGraphId, uint32_t NamePtr, uint32_t NameSize); }; -class AVFilterGraphParsePtr - : public WasmEdgeFFmpegAVFilter { +class AVFilterGraphParsePtr : public HostFunction { public: - AVFilterGraphParsePtr(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterGraphId, uint32_t FiltersString, uint32_t FiltersSize, uint32_t InputsId, uint32_t OutputsId); }; -class AVFilterInOutFree : public WasmEdgeFFmpegAVFilter { +class AVFilterInOutFree : public HostFunction { public: - AVFilterInOutFree(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t InOutId); }; -class AVFilterVersion : public WasmEdgeFFmpegAVFilter { +class AVFilterVersion : public HostFunction { public: - AVFilterVersion(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVFilterGetByName : public WasmEdgeFFmpegAVFilter { +class AVFilterGetByName : public HostFunction { public: - AVFilterGetByName(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterPtr, uint32_t StrPtr, uint32_t StrLen); }; class AVFilterConfigurationLength - : public WasmEdgeFFmpegAVFilter { + : public HostFunction { public: - AVFilterConfigurationLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVFilterConfiguration - : public WasmEdgeFFmpegAVFilter { +class AVFilterConfiguration : public HostFunction { public: - AVFilterConfiguration(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, uint32_t ConfigLen); }; -class AVFilterLicenseLength - : public WasmEdgeFFmpegAVFilter { +class AVFilterLicenseLength : public HostFunction { public: - AVFilterLicenseLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVFilterLicense : public WasmEdgeFFmpegAVFilter { +class AVFilterLicense : public HostFunction { public: - AVFilterLicense(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, uint32_t LicenseLen); }; class AVFilterGraphCreateFilter - : public WasmEdgeFFmpegAVFilter { + : public HostFunction { public: - AVFilterGraphCreateFilter(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterCtxPtr, uint32_t FilterId, uint32_t NamePtr, uint32_t NameLen, uint32_t ArgsPtr, uint32_t ArgsLen, uint32_t FilterGraphId); }; -class AVFilterInOutAlloc : public WasmEdgeFFmpegAVFilter { +class AVFilterInOutAlloc : public HostFunction { public: - AVFilterInOutAlloc(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t InOutPtr); }; -class AVFilterPadGetNameLength - : public WasmEdgeFFmpegAVFilter { +class AVFilterPadGetNameLength : public HostFunction { public: - AVFilterPadGetNameLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterPadId, int32_t Idx); }; -class AVFilterPadGetName : public WasmEdgeFFmpegAVFilter { +class AVFilterPadGetName : public HostFunction { public: - AVFilterPadGetName(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterPadId, int32_t Idx, uint32_t NamePtr, uint32_t NameLen); }; -class AVFilterPadGetType : public WasmEdgeFFmpegAVFilter { +class AVFilterPadGetType : public HostFunction { public: - AVFilterPadGetType(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterPadId, int32_t Idx); }; -class AVFilterGraphDumpLength - : public WasmEdgeFFmpegAVFilter { +class AVFilterGraphDumpLength : public HostFunction { public: - AVFilterGraphDumpLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterGraphId); }; -class AVFilterGraphDump : public WasmEdgeFFmpegAVFilter { +class AVFilterGraphDump : public HostFunction { public: - AVFilterGraphDump(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterGraphId, uint32_t GraphStrPtr, uint32_t GraphStrLen); }; -class AVFilterFreeGraphStr - : public WasmEdgeFFmpegAVFilter { +class AVFilterFreeGraphStr : public HostFunction { public: - AVFilterFreeGraphStr(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterGraphId); }; -class AVFilterDrop : public WasmEdgeFFmpegAVFilter { +class AVFilterDrop : public HostFunction { public: - AVFilterDrop(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterId); }; -class AVFilterPadDrop : public WasmEdgeFFmpegAVFilter { +class AVFilterPadDrop : public HostFunction { public: - AVFilterPadDrop(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterPadId); }; -class AVFilterContextDrop : public WasmEdgeFFmpegAVFilter { +class AVFilterContextDrop : public HostFunction { public: - AVFilterContextDrop(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterCtxId); }; diff --git a/plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.h b/plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.h index 41144b9f..40db402b 100644 --- a/plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.h +++ b/plugins/wasmedge_ffmpeg/avfilter/buffer_source_sink.h @@ -3,63 +3,53 @@ #pragma once -#include "avfilter_base.h" -#include "runtime/callingframe.h" +#include "ffmpeg_base.h" namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { namespace AVFilter { -class AVBufferSinkGetFrame - : public WasmEdgeFFmpegAVFilter { +class AVBufferSinkGetFrame : public HostFunction { public: - AVBufferSinkGetFrame(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterContextId, uint32_t FrameId); }; -class AVBufferSinkGetSamples - : public WasmEdgeFFmpegAVFilter { +class AVBufferSinkGetSamples : public HostFunction { public: - AVBufferSinkGetSamples(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterContextId, uint32_t FrameId, int32_t Samples); }; -class AvBufferSinkSetFrameSize - : public WasmEdgeFFmpegAVFilter { +class AvBufferSinkSetFrameSize : public HostFunction { public: - AvBufferSinkSetFrameSize(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterContextId, int32_t Value); }; class AVBufferSrcGetNbFailedRequests - : public WasmEdgeFFmpegAVFilter { + : public HostFunction { public: - AVBufferSrcGetNbFailedRequests(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterContextId); }; -class AVBufferSrcAddFrame : public WasmEdgeFFmpegAVFilter { +class AVBufferSrcAddFrame : public HostFunction { public: - AVBufferSrcAddFrame(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterContextId, uint32_t FrameId); }; -class AVBufferSrcClose : public WasmEdgeFFmpegAVFilter { +class AVBufferSrcClose : public HostFunction { public: - AVBufferSrcClose(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFilter(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FilterContextId, int64_t Pts, uint32_t Flags); }; diff --git a/plugins/wasmedge_ffmpeg/avfilter/module.h b/plugins/wasmedge_ffmpeg/avfilter/module.h index 7892a925..176704fa 100644 --- a/plugins/wasmedge_ffmpeg/avfilter/module.h +++ b/plugins/wasmedge_ffmpeg/avfilter/module.h @@ -4,6 +4,7 @@ #pragma once #include "ffmpeg_env.h" + #include "runtime/instance/module.h" namespace WasmEdge { diff --git a/plugins/wasmedge_ffmpeg/avformat/avChapter.h b/plugins/wasmedge_ffmpeg/avformat/avChapter.h index 58ed4f54..822c7367 100644 --- a/plugins/wasmedge_ffmpeg/avformat/avChapter.h +++ b/plugins/wasmedge_ffmpeg/avformat/avChapter.h @@ -3,98 +3,85 @@ #pragma once -#include "avformat_base.h" -#include "runtime/callingframe.h" +#include "ffmpeg_base.h" namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { namespace AVFormat { -class AVChapterId : public WasmEdgeFFmpegAVFormat { +class AVChapterId : public HostFunction { public: - AVChapterId(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t ChapterIdx); }; -class AVChapterSetId : public WasmEdgeFFmpegAVFormat { +class AVChapterSetId : public HostFunction { public: - AVChapterSetId(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t ChapterIdx, int64_t ChapterId); }; -class AVChapterTimebase : public WasmEdgeFFmpegAVFormat { +class AVChapterTimebase : public HostFunction { public: - AVChapterTimebase(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t NumPtr, uint32_t DenPtr, uint32_t AvFormatCtxId, uint32_t ChapterIdx); }; -class AVChapterSetTimebase - : public WasmEdgeFFmpegAVFormat { +class AVChapterSetTimebase : public HostFunction { public: - AVChapterSetTimebase(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t Num, int32_t Den, uint32_t AvFormatCtxId, uint32_t ChapterIdx); }; -class AVChapterStart : public WasmEdgeFFmpegAVFormat { +class AVChapterStart : public HostFunction { public: - AVChapterStart(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t ChapterIdx); }; -class AVChapterSetStart : public WasmEdgeFFmpegAVFormat { +class AVChapterSetStart : public HostFunction { public: - AVChapterSetStart(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t ChapterIdx, int64_t StartValue); }; -class AVChapterEnd : public WasmEdgeFFmpegAVFormat { +class AVChapterEnd : public HostFunction { public: - AVChapterEnd(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t ChapterIdx); }; -class AVChapterSetEnd : public WasmEdgeFFmpegAVFormat { +class AVChapterSetEnd : public HostFunction { public: - AVChapterSetEnd(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t ChapterIdx, int64_t EndValue); }; -class AVChapterMetadata : public WasmEdgeFFmpegAVFormat { +class AVChapterMetadata : public HostFunction { public: - AVChapterMetadata(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t ChapterIdx, uint32_t DictPtr); }; -class AVChapterSetMetadata - : public WasmEdgeFFmpegAVFormat { +class AVChapterSetMetadata : public HostFunction { public: - AVChapterSetMetadata(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t ChapterIdx, uint32_t DictId); diff --git a/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.h b/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.h index b75efd7d..8f01ba9d 100644 --- a/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.h +++ b/plugins/wasmedge_ffmpeg/avformat/avInputOutputFormat.h @@ -3,141 +3,116 @@ #pragma once -#include "avformat_base.h" -#include "runtime/callingframe.h" +#include "ffmpeg_base.h" namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { namespace AVFormat { -class AVIOFormatNameLength - : public WasmEdgeFFmpegAVFormat { +class AVIOFormatNameLength : public HostFunction { public: - AVIOFormatNameLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &, uint32_t AVIOFormatId, uint32_t FormatType); }; -class AVInputFormatName : public WasmEdgeFFmpegAVFormat { +class AVInputFormatName : public HostFunction { public: - AVInputFormatName(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AVInputFormatId, uint32_t NamePtr, uint32_t NameLen); }; -class AVOutputFormatName : public WasmEdgeFFmpegAVFormat { +class AVOutputFormatName : public HostFunction { public: - AVOutputFormatName(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AVInputFormatId, uint32_t NamePtr, uint32_t NameLen); }; -class AVIOFormatLongNameLength - : public WasmEdgeFFmpegAVFormat { +class AVIOFormatLongNameLength : public HostFunction { public: - AVIOFormatLongNameLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &, uint32_t AVIOFormatId, uint32_t FormatType); }; -class AVInputFormatLongName - : public WasmEdgeFFmpegAVFormat { +class AVInputFormatLongName : public HostFunction { public: - AVInputFormatLongName(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AVInputFormatId, uint32_t LongNamePtr, uint32_t LongNameLen); }; -class AVOutputFormatLongName - : public WasmEdgeFFmpegAVFormat { +class AVOutputFormatLongName : public HostFunction { public: - AVOutputFormatLongName(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AVInputFormatId, uint32_t LongNamePtr, uint32_t LongNameLen); }; class AVIOFormatExtensionsLength - : public WasmEdgeFFmpegAVFormat { + : public HostFunction { public: - AVIOFormatExtensionsLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &, uint32_t AVIOFormatId, uint32_t FormatType); }; -class AVInputFormatExtensions - : public WasmEdgeFFmpegAVFormat { +class AVInputFormatExtensions : public HostFunction { public: - AVInputFormatExtensions(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AVInputFormatId, uint32_t Extensions, uint32_t ExtensionsLen); }; -class AVOutputFormatExtensions - : public WasmEdgeFFmpegAVFormat { +class AVOutputFormatExtensions : public HostFunction { public: - AVOutputFormatExtensions(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AVInputFormatId, uint32_t Extensions, uint32_t ExtensionsLen); }; -class AVIOFormatMimeTypeLength - : public WasmEdgeFFmpegAVFormat { +class AVIOFormatMimeTypeLength : public HostFunction { public: - AVIOFormatMimeTypeLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &, uint32_t AVIOFormatId, uint32_t FormatType); }; -class AVInputFormatMimeType - : public WasmEdgeFFmpegAVFormat { +class AVInputFormatMimeType : public HostFunction { public: - AVInputFormatMimeType(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AVInputFormatId, uint32_t MimeTypePtr, uint32_t MimeTypeLen); }; -class AVOutputFormatMimeType - : public WasmEdgeFFmpegAVFormat { +class AVOutputFormatMimeType : public HostFunction { public: - AVOutputFormatMimeType(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AVInputFormatId, uint32_t MimeTypePtr, uint32_t MimeTypeLen); }; -class AVOutputFormatFlags : public WasmEdgeFFmpegAVFormat { +class AVOutputFormatFlags : public HostFunction { public: - AVOutputFormatFlags(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AVInputFormatId); }; -class AVInputOutputFormatFree - : public WasmEdgeFFmpegAVFormat { +class AVInputOutputFormatFree : public HostFunction { public: - AVInputOutputFormatFree(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AVInputOutputId); }; diff --git a/plugins/wasmedge_ffmpeg/avformat/avStream.h b/plugins/wasmedge_ffmpeg/avformat/avStream.h index 552bb206..282d23b0 100644 --- a/plugins/wasmedge_ffmpeg/avformat/avStream.h +++ b/plugins/wasmedge_ffmpeg/avformat/avStream.h @@ -3,148 +3,128 @@ #pragma once -#include "avformat_base.h" -#include "runtime/callingframe.h" +#include "ffmpeg_base.h" namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { namespace AVFormat { -class AVStreamId : public WasmEdgeFFmpegAVFormat { +class AVStreamId : public HostFunction { public: - AVStreamId(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t StreamIdx); }; -class AVStreamIndex : public WasmEdgeFFmpegAVFormat { +class AVStreamIndex : public HostFunction { public: - AVStreamIndex(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t StreamIdx); }; -class AVStreamCodecPar : public WasmEdgeFFmpegAVFormat { +class AVStreamCodecPar : public HostFunction { public: - AVStreamCodecPar(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t StreamIdx, uint32_t CodecParameterPtr); }; -class AVStreamTimebase : public WasmEdgeFFmpegAVFormat { +class AVStreamTimebase : public HostFunction { public: - AVStreamTimebase(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t NumPtr, uint32_t DenPtr, uint32_t AvFormatCtxId, uint32_t StreamIdx); }; -class AVStreamSetTimebase : public WasmEdgeFFmpegAVFormat { +class AVStreamSetTimebase : public HostFunction { public: - AVStreamSetTimebase(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t Num, uint32_t Den, uint32_t AvFormatCtxId, uint32_t StreamIdx); }; -class AVStreamDuration : public WasmEdgeFFmpegAVFormat { +class AVStreamDuration : public HostFunction { public: - AVStreamDuration(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t StreamIdx); }; -class AVStreamStartTime : public WasmEdgeFFmpegAVFormat { +class AVStreamStartTime : public HostFunction { public: - AVStreamStartTime(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t StreamIdx); }; -class AVStreamNbFrames : public WasmEdgeFFmpegAVFormat { +class AVStreamNbFrames : public HostFunction { public: - AVStreamNbFrames(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t StreamIdx); }; -class AVStreamDisposition : public WasmEdgeFFmpegAVFormat { +class AVStreamDisposition : public HostFunction { public: - AVStreamDisposition(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t StreamIdx); }; -class AVStreamRFrameRate : public WasmEdgeFFmpegAVFormat { +class AVStreamRFrameRate : public HostFunction { public: - AVStreamRFrameRate(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t NumPtr, uint32_t DenPtr, uint32_t AvFormatCtxId, uint32_t StreamIdx); }; -class AVStreamSetRFrameRate - : public WasmEdgeFFmpegAVFormat { +class AVStreamSetRFrameRate : public HostFunction { public: - AVStreamSetRFrameRate(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t Num, int32_t Den, uint32_t AvFormatCtxId, uint32_t StreamIdx); }; -class AVStreamAvgFrameRate - : public WasmEdgeFFmpegAVFormat { +class AVStreamAvgFrameRate : public HostFunction { public: - AVStreamAvgFrameRate(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t NumPtr, uint32_t DenPtr, uint32_t AvFormatCtxId, uint32_t StreamIdx); }; -class AVStreamSetAvgFrameRate - : public WasmEdgeFFmpegAVFormat { +class AVStreamSetAvgFrameRate : public HostFunction { public: - AVStreamSetAvgFrameRate(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t Num, int32_t Den, uint32_t AvFormatCtxId, uint32_t StreamIdx); }; -class AVStreamMetadata : public WasmEdgeFFmpegAVFormat { +class AVStreamMetadata : public HostFunction { public: - AVStreamMetadata(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t StreamIdx, uint32_t DictPtr); }; -class AVStreamSetMetadata : public WasmEdgeFFmpegAVFormat { +class AVStreamSetMetadata : public HostFunction { public: - AVStreamSetMetadata(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t StreamIdx, uint32_t DictId); }; -class AVStreamDiscard : public WasmEdgeFFmpegAVFormat { +class AVStreamDiscard : public HostFunction { public: - AVStreamDiscard(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t StreamIdx); }; diff --git a/plugins/wasmedge_ffmpeg/avformat/avformatContext.h b/plugins/wasmedge_ffmpeg/avformat/avformatContext.h index 3f131944..a0d3b1b2 100644 --- a/plugins/wasmedge_ffmpeg/avformat/avformatContext.h +++ b/plugins/wasmedge_ffmpeg/avformat/avformatContext.h @@ -3,95 +3,79 @@ #pragma once -#include "avformat_base.h" -#include "runtime/callingframe.h" +#include "ffmpeg_base.h" namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { namespace AVFormat { -class AVFormatCtxIFormat : public WasmEdgeFFmpegAVFormat { +class AVFormatCtxIFormat : public HostFunction { public: - AVFormatCtxIFormat(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t AvInputFormatPtr); }; -class AVFormatCtxOFormat : public WasmEdgeFFmpegAVFormat { +class AVFormatCtxOFormat : public HostFunction { public: - AVFormatCtxOFormat(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t AvOutputFormatPtr); }; -class AVFormatCtxProbeScore - : public WasmEdgeFFmpegAVFormat { +class AVFormatCtxProbeScore : public HostFunction { public: - AVFormatCtxProbeScore(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId); }; -class AVFormatCtxNbStreams - : public WasmEdgeFFmpegAVFormat { +class AVFormatCtxNbStreams : public HostFunction { public: - AVFormatCtxNbStreams(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId); }; -class AVFormatCtxBitRate : public WasmEdgeFFmpegAVFormat { +class AVFormatCtxBitRate : public HostFunction { public: - AVFormatCtxBitRate(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId); }; -class AVFormatCtxDuration : public WasmEdgeFFmpegAVFormat { +class AVFormatCtxDuration : public HostFunction { public: - AVFormatCtxDuration(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId); }; -class AVFormatCtxNbChapters - : public WasmEdgeFFmpegAVFormat { +class AVFormatCtxNbChapters : public HostFunction { public: - AVFormatCtxNbChapters(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId); }; -class AVFormatCtxSetNbChapters - : public WasmEdgeFFmpegAVFormat { +class AVFormatCtxSetNbChapters : public HostFunction { public: - AVFormatCtxSetNbChapters(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t NbChapters); }; -class AVFormatCtxMetadata : public WasmEdgeFFmpegAVFormat { +class AVFormatCtxMetadata : public HostFunction { public: - AVFormatCtxMetadata(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t DictPtr); }; -class AVFormatCtxSetMetadata - : public WasmEdgeFFmpegAVFormat { +class AVFormatCtxSetMetadata : public HostFunction { public: - AVFormatCtxSetMetadata(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t DictId); }; diff --git a/plugins/wasmedge_ffmpeg/avformat/avformat_base.h b/plugins/wasmedge_ffmpeg/avformat/avformat_base.h deleted file mode 100644 index a53a92df..00000000 --- a/plugins/wasmedge_ffmpeg/avformat/avformat_base.h +++ /dev/null @@ -1,28 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2024 Second State INC - -#pragma once - -#include "ffmpeg_env.h" -#include "runtime/hostfunc.h" - -namespace WasmEdge { -namespace Host { -namespace WasmEdgeFFmpeg { -namespace AVFormat { - -template -class WasmEdgeFFmpegAVFormat : public Runtime::HostFunction { -public: - WasmEdgeFFmpegAVFormat( - std::shared_ptr HostEnv) - : Runtime::HostFunction(0), Env(HostEnv) {} - -protected: - std::shared_ptr Env; -}; - -} // namespace AVFormat -} // namespace WasmEdgeFFmpeg -} // namespace Host -} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avformat/avformat_func.h b/plugins/wasmedge_ffmpeg/avformat/avformat_func.h index 30d8a52b..533017b6 100644 --- a/plugins/wasmedge_ffmpeg/avformat/avformat_func.h +++ b/plugins/wasmedge_ffmpeg/avformat/avformat_func.h @@ -3,233 +3,197 @@ #pragma once -#include "avformat_base.h" -#include "runtime/callingframe.h" +#include "ffmpeg_base.h" namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { namespace AVFormat { -class AVFormatOpenInput : public WasmEdgeFFmpegAVFormat { +class AVFormatOpenInput : public HostFunction { public: - AVFormatOpenInput(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxPtr, uint32_t UrlPtr, uint32_t UrlSize, uint32_t AvInputFormatId, uint32_t AvDictionaryId); }; -class AVFormatFindStreamInfo - : public WasmEdgeFFmpegAVFormat { +class AVFormatFindStreamInfo : public HostFunction { public: - AVFormatFindStreamInfo(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t AvDictionaryId); }; -class AVFormatCloseInput : public WasmEdgeFFmpegAVFormat { +class AVFormatCloseInput : public HostFunction { public: - AVFormatCloseInput(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId); }; -class AVReadPause : public WasmEdgeFFmpegAVFormat { +class AVReadPause : public HostFunction { public: - AVReadPause(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId); }; -class AVReadPlay : public WasmEdgeFFmpegAVFormat { +class AVReadPlay : public HostFunction { public: - AVReadPlay(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId); }; -class AVFormatSeekFile : public WasmEdgeFFmpegAVFormat { +class AVFormatSeekFile : public HostFunction { public: - AVFormatSeekFile(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t StreamIdx, int64_t MinTs, int64_t Ts, int64_t MaxTs, int32_t Flags); }; -class AVDumpFormat : public WasmEdgeFFmpegAVFormat { +class AVDumpFormat : public HostFunction { public: - AVDumpFormat(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, int32_t Idx, uint32_t UrlPtr, uint32_t UrlSize, int32_t IsOutput); }; -class AVFormatFreeContext : public WasmEdgeFFmpegAVFormat { +class AVFormatFreeContext : public HostFunction { public: - AVFormatFreeContext(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxPtr); }; -class AVFindBestStream : public WasmEdgeFFmpegAVFormat { +class AVFindBestStream : public HostFunction { public: - AVFindBestStream(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, int32_t MediaTypeId, int32_t WantedStream, int32_t RelatedStream, uint32_t DecoderRetId, int32_t Flags); }; -class AVReadFrame : public WasmEdgeFFmpegAVFormat { +class AVReadFrame : public HostFunction { public: - AVReadFrame(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t PacketId); }; -class AVIOClose : public WasmEdgeFFmpegAVFormat { +class AVIOClose : public HostFunction { public: - AVIOClose(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId); }; -class AVFormatNetworkInit : public WasmEdgeFFmpegAVFormat { +class AVFormatNetworkInit : public HostFunction { public: - AVFormatNetworkInit(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVFormatNetworkDeInit - : public WasmEdgeFFmpegAVFormat { +class AVFormatNetworkDeInit : public HostFunction { public: - AVFormatNetworkDeInit(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVFormatWriteHeader : public WasmEdgeFFmpegAVFormat { +class AVFormatWriteHeader : public HostFunction { public: - AVFormatWriteHeader(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t DictId); }; -class AVFormatWriteTrailer - : public WasmEdgeFFmpegAVFormat { +class AVFormatWriteTrailer : public HostFunction { public: - AVFormatWriteTrailer(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId); }; class AVFormatAllocOutputContext2 - : public WasmEdgeFFmpegAVFormat { + : public HostFunction { public: - AVFormatAllocOutputContext2(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxPtr, uint32_t AVOutputFormatId, uint32_t FormatNamePtr, uint32_t FormatLen, uint32_t FileNamePtr, uint32_t FileNameLen); }; -class AVIOOpen : public WasmEdgeFFmpegAVFormat { +class AVIOOpen : public HostFunction { public: - AVIOOpen(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxId, uint32_t FileNamePtr, uint32_t FileNameLen, int32_t Flags); }; -class AVIOOpen2 : public WasmEdgeFFmpegAVFormat { +class AVIOOpen2 : public HostFunction { public: - AVIOOpen2(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvFormatCtxtId, uint32_t UrlPtr, uint32_t UrlLen, int32_t Flags, uint32_t AVIOInterruptCBId, uint32_t AVDictionaryId); }; -class AVFormatVersion : public WasmEdgeFFmpegAVFormat { +class AVFormatVersion : public HostFunction { public: - AVFormatVersion(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVChapterMallocz : public WasmEdgeFFmpegAVFormat { +class AVChapterMallocz : public HostFunction { public: - AVChapterMallocz(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AVChapterPtr); }; -class AVChapterDynarrayAdd - : public WasmEdgeFFmpegAVFormat { +class AVChapterDynarrayAdd : public HostFunction { public: - AVChapterDynarrayAdd(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} - + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, int32_t NbChaptersPtr, uint32_t AvChapterId); }; -class AVFreeP : public WasmEdgeFFmpegAVFormat { +class AVFreeP : public HostFunction { public: - AVFreeP(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} - + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &, uint32_t AvChapterId); }; -class AVInterleavedWriteFrame - : public WasmEdgeFFmpegAVFormat { +class AVInterleavedWriteFrame : public HostFunction { public: - AVInterleavedWriteFrame(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} - + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t AvPacketId); }; -class AVWriteFrame : public WasmEdgeFFmpegAVFormat { +class AVWriteFrame : public HostFunction { public: - AVWriteFrame(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} - + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &, uint32_t AvFormatCtxId, uint32_t AvPacketId); }; -class AVFormatNewStream : public WasmEdgeFFmpegAVFormat { +class AVFormatNewStream : public HostFunction { public: - AVFormatNewStream(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AVFormatCtxId, uint32_t AVCodecId); }; -class AVGuessCodec : public WasmEdgeFFmpegAVFormat { +class AVGuessCodec : public HostFunction { public: - AVGuessCodec(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AVInputOutputId, uint32_t ShortNamePtr, uint32_t ShortNameLen, uint32_t FileNamePtr, @@ -238,34 +202,28 @@ class AVGuessCodec : public WasmEdgeFFmpegAVFormat { }; class AVFormatConfigurationLength - : public WasmEdgeFFmpegAVFormat { + : public HostFunction { public: - AVFormatConfigurationLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVFormatConfiguration - : public WasmEdgeFFmpegAVFormat { +class AVFormatConfiguration : public HostFunction { public: - AVFormatConfiguration(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, uint32_t ConfigLen); }; -class AVFormatLicenseLength - : public WasmEdgeFFmpegAVFormat { +class AVFormatLicenseLength : public HostFunction { public: - AVFormatLicenseLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVFormatLicense : public WasmEdgeFFmpegAVFormat { +class AVFormatLicense : public HostFunction { public: - AVFormatLicense(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVFormat(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, uint32_t LicenseLen); }; diff --git a/plugins/wasmedge_ffmpeg/avformat/module.h b/plugins/wasmedge_ffmpeg/avformat/module.h index eae89e3a..47b604d9 100644 --- a/plugins/wasmedge_ffmpeg/avformat/module.h +++ b/plugins/wasmedge_ffmpeg/avformat/module.h @@ -4,6 +4,7 @@ #pragma once #include "ffmpeg_env.h" + #include "runtime/instance/module.h" namespace WasmEdge { diff --git a/plugins/wasmedge_ffmpeg/avutil/avDictionary.h b/plugins/wasmedge_ffmpeg/avutil/avDictionary.h index b8732c76..c1731a33 100644 --- a/plugins/wasmedge_ffmpeg/avutil/avDictionary.h +++ b/plugins/wasmedge_ffmpeg/avutil/avDictionary.h @@ -3,37 +3,33 @@ #pragma once -#include "avutil_base.h" -#include "runtime/callingframe.h" +#include "ffmpeg_base.h" namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { namespace AVUtil { -class AVDictSet : public WasmEdgeFFmpegAVUtil { +class AVDictSet : public HostFunction { public: - AVDictSet(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t DictId, uint32_t KeyPtr, uint32_t KeyLen, uint32_t ValuePtr, uint32_t ValueLen, int32_t Flags); }; -class AVDictGet : public WasmEdgeFFmpegAVUtil { +class AVDictGet : public HostFunction { public: - AVDictGet(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t DictId, uint32_t KeyPtr, uint32_t KeyLen, uint32_t PrevDictEntryIdx, uint32_t Flags, uint32_t KeyLenPtr, uint32_t ValueLenPtr); }; -class AVDictGetKeyValue : public WasmEdgeFFmpegAVUtil { +class AVDictGetKeyValue : public HostFunction { public: - AVDictGetKeyValue(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t DictId, uint32_t KeyPtr, uint32_t KeyLen, uint32_t ValBufPtr, uint32_t ValBufLen, uint32_t KeyBufPtr, @@ -41,18 +37,16 @@ class AVDictGetKeyValue : public WasmEdgeFFmpegAVUtil { uint32_t Flags); }; -class AVDictCopy : public WasmEdgeFFmpegAVUtil { +class AVDictCopy : public HostFunction { public: - AVDictCopy(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t DestDictId, uint32_t SrcDictId, uint32_t Flags); }; -class AVDictFree : public WasmEdgeFFmpegAVUtil { +class AVDictFree : public HostFunction { public: - AVDictFree(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t DictId); }; diff --git a/plugins/wasmedge_ffmpeg/avutil/avFrame.h b/plugins/wasmedge_ffmpeg/avutil/avFrame.h index 39cb732e..9e4a3639 100644 --- a/plugins/wasmedge_ffmpeg/avutil/avFrame.h +++ b/plugins/wasmedge_ffmpeg/avutil/avFrame.h @@ -3,255 +3,217 @@ #pragma once -#include "avutil_base.h" -#include "runtime/callingframe.h" +#include "ffmpeg_base.h" namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { namespace AVUtil { -class AVFrameAlloc : public WasmEdgeFFmpegAVUtil { +class AVFrameAlloc : public HostFunction { public: - AVFrameAlloc(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FramePtr); }; -class AVFrameFree : public WasmEdgeFFmpegAVUtil { +class AVFrameFree : public HostFunction { public: - AVFrameFree(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; -class AVFrameWidth : public WasmEdgeFFmpegAVUtil { +class AVFrameWidth : public HostFunction { public: - AVFrameWidth(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; -class AVFrameHeight : public WasmEdgeFFmpegAVUtil { +class AVFrameHeight : public HostFunction { public: - AVFrameHeight(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; -class AVFrameSetWidth : public WasmEdgeFFmpegAVUtil { +class AVFrameSetWidth : public HostFunction { public: - AVFrameSetWidth(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, uint32_t Width); }; -class AVFrameSetHeight : public WasmEdgeFFmpegAVUtil { +class AVFrameSetHeight : public HostFunction { public: - AVFrameSetHeight(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, uint32_t Height); }; -class AVFrameVideoFormat : public WasmEdgeFFmpegAVUtil { +class AVFrameVideoFormat : public HostFunction { public: - AVFrameVideoFormat(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; -class AVFrameSetVideoFormat - : public WasmEdgeFFmpegAVUtil { +class AVFrameSetVideoFormat : public HostFunction { public: - AVFrameSetVideoFormat(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, uint32_t AvPixFormatId); }; -class AVFrameIsNull : public WasmEdgeFFmpegAVUtil { +class AVFrameIsNull : public HostFunction { public: - AVFrameIsNull(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; -class AVFrameLinesize : public WasmEdgeFFmpegAVUtil { +class AVFrameLinesize : public HostFunction { public: - AVFrameLinesize(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, uint32_t Idx); }; -class AVFrameData : public WasmEdgeFFmpegAVUtil { +class AVFrameData : public HostFunction { public: - AVFrameData(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, uint32_t FrameBufPtr, uint32_t FrameBufLen, uint32_t Index); }; -class AVFrameGetBuffer : public WasmEdgeFFmpegAVUtil { +class AVFrameGetBuffer : public HostFunction { public: - AVFrameGetBuffer(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, int32_t Align); }; -class AVFrameAudioFormat : public WasmEdgeFFmpegAVUtil { +class AVFrameAudioFormat : public HostFunction { public: - AVFrameAudioFormat(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; -class AVFrameSetAudioFormat - : public WasmEdgeFFmpegAVUtil { +class AVFrameSetAudioFormat : public HostFunction { public: - AVFrameSetAudioFormat(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, uint32_t SampleFormatId); }; -class AVFrameSetChannelLayout - : public WasmEdgeFFmpegAVUtil { +class AVFrameSetChannelLayout : public HostFunction { public: - AVFrameSetChannelLayout(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, uint64_t ChannelLayoutID); }; -class AVFrameSetNbSamples : public WasmEdgeFFmpegAVUtil { +class AVFrameSetNbSamples : public HostFunction { public: - AVFrameSetNbSamples(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, int32_t Samples); }; -class AVFrameNbSamples : public WasmEdgeFFmpegAVUtil { +class AVFrameNbSamples : public HostFunction { public: - AVFrameNbSamples(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; -class AVFrameSampleRate : public WasmEdgeFFmpegAVUtil { +class AVFrameSampleRate : public HostFunction { public: - AVFrameSampleRate(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; -class AVFrameSetSampleRate : public WasmEdgeFFmpegAVUtil { +class AVFrameSetSampleRate : public HostFunction { public: - AVFrameSetSampleRate(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, int32_t SampleRate); }; -class AVFrameChannels : public WasmEdgeFFmpegAVUtil { +class AVFrameChannels : public HostFunction { public: - AVFrameChannels(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; -class AVFrameSetChannels : public WasmEdgeFFmpegAVUtil { +class AVFrameSetChannels : public HostFunction { public: - AVFrameSetChannels(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, int32_t Channels); }; -class AVFrameChannelLayout : public WasmEdgeFFmpegAVUtil { +class AVFrameChannelLayout : public HostFunction { public: - AVFrameChannelLayout(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; class AVFrameBestEffortTimestamp - : public WasmEdgeFFmpegAVUtil { + : public HostFunction { public: - AVFrameBestEffortTimestamp(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; -class AVFramePictType : public WasmEdgeFFmpegAVUtil { +class AVFramePictType : public HostFunction { public: - AVFramePictType(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; -class AVFrameSetPictType : public WasmEdgeFFmpegAVUtil { +class AVFrameSetPictType : public HostFunction { public: - AVFrameSetPictType(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, int32_t PictureId); }; -class AVFrameInterlacedFrame - : public WasmEdgeFFmpegAVUtil { +class AVFrameInterlacedFrame : public HostFunction { public: - AVFrameInterlacedFrame(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; -class AVFrameTopFieldFirst : public WasmEdgeFFmpegAVUtil { +class AVFrameTopFieldFirst : public HostFunction { public: - AVFrameTopFieldFirst(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; -class AVFramePaletteHasChanged - : public WasmEdgeFFmpegAVUtil { +class AVFramePaletteHasChanged : public HostFunction { public: - AVFramePaletteHasChanged(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; -class AVFrameColorSpace : public WasmEdgeFFmpegAVUtil { +class AVFrameColorSpace : public HostFunction { public: - AVFrameColorSpace(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; -class AVFrameSetColorSpace : public WasmEdgeFFmpegAVUtil { +class AVFrameSetColorSpace : public HostFunction { public: - AVFrameSetColorSpace(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, int32_t ColorSpaceId); }; -class AVFrameColorRange : public WasmEdgeFFmpegAVUtil { +class AVFrameColorRange : public HostFunction { public: - AVFrameColorRange(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; -class AVFrameSetColorRange : public WasmEdgeFFmpegAVUtil { +class AVFrameSetColorRange : public HostFunction { public: - AVFrameSetColorRange(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, int32_t ColorRangeId); }; @@ -259,144 +221,121 @@ class AVFrameSetColorRange : public WasmEdgeFFmpegAVUtil { // color_transfer_characteristic class AVFrameColorTransferCharacteristic - : public WasmEdgeFFmpegAVUtil { + : public HostFunction { public: - AVFrameColorTransferCharacteristic(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; class AVFrameSetColorTransferCharacteristic - : public WasmEdgeFFmpegAVUtil { + : public HostFunction { public: - AVFrameSetColorTransferCharacteristic( - std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, int32_t ColorTransferCharacteristicId); }; -class AVFrameChromaLocation - : public WasmEdgeFFmpegAVUtil { +class AVFrameChromaLocation : public HostFunction { public: - AVFrameChromaLocation(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; class AVFrameCodedPictureNumber - : public WasmEdgeFFmpegAVUtil { + : public HostFunction { public: - AVFrameCodedPictureNumber(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; class AVFrameDisplayPictureNumber - : public WasmEdgeFFmpegAVUtil { + : public HostFunction { public: - AVFrameDisplayPictureNumber(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; -class AVFrameRepeatPict : public WasmEdgeFFmpegAVUtil { +class AVFrameRepeatPict : public HostFunction { public: - AVFrameRepeatPict(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; -class AVFrameFlags : public WasmEdgeFFmpegAVUtil { +class AVFrameFlags : public HostFunction { public: - AVFrameFlags(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; -class AVFrameQuality : public WasmEdgeFFmpegAVUtil { +class AVFrameQuality : public HostFunction { public: - AVFrameQuality(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; -class AVFrameMetadata : public WasmEdgeFFmpegAVUtil { +class AVFrameMetadata : public HostFunction { public: - AVFrameMetadata(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, uint32_t DictPtr); }; -class AVFrameSetMetadata : public WasmEdgeFFmpegAVUtil { +class AVFrameSetMetadata : public HostFunction { public: - AVFrameSetMetadata(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, uint32_t DictId); }; -class AVFrameKeyFrame : public WasmEdgeFFmpegAVUtil { +class AVFrameKeyFrame : public HostFunction { public: - AVFrameKeyFrame(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; -class AVFramePts : public WasmEdgeFFmpegAVUtil { +class AVFramePts : public HostFunction { public: - AVFramePts(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; -class AVFrameSetPts : public WasmEdgeFFmpegAVUtil { +class AVFrameSetPts : public HostFunction { public: - AVFrameSetPts(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, int64_t Pts); }; -class AVFrameCopy : public WasmEdgeFFmpegAVUtil { +class AVFrameCopy : public HostFunction { public: - AVFrameCopy(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t DestFrameId, uint32_t SrcFrameId); }; -class AVFrameCopyProps : public WasmEdgeFFmpegAVUtil { +class AVFrameCopyProps : public HostFunction { public: - AVFrameCopyProps(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t DestFrameId, uint32_t SrcFrameId); }; -class AVFrameSampleAspectRatio - : public WasmEdgeFFmpegAVUtil { +class AVFrameSampleAspectRatio : public HostFunction { public: - AVFrameSampleAspectRatio(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, uint32_t NumPtr, uint32_t DenPtr); }; -class AVFrameColorPrimaries - : public WasmEdgeFFmpegAVUtil { +class AVFrameColorPrimaries : public HostFunction { public: - AVFrameColorPrimaries(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; -class AVFrameSetColorPrimaries - : public WasmEdgeFFmpegAVUtil { +class AVFrameSetColorPrimaries : public HostFunction { public: - AVFrameSetColorPrimaries(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId, int32_t ColorPrimariesId); }; diff --git a/plugins/wasmedge_ffmpeg/avutil/avRational.h b/plugins/wasmedge_ffmpeg/avutil/avRational.h index f37195da..82866940 100644 --- a/plugins/wasmedge_ffmpeg/avutil/avRational.h +++ b/plugins/wasmedge_ffmpeg/avutil/avRational.h @@ -3,103 +3,91 @@ #pragma once -#include "avutil_base.h" -#include "runtime/callingframe.h" +#include "ffmpeg_base.h" namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { namespace AVUtil { -class AVAddQ : public WasmEdgeFFmpegAVUtil { +class AVAddQ : public HostFunction { public: - AVAddQ(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, int32_t ADen, int32_t BNum, int32_t BDen, uint32_t CNumPtr, uint32_t CDenPtr); }; -class AVSubQ : public WasmEdgeFFmpegAVUtil { +class AVSubQ : public HostFunction { public: - AVSubQ(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, int32_t ADen, int32_t BNum, int32_t BDen, uint32_t CNumPtr, uint32_t CDenPtr); }; -class AVMulQ : public WasmEdgeFFmpegAVUtil { +class AVMulQ : public HostFunction { public: - AVMulQ(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, int32_t ADen, int32_t BNum, int32_t BDen, uint32_t CNumPtr, uint32_t CDenPtr); }; -class AVDivQ : public WasmEdgeFFmpegAVUtil { +class AVDivQ : public HostFunction { public: - AVDivQ(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, int32_t ADen, int32_t BNum, int32_t BDen, uint32_t CNumPtr, uint32_t CDenPtr); }; -class AVCmpQ : public WasmEdgeFFmpegAVUtil { +class AVCmpQ : public HostFunction { public: - AVCmpQ(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, int32_t ADen, int32_t BNum, int32_t BDen); }; -class AVNearerQ : public WasmEdgeFFmpegAVUtil { +class AVNearerQ : public HostFunction { public: - AVNearerQ(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, int32_t ADen, int32_t BNum, int32_t BDen, int32_t CNum, int32_t CDen); }; -class AVQ2d : public WasmEdgeFFmpegAVUtil { +class AVQ2d : public HostFunction { public: - AVQ2d(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, int32_t ADen); }; -class AVD2Q : public WasmEdgeFFmpegAVUtil { +class AVD2Q : public HostFunction { public: - AVD2Q(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, double_t D, int32_t Max, uint32_t ANumPtr, uint32_t ADenPtr); }; -class AVQ2IntFloat : public WasmEdgeFFmpegAVUtil { +class AVQ2IntFloat : public HostFunction { public: - AVQ2IntFloat(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, int32_t ADen); }; -class AVInvQ : public WasmEdgeFFmpegAVUtil { +class AVInvQ : public HostFunction { public: - AVInvQ(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t ANum, int32_t ADen, uint32_t BNumPtr, uint32_t BDenPtr); }; -class AVReduce : public WasmEdgeFFmpegAVUtil { +class AVReduce : public HostFunction { public: - AVReduce(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t ANumPtr, uint32_t ADenPtr, int64_t BNum, int64_t BDen, int64_t Max); diff --git a/plugins/wasmedge_ffmpeg/avutil/avTime.h b/plugins/wasmedge_ffmpeg/avutil/avTime.h index 6a841844..6ec6e2c6 100644 --- a/plugins/wasmedge_ffmpeg/avutil/avTime.h +++ b/plugins/wasmedge_ffmpeg/avutil/avTime.h @@ -3,40 +3,35 @@ #pragma once -#include "avutil_base.h" -#include "runtime/callingframe.h" +#include "ffmpeg_base.h" namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { namespace AVUtil { -class AVGetTime : public WasmEdgeFFmpegAVUtil { +class AVGetTime : public HostFunction { public: - AVGetTime(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVGetTimeRelative : public WasmEdgeFFmpegAVUtil { +class AVGetTimeRelative : public HostFunction { public: - AVGetTimeRelative(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; class AVGetTimeRelativeIsMonotonic - : public WasmEdgeFFmpegAVUtil { + : public HostFunction { public: - AVGetTimeRelativeIsMonotonic(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVUSleep : public WasmEdgeFFmpegAVUtil { +class AVUSleep : public HostFunction { public: - AVUSleep(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t USec); }; diff --git a/plugins/wasmedge_ffmpeg/avutil/avutil_base.h b/plugins/wasmedge_ffmpeg/avutil/avutil_base.h deleted file mode 100644 index 851f692c..00000000 --- a/plugins/wasmedge_ffmpeg/avutil/avutil_base.h +++ /dev/null @@ -1,28 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2024 Second State INC - -#pragma once - -#include "ffmpeg_env.h" -#include "runtime/hostfunc.h" - -namespace WasmEdge { -namespace Host { -namespace WasmEdgeFFmpeg { -namespace AVUtil { - -template -class WasmEdgeFFmpegAVUtil : public Runtime::HostFunction { -public: - WasmEdgeFFmpegAVUtil( - std::shared_ptr HostEnv) - : Runtime::HostFunction(0), Env(HostEnv) {} - -protected: - std::shared_ptr Env; -}; - -} // namespace AVUtil -} // namespace WasmEdgeFFmpeg -} // namespace Host -} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/avutil/avutil_func.h b/plugins/wasmedge_ffmpeg/avutil/avutil_func.h index 5532eb59..05a2c7b5 100644 --- a/plugins/wasmedge_ffmpeg/avutil/avutil_func.h +++ b/plugins/wasmedge_ffmpeg/avutil/avutil_func.h @@ -3,205 +3,175 @@ #pragma once -#include "avutil_base.h" - -#include "runtime/callingframe.h" +#include "ffmpeg_base.h" namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { namespace AVUtil { -class AVLogSetLevel : public WasmEdgeFFmpegAVUtil { +class AVLogSetLevel : public HostFunction { public: - AVLogSetLevel(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t LogLevelId); }; -class AVLogGetLevel : public WasmEdgeFFmpegAVUtil { +class AVLogGetLevel : public HostFunction { public: - AVLogGetLevel(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVLogSetFlags : public WasmEdgeFFmpegAVUtil { +class AVLogSetFlags : public HostFunction { public: - AVLogSetFlags(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t FlagsId); }; -class AVLogGetFlags : public WasmEdgeFFmpegAVUtil { +class AVLogGetFlags : public HostFunction { public: - AVLogGetFlags(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; // Option funcs. -class AVOptSetBin : public WasmEdgeFFmpegAVUtil { +class AVOptSetBin : public HostFunction { public: - AVOptSetBin(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVOptSet : public WasmEdgeFFmpegAVUtil { +class AVOptSet : public HostFunction { public: - AVOptSet(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVOptSetInt : public WasmEdgeFFmpegAVUtil { +class AVOptSetInt : public HostFunction { public: - AVOptSetInt(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVOptSetDouble : public WasmEdgeFFmpegAVUtil { +class AVOptSetDouble : public HostFunction { public: - AVOptSetDouble(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVOptSetQ : public WasmEdgeFFmpegAVUtil { +class AVOptSetQ : public HostFunction { public: - AVOptSetQ(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVOptSetImageSize : public WasmEdgeFFmpegAVUtil { +class AVOptSetImageSize : public HostFunction { public: - AVOptSetImageSize(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVOptSetPixelFmt : public WasmEdgeFFmpegAVUtil { +class AVOptSetPixelFmt : public HostFunction { public: - AVOptSetPixelFmt(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVOptSetSampleFmt : public WasmEdgeFFmpegAVUtil { +class AVOptSetSampleFmt : public HostFunction { public: - AVOptSetSampleFmt(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVOptSetChannelLayout - : public WasmEdgeFFmpegAVUtil { +class AVOptSetChannelLayout : public HostFunction { public: - AVOptSetChannelLayout(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVRescaleQ : public WasmEdgeFFmpegAVUtil { +class AVRescaleQ : public HostFunction { public: - AVRescaleQ(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int64_t A, int32_t BNum, int32_t BDen, int32_t CNum, int32_t CDen); }; -class AVRescaleQRnd : public WasmEdgeFFmpegAVUtil { +class AVRescaleQRnd : public HostFunction { public: - AVRescaleQRnd(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &, int64_t A, int32_t BNum, int32_t BDen, int32_t CNum, int32_t CDen, int32_t RoundingId); }; -class AVUtilVersion : public WasmEdgeFFmpegAVUtil { +class AVUtilVersion : public HostFunction { public: - AVUtilVersion(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &); }; class AVGetChannelLayoutNbChannels - : public WasmEdgeFFmpegAVUtil { + : public HostFunction { public: - AVGetChannelLayoutNbChannels(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint64_t ChannelLayoutId); }; class AVGetChannelLayoutNameLen - : public WasmEdgeFFmpegAVUtil { + : public HostFunction { public: - AVGetChannelLayoutNameLen(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint64_t ChannelLayoutId); }; -class AVGetChannelLayoutName - : public WasmEdgeFFmpegAVUtil { +class AVGetChannelLayoutName : public HostFunction { public: - AVGetChannelLayoutName(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint64_t ChannelLayoutId, uint32_t NamePtr, uint32_t NameLen); }; -class AVGetChannelLayoutMask - : public WasmEdgeFFmpegAVUtil { +class AVGetChannelLayoutMask : public HostFunction { public: - AVGetChannelLayoutMask(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint64_t ChannelLayoutId); }; class AVGetDefaultChannelLayout - : public WasmEdgeFFmpegAVUtil { + : public HostFunction { public: - AVGetDefaultChannelLayout(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t ChannelLayoutId); }; class AVUtilConfigurationLength - : public WasmEdgeFFmpegAVUtil { + : public HostFunction { public: - AVUtilConfigurationLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVUtilConfiguration : public WasmEdgeFFmpegAVUtil { +class AVUtilConfiguration : public HostFunction { public: - AVUtilConfiguration(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, uint32_t ConfigLen); }; -class AVUtilLicenseLength : public WasmEdgeFFmpegAVUtil { +class AVUtilLicenseLength : public HostFunction { public: - AVUtilLicenseLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class AVUtilLicense : public WasmEdgeFFmpegAVUtil { +class AVUtilLicense : public HostFunction { public: - AVUtilLicense(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, uint32_t LicenseLen); }; diff --git a/plugins/wasmedge_ffmpeg/avutil/error.h b/plugins/wasmedge_ffmpeg/avutil/error.h index 3e2ec1e3..0ff38ce5 100644 --- a/plugins/wasmedge_ffmpeg/avutil/error.h +++ b/plugins/wasmedge_ffmpeg/avutil/error.h @@ -3,33 +3,29 @@ #pragma once -#include "avutil_base.h" -#include "runtime/callingframe.h" +#include "ffmpeg_base.h" namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { namespace AVUtil { -class AVUtilAVStrError : public WasmEdgeFFmpegAVUtil { +class AVUtilAVStrError : public HostFunction { public: - AVUtilAVStrError(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t ErrNum, uint32_t ErrBuf, uint32_t BufLen); }; -class AVUtilAVError : public WasmEdgeFFmpegAVUtil { +class AVUtilAVError : public HostFunction { public: - AVUtilAVError(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t ErrNum); }; -class AVUtilAVUNError : public WasmEdgeFFmpegAVUtil { +class AVUtilAVUNError : public HostFunction { public: - AVUtilAVUNError(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t ErrNum); }; diff --git a/plugins/wasmedge_ffmpeg/avutil/module.h b/plugins/wasmedge_ffmpeg/avutil/module.h index 0ef5e265..6c22537b 100644 --- a/plugins/wasmedge_ffmpeg/avutil/module.h +++ b/plugins/wasmedge_ffmpeg/avutil/module.h @@ -4,6 +4,7 @@ #pragma once #include "ffmpeg_env.h" + #include "runtime/instance/module.h" namespace WasmEdge { diff --git a/plugins/wasmedge_ffmpeg/avutil/pixfmt.h b/plugins/wasmedge_ffmpeg/avutil/pixfmt.h index dedbcd30..0b5dfc64 100644 --- a/plugins/wasmedge_ffmpeg/avutil/pixfmt.h +++ b/plugins/wasmedge_ffmpeg/avutil/pixfmt.h @@ -3,8 +3,7 @@ #pragma once -#include "avutil_base.h" -#include "runtime/callingframe.h" +#include "ffmpeg_base.h" namespace WasmEdge { namespace Host { @@ -12,120 +11,103 @@ namespace WasmEdgeFFmpeg { namespace AVUtil { class AvPixFmtDescriptorNbComponents - : public WasmEdgeFFmpegAVUtil { + : public HostFunction { public: - AvPixFmtDescriptorNbComponents(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t PixFormatId); }; class AvPixFmtDescriptorLog2ChromaW - : public WasmEdgeFFmpegAVUtil { + : public HostFunction { public: - AvPixFmtDescriptorLog2ChromaW(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t PixFormatId); }; class AvPixFmtDescriptorLog2ChromaH - : public WasmEdgeFFmpegAVUtil { + : public HostFunction { public: - AvPixFmtDescriptorLog2ChromaH(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t PixFormatId); }; -class AVColorRangeNameLength - : public WasmEdgeFFmpegAVUtil { +class AVColorRangeNameLength : public HostFunction { public: - AVColorRangeNameLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t RangeId); }; -class AVColorRangeName : public WasmEdgeFFmpegAVUtil { +class AVColorRangeName : public HostFunction { public: - AVColorRangeName(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t RangeId, uint32_t RangeName, uint32_t RangeLength); }; class AVColorTransferNameLength - : public WasmEdgeFFmpegAVUtil { + : public HostFunction { public: - AVColorTransferNameLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t TransferId); }; -class AVColorTransferName : public WasmEdgeFFmpegAVUtil { +class AVColorTransferName : public HostFunction { public: - AVColorTransferName(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t TransferId, uint32_t TransferNamePtr, uint32_t TransferLength); }; -class AVColorSpaceNameLength - : public WasmEdgeFFmpegAVUtil { +class AVColorSpaceNameLength : public HostFunction { public: - AVColorSpaceNameLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t ColorSpaceId); }; -class AVColorSpaceName : public WasmEdgeFFmpegAVUtil { +class AVColorSpaceName : public HostFunction { public: - AVColorSpaceName(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t ColorSpaceId, uint32_t ColorSpaceNamePtr, uint32_t ColorSpaceLen); }; class AVColorPrimariesNameLength - : public WasmEdgeFFmpegAVUtil { + : public HostFunction { public: - AVColorPrimariesNameLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t ColorPrimariesId); }; -class AVColorPrimariesName : public WasmEdgeFFmpegAVUtil { +class AVColorPrimariesName : public HostFunction { public: - AVColorPrimariesName(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t ColorPrimariesId, uint32_t ColorPrimariesNamePtr, uint32_t ColorPrimariesLen); }; -class AVPixelFormatNameLength - : public WasmEdgeFFmpegAVUtil { +class AVPixelFormatNameLength : public HostFunction { public: - AVPixelFormatNameLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t AvPixFormatId); }; -class AVPixelFormatName : public WasmEdgeFFmpegAVUtil { +class AVPixelFormatName : public HostFunction { public: - AVPixelFormatName(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t PixFormatId, uint32_t PixFormatNamePtr, uint32_t PixFormatNameLen); }; -class AVPixelFormatMask : public WasmEdgeFFmpegAVUtil { +class AVPixelFormatMask : public HostFunction { public: - AVPixelFormatMask(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t PixFormatId); }; diff --git a/plugins/wasmedge_ffmpeg/avutil/samplefmt.h b/plugins/wasmedge_ffmpeg/avutil/samplefmt.h index 02b0d0c4..373ec2b7 100644 --- a/plugins/wasmedge_ffmpeg/avutil/samplefmt.h +++ b/plugins/wasmedge_ffmpeg/avutil/samplefmt.h @@ -3,103 +3,89 @@ #pragma once -#include "avutil_base.h" -#include "runtime/callingframe.h" +#include "ffmpeg_base.h" namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { namespace AVUtil { -class AVGetPlanarSampleFmt : public WasmEdgeFFmpegAVUtil { +class AVGetPlanarSampleFmt : public HostFunction { public: - AVGetPlanarSampleFmt(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SampleFormatId); }; -class AVGetPackedSampleFmt : public WasmEdgeFFmpegAVUtil { +class AVGetPackedSampleFmt : public HostFunction { public: - AVGetPackedSampleFmt(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SampleFormatId); }; -class AVSampleFmtIsPlanar : public WasmEdgeFFmpegAVUtil { +class AVSampleFmtIsPlanar : public HostFunction { public: - AVSampleFmtIsPlanar(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SampleFormatId); }; -class AVGetBytesPerSample : public WasmEdgeFFmpegAVUtil { +class AVGetBytesPerSample : public HostFunction { public: - AVGetBytesPerSample(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SampleFormatId); }; -class AVGetSampleFmt : public WasmEdgeFFmpegAVUtil { +class AVGetSampleFmt : public HostFunction { public: - AVGetSampleFmt(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t Str, uint32_t StrLen); }; -class AVSamplesGetBufferSize - : public WasmEdgeFFmpegAVUtil { +class AVSamplesGetBufferSize : public HostFunction { public: - AVSamplesGetBufferSize(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, int32_t NbChannels, int32_t NbSamples, uint32_t SampleFormatId, int32_t Align); }; class AVSamplesAllocArrayAndSamples - : public WasmEdgeFFmpegAVUtil { + : public HostFunction { public: - AVSamplesAllocArrayAndSamples(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t BufferPtr, uint32_t LinesizePtr, int32_t NbChannels, int32_t NbSamples, uint32_t SampleFmtId, int32_t Align); }; -class AVGetSampleFmtNameLength - : public WasmEdgeFFmpegAVUtil { +class AVGetSampleFmtNameLength : public HostFunction { public: - AVGetSampleFmtNameLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SampleFmtId); }; -class AVGetSampleFmtName : public WasmEdgeFFmpegAVUtil { +class AVGetSampleFmtName : public HostFunction { public: - AVGetSampleFmtName(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SampleFmtId, uint32_t SampleFmtNamePtr, uint32_t SampleFmtNameLen); }; -class AVGetSampleFmtMask : public WasmEdgeFFmpegAVUtil { +class AVGetSampleFmtMask : public HostFunction { public: - AVGetSampleFmtMask(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SampleFmtId); }; -class AVFreep : public WasmEdgeFFmpegAVUtil { +class AVFreep : public HostFunction { public: - AVFreep(std::shared_ptr HostEnv) - : WasmEdgeFFmpegAVUtil(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t BufferId); }; diff --git a/plugins/wasmedge_ffmpeg/avcodec/avcodec_base.h b/plugins/wasmedge_ffmpeg/ffmpeg_base.h similarity index 68% rename from plugins/wasmedge_ffmpeg/avcodec/avcodec_base.h rename to plugins/wasmedge_ffmpeg/ffmpeg_base.h index 126f43b0..faac0e3d 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/avcodec_base.h +++ b/plugins/wasmedge_ffmpeg/ffmpeg_base.h @@ -4,24 +4,23 @@ #pragma once #include "ffmpeg_env.h" + +#include "runtime/callingframe.h" #include "runtime/hostfunc.h" namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { -namespace AVcodec { -template -class WasmEdgeFFmpegAVCodec : public Runtime::HostFunction { +template class HostFunction : public Runtime::HostFunction { public: - WasmEdgeFFmpegAVCodec(std::shared_ptr HostEnv) + HostFunction(std::shared_ptr HostEnv) : Runtime::HostFunction(0), Env(HostEnv) {} protected: std::shared_ptr Env; }; -} // namespace AVcodec } // namespace WasmEdgeFFmpeg } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/ffmpeg_env.cpp b/plugins/wasmedge_ffmpeg/ffmpeg_env.cpp index fe70b0b4..14ded24c 100644 --- a/plugins/wasmedge_ffmpeg/ffmpeg_env.cpp +++ b/plugins/wasmedge_ffmpeg/ffmpeg_env.cpp @@ -1,7 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "ffmpeg_env.h" #include "avcodec/module.h" #include "avdevice/module.h" #include "avfilter/module.h" @@ -10,6 +9,8 @@ #include "swresample/module.h" #include "swscale/module.h" +#include "ffmpeg_env.h" + namespace WasmEdge { namespace Host { namespace { @@ -45,14 +46,14 @@ createAVUtil(const Plugin::PluginModule::ModuleDescriptor *) noexcept { } Runtime::Instance::ModuleInstance * -createSWScale(const Plugin::PluginModule::ModuleDescriptor *) noexcept { - return new WasmEdgeFFmpeg::SWScale::WasmEdgeFFmpegSWScaleModule( +createSWResample(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeFFmpeg::SWResample::WasmEdgeFFmpegSWResampleModule( WasmEdgeFFmpeg::WasmEdgeFFmpegEnv::getInstance()); } Runtime::Instance::ModuleInstance * -createSWResample(const Plugin::PluginModule::ModuleDescriptor *) noexcept { - return new WasmEdgeFFmpeg::SWResample::WasmEdgeFFmpegSWResampleModule( +createSWScale(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new WasmEdgeFFmpeg::SWScale::WasmEdgeFFmpegSWScaleModule( WasmEdgeFFmpeg::WasmEdgeFFmpegEnv::getInstance()); } diff --git a/plugins/wasmedge_ffmpeg/ffmpeg_env.h b/plugins/wasmedge_ffmpeg/ffmpeg_env.h index 300de141..c2ad46b7 100644 --- a/plugins/wasmedge_ffmpeg/ffmpeg_env.h +++ b/plugins/wasmedge_ffmpeg/ffmpeg_env.h @@ -6,7 +6,8 @@ #include "bindings.h" #include "plugin/plugin.h" -#include "vector" +#include +#include namespace WasmEdge { namespace Host { diff --git a/plugins/wasmedge_ffmpeg/swresample/module.h b/plugins/wasmedge_ffmpeg/swresample/module.h index 0d1aa42f..00d4c839 100644 --- a/plugins/wasmedge_ffmpeg/swresample/module.h +++ b/plugins/wasmedge_ffmpeg/swresample/module.h @@ -4,6 +4,7 @@ #pragma once #include "ffmpeg_env.h" + #include "runtime/instance/module.h" namespace WasmEdge { diff --git a/plugins/wasmedge_ffmpeg/swresample/swresample_base.h b/plugins/wasmedge_ffmpeg/swresample/swresample_base.h deleted file mode 100644 index b3bd6078..00000000 --- a/plugins/wasmedge_ffmpeg/swresample/swresample_base.h +++ /dev/null @@ -1,28 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2024 Second State INC - -#pragma once - -#include "ffmpeg_env.h" -#include "runtime/hostfunc.h" - -namespace WasmEdge { -namespace Host { -namespace WasmEdgeFFmpeg { -namespace SWResample { - -template -class WasmEdgeFFmpegSWResample : public Runtime::HostFunction { -public: - WasmEdgeFFmpegSWResample( - std::shared_ptr HostEnv) - : Runtime::HostFunction(0), Env(HostEnv) {} - -protected: - std::shared_ptr Env; -}; - -} // namespace SWResample -} // namespace WasmEdgeFFmpeg -} // namespace Host -} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/swresample/swresample_func.h b/plugins/wasmedge_ffmpeg/swresample/swresample_func.h index 5cfa38c2..61104d70 100644 --- a/plugins/wasmedge_ffmpeg/swresample/swresample_func.h +++ b/plugins/wasmedge_ffmpeg/swresample/swresample_func.h @@ -3,42 +3,37 @@ #pragma once -#include "ffmpeg_env.h" -#include "runtime/callingframe.h" -#include "swresample_base.h" +#include "ffmpeg_base.h" namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { namespace SWResample { -class SWResampleVersion : public WasmEdgeFFmpegSWResample { +class SWResampleVersion : public HostFunction { public: - SWResampleVersion(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWResample(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class SWRGetDelay : public WasmEdgeFFmpegSWResample { +class SWRGetDelay : public HostFunction { public: - SWRGetDelay(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWResample(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SWRContextId, int64_t Base); }; -class SWRInit : public WasmEdgeFFmpegSWResample { +class SWRInit : public HostFunction { public: - SWRInit(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWResample(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SWRContextId); }; -class SWRAllocSetOpts : public WasmEdgeFFmpegSWResample { +class SWRAllocSetOpts : public HostFunction { public: - SWRAllocSetOpts(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWResample(HostEnv) {} + using HostFunction::HostFunction; + Expect body(const Runtime::CallingFrame &Frame, uint32_t SwrCtxPtr, uint32_t SWRContextId, uint64_t OutChLayout, uint32_t OutSampleFmtId, int32_t OutSampleRate, @@ -46,60 +41,51 @@ class SWRAllocSetOpts : public WasmEdgeFFmpegSWResample { int32_t InSampleRate, int32_t LogOffset); }; -class AVOptSetDict : public WasmEdgeFFmpegSWResample { +class AVOptSetDict : public HostFunction { public: - AVOptSetDict(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWResample(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SWRContextId, uint32_t DictId); }; -class SWRConvertFrame : public WasmEdgeFFmpegSWResample { +class SWRConvertFrame : public HostFunction { public: - SWRConvertFrame(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWResample(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SWRContextId, uint32_t FrameOutputId, uint32_t FrameInputId); }; -class SWRFree : public WasmEdgeFFmpegSWResample { +class SWRFree : public HostFunction { public: - SWRFree(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWResample(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SWRContextId); }; class SWResampleConfigurationLength - : public WasmEdgeFFmpegSWResample { + : public HostFunction { public: - SWResampleConfigurationLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWResample(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class SWResampleConfiguration - : public WasmEdgeFFmpegSWResample { +class SWResampleConfiguration : public HostFunction { public: - SWResampleConfiguration(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWResample(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, uint32_t ConfigLen); }; -class SWResampleLicenseLength - : public WasmEdgeFFmpegSWResample { +class SWResampleLicenseLength : public HostFunction { public: - SWResampleLicenseLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWResample(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class SWResampleLicense : public WasmEdgeFFmpegSWResample { +class SWResampleLicense : public HostFunction { public: - SWResampleLicense(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWResample(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, uint32_t LicenseLen); }; diff --git a/plugins/wasmedge_ffmpeg/swscale/module.h b/plugins/wasmedge_ffmpeg/swscale/module.h index 69fde0d7..e9ca104d 100644 --- a/plugins/wasmedge_ffmpeg/swscale/module.h +++ b/plugins/wasmedge_ffmpeg/swscale/module.h @@ -4,6 +4,7 @@ #pragma once #include "ffmpeg_env.h" + #include "runtime/instance/module.h" namespace WasmEdge { diff --git a/plugins/wasmedge_ffmpeg/swscale/swscale_base.h b/plugins/wasmedge_ffmpeg/swscale/swscale_base.h deleted file mode 100644 index 3ae1b62d..00000000 --- a/plugins/wasmedge_ffmpeg/swscale/swscale_base.h +++ /dev/null @@ -1,28 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2024 Second State INC - -#pragma once - -#include "ffmpeg_env.h" -#include "runtime/hostfunc.h" - -namespace WasmEdge { -namespace Host { -namespace WasmEdgeFFmpeg { -namespace SWScale { - -template -class WasmEdgeFFmpegSWScale : public Runtime::HostFunction { -public: - WasmEdgeFFmpegSWScale( - std::shared_ptr HostEnv) - : Runtime::HostFunction(0), Env(HostEnv) {} - -protected: - std::shared_ptr Env; -}; - -} // namespace SWScale -} // namespace WasmEdgeFFmpeg -} // namespace Host -} // namespace WasmEdge diff --git a/plugins/wasmedge_ffmpeg/swscale/swscale_func.h b/plugins/wasmedge_ffmpeg/swscale/swscale_func.h index 30ce313d..12edd643 100644 --- a/plugins/wasmedge_ffmpeg/swscale/swscale_func.h +++ b/plugins/wasmedge_ffmpeg/swscale/swscale_func.h @@ -3,18 +3,16 @@ #pragma once -#include "runtime/callingframe.h" -#include "swscale_base.h" +#include "ffmpeg_base.h" namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { namespace SWScale { -class SwsGetContext : public WasmEdgeFFmpegSWScale { +class SwsGetContext : public HostFunction { public: - SwsGetContext(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWScale(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsCtxPtr, uint32_t SrcW, uint32_t SrcH, uint32_t SrcPixFormatId, uint32_t DesW, uint32_t DesH, uint32_t DesPixFormatId, @@ -22,26 +20,23 @@ class SwsGetContext : public WasmEdgeFFmpegSWScale { uint32_t DesFilterId); }; -class SwsFreeContext : public WasmEdgeFFmpegSWScale { +class SwsFreeContext : public HostFunction { public: - SwsFreeContext(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWScale(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsCtxId); }; -class SwsScale : public WasmEdgeFFmpegSWScale { +class SwsScale : public HostFunction { public: - SwsScale(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWScale(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsCtxId, uint32_t InputFrameId, int32_t SrcSliceY, int32_t SrcSliceH, uint32_t OutputFrameId); }; -class SwsGetCachedContext : public WasmEdgeFFmpegSWScale { +class SwsGetCachedContext : public HostFunction { public: - SwsGetCachedContext(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWScale(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsCachedCtxPtr, uint32_t SwsCtxPtr, uint32_t SrcW, uint32_t SrcH, uint32_t SrcPixFormatId, @@ -50,36 +45,31 @@ class SwsGetCachedContext : public WasmEdgeFFmpegSWScale { uint32_t DesFilterId); }; -class SwsIsSupportedInput : public WasmEdgeFFmpegSWScale { +class SwsIsSupportedInput : public HostFunction { public: - SwsIsSupportedInput(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWScale(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t PixFormatId); }; -class SwsIsSupportedOutput - : public WasmEdgeFFmpegSWScale { +class SwsIsSupportedOutput : public HostFunction { public: - SwsIsSupportedOutput(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWScale(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t PixFormatId); }; class SwsIsSupportedEndiannessConversion - : public WasmEdgeFFmpegSWScale { + : public HostFunction { public: - SwsIsSupportedEndiannessConversion(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWScale(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t PixFormatId); }; -class SwsGetDefaultFilter : public WasmEdgeFFmpegSWScale { +class SwsGetDefaultFilter : public HostFunction { public: - SwsGetDefaultFilter(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWScale(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsFilterPtr, float LumaGBlur, float ChromaGBlur, float LumaSharpen, @@ -87,138 +77,118 @@ class SwsGetDefaultFilter : public WasmEdgeFFmpegSWScale { float ChromaVShift, int32_t Verbose); }; -class SwsGetLumaH : public WasmEdgeFFmpegSWScale { +class SwsGetLumaH : public HostFunction { public: - SwsGetLumaH(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWScale(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsFilterId, uint32_t SwsVectorPtr); }; -class SwsGetLumaV : public WasmEdgeFFmpegSWScale { +class SwsGetLumaV : public HostFunction { public: - SwsGetLumaV(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWScale(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsFilterId, uint32_t SwsVectorPtr); }; -class SwsGetChromaH : public WasmEdgeFFmpegSWScale { +class SwsGetChromaH : public HostFunction { public: - SwsGetChromaH(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWScale(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsFilterId, uint32_t SwsVectorPtr); }; -class SwsGetChromaV : public WasmEdgeFFmpegSWScale { +class SwsGetChromaV : public HostFunction { public: - SwsGetChromaV(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWScale(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsFilterId, uint32_t SwsVectorPtr); }; -class SwsFreeFilter : public WasmEdgeFFmpegSWScale { +class SwsFreeFilter : public HostFunction { public: - SwsFreeFilter(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWScale(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsFilterId); }; -class SwsAllocVec : public WasmEdgeFFmpegSWScale { +class SwsAllocVec : public HostFunction { public: - SwsAllocVec(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWScale(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsVectorPtr, int32_t Length); }; -class SwsGetGaussianVec : public WasmEdgeFFmpegSWScale { +class SwsGetGaussianVec : public HostFunction { public: - SwsGetGaussianVec(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWScale(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsVectorPtr, double Variance, double Quality); }; -class SwsScaleVec : public WasmEdgeFFmpegSWScale { +class SwsScaleVec : public HostFunction { public: - SwsScaleVec(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWScale(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsVectorId, double Scalar); }; -class SwsNormalizeVec : public WasmEdgeFFmpegSWScale { +class SwsNormalizeVec : public HostFunction { public: - SwsNormalizeVec(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWScale(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsVectorId, double Height); }; -class SwsGetCoeffVecLength - : public WasmEdgeFFmpegSWScale { +class SwsGetCoeffVecLength : public HostFunction { public: - SwsGetCoeffVecLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWScale(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &, uint32_t SwsVectorId); }; -class SwsGetCoeff : public WasmEdgeFFmpegSWScale { +class SwsGetCoeff : public HostFunction { public: - SwsGetCoeff(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWScale(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &, uint32_t SwsVectorId, uint32_t CoeffBuf, uint32_t Len); }; -class SwsFreeVec : public WasmEdgeFFmpegSWScale { +class SwsFreeVec : public HostFunction { public: - SwsFreeVec(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWScale(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t SwsVectorId); }; -class SwscaleVersion : public WasmEdgeFFmpegSWScale { +class SwscaleVersion : public HostFunction { public: - SwscaleVersion(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWScale(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; class SwscaleConfigurationLength - : public WasmEdgeFFmpegSWScale { + : public HostFunction { public: - SwscaleConfigurationLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWScale(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class SwscaleConfiguration - : public WasmEdgeFFmpegSWScale { +class SwscaleConfiguration : public HostFunction { public: - SwscaleConfiguration(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWScale(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t ConfigPtr, uint32_t ConfigLen); }; -class SwscaleLicenseLength - : public WasmEdgeFFmpegSWScale { +class SwscaleLicenseLength : public HostFunction { public: - SwscaleLicenseLength(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWScale(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame); }; -class SwscaleLicense : public WasmEdgeFFmpegSWScale { +class SwscaleLicense : public HostFunction { public: - SwscaleLicense(std::shared_ptr HostEnv) - : WasmEdgeFFmpegSWScale(HostEnv) {} + using HostFunction::HostFunction; Expect body(const Runtime::CallingFrame &Frame, uint32_t LicensePtr, uint32_t LicenseLen); }; diff --git a/test/plugins/wasmedge_ffmpeg/CMakeLists.txt b/test/plugins/wasmedge_ffmpeg/CMakeLists.txt index 5b66bf2c..f1c0ca15 100644 --- a/test/plugins/wasmedge_ffmpeg/CMakeLists.txt +++ b/test/plugins/wasmedge_ffmpeg/CMakeLists.txt @@ -27,9 +27,9 @@ wasmedge_add_executable(wasmedgeFFmpegTests avutil/avSampleFmt.cpp avutil/avPixfmt.cpp - swscale/swscale_func.cpp - swresample/swresample_func.cpp + + swscale/swscale_func.cpp utils.cpp ) @@ -54,7 +54,6 @@ target_include_directories(wasmedgeFFmpegTests target_link_libraries(wasmedgeFFmpegTests PRIVATE - wasmedgePluginWasmEdgeFFmpeg ${GTEST_BOTH_LIBRARIES} ) diff --git a/test/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp b/test/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp index eb2e8a95..17dc5a49 100644 --- a/test/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp +++ b/test/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avcodec/avCodec.h" #include "avcodec/module.h" #include "utils.h" @@ -11,7 +14,6 @@ namespace Host { namespace WasmEdgeFFmpeg { TEST_F(FFmpegTest, AVCodec) { - ASSERT_TRUE(AVCodecMod != nullptr); uint32_t AVCodecPtr = UINT32_C(20); diff --git a/test/plugins/wasmedge_ffmpeg/avcodec/avCodecCtx.cpp b/test/plugins/wasmedge_ffmpeg/avcodec/avCodecCtx.cpp index b89128f8..d7adca94 100644 --- a/test/plugins/wasmedge_ffmpeg/avcodec/avCodecCtx.cpp +++ b/test/plugins/wasmedge_ffmpeg/avcodec/avCodecCtx.cpp @@ -1,5 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avcodec/avCodecContext.h" #include "avcodec/module.h" + #include "utils.h" #include diff --git a/test/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp b/test/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp index 9fc5029c..2bcfdfd4 100644 --- a/test/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp +++ b/test/plugins/wasmedge_ffmpeg/avcodec/avCodecParameters.cpp @@ -1,5 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avcodec/avCodecParameters.h" #include "avcodec/module.h" + #include "utils.h" #include diff --git a/test/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp b/test/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp index 88b5a226..fa163e49 100644 --- a/test/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp +++ b/test/plugins/wasmedge_ffmpeg/avcodec/avPacket.cpp @@ -1,5 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avcodec/avPacket.h" #include "avcodec/module.h" + #include "utils.h" #include @@ -11,7 +15,6 @@ namespace Host { namespace WasmEdgeFFmpeg { TEST_F(FFmpegTest, AVPacketTest) { - ASSERT_TRUE(AVCodecMod != nullptr); uint32_t PacketPtr = UINT32_C(4); diff --git a/test/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp b/test/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp index d5fcc6b0..e85bb769 100644 --- a/test/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp +++ b/test/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp @@ -1,7 +1,11 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avcodec/avcodec_func.h" #include "avcodec/module.h" #include "utils.h" + #include namespace WasmEdge { @@ -404,7 +408,6 @@ TEST_F(FFmpegTest, AVCodecFunc) { } TEST_F(FFmpegTest, SendPacketReceiveFrame) { - std::string FileName = "ffmpeg-assets/dummy.mp4"; // 32 chars uint32_t CodecCtxPtr = UINT32_C(64); uint32_t FramePtr = UINT32_C(72); diff --git a/test/plugins/wasmedge_ffmpeg/avfilter/avfilter.cpp b/test/plugins/wasmedge_ffmpeg/avfilter/avfilter.cpp index 58930164..c96448a2 100644 --- a/test/plugins/wasmedge_ffmpeg/avfilter/avfilter.cpp +++ b/test/plugins/wasmedge_ffmpeg/avfilter/avfilter.cpp @@ -1,6 +1,10 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avfilter/avFilter.h" #include "avfilter//avfilter_func.h" #include "avfilter/module.h" + #include "utils.h" #include diff --git a/test/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp b/test/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp index 2a936fbd..1e50b401 100644 --- a/test/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp +++ b/test/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp @@ -1,7 +1,11 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avfilter//avfilter_func.h" #include "avfilter/avFilter.h" #include "avfilter/buffer_source_sink.h" #include "avfilter/module.h" + #include "utils.h" #include @@ -11,7 +15,6 @@ namespace Host { namespace WasmEdgeFFmpeg { TEST_F(FFmpegTest, AVFilterFunc) { - ASSERT_TRUE(AVFilterMod != nullptr); // Structs Ptr diff --git a/test/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp b/test/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp index 3dbef770..299d26f9 100644 --- a/test/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp +++ b/test/plugins/wasmedge_ffmpeg/avformat/avChapter.cpp @@ -1,5 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avformat/avChapter.h" #include "avformat/module.h" + #include "utils.h" #include @@ -10,7 +14,6 @@ namespace WasmEdgeFFmpeg { // Sample Video under test has only Single Chapter. TEST_F(FFmpegTest, AVChapter) { - ASSERT_TRUE(AVFormatMod != nullptr); uint32_t ChapterIdx = 0; diff --git a/test/plugins/wasmedge_ffmpeg/avformat/avInputOutputContext.cpp b/test/plugins/wasmedge_ffmpeg/avformat/avInputOutputContext.cpp index c78db2f6..99a1fcbe 100644 --- a/test/plugins/wasmedge_ffmpeg/avformat/avInputOutputContext.cpp +++ b/test/plugins/wasmedge_ffmpeg/avformat/avInputOutputContext.cpp @@ -1,6 +1,10 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avformat/avInputOutputFormat.h" #include "avformat/avformatContext.h" #include "avformat/module.h" + #include "utils.h" #include @@ -10,7 +14,6 @@ namespace Host { namespace WasmEdgeFFmpeg { TEST_F(FFmpegTest, AVInputFormat) { - std::string FileName = "ffmpeg-assets/sample_video.mp4"; // 32 chars uint32_t FormatCtxPtr = UINT32_C(24); uint32_t InputFormatPtr = UINT32_C(28); diff --git a/test/plugins/wasmedge_ffmpeg/avformat/avStream.cpp b/test/plugins/wasmedge_ffmpeg/avformat/avStream.cpp index 1fc2dadc..0cc0655b 100644 --- a/test/plugins/wasmedge_ffmpeg/avformat/avStream.cpp +++ b/test/plugins/wasmedge_ffmpeg/avformat/avStream.cpp @@ -1,5 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avformat/avStream.h" #include "avformat/module.h" + #include "utils.h" #include @@ -10,7 +14,6 @@ namespace WasmEdgeFFmpeg { // Testing all AVFormat_funcs. TEST_F(FFmpegTest, AVStreamStruct) { - ASSERT_TRUE(AVFormatMod != nullptr); uint32_t StreamIdx = 0; diff --git a/test/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp b/test/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp index 27c7f511..3d23d403 100644 --- a/test/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp +++ b/test/plugins/wasmedge_ffmpeg/avformat/avformatContext.cpp @@ -1,5 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avformat/avformatContext.h" #include "avformat/module.h" + #include "utils.h" #include @@ -10,7 +14,6 @@ namespace WasmEdgeFFmpeg { // Testing all AVFormat_funcs. TEST_F(FFmpegTest, AVFormatContextStruct) { - uint32_t FormatCtxPtr = UINT32_C(4); uint32_t InputFormatPtr = UINT32_C(8); uint32_t OutputFormatPtr = UINT32_C(12); diff --git a/test/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp b/test/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp index 6ca7f158..788e7495 100644 --- a/test/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp +++ b/test/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp @@ -1,5 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avformat/avformat_func.h" #include "avformat/module.h" + #include "utils.h" #include @@ -10,7 +14,6 @@ namespace WasmEdgeFFmpeg { // Testing all AVFormat_funcs. TEST_F(FFmpegTest, AVInputFormatFunc) { - uint32_t FormatCtxPtr = UINT32_C(4); uint32_t DictPtr = UINT32_C(16); uint32_t KeyPtr = UINT32_C(100); diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp index 2e3542c4..bba090c8 100644 --- a/test/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp +++ b/test/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp @@ -1,5 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avutil/avDictionary.h" #include "avutil/module.h" + #include "utils.h" #include diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avError.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avError.cpp index bcba324c..8584c97b 100644 --- a/test/plugins/wasmedge_ffmpeg/avutil/avError.cpp +++ b/test/plugins/wasmedge_ffmpeg/avutil/avError.cpp @@ -1,5 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avutil/error.h" #include "avutil/module.h" + #include "utils.h" #include diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp index 5ffa1d9d..b30fe2ab 100644 --- a/test/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp +++ b/test/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp @@ -1,5 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avutil/avFrame.h" #include "avutil/module.h" + #include "utils.h" #include @@ -9,7 +13,6 @@ namespace Host { namespace WasmEdgeFFmpeg { TEST_F(FFmpegTest, AVFrame) { - uint32_t AVFramePtr = UINT32_C(72); uint32_t AVFrame2Ptr = UINT32_C(40); uint32_t DictPtr = UINT32_C(36); diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avPixfmt.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avPixfmt.cpp index 91b3399a..625acd10 100644 --- a/test/plugins/wasmedge_ffmpeg/avutil/avPixfmt.cpp +++ b/test/plugins/wasmedge_ffmpeg/avutil/avPixfmt.cpp @@ -1,5 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avutil/module.h" #include "avutil/pixfmt.h" + #include "utils.h" #include @@ -9,7 +13,6 @@ namespace Host { namespace WasmEdgeFFmpeg { TEST_F(FFmpegTest, AVPixFmt) { - uint32_t NamePtr = UINT32_C(4); auto *FuncInst = AVUtilMod->findFuncExports( diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avRational.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avRational.cpp index 250c0fec..85223ad9 100644 --- a/test/plugins/wasmedge_ffmpeg/avutil/avRational.cpp +++ b/test/plugins/wasmedge_ffmpeg/avutil/avRational.cpp @@ -1,5 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avutil/avRational.h" #include "avutil/module.h" + #include "utils.h" #include @@ -9,7 +13,6 @@ namespace Host { namespace WasmEdgeFFmpeg { TEST_F(FFmpegTest, AVRational) { - ASSERT_TRUE(AVUtilMod != nullptr); uint32_t NumPtr = UINT32_C(4); diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avSampleFmt.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avSampleFmt.cpp index 14ea8bb5..d12cd16d 100644 --- a/test/plugins/wasmedge_ffmpeg/avutil/avSampleFmt.cpp +++ b/test/plugins/wasmedge_ffmpeg/avutil/avSampleFmt.cpp @@ -1,5 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avutil/module.h" #include "avutil/samplefmt.h" + #include "utils.h" #include diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp index 43c75cee..2406cea1 100644 --- a/test/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp +++ b/test/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp @@ -1,6 +1,10 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "avutil/avutil_func.h" #include "avutil/avTime.h" #include "avutil/module.h" + #include "utils.h" #include @@ -10,7 +14,6 @@ namespace Host { namespace WasmEdgeFFmpeg { TEST_F(FFmpegTest, AVUtilFunc) { - ASSERT_TRUE(AVUtilMod != nullptr); uint32_t NamePtr = UINT32_C(4); diff --git a/test/plugins/wasmedge_ffmpeg/main.cpp b/test/plugins/wasmedge_ffmpeg/main.cpp index 852694a0..c2be683b 100644 --- a/test/plugins/wasmedge_ffmpeg/main.cpp +++ b/test/plugins/wasmedge_ffmpeg/main.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include GTEST_API_ int main(int Argc, char **Argv) { diff --git a/test/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp b/test/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp index 1fc4f905..6f84db6f 100644 --- a/test/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp +++ b/test/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp @@ -1,7 +1,11 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "swresample/swresample_func.h" #include "swresample/module.h" #include "utils.h" + #include namespace WasmEdge { @@ -9,7 +13,6 @@ namespace Host { namespace WasmEdgeFFmpeg { TEST_F(FFmpegTest, SWResampleFunc) { - ASSERT_TRUE(SWResampleMod != nullptr); uint32_t DictPtr = UINT32_C(4); diff --git a/test/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp b/test/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp index cfe4b38c..2f0fa314 100644 --- a/test/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp +++ b/test/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp @@ -1,7 +1,11 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "swscale/swscale_func.h" #include "swscale/module.h" #include "utils.h" + #include namespace WasmEdge { @@ -13,7 +17,6 @@ namespace WasmEdgeFFmpeg { // ============================================================================ TEST_F(FFmpegTest, SwsContext) { - ASSERT_TRUE(SWScaleMod != nullptr); auto *FuncInst = @@ -199,7 +202,6 @@ TEST_F(FFmpegTest, SwsContext) { // ============================================================================ TEST_F(FFmpegTest, SwsFilter) { - ASSERT_TRUE(SWScaleMod != nullptr); auto *FuncInst = SWScaleMod->findFuncExports( "wasmedge_ffmpeg_swscale_sws_getDefaultFilter"); @@ -321,7 +323,6 @@ TEST_F(FFmpegTest, SwsFilter) { // ============================================================================ TEST_F(FFmpegTest, SwsVector) { - ASSERT_TRUE(SWScaleMod != nullptr); uint32_t SwsVectorPtr = UINT32_C(40); uint32_t CoeffPtr = UINT32_C(100); @@ -456,7 +457,6 @@ TEST_F(FFmpegTest, SwsVector) { // ============================================================================ TEST_F(FFmpegTest, SWScaleVersion) { - ASSERT_TRUE(SWScaleMod != nullptr); uint32_t Length = 0; diff --git a/test/plugins/wasmedge_ffmpeg/utils.cpp b/test/plugins/wasmedge_ffmpeg/utils.cpp index cef6ba40..2656cbb3 100644 --- a/test/plugins/wasmedge_ffmpeg/utils.cpp +++ b/test/plugins/wasmedge_ffmpeg/utils.cpp @@ -1,4 +1,8 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "utils.h" + #include "avcodec/avCodecContext.h" #include "avcodec/avPacket.h" #include "avcodec/avcodec_func.h" @@ -12,7 +16,6 @@ namespace Host { namespace WasmEdgeFFmpeg { void FFmpegTest::initEmptyFrame(uint32_t FramePtr) { - auto *FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_alloc"); auto &HostFuncAVFrameAlloc = @@ -150,18 +153,16 @@ void FFmpegTest::initFFmpegStructs(uint32_t AVCodecPtr, uint32_t AVFormatCtxPtr, FuncInst->getHostFunc()); while (true) { - HostFuncAVCodecReceiveFrame.run( CallFrame, std::initializer_list{AVCodecCtxId, FrameId}, Result); // Error returned by FFmpeg are negative. - int32_t Error = Result[0].get() * -1; + int32_t Error = Result[0].get() * (-1); if (Error == EAGAIN) { while (true) { - allocPacket(PacketPtr); uint32_t PackedId = readUInt32(MemInst, PacketPtr); @@ -203,7 +204,6 @@ void FFmpegTest::initFFmpegStructs(uint32_t AVCodecPtr, uint32_t AVFormatCtxPtr, void FFmpegTest::initFormatCtx(uint32_t AVFormatCtxPtr, uint32_t FilePtr, std::string FileName) { - int32_t Length = FileName.length(); fillMemContent(MemInst, FilePtr, Length); fillMemContent(MemInst, FilePtr, FileName); @@ -222,7 +222,6 @@ void FFmpegTest::initFormatCtx(uint32_t AVFormatCtxPtr, uint32_t FilePtr, void FFmpegTest::initDict(uint32_t DictPtr, uint32_t KeyPtr, std::string Key, uint32_t ValuePtr, std::string Value) { - uint32_t KeyLen = Key.length(); uint32_t ValueLen = Value.length(); fillMemContent(MemInst, KeyPtr, KeyLen + ValueLen); @@ -242,7 +241,6 @@ void FFmpegTest::initDict(uint32_t DictPtr, uint32_t KeyPtr, std::string Key, } void FFmpegTest::allocPacket(uint32_t PacketPtr) { - auto *FuncInst = AVCodecMod->findFuncExports("wasmedge_ffmpeg_avcodec_av_packet_alloc"); auto &HostFuncAVPacketAlloc = diff --git a/test/plugins/wasmedge_ffmpeg/utils.h b/test/plugins/wasmedge_ffmpeg/utils.h index 5b5175fa..7a66f18f 100644 --- a/test/plugins/wasmedge_ffmpeg/utils.h +++ b/test/plugins/wasmedge_ffmpeg/utils.h @@ -1,18 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once + #include "avcodec/module.h" #include "avfilter/module.h" #include "avformat/module.h" #include "avutil/module.h" +#include "swresample/module.h" +#include "swscale/module.h" + #include "common/types.h" #include "runtime/callingframe.h" #include "runtime/instance/module.h" -#include "swresample/module.h" -#include "swscale/module.h" + #include "gtest/gtest.h" namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { + inline void writeUInt32(WasmEdge::Runtime::Instance::MemoryInstance *MemInst, uint32_t Value, uint32_t &Ptr) { uint32_t *BufPtr = MemInst->getPointer(Ptr); @@ -159,6 +166,7 @@ class FFmpegTest : public ::testing::Test { WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::WasmEdgeFFmpegAVFilterModule *AVFilterMod = nullptr; }; + } // namespace WasmEdgeFFmpeg } // namespace Host } // namespace WasmEdge From f3a7ea83d010c9f720ec2d844585105652c081ff Mon Sep 17 00:00:00 2001 From: Yi-Ying He Date: Mon, 19 Aug 2024 21:24:01 +0800 Subject: [PATCH 400/623] [CI] Add build and release CI for WASI-NN plugins on MacOS. (#3668) Signed-off-by: YiYing He --- plugins/wasi_nn/CMakeLists.txt | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index a4bfb187..bc89d3c8 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -53,11 +53,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) set(GGML_CUDA OFF) endif() - if(NOT APPLE) - set(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL OFF) - endif() - - if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL) + if(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64" AND WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL) message(STATUS "WASI-NN GGML LLAMA backend: Enable GGML_METAL") set(GGML_METAL ON) set(GGML_METAL_EMBED_LIBRARY ON) @@ -145,7 +141,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) simdjson::simdjson llava ) - if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL) + if(APPLE AND WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL) add_custom_command( TARGET wasmedgePluginWasiNN POST_BUILD From 0f5e58e75116fa0861767d57f5c8093e8f94bd66 Mon Sep 17 00:00:00 2001 From: dm4 Date: Mon, 19 Aug 2024 22:59:40 +0800 Subject: [PATCH 401/623] [CI] Fix the CUDA asset on CI (#3671) Signed-off-by: dm4 --- test/plugins/wasmedge_stablediffusion/CMakeLists.txt | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/plugins/wasmedge_stablediffusion/CMakeLists.txt b/test/plugins/wasmedge_stablediffusion/CMakeLists.txt index dac8abd7..ac4e5dd0 100644 --- a/test/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/test/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -1,3 +1,11 @@ +if(WASMEDGE_PLUGIN_STABLEDIFFUSION_CUBLAS) + message(STATUS "Stable diffusion plugin: Enable SD_CUBLAS") + set(SD_CUBLAS ON) +else() + message(STATUS "Stable diffusion plugin: Disable SD_CUBLAS") + set(SD_CUBLAS OFF) +endif() + wasmedge_add_executable(wasmedgeStableDiffusionTests wasmedge_stablediffusion.cpp ) From 4f056a955e01c4738b161e6f12272c1c46e03648 Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 20 Aug 2024 17:26:19 +0800 Subject: [PATCH 402/623] [Plugin] stable diffusion: fix SD_CUBLAS configuration Signed-off-by: dm4 --- plugins/wasmedge_stablediffusion/CMakeLists.txt | 4 ++-- test/plugins/wasmedge_stablediffusion/CMakeLists.txt | 8 -------- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index d8cb20d2..9c57724c 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -4,10 +4,10 @@ if(WASMEDGE_PLUGIN_STABLEDIFFUSION_CUBLAS) message(STATUS "Stable diffusion plugin: Enable SD_CUBLAS") - set(SD_CUBLAS ON) + set(SD_CUBLAS ON CACHE BOOL "Stable diffusion plugin: Enable SD_CUBLAS") else() message(STATUS "Stable diffusion plugin: Disable SD_CUBLAS") - set(SD_CUBLAS OFF) + set(SD_CUBLAS OFF CACHE BOOL "Stable diffusion plugin: Disable SD_CUBLAS") endif() # setup stable diffusion diff --git a/test/plugins/wasmedge_stablediffusion/CMakeLists.txt b/test/plugins/wasmedge_stablediffusion/CMakeLists.txt index ac4e5dd0..dac8abd7 100644 --- a/test/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/test/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -1,11 +1,3 @@ -if(WASMEDGE_PLUGIN_STABLEDIFFUSION_CUBLAS) - message(STATUS "Stable diffusion plugin: Enable SD_CUBLAS") - set(SD_CUBLAS ON) -else() - message(STATUS "Stable diffusion plugin: Disable SD_CUBLAS") - set(SD_CUBLAS OFF) -endif() - wasmedge_add_executable(wasmedgeStableDiffusionTests wasmedge_stablediffusion.cpp ) From 48453608447a5fbda1d63c6fb16eb6a4df79cd72 Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 21 Aug 2024 21:33:47 +0800 Subject: [PATCH 403/623] [Plugin] stable diffusion: add Metal support (#3680) Signed-off-by: dm4 --- .../wasmedge_stablediffusion/CMakeLists.txt | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index 9c57724c..313449b6 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -10,6 +10,15 @@ else() set(SD_CUBLAS OFF CACHE BOOL "Stable diffusion plugin: Disable SD_CUBLAS") endif() +if(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64" AND WASMEDGE_PLUGIN_STABLEDIFFUSION_METAL) + message(STATUS "Stable diffusion plugin: Enable SD_METAL") + set(SD_METAL ON CACHE BOOL "Stable diffusion plugin: Enable SD_METAL") + set(GGML_METAL_EMBED_LIBRARY ON) +else() + message(STATUS "Stable diffusion plugin: Disable SD_METAL") + set(SD_METAL OFF CACHE BOOL "Stable diffusion plugin: Disable SD_METAL") +endif() + # setup stable diffusion message(STATUS "Downloading stable diffusion source") FetchContent_Declare( @@ -86,6 +95,15 @@ else() ) endif() +if(WASMEDGE_PLUGIN_STABLEDIFFUSION_METAL) + add_custom_command( + TARGET wasmedgePluginWasmEdgeStableDiffusion + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${stable-diffusion_SOURCE_DIR}/ggml/src/ggml-metal.metal ggml-metal.metal + COMMAND ${CMAKE_COMMAND} -E copy ${stable-diffusion_SOURCE_DIR}/ggml/ggml-common.h ggml-common.h + ) +endif() + install( TARGETS wasmedgePluginWasmEdgeStableDiffusion DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge From 10a48b875750aa99663cff7104cb69ce69ea9995 Mon Sep 17 00:00:00 2001 From: hydai Date: Thu, 22 Aug 2024 02:54:14 +0800 Subject: [PATCH 404/623] [WASI-NN] ggml: bump to b3613 Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index bc89d3c8..780ec898 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -68,7 +68,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b3567 + GIT_TAG b3613 GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(llama) From 020c8596dc80d39906c6a48927708f5848160cd6 Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Wed, 21 Aug 2024 20:42:12 +0800 Subject: [PATCH 405/623] [Docker] Refactor ubuntu images with docker bake Signed-off-by: Yi Huang --- utils/docker/Dockerfile.base | 19 ------- utils/docker/Dockerfile.build-clang | 10 ---- utils/docker/Dockerfile.build-gcc | 11 ---- utils/docker/Dockerfile.build-plugins-deps | 12 ----- utils/docker/Dockerfile.ubuntu-base | 43 ++++++++++++++++ utils/docker/Dockerfile.ubuntu-plugins-deps | 13 +++++ utils/docker/build.sh | 49 ------------------ utils/docker/docker-bake.ubuntu.hcl | 56 +++++++++++++++++++++ 8 files changed, 112 insertions(+), 101 deletions(-) delete mode 100644 utils/docker/Dockerfile.base delete mode 100644 utils/docker/Dockerfile.build-clang delete mode 100644 utils/docker/Dockerfile.build-gcc delete mode 100644 utils/docker/Dockerfile.build-plugins-deps create mode 100644 utils/docker/Dockerfile.ubuntu-base create mode 100644 utils/docker/Dockerfile.ubuntu-plugins-deps delete mode 100755 utils/docker/build.sh create mode 100644 utils/docker/docker-bake.ubuntu.hcl diff --git a/utils/docker/Dockerfile.base b/utils/docker/Dockerfile.base deleted file mode 100644 index db431fb9..00000000 --- a/utils/docker/Dockerfile.base +++ /dev/null @@ -1,19 +0,0 @@ -FROM ubuntu:22.04 - -MAINTAINER hydai hydai@secondstate.io -ENV DEBIAN_FRONTEND=noninteractive - -RUN apt update && apt upgrade -y \ - && apt install -y \ - software-properties-common \ - dpkg-dev \ - wget \ - cmake \ - ninja-build \ - curl \ - git \ - zlib1g-dev \ - llvm-15-dev \ - liblld-15-dev - -RUN rm -rf /var/lib/apt/lists/* diff --git a/utils/docker/Dockerfile.build-clang b/utils/docker/Dockerfile.build-clang deleted file mode 100644 index bed2409b..00000000 --- a/utils/docker/Dockerfile.build-clang +++ /dev/null @@ -1,10 +0,0 @@ -ARG BASE=wasmedge/wasmedge:ubuntu-base -FROM ${BASE} - -RUN apt update && apt install -y \ - clang-15 - -RUN rm -rf /var/lib/apt/lists/* - -ENV CC=/usr/bin/clang-15 -ENV CXX=/usr/bin/clang++-15 diff --git a/utils/docker/Dockerfile.build-gcc b/utils/docker/Dockerfile.build-gcc deleted file mode 100644 index 12064114..00000000 --- a/utils/docker/Dockerfile.build-gcc +++ /dev/null @@ -1,11 +0,0 @@ -ARG BASE=wasmedge/wasmedge:ubuntu-base -FROM ${BASE} - -RUN apt update && apt install -y \ - gcc \ - g++ - -RUN rm -rf /var/lib/apt/lists/* - -ENV CC=gcc -ENV CXX=g++ diff --git a/utils/docker/Dockerfile.build-plugins-deps b/utils/docker/Dockerfile.build-plugins-deps deleted file mode 100644 index ae797d66..00000000 --- a/utils/docker/Dockerfile.build-plugins-deps +++ /dev/null @@ -1,12 +0,0 @@ -ARG BASE=wasmedge/wasmedge:ubuntu-build-clang -FROM ${BASE} - -RUN apt update && apt install -y \ - wget \ - unzip - -RUN rm -rf /var/lib/apt/lists/* - -COPY opencvmini/install-opencvmini.sh . -ENV OPENCV_VERSION=4.8.0 -RUN [ "/bin/bash", "install-opencvmini.sh" ] diff --git a/utils/docker/Dockerfile.ubuntu-base b/utils/docker/Dockerfile.ubuntu-base new file mode 100644 index 00000000..18009a71 --- /dev/null +++ b/utils/docker/Dockerfile.ubuntu-base @@ -0,0 +1,43 @@ +ARG TOOLCHAIN=clang +FROM ubuntu:22.04 AS base + +ENV DEBIAN_FRONTEND=noninteractive +RUN apt-get update && \ + apt-get upgrade -y && \ + apt-get install -y \ + cmake \ + curl \ + dpkg-dev \ + git \ + llvm-15-dev \ + liblld-15-dev \ + ninja-build \ + software-properties-common \ + wget \ + zlib1g-dev + +### deps for clang ### +FROM base AS deps-clang + +RUN apt-get update && \ + apt-get install -y \ + clang-15 + +ENV CC=/usr/bin/clang-15 +ENV CXX=/usr/bin/clang++-15 + +### deps for gcc ### +FROM base AS deps-gcc + +RUN apt-get update && \ + apt-get install -y \ + gcc \ + g++ + +ENV CC=gcc +ENV CXX=g++ + +### deps for all ### +FROM deps-${TOOLCHAIN} AS final + +RUN rm -rf /var/lib/apt/lists/* diff --git a/utils/docker/Dockerfile.ubuntu-plugins-deps b/utils/docker/Dockerfile.ubuntu-plugins-deps new file mode 100644 index 00000000..a10ccec5 --- /dev/null +++ b/utils/docker/Dockerfile.ubuntu-plugins-deps @@ -0,0 +1,13 @@ +ARG BASE_IMAGE=wasmedge/wasmedge:latest +FROM ${BASE_IMAGE} AS base + +RUN apt-get update && \ + apt-get install -y \ + unzip \ + wget + +COPY opencvmini/install-opencvmini.sh . +ENV OPENCV_VERSION="4.8.0" +RUN [ "/bin/bash", "install-opencvmini.sh" ] + +RUN rm -rf /var/lib/apt/lists/* diff --git a/utils/docker/build.sh b/utils/docker/build.sh deleted file mode 100755 index 71a485bf..00000000 --- a/utils/docker/build.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/usr/bin/env bash -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# SPDX-FileCopyrightText: 2019-2024 Second State INC - -NAME=${1:+$1/}wasmedge -INTERMEDIATES=() -IMAGES=() - -set -e - -function docker_build -{ - local FILENAME=$1; shift - local TAG=$1; shift - local NAME_TAG=${NAME}:${TAG} - echo "Building docker image \"${NAME_TAG}\" from file \"${FILENAME}\"." - - ( set -x; docker build "$@" -f "docker/${FILENAME}" -t "${NAME_TAG}" . ) - - if [[ "${TAG}" == im-* ]]; then - INTERMEDIATES+=( "${NAME_TAG}" ) - else - IMAGES+=( "${NAME_TAG}" ) - fi -} - -# Build all images. -docker_build Dockerfile.base ubuntu-base -docker_build Dockerfile.ci-image-base ci-image-base -docker_build Dockerfile.build-clang ubuntu-build-clang \ - --build-arg "BASE=${NAME}:ubuntu-base" -docker_build Dockerfile.build-clang latest \ - --build-arg "BASE=${NAME}:ubuntu-base" -docker_build Dockerfile.build-gcc ubuntu-build-gcc \ - --build-arg "BASE=${NAME}:ubuntu-base" -docker_build Dockerfile.build-plugins-deps ubuntu-build-clang-plugins-deps \ - --build-arg "BASE=${NAME}:ubuntu-build-clang" -docker_build Dockerfile.build-plugins-deps ubuntu-build-gcc-plugins-deps \ - --build-arg "BASE=${NAME}:ubuntu-build-gcc" - -# Remove intermediate images. -for NAME_TAG in "${INTERMEDIATES[@]}"; do - ( set -x; docker rmi "${NAME_TAG}" ) -done - -# Push all images. -for NAME_TAG in "${IMAGES[@]}"; do - ( set -x; docker push "${NAME_TAG}" ) -done diff --git a/utils/docker/docker-bake.ubuntu.hcl b/utils/docker/docker-bake.ubuntu.hcl new file mode 100644 index 00000000..eb3a534d --- /dev/null +++ b/utils/docker/docker-bake.ubuntu.hcl @@ -0,0 +1,56 @@ +group "default" { + targets = [ + "clang", + "clang-plugins", + "gcc", + "gcc-plugins" + ] +} + +target "base" { + dockerfile = "Dockerfile.ubuntu-base" + context = "./utils/docker" +} + +target "plugins-base" { + dockerfile = "./docker/Dockerfile.ubuntu-plugins-deps" + context = "./utils" +} + +target "clang" { + inherits = ["base"] + tags = [ + "wasmedge/wasmedge:latest", + "wasmedge/wasmedge:ubuntu-build-clang" + ] +} + +target "clang-plugins" { + inherits = ["plugins-base"] + tags = ["wasmedge/wasmedge:ubuntu-build-clang-plugins-deps"] + contexts = { + "wasmedge/wasmedge:ubuntu-build-clang" = "target:base" + } + args = { + BASE_IMAGE = "wasmedge/wasmedge:ubuntu-build-clang" + } +} + +target "gcc" { + inherits = ["base"] + tags = ["wasmedge/wasmedge:ubuntu-build-gcc"] + args = { + TOOLCHAIN = "gcc" + } +} + +target "gcc-plugins" { + inherits = ["plugins-base"] + tags = ["wasmedge/wasmedge:ubuntu-build-gcc-plugins-deps"] + contexts = { + "wasmedge/wasmedge:ubuntu-build-gcc" = "target:base" + } + args = { + BASE_IMAGE = "wasmedge/wasmedge:ubuntu-build-gcc" + } +} From 9c9ee9b9fbfd79fe2da1c7973cb3e59c2f0d17c6 Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Wed, 21 Aug 2024 19:43:16 +0800 Subject: [PATCH 406/623] [Docker] Fix warnings for manylinux images - FromAsCasing - InvalidDefaultArgInFrom - MaintainerDeprecated Signed-off-by: Yi Huang --- utils/docker/Dockerfile.manylinux_2_28-base | 5 +---- utils/docker/Dockerfile.manylinux_2_28-plugins-deps | 6 +++--- utils/docker/docker-bake.manylinux.hcl | 4 ---- 3 files changed, 4 insertions(+), 11 deletions(-) diff --git a/utils/docker/Dockerfile.manylinux_2_28-base b/utils/docker/Dockerfile.manylinux_2_28-base index 7166f1fa..30b64c31 100644 --- a/utils/docker/Dockerfile.manylinux_2_28-base +++ b/utils/docker/Dockerfile.manylinux_2_28-base @@ -1,11 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-FileCopyrightText: 2019-2024 Second State INC - -ARG BASE_IMAGE +ARG BASE_IMAGE="quay.io/pypa/manylinux_2_28_x86_64" FROM ${BASE_IMAGE} -MAINTAINER hydai hydai@secondstate.io - ADD SHA256SUM.manylinux_2_28 /root/ # See /opt/rh/gcc-toolset-13/enable diff --git a/utils/docker/Dockerfile.manylinux_2_28-plugins-deps b/utils/docker/Dockerfile.manylinux_2_28-plugins-deps index 1f907cb0..ec7be56d 100644 --- a/utils/docker/Dockerfile.manylinux_2_28-plugins-deps +++ b/utils/docker/Dockerfile.manylinux_2_28-plugins-deps @@ -1,5 +1,5 @@ -ARG BASE_IMAGE -FROM ${BASE_IMAGE} as base +ARG BASE_IMAGE="wasmedge/wasmedge:manylinux_2_28_x86_64" +FROM ${BASE_IMAGE} AS base WORKDIR /root @@ -20,7 +20,7 @@ RUN cd && (yum check-update || true) && \ yum install -y wget unzip zlib-devel zlib-static ### deps for all ### -FROM deps-${TARGETARCH} as final +FROM deps-${TARGETARCH} AS final COPY opencvmini/install-opencvmini.sh . ENV OPENCV_VERSION="4.8.0" diff --git a/utils/docker/docker-bake.manylinux.hcl b/utils/docker/docker-bake.manylinux.hcl index 306081cb..0d122ffc 100644 --- a/utils/docker/docker-bake.manylinux.hcl +++ b/utils/docker/docker-bake.manylinux.hcl @@ -13,7 +13,6 @@ target "x86_64" { platforms = ["linux/amd64"] tags = ["wasmedge/wasmedge:manylinux_2_28_x86_64"] args = { - BASE_IMAGE = "quay.io/pypa/manylinux_2_28_x86_64", LLVM_TARGETS = "X86;BPF", LLVM_TRIPLE = "x86_64-pc-linux-gnu" } @@ -26,9 +25,6 @@ target "x86_64-plugins" { contexts = { "wasmedge/wasmedge:manylinux_2_28_x86_64"= "target:x86_64" } - args = { - BASE_IMAGE = "wasmedge/wasmedge:manylinux_2_28_x86_64" - } } target "aarch64" { From b98b0526b8fb5dfd510483c5361b2086b7e960c2 Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Thu, 22 Aug 2024 14:36:26 +0800 Subject: [PATCH 407/623] [Docker] Fix Github warnings: remove MAINTAINER field Signed-off-by: Yi Huang --- utils/docker/Dockerfile.ci-image-base | 1 - 1 file changed, 1 deletion(-) diff --git a/utils/docker/Dockerfile.ci-image-base b/utils/docker/Dockerfile.ci-image-base index 92f7bd67..15a854af 100644 --- a/utils/docker/Dockerfile.ci-image-base +++ b/utils/docker/Dockerfile.ci-image-base @@ -1,6 +1,5 @@ FROM ubuntu:22.04 -MAINTAINER hydai hydai@secondstate.io ENV DEBIAN_FRONTEND=noninteractive RUN apt update && apt upgrade -y \ From 1d0b57c6683cd55d74ba56cec833aef8db76dddb Mon Sep 17 00:00:00 2001 From: Yi-Ying He Date: Sat, 24 Aug 2024 00:18:07 +0800 Subject: [PATCH 408/623] [CMake] Revert ff06762d: remove the component of cpack until 0.15.0. (#3689) Revert "[CMake] Add the WasmEdge component of cpack. (#3662)" This reverts commit ff06762dbab774a8e9c617dd847e34285297178b. Signed-off-by: YiYing He --- plugins/wasi_crypto/CMakeLists.txt | 1 - plugins/wasi_http/CMakeLists.txt | 1 - plugins/wasi_llm/CMakeLists.txt | 1 - plugins/wasi_nn/CMakeLists.txt | 1 - plugins/wasi_ocr/CMakeLists.txt | 1 - plugins/wasi_poll/CMakeLists.txt | 1 - plugins/wasmedge_ffmpeg/CMakeLists.txt | 1 - plugins/wasmedge_image/CMakeLists.txt | 1 - plugins/wasmedge_opencvmini/CMakeLists.txt | 1 - plugins/wasmedge_process/CMakeLists.txt | 1 - plugins/wasmedge_stablediffusion/CMakeLists.txt | 1 - plugins/wasmedge_tensorflow/CMakeLists.txt | 1 - plugins/wasmedge_tensorflowlite/CMakeLists.txt | 1 - plugins/wasmedge_zlib/CMakeLists.txt | 1 - 14 files changed, 14 deletions(-) diff --git a/plugins/wasi_crypto/CMakeLists.txt b/plugins/wasi_crypto/CMakeLists.txt index 48ae00c9..cc8c8dbe 100644 --- a/plugins/wasi_crypto/CMakeLists.txt +++ b/plugins/wasi_crypto/CMakeLists.txt @@ -82,5 +82,4 @@ endif() install( TARGETS wasmedgePluginWasiCrypto DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge - COMPONENT WasmEdge ) diff --git a/plugins/wasi_http/CMakeLists.txt b/plugins/wasi_http/CMakeLists.txt index 45bc0030..3c7321fd 100644 --- a/plugins/wasi_http/CMakeLists.txt +++ b/plugins/wasi_http/CMakeLists.txt @@ -45,5 +45,4 @@ endif() install( TARGETS wasmedgePluginWasiHttp DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge - COMPONENT WasmEdge ) diff --git a/plugins/wasi_llm/CMakeLists.txt b/plugins/wasi_llm/CMakeLists.txt index 440b5f51..11b96d92 100644 --- a/plugins/wasi_llm/CMakeLists.txt +++ b/plugins/wasi_llm/CMakeLists.txt @@ -48,5 +48,4 @@ endif() install( TARGETS wasmedgePluginWasiLLM DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge - COMPONENT WasmEdge ) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 780ec898..a40f3b80 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -228,5 +228,4 @@ wasmedge_setup_wasinn_target(wasmedgePluginWasiNN) install( TARGETS wasmedgePluginWasiNN DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge - COMPONENT WasmEdge ) diff --git a/plugins/wasi_ocr/CMakeLists.txt b/plugins/wasi_ocr/CMakeLists.txt index 4690ebae..40fe5f8b 100644 --- a/plugins/wasi_ocr/CMakeLists.txt +++ b/plugins/wasi_ocr/CMakeLists.txt @@ -34,7 +34,6 @@ endif() install( TARGETS wasmedgePluginWasiOCR DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge - COMPONENT WasmEdge ) message(STATUS "WASI-OCR: Build Tesseract backend for WASI-OCR") diff --git a/plugins/wasi_poll/CMakeLists.txt b/plugins/wasi_poll/CMakeLists.txt index 9c641135..0ffa67c6 100644 --- a/plugins/wasi_poll/CMakeLists.txt +++ b/plugins/wasi_poll/CMakeLists.txt @@ -39,5 +39,4 @@ endif() install( TARGETS wasmedgePluginWasiPoll DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge - COMPONENT WasmEdge ) diff --git a/plugins/wasmedge_ffmpeg/CMakeLists.txt b/plugins/wasmedge_ffmpeg/CMakeLists.txt index 0a1ff4a8..47e72cdd 100644 --- a/plugins/wasmedge_ffmpeg/CMakeLists.txt +++ b/plugins/wasmedge_ffmpeg/CMakeLists.txt @@ -88,5 +88,4 @@ endif() install( TARGETS wasmedgePluginWasmEdgeFFmpeg DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge - COMPONENT WasmEdge ) diff --git a/plugins/wasmedge_image/CMakeLists.txt b/plugins/wasmedge_image/CMakeLists.txt index 800c3a94..8360a40d 100644 --- a/plugins/wasmedge_image/CMakeLists.txt +++ b/plugins/wasmedge_image/CMakeLists.txt @@ -143,5 +143,4 @@ endif() install( TARGETS wasmedgePluginWasmEdgeImage DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge - COMPONENT WasmEdge ) diff --git a/plugins/wasmedge_opencvmini/CMakeLists.txt b/plugins/wasmedge_opencvmini/CMakeLists.txt index ed10e816..eaec2009 100644 --- a/plugins/wasmedge_opencvmini/CMakeLists.txt +++ b/plugins/wasmedge_opencvmini/CMakeLists.txt @@ -39,5 +39,4 @@ endif() install( TARGETS wasmedgePluginWasmEdgeOpenCVMini DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge - COMPONENT WasmEdge ) diff --git a/plugins/wasmedge_process/CMakeLists.txt b/plugins/wasmedge_process/CMakeLists.txt index 28a4bcce..819a64d7 100644 --- a/plugins/wasmedge_process/CMakeLists.txt +++ b/plugins/wasmedge_process/CMakeLists.txt @@ -34,5 +34,4 @@ endif() install( TARGETS wasmedgePluginWasmEdgeProcess DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge - COMPONENT WasmEdge ) diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index 313449b6..28214e52 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -107,5 +107,4 @@ endif() install( TARGETS wasmedgePluginWasmEdgeStableDiffusion DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge - COMPONENT WasmEdge ) diff --git a/plugins/wasmedge_tensorflow/CMakeLists.txt b/plugins/wasmedge_tensorflow/CMakeLists.txt index ccfe25ed..96147ee7 100644 --- a/plugins/wasmedge_tensorflow/CMakeLists.txt +++ b/plugins/wasmedge_tensorflow/CMakeLists.txt @@ -37,5 +37,4 @@ wasmedge_setup_tf_target(wasmedgePluginWasmEdgeTensorflow) install( TARGETS wasmedgePluginWasmEdgeTensorflow DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge - COMPONENT WasmEdge ) diff --git a/plugins/wasmedge_tensorflowlite/CMakeLists.txt b/plugins/wasmedge_tensorflowlite/CMakeLists.txt index f8ee177d..30695009 100644 --- a/plugins/wasmedge_tensorflowlite/CMakeLists.txt +++ b/plugins/wasmedge_tensorflowlite/CMakeLists.txt @@ -37,5 +37,4 @@ wasmedge_setup_tflite_target(wasmedgePluginWasmEdgeTensorflowLite) install( TARGETS wasmedgePluginWasmEdgeTensorflowLite DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge - COMPONENT WasmEdge ) diff --git a/plugins/wasmedge_zlib/CMakeLists.txt b/plugins/wasmedge_zlib/CMakeLists.txt index 56745021..d75e577c 100644 --- a/plugins/wasmedge_zlib/CMakeLists.txt +++ b/plugins/wasmedge_zlib/CMakeLists.txt @@ -40,5 +40,4 @@ endif() install( TARGETS wasmedgePluginWasmEdgeZlib DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge - COMPONENT WasmEdge ) From 7aee627566576089aa7978be8a0d719a2fe7a6a3 Mon Sep 17 00:00:00 2001 From: Yi Date: Tue, 27 Aug 2024 02:21:34 +0800 Subject: [PATCH 409/623] [Docker] Add ubuntu 20.04 to bake list (#3694) Signed-off-by: Yi Huang --- utils/docker/docker-bake.ubuntu.hcl | 80 ++++++++++++++++------------- 1 file changed, 45 insertions(+), 35 deletions(-) diff --git a/utils/docker/docker-bake.ubuntu.hcl b/utils/docker/docker-bake.ubuntu.hcl index eb3a534d..6f6a5afc 100644 --- a/utils/docker/docker-bake.ubuntu.hcl +++ b/utils/docker/docker-bake.ubuntu.hcl @@ -1,56 +1,66 @@ group "default" { targets = [ - "clang", - "clang-plugins", - "gcc", - "gcc-plugins" + "base", + "latest", + "plugins" ] } +function "name" { + params = [toolchain, ubuntu] + result = "${toolchain}-ubuntu${replace(ubuntu, ".", "")}" +} + +function "tag" { + params = [toolchain, ubuntu] + result = equal(ubuntu, "22.04") ? "ubuntu-build-${toolchain}" : "ubuntu-${ubuntu}-build-${toolchain}" +} + +variable "matrix" { + default = { + toolchain = ["clang", "gcc"] + ubuntu = ["20.04", "22.04"] + } +} + target "base" { + matrix = matrix + name = name(toolchain, ubuntu) + dockerfile = "Dockerfile.ubuntu-base" context = "./utils/docker" + + tags = ["wasmedge/wasmedge:${tag(toolchain, ubuntu)}"] + args = { + TOOLCHAIN = toolchain + } } -target "plugins-base" { +target "plugins" { + matrix = matrix + name = "${name(toolchain, ubuntu)}-plugins" + dockerfile = "./docker/Dockerfile.ubuntu-plugins-deps" context = "./utils" -} - -target "clang" { - inherits = ["base"] - tags = [ - "wasmedge/wasmedge:latest", - "wasmedge/wasmedge:ubuntu-build-clang" - ] -} -target "clang-plugins" { - inherits = ["plugins-base"] - tags = ["wasmedge/wasmedge:ubuntu-build-clang-plugins-deps"] - contexts = { - "wasmedge/wasmedge:ubuntu-build-clang" = "target:base" - } - args = { - BASE_IMAGE = "wasmedge/wasmedge:ubuntu-build-clang" + contexts = { + "wasmedge/wasmedge:${tag(toolchain, ubuntu)}" = "target:${name(toolchain, ubuntu)}" } -} -target "gcc" { - inherits = ["base"] - tags = ["wasmedge/wasmedge:ubuntu-build-gcc"] - args = { - TOOLCHAIN = "gcc" + tags = ["wasmedge/wasmedge:${tag(toolchain, ubuntu)}-plugins-deps"] + args = { + BASE_IMAGE = "wasmedge/wasmedge:${tag(toolchain, ubuntu)}" } } -target "gcc-plugins" { - inherits = ["plugins-base"] - tags = ["wasmedge/wasmedge:ubuntu-build-gcc-plugins-deps"] - contexts = { - "wasmedge/wasmedge:ubuntu-build-gcc" = "target:base" +target "latest" { + matrix = { + toolchain = ["clang"] + ubuntu = ["22.04"] } - args = { - BASE_IMAGE = "wasmedge/wasmedge:ubuntu-build-gcc" + inherits = ["${name(toolchain, ubuntu)}"] + contexts = { + "wasmedge/wasmedge:${tag(toolchain, ubuntu)}" = "target:${name(toolchain, ubuntu)}" } + tags = ["wasmedge/wasmedge:latest"] } From 138f382f9b708eac69fb49d6b69638fecdad30ca Mon Sep 17 00:00:00 2001 From: Yi Date: Wed, 28 Aug 2024 21:17:38 +0800 Subject: [PATCH 410/623] [Docker] Apply ubuntu version from bake file (#3703) Signed-off-by: Yi Huang --- utils/docker/Dockerfile.ubuntu-base | 54 ++++++++++++++++++++++------- utils/docker/docker-bake.ubuntu.hcl | 3 +- 2 files changed, 44 insertions(+), 13 deletions(-) diff --git a/utils/docker/Dockerfile.ubuntu-base b/utils/docker/Dockerfile.ubuntu-base index 18009a71..195dd4a4 100644 --- a/utils/docker/Dockerfile.ubuntu-base +++ b/utils/docker/Dockerfile.ubuntu-base @@ -1,5 +1,6 @@ +ARG UBUNTU_VER=22 ARG TOOLCHAIN=clang -FROM ubuntu:22.04 AS base +FROM ubuntu:${UBUNTU_VER}.04 AS base ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update && \ @@ -9,28 +10,57 @@ RUN apt-get update && \ curl \ dpkg-dev \ git \ - llvm-15-dev \ - liblld-15-dev \ ninja-build \ software-properties-common \ wget \ zlib1g-dev -### deps for clang ### -FROM base AS deps-clang +### deps for ubuntu 20.04 ### +FROM base AS deps-20 -RUN apt-get update && \ - apt-get install -y \ +RUN apt-get install -y \ + llvm-12-dev \ + liblld-12-dev + +### deps for ubuntu 22.04 ### +FROM base AS deps-22 + +RUN apt-get install -y \ + llvm-15-dev \ + liblld-15-dev + +### deps for clang / ubuntu 20.04 ### +FROM deps-20 AS deps-20-clang + +RUN apt-get install -y \ + clang-12 + +ENV CC=/usr/bin/clang-12 +ENV CXX=/usr/bin/clang++-12 + +### deps for clang / ubuntu 22.04 ### +FROM deps-22 AS deps-22-clang + +RUN apt-get install -y \ clang-15 ENV CC=/usr/bin/clang-15 ENV CXX=/usr/bin/clang++-15 -### deps for gcc ### -FROM base AS deps-gcc +### deps for gcc / ubuntu 20.04 ### +FROM deps-20 AS deps-20-gcc -RUN apt-get update && \ - apt-get install -y \ +RUN apt-get install -y \ + gcc \ + g++ + +ENV CC=gcc +ENV CXX=g++ + +### deps for gcc / ubuntu 22.04 ### +FROM deps-22 AS deps-22-gcc + +RUN apt-get install -y \ gcc \ g++ @@ -38,6 +68,6 @@ ENV CC=gcc ENV CXX=g++ ### deps for all ### -FROM deps-${TOOLCHAIN} AS final +FROM deps-${UBUNTU_VER}-${TOOLCHAIN} AS final RUN rm -rf /var/lib/apt/lists/* diff --git a/utils/docker/docker-bake.ubuntu.hcl b/utils/docker/docker-bake.ubuntu.hcl index 6f6a5afc..9232c75f 100644 --- a/utils/docker/docker-bake.ubuntu.hcl +++ b/utils/docker/docker-bake.ubuntu.hcl @@ -32,7 +32,8 @@ target "base" { tags = ["wasmedge/wasmedge:${tag(toolchain, ubuntu)}"] args = { - TOOLCHAIN = toolchain + TOOLCHAIN = toolchain + UBUNTU_VER = replace(ubuntu, ".04", "") } } From 26703b22c727366682b17dfb25626d73e3551786 Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Thu, 29 Aug 2024 16:12:38 +0800 Subject: [PATCH 411/623] [Docker] Add aarch64 for ubuntu 20.04 Signed-off-by: Yi Huang --- utils/docker/docker-bake.ubuntu.hcl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/utils/docker/docker-bake.ubuntu.hcl b/utils/docker/docker-bake.ubuntu.hcl index 9232c75f..68b9b87d 100644 --- a/utils/docker/docker-bake.ubuntu.hcl +++ b/utils/docker/docker-bake.ubuntu.hcl @@ -65,3 +65,17 @@ target "latest" { } tags = ["wasmedge/wasmedge:latest"] } + +target "clang-ubuntu2004-aarch64" { + matrix = { + toolchain = ["clang"] + ubuntu = ["20.04"] + } + inherits = ["${name(toolchain, ubuntu)}"] + contexts = { + "wasmedge/wasmedge:${tag(toolchain, ubuntu)}" = "target:${name(toolchain, ubuntu)}" + } + + tags = ["wasmedge/wasmedge:${tag(toolchain, ubuntu)}-aarch64"] + platforms = ["linux/arm64"] +} From 5f304eae3dac329cd9da571af616b5c9e2588c02 Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Tue, 27 Aug 2024 16:25:29 +0800 Subject: [PATCH 412/623] [Misc] Replace all usage of `std::cout` to `fmt::print` * Replace `std::stringstream` to `fmt::format` * Add `fmt::format` support for `uint128_t` * Add tests for `fmt::print` `uint128_t` * Fix potential buffer overflow in wasi_ocr Signed-off-by: Shen-Ta Hsieh --- plugins/wasi_nn/ggml.cpp | 29 +++++++++++------------- plugins/wasi_nn/piper.cpp | 1 - plugins/wasi_ocr/wasiocrfunc.cpp | 8 +++---- plugins/wasm_bpf/wasm-bpf-module.cpp | 1 - plugins/wasmedge_zlib/zlibfunc.cpp | 1 - test/plugins/wasm_bpf/wasm_bpf.cpp | 33 ++++++++++------------------ 6 files changed, 29 insertions(+), 44 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index d196f944..6268e322 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -355,13 +356,12 @@ Expect setupContextParam(Graph &GraphRef, Expect buildOutputMetadata(Context &CxtRef, std::string &Metadata) noexcept { - std::ostringstream OS; - OS << R"({"input_tokens": )" << CxtRef.LlamaNInputs - << R"(, "output_tokens": )" << CxtRef.LlamaOutputTokens.size() - << R"(, "llama_build_number": )" << LLAMA_BUILD_NUMBER - << R"(, "llama_commit": ")" << LLAMA_COMMIT << R"("})"; - Metadata = OS.str(); - + Metadata = fmt::format(R"({{"input_tokens": {}, )" + R"("output_tokens": {}, )" + R"("llama_build_number": {}, )" + R"("llama_commit": "{}"}})"sv, + CxtRef.LlamaNInputs, CxtRef.LlamaOutputTokens.size(), + LLAMA_BUILD_NUMBER, LLAMA_COMMIT); return ErrNo::Success; } @@ -378,14 +378,10 @@ void buildOutputEmbedding(std::string &Embedding, int32_t NEmbd, // | (n_embedding-1)*(',') | // | ']' | // | '}' | - std::ostringstream OS; - OS.precision(10); - OS << R"({"n_embedding": )" << NEmbd << R"(, "embedding": [)"; - for (int32_t Idx = 0; Idx < NEmbd - 1; Idx++) { - OS << Embeddings[Idx] << ","; - } - OS << Embeddings[NEmbd - 1] << "]}"; - Embedding = OS.str(); + Embedding = + fmt::format(R"({{"n_embedding": {:.10}, )" + R"("embedding": [{:.10}]}})"sv, + NEmbd, fmt::join(Embeddings, Embeddings + NEmbd, ","sv)); } ErrNo evaluateTokens(Graph &GraphRef, struct llama_context *LlamaContext, @@ -1188,7 +1184,8 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { CxtRef.LlamaOutputs += llama_token_to_piece(LlamaContext, Id); // When setting StreamStdout, we print the output to stdout. if (GraphRef.StreamStdout) { - std::cout << llama_token_to_piece(LlamaContext, Id) << std::flush; + fmt::print("{}"sv, llama_token_to_piece(LlamaContext, Id)); + std::fflush(stdout); } // Break if reverse prompt is found. if (!GraphRef.ReversePrompt.empty() && diff --git a/plugins/wasi_nn/piper.cpp b/plugins/wasi_nn/piper.cpp index c764c882..fdc5c262 100644 --- a/plugins/wasi_nn/piper.cpp +++ b/plugins/wasi_nn/piper.cpp @@ -12,7 +12,6 @@ #include #include #include -#include #include #include #include diff --git a/plugins/wasi_ocr/wasiocrfunc.cpp b/plugins/wasi_ocr/wasiocrfunc.cpp index f63bd021..da97b794 100644 --- a/plugins/wasi_ocr/wasiocrfunc.cpp +++ b/plugins/wasi_ocr/wasiocrfunc.cpp @@ -4,7 +4,7 @@ #include "wasiocrfunc.h" #include "common/spdlog.h" -#include +#include #include namespace WasmEdge { @@ -54,12 +54,12 @@ Expect WasiOCRGetOutput::body(const Runtime::CallingFrame &Frame, } tesseract::PageIteratorLevel level = tesseract::RIL_WORD; - const char *outText = Env.TesseractApi->GetTSVText(level); - std::strcpy(Buf.data(), outText); + std::unique_ptr outText = Env.TesseractApi->GetTSVText(level); + std::copy_n(outText, std::min(std::strlen(outText.get()), Buf.size()), + Buf.begin()); // remaining free and deltee memory stuff Env.TesseractApi->End(); - delete[] outText; // USE WHEN USING TESS API return static_cast(WASIOCR::ErrNo::Success); // return outText; diff --git a/plugins/wasm_bpf/wasm-bpf-module.cpp b/plugins/wasm_bpf/wasm-bpf-module.cpp index 8e564a6d..a240e6a4 100644 --- a/plugins/wasm_bpf/wasm-bpf-module.cpp +++ b/plugins/wasm_bpf/wasm-bpf-module.cpp @@ -13,7 +13,6 @@ #include "runtime/callingframe.h" #include "state.h" #include -#include namespace WasmEdge { namespace Host { diff --git a/plugins/wasmedge_zlib/zlibfunc.cpp b/plugins/wasmedge_zlib/zlibfunc.cpp index 5ee8d343..e2f7d3ef 100644 --- a/plugins/wasmedge_zlib/zlibfunc.cpp +++ b/plugins/wasmedge_zlib/zlibfunc.cpp @@ -4,7 +4,6 @@ #include "zlibfunc.h" #include -#include namespace WasmEdge { namespace Host { diff --git a/test/plugins/wasm_bpf/wasm_bpf.cpp b/test/plugins/wasm_bpf/wasm_bpf.cpp index 705015c3..bdef8d10 100644 --- a/test/plugins/wasm_bpf/wasm_bpf.cpp +++ b/test/plugins/wasm_bpf/wasm_bpf.cpp @@ -16,13 +16,12 @@ #include #include #include -#include #include #include +#include +#include #include #include -#include -#include #include #include #include @@ -106,24 +105,18 @@ class PollCallbackFunction if (unlikely(!dataPtr)) { return WasmEdge::Unexpect(WasmEdge::ErrCode::Value::HostFuncError); } - auto nowTime = chrono::system_clock::to_time_t(chrono::system_clock::now()); - tm nowTimeRepr; - localtime_r(&nowTime, &nowTimeRepr); + auto nowTime = chrono::system_clock::now(); if (dataPtr->exit_event == 1) { - cout.setf(ios::left); - cout << std::put_time(&nowTimeRepr, "%H:%M:%S") << " EXIT " << setw(16) - << setfill(' ') << dataPtr->comm << " " << setw(7) << setfill(' ') - << dataPtr->pid << " " << setw(7) << setfill(' ') << dataPtr->ppid - << " [" << dataPtr->exit_code << "]"; + fmt::print("{:%H:%M:%S} EXIT {:<16} {:<7} {:<7} [{}]"sv, nowTime, + dataPtr->comm, dataPtr->pid, dataPtr->ppid, + dataPtr->exit_code); if (dataPtr->duration_ns != 0) { - cout << " (" << dataPtr->duration_ns / 1000000 << ")" << endl; + fmt::print(" ({})"sv, dataPtr->duration_ns / 1000000); } + fmt::print("\n"sv); } else { - cout.setf(ios::left); - cout << std::put_time(&nowTimeRepr, "%H:%M:%S") << " EXEC " << setw(16) - << setfill(' ') << dataPtr->comm << " " << setw(7) << setfill(' ') - << dataPtr->pid << " " << setw(7) << setfill(' ') << dataPtr->ppid - << " " << dataPtr->filename << endl; + fmt::print("{:%H:%M:%S} EXEC {:<16} {:<7} {:<7} {}\n"sv, nowTime, + dataPtr->comm, dataPtr->pid, dataPtr->ppid, dataPtr->filename); } return 0; } @@ -543,9 +536,7 @@ TEST(WasmBpfTest, RunBpfProgramWithMapOperation) { for (size_t i = 0; i < maxIdx; i++) { auto low = UINT64_C(1) << (i); auto high = (UINT64_C(1) << (i + 1)) - 1; - cout.setf(ios::left); - cout << setw(6) << low << "..." << setw(6) << high << " " << setw(6) - << histRef.slots[i] << endl; + fmt::print("{:<6}...{:<6} {:<6}\n"sv, low, high, histRef.slots[i]); } writeU32(lookUpKeyOffset, readU32(nextKeyOffset)); } @@ -554,7 +545,7 @@ TEST(WasmBpfTest, RunBpfProgramWithMapOperation) { EXPECT_GE(mapDeleteElem(histsFd, nextKeyOffset), 0); writeU32(lookUpKeyOffset, readU32(nextKeyOffset)); } - cout << endl; + fmt::print("\n"sv); } // Get function `wasm_close_bpf_object` From a89da8f6517588e804a1e68271b6f954ade3029e Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Fri, 30 Aug 2024 16:40:20 +0800 Subject: [PATCH 413/623] [Misc] Replace raw pointers to `std::unique_ptr` in unittest Signed-off-by: Shen-Ta Hsieh --- test/plugins/unittest/unittest_cpp.cpp | 48 ++++---- test/plugins/wasi_crypto/common.cpp | 3 +- test/plugins/wasi_crypto/helper.cpp | 3 +- test/plugins/wasi_crypto/helper.h | 62 +++++------ test/plugins/wasi_llm/wasi_llm.cpp | 25 +++-- test/plugins/wasi_logging/wasi_logging.cpp | 33 +++--- test/plugins/wasi_nn/wasi_nn.cpp | 72 ++++++------ test/plugins/wasm_bpf/wasm_bpf.cpp | 32 ++++-- test/plugins/wasmedge_ffmpeg/utils.h | 104 +++++++++--------- .../plugins/wasmedge_image/wasmedge_image.cpp | 26 +++-- .../wasmedge_opencvmini.cpp | 25 +++-- .../wasmedge_process/wasmedge_process.cpp | 89 +++++++-------- .../wasmedge_stablediffusion.cpp | 22 +++- .../wasmedge_tensorflow.cpp | 26 +++-- .../wasmedge_tensorflowlite.cpp | 27 +++-- test/plugins/wasmedge_zlib/wasmedge_zlib.cpp | 32 ++++-- 16 files changed, 353 insertions(+), 276 deletions(-) diff --git a/test/plugins/unittest/unittest_cpp.cpp b/test/plugins/unittest/unittest_cpp.cpp index 806e855a..9410956e 100644 --- a/test/plugins/unittest/unittest_cpp.cpp +++ b/test/plugins/unittest/unittest_cpp.cpp @@ -11,11 +11,23 @@ #include #include #include +#include #include #include namespace { -WasmEdge::Runtime::Instance::ModuleInstance *createModuleC() { + +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + +std::unique_ptr createModuleC() { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( "./" WASMEDGE_LIB_PREFIX @@ -24,13 +36,13 @@ WasmEdge::Runtime::Instance::ModuleInstance *createModuleC() { WasmEdge::Plugin::Plugin::find("wasmedge_plugintest_c"sv)) { if (const auto *Module = Plugin->findModule("wasmedge_plugintest_c_module"sv)) { - return Module->create().release(); + return Module->create(); } } - return nullptr; + return {}; } -WasmEdge::Runtime::Instance::ModuleInstance *createModuleCPP() { +std::unique_ptr createModuleCPP() { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( "./" WASMEDGE_LIB_PREFIX @@ -45,18 +57,18 @@ WasmEdge::Runtime::Instance::ModuleInstance *createModuleCPP() { Parser.set_raw_value("opt"sv); if (const auto *Module = Plugin->findModule("wasmedge_plugintest_cpp_module"sv)) { - return Module->create().release(); + return dynamicPointerCast( + Module->create()); } } - return nullptr; + return {}; } } // namespace TEST(wasmedgePluginTests, CPP_Run) { // Create the wasmedge_plugintest_cpp_module module instance. - auto *TestModCPP = dynamic_cast( - createModuleCPP()); - ASSERT_FALSE(TestModCPP == nullptr); + auto TestModCPP = createModuleCPP(); + ASSERT_TRUE(TestModCPP); WasmEdge::Runtime::Instance::ModuleInstance Mod(""); WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); @@ -98,36 +110,30 @@ TEST(wasmedgePluginTests, CPP_Run) { EXPECT_TRUE(HostFuncInst3.run(CallFrame, {}, RetVal)); EXPECT_EQ(RetVal[0].get(), 1); - delete TestModCPP; - // Create the wasmedge_plugintest_c_module module instance. - auto *TestModC = createModuleC(); - ASSERT_FALSE(TestModC == nullptr); + auto TestModC = createModuleC(); + ASSERT_TRUE(TestModC); // The host functions are implemented in the C API. // Therefore not test to invoke them here. - delete TestModC; } TEST(wasmedgePluginTests, CPP_Module) { // Create the wasmedge_plugintest_cpp_module module instance. - auto *TestModCPP = dynamic_cast( - createModuleCPP()); - ASSERT_FALSE(TestModCPP == nullptr); + auto TestModCPP = createModuleCPP(); + ASSERT_TRUE(TestModCPP); EXPECT_EQ(TestModCPP->getFuncExportNum(), 5U); EXPECT_NE(TestModCPP->findFuncExports("add"), nullptr); EXPECT_NE(TestModCPP->findFuncExports("sub"), nullptr); EXPECT_NE(TestModCPP->findFuncExports("arg_len"), nullptr); EXPECT_NE(TestModCPP->findFuncExports("opt"), nullptr); EXPECT_NE(TestModCPP->findFuncExports("name_size"), nullptr); - delete TestModCPP; // Create the wasmedge_plugintest_c_module module instance. - auto *TestModC = createModuleC(); - ASSERT_FALSE(TestModC == nullptr); + auto TestModC = createModuleC(); + ASSERT_TRUE(TestModC); EXPECT_EQ(TestModC->getFuncExportNum(), 2U); EXPECT_NE(TestModC->findFuncExports("add"), nullptr); EXPECT_NE(TestModC->findFuncExports("sub"), nullptr); - delete TestModC; } GTEST_API_ int main(int argc, char **argv) { diff --git a/test/plugins/wasi_crypto/common.cpp b/test/plugins/wasi_crypto/common.cpp index 086701cf..e8310c33 100644 --- a/test/plugins/wasi_crypto/common.cpp +++ b/test/plugins/wasi_crypto/common.cpp @@ -5,7 +5,8 @@ #include "helper.h" namespace { -template T *getHostFunc(M *Mod, const char *Name) { +template +inline T *getHostFunc(const M &Mod, const char *Name) { if (Mod) { auto *FuncInst = Mod->findFuncExports(Name); if (FuncInst && FuncInst->isHostFunction()) { diff --git a/test/plugins/wasi_crypto/helper.cpp b/test/plugins/wasi_crypto/helper.cpp index c0a0ce04..f3ca00a7 100644 --- a/test/plugins/wasi_crypto/helper.cpp +++ b/test/plugins/wasi_crypto/helper.cpp @@ -22,7 +22,8 @@ } while (0) namespace { -template T *getHostFunc(M *Mod, const char *Name) { +template +inline T *getHostFunc(M &Mod, const char *Name) { if (Mod) { auto *FuncInst = Mod->findFuncExports(Name); if (FuncInst && FuncInst->isHostFunction()) { diff --git a/test/plugins/wasi_crypto/helper.h b/test/plugins/wasi_crypto/helper.h index b6512a5f..0fd03a50 100644 --- a/test/plugins/wasi_crypto/helper.h +++ b/test/plugins/wasi_crypto/helper.h @@ -20,6 +20,7 @@ #include #include +#include #include #include #include @@ -52,6 +53,16 @@ std::vector operator"" _u8(const char *Str, std::size_t Len); std::vector operator"" _u8v(const char *Str, std::size_t Len); +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + /// Designed for testing. class WasiCryptoTest : public ::testing::Test { public: @@ -68,50 +79,32 @@ class WasiCryptoTest : public ::testing::Test { if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasi_crypto"sv)) { if (const auto *Module = Plugin->findModule("wasi_crypto_asymmetric_common"sv)) { - WasiCryptoAsymCommonMod = - dynamic_cast( - Module->create().release()); + WasiCryptoAsymCommonMod = dynamicPointerCast< + WasmEdge::Host::WasiCryptoAsymmetricCommonModule>(Module->create()); } if (const auto *Module = Plugin->findModule("wasi_crypto_common"sv)) { WasiCryptoCommonMod = - dynamic_cast( - Module->create().release()); + dynamicPointerCast( + Module->create()); } if (const auto *Module = Plugin->findModule("wasi_crypto_kx"sv)) { - WasiCryptoKxMod = dynamic_cast( - Module->create().release()); + WasiCryptoKxMod = + dynamicPointerCast( + Module->create()); } if (const auto *Module = Plugin->findModule("wasi_crypto_signatures"sv)) { WasiCryptoSignMod = - dynamic_cast( - Module->create().release()); + dynamicPointerCast( + Module->create()); } if (const auto *Module = Plugin->findModule("wasi_crypto_symmetric"sv)) { WasiCryptoSymmMod = - dynamic_cast( - Module->create().release()); + dynamicPointerCast( + Module->create()); } } } - ~WasiCryptoTest() override { - if (WasiCryptoAsymCommonMod) { - delete WasiCryptoAsymCommonMod; - } - if (WasiCryptoCommonMod) { - delete WasiCryptoCommonMod; - } - if (WasiCryptoKxMod) { - delete WasiCryptoKxMod; - } - if (WasiCryptoSignMod) { - delete WasiCryptoSignMod; - } - if (WasiCryptoSymmMod) { - delete WasiCryptoSymmMod; - } - } - protected: void writeDummyMemoryContent(); @@ -398,11 +391,12 @@ class WasiCryptoTest : public ::testing::Test { std::array Errno; - Host::WasiCryptoAsymmetricCommonModule *WasiCryptoAsymCommonMod = nullptr; - Host::WasiCryptoCommonModule *WasiCryptoCommonMod = nullptr; - Host::WasiCryptoKxModule *WasiCryptoKxMod = nullptr; - Host::WasiCryptoSignaturesModule *WasiCryptoSignMod = nullptr; - Host::WasiCryptoSymmetricModule *WasiCryptoSymmMod = nullptr; + std::unique_ptr + WasiCryptoAsymCommonMod; + std::unique_ptr WasiCryptoCommonMod; + std::unique_ptr WasiCryptoKxMod; + std::unique_ptr WasiCryptoSignMod; + std::unique_ptr WasiCryptoSymmMod; }; } // namespace WasiCrypto diff --git a/test/plugins/wasi_llm/wasi_llm.cpp b/test/plugins/wasi_llm/wasi_llm.cpp index 67790397..408c06c6 100644 --- a/test/plugins/wasi_llm/wasi_llm.cpp +++ b/test/plugins/wasi_llm/wasi_llm.cpp @@ -12,23 +12,36 @@ #include #include #include +#include #include #include using WasmEdge::Host::WASILLM::ErrNo; namespace { -WasmEdge::Runtime::Instance::ModuleInstance *createModule() { + +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + +std::unique_ptr createModule() { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load( std::filesystem::u8path("../../../plugins/wasi_llm/" WASMEDGE_LIB_PREFIX "wasmedgePluginWasiLLM" WASMEDGE_LIB_EXTENSION)); if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasi_llm"sv)) { if (const auto *Module = Plugin->findModule("wasi_llm"sv)) { - return Module->create().release(); + return dynamicPointerCast( + Module->create()); } } - return nullptr; + return {}; } } // namespace @@ -47,8 +60,8 @@ void writeUInt32(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, TEST(WasiLLMTest, TrainGPT2) { // Create wasi_llm module instance. - auto *LLMMod = dynamic_cast(createModule()); - EXPECT_NE(LLMMod, nullptr); + auto LLMMod = createModule(); + ASSERT_TRUE(LLMMod); EXPECT_EQ(LLMMod->getFuncExportNum(), 4U); // Create the calling frame with memory instance. @@ -189,8 +202,6 @@ TEST(WasiLLMTest, TrainGPT2) { /*Epoch*/ 20}, Errno)); } - - delete LLMMod; } GTEST_API_ int main(int argc, char **argv) { diff --git a/test/plugins/wasi_logging/wasi_logging.cpp b/test/plugins/wasi_logging/wasi_logging.cpp index 90cad91e..706aeada 100644 --- a/test/plugins/wasi_logging/wasi_logging.cpp +++ b/test/plugins/wasi_logging/wasi_logging.cpp @@ -8,22 +8,32 @@ #include "runtime/instance/module.h" #include - -#include +#include #include namespace { -WasmEdge::Runtime::Instance::ModuleInstance *createModule() { +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + +std::unique_ptr createModule() { using namespace std::literals::string_view_literals; // The built-in plugins are loaded when loading from default paths. WasmEdge::Plugin::Plugin::loadFromDefaultPaths(); if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasi_logging"sv)) { if (const auto *Module = Plugin->findModule("wasi:logging/logging"sv)) { - return Module->create().release(); + return dynamicPointerCast( + Module->create()); } } - return nullptr; + return {}; } void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, @@ -43,12 +53,10 @@ TEST(WasiLoggingTests, func_log) { using namespace std::literals::string_view_literals; // Create the wasi-logging module instance. // Here create 2 wasi-logging modules for testing in multiple modules. - auto WasiLoggingMod1 = - dynamic_cast(createModule()); - EXPECT_NE(WasiLoggingMod1, nullptr); - auto WasiLoggingMod2 = - dynamic_cast(createModule()); - EXPECT_NE(WasiLoggingMod2, nullptr); + auto WasiLoggingMod1 = createModule(); + ASSERT_TRUE(WasiLoggingMod1); + auto WasiLoggingMod2 = createModule(); + ASSERT_TRUE(WasiLoggingMod2); // Create the calling frame with memory instance. WasmEdge::Runtime::Instance::ModuleInstance Mod(""); @@ -173,9 +181,6 @@ TEST(WasiLoggingTests, func_log) { std::initializer_list{ UINT32_C(6), UINT32_C(0), UINT32_C(6), UINT32_C(128), UINT32_C(19)}, {})); - - delete WasiLoggingMod1; - delete WasiLoggingMod2; } GTEST_API_ int main(int argc, char **argv) { diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 76411a25..4d0f4f40 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include @@ -28,7 +29,18 @@ using WasmEdge::Host::WASINN::ErrNo; defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS) namespace { -WasmEdge::Runtime::Instance::ModuleInstance * + +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + +std::unique_ptr createModule(std::string_view NNRPCURI = "") { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load( @@ -41,10 +53,10 @@ createModule(std::string_view NNRPCURI = "") { Parser.set_raw_value("nn-rpc-uri"sv, std::string(NNRPCURI)); } if (const auto *Module = Plugin->findModule("wasi_nn"sv)) { - return Module->create().release(); + return dynamicPointerCast(Module->create()); } } - return nullptr; + return {}; } #if !defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS) @@ -104,8 +116,8 @@ std::vector classSort(WasmEdge::Span Array) { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO TEST(WasiNNTest, OpenVINOBackend) { // Create the wasi_nn module instance. - auto *NNMod = dynamic_cast(createModule()); - ASSERT_TRUE(NNMod != nullptr); + auto NNMod = createModule(); + ASSERT_TRUE(NNMod); // Create the calling frame with memory instance. WasmEdge::Runtime::Instance::ModuleInstance Mod(""); @@ -521,16 +533,14 @@ TEST(WasiNNTest, OpenVINOBackend) { EXPECT_EQ(SortedIndex[I], CorrectClasses[I]); } } - - delete NNMod; } #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH TEST(WasiNNTest, PyTorchBackend) { // Create the wasi_nn module instance. - auto *NNMod = dynamic_cast(createModule()); - EXPECT_FALSE(NNMod == nullptr); + auto NNMod = createModule(); + ASSERT_TRUE(NNMod); // Create the calling frame with memory instance. WasmEdge::Runtime::Instance::ModuleInstance Mod(""); @@ -895,16 +905,14 @@ TEST(WasiNNTest, PyTorchBackend) { EXPECT_EQ(SortedIndex[I], CorrectClasses[I]); } } - - delete NNMod; } #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE TEST(WasiNNTest, TFLiteBackend) { // Create the wasi_nn module instance. - auto *NNMod = dynamic_cast(createModule()); - EXPECT_FALSE(NNMod == nullptr); + auto NNMod = createModule(); + ASSERT_TRUE(NNMod); // Create the calling frame with memory instance. WasmEdge::Runtime::Instance::ModuleInstance Mod(""); @@ -1262,16 +1270,14 @@ TEST(WasiNNTest, TFLiteBackend) { OutputClassification[CorrectClasses[I]]); } } - - delete NNMod; } #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML TEST(WasiNNTest, GGMLBackend) { // Create the wasi_nn module instance. - auto *NNMod = dynamic_cast(createModule()); - EXPECT_FALSE(NNMod == nullptr); + auto NNMod = createModule(); + ASSERT_TRUE(NNMod); // Create the calling frame with memory instance. WasmEdge::Runtime::Instance::ModuleInstance Mod(""); @@ -1551,9 +1557,8 @@ TEST(WasiNNTest, GGMLBackendWithRPC) { } // Create the wasi_nn module instance. - auto *NNMod = - dynamic_cast(createModule(NNRPCURI)); - EXPECT_FALSE(NNMod == nullptr); + auto NNMod = createModule(NNRPCURI); + ASSERT_TRUE(NNMod); // Create the calling frame with memory instance. WasmEdge::Runtime::Instance::ModuleInstance Mod(""); @@ -1765,8 +1770,6 @@ TEST(WasiNNTest, GGMLBackendWithRPC) { auto BytesWritten = *MemInst.getPointer(BuilderPtr); EXPECT_GE(BytesWritten, 50); } - - delete NNMod; } TEST(WasiNNTest, GGMLBackendComputeSingleWithRPC) { @@ -1787,9 +1790,8 @@ TEST(WasiNNTest, GGMLBackendComputeSingleWithRPC) { } // Create the wasmedge_process module instance. - auto *NNMod = - dynamic_cast(createModule(NNRPCURI)); - EXPECT_FALSE(NNMod == nullptr); + auto NNMod = createModule(NNRPCURI); + ASSERT_TRUE(NNMod); // Create the calling frame with memory instance. WasmEdge::Runtime::Instance::ModuleInstance Mod(""); @@ -1982,8 +1984,8 @@ TEST(WasiNNTest, GGMLBackendComputeSingleWithRPC) { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED TEST(WasiNNTest, NeuralSpeedBackend) { // Create the wasi_nn module instance. - auto *NNMod = dynamic_cast(createModule()); - ASSERT_TRUE(NNMod != nullptr); + auto NNMod = createModule(); + ASSERT_TRUE(NNMod); // Create the calling frame with memory instance. WasmEdge::Runtime::Instance::ModuleInstance Mod(""); @@ -2240,16 +2242,14 @@ TEST(WasiNNTest, NeuralSpeedBackend) { Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); } - - delete NNMod; } #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER TEST(WasiNNTest, WhisperBackend) { // Create the wasi_nn module instance. - auto *NNMod = dynamic_cast(createModule()); - ASSERT_TRUE(NNMod != nullptr); + auto NNMod = createModule(); + ASSERT_TRUE(NNMod); // Create the calling frame with memory instance. WasmEdge::Runtime::Instance::ModuleInstance Mod(""); @@ -2473,16 +2473,14 @@ TEST(WasiNNTest, WhisperBackend) { auto BytesWritten = *MemInst.getPointer(BuilderPtr); EXPECT_GE(BytesWritten, 50); } - - delete NNMod; } #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER TEST(WasiNNTest, PiperBackend) { // Create the wasmedge_process module instance. - auto *NNMod = dynamic_cast(createModule()); - ASSERT_TRUE(NNMod != nullptr); + auto NNMod = createModule(); + ASSERT_TRUE(NNMod); // Create the calling frame with memory instance. WasmEdge::Runtime::Instance::ModuleInstance Mod(""); @@ -2727,8 +2725,8 @@ TEST(WasiNNTest, PiperBackend) { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS TEST(WasiNNTest, ChatTTSBackend) { // Create the wasmedge_process module instance. - auto *NNMod = dynamic_cast(createModule()); - ASSERT_TRUE(NNMod != nullptr); + auto NNMod = createModule(); + ASSERT_TRUE(NNMod); // Create the calling frame with memory instance. WasmEdge::Runtime::Instance::ModuleInstance Mod(""); diff --git a/test/plugins/wasm_bpf/wasm_bpf.cpp b/test/plugins/wasm_bpf/wasm_bpf.cpp index bdef8d10..827a491d 100644 --- a/test/plugins/wasm_bpf/wasm_bpf.cpp +++ b/test/plugins/wasm_bpf/wasm_bpf.cpp @@ -28,17 +28,29 @@ #include #include namespace { -WasmEdge::Runtime::Instance::ModuleInstance *createModule() { + +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + +std::unique_ptr createModule() { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load( std::filesystem::u8path("../../../plugins/wasm_bpf/" WASMEDGE_LIB_PREFIX "wasmedgePluginWasmBpf" WASMEDGE_LIB_EXTENSION)); if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasm_bpf"sv)) { if (const auto *Module = Plugin->findModule("wasm_bpf"sv)) { - return Module->create().release(); + return dynamicPointerCast( + Module->create()); } } - return nullptr; + return {}; } std::filesystem::path getAssertsPath() { @@ -59,8 +71,8 @@ void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &memInst, } // namespace TEST(WasmBpfTest, Module) { - auto module = dynamic_cast(createModule()); - EXPECT_NE(module, nullptr); + auto module = createModule(); + ASSERT_TRUE(module); // Test whether functions are exported EXPECT_EQ(module->getFuncExportNum(), 6U); EXPECT_NE(module->findFuncExports("wasm_load_bpf_object"), nullptr); @@ -69,8 +81,6 @@ TEST(WasmBpfTest, Module) { EXPECT_NE(module->findFuncExports("wasm_bpf_buffer_poll"), nullptr); EXPECT_NE(module->findFuncExports("wasm_bpf_map_fd_by_name"), nullptr); EXPECT_NE(module->findFuncExports("wasm_bpf_map_operate"), nullptr); - - delete module; } static const size_t TASK_COMM_LEN = 16; @@ -125,8 +135,8 @@ class PollCallbackFunction TEST(WasmBpfTest, RunBpfProgramWithPolling) { using namespace std::literals::string_view_literals; // Test loading and attaching a bpf program, and polling buffer - auto module = dynamic_cast(createModule()); - EXPECT_NE(module, nullptr); + auto module = createModule(); + ASSERT_TRUE(module); // Create the calling frame with memory instance. WasmEdge::Runtime::Instance::ModuleInstance moduleInst(""); @@ -339,8 +349,8 @@ struct hist { TEST(WasmBpfTest, RunBpfProgramWithMapOperation) { // Test loading and attaching a bpf program, and polling buffer - auto module = dynamic_cast(createModule()); - EXPECT_NE(module, nullptr); + auto module = createModule(); + ASSERT_TRUE(module); // Create the calling frame with memory instance. WasmEdge::Runtime::Instance::ModuleInstance moduleInst(""); diff --git a/test/plugins/wasmedge_ffmpeg/utils.h b/test/plugins/wasmedge_ffmpeg/utils.h index 7a66f18f..ce4d771c 100644 --- a/test/plugins/wasmedge_ffmpeg/utils.h +++ b/test/plugins/wasmedge_ffmpeg/utils.h @@ -14,12 +14,23 @@ #include "runtime/callingframe.h" #include "runtime/instance/module.h" -#include "gtest/gtest.h" +#include +#include namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + inline void writeUInt32(WasmEdge::Runtime::Instance::MemoryInstance *MemInst, uint32_t Value, uint32_t &Ptr) { uint32_t *BufPtr = MemInst->getPointer(Ptr); @@ -72,64 +83,47 @@ class FFmpegTest : public ::testing::Test { WasmEdge::Plugin::Plugin::find("wasmedge_ffmpeg"sv)) { if (const auto *Module = Plugin->findModule("wasmedge_ffmpeg_avformat"sv)) { - AVFormatMod = dynamic_cast( - Module->create().release()); + AVFormatMod = + dynamicPointerCast( + Module->create()); } if (const auto *Module = Plugin->findModule("wasmedge_ffmpeg_avutil"sv)) { - AVUtilMod = dynamic_cast< - WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::WasmEdgeFFmpegAVUtilModule - *>(Module->create().release()); + AVUtilMod = dynamicPointerCast< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::WasmEdgeFFmpegAVUtilModule>( + Module->create()); } if (const auto *Module = Plugin->findModule("wasmedge_ffmpeg_swscale"sv)) { - SWScaleMod = dynamic_cast< - WasmEdge::Host::WasmEdgeFFmpeg::SWScale::WasmEdgeFFmpegSWScaleModule - *>(Module->create().release()); + SWScaleMod = + dynamicPointerCast( + Module->create()); } if (const auto *Module = Plugin->findModule("wasmedge_ffmpeg_avcodec"sv)) { - AVCodecMod = dynamic_cast< - WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::WasmEdgeFFmpegAVCodecModule - *>(Module->create().release()); + AVCodecMod = + dynamicPointerCast( + Module->create()); } if (const auto *Module = Plugin->findModule("wasmedge_ffmpeg_swresample"sv)) { SWResampleMod = - dynamic_cast( - Module->create().release()); + dynamicPointerCast( + Module->create()); } if (const auto *Module = Plugin->findModule("wasmedge_ffmpeg_avfilter"sv)) { - AVFilterMod = dynamic_cast( - Module->create().release()); + AVFilterMod = + dynamicPointerCast( + Module->create()); } } } - ~FFmpegTest() override { - if (AVUtilMod) { - delete AVUtilMod; - } - if (AVCodecMod) { - delete AVCodecMod; - } - if (SWScaleMod) { - delete SWScaleMod; - } - if (SWResampleMod) { - delete SWResampleMod; - } - if (AVFormatMod) { - delete AVFormatMod; - } - if (AVFilterMod) { - delete AVFilterMod; - } - } - protected: void initEmptyFrame(uint32_t FramePtr); @@ -153,18 +147,24 @@ class FFmpegTest : public ::testing::Test { WasmEdge::Runtime::CallingFrame CallFrame; // Wasm Modules. - WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::WasmEdgeFFmpegAVFormatModule - *AVFormatMod = nullptr; - WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::WasmEdgeFFmpegAVUtilModule - *AVUtilMod = nullptr; - WasmEdge::Host::WasmEdgeFFmpeg::SWResample::WasmEdgeFFmpegSWResampleModule - *SWResampleMod = nullptr; - WasmEdge::Host::WasmEdgeFFmpeg::SWScale::WasmEdgeFFmpegSWScaleModule - *SWScaleMod = nullptr; - WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::WasmEdgeFFmpegAVCodecModule - *AVCodecMod = nullptr; - WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::WasmEdgeFFmpegAVFilterModule - *AVFilterMod = nullptr; + std::unique_ptr< + WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::WasmEdgeFFmpegAVFormatModule> + AVFormatMod; + std::unique_ptr< + WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::WasmEdgeFFmpegAVUtilModule> + AVUtilMod; + std::unique_ptr + SWResampleMod; + std::unique_ptr< + WasmEdge::Host::WasmEdgeFFmpeg::SWScale::WasmEdgeFFmpegSWScaleModule> + SWScaleMod; + std::unique_ptr< + WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::WasmEdgeFFmpegAVCodecModule> + AVCodecMod; + std::unique_ptr< + WasmEdge::Host::WasmEdgeFFmpeg::AVFilter::WasmEdgeFFmpegAVFilterModule> + AVFilterMod; }; } // namespace WasmEdgeFFmpeg diff --git a/test/plugins/wasmedge_image/wasmedge_image.cpp b/test/plugins/wasmedge_image/wasmedge_image.cpp index 0c16db43..be767b4e 100644 --- a/test/plugins/wasmedge_image/wasmedge_image.cpp +++ b/test/plugins/wasmedge_image/wasmedge_image.cpp @@ -10,35 +10,47 @@ #include #include #include +#include #include #include namespace { -WasmEdge::Runtime::Instance::ModuleInstance *createModule() { + +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + +std::unique_ptr createModule() { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( "../../../plugins/wasmedge_image/" WASMEDGE_LIB_PREFIX "wasmedgePluginWasmEdgeImage" WASMEDGE_LIB_EXTENSION)); if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasmedge_image"sv)) { if (const auto *Module = Plugin->findModule("wasmedge_image"sv)) { - return Module->create().release(); + return dynamicPointerCast( + Module->create()); } } - return nullptr; + return {}; } + } // namespace // TODO: unit tests for every functions. TEST(WasmEdgeImageTest, Module) { // Create the wasmedge_image module instance. - auto *ImgMod = - dynamic_cast(createModule()); - EXPECT_FALSE(ImgMod == nullptr); + auto ImgMod = createModule(); + ASSERT_TRUE(ImgMod); EXPECT_EQ(ImgMod->getFuncExportNum(), 2U); EXPECT_NE(ImgMod->findFuncExports("load_jpg"), nullptr); EXPECT_NE(ImgMod->findFuncExports("load_png"), nullptr); - delete ImgMod; } GTEST_API_ int main(int argc, char **argv) { diff --git a/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp b/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp index 60e78db0..d2905c80 100644 --- a/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp +++ b/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp @@ -14,7 +14,18 @@ #include namespace { -WasmEdge::Runtime::Instance::ModuleInstance *createModule() { + +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + +std::unique_ptr createModule() { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( "../../../plugins/wasmedge_opencvmini/" WASMEDGE_LIB_PREFIX @@ -22,26 +33,26 @@ WasmEdge::Runtime::Instance::ModuleInstance *createModule() { if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasmedge_opencvmini"sv)) { if (const auto *Module = Plugin->findModule("wasmedge_opencvmini"sv)) { - return Module->create().release(); + return dynamicPointerCast( + Module->create()); } } - return nullptr; + return {}; } + } // namespace // TODO: unit tests for every functions. TEST(WasmEdgeOpecvminiTest, Module) { // Create the wasmedge_opencvmini module instance. - auto *ImgMod = - dynamic_cast(createModule()); - EXPECT_FALSE(ImgMod == nullptr); + auto ImgMod = createModule(); + ASSERT_TRUE(ImgMod); EXPECT_EQ(ImgMod->getFuncExportNum(), 19U); EXPECT_NE(ImgMod->findFuncExports("wasmedge_opencvmini_imdecode"), nullptr); EXPECT_NE(ImgMod->findFuncExports("wasmedge_opencvmini_imencode"), nullptr); EXPECT_NE(ImgMod->findFuncExports("wasmedge_opencvmini_rectangle"), nullptr); EXPECT_NE(ImgMod->findFuncExports("wasmedge_opencvmini_cvt_color"), nullptr); - delete ImgMod; } GTEST_API_ int main(int argc, char **argv) { diff --git a/test/plugins/wasmedge_process/wasmedge_process.cpp b/test/plugins/wasmedge_process/wasmedge_process.cpp index fd322f5f..d74e245f 100644 --- a/test/plugins/wasmedge_process/wasmedge_process.cpp +++ b/test/plugins/wasmedge_process/wasmedge_process.cpp @@ -10,14 +10,26 @@ #include #include #include +#include #include #include #include namespace { + WasmEdge::Runtime::CallingFrame DummyCallFrame(nullptr, nullptr); -WasmEdge::Runtime::Instance::ModuleInstance *createModule() { +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + +std::unique_ptr createModule() { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( "../../../plugins/wasmedge_process/" WASMEDGE_LIB_PREFIX @@ -25,10 +37,11 @@ WasmEdge::Runtime::Instance::ModuleInstance *createModule() { if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasmedge_process"sv)) { if (const auto *Module = Plugin->findModule("wasmedge_process"sv)) { - return Module->create().release(); + return dynamicPointerCast( + Module->create()); } } - return nullptr; + return {}; } void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, @@ -47,9 +60,8 @@ using namespace std::literals::string_view_literals; TEST(WasmEdgeProcessTest, SetProgName) { // Create the wasmedge_process module instance. - auto *ProcMod = - dynamic_cast(createModule()); - ASSERT_TRUE(ProcMod != nullptr); + auto ProcMod = createModule(); + ASSERT_TRUE(ProcMod); // Create the calling frame with memory instance. WasmEdge::Runtime::Instance::ModuleInstance Mod(""); @@ -86,15 +98,12 @@ TEST(WasmEdgeProcessTest, SetProgName) { DummyCallFrame, std::initializer_list{UINT32_C(0), UINT32_C(4)}, {})); - - delete ProcMod; } TEST(WasmEdgeProcessTest, AddArg) { // Create the wasmedge_process module instance. - auto *ProcMod = - dynamic_cast(createModule()); - ASSERT_TRUE(ProcMod != nullptr); + auto ProcMod = createModule(); + ASSERT_TRUE(ProcMod); // Create the calling frame with memory instance. WasmEdge::Runtime::Instance::ModuleInstance Mod(""); @@ -151,15 +160,12 @@ TEST(WasmEdgeProcessTest, AddArg) { DummyCallFrame, std::initializer_list{UINT32_C(0), UINT32_C(4)}, {})); - - delete ProcMod; } TEST(WasmEdgeProcessTest, AddEnv) { // Create the wasmedge_process module instance. - auto *ProcMod = - dynamic_cast(createModule()); - ASSERT_TRUE(ProcMod != nullptr); + auto ProcMod = createModule(); + ASSERT_TRUE(ProcMod); // Create the calling frame with memory instance. WasmEdge::Runtime::Instance::ModuleInstance Mod(""); @@ -213,15 +219,12 @@ TEST(WasmEdgeProcessTest, AddEnv) { std::initializer_list{ UINT32_C(0), UINT32_C(4), UINT32_C(4), UINT32_C(6)}, {})); - - delete ProcMod; } TEST(WasmEdgeProcessTest, AddStdIn) { // Create the wasmedge_process module instance. - auto *ProcMod = - dynamic_cast(createModule()); - ASSERT_TRUE(ProcMod != nullptr); + auto ProcMod = createModule(); + ASSERT_TRUE(ProcMod); // Create the calling frame with memory instance. WasmEdge::Runtime::Instance::ModuleInstance Mod(""); @@ -272,15 +275,12 @@ TEST(WasmEdgeProcessTest, AddStdIn) { DummyCallFrame, std::initializer_list{UINT32_C(0), UINT32_C(4)}, {})); - - delete ProcMod; } TEST(WasmEdgeProcessTest, SetTimeOut) { // Create the wasmedge_process module instance. - auto *ProcMod = - dynamic_cast(createModule()); - ASSERT_TRUE(ProcMod != nullptr); + auto ProcMod = createModule(); + ASSERT_TRUE(ProcMod); // Get the function "wasmedge_process_set_timeout". auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_set_timeout"); @@ -295,15 +295,12 @@ TEST(WasmEdgeProcessTest, SetTimeOut) { DummyCallFrame, std::initializer_list{UINT32_C(100)}, {})); EXPECT_EQ(ProcMod->getEnv().TimeOut, 100U); - - delete ProcMod; } TEST(WasmEdgeProcessTest, Run) { // Create the wasmedge_process module instance. - auto *ProcMod = - dynamic_cast(createModule()); - ASSERT_TRUE(ProcMod != nullptr); + auto ProcMod = createModule(); + ASSERT_TRUE(ProcMod); // Create the calling frame with memory instance. WasmEdge::Runtime::Instance::ModuleInstance Mod(""); @@ -377,15 +374,12 @@ TEST(WasmEdgeProcessTest, Run) { std::string OutStr = "123456 test\n"; EXPECT_TRUE(std::equal(ProcMod->getEnv().StdOut.begin(), ProcMod->getEnv().StdOut.end(), OutStr.begin())); - - delete ProcMod; } TEST(WasmEdgeProcessTest, GetExitCode) { // Create the wasmedge_process module instance. - auto *ProcMod = - dynamic_cast(createModule()); - ASSERT_TRUE(ProcMod != nullptr); + auto ProcMod = createModule(); + ASSERT_TRUE(ProcMod); // Get the function "wasmedge_process_get_exit_code". auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_get_exit_code"); @@ -399,15 +393,12 @@ TEST(WasmEdgeProcessTest, GetExitCode) { std::array RetVal; EXPECT_TRUE(HostFuncInst.run(DummyCallFrame, {}, RetVal)); EXPECT_EQ(RetVal[0].get(), 0); - - delete ProcMod; } TEST(WasmEdgeProcessTest, GetStdOut) { // Create the wasmedge_process module instance. - auto *ProcMod = - dynamic_cast(createModule()); - ASSERT_TRUE(ProcMod != nullptr); + auto ProcMod = createModule(); + ASSERT_TRUE(ProcMod); // Create the calling frame with memory instance. WasmEdge::Runtime::Instance::ModuleInstance Mod(""); @@ -469,15 +460,12 @@ TEST(WasmEdgeProcessTest, GetStdOut) { EXPECT_TRUE(std::equal(ProcMod->getEnv().StdOut.begin(), ProcMod->getEnv().StdOut.end(), MemInst.getPointer(0))); - - delete ProcMod; } TEST(WasmEdgeProcessTest, GetStdErr) { // Create the wasmedge_process module instance. - auto *ProcMod = - dynamic_cast(createModule()); - ASSERT_TRUE(ProcMod != nullptr); + auto ProcMod = createModule(); + ASSERT_TRUE(ProcMod); // Create the calling frame with memory instance. WasmEdge::Runtime::Instance::ModuleInstance Mod(""); @@ -538,15 +526,13 @@ TEST(WasmEdgeProcessTest, GetStdErr) { EXPECT_TRUE(std::equal(ProcMod->getEnv().StdOut.begin(), ProcMod->getEnv().StdOut.end(), MemInst.getPointer(0))); - - delete ProcMod; } TEST(WasmEdgeProcessTest, Module) { // Create the wasmedge_process module instance. - auto *ProcMod = - dynamic_cast(createModule()); - EXPECT_FALSE(ProcMod == nullptr); + auto ProcMod = createModule(); + ASSERT_TRUE(ProcMod); + EXPECT_EQ(ProcMod->getEnv().ExitCode, 0U); EXPECT_EQ(ProcMod->getFuncExportNum(), 11U); EXPECT_NE(ProcMod->findFuncExports("wasmedge_process_set_prog_name"), @@ -564,7 +550,6 @@ TEST(WasmEdgeProcessTest, Module) { EXPECT_NE(ProcMod->findFuncExports("wasmedge_process_get_stderr_len"), nullptr); EXPECT_NE(ProcMod->findFuncExports("wasmedge_process_get_stderr"), nullptr); - delete ProcMod; } GTEST_API_ int main(int argc, char **argv) { diff --git a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp index cbed4455..77529201 100644 --- a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp +++ b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp @@ -13,7 +13,18 @@ using WasmEdge::Host::StableDiffusion::ErrNo; namespace { -WasmEdge::Runtime::Instance::ModuleInstance *createModule() { + +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + +std::unique_ptr createModule() { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( "../../../plugins/wasmedge_stablediffusion/" WASMEDGE_LIB_PREFIX @@ -21,10 +32,10 @@ WasmEdge::Runtime::Instance::ModuleInstance *createModule() { if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasmedge_stablediffusion"sv)) { if (const auto *Module = Plugin->findModule("wasmedge_stablediffusion"sv)) { - return Module->create().release(); + return dynamicPointerCast(Module->create()); } } - return nullptr; + return {}; } } // namespace @@ -51,8 +62,8 @@ void writeFatPointer(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { // Create the stable diffusion module instance. - auto *SBMod = dynamic_cast(createModule()); - EXPECT_FALSE(SBMod == nullptr); + auto SBMod = createModule(); + ASSERT_TRUE(SBMod); EXPECT_EQ(SBMod->getFuncExportNum(), 4U); // Create the calling frame with memory instance. @@ -304,7 +315,6 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { EXPECT_GE(BytesWritten, 50); EXPECT_TRUE(std::filesystem::exists(OutputPathString2)); } - delete SBMod; } GTEST_API_ int main(int argc, char **argv) { diff --git a/test/plugins/wasmedge_tensorflow/wasmedge_tensorflow.cpp b/test/plugins/wasmedge_tensorflow/wasmedge_tensorflow.cpp index 8b53e4b3..8ce35675 100644 --- a/test/plugins/wasmedge_tensorflow/wasmedge_tensorflow.cpp +++ b/test/plugins/wasmedge_tensorflow/wasmedge_tensorflow.cpp @@ -10,11 +10,23 @@ #include #include #include +#include #include #include namespace { -WasmEdge::Runtime::Instance::ModuleInstance *createModule() { + +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + +std::unique_ptr createModule() { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( "../../../plugins/wasmedge_tensorflow/" WASMEDGE_LIB_PREFIX @@ -22,10 +34,11 @@ WasmEdge::Runtime::Instance::ModuleInstance *createModule() { if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasmedge_tensorflow"sv)) { if (const auto *Module = Plugin->findModule("wasmedge_tensorflow"sv)) { - return Module->create().release(); + return dynamicPointerCast( + Module->create()); } } - return nullptr; + return {}; } } // namespace @@ -33,9 +46,9 @@ WasmEdge::Runtime::Instance::ModuleInstance *createModule() { TEST(WasmEdgeTensorflowTest, Module) { // Create the wasmedge_tensorflow module instance. - auto *TFMod = - dynamic_cast(createModule()); - EXPECT_FALSE(TFMod == nullptr); + auto TFMod = createModule(); + ASSERT_TRUE(TFMod); + EXPECT_EQ(TFMod->getFuncExportNum(), 11U); EXPECT_NE(TFMod->findFuncExports("create_session"), nullptr); EXPECT_NE(TFMod->findFuncExports("create_session_saved_model"), nullptr); @@ -48,7 +61,6 @@ TEST(WasmEdgeTensorflowTest, Module) { EXPECT_NE(TFMod->findFuncExports("append_output"), nullptr); EXPECT_NE(TFMod->findFuncExports("clear_input"), nullptr); EXPECT_NE(TFMod->findFuncExports("clear_output"), nullptr); - delete TFMod; } GTEST_API_ int main(int argc, char **argv) { diff --git a/test/plugins/wasmedge_tensorflowlite/wasmedge_tensorflowlite.cpp b/test/plugins/wasmedge_tensorflowlite/wasmedge_tensorflowlite.cpp index 61c49219..843e5fc5 100644 --- a/test/plugins/wasmedge_tensorflowlite/wasmedge_tensorflowlite.cpp +++ b/test/plugins/wasmedge_tensorflowlite/wasmedge_tensorflowlite.cpp @@ -10,11 +10,23 @@ #include #include #include +#include #include #include namespace { -WasmEdge::Runtime::Instance::ModuleInstance *createModule() { + +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + +std::unique_ptr createModule() { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( "../../../plugins/wasmedge_tensorflowlite/" WASMEDGE_LIB_PREFIX @@ -22,21 +34,21 @@ WasmEdge::Runtime::Instance::ModuleInstance *createModule() { if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasmedge_tensorflowlite"sv)) { if (const auto *Module = Plugin->findModule("wasmedge_tensorflowlite"sv)) { - return Module->create().release(); + return dynamicPointerCast( + Module->create()); } } - return nullptr; + return {}; } + } // namespace // TODO: unit tests for every functions. TEST(WasmEdgeTensorflowLiteTest, Module) { // Create the wasmedge_tensorflowlite module instance. - auto *TFLiteMod = - dynamic_cast( - createModule()); - EXPECT_FALSE(TFLiteMod == nullptr); + auto TFLiteMod = createModule(); + ASSERT_TRUE(TFLiteMod); EXPECT_EQ(TFLiteMod->getFuncExportNum(), 7U); EXPECT_NE(TFLiteMod->findFuncExports("create_session"), nullptr); EXPECT_NE(TFLiteMod->findFuncExports("delete_session"), nullptr); @@ -45,7 +57,6 @@ TEST(WasmEdgeTensorflowLiteTest, Module) { EXPECT_NE(TFLiteMod->findFuncExports("get_tensor_len"), nullptr); EXPECT_NE(TFLiteMod->findFuncExports("get_tensor_data"), nullptr); EXPECT_NE(TFLiteMod->findFuncExports("append_input"), nullptr); - delete TFLiteMod; } GTEST_API_ int main(int argc, char **argv) { diff --git a/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp b/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp index e22f725d..46b3cb78 100644 --- a/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp +++ b/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp @@ -11,23 +11,36 @@ #include #include #include +#include #include #include namespace { + WasmEdge::Runtime::CallingFrame DummyCallFrame(nullptr, nullptr); -WasmEdge::Runtime::Instance::ModuleInstance *createModule() { +template +inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { + static_assert(std::has_virtual_destructor_v); + T *P = dynamic_cast(R.get()); + if (P) { + R.release(); + } + return std::unique_ptr(P); +} + +std::unique_ptr createModule() { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( "../../../plugins/wasmedge_zlib/" WASMEDGE_LIB_PREFIX "wasmedgePluginWasmEdgeZlib" WASMEDGE_LIB_EXTENSION)); if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasmedge_zlib"sv)) { if (const auto *Module = Plugin->findModule("wasmedge_zlib"sv)) { - return Module->create().release(); + return dynamicPointerCast( + Module->create()); } } - return nullptr; + return {}; } } // namespace @@ -49,9 +62,8 @@ constexpr auto RandChar = []() -> char { }; TEST(WasmEdgeZlibTest, DeflateInflateCycle) { - auto *ZlibMod = - dynamic_cast(createModule()); - ASSERT_TRUE(ZlibMod != nullptr); + auto ZlibMod = createModule(); + ASSERT_TRUE(ZlibMod); // Create the calling frame with memory instance. WasmEdge::Runtime::Instance::ModuleInstance Mod(""); @@ -250,9 +262,9 @@ TEST(WasmEdgeZlibTest, DeflateInflateCycle) { TEST(WasmEdgeZlibTest, Module) { // Create the wasmedge_zlib module instance. - auto *ZlibMod = - dynamic_cast(createModule()); - EXPECT_FALSE(ZlibMod == nullptr); + auto ZlibMod = createModule(); + ASSERT_TRUE(ZlibMod); + EXPECT_TRUE(ZlibMod->getEnv().ZStreamMap.empty()); EXPECT_EQ(ZlibMod->getFuncExportNum(), 76U); @@ -332,8 +344,6 @@ TEST(WasmEdgeZlibTest, Module) { EXPECT_NE(ZlibMod->findFuncExports("inflateCodesUsed"), nullptr); EXPECT_NE(ZlibMod->findFuncExports("inflateResetKeep"), nullptr); EXPECT_NE(ZlibMod->findFuncExports("deflateResetKeep"), nullptr); - - delete ZlibMod; } GTEST_API_ int main(int ArgC, char **ArgV) { From 4a2b6b1f4f653b3f1bb1d484c02dbaea63c5ed74 Mon Sep 17 00:00:00 2001 From: vincent Date: Mon, 29 Apr 2024 14:11:12 +0800 Subject: [PATCH 414/623] [WASI-NN] burn: implement wasi_nn_rust A wasi_nn plugin written in Rust, with the current backend using the burn.rs framework. Signed-off-by: vincent --- plugins/CMakeLists.txt | 4 + plugins/wasi_nn_rust/.gitignore | 1 + plugins/wasi_nn_rust/CMakeLists.txt | 19 ++ plugins/wasi_nn_rust/Cargo.toml | 17 ++ plugins/wasi_nn_rust/src/helper.rs | 11 + plugins/wasi_nn_rust/src/lib.rs | 431 ++++++++++++++++++++++++++++ 6 files changed, 483 insertions(+) create mode 100644 plugins/wasi_nn_rust/.gitignore create mode 100644 plugins/wasi_nn_rust/CMakeLists.txt create mode 100644 plugins/wasi_nn_rust/Cargo.toml create mode 100644 plugins/wasi_nn_rust/src/helper.rs create mode 100644 plugins/wasi_nn_rust/src/lib.rs diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index b7aab370..49ab84b1 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -84,6 +84,10 @@ if(WASMEDGE_PLUGIN_OPENCVMINI) endif() endif() +if(WASMEDGE_PLUGIN_WASI_NN_RUST_BACKEND) + add_subdirectory(wasi_nn_rust) +endif() + if(WASMEDGE_PLUGIN_ZLIB) add_subdirectory(wasmedge_zlib) endif() diff --git a/plugins/wasi_nn_rust/.gitignore b/plugins/wasi_nn_rust/.gitignore new file mode 100644 index 00000000..eb5a316c --- /dev/null +++ b/plugins/wasi_nn_rust/.gitignore @@ -0,0 +1 @@ +target diff --git a/plugins/wasi_nn_rust/CMakeLists.txt b/plugins/wasi_nn_rust/CMakeLists.txt new file mode 100644 index 00000000..b4b4c166 --- /dev/null +++ b/plugins/wasi_nn_rust/CMakeLists.txt @@ -0,0 +1,19 @@ +if(CMAKE_BUILD_TYPE STREQUAL "Debug") + set(CARGO_CMD cargo build) + set(TARGET_DIR "debug") +else() + set(CARGO_CMD cargo build --release) + set(TARGET_DIR "release") +endif() + +set(RS_SO ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_DIR}/libwasmedgePluginWasiNN${CMAKE_SHARED_LIBRARY_SUFFIX}) + +set(WASMEDGE_LIB_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../lib/api) + +add_custom_target(wasi_nn_rust ALL + COMMAND WASMEDGE_LIB_DIR=${WASMEDGE_LIB_DIR} LD_LIBARAY_PATH=${WASMEDGE_LIB_DIR} CARGO_TARGET_DIR=${CMAKE_CURRENT_BINARY_DIR} ${CARGO_CMD} + COMMAND cp ${RS_SO} ${CMAKE_CURRENT_BINARY_DIR} + COMMAND rm -rf ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_DIR} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + DEPENDS wasmedge_shared +) diff --git a/plugins/wasi_nn_rust/Cargo.toml b/plugins/wasi_nn_rust/Cargo.toml new file mode 100644 index 00000000..f4855dc3 --- /dev/null +++ b/plugins/wasi_nn_rust/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "wasi_nn_rust" +version = "0.0.1" +edition = "2021" + +[lib] +name = "wasmedgePluginWasiNN" +path = "src/lib.rs" +crate-type = ["cdylib"] + +[dependencies] +squeezenet-burn = { git = "https://github.com/tracel-ai/models.git", features = ["weights_file"], default-features = false } +burn = { version = "0.13.2", features = ["ndarray", "wgpu"] } +wasmedge_plugin_sdk = { git = "https://github.com/second-state/wasmedge_plugin_rust_sdk.git" } +wasi-nn = { git = "https://github.com/CaptainVincent/wasi-nn.git", branch = "burn" } +lazy_static = "1.4.0" +bytemuck = "1.16.0" \ No newline at end of file diff --git a/plugins/wasi_nn_rust/src/helper.rs b/plugins/wasi_nn_rust/src/helper.rs new file mode 100644 index 00000000..81375c34 --- /dev/null +++ b/plugins/wasi_nn_rust/src/helper.rs @@ -0,0 +1,11 @@ +#[macro_export] +macro_rules! get_slice { + ($memory:expr, $ptr:expr, $length:expr, $ty:ty) => {{ + let raw_bytes = $memory + .data_pointer($ptr as usize, $length as usize) + .expect("Failed to get data pointer"); + bytemuck::cast_slice::(raw_bytes) + }}; +} + +pub use get_slice; diff --git a/plugins/wasi_nn_rust/src/lib.rs b/plugins/wasi_nn_rust/src/lib.rs new file mode 100644 index 00000000..a0fc1ff5 --- /dev/null +++ b/plugins/wasi_nn_rust/src/lib.rs @@ -0,0 +1,431 @@ +mod helper; + +pub enum ErrNo { + Success = 0, // No error occurred. + InvalidArgument = 1, // Caller module passed an invalid argument. + InvalidEncoding = 2, // Invalid encoding. + MissingMemory = 3, // Caller module is missing a memory export. + Busy = 4, // Device or resource busy. + RuntimeError = 5, // Runtime Error. + UnsupportedOperation = 6, // Unsupported Operation. + TooLarge = 7, // Too Large. + NotFound = 8, // Not Found. +} +mod wasi_nn { + use crate::helper::get_slice; + use crate::ErrNo; + use burn::backend::wgpu::{AutoGraphicsApi, Wgpu}; + use burn::backend::NdArray; + use burn::prelude::Backend; + use burn::tensor::Tensor; + use lazy_static::lazy_static; + use squeezenet_burn::model::squeezenet1::Model; + use std::collections::HashMap; + use std::env; + use std::mem; + use std::sync::Mutex; + use wasi_nn::TensorType; + use wasmedge_plugin_sdk::{ + error::CoreError, + memory::Memory, + module::{PluginModule, SyncInstanceRef}, + types::{ValType, WasmVal}, + }; + + type NdArrayBackend = NdArray; + type WgpuBackend = Wgpu; + + enum GraphType { + /// The model is loaded to the NdArray backend + WithNdArrayBackend(Graph), + /// The model is loaded to the Wgpu backend + WithWgpuBackend(Graph), + } + + struct Graph { + model: Model, + } + + impl Graph { + /// Constructor + pub fn new(device: &B::Device, path: &str) -> Self { + Self { + model: Model::from_file(path, device), + } + } + } + + enum ExecutionContextType { + /// The model is loaded to the NdArray backend + WithNdArrayBackend(ExecutionContext), + /// The model is loaded to the Wgpu backend + WithWgpuBackend(ExecutionContext), + } + + const INPUT_DIM: usize = 4; + const OUTPUT_DIM: usize = 2; + struct ExecutionContext { + inputs: HashMap>, + outputs: Vec>, + } + + lazy_static! { + static ref GRAPH_HANDLE_MAP: Mutex> = Mutex::new(HashMap::new()); + static ref GRAPH_NAME_MAP: Mutex> = Mutex::new(HashMap::new()); + static ref CONTEXT_HANDLE_MAP: Mutex> = + Mutex::new(HashMap::new()); + } + + fn parse_opts() { + fn process_nn_preload(nn_preload: String) { + let parts: Vec<&str> = nn_preload.split(':').collect(); + if parts.len() < 4 { + panic!("[WASI_NN] Invalid nn-preload format. {:?} len < 4", parts); + } + + let graph_encoding = parts[1].to_string(); + if graph_encoding.to_lowercase() != "burn" { + panic!("[WASI_NN] Unsupported graph encoding. {:?}", graph_encoding); + } + + let file_path = parts[3].to_string(); + if let Ok(metadata) = fs::metadata(&file_path) { + if !metadata.is_file() { + panic!("[WASI_NN] File does not exist. {:?}", file_path); + } + } else { + panic!( + "[WASI_NN] Failed to retrieve metadata for file. {:?}", + file_path + ); + } + + let name = parts[0].to_string(); + let mut graph_map = GRAPH_HANDLE_MAP.lock().unwrap(); + let graph_handle = graph_map.len() as u32; + let mut name_map = GRAPH_NAME_MAP.lock().unwrap(); + name_map.insert(name.clone(), graph_handle); + let execution_target = parts[2].to_string().to_lowercase(); + if execution_target == "gpu" { + let device = Default::default(); + graph_map.insert( + graph_handle, + GraphType::WithWgpuBackend(Graph::new(&device, &file_path)), + ); + } else { + let device = Default::default(); + graph_map.insert( + graph_handle, + GraphType::WithNdArrayBackend(Graph::new(&device, &file_path)), + ); + }; + } + + unsafe { + if let Ok(nn_preload) = (*crate::nn_preload()).to_string() { + println!("nn-preload: {:?}", nn_preload); + process_nn_preload(nn_preload); + } else if let Ok(env_nn_preload) = env::var("WASMEDGE_WASINN_PRELOAD") { + process_nn_preload(env_nn_preload); + } + } + } + + pub fn create_module() -> PluginModule<()> { + fn load<'a>( + _inst: &'a mut SyncInstanceRef, + _memory: &'a mut Memory, + _data: &'a mut (), + _args: Vec, + ) -> Result, CoreError> { + Ok(vec![WasmVal::I32(ErrNo::UnsupportedOperation as i32)]) + } + + fn load_by_name<'a>( + _inst: &'a mut SyncInstanceRef, + memory: &'a mut Memory, + _data: &'a mut (), + args: Vec, + ) -> Result, CoreError> { + if let [WasmVal::I32(data_ptr), WasmVal::I32(data_len), WasmVal::I32(graph_handle_ptr)] = + &args[..] + { + let bytes = memory + .data_pointer(*data_ptr as usize, *data_len as usize) + .unwrap(); + let name = String::from_utf8_lossy(&bytes); + let name_map = GRAPH_NAME_MAP.lock().unwrap(); + if let Some(handle) = name_map.get(name.as_ref()) { + memory.write_data((*graph_handle_ptr as usize).into(), *handle); + Ok(vec![WasmVal::I32(ErrNo::Success as i32)]) + } else { + Ok(vec![WasmVal::I32(ErrNo::NotFound as i32)]) + } + } else { + Ok(vec![WasmVal::I32(ErrNo::InvalidArgument as i32)]) + } + } + + fn init_execution_context<'a>( + _inst: &'a mut SyncInstanceRef, + memory: &'a mut Memory, + _data: &'a mut (), + args: Vec, + ) -> Result, CoreError> { + if let [WasmVal::I32(graph_handle), WasmVal::I32(context_handle_ptr)] = &args[..] { + let mut context_map = CONTEXT_HANDLE_MAP.lock().unwrap(); + let context_handle = context_map.len() as u32; + let graph_map = GRAPH_HANDLE_MAP.lock().unwrap(); + let graph = graph_map + .get(&(*graph_handle as u32)) + .unwrap_or_else(|| unreachable!()); + match graph { + GraphType::WithNdArrayBackend(_) => { + context_map.insert( + context_handle, + ( + *graph_handle as u32, + ExecutionContextType::WithNdArrayBackend(ExecutionContext { + inputs: HashMap::new(), + outputs: vec![], + }), + ), + ); + } + GraphType::WithWgpuBackend(_) => { + context_map.insert( + context_handle, + ( + *graph_handle as u32, + ExecutionContextType::WithWgpuBackend(ExecutionContext { + inputs: HashMap::new(), + outputs: vec![], + }), + ), + ); + } + } + memory.write_data((*context_handle_ptr as usize).into(), context_handle); + Ok(vec![WasmVal::I32(ErrNo::Success as i32)]) + } else { + Ok(vec![WasmVal::I32(ErrNo::InvalidArgument as i32)]) + } + } + + fn set_input<'a>( + _inst: &'a mut SyncInstanceRef, + memory: &'a mut Memory, + _data: &'a mut (), + args: Vec, + ) -> Result, CoreError> { + #[derive(Debug)] + #[repr(C)] + struct WasiTensorData { + dimens_ptr: u32, + dimens_length: u32, + tensor_type: TensorType, + tensor_ptr: u32, + tensor_length: u32, + } + if let [WasmVal::I32(context_handle), WasmVal::I32(input_index), WasmVal::I32(input_tensor_ptr)] = + &args[..] + { + match memory.get_data::((*input_tensor_ptr as usize).into()) { + Some(input_tensor) => { + let raw_dimens = get_slice!( + memory, + input_tensor.dimens_ptr, + INPUT_DIM * mem::size_of::(), + u32 + ); + let dimens: [usize; 4] = raw_dimens + .iter() + .map(|&x| x as usize) + .collect::>() + .try_into() + .unwrap(); + + // FIXME: The type of f32 should be decided at runtime based on input_tensor.tensor_type. + let tensor = get_slice!( + memory, + input_tensor.tensor_ptr, + input_tensor.tensor_length, + f32 + ); + + let mut context_map = CONTEXT_HANDLE_MAP.lock().unwrap(); + let (_, context) = context_map + .get_mut(&(*context_handle as u32)) + .unwrap_or_else(|| unreachable!()); + + match context { + ExecutionContextType::WithNdArrayBackend(inner) => { + let device = Default::default(); + let tensor = + Tensor::::from_data(&tensor[..], &device) + .reshape(dimens); + + inner.inputs.insert(*input_index as u32, tensor); + } + ExecutionContextType::WithWgpuBackend(inner) => { + let device = Default::default(); + let tensor = + Tensor::::from_data(&tensor[..], &device) + .reshape(dimens); + inner.inputs.insert(*input_index as u32, tensor); + } + } + Ok(vec![WasmVal::I32(ErrNo::Success as i32)]) + } + None => Ok(vec![WasmVal::I32(ErrNo::MissingMemory as i32)]), + } + } else { + Ok(vec![WasmVal::I32(ErrNo::InvalidArgument as i32)]) + } + } + + fn compute<'a>( + _inst: &'a mut SyncInstanceRef, + _memory: &'a mut Memory, + _data: &'a mut (), + args: Vec, + ) -> Result, CoreError> { + if let [WasmVal::I32(context_handle)] = &args[..] { + let mut context_map = CONTEXT_HANDLE_MAP.lock().unwrap(); + let (graph_handle, context) = context_map + .get_mut(&(*context_handle as u32)) + .unwrap_or_else(|| unreachable!()); + + let graph_map = GRAPH_HANDLE_MAP.lock().unwrap(); + let graph = graph_map + .get(graph_handle) + .unwrap_or_else(|| unreachable!()); + + match context { + ExecutionContextType::WithNdArrayBackend(ctx_inner) => { + let GraphType::WithNdArrayBackend(graph_inner) = graph else { + unreachable!() + }; + let output = graph_inner + .model + .forward((*ctx_inner.inputs.get(&0).unwrap()).clone()); + ctx_inner.outputs.push(output); + } + ExecutionContextType::WithWgpuBackend(ctx_inner) => { + let GraphType::WithWgpuBackend(graph_inner) = graph else { + unreachable!() + }; + let output = graph_inner + .model + .forward((*ctx_inner.inputs.get(&0).unwrap()).clone()); + ctx_inner.outputs.push(output); + } + }; + Ok(vec![WasmVal::I32(ErrNo::Success as i32)]) + } else { + Ok(vec![WasmVal::I32(ErrNo::InvalidArgument as i32)]) + } + } + + fn get_output<'a>( + _inst: &'a mut SyncInstanceRef, + memory: &'a mut Memory, + _data: &'a mut (), + args: Vec, + ) -> Result, CoreError> { + if let [WasmVal::I32(context_handle), WasmVal::I32(output_index), WasmVal::I32(output_ptr), WasmVal::I32(output_max_size), WasmVal::I32(written_length)] = + &args[..] + { + let mut context_map = CONTEXT_HANDLE_MAP.lock().unwrap(); + let (_, context) = context_map + .get_mut(&(*context_handle as u32)) + .unwrap_or_else(|| unreachable!()); + let raw_output = match context { + ExecutionContextType::WithNdArrayBackend(ctx_inner) => { + ctx_inner.outputs[*output_index as usize] + .clone() + .into_data() + .value + } + ExecutionContextType::WithWgpuBackend(ctx_inner) => { + ctx_inner.outputs[*output_index as usize] + .clone() + .into_data() + .value + } + }; + let output: &[u8] = bytemuck::cast_slice(&raw_output); + if output.len() > *output_max_size as usize { + return Ok(vec![WasmVal::I32(ErrNo::TooLarge as i32)]); + } + memory.write_bytes(output, *output_ptr as u32).unwrap(); + memory.write_data((*written_length as usize).into(), output.len()); + Ok(vec![WasmVal::I32(ErrNo::Success as i32)]) + } else { + Ok(vec![WasmVal::I32(ErrNo::InvalidArgument as i32)]) + } + } + + parse_opts(); + + let mut module = PluginModule::create("wasi_ephemeral_nn", ()).unwrap(); + module + .add_func("load", (vec![ValType::I32; 5], vec![ValType::I32]), load) + .unwrap(); + module + .add_func( + "load_by_name", + (vec![ValType::I32; 3], vec![ValType::I32]), + load_by_name, + ) + .unwrap(); + module + .add_func( + "init_execution_context", + (vec![ValType::I32; 2], vec![ValType::I32]), + init_execution_context, + ) + .unwrap(); + module + .add_func( + "set_input", + (vec![ValType::I32; 3], vec![ValType::I32]), + set_input, + ) + .unwrap(); + module + .add_func( + "compute", + (vec![ValType::I32; 1], vec![ValType::I32]), + compute, + ) + .unwrap(); + module + .add_func( + "get_output", + (vec![ValType::I32; 5], vec![ValType::I32]), + get_output, + ) + .unwrap(); + module + } +} + +use wasi_nn::create_module; +use wasmedge_plugin_sdk::plugin::{option_string, register_plugin, OptionString}; +register_plugin!( + plugin_name = "wasi_nn", + plugin_description = "burn framework adapter as wasi-nn plugin", + version = (0,0,0,1), + modules = [ + {"wasi_nn", "wasinn with burn backend module", create_module} + ], + options = [ + { + "nn-preload", + "Allow preload models from wasinn plugin. Each NN model can be specified as --nn-preload `COMMAND`.", + OptionString, + option_string!("none") + } + ] +); From 4a67d2a9ee5312ef8513a2b61206ea6b1f08ced5 Mon Sep 17 00:00:00 2001 From: vincent Date: Sun, 2 Jun 2024 22:57:24 +0800 Subject: [PATCH 415/623] [WASI-NN] burn: switch squeezenet repo to use prebuilt model feature Signed-off-by: vincent --- plugins/wasi_nn_rust/Cargo.toml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn_rust/Cargo.toml b/plugins/wasi_nn_rust/Cargo.toml index f4855dc3..2cf2ef01 100644 --- a/plugins/wasi_nn_rust/Cargo.toml +++ b/plugins/wasi_nn_rust/Cargo.toml @@ -9,9 +9,11 @@ path = "src/lib.rs" crate-type = ["cdylib"] [dependencies] -squeezenet-burn = { git = "https://github.com/tracel-ai/models.git", features = ["weights_file"], default-features = false } burn = { version = "0.13.2", features = ["ndarray", "wgpu"] } +squeezenet-burn = { git = "https://github.com/CaptainVincent/models.git", branch = "prebuilt-feature", features = [ + "weights_file", +], default-features = false } wasmedge_plugin_sdk = { git = "https://github.com/second-state/wasmedge_plugin_rust_sdk.git" } wasi-nn = { git = "https://github.com/CaptainVincent/wasi-nn.git", branch = "burn" } lazy_static = "1.4.0" -bytemuck = "1.16.0" \ No newline at end of file +bytemuck = "1.16.0" From 307a4eb39901bd2e13655c739d08ae104720afc6 Mon Sep 17 00:00:00 2001 From: vincent Date: Sun, 2 Jun 2024 23:56:09 +0800 Subject: [PATCH 416/623] [WASI-NN] burn: refactor and add support for the whisper model Use -DWASMEDGE_PLUGIN_WASI_NN_RUST_MODEL=squeezenet or whisper for switch plugin model Signed-off-by: vincent --- plugins/CMakeLists.txt | 2 +- plugins/wasi_nn_rust/CMakeLists.txt | 12 +- plugins/wasi_nn_rust/Cargo.toml | 25 ++- plugins/wasi_nn_rust/src/lib.rs | 143 ++++++------------ plugins/wasi_nn_rust/src/models/mod.rs | 4 + plugins/wasi_nn_rust/src/models/squeezenet.rs | 45 ++++++ plugins/wasi_nn_rust/src/models/whisper.rs | 96 ++++++++++++ 7 files changed, 224 insertions(+), 103 deletions(-) create mode 100644 plugins/wasi_nn_rust/src/models/mod.rs create mode 100644 plugins/wasi_nn_rust/src/models/squeezenet.rs create mode 100644 plugins/wasi_nn_rust/src/models/whisper.rs diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index 49ab84b1..8aabe099 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -84,7 +84,7 @@ if(WASMEDGE_PLUGIN_OPENCVMINI) endif() endif() -if(WASMEDGE_PLUGIN_WASI_NN_RUST_BACKEND) +if(WASMEDGE_PLUGIN_WASI_NN_RUST_MODEL) add_subdirectory(wasi_nn_rust) endif() diff --git a/plugins/wasi_nn_rust/CMakeLists.txt b/plugins/wasi_nn_rust/CMakeLists.txt index b4b4c166..e85565d1 100644 --- a/plugins/wasi_nn_rust/CMakeLists.txt +++ b/plugins/wasi_nn_rust/CMakeLists.txt @@ -1,3 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + if(CMAKE_BUILD_TYPE STREQUAL "Debug") set(CARGO_CMD cargo build) set(TARGET_DIR "debug") @@ -6,14 +9,17 @@ else() set(TARGET_DIR "release") endif() +message(STATUS "WasmEdge Wasi-NN Rust plugin model: ${WASMEDGE_PLUGIN_WASI_NN_RUST_MODEL}") +set(CARGO_FEATURES "--features=${WASMEDGE_PLUGIN_WASI_NN_RUST_MODEL}") + set(RS_SO ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_DIR}/libwasmedgePluginWasiNN${CMAKE_SHARED_LIBRARY_SUFFIX}) set(WASMEDGE_LIB_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../lib/api) add_custom_target(wasi_nn_rust ALL - COMMAND WASMEDGE_LIB_DIR=${WASMEDGE_LIB_DIR} LD_LIBARAY_PATH=${WASMEDGE_LIB_DIR} CARGO_TARGET_DIR=${CMAKE_CURRENT_BINARY_DIR} ${CARGO_CMD} - COMMAND cp ${RS_SO} ${CMAKE_CURRENT_BINARY_DIR} - COMMAND rm -rf ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_DIR} + COMMAND WASMEDGE_LIB_DIR=${WASMEDGE_LIB_DIR} LD_LIBARAY_PATH=${WASMEDGE_LIB_DIR} CARGO_TARGET_DIR=${CMAKE_CURRENT_BINARY_DIR} ${CARGO_CMD} ${CARGO_FEATURES} + COMMAND ${CMAKE_COMMAND} -E copy ${RS_SO} ${CMAKE_CURRENT_BINARY_DIR} + COMMAND ${CMAKE_COMMAND} -E remove_directory ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_DIR} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS wasmedge_shared ) diff --git a/plugins/wasi_nn_rust/Cargo.toml b/plugins/wasi_nn_rust/Cargo.toml index 2cf2ef01..005a0bc0 100644 --- a/plugins/wasi_nn_rust/Cargo.toml +++ b/plugins/wasi_nn_rust/Cargo.toml @@ -8,12 +8,31 @@ name = "wasmedgePluginWasiNN" path = "src/lib.rs" crate-type = ["cdylib"] +[features] +default = [] +squeezenet = ["squeezenet-burn"] +whisper = ["whisper-burn", "strum", "strum_macros"] + +[dependencies.squeezenet-burn] +package = "squeezenet-burn" +branch = "prebuilt-feature" +git = "https://github.com/CaptainVincent/models.git" +features = ["weights_file"] +default-features = false +optional = true + +[dependencies.whisper-burn] +package = "whisper" +branch = "dev" +git = "https://github.com/CaptainVincent/whisper-burn.git" +optional = true + [dependencies] burn = { version = "0.13.2", features = ["ndarray", "wgpu"] } -squeezenet-burn = { git = "https://github.com/CaptainVincent/models.git", branch = "prebuilt-feature", features = [ - "weights_file", -], default-features = false } wasmedge_plugin_sdk = { git = "https://github.com/second-state/wasmedge_plugin_rust_sdk.git" } wasi-nn = { git = "https://github.com/CaptainVincent/wasi-nn.git", branch = "burn" } lazy_static = "1.4.0" bytemuck = "1.16.0" +cfg-if = "1.0.0" +strum = { version = "0.25.0", optional = true } +strum_macros = { version = "0.25.0", optional = true } diff --git a/plugins/wasi_nn_rust/src/lib.rs b/plugins/wasi_nn_rust/src/lib.rs index a0fc1ff5..3fbbf9af 100644 --- a/plugins/wasi_nn_rust/src/lib.rs +++ b/plugins/wasi_nn_rust/src/lib.rs @@ -1,4 +1,5 @@ mod helper; +mod models; pub enum ErrNo { Success = 0, // No error occurred. @@ -13,17 +14,20 @@ pub enum ErrNo { } mod wasi_nn { use crate::helper::get_slice; + #[cfg(feature = "squeezenet")] + use crate::models::squeezenet::*; + #[cfg(feature = "whisper")] + use crate::models::whisper::*; use crate::ErrNo; - use burn::backend::wgpu::{AutoGraphicsApi, Wgpu}; + use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; use burn::backend::NdArray; - use burn::prelude::Backend; - use burn::tensor::Tensor; use lazy_static::lazy_static; - use squeezenet_burn::model::squeezenet1::Model; use std::collections::HashMap; use std::env; use std::mem; + use std::process; use std::sync::Mutex; + use wasi_nn::TensorType; use wasmedge_plugin_sdk::{ error::CoreError, @@ -35,69 +39,40 @@ mod wasi_nn { type NdArrayBackend = NdArray; type WgpuBackend = Wgpu; - enum GraphType { + pub enum Graph { /// The model is loaded to the NdArray backend - WithNdArrayBackend(Graph), + WithNdArrayBackend(GraphInner), /// The model is loaded to the Wgpu backend - WithWgpuBackend(Graph), - } - - struct Graph { - model: Model, + WithWgpuBackend(GraphInner), } - impl Graph { - /// Constructor - pub fn new(device: &B::Device, path: &str) -> Self { - Self { - model: Model::from_file(path, device), - } - } - } - - enum ExecutionContextType { + enum ExecutionContext { /// The model is loaded to the NdArray backend - WithNdArrayBackend(ExecutionContext), + WithNdArrayBackend(ContextInner), /// The model is loaded to the Wgpu backend - WithWgpuBackend(ExecutionContext), - } - - const INPUT_DIM: usize = 4; - const OUTPUT_DIM: usize = 2; - struct ExecutionContext { - inputs: HashMap>, - outputs: Vec>, + WithWgpuBackend(ContextInner), } lazy_static! { - static ref GRAPH_HANDLE_MAP: Mutex> = Mutex::new(HashMap::new()); + static ref GRAPH_HANDLE_MAP: Mutex> = Mutex::new(HashMap::new()); static ref GRAPH_NAME_MAP: Mutex> = Mutex::new(HashMap::new()); - static ref CONTEXT_HANDLE_MAP: Mutex> = + static ref CONTEXT_HANDLE_MAP: Mutex> = Mutex::new(HashMap::new()); } fn parse_opts() { fn process_nn_preload(nn_preload: String) { let parts: Vec<&str> = nn_preload.split(':').collect(); + if parts.len() < 4 { - panic!("[WASI_NN] Invalid nn-preload format. {:?} len < 4", parts); + eprintln!("[WASI_NN] Invalid nn-preload format. {:?} len < 4", parts); + process::exit(1); } let graph_encoding = parts[1].to_string(); if graph_encoding.to_lowercase() != "burn" { - panic!("[WASI_NN] Unsupported graph encoding. {:?}", graph_encoding); - } - - let file_path = parts[3].to_string(); - if let Ok(metadata) = fs::metadata(&file_path) { - if !metadata.is_file() { - panic!("[WASI_NN] File does not exist. {:?}", file_path); - } - } else { - panic!( - "[WASI_NN] Failed to retrieve metadata for file. {:?}", - file_path - ); + eprintln!("[WASI_NN] Unsupported graph encoding. {:?}", graph_encoding); + process::exit(1); } let name = parts[0].to_string(); @@ -105,25 +80,24 @@ mod wasi_nn { let graph_handle = graph_map.len() as u32; let mut name_map = GRAPH_NAME_MAP.lock().unwrap(); name_map.insert(name.clone(), graph_handle); - let execution_target = parts[2].to_string().to_lowercase(); - if execution_target == "gpu" { - let device = Default::default(); + let target = parts[2].to_string().to_lowercase(); + if target == "gpu" { + let device = WgpuDevice::BestAvailable; graph_map.insert( graph_handle, - GraphType::WithWgpuBackend(Graph::new(&device, &file_path)), + Graph::WithWgpuBackend(GraphInner::create(parts[3..].to_vec(), &device)), ); } else { let device = Default::default(); graph_map.insert( graph_handle, - GraphType::WithNdArrayBackend(Graph::new(&device, &file_path)), + Graph::WithNdArrayBackend(GraphInner::create(parts[3..].to_vec(), &device)), ); }; } unsafe { if let Ok(nn_preload) = (*crate::nn_preload()).to_string() { - println!("nn-preload: {:?}", nn_preload); process_nn_preload(nn_preload); } else if let Ok(env_nn_preload) = env::var("WASMEDGE_WASINN_PRELOAD") { process_nn_preload(env_nn_preload); @@ -180,27 +154,21 @@ mod wasi_nn { .get(&(*graph_handle as u32)) .unwrap_or_else(|| unreachable!()); match graph { - GraphType::WithNdArrayBackend(_) => { + Graph::WithNdArrayBackend(_) => { context_map.insert( context_handle, ( *graph_handle as u32, - ExecutionContextType::WithNdArrayBackend(ExecutionContext { - inputs: HashMap::new(), - outputs: vec![], - }), + ExecutionContext::WithNdArrayBackend(ContextInner::new()), ), ); } - GraphType::WithWgpuBackend(_) => { + Graph::WithWgpuBackend(_) => { context_map.insert( context_handle, ( *graph_handle as u32, - ExecutionContextType::WithWgpuBackend(ExecutionContext { - inputs: HashMap::new(), - outputs: vec![], - }), + ExecutionContext::WithWgpuBackend(ContextInner::new()), ), ); } @@ -238,7 +206,7 @@ mod wasi_nn { INPUT_DIM * mem::size_of::(), u32 ); - let dimens: [usize; 4] = raw_dimens + let dimens: [usize; INPUT_DIM] = raw_dimens .iter() .map(|&x| x as usize) .collect::>() @@ -259,20 +227,11 @@ mod wasi_nn { .unwrap_or_else(|| unreachable!()); match context { - ExecutionContextType::WithNdArrayBackend(inner) => { - let device = Default::default(); - let tensor = - Tensor::::from_data(&tensor[..], &device) - .reshape(dimens); - - inner.inputs.insert(*input_index as u32, tensor); + ExecutionContext::WithNdArrayBackend(inner) => { + inner.set_input(*input_index as u32, tensor, dimens); } - ExecutionContextType::WithWgpuBackend(inner) => { - let device = Default::default(); - let tensor = - Tensor::::from_data(&tensor[..], &device) - .reshape(dimens); - inner.inputs.insert(*input_index as u32, tensor); + ExecutionContext::WithWgpuBackend(inner) => { + inner.set_input(*input_index as u32, tensor, dimens); } } Ok(vec![WasmVal::I32(ErrNo::Success as i32)]) @@ -302,22 +261,20 @@ mod wasi_nn { .unwrap_or_else(|| unreachable!()); match context { - ExecutionContextType::WithNdArrayBackend(ctx_inner) => { - let GraphType::WithNdArrayBackend(graph_inner) = graph else { + ExecutionContext::WithNdArrayBackend(ctx_inner) => { + let Graph::WithNdArrayBackend(graph_inner) = graph else { unreachable!() }; - let output = graph_inner - .model - .forward((*ctx_inner.inputs.get(&0).unwrap()).clone()); + let output = + graph_inner.compute((*ctx_inner.inputs.get(&0).unwrap()).clone()); ctx_inner.outputs.push(output); } - ExecutionContextType::WithWgpuBackend(ctx_inner) => { - let GraphType::WithWgpuBackend(graph_inner) = graph else { + ExecutionContext::WithWgpuBackend(ctx_inner) => { + let Graph::WithWgpuBackend(graph_inner) = graph else { unreachable!() }; - let output = graph_inner - .model - .forward((*ctx_inner.inputs.get(&0).unwrap()).clone()); + let output = + graph_inner.compute((*ctx_inner.inputs.get(&0).unwrap()).clone()); ctx_inner.outputs.push(output); } }; @@ -341,17 +298,11 @@ mod wasi_nn { .get_mut(&(*context_handle as u32)) .unwrap_or_else(|| unreachable!()); let raw_output = match context { - ExecutionContextType::WithNdArrayBackend(ctx_inner) => { - ctx_inner.outputs[*output_index as usize] - .clone() - .into_data() - .value + ExecutionContext::WithNdArrayBackend(ctx_inner) => { + ctx_inner.get_output(*output_index as usize) } - ExecutionContextType::WithWgpuBackend(ctx_inner) => { - ctx_inner.outputs[*output_index as usize] - .clone() - .into_data() - .value + ExecutionContext::WithWgpuBackend(ctx_inner) => { + ctx_inner.get_output(*output_index as usize) } }; let output: &[u8] = bytemuck::cast_slice(&raw_output); diff --git a/plugins/wasi_nn_rust/src/models/mod.rs b/plugins/wasi_nn_rust/src/models/mod.rs new file mode 100644 index 00000000..f7f7aadf --- /dev/null +++ b/plugins/wasi_nn_rust/src/models/mod.rs @@ -0,0 +1,4 @@ +#[cfg(feature = "squeezenet")] +pub mod squeezenet; +#[cfg(feature = "whisper")] +pub mod whisper; diff --git a/plugins/wasi_nn_rust/src/models/squeezenet.rs b/plugins/wasi_nn_rust/src/models/squeezenet.rs new file mode 100644 index 00000000..c4c7d0d1 --- /dev/null +++ b/plugins/wasi_nn_rust/src/models/squeezenet.rs @@ -0,0 +1,45 @@ +use burn::tensor::backend::Backend; +use burn::tensor::Tensor; +use squeezenet_burn::model::squeezenet1::Model; +use std::collections::HashMap; + +pub struct GraphInner { + pub model: Model, +} + +impl GraphInner { + pub fn create(args: Vec<&str>, device: &B::Device) -> Self { + let weights_path = args[0]; + Self { + model: Model::from_file(weights_path, device), + } + } + pub fn compute(&self, input: Tensor) -> Tensor { + self.model.forward(input) + } +} + +pub const INPUT_DIM: usize = 4; +pub const OUTPUT_DIM: usize = 2; + +pub struct ContextInner { + pub inputs: HashMap>, + pub outputs: Vec>, +} + +impl ContextInner { + pub fn new() -> Self { + Self { + inputs: HashMap::new(), + outputs: Vec::new(), + } + } + pub fn set_input(&mut self, key: u32, input: &[B::FloatElem], dimens: [usize; INPUT_DIM]) { + let device = Default::default(); + let tensor = Tensor::::from_data(&*input, &device).reshape(dimens); + self.inputs.insert(key, tensor); + } + pub fn get_output(&mut self, key: usize) -> Vec<::FloatElem> { + self.outputs[key].clone().into_data().value + } +} diff --git a/plugins/wasi_nn_rust/src/models/whisper.rs b/plugins/wasi_nn_rust/src/models/whisper.rs new file mode 100644 index 00000000..987f77ed --- /dev/null +++ b/plugins/wasi_nn_rust/src/models/whisper.rs @@ -0,0 +1,96 @@ +use burn::config::Config; +use burn::module::Module; +use burn::record::{DefaultRecorder, Recorder}; +use burn::tensor::backend::Backend; +use std::collections::HashMap; +use std::marker::PhantomData; +use std::process; +use strum::IntoEnumIterator; +use whisper_burn::model::Whisper as Model; +use whisper_burn::model::WhisperConfig as ModelConfig; +use whisper_burn::token::{Gpt2Tokenizer, Language}; +use whisper_burn::transcribe::waveform_to_text; + +pub struct GraphInner { + pub model: Model, + pub metadata: Vec, +} + +impl GraphInner { + pub fn create(args: Vec<&str>, device: &B::Device) -> Self { + if args.len() < 4 { + eprintln!( + "[WASI_NN] Invalid nn-preload model format. {:?} len < 4", + args + ); + process::exit(1); + } + let weights_path = args[0]; + let config_path = args[1]; + let config = match ModelConfig::load(config_path) { + Ok(config) => config, + Err(e) => { + eprintln!("Failed to load whisper config: {}", e); + process::exit(1); + } + }; + let recorder = DefaultRecorder::new().load(weights_path.into(), device); + let model = recorder + .map(|record| config.init(device).load_record(record)) + .unwrap(); + Self { + model: model, + metadata: args[2..].iter().map(|&s| s.to_string()).collect(), + } + } + pub fn compute(&self, input: Vec) -> Vec { + let tokenizer_path = &self.metadata[0].to_string(); + let lang_str = &self.metadata[1].to_string(); + let lang = match Language::iter().find(|lang| lang.as_str() == lang_str) { + Some(lang) => lang, + None => { + eprintln!("Invalid language abbreviation: {}", lang_str); + process::exit(1); + } + }; + let bpe = match Gpt2Tokenizer::new_with_path(tokenizer_path) { + Ok(bpe) => bpe, + Err(e) => { + eprintln!("Failed to load tokenizer: {}", e); + process::exit(1); + } + }; + let (text, _) = match waveform_to_text(&self.model, &bpe, lang, input, 16000) { + Ok((text, tokens)) => (text, tokens), + Err(e) => { + eprintln!("Error during transcription: {}", e); + process::exit(1); + } + }; + return text.into_bytes(); + } +} + +pub const INPUT_DIM: usize = 2; + +pub struct ContextInner { + pub inputs: HashMap>, + pub outputs: Vec>, + _marker: PhantomData, +} + +impl ContextInner { + pub fn new() -> Self { + Self { + inputs: HashMap::new(), + outputs: Vec::new(), + _marker: PhantomData, + } + } + pub fn set_input(&mut self, key: u32, input: &[f32], _: [usize; INPUT_DIM]) { + self.inputs.insert(key, input.to_vec()); + } + pub fn get_output(&mut self, key: usize) -> &Vec { + &self.outputs[key] + } +} From 4c22e53eab5f160d4ee1867536613991753e6619 Mon Sep 17 00:00:00 2001 From: vincent Date: Mon, 22 Jul 2024 15:12:36 +0800 Subject: [PATCH 417/623] [WASI-NN] burn: replace dependencies Signed-off-by: vincent --- plugins/wasi_nn_rust/Cargo.toml | 8 ++++---- plugins/wasi_nn_rust/src/lib.rs | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/plugins/wasi_nn_rust/Cargo.toml b/plugins/wasi_nn_rust/Cargo.toml index 005a0bc0..def15253 100644 --- a/plugins/wasi_nn_rust/Cargo.toml +++ b/plugins/wasi_nn_rust/Cargo.toml @@ -16,7 +16,7 @@ whisper = ["whisper-burn", "strum", "strum_macros"] [dependencies.squeezenet-burn] package = "squeezenet-burn" branch = "prebuilt-feature" -git = "https://github.com/CaptainVincent/models.git" +git = "https://github.com/second-state/burn-rs-models.git" features = ["weights_file"] default-features = false optional = true @@ -24,13 +24,13 @@ optional = true [dependencies.whisper-burn] package = "whisper" branch = "dev" -git = "https://github.com/CaptainVincent/whisper-burn.git" +git = "https://github.com/second-state/burn-rs-whisper.git" optional = true [dependencies] burn = { version = "0.13.2", features = ["ndarray", "wgpu"] } -wasmedge_plugin_sdk = { git = "https://github.com/second-state/wasmedge_plugin_rust_sdk.git" } -wasi-nn = { git = "https://github.com/CaptainVincent/wasi-nn.git", branch = "burn" } +wasmedge_plugin_sdk = { git = "https://github.com/second-state/wasmedge_plugin_rust_sdk.git", features = ["standalone"] } +wasmedge-wasi-nn = "0.8.0" lazy_static = "1.4.0" bytemuck = "1.16.0" cfg-if = "1.0.0" diff --git a/plugins/wasi_nn_rust/src/lib.rs b/plugins/wasi_nn_rust/src/lib.rs index 3fbbf9af..562f8304 100644 --- a/plugins/wasi_nn_rust/src/lib.rs +++ b/plugins/wasi_nn_rust/src/lib.rs @@ -28,7 +28,7 @@ mod wasi_nn { use std::process; use std::sync::Mutex; - use wasi_nn::TensorType; + use wasmedge_wasi_nn::TensorType; use wasmedge_plugin_sdk::{ error::CoreError, memory::Memory, From cc654ebfd4bf7eace038428ac5da8918ea33f822 Mon Sep 17 00:00:00 2001 From: "Shen-Ta Hsieh(BestSteve)" Date: Mon, 2 Sep 2024 21:24:01 +0800 Subject: [PATCH 418/623] [Test] Fix undefined variable in newer gcc (#3714) Signed-off-by: Shen-Ta Hsieh --- test/plugins/unittest/CMakeLists.txt | 1 + test/plugins/unittest/testplugin.cpp | 6 +++--- test/plugins/unittest/testplugin.h | 6 +++--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/test/plugins/unittest/CMakeLists.txt b/test/plugins/unittest/CMakeLists.txt index 8a0e2ad5..979cc968 100644 --- a/test/plugins/unittest/CMakeLists.txt +++ b/test/plugins/unittest/CMakeLists.txt @@ -69,6 +69,7 @@ target_include_directories(wasmedgePluginUnittestsCPP target_link_libraries(wasmedgePluginUnittestsCPP PRIVATE + wasmedgePluginTestModuleCPP ${GTEST_BOTH_LIBRARIES} ) diff --git a/test/plugins/unittest/testplugin.cpp b/test/plugins/unittest/testplugin.cpp index 589df415..e54fba13 100644 --- a/test/plugins/unittest/testplugin.cpp +++ b/test/plugins/unittest/testplugin.cpp @@ -11,15 +11,15 @@ namespace Host { using namespace std::literals::string_view_literals; -PO::List +WASMEDGE_EXPORT PO::List WasmEdgePluginTestEnv::CmdArgs(PO::Description("Test for args."sv), PO::MetaVar("ARG"sv)); -PO::Option +WASMEDGE_EXPORT PO::Option WasmEdgePluginTestEnv::CmdName(PO::Description("Test for input name."sv), PO::DefaultValue(std::string(""))); -PO::Option +WASMEDGE_EXPORT PO::Option WasmEdgePluginTestEnv::CmdOpt(PO::Description("Test for option."sv)); namespace { diff --git a/test/plugins/unittest/testplugin.h b/test/plugins/unittest/testplugin.h index 82376b3e..4418ed95 100644 --- a/test/plugins/unittest/testplugin.h +++ b/test/plugins/unittest/testplugin.h @@ -18,9 +18,9 @@ class WasmEdgePluginTestEnv { public: WasmEdgePluginTestEnv() noexcept = default; - static PO::List CmdArgs; - static PO::Option CmdName; - static PO::Option CmdOpt; + WASMEDGE_EXPORT static PO::List CmdArgs; + WASMEDGE_EXPORT static PO::Option CmdName; + WASMEDGE_EXPORT static PO::Option CmdOpt; }; template From 7c6907fc063357342c7ffc1bd079ff3a53e7062d Mon Sep 17 00:00:00 2001 From: dm4 Date: Mon, 2 Sep 2024 22:09:46 +0800 Subject: [PATCH 419/623] [WASI-NN] ggml: bump to b3651 (#3719) Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index a40f3b80..8006f715 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -68,7 +68,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) FetchContent_Declare( llama GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b3613 + GIT_TAG b3651 GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(llama) From 18c35bc9e4e1f57c9d8d79024324f73661493806 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Mon, 2 Sep 2024 18:17:39 +0800 Subject: [PATCH 420/623] [Plugin] Rename the option name for WASI-NN burn.rs backend. Signed-off-by: YiYing He --- plugins/CMakeLists.txt | 4 ++-- plugins/{wasi_nn_rust => wasi_nn_burnrs}/.gitignore | 0 plugins/{wasi_nn_rust => wasi_nn_burnrs}/CMakeLists.txt | 6 +++--- plugins/{wasi_nn_rust => wasi_nn_burnrs}/Cargo.toml | 2 +- plugins/{wasi_nn_rust => wasi_nn_burnrs}/src/helper.rs | 0 plugins/{wasi_nn_rust => wasi_nn_burnrs}/src/lib.rs | 0 plugins/{wasi_nn_rust => wasi_nn_burnrs}/src/models/mod.rs | 0 .../src/models/squeezenet.rs | 0 .../{wasi_nn_rust => wasi_nn_burnrs}/src/models/whisper.rs | 0 9 files changed, 6 insertions(+), 6 deletions(-) rename plugins/{wasi_nn_rust => wasi_nn_burnrs}/.gitignore (100%) rename plugins/{wasi_nn_rust => wasi_nn_burnrs}/CMakeLists.txt (79%) rename plugins/{wasi_nn_rust => wasi_nn_burnrs}/Cargo.toml (97%) rename plugins/{wasi_nn_rust => wasi_nn_burnrs}/src/helper.rs (100%) rename plugins/{wasi_nn_rust => wasi_nn_burnrs}/src/lib.rs (100%) rename plugins/{wasi_nn_rust => wasi_nn_burnrs}/src/models/mod.rs (100%) rename plugins/{wasi_nn_rust => wasi_nn_burnrs}/src/models/squeezenet.rs (100%) rename plugins/{wasi_nn_rust => wasi_nn_burnrs}/src/models/whisper.rs (100%) diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index 8aabe099..4f14f1dc 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -84,8 +84,8 @@ if(WASMEDGE_PLUGIN_OPENCVMINI) endif() endif() -if(WASMEDGE_PLUGIN_WASI_NN_RUST_MODEL) - add_subdirectory(wasi_nn_rust) +if(WASMEDGE_PLUGIN_WASI_NN_BURNRS_MODEL) + add_subdirectory(wasi_nn_burnrs) endif() if(WASMEDGE_PLUGIN_ZLIB) diff --git a/plugins/wasi_nn_rust/.gitignore b/plugins/wasi_nn_burnrs/.gitignore similarity index 100% rename from plugins/wasi_nn_rust/.gitignore rename to plugins/wasi_nn_burnrs/.gitignore diff --git a/plugins/wasi_nn_rust/CMakeLists.txt b/plugins/wasi_nn_burnrs/CMakeLists.txt similarity index 79% rename from plugins/wasi_nn_rust/CMakeLists.txt rename to plugins/wasi_nn_burnrs/CMakeLists.txt index e85565d1..f7967a46 100644 --- a/plugins/wasi_nn_rust/CMakeLists.txt +++ b/plugins/wasi_nn_burnrs/CMakeLists.txt @@ -9,14 +9,14 @@ else() set(TARGET_DIR "release") endif() -message(STATUS "WasmEdge Wasi-NN Rust plugin model: ${WASMEDGE_PLUGIN_WASI_NN_RUST_MODEL}") -set(CARGO_FEATURES "--features=${WASMEDGE_PLUGIN_WASI_NN_RUST_MODEL}") +message(STATUS "WasmEdge WASI-NN Burn.rs backend plugin model: ${WASMEDGE_PLUGIN_WASI_NN_BURNRS_MODEL}") +set(CARGO_FEATURES "--features=${WASMEDGE_PLUGIN_WASI_NN_BURNRS_MODEL}") set(RS_SO ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_DIR}/libwasmedgePluginWasiNN${CMAKE_SHARED_LIBRARY_SUFFIX}) set(WASMEDGE_LIB_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../lib/api) -add_custom_target(wasi_nn_rust ALL +add_custom_target(wasi_nn_burnrs ALL COMMAND WASMEDGE_LIB_DIR=${WASMEDGE_LIB_DIR} LD_LIBARAY_PATH=${WASMEDGE_LIB_DIR} CARGO_TARGET_DIR=${CMAKE_CURRENT_BINARY_DIR} ${CARGO_CMD} ${CARGO_FEATURES} COMMAND ${CMAKE_COMMAND} -E copy ${RS_SO} ${CMAKE_CURRENT_BINARY_DIR} COMMAND ${CMAKE_COMMAND} -E remove_directory ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_DIR} diff --git a/plugins/wasi_nn_rust/Cargo.toml b/plugins/wasi_nn_burnrs/Cargo.toml similarity index 97% rename from plugins/wasi_nn_rust/Cargo.toml rename to plugins/wasi_nn_burnrs/Cargo.toml index def15253..6a71a217 100644 --- a/plugins/wasi_nn_rust/Cargo.toml +++ b/plugins/wasi_nn_burnrs/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "wasi_nn_rust" +name = "wasi_nn_burnrs" version = "0.0.1" edition = "2021" diff --git a/plugins/wasi_nn_rust/src/helper.rs b/plugins/wasi_nn_burnrs/src/helper.rs similarity index 100% rename from plugins/wasi_nn_rust/src/helper.rs rename to plugins/wasi_nn_burnrs/src/helper.rs diff --git a/plugins/wasi_nn_rust/src/lib.rs b/plugins/wasi_nn_burnrs/src/lib.rs similarity index 100% rename from plugins/wasi_nn_rust/src/lib.rs rename to plugins/wasi_nn_burnrs/src/lib.rs diff --git a/plugins/wasi_nn_rust/src/models/mod.rs b/plugins/wasi_nn_burnrs/src/models/mod.rs similarity index 100% rename from plugins/wasi_nn_rust/src/models/mod.rs rename to plugins/wasi_nn_burnrs/src/models/mod.rs diff --git a/plugins/wasi_nn_rust/src/models/squeezenet.rs b/plugins/wasi_nn_burnrs/src/models/squeezenet.rs similarity index 100% rename from plugins/wasi_nn_rust/src/models/squeezenet.rs rename to plugins/wasi_nn_burnrs/src/models/squeezenet.rs diff --git a/plugins/wasi_nn_rust/src/models/whisper.rs b/plugins/wasi_nn_burnrs/src/models/whisper.rs similarity index 100% rename from plugins/wasi_nn_rust/src/models/whisper.rs rename to plugins/wasi_nn_burnrs/src/models/whisper.rs From 0c1b53ec7d42c4c329a16deae91fdf440818611c Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Tue, 3 Sep 2024 16:53:05 +0800 Subject: [PATCH 421/623] [Test] Decouple plugin content and plugin test Signed-off-by: Shen-Ta Hsieh --- test/plugins/unittest/CMakeLists.txt | 1 - test/plugins/unittest/testplugin.cpp | 6 +++--- test/plugins/unittest/testplugin.h | 6 +++--- test/plugins/unittest/unittest_cpp.cpp | 30 ++++++-------------------- 4 files changed, 12 insertions(+), 31 deletions(-) diff --git a/test/plugins/unittest/CMakeLists.txt b/test/plugins/unittest/CMakeLists.txt index 979cc968..8a0e2ad5 100644 --- a/test/plugins/unittest/CMakeLists.txt +++ b/test/plugins/unittest/CMakeLists.txt @@ -69,7 +69,6 @@ target_include_directories(wasmedgePluginUnittestsCPP target_link_libraries(wasmedgePluginUnittestsCPP PRIVATE - wasmedgePluginTestModuleCPP ${GTEST_BOTH_LIBRARIES} ) diff --git a/test/plugins/unittest/testplugin.cpp b/test/plugins/unittest/testplugin.cpp index e54fba13..589df415 100644 --- a/test/plugins/unittest/testplugin.cpp +++ b/test/plugins/unittest/testplugin.cpp @@ -11,15 +11,15 @@ namespace Host { using namespace std::literals::string_view_literals; -WASMEDGE_EXPORT PO::List +PO::List WasmEdgePluginTestEnv::CmdArgs(PO::Description("Test for args."sv), PO::MetaVar("ARG"sv)); -WASMEDGE_EXPORT PO::Option +PO::Option WasmEdgePluginTestEnv::CmdName(PO::Description("Test for input name."sv), PO::DefaultValue(std::string(""))); -WASMEDGE_EXPORT PO::Option +PO::Option WasmEdgePluginTestEnv::CmdOpt(PO::Description("Test for option."sv)); namespace { diff --git a/test/plugins/unittest/testplugin.h b/test/plugins/unittest/testplugin.h index 4418ed95..82376b3e 100644 --- a/test/plugins/unittest/testplugin.h +++ b/test/plugins/unittest/testplugin.h @@ -18,9 +18,9 @@ class WasmEdgePluginTestEnv { public: WasmEdgePluginTestEnv() noexcept = default; - WASMEDGE_EXPORT static PO::List CmdArgs; - WASMEDGE_EXPORT static PO::Option CmdName; - WASMEDGE_EXPORT static PO::Option CmdOpt; + static PO::List CmdArgs; + static PO::Option CmdName; + static PO::Option CmdOpt; }; template diff --git a/test/plugins/unittest/unittest_cpp.cpp b/test/plugins/unittest/unittest_cpp.cpp index 9410956e..f800c29e 100644 --- a/test/plugins/unittest/unittest_cpp.cpp +++ b/test/plugins/unittest/unittest_cpp.cpp @@ -2,11 +2,10 @@ // SPDX-FileCopyrightText: 2019-2024 Second State INC #include "common/defines.h" +#include "plugin/plugin.h" #include "runtime/callingframe.h" #include "runtime/instance/module.h" -#include "testplugin.h" - #include #include #include @@ -17,16 +16,6 @@ namespace { -template -inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { - static_assert(std::has_virtual_destructor_v); - T *P = dynamic_cast(R.get()); - if (P) { - R.release(); - } - return std::unique_ptr(P); -} - std::unique_ptr createModuleC() { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( @@ -42,7 +31,7 @@ std::unique_ptr createModuleC() { return {}; } -std::unique_ptr createModuleCPP() { +std::unique_ptr createModuleCPP() { using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( "./" WASMEDGE_LIB_PREFIX @@ -57,8 +46,7 @@ std::unique_ptr createModuleCPP() { Parser.set_raw_value("opt"sv); if (const auto *Module = Plugin->findModule("wasmedge_plugintest_cpp_module"sv)) { - return dynamicPointerCast( - Module->create()); + return Module->create(); } } return {}; @@ -78,9 +66,7 @@ TEST(wasmedgePluginTests, CPP_Run) { auto *FuncInst1 = TestModCPP->findFuncExports("arg_len"); EXPECT_NE(FuncInst1, nullptr); EXPECT_TRUE(FuncInst1->isHostFunction()); - auto &HostFuncInst1 = - dynamic_cast( - FuncInst1->getHostFunc()); + auto &HostFuncInst1 = FuncInst1->getHostFunc(); // Test: Run function successfully. EXPECT_TRUE(HostFuncInst1.run(CallFrame, {}, RetVal)); @@ -90,9 +76,7 @@ TEST(wasmedgePluginTests, CPP_Run) { auto *FuncInst2 = TestModCPP->findFuncExports("name_size"); EXPECT_NE(FuncInst2, nullptr); EXPECT_TRUE(FuncInst2->isHostFunction()); - auto &HostFuncInst2 = - dynamic_cast( - FuncInst2->getHostFunc()); + auto &HostFuncInst2 = FuncInst2->getHostFunc(); // Test: Run function successfully. EXPECT_TRUE(HostFuncInst2.run(CallFrame, {}, RetVal)); @@ -102,9 +86,7 @@ TEST(wasmedgePluginTests, CPP_Run) { auto *FuncInst3 = TestModCPP->findFuncExports("opt"); EXPECT_NE(FuncInst3, nullptr); EXPECT_TRUE(FuncInst3->isHostFunction()); - auto &HostFuncInst3 = - dynamic_cast( - FuncInst3->getHostFunc()); + auto &HostFuncInst3 = FuncInst3->getHostFunc(); // Test: Run function successfully. EXPECT_TRUE(HostFuncInst3.run(CallFrame, {}, RetVal)); From 64fcaee59fd9cf714906184dacc14719edc19712 Mon Sep 17 00:00:00 2001 From: grorge Date: Thu, 29 Aug 2024 23:04:01 +0800 Subject: [PATCH 422/623] [plugin] Stable Diffusion: upgrade version to e71ddce Signed-off-by: grorge --- .../wasmedge_stablediffusion/CMakeLists.txt | 13 +- plugins/wasmedge_stablediffusion/sd_func.cpp | 64 ++++++---- plugins/wasmedge_stablediffusion/sd_func.h | 13 +- .../wasmedge_stablediffusion.cpp | 116 ++++++++++-------- 4 files changed, 123 insertions(+), 83 deletions(-) diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index 28214e52..5568673a 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -24,7 +24,7 @@ message(STATUS "Downloading stable diffusion source") FetchContent_Declare( stable-diffusion GIT_REPOSITORY https://github.com/leejet/stable-diffusion.cpp.git - GIT_TAG master-9c51d87 + GIT_TAG master-e71ddce GIT_SHALLOW TRUE ) FetchContent_MakeAvailable(stable-diffusion) @@ -73,10 +73,12 @@ target_include_directories(wasmedgePluginWasmEdgeStableDiffusion $ ${CMAKE_CURRENT_SOURCE_DIR} ) - +target_include_directories( wasmedgePluginWasmEdgeStableDiffusion SYSTEM PRIVATE + "${CMAKE_BINARY_DIR}/_deps/stable-diffusion-src/thirdparty" +) if (MSVC) target_compile_options( - wasmedgePluginWasmEdgeStableDiffusion + stable-diffusion PRIVATE /wd4459 /wd4100 @@ -85,13 +87,16 @@ if (MSVC) ) else() target_compile_options( - wasmedgePluginWasmEdgeStableDiffusion + stable-diffusion PRIVATE -Wno-unused-function -Wno-unused-variable -Wno-unused-parameter -Wno-missing-field-initializers -Wno-deprecated-declarations + -Wno-braced-scalar-init + -Wno-unused-value + -Wno-uninitialized ) endif() diff --git a/plugins/wasmedge_stablediffusion/sd_func.cpp b/plugins/wasmedge_stablediffusion/sd_func.cpp index 7aca1854..f8879122 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.cpp +++ b/plugins/wasmedge_stablediffusion/sd_func.cpp @@ -4,6 +4,7 @@ #include "sd_func.h" #include "common/spdlog.h" #include "sd_env.h" +#include "spdlog/spdlog.h" #include "stable-diffusion.h" #define STB_IMAGE_IMPLEMENTATION @@ -150,8 +151,11 @@ Expect SDConvert::body(const Runtime::CallingFrame &Frame, Expect SDCreateContext::body( const Runtime::CallingFrame &Frame, uint32_t ModelPathPtr, - uint32_t ModelPathLen, uint32_t VaePathPtr, uint32_t VaePathLen, - uint32_t TaesdPathPtr, uint32_t TaesdPathLen, uint32_t ControlNetPathPtr, + uint32_t ModelPathLen, uint32_t clipLPathPtr, uint32_t clipLPathLen, + uint32_t t5xxlPathPtr, uint32_t t5xxlPathLen, + uint32_t diffusionModelPathPtr, uint32_t diffusionModelPathLen, + uint32_t VaePathPtr, uint32_t VaePathLen, uint32_t TaesdPathPtr, + uint32_t TaesdPathLen, uint32_t ControlNetPathPtr, uint32_t ControlNetPathLen, uint32_t LoraModelDirPtr, uint32_t LoraModelDirLen, uint32_t EmbedDirPtr, uint32_t EmbedDirLen, uint32_t IdEmbedDirPtr, uint32_t IdEmbedDirLen, uint32_t VaeDecodeOnly, @@ -164,6 +168,14 @@ Expect SDCreateContext::body( // Check the input model buffer. MEM_SPAN_CHECK(ModelPathSpan, MemInst, char, ModelPathPtr, ModelPathLen, "Failed when accessing the input model path memory."sv) + MEM_SPAN_CHECK(clipLPathSpan, MemInst, char, clipLPathPtr, clipLPathLen, + "Failed when accessing the input clipL path memory."sv) + MEM_SPAN_CHECK(t5xxlPathSpan, MemInst, char, t5xxlPathPtr, t5xxlPathLen, + "Failed when accessing the input t5xxl path memory."sv) + MEM_SPAN_CHECK( + diffusionModelPathSpan, MemInst, char, diffusionModelPathPtr, + diffusionModelPathLen, + "Failed when accessing the input diffusion model path memory."sv) MEM_SPAN_CHECK(VaePathSpan, MemInst, char, VaePathPtr, VaePathLen, "Failed when accessing the input vae path memory."sv) MEM_SPAN_CHECK(ControlNetPathSpan, MemInst, char, ControlNetPathPtr, @@ -194,6 +206,12 @@ Expect SDCreateContext::body( std::string EmbedDir = std::string(EmbedDirSpan.begin(), EmbedDirSpan.end()); std::string IdEmbedDir = std::string(IdEmbedDirSpan.begin(), IdEmbedDirSpan.end()); + std::string clipLPath = + std::string(clipLPathSpan.begin(), clipLPathSpan.end()); + std::string t5xxlPath = + std::string(t5xxlPathSpan.begin(), t5xxlPathSpan.end()); + std::string diffusionModelPath = + std::string(diffusionModelPathSpan.begin(), diffusionModelPathSpan.end()); if (NThreads == -1) { NThreads = get_num_physical_cores(); } @@ -202,10 +220,12 @@ Expect SDCreateContext::body( // Create context and import graph. sd_ctx_t *Ctx = new_sd_ctx( - ModelPath.data(), VaePath.data(), TaesdPath.data(), ControlNetPath.data(), - LoraModelDir.data(), EmbedDir.data(), IdEmbedDir.data(), - static_cast(VaeDecodeOnly), static_cast(VaeTiling), true, - NThreads, static_cast(Wtype), static_cast(RngType), + ModelPath.data(), clipLPath.data(), t5xxlPath.data(), + diffusionModelPath.data(), VaePath.data(), TaesdPath.data(), + ControlNetPath.data(), LoraModelDir.data(), EmbedDir.data(), + IdEmbedDir.data(), static_cast(VaeDecodeOnly), + static_cast(VaeTiling), true, NThreads, + static_cast(Wtype), static_cast(RngType), static_cast(Schedule), ClipOnCpu, ControlNetCpu, VaeOnCpu); if (Ctx == nullptr) { spdlog::error("[WasmEdge-StableDiffusion] Failed to create context."); @@ -219,15 +239,15 @@ Expect SDCreateContext::body( Expect SDTextToImage::body( const Runtime::CallingFrame &Frame, uint32_t PromptPtr, uint32_t PromptLen, uint32_t SessionId, uint32_t ControlImagePtr, uint32_t ControlImageLen, - uint32_t NegativePromptPtr, uint32_t NegativePromptLen, uint32_t Width, - uint32_t Height, int32_t ClipSkip, float CfgScale, uint32_t SampleMethod, - uint32_t SampleSteps, uint32_t Seed, uint32_t BatchCount, - float ControlStrength, float StyleRatio, uint32_t NormalizeInput, - uint32_t InputIdImagesDirPtr, uint32_t InputIdImagesDirLen, - uint32_t CannyPreprocess, uint32_t UpscaleModelPathPtr, - uint32_t UpscaleModelPathLen, uint32_t UpscaleRepeats, - uint32_t OutputPathPtr, uint32_t OutputPathLen, uint32_t OutBufferPtr, - uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr) { + uint32_t NegativePromptPtr, uint32_t NegativePromptLen, float Guidance, + uint32_t Width, uint32_t Height, int32_t ClipSkip, float CfgScale, + uint32_t SampleMethod, uint32_t SampleSteps, uint32_t Seed, + uint32_t BatchCount, float ControlStrength, float StyleRatio, + uint32_t NormalizeInput, uint32_t InputIdImagesDirPtr, + uint32_t InputIdImagesDirLen, uint32_t CannyPreprocess, uint32_t, uint32_t, + uint32_t, uint32_t OutputPathPtr, uint32_t OutputPathLen, + uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, + uint32_t BytesWrittenPtr) { // Check memory instance from module. MEMINST_CHECK(MemInst, Frame, 0) // Check the input model buffer. @@ -269,9 +289,9 @@ Expect SDTextToImage::body( spdlog::info("[WasmEdge-StableDiffusion] Start to generate image."sv); Results = txt2img(SDCtx, Prompt.data(), NegativePrompt.data(), ClipSkip, CfgScale, - Width, Height, sample_method_t(SampleMethod), SampleSteps, Seed, - BatchCount, ControlImage, ControlStrength, StyleRatio, - NormalizeInput, InputIdImagesDir.data()); + Guidance, Width, Height, sample_method_t(SampleMethod), + SampleSteps, Seed, BatchCount, ControlImage, ControlStrength, + StyleRatio, NormalizeInput, InputIdImagesDir.data()); // TODO upscale image int Len; unsigned char *Png = stbi_write_png_to_mem( @@ -298,15 +318,14 @@ Expect SDTextToImage::body( Expect SDImageToImage::body( const Runtime::CallingFrame &Frame, uint32_t ImagePtr, uint32_t ImageLen, - uint32_t SessionId, uint32_t Width, uint32_t Height, + uint32_t SessionId, float Guidance, uint32_t Width, uint32_t Height, uint32_t ControlImagePtr, uint32_t ControlImageLen, uint32_t PromptPtr, uint32_t PromptLen, uint32_t NegativePromptPtr, uint32_t NegativePromptLen, int32_t ClipSkip, float CfgScale, uint32_t SampleMethod, uint32_t SampleSteps, float Strength, uint32_t Seed, uint32_t BatchCount, float ControlStrength, float StyleRatio, uint32_t NormalizeInput, uint32_t InputIdImagesDirPtr, uint32_t InputIdImagesDirLen, - uint32_t CannyPreprocess, uint32_t UpscaleModelPathPtr, - uint32_t UpscaleModelPathLen, uint32_t UpscaleRepeats, + uint32_t CannyPreprocess, uint32_t, uint32_t, uint32_t, uint32_t OutputPathPtr, uint32_t OutputPathLen, uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr) { // Check memory instance from module. @@ -360,7 +379,6 @@ Expect SDImageToImage::body( stbi_load_from_memory(ImageSpan.data(), ImageSpan.size(), &ImageWidth, &ImageHeight, &Channel, 3); } - // TODO: Resize image when image size not matches width and height sd_image_t InputImage = {Width, Height, 3, InputImageBuffer}; sd_image_t *ControlImage = nullptr; @@ -374,7 +392,7 @@ Expect SDImageToImage::body( sd_image_t *Results = nullptr; spdlog::info("[WasmEdge-StableDiffusion] Start to generate image."sv); Results = img2img(SDCtx, InputImage, Prompt.data(), NegativePrompt.data(), - ClipSkip, CfgScale, Width, Height, + ClipSkip, CfgScale, Guidance, Width, Height, sample_method_t(SampleMethod), SampleSteps, Strength, Seed, BatchCount, ControlImage, ControlStrength, StyleRatio, NormalizeInput, InputIdImagesDir.data()); diff --git a/plugins/wasmedge_stablediffusion/sd_func.h b/plugins/wasmedge_stablediffusion/sd_func.h index 88166c79..70b9d831 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.h +++ b/plugins/wasmedge_stablediffusion/sd_func.h @@ -16,8 +16,11 @@ class SDCreateContext : public StableDiffusion::Func { SDCreateContext(StableDiffusion::SDEnviornment &HostEnv) : Func(HostEnv) {} Expect body(const Runtime::CallingFrame &Frame, uint32_t ModelPathPtr, - uint32_t ModelPathLen, uint32_t VaePathPtr, uint32_t VaePathLen, - uint32_t TaesdPathPtr, uint32_t TaesdPathLen, uint32_t ControlNetPathPtr, + uint32_t ModelPathLen, uint32_t clipLPathPtr, uint32_t clipLPathLen, + uint32_t t5xxlPathPtr, uint32_t t5xxlPathLen, + uint32_t diffusionModelPathPtr, uint32_t diffusionModelPathLen, + uint32_t VaePathPtr, uint32_t VaePathLen, uint32_t TaesdPathPtr, + uint32_t TaesdPathLen, uint32_t ControlNetPathPtr, uint32_t ControlNetPathLen, uint32_t LoraModelDirPtr, uint32_t LoraModelDirLen, uint32_t EmbedDirPtr, uint32_t EmbedDirLen, uint32_t IdEmbedDirPtr, uint32_t IdEmbedDirLen, uint32_t VaeDecodeOnly, @@ -31,7 +34,7 @@ class SDImageToImage : public StableDiffusion::Func { SDImageToImage(StableDiffusion::SDEnviornment &HostEnv) : Func(HostEnv) {} Expect body(const Runtime::CallingFrame &Frame, uint32_t ImagePtr, uint32_t ImageLen, - uint32_t SessionId, uint32_t Width, uint32_t Height, + uint32_t SessionId, float Guidance, uint32_t Width, uint32_t Height, uint32_t ControlImagePtr, uint32_t ControlImageLen, uint32_t PromptPtr, uint32_t PromptLen, uint32_t NegativePromptPtr, uint32_t NegativePromptLen, int32_t ClipSkip, float CfgScale, @@ -52,8 +55,8 @@ class SDTextToImage : public StableDiffusion::Func { body(const Runtime::CallingFrame &Frame, uint32_t PromptPtr, uint32_t PromptLen, uint32_t SessionId, uint32_t ControlImagePtr, uint32_t ControlImageLen, uint32_t NegativePromptPtr, - uint32_t NegativePromptLen, uint32_t Width, uint32_t Height, - int32_t ClipSkip, float CfgScale, uint32_t SampleMethod, + uint32_t NegativePromptLen, float Guidance, uint32_t Width, + uint32_t Height, int32_t ClipSkip, float CfgScale, uint32_t SampleMethod, uint32_t SampleSteps, uint32_t Seed, uint32_t BatchCount, float ControlStrength, float StyleRatio, uint32_t NormalizeInput, uint32_t InputIdImagesDirPtr, uint32_t InputIdImagesDirLen, diff --git a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp index 77529201..a589a9f2 100644 --- a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp +++ b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp @@ -70,7 +70,7 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { WasmEdge::Runtime::Instance::ModuleInstance Mod(""); Mod.addHostMemory( "memory", std::make_unique( - WasmEdge::AST::MemoryType(60000))); + WasmEdge::AST::MemoryType(2097024))); auto *MemInstPtr = Mod.findMemoryExports("memory"); ASSERT_TRUE(MemInstPtr != nullptr); auto &MemInst = *MemInstPtr; @@ -154,28 +154,34 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { std::initializer_list{ QuantModelPathPtr, // ModelPathPtr static_cast(QuantModelPath.size()), // ModelPathLen - 0, // VaePathPtr - 0, // VaePathLen - 0, // TaesdPathPtr - 0, // TaesdPathLen - 0, // ControlNetPathPtr - 0, // ControlNetPathLen - 0, // LoraModelDirPtr - 0, // LoraModelDirLen - 0, // EmbedDirPtr - 0, // EmbedDirLen - 0, // IdEmbedDirPtr - 0, // IdEmbedDirLen - 1, // VaeDecodeOnly - 0, // VaeTiling - -1, // NThreads - 31, // Wtype - 1, // RngType - 0, // Schedule - 0, // ClipOnCpu - 0, // ControlNetCpu - 0, // VaeOnCpu - SessionPtr}, // SessiontIdPtr + 0, // ClipLPathPtr + 0, // ClipLPathLen + 0, // T5xxlPathPtr + 0, // T5xxlPathLen + 0, // DiffusionModelPathPtr + 0, // DiffusionModelPathLen + 0, // VaePathPtr + 0, // VaePathLen + 0, // TaesdPathPtr + 0, // TaesdPathLen + 0, // ControlNetPathPtr + 0, // ControlNetPathLen + 0, // LoraModelDirPtr + 0, // LoraModelDirLen + 0, // EmbedDirPtr + 0, // EmbedDirLen + 0, // IdEmbedDirPtr + 0, // IdEmbedDirLen + 1, // VaeDecodeOnly + 0, // VaeTiling + -1, // NThreads + 34, // Wtype + 1, // RngType + 0, // Schedule + 0, // ClipOnCpu + 0, // ControlNetCpu + 0, // VaeOnCpu + SessionPtr}, // SessiontIdPtr Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); SessionId = *MemInst.getPointer(SessionPtr); @@ -200,8 +206,9 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { 0, // ControlImageLen 0, // NegativePromptPtr 0, // NegativePromptLen - 64, // Width - 64, // Height + 3.5f, // Guidance + 256, // Width + 256, // Height -1, // ClipSkip 7.0f, // CfgScale 0, // SampleMethod @@ -220,7 +227,7 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { OutputPathPtr, // OutputPathPtr static_cast(OutputPath.size()), // OutputPathLen OutputPtr, // OutBufferPtr - 65532, // OutBufferMaxSize + 1048512, // OutBufferMaxSize BytesWrittenPtr}, // BytesWrittenPtr Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); @@ -237,28 +244,34 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { std::initializer_list{ QuantModelPathPtr, // ModelPathPtr static_cast(QuantModelPath.size()), // ModelPathLen - 0, // VaePathPtr - 0, // VaePathLen - 0, // TaesdPathPtr - 0, // TaesdPathLen - 0, // ControlNetPathPtr - 0, // ControlNetPathLen - 0, // LoraModelDirPtr - 0, // LoraModelDirLen - 0, // EmbedDirPtr - 0, // EmbedDirLen - 0, // IdEmbedDirPtr - 0, // IdEmbedDirLen - 0, // VaeDecodeOnly - 0, // VaeTiling - -1, // NThreads - 31, // Wtype - 1, // RngType - 0, // Schedule - 0, // ClipOnCpu - 0, // ControlNetCpu - 0, // VaeOnCpu - SessionPtr}, // SessiontIdPtr + 0, // ClipLPathPtr + 0, // ClipLPathLen + 0, // T5xxlPathPtr + 0, // T5xxlPathLen + 0, // DiffusionModelPathPtr + 0, // DiffusionModelPathLen + 0, // VaePathPtr + 0, // VaePathLen + 0, // TaesdPathPtr + 0, // TaesdPathLen + 0, // ControlNetPathPtr + 0, // ControlNetPathLen + 0, // LoraModelDirPtr + 0, // LoraModelDirLen + 0, // EmbedDirPtr + 0, // EmbedDirLen + 0, // IdEmbedDirPtr + 0, // IdEmbedDirLen + 0, // VaeDecodeOnly + 0, // VaeTiling + -1, // NThreads + 34, // Wtype + 1, // RngType + 0, // Schedule + 0, // ClipOnCpu + 0, // ControlNetCpu + 0, // VaeOnCpu + SessionPtr}, // SessiontIdPtr Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); SessionId = *MemInst.getPointer(SessionPtr); @@ -280,8 +293,9 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { InputPathPtr, // ImagePtr static_cast(InputPath.size()), // ImageLen SessionId, // SessionId - 64, // Width - 64, // Height + 3.5f, // Guidance + 256, // Width + 256, // Height 0, // ControlImagePtr 0, // ControlImageLen PromptPtr, // PromptPtr @@ -307,7 +321,7 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { OutputPathPtr, // OutputPathPtr static_cast(OutputPath2.size()), // OutputPathLen OutputPtr, // OutBufferPtr - 65532, // OutBufferMaxSize + 1048512, // OutBufferMaxSize BytesWrittenPtr}, // BytesWrittenPtr Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); From 9639d32dba412dfc6b6a518127e9aa0fe278ee8d Mon Sep 17 00:00:00 2001 From: grorge Date: Fri, 30 Aug 2024 20:46:04 +0800 Subject: [PATCH 423/623] [plugin] Stable Diffusion: fix ggml-common path Signed-off-by: grorge --- plugins/wasmedge_stablediffusion/CMakeLists.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index 5568673a..8c7d30db 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -97,6 +97,7 @@ else() -Wno-braced-scalar-init -Wno-unused-value -Wno-uninitialized + -Wno-format ) endif() @@ -105,7 +106,7 @@ if(WASMEDGE_PLUGIN_STABLEDIFFUSION_METAL) TARGET wasmedgePluginWasmEdgeStableDiffusion POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${stable-diffusion_SOURCE_DIR}/ggml/src/ggml-metal.metal ggml-metal.metal - COMMAND ${CMAKE_COMMAND} -E copy ${stable-diffusion_SOURCE_DIR}/ggml/ggml-common.h ggml-common.h + COMMAND ${CMAKE_COMMAND} -E copy ${stable-diffusion_SOURCE_DIR}/ggml/src/ggml-common.h ggml-common.h ) endif() From e857cdd031ee4194a556f3233eae5d1d1e0f8ba9 Mon Sep 17 00:00:00 2001 From: Yi-Ying He Date: Wed, 4 Sep 2024 19:42:50 +0800 Subject: [PATCH 424/623] [Plugin] Support `translate` and `language` option for WASI-NN (#3716) whisper backend. Signed-off-by: YiYing He --- plugins/wasi_nn/CMakeLists.txt | 2 + plugins/wasi_nn/whispercpp.cpp | 131 +++++++++++++++++++++++++++++++-- plugins/wasi_nn/whispercpp.h | 5 +- 3 files changed, 131 insertions(+), 7 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 8006f715..cf0af84a 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -168,6 +168,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) wasmedge_setup_simdjson() target_link_libraries(wasmedgePluginWasiNN PRIVATE simdjson::simdjson) elseif(BACKEND STREQUAL "whisper") + wasmedge_setup_simdjson() if(APPLE AND CMAKE_SYSTEM_VERSION VERSION_LESS 23) # `cblas_sgemm()` introduced in macOS 13.3. set(WHISPER_NO_ACCELERATE ON CACHE INTERNAL "Stable diffusion turn off accelerate") @@ -184,6 +185,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) set_property(TARGET whisper PROPERTY POSITION_INDEPENDENT_CODE ON) target_link_libraries(wasmedgePluginWasiNN PRIVATE whisper + simdjson::simdjson ) endif() endforeach() diff --git a/plugins/wasi_nn/whispercpp.cpp b/plugins/wasi_nn/whispercpp.cpp index bee0efff..adf42071 100644 --- a/plugins/wasi_nn/whispercpp.cpp +++ b/plugins/wasi_nn/whispercpp.cpp @@ -6,6 +6,7 @@ #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER #define DR_WAV_IMPLEMENTATION +#include "simdjson.h" #include #include @@ -137,6 +138,87 @@ void WhisperOutputSegmentCallback(struct whisper_context *WhisperCtx, } } +Expect parseMetadata(Graph &GraphRef, + const std::string &Metadata) noexcept { + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + auto ParseError = Parser.parse(Metadata).get(Doc); + if (ParseError) { + spdlog::error("[WASI-NN] Whisper backend: Parse metadata error."sv); + return ErrNo::InvalidEncoding; + } + + // Get metadata from the json. + // Currently supported metadata: + // Plugin parameters (used by this plugin): + // enable-log: bool + // enable-debug-log: bool + // translate: bool + // language: string + // detect-language: bool + // prompt: string + + // The plugin parameters. + if (Doc.at_key("enable-log").error() == simdjson::SUCCESS) { + auto Err = Doc["enable-log"].get().get(GraphRef.EnableLog); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the enable-log " + "option."sv); + return ErrNo::InvalidArgument; + } + } + if (Doc.at_key("enable-debug-log").error() == simdjson::SUCCESS) { + auto Err = Doc["enable-debug-log"].get().get(GraphRef.EnableDebugLog); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the enable-debug-log " + "option."sv); + return ErrNo::InvalidArgument; + } + } + if (Doc.at_key("translate").error() == simdjson::SUCCESS) { + auto Err = Doc["translate"].get().get(GraphRef.Translate); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the translate " + "option."sv); + return ErrNo::InvalidArgument; + } + } + if (Doc.at_key("language").error() == simdjson::SUCCESS) { + std::string_view Language; + auto Err = Doc["language"].get().get(Language); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the language " + "option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.SpokenLanguage = Language; + } + if (Doc.at_key("detect-language").error() == simdjson::SUCCESS) { + auto Err = Doc["detect-language"].get().get(GraphRef.DetectLanguage); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the detect-language " + "option."sv); + return ErrNo::InvalidArgument; + } + } + if (Doc.at_key("prompt").error() == simdjson::SUCCESS) { + std::string_view Prompt; + auto Err = Doc["prompt"].get().get(Prompt); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the prompt option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.InitialPrompt = Prompt; + } + return ErrNo::Success; +} + } // namespace Expect load(WasiNNEnvironment &Env, Span> Builders, @@ -147,17 +229,26 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Initialize the parameters. auto CParam = whisper_context_default_params(); - GraphRef.EnableLog = false; - GraphRef.EnableDebugLog = false; + GraphRef.ModelFilePath = ""sv; + GraphRef.SpokenLanguage = "en"sv; GraphRef.UseGPU = CParam.use_gpu; GraphRef.MainGPU = CParam.gpu_device; - GraphRef.ModelFilePath = ""sv; - GraphRef.ModelLanguage = "en"sv; // Set whisper log callback. whisper_log_set(WhisperLogCallback, &GraphRef); - // TODO: Use the metadata to pass data. + // If the graph builder length > 1, the data of builder[1] is the metadata. + if (Builders.size() > 1) { + const std::string Metadata(reinterpret_cast(Builders[1].data()), + Builders[1].size()); + // Ignore context or model updates when initializing the graph. + auto Res = parseMetadata(GraphRef, Metadata); + if (Res != ErrNo::Success) { + spdlog::error("[WASI-NN] Whisper backend: Failed to parse metadata."sv); + Env.NNGraph.pop_back(); + return Res; + } + } // Handle the model path. if (GraphRef.EnableDebugLog) { @@ -196,6 +287,31 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, "given parameters...Done"sv); } + // Check the language. + if (GraphRef.SpokenLanguage != "auto"sv && + whisper_lang_id(GraphRef.SpokenLanguage.c_str()) == -1) { + spdlog::error("[WASI-NN] Whisper backend: Error: unknown language {}."sv, + GraphRef.SpokenLanguage); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } + + // Check the translate option. + if (!whisper_is_multilingual(GraphRef.WhisperCtx)) { + if (GraphRef.SpokenLanguage != "en"sv || GraphRef.Translate) { + GraphRef.SpokenLanguage = "en"sv; + GraphRef.Translate = false; + if (GraphRef.EnableLog) { + spdlog::info( + "[WASI-NN] Whisper backend: Model is not multilingual. Ignoring " + "language and translation options"sv); + } + } + } + if (GraphRef.DetectLanguage) { + GraphRef.SpokenLanguage = "auto"sv; + } + // Store the loaded graph. GraphId = Env.NNGraph.size() - 1; @@ -214,7 +330,10 @@ Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, auto &WParam = CxtRef.WhisperParams; WParam.print_progress = false; WParam.thold_pt = GraphRef.WordThreshold; - WParam.language = GraphRef.ModelLanguage.c_str(); + WParam.translate = GraphRef.Translate; + WParam.language = GraphRef.SpokenLanguage.c_str(); + WParam.detect_language = GraphRef.DetectLanguage; + WParam.initial_prompt = GraphRef.InitialPrompt.c_str(); WParam.temperature_inc = GraphRef.TemperatureInc; WParam.temperature = GraphRef.Temperature; WParam.entropy_thold = GraphRef.EntropyThreshold; diff --git a/plugins/wasi_nn/whispercpp.h b/plugins/wasi_nn/whispercpp.h index c5d3631e..9dd36bbb 100644 --- a/plugins/wasi_nn/whispercpp.h +++ b/plugins/wasi_nn/whispercpp.h @@ -22,10 +22,13 @@ namespace WasmEdge::Host::WASINN::Whisper { struct Graph { whisper_context *WhisperCtx = nullptr; std::string ModelFilePath; - std::string ModelLanguage; // Whisper parameters: bool EnableLog = false; bool EnableDebugLog = false; + bool Translate = false; + bool DetectLanguage = false; + std::string SpokenLanguage; + std::string InitialPrompt; // Context parameters: bool UseGPU = true; int64_t MainGPU = 0; // Use GPU 0 by default From cb61d6bde1e81623cc87950bc616d3dc27cdbe49 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Thu, 22 Aug 2024 15:43:53 +0800 Subject: [PATCH 425/623] [Plugin] Correct the plugin names. 1. Move the `wasi_llm` to `wasmedge_llmc`. 2. Move the `wasi_ocr` to `wasmedge_ocr`. Signed-off-by: YiYing He --- plugins/CMakeLists.txt | 4 +- plugins/wasi_llm/types.h | 16 ---- plugins/wasi_llm/wasillmfunc.h | 93 ------------------- plugins/wasi_llm/wasillmmodule.cpp | 20 ---- plugins/wasi_ocr/wasiocrmodule.cpp | 17 ---- .../CMakeLists.txt | 20 ++-- .../llmc_base.h} | 14 +-- .../llmc_env.cpp} | 31 +++---- .../wasillmenv.h => wasmedge_llmc/llmc_env.h} | 16 +++- .../llmc_func.cpp} | 84 +++++++++-------- plugins/wasmedge_llmc/llmc_func.h | 91 ++++++++++++++++++ .../{wasi_llm => wasmedge_llmc}/llmc_fwd.h | 2 +- plugins/wasmedge_llmc/llmc_module.cpp | 20 ++++ .../llmc_module.h} | 12 ++- .../{wasi_ocr => wasmedge_ocr}/CMakeLists.txt | 22 ++--- .../wasiocrbase.h => wasmedge_ocr/ocr_base.h} | 12 ++- .../ocr_env.cpp} | 16 ++-- .../wasiocrenv.h => wasmedge_ocr/ocr_env.h} | 16 ++-- .../ocr_func.cpp} | 24 ++--- .../wasiocrfunc.h => wasmedge_ocr/ocr_func.h} | 14 +-- plugins/wasmedge_ocr/ocr_module.cpp | 17 ++++ .../ocr_module.h} | 11 ++- test/plugins/CMakeLists.txt | 2 +- .../CMakeLists.txt | 31 ++++--- .../wasmedge_llmc.cpp} | 62 +++++++------ 25 files changed, 337 insertions(+), 330 deletions(-) delete mode 100644 plugins/wasi_llm/types.h delete mode 100644 plugins/wasi_llm/wasillmfunc.h delete mode 100644 plugins/wasi_llm/wasillmmodule.cpp delete mode 100644 plugins/wasi_ocr/wasiocrmodule.cpp rename plugins/{wasi_llm => wasmedge_llmc}/CMakeLists.txt (61%) rename plugins/{wasi_llm/wasillmbase.h => wasmedge_llmc/llmc_base.h} (51%) rename plugins/{wasi_llm/wasillmenv.cpp => wasmedge_llmc/llmc_env.cpp} (71%) rename plugins/{wasi_llm/wasillmenv.h => wasmedge_llmc/llmc_env.h} (84%) rename plugins/{wasi_llm/wasillmfunc.cpp => wasmedge_llmc/llmc_func.cpp} (56%) create mode 100644 plugins/wasmedge_llmc/llmc_func.h rename plugins/{wasi_llm => wasmedge_llmc}/llmc_fwd.h (96%) create mode 100644 plugins/wasmedge_llmc/llmc_module.cpp rename plugins/{wasi_llm/wasillmmodule.h => wasmedge_llmc/llmc_module.h} (61%) rename plugins/{wasi_ocr => wasmedge_ocr}/CMakeLists.txt (63%) rename plugins/{wasi_ocr/wasiocrbase.h => wasmedge_ocr/ocr_base.h} (52%) rename plugins/{wasi_ocr/wasiocrenv.cpp => wasmedge_ocr/ocr_env.cpp} (81%) rename plugins/{wasi_ocr/wasiocrenv.h => wasmedge_ocr/ocr_env.h} (81%) rename plugins/{wasi_ocr/wasiocrfunc.cpp => wasmedge_ocr/ocr_func.cpp} (70%) rename plugins/{wasi_ocr/wasiocrfunc.h => wasmedge_ocr/ocr_func.h} (61%) create mode 100644 plugins/wasmedge_ocr/ocr_module.cpp rename plugins/{wasi_ocr/wasiocrmodule.h => wasmedge_ocr/ocr_module.h} (55%) rename test/plugins/{wasi_llm => wasmedge_llmc}/CMakeLists.txt (66%) rename test/plugins/{wasi_llm/wasi_llm.cpp => wasmedge_llmc/wasmedge_llmc.cpp} (77%) diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index 4f14f1dc..ff6a2b14 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -96,6 +96,6 @@ if(WASMEDGE_PLUGIN_FFMPEG) add_subdirectory(wasmedge_ffmpeg) endif() -if(WASMEDGE_PLUGIN_LLM) - add_subdirectory(wasi_llm) +if(WASMEDGE_PLUGIN_LLMC) + add_subdirectory(wasmedge_llmc) endif() diff --git a/plugins/wasi_llm/types.h b/plugins/wasi_llm/types.h deleted file mode 100644 index 71e73641..00000000 --- a/plugins/wasi_llm/types.h +++ /dev/null @@ -1,16 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2024 Second State INC - -#pragma once - -#include - -namespace WasmEdge::Host::WASILLM { - -enum class ErrNo : uint32_t { - Success = 0, - InvalidArgument = 1, - MissingMemory = 2, -}; - -} // namespace WasmEdge::Host::WASILLM diff --git a/plugins/wasi_llm/wasillmfunc.h b/plugins/wasi_llm/wasillmfunc.h deleted file mode 100644 index 4c588b19..00000000 --- a/plugins/wasi_llm/wasillmfunc.h +++ /dev/null @@ -1,93 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2024 Second State INC - -#pragma once - -#include "runtime/callingframe.h" -#include "types.h" -#include "wasillmbase.h" -#include "wasillmenv.h" - -#include - -namespace WasmEdge { -namespace Host { - -class WasiLLMModelCreate : public WasiLLM { -public: - explicit WasiLLMModelCreate(WASILLM::WASILLMEnv &Env) : WasiLLM(Env) {} - - Expect body(const Runtime::CallingFrame &Frame, - uint32_t CheckPointPath, uint32_t CheckPointPathLen, - uint32_t ModelIdPtr) { - return bodyImpl(Frame, CheckPointPath, CheckPointPathLen, ModelIdPtr) - .map(castErrNo); - } - -private: - Expect bodyImpl(const Runtime::CallingFrame &Frame, - uint32_t CheckPointPath, - uint32_t CheckPointPathLen, - uint32_t ModelIdPtr); -}; - -class WasiLLMDataLoaderCreate : public WasiLLM { -public: - explicit WasiLLMDataLoaderCreate(WASILLM::WASILLMEnv &Env) : WasiLLM(Env) {} - - Expect body(const Runtime::CallingFrame &Frame, uint32_t DataPath, - uint32_t DataPathLen, uint32_t B, uint32_t T, - uint32_t ProcessRank, uint32_t NumProcesses, - int32_t ShouldShuffle, uint32_t DataLoaderIdPtr) { - return bodyImpl(Frame, DataPath, DataPathLen, B, T, ProcessRank, - NumProcesses, ShouldShuffle, DataLoaderIdPtr) - .map(castErrNo); - } - -private: - Expect bodyImpl(const Runtime::CallingFrame &Frame, - uint32_t DataPath, uint32_t DataPathLen, - uint32_t B, uint32_t T, uint32_t ProcessRank, - uint32_t NumProcesses, int32_t ShouldShuffle, - uint32_t DataLoaderIdPtr); -}; - -class WasiLLMTokenizerCreate : public WasiLLM { -public: - explicit WasiLLMTokenizerCreate(WASILLM::WASILLMEnv &Env) : WasiLLM(Env) {} - - Expect body(const Runtime::CallingFrame &Frame, uint32_t FilePath, - uint32_t FilePathLen, uint32_t TokenizerIdPtr) { - return bodyImpl(Frame, FilePath, FilePathLen, TokenizerIdPtr) - .map(castErrNo); - } - -private: - Expect bodyImpl(const Runtime::CallingFrame &Frame, - uint32_t FilePath, uint32_t FilePathLen, - uint32_t TokenizerIdPtr); -}; - -class WasiLLMModelTrain : public WasiLLM { -public: - explicit WasiLLMModelTrain(WASILLM::WASILLMEnv &Env) : WasiLLM(Env) {} - - Expect body(const Runtime::CallingFrame &Frame, uint32_t ModelId, - uint32_t TrainDataLoaderId, uint32_t ValDataLoaderId, - uint32_t TokenizerId, uint32_t B, uint32_t T, float Lr, - uint32_t Epoch) { - return bodyImpl(Frame, ModelId, TrainDataLoaderId, ValDataLoaderId, - TokenizerId, B, T, Lr, Epoch) - .map(castErrNo); - } - -private: - Expect bodyImpl(const Runtime::CallingFrame &Frame, - uint32_t ModelId, uint32_t TrainDataLoaderId, - uint32_t ValDataLoaderId, - uint32_t TokenizerId, uint32_t B, uint32_t T, - float Lr, uint32_t Epoch); -}; - -} // namespace Host -} // namespace WasmEdge diff --git a/plugins/wasi_llm/wasillmmodule.cpp b/plugins/wasi_llm/wasillmmodule.cpp deleted file mode 100644 index 36742ac3..00000000 --- a/plugins/wasi_llm/wasillmmodule.cpp +++ /dev/null @@ -1,20 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2024 Second State INC - -#include "wasillmmodule.h" -#include "wasillmfunc.h" - -namespace WasmEdge { -namespace Host { - -WasiLLMModule::WasiLLMModule() : ModuleInstance("wasi_llm") { - addHostFunc("model_create", std::make_unique(Env)); - addHostFunc("dataloader_create", - std::make_unique(Env)); - addHostFunc("tokenizer_create", - std::make_unique(Env)); - addHostFunc("model_train", std::make_unique(Env)); -} - -} // namespace Host -} // namespace WasmEdge diff --git a/plugins/wasi_ocr/wasiocrmodule.cpp b/plugins/wasi_ocr/wasiocrmodule.cpp deleted file mode 100644 index ea16f215..00000000 --- a/plugins/wasi_ocr/wasiocrmodule.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2023 Second State INC - -#include "wasiocrmodule.h" -#include "wasiocrfunc.h" - -namespace WasmEdge { -namespace Host { - -WasiOCRModule::WasiOCRModule() : ModuleInstance("wasi_ephemeral_ocr") { - addHostFunc("num_of_extractions", - std::make_unique(Env)); - addHostFunc("get_output", std::make_unique(Env)); -} - -} // namespace Host -} // namespace WasmEdge diff --git a/plugins/wasi_llm/CMakeLists.txt b/plugins/wasmedge_llmc/CMakeLists.txt similarity index 61% rename from plugins/wasi_llm/CMakeLists.txt rename to plugins/wasmedge_llmc/CMakeLists.txt index 11b96d92..7a486e1b 100644 --- a/plugins/wasi_llm/CMakeLists.txt +++ b/plugins/wasmedge_llmc/CMakeLists.txt @@ -3,11 +3,11 @@ # TODO: Fetch llm.c source. -wasmedge_add_library(wasmedgePluginWasiLLM +wasmedge_add_library(wasmedgePluginWasmEdgeLLMC SHARED - wasillmfunc.cpp - wasillmmodule.cpp - wasillmenv.cpp + llmc_func.cpp + llmc_module.cpp + llmc_env.cpp ) message(STATUS "Start fetching llm.c source") @@ -18,34 +18,34 @@ FetchContent_Declare( ) FetchContent_MakeAvailable(llmc) -target_link_libraries(wasmedgePluginWasiLLM PRIVATE +target_link_libraries(wasmedgePluginWasmEdgeLLMC PRIVATE train_gpt2_cpu ) -target_compile_options(wasmedgePluginWasiLLM +target_compile_options(wasmedgePluginWasmEdgeLLMC PUBLIC -DWASMEDGE_PLUGIN ) -target_include_directories(wasmedgePluginWasiLLM +target_include_directories(wasmedgePluginWasmEdgeLLMC PUBLIC $ ${CMAKE_CURRENT_SOURCE_DIR} ) if(WASMEDGE_LINK_PLUGINS_STATIC) - target_link_libraries(wasmedgePluginWasiLLM + target_link_libraries(wasmedgePluginWasmEdgeLLMC PRIVATE wasmedgeCAPI ) else() - target_link_libraries(wasmedgePluginWasiLLM + target_link_libraries(wasmedgePluginWasmEdgeLLMC PRIVATE wasmedge_shared ) endif() install( - TARGETS wasmedgePluginWasiLLM + TARGETS wasmedgePluginWasmEdgeLLMC DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge ) diff --git a/plugins/wasi_llm/wasillmbase.h b/plugins/wasmedge_llmc/llmc_base.h similarity index 51% rename from plugins/wasi_llm/wasillmbase.h rename to plugins/wasmedge_llmc/llmc_base.h index fcc941ce..6af0d36a 100644 --- a/plugins/wasi_llm/wasillmbase.h +++ b/plugins/wasmedge_llmc/llmc_base.h @@ -3,24 +3,26 @@ #pragma once +#include "llmc_env.h" + #include "common/errcode.h" #include "runtime/hostfunc.h" -#include "types.h" -#include "wasillmenv.h" namespace WasmEdge { namespace Host { +namespace WasmEdgeLLMC { -template class WasiLLM : public Runtime::HostFunction { +template class HostFunction : public Runtime::HostFunction { public: - WasiLLM(WASILLM::WASILLMEnv &E) : Runtime::HostFunction(0), Env(E) {} + HostFunction(LLMCEnv &E) : Runtime::HostFunction(0), Env(E) {} protected: - static constexpr uint32_t castErrNo(WASILLM::ErrNo E) noexcept { + static constexpr uint32_t castErrNo(ErrNo E) noexcept { return static_cast(E); } - WASILLM::WASILLMEnv &Env; + LLMCEnv &Env; }; +} // namespace WasmEdgeLLMC } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasi_llm/wasillmenv.cpp b/plugins/wasmedge_llmc/llmc_env.cpp similarity index 71% rename from plugins/wasi_llm/wasillmenv.cpp rename to plugins/wasmedge_llmc/llmc_env.cpp index 336ece58..d254ebce 100644 --- a/plugins/wasi_llm/wasillmenv.cpp +++ b/plugins/wasmedge_llmc/llmc_env.cpp @@ -1,46 +1,45 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "wasillmenv.h" +#include "llmc_env.h" #include "llmc_fwd.h" -#include "wasillmmodule.h" +#include "llmc_module.h" namespace WasmEdge { namespace Host { +namespace WasmEdgeLLMC { -namespace WASILLM { - -uint32_t WASILLMEnv::addModel(GPT2 *M) noexcept { +uint32_t LLMCEnv::addModel(GPT2 *M) noexcept { Models.push_back(M); return Models.size() - 1; } -GPT2 *WASILLMEnv::getModel(uint32_t Id) noexcept { +GPT2 *LLMCEnv::getModel(uint32_t Id) noexcept { assert(Id < Models.size() && "Out of bounds"); return Models[Id]; } -uint32_t WASILLMEnv::addTokenizer(Tokenizer *T) noexcept { +uint32_t LLMCEnv::addTokenizer(Tokenizer *T) noexcept { Tokenizers.push_back(T); return Tokenizers.size() - 1; } -Tokenizer *WASILLMEnv::getTokenizer(uint32_t Id) noexcept { +Tokenizer *LLMCEnv::getTokenizer(uint32_t Id) noexcept { assert(Id < Tokenizers.size() && "Out of bounds"); return Tokenizers[Id]; } -uint32_t WASILLMEnv::addDataLoader(DataLoader *D) noexcept { +uint32_t LLMCEnv::addDataLoader(DataLoader *D) noexcept { DataLoaders.push_back(D); return DataLoaders.size() - 1; } -DataLoader *WASILLMEnv::getDataLoader(uint32_t Id) noexcept { +DataLoader *LLMCEnv::getDataLoader(uint32_t Id) noexcept { assert(Id < DataLoaders.size() && "Out of bounds"); return DataLoaders[Id]; } -WASILLMEnv::~WASILLMEnv() { +LLMCEnv::~LLMCEnv() { for (GPT2 *M : Models) { gpt2_free(M); } @@ -53,22 +52,21 @@ WASILLMEnv::~WASILLMEnv() { } namespace { - Runtime::Instance::ModuleInstance * create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { - return new WasiLLMModule; + return new WasmEdgeLLMCModule; } static Plugin::PluginModule::ModuleDescriptor MD[] = { { - /* Name */ "wasi_llm", + /* Name */ "wasmedge_llmc", /* Description */ "", /* Create */ create, }, }; Plugin::Plugin::PluginDescriptor Descriptor{ - /* Name */ "wasi_llm", + /* Name */ "wasmedge_llmc", /* Description */ "", /* APIVersion */ Plugin::Plugin::CurrentAPIVersion, /* Version */ {0, 1, 0, 0}, @@ -82,7 +80,6 @@ Plugin::Plugin::PluginDescriptor Descriptor{ EXPORT_GET_DESCRIPTOR(Descriptor) -} // namespace WASILLM - +} // namespace WasmEdgeLLMC } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasi_llm/wasillmenv.h b/plugins/wasmedge_llmc/llmc_env.h similarity index 84% rename from plugins/wasi_llm/wasillmenv.h rename to plugins/wasmedge_llmc/llmc_env.h index af409023..66b9b8c1 100644 --- a/plugins/wasi_llm/wasillmenv.h +++ b/plugins/wasmedge_llmc/llmc_env.h @@ -4,6 +4,7 @@ #pragma once #include "plugin/plugin.h" + #include #include #include @@ -16,9 +17,15 @@ struct DataLoader; namespace WasmEdge { namespace Host { -namespace WASILLM { +namespace WasmEdgeLLMC { + +enum class ErrNo : uint32_t { + Success = 0, + InvalidArgument = 1, + MissingMemory = 2, +}; -class WASILLMEnv { +class LLMCEnv { std::vector Models; std::vector Tokenizers; std::vector DataLoaders; @@ -42,8 +49,9 @@ class WASILLMEnv { size_t getDataLoaderSize() const noexcept { return DataLoaders.size(); } - ~WASILLMEnv(); + ~LLMCEnv(); }; -} // namespace WASILLM + +} // namespace WasmEdgeLLMC } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasi_llm/wasillmfunc.cpp b/plugins/wasmedge_llmc/llmc_func.cpp similarity index 56% rename from plugins/wasi_llm/wasillmfunc.cpp rename to plugins/wasmedge_llmc/llmc_func.cpp index ec794909..b90fba78 100644 --- a/plugins/wasi_llm/wasillmfunc.cpp +++ b/plugins/wasmedge_llmc/llmc_func.cpp @@ -1,69 +1,71 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "wasillmfunc.h" +#include "llmc_func.h" +#include "llmc_fwd.h" + #include "common/errcode.h" #include "common/spdlog.h" -#include "llmc_fwd.h" -#include "types.h" + #include #include namespace WasmEdge { namespace Host { +namespace WasmEdgeLLMC { -Expect -WasiLLMModelCreate::bodyImpl(const Runtime::CallingFrame &Frame, - uint32_t CheckPointPath, - uint32_t CheckPointPathLen, uint32_t ModelIdPtr) { +Expect ModelCreate::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t CheckPointPath, + uint32_t CheckPointPathLen, + uint32_t ModelIdPtr) { auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { - spdlog::error("[WasmEdge-LLM] Memory instance not found."sv); - return WASILLM::ErrNo::MissingMemory; + spdlog::error("[WasmEdge-LLMC] Memory instance not found."sv); + return ErrNo::MissingMemory; } auto CheckPointPathSpan = MemInst->getSpan(CheckPointPath, CheckPointPathLen); if (unlikely(CheckPointPathSpan.size() != CheckPointPathLen)) { spdlog::error( - "[WasmEdge-LLM] Failed when accessing the input checkpoint path memory."sv); - return WASILLM::ErrNo::MissingMemory; + "[WasmEdge-LLMC] Failed when accessing the input checkpoint path memory."sv); + return ErrNo::MissingMemory; } auto *ModelId = MemInst->getPointer(ModelIdPtr); if (unlikely(ModelId == nullptr)) { spdlog::error( - "[WasmEdge-LLM] Failed when accessing the return model memory."sv); - return WASILLM::ErrNo::InvalidArgument; + "[WasmEdge-LLMC] Failed when accessing the return model memory."sv); + return ErrNo::InvalidArgument; } std::string CheckPointPathStr = std::string(CheckPointPathSpan.begin(), CheckPointPathSpan.begin() + CheckPointPathSpan.size()); GPT2 *Model = gpt2_create(CheckPointPathStr.data()); *ModelId = Env.addModel(Model); - return WASILLM::ErrNo::Success; + return ErrNo::Success; } -Expect WasiLLMDataLoaderCreate::bodyImpl( +Expect DataLoaderCreate::bodyImpl( const Runtime::CallingFrame &Frame, uint32_t DataPath, uint32_t DataPathLen, uint32_t B, uint32_t T, uint32_t ProcessRank, uint32_t NumProcesses, int32_t ShouldShuffle, uint32_t DataLoaderIdPtr) { auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { - spdlog::error("[WasmEdge-LLM] Memory instance not found."sv); - return WASILLM::ErrNo::MissingMemory; + spdlog::error("[WasmEdge-LLMC] Memory instance not found."sv); + return ErrNo::MissingMemory; } auto DataPathSpan = MemInst->getSpan(DataPath, DataPathLen); if (unlikely(DataPathSpan.size() != DataPathLen)) { spdlog::error( - "[WasmEdge-LLM] Failed when accessing the input dataloader path memory."sv); - return WASILLM::ErrNo::MissingMemory; + "[WasmEdge-LLMC] Failed when accessing the input dataloader path memory."sv); + return ErrNo::MissingMemory; } auto *DataLoaderId = MemInst->getPointer(DataLoaderIdPtr); if (unlikely(DataLoaderId == nullptr)) { spdlog::error( - "[WasmEdge-LLM] Failed when accessing the return dataloader memory."sv); - return WASILLM::ErrNo::InvalidArgument; + "[WasmEdge-LLMC] Failed when accessing the return dataloader memory."sv); + return ErrNo::InvalidArgument; } std::string DataPathStr = std::string( @@ -71,55 +73,55 @@ Expect WasiLLMDataLoaderCreate::bodyImpl( DataLoader *D = dataloader_create(DataPathStr.data(), B, T, ProcessRank, NumProcesses, ShouldShuffle); *DataLoaderId = Env.addDataLoader(D); - return WASILLM::ErrNo::Success; + return ErrNo::Success; } -Expect -WasiLLMTokenizerCreate::bodyImpl(const Runtime::CallingFrame &Frame, - uint32_t FilePath, uint32_t FilePathLen, - uint32_t TokenizerIdPtr) { +Expect TokenizerCreate::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t FilePath, uint32_t FilePathLen, + uint32_t TokenizerIdPtr) { auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { - spdlog::error("[WasmEdge-LLM] Memory instance not found."sv); - return WASILLM::ErrNo::MissingMemory; + spdlog::error("[WasmEdge-LLMC] Memory instance not found."sv); + return ErrNo::MissingMemory; } auto FilePathSpan = MemInst->getSpan(FilePath, FilePathLen); if (unlikely(FilePathSpan.size() != FilePathLen)) { spdlog::error( - "[WasmEdge-LLM] Failed when accessing the input tokenizer path memory."sv); - return WASILLM::ErrNo::MissingMemory; + "[WasmEdge-LLMC] Failed when accessing the input tokenizer path memory."sv); + return ErrNo::MissingMemory; } auto *TokenizerId = MemInst->getPointer(TokenizerIdPtr); if (unlikely(TokenizerId == nullptr)) { spdlog::error( - "[WasmEdge-LLM] Failed when accessing the return tokenizer memory."sv); - return WASILLM::ErrNo::InvalidArgument; + "[WasmEdge-LLMC] Failed when accessing the return tokenizer memory."sv); + return ErrNo::InvalidArgument; } std::string FilePathStr = std::string( FilePathSpan.begin(), FilePathSpan.begin() + FilePathSpan.size()); Tokenizer *T = tokenizer_create(FilePathStr.data()); *TokenizerId = Env.addTokenizer(T); - return WASILLM::ErrNo::Success; + return ErrNo::Success; } -Expect -WasiLLMModelTrain::bodyImpl(const Runtime::CallingFrame &Frame, - uint32_t ModelId, uint32_t TrainDataLoaderId, - uint32_t ValDataLoaderId, uint32_t TokenizerId, - uint32_t B, uint32_t T, float Lr, uint32_t Epoch) { +Expect ModelTrain::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t ModelId, uint32_t TrainDataLoaderId, + uint32_t ValDataLoaderId, + uint32_t TokenizerId, uint32_t B, uint32_t T, + float Lr, uint32_t Epoch) { auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { - spdlog::error("[WasmEdge-LLM] Memory instance not found."sv); - return WASILLM::ErrNo::MissingMemory; + spdlog::error("[WasmEdge-LLMC] Memory instance not found."sv); + return ErrNo::MissingMemory; } GPT2 *Model = Env.getModel(ModelId); DataLoader *TrainDataLoader = Env.getDataLoader(TrainDataLoaderId); DataLoader *ValDataLoader = Env.getDataLoader(ValDataLoaderId); Tokenizer *Tokenizer = Env.getTokenizer(TokenizerId); gpt2_train(Model, TrainDataLoader, ValDataLoader, Tokenizer, B, T, Lr, Epoch); - return WASILLM::ErrNo::Success; + return ErrNo::Success; } +} // namespace WasmEdgeLLMC } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasmedge_llmc/llmc_func.h b/plugins/wasmedge_llmc/llmc_func.h new file mode 100644 index 00000000..4bae88fb --- /dev/null +++ b/plugins/wasmedge_llmc/llmc_func.h @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "llmc_base.h" +#include "llmc_env.h" + +#include "runtime/callingframe.h" + +#include + +namespace WasmEdge { +namespace Host { +namespace WasmEdgeLLMC { + +class ModelCreate : public HostFunction { +public: + explicit ModelCreate(LLMCEnv &Env) : HostFunction(Env) {} + + Expect body(const Runtime::CallingFrame &Frame, + uint32_t CheckPointPath, uint32_t CheckPointPathLen, + uint32_t ModelIdPtr) { + return bodyImpl(Frame, CheckPointPath, CheckPointPathLen, ModelIdPtr) + .map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t CheckPointPath, uint32_t CheckPointPathLen, + uint32_t ModelIdPtr); +}; + +class DataLoaderCreate : public HostFunction { +public: + explicit DataLoaderCreate(LLMCEnv &Env) : HostFunction(Env) {} + + Expect body(const Runtime::CallingFrame &Frame, uint32_t DataPath, + uint32_t DataPathLen, uint32_t B, uint32_t T, + uint32_t ProcessRank, uint32_t NumProcesses, + int32_t ShouldShuffle, uint32_t DataLoaderIdPtr) { + return bodyImpl(Frame, DataPath, DataPathLen, B, T, ProcessRank, + NumProcesses, ShouldShuffle, DataLoaderIdPtr) + .map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, uint32_t DataPath, + uint32_t DataPathLen, uint32_t B, uint32_t T, + uint32_t ProcessRank, uint32_t NumProcesses, + int32_t ShouldShuffle, uint32_t DataLoaderIdPtr); +}; + +class TokenizerCreate : public HostFunction { +public: + explicit TokenizerCreate(LLMCEnv &Env) : HostFunction(Env) {} + + Expect body(const Runtime::CallingFrame &Frame, uint32_t FilePath, + uint32_t FilePathLen, uint32_t TokenizerIdPtr) { + return bodyImpl(Frame, FilePath, FilePathLen, TokenizerIdPtr) + .map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, uint32_t FilePath, + uint32_t FilePathLen, uint32_t TokenizerIdPtr); +}; + +class ModelTrain : public HostFunction { +public: + explicit ModelTrain(LLMCEnv &Env) : HostFunction(Env) {} + + Expect body(const Runtime::CallingFrame &Frame, uint32_t ModelId, + uint32_t TrainDataLoaderId, uint32_t ValDataLoaderId, + uint32_t TokenizerId, uint32_t B, uint32_t T, float Lr, + uint32_t Epoch) { + return bodyImpl(Frame, ModelId, TrainDataLoaderId, ValDataLoaderId, + TokenizerId, B, T, Lr, Epoch) + .map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, uint32_t ModelId, + uint32_t TrainDataLoaderId, uint32_t ValDataLoaderId, + uint32_t TokenizerId, uint32_t B, uint32_t T, float Lr, + uint32_t Epoch); +}; + +} // namespace WasmEdgeLLMC +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_llm/llmc_fwd.h b/plugins/wasmedge_llmc/llmc_fwd.h similarity index 96% rename from plugins/wasi_llm/llmc_fwd.h rename to plugins/wasmedge_llmc/llmc_fwd.h index 4486be35..804893eb 100644 --- a/plugins/wasi_llm/llmc_fwd.h +++ b/plugins/wasmedge_llmc/llmc_fwd.h @@ -3,7 +3,7 @@ #pragma once -#include "wasillmenv.h" +#include "llmc_env.h" extern "C" { diff --git a/plugins/wasmedge_llmc/llmc_module.cpp b/plugins/wasmedge_llmc/llmc_module.cpp new file mode 100644 index 00000000..8914eb03 --- /dev/null +++ b/plugins/wasmedge_llmc/llmc_module.cpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "llmc_module.h" +#include "llmc_func.h" + +namespace WasmEdge { +namespace Host { + +WasmEdgeLLMCModule::WasmEdgeLLMCModule() : ModuleInstance("wasmedge_llmc") { + addHostFunc("model_create", std::make_unique(Env)); + addHostFunc("dataloader_create", + std::make_unique(Env)); + addHostFunc("tokenizer_create", + std::make_unique(Env)); + addHostFunc("model_train", std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_llm/wasillmmodule.h b/plugins/wasmedge_llmc/llmc_module.h similarity index 61% rename from plugins/wasi_llm/wasillmmodule.h rename to plugins/wasmedge_llmc/llmc_module.h index 6a6574c2..86a923c3 100644 --- a/plugins/wasi_llm/wasillmmodule.h +++ b/plugins/wasmedge_llmc/llmc_module.h @@ -3,17 +3,19 @@ #pragma once +#include "llmc_env.h" + #include "runtime/instance/module.h" -#include "wasillmenv.h" namespace WasmEdge { namespace Host { -class WasiLLMModule : public Runtime::Instance::ModuleInstance { - WASILLM::WASILLMEnv Env; - +class WasmEdgeLLMCModule : public Runtime::Instance::ModuleInstance { public: - WasiLLMModule(); + WasmEdgeLLMCModule(); + +private: + WasmEdgeLLMC::LLMCEnv Env; }; } // namespace Host diff --git a/plugins/wasi_ocr/CMakeLists.txt b/plugins/wasmedge_ocr/CMakeLists.txt similarity index 63% rename from plugins/wasi_ocr/CMakeLists.txt rename to plugins/wasmedge_ocr/CMakeLists.txt index 40fe5f8b..e3139149 100644 --- a/plugins/wasi_ocr/CMakeLists.txt +++ b/plugins/wasmedge_ocr/CMakeLists.txt @@ -1,38 +1,38 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: 2019-2024 Second State INC -add_library(wasmedgePluginWasiOCR +add_library(wasmedgePluginWasmEdgeOCR SHARED - wasiocrenv.cpp - wasiocrfunc.cpp - wasiocrmodule.cpp + ocr_env.cpp + ocr_func.cpp + ocr_module.cpp ) -target_compile_options(wasmedgePluginWasiOCR +target_compile_options(wasmedgePluginWasmEdgeOCR PUBLIC -DWASMEDGE_PLUGIN ) -target_include_directories(wasmedgePluginWasiOCR +target_include_directories(wasmedgePluginWasmEdgeOCR PUBLIC $ ${CMAKE_CURRENT_SOURCE_DIR} ) if(WASMEDGE_LINK_PLUGINS_STATIC) - target_link_libraries(wasmedgePluginWasiOCR + target_link_libraries(wasmedgePluginWasmEdgeOCR PRIVATE wasmedgeCAPI ) else() - target_link_libraries(wasmedgePluginWasiOCR + target_link_libraries(wasmedgePluginWasmEdgeOCR PRIVATE wasmedge_shared ) endif() install( - TARGETS wasmedgePluginWasiOCR + TARGETS wasmedgePluginWasmEdgeOCR DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge ) @@ -41,13 +41,13 @@ find_package(PkgConfig REQUIRED) pkg_search_module(TESSERACT REQUIRED tesseract) pkg_search_module(LEPTONICA REQUIRED lept) -target_include_directories(wasmedgePluginWasiOCR +target_include_directories(wasmedgePluginWasmEdgeOCR PUBLIC ${TESSERACT_INCLUDE_DIRS} ${LEPTONICA_INCLUDE_DIRS} ) -target_link_libraries(wasmedgePluginWasiOCR +target_link_libraries(wasmedgePluginWasmEdgeOCR PUBLIC ${TESSERACT_LIBRARIES} ${LEPTONICA_LIBRARIES} diff --git a/plugins/wasi_ocr/wasiocrbase.h b/plugins/wasmedge_ocr/ocr_base.h similarity index 52% rename from plugins/wasi_ocr/wasiocrbase.h rename to plugins/wasmedge_ocr/ocr_base.h index 38fd9d72..dc525dd2 100644 --- a/plugins/wasi_ocr/wasiocrbase.h +++ b/plugins/wasmedge_ocr/ocr_base.h @@ -3,21 +3,23 @@ #pragma once +#include "ocr_env.h" + #include "common/errcode.h" #include "runtime/hostfunc.h" -#include "wasiocrenv.h" namespace WasmEdge { namespace Host { +namespace WasmEdgeOCR { -template class WasiOCR : public Runtime::HostFunction { +template class HostFunction : public Runtime::HostFunction { public: - WasiOCR(WASIOCR::WasiOCREnvironment &HostEnv) - : Runtime::HostFunction(0), Env(HostEnv) {} + HostFunction(OCREnv &HostEnv) : Runtime::HostFunction(0), Env(HostEnv) {} protected: - WASIOCR::WasiOCREnvironment &Env; + OCREnv &Env; }; +} // namespace WasmEdgeOCR } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasi_ocr/wasiocrenv.cpp b/plugins/wasmedge_ocr/ocr_env.cpp similarity index 81% rename from plugins/wasi_ocr/wasiocrenv.cpp rename to plugins/wasmedge_ocr/ocr_env.cpp index 3df9d69c..111b5136 100644 --- a/plugins/wasi_ocr/wasiocrenv.cpp +++ b/plugins/wasmedge_ocr/ocr_env.cpp @@ -1,21 +1,20 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2023 Second State INC -#include "wasiocrenv.h" -#include "wasiocrmodule.h" +#include "ocr_env.h" +#include "ocr_module.h" namespace WasmEdge { namespace Host { - namespace { Runtime::Instance::ModuleInstance * create(const Plugin::PluginModule::ModuleDescriptor *) noexcept { - return new WasiOCRModule; + return new WasmEdgeOCRModule; } Plugin::Plugin::PluginDescriptor Descriptor{ - .Name = "wasi_ocr", + .Name = "wasmedge_ocr", .Description = "A WasmEdge Plugin for Optical Character Recognition (OCR) " "powered by the Tesseract API.", .APIVersion = Plugin::Plugin::CurrentAPIVersion, @@ -24,7 +23,7 @@ Plugin::Plugin::PluginDescriptor Descriptor{ .ModuleDescriptions = (Plugin::PluginModule::ModuleDescriptor[]){ { - .Name = "wasi_ocr", + .Name = "wasmedge_ocr", .Description = "A WasmEdge Plugin for Optical Character Recognition (OCR) " "powered by the Tesseract API.", @@ -34,9 +33,8 @@ Plugin::Plugin::PluginDescriptor Descriptor{ .AddOptions = nullptr, }; -} // namespace - -Plugin::PluginRegister WASIOCR::WasiOCREnvironment::Register(&Descriptor); +EXPORT_GET_DESCRIPTOR(Descriptor) +} // namespace } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasi_ocr/wasiocrenv.h b/plugins/wasmedge_ocr/ocr_env.h similarity index 81% rename from plugins/wasi_ocr/wasiocrenv.h rename to plugins/wasmedge_ocr/ocr_env.h index 14ec7562..e4a0ddf4 100644 --- a/plugins/wasi_ocr/wasiocrenv.h +++ b/plugins/wasmedge_ocr/ocr_env.h @@ -5,14 +5,15 @@ #include "common/spdlog.h" #include "plugin/plugin.h" -#include #include #include +#include + namespace WasmEdge { namespace Host { -namespace WASIOCR { +namespace WasmEdgeOCR { enum class ErrNo : uint32_t { Success = 0, // No error occurred. @@ -21,16 +22,17 @@ enum class ErrNo : uint32_t { Busy = 3 // Device or resource busy. }; -class WasiOCREnvironment { +class OCREnv { public: - WasiOCREnvironment() noexcept { + OCREnv() noexcept { // check Tesseract API by initializing tesseract-ocr with English, without // specifying tessdata path if (TesseractApi->Init(NULL, "eng")) { - spdlog::error("[WASI-OCR] Error occurred when initializing tesseract."); + spdlog::error( + "[WasmEdge-OCR] Error occurred when initializing tesseract."); } } - ~WasiOCREnvironment() noexcept { + ~OCREnv() noexcept { if (TesseractApi) { TesseractApi->End(); ; @@ -41,6 +43,6 @@ class WasiOCREnvironment { static Plugin::PluginRegister Register; }; -} // namespace WASIOCR +} // namespace WasmEdgeOCR } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasi_ocr/wasiocrfunc.cpp b/plugins/wasmedge_ocr/ocr_func.cpp similarity index 70% rename from plugins/wasi_ocr/wasiocrfunc.cpp rename to plugins/wasmedge_ocr/ocr_func.cpp index da97b794..e75ef40f 100644 --- a/plugins/wasi_ocr/wasiocrfunc.cpp +++ b/plugins/wasmedge_ocr/ocr_func.cpp @@ -1,7 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2023 Second State INC -#include "wasiocrfunc.h" +#include "ocr_func.h" + #include "common/spdlog.h" #include @@ -9,10 +10,11 @@ namespace WasmEdge { namespace Host { +namespace WasmEdgeOCR { -Expect -WasiOCRNumOfExtractions::body(const Runtime::CallingFrame &Frame, - uint32_t ImagePathPtr, uint32_t ImagePathLen) { +Expect NumOfExtractions::body(const Runtime::CallingFrame &Frame, + uint32_t ImagePathPtr, + uint32_t ImagePathLen) { // Check memory instance from module. auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { @@ -35,10 +37,9 @@ WasiOCRNumOfExtractions::body(const Runtime::CallingFrame &Frame, return static_cast(length); } -Expect WasiOCRGetOutput::body(const Runtime::CallingFrame &Frame, - uint32_t OutBufferPtr [[maybe_unused]], - uint32_t OutBufferMaxSize - [[maybe_unused]]) { +Expect GetOutput::body(const Runtime::CallingFrame &Frame, + uint32_t OutBufferPtr [[maybe_unused]], + uint32_t OutBufferMaxSize [[maybe_unused]]) { // Check memory instance from module. auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { @@ -49,8 +50,8 @@ Expect WasiOCRGetOutput::body(const Runtime::CallingFrame &Frame, auto Buf = MemInst->getSpan(OutBufferPtr, OutBufferMaxSize); if (unlikely(Buf.empty())) { spdlog::error( - "[WASI-OCR] Failed when accessing the return OutBufferPtr memory."); - return static_cast(WASIOCR::ErrNo::InvalidArgument); + "[WasmEdge-OCR] Failed when accessing the return OutBufferPtr memory."); + return static_cast(ErrNo::InvalidArgument); } tesseract::PageIteratorLevel level = tesseract::RIL_WORD; @@ -61,9 +62,10 @@ Expect WasiOCRGetOutput::body(const Runtime::CallingFrame &Frame, // remaining free and deltee memory stuff Env.TesseractApi->End(); - return static_cast(WASIOCR::ErrNo::Success); + return static_cast(ErrNo::Success); // return outText; } +} // namespace WasmEdgeOCR } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasi_ocr/wasiocrfunc.h b/plugins/wasmedge_ocr/ocr_func.h similarity index 61% rename from plugins/wasi_ocr/wasiocrfunc.h rename to plugins/wasmedge_ocr/ocr_func.h index ad98639b..b00f191d 100644 --- a/plugins/wasi_ocr/wasiocrfunc.h +++ b/plugins/wasmedge_ocr/ocr_func.h @@ -3,28 +3,30 @@ #pragma once +#include "ocr_base.h" + #include "runtime/callingframe.h" -#include "wasiocrbase.h" #include namespace WasmEdge { namespace Host { +namespace WasmEdgeOCR { -class WasiOCRNumOfExtractions : public WasiOCR { +class NumOfExtractions : public HostFunction { public: - WasiOCRNumOfExtractions(WASIOCR::WasiOCREnvironment &HostEnv) - : WasiOCR(HostEnv) {} + NumOfExtractions(OCREnv &HostEnv) : HostFunction(HostEnv) {} Expect body(const Runtime::CallingFrame &, uint32_t ImagePathPtr, uint32_t ImagePathLen); }; -class WasiOCRGetOutput : public WasiOCR { +class GetOutput : public HostFunction { public: - WasiOCRGetOutput(WASIOCR::WasiOCREnvironment &HostEnv) : WasiOCR(HostEnv) {} + GetOutput(OCREnv &HostEnv) : HostFunction(HostEnv) {} Expect body(const Runtime::CallingFrame &, uint32_t OutBufferPtr, uint32_t OutBufferMaxSize); }; +} // namespace WasmEdgeOCR } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasmedge_ocr/ocr_module.cpp b/plugins/wasmedge_ocr/ocr_module.cpp new file mode 100644 index 00000000..667c69c1 --- /dev/null +++ b/plugins/wasmedge_ocr/ocr_module.cpp @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2023 Second State INC + +#include "ocr_module.h" +#include "ocr_func.h" + +namespace WasmEdge { +namespace Host { + +WasmEdgeOCRModule::WasmEdgeOCRModule() : ModuleInstance("wasmedge_ocr") { + addHostFunc("num_of_extractions", + std::make_unique(Env)); + addHostFunc("get_output", std::make_unique(Env)); +} + +} // namespace Host +} // namespace WasmEdge diff --git a/plugins/wasi_ocr/wasiocrmodule.h b/plugins/wasmedge_ocr/ocr_module.h similarity index 55% rename from plugins/wasi_ocr/wasiocrmodule.h rename to plugins/wasmedge_ocr/ocr_module.h index cbab6d5c..4f2b64c1 100644 --- a/plugins/wasi_ocr/wasiocrmodule.h +++ b/plugins/wasmedge_ocr/ocr_module.h @@ -3,20 +3,21 @@ #pragma once +#include "ocr_env.h" + #include "runtime/instance/module.h" -#include "wasiocrenv.h" namespace WasmEdge { namespace Host { -class WasiOCRModule : public Runtime::Instance::ModuleInstance { +class WasmEdgeOCRModule : public Runtime::Instance::ModuleInstance { public: - WasiOCRModule(); + WasmEdgeOCRModule(); - WASIOCR::WasiOCREnvironment &getEnv() { return Env; } + WasmEdgeOCR::OCREnv &getEnv() { return Env; } private: - WASIOCR::WasiOCREnvironment Env; + WasmEdgeOCR::OCREnv Env; }; } // namespace Host diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt index e8397a55..a3f35f49 100644 --- a/test/plugins/CMakeLists.txt +++ b/test/plugins/CMakeLists.txt @@ -59,7 +59,7 @@ endif() add_subdirectory(wasi_logging) -if(WASMEDGE_PLUGIN_LLM) +if(WASMEDGE_PLUGIN_LLMC) add_subdirectory(wasi_llm) endif() diff --git a/test/plugins/wasi_llm/CMakeLists.txt b/test/plugins/wasmedge_llmc/CMakeLists.txt similarity index 66% rename from test/plugins/wasi_llm/CMakeLists.txt rename to test/plugins/wasmedge_llmc/CMakeLists.txt index a29cd29e..820e6503 100644 --- a/test/plugins/wasi_llm/CMakeLists.txt +++ b/test/plugins/wasmedge_llmc/CMakeLists.txt @@ -1,30 +1,33 @@ -wasmedge_add_executable(wasiLLMTests - wasi_llm.cpp +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +wasmedge_add_executable(wasmedgeLLMCTests + wasmedge_llmc.cpp ) -add_dependencies(wasiLLMTests - wasmedgePluginWasiLLM +add_dependencies(wasmedgeLLMCTests + wasmedgePluginWasmEdgeLLMC ) -target_include_directories(wasiLLMTests +target_include_directories(wasmedgeLLMCTests PUBLIC $ - $ + $ ) -target_link_libraries(wasiLLMTests +target_link_libraries(wasmedgeLLMCTests PRIVATE ${GTEST_BOTH_LIBRARIES} ) # Link to the WasmEdge library if(WASMEDGE_LINK_PLUGINS_STATIC) - target_link_libraries(wasiLLMTests + target_link_libraries(wasmedgeLLMCTests PRIVATE wasmedgeCAPI ) else() - target_link_libraries(wasiLLMTests + target_link_libraries(wasmedgeLLMCTests PRIVATE wasmedge_shared ) @@ -42,26 +45,26 @@ endfunction() message(STATUS "Downloading GPT2 model check point to ${CMAKE_CURRENT_BINARY_DIR}/gpt2_124M.bin") download( https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/gpt2_124M.bin - ${CMAKE_CURRENT_BINARY_DIR}/wasi_llm/gpt2_124M.bin + ${CMAKE_CURRENT_BINARY_DIR}/wasmedge_llmc/gpt2_124M.bin SHA256=3da8b207584030bcdcd207cf7a99952e3421dce92da218b351071857511bf162 ) message(STATUS "Downloading training dataset to ${CMAKE_CURRENT_BINARY_DIR}/tiny_shakespeare_train.bin") download( https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/tiny_shakespeare_train.bin - ${CMAKE_CURRENT_BINARY_DIR}/wasi_llm/tiny_shakespeare_train.bin + ${CMAKE_CURRENT_BINARY_DIR}/wasmedge_llmc/tiny_shakespeare_train.bin SHA256=8a70606be574040c26d225694f5f9759973b419852d22f7fe5c118e1b359dcc8 ) message(STATUS "Downloading validation dataset to ${CMAKE_CURRENT_BINARY_DIR}/tiny_shakespeare_val.bin") download( https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/tiny_shakespeare_val.bin - ${CMAKE_CURRENT_BINARY_DIR}/wasi_llm/tiny_shakespeare_val.bin + ${CMAKE_CURRENT_BINARY_DIR}/wasmedge_llmc/tiny_shakespeare_val.bin SHA256=fe99db720dc7c83e694806d4e047a952909411da1daccde4ccc2e55f40882a62 ) message(STATUS "Downloading tokenizer data to ${CMAKE_CURRENT_BINARY_DIR}/gpt2_tokenizer.bin") download( https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/gpt2_tokenizer.bin - ${CMAKE_CURRENT_BINARY_DIR}/wasi_llm/gpt2_tokenizer.bin + ${CMAKE_CURRENT_BINARY_DIR}/wasmedge_llmc/gpt2_tokenizer.bin SHA256=6f3abc21e444e4e8300e225f4e03da48ea121cf17e30f67009b8dad7a66c2f13 ) -add_test(wasiLLMTests wasiLLMTests) +add_test(wasmedgeLLMCTests wasmedgeLLMCTests) diff --git a/test/plugins/wasi_llm/wasi_llm.cpp b/test/plugins/wasmedge_llmc/wasmedge_llmc.cpp similarity index 77% rename from test/plugins/wasi_llm/wasi_llm.cpp rename to test/plugins/wasmedge_llmc/wasmedge_llmc.cpp index 408c06c6..4fcbff20 100644 --- a/test/plugins/wasi_llm/wasi_llm.cpp +++ b/test/plugins/wasmedge_llmc/wasmedge_llmc.cpp @@ -1,11 +1,14 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "llmc_func.h" +#include "llmc_module.h" + #include "common/defines.h" #include "common/types.h" #include "plugin/plugin.h" #include "runtime/callingframe.h" #include "runtime/instance/module.h" -#include "types.h" -#include "wasillmfunc.h" -#include "wasillmmodule.h" #include #include @@ -16,7 +19,7 @@ #include #include -using WasmEdge::Host::WASILLM::ErrNo; +using WasmEdge::Host::WasmEdgeLLMC::ErrNo; namespace { @@ -30,14 +33,14 @@ inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { return std::unique_ptr(P); } -std::unique_ptr createModule() { +std::unique_ptr createModule() { using namespace std::literals::string_view_literals; - WasmEdge::Plugin::Plugin::load( - std::filesystem::u8path("../../../plugins/wasi_llm/" WASMEDGE_LIB_PREFIX - "wasmedgePluginWasiLLM" WASMEDGE_LIB_EXTENSION)); - if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasi_llm"sv)) { - if (const auto *Module = Plugin->findModule("wasi_llm"sv)) { - return dynamicPointerCast( + WasmEdge::Plugin::Plugin::load(std::filesystem::u8path( + "../../../plugins/wasmedge_llmc/" WASMEDGE_LIB_PREFIX + "wasmedgePluginWasmEdgeLLMC" WASMEDGE_LIB_EXTENSION)); + if (const auto *Plugin = WasmEdge::Plugin::Plugin::find("wasmedge_llmc"sv)) { + if (const auto *Module = Plugin->findModule("wasmedge_llmc"sv)) { + return dynamicPointerCast( Module->create()); } } @@ -58,11 +61,11 @@ void writeUInt32(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, Ptr += 4; } -TEST(WasiLLMTest, TrainGPT2) { - // Create wasi_llm module instance. - auto LLMMod = createModule(); - ASSERT_TRUE(LLMMod); - EXPECT_EQ(LLMMod->getFuncExportNum(), 4U); +TEST(WasmEdgeLLMTest, TrainGPT2) { + // Create wasmedge_llmc module instance. + auto LLMCMod = createModule(); + ASSERT_TRUE(LLMCMod); + EXPECT_EQ(LLMCMod->getFuncExportNum(), 4U); // Create the calling frame with memory instance. WasmEdge::Runtime::Instance::ModuleInstance Mod(""); @@ -74,53 +77,54 @@ TEST(WasiLLMTest, TrainGPT2) { auto &MemInst = *MemInstPtr; WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); - auto *ModelCreate = LLMMod->findFuncExports("model_create"); + auto *ModelCreate = LLMCMod->findFuncExports("model_create"); EXPECT_NE(ModelCreate, nullptr); EXPECT_TRUE(ModelCreate->isHostFunction()); auto &HostFuncModelCreate = - dynamic_cast( + dynamic_cast( ModelCreate->getHostFunc()); - auto *DataLoaderCreate = LLMMod->findFuncExports("dataloader_create"); + auto *DataLoaderCreate = LLMCMod->findFuncExports("dataloader_create"); EXPECT_NE(DataLoaderCreate, nullptr); EXPECT_TRUE(DataLoaderCreate->isHostFunction()); auto &HostFuncDataLoadereCreate = - dynamic_cast( + dynamic_cast( DataLoaderCreate->getHostFunc()); - auto *TokenizerCreate = LLMMod->findFuncExports("tokenizer_create"); + auto *TokenizerCreate = LLMCMod->findFuncExports("tokenizer_create"); EXPECT_NE(TokenizerCreate, nullptr); EXPECT_TRUE(TokenizerCreate->isHostFunction()); auto &HostFuncTokenizerCreate = - dynamic_cast( + dynamic_cast( TokenizerCreate->getHostFunc()); - auto *ModelTrain = LLMMod->findFuncExports("model_train"); + auto *ModelTrain = LLMCMod->findFuncExports("model_train"); EXPECT_NE(ModelTrain, nullptr); EXPECT_TRUE(ModelTrain->isHostFunction()); - auto &HostFuncModelTrain = dynamic_cast( - ModelTrain->getHostFunc()); + auto &HostFuncModelTrain = + dynamic_cast( + ModelTrain->getHostFunc()); std::array Errno = {UINT32_C(0)}; - std::string CheckPointString = "./wasi_llm/gpt2_124M.bin"; + std::string CheckPointString = "./wasmedge_llmc/gpt2_124M.bin"; std::vector CheckPointPath(CheckPointString.begin(), CheckPointString.end()); uint32_t CheckPointPathPtr = UINT32_C(0); writeBinaries(MemInst, CheckPointPath, CheckPointPathPtr); - std::string TrainDataString = "./wasi_llm/tiny_shakespeare_train.bin"; + std::string TrainDataString = "./wasmedge_llmc/tiny_shakespeare_train.bin"; std::vector TrainDataPath(TrainDataString.begin(), TrainDataString.end()); uint32_t TrainDataPathPtr = CheckPointPathPtr + CheckPointPath.size(); writeBinaries(MemInst, TrainDataPath, TrainDataPathPtr); - std::string ValDataString = "./wasi_llm/tiny_shakespeare_val.bin"; + std::string ValDataString = "./wasmedge_llmc/tiny_shakespeare_val.bin"; std::vector ValDataPath(ValDataString.begin(), ValDataString.end()); uint32_t ValDataPathPtr = TrainDataPathPtr + TrainDataPath.size(); writeBinaries(MemInst, ValDataPath, ValDataPathPtr); - std::string TokenizerBin = "./wasi_llm/gpt2_tokenizer.bin"; + std::string TokenizerBin = "./wasmedge_llmc/gpt2_tokenizer.bin"; std::vector TokenizerBinPath(TokenizerBin.begin(), TokenizerBin.end()); uint32_t TokenizerBinPtr = ValDataPathPtr + ValDataPath.size(); writeBinaries(MemInst, TokenizerBinPath, TokenizerBinPtr); From f4edc5f5bce981e94f9e68944e68487cde674e24 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Wed, 4 Sep 2024 20:04:13 +0800 Subject: [PATCH 426/623] [CMake] Refine the order of plugins in CMake. Signed-off-by: YiYing He --- plugins/CMakeLists.txt | 104 +++++++++++++++++++++--------------- test/plugins/CMakeLists.txt | 93 ++++++++++++++++++++++---------- 2 files changed, 127 insertions(+), 70 deletions(-) diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index ff6a2b14..8f57a5fa 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -1,22 +1,51 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: 2019-2024 Second State INC +# WASI plug-in: WASI-Crypto proposal. +if(WASMEDGE_PLUGIN_WASI_CRYPTO) + add_subdirectory(wasi_crypto) +endif() + +# WASI plug-in: WASI-Http proposal. if(WASMEDGE_PLUGIN_WASI_HTTP) add_subdirectory(wasi_http) endif() -if(WASMEDGE_PLUGIN_WASI_POLL) - add_subdirectory(wasi_poll) +# WASI plug-in: WASI-Logging proposal. +if(WASMEDGE_PLUGIN_WASI_LOGGING) + # BUILTIN-PLUGIN: Add the wasi-logging plugin here after the new plugin + # architecture ready in 0.15.0. endif() +# WASI plug-in: WASI-NN proposal with backends. if(WASMEDGE_PLUGIN_WASI_NN_BACKEND) add_subdirectory(wasi_nn) endif() +if(WASMEDGE_PLUGIN_WASI_NN_BURNRS_MODEL) + add_subdirectory(wasi_nn_burnrs) +endif() -if(WASMEDGE_PLUGIN_WASI_CRYPTO) - add_subdirectory(wasi_crypto) +# WASI plug-in: WASI-Poll proposal. +if(WASMEDGE_PLUGIN_WASI_POLL) + add_subdirectory(wasi_poll) endif() +# WasmEdge plug-in: wasm-bpf. +if(WASMEDGE_PLUGIN_WASM_BPF) + # Only Linux systems support wasm_bpf now. + if(CMAKE_SYSTEM_NAME MATCHES "Linux") + add_subdirectory(wasm_bpf) + else() + message(WARNING "Only Linux platforms support wasm_bpf plug-in now.") + endif() +endif() + +# WasmEdge plug-in: ffmpeg. +if(WASMEDGE_PLUGIN_FFMPEG) + add_subdirectory(wasmedge_ffmpeg) +endif() + +# WasmEdge plug-in: Image. if(WASMEDGE_PLUGIN_IMAGE) # Only Linux and MacOS support wasmedge_image now. if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") @@ -26,24 +55,27 @@ if(WASMEDGE_PLUGIN_IMAGE) endif() endif() -if(WASMEDGE_PLUGIN_TENSORFLOW) - # Only Linux and MacOS support wasmedge_tensorflow now. - if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") - add_subdirectory(wasmedge_tensorflow) - else() - message(WARNING "Only Linux and Darwin platforms support WasmEdge_Tensorflow plug-in now.") - endif() +# WasmEdge plug-in: LLMC. +if(WASMEDGE_PLUGIN_LLMC) + add_subdirectory(wasmedge_llmc) endif() -if(WASMEDGE_PLUGIN_TENSORFLOWLITE) - # Only Linux and MacOS support wasmedge_tensorflowlite now. +# WasmEdge plug-in: OCR. +if(WASMEDGE_PLUGIN_OCR) + add_subdirectory(wasmedge_ocr) +endif() + +# WasmEdge plug-in: OpenCV-mini. +if(WASMEDGE_PLUGIN_OPENCVMINI) + # Only Linux and MacOS support wasmedge_opencvmini now. if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") - add_subdirectory(wasmedge_tensorflowlite) + add_subdirectory(wasmedge_opencvmini) else() - message(WARNING "Only Linux and Darwin platforms support WasmEdge_TensorflowLite plug-in now.") + message(WARNING "Only Linux and Darwin platforms support WasmEdge_OpenCVMini plug-in now.") endif() endif() +# WasmEdge plug-in: Process. if(WASMEDGE_PLUGIN_PROCESS) # Only Linux systems support wasmedge_process now. if(CMAKE_SYSTEM_NAME MATCHES "Linux") @@ -53,19 +85,7 @@ if(WASMEDGE_PLUGIN_PROCESS) endif() endif() -if(WASMEDGE_PLUGIN_WASM_BPF) - # Only Linux systems support wasm_bpf now. - if(CMAKE_SYSTEM_NAME MATCHES "Linux") - add_subdirectory(wasm_bpf) - else() - message(WARNING "Only Linux platforms support wasm_bpf plug-in now.") - endif() -endif() - -if(WASMEDGE_PLUGIN_WASI_OCR) - add_subdirectory(wasi_ocr) -endif() - +# WasmEdge plug-in: Stable-diffusion. if(WASMEDGE_PLUGIN_STABLEDIFFUSION) # Only Linux and MacOS support wasmedge_stablediffusion now. if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") @@ -75,27 +95,27 @@ if(WASMEDGE_PLUGIN_STABLEDIFFUSION) endif() endif() -if(WASMEDGE_PLUGIN_OPENCVMINI) - # Only Linux and MacOS support wasmedge_opencvmini now. +# WasmEdge plug-in: TensorFlow. +if(WASMEDGE_PLUGIN_TENSORFLOW) + # Only Linux and MacOS support wasmedge_tensorflow now. if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") - add_subdirectory(wasmedge_opencvmini) + add_subdirectory(wasmedge_tensorflow) else() - message(WARNING "Only Linux and Darwin platforms support WasmEdge_OpenCVMini plug-in now.") + message(WARNING "Only Linux and Darwin platforms support WasmEdge_Tensorflow plug-in now.") endif() endif() -if(WASMEDGE_PLUGIN_WASI_NN_BURNRS_MODEL) - add_subdirectory(wasi_nn_burnrs) +# WasmEdge plug-in: TensorFlow-Lite. +if(WASMEDGE_PLUGIN_TENSORFLOWLITE) + # Only Linux and MacOS support wasmedge_tensorflowlite now. + if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") + add_subdirectory(wasmedge_tensorflowlite) + else() + message(WARNING "Only Linux and Darwin platforms support WasmEdge_TensorflowLite plug-in now.") + endif() endif() +# WasmEdge plug-in: zlib. if(WASMEDGE_PLUGIN_ZLIB) add_subdirectory(wasmedge_zlib) endif() - -if(WASMEDGE_PLUGIN_FFMPEG) - add_subdirectory(wasmedge_ffmpeg) -endif() - -if(WASMEDGE_PLUGIN_LLMC) - add_subdirectory(wasmedge_llmc) -endif() diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt index a3f35f49..457b1653 100644 --- a/test/plugins/CMakeLists.txt +++ b/test/plugins/CMakeLists.txt @@ -1,66 +1,103 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: 2019-2024 Second State INC -if(WASMEDGE_PLUGIN_FFMPEG) - add_subdirectory(wasmedge_ffmpeg) -endif() - -if(WASMEDGE_PLUGIN_PROCESS) - if(CMAKE_SYSTEM_NAME MATCHES "Linux") - add_subdirectory(wasmedge_process) - endif() +# WASI plug-in: WASI-Crypto proposal. +if(WASMEDGE_PLUGIN_WASI_CRYPTO) + add_subdirectory(wasi_crypto) endif() -if(WASMEDGE_PLUGIN_ZLIB) - add_subdirectory(wasmedge_zlib) -endif() +# WASI plug-in: WASI-Logging proposal. +add_subdirectory(wasi_logging) +# WASI plug-in: WASI-NN proposal with backends. if(WASMEDGE_PLUGIN_WASI_NN_BACKEND) add_subdirectory(wasi_nn) endif() -if(WASMEDGE_PLUGIN_WASI_CRYPTO) - add_subdirectory(wasi_crypto) -endif() - -if(WASMEDGE_PLUGIN_TENSORFLOW) - if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") - add_subdirectory(wasmedge_tensorflow) +# WasmEdge plug-in: wasm-bpf. +if(WASMEDGE_PLUGIN_WASM_BPF) + # Only Linux systems support wasm_bpf now. + if(CMAKE_SYSTEM_NAME MATCHES "Linux") + add_subdirectory(wasm_bpf) + else() + message(WARNING "Only Linux platforms support wasm_bpf plug-in now.") endif() endif() -if(WASMEDGE_PLUGIN_TENSORFLOWLITE) - if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") - add_subdirectory(wasmedge_tensorflowlite) - endif() +# WasmEdge plug-in: ffmpeg. +if(WASMEDGE_PLUGIN_FFMPEG) + add_subdirectory(wasmedge_ffmpeg) endif() +# WasmEdge plug-in: Image. if(WASMEDGE_PLUGIN_IMAGE) + # Only Linux and MacOS support wasmedge_image now. if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") add_subdirectory(wasmedge_image) + else() + message(WARNING "Only Linux and Darwin platforms support WasmEdge_Image plug-in now.") endif() endif() +# WasmEdge plug-in: LLMC. +if(WASMEDGE_PLUGIN_LLMC) + add_subdirectory(wasmedge_llmc) +endif() + +# WasmEdge plug-in: OpenCV-mini. if(WASMEDGE_PLUGIN_OPENCVMINI) + # Only Linux and MacOS support wasmedge_opencvmini now. if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") add_subdirectory(wasmedge_opencvmini) + else() + message(WARNING "Only Linux and Darwin platforms support WasmEdge_OpenCVMini plug-in now.") endif() endif() -if(WASMEDGE_PLUGIN_WASM_BPF) +# WasmEdge plug-in: Process. +if(WASMEDGE_PLUGIN_PROCESS) + # Only Linux systems support wasmedge_process now. if(CMAKE_SYSTEM_NAME MATCHES "Linux") - add_subdirectory(wasm_bpf) + add_subdirectory(wasmedge_process) + else() + message(WARNING "Only Linux platforms support WasmEdge_Process plug-in now.") endif() endif() +# WasmEdge plug-in: Stable-diffusion. if(WASMEDGE_PLUGIN_STABLEDIFFUSION) - add_subdirectory(wasmedge_stablediffusion) + # Only Linux and MacOS support wasmedge_stablediffusion now. + if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") + add_subdirectory(wasmedge_stablediffusion) + else() + message(WARNING "Only Linux and Darwin platforms support WasmEdge_StableDiffusion plug-in now.") + endif() endif() -add_subdirectory(wasi_logging) +# WasmEdge plug-in: TensorFlow. +if(WASMEDGE_PLUGIN_TENSORFLOW) + # Only Linux and MacOS support wasmedge_tensorflow now. + if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") + add_subdirectory(wasmedge_tensorflow) + else() + message(WARNING "Only Linux and Darwin platforms support WasmEdge_Tensorflow plug-in now.") + endif() +endif() -if(WASMEDGE_PLUGIN_LLMC) - add_subdirectory(wasi_llm) +# WasmEdge plug-in: TensorFlow-Lite. +if(WASMEDGE_PLUGIN_TENSORFLOWLITE) + # Only Linux and MacOS support wasmedge_tensorflowlite now. + if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") + add_subdirectory(wasmedge_tensorflowlite) + else() + message(WARNING "Only Linux and Darwin platforms support WasmEdge_TensorflowLite plug-in now.") + endif() +endif() + +# WasmEdge plug-in: zlib. +if(WASMEDGE_PLUGIN_ZLIB) + add_subdirectory(wasmedge_zlib) endif() +# Plug-in unit tests. add_subdirectory(unittest) From b8e5f797ccf479b48513402200bb337a8a376b74 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Wed, 4 Sep 2024 20:39:39 +0800 Subject: [PATCH 427/623] [Plugin] Refine the include directories in stable-diffusion. Signed-off-by: YiYing He --- plugins/wasmedge_stablediffusion/CMakeLists.txt | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index 8c7d30db..9936d3f6 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -49,7 +49,11 @@ wasmedge_add_library(wasmedgePluginWasmEdgeStableDiffusion sd_module.cpp ) -target_link_libraries(wasmedgePluginWasmEdgeStableDiffusion PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(wasmedgePluginWasmEdgeStableDiffusion + PRIVATE + stable-diffusion + ${CMAKE_THREAD_LIBS_INIT} +) target_compile_options(wasmedgePluginWasmEdgeStableDiffusion PUBLIC @@ -73,9 +77,13 @@ target_include_directories(wasmedgePluginWasmEdgeStableDiffusion $ ${CMAKE_CURRENT_SOURCE_DIR} ) -target_include_directories( wasmedgePluginWasmEdgeStableDiffusion SYSTEM PRIVATE - "${CMAKE_BINARY_DIR}/_deps/stable-diffusion-src/thirdparty" + +target_include_directories(wasmedgePluginWasmEdgeStableDiffusion + SYSTEM + PRIVATE + "${stable-diffusion_SOURCE_DIR}/thirdparty" ) + if (MSVC) target_compile_options( stable-diffusion From 088a4ba89ef5c2678a1a021851b81152b8e06818 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Thu, 5 Sep 2024 16:18:28 +0800 Subject: [PATCH 428/623] [CI] Fix the CI for WASI-NN burn.rs backend. Signed-off-by: YiYing He --- plugins/wasi_nn_burnrs/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn_burnrs/CMakeLists.txt b/plugins/wasi_nn_burnrs/CMakeLists.txt index f7967a46..ad7c1344 100644 --- a/plugins/wasi_nn_burnrs/CMakeLists.txt +++ b/plugins/wasi_nn_burnrs/CMakeLists.txt @@ -16,7 +16,7 @@ set(RS_SO ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_DIR}/libwasmedgePluginWasiNN${CMA set(WASMEDGE_LIB_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../lib/api) -add_custom_target(wasi_nn_burnrs ALL +add_custom_target(wasmedgePluginWasiNNBurnRS ALL COMMAND WASMEDGE_LIB_DIR=${WASMEDGE_LIB_DIR} LD_LIBARAY_PATH=${WASMEDGE_LIB_DIR} CARGO_TARGET_DIR=${CMAKE_CURRENT_BINARY_DIR} ${CARGO_CMD} ${CARGO_FEATURES} COMMAND ${CMAKE_COMMAND} -E copy ${RS_SO} ${CMAKE_CURRENT_BINARY_DIR} COMMAND ${CMAKE_COMMAND} -E remove_directory ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_DIR} From 57454a974c1069e6e35669a03258a7fd4714006b Mon Sep 17 00:00:00 2001 From: hydai Date: Mon, 9 Sep 2024 06:11:14 +0800 Subject: [PATCH 429/623] [WASI-NN] neural speed: removed due to the upstream end-of-life (#3745) Ref: https://github.com/intel/neural-speed Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 5 +- plugins/wasi_nn/neuralspeed.cpp | 322 +-------------------------- plugins/wasi_nn/neuralspeed.h | 31 --- test/plugins/wasi_nn/CMakeLists.txt | 7 +- test/plugins/wasi_nn/wasi_nn.cpp | 265 ---------------------- utils/wasi-nn/install-neuralspeed.sh | 13 -- 6 files changed, 8 insertions(+), 635 deletions(-) delete mode 100644 utils/wasi-nn/install-neuralspeed.sh diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index cf0af84a..328ba9dc 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -149,7 +149,10 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) COMMAND ${CMAKE_COMMAND} -E copy ${llama_SOURCE_DIR}/ggml/src/ggml-common.h ggml-common.h ) endif() - elseif(BACKEND STREQUAL "neuralspeed" OR BACKEND STREQUAL "chattts") + elseif(BACKEND STREQUAL "neuralspeed") + message(NOTICE "WASI-NN NeuralSpeed backend is removed due to the upstream end-of-life.") + message(NOTICE "Reference: https://github.com/intel/neural-speed") + elseif(BACKEND STREQUAL "chattts") wasmedge_setup_simdjson() find_package(Python3 COMPONENTS Interpreter Development) diff --git a/plugins/wasi_nn/neuralspeed.cpp b/plugins/wasi_nn/neuralspeed.cpp index ef4f11eb..0d29c59c 100644 --- a/plugins/wasi_nn/neuralspeed.cpp +++ b/plugins/wasi_nn/neuralspeed.cpp @@ -4,327 +4,12 @@ #include "neuralspeed.h" #include "wasinnenv.h" -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED -#include "simdjson.h" - -#if !defined(_WIN32) && !defined(_WIN64) && !defined(__WIN32__) && \ - !defined(__TOS_WIN__) && !defined(__WINDOWS__) -#include -#endif -#include -#include -#endif - namespace WasmEdge::Host::WASINN::NeuralSpeed { -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED - -#if defined(_WIN32) || defined(_WIN64) || defined(__WIN32__) || \ - defined(__TOS_WIN__) || defined(__WINDOWS__) -HINSTANCE SharedLib = LoadLibrary(PYTHON_LIB_PATH); -#else -void *SharedLib = dlopen(PYTHON_LIB_PATH, RTLD_GLOBAL | RTLD_NOW); -#endif - -void printImformation(Graph &GraphRef, Context &CxtRef) { - spdlog::info("[WASI-NN] Neural speed backend: Number of input tokens: {}"sv, - CxtRef.Inputs.size()); - spdlog::info("[WASI-NN] Neural speed backend: Number of Output tokens: {}"sv, - CxtRef.Outputs.size()); - spdlog::info("[WASI-NN] Neural speed backend: Load time: {} ms"sv, - GraphRef.LoadTime); - spdlog::info("[WASI-NN] Neural speed backend: Compute time: {} ms "sv, - GraphRef.ComputeTime); -} - -Expect load(WASINN::WasiNNEnvironment &Env, - Span> Builders, WASINN::Device, - uint32_t &GraphId) noexcept { - // Add a new graph. - Env.NNGraph.emplace_back(Backend::NeuralSpeed); - auto &GraphRef = Env.NNGraph.back().get(); - const auto StartTime = std::chrono::steady_clock::now(); - // Initialize the plugin parameters. - GraphRef.EnableDebugLog = false; - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN] Neural speed backend: Load."sv); - } - - if (Builders.size() > 1) { - std::string Metadata = std::string( - reinterpret_cast(Builders[1].data()), Builders[1].size()); - simdjson::dom::parser Parser; - simdjson::dom::element Doc; - auto ParseError = Parser.parse(Metadata).get(Doc); - if (ParseError) { - spdlog::error("[WASI-NN] Neural speed backend: Parse metadata error"sv); - Env.NNGraph.pop_back(); - return ErrNo::InvalidEncoding; - } - if (Doc.at_key("model_type").error() == simdjson::SUCCESS) { - std::string_view ModelType; - auto Err = Doc["model_type"].get().get(ModelType); - if (Err) { - spdlog::error( - "[WASI-NN] Neural speed backend: Unable to retrieve the model_type option."sv); - Env.NNGraph.pop_back(); - return ErrNo::InvalidArgument; - } - GraphRef.ModelType = ModelType; - } - } - - // Handle the model path. - auto Weight = Builders[0]; - const std::string BinModel(reinterpret_cast(Weight.data()), - Weight.size()); - spdlog::info("[WASI-NN] Neural speed: BinModel: {}"sv, BinModel.size()); - std::string ModelFilePath; - if (BinModel.substr(0, 8) == "preload:") { - ModelFilePath = BinModel.substr(8); - } else { - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN] Neural speed: Model path not found in nn-preload, " - "write model into a tmpfile."sv); - } - // Write neural speed model to file. - ModelFilePath = "neural-speed-model.bin"sv; - std::ofstream TempFile(ModelFilePath, std::ios::binary); - TempFile.imbue( - std::locale(TempFile.getloc(), new std::codecvt_utf8)); - if (!TempFile) { - spdlog::error( - "[WASI-NN] Neural speed: Failed to create the temporary file. " - "Currently, our workaround involves creating a temporary model."sv); - Env.NNGraph.pop_back(); - return ErrNo::InvalidArgument; - } - TempFile << BinModel; - TempFile.close(); - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN] Neural speed: Write model into a tmpfile...Done"sv); - } - } - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN] Neural speed: Finished handling model path."sv); - } - - // Create Model class - if (!Py_IsInitialized()) { - Py_Initialize(); - } - if (GraphRef.NeuralSpeedModule == nullptr) { - GraphRef.NeuralSpeedModule = PyImport_ImportModule("neural_speed"); - } - if (GraphRef.NeuralSpeedModule == nullptr) { - PyErr_Print(); - spdlog::error( - "[WASI-NN] neural speed backend: Can not find neural speed library."sv); - Env.NNGraph.pop_back(); - return WASINN::ErrNo::RuntimeError; - } - if (GraphRef.ModelClass == nullptr) { - GraphRef.ModelClass = - PyObject_GetAttrString(GraphRef.NeuralSpeedModule, "Model"); - } - if (GraphRef.ModelClass == nullptr || - !PyCallable_Check(GraphRef.ModelClass)) { - spdlog::error( - "[WASI-NN] neural speed backend: Can not find Model class in neural speed."sv); - Env.NNGraph.pop_back(); - return WASINN::ErrNo::RuntimeError; - } - GraphRef.Model = PyObject_CallObject(GraphRef.ModelClass, NULL); - if (GraphRef.Model == nullptr) { - spdlog::error("[WASI-NN] neural speed backend: Load model error."sv); - Env.NNGraph.pop_back(); - return WASINN::ErrNo::InvalidArgument; - } - PyObject *LoadResult = - PyObject_CallMethod(GraphRef.Model, "init_from_bin", "(ss)", - GraphRef.ModelType.c_str(), ModelFilePath.c_str()); - if (LoadResult == nullptr) { - spdlog::error("[WASI-NN] neural speed backend: Load model error."sv); - Py_XDECREF(GraphRef.Model); - Env.NNGraph.pop_back(); - return WASINN::ErrNo::InvalidArgument; - } - Py_XDECREF(LoadResult); - GraphRef.LoadTime = std::chrono::duration_cast( - std::chrono::steady_clock::now() - StartTime) - .count(); - - // Store the loaded graph. - GraphId = Env.NNGraph.size() - 1; - - return WASINN::ErrNo::Success; -} - -Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, - uint32_t &ContextId) noexcept { - if (!Py_IsInitialized()) { - spdlog::error( - "[WASI-NN] Neural speed backend: Model has been realse, please reload it."sv); - return WASINN::ErrNo::RuntimeError; - } - Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); - ContextId = Env.NNContext.size() - 1; - return ErrNo::Success; -} - -Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, - uint32_t, const TensorData &Tensor) noexcept { - auto &CxtRef = Env.NNContext[ContextId].get(); - auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - if (!Py_IsInitialized()) { - spdlog::error( - "[WASI-NN] Neural speed backend: Model has been realse, please reload it."sv); - return WASINN::ErrNo::RuntimeError; - } - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN] Neural speed backend: setInput"sv); - } - - // Set the input. - if (Tensor.Tensor.size() % sizeof(long long int) != 0) { - spdlog::error("[WASI-NN] neural speed backend: Input tensor size is not a " - "multiple of " - "4 bytes."sv); - return WASINN::ErrNo::InvalidArgument; - } - const std::vector Prompt{ - reinterpret_cast(Tensor.Tensor.data()), - reinterpret_cast(Tensor.Tensor.data() + - Tensor.Tensor.size())}; - CxtRef.Inputs.clear(); - CxtRef.Inputs = Prompt; - - return WASINN::ErrNo::Success; -} - -Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, - uint32_t, Span OutBuffer, - uint32_t &BytesWritten) noexcept { - auto &CxtRef = Env.NNContext[ContextId].get(); - auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN] Neural speed backend: getOutput"sv); - } - std::string StringTmp(reinterpret_cast(CxtRef.Outputs.data()), - CxtRef.Outputs.size() * sizeof(long long int)); - std::copy_n(StringTmp.data(), StringTmp.length(), OutBuffer.data()); - BytesWritten = StringTmp.length(); - return WASINN::ErrNo::Success; -} - -Expect compute(WasiNNEnvironment &Env, - uint32_t ContextId) noexcept { - if (!Py_IsInitialized()) { - spdlog::error( - "[WASI-NN]Neural speed backend: Model has been realse, please reload it."sv); - return WASINN::ErrNo::RuntimeError; - } - auto &CxtRef = Env.NNContext[ContextId].get(); - auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - const auto StartTime = std::chrono::steady_clock::now(); - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN] Neural speed backend: compute"sv); - } - if (CxtRef.Inputs.size() == 0) { - spdlog::error("[WASI-NN] Neural speed backend: Llama input is not set!"sv); - return ErrNo::InvalidArgument; - } - CxtRef.Outputs.clear(); - PyObject *TensorList = PyList_New(0); - for (size_t Cnt = 0; Cnt < CxtRef.Inputs.size(); ++Cnt) { - PyObject *Num = PyLong_FromLong(CxtRef.Inputs[Cnt]); - PyList_Append(TensorList, Num); - Py_DECREF(Num); - } - PyObject *TmpArg = PyList_New(0); - PyList_Append(TmpArg, TensorList); - PyObject *TorchModule = PyImport_ImportModule("torch"); - if (TorchModule == nullptr) { - spdlog::error( - "[WASI-NN] neural speed backend: Can not find torch library."sv); - return WASINN::ErrNo::RuntimeError; - } - PyObject *LongTensorFunc = PyObject_GetAttrString(TorchModule, "LongTensor"); - PyObject *LongTensorArgs = PyTuple_Pack(1, TmpArg); - PyObject *LongTensor = PyObject_CallObject(LongTensorFunc, LongTensorArgs); - Py_DECREF(TensorList); - Py_DECREF(TorchModule); - Py_DECREF(LongTensorFunc); - Py_DECREF(TmpArg); - Py_DECREF(LongTensorArgs); - if (LongTensor == nullptr) { - spdlog::error( - "[WASI-NN] neural speed backend: Input transfer tensor failed."sv); - return WASINN::ErrNo::InvalidArgument; - } - PyObject *GenerateString = PyUnicode_FromString("generate"); - PyObject *Result = PyObject_CallMethodObjArgs(GraphRef.Model, GenerateString, - LongTensor, NULL); - Py_DECREF(GenerateString); - if (Result == nullptr) { - PyErr_Print(); - spdlog::error( - "[WASI-NN] neural speed backend: Neural Speed runtime error."sv); - return WASINN::ErrNo::RuntimeError; - } - if (PyList_Check(Result)) { - const Py_ssize_t OuterSize = PyList_Size(Result); - for (Py_ssize_t OutterCnt = 0; OutterCnt < OuterSize; ++OutterCnt) { - PyObject *InnerList = PyList_GetItem(Result, OutterCnt); - if (PyList_Check(InnerList)) { - std::vector InnerVec; - const Py_ssize_t InnerSize = PyList_Size(InnerList); - for (Py_ssize_t InnerCnt = 0; InnerCnt < InnerSize; ++InnerCnt) { - PyObject *Num = PyList_GetItem(InnerList, InnerCnt); - if (PyLong_Check(Num)) { - InnerVec.push_back(PyLong_AsLong(Num)); - } - } - CxtRef.Outputs = InnerVec; - } - } - } - Py_DECREF(Result); - Py_DECREF(LongTensor); - GraphRef.ComputeTime = std::chrono::duration_cast( - std::chrono::steady_clock::now() - StartTime) - .count(); - if (GraphRef.EnableDebugLog) { - printImformation(GraphRef, CxtRef); - } - return WASINN::ErrNo::Success; -} - -Expect unload(WASINN::WasiNNEnvironment &Env, - uint32_t ContextId) noexcept { - auto &CxtRef = Env.NNContext[ContextId].get(); - auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN] Neural speed backend: start unload."sv); - } - if (Py_IsInitialized()) { - Py_XDECREF(GraphRef.Model); - Py_XDECREF(GraphRef.ModelClass); - Py_XDECREF(GraphRef.NeuralSpeedModule); - GraphRef.ModelClass = nullptr; - GraphRef.NeuralSpeedModule = nullptr; - Py_Finalize(); - } - return WASINN::ErrNo::Success; -} - -#else namespace { Expect reportBackendNotSupported() noexcept { - spdlog::error( - "[WASI-NN] Neural speed backend is not built. use " - "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"NeuralSpeed\" to build it."sv); + spdlog::error("[WASI-NN] Neural Speed backend is removed due to the upstream " + "end-of-life. " + "Reference: https://github.com/intel/neural-speed"sv); return WASINN::ErrNo::InvalidArgument; } } // namespace @@ -352,5 +37,4 @@ Expect compute(WASINN::WasiNNEnvironment &, uint32_t) noexcept { Expect unload(WASINN::WasiNNEnvironment &, uint32_t) noexcept { return reportBackendNotSupported(); } -#endif } // namespace WasmEdge::Host::WASINN::NeuralSpeed diff --git a/plugins/wasi_nn/neuralspeed.h b/plugins/wasi_nn/neuralspeed.h index 7893aa74..b74b3c27 100644 --- a/plugins/wasi_nn/neuralspeed.h +++ b/plugins/wasi_nn/neuralspeed.h @@ -6,46 +6,15 @@ #include "plugin/plugin.h" #include "types.h" -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED -#include -#endif - namespace WasmEdge::Host::WASINN { struct WasiNNEnvironment; } namespace WasmEdge::Host::WASINN::NeuralSpeed { -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED -struct Graph { - bool EnableDebugLog = true; - std::string ModelType = "llama"; - inline static int GraphNumber = 0; - Graph() noexcept { Py_Initialize(); } - ~Graph() noexcept { - if (Py_IsInitialized()) { - Py_XDECREF(Model); - Py_XDECREF(ModelClass); - Py_XDECREF(NeuralSpeedModule); - } - } - PyObject *Model = nullptr; - PyObject *NeuralSpeedModule = nullptr; - PyObject *ModelClass = nullptr; - int64_t LoadTime = 0; - int64_t ComputeTime = 0; -}; -struct Context { - Context(size_t Gid, Graph &) noexcept : GraphId(Gid) {} - size_t GraphId; - std::vector Inputs; - std::vector Outputs; -}; -#else struct Graph {}; struct Context { Context(size_t, Graph &) noexcept {} }; -#endif struct Environ {}; diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index 03669f6f..88fad025 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -71,12 +71,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) ) endif() elseif(BACKEND STREQUAL "neuralspeed") - message( STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_neural_speed_fixtures") - download( - https://huggingface.co/grorge123/phi-2-GPTQ/resolve/main/ne_phi_q_nf4_bestla_cfp32_g32.bin - ${CMAKE_CURRENT_BINARY_DIR}/wasinn_neural_speed_fixtures/ne_phi_q_nf4_bestla_cfp32_g32.bin - MD5=5e055b41f8cc1a42f26ff8742719ef1e - ) + message(NOTICE "Neural Speed backend is removed due to the upstream end-of-life.") elseif(BACKEND STREQUAL "piper") message(STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_piper_fixtures") download( diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 4d0f4f40..0ed4990b 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -24,7 +24,6 @@ using WasmEdge::Host::WASINN::ErrNo; defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML) || \ - defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS) @@ -1981,270 +1980,6 @@ TEST(WasiNNTest, GGMLBackendComputeSingleWithRPC) { #endif // WASMEDGE_BUILD_WASI_NN_RPC #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED -TEST(WasiNNTest, NeuralSpeedBackend) { - // Create the wasi_nn module instance. - auto NNMod = createModule(); - ASSERT_TRUE(NNMod); - - // Create the calling frame with memory instance. - WasmEdge::Runtime::Instance::ModuleInstance Mod(""); - Mod.addHostMemory( - "memory", std::make_unique( - WasmEdge::AST::MemoryType(60000))); - auto *MemInstPtr = Mod.findMemoryExports("memory"); - ASSERT_TRUE(MemInstPtr != nullptr); - auto &MemInst = *MemInstPtr; - WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); - - // Load the files. - std::vector Prompt = {7454, 2402, 257, 640, 11, 612, - 11196, 257, 1310, 2576, 11}; - std::string tmp(reinterpret_cast(Prompt.data()), - Prompt.size() * sizeof(long long int)); - std::vector TensorData(tmp.begin(), tmp.end()); - std::vector WeightRead = readEntireFile( - "./wasinn_neural_speed_fixtures/ne_phi_q_nf4_bestla_cfp32_g32.bin"); - - std::vector TensorDim{1}; - uint32_t BuilderPtr = UINT32_C(0); - uint32_t LoadEntryPtr = UINT32_C(0); - uint32_t SetInputEntryPtr = UINT32_C(0); - uint32_t OutBoundPtr = UINT32_C(61000) * UINT32_C(65536); - uint32_t StorePtr = UINT32_C(65536); - - // Return value. - std::array Errno = {UINT32_C(0)}; - - // Get the function "load". - auto *FuncInst = NNMod->findFuncExports("load"); - EXPECT_NE(FuncInst, nullptr); - EXPECT_TRUE(FuncInst->isHostFunction()); - auto &HostFuncLoad = - dynamic_cast(FuncInst->getHostFunc()); - // Get the function "init_execution_context". - FuncInst = NNMod->findFuncExports("init_execution_context"); - EXPECT_NE(FuncInst, nullptr); - EXPECT_TRUE(FuncInst->isHostFunction()); - auto &HostFuncInit = dynamic_cast( - FuncInst->getHostFunc()); - // Get the function "set_input". - FuncInst = NNMod->findFuncExports("set_input"); - EXPECT_NE(FuncInst, nullptr); - EXPECT_TRUE(FuncInst->isHostFunction()); - auto &HostFuncSetInput = - dynamic_cast(FuncInst->getHostFunc()); - // Get the function "get_output". - FuncInst = NNMod->findFuncExports("get_output"); - EXPECT_NE(FuncInst, nullptr); - EXPECT_TRUE(FuncInst->isHostFunction()); - auto &HostFuncGetOutput = - dynamic_cast(FuncInst->getHostFunc()); - // Get the function "compute". - FuncInst = NNMod->findFuncExports("compute"); - EXPECT_NE(FuncInst, nullptr); - EXPECT_TRUE(FuncInst->isHostFunction()); - auto &HostFuncCompute = - dynamic_cast(FuncInst->getHostFunc()); - // Get the function "unload". - FuncInst = NNMod->findFuncExports("unload"); - EXPECT_NE(FuncInst, nullptr); - EXPECT_TRUE(FuncInst->isHostFunction()); - auto &HostFuncUnload = - dynamic_cast(FuncInst->getHostFunc()); - - // Neural Speed WASI-NN load tests. - // Test: load -- meaningless binaries. - { - EXPECT_TRUE( - HostFuncLoad.run(CallFrame, - std::initializer_list{ - LoadEntryPtr, UINT32_C(1), - static_cast(Backend::NeuralSpeed), - static_cast(Device::CPU), BuilderPtr}, - Errno)); - EXPECT_EQ(Errno[0].get(), - static_cast(ErrNo::InvalidArgument)); - } - // Test: load -- graph id ptr out of bounds. - { - EXPECT_TRUE( - HostFuncLoad.run(CallFrame, - std::initializer_list{ - LoadEntryPtr, UINT32_C(1), - static_cast(Backend::NeuralSpeed), - static_cast(Device::CPU), OutBoundPtr}, - Errno)); - EXPECT_EQ(Errno[0].get(), - static_cast(ErrNo::InvalidArgument)); - } - - // Test: load -- graph builder ptr out of bounds. - { - EXPECT_TRUE( - HostFuncLoad.run(CallFrame, - std::initializer_list{ - OutBoundPtr, UINT32_C(1), - static_cast(Backend::NeuralSpeed), - static_cast(Device::CPU), BuilderPtr}, - Errno)); - EXPECT_EQ(Errno[0].get(), - static_cast(ErrNo::InvalidArgument)); - } - - // Test: load -- Neural Speed model bin ptr out of bounds. - BuilderPtr = LoadEntryPtr; - writeFatPointer(MemInst, OutBoundPtr, - static_cast(WeightRead.size()), BuilderPtr); - { - EXPECT_TRUE( - HostFuncLoad.run(CallFrame, - std::initializer_list{ - LoadEntryPtr, UINT32_C(1), - static_cast(Backend::NeuralSpeed), - static_cast(Device::CPU), BuilderPtr}, - Errno)); - EXPECT_EQ(Errno[0].get(), - static_cast(ErrNo::InvalidArgument)); - } - // Test: load -- load successfully. - std::string Config = "{\"model_type\":\"phi\"}"; - std::vector ConfigData(Config.begin(), Config.end()); - BuilderPtr = LoadEntryPtr; - writeFatPointer(MemInst, StorePtr, WeightRead.size(), BuilderPtr); - writeFatPointer(MemInst, StorePtr + WeightRead.size(), ConfigData.size(), - BuilderPtr); - writeBinaries(MemInst, WeightRead, StorePtr); - writeBinaries(MemInst, ConfigData, StorePtr + WeightRead.size()); - StorePtr += WeightRead.size() + ConfigData.size(); - { - EXPECT_TRUE( - HostFuncLoad.run(CallFrame, - std::initializer_list{ - LoadEntryPtr, UINT32_C(2), - static_cast(Backend::NeuralSpeed), - static_cast(Device::CPU), BuilderPtr}, - Errno)); - EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); - EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); - BuilderPtr += 4; - } - - // Neural Speed WASI-NN init_execution_context tests. - // Test: init_execution_context -- graph id invalid. - { - EXPECT_TRUE(HostFuncInit.run( - CallFrame, - std::initializer_list{UINT32_C(2), BuilderPtr}, - Errno)); - EXPECT_EQ(Errno[0].get(), - static_cast(ErrNo::InvalidArgument)); - } - - // Test: init_execution_context -- init second context. - { - EXPECT_TRUE(HostFuncInit.run( - CallFrame, - std::initializer_list{UINT32_C(0), BuilderPtr}, - Errno)); - EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); - EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); - BuilderPtr += 4; - } - - // Neural Speed WASI-NN set_input tests. - SetInputEntryPtr = BuilderPtr; - writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); - writeUInt32(MemInst, UINT32_C(1), BuilderPtr); - writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), - BuilderPtr); - writeBinaries(MemInst, TensorDim, StorePtr); - writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); - - // Test: set_input -- context id exceeds. - { - EXPECT_TRUE( - HostFuncSetInput.run(CallFrame, - std::initializer_list{ - UINT32_C(3), UINT32_C(0), SetInputEntryPtr}, - Errno)); - EXPECT_EQ(Errno[0].get(), - static_cast(ErrNo::InvalidArgument)); - } - // Test: set_input -- set input successfully. - { - EXPECT_TRUE( - HostFuncSetInput.run(CallFrame, - std::initializer_list{ - UINT32_C(0), UINT32_C(0), SetInputEntryPtr}, - Errno)); - EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); - } - StorePtr += (TensorDim.size() * 4 + TensorData.size()); - - // Neural Speed WASI-NN compute tests. - // Test: compute -- context id exceeds. - { - EXPECT_TRUE(HostFuncCompute.run( - CallFrame, std::initializer_list{UINT32_C(3)}, - Errno)); - EXPECT_EQ(Errno[0].get(), - static_cast(ErrNo::InvalidArgument)); - } - // Test: compute -- compute successfully. - { - EXPECT_TRUE(HostFuncCompute.run( - CallFrame, std::initializer_list{UINT32_C(0)}, - Errno)); - EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); - } - - // Neural Speed WASI-NN get_output tests. - // Test: get_output -- output bytes ptr out of bounds. - { - EXPECT_TRUE(HostFuncGetOutput.run( - CallFrame, - std::initializer_list{ - UINT32_C(0), UINT32_C(0), StorePtr, 65532, OutBoundPtr}, - Errno)); - EXPECT_EQ(Errno[0].get(), - static_cast(ErrNo::InvalidArgument)); - } - - // Test: get_output -- output buffer ptr out of bounds. - { - EXPECT_TRUE(HostFuncGetOutput.run( - CallFrame, - std::initializer_list{ - UINT32_C(0), UINT32_C(0), OutBoundPtr, 65532, BuilderPtr}, - Errno)); - EXPECT_EQ(Errno[0].get(), - static_cast(ErrNo::InvalidArgument)); - } - // Test: get_output -- get output successfully. - { - EXPECT_TRUE(HostFuncGetOutput.run( - CallFrame, - std::initializer_list{ - UINT32_C(0), UINT32_C(0), StorePtr, 65532, BuilderPtr}, - Errno)); - EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); - // Should output more than 50 bytes. - auto BytesWritten = *MemInst.getPointer(BuilderPtr); - EXPECT_GE(BytesWritten, 50); - } - - // Neural Speed WASI-NN unload tests. - // Test: unload -- unload successfully. - { - EXPECT_TRUE(HostFuncUnload.run( - CallFrame, std::initializer_list{UINT32_C(0)}, - Errno)); - EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); - } -} -#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_NEURAL_SPEED - #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER TEST(WasiNNTest, WhisperBackend) { // Create the wasi_nn module instance. diff --git a/utils/wasi-nn/install-neuralspeed.sh b/utils/wasi-nn/install-neuralspeed.sh deleted file mode 100644 index 8b8d53df..00000000 --- a/utils/wasi-nn/install-neuralspeed.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/usr/bin/env bash -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2024 Second State INC - -set -e -echo "Installing Python library!" -apt update -apt install -y python3-dev python3-pip - -echo "Installing Neural Speed!" -wget https://raw.githubusercontent.com/intel/neural-speed/main/requirements.txt -pip install -r requirements.txt -pip install neural-speed==${NEURALSPEED_VERSION} From 8c559e2f3b521c439c116d5eb07b0945688bbb5e Mon Sep 17 00:00:00 2001 From: Jun Zhang Date: Thu, 12 Sep 2024 17:07:53 +0800 Subject: [PATCH 430/623] [Plugin] llmc: Support training GPT2 on GPU (CUDA) (#3750) Signed-off-by: Jun Zhang --- plugins/wasmedge_llmc/CMakeLists.txt | 24 ++++++++++++++++++----- plugins/wasmedge_llmc/llmc_env.cpp | 2 +- plugins/wasmedge_llmc/llmc_fwd.h | 2 +- test/plugins/wasmedge_llmc/CMakeLists.txt | 18 ++++++++++++----- 4 files changed, 34 insertions(+), 12 deletions(-) diff --git a/plugins/wasmedge_llmc/CMakeLists.txt b/plugins/wasmedge_llmc/CMakeLists.txt index 7a486e1b..7e9e4c59 100644 --- a/plugins/wasmedge_llmc/CMakeLists.txt +++ b/plugins/wasmedge_llmc/CMakeLists.txt @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: 2019-2024 Second State INC -# TODO: Fetch llm.c source. - wasmedge_add_library(wasmedgePluginWasmEdgeLLMC SHARED llmc_func.cpp @@ -10,17 +8,33 @@ wasmedge_add_library(wasmedgePluginWasmEdgeLLMC llmc_env.cpp ) +option(WASMEDGE_PLUGIN_LLMC_CUDA "Training GPT2 with CUDA" OFF) + message(STATUS "Start fetching llm.c source") include(FetchContent) + +if (WASMEDGE_PLUGIN_LLMC_CUDA) + set(CUDALIB ON) + message(STATUS "Build wasmedge_llmc with CUDA backend") +else() + message(STATUS "Build wasmedge_llmc with CPU backend") +endif() + FetchContent_Declare( llmc GIT_REPOSITORY https://github.com/WasmEdge/llm.c ) FetchContent_MakeAvailable(llmc) -target_link_libraries(wasmedgePluginWasmEdgeLLMC PRIVATE - train_gpt2_cpu -) +if (WASMEDGE_PLUGIN_LLMC_CUDA) + target_link_libraries(wasmedgePluginWasmEdgeLLMC PRIVATE + train_gpt2_cuda + ) +else() + target_link_libraries(wasmedgePluginWasmEdgeLLMC PRIVATE + train_gpt2_cpu + ) +endif() target_compile_options(wasmedgePluginWasmEdgeLLMC PUBLIC diff --git a/plugins/wasmedge_llmc/llmc_env.cpp b/plugins/wasmedge_llmc/llmc_env.cpp index d254ebce..a960ec16 100644 --- a/plugins/wasmedge_llmc/llmc_env.cpp +++ b/plugins/wasmedge_llmc/llmc_env.cpp @@ -41,7 +41,7 @@ DataLoader *LLMCEnv::getDataLoader(uint32_t Id) noexcept { LLMCEnv::~LLMCEnv() { for (GPT2 *M : Models) { - gpt2_free(M); + gpt2_destroy(M); } for (DataLoader *DL : DataLoaders) { dataloader_destroy(DL); diff --git a/plugins/wasmedge_llmc/llmc_fwd.h b/plugins/wasmedge_llmc/llmc_fwd.h index 804893eb..afa39f10 100644 --- a/plugins/wasmedge_llmc/llmc_fwd.h +++ b/plugins/wasmedge_llmc/llmc_fwd.h @@ -13,7 +13,7 @@ struct DataLoader; GPT2 *gpt2_create(const char *checkpoint_path); -void gpt2_free(GPT2 *model); +void gpt2_destroy(GPT2 *model); DataLoader *dataloader_create(const char *filename_pattern, size_t B, size_t T, int process_rank, int num_processes, diff --git a/test/plugins/wasmedge_llmc/CMakeLists.txt b/test/plugins/wasmedge_llmc/CMakeLists.txt index 820e6503..7fce5e8a 100644 --- a/test/plugins/wasmedge_llmc/CMakeLists.txt +++ b/test/plugins/wasmedge_llmc/CMakeLists.txt @@ -43,11 +43,19 @@ function(download URL OUTPUT HASH) endfunction() message(STATUS "Downloading GPT2 model check point to ${CMAKE_CURRENT_BINARY_DIR}/gpt2_124M.bin") -download( - https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/gpt2_124M.bin - ${CMAKE_CURRENT_BINARY_DIR}/wasmedge_llmc/gpt2_124M.bin - SHA256=3da8b207584030bcdcd207cf7a99952e3421dce92da218b351071857511bf162 -) +if (WASMEDGE_PLUGIN_LLMC_CUDA) + download( + https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/gpt2_124M_bf16.bin + ${CMAKE_CURRENT_BINARY_DIR}/wasmedge_llmc/gpt2_124M.bin + SHA256=6661f45628102b4c6e86835d9057b5ba2c024dbf9b81445175e258b7878a1a6f + ) +else() + download( + https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/gpt2_124M.bin + ${CMAKE_CURRENT_BINARY_DIR}/wasmedge_llmc/gpt2_124M.bin + SHA256=3da8b207584030bcdcd207cf7a99952e3421dce92da218b351071857511bf162 + ) +endif() message(STATUS "Downloading training dataset to ${CMAKE_CURRENT_BINARY_DIR}/tiny_shakespeare_train.bin") download( https://huggingface.co/datasets/karpathy/llmc-starter-pack/resolve/main/tiny_shakespeare_train.bin From afa3302fb2ecdb8898555b6e9f030568b2628949 Mon Sep 17 00:00:00 2001 From: Yi-Ying He Date: Fri, 13 Sep 2024 14:33:10 +0800 Subject: [PATCH 431/623] [Plugin] WASI-NN Whisper: supporting metadata in set_input. (#3755) Signed-off-by: YiYing He --- plugins/wasi_nn/whispercpp.cpp | 163 +++++++++++++++++++++------------ plugins/wasi_nn/whispercpp.h | 28 ++++-- 2 files changed, 122 insertions(+), 69 deletions(-) diff --git a/plugins/wasi_nn/whispercpp.cpp b/plugins/wasi_nn/whispercpp.cpp index adf42071..3cc1452d 100644 --- a/plugins/wasi_nn/whispercpp.cpp +++ b/plugins/wasi_nn/whispercpp.cpp @@ -83,7 +83,7 @@ bool loadWAV(Span Buf, std::vector &PCMF32) { void WhisperLogCallback(ggml_log_level LogLevel, const char *LogText, void *UserData) { const Graph &GraphRef = *reinterpret_cast(UserData); - if (!GraphRef.EnableLog) { + if (!GraphRef.WhisperConfig.EnableLog) { return; } std::string Text(LogText); @@ -138,7 +138,25 @@ void WhisperOutputSegmentCallback(struct whisper_context *WhisperCtx, } } -Expect parseMetadata(Graph &GraphRef, +void setWhisperParams(Context &CxtRef) noexcept { + auto &WParam = CxtRef.WhisperParams; + auto &ConfigRef = CxtRef.WhisperConfig; + WParam.print_progress = false; + WParam.thold_pt = ConfigRef.WordThreshold; + WParam.translate = ConfigRef.Translate; + WParam.language = ConfigRef.SpokenLanguage.c_str(); + WParam.detect_language = ConfigRef.DetectLanguage; + WParam.initial_prompt = ConfigRef.InitialPrompt.c_str(); + WParam.temperature_inc = ConfigRef.TemperatureInc; + WParam.temperature = ConfigRef.Temperature; + WParam.entropy_thold = ConfigRef.EntropyThreshold; + WParam.logprob_thold = ConfigRef.LogprobThreshold; + WParam.grammar_penalty = ConfigRef.GrammarPenalty; + WParam.new_segment_callback = WhisperOutputSegmentCallback; + WParam.new_segment_callback_user_data = &CxtRef; +} + +Expect parseMetadata(Config &ConfigRef, const std::string &Metadata) noexcept { simdjson::dom::parser Parser; simdjson::dom::element Doc; @@ -160,7 +178,7 @@ Expect parseMetadata(Graph &GraphRef, // The plugin parameters. if (Doc.at_key("enable-log").error() == simdjson::SUCCESS) { - auto Err = Doc["enable-log"].get().get(GraphRef.EnableLog); + auto Err = Doc["enable-log"].get().get(ConfigRef.EnableLog); if (Err) { spdlog::error( "[WASI-NN] Whisper backend: Unable to retrieve the enable-log " @@ -169,7 +187,8 @@ Expect parseMetadata(Graph &GraphRef, } } if (Doc.at_key("enable-debug-log").error() == simdjson::SUCCESS) { - auto Err = Doc["enable-debug-log"].get().get(GraphRef.EnableDebugLog); + auto Err = + Doc["enable-debug-log"].get().get(ConfigRef.EnableDebugLog); if (Err) { spdlog::error( "[WASI-NN] Whisper backend: Unable to retrieve the enable-debug-log " @@ -178,7 +197,7 @@ Expect parseMetadata(Graph &GraphRef, } } if (Doc.at_key("translate").error() == simdjson::SUCCESS) { - auto Err = Doc["translate"].get().get(GraphRef.Translate); + auto Err = Doc["translate"].get().get(ConfigRef.Translate); if (Err) { spdlog::error( "[WASI-NN] Whisper backend: Unable to retrieve the translate " @@ -195,10 +214,10 @@ Expect parseMetadata(Graph &GraphRef, "option."sv); return ErrNo::InvalidArgument; } - GraphRef.SpokenLanguage = Language; + ConfigRef.SpokenLanguage = Language; } if (Doc.at_key("detect-language").error() == simdjson::SUCCESS) { - auto Err = Doc["detect-language"].get().get(GraphRef.DetectLanguage); + auto Err = Doc["detect-language"].get().get(ConfigRef.DetectLanguage); if (Err) { spdlog::error( "[WASI-NN] Whisper backend: Unable to retrieve the detect-language " @@ -214,7 +233,37 @@ Expect parseMetadata(Graph &GraphRef, "[WASI-NN] Whisper backend: Unable to retrieve the prompt option."sv); return ErrNo::InvalidArgument; } - GraphRef.InitialPrompt = Prompt; + ConfigRef.InitialPrompt = Prompt; + } + return ErrNo::Success; +} + +Expect handleTranslationConfig(whisper_context *WhisperCtx, + Config &ConfigRef) noexcept { + assuming(WhisperCtx); + + // Check the language. + if (ConfigRef.SpokenLanguage != "auto"sv && + whisper_lang_id(ConfigRef.SpokenLanguage.c_str()) == -1) { + spdlog::error("[WASI-NN] Whisper backend: Error: unknown language {}."sv, + ConfigRef.SpokenLanguage); + return ErrNo::InvalidArgument; + } + + // Check the translate option. + if (!whisper_is_multilingual(WhisperCtx)) { + if (ConfigRef.SpokenLanguage != "en"sv || ConfigRef.Translate) { + ConfigRef.SpokenLanguage = "en"sv; + ConfigRef.Translate = false; + if (ConfigRef.EnableLog) { + spdlog::info( + "[WASI-NN] Whisper backend: Model is not multilingual. Ignoring " + "language and translation options"sv); + } + } + } + if (ConfigRef.DetectLanguage) { + ConfigRef.SpokenLanguage = "auto"sv; } return ErrNo::Success; } @@ -230,7 +279,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Initialize the parameters. auto CParam = whisper_context_default_params(); GraphRef.ModelFilePath = ""sv; - GraphRef.SpokenLanguage = "en"sv; + GraphRef.WhisperConfig.SpokenLanguage = "en"sv; GraphRef.UseGPU = CParam.use_gpu; GraphRef.MainGPU = CParam.gpu_device; @@ -242,7 +291,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, const std::string Metadata(reinterpret_cast(Builders[1].data()), Builders[1].size()); // Ignore context or model updates when initializing the graph. - auto Res = parseMetadata(GraphRef, Metadata); + auto Res = parseMetadata(GraphRef.WhisperConfig, Metadata); if (Res != ErrNo::Success) { spdlog::error("[WASI-NN] Whisper backend: Failed to parse metadata."sv); Env.NNGraph.pop_back(); @@ -251,7 +300,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, } // Handle the model path. - if (GraphRef.EnableDebugLog) { + if (GraphRef.WhisperConfig.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] Whisper backend: Handling model path."sv); } auto Weight = Builders[0]; @@ -262,7 +311,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, } // Initialize whisper context from model file with parameters. - if (GraphRef.EnableDebugLog) { + if (GraphRef.WhisperConfig.EnableDebugLog) { spdlog::info( "[WASI-NN][Debug] Whisper backend: Initialize whisper context with " "given parameters"sv); @@ -281,35 +330,17 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, Env.NNGraph.pop_back(); return ErrNo::InvalidArgument; } - if (GraphRef.EnableDebugLog) { + if (GraphRef.WhisperConfig.EnableDebugLog) { spdlog::info( "[WASI-NN][Debug] Whisper backend: Initialize whisper context with " "given parameters...Done"sv); } - // Check the language. - if (GraphRef.SpokenLanguage != "auto"sv && - whisper_lang_id(GraphRef.SpokenLanguage.c_str()) == -1) { - spdlog::error("[WASI-NN] Whisper backend: Error: unknown language {}."sv, - GraphRef.SpokenLanguage); + auto ResTranslateConfig = + handleTranslationConfig(GraphRef.WhisperCtx, GraphRef.WhisperConfig); + if (ResTranslateConfig != ErrNo::Success) { Env.NNGraph.pop_back(); - return ErrNo::InvalidArgument; - } - - // Check the translate option. - if (!whisper_is_multilingual(GraphRef.WhisperCtx)) { - if (GraphRef.SpokenLanguage != "en"sv || GraphRef.Translate) { - GraphRef.SpokenLanguage = "en"sv; - GraphRef.Translate = false; - if (GraphRef.EnableLog) { - spdlog::info( - "[WASI-NN] Whisper backend: Model is not multilingual. Ignoring " - "language and translation options"sv); - } - } - } - if (GraphRef.DetectLanguage) { - GraphRef.SpokenLanguage = "auto"sv; + return ResTranslateConfig; } // Store the loaded graph. @@ -321,31 +352,17 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, uint32_t &ContextId) noexcept { auto &GraphRef = Env.NNGraph[GraphId].get(); - if (GraphRef.EnableDebugLog) { + if (GraphRef.WhisperConfig.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] Whisper backend: initExecCtx"sv); } Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); ContextId = Env.NNContext.size() - 1; - auto &CxtRef = Env.NNContext[ContextId].get(); - auto &WParam = CxtRef.WhisperParams; - WParam.print_progress = false; - WParam.thold_pt = GraphRef.WordThreshold; - WParam.translate = GraphRef.Translate; - WParam.language = GraphRef.SpokenLanguage.c_str(); - WParam.detect_language = GraphRef.DetectLanguage; - WParam.initial_prompt = GraphRef.InitialPrompt.c_str(); - WParam.temperature_inc = GraphRef.TemperatureInc; - WParam.temperature = GraphRef.Temperature; - WParam.entropy_thold = GraphRef.EntropyThreshold; - WParam.logprob_thold = GraphRef.LogprobThreshold; - WParam.grammar_penalty = GraphRef.GrammarPenalty; - WParam.new_segment_callback = WhisperOutputSegmentCallback; - WParam.new_segment_callback_user_data = &CxtRef; - if (GraphRef.EnableLog) { + setWhisperParams(Env.NNContext[ContextId].get()); + if (GraphRef.WhisperConfig.EnableLog) { spdlog::info("[WASI-NN] Whisper backend: whisper_system_info: {}"sv, whisper_print_system_info()); } - if (GraphRef.EnableDebugLog) { + if (GraphRef.WhisperConfig.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] Whisper backend: initExecCtx...Done"sv); } return ErrNo::Success; @@ -355,11 +372,36 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, uint32_t Index [[maybe_unused]], const TensorData &Tensor) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); - auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - if (GraphRef.EnableDebugLog) { + if (CxtRef.WhisperConfig.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] Whisper backend: setInput"sv); } + // Use index 1 for metadata. + if (Index == 1) { + if (CxtRef.WhisperConfig.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] Whisper backend: found Metadata, processing"sv); + } + const std::string Metadata(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); + auto Res = parseMetadata(CxtRef.WhisperConfig, Metadata); + if (Res != ErrNo::Success) { + spdlog::error("[WASI-NN] Whisper backend: Failed to parse metadata."sv); + return Res; + } + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + Res = handleTranslationConfig(GraphRef.WhisperCtx, CxtRef.WhisperConfig); + if (Res != ErrNo::Success) { + return Res; + } + setWhisperParams(CxtRef); + if (CxtRef.WhisperConfig.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] Whisper backend: found Metadata, processing...Done"sv); + } + return ErrNo::Success; + } + if (Tensor.Dimension.size() != 2) { spdlog::error("[WASI-NN] Tensor dimension is out of range, expect 2-dim, " "but got {}-dim.", @@ -385,7 +427,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, return WASINN::ErrNo::InvalidArgument; } - if (GraphRef.EnableDebugLog) { + if (CxtRef.WhisperConfig.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] Whisper backend: setInput...Done"sv); } return ErrNo::Success; @@ -395,8 +437,7 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, uint32_t Index, Span OutBuffer, uint32_t &BytesWritten) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); - auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - if (GraphRef.EnableDebugLog) { + if (CxtRef.WhisperConfig.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] Whisper backend: getOutput with Index {}"sv, Index); } @@ -410,7 +451,7 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, std::copy_n(CxtRef.Outputs.data(), CxtRef.Outputs.length(), OutBuffer.data()); BytesWritten = CxtRef.Outputs.length(); - if (GraphRef.EnableDebugLog) { + if (CxtRef.WhisperConfig.EnableDebugLog) { spdlog::info( "[WASI-NN][Debug] Whisper backend: getOutput with Index {}...Done"sv, Index); @@ -421,7 +462,7 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - if (GraphRef.EnableDebugLog) { + if (CxtRef.WhisperConfig.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] Whisper backend: compute"sv); } @@ -433,7 +474,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { return ErrNo::RuntimeError; } - if (GraphRef.EnableDebugLog) { + if (CxtRef.WhisperConfig.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] Whisper backend: compute...Done"sv); } return ErrNo::Success; diff --git a/plugins/wasi_nn/whispercpp.h b/plugins/wasi_nn/whispercpp.h index 9dd36bbb..70f727a3 100644 --- a/plugins/wasi_nn/whispercpp.h +++ b/plugins/wasi_nn/whispercpp.h @@ -19,9 +19,8 @@ struct WasiNNEnvironment; namespace WasmEdge::Host::WASINN::Whisper { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER -struct Graph { - whisper_context *WhisperCtx = nullptr; - std::string ModelFilePath; + +struct Config { // Whisper parameters: bool EnableLog = false; bool EnableDebugLog = false; @@ -29,9 +28,6 @@ struct Graph { bool DetectLanguage = false; std::string SpokenLanguage; std::string InitialPrompt; - // Context parameters: - bool UseGPU = true; - int64_t MainGPU = 0; // Use GPU 0 by default // Sampling parameters: float WordThreshold = 0.01f; float EntropyThreshold = 2.40f; @@ -41,13 +37,29 @@ struct Graph { float GrammarPenalty = 100.0f; }; +struct Graph { + whisper_context *WhisperCtx = nullptr; + std::string ModelFilePath; + // Whisper config: + Config WhisperConfig; + // Context parameters: + bool UseGPU = true; + int64_t MainGPU = 0; // Use GPU 0 by default +}; + struct Context { public: - Context(size_t GId, Graph &) noexcept : GraphId(GId) {} + Context(size_t GId, Graph &G) noexcept + : GraphId(GId), WhisperConfig(G.WhisperConfig) {} size_t GraphId; - std::vector InputPCM; // mono-channel F32 PCM input. + // mono-channel F32 PCM input. + std::vector InputPCM; + // Whisper config. Inherit from the graph and accept metadata when setting + // input. + Config WhisperConfig; whisper_full_params WhisperParams = whisper_full_default_params( whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH); + // Recognition outputs. std::string Outputs; }; #else From 43be211888b1e361499eb2bb27a1c2061a4e0c05 Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Thu, 12 Sep 2024 18:33:36 +0800 Subject: [PATCH 432/623] [Docker] Ubuntu: Add plugin dependencies Signed-off-by: Yi Huang --- utils/docker/Dockerfile.ubuntu-base | 24 ++++++++++++--------- utils/docker/Dockerfile.ubuntu-plugins-deps | 17 ++++++++++++++- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/utils/docker/Dockerfile.ubuntu-base b/utils/docker/Dockerfile.ubuntu-base index 195dd4a4..385af5ed 100644 --- a/utils/docker/Dockerfile.ubuntu-base +++ b/utils/docker/Dockerfile.ubuntu-base @@ -9,6 +9,8 @@ RUN apt-get update && \ cmake \ curl \ dpkg-dev \ + g++ \ + gcc \ git \ ninja-build \ software-properties-common \ @@ -20,14 +22,24 @@ FROM base AS deps-20 RUN apt-get install -y \ llvm-12-dev \ - liblld-12-dev + liblld-12-dev \ + clang-12 + +RUN update-alternatives --install /usr/bin/clang clang /usr/bin/clang-12 100 && \ + update-alternatives --install /usr/bin/clang++ clang++ /usr/bin/clang++-12 100 && \ + update-alternatives --install /usr/bin/llvm-strip llvm-strip /usr/bin/llvm-strip-12 100 ### deps for ubuntu 22.04 ### FROM base AS deps-22 RUN apt-get install -y \ llvm-15-dev \ - liblld-15-dev + liblld-15-dev \ + clang-15 + +RUN update-alternatives --install /usr/bin/clang clang /usr/bin/clang-15 100 && \ + update-alternatives --install /usr/bin/clang++ clang++ /usr/bin/clang++-15 100 && \ + update-alternatives --install /usr/bin/llvm-strip llvm-strip /usr/bin/llvm-strip-15 100 ### deps for clang / ubuntu 20.04 ### FROM deps-20 AS deps-20-clang @@ -50,20 +62,12 @@ ENV CXX=/usr/bin/clang++-15 ### deps for gcc / ubuntu 20.04 ### FROM deps-20 AS deps-20-gcc -RUN apt-get install -y \ - gcc \ - g++ - ENV CC=gcc ENV CXX=g++ ### deps for gcc / ubuntu 22.04 ### FROM deps-22 AS deps-22-gcc -RUN apt-get install -y \ - gcc \ - g++ - ENV CC=gcc ENV CXX=g++ diff --git a/utils/docker/Dockerfile.ubuntu-plugins-deps b/utils/docker/Dockerfile.ubuntu-plugins-deps index a10ccec5..dcff0c53 100644 --- a/utils/docker/Dockerfile.ubuntu-plugins-deps +++ b/utils/docker/Dockerfile.ubuntu-plugins-deps @@ -1,13 +1,28 @@ ARG BASE_IMAGE=wasmedge/wasmedge:latest FROM ${BASE_IMAGE} AS base +WORKDIR /root + RUN apt-get update && \ apt-get install -y \ + cargo \ + libelf-dev \ + libomp-dev \ + libssl-dev \ + pkg-config \ unzip \ - wget + yasm COPY opencvmini/install-opencvmini.sh . ENV OPENCV_VERSION="4.8.0" RUN [ "/bin/bash", "install-opencvmini.sh" ] +COPY ffmpeg/install-ffmpeg-v6.0.sh . +RUN [ "/bin/bash", "install-ffmpeg-v6.0.sh" ] +ENV PKG_CONFIG_PATH=/root/FFmpeg-n6.0/output/lib/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} +ENV LD_LIBRARY_PATH=/root/FFmpeg-n6.0/output/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} + +### cleanup +FROM base AS clean-apt + RUN rm -rf /var/lib/apt/lists/* From 0677352ab60ac39a2283116ddf354912d7a62241 Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Thu, 12 Sep 2024 18:29:29 +0800 Subject: [PATCH 433/623] [Docker] Ubuntu: Refine for caching Signed-off-by: Yi Huang --- utils/docker/Dockerfile.ubuntu-base | 35 ++------ utils/docker/Dockerfile.ubuntu-gcc | 5 ++ utils/docker/docker-bake.ubuntu.hcl | 126 +++++++++++++++++++--------- 3 files changed, 95 insertions(+), 71 deletions(-) create mode 100644 utils/docker/Dockerfile.ubuntu-gcc diff --git a/utils/docker/Dockerfile.ubuntu-base b/utils/docker/Dockerfile.ubuntu-base index 385af5ed..8dc3f7b3 100644 --- a/utils/docker/Dockerfile.ubuntu-base +++ b/utils/docker/Dockerfile.ubuntu-base @@ -1,5 +1,4 @@ ARG UBUNTU_VER=22 -ARG TOOLCHAIN=clang FROM ubuntu:${UBUNTU_VER}.04 AS base ENV DEBIAN_FRONTEND=noninteractive @@ -29,6 +28,9 @@ RUN update-alternatives --install /usr/bin/clang clang /usr/bin/clang-12 100 && update-alternatives --install /usr/bin/clang++ clang++ /usr/bin/clang++-12 100 && \ update-alternatives --install /usr/bin/llvm-strip llvm-strip /usr/bin/llvm-strip-12 100 +ENV CC=/usr/bin/clang-12 +ENV CXX=/usr/bin/clang++-12 + ### deps for ubuntu 22.04 ### FROM base AS deps-22 @@ -41,37 +43,10 @@ RUN update-alternatives --install /usr/bin/clang clang /usr/bin/clang-15 100 && update-alternatives --install /usr/bin/clang++ clang++ /usr/bin/clang++-15 100 && \ update-alternatives --install /usr/bin/llvm-strip llvm-strip /usr/bin/llvm-strip-15 100 -### deps for clang / ubuntu 20.04 ### -FROM deps-20 AS deps-20-clang - -RUN apt-get install -y \ - clang-12 - -ENV CC=/usr/bin/clang-12 -ENV CXX=/usr/bin/clang++-12 - -### deps for clang / ubuntu 22.04 ### -FROM deps-22 AS deps-22-clang - -RUN apt-get install -y \ - clang-15 - ENV CC=/usr/bin/clang-15 ENV CXX=/usr/bin/clang++-15 -### deps for gcc / ubuntu 20.04 ### -FROM deps-20 AS deps-20-gcc - -ENV CC=gcc -ENV CXX=g++ - -### deps for gcc / ubuntu 22.04 ### -FROM deps-22 AS deps-22-gcc - -ENV CC=gcc -ENV CXX=g++ - -### deps for all ### -FROM deps-${UBUNTU_VER}-${TOOLCHAIN} AS final +### cleanup +FROM deps-${UBUNTU_VER} AS clean-apt RUN rm -rf /var/lib/apt/lists/* diff --git a/utils/docker/Dockerfile.ubuntu-gcc b/utils/docker/Dockerfile.ubuntu-gcc new file mode 100644 index 00000000..1b201ae2 --- /dev/null +++ b/utils/docker/Dockerfile.ubuntu-gcc @@ -0,0 +1,5 @@ +ARG BASE_IMAGE=wasmedge/wasmedge:latest +FROM ${BASE_IMAGE} AS base + +ENV CC=/usr/bin/gcc +ENV CXX=/usr/bin/g++ diff --git a/utils/docker/docker-bake.ubuntu.hcl b/utils/docker/docker-bake.ubuntu.hcl index 68b9b87d..fb9bf956 100644 --- a/utils/docker/docker-bake.ubuntu.hcl +++ b/utils/docker/docker-bake.ubuntu.hcl @@ -1,81 +1,125 @@ group "default" { targets = [ - "base", - "latest", - "plugins" + "clang", + "gcc" ] } -function "name" { - params = [toolchain, ubuntu] - result = "${toolchain}-ubuntu${replace(ubuntu, ".", "")}" +function "no-dot" { + params = [ubuntu] + result = replace(ubuntu, ".", "") } -function "tag" { - params = [toolchain, ubuntu] - result = equal(ubuntu, "22.04") ? "ubuntu-build-${toolchain}" : "ubuntu-${ubuntu}-build-${toolchain}" +function "major" { + params = [ubuntu] + result = regex("^[[:digit:]]+", ubuntu) } -variable "matrix" { - default = { - toolchain = ["clang", "gcc"] - ubuntu = ["20.04", "22.04"] - } +function "tags-latest" { + params = [target, ubuntu, toolchain] + result = target == "base" && ubuntu == "22.04" && toolchain == "clang" ? "latest" : "" } -target "base" { - matrix = matrix - name = name(toolchain, ubuntu) +function "tags-backports" { + params = [target, ubuntu, toolchain] + result = join("-", compact([ + "ubuntu", + ubuntu != "22.04" ? ubuntu : "", + "build", + toolchain, + target == "plugins" ? "plugins-deps" : "", + ])) +} + +function "tags-simplified" { + params = [target, ubuntu, toolchain] + result = target == "base" && toolchain == "clang" ? "ubuntu-${ubuntu}" : "" +} + +function "tags" { + params = [target, ubuntu, toolchain] + result = [for tag in compact([ + tags-latest(target, ubuntu, toolchain), + tags-backports(target, ubuntu, toolchain), + tags-simplified(target, ubuntu, toolchain), + ]) : "wasmedge/wasmedge:${tag}"] +} +target "base" { dockerfile = "Dockerfile.ubuntu-base" context = "./utils/docker" - tags = ["wasmedge/wasmedge:${tag(toolchain, ubuntu)}"] + matrix = { + ubuntu = ["20.04", "22.04"] + } + + name = "base-${no-dot(ubuntu)}" + tags = ["local/tmp:base-${ubuntu}"] args = { - TOOLCHAIN = toolchain - UBUNTU_VER = replace(ubuntu, ".04", "") + UBUNTU_VER = major(ubuntu) } } target "plugins" { - matrix = matrix - name = "${name(toolchain, ubuntu)}-plugins" - dockerfile = "./docker/Dockerfile.ubuntu-plugins-deps" context = "./utils" - contexts = { - "wasmedge/wasmedge:${tag(toolchain, ubuntu)}" = "target:${name(toolchain, ubuntu)}" + matrix = { + ubuntu = ["20.04", "22.04"] } - tags = ["wasmedge/wasmedge:${tag(toolchain, ubuntu)}-plugins-deps"] + inherits = ["base-${no-dot(ubuntu)}"] + name = "plugins-${no-dot(ubuntu)}" + contexts = { + "local/tmp:base-${ubuntu}" = "target:base-${no-dot(ubuntu)}" + } + tags = ["local/tmp:plugins-${ubuntu}"] args = { - BASE_IMAGE = "wasmedge/wasmedge:${tag(toolchain, ubuntu)}" + BASE_IMAGE = "local/tmp:base-${ubuntu}" + UBUNTU_VER = major(ubuntu) } } -target "latest" { +target "clang" { matrix = { - toolchain = ["clang"] - ubuntu = ["22.04"] + parent = ["base", "plugins"] + ubuntu = ["20.04", "22.04"] } - inherits = ["${name(toolchain, ubuntu)}"] - contexts = { - "wasmedge/wasmedge:${tag(toolchain, ubuntu)}" = "target:${name(toolchain, ubuntu)}" + + inherits = ["${parent}-${no-dot(ubuntu)}"] + name = "${parent}-${no-dot(ubuntu)}-clang" + contexts = { + "local/tmp:${parent}-${ubuntu}" = "target:${parent}-${no-dot(ubuntu)}" } - tags = ["wasmedge/wasmedge:latest"] + tags = tags(parent, ubuntu, "clang") } -target "clang-ubuntu2004-aarch64" { +target "gcc" { + dockerfile = "Dockerfile.ubuntu-gcc" + context = "./utils/docker" + matrix = { - toolchain = ["clang"] - ubuntu = ["20.04"] + parent = ["base", "plugins"] + ubuntu = ["20.04", "22.04"] } - inherits = ["${name(toolchain, ubuntu)}"] - contexts = { - "wasmedge/wasmedge:${tag(toolchain, ubuntu)}" = "target:${name(toolchain, ubuntu)}" + + inherits = ["${parent}-${no-dot(ubuntu)}"] + name = "${parent}-${no-dot(ubuntu)}-gcc" + contexts = { + "local/tmp:${parent}-${ubuntu}" = "target:${parent}-${no-dot(ubuntu)}" + } + tags = tags(parent, ubuntu, "gcc") + args = { + BASE_IMAGE = "local/tmp:${parent}-${ubuntu}" } +} - tags = ["wasmedge/wasmedge:${tag(toolchain, ubuntu)}-aarch64"] +# TODO: Refactor with multi-arch image +target "base-2004-clang-aarch64" { + inherits = ["base-2004"] + contexts = { + "local/tmp:base-2004" = "target:base-2004" + } + tags = [for tag in tags("base", "20.04", "clang") : "${tag}-aarch64"] platforms = ["linux/arm64"] } From 3f10df2a0104fb447221ebe99a58141acef40ab8 Mon Sep 17 00:00:00 2001 From: "Shen-Ta Hsieh(BestSteve)" Date: Thu, 19 Sep 2024 14:53:38 +0800 Subject: [PATCH 434/623] [WASI-NN] Fix `fmt::format` error in embedding (#3779) Signed-off-by: Shen-Ta Hsieh --- plugins/wasi_nn/ggml.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 6268e322..6d97588a 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -379,7 +379,7 @@ void buildOutputEmbedding(std::string &Embedding, int32_t NEmbd, // | ']' | // | '}' | Embedding = - fmt::format(R"({{"n_embedding": {:.10}, )" + fmt::format(R"({{"n_embedding": {}, )" R"("embedding": [{:.10}]}})"sv, NEmbd, fmt::join(Embeddings, Embeddings + NEmbd, ","sv)); } From 5f056a57854904a3e683bbbf667af73f52838328 Mon Sep 17 00:00:00 2001 From: Yi-Ying He Date: Fri, 20 Sep 2024 19:47:16 +0800 Subject: [PATCH 435/623] [WASI-NN] ggml: fix accessing freed data after unload. (#3785) Signed-off-by: YiYing He --- plugins/wasi_nn/ggml.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/ggml.cpp index 6d97588a..65c58e9e 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/ggml.cpp @@ -1480,23 +1480,24 @@ Expect finiSingle(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { Expect unload(WasiNNEnvironment &Env, uint32_t GraphId) noexcept { auto &GraphRef = Env.NNGraph[GraphId].get(); - if (GraphRef.EnableDebugLog) { + const bool IsDebugLog = GraphRef.EnableDebugLog; + if (IsDebugLog) { spdlog::info("[WASI-NN][Debug] GGML backend: unload"sv); } if (GraphRef.LlamaModel != nullptr) { - if (GraphRef.EnableDebugLog) { + if (IsDebugLog) { spdlog::info("[WASI-NN][Debug] GGML backend: unload: free llama model"sv); } llama_free_model(GraphRef.LlamaModel); GraphRef.LlamaModel = nullptr; - if (GraphRef.EnableDebugLog) { + if (IsDebugLog) { spdlog::info( "[WASI-NN][Debug] GGML backend: unload: free llama model...Done"sv); } } Env.NNGraph.erase(Env.NNGraph.begin() + GraphId); Env.mdRemoveById(GraphId); - if (GraphRef.EnableDebugLog) { + if (IsDebugLog) { spdlog::info("[WASI-NN][Debug] GGML backend: unload...Done"sv); } return ErrNo::Success; From 4683d938647305f02b51c89b0024352dca43bd05 Mon Sep 17 00:00:00 2001 From: PeterD1524 Date: Sun, 22 Sep 2024 02:41:42 +0800 Subject: [PATCH 436/623] [Docker] install latest cmake from Kitware APT Repository for ubuntu 20.04 Kitware APT Repository: https://apt.kitware.com/ Signed-off-by: PeterD1524 --- utils/docker/Dockerfile.ubuntu-base | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/utils/docker/Dockerfile.ubuntu-base b/utils/docker/Dockerfile.ubuntu-base index 8dc3f7b3..416b6611 100644 --- a/utils/docker/Dockerfile.ubuntu-base +++ b/utils/docker/Dockerfile.ubuntu-base @@ -5,7 +5,6 @@ ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update && \ apt-get upgrade -y && \ apt-get install -y \ - cmake \ curl \ dpkg-dev \ g++ \ @@ -19,6 +18,9 @@ RUN apt-get update && \ ### deps for ubuntu 20.04 ### FROM base AS deps-20 +RUN curl -sSf https://apt.kitware.com/kitware-archive.sh | sh +RUN apt-get install -y cmake + RUN apt-get install -y \ llvm-12-dev \ liblld-12-dev \ @@ -34,6 +36,8 @@ ENV CXX=/usr/bin/clang++-12 ### deps for ubuntu 22.04 ### FROM base AS deps-22 +RUN apt-get install -y cmake + RUN apt-get install -y \ llvm-15-dev \ liblld-15-dev \ From e26fda9453f1e695a2ef4f1b96b8e2e577fa0986 Mon Sep 17 00:00:00 2001 From: Yi Date: Mon, 23 Sep 2024 17:03:49 +0800 Subject: [PATCH 437/623] [CI] Refactor ubuntu plugins (WASI-NN) (#3786) * [CI] Ubuntu: Rename 22.04 to latest Signed-off-by: Yi Huang * [Docker] Ubuntu: Tag 22.04 Signed-off-by: Yi Huang * [Docker] Ubuntu: Add WASI-NN dependencies Signed-off-by: Yi Huang * [Docker] Ubuntu: Clean-up files after installations Signed-off-by: Yi Huang * [CI] Plugins: Add ubuntu 20.04 and 22.04 to matrix - wasi_nn-ggml - wasi_nn-openvino - wasi_nn-piper - wasi_nn-pytorch - wasi_nn-tensorflowlite - wasi_nn-whisper Signed-off-by: Yi Huang * [CI] Plugins: Remove ubuntu plugins refactored with matrix - wasi_nn-ggml - wasi_nn-openvino - wasi_nn-piper - wasi_nn-pytorch - wasi_nn-tensorflowlite - wasi_nn-whisper Signed-off-by: Yi Huang --------- Signed-off-by: Yi Huang --- utils/docker/Dockerfile.ubuntu-plugins-deps | 27 +++++++++++++++++++++ utils/docker/docker-bake.ubuntu.hcl | 19 ++++++++++++++- utils/ffmpeg/install-ffmpeg-v6.0.sh | 5 ++-- utils/opencvmini/install-opencvmini.sh | 2 ++ utils/wasi-nn/install-openvino.sh | 6 +++-- 5 files changed, 54 insertions(+), 5 deletions(-) diff --git a/utils/docker/Dockerfile.ubuntu-plugins-deps b/utils/docker/Dockerfile.ubuntu-plugins-deps index dcff0c53..eaa2fd82 100644 --- a/utils/docker/Dockerfile.ubuntu-plugins-deps +++ b/utils/docker/Dockerfile.ubuntu-plugins-deps @@ -8,11 +8,17 @@ RUN apt-get update && \ cargo \ libelf-dev \ libomp-dev \ + libopenblas-dev \ libssl-dev \ pkg-config \ unzip \ yasm +RUN apt-get install -y \ + libgrpc++-dev \ + libgrpc-dev \ + protobuf-compiler-grpc + COPY opencvmini/install-opencvmini.sh . ENV OPENCV_VERSION="4.8.0" RUN [ "/bin/bash", "install-opencvmini.sh" ] @@ -22,7 +28,28 @@ RUN [ "/bin/bash", "install-ffmpeg-v6.0.sh" ] ENV PKG_CONFIG_PATH=/root/FFmpeg-n6.0/output/lib/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} ENV LD_LIBRARY_PATH=/root/FFmpeg-n6.0/output/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} +COPY wasi-nn/install-pytorch.sh . +ENV PYTORCH_VERSION="1.8.2" +ENV PYTORCH_INSTALL_TO="/root" +ENV Torch_DIR="/root/libtorch" +RUN [ "/bin/bash", "install-pytorch.sh" ] + +COPY wasi-nn/install-openvino.sh . +ENV OPENVINO_VERSION="2024.2.0" +ENV OPENVINO_YEAR="2024" +RUN [ "/bin/bash", "install-openvino.sh" ] + +COPY wasi-nn/install-onnxruntime.sh . +RUN [ "/bin/bash", "install-onnxruntime.sh" ] + ### cleanup FROM base AS clean-apt +RUN rm -f \ + install-opencvmini.sh \ + install-ffmpeg-v6.0.sh \ + install-pytorch.sh \ + install-openvino.sh \ + install-onnxruntime.sh + RUN rm -rf /var/lib/apt/lists/* diff --git a/utils/docker/docker-bake.ubuntu.hcl b/utils/docker/docker-bake.ubuntu.hcl index fb9bf956..42f7c9d2 100644 --- a/utils/docker/docker-bake.ubuntu.hcl +++ b/utils/docker/docker-bake.ubuntu.hcl @@ -5,6 +5,12 @@ group "default" { ] } +group "latest" { + targets = [ + "base-2204-clang", + ] +} + function "no-dot" { params = [ubuntu] result = replace(ubuntu, ".", "") @@ -20,11 +26,21 @@ function "tags-latest" { result = target == "base" && ubuntu == "22.04" && toolchain == "clang" ? "latest" : "" } +function "tags-latest-backports" { + params = [target, ubuntu, toolchain] + result = ubuntu == "22.04" ? join("-", compact([ + "ubuntu", + "build", + toolchain, + target == "plugins" ? "plugins-deps" : "", + ])) : "" +} + function "tags-backports" { params = [target, ubuntu, toolchain] result = join("-", compact([ "ubuntu", - ubuntu != "22.04" ? ubuntu : "", + ubuntu, "build", toolchain, target == "plugins" ? "plugins-deps" : "", @@ -40,6 +56,7 @@ function "tags" { params = [target, ubuntu, toolchain] result = [for tag in compact([ tags-latest(target, ubuntu, toolchain), + tags-latest-backports(target, ubuntu, toolchain), tags-backports(target, ubuntu, toolchain), tags-simplified(target, ubuntu, toolchain), ]) : "wasmedge/wasmedge:${tag}"] diff --git a/utils/ffmpeg/install-ffmpeg-v6.0.sh b/utils/ffmpeg/install-ffmpeg-v6.0.sh index 479908ce..d02dc733 100755 --- a/utils/ffmpeg/install-ffmpeg-v6.0.sh +++ b/utils/ffmpeg/install-ffmpeg-v6.0.sh @@ -1,6 +1,5 @@ #!/usr/bin/env bash -rm -rf FFmpeg-n6.0 ffmpeg.zip -echo $(pwd) +set -e curl -sL https://github.com/FFmpeg/FFmpeg/archive/refs/tags/n6.0.zip -o ffmpeg.zip @@ -11,3 +10,5 @@ cd FFmpeg-n6.0 ./configure --prefix=$(pwd)/output --enable-gpl --enable-nonfree --enable-shared --disable-static make && make install cd .. + +rm -rf ffmpeg.zip diff --git a/utils/opencvmini/install-opencvmini.sh b/utils/opencvmini/install-opencvmini.sh index 799f7382..6e47fd99 100644 --- a/utils/opencvmini/install-opencvmini.sh +++ b/utils/opencvmini/install-opencvmini.sh @@ -15,3 +15,5 @@ cmake -GNinja .. cmake --build . # Install to system cmake --install . + +rm -f opencv.zip diff --git a/utils/wasi-nn/install-openvino.sh b/utils/wasi-nn/install-openvino.sh index 10a3d089..cd4030c0 100755 --- a/utils/wasi-nn/install-openvino.sh +++ b/utils/wasi-nn/install-openvino.sh @@ -4,8 +4,10 @@ set -e echo "Installing OpenVINO with version 2024.2.0" -wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB -apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB +KEY_FILE=GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB +wget https://apt.repos.intel.com/intel-gpg-keys/$KEY_FILE && \ + apt-key add $KEY_FILE && \ + rm -f $KEY_FILE echo "deb https://apt.repos.intel.com/openvino/2024 ubuntu20 main" | tee /etc/apt/sources.list.d/intel-openvino-2024.list apt update apt-get -y install openvino-2024.2.0 From 82f48910d6537bf816389e1d8dedd25a830a135b Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Tue, 24 Sep 2024 15:44:46 +0800 Subject: [PATCH 438/623] [Docker] Ubuntu: Pre-install CUDA Signed-off-by: Yi Huang --- utils/docker/Dockerfile.ubuntu-cuda | 24 ++++++++++++++++++++++++ utils/docker/docker-bake.ubuntu.hcl | 20 ++++++++++++++++++++ 2 files changed, 44 insertions(+) create mode 100644 utils/docker/Dockerfile.ubuntu-cuda diff --git a/utils/docker/Dockerfile.ubuntu-cuda b/utils/docker/Dockerfile.ubuntu-cuda new file mode 100644 index 00000000..f10b8b5f --- /dev/null +++ b/utils/docker/Dockerfile.ubuntu-cuda @@ -0,0 +1,24 @@ +ARG BASE_IMAGE=wasmedge/wasmedge:latest +FROM ${BASE_IMAGE} AS base + +WORKDIR /root + +ARG CUDA_KEYRING=cuda-keyring_1.1-1_all.deb +RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/${CUDA_KEYRING} && \ + dpkg -i ${CUDA_KEYRING} && \ + rm -f ${CUDA_KEYRING} + +ARG NVCC_VER=12-0 +RUN apt-get update && \ + apt-get install -y \ + cuda-nvcc-${NVCC_VER} \ + libcublas-dev-${NVCC_VER} \ + pkg-config \ + unzip + +ENV CXXFLAGS="-Wno-error" + +### cleanup +FROM base AS clean-apt + +RUN rm -rf /var/lib/apt/lists/* diff --git a/utils/docker/docker-bake.ubuntu.hcl b/utils/docker/docker-bake.ubuntu.hcl index 42f7c9d2..1c7ad294 100644 --- a/utils/docker/docker-bake.ubuntu.hcl +++ b/utils/docker/docker-bake.ubuntu.hcl @@ -1,5 +1,6 @@ group "default" { targets = [ + "cuda", "clang", "gcc" ] @@ -131,6 +132,25 @@ target "gcc" { } } +target "cuda" { + dockerfile = "Dockerfile.ubuntu-cuda" + context = "./utils/docker" + + matrix = { + cuda = ["11.3", "12.0"] + } + + name = "base-2004-gcc-cuda${major(cuda)}" + contexts = { + "wasmedge/wasmedge:ubuntu-20.04-build-gcc" = "target:base-2004-gcc" + } + tags = ["wasmedge/wasmedge:ubuntu-20.04-build-gcc-cuda${major(cuda)}"] + args = { + BASE_IMAGE = "wasmedge/wasmedge:ubuntu-20.04-build-gcc" + NVCC_VER = replace(cuda, ".", "-") + } +} + # TODO: Refactor with multi-arch image target "base-2004-clang-aarch64" { inherits = ["base-2004"] From 684afe605dbcbcf6ec0a747ad7d2799dd8c51360 Mon Sep 17 00:00:00 2001 From: grorge Date: Thu, 18 Jul 2024 23:00:33 +0800 Subject: [PATCH 439/623] [WASI-NN] mlx: add cmake Signed-off-by: grorge --- plugins/wasi_nn/CMakeLists.txt | 21 +++++++++++++++++++++ plugins/wasi_nn/mlx.cpp | 0 plugins/wasi_nn/mlx.h | 0 3 files changed, 21 insertions(+) create mode 100644 plugins/wasi_nn/mlx.cpp create mode 100644 plugins/wasi_nn/mlx.h diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 328ba9dc..138ba623 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -16,6 +16,7 @@ wasmedge_add_library(wasmedgePluginWasiNN piper.cpp whispercpp.cpp chattts.cpp + mlx.cpp ) foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) @@ -190,6 +191,26 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) whisper simdjson::simdjson ) + elseif(BACKEND STREQUAL "mlx") + include(FetchContent) + find_package(LAPACK REQUIRED) + message(STATUS "LAPACK_LIBRARIES: ${LAPACK_LIBRARIES} ${LAPACK_INCLUDE_DIRS}") + if (!${APPLE}) + set(MLX_BUILD_METAL OFF) + endif() + FetchContent_Declare( + mlx + GIT_REPOSITORY https://github.com/ml-explore/mlx.git + GIT_TAG v0.16.0 + GIT_SHALLOW FALSE + ) + FetchContent_MakeAvailable(mlx) + # install liblapack-dev libopenblas-dev + set_property(TARGET mlx PROPERTY POSITION_INDEPENDENT_CODE ON) + target_link_libraries(mlx PRIVATE OpenBLAS) + target_link_libraries(wasmedgePluginWasiNN PRIVATE + mlx + ) endif() endforeach() diff --git a/plugins/wasi_nn/mlx.cpp b/plugins/wasi_nn/mlx.cpp new file mode 100644 index 00000000..e69de29b diff --git a/plugins/wasi_nn/mlx.h b/plugins/wasi_nn/mlx.h new file mode 100644 index 00000000..e69de29b From 755d8ff6671b7ff368b803d28e97dd3e91d9e5e7 Mon Sep 17 00:00:00 2001 From: grorge Date: Sun, 21 Jul 2024 21:21:23 +0800 Subject: [PATCH 440/623] [WASI-NN] mlx: add mlx to wasi-nn backend Signed-off-by: grorge --- plugins/wasi_nn/CMakeLists.txt | 54 ++++++++++++++++-------- plugins/wasi_nn/mlx.cpp | 77 ++++++++++++++++++++++++++++++++++ plugins/wasi_nn/mlx.h | 48 +++++++++++++++++++++ plugins/wasi_nn/types.h | 4 +- plugins/wasi_nn/wasinnenv.cpp | 11 +++-- plugins/wasi_nn/wasinnenv.h | 1 + 6 files changed, 172 insertions(+), 23 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 138ba623..6ef6ab6f 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -192,25 +192,43 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) simdjson::simdjson ) elseif(BACKEND STREQUAL "mlx") - include(FetchContent) - find_package(LAPACK REQUIRED) - message(STATUS "LAPACK_LIBRARIES: ${LAPACK_LIBRARIES} ${LAPACK_INCLUDE_DIRS}") - if (!${APPLE}) - set(MLX_BUILD_METAL OFF) + find_package(MLX CONFIG) + if(MLX_FOUND) + message(STATUS "Found MLX: ${MLX_INCLUDE_DIRS}") + else() + message(STATUS "MLX not found, downloading from source") + include(FetchContent) + set(MLX_BUILD_TESTS OFF) + if (NOT APPLE) + set(MLX_BUILD_METAL OFF) + set(MLX_BUILD_CPU OFF) + endif() + + FetchContent_Declare( + mlx + GIT_REPOSITORY https://github.com/ml-explore/mlx.git + GIT_TAG v0.16.0 + GIT_SHALLOW FALSE + ) + FetchContent_MakeAvailable(mlx) + set_property(TARGET mlx PROPERTY POSITION_INDEPENDENT_CODE ON) + set_target_properties(mlx PROPERTIES + INTERFACE_LINK_LIBRARIES "$" + ) + target_compile_options(mlx + PRIVATE + -Wno-unused-parameter + -Wno-deprecated-copy + -Wno-format + ) + target_compile_options(wasmedgePluginWasiNN + PRIVATE + -Wno-unused-parameter + ) + target_link_libraries(wasmedgePluginWasiNN PRIVATE + mlx + ) endif() - FetchContent_Declare( - mlx - GIT_REPOSITORY https://github.com/ml-explore/mlx.git - GIT_TAG v0.16.0 - GIT_SHALLOW FALSE - ) - FetchContent_MakeAvailable(mlx) - # install liblapack-dev libopenblas-dev - set_property(TARGET mlx PROPERTY POSITION_INDEPENDENT_CODE ON) - target_link_libraries(mlx PRIVATE OpenBLAS) - target_link_libraries(wasmedgePluginWasiNN PRIVATE - mlx - ) endif() endforeach() diff --git a/plugins/wasi_nn/mlx.cpp b/plugins/wasi_nn/mlx.cpp index e69de29b..e1e48c1e 100644 --- a/plugins/wasi_nn/mlx.cpp +++ b/plugins/wasi_nn/mlx.cpp @@ -0,0 +1,77 @@ +#include "mlx.h" +#include "wasinnenv.h" + +namespace WasmEdge::Host::WASINN::MLX { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX +Expect load(WASINN::WasiNNEnvironment &Env, + Span>, WASINN::Device, + uint32_t &GraphId) noexcept { + // Add a new graph. + // Env.NNGraph.emplace_back(Backend::MLX); + // auto &GraphRef = Env.NNGraph.back().get(); + + return WASINN::ErrNo::Success; +} + +Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, + uint32_t &ContextId) noexcept { + Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); + ContextId = Env.NNContext.size() - 1; + return ErrNo::Success; +} + +Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index, + const TensorData &Tensor) noexcept { + // auto &CxtRef = Env.NNContext[ContextId].get(); + // auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + + return WASINN::ErrNo::Success; +} + +Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index, Span OutBuffer, + uint32_t &BytesWritten) noexcept { + // auto &CxtRef = Env.NNContext[ContextId].get(); + // auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + + return WASINN::ErrNo::Success; +} +Expect compute(WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + // auto &CxtRef = Env.NNContext[ContextId].get(); + // auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + return WASINN::ErrNo::Success; +} +#else +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error("[WASI-NN] MLX backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"MLX\" to build it."sv); + return WASINN::ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WASINN::WasiNNEnvironment &, + Span>, WASINN::Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WASINN::WasiNNEnvironment &, uint32_t, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + Span, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WASINN::WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +#endif + +} // namespace WasmEdge::Host::WASINN::MLX \ No newline at end of file diff --git a/plugins/wasi_nn/mlx.h b/plugins/wasi_nn/mlx.h index e69de29b..56c61674 100644 --- a/plugins/wasi_nn/mlx.h +++ b/plugins/wasi_nn/mlx.h @@ -0,0 +1,48 @@ +#pragma once + +#include "plugin/plugin.h" +#include "types.h" +#include + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} + +namespace WasmEdge::Host::WASINN::MLX { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX +struct Graph { + mlx::core::StreamOrDevice MLXDevice = mlx::core::metal::is_available() + ? mlx::core::Device::gpu + : mlx::core::Device::cpu; +}; +struct Context { + Context(size_t Gid, Graph &) noexcept : GraphId(Gid) {} + size_t GraphId; +}; +#else +struct Graph {}; +struct Context { + Context(size_t, Graph &) noexcept {} +}; +#endif + +struct Environ {}; + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +Expect unload(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId) noexcept; +} // namespace WasmEdge::Host::WASINN::MLX \ No newline at end of file diff --git a/plugins/wasi_nn/types.h b/plugins/wasi_nn/types.h index 12eede09..6bf716b3 100644 --- a/plugins/wasi_nn/types.h +++ b/plugins/wasi_nn/types.h @@ -40,6 +40,7 @@ enum class Backend : uint8_t { GGML = 6, NeuralSpeed = 7, Whisper = 9, + MLX = 10, Piper = 11, ChatTTS = 12, }; @@ -54,7 +55,8 @@ enum class Backend : uint8_t { F(NeuralSpeed) \ F(Whisper) \ F(Piper) \ - F(ChatTTS) + F(ChatTTS) \ + F(MLX) struct TensorData { Span Dimension; diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index 331e4963..61c06f55 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -32,6 +32,7 @@ std::map BackendMap = { {"ggml"sv, Backend::GGML}, {"neuralspeed"sv, Backend::NeuralSpeed}, {"whisper"sv, Backend::Whisper}, + {"mlx"sv, Backend::MLX}, {"piper"sv, Backend::Piper}, {"chattts"sv, Backend::ChatTTS}}; @@ -105,9 +106,10 @@ WasiNNEnvironment::WasiNNEnvironment() noexcept { auto Device = DeviceMap.find(Target); if (Backend != BackendMap.end() && Device != DeviceMap.end()) { if (Backend->second == Backend::GGML) { - // In GGML, we only support loading one model from nn-preload config. - // To handle paths on Windows that contains `:` in the path, we combine - // the Paths into a single string separated by `:`. + // In GGML, we only support loading one model from nn-preload + // config. To handle paths on Windows that contains `:` in the + // path, we combine the Paths into a single string separated by + // `:`. std::string P; for (const std::string &PathSegment : Paths) { P += PathSegment; @@ -115,7 +117,8 @@ WasiNNEnvironment::WasiNNEnvironment() noexcept { P += ":"; } } - // We write model path to model data to avoid file IO in llama.cpp. + // We write model path to model data to avoid file IO in + // llama.cpp. std::string ModelPath = "preload:" + P; std::vector ModelPathData(ModelPath.begin(), ModelPath.end()); Models.push_back(std::move(ModelPathData)); diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index 6ec0077f..9a5100de 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -11,6 +11,7 @@ #include "chattts.h" #include "ggml.h" +#include "mlx.h" #include "neuralspeed.h" #include "onnx.h" #include "openvino.h" From 029377c6204a4456264ed991938d16b3eba9fd50 Mon Sep 17 00:00:00 2001 From: grorge Date: Fri, 23 Aug 2024 08:51:29 +0800 Subject: [PATCH 441/623] [WASI-NN] mlx: add mlx_cpp to wasi-nn backend Signed-off-by: grorge --- plugins/wasi_nn/CMakeLists.txt | 35 ++- plugins/wasi_nn/MLX/CMakeLists.txt | 34 +++ plugins/wasi_nn/MLX/mlx/activations.cpp | 6 + plugins/wasi_nn/MLX/mlx/activations.h | 8 + plugins/wasi_nn/MLX/mlx/base.cpp | 36 +++ plugins/wasi_nn/MLX/mlx/base.h | 53 ++++ plugins/wasi_nn/MLX/mlx/embedding.cpp | 12 + plugins/wasi_nn/MLX/mlx/embedding.h | 18 ++ plugins/wasi_nn/MLX/mlx/linear.cpp | 12 + plugins/wasi_nn/MLX/mlx/linear.h | 29 +++ plugins/wasi_nn/MLX/mlx/normalization.cpp | 6 + plugins/wasi_nn/MLX/mlx/normalization.h | 14 ++ .../wasi_nn/MLX/mlx/positional_encoding.cpp | 8 + plugins/wasi_nn/MLX/mlx/positional_encoding.h | 20 ++ plugins/wasi_nn/MLX/mlx/transformer.cpp | 12 + plugins/wasi_nn/MLX/mlx/transformer.h | 46 ++++ plugins/wasi_nn/MLX/model/converter.cpp | 76 ++++++ plugins/wasi_nn/MLX/model/converter.h | 11 + plugins/wasi_nn/MLX/model/registry.cpp | 22 ++ plugins/wasi_nn/MLX/model/registry.h | 14 ++ plugins/wasi_nn/MLX/model/transformer.cpp | 186 ++++++++++++++ plugins/wasi_nn/MLX/model/transformer.h | 198 +++++++++++++++ plugins/wasi_nn/MLX/model/utils.cpp | 32 +++ plugins/wasi_nn/MLX/model/utils.h | 8 + plugins/wasi_nn/MLX/prompt/llama.cpp | 5 + plugins/wasi_nn/MLX/prompt/llama.h | 24 ++ plugins/wasi_nn/mlx.cpp | 233 +++++++++++++++++- plugins/wasi_nn/mlx.h | 20 +- 28 files changed, 1154 insertions(+), 24 deletions(-) create mode 100644 plugins/wasi_nn/MLX/CMakeLists.txt create mode 100644 plugins/wasi_nn/MLX/mlx/activations.cpp create mode 100644 plugins/wasi_nn/MLX/mlx/activations.h create mode 100644 plugins/wasi_nn/MLX/mlx/base.cpp create mode 100644 plugins/wasi_nn/MLX/mlx/base.h create mode 100644 plugins/wasi_nn/MLX/mlx/embedding.cpp create mode 100644 plugins/wasi_nn/MLX/mlx/embedding.h create mode 100644 plugins/wasi_nn/MLX/mlx/linear.cpp create mode 100644 plugins/wasi_nn/MLX/mlx/linear.h create mode 100644 plugins/wasi_nn/MLX/mlx/normalization.cpp create mode 100644 plugins/wasi_nn/MLX/mlx/normalization.h create mode 100644 plugins/wasi_nn/MLX/mlx/positional_encoding.cpp create mode 100644 plugins/wasi_nn/MLX/mlx/positional_encoding.h create mode 100644 plugins/wasi_nn/MLX/mlx/transformer.cpp create mode 100644 plugins/wasi_nn/MLX/mlx/transformer.h create mode 100644 plugins/wasi_nn/MLX/model/converter.cpp create mode 100644 plugins/wasi_nn/MLX/model/converter.h create mode 100644 plugins/wasi_nn/MLX/model/registry.cpp create mode 100644 plugins/wasi_nn/MLX/model/registry.h create mode 100644 plugins/wasi_nn/MLX/model/transformer.cpp create mode 100644 plugins/wasi_nn/MLX/model/transformer.h create mode 100644 plugins/wasi_nn/MLX/model/utils.cpp create mode 100644 plugins/wasi_nn/MLX/model/utils.h create mode 100644 plugins/wasi_nn/MLX/prompt/llama.cpp create mode 100644 plugins/wasi_nn/MLX/prompt/llama.h diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 6ef6ab6f..6cc6386a 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -192,10 +192,12 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) simdjson::simdjson ) elseif(BACKEND STREQUAL "mlx") + wasmedge_setup_simdjson() find_package(MLX CONFIG) if(MLX_FOUND) message(STATUS "Found MLX: ${MLX_INCLUDE_DIRS}") else() + # TODO: TEST message(STATUS "MLX not found, downloading from source") include(FetchContent) set(MLX_BUILD_TESTS OFF) @@ -216,19 +218,34 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) INTERFACE_LINK_LIBRARIES "$" ) target_compile_options(mlx - PRIVATE - -Wno-unused-parameter - -Wno-deprecated-copy - -Wno-format - ) - target_compile_options(wasmedgePluginWasiNN PRIVATE -Wno-unused-parameter - ) - target_link_libraries(wasmedgePluginWasiNN PRIVATE - mlx + -Wno-deprecated-copy + -Wno-format ) endif() + add_subdirectory(MLX) + target_include_directories(wasmedgePluginWasiNN PRIVATE MLX/model MLX/prompt MLX/mlx) + + message(STATUS "Downloading tokenizers") + FetchContent_Declare( + tokenizers + GIT_REPOSITORY git@github.com:mlc-ai/tokenizers-cpp.git + GIT_TAG 5de6f65 + GIT_SHALLOW FALSE + ) + FetchContent_MakeAvailable(tokenizers) + target_include_directories(wasmedgePluginWasiNN PRIVATE ${tokenizers_SOURCE_DIR}/include) + target_link_libraries(wasmedgePluginWasiNN PRIVATE tokenizers_cpp) + target_compile_options(wasmedgePluginWasiNN + PRIVATE + -Wno-unused-parameter + ) + target_link_libraries(wasmedgePluginWasiNN PRIVATE + mlx_cpp + mlx + simdjson::simdjson + ) endif() endforeach() diff --git a/plugins/wasi_nn/MLX/CMakeLists.txt b/plugins/wasi_nn/MLX/CMakeLists.txt new file mode 100644 index 00000000..e83328d7 --- /dev/null +++ b/plugins/wasi_nn/MLX/CMakeLists.txt @@ -0,0 +1,34 @@ +wasmedge_add_library( + mlx_cpp + prompt/llama.cpp + model/transformer.cpp + model/converter.cpp + model/utils.cpp + model/registry.cpp + mlx/base.cpp + mlx/linear.cpp + mlx/positional_encoding.cpp + mlx/activations.cpp + mlx/embedding.cpp + mlx/normalization.cpp + mlx/transformer.cpp) + +target_link_libraries(mlx_cpp PUBLIC mlx) +target_include_directories(mlx_cpp PUBLIC ./mlx) +target_include_directories(mlx_cpp PUBLIC ${MLX_INCLUDE_DIRS}) +target_link_libraries(mlx_cpp PUBLIC ${MLX_LIBRARIES}) + +message(STATUS "Downloading gguflib") +FetchContent_Declare( + gguflib + GIT_REPOSITORY https://github.com/antirez/gguf-tools/ + GIT_TAG af7d88d808a7608a33723fba067036202910acb3 + GIT_SHALLOW FALSE +) +FetchContent_MakeAvailable(gguflib) +target_include_directories(mlx_cpp + PRIVATE $) +add_library(gguflib STATIC ${gguflib_SOURCE_DIR}/fp16.c + ${gguflib_SOURCE_DIR}/gguflib.c) +set_target_properties(gguflib PROPERTIES LINKER_LANGUAGE CXX) +target_link_libraries(mlx_cpp PRIVATE gguflib) diff --git a/plugins/wasi_nn/MLX/mlx/activations.cpp b/plugins/wasi_nn/MLX/mlx/activations.cpp new file mode 100644 index 00000000..cb0b559a --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/activations.cpp @@ -0,0 +1,6 @@ +#include "activations.h" +#include +namespace mlx::core { +mx::array gelu(mx::array X) { return X * (1 + mx::erf(X / std::sqrt(2))) / 2; } +mx::array silu(mx::array X) { return X * mx::sigmoid(X); } +} // namespace mlx::core \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/mlx/activations.h b/plugins/wasi_nn/MLX/mlx/activations.h new file mode 100644 index 00000000..f3a02d83 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/activations.h @@ -0,0 +1,8 @@ +#pragma once +#include "base.h" +#include +#include +namespace mlx::core { +mx::array gelu(mx::array X); +mx::array silu(mx::array X); +} // namespace mlx::core \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/mlx/base.cpp b/plugins/wasi_nn/MLX/mlx/base.cpp new file mode 100644 index 00000000..03dff5f0 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/base.cpp @@ -0,0 +1,36 @@ +#include "base.h" +#include "../model/utils.h" +#include + +namespace mlx::core::nn { + +mx::array &Module::registerParameter(std::string Name, array &&W) { + Parameters.insert({Name, W}); + return Parameters.at(Name); +} +void Module::update(std::unordered_map Parameters) { + for (auto &[k, v] : Parameters) { + apply(k, v); + } +} +void Module::apply(std::string Key, mx::array Value) { + std::vector SplitKey = splitString(Key, '.'); + if (SplitKey.size() == 1) { + if (Parameters.find(Key) == Parameters.end()) { + throw std::invalid_argument("Unsupported weight: " + Key); + } + this->Parameters.at(Key) = Value; + } else { + std::string LayerName = SplitKey[0]; + SplitKey.erase(SplitKey.begin()); + if (LayerName == "layers") { + LayerName += "." + SplitKey[0]; + SplitKey.erase(SplitKey.begin()); + } + if (Submodules.find(LayerName) == Submodules.end()) { + throw std::invalid_argument("Unsupported Tensor: " + LayerName); + } + Submodules.at(LayerName)->apply(joinString(SplitKey, '.'), Value); + } +} +} // namespace mlx::core::nn \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/mlx/base.h b/plugins/wasi_nn/MLX/mlx/base.h new file mode 100644 index 00000000..ddf68e99 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/base.h @@ -0,0 +1,53 @@ +#pragma once + +#include "mlx/mlx.h" +#include +#include +#include +#include +namespace mx = mlx::core; + +namespace mlx::core::nn { +class Module { +public: + virtual ~Module() { + for (auto Module : Submodules) { + delete Module.second; + } + } + std::string Name; + std::unordered_map Parameters{}; + std::unordered_map Submodules{}; + mx::array ®isterParameter(std::string Name, array &&W); + + void update(std::unordered_map Parameters); + void apply(std::string Key, mx::array Parameters); + template void registerModule(std::string ModuleName, T *M) { + using DecayedT = std::decay_t; + if (!std::is_base_of::value) { + throw std::invalid_argument("Invalid subModule."); + } + + if (Submodules.find(ModuleName) == Submodules.end()) { + Submodules.insert({ModuleName, M}); + Submodules.at(ModuleName)->Name = ModuleName; + } + } + template + void registerLayer(std::string ModuleName, std::vector &Layers) { + if (!std::is_base_of::value) { + throw std::invalid_argument("Invalid subModule."); + } + for (size_t Idx = 0; Idx < Layers.size(); Idx++) { + registerModule(ModuleName + "." + std::to_string(Idx), Layers[Idx]); + } + } +}; +} // namespace mlx::core::nn + +template void printVec(std::vector Ve) { + for (auto I : Ve) { + std::cout << I << " "; + } + std::cout << std::endl; +} \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/mlx/embedding.cpp b/plugins/wasi_nn/MLX/mlx/embedding.cpp new file mode 100644 index 00000000..96c05088 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/embedding.cpp @@ -0,0 +1,12 @@ +#include "embedding.h" +#include +#include + +namespace mlx::core::nn { +mx::array Embedding::forward(mx::array Input) { + return take(Parameters.at("weight"), Input, 0); +} +mx::array Embedding::asLinear(mx::array Input) { + return matmul(Input, transpose(Parameters.at("weight"))); +} +} // namespace mlx::core::nn \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/mlx/embedding.h b/plugins/wasi_nn/MLX/mlx/embedding.h new file mode 100644 index 00000000..cd11b26e --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/embedding.h @@ -0,0 +1,18 @@ +#pragma once +#include "base.h" +#include +#include +#include +namespace mlx::core::nn { +class Embedding : public Module { +public: + Embedding(int NumEmbeddings, int Dims) { + const double Scale = std::sqrt(1 / Dims); + registerParameter("weight", + mx::random::normal({NumEmbeddings, Dims}, 0.0, Scale)); + } + mx::array forward(mx::array Input); + mx::array asLinear(mx::array Input); +}; + +} // namespace mlx::core::nn \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/mlx/linear.cpp b/plugins/wasi_nn/MLX/mlx/linear.cpp new file mode 100644 index 00000000..5cd7b0a2 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/linear.cpp @@ -0,0 +1,12 @@ +#include "linear.h" + +namespace mlx::core::nn { +mx::array Linear::forward(mx::array Input) { + if (EnableBias) { + return mx::addmm(Parameters.at("bias"), Input, + transpose(Parameters.at("weight"))); + } + return matmul(Input, transpose(Parameters.at("weight"))); +} + +} // namespace mlx::core::nn \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/mlx/linear.h b/plugins/wasi_nn/MLX/mlx/linear.h new file mode 100644 index 00000000..1b7c5c75 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/linear.h @@ -0,0 +1,29 @@ +#pragma once + +#include "base.h" +#include "mlx/mlx.h" +#include +#include +#include + +namespace mlx::core::nn { + +class Linear : public Module { + bool EnableBias = true; + +public: + Linear(int InputDims, int OutputDims, bool EnableBias = true) + : EnableBias(EnableBias) { + const double Scale = std::sqrt(1.0 / InputDims); + registerParameter( + "weight", mx::random::uniform(-Scale, Scale, {OutputDims, InputDims})); + if (EnableBias) { + registerParameter("bias", mx::random::uniform(-Scale, Scale, + { + OutputDims, + })); + } + } + mx::array forward(mx::array Input); +}; +} // namespace mlx::core::nn \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/mlx/normalization.cpp b/plugins/wasi_nn/MLX/mlx/normalization.cpp new file mode 100644 index 00000000..75bf0708 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/normalization.cpp @@ -0,0 +1,6 @@ +#include "normalization.h" +namespace mlx::core::nn { +mx::array RMSNorm::forward(mx::array Input) { + return mx::fast::rms_norm(Input, Parameters.at("weight"), Eps); +} +} // namespace mlx::core::nn \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/mlx/normalization.h b/plugins/wasi_nn/MLX/mlx/normalization.h new file mode 100644 index 00000000..1e6d9072 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/normalization.h @@ -0,0 +1,14 @@ +#include "base.h" + +namespace mlx::core::nn { +class RMSNorm : public nn::Module { + float Eps; + +public: + RMSNorm(int Dims, float Eps = 1e-5) : Eps(Eps) { + registerParameter("weight", mx::ones({Dims})); + } + mx::array forward(mx::array Input); +}; + +} // namespace mlx::core::nn \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/mlx/positional_encoding.cpp b/plugins/wasi_nn/MLX/mlx/positional_encoding.cpp new file mode 100644 index 00000000..fba24b14 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/positional_encoding.cpp @@ -0,0 +1,8 @@ +#include "positional_encoding.h" + +namespace mlx::core::nn { +mx::array RoPE::forward(mx::array Input, int Offset) { + return mx::fast::rope(Input, Dims, Tranditional, Base, Scale, Offset); +} + +} // namespace mlx::core::nn \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/mlx/positional_encoding.h b/plugins/wasi_nn/MLX/mlx/positional_encoding.h new file mode 100644 index 00000000..5d997909 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/positional_encoding.h @@ -0,0 +1,20 @@ +#pragma once +#include "base.h" +#include +#include + +namespace mlx::core::nn { +class RoPE : public Module { + int Dims; + bool Tranditional; + float Base; + float Scale; + +public: + RoPE(int Dims, bool Traditional = false, float Base = 10000, + float Scale = 1.0) + : Dims(Dims), Tranditional(Traditional), Base(Base), Scale(Scale) {} + mx::array forward(mx::array Input, int Offset = 0); +}; + +} // namespace mlx::core::nn \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/mlx/transformer.cpp b/plugins/wasi_nn/MLX/mlx/transformer.cpp new file mode 100644 index 00000000..7564043f --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/transformer.cpp @@ -0,0 +1,12 @@ +#include "transformer.h" +#include + +namespace mlx::core::nn { +mx::array MultiHeadAttention::createAdditiveCausalMask(int N, mx::Dtype DType) { + auto Indices = mx::arange(N); + // mask = indices[:, None] < indices[None] + auto Mask = reshape(Indices, {N, 1}) < reshape(Indices, {1, N}); + Mask = astype(Mask, DType) * -1e9; + return Mask; +} +} // namespace mlx::core::nn \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/mlx/transformer.h b/plugins/wasi_nn/MLX/mlx/transformer.h new file mode 100644 index 00000000..4227222c --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/transformer.h @@ -0,0 +1,46 @@ +#pragma once +#include "base.h" +#include "linear.h" +#include + +namespace mlx::core::nn { +class MultiHeadAttention : public Module { + int NumHeads; + +public: + MultiHeadAttention(int Dims, int NumHeads, + std::optional QueryInputDims = {}, + std::optional KeyInputDims = {}, + std::optional ValueInputDims = {}, + std::optional ValueDims = {}, + std::optional ValueOutputDims = {}, bool Bias = false) + : NumHeads(NumHeads) { + if (Dims % NumHeads != 0) { + throw std::invalid_argument("Dims must be divisible by NumHeads"); + } + if (!QueryInputDims) { + QueryInputDims = Dims; + } + if (!KeyInputDims) { + KeyInputDims = Dims; + } + if (!ValueInputDims) { + ValueInputDims = KeyInputDims; + } + if (!ValueDims) { + ValueDims = Dims; + } + if (!ValueOutputDims) { + ValueOutputDims = Dims; + } + registerModule("query_proj", new Linear(*QueryInputDims, Dims, Bias)); + registerModule("key_proj", new Linear(*KeyInputDims, Dims, Bias)); + registerModule("value_proj", new Linear(*ValueInputDims, *ValueDims, Bias)); + registerModule("out_proj", new Linear(*ValueDims, *ValueOutputDims, Bias)); + }; + mx::array forward(mx::array Queries, mx::array Keys, mx::array Values, + mx::array Mask); + static mx::array createAdditiveCausalMask(int N, + mx::Dtype DType = mx::float32); +}; +} // namespace mlx::core::nn \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/model/converter.cpp b/plugins/wasi_nn/MLX/model/converter.cpp new file mode 100644 index 00000000..0b82e979 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/converter.cpp @@ -0,0 +1,76 @@ +#include "converter.h" +#include "utils.h" +#include +#include +#include +#include + +std::unordered_map +weightsToMlx(std::string WeightPath) { + const std::filesystem::path Path(WeightPath); + if (std::filesystem::is_directory(Path)) { + std::unordered_map Loaded; + for (const auto &Entry : std::filesystem::directory_iterator(Path)) { + if (Entry.path().extension() == ".safetensors") { + auto SubWeight = weightsToMlx(Entry.path()); + Loaded.insert(SubWeight.begin(), SubWeight.end()); + } + } + return Loaded; + } + if (endsWith(WeightPath, ".safetensors")) { + std::cout << "Loading model from .safetensors file...\n"; + const mx::SafetensorsLoad Loaded = mx::load_safetensors(WeightPath); + return Loaded.first; + } + if (endsWith(WeightPath, ".gguf")) { + std::cout << "Loading model from .gguf file...\n"; + const mx::GGUFLoad Loaded = mx::load_gguf(WeightPath); + return Loaded.first; + } + std::cout << "Can not regonize model file\n"; + throw std::invalid_argument("Invalid model path."); +} + +std::unordered_map +llamaToMlxllm(std::string WeightPath) { + std::unordered_map ModelWeights; + auto Weight = weightsToMlx(WeightPath); + for (auto &[k, v] : Weight) { + std::string NewKey = k; + if (startsWith(NewKey, "model.")) { + strReplace(NewKey, "model.", ""); + } + std::vector SplitKey = splitString(NewKey, '.'); + if (find(SplitKey.begin(), SplitKey.end(), "layers") != SplitKey.end()) { + if (find(SplitKey.begin(), SplitKey.end(), "rotary_emb") != + SplitKey.end()) { + continue; + } + if (find(SplitKey.begin(), SplitKey.end(), "self_attn") != + SplitKey.end()) { + ModelWeights.insert({SplitKey[0] + "." + SplitKey[1] + ".attention." + + SplitKey[3] + "." + SplitKey[4], + v}); + } else if (find(SplitKey.begin(), SplitKey.end(), "mlp") != + SplitKey.end()) { + + ModelWeights.insert({NewKey, v}); + } else { + const std::unordered_map KeyMap = { + {"input_layernorm", "attention_norm"}, + {"post_attention_layernorm", "mlp_norm"}}; + ModelWeights.insert({SplitKey[0] + "." + SplitKey[1] + "." + + KeyMap.at(SplitKey[2]) + "." + SplitKey[3], + v}); + } + } else { + const std::unordered_map KeyMap = { + {"embed_tokens", "token_embed"}, + {"lm_head", "head"}, + {"norm", "norm"}}; + ModelWeights.insert({KeyMap.at(SplitKey[0]) + "." + SplitKey[1], v}); + } + } + return ModelWeights; +} \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/model/converter.h b/plugins/wasi_nn/MLX/model/converter.h new file mode 100644 index 00000000..c5dbd4b5 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/converter.h @@ -0,0 +1,11 @@ +#pragma once +#include "base.h" +#include "mlx/mlx.h" +#include + +#define strReplace(Str, From, To) Str.replace(Str.find(From), strlen(From), To) + +std::unordered_map weightsToMlx(std::string WeightPath); + +std::unordered_map +llamaToMlxllm(std::string WeightPath); \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/model/registry.cpp b/plugins/wasi_nn/MLX/model/registry.cpp new file mode 100644 index 00000000..2ff48b92 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/registry.cpp @@ -0,0 +1,22 @@ +#include "registry.h" + +Transformer *llama38b(int VocabSize, float NormEps, float RopeTheta, + bool RopeTraditional) { + return new Transformer(4096, std::vector{14336}, VocabSize, 32, + std::vector{32}, std::vector{8}, NormEps, {}, + RopeTraditional, RopeTheta); +} + +Transformer *llama27bChat(int VocabSize, float NormEps, float RopeTheta, + bool RopeTraditional) { + return new Transformer(4096, std::vector{11008}, VocabSize, 32, + std::vector{32}, std::vector{32}, NormEps, + {}, RopeTraditional, RopeTheta); +} + +Transformer *tinyLlama11BChatV10(int VocabSize, float NormEps, float RopeTheta, + bool RopeTraditional) { + return new Transformer(2048, std::vector{5632}, VocabSize, 22, + std::vector{32}, std::vector{4}, NormEps, {}, + RopeTraditional, RopeTheta); +} \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/model/registry.h b/plugins/wasi_nn/MLX/model/registry.h new file mode 100644 index 00000000..81774bd0 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/registry.h @@ -0,0 +1,14 @@ +#pragma once + +#include "transformer.h" + +Transformer *llama38b(int VocabSize = 32000, float NormEps = 1e-5, + float RopeTheta = 10000.0, bool RopeTraditional = false); + +Transformer *llama27bChat(int VocabSize = 32000, float NormEps = 1e-5, + float RopeTheta = 10000.0, + bool RopeTraditional = false); + +Transformer *tinyLlama11BChatV10(int VocabSize = 32000, float NormEps = 1e-5, + float RopeTheta = 10000.0, + bool RopeTraditional = false); \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/model/transformer.cpp b/plugins/wasi_nn/MLX/model/transformer.cpp new file mode 100644 index 00000000..f6271664 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/transformer.cpp @@ -0,0 +1,186 @@ +#include "../mlx/transformer.h" +#include "base.h" +#include "embedding.h" +#include "linear.h" +#include "transformer.h" +#include +#include +#include +#include +#include +#include +#include + +mx::array RMSNorm::forward(mx::array Input) { + return mx::fast::rms_norm(Input, 1.0 + Parameters.at("weight"), Eps); +} +std::tuple> +Attention::forward(mx::array Input, std::optional Mask, + std::optional> KVCache) { + const auto &[B, L, D] = + std::tie(Input.shape()[0], Input.shape()[1], Input.shape()[2]); + mx::array Queries = + dynamic_cast(Submodules["q_proj"])->forward(Input); + mx::array Keys = + dynamic_cast(Submodules["k_proj"])->forward(Input); + mx::array Values = + dynamic_cast(Submodules["v_proj"])->forward(Input); + Queries = transpose(reshape(Queries, {B, L, NHeads, -1}), {0, 2, 1, 3}); + Keys = transpose(reshape(Keys, {B, L, NKVHeads, -1}), {0, 2, 1, 3}); + Values = transpose(reshape(Values, {B, L, NKVHeads, -1}), {0, 2, 1, 3}); + + if (NormQKProj) { + Queries = + dynamic_cast(Submodules["q_norm"])->forward(Queries); + Keys = dynamic_cast(Submodules["k_norm"])->forward(Keys); + } + if (KVCache) { + const auto &[KeyCache, ValueCache] = *KVCache; + Queries = dynamic_cast(Submodules["rope"]) + ->forward(Queries, KeyCache.shape(2)); + Keys = dynamic_cast(Submodules["rope"]) + ->forward(Keys, KeyCache.shape(2)); + Keys = mx::concatenate({KeyCache, Keys}, 2); + Values = mx::concatenate({ValueCache, Values}, 2); + } else { + Queries = dynamic_cast(Submodules["rope"])->forward(Queries); + Keys = dynamic_cast(Submodules["rope"])->forward(Keys); + } + mx::array Output = mx::fast::scaled_dot_product_attention( + Queries, Keys, Values, Scale, Mask); + Output = reshape(transpose(Output, {0, 2, 1, 3}), {B, L, -1}); + return {dynamic_cast(Submodules["o_proj"])->forward(Output), + {Keys, Values}}; +} +mx::array MLP::forward(mx::array Input) { + if (Gemma) { + return dynamic_cast(Submodules["down_proj"]) + ->forward( + gelu(dynamic_cast(Submodules["gate_proj"]) + ->forward(Input)) * + dynamic_cast(Submodules["up_proj"])->forward(Input)); + } + return dynamic_cast(Submodules["down_proj"]) + ->forward( + silu(dynamic_cast(Submodules["gate_proj"]) + ->forward(Input)) * + dynamic_cast(Submodules["up_proj"])->forward(Input)); +} +std::tuple> +TransformerBlock::forward( + mx::array Input, std::optional Mask, + std::optional> KVCachePar) { + mx::array NormOutput = {}; + if (!Gemma) { + NormOutput = dynamic_cast(Submodules["attention_norm"]) + ->forward(Input); + } else { + NormOutput = + dynamic_cast(Submodules["attention_norm"])->forward(Input); + } + auto [R, KVCache] = dynamic_cast(Submodules["attention"]) + ->forward(NormOutput, Mask, KVCachePar); + auto H = Input + R; + if (!Gemma) { + R = dynamic_cast(Submodules["mlp"]) + ->forward(dynamic_cast(Submodules["mlp_norm"]) + ->forward(H)); + } else { + R = dynamic_cast(Submodules["mlp"]) + ->forward( + dynamic_cast(Submodules["mlp_norm"])->forward(H)); + } + return {H + R, KVCache}; +} +std::tuple>>> +Transformer::embed( + mx::array Input, + std::optional>> KVCachePar, + bool Norm) { + mx::array H = dynamic_cast(Submodules["token_embed"]) + ->forward(Input); + if (Gemma) { + H = H * (pow(Dim, 0.5)); + } + std::optional Mask; + if (H.shape()[1] > 1) { + Mask = mx::nn::MultiHeadAttention::createAdditiveCausalMask(H.shape()[1]); + Mask = astype(*Mask, H.dtype()); + } + std::vector> KVCache; + KVCache.reserve(Layers.size()); + for (size_t Idx = 0; Idx < Layers.size(); Idx++) { + std::tuple> Result = {{}, + {{}, {}}}; + if (KVCachePar) { + Result = Layers[Idx]->forward(H, Mask, (*KVCachePar)[Idx]); + } else { + Result = Layers[Idx]->forward(H, Mask, {}); + } + H = std::get<0>(Result); + KVCache.emplace_back(std::get<1>(Result)); + } + if (Norm) { + if (!Gemma) { + return {dynamic_cast(Submodules["norm"])->forward(H), + KVCache}; + } + return {dynamic_cast(Submodules["norm"])->forward(H), KVCache}; + } + return {H, KVCache}; +} +std::tuple>>> +Transformer::forward( + mx::array Input, + std::optional>> KVCachePar) { + auto [X, KVCache] = embed(Input, KVCachePar, true); + mx::array Out = {}; + if (EmbedAsHead) { + Out = dynamic_cast(Submodules["token_embed"]) + ->asLinear(X); + } else { + Out = dynamic_cast(Submodules["head"])->forward(X); + } + return {Out, KVCache}; +} +std::tuple>>> +Transformer::generate(mx::array Input, std::optional Temp) { + // Reshape Input to input[:, None] + std::vector ReshapeDim = Input.shape(); + ReshapeDim.insert(ReshapeDim.begin(), 1); + auto [Logits, KVCache] = forward(reshape(Input, ReshapeDim)); + const int H = Logits.shape()[1] - 1; + // take logits[:, -1, :] + Logits = take(Logits, mx::array({H}), 1); + ReshapeDim = Logits.shape(); + ReshapeDim.erase(ReshapeDim.begin() + 1); + Logits = reshape(Logits, ReshapeDim); + mx::array Y = {}; + if (Temp == 0) { + Y = mx::argmax(Logits, -1); + } else { + Y = mx::random::categorical(Logits * (1.0 / *Temp)); + } + return {Y, KVCache}; +} +std::tuple>>> +Transformer::nextGenerate( + mx::array Y, std::optional Temp, + std::optional>> KVCachePar) { + // Reshape Y to y[:, None] + std::vector ReshapeDim = Y.shape(); + ReshapeDim.insert(ReshapeDim.begin() + 1, 1); + auto [Logits, KVCache] = forward(reshape(Y, ReshapeDim), KVCachePar); + Logits = squeeze(Logits, 1); + mx::array NextY = {}; + if (Temp == 0) { + NextY = mx::argmax(Logits, -1); + } else { + NextY = mx::random::categorical(Logits * (1.0 / *Temp)); + } + return {NextY, KVCache}; +} \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/model/transformer.h b/plugins/wasi_nn/MLX/model/transformer.h new file mode 100644 index 00000000..4be559d7 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/transformer.h @@ -0,0 +1,198 @@ +#pragma once +#include "activations.h" +#include "base.h" +#include "embedding.h" +#include "linear.h" +#include "normalization.h" +#include "positional_encoding.h" +#include +#include +#include +#include +#include +#include +#include + +namespace nn = mlx::core::nn; + +class RMSNorm : public nn::Module { + float Eps; + +public: + RMSNorm(int Dims, float Eps = 1e-5) : Eps(Eps) { + registerParameter("weight", mx::ones({Dims})); + } + mx::array forward(mx::array Input); +}; +class Attention : public nn::Module { + + int NHeads; + int NKVHeads; + bool NormQKProj; + double Scale; + +public: + Attention(int Dim, int NHeads, int NKVHeads, + std::optional HeadDimPar = 0, bool RopeTraditional = false, + float RopeTheta = 1000, + std::optional> + RopeScaling = {}, + bool NormQKProj = false, float AttentionNormEps = 1e-6) + : NHeads(NHeads), NKVHeads(NKVHeads), NormQKProj(NormQKProj) { + int HeadDim; + if (HeadDimPar) { + HeadDim = *HeadDimPar; + } else { + HeadDim = Dim / NHeads; + } + Scale = pow(HeadDim, -0.5); + registerModule("q_proj", new nn::Linear(Dim, NHeads * HeadDim, false)); + registerModule("k_proj", new nn::Linear(Dim, NKVHeads * HeadDim, false)); + registerModule("v_proj", new nn::Linear(Dim, NKVHeads * HeadDim, false)); + registerModule("o_proj", new nn::Linear(NHeads * HeadDim, Dim, false)); + + if (NormQKProj) { + registerModule("q_norm", new nn::RMSNorm(HeadDim, AttentionNormEps)); + registerModule("k_norm", new nn::RMSNorm(HeadDim, AttentionNormEps)); + } + float RopeScale; + if (RopeScaling && (*RopeScaling)["type"] == "linear") { + RopeScale = 1 / stof((*RopeScaling)["factor"]); + } else { + RopeScale = 1; + } + + registerModule( + "rope", new nn::RoPE(HeadDim, RopeTraditional, RopeTheta, RopeScale)); + } + std::tuple> + forward(mx::array Input, std::optional Mask = {}, + std::optional> KVCache = {}); +}; +class MLP : public nn::Module { + bool Gemma; + +public: + MLP(int Dim, int HiddenDim, bool Gemma = false) : Gemma(Gemma) { + registerModule("gate_proj", new nn::Linear(Dim, HiddenDim, false)); + registerModule("down_proj", new nn::Linear(HiddenDim, Dim, false)); + registerModule("up_proj", new nn::Linear(Dim, HiddenDim, false)); + } + mx::array forward(mx::array Input); +}; +class TransformerBlock : public nn::Module { + bool Gemma; + +public: + TransformerBlock(int Dim, int NHeads, int NKVHeads, int HiddenDim, + float NormEps, std::optional HeadDim = {}, + bool RopeTraditional = false, float RopeTheta = 1000, + std::optional> + RopeScaling = {}, + bool NormQKProj = false, float AttentionNormEps = 1e-6, + bool Gemma = false) + : Gemma(Gemma) { + registerModule("attention", + new Attention(Dim, NHeads, NKVHeads, HeadDim, + RopeTraditional, RopeTheta, RopeScaling, + NormQKProj, AttentionNormEps)); + registerModule("mlp", new MLP(Dim, HiddenDim, Gemma)); + if (!Gemma) { + registerModule("attention_norm", new nn::RMSNorm(Dim, NormEps)); + registerModule("mlp_norm", new nn::RMSNorm(Dim, NormEps)); + } else { + registerModule("attention_norm", new RMSNorm(Dim, NormEps)); + registerModule("mlp_norm", new RMSNorm(Dim, NormEps)); + } + } + std::tuple> + forward(mx::array Input, std::optional Mask = {}, + std::optional> KVCachePar = {}); +}; +class Transformer : public nn::Module { + int Dim; + std::optional> HiddenDim; + bool Gemma; + bool EmbedAsHead; + std::vector Layers{}; + +public: + Transformer( + int Dim, std::optional> HiddenDim, int VocabSize, + int NLayers, std::optional> NHeads, + std::optional> NKVHeads = {}, float NormEps = 1e-5, + std::optional HeadDim = {}, bool RopeTraditional = false, + float RopeTheta = 1000, + std::optional>> + RopeScaling = {}, + bool NormQKProj = false, float AttentionNormEps = 1e-6, + bool Gemma = false, bool EmbedAsHeadPar = false) + : Dim(Dim), HiddenDim(HiddenDim), Gemma(Gemma), + EmbedAsHead(EmbedAsHeadPar) { + if (VocabSize <= 0) { + throw std::invalid_argument("VocabSize must be greater than 0."); + } + EmbedAsHead = Gemma ? true : EmbedAsHead; + if (!NKVHeads) { + NKVHeads = NHeads; + } + registerModule("token_embed", new nn::Embedding(VocabSize, Dim)); + if (HiddenDim->size() == 1) { + while (static_cast(HiddenDim->size()) < NLayers) { + HiddenDim->emplace_back((*HiddenDim)[0]); + } + } + if (NHeads->size() == 1) { + while (static_cast(NHeads->size()) < NLayers) { + NHeads->emplace_back((*NHeads)[0]); + } + } + if (NKVHeads->size() == 1) { + while (static_cast(NKVHeads->size()) < NLayers) { + NKVHeads->emplace_back((*NKVHeads)[0]); + } + } + Layers.reserve(NLayers); + for (int Idx = 0; Idx < NLayers; Idx++) { + if (RopeScaling) { + Layers.push_back(new TransformerBlock( + Dim, (*NHeads)[Idx], (*NKVHeads)[Idx], (*HiddenDim)[Idx], NormEps, + HeadDim, RopeTraditional, RopeTheta, (*RopeScaling)[Idx], + NormQKProj, AttentionNormEps, Gemma)); + } else { + Layers.push_back(new TransformerBlock( + Dim, (*NHeads)[Idx], (*NKVHeads)[Idx], (*HiddenDim)[Idx], NormEps, + HeadDim, RopeTraditional, RopeTheta, {}, NormQKProj, + AttentionNormEps, Gemma)); + } + } + registerLayer("layers", Layers); + if (!Gemma) { + registerModule("norm", new nn::RMSNorm(Dim, NormEps)); + } else { + registerModule("norm", new RMSNorm(Dim, NormEps)); + } + if (!EmbedAsHead) { + registerModule("head", new nn::Linear(Dim, VocabSize, false)); + } + } + std::tuple>>> + embed(mx::array Input, + std::optional>> + KVCachePar = {}, + bool Norm = false); + std::tuple>>> + forward(mx::array Input, + std::optional>> + KVCachePar = {}); + std::tuple>>> + generate(mx::array Input, std::optional Temp = 0.0); + std::tuple>>> + nextGenerate(mx::array Y, std::optional Temp = 0.0, + std::optional>> + KVCachePar = {}); +}; \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/model/utils.cpp b/plugins/wasi_nn/MLX/model/utils.cpp new file mode 100644 index 00000000..283a6059 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/utils.cpp @@ -0,0 +1,32 @@ +#include "utils.h" +#include +std::vector splitString(const std::string &S, char Delim) { + std::vector Result; + std::stringstream SS(S); + std::string Item; + while (std::getline(SS, Item, Delim)) { + Result.emplace_back(Item); + } + return Result; +} +std::string joinString(std::vector &S, char Delim) { + std::string Result; + for (size_t Idx = 0; Idx < S.size(); Idx++) { + Result += S[Idx]; + if (Idx != S.size() - 1) { + Result += Delim; + } + } + return Result; +} + +bool endsWith(std::string const &Value, std::string const &Ending) { + if (Ending.size() > Value.size()) + return false; + return std::equal(Ending.rbegin(), Ending.rend(), Value.rbegin()); +} +bool startsWith(std::string const &Value, std::string const &Starting) { + if (Starting.size() > Value.size()) + return false; + return std::equal(Starting.begin(), Starting.end(), Value.begin()); +} \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/model/utils.h b/plugins/wasi_nn/MLX/model/utils.h new file mode 100644 index 00000000..5c3d9682 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/utils.h @@ -0,0 +1,8 @@ +#pragma once +#include +#include +#include +std::vector splitString(const std::string &S, char Delim); +std::string joinString(std::vector &S, char Delim); +bool endsWith(std::string const &Value, std::string const &Ending); +bool startsWith(std::string const &Value, std::string const &Starting); \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/prompt/llama.cpp b/plugins/wasi_nn/MLX/prompt/llama.cpp new file mode 100644 index 00000000..81b41411 --- /dev/null +++ b/plugins/wasi_nn/MLX/prompt/llama.cpp @@ -0,0 +1,5 @@ +#include "llama.h" + +std::string TinyLLaMAPrompt::prepare(std::string Prompt) { + return SystemStart + TextEnd + Prompt + TextEnd + Assistant; +} \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/prompt/llama.h b/plugins/wasi_nn/MLX/prompt/llama.h new file mode 100644 index 00000000..21261819 --- /dev/null +++ b/plugins/wasi_nn/MLX/prompt/llama.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include + +class BasePrompt { +public: + std::string SystemStart; + std::string User; + std::string Assistant; + std::string TextEnd; + virtual std::string prepare(std::string Prompt) { + return SystemStart + TextEnd + Prompt + TextEnd + Assistant; + }; +}; + +class TinyLLaMAPrompt : public BasePrompt { +public: + std::string SystemStart = "<|system|>"; + std::string User = "<|user|>"; + std::string Assistant = "<|assistant|>"; + std::string TextEnd = ""; + std::string prepare(std::string Prompt); +}; \ No newline at end of file diff --git a/plugins/wasi_nn/mlx.cpp b/plugins/wasi_nn/mlx.cpp index e1e48c1e..a4ead7ce 100644 --- a/plugins/wasi_nn/mlx.cpp +++ b/plugins/wasi_nn/mlx.cpp @@ -1,15 +1,173 @@ #include "mlx.h" #include "wasinnenv.h" +#include +#include +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX +#include "converter.h" +#include "registry.h" +#include "simdjson.h" +#include "utils.h" +#endif namespace WasmEdge::Host::WASINN::MLX { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX +std::string loadBytesFromFile(const std::string &Path) { + std::ifstream Fs(Path, std::ios::in | std::ios::binary); + if (Fs.fail()) { + std::cerr << "Cannot open " << Path << std::endl; + exit(1); + } + std::string Data; + Fs.seekg(0, std::ios::end); + const size_t Size = static_cast(Fs.tellg()); + Fs.seekg(0, std::ios::beg); + Data.resize(Size); + Fs.read(Data.data(), Size); + return Data; +} +enum AnserSataus { + STOP, + WAIT, + GO, +}; +AnserSataus answerSataus(std::string Text, std::string End) { + if (endsWith(Text, End)) { + return STOP; + } + for (int Idx = 1; Idx < static_cast(End.size()); Idx++) { + if (endsWith(Text, End.substr(0, Idx))) { + return WAIT; + } + } + return GO; +} Expect load(WASINN::WasiNNEnvironment &Env, - Span>, WASINN::Device, + Span> Builders, WASINN::Device, uint32_t &GraphId) noexcept { // Add a new graph. - // Env.NNGraph.emplace_back(Backend::MLX); - // auto &GraphRef = Env.NNGraph.back().get(); - + Env.NNGraph.emplace_back(Backend::MLX); + auto &GraphRef = Env.NNGraph.back().get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN] MLX backend: Load."sv); + } + std::string TokenizerPath; + // Parse metadata. + if (Builders.size() > 1) { + const std::string Metadata = std::string( + reinterpret_cast(Builders[1].data()), Builders[1].size()); + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + auto ParseError = Parser.parse(Metadata).get(Doc); + if (ParseError) { + spdlog::error("[WASI-NN] MLX backend: Parse metadata error"sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidEncoding; + } + if (Doc.at_key("model_type").error() == simdjson::SUCCESS) { + std::string_view ModelType; + auto Err = Doc["model_type"].get().get(ModelType); + if (Err) { + spdlog::error( + "[WASI-NN] MLX backend: Unable to retrieve the model_type option."sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } + GraphRef.ModelType = ModelType; + } + if (Doc.at_key("enable_debug_log").error() == simdjson::SUCCESS) { + bool EnableDebugLog; + auto Err = Doc["enable_debug_log"].get().get(EnableDebugLog); + if (Err) { + spdlog::error( + "[WASI-NN] MLX backend: Unable to retrieve the enable_debug_log option."sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } + GraphRef.EnableDebugLog = EnableDebugLog; + } + if (Doc.at_key("tokenizer_path").error() == simdjson::SUCCESS) { + std::string_view TokenizerPathView; + auto Err = + Doc["tokenizer_path"].get().get(TokenizerPathView); + if (Err) { + spdlog::error( + "[WASI-NN] MLX backend: Unable to retrieve the tokenizer_path option."sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } + TokenizerPath = TokenizerPathView; + } + } + + // Load tokenizer. + if (!TokenizerPath.empty()) { + GraphRef.Tok = + tokenizers::Tokenizer::FromBlobJSON(loadBytesFromFile(TokenizerPath)); + } + + // Handle the model path. + auto Weight = Builders[0]; + const std::string BinModel(reinterpret_cast(Weight.data()), + Weight.size()); + spdlog::info("[WASI-NN] Neural speed: BinModel: {}"sv, BinModel.size()); + std::string ModelFilePath; + if (BinModel.substr(0, 8) == "preload:"sv) { + ModelFilePath = BinModel.substr(8); + } else { + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] MLX backend: Model path not found in nn-preload, " + "write model into a tmpfile."sv); + } + // Write model to file. + ModelFilePath = "MLX.bin"sv; + std::ofstream TempFile(ModelFilePath, std::ios::out | std::ios::binary); + if (!TempFile) { + spdlog::error( + "[WASI-NN] MLX backend: Failed to create the temporary file. " + "Currently, our workaround involves creating a temporary model " + "file named \"MLX.bin\" and passing this filename as a " + "parameter to the ggml llama library."sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } + TempFile.write(BinModel.data(), BinModel.size()); + TempFile.close(); + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] MLX backend: Write model into a tmpfile...Done"sv); + } + } + + // Create Model. + if (GraphRef.Model == nullptr) { + if (GraphRef.ModelType == "tiny_llama_1.1B_chat_v1.0") { + GraphRef.Model = tinyLlama11BChatV10(); + GraphRef.Prmopt = TinyLLaMAPrompt(); + } else if (GraphRef.ModelType == "llama_3_8b") { + GraphRef.Model = llama38b(); + } else if (GraphRef.ModelType == "llama_2_7b_chat_hf") { + GraphRef.Model = llama27bChat(); + } else { + spdlog::error("[WASI-NN] MLX backend: Model type not supported."sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } + } + + // Load weight. + if (GraphRef.ModelType == "tiny_llama_1.1B_chat_v1.0") { + GraphRef.Model->update(llamaToMlxllm(ModelFilePath)); + } else if (GraphRef.ModelType == "llama_3_8b") { + GraphRef.Model->update(llamaToMlxllm(ModelFilePath)); + } else if (GraphRef.ModelType == "llama_2_7b_chat_hf") { + GraphRef.Model->update(llamaToMlxllm(ModelFilePath)); + } else { + spdlog::error("[WASI-NN] MLX backend: Model type not supported."sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } + return WASINN::ErrNo::Success; } @@ -23,24 +181,75 @@ Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, uint32_t Index, const TensorData &Tensor) noexcept { - // auto &CxtRef = Env.NNContext[ContextId].get(); - // auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN] MLX backend: setInput"sv); + } + CxtRef.Inputs = + std::string(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); return WASINN::ErrNo::Success; } Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, uint32_t Index, Span OutBuffer, uint32_t &BytesWritten) noexcept { - // auto &CxtRef = Env.NNContext[ContextId].get(); - // auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN] MLX backend: getOutput"sv); + } + std::string StringTmp(reinterpret_cast(CxtRef.Outputs.data()), + CxtRef.Outputs.size() * sizeof(long long int)); + std::copy_n(StringTmp.data(), StringTmp.length(), OutBuffer.data()); + BytesWritten = StringTmp.length(); return WASINN::ErrNo::Success; } Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { - // auto &CxtRef = Env.NNContext[ContextId].get(); - // auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + if (GraphRef.Tok == nullptr) { + spdlog::error("[WASI-NN] MLX backend: Tokenizer not loaded."sv); + return ErrNo::InvalidArgument; + } + if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN] MLX backend: compute"sv); + } + const std::vector Ids = GraphRef.Tok->Encode(CxtRef.Inputs); + auto Token = mx::array(Ids.data(), {static_cast(Ids.size())}, mx::int32); + std::vector TokenList; + std::string Answer; + int Skip = 0; + int TokenCount = 0; + auto [Y, KVCache] = GraphRef.Model->generate(Token, 0.1); + while (true) { + TokenCount++; + if (TokenCount > GraphRef.MaxToken) { + break; + } + eval(Y); + std::vector Tokens; + auto *Data = Y.data(); + for (int Idx = 0; Idx < static_cast(Y.size()); Idx++) { + Tokens.emplace_back(Data[Idx]); + } + // TODO: break when the token is the eos_token_id + TokenList.insert(TokenList.end(), Tokens.begin(), Tokens.end()); + Answer = GraphRef.Tok->Decode(TokenList); + const AnserSataus Status = answerSataus(Answer, GraphRef.Prmopt.TextEnd); + if (Status == STOP) { + break; + } + if (Status == GO) { + CxtRef.Outputs += Answer.substr(Skip); + Skip = Answer.size(); + } + auto [NY, NKVCache] = + GraphRef.Model->nextGenerate(Y, GraphRef.Temp, KVCache); + Y = NY, KVCache = NKVCache; + } return WASINN::ErrNo::Success; } #else diff --git a/plugins/wasi_nn/mlx.h b/plugins/wasi_nn/mlx.h index 56c61674..f39a17d7 100644 --- a/plugins/wasi_nn/mlx.h +++ b/plugins/wasi_nn/mlx.h @@ -2,7 +2,14 @@ #include "plugin/plugin.h" #include "types.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX +#include "MLX/model/transformer.h" +#include "llama.h" +#include "transformer.h" #include +#include +#endif namespace WasmEdge::Host::WASINN { struct WasiNNEnvironment; @@ -11,13 +18,20 @@ struct WasiNNEnvironment; namespace WasmEdge::Host::WASINN::MLX { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX struct Graph { - mlx::core::StreamOrDevice MLXDevice = mlx::core::metal::is_available() - ? mlx::core::Device::gpu - : mlx::core::Device::cpu; + std::string ModelType = "tiny_llama_1.1B_chat_v1.0"; + std::unique_ptr Tok; + Transformer *Model; + inline static int GraphNumber = 0; + double Temp = 0.0; + bool EnableDebugLog = true; + int MaxToken = 1024; + BasePrompt Prmopt; }; struct Context { Context(size_t Gid, Graph &) noexcept : GraphId(Gid) {} size_t GraphId; + std::string Inputs; + std::string Outputs; }; #else struct Graph {}; From 3c2f68978897a3be2f9b88e3568f3757c4b034e9 Mon Sep 17 00:00:00 2001 From: grorge Date: Mon, 26 Aug 2024 23:11:43 +0800 Subject: [PATCH 442/623] [WASI-NN] mlx: add test Signed-off-by: grorge --- plugins/wasi_nn/mlx.cpp | 3 +- plugins/wasi_nn/mlx.h | 11 +- test/plugins/wasi_nn/CMakeLists.txt | 12 ++ test/plugins/wasi_nn/wasi_nn.cpp | 249 +++++++++++++++++++++++++++- 4 files changed, 271 insertions(+), 4 deletions(-) diff --git a/plugins/wasi_nn/mlx.cpp b/plugins/wasi_nn/mlx.cpp index a4ead7ce..48a04f52 100644 --- a/plugins/wasi_nn/mlx.cpp +++ b/plugins/wasi_nn/mlx.cpp @@ -102,7 +102,8 @@ Expect load(WASINN::WasiNNEnvironment &Env, // Load tokenizer. if (!TokenizerPath.empty()) { GraphRef.Tok = - tokenizers::Tokenizer::FromBlobJSON(loadBytesFromFile(TokenizerPath)); + tokenizers::Tokenizer::FromBlobJSON(loadBytesFromFile(TokenizerPath)) + .release(); } // Handle the model path. diff --git a/plugins/wasi_nn/mlx.h b/plugins/wasi_nn/mlx.h index f39a17d7..d07999b9 100644 --- a/plugins/wasi_nn/mlx.h +++ b/plugins/wasi_nn/mlx.h @@ -18,10 +18,17 @@ struct WasiNNEnvironment; namespace WasmEdge::Host::WASINN::MLX { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX struct Graph { + ~Graph() noexcept { + if (Model != nullptr) { + delete Model; + } + if (Tok != nullptr) { + delete Tok; + } + } std::string ModelType = "tiny_llama_1.1B_chat_v1.0"; - std::unique_ptr Tok; + tokenizers::Tokenizer *Tok = nullptr; Transformer *Model; - inline static int GraphNumber = 0; double Temp = 0.0; bool EnableDebugLog = true; int MaxToken = 1024; diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index 88fad025..8d21ab61 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -106,6 +106,18 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) ${CMAKE_CURRENT_BINARY_DIR}/wasinn_whisper_fixtures/test.wav MD5=6cf3f7af1ebbd6b29c373e526b548dba ) + elseif(BACKEND STREQUAL "mlx") + message( STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_mlx_fixtures") + download( + https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/resolve/main/model.safetensors + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_mlx_fixtures/model.safetensors + MD5=59e1605b3af5f1673eb8396251d6bc46 + ) + download( + https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/resolve/main/tokenizer.json + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_mlx_fixtures/tokenizer.json + MD5=c9dc953a24ad2b76b4bae4bf456f18bd + ) else() # Add the other backend test files fetching here. endif() diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 0ed4990b..17473278 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -26,7 +26,8 @@ using WasmEdge::Host::WASINN::ErrNo; defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER) || \ - defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS) + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS) || \ + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX) namespace { template @@ -2706,3 +2707,249 @@ TEST(WasiNNTest, ChatTTSBackend) { } } #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX +TEST(WasiNNTest, MLXBackend) { + // Create the wasi_nn module instance. + auto *NNMod = dynamic_cast(createModule()); + ASSERT_TRUE(NNMod != nullptr); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(60000))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // Load the files. + std::string Prompt = "How are you?"; + std::string Tokenizer = "./wasinn_mlx_fixtures/tokenizer.json"; + std::vector TensorData(Prompt.begin(), Prompt.end()); + std::vector WeightRead = + readEntireFile("./wasinn_mlx_fixtures/model.safetensors"); + + std::vector TensorDim{1}; + uint32_t BuilderPtr = UINT32_C(0); + uint32_t LoadEntryPtr = UINT32_C(0); + uint32_t SetInputEntryPtr = UINT32_C(0); + uint32_t OutBoundPtr = UINT32_C(61000) * UINT32_C(65536); + uint32_t StorePtr = UINT32_C(65536); + + // Return value. + std::array Errno = {UINT32_C(0)}; + + // Get the function "load". + auto *FuncInst = NNMod->findFuncExports("load"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncLoad = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "init_execution_context". + FuncInst = NNMod->findFuncExports("init_execution_context"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInit = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "set_input". + FuncInst = NNMod->findFuncExports("set_input"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncSetInput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "get_output". + FuncInst = NNMod->findFuncExports("get_output"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncGetOutput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "compute". + FuncInst = NNMod->findFuncExports("compute"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncCompute = + dynamic_cast(FuncInst->getHostFunc()); + + // MLX WASI-NN load tests. + // Test: load -- meaningless binaries. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::MLX), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: load -- graph id ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::MLX), + static_cast(Device::CPU), OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- graph builder ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + OutBoundPtr, UINT32_C(1), static_cast(Backend::MLX), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: load -- MLX model bin ptr out of bounds. + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, OutBoundPtr, + static_cast(WeightRead.size()), BuilderPtr); + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), static_cast(Backend::MLX), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: load -- load successfully. + std::string Config = "{\"model_type\":\"tiny_llama_1.1B_chat_v1.0\", " + "\"tokenizer\":\"" + + Tokenizer + "\"}"; + std::vector ConfigData(Config.begin(), Config.end()); + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, StorePtr, WeightRead.size(), BuilderPtr); + writeFatPointer(MemInst, StorePtr + WeightRead.size(), ConfigData.size(), + BuilderPtr); + writeBinaries(MemInst, WeightRead, StorePtr); + writeBinaries(MemInst, ConfigData, StorePtr + WeightRead.size()); + StorePtr += WeightRead.size() + ConfigData.size(); + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(2), static_cast(Backend::MLX), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // MLX WASI-NN init_execution_context tests. + // Test: init_execution_context -- graph id invalid. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(2), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: init_execution_context -- init second context. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 0); + BuilderPtr += 4; + } + + // MLX WASI-NN set_input tests. + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, UINT32_C(1), BuilderPtr); + writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), + BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries(MemInst, TensorData, StorePtr + TensorDim.size() * 4); + + // Test: set_input -- context id exceeds. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(3), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: set_input -- set input successfully. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += (TensorDim.size() * 4 + TensorData.size()); + + // MLX WASI-NN compute tests. + // Test: compute -- context id exceeds. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(3)}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: compute -- compute successfully. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(0)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // MLX WASI-NN get_output tests. + // Test: get_output -- output bytes ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: get_output -- output buffer ptr out of bounds. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), OutBoundPtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: get_output -- get output successfully. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(0), UINT32_C(0), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + // Should output more than 50 bytes. + auto BytesWritten = *MemInst.getPointer(BuilderPtr); + EXPECT_GE(BytesWritten, 50); + } + + delete NNMod; +} +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX \ No newline at end of file From 2d2d09af7fe94a659bd215143b28b86cc65b353d Mon Sep 17 00:00:00 2001 From: grorge Date: Tue, 27 Aug 2024 15:46:24 +0800 Subject: [PATCH 443/623] [WASI-NN] mlx: handle lake weight Signed-off-by: grorge --- plugins/wasi_nn/MLX/CMakeLists.txt | 2 +- .../MLX/prompt/{llama.cpp => prompt.cpp} | 2 +- .../wasi_nn/MLX/prompt/{llama.h => prompt.h} | 12 +++++----- plugins/wasi_nn/mlx.cpp | 22 ++++++++++--------- plugins/wasi_nn/mlx.h | 2 +- 5 files changed, 22 insertions(+), 18 deletions(-) rename plugins/wasi_nn/MLX/prompt/{llama.cpp => prompt.cpp} (86%) rename plugins/wasi_nn/MLX/prompt/{llama.h => prompt.h} (64%) diff --git a/plugins/wasi_nn/MLX/CMakeLists.txt b/plugins/wasi_nn/MLX/CMakeLists.txt index e83328d7..d05d3069 100644 --- a/plugins/wasi_nn/MLX/CMakeLists.txt +++ b/plugins/wasi_nn/MLX/CMakeLists.txt @@ -1,6 +1,6 @@ wasmedge_add_library( mlx_cpp - prompt/llama.cpp + prompt/prompt.cpp model/transformer.cpp model/converter.cpp model/utils.cpp diff --git a/plugins/wasi_nn/MLX/prompt/llama.cpp b/plugins/wasi_nn/MLX/prompt/prompt.cpp similarity index 86% rename from plugins/wasi_nn/MLX/prompt/llama.cpp rename to plugins/wasi_nn/MLX/prompt/prompt.cpp index 81b41411..4904abbd 100644 --- a/plugins/wasi_nn/MLX/prompt/llama.cpp +++ b/plugins/wasi_nn/MLX/prompt/prompt.cpp @@ -1,4 +1,4 @@ -#include "llama.h" +#include "prompt.h" std::string TinyLLaMAPrompt::prepare(std::string Prompt) { return SystemStart + TextEnd + Prompt + TextEnd + Assistant; diff --git a/plugins/wasi_nn/MLX/prompt/llama.h b/plugins/wasi_nn/MLX/prompt/prompt.h similarity index 64% rename from plugins/wasi_nn/MLX/prompt/llama.h rename to plugins/wasi_nn/MLX/prompt/prompt.h index 21261819..e16c61bb 100644 --- a/plugins/wasi_nn/MLX/prompt/llama.h +++ b/plugins/wasi_nn/MLX/prompt/prompt.h @@ -16,9 +16,11 @@ class BasePrompt { class TinyLLaMAPrompt : public BasePrompt { public: - std::string SystemStart = "<|system|>"; - std::string User = "<|user|>"; - std::string Assistant = "<|assistant|>"; - std::string TextEnd = ""; - std::string prepare(std::string Prompt); + TinyLLaMAPrompt() { + SystemStart = "<|system|>"; + Assistant = "<|assistant|>"; + User = "<|user|>"; + TextEnd = ""; + } + std::string prepare(std::string Prompt) override; }; \ No newline at end of file diff --git a/plugins/wasi_nn/mlx.cpp b/plugins/wasi_nn/mlx.cpp index 48a04f52..eca7622d 100644 --- a/plugins/wasi_nn/mlx.cpp +++ b/plugins/wasi_nn/mlx.cpp @@ -1,4 +1,5 @@ #include "mlx.h" +#include "spdlog/spdlog.h" #include "wasinnenv.h" #include #include @@ -85,13 +86,13 @@ Expect load(WASINN::WasiNNEnvironment &Env, } GraphRef.EnableDebugLog = EnableDebugLog; } - if (Doc.at_key("tokenizer_path").error() == simdjson::SUCCESS) { + if (Doc.at_key("tokenizer").error() == simdjson::SUCCESS) { std::string_view TokenizerPathView; auto Err = - Doc["tokenizer_path"].get().get(TokenizerPathView); + Doc["tokenizer"].get().get(TokenizerPathView); if (Err) { spdlog::error( - "[WASI-NN] MLX backend: Unable to retrieve the tokenizer_path option."sv); + "[WASI-NN] MLX backend: Unable to retrieve the tokenizer option."sv); Env.NNGraph.pop_back(); return ErrNo::InvalidArgument; } @@ -110,7 +111,11 @@ Expect load(WASINN::WasiNNEnvironment &Env, auto Weight = Builders[0]; const std::string BinModel(reinterpret_cast(Weight.data()), Weight.size()); - spdlog::info("[WASI-NN] Neural speed: BinModel: {}"sv, BinModel.size()); + spdlog::info("[WASI-NN] MLX BinModel: {}"sv, BinModel.size()); + if (BinModel.size() == 0) { + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } std::string ModelFilePath; if (BinModel.substr(0, 8) == "preload:"sv) { ModelFilePath = BinModel.substr(8); @@ -121,14 +126,12 @@ Expect load(WASINN::WasiNNEnvironment &Env, "write model into a tmpfile."sv); } // Write model to file. - ModelFilePath = "MLX.bin"sv; + // TODO: handle different model format. + ModelFilePath = "MLX.safetensors"sv; std::ofstream TempFile(ModelFilePath, std::ios::out | std::ios::binary); if (!TempFile) { spdlog::error( - "[WASI-NN] MLX backend: Failed to create the temporary file. " - "Currently, our workaround involves creating a temporary model " - "file named \"MLX.bin\" and passing this filename as a " - "parameter to the ggml llama library."sv); + "[WASI-NN] MLX backend: Failed to create the temporary file. "sv); Env.NNGraph.pop_back(); return ErrNo::InvalidArgument; } @@ -155,7 +158,6 @@ Expect load(WASINN::WasiNNEnvironment &Env, return ErrNo::InvalidArgument; } } - // Load weight. if (GraphRef.ModelType == "tiny_llama_1.1B_chat_v1.0") { GraphRef.Model->update(llamaToMlxllm(ModelFilePath)); diff --git a/plugins/wasi_nn/mlx.h b/plugins/wasi_nn/mlx.h index d07999b9..4b40c473 100644 --- a/plugins/wasi_nn/mlx.h +++ b/plugins/wasi_nn/mlx.h @@ -5,7 +5,7 @@ #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX #include "MLX/model/transformer.h" -#include "llama.h" +#include "prompt.h" #include "transformer.h" #include #include From 18751205ce334f67726867950844d0c5f5c5dfae Mon Sep 17 00:00:00 2001 From: grorge Date: Sun, 8 Sep 2024 14:19:20 +0800 Subject: [PATCH 444/623] [WASI-NN] mlx: handle different model type Signed-off-by: grorge --- plugins/wasi_nn/MLX/prompt/prompt.cpp | 9 ++ plugins/wasi_nn/MLX/prompt/prompt.h | 38 ++++- plugins/wasi_nn/mlx.cpp | 220 +++++++++++++++----------- plugins/wasi_nn/mlx.h | 4 +- test/plugins/wasi_nn/wasi_nn.cpp | 4 +- 5 files changed, 171 insertions(+), 104 deletions(-) diff --git a/plugins/wasi_nn/MLX/prompt/prompt.cpp b/plugins/wasi_nn/MLX/prompt/prompt.cpp index 4904abbd..9434b04e 100644 --- a/plugins/wasi_nn/MLX/prompt/prompt.cpp +++ b/plugins/wasi_nn/MLX/prompt/prompt.cpp @@ -1,5 +1,14 @@ #include "prompt.h" +#include std::string TinyLLaMAPrompt::prepare(std::string Prompt) { return SystemStart + TextEnd + Prompt + TextEnd + Assistant; +} +std::string LLaMA2Prompt::prepare(std::string Prompt) { + return InstStart + SysStart + SysEnd + Prompt + TextEnd; +} +std::string LLaMA3Prompt::prepare(std::string Prompt) { + return PropmtStart + StartHeader + "system" + EndHeader + TextEnd + Prompt + + EndHeader + TextEnd + StartHeader + "user" + EndHeader + Prompt + + TextEnd + StartHeader + "assistant" + EndHeader; } \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/prompt/prompt.h b/plugins/wasi_nn/MLX/prompt/prompt.h index e16c61bb..06eee87a 100644 --- a/plugins/wasi_nn/MLX/prompt/prompt.h +++ b/plugins/wasi_nn/MLX/prompt/prompt.h @@ -2,20 +2,19 @@ #include #include +#include class BasePrompt { public: - std::string SystemStart; - std::string User; - std::string Assistant; std::string TextEnd; - virtual std::string prepare(std::string Prompt) { - return SystemStart + TextEnd + Prompt + TextEnd + Assistant; - }; + virtual std::string prepare(std::string Prompt) { return Prompt + TextEnd; }; }; class TinyLLaMAPrompt : public BasePrompt { public: + std::string SystemStart; + std::string User; + std::string Assistant; TinyLLaMAPrompt() { SystemStart = "<|system|>"; Assistant = "<|assistant|>"; @@ -23,4 +22,31 @@ class TinyLLaMAPrompt : public BasePrompt { TextEnd = ""; } std::string prepare(std::string Prompt) override; +}; + +class LLaMA2Prompt : public BasePrompt { +public: + std::string SysStart; + std::string SysEnd; + std::string InstStart; + LLaMA2Prompt() { + SysStart = "<>"; + SysEnd = "<>"; + InstStart = "[INST]"; + TextEnd = "[/INST]"; + } + std::string prepare(std::string Prompt) override; +}; +class LLaMA3Prompt : public BasePrompt { +public: + std::string PropmtStart; + std::string StartHeader; + std::string EndHeader; + LLaMA3Prompt() { + PropmtStart = "<|begin_of_text|>"; + StartHeader = "<|start_header_id|>"; + EndHeader = "<|end_header_id|>"; + TextEnd = "<|eot_id|>"; + } + std::string prepare(std::string Prompt) override; }; \ No newline at end of file diff --git a/plugins/wasi_nn/mlx.cpp b/plugins/wasi_nn/mlx.cpp index eca7622d..807e3d91 100644 --- a/plugins/wasi_nn/mlx.cpp +++ b/plugins/wasi_nn/mlx.cpp @@ -1,14 +1,17 @@ #include "mlx.h" +#include "prompt.h" #include "spdlog/spdlog.h" #include "wasinnenv.h" +#include <_types/_uint8_t.h> +#include #include #include #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX #include "converter.h" #include "registry.h" -#include "simdjson.h" #include "utils.h" +#include #endif namespace WasmEdge::Host::WASINN::MLX { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX @@ -53,51 +56,75 @@ Expect load(WASINN::WasiNNEnvironment &Env, } std::string TokenizerPath; // Parse metadata. - if (Builders.size() > 1) { - const std::string Metadata = std::string( - reinterpret_cast(Builders[1].data()), Builders[1].size()); - simdjson::dom::parser Parser; - simdjson::dom::element Doc; - auto ParseError = Parser.parse(Metadata).get(Doc); - if (ParseError) { - spdlog::error("[WASI-NN] MLX backend: Parse metadata error"sv); + if (Builders.size() <= 1) { + spdlog::error( + "[WASI-NN] MLX backend: Lack necessary metadata(tokenizer, model_type)."sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } + const std::string Metadata = std::string( + reinterpret_cast(Builders.back().data()), Builders.back().size()); + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + auto ParseError = Parser.parse(Metadata).get(Doc); + if (ParseError) { + spdlog::error("[WASI-NN] MLX backend: Parse metadata error"sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidEncoding; + } + if (Doc.at_key("model_type").error() == simdjson::SUCCESS) { + std::string_view ModelType; + auto Err = Doc["model_type"].get().get(ModelType); + if (Err) { + spdlog::error( + "[WASI-NN] MLX backend: Unable to retrieve the model_type option."sv); Env.NNGraph.pop_back(); - return ErrNo::InvalidEncoding; + return ErrNo::InvalidArgument; } - if (Doc.at_key("model_type").error() == simdjson::SUCCESS) { - std::string_view ModelType; - auto Err = Doc["model_type"].get().get(ModelType); - if (Err) { - spdlog::error( - "[WASI-NN] MLX backend: Unable to retrieve the model_type option."sv); - Env.NNGraph.pop_back(); - return ErrNo::InvalidArgument; - } - GraphRef.ModelType = ModelType; + GraphRef.ModelType = ModelType; + } else { + spdlog::error( + "[WASI-NN] MLX backend: Unable to retrieve the model_type option."sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } + if (Doc.at_key("enable_debug_log").error() == simdjson::SUCCESS) { + bool EnableDebugLog; + auto Err = Doc["enable_debug_log"].get().get(EnableDebugLog); + if (Err) { + spdlog::error( + "[WASI-NN] MLX backend: Unable to retrieve the enable_debug_log option."sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; } - if (Doc.at_key("enable_debug_log").error() == simdjson::SUCCESS) { - bool EnableDebugLog; - auto Err = Doc["enable_debug_log"].get().get(EnableDebugLog); - if (Err) { - spdlog::error( - "[WASI-NN] MLX backend: Unable to retrieve the enable_debug_log option."sv); - Env.NNGraph.pop_back(); - return ErrNo::InvalidArgument; - } - GraphRef.EnableDebugLog = EnableDebugLog; + GraphRef.EnableDebugLog = EnableDebugLog; + } + if (Doc.at_key("tokenizer").error() == simdjson::SUCCESS) { + std::string_view TokenizerPathView; + auto Err = Doc["tokenizer"].get().get(TokenizerPathView); + if (Err) { + spdlog::error( + "[WASI-NN] MLX backend: Unable to retrieve the tokenizer option."sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; } - if (Doc.at_key("tokenizer").error() == simdjson::SUCCESS) { - std::string_view TokenizerPathView; - auto Err = - Doc["tokenizer"].get().get(TokenizerPathView); - if (Err) { - spdlog::error( - "[WASI-NN] MLX backend: Unable to retrieve the tokenizer option."sv); - Env.NNGraph.pop_back(); - return ErrNo::InvalidArgument; - } - TokenizerPath = TokenizerPathView; + TokenizerPath = TokenizerPathView; + } else { + spdlog::error( + "[WASI-NN] MLX backend: Unable to retrieve the tokenizer option."sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } + if (Doc.at_key("max_token").error() == simdjson::SUCCESS) { + uint64_t MaxToken; + auto Err = Doc["max_token"].get().get(MaxToken); + if (Err) { + spdlog::error( + "[WASI-NN] MLX backend: Unable to retrieve the max_token option."sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; } + GraphRef.MaxToken = MaxToken; } // Load tokenizer. @@ -105,71 +132,77 @@ Expect load(WASINN::WasiNNEnvironment &Env, GraphRef.Tok = tokenizers::Tokenizer::FromBlobJSON(loadBytesFromFile(TokenizerPath)) .release(); + } else { + spdlog::error("[WASI-NN] MLX backend: Tokenizer path not found."sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; } - // Handle the model path. - auto Weight = Builders[0]; - const std::string BinModel(reinterpret_cast(Weight.data()), - Weight.size()); - spdlog::info("[WASI-NN] MLX BinModel: {}"sv, BinModel.size()); - if (BinModel.size() == 0) { + // Create Model. + if (GraphRef.ModelType == "tiny_llama_1.1B_chat_v1.0") { + GraphRef.Model = tinyLlama11BChatV10(); + GraphRef.Prmopt = TinyLLaMAPrompt(); + } else if (GraphRef.ModelType == "llama_3_8b") { + GraphRef.Model = llama38b(); + GraphRef.Prmopt = LLaMA3Prompt(); + } else if (GraphRef.ModelType == "llama_2_7b_chat_hf") { + GraphRef.Model = llama27bChat(); + GraphRef.Prmopt = LLaMA2Prompt(); + } else { + spdlog::error("[WASI-NN] MLX backend: Model type not supported."sv); Env.NNGraph.pop_back(); return ErrNo::InvalidArgument; } - std::string ModelFilePath; - if (BinModel.substr(0, 8) == "preload:"sv) { - ModelFilePath = BinModel.substr(8); - } else { - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] MLX backend: Model path not found in nn-preload, " - "write model into a tmpfile."sv); - } - // Write model to file. - // TODO: handle different model format. - ModelFilePath = "MLX.safetensors"sv; - std::ofstream TempFile(ModelFilePath, std::ios::out | std::ios::binary); - if (!TempFile) { - spdlog::error( - "[WASI-NN] MLX backend: Failed to create the temporary file. "sv); + + // Handle the model path. + for (size_t Idx = 0; Idx < Builders.size() - 1; Idx++) { + auto Weight = Builders[Idx]; + const std::string BinModel(reinterpret_cast(Weight.data()), + Weight.size()); + spdlog::info("[WASI-NN] MLX BinModel: {}"sv, BinModel.size()); + if (BinModel.size() == 0) { Env.NNGraph.pop_back(); return ErrNo::InvalidArgument; } - TempFile.write(BinModel.data(), BinModel.size()); - TempFile.close(); - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] MLX backend: Write model into a tmpfile...Done"sv); + std::string ModelFilePath; + if (BinModel.substr(0, 8) == "preload:"sv) { + ModelFilePath = BinModel.substr(8); + } else { + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] MLX backend: Model path not found in nn-preload, " + "write model into a tmpfile."sv); + } + // Write model to file. + // TODO: handle different model format. + ModelFilePath = "MLX" + std::to_string(Idx) + ".safetensors"; + std::ofstream TempFile(ModelFilePath, std::ios::out | std::ios::binary); + if (!TempFile) { + spdlog::error( + "[WASI-NN] MLX backend: Failed to create the temporary file. "sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } + TempFile.write(BinModel.data(), BinModel.size()); + TempFile.close(); + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] MLX backend: Write model into a tmpfile...Done"sv); + } } - } - - // Create Model. - if (GraphRef.Model == nullptr) { + // Load weight. if (GraphRef.ModelType == "tiny_llama_1.1B_chat_v1.0") { - GraphRef.Model = tinyLlama11BChatV10(); - GraphRef.Prmopt = TinyLLaMAPrompt(); + GraphRef.Model->update(llamaToMlxllm(ModelFilePath)); } else if (GraphRef.ModelType == "llama_3_8b") { - GraphRef.Model = llama38b(); + GraphRef.Model->update(llamaToMlxllm(ModelFilePath)); } else if (GraphRef.ModelType == "llama_2_7b_chat_hf") { - GraphRef.Model = llama27bChat(); + GraphRef.Model->update(llamaToMlxllm(ModelFilePath)); } else { spdlog::error("[WASI-NN] MLX backend: Model type not supported."sv); Env.NNGraph.pop_back(); return ErrNo::InvalidArgument; } } - // Load weight. - if (GraphRef.ModelType == "tiny_llama_1.1B_chat_v1.0") { - GraphRef.Model->update(llamaToMlxllm(ModelFilePath)); - } else if (GraphRef.ModelType == "llama_3_8b") { - GraphRef.Model->update(llamaToMlxllm(ModelFilePath)); - } else if (GraphRef.ModelType == "llama_2_7b_chat_hf") { - GraphRef.Model->update(llamaToMlxllm(ModelFilePath)); - } else { - spdlog::error("[WASI-NN] MLX backend: Model type not supported."sv); - Env.NNGraph.pop_back(); - return ErrNo::InvalidArgument; - } return WASINN::ErrNo::Success; } @@ -204,7 +237,7 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, spdlog::info("[WASI-NN] MLX backend: getOutput"sv); } std::string StringTmp(reinterpret_cast(CxtRef.Outputs.data()), - CxtRef.Outputs.size() * sizeof(long long int)); + CxtRef.Outputs.size()); std::copy_n(StringTmp.data(), StringTmp.length(), OutBuffer.data()); BytesWritten = StringTmp.length(); return WASINN::ErrNo::Success; @@ -220,12 +253,13 @@ Expect compute(WasiNNEnvironment &Env, if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN] MLX backend: compute"sv); } - const std::vector Ids = GraphRef.Tok->Encode(CxtRef.Inputs); - auto Token = mx::array(Ids.data(), {static_cast(Ids.size())}, mx::int32); + const std::vector Ids = GraphRef.Tok->Encode(CxtRef.Inputs); + auto Token = + mx::array(Ids.data(), {static_cast(Ids.size())}, mx::int32); std::vector TokenList; std::string Answer; - int Skip = 0; - int TokenCount = 0; + int32_t Skip = 0; + uint64_t TokenCount = 0; auto [Y, KVCache] = GraphRef.Model->generate(Token, 0.1); while (true) { TokenCount++; diff --git a/plugins/wasi_nn/mlx.h b/plugins/wasi_nn/mlx.h index 4b40c473..e37db5bf 100644 --- a/plugins/wasi_nn/mlx.h +++ b/plugins/wasi_nn/mlx.h @@ -26,12 +26,12 @@ struct Graph { delete Tok; } } - std::string ModelType = "tiny_llama_1.1B_chat_v1.0"; + std::string ModelType; tokenizers::Tokenizer *Tok = nullptr; Transformer *Model; double Temp = 0.0; bool EnableDebugLog = true; - int MaxToken = 1024; + uint64_t MaxToken = 1024; BasePrompt Prmopt; }; struct Context { diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 17473278..45a18c4b 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -2711,7 +2711,7 @@ TEST(WasiNNTest, ChatTTSBackend) { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX TEST(WasiNNTest, MLXBackend) { // Create the wasi_nn module instance. - auto *NNMod = dynamic_cast(createModule()); + auto NNMod = createModule(); ASSERT_TRUE(NNMod != nullptr); // Create the calling frame with memory instance. @@ -2949,7 +2949,5 @@ TEST(WasiNNTest, MLXBackend) { auto BytesWritten = *MemInst.getPointer(BuilderPtr); EXPECT_GE(BytesWritten, 50); } - - delete NNMod; } #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX \ No newline at end of file From 591d8768680b5332fa79f15d11353e09e11594f4 Mon Sep 17 00:00:00 2001 From: grorge Date: Sun, 8 Sep 2024 16:52:54 +0800 Subject: [PATCH 445/623] [WASI-NN] mlx: fix include scope Signed-off-by: grorge --- plugins/wasi_nn/CMakeLists.txt | 23 ++++++++++++----------- plugins/wasi_nn/mlx.cpp | 6 +----- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 6cc6386a..0f0d2765 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -197,15 +197,13 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) if(MLX_FOUND) message(STATUS "Found MLX: ${MLX_INCLUDE_DIRS}") else() - # TODO: TEST + # Not support directly download from source + find_library(ACCELERATE_LIBRARY Accelerate) + find_library(METAL_LIB Metal) + find_library(FOUNDATION_LIB Foundation) + find_library(QUARTZ_LIB QuartzCore) message(STATUS "MLX not found, downloading from source") include(FetchContent) - set(MLX_BUILD_TESTS OFF) - if (NOT APPLE) - set(MLX_BUILD_METAL OFF) - set(MLX_BUILD_CPU OFF) - endif() - FetchContent_Declare( mlx GIT_REPOSITORY https://github.com/ml-explore/mlx.git @@ -217,6 +215,13 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) set_target_properties(mlx PROPERTIES INTERFACE_LINK_LIBRARIES "$" ) + target_link_libraries( + mlx PUBLIC + ${ACCELERATE_LIBRARY} + ${METAL_LIB} + ${FOUNDATION_LIB} + ${QUARTZ_LIB} + ) target_compile_options(mlx PRIVATE -Wno-unused-parameter @@ -237,10 +242,6 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) FetchContent_MakeAvailable(tokenizers) target_include_directories(wasmedgePluginWasiNN PRIVATE ${tokenizers_SOURCE_DIR}/include) target_link_libraries(wasmedgePluginWasiNN PRIVATE tokenizers_cpp) - target_compile_options(wasmedgePluginWasiNN - PRIVATE - -Wno-unused-parameter - ) target_link_libraries(wasmedgePluginWasiNN PRIVATE mlx_cpp mlx diff --git a/plugins/wasi_nn/mlx.cpp b/plugins/wasi_nn/mlx.cpp index 807e3d91..70f565b8 100644 --- a/plugins/wasi_nn/mlx.cpp +++ b/plugins/wasi_nn/mlx.cpp @@ -1,14 +1,10 @@ #include "mlx.h" -#include "prompt.h" #include "spdlog/spdlog.h" #include "wasinnenv.h" -#include <_types/_uint8_t.h> -#include -#include -#include #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX #include "converter.h" +#include "prompt.h" #include "registry.h" #include "utils.h" #include From bf4f293902b679d16e13fe5d7d712ce47cc738e9 Mon Sep 17 00:00:00 2001 From: grorge Date: Mon, 9 Sep 2024 12:25:50 +0800 Subject: [PATCH 446/623] [WASI-NN] mlx: fix mlx fetch content Signed-off-by: grorge --- plugins/wasi_nn/CMakeLists.txt | 7 ++++--- plugins/wasi_nn/MLX/CMakeLists.txt | 9 ++++----- test/plugins/wasi_nn/CMakeLists.txt | 3 +++ 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 0f0d2765..c0f8e90c 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -204,6 +204,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) find_library(QUARTZ_LIB QuartzCore) message(STATUS "MLX not found, downloading from source") include(FetchContent) + set(MLX_BUILD_GGUF OFF) FetchContent_Declare( mlx GIT_REPOSITORY https://github.com/ml-explore/mlx.git @@ -223,14 +224,14 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) ${QUARTZ_LIB} ) target_compile_options(mlx - PRIVATE + PUBLIC -Wno-unused-parameter -Wno-deprecated-copy -Wno-format ) endif() add_subdirectory(MLX) - target_include_directories(wasmedgePluginWasiNN PRIVATE MLX/model MLX/prompt MLX/mlx) + target_include_directories(wasmedgePluginWasiNN PUBLIC MLX/model MLX/prompt MLX/mlx) message(STATUS "Downloading tokenizers") FetchContent_Declare( @@ -242,7 +243,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) FetchContent_MakeAvailable(tokenizers) target_include_directories(wasmedgePluginWasiNN PRIVATE ${tokenizers_SOURCE_DIR}/include) target_link_libraries(wasmedgePluginWasiNN PRIVATE tokenizers_cpp) - target_link_libraries(wasmedgePluginWasiNN PRIVATE + target_link_libraries(wasmedgePluginWasiNN PUBLIC mlx_cpp mlx simdjson::simdjson diff --git a/plugins/wasi_nn/MLX/CMakeLists.txt b/plugins/wasi_nn/MLX/CMakeLists.txt index d05d3069..881af0ce 100644 --- a/plugins/wasi_nn/MLX/CMakeLists.txt +++ b/plugins/wasi_nn/MLX/CMakeLists.txt @@ -14,8 +14,8 @@ wasmedge_add_library( mlx/transformer.cpp) target_link_libraries(mlx_cpp PUBLIC mlx) -target_include_directories(mlx_cpp PUBLIC ./mlx) -target_include_directories(mlx_cpp PUBLIC ${MLX_INCLUDE_DIRS}) +target_include_directories(mlx_cpp SYSTEM PUBLIC ./mlx) +target_include_directories(mlx_cpp SYSTEM PUBLIC ${MLX_INCLUDE_DIRS}) target_link_libraries(mlx_cpp PUBLIC ${MLX_LIBRARIES}) message(STATUS "Downloading gguflib") @@ -26,9 +26,8 @@ FetchContent_Declare( GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(gguflib) -target_include_directories(mlx_cpp - PRIVATE $) +target_include_directories(mlx_cpp SYSTEM PUBLIC $) add_library(gguflib STATIC ${gguflib_SOURCE_DIR}/fp16.c ${gguflib_SOURCE_DIR}/gguflib.c) set_target_properties(gguflib PROPERTIES LINKER_LANGUAGE CXX) -target_link_libraries(mlx_cpp PRIVATE gguflib) +target_link_libraries(mlx_cpp PUBLIC gguflib) diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index 8d21ab61..86e3b4cb 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -118,6 +118,9 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) ${CMAKE_CURRENT_BINARY_DIR}/wasinn_mlx_fixtures/tokenizer.json MD5=c9dc953a24ad2b76b4bae4bf456f18bd ) + target_compile_options(wasiNNTests PUBLIC + -Wno-unused-parameter + ) else() # Add the other backend test files fetching here. endif() From 0625ff8ff3665dca54dcf47bb6f38895f90cc429 Mon Sep 17 00:00:00 2001 From: grorge Date: Mon, 9 Sep 2024 13:32:42 +0800 Subject: [PATCH 447/623] [WASI-NN] mlx: change parameter Signed-off-by: grorge --- plugins/wasi_nn/mlx.cpp | 2 +- plugins/wasi_nn/mlx.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn/mlx.cpp b/plugins/wasi_nn/mlx.cpp index 70f565b8..6078165b 100644 --- a/plugins/wasi_nn/mlx.cpp +++ b/plugins/wasi_nn/mlx.cpp @@ -54,7 +54,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, // Parse metadata. if (Builders.size() <= 1) { spdlog::error( - "[WASI-NN] MLX backend: Lack necessary metadata(tokenizer, model_type)."sv); + "[WASI-NN] MLX backend: Lack model weight or required metadata (tokenizer, model_type)."sv); Env.NNGraph.pop_back(); return ErrNo::InvalidArgument; } diff --git a/plugins/wasi_nn/mlx.h b/plugins/wasi_nn/mlx.h index e37db5bf..5be43305 100644 --- a/plugins/wasi_nn/mlx.h +++ b/plugins/wasi_nn/mlx.h @@ -30,7 +30,7 @@ struct Graph { tokenizers::Tokenizer *Tok = nullptr; Transformer *Model; double Temp = 0.0; - bool EnableDebugLog = true; + bool EnableDebugLog = false; uint64_t MaxToken = 1024; BasePrompt Prmopt; }; From 616c441d25df85f5d362e81e449b0363823eb4d7 Mon Sep 17 00:00:00 2001 From: grorge Date: Sun, 15 Sep 2024 18:17:32 +0800 Subject: [PATCH 448/623] [WASI-NN] mlx: add quantized function Signed-off-by: grorge --- plugins/wasi_nn/CMakeLists.txt | 35 +++++++++++++- plugins/wasi_nn/MLX/CMakeLists.txt | 33 ------------- plugins/wasi_nn/MLX/mlx/activations.cpp | 4 +- plugins/wasi_nn/MLX/mlx/base.cpp | 29 +++++++++++- plugins/wasi_nn/MLX/mlx/base.h | 18 +++++-- plugins/wasi_nn/MLX/mlx/embedding.cpp | 7 +++ plugins/wasi_nn/MLX/mlx/embedding.h | 6 ++- plugins/wasi_nn/MLX/mlx/linear.cpp | 9 +++- plugins/wasi_nn/MLX/mlx/linear.h | 4 +- plugins/wasi_nn/MLX/mlx/quantized.cpp | 62 +++++++++++++++++++++++++ plugins/wasi_nn/MLX/mlx/quantized.h | 56 ++++++++++++++++++++++ plugins/wasi_nn/MLX/mlx/transformer.h | 3 +- plugins/wasi_nn/MLX/model/converter.cpp | 24 ++++++---- plugins/wasi_nn/MLX/model/transformer.h | 3 +- plugins/wasi_nn/MLX/model/utils.cpp | 17 +++++++ plugins/wasi_nn/MLX/model/utils.h | 7 ++- plugins/wasi_nn/mlx.cpp | 1 - 17 files changed, 259 insertions(+), 59 deletions(-) delete mode 100644 plugins/wasi_nn/MLX/CMakeLists.txt create mode 100644 plugins/wasi_nn/MLX/mlx/quantized.cpp create mode 100644 plugins/wasi_nn/MLX/mlx/quantized.h diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index c0f8e90c..b7a1cbdc 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -17,6 +17,19 @@ wasmedge_add_library(wasmedgePluginWasiNN whispercpp.cpp chattts.cpp mlx.cpp + MLX/prompt/prompt.cpp + MLX/model/transformer.cpp + MLX/model/converter.cpp + MLX/model/utils.cpp + MLX/model/registry.cpp + MLX/mlx/base.cpp + MLX/mlx/linear.cpp + MLX/mlx/positional_encoding.cpp + MLX/mlx/activations.cpp + MLX/mlx/embedding.cpp + MLX/mlx/normalization.cpp + MLX/mlx/transformer.cpp + MLX/mlx/quantized.cpp ) foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) @@ -230,7 +243,6 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) -Wno-format ) endif() - add_subdirectory(MLX) target_include_directories(wasmedgePluginWasiNN PUBLIC MLX/model MLX/prompt MLX/mlx) message(STATUS "Downloading tokenizers") @@ -243,8 +255,27 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) FetchContent_MakeAvailable(tokenizers) target_include_directories(wasmedgePluginWasiNN PRIVATE ${tokenizers_SOURCE_DIR}/include) target_link_libraries(wasmedgePluginWasiNN PRIVATE tokenizers_cpp) + + target_include_directories(wasmedgePluginWasiNN SYSTEM PUBLIC ./mlx) + target_include_directories(wasmedgePluginWasiNN SYSTEM PUBLIC ${MLX_INCLUDE_DIRS}) + target_link_libraries(wasmedgePluginWasiNN PUBLIC ${MLX_LIBRARIES}) + + message(STATUS "Downloading gguflib") + FetchContent_Declare( + gguflib + GIT_REPOSITORY https://github.com/antirez/gguf-tools/ + GIT_TAG af7d88d808a7608a33723fba067036202910acb3 + GIT_SHALLOW FALSE + ) + FetchContent_MakeAvailable(gguflib) + target_include_directories(wasmedgePluginWasiNN SYSTEM PUBLIC $) + add_library(gguflib STATIC ${gguflib_SOURCE_DIR}/fp16.c + ${gguflib_SOURCE_DIR}/gguflib.c) + set_target_properties(gguflib PROPERTIES LINKER_LANGUAGE CXX) + target_link_libraries(wasmedgePluginWasiNN PUBLIC gguflib) + + target_link_libraries(wasmedgePluginWasiNN PUBLIC - mlx_cpp mlx simdjson::simdjson ) diff --git a/plugins/wasi_nn/MLX/CMakeLists.txt b/plugins/wasi_nn/MLX/CMakeLists.txt deleted file mode 100644 index 881af0ce..00000000 --- a/plugins/wasi_nn/MLX/CMakeLists.txt +++ /dev/null @@ -1,33 +0,0 @@ -wasmedge_add_library( - mlx_cpp - prompt/prompt.cpp - model/transformer.cpp - model/converter.cpp - model/utils.cpp - model/registry.cpp - mlx/base.cpp - mlx/linear.cpp - mlx/positional_encoding.cpp - mlx/activations.cpp - mlx/embedding.cpp - mlx/normalization.cpp - mlx/transformer.cpp) - -target_link_libraries(mlx_cpp PUBLIC mlx) -target_include_directories(mlx_cpp SYSTEM PUBLIC ./mlx) -target_include_directories(mlx_cpp SYSTEM PUBLIC ${MLX_INCLUDE_DIRS}) -target_link_libraries(mlx_cpp PUBLIC ${MLX_LIBRARIES}) - -message(STATUS "Downloading gguflib") -FetchContent_Declare( - gguflib - GIT_REPOSITORY https://github.com/antirez/gguf-tools/ - GIT_TAG af7d88d808a7608a33723fba067036202910acb3 - GIT_SHALLOW FALSE -) -FetchContent_MakeAvailable(gguflib) -target_include_directories(mlx_cpp SYSTEM PUBLIC $) -add_library(gguflib STATIC ${gguflib_SOURCE_DIR}/fp16.c - ${gguflib_SOURCE_DIR}/gguflib.c) -set_target_properties(gguflib PROPERTIES LINKER_LANGUAGE CXX) -target_link_libraries(mlx_cpp PUBLIC gguflib) diff --git a/plugins/wasi_nn/MLX/mlx/activations.cpp b/plugins/wasi_nn/MLX/mlx/activations.cpp index cb0b559a..688609be 100644 --- a/plugins/wasi_nn/MLX/mlx/activations.cpp +++ b/plugins/wasi_nn/MLX/mlx/activations.cpp @@ -1,6 +1,8 @@ #include "activations.h" #include namespace mlx::core { -mx::array gelu(mx::array X) { return X * (1 + mx::erf(X / std::sqrt(2))) / 2; } +mx::array gelu(mx::array X) { + return X * (1 + mx::erf(X / std::sqrt(2.0))) / 2.0; +} mx::array silu(mx::array X) { return X * mx::sigmoid(X); } } // namespace mlx::core \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/mlx/base.cpp b/plugins/wasi_nn/MLX/mlx/base.cpp index 03dff5f0..a239bdbb 100644 --- a/plugins/wasi_nn/MLX/mlx/base.cpp +++ b/plugins/wasi_nn/MLX/mlx/base.cpp @@ -1,6 +1,7 @@ #include "base.h" #include "../model/utils.h" #include +#include namespace mlx::core::nn { @@ -13,11 +14,22 @@ void Module::update(std::unordered_map Parameters) { apply(k, v); } } +nn::Module *Module::toQuantized(int GroupSize, int Bits) { + for (auto &[k, v] : Submodules) { + auto *OldModule = v; + v = v->toQuantized(GroupSize, Bits); + if (OldModule != v) { + delete OldModule; + } + } + return this; +} void Module::apply(std::string Key, mx::array Value) { std::vector SplitKey = splitString(Key, '.'); if (SplitKey.size() == 1) { if (Parameters.find(Key) == Parameters.end()) { - throw std::invalid_argument("Unsupported weight: " + Key); + spdlog::error("Unsupported weight: {}", Key); + assumingUnreachable(); } this->Parameters.at(Key) = Value; } else { @@ -28,9 +40,22 @@ void Module::apply(std::string Key, mx::array Value) { SplitKey.erase(SplitKey.begin()); } if (Submodules.find(LayerName) == Submodules.end()) { - throw std::invalid_argument("Unsupported Tensor: " + LayerName); + spdlog::error("Unsupported Layer: {}", LayerName); + assumingUnreachable(); } Submodules.at(LayerName)->apply(joinString(SplitKey, '.'), Value); } } +std::unordered_map +Module::getWeigts(const std::string &Prefix) { + std::unordered_map Weights; + for (auto &[k, v] : Submodules) { + auto Subweights = v->getWeigts(Prefix + Name + "."); + Weights.insert(Subweights.begin(), Subweights.end()); + } + for (auto &[k, v] : Parameters) { + Weights.insert({Prefix + Name + "." + k, v}); + } + return Weights; +} } // namespace mlx::core::nn \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/mlx/base.h b/plugins/wasi_nn/MLX/mlx/base.h index ddf68e99..c063f491 100644 --- a/plugins/wasi_nn/MLX/mlx/base.h +++ b/plugins/wasi_nn/MLX/mlx/base.h @@ -1,9 +1,11 @@ #pragma once +#include "common/errcode.h" #include "mlx/mlx.h" #include #include #include +#include #include namespace mx = mlx::core; @@ -19,24 +21,31 @@ class Module { std::unordered_map Parameters{}; std::unordered_map Submodules{}; mx::array ®isterParameter(std::string Name, array &&W); - + std::unordered_map + getWeigts(const std::string &Prefix = "model"); + virtual nn::Module *toQuantized(int GroupSize = 64, int Bits = 4); void update(std::unordered_map Parameters); void apply(std::string Key, mx::array Parameters); template void registerModule(std::string ModuleName, T *M) { using DecayedT = std::decay_t; if (!std::is_base_of::value) { - throw std::invalid_argument("Invalid subModule."); + spdlog::error("Invalid subModule."); + assumingUnreachable(); } if (Submodules.find(ModuleName) == Submodules.end()) { Submodules.insert({ModuleName, M}); Submodules.at(ModuleName)->Name = ModuleName; + } else { + spdlog::error("Module already exists."); + assumingUnreachable(); } } template void registerLayer(std::string ModuleName, std::vector &Layers) { if (!std::is_base_of::value) { - throw std::invalid_argument("Invalid subModule."); + spdlog::error("Invalid subModule."); + assumingUnreachable(); } for (size_t Idx = 0; Idx < Layers.size(); Idx++) { registerModule(ModuleName + "." + std::to_string(Idx), Layers[Idx]); @@ -47,7 +56,6 @@ class Module { template void printVec(std::vector Ve) { for (auto I : Ve) { - std::cout << I << " "; + spdlog::debug("{} ", I); } - std::cout << std::endl; } \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/mlx/embedding.cpp b/plugins/wasi_nn/MLX/mlx/embedding.cpp index 96c05088..0f5326c2 100644 --- a/plugins/wasi_nn/MLX/mlx/embedding.cpp +++ b/plugins/wasi_nn/MLX/mlx/embedding.cpp @@ -1,4 +1,6 @@ #include "embedding.h" +#include "base.h" +#include "quantized.h" #include #include @@ -9,4 +11,9 @@ mx::array Embedding::forward(mx::array Input) { mx::array Embedding::asLinear(mx::array Input) { return matmul(Input, transpose(Parameters.at("weight"))); } +nn::Module *Embedding::toQuantized(int GroupSize, int Bits) { + auto *QuantModel = QuantizedEmbedding::fromEmbedding(this, GroupSize, Bits); + QuantModel->Name = Name; + return QuantModel; +} } // namespace mlx::core::nn \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/mlx/embedding.h b/plugins/wasi_nn/MLX/mlx/embedding.h index cd11b26e..fda27df3 100644 --- a/plugins/wasi_nn/MLX/mlx/embedding.h +++ b/plugins/wasi_nn/MLX/mlx/embedding.h @@ -6,13 +6,15 @@ namespace mlx::core::nn { class Embedding : public Module { public: + Embedding() = default; Embedding(int NumEmbeddings, int Dims) { - const double Scale = std::sqrt(1 / Dims); + const double Scale = std::sqrt(1.0 / Dims); registerParameter("weight", mx::random::normal({NumEmbeddings, Dims}, 0.0, Scale)); } - mx::array forward(mx::array Input); + virtual mx::array forward(mx::array Input); mx::array asLinear(mx::array Input); + nn::Module *toQuantized(int GroupSize = 64, int Bits = 4) override; }; } // namespace mlx::core::nn \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/mlx/linear.cpp b/plugins/wasi_nn/MLX/mlx/linear.cpp index 5cd7b0a2..d0c9722b 100644 --- a/plugins/wasi_nn/MLX/mlx/linear.cpp +++ b/plugins/wasi_nn/MLX/mlx/linear.cpp @@ -1,5 +1,6 @@ #include "linear.h" - +#include "base.h" +#include "quantized.h" namespace mlx::core::nn { mx::array Linear::forward(mx::array Input) { if (EnableBias) { @@ -9,4 +10,10 @@ mx::array Linear::forward(mx::array Input) { return matmul(Input, transpose(Parameters.at("weight"))); } +nn::Module *Linear::toQuantized(int GroupSize, int Bits) { + auto *QuantModel = QuantizedLinear::fromLinear(this, GroupSize, Bits); + QuantModel->Name = Name; + return QuantModel; +} + } // namespace mlx::core::nn \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/mlx/linear.h b/plugins/wasi_nn/MLX/mlx/linear.h index 1b7c5c75..502a94e9 100644 --- a/plugins/wasi_nn/MLX/mlx/linear.h +++ b/plugins/wasi_nn/MLX/mlx/linear.h @@ -12,6 +12,7 @@ class Linear : public Module { bool EnableBias = true; public: + Linear() = default; Linear(int InputDims, int OutputDims, bool EnableBias = true) : EnableBias(EnableBias) { const double Scale = std::sqrt(1.0 / InputDims); @@ -24,6 +25,7 @@ class Linear : public Module { })); } } - mx::array forward(mx::array Input); + virtual mx::array forward(mx::array Input); + nn::Module *toQuantized(int GroupSize = 64, int Bits = 4) override; }; } // namespace mlx::core::nn \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/mlx/quantized.cpp b/plugins/wasi_nn/MLX/mlx/quantized.cpp new file mode 100644 index 00000000..d1f5fda1 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/quantized.cpp @@ -0,0 +1,62 @@ +#include "quantized.h" +#include +#include +#include +#include +namespace mlx::core::nn { +mx::array QuantizedEmbedding::forward(mx::array Input) { + auto S = Input.shape(); + auto X = mx::flatten(Input); + auto Out = + mx::dequantize(take(Parameters.at("weight"), Input, 0), + take(Parameters.at("scales"), Input, 0), + take(Parameters.at("biases"), Input, 0), GroupSize, Bits); + S.emplace_back(-1); + return reshape(Out, {S}); +} + +mx::array QuantizedLinear::forward(mx::array Input) { + auto Out = mx::quantized_matmul( + Input, Parameters.at("weight"), Parameters.at("scales"), + Parameters.at("biases"), true, GroupSize, Bits); + if (Parameters.find("bias") != Parameters.end()) { + Out = add(Out, Parameters.at("bias")); + } + return Out; +} +QuantizedEmbedding * +QuantizedEmbedding::fromEmbedding(Embedding *EmbeddingModule, int GroupSize, + int Bits) { + auto EmbeddingShape = EmbeddingModule->Parameters.at("weight").shape(); + auto *QuantizedModel = new QuantizedEmbedding( + EmbeddingShape[0], EmbeddingShape[1], GroupSize, Bits); + auto Quantized = + mx::quantize(EmbeddingModule->Parameters.at("weight"), GroupSize, Bits); + QuantizedModel->Parameters.insert_or_assign("weight", std::get<0>(Quantized)); + QuantizedModel->Parameters.insert_or_assign( + "scales", std::move(std::get<1>(Quantized))); + QuantizedModel->Parameters.insert_or_assign( + "biases", std::move(std::get<2>(Quantized))); + return QuantizedModel; +} +QuantizedLinear *QuantizedLinear::fromLinear(Linear *LinearModule, + int GroupSize, int Bits) { + auto LinearShape = LinearModule->Parameters.at("weight").shape(); + const bool EnableBias = + LinearModule->Parameters.find("bias") != LinearModule->Parameters.end(); + auto *QuantizedModel = new QuantizedLinear(LinearShape[0], LinearShape[1], + EnableBias, GroupSize, Bits); + auto Quantized = + mx::quantize(LinearModule->Parameters.at("weight"), GroupSize, Bits); + QuantizedModel->Parameters.insert_or_assign("weight", std::get<0>(Quantized)); + QuantizedModel->Parameters.insert_or_assign( + "scales", std::move(std::get<1>(Quantized))); + QuantizedModel->Parameters.insert_or_assign( + "biases", std::move(std::get<2>(Quantized))); + if (EnableBias) { + QuantizedModel->Parameters.insert_or_assign( + "bias", LinearModule->Parameters.at("bias")); + } + return QuantizedModel; +} +} // namespace mlx::core::nn \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/mlx/quantized.h b/plugins/wasi_nn/MLX/mlx/quantized.h new file mode 100644 index 00000000..3f884ec3 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/quantized.h @@ -0,0 +1,56 @@ +#pragma once +#include "base.h" +#include "embedding.h" +#include "linear.h" +#include +#include + +namespace mlx::core::nn { + +class QuantizedEmbedding : public Embedding { +public: + int GroupSize; + int Bits; + int NumEmbeddings; + int Dims; + QuantizedEmbedding(int NumEmbeddings, int Dims, int GroupSize = 64, + int Bits = 4) + : GroupSize(GroupSize), Bits(Bits), NumEmbeddings(NumEmbeddings), + Dims(Dims) { + const double Scale = std::sqrt(1.0 / Dims); + registerParameter("weight", + mx::random::normal({NumEmbeddings, Dims}, 0.0, Scale)); + auto Quantized = mx::quantize(Parameters.at("weight"), GroupSize, Bits); + Parameters.insert_or_assign("weight", std::get<0>(Quantized)); + registerParameter("scales", std::move(std::get<1>(Quantized))); + registerParameter("biases", std::move(std::get<2>(Quantized))); + } + mx::array forward(mx::array Input) override; + static QuantizedEmbedding *fromEmbedding(Embedding *EmbeddingModule, + int GroupSize = 64, int Bits = 4); +}; + +class QuantizedLinear : public Linear { +public: + int GroupSize; + int Bits; + QuantizedLinear(int InputDims, int OutputDim, bool Bias = true, + int GroupSize = 64, int Bits = 4) + : GroupSize(GroupSize), Bits(Bits) { + const double Scale = std::sqrt(1.0 / InputDims); + registerParameter( + "weight", mx::random::uniform(-Scale, Scale, {OutputDim, InputDims})); + auto Quantized = mx::quantize(Parameters.at("weight"), GroupSize, Bits); + Parameters.insert_or_assign("weight", std::get<0>(Quantized)); + registerParameter("scales", std::move(std::get<1>(Quantized))); + registerParameter("biases", std::move(std::get<2>(Quantized))); + if (Bias) { + registerParameter("bias", mx::zeros({OutputDim})); + } + } + mx::array forward(mx::array Input) override; + static QuantizedLinear *fromLinear(Linear *LinearModule, int GroupSize = 64, + int Bits = 4); +}; + +} // namespace mlx::core::nn \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/mlx/transformer.h b/plugins/wasi_nn/MLX/mlx/transformer.h index 4227222c..dce95ea2 100644 --- a/plugins/wasi_nn/MLX/mlx/transformer.h +++ b/plugins/wasi_nn/MLX/mlx/transformer.h @@ -16,7 +16,8 @@ class MultiHeadAttention : public Module { std::optional ValueOutputDims = {}, bool Bias = false) : NumHeads(NumHeads) { if (Dims % NumHeads != 0) { - throw std::invalid_argument("Dims must be divisible by NumHeads"); + spdlog::error("Dims must be divisible by NumHeads"); + assumingUnreachable(); } if (!QueryInputDims) { QueryInputDims = Dims; diff --git a/plugins/wasi_nn/MLX/model/converter.cpp b/plugins/wasi_nn/MLX/model/converter.cpp index 0b82e979..a30e6144 100644 --- a/plugins/wasi_nn/MLX/model/converter.cpp +++ b/plugins/wasi_nn/MLX/model/converter.cpp @@ -19,17 +19,17 @@ weightsToMlx(std::string WeightPath) { return Loaded; } if (endsWith(WeightPath, ".safetensors")) { - std::cout << "Loading model from .safetensors file...\n"; + spdlog::info("Loading model from .safetensors file...\n"); const mx::SafetensorsLoad Loaded = mx::load_safetensors(WeightPath); return Loaded.first; } if (endsWith(WeightPath, ".gguf")) { - std::cout << "Loading model from .gguf file...\n"; + spdlog::info("Loading model from .gguf file...\n"); const mx::GGUFLoad Loaded = mx::load_gguf(WeightPath); return Loaded.first; } - std::cout << "Can not regonize model file\n"; - throw std::invalid_argument("Invalid model path."); + spdlog::error("Can not regonize model file\n"); + assumingUnreachable(); } std::unordered_map @@ -60,16 +60,24 @@ llamaToMlxllm(std::string WeightPath) { const std::unordered_map KeyMap = { {"input_layernorm", "attention_norm"}, {"post_attention_layernorm", "mlp_norm"}}; - ModelWeights.insert({SplitKey[0] + "." + SplitKey[1] + "." + - KeyMap.at(SplitKey[2]) + "." + SplitKey[3], - v}); + if (KeyMap.find(SplitKey[2]) == KeyMap.end()) { + ModelWeights.insert({NewKey, v}); + } else { + ModelWeights.insert({SplitKey[0] + "." + SplitKey[1] + "." + + KeyMap.at(SplitKey[2]) + "." + SplitKey[3], + v}); + } } } else { const std::unordered_map KeyMap = { {"embed_tokens", "token_embed"}, {"lm_head", "head"}, {"norm", "norm"}}; - ModelWeights.insert({KeyMap.at(SplitKey[0]) + "." + SplitKey[1], v}); + if (KeyMap.find(SplitKey[0]) == KeyMap.end()) { + ModelWeights.insert({NewKey, v}); + } else { + ModelWeights.insert({KeyMap.at(SplitKey[0]) + "." + SplitKey[1], v}); + } } } return ModelWeights; diff --git a/plugins/wasi_nn/MLX/model/transformer.h b/plugins/wasi_nn/MLX/model/transformer.h index 4be559d7..8db5d392 100644 --- a/plugins/wasi_nn/MLX/model/transformer.h +++ b/plugins/wasi_nn/MLX/model/transformer.h @@ -130,7 +130,8 @@ class Transformer : public nn::Module { : Dim(Dim), HiddenDim(HiddenDim), Gemma(Gemma), EmbedAsHead(EmbedAsHeadPar) { if (VocabSize <= 0) { - throw std::invalid_argument("VocabSize must be greater than 0."); + spdlog::error("VocabSize must be greater than 0."); + assumingUnreachable(); } EmbedAsHead = Gemma ? true : EmbedAsHead; if (!NKVHeads) { diff --git a/plugins/wasi_nn/MLX/model/utils.cpp b/plugins/wasi_nn/MLX/model/utils.cpp index 283a6059..aba35791 100644 --- a/plugins/wasi_nn/MLX/model/utils.cpp +++ b/plugins/wasi_nn/MLX/model/utils.cpp @@ -29,4 +29,21 @@ bool startsWith(std::string const &Value, std::string const &Starting) { if (Starting.size() > Value.size()) return false; return std::equal(Starting.begin(), Starting.end(), Value.begin()); +} +void saveWeights(const std::unordered_map &Weights, + const std::string Path) { + if (endsWith(Path, ".safetensors")) { + mx::save_safetensors(Path, Weights, {{"format", "mlx"}}); + } else { + spdlog::error("Unsupported file format"); + assumingUnreachable(); + } +} +void saveWeights(const mx::array &Weights, const std::string &Path) { + if (endsWith(Path, ".npz")) { + mx::save(Path, Weights); + } else { + spdlog::error("Unsupported file format"); + assumingUnreachable(); + } } \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/model/utils.h b/plugins/wasi_nn/MLX/model/utils.h index 5c3d9682..7faa16eb 100644 --- a/plugins/wasi_nn/MLX/model/utils.h +++ b/plugins/wasi_nn/MLX/model/utils.h @@ -1,8 +1,13 @@ #pragma once +#include "base.h" #include #include #include + std::vector splitString(const std::string &S, char Delim); std::string joinString(std::vector &S, char Delim); bool endsWith(std::string const &Value, std::string const &Ending); -bool startsWith(std::string const &Value, std::string const &Starting); \ No newline at end of file +bool startsWith(std::string const &Value, std::string const &Starting); +void saveWeights(const std::unordered_map &Weights, + const std::string Path); +void saveWeights(const mx::array &Weights, const std::string &Path); \ No newline at end of file diff --git a/plugins/wasi_nn/mlx.cpp b/plugins/wasi_nn/mlx.cpp index 6078165b..706382c1 100644 --- a/plugins/wasi_nn/mlx.cpp +++ b/plugins/wasi_nn/mlx.cpp @@ -1,5 +1,4 @@ #include "mlx.h" -#include "spdlog/spdlog.h" #include "wasinnenv.h" #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX From 1c3709f38c142b6f306ef7654ed9b4453a017c10 Mon Sep 17 00:00:00 2001 From: grorge Date: Sun, 15 Sep 2024 21:25:36 +0800 Subject: [PATCH 449/623] [WASI-NN] mlx: add namespace for mlx Signed-off-by: grorge --- plugins/wasi_nn/CMakeLists.txt | 81 ++++++++++++------- plugins/wasi_nn/MLX/mlx/activations.cpp | 5 +- plugins/wasi_nn/MLX/mlx/activations.h | 5 +- plugins/wasi_nn/MLX/mlx/base.cpp | 7 +- plugins/wasi_nn/MLX/mlx/base.h | 12 +-- plugins/wasi_nn/MLX/mlx/embedding.cpp | 5 +- plugins/wasi_nn/MLX/mlx/embedding.h | 5 +- plugins/wasi_nn/MLX/mlx/linear.cpp | 5 +- plugins/wasi_nn/MLX/mlx/linear.h | 5 +- plugins/wasi_nn/MLX/mlx/normalization.cpp | 5 +- plugins/wasi_nn/MLX/mlx/normalization.h | 4 +- .../wasi_nn/MLX/mlx/positional_encoding.cpp | 4 +- plugins/wasi_nn/MLX/mlx/positional_encoding.h | 4 +- plugins/wasi_nn/MLX/mlx/quantized.cpp | 5 +- plugins/wasi_nn/MLX/mlx/quantized.h | 4 +- plugins/wasi_nn/MLX/mlx/transformer.cpp | 4 +- plugins/wasi_nn/MLX/mlx/transformer.h | 4 +- plugins/wasi_nn/MLX/model/converter.cpp | 4 +- plugins/wasi_nn/MLX/model/converter.h | 4 +- plugins/wasi_nn/MLX/model/registry.cpp | 4 +- plugins/wasi_nn/MLX/model/registry.h | 4 +- plugins/wasi_nn/MLX/model/transformer.cpp | 22 ++--- plugins/wasi_nn/MLX/model/transformer.h | 4 +- plugins/wasi_nn/MLX/model/utils.cpp | 5 +- plugins/wasi_nn/MLX/model/utils.h | 5 +- plugins/wasi_nn/MLX/prompt/prompt.cpp | 5 +- plugins/wasi_nn/MLX/prompt/prompt.h | 5 +- plugins/wasi_nn/mlx.cpp | 7 +- 28 files changed, 159 insertions(+), 74 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index b7a1cbdc..ec5c7111 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -1,36 +1,55 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: 2019-2024 Second State INC - -wasmedge_add_library(wasmedgePluginWasiNN - SHARED - wasinnenv.cpp - wasinnfunc.cpp - wasinnmodule.cpp - openvino.cpp - onnx.cpp - tf.cpp - torch.cpp - tfl.cpp - ggml.cpp - neuralspeed.cpp - piper.cpp - whispercpp.cpp - chattts.cpp - mlx.cpp - MLX/prompt/prompt.cpp - MLX/model/transformer.cpp - MLX/model/converter.cpp - MLX/model/utils.cpp - MLX/model/registry.cpp - MLX/mlx/base.cpp - MLX/mlx/linear.cpp - MLX/mlx/positional_encoding.cpp - MLX/mlx/activations.cpp - MLX/mlx/embedding.cpp - MLX/mlx/normalization.cpp - MLX/mlx/transformer.cpp - MLX/mlx/quantized.cpp -) +if(BACKEND STREQUAL "mlx") + wasmedge_add_library(wasmedgePluginWasiNN + SHARED + wasinnenv.cpp + wasinnfunc.cpp + wasinnmodule.cpp + openvino.cpp + onnx.cpp + tf.cpp + torch.cpp + tfl.cpp + ggml.cpp + neuralspeed.cpp + piper.cpp + whispercpp.cpp + chattts.cpp + mlx.cpp + MLX/prompt/prompt.cpp + MLX/model/transformer.cpp + MLX/model/converter.cpp + MLX/model/utils.cpp + MLX/model/registry.cpp + MLX/mlx/base.cpp + MLX/mlx/linear.cpp + MLX/mlx/positional_encoding.cpp + MLX/mlx/activations.cpp + MLX/mlx/embedding.cpp + MLX/mlx/normalization.cpp + MLX/mlx/transformer.cpp + MLX/mlx/quantized.cpp + ) +else() + wasmedge_add_library(wasmedgePluginWasiNN + SHARED + wasinnenv.cpp + wasinnfunc.cpp + wasinnmodule.cpp + openvino.cpp + onnx.cpp + tf.cpp + torch.cpp + tfl.cpp + ggml.cpp + neuralspeed.cpp + piper.cpp + whispercpp.cpp + chattts.cpp + mlx.cpp + ) +endif() foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) string(TOLOWER ${BACKEND} BACKEND) diff --git a/plugins/wasi_nn/MLX/mlx/activations.cpp b/plugins/wasi_nn/MLX/mlx/activations.cpp index 688609be..4149651a 100644 --- a/plugins/wasi_nn/MLX/mlx/activations.cpp +++ b/plugins/wasi_nn/MLX/mlx/activations.cpp @@ -1,8 +1,11 @@ #include "activations.h" #include + +namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core { mx::array gelu(mx::array X) { return X * (1 + mx::erf(X / std::sqrt(2.0))) / 2.0; } mx::array silu(mx::array X) { return X * mx::sigmoid(X); } -} // namespace mlx::core \ No newline at end of file +} // namespace mlx::core +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/activations.h b/plugins/wasi_nn/MLX/mlx/activations.h index f3a02d83..e997a3d6 100644 --- a/plugins/wasi_nn/MLX/mlx/activations.h +++ b/plugins/wasi_nn/MLX/mlx/activations.h @@ -2,7 +2,10 @@ #include "base.h" #include #include + +namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core { mx::array gelu(mx::array X); mx::array silu(mx::array X); -} // namespace mlx::core \ No newline at end of file +} // namespace mlx::core +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/base.cpp b/plugins/wasi_nn/MLX/mlx/base.cpp index a239bdbb..ac9ca063 100644 --- a/plugins/wasi_nn/MLX/mlx/base.cpp +++ b/plugins/wasi_nn/MLX/mlx/base.cpp @@ -3,9 +3,11 @@ #include #include +namespace WasmEdge::Host::WASINN::MLX { + namespace mlx::core::nn { -mx::array &Module::registerParameter(std::string Name, array &&W) { +mx::array &Module::registerParameter(std::string Name, mx::array &&W) { Parameters.insert({Name, W}); return Parameters.at(Name); } @@ -58,4 +60,5 @@ Module::getWeigts(const std::string &Prefix) { } return Weights; } -} // namespace mlx::core::nn \ No newline at end of file +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/base.h b/plugins/wasi_nn/MLX/mlx/base.h index c063f491..d224c034 100644 --- a/plugins/wasi_nn/MLX/mlx/base.h +++ b/plugins/wasi_nn/MLX/mlx/base.h @@ -7,8 +7,9 @@ #include #include #include -namespace mx = mlx::core; +namespace mx = mlx::core; +namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core::nn { class Module { public: @@ -20,7 +21,7 @@ class Module { std::string Name; std::unordered_map Parameters{}; std::unordered_map Submodules{}; - mx::array ®isterParameter(std::string Name, array &&W); + mx::array ®isterParameter(std::string Name, mx::array &&W); std::unordered_map getWeigts(const std::string &Prefix = "model"); virtual nn::Module *toQuantized(int GroupSize = 64, int Bits = 4); @@ -52,10 +53,11 @@ class Module { } } }; -} // namespace mlx::core::nn - template void printVec(std::vector Ve) { for (auto I : Ve) { spdlog::debug("{} ", I); } -} \ No newline at end of file +} +} // namespace mlx::core::nn + +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/embedding.cpp b/plugins/wasi_nn/MLX/mlx/embedding.cpp index 0f5326c2..d14d7ccc 100644 --- a/plugins/wasi_nn/MLX/mlx/embedding.cpp +++ b/plugins/wasi_nn/MLX/mlx/embedding.cpp @@ -4,6 +4,8 @@ #include #include +namespace WasmEdge::Host::WASINN::MLX { + namespace mlx::core::nn { mx::array Embedding::forward(mx::array Input) { return take(Parameters.at("weight"), Input, 0); @@ -16,4 +18,5 @@ nn::Module *Embedding::toQuantized(int GroupSize, int Bits) { QuantModel->Name = Name; return QuantModel; } -} // namespace mlx::core::nn \ No newline at end of file +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/embedding.h b/plugins/wasi_nn/MLX/mlx/embedding.h index fda27df3..0bed7181 100644 --- a/plugins/wasi_nn/MLX/mlx/embedding.h +++ b/plugins/wasi_nn/MLX/mlx/embedding.h @@ -3,6 +3,8 @@ #include #include #include + +namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core::nn { class Embedding : public Module { public: @@ -17,4 +19,5 @@ class Embedding : public Module { nn::Module *toQuantized(int GroupSize = 64, int Bits = 4) override; }; -} // namespace mlx::core::nn \ No newline at end of file +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/linear.cpp b/plugins/wasi_nn/MLX/mlx/linear.cpp index d0c9722b..1b7addf0 100644 --- a/plugins/wasi_nn/MLX/mlx/linear.cpp +++ b/plugins/wasi_nn/MLX/mlx/linear.cpp @@ -1,6 +1,8 @@ #include "linear.h" #include "base.h" #include "quantized.h" + +namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core::nn { mx::array Linear::forward(mx::array Input) { if (EnableBias) { @@ -16,4 +18,5 @@ nn::Module *Linear::toQuantized(int GroupSize, int Bits) { return QuantModel; } -} // namespace mlx::core::nn \ No newline at end of file +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/linear.h b/plugins/wasi_nn/MLX/mlx/linear.h index 502a94e9..5724120a 100644 --- a/plugins/wasi_nn/MLX/mlx/linear.h +++ b/plugins/wasi_nn/MLX/mlx/linear.h @@ -6,6 +6,8 @@ #include #include +namespace WasmEdge::Host::WASINN::MLX { + namespace mlx::core::nn { class Linear : public Module { @@ -28,4 +30,5 @@ class Linear : public Module { virtual mx::array forward(mx::array Input); nn::Module *toQuantized(int GroupSize = 64, int Bits = 4) override; }; -} // namespace mlx::core::nn \ No newline at end of file +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/normalization.cpp b/plugins/wasi_nn/MLX/mlx/normalization.cpp index 75bf0708..3805a6ce 100644 --- a/plugins/wasi_nn/MLX/mlx/normalization.cpp +++ b/plugins/wasi_nn/MLX/mlx/normalization.cpp @@ -1,6 +1,9 @@ #include "normalization.h" + +namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core::nn { mx::array RMSNorm::forward(mx::array Input) { return mx::fast::rms_norm(Input, Parameters.at("weight"), Eps); } -} // namespace mlx::core::nn \ No newline at end of file +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/normalization.h b/plugins/wasi_nn/MLX/mlx/normalization.h index 1e6d9072..de3a6ce2 100644 --- a/plugins/wasi_nn/MLX/mlx/normalization.h +++ b/plugins/wasi_nn/MLX/mlx/normalization.h @@ -1,5 +1,6 @@ #include "base.h" +namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core::nn { class RMSNorm : public nn::Module { float Eps; @@ -11,4 +12,5 @@ class RMSNorm : public nn::Module { mx::array forward(mx::array Input); }; -} // namespace mlx::core::nn \ No newline at end of file +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/positional_encoding.cpp b/plugins/wasi_nn/MLX/mlx/positional_encoding.cpp index fba24b14..f1d213f6 100644 --- a/plugins/wasi_nn/MLX/mlx/positional_encoding.cpp +++ b/plugins/wasi_nn/MLX/mlx/positional_encoding.cpp @@ -1,8 +1,10 @@ #include "positional_encoding.h" +namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core::nn { mx::array RoPE::forward(mx::array Input, int Offset) { return mx::fast::rope(Input, Dims, Tranditional, Base, Scale, Offset); } -} // namespace mlx::core::nn \ No newline at end of file +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/positional_encoding.h b/plugins/wasi_nn/MLX/mlx/positional_encoding.h index 5d997909..ef706e49 100644 --- a/plugins/wasi_nn/MLX/mlx/positional_encoding.h +++ b/plugins/wasi_nn/MLX/mlx/positional_encoding.h @@ -3,6 +3,7 @@ #include #include +namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core::nn { class RoPE : public Module { int Dims; @@ -17,4 +18,5 @@ class RoPE : public Module { mx::array forward(mx::array Input, int Offset = 0); }; -} // namespace mlx::core::nn \ No newline at end of file +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/quantized.cpp b/plugins/wasi_nn/MLX/mlx/quantized.cpp index d1f5fda1..90f5c870 100644 --- a/plugins/wasi_nn/MLX/mlx/quantized.cpp +++ b/plugins/wasi_nn/MLX/mlx/quantized.cpp @@ -3,6 +3,8 @@ #include #include #include + +namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core::nn { mx::array QuantizedEmbedding::forward(mx::array Input) { auto S = Input.shape(); @@ -59,4 +61,5 @@ QuantizedLinear *QuantizedLinear::fromLinear(Linear *LinearModule, } return QuantizedModel; } -} // namespace mlx::core::nn \ No newline at end of file +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/quantized.h b/plugins/wasi_nn/MLX/mlx/quantized.h index 3f884ec3..ff49ab29 100644 --- a/plugins/wasi_nn/MLX/mlx/quantized.h +++ b/plugins/wasi_nn/MLX/mlx/quantized.h @@ -5,6 +5,7 @@ #include #include +namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core::nn { class QuantizedEmbedding : public Embedding { @@ -53,4 +54,5 @@ class QuantizedLinear : public Linear { int Bits = 4); }; -} // namespace mlx::core::nn \ No newline at end of file +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/transformer.cpp b/plugins/wasi_nn/MLX/mlx/transformer.cpp index 7564043f..853b83f8 100644 --- a/plugins/wasi_nn/MLX/mlx/transformer.cpp +++ b/plugins/wasi_nn/MLX/mlx/transformer.cpp @@ -1,6 +1,7 @@ #include "transformer.h" #include +namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core::nn { mx::array MultiHeadAttention::createAdditiveCausalMask(int N, mx::Dtype DType) { auto Indices = mx::arange(N); @@ -9,4 +10,5 @@ mx::array MultiHeadAttention::createAdditiveCausalMask(int N, mx::Dtype DType) { Mask = astype(Mask, DType) * -1e9; return Mask; } -} // namespace mlx::core::nn \ No newline at end of file +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/transformer.h b/plugins/wasi_nn/MLX/mlx/transformer.h index dce95ea2..a8c94897 100644 --- a/plugins/wasi_nn/MLX/mlx/transformer.h +++ b/plugins/wasi_nn/MLX/mlx/transformer.h @@ -3,6 +3,7 @@ #include "linear.h" #include +namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core::nn { class MultiHeadAttention : public Module { int NumHeads; @@ -44,4 +45,5 @@ class MultiHeadAttention : public Module { static mx::array createAdditiveCausalMask(int N, mx::Dtype DType = mx::float32); }; -} // namespace mlx::core::nn \ No newline at end of file +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/converter.cpp b/plugins/wasi_nn/MLX/model/converter.cpp index a30e6144..13fb54b7 100644 --- a/plugins/wasi_nn/MLX/model/converter.cpp +++ b/plugins/wasi_nn/MLX/model/converter.cpp @@ -5,6 +5,7 @@ #include #include +namespace WasmEdge::Host::WASINN::MLX { std::unordered_map weightsToMlx(std::string WeightPath) { const std::filesystem::path Path(WeightPath); @@ -81,4 +82,5 @@ llamaToMlxllm(std::string WeightPath) { } } return ModelWeights; -} \ No newline at end of file +} +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/converter.h b/plugins/wasi_nn/MLX/model/converter.h index c5dbd4b5..063556b1 100644 --- a/plugins/wasi_nn/MLX/model/converter.h +++ b/plugins/wasi_nn/MLX/model/converter.h @@ -3,9 +3,11 @@ #include "mlx/mlx.h" #include +namespace WasmEdge::Host::WASINN::MLX { #define strReplace(Str, From, To) Str.replace(Str.find(From), strlen(From), To) std::unordered_map weightsToMlx(std::string WeightPath); std::unordered_map -llamaToMlxllm(std::string WeightPath); \ No newline at end of file +llamaToMlxllm(std::string WeightPath); +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/registry.cpp b/plugins/wasi_nn/MLX/model/registry.cpp index 2ff48b92..95218fb3 100644 --- a/plugins/wasi_nn/MLX/model/registry.cpp +++ b/plugins/wasi_nn/MLX/model/registry.cpp @@ -1,5 +1,6 @@ #include "registry.h" +namespace WasmEdge::Host::WASINN::MLX { Transformer *llama38b(int VocabSize, float NormEps, float RopeTheta, bool RopeTraditional) { return new Transformer(4096, std::vector{14336}, VocabSize, 32, @@ -19,4 +20,5 @@ Transformer *tinyLlama11BChatV10(int VocabSize, float NormEps, float RopeTheta, return new Transformer(2048, std::vector{5632}, VocabSize, 22, std::vector{32}, std::vector{4}, NormEps, {}, RopeTraditional, RopeTheta); -} \ No newline at end of file +} +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/registry.h b/plugins/wasi_nn/MLX/model/registry.h index 81774bd0..871a27d3 100644 --- a/plugins/wasi_nn/MLX/model/registry.h +++ b/plugins/wasi_nn/MLX/model/registry.h @@ -2,6 +2,7 @@ #include "transformer.h" +namespace WasmEdge::Host::WASINN::MLX { Transformer *llama38b(int VocabSize = 32000, float NormEps = 1e-5, float RopeTheta = 10000.0, bool RopeTraditional = false); @@ -11,4 +12,5 @@ Transformer *llama27bChat(int VocabSize = 32000, float NormEps = 1e-5, Transformer *tinyLlama11BChatV10(int VocabSize = 32000, float NormEps = 1e-5, float RopeTheta = 10000.0, - bool RopeTraditional = false); \ No newline at end of file + bool RopeTraditional = false); +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/transformer.cpp b/plugins/wasi_nn/MLX/model/transformer.cpp index f6271664..d914ca39 100644 --- a/plugins/wasi_nn/MLX/model/transformer.cpp +++ b/plugins/wasi_nn/MLX/model/transformer.cpp @@ -11,6 +11,7 @@ #include #include +namespace WasmEdge::Host::WASINN::MLX { mx::array RMSNorm::forward(mx::array Input) { return mx::fast::rms_norm(Input, 1.0 + Parameters.at("weight"), Eps); } @@ -56,14 +57,14 @@ mx::array MLP::forward(mx::array Input) { if (Gemma) { return dynamic_cast(Submodules["down_proj"]) ->forward( - gelu(dynamic_cast(Submodules["gate_proj"]) - ->forward(Input)) * + mlx::core::gelu(dynamic_cast(Submodules["gate_proj"]) + ->forward(Input)) * dynamic_cast(Submodules["up_proj"])->forward(Input)); } return dynamic_cast(Submodules["down_proj"]) ->forward( - silu(dynamic_cast(Submodules["gate_proj"]) - ->forward(Input)) * + mlx::core::silu(dynamic_cast(Submodules["gate_proj"]) + ->forward(Input)) * dynamic_cast(Submodules["up_proj"])->forward(Input)); } std::tuple> @@ -98,14 +99,16 @@ Transformer::embed( mx::array Input, std::optional>> KVCachePar, bool Norm) { - mx::array H = dynamic_cast(Submodules["token_embed"]) - ->forward(Input); + mx::array H = + dynamic_cast(Submodules["token_embed"]) + ->forward(Input); if (Gemma) { H = H * (pow(Dim, 0.5)); } std::optional Mask; if (H.shape()[1] > 1) { - Mask = mx::nn::MultiHeadAttention::createAdditiveCausalMask(H.shape()[1]); + Mask = mlx::core::nn::MultiHeadAttention::createAdditiveCausalMask( + H.shape()[1]); Mask = astype(*Mask, H.dtype()); } std::vector> KVCache; @@ -138,7 +141,7 @@ Transformer::forward( auto [X, KVCache] = embed(Input, KVCachePar, true); mx::array Out = {}; if (EmbedAsHead) { - Out = dynamic_cast(Submodules["token_embed"]) + Out = dynamic_cast(Submodules["token_embed"]) ->asLinear(X); } else { Out = dynamic_cast(Submodules["head"])->forward(X); @@ -183,4 +186,5 @@ Transformer::nextGenerate( NextY = mx::random::categorical(Logits * (1.0 / *Temp)); } return {NextY, KVCache}; -} \ No newline at end of file +} +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/transformer.h b/plugins/wasi_nn/MLX/model/transformer.h index 8db5d392..d0dac595 100644 --- a/plugins/wasi_nn/MLX/model/transformer.h +++ b/plugins/wasi_nn/MLX/model/transformer.h @@ -13,6 +13,7 @@ #include #include +namespace WasmEdge::Host::WASINN::MLX { namespace nn = mlx::core::nn; class RMSNorm : public nn::Module { @@ -196,4 +197,5 @@ class Transformer : public nn::Module { nextGenerate(mx::array Y, std::optional Temp = 0.0, std::optional>> KVCachePar = {}); -}; \ No newline at end of file +}; +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/utils.cpp b/plugins/wasi_nn/MLX/model/utils.cpp index aba35791..a92eba7f 100644 --- a/plugins/wasi_nn/MLX/model/utils.cpp +++ b/plugins/wasi_nn/MLX/model/utils.cpp @@ -1,5 +1,7 @@ #include "utils.h" #include + +namespace WasmEdge::Host::WASINN::MLX { std::vector splitString(const std::string &S, char Delim) { std::vector Result; std::stringstream SS(S); @@ -46,4 +48,5 @@ void saveWeights(const mx::array &Weights, const std::string &Path) { spdlog::error("Unsupported file format"); assumingUnreachable(); } -} \ No newline at end of file +} +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/utils.h b/plugins/wasi_nn/MLX/model/utils.h index 7faa16eb..606a2245 100644 --- a/plugins/wasi_nn/MLX/model/utils.h +++ b/plugins/wasi_nn/MLX/model/utils.h @@ -4,10 +4,13 @@ #include #include +namespace WasmEdge::Host::WASINN::MLX { + std::vector splitString(const std::string &S, char Delim); std::string joinString(std::vector &S, char Delim); bool endsWith(std::string const &Value, std::string const &Ending); bool startsWith(std::string const &Value, std::string const &Starting); void saveWeights(const std::unordered_map &Weights, const std::string Path); -void saveWeights(const mx::array &Weights, const std::string &Path); \ No newline at end of file +void saveWeights(const mx::array &Weights, const std::string &Path); +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/prompt/prompt.cpp b/plugins/wasi_nn/MLX/prompt/prompt.cpp index 9434b04e..8e22fc71 100644 --- a/plugins/wasi_nn/MLX/prompt/prompt.cpp +++ b/plugins/wasi_nn/MLX/prompt/prompt.cpp @@ -1,6 +1,8 @@ #include "prompt.h" #include +namespace WasmEdge::Host::WASINN::MLX { + std::string TinyLLaMAPrompt::prepare(std::string Prompt) { return SystemStart + TextEnd + Prompt + TextEnd + Assistant; } @@ -11,4 +13,5 @@ std::string LLaMA3Prompt::prepare(std::string Prompt) { return PropmtStart + StartHeader + "system" + EndHeader + TextEnd + Prompt + EndHeader + TextEnd + StartHeader + "user" + EndHeader + Prompt + TextEnd + StartHeader + "assistant" + EndHeader; -} \ No newline at end of file +} +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/prompt/prompt.h b/plugins/wasi_nn/MLX/prompt/prompt.h index 06eee87a..20d46530 100644 --- a/plugins/wasi_nn/MLX/prompt/prompt.h +++ b/plugins/wasi_nn/MLX/prompt/prompt.h @@ -4,6 +4,8 @@ #include #include +namespace WasmEdge::Host::WASINN::MLX { + class BasePrompt { public: std::string TextEnd; @@ -49,4 +51,5 @@ class LLaMA3Prompt : public BasePrompt { TextEnd = "<|eot_id|>"; } std::string prepare(std::string Prompt) override; -}; \ No newline at end of file +}; +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/mlx.cpp b/plugins/wasi_nn/mlx.cpp index 706382c1..bbaa483e 100644 --- a/plugins/wasi_nn/mlx.cpp +++ b/plugins/wasi_nn/mlx.cpp @@ -198,7 +198,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, return ErrNo::InvalidArgument; } } - + GraphId = Env.NNGraph.size() - 1; return WASINN::ErrNo::Success; } @@ -210,8 +210,7 @@ Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, } Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, - uint32_t Index, - const TensorData &Tensor) noexcept { + uint32_t, const TensorData &Tensor) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); if (GraphRef.EnableDebugLog) { @@ -224,7 +223,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, - uint32_t Index, Span OutBuffer, + uint32_t, Span OutBuffer, uint32_t &BytesWritten) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); From 8047ff79b79df754ae368588f4c7188db0cee358 Mon Sep 17 00:00:00 2001 From: grorge Date: Mon, 16 Sep 2024 11:13:43 +0800 Subject: [PATCH 450/623] [WASI-NN] mlx: add quantize option Signed-off-by: grorge --- plugins/wasi_nn/CMakeLists.txt | 85 +++++++++++++------------------- plugins/wasi_nn/mlx.cpp | 28 +++++++++++ plugins/wasi_nn/mlx.h | 3 ++ test/plugins/wasi_nn/wasi_nn.cpp | 8 +-- 4 files changed, 71 insertions(+), 53 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index ec5c7111..789e4b3b 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -1,55 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: 2019-2024 Second State INC -if(BACKEND STREQUAL "mlx") - wasmedge_add_library(wasmedgePluginWasiNN - SHARED - wasinnenv.cpp - wasinnfunc.cpp - wasinnmodule.cpp - openvino.cpp - onnx.cpp - tf.cpp - torch.cpp - tfl.cpp - ggml.cpp - neuralspeed.cpp - piper.cpp - whispercpp.cpp - chattts.cpp - mlx.cpp - MLX/prompt/prompt.cpp - MLX/model/transformer.cpp - MLX/model/converter.cpp - MLX/model/utils.cpp - MLX/model/registry.cpp - MLX/mlx/base.cpp - MLX/mlx/linear.cpp - MLX/mlx/positional_encoding.cpp - MLX/mlx/activations.cpp - MLX/mlx/embedding.cpp - MLX/mlx/normalization.cpp - MLX/mlx/transformer.cpp - MLX/mlx/quantized.cpp - ) -else() - wasmedge_add_library(wasmedgePluginWasiNN - SHARED - wasinnenv.cpp - wasinnfunc.cpp - wasinnmodule.cpp - openvino.cpp - onnx.cpp - tf.cpp - torch.cpp - tfl.cpp - ggml.cpp - neuralspeed.cpp - piper.cpp - whispercpp.cpp - chattts.cpp - mlx.cpp - ) -endif() + +wasmedge_add_library(wasmedgePluginWasiNN + SHARED + wasinnenv.cpp + wasinnfunc.cpp + wasinnmodule.cpp + openvino.cpp + onnx.cpp + tf.cpp + torch.cpp + tfl.cpp + ggml.cpp + neuralspeed.cpp + piper.cpp + whispercpp.cpp + chattts.cpp + mlx.cpp +) foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) string(TOLOWER ${BACKEND} BACKEND) @@ -224,6 +192,22 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) simdjson::simdjson ) elseif(BACKEND STREQUAL "mlx") + target_sources(wasmedgePluginWasiNN + PRIVATE + MLX/prompt/prompt.cpp + MLX/model/transformer.cpp + MLX/model/converter.cpp + MLX/model/utils.cpp + MLX/model/registry.cpp + MLX/mlx/base.cpp + MLX/mlx/linear.cpp + MLX/mlx/positional_encoding.cpp + MLX/mlx/activations.cpp + MLX/mlx/embedding.cpp + MLX/mlx/normalization.cpp + MLX/mlx/transformer.cpp + MLX/mlx/quantized.cpp + ) wasmedge_setup_simdjson() find_package(MLX CONFIG) if(MLX_FOUND) @@ -272,6 +256,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(tokenizers) + set_property(TARGET tokenizer_cpp_objs PROPERTY POSITION_INDEPENDENT_CODE ON) target_include_directories(wasmedgePluginWasiNN PRIVATE ${tokenizers_SOURCE_DIR}/include) target_link_libraries(wasmedgePluginWasiNN PRIVATE tokenizers_cpp) diff --git a/plugins/wasi_nn/mlx.cpp b/plugins/wasi_nn/mlx.cpp index bbaa483e..b4de8623 100644 --- a/plugins/wasi_nn/mlx.cpp +++ b/plugins/wasi_nn/mlx.cpp @@ -121,6 +121,25 @@ Expect load(WASINN::WasiNNEnvironment &Env, } GraphRef.MaxToken = MaxToken; } + if (Doc.at_key("q_bits").error() == simdjson::SUCCESS && + Doc.at_key("group_size").error() == simdjson::SUCCESS && + Doc.at_key("is_quantized").error() == simdjson::SUCCESS) { + uint64_t QBits; + uint64_t GroupSize; + bool IsQuantized; + auto ErrQBits = Doc["q_bits"].get().get(QBits); + auto ErrGroupSize = Doc["group_size"].get().get(GroupSize); + auto ErrIsQuantized = Doc["is_quantized"].get().get(IsQuantized); + if (ErrQBits || ErrGroupSize || ErrIsQuantized) { + spdlog::error( + "[WASI-NN] MLX backend: Unable to retrieve the q_bits or group_size option."sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } + GraphRef.IsQuantized = IsQuantized; + GraphRef.QBits = QBits; + GraphRef.GroupSize = GroupSize; + } // Load tokenizer. if (!TokenizerPath.empty()) { @@ -149,6 +168,10 @@ Expect load(WASINN::WasiNNEnvironment &Env, return ErrNo::InvalidArgument; } + if (GraphRef.QBits != 0 && GraphRef.GroupSize != 0 && GraphRef.IsQuantized) { + GraphRef.Model->toQuantized(GraphRef.GroupSize, GraphRef.QBits); + } + // Handle the model path. for (size_t Idx = 0; Idx < Builders.size() - 1; Idx++) { auto Weight = Builders[Idx]; @@ -198,6 +221,11 @@ Expect load(WASINN::WasiNNEnvironment &Env, return ErrNo::InvalidArgument; } } + + if (GraphRef.QBits != 0 && GraphRef.GroupSize != 0 && !GraphRef.IsQuantized) { + GraphRef.Model->toQuantized(GraphRef.GroupSize, GraphRef.QBits); + } + GraphId = Env.NNGraph.size() - 1; return WASINN::ErrNo::Success; } diff --git a/plugins/wasi_nn/mlx.h b/plugins/wasi_nn/mlx.h index 5be43305..0c680867 100644 --- a/plugins/wasi_nn/mlx.h +++ b/plugins/wasi_nn/mlx.h @@ -31,7 +31,10 @@ struct Graph { Transformer *Model; double Temp = 0.0; bool EnableDebugLog = false; + bool IsQuantized = false; uint64_t MaxToken = 1024; + uint64_t QBits = 0; + uint64_t GroupSize = 0; BasePrompt Prmopt; }; struct Context { diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 45a18c4b..01ca1d53 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -2823,9 +2823,11 @@ TEST(WasiNNTest, MLXBackend) { static_cast(ErrNo::InvalidArgument)); } // Test: load -- load successfully. - std::string Config = "{\"model_type\":\"tiny_llama_1.1B_chat_v1.0\", " - "\"tokenizer\":\"" + - Tokenizer + "\"}"; + std::string Config = + "{\"model_type\":\"tiny_llama_1.1B_chat_v1.0\", " + "\"tokenizer\":\"" + + Tokenizer + + "\", \"q_bits\": 4, \"group_size\": 128, \"is_quantized\": false}"; std::vector ConfigData(Config.begin(), Config.end()); BuilderPtr = LoadEntryPtr; writeFatPointer(MemInst, StorePtr, WeightRead.size(), BuilderPtr); From a7167576191d3b1c0c884e1d1c985d1a1eb73d85 Mon Sep 17 00:00:00 2001 From: grorge Date: Sat, 21 Sep 2024 12:48:17 +0800 Subject: [PATCH 451/623] [WASI-NN] mlx: change raw pointer to smart pointer Signed-off-by: grorge --- plugins/wasi_nn/MLX/mlx/base.cpp | 31 +++--- plugins/wasi_nn/MLX/mlx/base.h | 37 +++---- plugins/wasi_nn/MLX/mlx/embedding.cpp | 7 +- plugins/wasi_nn/MLX/mlx/embedding.h | 3 +- plugins/wasi_nn/MLX/mlx/linear.cpp | 6 +- plugins/wasi_nn/MLX/mlx/linear.h | 3 +- plugins/wasi_nn/MLX/mlx/normalization.h | 1 + .../wasi_nn/MLX/mlx/positional_encoding.cpp | 1 + plugins/wasi_nn/MLX/mlx/positional_encoding.h | 1 + plugins/wasi_nn/MLX/mlx/quantized.cpp | 20 ++-- plugins/wasi_nn/MLX/mlx/quantized.h | 11 ++- plugins/wasi_nn/MLX/mlx/transformer.cpp | 2 + plugins/wasi_nn/MLX/mlx/transformer.h | 16 +++- plugins/wasi_nn/MLX/model/converter.cpp | 24 ++--- plugins/wasi_nn/MLX/model/converter.h | 1 + plugins/wasi_nn/MLX/model/registry.cpp | 34 ++++--- plugins/wasi_nn/MLX/model/registry.h | 21 ++-- plugins/wasi_nn/MLX/model/transformer.cpp | 96 +++++++++++-------- plugins/wasi_nn/MLX/model/transformer.h | 77 +++++++++------ plugins/wasi_nn/MLX/model/utils.cpp | 4 +- plugins/wasi_nn/mlx.cpp | 3 +- plugins/wasi_nn/mlx.h | 13 +-- 22 files changed, 236 insertions(+), 176 deletions(-) diff --git a/plugins/wasi_nn/MLX/mlx/base.cpp b/plugins/wasi_nn/MLX/mlx/base.cpp index ac9ca063..4a395e32 100644 --- a/plugins/wasi_nn/MLX/mlx/base.cpp +++ b/plugins/wasi_nn/MLX/mlx/base.cpp @@ -1,5 +1,6 @@ #include "base.h" #include "../model/utils.h" +#include #include #include @@ -12,25 +13,22 @@ mx::array &Module::registerParameter(std::string Name, mx::array &&W) { return Parameters.at(Name); } void Module::update(std::unordered_map Parameters) { - for (auto &[k, v] : Parameters) { - apply(k, v); + for (auto &[K, V] : Parameters) { + apply(K, V); } } -nn::Module *Module::toQuantized(int GroupSize, int Bits) { - for (auto &[k, v] : Submodules) { - auto *OldModule = v; - v = v->toQuantized(GroupSize, Bits); - if (OldModule != v) { - delete OldModule; - } +std::shared_ptr Module::toQuantized(int GroupSize, int Bits) { + for (auto &[K, V] : Submodules) { + const auto OldModule = V; + V = V->toQuantized(GroupSize, Bits); } - return this; + return shared_from_this(); } void Module::apply(std::string Key, mx::array Value) { std::vector SplitKey = splitString(Key, '.'); if (SplitKey.size() == 1) { if (Parameters.find(Key) == Parameters.end()) { - spdlog::error("Unsupported weight: {}", Key); + spdlog::error("[WASI-NN] MLX backend: Unsupported weight: {}"sv, Key); assumingUnreachable(); } this->Parameters.at(Key) = Value; @@ -42,7 +40,8 @@ void Module::apply(std::string Key, mx::array Value) { SplitKey.erase(SplitKey.begin()); } if (Submodules.find(LayerName) == Submodules.end()) { - spdlog::error("Unsupported Layer: {}", LayerName); + spdlog::error("[WASI-NN] MLX backend: Unsupported Layer: {}"sv, + LayerName); assumingUnreachable(); } Submodules.at(LayerName)->apply(joinString(SplitKey, '.'), Value); @@ -51,12 +50,12 @@ void Module::apply(std::string Key, mx::array Value) { std::unordered_map Module::getWeigts(const std::string &Prefix) { std::unordered_map Weights; - for (auto &[k, v] : Submodules) { - auto Subweights = v->getWeigts(Prefix + Name + "."); + for (auto &[K, V] : Submodules) { + auto Subweights = V->getWeigts(Prefix + Name + "."); Weights.insert(Subweights.begin(), Subweights.end()); } - for (auto &[k, v] : Parameters) { - Weights.insert({Prefix + Name + "." + k, v}); + for (auto &[K, V] : Parameters) { + Weights.insert({Prefix + Name + "." + K, V}); } return Weights; } diff --git a/plugins/wasi_nn/MLX/mlx/base.h b/plugins/wasi_nn/MLX/mlx/base.h index d224c034..66482f36 100644 --- a/plugins/wasi_nn/MLX/mlx/base.h +++ b/plugins/wasi_nn/MLX/mlx/base.h @@ -1,36 +1,36 @@ #pragma once - #include "common/errcode.h" #include "mlx/mlx.h" #include #include #include +#include #include #include +using namespace std::literals::string_view_literals; -namespace mx = mlx::core; namespace WasmEdge::Host::WASINN::MLX { +namespace mx = mlx::core; + namespace mlx::core::nn { -class Module { +class Module : public std::enable_shared_from_this { public: - virtual ~Module() { - for (auto Module : Submodules) { - delete Module.second; - } - } + virtual ~Module() = default; std::string Name; std::unordered_map Parameters{}; - std::unordered_map Submodules{}; + std::unordered_map> Submodules{}; mx::array ®isterParameter(std::string Name, mx::array &&W); std::unordered_map getWeigts(const std::string &Prefix = "model"); - virtual nn::Module *toQuantized(int GroupSize = 64, int Bits = 4); + virtual std::shared_ptr toQuantized(int GroupSize = 64, + int Bits = 4); void update(std::unordered_map Parameters); void apply(std::string Key, mx::array Parameters); - template void registerModule(std::string ModuleName, T *M) { + template + void registerModule(std::string ModuleName, std::shared_ptr M) { using DecayedT = std::decay_t; if (!std::is_base_of::value) { - spdlog::error("Invalid subModule."); + spdlog::error("[WASI-NN] MLX backend: Invalid subModule."sv); assumingUnreachable(); } @@ -38,14 +38,15 @@ class Module { Submodules.insert({ModuleName, M}); Submodules.at(ModuleName)->Name = ModuleName; } else { - spdlog::error("Module already exists."); + spdlog::error("[WASI-NN] MLX backend: Module already exists."sv); assumingUnreachable(); } } template - void registerLayer(std::string ModuleName, std::vector &Layers) { + void registerLayer(std::string ModuleName, + std::vector> &Layers) { if (!std::is_base_of::value) { - spdlog::error("Invalid subModule."); + spdlog::error("[WASI-NN] MLX backend: Invalid subModule."sv); assumingUnreachable(); } for (size_t Idx = 0; Idx < Layers.size(); Idx++) { @@ -53,11 +54,11 @@ class Module { } } }; +} // namespace mlx::core::nn + template void printVec(std::vector Ve) { for (auto I : Ve) { - spdlog::debug("{} ", I); + spdlog::debug("[WASI-NN] MLX backend: {} ."sv, I); } } -} // namespace mlx::core::nn - } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/embedding.cpp b/plugins/wasi_nn/MLX/mlx/embedding.cpp index d14d7ccc..5ec56806 100644 --- a/plugins/wasi_nn/MLX/mlx/embedding.cpp +++ b/plugins/wasi_nn/MLX/mlx/embedding.cpp @@ -1,6 +1,7 @@ #include "embedding.h" #include "base.h" #include "quantized.h" +#include #include #include @@ -13,8 +14,10 @@ mx::array Embedding::forward(mx::array Input) { mx::array Embedding::asLinear(mx::array Input) { return matmul(Input, transpose(Parameters.at("weight"))); } -nn::Module *Embedding::toQuantized(int GroupSize, int Bits) { - auto *QuantModel = QuantizedEmbedding::fromEmbedding(this, GroupSize, Bits); +std::shared_ptr Embedding::toQuantized(int GroupSize, int Bits) { + auto QuantModel = QuantizedEmbedding::fromEmbedding( + std::dynamic_pointer_cast(shared_from_this()), GroupSize, + Bits); QuantModel->Name = Name; return QuantModel; } diff --git a/plugins/wasi_nn/MLX/mlx/embedding.h b/plugins/wasi_nn/MLX/mlx/embedding.h index 0bed7181..f2943eba 100644 --- a/plugins/wasi_nn/MLX/mlx/embedding.h +++ b/plugins/wasi_nn/MLX/mlx/embedding.h @@ -16,7 +16,8 @@ class Embedding : public Module { } virtual mx::array forward(mx::array Input); mx::array asLinear(mx::array Input); - nn::Module *toQuantized(int GroupSize = 64, int Bits = 4) override; + std::shared_ptr toQuantized(int GroupSize = 64, + int Bits = 4) override; }; } // namespace mlx::core::nn diff --git a/plugins/wasi_nn/MLX/mlx/linear.cpp b/plugins/wasi_nn/MLX/mlx/linear.cpp index 1b7addf0..60018a0c 100644 --- a/plugins/wasi_nn/MLX/mlx/linear.cpp +++ b/plugins/wasi_nn/MLX/mlx/linear.cpp @@ -1,6 +1,7 @@ #include "linear.h" #include "base.h" #include "quantized.h" +#include namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core::nn { @@ -12,8 +13,9 @@ mx::array Linear::forward(mx::array Input) { return matmul(Input, transpose(Parameters.at("weight"))); } -nn::Module *Linear::toQuantized(int GroupSize, int Bits) { - auto *QuantModel = QuantizedLinear::fromLinear(this, GroupSize, Bits); +std::shared_ptr Linear::toQuantized(int GroupSize, int Bits) { + auto QuantModel = QuantizedLinear::fromLinear( + std::dynamic_pointer_cast(shared_from_this()), GroupSize, Bits); QuantModel->Name = Name; return QuantModel; } diff --git a/plugins/wasi_nn/MLX/mlx/linear.h b/plugins/wasi_nn/MLX/mlx/linear.h index 5724120a..d6d0eb58 100644 --- a/plugins/wasi_nn/MLX/mlx/linear.h +++ b/plugins/wasi_nn/MLX/mlx/linear.h @@ -28,7 +28,8 @@ class Linear : public Module { } } virtual mx::array forward(mx::array Input); - nn::Module *toQuantized(int GroupSize = 64, int Bits = 4) override; + std::shared_ptr toQuantized(int GroupSize = 64, + int Bits = 4) override; }; } // namespace mlx::core::nn } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/normalization.h b/plugins/wasi_nn/MLX/mlx/normalization.h index de3a6ce2..72bf3ba7 100644 --- a/plugins/wasi_nn/MLX/mlx/normalization.h +++ b/plugins/wasi_nn/MLX/mlx/normalization.h @@ -1,6 +1,7 @@ #include "base.h" namespace WasmEdge::Host::WASINN::MLX { + namespace mlx::core::nn { class RMSNorm : public nn::Module { float Eps; diff --git a/plugins/wasi_nn/MLX/mlx/positional_encoding.cpp b/plugins/wasi_nn/MLX/mlx/positional_encoding.cpp index f1d213f6..df04a92e 100644 --- a/plugins/wasi_nn/MLX/mlx/positional_encoding.cpp +++ b/plugins/wasi_nn/MLX/mlx/positional_encoding.cpp @@ -1,6 +1,7 @@ #include "positional_encoding.h" namespace WasmEdge::Host::WASINN::MLX { + namespace mlx::core::nn { mx::array RoPE::forward(mx::array Input, int Offset) { return mx::fast::rope(Input, Dims, Tranditional, Base, Scale, Offset); diff --git a/plugins/wasi_nn/MLX/mlx/positional_encoding.h b/plugins/wasi_nn/MLX/mlx/positional_encoding.h index ef706e49..8f8f97ae 100644 --- a/plugins/wasi_nn/MLX/mlx/positional_encoding.h +++ b/plugins/wasi_nn/MLX/mlx/positional_encoding.h @@ -4,6 +4,7 @@ #include namespace WasmEdge::Host::WASINN::MLX { + namespace mlx::core::nn { class RoPE : public Module { int Dims; diff --git a/plugins/wasi_nn/MLX/mlx/quantized.cpp b/plugins/wasi_nn/MLX/mlx/quantized.cpp index 90f5c870..c4b122dc 100644 --- a/plugins/wasi_nn/MLX/mlx/quantized.cpp +++ b/plugins/wasi_nn/MLX/mlx/quantized.cpp @@ -1,5 +1,6 @@ #include "quantized.h" #include +#include #include #include #include @@ -26,12 +27,12 @@ mx::array QuantizedLinear::forward(mx::array Input) { } return Out; } -QuantizedEmbedding * -QuantizedEmbedding::fromEmbedding(Embedding *EmbeddingModule, int GroupSize, - int Bits) { +std::shared_ptr +QuantizedEmbedding::fromEmbedding(std::shared_ptr EmbeddingModule, + int GroupSize, int Bits) { auto EmbeddingShape = EmbeddingModule->Parameters.at("weight").shape(); - auto *QuantizedModel = new QuantizedEmbedding( - EmbeddingShape[0], EmbeddingShape[1], GroupSize, Bits); + auto QuantizedModel = std::make_shared(QuantizedEmbedding( + EmbeddingShape[0], EmbeddingShape[1], GroupSize, Bits)); auto Quantized = mx::quantize(EmbeddingModule->Parameters.at("weight"), GroupSize, Bits); QuantizedModel->Parameters.insert_or_assign("weight", std::get<0>(Quantized)); @@ -41,13 +42,14 @@ QuantizedEmbedding::fromEmbedding(Embedding *EmbeddingModule, int GroupSize, "biases", std::move(std::get<2>(Quantized))); return QuantizedModel; } -QuantizedLinear *QuantizedLinear::fromLinear(Linear *LinearModule, - int GroupSize, int Bits) { +std::shared_ptr +QuantizedLinear::fromLinear(std::shared_ptr LinearModule, int GroupSize, + int Bits) { auto LinearShape = LinearModule->Parameters.at("weight").shape(); const bool EnableBias = LinearModule->Parameters.find("bias") != LinearModule->Parameters.end(); - auto *QuantizedModel = new QuantizedLinear(LinearShape[0], LinearShape[1], - EnableBias, GroupSize, Bits); + auto QuantizedModel = std::make_shared(QuantizedLinear( + LinearShape[0], LinearShape[1], EnableBias, GroupSize, Bits)); auto Quantized = mx::quantize(LinearModule->Parameters.at("weight"), GroupSize, Bits); QuantizedModel->Parameters.insert_or_assign("weight", std::get<0>(Quantized)); diff --git a/plugins/wasi_nn/MLX/mlx/quantized.h b/plugins/wasi_nn/MLX/mlx/quantized.h index ff49ab29..5b3ad9c8 100644 --- a/plugins/wasi_nn/MLX/mlx/quantized.h +++ b/plugins/wasi_nn/MLX/mlx/quantized.h @@ -6,6 +6,7 @@ #include namespace WasmEdge::Host::WASINN::MLX { + namespace mlx::core::nn { class QuantizedEmbedding : public Embedding { @@ -27,8 +28,9 @@ class QuantizedEmbedding : public Embedding { registerParameter("biases", std::move(std::get<2>(Quantized))); } mx::array forward(mx::array Input) override; - static QuantizedEmbedding *fromEmbedding(Embedding *EmbeddingModule, - int GroupSize = 64, int Bits = 4); + static std::shared_ptr + fromEmbedding(std::shared_ptr EmbeddingModule, int GroupSize = 64, + int Bits = 4); }; class QuantizedLinear : public Linear { @@ -50,8 +52,9 @@ class QuantizedLinear : public Linear { } } mx::array forward(mx::array Input) override; - static QuantizedLinear *fromLinear(Linear *LinearModule, int GroupSize = 64, - int Bits = 4); + static std::shared_ptr + fromLinear(std::shared_ptr LinearModule, int GroupSize = 64, + int Bits = 4); }; } // namespace mlx::core::nn diff --git a/plugins/wasi_nn/MLX/mlx/transformer.cpp b/plugins/wasi_nn/MLX/mlx/transformer.cpp index 853b83f8..da15eba4 100644 --- a/plugins/wasi_nn/MLX/mlx/transformer.cpp +++ b/plugins/wasi_nn/MLX/mlx/transformer.cpp @@ -2,11 +2,13 @@ #include namespace WasmEdge::Host::WASINN::MLX { + namespace mlx::core::nn { mx::array MultiHeadAttention::createAdditiveCausalMask(int N, mx::Dtype DType) { auto Indices = mx::arange(N); // mask = indices[:, None] < indices[None] auto Mask = reshape(Indices, {N, 1}) < reshape(Indices, {1, N}); + // using 1e9 instead of INF, and softmax(full(1e9)) != nan Mask = astype(Mask, DType) * -1e9; return Mask; } diff --git a/plugins/wasi_nn/MLX/mlx/transformer.h b/plugins/wasi_nn/MLX/mlx/transformer.h index a8c94897..9873ec76 100644 --- a/plugins/wasi_nn/MLX/mlx/transformer.h +++ b/plugins/wasi_nn/MLX/mlx/transformer.h @@ -4,6 +4,7 @@ #include namespace WasmEdge::Host::WASINN::MLX { + namespace mlx::core::nn { class MultiHeadAttention : public Module { int NumHeads; @@ -17,7 +18,8 @@ class MultiHeadAttention : public Module { std::optional ValueOutputDims = {}, bool Bias = false) : NumHeads(NumHeads) { if (Dims % NumHeads != 0) { - spdlog::error("Dims must be divisible by NumHeads"); + spdlog::error( + "[WASI-NN] MLX backend: Dims must be divisible by NumHeads"sv); assumingUnreachable(); } if (!QueryInputDims) { @@ -35,10 +37,14 @@ class MultiHeadAttention : public Module { if (!ValueOutputDims) { ValueOutputDims = Dims; } - registerModule("query_proj", new Linear(*QueryInputDims, Dims, Bias)); - registerModule("key_proj", new Linear(*KeyInputDims, Dims, Bias)); - registerModule("value_proj", new Linear(*ValueInputDims, *ValueDims, Bias)); - registerModule("out_proj", new Linear(*ValueDims, *ValueOutputDims, Bias)); + registerModule("query_proj", std::make_shared( + Linear(*QueryInputDims, Dims, Bias))); + registerModule("key_proj", + std::make_shared(Linear(*KeyInputDims, Dims, Bias))); + registerModule("value_proj", std::make_shared(Linear( + *ValueInputDims, *ValueDims, Bias))); + registerModule("out_proj", std::make_shared( + Linear(*ValueDims, *ValueOutputDims, Bias))); }; mx::array forward(mx::array Queries, mx::array Keys, mx::array Values, mx::array Mask); diff --git a/plugins/wasi_nn/MLX/model/converter.cpp b/plugins/wasi_nn/MLX/model/converter.cpp index 13fb54b7..35e57265 100644 --- a/plugins/wasi_nn/MLX/model/converter.cpp +++ b/plugins/wasi_nn/MLX/model/converter.cpp @@ -6,6 +6,7 @@ #include namespace WasmEdge::Host::WASINN::MLX { + std::unordered_map weightsToMlx(std::string WeightPath) { const std::filesystem::path Path(WeightPath); @@ -20,16 +21,17 @@ weightsToMlx(std::string WeightPath) { return Loaded; } if (endsWith(WeightPath, ".safetensors")) { - spdlog::info("Loading model from .safetensors file...\n"); + spdlog::info( + "[WASI-NN] MLX backend: Loading model from .safetensors file...\n"sv); const mx::SafetensorsLoad Loaded = mx::load_safetensors(WeightPath); return Loaded.first; } if (endsWith(WeightPath, ".gguf")) { - spdlog::info("Loading model from .gguf file...\n"); + spdlog::info("[WASI-NN] MLX backend: Loading model from .gguf file...\n"sv); const mx::GGUFLoad Loaded = mx::load_gguf(WeightPath); return Loaded.first; } - spdlog::error("Can not regonize model file\n"); + spdlog::error("[WASI-NN] MLX backend: Can not regonize model file\n"sv); assumingUnreachable(); } @@ -37,8 +39,8 @@ std::unordered_map llamaToMlxllm(std::string WeightPath) { std::unordered_map ModelWeights; auto Weight = weightsToMlx(WeightPath); - for (auto &[k, v] : Weight) { - std::string NewKey = k; + for (auto &[K, V] : Weight) { + std::string NewKey = K; if (startsWith(NewKey, "model.")) { strReplace(NewKey, "model.", ""); } @@ -52,21 +54,21 @@ llamaToMlxllm(std::string WeightPath) { SplitKey.end()) { ModelWeights.insert({SplitKey[0] + "." + SplitKey[1] + ".attention." + SplitKey[3] + "." + SplitKey[4], - v}); + V}); } else if (find(SplitKey.begin(), SplitKey.end(), "mlp") != SplitKey.end()) { - ModelWeights.insert({NewKey, v}); + ModelWeights.insert({NewKey, V}); } else { const std::unordered_map KeyMap = { {"input_layernorm", "attention_norm"}, {"post_attention_layernorm", "mlp_norm"}}; if (KeyMap.find(SplitKey[2]) == KeyMap.end()) { - ModelWeights.insert({NewKey, v}); + ModelWeights.insert({NewKey, V}); } else { ModelWeights.insert({SplitKey[0] + "." + SplitKey[1] + "." + KeyMap.at(SplitKey[2]) + "." + SplitKey[3], - v}); + V}); } } } else { @@ -75,9 +77,9 @@ llamaToMlxllm(std::string WeightPath) { {"lm_head", "head"}, {"norm", "norm"}}; if (KeyMap.find(SplitKey[0]) == KeyMap.end()) { - ModelWeights.insert({NewKey, v}); + ModelWeights.insert({NewKey, V}); } else { - ModelWeights.insert({KeyMap.at(SplitKey[0]) + "." + SplitKey[1], v}); + ModelWeights.insert({KeyMap.at(SplitKey[0]) + "." + SplitKey[1], V}); } } } diff --git a/plugins/wasi_nn/MLX/model/converter.h b/plugins/wasi_nn/MLX/model/converter.h index 063556b1..dda09837 100644 --- a/plugins/wasi_nn/MLX/model/converter.h +++ b/plugins/wasi_nn/MLX/model/converter.h @@ -4,6 +4,7 @@ #include namespace WasmEdge::Host::WASINN::MLX { + #define strReplace(Str, From, To) Str.replace(Str.find(From), strlen(From), To) std::unordered_map weightsToMlx(std::string WeightPath); diff --git a/plugins/wasi_nn/MLX/model/registry.cpp b/plugins/wasi_nn/MLX/model/registry.cpp index 95218fb3..7b4d5661 100644 --- a/plugins/wasi_nn/MLX/model/registry.cpp +++ b/plugins/wasi_nn/MLX/model/registry.cpp @@ -1,24 +1,28 @@ #include "registry.h" +#include "transformer.h" namespace WasmEdge::Host::WASINN::MLX { -Transformer *llama38b(int VocabSize, float NormEps, float RopeTheta, - bool RopeTraditional) { - return new Transformer(4096, std::vector{14336}, VocabSize, 32, - std::vector{32}, std::vector{8}, NormEps, {}, - RopeTraditional, RopeTheta); + +std::shared_ptr llama38b(int VocabSize, float NormEps, + float RopeTheta, bool RopeTraditional) { + return std::make_shared(Transformer( + 4096, std::vector{14336}, VocabSize, 32, std::vector{32}, + std::vector{8}, NormEps, {}, RopeTraditional, RopeTheta)); } -Transformer *llama27bChat(int VocabSize, float NormEps, float RopeTheta, - bool RopeTraditional) { - return new Transformer(4096, std::vector{11008}, VocabSize, 32, - std::vector{32}, std::vector{32}, NormEps, - {}, RopeTraditional, RopeTheta); +std::shared_ptr llama27bChat(int VocabSize, float NormEps, + float RopeTheta, + bool RopeTraditional) { + return std::make_shared(Transformer( + 4096, std::vector{11008}, VocabSize, 32, std::vector{32}, + std::vector{32}, NormEps, {}, RopeTraditional, RopeTheta)); } -Transformer *tinyLlama11BChatV10(int VocabSize, float NormEps, float RopeTheta, - bool RopeTraditional) { - return new Transformer(2048, std::vector{5632}, VocabSize, 22, - std::vector{32}, std::vector{4}, NormEps, {}, - RopeTraditional, RopeTheta); +std::shared_ptr tinyLlama11BChatV10(int VocabSize, float NormEps, + float RopeTheta, + bool RopeTraditional) { + return std::make_shared(Transformer( + 2048, std::vector{5632}, VocabSize, 22, std::vector{32}, + std::vector{4}, NormEps, {}, RopeTraditional, RopeTheta)); } } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/registry.h b/plugins/wasi_nn/MLX/model/registry.h index 871a27d3..4ed1d0c3 100644 --- a/plugins/wasi_nn/MLX/model/registry.h +++ b/plugins/wasi_nn/MLX/model/registry.h @@ -3,14 +3,19 @@ #include "transformer.h" namespace WasmEdge::Host::WASINN::MLX { -Transformer *llama38b(int VocabSize = 32000, float NormEps = 1e-5, - float RopeTheta = 10000.0, bool RopeTraditional = false); -Transformer *llama27bChat(int VocabSize = 32000, float NormEps = 1e-5, - float RopeTheta = 10000.0, - bool RopeTraditional = false); +std::shared_ptr llama38b(int VocabSize = 32000, + float NormEps = 1e-5, + float RopeTheta = 10000.0, + bool RopeTraditional = false); -Transformer *tinyLlama11BChatV10(int VocabSize = 32000, float NormEps = 1e-5, - float RopeTheta = 10000.0, - bool RopeTraditional = false); +std::shared_ptr llama27bChat(int VocabSize = 32000, + float NormEps = 1e-5, + float RopeTheta = 10000.0, + bool RopeTraditional = false); + +std::shared_ptr tinyLlama11BChatV10(int VocabSize = 32000, + float NormEps = 1e-5, + float RopeTheta = 10000.0, + bool RopeTraditional = false); } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/transformer.cpp b/plugins/wasi_nn/MLX/model/transformer.cpp index d914ca39..00cd0e70 100644 --- a/plugins/wasi_nn/MLX/model/transformer.cpp +++ b/plugins/wasi_nn/MLX/model/transformer.cpp @@ -5,6 +5,7 @@ #include "transformer.h" #include #include +#include #include #include #include @@ -12,6 +13,7 @@ #include namespace WasmEdge::Host::WASINN::MLX { + mx::array RMSNorm::forward(mx::array Input) { return mx::fast::rms_norm(Input, 1.0 + Parameters.at("weight"), Eps); } @@ -21,51 +23,58 @@ Attention::forward(mx::array Input, std::optional Mask, const auto &[B, L, D] = std::tie(Input.shape()[0], Input.shape()[1], Input.shape()[2]); mx::array Queries = - dynamic_cast(Submodules["q_proj"])->forward(Input); - mx::array Keys = - dynamic_cast(Submodules["k_proj"])->forward(Input); - mx::array Values = - dynamic_cast(Submodules["v_proj"])->forward(Input); + std::dynamic_pointer_cast(Submodules["q_proj"]) + ->forward(Input); + mx::array Keys = std::dynamic_pointer_cast(Submodules["k_proj"]) + ->forward(Input); + mx::array Values = std::dynamic_pointer_cast(Submodules["v_proj"]) + ->forward(Input); Queries = transpose(reshape(Queries, {B, L, NHeads, -1}), {0, 2, 1, 3}); Keys = transpose(reshape(Keys, {B, L, NKVHeads, -1}), {0, 2, 1, 3}); Values = transpose(reshape(Values, {B, L, NKVHeads, -1}), {0, 2, 1, 3}); if (NormQKProj) { - Queries = - dynamic_cast(Submodules["q_norm"])->forward(Queries); - Keys = dynamic_cast(Submodules["k_norm"])->forward(Keys); + Queries = std::dynamic_pointer_cast(Submodules["q_norm"]) + ->forward(Queries); + Keys = std::dynamic_pointer_cast(Submodules["k_norm"]) + ->forward(Keys); } if (KVCache) { const auto &[KeyCache, ValueCache] = *KVCache; - Queries = dynamic_cast(Submodules["rope"]) + Queries = std::dynamic_pointer_cast(Submodules["rope"]) ->forward(Queries, KeyCache.shape(2)); - Keys = dynamic_cast(Submodules["rope"]) + Keys = std::dynamic_pointer_cast(Submodules["rope"]) ->forward(Keys, KeyCache.shape(2)); Keys = mx::concatenate({KeyCache, Keys}, 2); Values = mx::concatenate({ValueCache, Values}, 2); } else { - Queries = dynamic_cast(Submodules["rope"])->forward(Queries); - Keys = dynamic_cast(Submodules["rope"])->forward(Keys); + Queries = std::dynamic_pointer_cast(Submodules["rope"]) + ->forward(Queries); + Keys = + std::dynamic_pointer_cast(Submodules["rope"])->forward(Keys); } mx::array Output = mx::fast::scaled_dot_product_attention( Queries, Keys, Values, Scale, Mask); Output = reshape(transpose(Output, {0, 2, 1, 3}), {B, L, -1}); - return {dynamic_cast(Submodules["o_proj"])->forward(Output), + return {std::dynamic_pointer_cast(Submodules["o_proj"]) + ->forward(Output), {Keys, Values}}; } mx::array MLP::forward(mx::array Input) { if (Gemma) { - return dynamic_cast(Submodules["down_proj"]) - ->forward( - mlx::core::gelu(dynamic_cast(Submodules["gate_proj"]) - ->forward(Input)) * - dynamic_cast(Submodules["up_proj"])->forward(Input)); + return std::dynamic_pointer_cast(Submodules["down_proj"]) + ->forward(mlx::core::gelu(std::dynamic_pointer_cast( + Submodules["gate_proj"]) + ->forward(Input)) * + std::dynamic_pointer_cast(Submodules["up_proj"]) + ->forward(Input)); } - return dynamic_cast(Submodules["down_proj"]) - ->forward( - mlx::core::silu(dynamic_cast(Submodules["gate_proj"]) - ->forward(Input)) * - dynamic_cast(Submodules["up_proj"])->forward(Input)); + return std::dynamic_pointer_cast(Submodules["down_proj"]) + ->forward(mlx::core::silu(std::dynamic_pointer_cast( + Submodules["gate_proj"]) + ->forward(Input)) * + std::dynamic_pointer_cast(Submodules["up_proj"]) + ->forward(Input)); } std::tuple> TransformerBlock::forward( @@ -73,23 +82,27 @@ TransformerBlock::forward( std::optional> KVCachePar) { mx::array NormOutput = {}; if (!Gemma) { - NormOutput = dynamic_cast(Submodules["attention_norm"]) - ->forward(Input); + NormOutput = + std::dynamic_pointer_cast(Submodules["attention_norm"]) + ->forward(Input); } else { NormOutput = - dynamic_cast(Submodules["attention_norm"])->forward(Input); + std::dynamic_pointer_cast(Submodules["attention_norm"]) + ->forward(Input); } - auto [R, KVCache] = dynamic_cast(Submodules["attention"]) - ->forward(NormOutput, Mask, KVCachePar); + auto [R, KVCache] = + std::dynamic_pointer_cast(Submodules["attention"]) + ->forward(NormOutput, Mask, KVCachePar); auto H = Input + R; if (!Gemma) { - R = dynamic_cast(Submodules["mlp"]) - ->forward(dynamic_cast(Submodules["mlp_norm"]) - ->forward(H)); - } else { - R = dynamic_cast(Submodules["mlp"]) + R = std::dynamic_pointer_cast(Submodules["mlp"]) ->forward( - dynamic_cast(Submodules["mlp_norm"])->forward(H)); + std::dynamic_pointer_cast(Submodules["mlp_norm"]) + ->forward(H)); + } else { + R = std::dynamic_pointer_cast(Submodules["mlp"]) + ->forward(std::dynamic_pointer_cast(Submodules["mlp_norm"]) + ->forward(H)); } return {H + R, KVCache}; } @@ -100,15 +113,14 @@ Transformer::embed( std::optional>> KVCachePar, bool Norm) { mx::array H = - dynamic_cast(Submodules["token_embed"]) + std::dynamic_pointer_cast(Submodules["token_embed"]) ->forward(Input); if (Gemma) { H = H * (pow(Dim, 0.5)); } std::optional Mask; if (H.shape()[1] > 1) { - Mask = mlx::core::nn::MultiHeadAttention::createAdditiveCausalMask( - H.shape()[1]); + Mask = nn::MultiHeadAttention::createAdditiveCausalMask(H.shape()[1]); Mask = astype(*Mask, H.dtype()); } std::vector> KVCache; @@ -126,10 +138,12 @@ Transformer::embed( } if (Norm) { if (!Gemma) { - return {dynamic_cast(Submodules["norm"])->forward(H), + return {std::dynamic_pointer_cast(Submodules["norm"]) + ->forward(H), KVCache}; } - return {dynamic_cast(Submodules["norm"])->forward(H), KVCache}; + return {std::dynamic_pointer_cast(Submodules["norm"])->forward(H), + KVCache}; } return {H, KVCache}; } @@ -141,10 +155,10 @@ Transformer::forward( auto [X, KVCache] = embed(Input, KVCachePar, true); mx::array Out = {}; if (EmbedAsHead) { - Out = dynamic_cast(Submodules["token_embed"]) + Out = std::dynamic_pointer_cast(Submodules["token_embed"]) ->asLinear(X); } else { - Out = dynamic_cast(Submodules["head"])->forward(X); + Out = std::dynamic_pointer_cast(Submodules["head"])->forward(X); } return {Out, KVCache}; } diff --git a/plugins/wasi_nn/MLX/model/transformer.h b/plugins/wasi_nn/MLX/model/transformer.h index d0dac595..6b99e599 100644 --- a/plugins/wasi_nn/MLX/model/transformer.h +++ b/plugins/wasi_nn/MLX/model/transformer.h @@ -14,6 +14,7 @@ #include namespace WasmEdge::Host::WASINN::MLX { + namespace nn = mlx::core::nn; class RMSNorm : public nn::Module { @@ -47,14 +48,20 @@ class Attention : public nn::Module { HeadDim = Dim / NHeads; } Scale = pow(HeadDim, -0.5); - registerModule("q_proj", new nn::Linear(Dim, NHeads * HeadDim, false)); - registerModule("k_proj", new nn::Linear(Dim, NKVHeads * HeadDim, false)); - registerModule("v_proj", new nn::Linear(Dim, NKVHeads * HeadDim, false)); - registerModule("o_proj", new nn::Linear(NHeads * HeadDim, Dim, false)); + registerModule("q_proj", std::make_shared( + nn::Linear(Dim, NHeads * HeadDim, false))); + registerModule("k_proj", std::make_shared( + nn::Linear(Dim, NKVHeads * HeadDim, false))); + registerModule("v_proj", std::make_shared( + nn::Linear(Dim, NKVHeads * HeadDim, false))); + registerModule("o_proj", std::make_shared( + nn::Linear(NHeads * HeadDim, Dim, false))); if (NormQKProj) { - registerModule("q_norm", new nn::RMSNorm(HeadDim, AttentionNormEps)); - registerModule("k_norm", new nn::RMSNorm(HeadDim, AttentionNormEps)); + registerModule("q_norm", std::make_shared( + nn::RMSNorm(HeadDim, AttentionNormEps))); + registerModule("k_norm", std::make_shared( + nn::RMSNorm(HeadDim, AttentionNormEps))); } float RopeScale; if (RopeScaling && (*RopeScaling)["type"] == "linear") { @@ -63,8 +70,9 @@ class Attention : public nn::Module { RopeScale = 1; } - registerModule( - "rope", new nn::RoPE(HeadDim, RopeTraditional, RopeTheta, RopeScale)); + registerModule("rope", + std::make_shared(nn::RoPE(HeadDim, RopeTraditional, + RopeTheta, RopeScale))); } std::tuple> forward(mx::array Input, std::optional Mask = {}, @@ -75,9 +83,12 @@ class MLP : public nn::Module { public: MLP(int Dim, int HiddenDim, bool Gemma = false) : Gemma(Gemma) { - registerModule("gate_proj", new nn::Linear(Dim, HiddenDim, false)); - registerModule("down_proj", new nn::Linear(HiddenDim, Dim, false)); - registerModule("up_proj", new nn::Linear(Dim, HiddenDim, false)); + registerModule("gate_proj", std::make_shared( + nn::Linear(Dim, HiddenDim, false))); + registerModule("down_proj", std::make_shared( + nn::Linear(HiddenDim, Dim, false))); + registerModule("up_proj", std::make_shared( + nn::Linear(Dim, HiddenDim, false))); } mx::array forward(mx::array Input); }; @@ -94,16 +105,20 @@ class TransformerBlock : public nn::Module { bool Gemma = false) : Gemma(Gemma) { registerModule("attention", - new Attention(Dim, NHeads, NKVHeads, HeadDim, - RopeTraditional, RopeTheta, RopeScaling, - NormQKProj, AttentionNormEps)); - registerModule("mlp", new MLP(Dim, HiddenDim, Gemma)); + std::make_shared(Attention( + Dim, NHeads, NKVHeads, HeadDim, RopeTraditional, + RopeTheta, RopeScaling, NormQKProj, AttentionNormEps))); + registerModule("mlp", std::make_shared(MLP(Dim, HiddenDim, Gemma))); if (!Gemma) { - registerModule("attention_norm", new nn::RMSNorm(Dim, NormEps)); - registerModule("mlp_norm", new nn::RMSNorm(Dim, NormEps)); + registerModule("attention_norm", + std::make_shared(nn::RMSNorm(Dim, NormEps))); + registerModule("mlp_norm", + std::make_shared(nn::RMSNorm(Dim, NormEps))); } else { - registerModule("attention_norm", new RMSNorm(Dim, NormEps)); - registerModule("mlp_norm", new RMSNorm(Dim, NormEps)); + registerModule("attention_norm", + std::make_shared(RMSNorm(Dim, NormEps))); + registerModule("mlp_norm", + std::make_shared(RMSNorm(Dim, NormEps))); } } std::tuple> @@ -115,7 +130,7 @@ class Transformer : public nn::Module { std::optional> HiddenDim; bool Gemma; bool EmbedAsHead; - std::vector Layers{}; + std::vector> Layers{}; public: Transformer( @@ -131,14 +146,16 @@ class Transformer : public nn::Module { : Dim(Dim), HiddenDim(HiddenDim), Gemma(Gemma), EmbedAsHead(EmbedAsHeadPar) { if (VocabSize <= 0) { - spdlog::error("VocabSize must be greater than 0."); + spdlog::error( + "[WASI-NN] MLX backend: VocabSize must be greater than 0."sv); assumingUnreachable(); } EmbedAsHead = Gemma ? true : EmbedAsHead; if (!NKVHeads) { NKVHeads = NHeads; } - registerModule("token_embed", new nn::Embedding(VocabSize, Dim)); + registerModule("token_embed", std::make_shared( + nn::Embedding(VocabSize, Dim))); if (HiddenDim->size() == 1) { while (static_cast(HiddenDim->size()) < NLayers) { HiddenDim->emplace_back((*HiddenDim)[0]); @@ -157,25 +174,27 @@ class Transformer : public nn::Module { Layers.reserve(NLayers); for (int Idx = 0; Idx < NLayers; Idx++) { if (RopeScaling) { - Layers.push_back(new TransformerBlock( + Layers.push_back(std::make_shared(TransformerBlock( Dim, (*NHeads)[Idx], (*NKVHeads)[Idx], (*HiddenDim)[Idx], NormEps, HeadDim, RopeTraditional, RopeTheta, (*RopeScaling)[Idx], - NormQKProj, AttentionNormEps, Gemma)); + NormQKProj, AttentionNormEps, Gemma))); } else { - Layers.push_back(new TransformerBlock( + Layers.push_back(std::make_shared(TransformerBlock( Dim, (*NHeads)[Idx], (*NKVHeads)[Idx], (*HiddenDim)[Idx], NormEps, HeadDim, RopeTraditional, RopeTheta, {}, NormQKProj, - AttentionNormEps, Gemma)); + AttentionNormEps, Gemma))); } } registerLayer("layers", Layers); if (!Gemma) { - registerModule("norm", new nn::RMSNorm(Dim, NormEps)); + registerModule("norm", + std::make_shared(nn::RMSNorm(Dim, NormEps))); } else { - registerModule("norm", new RMSNorm(Dim, NormEps)); + registerModule("norm", std::make_shared(RMSNorm(Dim, NormEps))); } if (!EmbedAsHead) { - registerModule("head", new nn::Linear(Dim, VocabSize, false)); + registerModule("head", std::make_shared( + nn::Linear(Dim, VocabSize, false))); } } std::tuple &Weights, if (endsWith(Path, ".safetensors")) { mx::save_safetensors(Path, Weights, {{"format", "mlx"}}); } else { - spdlog::error("Unsupported file format"); + spdlog::error("[WASI-NN] MLX backend: Unsupported file format"sv); assumingUnreachable(); } } @@ -45,7 +45,7 @@ void saveWeights(const mx::array &Weights, const std::string &Path) { if (endsWith(Path, ".npz")) { mx::save(Path, Weights); } else { - spdlog::error("Unsupported file format"); + spdlog::error("[WASI-NN] MLX backend: Unsupported file format"sv); assumingUnreachable(); } } diff --git a/plugins/wasi_nn/mlx.cpp b/plugins/wasi_nn/mlx.cpp index b4de8623..c5cf045d 100644 --- a/plugins/wasi_nn/mlx.cpp +++ b/plugins/wasi_nn/mlx.cpp @@ -144,8 +144,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, // Load tokenizer. if (!TokenizerPath.empty()) { GraphRef.Tok = - tokenizers::Tokenizer::FromBlobJSON(loadBytesFromFile(TokenizerPath)) - .release(); + tokenizers::Tokenizer::FromBlobJSON(loadBytesFromFile(TokenizerPath)); } else { spdlog::error("[WASI-NN] MLX backend: Tokenizer path not found."sv); Env.NNGraph.pop_back(); diff --git a/plugins/wasi_nn/mlx.h b/plugins/wasi_nn/mlx.h index 0c680867..af3e152e 100644 --- a/plugins/wasi_nn/mlx.h +++ b/plugins/wasi_nn/mlx.h @@ -2,6 +2,7 @@ #include "plugin/plugin.h" #include "types.h" +#include #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX #include "MLX/model/transformer.h" @@ -18,17 +19,9 @@ struct WasiNNEnvironment; namespace WasmEdge::Host::WASINN::MLX { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX struct Graph { - ~Graph() noexcept { - if (Model != nullptr) { - delete Model; - } - if (Tok != nullptr) { - delete Tok; - } - } std::string ModelType; - tokenizers::Tokenizer *Tok = nullptr; - Transformer *Model; + std::unique_ptr Tok = nullptr; + std::shared_ptr Model; double Temp = 0.0; bool EnableDebugLog = false; bool IsQuantized = false; From 5714f26cebed522867cd8e3b268f0325dae007ae Mon Sep 17 00:00:00 2001 From: grorge Date: Tue, 24 Sep 2024 16:50:31 +0800 Subject: [PATCH 452/623] [WASI-NN] mlx: handle load tokenizer failed Signed-off-by: grorge --- plugins/wasi_nn/mlx.cpp | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/plugins/wasi_nn/mlx.cpp b/plugins/wasi_nn/mlx.cpp index c5cf045d..5e8406f4 100644 --- a/plugins/wasi_nn/mlx.cpp +++ b/plugins/wasi_nn/mlx.cpp @@ -13,8 +13,8 @@ namespace WasmEdge::Host::WASINN::MLX { std::string loadBytesFromFile(const std::string &Path) { std::ifstream Fs(Path, std::ios::in | std::ios::binary); if (Fs.fail()) { - std::cerr << "Cannot open " << Path << std::endl; - exit(1); + spdlog::error("[WASI-NN] MLX backend: Cannot open {}."sv, Path); + return ""; } std::string Data; Fs.seekg(0, std::ios::end); @@ -143,8 +143,13 @@ Expect load(WASINN::WasiNNEnvironment &Env, // Load tokenizer. if (!TokenizerPath.empty()) { - GraphRef.Tok = - tokenizers::Tokenizer::FromBlobJSON(loadBytesFromFile(TokenizerPath)); + auto Bytes = loadBytesFromFile(TokenizerPath); + if (Bytes.empty()) { + spdlog::error("[WASI-NN] MLX backend: Load tokenizer failed."sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } + GraphRef.Tok = tokenizers::Tokenizer::FromBlobJSON(Bytes); } else { spdlog::error("[WASI-NN] MLX backend: Tokenizer path not found."sv); Env.NNGraph.pop_back(); From d893ddd47d08aa4f7c8cebbb1fa12552772942ec Mon Sep 17 00:00:00 2001 From: Yi-Ying He Date: Thu, 26 Sep 2024 17:47:26 +0800 Subject: [PATCH 453/623] [CMake] Apply the WasmEdge component in the cpack. (#3690) Signed-off-by: YiYing He --- plugins/wasi_crypto/CMakeLists.txt | 1 + plugins/wasi_http/CMakeLists.txt | 1 + plugins/wasi_nn/CMakeLists.txt | 1 + plugins/wasi_poll/CMakeLists.txt | 1 + plugins/wasmedge_ffmpeg/CMakeLists.txt | 1 + plugins/wasmedge_image/CMakeLists.txt | 1 + plugins/wasmedge_llmc/CMakeLists.txt | 1 + plugins/wasmedge_ocr/CMakeLists.txt | 1 + plugins/wasmedge_opencvmini/CMakeLists.txt | 1 + plugins/wasmedge_process/CMakeLists.txt | 1 + plugins/wasmedge_stablediffusion/CMakeLists.txt | 1 + plugins/wasmedge_tensorflow/CMakeLists.txt | 1 + plugins/wasmedge_tensorflowlite/CMakeLists.txt | 1 + plugins/wasmedge_zlib/CMakeLists.txt | 1 + 14 files changed, 14 insertions(+) diff --git a/plugins/wasi_crypto/CMakeLists.txt b/plugins/wasi_crypto/CMakeLists.txt index cc8c8dbe..48ae00c9 100644 --- a/plugins/wasi_crypto/CMakeLists.txt +++ b/plugins/wasi_crypto/CMakeLists.txt @@ -82,4 +82,5 @@ endif() install( TARGETS wasmedgePluginWasiCrypto DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge ) diff --git a/plugins/wasi_http/CMakeLists.txt b/plugins/wasi_http/CMakeLists.txt index 3c7321fd..45bc0030 100644 --- a/plugins/wasi_http/CMakeLists.txt +++ b/plugins/wasi_http/CMakeLists.txt @@ -45,4 +45,5 @@ endif() install( TARGETS wasmedgePluginWasiHttp DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge ) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 789e4b3b..7c74e8de 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -326,4 +326,5 @@ wasmedge_setup_wasinn_target(wasmedgePluginWasiNN) install( TARGETS wasmedgePluginWasiNN DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge ) diff --git a/plugins/wasi_poll/CMakeLists.txt b/plugins/wasi_poll/CMakeLists.txt index 0ffa67c6..9c641135 100644 --- a/plugins/wasi_poll/CMakeLists.txt +++ b/plugins/wasi_poll/CMakeLists.txt @@ -39,4 +39,5 @@ endif() install( TARGETS wasmedgePluginWasiPoll DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge ) diff --git a/plugins/wasmedge_ffmpeg/CMakeLists.txt b/plugins/wasmedge_ffmpeg/CMakeLists.txt index 47e72cdd..0a1ff4a8 100644 --- a/plugins/wasmedge_ffmpeg/CMakeLists.txt +++ b/plugins/wasmedge_ffmpeg/CMakeLists.txt @@ -88,4 +88,5 @@ endif() install( TARGETS wasmedgePluginWasmEdgeFFmpeg DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge ) diff --git a/plugins/wasmedge_image/CMakeLists.txt b/plugins/wasmedge_image/CMakeLists.txt index 8360a40d..800c3a94 100644 --- a/plugins/wasmedge_image/CMakeLists.txt +++ b/plugins/wasmedge_image/CMakeLists.txt @@ -143,4 +143,5 @@ endif() install( TARGETS wasmedgePluginWasmEdgeImage DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge ) diff --git a/plugins/wasmedge_llmc/CMakeLists.txt b/plugins/wasmedge_llmc/CMakeLists.txt index 7e9e4c59..da37ed2f 100644 --- a/plugins/wasmedge_llmc/CMakeLists.txt +++ b/plugins/wasmedge_llmc/CMakeLists.txt @@ -62,4 +62,5 @@ endif() install( TARGETS wasmedgePluginWasmEdgeLLMC DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge ) diff --git a/plugins/wasmedge_ocr/CMakeLists.txt b/plugins/wasmedge_ocr/CMakeLists.txt index e3139149..a540eb06 100644 --- a/plugins/wasmedge_ocr/CMakeLists.txt +++ b/plugins/wasmedge_ocr/CMakeLists.txt @@ -34,6 +34,7 @@ endif() install( TARGETS wasmedgePluginWasmEdgeOCR DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge ) message(STATUS "WASI-OCR: Build Tesseract backend for WASI-OCR") diff --git a/plugins/wasmedge_opencvmini/CMakeLists.txt b/plugins/wasmedge_opencvmini/CMakeLists.txt index eaec2009..ed10e816 100644 --- a/plugins/wasmedge_opencvmini/CMakeLists.txt +++ b/plugins/wasmedge_opencvmini/CMakeLists.txt @@ -39,4 +39,5 @@ endif() install( TARGETS wasmedgePluginWasmEdgeOpenCVMini DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge ) diff --git a/plugins/wasmedge_process/CMakeLists.txt b/plugins/wasmedge_process/CMakeLists.txt index 819a64d7..28a4bcce 100644 --- a/plugins/wasmedge_process/CMakeLists.txt +++ b/plugins/wasmedge_process/CMakeLists.txt @@ -34,4 +34,5 @@ endif() install( TARGETS wasmedgePluginWasmEdgeProcess DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge ) diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index 9936d3f6..44827755 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -121,4 +121,5 @@ endif() install( TARGETS wasmedgePluginWasmEdgeStableDiffusion DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge ) diff --git a/plugins/wasmedge_tensorflow/CMakeLists.txt b/plugins/wasmedge_tensorflow/CMakeLists.txt index 96147ee7..ccfe25ed 100644 --- a/plugins/wasmedge_tensorflow/CMakeLists.txt +++ b/plugins/wasmedge_tensorflow/CMakeLists.txt @@ -37,4 +37,5 @@ wasmedge_setup_tf_target(wasmedgePluginWasmEdgeTensorflow) install( TARGETS wasmedgePluginWasmEdgeTensorflow DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge ) diff --git a/plugins/wasmedge_tensorflowlite/CMakeLists.txt b/plugins/wasmedge_tensorflowlite/CMakeLists.txt index 30695009..f8ee177d 100644 --- a/plugins/wasmedge_tensorflowlite/CMakeLists.txt +++ b/plugins/wasmedge_tensorflowlite/CMakeLists.txt @@ -37,4 +37,5 @@ wasmedge_setup_tflite_target(wasmedgePluginWasmEdgeTensorflowLite) install( TARGETS wasmedgePluginWasmEdgeTensorflowLite DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge ) diff --git a/plugins/wasmedge_zlib/CMakeLists.txt b/plugins/wasmedge_zlib/CMakeLists.txt index d75e577c..56745021 100644 --- a/plugins/wasmedge_zlib/CMakeLists.txt +++ b/plugins/wasmedge_zlib/CMakeLists.txt @@ -40,4 +40,5 @@ endif() install( TARGETS wasmedgePluginWasmEdgeZlib DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge + COMPONENT WasmEdge ) From 32462f1ec446bc129cd0d114adde99aaea369f98 Mon Sep 17 00:00:00 2001 From: PeterD1524 <53310459+PeterD1524@users.noreply.github.com> Date: Fri, 27 Sep 2024 05:48:05 +0800 Subject: [PATCH 454/623] [WASI-NN] piper: fix arguments for target linking and including in piper patch (#3798) The variables FMT_LINK_LIBRARIES, FMT_LINK_DIRECTORIES, FMT_INCLUDE_DIRECTORIES, SPDLOG_LINK_LIBRARIES, SPDLOG_LINK_DIRECTORIES, SPDLOG_INCLUDE_DIRECTORIES are semicolon-separated lists. They are intended to be used as unquoted arguments for target_link_libraries, target_link_directories, target_include_directories. They were incorrectly quoted. unquoted argument: https://cmake.org/cmake/help/latest/manual/cmake-language.7.html#unquoted-argument Signed-off-by: PeterD1524 --- plugins/wasi_nn/piper.patch | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/plugins/wasi_nn/piper.patch b/plugins/wasi_nn/piper.patch index c4ba79e2..bca50ca7 100644 --- a/plugins/wasi_nn/piper.patch +++ b/plugins/wasi_nn/piper.patch @@ -1,5 +1,5 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index f96ec44..1e84722 100644 +index f96ec44..6a2d6c4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,6 +2,8 @@ cmake_minimum_required(VERSION 3.13) @@ -139,8 +139,8 @@ index f96ec44..1e84722 100644 - fmt - spdlog +target_link_libraries(piper PRIVATE -+ "${FMT_LINK_LIBRARIES}" -+ "${SPDLOG_LINK_LIBRARIES}" ++ ${FMT_LINK_LIBRARIES} ++ ${SPDLOG_LINK_LIBRARIES} espeak-ng - piper_phonemize onnxruntime @@ -159,8 +159,8 @@ index f96ec44..1e84722 100644 - ${SPDLOG_DIR}/include - ${PIPER_PHONEMIZE_DIR}/include +target_link_directories(piper PRIVATE -+ "${FMT_LINK_DIRECTORIES}" -+ "${SPDLOG_LINK_DIRECTORIES}" ++ ${FMT_LINK_DIRECTORIES} ++ ${SPDLOG_LINK_DIRECTORIES} ) -target_compile_definitions(piper PUBLIC _PIPER_VERSION=${piper_version}) @@ -174,8 +174,8 @@ index f96ec44..1e84722 100644 - NAME test_piper - COMMAND test_piper "${CMAKE_SOURCE_DIR}/etc/test_voice.onnx" "${PIPER_PHONEMIZE_DIR}/share/espeak-ng-data" "${CMAKE_CURRENT_BINARY_DIR}/test.wav" +target_include_directories(piper PRIVATE -+ "${FMT_INCLUDE_DIRECTORIES}" -+ "${SPDLOG_INCLUDE_DIRECTORIES}" ++ ${FMT_INCLUDE_DIRECTORIES} ++ ${SPDLOG_INCLUDE_DIRECTORIES} + INTERFACE "${PIPER_INTERFACE_INCLUDE_DIRECTORY}" ) @@ -188,8 +188,8 @@ index f96ec44..1e84722 100644 - ${FMT_DIR}/include - ${SPDLOG_DIR}/include - ${PIPER_PHONEMIZE_DIR}/include -+ "${FMT_INCLUDE_DIRECTORIES}" -+ "${SPDLOG_INCLUDE_DIRECTORIES}" ++ ${FMT_INCLUDE_DIRECTORIES} ++ ${SPDLOG_INCLUDE_DIRECTORIES} ) target_link_directories( @@ -197,15 +197,15 @@ index f96ec44..1e84722 100644 - ${FMT_DIR}/lib - ${SPDLOG_DIR}/lib - ${PIPER_PHONEMIZE_DIR}/lib -+ "${FMT_LINK_DIRECTORIES}" -+ "${SPDLOG_LINK_DIRECTORIES}" ++ ${FMT_LINK_DIRECTORIES} ++ ${SPDLOG_LINK_DIRECTORIES} ) target_link_libraries(test_piper PUBLIC - fmt - spdlog -+ "${FMT_LINK_LIBRARIES}" -+ "${SPDLOG_LINK_LIBRARIES}" ++ ${FMT_LINK_LIBRARIES} ++ ${SPDLOG_LINK_LIBRARIES} espeak-ng piper_phonemize onnxruntime From 7347b377692f5726710fae4e1df486c780cd744d Mon Sep 17 00:00:00 2001 From: Han-Wen Tsao Date: Sun, 29 Sep 2024 20:39:48 +0800 Subject: [PATCH 455/623] [Plugin] Stable Diffusion: update SD version to 14206fd (#3801) Signed-off-by: grorge --- plugins/wasmedge_stablediffusion/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index 44827755..4bcff1ea 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -24,7 +24,7 @@ message(STATUS "Downloading stable diffusion source") FetchContent_Declare( stable-diffusion GIT_REPOSITORY https://github.com/leejet/stable-diffusion.cpp.git - GIT_TAG master-e71ddce + GIT_TAG 14206fd48832ab600d9db75f15acb5062ae2c296 GIT_SHALLOW TRUE ) FetchContent_MakeAvailable(stable-diffusion) From bfaf4f8b72bc3f2522f5af83539f5a3f160727b8 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Mon, 16 Sep 2024 03:40:16 +0800 Subject: [PATCH 456/623] [WASI-NN] Whisper: support several options. 1. threads 2. processors 3. max-context 4. max-len 5. offset-t 6. duration 7. split-on-word 8. temperature Signed-off-by: YiYing He --- plugins/wasi_nn/wasinnfunc.cpp | 4 +- plugins/wasi_nn/whispercpp.cpp | 168 ++++++++++++++++++++++++++++++++- plugins/wasi_nn/whispercpp.h | 13 +++ 3 files changed, 181 insertions(+), 4 deletions(-) diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index d9fb5e69..7ecc138a 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -631,13 +631,15 @@ Expect WasiNNUnload::bodyImpl(const Runtime::CallingFrame &Frame, switch (Env.NNGraph[GraphId].getBackend()) { case WASINN::Backend::GGML: return WASINN::GGML::unload(Env, GraphId); + case WASINN::Backend::Whisper: + return WASINN::Whisper::unload(Env, GraphId); case WASINN::Backend::NeuralSpeed: return WASINN::NeuralSpeed::unload(Env, GraphId); case WASINN::Backend::ChatTTS: return WASINN::ChatTTS::unload(Env, GraphId); default: spdlog::error( - "[WASI-NN] unlaod: Only GGML, Neural speed, and ChatTTS backend supports unload."sv); + "[WASI-NN] unlaod: Only GGML, Whisper, Neural speed, and ChatTTS backend supports unload."sv); return WASINN::ErrNo::InvalidArgument; } } diff --git a/plugins/wasi_nn/whispercpp.cpp b/plugins/wasi_nn/whispercpp.cpp index 3cc1452d..79bc23b2 100644 --- a/plugins/wasi_nn/whispercpp.cpp +++ b/plugins/wasi_nn/whispercpp.cpp @@ -141,8 +141,14 @@ void WhisperOutputSegmentCallback(struct whisper_context *WhisperCtx, void setWhisperParams(Context &CxtRef) noexcept { auto &WParam = CxtRef.WhisperParams; auto &ConfigRef = CxtRef.WhisperConfig; + WParam.n_threads = ConfigRef.ThreadsNum; + WParam.n_max_text_ctx = ConfigRef.MaxTokenContext; + WParam.offset_ms = ConfigRef.TimeOffsetMS; + WParam.duration_ms = ConfigRef.DurationMS; WParam.print_progress = false; WParam.thold_pt = ConfigRef.WordThreshold; + WParam.max_len = ConfigRef.MaxSegmentLength; + WParam.split_on_word = ConfigRef.SplitOnWord; WParam.translate = ConfigRef.Translate; WParam.language = ConfigRef.SpokenLanguage.c_str(); WParam.detect_language = ConfigRef.DetectLanguage; @@ -154,6 +160,34 @@ void setWhisperParams(Context &CxtRef) noexcept { WParam.grammar_penalty = ConfigRef.GrammarPenalty; WParam.new_segment_callback = WhisperOutputSegmentCallback; WParam.new_segment_callback_user_data = &CxtRef; + + if (ConfigRef.EnableDebugLog) { + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: threads: {}", + ConfigRef.ThreadsNum); + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: processors: {}", + ConfigRef.ProcessorsNum); + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: max-context: {}", + ConfigRef.MaxTokenContext); + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: offset-t: {}", + ConfigRef.TimeOffsetMS); + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: duration: {}", + ConfigRef.DurationMS); + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: max-len: {}", + ConfigRef.MaxSegmentLength); + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: split-on-word : {}", + ConfigRef.SplitOnWord); + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: translate: {}", + ConfigRef.Translate); + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: language: \"{}\"", + ConfigRef.SpokenLanguage); + spdlog::info( + "[WASI-NN][Debug] Whisper backend: Config: detect-language: {}", + ConfigRef.DetectLanguage); + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: temperature: {}", + ConfigRef.Temperature); + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: prompt: \"{}\"", + ConfigRef.InitialPrompt); + } } Expect parseMetadata(Config &ConfigRef, @@ -166,14 +200,30 @@ Expect parseMetadata(Config &ConfigRef, return ErrNo::InvalidEncoding; } + auto PrintParsedOption = [=](std::string_view Name, const auto &Val) { + if (ConfigRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] Whisper backend: Parsed metadata -- {}:{}"sv, Name, + Val); + } + }; + // Get metadata from the json. // Currently supported metadata: // Plugin parameters (used by this plugin): // enable-log: bool // enable-debug-log: bool + // threads: uint32_t + // processors: uint32_t + // offset-t: uint32_t + // duration: uint32_t + // max-context: uint32_t + // max-len: uint32_t + // split-on-word: bool // translate: bool // language: string // detect-language: bool + // temperature: float // prompt: string // The plugin parameters. @@ -196,6 +246,71 @@ Expect parseMetadata(Config &ConfigRef, return ErrNo::InvalidArgument; } } + if (Doc.at_key("threads").error() == simdjson::SUCCESS) { + auto Err = Doc["threads"].get().get(ConfigRef.ThreadsNum); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the threads option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("threads"sv, ConfigRef.ThreadsNum); + } + if (Doc.at_key("processors").error() == simdjson::SUCCESS) { + auto Err = Doc["processors"].get().get(ConfigRef.ProcessorsNum); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the processors option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("processors"sv, ConfigRef.ProcessorsNum); + } + if (Doc.at_key("offset-t").error() == simdjson::SUCCESS) { + auto Err = Doc["offset-t"].get().get(ConfigRef.TimeOffsetMS); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the offset-t option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("offset-t"sv, ConfigRef.TimeOffsetMS); + } + if (Doc.at_key("duration").error() == simdjson::SUCCESS) { + auto Err = Doc["duration"].get().get(ConfigRef.DurationMS); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the duration option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("duration"sv, ConfigRef.DurationMS); + } + if (Doc.at_key("max-context").error() == simdjson::SUCCESS) { + auto Err = + Doc["max-context"].get().get(ConfigRef.MaxTokenContext); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the max-context option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("max-context"sv, ConfigRef.MaxTokenContext); + } + if (Doc.at_key("max-len").error() == simdjson::SUCCESS) { + auto Err = Doc["max-len"].get().get(ConfigRef.MaxSegmentLength); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the max-len option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("max-len"sv, ConfigRef.MaxSegmentLength); + } + if (Doc.at_key("split-on-word").error() == simdjson::SUCCESS) { + auto Err = Doc["split-on-word"].get().get(ConfigRef.SplitOnWord); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the split-on-word " + "option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("split-on-word"sv, ConfigRef.SplitOnWord); + } if (Doc.at_key("translate").error() == simdjson::SUCCESS) { auto Err = Doc["translate"].get().get(ConfigRef.Translate); if (Err) { @@ -204,6 +319,7 @@ Expect parseMetadata(Config &ConfigRef, "option."sv); return ErrNo::InvalidArgument; } + PrintParsedOption("translate"sv, ConfigRef.Translate); } if (Doc.at_key("language").error() == simdjson::SUCCESS) { std::string_view Language; @@ -215,6 +331,7 @@ Expect parseMetadata(Config &ConfigRef, return ErrNo::InvalidArgument; } ConfigRef.SpokenLanguage = Language; + PrintParsedOption("language"sv, ConfigRef.SpokenLanguage); } if (Doc.at_key("detect-language").error() == simdjson::SUCCESS) { auto Err = Doc["detect-language"].get().get(ConfigRef.DetectLanguage); @@ -224,6 +341,18 @@ Expect parseMetadata(Config &ConfigRef, "option."sv); return ErrNo::InvalidArgument; } + PrintParsedOption("detect-language"sv, ConfigRef.DetectLanguage); + } + if (Doc.at_key("temperature").error() == simdjson::SUCCESS) { + double Temperature; + auto Err = Doc["temperature"].get().get(Temperature); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the temperature option."sv); + return ErrNo::InvalidArgument; + } + ConfigRef.Temperature = static_cast(Temperature); + PrintParsedOption("temperature"sv, ConfigRef.Temperature); } if (Doc.at_key("prompt").error() == simdjson::SUCCESS) { std::string_view Prompt; @@ -234,6 +363,7 @@ Expect parseMetadata(Config &ConfigRef, return ErrNo::InvalidArgument; } ConfigRef.InitialPrompt = Prompt; + PrintParsedOption("prompt"sv, ConfigRef.InitialPrompt); } return ErrNo::Success; } @@ -382,6 +512,10 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, spdlog::info( "[WASI-NN][Debug] Whisper backend: found Metadata, processing"sv); } + // Set the whisper config of this context as the graph default first. + // This will reset the config and inherit settings from the graph metadata. + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + CxtRef.WhisperConfig = GraphRef.WhisperConfig; const std::string Metadata(reinterpret_cast(Tensor.Tensor.data()), Tensor.Tensor.size()); auto Res = parseMetadata(CxtRef.WhisperConfig, Metadata); @@ -389,7 +523,6 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, spdlog::error("[WASI-NN] Whisper backend: Failed to parse metadata."sv); return Res; } - auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); Res = handleTranslationConfig(GraphRef.WhisperCtx, CxtRef.WhisperConfig); if (Res != ErrNo::Success) { return Res; @@ -467,8 +600,9 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { } CxtRef.Outputs.clear(); - if (whisper_full(GraphRef.WhisperCtx, CxtRef.WhisperParams, - CxtRef.InputPCM.data(), CxtRef.InputPCM.size()) != 0) { + if (whisper_full_parallel(GraphRef.WhisperCtx, CxtRef.WhisperParams, + CxtRef.InputPCM.data(), CxtRef.InputPCM.size(), + CxtRef.WhisperConfig.ProcessorsNum) != 0) { spdlog::error( "[WASI-NN] Whisper backend: Error: failed to process audio."sv); return ErrNo::RuntimeError; @@ -480,6 +614,31 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { return ErrNo::Success; } +Expect unload(WasiNNEnvironment &Env, uint32_t GraphId) noexcept { + auto &GraphRef = Env.NNGraph[GraphId].get(); + const bool IsDebugLog = GraphRef.WhisperConfig.EnableDebugLog; + if (IsDebugLog) { + spdlog::info("[WASI-NN][Debug] Whisper backend: unload"sv); + } + if (GraphRef.WhisperCtx != nullptr) { + if (IsDebugLog) { + spdlog::info( + "[WASI-NN][Debug] Whisper backend: unload: free whisper context"sv); + } + whisper_free(GraphRef.WhisperCtx); + GraphRef.WhisperCtx = nullptr; + if (IsDebugLog) { + spdlog::info( + "[WASI-NN][Debug] Whisper backend: unload: free whisper context...Done"sv); + } + } + Env.NNGraph.erase(Env.NNGraph.begin() + GraphId); + if (IsDebugLog) { + spdlog::info("[WASI-NN][Debug] Whisper backend: unload...Done"sv); + } + return ErrNo::Success; +} + #else namespace { @@ -508,6 +667,9 @@ Expect getOutput(WasiNNEnvironment &, uint32_t, uint32_t, Span, Expect compute(WasiNNEnvironment &, uint32_t) noexcept { return reportBackendNotSupported(); } +Expect unload(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} #endif } // namespace WasmEdge::Host::WASINN::Whisper diff --git a/plugins/wasi_nn/whispercpp.h b/plugins/wasi_nn/whispercpp.h index 70f727a3..6958f9b3 100644 --- a/plugins/wasi_nn/whispercpp.h +++ b/plugins/wasi_nn/whispercpp.h @@ -9,7 +9,9 @@ #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER #include +#include #include +#include #include #endif @@ -22,10 +24,19 @@ namespace WasmEdge::Host::WASINN::Whisper { struct Config { // Whisper parameters: + uint64_t ThreadsNum = + std::min(static_cast(4), + static_cast(std::thread::hardware_concurrency())); + uint64_t ProcessorsNum = 1; + uint64_t MaxTokenContext = 16384; + uint64_t TimeOffsetMS = 0; + uint64_t DurationMS = 0; + uint64_t MaxSegmentLength = 0; bool EnableLog = false; bool EnableDebugLog = false; bool Translate = false; bool DetectLanguage = false; + bool SplitOnWord = false; std::string SpokenLanguage; std::string InitialPrompt; // Sampling parameters: @@ -86,4 +97,6 @@ Expect getOutput(WASINN::WasiNNEnvironment &Env, uint32_t &BytesWritten) noexcept; Expect compute(WASINN::WasiNNEnvironment &Env, uint32_t ContextId) noexcept; +Expect unload(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId) noexcept; } // namespace WasmEdge::Host::WASINN::Whisper From ad5f2e70f8c3f39282545cf14c998ba63dbc65d4 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Fri, 27 Sep 2024 03:32:46 +0800 Subject: [PATCH 457/623] [WASI-NN] Whisper: update to the latest and support Metal on MacOS. Signed-off-by: YiYing He --- plugins/wasi_nn/CMakeLists.txt | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 7c74e8de..35c64c5a 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -173,20 +173,24 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) target_link_libraries(wasmedgePluginWasiNN PRIVATE simdjson::simdjson) elseif(BACKEND STREQUAL "whisper") wasmedge_setup_simdjson() - if(APPLE AND CMAKE_SYSTEM_VERSION VERSION_LESS 23) - # `cblas_sgemm()` introduced in macOS 13.3. - set(WHISPER_NO_ACCELERATE ON CACHE INTERNAL "Stable diffusion turn off accelerate") - endif() set(BUILD_SHARED_LIBS OFF CACHE INTERNAL "Whisper not build shared") + set(GGML_OPENMP OFF) + set(GGML_ACCELERATE OFF) + set(GGML_BLAS OFF) include(FetchContent) FetchContent_Declare( whisper GIT_REPOSITORY https://github.com/ggerganov/whisper.cpp.git - GIT_TAG v1.6.2 + GIT_TAG 69339af2d104802f3f201fd419163defba52890e GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(whisper) set_property(TARGET whisper PROPERTY POSITION_INDEPENDENT_CODE ON) + set_property(TARGET ggml PROPERTY POSITION_INDEPENDENT_CODE ON) + target_include_directories(wasmedgePluginWasiNN PRIVATE + ${whisper_SOURCE_DIR} + ${whisper_SOURCE_DIR}/ggml/include + ) target_link_libraries(wasmedgePluginWasiNN PRIVATE whisper simdjson::simdjson From f1ccff89979e5e1c0af10d5bef36620f3a0b7290 Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Tue, 4 Jun 2024 18:16:14 +0800 Subject: [PATCH 458/623] [Misc] Add fast string hash for speed Signed-off-by: Shen-Ta Hsieh --- plugins/wasmedge_process/processenv.h | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/plugins/wasmedge_process/processenv.h b/plugins/wasmedge_process/processenv.h index 9b3626dd..07e9aa0b 100644 --- a/plugins/wasmedge_process/processenv.h +++ b/plugins/wasmedge_process/processenv.h @@ -3,6 +3,7 @@ #pragma once +#include "common/hash.h" #include "plugin/plugin.h" #include "po/argument_parser.h" #include "po/list.h" @@ -37,9 +38,12 @@ class WasmEdgeProcessEnvironment { std::vector StdErr; /// Configurations - uint32_t TimeOut = DEFAULT_TIMEOUT; /// Timeout in milliseconds. - std::unordered_set AllowedCmd; /// Programs in white list. - bool AllowedAll; /// Flag to allow all programs. + /// Timeout in milliseconds. + uint32_t TimeOut = DEFAULT_TIMEOUT; + /// Programs in white list. + std::unordered_set AllowedCmd; + /// Flag to allow all programs. + bool AllowedAll; /// Results uint32_t ExitCode = 0; From 912ceeb24b48a06457cad2f4b1c16caa4b3c529a Mon Sep 17 00:00:00 2001 From: Yi-Ying He Date: Wed, 2 Oct 2024 00:11:38 +0800 Subject: [PATCH 459/623] [WASI-NN] mlx: Lint codes and refine cmake. (#3808) Signed-off-by: YiYing He --- plugins/wasi_nn/CMakeLists.txt | 49 +++++++++++++------ plugins/wasi_nn/MLX/mlx/activations.cpp | 9 ++++ plugins/wasi_nn/MLX/mlx/activations.h | 9 +++- plugins/wasi_nn/MLX/mlx/base.cpp | 12 +++-- plugins/wasi_nn/MLX/mlx/base.h | 28 +++++++++-- plugins/wasi_nn/MLX/mlx/embedding.cpp | 11 +++-- plugins/wasi_nn/MLX/mlx/embedding.h | 14 +++++- plugins/wasi_nn/MLX/mlx/linear.cpp | 6 ++- plugins/wasi_nn/MLX/mlx/linear.h | 12 ++++- plugins/wasi_nn/MLX/mlx/normalization.cpp | 5 ++ plugins/wasi_nn/MLX/mlx/normalization.h | 6 ++- .../wasi_nn/MLX/mlx/positional_encoding.cpp | 5 +- plugins/wasi_nn/MLX/mlx/positional_encoding.h | 8 ++- plugins/wasi_nn/MLX/mlx/quantized.cpp | 12 ++++- plugins/wasi_nn/MLX/mlx/quantized.h | 14 +++++- plugins/wasi_nn/MLX/mlx/transformer.cpp | 7 ++- plugins/wasi_nn/MLX/mlx/transformer.h | 10 +++- plugins/wasi_nn/MLX/model/converter.cpp | 7 ++- plugins/wasi_nn/MLX/model/converter.h | 10 +++- plugins/wasi_nn/MLX/model/registry.cpp | 4 ++ plugins/wasi_nn/MLX/model/registry.h | 4 ++ plugins/wasi_nn/MLX/model/transformer.cpp | 19 +++++-- plugins/wasi_nn/MLX/model/transformer.h | 22 ++++++++- plugins/wasi_nn/MLX/model/utils.cpp | 10 ++++ plugins/wasi_nn/MLX/model/utils.h | 11 +++++ plugins/wasi_nn/MLX/prompt/prompt.cpp | 7 +++ plugins/wasi_nn/MLX/prompt/prompt.h | 12 +++++ 27 files changed, 275 insertions(+), 48 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 35c64c5a..c6b07cd0 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -196,6 +196,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) simdjson::simdjson ) elseif(BACKEND STREQUAL "mlx") + wasmedge_setup_simdjson() target_sources(wasmedgePluginWasiNN PRIVATE MLX/prompt/prompt.cpp @@ -212,7 +213,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) MLX/mlx/transformer.cpp MLX/mlx/quantized.cpp ) - wasmedge_setup_simdjson() + find_package(MLX CONFIG) if(MLX_FOUND) message(STATUS "Found MLX: ${MLX_INCLUDE_DIRS}") @@ -236,8 +237,8 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) set_target_properties(mlx PROPERTIES INTERFACE_LINK_LIBRARIES "$" ) - target_link_libraries( - mlx PUBLIC + target_link_libraries(mlx + PUBLIC ${ACCELERATE_LIBRARY} ${METAL_LIB} ${FOUNDATION_LIB} @@ -250,7 +251,6 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) -Wno-format ) endif() - target_include_directories(wasmedgePluginWasiNN PUBLIC MLX/model MLX/prompt MLX/mlx) message(STATUS "Downloading tokenizers") FetchContent_Declare( @@ -261,12 +261,6 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) ) FetchContent_MakeAvailable(tokenizers) set_property(TARGET tokenizer_cpp_objs PROPERTY POSITION_INDEPENDENT_CODE ON) - target_include_directories(wasmedgePluginWasiNN PRIVATE ${tokenizers_SOURCE_DIR}/include) - target_link_libraries(wasmedgePluginWasiNN PRIVATE tokenizers_cpp) - - target_include_directories(wasmedgePluginWasiNN SYSTEM PUBLIC ./mlx) - target_include_directories(wasmedgePluginWasiNN SYSTEM PUBLIC ${MLX_INCLUDE_DIRS}) - target_link_libraries(wasmedgePluginWasiNN PUBLIC ${MLX_LIBRARIES}) message(STATUS "Downloading gguflib") FetchContent_Declare( @@ -276,14 +270,37 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) GIT_SHALLOW FALSE ) FetchContent_MakeAvailable(gguflib) - target_include_directories(wasmedgePluginWasiNN SYSTEM PUBLIC $) - add_library(gguflib STATIC ${gguflib_SOURCE_DIR}/fp16.c - ${gguflib_SOURCE_DIR}/gguflib.c) + add_library(gguflib + STATIC + ${gguflib_SOURCE_DIR}/fp16.c + ${gguflib_SOURCE_DIR}/gguflib.c + ) set_target_properties(gguflib PROPERTIES LINKER_LANGUAGE CXX) - target_link_libraries(wasmedgePluginWasiNN PUBLIC gguflib) - - target_link_libraries(wasmedgePluginWasiNN PUBLIC + target_include_directories(wasmedgePluginWasiNN + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/MLX/model + ${CMAKE_CURRENT_SOURCE_DIR}/MLX/prompt + ${CMAKE_CURRENT_SOURCE_DIR}/MLX/mlx + ) + target_include_directories(wasmedgePluginWasiNN + PRIVATE + ${tokenizers_SOURCE_DIR}/include + ) + target_include_directories(wasmedgePluginWasiNN + SYSTEM PUBLIC + ./mlx + ${MLX_INCLUDE_DIRS} + $ + ) + target_link_libraries(wasmedgePluginWasiNN + PRIVATE + tokenizers_cpp + ) + target_link_libraries(wasmedgePluginWasiNN + PUBLIC + ${MLX_LIBRARIES} + gguflib mlx simdjson::simdjson ) diff --git a/plugins/wasi_nn/MLX/mlx/activations.cpp b/plugins/wasi_nn/MLX/mlx/activations.cpp index 4149651a..0b6fc13f 100644 --- a/plugins/wasi_nn/MLX/mlx/activations.cpp +++ b/plugins/wasi_nn/MLX/mlx/activations.cpp @@ -1,11 +1,20 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "activations.h" + #include +#include + namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core { + mx::array gelu(mx::array X) { return X * (1 + mx::erf(X / std::sqrt(2.0))) / 2.0; } + mx::array silu(mx::array X) { return X * mx::sigmoid(X); } + } // namespace mlx::core } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/activations.h b/plugins/wasi_nn/MLX/mlx/activations.h index e997a3d6..68fda485 100644 --- a/plugins/wasi_nn/MLX/mlx/activations.h +++ b/plugins/wasi_nn/MLX/mlx/activations.h @@ -1,11 +1,16 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once -#include "base.h" -#include + #include namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core { + mx::array gelu(mx::array X); + mx::array silu(mx::array X); + } // namespace mlx::core } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/base.cpp b/plugins/wasi_nn/MLX/mlx/base.cpp index 4a395e32..0c7ac715 100644 --- a/plugins/wasi_nn/MLX/mlx/base.cpp +++ b/plugins/wasi_nn/MLX/mlx/base.cpp @@ -1,22 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "base.h" #include "../model/utils.h" -#include + #include -#include namespace WasmEdge::Host::WASINN::MLX { - namespace mlx::core::nn { mx::array &Module::registerParameter(std::string Name, mx::array &&W) { Parameters.insert({Name, W}); return Parameters.at(Name); } + void Module::update(std::unordered_map Parameters) { for (auto &[K, V] : Parameters) { apply(K, V); } } + std::shared_ptr Module::toQuantized(int GroupSize, int Bits) { for (auto &[K, V] : Submodules) { const auto OldModule = V; @@ -24,6 +27,7 @@ std::shared_ptr Module::toQuantized(int GroupSize, int Bits) { } return shared_from_this(); } + void Module::apply(std::string Key, mx::array Value) { std::vector SplitKey = splitString(Key, '.'); if (SplitKey.size() == 1) { @@ -47,6 +51,7 @@ void Module::apply(std::string Key, mx::array Value) { Submodules.at(LayerName)->apply(joinString(SplitKey, '.'), Value); } } + std::unordered_map Module::getWeigts(const std::string &Prefix) { std::unordered_map Weights; @@ -59,5 +64,6 @@ Module::getWeigts(const std::string &Prefix) { } return Weights; } + } // namespace mlx::core::nn } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/base.h b/plugins/wasi_nn/MLX/mlx/base.h index 66482f36..65bbff67 100644 --- a/plugins/wasi_nn/MLX/mlx/base.h +++ b/plugins/wasi_nn/MLX/mlx/base.h @@ -1,31 +1,48 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once -#include "common/errcode.h" + #include "mlx/mlx.h" + +#include "common/errcode.h" + #include +#include #include #include #include #include #include + using namespace std::literals::string_view_literals; namespace WasmEdge::Host::WASINN::MLX { + namespace mx = mlx::core; namespace mlx::core::nn { + class Module : public std::enable_shared_from_this { public: - virtual ~Module() = default; std::string Name; - std::unordered_map Parameters{}; - std::unordered_map> Submodules{}; + std::unordered_map Parameters; + std::unordered_map> Submodules; + + virtual ~Module() = default; + mx::array ®isterParameter(std::string Name, mx::array &&W); + std::unordered_map getWeigts(const std::string &Prefix = "model"); + virtual std::shared_ptr toQuantized(int GroupSize = 64, int Bits = 4); + void update(std::unordered_map Parameters); + void apply(std::string Key, mx::array Parameters); + template void registerModule(std::string ModuleName, std::shared_ptr M) { using DecayedT = std::decay_t; @@ -42,6 +59,7 @@ class Module : public std::enable_shared_from_this { assumingUnreachable(); } } + template void registerLayer(std::string ModuleName, std::vector> &Layers) { @@ -54,6 +72,7 @@ class Module : public std::enable_shared_from_this { } } }; + } // namespace mlx::core::nn template void printVec(std::vector Ve) { @@ -61,4 +80,5 @@ template void printVec(std::vector Ve) { spdlog::debug("[WASI-NN] MLX backend: {} ."sv, I); } } + } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/embedding.cpp b/plugins/wasi_nn/MLX/mlx/embedding.cpp index 5ec56806..e12264e6 100644 --- a/plugins/wasi_nn/MLX/mlx/embedding.cpp +++ b/plugins/wasi_nn/MLX/mlx/embedding.cpp @@ -1,19 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "embedding.h" -#include "base.h" #include "quantized.h" -#include + #include #include namespace WasmEdge::Host::WASINN::MLX { - namespace mlx::core::nn { + mx::array Embedding::forward(mx::array Input) { return take(Parameters.at("weight"), Input, 0); } + mx::array Embedding::asLinear(mx::array Input) { return matmul(Input, transpose(Parameters.at("weight"))); } + std::shared_ptr Embedding::toQuantized(int GroupSize, int Bits) { auto QuantModel = QuantizedEmbedding::fromEmbedding( std::dynamic_pointer_cast(shared_from_this()), GroupSize, @@ -21,5 +25,6 @@ std::shared_ptr Embedding::toQuantized(int GroupSize, int Bits) { QuantModel->Name = Name; return QuantModel; } + } // namespace mlx::core::nn } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/embedding.h b/plugins/wasi_nn/MLX/mlx/embedding.h index f2943eba..94eba689 100644 --- a/plugins/wasi_nn/MLX/mlx/embedding.h +++ b/plugins/wasi_nn/MLX/mlx/embedding.h @@ -1,21 +1,33 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once + #include "base.h" -#include + #include #include +#include +#include + namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core::nn { + class Embedding : public Module { public: Embedding() = default; + Embedding(int NumEmbeddings, int Dims) { const double Scale = std::sqrt(1.0 / Dims); registerParameter("weight", mx::random::normal({NumEmbeddings, Dims}, 0.0, Scale)); } + virtual mx::array forward(mx::array Input); + mx::array asLinear(mx::array Input); + std::shared_ptr toQuantized(int GroupSize = 64, int Bits = 4) override; }; diff --git a/plugins/wasi_nn/MLX/mlx/linear.cpp b/plugins/wasi_nn/MLX/mlx/linear.cpp index 60018a0c..1d5a0104 100644 --- a/plugins/wasi_nn/MLX/mlx/linear.cpp +++ b/plugins/wasi_nn/MLX/mlx/linear.cpp @@ -1,10 +1,14 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "linear.h" -#include "base.h" #include "quantized.h" + #include namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core::nn { + mx::array Linear::forward(mx::array Input) { if (EnableBias) { return mx::addmm(Parameters.at("bias"), Input, diff --git a/plugins/wasi_nn/MLX/mlx/linear.h b/plugins/wasi_nn/MLX/mlx/linear.h index d6d0eb58..116317fc 100644 --- a/plugins/wasi_nn/MLX/mlx/linear.h +++ b/plugins/wasi_nn/MLX/mlx/linear.h @@ -1,13 +1,18 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "base.h" #include "mlx/mlx.h" -#include + #include #include -namespace WasmEdge::Host::WASINN::MLX { +#include +#include +namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core::nn { class Linear : public Module { @@ -27,9 +32,12 @@ class Linear : public Module { })); } } + virtual mx::array forward(mx::array Input); + std::shared_ptr toQuantized(int GroupSize = 64, int Bits = 4) override; }; + } // namespace mlx::core::nn } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/normalization.cpp b/plugins/wasi_nn/MLX/mlx/normalization.cpp index 3805a6ce..919b2d96 100644 --- a/plugins/wasi_nn/MLX/mlx/normalization.cpp +++ b/plugins/wasi_nn/MLX/mlx/normalization.cpp @@ -1,9 +1,14 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "normalization.h" namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core::nn { + mx::array RMSNorm::forward(mx::array Input) { return mx::fast::rms_norm(Input, Parameters.at("weight"), Eps); } + } // namespace mlx::core::nn } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/normalization.h b/plugins/wasi_nn/MLX/mlx/normalization.h index 72bf3ba7..3a2b901c 100644 --- a/plugins/wasi_nn/MLX/mlx/normalization.h +++ b/plugins/wasi_nn/MLX/mlx/normalization.h @@ -1,8 +1,11 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "base.h" namespace WasmEdge::Host::WASINN::MLX { - namespace mlx::core::nn { + class RMSNorm : public nn::Module { float Eps; @@ -10,6 +13,7 @@ class RMSNorm : public nn::Module { RMSNorm(int Dims, float Eps = 1e-5) : Eps(Eps) { registerParameter("weight", mx::ones({Dims})); } + mx::array forward(mx::array Input); }; diff --git a/plugins/wasi_nn/MLX/mlx/positional_encoding.cpp b/plugins/wasi_nn/MLX/mlx/positional_encoding.cpp index df04a92e..a5be4b92 100644 --- a/plugins/wasi_nn/MLX/mlx/positional_encoding.cpp +++ b/plugins/wasi_nn/MLX/mlx/positional_encoding.cpp @@ -1,8 +1,11 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "positional_encoding.h" namespace WasmEdge::Host::WASINN::MLX { - namespace mlx::core::nn { + mx::array RoPE::forward(mx::array Input, int Offset) { return mx::fast::rope(Input, Dims, Tranditional, Base, Scale, Offset); } diff --git a/plugins/wasi_nn/MLX/mlx/positional_encoding.h b/plugins/wasi_nn/MLX/mlx/positional_encoding.h index 8f8f97ae..42abaff9 100644 --- a/plugins/wasi_nn/MLX/mlx/positional_encoding.h +++ b/plugins/wasi_nn/MLX/mlx/positional_encoding.h @@ -1,11 +1,16 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once + #include "base.h" + #include #include namespace WasmEdge::Host::WASINN::MLX { - namespace mlx::core::nn { + class RoPE : public Module { int Dims; bool Tranditional; @@ -16,6 +21,7 @@ class RoPE : public Module { RoPE(int Dims, bool Traditional = false, float Base = 10000, float Scale = 1.0) : Dims(Dims), Tranditional(Traditional), Base(Base), Scale(Scale) {} + mx::array forward(mx::array Input, int Offset = 0); }; diff --git a/plugins/wasi_nn/MLX/mlx/quantized.cpp b/plugins/wasi_nn/MLX/mlx/quantized.cpp index c4b122dc..6ed72678 100644 --- a/plugins/wasi_nn/MLX/mlx/quantized.cpp +++ b/plugins/wasi_nn/MLX/mlx/quantized.cpp @@ -1,12 +1,17 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "quantized.h" -#include -#include + #include #include + +#include #include namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core::nn { + mx::array QuantizedEmbedding::forward(mx::array Input) { auto S = Input.shape(); auto X = mx::flatten(Input); @@ -27,6 +32,7 @@ mx::array QuantizedLinear::forward(mx::array Input) { } return Out; } + std::shared_ptr QuantizedEmbedding::fromEmbedding(std::shared_ptr EmbeddingModule, int GroupSize, int Bits) { @@ -42,6 +48,7 @@ QuantizedEmbedding::fromEmbedding(std::shared_ptr EmbeddingModule, "biases", std::move(std::get<2>(Quantized))); return QuantizedModel; } + std::shared_ptr QuantizedLinear::fromLinear(std::shared_ptr LinearModule, int GroupSize, int Bits) { @@ -63,5 +70,6 @@ QuantizedLinear::fromLinear(std::shared_ptr LinearModule, int GroupSize, } return QuantizedModel; } + } // namespace mlx::core::nn } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/quantized.h b/plugins/wasi_nn/MLX/mlx/quantized.h index 5b3ad9c8..631f32a9 100644 --- a/plugins/wasi_nn/MLX/mlx/quantized.h +++ b/plugins/wasi_nn/MLX/mlx/quantized.h @@ -1,12 +1,18 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once + #include "base.h" #include "embedding.h" #include "linear.h" + #include #include -namespace WasmEdge::Host::WASINN::MLX { +#include +namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core::nn { class QuantizedEmbedding : public Embedding { @@ -15,6 +21,7 @@ class QuantizedEmbedding : public Embedding { int Bits; int NumEmbeddings; int Dims; + QuantizedEmbedding(int NumEmbeddings, int Dims, int GroupSize = 64, int Bits = 4) : GroupSize(GroupSize), Bits(Bits), NumEmbeddings(NumEmbeddings), @@ -27,7 +34,9 @@ class QuantizedEmbedding : public Embedding { registerParameter("scales", std::move(std::get<1>(Quantized))); registerParameter("biases", std::move(std::get<2>(Quantized))); } + mx::array forward(mx::array Input) override; + static std::shared_ptr fromEmbedding(std::shared_ptr EmbeddingModule, int GroupSize = 64, int Bits = 4); @@ -37,6 +46,7 @@ class QuantizedLinear : public Linear { public: int GroupSize; int Bits; + QuantizedLinear(int InputDims, int OutputDim, bool Bias = true, int GroupSize = 64, int Bits = 4) : GroupSize(GroupSize), Bits(Bits) { @@ -51,7 +61,9 @@ class QuantizedLinear : public Linear { registerParameter("bias", mx::zeros({OutputDim})); } } + mx::array forward(mx::array Input) override; + static std::shared_ptr fromLinear(std::shared_ptr LinearModule, int GroupSize = 64, int Bits = 4); diff --git a/plugins/wasi_nn/MLX/mlx/transformer.cpp b/plugins/wasi_nn/MLX/mlx/transformer.cpp index da15eba4..6933dd5c 100644 --- a/plugins/wasi_nn/MLX/mlx/transformer.cpp +++ b/plugins/wasi_nn/MLX/mlx/transformer.cpp @@ -1,9 +1,13 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "transformer.h" + #include namespace WasmEdge::Host::WASINN::MLX { - namespace mlx::core::nn { + mx::array MultiHeadAttention::createAdditiveCausalMask(int N, mx::Dtype DType) { auto Indices = mx::arange(N); // mask = indices[:, None] < indices[None] @@ -12,5 +16,6 @@ mx::array MultiHeadAttention::createAdditiveCausalMask(int N, mx::Dtype DType) { Mask = astype(Mask, DType) * -1e9; return Mask; } + } // namespace mlx::core::nn } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/transformer.h b/plugins/wasi_nn/MLX/mlx/transformer.h index 9873ec76..84cb0b7d 100644 --- a/plugins/wasi_nn/MLX/mlx/transformer.h +++ b/plugins/wasi_nn/MLX/mlx/transformer.h @@ -1,11 +1,16 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once + #include "base.h" #include "linear.h" + #include namespace WasmEdge::Host::WASINN::MLX { - namespace mlx::core::nn { + class MultiHeadAttention : public Module { int NumHeads; @@ -46,10 +51,13 @@ class MultiHeadAttention : public Module { registerModule("out_proj", std::make_shared( Linear(*ValueDims, *ValueOutputDims, Bias))); }; + mx::array forward(mx::array Queries, mx::array Keys, mx::array Values, mx::array Mask); + static mx::array createAdditiveCausalMask(int N, mx::Dtype DType = mx::float32); }; + } // namespace mlx::core::nn } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/converter.cpp b/plugins/wasi_nn/MLX/model/converter.cpp index 35e57265..21c56a60 100644 --- a/plugins/wasi_nn/MLX/model/converter.cpp +++ b/plugins/wasi_nn/MLX/model/converter.cpp @@ -1,8 +1,10 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "converter.h" #include "utils.h" + #include -#include -#include #include namespace WasmEdge::Host::WASINN::MLX { @@ -85,4 +87,5 @@ llamaToMlxllm(std::string WeightPath) { } return ModelWeights; } + } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/converter.h b/plugins/wasi_nn/MLX/model/converter.h index dda09837..b4c4a438 100644 --- a/plugins/wasi_nn/MLX/model/converter.h +++ b/plugins/wasi_nn/MLX/model/converter.h @@ -1,7 +1,14 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once + #include "base.h" + #include "mlx/mlx.h" -#include + +#include +#include namespace WasmEdge::Host::WASINN::MLX { @@ -11,4 +18,5 @@ std::unordered_map weightsToMlx(std::string WeightPath); std::unordered_map llamaToMlxllm(std::string WeightPath); + } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/registry.cpp b/plugins/wasi_nn/MLX/model/registry.cpp index 7b4d5661..7fa15b86 100644 --- a/plugins/wasi_nn/MLX/model/registry.cpp +++ b/plugins/wasi_nn/MLX/model/registry.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "registry.h" #include "transformer.h" @@ -25,4 +28,5 @@ std::shared_ptr tinyLlama11BChatV10(int VocabSize, float NormEps, 2048, std::vector{5632}, VocabSize, 22, std::vector{32}, std::vector{4}, NormEps, {}, RopeTraditional, RopeTheta)); } + } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/registry.h b/plugins/wasi_nn/MLX/model/registry.h index 4ed1d0c3..456b8aaa 100644 --- a/plugins/wasi_nn/MLX/model/registry.h +++ b/plugins/wasi_nn/MLX/model/registry.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include "transformer.h" @@ -18,4 +21,5 @@ std::shared_ptr tinyLlama11BChatV10(int VocabSize = 32000, float NormEps = 1e-5, float RopeTheta = 10000.0, bool RopeTraditional = false); + } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/transformer.cpp b/plugins/wasi_nn/MLX/model/transformer.cpp index 00cd0e70..14503a1d 100644 --- a/plugins/wasi_nn/MLX/model/transformer.cpp +++ b/plugins/wasi_nn/MLX/model/transformer.cpp @@ -1,14 +1,19 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "../mlx/transformer.h" #include "base.h" #include "embedding.h" #include "linear.h" #include "transformer.h" -#include -#include -#include + #include #include #include + +#include +#include +#include #include #include @@ -17,6 +22,7 @@ namespace WasmEdge::Host::WASINN::MLX { mx::array RMSNorm::forward(mx::array Input) { return mx::fast::rms_norm(Input, 1.0 + Parameters.at("weight"), Eps); } + std::tuple> Attention::forward(mx::array Input, std::optional Mask, std::optional> KVCache) { @@ -60,6 +66,7 @@ Attention::forward(mx::array Input, std::optional Mask, ->forward(Output), {Keys, Values}}; } + mx::array MLP::forward(mx::array Input) { if (Gemma) { return std::dynamic_pointer_cast(Submodules["down_proj"]) @@ -76,6 +83,7 @@ mx::array MLP::forward(mx::array Input) { std::dynamic_pointer_cast(Submodules["up_proj"]) ->forward(Input)); } + std::tuple> TransformerBlock::forward( mx::array Input, std::optional Mask, @@ -106,6 +114,7 @@ TransformerBlock::forward( } return {H + R, KVCache}; } + std::tuple>>> Transformer::embed( @@ -147,6 +156,7 @@ Transformer::embed( } return {H, KVCache}; } + std::tuple>>> Transformer::forward( @@ -162,6 +172,7 @@ Transformer::forward( } return {Out, KVCache}; } + std::tuple>>> Transformer::generate(mx::array Input, std::optional Temp) { @@ -183,6 +194,7 @@ Transformer::generate(mx::array Input, std::optional Temp) { } return {Y, KVCache}; } + std::tuple>>> Transformer::nextGenerate( @@ -201,4 +213,5 @@ Transformer::nextGenerate( } return {NextY, KVCache}; } + } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/transformer.h b/plugins/wasi_nn/MLX/model/transformer.h index 6b99e599..680c08a4 100644 --- a/plugins/wasi_nn/MLX/model/transformer.h +++ b/plugins/wasi_nn/MLX/model/transformer.h @@ -1,13 +1,19 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once + #include "activations.h" #include "base.h" #include "embedding.h" #include "linear.h" #include "normalization.h" #include "positional_encoding.h" + #include #include #include + #include #include #include @@ -24,10 +30,11 @@ class RMSNorm : public nn::Module { RMSNorm(int Dims, float Eps = 1e-5) : Eps(Eps) { registerParameter("weight", mx::ones({Dims})); } + mx::array forward(mx::array Input); }; -class Attention : public nn::Module { +class Attention : public nn::Module { int NHeads; int NKVHeads; bool NormQKProj; @@ -74,10 +81,12 @@ class Attention : public nn::Module { std::make_shared(nn::RoPE(HeadDim, RopeTraditional, RopeTheta, RopeScale))); } + std::tuple> forward(mx::array Input, std::optional Mask = {}, std::optional> KVCache = {}); }; + class MLP : public nn::Module { bool Gemma; @@ -90,8 +99,10 @@ class MLP : public nn::Module { registerModule("up_proj", std::make_shared( nn::Linear(Dim, HiddenDim, false))); } + mx::array forward(mx::array Input); }; + class TransformerBlock : public nn::Module { bool Gemma; @@ -121,16 +132,18 @@ class TransformerBlock : public nn::Module { std::make_shared(RMSNorm(Dim, NormEps))); } } + std::tuple> forward(mx::array Input, std::optional Mask = {}, std::optional> KVCachePar = {}); }; + class Transformer : public nn::Module { int Dim; std::optional> HiddenDim; bool Gemma; bool EmbedAsHead; - std::vector> Layers{}; + std::vector> Layers; public: Transformer( @@ -197,24 +210,29 @@ class Transformer : public nn::Module { nn::Linear(Dim, VocabSize, false))); } } + std::tuple>>> embed(mx::array Input, std::optional>> KVCachePar = {}, bool Norm = false); + std::tuple>>> forward(mx::array Input, std::optional>> KVCachePar = {}); + std::tuple>>> generate(mx::array Input, std::optional Temp = 0.0); + std::tuple>>> nextGenerate(mx::array Y, std::optional Temp = 0.0, std::optional>> KVCachePar = {}); }; + } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/utils.cpp b/plugins/wasi_nn/MLX/model/utils.cpp index 6b21eb9d..451d9fc3 100644 --- a/plugins/wasi_nn/MLX/model/utils.cpp +++ b/plugins/wasi_nn/MLX/model/utils.cpp @@ -1,7 +1,12 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "utils.h" + #include namespace WasmEdge::Host::WASINN::MLX { + std::vector splitString(const std::string &S, char Delim) { std::vector Result; std::stringstream SS(S); @@ -11,6 +16,7 @@ std::vector splitString(const std::string &S, char Delim) { } return Result; } + std::string joinString(std::vector &S, char Delim) { std::string Result; for (size_t Idx = 0; Idx < S.size(); Idx++) { @@ -27,11 +33,13 @@ bool endsWith(std::string const &Value, std::string const &Ending) { return false; return std::equal(Ending.rbegin(), Ending.rend(), Value.rbegin()); } + bool startsWith(std::string const &Value, std::string const &Starting) { if (Starting.size() > Value.size()) return false; return std::equal(Starting.begin(), Starting.end(), Value.begin()); } + void saveWeights(const std::unordered_map &Weights, const std::string Path) { if (endsWith(Path, ".safetensors")) { @@ -41,6 +49,7 @@ void saveWeights(const std::unordered_map &Weights, assumingUnreachable(); } } + void saveWeights(const mx::array &Weights, const std::string &Path) { if (endsWith(Path, ".npz")) { mx::save(Path, Weights); @@ -49,4 +58,5 @@ void saveWeights(const mx::array &Weights, const std::string &Path) { assumingUnreachable(); } } + } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/utils.h b/plugins/wasi_nn/MLX/model/utils.h index 606a2245..e0a86f4e 100644 --- a/plugins/wasi_nn/MLX/model/utils.h +++ b/plugins/wasi_nn/MLX/model/utils.h @@ -1,5 +1,10 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once + #include "base.h" + #include #include #include @@ -7,10 +12,16 @@ namespace WasmEdge::Host::WASINN::MLX { std::vector splitString(const std::string &S, char Delim); + std::string joinString(std::vector &S, char Delim); + bool endsWith(std::string const &Value, std::string const &Ending); + bool startsWith(std::string const &Value, std::string const &Starting); + void saveWeights(const std::unordered_map &Weights, const std::string Path); + void saveWeights(const mx::array &Weights, const std::string &Path); + } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/prompt/prompt.cpp b/plugins/wasi_nn/MLX/prompt/prompt.cpp index 8e22fc71..0b2d2ceb 100644 --- a/plugins/wasi_nn/MLX/prompt/prompt.cpp +++ b/plugins/wasi_nn/MLX/prompt/prompt.cpp @@ -1,4 +1,8 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #include "prompt.h" + #include namespace WasmEdge::Host::WASINN::MLX { @@ -6,12 +10,15 @@ namespace WasmEdge::Host::WASINN::MLX { std::string TinyLLaMAPrompt::prepare(std::string Prompt) { return SystemStart + TextEnd + Prompt + TextEnd + Assistant; } + std::string LLaMA2Prompt::prepare(std::string Prompt) { return InstStart + SysStart + SysEnd + Prompt + TextEnd; } + std::string LLaMA3Prompt::prepare(std::string Prompt) { return PropmtStart + StartHeader + "system" + EndHeader + TextEnd + Prompt + EndHeader + TextEnd + StartHeader + "user" + EndHeader + Prompt + TextEnd + StartHeader + "assistant" + EndHeader; } + } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/prompt/prompt.h b/plugins/wasi_nn/MLX/prompt/prompt.h index 20d46530..d66cfc5b 100644 --- a/plugins/wasi_nn/MLX/prompt/prompt.h +++ b/plugins/wasi_nn/MLX/prompt/prompt.h @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once #include @@ -9,6 +12,7 @@ namespace WasmEdge::Host::WASINN::MLX { class BasePrompt { public: std::string TextEnd; + virtual std::string prepare(std::string Prompt) { return Prompt + TextEnd; }; }; @@ -17,12 +21,14 @@ class TinyLLaMAPrompt : public BasePrompt { std::string SystemStart; std::string User; std::string Assistant; + TinyLLaMAPrompt() { SystemStart = "<|system|>"; Assistant = "<|assistant|>"; User = "<|user|>"; TextEnd = ""; } + std::string prepare(std::string Prompt) override; }; @@ -31,25 +37,31 @@ class LLaMA2Prompt : public BasePrompt { std::string SysStart; std::string SysEnd; std::string InstStart; + LLaMA2Prompt() { SysStart = "<>"; SysEnd = "<>"; InstStart = "[INST]"; TextEnd = "[/INST]"; } + std::string prepare(std::string Prompt) override; }; + class LLaMA3Prompt : public BasePrompt { public: std::string PropmtStart; std::string StartHeader; std::string EndHeader; + LLaMA3Prompt() { PropmtStart = "<|begin_of_text|>"; StartHeader = "<|start_header_id|>"; EndHeader = "<|end_header_id|>"; TextEnd = "<|eot_id|>"; } + std::string prepare(std::string Prompt) override; }; + } // namespace WasmEdge::Host::WASINN::MLX From 6427e59c9713b8ddcc973f866f39cb8486a456d2 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Wed, 2 Oct 2024 02:29:25 +0800 Subject: [PATCH 460/623] [WASI-NN] mlx: fix missing header. Signed-off-by: YiYing He --- plugins/wasi_nn/CMakeLists.txt | 4 ++-- plugins/wasi_nn/MLX/mlx/activations.h | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index c6b07cd0..775d9564 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -255,7 +255,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) message(STATUS "Downloading tokenizers") FetchContent_Declare( tokenizers - GIT_REPOSITORY git@github.com:mlc-ai/tokenizers-cpp.git + GIT_REPOSITORY https://github.com/mlc-ai/tokenizers-cpp.git GIT_TAG 5de6f65 GIT_SHALLOW FALSE ) @@ -289,7 +289,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) ) target_include_directories(wasmedgePluginWasiNN SYSTEM PUBLIC - ./mlx + ${CMAKE_CURRENT_SOURCE_DIR}/mlx ${MLX_INCLUDE_DIRS} $ ) diff --git a/plugins/wasi_nn/MLX/mlx/activations.h b/plugins/wasi_nn/MLX/mlx/activations.h index 68fda485..039863e2 100644 --- a/plugins/wasi_nn/MLX/mlx/activations.h +++ b/plugins/wasi_nn/MLX/mlx/activations.h @@ -3,6 +3,8 @@ #pragma once +#include "base.h" + #include namespace WasmEdge::Host::WASINN::MLX { From 29b207090d227ae942183286936e6b8847499e68 Mon Sep 17 00:00:00 2001 From: Han-Wen Tsao Date: Thu, 3 Oct 2024 00:51:11 +0800 Subject: [PATCH 461/623] [Plugin] Stable Diffusion: add enable OpenMP option (#3810) * [Plugin] Stable Diffusion: add enable OpenMP option Signed-off-by: grorge * [Plugin] Stable Diffusion: set openMP default to off Signed-off-by: grorge --------- Signed-off-by: grorge --- plugins/wasmedge_stablediffusion/CMakeLists.txt | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index 4bcff1ea..6ff6af38 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -19,6 +19,14 @@ else() set(SD_METAL OFF CACHE BOOL "Stable diffusion plugin: Disable SD_METAL") endif() +if(WASMEDGE_PLUGIN_STABLEDIFFUSION_OPENMP) + message(STATUS "Stable diffusion plugin: Enable SD_OPENMP") + set(GGML_OPENMP ON) +else() + message(STATUS "Stable diffusion plugin: Disable SD_OPENMP") + set(GGML_OPENMP OFF) +endif() + # setup stable diffusion message(STATUS "Downloading stable diffusion source") FetchContent_Declare( From f6d81c45b26072b9a575a77dbcec7faeb7a9c87d Mon Sep 17 00:00:00 2001 From: Han-Wen Tsao Date: Sat, 5 Oct 2024 16:22:24 +0800 Subject: [PATCH 462/623] [WASI-NN] whisper: add metal and cuda option (#3815) * [WASI-NN] whisper: add metal and cuda option Signed-off-by: grorge * [WASI-NN] whisper: remove copy header Signed-off-by: grorge --------- Signed-off-by: grorge --- plugins/wasi_nn/CMakeLists.txt | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 775d9564..5d714dec 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -177,6 +177,21 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) set(GGML_OPENMP OFF) set(GGML_ACCELERATE OFF) set(GGML_BLAS OFF) + if(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64" AND WASMEDGE_PLUGIN_WASI_NN_WHISPER_METAL) + message(STATUS "WASI-NN Whisper backend: Enable GGML_METAL") + set(GGML_METAL ON) + set(GGML_METAL_EMBED_LIBRARY ON) + else() + message(STATUS "WASI-NN Whisper backend: Disable GGML_METAL") + set(GGML_METAL OFF) + endif() + if(WASMEDGE_PLUGIN_WASI_NN_WHISPER_CUDA) + message(STATUS "WASI-NN Whisper backend: Enable GGML_CUDA") + set(GGML_CUDA ON) + else() + message(STATUS "WASI-NN Whisper backend: Disable GGML_CUDA") + set(GGML_CUDA OFF) + endif() include(FetchContent) FetchContent_Declare( whisper From 87272829aaa3bd94b98e1fba05f9ae619c500664 Mon Sep 17 00:00:00 2001 From: LFsWang <7088579+LFsWang@users.noreply.github.com> Date: Mon, 7 Oct 2024 18:41:50 +0800 Subject: [PATCH 463/623] [Plugin] Update pytorch version (#3818) --- utils/docker/Dockerfile.manylinux2014-build-plugins-deps | 2 +- utils/docker/Dockerfile.manylinux_2_28-plugins-deps | 2 +- utils/docker/Dockerfile.ubuntu-plugins-deps | 2 +- utils/wasi-nn/install-pytorch.sh | 8 ++++---- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/utils/docker/Dockerfile.manylinux2014-build-plugins-deps b/utils/docker/Dockerfile.manylinux2014-build-plugins-deps index e2aab3ea..44b24d4f 100644 --- a/utils/docker/Dockerfile.manylinux2014-build-plugins-deps +++ b/utils/docker/Dockerfile.manylinux2014-build-plugins-deps @@ -18,7 +18,7 @@ ENV OPENCV_VERSION "4.8.0" RUN [ "/bin/bash", "install-opencvmini.sh" ] COPY wasi-nn/install-pytorch.sh . -ENV PYTORCH_VERSION "1.8.2" +ENV PYTORCH_VERSION "2.4.1" ENV PYTORCH_INSTALL_TO "/root" ENV Torch_DIR "/root/libtorch" RUN [ "/bin/bash", "install-pytorch.sh", "--disable-cxx11-abi" ] diff --git a/utils/docker/Dockerfile.manylinux_2_28-plugins-deps b/utils/docker/Dockerfile.manylinux_2_28-plugins-deps index ec7be56d..d983fcb4 100644 --- a/utils/docker/Dockerfile.manylinux_2_28-plugins-deps +++ b/utils/docker/Dockerfile.manylinux_2_28-plugins-deps @@ -9,7 +9,7 @@ RUN cd && (yum check-update || true) && \ yum install -y wget unzip zlib-devel zlib-static elfutils-libelf-devel COPY wasi-nn/install-pytorch.sh . -ENV PYTORCH_VERSION="1.8.2" +ENV PYTORCH_VERSION="2.4.1" ENV PYTORCH_INSTALL_TO="/root" ENV Torch_DIR="/root/libtorch" RUN [ "/bin/bash", "install-pytorch.sh" ] diff --git a/utils/docker/Dockerfile.ubuntu-plugins-deps b/utils/docker/Dockerfile.ubuntu-plugins-deps index eaa2fd82..3684c463 100644 --- a/utils/docker/Dockerfile.ubuntu-plugins-deps +++ b/utils/docker/Dockerfile.ubuntu-plugins-deps @@ -29,7 +29,7 @@ ENV PKG_CONFIG_PATH=/root/FFmpeg-n6.0/output/lib/pkgconfig${PKG_CONFIG_PATH:+:${ ENV LD_LIBRARY_PATH=/root/FFmpeg-n6.0/output/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} COPY wasi-nn/install-pytorch.sh . -ENV PYTORCH_VERSION="1.8.2" +ENV PYTORCH_VERSION="2.4.1" ENV PYTORCH_INSTALL_TO="/root" ENV Torch_DIR="/root/libtorch" RUN [ "/bin/bash", "install-pytorch.sh" ] diff --git a/utils/wasi-nn/install-pytorch.sh b/utils/wasi-nn/install-pytorch.sh index f3fd070e..1f6ac1d6 100755 --- a/utils/wasi-nn/install-pytorch.sh +++ b/utils/wasi-nn/install-pytorch.sh @@ -3,7 +3,7 @@ # SPDX-FileCopyrightText: 2019-2024 Second State INC if [[ ! -n ${PYTORCH_VERSION} ]]; then - PYTORCH_VERSION="1.8.2" + PYTORCH_VERSION="2.4.1" fi if [[ ! -n ${PYTORCH_INSTALL_TO} ]]; then @@ -11,20 +11,20 @@ if [[ ! -n ${PYTORCH_INSTALL_TO} ]]; then fi PYTORCH_LINK="libtorch-cxx11-abi" -PYTORCH_SHA="b76d6dd4380e2233ce6f7654e672e13aae7c871231d223a4267ef018dcbfb616" +PYTORCH_SHA="415c3ed51c766a6ef20dc10b2e60fae7f10a3ae8aa62223d6f4bccc1fc98740b" for i in "$@"; do case $i in --disable-cxx11-abi) PYTORCH_LINK="libtorch" - PYTORCH_SHA="b5ddadc9addc054d8503f4086546f0cbcfdc3fc70087863bbd7b0e3300e3247f" + PYTORCH_SHA="f49d55df661c566c29a7a75bcae2fad69177eaebd330618d42ca162eb3a1fad1" shift ;; esac done if [ ! -d ${PYTORCH_INSTALL_TO}/libtorch ]; then - curl -s -L -O --remote-name-all https://download.pytorch.org/libtorch/lts/1.8/cpu/${PYTORCH_LINK}-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip + curl -s -L -O --remote-name-all https://download.pytorch.org/libtorch/cpu/${PYTORCH_LINK}-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip echo "${PYTORCH_SHA} ${PYTORCH_LINK}-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip" | sha256sum -c unzip -q "${PYTORCH_LINK}-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip" -d ${PYTORCH_INSTALL_TO} rm -f "${PYTORCH_LINK}-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip" From 4eb7d03155fa09496dcbbb04a8b63f0b839333f4 Mon Sep 17 00:00:00 2001 From: Yi Date: Tue, 8 Oct 2024 14:27:30 +0800 Subject: [PATCH 464/623] [Docker] Ubuntu: Fix OpenVINO apt source (#3819) Signed-off-by: Yi Huang --- utils/docker/Dockerfile.ubuntu-env | 15 +++++++++++ utils/docker/Dockerfile.ubuntu-gcc | 5 ---- utils/docker/Dockerfile.ubuntu-plugins-deps | 4 +++ utils/docker/docker-bake.ubuntu.hcl | 29 +++++---------------- utils/wasi-nn/install-openvino.sh | 4 +-- 5 files changed, 28 insertions(+), 29 deletions(-) create mode 100644 utils/docker/Dockerfile.ubuntu-env delete mode 100644 utils/docker/Dockerfile.ubuntu-gcc diff --git a/utils/docker/Dockerfile.ubuntu-env b/utils/docker/Dockerfile.ubuntu-env new file mode 100644 index 00000000..dde599e0 --- /dev/null +++ b/utils/docker/Dockerfile.ubuntu-env @@ -0,0 +1,15 @@ +ARG BASE_IMAGE=wasmedge/wasmedge:latest +ARG TOOLCHAIN=clang +FROM ${BASE_IMAGE} AS base + +### env for clang +FROM base AS deps-clang + +### env for gcc +FROM base AS deps-gcc + +ENV CC=/usr/bin/gcc +ENV CXX=/usr/bin/g++ + +### final +FROM deps-${TOOLCHAIN} AS final diff --git a/utils/docker/Dockerfile.ubuntu-gcc b/utils/docker/Dockerfile.ubuntu-gcc deleted file mode 100644 index 1b201ae2..00000000 --- a/utils/docker/Dockerfile.ubuntu-gcc +++ /dev/null @@ -1,5 +0,0 @@ -ARG BASE_IMAGE=wasmedge/wasmedge:latest -FROM ${BASE_IMAGE} AS base - -ENV CC=/usr/bin/gcc -ENV CXX=/usr/bin/g++ diff --git a/utils/docker/Dockerfile.ubuntu-plugins-deps b/utils/docker/Dockerfile.ubuntu-plugins-deps index 3684c463..1ba0da63 100644 --- a/utils/docker/Dockerfile.ubuntu-plugins-deps +++ b/utils/docker/Dockerfile.ubuntu-plugins-deps @@ -1,4 +1,5 @@ ARG BASE_IMAGE=wasmedge/wasmedge:latest +ARG UBUNTU_VER=20 FROM ${BASE_IMAGE} AS base WORKDIR /root @@ -34,7 +35,10 @@ ENV PYTORCH_INSTALL_TO="/root" ENV Torch_DIR="/root/libtorch" RUN [ "/bin/bash", "install-pytorch.sh" ] +ARG UBUNTU_VER + COPY wasi-nn/install-openvino.sh . +ENV OPENVINO_UBUNTU_VERSION=${UBUNTU_VER} ENV OPENVINO_VERSION="2024.2.0" ENV OPENVINO_YEAR="2024" RUN [ "/bin/bash", "install-openvino.sh" ] diff --git a/utils/docker/docker-bake.ubuntu.hcl b/utils/docker/docker-bake.ubuntu.hcl index 1c7ad294..d7ae601d 100644 --- a/utils/docker/docker-bake.ubuntu.hcl +++ b/utils/docker/docker-bake.ubuntu.hcl @@ -1,8 +1,7 @@ group "default" { targets = [ "cuda", - "clang", - "gcc" + "final" ] } @@ -86,7 +85,6 @@ target "plugins" { ubuntu = ["20.04", "22.04"] } - inherits = ["base-${no-dot(ubuntu)}"] name = "plugins-${no-dot(ubuntu)}" contexts = { "local/tmp:base-${ubuntu}" = "target:base-${no-dot(ubuntu)}" @@ -98,37 +96,24 @@ target "plugins" { } } -target "clang" { +target "final" { matrix = { parent = ["base", "plugins"] ubuntu = ["20.04", "22.04"] + toolchain = ["clang", "gcc"] } - inherits = ["${parent}-${no-dot(ubuntu)}"] - name = "${parent}-${no-dot(ubuntu)}-clang" - contexts = { - "local/tmp:${parent}-${ubuntu}" = "target:${parent}-${no-dot(ubuntu)}" - } - tags = tags(parent, ubuntu, "clang") -} - -target "gcc" { - dockerfile = "Dockerfile.ubuntu-gcc" + dockerfile = "Dockerfile.ubuntu-env" context = "./utils/docker" - matrix = { - parent = ["base", "plugins"] - ubuntu = ["20.04", "22.04"] - } - - inherits = ["${parent}-${no-dot(ubuntu)}"] - name = "${parent}-${no-dot(ubuntu)}-gcc" + name = "${parent}-${no-dot(ubuntu)}-${toolchain}" contexts = { "local/tmp:${parent}-${ubuntu}" = "target:${parent}-${no-dot(ubuntu)}" } - tags = tags(parent, ubuntu, "gcc") + tags = tags(parent, ubuntu, toolchain) args = { BASE_IMAGE = "local/tmp:${parent}-${ubuntu}" + TOOLCHAIN = toolchain } } diff --git a/utils/wasi-nn/install-openvino.sh b/utils/wasi-nn/install-openvino.sh index cd4030c0..9950caca 100755 --- a/utils/wasi-nn/install-openvino.sh +++ b/utils/wasi-nn/install-openvino.sh @@ -1,14 +1,14 @@ #!/usr/bin/env bash # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-FileCopyrightText: 2019-2024 Second State INC - set -e echo "Installing OpenVINO with version 2024.2.0" KEY_FILE=GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB wget https://apt.repos.intel.com/intel-gpg-keys/$KEY_FILE && \ apt-key add $KEY_FILE && \ rm -f $KEY_FILE -echo "deb https://apt.repos.intel.com/openvino/2024 ubuntu20 main" | tee /etc/apt/sources.list.d/intel-openvino-2024.list +UBUNTU_VERSION="ubuntu${OPENVINO_UBUNTU_VERSION:-20}" +echo "deb https://apt.repos.intel.com/openvino/2024 ${UBUNTU_VERSION} main" | tee /etc/apt/sources.list.d/intel-openvino-2024.list apt update apt-get -y install openvino-2024.2.0 ldconfig From 469e618481d77a364669e270e368873af3ce352b Mon Sep 17 00:00:00 2001 From: Han-Wen Tsao Date: Fri, 11 Oct 2024 14:25:27 +0800 Subject: [PATCH 465/623] [Plugin] Stable Diffusion: fix reuse context segment fault (#3824) --- plugins/wasmedge_stablediffusion/sd_func.cpp | 2 +- .../wasmedge_stablediffusion.cpp | 47 +++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/plugins/wasmedge_stablediffusion/sd_func.cpp b/plugins/wasmedge_stablediffusion/sd_func.cpp index f8879122..14ec327a 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.cpp +++ b/plugins/wasmedge_stablediffusion/sd_func.cpp @@ -224,7 +224,7 @@ Expect SDCreateContext::body( diffusionModelPath.data(), VaePath.data(), TaesdPath.data(), ControlNetPath.data(), LoraModelDir.data(), EmbedDir.data(), IdEmbedDir.data(), static_cast(VaeDecodeOnly), - static_cast(VaeTiling), true, NThreads, + static_cast(VaeTiling), false, NThreads, static_cast(Wtype), static_cast(RngType), static_cast(Schedule), ClipOnCpu, ControlNetCpu, VaeOnCpu); if (Ctx == nullptr) { diff --git a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp index a589a9f2..318c9cc7 100644 --- a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp +++ b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp @@ -189,6 +189,53 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { } // Test: text_to_image -- generate image from text. + { + uint32_t PromptPtr = UINT32_C(0); + uint32_t OutputPathPtr = PromptPtr + PromptData.size(); + uint32_t BytesWrittenPtr = OutputPathPtr + OutputPath.size(); + OutputPtr = BytesWrittenPtr + 4; + writeBinaries(MemInst, PromptData, PromptPtr); + writeBinaries(MemInst, OutputPath, OutputPathPtr); + EXPECT_TRUE(HostFuncTextToImage.run( + CallFrame, + std::initializer_list{ + PromptPtr, // PromptPtr + static_cast(PromptData.size()), // PromptLen + SessionId, // SessionId + 0, // ControlImagePtr + 0, // ControlImageLen + 0, // NegativePromptPtr + 0, // NegativePromptLen + 3.5f, // Guidance + 256, // Width + 256, // Height + -1, // ClipSkip + 7.0f, // CfgScale + 0, // SampleMethod + 1, // SampleSteps + 42, // Seed + 1, // BatchCount + 0.90f, // ControlStrength + 20.0f, // StyleRatio + 0, // NormalizeInput + 0, // InputIdImagesDirPtr + 0, // InputIdImagesDirLen + 0, // CannyPreprocess + 0, // UpscaleModelPathPtr + 0, // UpscaleModelPathLen + 1, // UpscaleRepeats + OutputPathPtr, // OutputPathPtr + static_cast(OutputPath.size()), // OutputPathLen + OutputPtr, // OutBufferPtr + 1048512, // OutBufferMaxSize + BytesWrittenPtr}, // BytesWrittenPtr + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + auto BytesWritten = *MemInst.getPointer(BytesWrittenPtr); + EXPECT_GE(BytesWritten, 50); + EXPECT_TRUE(std::filesystem::exists(OutputPathString)); + } + // Test: text_to_image -- reuse context to generate image from text. { uint32_t PromptPtr = UINT32_C(0); uint32_t OutputPathPtr = PromptPtr + PromptData.size(); From c3f7a944791dd894812643f4c157b108c5b825dd Mon Sep 17 00:00:00 2001 From: LFsWang <7088579+LFsWang@users.noreply.github.com> Date: Mon, 14 Oct 2024 16:32:49 +0800 Subject: [PATCH 466/623] [Docker] Disable pytorch c++11 abi on manylinux_2_28 (#3826) Signed-off-by: Sylveon --- utils/docker/Dockerfile.manylinux_2_28-plugins-deps | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/docker/Dockerfile.manylinux_2_28-plugins-deps b/utils/docker/Dockerfile.manylinux_2_28-plugins-deps index d983fcb4..5db52278 100644 --- a/utils/docker/Dockerfile.manylinux_2_28-plugins-deps +++ b/utils/docker/Dockerfile.manylinux_2_28-plugins-deps @@ -12,7 +12,7 @@ COPY wasi-nn/install-pytorch.sh . ENV PYTORCH_VERSION="2.4.1" ENV PYTORCH_INSTALL_TO="/root" ENV Torch_DIR="/root/libtorch" -RUN [ "/bin/bash", "install-pytorch.sh" ] +RUN [ "/bin/bash", "install-pytorch.sh", "--disable-cxx11-abi" ] ### deps for aarch64 ### FROM base AS deps-arm64 From 48f518add9612bd80f6baf56c38465e0739d70a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=AEm=20Ts=C3=BA-thu=C3=A0n?= Date: Thu, 2 May 2024 15:48:53 +0800 Subject: [PATCH 467/623] [Draft] interface types and values MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * create ValInterface as std::variant on ValVariant * InterfaceType inherit ValType * execute APIs and type checker Signed-off-by: Lîm Tsú-thuàn solve `execute` signature problem Signed-off-by: Lîm Tsú-thuàn use the correct HostFunc Signed-off-by: Lîm Tsú-thuàn remove CallingFrame Signed-off-by: Lîm Tsú-thuàn conversion for interface type Signed-off-by: Lîm Tsú-thuàn --- plugins/wasi_http/base.h | 7 ++++--- plugins/wasi_http/func.cpp | 15 ++++++--------- plugins/wasi_http/func.h | 4 ++-- plugins/wasi_poll/base.h | 5 +++-- plugins/wasi_poll/func.cpp | 2 +- plugins/wasi_poll/func.h | 2 +- 6 files changed, 17 insertions(+), 18 deletions(-) diff --git a/plugins/wasi_http/base.h b/plugins/wasi_http/base.h index 0c5fdbd2..48243ec7 100644 --- a/plugins/wasi_http/base.h +++ b/plugins/wasi_http/base.h @@ -6,15 +6,16 @@ #include "env.h" #include "common/errcode.h" -#include "runtime/hostfunc.h" +#include "runtime/component/hostfunc.h" namespace WasmEdge { namespace Host { -template class WasiHttp : public Runtime::HostFunction { +template +class WasiHttp : public Runtime::Component::HostFunction { public: WasiHttp(WasiHttpEnvironment &HostEnv) - : Runtime::HostFunction(0), Env(HostEnv) {} + : Runtime::Component::HostFunction(), Env(HostEnv) {} protected: WasiHttpEnvironment &Env; diff --git a/plugins/wasi_http/func.cpp b/plugins/wasi_http/func.cpp index f00b2121..245c9942 100644 --- a/plugins/wasi_http/func.cpp +++ b/plugins/wasi_http/func.cpp @@ -13,21 +13,18 @@ namespace WasmEdge { namespace Host { -Expect WasiHttpPrint::body(const Runtime::CallingFrame &, - StrVariant Str) { - spdlog::info("[WASI-HTTP] print: {}", Str.getString()); +Expect WasiHttpPrint::body(std::string S) { + spdlog::info("[WASI-HTTP] print: {}", S); return {}; } -Expect WasiHttpGet::body(const Runtime::CallingFrame &, - StrVariant URI) { - const auto &S = URI.getString(); - spdlog::info("[WASI-HTTP] URI: {}", S); +Expect WasiHttpGet::body(std::string URI) { + spdlog::info("[WASI-HTTP] URI: {}", URI); cpr::Response Res = cpr::Get( - cpr::Url{S}, cpr::Authentication{"user", "pass", cpr::AuthMode::BASIC}); + cpr::Url{URI}, cpr::Authentication{"user", "pass", cpr::AuthMode::BASIC}); spdlog::info("[WASI-HTTP] status: {}", Res.status_code); - return StrVariant(std::move(Res.text)); + return std::move(Res.text); } } // namespace Host diff --git a/plugins/wasi_http/func.h b/plugins/wasi_http/func.h index 79ce1c83..855f3855 100644 --- a/plugins/wasi_http/func.h +++ b/plugins/wasi_http/func.h @@ -12,13 +12,13 @@ namespace Host { class WasiHttpPrint : public WasiHttp { public: WasiHttpPrint(WasiHttpEnvironment &HostEnv) : WasiHttp(HostEnv) {} - Expect body(const Runtime::CallingFrame &Frame, StrVariant Str); + Expect body(std::string Str); }; class WasiHttpGet : public WasiHttp { public: WasiHttpGet(WasiHttpEnvironment &HostEnv) : WasiHttp(HostEnv) {} - Expect body(const Runtime::CallingFrame &Frame, StrVariant URI); + Expect body(std::string URI); }; } // namespace Host diff --git a/plugins/wasi_poll/base.h b/plugins/wasi_poll/base.h index e1c50bb7..0a8f86fb 100644 --- a/plugins/wasi_poll/base.h +++ b/plugins/wasi_poll/base.h @@ -11,10 +11,11 @@ namespace WasmEdge { namespace Host { -template class WasiPoll : public Runtime::HostFunction { +template +class WasiPoll : public Runtime::Component::HostFunction { public: WasiPoll(WasiPollEnvironment &HostEnv) - : Runtime::HostFunction(0), Env(HostEnv) {} + : Runtime::Component::HostFunction(), Env(HostEnv) {} protected: WasiPollEnvironment &Env; diff --git a/plugins/wasi_poll/func.cpp b/plugins/wasi_poll/func.cpp index 732208dd..8cabd345 100644 --- a/plugins/wasi_poll/func.cpp +++ b/plugins/wasi_poll/func.cpp @@ -8,7 +8,7 @@ namespace WasmEdge { namespace Host { -Expect Drop::body(const Runtime::CallingFrame &, Pollable) { return {}; } +Expect Drop::body(Pollable) { return {}; } } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasi_poll/func.h b/plugins/wasi_poll/func.h index 4daa13bc..1fb1fa26 100644 --- a/plugins/wasi_poll/func.h +++ b/plugins/wasi_poll/func.h @@ -14,7 +14,7 @@ using Pollable = uint32_t; class Drop : public WasiPoll { public: Drop(WasiPollEnvironment &HostEnv) : WasiPoll(HostEnv) {} - Expect body(const Runtime::CallingFrame &Frame, Pollable This); + Expect body(Pollable This); }; // poll-oneoff: func(in: list) -> list From 3213a5efaf3c06567f3730fa4216bec676282fbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=AEm=20Ts=C3=BA-thu=C3=A0n?= Date: Wed, 26 Jun 2024 15:06:30 +0800 Subject: [PATCH 468/623] [Plugin] impl wasi-poll poll-oneoff MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update hostfunc's fromTuple to update Span of ValInterface * add emplace helper * fix `Wit` type converter, it still convert `vector` rather than `List` Signed-off-by: Lîm Tsú-thuàn --- plugins/wasi_poll/func.cpp | 18 +++++++++++++++++- plugins/wasi_poll/func.h | 8 +++++++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_poll/func.cpp b/plugins/wasi_poll/func.cpp index 8cabd345..b6326776 100644 --- a/plugins/wasi_poll/func.cpp +++ b/plugins/wasi_poll/func.cpp @@ -8,7 +8,23 @@ namespace WasmEdge { namespace Host { -Expect Drop::body(Pollable) { return {}; } +bool isPollable(Pollable) { + // TODO: use a global HashMap to note this + return false; +} + +Expect Drop::body(Pollable) { + // TODO: ensure this affect the global HashMap + return {}; +} + +Expect> PollOneoff::body(List In) { + std::vector Res; + for (auto P : In.collection()) { + Res.push_back(isPollable(P)); + } + return List(std::move(Res)); +} } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasi_poll/func.h b/plugins/wasi_poll/func.h index 1fb1fa26..71e15f58 100644 --- a/plugins/wasi_poll/func.h +++ b/plugins/wasi_poll/func.h @@ -11,13 +11,19 @@ namespace Host { using Pollable = uint32_t; +bool isPollable(Pollable); + class Drop : public WasiPoll { public: Drop(WasiPollEnvironment &HostEnv) : WasiPoll(HostEnv) {} Expect body(Pollable This); }; -// poll-oneoff: func(in: list) -> list +class PollOneoff : public WasiPoll { +public: + PollOneoff(WasiPollEnvironment &HostEnv) : WasiPoll(HostEnv) {} + Expect> body(List In); +}; } // namespace Host } // namespace WasmEdge From 55f7f296ad7679680bc0f8ac5823804384c072e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=AEm=20Ts=C3=BA-thu=C3=A0n?= Date: Thu, 27 Jun 2024 11:00:40 +0800 Subject: [PATCH 469/623] [Plugin] put pollable map into Environment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Lîm Tsú-thuàn --- plugins/wasi_poll/env.cpp | 4 +++- plugins/wasi_poll/env.h | 8 ++++++++ plugins/wasi_poll/func.cpp | 11 +++-------- plugins/wasi_poll/func.h | 4 ---- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/plugins/wasi_poll/env.cpp b/plugins/wasi_poll/env.cpp index 90b5e632..0acd2fbb 100644 --- a/plugins/wasi_poll/env.cpp +++ b/plugins/wasi_poll/env.cpp @@ -7,8 +7,10 @@ namespace WasmEdge { namespace Host { -WasiPollEnvironment::WasiPollEnvironment() noexcept {} +WasiPollEnvironment::WasiPollEnvironment() noexcept : PollableMap() {} +bool WasiPollEnvironment::isPollable(Pollable P) { return PollableMap.at(P); } +void WasiPollEnvironment::dropPollable(Pollable P) { PollableMap.erase(P); } namespace { Runtime::Instance::ComponentInstance * diff --git a/plugins/wasi_poll/env.h b/plugins/wasi_poll/env.h index e7d8ab1d..8ac49f6c 100644 --- a/plugins/wasi_poll/env.h +++ b/plugins/wasi_poll/env.h @@ -10,9 +10,17 @@ namespace WasmEdge { namespace Host { +using Pollable = uint32_t; + class WasiPollEnvironment { public: WasiPollEnvironment() noexcept; + + bool isPollable(Pollable P); + void dropPollable(Pollable P); + +private: + std::unordered_map PollableMap; }; } // namespace Host diff --git a/plugins/wasi_poll/func.cpp b/plugins/wasi_poll/func.cpp index b6326776..8c622448 100644 --- a/plugins/wasi_poll/func.cpp +++ b/plugins/wasi_poll/func.cpp @@ -8,20 +8,15 @@ namespace WasmEdge { namespace Host { -bool isPollable(Pollable) { - // TODO: use a global HashMap to note this - return false; -} - -Expect Drop::body(Pollable) { - // TODO: ensure this affect the global HashMap +Expect Drop::body(Pollable P) { + Env.dropPollable(P); return {}; } Expect> PollOneoff::body(List In) { std::vector Res; for (auto P : In.collection()) { - Res.push_back(isPollable(P)); + Res.push_back(Env.isPollable(P)); } return List(std::move(Res)); } diff --git a/plugins/wasi_poll/func.h b/plugins/wasi_poll/func.h index 71e15f58..0c5c4949 100644 --- a/plugins/wasi_poll/func.h +++ b/plugins/wasi_poll/func.h @@ -9,10 +9,6 @@ namespace WasmEdge { namespace Host { -using Pollable = uint32_t; - -bool isPollable(Pollable); - class Drop : public WasiPoll { public: Drop(WasiPollEnvironment &HostEnv) : WasiPoll(HostEnv) {} From 2b14f9f58173921852e35ff2f3d7b27f5614e74a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=AEm=20Ts=C3=BA-thu=C3=A0n?= Date: Mon, 12 Aug 2024 13:41:13 +0800 Subject: [PATCH 470/623] [Misc] use default constructor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Lîm Tsú-thuàn --- plugins/wasi_poll/env.cpp | 2 -- plugins/wasi_poll/env.h | 2 -- 2 files changed, 4 deletions(-) diff --git a/plugins/wasi_poll/env.cpp b/plugins/wasi_poll/env.cpp index 0acd2fbb..8fb42f4d 100644 --- a/plugins/wasi_poll/env.cpp +++ b/plugins/wasi_poll/env.cpp @@ -7,8 +7,6 @@ namespace WasmEdge { namespace Host { -WasiPollEnvironment::WasiPollEnvironment() noexcept : PollableMap() {} - bool WasiPollEnvironment::isPollable(Pollable P) { return PollableMap.at(P); } void WasiPollEnvironment::dropPollable(Pollable P) { PollableMap.erase(P); } namespace { diff --git a/plugins/wasi_poll/env.h b/plugins/wasi_poll/env.h index 8ac49f6c..65ae096c 100644 --- a/plugins/wasi_poll/env.h +++ b/plugins/wasi_poll/env.h @@ -14,8 +14,6 @@ using Pollable = uint32_t; class WasiPollEnvironment { public: - WasiPollEnvironment() noexcept; - bool isPollable(Pollable P); void dropPollable(Pollable P); From 254824d8412544c1896ad4a23977f24cc13658c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=AEm=20Ts=C3=BA-thu=C3=A0n?= Date: Mon, 12 Aug 2024 13:43:39 +0800 Subject: [PATCH 471/623] [Misc] add `noexcept` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Lîm Tsú-thuàn --- plugins/wasi_poll/env.cpp | 4 +++- plugins/wasi_poll/env.h | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_poll/env.cpp b/plugins/wasi_poll/env.cpp index 8fb42f4d..223839ad 100644 --- a/plugins/wasi_poll/env.cpp +++ b/plugins/wasi_poll/env.cpp @@ -7,7 +7,9 @@ namespace WasmEdge { namespace Host { -bool WasiPollEnvironment::isPollable(Pollable P) { return PollableMap.at(P); } +bool WasiPollEnvironment::isPollable(Pollable P) noexcept { + return PollableMap.at(P); +} void WasiPollEnvironment::dropPollable(Pollable P) { PollableMap.erase(P); } namespace { diff --git a/plugins/wasi_poll/env.h b/plugins/wasi_poll/env.h index 65ae096c..4852f7b0 100644 --- a/plugins/wasi_poll/env.h +++ b/plugins/wasi_poll/env.h @@ -14,7 +14,7 @@ using Pollable = uint32_t; class WasiPollEnvironment { public: - bool isPollable(Pollable P); + bool isPollable(Pollable P) noexcept; void dropPollable(Pollable P); private: From 0ff7e7bb8b1d4225d19798a38261091f6aaa6c11 Mon Sep 17 00:00:00 2001 From: vincent Date: Mon, 14 Oct 2024 11:08:09 +0800 Subject: [PATCH 472/623] [Misc] move the rust plugin into a separate repository https://github.com/WasmEdge/rust-plugins Signed-off-by: vincent --- plugins/CMakeLists.txt | 3 - plugins/wasi_nn_burnrs/.gitignore | 1 - plugins/wasi_nn_burnrs/CMakeLists.txt | 25 -- plugins/wasi_nn_burnrs/Cargo.toml | 38 -- plugins/wasi_nn_burnrs/src/helper.rs | 11 - plugins/wasi_nn_burnrs/src/lib.rs | 382 ------------------ plugins/wasi_nn_burnrs/src/models/mod.rs | 4 - .../wasi_nn_burnrs/src/models/squeezenet.rs | 45 --- plugins/wasi_nn_burnrs/src/models/whisper.rs | 96 ----- 9 files changed, 605 deletions(-) delete mode 100644 plugins/wasi_nn_burnrs/.gitignore delete mode 100644 plugins/wasi_nn_burnrs/CMakeLists.txt delete mode 100644 plugins/wasi_nn_burnrs/Cargo.toml delete mode 100644 plugins/wasi_nn_burnrs/src/helper.rs delete mode 100644 plugins/wasi_nn_burnrs/src/lib.rs delete mode 100644 plugins/wasi_nn_burnrs/src/models/mod.rs delete mode 100644 plugins/wasi_nn_burnrs/src/models/squeezenet.rs delete mode 100644 plugins/wasi_nn_burnrs/src/models/whisper.rs diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index 8f57a5fa..b4d89734 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -21,9 +21,6 @@ endif() if(WASMEDGE_PLUGIN_WASI_NN_BACKEND) add_subdirectory(wasi_nn) endif() -if(WASMEDGE_PLUGIN_WASI_NN_BURNRS_MODEL) - add_subdirectory(wasi_nn_burnrs) -endif() # WASI plug-in: WASI-Poll proposal. if(WASMEDGE_PLUGIN_WASI_POLL) diff --git a/plugins/wasi_nn_burnrs/.gitignore b/plugins/wasi_nn_burnrs/.gitignore deleted file mode 100644 index eb5a316c..00000000 --- a/plugins/wasi_nn_burnrs/.gitignore +++ /dev/null @@ -1 +0,0 @@ -target diff --git a/plugins/wasi_nn_burnrs/CMakeLists.txt b/plugins/wasi_nn_burnrs/CMakeLists.txt deleted file mode 100644 index ad7c1344..00000000 --- a/plugins/wasi_nn_burnrs/CMakeLists.txt +++ /dev/null @@ -1,25 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: 2019-2024 Second State INC - -if(CMAKE_BUILD_TYPE STREQUAL "Debug") - set(CARGO_CMD cargo build) - set(TARGET_DIR "debug") -else() - set(CARGO_CMD cargo build --release) - set(TARGET_DIR "release") -endif() - -message(STATUS "WasmEdge WASI-NN Burn.rs backend plugin model: ${WASMEDGE_PLUGIN_WASI_NN_BURNRS_MODEL}") -set(CARGO_FEATURES "--features=${WASMEDGE_PLUGIN_WASI_NN_BURNRS_MODEL}") - -set(RS_SO ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_DIR}/libwasmedgePluginWasiNN${CMAKE_SHARED_LIBRARY_SUFFIX}) - -set(WASMEDGE_LIB_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../lib/api) - -add_custom_target(wasmedgePluginWasiNNBurnRS ALL - COMMAND WASMEDGE_LIB_DIR=${WASMEDGE_LIB_DIR} LD_LIBARAY_PATH=${WASMEDGE_LIB_DIR} CARGO_TARGET_DIR=${CMAKE_CURRENT_BINARY_DIR} ${CARGO_CMD} ${CARGO_FEATURES} - COMMAND ${CMAKE_COMMAND} -E copy ${RS_SO} ${CMAKE_CURRENT_BINARY_DIR} - COMMAND ${CMAKE_COMMAND} -E remove_directory ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_DIR} - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} - DEPENDS wasmedge_shared -) diff --git a/plugins/wasi_nn_burnrs/Cargo.toml b/plugins/wasi_nn_burnrs/Cargo.toml deleted file mode 100644 index 6a71a217..00000000 --- a/plugins/wasi_nn_burnrs/Cargo.toml +++ /dev/null @@ -1,38 +0,0 @@ -[package] -name = "wasi_nn_burnrs" -version = "0.0.1" -edition = "2021" - -[lib] -name = "wasmedgePluginWasiNN" -path = "src/lib.rs" -crate-type = ["cdylib"] - -[features] -default = [] -squeezenet = ["squeezenet-burn"] -whisper = ["whisper-burn", "strum", "strum_macros"] - -[dependencies.squeezenet-burn] -package = "squeezenet-burn" -branch = "prebuilt-feature" -git = "https://github.com/second-state/burn-rs-models.git" -features = ["weights_file"] -default-features = false -optional = true - -[dependencies.whisper-burn] -package = "whisper" -branch = "dev" -git = "https://github.com/second-state/burn-rs-whisper.git" -optional = true - -[dependencies] -burn = { version = "0.13.2", features = ["ndarray", "wgpu"] } -wasmedge_plugin_sdk = { git = "https://github.com/second-state/wasmedge_plugin_rust_sdk.git", features = ["standalone"] } -wasmedge-wasi-nn = "0.8.0" -lazy_static = "1.4.0" -bytemuck = "1.16.0" -cfg-if = "1.0.0" -strum = { version = "0.25.0", optional = true } -strum_macros = { version = "0.25.0", optional = true } diff --git a/plugins/wasi_nn_burnrs/src/helper.rs b/plugins/wasi_nn_burnrs/src/helper.rs deleted file mode 100644 index 81375c34..00000000 --- a/plugins/wasi_nn_burnrs/src/helper.rs +++ /dev/null @@ -1,11 +0,0 @@ -#[macro_export] -macro_rules! get_slice { - ($memory:expr, $ptr:expr, $length:expr, $ty:ty) => {{ - let raw_bytes = $memory - .data_pointer($ptr as usize, $length as usize) - .expect("Failed to get data pointer"); - bytemuck::cast_slice::(raw_bytes) - }}; -} - -pub use get_slice; diff --git a/plugins/wasi_nn_burnrs/src/lib.rs b/plugins/wasi_nn_burnrs/src/lib.rs deleted file mode 100644 index 562f8304..00000000 --- a/plugins/wasi_nn_burnrs/src/lib.rs +++ /dev/null @@ -1,382 +0,0 @@ -mod helper; -mod models; - -pub enum ErrNo { - Success = 0, // No error occurred. - InvalidArgument = 1, // Caller module passed an invalid argument. - InvalidEncoding = 2, // Invalid encoding. - MissingMemory = 3, // Caller module is missing a memory export. - Busy = 4, // Device or resource busy. - RuntimeError = 5, // Runtime Error. - UnsupportedOperation = 6, // Unsupported Operation. - TooLarge = 7, // Too Large. - NotFound = 8, // Not Found. -} -mod wasi_nn { - use crate::helper::get_slice; - #[cfg(feature = "squeezenet")] - use crate::models::squeezenet::*; - #[cfg(feature = "whisper")] - use crate::models::whisper::*; - use crate::ErrNo; - use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; - use burn::backend::NdArray; - use lazy_static::lazy_static; - use std::collections::HashMap; - use std::env; - use std::mem; - use std::process; - use std::sync::Mutex; - - use wasmedge_wasi_nn::TensorType; - use wasmedge_plugin_sdk::{ - error::CoreError, - memory::Memory, - module::{PluginModule, SyncInstanceRef}, - types::{ValType, WasmVal}, - }; - - type NdArrayBackend = NdArray; - type WgpuBackend = Wgpu; - - pub enum Graph { - /// The model is loaded to the NdArray backend - WithNdArrayBackend(GraphInner), - /// The model is loaded to the Wgpu backend - WithWgpuBackend(GraphInner), - } - - enum ExecutionContext { - /// The model is loaded to the NdArray backend - WithNdArrayBackend(ContextInner), - /// The model is loaded to the Wgpu backend - WithWgpuBackend(ContextInner), - } - - lazy_static! { - static ref GRAPH_HANDLE_MAP: Mutex> = Mutex::new(HashMap::new()); - static ref GRAPH_NAME_MAP: Mutex> = Mutex::new(HashMap::new()); - static ref CONTEXT_HANDLE_MAP: Mutex> = - Mutex::new(HashMap::new()); - } - - fn parse_opts() { - fn process_nn_preload(nn_preload: String) { - let parts: Vec<&str> = nn_preload.split(':').collect(); - - if parts.len() < 4 { - eprintln!("[WASI_NN] Invalid nn-preload format. {:?} len < 4", parts); - process::exit(1); - } - - let graph_encoding = parts[1].to_string(); - if graph_encoding.to_lowercase() != "burn" { - eprintln!("[WASI_NN] Unsupported graph encoding. {:?}", graph_encoding); - process::exit(1); - } - - let name = parts[0].to_string(); - let mut graph_map = GRAPH_HANDLE_MAP.lock().unwrap(); - let graph_handle = graph_map.len() as u32; - let mut name_map = GRAPH_NAME_MAP.lock().unwrap(); - name_map.insert(name.clone(), graph_handle); - let target = parts[2].to_string().to_lowercase(); - if target == "gpu" { - let device = WgpuDevice::BestAvailable; - graph_map.insert( - graph_handle, - Graph::WithWgpuBackend(GraphInner::create(parts[3..].to_vec(), &device)), - ); - } else { - let device = Default::default(); - graph_map.insert( - graph_handle, - Graph::WithNdArrayBackend(GraphInner::create(parts[3..].to_vec(), &device)), - ); - }; - } - - unsafe { - if let Ok(nn_preload) = (*crate::nn_preload()).to_string() { - process_nn_preload(nn_preload); - } else if let Ok(env_nn_preload) = env::var("WASMEDGE_WASINN_PRELOAD") { - process_nn_preload(env_nn_preload); - } - } - } - - pub fn create_module() -> PluginModule<()> { - fn load<'a>( - _inst: &'a mut SyncInstanceRef, - _memory: &'a mut Memory, - _data: &'a mut (), - _args: Vec, - ) -> Result, CoreError> { - Ok(vec![WasmVal::I32(ErrNo::UnsupportedOperation as i32)]) - } - - fn load_by_name<'a>( - _inst: &'a mut SyncInstanceRef, - memory: &'a mut Memory, - _data: &'a mut (), - args: Vec, - ) -> Result, CoreError> { - if let [WasmVal::I32(data_ptr), WasmVal::I32(data_len), WasmVal::I32(graph_handle_ptr)] = - &args[..] - { - let bytes = memory - .data_pointer(*data_ptr as usize, *data_len as usize) - .unwrap(); - let name = String::from_utf8_lossy(&bytes); - let name_map = GRAPH_NAME_MAP.lock().unwrap(); - if let Some(handle) = name_map.get(name.as_ref()) { - memory.write_data((*graph_handle_ptr as usize).into(), *handle); - Ok(vec![WasmVal::I32(ErrNo::Success as i32)]) - } else { - Ok(vec![WasmVal::I32(ErrNo::NotFound as i32)]) - } - } else { - Ok(vec![WasmVal::I32(ErrNo::InvalidArgument as i32)]) - } - } - - fn init_execution_context<'a>( - _inst: &'a mut SyncInstanceRef, - memory: &'a mut Memory, - _data: &'a mut (), - args: Vec, - ) -> Result, CoreError> { - if let [WasmVal::I32(graph_handle), WasmVal::I32(context_handle_ptr)] = &args[..] { - let mut context_map = CONTEXT_HANDLE_MAP.lock().unwrap(); - let context_handle = context_map.len() as u32; - let graph_map = GRAPH_HANDLE_MAP.lock().unwrap(); - let graph = graph_map - .get(&(*graph_handle as u32)) - .unwrap_or_else(|| unreachable!()); - match graph { - Graph::WithNdArrayBackend(_) => { - context_map.insert( - context_handle, - ( - *graph_handle as u32, - ExecutionContext::WithNdArrayBackend(ContextInner::new()), - ), - ); - } - Graph::WithWgpuBackend(_) => { - context_map.insert( - context_handle, - ( - *graph_handle as u32, - ExecutionContext::WithWgpuBackend(ContextInner::new()), - ), - ); - } - } - memory.write_data((*context_handle_ptr as usize).into(), context_handle); - Ok(vec![WasmVal::I32(ErrNo::Success as i32)]) - } else { - Ok(vec![WasmVal::I32(ErrNo::InvalidArgument as i32)]) - } - } - - fn set_input<'a>( - _inst: &'a mut SyncInstanceRef, - memory: &'a mut Memory, - _data: &'a mut (), - args: Vec, - ) -> Result, CoreError> { - #[derive(Debug)] - #[repr(C)] - struct WasiTensorData { - dimens_ptr: u32, - dimens_length: u32, - tensor_type: TensorType, - tensor_ptr: u32, - tensor_length: u32, - } - if let [WasmVal::I32(context_handle), WasmVal::I32(input_index), WasmVal::I32(input_tensor_ptr)] = - &args[..] - { - match memory.get_data::((*input_tensor_ptr as usize).into()) { - Some(input_tensor) => { - let raw_dimens = get_slice!( - memory, - input_tensor.dimens_ptr, - INPUT_DIM * mem::size_of::(), - u32 - ); - let dimens: [usize; INPUT_DIM] = raw_dimens - .iter() - .map(|&x| x as usize) - .collect::>() - .try_into() - .unwrap(); - - // FIXME: The type of f32 should be decided at runtime based on input_tensor.tensor_type. - let tensor = get_slice!( - memory, - input_tensor.tensor_ptr, - input_tensor.tensor_length, - f32 - ); - - let mut context_map = CONTEXT_HANDLE_MAP.lock().unwrap(); - let (_, context) = context_map - .get_mut(&(*context_handle as u32)) - .unwrap_or_else(|| unreachable!()); - - match context { - ExecutionContext::WithNdArrayBackend(inner) => { - inner.set_input(*input_index as u32, tensor, dimens); - } - ExecutionContext::WithWgpuBackend(inner) => { - inner.set_input(*input_index as u32, tensor, dimens); - } - } - Ok(vec![WasmVal::I32(ErrNo::Success as i32)]) - } - None => Ok(vec![WasmVal::I32(ErrNo::MissingMemory as i32)]), - } - } else { - Ok(vec![WasmVal::I32(ErrNo::InvalidArgument as i32)]) - } - } - - fn compute<'a>( - _inst: &'a mut SyncInstanceRef, - _memory: &'a mut Memory, - _data: &'a mut (), - args: Vec, - ) -> Result, CoreError> { - if let [WasmVal::I32(context_handle)] = &args[..] { - let mut context_map = CONTEXT_HANDLE_MAP.lock().unwrap(); - let (graph_handle, context) = context_map - .get_mut(&(*context_handle as u32)) - .unwrap_or_else(|| unreachable!()); - - let graph_map = GRAPH_HANDLE_MAP.lock().unwrap(); - let graph = graph_map - .get(graph_handle) - .unwrap_or_else(|| unreachable!()); - - match context { - ExecutionContext::WithNdArrayBackend(ctx_inner) => { - let Graph::WithNdArrayBackend(graph_inner) = graph else { - unreachable!() - }; - let output = - graph_inner.compute((*ctx_inner.inputs.get(&0).unwrap()).clone()); - ctx_inner.outputs.push(output); - } - ExecutionContext::WithWgpuBackend(ctx_inner) => { - let Graph::WithWgpuBackend(graph_inner) = graph else { - unreachable!() - }; - let output = - graph_inner.compute((*ctx_inner.inputs.get(&0).unwrap()).clone()); - ctx_inner.outputs.push(output); - } - }; - Ok(vec![WasmVal::I32(ErrNo::Success as i32)]) - } else { - Ok(vec![WasmVal::I32(ErrNo::InvalidArgument as i32)]) - } - } - - fn get_output<'a>( - _inst: &'a mut SyncInstanceRef, - memory: &'a mut Memory, - _data: &'a mut (), - args: Vec, - ) -> Result, CoreError> { - if let [WasmVal::I32(context_handle), WasmVal::I32(output_index), WasmVal::I32(output_ptr), WasmVal::I32(output_max_size), WasmVal::I32(written_length)] = - &args[..] - { - let mut context_map = CONTEXT_HANDLE_MAP.lock().unwrap(); - let (_, context) = context_map - .get_mut(&(*context_handle as u32)) - .unwrap_or_else(|| unreachable!()); - let raw_output = match context { - ExecutionContext::WithNdArrayBackend(ctx_inner) => { - ctx_inner.get_output(*output_index as usize) - } - ExecutionContext::WithWgpuBackend(ctx_inner) => { - ctx_inner.get_output(*output_index as usize) - } - }; - let output: &[u8] = bytemuck::cast_slice(&raw_output); - if output.len() > *output_max_size as usize { - return Ok(vec![WasmVal::I32(ErrNo::TooLarge as i32)]); - } - memory.write_bytes(output, *output_ptr as u32).unwrap(); - memory.write_data((*written_length as usize).into(), output.len()); - Ok(vec![WasmVal::I32(ErrNo::Success as i32)]) - } else { - Ok(vec![WasmVal::I32(ErrNo::InvalidArgument as i32)]) - } - } - - parse_opts(); - - let mut module = PluginModule::create("wasi_ephemeral_nn", ()).unwrap(); - module - .add_func("load", (vec![ValType::I32; 5], vec![ValType::I32]), load) - .unwrap(); - module - .add_func( - "load_by_name", - (vec![ValType::I32; 3], vec![ValType::I32]), - load_by_name, - ) - .unwrap(); - module - .add_func( - "init_execution_context", - (vec![ValType::I32; 2], vec![ValType::I32]), - init_execution_context, - ) - .unwrap(); - module - .add_func( - "set_input", - (vec![ValType::I32; 3], vec![ValType::I32]), - set_input, - ) - .unwrap(); - module - .add_func( - "compute", - (vec![ValType::I32; 1], vec![ValType::I32]), - compute, - ) - .unwrap(); - module - .add_func( - "get_output", - (vec![ValType::I32; 5], vec![ValType::I32]), - get_output, - ) - .unwrap(); - module - } -} - -use wasi_nn::create_module; -use wasmedge_plugin_sdk::plugin::{option_string, register_plugin, OptionString}; -register_plugin!( - plugin_name = "wasi_nn", - plugin_description = "burn framework adapter as wasi-nn plugin", - version = (0,0,0,1), - modules = [ - {"wasi_nn", "wasinn with burn backend module", create_module} - ], - options = [ - { - "nn-preload", - "Allow preload models from wasinn plugin. Each NN model can be specified as --nn-preload `COMMAND`.", - OptionString, - option_string!("none") - } - ] -); diff --git a/plugins/wasi_nn_burnrs/src/models/mod.rs b/plugins/wasi_nn_burnrs/src/models/mod.rs deleted file mode 100644 index f7f7aadf..00000000 --- a/plugins/wasi_nn_burnrs/src/models/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -#[cfg(feature = "squeezenet")] -pub mod squeezenet; -#[cfg(feature = "whisper")] -pub mod whisper; diff --git a/plugins/wasi_nn_burnrs/src/models/squeezenet.rs b/plugins/wasi_nn_burnrs/src/models/squeezenet.rs deleted file mode 100644 index c4c7d0d1..00000000 --- a/plugins/wasi_nn_burnrs/src/models/squeezenet.rs +++ /dev/null @@ -1,45 +0,0 @@ -use burn::tensor::backend::Backend; -use burn::tensor::Tensor; -use squeezenet_burn::model::squeezenet1::Model; -use std::collections::HashMap; - -pub struct GraphInner { - pub model: Model, -} - -impl GraphInner { - pub fn create(args: Vec<&str>, device: &B::Device) -> Self { - let weights_path = args[0]; - Self { - model: Model::from_file(weights_path, device), - } - } - pub fn compute(&self, input: Tensor) -> Tensor { - self.model.forward(input) - } -} - -pub const INPUT_DIM: usize = 4; -pub const OUTPUT_DIM: usize = 2; - -pub struct ContextInner { - pub inputs: HashMap>, - pub outputs: Vec>, -} - -impl ContextInner { - pub fn new() -> Self { - Self { - inputs: HashMap::new(), - outputs: Vec::new(), - } - } - pub fn set_input(&mut self, key: u32, input: &[B::FloatElem], dimens: [usize; INPUT_DIM]) { - let device = Default::default(); - let tensor = Tensor::::from_data(&*input, &device).reshape(dimens); - self.inputs.insert(key, tensor); - } - pub fn get_output(&mut self, key: usize) -> Vec<::FloatElem> { - self.outputs[key].clone().into_data().value - } -} diff --git a/plugins/wasi_nn_burnrs/src/models/whisper.rs b/plugins/wasi_nn_burnrs/src/models/whisper.rs deleted file mode 100644 index 987f77ed..00000000 --- a/plugins/wasi_nn_burnrs/src/models/whisper.rs +++ /dev/null @@ -1,96 +0,0 @@ -use burn::config::Config; -use burn::module::Module; -use burn::record::{DefaultRecorder, Recorder}; -use burn::tensor::backend::Backend; -use std::collections::HashMap; -use std::marker::PhantomData; -use std::process; -use strum::IntoEnumIterator; -use whisper_burn::model::Whisper as Model; -use whisper_burn::model::WhisperConfig as ModelConfig; -use whisper_burn::token::{Gpt2Tokenizer, Language}; -use whisper_burn::transcribe::waveform_to_text; - -pub struct GraphInner { - pub model: Model, - pub metadata: Vec, -} - -impl GraphInner { - pub fn create(args: Vec<&str>, device: &B::Device) -> Self { - if args.len() < 4 { - eprintln!( - "[WASI_NN] Invalid nn-preload model format. {:?} len < 4", - args - ); - process::exit(1); - } - let weights_path = args[0]; - let config_path = args[1]; - let config = match ModelConfig::load(config_path) { - Ok(config) => config, - Err(e) => { - eprintln!("Failed to load whisper config: {}", e); - process::exit(1); - } - }; - let recorder = DefaultRecorder::new().load(weights_path.into(), device); - let model = recorder - .map(|record| config.init(device).load_record(record)) - .unwrap(); - Self { - model: model, - metadata: args[2..].iter().map(|&s| s.to_string()).collect(), - } - } - pub fn compute(&self, input: Vec) -> Vec { - let tokenizer_path = &self.metadata[0].to_string(); - let lang_str = &self.metadata[1].to_string(); - let lang = match Language::iter().find(|lang| lang.as_str() == lang_str) { - Some(lang) => lang, - None => { - eprintln!("Invalid language abbreviation: {}", lang_str); - process::exit(1); - } - }; - let bpe = match Gpt2Tokenizer::new_with_path(tokenizer_path) { - Ok(bpe) => bpe, - Err(e) => { - eprintln!("Failed to load tokenizer: {}", e); - process::exit(1); - } - }; - let (text, _) = match waveform_to_text(&self.model, &bpe, lang, input, 16000) { - Ok((text, tokens)) => (text, tokens), - Err(e) => { - eprintln!("Error during transcription: {}", e); - process::exit(1); - } - }; - return text.into_bytes(); - } -} - -pub const INPUT_DIM: usize = 2; - -pub struct ContextInner { - pub inputs: HashMap>, - pub outputs: Vec>, - _marker: PhantomData, -} - -impl ContextInner { - pub fn new() -> Self { - Self { - inputs: HashMap::new(), - outputs: Vec::new(), - _marker: PhantomData, - } - } - pub fn set_input(&mut self, key: u32, input: &[f32], _: [usize; INPUT_DIM]) { - self.inputs.insert(key, input.to_vec()); - } - pub fn get_output(&mut self, key: usize) -> &Vec { - &self.outputs[key] - } -} From fb538867b233f4d7630f61906c7d45d81e4e4121 Mon Sep 17 00:00:00 2001 From: PeterD1524 <53310459+PeterD1524@users.noreply.github.com> Date: Wed, 16 Oct 2024 04:07:07 +0800 Subject: [PATCH 473/623] [WASI-NN] piper: extend the json_input functionality to allow setting various parameters at runtime (#3825) The added parameters are output_type, noise_scale, length_scale, noise_w, sentence_silence, phoneme_silence. Signed-off-by: PeterD1524 --- plugins/wasi_nn/piper.cpp | 301 +++++++++++++++++-------------- plugins/wasi_nn/piper.h | 30 +-- test/plugins/wasi_nn/wasi_nn.cpp | 134 ++++++++++++++ 3 files changed, 319 insertions(+), 146 deletions(-) diff --git a/plugins/wasi_nn/piper.cpp b/plugins/wasi_nn/piper.cpp index fdc5c262..14f7752f 100644 --- a/plugins/wasi_nn/piper.cpp +++ b/plugins/wasi_nn/piper.cpp @@ -9,7 +9,6 @@ #include "types.h" #include #include -#include #include #include #include @@ -52,60 +51,9 @@ WASINN::ErrNo getOptionalOption(simdjson::dom::object &Object, return Err; } -Expect parseRunConfig(RunConfig &RunConfig, - const std::string &String) noexcept { - simdjson::dom::parser Parser; - simdjson::dom::element Doc; - if (auto Error = Parser.parse(String).get(Doc)) { - spdlog::error("[WASI-NN] Piper backend: Parse run config error: {}"sv, - simdjson::error_message(Error)); - return WASINN::ErrNo::InvalidEncoding; - } - simdjson::dom::object Object; - if (auto Error = Doc.get(Object)) { - spdlog::error( - "[WASI-NN] Piper backend: The run config is not an object: {}"sv, - simdjson::error_message(Error)); - return WASINN::ErrNo::InvalidArgument; - } - - auto ModelPath = std::optional{}; - if (auto Err = getOptionalOption(Object, "model", ModelPath); - Err != WASINN::ErrNo::Success) { - return Err; - } - // Verify model file exists - if (ModelPath) { - auto Path = std::filesystem::u8path(ModelPath.value()); - if (!std::filesystem::exists(Path)) { - spdlog::error("[WASI-NN] Piper backend: Model file doesn't exist"sv); - return WASINN::ErrNo::InvalidArgument; - } - RunConfig.ModelPath = Path; - } else { - spdlog::error( - "[WASI-NN] Piper backend: The model option is required but not provided"sv); - return WASINN::ErrNo::InvalidArgument; - } - - auto ModelConfigPath = std::optional{}; - if (auto Err = getOptionalOption(Object, "config", ModelConfigPath); - Err != WASINN::ErrNo::Success) { - return Err; - } - if (ModelConfigPath) { - RunConfig.ModelConfigPath = - std::filesystem::u8path(ModelConfigPath.value()); - } else { - RunConfig.ModelConfigPath = RunConfig.ModelPath; - RunConfig.ModelConfigPath += ".json"; - } - // Verify model config exists - if (!std::filesystem::exists(RunConfig.ModelConfigPath)) { - spdlog::error("[WASI-NN] Piper backend: Model config doesn't exist"sv); - return WASINN::ErrNo::InvalidArgument; - } - +WASINN::ErrNo parseSynthesisConfig(SynthesisConfig &SynthesisConfig, + simdjson::dom::object &Object, + const bool JsonInput) { { auto Value = std::optional{}; if (auto Err = getOptionalOption(Object, "output_type", Value); @@ -114,9 +62,9 @@ Expect parseRunConfig(RunConfig &RunConfig, } if (Value) { if (Value.value() == "wav") { - RunConfig.OutputType = RunConfigOutputType::OUTPUT_WAV; + SynthesisConfig.OutputType = SynthesisConfigOutputType::OUTPUT_WAV; } else if (Value.value() == "raw") { - RunConfig.OutputType = RunConfigOutputType::OUTPUT_RAW; + SynthesisConfig.OutputType = SynthesisConfigOutputType::OUTPUT_RAW; } else { spdlog::error( "[WASI-NN] Piper backend: The output_type option has an unknown value {}."sv, @@ -125,27 +73,36 @@ Expect parseRunConfig(RunConfig &RunConfig, } } } - if (auto Err = getOptionalOption(Object, "speaker", RunConfig.SpeakerId); - Err != WASINN::ErrNo::Success) { - return Err; + if (JsonInput) { + if (auto Err = + getOptionalOption(Object, "speaker_id", SynthesisConfig.SpeakerId); + Err != WASINN::ErrNo::Success) { + return Err; + } + } else { + if (auto Err = + getOptionalOption(Object, "speaker", SynthesisConfig.SpeakerId); + Err != WASINN::ErrNo::Success) { + return Err; + } } if (auto Err = getOptionalOption(Object, "noise_scale", - RunConfig.NoiseScale); + SynthesisConfig.NoiseScale); Err != WASINN::ErrNo::Success) { return Err; } if (auto Err = getOptionalOption(Object, "length_scale", - RunConfig.LengthScale); + SynthesisConfig.LengthScale); Err != WASINN::ErrNo::Success) { return Err; } - if (auto Err = - getOptionalOption(Object, "noise_w", RunConfig.NoiseW); + if (auto Err = getOptionalOption(Object, "noise_w", + SynthesisConfig.NoiseW); Err != WASINN::ErrNo::Success) { return Err; } if (auto Err = getOptionalOption( - Object, "sentence_silence", RunConfig.SentenceSilenceSeconds); + Object, "sentence_silence", SynthesisConfig.SentenceSilenceSeconds); Err != WASINN::ErrNo::Success) { return Err; } @@ -171,14 +128,77 @@ Expect parseRunConfig(RunConfig &RunConfig, PhonemeStr, simdjson::error_message(Error)); return WASINN::ErrNo::InvalidArgument; } - if (!RunConfig.PhonemeSilenceSeconds) { - RunConfig.PhonemeSilenceSeconds.emplace(); + if (!SynthesisConfig.PhonemeSilenceSeconds) { + SynthesisConfig.PhonemeSilenceSeconds.emplace(); } auto Phoneme = piper::getCodepoint(PhonemeStr); - RunConfig.PhonemeSilenceSeconds.value()[Phoneme] = Seconds.value(); + SynthesisConfig.PhonemeSilenceSeconds.value()[Phoneme] = + Seconds.value(); } } } + return WASINN::ErrNo::Success; +} + +WASINN::ErrNo parseRunConfig(RunConfig &RunConfig, + const std::string &String) noexcept { + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + if (auto Error = Parser.parse(String).get(Doc)) { + spdlog::error("[WASI-NN] Piper backend: Parse run config error: {}"sv, + simdjson::error_message(Error)); + return WASINN::ErrNo::InvalidEncoding; + } + simdjson::dom::object Object; + if (auto Error = Doc.get(Object)) { + spdlog::error( + "[WASI-NN] Piper backend: The run config is not an object: {}"sv, + simdjson::error_message(Error)); + return WASINN::ErrNo::InvalidArgument; + } + + auto ModelPath = std::optional{}; + if (auto Err = getOptionalOption(Object, "model", ModelPath); + Err != WASINN::ErrNo::Success) { + return Err; + } + // Verify model file exists + if (ModelPath) { + auto Path = std::filesystem::u8path(ModelPath.value()); + if (!std::filesystem::exists(Path)) { + spdlog::error("[WASI-NN] Piper backend: Model file doesn't exist"sv); + return WASINN::ErrNo::InvalidArgument; + } + RunConfig.ModelPath = Path; + } else { + spdlog::error( + "[WASI-NN] Piper backend: The model option is required but not provided"sv); + return WASINN::ErrNo::InvalidArgument; + } + + auto ModelConfigPath = std::optional{}; + if (auto Err = getOptionalOption(Object, "config", ModelConfigPath); + Err != WASINN::ErrNo::Success) { + return Err; + } + if (ModelConfigPath) { + RunConfig.ModelConfigPath = + std::filesystem::u8path(ModelConfigPath.value()); + } else { + RunConfig.ModelConfigPath = RunConfig.ModelPath; + RunConfig.ModelConfigPath += ".json"; + } + // Verify model config exists + if (!std::filesystem::exists(RunConfig.ModelConfigPath)) { + spdlog::error("[WASI-NN] Piper backend: Model config doesn't exist"sv); + return WASINN::ErrNo::InvalidArgument; + } + + if (auto Err = + parseSynthesisConfig(RunConfig.DefaultSynthesisConfig, Object, false); + Err != WASINN::ErrNo::Success) { + return Err; + } { auto Path = std::optional{}; if (auto Err = getOptionalOption(Object, "espeak_data", Path); @@ -207,6 +227,41 @@ Expect parseRunConfig(RunConfig &RunConfig, return WASINN::ErrNo::Success; } +void updateSynthesisConfig(SynthesisConfig &SynthesisConfig, + piper::SynthesisConfig &PiperSynthesisConfig, + const bool ForceOverwritePhonemeSilenceSeconds) { + if (SynthesisConfig.NoiseScale) { + PiperSynthesisConfig.noiseScale = SynthesisConfig.NoiseScale.value(); + } + if (SynthesisConfig.LengthScale) { + PiperSynthesisConfig.lengthScale = SynthesisConfig.LengthScale.value(); + } + if (SynthesisConfig.NoiseW) { + PiperSynthesisConfig.noiseW = SynthesisConfig.NoiseW.value(); + } + if (SynthesisConfig.SentenceSilenceSeconds) { + PiperSynthesisConfig.sentenceSilenceSeconds = + SynthesisConfig.SentenceSilenceSeconds.value(); + } + if (ForceOverwritePhonemeSilenceSeconds) { + PiperSynthesisConfig.phonemeSilenceSeconds = + SynthesisConfig.PhonemeSilenceSeconds; + } else if (SynthesisConfig.PhonemeSilenceSeconds) { + if (!PiperSynthesisConfig.phonemeSilenceSeconds) { + // Overwrite + PiperSynthesisConfig.phonemeSilenceSeconds = + SynthesisConfig.PhonemeSilenceSeconds; + } else { + // Merge + for (const auto &[Phoneme, SilenceSeconds] : + *SynthesisConfig.PhonemeSilenceSeconds) { + PiperSynthesisConfig.phonemeSilenceSeconds->try_emplace(Phoneme, + SilenceSeconds); + } + } + } +} + Expect load(WASINN::WasiNNEnvironment &Env, Span> Builders, WASINN::Device, uint32_t &GraphId) noexcept { @@ -233,8 +288,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, GraphRef.Voice = std::make_unique(); piper::loadVoice(*GraphRef.PiperConfig, GraphRef.Config->ModelPath.string(), GraphRef.Config->ModelConfigPath.string(), *GraphRef.Voice, - GraphRef.Config->SpeakerId); - GraphRef.SpeakerId = GraphRef.Config->SpeakerId; + GraphRef.Config->DefaultSynthesisConfig.SpeakerId); if (GraphRef.Voice->phonemizeConfig.phonemeType == piper::PhonemeType::eSpeakPhonemes) { @@ -280,40 +334,20 @@ Expect load(WASINN::WasiNNEnvironment &Env, piper::initialize(*GraphRef.PiperConfig); - // Scales - if (GraphRef.Config->NoiseScale) { - GraphRef.Voice->synthesisConfig.noiseScale = - GraphRef.Config->NoiseScale.value(); - } - - if (GraphRef.Config->LengthScale) { - GraphRef.Voice->synthesisConfig.lengthScale = - GraphRef.Config->LengthScale.value(); - } - - if (GraphRef.Config->NoiseW) { - GraphRef.Voice->synthesisConfig.noiseW = GraphRef.Config->NoiseW.value(); - } - - if (GraphRef.Config->SentenceSilenceSeconds) { - GraphRef.Voice->synthesisConfig.sentenceSilenceSeconds = - GraphRef.Config->SentenceSilenceSeconds.value(); - } - - if (GraphRef.Config->PhonemeSilenceSeconds) { - if (!GraphRef.Voice->synthesisConfig.phonemeSilenceSeconds) { - // Overwrite - GraphRef.Voice->synthesisConfig.phonemeSilenceSeconds = - GraphRef.Config->PhonemeSilenceSeconds; - } else { - // Merge - for (const auto &[Phoneme, SilenceSeconds] : - *GraphRef.Config->PhonemeSilenceSeconds) { - GraphRef.Voice->synthesisConfig.phonemeSilenceSeconds->try_emplace( - Phoneme, SilenceSeconds); - } - } - } // if phonemeSilenceSeconds + // Update the default config + updateSynthesisConfig(GraphRef.Config->DefaultSynthesisConfig, + GraphRef.Voice->synthesisConfig, false); + // Copy back the result + GraphRef.Config->DefaultSynthesisConfig.NoiseScale = + GraphRef.Voice->synthesisConfig.noiseScale; + GraphRef.Config->DefaultSynthesisConfig.LengthScale = + GraphRef.Voice->synthesisConfig.lengthScale; + GraphRef.Config->DefaultSynthesisConfig.NoiseW = + GraphRef.Voice->synthesisConfig.noiseW; + GraphRef.Config->DefaultSynthesisConfig.SentenceSilenceSeconds = + GraphRef.Voice->synthesisConfig.sentenceSilenceSeconds; + GraphRef.Config->DefaultSynthesisConfig.PhonemeSilenceSeconds = + GraphRef.Voice->synthesisConfig.phonemeSilenceSeconds; // Store the loaded graph. GraphId = Env.NNGraph.size() - 1; @@ -329,24 +363,6 @@ Expect initExecCtx(WASINN::WasiNNEnvironment &Env, return WASINN::ErrNo::Success; } -template -WASINN::ErrNo getOptionalInputOption(simdjson::dom::object &Object, - std::string_view Key, - std::optional &Result) { - auto Value = T{}; - if (auto Error = Object[Key].get(Value)) { - if (Error == simdjson::error_code::NO_SUCH_FIELD) { - return WASINN::ErrNo::Success; - } - spdlog::error( - "[WASI-NN] Piper backend: Unable to retrieve \"{}\" from json input: {}"sv, - Key, simdjson::error_message(Error)); - return WASINN::ErrNo::InvalidArgument; - } - Result = Value; - return WASINN::ErrNo::Success; -} - Expect setInput(WASINN::WasiNNEnvironment &Env, uint32_t ContextId, uint32_t Index, const TensorData &Tensor) noexcept { @@ -391,17 +407,15 @@ Expect setInput(WASINN::WasiNNEnvironment &Env, } Line = Text; - // Override speaker id - auto SpeakerId = std::optional{}; - if (auto Err = getOptionalInputOption(Object, "speaker_id", SpeakerId); + // Parse override config + auto JsonInputSynthesisConfig = SynthesisConfig{}; + if (auto Err = parseSynthesisConfig(JsonInputSynthesisConfig, Object, true); Err != WASINN::ErrNo::Success) { return Err; } - if (SpeakerId) { - GraphRef.Voice->synthesisConfig.speakerId = SpeakerId; - } else { + if (!JsonInputSynthesisConfig.SpeakerId) { auto SpeakerName = std::optional{}; - if (auto Err = getOptionalInputOption(Object, "speaker", SpeakerName); + if (auto Err = getOptionalOption(Object, "speaker", SpeakerName); Err != WASINN::ErrNo::Success) { return Err; } @@ -410,13 +424,18 @@ Expect setInput(WASINN::WasiNNEnvironment &Env, auto Name = std::string{SpeakerName.value()}; if (GraphRef.Voice->modelConfig.speakerIdMap && GraphRef.Voice->modelConfig.speakerIdMap->count(Name) > 0) { - GraphRef.Voice->synthesisConfig.speakerId = + JsonInputSynthesisConfig.SpeakerId = GraphRef.Voice->modelConfig.speakerIdMap.value()[Name]; } else { spdlog::warn("[WASI-NN] Piper backend: No speaker named: {}"sv, Name); } } } + if (!CxtRef.JsonInputSynthesisConfig) { + CxtRef.JsonInputSynthesisConfig = + std::make_unique>(); + } + *CxtRef.JsonInputSynthesisConfig = JsonInputSynthesisConfig; } CxtRef.Line = Line; return WASINN::ErrNo::Success; @@ -467,15 +486,30 @@ Expect compute(WASINN::WasiNNEnvironment &Env, return WASINN::ErrNo::InvalidArgument; } + auto OutputType = SynthesisConfigOutputType::OUTPUT_WAV; + if (GraphRef.Config->DefaultSynthesisConfig.OutputType) { + OutputType = GraphRef.Config->DefaultSynthesisConfig.OutputType.value(); + } + + // Override config + if (CxtRef.JsonInputSynthesisConfig && + CxtRef.JsonInputSynthesisConfig->has_value()) { + updateSynthesisConfig(CxtRef.JsonInputSynthesisConfig->value(), + GraphRef.Voice->synthesisConfig, false); + if (CxtRef.JsonInputSynthesisConfig->value().OutputType) { + OutputType = CxtRef.JsonInputSynthesisConfig->value().OutputType.value(); + } + } + auto Result = piper::SynthesisResult{}; - if (GraphRef.Config->OutputType == RunConfigOutputType::OUTPUT_WAV) { + if (OutputType == SynthesisConfigOutputType::OUTPUT_WAV) { auto AudioFile = std::stringstream{std::ios::binary | std::ios::in | std::ios::out}; piper::textToWavFile(*GraphRef.PiperConfig, *GraphRef.Voice, CxtRef.Line.value(), AudioFile, Result); auto String = AudioFile.str(); CxtRef.Output = std::vector{String.begin(), String.end()}; - } else if (GraphRef.Config->OutputType == RunConfigOutputType::OUTPUT_RAW) { + } else if (OutputType == SynthesisConfigOutputType::OUTPUT_RAW) { auto AudioBuffer = std::vector{}; piper::textToAudio(*GraphRef.PiperConfig, *GraphRef.Voice, CxtRef.Line.value(), AudioBuffer, Result, nullptr); @@ -486,7 +520,8 @@ Expect compute(WASINN::WasiNNEnvironment &Env, } // Restore config (json_input) - GraphRef.Voice->synthesisConfig.speakerId = GraphRef.SpeakerId; + updateSynthesisConfig(GraphRef.Config->DefaultSynthesisConfig, + GraphRef.Voice->synthesisConfig, true); return WASINN::ErrNo::Success; } #else diff --git a/plugins/wasi_nn/piper.h b/plugins/wasi_nn/piper.h index 5ff6c43c..3a24f17a 100644 --- a/plugins/wasi_nn/piper.h +++ b/plugins/wasi_nn/piper.h @@ -21,17 +21,11 @@ struct WasiNNEnvironment; namespace WasmEdge::Host::WASINN::Piper { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER -enum class RunConfigOutputType { OUTPUT_WAV, OUTPUT_RAW }; -struct RunConfig { - // Path to .onnx voice file - std::filesystem::path ModelPath; - - // Path to JSON voice config file - std::filesystem::path ModelConfigPath; - +enum class SynthesisConfigOutputType { OUTPUT_WAV, OUTPUT_RAW }; +struct SynthesisConfig { // Type of output to produce. // Default is a WAV file. - RunConfigOutputType OutputType = RunConfigOutputType::OUTPUT_WAV; + std::optional OutputType; // Numerical id of the default speaker (multi-speaker voices) std::optional SpeakerId; @@ -48,6 +42,16 @@ struct RunConfig { // Seconds of silence to add after each sentence std::optional SentenceSilenceSeconds; + // Seconds of extra silence to insert after a single phoneme + std::optional> PhonemeSilenceSeconds; +}; +struct RunConfig { + // Path to .onnx voice file + std::filesystem::path ModelPath; + + // Path to JSON voice config file + std::filesystem::path ModelConfigPath; + // Path to espeak-ng data directory std::optional ESpeakDataPath; @@ -55,27 +59,27 @@ struct RunConfig { // https://github.com/mush42/libtashkeel/ std::optional TashkeelModelPath; - // input is JSON instead of text with format: + // input is JSON with format: // { // "text": str, (required) // "speaker_id": int, (optional) // "speaker": str, (optional) // } + // including options in SynthesisConfig bool JsonInput = false; - // Seconds of extra silence to insert after a single phoneme - std::optional> PhonemeSilenceSeconds; + SynthesisConfig DefaultSynthesisConfig; }; struct Graph { std::unique_ptr Config; std::unique_ptr PiperConfig; std::unique_ptr Voice; - std::optional SpeakerId; }; struct Context { Context(size_t GId, Graph &) noexcept : GraphId(GId) {} size_t GraphId; std::optional Line; + std::unique_ptr> JsonInputSynthesisConfig; std::optional> Output; }; #else diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 01ca1d53..aa8be40c 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -2455,6 +2455,140 @@ TEST(WasiNNTest, PiperBackend) { auto BytesWritten = *MemInst.getPointer(BuilderPtr); EXPECT_GE(BytesWritten, 10000); } + + // Piper json_input tests. + // Test: load -- load successfully. + Config = + "{\"model\": " + "\"./wasinn_piper_fixtures/test_voice.onnx\",\"espeak_data\": " + "\"./wasinn_piper_fixtures/piper/espeak-ng-data\",\"json_input\":true}"; + ConfigData = {Config.begin(), Config.end()}; + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, StorePtr, ConfigData.size(), BuilderPtr); + writeBinaries(MemInst, ConfigData, StorePtr); + { + EXPECT_TRUE(HostFuncLoad.run(CallFrame, + std::initializer_list{ + LoadEntryPtr, UINT32_C(1), + static_cast(Backend::Piper), + UINT32_C(0), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 1); + BuilderPtr += 4; + } + + // Test: init_execution_context -- init context successfully. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{UINT32_C(1), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(*MemInst.getPointer(BuilderPtr), 1); + BuilderPtr += 4; + } + + // First json input with parameters overridden + Text = "{\"text\": \"This is a test.\", \"noise_scale\": 0.0, " + "\"length_scale\": 2.0, \"noise_w\": 0.0, \"sentence_silence\": 1.0, " + "\"phoneme_silence\": {\"t\": 0.0}}"; + TensorData = {Text.begin(), Text.end()}; + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, 2, BuilderPtr); + writeFatPointer(MemInst, + StorePtr + TensorDim.size() * + sizeof(decltype(TensorDim)::value_type), + TensorData.size(), BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries( + MemInst, TensorData, + StorePtr + TensorDim.size() * sizeof(decltype(TensorDim)::value_type)); + + // Test: set_input -- set input successfully. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(1), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += TensorDim.size() * sizeof(decltype(TensorDim)::value_type) + + TensorData.size(); + + // Test: compute -- compute successfully. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(1)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Test: get_output -- get output successfully. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(1), UINT32_C(0), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + auto BytesWritten = *MemInst.getPointer(BuilderPtr); + // Should output more than 50000 bytes. + EXPECT_GE(BytesWritten, 50000); + } + + // Second json input to check if one-time overriding is working properly + Text = "{\"text\": \"This is a test.\", \"output_type\": \"raw\", " + "\"noise_scale\": 0.0, \"noise_w\": 0.0}"; + TensorData = {Text.begin(), Text.end()}; + SetInputEntryPtr = BuilderPtr; + writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); + writeUInt32(MemInst, 2, BuilderPtr); + writeFatPointer(MemInst, + StorePtr + TensorDim.size() * + sizeof(decltype(TensorDim)::value_type), + TensorData.size(), BuilderPtr); + writeBinaries(MemInst, TensorDim, StorePtr); + writeBinaries( + MemInst, TensorData, + StorePtr + TensorDim.size() * sizeof(decltype(TensorDim)::value_type)); + + // Test: set_input -- set input successfully. + { + EXPECT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + UINT32_C(1), UINT32_C(0), SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + StorePtr += TensorDim.size() * sizeof(decltype(TensorDim)::value_type) + + TensorData.size(); + + // Test: compute -- compute successfully. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{UINT32_C(1)}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Test: get_output -- get output successfully. + { + EXPECT_TRUE(HostFuncGetOutput.run( + CallFrame, + std::initializer_list{ + UINT32_C(1), UINT32_C(0), StorePtr, 65532, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + auto BytesWritten = *MemInst.getPointer(BuilderPtr); + EXPECT_GE(BytesWritten, 30000); + // Should output less than 40000 bytes. + EXPECT_LT(BytesWritten, 40000); + EXPECT_EQ(BytesWritten, 34048); + } } #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER From a29eb1d3245acb1116e4da3b7c4287b6d4c8527d Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Wed, 16 Oct 2024 16:20:37 +0800 Subject: [PATCH 474/623] [Test] wasmedge_zlib: Fix -Wformat-truncation error: 'snprintf' will always be truncated; specified size is 3, but format string expands to at least 4 [-Werror,-Wformat-truncation] Signed-off-by: Yi Huang --- test/plugins/wasmedge_zlib/wasmedge_zlib.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp b/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp index 46b3cb78..b777abf2 100644 --- a/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp +++ b/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp @@ -120,8 +120,8 @@ TEST(WasmEdgeZlibTest, DeflateInflateCycle) { std::array RetVal; WasmZlibVersion = WasmHP; - std::snprintf(MemInst.getPointer(WasmHP), std::strlen(ZLIB_VERSION), - ZLIB_VERSION); + std::snprintf(MemInst.getPointer(WasmHP), + std::strlen(ZLIB_VERSION) + 1, ZLIB_VERSION); WasmHP += std::strlen(ZLIB_VERSION); WasmData = WasmHP; From 201edaf89e102374a92d6b7d35d1b8f9a8af1efa Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Wed, 16 Oct 2024 16:43:10 +0800 Subject: [PATCH 475/623] [Plugins] wasm_bpf: Fix error of poisoned identifier See https://github.com/libbpf/libbpf/commit/950cffc0366981d4e41b08f007b37bd6af931f25 Signed-off-by: Yi Huang --- plugins/wasm_bpf/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasm_bpf/CMakeLists.txt b/plugins/wasm_bpf/CMakeLists.txt index 924b898a..4926df6d 100644 --- a/plugins/wasm_bpf/CMakeLists.txt +++ b/plugins/wasm_bpf/CMakeLists.txt @@ -92,7 +92,7 @@ if(NOT ${LIBBPF_FOUND}) FetchContent_Declare( libbpf GIT_REPOSITORY https://github.com/libbpf/libbpf - GIT_TAG cf46d44f0a06aa8b9400691ea3eb86ca4f066d5c + GIT_TAG 950cffc0366981d4e41b08f007b37bd6af931f25 ) FetchContent_GetProperties(libbpf) From 5de0eb4594febf11ea97823ab85cb69f5a6084a9 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Wed, 16 Oct 2024 22:49:42 +0800 Subject: [PATCH 476/623] [CMake] Add option for disable cxx11-abi and turn off cxx11-abi on manylinux. Signed-off-by: YiYing He --- plugins/wasmedge_ocr/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasmedge_ocr/CMakeLists.txt b/plugins/wasmedge_ocr/CMakeLists.txt index a540eb06..4079fce0 100644 --- a/plugins/wasmedge_ocr/CMakeLists.txt +++ b/plugins/wasmedge_ocr/CMakeLists.txt @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: 2019-2024 Second State INC -add_library(wasmedgePluginWasmEdgeOCR +wasmedge_add_library(wasmedgePluginWasmEdgeOCR SHARED ocr_env.cpp ocr_func.cpp From 6d73d5d8dc1d5caebdb2250192f986f57e67a06d Mon Sep 17 00:00:00 2001 From: YiYing He Date: Wed, 9 Oct 2024 02:20:06 +0800 Subject: [PATCH 477/623] [WASI-NN] Rename the source names to avoid conflict header names from dependency. Signed-off-by: YiYing He --- plugins/wasi_nn/CMakeLists.txt | 22 +++++++-------- plugins/wasi_nn/MLX/mlx/activations.cpp | 2 +- plugins/wasi_nn/MLX/mlx/activations.h | 2 +- plugins/wasi_nn/MLX/mlx/base.cpp | 4 +-- plugins/wasi_nn/MLX/mlx/base.h | 4 +-- plugins/wasi_nn/MLX/mlx/embedding.cpp | 4 +-- plugins/wasi_nn/MLX/mlx/embedding.h | 2 +- plugins/wasi_nn/MLX/mlx/linear.cpp | 4 +-- plugins/wasi_nn/MLX/mlx/linear.h | 4 +-- plugins/wasi_nn/MLX/mlx/normalization.cpp | 2 +- plugins/wasi_nn/MLX/mlx/normalization.h | 2 +- .../wasi_nn/MLX/mlx/positional_encoding.cpp | 2 +- plugins/wasi_nn/MLX/mlx/positional_encoding.h | 2 +- plugins/wasi_nn/MLX/mlx/quantized.cpp | 2 +- plugins/wasi_nn/MLX/mlx/quantized.h | 6 ++--- plugins/wasi_nn/MLX/mlx/transformer.cpp | 2 +- plugins/wasi_nn/MLX/mlx/transformer.h | 4 +-- plugins/wasi_nn/MLX/model/converter.cpp | 4 +-- plugins/wasi_nn/MLX/model/converter.h | 4 +-- plugins/wasi_nn/MLX/model/registry.cpp | 4 +-- plugins/wasi_nn/MLX/model/registry.h | 2 +- plugins/wasi_nn/MLX/model/transformer.cpp | 10 +++---- plugins/wasi_nn/MLX/model/transformer.h | 12 ++++----- plugins/wasi_nn/MLX/model/utils.cpp | 2 +- plugins/wasi_nn/MLX/model/utils.h | 2 +- plugins/wasi_nn/MLX/prompt/prompt.cpp | 2 +- .../{chattts.cpp => wasinn_chattts.cpp} | 2 +- .../wasi_nn/{chattts.h => wasinn_chattts.h} | 4 +-- plugins/wasi_nn/{ggml.cpp => wasinn_ggml.cpp} | 7 ++--- plugins/wasi_nn/{ggml.h => wasinn_ggml.h} | 3 ++- plugins/wasi_nn/{mlx.cpp => wasinn_mlx.cpp} | 20 +++++++++----- plugins/wasi_nn/{mlx.h => wasinn_mlx.h} | 14 +++++++--- ...neuralspeed.cpp => wasinn_neuralspeed.cpp} | 6 ++--- .../{neuralspeed.h => wasinn_neuralspeed.h} | 3 ++- plugins/wasi_nn/{onnx.cpp => wasinn_onnx.cpp} | 2 +- plugins/wasi_nn/{onnx.h => wasinn_onnx.h} | 3 ++- .../{openvino.cpp => wasinn_openvino.cpp} | 3 ++- .../wasi_nn/{openvino.h => wasinn_openvino.h} | 3 ++- .../wasi_nn/{piper.cpp => wasinn_piper.cpp} | 5 ++-- plugins/wasi_nn/{piper.h => wasinn_piper.h} | 6 +++-- plugins/wasi_nn/{tf.cpp => wasinn_tf.cpp} | 2 +- plugins/wasi_nn/{tf.h => wasinn_tf.h} | 3 ++- plugins/wasi_nn/{tfl.cpp => wasinn_tfl.cpp} | 2 +- plugins/wasi_nn/{tfl.h => wasinn_tfl.h} | 3 ++- .../wasi_nn/{torch.cpp => wasinn_torch.cpp} | 2 +- plugins/wasi_nn/{torch.h => wasinn_torch.h} | 3 ++- .../{whispercpp.cpp => wasinn_whisper.cpp} | 2 +- .../{whispercpp.h => wasinn_whisper.h} | 3 ++- plugins/wasi_nn/wasinnbase.h | 3 ++- plugins/wasi_nn/wasinnenv.cpp | 2 +- plugins/wasi_nn/wasinnenv.h | 27 ++++++++++--------- plugins/wasi_nn/wasinnfunc.cpp | 3 ++- plugins/wasi_nn/wasinnfunc.h | 3 ++- plugins/wasi_nn/wasinnmodule.h | 3 ++- plugins/wasi_nn/{types.h => wasinntypes.h} | 0 55 files changed, 142 insertions(+), 112 deletions(-) rename plugins/wasi_nn/{chattts.cpp => wasinn_chattts.cpp} (99%) rename plugins/wasi_nn/{chattts.h => wasinn_chattts.h} (98%) rename plugins/wasi_nn/{ggml.cpp => wasinn_ggml.cpp} (99%) rename plugins/wasi_nn/{ggml.h => wasinn_ggml.h} (99%) rename plugins/wasi_nn/{mlx.cpp => wasinn_mlx.cpp} (97%) rename plugins/wasi_nn/{mlx.h => wasinn_mlx.h} (89%) rename plugins/wasi_nn/{neuralspeed.cpp => wasinn_neuralspeed.cpp} (91%) rename plugins/wasi_nn/{neuralspeed.h => wasinn_neuralspeed.h} (98%) rename plugins/wasi_nn/{onnx.cpp => wasinn_onnx.cpp} (98%) rename plugins/wasi_nn/{onnx.h => wasinn_onnx.h} (98%) rename plugins/wasi_nn/{openvino.cpp => wasinn_openvino.cpp} (99%) rename plugins/wasi_nn/{openvino.h => wasinn_openvino.h} (98%) rename plugins/wasi_nn/{piper.cpp => wasinn_piper.cpp} (99%) rename plugins/wasi_nn/{piper.h => wasinn_piper.h} (99%) rename plugins/wasi_nn/{tf.cpp => wasinn_tf.cpp} (98%) rename plugins/wasi_nn/{tf.h => wasinn_tf.h} (98%) rename plugins/wasi_nn/{tfl.cpp => wasinn_tfl.cpp} (99%) rename plugins/wasi_nn/{tfl.h => wasinn_tfl.h} (98%) rename plugins/wasi_nn/{torch.cpp => wasinn_torch.cpp} (99%) rename plugins/wasi_nn/{torch.h => wasinn_torch.h} (98%) rename plugins/wasi_nn/{whispercpp.cpp => wasinn_whisper.cpp} (99%) rename plugins/wasi_nn/{whispercpp.h => wasinn_whisper.h} (99%) rename plugins/wasi_nn/{types.h => wasinntypes.h} (100%) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 5d714dec..eb34d036 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -6,17 +6,17 @@ wasmedge_add_library(wasmedgePluginWasiNN wasinnenv.cpp wasinnfunc.cpp wasinnmodule.cpp - openvino.cpp - onnx.cpp - tf.cpp - torch.cpp - tfl.cpp - ggml.cpp - neuralspeed.cpp - piper.cpp - whispercpp.cpp - chattts.cpp - mlx.cpp + wasinn_openvino.cpp + wasinn_onnx.cpp + wasinn_tf.cpp + wasinn_torch.cpp + wasinn_tfl.cpp + wasinn_ggml.cpp + wasinn_neuralspeed.cpp + wasinn_piper.cpp + wasinn_whisper.cpp + wasinn_chattts.cpp + wasinn_mlx.cpp ) foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) diff --git a/plugins/wasi_nn/MLX/mlx/activations.cpp b/plugins/wasi_nn/MLX/mlx/activations.cpp index 0b6fc13f..64b7202e 100644 --- a/plugins/wasi_nn/MLX/mlx/activations.cpp +++ b/plugins/wasi_nn/MLX/mlx/activations.cpp @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "activations.h" +#include "mlx/activations.h" #include diff --git a/plugins/wasi_nn/MLX/mlx/activations.h b/plugins/wasi_nn/MLX/mlx/activations.h index 039863e2..b6d75127 100644 --- a/plugins/wasi_nn/MLX/mlx/activations.h +++ b/plugins/wasi_nn/MLX/mlx/activations.h @@ -3,7 +3,7 @@ #pragma once -#include "base.h" +#include "mlx/base.h" #include diff --git a/plugins/wasi_nn/MLX/mlx/base.cpp b/plugins/wasi_nn/MLX/mlx/base.cpp index 0c7ac715..9afc8b6f 100644 --- a/plugins/wasi_nn/MLX/mlx/base.cpp +++ b/plugins/wasi_nn/MLX/mlx/base.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "base.h" -#include "../model/utils.h" +#include "mlx/base.h" +#include "model/utils.h" #include diff --git a/plugins/wasi_nn/MLX/mlx/base.h b/plugins/wasi_nn/MLX/mlx/base.h index 65bbff67..24ce0e42 100644 --- a/plugins/wasi_nn/MLX/mlx/base.h +++ b/plugins/wasi_nn/MLX/mlx/base.h @@ -3,10 +3,10 @@ #pragma once -#include "mlx/mlx.h" - #include "common/errcode.h" +#include + #include #include #include diff --git a/plugins/wasi_nn/MLX/mlx/embedding.cpp b/plugins/wasi_nn/MLX/mlx/embedding.cpp index e12264e6..8ceb7f57 100644 --- a/plugins/wasi_nn/MLX/mlx/embedding.cpp +++ b/plugins/wasi_nn/MLX/mlx/embedding.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "embedding.h" -#include "quantized.h" +#include "mlx/embedding.h" +#include "mlx/quantized.h" #include #include diff --git a/plugins/wasi_nn/MLX/mlx/embedding.h b/plugins/wasi_nn/MLX/mlx/embedding.h index 94eba689..778dd1bc 100644 --- a/plugins/wasi_nn/MLX/mlx/embedding.h +++ b/plugins/wasi_nn/MLX/mlx/embedding.h @@ -3,7 +3,7 @@ #pragma once -#include "base.h" +#include "mlx/base.h" #include #include diff --git a/plugins/wasi_nn/MLX/mlx/linear.cpp b/plugins/wasi_nn/MLX/mlx/linear.cpp index 1d5a0104..99f3ad3f 100644 --- a/plugins/wasi_nn/MLX/mlx/linear.cpp +++ b/plugins/wasi_nn/MLX/mlx/linear.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "linear.h" -#include "quantized.h" +#include "mlx/linear.h" +#include "mlx/quantized.h" #include diff --git a/plugins/wasi_nn/MLX/mlx/linear.h b/plugins/wasi_nn/MLX/mlx/linear.h index 116317fc..5afa7a19 100644 --- a/plugins/wasi_nn/MLX/mlx/linear.h +++ b/plugins/wasi_nn/MLX/mlx/linear.h @@ -3,9 +3,9 @@ #pragma once -#include "base.h" -#include "mlx/mlx.h" +#include "mlx/base.h" +#include #include #include diff --git a/plugins/wasi_nn/MLX/mlx/normalization.cpp b/plugins/wasi_nn/MLX/mlx/normalization.cpp index 919b2d96..501bd5f0 100644 --- a/plugins/wasi_nn/MLX/mlx/normalization.cpp +++ b/plugins/wasi_nn/MLX/mlx/normalization.cpp @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "normalization.h" +#include "mlx/normalization.h" namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core::nn { diff --git a/plugins/wasi_nn/MLX/mlx/normalization.h b/plugins/wasi_nn/MLX/mlx/normalization.h index 3a2b901c..a09ccddd 100644 --- a/plugins/wasi_nn/MLX/mlx/normalization.h +++ b/plugins/wasi_nn/MLX/mlx/normalization.h @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "base.h" +#include "mlx/base.h" namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core::nn { diff --git a/plugins/wasi_nn/MLX/mlx/positional_encoding.cpp b/plugins/wasi_nn/MLX/mlx/positional_encoding.cpp index a5be4b92..635bd990 100644 --- a/plugins/wasi_nn/MLX/mlx/positional_encoding.cpp +++ b/plugins/wasi_nn/MLX/mlx/positional_encoding.cpp @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "positional_encoding.h" +#include "mlx/positional_encoding.h" namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core::nn { diff --git a/plugins/wasi_nn/MLX/mlx/positional_encoding.h b/plugins/wasi_nn/MLX/mlx/positional_encoding.h index 42abaff9..95d1b676 100644 --- a/plugins/wasi_nn/MLX/mlx/positional_encoding.h +++ b/plugins/wasi_nn/MLX/mlx/positional_encoding.h @@ -3,7 +3,7 @@ #pragma once -#include "base.h" +#include "mlx/base.h" #include #include diff --git a/plugins/wasi_nn/MLX/mlx/quantized.cpp b/plugins/wasi_nn/MLX/mlx/quantized.cpp index 6ed72678..dee36264 100644 --- a/plugins/wasi_nn/MLX/mlx/quantized.cpp +++ b/plugins/wasi_nn/MLX/mlx/quantized.cpp @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "quantized.h" +#include "mlx/quantized.h" #include #include diff --git a/plugins/wasi_nn/MLX/mlx/quantized.h b/plugins/wasi_nn/MLX/mlx/quantized.h index 631f32a9..3bbe67fe 100644 --- a/plugins/wasi_nn/MLX/mlx/quantized.h +++ b/plugins/wasi_nn/MLX/mlx/quantized.h @@ -3,9 +3,9 @@ #pragma once -#include "base.h" -#include "embedding.h" -#include "linear.h" +#include "mlx/base.h" +#include "mlx/embedding.h" +#include "mlx/linear.h" #include #include diff --git a/plugins/wasi_nn/MLX/mlx/transformer.cpp b/plugins/wasi_nn/MLX/mlx/transformer.cpp index 6933dd5c..76f4d2b6 100644 --- a/plugins/wasi_nn/MLX/mlx/transformer.cpp +++ b/plugins/wasi_nn/MLX/mlx/transformer.cpp @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "transformer.h" +#include "mlx/transformer.h" #include diff --git a/plugins/wasi_nn/MLX/mlx/transformer.h b/plugins/wasi_nn/MLX/mlx/transformer.h index 84cb0b7d..fc2c8410 100644 --- a/plugins/wasi_nn/MLX/mlx/transformer.h +++ b/plugins/wasi_nn/MLX/mlx/transformer.h @@ -3,8 +3,8 @@ #pragma once -#include "base.h" -#include "linear.h" +#include "mlx/base.h" +#include "mlx/linear.h" #include diff --git a/plugins/wasi_nn/MLX/model/converter.cpp b/plugins/wasi_nn/MLX/model/converter.cpp index 21c56a60..4d20b36c 100644 --- a/plugins/wasi_nn/MLX/model/converter.cpp +++ b/plugins/wasi_nn/MLX/model/converter.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "converter.h" -#include "utils.h" +#include "model/converter.h" +#include "model/utils.h" #include #include diff --git a/plugins/wasi_nn/MLX/model/converter.h b/plugins/wasi_nn/MLX/model/converter.h index b4c4a438..bcf9a41b 100644 --- a/plugins/wasi_nn/MLX/model/converter.h +++ b/plugins/wasi_nn/MLX/model/converter.h @@ -3,9 +3,9 @@ #pragma once -#include "base.h" +#include "mlx/base.h" -#include "mlx/mlx.h" +#include #include #include diff --git a/plugins/wasi_nn/MLX/model/registry.cpp b/plugins/wasi_nn/MLX/model/registry.cpp index 7fa15b86..0c6f2757 100644 --- a/plugins/wasi_nn/MLX/model/registry.cpp +++ b/plugins/wasi_nn/MLX/model/registry.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "registry.h" -#include "transformer.h" +#include "model/registry.h" +#include "model/transformer.h" namespace WasmEdge::Host::WASINN::MLX { diff --git a/plugins/wasi_nn/MLX/model/registry.h b/plugins/wasi_nn/MLX/model/registry.h index 456b8aaa..bd1a908d 100644 --- a/plugins/wasi_nn/MLX/model/registry.h +++ b/plugins/wasi_nn/MLX/model/registry.h @@ -3,7 +3,7 @@ #pragma once -#include "transformer.h" +#include "model/transformer.h" namespace WasmEdge::Host::WASINN::MLX { diff --git a/plugins/wasi_nn/MLX/model/transformer.cpp b/plugins/wasi_nn/MLX/model/transformer.cpp index 14503a1d..683eb968 100644 --- a/plugins/wasi_nn/MLX/model/transformer.cpp +++ b/plugins/wasi_nn/MLX/model/transformer.cpp @@ -1,11 +1,11 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "../mlx/transformer.h" -#include "base.h" -#include "embedding.h" -#include "linear.h" -#include "transformer.h" +#include "mlx/transformer.h" +#include "mlx/base.h" +#include "mlx/embedding.h" +#include "mlx/linear.h" +#include "model/transformer.h" #include #include diff --git a/plugins/wasi_nn/MLX/model/transformer.h b/plugins/wasi_nn/MLX/model/transformer.h index 680c08a4..d51be4d1 100644 --- a/plugins/wasi_nn/MLX/model/transformer.h +++ b/plugins/wasi_nn/MLX/model/transformer.h @@ -3,12 +3,12 @@ #pragma once -#include "activations.h" -#include "base.h" -#include "embedding.h" -#include "linear.h" -#include "normalization.h" -#include "positional_encoding.h" +#include "mlx/activations.h" +#include "mlx/base.h" +#include "mlx/embedding.h" +#include "mlx/linear.h" +#include "mlx/normalization.h" +#include "mlx/positional_encoding.h" #include #include diff --git a/plugins/wasi_nn/MLX/model/utils.cpp b/plugins/wasi_nn/MLX/model/utils.cpp index 451d9fc3..0674cf17 100644 --- a/plugins/wasi_nn/MLX/model/utils.cpp +++ b/plugins/wasi_nn/MLX/model/utils.cpp @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "utils.h" +#include "model/utils.h" #include diff --git a/plugins/wasi_nn/MLX/model/utils.h b/plugins/wasi_nn/MLX/model/utils.h index e0a86f4e..9d8118ca 100644 --- a/plugins/wasi_nn/MLX/model/utils.h +++ b/plugins/wasi_nn/MLX/model/utils.h @@ -3,7 +3,7 @@ #pragma once -#include "base.h" +#include "mlx/base.h" #include #include diff --git a/plugins/wasi_nn/MLX/prompt/prompt.cpp b/plugins/wasi_nn/MLX/prompt/prompt.cpp index 0b2d2ceb..7a944c29 100644 --- a/plugins/wasi_nn/MLX/prompt/prompt.cpp +++ b/plugins/wasi_nn/MLX/prompt/prompt.cpp @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "prompt.h" +#include "prompt/prompt.h" #include diff --git a/plugins/wasi_nn/chattts.cpp b/plugins/wasi_nn/wasinn_chattts.cpp similarity index 99% rename from plugins/wasi_nn/chattts.cpp rename to plugins/wasi_nn/wasinn_chattts.cpp index ecd4f99e..a889ce1d 100644 --- a/plugins/wasi_nn/chattts.cpp +++ b/plugins/wasi_nn/wasinn_chattts.cpp @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "chattts.h" +#include "wasinn_chattts.h" #include "wasinnenv.h" #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS diff --git a/plugins/wasi_nn/chattts.h b/plugins/wasi_nn/wasinn_chattts.h similarity index 98% rename from plugins/wasi_nn/chattts.h rename to plugins/wasi_nn/wasinn_chattts.h index 46cab0a3..9b354a66 100644 --- a/plugins/wasi_nn/chattts.h +++ b/plugins/wasi_nn/wasinn_chattts.h @@ -3,10 +3,10 @@ #pragma once +#include "wasinntypes.h" + #include "plugin/plugin.h" -#include "types.h" -#include #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS #include #endif diff --git a/plugins/wasi_nn/ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp similarity index 99% rename from plugins/wasi_nn/ggml.cpp rename to plugins/wasi_nn/wasinn_ggml.cpp index 65c58e9e..dd5d3084 100644 --- a/plugins/wasi_nn/ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -1,22 +1,23 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "ggml.h" +#include "wasinn_ggml.h" #include "wasinnenv.h" #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML #include "simdjson.h" -#include #include #include #include #include -#include #include #include #include #include #include + +#include +#include #include #endif diff --git a/plugins/wasi_nn/ggml.h b/plugins/wasi_nn/wasinn_ggml.h similarity index 99% rename from plugins/wasi_nn/ggml.h rename to plugins/wasi_nn/wasinn_ggml.h index b29dcd0b..7d2ad8ab 100644 --- a/plugins/wasi_nn/ggml.h +++ b/plugins/wasi_nn/wasinn_ggml.h @@ -3,8 +3,9 @@ #pragma once +#include "wasinntypes.h" + #include "plugin/plugin.h" -#include "types.h" #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML #include diff --git a/plugins/wasi_nn/mlx.cpp b/plugins/wasi_nn/wasinn_mlx.cpp similarity index 97% rename from plugins/wasi_nn/mlx.cpp rename to plugins/wasi_nn/wasinn_mlx.cpp index 5e8406f4..7f4cb443 100644 --- a/plugins/wasi_nn/mlx.cpp +++ b/plugins/wasi_nn/wasinn_mlx.cpp @@ -1,13 +1,18 @@ -#include "mlx.h" +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "wasinn_mlx.h" #include "wasinnenv.h" #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX -#include "converter.h" -#include "prompt.h" -#include "registry.h" -#include "utils.h" +#include "MLX/model/converter.h" +#include "MLX/model/registry.h" +#include "MLX/model/utils.h" +#include "MLX/prompt/prompt.h" + #include #endif + namespace WasmEdge::Host::WASINN::MLX { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX std::string loadBytesFromFile(const std::string &Path) { @@ -24,11 +29,13 @@ std::string loadBytesFromFile(const std::string &Path) { Fs.read(Data.data(), Size); return Data; } + enum AnserSataus { STOP, WAIT, GO, }; + AnserSataus answerSataus(std::string Text, std::string End) { if (endsWith(Text, End)) { return STOP; @@ -40,6 +47,7 @@ AnserSataus answerSataus(std::string Text, std::string End) { } return GO; } + Expect load(WASINN::WasiNNEnvironment &Env, Span> Builders, WASINN::Device, uint32_t &GraphId) noexcept { @@ -346,4 +354,4 @@ Expect compute(WASINN::WasiNNEnvironment &, uint32_t) noexcept { } #endif -} // namespace WasmEdge::Host::WASINN::MLX \ No newline at end of file +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/mlx.h b/plugins/wasi_nn/wasinn_mlx.h similarity index 89% rename from plugins/wasi_nn/mlx.h rename to plugins/wasi_nn/wasinn_mlx.h index af3e152e..f487af26 100644 --- a/plugins/wasi_nn/mlx.h +++ b/plugins/wasi_nn/wasinn_mlx.h @@ -1,13 +1,19 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + #pragma once +#include "wasinntypes.h" + #include "plugin/plugin.h" -#include "types.h" + #include #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX +#include "MLX/mlx/transformer.h" #include "MLX/model/transformer.h" -#include "prompt.h" -#include "transformer.h" +#include "MLX/prompt/prompt.h" + #include #include #endif @@ -62,4 +68,4 @@ Expect compute(WASINN::WasiNNEnvironment &Env, uint32_t ContextId) noexcept; Expect unload(WASINN::WasiNNEnvironment &Env, uint32_t GraphId) noexcept; -} // namespace WasmEdge::Host::WASINN::MLX \ No newline at end of file +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/neuralspeed.cpp b/plugins/wasi_nn/wasinn_neuralspeed.cpp similarity index 91% rename from plugins/wasi_nn/neuralspeed.cpp rename to plugins/wasi_nn/wasinn_neuralspeed.cpp index 0d29c59c..9937930d 100644 --- a/plugins/wasi_nn/neuralspeed.cpp +++ b/plugins/wasi_nn/wasinn_neuralspeed.cpp @@ -1,15 +1,15 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "neuralspeed.h" +#include "wasinn_neuralspeed.h" #include "wasinnenv.h" namespace WasmEdge::Host::WASINN::NeuralSpeed { namespace { Expect reportBackendNotSupported() noexcept { spdlog::error("[WASI-NN] Neural Speed backend is removed due to the upstream " - "end-of-life. " - "Reference: https://github.com/intel/neural-speed"sv); + "end-of-life. Reference: " + "https://github.com/intel/neural-speed"sv); return WASINN::ErrNo::InvalidArgument; } } // namespace diff --git a/plugins/wasi_nn/neuralspeed.h b/plugins/wasi_nn/wasinn_neuralspeed.h similarity index 98% rename from plugins/wasi_nn/neuralspeed.h rename to plugins/wasi_nn/wasinn_neuralspeed.h index b74b3c27..0f53b675 100644 --- a/plugins/wasi_nn/neuralspeed.h +++ b/plugins/wasi_nn/wasinn_neuralspeed.h @@ -3,8 +3,9 @@ #pragma once +#include "wasinntypes.h" + #include "plugin/plugin.h" -#include "types.h" namespace WasmEdge::Host::WASINN { struct WasiNNEnvironment; diff --git a/plugins/wasi_nn/onnx.cpp b/plugins/wasi_nn/wasinn_onnx.cpp similarity index 98% rename from plugins/wasi_nn/onnx.cpp rename to plugins/wasi_nn/wasinn_onnx.cpp index 7b7cb6f9..236d3d32 100644 --- a/plugins/wasi_nn/onnx.cpp +++ b/plugins/wasi_nn/wasinn_onnx.cpp @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "onnx.h" +#include "wasinn_onnx.h" #include "wasinnenv.h" namespace WasmEdge::Host::WASINN::ONNX { diff --git a/plugins/wasi_nn/onnx.h b/plugins/wasi_nn/wasinn_onnx.h similarity index 98% rename from plugins/wasi_nn/onnx.h rename to plugins/wasi_nn/wasinn_onnx.h index 6d46e02b..7e5343d8 100644 --- a/plugins/wasi_nn/onnx.h +++ b/plugins/wasi_nn/wasinn_onnx.h @@ -3,8 +3,9 @@ #pragma once +#include "wasinntypes.h" + #include "plugin/plugin.h" -#include "types.h" namespace WasmEdge::Host::WASINN { struct WasiNNEnvironment; diff --git a/plugins/wasi_nn/openvino.cpp b/plugins/wasi_nn/wasinn_openvino.cpp similarity index 99% rename from plugins/wasi_nn/openvino.cpp rename to plugins/wasi_nn/wasinn_openvino.cpp index 72fcb70c..91a01893 100644 --- a/plugins/wasi_nn/openvino.cpp +++ b/plugins/wasi_nn/wasinn_openvino.cpp @@ -1,8 +1,9 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "openvino.h" +#include "wasinn_openvino.h" #include "wasinnenv.h" + #include namespace WasmEdge::Host::WASINN::OpenVINO { diff --git a/plugins/wasi_nn/openvino.h b/plugins/wasi_nn/wasinn_openvino.h similarity index 98% rename from plugins/wasi_nn/openvino.h rename to plugins/wasi_nn/wasinn_openvino.h index b3616666..57691b7c 100644 --- a/plugins/wasi_nn/openvino.h +++ b/plugins/wasi_nn/wasinn_openvino.h @@ -3,8 +3,9 @@ #pragma once +#include "wasinntypes.h" + #include "plugin/plugin.h" -#include "types.h" #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO #include "openvino/openvino.hpp" diff --git a/plugins/wasi_nn/piper.cpp b/plugins/wasi_nn/wasinn_piper.cpp similarity index 99% rename from plugins/wasi_nn/piper.cpp rename to plugins/wasi_nn/wasinn_piper.cpp index 14f7752f..6bfe55d0 100644 --- a/plugins/wasi_nn/piper.cpp +++ b/plugins/wasi_nn/wasinn_piper.cpp @@ -1,12 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "piper.h" +#include "wasinn_piper.h" #include "wasinnenv.h" #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER #include "simdjson.h" -#include "types.h" + #include #include #include @@ -14,7 +14,6 @@ #include #include #include -#include #include #include #include diff --git a/plugins/wasi_nn/piper.h b/plugins/wasi_nn/wasinn_piper.h similarity index 99% rename from plugins/wasi_nn/piper.h rename to plugins/wasi_nn/wasinn_piper.h index 3a24f17a..5516b560 100644 --- a/plugins/wasi_nn/piper.h +++ b/plugins/wasi_nn/wasinn_piper.h @@ -3,14 +3,16 @@ #pragma once +#include "wasinntypes.h" + #include "plugin/plugin.h" -#include "types.h" #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER +#include + #include #include #include -#include #include #include #endif diff --git a/plugins/wasi_nn/tf.cpp b/plugins/wasi_nn/wasinn_tf.cpp similarity index 98% rename from plugins/wasi_nn/tf.cpp rename to plugins/wasi_nn/wasinn_tf.cpp index 6dd02ef3..2d860429 100644 --- a/plugins/wasi_nn/tf.cpp +++ b/plugins/wasi_nn/wasinn_tf.cpp @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "tf.h" +#include "wasinn_tf.h" #include "wasinnenv.h" namespace WasmEdge::Host::WASINN::Tensorflow { diff --git a/plugins/wasi_nn/tf.h b/plugins/wasi_nn/wasinn_tf.h similarity index 98% rename from plugins/wasi_nn/tf.h rename to plugins/wasi_nn/wasinn_tf.h index 6f822af4..69e42941 100644 --- a/plugins/wasi_nn/tf.h +++ b/plugins/wasi_nn/wasinn_tf.h @@ -3,8 +3,9 @@ #pragma once +#include "wasinntypes.h" + #include "plugin/plugin.h" -#include "types.h" namespace WasmEdge::Host::WASINN { struct WasiNNEnvironment; diff --git a/plugins/wasi_nn/tfl.cpp b/plugins/wasi_nn/wasinn_tfl.cpp similarity index 99% rename from plugins/wasi_nn/tfl.cpp rename to plugins/wasi_nn/wasinn_tfl.cpp index 28e1a6a9..69406813 100644 --- a/plugins/wasi_nn/tfl.cpp +++ b/plugins/wasi_nn/wasinn_tfl.cpp @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "tfl.h" +#include "wasinn_tfl.h" #include "wasinnenv.h" #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE diff --git a/plugins/wasi_nn/tfl.h b/plugins/wasi_nn/wasinn_tfl.h similarity index 98% rename from plugins/wasi_nn/tfl.h rename to plugins/wasi_nn/wasinn_tfl.h index 451fcf6d..b640f203 100644 --- a/plugins/wasi_nn/tfl.h +++ b/plugins/wasi_nn/wasinn_tfl.h @@ -3,8 +3,9 @@ #pragma once +#include "wasinntypes.h" + #include "plugin/plugin.h" -#include "types.h" #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE #include "tensorflow/lite/c/c_api.h" diff --git a/plugins/wasi_nn/torch.cpp b/plugins/wasi_nn/wasinn_torch.cpp similarity index 99% rename from plugins/wasi_nn/torch.cpp rename to plugins/wasi_nn/wasinn_torch.cpp index 7470ce34..46412b5f 100644 --- a/plugins/wasi_nn/torch.cpp +++ b/plugins/wasi_nn/wasinn_torch.cpp @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "torch.h" +#include "wasinn_torch.h" #include "wasinnenv.h" #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH diff --git a/plugins/wasi_nn/torch.h b/plugins/wasi_nn/wasinn_torch.h similarity index 98% rename from plugins/wasi_nn/torch.h rename to plugins/wasi_nn/wasinn_torch.h index fa480cfb..96f8e1b4 100644 --- a/plugins/wasi_nn/torch.h +++ b/plugins/wasi_nn/wasinn_torch.h @@ -3,8 +3,9 @@ #pragma once +#include "wasinntypes.h" + #include "plugin/plugin.h" -#include "types.h" #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH #include diff --git a/plugins/wasi_nn/whispercpp.cpp b/plugins/wasi_nn/wasinn_whisper.cpp similarity index 99% rename from plugins/wasi_nn/whispercpp.cpp rename to plugins/wasi_nn/wasinn_whisper.cpp index 79bc23b2..3327096b 100644 --- a/plugins/wasi_nn/whispercpp.cpp +++ b/plugins/wasi_nn/wasinn_whisper.cpp @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "whispercpp.h" +#include "wasinn_whisper.h" #include "wasinnenv.h" #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER diff --git a/plugins/wasi_nn/whispercpp.h b/plugins/wasi_nn/wasinn_whisper.h similarity index 99% rename from plugins/wasi_nn/whispercpp.h rename to plugins/wasi_nn/wasinn_whisper.h index 6958f9b3..e1fe78e3 100644 --- a/plugins/wasi_nn/whispercpp.h +++ b/plugins/wasi_nn/wasinn_whisper.h @@ -3,8 +3,9 @@ #pragma once +#include "wasinntypes.h" + #include "plugin/plugin.h" -#include "types.h" #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER #include diff --git a/plugins/wasi_nn/wasinnbase.h b/plugins/wasi_nn/wasinnbase.h index 00898003..1f5f70d8 100644 --- a/plugins/wasi_nn/wasinnbase.h +++ b/plugins/wasi_nn/wasinnbase.h @@ -3,9 +3,10 @@ #pragma once +#include "wasinnenv.h" + #include "common/errcode.h" #include "runtime/hostfunc.h" -#include "wasinnenv.h" namespace WasmEdge { namespace Host { diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index 61c06f55..088a2b49 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -2,8 +2,8 @@ // SPDX-FileCopyrightText: 2019-2024 Second State INC #include "wasinnenv.h" -#include "types.h" #include "wasinnmodule.h" +#include "wasinntypes.h" #include diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index 9a5100de..4fea1411 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -3,25 +3,26 @@ #pragma once +#include "wasinn_chattts.h" +#include "wasinn_ggml.h" +#include "wasinn_mlx.h" +#include "wasinn_neuralspeed.h" +#include "wasinn_onnx.h" +#include "wasinn_openvino.h" +#include "wasinn_piper.h" +#include "wasinn_tf.h" +#include "wasinn_tfl.h" +#include "wasinn_torch.h" +#include "wasinn_whisper.h" +#include "wasinntypes.h" + #include "common/spdlog.h" #include "plugin/plugin.h" + #include #include #include -#include "chattts.h" -#include "ggml.h" -#include "mlx.h" -#include "neuralspeed.h" -#include "onnx.h" -#include "openvino.h" -#include "piper.h" -#include "tf.h" -#include "tfl.h" -#include "torch.h" -#include "types.h" -#include "whispercpp.h" - #ifdef WASMEDGE_BUILD_WASI_NN_RPC #include #include diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index 7ecc138a..3de7d13b 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -2,9 +2,10 @@ // SPDX-FileCopyrightText: 2019-2024 Second State INC #include "wasinnfunc.h" -#include "common/spdlog.h" #include "wasinnenv.h" +#include "common/spdlog.h" + #include #include diff --git a/plugins/wasi_nn/wasinnfunc.h b/plugins/wasi_nn/wasinnfunc.h index b955a300..cb399cc7 100644 --- a/plugins/wasi_nn/wasinnfunc.h +++ b/plugins/wasi_nn/wasinnfunc.h @@ -3,9 +3,10 @@ #pragma once -#include "runtime/callingframe.h" #include "wasinnbase.h" +#include "runtime/callingframe.h" + #include namespace WasmEdge { diff --git a/plugins/wasi_nn/wasinnmodule.h b/plugins/wasi_nn/wasinnmodule.h index 87b0fd4e..3a428eec 100644 --- a/plugins/wasi_nn/wasinnmodule.h +++ b/plugins/wasi_nn/wasinnmodule.h @@ -3,9 +3,10 @@ #pragma once -#include "runtime/instance/module.h" #include "wasinnenv.h" +#include "runtime/instance/module.h" + namespace WasmEdge { namespace Host { diff --git a/plugins/wasi_nn/types.h b/plugins/wasi_nn/wasinntypes.h similarity index 100% rename from plugins/wasi_nn/types.h rename to plugins/wasi_nn/wasinntypes.h From 12893162aafdd072a113e977277ed4602c60e351 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Mon, 14 Oct 2024 14:46:40 +0800 Subject: [PATCH 478/623] [WASI-NN] Refactor dependency cmake. Signed-off-by: YiYing He --- plugins/wasi_nn/CMakeLists.txt | 293 +--------------------------- test/plugins/wasi_nn/CMakeLists.txt | 3 +- test/plugins/wasi_nn/wasi_nn.cpp | 6 +- 3 files changed, 10 insertions(+), 292 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index eb34d036..9f0036bf 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -19,200 +19,15 @@ wasmedge_add_library(wasmedgePluginWasiNN wasinn_mlx.cpp ) +include(WASINNDeps) +wasmedge_setup_wasinn_target(wasmedgePluginWasiNN PLUGINLIB) + +# This for-each iteration is for the additional sources. +# The dependencies are moved into `cmake/WASINNDeps.cmake`. foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) string(TOLOWER ${BACKEND} BACKEND) - if(BACKEND STREQUAL "ggml") - wasmedge_setup_simdjson() - # llama.cpp options - # Disable warnings and debug messages - set(LLAMA_ALL_WARNINGS OFF) - set(LLAMA_METAL_NDEBUG ON) - set(GGML_ACCELERATE OFF) - set(GGML_BLAS OFF) - set(GGML_OPENMP OFF) - set(BUILD_SHARED_LIBS OFF) - - if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_NATIVE) - message(STATUS "WASI-NN GGML LLAMA backend: Enable GGML_NATIVE(AVX/AVX2/FMA/F16C)") - set(GGML_NATIVE ON) - else() - message(STATUS "WASI-NN GGML LLAMA backend: Disable GGML_NATIVE(AVX/AVX2/FMA/F16C)") - set(GGML_NATIVE OFF) - set(GGML_AVX OFF) - set(GGML_AVX2 OFF) - set(GGML_FMA OFF) - set(GGML_F16C OFF) - endif() - - if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_CUBLAS) - message(STATUS "WASI-NN GGML LLAMA backend: Enable GGML_CUDA") - set(GGML_CUDA ON) - # We need to set GGML_USE_CUDA for clip from llava. - add_compile_definitions(GGML_USE_CUDA) - else() - message(STATUS "WASI-NN GGML LLAMA backend: Disable GGML_CUDA") - set(GGML_CUDA OFF) - endif() - - if(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64" AND WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL) - message(STATUS "WASI-NN GGML LLAMA backend: Enable GGML_METAL") - set(GGML_METAL ON) - set(GGML_METAL_EMBED_LIBRARY ON) - else() - message(STATUS "WASI-NN GGML LLAMA backend: Disable GGML_METAL") - set(GGML_METAL OFF) - endif() - - # setup llama.cpp - message(STATUS "Downloading llama.cpp source") - include(FetchContent) - FetchContent_Declare( - llama - GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b3651 - GIT_SHALLOW FALSE - ) - FetchContent_MakeAvailable(llama) - set_property(TARGET ggml PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET common PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET llama PROPERTY POSITION_INDEPENDENT_CODE ON) - - # Setup llava from llama.cpp - wasmedge_add_library(llava OBJECT - ${llama_SOURCE_DIR}/examples/llava/clip.cpp - ${llama_SOURCE_DIR}/examples/llava/llava.cpp - ) - if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") - target_compile_options(llava - PRIVATE - $<$:/utf-8> - $<$:-Xcompiler=/utf-8> - $<$:/wd4067> # unexpected tokens following preprocessor directive - expected a newline - $<$:/wd4101> # 'identifier' : unreferenced local variable - $<$:/wd4189> # 'identifier' : local variable is initialized but not referenced - $<$:/wd4244> # 'argument' : conversion from 'type1' to 'type2', possible loss of data - $<$:/wd4267> # 'var' : conversion from 'size_t' to 'type', possible loss of data - $<$:/wd4297> # 'function' : function assumed not to throw an exception but does - $<$:/wd4456> # declaration of 'identifier' hides previous local declaration - $<$:/wd4505> # 'function' : unreferenced local function has been removed - ) - elseif(CMAKE_CXX_COMPILER_ID MATCHES "GNU") - target_compile_options(llava - PRIVATE - $<$:-Wno-exceptions> - -Wno-cast-align - -Wno-cast-qual - -Wno-float-conversion - -Wno-implicit-fallthrough - -Wno-unused-macros - -Wno-unused-function - -Wno-unused-variable - ) - elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang") - target_compile_options(llava - PRIVATE - $<$:-Wno-exceptions> - -Wno-cast-align - -Wno-cast-qual - -Wno-disabled-macro-expansion - -Wno-float-conversion - -Wno-implicit-fallthrough - -Wno-implicit-float-conversion - -Wno-unused-macros - -Wno-unused-function - -Wno-unused-variable - -Wno-sign-conversion - -Wno-shorten-64-to-32 - -Wno-implicit-int-conversion - -Wno-old-style-cast - -Wno-extra-semi-stmt - -Wno-format-nonliteral - -Wno-documentation - -Wno-unused-template - ) - endif() - target_link_libraries(llava PRIVATE ggml llama) - target_include_directories(llava PUBLIC - ${llama_SOURCE_DIR} - ${llama_SOURCE_DIR}/common - ${llama_SOURCE_DIR}/examples/llava - ) - target_link_libraries(wasmedgePluginWasiNN PRIVATE - common - simdjson::simdjson - llava - ) - if(APPLE AND WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_METAL) - add_custom_command( - TARGET wasmedgePluginWasiNN - POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${llama_SOURCE_DIR}/ggml/src/ggml-metal.metal ggml-metal.metal - COMMAND ${CMAKE_COMMAND} -E copy ${llama_SOURCE_DIR}/ggml/src/ggml-common.h ggml-common.h - ) - endif() - elseif(BACKEND STREQUAL "neuralspeed") - message(NOTICE "WASI-NN NeuralSpeed backend is removed due to the upstream end-of-life.") - message(NOTICE "Reference: https://github.com/intel/neural-speed") - elseif(BACKEND STREQUAL "chattts") - wasmedge_setup_simdjson() - - find_package(Python3 COMPONENTS Interpreter Development) - if(Python3_FOUND) - target_compile_definitions(wasmedgePluginWasiNN - PRIVATE PYTHON_LIB_PATH="${Python3_LIBRARIES}" - ) - include_directories(${Python3_INCLUDE_DIRS}) - target_link_libraries(wasmedgePluginWasiNN PRIVATE ${Python3_LIBRARIES}) - target_link_directories(wasmedgePluginWasiNN PRIVATE ${Python3_RUNTIME_LIBRARY_DIRS}) - else() - message(FATAL_ERROR "Can not find python3.") - endif() - target_link_libraries(wasmedgePluginWasiNN PRIVATE simdjson::simdjson) - elseif(BACKEND STREQUAL "piper") - wasmedge_setup_simdjson() - target_link_libraries(wasmedgePluginWasiNN PRIVATE simdjson::simdjson) - elseif(BACKEND STREQUAL "whisper") - wasmedge_setup_simdjson() - set(BUILD_SHARED_LIBS OFF CACHE INTERNAL "Whisper not build shared") - set(GGML_OPENMP OFF) - set(GGML_ACCELERATE OFF) - set(GGML_BLAS OFF) - if(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64" AND WASMEDGE_PLUGIN_WASI_NN_WHISPER_METAL) - message(STATUS "WASI-NN Whisper backend: Enable GGML_METAL") - set(GGML_METAL ON) - set(GGML_METAL_EMBED_LIBRARY ON) - else() - message(STATUS "WASI-NN Whisper backend: Disable GGML_METAL") - set(GGML_METAL OFF) - endif() - if(WASMEDGE_PLUGIN_WASI_NN_WHISPER_CUDA) - message(STATUS "WASI-NN Whisper backend: Enable GGML_CUDA") - set(GGML_CUDA ON) - else() - message(STATUS "WASI-NN Whisper backend: Disable GGML_CUDA") - set(GGML_CUDA OFF) - endif() - include(FetchContent) - FetchContent_Declare( - whisper - GIT_REPOSITORY https://github.com/ggerganov/whisper.cpp.git - GIT_TAG 69339af2d104802f3f201fd419163defba52890e - GIT_SHALLOW FALSE - ) - FetchContent_MakeAvailable(whisper) - set_property(TARGET whisper PROPERTY POSITION_INDEPENDENT_CODE ON) - set_property(TARGET ggml PROPERTY POSITION_INDEPENDENT_CODE ON) - target_include_directories(wasmedgePluginWasiNN PRIVATE - ${whisper_SOURCE_DIR} - ${whisper_SOURCE_DIR}/ggml/include - ) - target_link_libraries(wasmedgePluginWasiNN PRIVATE - whisper - simdjson::simdjson - ) - elseif(BACKEND STREQUAL "mlx") - wasmedge_setup_simdjson() - target_sources(wasmedgePluginWasiNN + if(BACKEND STREQUAL "mlx") + target_sources(wasmedgePluginWasiNN PRIVATE MLX/prompt/prompt.cpp MLX/model/transformer.cpp @@ -228,97 +43,6 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) MLX/mlx/transformer.cpp MLX/mlx/quantized.cpp ) - - find_package(MLX CONFIG) - if(MLX_FOUND) - message(STATUS "Found MLX: ${MLX_INCLUDE_DIRS}") - else() - # Not support directly download from source - find_library(ACCELERATE_LIBRARY Accelerate) - find_library(METAL_LIB Metal) - find_library(FOUNDATION_LIB Foundation) - find_library(QUARTZ_LIB QuartzCore) - message(STATUS "MLX not found, downloading from source") - include(FetchContent) - set(MLX_BUILD_GGUF OFF) - FetchContent_Declare( - mlx - GIT_REPOSITORY https://github.com/ml-explore/mlx.git - GIT_TAG v0.16.0 - GIT_SHALLOW FALSE - ) - FetchContent_MakeAvailable(mlx) - set_property(TARGET mlx PROPERTY POSITION_INDEPENDENT_CODE ON) - set_target_properties(mlx PROPERTIES - INTERFACE_LINK_LIBRARIES "$" - ) - target_link_libraries(mlx - PUBLIC - ${ACCELERATE_LIBRARY} - ${METAL_LIB} - ${FOUNDATION_LIB} - ${QUARTZ_LIB} - ) - target_compile_options(mlx - PUBLIC - -Wno-unused-parameter - -Wno-deprecated-copy - -Wno-format - ) - endif() - - message(STATUS "Downloading tokenizers") - FetchContent_Declare( - tokenizers - GIT_REPOSITORY https://github.com/mlc-ai/tokenizers-cpp.git - GIT_TAG 5de6f65 - GIT_SHALLOW FALSE - ) - FetchContent_MakeAvailable(tokenizers) - set_property(TARGET tokenizer_cpp_objs PROPERTY POSITION_INDEPENDENT_CODE ON) - - message(STATUS "Downloading gguflib") - FetchContent_Declare( - gguflib - GIT_REPOSITORY https://github.com/antirez/gguf-tools/ - GIT_TAG af7d88d808a7608a33723fba067036202910acb3 - GIT_SHALLOW FALSE - ) - FetchContent_MakeAvailable(gguflib) - add_library(gguflib - STATIC - ${gguflib_SOURCE_DIR}/fp16.c - ${gguflib_SOURCE_DIR}/gguflib.c - ) - set_target_properties(gguflib PROPERTIES LINKER_LANGUAGE CXX) - - target_include_directories(wasmedgePluginWasiNN - PUBLIC - ${CMAKE_CURRENT_SOURCE_DIR}/MLX/model - ${CMAKE_CURRENT_SOURCE_DIR}/MLX/prompt - ${CMAKE_CURRENT_SOURCE_DIR}/MLX/mlx - ) - target_include_directories(wasmedgePluginWasiNN - PRIVATE - ${tokenizers_SOURCE_DIR}/include - ) - target_include_directories(wasmedgePluginWasiNN - SYSTEM PUBLIC - ${CMAKE_CURRENT_SOURCE_DIR}/mlx - ${MLX_INCLUDE_DIRS} - $ - ) - target_link_libraries(wasmedgePluginWasiNN - PRIVATE - tokenizers_cpp - ) - target_link_libraries(wasmedgePluginWasiNN - PUBLIC - ${MLX_LIBRARIES} - gguflib - mlx - simdjson::simdjson - ) endif() endforeach() @@ -356,9 +80,6 @@ else() ) endif() -include(WASINNDeps) -wasmedge_setup_wasinn_target(wasmedgePluginWasiNN) - install( TARGETS wasmedgePluginWasiNN DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index 86e3b4cb..4de4d4c4 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -70,8 +70,6 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) /wd4067 # unexpected tokens following preprocessor directive - expected a newline ) endif() - elseif(BACKEND STREQUAL "neuralspeed") - message(NOTICE "Neural Speed backend is removed due to the upstream end-of-life.") elseif(BACKEND STREQUAL "piper") message(STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_piper_fixtures") download( @@ -143,6 +141,7 @@ target_link_libraries(wasiNNTests PRIVATE ${GTEST_BOTH_LIBRARIES} ) + # Link to the WasmEdge library if(WASMEDGE_LINK_PLUGINS_STATIC) target_link_libraries(wasiNNTests diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index aa8be40c..ab14d411 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -59,8 +59,8 @@ createModule(std::string_view NNRPCURI = "") { return {}; } -#if !defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS) -inline std::vector readEntireFile(const std::string &Path) { +inline std::vector readEntireFile + [[maybe_unused]] (const std::string &Path) { std::ifstream Fin(Path, std::ios::in | std::ios::binary | std::ios::ate); if (!Fin) { return {}; @@ -74,7 +74,6 @@ inline std::vector readEntireFile(const std::string &Path) { Fin.close(); return Buf; } -#endif template void writeBinaries(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, @@ -2229,7 +2228,6 @@ TEST(WasiNNTest, PiperBackend) { WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); // Load the files. - (void)readEntireFile; std::string Text = "This is a test."; std::vector TensorData(Text.begin(), Text.end()); From db84026e73c3f0f04fe3963aa24dbc2b36212d4f Mon Sep 17 00:00:00 2001 From: Han-Wen Tsao Date: Sat, 19 Oct 2024 22:06:54 +0800 Subject: [PATCH 479/623] [Plugin] Stable Diffusion: add more test (#3843) Signed-off-by: grorge --- plugins/wasmedge_stablediffusion/sd_env.cpp | 3 + plugins/wasmedge_stablediffusion/sd_func.cpp | 6 +- .../wasmedge_stablediffusion.cpp | 153 +++++++++++++++++- 3 files changed, 159 insertions(+), 3 deletions(-) diff --git a/plugins/wasmedge_stablediffusion/sd_env.cpp b/plugins/wasmedge_stablediffusion/sd_env.cpp index 73c618fa..f99ae60d 100644 --- a/plugins/wasmedge_stablediffusion/sd_env.cpp +++ b/plugins/wasmedge_stablediffusion/sd_env.cpp @@ -43,6 +43,9 @@ uint32_t SDEnviornment::addContext(sd_ctx_t *Ctx) noexcept { } sd_ctx_t *SDEnviornment::getContext(const uint32_t Id) noexcept { + if (Id >= Contexts.size()) { + return nullptr; + } return Contexts[Id]; } diff --git a/plugins/wasmedge_stablediffusion/sd_func.cpp b/plugins/wasmedge_stablediffusion/sd_func.cpp index 14ec327a..7a12898d 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.cpp +++ b/plugins/wasmedge_stablediffusion/sd_func.cpp @@ -275,7 +275,8 @@ Expect SDTextToImage::body( if (!parameterCheck(Env, Width, Height, SessionId)) { return static_cast(ErrNo::InvalidArgument); } - sd_ctx_t *SDCtx = Env.getContext(SessionId); + SESSION_CHECK(SDCtx, SessionId, "Session ID is invalid."sv, + ErrNo::InvalidArgument) sd_image_t *Results = nullptr; sd_image_t *ControlImage = nullptr; uint8_t *ControlImageBuffer = nullptr; @@ -353,7 +354,8 @@ Expect SDImageToImage::body( if (!parameterCheck(Env, Width, Height, SessionId)) { return static_cast(ErrNo::InvalidArgument); } - sd_ctx_t *SDCtx = Env.getContext(SessionId); + SESSION_CHECK(SDCtx, SessionId, "Session ID is invalid."sv, + ErrNo::InvalidArgument) std::string Prompt(PromptSpan.begin(), PromptSpan.end()); std::string NegativePrompt(NegativePromptSpan.begin(), NegativePromptSpan.end()); diff --git a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp index 318c9cc7..af8ce453 100644 --- a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp +++ b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp @@ -147,6 +147,7 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); EXPECT_TRUE(std::filesystem::exists(QuantModelPathString)); } + // Test: create_context -- create context for text to image. { EXPECT_TRUE(HostFuncCreateContext.run( @@ -212,7 +213,7 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { -1, // ClipSkip 7.0f, // CfgScale 0, // SampleMethod - 1, // SampleSteps + 1, // SampleSteps 42, // Seed 1, // BatchCount 0.90f, // ControlStrength @@ -235,6 +236,7 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { EXPECT_GE(BytesWritten, 50); EXPECT_TRUE(std::filesystem::exists(OutputPathString)); } + // Test: text_to_image -- reuse context to generate image from text. { uint32_t PromptPtr = UINT32_C(0); @@ -325,6 +327,58 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { EXPECT_EQ(SessionId, 1); } // Test: image_to_image -- generate image from image. + { + uint32_t PromptPtr = UINT32_C(0); + uint32_t InputPathPtr = PromptPtr + PromptData2.size(); + uint32_t OutputPathPtr = InputPathPtr + InputPath.size(); + uint32_t BytesWrittenPtr = OutputPathPtr + OutputPath2.size(); + OutputPtr = BytesWrittenPtr + 4; + writeBinaries(MemInst, PromptData2, PromptPtr); + writeBinaries(MemInst, InputPath, InputPathPtr); + writeBinaries(MemInst, OutputPath2, OutputPathPtr); + EXPECT_TRUE(HostFuncImageToImage.run( + CallFrame, + std::initializer_list{ + InputPathPtr, // ImagePtr + static_cast(InputPath.size()), // ImageLen + SessionId, // SessionId + 3.5f, // Guidance + 256, // Width + 256, // Height + 0, // ControlImagePtr + 0, // ControlImageLen + PromptPtr, // PromptPtr + static_cast(PromptData2.size()), // PromptLen + 0, // NegativePromptPtr + 0, // NegativePromptLen + -1, // ClipSkip + 7.0f, // CfgScale + 0, // SampleMethod + 1, // SampleSteps + 0.75f, // Strength + 42, // Seed + 1, // BatchCount + 0.9f, // ControlStrength + 20.0f, // StyleRatio + 0, // NormalizeInput + 0, // InputIdImagesDirPtr + 0, // InputIdImagesDirLen + 0, // CannyPreprocess + 0, // UpscaleModelPathPtr + 0, // UpscaleModelPathLen + 1, // UpscaleRepeats + OutputPathPtr, // OutputPathPtr + static_cast(OutputPath2.size()), // OutputPathLen + OutputPtr, // OutBufferPtr + 1048512, // OutBufferMaxSize + BytesWrittenPtr}, // BytesWrittenPtr + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + auto BytesWritten = *MemInst.getPointer(BytesWrittenPtr); + EXPECT_GE(BytesWritten, 50); + EXPECT_TRUE(std::filesystem::exists(OutputPathString2)); + } + // Test: image_to_image -- reuse context to generate image from image. { uint32_t PromptPtr = UINT32_C(0); uint32_t InputPathPtr = PromptPtr + PromptData2.size(); @@ -376,6 +430,103 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { EXPECT_GE(BytesWritten, 50); EXPECT_TRUE(std::filesystem::exists(OutputPathString2)); } + + // Test: text_to_image -- non exist SessionId. + { + uint32_t PromptPtr = UINT32_C(0); + uint32_t OutputPathPtr = PromptPtr + PromptData.size(); + uint32_t BytesWrittenPtr = OutputPathPtr + OutputPath.size(); + OutputPtr = BytesWrittenPtr + 4; + writeBinaries(MemInst, PromptData, PromptPtr); + writeBinaries(MemInst, OutputPath, OutputPathPtr); + EXPECT_TRUE(HostFuncTextToImage.run( + CallFrame, + std::initializer_list{ + PromptPtr, // PromptPtr + static_cast(PromptData.size()), // PromptLen + 99, // SessionId + 0, // ControlImagePtr + 0, // ControlImageLen + 0, // NegativePromptPtr + 0, // NegativePromptLen + 3.5f, // Guidance + 256, // Width + 256, // Height + -1, // ClipSkip + 7.0f, // CfgScale + 0, // SampleMethod + 1, // SampleSteps + 42, // Seed + 1, // BatchCount + 0.90f, // ControlStrength + 20.0f, // StyleRatio + 0, // NormalizeInput + 0, // InputIdImagesDirPtr + 0, // InputIdImagesDirLen + 0, // CannyPreprocess + 0, // UpscaleModelPathPtr + 0, // UpscaleModelPathLen + 1, // UpscaleRepeats + OutputPathPtr, // OutputPathPtr + static_cast(OutputPath.size()), // OutputPathLen + OutputPtr, // OutBufferPtr + 1048512, // OutBufferMaxSize + BytesWrittenPtr}, // BytesWrittenPtr + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + + // Test: image_to_image -- non exist SessionId. + { + uint32_t PromptPtr = UINT32_C(0); + uint32_t InputPathPtr = PromptPtr + PromptData2.size(); + uint32_t OutputPathPtr = InputPathPtr + InputPath.size(); + uint32_t BytesWrittenPtr = OutputPathPtr + OutputPath2.size(); + OutputPtr = BytesWrittenPtr + 4; + writeBinaries(MemInst, PromptData2, PromptPtr); + writeBinaries(MemInst, InputPath, InputPathPtr); + writeBinaries(MemInst, OutputPath2, OutputPathPtr); + EXPECT_TRUE(HostFuncImageToImage.run( + CallFrame, + std::initializer_list{ + InputPathPtr, // ImagePtr + static_cast(InputPath.size()), // ImageLen + -1, // SessionId + 3.5f, // Guidance + 256, // Width + 256, // Height + 0, // ControlImagePtr + 0, // ControlImageLen + PromptPtr, // PromptPtr + static_cast(PromptData2.size()), // PromptLen + 0, // NegativePromptPtr + 0, // NegativePromptLen + -1, // ClipSkip + 7.0f, // CfgScale + 0, // SampleMethod + 20, // SampleSteps + 0.75f, // Strength + 42, // Seed + 1, // BatchCount + 0.9f, // ControlStrength + 20.0f, // StyleRatio + 0, // NormalizeInput + 0, // InputIdImagesDirPtr + 0, // InputIdImagesDirLen + 0, // CannyPreprocess + 0, // UpscaleModelPathPtr + 0, // UpscaleModelPathLen + 1, // UpscaleRepeats + OutputPathPtr, // OutputPathPtr + static_cast(OutputPath2.size()), // OutputPathLen + OutputPtr, // OutBufferPtr + 1048512, // OutBufferMaxSize + BytesWrittenPtr}, // BytesWrittenPtr + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } } GTEST_API_ int main(int argc, char **argv) { From 4dd960a4d6780dac148e339a57e4bc7b442ce5b6 Mon Sep 17 00:00:00 2001 From: Yi-Ying He Date: Sat, 19 Oct 2024 22:25:12 +0800 Subject: [PATCH 480/623] [CMake] Update the FetchContent usage. (#3842) Signed-off-by: YiYing He --- plugins/wasm_bpf/CMakeLists.txt | 11 +++-------- plugins/wasmedge_image/CMakeLists.txt | 22 ++++++++-------------- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/plugins/wasm_bpf/CMakeLists.txt b/plugins/wasm_bpf/CMakeLists.txt index 4926df6d..dc9ba54a 100644 --- a/plugins/wasm_bpf/CMakeLists.txt +++ b/plugins/wasm_bpf/CMakeLists.txt @@ -87,20 +87,15 @@ endif() # Try FetchContent if(NOT ${LIBBPF_FOUND}) - message(STATUS "Try to get libbpf through FetchContent") + message(STATUS "Downloading libbpf source") include(FetchContent) FetchContent_Declare( libbpf GIT_REPOSITORY https://github.com/libbpf/libbpf GIT_TAG 950cffc0366981d4e41b08f007b37bd6af931f25 ) - FetchContent_GetProperties(libbpf) - - if(NOT libbpf_POPULATED) - message(STATUS "Fetching libbpf..") - FetchContent_Populate(libbpf) - message(STATUS "Fetched libbpf") - endif() + FetchContent_MakeAvailable(libbpf) + message(STATUS "Downloading libbpf source - done") set(LIBBPF_DOWNLOAD_SOURCE_DIR "${libbpf_SOURCE_DIR}") message(DEBUG "libbpf saved at: ${LIBBPF_DOWNLOAD_SOURCE_DIR}") diff --git a/plugins/wasmedge_image/CMakeLists.txt b/plugins/wasmedge_image/CMakeLists.txt index 800c3a94..34bb9285 100644 --- a/plugins/wasmedge_image/CMakeLists.txt +++ b/plugins/wasmedge_image/CMakeLists.txt @@ -42,29 +42,23 @@ if(APPLE) elseif(UNIX) # Fetch and build libjpeg and libpng. include(FetchContent) + message(STATUS "Downloading libpng source") FetchContent_Declare( wasmedge_image_libpng URL "https://downloads.sourceforge.net/libpng/libpng-1.6.39.tar.gz" URL_HASH "SHA256=af4fb7f260f839919e5958e5ab01a275d4fe436d45442a36ee62f73e5beb75ba" ) - FetchContent_GetProperties(wasmedge_image_libpng) - if(NOT wasmedge_image_libpng_POPULATED) - message(STATUS "Downloading libpng source") - FetchContent_Populate(wasmedge_image_libpng) - message(STATUS "Downloading libpng source - done") - endif() + FetchContent_MakeAvailable(wasmedge_image_libpng) + message(STATUS "Downloading libpng source - done") + message(STATUS "Downloading libjpeg source") FetchContent_Declare( wasmedge_image_libjpeg URL "http://ijg.org/files/jpegsrc.v9e.tar.gz" URL_HASH "SHA256=4077d6a6a75aeb01884f708919d25934c93305e49f7e3f36db9129320e6f4f3d" ) - FetchContent_GetProperties(wasmedge_image_libjpeg) - if(NOT wasmedge_image_libjpeg_POPULATED) - message(STATUS "Downloading libjpeg source") - FetchContent_Populate(wasmedge_image_libjpeg) - message(STATUS "Downloading libjpeg source - done") - endif() + FetchContent_MakeAvailable(wasmedge_image_libjpeg) + message(STATUS "Downloading libjpeg source - done") add_custom_command( OUTPUT ${wasmedge_image_libjpeg_SOURCE_DIR}/.libs/libjpeg.a @@ -102,10 +96,11 @@ endif() # Need zlib and boost. find_package(ZLIB REQUIRED) -find_package(Boost 1.74.0) +find_package(Boost 1.74.0 CONFIG) if(${Boost_FOUND}) else() include(FetchContent) + message(STATUS "Downloading boost 1.82.0 source") FetchContent_Declare( Boost URL http://sources.buildroot.net/boost/boost_1_82_0.tar.bz2 @@ -113,7 +108,6 @@ else() ) set(BOOST_ENABLE_CMAKE ON) set(BOOST_RUNTIME_LINK static) - message(STATUS "Downloading boost 1.82.0 source") FetchContent_MakeAvailable(Boost) message(STATUS "Downloading boost 1.82.0 source - done") add_library(Boost_boost INTERFACE) From 9a4346b6c22532193fe0dbc59bd2cb7dec18fb2a Mon Sep 17 00:00:00 2001 From: dm4 Date: Fri, 18 Oct 2024 17:02:05 +0800 Subject: [PATCH 481/623] [WASI-NN] ggml: bump llama.cpp b3942 Signed-off-by: dm4 --- plugins/wasi_nn/wasinn_ggml.cpp | 268 +++++++++++++++----------------- plugins/wasi_nn/wasinn_ggml.h | 6 +- 2 files changed, 129 insertions(+), 145 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index dd5d3084..e7c1e090 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -334,24 +335,20 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, return ErrNo::Success; } -Expect setupGPTParam(Graph &GraphRef, gpt_params &GPTParams) { - GPTParams.sparams.temp = static_cast(GraphRef.Temp); - GPTParams.sparams.top_p = static_cast(GraphRef.TopP); - GPTParams.sparams.penalty_repeat = static_cast(GraphRef.RepeatPenalty); - GPTParams.sparams.penalty_present = - static_cast(GraphRef.PresencePenalty); - GPTParams.sparams.grammar = GraphRef.Grammar; - return ErrNo::Success; -} - -Expect setupContextParam(Graph &GraphRef, - llama_context_params &ContextParams) { - ContextParams.n_ctx = static_cast(GraphRef.CtxSize); - ContextParams.n_batch = static_cast(GraphRef.BatchSize); - ContextParams.n_ubatch = static_cast(GraphRef.UBatchSize); - ContextParams.n_threads = static_cast(GraphRef.Threads); - ContextParams.n_threads_batch = static_cast(GraphRef.Threads); - ContextParams.embeddings = GraphRef.Embedding; +Expect setupParams(Graph &GraphRef, common_params &Params) { + Params.model = GraphRef.ModelFilePath; + Params.n_gpu_layers = static_cast(GraphRef.NGPULayers); + Params.n_ctx = static_cast(GraphRef.CtxSize); + Params.n_batch = static_cast(GraphRef.BatchSize); + Params.n_ubatch = static_cast(GraphRef.UBatchSize); + Params.cpuparams.n_threads = static_cast(GraphRef.Threads); + Params.cpuparams_batch.n_threads = static_cast(GraphRef.Threads); + Params.embedding = GraphRef.Embedding; + Params.sparams.temp = static_cast(GraphRef.Temp); + Params.sparams.top_p = static_cast(GraphRef.TopP); + Params.sparams.penalty_repeat = static_cast(GraphRef.RepeatPenalty); + Params.sparams.penalty_present = static_cast(GraphRef.PresencePenalty); + Params.sparams.grammar = GraphRef.Grammar; return ErrNo::Success; } @@ -432,8 +429,8 @@ void batchAddSeq(llama_batch &Batch, const std::vector &Tokens, for (int I = 0; I < static_cast(Tokens.size()); I++) { // llama_batch_add_seq(llama_batch, llama_token, llama_pos, // std::vector, logits); - llama_batch_add(Batch, Tokens[I], I, {SequenceId}, - I == static_cast(Tokens.size()) - 1); + common_batch_add(Batch, Tokens[I], I, {SequenceId}, + I == static_cast(Tokens.size()) - 1); } } @@ -472,7 +469,7 @@ ErrNo batchDecode(llama_context *LlamaContext, llama_batch &Batch, } // Normalize the embeddings. - llama_embd_normalize(Embd, Output, NEmbd); + common_embd_normalize(Embd, Output, NEmbd); } return ErrNo::Success; @@ -507,13 +504,8 @@ Expect getEmbedding(WasiNNEnvironment &Env, if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] GGML backend: handle embedding"sv); } - // Initialize the llama context. - llama_context_params ContextParams = llama_context_default_params(); - setupContextParam(GraphRef, ContextParams); - // For non-causal models, batch size must be equal to ubatch size - ContextParams.n_ubatch = ContextParams.n_batch; - auto *LlamaContext = - llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); + // Clear the llama context. + llama_kv_cache_clear(GraphRef.LlamaContext); // Use the const sequence id here. const llama_seq_id SequenceId = 0; @@ -526,14 +518,13 @@ Expect getEmbedding(WasiNNEnvironment &Env, } // Check if the input is too long. - if (static_cast(CxtRef.LlamaInputs.size()) > - ContextParams.n_batch) { + if (static_cast(CxtRef.LlamaInputs.size()) > GraphRef.BatchSize) { if (GraphRef.EnableLog) { spdlog::info( "[WASI-NN] GGML backend: the prompt is too long. " "Your input has {} tokens exceeds batch size {}. " "Please reduce the input size or increase your batch-size."sv, - CxtRef.LlamaInputs.size(), ContextParams.n_batch); + CxtRef.LlamaInputs.size(), GraphRef.BatchSize); } return ErrNo::PromptTooLong; } @@ -545,7 +536,8 @@ Expect getEmbedding(WasiNNEnvironment &Env, /* n_seq_max */ 1); std::vector Embeddings(NEmbd); batchAddSeq(Batch, CxtRef.LlamaInputs, SequenceId); - ReturnCode = batchDecode(LlamaContext, Batch, Embeddings.data(), NEmbd); + ReturnCode = + batchDecode(GraphRef.LlamaContext, Batch, Embeddings.data(), NEmbd); if (ReturnCode != ErrNo::Success) { spdlog::error("[WASI-NN] GGML backend: failed to evaluate input tokens."sv); return ReturnCode; @@ -559,12 +551,12 @@ Expect getEmbedding(WasiNNEnvironment &Env, } if (GraphRef.EnableLog) { - llama_print_timings(LlamaContext); + common_perf_print(GraphRef.LlamaContext, /* Sampler */ nullptr); } - // We free the contexts here to keep the ggml plugin stateless. + // We clear the contexts here to keep the ggml plugin stateless. // Users could fully control the contexts by themselves via their prompt. - llama_free(LlamaContext); + llama_kv_cache_clear(GraphRef.LlamaContext); llama_batch_free(Batch); if (GraphRef.EnableDebugLog) { @@ -703,19 +695,20 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, GraphRef.MMProjModelPath = ""sv; GraphRef.ImagePath = ""sv; // Initialize the model parameters. - GraphRef.NGPULayers = 0; + llama_model_params ModelParams = llama_model_default_params(); + GraphRef.NGPULayers = ModelParams.n_gpu_layers; // Initialize the context parameters. GraphRef.CtxSize = ContextDefault.n_ctx; GraphRef.BatchSize = ContextDefault.n_batch; GraphRef.Threads = ContextDefault.n_threads; // Initialize the sampling parameters. - const llama_sampling_params SamplingDefault; - GraphRef.Temp = SamplingDefault.temp; - GraphRef.TopP = SamplingDefault.top_p; - GraphRef.RepeatPenalty = SamplingDefault.penalty_repeat; - GraphRef.PresencePenalty = SamplingDefault.penalty_present; - GraphRef.FrequencyPenalty = SamplingDefault.penalty_freq; - GraphRef.Grammar = SamplingDefault.grammar; + const common_sampler_params SamplerDefault; + GraphRef.Temp = SamplerDefault.temp; + GraphRef.TopP = SamplerDefault.top_p; + GraphRef.RepeatPenalty = SamplerDefault.penalty_repeat; + GraphRef.PresencePenalty = SamplerDefault.penalty_present; + GraphRef.FrequencyPenalty = SamplerDefault.penalty_freq; + GraphRef.Grammar = SamplerDefault.grammar; // Set llama log callback. llama_log_set(LlamaLogCallback, &GraphRef); @@ -745,9 +738,8 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, auto Weight = Builders[0]; const std::string_view BinModel(reinterpret_cast(Weight.data()), Weight.size()); - std::string ModelFilePath; if (BinModel.substr(0, 8) == "preload:"sv) { - ModelFilePath = BinModel.substr(8); + GraphRef.ModelFilePath = BinModel.substr(8); } else { if (GraphRef.EnableDebugLog) { spdlog::info( @@ -756,8 +748,9 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, } // TODO: pass the model directly to ggml // Write ggml model to file. - ModelFilePath = "ggml-model.bin"sv; - std::ofstream TempFile(ModelFilePath, std::ios::out | std::ios::binary); + GraphRef.ModelFilePath = "ggml-model.bin"sv; + std::ofstream TempFile(GraphRef.ModelFilePath, + std::ios::out | std::ios::binary); if (!TempFile) { spdlog::error( "[WASI-NN] GGML backend: Failed to create the temporary file. " @@ -779,7 +772,8 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, "[WASI-NN][Debug] GGML backend: Finished handling model path."sv); } // Check if the model exists. - if (!std::filesystem::exists(std::filesystem::u8path(ModelFilePath))) { + if (!std::filesystem::exists( + std::filesystem::u8path(GraphRef.ModelFilePath))) { spdlog::error("[WASI-NN] GGML backend: Model file not found."sv); Env.NNGraph.pop_back(); return ErrNo::ModelNotFound; @@ -789,20 +783,26 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, spdlog::info( "[WASI-NN][Debug] GGML backend: Initialize ggml model with given parameters"sv); } - // Initialize ggml model with model parameters. - GraphRef.ModelFilePath = ModelFilePath; - llama_model_params ModelParams = llama_model_default_params(); - ModelParams.n_gpu_layers = static_cast(GraphRef.NGPULayers); - ModelParams.main_gpu = static_cast(GraphRef.MainGPU); - ModelParams.tensor_split = GraphRef.TensorSplit.data(); - ModelParams.use_mmap = GraphRef.UseMMap; - GraphRef.LlamaModel = - llama_load_model_from_file(GraphRef.ModelFilePath.c_str(), ModelParams); + // Initialize ggml parameters. + common_params Params; + setupParams(GraphRef, Params); + llama_backend_init(); + llama_numa_init(Params.numa); + + // Initialize the llama model and context. + common_init_result LlamaInit = common_init_from_params(Params); + GraphRef.LlamaModel = LlamaInit.model; + GraphRef.LlamaContext = LlamaInit.context; if (GraphRef.LlamaModel == nullptr) { spdlog::error("[WASI-NN] GGML backend: Error: unable to init model."sv); Env.NNGraph.pop_back(); return ErrNo::InvalidArgument; } + if (GraphRef.LlamaContext == nullptr) { + spdlog::error("[WASI-NN] GGML backend: Error: unable to init context."sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } if (GraphRef.EnableDebugLog) { spdlog::info( "[WASI-NN][Debug] GGML backend: Initialize ggml model with given parameters...Done"sv); @@ -811,9 +811,6 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Store the loaded graph. GraphId = static_cast(Env.NNGraph.size() - 1); - // Disable llama log by default. - log_disable(); - return ErrNo::Success; } @@ -889,16 +886,13 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, return ErrNo::Success; } - // Initialize the llama context. + // Clear the llama context. if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: init llama context"sv); + spdlog::info("[WASI-NN][Debug] GGML backend: clear llama context"sv); } - llama_context_params ContextParams = llama_context_default_params(); - setupContextParam(GraphRef, ContextParams); - auto LlamaContext = - llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); + llama_kv_cache_clear(GraphRef.LlamaContext); if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: init llama context...Done"sv); + spdlog::info("[WASI-NN][Debug] GGML backend: clear llama context...Done"sv); } // Set the input. @@ -915,8 +909,8 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] GGML backend: tokenize text prompt"sv); } - CxtRef.LlamaInputs = - llama_tokenize(LlamaContext, Prompt, AddSpecial, ParseSpecial); + CxtRef.LlamaInputs = common_tokenize(GraphRef.LlamaContext, Prompt, + AddSpecial, ParseSpecial); if (GraphRef.EnableDebugLog) { spdlog::info( "[WASI-NN][Debug] GGML backend: tokenize text prompt...Done"sv); @@ -990,12 +984,12 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, std::string PromptBeforeImage = Prompt.substr(0, PlaceholderPosition); std::string PromptAfterImage = Prompt.substr(PlaceholderPosition + PromptImagePlaceholder.length()); - std::vector EmbdInputBeforeImage = llama_tokenize( - LlamaContext, PromptBeforeImage, AddSpecial, ParseSpecial); + std::vector EmbdInputBeforeImage = common_tokenize( + GraphRef.LlamaContext, PromptBeforeImage, AddSpecial, ParseSpecial); // Do not add special token (such as , , ... tokens.) to the // tokens after the image. - std::vector EmbdInputAfterImage = - llama_tokenize(LlamaContext, PromptAfterImage, false, ParseSpecial); + std::vector EmbdInputAfterImage = common_tokenize( + GraphRef.LlamaContext, PromptAfterImage, false, ParseSpecial); CxtRef.LlavaImagePosition = EmbdInputBeforeImage.size(); CxtRef.LlamaInputs.reserve(EmbdInputBeforeImage.size() + EmbdInputAfterImage.size()); @@ -1011,20 +1005,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } } CxtRef.LlamaNInputs = CxtRef.LlamaInputs.size(); - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: set the input...Done"sv); - } - - // Delete the llama context. - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: delete llama context to make it stateless"sv); - } - llama_free(LlamaContext); - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: delete llama context to make it stateless...Done"sv); - } + GraphRef.ComputeSingleStarted = false; if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] GGML backend: setInput...Done"sv); @@ -1099,20 +1080,20 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { "[WASI-NN][Debug] GGML backend: clear the previous output and tokens...Done"sv); } - // Initialize the llama context. - gpt_params GPTParams; - llama_context_params ContextParams = llama_context_default_params(); - setupGPTParam(GraphRef, GPTParams); - setupContextParam(GraphRef, ContextParams); - auto LlamaContext = - llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); - struct llama_sampling_context *CtxSampling = - llama_sampling_init(GPTParams.sparams); + // Clear the llama context. + llama_kv_cache_clear(GraphRef.LlamaContext); + + // Setup the parameters and sampler. + common_params Params; + setupParams(GraphRef, Params); + struct common_sampler *Sampler = + common_sampler_init(GraphRef.LlamaModel, Params.sparams); + // Prepare variables; int32_t NPast = 0; uint64_t NRemain = GraphRef.NPredict; // Get the context size. - const uint64_t NCtx = llama_n_ctx(LlamaContext); + const uint64_t NCtx = llama_n_ctx(GraphRef.LlamaContext); // Minus 4 for the special tokens. (Such as , , ... tokens.) const uint64_t MaxTokensListSize = NCtx - 4; // Return value. @@ -1131,7 +1112,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // Evaluate input tokens. if (CxtRef.LlavaImageEmbd == nullptr) { // Text only prompt. - ReturnCode = evaluateTokens(GraphRef, LlamaContext, + ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext, std::move(CxtRef.LlamaInputs), NPast); if (ReturnCode != ErrNo::Success) { spdlog::error( @@ -1146,7 +1127,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { std::vector EmbdInputAfterImage(CxtRef.LlamaInputs.begin() + CxtRef.LlavaImagePosition, CxtRef.LlamaInputs.end()); - ReturnCode = evaluateTokens(GraphRef, LlamaContext, + ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext, std::move(EmbdInputBeforeImage), NPast); if (ReturnCode != ErrNo::Success) { spdlog::error( @@ -1154,14 +1135,14 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { return ReturnCode; } bool EvalImageStatus = - llava_eval_image_embed(LlamaContext, CxtRef.LlavaImageEmbd, + llava_eval_image_embed(GraphRef.LlamaContext, CxtRef.LlavaImageEmbd, static_cast(GraphRef.BatchSize), &NPast); if (!EvalImageStatus) { spdlog::error( "[WASI-NN] GGML backend: failed to evaluate embed image tokens."sv); return ErrNo::RuntimeError; } - ReturnCode = evaluateTokens(GraphRef, LlamaContext, + ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext, std::move(EmbdInputAfterImage), NPast); if (ReturnCode != ErrNo::Success) { spdlog::error( @@ -1175,17 +1156,18 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { spdlog::info("[WASI-NN][Debug] GGML backend: enter main predict loop"sv); } while (NRemain > 0) { + // Use idx = -1 to sample the next token. const llama_token Id = - llama_sampling_sample(CtxSampling, LlamaContext, nullptr); - llama_sampling_accept(CtxSampling, LlamaContext, Id, true); + common_sampler_sample(Sampler, GraphRef.LlamaContext, /* idx */ -1); + common_sampler_accept(Sampler, Id, /* accept_grammar */ true); --NRemain; // Save the output token. CxtRef.LlamaOutputTokens.emplace_back(Id); - CxtRef.LlamaOutputs += llama_token_to_piece(LlamaContext, Id); + CxtRef.LlamaOutputs += common_token_to_piece(GraphRef.LlamaContext, Id); // When setting StreamStdout, we print the output to stdout. if (GraphRef.StreamStdout) { - fmt::print("{}"sv, llama_token_to_piece(LlamaContext, Id)); + fmt::print("{}"sv, common_token_to_piece(GraphRef.LlamaContext, Id)); std::fflush(stdout); } // Break if reverse prompt is found. @@ -1197,15 +1179,14 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { break; } // Deal with end of text token. - if (llama_token_is_eog(GraphRef.LlamaModel, - llama_sampling_last(CtxSampling))) { + if (llama_token_is_eog(GraphRef.LlamaModel, common_sampler_last(Sampler))) { if (GraphRef.EnableLog) { spdlog::info("[WASI-NN] GGML backend: EOS token found"sv); } break; } // Evaluate the output token. - ReturnCode = evaluateTokens(GraphRef, LlamaContext, {Id}, NPast); + ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext, {Id}, NPast); if (ReturnCode != ErrNo::Success) { break; } @@ -1217,24 +1198,23 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // End of main predict loop. if (GraphRef.EnableLog) { - llama_print_timings(LlamaContext); + common_perf_print(GraphRef.LlamaContext, Sampler); } // We free the contexts here to keep the ggml plugin stateless. // Users could fully control the contexts by themselves via their prompt. if (GraphRef.EnableDebugLog) { spdlog::info( - "[WASI-NN][Debug] GGML backend: delete llama context to make it stateless"sv); + "[WASI-NN][Debug] GGML backend: delete llama sampler to make it stateless"sv); } - llama_sampling_free(CtxSampling); - llama_free(LlamaContext); + common_sampler_free(Sampler); if (CxtRef.LlavaImageEmbd != nullptr) { llava_image_embed_free(CxtRef.LlavaImageEmbd); CxtRef.LlavaImageEmbd = nullptr; } if (GraphRef.EnableDebugLog) { spdlog::info( - "[WASI-NN][Debug] GGML backend: delete llama context to make it stateless...Done"sv); + "[WASI-NN][Debug] GGML backend: delete llama sampler to make it stateless...Done"sv); } if (GraphRef.EnableDebugLog) { @@ -1272,8 +1252,8 @@ Expect getOutputSingle(WasiNNEnvironment &Env, uint32_t ContextId, } return ErrNo::Success; } - std::string LastToken = llama_token_to_piece(CxtRef.LlamaContext, - CxtRef.LlamaOutputTokens.back()); + std::string LastToken = common_token_to_piece( + GraphRef.LlamaContext, CxtRef.LlamaOutputTokens.back()); std::copy_n(LastToken.data(), LastToken.length(), OutBuffer.data()); BytesWritten = static_cast(LastToken.length()); if (GraphRef.EnableDebugLog) { @@ -1295,7 +1275,8 @@ Expect computeSingle(WasiNNEnvironment &Env, } // New compute single token context. - if (CxtRef.LlamaContext == nullptr) { + if (!GraphRef.ComputeSingleStarted) { + GraphRef.ComputeSingleStarted = true; // Check if the input is set before setting up the context. if (CxtRef.LlamaInputs.size() == 0) { spdlog::error("[WASI-NN] GGML backend: Llama input is not set!"sv); @@ -1314,18 +1295,18 @@ Expect computeSingle(WasiNNEnvironment &Env, "[WASI-NN][Debug] GGML backend: clear the previous output and tokens...Done"sv); } - // Initialize the llama context. - gpt_params GPTParams; - llama_context_params ContextParams = llama_context_default_params(); - setupGPTParam(GraphRef, GPTParams); - setupContextParam(GraphRef, ContextParams); - CxtRef.LlamaContext = - llama_new_context_with_model(GraphRef.LlamaModel, ContextParams); - CxtRef.LlamaSampling = llama_sampling_init(GPTParams.sparams); + // Clear the llama context. + llama_kv_cache_clear(GraphRef.LlamaContext); + + // Setup the parameters and sampler. + common_params Params; + setupParams(GraphRef, Params); + CxtRef.LlamaSampler = + common_sampler_init(GraphRef.LlamaModel, Params.sparams); CxtRef.LlamaNPast = 0; // Get the context size. - const uint64_t NCtx = llama_n_ctx(CxtRef.LlamaContext); + const uint64_t NCtx = llama_n_ctx(GraphRef.LlamaContext); // Minus 4 for the special tokens. (Such as , , ... tokens.) const uint64_t MaxTokensListSize = NCtx - 4; // Return value. @@ -1345,7 +1326,7 @@ Expect computeSingle(WasiNNEnvironment &Env, if (CxtRef.LlavaImageEmbd == nullptr) { // Text only prompt. ReturnCode = - evaluateTokens(GraphRef, CxtRef.LlamaContext, + evaluateTokens(GraphRef, GraphRef.LlamaContext, std::move(CxtRef.LlamaInputs), CxtRef.LlamaNPast); if (ReturnCode != ErrNo::Success) { spdlog::error( @@ -1361,7 +1342,7 @@ Expect computeSingle(WasiNNEnvironment &Env, CxtRef.LlamaInputs.begin() + CxtRef.LlavaImagePosition, CxtRef.LlamaInputs.end()); ReturnCode = - evaluateTokens(GraphRef, CxtRef.LlamaContext, + evaluateTokens(GraphRef, GraphRef.LlamaContext, std::move(EmbdInputBeforeImage), CxtRef.LlamaNPast); if (ReturnCode != ErrNo::Success) { spdlog::error( @@ -1369,7 +1350,7 @@ Expect computeSingle(WasiNNEnvironment &Env, return ReturnCode; } bool EvalImageStatus = llava_eval_image_embed( - CxtRef.LlamaContext, CxtRef.LlavaImageEmbd, + GraphRef.LlamaContext, CxtRef.LlavaImageEmbd, static_cast(GraphRef.BatchSize), &CxtRef.LlamaNPast); if (!EvalImageStatus) { spdlog::error( @@ -1377,7 +1358,7 @@ Expect computeSingle(WasiNNEnvironment &Env, return ErrNo::RuntimeError; } ReturnCode = - evaluateTokens(GraphRef, CxtRef.LlamaContext, + evaluateTokens(GraphRef, GraphRef.LlamaContext, std::move(EmbdInputAfterImage), CxtRef.LlamaNPast); if (ReturnCode != ErrNo::Success) { spdlog::error( @@ -1392,17 +1373,18 @@ Expect computeSingle(WasiNNEnvironment &Env, spdlog::info("[WASI-NN][Debug] GGML backend: enter main predict process"sv); } auto ReturnCode = ErrNo::Success; - const llama_token Id = - llama_sampling_sample(CxtRef.LlamaSampling, CxtRef.LlamaContext, nullptr); - llama_sampling_accept(CxtRef.LlamaSampling, CxtRef.LlamaContext, Id, true); + // Use idx = -1 to sample the next token. + const llama_token Id = common_sampler_sample( + CxtRef.LlamaSampler, GraphRef.LlamaContext, /* idx */ -1); + common_sampler_accept(CxtRef.LlamaSampler, Id, /* accept_grammar */ true); // Save the output token. // In single token mode, we do not handle StreamStdout and ReversePrompt. CxtRef.LlamaOutputTokens.emplace_back(Id); - CxtRef.LlamaOutputs += llama_token_to_piece(CxtRef.LlamaContext, Id); + CxtRef.LlamaOutputs += common_token_to_piece(GraphRef.LlamaContext, Id); // Deal with end of text token. if (llama_token_is_eog(GraphRef.LlamaModel, - llama_sampling_last(CxtRef.LlamaSampling))) { + common_sampler_last(CxtRef.LlamaSampler))) { ReturnCode = ErrNo::EndOfSequence; if (GraphRef.EnableLog) { spdlog::info("[WASI-NN] GGML backend: EOS token found"sv); @@ -1410,8 +1392,8 @@ Expect computeSingle(WasiNNEnvironment &Env, } // Evaluate the output token if not EOS. if (ReturnCode != ErrNo::EndOfSequence) { - ReturnCode = - evaluateTokens(GraphRef, CxtRef.LlamaContext, {Id}, CxtRef.LlamaNPast); + ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext, {Id}, + CxtRef.LlamaNPast); } if (GraphRef.EnableDebugLog) { spdlog::info( @@ -1436,7 +1418,7 @@ Expect finiSingle(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // Logging for the llama timings. if (GraphRef.EnableLog) { - llama_print_timings(CxtRef.LlamaContext); + common_perf_print(GraphRef.LlamaContext, CxtRef.LlamaSampler); } // Clear the outputs. @@ -1451,15 +1433,15 @@ Expect finiSingle(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { "[WASI-NN][Debug] GGML backend: finiSingle: clear the previous output and tokens...Done"sv); } - // Delete the llama context. + // Clear the llama context. if (GraphRef.EnableDebugLog) { spdlog::info( - "[WASI-NN][Debug] GGML backend: finiSingle: free the llama context"sv); + "[WASI-NN][Debug] GGML backend: finiSingle: clear the llama context"sv); } - llama_sampling_free(CxtRef.LlamaSampling); - llama_free(CxtRef.LlamaContext); - CxtRef.LlamaSampling = nullptr; - CxtRef.LlamaContext = nullptr; + llama_kv_cache_clear(GraphRef.LlamaContext); + common_sampler_reset(CxtRef.LlamaSampler); + common_sampler_free(CxtRef.LlamaSampler); + CxtRef.LlamaSampler = nullptr; if (CxtRef.LlavaImageEmbd != nullptr) { llava_image_embed_free(CxtRef.LlavaImageEmbd); CxtRef.LlavaImageEmbd = nullptr; diff --git a/plugins/wasi_nn/wasinn_ggml.h b/plugins/wasi_nn/wasinn_ggml.h index 7d2ad8ab..0504bc0f 100644 --- a/plugins/wasi_nn/wasinn_ggml.h +++ b/plugins/wasi_nn/wasinn_ggml.h @@ -11,6 +11,7 @@ #include #include #include +#include #endif namespace WasmEdge::Host::WASINN { @@ -23,11 +24,13 @@ namespace WasmEdge::Host::WASINN::GGML { struct Graph { llama_model *LlamaModel = nullptr; std::string ModelFilePath; + llama_context *LlamaContext = nullptr; // Plugin parameters: bool EnableLog = false; bool EnableDebugLog = false; bool StreamStdout = false; bool Embedding = false; + bool ComputeSingleStarted = false; uint64_t NPredict; std::string ReversePrompt; std::string MMProjModelPath; @@ -60,8 +63,7 @@ struct Context { std::string LlamaOutputs; std::vector LlamaOutputTokens; // Preserve for computing single token - llama_context *LlamaContext = nullptr; - struct llama_sampling_context *LlamaSampling = nullptr; + common_sampler *LlamaSampler = nullptr; int32_t LlamaNPast = 0; // Preserve for llava struct llava_image_embed *LlavaImageEmbd = nullptr; From 402f2745c0d5b0114863b5a39179542262391ce8 Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Wed, 18 Sep 2024 16:31:01 +0800 Subject: [PATCH 482/623] [Docker] Ubuntu: Add 24.04 Signed-off-by: Yi Huang --- utils/docker/Dockerfile.ubuntu-base | 18 ++++++++++++++ utils/docker/docker-bake.ubuntu.hcl | 37 +++++++++++++++++++++++++---- 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/utils/docker/Dockerfile.ubuntu-base b/utils/docker/Dockerfile.ubuntu-base index 416b6611..a6ed6278 100644 --- a/utils/docker/Dockerfile.ubuntu-base +++ b/utils/docker/Dockerfile.ubuntu-base @@ -50,6 +50,24 @@ RUN update-alternatives --install /usr/bin/clang clang /usr/bin/clang-15 100 && ENV CC=/usr/bin/clang-15 ENV CXX=/usr/bin/clang++-15 +### deps for ubuntu 24.04 ### +FROM base AS deps-24 + +RUN apt-get install -y cmake + +RUN apt-get install -y \ + llvm-18-dev \ + liblld-18-dev \ + libpolly-18-dev \ + clang-18 + +RUN update-alternatives --install /usr/bin/clang clang /usr/bin/clang-18 100 && \ + update-alternatives --install /usr/bin/clang++ clang++ /usr/bin/clang++-18 100 && \ + update-alternatives --install /usr/bin/llvm-strip llvm-strip /usr/bin/llvm-strip-18 100 + +ENV CC=/usr/bin/clang-18 +ENV CXX=/usr/bin/clang++-18 + ### cleanup FROM deps-${UBUNTU_VER} AS clean-apt diff --git a/utils/docker/docker-bake.ubuntu.hcl b/utils/docker/docker-bake.ubuntu.hcl index d7ae601d..f37c51da 100644 --- a/utils/docker/docker-bake.ubuntu.hcl +++ b/utils/docker/docker-bake.ubuntu.hcl @@ -6,8 +6,35 @@ group "default" { } group "latest" { + targets = [ + "base-2404-clang", + ] +} + +group "focal" { + targets = [ + "base-2004-clang", + "base-2004-gcc", + "plugins-2004-clang", + "plugins-2004-gcc", + ] +} + +group "jammy" { targets = [ "base-2204-clang", + "base-2204-gcc", + "plugins-2204-clang", + "plugins-2204-gcc", + ] +} + +group "noble" { + targets = [ + "base-2404-clang", + "base-2404-gcc", + "plugins-2404-clang", + "plugins-2404-gcc", ] } @@ -23,12 +50,12 @@ function "major" { function "tags-latest" { params = [target, ubuntu, toolchain] - result = target == "base" && ubuntu == "22.04" && toolchain == "clang" ? "latest" : "" + result = target == "base" && ubuntu == "24.04" && toolchain == "clang" ? "latest" : "" } function "tags-latest-backports" { params = [target, ubuntu, toolchain] - result = ubuntu == "22.04" ? join("-", compact([ + result = ubuntu == "24.04" ? join("-", compact([ "ubuntu", "build", toolchain, @@ -67,7 +94,7 @@ target "base" { context = "./utils/docker" matrix = { - ubuntu = ["20.04", "22.04"] + ubuntu = ["20.04", "22.04", "24.04"] } name = "base-${no-dot(ubuntu)}" @@ -82,7 +109,7 @@ target "plugins" { context = "./utils" matrix = { - ubuntu = ["20.04", "22.04"] + ubuntu = ["20.04", "22.04", "24.04"] } name = "plugins-${no-dot(ubuntu)}" @@ -99,7 +126,7 @@ target "plugins" { target "final" { matrix = { parent = ["base", "plugins"] - ubuntu = ["20.04", "22.04"] + ubuntu = ["20.04", "22.04", "24.04"] toolchain = ["clang", "gcc"] } From 5724e965cd1918c3d7b6de1f48bad3b36a63f48d Mon Sep 17 00:00:00 2001 From: YiYing He Date: Mon, 28 Oct 2024 15:29:19 +0800 Subject: [PATCH 483/623] [WASI-NN] Whisper: fix the token timestamp option. Signed-off-by: YiYing He --- plugins/wasi_nn/wasinn_whisper.cpp | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/plugins/wasi_nn/wasinn_whisper.cpp b/plugins/wasi_nn/wasinn_whisper.cpp index 3327096b..8a4c2404 100644 --- a/plugins/wasi_nn/wasinn_whisper.cpp +++ b/plugins/wasi_nn/wasinn_whisper.cpp @@ -148,6 +148,7 @@ void setWhisperParams(Context &CxtRef) noexcept { WParam.print_progress = false; WParam.thold_pt = ConfigRef.WordThreshold; WParam.max_len = ConfigRef.MaxSegmentLength; + WParam.token_timestamps = (WParam.max_len > 0); WParam.split_on_word = ConfigRef.SplitOnWord; WParam.translate = ConfigRef.Translate; WParam.language = ConfigRef.SpokenLanguage.c_str(); @@ -283,14 +284,17 @@ Expect parseMetadata(Config &ConfigRef, PrintParsedOption("duration"sv, ConfigRef.DurationMS); } if (Doc.at_key("max-context").error() == simdjson::SUCCESS) { - auto Err = - Doc["max-context"].get().get(ConfigRef.MaxTokenContext); + int64_t MaxContext = 0; + auto Err = Doc["max-context"].get().get(MaxContext); if (Err) { spdlog::error( "[WASI-NN] Whisper backend: Unable to retrieve the max-context option."sv); return ErrNo::InvalidArgument; } - PrintParsedOption("max-context"sv, ConfigRef.MaxTokenContext); + if (MaxContext >= 0) { + ConfigRef.MaxTokenContext = static_cast(MaxContext); + PrintParsedOption("max-context"sv, ConfigRef.MaxTokenContext); + } } if (Doc.at_key("max-len").error() == simdjson::SUCCESS) { auto Err = Doc["max-len"].get().get(ConfigRef.MaxSegmentLength); @@ -487,7 +491,10 @@ Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, } Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); ContextId = Env.NNContext.size() - 1; - setWhisperParams(Env.NNContext[ContextId].get()); + auto &CxtRef = Env.NNContext[ContextId].get(); + CxtRef.WhisperParams = whisper_full_default_params( + whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH); + setWhisperParams(CxtRef); if (GraphRef.WhisperConfig.EnableLog) { spdlog::info("[WASI-NN] Whisper backend: whisper_system_info: {}"sv, whisper_print_system_info()); From 126e93748f8f297f34f8233fe1b597839dc5a203 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Mon, 28 Oct 2024 15:30:41 +0800 Subject: [PATCH 484/623] [WASI-NN] Whisper: move the whisper.cpp linking out of header. Signed-off-by: YiYing He --- plugins/wasi_nn/wasinn_whisper.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/plugins/wasi_nn/wasinn_whisper.h b/plugins/wasi_nn/wasinn_whisper.h index e1fe78e3..f468e9e0 100644 --- a/plugins/wasi_nn/wasinn_whisper.h +++ b/plugins/wasi_nn/wasinn_whisper.h @@ -69,8 +69,7 @@ struct Context { // Whisper config. Inherit from the graph and accept metadata when setting // input. Config WhisperConfig; - whisper_full_params WhisperParams = whisper_full_default_params( - whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH); + whisper_full_params WhisperParams; // Recognition outputs. std::string Outputs; }; From 37acf1d6db4edd734c7f5762a9d7f0c2e331eec7 Mon Sep 17 00:00:00 2001 From: Tenderyi <1559342051@qq.com> Date: Thu, 31 Oct 2024 15:16:33 +0800 Subject: [PATCH 485/623] [Plugin] stable diffusion: achieve new functions (#3763) * sd functions: Using ESRGAN to upscale results, change the size of pics, batch_count Signed-off-by: Tenderyi <1559342051@qq.com> * [OSPP][stable-difusion.cpp] achieve new functions & fix the format Signed-off-by: Tenderyi <1559342051@qq.com> * [OSPP][stable-difusion.cpp] achieve new functions & fix the format2 Signed-off-by: Tenderyi <1559342051@qq.com> * [OSPP][stable-difusion.cpp] achieve new functions & fix the format3 Signed-off-by: Tenderyi <1559342051@qq.com> * [OSPP][stable-difusion.cpp] achieve new functions & fix the format Signed-off-by: Tenderyi <1559342051@qq.com> * [OSPP][stable-difusion.cpp]Using ESRGAN to upscale results, change the size of pics, batch_count Signed-off-by: Tenderyi <1559342051@qq.com> * [OSPP][stable-difusion.cpp] fix the params problems for upscaleModel Signed-off-by: Tenderyi <1559342051@qq.com> * [OSPP][stable-difusion.cpp] fix the params problems for upscaleModel Signed-off-by: Tenderyi <1559342051@qq.com> * [Stable-Diffusion.cpp] fix the problems Signed-off-by: Tenderyi <1559342051@qq.com> * [Stable-Diffusion.cpp] fix the problems Signed-off-by: Tenderyi <1559342051@qq.com> * [Stable-Diffusion.cpp] fix problems Signed-off-by: Tenderyi <1559342051@qq.com> * [OSPP][stable-difusion.cpp] reuse params Signed-off-by: Tenderyi <1559342051@qq.com> * [OSPP][stable-difusion.cpp] fix CMAKE problems Signed-off-by: Tenderyi <1559342051@qq.com> * [OSPP][stable-difusion.cpp] fix merge problems Signed-off-by: Tenderyi <1559342051@qq.com> * [OSPP][stable-difusion.cpp] test for slove error Signed-off-by: Tenderyi <1559342051@qq.com> * [OSPP][stable-difusion.cpp] add sv Signed-off-by: Tenderyi <1559342051@qq.com> * [OSPP][stable-difusion.cpp] add sv Signed-off-by: Tenderyi <1559342051@qq.com> --------- Signed-off-by: Tenderyi <1559342051@qq.com> Co-authored-by: Han-Wen Tsao Co-authored-by: dm4 --- .../wasmedge_stablediffusion/CMakeLists.txt | 2 +- plugins/wasmedge_stablediffusion/sd_env.cpp | 13 +- plugins/wasmedge_stablediffusion/sd_env.h | 15 +- plugins/wasmedge_stablediffusion/sd_func.cpp | 273 +++++++++++++----- plugins/wasmedge_stablediffusion/sd_func.h | 4 +- 5 files changed, 227 insertions(+), 80 deletions(-) diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index 6ff6af38..d0de98d9 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -130,4 +130,4 @@ install( TARGETS wasmedgePluginWasmEdgeStableDiffusion DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge COMPONENT WasmEdge -) +) \ No newline at end of file diff --git a/plugins/wasmedge_stablediffusion/sd_env.cpp b/plugins/wasmedge_stablediffusion/sd_env.cpp index f99ae60d..ed1ec4a5 100644 --- a/plugins/wasmedge_stablediffusion/sd_env.cpp +++ b/plugins/wasmedge_stablediffusion/sd_env.cpp @@ -37,16 +37,23 @@ EXPORT_GET_DESCRIPTOR(Descriptor) namespace StableDiffusion { -uint32_t SDEnviornment::addContext(sd_ctx_t *Ctx) noexcept { - Contexts.push_back(Ctx); +uint32_t SDEnviornment::addContext(sd_ctx_t *Ctx, int32_t Nthreads, + uint32_t Wtype) noexcept { + Contexts.push_back({Ctx, Nthreads, Wtype}); return Contexts.size() - 1; } +void SDEnviornment::freeContext(const uint32_t Id) noexcept { + sd_ctx_t *SDCtx = Contexts[Id].Context; + free_sd_ctx(SDCtx); + Contexts.erase(Contexts.begin() + Id - 1); +} + sd_ctx_t *SDEnviornment::getContext(const uint32_t Id) noexcept { if (Id >= Contexts.size()) { return nullptr; } - return Contexts[Id]; + return Contexts[Id].Context; } void SBLog(enum sd_log_level_t Level, const char *Log, void *) { diff --git a/plugins/wasmedge_stablediffusion/sd_env.h b/plugins/wasmedge_stablediffusion/sd_env.h index 4c3a278c..b58dd28f 100644 --- a/plugins/wasmedge_stablediffusion/sd_env.h +++ b/plugins/wasmedge_stablediffusion/sd_env.h @@ -23,6 +23,12 @@ enum class ErrNo : uint32_t { RuntimeError = 5, // Runtime Error. }; +struct ContextInfo { + sd_ctx_t *Context; + int32_t NThreads; + uint32_t Wtype; +}; + class SDEnviornment { public: SDEnviornment() noexcept { @@ -30,13 +36,18 @@ class SDEnviornment { sd_set_log_callback(SBLog, nullptr); } }; - uint32_t addContext(sd_ctx_t *Ctx) noexcept; + uint32_t addContext(sd_ctx_t *Ctx, int32_t Nthreads, uint32_t Wtype) noexcept; + void freeContext(const uint32_t Id) noexcept; sd_ctx_t *getContext(const uint32_t Id) noexcept; size_t getContextSize() noexcept { return Contexts.size(); } + int32_t getNThreads(const uint32_t Id) noexcept { + return Contexts[Id].NThreads; + } + uint32_t getWtype(const uint32_t Id) noexcept { return Contexts[Id].Wtype; } private: - std::vector Contexts; bool EnableSDLog = false; + std::vector Contexts; }; } // namespace StableDiffusion diff --git a/plugins/wasmedge_stablediffusion/sd_func.cpp b/plugins/wasmedge_stablediffusion/sd_func.cpp index 7a12898d..cccbe0e0 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.cpp +++ b/plugins/wasmedge_stablediffusion/sd_func.cpp @@ -15,6 +15,10 @@ #define STB_IMAGE_WRITE_STATIC #include "stb_image_write.h" +#define STB_IMAGE_RESIZE_IMPLEMENTATION +#define STB_IMAGE_RESIZE_STATIC +#include "stb_image_resize.h" + namespace WasmEdge { namespace Host { namespace StableDiffusion { @@ -57,53 +61,135 @@ namespace StableDiffusion { bool parameterCheck(SDEnviornment &Env, uint32_t Width, uint32_t Height, uint32_t SessionId) { if (SessionId >= Env.getContextSize()) { - spdlog::error("[WasmEdge-StableDiffusion] Session ID is invalid."); + spdlog::error("[WasmEdge-StableDiffusion] Session ID is invalid."sv); return false; } if (Width % 64 != 0) { - spdlog::error("[WasmEdge-StableDiffusion] Width must be a multiple of 64 " - "and greater than 0."); + spdlog::error( + "[WasmEdge-StableDiffusion] Width must be a multiple of 64 and greater than 0."sv); return false; } if (Height % 64 != 0) { - spdlog::error("[WasmEdge-StableDiffusion] Height must be a multiple of 64 " - "and greater than 0."); + spdlog::error( + "[WasmEdge-StableDiffusion] Height must be a multiple of 64 and greater than 0."sv); return false; } return true; } -sd_image_t *readControlImage(Span ControlImage, - uint8_t *ControlImageBuf, int Width, int Height, +sd_image_t *readControlImage(Span ControlImage, int Width, int Height, bool CannyPreprocess) { + uint8_t *ControlImageBuffer = nullptr; sd_image_t *ControlImg = nullptr; int Channel = 0; std::string ControlImagePath(ControlImage.begin(), ControlImage.end()); if (ControlImagePath.substr(0, 5) == "path:"sv) { - ControlImageBuf = stbi_load(ControlImagePath.substr(5).data(), &Width, - &Height, &Channel, 3); + ControlImageBuffer = stbi_load(ControlImagePath.substr(5).data(), &Width, + &Height, &Channel, 3); } else { - ControlImageBuf = stbi_load_from_memory( + ControlImageBuffer = stbi_load_from_memory( ControlImage.data(), ControlImage.size(), &Width, &Height, &Channel, 3); } - if (ControlImageBuf == nullptr) { + if (ControlImageBuffer == nullptr) { spdlog::error( "[WasmEdge-StableDiffusion] Load image from control image failed."sv); return nullptr; } ControlImg = new sd_image_t{static_cast(Width), - static_cast(Height), 3, ControlImageBuf}; + static_cast(Height), 3, ControlImageBuffer}; if (CannyPreprocess) { // apply preprocessor ControlImg->data = preprocess_canny(ControlImg->data, ControlImg->width, ControlImg->height, 0.08f, 0.08f, 0.8f, 1.0f, false); } + free(ControlImageBuffer); return ControlImg; } +void upscalerModel(const char *UpscaleModelPath, uint32_t UpscaleRepeats, + int32_t NThreads, uint32_t Wtype, uint32_t BatchCount, + sd_image_t *Results) { + // unused for RealESRGAN_x4plus_anime_6B.pth + int UpscaleFactor = 4; + upscaler_ctx_t *UpscalerCtx = new_upscaler_ctx(UpscaleModelPath, NThreads, + static_cast(Wtype)); + if (UpscalerCtx == nullptr) { + spdlog::error("[WasmEdge-StableDiffusion] Create upscaler ctx failed."sv); + } else { + for (uint32_t I = 0; I < BatchCount; I++) { + if (Results[I].data == nullptr) { + continue; + } + sd_image_t CurrentImage = Results[I]; + for (uint32_t U = 0; U < UpscaleRepeats; ++U) { + sd_image_t UpscaledImage = + upscale(UpscalerCtx, CurrentImage, UpscaleFactor); + if (UpscaledImage.data == nullptr) { + spdlog::error("[WasmEdge-StableDiffusion] Upscale failed."sv); + break; + } + free(CurrentImage.data); + CurrentImage = UpscaledImage; + } + // Set the final upscaled image as the result + Results[I] = CurrentImage; + } + } + free(UpscalerCtx); +} + +bool saveResults(sd_image_t *Results, uint32_t BatchCount, + std::string OutputPath, uint32_t OutputPathLen, + uint32_t *BytesWritten, uint32_t OutBufferMaxSize, + uint8_t *OutputBufferSpanPtr) { + int Len; + unsigned char *Png = stbi_write_png_to_mem( + reinterpret_cast(Results), 0, Results->width, + Results->height, Results->channel, &Len, nullptr); + if (OutputPathLen != 0) { + size_t Last = OutputPath.find_last_of("."); + std::string DummyName = OutputPath; + std::string Extension = ".png"; + if (Last != std::string::npos) { + std::string LastStr = OutputPath.substr(Last); + if (LastStr == ".png" || LastStr == ".PNG") { + DummyName = OutputPath.substr(0, Last); + Extension = LastStr; + } + } + for (uint32_t I = 0; I < BatchCount; I++) { + if (Results[I].data != nullptr) { + std::string FinalImagePath; + if (I <= 0) + FinalImagePath = DummyName + Extension; + else + FinalImagePath += "_" + std::to_string(I + 1) + Extension; + stbi_write_png(FinalImagePath.c_str(), Results[I].width, + Results[I].height, Results[I].channel, Results[I].data, + 0, nullptr); + spdlog::info("[WasmEdge-StableDiffusion] Save result image to {}."sv, + FinalImagePath.c_str()); + free(Results[I].data); + Results[I].data = nullptr; + } + } + } + *BytesWritten = Len; + if (OutBufferMaxSize < *BytesWritten) { + spdlog::error("[WasmEdge-StableDiffusion] Output buffer is not enough."sv); + free(Png); + free(Results); + return false; + } + std::copy_n(Png, *BytesWritten, OutputBufferSpanPtr); + free(Png); + free(Results); + return true; +} + Expect SDConvert::body(const Runtime::CallingFrame &Frame, uint32_t ModelPathPtr, uint32_t ModelPathLen, uint32_t VaeModelPathPtr, @@ -134,7 +220,7 @@ Expect SDConvert::body(const Runtime::CallingFrame &Frame, std::ifstream Fin(ModelPath.data(), std::ios::in | std::ios::binary); if (!Fin) { Fin.close(); - spdlog::error("[WasmEdge-StableDiffusion] Model not found."); + spdlog::error("[WasmEdge-StableDiffusion] Model not found."sv); return static_cast(ErrNo::InvalidArgument); } Fin.close(); @@ -142,7 +228,7 @@ Expect SDConvert::body(const Runtime::CallingFrame &Frame, bool Ret = ::convert(ModelPath.data(), VaeModelPath.data(), OutputPath.data(), static_cast(WType)); if (!Ret) { - spdlog::error("[WasmEdge-StableDiffusion] Failed to convert model."); + spdlog::error("[WasmEdge-StableDiffusion] Failed to convert model."sv); return static_cast(ErrNo::InvalidArgument); } @@ -164,7 +250,6 @@ Expect SDCreateContext::body( uint32_t VaeOnCpu, uint32_t SessiontIdPtr) { // Check memory instance from module. MEMINST_CHECK(MemInst, Frame, 0) - // Check the input model buffer. MEM_SPAN_CHECK(ModelPathSpan, MemInst, char, ModelPathPtr, ModelPathLen, "Failed when accessing the input model path memory."sv) @@ -193,7 +278,6 @@ Expect SDCreateContext::body( "Failed when accessing the input id dembed directory memory."sv) MEM_PTR_CHECK(SessionId, MemInst, uint32_t, SessiontIdPtr, "Failed when accessing the return SessionID memory."sv) - std::string ModelPath = std::string(ModelPathSpan.begin(), ModelPathSpan.end()); std::string VaePath = std::string(VaePathSpan.begin(), VaePathSpan.end()); @@ -215,10 +299,14 @@ Expect SDCreateContext::body( if (NThreads == -1) { NThreads = get_num_physical_cores(); } - - spdlog::info("[WasmEdge-StableDiffusion] Create context."sv); + // Check parameters + if (ModelPathLen == 0 && diffusionModelPathLen == 0) { + spdlog::error( + "[WasmEdge-StableDiffusion] The following arguments are required: ModelPath / DiffusionModelPath"sv); + return static_cast(ErrNo::InvalidArgument); + } // Create context and import graph. - + spdlog::debug("[WasmEdge-StableDiffusion] Create context."sv); sd_ctx_t *Ctx = new_sd_ctx( ModelPath.data(), clipLPath.data(), t5xxlPath.data(), diffusionModelPath.data(), VaePath.data(), TaesdPath.data(), @@ -228,11 +316,10 @@ Expect SDCreateContext::body( static_cast(Wtype), static_cast(RngType), static_cast(Schedule), ClipOnCpu, ControlNetCpu, VaeOnCpu); if (Ctx == nullptr) { - spdlog::error("[WasmEdge-StableDiffusion] Failed to create context."); + spdlog::error("[WasmEdge-StableDiffusion] Failed to create context."sv); return static_cast(ErrNo::InvalidArgument); } - *SessionId = Env.addContext(Ctx); - + *SessionId = Env.addContext(Ctx, NThreads, static_cast(Wtype)); return static_cast(ErrNo::Success); } @@ -244,8 +331,9 @@ Expect SDTextToImage::body( uint32_t SampleMethod, uint32_t SampleSteps, uint32_t Seed, uint32_t BatchCount, float ControlStrength, float StyleRatio, uint32_t NormalizeInput, uint32_t InputIdImagesDirPtr, - uint32_t InputIdImagesDirLen, uint32_t CannyPreprocess, uint32_t, uint32_t, - uint32_t, uint32_t OutputPathPtr, uint32_t OutputPathLen, + uint32_t InputIdImagesDirLen, uint32_t CannyPreprocess, + uint32_t UpscaleModelPathPtr, uint32_t UpscaleModelPathLen, + uint32_t UpscaleRepeats, uint32_t OutputPathPtr, uint32_t OutputPathLen, uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr) { // Check memory instance from module. @@ -279,41 +367,43 @@ Expect SDTextToImage::body( ErrNo::InvalidArgument) sd_image_t *Results = nullptr; sd_image_t *ControlImage = nullptr; - uint8_t *ControlImageBuffer = nullptr; + // Read control image if (ControlImageLen != 0) { MEM_SPAN_CHECK(ControlImageSpan, MemInst, uint8_t, ControlImagePtr, ControlImageLen, "Failed when accessing the control image memory."sv) - ControlImage = readControlImage(ControlImageSpan, ControlImageBuffer, Width, - Height, CannyPreprocess); + ControlImage = + readControlImage(ControlImageSpan, Width, Height, CannyPreprocess); } + // Generate images spdlog::info("[WasmEdge-StableDiffusion] Start to generate image."sv); Results = txt2img(SDCtx, Prompt.data(), NegativePrompt.data(), ClipSkip, CfgScale, Guidance, Width, Height, sample_method_t(SampleMethod), SampleSteps, Seed, BatchCount, ControlImage, ControlStrength, StyleRatio, NormalizeInput, InputIdImagesDir.data()); - // TODO upscale image - int Len; - unsigned char *Png = stbi_write_png_to_mem( - reinterpret_cast(Results->data), 0, Results->width, - Results->height, Results->channel, &Len, nullptr); - if (OutputPathLen != 0) { - stbi_write_png(OutputPath.data(), Results->width, Results->height, - Results->channel, Results->data, 0, nullptr); + free(ControlImage); + if (Results == nullptr) { + spdlog::error("[WasmEdge-StableDiffusion] Generate failed."sv); + Env.freeContext(SessionId); + return static_cast(ErrNo::RuntimeError); } - *BytesWritten = Len; - if (OutBufferMaxSize < *BytesWritten) { - spdlog::error("[WasmEdge-StableDiffusion] Output buffer is not enough."sv); - free(Png); - free(Results); - free(ControlImageBuffer); + // Upscale image + if (UpscaleModelPathLen != 0) { + MEM_SPAN_CHECK(UpscaleModelSpan, MemInst, char, UpscaleModelPathPtr, + UpscaleModelPathLen, + "Failed when accessing the Upscaler Image memory."sv) + std::string UpscaleModelPath(UpscaleModelSpan.begin(), + UpscaleModelSpan.end()); + upscalerModel(UpscaleModelPath.data(), UpscaleRepeats, + Env.getNThreads(SessionId), Env.getWtype(SessionId), + BatchCount, Results); + } + // Save results + if (!saveResults(Results, BatchCount, OutputPath, OutputPathLen, BytesWritten, + OutBufferMaxSize, OutputBufferSpan.data())) { return static_cast(ErrNo::RuntimeError); } - std::copy_n(Png, *BytesWritten, OutputBufferSpan.data()); - free(Png); - free(Results); - free(ControlImageBuffer); return static_cast(ErrNo::Success); } @@ -326,16 +416,15 @@ Expect SDImageToImage::body( uint32_t SampleSteps, float Strength, uint32_t Seed, uint32_t BatchCount, float ControlStrength, float StyleRatio, uint32_t NormalizeInput, uint32_t InputIdImagesDirPtr, uint32_t InputIdImagesDirLen, - uint32_t CannyPreprocess, uint32_t, uint32_t, uint32_t, + uint32_t CannyPreprocess, uint32_t UpscaleModelPathPtr, + uint32_t UpscaleModelPathLen, uint32_t UpscaleRepeats, uint32_t OutputPathPtr, uint32_t OutputPathLen, uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr) { // Check memory instance from module. MEMINST_CHECK(MemInst, Frame, 0) - // Check the input parameter valid. MEM_SPAN_CHECK(ImageSpan, MemInst, uint8_t, ImagePtr, ImageLen, "Failed when accessing the input image memory."sv) - MEM_SPAN_CHECK(PromptSpan, MemInst, char, PromptPtr, PromptLen, "Failed when accessing the promp memory."sv) MEM_SPAN_CHECK(NegativePromptSpan, MemInst, char, NegativePromptPtr, @@ -362,8 +451,8 @@ Expect SDImageToImage::body( std::string InputIdImagesDir(InputIdImagesDirSpan.begin(), InputIdImagesDirSpan.end()); std::string OutputPath(OutputPathSpan.begin(), OutputPathSpan.end()); + // Read input image uint8_t *InputImageBuffer = nullptr; - uint8_t *ControlImageBuffer = nullptr; int Channel = 0; int ImageWidth = 0; int ImageHeight = 0; @@ -376,21 +465,61 @@ Expect SDImageToImage::body( "[WasmEdge-StableDiffusion] Load image from input image failed."sv); return static_cast(ErrNo::InvalidArgument); } + if (Channel < 3) { + spdlog::error( + "[WasmEdge-StableDiffusion] The number of channels for the input image must be >= 3."sv); + free(InputImageBuffer); + return static_cast(ErrNo::InvalidArgument); + } + if (ImageWidth <= 0) { + spdlog::error( + "[WasmEdge-StableDiffusion] The width of image must be greater than 0."sv); + free(InputImageBuffer); + return static_cast(ErrNo::InvalidArgument); + } + if (ImageHeight <= 0) { + spdlog::error( + "[WasmEdge-StableDiffusion] The height of image must be greater than 0."sv); + free(InputImageBuffer); + return static_cast(ErrNo::InvalidArgument); + } + // Resize image when image size not matches width and height + if (Height != static_cast(ImageHeight) || + Width != static_cast(ImageWidth)) { + int ResizedHeight = Height; + int ResizedWidth = Width; + uint8_t *ResizedImageBuffer = + (uint8_t *)malloc(ResizedHeight * ResizedWidth * 3); + if (ResizedImageBuffer == nullptr) { + spdlog::error( + "[WasmEdge-StableDiffusion] Failed to allocate memory for resize input image."sv); + free(InputImageBuffer); + return static_cast(ErrNo::InvalidArgument); + } + stbir_resize(InputImageBuffer, ImageWidth, ImageHeight, 0, + ResizedImageBuffer, ResizedWidth, ResizedHeight, 0, + STBIR_TYPE_UINT8, 3, STBIR_ALPHA_CHANNEL_NONE, 0, + STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP, STBIR_FILTER_BOX, + STBIR_FILTER_BOX, STBIR_COLORSPACE_SRGB, nullptr); + free(InputImageBuffer); + InputImageBuffer = ResizedImageBuffer; + } } else { InputImageBuffer = stbi_load_from_memory(ImageSpan.data(), ImageSpan.size(), &ImageWidth, &ImageHeight, &Channel, 3); } - // TODO: Resize image when image size not matches width and height sd_image_t InputImage = {Width, Height, 3, InputImageBuffer}; + // Read control image sd_image_t *ControlImage = nullptr; if (ControlImageLen != 0) { MEM_SPAN_CHECK(ControlImageSpan, MemInst, uint8_t, ControlImagePtr, ControlImageLen, "Failed when accessing the control image memory."sv) - ControlImage = readControlImage(ControlImageSpan, ControlImageBuffer, Width, - Height, CannyPreprocess); + ControlImage = + readControlImage(ControlImageSpan, Width, Height, CannyPreprocess); } + // Generate images sd_image_t *Results = nullptr; spdlog::info("[WasmEdge-StableDiffusion] Start to generate image."sv); Results = img2img(SDCtx, InputImage, Prompt.data(), NegativePrompt.data(), @@ -398,29 +527,29 @@ Expect SDImageToImage::body( sample_method_t(SampleMethod), SampleSteps, Strength, Seed, BatchCount, ControlImage, ControlStrength, StyleRatio, NormalizeInput, InputIdImagesDir.data()); - // TODO: upscale image - int Len; - unsigned char *Png = stbi_write_png_to_mem( - reinterpret_cast(Results->data), 0, Results->width, - Results->height, Results->channel, &Len, nullptr); - if (OutputPathLen != 0) { - stbi_write_png(OutputPath.data(), Results->width, Results->height, - Results->channel, Results->data, 0, nullptr); + free(ControlImage); + free(InputImageBuffer); + if (Results == nullptr) { + spdlog::error("[WasmEdge-StableDiffusion] Generate failed."sv); + Env.freeContext(SessionId); + return static_cast(ErrNo::RuntimeError); } - *BytesWritten = Len; - if (OutBufferMaxSize < *BytesWritten) { - spdlog::error("[WasmEdge-StableDiffusion] Output buffer is not enough."sv); - free(Png); - free(Results); - free(InputImageBuffer); - free(ControlImageBuffer); + // Upscale image + if (UpscaleModelPathLen != 0) { + MEM_SPAN_CHECK(UpscaleModelSpan, MemInst, char, UpscaleModelPathPtr, + UpscaleModelPathLen, + "Failed when accessing the Upscaler Image memory."sv) + std::string UpscaleModelPath(UpscaleModelSpan.begin(), + UpscaleModelSpan.end()); + upscalerModel(UpscaleModelPath.data(), UpscaleRepeats, + Env.getNThreads(SessionId), Env.getWtype(SessionId), + BatchCount, Results); + } + // Save results + if (!saveResults(Results, BatchCount, OutputPath, OutputPathLen, BytesWritten, + OutBufferMaxSize, OutputBufferSpan.data())) { return static_cast(ErrNo::RuntimeError); } - std::copy_n(Png, *BytesWritten, OutputBufferSpan.data()); - free(Png); - free(Results); - free(InputImageBuffer); - free(ControlImageBuffer); return static_cast(ErrNo::Success); } diff --git a/plugins/wasmedge_stablediffusion/sd_func.h b/plugins/wasmedge_stablediffusion/sd_func.h index 70b9d831..13586352 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.h +++ b/plugins/wasmedge_stablediffusion/sd_func.h @@ -3,9 +3,9 @@ #pragma once -#include "sd_base.h" - #include "runtime/callingframe.h" +#include "sd_base.h" +#include "stable-diffusion.h" namespace WasmEdge { namespace Host { From f023c2865e983e22ac0f03576242891fc1d59862 Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 6 Nov 2024 14:27:00 +0800 Subject: [PATCH 486/623] [WASI-NN] ggml: bump llama.cpp b4034 Signed-off-by: dm4 --- plugins/wasi_nn/wasinn_ggml.cpp | 10 ++++------ test/plugins/wasi_nn/CMakeLists.txt | 6 ++++++ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index e7c1e090..5006f593 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -402,13 +402,11 @@ ErrNo evaluateTokens(Graph &GraphRef, struct llama_context *LlamaContext, if (NEval > static_cast(GraphRef.BatchSize)) { NEval = static_cast(GraphRef.BatchSize); } - // llama_batch_get_one(*token, n_tokens, position, sequence_id) - // This will return batch for single sequence of tokens starting at - // position. - const llama_seq_id SequenceId = 0; + // llama_batch_get_one(*token, n_tokens) + // - Return batch for single sequence of tokens. + // - The sequence ID will be fixed to 0. auto Status = - llama_decode(LlamaContext, - llama_batch_get_one(&Tokens[I], NEval, NPast, SequenceId)); + llama_decode(LlamaContext, llama_batch_get_one(&Tokens[I], NEval)); if (Status == 1) { spdlog::error( "[WASI-NN] GGML backend: failed to llama_decode: try reducing the size of the batch or increasing the size of context"sv); diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index 4de4d4c4..432c008b 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -68,6 +68,12 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") target_compile_options(wasiNNTests PUBLIC /wd4067 # unexpected tokens following preprocessor directive - expected a newline + /wd4505 # unreferenced local function has been removed + ) + else() + # string_split in common.h unused + target_compile_options(wasiNNTests PUBLIC + -Wno-unused-function ) endif() elseif(BACKEND STREQUAL "piper") From a3b44858fabfe06d3bdb81ce8d70547ff3744245 Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Mon, 4 Nov 2024 17:44:45 +0800 Subject: [PATCH 487/623] [Docker] Install FFmpeg 6.1 via apt on Ubuntu 24.04 Signed-off-by: Yi Huang --- utils/docker/Dockerfile.ubuntu-plugins-deps | 28 ++++++++++++++++++--- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/utils/docker/Dockerfile.ubuntu-plugins-deps b/utils/docker/Dockerfile.ubuntu-plugins-deps index 1ba0da63..04c36940 100644 --- a/utils/docker/Dockerfile.ubuntu-plugins-deps +++ b/utils/docker/Dockerfile.ubuntu-plugins-deps @@ -20,15 +20,35 @@ RUN apt-get install -y \ libgrpc-dev \ protobuf-compiler-grpc -COPY opencvmini/install-opencvmini.sh . -ENV OPENCV_VERSION="4.8.0" -RUN [ "/bin/bash", "install-opencvmini.sh" ] +# FFmpeg 6.1 (ubuntu 24.04) +FROM base AS deps-24 + +RUN apt-get install -y \ + libavcodec-dev \ + libavdevice-dev \ + libavfilter-dev \ + libavformat-dev \ + libavutil-dev \ + libswresample-dev \ + libswscale-dev + +# FFmpeg 6.0 (ubuntu 20.04, 22.04) +FROM base AS deps-20 COPY ffmpeg/install-ffmpeg-v6.0.sh . RUN [ "/bin/bash", "install-ffmpeg-v6.0.sh" ] ENV PKG_CONFIG_PATH=/root/FFmpeg-n6.0/output/lib/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} ENV LD_LIBRARY_PATH=/root/FFmpeg-n6.0/output/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} +FROM deps-20 AS deps-22 + +# Other dependencies +FROM deps-${UBUNTU_VER} AS deps-all + +COPY opencvmini/install-opencvmini.sh . +ENV OPENCV_VERSION="4.8.0" +RUN [ "/bin/bash", "install-opencvmini.sh" ] + COPY wasi-nn/install-pytorch.sh . ENV PYTORCH_VERSION="2.4.1" ENV PYTORCH_INSTALL_TO="/root" @@ -47,7 +67,7 @@ COPY wasi-nn/install-onnxruntime.sh . RUN [ "/bin/bash", "install-onnxruntime.sh" ] ### cleanup -FROM base AS clean-apt +FROM deps-all AS clean-apt RUN rm -f \ install-opencvmini.sh \ From 23e177e0668953e8fcf50dfaf2fb4922e41197d2 Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Tue, 5 Nov 2024 15:34:06 +0800 Subject: [PATCH 488/623] [Docker] Ubuntu: Install deps to /usr/local instead of /root * Install FFmpeg 6.0 to /usr/local * Install PyTorch to /usr/local * Fix clean-up files for install-opencvmini.sh Signed-off-by: Yi Huang --- utils/docker/Dockerfile.ubuntu-plugins-deps | 14 ++++++++------ utils/opencvmini/install-opencvmini.sh | 2 +- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/utils/docker/Dockerfile.ubuntu-plugins-deps b/utils/docker/Dockerfile.ubuntu-plugins-deps index 04c36940..4b321f20 100644 --- a/utils/docker/Dockerfile.ubuntu-plugins-deps +++ b/utils/docker/Dockerfile.ubuntu-plugins-deps @@ -2,8 +2,6 @@ ARG BASE_IMAGE=wasmedge/wasmedge:latest ARG UBUNTU_VER=20 FROM ${BASE_IMAGE} AS base -WORKDIR /root - RUN apt-get update && \ apt-get install -y \ cargo \ @@ -35,24 +33,28 @@ RUN apt-get install -y \ # FFmpeg 6.0 (ubuntu 20.04, 22.04) FROM base AS deps-20 +WORKDIR /usr/local + COPY ffmpeg/install-ffmpeg-v6.0.sh . RUN [ "/bin/bash", "install-ffmpeg-v6.0.sh" ] -ENV PKG_CONFIG_PATH=/root/FFmpeg-n6.0/output/lib/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} -ENV LD_LIBRARY_PATH=/root/FFmpeg-n6.0/output/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} +ENV PKG_CONFIG_PATH=/usr/local/FFmpeg-n6.0/output/lib/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} +ENV LD_LIBRARY_PATH=/usr/local/FFmpeg-n6.0/output/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} FROM deps-20 AS deps-22 # Other dependencies FROM deps-${UBUNTU_VER} AS deps-all +WORKDIR /root + COPY opencvmini/install-opencvmini.sh . ENV OPENCV_VERSION="4.8.0" RUN [ "/bin/bash", "install-opencvmini.sh" ] COPY wasi-nn/install-pytorch.sh . ENV PYTORCH_VERSION="2.4.1" -ENV PYTORCH_INSTALL_TO="/root" -ENV Torch_DIR="/root/libtorch" +ENV PYTORCH_INSTALL_TO="/usr/local" +ENV Torch_DIR="/usr/local/libtorch" RUN [ "/bin/bash", "install-pytorch.sh" ] ARG UBUNTU_VER diff --git a/utils/opencvmini/install-opencvmini.sh b/utils/opencvmini/install-opencvmini.sh index 6e47fd99..43afef93 100644 --- a/utils/opencvmini/install-opencvmini.sh +++ b/utils/opencvmini/install-opencvmini.sh @@ -16,4 +16,4 @@ cmake --build . # Install to system cmake --install . -rm -f opencv.zip +cd - && rm -rf opencv opencv.zip From 5a1e8892362e8f0f8a3d689816e5178297f1892f Mon Sep 17 00:00:00 2001 From: grorge Date: Thu, 21 Nov 2024 16:11:20 +0800 Subject: [PATCH 489/623] [Plugin] Stable Diffusion: support clip_g option Signed-off-by: grorge --- plugins/wasmedge_stablediffusion/CMakeLists.txt | 2 +- plugins/wasmedge_stablediffusion/sd_func.cpp | 14 +++++++++----- plugins/wasmedge_stablediffusion/sd_func.h | 8 ++++---- .../wasmedge_stablediffusion.cpp | 4 ++++ 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index d0de98d9..c51a8bb0 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -32,7 +32,7 @@ message(STATUS "Downloading stable diffusion source") FetchContent_Declare( stable-diffusion GIT_REPOSITORY https://github.com/leejet/stable-diffusion.cpp.git - GIT_TAG 14206fd48832ab600d9db75f15acb5062ae2c296 + GIT_TAG ac54e0076052a196b7df961eb1f792c9ff4d7f22 GIT_SHALLOW TRUE ) FetchContent_MakeAvailable(stable-diffusion) diff --git a/plugins/wasmedge_stablediffusion/sd_func.cpp b/plugins/wasmedge_stablediffusion/sd_func.cpp index cccbe0e0..62455cda 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.cpp +++ b/plugins/wasmedge_stablediffusion/sd_func.cpp @@ -238,10 +238,10 @@ Expect SDConvert::body(const Runtime::CallingFrame &Frame, Expect SDCreateContext::body( const Runtime::CallingFrame &Frame, uint32_t ModelPathPtr, uint32_t ModelPathLen, uint32_t clipLPathPtr, uint32_t clipLPathLen, - uint32_t t5xxlPathPtr, uint32_t t5xxlPathLen, - uint32_t diffusionModelPathPtr, uint32_t diffusionModelPathLen, - uint32_t VaePathPtr, uint32_t VaePathLen, uint32_t TaesdPathPtr, - uint32_t TaesdPathLen, uint32_t ControlNetPathPtr, + uint32_t clipGPathPtr, uint32_t clipGPathLen, uint32_t t5xxlPathPtr, + uint32_t t5xxlPathLen, uint32_t diffusionModelPathPtr, + uint32_t diffusionModelPathLen, uint32_t VaePathPtr, uint32_t VaePathLen, + uint32_t TaesdPathPtr, uint32_t TaesdPathLen, uint32_t ControlNetPathPtr, uint32_t ControlNetPathLen, uint32_t LoraModelDirPtr, uint32_t LoraModelDirLen, uint32_t EmbedDirPtr, uint32_t EmbedDirLen, uint32_t IdEmbedDirPtr, uint32_t IdEmbedDirLen, uint32_t VaeDecodeOnly, @@ -255,6 +255,8 @@ Expect SDCreateContext::body( "Failed when accessing the input model path memory."sv) MEM_SPAN_CHECK(clipLPathSpan, MemInst, char, clipLPathPtr, clipLPathLen, "Failed when accessing the input clipL path memory."sv) + MEM_SPAN_CHECK(clipGPathSpan, MemInst, char, clipGPathPtr, clipGPathLen, + "Failed when accessing the input clipG path memory."sv) MEM_SPAN_CHECK(t5xxlPathSpan, MemInst, char, t5xxlPathPtr, t5xxlPathLen, "Failed when accessing the input t5xxl path memory."sv) MEM_SPAN_CHECK( @@ -292,6 +294,8 @@ Expect SDCreateContext::body( std::string(IdEmbedDirSpan.begin(), IdEmbedDirSpan.end()); std::string clipLPath = std::string(clipLPathSpan.begin(), clipLPathSpan.end()); + std::string clipGPath = + std::string(clipGPathSpan.begin(), clipGPathSpan.end()); std::string t5xxlPath = std::string(t5xxlPathSpan.begin(), t5xxlPathSpan.end()); std::string diffusionModelPath = @@ -308,7 +312,7 @@ Expect SDCreateContext::body( // Create context and import graph. spdlog::debug("[WasmEdge-StableDiffusion] Create context."sv); sd_ctx_t *Ctx = new_sd_ctx( - ModelPath.data(), clipLPath.data(), t5xxlPath.data(), + ModelPath.data(), clipLPath.data(), clipGPath.data(), t5xxlPath.data(), diffusionModelPath.data(), VaePath.data(), TaesdPath.data(), ControlNetPath.data(), LoraModelDir.data(), EmbedDir.data(), IdEmbedDir.data(), static_cast(VaeDecodeOnly), diff --git a/plugins/wasmedge_stablediffusion/sd_func.h b/plugins/wasmedge_stablediffusion/sd_func.h index 13586352..79592790 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.h +++ b/plugins/wasmedge_stablediffusion/sd_func.h @@ -17,10 +17,10 @@ class SDCreateContext : public StableDiffusion::Func { Expect body(const Runtime::CallingFrame &Frame, uint32_t ModelPathPtr, uint32_t ModelPathLen, uint32_t clipLPathPtr, uint32_t clipLPathLen, - uint32_t t5xxlPathPtr, uint32_t t5xxlPathLen, - uint32_t diffusionModelPathPtr, uint32_t diffusionModelPathLen, - uint32_t VaePathPtr, uint32_t VaePathLen, uint32_t TaesdPathPtr, - uint32_t TaesdPathLen, uint32_t ControlNetPathPtr, + uint32_t clipGPathPtr, uint32_t clipGPathLen, uint32_t t5xxlPathPtr, + uint32_t t5xxlPathLen, uint32_t diffusionModelPathPtr, + uint32_t diffusionModelPathLen, uint32_t VaePathPtr, uint32_t VaePathLen, + uint32_t TaesdPathPtr, uint32_t TaesdPathLen, uint32_t ControlNetPathPtr, uint32_t ControlNetPathLen, uint32_t LoraModelDirPtr, uint32_t LoraModelDirLen, uint32_t EmbedDirPtr, uint32_t EmbedDirLen, uint32_t IdEmbedDirPtr, uint32_t IdEmbedDirLen, uint32_t VaeDecodeOnly, diff --git a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp index af8ce453..1413b737 100644 --- a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp +++ b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp @@ -157,6 +157,8 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { static_cast(QuantModelPath.size()), // ModelPathLen 0, // ClipLPathPtr 0, // ClipLPathLen + 0, // ClipGPathPtr + 0, // ClipGPathLen 0, // T5xxlPathPtr 0, // T5xxlPathLen 0, // DiffusionModelPathPtr @@ -295,6 +297,8 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { static_cast(QuantModelPath.size()), // ModelPathLen 0, // ClipLPathPtr 0, // ClipLPathLen + 0, // ClipGPathPtr + 0, // ClipGPathLen 0, // T5xxlPathPtr 0, // T5xxlPathLen 0, // DiffusionModelPathPtr From 1461d07f7483e106ccf384767c3c099fd6d65fe3 Mon Sep 17 00:00:00 2001 From: grorge Date: Fri, 22 Nov 2024 18:58:32 +0800 Subject: [PATCH 490/623] [Plugin] Stable Diffusion: Bump the version Signed-off-by: grorge --- plugins/wasmedge_stablediffusion/sd_env.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasmedge_stablediffusion/sd_env.cpp b/plugins/wasmedge_stablediffusion/sd_env.cpp index ed1ec4a5..3a936bd4 100644 --- a/plugins/wasmedge_stablediffusion/sd_env.cpp +++ b/plugins/wasmedge_stablediffusion/sd_env.cpp @@ -17,7 +17,7 @@ Plugin::Plugin::PluginDescriptor Descriptor{ .Name = "wasmedge_stablediffusion", .Description = "Stable Diffusion plug-in for WasmEdge.", .APIVersion = Plugin::Plugin::CurrentAPIVersion, - .Version = {0, 1, 0, 0}, + .Version = {0, 2, 0, 0}, .ModuleCount = 1, .ModuleDescriptions = (Plugin::PluginModule::ModuleDescriptor[]){ From 932e3f0dc391dce5801ea18d27a2ff2bbe28c151 Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 26 Nov 2024 22:54:04 +0800 Subject: [PATCH 491/623] [WASI-NN] ggml: bump llama.cpp b4179 Signed-off-by: dm4 --- plugins/wasi_nn/wasinn_ggml.cpp | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 5006f593..68a8388c 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -344,11 +344,12 @@ Expect setupParams(Graph &GraphRef, common_params &Params) { Params.cpuparams.n_threads = static_cast(GraphRef.Threads); Params.cpuparams_batch.n_threads = static_cast(GraphRef.Threads); Params.embedding = GraphRef.Embedding; - Params.sparams.temp = static_cast(GraphRef.Temp); - Params.sparams.top_p = static_cast(GraphRef.TopP); - Params.sparams.penalty_repeat = static_cast(GraphRef.RepeatPenalty); - Params.sparams.penalty_present = static_cast(GraphRef.PresencePenalty); - Params.sparams.grammar = GraphRef.Grammar; + Params.sampling.temp = static_cast(GraphRef.Temp); + Params.sampling.top_p = static_cast(GraphRef.TopP); + Params.sampling.penalty_repeat = static_cast(GraphRef.RepeatPenalty); + Params.sampling.penalty_present = + static_cast(GraphRef.PresencePenalty); + Params.sampling.grammar = GraphRef.Grammar; return ErrNo::Success; } @@ -700,7 +701,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, GraphRef.BatchSize = ContextDefault.n_batch; GraphRef.Threads = ContextDefault.n_threads; // Initialize the sampling parameters. - const common_sampler_params SamplerDefault; + const common_params_sampling SamplerDefault; GraphRef.Temp = SamplerDefault.temp; GraphRef.TopP = SamplerDefault.top_p; GraphRef.RepeatPenalty = SamplerDefault.penalty_repeat; @@ -1085,7 +1086,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { common_params Params; setupParams(GraphRef, Params); struct common_sampler *Sampler = - common_sampler_init(GraphRef.LlamaModel, Params.sparams); + common_sampler_init(GraphRef.LlamaModel, Params.sampling); // Prepare variables; int32_t NPast = 0; @@ -1300,7 +1301,7 @@ Expect computeSingle(WasiNNEnvironment &Env, common_params Params; setupParams(GraphRef, Params); CxtRef.LlamaSampler = - common_sampler_init(GraphRef.LlamaModel, Params.sparams); + common_sampler_init(GraphRef.LlamaModel, Params.sampling); CxtRef.LlamaNPast = 0; // Get the context size. From 2fb49df448f9de204d305d43bd70d9aa0974f4fa Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 3 Dec 2024 11:06:45 +0800 Subject: [PATCH 492/623] [WASI-NN] ggml: add the warmup option Signed-off-by: dm4 --- plugins/wasi_nn/wasinn_ggml.cpp | 10 ++++++++++ plugins/wasi_nn/wasinn_ggml.h | 1 + 2 files changed, 11 insertions(+) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 68a8388c..03077b28 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -79,6 +79,7 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, // main-gpu: int64_t // tensor-split: string, comma-separated floating number list // use-mmap: use mmap + // warmup: bool // Context parameters (used by the llama context): // ctx-size: uint64_t // batch-size: uint64_t @@ -226,6 +227,14 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, return ErrNo::InvalidArgument; } } + if (Doc.at_key("warmup").error() == simdjson::SUCCESS) { + auto Err = Doc["warmup"].get().get(GraphRef.WarmUp); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the warmup option."sv); + return ErrNo::InvalidArgument; + } + } // The context parameters. if (Doc.at_key("ctx-size").error() == simdjson::SUCCESS) { @@ -341,6 +350,7 @@ Expect setupParams(Graph &GraphRef, common_params &Params) { Params.n_ctx = static_cast(GraphRef.CtxSize); Params.n_batch = static_cast(GraphRef.BatchSize); Params.n_ubatch = static_cast(GraphRef.UBatchSize); + Params.warmup = GraphRef.WarmUp; Params.cpuparams.n_threads = static_cast(GraphRef.Threads); Params.cpuparams_batch.n_threads = static_cast(GraphRef.Threads); Params.embedding = GraphRef.Embedding; diff --git a/plugins/wasi_nn/wasinn_ggml.h b/plugins/wasi_nn/wasinn_ggml.h index 0504bc0f..89223605 100644 --- a/plugins/wasi_nn/wasinn_ggml.h +++ b/plugins/wasi_nn/wasinn_ggml.h @@ -40,6 +40,7 @@ struct Graph { int64_t NGPULayers = 0; std::vector TensorSplit; bool UseMMap = true; + bool WarmUp = true; // Context parameters: uint64_t CtxSize; uint64_t BatchSize; From 80d59b74bfc0800140cd27fc6d72160d5964a97b Mon Sep 17 00:00:00 2001 From: LFsWang <7088579+LFsWang@users.noreply.github.com> Date: Wed, 4 Dec 2024 13:52:33 +0800 Subject: [PATCH 493/623] [Plugin] Update pytorch version (#3901) Signed-off-by: Sylveon --- utils/docker/Dockerfile.manylinux2014-build-plugins-deps | 2 +- utils/docker/Dockerfile.manylinux_2_28-plugins-deps | 2 +- utils/docker/Dockerfile.ubuntu-plugins-deps | 2 +- utils/wasi-nn/install-pytorch.sh | 6 +++--- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/utils/docker/Dockerfile.manylinux2014-build-plugins-deps b/utils/docker/Dockerfile.manylinux2014-build-plugins-deps index 44b24d4f..38223ed2 100644 --- a/utils/docker/Dockerfile.manylinux2014-build-plugins-deps +++ b/utils/docker/Dockerfile.manylinux2014-build-plugins-deps @@ -18,7 +18,7 @@ ENV OPENCV_VERSION "4.8.0" RUN [ "/bin/bash", "install-opencvmini.sh" ] COPY wasi-nn/install-pytorch.sh . -ENV PYTORCH_VERSION "2.4.1" +ENV PYTORCH_VERSION "2.5.1" ENV PYTORCH_INSTALL_TO "/root" ENV Torch_DIR "/root/libtorch" RUN [ "/bin/bash", "install-pytorch.sh", "--disable-cxx11-abi" ] diff --git a/utils/docker/Dockerfile.manylinux_2_28-plugins-deps b/utils/docker/Dockerfile.manylinux_2_28-plugins-deps index 5db52278..ac6791fb 100644 --- a/utils/docker/Dockerfile.manylinux_2_28-plugins-deps +++ b/utils/docker/Dockerfile.manylinux_2_28-plugins-deps @@ -9,7 +9,7 @@ RUN cd && (yum check-update || true) && \ yum install -y wget unzip zlib-devel zlib-static elfutils-libelf-devel COPY wasi-nn/install-pytorch.sh . -ENV PYTORCH_VERSION="2.4.1" +ENV PYTORCH_VERSION="2.5.1" ENV PYTORCH_INSTALL_TO="/root" ENV Torch_DIR="/root/libtorch" RUN [ "/bin/bash", "install-pytorch.sh", "--disable-cxx11-abi" ] diff --git a/utils/docker/Dockerfile.ubuntu-plugins-deps b/utils/docker/Dockerfile.ubuntu-plugins-deps index 4b321f20..940e2157 100644 --- a/utils/docker/Dockerfile.ubuntu-plugins-deps +++ b/utils/docker/Dockerfile.ubuntu-plugins-deps @@ -52,7 +52,7 @@ ENV OPENCV_VERSION="4.8.0" RUN [ "/bin/bash", "install-opencvmini.sh" ] COPY wasi-nn/install-pytorch.sh . -ENV PYTORCH_VERSION="2.4.1" +ENV PYTORCH_VERSION="2.5.1" ENV PYTORCH_INSTALL_TO="/usr/local" ENV Torch_DIR="/usr/local/libtorch" RUN [ "/bin/bash", "install-pytorch.sh" ] diff --git a/utils/wasi-nn/install-pytorch.sh b/utils/wasi-nn/install-pytorch.sh index 1f6ac1d6..55f77ac3 100755 --- a/utils/wasi-nn/install-pytorch.sh +++ b/utils/wasi-nn/install-pytorch.sh @@ -3,7 +3,7 @@ # SPDX-FileCopyrightText: 2019-2024 Second State INC if [[ ! -n ${PYTORCH_VERSION} ]]; then - PYTORCH_VERSION="2.4.1" + PYTORCH_VERSION="2.5.1" fi if [[ ! -n ${PYTORCH_INSTALL_TO} ]]; then @@ -11,13 +11,13 @@ if [[ ! -n ${PYTORCH_INSTALL_TO} ]]; then fi PYTORCH_LINK="libtorch-cxx11-abi" -PYTORCH_SHA="415c3ed51c766a6ef20dc10b2e60fae7f10a3ae8aa62223d6f4bccc1fc98740b" +PYTORCH_SHA="618ca54eef82a1dca46ff1993d5807d9c0deb0bae147da4974166a147cb562fa" for i in "$@"; do case $i in --disable-cxx11-abi) PYTORCH_LINK="libtorch" - PYTORCH_SHA="f49d55df661c566c29a7a75bcae2fad69177eaebd330618d42ca162eb3a1fad1" + PYTORCH_SHA="21d05ad61935fc70912c779443dba112bda9c9ec1c999345d724935828f81c55" shift ;; esac From 6e8ef7b90a1cd2d9c4353ecd45741725bb919870 Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 4 Dec 2024 17:28:34 +0800 Subject: [PATCH 494/623] [WASI-NN] ggml: disable warmup by default to match previous behavior Signed-off-by: dm4 --- plugins/wasi_nn/wasinn_ggml.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/wasinn_ggml.h b/plugins/wasi_nn/wasinn_ggml.h index 89223605..7a65b328 100644 --- a/plugins/wasi_nn/wasinn_ggml.h +++ b/plugins/wasi_nn/wasinn_ggml.h @@ -40,7 +40,7 @@ struct Graph { int64_t NGPULayers = 0; std::vector TensorSplit; bool UseMMap = true; - bool WarmUp = true; + bool WarmUp = false; // Context parameters: uint64_t CtxSize; uint64_t BatchSize; From 0af6d36fea164d495ea6fd25eef6cf250bf7944d Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 4 Dec 2024 17:56:08 +0800 Subject: [PATCH 495/623] [WASI-NN] ggml: reload llama context if embedding status changed Signed-off-by: dm4 --- plugins/wasi_nn/wasinn_ggml.cpp | 77 ++++++++++++++++++++++----------- 1 file changed, 51 insertions(+), 26 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 03077b28..5892b2b9 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -51,8 +51,28 @@ void LlamaLogCallback(ggml_log_level LogLevel, const char *LogText, } } +Expect setupParams(Graph &GraphRef, common_params &Params) { + Params.model = GraphRef.ModelFilePath; + Params.n_gpu_layers = static_cast(GraphRef.NGPULayers); + Params.n_ctx = static_cast(GraphRef.CtxSize); + Params.n_batch = static_cast(GraphRef.BatchSize); + Params.n_ubatch = static_cast(GraphRef.UBatchSize); + Params.warmup = GraphRef.WarmUp; + Params.cpuparams.n_threads = static_cast(GraphRef.Threads); + Params.cpuparams_batch.n_threads = static_cast(GraphRef.Threads); + Params.embedding = GraphRef.Embedding; + Params.sampling.temp = static_cast(GraphRef.Temp); + Params.sampling.top_p = static_cast(GraphRef.TopP); + Params.sampling.penalty_repeat = static_cast(GraphRef.RepeatPenalty); + Params.sampling.penalty_present = + static_cast(GraphRef.PresencePenalty); + Params.sampling.grammar = GraphRef.Grammar; + return ErrNo::Success; +} + Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, - bool *IsModelUpdated = nullptr) noexcept { + bool *IsModelUpdated = nullptr, + bool *IsContextUpdated = nullptr) noexcept { simdjson::dom::parser Parser; simdjson::dom::element Doc; auto ParseError = Parser.parse(Metadata).get(Doc); @@ -94,11 +114,8 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, // grammar: string // Get the current llama parameters. - llama_model_params ModelParams = llama_model_default_params(); - ModelParams.n_gpu_layers = static_cast(GraphRef.NGPULayers); - ModelParams.main_gpu = static_cast(GraphRef.MainGPU); - ModelParams.tensor_split = GraphRef.TensorSplit.data(); - ModelParams.use_mmap = GraphRef.UseMMap; + common_params Params; + setupParams(GraphRef, Params); // The plugin parameters. if (Doc.at_key("enable-log").error() == simdjson::SUCCESS) { @@ -337,29 +354,15 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, } // Check if the model is updated. - if (IsModelUpdated && ModelParams.n_gpu_layers != GraphRef.NGPULayers) { + if (IsModelUpdated && Params.n_gpu_layers != GraphRef.NGPULayers) { *IsModelUpdated = true; } - return ErrNo::Success; -} + // Check if the context parameters are updated. + if (IsContextUpdated && Params.embedding != GraphRef.Embedding) { + *IsContextUpdated = true; + } -Expect setupParams(Graph &GraphRef, common_params &Params) { - Params.model = GraphRef.ModelFilePath; - Params.n_gpu_layers = static_cast(GraphRef.NGPULayers); - Params.n_ctx = static_cast(GraphRef.CtxSize); - Params.n_batch = static_cast(GraphRef.BatchSize); - Params.n_ubatch = static_cast(GraphRef.UBatchSize); - Params.warmup = GraphRef.WarmUp; - Params.cpuparams.n_threads = static_cast(GraphRef.Threads); - Params.cpuparams_batch.n_threads = static_cast(GraphRef.Threads); - Params.embedding = GraphRef.Embedding; - Params.sampling.temp = static_cast(GraphRef.Temp); - Params.sampling.top_p = static_cast(GraphRef.TopP); - Params.sampling.penalty_repeat = static_cast(GraphRef.RepeatPenalty); - Params.sampling.penalty_present = - static_cast(GraphRef.PresencePenalty); - Params.sampling.grammar = GraphRef.Grammar; return ErrNo::Success; } @@ -850,6 +853,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } bool IsModelParamsUpdated = false; + bool IsContextParamsUpdated = false; // Use index 1 for metadata. if (Index == 1) { if (GraphRef.EnableDebugLog) { @@ -858,7 +862,8 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } const std::string Metadata(reinterpret_cast(Tensor.Tensor.data()), Tensor.Tensor.size()); - auto Res = parseMetadata(GraphRef, Metadata, &IsModelParamsUpdated); + auto Res = parseMetadata(GraphRef, Metadata, &IsModelParamsUpdated, + &IsContextParamsUpdated); if (Res != ErrNo::Success) { spdlog::error("[WASI-NN] GGML backend: Failed to parse metadata."sv); @@ -888,6 +893,26 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } #endif + // Some changes of context parameters will require the context to be + // reloaded. + if (IsContextParamsUpdated) { + if (GraphRef.EnableLog) { + spdlog::info( + "[WASI-NN] GGML backend: Reloaded model due to parameters change."sv); + } + llama_free(GraphRef.LlamaContext); + common_params Params; + setupParams(GraphRef, Params); + llama_new_context_with_model(GraphRef.LlamaModel, + common_context_params_to_llama(Params)); + if (GraphRef.LlamaContext == nullptr) { + spdlog::error( + "[WASI-NN] GGML backend: Error: unable to init context."sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } + } + if (GraphRef.EnableDebugLog) { spdlog::info( "[WASI-NN][Debug] GGML backend: found Metadata, processing...Done"sv); From 55a136874714796ba1e9dae37bc128b0a307ba17 Mon Sep 17 00:00:00 2001 From: dm4 Date: Fri, 6 Dec 2024 15:45:56 +0800 Subject: [PATCH 496/623] [WASI-NN] ggml: fix reloading llama context Signed-off-by: dm4 --- plugins/wasi_nn/wasinn_ggml.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 5892b2b9..2d544c81 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -903,8 +903,8 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, llama_free(GraphRef.LlamaContext); common_params Params; setupParams(GraphRef, Params); - llama_new_context_with_model(GraphRef.LlamaModel, - common_context_params_to_llama(Params)); + GraphRef.LlamaContext = llama_new_context_with_model( + GraphRef.LlamaModel, common_context_params_to_llama(Params)); if (GraphRef.LlamaContext == nullptr) { spdlog::error( "[WASI-NN] GGML backend: Error: unable to init context."sv); From 988b819d112ccb3c255fef22da42937057085073 Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Mon, 9 Dec 2024 17:45:33 +0800 Subject: [PATCH 497/623] [Lint] Fix `markdownlint` errors Signed-off-by: Yi Huang --- plugins/wasm_bpf/README.md | 5 +++-- test/plugins/wasm_bpf/assets/README.md | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/plugins/wasm_bpf/README.md b/plugins/wasm_bpf/README.md index 824bd062..18438062 100644 --- a/plugins/wasm_bpf/README.md +++ b/plugins/wasm_bpf/README.md @@ -3,6 +3,7 @@ This plugin added six host functions that give you Wasm application access to eBPF. Six functions are listed here. And all of them are in the module `wasm_bpf`, if you loaded this plugin. + ```c /// lookup a bpf map fd by name. i32 wasm_bpf_map_fd_by_name(u64 obj, u32 name); @@ -30,7 +31,7 @@ i32 wasm_bpf_map_operate(u64 fd, i32 cmd, u32 key, u32 value, ### Install dependencies -See the https://wasmedge.org/book/en/contribute/build_from_src/linux.html for how to build `WasmEdge` from source. +See the for how to build `WasmEdge` from source. #### libbpf @@ -44,7 +45,7 @@ Run the following commands at the root of the `WasmEdge` project: - Note: It's important to set `WASMEDGE_PLUGIN_WASM_BPF` to `TRUE` in the command line. This toggle controls the build of `wasm_bpf` plugin. -``` +```sh cmake -DWASMEDGE_PLUGIN_WASM_BPF:BOOL=TRUE -B ./build -G "Unix Makefiles" cmake --build ./build ``` diff --git a/test/plugins/wasm_bpf/assets/README.md b/test/plugins/wasm_bpf/assets/README.md index 33fa78dd..f2435fff 100644 --- a/test/plugins/wasm_bpf/assets/README.md +++ b/test/plugins/wasm_bpf/assets/README.md @@ -1,5 +1,6 @@ -This file contains bpf programs that will be used during testing. +# wasm_bpf Plugin tests +This file contains bpf programs that will be used during testing. - `bootstrap` and `runqlat`: examples copied from `wasm-bpf`. See [here](https://github.com/eunomia-bpf/wasm-bpf/tree/main/examples) for build instructions. From 6292e2d02e8e9d341736c8867160f64a2f11d59d Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Mon, 9 Dec 2024 17:58:54 +0800 Subject: [PATCH 498/623] [Lint] Run `clang-format-18` Signed-off-by: Yi Huang --- plugins/wasm_bpf/wasm-bpf.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasm_bpf/wasm-bpf.cpp b/plugins/wasm_bpf/wasm-bpf.cpp index 3c60665a..0c45fc08 100644 --- a/plugins/wasm_bpf/wasm-bpf.cpp +++ b/plugins/wasm_bpf/wasm-bpf.cpp @@ -209,7 +209,7 @@ int32_t wasm_bpf_program::attach_bpf_program(const char *name, if (!link) { return static_cast(libbpf_get_error(link)); } - links.emplace(std::unique_ptr{ + links.emplace(std::unique_ptr{ link, bpf_link__destroy}); return 0; } From 4effce57b827d09e580423af832d39021ddae423 Mon Sep 17 00:00:00 2001 From: grorge Date: Thu, 5 Dec 2024 22:43:59 +0800 Subject: [PATCH 499/623] [Plugin] Stable Diffusion: Upgrade to 9578fdc Signed-off-by: grorge --- .../wasmedge_stablediffusion/CMakeLists.txt | 14 ++--- plugins/wasmedge_stablediffusion/sd_env.cpp | 2 +- plugins/wasmedge_stablediffusion/sd_func.cpp | 57 +++++++++-------- plugins/wasmedge_stablediffusion/sd_func.h | 15 +++-- .../wasmedge_stablediffusion.cpp | 62 ++++++++++++++++--- 5 files changed, 100 insertions(+), 50 deletions(-) diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index c51a8bb0..1a7f79a1 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -32,23 +32,17 @@ message(STATUS "Downloading stable diffusion source") FetchContent_Declare( stable-diffusion GIT_REPOSITORY https://github.com/leejet/stable-diffusion.cpp.git - GIT_TAG ac54e0076052a196b7df961eb1f792c9ff4d7f22 + GIT_TAG master-9578fdc GIT_SHALLOW TRUE -) + ) +set(SD_BUILD_SHARED_LIBS ON CACHE INTERNAL "Stable diffusion plugin: Build shared libs") FetchContent_MakeAvailable(stable-diffusion) set_property(TARGET stable-diffusion PROPERTY POSITION_INDEPENDENT_CODE ON) if(APPLE AND CMAKE_SYSTEM_VERSION VERSION_LESS 23) # `cblas_sgemm()` introduced in macOS 13.3. set(GGML_NO_ACCELERATE ON CACHE INTERNAL "Stable diffusion plugin: Turn off accelerate") endif() -get_target_property(SD_DEPS stable-diffusion LINK_LIBRARIES) -foreach(dep ${SD_DEPS}) - if(TARGET ${dep}) - set_target_properties(${dep} PROPERTIES - POSITION_INDEPENDENT_CODE ON - ) - endif() -endforeach() + wasmedge_add_library(wasmedgePluginWasmEdgeStableDiffusion SHARED diff --git a/plugins/wasmedge_stablediffusion/sd_env.cpp b/plugins/wasmedge_stablediffusion/sd_env.cpp index 3a936bd4..f981cd44 100644 --- a/plugins/wasmedge_stablediffusion/sd_env.cpp +++ b/plugins/wasmedge_stablediffusion/sd_env.cpp @@ -17,7 +17,7 @@ Plugin::Plugin::PluginDescriptor Descriptor{ .Name = "wasmedge_stablediffusion", .Description = "Stable Diffusion plug-in for WasmEdge.", .APIVersion = Plugin::Plugin::CurrentAPIVersion, - .Version = {0, 2, 0, 0}, + .Version = {0, 3, 0, 0}, .ModuleCount = 1, .ModuleDescriptions = (Plugin::PluginModule::ModuleDescriptor[]){ diff --git a/plugins/wasmedge_stablediffusion/sd_func.cpp b/plugins/wasmedge_stablediffusion/sd_func.cpp index 62455cda..c92c5e77 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.cpp +++ b/plugins/wasmedge_stablediffusion/sd_func.cpp @@ -110,12 +110,10 @@ sd_image_t *readControlImage(Span ControlImage, int Width, int Height, } void upscalerModel(const char *UpscaleModelPath, uint32_t UpscaleRepeats, - int32_t NThreads, uint32_t Wtype, uint32_t BatchCount, - sd_image_t *Results) { + int32_t NThreads, uint32_t BatchCount, sd_image_t *Results) { // unused for RealESRGAN_x4plus_anime_6B.pth int UpscaleFactor = 4; - upscaler_ctx_t *UpscalerCtx = new_upscaler_ctx(UpscaleModelPath, NThreads, - static_cast(Wtype)); + upscaler_ctx_t *UpscalerCtx = new_upscaler_ctx(UpscaleModelPath, NThreads); if (UpscalerCtx == nullptr) { spdlog::error("[WasmEdge-StableDiffusion] Create upscaler ctx failed."sv); } else { @@ -247,7 +245,7 @@ Expect SDCreateContext::body( uint32_t IdEmbedDirPtr, uint32_t IdEmbedDirLen, uint32_t VaeDecodeOnly, uint32_t VaeTiling, int32_t NThreads, uint32_t Wtype, uint32_t RngType, uint32_t Schedule, uint32_t ClipOnCpu, uint32_t ControlNetCpu, - uint32_t VaeOnCpu, uint32_t SessiontIdPtr) { + uint32_t VaeOnCpu, uint32_t DiffusionFlashAttn, uint32_t SessiontIdPtr) { // Check memory instance from module. MEMINST_CHECK(MemInst, Frame, 0) // Check the input model buffer. @@ -318,7 +316,8 @@ Expect SDCreateContext::body( IdEmbedDir.data(), static_cast(VaeDecodeOnly), static_cast(VaeTiling), false, NThreads, static_cast(Wtype), static_cast(RngType), - static_cast(Schedule), ClipOnCpu, ControlNetCpu, VaeOnCpu); + static_cast(Schedule), ClipOnCpu, ControlNetCpu, VaeOnCpu, + DiffusionFlashAttn); if (Ctx == nullptr) { spdlog::error("[WasmEdge-StableDiffusion] Failed to create context."sv); return static_cast(ErrNo::InvalidArgument); @@ -337,9 +336,10 @@ Expect SDTextToImage::body( uint32_t NormalizeInput, uint32_t InputIdImagesDirPtr, uint32_t InputIdImagesDirLen, uint32_t CannyPreprocess, uint32_t UpscaleModelPathPtr, uint32_t UpscaleModelPathLen, - uint32_t UpscaleRepeats, uint32_t OutputPathPtr, uint32_t OutputPathLen, - uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, - uint32_t BytesWrittenPtr) { + uint32_t UpscaleRepeats, uint32_t SkipLayersPtr, uint32_t SkipLayersLen, + float SlgScale, float SkipLayerStart, float SkipLayerEnd, + uint32_t OutputPathPtr, uint32_t OutputPathLen, uint32_t OutBufferPtr, + uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr) { // Check memory instance from module. MEMINST_CHECK(MemInst, Frame, 0) // Check the input model buffer. @@ -358,6 +358,8 @@ Expect SDTextToImage::body( "Failed when accessing the return bytes written memory."sv) MEM_SPAN_CHECK(OutputPathSpan, MemInst, char, OutputPathPtr, OutputPathLen, "Failed when accessing the output path memory."sv) + MEM_SPAN_CHECK(SkipLayersSpan, MemInst, int32_t, SkipLayersPtr, SkipLayersLen, + "Failed when accessing the SkipLayers memory."sv) std::string Prompt(PromptSpan.begin(), PromptSpan.end()); std::string NegativePrompt(NegativePromptSpan.begin(), NegativePromptSpan.end()); @@ -381,11 +383,12 @@ Expect SDTextToImage::body( } // Generate images spdlog::info("[WasmEdge-StableDiffusion] Start to generate image."sv); - Results = - txt2img(SDCtx, Prompt.data(), NegativePrompt.data(), ClipSkip, CfgScale, - Guidance, Width, Height, sample_method_t(SampleMethod), - SampleSteps, Seed, BatchCount, ControlImage, ControlStrength, - StyleRatio, NormalizeInput, InputIdImagesDir.data()); + Results = txt2img( + SDCtx, Prompt.data(), NegativePrompt.data(), ClipSkip, CfgScale, Guidance, + Width, Height, sample_method_t(SampleMethod), SampleSteps, Seed, + BatchCount, ControlImage, ControlStrength, StyleRatio, NormalizeInput, + InputIdImagesDir.data(), SkipLayersSpan.data(), SkipLayersSpan.size(), + SlgScale, SkipLayerStart, SkipLayerEnd); free(ControlImage); if (Results == nullptr) { spdlog::error("[WasmEdge-StableDiffusion] Generate failed."sv); @@ -400,8 +403,7 @@ Expect SDTextToImage::body( std::string UpscaleModelPath(UpscaleModelSpan.begin(), UpscaleModelSpan.end()); upscalerModel(UpscaleModelPath.data(), UpscaleRepeats, - Env.getNThreads(SessionId), Env.getWtype(SessionId), - BatchCount, Results); + Env.getNThreads(SessionId), BatchCount, Results); } // Save results if (!saveResults(Results, BatchCount, OutputPath, OutputPathLen, BytesWritten, @@ -422,8 +424,10 @@ Expect SDImageToImage::body( uint32_t InputIdImagesDirPtr, uint32_t InputIdImagesDirLen, uint32_t CannyPreprocess, uint32_t UpscaleModelPathPtr, uint32_t UpscaleModelPathLen, uint32_t UpscaleRepeats, - uint32_t OutputPathPtr, uint32_t OutputPathLen, uint32_t OutBufferPtr, - uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr) { + uint32_t SkipLayersPtr, uint32_t SkipLayersLen, float SlgScale, + float SkipLayerStart, float SkipLayerEnd, uint32_t OutputPathPtr, + uint32_t OutputPathLen, uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, + uint32_t BytesWrittenPtr) { // Check memory instance from module. MEMINST_CHECK(MemInst, Frame, 0) // Check the input parameter valid. @@ -444,6 +448,8 @@ Expect SDImageToImage::body( "Failed when accessing the return bytes written memory."sv) MEM_SPAN_CHECK(OutputPathSpan, MemInst, char, OutputPathPtr, OutputPathLen, "Failed when accessing the output path memory."sv) + MEM_SPAN_CHECK(SkipLayersSpan, MemInst, int32_t, SkipLayersPtr, SkipLayersLen, + "Failed when accessing the SkipLayers memory."sv) if (!parameterCheck(Env, Width, Height, SessionId)) { return static_cast(ErrNo::InvalidArgument); } @@ -526,11 +532,13 @@ Expect SDImageToImage::body( // Generate images sd_image_t *Results = nullptr; spdlog::info("[WasmEdge-StableDiffusion] Start to generate image."sv); - Results = img2img(SDCtx, InputImage, Prompt.data(), NegativePrompt.data(), - ClipSkip, CfgScale, Guidance, Width, Height, - sample_method_t(SampleMethod), SampleSteps, Strength, Seed, - BatchCount, ControlImage, ControlStrength, StyleRatio, - NormalizeInput, InputIdImagesDir.data()); + Results = + img2img(SDCtx, InputImage, Prompt.data(), NegativePrompt.data(), ClipSkip, + CfgScale, Guidance, Width, Height, sample_method_t(SampleMethod), + SampleSteps, Strength, Seed, BatchCount, ControlImage, + ControlStrength, StyleRatio, NormalizeInput, + InputIdImagesDir.data(), SkipLayersSpan.data(), + SkipLayersSpan.size(), SlgScale, SkipLayerStart, SkipLayerEnd); free(ControlImage); free(InputImageBuffer); if (Results == nullptr) { @@ -546,8 +554,7 @@ Expect SDImageToImage::body( std::string UpscaleModelPath(UpscaleModelSpan.begin(), UpscaleModelSpan.end()); upscalerModel(UpscaleModelPath.data(), UpscaleRepeats, - Env.getNThreads(SessionId), Env.getWtype(SessionId), - BatchCount, Results); + Env.getNThreads(SessionId), BatchCount, Results); } // Save results if (!saveResults(Results, BatchCount, OutputPath, OutputPathLen, BytesWritten, diff --git a/plugins/wasmedge_stablediffusion/sd_func.h b/plugins/wasmedge_stablediffusion/sd_func.h index 79592790..cd4019ba 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.h +++ b/plugins/wasmedge_stablediffusion/sd_func.h @@ -26,7 +26,7 @@ class SDCreateContext : public StableDiffusion::Func { uint32_t IdEmbedDirPtr, uint32_t IdEmbedDirLen, uint32_t VaeDecodeOnly, uint32_t VaeTiling, int32_t NThreads, uint32_t Wtype, uint32_t RngType, uint32_t Schedule, uint32_t ClipOnCpu, uint32_t ControlNetCpu, - uint32_t VaeOnCpu, uint32_t SessiontIdPtr); + uint32_t VaeOnCpu, uint32_t DiffusionFlashAttn, uint32_t SessiontIdPtr); }; class SDImageToImage : public StableDiffusion::Func { @@ -43,9 +43,10 @@ class SDImageToImage : public StableDiffusion::Func { float StyleRatio, uint32_t NormalizeInput, uint32_t InputIdImagesDirPtr, uint32_t InputIdImagesDirLen, uint32_t CannyPreprocess, uint32_t UpscaleModelPathPtr, uint32_t UpscaleModelPathLen, - uint32_t UpscaleRepeats, uint32_t OutputPathPtr, uint32_t OutputPathLen, - uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, - uint32_t BytesWrittenPtr); + uint32_t UpscaleRepeats, uint32_t SkipLayersPtr, uint32_t SkipLayersLen, + float SlgScale, float SkipLayerStart, float SkipLayerEnd, + uint32_t OutputPathPtr, uint32_t OutputPathLen, uint32_t OutBufferPtr, + uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr); }; class SDTextToImage : public StableDiffusion::Func { @@ -62,8 +63,10 @@ class SDTextToImage : public StableDiffusion::Func { uint32_t InputIdImagesDirPtr, uint32_t InputIdImagesDirLen, uint32_t CannyPreprocess, uint32_t UpscaleModelPathPtr, uint32_t UpscaleModelPathLen, uint32_t UpscaleRepeats, - uint32_t OutputPathPtr, uint32_t OutputPathLen, uint32_t OutBufferPtr, - uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr); + uint32_t SkipLayersPtr, uint32_t SkipLayersLen, float SlgScale, + float SkipLayerStart, float SkipLayerEnd, uint32_t OutputPathPtr, + uint32_t OutputPathLen, uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, + uint32_t BytesWrittenPtr); }; class SDConvert : public StableDiffusion::Func { diff --git a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp index 1413b737..83e89a65 100644 --- a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp +++ b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include using WasmEdge::Host::StableDiffusion::ErrNo; @@ -115,6 +116,7 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { std::string Prompt = "a lovely cat"; std::string Prompt2 = "with blue eyes"; + std::vector SkipLayers = {7, 8, 9}; std::string OutputPathString = "./stableDiffusion/output.png"; std::vector OutputPath(OutputPathString.begin(), OutputPathString.end()); @@ -178,12 +180,13 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { 1, // VaeDecodeOnly 0, // VaeTiling -1, // NThreads - 34, // Wtype + 36, // Wtype 1, // RngType 0, // Schedule 0, // ClipOnCpu 0, // ControlNetCpu 0, // VaeOnCpu + 0, // DiffusionFlashAttn SessionPtr}, // SessiontIdPtr Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); @@ -194,10 +197,12 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { // Test: text_to_image -- generate image from text. { uint32_t PromptPtr = UINT32_C(0); - uint32_t OutputPathPtr = PromptPtr + PromptData.size(); + uint32_t SkipLayersPtr = PromptPtr + PromptData.size(); + uint32_t OutputPathPtr = SkipLayersPtr + SkipLayers.size() * 4; uint32_t BytesWrittenPtr = OutputPathPtr + OutputPath.size(); OutputPtr = BytesWrittenPtr + 4; writeBinaries(MemInst, PromptData, PromptPtr); + writeBinaries(MemInst, SkipLayers, SkipLayersPtr); writeBinaries(MemInst, OutputPath, OutputPathPtr); EXPECT_TRUE(HostFuncTextToImage.run( CallFrame, @@ -227,6 +232,11 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { 0, // UpscaleModelPathPtr 0, // UpscaleModelPathLen 1, // UpscaleRepeats + SkipLayersPtr, // SkipLayersPtr + static_cast(SkipLayers.size()), // SkipLayersLen + 0.0f, // SlgScale + 0.01f, // SkipLayerStart + 0.2, // SkipLayerEnd OutputPathPtr, // OutputPathPtr static_cast(OutputPath.size()), // OutputPathLen OutputPtr, // OutBufferPtr @@ -242,10 +252,12 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { // Test: text_to_image -- reuse context to generate image from text. { uint32_t PromptPtr = UINT32_C(0); - uint32_t OutputPathPtr = PromptPtr + PromptData.size(); + uint32_t SkipLayersPtr = PromptPtr + PromptData.size(); + uint32_t OutputPathPtr = SkipLayersPtr + SkipLayers.size() * 4; uint32_t BytesWrittenPtr = OutputPathPtr + OutputPath.size(); OutputPtr = BytesWrittenPtr + 4; writeBinaries(MemInst, PromptData, PromptPtr); + writeBinaries(MemInst, SkipLayers, SkipLayersPtr); writeBinaries(MemInst, OutputPath, OutputPathPtr); EXPECT_TRUE(HostFuncTextToImage.run( CallFrame, @@ -275,6 +287,11 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { 0, // UpscaleModelPathPtr 0, // UpscaleModelPathLen 1, // UpscaleRepeats + SkipLayersPtr, // SkipLayersPtr + static_cast(SkipLayers.size()), // SkipLayersLen + 0.0f, // SlgScale + 0.01f, // SkipLayerStart + 0.2, // SkipLayerEnd OutputPathPtr, // OutputPathPtr static_cast(OutputPath.size()), // OutputPathLen OutputPtr, // OutBufferPtr @@ -318,12 +335,13 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { 0, // VaeDecodeOnly 0, // VaeTiling -1, // NThreads - 34, // Wtype + 36, // Wtype 1, // RngType 0, // Schedule 0, // ClipOnCpu 0, // ControlNetCpu 0, // VaeOnCpu + 0, // DiffusionFlashAttn SessionPtr}, // SessiontIdPtr Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); @@ -334,11 +352,13 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { { uint32_t PromptPtr = UINT32_C(0); uint32_t InputPathPtr = PromptPtr + PromptData2.size(); - uint32_t OutputPathPtr = InputPathPtr + InputPath.size(); + uint32_t SkipLayersPtr = InputPathPtr + InputPath.size(); + uint32_t OutputPathPtr = SkipLayersPtr + SkipLayers.size() * 4; uint32_t BytesWrittenPtr = OutputPathPtr + OutputPath2.size(); OutputPtr = BytesWrittenPtr + 4; writeBinaries(MemInst, PromptData2, PromptPtr); writeBinaries(MemInst, InputPath, InputPathPtr); + writeBinaries(MemInst, SkipLayers, SkipLayersPtr); writeBinaries(MemInst, OutputPath2, OutputPathPtr); EXPECT_TRUE(HostFuncImageToImage.run( CallFrame, @@ -371,6 +391,11 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { 0, // UpscaleModelPathPtr 0, // UpscaleModelPathLen 1, // UpscaleRepeats + SkipLayersPtr, // SkipLayersPtr + static_cast(SkipLayers.size()), // SkipLayersLen + 0.0f, // SlgScale + 0.01f, // SkipLayerStart + 0.2, // SkipLayerEnd OutputPathPtr, // OutputPathPtr static_cast(OutputPath2.size()), // OutputPathLen OutputPtr, // OutBufferPtr @@ -386,11 +411,13 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { { uint32_t PromptPtr = UINT32_C(0); uint32_t InputPathPtr = PromptPtr + PromptData2.size(); - uint32_t OutputPathPtr = InputPathPtr + InputPath.size(); + uint32_t SkipLayersPtr = InputPathPtr + InputPath.size(); + uint32_t OutputPathPtr = SkipLayersPtr + SkipLayers.size() * 4; uint32_t BytesWrittenPtr = OutputPathPtr + OutputPath2.size(); OutputPtr = BytesWrittenPtr + 4; writeBinaries(MemInst, PromptData2, PromptPtr); writeBinaries(MemInst, InputPath, InputPathPtr); + writeBinaries(MemInst, SkipLayers, SkipLayersPtr); writeBinaries(MemInst, OutputPath2, OutputPathPtr); EXPECT_TRUE(HostFuncImageToImage.run( CallFrame, @@ -423,6 +450,11 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { 0, // UpscaleModelPathPtr 0, // UpscaleModelPathLen 1, // UpscaleRepeats + SkipLayersPtr, // SkipLayersPtr + static_cast(SkipLayers.size()), // SkipLayersLen + 0.0f, // SlgScale + 0.01f, // SkipLayerStart + 0.2, // SkipLayerEnd OutputPathPtr, // OutputPathPtr static_cast(OutputPath2.size()), // OutputPathLen OutputPtr, // OutBufferPtr @@ -438,10 +470,12 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { // Test: text_to_image -- non exist SessionId. { uint32_t PromptPtr = UINT32_C(0); - uint32_t OutputPathPtr = PromptPtr + PromptData.size(); + uint32_t SkipLayersPtr = PromptPtr + PromptData.size(); + uint32_t OutputPathPtr = SkipLayersPtr + SkipLayers.size() * 4; uint32_t BytesWrittenPtr = OutputPathPtr + OutputPath.size(); OutputPtr = BytesWrittenPtr + 4; writeBinaries(MemInst, PromptData, PromptPtr); + writeBinaries(MemInst, SkipLayers, SkipLayersPtr); writeBinaries(MemInst, OutputPath, OutputPathPtr); EXPECT_TRUE(HostFuncTextToImage.run( CallFrame, @@ -471,6 +505,11 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { 0, // UpscaleModelPathPtr 0, // UpscaleModelPathLen 1, // UpscaleRepeats + SkipLayersPtr, // SkipLayersPtr + static_cast(SkipLayers.size()), // SkipLayersLen + 0.0f, // SlgScale + 0.01f, // SkipLayerStart + 0.2, // SkipLayerEnd OutputPathPtr, // OutputPathPtr static_cast(OutputPath.size()), // OutputPathLen OutputPtr, // OutBufferPtr @@ -485,11 +524,13 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { { uint32_t PromptPtr = UINT32_C(0); uint32_t InputPathPtr = PromptPtr + PromptData2.size(); - uint32_t OutputPathPtr = InputPathPtr + InputPath.size(); + uint32_t SkipLayersPtr = InputPathPtr + InputPath.size(); + uint32_t OutputPathPtr = SkipLayersPtr + SkipLayers.size() * 4; uint32_t BytesWrittenPtr = OutputPathPtr + OutputPath2.size(); OutputPtr = BytesWrittenPtr + 4; writeBinaries(MemInst, PromptData2, PromptPtr); writeBinaries(MemInst, InputPath, InputPathPtr); + writeBinaries(MemInst, SkipLayers, SkipLayersPtr); writeBinaries(MemInst, OutputPath2, OutputPathPtr); EXPECT_TRUE(HostFuncImageToImage.run( CallFrame, @@ -522,6 +563,11 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { 0, // UpscaleModelPathPtr 0, // UpscaleModelPathLen 1, // UpscaleRepeats + SkipLayersPtr, // SkipLayersPtr + static_cast(SkipLayers.size()), // SkipLayersLen + 0.0f, // SlgScale + 0.01f, // SkipLayerStart + 0.2, // SkipLayerEnd OutputPathPtr, // OutputPathPtr static_cast(OutputPath2.size()), // OutputPathLen OutputPtr, // OutBufferPtr From a278b7512331a5f7c390439e24ce4e08b45bb600 Mon Sep 17 00:00:00 2001 From: grorge Date: Thu, 12 Dec 2024 20:52:12 +0800 Subject: [PATCH 500/623] [Plugin] Stable Diffusion: fix build metal failed Signed-off-by: grorge --- plugins/wasmedge_stablediffusion/CMakeLists.txt | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index 1a7f79a1..fb7db00f 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -31,8 +31,8 @@ endif() message(STATUS "Downloading stable diffusion source") FetchContent_Declare( stable-diffusion - GIT_REPOSITORY https://github.com/leejet/stable-diffusion.cpp.git - GIT_TAG master-9578fdc + GIT_REPOSITORY https://github.com/second-state/stable-diffusion.cpp.git + GIT_TAG d08889cb3f86f49d3f4f9c0c7e3781238c44bd3d GIT_SHALLOW TRUE ) set(SD_BUILD_SHARED_LIBS ON CACHE INTERNAL "Stable diffusion plugin: Build shared libs") @@ -108,6 +108,7 @@ else() -Wno-unused-value -Wno-uninitialized -Wno-format + -Wno-enum-compare ) endif() @@ -115,7 +116,7 @@ if(WASMEDGE_PLUGIN_STABLEDIFFUSION_METAL) add_custom_command( TARGET wasmedgePluginWasmEdgeStableDiffusion POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${stable-diffusion_SOURCE_DIR}/ggml/src/ggml-metal.metal ggml-metal.metal + COMMAND ${CMAKE_COMMAND} -E copy ${stable-diffusion_SOURCE_DIR}/ggml/src/ggml-metal/ggml-metal.metal ggml-metal.metal COMMAND ${CMAKE_COMMAND} -E copy ${stable-diffusion_SOURCE_DIR}/ggml/src/ggml-common.h ggml-common.h ) endif() From 50a43cb0a0a8717f762b235c82cf865b0c962267 Mon Sep 17 00:00:00 2001 From: PeterD1524 <53310459+PeterD1524@users.noreply.github.com> Date: Tue, 17 Dec 2024 16:48:46 +0800 Subject: [PATCH 501/623] [WASI-NN]: add finalize_execution_context function (#3917) [WASI-NN]: add finalize_execution_context function Signed-off-by: PeterD1524 --- plugins/wasi_nn/wasinnfunc.cpp | 27 +++++++++++++++++++++++++++ plugins/wasi_nn/wasinnfunc.h | 12 ++++++++++++ plugins/wasi_nn/wasinnmodule.cpp | 2 ++ 3 files changed, 41 insertions(+) diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index 3de7d13b..b9c9abd7 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -645,5 +645,32 @@ Expect WasiNNUnload::bodyImpl(const Runtime::CallingFrame &Frame, } } +Expect +WasiNNFinalizeExecCtx::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t Context) { +#ifdef WASMEDGE_BUILD_WASI_NN_RPC + if (Env.NNRPCChannel != nullptr) { + // TODO: implement RPC for finalize_execution_context + spdlog::error( + "[WASI-NN] RPC client is not implemented for finalize_execution_context"sv); + return WASINN::ErrNo::UnsupportedOperation; + } +#endif + auto *MemInst = Frame.getMemoryByIndex(0); + if (MemInst == nullptr) { + return Unexpect(ErrCode::Value::HostFuncError); + } + + if (Env.NNContext.size() <= Context) { + spdlog::error( + "[WASI-NN] finalize_execution_context: Execution Context does not exist."sv); + return WASINN::ErrNo::InvalidArgument; + } + + spdlog::error( + "[WASI-NN] finalize_execution_context: No backend supports finalize_execution_context."sv); + return WASINN::ErrNo::InvalidArgument; +} + } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasi_nn/wasinnfunc.h b/plugins/wasi_nn/wasinnfunc.h index cb399cc7..769955dc 100644 --- a/plugins/wasi_nn/wasinnfunc.h +++ b/plugins/wasi_nn/wasinnfunc.h @@ -174,5 +174,17 @@ class WasiNNUnload : public WasiNN { uint32_t GraphId); }; +class WasiNNFinalizeExecCtx : public WasiNN { +public: + WasiNNFinalizeExecCtx(WASINN::WasiNNEnvironment &HostEnv) : WasiNN(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, uint32_t Context) { + return bodyImpl(Frame, Context).map(castErrNo); + } + +private: + Expect bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t Context); +}; + } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasi_nn/wasinnmodule.cpp b/plugins/wasi_nn/wasinnmodule.cpp index 6226f89d..33c3d5b1 100644 --- a/plugins/wasi_nn/wasinnmodule.cpp +++ b/plugins/wasi_nn/wasinnmodule.cpp @@ -22,6 +22,8 @@ WasiNNModule::WasiNNModule() : ModuleInstance("wasi_ephemeral_nn") { addHostFunc("compute_single", std::make_unique(Env)); addHostFunc("fini_single", std::make_unique(Env)); addHostFunc("unload", std::make_unique(Env)); + addHostFunc("finalize_execution_context", + std::make_unique(Env)); } } // namespace Host From 669dd2b2b3ff6a400a084a1471fcf20ba45ef339 Mon Sep 17 00:00:00 2001 From: dm4 Date: Mon, 23 Dec 2024 21:28:53 +0800 Subject: [PATCH 502/623] [WASI-NN] ggml: support Qwen2VL and bump llama.cpp b4381 (#3930) Signed-off-by: dm4 --- plugins/wasi_nn/wasinn_ggml.cpp | 194 ++++++++++++++++++++++++++------ plugins/wasi_nn/wasinn_ggml.h | 20 ++++ 2 files changed, 177 insertions(+), 37 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 2d544c81..34ab55af 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -396,8 +396,72 @@ void buildOutputEmbedding(std::string &Embedding, int32_t NEmbd, NEmbd, fmt::join(Embeddings, Embeddings + NEmbd, ","sv)); } +static bool evaluateQwen2vlImageEmbed( + llama_context *CtxLlama, const struct llava_image_embed *ImageEmbed, + int NBatch, int *NPast, int *StPosId, struct clip_image_size *ImageSize) { + int NEmbd = llama_n_embd(llama_get_model(CtxLlama)); + const int PatchSize = 14 * 2; + const int Ph = + ImageSize->height / PatchSize + (ImageSize->height % PatchSize > 0); + const int Pw = + ImageSize->width / PatchSize + (ImageSize->width % PatchSize > 0); + auto ImgTokens = ImageEmbed->n_image_pos; + std::vector MRopePos; + MRopePos.resize(ImgTokens * 4); + + for (int Y = 0; Y < Ph; Y++) { + for (int X = 0; X < Pw; X++) { + int I = Y * Pw + X; + MRopePos[I] = *StPosId; + MRopePos[I + ImgTokens] = *StPosId + Y; + MRopePos[I + ImgTokens * 2] = *StPosId + X; + MRopePos[I + ImgTokens * 3] = 0; + } + } + *StPosId += std::max(Pw, Ph); + + int Processed = 0; + std::vector BatchMRopePos; + BatchMRopePos.resize(ImgTokens * 4); + + for (int I = 0; I < ImgTokens; I += NBatch) { + int NEval = ImgTokens - I; + if (NEval > NBatch) { + NEval = NBatch; + } + + std::fill(BatchMRopePos.begin(), BatchMRopePos.end(), 0); + std::copy_n(&MRopePos[Processed], NEval, BatchMRopePos.data()); + std::copy_n(&MRopePos[ImgTokens * 1 + Processed], NEval, + &BatchMRopePos[NEval * 1]); + std::copy_n(&MRopePos[ImgTokens * 2 + Processed], NEval, + &BatchMRopePos[NEval * 2]); + std::copy_n(&MRopePos[ImgTokens * 3 + Processed], NEval, + &BatchMRopePos[NEval * 3]); + + llama_batch Batch = { + static_cast(NEval), // n_tokens + nullptr, // token + (ImageEmbed->embed + I * NEmbd), // embed + BatchMRopePos.data(), // pos + nullptr, // n_seq_id + nullptr, // seq_id + nullptr, // logits + }; + if (llama_decode(CtxLlama, Batch)) { + spdlog::error( + "[WASI-NN] GGML backend: evaluateQwen2vlImageEmbed failed to eval"sv); + return false; + } + *NPast += NEval; + Processed += NEval; + } + return true; +} + ErrNo evaluateTokens(Graph &GraphRef, struct llama_context *LlamaContext, - std::vector Tokens, int &NPast) noexcept { + std::vector Tokens, int &NPast, + int &NPos) noexcept { uint32_t NCtx = llama_n_ctx(LlamaContext); // End the inference if the context is full. @@ -410,17 +474,28 @@ ErrNo evaluateTokens(Graph &GraphRef, struct llama_context *LlamaContext, return ErrNo::ContextFull; } + std::vector LlamaPos; for (int I = 0; I < static_cast(Tokens.size()); I += static_cast(GraphRef.BatchSize)) { int NEval = static_cast(Tokens.size()) - I; if (NEval > static_cast(GraphRef.BatchSize)) { NEval = static_cast(GraphRef.BatchSize); } - // llama_batch_get_one(*token, n_tokens) - // - Return batch for single sequence of tokens. - // - The sequence ID will be fixed to 0. - auto Status = - llama_decode(LlamaContext, llama_batch_get_one(&Tokens[I], NEval)); + // Get a batch for single sequence of tokens. + auto Batch = llama_batch_get_one(&Tokens[I], NEval); + + // Add pos information for Qwen2vl. + if (GraphRef.VisionModelType == VisionModel::Qwen2VL) { + LlamaPos.resize(Batch.n_tokens * 4); + std::fill(LlamaPos.begin(), LlamaPos.end(), 0); + for (int J = 0; J < Batch.n_tokens * 3; J++) { + LlamaPos[J] = NPos + (J % Batch.n_tokens); + } + Batch.pos = LlamaPos.data(); + } + + // Decode the batch. + auto Status = llama_decode(LlamaContext, Batch); if (Status == 1) { spdlog::error( "[WASI-NN] GGML backend: failed to llama_decode: try reducing the size of the batch or increasing the size of context"sv); @@ -431,6 +506,7 @@ ErrNo evaluateTokens(Graph &GraphRef, struct llama_context *LlamaContext, return ErrNo::RuntimeError; } NPast += NEval; + NPos += NEval; } return ErrNo::Success; @@ -447,7 +523,8 @@ void batchAddSeq(llama_batch &Batch, const std::vector &Tokens, } ErrNo batchDecode(llama_context *LlamaContext, llama_batch &Batch, - float *Output, int NEmbd) noexcept { + float *Output, int NEmbd, + EmbdNormalizeType EmbdNormalize) noexcept { // Clear previous kv_cache values (irrelevant for embeddings) llama_kv_cache_clear(LlamaContext); @@ -481,7 +558,8 @@ ErrNo batchDecode(llama_context *LlamaContext, llama_batch &Batch, } // Normalize the embeddings. - common_embd_normalize(Embd, Output, NEmbd); + common_embd_normalize(Embd, Output, NEmbd, + static_cast(EmbdNormalize)); } return ErrNo::Success; @@ -548,8 +626,8 @@ Expect getEmbedding(WasiNNEnvironment &Env, /* n_seq_max */ 1); std::vector Embeddings(NEmbd); batchAddSeq(Batch, CxtRef.LlamaInputs, SequenceId); - ReturnCode = - batchDecode(GraphRef.LlamaContext, Batch, Embeddings.data(), NEmbd); + ReturnCode = batchDecode(GraphRef.LlamaContext, Batch, Embeddings.data(), + NEmbd, GraphRef.EmbdNormalize); if (ReturnCode != ErrNo::Success) { spdlog::error("[WASI-NN] GGML backend: failed to evaluate input tokens."sv); return ReturnCode; @@ -699,6 +777,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Initialize the plugin parameters. auto ContextDefault = llama_context_default_params(); + const common_params ParamsDefault; GraphRef.EnableLog = false; GraphRef.EnableDebugLog = false; GraphRef.StreamStdout = false; @@ -706,12 +785,15 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, GraphRef.ReversePrompt = ""sv; GraphRef.MMProjModelPath = ""sv; GraphRef.ImagePath = ""sv; + GraphRef.EmbdNormalize = + static_cast(ParamsDefault.embd_normalize); // Initialize the model parameters. llama_model_params ModelParams = llama_model_default_params(); GraphRef.NGPULayers = ModelParams.n_gpu_layers; // Initialize the context parameters. GraphRef.CtxSize = ContextDefault.n_ctx; GraphRef.BatchSize = ContextDefault.n_batch; + GraphRef.UBatchSize = ContextDefault.n_ubatch; GraphRef.Threads = ContextDefault.n_threads; // Initialize the sampling parameters. const common_params_sampling SamplerDefault; @@ -976,32 +1058,50 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } } - // Load image for llava. - int LlavaVerbosity = 0; - if (GraphRef.EnableLog) { - LlavaVerbosity = 1; + // Load the clip model if not loaded. + if (GraphRef.ClipContext == nullptr) { + if (GraphRef.EnableLog) { + spdlog::info( + "[WASI-NN] GGML backend: Load the clip model. " + "Because llama.cpp disabled the GPU support for CLIP, " + "the step of loading images in CLIP can only use the CPU, " + "which may result in reduced efficiency. " + "(You can refer to PR https://github.com/ggerganov/llama.cpp/pull/10896)"sv); + } + GraphRef.ClipContext = clip_model_load(GraphRef.MMProjModelPath.c_str(), + GraphRef.EnableLog ? 1 : 0); + if (GraphRef.ClipContext == nullptr) { + spdlog::error( + "[WASI-NN] GGML backend: Error: unable to load the clip model."sv); + return ErrNo::InvalidArgument; + } + if (clip_is_qwen2vl(GraphRef.ClipContext)) { + GraphRef.VisionModelType = VisionModel::Qwen2VL; + if (GraphRef.EnableLog) { + spdlog::info("[WASI-NN] GGML backend: Qwen2vl model detected."sv); + } + } } - auto ClipContext = - clip_model_load(GraphRef.MMProjModelPath.c_str(), LlavaVerbosity); + + // Get image embed. if (ContainsBase64Image) { // Load the base64 image from the prompt. CxtRef.LlavaImageEmbd = - loadBase64ImageFromPrompt(GraphRef, ClipContext, Prompt); + loadBase64ImageFromPrompt(GraphRef, GraphRef.ClipContext, Prompt); // Replace the base64 image in the prompt with a placeholder. auto Res = replaceBase64ImagePlaceholderInPrompt(Prompt); if (Res != ErrNo::Success) { spdlog::error( "[WASI-NN] GGML backend: Error: unable to replace the base64 image in the prompt."sv); - clip_free(ClipContext); + clip_free(GraphRef.ClipContext); return Res; } } else { // Load the image from the file. CxtRef.LlavaImageEmbd = llava_image_embed_make_with_filename( - ClipContext, static_cast(GraphRef.Threads), + GraphRef.ClipContext, static_cast(GraphRef.Threads), GraphRef.ImagePath.c_str()); } - clip_free(ClipContext); if (CxtRef.LlavaImageEmbd == nullptr) { spdlog::error( "[WASI-NN] GGML backend: Error: unable to load the image."sv); @@ -1125,6 +1225,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // Prepare variables; int32_t NPast = 0; + int32_t NPos = 0; uint64_t NRemain = GraphRef.NPredict; // Get the context size. const uint64_t NCtx = llama_n_ctx(GraphRef.LlamaContext); @@ -1147,7 +1248,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { if (CxtRef.LlavaImageEmbd == nullptr) { // Text only prompt. ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext, - std::move(CxtRef.LlamaInputs), NPast); + std::move(CxtRef.LlamaInputs), NPast, NPos); if (ReturnCode != ErrNo::Success) { spdlog::error( "[WASI-NN] GGML backend: failed to evaluate input tokens."sv); @@ -1162,22 +1263,35 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { CxtRef.LlavaImagePosition, CxtRef.LlamaInputs.end()); ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext, - std::move(EmbdInputBeforeImage), NPast); + std::move(EmbdInputBeforeImage), NPast, NPos); if (ReturnCode != ErrNo::Success) { spdlog::error( "[WASI-NN] GGML backend: failed to evaluate input tokens before image."sv); return ReturnCode; } - bool EvalImageStatus = - llava_eval_image_embed(GraphRef.LlamaContext, CxtRef.LlavaImageEmbd, - static_cast(GraphRef.BatchSize), &NPast); + + bool EvalImageStatus = false; + switch (GraphRef.VisionModelType) { + case VisionModel::Llava: + EvalImageStatus = + llava_eval_image_embed(GraphRef.LlamaContext, CxtRef.LlavaImageEmbd, + static_cast(GraphRef.BatchSize), &NPast); + break; + case VisionModel::Qwen2VL: + auto ImageSize = clip_get_load_image_size(GraphRef.ClipContext); + EvalImageStatus = evaluateQwen2vlImageEmbed( + GraphRef.LlamaContext, CxtRef.LlavaImageEmbd, + static_cast(GraphRef.BatchSize), &NPast, &NPos, ImageSize); + break; + } + if (!EvalImageStatus) { spdlog::error( "[WASI-NN] GGML backend: failed to evaluate embed image tokens."sv); return ErrNo::RuntimeError; } ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext, - std::move(EmbdInputAfterImage), NPast); + std::move(EmbdInputAfterImage), NPast, NPos); if (ReturnCode != ErrNo::Success) { spdlog::error( "[WASI-NN] GGML backend: failed to evaluate input tokens after image."sv); @@ -1220,7 +1334,8 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { break; } // Evaluate the output token. - ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext, {Id}, NPast); + ReturnCode = + evaluateTokens(GraphRef, GraphRef.LlamaContext, {Id}, NPast, NPos); if (ReturnCode != ErrNo::Success) { break; } @@ -1338,6 +1453,7 @@ Expect computeSingle(WasiNNEnvironment &Env, CxtRef.LlamaSampler = common_sampler_init(GraphRef.LlamaModel, Params.sampling); CxtRef.LlamaNPast = 0; + CxtRef.LlamaNPos = 0; // Get the context size. const uint64_t NCtx = llama_n_ctx(GraphRef.LlamaContext); @@ -1359,9 +1475,9 @@ Expect computeSingle(WasiNNEnvironment &Env, // Evaluate input tokens. if (CxtRef.LlavaImageEmbd == nullptr) { // Text only prompt. - ReturnCode = - evaluateTokens(GraphRef, GraphRef.LlamaContext, - std::move(CxtRef.LlamaInputs), CxtRef.LlamaNPast); + ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext, + std::move(CxtRef.LlamaInputs), + CxtRef.LlamaNPast, CxtRef.LlamaNPos); if (ReturnCode != ErrNo::Success) { spdlog::error( "[WASI-NN] GGML backend: failed to evaluate input tokens."sv); @@ -1375,9 +1491,9 @@ Expect computeSingle(WasiNNEnvironment &Env, std::vector EmbdInputAfterImage( CxtRef.LlamaInputs.begin() + CxtRef.LlavaImagePosition, CxtRef.LlamaInputs.end()); - ReturnCode = - evaluateTokens(GraphRef, GraphRef.LlamaContext, - std::move(EmbdInputBeforeImage), CxtRef.LlamaNPast); + ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext, + std::move(EmbdInputBeforeImage), + CxtRef.LlamaNPast, CxtRef.LlamaNPos); if (ReturnCode != ErrNo::Success) { spdlog::error( "[WASI-NN] GGML backend: failed to evaluate input tokens before image."sv); @@ -1391,9 +1507,9 @@ Expect computeSingle(WasiNNEnvironment &Env, "[WASI-NN] GGML backend: failed to evaluate embed image tokens."sv); return ErrNo::RuntimeError; } - ReturnCode = - evaluateTokens(GraphRef, GraphRef.LlamaContext, - std::move(EmbdInputAfterImage), CxtRef.LlamaNPast); + ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext, + std::move(EmbdInputAfterImage), + CxtRef.LlamaNPast, CxtRef.LlamaNPos); if (ReturnCode != ErrNo::Success) { spdlog::error( "[WASI-NN] GGML backend: failed to evaluate input tokens after image."sv); @@ -1427,7 +1543,7 @@ Expect computeSingle(WasiNNEnvironment &Env, // Evaluate the output token if not EOS. if (ReturnCode != ErrNo::EndOfSequence) { ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext, {Id}, - CxtRef.LlamaNPast); + CxtRef.LlamaNPast, CxtRef.LlamaNPos); } if (GraphRef.EnableDebugLog) { spdlog::info( @@ -1476,6 +1592,10 @@ Expect finiSingle(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { common_sampler_reset(CxtRef.LlamaSampler); common_sampler_free(CxtRef.LlamaSampler); CxtRef.LlamaSampler = nullptr; + if (GraphRef.ClipContext != nullptr) { + clip_free(GraphRef.ClipContext); + GraphRef.ClipContext = nullptr; + } if (CxtRef.LlavaImageEmbd != nullptr) { llava_image_embed_free(CxtRef.LlavaImageEmbd); CxtRef.LlavaImageEmbd = nullptr; diff --git a/plugins/wasi_nn/wasinn_ggml.h b/plugins/wasi_nn/wasinn_ggml.h index 7a65b328..241696de 100644 --- a/plugins/wasi_nn/wasinn_ggml.h +++ b/plugins/wasi_nn/wasinn_ggml.h @@ -21,20 +21,39 @@ struct WasiNNEnvironment; namespace WasmEdge::Host::WASINN::GGML { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML + +enum class EmbdNormalizeType : int32_t { + // Follow: + // https://github.com/ggerganov/llama.cpp/blob/0bf2d10c5514ff61b99897a4a5054f846e384e1e/common/common.h#L312 + None = -1, + MaxAbsolute = 0, + Taxicab = 1, + Euclidean = 2, + PNorm = 3, +}; + +enum class VisionModel : uint8_t { + Llava = 0, + Qwen2VL = 1, +}; + struct Graph { llama_model *LlamaModel = nullptr; std::string ModelFilePath; llama_context *LlamaContext = nullptr; + struct clip_ctx *ClipContext = nullptr; // Plugin parameters: bool EnableLog = false; bool EnableDebugLog = false; bool StreamStdout = false; bool Embedding = false; + EmbdNormalizeType EmbdNormalize = EmbdNormalizeType::Euclidean; bool ComputeSingleStarted = false; uint64_t NPredict; std::string ReversePrompt; std::string MMProjModelPath; std::string ImagePath; + VisionModel VisionModelType = VisionModel::Llava; // Model parameters: int64_t MainGPU = 0; // Use GPU 0 by default int64_t NGPULayers = 0; @@ -66,6 +85,7 @@ struct Context { // Preserve for computing single token common_sampler *LlamaSampler = nullptr; int32_t LlamaNPast = 0; + int32_t LlamaNPos = 0; // Preserve for llava struct llava_image_embed *LlavaImageEmbd = nullptr; size_t LlavaImagePosition = 0; From e59fef129d682038ceefc701e32611dadd4154e6 Mon Sep 17 00:00:00 2001 From: "Shen-Ta Hsieh(BestSteve)" Date: Sun, 29 Dec 2024 23:42:40 +0800 Subject: [PATCH 503/623] [Docker] Update llvm, zstd, ninja, cmake in dockerfile (#3924) [Docker] Update llvm, zstd, ninja, cmake in dockerfile and workflows * Remove unused zstd flags in manylinux_2_28 * Update windows sdk to 26100 Signed-off-by: Shen-Ta Hsieh --- utils/docker/Dockerfile.manylinux2014_aarch64 | 54 +++++++++---------- utils/docker/Dockerfile.manylinux2014_x86_64 | 54 +++++++++---------- utils/docker/Dockerfile.manylinux_2_28-base | 43 ++++++++------- utils/docker/SHA256SUM.manylinux2014 | 18 +++---- utils/docker/SHA256SUM.manylinux_2_28 | 14 ++--- 5 files changed, 91 insertions(+), 92 deletions(-) diff --git a/utils/docker/Dockerfile.manylinux2014_aarch64 b/utils/docker/Dockerfile.manylinux2014_aarch64 index b6665388..30a3e87a 100644 --- a/utils/docker/Dockerfile.manylinux2014_aarch64 +++ b/utils/docker/Dockerfile.manylinux2014_aarch64 @@ -18,39 +18,39 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil 'import multiprocessing; print(multiprocessing.cpu_count())') && \ export CFGFLAGS="--prefix=/opt/rh/devtoolset-10/root/usr --disable-shared --libdir=/opt/rh/devtoolset-10/root/usr/lib64" && \ curl -s -L -O --remote-name-all \ - https://github.com/facebook/zstd/releases/download/v1.5.5/zstd-1.5.5.tar.gz \ - https://github.com/Kitware/CMake/releases/download/v3.29.3/cmake-3.29.3.tar.gz \ - https://github.com/ninja-build/ninja/archive/refs/tags/v1.11.1.tar.gz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/llvm-17.0.6.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/lld-17.0.6.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/libunwind-17.0.6.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/cmake-17.0.6.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/third-party-17.0.6.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/clang-17.0.6.src.tar.xz && \ + https://github.com/facebook/zstd/releases/download/v1.5.6/zstd-1.5.6.tar.gz \ + https://github.com/Kitware/CMake/releases/download/v3.31.2/cmake-v3.31.2.tar.gz \ + https://github.com/ninja-build/ninja/archive/refs/tags/v1.12.1.tar.gz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/llvm-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/lld-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/libunwind-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/cmake-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/third-party-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/clang-19.1.5.src.tar.xz && \ sha256sum -c SHA256SUM.manylinux2014 && \ - gzip -dc zstd-1.5.5.tar.gz | tar -xf - && \ - gzip -dc cmake-3.29.3.tar.gz | tar -xf - && \ - gzip -dc v1.11.1.tar.gz | tar -xf - && \ - xz -dc llvm-17.0.6.src.tar.xz | tar -xf - && \ - xz -dc lld-17.0.6.src.tar.xz | tar -xf - && \ - xz -dc libunwind-17.0.6.src.tar.xz | tar -xf - && \ - xz -dc cmake-17.0.6.src.tar.xz | tar -xf - && \ - xz -dc third-party-17.0.6.src.tar.xz | tar -xf - && \ - xz -dc clang-17.0.6.src.tar.xz | tar -xf - && \ + gzip -dc zstd-1.5.6.tar.gz | tar -xf - && \ + gzip -dc cmake-v3.31.2.tar.gz | tar -xf - && \ + gzip -dc v1.12.1.tar.gz | tar -xf - && \ + xz -dc llvm-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc lld-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc libunwind-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc cmake-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc third-party-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc clang-19.1.5.src.tar.xz | tar -xf - && \ export ZSTDFLAGS=(PREFIX=/opt/rh/devtoolset-10/root/usr LIBDIR=/opt/rh/devtoolset-10/root/usr/lib64 SED_ERE_OPT=--regexp-extended MOREFLAGS="-std=c17 -O3 -fPIC -fPIE -fvisibility=hidden") && \ - cd zstd-1.5.5 && make -s "${ZSTDFLAGS[@]}" -j $CPU && make -s "${ZSTDFLAGS[@]}" install && rm -vf /opt/rh/devtoolset-10/root/usr/lib64/libzstd.so* && cd - && \ + cd zstd-1.5.6 && make -s "${ZSTDFLAGS[@]}" -j $CPU && make -s "${ZSTDFLAGS[@]}" install && rm -vf /opt/rh/devtoolset-10/root/usr/lib64/libzstd.so* && cd - && \ mkdir build && cd build && /opt/python/cp311-cp311/bin/python \ - ../ninja-1.11.1/configure.py --bootstrap \ + ../ninja-1.12.1/configure.py --bootstrap \ --with-python=/opt/python/cp311-cp311/bin/python && \ cp -v ninja /opt/rh/devtoolset-10/root/usr/bin/ninja && cd - && rm -rf build && \ - mkdir build && cd build && ../cmake-3.29.3/configure --prefix=/opt/rh/devtoolset-10/root/usr \ + mkdir build && cd build && ../cmake-v3.31.2/configure --prefix=/opt/rh/devtoolset-10/root/usr \ --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mv -v llvm-17.0.6.src llvm && \ - mv -v lld-17.0.6.src lld && \ - mv -v libunwind-17.0.6.src libunwind && \ - mv -v cmake-17.0.6.src cmake && \ - mv -v third-party-17.0.6.src third-party && \ - mv -v clang-17.0.6.src clang && \ + mv -v llvm-19.1.5.src llvm && \ + mv -v lld-19.1.5.src lld && \ + mv -v libunwind-19.1.5.src libunwind && \ + mv -v cmake-19.1.5.src cmake && \ + mv -v third-party-19.1.5.src third-party && \ + mv -v clang-19.1.5.src clang && \ cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_INSTALL_PREFIX=/opt/rh/devtoolset-10/root/usr \ -DPython3_ROOT_DIR=/opt/python/cp311-cp311 -DLLVM_LIBDIR_SUFFIX=64 \ diff --git a/utils/docker/Dockerfile.manylinux2014_x86_64 b/utils/docker/Dockerfile.manylinux2014_x86_64 index da67eae9..6ba01f13 100644 --- a/utils/docker/Dockerfile.manylinux2014_x86_64 +++ b/utils/docker/Dockerfile.manylinux2014_x86_64 @@ -18,39 +18,39 @@ RUN cd && (yum check-update || true) && yum install -y xz openssl-devel rpm-buil 'import multiprocessing; print(multiprocessing.cpu_count())') && \ export CFGFLAGS="--prefix=/opt/rh/devtoolset-11/root/usr --disable-shared --libdir=/opt/rh/devtoolset-11/root/usr/lib64" && \ curl -s -L -O --remote-name-all \ - https://github.com/facebook/zstd/releases/download/v1.5.5/zstd-1.5.5.tar.gz \ - https://github.com/Kitware/CMake/releases/download/v3.29.3/cmake-3.29.3.tar.gz \ - https://github.com/ninja-build/ninja/archive/refs/tags/v1.11.1.tar.gz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/llvm-17.0.6.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/lld-17.0.6.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/libunwind-17.0.6.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/cmake-17.0.6.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/third-party-17.0.6.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/clang-17.0.6.src.tar.xz && \ + https://github.com/facebook/zstd/releases/download/v1.5.6/zstd-1.5.6.tar.gz \ + https://github.com/Kitware/CMake/releases/download/v3.31.2/cmake-v3.31.2.tar.gz \ + https://github.com/ninja-build/ninja/archive/refs/tags/v1.12.1.tar.gz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/llvm-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/lld-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/libunwind-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/cmake-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/third-party-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/clang-19.1.5.src.tar.xz && \ sha256sum -c SHA256SUM.manylinux2014 && \ - gzip -dc zstd-1.5.5.tar.gz | tar -xf - && \ - gzip -dc cmake-3.29.3.tar.gz | tar -xf - && \ - gzip -dc v1.11.1.tar.gz | tar -xf - && \ - xz -dc llvm-17.0.6.src.tar.xz | tar -xf - && \ - xz -dc lld-17.0.6.src.tar.xz | tar -xf - && \ - xz -dc libunwind-17.0.6.src.tar.xz | tar -xf - && \ - xz -dc cmake-17.0.6.src.tar.xz | tar -xf - && \ - xz -dc third-party-17.0.6.src.tar.xz | tar -xf - && \ - xz -dc clang-17.0.6.src.tar.xz | tar -xf - && \ + gzip -dc zstd-1.5.6.tar.gz | tar -xf - && \ + gzip -dc cmake-v3.31.2.tar.gz | tar -xf - && \ + gzip -dc v1.12.1.tar.gz | tar -xf - && \ + xz -dc llvm-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc lld-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc libunwind-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc cmake-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc third-party-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc clang-19.1.5.src.tar.xz | tar -xf - && \ export ZSTDFLAGS=(PREFIX=/opt/rh/devtoolset-11/root/usr LIBDIR=/opt/rh/devtoolset-11/root/usr/lib64 SED_ERE_OPT=--regexp-extended MOREFLAGS="-std=c17 -O3 -fPIC -fPIE -fvisibility=hidden") && \ - cd zstd-1.5.5 && make -s "${ZSTDFLAGS[@]}" -j $CPU && make -s "${ZSTDFLAGS[@]}" install && rm -vf /opt/rh/devtoolset-11/root/usr/lib64/libzstd.so* && cd - && \ + cd zstd-1.5.6 && make -s "${ZSTDFLAGS[@]}" -j $CPU && make -s "${ZSTDFLAGS[@]}" install && rm -vf /opt/rh/devtoolset-11/root/usr/lib64/libzstd.so* && cd - && \ mkdir build && cd build && /opt/python/cp311-cp311/bin/python \ - ../ninja-1.11.1/configure.py --bootstrap \ + ../ninja-1.12.1/configure.py --bootstrap \ --with-python=/opt/python/cp311-cp311/bin/python && \ cp -v ninja /opt/rh/devtoolset-11/root/usr/bin/ninja && cd - && rm -rf build && \ - mkdir build && cd build && ../cmake-3.29.3/configure --prefix=/opt/rh/devtoolset-11/root/usr \ + mkdir build && cd build && ../cmake-v3.31.2/configure --prefix=/opt/rh/devtoolset-11/root/usr \ --parallel=$CPU && make -s -j $CPU && make -s install && cd - && rm -rf build && \ - mv -v llvm-17.0.6.src llvm && \ - mv -v lld-17.0.6.src lld && \ - mv -v libunwind-17.0.6.src libunwind && \ - mv -v cmake-17.0.6.src cmake && \ - mv -v third-party-17.0.6.src third-party && \ - mv -v clang-17.0.6.src clang && \ + mv -v llvm-19.1.5.src llvm && \ + mv -v lld-19.1.5.src lld && \ + mv -v libunwind-19.1.5.src libunwind && \ + mv -v cmake-19.1.5.src cmake && \ + mv -v third-party-19.1.5.src third-party && \ + mv -v clang-19.1.5.src clang && \ cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_INSTALL_PREFIX=/opt/rh/devtoolset-11/root/usr \ -DPython3_ROOT_DIR=/opt/python/cp311-cp311 -DLLVM_LIBDIR_SUFFIX=64 \ diff --git a/utils/docker/Dockerfile.manylinux_2_28-base b/utils/docker/Dockerfile.manylinux_2_28-base index 30b64c31..da7b1449 100644 --- a/utils/docker/Dockerfile.manylinux_2_28-base +++ b/utils/docker/Dockerfile.manylinux_2_28-base @@ -21,32 +21,31 @@ RUN cd && (yum check-update || true) && yum install -y openssl-devel rpm-build c 'import multiprocessing; print(multiprocessing.cpu_count())') && \ export CFGFLAGS="--prefix=/opt/rh/gcc-toolset-13/root/usr --disable-shared --libdir=/opt/rh/gcc-toolset-13/root/usr/lib64" && \ curl -s -L -O --remote-name-all \ - https://github.com/ninja-build/ninja/archive/refs/tags/v1.11.1.tar.gz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/llvm-17.0.6.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/lld-17.0.6.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/libunwind-17.0.6.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/cmake-17.0.6.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/third-party-17.0.6.src.tar.xz \ - https://github.com/llvm/llvm-project/releases/download/llvmorg-17.0.6/clang-17.0.6.src.tar.xz && \ + https://github.com/ninja-build/ninja/archive/refs/tags/v1.12.1.tar.gz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/llvm-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/lld-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/libunwind-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/cmake-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/third-party-19.1.5.src.tar.xz \ + https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.5/clang-19.1.5.src.tar.xz && \ sha256sum -c SHA256SUM.manylinux_2_28 && \ - gzip -dc v1.11.1.tar.gz | tar -xf - && \ - xz -dc llvm-17.0.6.src.tar.xz | tar -xf - && \ - xz -dc lld-17.0.6.src.tar.xz | tar -xf - && \ - xz -dc libunwind-17.0.6.src.tar.xz | tar -xf - && \ - xz -dc cmake-17.0.6.src.tar.xz | tar -xf - && \ - xz -dc third-party-17.0.6.src.tar.xz | tar -xf - && \ - xz -dc clang-17.0.6.src.tar.xz | tar -xf - && \ - export ZSTDFLAGS=(PREFIX=/opt/rh/gcc-toolset-13/root/usr LIBDIR=/opt/rh/gcc-toolset-13/root/usr/lib64 SED_ERE_OPT=--regexp-extended MOREFLAGS="-std=c17 -O3 -fPIC -fPIE -fvisibility=hidden") && \ + gzip -dc v1.12.1.tar.gz | tar -xf - && \ + xz -dc llvm-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc lld-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc libunwind-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc cmake-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc third-party-19.1.5.src.tar.xz | tar -xf - && \ + xz -dc clang-19.1.5.src.tar.xz | tar -xf - && \ mkdir build && cd build && /opt/python/cp311-cp311/bin/python \ - ../ninja-1.11.1/configure.py --bootstrap \ + ../ninja-1.12.1/configure.py --bootstrap \ --with-python=/opt/python/cp311-cp311/bin/python && \ cp -v ninja /opt/rh/gcc-toolset-13/root/usr/bin/ninja && cd - && rm -rf build && \ - mv -v llvm-17.0.6.src llvm && \ - mv -v lld-17.0.6.src lld && \ - mv -v libunwind-17.0.6.src libunwind && \ - mv -v cmake-17.0.6.src cmake && \ - mv -v third-party-17.0.6.src third-party && \ - mv -v clang-17.0.6.src clang && \ + mv -v llvm-19.1.5.src llvm && \ + mv -v lld-19.1.5.src lld && \ + mv -v libunwind-19.1.5.src libunwind && \ + mv -v cmake-19.1.5.src cmake && \ + mv -v third-party-19.1.5.src third-party && \ + mv -v clang-19.1.5.src clang && \ cmake -Bbuild -GNinja -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_INSTALL_PREFIX=/opt/rh/gcc-toolset-13/root/usr \ -DPython3_ROOT_DIR=/opt/python/cp311-cp311 -DLLVM_LIBDIR_SUFFIX=64 \ diff --git a/utils/docker/SHA256SUM.manylinux2014 b/utils/docker/SHA256SUM.manylinux2014 index 2e48695e..63aa74fb 100644 --- a/utils/docker/SHA256SUM.manylinux2014 +++ b/utils/docker/SHA256SUM.manylinux2014 @@ -1,9 +1,9 @@ -a78f668a726ae1d3d9a7179996d97b12b90fb76ab9442a43110b972ff7ad9029 clang-17.0.6.src.tar.xz -807f069c54dc20cb47b21c1f6acafdd9c649f3ae015609040d6182cab01140f4 cmake-17.0.6.src.tar.xz -252aee1448d49caa04954fd5e27d189dd51570557313e7b281636716a238bccb cmake-3.29.3.tar.gz -9e7535a353aa862730b4ba38df42e06f6856b40c4cc51b57f27b5046dc21d70d libunwind-17.0.6.src.tar.xz -4ac13125616dc44905b85820aa403d27ec1226329b7f674daeb5f5584c6f0b22 lld-17.0.6.src.tar.xz -b638167da139126ca11917b6880207cc6e8f9d1cbb1a48d87d017f697ef78188 llvm-17.0.6.src.tar.xz -3054d0a9c9375dab1a4539cc2cc45ab340341c5d71475f9599ba7752e222947b third-party-17.0.6.src.tar.xz -31747ae633213f1eda3842686f83c2aa1412e0f5691d1c14dbbcc67fe7400cea v1.11.1.tar.gz -9c4396cc829cfae319a6e2615202e82aad41372073482fce286fac78646d3ee4 zstd-1.5.5.tar.gz +e7dfc8050407b5cc564c1c1afe19517255c9229cccd886dbd5bac9b652828d85 clang-19.1.5.src.tar.xz +a08ae477571fd5e929c27d3d0d28c6168d58dd00b6354c2de3266ae0d86ad44f cmake-19.1.5.src.tar.xz +0019dfc4b32d63c1392aa264aed2253c1e0c2fb09216f8e2cc269bbfb8bb49b5 cmake-v3.31.2.tar.gz +997b493fb604e5e2c5b11c765a4c42b37acf00a4d6e8a14f8108d5c1051d760f libunwind-19.1.5.src.tar.xz +f71835d49461a15c283aa9030a854abfd7c651a685d711a67158644b043f6f14 lld-19.1.5.src.tar.xz +7d71635948e4da1814ce8e15ec45399e4094a5442e86d352c96ded0f2b3171b6 llvm-19.1.5.src.tar.xz +22b352c35b034a4ab3f2b852b6a2602a4da8971abe459080450d9e3462550d1d third-party-19.1.5.src.tar.xz +821bdff48a3f683bc4bb3b6f0b5fe7b2d647cf65d52aeb63328c91a6c6df285a v1.12.1.tar.gz +8c29e06cf42aacc1eafc4077ae2ec6c6fcb96a626157e0593d5e82a34fd403c1 zstd-1.5.6.tar.gz diff --git a/utils/docker/SHA256SUM.manylinux_2_28 b/utils/docker/SHA256SUM.manylinux_2_28 index 0b732009..136490f8 100644 --- a/utils/docker/SHA256SUM.manylinux_2_28 +++ b/utils/docker/SHA256SUM.manylinux_2_28 @@ -1,7 +1,7 @@ -31747ae633213f1eda3842686f83c2aa1412e0f5691d1c14dbbcc67fe7400cea v1.11.1.tar.gz -a78f668a726ae1d3d9a7179996d97b12b90fb76ab9442a43110b972ff7ad9029 clang-17.0.6.src.tar.xz -807f069c54dc20cb47b21c1f6acafdd9c649f3ae015609040d6182cab01140f4 cmake-17.0.6.src.tar.xz -9e7535a353aa862730b4ba38df42e06f6856b40c4cc51b57f27b5046dc21d70d libunwind-17.0.6.src.tar.xz -4ac13125616dc44905b85820aa403d27ec1226329b7f674daeb5f5584c6f0b22 lld-17.0.6.src.tar.xz -b638167da139126ca11917b6880207cc6e8f9d1cbb1a48d87d017f697ef78188 llvm-17.0.6.src.tar.xz -3054d0a9c9375dab1a4539cc2cc45ab340341c5d71475f9599ba7752e222947b third-party-17.0.6.src.tar.xz +e7dfc8050407b5cc564c1c1afe19517255c9229cccd886dbd5bac9b652828d85 clang-19.1.5.src.tar.xz +a08ae477571fd5e929c27d3d0d28c6168d58dd00b6354c2de3266ae0d86ad44f cmake-19.1.5.src.tar.xz +997b493fb604e5e2c5b11c765a4c42b37acf00a4d6e8a14f8108d5c1051d760f libunwind-19.1.5.src.tar.xz +f71835d49461a15c283aa9030a854abfd7c651a685d711a67158644b043f6f14 lld-19.1.5.src.tar.xz +7d71635948e4da1814ce8e15ec45399e4094a5442e86d352c96ded0f2b3171b6 llvm-19.1.5.src.tar.xz +22b352c35b034a4ab3f2b852b6a2602a4da8971abe459080450d9e3462550d1d third-party-19.1.5.src.tar.xz +821bdff48a3f683bc4bb3b6f0b5fe7b2d647cf65d52aeb63328c91a6c6df285a v1.12.1.tar.gz From 8b4a2b96ff822a34a2d184409f9e1854b9964884 Mon Sep 17 00:00:00 2001 From: Sylveon Date: Tue, 3 Sep 2024 15:48:37 +0800 Subject: [PATCH 504/623] [Plugin] torchaoti Signed-off-by: Sylveon --- plugins/wasi_nn/wasinn_torch.cpp | 211 +++++++++++++++++++++++++------ plugins/wasi_nn/wasinn_torch.h | 61 ++++++++- 2 files changed, 228 insertions(+), 44 deletions(-) diff --git a/plugins/wasi_nn/wasinn_torch.cpp b/plugins/wasi_nn/wasinn_torch.cpp index 46412b5f..3cea4699 100644 --- a/plugins/wasi_nn/wasinn_torch.cpp +++ b/plugins/wasi_nn/wasinn_torch.cpp @@ -10,6 +10,142 @@ namespace WasmEdge::Host::WASINN::PyTorch { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH + +Expect TorchScript::setDevice(Device Device) { + if (Device == Device::CPU) { + TorchDevice = at::kCPU; + return ErrNo::Success; + } else if (Device == Device::GPU) { + if (!torch::cuda::is_available()) { + spdlog::error( + "[WASI-NN] CUDA Unavailable, platform Cannot support GPU target."); + return ErrNo::InvalidArgument; + } + TorchDevice = at::kCUDA; + return ErrNo::Success; + } + + spdlog::error("[WASI-NN] PyTorch Only support CPU and GPU target."); + return ErrNo::InvalidArgument; +} + +Expect TorchScript::loadFromBiary(std::istream &In, Device Device) { + if (auto Err = setDevice(Device); Err != ErrNo::Success) { + return Err; + } + TorchModel = torch::jit::load(In); + return ErrNo::Success; +} + +Expect TorchScript::loadFromPath(const std::string &Path, + Device Device) { + if (auto Err = setDevice(Device); Err != ErrNo::Success) { + return Err; + } + TorchModel = torch::jit::load(Path); + return ErrNo::Success; +} + +Expect TorchScript::run(std::vector In, + std::vector &Out) { + std::vector Inputs; + std::vector Outputs; + for (auto &OneOf : In) { + Inputs.push_back(OneOf); + } + auto RawOutput = TorchModel.forward(Inputs); + if (RawOutput.isTensorList()) { + auto OutTensors = RawOutput.toTensorVector(); + for (auto &OneOf : OutTensors) { + Out.push_back(OneOf.clone()); + } + } else if (RawOutput.isTuple()) { + auto OutTensorsTuple = RawOutput.toTuple()->elements(); + for (auto &OneOf : OutTensorsTuple) { + Out.push_back(OneOf.toTensor().clone()); + } + } else if (RawOutput.isTensor()) { + auto OutTensor = RawOutput.toTensor(); + Out.push_back(OutTensor.clone()); + } else { + spdlog::error("[WASI-NN] PyTorch backend only supports output a tensor, " + "a list of tensor or a tuple of tensor"); + return ErrNo::InvalidArgument; + } + return ErrNo::Success; +} + +Expect AOTInductor::setDevice(Device Device) { + if (Device == Device::CPU) { + TorchDevice = at::kCPU; + return ErrNo::Success; + } else if (Device == Device::GPU) { +#ifdef TORCHAOTI_USE_CUDA + TorchDevice = at::kCUDA; + return ErrNo::Success; +#else + spdlog::error( + "[WASI-NN] Please rebuild the plugin with AOTInductor CUDA support."); + return ErrNo::InvalidArgument; +#endif + } + + spdlog::error("[WASI-NN] AOTInductor Only support CPU and GPU target."); + return ErrNo::InvalidArgument; +} + +Expect AOTInductor::loadFromBiary(std::istream &, Device) { + spdlog::error("[WASI-NN] AOTInductor can not load by binary data. Please " + "pass the share library name (*.so) in nn-preload"); + return ErrNo::InvalidArgument; +} + +Expect AOTInductor::loadFromPath(const std::string &Path, + Device Device) { + if (auto Err = setDevice(Device); Err != ErrNo::Success) { + return Err; + } + if (TorchDevice == at::kCPU) { + TorchModel = new torch::inductor::AOTIModelContainerRunnerCpu(Path.c_str()); + } else if (TorchDevice == at::kCUDA) { +#ifdef TORCHAOTI_USE_CUDA + TorchModel = + new torch::inductor::AOTIModelContainerRunnerCuda(Path.c_str()); +#else + spdlog::error( + "[WASI-NN] Please rebuild the plugin with AOTInductor CUDA support."); + return ErrNo::InvalidArgument; +#endif + } else { + spdlog::error("[WASI-NN] Can not load the AOTInductor."); + return ErrNo::InvalidArgument; + } + return ErrNo::Success; +} + +Expect AOTInductor::run(std::vector In, + std::vector &Out) { + std::vector RawOutput = TorchModel->run(In); + + for (auto &OneOf : RawOutput) { + Out.push_back(OneOf.clone()); + } + return ErrNo::Success; +} + +PyModelBackend GuessPyModelBackendType(const std::string_view &Model) { + if (Model.substr(0, 8) == "preload:"sv) { + if (Model.substr(Model.size() - 3, 3) == ".so"sv) { + // AOTInductor only accept the shared library. + return PyModelBackend::AOTInductor; + } + } + + // Fall back to TorchScript if the model type is not set. + // This keep the compatibility with the old version. + return PyModelBackend::TorchScript; +} + Expect load(WasiNNEnvironment &Env, Span> Builders, Device Device, uint32_t &GraphId) noexcept { // The graph builder length must be 1. @@ -23,33 +159,37 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Add a new graph. Env.NNGraph.emplace_back(Backend::PyTorch); auto &GraphRef = Env.NNGraph.back().get(); - // Setup Graph Device - if (Device == Device::CPU) { - GraphRef.TorchDevice = at::kCPU; - } else if (Device == Device::GPU) { - if (!torch::cuda::is_available()) { - spdlog::error( - "[WASI-NN] CUDA Unavailable, platform Cannot support GPU target."); + + // Load the model from the binary data. + // Note: Pytorch use try catch to handle the error. + try { + const std::string_view BinModel(reinterpret_cast(Weight.data()), + Weight.size()); + PyModelBackend ModelType = GuessPyModelBackendType(BinModel); + + if (ModelType == PyModelBackend::TorchScript) { + GraphRef.Model = new TorchScript(); + } else if (ModelType == PyModelBackend::AOTInductor) { + GraphRef.Model = new AOTInductor(); + } else { + spdlog::error("[WASI-NN] Unknown model type."); return ErrNo::InvalidArgument; } - GraphRef.TorchDevice = at::kCUDA; - } else { - spdlog::error("[WASI-NN] PyTorch Only support CPU and GPU target."); - return ErrNo::InvalidArgument; - } - - std::istringstream BinRead( - std::string(reinterpret_cast(Weight.data()), Weight.size())); - try { - GraphRef.TorchModel = torch::jit::load(BinRead); - GraphRef.TorchModel.to(GraphRef.TorchDevice); + if (BinModel.substr(0, 8) == "preload:"sv) { + const std::string ModelFilePath(BinModel.substr(8)); + GraphRef.Model->loadFromPath(ModelFilePath, Device); + } else { + std::istringstream BinRead{std::string(BinModel)}; + // std::istringstream BinRead(BinModel); // Need C++26... + GraphRef.Model->loadFromBiary(BinRead, Device); + } } catch (const c10::Error &e) { spdlog::error("[WASI-NN] Failed when load the TorchScript model."); Env.NNGraph.pop_back(); return ErrNo::InvalidArgument; } - // Store the loaded graph. + GraphId = Env.NNGraph.size() - 1; return ErrNo::Success; } @@ -83,7 +223,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, torch::Tensor InTensor = torch::from_blob(reinterpret_cast(Tensor.Tensor.data()), Dims, Options) - .to(GraphRef.TorchDevice); + .to(GraphRef.Model->getDevice()); CxtRef.TorchInputs[Index] = InTensor.clone(); return ErrNo::Success; @@ -129,29 +269,15 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { } } auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - torch::jit::IValue RawOutput = - GraphRef.TorchModel.forward(CxtRef.TorchInputs); - // TODO: more output type should be supported here - if (RawOutput.isTensorList()) { - auto OutTensors = RawOutput.toTensorVector(); - for (auto &OneOf : OutTensors) { - CxtRef.TorchOutputs.push_back(OneOf.clone()); - } - } else if (RawOutput.isTuple()) { - auto OutTensorsTuple = RawOutput.toTuple()->elements(); - for (auto &OneOf : OutTensorsTuple) { - CxtRef.TorchOutputs.push_back(OneOf.toTensor().clone()); - } - } else if (RawOutput.isTensor()) { - auto OutTensor = RawOutput.toTensor(); - CxtRef.TorchOutputs.push_back(OutTensor.clone()); - } else { - spdlog::error("[WASI-NN] PyTorch backend only supports output a tensor, " - "a list of tensor or a tuple of tensor"); - return ErrNo::InvalidArgument; - } + return GraphRef.Model->run(CxtRef.TorchInputs, CxtRef.TorchOutputs); +} + +Expect unload(WasiNNEnvironment &Env, uint32_t GraphId) noexcept { + auto &GraphRef = Env.NNGraph[GraphId].get(); + delete GraphRef.Model; return ErrNo::Success; } + #else namespace { Expect reportBackendNotSupported() noexcept { @@ -179,6 +305,9 @@ Expect getOutput(WasiNNEnvironment &, uint32_t, uint32_t, Span, Expect compute(WasiNNEnvironment &, uint32_t) noexcept { return reportBackendNotSupported(); } +Expect unload(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} #endif } // namespace WasmEdge::Host::WASINN::PyTorch diff --git a/plugins/wasi_nn/wasinn_torch.h b/plugins/wasi_nn/wasinn_torch.h index 96f8e1b4..815675a7 100644 --- a/plugins/wasi_nn/wasinn_torch.h +++ b/plugins/wasi_nn/wasinn_torch.h @@ -8,6 +8,11 @@ #include "plugin/plugin.h" #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH +#include +#include +#ifdef TORCHAOTI_USE_CUDA +#include +#endif #include #include #endif @@ -19,16 +24,64 @@ struct WasiNNEnvironment; namespace WasmEdge::Host::WASINN::PyTorch { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH -struct Graph { - torch::jit::Module TorchModel; + +class PyBaseModule { +public: + virtual ~PyBaseModule() = default; + virtual Expect setDevice(Device Device) = 0; + virtual Expect loadFromPath(const std::string &Path, + Device Device) = 0; + virtual Expect loadFromBiary(std::istream &In, Device Device) = 0; + virtual Expect run(std::vector In, + std::vector &Out) = 0; + + torch::DeviceType getDevice() const { return TorchDevice; } + +protected: torch::DeviceType TorchDevice = at::kCPU; }; +class TorchScript : public PyBaseModule { + Expect setDevice(Device Device) override; + +public: + Expect loadFromPath(const std::string &Path, Device Device) override; + Expect loadFromBiary(std::istream &In, Device Device) override; + Expect run(std::vector In, + std::vector &Out) override; + + torch::jit::Module TorchModel; +}; + +class AOTInductor : public PyBaseModule { + Expect setDevice(Device Device) override; + +public: + Expect loadFromPath(const std::string &Path, Device Device) override; + Expect loadFromBiary(std::istream &In, Device Device) override; + Expect run(std::vector In, + std::vector &Out) override; + + torch::inductor::AOTIModelContainerRunner *TorchModel; + + ~AOTInductor() { + if (TorchModel) { + delete TorchModel; + } + } +}; + +enum class PyModelBackend { TorchScript, AOTInductor, UNKNOWN }; + +struct Graph { + PyBaseModule *Model = nullptr; +}; + struct Context { public: Context(size_t GId, Graph &) noexcept : GraphId(GId) {} size_t GraphId; - std::vector TorchInputs; + std::vector TorchInputs; std::vector TorchOutputs; }; #else @@ -55,4 +108,6 @@ Expect getOutput(WASINN::WasiNNEnvironment &Env, uint32_t &BytesWritten) noexcept; Expect compute(WASINN::WasiNNEnvironment &Env, uint32_t ContextId) noexcept; +Expect unload(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId) noexcept; } // namespace WasmEdge::Host::WASINN::PyTorch From 02d89b03c7ad20c6b64ca1b510c040ac033cbec4 Mon Sep 17 00:00:00 2001 From: Sylveon Date: Sat, 12 Oct 2024 22:31:24 +0800 Subject: [PATCH 505/623] [Misc] add warning if the abi may incompatible Signed-off-by: Sylveon --- plugins/wasi_nn/wasinn_torch.cpp | 76 ++++++++++++++++++++------------ plugins/wasi_nn/wasinn_torch.h | 7 +-- plugins/wasi_nn/wasinnenv.cpp | 4 +- 3 files changed, 56 insertions(+), 31 deletions(-) diff --git a/plugins/wasi_nn/wasinn_torch.cpp b/plugins/wasi_nn/wasinn_torch.cpp index 3cea4699..d87a910f 100644 --- a/plugins/wasi_nn/wasinn_torch.cpp +++ b/plugins/wasi_nn/wasinn_torch.cpp @@ -17,19 +17,20 @@ Expect TorchScript::setDevice(Device Device) { return ErrNo::Success; } else if (Device == Device::GPU) { if (!torch::cuda::is_available()) { - spdlog::error( - "[WASI-NN] CUDA Unavailable, platform Cannot support GPU target."); + spdlog::error("[WASI-NN] Torch: CUDA Unavailable. Please check if the " + "installed Torch version or platform supports CUDA."sv); return ErrNo::InvalidArgument; } TorchDevice = at::kCUDA; return ErrNo::Success; } - spdlog::error("[WASI-NN] PyTorch Only support CPU and GPU target."); + spdlog::error("[WASI-NN] Torch: Unknown target device. We currently support " + "only CPU and GPU targets."sv); return ErrNo::InvalidArgument; } -Expect TorchScript::loadFromBiary(std::istream &In, Device Device) { +Expect TorchScript::loadFromBinary(std::istream &In, Device Device) { if (auto Err = setDevice(Device); Err != ErrNo::Success) { return Err; } @@ -68,13 +69,23 @@ Expect TorchScript::run(std::vector In, auto OutTensor = RawOutput.toTensor(); Out.push_back(OutTensor.clone()); } else { - spdlog::error("[WASI-NN] PyTorch backend only supports output a tensor, " - "a list of tensor or a tuple of tensor"); + spdlog::error( + "[WASI-NN] Torch: The output can only be one of the following tensor " + "types: a tensor, a list of tensors, or a tuple of tensors."sv); return ErrNo::InvalidArgument; } return ErrNo::Success; } +AOTInductor::AOTInductor() : TorchModel(nullptr) { +#if defined(_GLIBCXX_USE_CXX11_ABI) && _GLIBCXX_USE_CXX11_ABI == 1 + spdlog::warn( + "[WASI-NN] Torch: AOTInductor build by pip default is not supported in " + "_GLIBCXX_USE_CXX11_ABI=1. Please rebuild the WasmEdge with " + "_GLIBCXX_USE_CXX11_ABI=0."sv); +#endif +} + Expect AOTInductor::setDevice(Device Device) { if (Device == Device::CPU) { TorchDevice = at::kCPU; @@ -84,19 +95,21 @@ Expect AOTInductor::setDevice(Device Device) { TorchDevice = at::kCUDA; return ErrNo::Success; #else - spdlog::error( - "[WASI-NN] Please rebuild the plugin with AOTInductor CUDA support."); + spdlog::error("[WASI-NN] Torch: Please rebuild the plugin with AOTInductor " + "CUDA support."sv); return ErrNo::InvalidArgument; #endif } - spdlog::error("[WASI-NN] AOTInductor Only support CPU and GPU target."); + spdlog::error("[WASI-NN] Torch: Unknown target device. We currently support " + "only CPU and GPU targets."sv); return ErrNo::InvalidArgument; } -Expect AOTInductor::loadFromBiary(std::istream &, Device) { - spdlog::error("[WASI-NN] AOTInductor can not load by binary data. Please " - "pass the share library name (*.so) in nn-preload"); +Expect AOTInductor::loadFromBinary(std::istream &, Device) { + spdlog::error( + "[WASI-NN] Torch: AOTInductor can not load by binary data. Please " + "pass the share library name (*.so) in nn-preload"sv); return ErrNo::InvalidArgument; } @@ -112,12 +125,12 @@ Expect AOTInductor::loadFromPath(const std::string &Path, TorchModel = new torch::inductor::AOTIModelContainerRunnerCuda(Path.c_str()); #else - spdlog::error( - "[WASI-NN] Please rebuild the plugin with AOTInductor CUDA support."); + spdlog::error("[WASI-NN] Torch: Please rebuild the plugin with AOTInductor " + "CUDA support."sv); return ErrNo::InvalidArgument; #endif } else { - spdlog::error("[WASI-NN] Can not load the AOTInductor."); + spdlog::error("[WASI-NN] Torch: Can not load the AOTInductor."sv); return ErrNo::InvalidArgument; } return ErrNo::Success; @@ -133,7 +146,9 @@ Expect AOTInductor::run(std::vector In, return ErrNo::Success; } -PyModelBackend GuessPyModelBackendType(const std::string_view &Model) { +PyModelBackend guessPyModelBackendType(const std::string_view &Model) { + // TODO: Add more model type detection when we supporet more OS. + // ex .dll, .dylib, etc. if (Model.substr(0, 8) == "preload:"sv) { if (Model.substr(Model.size() - 3, 3) == ".so"sv) { // AOTInductor only accept the shared library. @@ -141,6 +156,11 @@ PyModelBackend GuessPyModelBackendType(const std::string_view &Model) { } } + // ELF Header: 0x7f 'E' 'L' 'F' + if (Model.substr(0, 4) == "\x7f\x45\x4c\x46"sv) { + return PyModelBackend::AOTInductor; + } + // Fall back to TorchScript if the model type is not set. // This keep the compatibility with the old version. return PyModelBackend::TorchScript; @@ -150,7 +170,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, Device Device, uint32_t &GraphId) noexcept { // The graph builder length must be 1. if (Builders.size() != 1) { - spdlog::error("[WASI-NN] Wrong GraphBuilder Length {:d}, expect 1", + spdlog::error("[WASI-NN] Torch: Wrong GraphBuilder Length {:d}, expect 1"sv, Builders.size()); return ErrNo::InvalidArgument; } @@ -165,14 +185,14 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, try { const std::string_view BinModel(reinterpret_cast(Weight.data()), Weight.size()); - PyModelBackend ModelType = GuessPyModelBackendType(BinModel); + PyModelBackend ModelType = guessPyModelBackendType(BinModel); if (ModelType == PyModelBackend::TorchScript) { GraphRef.Model = new TorchScript(); } else if (ModelType == PyModelBackend::AOTInductor) { GraphRef.Model = new AOTInductor(); } else { - spdlog::error("[WASI-NN] Unknown model type."); + spdlog::error("[WASI-NN] Torch: Unknown model type."sv); return ErrNo::InvalidArgument; } @@ -182,10 +202,10 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, } else { std::istringstream BinRead{std::string(BinModel)}; // std::istringstream BinRead(BinModel); // Need C++26... - GraphRef.Model->loadFromBiary(BinRead, Device); + GraphRef.Model->loadFromBinary(BinRead, Device); } } catch (const c10::Error &e) { - spdlog::error("[WASI-NN] Failed when load the TorchScript model."); + spdlog::error("[WASI-NN] Torch: Failed when load the TorchScript model."sv); Env.NNGraph.pop_back(); return ErrNo::InvalidArgument; } @@ -210,7 +230,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } if (Tensor.RType != TensorType::F32) { spdlog::error( - "[WASI-NN] Only F32 inputs and outputs are supported for now."); + "[WASI-NN] Torch: Only F32 inputs and outputs are supported for now."sv); return ErrNo::InvalidArgument; } auto Options = @@ -235,8 +255,8 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, auto &CxtRef = Env.NNContext[ContextId].get(); if (CxtRef.TorchOutputs.size() <= Index) { spdlog::error( - "[WASI-NN] The output index {} exceeds the outputs number {}.", Index, - CxtRef.TorchOutputs.size()); + "[WASI-NN] Torch: The output index {} exceeds the outputs number {}."sv, + Index, CxtRef.TorchOutputs.size()); return ErrNo::InvalidArgument; } torch::Tensor OutTensor = @@ -258,13 +278,13 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); if (CxtRef.TorchInputs.size() == 0) { - spdlog::error("[WASI-NN] Input is not set!"); + spdlog::error("[WASI-NN] Torch: Input is not set!"sv); return ErrNo::InvalidArgument; } for (size_t I = 0; I < CxtRef.TorchInputs.size(); I++) { torch::jit::IValue InTensor = CxtRef.TorchInputs[I]; if (InTensor.isNone()) { - spdlog::error("[WASI-NN] Input [{}] is not set!", I); + spdlog::error("[WASI-NN] Torch: Input [{}] is not set!"sv, I); return ErrNo::InvalidArgument; } } @@ -274,7 +294,9 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { Expect unload(WasiNNEnvironment &Env, uint32_t GraphId) noexcept { auto &GraphRef = Env.NNGraph[GraphId].get(); - delete GraphRef.Model; + if (GraphRef.Model) { + delete GraphRef.Model; + } return ErrNo::Success; } diff --git a/plugins/wasi_nn/wasinn_torch.h b/plugins/wasi_nn/wasinn_torch.h index 815675a7..579ba354 100644 --- a/plugins/wasi_nn/wasinn_torch.h +++ b/plugins/wasi_nn/wasinn_torch.h @@ -31,7 +31,7 @@ class PyBaseModule { virtual Expect setDevice(Device Device) = 0; virtual Expect loadFromPath(const std::string &Path, Device Device) = 0; - virtual Expect loadFromBiary(std::istream &In, Device Device) = 0; + virtual Expect loadFromBinary(std::istream &In, Device Device) = 0; virtual Expect run(std::vector In, std::vector &Out) = 0; @@ -46,7 +46,7 @@ class TorchScript : public PyBaseModule { public: Expect loadFromPath(const std::string &Path, Device Device) override; - Expect loadFromBiary(std::istream &In, Device Device) override; + Expect loadFromBinary(std::istream &In, Device Device) override; Expect run(std::vector In, std::vector &Out) override; @@ -57,8 +57,9 @@ class AOTInductor : public PyBaseModule { Expect setDevice(Device Device) override; public: + AOTInductor(); Expect loadFromPath(const std::string &Path, Device Device) override; - Expect loadFromBiary(std::istream &In, Device Device) override; + Expect loadFromBinary(std::istream &In, Device Device) override; Expect run(std::vector In, std::vector &Out) override; diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index 088a2b49..c6b91ac6 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -27,6 +27,7 @@ std::map BackendMap = { {"onnx"sv, Backend::ONNX}, {"tensorflow"sv, Backend::Tensorflow}, {"pytorch"sv, Backend::PyTorch}, + {"pytorchaoti"sv, Backend::PyTorch}, {"tensorflowlite"sv, Backend::TensorflowLite}, {"autodetect"sv, Backend::Autodetect}, {"ggml"sv, Backend::GGML}, @@ -105,7 +106,8 @@ WasiNNEnvironment::WasiNNEnvironment() noexcept { auto Backend = BackendMap.find(Encode); auto Device = DeviceMap.find(Target); if (Backend != BackendMap.end() && Device != DeviceMap.end()) { - if (Backend->second == Backend::GGML) { + if (Backend->second == Backend::GGML || + (Backend->second == Backend::PyTorch && Encode == "pytorchaoti"sv)) { // In GGML, we only support loading one model from nn-preload // config. To handle paths on Windows that contains `:` in the // path, we combine the Paths into a single string separated by From c9e73d0613c206d7bc2fba8004351967b58f0f87 Mon Sep 17 00:00:00 2001 From: Han-Wen Tsao Date: Mon, 30 Dec 2024 18:29:06 +0800 Subject: [PATCH 506/623] [WASI-NN] whisper: add new options, no-timestamp and audio-ctx (#3931) * [WASI-NN] whisper: add new option Signed-off-by: grorge * [WASI-NN] whisper: fix naming issue Signed-off-by: grorge * [WASI-NN] whisper: add no-timestamp and audio-ctx Signed-off-by: grorge --------- Signed-off-by: grorge --- plugins/wasi_nn/wasinn_whisper.cpp | 511 +++++++++++++++++++++++++++-- plugins/wasi_nn/wasinn_whisper.h | 14 + 2 files changed, 496 insertions(+), 29 deletions(-) diff --git a/plugins/wasi_nn/wasinn_whisper.cpp b/plugins/wasi_nn/wasinn_whisper.cpp index 8a4c2404..a23a8c8c 100644 --- a/plugins/wasi_nn/wasinn_whisper.cpp +++ b/plugins/wasi_nn/wasinn_whisper.cpp @@ -3,6 +3,8 @@ #include "wasinn_whisper.h" #include "wasinnenv.h" +#include +#include #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER #define DR_WAV_IMPLEMENTATION @@ -16,6 +18,309 @@ namespace WasmEdge::Host::WASINN::Whisper { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER namespace { +int timestampToSample(int64_t T, int NSamples, int WhisperSampleRate) { + return std::max(0, std::min(static_cast(NSamples) - 1, + static_cast((T * WhisperSampleRate) / 100))); +} +std::string toTimestamp(int64_t T, bool Comma) { + int64_t Msec = T * 10; + int64_t Hr = Msec / (1000 * 60 * 60); + Msec = Msec - Hr * (1000 * 60 * 60); + int64_t Min = Msec / (1000 * 60); + Msec = Msec - Min * (1000 * 60); + int64_t Sec = Msec / 1000; + Msec = Msec - Sec * 1000; + + char Buf[32] = {}; + snprintf(Buf, sizeof(Buf), "%02d:%02d:%02d%s%03d", static_cast(Hr), + static_cast(Min), static_cast(Sec), Comma ? "," : ".", + static_cast(Msec)); + + return std::string(Buf); +} + +std::string +estimateDiarizationSpeaker(const std::vector> PCMF32s, + int64_t T0, int64_t T1, bool IdOnly = false) { + std::string Speaker = ""; + const int64_t NSamples = PCMF32s[0].size(); + + const int64_t Is0 = timestampToSample(T0, NSamples, WHISPER_SAMPLE_RATE); + const int64_t Is1 = timestampToSample(T1, NSamples, WHISPER_SAMPLE_RATE); + + double Energy0 = 0.0f; + double Energy1 = 0.0f; + + for (int64_t I = Is0; I < Is1; I++) { + Energy0 += fabs(PCMF32s[0][I]); + Energy1 += fabs(PCMF32s[1][I]); + } + + if (Energy0 > 1.1 * Energy1) { + Speaker = "0"; + } else if (Energy1 > 1.1 * Energy0) { + Speaker = "1"; + } else { + Speaker = "?"; + } + + if (!IdOnly) { + Speaker.insert(0, "(speaker "); + Speaker.append(")"); + } + + return Speaker; +} + +bool outputSrt(whisper_context *Ctx, const std::string &Fname, + const Config &Params, + const std::vector> &PCMF32s) { + std::ofstream Fout(Fname); + if (!Fout.is_open()) { + spdlog::error("[WASI-NN] Whisper backend: failed to open {} for writing."sv, + Fname); + return false; + } + spdlog::info("[WASI-NN] Whisper backend: saving srt output to {}."sv, Fname); + + const int NSegments = whisper_full_n_segments(Ctx); + for (int I = 0; I < NSegments; ++I) { + const std::string &Text = whisper_full_get_segment_text(Ctx, I); + const int64_t T0 = whisper_full_get_segment_t0(Ctx, I); + const int64_t T1 = whisper_full_get_segment_t1(Ctx, I); + std::string Speaker = ""; + + if (Params.Diarize && PCMF32s.size() == 2) { + Speaker = estimateDiarizationSpeaker(PCMF32s, T0, T1); + } + + Fout << I + 1 + Params.OffsetN << "\n"; + Fout << toTimestamp(T0, true) << " --> " << toTimestamp(T1, true) << "\n"; + Fout << Speaker << Text << "\n\n"; + } + return true; +} + +static bool outputLrc(whisper_context *Ctx, const std::string &Fname, + const Config &Params, + const std::vector> &PCMF32s) { + std::ofstream Fout(Fname); + if (!Fout.is_open()) { + spdlog::error("[WASI-NN] Whisper backend: failed to open {} for writing."sv, + Fname); + return false; + } + + spdlog::info("[WASI-NN] Whisper backend: saving lrc output to {}."sv, Fname); + + Fout << "[by:whisper.cpp]\n"; + + const int NSegments = whisper_full_n_segments(Ctx); + for (int I = 0; I < NSegments; ++I) { + const std::string &text = whisper_full_get_segment_text(Ctx, I); + const int64_t T = whisper_full_get_segment_t0(Ctx, I); + + int64_t Msec = T * 10; + int64_t Min = Msec / (1000 * 60); + Msec = Msec - Min * (1000 * 60); + int64_t Sec = Msec / 1000; + Msec = Msec - Sec * 1000; + + char Buf[16]; + snprintf(Buf, sizeof(Buf), "%02d:%02d.%02d", static_cast(Min), + static_cast(Sec), static_cast((Msec / 10))); + std::string TimestampLrc = std::string(Buf); + std::string Speaker = ""; + + if (Params.Diarize && PCMF32s.size() == 2) { + const int64_t t0 = whisper_full_get_segment_t0(Ctx, I); + const int64_t t1 = whisper_full_get_segment_t1(Ctx, I); + Speaker = estimateDiarizationSpeaker(PCMF32s, t0, t1); + } + + Fout << '[' << TimestampLrc << ']' << Speaker << text << "\n"; + } + + return true; +} + +std::string escapeDoubleQuotesAndBackslashes(const std::string &Str) { + std::string Escaped; + for (auto W : Str) { + if (W == '"' || W == '\\') { + Escaped += '\\'; + } + Escaped += W; + } + return Escaped; +} + +bool outputJson(whisper_context *Ctx, const std::string &Fname, + const Config &Params, + const std::vector> &PCMF32s, bool Full) { + std::ofstream Fout(Fname); + int Indent = 0; + + auto Doindent = [&]() { + for (int i = 0; i < Indent; i++) + Fout << "\t"; + }; + + auto StartArr = [&](const char *Name) { + Doindent(); + Fout << "\"" << Name << "\": [\n"; + Indent++; + }; + + auto EndArr = [&](bool End) { + Indent--; + Doindent(); + Fout << (End ? "]\n" : "],\n"); + }; + + auto StartObj = [&](const char *Name) { + Doindent(); + if (Name) { + Fout << "\"" << Name << "\": {\n"; + } else { + Fout << "{\n"; + } + Indent++; + }; + + auto EndObj = [&](bool End) { + Indent--; + Doindent(); + Fout << (End ? "}\n" : "},\n"); + }; + + auto StartValue = [&](const char *Name) { + Doindent(); + Fout << "\"" << Name << "\": "; + }; + + auto ValueS = [&](const char *Name, const std::string &Val, bool End) { + StartValue(Name); + std::string ValEscaped = escapeDoubleQuotesAndBackslashes(Val); + Fout << "\"" << ValEscaped << (End ? "\"\n" : "\",\n"); + }; + + auto EndValue = [&](bool End) { Fout << (End ? "\n" : ",\n"); }; + + auto ValueI = [&](const char *Name, const int64_t Val, bool End) { + StartValue(Name); + Fout << Val; + EndValue(End); + }; + + auto ValueF = [&](const char *Name, const float Val, bool End) { + StartValue(Name); + Fout << Val; + EndValue(End); + }; + + auto ValueB = [&](const char *Name, const bool Val, bool End) { + StartValue(Name); + Fout << (Val ? "true" : "false"); + EndValue(End); + }; + + auto TimesO = [&](int64_t T0, int64_t T1, bool End) { + StartObj("timestamps"); + ValueS("from", toTimestamp(T0, true), false); + ValueS("to", toTimestamp(T1, true), true); + EndObj(false); + StartObj("offsets"); + ValueI("from", T0 * 10, false); + ValueI("to", T1 * 10, true); + EndObj(End); + }; + + if (!Fout.is_open()) { + spdlog::error("[WASI-NN] Whisper backend: failed to open {} for writing."sv, + Fname); + return false; + } + + spdlog::info("[WASI-NN] Whisper backend: saving json output to {}."sv, Fname); + + StartObj(nullptr); + ValueS("systeminfo", whisper_print_system_info(), false); + StartObj("model"); + ValueS("type", whisper_model_type_readable(Ctx), false); + ValueB("multilingual", whisper_is_multilingual(Ctx), false); + ValueI("vocab", whisper_model_n_vocab(Ctx), false); + StartObj("audio"); + ValueI("ctx", whisper_model_n_audio_ctx(Ctx), false); + ValueI("state", whisper_model_n_audio_state(Ctx), false); + ValueI("head", whisper_model_n_audio_head(Ctx), false); + ValueI("layer", whisper_model_n_audio_layer(Ctx), true); + EndObj(false); + StartObj("text"); + ValueI("ctx", whisper_model_n_text_ctx(Ctx), false); + ValueI("state", whisper_model_n_text_state(Ctx), false); + ValueI("head", whisper_model_n_text_head(Ctx), false); + ValueI("layer", whisper_model_n_text_layer(Ctx), true); + EndObj(false); + ValueI("mels", whisper_model_n_mels(Ctx), false); + ValueI("ftype", whisper_model_ftype(Ctx), true); + EndObj(false); + StartObj("params"); + ValueS("model", "Wasi-nn preload", false); + ValueS("language", Params.SpokenLanguage, false); + ValueB("translate", Params.Translate, true); + EndObj(false); + StartObj("result"); + ValueS("language", whisper_lang_str(whisper_full_lang_id(Ctx)), true); + EndObj(false); + StartArr("transcription"); + + const int NSegments = whisper_full_n_segments(Ctx); + for (int I = 0; I < NSegments; ++I) { + const std::string &Text = whisper_full_get_segment_text(Ctx, I); + + const int64_t T0 = whisper_full_get_segment_t0(Ctx, I); + const int64_t T1 = whisper_full_get_segment_t1(Ctx, I); + + StartObj(nullptr); + TimesO(T0, T1, false); + ValueS("text", Text, !Params.Diarize && !Params.TinyDiarize && !Full); + + if (Full) { + StartArr("tokens"); + const int n = whisper_full_n_tokens(Ctx, I); + for (int j = 0; j < n; ++j) { + auto token = whisper_full_get_token_data(Ctx, I, j); + StartObj(nullptr); + ValueS("text", whisper_token_to_str(Ctx, token.id), false); + if (token.t0 > -1 && token.t1 > -1) { + // If we have per-token timestamps, write them out + TimesO(token.t0, token.t1, false); + } + ValueI("id", token.id, false); + ValueF("p", token.p, false); + ValueF("t_dtw", token.t_dtw, true); + EndObj(j == (n - 1)); + } + EndArr(!Params.Diarize && !Params.TinyDiarize); + } + + if (Params.Diarize && PCMF32s.size() == 2) { + ValueS("speaker", estimateDiarizationSpeaker(PCMF32s, T0, T1, true), + true); + } + + if (Params.TinyDiarize) { + ValueB("speaker_turn_next", + whisper_full_get_segment_speaker_turn_next(Ctx, I), true); + } + EndObj(I == (NSegments - 1)); + } + + EndArr(true); + EndObj(true); + return true; +} bool checkAudioRIFF(const std::string_view Buf, const std::string_view Format) { if (Buf.size() < 12 || Buf.substr(0, 4) != "RIFF"sv) { @@ -31,7 +336,8 @@ bool checkAudioRIFF(const std::string_view Buf, const std::string_view Format) { return true; } -bool loadWAV(Span Buf, std::vector &PCMF32) { +bool loadWAV(Span Buf, std::vector &PCMF32, + std::vector> &PCMF32s, bool Stereo) { // Not to use the helper function in examples of whisper.cpp to prevent from // copy. drwav WAV; @@ -77,6 +383,16 @@ bool loadWAV(Span Buf, std::vector &PCMF32) { static_cast(PCM16[2 * I] + PCM16[2 * I + 1]) / 65536.0f; } } + if (Stereo) { + PCMF32s.resize(2); + + PCMF32s[0].resize(N); + PCMF32s[1].resize(N); + for (uint64_t I = 0; I < N; I++) { + PCMF32s[0][I] = float(PCM16[2 * I]) / 32768.0f; + PCMF32s[1][I] = float(PCM16[2 * I + 1]) / 32768.0f; + } + } return true; } @@ -110,31 +426,27 @@ void WhisperOutputSegmentCallback(struct whisper_context *WhisperCtx, auto &CxtRef = *reinterpret_cast(UserData); const int SegN = whisper_full_n_segments(WhisperCtx); - auto ToTimeStr = [](int64_t T) -> std::string { - T *= 10; - uint32_t HR = static_cast(T / (1000 * 60 * 60)); - T %= 1000 * 60 * 60; - uint32_t M = static_cast(T / (1000 * 60)); - T %= 1000 * 60; - uint32_t S = static_cast(T / 1000); - uint32_t MS = static_cast(T % 1000); - char Buf[32]; - snprintf(Buf, sizeof(Buf), "%02d:%02d:%02d.%03d", HR, M, S, MS); - return std::string(Buf); - }; - + std::string Speaker = ""; // Output the last new N segments. for (int I = SegN - NewN; I < SegN; I++) { - int64_t T0 = whisper_full_get_segment_t0(WhisperCtx, I); - int64_t T1 = whisper_full_get_segment_t1(WhisperCtx, I); - // TODO: Add the print timestamp config. - CxtRef.Outputs += "["; - CxtRef.Outputs += ToTimeStr(T0); - CxtRef.Outputs += " --> "; - CxtRef.Outputs += ToTimeStr(T1); - CxtRef.Outputs += "] "; - CxtRef.Outputs += whisper_full_get_segment_text(WhisperCtx, I); - CxtRef.Outputs += "\n"; + int64_t T0 = 0; + int64_t T1 = 0; + if (!CxtRef.WhisperConfig.NoTimestamps) { + T0 = whisper_full_get_segment_t0(WhisperCtx, I); + T1 = whisper_full_get_segment_t1(WhisperCtx, I); + CxtRef.Outputs += "["; + CxtRef.Outputs += toTimestamp(T0, false); + CxtRef.Outputs += " --> "; + CxtRef.Outputs += toTimestamp(T1, false); + CxtRef.Outputs += "] "; + } + if (CxtRef.WhisperConfig.Diarize && CxtRef.InputPCMs.size() == 2) { + Speaker = estimateDiarizationSpeaker(CxtRef.InputPCMs, T0, T1); + } + CxtRef.Outputs += Speaker + whisper_full_get_segment_text(WhisperCtx, I); + if (!CxtRef.WhisperConfig.NoTimestamps || CxtRef.WhisperConfig.Diarize) { + CxtRef.Outputs += "\n"; + } } } @@ -161,6 +473,15 @@ void setWhisperParams(Context &CxtRef) noexcept { WParam.grammar_penalty = ConfigRef.GrammarPenalty; WParam.new_segment_callback = WhisperOutputSegmentCallback; WParam.new_segment_callback_user_data = &CxtRef; + WParam.greedy.best_of = ConfigRef.BestOf; + WParam.print_timestamps = !ConfigRef.NoTimestamps; + WParam.no_timestamps = ConfigRef.NoTimestamps; + WParam.audio_ctx = ConfigRef.AudioCtx; + WParam.strategy = + (ConfigRef.BeamSize > 1) + ? whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH + : whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY; + WParam.beam_search.beam_size = ConfigRef.BeamSize; if (ConfigRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] Whisper backend: Config: threads: {}", @@ -201,7 +522,7 @@ Expect parseMetadata(Config &ConfigRef, return ErrNo::InvalidEncoding; } - auto PrintParsedOption = [=](std::string_view Name, const auto &Val) { + auto PrintParsedOption = [&](std::string_view Name, const auto &Val) { if (ConfigRef.EnableDebugLog) { spdlog::info( "[WASI-NN][Debug] Whisper backend: Parsed metadata -- {}:{}"sv, Name, @@ -369,6 +690,117 @@ Expect parseMetadata(Config &ConfigRef, ConfigRef.InitialPrompt = Prompt; PrintParsedOption("prompt"sv, ConfigRef.InitialPrompt); } + if (Doc.at_key("best-of").error() == simdjson::SUCCESS) { + auto Err = Doc["best-of"].get().get(ConfigRef.BestOf); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the best-of option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("best-of"sv, ConfigRef.BestOf); + } + if (Doc.at_key("beam-size").error() == simdjson::SUCCESS) { + auto Err = Doc["beam-size"].get().get(ConfigRef.BeamSize); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the beam-size option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("beam-size"sv, ConfigRef.BeamSize); + } + if (Doc.at_key("output-srt").error() == simdjson::SUCCESS) { + auto Err = Doc["output-srt"].get().get(ConfigRef.OutputSrt); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the output-srt " + "option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("output-srt"sv, ConfigRef.OutputLrc); + } + if (Doc.at_key("output-lrc").error() == simdjson::SUCCESS) { + auto Err = Doc["output-lrc"].get().get(ConfigRef.OutputLrc); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the output-lrc" + "option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("output-lrc"sv, ConfigRef.OutputLrc); + } + if (Doc.at_key("output-json").error() == simdjson::SUCCESS) { + auto Err = Doc["output-json"].get().get(ConfigRef.OutputJson); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the output-json " + "option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("output-json"sv, ConfigRef.OutputJson); + } + if (Doc.at_key("output-json-full").error() == simdjson::SUCCESS) { + auto Err = + Doc["output-json-full"].get().get(ConfigRef.OutputJsonFull); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the output-json-full " + "option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("output-json-full"sv, ConfigRef.OutputJsonFull); + } + if (Doc.at_key("no-timestamps").error() == simdjson::SUCCESS) { + auto Err = Doc["no-timestamps"].get().get(ConfigRef.NoTimestamps); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the no-timestamps " + "option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("no-timestamps"sv, ConfigRef.NoTimestamps); + } + if (Doc.at_key("output-file").error() == simdjson::SUCCESS) { + std::string_view FileName; + auto Err = Doc["output-file"].get().get(FileName); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the output file" + "option."sv); + return ErrNo::InvalidArgument; + } + ConfigRef.FileName = FileName; + PrintParsedOption("output-file"sv, ConfigRef.FileName); + } + if (Doc.at_key("audio-ctx").error() == simdjson::SUCCESS) { + auto Err = Doc["audio-ctx"].get().get(ConfigRef.AudioCtx); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the audio-ctx " + "option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("audio-ctx"sv, ConfigRef.AudioCtx); + } + if (Doc.at_key("diarize").error() == simdjson::SUCCESS) { + auto Err = Doc["diarize"].get().get(ConfigRef.Diarize); + if (Err) { + spdlog::error("[WASI-NN] Whisper backend: Unable to retrieve the diarize " + "option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("diarize"sv, ConfigRef.Diarize); + } + if (Doc.at_key("offset-n").error() == simdjson::SUCCESS) { + auto Err = Doc["offset-n"].get().get(ConfigRef.OffsetN); + if (Err) { + spdlog::error( + "[WASI-NN] Whisper backend: Unable to retrieve the offset-n " + "option."sv); + return ErrNo::InvalidArgument; + } + PrintParsedOption("offset-n"sv, ConfigRef.OffsetN); + } + return ErrNo::Success; } @@ -402,7 +834,7 @@ Expect handleTranslationConfig(whisper_context *WhisperCtx, return ErrNo::Success; } -} // namespace +} // Namespace Expect load(WasiNNEnvironment &Env, Span> Builders, [[maybe_unused]] Device Device, uint32_t &GraphId) noexcept { @@ -563,7 +995,8 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, spdlog::error("[WASI-NN] Only WAV format supported now."sv); return WASINN::ErrNo::InvalidArgument; } - if (!loadWAV(Tensor.Tensor, CxtRef.InputPCM)) { + if (!loadWAV(Tensor.Tensor, CxtRef.InputPCM, CxtRef.InputPCMs, + CxtRef.WhisperConfig.Diarize)) { return WASINN::ErrNo::InvalidArgument; } @@ -577,6 +1010,7 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, uint32_t Index, Span OutBuffer, uint32_t &BytesWritten) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); if (CxtRef.WhisperConfig.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] Whisper backend: getOutput with Index {}"sv, Index); @@ -596,6 +1030,25 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, "[WASI-NN][Debug] Whisper backend: getOutput with Index {}...Done"sv, Index); } + + if (CxtRef.WhisperConfig.OutputSrt) { + const auto Fname = CxtRef.WhisperConfig.FileName + ".srt"; + outputSrt(GraphRef.WhisperCtx, Fname, CxtRef.WhisperConfig, + CxtRef.InputPCMs); + } + + if (CxtRef.WhisperConfig.OutputLrc) { + const auto Fname = CxtRef.WhisperConfig.FileName + ".lrc"; + outputLrc(GraphRef.WhisperCtx, Fname, CxtRef.WhisperConfig, + CxtRef.InputPCMs); + } + + if (CxtRef.WhisperConfig.OutputJson) { + const auto Fname = CxtRef.WhisperConfig.FileName + ".json"; + outputJson(GraphRef.WhisperCtx, Fname, CxtRef.WhisperConfig, + CxtRef.InputPCMs, CxtRef.WhisperConfig.OutputJsonFull); + } + return ErrNo::Success; } @@ -654,7 +1107,7 @@ Expect reportBackendNotSupported() noexcept { "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"whisper\" to build it."sv); return ErrNo::InvalidArgument; } -} // namespace +} // Namespace Expect load(WasiNNEnvironment &, Span>, Device, uint32_t &) noexcept { @@ -679,4 +1132,4 @@ Expect unload(WasiNNEnvironment &, uint32_t) noexcept { } #endif -} // namespace WasmEdge::Host::WASINN::Whisper +} // Namespace WasmEdge::Host::WASINN::Whisper diff --git a/plugins/wasi_nn/wasinn_whisper.h b/plugins/wasi_nn/wasinn_whisper.h index f468e9e0..7cbae0c9 100644 --- a/plugins/wasi_nn/wasinn_whisper.h +++ b/plugins/wasi_nn/wasinn_whisper.h @@ -6,6 +6,7 @@ #include "wasinntypes.h" #include "plugin/plugin.h" +#include #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER #include @@ -38,8 +39,20 @@ struct Config { bool Translate = false; bool DetectLanguage = false; bool SplitOnWord = false; + bool Diarize = false; + bool TinyDiarize = false; std::string SpokenLanguage; std::string InitialPrompt; + uint64_t BestOf = 5; + uint64_t BeamSize = 5; + uint64_t OffsetN = 0; + std::string FileName = "output"; + bool OutputSrt = false; + bool OutputLrc = false; + bool OutputJson = false; + bool OutputJsonFull = false; + bool NoTimestamps = false; + uint64_t AudioCtx = 0; // Sampling parameters: float WordThreshold = 0.01f; float EntropyThreshold = 2.40f; @@ -66,6 +79,7 @@ struct Context { size_t GraphId; // mono-channel F32 PCM input. std::vector InputPCM; + std::vector> InputPCMs; // Whisper config. Inherit from the graph and accept metadata when setting // input. Config WhisperConfig; From 23ae18887693652afe80281ef0f3b4fbbb190355 Mon Sep 17 00:00:00 2001 From: PeterD1524 <53310459+PeterD1524@users.noreply.github.com> Date: Tue, 31 Dec 2024 02:53:08 +0800 Subject: [PATCH 507/623] [WASI-NN] ChatTTS: fix GIL problem and do not call Py_Finalize (#3940) [WASI-NN] ChatTTS: handle the GIL and do not call Py_Finalize in unload Call PyEval_SaveThread to release the GIL if Py_Initialize. Added a new class GIL. When constructed, it will ensure holding the GIL by calling PyGILState_Check. When destructed, it will release the GIL if the GIL was acquired previously by calling PyGILState_Release. This class should be used before calling Python/C API. Removed the Py_Finalize in unload because numpy does not allow multiple Python interpreter initialization, and because Py_Finalize will release the resources of all graphs. Signed-off-by: PeterD1524 --- plugins/wasi_nn/wasinn_chattts.cpp | 15 +++++++++++---- plugins/wasi_nn/wasinn_chattts.h | 20 +++++++++++++++++++- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/plugins/wasi_nn/wasinn_chattts.cpp b/plugins/wasi_nn/wasinn_chattts.cpp index a889ce1d..3cb47867 100644 --- a/plugins/wasi_nn/wasinn_chattts.cpp +++ b/plugins/wasi_nn/wasinn_chattts.cpp @@ -37,7 +37,11 @@ Expect load(WASINN::WasiNNEnvironment &Env, // Create Model class if (!Py_IsInitialized()) { Py_Initialize(); + if (PyGILState_Check()) { + PyEval_SaveThread(); + } } + GIL Lock; if (GraphRef.ChatTTSModule == nullptr) { GraphRef.ChatTTSModule = PyImport_ImportModule("ChatTTS"); if (GraphRef.ChatTTSModule == nullptr) { @@ -123,6 +127,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, Env.NNGraph.pop_back(); return ErrNo::InvalidEncoding; } + GIL Lock; // Handle Refine Text Params PyObject *PromptObj = nullptr; if (Doc.at_key("prompt").error() == simdjson::SUCCESS) { @@ -253,6 +258,7 @@ Expect compute(WasiNNEnvironment &Env, spdlog::error("[WASI-NN] ChatTTS backend: Input is not set!"sv); return ErrNo::InvalidArgument; } + GIL Lock; PyObject *InputStr = PyUnicode_FromString(CxtRef.Inputs.c_str()); PyObject *InferMethod = PyObject_GetAttrString(GraphRef.Chat, "infer"); PyObject *Result = nullptr; @@ -305,20 +311,21 @@ Expect compute(WasiNNEnvironment &Env, } Expect unload(WASINN::WasiNNEnvironment &Env, - uint32_t ContextId) noexcept { - auto &CxtRef = Env.NNContext[ContextId].get(); - auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + uint32_t GraphId) noexcept { + auto &GraphRef = Env.NNGraph[GraphId].get(); if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN] Neural speed backend: start unload."sv); } if (Py_IsInitialized()) { + GIL Lock; Py_XDECREF(GraphRef.ParamsRefineText); Py_XDECREF(GraphRef.ParamsInferCode); Py_XDECREF(GraphRef.Chat); Py_XDECREF(GraphRef.ChatTTSModule); + GraphRef.ParamsRefineText = nullptr; + GraphRef.ParamsInferCode = nullptr; GraphRef.Chat = nullptr; GraphRef.ChatTTSModule = nullptr; - Py_Finalize(); } return WASINN::ErrNo::Success; } diff --git a/plugins/wasi_nn/wasinn_chattts.h b/plugins/wasi_nn/wasinn_chattts.h index 9b354a66..c59010db 100644 --- a/plugins/wasi_nn/wasinn_chattts.h +++ b/plugins/wasi_nn/wasinn_chattts.h @@ -17,11 +17,29 @@ struct WasiNNEnvironment; namespace WasmEdge::Host::WASINN::ChatTTS { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS +class GIL { +private: + PyGILState_STATE GState; + +public: + GIL() : GState(PyGILState_Ensure()) {} + ~GIL() { PyGILState_Release(GState); } + GIL(const GIL &) = delete; + GIL &operator=(const GIL &) = delete; +}; struct Graph { bool EnableDebugLog = false; - Graph() noexcept { Py_Initialize(); } + Graph() noexcept { + if (!Py_IsInitialized()) { + Py_Initialize(); + if (PyGILState_Check()) { + PyEval_SaveThread(); + } + } + } ~Graph() noexcept { if (Py_IsInitialized()) { + GIL Lock; Py_XDECREF(Chat); Py_XDECREF(ChatTTSModule); } From 16cf9d6a59fb89008dd814f8d99d03805854de4d Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 31 Dec 2024 15:05:11 +0800 Subject: [PATCH 508/623] [WASI-NN] ggml: show clip uses pure CPU log (#3943) Signed-off-by: hydai --- plugins/wasi_nn/wasinn_ggml.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 34ab55af..f85377d7 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -1060,14 +1060,12 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, // Load the clip model if not loaded. if (GraphRef.ClipContext == nullptr) { - if (GraphRef.EnableLog) { - spdlog::info( - "[WASI-NN] GGML backend: Load the clip model. " - "Because llama.cpp disabled the GPU support for CLIP, " - "the step of loading images in CLIP can only use the CPU, " - "which may result in reduced efficiency. " - "(You can refer to PR https://github.com/ggerganov/llama.cpp/pull/10896)"sv); - } + spdlog::info( + "[WASI-NN] GGML backend: Load the clip model. " + "Because llama.cpp disabled the GPU support for CLIP, " + "the step of loading images in CLIP can only use the CPU, " + "which may result in reduced efficiency. " + "(You can refer to PR https://github.com/ggerganov/llama.cpp/pull/10896)"sv); GraphRef.ClipContext = clip_model_load(GraphRef.MMProjModelPath.c_str(), GraphRef.EnableLog ? 1 : 0); if (GraphRef.ClipContext == nullptr) { @@ -1077,6 +1075,8 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } if (clip_is_qwen2vl(GraphRef.ClipContext)) { GraphRef.VisionModelType = VisionModel::Qwen2VL; + spdlog::info( + "[WASI-NN] GGML backend: The Qwen2VL clip model is loaded."sv); if (GraphRef.EnableLog) { spdlog::info("[WASI-NN] GGML backend: Qwen2vl model detected."sv); } From 30c83c4d0849094b55c6fd8237e81a85416418e5 Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 31 Dec 2024 15:24:27 +0800 Subject: [PATCH 509/623] [WASI-NN] ggml: update option types according to llama.cpp Ref: - https://github.com/ggerganov/llama.cpp/blob/master/common/common.h - https://github.com/ggerganov/llama.cpp/blob/master/include/llama.h Signed-off-by: dm4 --- plugins/wasi_nn/wasinn_ggml.cpp | 34 ++++++++++++++++----------------- plugins/wasi_nn/wasinn_ggml.h | 10 +++++----- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index f85377d7..d8ec5b7b 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -54,12 +54,12 @@ void LlamaLogCallback(ggml_log_level LogLevel, const char *LogText, Expect setupParams(Graph &GraphRef, common_params &Params) { Params.model = GraphRef.ModelFilePath; Params.n_gpu_layers = static_cast(GraphRef.NGPULayers); - Params.n_ctx = static_cast(GraphRef.CtxSize); - Params.n_batch = static_cast(GraphRef.BatchSize); - Params.n_ubatch = static_cast(GraphRef.UBatchSize); + Params.n_ctx = static_cast(GraphRef.CtxSize); + Params.n_batch = static_cast(GraphRef.BatchSize); + Params.n_ubatch = static_cast(GraphRef.UBatchSize); Params.warmup = GraphRef.WarmUp; - Params.cpuparams.n_threads = static_cast(GraphRef.Threads); - Params.cpuparams_batch.n_threads = static_cast(GraphRef.Threads); + Params.cpuparams.n_threads = static_cast(GraphRef.Threads); + Params.cpuparams_batch.n_threads = static_cast(GraphRef.Threads); Params.embedding = GraphRef.Embedding; Params.sampling.temp = static_cast(GraphRef.Temp); Params.sampling.top_p = static_cast(GraphRef.TopP); @@ -89,7 +89,7 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, // enable-debug-log: bool // stream-stdout: bool // embedding: bool - // n-predict: uint64_t + // n-predict: int64_t // reverse-prompt: string // mmproj: string // image: string @@ -101,10 +101,10 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, // use-mmap: use mmap // warmup: bool // Context parameters (used by the llama context): - // ctx-size: uint64_t - // batch-size: uint64_t - // ubatch-size: uint64_t - // threads: uint64_t + // ctx-size: int64_t + // batch-size: int64_t + // ubatch-size: int64_t + // threads: int64_t // Sampling parameters (used by the llama sampling context). // temp: double // top-p: double @@ -151,7 +151,7 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, } } if (Doc.at_key("n-predict").error() == simdjson::SUCCESS) { - auto Err = Doc["n-predict"].get().get(GraphRef.NPredict); + auto Err = Doc["n-predict"].get().get(GraphRef.NPredict); if (Err) { spdlog::error( "[WASI-NN] GGML backend: Unable to retrieve the n-predict option."sv); @@ -255,7 +255,7 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, // The context parameters. if (Doc.at_key("ctx-size").error() == simdjson::SUCCESS) { - auto Err = Doc["ctx-size"].get().get(GraphRef.CtxSize); + auto Err = Doc["ctx-size"].get().get(GraphRef.CtxSize); if (Err) { spdlog::error( "[WASI-NN] GGML backend: Unable to retrieve the ctx-size option."sv); @@ -263,7 +263,7 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, } } if (Doc.at_key("batch-size").error() == simdjson::SUCCESS) { - auto Err = Doc["batch-size"].get().get(GraphRef.BatchSize); + auto Err = Doc["batch-size"].get().get(GraphRef.BatchSize); if (Err) { spdlog::error( "[WASI-NN] GGML backend: Unable to retrieve the batch-size option."sv); @@ -271,7 +271,7 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, } } if (Doc.at_key("ubatch-size").error() == simdjson::SUCCESS) { - auto Err = Doc["ubatch-size"].get().get(GraphRef.UBatchSize); + auto Err = Doc["ubatch-size"].get().get(GraphRef.UBatchSize); if (Err) { spdlog::error( "[WASI-NN] GGML backend: Unable to retrieve the ubatch-size option."sv); @@ -279,7 +279,7 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, } } if (Doc.at_key("threads").error() == simdjson::SUCCESS) { - auto Err = Doc["threads"].get().get(GraphRef.Threads); + auto Err = Doc["threads"].get().get(GraphRef.Threads); if (Err) { spdlog::error( "[WASI-NN] GGML backend: Unable to retrieve the threads option."sv); @@ -608,7 +608,7 @@ Expect getEmbedding(WasiNNEnvironment &Env, } // Check if the input is too long. - if (static_cast(CxtRef.LlamaInputs.size()) > GraphRef.BatchSize) { + if (static_cast(CxtRef.LlamaInputs.size()) > GraphRef.BatchSize) { if (GraphRef.EnableLog) { spdlog::info( "[WASI-NN] GGML backend: the prompt is too long. " @@ -1226,7 +1226,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // Prepare variables; int32_t NPast = 0; int32_t NPos = 0; - uint64_t NRemain = GraphRef.NPredict; + int64_t NRemain = GraphRef.NPredict; // Get the context size. const uint64_t NCtx = llama_n_ctx(GraphRef.LlamaContext); // Minus 4 for the special tokens. (Such as , , ... tokens.) diff --git a/plugins/wasi_nn/wasinn_ggml.h b/plugins/wasi_nn/wasinn_ggml.h index 241696de..887a35ba 100644 --- a/plugins/wasi_nn/wasinn_ggml.h +++ b/plugins/wasi_nn/wasinn_ggml.h @@ -49,7 +49,7 @@ struct Graph { bool Embedding = false; EmbdNormalizeType EmbdNormalize = EmbdNormalizeType::Euclidean; bool ComputeSingleStarted = false; - uint64_t NPredict; + int64_t NPredict; std::string ReversePrompt; std::string MMProjModelPath; std::string ImagePath; @@ -61,10 +61,10 @@ struct Graph { bool UseMMap = true; bool WarmUp = false; // Context parameters: - uint64_t CtxSize; - uint64_t BatchSize; - uint64_t UBatchSize; - uint64_t Threads; + int64_t CtxSize; + int64_t BatchSize; + int64_t UBatchSize; + int64_t Threads; // Sampling parameters: double Temp = 0.80; double TopP = 0.95; From 1de155231e012a851e39a7eaf32c2aa59290c42a Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 31 Dec 2024 15:37:55 +0800 Subject: [PATCH 510/623] [WASI-NN] ggml: do not append SEP when get embeddings Ref: https://github.com/ggerganov/llama.cpp/blob/7eee341/examples/embedding/embedding.cpp#L147-L154 Signed-off-by: dm4 --- plugins/wasi_nn/wasinn_ggml.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index d8ec5b7b..c7af7f51 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -604,7 +604,8 @@ Expect getEmbedding(WasiNNEnvironment &Env, // Add SEP if not present. if (CxtRef.LlamaInputs.back() != llama_token_sep(GraphRef.LlamaModel)) { - CxtRef.LlamaInputs.push_back(llama_token_sep(GraphRef.LlamaModel)); + spdlog::warn( + "[WASI-NN] GGML backend: last token in the prompt is not SEP, 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header"sv); } // Check if the input is too long. From 6cd4aba9ad59e2a81e23f607a11ec1c8830618b7 Mon Sep 17 00:00:00 2001 From: Han-Wen Tsao Date: Sun, 5 Jan 2025 02:45:30 +0800 Subject: [PATCH 511/623] [Plugin] Stable Diffusion: upgrade version to dcf91 (#3950) Signed-off-by: grorge --- .../wasmedge_stablediffusion/CMakeLists.txt | 14 ++--- plugins/wasmedge_stablediffusion/sd_env.cpp | 2 +- plugins/wasmedge_stablediffusion/sd_func.cpp | 58 +++++++++++++------ plugins/wasmedge_stablediffusion/sd_func.h | 3 +- .../wasmedge_stablediffusion.cpp | 6 ++ 5 files changed, 56 insertions(+), 27 deletions(-) diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index fb7db00f..c716e06a 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -2,12 +2,12 @@ # SPDX-FileCopyrightText: 2019-2024 Second State INC -if(WASMEDGE_PLUGIN_STABLEDIFFUSION_CUBLAS) - message(STATUS "Stable diffusion plugin: Enable SD_CUBLAS") - set(SD_CUBLAS ON CACHE BOOL "Stable diffusion plugin: Enable SD_CUBLAS") +if(WASMEDGE_PLUGIN_STABLEDIFFUSION_CUDA) + message(STATUS "Stable diffusion plugin: Enable SD_CUDA") + set(SD_CUDA ON CACHE BOOL "Stable diffusion plugin: Enable SD_CUDA") else() - message(STATUS "Stable diffusion plugin: Disable SD_CUBLAS") - set(SD_CUBLAS OFF CACHE BOOL "Stable diffusion plugin: Disable SD_CUBLAS") + message(STATUS "Stable diffusion plugin: Disable SD_CUDA") + set(SD_CUDA OFF CACHE BOOL "Stable diffusion plugin: Disable SD_CUDA") endif() if(APPLE AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64" AND WASMEDGE_PLUGIN_STABLEDIFFUSION_METAL) @@ -31,8 +31,8 @@ endif() message(STATUS "Downloading stable diffusion source") FetchContent_Declare( stable-diffusion - GIT_REPOSITORY https://github.com/second-state/stable-diffusion.cpp.git - GIT_TAG d08889cb3f86f49d3f4f9c0c7e3781238c44bd3d + GIT_REPOSITORY https://github.com/leejet/stable-diffusion.cpp.git + GIT_TAG dcf91f9e0f2cbf9da472ee2a556751ed4bab2d2a GIT_SHALLOW TRUE ) set(SD_BUILD_SHARED_LIBS ON CACHE INTERNAL "Stable diffusion plugin: Build shared libs") diff --git a/plugins/wasmedge_stablediffusion/sd_env.cpp b/plugins/wasmedge_stablediffusion/sd_env.cpp index f981cd44..f2aa0ca2 100644 --- a/plugins/wasmedge_stablediffusion/sd_env.cpp +++ b/plugins/wasmedge_stablediffusion/sd_env.cpp @@ -17,7 +17,7 @@ Plugin::Plugin::PluginDescriptor Descriptor{ .Name = "wasmedge_stablediffusion", .Description = "Stable Diffusion plug-in for WasmEdge.", .APIVersion = Plugin::Plugin::CurrentAPIVersion, - .Version = {0, 3, 0, 0}, + .Version = {0, 4, 0, 0}, .ModuleCount = 1, .ModuleDescriptions = (Plugin::PluginModule::ModuleDescriptor[]){ diff --git a/plugins/wasmedge_stablediffusion/sd_func.cpp b/plugins/wasmedge_stablediffusion/sd_func.cpp index c92c5e77..f0cca7a3 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.cpp +++ b/plugins/wasmedge_stablediffusion/sd_func.cpp @@ -109,6 +109,24 @@ sd_image_t *readControlImage(Span ControlImage, int Width, int Height, return ControlImg; } +sd_image_t readMaskImage(Span MaskImage, int Width, int Height) { + uint8_t *MaskImageBuffer = NULL; + std::string MaskImagePath(MaskImage.begin(), MaskImage.end()); + int Channel = 0; + if (MaskImagePath.substr(0, 5) == "path:"sv) { + MaskImageBuffer = + stbi_load(MaskImagePath.substr(5).data(), &Width, &Height, &Channel, 3); + } else if (MaskImage.size() != 0) { + MaskImageBuffer = stbi_load_from_memory(MaskImage.data(), MaskImage.size(), + &Width, &Height, &Channel, 3); + } else { + std::vector Arr(Width * Height, 255); + MaskImageBuffer = Arr.data(); + } + return {static_cast(Width), static_cast(Height), 1, + MaskImageBuffer}; +} + void upscalerModel(const char *UpscaleModelPath, uint32_t UpscaleRepeats, int32_t NThreads, uint32_t BatchCount, sd_image_t *Results) { // unused for RealESRGAN_x4plus_anime_6B.pth @@ -415,24 +433,26 @@ Expect SDTextToImage::body( Expect SDImageToImage::body( const Runtime::CallingFrame &Frame, uint32_t ImagePtr, uint32_t ImageLen, - uint32_t SessionId, float Guidance, uint32_t Width, uint32_t Height, - uint32_t ControlImagePtr, uint32_t ControlImageLen, uint32_t PromptPtr, - uint32_t PromptLen, uint32_t NegativePromptPtr, uint32_t NegativePromptLen, - int32_t ClipSkip, float CfgScale, uint32_t SampleMethod, - uint32_t SampleSteps, float Strength, uint32_t Seed, uint32_t BatchCount, - float ControlStrength, float StyleRatio, uint32_t NormalizeInput, - uint32_t InputIdImagesDirPtr, uint32_t InputIdImagesDirLen, - uint32_t CannyPreprocess, uint32_t UpscaleModelPathPtr, - uint32_t UpscaleModelPathLen, uint32_t UpscaleRepeats, - uint32_t SkipLayersPtr, uint32_t SkipLayersLen, float SlgScale, - float SkipLayerStart, float SkipLayerEnd, uint32_t OutputPathPtr, - uint32_t OutputPathLen, uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, - uint32_t BytesWrittenPtr) { + uint32_t MaskImagePtr, uint32_t MaskImageLen, uint32_t SessionId, + float Guidance, uint32_t Width, uint32_t Height, uint32_t ControlImagePtr, + uint32_t ControlImageLen, uint32_t PromptPtr, uint32_t PromptLen, + uint32_t NegativePromptPtr, uint32_t NegativePromptLen, int32_t ClipSkip, + float CfgScale, uint32_t SampleMethod, uint32_t SampleSteps, float Strength, + uint32_t Seed, uint32_t BatchCount, float ControlStrength, float StyleRatio, + uint32_t NormalizeInput, uint32_t InputIdImagesDirPtr, + uint32_t InputIdImagesDirLen, uint32_t CannyPreprocess, + uint32_t UpscaleModelPathPtr, uint32_t UpscaleModelPathLen, + uint32_t UpscaleRepeats, uint32_t SkipLayersPtr, uint32_t SkipLayersLen, + float SlgScale, float SkipLayerStart, float SkipLayerEnd, + uint32_t OutputPathPtr, uint32_t OutputPathLen, uint32_t OutBufferPtr, + uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr) { // Check memory instance from module. MEMINST_CHECK(MemInst, Frame, 0) // Check the input parameter valid. MEM_SPAN_CHECK(ImageSpan, MemInst, uint8_t, ImagePtr, ImageLen, "Failed when accessing the input image memory."sv) + MEM_SPAN_CHECK(MaskImageSpan, MemInst, uint8_t, MaskImagePtr, MaskImageLen, + "Failed when accessing the input mask image memory."sv) MEM_SPAN_CHECK(PromptSpan, MemInst, char, PromptPtr, PromptLen, "Failed when accessing the promp memory."sv) MEM_SPAN_CHECK(NegativePromptSpan, MemInst, char, NegativePromptPtr, @@ -529,15 +549,17 @@ Expect SDImageToImage::body( ControlImage = readControlImage(ControlImageSpan, Width, Height, CannyPreprocess); } + // Read mask image + sd_image_t MaskImage = readMaskImage(MaskImageSpan, Width, Height); // Generate images sd_image_t *Results = nullptr; spdlog::info("[WasmEdge-StableDiffusion] Start to generate image."sv); Results = - img2img(SDCtx, InputImage, Prompt.data(), NegativePrompt.data(), ClipSkip, - CfgScale, Guidance, Width, Height, sample_method_t(SampleMethod), - SampleSteps, Strength, Seed, BatchCount, ControlImage, - ControlStrength, StyleRatio, NormalizeInput, - InputIdImagesDir.data(), SkipLayersSpan.data(), + img2img(SDCtx, InputImage, MaskImage, Prompt.data(), + NegativePrompt.data(), ClipSkip, CfgScale, Guidance, Width, + Height, sample_method_t(SampleMethod), SampleSteps, Strength, + Seed, BatchCount, ControlImage, ControlStrength, StyleRatio, + NormalizeInput, InputIdImagesDir.data(), SkipLayersSpan.data(), SkipLayersSpan.size(), SlgScale, SkipLayerStart, SkipLayerEnd); free(ControlImage); free(InputImageBuffer); diff --git a/plugins/wasmedge_stablediffusion/sd_func.h b/plugins/wasmedge_stablediffusion/sd_func.h index cd4019ba..aa9285bc 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.h +++ b/plugins/wasmedge_stablediffusion/sd_func.h @@ -34,7 +34,8 @@ class SDImageToImage : public StableDiffusion::Func { SDImageToImage(StableDiffusion::SDEnviornment &HostEnv) : Func(HostEnv) {} Expect body(const Runtime::CallingFrame &Frame, uint32_t ImagePtr, uint32_t ImageLen, - uint32_t SessionId, float Guidance, uint32_t Width, uint32_t Height, + uint32_t MaskImagePtr, uint32_t MaskImageLen, uint32_t SessionId, + float Guidance, uint32_t Width, uint32_t Height, uint32_t ControlImagePtr, uint32_t ControlImageLen, uint32_t PromptPtr, uint32_t PromptLen, uint32_t NegativePromptPtr, uint32_t NegativePromptLen, int32_t ClipSkip, float CfgScale, diff --git a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp index 83e89a65..f2114669 100644 --- a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp +++ b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp @@ -365,6 +365,8 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { std::initializer_list{ InputPathPtr, // ImagePtr static_cast(InputPath.size()), // ImageLen + 0, // MaskInputPtr + 0, // MaskInputLen SessionId, // SessionId 3.5f, // Guidance 256, // Width @@ -424,6 +426,8 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { std::initializer_list{ InputPathPtr, // ImagePtr static_cast(InputPath.size()), // ImageLen + 0, // MaskInputPtr + 0, // MaskInputLen SessionId, // SessionId 3.5f, // Guidance 256, // Width @@ -537,6 +541,8 @@ TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { std::initializer_list{ InputPathPtr, // ImagePtr static_cast(InputPath.size()), // ImageLen + 0, // MaskInputPtr + 0, // MaskInputLen -1, // SessionId 3.5f, // Guidance 256, // Width From e03700ae3fada757f241ced329ac818cd6f79c43 Mon Sep 17 00:00:00 2001 From: dm4 Date: Mon, 6 Jan 2025 16:16:11 +0800 Subject: [PATCH 512/623] [WASI-NN] ggml: bump llama.cpp b4419 Signed-off-by: dm4 --- plugins/wasi_nn/wasinn_ggml.cpp | 113 +++++++++++++++++--------------- plugins/wasi_nn/wasinn_ggml.h | 5 +- 2 files changed, 62 insertions(+), 56 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index c7af7f51..44c98d42 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -595,7 +595,7 @@ Expect getEmbedding(WasiNNEnvironment &Env, spdlog::info("[WASI-NN][Debug] GGML backend: handle embedding"sv); } // Clear the llama context. - llama_kv_cache_clear(GraphRef.LlamaContext); + llama_kv_cache_clear(GraphRef.LlamaContext.get()); // Use the const sequence id here. const llama_seq_id SequenceId = 0; @@ -603,7 +603,7 @@ Expect getEmbedding(WasiNNEnvironment &Env, auto ReturnCode = ErrNo::Success; // Add SEP if not present. - if (CxtRef.LlamaInputs.back() != llama_token_sep(GraphRef.LlamaModel)) { + if (CxtRef.LlamaInputs.back() != llama_token_sep(GraphRef.LlamaModel.get())) { spdlog::warn( "[WASI-NN] GGML backend: last token in the prompt is not SEP, 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header"sv); } @@ -620,15 +620,15 @@ Expect getEmbedding(WasiNNEnvironment &Env, return ErrNo::PromptTooLong; } - const int32_t NEmbd = llama_n_embd(GraphRef.LlamaModel); + const int32_t NEmbd = llama_n_embd(GraphRef.LlamaModel.get()); struct llama_batch Batch = llama_batch_init( /* n_tokens_alloc */ static_cast(GraphRef.BatchSize), /* embd */ 0, /* n_seq_max */ 1); std::vector Embeddings(NEmbd); batchAddSeq(Batch, CxtRef.LlamaInputs, SequenceId); - ReturnCode = batchDecode(GraphRef.LlamaContext, Batch, Embeddings.data(), - NEmbd, GraphRef.EmbdNormalize); + ReturnCode = batchDecode(GraphRef.LlamaContext.get(), Batch, + Embeddings.data(), NEmbd, GraphRef.EmbdNormalize); if (ReturnCode != ErrNo::Success) { spdlog::error("[WASI-NN] GGML backend: failed to evaluate input tokens."sv); return ReturnCode; @@ -642,12 +642,12 @@ Expect getEmbedding(WasiNNEnvironment &Env, } if (GraphRef.EnableLog) { - common_perf_print(GraphRef.LlamaContext, /* Sampler */ nullptr); + common_perf_print(GraphRef.LlamaContext.get(), /* Sampler */ nullptr); } // We clear the contexts here to keep the ggml plugin stateless. // Users could fully control the contexts by themselves via their prompt. - llama_kv_cache_clear(GraphRef.LlamaContext); + llama_kv_cache_clear(GraphRef.LlamaContext.get()); llama_batch_free(Batch); if (GraphRef.EnableDebugLog) { @@ -886,8 +886,8 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Initialize the llama model and context. common_init_result LlamaInit = common_init_from_params(Params); - GraphRef.LlamaModel = LlamaInit.model; - GraphRef.LlamaContext = LlamaInit.context; + GraphRef.LlamaModel = std::move(LlamaInit.model); + GraphRef.LlamaContext = std::move(LlamaInit.context); if (GraphRef.LlamaModel == nullptr) { spdlog::error("[WASI-NN] GGML backend: Error: unable to init model."sv); Env.NNGraph.pop_back(); @@ -963,9 +963,9 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, if (IsModelParamsUpdated) { llama_model_params ModelParams = llama_model_default_params(); ModelParams.n_gpu_layers = static_cast(GraphRef.NGPULayers); - llama_free_model(GraphRef.LlamaModel); - GraphRef.LlamaModel = llama_load_model_from_file( - GraphRef.ModelFilePath.c_str(), ModelParams); + GraphRef.LlamaModel.reset(); + GraphRef.LlamaModel = llama_model_ptr(llama_load_model_from_file( + GraphRef.ModelFilePath.c_str(), ModelParams)); if (GraphRef.LlamaModel == nullptr) { spdlog::error( "[WASI-NN] GGML backend: Error: unable to init model."sv); @@ -983,11 +983,13 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, spdlog::info( "[WASI-NN] GGML backend: Reloaded model due to parameters change."sv); } - llama_free(GraphRef.LlamaContext); + GraphRef.LlamaContext.reset(); common_params Params; setupParams(GraphRef, Params); - GraphRef.LlamaContext = llama_new_context_with_model( - GraphRef.LlamaModel, common_context_params_to_llama(Params)); + GraphRef.LlamaContext = + llama_context_ptr(llama_context_ptr(llama_new_context_with_model( + GraphRef.LlamaModel.get(), + common_context_params_to_llama(Params)))); if (GraphRef.LlamaContext == nullptr) { spdlog::error( "[WASI-NN] GGML backend: Error: unable to init context."sv); @@ -1007,7 +1009,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] GGML backend: clear llama context"sv); } - llama_kv_cache_clear(GraphRef.LlamaContext); + llama_kv_cache_clear(GraphRef.LlamaContext.get()); if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] GGML backend: clear llama context...Done"sv); } @@ -1026,7 +1028,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] GGML backend: tokenize text prompt"sv); } - CxtRef.LlamaInputs = common_tokenize(GraphRef.LlamaContext, Prompt, + CxtRef.LlamaInputs = common_tokenize(GraphRef.LlamaContext.get(), Prompt, AddSpecial, ParseSpecial); if (GraphRef.EnableDebugLog) { spdlog::info( @@ -1119,12 +1121,13 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, std::string PromptBeforeImage = Prompt.substr(0, PlaceholderPosition); std::string PromptAfterImage = Prompt.substr(PlaceholderPosition + PromptImagePlaceholder.length()); - std::vector EmbdInputBeforeImage = common_tokenize( - GraphRef.LlamaContext, PromptBeforeImage, AddSpecial, ParseSpecial); + std::vector EmbdInputBeforeImage = + common_tokenize(GraphRef.LlamaContext.get(), PromptBeforeImage, + AddSpecial, ParseSpecial); // Do not add special token (such as , , ... tokens.) to the // tokens after the image. std::vector EmbdInputAfterImage = common_tokenize( - GraphRef.LlamaContext, PromptAfterImage, false, ParseSpecial); + GraphRef.LlamaContext.get(), PromptAfterImage, false, ParseSpecial); CxtRef.LlavaImagePosition = EmbdInputBeforeImage.size(); CxtRef.LlamaInputs.reserve(EmbdInputBeforeImage.size() + EmbdInputAfterImage.size()); @@ -1216,20 +1219,20 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { } // Clear the llama context. - llama_kv_cache_clear(GraphRef.LlamaContext); + llama_kv_cache_clear(GraphRef.LlamaContext.get()); // Setup the parameters and sampler. common_params Params; setupParams(GraphRef, Params); struct common_sampler *Sampler = - common_sampler_init(GraphRef.LlamaModel, Params.sampling); + common_sampler_init(GraphRef.LlamaModel.get(), Params.sampling); // Prepare variables; int32_t NPast = 0; int32_t NPos = 0; int64_t NRemain = GraphRef.NPredict; // Get the context size. - const uint64_t NCtx = llama_n_ctx(GraphRef.LlamaContext); + const uint64_t NCtx = llama_n_ctx(GraphRef.LlamaContext.get()); // Minus 4 for the special tokens. (Such as , , ... tokens.) const uint64_t MaxTokensListSize = NCtx - 4; // Return value. @@ -1248,7 +1251,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // Evaluate input tokens. if (CxtRef.LlavaImageEmbd == nullptr) { // Text only prompt. - ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext, + ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext.get(), std::move(CxtRef.LlamaInputs), NPast, NPos); if (ReturnCode != ErrNo::Success) { spdlog::error( @@ -1263,7 +1266,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { std::vector EmbdInputAfterImage(CxtRef.LlamaInputs.begin() + CxtRef.LlavaImagePosition, CxtRef.LlamaInputs.end()); - ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext, + ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext.get(), std::move(EmbdInputBeforeImage), NPast, NPos); if (ReturnCode != ErrNo::Success) { spdlog::error( @@ -1274,14 +1277,14 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { bool EvalImageStatus = false; switch (GraphRef.VisionModelType) { case VisionModel::Llava: - EvalImageStatus = - llava_eval_image_embed(GraphRef.LlamaContext, CxtRef.LlavaImageEmbd, - static_cast(GraphRef.BatchSize), &NPast); + EvalImageStatus = llava_eval_image_embed( + GraphRef.LlamaContext.get(), CxtRef.LlavaImageEmbd, + static_cast(GraphRef.BatchSize), &NPast); break; case VisionModel::Qwen2VL: auto ImageSize = clip_get_load_image_size(GraphRef.ClipContext); EvalImageStatus = evaluateQwen2vlImageEmbed( - GraphRef.LlamaContext, CxtRef.LlavaImageEmbd, + GraphRef.LlamaContext.get(), CxtRef.LlavaImageEmbd, static_cast(GraphRef.BatchSize), &NPast, &NPos, ImageSize); break; } @@ -1291,7 +1294,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { "[WASI-NN] GGML backend: failed to evaluate embed image tokens."sv); return ErrNo::RuntimeError; } - ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext, + ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext.get(), std::move(EmbdInputAfterImage), NPast, NPos); if (ReturnCode != ErrNo::Success) { spdlog::error( @@ -1306,17 +1309,19 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { } while (NRemain > 0) { // Use idx = -1 to sample the next token. - const llama_token Id = - common_sampler_sample(Sampler, GraphRef.LlamaContext, /* idx */ -1); + const llama_token Id = common_sampler_sample( + Sampler, GraphRef.LlamaContext.get(), /* idx */ -1); common_sampler_accept(Sampler, Id, /* accept_grammar */ true); --NRemain; // Save the output token. CxtRef.LlamaOutputTokens.emplace_back(Id); - CxtRef.LlamaOutputs += common_token_to_piece(GraphRef.LlamaContext, Id); + CxtRef.LlamaOutputs += + common_token_to_piece(GraphRef.LlamaContext.get(), Id); // When setting StreamStdout, we print the output to stdout. if (GraphRef.StreamStdout) { - fmt::print("{}"sv, common_token_to_piece(GraphRef.LlamaContext, Id)); + fmt::print("{}"sv, + common_token_to_piece(GraphRef.LlamaContext.get(), Id)); std::fflush(stdout); } // Break if reverse prompt is found. @@ -1328,15 +1333,16 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { break; } // Deal with end of text token. - if (llama_token_is_eog(GraphRef.LlamaModel, common_sampler_last(Sampler))) { + if (llama_token_is_eog(GraphRef.LlamaModel.get(), + common_sampler_last(Sampler))) { if (GraphRef.EnableLog) { spdlog::info("[WASI-NN] GGML backend: EOS token found"sv); } break; } // Evaluate the output token. - ReturnCode = - evaluateTokens(GraphRef, GraphRef.LlamaContext, {Id}, NPast, NPos); + ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext.get(), {Id}, + NPast, NPos); if (ReturnCode != ErrNo::Success) { break; } @@ -1348,7 +1354,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // End of main predict loop. if (GraphRef.EnableLog) { - common_perf_print(GraphRef.LlamaContext, Sampler); + common_perf_print(GraphRef.LlamaContext.get(), Sampler); } // We free the contexts here to keep the ggml plugin stateless. @@ -1403,7 +1409,7 @@ Expect getOutputSingle(WasiNNEnvironment &Env, uint32_t ContextId, return ErrNo::Success; } std::string LastToken = common_token_to_piece( - GraphRef.LlamaContext, CxtRef.LlamaOutputTokens.back()); + GraphRef.LlamaContext.get(), CxtRef.LlamaOutputTokens.back()); std::copy_n(LastToken.data(), LastToken.length(), OutBuffer.data()); BytesWritten = static_cast(LastToken.length()); if (GraphRef.EnableDebugLog) { @@ -1446,18 +1452,18 @@ Expect computeSingle(WasiNNEnvironment &Env, } // Clear the llama context. - llama_kv_cache_clear(GraphRef.LlamaContext); + llama_kv_cache_clear(GraphRef.LlamaContext.get()); // Setup the parameters and sampler. common_params Params; setupParams(GraphRef, Params); CxtRef.LlamaSampler = - common_sampler_init(GraphRef.LlamaModel, Params.sampling); + common_sampler_init(GraphRef.LlamaModel.get(), Params.sampling); CxtRef.LlamaNPast = 0; CxtRef.LlamaNPos = 0; // Get the context size. - const uint64_t NCtx = llama_n_ctx(GraphRef.LlamaContext); + const uint64_t NCtx = llama_n_ctx(GraphRef.LlamaContext.get()); // Minus 4 for the special tokens. (Such as , , ... tokens.) const uint64_t MaxTokensListSize = NCtx - 4; // Return value. @@ -1476,7 +1482,7 @@ Expect computeSingle(WasiNNEnvironment &Env, // Evaluate input tokens. if (CxtRef.LlavaImageEmbd == nullptr) { // Text only prompt. - ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext, + ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext.get(), std::move(CxtRef.LlamaInputs), CxtRef.LlamaNPast, CxtRef.LlamaNPos); if (ReturnCode != ErrNo::Success) { @@ -1492,7 +1498,7 @@ Expect computeSingle(WasiNNEnvironment &Env, std::vector EmbdInputAfterImage( CxtRef.LlamaInputs.begin() + CxtRef.LlavaImagePosition, CxtRef.LlamaInputs.end()); - ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext, + ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext.get(), std::move(EmbdInputBeforeImage), CxtRef.LlamaNPast, CxtRef.LlamaNPos); if (ReturnCode != ErrNo::Success) { @@ -1501,14 +1507,14 @@ Expect computeSingle(WasiNNEnvironment &Env, return ReturnCode; } bool EvalImageStatus = llava_eval_image_embed( - GraphRef.LlamaContext, CxtRef.LlavaImageEmbd, + GraphRef.LlamaContext.get(), CxtRef.LlavaImageEmbd, static_cast(GraphRef.BatchSize), &CxtRef.LlamaNPast); if (!EvalImageStatus) { spdlog::error( "[WASI-NN] GGML backend: failed to evaluate embed image tokens."sv); return ErrNo::RuntimeError; } - ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext, + ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext.get(), std::move(EmbdInputAfterImage), CxtRef.LlamaNPast, CxtRef.LlamaNPos); if (ReturnCode != ErrNo::Success) { @@ -1526,15 +1532,15 @@ Expect computeSingle(WasiNNEnvironment &Env, auto ReturnCode = ErrNo::Success; // Use idx = -1 to sample the next token. const llama_token Id = common_sampler_sample( - CxtRef.LlamaSampler, GraphRef.LlamaContext, /* idx */ -1); + CxtRef.LlamaSampler, GraphRef.LlamaContext.get(), /* idx */ -1); common_sampler_accept(CxtRef.LlamaSampler, Id, /* accept_grammar */ true); // Save the output token. // In single token mode, we do not handle StreamStdout and ReversePrompt. CxtRef.LlamaOutputTokens.emplace_back(Id); - CxtRef.LlamaOutputs += common_token_to_piece(GraphRef.LlamaContext, Id); + CxtRef.LlamaOutputs += common_token_to_piece(GraphRef.LlamaContext.get(), Id); // Deal with end of text token. - if (llama_token_is_eog(GraphRef.LlamaModel, + if (llama_token_is_eog(GraphRef.LlamaModel.get(), common_sampler_last(CxtRef.LlamaSampler))) { ReturnCode = ErrNo::EndOfSequence; if (GraphRef.EnableLog) { @@ -1543,7 +1549,7 @@ Expect computeSingle(WasiNNEnvironment &Env, } // Evaluate the output token if not EOS. if (ReturnCode != ErrNo::EndOfSequence) { - ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext, {Id}, + ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext.get(), {Id}, CxtRef.LlamaNPast, CxtRef.LlamaNPos); } if (GraphRef.EnableDebugLog) { @@ -1569,7 +1575,7 @@ Expect finiSingle(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // Logging for the llama timings. if (GraphRef.EnableLog) { - common_perf_print(GraphRef.LlamaContext, CxtRef.LlamaSampler); + common_perf_print(GraphRef.LlamaContext.get(), CxtRef.LlamaSampler); } // Clear the outputs. @@ -1589,7 +1595,7 @@ Expect finiSingle(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { spdlog::info( "[WASI-NN][Debug] GGML backend: finiSingle: clear the llama context"sv); } - llama_kv_cache_clear(GraphRef.LlamaContext); + llama_kv_cache_clear(GraphRef.LlamaContext.get()); common_sampler_reset(CxtRef.LlamaSampler); common_sampler_free(CxtRef.LlamaSampler); CxtRef.LlamaSampler = nullptr; @@ -1626,8 +1632,7 @@ Expect unload(WasiNNEnvironment &Env, uint32_t GraphId) noexcept { if (IsDebugLog) { spdlog::info("[WASI-NN][Debug] GGML backend: unload: free llama model"sv); } - llama_free_model(GraphRef.LlamaModel); - GraphRef.LlamaModel = nullptr; + GraphRef.LlamaModel.reset(); if (IsDebugLog) { spdlog::info( "[WASI-NN][Debug] GGML backend: unload: free llama model...Done"sv); diff --git a/plugins/wasi_nn/wasinn_ggml.h b/plugins/wasi_nn/wasinn_ggml.h index 887a35ba..b6b31455 100644 --- a/plugins/wasi_nn/wasinn_ggml.h +++ b/plugins/wasi_nn/wasinn_ggml.h @@ -9,6 +9,7 @@ #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML #include +#include #include #include #include @@ -38,9 +39,9 @@ enum class VisionModel : uint8_t { }; struct Graph { - llama_model *LlamaModel = nullptr; + llama_model_ptr LlamaModel = nullptr; std::string ModelFilePath; - llama_context *LlamaContext = nullptr; + llama_context_ptr LlamaContext = nullptr; struct clip_ctx *ClipContext = nullptr; // Plugin parameters: bool EnableLog = false; From f00c77c1390410038b5ee4ebb9eba71b62982aa6 Mon Sep 17 00:00:00 2001 From: dm4 Date: Mon, 6 Jan 2025 16:26:34 +0800 Subject: [PATCH 513/623] [WASI-NN] ggml: support split-mode option Signed-off-by: dm4 --- plugins/wasi_nn/wasinn_ggml.cpp | 22 ++++++++++++++++++++++ plugins/wasi_nn/wasinn_ggml.h | 1 + 2 files changed, 23 insertions(+) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 44c98d42..69af671f 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -58,6 +58,7 @@ Expect setupParams(Graph &GraphRef, common_params &Params) { Params.n_batch = static_cast(GraphRef.BatchSize); Params.n_ubatch = static_cast(GraphRef.UBatchSize); Params.warmup = GraphRef.WarmUp; + Params.split_mode = GraphRef.SplitMode; Params.cpuparams.n_threads = static_cast(GraphRef.Threads); Params.cpuparams_batch.n_threads = static_cast(GraphRef.Threads); Params.embedding = GraphRef.Embedding; @@ -100,6 +101,7 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, // tensor-split: string, comma-separated floating number list // use-mmap: use mmap // warmup: bool + // split-mode: string, {none,layer,row} // Context parameters (used by the llama context): // ctx-size: int64_t // batch-size: int64_t @@ -252,6 +254,26 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, return ErrNo::InvalidArgument; } } + if (Doc.at_key("split-mode").error() == simdjson::SUCCESS) { + std::string_view SplitMode; + auto Err = Doc["split-mode"].get().get(SplitMode); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the split-mode option."sv); + return ErrNo::InvalidArgument; + } + if (SplitMode == "none") { + GraphRef.SplitMode = LLAMA_SPLIT_MODE_NONE; + } else if (SplitMode == "layer") { + GraphRef.SplitMode = LLAMA_SPLIT_MODE_LAYER; + } else if (SplitMode == "row") { + GraphRef.SplitMode = LLAMA_SPLIT_MODE_ROW; + } else { + spdlog::error("[WASI-NN] GGML backend: Invalid split-mode option: {}"sv, + SplitMode); + return ErrNo::InvalidArgument; + } + } // The context parameters. if (Doc.at_key("ctx-size").error() == simdjson::SUCCESS) { diff --git a/plugins/wasi_nn/wasinn_ggml.h b/plugins/wasi_nn/wasinn_ggml.h index b6b31455..8e6717b5 100644 --- a/plugins/wasi_nn/wasinn_ggml.h +++ b/plugins/wasi_nn/wasinn_ggml.h @@ -61,6 +61,7 @@ struct Graph { std::vector TensorSplit; bool UseMMap = true; bool WarmUp = false; + enum llama_split_mode SplitMode = LLAMA_SPLIT_MODE_LAYER; // Context parameters: int64_t CtxSize; int64_t BatchSize; From c43d63319c37a219db99fe11bd93aa3f5ab746c9 Mon Sep 17 00:00:00 2001 From: dm4 Date: Mon, 6 Jan 2025 16:31:34 +0800 Subject: [PATCH 514/623] [WASI-NN] ggml: support seed option Signed-off-by: dm4 --- plugins/wasi_nn/wasinn_ggml.cpp | 10 ++++++++++ plugins/wasi_nn/wasinn_ggml.h | 1 + 2 files changed, 11 insertions(+) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 69af671f..bcb5d397 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -68,6 +68,7 @@ Expect setupParams(Graph &GraphRef, common_params &Params) { Params.sampling.penalty_present = static_cast(GraphRef.PresencePenalty); Params.sampling.grammar = GraphRef.Grammar; + Params.sampling.seed = static_cast(GraphRef.Seed); return ErrNo::Success; } @@ -114,6 +115,7 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, // presence-penalty: double // frequency-penalty: double // grammar: string + // seed: uint64_t // Get the current llama parameters. common_params Params; @@ -374,6 +376,14 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, GraphRef.Grammar = json_schema_to_grammar(nlohmann::ordered_json::parse(JsonSchema)); } + if (Doc.at_key("seed").error() == simdjson::SUCCESS) { + auto Err = Doc["seed"].get().get(GraphRef.Seed); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the seed option."sv); + return ErrNo::InvalidArgument; + } + } // Check if the model is updated. if (IsModelUpdated && Params.n_gpu_layers != GraphRef.NGPULayers) { diff --git a/plugins/wasi_nn/wasinn_ggml.h b/plugins/wasi_nn/wasinn_ggml.h index 8e6717b5..e082710c 100644 --- a/plugins/wasi_nn/wasinn_ggml.h +++ b/plugins/wasi_nn/wasinn_ggml.h @@ -74,6 +74,7 @@ struct Graph { double PresencePenalty = 0.00; double FrequencyPenalty = 0.00; std::string Grammar; + uint64_t Seed = LLAMA_DEFAULT_SEED; }; struct Context { From 52ab729f131c82a8798592269a615dcb6d4c8e9a Mon Sep 17 00:00:00 2001 From: YiYing He Date: Wed, 18 Dec 2024 21:43:50 +0800 Subject: [PATCH 515/623] [CMake] Move the boost, libpng, and libjpeg dependency to the helper. Signed-off-by: YiYing He --- plugins/wasmedge_image/CMakeLists.txt | 102 ++------------------------ 1 file changed, 6 insertions(+), 96 deletions(-) diff --git a/plugins/wasmedge_image/CMakeLists.txt b/plugins/wasmedge_image/CMakeLists.txt index 34bb9285..af130881 100644 --- a/plugins/wasmedge_image/CMakeLists.txt +++ b/plugins/wasmedge_image/CMakeLists.txt @@ -19,107 +19,17 @@ target_include_directories(wasmedgePluginWasmEdgeImage ${CMAKE_CURRENT_SOURCE_DIR} ) -# Find the libjpeg and libpng. -add_library(wasmedgePluginWasmEdgeImageJPEG STATIC IMPORTED GLOBAL) -add_library(wasmedgePluginWasmEdgeImagePNG STATIC IMPORTED GLOBAL) -if(APPLE) - # For MacOS, use the installed libjpeg, libpng, and zlib static library. - find_package(JPEG REQUIRED) - find_package(PNG REQUIRED) - # The find_package will get the shared library. Therefore find the static one. - find_library(JPEG_STATIC NAMES libjpeg.a) - find_library(PNG_STATIC NAMES libpng16.a) - set_target_properties(wasmedgePluginWasmEdgeImageJPEG - PROPERTIES - IMPORTED_LOCATION ${JPEG_STATIC} - INTERFACE_INCLUDE_DIRECTORIES "${JPEG_INCLUDE_DIR}" - ) - set_target_properties(wasmedgePluginWasmEdgeImagePNG - PROPERTIES - IMPORTED_LOCATION ${PNG_STATIC} - INTERFACE_INCLUDE_DIRECTORIES "${PNG_INCLUDE_DIR}" - ) -elseif(UNIX) - # Fetch and build libjpeg and libpng. - include(FetchContent) - message(STATUS "Downloading libpng source") - FetchContent_Declare( - wasmedge_image_libpng - URL "https://downloads.sourceforge.net/libpng/libpng-1.6.39.tar.gz" - URL_HASH "SHA256=af4fb7f260f839919e5958e5ab01a275d4fe436d45442a36ee62f73e5beb75ba" - ) - FetchContent_MakeAvailable(wasmedge_image_libpng) - message(STATUS "Downloading libpng source - done") - - message(STATUS "Downloading libjpeg source") - FetchContent_Declare( - wasmedge_image_libjpeg - URL "http://ijg.org/files/jpegsrc.v9e.tar.gz" - URL_HASH "SHA256=4077d6a6a75aeb01884f708919d25934c93305e49f7e3f36db9129320e6f4f3d" - ) - FetchContent_MakeAvailable(wasmedge_image_libjpeg) - message(STATUS "Downloading libjpeg source - done") - - add_custom_command( - OUTPUT ${wasmedge_image_libjpeg_SOURCE_DIR}/.libs/libjpeg.a - COMMAND ${CMAKE_COMMAND} -E env CFLAGS=-fPIC ./configure --enable-shared=off - COMMAND make - WORKING_DIRECTORY ${wasmedge_image_libjpeg_SOURCE_DIR} - ) - add_custom_command( - OUTPUT ${wasmedge_image_libpng_SOURCE_DIR}/.libs/libpng16.a - COMMAND ${CMAKE_COMMAND} -E env CFLAGS=-fPIC ./configure --enable-shared=off - COMMAND make - WORKING_DIRECTORY ${wasmedge_image_libpng_SOURCE_DIR} - ) - add_custom_target(wasmedgePluginWasmEdgeImageJPEG_target - ALL DEPENDS - ${wasmedge_image_libjpeg_SOURCE_DIR}/.libs/libjpeg.a - ) - add_custom_target(wasmedgePluginWasmEdgeImagePNG_target - ALL DEPENDS - ${wasmedge_image_libpng_SOURCE_DIR}/.libs/libpng16.a - ) - add_dependencies(wasmedgePluginWasmEdgeImageJPEG wasmedgePluginWasmEdgeImageJPEG_target) - add_dependencies(wasmedgePluginWasmEdgeImagePNG wasmedgePluginWasmEdgeImagePNG_target) - set_target_properties(wasmedgePluginWasmEdgeImageJPEG - PROPERTIES - IMPORTED_LOCATION ${wasmedge_image_libjpeg_SOURCE_DIR}/.libs/libjpeg.a - INTERFACE_INCLUDE_DIRECTORIES ${wasmedge_image_libjpeg_SOURCE_DIR} - ) - set_target_properties(wasmedgePluginWasmEdgeImagePNG - PROPERTIES - IMPORTED_LOCATION ${wasmedge_image_libpng_SOURCE_DIR}/.libs/libpng16.a - INTERFACE_INCLUDE_DIRECTORIES ${wasmedge_image_libpng_SOURCE_DIR} - ) -endif() - -# Need zlib and boost. +# Need libjpeg, libpng, zlib, and boost. find_package(ZLIB REQUIRED) -find_package(Boost 1.74.0 CONFIG) -if(${Boost_FOUND}) -else() - include(FetchContent) - message(STATUS "Downloading boost 1.82.0 source") - FetchContent_Declare( - Boost - URL http://sources.buildroot.net/boost/boost_1_82_0.tar.bz2 - URL_HASH SHA256=a6e1ab9b0860e6a2881dd7b21fe9f737a095e5f33a3a874afc6a345228597ee6 - ) - set(BOOST_ENABLE_CMAKE ON) - set(BOOST_RUNTIME_LINK static) - FetchContent_MakeAvailable(Boost) - message(STATUS "Downloading boost 1.82.0 source - done") - add_library(Boost_boost INTERFACE) - add_library(Boost::boost ALIAS Boost_boost) - target_include_directories(Boost_boost SYSTEM INTERFACE ${boost_SOURCE_DIR}) -endif() +wasmedge_setup_jpeg() +wasmedge_setup_png() +wasmedge_setup_boost() target_link_libraries(wasmedgePluginWasmEdgeImage PUBLIC Boost::boost - wasmedgePluginWasmEdgeImageJPEG - wasmedgePluginWasmEdgeImagePNG + wasmedgeDepsJPEG + wasmedgeDepsPNG z ) if(WASMEDGE_LINK_PLUGINS_STATIC) From 7acc73fb31c86200b37d4b4343060b54e1251be5 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Thu, 19 Dec 2024 01:39:12 +0800 Subject: [PATCH 516/623] [WASI-NN] ggml: refine the logging. Signed-off-by: YiYing He --- plugins/wasi_nn/wasinn_ggml.cpp | 884 ++++++++++++-------------------- plugins/wasi_nn/wasinn_ggml.h | 2 +- 2 files changed, 337 insertions(+), 549 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index bcb5d397..5239a01c 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -27,6 +27,27 @@ namespace WasmEdge::Host::WASINN::GGML { namespace { +// Macro for logging debug message. +#define LOG_DEBUG(Debug, ...) \ + if (Debug) { \ + spdlog::info("[WASI-NN][Debug] GGML backend: "sv __VA_ARGS__); \ + } + +// Macro for logging info message. +#define LOG_INFO(Info, ...) \ + if (Info) { \ + spdlog::info("[WASI-NN] GGML backend: "sv __VA_ARGS__); \ + } + +// Macro for logging warning message. +#define LOG_WARN(...) spdlog::warn("[WASI-NN] GGML backend: "sv __VA_ARGS__); + +// Macro for logging error message and return. +#define RET_ERROR(Error, ...) \ + spdlog::error("[WASI-NN] GGML backend: "sv __VA_ARGS__); \ + return Error; + +// Llama logging callback. void LlamaLogCallback(ggml_log_level LogLevel, const char *LogText, void *UserData) { Graph &GraphRef = *reinterpret_cast(UserData); @@ -51,7 +72,8 @@ void LlamaLogCallback(ggml_log_level LogLevel, const char *LogText, } } -Expect setupParams(Graph &GraphRef, common_params &Params) { +// Setup llama common params from graph. +void setupParams(Graph &GraphRef, common_params &Params) { Params.model = GraphRef.ModelFilePath; Params.n_gpu_layers = static_cast(GraphRef.NGPULayers); Params.n_ctx = static_cast(GraphRef.CtxSize); @@ -69,9 +91,9 @@ Expect setupParams(Graph &GraphRef, common_params &Params) { static_cast(GraphRef.PresencePenalty); Params.sampling.grammar = GraphRef.Grammar; Params.sampling.seed = static_cast(GraphRef.Seed); - return ErrNo::Success; } +// Parse metadata from json. Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, bool *IsModelUpdated = nullptr, bool *IsContextUpdated = nullptr) noexcept { @@ -79,8 +101,7 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, simdjson::dom::element Doc; auto ParseError = Parser.parse(Metadata).get(Doc); if (ParseError) { - spdlog::error("[WASI-NN] GGML backend: Parse metadata error"sv); - return ErrNo::InvalidEncoding; + RET_ERROR(ErrNo::InvalidEncoding, "parse metadata error."sv) } // Get metadata from the json. @@ -125,50 +146,44 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, if (Doc.at_key("enable-log").error() == simdjson::SUCCESS) { auto Err = Doc["enable-log"].get().get(GraphRef.EnableLog); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the enable-log option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the enable-log option."sv) } } if (Doc.at_key("enable-debug-log").error() == simdjson::SUCCESS) { auto Err = Doc["enable-debug-log"].get().get(GraphRef.EnableDebugLog); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the enable-debug-log option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the enable-debug-log option."sv) } } if (Doc.at_key("stream-stdout").error() == simdjson::SUCCESS) { auto Err = Doc["stream-stdout"].get().get(GraphRef.StreamStdout); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the stream-stdout option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the stream-stdout option."sv) } } if (Doc.at_key("embedding").error() == simdjson::SUCCESS) { auto Err = Doc["embedding"].get().get(GraphRef.Embedding); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the embedding option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the embedding option."sv) } } if (Doc.at_key("n-predict").error() == simdjson::SUCCESS) { auto Err = Doc["n-predict"].get().get(GraphRef.NPredict); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the n-predict option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-predict option."sv) } } if (Doc.at_key("reverse-prompt").error() == simdjson::SUCCESS) { std::string_view ReversePrompt; auto Err = Doc["reverse-prompt"].get().get(ReversePrompt); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the reverse-prompt option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the reverse-prompt option."sv) } GraphRef.ReversePrompt = ReversePrompt; } @@ -176,9 +191,8 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, std::string_view MMProjModelPath; auto Err = Doc["mmproj"].get().get(MMProjModelPath); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the mmproj option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the mmproj option."sv) } GraphRef.MMProjModelPath = MMProjModelPath; } @@ -186,9 +200,8 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, std::string_view ImagePath; auto Err = Doc["image"].get().get(ImagePath); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the image option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the image option."sv) } GraphRef.ImagePath = ImagePath; } @@ -197,17 +210,15 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, if (Doc.at_key("n-gpu-layers").error() == simdjson::SUCCESS) { auto Err = Doc["n-gpu-layers"].get().get(GraphRef.NGPULayers); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the n-gpu-layers option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-gpu-layers option."sv) } } if (Doc.at_key("main-gpu").error() == simdjson::SUCCESS) { auto Err = Doc["main-gpu"].get().get(GraphRef.MainGPU); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the main-gpu option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the main-gpu option."sv) } } if (Doc.at_key("tensor-split").error() == simdjson::SUCCESS) { @@ -216,9 +227,8 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, std::string_view TSV; auto Err = Doc["tensor-split"].get().get(TSV); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the tensor-split option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the tensor-split option."sv) } std::string TS(TSV); std::replace(TS.begin(), TS.end(), ',', ' '); @@ -231,10 +241,10 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, } size_t NDevices = llama_max_devices(); if (GraphRef.TensorSplit.size() > NDevices) { - spdlog::error( - "[WASI-NN] GGML backend: Number of Tensor-Split is larger " - "than MaxDevices, please reduce the size of tensor-split."sv); - return ErrNo::InvalidArgument; + RET_ERROR( + ErrNo::InvalidArgument, + "Number of Tensor-Split is larger than MaxDevices, please reduce "sv + "the size of tensor-split."sv) } for (size_t Idx = GraphRef.TensorSplit.size(); Idx < NDevices; Idx++) { GraphRef.TensorSplit.push_back(0.0f); @@ -243,17 +253,15 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, if (Doc.at_key("use-mmap").error() == simdjson::SUCCESS) { auto Err = Doc["use-mmap"].get().get(GraphRef.UseMMap); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the use-mmap option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the use-mmap option."sv) } } if (Doc.at_key("warmup").error() == simdjson::SUCCESS) { auto Err = Doc["warmup"].get().get(GraphRef.WarmUp); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the warmup option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the warmup option."sv) } } if (Doc.at_key("split-mode").error() == simdjson::SUCCESS) { @@ -281,33 +289,29 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, if (Doc.at_key("ctx-size").error() == simdjson::SUCCESS) { auto Err = Doc["ctx-size"].get().get(GraphRef.CtxSize); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the ctx-size option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the ctx-size option."sv) } } if (Doc.at_key("batch-size").error() == simdjson::SUCCESS) { auto Err = Doc["batch-size"].get().get(GraphRef.BatchSize); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the batch-size option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the batch-size option."sv) } } if (Doc.at_key("ubatch-size").error() == simdjson::SUCCESS) { auto Err = Doc["ubatch-size"].get().get(GraphRef.UBatchSize); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the ubatch-size option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the ubatch-size option."sv) } } if (Doc.at_key("threads").error() == simdjson::SUCCESS) { auto Err = Doc["threads"].get().get(GraphRef.Threads); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the threads option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the threads option."sv) } } @@ -315,53 +319,46 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, if (Doc.at_key("temp").error() == simdjson::SUCCESS) { auto Err = Doc["temp"].get().get(GraphRef.Temp); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the temp option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the temp option."sv) } GraphRef.Temp = std::max(0.0, GraphRef.Temp); } if (Doc.at_key("top-p").error() == simdjson::SUCCESS) { auto Err = Doc["top-p"].get().get(GraphRef.TopP); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the top-p option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the top-p option."sv) } } if (Doc.at_key("repeat-penalty").error() == simdjson::SUCCESS) { auto Err = Doc["repeat-penalty"].get().get(GraphRef.RepeatPenalty); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the repeat-penalty option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the repeat-penalty option."sv) } } if (Doc.at_key("presence-penalty").error() == simdjson::SUCCESS) { auto Err = Doc["presence-penalty"].get().get(GraphRef.PresencePenalty); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the presence-penalty option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the presence-penalty option."sv) } } if (Doc.at_key("frequency-penalty").error() == simdjson::SUCCESS) { auto Err = Doc["frequency-penalty"].get().get(GraphRef.FrequencyPenalty); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the frequency-penalty option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the frequency-penalty option."sv) } } if (Doc.at_key("grammar").error() == simdjson::SUCCESS) { std::string_view Grammar; auto Err = Doc["grammar"].get().get(Grammar); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the grammar option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the grammar option."sv) } GraphRef.Grammar = Grammar; } @@ -369,9 +366,8 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, std::string_view JsonSchema; auto Err = Doc["json-schema"].get().get(JsonSchema); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the json-schema option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the json-schema option."sv) } GraphRef.Grammar = json_schema_to_grammar(nlohmann::ordered_json::parse(JsonSchema)); @@ -398,15 +394,13 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, return ErrNo::Success; } -Expect buildOutputMetadata(Context &CxtRef, - std::string &Metadata) noexcept { +void buildOutputMetadata(Context &CxtRef, std::string &Metadata) noexcept { Metadata = fmt::format(R"({{"input_tokens": {}, )" R"("output_tokens": {}, )" R"("llama_build_number": {}, )" R"("llama_commit": "{}"}})"sv, CxtRef.LlamaNInputs, CxtRef.LlamaOutputTokens.size(), LLAMA_BUILD_NUMBER, LLAMA_COMMIT); - return ErrNo::Success; } void buildOutputEmbedding(std::string &Embedding, int32_t NEmbd, @@ -481,9 +475,7 @@ static bool evaluateQwen2vlImageEmbed( nullptr, // logits }; if (llama_decode(CtxLlama, Batch)) { - spdlog::error( - "[WASI-NN] GGML backend: evaluateQwen2vlImageEmbed failed to eval"sv); - return false; + RET_ERROR(false, "evaluateQwen2vlImageEmbed: fail to eval."sv) } *NPast += NEval; Processed += NEval; @@ -498,11 +490,11 @@ ErrNo evaluateTokens(Graph &GraphRef, struct llama_context *LlamaContext, // End the inference if the context is full. if (NPast + static_cast(Tokens.size()) > NCtx) { - if (GraphRef.EnableLog) { - spdlog::info( - "[WASI-NN] GGML backend: the context if full ({} / {} tokens). Please increase your context size."sv, - NPast + static_cast(Tokens.size()), NCtx); - } + LOG_INFO( + GraphRef.EnableLog, + "the context if full ({} / {} tokens). Please increase your context "sv + "size."sv, + NPast + static_cast(Tokens.size()), NCtx) return ErrNo::ContextFull; } @@ -529,13 +521,13 @@ ErrNo evaluateTokens(Graph &GraphRef, struct llama_context *LlamaContext, // Decode the batch. auto Status = llama_decode(LlamaContext, Batch); if (Status == 1) { - spdlog::error( - "[WASI-NN] GGML backend: failed to llama_decode: try reducing the size of the batch or increasing the size of context"sv); - return ErrNo::RuntimeError; + RET_ERROR(ErrNo::RuntimeError, + "failed to llama_decode: try reducing the size of the batch "sv + "or increasing the size of context."sv) } else if (Status < 0) { - spdlog::error( - "[WASI-NN] GGML backend: failed to llama_decode: internal fatal error. Please open an issue on GitHub"sv); - return ErrNo::RuntimeError; + RET_ERROR(ErrNo::RuntimeError, + "failed to llama_decode: internal fatal error. Please open "sv + "an issue on GitHub."sv) } NPast += NEval; NPos += NEval; @@ -563,13 +555,13 @@ ErrNo batchDecode(llama_context *LlamaContext, llama_batch &Batch, // Decode the batch. auto Status = llama_decode(LlamaContext, Batch); if (Status == 1) { - spdlog::error( - "[WASI-NN] GGML backend: failed to llama_decode: try reducing the size of the batch or increasing the size of context"sv); - return ErrNo::RuntimeError; + RET_ERROR(ErrNo::RuntimeError, + "failed to llama_decode: try reducing the size of the batch or "sv + "increasing the size of context."sv) } else if (Status < 0) { - spdlog::error( - "[WASI-NN] GGML backend: failed to llama_decode: internal fatal error. Please open an issue on GitHub"sv); - return ErrNo::RuntimeError; + RET_ERROR(ErrNo::RuntimeError, + "failed to llama_decode: internal fatal error. Please open an "sv + "issue on GitHub."sv) } for (int I = 0; I < Batch.n_tokens; I++) { @@ -601,31 +593,23 @@ Expect getEmbedding(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: getEmbedding"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, "getEmbedding"sv) if (CxtRef.LlamaInputs.size() == 0) { - spdlog::error("[WASI-NN] GGML backend: Llama input is not set!"sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, "Llama input is not set!"sv) } // Clear the outputs. - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: clear the previous output and tokens"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, + "getEmbedding: clear the previous output and tokens"sv) CxtRef.LlamaOutputs.clear(); CxtRef.LlamaOutputTokens.clear(); - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: clear the previous output and tokens...Done"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, + "getEmbedding: clear the previous output and tokens...Done"sv) + + // Main prediction loop. + LOG_DEBUG(GraphRef.EnableDebugLog, "getEmbedding: enter embedding loop"sv) - // Main predict loop. - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: handle embedding"sv); - } // Clear the llama context. llama_kv_cache_clear(GraphRef.LlamaContext.get()); @@ -636,20 +620,19 @@ Expect getEmbedding(WasiNNEnvironment &Env, // Add SEP if not present. if (CxtRef.LlamaInputs.back() != llama_token_sep(GraphRef.LlamaModel.get())) { - spdlog::warn( - "[WASI-NN] GGML backend: last token in the prompt is not SEP, 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header"sv); + LOG_WARN( + "last token in the prompt is not SEP, "sv + "'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF "sv + "header."sv) } // Check if the input is too long. if (static_cast(CxtRef.LlamaInputs.size()) > GraphRef.BatchSize) { - if (GraphRef.EnableLog) { - spdlog::info( - "[WASI-NN] GGML backend: the prompt is too long. " - "Your input has {} tokens exceeds batch size {}. " - "Please reduce the input size or increase your batch-size."sv, - CxtRef.LlamaInputs.size(), GraphRef.BatchSize); - } - return ErrNo::PromptTooLong; + RET_ERROR( + ErrNo::PromptTooLong, + "the prompt is too long. Your input has {} tokens exceeds batch "sv + "size {}. Please reduce the input size or increase your batch-size."sv, + CxtRef.LlamaInputs.size(), GraphRef.BatchSize) } const int32_t NEmbd = llama_n_embd(GraphRef.LlamaModel.get()); @@ -662,16 +645,12 @@ Expect getEmbedding(WasiNNEnvironment &Env, ReturnCode = batchDecode(GraphRef.LlamaContext.get(), Batch, Embeddings.data(), NEmbd, GraphRef.EmbdNormalize); if (ReturnCode != ErrNo::Success) { - spdlog::error("[WASI-NN] GGML backend: failed to evaluate input tokens."sv); - return ReturnCode; + RET_ERROR(ReturnCode, "failed to evaluate input tokens."sv) } - buildOutputEmbedding(CxtRef.LlamaOutputs, NEmbd, Embeddings.data()); - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: enter embedding loop...Done"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, + "getEmbedding: enter embedding loop...Done"sv) if (GraphRef.EnableLog) { common_perf_print(GraphRef.LlamaContext.get(), /* Sampler */ nullptr); @@ -682,17 +661,14 @@ Expect getEmbedding(WasiNNEnvironment &Env, llama_kv_cache_clear(GraphRef.LlamaContext.get()); llama_batch_free(Batch); - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: getEmbedding...Done"sv); - } - + LOG_DEBUG(GraphRef.EnableDebugLog, "getEmbedding...Done"sv) return ErrNo::Success; } const std::string_view Base64ImageTagPrefix = ""sv; -const std::string_view PromptImagePlaceholder = ""sv; +const std::string_view LlavaPromptImagePlaceholder = ""sv; bool containsBase64Image(Graph &GraphRef, std::string_view Prompt) noexcept { // Check if the prompt contains a base64 image. @@ -701,82 +677,74 @@ bool containsBase64Image(Graph &GraphRef, std::string_view Prompt) noexcept { auto Base64ImageTagBeginPos = Prompt.find(Base64ImageTagPrefix); if (Base64ImageTagBeginPos == std::string::npos) { - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: No base64 image tag found in the prompt."sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, + "No base64 image tag found in the prompt."sv) return false; } auto Base64ImageTagEndPos = Prompt.find(Base64ImageTagSuffix, Base64ImageTagBeginPos); if (Base64ImageTagEndPos == std::string::npos) { - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: Found an unclosed base64 image tag."sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, "Found an unclosed base64 image tag."sv) return false; } return true; } -struct llava_image_embed * -loadBase64ImageFromPrompt(Graph &GraphRef, clip_ctx *ClipContext, - std::string_view Prompt) noexcept { - // Load the base64 image from the prompt. - // Follow this link for the supported image formats: - // https://github.com/ggerganov/llama.cpp/blob/master/common/stb_image.h - - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: loadBase64ImageFromPrompt"sv); - } - +std::string_view findBase64ImagePayload(std::string_view Prompt) noexcept { // Find `` auto Base64ImageTagEndPos = Prompt.find(Base64ImageTagSuffix, Base64ImageBytesBeginPos); if (Base64ImageTagEndPos == std::string::npos) { - return nullptr; + return Prompt.substr(); } - auto Base64Str = - Prompt.substr(Base64ImageBytesBeginPos + Base64ImageBytesPrefix.size(), - Base64ImageTagEndPos - Base64ImageBytesBeginPos - - Base64ImageBytesPrefix.size()); + return Prompt.substr(Base64ImageBytesBeginPos + Base64ImageBytesPrefix.size(), + Base64ImageTagEndPos - Base64ImageBytesBeginPos - + Base64ImageBytesPrefix.size()); +} + +struct llava_image_embed * +llavaLoadBase64ImageFromPrompt(Graph &GraphRef, clip_ctx *ClipContext, + std::string_view Prompt) noexcept { + // Load the base64 image from the prompt. + // Follow this link for the supported image formats: + // https://github.com/ggerganov/llama.cpp/blob/master/common/stb_image.h + LOG_DEBUG(GraphRef.EnableDebugLog, "llavaLoadBase64ImageFromPrompt"sv) // Decode the base64 image. + auto Base64Str = findBase64ImagePayload(Prompt); + if (Base64Str.size() == 0) { + return nullptr; + } auto RequiredBytes = base64::required_encode_size(Base64Str.size()); auto ImageBytes = std::vector(RequiredBytes); try { base64::decode(Base64Str.begin(), Base64Str.end(), ImageBytes.begin()); } catch (const base64_error &E) { - spdlog::error("[WASI-NN] GGML backend: Error when base64::decode: {}"sv, - E.what()); - return nullptr; - } - - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: loadBase64ImageFromPrompt...Done"sv); + RET_ERROR(nullptr, "Error when base64::decode: {}"sv, E.what()) } + LOG_DEBUG(GraphRef.EnableDebugLog, "llavaLoadBase64ImageFromPrompt...Done"sv) return llava_image_embed_make_with_bytes( ClipContext, static_cast(GraphRef.Threads), ImageBytes.data(), static_cast(ImageBytes.size())); } -ErrNo replaceBase64ImagePlaceholderInPrompt(std::string &Prompt) noexcept { +ErrNo replaceBase64ImagePlaceholderInPrompt( + std::string &Prompt, const std::string_view Placeholder) noexcept { // Replace the base64 image in the prompt with a placeholder. // Find `(Weight.data()), Weight.size()); if (BinModel.substr(0, 8) == "preload:"sv) { GraphRef.ModelFilePath = BinModel.substr(8); } else { - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: Model path not found in nn-preload, " - "write model into a tmpfile."sv); - } - // TODO: pass the model directly to ggml + LOG_DEBUG(GraphRef.EnableDebugLog, + "load: Model path not found in nn-preload, write model into "sv + "a tmpfile."sv) + // TODO: pass the model directly to ggml. // Write ggml model to file. GraphRef.ModelFilePath = "ggml-model.bin"sv; std::ofstream TempFile(GraphRef.ModelFilePath, std::ios::out | std::ios::binary); if (!TempFile) { - spdlog::error( - "[WASI-NN] GGML backend: Failed to create the temporary file. " - "Currently, our workaround involves creating a temporary model " - "file named \"ggml-model.bin\" and passing this filename as a " - "parameter to the ggml llama library."sv); Env.NNGraph.pop_back(); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Failed to create the temporary file. Currently, our "sv + "workaround involves creating a temporary model file named "sv + "\"ggml-model.bin\" and passing this filename as a "sv + "parameter to the ggml llama library."sv) } TempFile.write(BinModel.data(), BinModel.size()); TempFile.close(); - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: Write model into a tmpfile...Done"sv); - } - } - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: Finished handling model path."sv); + LOG_DEBUG(GraphRef.EnableDebugLog, + "load: Write model into a tmpfile...Done"sv) } + LOG_DEBUG(GraphRef.EnableDebugLog, "load: handling model path...Done"sv) + // Check if the model exists. if (!std::filesystem::exists( std::filesystem::u8path(GraphRef.ModelFilePath))) { - spdlog::error("[WASI-NN] GGML backend: Model file not found."sv); Env.NNGraph.pop_back(); - return ErrNo::ModelNotFound; + RET_ERROR(ErrNo::ModelNotFound, "model file not found."sv) } - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: Initialize ggml model with given parameters"sv); - } // Initialize ggml parameters. + LOG_DEBUG(GraphRef.EnableDebugLog, + "load: initialize ggml model with given parameters."sv) common_params Params; setupParams(GraphRef, Params); llama_backend_init(); @@ -921,41 +872,32 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, GraphRef.LlamaModel = std::move(LlamaInit.model); GraphRef.LlamaContext = std::move(LlamaInit.context); if (GraphRef.LlamaModel == nullptr) { - spdlog::error("[WASI-NN] GGML backend: Error: unable to init model."sv); Env.NNGraph.pop_back(); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, "unable to init model."sv) } if (GraphRef.LlamaContext == nullptr) { - spdlog::error("[WASI-NN] GGML backend: Error: unable to init context."sv); Env.NNGraph.pop_back(); - return ErrNo::InvalidArgument; - } - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: Initialize ggml model with given parameters...Done"sv); + RET_ERROR(ErrNo::InvalidArgument, "unable to init context."sv) } + LOG_DEBUG(GraphRef.EnableDebugLog, + "load: initialize ggml model with given parameters...Done"sv) // Store the loaded graph. GraphId = static_cast(Env.NNGraph.size() - 1); + LOG_DEBUG(GraphRef.EnableDebugLog, "load...Done"sv) return ErrNo::Success; } Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, uint32_t &ContextId) noexcept { auto &GraphRef = Env.NNGraph[GraphId].get(); - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: initExecCtx"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, "initExecCtx"sv) Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); ContextId = static_cast(Env.NNContext.size() - 1); - if (GraphRef.EnableLog) { - spdlog::info("[WASI-NN] GGML backend: llama_system_info: {}"sv, - llama_print_system_info()); - } - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: initExecCtx...Done"sv); - } + LOG_INFO(GraphRef.EnableLog, "llama_system_info: {}"sv, + llama_print_system_info()) + LOG_DEBUG(GraphRef.EnableDebugLog, "initExecCtx...Done"sv) return ErrNo::Success; } @@ -963,46 +905,38 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, uint32_t Index, const TensorData &Tensor) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: setInput"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput"sv) + // Use index 1 for metadata. bool IsModelParamsUpdated = false; bool IsContextParamsUpdated = false; - // Use index 1 for metadata. if (Index == 1) { - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: found Metadata, processing"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: found Metadata, processing"sv) const std::string Metadata(reinterpret_cast(Tensor.Tensor.data()), Tensor.Tensor.size()); auto Res = parseMetadata(GraphRef, Metadata, &IsModelParamsUpdated, &IsContextParamsUpdated); - if (Res != ErrNo::Success) { - spdlog::error("[WASI-NN] GGML backend: Failed to parse metadata."sv); - return Res; + RET_ERROR(Res, "failed to parse metadata."sv) } #ifndef __APPLE__ - // XXX: Due to the limitation of WASI-NN proposal, - // this is a workaround for non-macOS devices. - // However, if the model params is updated in Config stage, - // then, we doesn't encourage to use this to avoid the model + // XXX: Due to the limitation of WASI-NN proposal, this is a workaround for + // non-macOS devices. However, if the model params is updated in Config + // stage, then, we don't encourage to use this to avoid the model // reloading. { if (IsModelParamsUpdated) { + LOG_INFO(GraphRef.EnableLog, + "Reloaded model due to parameters change."sv) llama_model_params ModelParams = llama_model_default_params(); ModelParams.n_gpu_layers = static_cast(GraphRef.NGPULayers); GraphRef.LlamaModel.reset(); GraphRef.LlamaModel = llama_model_ptr(llama_load_model_from_file( GraphRef.ModelFilePath.c_str(), ModelParams)); if (GraphRef.LlamaModel == nullptr) { - spdlog::error( - "[WASI-NN] GGML backend: Error: unable to init model."sv); Env.NNGraph.pop_back(); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, "unable to init model."sv) } } } @@ -1011,10 +945,8 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, // Some changes of context parameters will require the context to be // reloaded. if (IsContextParamsUpdated) { - if (GraphRef.EnableLog) { - spdlog::info( - "[WASI-NN] GGML backend: Reloaded model due to parameters change."sv); - } + LOG_INFO(GraphRef.EnableLog, + "Reloaded context due to parameters change."sv) GraphRef.LlamaContext.reset(); common_params Params; setupParams(GraphRef, Params); @@ -1023,113 +955,79 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, GraphRef.LlamaModel.get(), common_context_params_to_llama(Params)))); if (GraphRef.LlamaContext == nullptr) { - spdlog::error( - "[WASI-NN] GGML backend: Error: unable to init context."sv); Env.NNGraph.pop_back(); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, "unable to init context."sv) } } - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: found Metadata, processing...Done"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: found Metadata, processing...Done"sv) return ErrNo::Success; } // Clear the llama context. - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: clear llama context"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: clear llama context"sv) llama_kv_cache_clear(GraphRef.LlamaContext.get()); - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: clear llama context...Done"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: clear llama context...Done"sv) // Set the input. - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: set the input"sv); - } const bool AddSpecial = true; const bool ParseSpecial = true; std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), Tensor.Tensor.size()); CxtRef.LlamaInputs.clear(); - if (GraphRef.MMProjModelPath == ""sv) { - // Text only prompt. - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: tokenize text prompt"sv); - } - CxtRef.LlamaInputs = common_tokenize(GraphRef.LlamaContext.get(), Prompt, - AddSpecial, ParseSpecial); - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: tokenize text prompt...Done"sv); - } - } else { + if (GraphRef.MMProjModelPath != ""sv) { // Handle llava format prompt. - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: handle llava format prompt"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: handle llava format prompt"sv) // Check if the prompt contains a base64 image. bool ContainsBase64Image = containsBase64Image(GraphRef, Prompt); if (GraphRef.ImagePath == ""sv && ContainsBase64Image == false) { - spdlog::error( - "[WASI-NN] GGML backend: Error: when using llava model, " - "you need to specify the image path or have the base64 encoded " - "image in the prompt."sv); - return ErrNo::InvalidArgument; + RET_ERROR( + ErrNo::InvalidArgument, + "when using llava model, you need to specify the image path or "sv + "have the base64 encoded image in the prompt."sv) } // Show some warnings. - if (GraphRef.EnableLog) { - if (GraphRef.CtxSize < 4096) { - spdlog::info( - "[WASI-NN] GGML backend: Context size is {}, " - "we recommend context size >= 2048 when using llava-v1.5 " - "and context size >= 4096 when using llava-v1.6 for better results."sv, - GraphRef.CtxSize); - } + if (GraphRef.CtxSize < 4096) { + LOG_INFO( + GraphRef.EnableLog, + "Context size is {}, we recommend context size >= 2048 when using "sv + "llava-v1.5 and context size >= 4096 when using llava-v1.6 for "sv + "better results."sv, + GraphRef.CtxSize) } // Load the clip model if not loaded. if (GraphRef.ClipContext == nullptr) { - spdlog::info( - "[WASI-NN] GGML backend: Load the clip model. " - "Because llama.cpp disabled the GPU support for CLIP, " - "the step of loading images in CLIP can only use the CPU, " - "which may result in reduced efficiency. " - "(You can refer to PR https://github.com/ggerganov/llama.cpp/pull/10896)"sv); + LOG_INFO( + true, + "Load the clip model. Because llama.cpp disabled the GPU support "sv + "for CLIP, the step of loading images in CLIP can only use the "sv + "CPU, which may result in reduced efficiency. (You can refer to "sv + "PR https://github.com/ggerganov/llama.cpp/pull/10896)"sv) GraphRef.ClipContext = clip_model_load(GraphRef.MMProjModelPath.c_str(), GraphRef.EnableLog ? 1 : 0); if (GraphRef.ClipContext == nullptr) { - spdlog::error( - "[WASI-NN] GGML backend: Error: unable to load the clip model."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, "unable to load the clip model."sv) } if (clip_is_qwen2vl(GraphRef.ClipContext)) { GraphRef.VisionModelType = VisionModel::Qwen2VL; - spdlog::info( - "[WASI-NN] GGML backend: The Qwen2VL clip model is loaded."sv); - if (GraphRef.EnableLog) { - spdlog::info("[WASI-NN] GGML backend: Qwen2vl model detected."sv); - } + LOG_INFO(true, "Qwen2vl model loaded."sv) } } // Get image embed. if (ContainsBase64Image) { // Load the base64 image from the prompt. - CxtRef.LlavaImageEmbd = - loadBase64ImageFromPrompt(GraphRef, GraphRef.ClipContext, Prompt); + CxtRef.LlavaImageEmbd = llavaLoadBase64ImageFromPrompt( + GraphRef, GraphRef.ClipContext, Prompt); // Replace the base64 image in the prompt with a placeholder. - auto Res = replaceBase64ImagePlaceholderInPrompt(Prompt); + auto Res = replaceBase64ImagePlaceholderInPrompt( + Prompt, LlavaPromptImagePlaceholder); if (Res != ErrNo::Success) { - spdlog::error( - "[WASI-NN] GGML backend: Error: unable to replace the base64 image in the prompt."sv); clip_free(GraphRef.ClipContext); - return Res; + RET_ERROR(Res, "unable to replace the base64 image in the prompt."sv) } } else { // Load the image from the file. @@ -1138,21 +1036,18 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, GraphRef.ImagePath.c_str()); } if (CxtRef.LlavaImageEmbd == nullptr) { - spdlog::error( - "[WASI-NN] GGML backend: Error: unable to load the image."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, "unable to load the image."sv) } // We split prompt by as placeholder and save the position. - auto PlaceholderPosition = Prompt.find(PromptImagePlaceholder); + auto PlaceholderPosition = Prompt.find(LlavaPromptImagePlaceholder); if (PlaceholderPosition == std::string::npos) { - spdlog::error( - "[WASI-NN] GGML backend: Error: unable to find the placeholder in the llava prompt."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "unable to find the placeholder in the llava prompt."sv) } std::string PromptBeforeImage = Prompt.substr(0, PlaceholderPosition); - std::string PromptAfterImage = - Prompt.substr(PlaceholderPosition + PromptImagePlaceholder.length()); + std::string PromptAfterImage = Prompt.substr( + PlaceholderPosition + LlavaPromptImagePlaceholder.length()); std::vector EmbdInputBeforeImage = common_tokenize(GraphRef.LlamaContext.get(), PromptBeforeImage, AddSpecial, ParseSpecial); @@ -1160,7 +1055,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, // tokens after the image. std::vector EmbdInputAfterImage = common_tokenize( GraphRef.LlamaContext.get(), PromptAfterImage, false, ParseSpecial); - CxtRef.LlavaImagePosition = EmbdInputBeforeImage.size(); + CxtRef.ImagePosition = EmbdInputBeforeImage.size(); CxtRef.LlamaInputs.reserve(EmbdInputBeforeImage.size() + EmbdInputAfterImage.size()); CxtRef.LlamaInputs.insert(CxtRef.LlamaInputs.end(), @@ -1169,17 +1064,20 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, CxtRef.LlamaInputs.insert(CxtRef.LlamaInputs.end(), EmbdInputAfterImage.begin(), EmbdInputAfterImage.end()); - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: handle llava format prompt...Done"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: handle llava format prompt...Done"sv) + } else { + // Text only prompt. + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize text prompt"sv) + CxtRef.LlamaInputs = common_tokenize(GraphRef.LlamaContext.get(), Prompt, + AddSpecial, ParseSpecial); + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: tokenize text prompt...Done"sv) } CxtRef.LlamaNInputs = CxtRef.LlamaInputs.size(); GraphRef.ComputeSingleStarted = false; - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: setInput...Done"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput...Done"sv) return ErrNo::Success; } @@ -1188,67 +1086,45 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, uint32_t &BytesWritten) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: getOutput with Index {}"sv, - Index); - } + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutput with Index {}"sv, Index) // Index 1 is for the metadata of the outputs. if (Index == 1) { std::string Metadata; - auto Res = buildOutputMetadata(CxtRef, Metadata); - if (Res != ErrNo::Success) { - spdlog::error( - "[WASI-NN] GGML backend: Failed to build output metadata."sv); - return Res; - } + buildOutputMetadata(CxtRef, Metadata); std::copy_n(Metadata.data(), Metadata.length(), OutBuffer.data()); BytesWritten = static_cast(Metadata.length()); - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: getOutput with Index {}...Done"sv, - Index); - } + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutput with Index {}...Done"sv, + Index) return ErrNo::Success; } std::copy_n(CxtRef.LlamaOutputs.data(), CxtRef.LlamaOutputs.length(), OutBuffer.data()); BytesWritten = static_cast(CxtRef.LlamaOutputs.length()); - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: getOutput with Index {}...Done"sv, - Index); - } + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutput with Index {}...Done"sv, Index) return ErrNo::Success; } Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: compute"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, "compute") if (GraphRef.Embedding) { return getEmbedding(Env, ContextId); } if (CxtRef.LlamaInputs.size() == 0) { - spdlog::error("[WASI-NN] GGML backend: Llama input is not set!"sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, "llama input is not set!"sv) } // Clear the outputs. - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: clear the previous output and tokens"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, + "compute: clear the previous output and tokens") CxtRef.LlamaOutputs.clear(); CxtRef.LlamaOutputTokens.clear(); - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: clear the previous output and tokens...Done"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, + "compute: clear the previous output and tokens...Done") // Clear the llama context. llama_kv_cache_clear(GraphRef.LlamaContext.get()); @@ -1272,38 +1148,26 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // Check if the input is too long. if (static_cast(CxtRef.LlamaInputs.size()) > MaxTokensListSize) { - if (GraphRef.EnableLog) { - spdlog::info("[WASI-NN] GGML backend: the prompt is too long. Your input " - "has {} tokens. Please reduce it to {} tokens."sv, - CxtRef.LlamaInputs.size(), MaxTokensListSize); - } - return ErrNo::PromptTooLong; + RET_ERROR( + ErrNo::PromptTooLong, + "the prompt is too long. Your input has {} tokens. Please reduce it "sv + "to {} tokens."sv, + CxtRef.LlamaInputs.size(), MaxTokensListSize) } // Evaluate input tokens. - if (CxtRef.LlavaImageEmbd == nullptr) { - // Text only prompt. - ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext.get(), - std::move(CxtRef.LlamaInputs), NPast, NPos); - if (ReturnCode != ErrNo::Success) { - spdlog::error( - "[WASI-NN] GGML backend: failed to evaluate input tokens."sv); - return ReturnCode; - } - } else { + if (CxtRef.LlavaImageEmbd != nullptr) { // Llava format prompt with image data. - std::vector EmbdInputBeforeImage( - CxtRef.LlamaInputs.begin(), - CxtRef.LlamaInputs.begin() + CxtRef.LlavaImagePosition); + std::vector EmbdInputBeforeImage(CxtRef.LlamaInputs.begin(), + CxtRef.LlamaInputs.begin() + + CxtRef.ImagePosition); std::vector EmbdInputAfterImage(CxtRef.LlamaInputs.begin() + - CxtRef.LlavaImagePosition, + CxtRef.ImagePosition, CxtRef.LlamaInputs.end()); ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext.get(), std::move(EmbdInputBeforeImage), NPast, NPos); if (ReturnCode != ErrNo::Success) { - spdlog::error( - "[WASI-NN] GGML backend: failed to evaluate input tokens before image."sv); - return ReturnCode; + RET_ERROR(ReturnCode, "failed to evaluate input tokens before image."sv) } bool EvalImageStatus = false; @@ -1322,23 +1186,24 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { } if (!EvalImageStatus) { - spdlog::error( - "[WASI-NN] GGML backend: failed to evaluate embed image tokens."sv); - return ErrNo::RuntimeError; + RET_ERROR(ErrNo::RuntimeError, "failed to evaluate embed image tokens."sv) } ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext.get(), std::move(EmbdInputAfterImage), NPast, NPos); if (ReturnCode != ErrNo::Success) { - spdlog::error( - "[WASI-NN] GGML backend: failed to evaluate input tokens after image."sv); - return ReturnCode; + RET_ERROR(ReturnCode, "failed to evaluate input tokens after image."sv) + } + } else { + // Text only prompt. + ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext.get(), + std::move(CxtRef.LlamaInputs), NPast, NPos); + if (ReturnCode != ErrNo::Success) { + RET_ERROR(ReturnCode, "failed to evaluate input tokens."sv) } } - // Main predict loop. - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: enter main predict loop"sv); - } + // Main prediction loop. + LOG_DEBUG(GraphRef.EnableDebugLog, "compute: enter main prediction loop"sv) while (NRemain > 0) { // Use idx = -1 to sample the next token. const llama_token Id = common_sampler_sample( @@ -1359,17 +1224,13 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // Break if reverse prompt is found. if (!GraphRef.ReversePrompt.empty() && CxtRef.LlamaOutputs.find(GraphRef.ReversePrompt) != std::string::npos) { - if (GraphRef.EnableLog) { - spdlog::info("[WASI-NN] GGML backend: reverse prompt found"sv); - } + LOG_INFO(GraphRef.EnableLog, "reverse prompt found."sv) break; } // Deal with end of text token. if (llama_token_is_eog(GraphRef.LlamaModel.get(), common_sampler_last(Sampler))) { - if (GraphRef.EnableLog) { - spdlog::info("[WASI-NN] GGML backend: EOS token found"sv); - } + LOG_INFO(GraphRef.EnableLog, "EOS token found."sv) break; } // Evaluate the output token. @@ -1379,10 +1240,8 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { break; } } - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: enter main predict loop...Done"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, + "compute: enter main prediction loop...Done"sv) // End of main predict loop. if (GraphRef.EnableLog) { @@ -1391,24 +1250,16 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // We free the contexts here to keep the ggml plugin stateless. // Users could fully control the contexts by themselves via their prompt. - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: delete llama sampler to make it stateless"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, + "compute: delete llama sampler to make it stateless"sv) common_sampler_free(Sampler); if (CxtRef.LlavaImageEmbd != nullptr) { llava_image_embed_free(CxtRef.LlavaImageEmbd); CxtRef.LlavaImageEmbd = nullptr; } - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: delete llama sampler to make it stateless...Done"sv); - } - - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: compute...Done"sv); - } - + LOG_DEBUG(GraphRef.EnableDebugLog, + "compute: delete llama sampler to make it stateless...Done"sv) + LOG_DEBUG(GraphRef.EnableDebugLog, "compute...Done"sv) return ReturnCode; } @@ -1417,38 +1268,24 @@ Expect getOutputSingle(WasiNNEnvironment &Env, uint32_t ContextId, uint32_t &BytesWritten) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: getOutputSingle with Index {}"sv, - Index); - } + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutputSingle with Index {}"sv, Index) + // Index 1 is for the metadata of the outputs. if (Index == 1) { std::string Metadata; - auto Res = buildOutputMetadata(CxtRef, Metadata); - if (Res != ErrNo::Success) { - spdlog::error( - "[WASI-NN] GGML backend: Failed to build output metadata."sv); - return Res; - } + buildOutputMetadata(CxtRef, Metadata); std::copy_n(Metadata.data(), Metadata.length(), OutBuffer.data()); BytesWritten = static_cast(Metadata.length()); - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: getOutputSingle with Index {}...Done"sv, - Index); - } + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutputSingle with Index {}...Done"sv, + Index) return ErrNo::Success; } std::string LastToken = common_token_to_piece( GraphRef.LlamaContext.get(), CxtRef.LlamaOutputTokens.back()); std::copy_n(LastToken.data(), LastToken.length(), OutBuffer.data()); BytesWritten = static_cast(LastToken.length()); - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: getOutputSingle with Index {}...Done"sv, - Index); - } + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutputSingle with Index {}...Done"sv, + Index) return ErrNo::Success; } @@ -1456,32 +1293,23 @@ Expect computeSingle(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - - // Logging. - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: computeSingleToken"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, "computeSingle"sv) // New compute single token context. if (!GraphRef.ComputeSingleStarted) { GraphRef.ComputeSingleStarted = true; // Check if the input is set before setting up the context. if (CxtRef.LlamaInputs.size() == 0) { - spdlog::error("[WASI-NN] GGML backend: Llama input is not set!"sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, "llama input is not set!"sv) } // Clear the outputs. - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: clear the previous output and tokens"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, + "computeSingle: clear the previous output and tokens"sv) CxtRef.LlamaOutputs.clear(); CxtRef.LlamaOutputTokens.clear(); - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: clear the previous output and tokens...Done"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, + "computeSingle: clear the previous output and tokens...Done"sv) // Clear the llama context. llama_kv_cache_clear(GraphRef.LlamaContext.get()); @@ -1503,64 +1331,55 @@ Expect computeSingle(WasiNNEnvironment &Env, // Check if the input is too long. if (static_cast(CxtRef.LlamaInputs.size()) > MaxTokensListSize) { - if (GraphRef.EnableLog) { - spdlog::info( - "[WASI-NN] GGML backend: the prompt is too long. Your input has {} tokens. Please reduce it to {} tokens."sv, - CxtRef.LlamaInputs.size(), MaxTokensListSize); - } - return ErrNo::PromptTooLong; + RET_ERROR( + ErrNo::PromptTooLong, + "the prompt is too long. Your input has {} tokens. Please reduce "sv + "it to {} tokens."sv, + CxtRef.LlamaInputs.size(), MaxTokensListSize) } // Evaluate input tokens. - if (CxtRef.LlavaImageEmbd == nullptr) { - // Text only prompt. - ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext.get(), - std::move(CxtRef.LlamaInputs), - CxtRef.LlamaNPast, CxtRef.LlamaNPos); - if (ReturnCode != ErrNo::Success) { - spdlog::error( - "[WASI-NN] GGML backend: failed to evaluate input tokens."sv); - return ReturnCode; - } - } else { + if (CxtRef.LlavaImageEmbd != nullptr) { // Llava format prompt with image data. - std::vector EmbdInputBeforeImage( - CxtRef.LlamaInputs.begin(), - CxtRef.LlamaInputs.begin() + CxtRef.LlavaImagePosition); - std::vector EmbdInputAfterImage( - CxtRef.LlamaInputs.begin() + CxtRef.LlavaImagePosition, - CxtRef.LlamaInputs.end()); + std::vector EmbdInputBeforeImage(CxtRef.LlamaInputs.begin(), + CxtRef.LlamaInputs.begin() + + CxtRef.ImagePosition); + std::vector EmbdInputAfterImage(CxtRef.LlamaInputs.begin() + + CxtRef.ImagePosition, + CxtRef.LlamaInputs.end()); ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext.get(), std::move(EmbdInputBeforeImage), CxtRef.LlamaNPast, CxtRef.LlamaNPos); if (ReturnCode != ErrNo::Success) { - spdlog::error( - "[WASI-NN] GGML backend: failed to evaluate input tokens before image."sv); - return ReturnCode; + RET_ERROR(ReturnCode, "failed to evaluate input tokens before image."sv) } bool EvalImageStatus = llava_eval_image_embed( GraphRef.LlamaContext.get(), CxtRef.LlavaImageEmbd, static_cast(GraphRef.BatchSize), &CxtRef.LlamaNPast); if (!EvalImageStatus) { - spdlog::error( - "[WASI-NN] GGML backend: failed to evaluate embed image tokens."sv); - return ErrNo::RuntimeError; + RET_ERROR(ErrNo::RuntimeError, + "failed to evaluate embed image tokens."sv) } ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext.get(), std::move(EmbdInputAfterImage), CxtRef.LlamaNPast, CxtRef.LlamaNPos); if (ReturnCode != ErrNo::Success) { - spdlog::error( - "[WASI-NN] GGML backend: failed to evaluate input tokens after image."sv); - return ReturnCode; + RET_ERROR(ReturnCode, "failed to evaluate input tokens after image."sv) + } + } else { + // Text only prompt. + ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext.get(), + std::move(CxtRef.LlamaInputs), + CxtRef.LlamaNPast, CxtRef.LlamaNPos); + if (ReturnCode != ErrNo::Success) { + RET_ERROR(ReturnCode, "failed to evaluate input tokens."sv) } } } - // Main predict process. - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: enter main predict process"sv); - } + // Main prediction process. + LOG_DEBUG(GraphRef.EnableDebugLog, + "computeSingle: enter main prediction process"sv) auto ReturnCode = ErrNo::Success; // Use idx = -1 to sample the next token. const llama_token Id = common_sampler_sample( @@ -1575,35 +1394,25 @@ Expect computeSingle(WasiNNEnvironment &Env, if (llama_token_is_eog(GraphRef.LlamaModel.get(), common_sampler_last(CxtRef.LlamaSampler))) { ReturnCode = ErrNo::EndOfSequence; - if (GraphRef.EnableLog) { - spdlog::info("[WASI-NN] GGML backend: EOS token found"sv); - } + LOG_INFO(GraphRef.EnableLog, "EOS token found."sv) } // Evaluate the output token if not EOS. if (ReturnCode != ErrNo::EndOfSequence) { ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext.get(), {Id}, CxtRef.LlamaNPast, CxtRef.LlamaNPos); } - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: enter main predict process...Done"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, + "computeSingle: enter main prediction process...Done"sv) // End of main predict process. - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: computeSingleToken...Done"sv); - } - + LOG_DEBUG(GraphRef.EnableDebugLog, "computeSingle...Done"sv) return ReturnCode; } Expect finiSingle(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: finiSingle"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, "finiSingle"sv) // Logging for the llama timings. if (GraphRef.EnableLog) { @@ -1611,22 +1420,15 @@ Expect finiSingle(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { } // Clear the outputs. - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: finiSingle: clear the previous output and tokens"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, + "finiSingle: clear the previous output and tokens"sv) CxtRef.LlamaOutputs.clear(); CxtRef.LlamaOutputTokens.clear(); - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: finiSingle: clear the previous output and tokens...Done"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, + "finiSingle: clear the previous output and tokens...Done"sv) // Clear the llama context. - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: finiSingle: clear the llama context"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, "finiSingle: clear the llama context"sv) llama_kv_cache_clear(GraphRef.LlamaContext.get()); common_sampler_reset(CxtRef.LlamaSampler); common_sampler_free(CxtRef.LlamaSampler); @@ -1639,42 +1441,28 @@ Expect finiSingle(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { llava_image_embed_free(CxtRef.LlavaImageEmbd); CxtRef.LlavaImageEmbd = nullptr; } - if (GraphRef.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: finiSingle: free the llama context...Done"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, + "finiSingle: clear the llama context...Done"sv) // Reset the context variables. CxtRef.LlamaNPast = 0; - if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: finiSingle...Done"sv); - } - + LOG_DEBUG(GraphRef.EnableDebugLog, "finiSingle...Done"sv) return ErrNo::Success; } Expect unload(WasiNNEnvironment &Env, uint32_t GraphId) noexcept { auto &GraphRef = Env.NNGraph[GraphId].get(); - const bool IsDebugLog = GraphRef.EnableDebugLog; - if (IsDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: unload"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, "unload"sv) + if (GraphRef.LlamaModel != nullptr) { - if (IsDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: unload: free llama model"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, "unload: free llama model"sv) GraphRef.LlamaModel.reset(); - if (IsDebugLog) { - spdlog::info( - "[WASI-NN][Debug] GGML backend: unload: free llama model...Done"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, "unload: free llama model...Done"sv) } Env.NNGraph.erase(Env.NNGraph.begin() + GraphId); Env.mdRemoveById(GraphId); - if (IsDebugLog) { - spdlog::info("[WASI-NN][Debug] GGML backend: unload...Done"sv); - } + LOG_DEBUG(GraphRef.EnableDebugLog, "unload...Done"sv) return ErrNo::Success; } diff --git a/plugins/wasi_nn/wasinn_ggml.h b/plugins/wasi_nn/wasinn_ggml.h index e082710c..82aefca5 100644 --- a/plugins/wasi_nn/wasinn_ggml.h +++ b/plugins/wasi_nn/wasinn_ggml.h @@ -91,7 +91,7 @@ struct Context { int32_t LlamaNPos = 0; // Preserve for llava struct llava_image_embed *LlavaImageEmbd = nullptr; - size_t LlavaImagePosition = 0; + size_t ImagePosition = 0; }; #else struct Graph {}; From 5f81a0bfe70e24ae90281092ad7ac3de923c9151 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Mon, 23 Dec 2024 22:59:13 +0800 Subject: [PATCH 517/623] [WASI-NN] Add the graph and context management mechanism. Signed-off-by: YiYing He --- plugins/wasi_nn/wasinnenv.h | 188 +++++++++++++++++++++++++++++++----- 1 file changed, 166 insertions(+), 22 deletions(-) diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index 4fea1411..3aec8817 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -21,6 +21,8 @@ #include #include +#include +#include #include #ifdef WASMEDGE_BUILD_WASI_NN_RPC @@ -64,17 +66,9 @@ class Graph { public: Graph() = delete; Graph(Backend BE) noexcept : Impl(std::in_place_type_t()) { - switch (BE) { -#define EACH(B) \ - case Backend::B: \ - Impl.emplace(); \ - break; - FOR_EACH_BACKEND(EACH) -#undef EACH - default: - __builtin_unreachable(); - } + init(BE); } + Backend getBackend() const noexcept { using V = std::decay_t; switch (Impl.index()) { @@ -87,6 +81,7 @@ class Graph { __builtin_unreachable(); } } + template auto &get() noexcept { return *std::get_if>(&Impl); } @@ -97,29 +92,65 @@ class Graph { template const auto &get() const noexcept { return *std::get_if(&Impl); } + + void init(Backend BE) noexcept { + switch (BE) { +#define EACH(B) \ + case Backend::B: \ + Impl.emplace(); \ + break; + FOR_EACH_BACKEND(EACH) +#undef EACH + default: + __builtin_unreachable(); + } + Stat = Status::Uninitialized; + CtxCnt = 0; + } + void reset() noexcept { + Impl = std::monostate{}; + Stat = Status::Uninitialized; + CtxCnt = 0; + } + void increaseContext() noexcept { CtxCnt++; } + void decreaseContext() noexcept { + assuming(CtxCnt > 0); + CtxCnt--; + } + uint32_t getContextCount() const noexcept { return CtxCnt; } + bool isFinalized() const noexcept { + return Stat == Status::Uninitialized || Stat == Status::Finalized; + } + bool isReady() const noexcept { return Stat == Status::Ready; } + void setInvalid() noexcept { Stat = Status::Invalid; } + void setFinalized() noexcept { Stat = Status::Finalized; } + void setReady() noexcept { Stat = Status::Ready; } + +private: std::variant< #define EACH(B) B::Graph, FOR_EACH_BACKEND(EACH) #undef EACH std::monostate> Impl; + // Graph status. + // Uninitialized: A new graph in monostate. + // Invalid: The graph loaded failed in set_input with metadata. Can be + // reload with a new metadata in set_input. + // Finalized: The graph being deleted, but there are contexts linked. This + // graph ID will be released once the contexts are deleted. + // Ready: This graph can be used to create a context. + enum class Status : uint8_t { Uninitialized, Invalid, Finalized, Ready }; + Status Stat; + uint32_t CtxCnt; }; class Context { public: Context() = delete; - Context(size_t GId, Graph &G) noexcept + Context(uint32_t GId, Graph &G) noexcept : Impl(std::in_place_type_t()) { - switch (G.getBackend()) { -#define EACH(B) \ - case Backend::B: \ - Impl.emplace(GId, G.get()); \ - break; - FOR_EACH_BACKEND(EACH) -#undef EACH - default: - __builtin_unreachable(); - } + init(GId, G); } Backend getBackend() const noexcept { @@ -145,12 +176,45 @@ class Context { template const auto &get() const noexcept { return *std::get_if(&Impl); } + + void init(uint32_t GId, Graph &G) noexcept { + switch (G.getBackend()) { +#define EACH(B) \ + case Backend::B: \ + Impl.emplace(GId, G.get()); \ + break; + FOR_EACH_BACKEND(EACH) +#undef EACH + default: + __builtin_unreachable(); + } + Stat = Status::Uninitialized; + GraphId = GId; + } + void reset() noexcept { + Impl = std::monostate{}; + Stat = Status::Uninitialized; + GraphId = 0; + } + uint32_t getGraphId() const noexcept { + return static_cast(GraphId); + } + bool isReady() const noexcept { return Stat == Status::Ready; } + void setReady() noexcept { Stat = Status::Ready; } + +private: std::variant< #define EACH(B) B::Context, FOR_EACH_BACKEND(EACH) #undef EACH std::monostate> Impl; + // Context status. + // Uninitialized: A new context in monostate. + // Ready: This context can be used to infer. + enum class Status : uint8_t { Uninitialized, Ready }; + Status Stat; + uint32_t GraphId; }; struct WasiNNEnvironment : @@ -211,13 +275,93 @@ struct WasiNNEnvironment : return WASINN::ErrNo::NotFound; } - mutable std::shared_mutex MdMutex; ///< Protect MdMap + uint32_t newGraph(Backend BE) noexcept { + std::unique_lock Lock(MdMutex); + uint32_t ID = static_cast(NNGraph.size()); + if (NNGraphRecycle.empty()) { + NNGraph.emplace_back(BE); + } else { + ID = *NNGraphRecycle.begin(); + NNGraph[ID].init(BE); + NNGraphRecycle.erase(ID); + } + return ID; + } + + uint32_t newContext(uint32_t GId, Graph &G) noexcept { + std::unique_lock Lock(MdMutex); + assuming(NNGraph.size() > GId); + // TODO: Merge GId into graph class. + uint32_t ID = static_cast(NNContext.size()); + if (NNContextRecycle.empty()) { + NNContext.emplace_back(GId, G); + } else { + ID = *NNContextRecycle.begin(); + NNContext[ID].init(GId, G); + NNContextRecycle.erase(ID); + } + G.increaseContext(); + return ID; + } + + void deleteGraph(const uint32_t Id) noexcept { + // TODO: Add the deallocation callback. + std::unique_lock Lock(MdMutex); + if (Id < NNGraph.size()) { + auto &G = NNGraph[Id]; + G.setFinalized(); + if (G.getContextCount() == 0) { + // Checked all contexts are deleted. Release the graph id. + if (Id == NNGraph.size() - 1) { + NNGraph.pop_back(); + } else { + G.reset(); + NNGraphRecycle.insert(Id); + } + } + } + } + + void deleteContext(const uint32_t Id) noexcept { + // TODO: Add the deallocation callback. + std::unique_lock Lock(MdMutex); + if (Id < NNContext.size() && + NNContextRecycle.find(Id) == NNContextRecycle.end()) { + auto GId = NNContext[Id].getGraphId(); + auto &G = NNGraph[GId]; + G.decreaseContext(); + if (G.getContextCount() == 0 && G.isFinalized()) { + // Checked all contexts are deleted. Release the graph id. + if (GId == NNGraph.size() - 1) { + NNGraph.pop_back(); + } else { + G.reset(); + NNGraphRecycle.insert(GId); + } + } + if (Id == NNContext.size() - 1) { + NNContext.pop_back(); + } else { + NNContext[Id].reset(); + NNContextRecycle.insert(Id); + } + } + } + + // Md storage + mutable std::shared_mutex MdMutex; std::unordered_map>, Backend, Device>> RawMdMap; std::unordered_map MdMap; + + // Graph and context + std::unordered_set NNGraphRecycle; std::vector NNGraph; + std::unordered_set NNContextRecycle; std::vector NNContext; + + // Preload model list static PO::List NNModels; #ifdef WASMEDGE_BUILD_WASI_NN_RPC static PO::Option NNRPCURI; // For RPC client mode From 86382d550cd803b69c95f710d8db08bee39cab35 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Mon, 23 Dec 2024 23:10:17 +0800 Subject: [PATCH 518/623] [WASI-NN] Apply the graph and context management mechanism for all backends. Signed-off-by: YiYing He --- plugins/wasi_nn/wasinn_chattts.cpp | 24 ++-- plugins/wasi_nn/wasinn_chattts.h | 6 +- plugins/wasi_nn/wasinn_ggml.cpp | 80 ++++++++---- plugins/wasi_nn/wasinn_ggml.h | 8 +- plugins/wasi_nn/wasinn_mlx.cpp | 42 ++++--- plugins/wasi_nn/wasinn_mlx.h | 6 +- plugins/wasi_nn/wasinn_neuralspeed.cpp | 3 - plugins/wasi_nn/wasinn_neuralspeed.h | 5 +- plugins/wasi_nn/wasinn_onnx.h | 2 +- plugins/wasi_nn/wasinn_openvino.cpp | 13 +- plugins/wasi_nn/wasinn_openvino.h | 6 +- plugins/wasi_nn/wasinn_piper.cpp | 20 +-- plugins/wasi_nn/wasinn_piper.h | 6 +- plugins/wasi_nn/wasinn_tf.h | 2 +- plugins/wasi_nn/wasinn_tfl.cpp | 18 +-- plugins/wasi_nn/wasinn_tfl.h | 6 +- plugins/wasi_nn/wasinn_torch.cpp | 12 +- plugins/wasi_nn/wasinn_torch.h | 6 +- plugins/wasi_nn/wasinn_whisper.cpp | 39 ++++-- plugins/wasi_nn/wasinn_whisper.h | 6 +- plugins/wasi_nn/wasinnenv.h | 9 +- plugins/wasi_nn/wasinnfunc.cpp | 167 ++++++++++++++++--------- test/plugins/wasi_nn/wasi_nn.cpp | 6 + 23 files changed, 301 insertions(+), 191 deletions(-) diff --git a/plugins/wasi_nn/wasinn_chattts.cpp b/plugins/wasi_nn/wasinn_chattts.cpp index 3cb47867..f15fdb75 100644 --- a/plugins/wasi_nn/wasinn_chattts.cpp +++ b/plugins/wasi_nn/wasinn_chattts.cpp @@ -27,8 +27,8 @@ Expect load(WASINN::WasiNNEnvironment &Env, Span>, WASINN::Device, uint32_t &GraphId) noexcept { // Add a new graph. - Env.NNGraph.emplace_back(Backend::ChatTTS); - auto &GraphRef = Env.NNGraph.back().get(); + uint32_t GId = Env.newGraph(Backend::ChatTTS); + auto &GraphRef = Env.NNGraph[GId].get(); // Initialize the plugin parameters. if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN] ChatTTS backend: Load."sv); @@ -47,7 +47,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, if (GraphRef.ChatTTSModule == nullptr) { spdlog::error( "[WASI-NN] ChatTTS backend: Can not find ChatTTS library."sv); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return WASINN::ErrNo::RuntimeError; } } @@ -57,20 +57,20 @@ Expect load(WASINN::WasiNNEnvironment &Env, if (ChatFunction == nullptr || !PyCallable_Check(ChatFunction)) { spdlog::error( "[WASI-NN] ChatTTS backend: Can not find Chat class in ChatTTS."sv); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return WASINN::ErrNo::RuntimeError; } GraphRef.Chat = PyObject_CallObject(ChatFunction, nullptr); Py_XDECREF(ChatFunction); if (GraphRef.Chat == nullptr) { spdlog::error("[WASI-NN] ChatTTS backend: Can not create chat."sv); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return WASINN::ErrNo::RuntimeError; } PyObject *LoadMethod = PyObject_GetAttrString(GraphRef.Chat, "load"); if (LoadMethod == nullptr || !PyCallable_Check(LoadMethod)) { spdlog::error("[WASI-NN] ChatTTS backend: Can not load chat."sv); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return WASINN::ErrNo::RuntimeError; } PyObject *Value = PyObject_CallObject(LoadMethod, nullptr); @@ -78,7 +78,8 @@ Expect load(WASINN::WasiNNEnvironment &Env, Py_XDECREF(LoadMethod); } // Store the loaded graph. - GraphId = Env.NNGraph.size() - 1; + GraphId = GId; + Env.NNGraph[GId].setReady(); return WASINN::ErrNo::Success; } @@ -90,8 +91,8 @@ Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, "[WASI-NN] ChatTTS backend: Model has been released, please reload it."sv); return WASINN::ErrNo::RuntimeError; } - Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); - ContextId = Env.NNContext.size() - 1; + ContextId = Env.newContext(GraphId, Env.NNGraph[GraphId]); + Env.NNContext[ContextId].setReady(); return ErrNo::Success; } @@ -124,7 +125,6 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, auto ParseError = Parser.parse(Metadata).get(Doc); if (ParseError) { spdlog::error("[WASI-NN] ChatTTS backend: Parse metadata error"sv); - Env.NNGraph.pop_back(); return ErrNo::InvalidEncoding; } GIL Lock; @@ -242,6 +242,7 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, } return WASINN::ErrNo::InvalidArgument; } + Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { if (!Py_IsInitialized()) { @@ -314,7 +315,7 @@ Expect unload(WASINN::WasiNNEnvironment &Env, uint32_t GraphId) noexcept { auto &GraphRef = Env.NNGraph[GraphId].get(); if (GraphRef.EnableDebugLog) { - spdlog::info("[WASI-NN] Neural speed backend: start unload."sv); + spdlog::info("[WASI-NN] ChatTTS backend: start unload."sv); } if (Py_IsInitialized()) { GIL Lock; @@ -327,6 +328,7 @@ Expect unload(WASINN::WasiNNEnvironment &Env, GraphRef.Chat = nullptr; GraphRef.ChatTTSModule = nullptr; } + Env.deleteGraph(GraphId); return WASINN::ErrNo::Success; } diff --git a/plugins/wasi_nn/wasinn_chattts.h b/plugins/wasi_nn/wasinn_chattts.h index c59010db..8adf8429 100644 --- a/plugins/wasi_nn/wasinn_chattts.h +++ b/plugins/wasi_nn/wasinn_chattts.h @@ -50,15 +50,15 @@ struct Graph { PyObject *ParamsInferCode = nullptr; }; struct Context { - Context(size_t Gid, Graph &) noexcept : GraphId(Gid) {} - size_t GraphId; + Context(uint32_t Gid, Graph &) noexcept : GraphId(Gid) {} + uint32_t GraphId; std::string Inputs; std::vector Outputs; }; #else struct Graph {}; struct Context { - Context(size_t, Graph &) noexcept {} + Context(uint32_t, Graph &) noexcept {} }; #endif diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 5239a01c..1dceb7dd 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -771,11 +771,11 @@ ErrNo replaceBase64ImagePlaceholderInPrompt( Expect load(WasiNNEnvironment &Env, Span> Builders, [[maybe_unused]] Device Device, uint32_t &GraphId) noexcept { // Add a new graph. - Env.NNGraph.emplace_back(Backend::GGML); - auto &GraphRef = Env.NNGraph.back().get(); + uint32_t GId = Env.newGraph(Backend::GGML); + auto &GraphRef = Env.NNGraph[GId].get(); // Initialize the plugin parameters. - auto ContextDefault = llama_context_default_params(); + llama_context_params ContextDefault = llama_context_default_params(); const common_params ParamsDefault; GraphRef.EnableLog = false; GraphRef.EnableDebugLog = false; @@ -813,10 +813,12 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Ignore context or model updates when initializing the graph. auto Res = parseMetadata(GraphRef, Metadata); if (Res != ErrNo::Success) { - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); RET_ERROR(Res, "Failed to parse metadata."sv) } } + + // Logging. LOG_DEBUG(GraphRef.EnableDebugLog, "load"sv) LOG_INFO(GraphRef.EnableLog, "LLAMA_COMMIT {}"sv, LLAMA_COMMIT) LOG_INFO(GraphRef.EnableLog, "LLAMA_BUILD_NUMBER {}"sv, LLAMA_BUILD_NUMBER) @@ -838,7 +840,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, std::ofstream TempFile(GraphRef.ModelFilePath, std::ios::out | std::ios::binary); if (!TempFile) { - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); RET_ERROR(ErrNo::InvalidArgument, "Failed to create the temporary file. Currently, our "sv "workaround involves creating a temporary model file named "sv @@ -855,7 +857,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Check if the model exists. if (!std::filesystem::exists( std::filesystem::u8path(GraphRef.ModelFilePath))) { - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); RET_ERROR(ErrNo::ModelNotFound, "model file not found."sv) } @@ -872,18 +874,19 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, GraphRef.LlamaModel = std::move(LlamaInit.model); GraphRef.LlamaContext = std::move(LlamaInit.context); if (GraphRef.LlamaModel == nullptr) { - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); RET_ERROR(ErrNo::InvalidArgument, "unable to init model."sv) } if (GraphRef.LlamaContext == nullptr) { - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); RET_ERROR(ErrNo::InvalidArgument, "unable to init context."sv) } LOG_DEBUG(GraphRef.EnableDebugLog, "load: initialize ggml model with given parameters...Done"sv) // Store the loaded graph. - GraphId = static_cast(Env.NNGraph.size() - 1); + GraphId = GId; + Env.NNGraph[GId].setReady(); LOG_DEBUG(GraphRef.EnableDebugLog, "load...Done"sv) return ErrNo::Success; @@ -893,10 +896,10 @@ Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, uint32_t &ContextId) noexcept { auto &GraphRef = Env.NNGraph[GraphId].get(); LOG_DEBUG(GraphRef.EnableDebugLog, "initExecCtx"sv) - Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); - ContextId = static_cast(Env.NNContext.size() - 1); + ContextId = Env.newContext(GraphId, Env.NNGraph[GraphId]); LOG_INFO(GraphRef.EnableLog, "llama_system_info: {}"sv, llama_print_system_info()) + Env.NNContext[ContextId].setReady(); LOG_DEBUG(GraphRef.EnableDebugLog, "initExecCtx...Done"sv) return ErrNo::Success; } @@ -926,7 +929,10 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, // stage, then, we don't encourage to use this to avoid the model // reloading. { - if (IsModelParamsUpdated) { + if (IsModelParamsUpdated || GraphRef.LlamaModel == nullptr) { + // The llama model may be nullptr if set_input with updated model params + // last time. Therefore besides the model params updated, we should + // reload the llama model if the model is nullptr. LOG_INFO(GraphRef.EnableLog, "Reloaded model due to parameters change."sv) llama_model_params ModelParams = llama_model_default_params(); @@ -935,7 +941,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, GraphRef.LlamaModel = llama_model_ptr(llama_load_model_from_file( GraphRef.ModelFilePath.c_str(), ModelParams)); if (GraphRef.LlamaModel == nullptr) { - Env.NNGraph.pop_back(); + Env.NNGraph[CxtRef.GraphId].setInvalid(); RET_ERROR(ErrNo::InvalidArgument, "unable to init model."sv) } } @@ -944,27 +950,32 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, // Some changes of context parameters will require the context to be // reloaded. - if (IsContextParamsUpdated) { + if (IsContextParamsUpdated || GraphRef.LlamaContext == nullptr) { LOG_INFO(GraphRef.EnableLog, - "Reloaded context due to parameters change."sv) + "Reloaded llama context due to parameters change."sv) GraphRef.LlamaContext.reset(); common_params Params; setupParams(GraphRef, Params); - GraphRef.LlamaContext = - llama_context_ptr(llama_context_ptr(llama_new_context_with_model( - GraphRef.LlamaModel.get(), - common_context_params_to_llama(Params)))); + GraphRef.LlamaContext = llama_context_ptr(llama_new_context_with_model( + GraphRef.LlamaModel.get(), common_context_params_to_llama(Params))); if (GraphRef.LlamaContext == nullptr) { - Env.NNGraph.pop_back(); + Env.NNGraph[CxtRef.GraphId].setInvalid(); RET_ERROR(ErrNo::InvalidArgument, "unable to init context."sv) } } + Env.NNGraph[CxtRef.GraphId].setReady(); LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: found Metadata, processing...Done"sv) return ErrNo::Success; } + if (!Env.NNGraph[CxtRef.GraphId].isReady()) { + RET_ERROR(ErrNo::InvalidArgument, + "Graph is invalid. Please reload again by passing metadata "sv + "in set_input or unload graph."sv) + } + // Clear the llama context. LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: clear llama context"sv) llama_kv_cache_clear(GraphRef.LlamaContext.get()); @@ -1453,16 +1464,32 @@ Expect finiSingle(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { Expect unload(WasiNNEnvironment &Env, uint32_t GraphId) noexcept { auto &GraphRef = Env.NNGraph[GraphId].get(); - LOG_DEBUG(GraphRef.EnableDebugLog, "unload"sv) + const bool IsDebugLog = GraphRef.EnableDebugLog; + LOG_DEBUG(IsDebugLog, "unload"sv) if (GraphRef.LlamaModel != nullptr) { - LOG_DEBUG(GraphRef.EnableDebugLog, "unload: free llama model"sv) + LOG_DEBUG(IsDebugLog, "unload: free llama model"sv) GraphRef.LlamaModel.reset(); - LOG_DEBUG(GraphRef.EnableDebugLog, "unload: free llama model...Done"sv) + LOG_DEBUG(IsDebugLog, "unload: free llama model...Done"sv) } - Env.NNGraph.erase(Env.NNGraph.begin() + GraphId); + if (GraphRef.LlamaContext != nullptr) { + LOG_DEBUG(IsDebugLog, "unload: free llama context"sv) + GraphRef.LlamaContext.reset(); + LOG_DEBUG(IsDebugLog, "unload: free llama context...Done"sv) + } + Env.deleteGraph(GraphId); Env.mdRemoveById(GraphId); - LOG_DEBUG(GraphRef.EnableDebugLog, "unload...Done"sv) + LOG_DEBUG(IsDebugLog, "unload...Done"sv) + return ErrNo::Success; +} + +Expect finalizeExecCtx(WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "finalize_execution_context"sv) + Env.deleteContext(ContextId); + LOG_DEBUG(GraphRef.EnableDebugLog, "finalize_execution_context...Done"sv) return ErrNo::Success; } @@ -1506,6 +1533,9 @@ Expect finiSingle(WasiNNEnvironment &, uint32_t) noexcept { Expect unload(WasiNNEnvironment &, uint32_t) noexcept { return reportBackendNotSupported(); } +Expect finalizeExecCtx(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} #endif } // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/wasinn_ggml.h b/plugins/wasi_nn/wasinn_ggml.h index 82aefca5..abdbba09 100644 --- a/plugins/wasi_nn/wasinn_ggml.h +++ b/plugins/wasi_nn/wasinn_ggml.h @@ -79,8 +79,8 @@ struct Graph { struct Context { public: - Context(size_t GId, Graph &) noexcept : GraphId(GId) {} - size_t GraphId; + Context(uint32_t GId, Graph &) noexcept : GraphId(GId) {} + uint32_t GraphId; std::vector LlamaInputs; uint64_t LlamaNInputs = 0; std::string LlamaOutputs; @@ -96,7 +96,7 @@ struct Context { #else struct Graph {}; struct Context { - Context(size_t, Graph &) noexcept {} + Context(uint32_t, Graph &) noexcept {} }; #endif @@ -127,4 +127,6 @@ Expect finiSingle(WASINN::WasiNNEnvironment &Env, uint32_t ContextId) noexcept; Expect unload(WASINN::WasiNNEnvironment &Env, uint32_t GraphId) noexcept; +Expect finalizeExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; } // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/wasinn_mlx.cpp b/plugins/wasi_nn/wasinn_mlx.cpp index 7f4cb443..3f600228 100644 --- a/plugins/wasi_nn/wasinn_mlx.cpp +++ b/plugins/wasi_nn/wasinn_mlx.cpp @@ -52,8 +52,8 @@ Expect load(WASINN::WasiNNEnvironment &Env, Span> Builders, WASINN::Device, uint32_t &GraphId) noexcept { // Add a new graph. - Env.NNGraph.emplace_back(Backend::MLX); - auto &GraphRef = Env.NNGraph.back().get(); + uint32_t GId = Env.newGraph(Backend::MLX); + auto &GraphRef = Env.NNGraph[GId].get(); if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN] MLX backend: Load."sv); } @@ -62,7 +62,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, if (Builders.size() <= 1) { spdlog::error( "[WASI-NN] MLX backend: Lack model weight or required metadata (tokenizer, model_type)."sv); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return ErrNo::InvalidArgument; } const std::string Metadata = std::string( @@ -72,7 +72,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, auto ParseError = Parser.parse(Metadata).get(Doc); if (ParseError) { spdlog::error("[WASI-NN] MLX backend: Parse metadata error"sv); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return ErrNo::InvalidEncoding; } if (Doc.at_key("model_type").error() == simdjson::SUCCESS) { @@ -81,14 +81,14 @@ Expect load(WASINN::WasiNNEnvironment &Env, if (Err) { spdlog::error( "[WASI-NN] MLX backend: Unable to retrieve the model_type option."sv); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return ErrNo::InvalidArgument; } GraphRef.ModelType = ModelType; } else { spdlog::error( "[WASI-NN] MLX backend: Unable to retrieve the model_type option."sv); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return ErrNo::InvalidArgument; } if (Doc.at_key("enable_debug_log").error() == simdjson::SUCCESS) { @@ -97,7 +97,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, if (Err) { spdlog::error( "[WASI-NN] MLX backend: Unable to retrieve the enable_debug_log option."sv); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return ErrNo::InvalidArgument; } GraphRef.EnableDebugLog = EnableDebugLog; @@ -108,14 +108,14 @@ Expect load(WASINN::WasiNNEnvironment &Env, if (Err) { spdlog::error( "[WASI-NN] MLX backend: Unable to retrieve the tokenizer option."sv); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return ErrNo::InvalidArgument; } TokenizerPath = TokenizerPathView; } else { spdlog::error( "[WASI-NN] MLX backend: Unable to retrieve the tokenizer option."sv); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return ErrNo::InvalidArgument; } if (Doc.at_key("max_token").error() == simdjson::SUCCESS) { @@ -124,7 +124,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, if (Err) { spdlog::error( "[WASI-NN] MLX backend: Unable to retrieve the max_token option."sv); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return ErrNo::InvalidArgument; } GraphRef.MaxToken = MaxToken; @@ -141,7 +141,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, if (ErrQBits || ErrGroupSize || ErrIsQuantized) { spdlog::error( "[WASI-NN] MLX backend: Unable to retrieve the q_bits or group_size option."sv); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return ErrNo::InvalidArgument; } GraphRef.IsQuantized = IsQuantized; @@ -154,13 +154,13 @@ Expect load(WASINN::WasiNNEnvironment &Env, auto Bytes = loadBytesFromFile(TokenizerPath); if (Bytes.empty()) { spdlog::error("[WASI-NN] MLX backend: Load tokenizer failed."sv); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return ErrNo::InvalidArgument; } GraphRef.Tok = tokenizers::Tokenizer::FromBlobJSON(Bytes); } else { spdlog::error("[WASI-NN] MLX backend: Tokenizer path not found."sv); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return ErrNo::InvalidArgument; } @@ -176,7 +176,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, GraphRef.Prmopt = LLaMA2Prompt(); } else { spdlog::error("[WASI-NN] MLX backend: Model type not supported."sv); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return ErrNo::InvalidArgument; } @@ -191,7 +191,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, Weight.size()); spdlog::info("[WASI-NN] MLX BinModel: {}"sv, BinModel.size()); if (BinModel.size() == 0) { - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return ErrNo::InvalidArgument; } std::string ModelFilePath; @@ -210,7 +210,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, if (!TempFile) { spdlog::error( "[WASI-NN] MLX backend: Failed to create the temporary file. "sv); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return ErrNo::InvalidArgument; } TempFile.write(BinModel.data(), BinModel.size()); @@ -229,7 +229,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, GraphRef.Model->update(llamaToMlxllm(ModelFilePath)); } else { spdlog::error("[WASI-NN] MLX backend: Model type not supported."sv); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return ErrNo::InvalidArgument; } } @@ -238,14 +238,15 @@ Expect load(WASINN::WasiNNEnvironment &Env, GraphRef.Model->toQuantized(GraphRef.GroupSize, GraphRef.QBits); } - GraphId = Env.NNGraph.size() - 1; + GraphId = GId; + Env.NNGraph[GId].setReady(); return WASINN::ErrNo::Success; } Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, uint32_t &ContextId) noexcept { - Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); - ContextId = Env.NNContext.size() - 1; + ContextId = Env.newContext(GraphId, Env.NNGraph[GraphId]); + Env.NNContext[ContextId].setReady(); return ErrNo::Success; } @@ -276,6 +277,7 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, BytesWritten = StringTmp.length(); return WASINN::ErrNo::Success; } + Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); diff --git a/plugins/wasi_nn/wasinn_mlx.h b/plugins/wasi_nn/wasinn_mlx.h index f487af26..e2cd9b4e 100644 --- a/plugins/wasi_nn/wasinn_mlx.h +++ b/plugins/wasi_nn/wasinn_mlx.h @@ -37,15 +37,15 @@ struct Graph { BasePrompt Prmopt; }; struct Context { - Context(size_t Gid, Graph &) noexcept : GraphId(Gid) {} - size_t GraphId; + Context(uint32_t Gid, Graph &) noexcept : GraphId(Gid) {} + uint32_t GraphId; std::string Inputs; std::string Outputs; }; #else struct Graph {}; struct Context { - Context(size_t, Graph &) noexcept {} + Context(uint32_t, Graph &) noexcept {} }; #endif diff --git a/plugins/wasi_nn/wasinn_neuralspeed.cpp b/plugins/wasi_nn/wasinn_neuralspeed.cpp index 9937930d..cc36ea09 100644 --- a/plugins/wasi_nn/wasinn_neuralspeed.cpp +++ b/plugins/wasi_nn/wasinn_neuralspeed.cpp @@ -34,7 +34,4 @@ Expect getOutput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, Expect compute(WASINN::WasiNNEnvironment &, uint32_t) noexcept { return reportBackendNotSupported(); } -Expect unload(WASINN::WasiNNEnvironment &, uint32_t) noexcept { - return reportBackendNotSupported(); -} } // namespace WasmEdge::Host::WASINN::NeuralSpeed diff --git a/plugins/wasi_nn/wasinn_neuralspeed.h b/plugins/wasi_nn/wasinn_neuralspeed.h index 0f53b675..974f3dc8 100644 --- a/plugins/wasi_nn/wasinn_neuralspeed.h +++ b/plugins/wasi_nn/wasinn_neuralspeed.h @@ -14,7 +14,7 @@ struct WasiNNEnvironment; namespace WasmEdge::Host::WASINN::NeuralSpeed { struct Graph {}; struct Context { - Context(size_t, Graph &) noexcept {} + Context(uint32_t, Graph &) noexcept {} }; struct Environ {}; @@ -34,7 +34,4 @@ Expect getOutput(WASINN::WasiNNEnvironment &Env, uint32_t &BytesWritten) noexcept; Expect compute(WASINN::WasiNNEnvironment &Env, uint32_t ContextId) noexcept; -Expect unload(WASINN::WasiNNEnvironment &Env, - uint32_t GraphId) noexcept; - } // namespace WasmEdge::Host::WASINN::NeuralSpeed diff --git a/plugins/wasi_nn/wasinn_onnx.h b/plugins/wasi_nn/wasinn_onnx.h index 7e5343d8..b546e1d9 100644 --- a/plugins/wasi_nn/wasinn_onnx.h +++ b/plugins/wasi_nn/wasinn_onnx.h @@ -14,7 +14,7 @@ struct WasiNNEnvironment; namespace WasmEdge::Host::WASINN::ONNX { struct Graph {}; struct Context { - Context(size_t, Graph &) noexcept {} + Context(uint32_t, Graph &) noexcept {} }; struct Environ {}; diff --git a/plugins/wasi_nn/wasinn_openvino.cpp b/plugins/wasi_nn/wasinn_openvino.cpp index 91a01893..744bf7ed 100644 --- a/plugins/wasi_nn/wasinn_openvino.cpp +++ b/plugins/wasi_nn/wasinn_openvino.cpp @@ -25,8 +25,8 @@ Expect load(WASINN::WasiNNEnvironment &Env, auto Weight = Builders[1]; // Add a new graph. - Env.NNGraph.emplace_back(Backend::OpenVINO); - auto &GraphRef = Env.NNGraph.back().get(); + uint32_t GId = Env.newGraph(Backend::OpenVINO); + auto &GraphRef = Env.NNGraph[GId].get(); // Store device information GraphRef.TargetDevice = Device; @@ -42,11 +42,12 @@ Expect load(WASINN::WasiNNEnvironment &Env, ModelString, GraphRef.OpenVINOIWeightTensor); } catch (const std::exception &EX) { spdlog::error("[WASI-NN] Model Load Exception: {}", EX.what()); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return WASINN::ErrNo::RuntimeError; } // Store the loaded graph. - GraphId = Env.NNGraph.size() - 1; + GraphId = GId; + Env.NNGraph[GId].setReady(); return WASINN::ErrNo::Success; } @@ -60,8 +61,8 @@ Expect initExecCtx(WASINN::WasiNNEnvironment &Env, return WASINN::ErrNo::MissingMemory; } // Create context. - Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); - ContextId = Env.NNContext.size() - 1; + ContextId = Env.newContext(GraphId, Env.NNGraph[GraphId]); + Env.NNContext[ContextId].setReady(); return WASINN::ErrNo::Success; } diff --git a/plugins/wasi_nn/wasinn_openvino.h b/plugins/wasi_nn/wasinn_openvino.h index 57691b7c..f7472ae9 100644 --- a/plugins/wasi_nn/wasinn_openvino.h +++ b/plugins/wasi_nn/wasinn_openvino.h @@ -25,9 +25,9 @@ struct Graph { }; struct Context { - Context(size_t GId, Graph &) noexcept : GraphId(GId) {} + Context(uint32_t GId, Graph &) noexcept : GraphId(GId) {} ~Context() noexcept {} - size_t GraphId; + uint32_t GraphId; ov::InferRequest OpenVINOInferRequest; }; @@ -39,7 +39,7 @@ struct Environ { #else struct Graph {}; struct Context { - Context(size_t, Graph &) noexcept {} + Context(uint32_t, Graph &) noexcept {} }; struct Environ {}; #endif diff --git a/plugins/wasi_nn/wasinn_piper.cpp b/plugins/wasi_nn/wasinn_piper.cpp index 6bfe55d0..c9bf9e0b 100644 --- a/plugins/wasi_nn/wasinn_piper.cpp +++ b/plugins/wasi_nn/wasinn_piper.cpp @@ -273,12 +273,13 @@ Expect load(WASINN::WasiNNEnvironment &Env, } // Add a new graph. - auto &GraphRef = Env.NNGraph.emplace_back(Backend::Piper).get(); + uint32_t GId = Env.newGraph(Backend::Piper); + auto &GraphRef = Env.NNGraph[GId].get(); GraphRef.Config = std::make_unique(); auto String = std::string{Builders[0].begin(), Builders[0].end()}; if (auto Res = parseRunConfig(*GraphRef.Config, String); Res != WASINN::ErrNo::Success) { - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); spdlog::error("[WASI-NN] Piper backend: Failed to parse run config."sv); return Res; } @@ -294,13 +295,13 @@ Expect load(WASINN::WasiNNEnvironment &Env, if (!GraphRef.Config->ESpeakDataPath) { spdlog::error( "[WASI-NN] Piper backend: espeak-ng data directory is required for eSpeakPhonemes"sv); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return WASINN::ErrNo::InvalidArgument; } if (!std::filesystem::exists(GraphRef.Config->ESpeakDataPath.value())) { spdlog::error( "[WASI-NN] Piper backend: espeak-ng data directory doesn't exist"sv); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return WASINN::ErrNo::InvalidArgument; } // User provided path @@ -316,13 +317,13 @@ Expect load(WASINN::WasiNNEnvironment &Env, if (!GraphRef.Config->TashkeelModelPath) { spdlog::error( "[WASI-NN] Piper backend: libtashkeel ort model is required for Arabic"sv); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return WASINN::ErrNo::InvalidArgument; } if (!std::filesystem::exists(GraphRef.Config->TashkeelModelPath.value())) { spdlog::error( "[WASI-NN] Piper backend: libtashkeel ort model doesn't exist"sv); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return WASINN::ErrNo::InvalidArgument; } GraphRef.PiperConfig->useTashkeel = true; @@ -349,7 +350,8 @@ Expect load(WASINN::WasiNNEnvironment &Env, GraphRef.Voice->synthesisConfig.phonemeSilenceSeconds; // Store the loaded graph. - GraphId = Env.NNGraph.size() - 1; + GraphId = GId; + Env.NNGraph[GId].setReady(); return WASINN::ErrNo::Success; } @@ -357,8 +359,8 @@ Expect initExecCtx(WASINN::WasiNNEnvironment &Env, uint32_t GraphId, uint32_t &ContextId) noexcept { // Create context. - Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); - ContextId = Env.NNContext.size() - 1; + ContextId = Env.newContext(GraphId, Env.NNGraph[GraphId]); + Env.NNContext[ContextId].setReady(); return WASINN::ErrNo::Success; } diff --git a/plugins/wasi_nn/wasinn_piper.h b/plugins/wasi_nn/wasinn_piper.h index 5516b560..1dec1455 100644 --- a/plugins/wasi_nn/wasinn_piper.h +++ b/plugins/wasi_nn/wasinn_piper.h @@ -78,8 +78,8 @@ struct Graph { std::unique_ptr Voice; }; struct Context { - Context(size_t GId, Graph &) noexcept : GraphId(GId) {} - size_t GraphId; + Context(uint32_t GId, Graph &) noexcept : GraphId(GId) {} + uint32_t GraphId; std::optional Line; std::unique_ptr> JsonInputSynthesisConfig; std::optional> Output; @@ -87,7 +87,7 @@ struct Context { #else struct Graph {}; struct Context { - Context(size_t, Graph &) noexcept {} + Context(uint32_t, Graph &) noexcept {} }; #endif diff --git a/plugins/wasi_nn/wasinn_tf.h b/plugins/wasi_nn/wasinn_tf.h index 69e42941..c87329cd 100644 --- a/plugins/wasi_nn/wasinn_tf.h +++ b/plugins/wasi_nn/wasinn_tf.h @@ -14,7 +14,7 @@ struct WasiNNEnvironment; namespace WasmEdge::Host::WASINN::Tensorflow { struct Graph {}; struct Context { - Context(size_t, Graph &) noexcept {} + Context(uint32_t, Graph &) noexcept {} }; struct Environ {}; diff --git a/plugins/wasi_nn/wasinn_tfl.cpp b/plugins/wasi_nn/wasinn_tfl.cpp index 69406813..acdda8ee 100644 --- a/plugins/wasi_nn/wasinn_tfl.cpp +++ b/plugins/wasi_nn/wasinn_tfl.cpp @@ -24,8 +24,8 @@ Expect load(WASINN::WasiNNEnvironment &Env, return WASINN::ErrNo::InvalidArgument; } // Add a new graph. - Env.NNGraph.emplace_back(WASINN::Backend::TensorflowLite); - auto &GraphRef = Env.NNGraph.back().get(); + uint32_t GId = Env.newGraph(Backend::TensorflowLite); + auto &GraphRef = Env.NNGraph[GId].get(); // Copy graph builder data to TfLiteModData and create a new TfLiteModel. GraphRef.TfLiteModData.assign(Builders[0].begin(), Builders[0].end()); @@ -33,12 +33,13 @@ Expect load(WASINN::WasiNNEnvironment &Env, GraphRef.TfLiteModData.size()); if (unlikely(GraphRef.TFLiteMod == nullptr)) { spdlog::error("[WASI-NN] Cannot import TFLite model"); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return WASINN::ErrNo::InvalidArgument; } // Store the loaded graph. - GraphId = Env.NNGraph.size() - 1; + GraphId = GId; + Env.NNGraph[GId].setReady(); return WASINN::ErrNo::Success; } @@ -52,8 +53,8 @@ Expect initExecCtx(WASINN::WasiNNEnvironment &Env, } // Create context. - Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); - auto &CxtRef = Env.NNContext.back().get(); + uint32_t CId = Env.newContext(GraphId, Env.NNGraph[GraphId]); + auto &CxtRef = Env.NNContext[CId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); auto *TFLiteOps = TfLiteInterpreterOptionsCreate(); TfLiteInterpreterOptionsSetNumThreads(TFLiteOps, 2); @@ -61,12 +62,13 @@ Expect initExecCtx(WASINN::WasiNNEnvironment &Env, TfLiteInterpreterOptionsDelete(TFLiteOps); if (unlikely(CxtRef.TFLiteInterp == nullptr)) { spdlog::error("[WASI-NN] Cannot create TFLite interpreter."); - Env.NNContext.pop_back(); + Env.deleteContext(CId); return WASINN::ErrNo::Busy; } TfLiteInterpreterAllocateTensors(CxtRef.TFLiteInterp); - ContextId = Env.NNContext.size() - 1; + ContextId = CId; + Env.NNContext[ContextId].setReady(); return WASINN::ErrNo::Success; } diff --git a/plugins/wasi_nn/wasinn_tfl.h b/plugins/wasi_nn/wasinn_tfl.h index b640f203..e2a02419 100644 --- a/plugins/wasi_nn/wasinn_tfl.h +++ b/plugins/wasi_nn/wasinn_tfl.h @@ -31,19 +31,19 @@ struct Graph { struct Context { public: - Context(size_t GId, Graph &) noexcept : GraphId(GId) {} + Context(uint32_t GId, Graph &) noexcept : GraphId(GId) {} ~Context() noexcept { if (TFLiteInterp) { TfLiteInterpreterDelete(TFLiteInterp); } } - size_t GraphId; + uint32_t GraphId; TfLiteInterpreter *TFLiteInterp = nullptr; }; #else struct Graph {}; struct Context { - Context(size_t, Graph &) noexcept {} + Context(uint32_t, Graph &) noexcept {} }; #endif diff --git a/plugins/wasi_nn/wasinn_torch.cpp b/plugins/wasi_nn/wasinn_torch.cpp index d87a910f..95cfc03c 100644 --- a/plugins/wasi_nn/wasinn_torch.cpp +++ b/plugins/wasi_nn/wasinn_torch.cpp @@ -177,8 +177,8 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, auto Weight = Builders[0]; // Add a new graph. - Env.NNGraph.emplace_back(Backend::PyTorch); - auto &GraphRef = Env.NNGraph.back().get(); + uint32_t GId = Env.newGraph(Backend::PyTorch); + auto &GraphRef = Env.NNGraph[GId].get(); // Load the model from the binary data. // Note: Pytorch use try catch to handle the error. @@ -210,15 +210,15 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, return ErrNo::InvalidArgument; } - GraphId = Env.NNGraph.size() - 1; + GraphId = GId; + Env.NNGraph[GId].setReady(); return ErrNo::Success; } Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, uint32_t &ContextId) noexcept { - Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); - - ContextId = Env.NNContext.size() - 1; + ContextId = Env.newContext(GraphId, Env.NNGraph[GraphId]); + Env.NNContext[ContextId].setReady(); return ErrNo::Success; } diff --git a/plugins/wasi_nn/wasinn_torch.h b/plugins/wasi_nn/wasinn_torch.h index 579ba354..8a9d8cde 100644 --- a/plugins/wasi_nn/wasinn_torch.h +++ b/plugins/wasi_nn/wasinn_torch.h @@ -80,15 +80,15 @@ struct Graph { struct Context { public: - Context(size_t GId, Graph &) noexcept : GraphId(GId) {} - size_t GraphId; + Context(uint32_t GId, Graph &) noexcept : GraphId(GId) {} + uint32_t GraphId; std::vector TorchInputs; std::vector TorchOutputs; }; #else struct Graph {}; struct Context { - Context(size_t, Graph &) noexcept {} + Context(uint32_t, Graph &) noexcept {} }; #endif diff --git a/plugins/wasi_nn/wasinn_whisper.cpp b/plugins/wasi_nn/wasinn_whisper.cpp index a23a8c8c..7bfefcd2 100644 --- a/plugins/wasi_nn/wasinn_whisper.cpp +++ b/plugins/wasi_nn/wasinn_whisper.cpp @@ -839,8 +839,8 @@ Expect handleTranslationConfig(whisper_context *WhisperCtx, Expect load(WasiNNEnvironment &Env, Span> Builders, [[maybe_unused]] Device Device, uint32_t &GraphId) noexcept { // Add a new graph. - Env.NNGraph.emplace_back(Backend::Whisper); - auto &GraphRef = Env.NNGraph.back().get(); + uint32_t GId = Env.newGraph(Backend::Whisper); + auto &GraphRef = Env.NNGraph[GId].get(); // Initialize the parameters. auto CParam = whisper_context_default_params(); @@ -860,7 +860,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, auto Res = parseMetadata(GraphRef.WhisperConfig, Metadata); if (Res != ErrNo::Success) { spdlog::error("[WASI-NN] Whisper backend: Failed to parse metadata."sv); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return Res; } } @@ -893,7 +893,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, spdlog::error( "[WASI-NN] Whisper backend: Error: unable to init whisper context from " "model."sv); - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return ErrNo::InvalidArgument; } if (GraphRef.WhisperConfig.EnableDebugLog) { @@ -905,12 +905,13 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, auto ResTranslateConfig = handleTranslationConfig(GraphRef.WhisperCtx, GraphRef.WhisperConfig); if (ResTranslateConfig != ErrNo::Success) { - Env.NNGraph.pop_back(); + Env.deleteGraph(GId); return ResTranslateConfig; } // Store the loaded graph. - GraphId = Env.NNGraph.size() - 1; + GraphId = GId; + Env.NNGraph[GId].setReady(); return ErrNo::Success; } @@ -921,8 +922,7 @@ Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, if (GraphRef.WhisperConfig.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] Whisper backend: initExecCtx"sv); } - Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]); - ContextId = Env.NNContext.size() - 1; + ContextId = Env.newContext(GraphId, Env.NNGraph[GraphId]); auto &CxtRef = Env.NNContext[ContextId].get(); CxtRef.WhisperParams = whisper_full_default_params( whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH); @@ -931,6 +931,7 @@ Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, spdlog::info("[WASI-NN] Whisper backend: whisper_system_info: {}"sv, whisper_print_system_info()); } + Env.NNContext[ContextId].setReady(); if (GraphRef.WhisperConfig.EnableDebugLog) { spdlog::info("[WASI-NN][Debug] Whisper backend: initExecCtx...Done"sv); } @@ -1092,13 +1093,30 @@ Expect unload(WasiNNEnvironment &Env, uint32_t GraphId) noexcept { "[WASI-NN][Debug] Whisper backend: unload: free whisper context...Done"sv); } } - Env.NNGraph.erase(Env.NNGraph.begin() + GraphId); + Env.deleteGraph(GraphId); + Env.mdRemoveById(GraphId); if (IsDebugLog) { spdlog::info("[WASI-NN][Debug] Whisper backend: unload...Done"sv); } return ErrNo::Success; } +Expect finalizeExecCtx(WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + const bool IsDebugLog = CxtRef.WhisperConfig.EnableDebugLog; + if (IsDebugLog) { + spdlog::info( + "[WASI-NN][Debug] Whisper backend: finalize_execution_context"sv); + } + // TODO: Free resources + Env.deleteContext(ContextId); + if (IsDebugLog) { + spdlog::info( + "[WASI-NN][Debug] Whisper backend: finalize_execution_context...Done"sv); + } + return ErrNo::Success; +} #else namespace { @@ -1130,6 +1148,9 @@ Expect compute(WasiNNEnvironment &, uint32_t) noexcept { Expect unload(WasiNNEnvironment &, uint32_t) noexcept { return reportBackendNotSupported(); } +Expect finalizeExecCtx(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} #endif } // Namespace WasmEdge::Host::WASINN::Whisper diff --git a/plugins/wasi_nn/wasinn_whisper.h b/plugins/wasi_nn/wasinn_whisper.h index 7cbae0c9..2071ef29 100644 --- a/plugins/wasi_nn/wasinn_whisper.h +++ b/plugins/wasi_nn/wasinn_whisper.h @@ -74,9 +74,9 @@ struct Graph { struct Context { public: - Context(size_t GId, Graph &G) noexcept + Context(uint32_t GId, Graph &G) noexcept : GraphId(GId), WhisperConfig(G.WhisperConfig) {} - size_t GraphId; + uint32_t GraphId; // mono-channel F32 PCM input. std::vector InputPCM; std::vector> InputPCMs; @@ -90,7 +90,7 @@ struct Context { #else struct Graph {}; struct Context { - Context(size_t, Graph &) noexcept {} + Context(uint32_t, Graph &) noexcept {} }; #endif diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index 3aec8817..8d2d575e 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -276,7 +276,7 @@ struct WasiNNEnvironment : } uint32_t newGraph(Backend BE) noexcept { - std::unique_lock Lock(MdMutex); + std::unique_lock Lock(GraphMutex); uint32_t ID = static_cast(NNGraph.size()); if (NNGraphRecycle.empty()) { NNGraph.emplace_back(BE); @@ -289,7 +289,7 @@ struct WasiNNEnvironment : } uint32_t newContext(uint32_t GId, Graph &G) noexcept { - std::unique_lock Lock(MdMutex); + std::unique_lock Lock(GraphMutex); assuming(NNGraph.size() > GId); // TODO: Merge GId into graph class. uint32_t ID = static_cast(NNContext.size()); @@ -306,7 +306,7 @@ struct WasiNNEnvironment : void deleteGraph(const uint32_t Id) noexcept { // TODO: Add the deallocation callback. - std::unique_lock Lock(MdMutex); + std::unique_lock Lock(GraphMutex); if (Id < NNGraph.size()) { auto &G = NNGraph[Id]; G.setFinalized(); @@ -324,7 +324,7 @@ struct WasiNNEnvironment : void deleteContext(const uint32_t Id) noexcept { // TODO: Add the deallocation callback. - std::unique_lock Lock(MdMutex); + std::unique_lock Lock(GraphMutex); if (Id < NNContext.size() && NNContextRecycle.find(Id) == NNContextRecycle.end()) { auto GId = NNContext[Id].getGraphId(); @@ -356,6 +356,7 @@ struct WasiNNEnvironment : std::unordered_map MdMap; // Graph and context + mutable std::shared_mutex GraphMutex; std::unordered_set NNGraphRecycle; std::vector NNGraph; std::unordered_set NNContextRecycle; diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index b9c9abd7..32c02c03 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -80,10 +80,10 @@ WasiNNLoad::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t BuilderPtr, case WASINN::Device::TPU: break; default: - spdlog::error("[WASI-NN] Unknown device {};"sv, Target); + spdlog::error("[WASI-NN] Unknown device {}."sv, Target); return WASINN::ErrNo::InvalidArgument; } - spdlog::debug("[WASI-NN] Using device: {}", Device); + spdlog::debug("[WASI-NN] Using device: {}.", Device); // Builders' Layout: // | builder-0 | builder-0 len | builder-1 | builder-1 len | ... @@ -258,9 +258,18 @@ WasiNNInitExecCtx::bodyImpl(const Runtime::CallingFrame &Frame, } #endif // ifdef WASMEDGE_BUILD_WASI_NN_RPC - if (Env.NNGraph.size() <= GraphId) { + if (Env.NNGraph.size() <= GraphId || Env.NNGraph[GraphId].isFinalized()) { spdlog::error( - "[WASI-NN] init_execution_context: Graph Id does not exist."sv); + "[WASI-NN] init_execution_context: Graph ID {} does not exist or is "sv + "unloaded."sv, + GraphId); + return WASINN::ErrNo::InvalidArgument; + } + if (!Env.NNGraph[GraphId].isReady()) { + spdlog::error( + "[WASI-NN] init_execution_context: Graph ID {} is invalid. Please "sv + "reload or unload this graph."sv, + GraphId); return WASINN::ErrNo::InvalidArgument; } @@ -277,7 +286,7 @@ WasiNNInitExecCtx::bodyImpl(const Runtime::CallingFrame &Frame, } Expect -WasiNNSetInput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context, +WasiNNSetInput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t ContextId, uint32_t Index, uint32_t TensorPtr) { auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { @@ -333,7 +342,7 @@ WasiNNSetInput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context, Env.NNRPCChannel); grpc::ClientContext ClientContext; wasi_ephemeral_nn::SetInputRequest Req; - Req.set_resource_handle(Context); + Req.set_resource_handle(ContextId); Req.set_index(Index); wasi_ephemeral_nn::Tensor RPCTensor; RPCTensor.mutable_dimensions()->Add(Tensor.Dimension.begin(), @@ -352,15 +361,17 @@ WasiNNSetInput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context, } #endif // ifdef WASMEDGE_BUILD_WASI_NN_RPC - if (Env.NNContext.size() <= Context) { - spdlog::error("[WASI-NN] set_input: Execution Context does not exist."sv); + if (Env.NNContext.size() <= ContextId || + !Env.NNContext[ContextId].isReady()) { + spdlog::error("[WASI-NN] set_input: Context ID {} does not exist."sv, + ContextId); return WASINN::ErrNo::InvalidArgument; } - switch (const auto Backend = Env.NNContext[Context].getBackend()) { + switch (const auto Backend = Env.NNContext[ContextId].getBackend()) { #define EACH(B) \ case WASINN::Backend::B: \ - return WASINN::B::setInput(Env, Context, Index, Tensor); + return WASINN::B::setInput(Env, ContextId, Index, Tensor); FOR_EACH_BACKEND(EACH) #undef EACH default: @@ -370,9 +381,10 @@ WasiNNSetInput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context, } Expect -WasiNNGetOutput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context, - uint32_t Index, uint32_t OutBufferPtr, - uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr) { +WasiNNGetOutput::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t ContextId, uint32_t Index, + uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, + uint32_t BytesWrittenPtr) { auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); @@ -397,7 +409,7 @@ WasiNNGetOutput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context, Env.NNRPCChannel); grpc::ClientContext ClientContext; wasi_ephemeral_nn::GetOutputRequest Req; - Req.set_resource_handle(Context); + Req.set_resource_handle(ContextId); Req.set_index(Index); wasi_ephemeral_nn::GetOutputResult Res; auto Status = Stub->GetOutput(&ClientContext, Req, &Res); @@ -413,15 +425,18 @@ WasiNNGetOutput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context, } #endif // ifdef WASMEDGE_BUILD_WASI_NN_RPC - if (Env.NNContext.size() <= Context) { - spdlog::error("[WASI-NN] get_output: Execution Context does not exist"sv); + if (Env.NNContext.size() <= ContextId || + !Env.NNContext[ContextId].isReady()) { + spdlog::error("[WASI-NN] get_output: Context ID {} does not exist."sv, + ContextId); return WASINN::ErrNo::InvalidArgument; } - switch (const auto Backend = Env.NNContext[Context].getBackend()) { + switch (const auto Backend = Env.NNContext[ContextId].getBackend()) { #define EACH(B) \ case WASINN::Backend::B: \ - return WASINN::B::getOutput(Env, Context, Index, OutBuffer, *BytesWritten); + return WASINN::B::getOutput(Env, ContextId, Index, OutBuffer, \ + *BytesWritten); FOR_EACH_BACKEND(EACH) #undef EACH default: @@ -431,7 +446,7 @@ WasiNNGetOutput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context, } Expect WasiNNGetOutputSingle::bodyImpl( - const Runtime::CallingFrame &Frame, uint32_t Context, uint32_t Index, + const Runtime::CallingFrame &Frame, uint32_t ContextId, uint32_t Index, uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr) { auto *MemInst = Frame.getMemoryByIndex(0); @@ -458,7 +473,7 @@ Expect WasiNNGetOutputSingle::bodyImpl( Env.NNRPCChannel); grpc::ClientContext ClientContext; wasi_ephemeral_nn::GetOutputRequest Req; - Req.set_resource_handle(Context); + Req.set_resource_handle(ContextId); Req.set_index(Index); wasi_ephemeral_nn::GetOutputResult Res; auto Status = Stub->GetOutputSingle(&ClientContext, Req, &Res); @@ -474,32 +489,35 @@ Expect WasiNNGetOutputSingle::bodyImpl( } #endif // ifdef WASMEDGE_BUILD_WASI_NN_RPC - if (Env.NNContext.size() <= Context) { + if (Env.NNContext.size() <= ContextId || + !Env.NNContext[ContextId].isReady()) { spdlog::error( - "[WASI-NN] get_output_single: Execution Context does not exist"sv); + "[WASI-NN] get_output_single: Context ID {} does not exist."sv, + ContextId); return WASINN::ErrNo::InvalidArgument; } - switch (Env.NNContext[Context].getBackend()) { + switch (Env.NNContext[ContextId].getBackend()) { case WASINN::Backend::GGML: - return WASINN::GGML::getOutputSingle(Env, Context, Index, OutBuffer, + return WASINN::GGML::getOutputSingle(Env, ContextId, Index, OutBuffer, *BytesWritten); default: - spdlog::error( - "[WASI-NN] get_output_single: Only GGML backend supports get_output_single."sv); + spdlog::error("[WASI-NN] get_output_single: Only GGML backend supports "sv + "get_output_single."sv); return WASINN::ErrNo::InvalidArgument; } } Expect -WasiNNCompute::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context) { +WasiNNCompute::bodyImpl(const Runtime::CallingFrame &Frame, + uint32_t ContextId) { #ifdef WASMEDGE_BUILD_WASI_NN_RPC if (Env.NNRPCChannel != nullptr) { auto Stub = wasi_ephemeral_nn::GraphExecutionContextResource::NewStub( Env.NNRPCChannel); grpc::ClientContext ClientContext; wasi_ephemeral_nn::ComputeRequest Req; - Req.set_resource_handle(Context); + Req.set_resource_handle(ContextId); google::protobuf::Empty Res; auto Status = Stub->Compute(&ClientContext, Req, &Res); if (!Status.ok()) { @@ -514,15 +532,26 @@ WasiNNCompute::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context) { return Unexpect(ErrCode::Value::HostFuncError); } - if (Env.NNContext.size() <= Context) { - spdlog::error("[WASI-NN] compute: Execution Context does not exist."sv); + if (Env.NNContext.size() <= ContextId || + !Env.NNContext[ContextId].isReady()) { + spdlog::error("[WASI-NN] compute: Context ID {} does not exist."sv, + ContextId); return WASINN::ErrNo::InvalidArgument; } - switch (const auto Backend = Env.NNContext[Context].getBackend()) { + auto GraphId = Env.NNContext[ContextId].getGraphId(); + assuming(Env.NNGraph.size() > GraphId); + if (!Env.NNGraph[GraphId].isReady()) { + spdlog::error("[WASI-NN] compute: Graph ID {} for context ID {} does not "sv + "exist or has released."sv, + GraphId, ContextId); + return WASINN::ErrNo::InvalidArgument; + } + + switch (const auto Backend = Env.NNContext[ContextId].getBackend()) { #define EACH(B) \ case WASINN::Backend::B: \ - return WASINN::B::compute(Env, Context); + return WASINN::B::compute(Env, ContextId); FOR_EACH_BACKEND(EACH) #undef EACH default: @@ -533,14 +562,14 @@ WasiNNCompute::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t Context) { Expect WasiNNComputeSingle::bodyImpl(const Runtime::CallingFrame &Frame, - uint32_t Context) { + uint32_t ContextId) { #ifdef WASMEDGE_BUILD_WASI_NN_RPC if (Env.NNRPCChannel != nullptr) { auto Stub = wasi_ephemeral_nn::GraphExecutionContextResource::NewStub( Env.NNRPCChannel); grpc::ClientContext ClientContext; wasi_ephemeral_nn::ComputeRequest Req; - Req.set_resource_handle(Context); + Req.set_resource_handle(ContextId); google::protobuf::Empty Res; auto Status = Stub->ComputeSingle(&ClientContext, Req, &Res); if (!Status.ok()) { @@ -555,32 +584,42 @@ WasiNNComputeSingle::bodyImpl(const Runtime::CallingFrame &Frame, return Unexpect(ErrCode::Value::HostFuncError); } - if (Env.NNContext.size() <= Context) { - spdlog::error( - "[WASI-NN] compute_single: Execution Context does not exist."sv); + if (Env.NNContext.size() <= ContextId || + !Env.NNContext[ContextId].isReady()) { + spdlog::error("[WASI-NN] compute_single: Context ID {} does not exist."sv, + ContextId); + return WASINN::ErrNo::InvalidArgument; + } + + auto GraphId = Env.NNContext[ContextId].getGraphId(); + assuming(Env.NNGraph.size() > GraphId); + if (!Env.NNGraph[GraphId].isReady()) { + spdlog::error("[WASI-NN] compute_single: Graph ID {} for context ID {} "sv + "does not exist or has released."sv, + GraphId, ContextId); return WASINN::ErrNo::InvalidArgument; } - switch (Env.NNContext[Context].getBackend()) { + switch (Env.NNContext[ContextId].getBackend()) { case WASINN::Backend::GGML: - return WASINN::GGML::computeSingle(Env, Context); + return WASINN::GGML::computeSingle(Env, ContextId); default: - spdlog::error( - "[WASI-NN] compute_single: Only GGML backend supports compute_single."sv); + spdlog::error("[WASI-NN] compute_single: Only GGML backend supports "sv + "compute_single."sv); return WASINN::ErrNo::InvalidArgument; } } Expect WasiNNFiniSingle::bodyImpl(const Runtime::CallingFrame &Frame, - uint32_t Context) { + uint32_t ContextId) { #ifdef WASMEDGE_BUILD_WASI_NN_RPC if (Env.NNRPCChannel != nullptr) { auto Stub = wasi_ephemeral_nn::GraphExecutionContextResource::NewStub( Env.NNRPCChannel); grpc::ClientContext ClientContext; wasi_ephemeral_nn::FiniSingleRequest Req; - Req.set_resource_handle(Context); + Req.set_resource_handle(ContextId); google::protobuf::Empty Res; auto Status = Stub->FiniSingle(&ClientContext, Req, &Res); if (!Status.ok()) { @@ -595,17 +634,19 @@ WasiNNFiniSingle::bodyImpl(const Runtime::CallingFrame &Frame, return Unexpect(ErrCode::Value::HostFuncError); } - if (Env.NNContext.size() <= Context) { - spdlog::error("[WASI-NN] fini_single: Execution Context does not exist."sv); + if (Env.NNContext.size() <= ContextId || + !Env.NNContext[ContextId].isReady()) { + spdlog::error("[WASI-NN] fini_single: Context ID {} does not exist."sv, + ContextId); return WASINN::ErrNo::InvalidArgument; } - switch (Env.NNContext[Context].getBackend()) { + switch (Env.NNContext[ContextId].getBackend()) { case WASINN::Backend::GGML: - return WASINN::GGML::finiSingle(Env, Context); + return WASINN::GGML::finiSingle(Env, ContextId); default: spdlog::error( - "[WASI-NN] fini_single: Only GGML backend supports compute_single."sv); + "[WASI-NN] fini_single: Only GGML backend supports fini_single."sv); return WASINN::ErrNo::InvalidArgument; } } @@ -634,25 +675,23 @@ Expect WasiNNUnload::bodyImpl(const Runtime::CallingFrame &Frame, return WASINN::GGML::unload(Env, GraphId); case WASINN::Backend::Whisper: return WASINN::Whisper::unload(Env, GraphId); - case WASINN::Backend::NeuralSpeed: - return WASINN::NeuralSpeed::unload(Env, GraphId); case WASINN::Backend::ChatTTS: return WASINN::ChatTTS::unload(Env, GraphId); default: - spdlog::error( - "[WASI-NN] unlaod: Only GGML, Whisper, Neural speed, and ChatTTS backend supports unload."sv); + spdlog::error("[WASI-NN] unlaod: Only GGML, Whisper, and ChatTTS "sv + "backends support unload."sv); return WASINN::ErrNo::InvalidArgument; } } Expect WasiNNFinalizeExecCtx::bodyImpl(const Runtime::CallingFrame &Frame, - uint32_t Context) { + uint32_t ContextId) { #ifdef WASMEDGE_BUILD_WASI_NN_RPC if (Env.NNRPCChannel != nullptr) { // TODO: implement RPC for finalize_execution_context - spdlog::error( - "[WASI-NN] RPC client is not implemented for finalize_execution_context"sv); + spdlog::error("[WASI-NN] RPC client is not implemented for "sv + "finalize_execution_context"sv); return WASINN::ErrNo::UnsupportedOperation; } #endif @@ -661,15 +700,23 @@ WasiNNFinalizeExecCtx::bodyImpl(const Runtime::CallingFrame &Frame, return Unexpect(ErrCode::Value::HostFuncError); } - if (Env.NNContext.size() <= Context) { + if (Env.NNContext.size() <= ContextId) { spdlog::error( - "[WASI-NN] finalize_execution_context: Execution Context does not exist."sv); + "[WASI-NN] finalize_execution_context: Context ID {} does not exist."sv, + ContextId); return WASINN::ErrNo::InvalidArgument; } - spdlog::error( - "[WASI-NN] finalize_execution_context: No backend supports finalize_execution_context."sv); - return WASINN::ErrNo::InvalidArgument; + switch (Env.NNContext[ContextId].getBackend()) { + case WASINN::Backend::GGML: + return WASINN::GGML::unload(Env, ContextId); + case WASINN::Backend::Whisper: + return WASINN::Whisper::unload(Env, ContextId); + default: + spdlog::error("[WASI-NN] finalize_execution_context: Only GGML and "sv + "Whisper backends support finalize_execution_context."sv); + return WASINN::ErrNo::InvalidArgument; + } } } // namespace Host diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index ab14d411..e994b75d 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -323,6 +323,7 @@ TEST(WasiNNTest, OpenVINOBackend) { // Swap to the tmp. env. NNGraphTmp.emplace_back(Backend::OpenVINO); + NNGraphTmp.back().setReady(); NNGraphTmp.swap(NNMod->getEnv().NNGraph); NNContextTmp.swap(NNMod->getEnv().NNContext); // Test: init_execution_context -- graph id exceeds. @@ -373,6 +374,7 @@ TEST(WasiNNTest, OpenVINOBackend) { // Swap to the tmp. env. NNContextTmp.emplace_back(0, NNGraphTmp[0]); + NNContextTmp.back().setReady(); NNGraphTmp.swap(NNMod->getEnv().NNGraph); NNContextTmp.swap(NNMod->getEnv().NNContext); // Test: set_input -- context id exceeds. @@ -721,6 +723,7 @@ TEST(WasiNNTest, PyTorchBackend) { // Swap to the tmp. env. NNGraphTmp.emplace_back(Backend::PyTorch); + NNGraphTmp.back().setReady(); // Test: init_execution_context -- graph id exceeds. // TODO: not null test for pytorch now // NNGraphTmp.swap(NNMod->getEnv().NNGraph); @@ -782,6 +785,7 @@ TEST(WasiNNTest, PyTorchBackend) { } NNContextTmp.emplace_back(0, NNGraphTmp[0]); + NNContextTmp.back().setReady(); // Test: set_input -- tensor type not FP32. BuilderPtr = SetInputEntryPtr; @@ -1103,6 +1107,7 @@ TEST(WasiNNTest, TFLiteBackend) { // Swap to the tmp. env. // Test: init_execution_context -- graph id exceeds. NNGraphTmp.emplace_back(Backend::TensorflowLite); + NNGraphTmp.back().setReady(); NNGraphTmp.swap(NNMod->getEnv().NNGraph); NNContextTmp.swap(NNMod->getEnv().NNContext); { @@ -1162,6 +1167,7 @@ TEST(WasiNNTest, TFLiteBackend) { } NNContextTmp.emplace_back(0, NNGraphTmp[0]); + NNContextTmp.back().setReady(); // Test: set_input -- set input successfully. BuilderPtr = SetInputEntryPtr; From 2ccb994eebf64edb79d5edbfa6f7893655d8d525 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Mon, 30 Dec 2024 08:13:34 +0800 Subject: [PATCH 519/623] [WASI-NN] ggml: various refactoring. 1. Move the sampler to context level. 2. Fix the batch encoding. 3. Update the metadata and config structure. 4. Merge the redundant codes. 5. Refine the model/sampler reloading in set_input. Signed-off-by: YiYing He --- plugins/wasi_nn/wasinn_ggml.cpp | 1323 +++++++++++++++---------------- plugins/wasi_nn/wasinn_ggml.h | 53 +- 2 files changed, 690 insertions(+), 686 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 1dceb7dd..c4141cd1 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -19,7 +19,9 @@ #include #include +#include #include +#include #endif namespace WasmEdge::Host::WASINN::GGML { @@ -72,8 +74,22 @@ void LlamaLogCallback(ggml_log_level LogLevel, const char *LogText, } } +// >>>>>>>> Metadata related functions >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + +// Setup llama sampler params from graph. +void setupSamplerParams(Graph &GraphRef, + common_params_sampling &Sampling) noexcept { + Sampling.temp = static_cast(GraphRef.Temp); + Sampling.top_p = static_cast(GraphRef.TopP); + Sampling.penalty_repeat = static_cast(GraphRef.RepeatPenalty); + Sampling.penalty_present = static_cast(GraphRef.PresencePenalty); + Sampling.penalty_freq = static_cast(GraphRef.FrequencyPenalty); + Sampling.grammar = GraphRef.Grammar; + Sampling.seed = static_cast(GraphRef.Seed); +} + // Setup llama common params from graph. -void setupParams(Graph &GraphRef, common_params &Params) { +void setupCommonParams(Graph &GraphRef, common_params &Params) noexcept { Params.model = GraphRef.ModelFilePath; Params.n_gpu_layers = static_cast(GraphRef.NGPULayers); Params.n_ctx = static_cast(GraphRef.CtxSize); @@ -84,19 +100,15 @@ void setupParams(Graph &GraphRef, common_params &Params) { Params.cpuparams.n_threads = static_cast(GraphRef.Threads); Params.cpuparams_batch.n_threads = static_cast(GraphRef.Threads); Params.embedding = GraphRef.Embedding; - Params.sampling.temp = static_cast(GraphRef.Temp); - Params.sampling.top_p = static_cast(GraphRef.TopP); - Params.sampling.penalty_repeat = static_cast(GraphRef.RepeatPenalty); - Params.sampling.penalty_present = - static_cast(GraphRef.PresencePenalty); - Params.sampling.grammar = GraphRef.Grammar; - Params.sampling.seed = static_cast(GraphRef.Seed); + setupSamplerParams(GraphRef, Params.sampling); } // Parse metadata from json. -Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, - bool *IsModelUpdated = nullptr, - bool *IsContextUpdated = nullptr) noexcept { +ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, + const std::string &Metadata, bool *IsModelUpdated = nullptr, + bool *IsContextUpdated = nullptr, + bool *IsSamplerUpdated = nullptr) noexcept { + // Parse metadata from the json. simdjson::dom::parser Parser; simdjson::dom::element Doc; auto ParseError = Parser.parse(Metadata).get(Doc); @@ -104,32 +116,25 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, RET_ERROR(ErrNo::InvalidEncoding, "parse metadata error."sv) } - // Get metadata from the json. - // Currently supported metadata: - // Plugin parameters (used by this plugin): + // Plugin parameters (used by this graph and created contexts): // enable-log: bool // enable-debug-log: bool - // stream-stdout: bool - // embedding: bool - // n-predict: int64_t - // reverse-prompt: string - // mmproj: string - // image: string - // use-mmap: bool // Model parameters (need to reload the model if updated): - // n-gpu-layers: int64_t // main-gpu: int64_t + // n-gpu-layers: int64_t // tensor-split: string, comma-separated floating number list - // use-mmap: use mmap + // embedding: bool + // use-mmap: bool // warmup: bool // split-mode: string, {none,layer,row} + // mmproj: string // Context parameters (used by the llama context): // ctx-size: int64_t // batch-size: int64_t // ubatch-size: int64_t // threads: int64_t - // Sampling parameters (used by the llama sampling context). + // Sampling parameters (used by the llama sampling context): // temp: double // top-p: double // repeat-penalty: double @@ -137,10 +142,23 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, // frequency-penalty: double // grammar: string // seed: uint64_t + // Config parameters (mutable config at runtime for contexts): + // stream-stdout: bool + // n-predict: int64_t + // reverse-prompt: string + // image: string // Get the current llama parameters. - common_params Params; - setupParams(GraphRef, Params); + int64_t PrevNGPULayers = GraphRef.NGPULayers; + bool PrevEmbedding = GraphRef.Embedding; + // Get the current sampler parameters. + double PrevTemp = GraphRef.Temp; + double PrevTopP = GraphRef.TopP; + double PrevRepeatPenalty = GraphRef.RepeatPenalty; + double PrevPresencePenalty = GraphRef.PresencePenalty; + double PrevFrequencyPenalty = GraphRef.FrequencyPenalty; + std::string PrevGrammar = GraphRef.Grammar; + uint64_t PrevSeed = GraphRef.Seed; // The plugin parameters. if (Doc.at_key("enable-log").error() == simdjson::SUCCESS) { @@ -157,56 +175,15 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, "Unable to retrieve the enable-debug-log option."sv) } } - if (Doc.at_key("stream-stdout").error() == simdjson::SUCCESS) { - auto Err = Doc["stream-stdout"].get().get(GraphRef.StreamStdout); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the stream-stdout option."sv) - } - } - if (Doc.at_key("embedding").error() == simdjson::SUCCESS) { - auto Err = Doc["embedding"].get().get(GraphRef.Embedding); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the embedding option."sv) - } - } - if (Doc.at_key("n-predict").error() == simdjson::SUCCESS) { - auto Err = Doc["n-predict"].get().get(GraphRef.NPredict); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the n-predict option."sv) - } - } - if (Doc.at_key("reverse-prompt").error() == simdjson::SUCCESS) { - std::string_view ReversePrompt; - auto Err = Doc["reverse-prompt"].get().get(ReversePrompt); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the reverse-prompt option."sv) - } - GraphRef.ReversePrompt = ReversePrompt; - } - if (Doc.at_key("mmproj").error() == simdjson::SUCCESS) { - std::string_view MMProjModelPath; - auto Err = Doc["mmproj"].get().get(MMProjModelPath); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the mmproj option."sv) - } - GraphRef.MMProjModelPath = MMProjModelPath; - } - if (Doc.at_key("image").error() == simdjson::SUCCESS) { - std::string_view ImagePath; - auto Err = Doc["image"].get().get(ImagePath); + + // The model parameters. + if (Doc.at_key("main-gpu").error() == simdjson::SUCCESS) { + auto Err = Doc["main-gpu"].get().get(GraphRef.MainGPU); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the image option."sv) + "Unable to retrieve the main-gpu option."sv) } - GraphRef.ImagePath = ImagePath; } - - // The model parameters. if (Doc.at_key("n-gpu-layers").error() == simdjson::SUCCESS) { auto Err = Doc["n-gpu-layers"].get().get(GraphRef.NGPULayers); if (Err) { @@ -214,13 +191,6 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, "Unable to retrieve the n-gpu-layers option."sv) } } - if (Doc.at_key("main-gpu").error() == simdjson::SUCCESS) { - auto Err = Doc["main-gpu"].get().get(GraphRef.MainGPU); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the main-gpu option."sv) - } - } if (Doc.at_key("tensor-split").error() == simdjson::SUCCESS) { // The TensorSplit is a comma-separated list of non-negative values. // E.g., "3,2" presents 60% of the data to GPU 0 and 40% to GPU 1. @@ -250,6 +220,13 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, GraphRef.TensorSplit.push_back(0.0f); } } + if (Doc.at_key("embedding").error() == simdjson::SUCCESS) { + auto Err = Doc["embedding"].get().get(GraphRef.Embedding); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the embedding option."sv) + } + } if (Doc.at_key("use-mmap").error() == simdjson::SUCCESS) { auto Err = Doc["use-mmap"].get().get(GraphRef.UseMMap); if (Err) { @@ -284,6 +261,15 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, return ErrNo::InvalidArgument; } } + if (Doc.at_key("mmproj").error() == simdjson::SUCCESS) { + std::string_view MMProjModelPath; + auto Err = Doc["mmproj"].get().get(MMProjModelPath); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the mmproj option."sv) + } + GraphRef.MMProjModelPath = MMProjModelPath; + } // The context parameters. if (Doc.at_key("ctx-size").error() == simdjson::SUCCESS) { @@ -381,28 +367,145 @@ Expect parseMetadata(Graph &GraphRef, const std::string &Metadata, } } - // Check if the model is updated. - if (IsModelUpdated && Params.n_gpu_layers != GraphRef.NGPULayers) { + // The config parameters. + if (Doc.at_key("stream-stdout").error() == simdjson::SUCCESS) { + auto Err = Doc["stream-stdout"].get().get(ConfRef.StreamStdout); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the stream-stdout option."sv) + } + } + if (Doc.at_key("n-predict").error() == simdjson::SUCCESS) { + auto Err = Doc["n-predict"].get().get(ConfRef.NPredict); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-predict option."sv) + } + } + if (Doc.at_key("reverse-prompt").error() == simdjson::SUCCESS) { + std::string_view ReversePrompt; + auto Err = Doc["reverse-prompt"].get().get(ReversePrompt); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the reverse-prompt option."sv) + } + ConfRef.ReversePrompt = ReversePrompt; + } + if (Doc.at_key("image").error() == simdjson::SUCCESS) { + std::string_view ImagePath; + auto Err = Doc["image"].get().get(ImagePath); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the image option."sv) + } + ConfRef.ImagePath = ImagePath; + } + + // Check if the model parameters are updated. + if (IsModelUpdated && PrevNGPULayers != GraphRef.NGPULayers) { *IsModelUpdated = true; } // Check if the context parameters are updated. - if (IsContextUpdated && Params.embedding != GraphRef.Embedding) { + if (IsContextUpdated && PrevEmbedding != GraphRef.Embedding) { *IsContextUpdated = true; } + // Check if the sampler parameters are updated. + if (IsSamplerUpdated && + (PrevTemp != GraphRef.Temp || PrevTopP != GraphRef.TopP || + PrevRepeatPenalty != GraphRef.RepeatPenalty || + PrevPresencePenalty != GraphRef.PresencePenalty || + PrevFrequencyPenalty != GraphRef.FrequencyPenalty || + PrevGrammar != GraphRef.Grammar || PrevSeed != GraphRef.Seed)) { + *IsSamplerUpdated = true; + } + return ErrNo::Success; } -void buildOutputMetadata(Context &CxtRef, std::string &Metadata) noexcept { - Metadata = fmt::format(R"({{"input_tokens": {}, )" - R"("output_tokens": {}, )" - R"("llama_build_number": {}, )" - R"("llama_commit": "{}"}})"sv, - CxtRef.LlamaNInputs, CxtRef.LlamaOutputTokens.size(), - LLAMA_BUILD_NUMBER, LLAMA_COMMIT); +// <<<<<<<< Metadata related functions <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + +// >>>>>>>> Input related functions >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + +const std::string_view Base64ImageTagPrefix = ""sv; +const std::string_view LlavaPromptImagePlaceholder = ""sv; + +// Get base64 image position if found in prompt. +std::optional> +findBase64ImagePayload(std::string_view Prompt, + bool IsDebugLog = false) noexcept { + // Find `` + auto EndTagPos = Prompt.find(Base64ImageTagSuffix, PayloadPos); + if (EndTagPos == std::string::npos) { + LOG_DEBUG(IsDebugLog, "Base64 image tag unclosed."sv) + return std::nullopt; + } + return std::make_tuple(BeginTagPos, PayloadPos, EndTagPos); } +// Extract base64 image payload and image type. Replace it with placeholder. +std::optional, std::string>> +extractBase64ImagePayload(std::string &Prompt, + std::tuple ImagePos, + const std::string_view Placeholder) noexcept { + // Locate the payload and image type. + size_t BeginTagPos = std::get<0>(ImagePos); + size_t TypePos = std::get<0>(ImagePos) + Base64ImageTagPrefix.size(); + size_t PayloadPos = std::get<1>(ImagePos); + size_t BeginBytePos = std::get<1>(ImagePos) + Base64ImageBytesPrefix.size(); + size_t EndTagPos = std::get<2>(ImagePos); + std::string_view Payload = + std::string_view(Prompt).substr(BeginBytePos, EndTagPos - BeginBytePos); + std::string ImageType = Prompt.substr(TypePos, PayloadPos - TypePos); + + // Decode the base64 payload. + auto RequiredBytes = base64::required_encode_size(Payload.size()); + std::vector ImageBytes(RequiredBytes); + try { + base64::decode(Payload.begin(), Payload.end(), ImageBytes.begin()); + } catch (const base64_error &E) { + RET_ERROR(std::make_pair(std::vector(), ""), + "Error when base64::decode: {}"sv, E.what()) + } + + // Replace the base64 image with the placeholder. + Prompt.replace(BeginTagPos, + EndTagPos - BeginTagPos + Base64ImageTagSuffix.size(), + Placeholder); + return std::make_pair(ImageBytes, ImageType); +} + +// <<<<<<<< Input related functions <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + +// >>>>>>>> Output related functions >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + +// Generate output metadata. +std::string buildOutputMetadata(Context &CxtRef) noexcept { + return fmt::format(R"({{"input_tokens": {}, )" + R"("output_tokens": {}, )" + R"("llama_build_number": {}, )" + R"("llama_commit": "{}"}})"sv, + CxtRef.LlamaNInputs, CxtRef.LlamaOutputTokens.size(), + LLAMA_BUILD_NUMBER, LLAMA_COMMIT); +} + +// Generate output embedding. void buildOutputEmbedding(std::string &Embedding, int32_t NEmbd, const float *Embeddings) noexcept { // Embedding vector format @@ -422,36 +525,82 @@ void buildOutputEmbedding(std::string &Embedding, int32_t NEmbd, NEmbd, fmt::join(Embeddings, Embeddings + NEmbd, ","sv)); } -static bool evaluateQwen2vlImageEmbed( - llama_context *CtxLlama, const struct llava_image_embed *ImageEmbed, - int NBatch, int *NPast, int *StPosId, struct clip_image_size *ImageSize) { - int NEmbd = llama_n_embd(llama_get_model(CtxLlama)); +// <<<<<<<< Output related functions <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + +// >>>>>>>> Compute related functions >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + +// Helper to init a llama batch. +struct llama_batch allocBatch(int64_t NTokens, int64_t Embd = 0, + int32_t NSeqMax = 1) noexcept { + struct llama_batch Batch = llama_batch_init( + /* n_tokens_alloc */ static_cast(NTokens), + /* embd */ static_cast(Embd), + /* n_seq_max */ static_cast(NSeqMax)); + std::fill(Batch.n_seq_id, Batch.n_seq_id + NTokens, + static_cast(NSeqMax)); + for (int64_t I = 0; I < NTokens; I++) { + std::fill(Batch.seq_id[I], Batch.seq_id[I] + NSeqMax, 0); + } + std::fill(Batch.logits, Batch.logits + NTokens, false); + return Batch; +} + +// Fill tokens (smaller than batch size) into a batch with position data. +void fillBatch(Span Tokens, Graph &GraphRef, + llama_batch &Batch, int &NPos, bool IsLogit = false) { + assuming(GraphRef.BatchSize >= static_cast(Tokens.size())); + assuming(Batch.token != nullptr); + assuming(Batch.pos != nullptr); + assuming(Batch.logits != nullptr); + // Fill the batch with pos information. + Batch.n_tokens = static_cast(Tokens.size()); + for (uint32_t I = 0; I < Tokens.size(); I++) { + Batch.token[I] = Tokens[I]; + Batch.pos[I] = NPos + I; + Batch.logits[I] = false; + } + + // Logits of sampling or end of inputs. + if (IsLogit) { + Batch.logits[Tokens.size() - 1] = true; + } + + // Move the position. + NPos += static_cast(Tokens.size()); +} + +// Evaluate Qwen2vl image embedding. +bool evaluateQwen2vlImageEmbed(llama_context *LlamaCxt, + const struct llava_image_embed *ImageEmbed, + int64_t NBatch, int32_t &NPos, + struct clip_image_size *ImageSize) { + int NEmbd = llama_n_embd(llama_get_model(LlamaCxt)); const int PatchSize = 14 * 2; const int Ph = ImageSize->height / PatchSize + (ImageSize->height % PatchSize > 0); const int Pw = ImageSize->width / PatchSize + (ImageSize->width % PatchSize > 0); - auto ImgTokens = ImageEmbed->n_image_pos; + const int ImgTokens = ImageEmbed->n_image_pos; std::vector MRopePos; MRopePos.resize(ImgTokens * 4); + int32_t StPosId = NPos; for (int Y = 0; Y < Ph; Y++) { for (int X = 0; X < Pw; X++) { int I = Y * Pw + X; - MRopePos[I] = *StPosId; - MRopePos[I + ImgTokens] = *StPosId + Y; - MRopePos[I + ImgTokens * 2] = *StPosId + X; + MRopePos[I] = StPosId; + MRopePos[I + ImgTokens] = StPosId + Y; + MRopePos[I + ImgTokens * 2] = StPosId + X; MRopePos[I + ImgTokens * 3] = 0; } } - *StPosId += std::max(Pw, Ph); - int Processed = 0; + int32_t Processed = 0; std::vector BatchMRopePos; BatchMRopePos.resize(ImgTokens * 4); - for (int I = 0; I < ImgTokens; I += NBatch) { - int NEval = ImgTokens - I; + for (int64_t I = 0; I < ImgTokens; I += NBatch) { + int64_t NEval = ImgTokens - I; if (NEval > NBatch) { NEval = NBatch; } @@ -474,52 +623,43 @@ static bool evaluateQwen2vlImageEmbed( nullptr, // seq_id nullptr, // logits }; - if (llama_decode(CtxLlama, Batch)) { + if (llama_decode(LlamaCxt, Batch)) { RET_ERROR(false, "evaluateQwen2vlImageEmbed: fail to eval."sv) } - *NPast += NEval; - Processed += NEval; + NPos += static_cast(NEval); + Processed += static_cast(NEval); } return true; } -ErrNo evaluateTokens(Graph &GraphRef, struct llama_context *LlamaContext, - std::vector Tokens, int &NPast, - int &NPos) noexcept { - uint32_t NCtx = llama_n_ctx(LlamaContext); - +// Evaluate tokens. Construct the tokens into batch and decode. +ErrNo evaluateTokens(Span Tokens, Graph &GraphRef, + llama_batch &Batch, int &NPos, + bool IsLogits = false) noexcept { // End the inference if the context is full. - if (NPast + static_cast(Tokens.size()) > NCtx) { - LOG_INFO( - GraphRef.EnableLog, - "the context if full ({} / {} tokens). Please increase your context "sv - "size."sv, - NPast + static_cast(Tokens.size()), NCtx) + uint32_t NCtx = llama_n_ctx(GraphRef.LlamaContext.get()); + if (NPos + static_cast(Tokens.size()) > NCtx) { + LOG_INFO(GraphRef.EnableLog, + "the context if full ({} / {} tokens). Please increase your "sv + "context size."sv, + NPos + static_cast(Tokens.size()), NCtx) return ErrNo::ContextFull; } - std::vector LlamaPos; + // Loop for decode batch. Split tokens into batch size length. for (int I = 0; I < static_cast(Tokens.size()); I += static_cast(GraphRef.BatchSize)) { int NEval = static_cast(Tokens.size()) - I; if (NEval > static_cast(GraphRef.BatchSize)) { NEval = static_cast(GraphRef.BatchSize); } - // Get a batch for single sequence of tokens. - auto Batch = llama_batch_get_one(&Tokens[I], NEval); - - // Add pos information for Qwen2vl. - if (GraphRef.VisionModelType == VisionModel::Qwen2VL) { - LlamaPos.resize(Batch.n_tokens * 4); - std::fill(LlamaPos.begin(), LlamaPos.end(), 0); - for (int J = 0; J < Batch.n_tokens * 3; J++) { - LlamaPos[J] = NPos + (J % Batch.n_tokens); - } - Batch.pos = LlamaPos.data(); - } + // Fill the batch with pos information. + fillBatch(Span(Tokens.begin() + I, NEval), GraphRef, + Batch, NPos, + IsLogits && I + NEval >= static_cast(Tokens.size())); // Decode the batch. - auto Status = llama_decode(LlamaContext, Batch); + auto Status = llama_decode(GraphRef.LlamaContext.get(), Batch); if (Status == 1) { RET_ERROR(ErrNo::RuntimeError, "failed to llama_decode: try reducing the size of the batch "sv @@ -529,243 +669,213 @@ ErrNo evaluateTokens(Graph &GraphRef, struct llama_context *LlamaContext, "failed to llama_decode: internal fatal error. Please open "sv "an issue on GitHub."sv) } - NPast += NEval; - NPos += NEval; - } - - return ErrNo::Success; -} - -void batchAddSeq(llama_batch &Batch, const std::vector &Tokens, - llama_seq_id SequenceId) noexcept { - for (int I = 0; I < static_cast(Tokens.size()); I++) { - // llama_batch_add_seq(llama_batch, llama_token, llama_pos, - // std::vector, logits); - common_batch_add(Batch, Tokens[I], I, {SequenceId}, - I == static_cast(Tokens.size()) - 1); - } -} - -ErrNo batchDecode(llama_context *LlamaContext, llama_batch &Batch, - float *Output, int NEmbd, - EmbdNormalizeType EmbdNormalize) noexcept { - // Clear previous kv_cache values (irrelevant for embeddings) - llama_kv_cache_clear(LlamaContext); - - // Decode the batch. - auto Status = llama_decode(LlamaContext, Batch); - if (Status == 1) { - RET_ERROR(ErrNo::RuntimeError, - "failed to llama_decode: try reducing the size of the batch or "sv - "increasing the size of context."sv) - } else if (Status < 0) { - RET_ERROR(ErrNo::RuntimeError, - "failed to llama_decode: internal fatal error. Please open an "sv - "issue on GitHub."sv) - } - - for (int I = 0; I < Batch.n_tokens; I++) { - if (!Batch.logits[I]) { - continue; - } - - // Try to get sequence embeddings. - auto *Embd = llama_get_embeddings_seq(LlamaContext, Batch.seq_id[I][0]); - if (Embd == nullptr) { - Embd = llama_get_embeddings_ith(LlamaContext, I); - if (Embd == nullptr) { - spdlog::error( - "[WASI-NN] GGML backend: failed to get embeddings for token {}"sv, - I); - continue; - } - } - - // Normalize the embeddings. - common_embd_normalize(Embd, Output, NEmbd, - static_cast(EmbdNormalize)); } return ErrNo::Success; } -Expect getEmbedding(WasiNNEnvironment &Env, - uint32_t ContextId) noexcept { - auto &CxtRef = Env.NNContext[ContextId].get(); - auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - LOG_DEBUG(GraphRef.EnableDebugLog, "getEmbedding"sv) - +// Evaluate the input tokens. Clean all inputs if succeeded. +ErrNo evaluateInput(Graph &GraphRef, Context &CxtRef, + std::string_view LogPrefix) noexcept { + // Check if the input is set before setting up the context. if (CxtRef.LlamaInputs.size() == 0) { - RET_ERROR(ErrNo::InvalidArgument, "Llama input is not set!"sv) + RET_ERROR(ErrNo::InvalidArgument, "{}: llama input is not set!"sv, + LogPrefix) } // Clear the outputs. LOG_DEBUG(GraphRef.EnableDebugLog, - "getEmbedding: clear the previous output and tokens"sv) + "{}: clear the previous output and tokens"sv, LogPrefix) CxtRef.LlamaOutputs.clear(); CxtRef.LlamaOutputTokens.clear(); LOG_DEBUG(GraphRef.EnableDebugLog, - "getEmbedding: clear the previous output and tokens...Done"sv) - - // Main prediction loop. - LOG_DEBUG(GraphRef.EnableDebugLog, "getEmbedding: enter embedding loop"sv) + "{}: clear the previous output and tokens...Done"sv, LogPrefix) // Clear the llama context. llama_kv_cache_clear(GraphRef.LlamaContext.get()); - // Use the const sequence id here. - const llama_seq_id SequenceId = 0; + // Prepare variables; + CxtRef.NPos = 0; + // Get the context size. + const uint64_t NCtx = llama_n_ctx(GraphRef.LlamaContext.get()); + // Minus 4 for the special tokens. (Such as , , ... tokens.) + const uint64_t MaxTokensListSize = NCtx - 4; // Return value. auto ReturnCode = ErrNo::Success; - // Add SEP if not present. - if (CxtRef.LlamaInputs.back() != llama_token_sep(GraphRef.LlamaModel.get())) { - LOG_WARN( - "last token in the prompt is not SEP, "sv - "'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF "sv - "header."sv) - } - // Check if the input is too long. - if (static_cast(CxtRef.LlamaInputs.size()) > GraphRef.BatchSize) { - RET_ERROR( - ErrNo::PromptTooLong, - "the prompt is too long. Your input has {} tokens exceeds batch "sv - "size {}. Please reduce the input size or increase your batch-size."sv, - CxtRef.LlamaInputs.size(), GraphRef.BatchSize) + if (static_cast(CxtRef.LlamaInputs.size()) > MaxTokensListSize) { + RET_ERROR(ErrNo::PromptTooLong, + "{}: the prompt is too long. Your input has {} tokens. "sv + "Please reduce it to {} tokens."sv, + LogPrefix, CxtRef.LlamaInputs.size(), MaxTokensListSize) } - const int32_t NEmbd = llama_n_embd(GraphRef.LlamaModel.get()); - struct llama_batch Batch = llama_batch_init( - /* n_tokens_alloc */ static_cast(GraphRef.BatchSize), - /* embd */ 0, - /* n_seq_max */ 1); - std::vector Embeddings(NEmbd); - batchAddSeq(Batch, CxtRef.LlamaInputs, SequenceId); - ReturnCode = batchDecode(GraphRef.LlamaContext.get(), Batch, - Embeddings.data(), NEmbd, GraphRef.EmbdNormalize); - if (ReturnCode != ErrNo::Success) { - RET_ERROR(ReturnCode, "failed to evaluate input tokens."sv) - } - buildOutputEmbedding(CxtRef.LlamaOutputs, NEmbd, Embeddings.data()); + // Evaluate input tokens. + if (CxtRef.LlavaImageEmbd != nullptr) { + // Llava format prompt with image data. + ReturnCode = + evaluateTokens(Span(CxtRef.LlamaInputs.begin(), + CxtRef.ImagePosition), + GraphRef, CxtRef.LlamaBatch, CxtRef.NPos); + if (ReturnCode != ErrNo::Success) { + RET_ERROR(ReturnCode, + "{}: failed to evaluate input tokens before image."sv, + LogPrefix) + } - LOG_DEBUG(GraphRef.EnableDebugLog, - "getEmbedding: enter embedding loop...Done"sv) + bool EvalImageStatus = false; + switch (GraphRef.VisionModelType) { + case VisionModel::Llava: + EvalImageStatus = llava_eval_image_embed( + GraphRef.LlamaContext.get(), CxtRef.LlavaImageEmbd, + static_cast(GraphRef.BatchSize), &CxtRef.NPos); + break; + case VisionModel::Qwen2VL: + auto ImageSize = clip_get_load_image_size(GraphRef.ClipContext); + EvalImageStatus = evaluateQwen2vlImageEmbed( + GraphRef.LlamaContext.get(), CxtRef.LlavaImageEmbd, + static_cast(GraphRef.BatchSize), CxtRef.NPos, ImageSize); + break; + } - if (GraphRef.EnableLog) { - common_perf_print(GraphRef.LlamaContext.get(), /* Sampler */ nullptr); + if (!EvalImageStatus) { + RET_ERROR(ErrNo::RuntimeError, + "{}: failed to evaluate embed image tokens."sv, LogPrefix) + } + ReturnCode = + evaluateTokens(Span( + CxtRef.LlamaInputs.begin() + CxtRef.ImagePosition, + CxtRef.LlamaInputs.size() - CxtRef.ImagePosition), + GraphRef, CxtRef.LlamaBatch, CxtRef.NPos, true); + if (ReturnCode != ErrNo::Success) { + RET_ERROR(ReturnCode, + "{}: failed to evaluate input tokens after image."sv, LogPrefix) + } + } else { + // Text only prompt. + ReturnCode = + evaluateTokens(Span(CxtRef.LlamaInputs.begin(), + CxtRef.LlamaInputs.size()), + GraphRef, CxtRef.LlamaBatch, CxtRef.NPos, true); + if (ReturnCode != ErrNo::Success) { + RET_ERROR(ReturnCode, "{}: failed to evaluate input tokens."sv, LogPrefix) + } } - // We clear the contexts here to keep the ggml plugin stateless. - // Users could fully control the contexts by themselves via their prompt. - llama_kv_cache_clear(GraphRef.LlamaContext.get()); - llama_batch_free(Batch); - - LOG_DEBUG(GraphRef.EnableDebugLog, "getEmbedding...Done"sv) + CxtRef.Conf.ImagePath = ""sv; + if (CxtRef.LlavaImageEmbd != nullptr) { + llava_image_embed_free(CxtRef.LlavaImageEmbd); + CxtRef.LlavaImageEmbd = nullptr; + } return ErrNo::Success; } -const std::string_view Base64ImageTagPrefix = ""sv; -const std::string_view LlavaPromptImagePlaceholder = ""sv; - -bool containsBase64Image(Graph &GraphRef, std::string_view Prompt) noexcept { - // Check if the prompt contains a base64 image. - // Follow this link for the supported image formats: - // https://github.com/ggerganov/llama.cpp/blob/master/common/stb_image.h +// Sample and get the output token. +ErrNo sampleOutput(Graph &GraphRef, Context &CxtRef, + bool IsSingleTokenMode = false) noexcept { + // Use idx = -1 to sample the next token. + const llama_token Id = common_sampler_sample( + CxtRef.LlamaSampler, GraphRef.LlamaContext.get(), /* idx */ -1); + common_sampler_accept(CxtRef.LlamaSampler, Id, /* accept_grammar */ true); - auto Base64ImageTagBeginPos = Prompt.find(Base64ImageTagPrefix); - if (Base64ImageTagBeginPos == std::string::npos) { - LOG_DEBUG(GraphRef.EnableDebugLog, - "No base64 image tag found in the prompt."sv) - return false; + // Save the output token. + CxtRef.LlamaOutputTokens.emplace_back(Id); + CxtRef.LlamaOutputs += common_token_to_piece(GraphRef.LlamaContext.get(), Id); + // In single token mode, we do not handle StreamStdout and ReversePrompt. + if (!IsSingleTokenMode) { + // When setting StreamStdout, we print the output to stdout. + if (CxtRef.Conf.StreamStdout) { + fmt::print("{}"sv, + common_token_to_piece(GraphRef.LlamaContext.get(), Id)); + std::fflush(stdout); + } + // Break if reverse prompt is found. + if (!CxtRef.Conf.ReversePrompt.empty() && + CxtRef.LlamaOutputs.find(CxtRef.Conf.ReversePrompt) != + std::string::npos) { + LOG_INFO(GraphRef.EnableLog, "reverse prompt found."sv) + return ErrNo::EndOfSequence; + } } - auto Base64ImageTagEndPos = - Prompt.find(Base64ImageTagSuffix, Base64ImageTagBeginPos); - if (Base64ImageTagEndPos == std::string::npos) { - LOG_DEBUG(GraphRef.EnableDebugLog, "Found an unclosed base64 image tag."sv) - return false; + // Deal with end of text token. + if (llama_token_is_eog(GraphRef.LlamaModel.get(), + common_sampler_last(CxtRef.LlamaSampler))) { + LOG_INFO(GraphRef.EnableLog, "EOS token found."sv) + return ErrNo::EndOfSequence; } - return true; + // Evaluate the output token. + return evaluateTokens(Span(&Id, 1), GraphRef, + CxtRef.OutputBatch, CxtRef.NPos, true); } -std::string_view findBase64ImagePayload(std::string_view Prompt) noexcept { - // Find ` 0 && + CxtRef.LlamaInputs.back() != llama_token_sep(GraphRef.LlamaModel.get())) { + LOG_WARN( + "last token in the prompt is not SEP, "sv + "'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF "sv + "header."sv) } - // Find `">` - auto Base64ImageTagEndPos = - Prompt.find(Base64ImageTagSuffix, Base64ImageBytesBeginPos); - if (Base64ImageTagEndPos == std::string::npos) { - return Prompt.substr(); + // Check if the input is too long. + if (static_cast(CxtRef.LlamaInputs.size()) > GraphRef.BatchSize) { + RET_ERROR( + ErrNo::PromptTooLong, + "the prompt is too long. Your input has {} tokens exceeds batch "sv + "size {}. Please reduce the input size or increase your batch-size."sv, + CxtRef.LlamaInputs.size(), GraphRef.BatchSize) } - return Prompt.substr(Base64ImageBytesBeginPos + Base64ImageBytesPrefix.size(), - Base64ImageTagEndPos - Base64ImageBytesBeginPos - - Base64ImageBytesPrefix.size()); -} - -struct llava_image_embed * -llavaLoadBase64ImageFromPrompt(Graph &GraphRef, clip_ctx *ClipContext, - std::string_view Prompt) noexcept { - // Load the base64 image from the prompt. - // Follow this link for the supported image formats: - // https://github.com/ggerganov/llama.cpp/blob/master/common/stb_image.h - LOG_DEBUG(GraphRef.EnableDebugLog, "llavaLoadBase64ImageFromPrompt"sv) - - // Decode the base64 image. - auto Base64Str = findBase64ImagePayload(Prompt); - if (Base64Str.size() == 0) { - return nullptr; - } - auto RequiredBytes = base64::required_encode_size(Base64Str.size()); - auto ImageBytes = std::vector(RequiredBytes); - try { - base64::decode(Base64Str.begin(), Base64Str.end(), ImageBytes.begin()); - } catch (const base64_error &E) { - RET_ERROR(nullptr, "Error when base64::decode: {}"sv, E.what()) + // Evaluate the input tokens. + auto ReturnCode = evaluateInput(GraphRef, CxtRef, "getEmbedding"sv); + if (ReturnCode != ErrNo::Success) { + return ReturnCode; } - LOG_DEBUG(GraphRef.EnableDebugLog, "llavaLoadBase64ImageFromPrompt...Done"sv) - return llava_image_embed_make_with_bytes( - ClipContext, static_cast(GraphRef.Threads), ImageBytes.data(), - static_cast(ImageBytes.size())); -} + // Main prediction loop. + const int32_t NEmbd = llama_n_embd(GraphRef.LlamaModel.get()); + std::vector Embeddings(NEmbd); -ErrNo replaceBase64ImagePlaceholderInPrompt( - std::string &Prompt, const std::string_view Placeholder) noexcept { - // Replace the base64 image in the prompt with a placeholder. + for (int I = 0; I < CxtRef.LlamaBatch.n_tokens; I++) { + if (!CxtRef.LlamaBatch.logits[I]) { + continue; + } - // Find `(CxtRef.Conf.EmbdNormalize)); } - // Find `">` - auto Base64ImageTagEndPos = - Prompt.find(Base64ImageTagSuffix, Base64ImageTagBeginPos); - if (Base64ImageTagEndPos == std::string::npos) { - return ErrNo::InvalidArgument; + buildOutputEmbedding(CxtRef.LlamaOutputs, NEmbd, Embeddings.data()); + + if (GraphRef.EnableLog) { + common_perf_print(GraphRef.LlamaContext.get(), /* Sampler */ nullptr); } - auto Base64ImageTagLength = Base64ImageTagEndPos - Base64ImageTagBeginPos + - Base64ImageTagSuffix.size(); - Prompt.replace(Base64ImageTagBeginPos, Base64ImageTagLength, Placeholder); + LOG_DEBUG(GraphRef.EnableDebugLog, "getEmbedding...Done"sv) return ErrNo::Success; } +// <<<<<<<< Compute related functions <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + } // namespace Expect load(WasiNNEnvironment &Env, Span> Builders, @@ -775,33 +885,34 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, auto &GraphRef = Env.NNGraph[GId].get(); // Initialize the plugin parameters. - llama_context_params ContextDefault = llama_context_default_params(); - const common_params ParamsDefault; GraphRef.EnableLog = false; GraphRef.EnableDebugLog = false; - GraphRef.StreamStdout = false; - GraphRef.NPredict = ContextDefault.n_ctx; - GraphRef.ReversePrompt = ""sv; - GraphRef.MMProjModelPath = ""sv; - GraphRef.ImagePath = ""sv; - GraphRef.EmbdNormalize = - static_cast(ParamsDefault.embd_normalize); // Initialize the model parameters. - llama_model_params ModelParams = llama_model_default_params(); - GraphRef.NGPULayers = ModelParams.n_gpu_layers; + llama_model_params ModelParamsDefault = llama_model_default_params(); + GraphRef.NGPULayers = ModelParamsDefault.n_gpu_layers; + GraphRef.MMProjModelPath = ""sv; // Initialize the context parameters. - GraphRef.CtxSize = ContextDefault.n_ctx; - GraphRef.BatchSize = ContextDefault.n_batch; - GraphRef.UBatchSize = ContextDefault.n_ubatch; - GraphRef.Threads = ContextDefault.n_threads; + llama_context_params ContextParamsDefault = llama_context_default_params(); + GraphRef.CtxSize = ContextParamsDefault.n_ctx; + GraphRef.BatchSize = ContextParamsDefault.n_batch; + GraphRef.UBatchSize = ContextParamsDefault.n_ubatch; + GraphRef.Threads = ContextParamsDefault.n_threads; // Initialize the sampling parameters. - const common_params_sampling SamplerDefault; - GraphRef.Temp = SamplerDefault.temp; - GraphRef.TopP = SamplerDefault.top_p; - GraphRef.RepeatPenalty = SamplerDefault.penalty_repeat; - GraphRef.PresencePenalty = SamplerDefault.penalty_present; - GraphRef.FrequencyPenalty = SamplerDefault.penalty_freq; - GraphRef.Grammar = SamplerDefault.grammar; + const common_params_sampling SamplerParamsDefault; + GraphRef.Temp = SamplerParamsDefault.temp; + GraphRef.TopP = SamplerParamsDefault.top_p; + GraphRef.RepeatPenalty = SamplerParamsDefault.penalty_repeat; + GraphRef.PresencePenalty = SamplerParamsDefault.penalty_present; + GraphRef.FrequencyPenalty = SamplerParamsDefault.penalty_freq; + GraphRef.Grammar = SamplerParamsDefault.grammar; + // Initialize the config parameters. + const common_params CommonParamsDefault; + GraphRef.Conf.StreamStdout = false; + GraphRef.Conf.EmbdNormalize = + static_cast(CommonParamsDefault.embd_normalize); + GraphRef.Conf.NPredict = ContextParamsDefault.n_ctx; + GraphRef.Conf.ReversePrompt = ""sv; + GraphRef.Conf.ImagePath = ""sv; // Set llama log callback. llama_log_set(LlamaLogCallback, &GraphRef); @@ -811,7 +922,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, const std::string Metadata(reinterpret_cast(Builders[1].data()), Builders[1].size()); // Ignore context or model updates when initializing the graph. - auto Res = parseMetadata(GraphRef, Metadata); + auto Res = parseMetadata(GraphRef, GraphRef.Conf, Metadata); if (Res != ErrNo::Success) { Env.deleteGraph(GId); RET_ERROR(Res, "Failed to parse metadata."sv) @@ -865,7 +976,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, LOG_DEBUG(GraphRef.EnableDebugLog, "load: initialize ggml model with given parameters."sv) common_params Params; - setupParams(GraphRef, Params); + setupCommonParams(GraphRef, Params); llama_backend_init(); llama_numa_init(Params.numa); @@ -899,6 +1010,21 @@ Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, ContextId = Env.newContext(GraphId, Env.NNGraph[GraphId]); LOG_INFO(GraphRef.EnableLog, "llama_system_info: {}"sv, llama_print_system_info()) + + auto &CxtRef = Env.NNContext[ContextId].get(); + // Allocate the batch for input string prompt tokens. + CxtRef.LlamaBatch = allocBatch(GraphRef.BatchSize); + CxtRef.CurrentBatchSize = GraphRef.BatchSize; + + // Allocate the batch for output sampling. The batch size is always 1. + CxtRef.OutputBatch = allocBatch(1); + + // Allocate sampler. + common_params_sampling CommonSampling; + setupSamplerParams(GraphRef, CommonSampling); + CxtRef.LlamaSampler = + common_sampler_init(GraphRef.LlamaModel.get(), CommonSampling); + Env.NNContext[ContextId].setReady(); LOG_DEBUG(GraphRef.EnableDebugLog, "initExecCtx...Done"sv) return ErrNo::Success; @@ -911,33 +1037,43 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, LOG_DEBUG(GraphRef.EnableDebugLog, "setInput"sv) // Use index 1 for metadata. - bool IsModelParamsUpdated = false; - bool IsContextParamsUpdated = false; if (Index == 1) { LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: found Metadata, processing"sv) + bool IsModelParamsUpdated = false; + bool IsContextParamsUpdated = false; + bool IsSamplerParamsUpdated = false; const std::string Metadata(reinterpret_cast(Tensor.Tensor.data()), Tensor.Tensor.size()); - auto Res = parseMetadata(GraphRef, Metadata, &IsModelParamsUpdated, - &IsContextParamsUpdated); + auto Res = + parseMetadata(GraphRef, CxtRef.Conf, Metadata, &IsModelParamsUpdated, + &IsContextParamsUpdated, &IsSamplerParamsUpdated); if (Res != ErrNo::Success) { RET_ERROR(Res, "failed to parse metadata."sv) } #ifndef __APPLE__ - // XXX: Due to the limitation of WASI-NN proposal, this is a workaround for - // non-macOS devices. However, if the model params is updated in Config - // stage, then, we don't encourage to use this to avoid the model + // XXX: Due to the limitation of WASI-NN proposal, this is a workaround + // for non-macOS devices. However, if the model params is updated in + // Config stage, then, we don't encourage to use this to avoid the model // reloading. { if (IsModelParamsUpdated || GraphRef.LlamaModel == nullptr) { // The llama model may be nullptr if set_input with updated model params // last time. Therefore besides the model params updated, we should // reload the llama model if the model is nullptr. - LOG_INFO(GraphRef.EnableLog, - "Reloaded model due to parameters change."sv) + LOG_INFO(GraphRef.EnableLog, "Reload model due to parameters change."sv) llama_model_params ModelParams = llama_model_default_params(); ModelParams.n_gpu_layers = static_cast(GraphRef.NGPULayers); GraphRef.LlamaModel.reset(); + // Due to the model change, the context and sampler should also be + // reloaded. The new context and sampler will be created in the next + // block. + GraphRef.LlamaContext.reset(); + if (CxtRef.LlamaSampler) { + // TODO: Trigger the sampler in other contexts to reallocate. + common_sampler_free(CxtRef.LlamaSampler); + CxtRef.LlamaSampler = nullptr; + } GraphRef.LlamaModel = llama_model_ptr(llama_load_model_from_file( GraphRef.ModelFilePath.c_str(), ModelParams)); if (GraphRef.LlamaModel == nullptr) { @@ -952,10 +1088,10 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, // reloaded. if (IsContextParamsUpdated || GraphRef.LlamaContext == nullptr) { LOG_INFO(GraphRef.EnableLog, - "Reloaded llama context due to parameters change."sv) + "Reload llama context due to parameters change."sv) GraphRef.LlamaContext.reset(); common_params Params; - setupParams(GraphRef, Params); + setupCommonParams(GraphRef, Params); GraphRef.LlamaContext = llama_context_ptr(llama_new_context_with_model( GraphRef.LlamaModel.get(), common_context_params_to_llama(Params))); if (GraphRef.LlamaContext == nullptr) { @@ -964,12 +1100,38 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } } + // Some changes of sampling parameters will require the sampler to be + // reallocated. + if (IsSamplerParamsUpdated || CxtRef.LlamaSampler == nullptr) { + LOG_INFO(GraphRef.EnableLog, + "Reallocate llama sampler due to parameters change."sv) + if (CxtRef.LlamaSampler) { + common_sampler_free(CxtRef.LlamaSampler); + } + common_params_sampling CommonSampling; + setupSamplerParams(GraphRef, CommonSampling); + CxtRef.LlamaSampler = + common_sampler_init(GraphRef.LlamaModel.get(), CommonSampling); + if (GraphRef.LlamaContext == nullptr) { + Env.NNGraph[CxtRef.GraphId].setInvalid(); + RET_ERROR(ErrNo::InvalidArgument, "unable to init sampler."sv) + } + } + + // Check that is batch size changed. + if (CxtRef.CurrentBatchSize != GraphRef.BatchSize) { + llama_batch_free(CxtRef.LlamaBatch); + CxtRef.LlamaBatch = allocBatch(GraphRef.BatchSize); + CxtRef.CurrentBatchSize = GraphRef.BatchSize; + } + Env.NNGraph[CxtRef.GraphId].setReady(); LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: found Metadata, processing...Done"sv) return ErrNo::Success; } + // Check the graph is valid after reloading during previous set_input. if (!Env.NNGraph[CxtRef.GraphId].isReady()) { RET_ERROR(ErrNo::InvalidArgument, "Graph is invalid. Please reload again by passing metadata "sv @@ -987,29 +1149,16 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), Tensor.Tensor.size()); CxtRef.LlamaInputs.clear(); - if (GraphRef.MMProjModelPath != ""sv) { - // Handle llava format prompt. - LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: handle llava format prompt"sv) - // Check if the prompt contains a base64 image. - bool ContainsBase64Image = containsBase64Image(GraphRef, Prompt); - if (GraphRef.ImagePath == ""sv && ContainsBase64Image == false) { - RET_ERROR( - ErrNo::InvalidArgument, - "when using llava model, you need to specify the image path or "sv - "have the base64 encoded image in the prompt."sv) - } + if (CxtRef.LlavaImageEmbd != nullptr) { + llava_image_embed_free(CxtRef.LlavaImageEmbd); + CxtRef.LlavaImageEmbd = nullptr; + } + auto Base64ImagePos = findBase64ImagePayload(Prompt); - // Show some warnings. - if (GraphRef.CtxSize < 4096) { - LOG_INFO( - GraphRef.EnableLog, - "Context size is {}, we recommend context size >= 2048 when using "sv - "llava-v1.5 and context size >= 4096 when using llava-v1.6 for "sv - "better results."sv, - GraphRef.CtxSize) - } + if (Base64ImagePos.has_value() || CxtRef.Conf.ImagePath != ""sv) { + // Prompt with image input. Check is llava or mllama case. - // Load the clip model if not loaded. + // First check the projection model is loaded. if (GraphRef.ClipContext == nullptr) { LOG_INFO( true, @@ -1025,58 +1174,77 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, if (clip_is_qwen2vl(GraphRef.ClipContext)) { GraphRef.VisionModelType = VisionModel::Qwen2VL; LOG_INFO(true, "Qwen2vl model loaded."sv) + } else { + GraphRef.VisionModelType = VisionModel::Llava; } } - // Get image embed. - if (ContainsBase64Image) { - // Load the base64 image from the prompt. - CxtRef.LlavaImageEmbd = llavaLoadBase64ImageFromPrompt( - GraphRef, GraphRef.ClipContext, Prompt); - // Replace the base64 image in the prompt with a placeholder. - auto Res = replaceBase64ImagePlaceholderInPrompt( - Prompt, LlavaPromptImagePlaceholder); - if (Res != ErrNo::Success) { - clip_free(GraphRef.ClipContext); - RET_ERROR(Res, "unable to replace the base64 image in the prompt."sv) + // Prompt with image. + if (GraphRef.ClipContext != nullptr) { + // Llava case. + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: handle llava format prompt."sv) + + // Show some warnings. + if (GraphRef.CtxSize < 4096) { + LOG_INFO( + GraphRef.EnableLog, + "Context size is {}, we recommend context size >= 2048 when "sv + "using llava-v1.5 and context size >= 4096 when using llava-v1.6 "sv + "for better results."sv, + GraphRef.CtxSize) } - } else { - // Load the image from the file. - CxtRef.LlavaImageEmbd = llava_image_embed_make_with_filename( - GraphRef.ClipContext, static_cast(GraphRef.Threads), - GraphRef.ImagePath.c_str()); - } - if (CxtRef.LlavaImageEmbd == nullptr) { - RET_ERROR(ErrNo::InvalidArgument, "unable to load the image."sv) - } - // We split prompt by as placeholder and save the position. - auto PlaceholderPosition = Prompt.find(LlavaPromptImagePlaceholder); - if (PlaceholderPosition == std::string::npos) { - RET_ERROR(ErrNo::InvalidArgument, - "unable to find the placeholder in the llava prompt."sv) + // Get image embed. + // Follow this link for the supported image formats: + // https://github.com/ggerganov/llama.cpp/blob/master/common/stb_image.h + if (Base64ImagePos.has_value()) { + // Extract the payload and image type from the prompt. + auto Payload = extractBase64ImagePayload(Prompt, *Base64ImagePos, + LlavaPromptImagePlaceholder); + if (Payload.has_value()) { + CxtRef.LlavaImageEmbd = llava_image_embed_make_with_bytes( + GraphRef.ClipContext, static_cast(GraphRef.Threads), + Payload->first.data(), static_cast(Payload->first.size())); + } + } else { + // Load the image from the file. + CxtRef.LlavaImageEmbd = llava_image_embed_make_with_filename( + GraphRef.ClipContext, static_cast(GraphRef.Threads), + CxtRef.Conf.ImagePath.c_str()); + } + if (CxtRef.LlavaImageEmbd == nullptr) { + RET_ERROR(ErrNo::InvalidArgument, "llava unable to load the image."sv) + } + + // We split prompt by as placeholder and save the position. + auto PlaceholderPosition = Prompt.find(LlavaPromptImagePlaceholder); + if (PlaceholderPosition == std::string::npos) { + RET_ERROR(ErrNo::InvalidArgument, + "unable to find the placeholder in the llava prompt."sv) + } + std::string PromptBeforeImage = Prompt.substr(0, PlaceholderPosition); + std::string PromptAfterImage = Prompt.substr( + PlaceholderPosition + LlavaPromptImagePlaceholder.length()); + std::vector EmbdInputBeforeImage = + common_tokenize(GraphRef.LlamaContext.get(), PromptBeforeImage, + AddSpecial, ParseSpecial); + // Do not add special token (such as , , ... tokens.) to the + // tokens after the image. + std::vector EmbdInputAfterImage = common_tokenize( + GraphRef.LlamaContext.get(), PromptAfterImage, false, ParseSpecial); + CxtRef.ImagePosition = EmbdInputBeforeImage.size(); + CxtRef.LlamaInputs.reserve(EmbdInputBeforeImage.size() + + EmbdInputAfterImage.size()); + CxtRef.LlamaInputs.insert(CxtRef.LlamaInputs.end(), + EmbdInputBeforeImage.begin(), + EmbdInputBeforeImage.end()); + CxtRef.LlamaInputs.insert(CxtRef.LlamaInputs.end(), + EmbdInputAfterImage.begin(), + EmbdInputAfterImage.end()); + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: handle llava format prompt...Done"sv) } - std::string PromptBeforeImage = Prompt.substr(0, PlaceholderPosition); - std::string PromptAfterImage = Prompt.substr( - PlaceholderPosition + LlavaPromptImagePlaceholder.length()); - std::vector EmbdInputBeforeImage = - common_tokenize(GraphRef.LlamaContext.get(), PromptBeforeImage, - AddSpecial, ParseSpecial); - // Do not add special token (such as , , ... tokens.) to the - // tokens after the image. - std::vector EmbdInputAfterImage = common_tokenize( - GraphRef.LlamaContext.get(), PromptAfterImage, false, ParseSpecial); - CxtRef.ImagePosition = EmbdInputBeforeImage.size(); - CxtRef.LlamaInputs.reserve(EmbdInputBeforeImage.size() + - EmbdInputAfterImage.size()); - CxtRef.LlamaInputs.insert(CxtRef.LlamaInputs.end(), - EmbdInputBeforeImage.begin(), - EmbdInputBeforeImage.end()); - CxtRef.LlamaInputs.insert(CxtRef.LlamaInputs.end(), - EmbdInputAfterImage.begin(), - EmbdInputAfterImage.end()); - LOG_DEBUG(GraphRef.EnableDebugLog, - "setInput: handle llava format prompt...Done"sv) } else { // Text only prompt. LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize text prompt"sv) @@ -1086,7 +1254,9 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, "setInput: tokenize text prompt...Done"sv) } CxtRef.LlamaNInputs = CxtRef.LlamaInputs.size(); - GraphRef.ComputeSingleStarted = false; + + // Maybe currently in the compute_single mode. Reset the computing. + CxtRef.ComputeSingleStarted = false; LOG_DEBUG(GraphRef.EnableDebugLog, "setInput...Done"sv) return ErrNo::Success; @@ -1098,10 +1268,10 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); LOG_DEBUG(GraphRef.EnableDebugLog, "getOutput with Index {}"sv, Index) - // Index 1 is for the metadata of the outputs. + + // Use index 1 for the metadata of the outputs. if (Index == 1) { - std::string Metadata; - buildOutputMetadata(CxtRef, Metadata); + std::string Metadata = buildOutputMetadata(CxtRef); std::copy_n(Metadata.data(), Metadata.length(), OutBuffer.data()); BytesWritten = static_cast(Metadata.length()); LOG_DEBUG(GraphRef.EnableDebugLog, "getOutput with Index {}...Done"sv, @@ -1122,154 +1292,38 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { LOG_DEBUG(GraphRef.EnableDebugLog, "compute") if (GraphRef.Embedding) { - return getEmbedding(Env, ContextId); + return getEmbedding(GraphRef, CxtRef); } - if (CxtRef.LlamaInputs.size() == 0) { - RET_ERROR(ErrNo::InvalidArgument, "llama input is not set!"sv) - } - - // Clear the outputs. - LOG_DEBUG(GraphRef.EnableDebugLog, - "compute: clear the previous output and tokens") - CxtRef.LlamaOutputs.clear(); - CxtRef.LlamaOutputTokens.clear(); - LOG_DEBUG(GraphRef.EnableDebugLog, - "compute: clear the previous output and tokens...Done") - - // Clear the llama context. - llama_kv_cache_clear(GraphRef.LlamaContext.get()); - - // Setup the parameters and sampler. - common_params Params; - setupParams(GraphRef, Params); - struct common_sampler *Sampler = - common_sampler_init(GraphRef.LlamaModel.get(), Params.sampling); - - // Prepare variables; - int32_t NPast = 0; - int32_t NPos = 0; - int64_t NRemain = GraphRef.NPredict; - // Get the context size. - const uint64_t NCtx = llama_n_ctx(GraphRef.LlamaContext.get()); - // Minus 4 for the special tokens. (Such as , , ... tokens.) - const uint64_t MaxTokensListSize = NCtx - 4; - // Return value. - auto ReturnCode = ErrNo::Success; - - // Check if the input is too long. - if (static_cast(CxtRef.LlamaInputs.size()) > MaxTokensListSize) { - RET_ERROR( - ErrNo::PromptTooLong, - "the prompt is too long. Your input has {} tokens. Please reduce it "sv - "to {} tokens."sv, - CxtRef.LlamaInputs.size(), MaxTokensListSize) - } - - // Evaluate input tokens. - if (CxtRef.LlavaImageEmbd != nullptr) { - // Llava format prompt with image data. - std::vector EmbdInputBeforeImage(CxtRef.LlamaInputs.begin(), - CxtRef.LlamaInputs.begin() + - CxtRef.ImagePosition); - std::vector EmbdInputAfterImage(CxtRef.LlamaInputs.begin() + - CxtRef.ImagePosition, - CxtRef.LlamaInputs.end()); - ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext.get(), - std::move(EmbdInputBeforeImage), NPast, NPos); - if (ReturnCode != ErrNo::Success) { - RET_ERROR(ReturnCode, "failed to evaluate input tokens before image."sv) - } - - bool EvalImageStatus = false; - switch (GraphRef.VisionModelType) { - case VisionModel::Llava: - EvalImageStatus = llava_eval_image_embed( - GraphRef.LlamaContext.get(), CxtRef.LlavaImageEmbd, - static_cast(GraphRef.BatchSize), &NPast); - break; - case VisionModel::Qwen2VL: - auto ImageSize = clip_get_load_image_size(GraphRef.ClipContext); - EvalImageStatus = evaluateQwen2vlImageEmbed( - GraphRef.LlamaContext.get(), CxtRef.LlavaImageEmbd, - static_cast(GraphRef.BatchSize), &NPast, &NPos, ImageSize); - break; - } + // Reset the sampler for a new computation. + common_sampler_reset(CxtRef.LlamaSampler); - if (!EvalImageStatus) { - RET_ERROR(ErrNo::RuntimeError, "failed to evaluate embed image tokens."sv) - } - ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext.get(), - std::move(EmbdInputAfterImage), NPast, NPos); - if (ReturnCode != ErrNo::Success) { - RET_ERROR(ReturnCode, "failed to evaluate input tokens after image."sv) - } - } else { - // Text only prompt. - ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext.get(), - std::move(CxtRef.LlamaInputs), NPast, NPos); - if (ReturnCode != ErrNo::Success) { - RET_ERROR(ReturnCode, "failed to evaluate input tokens."sv) - } + // Evaluate the input tokens. + auto ReturnCode = evaluateInput(GraphRef, CxtRef, "compute"sv); + if (ReturnCode != ErrNo::Success) { + return ReturnCode; } // Main prediction loop. LOG_DEBUG(GraphRef.EnableDebugLog, "compute: enter main prediction loop"sv) - while (NRemain > 0) { - // Use idx = -1 to sample the next token. - const llama_token Id = common_sampler_sample( - Sampler, GraphRef.LlamaContext.get(), /* idx */ -1); - common_sampler_accept(Sampler, Id, /* accept_grammar */ true); - --NRemain; - - // Save the output token. - CxtRef.LlamaOutputTokens.emplace_back(Id); - CxtRef.LlamaOutputs += - common_token_to_piece(GraphRef.LlamaContext.get(), Id); - // When setting StreamStdout, we print the output to stdout. - if (GraphRef.StreamStdout) { - fmt::print("{}"sv, - common_token_to_piece(GraphRef.LlamaContext.get(), Id)); - std::fflush(stdout); - } - // Break if reverse prompt is found. - if (!GraphRef.ReversePrompt.empty() && - CxtRef.LlamaOutputs.find(GraphRef.ReversePrompt) != std::string::npos) { - LOG_INFO(GraphRef.EnableLog, "reverse prompt found."sv) - break; - } - // Deal with end of text token. - if (llama_token_is_eog(GraphRef.LlamaModel.get(), - common_sampler_last(Sampler))) { - LOG_INFO(GraphRef.EnableLog, "EOS token found."sv) - break; - } - // Evaluate the output token. - ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext.get(), {Id}, - NPast, NPos); + int64_t NRemain = CxtRef.Conf.NPredict; + while (NRemain-- > 0) { + ReturnCode = sampleOutput(GraphRef, CxtRef); if (ReturnCode != ErrNo::Success) { break; } } + if (ReturnCode == ErrNo::EndOfSequence) { + ReturnCode = ErrNo::Success; + } LOG_DEBUG(GraphRef.EnableDebugLog, "compute: enter main prediction loop...Done"sv) - // End of main predict loop. + // End of main prediction loop. if (GraphRef.EnableLog) { - common_perf_print(GraphRef.LlamaContext.get(), Sampler); + common_perf_print(GraphRef.LlamaContext.get(), CxtRef.LlamaSampler); } - // We free the contexts here to keep the ggml plugin stateless. - // Users could fully control the contexts by themselves via their prompt. - LOG_DEBUG(GraphRef.EnableDebugLog, - "compute: delete llama sampler to make it stateless"sv) - common_sampler_free(Sampler); - if (CxtRef.LlavaImageEmbd != nullptr) { - llava_image_embed_free(CxtRef.LlavaImageEmbd); - CxtRef.LlavaImageEmbd = nullptr; - } - LOG_DEBUG(GraphRef.EnableDebugLog, - "compute: delete llama sampler to make it stateless...Done"sv) LOG_DEBUG(GraphRef.EnableDebugLog, "compute...Done"sv) return ReturnCode; } @@ -1281,16 +1335,16 @@ Expect getOutputSingle(WasiNNEnvironment &Env, uint32_t ContextId, auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); LOG_DEBUG(GraphRef.EnableDebugLog, "getOutputSingle with Index {}"sv, Index) - // Index 1 is for the metadata of the outputs. + // Use index 1 for the metadata of the outputs. if (Index == 1) { - std::string Metadata; - buildOutputMetadata(CxtRef, Metadata); + std::string Metadata = buildOutputMetadata(CxtRef); std::copy_n(Metadata.data(), Metadata.length(), OutBuffer.data()); BytesWritten = static_cast(Metadata.length()); LOG_DEBUG(GraphRef.EnableDebugLog, "getOutputSingle with Index {}...Done"sv, Index) return ErrNo::Success; } + std::string LastToken = common_token_to_piece( GraphRef.LlamaContext.get(), CxtRef.LlamaOutputTokens.back()); std::copy_n(LastToken.data(), LastToken.length(), OutBuffer.data()); @@ -1307,110 +1361,26 @@ Expect computeSingle(WasiNNEnvironment &Env, LOG_DEBUG(GraphRef.EnableDebugLog, "computeSingle"sv) // New compute single token context. - if (!GraphRef.ComputeSingleStarted) { - GraphRef.ComputeSingleStarted = true; - // Check if the input is set before setting up the context. - if (CxtRef.LlamaInputs.size() == 0) { - RET_ERROR(ErrNo::InvalidArgument, "llama input is not set!"sv) - } + auto ReturnCode = ErrNo::Success; + if (!CxtRef.ComputeSingleStarted) { + CxtRef.ComputeSingleStarted = true; - // Clear the outputs. - LOG_DEBUG(GraphRef.EnableDebugLog, - "computeSingle: clear the previous output and tokens"sv) - CxtRef.LlamaOutputs.clear(); - CxtRef.LlamaOutputTokens.clear(); - LOG_DEBUG(GraphRef.EnableDebugLog, - "computeSingle: clear the previous output and tokens...Done"sv) - - // Clear the llama context. - llama_kv_cache_clear(GraphRef.LlamaContext.get()); - - // Setup the parameters and sampler. - common_params Params; - setupParams(GraphRef, Params); - CxtRef.LlamaSampler = - common_sampler_init(GraphRef.LlamaModel.get(), Params.sampling); - CxtRef.LlamaNPast = 0; - CxtRef.LlamaNPos = 0; - - // Get the context size. - const uint64_t NCtx = llama_n_ctx(GraphRef.LlamaContext.get()); - // Minus 4 for the special tokens. (Such as , , ... tokens.) - const uint64_t MaxTokensListSize = NCtx - 4; - // Return value. - auto ReturnCode = ErrNo::Success; - - // Check if the input is too long. - if (static_cast(CxtRef.LlamaInputs.size()) > MaxTokensListSize) { - RET_ERROR( - ErrNo::PromptTooLong, - "the prompt is too long. Your input has {} tokens. Please reduce "sv - "it to {} tokens."sv, - CxtRef.LlamaInputs.size(), MaxTokensListSize) - } + // Reset the sampler for a new computation. + common_sampler_reset(CxtRef.LlamaSampler); - // Evaluate input tokens. - if (CxtRef.LlavaImageEmbd != nullptr) { - // Llava format prompt with image data. - std::vector EmbdInputBeforeImage(CxtRef.LlamaInputs.begin(), - CxtRef.LlamaInputs.begin() + - CxtRef.ImagePosition); - std::vector EmbdInputAfterImage(CxtRef.LlamaInputs.begin() + - CxtRef.ImagePosition, - CxtRef.LlamaInputs.end()); - ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext.get(), - std::move(EmbdInputBeforeImage), - CxtRef.LlamaNPast, CxtRef.LlamaNPos); - if (ReturnCode != ErrNo::Success) { - RET_ERROR(ReturnCode, "failed to evaluate input tokens before image."sv) - } - bool EvalImageStatus = llava_eval_image_embed( - GraphRef.LlamaContext.get(), CxtRef.LlavaImageEmbd, - static_cast(GraphRef.BatchSize), &CxtRef.LlamaNPast); - if (!EvalImageStatus) { - RET_ERROR(ErrNo::RuntimeError, - "failed to evaluate embed image tokens."sv) - } - ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext.get(), - std::move(EmbdInputAfterImage), - CxtRef.LlamaNPast, CxtRef.LlamaNPos); - if (ReturnCode != ErrNo::Success) { - RET_ERROR(ReturnCode, "failed to evaluate input tokens after image."sv) - } - } else { - // Text only prompt. - ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext.get(), - std::move(CxtRef.LlamaInputs), - CxtRef.LlamaNPast, CxtRef.LlamaNPos); - if (ReturnCode != ErrNo::Success) { - RET_ERROR(ReturnCode, "failed to evaluate input tokens."sv) - } + // Evaluate the input tokens. + ReturnCode = evaluateInput(GraphRef, CxtRef, "computeSingle"sv); + if (ReturnCode != ErrNo::Success) { + return ReturnCode; } } // Main prediction process. LOG_DEBUG(GraphRef.EnableDebugLog, "computeSingle: enter main prediction process"sv) - auto ReturnCode = ErrNo::Success; - // Use idx = -1 to sample the next token. - const llama_token Id = common_sampler_sample( - CxtRef.LlamaSampler, GraphRef.LlamaContext.get(), /* idx */ -1); - common_sampler_accept(CxtRef.LlamaSampler, Id, /* accept_grammar */ true); - - // Save the output token. - // In single token mode, we do not handle StreamStdout and ReversePrompt. - CxtRef.LlamaOutputTokens.emplace_back(Id); - CxtRef.LlamaOutputs += common_token_to_piece(GraphRef.LlamaContext.get(), Id); - // Deal with end of text token. - if (llama_token_is_eog(GraphRef.LlamaModel.get(), - common_sampler_last(CxtRef.LlamaSampler))) { - ReturnCode = ErrNo::EndOfSequence; - LOG_INFO(GraphRef.EnableLog, "EOS token found."sv) - } - // Evaluate the output token if not EOS. - if (ReturnCode != ErrNo::EndOfSequence) { - ReturnCode = evaluateTokens(GraphRef, GraphRef.LlamaContext.get(), {Id}, - CxtRef.LlamaNPast, CxtRef.LlamaNPos); + ReturnCode = sampleOutput(GraphRef, CxtRef, true); + if (ReturnCode != ErrNo::Success) { + CxtRef.ComputeSingleStarted = false; } LOG_DEBUG(GraphRef.EnableDebugLog, "computeSingle: enter main prediction process...Done"sv) @@ -1438,25 +1408,10 @@ Expect finiSingle(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { LOG_DEBUG(GraphRef.EnableDebugLog, "finiSingle: clear the previous output and tokens...Done"sv) - // Clear the llama context. - LOG_DEBUG(GraphRef.EnableDebugLog, "finiSingle: clear the llama context"sv) - llama_kv_cache_clear(GraphRef.LlamaContext.get()); + // Reset the llama sampler. common_sampler_reset(CxtRef.LlamaSampler); - common_sampler_free(CxtRef.LlamaSampler); - CxtRef.LlamaSampler = nullptr; - if (GraphRef.ClipContext != nullptr) { - clip_free(GraphRef.ClipContext); - GraphRef.ClipContext = nullptr; - } - if (CxtRef.LlavaImageEmbd != nullptr) { - llava_image_embed_free(CxtRef.LlavaImageEmbd); - CxtRef.LlavaImageEmbd = nullptr; - } - LOG_DEBUG(GraphRef.EnableDebugLog, - "finiSingle: clear the llama context...Done"sv) - - // Reset the context variables. - CxtRef.LlamaNPast = 0; + CxtRef.ComputeSingleStarted = false; + CxtRef.NPos = 0; LOG_DEBUG(GraphRef.EnableDebugLog, "finiSingle...Done"sv) return ErrNo::Success; @@ -1467,6 +1422,7 @@ Expect unload(WasiNNEnvironment &Env, uint32_t GraphId) noexcept { const bool IsDebugLog = GraphRef.EnableDebugLog; LOG_DEBUG(IsDebugLog, "unload"sv) + // TODO: Move the resource deallocation into the destructor. if (GraphRef.LlamaModel != nullptr) { LOG_DEBUG(IsDebugLog, "unload: free llama model"sv) GraphRef.LlamaModel.reset(); @@ -1477,8 +1433,15 @@ Expect unload(WasiNNEnvironment &Env, uint32_t GraphId) noexcept { GraphRef.LlamaContext.reset(); LOG_DEBUG(IsDebugLog, "unload: free llama context...Done"sv) } + if (GraphRef.ClipContext != nullptr) { + LOG_DEBUG(IsDebugLog, "unload: free clip context"sv) + clip_free(GraphRef.ClipContext); + GraphRef.ClipContext = nullptr; + LOG_DEBUG(IsDebugLog, "unload: free clip context...Done"sv) + } Env.deleteGraph(GraphId); Env.mdRemoveById(GraphId); + LOG_DEBUG(IsDebugLog, "unload...Done"sv) return ErrNo::Success; } @@ -1488,7 +1451,29 @@ Expect finalizeExecCtx(WasiNNEnvironment &Env, auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); LOG_DEBUG(GraphRef.EnableDebugLog, "finalize_execution_context"sv) + + // TODO: Move the resource deallocation into the destructor. + if (CxtRef.LlavaImageEmbd != nullptr) { + LOG_DEBUG(GraphRef.EnableDebugLog, + "finalize_execution_context: free llava image embed"sv) + llava_image_embed_free(CxtRef.LlavaImageEmbd); + CxtRef.LlavaImageEmbd = nullptr; + LOG_DEBUG(GraphRef.EnableDebugLog, + "finalize_execution_context: free llava image embed...Done"sv) + } + if (CxtRef.LlamaSampler != nullptr) { + LOG_DEBUG(GraphRef.EnableDebugLog, + "finalize_execution_context: free compute_single sampler"sv) + common_sampler_free(CxtRef.LlamaSampler); + CxtRef.LlamaSampler = nullptr; + LOG_DEBUG( + GraphRef.EnableDebugLog, + "finalize_execution_context: free compute_single sampler...Done"sv) + } + llama_batch_free(CxtRef.LlamaBatch); + llama_batch_free(CxtRef.OutputBatch); Env.deleteContext(ContextId); + LOG_DEBUG(GraphRef.EnableDebugLog, "finalize_execution_context...Done"sv) return ErrNo::Success; } diff --git a/plugins/wasi_nn/wasinn_ggml.h b/plugins/wasi_nn/wasinn_ggml.h index abdbba09..0118efa8 100644 --- a/plugins/wasi_nn/wasinn_ggml.h +++ b/plugins/wasi_nn/wasinn_ggml.h @@ -38,30 +38,38 @@ enum class VisionModel : uint8_t { Qwen2VL = 1, }; -struct Graph { - llama_model_ptr LlamaModel = nullptr; - std::string ModelFilePath; - llama_context_ptr LlamaContext = nullptr; - struct clip_ctx *ClipContext = nullptr; - // Plugin parameters: - bool EnableLog = false; - bool EnableDebugLog = false; +struct LocalConfig { + // Configurations which can be changed in every contexts. + // The graph handles a default config and parsed from metadata when loading. + // The context inherits a copy from graph when creating, and can be modified + // when parsing metadata in set_input. bool StreamStdout = false; - bool Embedding = false; EmbdNormalizeType EmbdNormalize = EmbdNormalizeType::Euclidean; - bool ComputeSingleStarted = false; int64_t NPredict; std::string ReversePrompt; - std::string MMProjModelPath; std::string ImagePath; - VisionModel VisionModelType = VisionModel::Llava; +}; + +struct Graph { + // Plugin parameters: + bool EnableLog = false; + bool EnableDebugLog = false; // Model parameters: int64_t MainGPU = 0; // Use GPU 0 by default int64_t NGPULayers = 0; std::vector TensorSplit; + bool Embedding = false; bool UseMMap = true; bool WarmUp = false; enum llama_split_mode SplitMode = LLAMA_SPLIT_MODE_LAYER; + // Model context: + llama_model_ptr LlamaModel = nullptr; + llama_context_ptr LlamaContext = nullptr; + std::string ModelFilePath; + // Clip context (for llava): + std::string MMProjModelPath; + struct clip_ctx *ClipContext = nullptr; + VisionModel VisionModelType = VisionModel::Llava; // Context parameters: int64_t CtxSize; int64_t BatchSize; @@ -75,23 +83,34 @@ struct Graph { double FrequencyPenalty = 0.00; std::string Grammar; uint64_t Seed = LLAMA_DEFAULT_SEED; + // Configs. + LocalConfig Conf; }; struct Context { public: - Context(uint32_t GId, Graph &) noexcept : GraphId(GId) {} + Context(uint32_t GId, Graph &G) noexcept : GraphId(GId), Conf(G.Conf) {} uint32_t GraphId; + // Llama inputs: std::vector LlamaInputs; uint64_t LlamaNInputs = 0; + // Llama outputs: std::string LlamaOutputs; std::vector LlamaOutputTokens; - // Preserve for computing single token - common_sampler *LlamaSampler = nullptr; - int32_t LlamaNPast = 0; - int32_t LlamaNPos = 0; // Preserve for llava struct llava_image_embed *LlavaImageEmbd = nullptr; + // Data for computing: + bool ComputeSingleStarted = false; + struct common_sampler *LlamaSampler = nullptr; + // Handle the batch in the context to prevent from reallocation in every + // computing. + struct llama_batch LlamaBatch; + struct llama_batch OutputBatch; + int64_t CurrentBatchSize = 0; size_t ImagePosition = 0; + int32_t NPos = 0; + // Configs: + LocalConfig Conf; }; #else struct Graph {}; From 3c68df56eb909f4935a971d57526116ca1edff8e Mon Sep 17 00:00:00 2001 From: YiYing He Date: Fri, 10 Jan 2025 05:12:21 +0800 Subject: [PATCH 520/623] [WASI-NN] ggml: fix Qwen2vl models after refactoring. Signed-off-by: YiYing He --- plugins/wasi_nn/wasinn_ggml.cpp | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index c4141cd1..7b907810 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -653,13 +653,31 @@ ErrNo evaluateTokens(Span Tokens, Graph &GraphRef, if (NEval > static_cast(GraphRef.BatchSize)) { NEval = static_cast(GraphRef.BatchSize); } + + // LlamaPos for Qwen2VL. + static std::vector LlamaPos; + if (GraphRef.VisionModelType == VisionModel::Qwen2VL) { + LlamaPos.resize(NEval * 4); + std::fill(LlamaPos.begin(), LlamaPos.end(), 0); + for (int J = 0; J < NEval * 3; J++) { + LlamaPos[J] = NPos + (J % NEval); + } + } + // Fill the batch with pos information. fillBatch(Span(Tokens.begin() + I, NEval), GraphRef, Batch, NPos, IsLogits && I + NEval >= static_cast(Tokens.size())); + // Set the LlamaPos for Qwen2VL. + llama_pos *OriginBatchPos = Batch.pos; + if (GraphRef.VisionModelType == VisionModel::Qwen2VL) { + Batch.pos = LlamaPos.data(); + } + // Decode the batch. auto Status = llama_decode(GraphRef.LlamaContext.get(), Batch); + Batch.pos = OriginBatchPos; if (Status == 1) { RET_ERROR(ErrNo::RuntimeError, "failed to llama_decode: try reducing the size of the batch "sv From a49f20a70e8663fca6d4bfc03c50da9e728acbdb Mon Sep 17 00:00:00 2001 From: hydai Date: Mon, 13 Jan 2025 13:22:54 +0800 Subject: [PATCH 521/623] [WASI-NN] mlx: fix typos (#3961) [WASI-NN] mlx: fix typo Signed-off-by: hydai --- plugins/wasi_nn/wasinn_mlx.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/plugins/wasi_nn/wasinn_mlx.cpp b/plugins/wasi_nn/wasinn_mlx.cpp index 3f600228..bd50e47e 100644 --- a/plugins/wasi_nn/wasinn_mlx.cpp +++ b/plugins/wasi_nn/wasinn_mlx.cpp @@ -30,13 +30,13 @@ std::string loadBytesFromFile(const std::string &Path) { return Data; } -enum AnserSataus { +enum AnswerSataus { STOP, WAIT, GO, }; -AnserSataus answerSataus(std::string Text, std::string End) { +AnswerSataus answerSataus(std::string Text, std::string End) { if (endsWith(Text, End)) { return STOP; } @@ -311,7 +311,7 @@ Expect compute(WasiNNEnvironment &Env, // TODO: break when the token is the eos_token_id TokenList.insert(TokenList.end(), Tokens.begin(), Tokens.end()); Answer = GraphRef.Tok->Decode(TokenList); - const AnserSataus Status = answerSataus(Answer, GraphRef.Prmopt.TextEnd); + const AnswerSataus Status = answerSataus(Answer, GraphRef.Prmopt.TextEnd); if (Status == STOP) { break; } From cd9cdf6661295584b7b5ae27916e3210c7a93b29 Mon Sep 17 00:00:00 2001 From: hydai Date: Fri, 10 Jan 2025 21:37:07 +0800 Subject: [PATCH 522/623] [WASI-NN] ggml: bump to llama.cpp b4458 Signed-off-by: hydai --- plugins/wasi_nn/wasinn_ggml.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 7b907810..529abc35 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -1092,7 +1092,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, common_sampler_free(CxtRef.LlamaSampler); CxtRef.LlamaSampler = nullptr; } - GraphRef.LlamaModel = llama_model_ptr(llama_load_model_from_file( + GraphRef.LlamaModel = llama_model_ptr(llama_model_load_from_file( GraphRef.ModelFilePath.c_str(), ModelParams)); if (GraphRef.LlamaModel == nullptr) { Env.NNGraph[CxtRef.GraphId].setInvalid(); From 2c8a735d0d99085523c78ed90a9ca247a127127c Mon Sep 17 00:00:00 2001 From: hydai Date: Mon, 13 Jan 2025 17:01:31 +0800 Subject: [PATCH 523/623] [WASI-NN] ggml: apply the internal function renaming Signed-off-by: hydai --- plugins/wasi_nn/wasinn_ggml.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 529abc35..30e22cf0 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -574,7 +574,7 @@ bool evaluateQwen2vlImageEmbed(llama_context *LlamaCxt, const struct llava_image_embed *ImageEmbed, int64_t NBatch, int32_t &NPos, struct clip_image_size *ImageSize) { - int NEmbd = llama_n_embd(llama_get_model(LlamaCxt)); + int NEmbd = llama_model_n_embd(llama_get_model(LlamaCxt)); const int PatchSize = 14 * 2; const int Ph = ImageSize->height / PatchSize + (ImageSize->height % PatchSize > 0); @@ -817,8 +817,8 @@ ErrNo sampleOutput(Graph &GraphRef, Context &CxtRef, } } // Deal with end of text token. - if (llama_token_is_eog(GraphRef.LlamaModel.get(), - common_sampler_last(CxtRef.LlamaSampler))) { + const llama_vocab *Vocab = llama_model_get_vocab(GraphRef.LlamaModel.get()); + if (llama_vocab_is_eog(Vocab, common_sampler_last(CxtRef.LlamaSampler))) { LOG_INFO(GraphRef.EnableLog, "EOS token found."sv) return ErrNo::EndOfSequence; } @@ -831,9 +831,10 @@ ErrNo sampleOutput(Graph &GraphRef, Context &CxtRef, Expect getEmbedding(Graph &GraphRef, Context &CxtRef) noexcept { LOG_DEBUG(GraphRef.EnableDebugLog, "getEmbedding"sv) + const llama_vocab *Vocab = llama_model_get_vocab(GraphRef.LlamaModel.get()); // Add SEP if not present. if (CxtRef.LlamaInputs.size() > 0 && - CxtRef.LlamaInputs.back() != llama_token_sep(GraphRef.LlamaModel.get())) { + CxtRef.LlamaInputs.back() != llama_vocab_sep(Vocab)) { LOG_WARN( "last token in the prompt is not SEP, "sv "'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF "sv @@ -856,7 +857,7 @@ Expect getEmbedding(Graph &GraphRef, Context &CxtRef) noexcept { } // Main prediction loop. - const int32_t NEmbd = llama_n_embd(GraphRef.LlamaModel.get()); + const int32_t NEmbd = llama_model_n_embd(GraphRef.LlamaModel.get()); std::vector Embeddings(NEmbd); for (int I = 0; I < CxtRef.LlamaBatch.n_tokens; I++) { @@ -1110,7 +1111,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, GraphRef.LlamaContext.reset(); common_params Params; setupCommonParams(GraphRef, Params); - GraphRef.LlamaContext = llama_context_ptr(llama_new_context_with_model( + GraphRef.LlamaContext = llama_context_ptr(llama_init_from_model( GraphRef.LlamaModel.get(), common_context_params_to_llama(Params))); if (GraphRef.LlamaContext == nullptr) { Env.NNGraph[CxtRef.GraphId].setInvalid(); From 29518dee04887367d92580622a0fdd30465a45b1 Mon Sep 17 00:00:00 2001 From: hydai Date: Mon, 13 Jan 2025 21:57:55 +0800 Subject: [PATCH 524/623] [WASI-NN] ggml: Use the cached image embed instead to reduce costs (#3964) In this PR, we changed the logic for handling the image embd for the llava and qwen2vl models. 1. There is a new option call `always-regenerate-image-embd`, whose default value is `false`. If this option is set to `true`, it will recompute the image embedding whatever the previous image embedding is consumed or not. Otherwise, the following rules will apply. 2. When the base64 payload or the image path is provided, it will check if there is an existing computed image embd. If so, the image embd will never recompute it to reduce time costs unless the cached image embd is consumed by the `compute` function. 3. The reason we have to apply this workaround is that llama.cpp has disabled all GPU backends for the Clip model. It is quite slow because the entire workload is running solely on the CPU. Signed-off-by: hydai Signed-off-by: hydai --- plugins/wasi_nn/wasinn_ggml.cpp | 210 +++++++++++++++++++++----------- plugins/wasi_nn/wasinn_ggml.h | 1 + 2 files changed, 138 insertions(+), 73 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 30e22cf0..c5523f14 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -44,6 +44,9 @@ namespace { // Macro for logging warning message. #define LOG_WARN(...) spdlog::warn("[WASI-NN] GGML backend: "sv __VA_ARGS__); +// Macro for logging error message. +#define LOG_ERROR(...) spdlog::error("[WASI-NN] GGML backend: "sv __VA_ARGS__); + // Macro for logging error message and return. #define RET_ERROR(Error, ...) \ spdlog::error("[WASI-NN] GGML backend: "sv __VA_ARGS__); \ @@ -134,6 +137,7 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, // batch-size: int64_t // ubatch-size: int64_t // threads: int64_t + // [local-config] always-regenerate-image-embd: bool // Sampling parameters (used by the llama sampling context): // temp: double // top-p: double @@ -245,20 +249,18 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, std::string_view SplitMode; auto Err = Doc["split-mode"].get().get(SplitMode); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the split-mode option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the split-mode option."sv) } - if (SplitMode == "none") { + if (SplitMode == "none"sv) { GraphRef.SplitMode = LLAMA_SPLIT_MODE_NONE; - } else if (SplitMode == "layer") { + } else if (SplitMode == "layer"sv) { GraphRef.SplitMode = LLAMA_SPLIT_MODE_LAYER; - } else if (SplitMode == "row") { + } else if (SplitMode == "row"sv) { GraphRef.SplitMode = LLAMA_SPLIT_MODE_ROW; } else { - spdlog::error("[WASI-NN] GGML backend: Invalid split-mode option: {}"sv, - SplitMode); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, + "Unknown split-mode: {}. Valid: none, layer, row."sv, SplitMode) } } if (Doc.at_key("mmproj").error() == simdjson::SUCCESS) { @@ -361,9 +363,7 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, if (Doc.at_key("seed").error() == simdjson::SUCCESS) { auto Err = Doc["seed"].get().get(GraphRef.Seed); if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the seed option."sv); - return ErrNo::InvalidArgument; + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the seed option."sv) } } @@ -400,6 +400,14 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, } ConfRef.ImagePath = ImagePath; } + if (Doc.at_key("always-regenerate-image-embd").error() == simdjson::SUCCESS) { + auto Err = Doc["always-regenerate-image-embd"].get().get( + ConfRef.AlwaysRegenerateImageEmbd); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the always-regenerate-image-embd option."sv) + } + } // Check if the model parameters are updated. if (IsModelUpdated && PrevNGPULayers != GraphRef.NGPULayers) { @@ -447,13 +455,13 @@ findBase64ImagePayload(std::string_view Prompt, // Find `;base64,` (skip the image type part) auto PayloadPos = Prompt.find(Base64ImageBytesPrefix, BeginTagPos); if (PayloadPos == std::string::npos) { - LOG_DEBUG(IsDebugLog, "Cannot locate the payload."sv) + LOG_DEBUG(IsDebugLog, "base64: Cannot locate the payload."sv) return std::nullopt; } // Find `">` auto EndTagPos = Prompt.find(Base64ImageTagSuffix, PayloadPos); if (EndTagPos == std::string::npos) { - LOG_DEBUG(IsDebugLog, "Base64 image tag unclosed."sv) + LOG_DEBUG(IsDebugLog, "base64: image tag unclosed."sv) return std::nullopt; } return std::make_tuple(BeginTagPos, PayloadPos, EndTagPos); @@ -481,7 +489,7 @@ extractBase64ImagePayload(std::string &Prompt, base64::decode(Payload.begin(), Payload.end(), ImageBytes.begin()); } catch (const base64_error &E) { RET_ERROR(std::make_pair(std::vector(), ""), - "Error when base64::decode: {}"sv, E.what()) + "base64: Error when calling base64::decode: {}"sv, E.what()) } // Replace the base64 image with the placeholder. @@ -639,10 +647,11 @@ ErrNo evaluateTokens(Span Tokens, Graph &GraphRef, // End the inference if the context is full. uint32_t NCtx = llama_n_ctx(GraphRef.LlamaContext.get()); if (NPos + static_cast(Tokens.size()) > NCtx) { - LOG_INFO(GraphRef.EnableLog, - "the context if full ({} / {} tokens). Please increase your "sv - "context size."sv, - NPos + static_cast(Tokens.size()), NCtx) + LOG_INFO( + GraphRef.EnableLog, + "evaluateTokens: the context if full ({} / {} tokens). Please increase your "sv + "context size."sv, + NPos + static_cast(Tokens.size()), NCtx) return ErrNo::ContextFull; } @@ -679,13 +688,15 @@ ErrNo evaluateTokens(Span Tokens, Graph &GraphRef, auto Status = llama_decode(GraphRef.LlamaContext.get(), Batch); Batch.pos = OriginBatchPos; if (Status == 1) { - RET_ERROR(ErrNo::RuntimeError, - "failed to llama_decode: try reducing the size of the batch "sv - "or increasing the size of context."sv) + RET_ERROR( + ErrNo::RuntimeError, + "evaluateTokens: failed to llama_decode: try reducing the size of the batch "sv + "or increasing the size of context."sv) } else if (Status < 0) { - RET_ERROR(ErrNo::RuntimeError, - "failed to llama_decode: internal fatal error. Please open "sv - "an issue on GitHub."sv) + RET_ERROR( + ErrNo::RuntimeError, + "evaluateTokens: failed to llama_decode: internal fatal error. Please open "sv + "an issue on GitHub."sv) } } @@ -745,15 +756,23 @@ ErrNo evaluateInput(Graph &GraphRef, Context &CxtRef, bool EvalImageStatus = false; switch (GraphRef.VisionModelType) { case VisionModel::Llava: + LOG_DEBUG(GraphRef.EnableDebugLog, "{}: Eval llava image embd"sv, + LogPrefix) EvalImageStatus = llava_eval_image_embed( GraphRef.LlamaContext.get(), CxtRef.LlavaImageEmbd, static_cast(GraphRef.BatchSize), &CxtRef.NPos); + LOG_DEBUG(GraphRef.EnableDebugLog, "{}: Eval llava image embd...done"sv, + LogPrefix) break; case VisionModel::Qwen2VL: + LOG_DEBUG(GraphRef.EnableDebugLog, "{}: Eval Qwen2VL image embd"sv, + LogPrefix) auto ImageSize = clip_get_load_image_size(GraphRef.ClipContext); EvalImageStatus = evaluateQwen2vlImageEmbed( GraphRef.LlamaContext.get(), CxtRef.LlavaImageEmbd, static_cast(GraphRef.BatchSize), CxtRef.NPos, ImageSize); + LOG_DEBUG(GraphRef.EnableDebugLog, "{}: Eval Qwen2VL image embd...done"sv, + LogPrefix) break; } @@ -783,6 +802,7 @@ ErrNo evaluateInput(Graph &GraphRef, Context &CxtRef, CxtRef.Conf.ImagePath = ""sv; if (CxtRef.LlavaImageEmbd != nullptr) { + LOG_DEBUG(GraphRef.EnableDebugLog, "{}: ImageEmbd consumed"sv, LogPrefix) llava_image_embed_free(CxtRef.LlavaImageEmbd); CxtRef.LlavaImageEmbd = nullptr; } @@ -812,14 +832,14 @@ ErrNo sampleOutput(Graph &GraphRef, Context &CxtRef, if (!CxtRef.Conf.ReversePrompt.empty() && CxtRef.LlamaOutputs.find(CxtRef.Conf.ReversePrompt) != std::string::npos) { - LOG_INFO(GraphRef.EnableLog, "reverse prompt found."sv) + LOG_INFO(GraphRef.EnableLog, "sampleOutput: reverse prompt found."sv) return ErrNo::EndOfSequence; } } // Deal with end of text token. const llama_vocab *Vocab = llama_model_get_vocab(GraphRef.LlamaModel.get()); if (llama_vocab_is_eog(Vocab, common_sampler_last(CxtRef.LlamaSampler))) { - LOG_INFO(GraphRef.EnableLog, "EOS token found."sv) + LOG_INFO(GraphRef.EnableLog, "sampleOutput: EOS token found."sv) return ErrNo::EndOfSequence; } // Evaluate the output token. @@ -836,7 +856,7 @@ Expect getEmbedding(Graph &GraphRef, Context &CxtRef) noexcept { if (CxtRef.LlamaInputs.size() > 0 && CxtRef.LlamaInputs.back() != llama_vocab_sep(Vocab)) { LOG_WARN( - "last token in the prompt is not SEP, "sv + "getEmbedding: last token in the prompt is not SEP, "sv "'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF "sv "header."sv) } @@ -845,7 +865,7 @@ Expect getEmbedding(Graph &GraphRef, Context &CxtRef) noexcept { if (static_cast(CxtRef.LlamaInputs.size()) > GraphRef.BatchSize) { RET_ERROR( ErrNo::PromptTooLong, - "the prompt is too long. Your input has {} tokens exceeds batch "sv + "getEmbedding: the prompt is too long. Your input has {} tokens exceeds batch "sv "size {}. Please reduce the input size or increase your batch-size."sv, CxtRef.LlamaInputs.size(), GraphRef.BatchSize) } @@ -871,9 +891,7 @@ Expect getEmbedding(Graph &GraphRef, Context &CxtRef) noexcept { if (Embd == nullptr) { Embd = llama_get_embeddings_ith(GraphRef.LlamaContext.get(), I); if (Embd == nullptr) { - spdlog::error( - "[WASI-NN] GGML backend: failed to get embeddings for token {}"sv, - I); + LOG_ERROR("getEmbedding: failed to get embeddings for token {}"sv, I); continue; } } @@ -944,7 +962,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, auto Res = parseMetadata(GraphRef, GraphRef.Conf, Metadata); if (Res != ErrNo::Success) { Env.deleteGraph(GId); - RET_ERROR(Res, "Failed to parse metadata."sv) + RET_ERROR(Res, "load: Failed to parse metadata."sv) } } @@ -972,7 +990,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, if (!TempFile) { Env.deleteGraph(GId); RET_ERROR(ErrNo::InvalidArgument, - "Failed to create the temporary file. Currently, our "sv + "load: Failed to create the temporary file. Currently, our "sv "workaround involves creating a temporary model file named "sv "\"ggml-model.bin\" and passing this filename as a "sv "parameter to the ggml llama library."sv) @@ -988,7 +1006,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, if (!std::filesystem::exists( std::filesystem::u8path(GraphRef.ModelFilePath))) { Env.deleteGraph(GId); - RET_ERROR(ErrNo::ModelNotFound, "model file not found."sv) + RET_ERROR(ErrNo::ModelNotFound, "load: model file not found."sv) } // Initialize ggml parameters. @@ -1005,11 +1023,11 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, GraphRef.LlamaContext = std::move(LlamaInit.context); if (GraphRef.LlamaModel == nullptr) { Env.deleteGraph(GId); - RET_ERROR(ErrNo::InvalidArgument, "unable to init model."sv) + RET_ERROR(ErrNo::InvalidArgument, "load: unable to init model."sv) } if (GraphRef.LlamaContext == nullptr) { Env.deleteGraph(GId); - RET_ERROR(ErrNo::InvalidArgument, "unable to init context."sv) + RET_ERROR(ErrNo::InvalidArgument, "load: unable to init context."sv) } LOG_DEBUG(GraphRef.EnableDebugLog, "load: initialize ggml model with given parameters...Done"sv) @@ -1067,7 +1085,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, parseMetadata(GraphRef, CxtRef.Conf, Metadata, &IsModelParamsUpdated, &IsContextParamsUpdated, &IsSamplerParamsUpdated); if (Res != ErrNo::Success) { - RET_ERROR(Res, "failed to parse metadata."sv) + RET_ERROR(Res, "setInput: failed to parse metadata."sv) } #ifndef __APPLE__ @@ -1080,7 +1098,8 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, // The llama model may be nullptr if set_input with updated model params // last time. Therefore besides the model params updated, we should // reload the llama model if the model is nullptr. - LOG_INFO(GraphRef.EnableLog, "Reload model due to parameters change."sv) + LOG_INFO(GraphRef.EnableLog, + "setInput: Reload model due to parameters change."sv) llama_model_params ModelParams = llama_model_default_params(); ModelParams.n_gpu_layers = static_cast(GraphRef.NGPULayers); GraphRef.LlamaModel.reset(); @@ -1097,7 +1116,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, GraphRef.ModelFilePath.c_str(), ModelParams)); if (GraphRef.LlamaModel == nullptr) { Env.NNGraph[CxtRef.GraphId].setInvalid(); - RET_ERROR(ErrNo::InvalidArgument, "unable to init model."sv) + RET_ERROR(ErrNo::InvalidArgument, "setInput: unable to init model."sv) } } } @@ -1107,7 +1126,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, // reloaded. if (IsContextParamsUpdated || GraphRef.LlamaContext == nullptr) { LOG_INFO(GraphRef.EnableLog, - "Reload llama context due to parameters change."sv) + "setInput: Reload llama context due to parameters change."sv) GraphRef.LlamaContext.reset(); common_params Params; setupCommonParams(GraphRef, Params); @@ -1115,7 +1134,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, GraphRef.LlamaModel.get(), common_context_params_to_llama(Params))); if (GraphRef.LlamaContext == nullptr) { Env.NNGraph[CxtRef.GraphId].setInvalid(); - RET_ERROR(ErrNo::InvalidArgument, "unable to init context."sv) + RET_ERROR(ErrNo::InvalidArgument, "setInput: unable to init context."sv) } } @@ -1123,7 +1142,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, // reallocated. if (IsSamplerParamsUpdated || CxtRef.LlamaSampler == nullptr) { LOG_INFO(GraphRef.EnableLog, - "Reallocate llama sampler due to parameters change."sv) + "setInput: Reallocate llama sampler due to parameters change."sv) if (CxtRef.LlamaSampler) { common_sampler_free(CxtRef.LlamaSampler); } @@ -1133,7 +1152,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, common_sampler_init(GraphRef.LlamaModel.get(), CommonSampling); if (GraphRef.LlamaContext == nullptr) { Env.NNGraph[CxtRef.GraphId].setInvalid(); - RET_ERROR(ErrNo::InvalidArgument, "unable to init sampler."sv) + RET_ERROR(ErrNo::InvalidArgument, "setInput: unable to init sampler."sv) } } @@ -1152,9 +1171,10 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, // Check the graph is valid after reloading during previous set_input. if (!Env.NNGraph[CxtRef.GraphId].isReady()) { - RET_ERROR(ErrNo::InvalidArgument, - "Graph is invalid. Please reload again by passing metadata "sv - "in set_input or unload graph."sv) + RET_ERROR( + ErrNo::InvalidArgument, + "setInput: Graph is invalid. Please reload again by passing metadata "sv + "in set_input or unload graph."sv) } // Clear the llama context. @@ -1168,10 +1188,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), Tensor.Tensor.size()); CxtRef.LlamaInputs.clear(); - if (CxtRef.LlavaImageEmbd != nullptr) { - llava_image_embed_free(CxtRef.LlavaImageEmbd); - CxtRef.LlavaImageEmbd = nullptr; - } + auto Base64ImagePos = findBase64ImagePayload(Prompt); if (Base64ImagePos.has_value() || CxtRef.Conf.ImagePath != ""sv) { @@ -1181,20 +1198,22 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, if (GraphRef.ClipContext == nullptr) { LOG_INFO( true, - "Load the clip model. Because llama.cpp disabled the GPU support "sv + "setInput: Load the clip model. Because llama.cpp disabled the GPU support "sv "for CLIP, the step of loading images in CLIP can only use the "sv "CPU, which may result in reduced efficiency. (You can refer to "sv "PR https://github.com/ggerganov/llama.cpp/pull/10896)"sv) GraphRef.ClipContext = clip_model_load(GraphRef.MMProjModelPath.c_str(), GraphRef.EnableLog ? 1 : 0); if (GraphRef.ClipContext == nullptr) { - RET_ERROR(ErrNo::InvalidArgument, "unable to load the clip model."sv) + RET_ERROR(ErrNo::InvalidArgument, + "setInput: unable to load the clip model."sv) } if (clip_is_qwen2vl(GraphRef.ClipContext)) { GraphRef.VisionModelType = VisionModel::Qwen2VL; - LOG_INFO(true, "Qwen2vl model loaded."sv) + LOG_INFO(true, "setInput: Qwen2vl model loaded."sv) } else { GraphRef.VisionModelType = VisionModel::Llava; + LOG_INFO(true, "setInput: Llava model loaded."sv) } } @@ -1208,7 +1227,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, if (GraphRef.CtxSize < 4096) { LOG_INFO( GraphRef.EnableLog, - "Context size is {}, we recommend context size >= 2048 when "sv + "setInput: Context size is {}, we recommend context size >= 2048 when "sv "using llava-v1.5 and context size >= 4096 when using llava-v1.6 "sv "for better results."sv, GraphRef.CtxSize) @@ -1218,29 +1237,74 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, // Follow this link for the supported image formats: // https://github.com/ggerganov/llama.cpp/blob/master/common/stb_image.h if (Base64ImagePos.has_value()) { + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: Compute image embd from the base64 image."sv) // Extract the payload and image type from the prompt. auto Payload = extractBase64ImagePayload(Prompt, *Base64ImagePos, LlavaPromptImagePlaceholder); if (Payload.has_value()) { - CxtRef.LlavaImageEmbd = llava_image_embed_make_with_bytes( - GraphRef.ClipContext, static_cast(GraphRef.Threads), - Payload->first.data(), static_cast(Payload->first.size())); + // Only regenerate the image embedding if the + // always-regenerate-image-embd is on or the image embedding is not + // yet computed. + if (CxtRef.LlavaImageEmbd == nullptr || + CxtRef.Conf.AlwaysRegenerateImageEmbd) { + // Free existing image embedding if regeneration is needed + if (CxtRef.LlavaImageEmbd != nullptr) { + llava_image_embed_free(CxtRef.LlavaImageEmbd); + CxtRef.LlavaImageEmbd = nullptr; + } + + // Create a new image embedding + CxtRef.LlavaImageEmbd = llava_image_embed_make_with_bytes( + GraphRef.ClipContext, static_cast(GraphRef.Threads), + Payload->first.data(), static_cast(Payload->first.size())); + } else { + LOG_DEBUG( + GraphRef.EnableDebugLog, + "setInput: Previous image embd is not yet consumed. Use the cached base64 image embd instead of computing a new one"sv) + } } + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: Compute image embd from the base64 image...Done"sv) } else { - // Load the image from the file. - CxtRef.LlavaImageEmbd = llava_image_embed_make_with_filename( - GraphRef.ClipContext, static_cast(GraphRef.Threads), - CxtRef.Conf.ImagePath.c_str()); + // Only regenerate the image embedding if the + // always-regenerate-image-embd is on or the image embedding is not yet + // computed. + if (CxtRef.LlavaImageEmbd == nullptr || + CxtRef.Conf.AlwaysRegenerateImageEmbd) { + // Free existing image embedding if regeneration is needed + if (CxtRef.LlavaImageEmbd != nullptr) { + llava_image_embed_free(CxtRef.LlavaImageEmbd); + CxtRef.LlavaImageEmbd = nullptr; + } + + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: Compute image embd from file: {}"sv, + CxtRef.Conf.ImagePath) + // Load the image from the file. + CxtRef.LlavaImageEmbd = llava_image_embed_make_with_filename( + GraphRef.ClipContext, static_cast(GraphRef.Threads), + CxtRef.Conf.ImagePath.c_str()); + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: Compute image embd from file: {}...Done"sv, + CxtRef.Conf.ImagePath) + } else { + LOG_DEBUG( + GraphRef.EnableDebugLog, + "setInput: Previous image embd is not yet consumed. Use the cached image embd instead of computing a new one"sv) + } } if (CxtRef.LlavaImageEmbd == nullptr) { - RET_ERROR(ErrNo::InvalidArgument, "llava unable to load the image."sv) + RET_ERROR(ErrNo::InvalidArgument, + "setInput: llava unable to load the image."sv) } // We split prompt by as placeholder and save the position. auto PlaceholderPosition = Prompt.find(LlavaPromptImagePlaceholder); if (PlaceholderPosition == std::string::npos) { - RET_ERROR(ErrNo::InvalidArgument, - "unable to find the placeholder in the llava prompt."sv) + RET_ERROR( + ErrNo::InvalidArgument, + "setInput: unable to find the placeholder in the llava prompt."sv) } std::string PromptBeforeImage = Prompt.substr(0, PlaceholderPosition); std::string PromptAfterImage = Prompt.substr( @@ -1286,22 +1350,22 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, uint32_t &BytesWritten) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - LOG_DEBUG(GraphRef.EnableDebugLog, "getOutput with Index {}"sv, Index) + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutput: with Index {}"sv, Index) // Use index 1 for the metadata of the outputs. if (Index == 1) { std::string Metadata = buildOutputMetadata(CxtRef); std::copy_n(Metadata.data(), Metadata.length(), OutBuffer.data()); BytesWritten = static_cast(Metadata.length()); - LOG_DEBUG(GraphRef.EnableDebugLog, "getOutput with Index {}...Done"sv, - Index) + LOG_DEBUG(GraphRef.EnableDebugLog, + "getOutput: with Index {} a.k.a Metadata ...Done"sv, Index) return ErrNo::Success; } std::copy_n(CxtRef.LlamaOutputs.data(), CxtRef.LlamaOutputs.length(), OutBuffer.data()); BytesWritten = static_cast(CxtRef.LlamaOutputs.length()); - LOG_DEBUG(GraphRef.EnableDebugLog, "getOutput with Index {}...Done"sv, Index) + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutput: with Index {}...Done"sv, Index) return ErrNo::Success; } @@ -1352,15 +1416,15 @@ Expect getOutputSingle(WasiNNEnvironment &Env, uint32_t ContextId, uint32_t &BytesWritten) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - LOG_DEBUG(GraphRef.EnableDebugLog, "getOutputSingle with Index {}"sv, Index) + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutputSingle: with Index {}"sv, Index) // Use index 1 for the metadata of the outputs. if (Index == 1) { std::string Metadata = buildOutputMetadata(CxtRef); std::copy_n(Metadata.data(), Metadata.length(), OutBuffer.data()); BytesWritten = static_cast(Metadata.length()); - LOG_DEBUG(GraphRef.EnableDebugLog, "getOutputSingle with Index {}...Done"sv, - Index) + LOG_DEBUG(GraphRef.EnableDebugLog, + "getOutputSingle: with Index {} a.k.a Metadata...Done"sv, Index) return ErrNo::Success; } @@ -1368,7 +1432,7 @@ Expect getOutputSingle(WasiNNEnvironment &Env, uint32_t ContextId, GraphRef.LlamaContext.get(), CxtRef.LlamaOutputTokens.back()); std::copy_n(LastToken.data(), LastToken.length(), OutBuffer.data()); BytesWritten = static_cast(LastToken.length()); - LOG_DEBUG(GraphRef.EnableDebugLog, "getOutputSingle with Index {}...Done"sv, + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutputSingle: with Index {}...Done"sv, Index) return ErrNo::Success; } diff --git a/plugins/wasi_nn/wasinn_ggml.h b/plugins/wasi_nn/wasinn_ggml.h index 0118efa8..fb34c545 100644 --- a/plugins/wasi_nn/wasinn_ggml.h +++ b/plugins/wasi_nn/wasinn_ggml.h @@ -48,6 +48,7 @@ struct LocalConfig { int64_t NPredict; std::string ReversePrompt; std::string ImagePath; + bool AlwaysRegenerateImageEmbd = false; }; struct Graph { From b2f94bdf18a142063ccedbda69f99104b45e27d5 Mon Sep 17 00:00:00 2001 From: YiYing He Date: Fri, 10 Jan 2025 05:03:22 +0800 Subject: [PATCH 525/623] [Plugin] Use `stb_image` to replace the libpng and libjpeg. Signed-off-by: YiYing He --- plugins/wasmedge_image/CMakeLists.txt | 14 +- plugins/wasmedge_image/image_func.cpp | 190 +++---- plugins/wasmedge_image/image_func.h | 9 + plugins/wasmedge_image/image_module.cpp | 1 + .../plugins/wasmedge_image/wasmedge_image.cpp | 467 +++++++++++++++++- 5 files changed, 582 insertions(+), 99 deletions(-) diff --git a/plugins/wasmedge_image/CMakeLists.txt b/plugins/wasmedge_image/CMakeLists.txt index af130881..fd24e7eb 100644 --- a/plugins/wasmedge_image/CMakeLists.txt +++ b/plugins/wasmedge_image/CMakeLists.txt @@ -19,19 +19,13 @@ target_include_directories(wasmedgePluginWasmEdgeImage ${CMAKE_CURRENT_SOURCE_DIR} ) -# Need libjpeg, libpng, zlib, and boost. -find_package(ZLIB REQUIRED) -wasmedge_setup_jpeg() -wasmedge_setup_png() -wasmedge_setup_boost() - +# Need stb_image. +wasmedge_setup_stb_image() target_link_libraries(wasmedgePluginWasmEdgeImage PUBLIC - Boost::boost - wasmedgeDepsJPEG - wasmedgeDepsPNG - z + wasmedgeDepsSTBImage ) + if(WASMEDGE_LINK_PLUGINS_STATIC) target_link_libraries(wasmedgePluginWasmEdgeImage PRIVATE diff --git a/plugins/wasmedge_image/image_func.cpp b/plugins/wasmedge_image/image_func.cpp index 785a8163..216e8e82 100644 --- a/plugins/wasmedge_image/image_func.cpp +++ b/plugins/wasmedge_image/image_func.cpp @@ -6,12 +6,10 @@ #include "common/span.h" #include "common/spdlog.h" -#include -#include -#include -#include -#include -#include +#define STB_IMAGE_IMPLEMENTATION +#include +#define STB_IMAGE_RESIZE_IMPLEMENTATION +#include #include #include @@ -23,58 +21,81 @@ namespace WasmEdgeImage { namespace { -// Helper function to decode and resize image. -template -bool decodeImgToSize(Span Buf, uint32_t W, uint32_t H, - Span DstBuf) { - std::stringstream ImgStream; - ImgStream.write(Buf.data(), Buf.size()); - Image Img; - try { - boost::gil::read_and_convert_image(ImgStream, Img, FormatTag()); - } catch (std::exception const &e) { - spdlog::error("[WasmEdge-Image] Decode image fail: {}"sv, e.what()); +bool decodeImgToSize(Span Buf, uint32_t W, uint32_t H, + DataType OutType, Span DstBuf) noexcept { + // Specify the target data format. + bool IsRGB = true; + bool IsU8 = true; + switch (OutType) { + case DataType::BGR8: + IsRGB = false; + [[fallthrough]]; + case DataType::RGB8: + break; + case DataType::BGR32F: + IsRGB = false; + [[fallthrough]]; + case DataType::RGB32F: + IsU8 = false; + break; + default: return false; } - uint32_t C = boost::gil::num_channels::value; - typename Image::view_t ImgView = boost::gil::interleaved_view( - W, H, reinterpret_cast(DstBuf.data()), - W * C * sizeof(char)); - boost::gil::resize_view(boost::gil::const_view(Img), ImgView, - boost::gil::bilinear_sampler()); - return true; -} - -// Helper function to normalize image. -void normalizeImg(Span SrcBuf, Span DstBuf) { - for (uint32_t I = 0; I < DstBuf.size(); I++) { - DstBuf[I] = static_cast(SrcBuf[I]) / 255.0; + // Load and decode the image from buffer. + union RawImagePtr { + uint8_t *U8; + float *F32; + }; + RawImagePtr RawImg; + RawImg.U8 = nullptr; + int IW, IH, IC; + if (IsU8) { + RawImg.U8 = stbi_load_from_memory(Buf.data(), Buf.size(), &IW, &IH, &IC, 3); + } else { + RawImg.F32 = + stbi_loadf_from_memory(Buf.data(), Buf.size(), &IW, &IH, &IC, 3); + } + if (RawImg.U8 == nullptr) { + spdlog::error("[WasmEdge-Image] Load image failed."sv); + return false; } -} -// Template to decode and resize image to the target format. -template -uint32_t readBufToImg(Span InBuf, uint32_t W, uint32_t H, - Span OutBuf) { - if (unlikely(!decodeImgToSize(InBuf, W, H, OutBuf))) { - return static_cast(ErrNo::Fail); + // Resize. + if (unlikely(DstBuf.size() < + W * H * 3 * (IsU8 ? sizeof(uint8_t) : sizeof(float)))) { + spdlog::error("[WasmEdge-Image] Output buffer size {} not enough. "sv + "At least need {} bytes."sv, + DstBuf.size(), + W * H * 3 * (IsU8 ? sizeof(uint8_t) : sizeof(float))); + return false; + } + if (IsU8) { + stbir_resize_uint8_linear(RawImg.U8, IW, IH, 0, DstBuf.data(), + static_cast(W), static_cast(H), 0, + STBIR_RGB); + } else { + stbir_resize_float_linear( + RawImg.F32, IW, IH, 0, reinterpret_cast(DstBuf.data()), + static_cast(W), static_cast(H), 0, STBIR_RGB); } - return static_cast(ErrNo::Success); -} -// Template to decode and resize image to the target format. -template -uint32_t readBufToFlattenImg(Span InBuf, uint32_t W, uint32_t H, - Span OutBuf) { - std::vector ImgData(3 * W * H); - if (unlikely(!decodeImgToSize( - InBuf, W, H, Span(ImgData.data(), ImgData.size())))) { - return static_cast(ErrNo::Fail); + // Handle BGR case. + if (!IsRGB) { + if (IsU8) { + for (uint32_t I = 0; I < W * H; I++) { + std::swap(DstBuf[I * 3], DstBuf[I * 3 + 2]); + } + } else { + auto F32DstBuf = Span(reinterpret_cast(DstBuf.data()), + DstBuf.size() / sizeof(float)); + for (uint32_t I = 0; I < W * H; I++) { + std::swap(F32DstBuf[I * 3], F32DstBuf[I * 3 + 2]); + } + } } - normalizeImg(ImgData, Span(reinterpret_cast(OutBuf.data()), - OutBuf.size() / sizeof(float))); - return static_cast(ErrNo::Success); + stbi_image_free(RawImg.U8); + return true; } #define MEMINST_CHECK(Out, CallFrame, Index) \ @@ -102,31 +123,18 @@ Expect LoadJPG::body(const Runtime::CallingFrame &Frame, MEMINST_CHECK(MemInst, Frame, 0) // Check the input image buffer. - MEM_SPAN_CHECK(ImgBufSpan, MemInst, char, InImgBufPtr, InImgBufLen, + MEM_SPAN_CHECK(ImgBufSpan, MemInst, uint8_t, InImgBufPtr, InImgBufLen, "Failed when accessing the input image buffer memory."sv) // Check the output decoded image buffer. - MEM_SPAN_CHECK(OutBufSpan, MemInst, char, OutBufPtr, OutBufLen, + MEM_SPAN_CHECK(OutBufSpan, MemInst, uint8_t, OutBufPtr, OutBufLen, "Failed when accessing the output image data buffer memory."sv) - switch (static_cast(OutType)) { - case DataType::RGB8: - return readBufToImg( - ImgBufSpan, OutImgW, OutImgH, OutBufSpan); - case DataType::BGR8: - return readBufToImg( - ImgBufSpan, OutImgW, OutImgH, OutBufSpan); - case DataType::RGB32F: - return readBufToFlattenImg( - ImgBufSpan, OutImgW, OutImgH, OutBufSpan); - case DataType::BGR32F: - return readBufToFlattenImg( - ImgBufSpan, OutImgW, OutImgH, OutBufSpan); - break; - default: - spdlog::error("[WasmEdge-Image] Invalid output data format."sv); + if (unlikely(!decodeImgToSize(ImgBufSpan, OutImgW, OutImgH, + static_cast(OutType), OutBufSpan))) { return static_cast(ErrNo::Fail); } + return static_cast(ErrNo::Success); } Expect LoadPNG::body(const Runtime::CallingFrame &Frame, @@ -138,31 +146,41 @@ Expect LoadPNG::body(const Runtime::CallingFrame &Frame, MEMINST_CHECK(MemInst, Frame, 0) // Check the input image buffer. - MEM_SPAN_CHECK(ImgBufSpan, MemInst, char, InImgBufPtr, InImgBufLen, + MEM_SPAN_CHECK(ImgBufSpan, MemInst, uint8_t, InImgBufPtr, InImgBufLen, "Failed when accessing the input image buffer memory."sv) // Check the output decoded image buffer. - MEM_SPAN_CHECK(OutBufSpan, MemInst, char, OutBufPtr, OutBufLen, + MEM_SPAN_CHECK(OutBufSpan, MemInst, uint8_t, OutBufPtr, OutBufLen, "Failed when accessing the output image data buffer memory."sv) - switch (static_cast(OutType)) { - case DataType::RGB8: - return readBufToImg( - ImgBufSpan, OutImgW, OutImgH, OutBufSpan); - case DataType::BGR8: - return readBufToImg( - ImgBufSpan, OutImgW, OutImgH, OutBufSpan); - case DataType::RGB32F: - return readBufToFlattenImg( - ImgBufSpan, OutImgW, OutImgH, OutBufSpan); - case DataType::BGR32F: - return readBufToFlattenImg( - ImgBufSpan, OutImgW, OutImgH, OutBufSpan); - break; - default: - spdlog::error("[WasmEdge-Image] Invalid output data format."sv); + if (unlikely(!decodeImgToSize(ImgBufSpan, OutImgW, OutImgH, + static_cast(OutType), OutBufSpan))) { + return static_cast(ErrNo::Fail); + } + return static_cast(ErrNo::Success); +} + +Expect LoadImage::body(const Runtime::CallingFrame &Frame, + uint32_t InImgBufPtr, uint32_t InImgBufLen, + uint32_t OutImgW, uint32_t OutImgH, + uint32_t OutType, uint32_t OutBufPtr, + uint32_t OutBufLen) { + // Check memory instance from module. + MEMINST_CHECK(MemInst, Frame, 0) + + // Check the input image buffer. + MEM_SPAN_CHECK(ImgBufSpan, MemInst, uint8_t, InImgBufPtr, InImgBufLen, + "Failed when accessing the input image buffer memory."sv) + + // Check the output decoded image buffer. + MEM_SPAN_CHECK(OutBufSpan, MemInst, uint8_t, OutBufPtr, OutBufLen, + "Failed when accessing the output image data buffer memory."sv) + + if (unlikely(!decodeImgToSize(ImgBufSpan, OutImgW, OutImgH, + static_cast(OutType), OutBufSpan))) { return static_cast(ErrNo::Fail); } + return static_cast(ErrNo::Success); } } // namespace WasmEdgeImage diff --git a/plugins/wasmedge_image/image_func.h b/plugins/wasmedge_image/image_func.h index 9b18c3d2..4b6bc14f 100644 --- a/plugins/wasmedge_image/image_func.h +++ b/plugins/wasmedge_image/image_func.h @@ -29,6 +29,15 @@ class LoadPNG : public Func { uint32_t OutBufPtr, uint32_t OutBufLen); }; +class LoadImage : public Func { +public: + LoadImage(ImgEnv &HostEnv) : Func(HostEnv) {} + Expect body(const Runtime::CallingFrame &Frame, + uint32_t InImgBufPtr, uint32_t InImgBufLen, + uint32_t OutImgW, uint32_t OutImgH, uint32_t OutType, + uint32_t OutBufPtr, uint32_t OutBufLen); +}; + } // namespace WasmEdgeImage } // namespace Host } // namespace WasmEdge diff --git a/plugins/wasmedge_image/image_module.cpp b/plugins/wasmedge_image/image_module.cpp index 745cbb56..65e71479 100644 --- a/plugins/wasmedge_image/image_module.cpp +++ b/plugins/wasmedge_image/image_module.cpp @@ -13,6 +13,7 @@ WasmEdgeImageModule::WasmEdgeImageModule() : Runtime::Instance::ModuleInstance("wasmedge_image") { addHostFunc("load_jpg", std::make_unique(Env)); addHostFunc("load_png", std::make_unique(Env)); + addHostFunc("load_image", std::make_unique(Env)); } } // namespace Host diff --git a/test/plugins/wasmedge_image/wasmedge_image.cpp b/test/plugins/wasmedge_image/wasmedge_image.cpp index be767b4e..4f449727 100644 --- a/test/plugins/wasmedge_image/wasmedge_image.cpp +++ b/test/plugins/wasmedge_image/wasmedge_image.cpp @@ -14,6 +14,8 @@ #include #include +using WasmEdge::Host::WasmEdgeImage::ErrNo; + namespace { template @@ -40,17 +42,476 @@ std::unique_ptr createModule() { return {}; } -} // namespace +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t Offset, uint32_t Cnt, uint8_t C = 0) noexcept { + std::fill_n(MemInst.getPointer(Offset), Cnt, C); +} + +void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, + uint32_t Offset, + const std::vector &Payload) noexcept { + uint8_t *Buf = MemInst.getPointer(Offset); + std::copy_n(Payload.data(), Payload.size(), Buf); +} + +// Test image: #FF0000 png file, 30x30 image size, 158 bytes. +std::vector TestRedPNG = { + 0x89U, 0x50U, 0x4EU, 0x47U, 0x0DU, 0x0AU, 0x1AU, 0x0AU, 0x00U, 0x00U, 0x00U, + 0x0DU, 0x49U, 0x48U, 0x44U, 0x52U, 0x00U, 0x00U, 0x00U, 0x1EU, 0x00U, 0x00U, + 0x00U, 0x1EU, 0x08U, 0x06U, 0x00U, 0x00U, 0x00U, 0x3BU, 0x30U, 0xAEU, 0xA2U, + 0x00U, 0x00U, 0x00U, 0x01U, 0x73U, 0x52U, 0x47U, 0x42U, 0x00U, 0xAEU, 0xCEU, + 0x1CU, 0xE9U, 0x00U, 0x00U, 0x00U, 0x04U, 0x67U, 0x41U, 0x4DU, 0x41U, 0x00U, + 0x00U, 0xB1U, 0x8FU, 0x0BU, 0xFCU, 0x61U, 0x05U, 0x00U, 0x00U, 0x00U, 0x09U, + 0x70U, 0x48U, 0x59U, 0x73U, 0x00U, 0x00U, 0x16U, 0x25U, 0x00U, 0x00U, 0x16U, + 0x25U, 0x01U, 0x49U, 0x52U, 0x24U, 0xF0U, 0x00U, 0x00U, 0x00U, 0x33U, 0x49U, + 0x44U, 0x41U, 0x54U, 0x48U, 0x4BU, 0xEDU, 0xCDU, 0xA1U, 0x01U, 0x00U, 0x00U, + 0x0CU, 0x83U, 0xB0U, 0xFEU, 0xFFU, 0x74U, 0xE7U, 0x77U, 0x00U, 0x35U, 0x88U, + 0x18U, 0x0CU, 0x69U, 0xD2U, 0x85U, 0xFCU, 0x40U, 0x71U, 0x8CU, 0x71U, 0x8CU, + 0x71U, 0x8CU, 0x71U, 0x8CU, 0x71U, 0x8CU, 0x71U, 0x8CU, 0x71U, 0x8CU, 0x71U, + 0x8CU, 0x71U, 0x8CU, 0x99U, 0x8DU, 0x0FU, 0xD5U, 0x6CU, 0x01U, 0x62U, 0x5DU, + 0xE8U, 0xB5U, 0x3DU, 0x00U, 0x00U, 0x00U, 0x00U, 0x49U, 0x45U, 0x4EU, 0x44U, + 0xAEU, 0x42U, 0x60U, 0x82U}; -// TODO: unit tests for every functions. +// Test image: #FF0000 jpg file, 30x30 image size, 647 bytes. +std::vector TestRedJPG = { + 0xFFU, 0xD8U, 0xFFU, 0xE0U, 0x00U, 0x10U, 0x4AU, 0x46U, 0x49U, 0x46U, 0x00U, + 0x01U, 0x01U, 0x01U, 0x00U, 0x90U, 0x00U, 0x90U, 0x00U, 0x00U, 0xFFU, 0xDBU, + 0x00U, 0x43U, 0x00U, 0x02U, 0x01U, 0x01U, 0x02U, 0x01U, 0x01U, 0x02U, 0x02U, + 0x02U, 0x02U, 0x02U, 0x02U, 0x02U, 0x02U, 0x03U, 0x05U, 0x03U, 0x03U, 0x03U, + 0x03U, 0x03U, 0x06U, 0x04U, 0x04U, 0x03U, 0x05U, 0x07U, 0x06U, 0x07U, 0x07U, + 0x07U, 0x06U, 0x07U, 0x07U, 0x08U, 0x09U, 0x0BU, 0x09U, 0x08U, 0x08U, 0x0AU, + 0x08U, 0x07U, 0x07U, 0x0AU, 0x0DU, 0x0AU, 0x0AU, 0x0BU, 0x0CU, 0x0CU, 0x0CU, + 0x0CU, 0x07U, 0x09U, 0x0EU, 0x0FU, 0x0DU, 0x0CU, 0x0EU, 0x0BU, 0x0CU, 0x0CU, + 0x0CU, 0xFFU, 0xDBU, 0x00U, 0x43U, 0x01U, 0x02U, 0x02U, 0x02U, 0x03U, 0x03U, + 0x03U, 0x06U, 0x03U, 0x03U, 0x06U, 0x0CU, 0x08U, 0x07U, 0x08U, 0x0CU, 0x0CU, + 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, + 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, + 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, + 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0x0CU, + 0x0CU, 0x0CU, 0x0CU, 0x0CU, 0xFFU, 0xC0U, 0x00U, 0x11U, 0x08U, 0x00U, 0x1EU, + 0x00U, 0x1EU, 0x03U, 0x01U, 0x22U, 0x00U, 0x02U, 0x11U, 0x01U, 0x03U, 0x11U, + 0x01U, 0xFFU, 0xC4U, 0x00U, 0x1FU, 0x00U, 0x00U, 0x01U, 0x05U, 0x01U, 0x01U, + 0x01U, 0x01U, 0x01U, 0x01U, 0x00U, 0x00U, 0x00U, 0x00U, 0x00U, 0x00U, 0x00U, + 0x00U, 0x01U, 0x02U, 0x03U, 0x04U, 0x05U, 0x06U, 0x07U, 0x08U, 0x09U, 0x0AU, + 0x0BU, 0xFFU, 0xC4U, 0x00U, 0xB5U, 0x10U, 0x00U, 0x02U, 0x01U, 0x03U, 0x03U, + 0x02U, 0x04U, 0x03U, 0x05U, 0x05U, 0x04U, 0x04U, 0x00U, 0x00U, 0x01U, 0x7DU, + 0x01U, 0x02U, 0x03U, 0x00U, 0x04U, 0x11U, 0x05U, 0x12U, 0x21U, 0x31U, 0x41U, + 0x06U, 0x13U, 0x51U, 0x61U, 0x07U, 0x22U, 0x71U, 0x14U, 0x32U, 0x81U, 0x91U, + 0xA1U, 0x08U, 0x23U, 0x42U, 0xB1U, 0xC1U, 0x15U, 0x52U, 0xD1U, 0xF0U, 0x24U, + 0x33U, 0x62U, 0x72U, 0x82U, 0x09U, 0x0AU, 0x16U, 0x17U, 0x18U, 0x19U, 0x1AU, + 0x25U, 0x26U, 0x27U, 0x28U, 0x29U, 0x2AU, 0x34U, 0x35U, 0x36U, 0x37U, 0x38U, + 0x39U, 0x3AU, 0x43U, 0x44U, 0x45U, 0x46U, 0x47U, 0x48U, 0x49U, 0x4AU, 0x53U, + 0x54U, 0x55U, 0x56U, 0x57U, 0x58U, 0x59U, 0x5AU, 0x63U, 0x64U, 0x65U, 0x66U, + 0x67U, 0x68U, 0x69U, 0x6AU, 0x73U, 0x74U, 0x75U, 0x76U, 0x77U, 0x78U, 0x79U, + 0x7AU, 0x83U, 0x84U, 0x85U, 0x86U, 0x87U, 0x88U, 0x89U, 0x8AU, 0x92U, 0x93U, + 0x94U, 0x95U, 0x96U, 0x97U, 0x98U, 0x99U, 0x9AU, 0xA2U, 0xA3U, 0xA4U, 0xA5U, + 0xA6U, 0xA7U, 0xA8U, 0xA9U, 0xAAU, 0xB2U, 0xB3U, 0xB4U, 0xB5U, 0xB6U, 0xB7U, + 0xB8U, 0xB9U, 0xBAU, 0xC2U, 0xC3U, 0xC4U, 0xC5U, 0xC6U, 0xC7U, 0xC8U, 0xC9U, + 0xCAU, 0xD2U, 0xD3U, 0xD4U, 0xD5U, 0xD6U, 0xD7U, 0xD8U, 0xD9U, 0xDAU, 0xE1U, + 0xE2U, 0xE3U, 0xE4U, 0xE5U, 0xE6U, 0xE7U, 0xE8U, 0xE9U, 0xEAU, 0xF1U, 0xF2U, + 0xF3U, 0xF4U, 0xF5U, 0xF6U, 0xF7U, 0xF8U, 0xF9U, 0xFAU, 0xFFU, 0xC4U, 0x00U, + 0x1FU, 0x01U, 0x00U, 0x03U, 0x01U, 0x01U, 0x01U, 0x01U, 0x01U, 0x01U, 0x01U, + 0x01U, 0x01U, 0x00U, 0x00U, 0x00U, 0x00U, 0x00U, 0x00U, 0x01U, 0x02U, 0x03U, + 0x04U, 0x05U, 0x06U, 0x07U, 0x08U, 0x09U, 0x0AU, 0x0BU, 0xFFU, 0xC4U, 0x00U, + 0xB5U, 0x11U, 0x00U, 0x02U, 0x01U, 0x02U, 0x04U, 0x04U, 0x03U, 0x04U, 0x07U, + 0x05U, 0x04U, 0x04U, 0x00U, 0x01U, 0x02U, 0x77U, 0x00U, 0x01U, 0x02U, 0x03U, + 0x11U, 0x04U, 0x05U, 0x21U, 0x31U, 0x06U, 0x12U, 0x41U, 0x51U, 0x07U, 0x61U, + 0x71U, 0x13U, 0x22U, 0x32U, 0x81U, 0x08U, 0x14U, 0x42U, 0x91U, 0xA1U, 0xB1U, + 0xC1U, 0x09U, 0x23U, 0x33U, 0x52U, 0xF0U, 0x15U, 0x62U, 0x72U, 0xD1U, 0x0AU, + 0x16U, 0x24U, 0x34U, 0xE1U, 0x25U, 0xF1U, 0x17U, 0x18U, 0x19U, 0x1AU, 0x26U, + 0x27U, 0x28U, 0x29U, 0x2AU, 0x35U, 0x36U, 0x37U, 0x38U, 0x39U, 0x3AU, 0x43U, + 0x44U, 0x45U, 0x46U, 0x47U, 0x48U, 0x49U, 0x4AU, 0x53U, 0x54U, 0x55U, 0x56U, + 0x57U, 0x58U, 0x59U, 0x5AU, 0x63U, 0x64U, 0x65U, 0x66U, 0x67U, 0x68U, 0x69U, + 0x6AU, 0x73U, 0x74U, 0x75U, 0x76U, 0x77U, 0x78U, 0x79U, 0x7AU, 0x82U, 0x83U, + 0x84U, 0x85U, 0x86U, 0x87U, 0x88U, 0x89U, 0x8AU, 0x92U, 0x93U, 0x94U, 0x95U, + 0x96U, 0x97U, 0x98U, 0x99U, 0x9AU, 0xA2U, 0xA3U, 0xA4U, 0xA5U, 0xA6U, 0xA7U, + 0xA8U, 0xA9U, 0xAAU, 0xB2U, 0xB3U, 0xB4U, 0xB5U, 0xB6U, 0xB7U, 0xB8U, 0xB9U, + 0xBAU, 0xC2U, 0xC3U, 0xC4U, 0xC5U, 0xC6U, 0xC7U, 0xC8U, 0xC9U, 0xCAU, 0xD2U, + 0xD3U, 0xD4U, 0xD5U, 0xD6U, 0xD7U, 0xD8U, 0xD9U, 0xDAU, 0xE2U, 0xE3U, 0xE4U, + 0xE5U, 0xE6U, 0xE7U, 0xE8U, 0xE9U, 0xEAU, 0xF2U, 0xF3U, 0xF4U, 0xF5U, 0xF6U, + 0xF7U, 0xF8U, 0xF9U, 0xFAU, 0xFFU, 0xDAU, 0x00U, 0x0CU, 0x03U, 0x01U, 0x00U, + 0x02U, 0x11U, 0x03U, 0x11U, 0x00U, 0x3FU, 0x00U, 0xF8U, 0xBEU, 0x8AU, 0x28U, + 0xAFU, 0xE5U, 0x33U, 0xFDU, 0xFCU, 0x0AU, 0x28U, 0xA2U, 0x80U, 0x0AU, 0x28U, + 0xA2U, 0x80U, 0x0AU, 0x28U, 0xA2U, 0x80U, 0x3FU, 0xFFU, 0xD9U}; + +} // namespace TEST(WasmEdgeImageTest, Module) { // Create the wasmedge_image module instance. auto ImgMod = createModule(); ASSERT_TRUE(ImgMod); - EXPECT_EQ(ImgMod->getFuncExportNum(), 2U); + EXPECT_EQ(ImgMod->getFuncExportNum(), 3U); EXPECT_NE(ImgMod->findFuncExports("load_jpg"), nullptr); EXPECT_NE(ImgMod->findFuncExports("load_png"), nullptr); + EXPECT_NE(ImgMod->findFuncExports("load_image"), nullptr); +} + +TEST(WasmEdgeImageTest, LoadJPG) { + // Create the wasmedge_image module instance. + auto ImgMod = createModule(); + ASSERT_TRUE(ImgMod); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + std::array Errno = {UINT32_C(0)}; + + // Set the target size. + uint32_t TargetW = 50, TargetH = 60; + uint32_t TargetSize = TargetW * TargetH * 3; + // Assume the pixel position (45, 55). + uint32_t Position = 55 * TargetW + 45; + // Input payload offset. + uint32_t InOffset = 0; + // Output image data offset. + uint32_t OutOffset = 1024; + // Output image span. + WasmEdge::Span OutSpanU8; + WasmEdge::Span OutSpanF32; + + // Get the function "load_jpg". + auto *FuncInst = ImgMod->findFuncExports("load_jpg"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInst = dynamic_cast( + FuncInst->getHostFunc()); + + // Test: Load JPG and resize into 50x60 RGB u8 format. + // Clear the memory[0, 32768]. + fillMemContent(MemInst, 0, 32768); + // Set the memory[0, 647] as the JPG image payload. + fillMemContent(MemInst, 0, TestRedJPG); + // Get output image span. + OutSpanU8 = MemInst.getSpan(OutOffset, TargetSize); + // Run. + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + InOffset, // Payload offset. + static_cast(TestRedJPG.size()), // Payload size. + TargetW, TargetH, // Target width and height. + 0U, // Target type: RGB8. + OutOffset, // Output buffer offset. + TargetSize * + static_cast(sizeof(uint8_t)) // Output buffer size. + }, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + // Note: Due to the JPG compression, the R is 254, not 255 here. + EXPECT_EQ(OutSpanU8[Position * 3], UINT8_C(254)); + EXPECT_EQ(OutSpanU8[Position * 3 + 1], UINT8_C(0)); + EXPECT_EQ(OutSpanU8[Position * 3 + 2], UINT8_C(0)); + + // Test: Load JPG and resize into 50x60 BGR u8 format. + // Clear the memory[0, 32768]. + fillMemContent(MemInst, 0, 32768); + // Set the memory[0, 158] as the JPG image payload. + fillMemContent(MemInst, 0, TestRedJPG); + // Get output image span. + OutSpanU8 = MemInst.getSpan(OutOffset, TargetSize); + // Run. + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + InOffset, // Payload offset. + static_cast(TestRedJPG.size()), // Payload size. + TargetW, TargetH, // Target width and height. + 1U, // Target type: BGR8. + OutOffset, // Output buffer offset. + TargetSize * + static_cast(sizeof(uint8_t)) // Output buffer size. + }, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + // Note: Due to the JPG compression, the R is 254, not 255 here. + EXPECT_EQ(OutSpanU8[Position * 3], UINT8_C(0)); + EXPECT_EQ(OutSpanU8[Position * 3 + 1], UINT8_C(0)); + EXPECT_EQ(OutSpanU8[Position * 3 + 2], UINT8_C(254)); + + // Test: Load JPG and resize into 50x60 RGB f32 format. + // Clear the memory[0, 32768]. + fillMemContent(MemInst, 0, 32768); + // Set the memory[0, 647] as the JPG image payload. + fillMemContent(MemInst, 0, TestRedJPG); + // Get output image span. + OutSpanF32 = MemInst.getSpan(OutOffset, TargetSize); + // Run. + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + InOffset, // Payload offset. + static_cast(TestRedJPG.size()), // Payload size. + TargetW, TargetH, // Target width and height. + 2U, // Target type: RGB32F. + OutOffset, // Output buffer offset. + TargetSize * + static_cast(sizeof(float)) // Output buffer size. + }, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + // Note: Due to the JPG compression, the R is 0.991392851f, not 1.0f here. + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3] - 1.0f) < 0.01f); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3 + 1] - 0.0f) < 0.01f); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3 + 2] - 0.0f) < 0.01f); + + // Test: Load JPG and resize into 50x60 BGR f32 format. + // Clear the memory[0, 32768]. + fillMemContent(MemInst, 0, 32768); + // Set the memory[0, 647] as the JPG image payload. + fillMemContent(MemInst, 0, TestRedJPG); + // Get output image span. + OutSpanF32 = MemInst.getSpan(OutOffset, TargetSize); + // Run. + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + InOffset, // Payload offset. + static_cast(TestRedJPG.size()), // Payload size. + TargetW, TargetH, // Target width and height. + 3U, // Target type: BGR32F. + OutOffset, // Output buffer offset. + TargetSize * + static_cast(sizeof(float)) // Output buffer size. + }, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + // Note: Due to the JPG compression, the R is 0.991392851f, not 1.0f here. + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3] - 0.0f) < 0.01f); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3 + 1] - 0.0f) < 0.01f); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3 + 2] - 1.0f) < 0.01f); +} + +TEST(WasmEdgeImageTest, LoadPNG) { + // Create the wasmedge_image module instance. + auto ImgMod = createModule(); + ASSERT_TRUE(ImgMod); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + std::array Errno = {UINT32_C(0)}; + + // Set the target size. + uint32_t TargetW = 50, TargetH = 60; + uint32_t TargetSize = TargetW * TargetH * 3; + // Assume the pixel position (45, 55). + uint32_t Position = 55 * TargetW + 45; + // Input payload offset. + uint32_t InOffset = 0; + // Output image data offset. + uint32_t OutOffset = 1024; + // Output image span. + WasmEdge::Span OutSpanU8; + WasmEdge::Span OutSpanF32; + + // Get the function "load_png". + auto *FuncInst = ImgMod->findFuncExports("load_png"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInst = dynamic_cast( + FuncInst->getHostFunc()); + + // Test: Load PNG and resize into 50x60 RGB u8 format. + // Clear the memory[0, 32768]. + fillMemContent(MemInst, 0, 32768); + // Set the memory[0, 158] as the PNG image payload. + fillMemContent(MemInst, 0, TestRedPNG); + // Get output image span. + OutSpanU8 = MemInst.getSpan(OutOffset, TargetSize); + // Run. + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + InOffset, // Payload offset. + static_cast(TestRedPNG.size()), // Payload size. + TargetW, TargetH, // Target width and height. + 0U, // Target type: RGB8. + OutOffset, // Output buffer offset. + TargetSize * + static_cast(sizeof(uint8_t)) // Output buffer size. + }, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(OutSpanU8[Position * 3], UINT8_C(255)); + EXPECT_EQ(OutSpanU8[Position * 3 + 1], UINT8_C(0)); + EXPECT_EQ(OutSpanU8[Position * 3 + 2], UINT8_C(0)); + + // Test: Load PNG and resize into 50x60 BGR u8 format. + // Clear the memory[0, 32768]. + fillMemContent(MemInst, 0, 32768); + // Set the memory[0, 158] as the PNG image payload. + fillMemContent(MemInst, 0, TestRedPNG); + // Get output image span. + OutSpanU8 = MemInst.getSpan(OutOffset, TargetSize); + // Run. + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + InOffset, // Payload offset. + static_cast(TestRedPNG.size()), // Payload size. + TargetW, TargetH, // Target width and height. + 1U, // Target type: BGR8. + OutOffset, // Output buffer offset. + TargetSize * + static_cast(sizeof(uint8_t)) // Output buffer size. + }, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_EQ(OutSpanU8[Position * 3], UINT8_C(0)); + EXPECT_EQ(OutSpanU8[Position * 3 + 1], UINT8_C(0)); + EXPECT_EQ(OutSpanU8[Position * 3 + 2], UINT8_C(255)); + + // Test: Load PNG and resize into 50x60 RGB f32 format. + // Clear the memory[0, 32768]. + fillMemContent(MemInst, 0, 32768); + // Set the memory[0, 158] as the PNG image payload. + fillMemContent(MemInst, 0, TestRedPNG); + // Get output image span. + OutSpanF32 = MemInst.getSpan(OutOffset, TargetSize); + // Run. + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + InOffset, // Payload offset. + static_cast(TestRedPNG.size()), // Payload size. + TargetW, TargetH, // Target width and height. + 2U, // Target type: RGB32F. + OutOffset, // Output buffer offset. + TargetSize * + static_cast(sizeof(float)) // Output buffer size. + }, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3] - 1.0f) < 0.00001f); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3 + 1] - 0.0f) < 0.00001f); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3 + 2] - 0.0f) < 0.00001f); + + // Test: Load PNG and resize into 50x60 BGR f32 format. + // Clear the memory[0, 32768]. + fillMemContent(MemInst, 0, 32768); + // Set the memory[0, 158] as the PNG image payload. + fillMemContent(MemInst, 0, TestRedPNG); + // Get output image span. + OutSpanF32 = MemInst.getSpan(OutOffset, TargetSize); + // Run. + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + InOffset, // Payload offset. + static_cast(TestRedPNG.size()), // Payload size. + TargetW, TargetH, // Target width and height. + 3U, // Target type: BGR32F. + OutOffset, // Output buffer offset. + TargetSize * + static_cast(sizeof(float)) // Output buffer size. + }, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3] - 0.0f) < 0.00001f); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3 + 1] - 0.0f) < 0.00001f); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3 + 2] - 1.0f) < 0.00001f); +} + +TEST(WasmEdgeImageTest, LoadImage) { + // Test for the general API. + + // Create the wasmedge_image module instance. + auto ImgMod = createModule(); + ASSERT_TRUE(ImgMod); + + // Create the calling frame with memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(1))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + std::array Errno = {UINT32_C(0)}; + + // Set the target size. + uint32_t TargetW = 50, TargetH = 60; + uint32_t TargetSize = TargetW * TargetH * 3; + // Assume the pixel position (45, 55). + uint32_t Position = 55 * TargetW + 45; + // Input payload offset. + uint32_t InOffset = 0; + // Output image data offset. + uint32_t OutOffset = 1024; + // Output image span. + WasmEdge::Span OutSpanU8; + WasmEdge::Span OutSpanF32; + + // Get the function "load_image". + auto *FuncInst = ImgMod->findFuncExports("load_image"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncInst = dynamic_cast( + FuncInst->getHostFunc()); + + // Test: Load JPG and resize into 50x60 BGR u8 format. + // Clear the memory[0, 32768]. + fillMemContent(MemInst, 0, 32768); + // Set the memory[0, 647] as the JPG image payload. + fillMemContent(MemInst, 0, TestRedJPG); + // Get output image span. + OutSpanU8 = MemInst.getSpan(OutOffset, TargetSize); + // Run. + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + InOffset, // Payload offset. + static_cast(TestRedJPG.size()), // Payload size. + TargetW, TargetH, // Target width and height. + 1U, // Target type: BGR8. + OutOffset, // Output buffer offset. + TargetSize * + static_cast(sizeof(uint8_t)) // Output buffer size. + }, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + // Note: Due to the JPG compression, the R is 254, not 255 here. + EXPECT_EQ(OutSpanU8[Position * 3], UINT8_C(0)); + EXPECT_EQ(OutSpanU8[Position * 3 + 1], UINT8_C(0)); + EXPECT_EQ(OutSpanU8[Position * 3 + 2], UINT8_C(254)); + + // Test: Load PNG and resize into 50x60 RGB f32 format. + // Clear the memory[0, 32768]. + fillMemContent(MemInst, 0, 32768); + // Set the memory[0, 647] as the PNG image payload. + fillMemContent(MemInst, 0, TestRedPNG); + // Get output image span. + OutSpanF32 = MemInst.getSpan(OutOffset, TargetSize); + // Run. + EXPECT_TRUE(HostFuncInst.run( + CallFrame, + std::initializer_list{ + InOffset, // Payload offset. + static_cast(TestRedPNG.size()), // Payload size. + TargetW, TargetH, // Target width and height. + 2U, // Target type: RGB32F. + OutOffset, // Output buffer offset. + TargetSize * + static_cast(sizeof(float)) // Output buffer size. + }, + Errno)); + EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3] - 1.0f) < 0.00001f); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3 + 1] - 0.0f) < 0.00001f); + EXPECT_TRUE(std::fabs(OutSpanF32[Position * 3 + 2] - 0.0f) < 0.00001f); } GTEST_API_ int main(int argc, char **argv) { From 11539065dc36dcb348493bb95654fd8f0a60c6c1 Mon Sep 17 00:00:00 2001 From: Deveshi Dwivedi <120312681+deveshidwivedi@users.noreply.github.com> Date: Thu, 16 Jan 2025 19:04:22 +0530 Subject: [PATCH 526/623] [WASI-crypto] fix: secretkey_export on rsa with ENCODING_PKCS8 (#3963) * correct PKCS8 export for RSA secret key Signed-off-by: Deveshi Dwivedi * fix clang-format issues Signed-off-by: Deveshi Dwivedi * update RSA key test cases to use PKCS8 format Signed-off-by: Deveshi Dwivedi * fix failing test for pkcs8 encoding Signed-off-by: Deveshi Dwivedi --------- Signed-off-by: Deveshi Dwivedi --- plugins/wasi_crypto/signatures/rsa.cpp | 18 +- test/plugins/wasi_crypto/asymmetric.cpp | 359 ++++++++++++------------ 2 files changed, 198 insertions(+), 179 deletions(-) diff --git a/plugins/wasi_crypto/signatures/rsa.cpp b/plugins/wasi_crypto/signatures/rsa.cpp index 245fe9ed..722c40a4 100644 --- a/plugins/wasi_crypto/signatures/rsa.cpp +++ b/plugins/wasi_crypto/signatures/rsa.cpp @@ -171,7 +171,23 @@ Rsa::SecretKey::exportPem() const noexcept { template WasiCryptoExpect Rsa::SecretKey::exportPkcs8() const noexcept { - return i2dPrivateKey(Ctx.get()); + EVP_PKEY *Key = Ctx.get(); + BioPtr Bio{BIO_new(BIO_s_mem())}; + + opensslCheck(i2d_PKCS8PrivateKey_bio(Bio.get(), Key, nullptr, nullptr, 0, + nullptr, nullptr)); + + BUF_MEM *Mem = nullptr; + opensslCheck(BIO_get_mem_ptr(Bio.get(), &Mem)); + SecretVec Ret(Mem->length); + + if (size_t Size; BIO_read_ex(Bio.get(), Ret.data(), Ret.size(), &Size)) { + ensureOrReturn(Size == Ret.size(), __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + } else { + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + } + + return Ret; } template diff --git a/test/plugins/wasi_crypto/asymmetric.cpp b/test/plugins/wasi_crypto/asymmetric.cpp index 9e9f9c9f..24f621e7 100644 --- a/test/plugins/wasi_crypto/asymmetric.cpp +++ b/test/plugins/wasi_crypto/asymmetric.cpp @@ -253,46 +253,47 @@ TEST_F(WasiCryptoTest, Asymmetric) { "3o5BHs2kyvh9kuSthBY9XZnN\n" "-----END PRIVATE KEY-----\n"_u8}, {__WASI_SECRETKEY_ENCODING_PKCS8, - "308204a40201000282010100b681aab465cbe9ba9b0cad08276a497ccee1" - "7cca726155d6c6d11dbf9a0d12bbc65a91ca887d4d21e4319f492cd50166" - "907a0af10bafcf29f9fda9c72f189979c59bcd1034d4c4dfb35cbb1cc065" - "ad3a198514d93c57bb551189f8b8afd63a488911cfcea69b66bb2c82cdca" - "8f97c6d0461721369045f9f491f65c62d24b622d0ebd4d7c0c18bae0cacb" - "c0006bdf0aa7d946603208b0dcc84abf25d1c9b71ac740123b5453c40248" - "9b6d57b85593ced1dc241be5ebaca00c0cf345d389c8f5af1be4ee84d5ef" - "67b76d034378894f294b02658f32c4cd6cf028e8008b3c7743711d4bef4f" - "fa94dd2ad7a2f5419656e9e88f8ea40b0fb441ab1013de7c34f6b7b90203" - "0100010282010025db3ba773be1a5b44b21a6a2892d96f74123daa58936c" - "14c2e4b980f6d9635b63c7819a3b39927847372bcd27e97f02e1510f57f4" - "8ea13019d4ce14ace6335f98e7ba5f7435f62858b21175e34ab3e5eb4939" - "8bde026caf369621eb5d3dd895172984ab5ecfb93d75fc23e7b2654f2e00" - "7be7bdc3ab602fa4df2f46a84c4eaea3e89a1a12b1a2ea7aa17ce67427b1" - "9fc0a109b4c18f59b3b0ca3b0ca4bc6b5b73c0cb4607f7d5190fc39cc70d" - "fd844447f22f521928e4d6dcb9fcd9e90ba2bcc0999bec13945c5975834d" - "05907f6723faf27635775dce8bfaf2de2eb4ef0c4e7db8fb512a815209b1" - "e6c680119339bc45aa57f91fedb09394c70c0f44c8da0102818100ec8daf" - "36d801648caa03b4ae2447a47d14a72f5bf99455e2b3e2574c4329c0fb96" - "9d536732348ed455159871b248ede49e7c869a15efda1e10363529f2f37f" - "8885c3df6becf5d13073f7709c2eceeccd880175dae362c43c43b81cef77" - "46051aa6e28416f87a6e16a3e5539cf5359ede06d94924109dc698a547f6" - "c2f86c906102818100c5828daf83574c284226ed87d1497a23a5151c0314" - "b8a8225e84e71978f9247759879b00643a45fa30c954a333620974e6c178" - "d131646064992d6c7d08f69c121c213325b99f9c5e56194c1575f084bd9e" - "2021c80367286c9b803df8bdaa0b9366dddaf7d0d47b737e0ad59a622b70" - "094d6a08d7db84998e3ad6bc8e187e465902818100b129d00f1829badfaa" - "b949c99e7c55922434ae408924724a6e84d6f2d3de629d4a891b9ccf3a13" - "baacda96a898690c5a4be4617ee76d1283af8a99b8882f9bd568b1711448" - "8d3615bed493ef35135ec0f3da7c24ea65df286f6365b06738f6bba63f41" - "c45e667b2ad3a6fe3f305aef57c3f35a56fb66df0515cc56e060e4162102" - "818100be4f698e7093b142295bb10c6950cf6b8129b1f0160b4796b65481" - "093e53721bbe1cb7f1cca189c3e536596357a363514cf7a71e8ae5192c55" - "9c3b28cf76303412feba75e342343d81e0a63b178545a21fb6fe55e75182" - "d6038fb226f739de258dbcbfbc816ffbf3f0c327c6b648fc8f3a14ada8b0" - "5038559fc441b2f94ed3210281800f9c765e699b561c54d85c7c66f1ce33" - "7be4c2c692af507cbeb3c6a588bf8c54b23c989301a0ab06b331e3668b92" - "af860a5f50cf00f2d2508b15e6abfa1e9bc303b4ea2f990dc8d52ea360ae" - "d5f30cb1caa823ff6a1eb4f851223be1b7c21a1ee14a4248efeb8db8bdda" - "2e92ea0cde8e411ecda4caf87d92e4ad84163d5d99cd"_u8v}}, + "308204be020100300d06092a864886f70d0101010500048204a8308204a4" + "0201000282010100b681aab465cbe9ba9b0cad08276a497ccee17cca7261" + "55d6c6d11dbf9a0d12bbc65a91ca887d4d21e4319f492cd50166907a0af1" + "0bafcf29f9fda9c72f189979c59bcd1034d4c4dfb35cbb1cc065ad3a1985" + "14d93c57bb551189f8b8afd63a488911cfcea69b66bb2c82cdca8f97c6d0" + "461721369045f9f491f65c62d24b622d0ebd4d7c0c18bae0cacbc0006bdf" + "0aa7d946603208b0dcc84abf25d1c9b71ac740123b5453c402489b6d57b8" + "5593ced1dc241be5ebaca00c0cf345d389c8f5af1be4ee84d5ef67b76d03" + "4378894f294b02658f32c4cd6cf028e8008b3c7743711d4bef4ffa94dd2a" + "d7a2f5419656e9e88f8ea40b0fb441ab1013de7c34f6b7b9020301000102" + "82010025db3ba773be1a5b44b21a6a2892d96f74123daa58936c14c2e4b9" + "80f6d9635b63c7819a3b39927847372bcd27e97f02e1510f57f48ea13019" + "d4ce14ace6335f98e7ba5f7435f62858b21175e34ab3e5eb49398bde026c" + "af369621eb5d3dd895172984ab5ecfb93d75fc23e7b2654f2e007be7bdc3" + "ab602fa4df2f46a84c4eaea3e89a1a12b1a2ea7aa17ce67427b19fc0a109" + "b4c18f59b3b0ca3b0ca4bc6b5b73c0cb4607f7d5190fc39cc70dfd844447" + "f22f521928e4d6dcb9fcd9e90ba2bcc0999bec13945c5975834d05907f67" + "23faf27635775dce8bfaf2de2eb4ef0c4e7db8fb512a815209b1e6c68011" + "9339bc45aa57f91fedb09394c70c0f44c8da0102818100ec8daf36d80164" + "8caa03b4ae2447a47d14a72f5bf99455e2b3e2574c4329c0fb969d536732" + "348ed455159871b248ede49e7c869a15efda1e10363529f2f37f8885c3df" + "6becf5d13073f7709c2eceeccd880175dae362c43c43b81cef7746051aa6" + "e28416f87a6e16a3e5539cf5359ede06d94924109dc698a547f6c2f86c90" + "6102818100c5828daf83574c284226ed87d1497a23a5151c0314b8a8225e" + "84e71978f9247759879b00643a45fa30c954a333620974e6c178d1316460" + "64992d6c7d08f69c121c213325b99f9c5e56194c1575f084bd9e2021c803" + "67286c9b803df8bdaa0b9366dddaf7d0d47b737e0ad59a622b70094d6a08" + "d7db84998e3ad6bc8e187e465902818100b129d00f1829badfaab949c99e" + "7c55922434ae408924724a6e84d6f2d3de629d4a891b9ccf3a13baacda96" + "a898690c5a4be4617ee76d1283af8a99b8882f9bd568b17114488d3615be" + "d493ef35135ec0f3da7c24ea65df286f6365b06738f6bba63f41c45e667b" + "2ad3a6fe3f305aef57c3f35a56fb66df0515cc56e060e4162102818100be" + "4f698e7093b142295bb10c6950cf6b8129b1f0160b4796b65481093e5372" + "1bbe1cb7f1cca189c3e536596357a363514cf7a71e8ae5192c559c3b28cf" + "76303412feba75e342343d81e0a63b178545a21fb6fe55e75182d6038fb2" + "26f739de258dbcbfbc816ffbf3f0c327c6b648fc8f3a14ada8b05038559f" + "c441b2f94ed3210281800f9c765e699b561c54d85c7c66f1ce337be4c2c6" + "92af507cbeb3c6a588bf8c54b23c989301a0ab06b331e3668b92af860a5f" + "50cf00f2d2508b15e6abfa1e9bc303b4ea2f990dc8d52ea360aed5f30cb1" + "caa823ff6a1eb4f851223be1b7c21a1ee14a4248efeb8db8bdda2e92ea0c" + "de8e411ecda4caf87d92e4ad84163d5d99cd"_u8v}}, {}); RsaCheck( "3072"sv, @@ -366,65 +367,66 @@ TEST_F(WasiCryptoTest, Asymmetric) { "gpvbM5ETOLFocm1MM1hcOA==\n" "-----END PRIVATE KEY-----\n"_u8}, {__WASI_SECRETKEY_ENCODING_PKCS8, - "308206e202010002820181009ef69f30260218f18e869be9037146f04b4d" - "e8fff0210bfe7c2dccbaaacea37241632615188c4cae9c57fe21840ce103" - "5a4abd2e9ec034528cbb63f6711f8eb9d67eed6bc807c9c835cbcf64219d" - "aa774e2fd3e8255db326f790e8554c6696e1ecd73e763e4c736585884d4b" - "c2411e6990f8739af952c4f6094bdf7434d4a11a8d9d85408e74167eeaeb" - "7e8fe8f608e27adaefbc6293fdd3ad74387f5af125dfe1d1c15ac479f7ec" - "3bfd6d10f6c03c71efb677758133496cdb73540014b074fb6b83b13caec9" - "4b7866b6fe89469fd98fd55490389ef3a6b59c3edd156f92edadcdd16ff4" - "10c06258976e898c9ed508ad7f651fd03a68efc27c96725c7bbcd1d54c10" - "73a108b8d95987810e0c8cf3f8f9b7f65a1cbae160555ab7f1d70d4ad09e" - "497d371ae2fbe52377b014fabd79ac3de1218e1431c659062026dfba10a4" - "d533c78f643c51139402f6736e0b296e1aed6186951ab39cf9e581b2a960" - "473cd56c5fff709f72459cb43f3b104954e89c01a93572c7293a41e37f21" - "cf96aec1f35f0203010001028201804ac3e40b5955133649bab609da3ca5" - "08cfe24cfc538cb77f7218787a336c0d23e7ed223439df83117d2745b7ad" - "cc00e8fac6bc43f9169d8555fbad0074244b94cc75d6652327c6980bf558" - "0dd861b793758ab9382e9aeb7020705f55ff21214611870b31c20b631b14" - "bb0edfdaf595c041171a0881cb9427c4279369ac8f75566fd4ee9f7660d3" - "53ce5a04a4db051d18a87fe0d1d1eb992ebe1e339472c988eecbdf43f9d4" - "ff28c44b52dec163ccf6a10005ea19b232d50e06093030f98a24fa7dc412" - "c71ffdb98dc63db77f8f33dbd3df192b2771ab705f641c7778474de52ddc" - "6ebdfd01fc0bc795952987061c7fa11378a4b51fdb7508823667464a1e5b" - "f60b8672dd91f4eb22053cd8c923dc0225cabd4271f548ddca7f525a93b8" - "974f4c42acead413837699ca7c9c5dc7860799674bcf096fcd06fc079998" - "1d0c8ad923b9b9686abe319af624eaab7078150cc0ba282bcca6425137c2" - "277ba202a4129de2d97bf66ad105b5d62bf842108c121bda3b0bc08f57a3" - "6b9d931f815314aa010281c100cd6a02b434ac9f4f02522bd6a0b77072f5" - "e40975a285160ce39130498dd9933d582ec8e9ecd346a17d263b35966836" - "ba39f9b629242fda9070de855c47b9ecfc8520ff61923febf5699a561b1e" - "f03d4bb775bb8c7e99edf0f2946b36cb268045f376cc5d93792ca8a3116e" - "075aa3b95ac8d4c302d4266d28baf49a6e73287fab027af8a20dedfc05ed" - "a55c67c98c2b23f61be6194a8f74509027d73a2235da3f6d2abce60d0dd5" - "582790092d9cb07c7ed10c61dde3e187e464e81d74ebf61ee10281c100c6" - "1c33b8e3d8d29a3641189dfed84e78cdca674360bc26ed40850ca1e696e6" - "3395e347bb242ee7763e852c8d84a45156d771b922b26c33cacd34b0e603" - "7860fd640862137034c8424b7d5bbd19fd4e092f9592693c857a56585a0f" - "0c0023d98480dec59f9ff9b340f365418a7138d244d702ca7e23f74d647b" - "ac6499c74744346e76ab7c2cfc52e41ea853b8bd51fbd4ca49bed8550ffc" - "07ba11d94ff3694570f082b90a54e0bc5702b9536653b557fcba9a5db01a" - "f4826c9714723e933d9a3f0281c10085d44d92aec6d0bc1f1cfe26c56afe" - "3e47e99c28220c67435a785b67709d928a630b8826afff2e834410467f31" - "511066e022cb059ee7f69428953179dfd94887750cfc95cb3d0e3443eb23" - "b263c3cd8ba9297159a59a1025ed45b95c679adbf3b71d6d2482526e4028" - "8dd08bc607e959368337d27df9d320b83d68e810eb0fd290b92188235f2b" - "d588f135750120eb727083d8b41d99bd0448074cf83915b0eda5e8344e05" - "af3f9241a45bf675a19e5ad94421f6f8e315303e75e3cb2b789cc10281bf" - "5c16bff431f597f017482b29464d462ce17c34841d7358a4f058e88659a9" - "cb582f54770386ce46c9b046376f9138d0968d8f4f7fd1707aa2cac0b37a" - "3822bcf30c8cd90a301e58f8781ecd86198ea5b79f66e7a8037a08641aad" - "c250d1bfd85cbf8ace52650aad4883db8d9bcd059cb86339e8e6b9d13b28" - "7a54a86ed3334d8111d817dea10aa97d60c2de2801d91a36cfd177e517a5" - "568240fc0b081f1dd029afa31460b913be78b3cb71f91ef02cf64dc773c5" - "68c23fde3a5c46becfb2d30281c073881b2438f5eb61889314cadf68ad57" - "d5d3896e420bce3a2d9546ef1941eb29840dedb8156294aabb1bd1271f61" - "07cbdff1a4eaa5e12ceabdba028af878a0a7f45bc53200a43762feff8784" - "ffa1b4584333d4077bdceb9ec1fb90f34048639cff3c045ea9a339c03449" - "ae2bd609e0e5de03745069088810815b63e1a068eeeece766ff871365097" - "1f9db196de6fe768829f2217086aad9b7d6386e6161dd4be836b391fb6d7" - "361ad458681a72a346f3829bdb33911338b168726d4c33585c38"_u8v}}, + "308206fc020100300d06092a864886f70d0101010500048206e6308206e2" + "02010002820181009ef69f30260218f18e869be9037146f04b4de8fff021" + "0bfe7c2dccbaaacea37241632615188c4cae9c57fe21840ce1035a4abd2e" + "9ec034528cbb63f6711f8eb9d67eed6bc807c9c835cbcf64219daa774e2f" + "d3e8255db326f790e8554c6696e1ecd73e763e4c736585884d4bc2411e69" + "90f8739af952c4f6094bdf7434d4a11a8d9d85408e74167eeaeb7e8fe8f6" + "08e27adaefbc6293fdd3ad74387f5af125dfe1d1c15ac479f7ec3bfd6d10" + "f6c03c71efb677758133496cdb73540014b074fb6b83b13caec94b7866b6" + "fe89469fd98fd55490389ef3a6b59c3edd156f92edadcdd16ff410c06258" + "976e898c9ed508ad7f651fd03a68efc27c96725c7bbcd1d54c1073a108b8" + "d95987810e0c8cf3f8f9b7f65a1cbae160555ab7f1d70d4ad09e497d371a" + "e2fbe52377b014fabd79ac3de1218e1431c659062026dfba10a4d533c78f" + "643c51139402f6736e0b296e1aed6186951ab39cf9e581b2a960473cd56c" + "5fff709f72459cb43f3b104954e89c01a93572c7293a41e37f21cf96aec1" + "f35f0203010001028201804ac3e40b5955133649bab609da3ca508cfe24c" + "fc538cb77f7218787a336c0d23e7ed223439df83117d2745b7adcc00e8fa" + "c6bc43f9169d8555fbad0074244b94cc75d6652327c6980bf5580dd861b7" + "93758ab9382e9aeb7020705f55ff21214611870b31c20b631b14bb0edfda" + "f595c041171a0881cb9427c4279369ac8f75566fd4ee9f7660d353ce5a04" + "a4db051d18a87fe0d1d1eb992ebe1e339472c988eecbdf43f9d4ff28c44b" + "52dec163ccf6a10005ea19b232d50e06093030f98a24fa7dc412c71ffdb9" + "8dc63db77f8f33dbd3df192b2771ab705f641c7778474de52ddc6ebdfd01" + "fc0bc795952987061c7fa11378a4b51fdb7508823667464a1e5bf60b8672" + "dd91f4eb22053cd8c923dc0225cabd4271f548ddca7f525a93b8974f4c42" + "acead413837699ca7c9c5dc7860799674bcf096fcd06fc0799981d0c8ad9" + "23b9b9686abe319af624eaab7078150cc0ba282bcca6425137c2277ba202" + "a4129de2d97bf66ad105b5d62bf842108c121bda3b0bc08f57a36b9d931f" + "815314aa010281c100cd6a02b434ac9f4f02522bd6a0b77072f5e40975a2" + "85160ce39130498dd9933d582ec8e9ecd346a17d263b35966836ba39f9b6" + "29242fda9070de855c47b9ecfc8520ff61923febf5699a561b1ef03d4bb7" + "75bb8c7e99edf0f2946b36cb268045f376cc5d93792ca8a3116e075aa3b9" + "5ac8d4c302d4266d28baf49a6e73287fab027af8a20dedfc05eda55c67c9" + "8c2b23f61be6194a8f74509027d73a2235da3f6d2abce60d0dd558279009" + "2d9cb07c7ed10c61dde3e187e464e81d74ebf61ee10281c100c61c33b8e3" + "d8d29a3641189dfed84e78cdca674360bc26ed40850ca1e696e63395e347" + "bb242ee7763e852c8d84a45156d771b922b26c33cacd34b0e6037860fd64" + "0862137034c8424b7d5bbd19fd4e092f9592693c857a56585a0f0c0023d9" + "8480dec59f9ff9b340f365418a7138d244d702ca7e23f74d647bac6499c7" + "4744346e76ab7c2cfc52e41ea853b8bd51fbd4ca49bed8550ffc07ba11d9" + "4ff3694570f082b90a54e0bc5702b9536653b557fcba9a5db01af4826c97" + "14723e933d9a3f0281c10085d44d92aec6d0bc1f1cfe26c56afe3e47e99c" + "28220c67435a785b67709d928a630b8826afff2e834410467f31511066e0" + "22cb059ee7f69428953179dfd94887750cfc95cb3d0e3443eb23b263c3cd" + "8ba9297159a59a1025ed45b95c679adbf3b71d6d2482526e40288dd08bc6" + "07e959368337d27df9d320b83d68e810eb0fd290b92188235f2bd588f135" + "750120eb727083d8b41d99bd0448074cf83915b0eda5e8344e05af3f9241" + "a45bf675a19e5ad94421f6f8e315303e75e3cb2b789cc10281bf5c16bff4" + "31f597f017482b29464d462ce17c34841d7358a4f058e88659a9cb582f54" + "770386ce46c9b046376f9138d0968d8f4f7fd1707aa2cac0b37a3822bcf3" + "0c8cd90a301e58f8781ecd86198ea5b79f66e7a8037a08641aadc250d1bf" + "d85cbf8ace52650aad4883db8d9bcd059cb86339e8e6b9d13b287a54a86e" + "d3334d8111d817dea10aa97d60c2de2801d91a36cfd177e517a5568240fc" + "0b081f1dd029afa31460b913be78b3cb71f91ef02cf64dc773c568c23fde" + "3a5c46becfb2d30281c073881b2438f5eb61889314cadf68ad57d5d3896e" + "420bce3a2d9546ef1941eb29840dedb8156294aabb1bd1271f6107cbdff1" + "a4eaa5e12ceabdba028af878a0a7f45bc53200a43762feff8784ffa1b458" + "4333d4077bdceb9ec1fb90f34048639cff3c045ea9a339c03449ae2bd609" + "e0e5de03745069088810815b63e1a068eeeece766ff8713650971f9db196" + "de6fe768829f2217086aad9b7d6386e6161dd4be836b391fb6d7361ad458" + "681a72a346f3829bdb33911338b168726d4c33585c38"_u8v}}, {}); RsaCheck( "4096"sv, @@ -517,85 +519,86 @@ TEST_F(WasiCryptoTest, Asymmetric) { "Kah5KuUlkoLBFrMqKJCKqhvP72HZUp8=\n" "-----END PRIVATE KEY-----\n"_u8}, {__WASI_SECRETKEY_ENCODING_PKCS8, - "308209290201000282020100adc2251cad784f58e268ac62892dc48ce710" - "bf112baa773f6902cf5cd231fdddb50f5fc88b5bce0e8493b79280bd613c" - "eda874ac4e6802cd3031545648057e7af5001be5958c188df1cac68a54d5" - "5b2ca508d5413362b0aedf79ac360bcfaf0e2fb64e809dfddca7d9d876cc" - "ab9d5077a009b2f4e410655d6db3b7abd3a9c7586da159b80ab69cdd380a" - "a4d41c790ea2edbde014131a22318d843f1fb9680ae45018420b2dcd4457" - "07907fbf380e4a0906ef4ab52acafc8aae6ae54a5046855a530abb4d7bc5" - "474840468134952ed359e000e4692c1b02d83c37e0ed6bd99f52e608b534" - "38001d3c901e8053996bceeb2994e8f0f5d63206d0335e7faf3d460c555b" - "2514e81abd6980ed25359d7db8115f992b5c428c07351d62b3d182bf4f32" - "6fa65006ccc0dc98467add806aacbd1e037a9d2f1bcc99d8963cf7ab9313" - "c332efb41625ec1c383010f2a5b3e7a87786c946352c9f019116d7d1bdfd" - "14f6e5f72c165c39e327cb25cb50f71509ad3a2379e20aeb616f047bd913" - "3d0721760a20397d692514d1b849c6248e5a4620c481592bc75da34b299d" - "794a6ecbb51dbcf916290633c2000180d73b459b3afaa5a441906c53dea3" - "4b131ac8a4ec3f953a7d3dab91202c9971fa1cf8ed0f682fa5bbe5df6314" - "04695b34e51887881e3efd351a2cedb8e3428c049f851be325a80ae339f5" - "f0f7d5596a60a9ab38aeb143c8b1020301000102820201009fcb41b203dd" - "f6aab95ca5dbe06814afb7f7f09eebd752df1fc593c9bb0c7a791fffc988" - "7690b1092ce76414f90c309685c13bbb124818fb766c8730e9ff13782444" - "3b63818a5b327ec08aa1c0ae8db09afd6a91119e9af9d74ee00ebc01fced" - "40f7996e32ddb9c52b5424bcc8be5db80597a5da0cbaff5d527bca57dcc9" - "f027e47a54362ab411c267bd72241723455094eedf59d6cf5bced2646260" - "dc735040f35fbbace82c33c30d93d7c794d79f4279fc2a6a3db67b55565d" - "bc0c01933923fde68aa6114335f0be1b98cb30bee5636bdba6330a7ab4bd" - "037428086ffcacc6e201f412d7c5531dd53417b3ca0e1936af00e0d06864" - "d990e67bb7561964723b73cac9c0ec9c54a0c4f58d29aed239a32c941789" - "1b3a42023bb68788e697457d4956c5151d5d9c939fd441f032ab65d9fd0f" - "7d8fcac72581e608e068cbb10243ced24ad194ff2a729773c2df569ff254" - "c442937e6a3fc1db2ab635a6847e8bb07d6f6434d3247e2cc943a681e2e6" - "326badef64174062567c4f729b8221054cf3ba026d37062a51782843b5f4" - "0f94d97178a718c4a5f0129f9d743bbe7cd997a30a497a2b856dcdffb524" - "bc5af7cc403995bb48a4386e4c7c3cb5b4bdb6178db096136756db759703" - "76a29dff59fc9fcaba0da8c03ea53d81aa78f3007fdc0212a723a0d4fbc6" - "d31bc8e63e93277ed5c793dd4e5148ee7037bea4b2c9d20744f902820101" - "00e7a5aaff6c7c76b1a50188684664bc94995db54db31940979c11e4ae86" - "553a446f1b46c3653a240649fb647847f3fb8222321468d2516e9143ce56" - "ed4a1c58a61cee9c8ac9abc3a14a8004bca3aa548239ac343d5a566f2086" - "78e9241c83ab856c0ea5f79862d963f08e5e449f3737b30d62e9f12fe1ec" - "e23d615a2441e565aa27b88123a3cb87299061806de75b9c39a1870fcd33" - "8cde881c6b60ccd716b5841cc98b43f6f199fe34906d24f9df662c87973b" - "17740b6526dbd3be95649c8a709e37cbf2f33d442935f847d65c6aa0ae7f" - "e92352257d09f564f45dff127ffaac193764ad45e1e353767e2a3351faba" - "002e9055fc08643dddfaae68685ea1cf4f0282010100c006837d3d94e8cc" - "d5d1cb9d075435531add16ab32c727085969317a79ff2dd3942de70dd2e8" - "e85bc4d24f3e0f5ac5d10c086d7514b49ec8385da94a6fd5189eace884d4" - "1caafd2b251fccd47699e1ca1400265c2027e201d1563f54a215982ef4f2" - "56fc4509bb6cf4f2ddf988e1b018bc5f1c53e88996a8980565c8bfa49045" - "c0f005b09ace500998dff8866e4ff45b50ae2a833711c5b62cbf47239b29" - "8239bac53508e3b1e7500f280e9a3d2b03c5f543fb4d7cab31f3a6095760" - "ac12b5895b845464ffe15acf64478d25d8dbc99d66441ec9c844d3c5809f" - "7435abd7bd4c18e1512dd03f409fd8e25972d71793690a1162581ebff04b" - "e0c22c9d421ae7ff0282010100ca4e716ce9be399b23d496e11ab957c91f" - "b82b63548b355233479d449ea5486203f6fa7223b2074c46c87b126124bf" - "ff030661b4ba19cc4aa9c14741ad7bdc20171c7d32e8b64e004b244afec4" - "a13975121146fe5e2b269a6d56a3a69a109477cdcb6d3f33a300e0bb725a" - "f9eb633a0df21ef4d9634c18a9ed995c345628960568f346356e138e53ca" - "789cc55f4d2ebf5646b2922ff75351e42013465282cebc36b7fe1cb94a7f" - "86f7393b5913b0fb76e0643f835821ab91a862e10b6ff717210bcd071d83" - "397a91d344a6d89b95b4073246d64623df741710b82a6f5e24cf340641ce" - "775594d6084a701d42fcd1d027e082c5d57b2eb0a9710968bdde07cb1902" - "8201007084208dc504b8f8351d8e023fe61eacc863ddc188aa5afbb0704b" - "f6a8ff55d9d0cca8c357def32fba7f44c1677bd7c76b1691147682733b7c" - "939cfba9d5a26c6f827b3a5265fa9c4a4f7cd4cd7dd3687619b4606cd311" - "c1e0e879895cc3ec7d2f37c4b262bd961dfcd5462dcdfff8103668409006" - "0480cbdeefe2b9235a08d3667efdc6829efb14f487dff6dd326a4b0b5652" - "40ba86e6639d8aa2e3812a6c4ce95c5e7c0bf71543baedccc8a1a8cfb831" - "c398bfc99a0255f2e72c54cfded2925e968660fbf20f24d06c808f39a767" - "569c32b41ba606765416ced7e80efe05c44e1bd05658dc8740627416d78c" - "a31ad4047b3535cf26c6659f98074e4a0abbff02820100396f9627e90f47" - "64210dae75f1bbc148b43979a3f6c7a943bde4d27f142f35f53d5de83329" - "271ace78d0a952286820ff8ed87b3bc8e01ba9059d33568c7466d822960d" - "a81dd4259840061a20c160d2cefcd77d168b72b5c02e823c9a9e7808e486" - "97cca544016e901594e47886af97f4b8e0db41dfd32333bb5169a24dbbad" - "0f699fd43a589ccb23634593d2d89ad89b2992a3d2dadbbb22f7263d2ff8" - "170b0d322f6ff7b1b47970f124a05374a3d5c60ece415ac9ef30bcc7530b" - "5bb660b63f0e24a7c933464db5707ef9625ff083d3a2afde6e2f40bcc380" - "7b15a7023e28cd3c23ecbe66098bd0a029a8792ae5259282c116b32a2890" - "8aaa1bcfef61d9529f"_u8v}}, + "30820943020100300d06092a864886f70d01010105000482092d30820929" + "0201000282020100adc2251cad784f58e268ac62892dc48ce710bf112baa" + "773f6902cf5cd231fdddb50f5fc88b5bce0e8493b79280bd613ceda874ac" + "4e6802cd3031545648057e7af5001be5958c188df1cac68a54d55b2ca508" + "d5413362b0aedf79ac360bcfaf0e2fb64e809dfddca7d9d876ccab9d5077" + "a009b2f4e410655d6db3b7abd3a9c7586da159b80ab69cdd380aa4d41c79" + "0ea2edbde014131a22318d843f1fb9680ae45018420b2dcd445707907fbf" + "380e4a0906ef4ab52acafc8aae6ae54a5046855a530abb4d7bc547484046" + "8134952ed359e000e4692c1b02d83c37e0ed6bd99f52e608b53438001d3c" + "901e8053996bceeb2994e8f0f5d63206d0335e7faf3d460c555b2514e81a" + "bd6980ed25359d7db8115f992b5c428c07351d62b3d182bf4f326fa65006" + "ccc0dc98467add806aacbd1e037a9d2f1bcc99d8963cf7ab9313c332efb4" + "1625ec1c383010f2a5b3e7a87786c946352c9f019116d7d1bdfd14f6e5f7" + "2c165c39e327cb25cb50f71509ad3a2379e20aeb616f047bd9133d072176" + "0a20397d692514d1b849c6248e5a4620c481592bc75da34b299d794a6ecb" + "b51dbcf916290633c2000180d73b459b3afaa5a441906c53dea34b131ac8" + "a4ec3f953a7d3dab91202c9971fa1cf8ed0f682fa5bbe5df631404695b34" + "e51887881e3efd351a2cedb8e3428c049f851be325a80ae339f5f0f7d559" + "6a60a9ab38aeb143c8b1020301000102820201009fcb41b203ddf6aab95c" + "a5dbe06814afb7f7f09eebd752df1fc593c9bb0c7a791fffc9887690b109" + "2ce76414f90c309685c13bbb124818fb766c8730e9ff137824443b63818a" + "5b327ec08aa1c0ae8db09afd6a91119e9af9d74ee00ebc01fced40f7996e" + "32ddb9c52b5424bcc8be5db80597a5da0cbaff5d527bca57dcc9f027e47a" + "54362ab411c267bd72241723455094eedf59d6cf5bced2646260dc735040" + "f35fbbace82c33c30d93d7c794d79f4279fc2a6a3db67b55565dbc0c0193" + "3923fde68aa6114335f0be1b98cb30bee5636bdba6330a7ab4bd03742808" + "6ffcacc6e201f412d7c5531dd53417b3ca0e1936af00e0d06864d990e67b" + "b7561964723b73cac9c0ec9c54a0c4f58d29aed239a32c9417891b3a4202" + "3bb68788e697457d4956c5151d5d9c939fd441f032ab65d9fd0f7d8fcac7" + "2581e608e068cbb10243ced24ad194ff2a729773c2df569ff254c442937e" + "6a3fc1db2ab635a6847e8bb07d6f6434d3247e2cc943a681e2e6326badef" + "64174062567c4f729b8221054cf3ba026d37062a51782843b5f40f94d971" + "78a718c4a5f0129f9d743bbe7cd997a30a497a2b856dcdffb524bc5af7cc" + "403995bb48a4386e4c7c3cb5b4bdb6178db096136756db75970376a29dff" + "59fc9fcaba0da8c03ea53d81aa78f3007fdc0212a723a0d4fbc6d31bc8e6" + "3e93277ed5c793dd4e5148ee7037bea4b2c9d20744f90282010100e7a5aa" + "ff6c7c76b1a50188684664bc94995db54db31940979c11e4ae86553a446f" + "1b46c3653a240649fb647847f3fb8222321468d2516e9143ce56ed4a1c58" + "a61cee9c8ac9abc3a14a8004bca3aa548239ac343d5a566f208678e9241c" + "83ab856c0ea5f79862d963f08e5e449f3737b30d62e9f12fe1ece23d615a" + "2441e565aa27b88123a3cb87299061806de75b9c39a1870fcd338cde881c" + "6b60ccd716b5841cc98b43f6f199fe34906d24f9df662c87973b17740b65" + "26dbd3be95649c8a709e37cbf2f33d442935f847d65c6aa0ae7fe9235225" + "7d09f564f45dff127ffaac193764ad45e1e353767e2a3351faba002e9055" + "fc08643dddfaae68685ea1cf4f0282010100c006837d3d94e8ccd5d1cb9d" + "075435531add16ab32c727085969317a79ff2dd3942de70dd2e8e85bc4d2" + "4f3e0f5ac5d10c086d7514b49ec8385da94a6fd5189eace884d41caafd2b" + "251fccd47699e1ca1400265c2027e201d1563f54a215982ef4f256fc4509" + "bb6cf4f2ddf988e1b018bc5f1c53e88996a8980565c8bfa49045c0f005b0" + "9ace500998dff8866e4ff45b50ae2a833711c5b62cbf47239b298239bac5" + "3508e3b1e7500f280e9a3d2b03c5f543fb4d7cab31f3a6095760ac12b589" + "5b845464ffe15acf64478d25d8dbc99d66441ec9c844d3c5809f7435abd7" + "bd4c18e1512dd03f409fd8e25972d71793690a1162581ebff04be0c22c9d" + "421ae7ff0282010100ca4e716ce9be399b23d496e11ab957c91fb82b6354" + "8b355233479d449ea5486203f6fa7223b2074c46c87b126124bfff030661" + "b4ba19cc4aa9c14741ad7bdc20171c7d32e8b64e004b244afec4a1397512" + "1146fe5e2b269a6d56a3a69a109477cdcb6d3f33a300e0bb725af9eb633a" + "0df21ef4d9634c18a9ed995c345628960568f346356e138e53ca789cc55f" + "4d2ebf5646b2922ff75351e42013465282cebc36b7fe1cb94a7f86f7393b" + "5913b0fb76e0643f835821ab91a862e10b6ff717210bcd071d83397a91d3" + "44a6d89b95b4073246d64623df741710b82a6f5e24cf340641ce775594d6" + "084a701d42fcd1d027e082c5d57b2eb0a9710968bdde07cb190282010070" + "84208dc504b8f8351d8e023fe61eacc863ddc188aa5afbb0704bf6a8ff55" + "d9d0cca8c357def32fba7f44c1677bd7c76b1691147682733b7c939cfba9" + "d5a26c6f827b3a5265fa9c4a4f7cd4cd7dd3687619b4606cd311c1e0e879" + "895cc3ec7d2f37c4b262bd961dfcd5462dcdfff81036684090060480cbde" + "efe2b9235a08d3667efdc6829efb14f487dff6dd326a4b0b565240ba86e6" + "639d8aa2e3812a6c4ce95c5e7c0bf71543baedccc8a1a8cfb831c398bfc9" + "9a0255f2e72c54cfded2925e968660fbf20f24d06c808f39a767569c32b4" + "1ba606765416ced7e80efe05c44e1bd05658dc8740627416d78ca31ad404" + "7b3535cf26c6659f98074e4a0abbff02820100396f9627e90f4764210dae" + "75f1bbc148b43979a3f6c7a943bde4d27f142f35f53d5de83329271ace78" + "d0a952286820ff8ed87b3bc8e01ba9059d33568c7466d822960da81dd425" + "9840061a20c160d2cefcd77d168b72b5c02e823c9a9e7808e48697cca544" + "016e901594e47886af97f4b8e0db41dfd32333bb5169a24dbbad0f699fd4" + "3a589ccb23634593d2d89ad89b2992a3d2dadbbb22f7263d2ff8170b0d32" + "2f6ff7b1b47970f124a05374a3d5c60ece415ac9ef30bcc7530b5bb660b6" + "3f0e24a7c933464db5707ef9625ff083d3a2afde6e2f40bcc3807b15a702" + "3e28cd3c23ecbe66098bd0a029a8792ae5259282c116b32a28908aaa1bcf" + "ef61d9529f"_u8v}}, {}); } From c9461875d081600248c6052b8211cfe4eff41de0 Mon Sep 17 00:00:00 2001 From: Yi-Ying He Date: Mon, 20 Jan 2025 22:31:51 +0800 Subject: [PATCH 527/623] [WASI-NN] Fix wrong function dispatch. (#3979) Signed-off-by: YiYing He --- plugins/wasi_nn/wasinn_whisper.h | 2 ++ plugins/wasi_nn/wasinnfunc.cpp | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn/wasinn_whisper.h b/plugins/wasi_nn/wasinn_whisper.h index 2071ef29..c61ede75 100644 --- a/plugins/wasi_nn/wasinn_whisper.h +++ b/plugins/wasi_nn/wasinn_whisper.h @@ -113,4 +113,6 @@ Expect compute(WASINN::WasiNNEnvironment &Env, uint32_t ContextId) noexcept; Expect unload(WASINN::WasiNNEnvironment &Env, uint32_t GraphId) noexcept; +Expect finalizeExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; } // namespace WasmEdge::Host::WASINN::Whisper diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index 32c02c03..4ce733c9 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -709,9 +709,9 @@ WasiNNFinalizeExecCtx::bodyImpl(const Runtime::CallingFrame &Frame, switch (Env.NNContext[ContextId].getBackend()) { case WASINN::Backend::GGML: - return WASINN::GGML::unload(Env, ContextId); + return WASINN::GGML::finalizeExecCtx(Env, ContextId); case WASINN::Backend::Whisper: - return WASINN::Whisper::unload(Env, ContextId); + return WASINN::Whisper::finalizeExecCtx(Env, ContextId); default: spdlog::error("[WASI-NN] finalize_execution_context: Only GGML and "sv "Whisper backends support finalize_execution_context."sv); From c338fff893aa3e803522e02d5bbfbf1340ca3bcc Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Tue, 21 Jan 2025 12:27:34 +0800 Subject: [PATCH 528/623] [Misc] Use `string_view` literals suffix for spdlog functions Signed-off-by: Shen-Ta Hsieh --- plugins/wasi_http/func.cpp | 8 +++-- plugins/wasi_nn/wasinn_onnx.cpp | 4 ++- plugins/wasi_nn/wasinn_openvino.cpp | 32 +++++++++-------- plugins/wasi_nn/wasinn_tf.cpp | 4 ++- plugins/wasi_nn/wasinn_tfl.cpp | 39 +++++++++++---------- plugins/wasi_nn/wasinn_whisper.cpp | 39 +++++++++++---------- plugins/wasi_nn/wasinnenv.cpp | 2 ++ plugins/wasi_nn/wasinnfunc.cpp | 16 ++++----- plugins/wasm_bpf/func-bpf-map-operate.cpp | 8 +++-- plugins/wasm_bpf/wasm-bpf.cpp | 6 ++-- plugins/wasmedge_process/processfunc.cpp | 10 +++--- plugins/wasmedge_stablediffusion/sd_env.cpp | 4 ++- test/plugins/wasi_nn/wasi_nn.cpp | 6 ++-- 13 files changed, 100 insertions(+), 78 deletions(-) diff --git a/plugins/wasi_http/func.cpp b/plugins/wasi_http/func.cpp index 245c9942..bb101044 100644 --- a/plugins/wasi_http/func.cpp +++ b/plugins/wasi_http/func.cpp @@ -10,19 +10,21 @@ #include #include +using namespace std::literals; + namespace WasmEdge { namespace Host { Expect WasiHttpPrint::body(std::string S) { - spdlog::info("[WASI-HTTP] print: {}", S); + spdlog::info("[WASI-HTTP] print: {}"sv, S); return {}; } Expect WasiHttpGet::body(std::string URI) { - spdlog::info("[WASI-HTTP] URI: {}", URI); + spdlog::info("[WASI-HTTP] URI: {}"sv, URI); cpr::Response Res = cpr::Get( cpr::Url{URI}, cpr::Authentication{"user", "pass", cpr::AuthMode::BASIC}); - spdlog::info("[WASI-HTTP] status: {}", Res.status_code); + spdlog::info("[WASI-HTTP] status: {}"sv, Res.status_code); return std::move(Res.text); } diff --git a/plugins/wasi_nn/wasinn_onnx.cpp b/plugins/wasi_nn/wasinn_onnx.cpp index 236d3d32..744cc263 100644 --- a/plugins/wasi_nn/wasinn_onnx.cpp +++ b/plugins/wasi_nn/wasinn_onnx.cpp @@ -4,10 +4,12 @@ #include "wasinn_onnx.h" #include "wasinnenv.h" +using namespace std::literals; + namespace WasmEdge::Host::WASINN::ONNX { namespace { Expect reportBackendNotSupported() noexcept { - spdlog::error("[WASI-NN] ONNX backend is not supported."); + spdlog::error("[WASI-NN] ONNX backend is not supported."sv); return WASINN::ErrNo::InvalidArgument; } } // namespace diff --git a/plugins/wasi_nn/wasinn_openvino.cpp b/plugins/wasi_nn/wasinn_openvino.cpp index 744bf7ed..729edd02 100644 --- a/plugins/wasi_nn/wasinn_openvino.cpp +++ b/plugins/wasi_nn/wasinn_openvino.cpp @@ -6,6 +6,8 @@ #include +using namespace std::literals; + namespace WasmEdge::Host::WASINN::OpenVINO { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO Expect load(WASINN::WasiNNEnvironment &Env, @@ -13,7 +15,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, WASINN::Device Device, uint32_t &GraphId) noexcept { // The graph builder length must be 2. if (Builders.size() != 2) { - spdlog::error("[WASI-NN] Wrong GraphBuilder Length {:d}, expect 2", + spdlog::error("[WASI-NN] Wrong GraphBuilder Length {:d}, expect 2"sv, Builders.size()); return WASINN::ErrNo::InvalidArgument; } @@ -41,7 +43,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, GraphRef.OpenVINOModel = Env.OpenVINOCore.read_model( ModelString, GraphRef.OpenVINOIWeightTensor); } catch (const std::exception &EX) { - spdlog::error("[WASI-NN] Model Load Exception: {}", EX.what()); + spdlog::error("[WASI-NN] Model Load Exception: {}"sv, EX.what()); Env.deleteGraph(GId); return WASINN::ErrNo::RuntimeError; } @@ -57,7 +59,7 @@ Expect initExecCtx(WASINN::WasiNNEnvironment &Env, // Check the network and the execution network with the graph ID. auto &GraphRef = Env.NNGraph[GraphId].get(); if (GraphRef.OpenVINOModel == nullptr) { - spdlog::error("[WASI-NN] Model for Graph:{} is empty!", GraphId); + spdlog::error("[WASI-NN] Model for Graph:{} is empty!"sv, GraphId); return WASINN::ErrNo::MissingMemory; } // Create context. @@ -73,27 +75,27 @@ Expect setInput(WASINN::WasiNNEnvironment &Env, auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); if (GraphRef.OpenVINOModel == nullptr) { - spdlog::error("[WASI-NN] The founded openvino session is empty"); + spdlog::error("[WASI-NN] The founded openvino session is empty"sv); return WASINN::ErrNo::MissingMemory; } if (Tensor.Dimension.size() > 8) { - spdlog::error("[WASI-NN] Tensor dimension is out of range, expect " - "it under 8-dim, " - "but got {}-dim.", + spdlog::error("[WASI-NN] Tensor dimension is out of range, expect it under " + "8-dim, but got {}-dim."sv, Tensor.Dimension.size()); return WASINN::ErrNo::InvalidArgument; } if (Tensor.RType != WASINN::TensorType::F32) { spdlog::error( - "[WASI-NN] Only F32 inputs and outputs are supported for now."); + "[WASI-NN] Only F32 inputs and outputs are supported for now."sv); return WASINN::ErrNo::InvalidArgument; } // Check the input index. if (GraphRef.OpenVINOModel->inputs().size() <= Index) { - spdlog::error("[WASI-NN] The input index {} exceeds the inputs number {}.", - Index, GraphRef.OpenVINOModel->inputs().size()); + spdlog::error( + "[WASI-NN] The input index {} exceeds the inputs number {}."sv, Index, + GraphRef.OpenVINOModel->inputs().size()); return WASINN::ErrNo::InvalidArgument; } @@ -119,7 +121,7 @@ Expect setInput(WASINN::WasiNNEnvironment &Env, CxtRef.OpenVINOInferRequest = CompiledModel.create_infer_request(); CxtRef.OpenVINOInferRequest.set_input_tensor(Index, InputTensor); } catch (const std::exception &EX) { - spdlog::error("[WASI-NN] Set Input Exception: {}", EX.what()); + spdlog::error("[WASI-NN] Set Input Exception: {}"sv, EX.what()); return WASINN::ErrNo::RuntimeError; } return WASINN::ErrNo::Success; @@ -135,7 +137,7 @@ Expect getOutput(WASINN::WasiNNEnvironment &Env, // Check the output index. if (GraphRef.OpenVINOModel->outputs().size() <= Index) { spdlog::error( - "[WASI-NN] The output index {} exceeds the outputs number {}.", Index, + "[WASI-NN] The output index {} exceeds the outputs number {}."sv, Index, GraphRef.OpenVINOModel->outputs().size()); return WASINN::ErrNo::InvalidArgument; } @@ -147,7 +149,7 @@ Expect getOutput(WASINN::WasiNNEnvironment &Env, std::copy_n(static_cast(OutputTensor.data()), BytesWritten, OutBuffer.data()); } catch (const std::exception &EX) { - spdlog::error("[WASI-NN] Get Output Exception: {}", EX.what()); + spdlog::error("[WASI-NN] Get Output Exception: {}"sv, EX.what()); return WASINN::ErrNo::RuntimeError; } return WASINN::ErrNo::Success; @@ -159,7 +161,7 @@ Expect compute(WASINN::WasiNNEnvironment &Env, try { CxtRef.OpenVINOInferRequest.infer(); } catch (const std::exception &EX) { - spdlog::error("[WASI-NN] Infer Request Exception: {}", EX.what()); + spdlog::error("[WASI-NN] Infer Request Exception: {}"sv, EX.what()); return WASINN::ErrNo::RuntimeError; } return WASINN::ErrNo::Success; @@ -168,7 +170,7 @@ Expect compute(WASINN::WasiNNEnvironment &Env, namespace { Expect reportBackendNotSupported() noexcept { spdlog::error("[WASI-NN] OpenVINO backend is not built. use " - "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"OpenVINO\" to build it."); + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"OpenVINO\" to build it."sv); return WASINN::ErrNo::InvalidArgument; } } // namespace diff --git a/plugins/wasi_nn/wasinn_tf.cpp b/plugins/wasi_nn/wasinn_tf.cpp index 2d860429..caf4492b 100644 --- a/plugins/wasi_nn/wasinn_tf.cpp +++ b/plugins/wasi_nn/wasinn_tf.cpp @@ -4,10 +4,12 @@ #include "wasinn_tf.h" #include "wasinnenv.h" +using namespace std::literals; + namespace WasmEdge::Host::WASINN::Tensorflow { namespace { Expect reportBackendNotSupported() noexcept { - spdlog::error("[WASI-NN] Tensorflow backend is not supported."); + spdlog::error("[WASI-NN] Tensorflow backend is not supported."sv); return WASINN::ErrNo::InvalidArgument; } } // namespace diff --git a/plugins/wasi_nn/wasinn_tfl.cpp b/plugins/wasi_nn/wasinn_tfl.cpp index acdda8ee..51aab788 100644 --- a/plugins/wasi_nn/wasinn_tfl.cpp +++ b/plugins/wasi_nn/wasinn_tfl.cpp @@ -8,18 +8,20 @@ #include "tensorflow/lite/c/common.h" #endif +using namespace std::literals; + namespace WasmEdge::Host::WASINN::TensorflowLite { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_TFLITE Expect load(WASINN::WasiNNEnvironment &Env, Span> Builders, WASINN::Device Device, uint32_t &GraphId) noexcept { if ((Device != WASINN::Device::CPU)) { - spdlog::error("[WASI-NN] TensorflowLite Only support CPU target."); + spdlog::error("[WASI-NN] TensorflowLite Only support CPU target."sv); return WASINN::ErrNo::InvalidArgument; } // The graph builder length must be 1. if (Builders.size() != 1) { - spdlog::error("[WASI-NN] Wrong GraphBuilder Length {:d}, expect 1", + spdlog::error("[WASI-NN] Wrong GraphBuilder Length {:d}, expect 1"sv, Builders.size()); return WASINN::ErrNo::InvalidArgument; } @@ -32,7 +34,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, GraphRef.TFLiteMod = TfLiteModelCreate(GraphRef.TfLiteModData.data(), GraphRef.TfLiteModData.size()); if (unlikely(GraphRef.TFLiteMod == nullptr)) { - spdlog::error("[WASI-NN] Cannot import TFLite model"); + spdlog::error("[WASI-NN] Cannot import TFLite model"sv); Env.deleteGraph(GId); return WASINN::ErrNo::InvalidArgument; } @@ -48,7 +50,7 @@ Expect initExecCtx(WASINN::WasiNNEnvironment &Env, uint32_t &ContextId) noexcept { // Check the network and the execution network with the graph ID. if (Env.NNGraph[GraphId].get().TFLiteMod == nullptr) { - spdlog::error("[WASI-NN] Model for Graph:{} is missing!", GraphId); + spdlog::error("[WASI-NN] Model for Graph:{} is missing!"sv, GraphId); return WASINN::ErrNo::MissingMemory; } @@ -61,7 +63,7 @@ Expect initExecCtx(WASINN::WasiNNEnvironment &Env, CxtRef.TFLiteInterp = TfLiteInterpreterCreate(GraphRef.TFLiteMod, TFLiteOps); TfLiteInterpreterOptionsDelete(TFLiteOps); if (unlikely(CxtRef.TFLiteInterp == nullptr)) { - spdlog::error("[WASI-NN] Cannot create TFLite interpreter."); + spdlog::error("[WASI-NN] Cannot create TFLite interpreter."sv); Env.deleteContext(CId); return WASINN::ErrNo::Busy; } @@ -89,22 +91,23 @@ Expect setInput(WASINN::WasiNNEnvironment &Env, // Check the input data size. const auto HoldTensorByteSize = TfLiteTensorByteSize(HoldTensor); if (HoldTensorByteSize != Tensor.Tensor.size()) { - spdlog::error("[WASI-NN] Expect tensor byte size {}, but got {}", + spdlog::error("[WASI-NN] Expect tensor byte size {}, but got {}"sv, HoldTensorByteSize, Tensor.Tensor.size()); return WASINN::ErrNo::InvalidArgument; } // Check the input tensor dimensions. const auto HoldTensorNumDims = TfLiteTensorNumDims(HoldTensor); if (static_cast(HoldTensorNumDims) != Tensor.Dimension.size()) { - spdlog::error("[WASI-NN] Expect tensor number of dimensions {}, but got {}", - HoldTensorNumDims, Tensor.Dimension.size()); + spdlog::error( + "[WASI-NN] Expect tensor number of dimensions {}, but got {}"sv, + HoldTensorNumDims, Tensor.Dimension.size()); return WASINN::ErrNo::InvalidArgument; } for (uint32_t I = 0; I < Tensor.Dimension.size(); I++) { const auto HoldTensorDim = TfLiteTensorDim(HoldTensor, I); if (static_cast(HoldTensorDim) != Tensor.Dimension[I]) { - spdlog::error("[WASI-NN] Expect tensor dimension[{}] = {}, but got {}", I, - HoldTensorDim, Tensor.Dimension[I]); + spdlog::error("[WASI-NN] Expect tensor dimension[{}] = {}, but got {}"sv, + I, HoldTensorDim, Tensor.Dimension[I]); return WASINN::ErrNo::InvalidArgument; } } @@ -124,20 +127,20 @@ Expect setInput(WASINN::WasiNNEnvironment &Env, LiteType = WASINN::TensorType::I32; break; default: - spdlog::error("[WASI-NN] Unsupported TFLite type: {}", + spdlog::error("[WASI-NN] Unsupported TFLite type: {}"sv, TfLiteTypeGetName(Type)); return WASINN::ErrNo::InvalidArgument; } if (unlikely(LiteType != Tensor.RType)) { - spdlog::error("[WASI-NN] Expect tensor type {}, but got {}", LiteType, + spdlog::error("[WASI-NN] Expect tensor type {}, but got {}"sv, LiteType, Tensor.RType); return WASINN::ErrNo::InvalidArgument; } TfLiteStatus Stat = TfLiteTensorCopyFromBuffer( HoldTensor, Tensor.Tensor.data(), Tensor.Tensor.size()); if (unlikely(Stat != TfLiteStatus::kTfLiteOk)) { - spdlog::error("[WASI-NN] Copy tensor memory failed"); + spdlog::error("[WASI-NN] Copy tensor memory failed"sv); return WASINN::ErrNo::Busy; } @@ -152,7 +155,7 @@ Expect getOutput(WASINN::WasiNNEnvironment &Env, uint32_t OutCnt = TfLiteInterpreterGetOutputTensorCount(CxtRef.TFLiteInterp); if (Index >= OutCnt) { spdlog::error("[WASI-NN] Invalid index id {} for the input, only {} " - "outputs are allowed", + "outputs are allowed"sv, Index, OutCnt); return WASINN::ErrNo::InvalidArgument; } @@ -161,7 +164,7 @@ Expect getOutput(WASINN::WasiNNEnvironment &Env, const uint32_t BytesToWrite = TfLiteTensorByteSize(HoldTensor); // Check out buffer max size. if (OutBuffer.size() < BytesToWrite) { - spdlog::error("[WASI-NN] Expect out buffer max size {}, but got {}", + spdlog::error("[WASI-NN] Expect out buffer max size {}, but got {}"sv, BytesToWrite, OutBuffer.size()); return WASINN::ErrNo::InvalidArgument; } @@ -175,12 +178,12 @@ Expect compute(WASINN::WasiNNEnvironment &Env, auto &CxtRef = Env.NNContext[ContextId].get(); // Run session if (unlikely(CxtRef.TFLiteInterp == nullptr)) { - spdlog::error("[WASI-NN] Tensorflow Lite context empty"); + spdlog::error("[WASI-NN] Tensorflow Lite context empty"sv); return WASINN::ErrNo::MissingMemory; } TfLiteStatus Stat = TfLiteInterpreterInvoke(CxtRef.TFLiteInterp); if (unlikely(Stat != TfLiteStatus::kTfLiteOk)) { - spdlog::error("[WASI-NN] Invocation failed."); + spdlog::error("[WASI-NN] Invocation failed."sv); return WASINN::ErrNo::Busy; } return WASINN::ErrNo::Success; @@ -190,7 +193,7 @@ namespace { Expect reportBackendNotSupported() noexcept { spdlog::error( "[WASI-NN] TensorflowLite backend is not built. use " - "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"Tensorflowlite\" to build it."); + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"Tensorflowlite\" to build it."sv); return WASINN::ErrNo::InvalidArgument; } } // namespace diff --git a/plugins/wasi_nn/wasinn_whisper.cpp b/plugins/wasi_nn/wasinn_whisper.cpp index 7bfefcd2..759e147f 100644 --- a/plugins/wasi_nn/wasinn_whisper.cpp +++ b/plugins/wasi_nn/wasinn_whisper.cpp @@ -14,6 +14,8 @@ #include #endif +using namespace std::literals; + namespace WasmEdge::Host::WASINN::Whisper { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER @@ -484,30 +486,31 @@ void setWhisperParams(Context &CxtRef) noexcept { WParam.beam_search.beam_size = ConfigRef.BeamSize; if (ConfigRef.EnableDebugLog) { - spdlog::info("[WASI-NN][Debug] Whisper backend: Config: threads: {}", + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: threads: {}"sv, ConfigRef.ThreadsNum); - spdlog::info("[WASI-NN][Debug] Whisper backend: Config: processors: {}", + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: processors: {}"sv, ConfigRef.ProcessorsNum); - spdlog::info("[WASI-NN][Debug] Whisper backend: Config: max-context: {}", + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: max-context: {}"sv, ConfigRef.MaxTokenContext); - spdlog::info("[WASI-NN][Debug] Whisper backend: Config: offset-t: {}", + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: offset-t: {}"sv, ConfigRef.TimeOffsetMS); - spdlog::info("[WASI-NN][Debug] Whisper backend: Config: duration: {}", + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: duration: {}"sv, ConfigRef.DurationMS); - spdlog::info("[WASI-NN][Debug] Whisper backend: Config: max-len: {}", + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: max-len: {}"sv, ConfigRef.MaxSegmentLength); - spdlog::info("[WASI-NN][Debug] Whisper backend: Config: split-on-word : {}", - ConfigRef.SplitOnWord); - spdlog::info("[WASI-NN][Debug] Whisper backend: Config: translate: {}", + spdlog::info( + "[WASI-NN][Debug] Whisper backend: Config: split-on-word : {}"sv, + ConfigRef.SplitOnWord); + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: translate: {}"sv, ConfigRef.Translate); - spdlog::info("[WASI-NN][Debug] Whisper backend: Config: language: \"{}\"", + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: language: \"{}\""sv, ConfigRef.SpokenLanguage); spdlog::info( - "[WASI-NN][Debug] Whisper backend: Config: detect-language: {}", + "[WASI-NN][Debug] Whisper backend: Config: detect-language: {}"sv, ConfigRef.DetectLanguage); - spdlog::info("[WASI-NN][Debug] Whisper backend: Config: temperature: {}", + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: temperature: {}"sv, ConfigRef.Temperature); - spdlog::info("[WASI-NN][Debug] Whisper backend: Config: prompt: \"{}\"", + spdlog::info("[WASI-NN][Debug] Whisper backend: Config: prompt: \"{}\""sv, ConfigRef.InitialPrompt); } } @@ -969,20 +972,20 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } setWhisperParams(CxtRef); if (CxtRef.WhisperConfig.EnableDebugLog) { - spdlog::info( - "[WASI-NN][Debug] Whisper backend: found Metadata, processing...Done"sv); + spdlog::info("[WASI-NN][Debug] Whisper backend: found Metadata, " + "processing...Done"sv); } return ErrNo::Success; } if (Tensor.Dimension.size() != 2) { spdlog::error("[WASI-NN] Tensor dimension is out of range, expect 2-dim, " - "but got {}-dim.", + "but got {}-dim."sv, Tensor.Dimension.size()); return WASINN::ErrNo::InvalidArgument; } if (Tensor.Dimension[0] != 1) { - spdlog::error("[WASI-NN] Only 1 channel supported for now."); + spdlog::error("[WASI-NN] Only 1 channel supported for now."sv); return WASINN::ErrNo::InvalidArgument; } @@ -1019,7 +1022,7 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, // Check out buffer max size. if (OutBuffer.size() < CxtRef.Outputs.length()) { - spdlog::error("[WASI-NN] Expect out buffer max size {}, but got {}", + spdlog::error("[WASI-NN] Expect out buffer max size {}, but got {}"sv, CxtRef.Outputs.length(), OutBuffer.size()); return WASINN::ErrNo::InvalidArgument; } diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index c6b91ac6..275c14d1 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -11,6 +11,8 @@ #include #endif +using namespace std::literals; + namespace WasmEdge { namespace Host { diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index 4ce733c9..f1171393 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -83,7 +83,7 @@ WasiNNLoad::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t BuilderPtr, spdlog::error("[WASI-NN] Unknown device {}."sv, Target); return WASINN::ErrNo::InvalidArgument; } - spdlog::debug("[WASI-NN] Using device: {}.", Device); + spdlog::debug("[WASI-NN] Using device: {}."sv, Device); // Builders' Layout: // | builder-0 | builder-0 len | builder-1 | builder-1 len | ... @@ -259,17 +259,15 @@ WasiNNInitExecCtx::bodyImpl(const Runtime::CallingFrame &Frame, #endif // ifdef WASMEDGE_BUILD_WASI_NN_RPC if (Env.NNGraph.size() <= GraphId || Env.NNGraph[GraphId].isFinalized()) { - spdlog::error( - "[WASI-NN] init_execution_context: Graph ID {} does not exist or is "sv - "unloaded."sv, - GraphId); + spdlog::error("[WASI-NN] init_execution_context: Graph ID {} does not " + "exist or is unloaded."sv, + GraphId); return WASINN::ErrNo::InvalidArgument; } if (!Env.NNGraph[GraphId].isReady()) { - spdlog::error( - "[WASI-NN] init_execution_context: Graph ID {} is invalid. Please "sv - "reload or unload this graph."sv, - GraphId); + spdlog::error("[WASI-NN] init_execution_context: Graph ID {} is invalid. " + "Please reload or unload this graph."sv, + GraphId); return WASINN::ErrNo::InvalidArgument; } diff --git a/plugins/wasm_bpf/func-bpf-map-operate.cpp b/plugins/wasm_bpf/func-bpf-map-operate.cpp index 73b1cd6f..86ba7bb4 100644 --- a/plugins/wasm_bpf/func-bpf-map-operate.cpp +++ b/plugins/wasm_bpf/func-bpf-map-operate.cpp @@ -8,6 +8,8 @@ extern "C" { #include } +using namespace std::literals; + namespace WasmEdge { namespace Host { @@ -32,8 +34,8 @@ BpfMapOperate::body(const WasmEdge::Runtime::CallingFrame &Frame, int32_t fd, uint32_t info_len = sizeof(map_info); int32_t err; if ((err = bpf_map_get_info_by_fd(fd, &map_info, &info_len)) != 0) { - spdlog::debug("[WasmEdge Wasm_bpf] Invalid map fd found: fd={},err={}", fd, - err); + spdlog::debug("[WasmEdge Wasm_bpf] Invalid map fd found: fd={},err={}"sv, + fd, err); // Invalid map fd return err; } @@ -61,7 +63,7 @@ BpfMapOperate::body(const WasmEdge::Runtime::CallingFrame &Frame, int32_t fd, return bpf_map_delete_elem_flags(fd, key_ptr, flags); } default: // More syscall commands can be allowed here - spdlog::debug("[WasmEdge Wasm_bpf] Invalid map operation", cmd); + spdlog::debug("[WasmEdge Wasm_bpf] Invalid map operation"sv, cmd); return -EINVAL; } } diff --git a/plugins/wasm_bpf/wasm-bpf.cpp b/plugins/wasm_bpf/wasm-bpf.cpp index 0c45fc08..490dbbdd 100644 --- a/plugins/wasm_bpf/wasm-bpf.cpp +++ b/plugins/wasm_bpf/wasm-bpf.cpp @@ -16,6 +16,8 @@ extern "C" { #include } +using namespace std::literals; + static int32_t bpf_buffer_sample(void *ctx, void *data, size_t size); static int32_t libbpf_print_fn(enum libbpf_print_level level, const char *format, va_list args) { @@ -23,7 +25,7 @@ static int32_t libbpf_print_fn(enum libbpf_print_level level, return 0; char buf[DEBUG_PRINT_BUFFER_SIZE]; int32_t len = vsnprintf(buf, sizeof(buf), format, args); - spdlog::debug("[WasmEdge Wasm_bpf] {}", buf); + spdlog::debug("[WasmEdge Wasm_bpf] {}"sv, buf); return len; } @@ -195,7 +197,7 @@ int32_t wasm_bpf_program::attach_bpf_program(const char *name, bpf_object *o = obj.get(); bpf_program *prog = bpf_object__find_program_by_name(o, name); if (!prog) { - spdlog::error("[WasmEdge Wasm_bpf] get prog {} fail", name); + spdlog::error("[WasmEdge Wasm_bpf] get prog {} fail"sv, name); return -1; } // TODO: attach dynamically base on bpf_program__section_name(prog) and diff --git a/plugins/wasmedge_process/processfunc.cpp b/plugins/wasmedge_process/processfunc.cpp index ded6e971..d8eee10a 100644 --- a/plugins/wasmedge_process/processfunc.cpp +++ b/plugins/wasmedge_process/processfunc.cpp @@ -19,6 +19,8 @@ #elif WASMEDGE_OS_WINDOWS #endif +using namespace std::literals; + namespace WasmEdge { namespace Host { @@ -187,13 +189,13 @@ Expect WasmEdgeProcessRun::body(const Runtime::CallingFrame &) { #endif switch (errno) { case EACCES: - spdlog::error("Permission denied."); + spdlog::error("Permission denied."sv); break; case ENOENT: - spdlog::error("Command not found."); + spdlog::error("Command not found."sv); break; default: - spdlog::error("Unknown error."); + spdlog::error("Unknown error."sv); break; } _exit(-1); @@ -297,7 +299,7 @@ Expect WasmEdgeProcessRun::body(const Runtime::CallingFrame &) { Env.TimeOut = Env.DEFAULT_TIMEOUT; return Env.ExitCode; #elif WASMEDGE_OS_WINDOWS - spdlog::error("wasmedge_process doesn't support windows now."); + spdlog::error("wasmedge_process doesn't support windows now."sv); return Unexpect(ErrCode::Value::HostFuncError); #endif } diff --git a/plugins/wasmedge_stablediffusion/sd_env.cpp b/plugins/wasmedge_stablediffusion/sd_env.cpp index f2aa0ca2..438857e4 100644 --- a/plugins/wasmedge_stablediffusion/sd_env.cpp +++ b/plugins/wasmedge_stablediffusion/sd_env.cpp @@ -4,6 +4,8 @@ #include "sd_env.h" #include "sd_module.h" +using namespace std::literals; + namespace WasmEdge { namespace Host { namespace { @@ -79,7 +81,7 @@ void SBLog(enum sd_log_level_t Level, const char *Log, void *) { break; } - spdlog::info("[WasmEdge-StableDiffusion] SD-log: [{}] {}", LevelStr, Log); + spdlog::info("[WasmEdge-StableDiffusion] SD-log: [{}] {}"sv, LevelStr, Log); } } // namespace StableDiffusion diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index e994b75d..bf714876 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -16,6 +16,7 @@ #include #include +using namespace std::literals; using WasmEdge::Host::WASINN::Backend; using WasmEdge::Host::WASINN::Device; using WasmEdge::Host::WASINN::ErrNo; @@ -42,7 +43,6 @@ inline std::unique_ptr dynamicPointerCast(std::unique_ptr &&R) noexcept { std::unique_ptr createModule(std::string_view NNRPCURI = "") { - using namespace std::literals::string_view_literals; WasmEdge::Plugin::Plugin::load( std::filesystem::u8path("../../../plugins/wasi_nn/" WASMEDGE_LIB_PREFIX "wasmedgePluginWasiNN" WASMEDGE_LIB_EXTENSION)); @@ -933,7 +933,7 @@ TEST(WasiNNTest, TFLiteBackend) { std::vector WeightRead = readEntireFile("./wasinn_tflite_fixtures/" "lite-model_aiy_vision_classifier_birds_V1_3.tflite"); - spdlog::info("Read {}", TensorData.size()); + spdlog::info("Read {}"sv, TensorData.size()); std::vector TensorDim{1, 224, 224, 3}; uint32_t BuilderPtr = UINT32_C(0); uint32_t LoadEntryPtr = UINT32_C(0); @@ -3090,4 +3090,4 @@ TEST(WasiNNTest, MLXBackend) { EXPECT_GE(BytesWritten, 50); } } -#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX \ No newline at end of file +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX From e4e8b2d4dd1fa727b88a8db144e7ae6ca0e40ffa Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 23 Jan 2025 17:14:18 +0800 Subject: [PATCH 529/623] [WASI-NN] ggml: support text to speech Signed-off-by: dm4 --- plugins/wasi_nn/wasinn_ggml.cpp | 597 ++++++++++++++++++++++++++++++++ plugins/wasi_nn/wasinn_ggml.h | 6 + 2 files changed, 603 insertions(+) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index c5523f14..a87e629f 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -19,7 +19,9 @@ #include #include +#include #include +#include #include #include #endif @@ -89,6 +91,13 @@ void setupSamplerParams(Graph &GraphRef, Sampling.penalty_freq = static_cast(GraphRef.FrequencyPenalty); Sampling.grammar = GraphRef.Grammar; Sampling.seed = static_cast(GraphRef.Seed); + + if (GraphRef.TextToSpeech) { + Sampling.top_k = 4; + Sampling.samplers = { + COMMON_SAMPLER_TYPE_TOP_K, + }; + } } // Setup llama common params from graph. @@ -132,6 +141,10 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, // warmup: bool // split-mode: string, {none,layer,row} // mmproj: string + // TTS parameters: + // tts: bool + // model-vocoder: string + // tts-output-file: string // Context parameters (used by the llama context): // ctx-size: int64_t // batch-size: int64_t @@ -273,6 +286,34 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, GraphRef.MMProjModelPath = MMProjModelPath; } + // The TTS parameters. + if (Doc.at_key("tts").error() == simdjson::SUCCESS) { + auto Err = Doc["tts"].get().get(GraphRef.TextToSpeech); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the tts option."sv) + } + } + if (Doc.at_key("model-vocoder").error() == simdjson::SUCCESS) { + std::string_view VocoderModelPath; + auto Err = + Doc["model-vocoder"].get().get(VocoderModelPath); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the model-vocoder option."sv) + } + GraphRef.VocoderModelPath = VocoderModelPath; + } + if (Doc.at_key("tts-output-file").error() == simdjson::SUCCESS) { + std::string_view TTSOutputFilePath; + auto Err = + Doc["tts-output-file"].get().get(TTSOutputFilePath); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the tts-output-file option."sv) + } + GraphRef.TTSOutputFilePath = TTSOutputFilePath; + } + // The context parameters. if (Doc.at_key("ctx-size").error() == simdjson::SUCCESS) { auto Err = Doc["ctx-size"].get().get(GraphRef.CtxSize); @@ -499,6 +540,263 @@ extractBase64ImagePayload(std::string &Prompt, return std::make_pair(ImageBytes, ImageType); } +// TTS function to process the prompt text. +const std::vector TTSVoiceData = llama_tokens{ + 151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585, + 152460, 153375, 151670, 198, 74455, 155808, 151669, 151799, 151873, + 151863, 152446, 152372, 152204, 152728, 152229, 152470, 151970, 153413, + 152419, 153334, 153289, 153374, 153199, 152040, 153260, 152721, 152680, + 153297, 152419, 153248, 152400, 152691, 153368, 153437, 151670, 198, + 1722, 155828, 151669, 152607, 152256, 152991, 152299, 152688, 153163, + 153016, 152789, 153198, 152712, 151911, 153107, 152623, 152170, 152395, + 152852, 152207, 152461, 153321, 153309, 151750, 152137, 153340, 152573, + 152267, 153347, 151789, 152681, 153339, 151992, 152512, 151751, 152179, + 153434, 153180, 152900, 153440, 152474, 153122, 153129, 151904, 152311, + 151670, 198, 1499, 155791, 151669, 152276, 152454, 153354, 152544, + 153204, 153272, 152708, 153433, 152319, 153226, 153043, 152325, 153267, + 152622, 151670, 198, 4250, 155797, 151669, 153454, 153342, 151989, + 152458, 153420, 152303, 152271, 152827, 153036, 153196, 151708, 153263, + 152561, 153207, 152213, 152112, 153204, 151722, 152542, 151670, 198, + 19789, 155796, 151669, 153353, 153182, 152345, 152471, 152477, 153014, + 152002, 152191, 151734, 152312, 152810, 152237, 153224, 153169, 153224, + 152244, 153387, 153404, 151670, 198, 16069, 155811, 151669, 152265, + 151946, 151808, 152412, 152363, 152305, 153156, 152733, 152810, 153157, + 152016, 152100, 152069, 153234, 152317, 152589, 152707, 153121, 153341, + 152159, 152114, 153156, 153001, 153504, 153376, 152272, 152433, 152325, + 151941, 151670, 198, 285, 155788, 151669, 152238, 152255, 153427, + 152318, 153009, 152381, 152474, 152680, 152157, 153255, 152324, 151682, + 151670, 198, 32955, 155804, 151669, 153490, 153419, 152364, 152405, + 152682, 152206, 152078, 153369, 152725, 153193, 153027, 152946, 152488, + 153070, 151883, 152890, 152489, 153144, 153375, 152358, 151685, 152494, + 152117, 152740, 151670, 198, 37448, 480, 155840, 151669, 151902, + 152720, 153377, 152027, 152378, 152821, 153207, 153459, 153028, 153068, + 152507, 153255, 152158, 152921, 151958, 152609, 152748, 152822, 152286, + 151714, 152730, 152377, 152353, 152470, 152606, 152162, 152186, 153071, + 152244, 153118, 153375, 153018, 152712, 153098, 152976, 152336, 151843, + 153202, 152297, 151736, 153380, 153502, 152702, 152115, 153181, 152735, + 153277, 153457, 152393, 153112, 152595, 151670, 198, 19098, 155808, + 151669, 152464, 153452, 152595, 153312, 151937, 151933, 153197, 152239, + 153163, 152922, 153402, 152034, 152591, 153438, 152215, 151673, 152005, + 151785, 152642, 151924, 153278, 151805, 151974, 153482, 152718, 152862, + 153347, 151670, 198, 72, 155780, 151669, 151795, 152111, 152746, + 152377, 153471, 152309, 151670, 198, 19016, 155788, 151669, 153181, + 152271, 152190, 152842, 152224, 152701, 152939, 152536, 152091, 151815, + 152733, 151672, 151670, 198, 14689, 155788, 151669, 152291, 152072, + 152942, 151734, 153042, 153504, 152589, 153333, 151839, 151941, 153038, + 153180, 151670, 198, 36996, 8303, 155832, 151669, 152231, 152256, + 152835, 152801, 152985, 153400, 152393, 152818, 152765, 152249, 152600, + 151699, 152302, 152752, 153018, 153009, 151992, 153054, 152847, 153354, + 153228, 152662, 153355, 152532, 153393, 151782, 152458, 152048, 152757, + 152428, 153195, 151906, 153006, 153178, 153250, 152331, 152284, 152780, + 153138, 153319, 151980, 153142, 152418, 152228, 152733, 151670, 198, + 9096, 155801, 151669, 151698, 153321, 152217, 153039, 152935, 153400, + 152122, 152531, 153106, 152169, 152892, 152957, 151851, 152427, 152826, + 152451, 151851, 152901, 152885, 152594, 153446, 153080, 151670, 198, + 14689, 155795, 151669, 152658, 151700, 153321, 152450, 152530, 153191, + 151673, 151690, 151698, 152714, 152846, 152981, 153171, 153384, 153364, + 153188, 153246, 151670, 198, 1055, 155779, 151669, 151869, 152388, + 152711, 153334, 151736, 151670, 198, 1782, 155780, 151669, 153483, + 153240, 152241, 152558, 152697, 153046, 151670, 198, 5804, 1363, + 155820, 151669, 152941, 152764, 152605, 153034, 153434, 153372, 153347, + 151887, 152453, 152758, 152133, 152510, 152694, 152431, 152321, 153088, + 152676, 152223, 152581, 152459, 152015, 152502, 153063, 152712, 153294, + 153451, 153032, 152903, 152859, 152989, 151748, 152669, 152661, 152650, + 152409, 151861, 151670, 198, 300, 7973, 155828, 151669, 153095, + 152469, 152988, 152894, 151819, 152391, 153019, 152058, 153062, 153230, + 151826, 152112, 152306, 152264, 152769, 153390, 152384, 152435, 152790, + 153393, 152983, 152540, 152252, 152034, 153107, 152540, 151919, 151893, + 152558, 152817, 152946, 152956, 152129, 152715, 153131, 153490, 151734, + 152271, 152707, 151734, 153321, 152450, 151670, 198, 8088, 155792, + 151669, 152452, 153497, 153353, 152679, 152533, 152382, 152374, 152611, + 153341, 153163, 152285, 153411, 152495, 153141, 152320, 151670, 198, + 1199, 155781, 151669, 151764, 152360, 153295, 152634, 153342, 152199, + 152271, 151670, 198, 43366, 155799, 151669, 152308, 151682, 152889, + 152016, 152385, 152629, 152495, 151826, 153321, 152958, 152180, 151886, + 153432, 152922, 152128, 153024, 153040, 152593, 152287, 151677, 151670, + 198, 53660, 155808, 151669, 151727, 152092, 152680, 153331, 151699, + 152316, 152938, 152289, 152433, 153384, 151781, 153137, 153259, 152175, + 153213, 152291, 151869, 152691, 152489, 151941, 152049, 152034, 153053, + 152179, 153160, 151676, 153367, 151670, 198, 268, 4123, 480, + 155821, 151669, 152350, 152173, 152536, 151991, 151960, 153144, 153013, + 152358, 152234, 153135, 152291, 153235, 152143, 152583, 152402, 153483, + 152678, 152192, 152533, 152946, 151797, 153103, 152310, 152293, 151825, + 152548, 153442, 152109, 152659, 153325, 152781, 152570, 152957, 151752, + 152265, 153381, 152515, 151670, 198, 437, 155787, 151669, 152957, + 152659, 151975, 152709, 152402, 152836, 152174, 151792, 153409, 153327, + 152990, 151670, 198, 275, 155781, 151669, 152520, 153038, 152067, + 153273, 153185, 152265, 152974, 151670, 198, 94273, 155799, 151669, + 152953, 152938, 153427, 152244, 151920, 153423, 152929, 152367, 153052, + 152129, 152331, 152257, 152987, 152777, 153448, 152408, 151696, 152408, + 152326, 152699, 151670, 198, 385, 16239, 155828, 151669, 152306, + 152268, 153438, 153228, 152978, 152957, 153153, 153393, 152795, 152110, + 152918, 152923, 152467, 152331, 153053, 153330, 151889, 153444, 152234, + 152624, 151779, 152801, 152784, 152139, 152222, 152751, 152512, 153287, + 153141, 153052, 151840, 152589, 152508, 153499, 152109, 152255, 151739, + 152267, 152759, 153318, 153165, 153349, 151670, +}; + +const std::map Ones = { + {0, "zero"}, {1, "one"}, {2, "two"}, {3, "three"}, + {4, "four"}, {5, "five"}, {6, "six"}, {7, "seven"}, + {8, "eight"}, {9, "nine"}, {10, "ten"}, {11, "eleven"}, + {12, "twelve"}, {13, "thirteen"}, {14, "fourteen"}, {15, "fifteen"}, + {16, "sixteen"}, {17, "seventeen"}, {18, "eighteen"}, {19, "nineteen"}}; + +const std::map Tens = { + {2, "twenty"}, {3, "thirty"}, {4, "forty"}, {5, "fifty"}, + {6, "sixty"}, {7, "seventy"}, {8, "eighty"}, {9, "ninety"}}; + +// Convert a number less than 1000 to words +std::string convertLessThanThousand(int Num) { + std::string result; + + if (Num >= 100) { + result += Ones.at(Num / 100) + " hundred "; + Num %= 100; + } + + if (Num >= 20) { + result += Tens.at(Num / 10); + if (Num % 10 > 0) { + result += "-" + Ones.at(Num % 10); + } + } else if (Num > 0) { + result += Ones.at(Num); + } + + return result; +} + +std::string numberToWords(const std::string &NumberStr) { + try { + size_t DecimalPos = NumberStr.find('.'); + std::string IntegerPart = NumberStr.substr(0, DecimalPos); + + int IntNumber = std::stoi(IntegerPart); + std::string Result; + + if (IntNumber == 0) { + Result = "zero"; + } else { + if (IntNumber >= 1000000000) { + int billions = IntNumber / 1000000000; + Result += convertLessThanThousand(billions) + " billion "; + IntNumber %= 1000000000; + } + + if (IntNumber >= 1000000) { + int millions = IntNumber / 1000000; + Result += convertLessThanThousand(millions) + " million "; + IntNumber %= 1000000; + } + + if (IntNumber >= 1000) { + int thousands = IntNumber / 1000; + Result += convertLessThanThousand(thousands) + " thousand "; + IntNumber %= 1000; + } + + if (IntNumber > 0) { + Result += convertLessThanThousand(IntNumber); + } + } + + // Handle decimal part + if (DecimalPos != std::string::npos) { + Result += " point"; + std::string decimal_part = NumberStr.substr(DecimalPos + 1); + for (char digit : decimal_part) { + Result += " " + Ones.at(digit - '0'); + } + } + + return Result; + } catch (const std::exception &) { + // Skip if fails + return " "; + } +} + +std::string replaceNumbersWithWords(const std::string &InputText) { + std::regex NumberPattern(R"(\d+(\.\d+)?)"); + std::string Result; + auto It = + std::sregex_iterator(InputText.begin(), InputText.end(), NumberPattern); + auto End = std::sregex_iterator(); + + size_t LastPos = 0; + for (std::sregex_iterator I = It; I != End; ++I) { + const std::smatch &Match = *I; + Result.append(InputText, LastPos, Match.position() - LastPos); + Result.append(numberToWords(Match.str())); + LastPos = Match.position() + Match.length(); + } + Result.append(InputText, LastPos); + + return Result; +} + +// Based on: +// https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39 +// https://github.com/ggerganov/llama.cpp/blob/b4488/examples/tts/tts.cpp#L374 +std::string processTTSPromptText(const std::string &Text) { + std::string ProcessedText = replaceNumbersWithWords(Text); + + std::transform( + ProcessedText.begin(), ProcessedText.end(), ProcessedText.begin(), + [](unsigned char c) { return static_cast(::tolower(c)); }); + + std::regex SpecialChars(R"([-_/,\.\\])"); + ProcessedText = std::regex_replace(ProcessedText, SpecialChars, " "); + + std::regex NonAlpha(R"([^a-z\s])"); + ProcessedText = std::regex_replace(ProcessedText, NonAlpha, ""); + + std::regex MultipleSpaces(R"(\s+)"); + ProcessedText = std::regex_replace(ProcessedText, MultipleSpaces, " "); + + ProcessedText = + std::regex_replace(ProcessedText, std::regex(R"(^\s+|\s+$)"), ""); + + ProcessedText = + std::regex_replace(ProcessedText, std::regex(R"(\s)"), "<|text_sep|>"); + + return ProcessedText; +} + +std::vector processTTSPrompt(Graph &GraphRef, + std::string &Prompt) noexcept { + std::string ProcessedPrompt = processTTSPromptText(Prompt); + std::vector Result, TmpTokens; + Result = common_tokenize(GraphRef.LlamaContext.get(), "<|im_start|>\n", + /* add_special */ true, + /* parse_special */ true); + TmpTokens = common_tokenize( + GraphRef.LlamaContext.get(), + "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<" + "|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_" + "sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_" + "sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_" + "sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>" + "aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>" + "really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>" + "looks<|text_sep|>lovely<|text_sep|>", + /* add_special */ false, + /* parse_special */ true); + Result.insert(Result.end(), TmpTokens.begin(), TmpTokens.end()); + TmpTokens = common_tokenize(GraphRef.LlamaContext.get(), ProcessedPrompt, + /* add_special */ false, + /* parse_special */ true); + Result.insert(Result.end(), TmpTokens.begin(), TmpTokens.end()); + TmpTokens = common_tokenize(GraphRef.LlamaContext.get(), "<|text_end|>\n", + /* add_special */ false, + /* parse_special */ true); + Result.insert(Result.end(), TmpTokens.begin(), TmpTokens.end()); + Result.insert(Result.end(), TTSVoiceData.begin(), TTSVoiceData.end()); + + return Result; +} + // <<<<<<<< Input related functions <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< // >>>>>>>> Output related functions >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> @@ -911,6 +1209,257 @@ Expect getEmbedding(Graph &GraphRef, Context &CxtRef) noexcept { return ErrNo::Success; } +// TTS related functions. +void fillHannWindow(int Length, bool Periodic, float *Output) { + int Offset = -1; + float Pi = static_cast(std::acos(-1)); + if (Periodic) { + Offset = 0; + } + for (int I = 0; I < Length; I++) { + double Value = + 0.5 * + (1.0 - cosf(static_cast((2.0 * Pi * I) / (Length + Offset)))); + Output[I] = static_cast(Value); + } +} + +void twiddle(float *Real, float *Imag, int K, int N) { + float Pi = static_cast(std::acos(-1)); + float Angle = 2 * Pi * K / N; + *Real = cos(Angle); + *Imag = sin(Angle); +} + +void irfft(int N, const float *InpCplx, float *OutReal) { + int NN = N / 2 + 1; + + std::vector RealInput(NN); + std::vector ImagInput(NN); + for (int I = 0; I < NN; ++I) { + RealInput[I] = InpCplx[2 * I]; + ImagInput[I] = InpCplx[2 * I + 1]; + } + + std::vector RealOutput(N); + std::vector ImagOutput(N); + + for (int K = 0; K < N; ++K) { + RealOutput[K] = 0.0f; + ImagOutput[K] = 0.0f; + for (int M = 0; M < NN; ++M) { + float TwiddleReal; + float TwiddleImag; + + twiddle(&TwiddleReal, &TwiddleImag, K * M, N); + + RealOutput[K] += RealInput[M] * TwiddleReal - ImagInput[M] * TwiddleImag; + ImagOutput[K] += RealInput[M] * TwiddleImag + ImagInput[M] * TwiddleReal; + } + } + + for (int I = 0; I < N; ++I) { + OutReal[I] = RealOutput[I] / NN; + } +} + +void fold(const std::vector &Data, int64_t NOut, int64_t NWin, + int64_t NHop, int64_t NPad, std::vector &Output) { + int64_t OutputHeight = NOut; + int64_t KernelW = NWin; + int64_t StrideW = NHop; + int64_t Width = NOut; + + Output.resize(Width, 0.0f); + + int64_t ColIdx = 0; + for (int64_t WCol = 0; WCol < Width; ++WCol) { + int64_t Start = WCol * StrideW - NPad; + int64_t End = Start + KernelW; + + for (int64_t WIm = Start; WIm < End; ++WIm) { + if (WIm >= 0 && WIm < OutputHeight && + ColIdx < static_cast(Data.size())) { + Output[WIm] += Data[ColIdx]; + } + ColIdx++; + } + } + + Output.resize(NOut - 2 * NPad); +} + +std::vector embdToAudio(const float *Embd, const int NCodes, + const int NEmbd, const int NThread) { + const int NFft = 1280; + const int NHop = 320; + const int NWin = 1280; + const int NPad = (NWin - NHop) / 2; + const int NOut = (NCodes - 1) * NHop + NWin; + + std::vector Hann(NFft); + + fillHannWindow(static_cast(Hann.size()), true, Hann.data()); + + int NSpec = NEmbd * NCodes; + + std::vector E(NSpec); + std::vector S(NSpec); + std::vector ST(NSpec); + + for (int L = 0; L < NCodes; ++L) { + for (int K = 0; K < NEmbd; ++K) { + E[K * NCodes + L] = Embd[L * NEmbd + K]; + } + } + + for (int K = 0; K < NEmbd / 2; ++K) { + for (int L = 0; L < NCodes; ++L) { + float Mag = E[(K)*NCodes + L]; + float Phi = E[(K + NEmbd / 2) * NCodes + L]; + + Mag = exp(Mag); + + if (Mag > 1e2) { + Mag = 1e2; + } + S[2 * (K * NCodes + L) + 0] = Mag * cosf(Phi); + S[2 * (K * NCodes + L) + 1] = Mag * sinf(Phi); + } + } + + for (int L = 0; L < NCodes; ++L) { + for (int K = 0; K < NEmbd / 2; ++K) { + ST[L * NEmbd + 2 * K + 0] = S[2 * (K * NCodes + L) + 0]; + ST[L * NEmbd + 2 * K + 1] = S[2 * (K * NCodes + L) + 1]; + } + } + + std::vector Res(NCodes * NFft); + std::vector Hann2(NCodes * NFft); + + std::vector Workers(NThread); + for (int I = 0; I < NThread; ++I) { + Workers[I] = std::thread([&, I]() { + for (int L = I; L < NCodes; L += NThread) { + irfft(NFft, ST.data() + L * NEmbd, Res.data() + L * NFft); + for (int J = 0; J < NFft; ++J) { + Res[L * NFft + J] *= Hann[J]; + Hann2[L * NFft + J] = Hann[J] * Hann[J]; + } + } + }); + } + for (int I = 0; I < NThread; ++I) { + Workers[I].join(); + } + + std::vector Audio; + std::vector Env; + + fold(Res, NOut, NWin, NHop, NPad, Audio); + fold(Hann2, NOut, NWin, NHop, NPad, Env); // TODO: can be done once + + for (size_t I = 0; I < Audio.size(); ++I) { + Audio[I] /= Env[I]; + } + + return Audio; +} + +struct WavHeader { + char Riff[4] = {'R', 'I', 'F', 'F'}; + uint32_t ChunkSize; + char Wave[4] = {'W', 'A', 'V', 'E'}; + char Fmt[4] = {'f', 'm', 't', ' '}; + uint32_t FmtChunkSize = 16; + uint16_t AudioFormat = 1; // PCM + uint16_t NumChannels = 1; // Mono + uint32_t SampleRate; + uint32_t ByteRate; + uint16_t BlockAlign; + uint16_t BitsPerSample = 16; + char Data[4] = {'d', 'a', 't', 'a'}; + uint32_t DataSize; +}; + +void audioDataToWav(const std::string &Filename, const std::vector &Data, + int SampleRate) { + std::ofstream File(Filename, std::ios::binary); + if (!File) { + LOG_ERROR("audioDataToWav: Failed to open file '{}' for writing"sv, + Filename); + return; + } + + WavHeader Header; + Header.SampleRate = SampleRate; + Header.ByteRate = + Header.SampleRate * Header.NumChannels * (Header.BitsPerSample / 8); + Header.BlockAlign = Header.NumChannels * (Header.BitsPerSample / 8); + Header.DataSize = + static_cast(Data.size() * (Header.BitsPerSample / 8)); + Header.ChunkSize = 36 + Header.DataSize; + + File.write(reinterpret_cast(&Header), sizeof(Header)); + + for (const auto &Sample : Data) { + int16_t PCMSample = + static_cast(std::clamp(Sample * 32767.0, -32768.0, 32767.0)); + File.write(reinterpret_cast(&PCMSample), sizeof(PCMSample)); + } + + File.close(); +} + +// TextToSpeech function, will generate voice data from codes. +ErrNo codesToSpeech(Graph &GraphRef, Context &CxtRef) noexcept { + // Remove all non-audio tokens. + CxtRef.LlamaOutputTokens.erase( + std::remove_if(CxtRef.LlamaOutputTokens.begin(), + CxtRef.LlamaOutputTokens.end(), + [](llama_token T) { return T < 151672 || T > 155772; }), + CxtRef.LlamaOutputTokens.end()); + + // Adjust the token values for audio data. + for (llama_token &Token : CxtRef.LlamaOutputTokens) { + Token -= 151672; + } + + // Put codes into batch. + const uint32_t NCodes = + static_cast(CxtRef.LlamaOutputTokens.size()); + llama_batch TTSBatch = + llama_batch_init(NCodes, /* embd */ 0, /* n_seq_max */ 1); + for (uint32_t I = 0; I < NCodes; ++I) { + common_batch_add(TTSBatch, CxtRef.LlamaOutputTokens[I], I, + /* seq_ids */ {0}, /* logits */ true); + } + if (llama_decode(GraphRef.TTSContext.get(), TTSBatch) != 0) { + RET_ERROR(ErrNo::RuntimeError, "codesToSpeech: fail to eval."sv) + } + llama_batch_free(TTSBatch); + + // Get embeddings. + const int NEmbd = llama_model_n_embd(GraphRef.TTSModel.get()); + const float *Embd = llama_get_embeddings(GraphRef.TTSContext.get()); + + // Embeddings to audio. + std::vector AudioData = + embdToAudio(Embd, NCodes, NEmbd, static_cast(GraphRef.Threads)); + + // Zero out first 0.25 seconds of audio. + const uint32_t SamplingRate = 24000; + for (uint32_t i = 0; i < SamplingRate / 4; ++i) { + AudioData[i] = 0.0f; + } + + // Save .wav file + audioDataToWav(GraphRef.TTSOutputFilePath, AudioData, SamplingRate); + + return ErrNo::Success; +} + // <<<<<<<< Compute related functions <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< } // namespace @@ -1032,6 +1581,25 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, LOG_DEBUG(GraphRef.EnableDebugLog, "load: initialize ggml model with given parameters...Done"sv) + // Initialize the TTS related model and context. + if (GraphRef.TextToSpeech) { + LOG_DEBUG(GraphRef.EnableDebugLog, "load: initialize TTS model."sv) + Params.model = GraphRef.VocoderModelPath; + Params.embedding = true; + common_init_result TTSInit = common_init_from_params(Params); + GraphRef.TTSModel = std::move(TTSInit.model); + GraphRef.TTSContext = std::move(TTSInit.context); + if (GraphRef.TTSModel == nullptr) { + Env.deleteGraph(GId); + RET_ERROR(ErrNo::InvalidArgument, "load: unable to init TTS model."sv) + } + if (GraphRef.TTSContext == nullptr) { + Env.deleteGraph(GId); + RET_ERROR(ErrNo::InvalidArgument, "load: unable to init TTS context."sv) + } + LOG_DEBUG(GraphRef.EnableDebugLog, "load: initialize TTS model...Done"sv) + } + // Store the loaded graph. GraphId = GId; Env.NNGraph[GId].setReady(); @@ -1328,6 +1896,11 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: handle llava format prompt...Done"sv) } + } else if (GraphRef.TextToSpeech == true) { + // TTS prompt. + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize tts prompt"sv) + CxtRef.LlamaInputs = processTTSPrompt(GraphRef, Prompt); + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize tts prompt...Done"sv) } else { // Text only prompt. LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize text prompt"sv) @@ -1403,6 +1976,20 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { "compute: enter main prediction loop...Done"sv) // End of main prediction loop. + // TTS: convert output codes to audio file. + if (GraphRef.TextToSpeech) { + LOG_DEBUG(GraphRef.EnableDebugLog, + "compute: convert output codes to audio file."sv) + ReturnCode = codesToSpeech(GraphRef, CxtRef); + if (ReturnCode != ErrNo::Success) { + RET_ERROR(ReturnCode, + "compute: failed to convert output codes to audio "sv + "file."sv) + } + LOG_DEBUG(GraphRef.EnableDebugLog, + "compute: convert output codes to audio file...Done"sv) + } + if (GraphRef.EnableLog) { common_perf_print(GraphRef.LlamaContext.get(), CxtRef.LlamaSampler); } @@ -1522,6 +2109,16 @@ Expect unload(WasiNNEnvironment &Env, uint32_t GraphId) noexcept { GraphRef.ClipContext = nullptr; LOG_DEBUG(IsDebugLog, "unload: free clip context...Done"sv) } + if (GraphRef.TTSModel != nullptr) { + LOG_DEBUG(IsDebugLog, "unload: free TTS model"sv) + GraphRef.TTSModel.reset(); + LOG_DEBUG(IsDebugLog, "unload: free TTS model...Done"sv) + } + if (GraphRef.TTSContext != nullptr) { + LOG_DEBUG(IsDebugLog, "unload: free TTS context"sv) + GraphRef.TTSContext.reset(); + LOG_DEBUG(IsDebugLog, "unload: free TTS context...Done"sv) + } Env.deleteGraph(GraphId); Env.mdRemoveById(GraphId); diff --git a/plugins/wasi_nn/wasinn_ggml.h b/plugins/wasi_nn/wasinn_ggml.h index fb34c545..51d788dd 100644 --- a/plugins/wasi_nn/wasinn_ggml.h +++ b/plugins/wasi_nn/wasinn_ggml.h @@ -71,6 +71,12 @@ struct Graph { std::string MMProjModelPath; struct clip_ctx *ClipContext = nullptr; VisionModel VisionModelType = VisionModel::Llava; + // Text-to-speech: + bool TextToSpeech = false; + std::string VocoderModelPath; + std::string TTSOutputFilePath = "output.wav"; + llama_model_ptr TTSModel = nullptr; + llama_context_ptr TTSContext = nullptr; // Context parameters: int64_t CtxSize; int64_t BatchSize; From 23ab38e1191e4d07ab9a1f575963b4653648c68b Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 5 Feb 2025 14:04:48 +0800 Subject: [PATCH 530/623] [WASI-NN] ggml: apply clang-tidy (#4005) Signed-off-by: hydai --- plugins/wasi_nn/wasinn_ggml.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index a87e629f..24d56e45 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -704,9 +704,9 @@ std::string numberToWords(const std::string &NumberStr) { // Handle decimal part if (DecimalPos != std::string::npos) { Result += " point"; - std::string decimal_part = NumberStr.substr(DecimalPos + 1); - for (char digit : decimal_part) { - Result += " " + Ones.at(digit - '0'); + std::string DecimalPart = NumberStr.substr(DecimalPos + 1); + for (char Digit : DecimalPart) { + Result += " " + Ones.at(Digit - '0'); } } @@ -1450,8 +1450,8 @@ ErrNo codesToSpeech(Graph &GraphRef, Context &CxtRef) noexcept { // Zero out first 0.25 seconds of audio. const uint32_t SamplingRate = 24000; - for (uint32_t i = 0; i < SamplingRate / 4; ++i) { - AudioData[i] = 0.0f; + for (uint32_t I = 0; I < SamplingRate / 4; ++I) { + AudioData[I] = 0.0f; } // Save .wav file From d158af4808564efb06eda55773beec052bcb51d0 Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 11 Feb 2025 00:12:53 +0800 Subject: [PATCH 531/623] [WASI-NN] ggml: check the projection model before loading Signed-off-by: dm4 --- plugins/wasi_nn/wasinn_ggml.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 24d56e45..482aa104 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -1762,7 +1762,14 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, if (Base64ImagePos.has_value() || CxtRef.Conf.ImagePath != ""sv) { // Prompt with image input. Check is llava or mllama case. - // First check the projection model is loaded. + // First check the projection model is given. + if (GraphRef.MMProjModelPath == ""sv) { + RET_ERROR( + ErrNo::InvalidArgument, + "setInput: the given model does not support image input, so a projection model is required."sv) + } + + // Make sure the projection model is loaded. if (GraphRef.ClipContext == nullptr) { LOG_INFO( true, From 656374cfe51ec3c1e346a0f09707fcd673de2853 Mon Sep 17 00:00:00 2001 From: dm4 Date: Tue, 11 Feb 2025 00:14:17 +0800 Subject: [PATCH 532/623] [WASI-NN] ggml: apply clang-tidy Signed-off-by: dm4 --- plugins/wasi_nn/wasinn_ggml.cpp | 35 +++++++++++++++++---------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 482aa104..2fcb9222 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -55,7 +55,7 @@ namespace { return Error; // Llama logging callback. -void LlamaLogCallback(ggml_log_level LogLevel, const char *LogText, +void llamaLogCallback(ggml_log_level LogLevel, const char *LogText, void *UserData) { Graph &GraphRef = *reinterpret_cast(UserData); if (!GraphRef.EnableLog) { @@ -648,23 +648,23 @@ const std::map Tens = { // Convert a number less than 1000 to words std::string convertLessThanThousand(int Num) { - std::string result; + std::string Result; if (Num >= 100) { - result += Ones.at(Num / 100) + " hundred "; + Result += Ones.at(Num / 100) + " hundred "; Num %= 100; } if (Num >= 20) { - result += Tens.at(Num / 10); + Result += Tens.at(Num / 10); if (Num % 10 > 0) { - result += "-" + Ones.at(Num % 10); + Result += "-" + Ones.at(Num % 10); } } else if (Num > 0) { - result += Ones.at(Num); + Result += Ones.at(Num); } - return result; + return Result; } std::string numberToWords(const std::string &NumberStr) { @@ -679,20 +679,20 @@ std::string numberToWords(const std::string &NumberStr) { Result = "zero"; } else { if (IntNumber >= 1000000000) { - int billions = IntNumber / 1000000000; - Result += convertLessThanThousand(billions) + " billion "; + int Billions = IntNumber / 1000000000; + Result += convertLessThanThousand(Billions) + " billion "; IntNumber %= 1000000000; } if (IntNumber >= 1000000) { - int millions = IntNumber / 1000000; - Result += convertLessThanThousand(millions) + " million "; + int Millions = IntNumber / 1000000; + Result += convertLessThanThousand(Millions) + " million "; IntNumber %= 1000000; } if (IntNumber >= 1000) { - int thousands = IntNumber / 1000; - Result += convertLessThanThousand(thousands) + " thousand "; + int Thousands = IntNumber / 1000; + Result += convertLessThanThousand(Thousands) + " thousand "; IntNumber %= 1000; } @@ -744,7 +744,7 @@ std::string processTTSPromptText(const std::string &Text) { std::transform( ProcessedText.begin(), ProcessedText.end(), ProcessedText.begin(), - [](unsigned char c) { return static_cast(::tolower(c)); }); + [](unsigned char C) { return static_cast(::tolower(C)); }); std::regex SpecialChars(R"([-_/,\.\\])"); ProcessedText = std::regex_replace(ProcessedText, SpecialChars, " "); @@ -990,7 +990,8 @@ ErrNo evaluateTokens(Span Tokens, Graph &GraphRef, ErrNo::RuntimeError, "evaluateTokens: failed to llama_decode: try reducing the size of the batch "sv "or increasing the size of context."sv) - } else if (Status < 0) { + } + if (Status < 0) { RET_ERROR( ErrNo::RuntimeError, "evaluateTokens: failed to llama_decode: internal fatal error. Please open "sv @@ -1065,7 +1066,7 @@ ErrNo evaluateInput(Graph &GraphRef, Context &CxtRef, case VisionModel::Qwen2VL: LOG_DEBUG(GraphRef.EnableDebugLog, "{}: Eval Qwen2VL image embd"sv, LogPrefix) - auto ImageSize = clip_get_load_image_size(GraphRef.ClipContext); + auto *ImageSize = clip_get_load_image_size(GraphRef.ClipContext); EvalImageStatus = evaluateQwen2vlImageEmbed( GraphRef.LlamaContext.get(), CxtRef.LlavaImageEmbd, static_cast(GraphRef.BatchSize), CxtRef.NPos, ImageSize); @@ -1501,7 +1502,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, GraphRef.Conf.ImagePath = ""sv; // Set llama log callback. - llama_log_set(LlamaLogCallback, &GraphRef); + llama_log_set(llamaLogCallback, &GraphRef); // If the graph builder length > 1, the data of builder[1] is the metadata. if (Builders.size() > 1) { From af7b58b106265d3076123ff6046ff70e294cd774 Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 11 Feb 2025 16:25:44 +0800 Subject: [PATCH 533/623] [WASI-NN] apply WASMEDGE_WASI_NN_VERSION into the plugin's version (#4017) Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 14 +++++++++++++- plugins/wasi_nn/wasinnenv.cpp | 3 ++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 9f0036bf..ee057279 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -22,13 +22,25 @@ wasmedge_add_library(wasmedgePluginWasiNN include(WASINNDeps) wasmedge_setup_wasinn_target(wasmedgePluginWasiNN PLUGINLIB) +# Handle the version of the WASI-NN plugin +string(REPLACE "." ";" WASI_NN_VERSION_LIST ${WASMEDGE_WASI_NN_VERSION}) +list(GET WASI_NN_VERSION_LIST 0 WASI_NN_VERSION_MAJOR) +list(GET WASI_NN_VERSION_LIST 1 WASI_NN_VERSION_MINOR) +list(GET WASI_NN_VERSION_LIST 2 WASI_NN_VERSION_PATCH) + +target_compile_definitions(wasmedgePluginWasiNN PRIVATE + WASI_NN_VERSION_MAJOR=${WASI_NN_VERSION_MAJOR} + WASI_NN_VERSION_MINOR=${WASI_NN_VERSION_MINOR} + WASI_NN_VERSION_PATCH=${WASI_NN_VERSION_PATCH} +) + # This for-each iteration is for the additional sources. # The dependencies are moved into `cmake/WASINNDeps.cmake`. foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) string(TOLOWER ${BACKEND} BACKEND) if(BACKEND STREQUAL "mlx") target_sources(wasmedgePluginWasiNN - PRIVATE + PRIVATE MLX/prompt/prompt.cpp MLX/model/transformer.cpp MLX/model/converter.cpp diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index 275c14d1..88647e02 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -180,7 +180,8 @@ Plugin::Plugin::PluginDescriptor Descriptor{ /* Name */ "wasi_nn", /* Description */ "", /* APIVersion */ Plugin::Plugin::CurrentAPIVersion, - /* Version */ {0, 10, 1, 0}, + /* Version */ + {WASI_NN_VERSION_MAJOR, WASI_NN_VERSION_MINOR, WASI_NN_VERSION_PATCH, 0}, /* ModuleCount */ 1, /* ModuleDescriptions */ MD, /* ComponentCount */ 0, From a89dc29c39b3d4e97f71736990570cf56ed1bc7b Mon Sep 17 00:00:00 2001 From: LFsWang <7088579+LFsWang@users.noreply.github.com> Date: Tue, 11 Feb 2025 17:46:38 +0800 Subject: [PATCH 534/623] [WASI-NN] openvino: update to 2025.0.0 (#4016) Signed-off-by: LFsWang --- .../Dockerfile.manylinux2014-build-plugins-deps | 4 ++-- .../docker/Dockerfile.manylinux_2_28-plugins-deps | 4 ++-- utils/docker/Dockerfile.ubuntu-plugins-deps | 4 ++-- utils/wasi-nn/install-openvino.sh | 14 +++++++++++--- 4 files changed, 17 insertions(+), 9 deletions(-) diff --git a/utils/docker/Dockerfile.manylinux2014-build-plugins-deps b/utils/docker/Dockerfile.manylinux2014-build-plugins-deps index 38223ed2..1a8213e1 100644 --- a/utils/docker/Dockerfile.manylinux2014-build-plugins-deps +++ b/utils/docker/Dockerfile.manylinux2014-build-plugins-deps @@ -32,7 +32,7 @@ RUN [ "/bin/bash", "install-ffmpeg-v6.0.sh" ] ENV PKG_CONFIG_PATH /root/FFmpeg-n6.0/output/lib/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} ENV LD_LIBRARY_PATH /root/FFmpeg-n6.0/output/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} -ENV OPENVINO_VERSION "2024.2.0" -ENV OPENVINO_YEAR "2024" +ENV OPENVINO_VERSION "2025.0.0" +ENV OPENVINO_YEAR "2025" RUN yum clean all diff --git a/utils/docker/Dockerfile.manylinux_2_28-plugins-deps b/utils/docker/Dockerfile.manylinux_2_28-plugins-deps index ac6791fb..484ee476 100644 --- a/utils/docker/Dockerfile.manylinux_2_28-plugins-deps +++ b/utils/docker/Dockerfile.manylinux_2_28-plugins-deps @@ -35,8 +35,8 @@ RUN [ "/bin/bash", "install-ffmpeg-v6.0.sh" ] ENV PKG_CONFIG_PATH=/root/FFmpeg-n6.0/output/lib/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} ENV LD_LIBRARY_PATH=/root/FFmpeg-n6.0/output/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} -ENV OPENVINO_VERSION="2024.2.0" -ENV OPENVINO_YEAR="2024" +ENV OPENVINO_VERSION="2025.0.0" +ENV OPENVINO_YEAR="2025" COPY wasi-nn/install-onnxruntime.sh . RUN [ "/bin/bash", "install-onnxruntime.sh" ] diff --git a/utils/docker/Dockerfile.ubuntu-plugins-deps b/utils/docker/Dockerfile.ubuntu-plugins-deps index 940e2157..1b177c54 100644 --- a/utils/docker/Dockerfile.ubuntu-plugins-deps +++ b/utils/docker/Dockerfile.ubuntu-plugins-deps @@ -61,8 +61,8 @@ ARG UBUNTU_VER COPY wasi-nn/install-openvino.sh . ENV OPENVINO_UBUNTU_VERSION=${UBUNTU_VER} -ENV OPENVINO_VERSION="2024.2.0" -ENV OPENVINO_YEAR="2024" +ENV OPENVINO_VERSION="2025.0.0" +ENV OPENVINO_YEAR="2025" RUN [ "/bin/bash", "install-openvino.sh" ] COPY wasi-nn/install-onnxruntime.sh . diff --git a/utils/wasi-nn/install-openvino.sh b/utils/wasi-nn/install-openvino.sh index 9950caca..432f1ac8 100755 --- a/utils/wasi-nn/install-openvino.sh +++ b/utils/wasi-nn/install-openvino.sh @@ -2,13 +2,21 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-FileCopyrightText: 2019-2024 Second State INC set -e -echo "Installing OpenVINO with version 2024.2.0" + +if [[ ! -v "${OPENVINO_VERSION}" ]]; then + OPENVINO_VERSION="2025.0.0" +fi +if [[ ! -v "${OPENVINO_YEAR}" ]]; then + OPENVINO_YEAR="2025" +fi + +echo "Installing OpenVINO with version ${OPENVINO_VERSION}" KEY_FILE=GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB wget https://apt.repos.intel.com/intel-gpg-keys/$KEY_FILE && \ apt-key add $KEY_FILE && \ rm -f $KEY_FILE UBUNTU_VERSION="ubuntu${OPENVINO_UBUNTU_VERSION:-20}" -echo "deb https://apt.repos.intel.com/openvino/2024 ${UBUNTU_VERSION} main" | tee /etc/apt/sources.list.d/intel-openvino-2024.list +echo "deb https://apt.repos.intel.com/openvino/$OPENVINO_YEAR ${UBUNTU_VERSION} main" | tee /etc/apt/sources.list.d/intel-openvino-$OPENVINO_YEAR.list apt update -apt-get -y install openvino-2024.2.0 +apt-get -y install openvino-$OPENVINO_VERSION ldconfig From 7787040ece0c1aac92bf86a04639c830b5d92e89 Mon Sep 17 00:00:00 2001 From: dm4 Date: Mon, 17 Feb 2025 16:21:46 +0800 Subject: [PATCH 535/623] [WASI-NN] ggml: add TTS speaker profile support (#4020) [WASI-NN] ggml: add more TTS speakers support Signed-off-by: dm4 --- plugins/wasi_nn/wasinn_ggml.cpp | 184 ++++++++++++++------------------ plugins/wasi_nn/wasinn_ggml.h | 6 ++ 2 files changed, 85 insertions(+), 105 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 2fcb9222..ac16f0d3 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -145,6 +145,7 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, // tts: bool // model-vocoder: string // tts-output-file: string + // tts-speaker-file: string // Context parameters (used by the llama context): // ctx-size: int64_t // batch-size: int64_t @@ -313,6 +314,16 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, } GraphRef.TTSOutputFilePath = TTSOutputFilePath; } + if (Doc.at_key("tts-speaker-file").error() == simdjson::SUCCESS) { + std::string_view TTSSpeakerFilePath; + auto Err = + Doc["tts-speaker-file"].get().get(TTSSpeakerFilePath); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the tts-speaker-file option."sv) + } + GraphRef.TTSSpeakerFilePath = TTSSpeakerFilePath; + } // The context parameters. if (Doc.at_key("ctx-size").error() == simdjson::SUCCESS) { @@ -541,99 +552,13 @@ extractBase64ImagePayload(std::string &Prompt, } // TTS function to process the prompt text. -const std::vector TTSVoiceData = llama_tokens{ - 151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585, - 152460, 153375, 151670, 198, 74455, 155808, 151669, 151799, 151873, - 151863, 152446, 152372, 152204, 152728, 152229, 152470, 151970, 153413, - 152419, 153334, 153289, 153374, 153199, 152040, 153260, 152721, 152680, - 153297, 152419, 153248, 152400, 152691, 153368, 153437, 151670, 198, - 1722, 155828, 151669, 152607, 152256, 152991, 152299, 152688, 153163, - 153016, 152789, 153198, 152712, 151911, 153107, 152623, 152170, 152395, - 152852, 152207, 152461, 153321, 153309, 151750, 152137, 153340, 152573, - 152267, 153347, 151789, 152681, 153339, 151992, 152512, 151751, 152179, - 153434, 153180, 152900, 153440, 152474, 153122, 153129, 151904, 152311, - 151670, 198, 1499, 155791, 151669, 152276, 152454, 153354, 152544, - 153204, 153272, 152708, 153433, 152319, 153226, 153043, 152325, 153267, - 152622, 151670, 198, 4250, 155797, 151669, 153454, 153342, 151989, - 152458, 153420, 152303, 152271, 152827, 153036, 153196, 151708, 153263, - 152561, 153207, 152213, 152112, 153204, 151722, 152542, 151670, 198, - 19789, 155796, 151669, 153353, 153182, 152345, 152471, 152477, 153014, - 152002, 152191, 151734, 152312, 152810, 152237, 153224, 153169, 153224, - 152244, 153387, 153404, 151670, 198, 16069, 155811, 151669, 152265, - 151946, 151808, 152412, 152363, 152305, 153156, 152733, 152810, 153157, - 152016, 152100, 152069, 153234, 152317, 152589, 152707, 153121, 153341, - 152159, 152114, 153156, 153001, 153504, 153376, 152272, 152433, 152325, - 151941, 151670, 198, 285, 155788, 151669, 152238, 152255, 153427, - 152318, 153009, 152381, 152474, 152680, 152157, 153255, 152324, 151682, - 151670, 198, 32955, 155804, 151669, 153490, 153419, 152364, 152405, - 152682, 152206, 152078, 153369, 152725, 153193, 153027, 152946, 152488, - 153070, 151883, 152890, 152489, 153144, 153375, 152358, 151685, 152494, - 152117, 152740, 151670, 198, 37448, 480, 155840, 151669, 151902, - 152720, 153377, 152027, 152378, 152821, 153207, 153459, 153028, 153068, - 152507, 153255, 152158, 152921, 151958, 152609, 152748, 152822, 152286, - 151714, 152730, 152377, 152353, 152470, 152606, 152162, 152186, 153071, - 152244, 153118, 153375, 153018, 152712, 153098, 152976, 152336, 151843, - 153202, 152297, 151736, 153380, 153502, 152702, 152115, 153181, 152735, - 153277, 153457, 152393, 153112, 152595, 151670, 198, 19098, 155808, - 151669, 152464, 153452, 152595, 153312, 151937, 151933, 153197, 152239, - 153163, 152922, 153402, 152034, 152591, 153438, 152215, 151673, 152005, - 151785, 152642, 151924, 153278, 151805, 151974, 153482, 152718, 152862, - 153347, 151670, 198, 72, 155780, 151669, 151795, 152111, 152746, - 152377, 153471, 152309, 151670, 198, 19016, 155788, 151669, 153181, - 152271, 152190, 152842, 152224, 152701, 152939, 152536, 152091, 151815, - 152733, 151672, 151670, 198, 14689, 155788, 151669, 152291, 152072, - 152942, 151734, 153042, 153504, 152589, 153333, 151839, 151941, 153038, - 153180, 151670, 198, 36996, 8303, 155832, 151669, 152231, 152256, - 152835, 152801, 152985, 153400, 152393, 152818, 152765, 152249, 152600, - 151699, 152302, 152752, 153018, 153009, 151992, 153054, 152847, 153354, - 153228, 152662, 153355, 152532, 153393, 151782, 152458, 152048, 152757, - 152428, 153195, 151906, 153006, 153178, 153250, 152331, 152284, 152780, - 153138, 153319, 151980, 153142, 152418, 152228, 152733, 151670, 198, - 9096, 155801, 151669, 151698, 153321, 152217, 153039, 152935, 153400, - 152122, 152531, 153106, 152169, 152892, 152957, 151851, 152427, 152826, - 152451, 151851, 152901, 152885, 152594, 153446, 153080, 151670, 198, - 14689, 155795, 151669, 152658, 151700, 153321, 152450, 152530, 153191, - 151673, 151690, 151698, 152714, 152846, 152981, 153171, 153384, 153364, - 153188, 153246, 151670, 198, 1055, 155779, 151669, 151869, 152388, - 152711, 153334, 151736, 151670, 198, 1782, 155780, 151669, 153483, - 153240, 152241, 152558, 152697, 153046, 151670, 198, 5804, 1363, - 155820, 151669, 152941, 152764, 152605, 153034, 153434, 153372, 153347, - 151887, 152453, 152758, 152133, 152510, 152694, 152431, 152321, 153088, - 152676, 152223, 152581, 152459, 152015, 152502, 153063, 152712, 153294, - 153451, 153032, 152903, 152859, 152989, 151748, 152669, 152661, 152650, - 152409, 151861, 151670, 198, 300, 7973, 155828, 151669, 153095, - 152469, 152988, 152894, 151819, 152391, 153019, 152058, 153062, 153230, - 151826, 152112, 152306, 152264, 152769, 153390, 152384, 152435, 152790, - 153393, 152983, 152540, 152252, 152034, 153107, 152540, 151919, 151893, - 152558, 152817, 152946, 152956, 152129, 152715, 153131, 153490, 151734, - 152271, 152707, 151734, 153321, 152450, 151670, 198, 8088, 155792, - 151669, 152452, 153497, 153353, 152679, 152533, 152382, 152374, 152611, - 153341, 153163, 152285, 153411, 152495, 153141, 152320, 151670, 198, - 1199, 155781, 151669, 151764, 152360, 153295, 152634, 153342, 152199, - 152271, 151670, 198, 43366, 155799, 151669, 152308, 151682, 152889, - 152016, 152385, 152629, 152495, 151826, 153321, 152958, 152180, 151886, - 153432, 152922, 152128, 153024, 153040, 152593, 152287, 151677, 151670, - 198, 53660, 155808, 151669, 151727, 152092, 152680, 153331, 151699, - 152316, 152938, 152289, 152433, 153384, 151781, 153137, 153259, 152175, - 153213, 152291, 151869, 152691, 152489, 151941, 152049, 152034, 153053, - 152179, 153160, 151676, 153367, 151670, 198, 268, 4123, 480, - 155821, 151669, 152350, 152173, 152536, 151991, 151960, 153144, 153013, - 152358, 152234, 153135, 152291, 153235, 152143, 152583, 152402, 153483, - 152678, 152192, 152533, 152946, 151797, 153103, 152310, 152293, 151825, - 152548, 153442, 152109, 152659, 153325, 152781, 152570, 152957, 151752, - 152265, 153381, 152515, 151670, 198, 437, 155787, 151669, 152957, - 152659, 151975, 152709, 152402, 152836, 152174, 151792, 153409, 153327, - 152990, 151670, 198, 275, 155781, 151669, 152520, 153038, 152067, - 153273, 153185, 152265, 152974, 151670, 198, 94273, 155799, 151669, - 152953, 152938, 153427, 152244, 151920, 153423, 152929, 152367, 153052, - 152129, 152331, 152257, 152987, 152777, 153448, 152408, 151696, 152408, - 152326, 152699, 151670, 198, 385, 16239, 155828, 151669, 152306, - 152268, 153438, 153228, 152978, 152957, 153153, 153393, 152795, 152110, - 152918, 152923, 152467, 152331, 153053, 153330, 151889, 153444, 152234, - 152624, 151779, 152801, 152784, 152139, 152222, 152751, 152512, 153287, - 153141, 153052, 151840, 152589, 152508, 153499, 152109, 152255, 151739, - 152267, 152759, 153318, 153165, 153349, 151670, +// clang-format off +const TTSSpeakerProfile TTSDefaultSpeakerProfile = { + // Speaker profile from edwko/OuteTTS (en_female_1.json). + "<|text_start|>uhm<|text_sep|>now<|text_sep|>being<|text_sep|>the<|text_sep|>one<|text_sep|>to<|text_sep|>say<|text_sep|>i<|text_sep|>know<|text_sep|>the<|text_sep|>worst<|text_sep|>of<|text_sep|>you<|text_sep|>and<|text_sep|>ive<|text_sep|>been<|text_sep|>directly<|text_sep|>affected<|text_sep|>by<|text_sep|>people<|text_sep|>like<|text_sep|>you<|text_sep|>but<|text_sep|>its<|text_sep|>a<|text_sep|>clean<|text_sep|>slate<|text_sep|>with<|text_sep|>me<|text_sep|>buddy<|text_sep|>you<|text_sep|>know<|text_sep|>like<|text_sep|>thats<|text_sep|>really<|text_sep|>powerful<|text_sep|>in<|text_sep|>and<|text_sep|>of<|text_sep|>itself<|text_sep|>", + "<|audio_start|>\nuhm<|t_0.36|><|code_start|><|447|><|223|><|967|><|301|><|965|><|827|><|393|><|908|><|764|><|1167|><|711|><|1222|><|324|><|1318|><|806|><|498|><|1198|><|1127|><|1178|><|916|><|1234|><|1411|><|1428|><|706|><|427|><|1605|><|1578|><|code_end|>\nnow<|t_0.36|><|code_start|><|1049|><|327|><|385|><|1070|><|732|><|1480|><|450|><|1025|><|1469|><|174|><|1013|><|1710|><|1674|><|775|><|771|><|251|><|778|><|1400|><|897|><|1487|><|366|><|441|><|1000|><|393|><|271|><|1000|><|768|><|code_end|>\nbeing<|t_0.27|><|code_start|><|926|><|406|><|1457|><|437|><|1231|><|672|><|1785|><|521|><|1179|><|1559|><|198|><|1086|><|733|><|122|><|1344|><|845|><|348|><|1389|><|470|><|1773|><|code_end|>\nthe<|t_0.08|><|code_start|><|1775|><|562|><|768|><|1222|><|768|><|963|><|code_end|>\none<|t_0.21|><|code_start|><|1757|><|744|><|144|><|1610|><|655|><|616|><|1317|><|225|><|1325|><|913|><|1342|><|992|><|1018|><|80|><|1777|><|883|><|code_end|>\nto<|t_0.08|><|code_start|><|487|><|1363|><|1682|><|1426|><|655|><|1483|><|code_end|>\nsay<|t_0.27|><|code_start|><|1644|><|1804|><|731|><|273|><|1592|><|731|><|1523|><|1404|><|984|><|1207|><|430|><|1132|><|1123|><|768|><|1116|><|829|><|1082|><|1095|><|440|><|1162|><|code_end|>\ni<|t_0.33|><|code_start|><|1330|><|335|><|1162|><|1155|><|308|><|1162|><|1150|><|1481|><|612|><|674|><|712|><|1745|><|1188|><|1787|><|1135|><|1275|><|1237|><|1143|><|408|><|1063|><|393|><|927|><|1298|><|132|><|1686|><|code_end|>\nknow<|t_0.27|><|code_start|><|983|><|1677|><|586|><|1528|><|1435|><|835|><|1396|><|706|><|987|><|22|><|1172|><|218|><|1404|><|1001|><|521|><|1389|><|775|><|1416|><|877|><|120|><|code_end|>\nthe<|t_0.16|><|code_start|><|916|><|1756|><|513|><|1245|><|1392|><|89|><|1266|><|12|><|1045|><|1075|><|904|><|35|><|code_end|>\nworst<|t_0.32|><|code_start|><|1607|><|174|><|1231|><|144|><|932|><|490|><|771|><|1504|><|798|><|674|><|364|><|80|><|1314|><|1636|><|449|><|1704|><|713|><|1795|><|968|><|1527|><|1302|><|1529|><|1176|><|795|><|code_end|>\nof<|t_0.12|><|code_start|><|1193|><|1205|><|390|><|1128|><|1091|><|883|><|322|><|377|><|1070|><|code_end|>\nyou<|t_0.17|><|code_start|><|1016|><|1332|><|926|><|281|><|927|><|1368|><|1687|><|918|><|67|><|1638|><|1317|><|1265|><|1770|><|code_end|>\nand<|t_0.28|><|code_start|><|1129|><|1633|><|1373|><|1207|><|405|><|879|><|1030|><|1253|><|1071|><|612|><|724|><|1770|><|665|><|1046|><|1351|><|1450|><|1541|><|1384|><|111|><|1477|><|284|><|code_end|>\nive<|t_0.35|><|code_start|><|674|><|266|><|89|><|1333|><|1183|><|1526|><|1143|><|883|><|1135|><|732|><|827|><|1119|><|594|><|1261|><|1024|><|1347|><|92|><|1392|><|825|><|1710|><|1289|><|1598|><|1070|><|1525|><|1442|><|555|><|code_end|>\nbeen<|t_0.17|><|code_start|><|1461|><|194|><|337|><|1128|><|188|><|892|><|848|><|1280|><|959|><|754|><|231|><|649|><|1304|><|code_end|>\ndirectly<|t_0.87|><|code_start|><|1030|><|353|><|570|><|1331|><|470|><|1832|><|1362|><|1809|><|1383|><|101|><|325|><|1557|><|1242|><|1512|><|180|><|227|><|1242|><|643|><|209|><|464|><|171|><|1219|><|174|><|1723|><|734|><|118|><|1269|><|643|><|209|><|187|><|612|><|1231|><|68|><|567|><|1242|><|505|><|319|><|1268|><|794|><|678|><|40|><|1286|><|470|><|1454|><|199|><|965|><|188|><|300|><|1234|><|1125|><|794|><|1289|><|1224|><|257|><|469|><|1121|><|101|><|823|><|1769|><|1683|><|95|><|255|><|59|><|67|><|832|><|code_end|>\naffected<|t_0.44|><|code_start|><|510|><|873|><|787|><|1228|><|771|><|1428|><|501|><|751|><|696|><|258|><|845|><|1818|><|1112|><|498|><|1111|><|985|><|1073|><|832|><|1427|><|168|><|163|><|447|><|119|><|567|><|1626|><|1820|><|903|><|635|><|1060|><|10|><|1632|><|35|><|1635|><|code_end|>\nby<|t_0.19|><|code_start|><|144|><|144|><|460|><|185|><|1112|><|1044|><|498|><|1192|><|656|><|1333|><|1001|><|1186|><|1186|><|454|><|code_end|>\npeople<|t_0.48|><|code_start|><|1260|><|747|><|351|><|526|><|612|><|1151|><|1262|><|1791|><|344|><|1752|><|1547|><|930|><|1302|><|1703|><|1289|><|92|><|1407|><|1482|><|508|><|1431|><|355|><|1696|><|337|><|199|><|1157|><|223|><|464|><|568|><|845|><|411|><|826|><|718|><|1786|><|545|><|712|><|580|><|code_end|>\nlike<|t_0.32|><|code_start|><|630|><|532|><|526|><|607|><|526|><|839|><|1305|><|660|><|459|><|339|><|717|><|1178|><|1148|><|687|><|149|><|1390|><|229|><|199|><|513|><|712|><|1451|><|731|><|582|><|1551|><|code_end|>\nyou<|t_0.21|><|code_start|><|1389|><|954|><|1781|><|1047|><|1236|><|930|><|809|><|1621|><|1268|><|384|><|242|><|587|><|869|><|816|><|1680|><|405|><|code_end|>\nbut<|t_0.59|><|code_start|><|1089|><|1590|><|908|><|80|><|594|><|1046|><|1706|><|1025|><|1150|><|405|><|548|><|893|><|1285|><|464|><|301|><|939|><|643|><|23|><|285|><|161|><|209|><|453|><|72|><|167|><|417|><|244|><|151|><|643|><|391|><|199|><|651|><|1023|><|337|><|1010|><|54|><|331|><|1167|><|756|><|388|><|934|><|1060|><|18|><|1624|><|1060|><|code_end|>\nits<|t_0.16|><|code_start|><|1102|><|183|><|1199|><|1258|><|1285|><|35|><|659|><|180|><|426|><|1587|><|1733|><|942|><|code_end|>\na<|t_0.04|><|code_start|><|791|><|1012|><|818|><|code_end|>\nclean<|t_0.61|><|code_start|><|1819|><|976|><|163|><|447|><|316|><|223|><|763|><|457|><|1208|><|1808|><|1697|><|1162|><|1660|><|1833|><|1054|><|1734|><|1121|><|1309|><|1643|><|924|><|1677|><|1548|><|869|><|1268|><|223|><|674|><|111|><|792|><|1670|><|912|><|174|><|1554|><|90|><|80|><|1563|><|1621|><|1698|><|1544|><|992|><|988|><|175|><|793|><|1661|><|1026|><|80|><|1761|><|code_end|>\nslate<|t_0.40|><|code_start|><|1802|><|322|><|1689|><|1577|><|1302|><|1552|><|1529|><|1722|><|1580|><|582|><|1642|><|1529|><|1020|><|582|><|1538|><|970|><|437|><|1141|><|1477|><|988|><|335|><|1611|><|922|><|1558|><|1120|><|1189|><|423|><|188|><|171|><|562|><|code_end|>\nwith<|t_0.15|><|code_start|><|963|><|1347|><|1274|><|747|><|1230|><|712|><|1408|><|1290|><|957|><|1279|><|258|><|code_end|>\nme<|t_0.09|><|code_start|><|638|><|1058|><|174|><|1452|><|1038|><|894|><|1571|><|code_end|>\nbuddy<|t_0.32|><|code_start|><|1003|><|130|><|1341|><|938|><|40|><|804|><|167|><|89|><|1456|><|1189|><|1155|><|1171|><|1434|><|1077|><|1029|><|1455|><|1622|><|1037|><|163|><|1411|><|1165|><|1463|><|837|><|1202|><|code_end|>\nyou<|t_0.36|><|code_start|><|1354|><|1165|><|615|><|1588|><|1192|><|1445|><|1033|><|982|><|401|><|1079|><|684|><|1570|><|266|><|31|><|420|><|163|><|893|><|845|><|905|><|1827|><|1804|><|153|><|627|><|243|><|1179|><|298|><|1147|><|code_end|>\nknow<|t_0.19|><|code_start|><|163|><|1542|><|1366|><|698|><|1753|><|206|><|916|><|1499|><|245|><|665|><|600|><|894|><|587|><|1741|><|code_end|>\nlike<|t_0.24|><|code_start|><|1106|><|1280|><|1062|><|1304|><|945|><|809|><|598|><|104|><|1001|><|822|><|965|><|189|><|693|><|1810|><|1293|><|199|><|1277|><|44|><|code_end|>\nthats<|t_0.24|><|code_start|><|121|><|1789|><|1443|><|370|><|1154|><|393|><|1178|><|1200|><|1264|><|424|><|1391|><|381|><|978|><|1346|><|704|><|1808|><|1579|><|1492|><|code_end|>\nreally<|t_0.56|><|code_start|><|1177|><|1761|><|1723|><|1360|><|1413|><|830|><|551|><|193|><|59|><|332|><|598|><|734|><|1684|><|1802|><|60|><|1590|><|353|><|89|><|1636|><|1396|><|893|><|143|><|455|><|1501|><|435|><|1082|><|621|><|1593|><|677|><|474|><|971|><|1513|><|913|><|828|><|1381|><|1148|><|1798|><|1186|><|1443|><|38|><|335|><|883|><|code_end|>\npowerful<|t_0.63|><|code_start|><|1773|><|458|><|1070|><|964|><|826|><|1220|><|1012|><|1738|><|1125|><|669|><|490|><|1169|><|922|><|958|><|1204|><|489|><|1001|><|886|><|1045|><|675|><|1471|><|1652|><|732|><|698|><|1124|><|480|><|897|><|1484|><|1028|><|35|><|594|><|1465|><|505|><|1669|><|436|><|851|><|1288|><|31|><|1501|><|1187|><|394|><|909|><|1541|><|1793|><|1720|><|922|><|840|><|code_end|>\nin<|t_0.16|><|code_start|><|1317|><|523|><|630|><|1343|><|1187|><|719|><|907|><|636|><|111|><|1524|><|188|><|1382|><|code_end|>\nand<|t_0.13|><|code_start|><|1074|><|922|><|1280|><|1496|><|1050|><|832|><|133|><|1435|><|1049|><|1774|><|code_end|>\nof<|t_0.12|><|code_start|><|960|><|1052|><|1192|><|1303|><|1112|><|970|><|417|><|60|><|1155|><|code_end|>\nitself<|t_0.47|><|code_start|><|1682|><|1209|><|1410|><|513|><|1222|><|861|><|167|><|406|><|1551|><|582|><|634|><|1529|><|786|><|1363|><|1578|><|1739|><|873|><|424|><|1041|><|1328|><|955|><|1110|><|1490|><|1424|><|1199|><|988|><|1162|><|1133|><|1193|><|978|><|470|><|832|><|963|><|1251|><|733|><|code_end|>", }; +// clang-format on const std::map Ones = { {0, "zero"}, {1, "one"}, {2, "two"}, {3, "three"}, @@ -764,25 +689,67 @@ std::string processTTSPromptText(const std::string &Text) { return ProcessedText; } +std::optional +getSpeakerProfileFromFile(const std::string &FilePath) { + std::ifstream JsonFile(FilePath); + if (!JsonFile.is_open()) { + return std::nullopt; + } + nlohmann::json JsonData; + JsonFile >> JsonData; + JsonFile.close(); + + // Initialize the outputs + std::string AudioOutputText = "<|audio_start|>\n"; + std::string TextOutput = "<|text_start|>"; + + // Iterate through each word in the JSON data + for (const auto &WordData : JsonData["words"]) { + std::string Word = WordData["word"]; + double Duration = WordData["duration"]; + std::vector Codes = WordData["codes"]; + + // Create the audio output entry + std::ostringstream WordEntry; + WordEntry << Word << "<|t_" << std::fixed << std::setprecision(2) + << Duration << "|><|code_start|>"; + for (const auto &Code : Codes) { + WordEntry << "<|" << Code << "|>"; + } + WordEntry << "<|code_end|>\n"; + AudioOutputText += WordEntry.str(); + + // Create the text output entry + TextOutput += Word + "<|text_sep|>"; + } + + return TTSSpeakerProfile{TextOutput, AudioOutputText}; +} + std::vector processTTSPrompt(Graph &GraphRef, std::string &Prompt) noexcept { + // Use the custom speaker profile if available. + TTSSpeakerProfile SpeakerProfile = TTSDefaultSpeakerProfile; + if (!GraphRef.TTSSpeakerFilePath.empty()) { + std::optional SpeakerProfileOpt = + getSpeakerProfileFromFile(GraphRef.TTSSpeakerFilePath); + if (SpeakerProfileOpt.has_value()) { + SpeakerProfile = *SpeakerProfileOpt; + } else { + RET_ERROR( + {}, + "processTTSPrompt: Failed to load speaker profile from file: {}"sv, + GraphRef.TTSSpeakerFilePath); + } + } std::string ProcessedPrompt = processTTSPromptText(Prompt); std::vector Result, TmpTokens; Result = common_tokenize(GraphRef.LlamaContext.get(), "<|im_start|>\n", /* add_special */ true, /* parse_special */ true); - TmpTokens = common_tokenize( - GraphRef.LlamaContext.get(), - "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<" - "|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_" - "sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_" - "sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_" - "sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>" - "aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>" - "really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>" - "looks<|text_sep|>lovely<|text_sep|>", - /* add_special */ false, - /* parse_special */ true); + TmpTokens = common_tokenize(GraphRef.LlamaContext.get(), SpeakerProfile.Text, + /* add_special */ false, + /* parse_special */ true); Result.insert(Result.end(), TmpTokens.begin(), TmpTokens.end()); TmpTokens = common_tokenize(GraphRef.LlamaContext.get(), ProcessedPrompt, /* add_special */ false, @@ -792,7 +759,10 @@ std::vector processTTSPrompt(Graph &GraphRef, /* add_special */ false, /* parse_special */ true); Result.insert(Result.end(), TmpTokens.begin(), TmpTokens.end()); - Result.insert(Result.end(), TTSVoiceData.begin(), TTSVoiceData.end()); + TmpTokens = common_tokenize(GraphRef.LlamaContext.get(), SpeakerProfile.Data, + /* add_special */ false, + /* parse_special */ true); + Result.insert(Result.end(), TmpTokens.begin(), TmpTokens.end()); return Result; } @@ -1908,6 +1878,10 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, // TTS prompt. LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize tts prompt"sv) CxtRef.LlamaInputs = processTTSPrompt(GraphRef, Prompt); + if (CxtRef.LlamaInputs.empty()) { + RET_ERROR(ErrNo::InvalidArgument, + "setInput: failed to tokenize tts prompt."sv) + } LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize tts prompt...Done"sv) } else { // Text only prompt. diff --git a/plugins/wasi_nn/wasinn_ggml.h b/plugins/wasi_nn/wasinn_ggml.h index 51d788dd..6ef58053 100644 --- a/plugins/wasi_nn/wasinn_ggml.h +++ b/plugins/wasi_nn/wasinn_ggml.h @@ -38,6 +38,11 @@ enum class VisionModel : uint8_t { Qwen2VL = 1, }; +struct TTSSpeakerProfile { + std::string Text; + std::string Data; +}; + struct LocalConfig { // Configurations which can be changed in every contexts. // The graph handles a default config and parsed from metadata when loading. @@ -75,6 +80,7 @@ struct Graph { bool TextToSpeech = false; std::string VocoderModelPath; std::string TTSOutputFilePath = "output.wav"; + std::string TTSSpeakerFilePath; llama_model_ptr TTSModel = nullptr; llama_context_ptr TTSContext = nullptr; // Context parameters: From 18da99b4cc79dfc09acac8ef627ee1792e349be8 Mon Sep 17 00:00:00 2001 From: Han-Wen Tsao Date: Thu, 20 Feb 2025 11:09:13 +0800 Subject: [PATCH 536/623] [WASI-NN] ggml: add parameters (#3995) * [WASI-NN] ggml: add parameters Signed-off-by: grorge * [WASI-NN] ggml: refactor parameter Signed-off-by: grorge * [WASI-NN] ggml: init llama context Signed-off-by: grorge * [WASI-NN] ggml: add context parameter Signed-off-by: grorge * [WASI-NN] ggml: add sampling parameter Signed-off-by: grorge * [WASI-NN] ggml: add other parameter Signed-off-by: grorge * [WASI-NN] ggml: merge parameter Signed-off-by: grorge * [WASI-NN] ggml: add speculative, vocoder parameter Signed-off-by: grorge * [WASI-NN] ggml: merge and add new parameters Signed-off-by: grorge * [WASI-NN] ggml: remove duplicated parameter Signed-off-by: grorge --------- Signed-off-by: grorge Co-authored-by: hydai --- plugins/wasi_nn/wasinn_ggml.cpp | 1928 ++++++++++++++++++++++++++----- plugins/wasi_nn/wasinn_ggml.h | 25 +- 2 files changed, 1633 insertions(+), 320 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index ac16f0d3..b4abbd14 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -3,6 +3,7 @@ #include "wasinn_ggml.h" #include "wasinnenv.h" +#include #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML #include "simdjson.h" @@ -81,40 +82,6 @@ void llamaLogCallback(ggml_log_level LogLevel, const char *LogText, // >>>>>>>> Metadata related functions >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> -// Setup llama sampler params from graph. -void setupSamplerParams(Graph &GraphRef, - common_params_sampling &Sampling) noexcept { - Sampling.temp = static_cast(GraphRef.Temp); - Sampling.top_p = static_cast(GraphRef.TopP); - Sampling.penalty_repeat = static_cast(GraphRef.RepeatPenalty); - Sampling.penalty_present = static_cast(GraphRef.PresencePenalty); - Sampling.penalty_freq = static_cast(GraphRef.FrequencyPenalty); - Sampling.grammar = GraphRef.Grammar; - Sampling.seed = static_cast(GraphRef.Seed); - - if (GraphRef.TextToSpeech) { - Sampling.top_k = 4; - Sampling.samplers = { - COMMON_SAMPLER_TYPE_TOP_K, - }; - } -} - -// Setup llama common params from graph. -void setupCommonParams(Graph &GraphRef, common_params &Params) noexcept { - Params.model = GraphRef.ModelFilePath; - Params.n_gpu_layers = static_cast(GraphRef.NGPULayers); - Params.n_ctx = static_cast(GraphRef.CtxSize); - Params.n_batch = static_cast(GraphRef.BatchSize); - Params.n_ubatch = static_cast(GraphRef.UBatchSize); - Params.warmup = GraphRef.WarmUp; - Params.split_mode = GraphRef.SplitMode; - Params.cpuparams.n_threads = static_cast(GraphRef.Threads); - Params.cpuparams_batch.n_threads = static_cast(GraphRef.Threads); - Params.embedding = GraphRef.Embedding; - setupSamplerParams(GraphRef, Params.sampling); -} - // Parse metadata from json. ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, const std::string &Metadata, bool *IsModelUpdated = nullptr, @@ -127,357 +94,1694 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, if (ParseError) { RET_ERROR(ErrNo::InvalidEncoding, "parse metadata error."sv) } - - // Currently supported metadata: - // Plugin parameters (used by this graph and created contexts): - // enable-log: bool - // enable-debug-log: bool - // Model parameters (need to reload the model if updated): - // main-gpu: int64_t - // n-gpu-layers: int64_t - // tensor-split: string, comma-separated floating number list - // embedding: bool - // use-mmap: bool - // warmup: bool - // split-mode: string, {none,layer,row} - // mmproj: string - // TTS parameters: - // tts: bool - // model-vocoder: string - // tts-output-file: string - // tts-speaker-file: string - // Context parameters (used by the llama context): - // ctx-size: int64_t - // batch-size: int64_t - // ubatch-size: int64_t - // threads: int64_t - // [local-config] always-regenerate-image-embd: bool - // Sampling parameters (used by the llama sampling context): - // temp: double - // top-p: double - // repeat-penalty: double - // presence-penalty: double - // frequency-penalty: double - // grammar: string - // seed: uint64_t - // Config parameters (mutable config at runtime for contexts): - // stream-stdout: bool - // n-predict: int64_t - // reverse-prompt: string - // image: string - - // Get the current llama parameters. - int64_t PrevNGPULayers = GraphRef.NGPULayers; - bool PrevEmbedding = GraphRef.Embedding; - // Get the current sampler parameters. - double PrevTemp = GraphRef.Temp; - double PrevTopP = GraphRef.TopP; - double PrevRepeatPenalty = GraphRef.RepeatPenalty; - double PrevPresencePenalty = GraphRef.PresencePenalty; - double PrevFrequencyPenalty = GraphRef.FrequencyPenalty; - std::string PrevGrammar = GraphRef.Grammar; - uint64_t PrevSeed = GraphRef.Seed; - - // The plugin parameters. - if (Doc.at_key("enable-log").error() == simdjson::SUCCESS) { - auto Err = Doc["enable-log"].get().get(GraphRef.EnableLog); + + // Get the current llama parameters. + int64_t PrevNGPULayers = GraphRef.Params.n_gpu_layers; + bool PrevEmbedding = GraphRef.Params.embedding; + // Get the current sampler parameters. + double PrevTemp = GraphRef.Params.sampling.temp; + double PrevTopP = GraphRef.Params.sampling.top_p; + double PrevRepeatPenalty = GraphRef.Params.sampling.penalty_repeat; + double PrevPresencePenalty = GraphRef.Params.sampling.penalty_present; + double PrevFrequencyPenalty = GraphRef.Params.sampling.penalty_freq; + std::string PrevGrammar = GraphRef.Params.sampling.grammar; + uint64_t PrevSeed = GraphRef.Params.sampling.seed; + + // The plugin parameters. + if (Doc.at_key("enable-log").error() == simdjson::SUCCESS) { + auto Err = Doc["enable-log"].get().get(GraphRef.EnableLog); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the enable-log option."sv) + } + } + if (Doc.at_key("enable-debug-log").error() == simdjson::SUCCESS) { + auto Err = Doc["enable-debug-log"].get().get(GraphRef.EnableDebugLog); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the enable-debug-log option."sv) + } + } + + // The model parameters. + if (Doc.at_key("main-gpu").error() == simdjson::SUCCESS) { + int64_t MainGPU; + auto Err = Doc["main-gpu"].get().get(MainGPU); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the main-gpu option."sv) + } + GraphRef.Params.main_gpu = static_cast(MainGPU); + } + if (Doc.at_key("n-gpu-layers").error() == simdjson::SUCCESS) { + int64_t NGPULayers; + auto Err = Doc["n-gpu-layers"].get().get(NGPULayers); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-gpu-layers option."sv) + } + GraphRef.Params.n_gpu_layers = static_cast(NGPULayers); + } + if (Doc.at_key("tensor-split").error() == simdjson::SUCCESS) { + // The TensorSplit is a comma-separated list of non-negative values. + // E.g., "3,2" presents 60% of the data to GPU 0 and 40% to GPU 1. + std::string_view TSV; + auto Err = Doc["tensor-split"].get().get(TSV); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the tensor-split option."sv) + } + std::string TS(TSV); + std::replace(TS.begin(), TS.end(), ',', ' '); + std::stringstream SS(TS); + std::memset(GraphRef.Params.tensor_split, 0, + sizeof(GraphRef.Params.tensor_split)); + uint32_t TensorSplitSize = 0; + while (SS.good()) { + float TmpTensor; + SS >> TmpTensor; + GraphRef.Params.tensor_split[TensorSplitSize++] = TmpTensor; + } + size_t NDevices = llama_max_devices(); + if (TensorSplitSize > NDevices) { + RET_ERROR( + ErrNo::InvalidArgument, + "Number of Tensor-Split is larger than MaxDevices, please reduce "sv + "the size of tensor-split."sv) + } + for (size_t Idx = TensorSplitSize; Idx < NDevices; Idx++) { + GraphRef.Params.tensor_split[TensorSplitSize++] = 0.0f; + } + } + if (Doc.at_key("embedding").error() == simdjson::SUCCESS) { + + auto Err = Doc["embedding"].get().get(GraphRef.Params.embedding); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the embedding option."sv) + } + } + if (Doc.at_key("split-mode").error() == simdjson::SUCCESS) { + std::string_view SplitMode; + auto Err = Doc["split-mode"].get().get(SplitMode); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the split-mode option."sv) + } + if (SplitMode == "none"sv) { + GraphRef.Params.split_mode = LLAMA_SPLIT_MODE_NONE; + } else if (SplitMode == "layer"sv) { + GraphRef.Params.split_mode = LLAMA_SPLIT_MODE_LAYER; + } else if (SplitMode == "row"sv) { + GraphRef.Params.split_mode = LLAMA_SPLIT_MODE_ROW; + } else { + RET_ERROR(ErrNo::InvalidArgument, + "Unknown split-mode: {}. Valid: none, layer, row."sv, SplitMode) + } + } + if (Doc.at_key("mmproj").error() == simdjson::SUCCESS) { + std::string_view MMProjModelPath; + auto Err = Doc["mmproj"].get().get(MMProjModelPath); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the mmproj option."sv) + } + GraphRef.Params.mmproj = MMProjModelPath; + } + + // The TTS parameters. + if (Doc.at_key("tts").error() == simdjson::SUCCESS) { + auto Err = Doc["tts"].get().get(GraphRef.TextToSpeech); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the tts option."sv) + } + } + if (Doc.at_key("model-vocoder").error() == simdjson::SUCCESS) { + std::string_view VocoderModelPath; + auto Err = + Doc["model-vocoder"].get().get(VocoderModelPath); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the model-vocoder option."sv) + } + GraphRef.Params.vocoder.model = VocoderModelPath; + } + if (Doc.at_key("tts-output-file").error() == simdjson::SUCCESS) { + std::string_view TTSOutputFilePath; + auto Err = + Doc["tts-output-file"].get().get(TTSOutputFilePath); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the tts-output-file option."sv) + } + GraphRef.TTSOutputFilePath = TTSOutputFilePath; + } + if (Doc.at_key("tts-speaker-file").error() == simdjson::SUCCESS) { + std::string_view TTSSpeakerFilePath; + auto Err = + Doc["tts-speaker-file"].get().get(TTSSpeakerFilePath); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the tts-speaker-file option."sv) + } + GraphRef.TTSSpeakerFilePath = TTSSpeakerFilePath; + } + + // The context parameters. + if (Doc.at_key("ctx-size").error() == simdjson::SUCCESS) { + int64_t CtxSize; + auto Err = Doc["ctx-size"].get().get(CtxSize); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the ctx-size option."sv) + } + GraphRef.Params.n_ctx = static_cast(CtxSize); + } + if (Doc.at_key("batch-size").error() == simdjson::SUCCESS) { + int64_t BatchSize; + auto Err = Doc["batch-size"].get().get(BatchSize); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the batch-size option."sv) + } + GraphRef.Params.n_batch = static_cast(BatchSize); + } + if (Doc.at_key("ubatch-size").error() == simdjson::SUCCESS) { + int64_t UBatchSize; + auto Err = Doc["ubatch-size"].get().get(UBatchSize); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the ubatch-size option."sv) + } + GraphRef.Params.n_batch = static_cast(UBatchSize); + } + if (Doc.at_key("n-keep").error() == simdjson::SUCCESS) { + int64_t NKeep; + auto Err = Doc["n-keep"].get().get(NKeep); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the n-keep option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Params.n_keep = static_cast(NKeep); + } + if (Doc.at_key("n-chunks").error() == simdjson::SUCCESS) { + int64_t NChunks; + auto Err = Doc["n-chunks"].get().get(NChunks); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the n-chunks option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Params.n_chunks = static_cast(NChunks); + } + if (Doc.at_key("n-parallel").error() == simdjson::SUCCESS) { + int64_t NParallel; + auto Err = Doc["n-parallel"].get().get(NParallel); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the n-parallel option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Params.n_parallel = static_cast(NParallel); + } + if (Doc.at_key("n-sequences").error() == simdjson::SUCCESS) { + int64_t NSequences; + auto Err = Doc["n-sequences"].get().get(NSequences); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the n-sequences option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Params.n_sequences = static_cast(NSequences); + } + if (Doc.at_key("grp-attn-n").error() == simdjson::SUCCESS) { + int64_t GrpAttnN; + auto Err = Doc["grp-attn-n"].get().get(GrpAttnN); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the grp-attn-n option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Params.grp_attn_n = static_cast(GrpAttnN); + } + if (Doc.at_key("grp-attn-w").error() == simdjson::SUCCESS) { + int64_t GrpAttnW; + auto Err = Doc["grp-attn-w"].get().get(GrpAttnW); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the grp-attn-w option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Params.grp_attn_w = static_cast(GrpAttnW); + } + if (Doc.at_key("n-print").error() == simdjson::SUCCESS) { + int64_t NPrint; + auto Err = Doc["n-print"].get().get(NPrint); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the n-print option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Params.n_print = static_cast(NPrint); + } + if (Doc.at_key("rope-freq-base").error() == simdjson::SUCCESS) { + double RopeFreqBase; + auto Err = Doc["rope-freq-base"].get().get(RopeFreqBase); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the rope-freq-base option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Params.rope_freq_base = static_cast(RopeFreqBase); + } + if (Doc.at_key("rope-freq-scale").error() == simdjson::SUCCESS) { + double RopeFreqScale; + auto Err = Doc["rope-freq-scale"].get().get(RopeFreqScale); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the rope-freq-scale option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Params.rope_freq_scale = static_cast(RopeFreqScale); + } + if (Doc.at_key("yarn-ext-factor").error() == simdjson::SUCCESS) { + double YarnExtFactor; + auto Err = Doc["yarn-ext-factor"].get().get(YarnExtFactor); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the yarn-ext-factor option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Params.yarn_ext_factor = static_cast(YarnExtFactor); + } + if (Doc.at_key("yarn-attn-factor").error() == simdjson::SUCCESS) { + double YarnAttnFactor; + auto Err = Doc["yarn-attn-factor"].get().get(YarnAttnFactor); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the yarn-attn-factor option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Params.yarn_attn_factor = static_cast(YarnAttnFactor); + } + if (Doc.at_key("yarn-beta-fast").error() == simdjson::SUCCESS) { + double YarnBetaFast; + auto Err = Doc["yarn-beta-fast"].get().get(YarnBetaFast); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the yarn-beta-fast option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Params.yarn_beta_fast = static_cast(YarnBetaFast); + } + if (Doc.at_key("yarn-beta-slow").error() == simdjson::SUCCESS) { + double YarnBetaSlow; + auto Err = Doc["yarn-beta-slow"].get().get(YarnBetaSlow); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the yarn-beta-slow option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Params.yarn_beta_slow = static_cast(YarnBetaSlow); + } + if (Doc.at_key("yarn-orig-ctx").error() == simdjson::SUCCESS) { + int64_t YarnOrigCtx; + auto Err = Doc["yarn-orig-ctx"].get().get(YarnOrigCtx); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the yarn-orig-ctx option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Params.yarn_orig_ctx = static_cast(YarnOrigCtx); + } + if (Doc.at_key("defrag-thold").error() == simdjson::SUCCESS) { + double DefragThold; + auto Err = Doc["defrag-thold"].get().get(DefragThold); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the defrag-thold option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Params.defrag_thold = static_cast(DefragThold); + } + if (Doc.at_key("mask-valid").error() == simdjson::SUCCESS) { + auto Err = + Doc["mask-valid"].get().get(GraphRef.Params.cpuparams.mask_valid); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the mask-valid option."sv); + return ErrNo::InvalidArgument; + } + } + if (Doc.at_key("priority").error() == simdjson::SUCCESS) { + int64_t Priority; + auto Err = Doc["priority"].get().get(Priority); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the priority option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Params.cpuparams.priority = + static_cast(Priority); + } + if (Doc.at_key("strict-cpu").error() == simdjson::SUCCESS) { + auto Err = + Doc["strict-cpu"].get().get(GraphRef.Params.cpuparams.strict_cpu); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the strict-cpu option."sv); + return ErrNo::InvalidArgument; + } + } + if (Doc.at_key("poll").error() == simdjson::SUCCESS) { + int64_t Poll; + auto Err = Doc["poll"].get().get(Poll); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the poll option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Params.cpuparams.poll = static_cast(Poll); + } + if (Doc.at_key("mask-valid-batch").error() == simdjson::SUCCESS) { + auto Err = Doc["mask-valid-batch"].get().get( + GraphRef.Params.cpuparams_batch.mask_valid); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the mask-valid-batch option."sv); + return ErrNo::InvalidArgument; + } + } + if (Doc.at_key("priority-batch").error() == simdjson::SUCCESS) { + int64_t Priority; + auto Err = Doc["priority-batch"].get().get(Priority); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the priority-batch option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Params.cpuparams_batch.priority = + static_cast(Priority); + } + if (Doc.at_key("strict-cpu-batch").error() == simdjson::SUCCESS) { + auto Err = Doc["strict-cpu-batch"].get().get( + GraphRef.Params.cpuparams_batch.strict_cpu); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the strict-cpu-batch option."sv); + return ErrNo::InvalidArgument; + } + } + if (Doc.at_key("poll-batch").error() == simdjson::SUCCESS) { + int64_t Poll; + auto Err = Doc["poll-batch"].get().get(Poll); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the poll-batch option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Params.cpuparams_batch.poll = static_cast(Poll); + } + if (Doc.at_key("numa").error() == simdjson::SUCCESS) { + int64_t Numa; + auto Err = Doc["numa"].get().get(Numa); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the numa option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Params.numa = static_cast(Numa); + } + if (Doc.at_key("rope-scaling-type").error() == simdjson::SUCCESS) { + int64_t RopeScalingType; + auto Err = Doc["rope-scaling-type"].get().get(RopeScalingType); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the rope-scaling-type option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Params.rope_scaling_type = + static_cast(RopeScalingType); + } + if (Doc.at_key("pooling-type").error() == simdjson::SUCCESS) { + int64_t PoolingType; + auto Err = Doc["pooling-type"].get().get(PoolingType); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the pooling-type option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Params.pooling_type = + static_cast(PoolingType); + } + if (Doc.at_key("attention-type").error() == simdjson::SUCCESS) { + int64_t AttentionType; + auto Err = Doc["attention-type"].get().get(AttentionType); + if (Err) { + spdlog::error( + "[WASI-NN] GGML backend: Unable to retrieve the attention-type option."sv); + return ErrNo::InvalidArgument; + } + GraphRef.Params.attention_type = + static_cast(AttentionType); + } + if (Doc.at_key("threads").error() == simdjson::SUCCESS) { + int64_t NThreads; + auto Err = Doc["threads"].get().get(NThreads); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the threads option."sv) + } + GraphRef.Params.cpuparams.n_threads = static_cast(NThreads); + } + if (Doc.at_key("threads-batch").error() == simdjson::SUCCESS) { + int64_t NThreadsBatch; + auto Err = Doc["threads-batch"].get().get(NThreadsBatch); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the threads-batch option."sv) + } + GraphRef.Params.cpuparams_batch.n_threads = + static_cast(NThreadsBatch); + } + + // The sampling parameters. + if (Doc.at_key("n-prev").error() == simdjson::SUCCESS) { + int64_t NPrev; + auto Err = Doc["n-prev"].get().get(NPrev); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n_prev option."sv) + } + GraphRef.Params.sampling.n_prev = static_cast(NPrev); + } + if (Doc.at_key("n-probs").error() == simdjson::SUCCESS) { + int64_t NProbs; + auto Err = Doc["n-probs"].get().get(NProbs); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n_probs option."sv) + } + GraphRef.Params.sampling.n_probs = static_cast(NProbs); + } + if (Doc.at_key("min-keep").error() == simdjson::SUCCESS) { + int64_t MinKeep; + auto Err = Doc["min-keep"].get().get(MinKeep); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the min-keep option."sv) + } + GraphRef.Params.sampling.min_keep = static_cast(MinKeep); + } + if (Doc.at_key("top-k").error() == simdjson::SUCCESS) { + int64_t TopK; + auto Err = Doc["top-k"].get().get(TopK); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the top-k option."sv) + } + GraphRef.Params.sampling.top_k = static_cast(TopK); + } + if (Doc.at_key("min-p").error() == simdjson::SUCCESS) { + double MinP; + auto Err = Doc["min-p"].get().get(MinP); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the min-p option."sv) + } + GraphRef.Params.sampling.min_p = static_cast(MinP); + } + if (Doc.at_key("xtc-probability").error() == simdjson::SUCCESS) { + double XtcProbability; + auto Err = Doc["xtc-probability"].get().get(XtcProbability); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the xtc-probability option."sv) + } + GraphRef.Params.sampling.xtc_probability = + static_cast(XtcProbability); + } + if (Doc.at_key("xtc-threshold").error() == simdjson::SUCCESS) { + double XtcThreshold; + auto Err = Doc["xtc-threshold"].get().get(XtcThreshold); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the xtc-threshold option."sv) + } + GraphRef.Params.sampling.xtc_threshold = static_cast(XtcThreshold); + } + if (Doc.at_key("typ-p").error() == simdjson::SUCCESS) { + double TypP; + auto Err = Doc["typ-p"].get().get(TypP); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the typ-p option."sv) + } + GraphRef.Params.sampling.typ_p = static_cast(TypP); + } + if (Doc.at_key("dynatemp-range").error() == simdjson::SUCCESS) { + double DynaTempRange; + auto Err = Doc["dynatemp-range"].get().get(DynaTempRange); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the dynatemp-range option."sv) + } + GraphRef.Params.sampling.dynatemp_range = static_cast(DynaTempRange); + } + if (Doc.at_key("dynatemp-exponent").error() == simdjson::SUCCESS) { + double DynaTempExponent; + auto Err = Doc["dynatemp-exponent"].get().get(DynaTempExponent); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the dynatemp-exponent option."sv) + } + GraphRef.Params.sampling.dynatemp_exponent = + static_cast(DynaTempExponent); + } + if (Doc.at_key("last-n-penalty").error() == simdjson::SUCCESS) { + int64_t LastNPenalty; + auto Err = Doc["last-n-penalty"].get().get(LastNPenalty); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the last-n-penalty option."sv) + } + GraphRef.Params.sampling.penalty_last_n = + static_cast(LastNPenalty); + } + if (Doc.at_key("temp").error() == simdjson::SUCCESS) { + double Temp; + auto Err = Doc["temp"].get().get(Temp); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the temp option."sv) + } + GraphRef.Params.sampling.temp = static_cast(std::max(0.0, Temp)); + } + if (Doc.at_key("top-p").error() == simdjson::SUCCESS) { + double TopP; + auto Err = Doc["top-p"].get().get(TopP); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the top-p option."sv) + } + GraphRef.Params.sampling.top_p = static_cast(std::max(0.0, TopP)); + } + if (Doc.at_key("repeat-penalty").error() == simdjson::SUCCESS) { + double RepeatPenalty; + auto Err = Doc["repeat-penalty"].get().get(RepeatPenalty); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the repeat-penalty option."sv) + } + GraphRef.Params.sampling.penalty_repeat = + static_cast(std::max(0.0, RepeatPenalty)); + } + if (Doc.at_key("presence-penalty").error() == simdjson::SUCCESS) { + double PresencePenalty; + auto Err = Doc["presence-penalty"].get().get(PresencePenalty); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the presence-penalty option."sv) + } + GraphRef.Params.sampling.penalty_present = + static_cast(std::max(0.0, PresencePenalty)); + } + if (Doc.at_key("frequency-penalty").error() == simdjson::SUCCESS) { + double FrequencyPenalty; + auto Err = Doc["frequency-penalty"].get().get(FrequencyPenalty); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the frequency-penalty option."sv) + } + GraphRef.Params.sampling.penalty_freq = + static_cast(std::max(0.0, FrequencyPenalty)); + } + if (Doc.at_key("dry-multipier").error() == simdjson::SUCCESS) { + double DryMultiplier; + auto Err = Doc["dry-multipier"].get().get(DryMultiplier); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the dry-multipier option."sv) + } + GraphRef.Params.sampling.dry_multiplier = static_cast(DryMultiplier); + } + if (Doc.at_key("dry-base").error() == simdjson::SUCCESS) { + double DryBase; + auto Err = Doc["dry-base"].get().get(DryBase); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the dry-base option."sv) + } + GraphRef.Params.sampling.dry_base = static_cast(DryBase); + } + if (Doc.at_key("dry-allowed-length").error() == simdjson::SUCCESS) { + int64_t DryAllowedLength; + auto Err = Doc["dry-allowed-length"].get().get(DryAllowedLength); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the dry-allowed-length option."sv) + } + GraphRef.Params.sampling.dry_allowed_length = + static_cast(DryAllowedLength); + } + if (Doc.at_key("dry-penalty-last-n").error() == simdjson::SUCCESS) { + int64_t DryLastNPenalty; + auto Err = Doc["dry-last-n-penalty"].get().get(DryLastNPenalty); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the dry-last-n-penalty option."sv) + } + GraphRef.Params.sampling.penalty_last_n = + static_cast(DryLastNPenalty); + } + if (Doc.at_key("mirostat").error() == simdjson::SUCCESS) { + int64_t Mirostat; + auto Err = Doc["mirostat"].get().get(Mirostat); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the mirostat option."sv) + } + GraphRef.Params.sampling.mirostat = static_cast(Mirostat); + } + if (Doc.at_key("mirostat-eta").error() == simdjson::SUCCESS) { + double MirostatEta; + auto Err = Doc["mirostat-eta"].get().get(MirostatEta); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the mirostat-eta option."sv) + } + GraphRef.Params.sampling.mirostat_eta = static_cast(MirostatEta); + } + if (Doc.at_key("ignore-eos").error() == simdjson::SUCCESS) { + auto Err = + Doc["ignore-eos"].get().get(GraphRef.Params.sampling.ignore_eos); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the ignore-eos option."sv) + } + } + if (Doc.at_key("no-perf-sampling").error() == simdjson::SUCCESS) { + auto Err = Doc["no-perf-sampling"].get().get( + GraphRef.Params.sampling.no_perf); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the no-perf-sampling option."sv) + } + } + if (Doc.at_key("timing-per-token").error() == simdjson::SUCCESS) { + auto Err = Doc["timing-per-token"].get().get( + GraphRef.Params.sampling.timing_per_token); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the timing-per-token option."sv) + } + } + if (Doc.at_key("grammar").error() == simdjson::SUCCESS) { + std::string_view Grammar; + auto Err = Doc["grammar"].get().get(Grammar); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the grammar option."sv) + } + GraphRef.Params.sampling.grammar = Grammar; + } + if (Doc.at_key("json-schema").error() == simdjson::SUCCESS) { + std::string_view JsonSchema; + auto Err = Doc["json-schema"].get().get(JsonSchema); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the json-schema option."sv) + } + GraphRef.Params.sampling.grammar = + json_schema_to_grammar(nlohmann::ordered_json::parse(JsonSchema)); + } + if (Doc.at_key("seed").error() == simdjson::SUCCESS) { + uint64_t Seed; + auto Err = Doc["seed"].get().get(Seed); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the seed option."sv) + } + GraphRef.Params.sampling.seed = static_cast(Seed); + } + // The speculative parameters. + if (Doc.at_key("n-ctx-speculative").error() == simdjson::SUCCESS) { + int64_t NCtxSpeculative; + auto Err = Doc["n-ctx-speculative"].get().get(NCtxSpeculative); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-ctx-speculative option."sv) + } + GraphRef.Params.speculative.n_ctx = static_cast(NCtxSpeculative); + } + if (Doc.at_key("n-max-speculative").error() == simdjson::SUCCESS) { + int64_t NMaxSpeculative; + auto Err = Doc["n-max-speculative"].get().get(NMaxSpeculative); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-max-speculative option."sv) + } + GraphRef.Params.speculative.n_max = static_cast(NMaxSpeculative); + } + if (Doc.at_key("n-min-speculative").error() == simdjson::SUCCESS) { + int64_t NMinSpeculative; + auto Err = Doc["n-min-speculative"].get().get(NMinSpeculative); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-min-speculative option."sv) + } + GraphRef.Params.speculative.n_min = static_cast(NMinSpeculative); + } + if (Doc.at_key("n-gpu-layers-speculative").error() == simdjson::SUCCESS) { + int64_t NGPULatersinSpeculative; + auto Err = Doc["n-gpu-layers-speculative"].get().get( + NGPULatersinSpeculative); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-gpu-layers-speculative option."sv) + } + GraphRef.Params.speculative.n_gpu_layers = + static_cast(NGPULatersinSpeculative); + } + if (Doc.at_key("p-split-speculative").error() == simdjson::SUCCESS) { + double PSplitSpeculative; + auto Err = Doc["p-split-speculative"].get().get(PSplitSpeculative); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the p-split-speculative option."sv) + } + GraphRef.Params.speculative.p_split = static_cast(PSplitSpeculative); + } + if (Doc.at_key("p-min-speculative").error() == simdjson::SUCCESS) { + double PMinSpeculative; + auto Err = Doc["p-min-speculative"].get().get(PMinSpeculative); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the p-min-speculative option."sv) + } + GraphRef.Params.speculative.p_min = static_cast(PMinSpeculative); + } + // The vocoder parameters. + if (Doc.at_key("hf-repo-vocoder").error() == simdjson::SUCCESS) { + std::string_view HfRepo; + auto Err = Doc["hf-repo-vocoder"].get().get(HfRepo); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the hf-repo-vocoder option."sv) + } + GraphRef.Params.vocoder.hf_repo = HfRepo; + } + if (Doc.at_key("hf-file-vocoder").error() == simdjson::SUCCESS) { + std::string_view HfFile; + auto Err = Doc["hf-file-vocoder"].get().get(HfFile); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the hf-file-vocoder option."sv) + } + GraphRef.Params.vocoder.hf_file = HfFile; + } + if (Doc.at_key("model-url-vocoder").error() == simdjson::SUCCESS) { + std::string_view ModelUrlVocoder; + auto Err = + Doc["model-url-vocoder"].get().get(ModelUrlVocoder); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the model-url-vocoder option."sv) + } + GraphRef.Params.vocoder.model_url = ModelUrlVocoder; + } + // The config parameters. + if (Doc.at_key("stream-stdout").error() == simdjson::SUCCESS) { + auto Err = Doc["stream-stdout"].get().get(ConfRef.StreamStdout); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the stream-stdout option."sv) + } + } + if (Doc.at_key("n-predict").error() == simdjson::SUCCESS) { + auto Err = Doc["n-predict"].get().get(ConfRef.NPredict); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-predict option."sv) + } + } + if (Doc.at_key("reverse-prompt").error() == simdjson::SUCCESS) { + std::string_view ReversePrompt; + auto Err = Doc["reverse-prompt"].get().get(ReversePrompt); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the reverse-prompt option."sv) + } + ConfRef.ReversePrompt = ReversePrompt; + } + if (Doc.at_key("image").error() == simdjson::SUCCESS) { + std::string_view ImagePath; + auto Err = Doc["image"].get().get(ImagePath); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the image option."sv) + } + ConfRef.ImagePath = ImagePath; + } + if (Doc.at_key("always-regenerate-image-embd").error() == simdjson::SUCCESS) { + auto Err = Doc["always-regenerate-image-embd"].get().get( + ConfRef.AlwaysRegenerateImageEmbd); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the always-regenerate-image-embd option."sv) + } + } + if (Doc.at_key("model-alias").error() == simdjson::SUCCESS) { + std::string_view ModelAlias; + auto Err = Doc["model-alias"].get().get(ModelAlias); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the model-alias option."sv) + } + GraphRef.Params.model_alias = ModelAlias; + } + if (Doc.at_key("model-url").error() == simdjson::SUCCESS) { + std::string_view ModelUrl; + auto Err = Doc["model-url"].get().get(ModelUrl); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the model-url option."sv) + } + GraphRef.Params.model_url = ModelUrl; + } + if (Doc.at_key("hf-token").error() == simdjson::SUCCESS) { + std::string_view HfToken; + auto Err = Doc["hf-token"].get().get(HfToken); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the hf-token option."sv) + } + GraphRef.Params.hf_token = HfToken; + } + if (Doc.at_key("hf-repo").error() == simdjson::SUCCESS) { + std::string_view HfRepo; + auto Err = Doc["hf-repo"].get().get(HfRepo); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the hf-repo option."sv) + } + GraphRef.Params.hf_repo = HfRepo; + } + if (Doc.at_key("hf-file").error() == simdjson::SUCCESS) { + std::string_view HfFile; + auto Err = Doc["hf-file"].get().get(HfFile); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the hf-file option."sv) + } + GraphRef.Params.hf_file = HfFile; + } + if (Doc.at_key("prompt-file").error() == simdjson::SUCCESS) { + std::string_view PromptFile; + auto Err = Doc["prompt-file"].get().get(PromptFile); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the prompt-file option."sv) + } + GraphRef.Params.prompt_file = PromptFile; + } + if (Doc.at_key("path-prompt-cache").error() == simdjson::SUCCESS) { + std::string_view PathPromptCache; + auto Err = + Doc["path-prompt-cache"].get().get(PathPromptCache); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the path-prompt-cache option."sv) + } + GraphRef.Params.path_prompt_cache = PathPromptCache; + } + if (Doc.at_key("input-prefix").error() == simdjson::SUCCESS) { + std::string_view InputPrefix; + auto Err = Doc["input-prefix"].get().get(InputPrefix); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the input-prefix option."sv) + } + GraphRef.Params.input_prefix = InputPrefix; + } + if (Doc.at_key("input-suffix").error() == simdjson::SUCCESS) { + std::string_view InputSuffix; + auto Err = Doc["input-suffix"].get().get(InputSuffix); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the input-suffix option."sv) + } + GraphRef.Params.input_suffix = InputSuffix; + } + if (Doc.at_key("lookup-cache-static").error() == simdjson::SUCCESS) { + std::string_view LookupCacheStatic; + auto Err = Doc["lookup-cache-static"].get().get( + LookupCacheStatic); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the lookup-cache-static option."sv) + } + GraphRef.Params.lookup_cache_static = LookupCacheStatic; + } + if (Doc.at_key("lookup-cache-dynamic").error() == simdjson::SUCCESS) { + std::string_view LookupCacheDynamic; + auto Err = Doc["lookup-cache-dynamic"].get().get( + LookupCacheDynamic); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the lookup-cache-dynamic option."sv) + } + GraphRef.Params.lookup_cache_dynamic = LookupCacheDynamic; + } + if (Doc.at_key("logits-file").error() == simdjson::SUCCESS) { + std::string_view LogitsFile; + auto Err = Doc["logits-file"].get().get(LogitsFile); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the logits-file option."sv) + } + GraphRef.Params.logits_file = LogitsFile; + } + if (Doc.at_key("lora-init-without-apply").error() == simdjson::SUCCESS) { + auto Err = Doc["lora-init-without-apply"].get().get( + GraphRef.Params.lora_init_without_apply); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the lora-init-without-apply option."sv) + } + } + if (Doc.at_key("verbosity").error() == simdjson::SUCCESS) { + int64_t Verbosity; + auto Err = Doc["verbosity"].get().get(Verbosity); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the verbosity option."sv) + } + GraphRef.Params.verbosity = static_cast(Verbosity); + } + if (Doc.at_key("control-vector-layer-start").error() == simdjson::SUCCESS) { + int64_t ControlVectorLayerStart; + auto Err = Doc["control-vector-layer-start"].get().get( + ControlVectorLayerStart); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the control-vector-layer-start option."sv) + } + GraphRef.Params.control_vector_layer_start = + static_cast(ControlVectorLayerStart); + } + if (Doc.at_key("control-vector-layer-end").error() == simdjson::SUCCESS) { + int64_t ControlVectorLayerEnd; + auto Err = Doc["control-vector-layer-end"].get().get( + ControlVectorLayerEnd); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the control-vector-layer-end option."sv) + } + GraphRef.Params.control_vector_layer_end = + static_cast(ControlVectorLayerEnd); + } + if (Doc.at_key("ppl-stride").error() == simdjson::SUCCESS) { + int64_t PplStride; + auto Err = Doc["ppl-stride"].get().get(PplStride); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the ppl-stride option."sv) + } + GraphRef.Params.ppl_stride = static_cast(PplStride); + } + if (Doc.at_key("ppl-output-type").error() == simdjson::SUCCESS) { + int64_t PplOutputType; + auto Err = Doc["ppl-output-type"].get().get(PplOutputType); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the ppl-output-type option."sv) + } + GraphRef.Params.ppl_output_type = static_cast(PplOutputType); + } + if (Doc.at_key("hellaswag").error() == simdjson::SUCCESS) { + auto Err = Doc["hellaswag"].get().get(GraphRef.Params.hellaswag); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the hellaswag option."sv) + } + } + if (Doc.at_key("hellaswag-tasks").error() == simdjson::SUCCESS) { + uint64_t HellaswagTasks; + auto Err = Doc["hellaswag-tasks"].get().get(HellaswagTasks); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the hellaswag-tasks option."sv) + } + GraphRef.Params.hellaswag_tasks = HellaswagTasks; + } + if (Doc.at_key("winogrande").error() == simdjson::SUCCESS) { + auto Err = Doc["winogrande"].get().get(GraphRef.Params.winogrande); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the winogrande option."sv) + } + } + if (Doc.at_key("winogrande-tasks").error() == simdjson::SUCCESS) { + uint64_t WinograndeTasks; + auto Err = Doc["winogrande-tasks"].get().get(WinograndeTasks); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the winogrande-tasks option."sv) + } + GraphRef.Params.winogrande_tasks = WinograndeTasks; + } + if (Doc.at_key("multiple-choice").error() == simdjson::SUCCESS) { + auto Err = + Doc["multiple-choice"].get().get(GraphRef.Params.multiple_choice); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the multiple-choice option."sv) + } + } + if (Doc.at_key("multiple-choice-tasks").error() == simdjson::SUCCESS) { + uint64_t MultipleChoiceTasks; + auto Err = + Doc["multiple-choice-tasks"].get().get(MultipleChoiceTasks); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the multiple-choice-tasks option."sv) + } + GraphRef.Params.multiple_choice_tasks = MultipleChoiceTasks; + } + if (Doc.at_key("kl-divergence").error() == simdjson::SUCCESS) { + auto Err = + Doc["kl-divergence"].get().get(GraphRef.Params.kl_divergence); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the kl-divergence option."sv) + } + } + if (Doc.at_key("usage").error() == simdjson::SUCCESS) { + auto Err = Doc["usage"].get().get(GraphRef.Params.usage); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the usage option."sv) + } + } + if (Doc.at_key("use-color").error() == simdjson::SUCCESS) { + auto Err = Doc["use-color"].get().get(GraphRef.Params.use_color); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the use-color option."sv) + } + } + if (Doc.at_key("special").error() == simdjson::SUCCESS) { + auto Err = Doc["special"].get().get(GraphRef.Params.special); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the special option."sv) + } + } + if (Doc.at_key("interactive").error() == simdjson::SUCCESS) { + auto Err = Doc["interactive"].get().get(GraphRef.Params.interactive); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the interactive option."sv) + } + } + if (Doc.at_key("interactive-first").error() == simdjson::SUCCESS) { + auto Err = Doc["interactive-first"].get().get( + GraphRef.Params.interactive_first); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the interactive-first option."sv) + } + } + if (Doc.at_key("prompt-cache-all").error() == simdjson::SUCCESS) { + auto Err = Doc["prompt-cache-all"].get().get( + GraphRef.Params.prompt_cache_all); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the prompt-cache-all option."sv) + } + } + if (Doc.at_key("prompt-cache-ro").error() == simdjson::SUCCESS) { + auto Err = + Doc["prompt-cache-ro"].get().get(GraphRef.Params.prompt_cache_ro); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the prompt-cache-ro option."sv) + } + } + if (Doc.at_key("escape").error() == simdjson::SUCCESS) { + auto Err = Doc["escape"].get().get(GraphRef.Params.escape); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the escape option."sv) + } + } + if (Doc.at_key("multiline-input").error() == simdjson::SUCCESS) { + auto Err = + Doc["multiline-input"].get().get(GraphRef.Params.multiline_input); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the multiline-input option."sv) + } + } + if (Doc.at_key("simple-io").error() == simdjson::SUCCESS) { + auto Err = Doc["simple-io"].get().get(GraphRef.Params.simple_io); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the simple-io option."sv) + } + } + if (Doc.at_key("cont-batching").error() == simdjson::SUCCESS) { + auto Err = + Doc["cont-batching"].get().get(GraphRef.Params.cont_batching); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the cont-batching option."sv) + } + } + if (Doc.at_key("flash-attn").error() == simdjson::SUCCESS) { + auto Err = Doc["flash-attn"].get().get(GraphRef.Params.flash_attn); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the flash-attn option."sv) + } + } + if (Doc.at_key("no-perf").error() == simdjson::SUCCESS) { + auto Err = Doc["no-perf"].get().get(GraphRef.Params.no_perf); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the enable-log option."sv) + "Unable to retrieve the no-perf option."sv) } } - if (Doc.at_key("enable-debug-log").error() == simdjson::SUCCESS) { - auto Err = Doc["enable-debug-log"].get().get(GraphRef.EnableDebugLog); + if (Doc.at_key("ctx-shift").error() == simdjson::SUCCESS) { + auto Err = Doc["ctx-shift"].get().get(GraphRef.Params.ctx_shift); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the enable-debug-log option."sv) + "Unable to retrieve the ctx-shift option."sv) } } - - // The model parameters. - if (Doc.at_key("main-gpu").error() == simdjson::SUCCESS) { - auto Err = Doc["main-gpu"].get().get(GraphRef.MainGPU); + if (Doc.at_key("input-prefix-bos").error() == simdjson::SUCCESS) { + auto Err = Doc["input-prefix-bos"].get().get( + GraphRef.Params.input_prefix_bos); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the main-gpu option."sv) + "Unable to retrieve the input-prefix-bos option."sv) } } - if (Doc.at_key("n-gpu-layers").error() == simdjson::SUCCESS) { - auto Err = Doc["n-gpu-layers"].get().get(GraphRef.NGPULayers); + if (Doc.at_key("logits-all").error() == simdjson::SUCCESS) { + auto Err = Doc["logits-all"].get().get(GraphRef.Params.logits_all); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the n-gpu-layers option."sv) + "Unable to retrieve the logits-all option."sv) } } - if (Doc.at_key("tensor-split").error() == simdjson::SUCCESS) { - // The TensorSplit is a comma-separated list of non-negative values. - // E.g., "3,2" presents 60% of the data to GPU 0 and 40% to GPU 1. - std::string_view TSV; - auto Err = Doc["tensor-split"].get().get(TSV); + if (Doc.at_key("use-mlock").error() == simdjson::SUCCESS) { + auto Err = Doc["use-mlock"].get().get(GraphRef.Params.use_mlock); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the tensor-split option."sv) + "Unable to retrieve the use-mlock option."sv) } - std::string TS(TSV); - std::replace(TS.begin(), TS.end(), ',', ' '); - std::stringstream SS(TS); - GraphRef.TensorSplit.clear(); - while (SS.good()) { - float TmpTensor; - SS >> TmpTensor; - GraphRef.TensorSplit.push_back(TmpTensor); + } + if (Doc.at_key("use-mmap").error() == simdjson::SUCCESS) { + auto Err = Doc["use-mmap"].get().get(GraphRef.Params.use_mmap); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the use-mmap option."sv) } - size_t NDevices = llama_max_devices(); - if (GraphRef.TensorSplit.size() > NDevices) { - RET_ERROR( - ErrNo::InvalidArgument, - "Number of Tensor-Split is larger than MaxDevices, please reduce "sv - "the size of tensor-split."sv) + } + if (Doc.at_key("verbose-prompt").error() == simdjson::SUCCESS) { + auto Err = + Doc["verbose-prompt"].get().get(GraphRef.Params.verbose_prompt); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the verbose-prompt option."sv) } - for (size_t Idx = GraphRef.TensorSplit.size(); Idx < NDevices; Idx++) { - GraphRef.TensorSplit.push_back(0.0f); + } + if (Doc.at_key("display-prompt").error() == simdjson::SUCCESS) { + auto Err = + Doc["display-prompt"].get().get(GraphRef.Params.display_prompt); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the display-prompt option."sv) } } - if (Doc.at_key("embedding").error() == simdjson::SUCCESS) { - auto Err = Doc["embedding"].get().get(GraphRef.Embedding); + if (Doc.at_key("dump-kv-cache").error() == simdjson::SUCCESS) { + auto Err = + Doc["dump-kv-cache"].get().get(GraphRef.Params.dump_kv_cache); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the embedding option."sv) + "Unable to retrieve the dump-kv-cache option."sv) } } - if (Doc.at_key("use-mmap").error() == simdjson::SUCCESS) { - auto Err = Doc["use-mmap"].get().get(GraphRef.UseMMap); + if (Doc.at_key("no-kv-offload").error() == simdjson::SUCCESS) { + auto Err = + Doc["no-kv-offload"].get().get(GraphRef.Params.no_kv_offload); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the use-mmap option."sv) + "Unable to retrieve the no-kv-offload option."sv) } } if (Doc.at_key("warmup").error() == simdjson::SUCCESS) { - auto Err = Doc["warmup"].get().get(GraphRef.WarmUp); + auto Err = Doc["warmup"].get().get(GraphRef.Params.warmup); if (Err) { RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the warmup option."sv) } } - if (Doc.at_key("split-mode").error() == simdjson::SUCCESS) { - std::string_view SplitMode; - auto Err = Doc["split-mode"].get().get(SplitMode); + if (Doc.at_key("check-tensors").error() == simdjson::SUCCESS) { + auto Err = + Doc["check-tensors"].get().get(GraphRef.Params.check_tensors); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the split-mode option."sv) + "Unable to retrieve the check-tensors option."sv) } - if (SplitMode == "none"sv) { - GraphRef.SplitMode = LLAMA_SPLIT_MODE_NONE; - } else if (SplitMode == "layer"sv) { - GraphRef.SplitMode = LLAMA_SPLIT_MODE_LAYER; - } else if (SplitMode == "row"sv) { - GraphRef.SplitMode = LLAMA_SPLIT_MODE_ROW; - } else { + } + if (Doc.at_key("cache-type-k").error() == simdjson::SUCCESS) { + int64_t CacheTypeK; + auto Err = Doc["cache-type-k"].get().get(CacheTypeK); + if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unknown split-mode: {}. Valid: none, layer, row."sv, SplitMode) + "Unable to retrieve the cache-type-k option."sv) } + GraphRef.Params.cache_type_k = static_cast(CacheTypeK); } - if (Doc.at_key("mmproj").error() == simdjson::SUCCESS) { - std::string_view MMProjModelPath; - auto Err = Doc["mmproj"].get().get(MMProjModelPath); + if (Doc.at_key("cache-type-v").error() == simdjson::SUCCESS) { + int64_t CacheTypeV; + auto Err = Doc["cache-type-v"].get().get(CacheTypeV); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the mmproj option."sv) + "Unable to retrieve the cache-type-v option."sv) } - GraphRef.MMProjModelPath = MMProjModelPath; + GraphRef.Params.cache_type_v = static_cast(CacheTypeV); } - - // The TTS parameters. - if (Doc.at_key("tts").error() == simdjson::SUCCESS) { - auto Err = Doc["tts"].get().get(GraphRef.TextToSpeech); + if (Doc.at_key("embd-normalize").error() == simdjson::SUCCESS) { + int64_t EmbdNormalize; + auto Err = Doc["embd-normalize"].get().get(EmbdNormalize); if (Err) { - RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the tts option."sv) + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the embd-normalize option."sv) } + GraphRef.Params.embd_normalize = static_cast(EmbdNormalize); } - if (Doc.at_key("model-vocoder").error() == simdjson::SUCCESS) { - std::string_view VocoderModelPath; + if (Doc.at_key("embd-out").error() == simdjson::SUCCESS) { + std::string_view EmbdOut; + auto Err = Doc["embd-out"].get().get(EmbdOut); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the embd-out option."sv) + } + GraphRef.Params.embd_out = EmbdOut; + } + if (Doc.at_key("embd-sep").error() == simdjson::SUCCESS) { + std::string_view EmbdSep; + auto Err = Doc["embd-sep"].get().get(EmbdSep); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the embd-sep option."sv) + } + GraphRef.Params.embd_sep = EmbdSep; + } + if (Doc.at_key("reranking").error() == simdjson::SUCCESS) { + auto Err = Doc["reranking"].get().get(GraphRef.Params.reranking); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the reranking option."sv) + } + } + if (Doc.at_key("port").error() == simdjson::SUCCESS) { + int64_t Port; + auto Err = Doc["port"].get().get(Port); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the port option."sv) + } + GraphRef.Params.port = static_cast(Port); + } + if (Doc.at_key("timeout-read").error() == simdjson::SUCCESS) { + int64_t TimeoutRead; + auto Err = Doc["timeout-read"].get().get(TimeoutRead); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the timeout-read option."sv) + } + GraphRef.Params.timeout_read = static_cast(TimeoutRead); + } + if (Doc.at_key("timeout-write").error() == simdjson::SUCCESS) { + int64_t TimeoutWrite; + auto Err = Doc["timeout-write"].get().get(TimeoutWrite); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the timeout-write option."sv) + } + GraphRef.Params.timeout_write = static_cast(TimeoutWrite); + } + if (Doc.at_key("n-threads-http").error() == simdjson::SUCCESS) { + int64_t NThreadsHttp; + auto Err = Doc["n-threads-http"].get().get(NThreadsHttp); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-threads-http option."sv) + } + GraphRef.Params.n_threads_http = static_cast(NThreadsHttp); + } + if (Doc.at_key("n-cache-reuse").error() == simdjson::SUCCESS) { + int64_t NCacheReuse; + auto Err = Doc["n-cache-reuse"].get().get(NCacheReuse); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-cache-reuse option."sv) + } + GraphRef.Params.n_cache_reuse = static_cast(NCacheReuse); + } + if (Doc.at_key("hostname").error() == simdjson::SUCCESS) { + std::string_view Hostname; + auto Err = Doc["hostname"].get().get(Hostname); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the hostname option."sv) + } + GraphRef.Params.hostname = Hostname; + } + if (Doc.at_key("public-path").error() == simdjson::SUCCESS) { + std::string_view PublicPath; + auto Err = Doc["public-path"].get().get(PublicPath); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the public-path option."sv) + } + GraphRef.Params.public_path = PublicPath; + } + if (Doc.at_key("chat-template").error() == simdjson::SUCCESS) { + std::string_view ChatTemplate; + auto Err = Doc["chat-template"].get().get(ChatTemplate); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the chat-template option."sv) + } + GraphRef.Params.chat_template = ChatTemplate; + } + if (Doc.at_key("enable-chat-template").error() == simdjson::SUCCESS) { + auto Err = Doc["enable-chat-template"].get().get( + GraphRef.Params.enable_chat_template); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the enable-chat-template option."sv) + } + } + if (Doc.at_key("ssl-file-key").error() == simdjson::SUCCESS) { + std::string_view SslFileKey; + auto Err = Doc["ssl-file-key"].get().get(SslFileKey); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the ssl-file-key option."sv) + } + GraphRef.Params.ssl_file_key = SslFileKey; + } + if (Doc.at_key("ssl-file-cert").error() == simdjson::SUCCESS) { + std::string_view SslFileCert; + auto Err = Doc["ssl-file-cert"].get().get(SslFileCert); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the ssl-file-cert option."sv) + } + GraphRef.Params.ssl_file_cert = SslFileCert; + } + if (Doc.at_key("webui").error() == simdjson::SUCCESS) { + auto Err = Doc["webui"].get().get(GraphRef.Params.webui); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the webui option."sv) + } + } + if (Doc.at_key("endpoint-slots").error() == simdjson::SUCCESS) { + int64_t EndpointSlots; + auto Err = Doc["endpoint-slots"].get().get(EndpointSlots); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the endpoint-slots option."sv) + } + GraphRef.Params.endpoint_slots = static_cast(EndpointSlots); + } + if (Doc.at_key("endpoint-props").error() == simdjson::SUCCESS) { auto Err = - Doc["model-vocoder"].get().get(VocoderModelPath); + Doc["endpoint-props"].get().get(GraphRef.Params.endpoint_props); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the model-vocoder option."sv) + "Unable to retrieve the endpoint-props option."sv) } - GraphRef.VocoderModelPath = VocoderModelPath; } - if (Doc.at_key("tts-output-file").error() == simdjson::SUCCESS) { - std::string_view TTSOutputFilePath; + if (Doc.at_key("endpoint-metrics").error() == simdjson::SUCCESS) { + auto Err = Doc["endpoint-metrics"].get().get( + GraphRef.Params.endpoint_metrics); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the endpoint-metrics option."sv) + } + } + if (Doc.at_key("log-json").error() == simdjson::SUCCESS) { + auto Err = Doc["log-json"].get().get(GraphRef.Params.log_json); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the log-json option."sv) + } + } + if (Doc.at_key("slot-save-path").error() == simdjson::SUCCESS) { + std::string_view SlotSavePath; + auto Err = Doc["slot-save-path"].get().get(SlotSavePath); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the slot-save-path option."sv) + } + GraphRef.Params.slot_save_path = SlotSavePath; + } + if (Doc.at_key("slot-prompt-similarity").error() == simdjson::SUCCESS) { + double SlotPromptSimilarity; auto Err = - Doc["tts-output-file"].get().get(TTSOutputFilePath); + Doc["slot-prompt-similarity"].get().get(SlotPromptSimilarity); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the tts-output-file option."sv) + "Unable to retrieve the slot-prompt-similarity option."sv) } - GraphRef.TTSOutputFilePath = TTSOutputFilePath; + GraphRef.Params.slot_prompt_similarity = + static_cast(SlotPromptSimilarity); } - if (Doc.at_key("tts-speaker-file").error() == simdjson::SUCCESS) { - std::string_view TTSSpeakerFilePath; + if (Doc.at_key("is-pp-shared").error() == simdjson::SUCCESS) { auto Err = - Doc["tts-speaker-file"].get().get(TTSSpeakerFilePath); + Doc["is-pp-shared"].get().get(GraphRef.Params.is_pp_shared); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the tts-speaker-file option."sv) + "Unable to retrieve the is-pp-shared option."sv) } - GraphRef.TTSSpeakerFilePath = TTSSpeakerFilePath; } - - // The context parameters. - if (Doc.at_key("ctx-size").error() == simdjson::SUCCESS) { - auto Err = Doc["ctx-size"].get().get(GraphRef.CtxSize); + if (Doc.at_key("n-pp").error() == simdjson::SUCCESS) { + int64_t NPP; + auto Err = Doc["n-pp"].get().get(NPP); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the n-pp option."sv) + } + } + if (Doc.at_key("n-tg").error() == simdjson::SUCCESS) { + int64_t NTG; + auto Err = Doc["n-tg"].get().get(NTG); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the n-tg option."sv) + } + } + if (Doc.at_key("n-pl").error() == simdjson::SUCCESS) { + int64_t NPL; + auto Err = Doc["n-pl"].get().get(NPL); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the n-pl option."sv) + } + } + if (Doc.at_key("context-files").error() == simdjson::SUCCESS) { + std::string_view ContextFiles; + auto Err = Doc["context-files"].get().get(ContextFiles); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the ctx-size option."sv) + "Unable to retrieve the context-files option."sv) } } - if (Doc.at_key("batch-size").error() == simdjson::SUCCESS) { - auto Err = Doc["batch-size"].get().get(GraphRef.BatchSize); + if (Doc.at_key("chunk-size").error() == simdjson::SUCCESS) { + int64_t ChunkSize; + auto Err = Doc["chunk-size"].get().get(ChunkSize); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the batch-size option."sv) + "Unable to retrieve the chunk-size option."sv) } } - if (Doc.at_key("ubatch-size").error() == simdjson::SUCCESS) { - auto Err = Doc["ubatch-size"].get().get(GraphRef.UBatchSize); + if (Doc.at_key("chunk-separator").error() == simdjson::SUCCESS) { + std::string_view ChunkSeparator; + auto Err = + Doc["chunk-separator"].get().get(ChunkSeparator); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the ubatch-size option."sv) + "Unable to retrieve the chunk-separator option."sv) } } - if (Doc.at_key("threads").error() == simdjson::SUCCESS) { - auto Err = Doc["threads"].get().get(GraphRef.Threads); + if (Doc.at_key("n-junk").error() == simdjson::SUCCESS) { + int64_t NJunk; + auto Err = Doc["n-junk"].get().get(NJunk); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the threads option."sv) + "Unable to retrieve the n-junk option."sv) } } - - // The sampling parameters. - if (Doc.at_key("temp").error() == simdjson::SUCCESS) { - auto Err = Doc["temp"].get().get(GraphRef.Temp); + if (Doc.at_key("i-pos").error() == simdjson::SUCCESS) { + int64_t IPos; + auto Err = Doc["i-pos"].get().get(IPos); if (Err) { - RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the temp option."sv) + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the i-pos option."sv) } - GraphRef.Temp = std::max(0.0, GraphRef.Temp); } - if (Doc.at_key("top-p").error() == simdjson::SUCCESS) { - auto Err = Doc["top-p"].get().get(GraphRef.TopP); + if (Doc.at_key("out-file").error() == simdjson::SUCCESS) { + std::string_view OutFile; + auto Err = Doc["out-file"].get().get(OutFile); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the top-p option."sv) + "Unable to retrieve the out-file option."sv) } } - if (Doc.at_key("repeat-penalty").error() == simdjson::SUCCESS) { - auto Err = Doc["repeat-penalty"].get().get(GraphRef.RepeatPenalty); + if (Doc.at_key("n-out-freq").error() == simdjson::SUCCESS) { + int64_t NOutFreq; + auto Err = Doc["n-out-freq"].get().get(NOutFreq); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the repeat-penalty option."sv) + "Unable to retrieve the n-out-freq option."sv) } } - if (Doc.at_key("presence-penalty").error() == simdjson::SUCCESS) { - auto Err = - Doc["presence-penalty"].get().get(GraphRef.PresencePenalty); + if (Doc.at_key("n-save-freq").error() == simdjson::SUCCESS) { + int64_t NSaveFreq; + auto Err = Doc["n-save-freq"].get().get(NSaveFreq); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the presence-penalty option."sv) + "Unable to retrieve the n-save-freq option."sv) } } - if (Doc.at_key("frequency-penalty").error() == simdjson::SUCCESS) { + if (Doc.at_key("i-chunk").error() == simdjson::SUCCESS) { + int64_t IChunk; + auto Err = Doc["i-chunk"].get().get(IChunk); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the i-chunk option."sv) + } + } + if (Doc.at_key("process-output").error() == simdjson::SUCCESS) { auto Err = - Doc["frequency-penalty"].get().get(GraphRef.FrequencyPenalty); + Doc["process-output"].get().get(GraphRef.Params.process_output); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the frequency-penalty option."sv) + "Unable to retrieve the process-output option."sv) } } - if (Doc.at_key("grammar").error() == simdjson::SUCCESS) { - std::string_view Grammar; - auto Err = Doc["grammar"].get().get(Grammar); + if (Doc.at_key("compute-ppl").error() == simdjson::SUCCESS) { + auto Err = Doc["compute-ppl"].get().get(GraphRef.Params.compute_ppl); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the grammar option."sv) + "Unable to retrieve the compute-ppl option."sv) } - GraphRef.Grammar = Grammar; } - if (Doc.at_key("json-schema").error() == simdjson::SUCCESS) { - std::string_view JsonSchema; - auto Err = Doc["json-schema"].get().get(JsonSchema); + if (Doc.at_key("n-pca-batch").error() == simdjson::SUCCESS) { + int64_t NPCABatch; + auto Err = Doc["n-pca-batch"].get().get(NPCABatch); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the json-schema option."sv) + "Unable to retrieve the n-pca-batch option."sv) } - GraphRef.Grammar = - json_schema_to_grammar(nlohmann::ordered_json::parse(JsonSchema)); } - if (Doc.at_key("seed").error() == simdjson::SUCCESS) { - auto Err = Doc["seed"].get().get(GraphRef.Seed); + if (Doc.at_key("n-pca-iterations").error() == simdjson::SUCCESS) { + int64_t NPCAIterations; + auto Err = Doc["n-pca-iterations"].get().get(NPCAIterations); if (Err) { - RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the seed option."sv) + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-pca-iterations option."sv) } } - - // The config parameters. - if (Doc.at_key("stream-stdout").error() == simdjson::SUCCESS) { - auto Err = Doc["stream-stdout"].get().get(ConfRef.StreamStdout); + if (Doc.at_key("cvector-dimre-method").error() == simdjson::SUCCESS) { + std::string_view CVectorDimreMethod; + auto Err = Doc["cvector-dimre-method"].get().get( + CVectorDimreMethod); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the stream-stdout option."sv) + "Unable to retrieve the cvector-dimre-method option."sv) } } - if (Doc.at_key("n-predict").error() == simdjson::SUCCESS) { - auto Err = Doc["n-predict"].get().get(ConfRef.NPredict); + if (Doc.at_key("cvector-outfile").error() == simdjson::SUCCESS) { + std::string_view CVectorOutfile; + auto Err = + Doc["cvector-outfile"].get().get(CVectorOutfile); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the n-predict option."sv) + "Unable to retrieve the cvector-outfile option."sv) } } - if (Doc.at_key("reverse-prompt").error() == simdjson::SUCCESS) { - std::string_view ReversePrompt; - auto Err = Doc["reverse-prompt"].get().get(ReversePrompt); + if (Doc.at_key("cvector-positive-file").error() == simdjson::SUCCESS) { + std::string_view CVectorPositiveFile; + auto Err = Doc["cvector-positive-file"].get().get( + CVectorPositiveFile); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the reverse-prompt option."sv) + "Unable to retrieve the cvector-positive-file option."sv) } - ConfRef.ReversePrompt = ReversePrompt; } - if (Doc.at_key("image").error() == simdjson::SUCCESS) { - std::string_view ImagePath; - auto Err = Doc["image"].get().get(ImagePath); + if (Doc.at_key("cvector-negative-file").error() == simdjson::SUCCESS) { + std::string_view CVectorNegativeFile; + auto Err = Doc["cvector-negative-file"].get().get( + CVectorNegativeFile); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the image option."sv) + "Unable to retrieve the cvector-negative-file option."sv) } - ConfRef.ImagePath = ImagePath; } - if (Doc.at_key("always-regenerate-image-embd").error() == simdjson::SUCCESS) { - auto Err = Doc["always-regenerate-image-embd"].get().get( - ConfRef.AlwaysRegenerateImageEmbd); + if (Doc.at_key("spm-infill").error() == simdjson::SUCCESS) { + auto Err = Doc["spm-infill"].get().get(GraphRef.Params.spm_infill); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the always-regenerate-image-embd option."sv) + "Unable to retrieve the spm-infill option."sv) + } + } + if (Doc.at_key("lora-outfile").error() == simdjson::SUCCESS) { + std::string_view LoraOutfile; + auto Err = Doc["lora-outfile"].get().get(LoraOutfile); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the lora-outfile option."sv) + } + GraphRef.Params.lora_outfile = LoraOutfile; + } + if (Doc.at_key("batched-bench-output-jsonl").error() == simdjson::SUCCESS) { + auto Err = Doc["batched-bench-output-jsonl"].get().get( + GraphRef.Params.batched_bench_output_jsonl); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the batched-bench-output-jsonl option."sv) } } + if (GraphRef.TextToSpeech) { + GraphRef.Params.sampling.top_k = 4; + GraphRef.Params.sampling.samplers = { + COMMON_SAMPLER_TYPE_TOP_K, + }; + } // Check if the model parameters are updated. - if (IsModelUpdated && PrevNGPULayers != GraphRef.NGPULayers) { + if (IsModelUpdated && PrevNGPULayers != GraphRef.Params.n_gpu_layers) { *IsModelUpdated = true; } // Check if the context parameters are updated. - if (IsContextUpdated && PrevEmbedding != GraphRef.Embedding) { + if (IsContextUpdated && PrevEmbedding != GraphRef.Params.embedding) { *IsContextUpdated = true; } // Check if the sampler parameters are updated. if (IsSamplerUpdated && - (PrevTemp != GraphRef.Temp || PrevTopP != GraphRef.TopP || - PrevRepeatPenalty != GraphRef.RepeatPenalty || - PrevPresencePenalty != GraphRef.PresencePenalty || - PrevFrequencyPenalty != GraphRef.FrequencyPenalty || - PrevGrammar != GraphRef.Grammar || PrevSeed != GraphRef.Seed)) { + (PrevTemp != GraphRef.Params.sampling.temp || + PrevTopP != GraphRef.Params.sampling.top_p || + PrevRepeatPenalty != GraphRef.Params.sampling.penalty_repeat || + PrevPresencePenalty != GraphRef.Params.sampling.penalty_present || + PrevFrequencyPenalty != GraphRef.Params.sampling.penalty_freq || + PrevGrammar != GraphRef.Params.sampling.grammar || + PrevSeed != GraphRef.Params.sampling.seed)) { *IsSamplerUpdated = true; } @@ -824,7 +2128,7 @@ struct llama_batch allocBatch(int64_t NTokens, int64_t Embd = 0, // Fill tokens (smaller than batch size) into a batch with position data. void fillBatch(Span Tokens, Graph &GraphRef, llama_batch &Batch, int &NPos, bool IsLogit = false) { - assuming(GraphRef.BatchSize >= static_cast(Tokens.size())); + assuming(GraphRef.Params.n_batch >= static_cast(Tokens.size())); assuming(Batch.token != nullptr); assuming(Batch.pos != nullptr); assuming(Batch.logits != nullptr); @@ -925,10 +2229,10 @@ ErrNo evaluateTokens(Span Tokens, Graph &GraphRef, // Loop for decode batch. Split tokens into batch size length. for (int I = 0; I < static_cast(Tokens.size()); - I += static_cast(GraphRef.BatchSize)) { + I += static_cast(GraphRef.Params.n_batch)) { int NEval = static_cast(Tokens.size()) - I; - if (NEval > static_cast(GraphRef.BatchSize)) { - NEval = static_cast(GraphRef.BatchSize); + if (NEval > static_cast(GraphRef.Params.n_batch)) { + NEval = static_cast(GraphRef.Params.n_batch); } // LlamaPos for Qwen2VL. @@ -1029,7 +2333,7 @@ ErrNo evaluateInput(Graph &GraphRef, Context &CxtRef, LogPrefix) EvalImageStatus = llava_eval_image_embed( GraphRef.LlamaContext.get(), CxtRef.LlavaImageEmbd, - static_cast(GraphRef.BatchSize), &CxtRef.NPos); + static_cast(GraphRef.Params.n_batch), &CxtRef.NPos); LOG_DEBUG(GraphRef.EnableDebugLog, "{}: Eval llava image embd...done"sv, LogPrefix) break; @@ -1039,7 +2343,7 @@ ErrNo evaluateInput(Graph &GraphRef, Context &CxtRef, auto *ImageSize = clip_get_load_image_size(GraphRef.ClipContext); EvalImageStatus = evaluateQwen2vlImageEmbed( GraphRef.LlamaContext.get(), CxtRef.LlavaImageEmbd, - static_cast(GraphRef.BatchSize), CxtRef.NPos, ImageSize); + static_cast(GraphRef.Params.n_batch), CxtRef.NPos, ImageSize); LOG_DEBUG(GraphRef.EnableDebugLog, "{}: Eval Qwen2VL image embd...done"sv, LogPrefix) break; @@ -1131,12 +2435,13 @@ Expect getEmbedding(Graph &GraphRef, Context &CxtRef) noexcept { } // Check if the input is too long. - if (static_cast(CxtRef.LlamaInputs.size()) > GraphRef.BatchSize) { + if (static_cast(CxtRef.LlamaInputs.size()) > + GraphRef.Params.n_batch) { RET_ERROR( ErrNo::PromptTooLong, "getEmbedding: the prompt is too long. Your input has {} tokens exceeds batch "sv "size {}. Please reduce the input size or increase your batch-size."sv, - CxtRef.LlamaInputs.size(), GraphRef.BatchSize) + CxtRef.LlamaInputs.size(), GraphRef.Params.n_batch) } // Evaluate the input tokens. @@ -1417,7 +2722,8 @@ ErrNo codesToSpeech(Graph &GraphRef, Context &CxtRef) noexcept { // Embeddings to audio. std::vector AudioData = - embdToAudio(Embd, NCodes, NEmbd, static_cast(GraphRef.Threads)); + embdToAudio(Embd, NCodes, NEmbd, + static_cast(GraphRef.Params.cpuparams.n_threads)); // Zero out first 0.25 seconds of audio. const uint32_t SamplingRate = 24000; @@ -1444,26 +2750,53 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Initialize the plugin parameters. GraphRef.EnableLog = false; GraphRef.EnableDebugLog = false; + const common_params CommonParamsDefault; + GraphRef.Params = CommonParamsDefault; + GraphRef.Params.n_keep = 0; + GraphRef.Params.n_chunks = -1; + GraphRef.Params.n_parallel = 1; + GraphRef.Params.grp_attn_n = 1; + GraphRef.Params.grp_attn_w = 512; + GraphRef.Params.n_print = -1; + GraphRef.Params.split_mode = llama_split_mode::LLAMA_SPLIT_MODE_LAYER; // Initialize the model parameters. llama_model_params ModelParamsDefault = llama_model_default_params(); - GraphRef.NGPULayers = ModelParamsDefault.n_gpu_layers; - GraphRef.MMProjModelPath = ""sv; + GraphRef.Params.n_gpu_layers = ModelParamsDefault.n_gpu_layers; + GraphRef.Params.mmproj = ""sv; // Initialize the context parameters. llama_context_params ContextParamsDefault = llama_context_default_params(); - GraphRef.CtxSize = ContextParamsDefault.n_ctx; - GraphRef.BatchSize = ContextParamsDefault.n_batch; - GraphRef.UBatchSize = ContextParamsDefault.n_ubatch; - GraphRef.Threads = ContextParamsDefault.n_threads; + GraphRef.Params.cpuparams.n_threads = ContextParamsDefault.n_threads; + GraphRef.Params.n_ctx = ContextParamsDefault.n_ctx; + GraphRef.Params.n_batch = ContextParamsDefault.n_batch; + GraphRef.Params.n_ubatch = ContextParamsDefault.n_ubatch; + GraphRef.Params.cpuparams.n_threads = ContextParamsDefault.n_threads_batch; + GraphRef.Params.cpuparams_batch.n_threads = + ContextParamsDefault.n_threads_batch; + GraphRef.Params.rope_scaling_type = ContextParamsDefault.rope_scaling_type; + GraphRef.Params.pooling_type = ContextParamsDefault.pooling_type; + GraphRef.Params.attention_type = ContextParamsDefault.attention_type; + GraphRef.Params.rope_freq_base = ContextParamsDefault.rope_freq_base; + GraphRef.Params.rope_freq_scale = ContextParamsDefault.rope_freq_scale; + GraphRef.Params.yarn_ext_factor = ContextParamsDefault.yarn_ext_factor; + GraphRef.Params.yarn_attn_factor = ContextParamsDefault.yarn_attn_factor; + GraphRef.Params.yarn_beta_fast = ContextParamsDefault.yarn_beta_fast; + GraphRef.Params.yarn_beta_slow = ContextParamsDefault.yarn_beta_slow; + GraphRef.Params.yarn_orig_ctx = ContextParamsDefault.yarn_orig_ctx; + GraphRef.Params.defrag_thold = ContextParamsDefault.defrag_thold; + GraphRef.Params.cb_eval = ContextParamsDefault.cb_eval; + GraphRef.Params.cb_eval_user_data = ContextParamsDefault.cb_eval_user_data; + GraphRef.Params.cache_type_k = ContextParamsDefault.type_k; + GraphRef.Params.cache_type_v = ContextParamsDefault.type_v; + GraphRef.Params.logits_all = ContextParamsDefault.logits_all; + GraphRef.Params.embedding = ContextParamsDefault.embeddings; + GraphRef.Params.no_kv_offload = !ContextParamsDefault.offload_kqv; + GraphRef.Params.flash_attn = ContextParamsDefault.flash_attn; + GraphRef.Params.no_perf = ContextParamsDefault.no_perf; + // Initialize the sampling parameters. const common_params_sampling SamplerParamsDefault; - GraphRef.Temp = SamplerParamsDefault.temp; - GraphRef.TopP = SamplerParamsDefault.top_p; - GraphRef.RepeatPenalty = SamplerParamsDefault.penalty_repeat; - GraphRef.PresencePenalty = SamplerParamsDefault.penalty_present; - GraphRef.FrequencyPenalty = SamplerParamsDefault.penalty_freq; - GraphRef.Grammar = SamplerParamsDefault.grammar; + GraphRef.Params.sampling = SamplerParamsDefault; // Initialize the config parameters. - const common_params CommonParamsDefault; GraphRef.Conf.StreamStdout = false; GraphRef.Conf.EmbdNormalize = static_cast(CommonParamsDefault.embd_normalize); @@ -1497,15 +2830,15 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, const std::string_view BinModel(reinterpret_cast(Weight.data()), Weight.size()); if (BinModel.substr(0, 8) == "preload:"sv) { - GraphRef.ModelFilePath = BinModel.substr(8); + GraphRef.Params.model = BinModel.substr(8); } else { LOG_DEBUG(GraphRef.EnableDebugLog, "load: Model path not found in nn-preload, write model into "sv "a tmpfile."sv) // TODO: pass the model directly to ggml. // Write ggml model to file. - GraphRef.ModelFilePath = "ggml-model.bin"sv; - std::ofstream TempFile(GraphRef.ModelFilePath, + GraphRef.Params.model = "ggml-model.bin"sv; + std::ofstream TempFile(GraphRef.Params.model, std::ios::out | std::ios::binary); if (!TempFile) { Env.deleteGraph(GId); @@ -1524,16 +2857,21 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Check if the model exists. if (!std::filesystem::exists( - std::filesystem::u8path(GraphRef.ModelFilePath))) { + std::filesystem::u8path(GraphRef.Params.model))) { Env.deleteGraph(GId); RET_ERROR(ErrNo::ModelNotFound, "load: model file not found."sv) } + GraphRef.Params.model = GraphRef.Params.model; // Initialize ggml parameters. LOG_DEBUG(GraphRef.EnableDebugLog, "load: initialize ggml model with given parameters."sv) - common_params Params; - setupCommonParams(GraphRef, Params); + + common_params Params = GraphRef.Params; + Params.cpuparams.n_threads = + static_cast(GraphRef.Params.cpuparams.n_threads); + Params.cpuparams_batch.n_threads = + static_cast(GraphRef.Params.cpuparams.n_threads); llama_backend_init(); llama_numa_init(Params.numa); @@ -1555,7 +2893,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Initialize the TTS related model and context. if (GraphRef.TextToSpeech) { LOG_DEBUG(GraphRef.EnableDebugLog, "load: initialize TTS model."sv) - Params.model = GraphRef.VocoderModelPath; + Params.model = GraphRef.Params.vocoder.model; Params.embedding = true; common_init_result TTSInit = common_init_from_params(Params); GraphRef.TTSModel = std::move(TTSInit.model); @@ -1589,17 +2927,15 @@ Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, auto &CxtRef = Env.NNContext[ContextId].get(); // Allocate the batch for input string prompt tokens. - CxtRef.LlamaBatch = allocBatch(GraphRef.BatchSize); - CxtRef.CurrentBatchSize = GraphRef.BatchSize; + CxtRef.LlamaBatch = allocBatch(GraphRef.Params.n_batch); + CxtRef.CurrentBatchSize = GraphRef.Params.n_batch; // Allocate the batch for output sampling. The batch size is always 1. CxtRef.OutputBatch = allocBatch(1); // Allocate sampler. - common_params_sampling CommonSampling; - setupSamplerParams(GraphRef, CommonSampling); CxtRef.LlamaSampler = - common_sampler_init(GraphRef.LlamaModel.get(), CommonSampling); + common_sampler_init(GraphRef.LlamaModel.get(), GraphRef.Params.sampling); Env.NNContext[ContextId].setReady(); LOG_DEBUG(GraphRef.EnableDebugLog, "initExecCtx...Done"sv) @@ -1640,7 +2976,8 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, LOG_INFO(GraphRef.EnableLog, "setInput: Reload model due to parameters change."sv) llama_model_params ModelParams = llama_model_default_params(); - ModelParams.n_gpu_layers = static_cast(GraphRef.NGPULayers); + ModelParams.n_gpu_layers = + static_cast(GraphRef.Params.n_gpu_layers); GraphRef.LlamaModel.reset(); // Due to the model change, the context and sampler should also be // reloaded. The new context and sampler will be created in the next @@ -1652,7 +2989,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, CxtRef.LlamaSampler = nullptr; } GraphRef.LlamaModel = llama_model_ptr(llama_model_load_from_file( - GraphRef.ModelFilePath.c_str(), ModelParams)); + GraphRef.Params.model.c_str(), ModelParams)); if (GraphRef.LlamaModel == nullptr) { Env.NNGraph[CxtRef.GraphId].setInvalid(); RET_ERROR(ErrNo::InvalidArgument, "setInput: unable to init model."sv) @@ -1667,10 +3004,9 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, LOG_INFO(GraphRef.EnableLog, "setInput: Reload llama context due to parameters change."sv) GraphRef.LlamaContext.reset(); - common_params Params; - setupCommonParams(GraphRef, Params); GraphRef.LlamaContext = llama_context_ptr(llama_init_from_model( - GraphRef.LlamaModel.get(), common_context_params_to_llama(Params))); + GraphRef.LlamaModel.get(), + common_context_params_to_llama(GraphRef.Params))); if (GraphRef.LlamaContext == nullptr) { Env.NNGraph[CxtRef.GraphId].setInvalid(); RET_ERROR(ErrNo::InvalidArgument, "setInput: unable to init context."sv) @@ -1685,10 +3021,8 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, if (CxtRef.LlamaSampler) { common_sampler_free(CxtRef.LlamaSampler); } - common_params_sampling CommonSampling; - setupSamplerParams(GraphRef, CommonSampling); - CxtRef.LlamaSampler = - common_sampler_init(GraphRef.LlamaModel.get(), CommonSampling); + CxtRef.LlamaSampler = common_sampler_init(GraphRef.LlamaModel.get(), + GraphRef.Params.sampling); if (GraphRef.LlamaContext == nullptr) { Env.NNGraph[CxtRef.GraphId].setInvalid(); RET_ERROR(ErrNo::InvalidArgument, "setInput: unable to init sampler."sv) @@ -1696,10 +3030,10 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } // Check that is batch size changed. - if (CxtRef.CurrentBatchSize != GraphRef.BatchSize) { + if (CxtRef.CurrentBatchSize != GraphRef.Params.n_batch) { llama_batch_free(CxtRef.LlamaBatch); - CxtRef.LlamaBatch = allocBatch(GraphRef.BatchSize); - CxtRef.CurrentBatchSize = GraphRef.BatchSize; + CxtRef.LlamaBatch = allocBatch(GraphRef.Params.n_batch); + CxtRef.CurrentBatchSize = GraphRef.Params.n_batch; } Env.NNGraph[CxtRef.GraphId].setReady(); @@ -1733,8 +3067,8 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, if (Base64ImagePos.has_value() || CxtRef.Conf.ImagePath != ""sv) { // Prompt with image input. Check is llava or mllama case. - // First check the projection model is given. - if (GraphRef.MMProjModelPath == ""sv) { + // First check the projection model is loaded. + if (GraphRef.Params.mmproj == ""sv) { RET_ERROR( ErrNo::InvalidArgument, "setInput: the given model does not support image input, so a projection model is required."sv) @@ -1748,7 +3082,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, "for CLIP, the step of loading images in CLIP can only use the "sv "CPU, which may result in reduced efficiency. (You can refer to "sv "PR https://github.com/ggerganov/llama.cpp/pull/10896)"sv) - GraphRef.ClipContext = clip_model_load(GraphRef.MMProjModelPath.c_str(), + GraphRef.ClipContext = clip_model_load(GraphRef.Params.mmproj.c_str(), GraphRef.EnableLog ? 1 : 0); if (GraphRef.ClipContext == nullptr) { RET_ERROR(ErrNo::InvalidArgument, @@ -1770,13 +3104,13 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, "setInput: handle llava format prompt."sv) // Show some warnings. - if (GraphRef.CtxSize < 4096) { + if (GraphRef.Params.n_ctx < 4096) { LOG_INFO( GraphRef.EnableLog, "setInput: Context size is {}, we recommend context size >= 2048 when "sv "using llava-v1.5 and context size >= 4096 when using llava-v1.6 "sv "for better results."sv, - GraphRef.CtxSize) + GraphRef.Params.n_ctx) } // Get image embed. @@ -1802,7 +3136,8 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, // Create a new image embedding CxtRef.LlavaImageEmbd = llava_image_embed_make_with_bytes( - GraphRef.ClipContext, static_cast(GraphRef.Threads), + GraphRef.ClipContext, + static_cast(GraphRef.Params.cpuparams.n_threads), Payload->first.data(), static_cast(Payload->first.size())); } else { LOG_DEBUG( @@ -1829,7 +3164,8 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, CxtRef.Conf.ImagePath) // Load the image from the file. CxtRef.LlavaImageEmbd = llava_image_embed_make_with_filename( - GraphRef.ClipContext, static_cast(GraphRef.Threads), + GraphRef.ClipContext, + static_cast(GraphRef.Params.cpuparams.n_threads), CxtRef.Conf.ImagePath.c_str()); LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: Compute image embd from file: {}...Done"sv, @@ -1929,7 +3265,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); LOG_DEBUG(GraphRef.EnableDebugLog, "compute") - if (GraphRef.Embedding) { + if (GraphRef.Params.embedding) { return getEmbedding(GraphRef, CxtRef); } diff --git a/plugins/wasi_nn/wasinn_ggml.h b/plugins/wasi_nn/wasinn_ggml.h index 6ef58053..51ac6624 100644 --- a/plugins/wasi_nn/wasinn_ggml.h +++ b/plugins/wasi_nn/wasinn_ggml.h @@ -60,42 +60,19 @@ struct Graph { // Plugin parameters: bool EnableLog = false; bool EnableDebugLog = false; - // Model parameters: - int64_t MainGPU = 0; // Use GPU 0 by default - int64_t NGPULayers = 0; - std::vector TensorSplit; - bool Embedding = false; - bool UseMMap = true; - bool WarmUp = false; - enum llama_split_mode SplitMode = LLAMA_SPLIT_MODE_LAYER; + common_params Params; // Model context: llama_model_ptr LlamaModel = nullptr; llama_context_ptr LlamaContext = nullptr; - std::string ModelFilePath; // Clip context (for llava): - std::string MMProjModelPath; struct clip_ctx *ClipContext = nullptr; VisionModel VisionModelType = VisionModel::Llava; // Text-to-speech: bool TextToSpeech = false; - std::string VocoderModelPath; std::string TTSOutputFilePath = "output.wav"; std::string TTSSpeakerFilePath; llama_model_ptr TTSModel = nullptr; llama_context_ptr TTSContext = nullptr; - // Context parameters: - int64_t CtxSize; - int64_t BatchSize; - int64_t UBatchSize; - int64_t Threads; - // Sampling parameters: - double Temp = 0.80; - double TopP = 0.95; - double RepeatPenalty = 1.10; - double PresencePenalty = 0.00; - double FrequencyPenalty = 0.00; - std::string Grammar; - uint64_t Seed = LLAMA_DEFAULT_SEED; // Configs. LocalConfig Conf; }; From 9d25b67884251acc8e4281c158357c86af1f75ec Mon Sep 17 00:00:00 2001 From: dm4 Date: Mon, 17 Feb 2025 18:11:59 +0800 Subject: [PATCH 537/623] [WASI-NN] ggml: tts writes wav to output buffer Signed-off-by: dm4 --- plugins/wasi_nn/wasinn_ggml.cpp | 55 +++++++++++++++++++++------------ plugins/wasi_nn/wasinn_ggml.h | 2 +- 2 files changed, 36 insertions(+), 21 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index b4abbd14..c5c61f43 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -2392,7 +2392,10 @@ ErrNo sampleOutput(Graph &GraphRef, Context &CxtRef, // Save the output token. CxtRef.LlamaOutputTokens.emplace_back(Id); - CxtRef.LlamaOutputs += common_token_to_piece(GraphRef.LlamaContext.get(), Id); + std::string OutputString = + common_token_to_piece(GraphRef.LlamaContext.get(), Id); + CxtRef.LlamaOutputs.insert(CxtRef.LlamaOutputs.end(), OutputString.begin(), + OutputString.end()); // In single token mode, we do not handle StreamStdout and ReversePrompt. if (!IsSingleTokenMode) { // When setting StreamStdout, we print the output to stdout. @@ -2403,8 +2406,8 @@ ErrNo sampleOutput(Graph &GraphRef, Context &CxtRef, } // Break if reverse prompt is found. if (!CxtRef.Conf.ReversePrompt.empty() && - CxtRef.LlamaOutputs.find(CxtRef.Conf.ReversePrompt) != - std::string::npos) { + std::string(CxtRef.LlamaOutputs.begin(), CxtRef.LlamaOutputs.end()) + .find(CxtRef.Conf.ReversePrompt) != std::string::npos) { LOG_INFO(GraphRef.EnableLog, "sampleOutput: reverse prompt found."sv) return ErrNo::EndOfSequence; } @@ -2475,7 +2478,10 @@ Expect getEmbedding(Graph &GraphRef, Context &CxtRef) noexcept { static_cast(CxtRef.Conf.EmbdNormalize)); } - buildOutputEmbedding(CxtRef.LlamaOutputs, NEmbd, Embeddings.data()); + std::string EmbeddingString; + buildOutputEmbedding(EmbeddingString, NEmbd, Embeddings.data()); + CxtRef.LlamaOutputs = + std::vector(EmbeddingString.begin(), EmbeddingString.end()); if (GraphRef.EnableLog) { common_perf_print(GraphRef.LlamaContext.get(), /* Sampler */ nullptr); @@ -2659,15 +2665,9 @@ struct WavHeader { uint32_t DataSize; }; -void audioDataToWav(const std::string &Filename, const std::vector &Data, - int SampleRate) { - std::ofstream File(Filename, std::ios::binary); - if (!File) { - LOG_ERROR("audioDataToWav: Failed to open file '{}' for writing"sv, - Filename); - return; - } - +std::vector audioDataToWav(const std::vector &Data, + int SampleRate) { + std::vector WavData; WavHeader Header; Header.SampleRate = SampleRate; Header.ByteRate = @@ -2677,15 +2677,17 @@ void audioDataToWav(const std::string &Filename, const std::vector &Data, static_cast(Data.size() * (Header.BitsPerSample / 8)); Header.ChunkSize = 36 + Header.DataSize; - File.write(reinterpret_cast(&Header), sizeof(Header)); + WavData.insert(WavData.end(), reinterpret_cast(&Header), + reinterpret_cast(&Header) + sizeof(Header)); for (const auto &Sample : Data) { int16_t PCMSample = static_cast(std::clamp(Sample * 32767.0, -32768.0, 32767.0)); - File.write(reinterpret_cast(&PCMSample), sizeof(PCMSample)); + WavData.insert(WavData.end(), reinterpret_cast(&PCMSample), + reinterpret_cast(&PCMSample) + sizeof(PCMSample)); } - File.close(); + return WavData; } // TextToSpeech function, will generate voice data from codes. @@ -2731,8 +2733,21 @@ ErrNo codesToSpeech(Graph &GraphRef, Context &CxtRef) noexcept { AudioData[I] = 0.0f; } - // Save .wav file - audioDataToWav(GraphRef.TTSOutputFilePath, AudioData, SamplingRate); + // Convert audio data to wav and put it into output buffer. + CxtRef.LlamaOutputs = audioDataToWav(AudioData, SamplingRate); + + // Save .wav file if path is provided. + if (!GraphRef.TTSOutputFilePath.empty()) { + std::ofstream File(GraphRef.TTSOutputFilePath, std::ios::binary); + if (!File) { + RET_ERROR(ErrNo::RuntimeError, + "codesToSpeech: Failed to open file '{}' for writing"sv, + GraphRef.TTSOutputFilePath); + } + File.write(reinterpret_cast(CxtRef.LlamaOutputs.data()), + CxtRef.LlamaOutputs.size()); + File.close(); + } return ErrNo::Success; } @@ -3253,9 +3268,9 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, return ErrNo::Success; } - std::copy_n(CxtRef.LlamaOutputs.data(), CxtRef.LlamaOutputs.length(), + std::copy_n(CxtRef.LlamaOutputs.data(), CxtRef.LlamaOutputs.size(), OutBuffer.data()); - BytesWritten = static_cast(CxtRef.LlamaOutputs.length()); + BytesWritten = static_cast(CxtRef.LlamaOutputs.size()); LOG_DEBUG(GraphRef.EnableDebugLog, "getOutput: with Index {}...Done"sv, Index) return ErrNo::Success; } diff --git a/plugins/wasi_nn/wasinn_ggml.h b/plugins/wasi_nn/wasinn_ggml.h index 51ac6624..b9122037 100644 --- a/plugins/wasi_nn/wasinn_ggml.h +++ b/plugins/wasi_nn/wasinn_ggml.h @@ -85,7 +85,7 @@ struct Context { std::vector LlamaInputs; uint64_t LlamaNInputs = 0; // Llama outputs: - std::string LlamaOutputs; + std::vector LlamaOutputs; std::vector LlamaOutputTokens; // Preserve for llava struct llava_image_embed *LlavaImageEmbd = nullptr; From a504b3b0172eb339bad2b26fe610179d46bcbe60 Mon Sep 17 00:00:00 2001 From: hydai Date: Fri, 21 Feb 2025 23:53:56 +0800 Subject: [PATCH 538/623] [WASI-NN] ggml: fix the warmup regression Signed-off-by: hydai --- plugins/wasi_nn/wasinn_ggml.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index c5c61f43..662b26aa 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -2778,6 +2778,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, llama_model_params ModelParamsDefault = llama_model_default_params(); GraphRef.Params.n_gpu_layers = ModelParamsDefault.n_gpu_layers; GraphRef.Params.mmproj = ""sv; + GraphRef.Params.warmup = false; // Initialize the context parameters. llama_context_params ContextParamsDefault = llama_context_default_params(); GraphRef.Params.cpuparams.n_threads = ContextParamsDefault.n_threads; From b2b53e73cfa8ada666987b188099b345701b6b00 Mon Sep 17 00:00:00 2001 From: hydai Date: Sat, 22 Feb 2025 01:05:51 +0800 Subject: [PATCH 539/623] [WASI-NN] whisper: fix the modified test file path Signed-off-by: hydai --- test/plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index 432c008b..9fdb3868 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -106,7 +106,7 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) MD5=4279db3d7b18d9f6e4d5817a16af4f09 ) download( - https://github.com/second-state/WasmEdge-WASINN-examples/raw/master/whisper-basic/test.wav + https://github.com/second-state/WasmEdge-WASINN-examples/raw/master/wasmedge-ggml/whisper/test.wav ${CMAKE_CURRENT_BINARY_DIR}/wasinn_whisper_fixtures/test.wav MD5=6cf3f7af1ebbd6b29c373e526b548dba ) From 5a1da39f11fea933240b1577a1e4c7db1439f863 Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Tue, 25 Feb 2025 21:07:22 +0800 Subject: [PATCH 540/623] [CMake] Remove post-build copy for stable-diffusion Signed-off-by: Yi Huang --- .../wasmedge_stablediffusion/CMakeLists.txt | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/plugins/wasmedge_stablediffusion/CMakeLists.txt b/plugins/wasmedge_stablediffusion/CMakeLists.txt index c716e06a..1e3038b9 100644 --- a/plugins/wasmedge_stablediffusion/CMakeLists.txt +++ b/plugins/wasmedge_stablediffusion/CMakeLists.txt @@ -99,25 +99,16 @@ else() target_compile_options( stable-diffusion PRIVATE - -Wno-unused-function - -Wno-unused-variable - -Wno-unused-parameter - -Wno-missing-field-initializers + -Wno-unused-function + -Wno-unused-variable + -Wno-unused-parameter + -Wno-missing-field-initializers -Wno-deprecated-declarations -Wno-braced-scalar-init -Wno-unused-value -Wno-uninitialized -Wno-format -Wno-enum-compare - ) -endif() - -if(WASMEDGE_PLUGIN_STABLEDIFFUSION_METAL) - add_custom_command( - TARGET wasmedgePluginWasmEdgeStableDiffusion - POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${stable-diffusion_SOURCE_DIR}/ggml/src/ggml-metal/ggml-metal.metal ggml-metal.metal - COMMAND ${CMAKE_COMMAND} -E copy ${stable-diffusion_SOURCE_DIR}/ggml/src/ggml-common.h ggml-common.h ) endif() @@ -125,4 +116,4 @@ install( TARGETS wasmedgePluginWasmEdgeStableDiffusion DESTINATION ${CMAKE_INSTALL_LIBDIR}/wasmedge COMPONENT WasmEdge -) \ No newline at end of file +) From 0f627a564e08257ed1ea289efa07af64ecd2097a Mon Sep 17 00:00:00 2001 From: LFsWang <7088579+LFsWang@users.noreply.github.com> Date: Thu, 6 Mar 2025 20:38:55 +0800 Subject: [PATCH 541/623] [WASI-NN] add dependency installer for openvino-genai (#4032) Signed-off-by: LFsWang --- utils/docker/Dockerfile.ubuntu-plugins-deps | 10 +++++- utils/wasi-nn/install-openvino-genai.sh | 36 +++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 utils/wasi-nn/install-openvino-genai.sh diff --git a/utils/docker/Dockerfile.ubuntu-plugins-deps b/utils/docker/Dockerfile.ubuntu-plugins-deps index 1b177c54..c2f48bf8 100644 --- a/utils/docker/Dockerfile.ubuntu-plugins-deps +++ b/utils/docker/Dockerfile.ubuntu-plugins-deps @@ -60,10 +60,17 @@ RUN [ "/bin/bash", "install-pytorch.sh" ] ARG UBUNTU_VER COPY wasi-nn/install-openvino.sh . +COPY wasi-nn/install-openvino-genai.sh . ENV OPENVINO_UBUNTU_VERSION=${UBUNTU_VER} ENV OPENVINO_VERSION="2025.0.0" ENV OPENVINO_YEAR="2025" +ENV OPENVINOGEN_VERSION="2025.0.0.0" +ENV OPENVINOGEN_YEAR="2025.0" +ENV OPENVINOGEN_DIRNAME="openvino_genai" RUN [ "/bin/bash", "install-openvino.sh" ] +RUN [ "/bin/bash", "install-openvino-genai.sh" ] +RUN [ "ls", "-al" ] +RUN [ "/bin/bash", "-c", "echo \"source ./openvino_genai/setupvars.sh\" >> .bashrc" ] COPY wasi-nn/install-onnxruntime.sh . RUN [ "/bin/bash", "install-onnxruntime.sh" ] @@ -76,6 +83,7 @@ RUN rm -f \ install-ffmpeg-v6.0.sh \ install-pytorch.sh \ install-openvino.sh \ - install-onnxruntime.sh + install-onnxruntime.sh \ + install-openvino-genai.sh RUN rm -rf /var/lib/apt/lists/* diff --git a/utils/wasi-nn/install-openvino-genai.sh b/utils/wasi-nn/install-openvino-genai.sh new file mode 100644 index 00000000..a877adac --- /dev/null +++ b/utils/wasi-nn/install-openvino-genai.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# SPDX-FileCopyrightText: 2019-2024 Second State INC +set -e + +if [[ ! -v "${OPENVINOGEN_VERSION}" ]]; then + OPENVINOGEN_VERSION="2025.0.0.0" +fi +if [[ ! -v "${OPENVINOGEN_YEAR}" ]]; then + OPENVINOGEN_YEAR="2025.0" +fi +if [[ ! -v "${OPENVINOGEN_DIRNAME}" ]]; then + OPENVINOGEN_DIRNAME="openvino_genai" +fi + +if [[ ! -v "${OPENVINO_UBUNTU_VERSION}" ]]; then + source /etc/os-release + OPENVINO_UBUNTU_VERSION="${VERSION_ID::2}" +fi + +UBUNTU_VERSION="ubuntu${OPENVINO_UBUNTU_VERSION:-20}" +OPENVINOGEN_TGZ_NAME="openvino_genai_${UBUNTU_VERSION}_${OPENVINOGEN_VERSION}_x86_64" + + +echo "Installing OpenVINO GenAI with version ${OPENVINOGEN_VERSION}" +curl -L https://storage.openvinotoolkit.org/repositories/openvino_genai/packages/${OPENVINOGEN_YEAR}/linux/${OPENVINOGEN_TGZ_NAME}.tar.gz --output ${OPENVINOGEN_TGZ_NAME}.tgz +tar -xf ${OPENVINOGEN_TGZ_NAME}.tgz +mv ${OPENVINOGEN_TGZ_NAME} $OPENVINOGEN_DIRNAME +./${OPENVINOGEN_DIRNAME}/install_dependencies/install_openvino_dependencies.sh -y + +rm ${OPENVINOGEN_TGZ_NAME}.tgz + +echo "OpenVINO GenAI installed at `pwd`/$OPENVINOGEN_DIRNAME" +echo "Please source the setupvars.sh script in the OpenVINO GenAI directory to use the OpenVINO GenAI tools." +echo "# source $PWD/$OPENVINOGEN_DIRNAME/setupvars.sh" + From d8db095ebd833823b23237ac8bf5eb48c1a9afe2 Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 12 Mar 2025 18:55:08 +0800 Subject: [PATCH 542/623] [WASI-NN] ggml: bump to b4875; bump wasi-nn plugin to 0.1.15 (#4055) * Starts from b4875: support gemma-3 text-only model * Rename lora_outfile to out_file Signed-off-by: hydai --- plugins/wasi_nn/wasinn_ggml.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 662b26aa..89580475 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -1739,14 +1739,14 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, "Unable to retrieve the spm-infill option."sv) } } - if (Doc.at_key("lora-outfile").error() == simdjson::SUCCESS) { - std::string_view LoraOutfile; - auto Err = Doc["lora-outfile"].get().get(LoraOutfile); + if (Doc.at_key("out-file").error() == simdjson::SUCCESS) { + std::string_view Outfile; + auto Err = Doc["out-file"].get().get(Outfile); if (Err) { RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the lora-outfile option."sv) + "Unable to retrieve the outfile option."sv) } - GraphRef.Params.lora_outfile = LoraOutfile; + GraphRef.Params.out_file = Outfile; } if (Doc.at_key("batched-bench-output-jsonl").error() == simdjson::SUCCESS) { auto Err = Doc["batched-bench-output-jsonl"].get().get( From bc080bf4c06397febf3e440cb78a785018a9b8c3 Mon Sep 17 00:00:00 2001 From: hydai Date: Mon, 17 Mar 2025 14:03:16 +0800 Subject: [PATCH 543/623] [WASI-NN] ggml: bump to b4897; bump wasi-nn plugin to 0.1.16 (#4058) Replace llama_kv_cache_clear with llama_kv_self_clear. The API is changed by llama.cpp upstream. Signed-off-by: hydai --- plugins/wasi_nn/wasinn_ggml.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 89580475..c39fe673 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -2294,7 +2294,7 @@ ErrNo evaluateInput(Graph &GraphRef, Context &CxtRef, "{}: clear the previous output and tokens...Done"sv, LogPrefix) // Clear the llama context. - llama_kv_cache_clear(GraphRef.LlamaContext.get()); + llama_kv_self_clear(GraphRef.LlamaContext.get()); // Prepare variables; CxtRef.NPos = 0; @@ -3068,7 +3068,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, // Clear the llama context. LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: clear llama context"sv) - llama_kv_cache_clear(GraphRef.LlamaContext.get()); + llama_kv_self_clear(GraphRef.LlamaContext.get()); LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: clear llama context...Done"sv) // Set the input. From 14ccc76237764c5d6169949a130174a68a1015e8 Mon Sep 17 00:00:00 2001 From: LFsWang <7088579+LFsWang@users.noreply.github.com> Date: Tue, 18 Mar 2025 14:21:51 +0800 Subject: [PATCH 544/623] [WASI-NN] Add openvino genai support (#4034) Signed-off-by: LFsWang --- plugins/wasi_nn/CMakeLists.txt | 1 + plugins/wasi_nn/wasinn_openvino_genai.cpp | 247 ++++++++++++++++++++++ plugins/wasi_nn/wasinn_openvino_genai.h | 100 +++++++++ plugins/wasi_nn/wasinnenv.h | 1 + plugins/wasi_nn/wasinntypes.h | 4 +- 5 files changed, 352 insertions(+), 1 deletion(-) create mode 100644 plugins/wasi_nn/wasinn_openvino_genai.cpp create mode 100644 plugins/wasi_nn/wasinn_openvino_genai.h diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index ee057279..59324375 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -7,6 +7,7 @@ wasmedge_add_library(wasmedgePluginWasiNN wasinnfunc.cpp wasinnmodule.cpp wasinn_openvino.cpp + wasinn_openvino_genai.cpp wasinn_onnx.cpp wasinn_tf.cpp wasinn_torch.cpp diff --git a/plugins/wasi_nn/wasinn_openvino_genai.cpp b/plugins/wasi_nn/wasinn_openvino_genai.cpp new file mode 100644 index 00000000..a92030c3 --- /dev/null +++ b/plugins/wasi_nn/wasinn_openvino_genai.cpp @@ -0,0 +1,247 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "wasinn_openvino_genai.h" +#include "wasinnenv.h" + +#include + +using namespace std::literals; + +namespace WasmEdge::Host::WASINN::OpenVINOGenAI { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINOGENAI + +Expect GetDeviceString(WASINN::Device TargetDevice, + std::string &DeviceString) noexcept { + switch (TargetDevice) { + case Device::CPU: + DeviceString = "CPU"; + break; + case Device::GPU: + DeviceString = "GPU"; + break; + default: + spdlog::error("[WASI-NN] Unsupported device type"sv); + return WASINN::ErrNo::InvalidArgument; + } + return WASINN::ErrNo::Success; +} + +Expect isStringTensor(const TensorData &Tensor) noexcept { + if (Tensor.RType != WASINN::TensorType::U8) { + spdlog::error( + "[WASI-NN] Only STRING (u8) inputs and outputs are supported for " + "now."sv); + return WASINN::ErrNo::InvalidArgument; + } + if (Tensor.Dimension.size() != 1) { + spdlog::error("[WASI-NN] Tensor dimension is out of range, expect it under " + "1-dim, but got {}-dim."sv, + Tensor.Dimension.size()); + return WASINN::ErrNo::InvalidArgument; + } + return WASINN::ErrNo::Success; +} + +Expect +LLMPipelineBackend::SetContextInput(Context &CxtRef, uint32_t Index, + const TensorData &Tensor) { + + if (Index != 0) { + spdlog::error("[WASI-NN] The input index {} is out of range."sv, Index); + return WASINN::ErrNo::InvalidArgument; + } + + if (auto Res = isStringTensor(Tensor); !Res) { + return Res; + } + + try { + CxtRef.StringInput = + std::string(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); + } catch (const std::exception &EX) { + spdlog::error("[WASI-NN] Set Input Exception: {}"sv, EX.what()); + return WASINN::ErrNo::RuntimeError; + } + return WASINN::ErrNo::Success; +} + +Expect LLMPipelineBackend::Generate(Context &CxtRef) { + try { + CxtRef.StringOutput = Model->generate(CxtRef.StringInput); + } catch (const std::exception &EX) { + spdlog::error("[WASI-NN] Generate Exception: {}"sv, EX.what()); + return WASINN::ErrNo::RuntimeError; + } + return WASINN::ErrNo::Success; +} + +Expect +LLMPipelineBackend::GetContextOutput(Context &CxtRef, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) { + if (Index != 0) { + spdlog::error("[WASI-NN] The output index {} is out of range."sv, Index); + return WASINN::ErrNo::InvalidArgument; + } + + try { + BytesWritten = CxtRef.StringOutput.size(); + std::copy_n(reinterpret_cast(CxtRef.StringOutput.data()), + BytesWritten, OutBuffer.data()); + } catch (const std::exception &EX) { + spdlog::error("[WASI-NN] Get Output Exception: {}"sv, EX.what()); + return WASINN::ErrNo::RuntimeError; + } + return WASINN::ErrNo::Success; +} + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept { + // The graph builder length must be 3. + if (Builders.size() != 3) { + spdlog::error("[WASI-NN] Wrong GraphBuilder Length {:d}, expect 3"sv, + Builders.size()); + return WASINN::ErrNo::InvalidArgument; + } + + // Get the XML and Weight raw buffer. + // Builder-0: Reserved, the string "LLMPipeline" + // Builder-1: Path to the dir model xml/bin files + // Builder-2: Empty for now (reserved for future use) + + // There are 4 types (text or img) x (text or img), we assume the input is 0 + // for now. + auto ModelType = std::string( + reinterpret_cast(Builders[0].data()), Builders[0].size()); + auto ModelPath = std::string( + reinterpret_cast(Builders[1].data()), Builders[1].size()); + // TODO: Support extra model information. (ex. enable kv cache) + [[maybe_unused]] auto ModelExtra = std::string( + reinterpret_cast(Builders[2].data()), Builders[2].size()); + + // Add a new graph. + uint32_t GId = Env.newGraph(Backend::OpenVINOGenAI); + auto &GraphRef = Env.NNGraph[GId].get(); + + // Store device information + GraphRef.TargetDevice = Device; + std::string DeviceString; + if (auto Err = GetDeviceString(Device, DeviceString); + Err != WASINN::ErrNo::Success) { + return Err; + } + + try { + // Create the OpenVINO GenAI Backend. + // Currently, we only support LLMPipeline. + if (ModelType == "LLMPipeline") { + GraphRef.OpenVINOGenAI = + std::make_shared(ModelPath, DeviceString); + } else { + spdlog::error("[WASI-NN] Unsupported model type: {}"sv, ModelType); + return WASINN::ErrNo::InvalidArgument; + } + + } catch (const std::exception &EX) { + spdlog::error("[WASI-NN] Model Load Exception: {}"sv, EX.what()); + Env.deleteGraph(GId); + return WASINN::ErrNo::RuntimeError; + } + // Store the loaded graph. + GraphId = GId; + Env.NNGraph[GId].setReady(); + return WASINN::ErrNo::Success; +} + +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept { + // Check the network and the execution network with the graph ID. + auto &GraphRef = Env.NNGraph[GraphId].get(); + if (GraphRef.OpenVINOGenAI == nullptr) { + spdlog::error("[WASI-NN] Model for Graph:{} is empty!"sv, GraphId); + return WASINN::ErrNo::MissingMemory; + } + // Create context. + ContextId = Env.newContext(GraphId, Env.NNGraph[GraphId]); + Env.NNContext[ContextId].setReady(); + return WASINN::ErrNo::Success; +} + +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + + if (GraphRef.OpenVINOGenAI == nullptr) { + spdlog::error("[WASI-NN] The founded openvino genei session is empty"sv); + return WASINN::ErrNo::MissingMemory; + } + + return GraphRef.OpenVINOGenAI->SetContextInput(CxtRef, Index, Tensor); +} + +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + + if (GraphRef.OpenVINOGenAI == nullptr) { + spdlog::error("[WASI-NN] The founded openvino genei session is empty"sv); + return WASINN::ErrNo::MissingMemory; + } + + return GraphRef.OpenVINOGenAI->GetContextOutput(CxtRef, Index, OutBuffer, + BytesWritten); +} + +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + try { + GraphRef.OpenVINOGenAI->Generate(CxtRef); + } catch (const std::exception &EX) { + spdlog::error("[WASI-NN] Infer Request Exception: {}"sv, EX.what()); + return WASINN::ErrNo::RuntimeError; + } + return WASINN::ErrNo::Success; +} +#else +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error( + "[WASI-NN] OpenVINO GenAI backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"OpenVINOGenAI\" to build it."sv); + return WASINN::ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WASINN::WasiNNEnvironment &, + Span>, WASINN::Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WASINN::WasiNNEnvironment &, uint32_t, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WASINN::WasiNNEnvironment &, uint32_t, uint32_t, + Span, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WASINN::WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +#endif +} // namespace WasmEdge::Host::WASINN::OpenVINOGenAI diff --git a/plugins/wasi_nn/wasinn_openvino_genai.h b/plugins/wasi_nn/wasinn_openvino_genai.h new file mode 100644 index 00000000..cf6c673d --- /dev/null +++ b/plugins/wasi_nn/wasinn_openvino_genai.h @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "wasinntypes.h" + +#include "plugin/plugin.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINOGENAI +#include "openvino/openvino.hpp" +#include +#include +#endif + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} + +namespace WasmEdge::Host::WASINN::OpenVINOGenAI { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINOGENAI + +struct Context; +class OpenVINOGenAIBackend { +public: + OpenVINOGenAIBackend() = default; + virtual Expect SetContextInput(Context &CxtRef, uint32_t Index, + const TensorData &Tensor) = 0; + virtual Expect Generate(Context &CxtRef) = 0; + virtual Expect GetContextOutput(Context &CxtRef, + uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) = 0; + virtual ~OpenVINOGenAIBackend() noexcept {} +}; + +class LLMPipelineBackend : public OpenVINOGenAIBackend { +public: + LLMPipelineBackend(std::string Path, std::string Device) { + Model = std::make_shared(Path, Device); + } + ~LLMPipelineBackend() noexcept {} + Expect SetContextInput(Context &CxtRef, uint32_t Index, + const TensorData &Tensor) override; + Expect Generate(Context &CxtRef) override; + Expect GetContextOutput(Context &CxtRef, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) override; + +private: + std::shared_ptr Model; +}; + +struct Graph { + ~Graph() noexcept {} + std::shared_ptr OpenVINOGenAI; + Device TargetDevice = Device::AUTO; +}; + +struct Context { + Context(uint32_t GId, Graph &) noexcept : GraphId(GId) {} + ~Context() noexcept {} + uint32_t GraphId; + std::string StringInput; + std::string StringOutput; + + // For image input/output + // ov::Tensor TensorInput; + // ov::Tensor TensorOutput; +}; + +struct Environ { + Environ() noexcept {} + ~Environ() noexcept {} + ov::Core OpenVINOCore; +}; +#else +struct Graph {}; +struct Context { + Context(uint32_t, Graph &) noexcept {} +}; +struct Environ {}; +#endif + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +} // namespace WasmEdge::Host::WASINN::OpenVINOGenAI diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index 8d2d575e..95c54279 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -9,6 +9,7 @@ #include "wasinn_neuralspeed.h" #include "wasinn_onnx.h" #include "wasinn_openvino.h" +#include "wasinn_openvino_genai.h" #include "wasinn_piper.h" #include "wasinn_tf.h" #include "wasinn_tfl.h" diff --git a/plugins/wasi_nn/wasinntypes.h b/plugins/wasi_nn/wasinntypes.h index 6bf716b3..ebc95a1f 100644 --- a/plugins/wasi_nn/wasinntypes.h +++ b/plugins/wasi_nn/wasinntypes.h @@ -43,6 +43,7 @@ enum class Backend : uint8_t { MLX = 10, Piper = 11, ChatTTS = 12, + OpenVINOGenAI = 13, }; #define FOR_EACH_BACKEND(F) \ @@ -56,7 +57,8 @@ enum class Backend : uint8_t { F(Whisper) \ F(Piper) \ F(ChatTTS) \ - F(MLX) + F(MLX) \ + F(OpenVINOGenAI) struct TensorData { Span Dimension; From 4e07ded69144c826d164e6816884bed662e15651 Mon Sep 17 00:00:00 2001 From: LFsWang <7088579+LFsWang@users.noreply.github.com> Date: Sat, 29 Mar 2025 18:46:00 +0800 Subject: [PATCH 545/623] [WASI-NN] Fix symbol of openvino genai (#4067) Signed-off-by: LFsWang --- plugins/wasi_nn/wasinn_openvino_genai.cpp | 14 ++++++++++---- plugins/wasi_nn/wasinnenv.cpp | 3 ++- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/plugins/wasi_nn/wasinn_openvino_genai.cpp b/plugins/wasi_nn/wasinn_openvino_genai.cpp index a92030c3..4fd945e9 100644 --- a/plugins/wasi_nn/wasinn_openvino_genai.cpp +++ b/plugins/wasi_nn/wasinn_openvino_genai.cpp @@ -29,10 +29,11 @@ Expect GetDeviceString(WASINN::Device TargetDevice, Expect isStringTensor(const TensorData &Tensor) noexcept { if (Tensor.RType != WASINN::TensorType::U8) { - spdlog::error( + spdlog::warn( "[WASI-NN] Only STRING (u8) inputs and outputs are supported for " - "now."sv); - return WASINN::ErrNo::InvalidArgument; + "now. Input Type: {}"sv, + Tensor.RType); + // return WASINN::ErrNo::InvalidArgument; } if (Tensor.Dimension.size() != 1) { spdlog::error("[WASI-NN] Tensor dimension is out of range, expect it under " @@ -69,7 +70,12 @@ LLMPipelineBackend::SetContextInput(Context &CxtRef, uint32_t Index, Expect LLMPipelineBackend::Generate(Context &CxtRef) { try { - CxtRef.StringOutput = Model->generate(CxtRef.StringInput); + // TODO: let the user to set the generation config. + spdlog::warn("[WASI-NN] The generation config is not supported for now."sv); + spdlog::warn("[WASI-NN] Maximum token limit is set to 100."sv); + ov::genai::GenerationConfig config; + config.max_new_tokens = 100; + CxtRef.StringOutput = Model->generate(CxtRef.StringInput, config); } catch (const std::exception &EX) { spdlog::error("[WASI-NN] Generate Exception: {}"sv, EX.what()); return WASINN::ErrNo::RuntimeError; diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index 88647e02..e09d3780 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -37,7 +37,8 @@ std::map BackendMap = { {"whisper"sv, Backend::Whisper}, {"mlx"sv, Backend::MLX}, {"piper"sv, Backend::Piper}, - {"chattts"sv, Backend::ChatTTS}}; + {"chattts"sv, Backend::ChatTTS}, + {"openvinogenai"sv, Backend::OpenVINOGenAI}}; std::map DeviceMap = {{"cpu"sv, Device::CPU}, {"gpu"sv, Device::GPU}, From 5878ab633f7c6128f94c2e3ac18c85d8cf255d0e Mon Sep 17 00:00:00 2001 From: LFsWang <7088579+LFsWang@users.noreply.github.com> Date: Mon, 31 Mar 2025 11:44:02 +0800 Subject: [PATCH 546/623] [WASI-NN] Update TensorType Index (#4069) Signed-off-by: LFsWang --- plugins/wasi_nn/wasinntypes.h | 15 ++++++++++++++- test/plugins/wasi_nn/wasi_nn.cpp | 8 ++++---- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/plugins/wasi_nn/wasinntypes.h b/plugins/wasi_nn/wasinntypes.h index ebc95a1f..d6e37bd1 100644 --- a/plugins/wasi_nn/wasinntypes.h +++ b/plugins/wasi_nn/wasinntypes.h @@ -26,7 +26,14 @@ enum class ErrNo : uint32_t { ModelNotFound = 103, // Model Not Found. }; -enum class TensorType : uint8_t { F16 = 0, F32 = 1, U8 = 2, I32 = 3 }; +enum class TensorType : uint8_t { + F16 = 0, + F32 = 1, + F64 = 2, + U8 = 3, + I32 = 4, + I64 = 5 +}; enum class Device : uint32_t { CPU = 0, GPU = 1, TPU = 2, AUTO = 3 }; @@ -82,12 +89,18 @@ struct fmt::formatter case WasmEdge::Host::WASINN::TensorType::F32: Name = "F32"sv; break; + case WasmEdge::Host::WASINN::TensorType::F64: + Name = "F64"sv; + break; case WasmEdge::Host::WASINN::TensorType::U8: Name = "U8"sv; break; case WasmEdge::Host::WASINN::TensorType::I32: Name = "I32"sv; break; + case WasmEdge::Host::WASINN::TensorType::I64: + Name = "I64"sv; + break; default: Name = "Unknown"sv; } diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index bf714876..9feeb8ca 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -1174,7 +1174,7 @@ TEST(WasiNNTest, TFLiteBackend) { writeFatPointer(MemInst, StorePtr, static_cast(TensorDim.size()), BuilderPtr); // Tensor type U8 - writeUInt32(MemInst, UINT32_C(2), BuilderPtr); + writeUInt32(MemInst, UINT32_C(3), BuilderPtr); writeFatPointer(MemInst, StorePtr + static_cast(TensorDim.size()) * 4, static_cast(TensorData.size()), BuilderPtr); @@ -2376,7 +2376,7 @@ TEST(WasiNNTest, PiperBackend) { // Piper WASI-NN set_input tests. SetInputEntryPtr = BuilderPtr; writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); - writeUInt32(MemInst, 2, BuilderPtr); + writeUInt32(MemInst, 3, BuilderPtr); writeFatPointer(MemInst, StorePtr + TensorDim.size() * sizeof(decltype(TensorDim)::value_type), @@ -2500,7 +2500,7 @@ TEST(WasiNNTest, PiperBackend) { TensorData = {Text.begin(), Text.end()}; SetInputEntryPtr = BuilderPtr; writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); - writeUInt32(MemInst, 2, BuilderPtr); + writeUInt32(MemInst, 3, BuilderPtr); writeFatPointer(MemInst, StorePtr + TensorDim.size() * sizeof(decltype(TensorDim)::value_type), @@ -2549,7 +2549,7 @@ TEST(WasiNNTest, PiperBackend) { TensorData = {Text.begin(), Text.end()}; SetInputEntryPtr = BuilderPtr; writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); - writeUInt32(MemInst, 2, BuilderPtr); + writeUInt32(MemInst, 3, BuilderPtr); writeFatPointer(MemInst, StorePtr + TensorDim.size() * sizeof(decltype(TensorDim)::value_type), From 8df8e9afef84c750ba3c804bf6a91f1444be5bc6 Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 8 Apr 2025 12:52:28 +0800 Subject: [PATCH 547/623] [WASI-NN] ggml: bump to b5074; bump wasi-nn plugin to 0.1.18 (#4078) * There is no model, model_url, hf_repo, and hf_file. Instead, all these options are now inside the common_params_model structure. * Support llama4-text only * Disable CURL dependency Signed-off-by: hydai --- plugins/wasi_nn/wasinn_ggml.cpp | 34 ++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index c39fe673..e52b28d9 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -206,7 +206,7 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the mmproj option."sv) } - GraphRef.Params.mmproj = MMProjModelPath; + GraphRef.Params.mmproj.path = MMProjModelPath; } // The TTS parameters. @@ -224,7 +224,7 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the model-vocoder option."sv) } - GraphRef.Params.vocoder.model = VocoderModelPath; + GraphRef.Params.vocoder.model.path = VocoderModelPath; } if (Doc.at_key("tts-output-file").error() == simdjson::SUCCESS) { std::string_view TTSOutputFilePath; @@ -888,7 +888,7 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the hf-repo-vocoder option."sv) } - GraphRef.Params.vocoder.hf_repo = HfRepo; + GraphRef.Params.vocoder.model.hf_repo = HfRepo; } if (Doc.at_key("hf-file-vocoder").error() == simdjson::SUCCESS) { std::string_view HfFile; @@ -897,7 +897,7 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the hf-file-vocoder option."sv) } - GraphRef.Params.vocoder.hf_file = HfFile; + GraphRef.Params.vocoder.model.hf_file = HfFile; } if (Doc.at_key("model-url-vocoder").error() == simdjson::SUCCESS) { std::string_view ModelUrlVocoder; @@ -907,7 +907,7 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the model-url-vocoder option."sv) } - GraphRef.Params.vocoder.model_url = ModelUrlVocoder; + GraphRef.Params.vocoder.model.url = ModelUrlVocoder; } // The config parameters. if (Doc.at_key("stream-stdout").error() == simdjson::SUCCESS) { @@ -966,7 +966,7 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the model-url option."sv) } - GraphRef.Params.model_url = ModelUrl; + GraphRef.Params.model.url = ModelUrl; } if (Doc.at_key("hf-token").error() == simdjson::SUCCESS) { std::string_view HfToken; @@ -984,7 +984,7 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the hf-repo option."sv) } - GraphRef.Params.hf_repo = HfRepo; + GraphRef.Params.model.hf_repo = HfRepo; } if (Doc.at_key("hf-file").error() == simdjson::SUCCESS) { std::string_view HfFile; @@ -993,7 +993,7 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the hf-file option."sv) } - GraphRef.Params.hf_file = HfFile; + GraphRef.Params.model.hf_file = HfFile; } if (Doc.at_key("prompt-file").error() == simdjson::SUCCESS) { std::string_view PromptFile; @@ -2777,7 +2777,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Initialize the model parameters. llama_model_params ModelParamsDefault = llama_model_default_params(); GraphRef.Params.n_gpu_layers = ModelParamsDefault.n_gpu_layers; - GraphRef.Params.mmproj = ""sv; + GraphRef.Params.mmproj.path = ""sv; GraphRef.Params.warmup = false; // Initialize the context parameters. llama_context_params ContextParamsDefault = llama_context_default_params(); @@ -2846,15 +2846,15 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, const std::string_view BinModel(reinterpret_cast(Weight.data()), Weight.size()); if (BinModel.substr(0, 8) == "preload:"sv) { - GraphRef.Params.model = BinModel.substr(8); + GraphRef.Params.model.path = BinModel.substr(8); } else { LOG_DEBUG(GraphRef.EnableDebugLog, "load: Model path not found in nn-preload, write model into "sv "a tmpfile."sv) // TODO: pass the model directly to ggml. // Write ggml model to file. - GraphRef.Params.model = "ggml-model.bin"sv; - std::ofstream TempFile(GraphRef.Params.model, + GraphRef.Params.model.path = "ggml-model.bin"sv; + std::ofstream TempFile(GraphRef.Params.model.path, std::ios::out | std::ios::binary); if (!TempFile) { Env.deleteGraph(GId); @@ -2873,7 +2873,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Check if the model exists. if (!std::filesystem::exists( - std::filesystem::u8path(GraphRef.Params.model))) { + std::filesystem::u8path(GraphRef.Params.model.path))) { Env.deleteGraph(GId); RET_ERROR(ErrNo::ModelNotFound, "load: model file not found."sv) } @@ -3005,7 +3005,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, CxtRef.LlamaSampler = nullptr; } GraphRef.LlamaModel = llama_model_ptr(llama_model_load_from_file( - GraphRef.Params.model.c_str(), ModelParams)); + GraphRef.Params.model.path.c_str(), ModelParams)); if (GraphRef.LlamaModel == nullptr) { Env.NNGraph[CxtRef.GraphId].setInvalid(); RET_ERROR(ErrNo::InvalidArgument, "setInput: unable to init model."sv) @@ -3084,7 +3084,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, // Prompt with image input. Check is llava or mllama case. // First check the projection model is loaded. - if (GraphRef.Params.mmproj == ""sv) { + if (GraphRef.Params.mmproj.path == ""sv) { RET_ERROR( ErrNo::InvalidArgument, "setInput: the given model does not support image input, so a projection model is required."sv) @@ -3098,8 +3098,8 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, "for CLIP, the step of loading images in CLIP can only use the "sv "CPU, which may result in reduced efficiency. (You can refer to "sv "PR https://github.com/ggerganov/llama.cpp/pull/10896)"sv) - GraphRef.ClipContext = clip_model_load(GraphRef.Params.mmproj.c_str(), - GraphRef.EnableLog ? 1 : 0); + GraphRef.ClipContext = clip_model_load( + GraphRef.Params.mmproj.path.c_str(), GraphRef.EnableLog ? 1 : 0); if (GraphRef.ClipContext == nullptr) { RET_ERROR(ErrNo::InvalidArgument, "setInput: unable to load the clip model."sv) From b4f799b9204647170495397036dafa6c258e31f7 Mon Sep 17 00:00:00 2001 From: PeterD1524 <53310459+PeterD1524@users.noreply.github.com> Date: Sat, 12 Apr 2025 11:52:12 +0800 Subject: [PATCH 548/623] build(docker): install ChatTTS in ubuntu-plugins-deps for WASI-NN ChatTTS CI (#4080) build(docker): install ChatTTS in ubuntu-plugins-deps for WASI-NN ChatTTS CI Signed-off-by: PeterD1524 --- utils/docker/Dockerfile.ubuntu-plugins-deps | 6 +++++- utils/wasi-nn/install-chattts.sh | 24 +++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) create mode 100644 utils/wasi-nn/install-chattts.sh diff --git a/utils/docker/Dockerfile.ubuntu-plugins-deps b/utils/docker/Dockerfile.ubuntu-plugins-deps index c2f48bf8..4aaaffcc 100644 --- a/utils/docker/Dockerfile.ubuntu-plugins-deps +++ b/utils/docker/Dockerfile.ubuntu-plugins-deps @@ -75,6 +75,9 @@ RUN [ "/bin/bash", "-c", "echo \"source ./openvino_genai/setupvars.sh\" >> .bash COPY wasi-nn/install-onnxruntime.sh . RUN [ "/bin/bash", "install-onnxruntime.sh" ] +COPY wasi-nn/install-chattts.sh . +RUN [ "/bin/bash", "install-chattts.sh" ] + ### cleanup FROM deps-all AS clean-apt @@ -84,6 +87,7 @@ RUN rm -f \ install-pytorch.sh \ install-openvino.sh \ install-onnxruntime.sh \ - install-openvino-genai.sh + install-openvino-genai.sh \ + install-chattts.sh RUN rm -rf /var/lib/apt/lists/* diff --git a/utils/wasi-nn/install-chattts.sh b/utils/wasi-nn/install-chattts.sh new file mode 100644 index 00000000..f406d460 --- /dev/null +++ b/utils/wasi-nn/install-chattts.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2024 Second State INC + +set -e + +apt-get update +apt-get -y install python3 python3-dev + +# Use latest pip +python3 -m venv chattts_venv +source chattts_venv/bin/activate +pip install --upgrade pip + +# Install PyTorch CPU version to save space +pip --python /usr/bin/python3 install --break-system-packages --index-url https://download.pytorch.org/whl/cpu 'torch<=2.6.0' 'torchaudio<=2.6.0' +pip --python /usr/bin/python3 install --break-system-packages chattts==0.2.3 + +# Remove wheel cache +pip --python /usr/bin/python3 cache purge + +# Clean up +deactivate +rm -rf chattts_venv From 407b4c0bf14fe9cb4d2d96ccff36cfb5127a7940 Mon Sep 17 00:00:00 2001 From: PeterD1524 Date: Sat, 12 Apr 2025 00:10:06 +0800 Subject: [PATCH 549/623] fix(wasi-nn/ChatTTS): update compute function to be compatible with v0.2.1 In chattts v0.1.1, ChatTTS.Chat.infer returns a Python list. In chattts v0.2.1, ChatTTS.Chat.infer returns a numpy.ndarray. Use PyObject_GetItem (the equivalent of the Python expression o[index]) instead of PyList_GetItem (works only for list). Signed-off-by: PeterD1524 --- plugins/wasi_nn/wasinn_chattts.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/plugins/wasi_nn/wasinn_chattts.cpp b/plugins/wasi_nn/wasinn_chattts.cpp index f15fdb75..5937d735 100644 --- a/plugins/wasi_nn/wasinn_chattts.cpp +++ b/plugins/wasi_nn/wasinn_chattts.cpp @@ -292,8 +292,11 @@ Expect compute(WasiNNEnvironment &Env, Py_XDECREF(Kwargs); } if (Result != nullptr) { - PyObject *Wav0 = PyList_GetItem(Result, 0); + PyObject *Index = PyLong_FromLong(0); + PyObject *Wav0 = PyObject_GetItem(Result, Index); + Py_XDECREF(Index); PyObject *BytesObj = PyObject_CallMethod(Wav0, "tobytes", nullptr); + Py_XDECREF(Wav0); char *Bytes = PyBytes_AsString(BytesObj); Py_ssize_t size = PyBytes_Size(BytesObj); CxtRef.Outputs = std::vector(Bytes, Bytes + size); From 685289ca9e4d198cad8bff7b118d07a508bd251e Mon Sep 17 00:00:00 2001 From: PeterD1524 Date: Sat, 12 Apr 2025 00:18:58 +0800 Subject: [PATCH 550/623] test(wasi-nn/ChatTTS): update for TensorType change https://www.github.com/WasmEdge/WasmEdge/pull/4069 Signed-off-by: PeterD1524 --- test/plugins/wasi_nn/wasi_nn.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 9feeb8ca..40c8f8e7 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -20,6 +20,7 @@ using namespace std::literals; using WasmEdge::Host::WASINN::Backend; using WasmEdge::Host::WASINN::Device; using WasmEdge::Host::WASINN::ErrNo; +using WasmEdge::Host::WASINN::TensorType; #if defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_TORCH) || \ @@ -2730,7 +2731,7 @@ TEST(WasiNNTest, ChatTTSBackend) { // ChatTTS WASI-NN set_input tests. SetInputEntryPtr = BuilderPtr; writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); - writeUInt32(MemInst, UINT32_C(2), BuilderPtr); + writeUInt32(MemInst, static_cast(TensorType::U8), BuilderPtr); writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, TensorData.size(), BuilderPtr); writeBinaries(MemInst, TensorDim, StorePtr); @@ -2777,7 +2778,7 @@ TEST(WasiNNTest, ChatTTSBackend) { // Test: setInput -- set metadata successfully. SetInputEntryPtr = BuilderPtr; writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); - writeUInt32(MemInst, UINT32_C(2), BuilderPtr); + writeUInt32(MemInst, static_cast(TensorType::U8), BuilderPtr); writeFatPointer(MemInst, StorePtr + TensorDim.size() * 4, ConfigData.size(), BuilderPtr); writeBinaries(MemInst, TensorDim, StorePtr); From 194a9d5837f63a913fe264d62511cbf045a3d68c Mon Sep 17 00:00:00 2001 From: varunrmallya <100590632+varun-r-mallya@users.noreply.github.com> Date: Fri, 18 Apr 2025 03:45:48 +0530 Subject: [PATCH 551/623] fix: lint install-openvino-genai.sh (#4087) Signed-off-by: varun-r-mallya --- utils/wasi-nn/install-openvino-genai.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/utils/wasi-nn/install-openvino-genai.sh b/utils/wasi-nn/install-openvino-genai.sh index a877adac..2be3691e 100644 --- a/utils/wasi-nn/install-openvino-genai.sh +++ b/utils/wasi-nn/install-openvino-genai.sh @@ -33,4 +33,3 @@ rm ${OPENVINOGEN_TGZ_NAME}.tgz echo "OpenVINO GenAI installed at `pwd`/$OPENVINOGEN_DIRNAME" echo "Please source the setupvars.sh script in the OpenVINO GenAI directory to use the OpenVINO GenAI tools." echo "# source $PWD/$OPENVINOGEN_DIRNAME/setupvars.sh" - From 5eeb4dd945bf41b3d5f1c3a2d78f9bb74b459549 Mon Sep 17 00:00:00 2001 From: Han-Wen Tsao Date: Tue, 29 Apr 2025 01:47:48 +0800 Subject: [PATCH 552/623] feat(WASI-NN/MLX): support gemma3 for mlx plugin (#4085) * feat(WASI-NN/MLX): add get gemma3 model Signed-off-by: grorge * fix(WASI-NN/MLX): llm ci test failed Signed-off-by: grorge * feat(WASI-NN/MLX): support gemma3 for mlx plugin Signed-off-by: grorge * refactor(WASI-NN/MLX): add default parameter Signed-off-by: grorge * feat(WASI-NN/MLX): add eos token Signed-off-by: grorge * feat(WASI-NN/MLX): add timer Signed-off-by: grorge --------- Signed-off-by: grorge --- plugins/wasi_nn/CMakeLists.txt | 11 +- plugins/wasi_nn/MLX/mlx/activations.cpp | 7 + plugins/wasi_nn/MLX/mlx/activations.h | 2 +- plugins/wasi_nn/MLX/mlx/base.cpp | 5 + plugins/wasi_nn/MLX/mlx/convolution.cpp | 20 + plugins/wasi_nn/MLX/mlx/convolution.h | 42 ++ plugins/wasi_nn/MLX/mlx/normalization.cpp | 13 + plugins/wasi_nn/MLX/mlx/normalization.h | 19 + plugins/wasi_nn/MLX/mlx/pooling.cpp | 180 ++++++ plugins/wasi_nn/MLX/mlx/pooling.h | 45 ++ plugins/wasi_nn/MLX/mlx/quantized.cpp | 9 +- plugins/wasi_nn/MLX/model/gemma3/gemma3.cpp | 243 ++++++++ plugins/wasi_nn/MLX/model/gemma3/gemma3.h | 66 +++ plugins/wasi_nn/MLX/model/gemma3/language.cpp | 342 +++++++++++ plugins/wasi_nn/MLX/model/gemma3/language.h | 126 ++++ plugins/wasi_nn/MLX/model/gemma3/vision.cpp | 277 +++++++++ plugins/wasi_nn/MLX/model/gemma3/vision.h | 114 ++++ .../wasi_nn/MLX/model/{ => llm}/registry.cpp | 8 +- .../wasi_nn/MLX/model/{ => llm}/registry.h | 6 +- plugins/wasi_nn/MLX/model/llm/transformer.cpp | 282 +++++++++ plugins/wasi_nn/MLX/model/llm/transformer.h | 231 ++++++++ plugins/wasi_nn/MLX/model/transformer.cpp | 3 + plugins/wasi_nn/MLX/model/transformer.h | 3 + plugins/wasi_nn/MLX/model/vlm_base.cpp | 538 ++++++++++++++++++ plugins/wasi_nn/MLX/model/vlm_base.h | 177 ++++++ plugins/wasi_nn/MLX/model/vlm_sampling.cpp | 30 + plugins/wasi_nn/MLX/model/vlm_sampling.h | 13 + plugins/wasi_nn/wasinn_mlx.cpp | 296 +++++++--- plugins/wasi_nn/wasinn_mlx.h | 24 +- 29 files changed, 3022 insertions(+), 110 deletions(-) create mode 100644 plugins/wasi_nn/MLX/mlx/convolution.cpp create mode 100644 plugins/wasi_nn/MLX/mlx/convolution.h create mode 100644 plugins/wasi_nn/MLX/mlx/pooling.cpp create mode 100644 plugins/wasi_nn/MLX/mlx/pooling.h create mode 100644 plugins/wasi_nn/MLX/model/gemma3/gemma3.cpp create mode 100644 plugins/wasi_nn/MLX/model/gemma3/gemma3.h create mode 100644 plugins/wasi_nn/MLX/model/gemma3/language.cpp create mode 100644 plugins/wasi_nn/MLX/model/gemma3/language.h create mode 100644 plugins/wasi_nn/MLX/model/gemma3/vision.cpp create mode 100644 plugins/wasi_nn/MLX/model/gemma3/vision.h rename plugins/wasi_nn/MLX/model/{ => llm}/registry.cpp (94%) rename plugins/wasi_nn/MLX/model/{ => llm}/registry.h (94%) create mode 100644 plugins/wasi_nn/MLX/model/llm/transformer.cpp create mode 100644 plugins/wasi_nn/MLX/model/llm/transformer.h create mode 100644 plugins/wasi_nn/MLX/model/vlm_base.cpp create mode 100644 plugins/wasi_nn/MLX/model/vlm_base.h create mode 100644 plugins/wasi_nn/MLX/model/vlm_sampling.cpp create mode 100644 plugins/wasi_nn/MLX/model/vlm_sampling.h diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 59324375..896f1de1 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -43,17 +43,24 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) target_sources(wasmedgePluginWasiNN PRIVATE MLX/prompt/prompt.cpp - MLX/model/transformer.cpp + MLX/model/llm/transformer.cpp + MLX/model/llm/registry.cpp + MLX/model/gemma3/language.cpp + MLX/model/gemma3/vision.cpp + MLX/model/gemma3/gemma3.cpp MLX/model/converter.cpp MLX/model/utils.cpp - MLX/model/registry.cpp + MLX/model/vlm_base.cpp + MLX/model/vlm_sampling.cpp MLX/mlx/base.cpp MLX/mlx/linear.cpp + MLX/mlx/convolution.cpp MLX/mlx/positional_encoding.cpp MLX/mlx/activations.cpp MLX/mlx/embedding.cpp MLX/mlx/normalization.cpp MLX/mlx/transformer.cpp + MLX/mlx/pooling.cpp MLX/mlx/quantized.cpp ) endif() diff --git a/plugins/wasi_nn/MLX/mlx/activations.cpp b/plugins/wasi_nn/MLX/mlx/activations.cpp index 64b7202e..f94695e2 100644 --- a/plugins/wasi_nn/MLX/mlx/activations.cpp +++ b/plugins/wasi_nn/MLX/mlx/activations.cpp @@ -16,5 +16,12 @@ mx::array gelu(mx::array X) { mx::array silu(mx::array X) { return X * mx::sigmoid(X); } +mx::array geluApprox(mx::array X) { + return mx::array({0.5}, X.dtype()) * X * + (mx::array({1}, X.dtype()) + + mx::tanh(mx::array({std::sqrt(2.0 / M_PI)}, X.dtype()) * + (X + mx::array({0.044715}, X.dtype()) * + mx::power(X, mx::array({3}, X.dtype()))))); +} } // namespace mlx::core } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/activations.h b/plugins/wasi_nn/MLX/mlx/activations.h index b6d75127..15d0dfd4 100644 --- a/plugins/wasi_nn/MLX/mlx/activations.h +++ b/plugins/wasi_nn/MLX/mlx/activations.h @@ -13,6 +13,6 @@ namespace mlx::core { mx::array gelu(mx::array X); mx::array silu(mx::array X); - +mx::array geluApprox(mx::array X); } // namespace mlx::core } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/base.cpp b/plugins/wasi_nn/MLX/mlx/base.cpp index 9afc8b6f..636a78f1 100644 --- a/plugins/wasi_nn/MLX/mlx/base.cpp +++ b/plugins/wasi_nn/MLX/mlx/base.cpp @@ -23,6 +23,11 @@ void Module::update(std::unordered_map Parameters) { std::shared_ptr Module::toQuantized(int GroupSize, int Bits) { for (auto &[K, V] : Submodules) { const auto OldModule = V; + auto Weights = V->Parameters.find("weight"); + if (Weights != V->Parameters.end() && + Weights->second.shape().back() % GroupSize != 0) { + continue; + } V = V->toQuantized(GroupSize, Bits); } return shared_from_this(); diff --git a/plugins/wasi_nn/MLX/mlx/convolution.cpp b/plugins/wasi_nn/MLX/mlx/convolution.cpp new file mode 100644 index 00000000..c7947707 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/convolution.cpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "convolution.h" +#include +#include +namespace WasmEdge::Host::WASINN::MLX { +namespace mlx::core::nn { + +mx::array Conv2d::forward(mx::array Input) { + auto Y = mx::conv2d(Input, Parameters.at("weight"), Stride, Padding, Dilation, + Groups); + if (Parameters.find("bias") != Parameters.end()) { + Y = Y + Parameters.at("bias"); + } + return Y; +} + +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/convolution.h b/plugins/wasi_nn/MLX/mlx/convolution.h new file mode 100644 index 00000000..9db61de1 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/convolution.h @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "base.h" + +namespace WasmEdge::Host::WASINN::MLX { +namespace mlx::core::nn { + +class Conv2d : public nn::Module { + std::pair Padding; + std::pair Stride; + std::pair Dilation; + int Groups; + +public: + Conv2d(int InChannels, int OutChannels, int KernelSize, + std::pair Stride = {1, 1}, + std::pair Padding = {0, 0}, + std::pair Dilation = {1, 1}, int Groups = 1, + bool Bias = true) + : Padding(Padding), Stride(Stride), Dilation(Dilation), Groups(Groups) { + + if (InChannels % Groups != 0) { + // InChannels must be divisible by Groups + assumingUnreachable(); + } + double Scale = std::sqrt(1.0 / (InChannels * KernelSize * KernelSize)); + registerParameter("weight", mx::random::uniform(-Scale, Scale, + {OutChannels, InChannels, + KernelSize, KernelSize})); + if (Bias) { + registerParameter("bias", mx::zeros({OutChannels})); + } + } + mx::array forward(mx::array Input); +}; + +} // namespace mlx::core::nn + +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/normalization.cpp b/plugins/wasi_nn/MLX/mlx/normalization.cpp index 501bd5f0..dfc95205 100644 --- a/plugins/wasi_nn/MLX/mlx/normalization.cpp +++ b/plugins/wasi_nn/MLX/mlx/normalization.cpp @@ -10,5 +10,18 @@ mx::array RMSNorm::forward(mx::array Input) { return mx::fast::rms_norm(Input, Parameters.at("weight"), Eps); } +mx::array LayerNorm::forward(mx::array Input) { + std::optional Weight = {}; + std::optional Bias = {}; + if (Parameters.find("weight") != Parameters.end()) { + Weight = Parameters.at("weight"); + } + if (Parameters.find("bias") != Parameters.end()) { + Bias = Parameters.at("bias"); + } + + return mx::fast::layer_norm(Input, Weight, Bias, Eps); +} + } // namespace mlx::core::nn } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/normalization.h b/plugins/wasi_nn/MLX/mlx/normalization.h index a09ccddd..81ee7e3e 100644 --- a/plugins/wasi_nn/MLX/mlx/normalization.h +++ b/plugins/wasi_nn/MLX/mlx/normalization.h @@ -1,6 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC +#pragma once + #include "mlx/base.h" namespace WasmEdge::Host::WASINN::MLX { @@ -13,9 +15,26 @@ class RMSNorm : public nn::Module { RMSNorm(int Dims, float Eps = 1e-5) : Eps(Eps) { registerParameter("weight", mx::ones({Dims})); } + mx::array forward(mx::array Input); +}; + +class LayerNorm : public nn::Module { + int Dims; + float Eps; +public: + LayerNorm(int Dims, float Eps = 1e-5, bool Affine = true, bool Bias = true) + : Dims(Dims), Eps(Eps) { + if (Affine) { + registerParameter("weight", mx::ones({Dims})); + if (Bias) { + registerParameter("bias", mx::zeros({Dims})); + } + } + } mx::array forward(mx::array Input); }; } // namespace mlx::core::nn + } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/pooling.cpp b/plugins/wasi_nn/MLX/mlx/pooling.cpp new file mode 100644 index 00000000..4c5ff591 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/pooling.cpp @@ -0,0 +1,180 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "pooling.h" +#include "base.h" +#include "spdlog/spdlog.h" +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace mlx::core::nn { + +namespace { +std::vector valueOrList(const int Value, int Len) { + return std::vector(Len, Value); +} + +std::vector> +makePaddingPairs(const std::vector &Pads) { + std::vector> Pairs; + for (int P : Pads) { + Pairs.push_back({P, P}); + } + return Pairs; +} + +std::vector> +makePadShape(const std::vector> &Padding) { + std::vector> PadShape; + PadShape.push_back({0, 0}); + for (auto &P : Padding) { + PadShape.push_back(P); + } + PadShape.push_back({0, 0}); + return PadShape; +} + +} // namespace + +mx::array nonOverlappingSlidingWindows(const mx::array &X, + const std::vector &Shape, + const std::vector &WindowShape) { + std::vector NewShape; + NewShape.push_back(Shape[0]); + for (size_t I = 1; I < std::min(Shape.size(), WindowShape.size() + 1); I++) { + int S = Shape[I]; + int W = WindowShape[I - 1]; + NewShape.push_back(S / W); + NewShape.push_back(W); + } + NewShape.push_back(Shape.back()); + int LastAxis = NewShape.size() - 1; + std::vector AxisOrder; + AxisOrder.push_back(0); + for (int I = 1; I < LastAxis; I += 2) { + AxisOrder.push_back(I); + } + for (int I = 2; I < LastAxis; I += 2) { + AxisOrder.push_back(I); + } + AxisOrder.push_back(LastAxis); + return transpose(reshape(X, NewShape), AxisOrder); +} + +mx::array slidingWindows(const mx::array &X, + const std::vector &WindowShape, + const std::vector &WindowStrides) { + if (X.ndim() < 3) { + spdlog::error( + "To extract sliding windows at least 1 spatial dimension (3 total) is " + "needed but the input only has " + + std::to_string(X.ndim()) + " dimensions."); + assumingUnreachable(); + } + std::vector Shape = X.shape(); + size_t SpatialDimsCount = Shape.size() - 2; + if (SpatialDimsCount != WindowShape.size() || + WindowShape.size() != WindowStrides.size()) { + + spdlog::error( + "The window shapes and strides must have the same number of spatial " + "dimensions as the signal."); + assumingUnreachable(); + } + bool UseNonOverlap = true; + for (size_t I = 0; I < SpatialDimsCount; I++) { + if (!(WindowShape[I] == WindowStrides[I] && + (Shape[I + 1] % WindowShape[I] == 0))) { + UseNonOverlap = false; + break; + } + } + if (UseNonOverlap) + return nonOverlappingSlidingWindows(X, Shape, WindowShape); + size_t N = Shape.size(); + std::vector BaseStrides(N); + BaseStrides[N - 1] = 1; + for (int I = N - 2; I >= 0; I--) { + BaseStrides[I] = Shape[I + 1] * BaseStrides[I + 1]; + } + std::vector FinalShape; + FinalShape.push_back(Shape[0]); + for (size_t I = 0; I < SpatialDimsCount; I++) { + int OutDim = (Shape[I + 1] - WindowShape[I]) / WindowStrides[I] + 1; + FinalShape.push_back(OutDim); + } + FinalShape.insert(FinalShape.end(), WindowShape.begin(), WindowShape.end()); + FinalShape.push_back(Shape.back()); + std::vector FinalStrides; + FinalStrides.push_back(BaseStrides[0]); + for (size_t I = 1; I < BaseStrides.size() - 1; I++) { + FinalStrides.push_back(BaseStrides[I] * WindowStrides[I - 1]); + } + for (size_t I = 1; I < BaseStrides.size(); I++) { + FinalStrides.push_back(BaseStrides[I]); + } + return mx::as_strided(X, FinalShape, FinalStrides, 0); +} + +Pool::Pool( + const std::function &)> + &PoolingFunction, + const std::vector &KernelSize, const std::vector &Stride, + const std::vector> &Padding, int PaddingValue) + : PoolingFunction(PoolingFunction), KernelSize(KernelSize), Stride(Stride), + Padding(Padding), PaddingValue(PaddingValue) { + int N = KernelSize.size(); + for (int I = -(N + 1); I < -1; I++) { + Axes.push_back(I); + } +} + +mx::array Pool::forward(const mx::array &X) { + mx::array Out = X; + bool NeedPad = false; + for (auto &P : Padding) { + if (P.first > 0) { + NeedPad = true; + break; + } + } + if (NeedPad) { + std::vector> PadShape = makePadShape(Padding); + Out = mx::pad(Out, PadShape, mx::array(PaddingValue)); + } + Out = slidingWindows(Out, KernelSize, Stride); + return PoolingFunction(Out, Axes); +} + +Pool2d::Pool2d( + const std::function &)> + &PoolingFunction, + int PaddingValue, const std::vector &KernelSize, + const std::optional> &StrideOpt, + const std::optional> &PaddingOpt) + : Pool(PoolingFunction, + KernelSize.size() == 1 ? valueOrList(KernelSize[0], 2) : KernelSize, + (StrideOpt.has_value() + ? (StrideOpt.value().size() == 1 + ? valueOrList(StrideOpt.value()[0], 2) + : StrideOpt.value()) + : (KernelSize.size() == 1 ? valueOrList(KernelSize[0], 2) + : KernelSize)), + makePaddingPairs(PaddingOpt.has_value() ? PaddingOpt.value() + : valueOrList(0, 2)), + PaddingValue) {} + +AvgPool2d::AvgPool2d(const std::vector &KernelSize, + const std::optional> &Stride, + const std::optional> &Padding) + : Pool2d( + [](const mx::array &A, const std::vector &Axis) -> mx::array { + return mx::mean(A, Axis, false); + }, + 0, KernelSize, Stride, Padding) {} + +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/pooling.h b/plugins/wasi_nn/MLX/mlx/pooling.h new file mode 100644 index 00000000..1e14d0f4 --- /dev/null +++ b/plugins/wasi_nn/MLX/mlx/pooling.h @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "base.h" +namespace WasmEdge::Host::WASINN::MLX { + +namespace mlx::core::nn { +class Pool : public nn::Module { +public: + Pool(const std::function &)> &PoolingFunction, + const std::vector &KernelSize, const std::vector &Stride, + const std::vector> &Padding, int PaddingValue); + mx::array forward(const mx::array &X); + +protected: + std::function &)> + PoolingFunction; + std::vector KernelSize; + std::vector Stride; + std::vector> Padding; + int PaddingValue; + std::vector Axes; +}; + +class Pool2d : public Pool { +public: + Pool2d(const std::function &)> &PoolingFunction, + int PaddingValue, const std::vector &KernelSize, + const std::optional> &Stride, + const std::optional> &Padding); +}; + +class AvgPool2d : public Pool2d { +public: + AvgPool2d(const std::vector &KernelSize, + const std::optional> &Stride = std::nullopt, + const std::optional> &Padding = std::nullopt); +}; + +} // namespace mlx::core::nn +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/quantized.cpp b/plugins/wasi_nn/MLX/mlx/quantized.cpp index dee36264..02de23a4 100644 --- a/plugins/wasi_nn/MLX/mlx/quantized.cpp +++ b/plugins/wasi_nn/MLX/mlx/quantized.cpp @@ -19,8 +19,7 @@ mx::array QuantizedEmbedding::forward(mx::array Input) { mx::dequantize(take(Parameters.at("weight"), Input, 0), take(Parameters.at("scales"), Input, 0), take(Parameters.at("biases"), Input, 0), GroupSize, Bits); - S.emplace_back(-1); - return reshape(Out, {S}); + return Out; } mx::array QuantizedLinear::forward(mx::array Input) { @@ -53,10 +52,12 @@ std::shared_ptr QuantizedLinear::fromLinear(std::shared_ptr LinearModule, int GroupSize, int Bits) { auto LinearShape = LinearModule->Parameters.at("weight").shape(); + auto OutputDims = LinearShape[0]; + auto InputDims = LinearShape[1]; const bool EnableBias = LinearModule->Parameters.find("bias") != LinearModule->Parameters.end(); - auto QuantizedModel = std::make_shared(QuantizedLinear( - LinearShape[0], LinearShape[1], EnableBias, GroupSize, Bits)); + auto QuantizedModel = std::make_shared( + QuantizedLinear(InputDims, OutputDims, EnableBias, GroupSize, Bits)); auto Quantized = mx::quantize(LinearModule->Parameters.at("weight"), GroupSize, Bits); QuantizedModel->Parameters.insert_or_assign("weight", std::get<0>(Quantized)); diff --git a/plugins/wasi_nn/MLX/model/gemma3/gemma3.cpp b/plugins/wasi_nn/MLX/model/gemma3/gemma3.cpp new file mode 100644 index 00000000..e1ea9a24 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/gemma3/gemma3.cpp @@ -0,0 +1,243 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "model/gemma3/gemma3.h" +#include "language.h" +#include "mlx/embedding.h" +#include "vision.h" +#include +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace gemma3 { + +ModelConfig ModelConfig::fromDict(const simdjson::dom::object &Obj) { + ModelConfig Config; + auto ModelTypeResult = Obj["model_type"].get_string(); + if (!ModelTypeResult.error()) { + Config.ModelType = std::string(ModelTypeResult.value()); + } + auto VocabResult = Obj["vocab_size"].get_int64(); + if (!VocabResult.error()) { + Config.VocabSize = static_cast(VocabResult.value()); + } + auto IgnoreResult = Obj["ignore_index"].get_int64(); + if (!IgnoreResult.error()) { + Config.IgnoreIndex = static_cast(IgnoreResult.value()); + } + auto ImageTokenResult = Obj["image_token_index"].get_int64(); + if (!ImageTokenResult.error()) { + Config.ImageTokenIndex = static_cast(ImageTokenResult.value()); + } + auto HiddenSizeResult = Obj["hidden_size"].get_int64(); + if (!HiddenSizeResult.error()) { + Config.HiddenSize = static_cast(HiddenSizeResult.value()); + } + auto PadTokenResult = Obj["pad_token_id"].get_int64(); + if (!PadTokenResult.error()) { + Config.PadTokenId = static_cast(PadTokenResult.value()); + } + return Config; +} + +Gemma3MultiModalProjector::Gemma3MultiModalProjector( + const ModelConfig &Config) { + registerModule("mm_soft_emb_norm", + std::make_shared(Config.VisionConfig.HiddenSize, + Config.VisionConfig.LayerNormEps)); + registerParameter( + "mm_input_projection_weight", + mx::ones({Config.VisionConfig.HiddenSize, Config.TextConfig.HiddenSize})); + PatchesPerImage = + Config.VisionConfig.ImageSize / Config.VisionConfig.PatchSize; + TokensPerSide = static_cast( + std::sqrt(static_cast(Config.TextConfig.MmTokensPerImage))); + KernelSize = PatchesPerImage / TokensPerSide; + AvgPool = + nn::AvgPool2d(std::vector{KernelSize}, std::vector{KernelSize}); +} + +mx::array Gemma3MultiModalProjector::forward(const mx::array &X) { + int B = X.shape()[0]; + int L = X.shape()[2]; + mx::array ReshapedVisionOutputs = transpose(X, {0, 2, 1}); + ReshapedVisionOutputs = + reshape(ReshapedVisionOutputs, {B, L, PatchesPerImage, PatchesPerImage}); + ReshapedVisionOutputs = transpose(ReshapedVisionOutputs, {0, 2, 3, 1}); + mx::array PooledVisionOutputs = AvgPool.forward(ReshapedVisionOutputs); + PooledVisionOutputs = transpose(PooledVisionOutputs, {0, 3, 1, 2}); + PooledVisionOutputs = flatten(PooledVisionOutputs, 2); + PooledVisionOutputs = transpose(PooledVisionOutputs, {0, 2, 1}); + mx::array NormedVisionOutputs = + std::dynamic_pointer_cast(Submodules["mm_soft_emb_norm"]) + ->forward(PooledVisionOutputs); + mx::array ProjectedVisionOutputs = mx::einsum( + "btm,md->btd", + std::vector( + {NormedVisionOutputs, Parameters.at("mm_input_projection_weight")})); + return astype(ProjectedVisionOutputs, X.dtype()); +} + +Model::Model(const ModelConfig &Config) : Config(Config) { + registerModule("vision_tower", + std::make_shared(Config.VisionConfig)); + registerModule("language_model", + std::make_shared(Config.TextConfig)); + registerModule("multi_modal_projector", + std::make_shared(Config)); + ModelType = Config.ModelType; +} + +std::pair +Model::getInputEmbeddings(const mx::array &InputIds, + const mx::array &PixelValues, const mx::array &Mask) { + if (PixelValues.size() == 0) { + mx::array Embeds = std::dynamic_pointer_cast( + Submodules["language_model"] + ->Submodules["model"] + ->Submodules["embed_tokens"]) + ->forward(InputIds); + return {Embeds, mx::array({})}; + } + mx::array InputsEmbeds = + std::dynamic_pointer_cast(Submodules["language_model"] + ->Submodules["model"] + ->Submodules["embed_tokens"]) + ->forward(InputIds); + mx::array HiddenState = mx::array({}), Temp1 = mx::array({}), + Temp2 = mx::array({}); + std::tie(HiddenState, Temp1, Temp2) = + std::dynamic_pointer_cast(Submodules["vision_tower"]) + ->forward(astype(transpose(PixelValues, {0, 2, 3, 1}), + InputsEmbeds.dtype()), + true); + auto NewShape = HiddenState.shape(); + NewShape.insert(NewShape.begin(), 1); + mx::array ImageFeatures = + astype(reshape(HiddenState, NewShape), PixelValues.dtype()); + ImageFeatures = std::dynamic_pointer_cast( + Submodules["multi_modal_projector"]) + ->forward(ImageFeatures); + return _prepareInputsForMultimodal(ImageFeatures, InputsEmbeds, InputIds, + Mask); +} + +std::pair Model::_prepareInputsForMultimodal( + const mx::array &ImageFeatures, const mx::array &InputsEmbeds, + const mx::array &InputIds, const mx::array &AttentionMask) { + int EmbedDim = ImageFeatures.shape().back(); + int BatchSize = InputIds.shape()[0]; + int SequenceLength = InputIds.shape()[1]; + mx::array ScaledImageFeatures = + ImageFeatures / std::pow(Config.HiddenSize, 0.5); + mx::array FinalEmbedding = mx::zeros({BatchSize, SequenceLength, EmbedDim}); + int PadTokenId = Config.PadTokenId; + mx::array TextMask = + (InputIds != Config.ImageTokenIndex) & (InputIds != PadTokenId); + mx::array ImageMask = (InputIds == Config.ImageTokenIndex); + mx::array PadMask = (InputIds == PadTokenId); + mx::array TextMaskExpanded = expand_dims(TextMask, -1); + TextMaskExpanded = repeat(TextMaskExpanded, EmbedDim, -1); + mx::array PadMaskExpanded = expand_dims(PadMask, -1); + PadMaskExpanded = repeat(PadMaskExpanded, EmbedDim, -1); + FinalEmbedding = mx::where(TextMaskExpanded, InputsEmbeds, FinalEmbedding); + FinalEmbedding = mx::where(PadMaskExpanded, mx::zeros_like(FinalEmbedding), + FinalEmbedding); + int PadSize = FinalEmbedding.shape()[1] - ScaledImageFeatures.shape()[1]; + ScaledImageFeatures = + mx::pad(ScaledImageFeatures, {{0, 0}, {0, PadSize}, {0, 0}}); + mx::array ImageMaskExpanded = expand_dims(ImageMask, -1); + ImageMaskExpanded = repeat(ImageMaskExpanded, EmbedDim, -1); + FinalEmbedding = + mx::where(ImageMaskExpanded, ScaledImageFeatures, FinalEmbedding); + FinalEmbedding = mx::where(PadMaskExpanded, mx::zeros_like(FinalEmbedding), + FinalEmbedding); + mx::array AttentionMaskExpanded1 = expand_dims(AttentionMask, 1); + mx::array AttentionMaskExpanded2 = expand_dims(AttentionMask, 2); + mx::array FinalAttentionMask4d = + AttentionMaskExpanded1 * AttentionMaskExpanded2; + FinalAttentionMask4d = expand_dims(FinalAttentionMask4d, 1); + return {FinalEmbedding, FinalAttentionMask4d}; +} + +std::tuple> Model::forward( + const mx::array &InputIds, const mx::array &PixelValues, + const mx::array &Mask, + const std::optional>> &Cache) { + auto Pair = getInputEmbeddings(InputIds, PixelValues, Mask); + mx::array InputEmbeds = Pair.first; + + // TODO: Waiting for upstream fix Mask + auto Logits = std::dynamic_pointer_cast( + Submodules["language_model"]) + ->forward(InputIds, InputEmbeds, std::nullopt, Cache); + return Logits; +} + +std::shared_ptr Model::fromPretrained(const std::string &ModelPath) { + std::filesystem::path Path(ModelPath); + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + auto Error = Parser.load((Path / "config.json").string()).get(Doc); + if (Error) { + spdlog::error("Could not open config.json"); + assumingUnreachable(); + } + auto Obj = Doc.get_object(); + ModelConfig ModelConfigObj = ModelConfig::fromDict(Obj.value()); + ModelConfigObj.VisionConfig = + VisionConfig::fromDict(Obj["vision_config"].get_object().value()); + ModelConfigObj.TextConfig = + TextConfig::fromDict(Obj["text_config"].get_object().value()); + auto Model = std::make_shared(gemma3::Model(ModelConfigObj)); + auto QuantResult = Obj["quantization"].get_object(); + if (!QuantResult.error()) { + auto GroupSize = + static_cast(QuantResult.value()["group_size"].get_int64()); + auto Bits = static_cast(QuantResult.value()["bits"].get_int64()); + spdlog::info("Quantizing model to {} bits, {} group size.", Bits, + GroupSize); + Model = std::dynamic_pointer_cast( + Model->toQuantized(GroupSize, Bits)); + } + std::vector WeightFiles; + for (auto &P : std::filesystem::directory_iterator(Path)) { + if (P.path().extension() == ".safetensors") + WeightFiles.push_back(P.path()); + } + if (WeightFiles.empty()) { + spdlog::error("[WASI-NN] MLX backend: No safetensors found in {}."sv, + Path.string()); + assumingUnreachable(); + } + std::unordered_map Weights; + for (auto &Wf : WeightFiles) { + auto W = mx::load_safetensors(Wf.string()); + Weights.insert(W.first.begin(), W.first.end()); + } + Weights = Model->sanitize(Weights); + Weights = gemma3::VisionModel(ModelConfigObj.VisionConfig).sanitize(Weights); + Model->update(Weights); + return Model; +} + +std::unordered_map +Model::sanitize(const std::unordered_map &Weights) { + std::unordered_map Sanitized; + for (auto &Pair : Weights) { + std::string Key = Pair.first; + if (Key.find("vision_tower") == std::string::npos) { + size_t Pos = Key.find("vision_model"); + if (Pos != std::string::npos) + Key.replace(Pos, std::string("vision_model").length(), "vision_tower"); + } + Sanitized.insert({Key, Pair.second}); + } + return Sanitized; +} + +} // namespace gemma3 +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/gemma3/gemma3.h b/plugins/wasi_nn/MLX/model/gemma3/gemma3.h new file mode 100644 index 00000000..77ba7346 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/gemma3/gemma3.h @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "mlx/base.h" +#include "mlx/pooling.h" +#include "model/gemma3/language.h" +#include "model/gemma3/vision.h" +#include +#include +#include +#include +#include +namespace WasmEdge::Host::WASINN::MLX { + +namespace nn = mlx::core::nn; +namespace gemma3 { + +struct ModelConfig { + TextConfig TextConfig; + VisionConfig VisionConfig; + std::string ModelType = "gemma3"; + int VocabSize = 262208; + int IgnoreIndex = -100; + int ImageTokenIndex = 262144; + int HiddenSize = 2048; + int PadTokenId = 0; + static ModelConfig fromDict(const simdjson::dom::object &Obj); +}; + +class Gemma3MultiModalProjector : public nn::Module { +public: + Gemma3MultiModalProjector(const ModelConfig &Config); + mx::array forward(const mx::array &X); + +private: + nn::AvgPool2d AvgPool = nn::AvgPool2d({0}); + int PatchesPerImage; + int TokensPerSide; + int KernelSize; +}; + +class Model : public vlm::Module { +public: + Model(const ModelConfig &Config); + std::pair + getInputEmbeddings(const mx::array &InputIds, const mx::array &PixelValues, + const mx::array &Mask); + std::pair _prepareInputsForMultimodal( + const mx::array &ImageFeatures, const mx::array &InputsEmbeds, + const mx::array &InputIds, const mx::array &AttentionMask); + std::tuple> forward( + const mx::array &InputIds, const mx::array &PixelValues, + const mx::array &Mask, + const std::optional>> &Cache = + std::nullopt) override; + static std::shared_ptr fromPretrained(const std::string &ModelPath); + std::unordered_map + sanitize(const std::unordered_map &Weights); + ModelConfig Config; + std::string ModelType; +}; + +} // namespace gemma3 +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/gemma3/language.cpp b/plugins/wasi_nn/MLX/model/gemma3/language.cpp new file mode 100644 index 00000000..138b783a --- /dev/null +++ b/plugins/wasi_nn/MLX/model/gemma3/language.cpp @@ -0,0 +1,342 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "model/gemma3/language.h" +#include "mlx/activations.h" +#include "mlx/base.h" +#include "mlx/embedding.h" +#include "mlx/linear.h" +#include "mlx/normalization.h" +#include "mlx/positional_encoding.h" +#include +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace gemma3 { + +TextConfig TextConfig::fromDict(const simdjson::dom::object &Obj) { + TextConfig Config; + auto SResult = Obj["model_type"].get_string(); + if (!SResult.error()) + Config.ModelType = std::string(SResult.value()); + auto IResult = Obj["hidden_size"].get_int64(); + if (!IResult.error()) + Config.HiddenSize = static_cast(IResult.value()); + IResult = Obj["num_hidden_layers"].get_int64(); + if (!IResult.error()) + Config.NumHiddenLayers = static_cast(IResult.value()); + IResult = Obj["intermediate_size"].get_int64(); + if (!IResult.error()) + Config.IntermediateSize = static_cast(IResult.value()); + IResult = Obj["num_attention_heads"].get_int64(); + if (!IResult.error()) + Config.NumAttentionHeads = static_cast(IResult.value()); + IResult = Obj["head_dim"].get_int64(); + if (!IResult.error()) + Config.HeadDim = static_cast(IResult.value()); + auto DResult = Obj["rms_norm_eps"].get_double(); + if (!DResult.error()) + Config.RmsNormEps = static_cast(DResult.value()); + IResult = Obj["vocab_size"].get_int64(); + if (!IResult.error()) + Config.VocabSize = static_cast(IResult.value()); + IResult = Obj["num_key_value_heads"].get_int64(); + if (!IResult.error()) + Config.NumKeyValueHeads = static_cast(IResult.value()); + DResult = Obj["rope_global_base_freq"].get_double(); + if (!DResult.error()) + Config.RopeGlobalBaseFreq = static_cast(DResult.value()); + DResult = Obj["rope_local_base_freq"].get_double(); + if (!DResult.error()) + Config.RopeLocalBaseFreq = static_cast(DResult.value()); + auto BResult = Obj["rope_traditional"].get_bool(); + if (!BResult.error()) + Config.RopeTraditional = BResult.value(); + DResult = Obj["query_pre_attn_scalar"].get_double(); + if (!DResult.error()) + Config.QueryPreAttnScalar = static_cast(DResult.value()); + IResult = Obj["sliding_window"].get_int64(); + if (!IResult.error()) + Config.SlidingWindow = static_cast(IResult.value()); + IResult = Obj["mm_tokens_per_image"].get_int64(); + if (!IResult.error()) + Config.MmTokensPerImage = static_cast(IResult.value()); + IResult = Obj["sliding_window_pattern"].get_int64(); + if (!IResult.error()) + Config.SlidingWindowPattern = static_cast(IResult.value()); + return Config; +} + +RMSNorm::RMSNorm(int Dims, float Eps) : Eps(Eps) { + registerParameter("weight", mx::ones({Dims})); +} + +mx::array RMSNorm::forward(const mx::array &X) { + return mx::fast::rms_norm(X, + mx::array({1.0}, Parameters.at("weight").dtype()) + + Parameters.at("weight"), + Eps); +} + +Attention::Attention(const TextConfig &Config, int LayerIdx) + : NHeads(Config.NumAttentionHeads), NKVHeads(Config.NumKeyValueHeads), + Repeats(Config.NumAttentionHeads / Config.NumKeyValueHeads), + HeadDim(Config.HeadDim), LayerIdx(LayerIdx), + Scale(std::pow(Config.QueryPreAttnScalar, -0.5)), + QNorm(HeadDim, Config.RmsNormEps), KNorm(HeadDim, Config.RmsNormEps) { + registerModule("q_proj", std::make_shared( + Config.HiddenSize, NHeads * HeadDim, false)); + registerModule("k_proj", std::make_shared( + Config.HiddenSize, NKVHeads * HeadDim, false)); + registerModule("v_proj", std::make_shared( + Config.HiddenSize, NKVHeads * HeadDim, false)); + registerModule("o_proj", std::make_shared( + NHeads * HeadDim, Config.HiddenSize, false)); + registerModule("q_norm", + std::make_shared(HeadDim, Config.RmsNormEps)); + registerModule("k_norm", + std::make_shared(HeadDim, Config.RmsNormEps)); + IsSliding = ((LayerIdx + 1) % Config.SlidingWindowPattern) != 0; + registerModule("rope", std::make_shared( + HeadDim, Config.RopeTraditional, + (IsSliding ? Config.RopeLocalBaseFreq + : Config.RopeGlobalBaseFreq))); +} + +mx::array Attention::forward( + const mx::array &X, const std::optional &Mask, + const std::optional> &Cache) { + auto Shape = X.shape(); + int B = Shape[0], L = Shape[1]; + mx::array Queries = + std::dynamic_pointer_cast(Submodules["q_proj"])->forward(X); + mx::array Keys = + std::dynamic_pointer_cast(Submodules["k_proj"])->forward(X); + mx::array Values = + std::dynamic_pointer_cast(Submodules["v_proj"])->forward(X); + Queries = transpose(reshape(Queries, {B, L, NHeads, -1}), {0, 2, 1, 3}); + Keys = transpose(reshape(Keys, {B, L, NKVHeads, -1}), {0, 2, 1, 3}); + Values = transpose(reshape(Values, {B, L, NKVHeads, -1}), {0, 2, 1, 3}); + Queries = std::dynamic_pointer_cast(Submodules["q_norm"]) + ->forward(Queries); + Keys = std::dynamic_pointer_cast(Submodules["k_norm"]) + ->forward(Keys); + if (Cache.has_value()) { + Queries = std::dynamic_pointer_cast(Submodules["rope"]) + ->forward(Queries, Cache.value()->Offset); + Keys = std::dynamic_pointer_cast(Submodules["rope"]) + ->forward(Keys, Cache.value()->Offset); + std::tie(Keys, Values) = Cache.value()->updateAndFetch(Keys, Values); + } else { + Queries = std::dynamic_pointer_cast(Submodules["rope"]) + ->forward(Queries); + Keys = + std::dynamic_pointer_cast(Submodules["rope"])->forward(Keys); + } + if (Mask.has_value() && + Mask.value().shape().back() != Keys.shape().at(Keys.shape().size() - 2)) { + mx::array M = + take(Mask.value(), -Keys.shape().at(Keys.shape().size() - 2), -1); + return std::dynamic_pointer_cast(Submodules["o_proj"]) + ->forward(transpose(reshape(mx::fast::scaled_dot_product_attention( + Queries, Keys, Values, Scale, M), + {B, L, -1}), + {0, 2, 1})); + } + auto Output = (Mask.has_value() + ? mx::fast::scaled_dot_product_attention( + Queries, Keys, Values, Scale, Mask.value()) + : mx::fast::scaled_dot_product_attention(Queries, Keys, + Values, Scale)); + Output = reshape(transpose(Output, {0, 2, 1, 3}), {B, L, -1}); + return std::dynamic_pointer_cast(Submodules["o_proj"]) + ->forward(Output); +} + +MLP::MLP(int Dim, int HiddenDim) { + registerModule("gate_proj", + std::make_shared(Dim, HiddenDim, false)); + registerModule("down_proj", + std::make_shared(HiddenDim, Dim, false)); + registerModule("up_proj", + std::make_shared(Dim, HiddenDim, false)); +} + +mx::array MLP::forward(const mx::array &X) { + mx::array A = std::dynamic_pointer_cast(Submodules["gate_proj"]) + ->forward(X); + A = mlx::core::geluApprox(A); + mx::array B = + std::dynamic_pointer_cast(Submodules["up_proj"])->forward(X); + A = A * B; + return std::dynamic_pointer_cast(Submodules["down_proj"]) + ->forward(A); +} + +TransformerBlock::TransformerBlock(const TextConfig &Config, int LayerIdx) + : NumAttentionHeads(Config.NumAttentionHeads), + HiddenSize(Config.HiddenSize) { + registerModule("self_attn", std::make_shared(Config, LayerIdx)); + registerModule( + "mlp", std::make_shared(Config.HiddenSize, Config.IntermediateSize)); + registerModule("input_layernorm", std::make_shared( + Config.HiddenSize, Config.RmsNormEps)); + registerModule( + "post_attention_layernorm", + std::make_shared(Config.HiddenSize, Config.RmsNormEps)); + registerModule( + "pre_feedforward_layernorm", + std::make_shared(Config.HiddenSize, Config.RmsNormEps)); + registerModule( + "post_feedforward_layernorm", + std::make_shared(Config.HiddenSize, Config.RmsNormEps)); +} + +mx::array TransformerBlock::forward( + const mx::array &Input, const std::optional &Mask, + const std::optional> &Cache) { + auto X = Input; + // Clip the input to avoid overflow in float16, but it make more memory usage + // if (X.dtype() == mx::bfloat16) { + // X = mx::clip(X, mx::array{-65504}, mx::array{65504}); + // } + mx::array R = std::dynamic_pointer_cast(Submodules["self_attn"]) + ->forward(std::dynamic_pointer_cast( + Submodules["input_layernorm"]) + ->forward(X), + Mask, Cache); + mx::array H = std::dynamic_pointer_cast( + Submodules["post_attention_layernorm"]) + ->forward(R); + // if (H.dtype() == mx::bfloat16) { + // H = mx::clip(astype(X, mx::float32) + astype(H, mx::float32), + // mx::array{-65504}, mx::array{65504}); + // } else { + H = X + H; + // } + R = std::dynamic_pointer_cast(Submodules["mlp"]) + ->forward(std::dynamic_pointer_cast( + Submodules["pre_feedforward_layernorm"]) + ->forward(H)); + auto Out = std::dynamic_pointer_cast( + Submodules["post_feedforward_layernorm"]) + ->forward(R); + // if (Out.dtype() == mx::bfloat16) { + // Out = clip(astype(H, mx::float32) + astype(Out, mx::float32), + // mx::array{-65504}, mx::array{65504}); + // } else { + Out = H + Out; + // } + return Out; +} + +Gemma3Model::Gemma3Model(const TextConfig &Config) : Config(Config) { + if (Config.VocabSize <= 0) { + assumingUnreachable(); + } + registerModule("embed_tokens", std::make_shared( + Config.VocabSize, Config.HiddenSize)); + for (int I = 0; I < Config.NumHiddenLayers; I++) { + Layers.push_back(std::make_shared(Config, I)); + } + registerLayer("layers", Layers); + registerModule("norm", std::make_shared(Config.HiddenSize, + Config.RmsNormEps)); +} + +mx::array Gemma3Model::forward( + const mx::array &Inputs, const std::optional &InputsEmbeds, + const std::optional &Mask, + const std::optional>> &Cache) { + mx::array H = + InputsEmbeds.has_value() + ? InputsEmbeds.value() + : std::dynamic_pointer_cast(Submodules["embed_tokens"]) + ->forward(Inputs); + H = H * astype(mx::array(std::pow(Config.HiddenSize, 0.5), mx::bfloat16), + H.dtype()); + std::vector> CacheValue = + Cache.has_value() ? Cache.value() + : std::vector>( + Config.NumHiddenLayers, nullptr); + std::optional FullMask = std::nullopt; + std::optional SlidingWindowMask = std::nullopt; + if (!Mask.has_value()) { + int J = Config.SlidingWindowPattern; + FullMask = vlm::createAttentionMask( + H, std::vector>( + CacheValue.begin() + J - 1, CacheValue.begin() + J)); + SlidingWindowMask = vlm::createAttentionMask(H, CacheValue); + } + for (size_t I = 0; I < Layers.size(); I++) { + bool IsGlobal = + (I % Config.SlidingWindowPattern == Config.SlidingWindowPattern - 1); + std::optional MaskLocal = std::nullopt; + if (!Mask.has_value() && IsGlobal) { + MaskLocal = FullMask; + } else if (!Mask.has_value()) { + MaskLocal = SlidingWindowMask; + } else { + MaskLocal = Mask.value(); + } + H = dynamic_cast(Layers[I].get()) + ->forward(H, MaskLocal, CacheValue[I]); + } + return std::dynamic_pointer_cast(Submodules["norm"]) + ->forward(H); +} + +LanguageModel::LanguageModel(const TextConfig &Config) : Config(Config) { + registerModule("model", std::make_shared(Config)); + registerModule("lm_head", std::make_shared( + Config.HiddenSize, Config.VocabSize, false)); +} + +std::tuple> LanguageModel::forward( + const mx::array &Inputs, const std::optional &InputsEmbeds, + const std::optional &Mask, + const std::optional>> &Cache) { + mx::array Out = std::dynamic_pointer_cast(Submodules["model"]) + ->forward(Inputs, InputsEmbeds, Mask, Cache); + Out = std::dynamic_pointer_cast(Submodules["lm_head"]) + ->forward(Out); + return std::tuple>{Out, {}}; +} + +std::unordered_map LanguageModel::sanitize( + const std::unordered_map &Weights) { + std::unordered_map Sanitized; + if (Weights.find("lm_head.weight") == Weights.end()) + Sanitized.insert({"language_model.lm_head.weight", + Weights.at("language_model.model.embed_tokens.weight")}); + for (auto &Pair : Weights) { + if (Pair.first.find("self_attn.rotary_emb.inv_freq") == std::string::npos) + Sanitized.insert({Pair.first, Pair.second}); + } + return Sanitized; +} + +int LanguageModel::headDim() const { return Config.HeadDim; } + +int LanguageModel::nKvHeads() const { return Config.NumKeyValueHeads; } +int LanguageModel::layers() const { return Config.NumHiddenLayers; } + +std::vector> LanguageModel::makeCache() { + std::vector> Caches; + for (int I = 0; I < Config.NumHiddenLayers; I++) { + if (I % Config.SlidingWindowPattern == Config.SlidingWindowPattern - 1) { + Caches.emplace_back(std::make_shared( + Config.HeadDim, Config.NumKeyValueHeads)); + } else { + Caches.emplace_back( + std::make_shared(Config.SlidingWindow, 0)); + } + } + return Caches; +} + +} // namespace gemma3 +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/gemma3/language.h b/plugins/wasi_nn/MLX/model/gemma3/language.h new file mode 100644 index 00000000..483a17d4 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/gemma3/language.h @@ -0,0 +1,126 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "../vlm_base.h" +#include "simdjson.h" +#include +#include +#include +#include +#include +namespace WasmEdge::Host::WASINN::MLX { +namespace nn = mlx::core::nn; +namespace gemma3 { + +struct TextConfig { + std::string ModelType = "gemma3_text"; + int HiddenSize = 2560; + int NumHiddenLayers = 34; + int IntermediateSize = 10240; + int NumAttentionHeads = 8; + int HeadDim = 256; + float RmsNormEps = 1.0e-6; + int VocabSize = 262208; + int NumKeyValueHeads = 4; + float RopeGlobalBaseFreq = 1000000.0f; + float RopeLocalBaseFreq = 10000.0f; + bool RopeTraditional = false; + float QueryPreAttnScalar = 256; + int SlidingWindow = 1024; + std::optional< + std::unordered_map>>> + RopeScaling; + int MmTokensPerImage = 256; + int SlidingWindowPattern = 6; + static TextConfig fromDict(const simdjson::dom::object &Obj); +}; + +class RMSNorm : public nn::Module { +public: + RMSNorm(int Dims, float Eps = 1e-5); + mx::array forward(const mx::array &X); + +private: + float Eps; +}; + +class Attention : public nn::Module { +public: + Attention(const TextConfig &Config, int LayerIdx); + mx::array forward(const mx::array &X, + const std::optional &Mask = std::nullopt, + const std::optional> + &Cache = std::nullopt); + +private: + int NHeads; + int NKVHeads; + int Repeats; + int HeadDim; + int LayerIdx; + float Scale; + RMSNorm QNorm; + RMSNorm KNorm; + bool IsSliding; +}; + +class MLP : public nn::Module { +public: + MLP(int Dim, int HiddenDim); + mx::array forward(const mx::array &X); +}; + +class TransformerBlock : public nn::Module { +public: + TransformerBlock(const TextConfig &Config, int LayerIdx); + mx::array forward(const mx::array &X, + const std::optional &Mask = std::nullopt, + const std::optional> + &Cache = std::nullopt); + +private: + int NumAttentionHeads; + int HiddenSize; +}; + +class Gemma3Model : public nn::Module { +public: + Gemma3Model(const TextConfig &Config); + mx::array forward( + const mx::array &Inputs, + const std::optional &InputsEmbeds = std::nullopt, + const std::optional &Mask = std::nullopt, + const std::optional>> &Cache = + std::nullopt); + std::vector> Layers; + TextConfig Config; +}; + +class LanguageModel : public vlm::LanguageModel { +public: + LanguageModel(const TextConfig &Config); + std::tuple> forward( + const mx::array &Inputs, + const std::optional>> &Cache = + std::nullopt) override { + return forward(Inputs, std::nullopt, std::nullopt, Cache); + } + std::tuple> forward( + const mx::array &Inputs, + const std::optional &InputsEmbeds = std::nullopt, + const std::optional &Mask = std::nullopt, + const std::optional>> &Cache = + std::nullopt); + std::unordered_map + sanitize(const std::unordered_map &Weights); + int headDim() const override; + int nKvHeads() const override; + int layers() const override; + std::vector> makeCache() override; + TextConfig Config; +}; + +} // namespace gemma3 +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/gemma3/vision.cpp b/plugins/wasi_nn/MLX/model/gemma3/vision.cpp new file mode 100644 index 00000000..ff6e64ce --- /dev/null +++ b/plugins/wasi_nn/MLX/model/gemma3/vision.cpp @@ -0,0 +1,277 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "model/gemma3/vision.h" +#include "mlx/activations.h" +#include "mlx/convolution.h" +#include "mlx/embedding.h" +#include "mlx/linear.h" +#include "mlx/normalization.h" +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace gemma3 { + +VisionConfig VisionConfig::fromDict(const simdjson::dom::object &Obj) { + VisionConfig Config; + auto ModelTypeResult = Obj["model_type"].get_string(); + if (!ModelTypeResult.error()) { + Config.ModelType = std::string(ModelTypeResult.value()); + } + auto NumHiddenLayersResult = Obj["num_hidden_layers"].get_int64(); + if (!NumHiddenLayersResult.error()) { + Config.NumHiddenLayers = static_cast(NumHiddenLayersResult.value()); + } + auto HiddenSizeResult = Obj["hidden_size"].get_int64(); + if (!HiddenSizeResult.error()) { + Config.HiddenSize = static_cast(HiddenSizeResult.value()); + } + auto IntermediateSizeResult = Obj["intermediate_size"].get_int64(); + if (!IntermediateSizeResult.error()) { + Config.IntermediateSize = static_cast(IntermediateSizeResult.value()); + } + auto NumAttentionHeadsResult = Obj["num_attention_heads"].get_int64(); + if (!NumAttentionHeadsResult.error()) { + Config.NumAttentionHeads = + static_cast(NumAttentionHeadsResult.value()); + } + auto PatchSizeResult = Obj["patch_size"].get_int64(); + if (!PatchSizeResult.error()) { + Config.PatchSize = static_cast(PatchSizeResult.value()); + } + auto ImageSizeResult = Obj["image_size"].get_int64(); + if (!ImageSizeResult.error()) { + Config.ImageSize = static_cast(ImageSizeResult.value()); + } + auto NumChannelsResult = Obj["num_channels"].get_int64(); + if (!NumChannelsResult.error()) { + Config.NumChannels = static_cast(NumChannelsResult.value()); + } + auto LayerNormEpsResult = Obj["layer_norm_eps"].get_double(); + if (!LayerNormEpsResult.error()) { + Config.LayerNormEps = static_cast(LayerNormEpsResult.value()); + } + return Config; +} + +bool checkArrayShape(const mx::array &Arr) { + auto Shape = Arr.shape(); + if (Shape.size() != 4) + return false; + int OutChannels = Shape[0]; + int KH = Shape[1]; + int KW = Shape[2]; + return (OutChannels >= KH) && (OutChannels >= KW) && (KH == KW); +} + +VisionAttention::VisionAttention(int Dims, int NumHeads, + std::optional QueryInputDims, + std::optional KeyInputDims, + std::optional ValueInputDims, + std::optional ValueDims, + std::optional ValueOutputDims, bool Bias) + : NumHeads(NumHeads) { + if (Dims % NumHeads != 0) { + spdlog::error( + "[WASI-NN] MLX backend: Dims must be divisible by NumHeads"sv); + assumingUnreachable(); + } + int QInput = QueryInputDims.value_or(Dims); + int KInput = KeyInputDims.value_or(Dims); + int VInput = ValueInputDims.value_or(KInput); + int VDim = ValueDims.value_or(Dims); + int VOutDim = ValueOutputDims.value_or(Dims); + int HeadDim = Dims / NumHeads; + Scale = std::pow(HeadDim, -0.5); + registerModule("q_proj", std::make_shared(QInput, Dims, Bias)); + registerModule("k_proj", std::make_shared(KInput, Dims, Bias)); + registerModule("v_proj", std::make_shared(VInput, VDim, Bias)); + registerModule("out_proj", std::make_shared(VDim, VOutDim, Bias)); +} + +mx::array VisionAttention::forward(const mx::array &X, + const std::optional &Mask) { + mx::array Queries = + std::dynamic_pointer_cast(Submodules["q_proj"])->forward(X); + mx::array Keys = + std::dynamic_pointer_cast(Submodules["k_proj"])->forward(X); + mx::array Values = + std::dynamic_pointer_cast(Submodules["v_proj"])->forward(X); + int B = Queries.shape()[0]; + int L = Queries.shape()[1]; + Queries = transpose(reshape(Queries, {B, L, NumHeads, -1}), {0, 2, 1, 3}); + int S = Keys.shape()[1]; + Keys = transpose(reshape(Keys, {B, S, NumHeads, -1}), {0, 2, 1, 3}); + Values = transpose(reshape(Values, {B, S, NumHeads, -1}), {0, 2, 1, 3}); + mx::array Output = + Mask.has_value() ? mx::fast::scaled_dot_product_attention( + Queries, Keys, Values, Scale, Mask.value()) + : mx::fast::scaled_dot_product_attention(Queries, Keys, + Values, Scale); + Output = reshape(transpose(Output, {0, 2, 1, 3}), {B, L, -1}); + return std::dynamic_pointer_cast(Submodules["out_proj"]) + ->forward(Output); +} + +VisionMLP::VisionMLP(const VisionConfig &Config) { + registerModule("fc1", std::make_shared( + Config.HiddenSize, Config.IntermediateSize, true)); + registerModule("fc2", std::make_shared(Config.IntermediateSize, + Config.HiddenSize, true)); +} + +mx::array VisionMLP::forward(const mx::array &X) { + mx::array Out = + std::dynamic_pointer_cast(Submodules["fc1"])->forward(X); + Out = mlx::core::geluApprox(Out); + Out = std::dynamic_pointer_cast(Submodules["fc2"])->forward(Out); + return Out; +} + +EncoderLayer::EncoderLayer(const VisionConfig &Config) { + EmbedDim = Config.HiddenSize; + registerModule("self_attn", std::make_shared( + Config.HiddenSize, Config.NumAttentionHeads, + std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, true)); + registerModule("layer_norm1", std::make_shared( + EmbedDim, Config.LayerNormEps)); + registerModule("mlp", std::make_shared(Config)); + registerModule("layer_norm2", std::make_shared( + EmbedDim, Config.LayerNormEps)); +} + +mx::array EncoderLayer::forward(const mx::array &X, + const std::optional &Mask) { + mx::array R = + std::dynamic_pointer_cast(Submodules["self_attn"]) + ->forward(std::dynamic_pointer_cast( + Submodules["layer_norm1"]) + ->forward(X), + Mask); + mx::array H = X + R; + R = std::dynamic_pointer_cast(Submodules["mlp"]) + ->forward(std::dynamic_pointer_cast( + Submodules["layer_norm2"]) + ->forward(H)); + return H + R; +} + +Encoder::Encoder(const VisionConfig &Config) { + for (int I = 0; I < Config.NumHiddenLayers; I++) { + Layers.push_back(std::make_shared(Config)); + } + registerLayer("layers", Layers); +} + +std::pair> +Encoder::forward(const mx::array &X, + const std::optional &OutputHiddenStates, + const std::optional &Mask) { + std::vector EncoderStates; + if (OutputHiddenStates.has_value() && OutputHiddenStates.value()) + EncoderStates.push_back(X); + mx::array Out = X; + mx::array H = X; + for (auto &L : Layers) { + Out = L->forward(Out, Mask); + if (OutputHiddenStates.has_value() && OutputHiddenStates.value()) + EncoderStates.push_back(Out); + H = take(Out, 0, 0); + } + return {H, EncoderStates}; +} + +VisionEmbeddings::VisionEmbeddings(const VisionConfig &Config) + : Config(Config) { + EmbedDim = Config.HiddenSize; + ImageSize = Config.ImageSize; + PatchSize = Config.PatchSize; + registerModule("patch_embedding", std::make_shared(nn::Conv2d( + Config.NumChannels, EmbedDim, PatchSize, + {PatchSize, PatchSize}))); + NumPatches = (ImageSize / PatchSize) * (ImageSize / PatchSize); + NumPositions = NumPatches; + registerModule("position_embedding", + std::make_shared(NumPositions, EmbedDim)); +} + +mx::array VisionEmbeddings::forward(const mx::array &X) { + mx::array PatchEmbeddings = + std::dynamic_pointer_cast(Submodules["patch_embedding"]) + ->forward(X); + PatchEmbeddings = mx::flatten(PatchEmbeddings, 1, 2); + std::vector PositionIdsShapeVec; + for (int I = 0; I < NumPositions; I++) { + PositionIdsShapeVec.emplace_back(I); + } + mx::array PositionIds = + mx::array(PositionIdsShapeVec.data(), {1, NumPositions}); + mx::array Embeddings = PatchEmbeddings; + Embeddings = Embeddings + std::dynamic_pointer_cast( + Submodules["position_embedding"]) + ->forward(PositionIds); + return Embeddings; +} + +SigLipVisionModel::SigLipVisionModel(const VisionConfig &Config) { + registerModule("embeddings", std::make_shared(Config)); + registerModule("encoder", std::make_shared(Config)); + registerModule("post_layernorm", + std::make_shared(Config.HiddenSize)); +} + +std::tuple +SigLipVisionModel::forward(const mx::array &X, + const std::optional &OutputHiddenStates) { + mx::array Emb = + std::dynamic_pointer_cast(Submodules["embeddings"]) + ->forward(X); + auto EncoderOutputs = + std::dynamic_pointer_cast(Submodules["encoder"]) + ->forward(Emb, OutputHiddenStates, std::nullopt); + mx::array PoolerOutput = + std::dynamic_pointer_cast(Submodules["post_layernorm"]) + ->forward(std::get<0>(EncoderOutputs)); + return {PoolerOutput, Emb, std::get<1>(EncoderOutputs).back()}; +} + +VisionModel::VisionModel(const VisionConfig &Config) { + ModelType = Config.ModelType; + if (ModelType != "siglip_vision_model" && ModelType != "gemma3" && + ModelType != "gemma3_vision") { + spdlog::error("[WASI-NN] MLX backend: Unsupported model type: {}"sv, + ModelType); + assumingUnreachable(); + } + registerModule("vision_model", std::make_shared(Config)); +} + +std::tuple +VisionModel::forward(const mx::array &X, + const std::optional &OutputHiddenStates) { + return std::dynamic_pointer_cast( + Submodules["vision_model"]) + ->forward(X, OutputHiddenStates); +} + +std::unordered_map VisionModel::sanitize( + const std::unordered_map &Weights) { + std::unordered_map SanitizedWeights; + for (auto &Pair : Weights) { + if (Pair.first.find("patch_embedding.weight") != std::string::npos) { + if (checkArrayShape(Pair.second)) + SanitizedWeights.insert({Pair.first, Pair.second}); + else + SanitizedWeights.insert( + {Pair.first, transpose(Pair.second, {0, 2, 3, 1})}); + } else { + SanitizedWeights.insert({Pair.first, Pair.second}); + } + } + return SanitizedWeights; +} + +} // namespace gemma3 +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/gemma3/vision.h b/plugins/wasi_nn/MLX/model/gemma3/vision.h new file mode 100644 index 00000000..91a421f6 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/gemma3/vision.h @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "mlx/base.h" +#include "simdjson.h" +#include +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace nn = mlx::core::nn; +namespace gemma3 { + +struct VisionConfig { + std::string ModelType = "siglip_vision_model"; + int NumHiddenLayers = 27; + int HiddenSize = 1152; + int IntermediateSize = 4304; + int NumAttentionHeads = 16; + int PatchSize = 14; + int ImageSize = 896; + int NumChannels = 3; + float LayerNormEps = 1e-6f; + static VisionConfig fromDict(const simdjson::dom::object &Obj); +}; + +bool checkArrayShape(const mx::array &Arr); + +class VisionAttention : public nn::Module { +public: + VisionAttention(int Dims, int NumHeads, + std::optional QueryInputDims = std::nullopt, + std::optional KeyInputDims = std::nullopt, + std::optional ValueInputDims = std::nullopt, + std::optional ValueDims = std::nullopt, + std::optional ValueOutputDims = std::nullopt, + bool Bias = true); + mx::array forward(const mx::array &X, + const std::optional &Mask = std::nullopt); + +private: + int NumHeads; + float Scale; +}; + +class VisionMLP : public nn::Module { +public: + VisionMLP(const VisionConfig &Config); + mx::array forward(const mx::array &X); +}; + +class EncoderLayer : public nn::Module { +public: + EncoderLayer(const VisionConfig &Config); + mx::array forward(const mx::array &X, + const std::optional &Mask = std::nullopt); + +private: + int EmbedDim; +}; + +class Encoder : public nn::Module { +public: + Encoder(const VisionConfig &Config); + std::pair> + forward(const mx::array &X, + const std::optional &OutputHiddenStates = std::nullopt, + const std::optional &Mask = std::nullopt); + +private: + std::vector> Layers; +}; + +class VisionEmbeddings : public nn::Module { +public: + VisionEmbeddings(const VisionConfig &Config); + mx::array forward(const mx::array &X); + +private: + VisionConfig Config; + int EmbedDim; + int ImageSize; + int PatchSize; + int NumPatches; + int NumPositions; +}; + +class SigLipVisionModel : public nn::Module { +public: + SigLipVisionModel(const VisionConfig &Config); + std::tuple + forward(const mx::array &X, + const std::optional &OutputHiddenStates = std::nullopt); +}; + +class VisionModel : public nn::Module { +public: + VisionModel(const VisionConfig &Config); + std::tuple + forward(const mx::array &X, + const std::optional &OutputHiddenStates = std::nullopt); + std::unordered_map + sanitize(const std::unordered_map &Weights); + +private: + std::string ModelType; +}; + +} // namespace gemma3 +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/registry.cpp b/plugins/wasi_nn/MLX/model/llm/registry.cpp similarity index 94% rename from plugins/wasi_nn/MLX/model/registry.cpp rename to plugins/wasi_nn/MLX/model/llm/registry.cpp index 0c6f2757..2761b5f8 100644 --- a/plugins/wasi_nn/MLX/model/registry.cpp +++ b/plugins/wasi_nn/MLX/model/llm/registry.cpp @@ -1,11 +1,10 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC -#include "model/registry.h" -#include "model/transformer.h" - +#include "registry.h" +#include "transformer.h" namespace WasmEdge::Host::WASINN::MLX { - +namespace llm { std::shared_ptr llama38b(int VocabSize, float NormEps, float RopeTheta, bool RopeTraditional) { return std::make_shared(Transformer( @@ -29,4 +28,5 @@ std::shared_ptr tinyLlama11BChatV10(int VocabSize, float NormEps, std::vector{4}, NormEps, {}, RopeTraditional, RopeTheta)); } +} // namespace llm } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/registry.h b/plugins/wasi_nn/MLX/model/llm/registry.h similarity index 94% rename from plugins/wasi_nn/MLX/model/registry.h rename to plugins/wasi_nn/MLX/model/llm/registry.h index bd1a908d..359c7369 100644 --- a/plugins/wasi_nn/MLX/model/registry.h +++ b/plugins/wasi_nn/MLX/model/llm/registry.h @@ -3,10 +3,9 @@ #pragma once -#include "model/transformer.h" - +#include "transformer.h" namespace WasmEdge::Host::WASINN::MLX { - +namespace llm { std::shared_ptr llama38b(int VocabSize = 32000, float NormEps = 1e-5, float RopeTheta = 10000.0, @@ -22,4 +21,5 @@ std::shared_ptr tinyLlama11BChatV10(int VocabSize = 32000, float RopeTheta = 10000.0, bool RopeTraditional = false); +} // namespace llm } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/llm/transformer.cpp b/plugins/wasi_nn/MLX/model/llm/transformer.cpp new file mode 100644 index 00000000..e7ac5b94 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/llm/transformer.cpp @@ -0,0 +1,282 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "mlx/transformer.h" +#include "../utils.h" +#include "mlx/base.h" +#include "mlx/embedding.h" +#include "mlx/linear.h" +#include "transformer.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace llm { + +mx::array RMSNorm::forward(mx::array Input) { + return mx::fast::rms_norm(Input, 1.0 + Parameters.at("weight"), Eps); +} + +std::tuple> +Attention::forward(mx::array Input, std::optional Mask, + std::optional> KVCache) { + const auto &[B, L, D] = + std::tie(Input.shape()[0], Input.shape()[1], Input.shape()[2]); + mx::array Queries = + std::dynamic_pointer_cast(Submodules["q_proj"]) + ->forward(Input); + mx::array Keys = std::dynamic_pointer_cast(Submodules["k_proj"]) + ->forward(Input); + mx::array Values = std::dynamic_pointer_cast(Submodules["v_proj"]) + ->forward(Input); + Queries = transpose(reshape(Queries, {B, L, NHeads, -1}), {0, 2, 1, 3}); + Keys = transpose(reshape(Keys, {B, L, NKVHeads, -1}), {0, 2, 1, 3}); + Values = transpose(reshape(Values, {B, L, NKVHeads, -1}), {0, 2, 1, 3}); + + if (NormQKProj) { + Queries = std::dynamic_pointer_cast(Submodules["q_norm"]) + ->forward(Queries); + Keys = std::dynamic_pointer_cast(Submodules["k_norm"]) + ->forward(Keys); + } + if (KVCache) { + const auto &[KeyCache, ValueCache] = *KVCache; + Queries = std::dynamic_pointer_cast(Submodules["rope"]) + ->forward(Queries, KeyCache.shape(2)); + Keys = std::dynamic_pointer_cast(Submodules["rope"]) + ->forward(Keys, KeyCache.shape(2)); + Keys = mx::concatenate({KeyCache, Keys}, 2); + Values = mx::concatenate({ValueCache, Values}, 2); + } else { + Queries = std::dynamic_pointer_cast(Submodules["rope"]) + ->forward(Queries); + Keys = + std::dynamic_pointer_cast(Submodules["rope"])->forward(Keys); + } + mx::array Output = + Mask.has_value() ? mx::fast::scaled_dot_product_attention( + Queries, Keys, Values, Scale, Mask.value()) + : mx::fast::scaled_dot_product_attention(Queries, Keys, + Values, Scale); + + Output = reshape(transpose(Output, {0, 2, 1, 3}), {B, L, -1}); + return {std::dynamic_pointer_cast(Submodules["o_proj"]) + ->forward(Output), + {Keys, Values}}; +} + +mx::array MLP::forward(mx::array Input) { + if (Gemma) { + return std::dynamic_pointer_cast(Submodules["down_proj"]) + ->forward(mlx::core::gelu(std::dynamic_pointer_cast( + Submodules["gate_proj"]) + ->forward(Input)) * + std::dynamic_pointer_cast(Submodules["up_proj"]) + ->forward(Input)); + } + return std::dynamic_pointer_cast(Submodules["down_proj"]) + ->forward(mlx::core::silu(std::dynamic_pointer_cast( + Submodules["gate_proj"]) + ->forward(Input)) * + std::dynamic_pointer_cast(Submodules["up_proj"]) + ->forward(Input)); +} + +std::tuple> +TransformerBlock::forward( + mx::array Input, std::optional Mask, + std::optional> KVCachePar) { + mx::array NormOutput({}); + if (!Gemma) { + NormOutput = + std::dynamic_pointer_cast(Submodules["attention_norm"]) + ->forward(Input); + } else { + NormOutput = + std::dynamic_pointer_cast(Submodules["attention_norm"]) + ->forward(Input); + } + auto [R, KVCache] = + std::dynamic_pointer_cast(Submodules["attention"]) + ->forward(NormOutput, Mask, KVCachePar); + auto H = Input + R; + if (!Gemma) { + R = std::dynamic_pointer_cast(Submodules["mlp"]) + ->forward( + std::dynamic_pointer_cast(Submodules["mlp_norm"]) + ->forward(H)); + } else { + R = std::dynamic_pointer_cast(Submodules["mlp"]) + ->forward(std::dynamic_pointer_cast(Submodules["mlp_norm"]) + ->forward(H)); + } + return {H + R, KVCache}; +} + +std::tuple>>> +Transformer::embed( + mx::array Input, + std::optional>> KVCachePar, + bool Norm) { + mx::array H = + std::dynamic_pointer_cast(Submodules["token_embed"]) + ->forward(Input); + if (Gemma) { + H = H * (pow(Dim, 0.5)); + } + std::optional Mask; + if (H.shape()[1] > 1) { + Mask = nn::MultiHeadAttention::createAdditiveCausalMask(H.shape()[1]); + Mask = astype(*Mask, H.dtype()); + } + std::vector> KVCache; + KVCache.reserve(Layers.size()); + for (size_t Idx = 0; Idx < Layers.size(); Idx++) { + std::tuple> Result = { + mx::array({}), {mx::array({}), mx::array({})}}; + if (KVCachePar) { + Result = Layers[Idx]->forward(H, Mask, (*KVCachePar)[Idx]); + } else { + Result = Layers[Idx]->forward(H, Mask, {}); + } + H = std::get<0>(Result); + KVCache.emplace_back(std::get<1>(Result)); + } + if (Norm) { + if (!Gemma) { + return {std::dynamic_pointer_cast(Submodules["norm"]) + ->forward(H), + KVCache}; + } + return {std::dynamic_pointer_cast(Submodules["norm"])->forward(H), + KVCache}; + } + return {H, KVCache}; +} + +std::tuple>>> +Transformer::forward( + mx::array Input, + std::optional>> KVCachePar) { + auto [X, KVCache] = embed(Input, KVCachePar, true); + mx::array Out({}); + if (EmbedAsHead) { + Out = std::dynamic_pointer_cast(Submodules["token_embed"]) + ->asLinear(X); + } else { + Out = std::dynamic_pointer_cast(Submodules["head"])->forward(X); + } + return {Out, KVCache}; +} + +std::tuple>>> +Transformer::stepGenerate(mx::array Input, std::optional Temp) { + // Reshape Input to input[:, None] + std::vector ReshapeDim = Input.shape(); + ReshapeDim.insert(ReshapeDim.begin(), 1); + auto [Logits, KVCache] = forward(reshape(Input, ReshapeDim)); + const int H = Logits.shape()[1] - 1; + // take logits[:, -1, :] + Logits = take(Logits, mx::array({H}), 1); + ReshapeDim = Logits.shape(); + ReshapeDim.erase(ReshapeDim.begin() + 1); + Logits = reshape(Logits, ReshapeDim); + mx::array Y({}); + if (Temp == 0) { + Y = mx::argmax(Logits, -1); + } else { + Y = mx::random::categorical(Logits * (1.0 / *Temp)); + } + return {Y, KVCache}; +} + +std::tuple>>> +Transformer::nextStepGenerate( + mx::array Y, std::optional Temp, + std::optional>> KVCachePar) { + // Reshape Y to y[:, None] + std::vector ReshapeDim = Y.shape(); + ReshapeDim.insert(ReshapeDim.begin() + 1, 1); + auto [Logits, KVCache] = forward(reshape(Y, ReshapeDim), KVCachePar); + Logits = squeeze(Logits, 1); + mx::array NextY({}); + if (Temp == 0) { + NextY = mx::argmax(Logits, -1); + } else { + NextY = mx::random::categorical(Logits * (1.0 / *Temp)); + } + return {NextY, KVCache}; +} + +enum AnserSataus { + STOP, + WAIT, + GO, +}; + +AnserSataus answerSataus(std::string Text, std::string End) { + if (endsWith(Text, End)) { + return STOP; + } + for (int Idx = 1; Idx < static_cast(End.size()); Idx++) { + if (endsWith(Text, End.substr(0, Idx))) { + return WAIT; + } + } + return GO; +} + +Transformer::LLMOutput +Transformer::generate(const std::string &Prompt, const BasePrompt &ModelPrompt, + const int MaxToken, const bool Verbose, + const std::unique_ptr &Tok) { + const std::vector Ids = Tok->Encode(Prompt); + mx::array Token = + mx::array(Ids.data(), {static_cast(Ids.size())}, mx::int32); + std::vector TokenList; + int TokenCount = 0; + int Skip = 0; + std::string Answer; + auto [Y, KVCache] = this->stepGenerate(Token, 0.1); + while (true) { + TokenCount++; + if (TokenCount > MaxToken) { + break; + } + eval(Y); + std::vector Tokens; + auto *Data = Y.data(); + for (int Idx = 0; Idx < static_cast(Y.size()); Idx++) { + Tokens.emplace_back(Data[Idx]); + } + // TODO: break when the token is the eos_token_id + TokenList.insert(TokenList.end(), Tokens.begin(), Tokens.end()); + Answer = Tok->Decode(TokenList); + const AnserSataus Status = answerSataus(Answer, ModelPrompt.TextEnd); + if (Status == STOP) { + break; + } + if (Status == GO) { + if (Verbose) { + spdlog::info("[WASI-NN] MLX backend: {}"sv, Answer.substr(Skip)); + } + Skip = Answer.size(); + } + auto [NY, NKVCache] = this->nextStepGenerate(Y, 0.1, KVCache); + Y = NY, KVCache = NKVCache; + } + return {Answer, TokenList}; +} + +} // namespace llm +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/llm/transformer.h b/plugins/wasi_nn/MLX/model/llm/transformer.h new file mode 100644 index 00000000..1065e31a --- /dev/null +++ b/plugins/wasi_nn/MLX/model/llm/transformer.h @@ -0,0 +1,231 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once +#include "mlx/activations.h" +#include "mlx/base.h" +#include "mlx/embedding.h" +#include "mlx/linear.h" +#include "mlx/normalization.h" +#include "mlx/positional_encoding.h" +#include "prompt/prompt.h" +#include +#include +#include +#include +#include +#include +#include +#include +namespace WasmEdge::Host::WASINN::MLX { +namespace nn = mlx::core::nn; +namespace llm { + +class RMSNorm : public nn::Module { + float Eps; + +public: + RMSNorm(int Dims, float Eps = 1e-5) : Eps(Eps) { + registerParameter("weight", mx::ones({Dims})); + } + mx::array forward(mx::array Input); +}; +class Attention : public nn::Module { + + int NHeads; + int NKVHeads; + bool NormQKProj; + double Scale; + +public: + Attention(int Dim, int NHeads, int NKVHeads, + std::optional HeadDimPar = 0, bool RopeTraditional = false, + float RopeTheta = 1000, + std::optional> + RopeScaling = {}, + bool NormQKProj = false, float AttentionNormEps = 1e-6) + : NHeads(NHeads), NKVHeads(NKVHeads), NormQKProj(NormQKProj) { + int HeadDim; + if (HeadDimPar) { + HeadDim = *HeadDimPar; + } else { + HeadDim = Dim / NHeads; + } + Scale = pow(HeadDim, -0.5); + registerModule("q_proj", std::make_shared( + nn::Linear(Dim, NHeads * HeadDim, false))); + registerModule("k_proj", std::make_shared( + nn::Linear(Dim, NKVHeads * HeadDim, false))); + registerModule("v_proj", std::make_shared( + nn::Linear(Dim, NKVHeads * HeadDim, false))); + registerModule("o_proj", std::make_shared( + nn::Linear(NHeads * HeadDim, Dim, false))); + + if (NormQKProj) { + registerModule("q_norm", std::make_shared( + nn::RMSNorm(HeadDim, AttentionNormEps))); + registerModule("k_norm", std::make_shared( + nn::RMSNorm(HeadDim, AttentionNormEps))); + } + float RopeScale; + if (RopeScaling && (*RopeScaling)["type"] == "linear") { + RopeScale = 1 / stof((*RopeScaling)["factor"]); + } else { + RopeScale = 1; + } + + registerModule("rope", + std::make_shared(nn::RoPE(HeadDim, RopeTraditional, + RopeTheta, RopeScale))); + } + std::tuple> + forward(mx::array Input, std::optional Mask = {}, + std::optional> KVCache = {}); +}; +class MLP : public nn::Module { + bool Gemma; + +public: + MLP(int Dim, int HiddenDim, bool Gemma = false) : Gemma(Gemma) { + registerModule("gate_proj", std::make_shared( + nn::Linear(Dim, HiddenDim, false))); + registerModule("down_proj", std::make_shared( + nn::Linear(HiddenDim, Dim, false))); + registerModule("up_proj", std::make_shared( + nn::Linear(Dim, HiddenDim, false))); + } + mx::array forward(mx::array Input); +}; +class TransformerBlock : public nn::Module { + bool Gemma; + +public: + TransformerBlock(int Dim, int NHeads, int NKVHeads, int HiddenDim, + float NormEps, std::optional HeadDim = {}, + bool RopeTraditional = false, float RopeTheta = 1000, + std::optional> + RopeScaling = {}, + bool NormQKProj = false, float AttentionNormEps = 1e-6, + bool Gemma = false) + : Gemma(Gemma) { + registerModule("attention", + std::make_shared(Attention( + Dim, NHeads, NKVHeads, HeadDim, RopeTraditional, + RopeTheta, RopeScaling, NormQKProj, AttentionNormEps))); + registerModule("mlp", std::make_shared(MLP(Dim, HiddenDim, Gemma))); + if (!Gemma) { + registerModule("attention_norm", + std::make_shared(nn::RMSNorm(Dim, NormEps))); + registerModule("mlp_norm", + std::make_shared(nn::RMSNorm(Dim, NormEps))); + } else { + registerModule("attention_norm", + std::make_shared(RMSNorm(Dim, NormEps))); + registerModule("mlp_norm", + std::make_shared(RMSNorm(Dim, NormEps))); + } + } + std::tuple> + forward(mx::array Input, std::optional Mask = {}, + std::optional> KVCachePar = {}); +}; +class Transformer : public nn::Module { + int Dim; + std::optional> HiddenDim; + bool Gemma; + bool EmbedAsHead; + std::vector> Layers{}; + +public: + Transformer( + int Dim, std::optional> HiddenDim, int VocabSize, + int NLayers, std::optional> NHeads, + std::optional> NKVHeads = {}, float NormEps = 1e-5, + std::optional HeadDim = {}, bool RopeTraditional = false, + float RopeTheta = 1000, + std::optional>> + RopeScaling = {}, + bool NormQKProj = false, float AttentionNormEps = 1e-6, + bool Gemma = false, bool EmbedAsHeadPar = false) + : Dim(Dim), HiddenDim(HiddenDim), Gemma(Gemma), + EmbedAsHead(EmbedAsHeadPar) { + if (VocabSize <= 0) { + spdlog::error("VocabSize must be greater than 0."); + assumingUnreachable(); + } + EmbedAsHead = Gemma ? true : EmbedAsHead; + if (!NKVHeads) { + NKVHeads = NHeads; + } + registerModule("token_embed", std::make_shared( + nn::Embedding(VocabSize, Dim))); + if (HiddenDim->size() == 1) { + while (static_cast(HiddenDim->size()) < NLayers) { + HiddenDim->emplace_back((*HiddenDim)[0]); + } + } + if (NHeads->size() == 1) { + while (static_cast(NHeads->size()) < NLayers) { + NHeads->emplace_back((*NHeads)[0]); + } + } + if (NKVHeads->size() == 1) { + while (static_cast(NKVHeads->size()) < NLayers) { + NKVHeads->emplace_back((*NKVHeads)[0]); + } + } + Layers.reserve(NLayers); + for (int Idx = 0; Idx < NLayers; Idx++) { + if (RopeScaling) { + Layers.push_back(std::make_shared(TransformerBlock( + Dim, (*NHeads)[Idx], (*NKVHeads)[Idx], (*HiddenDim)[Idx], NormEps, + HeadDim, RopeTraditional, RopeTheta, (*RopeScaling)[Idx], + NormQKProj, AttentionNormEps, Gemma))); + } else { + Layers.push_back(std::make_shared(TransformerBlock( + Dim, (*NHeads)[Idx], (*NKVHeads)[Idx], (*HiddenDim)[Idx], NormEps, + HeadDim, RopeTraditional, RopeTheta, {}, NormQKProj, + AttentionNormEps, Gemma))); + } + } + registerLayer("layers", Layers); + if (!Gemma) { + registerModule("norm", + std::make_shared(nn::RMSNorm(Dim, NormEps))); + } else { + registerModule("norm", std::make_shared(RMSNorm(Dim, NormEps))); + } + if (!EmbedAsHead) { + registerModule("head", std::make_shared( + nn::Linear(Dim, VocabSize, false))); + } + } + struct LLMOutput { + std::string Answer; + std::vector TokenList; + }; + std::tuple>>> + embed(mx::array Input, + std::optional>> + KVCachePar = {}, + bool Norm = false); + std::tuple>>> + forward(mx::array Input, + std::optional>> + KVCachePar = {}); + std::tuple>>> + stepGenerate(mx::array Input, std::optional Temp = 0.0); + std::tuple>>> + nextStepGenerate(mx::array Y, std::optional Temp = 0.0, + std::optional>> + KVCachePar = {}); + LLMOutput generate(const std::string &Prompt, const BasePrompt &ModelPrompt, + const int MaxToken, const bool Verbose, + const std::unique_ptr &Tok); +}; +} // namespace llm +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/transformer.cpp b/plugins/wasi_nn/MLX/model/transformer.cpp index 683eb968..79ca927c 100644 --- a/plugins/wasi_nn/MLX/model/transformer.cpp +++ b/plugins/wasi_nn/MLX/model/transformer.cpp @@ -6,6 +6,7 @@ #include "mlx/embedding.h" #include "mlx/linear.h" #include "model/transformer.h" +namespace WasmEdge::Host::WASINN::MLX { #include #include @@ -215,3 +216,5 @@ Transformer::nextGenerate( } } // namespace WasmEdge::Host::WASINN::MLX + +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/transformer.h b/plugins/wasi_nn/MLX/model/transformer.h index d51be4d1..e555a8ea 100644 --- a/plugins/wasi_nn/MLX/model/transformer.h +++ b/plugins/wasi_nn/MLX/model/transformer.h @@ -9,6 +9,7 @@ #include "mlx/linear.h" #include "mlx/normalization.h" #include "mlx/positional_encoding.h" +namespace WasmEdge::Host::WASINN::MLX { #include #include @@ -236,3 +237,5 @@ class Transformer : public nn::Module { }; } // namespace WasmEdge::Host::WASINN::MLX + +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/vlm_base.cpp b/plugins/wasi_nn/MLX/model/vlm_base.cpp new file mode 100644 index 00000000..7e30518e --- /dev/null +++ b/plugins/wasi_nn/MLX/model/vlm_base.cpp @@ -0,0 +1,538 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "vlm_base.h" +#include "common/errcode.h" +#include "model/vlm_sampling.h" +#include "spdlog/spdlog.h" +#include +#include +#include +#include +#include +#include +#include +#include +namespace WasmEdge::Host::WASINN::MLX { +namespace vlm { + +// BaseCache implementation +std::vector BaseCache::getState() const { return {}; } + +void BaseCache::setState(const std::vector &State) { + if (!State.empty()) { + spdlog::error( + "[WASI-NN] MLX backend: This cache has no state but a state was set."sv); + assumingUnreachable(); + } +} + +std::string BaseCache::getMetaState() const { return ""; } + +void BaseCache::setMetaState(const std::string &Value) { + if (!Value.empty()) { + spdlog::error( + "[WASI-NN] MLX backend: This cache has no meta_state but a meta_state was set."sv); + assumingUnreachable(); + } +} + +bool BaseCache::isTrimmable() const { return false; } + +int BaseCache::trim(int) { return 0; } + +// KVCache implementation +KVCache::KVCache(int HeadDim, int NKVHeads, int Step) + : NKVHeads(NKVHeads), KHeadDim(HeadDim), VHeadDim(HeadDim), Step(Step) {} + +KVCache::KVCache(std::pair HeadDims, int NKVHeads, int Step) + : NKVHeads(NKVHeads), KHeadDim(HeadDims.first), VHeadDim(HeadDims.second), + Step(Step) {} + +std::tuple +KVCache::updateAndFetch(const mx::array &NewKeys, const mx::array &NewValues) { + update(NewKeys, NewValues); + return fetch(); +} + +std::tuple KVCache::fetch() const { + // self.keys[..., : self.offset, :], self.values[..., : self.offset, :] + mx::array Indices = mx::arange(Offset); + mx::array KeysSlice = take(Keys, Indices, -2); + mx::array ValuesSlice = take(Values, Indices, -2); + return {KeysSlice, ValuesSlice}; +} + +void KVCache::update(const mx::array &NewKeys, const mx::array &NewValues) { + int Prev = Offset; + std::vector NewShape = NewKeys.shape(); + int NewLen = NewShape[2]; + + if (Keys.size() == 0 || (Prev + NewLen) > Keys.shape()[2]) { + int NSteps = (Step + NewLen - 1) / Step; + int NewCapacity = NSteps * Step; + std::vector KShape = {1, NKVHeads, NewCapacity, KHeadDim}; + std::vector VShape = {1, NKVHeads, NewCapacity, VHeadDim}; + mx::array NewK = mx::zeros(KShape, NewKeys.dtype()); + mx::array NewV = mx::zeros(VShape, NewValues.dtype()); + if (Keys.size() != 0) { + if (Prev % Step != 0) { + mx::array Indices = mx::arange(Prev); + Keys = take(Keys, Indices, -2); + Values = take(Values, Indices, -2); + } + Keys = mx::concatenate({Keys, NewK}, 2); + Values = mx::concatenate({Values, NewV}, 2); + } else { + Keys = NewK; + Values = NewV; + } + } + + Offset += NewLen; + // self.keys[..., prev : self.offset, :] = keys + // self.values[..., prev : self.offset, :] = values + auto End = NewKeys.shape(); + std::vector Start(End.size(), 0); + std::vector Stride(End.size(), 1); + Start[End.size() - 2] = Prev; + End[End.size() - 2] = Offset; + Keys = mx::slice_update(Keys, NewKeys, Start, End, Stride); + Values = mx::slice_update(Values, NewValues, Start, End, Stride); +} + +std::vector KVCache::getState() const { + if (Offset == Keys.shape()[2]) { + return {Keys, Values}; + } + mx::array Indices = mx::arange(Offset); + mx::array KeysSlice = take(Keys, Indices, -2); + mx::array ValuesSlice = take(Values, Indices, -2); + return {KeysSlice, ValuesSlice}; +} + +void KVCache::setState(const std::vector &State) { + if (State.size() != 2) { + spdlog::error( + "[WASI-NN] MLX backend: KVCache state must contain exactly two arrays"sv); + assumingUnreachable(); + } + Keys = State[0]; + Values = State[1]; + Offset = Keys.shape()[2]; +} + +bool KVCache::isTrimmable() const { return true; } + +int KVCache::trim(int N) { + N = std::min(Offset, N); + Offset -= N; + return N; +} + +// RotatingKVCache implementation +RotatingKVCache::RotatingKVCache(int MaxSize, int Keep, int Step) + : KVCache(0, 0, Step), Keep(Keep), MaxSize(MaxSize), Idx(0) {} + +mx::array RotatingKVCache::trim(int TrimSize, const mx::array &V, + std::optional Append) { + std::vector ToCat; + + if (TrimSize > 0) { + mx::array KeepIndices = mx::arange(Keep); + mx::array TrimIndices = mx::arange(TrimSize + Keep, V.shape()[2]); + mx::array KeepPart = take(V, KeepIndices, -2); + mx::array TrimPart = take(V, TrimIndices, -2); + ToCat = {KeepPart, TrimPart}; + } else { + ToCat = {V}; + } + + if (Append.has_value()) { + ToCat.push_back(Append.value()); + } + + return mx::concatenate(ToCat, 2); +} + +mx::array RotatingKVCache::temporalOrder(const mx::array &V) { + if (Idx == V.shape()[2]) { + return V; + } + if (Idx < Offset) { + mx::array KeepIndices = mx::arange(Keep); + mx::array IdxToEndIndices = mx::arange(Idx, V.shape()[2]); + mx::array KeepToIdxIndices = mx::arange(Keep, Idx); + + mx::array KeepPart = take(V, KeepIndices, -2); + mx::array IdxToEndPart = take(V, IdxToEndIndices, -2); + mx::array KeepToIdxPart = take(V, KeepToIdxIndices, -2); + + return mx::concatenate({KeepPart, IdxToEndPart, KeepToIdxPart}, 2); + } + mx::array IdxIndices = mx::arange(Idx); + return take(V, IdxIndices, -2); +} + +std::tuple +RotatingKVCache::updateConcat(const mx::array &NewKeys, + const mx::array &NewValues) { + if (Keys.size() == 0) { + Keys = NewKeys; + Values = NewValues; + } else { + // Put the keys/values in temporal order to preserve context + Keys = temporalOrder(Keys); + Values = temporalOrder(Values); + + // The largest size is MaxSize + S to ensure every token gets at least + // MaxSize context + int TrimSize = Idx - MaxSize; + Keys = trim(TrimSize, Keys, NewKeys); + Values = trim(TrimSize, Values, NewValues); + } + + Offset += NewKeys.shape()[2]; + Idx = Keys.shape()[2]; + return {Keys, Values}; +} + +std::tuple +RotatingKVCache::updateInPlace(const mx::array &NewKeys, + const mx::array &NewValues) { + // May not have hit the max size yet, so potentially keep growing the cache + std::vector KeysShape = NewKeys.shape(); + int B = KeysShape[0]; + int NKVHeads = KeysShape[1]; + int S = KeysShape[2]; + int KHeadDim = KeysShape[3]; + int VHeadDim = NewValues.shape()[3]; + + int Prev = Offset; + if (Keys.size() == 0 || + (Prev >= Keys.shape()[2] && Keys.shape()[2] < MaxSize)) { + int NewSize = std::min(Step, MaxSize - Prev); + std::vector KShape = {B, NKVHeads, NewSize, KHeadDim}; + std::vector VShape = {B, NKVHeads, NewSize, VHeadDim}; + + mx::array NewK = mx::zeros(KShape, NewKeys.dtype()); + mx::array NewV = mx::zeros(VShape, NewValues.dtype()); + + if (Keys.size() != 0) { + Keys = mx::concatenate({Keys, NewK}, 2); + Values = mx::concatenate({Values, NewV}, 2); + } else { + Keys = NewK; + Values = NewV; + } + Idx = Prev; + } + + // Trim if needed + int TrimSize = Keys.shape()[2] - MaxSize; + if (TrimSize > 0) { + Keys = trim(TrimSize, Keys); + Values = trim(TrimSize, Values); + Idx = MaxSize; + } + + // Rotate + if (Idx == MaxSize) { + Idx = Keep; + } + + // Assign + std::vector KeysEnd = Keys.shape(); + std::vector Start(KeysEnd.size(), 0); + std::vector ValuesEnd = Values.shape(); + std::vector Stride(KeysEnd.size(), 1); + Start[Start.size() - 2] = Idx; + KeysEnd[KeysEnd.size() - 2] = Idx + S; + ValuesEnd[KeysEnd.size() - 2] = Idx + S; + Keys = mx::slice_update(Keys, NewKeys, Start, KeysEnd, Stride); + Values = mx::slice_update(Values, NewValues, Start, ValuesEnd, Stride); + + Offset += S; + Idx += S; + + // If the buffer is not full, slice off the end + if (Offset < MaxSize) { + mx::array OffsetIndices = mx::arange(Offset); + mx::array KeysSlice = take(Keys, OffsetIndices, -2); + mx::array ValuesSlice = take(Values, OffsetIndices, -2); + return {KeysSlice, ValuesSlice}; + } + + return {Keys, Values}; +} + +std::tuple +RotatingKVCache::updateAndFetch(const mx::array &NewKeys, + const mx::array &NewValues) { + if (NewKeys.shape()[2] == 1) { + return updateInPlace(NewKeys, NewValues); + } + return updateConcat(NewKeys, NewValues); +} + +std::string RotatingKVCache::getMetaState() const { + return std::to_string(Keep) + "," + std::to_string(MaxSize) + "," + + std::to_string(Step) + "," + std::to_string(Offset) + "," + + std::to_string(Idx); +} + +void RotatingKVCache::setMetaState(const std::string &Value) { + std::stringstream SS(Value); + std::string Item; + std::vector Values; + + while (std::getline(SS, Item, ',')) { + Values.push_back(std::stoi(Item)); + } + + if (Values.size() == 5) { + Keep = Values[0]; + MaxSize = Values[1]; + Step = Values[2]; + Offset = Values[3]; + Idx = Values[4]; + } else { + spdlog::error("[WASI-NN] MLX backend: Invalid meta state format."sv); + assumingUnreachable(); + } +} + +bool RotatingKVCache::isTrimmable() const { return Offset < MaxSize; } + +int RotatingKVCache::trim(int N) { + N = std::min(Offset, N); + Offset -= N; + Idx -= N; + return N; +} + +mx::array createAdditiveCausalMask(int N, int Offset) { + auto Rinds = mx::arange(Offset + N); + mx::array Linds = mx::array({}); + if (Offset) { + Linds = mx::arange(Offset, Offset + N); + } else { + Linds = Rinds; + } + // mask = linds[:, None] < rinds[None] + return mx::less(mx::expand_dims(Linds, 1), mx::expand_dims(Rinds, 0)) * -1e9; +} + +std::optional createAttentionMask( + mx::array H, + std::optional>> Cache) { + int T = H.shape()[1]; + std::optional Mask = std::nullopt; + if (T > 1) { + int Offset = 0; + if (Cache.has_value() && Cache.value().size() > 0 && + Cache.value()[0] != nullptr) { + auto C = Cache.value()[0]; + auto RotCache = std::dynamic_pointer_cast(C); + if (RotCache) { + Offset = std::min(RotCache->MaxSize - 1, RotCache->Offset); + } else { + Offset = C->Offset; + } + } + Mask = createAdditiveCausalMask(T, Offset); + Mask = mx::astype(Mask.value(), H.dtype()); + } + return Mask; +} + +std::vector Module::generate( + const std::string &Prompt, std::optional Image, bool Verbose, + std::map> + Kwargs) { + + if (Verbose) { + spdlog::info("=========="sv); + if (Image.has_value()) { + spdlog::info("Files: {}\n"sv, Image.value()); + } else if (Kwargs.count("Video") > 0) { + /* Print video path */ + } + spdlog::info("Prompt: {}"sv, Prompt); + } + + std::string Text = ""; + // stream generate + mx::array InputIds = mx::array({}); + mx::array PixelValues = mx::array({}); + mx::array Mask = mx::array({}); + + // For pixel_values + // int ImageTokenIndex; + // auto ImageTokenIndexIt = Kwargs.find("image_token_index"); + // if (ImageTokenIndexIt != Kwargs.end()) { + // if (auto *ImageTokenIndexPtr = + // std::get_if(&ImageTokenIndexIt->second)) { + // ImageTokenIndex = *ImageTokenIndexPtr; + // } else { + // assumingUnreachable(); + // } + // } else { + // assumingUnreachable(); + // } + if (Kwargs.count("pixel_values") == 0) { + spdlog::error("Not implemented"); + assumingUnreachable(); + } else { + InputIds = *std::get_if(&Kwargs.find("input_ids")->second); + PixelValues = *std::get_if(&Kwargs.find("pixel_values")->second); + Mask = *std::get_if(&Kwargs.find("mask")->second); + } + // Generate_state + // Initialize cache + std::vector> Cache; + int MaxTokens = 256; + float Temperature = 0.0f; + std::optional RepetitionPenalty = std::nullopt; + size_t RepetitionContextSize = 20; + float TopP = 1.0f; + std::map LogitBias = {}; + auto LanguageModel = std::dynamic_pointer_cast( + this->Submodules["language_model"]); + + auto Sample = [&](mx::array Logits) -> std::tuple { + if (!LogitBias.empty()) { + for (const auto &[Index, Value] : LogitBias) { + Logits = + scatter_add_axis(Logits, mx::array({Index}), mx::array({Value}), 1); + } + } + + mx::array LogProbs = Logits - mx::logsumexp(Logits, -1); + mx::array Token = mx::array({}); + if (Temperature == 0.0f) { + Token = mx::argmax(Logits, -1); + } else { + if (TopP > 0 and TopP < 1.0) { + Token = topPSampling(Logits, TopP, Temperature); + } else { + Token = mx::random::categorical(Logits / Temperature); + } + } + + return {Token, LogProbs}; + }; + + if (RepetitionPenalty.has_value() && RepetitionPenalty < 0) { + spdlog::error("Repetition penalty must be greater than 0"); + assumingUnreachable(); + } + + auto MakeCache = LanguageModel->makeCache(); + Cache.insert(Cache.begin(), MakeCache.begin(), MakeCache.end()); + + // Initialize repetition context + auto FlatternInputIdsShape = reshape(InputIds, {-1}); + mx::async_eval(FlatternInputIdsShape); + std::vector RepetitionContext(FlatternInputIdsShape.data(), + FlatternInputIdsShape.data() + + InputIds.size()); + if (RepetitionContext.size() > RepetitionContextSize) { + RepetitionContext.erase(RepetitionContext.begin(), + RepetitionContext.end() - RepetitionContextSize); + } + + auto Step = [&](mx::array Y) -> std::tuple { + std::vector NewShape = Y.shape(); + NewShape.insert(NewShape.begin(), 1); + // TODO: handle decoder_input_ids + auto Outputs = std::dynamic_pointer_cast( + this->Submodules["language_model"]) + ->forward(reshape(Y, NewShape), Cache); + mx::array Logits = std::get<0>(Outputs); + Logits = take(Logits, Logits.shape()[1] - 1, 1); + mx::array LogProbs = mx::array({}); + if (RepetitionPenalty.has_value()) { + if (RepetitionContext.size() > 0) { + auto Indices = mx::array(RepetitionContext.data(), + {static_cast(RepetitionContext.size())}); + auto SelectedLogits = take(Logits, Indices, 1); + SelectedLogits = where(SelectedLogits < 0, + SelectedLogits * RepetitionPenalty.value(), + SelectedLogits / RepetitionPenalty.value()); + put_along_axis(Logits, Indices, SelectedLogits, 1); + } + std::tie(Y, LogProbs) = Sample(Logits); + RepetitionContext.emplace_back(Y.item()); + } else { + std::tie(Y, LogProbs) = Sample(Logits); + } + if (RepetitionContext.size() > RepetitionContextSize) { + RepetitionContext.erase(RepetitionContext.begin(), + RepetitionContext.end() - RepetitionContextSize); + } + return {Y, squeeze(LogProbs, 0)}; + }; + + // Perform the first step + auto Outputs = this->forward(InputIds, PixelValues, Mask, Cache); + mx::array Logits = std::get<0>(Outputs); + Logits = take(Logits, Logits.shape()[1] - 1, 1); + auto [Y, LogProbs] = Sample(Logits); + mx::async_eval(Y); + // TODO: handle cross_attention_states, encoder_outputs + // End generate_state + + std::optional Result; + GenerationResult LastResponse; + // auto Tic = std::chrono::system_clock::now().time_since_epoch(); + std::vector TokenList; + int N = 0; + while (true) { + // for statistic + // int PromptTPS; + // if (N == 0) { + // PromptTPS = + // std::chrono::duration_cast(Tic).count(); Tic = + // std::chrono::system_clock::now().time_since_epoch(); + // } + // TODO: if token = eos_token_id break + auto Response = GenerationResult(); + N++; + if (N >= MaxTokens) { + break; + } + auto [NextY, NextLogProbs] = Step(Y); + mx::async_eval(NextY); + // TODO: handle decoder_input_ids + auto Token = Y.item(); + if (Token == 1 || Token == 106) { + break; + } + TokenList.emplace_back(Token); + Y = NextY; + LogProbs = NextLogProbs; + } + return TokenList; + // end stream generate + // TODO: waiting for processor + if (Verbose) { + spdlog::info("\n=========="sv); + if (Text.empty()) { + spdlog::info("No text generated for this prompt"sv); + return {}; + // return Text; + } + + spdlog::info("Prompt: {} tokends, {} tokens-per-sec"sv, + LastResponse.PromptTokens, LastResponse.PromptTps); + spdlog::info("Generation: {} tokens {} tokens-per-sec"sv, + LastResponse.GenerationTokens, LastResponse.GenerationTps); + spdlog::info("Peak memory: {} GB"sv); + } + + // return Text; +} +} // namespace vlm + +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/vlm_base.h b/plugins/wasi_nn/MLX/model/vlm_base.h new file mode 100644 index 00000000..58bee647 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/vlm_base.h @@ -0,0 +1,177 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once +#include "mlx/base.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { + +namespace vlm { +class BaseCache { +public: + int Offset = 0; + + virtual ~BaseCache() = default; + + virtual std::tuple + updateAndFetch(const mx::array &NewKeys, const mx::array &NewValues) = 0; + + virtual std::vector getState() const; + virtual void setState(const std::vector &State); + + virtual std::string getMetaState() const; + virtual void setMetaState(const std::string &Value); + + virtual bool isTrimmable() const; + virtual int trim(int N); + virtual std::string getType() const { return "BaseCache"; } +}; + +class KVCache : public BaseCache { +public: + int NKVHeads; + int KHeadDim; + int VHeadDim; + mx::array Keys = mx::array({}); + mx::array Values = mx::array({}); + int Step; + + KVCache(int HeadDim, int NKVHeads, int Step = 256); + KVCache(std::pair HeadDims, int NKVHeads, int Step = 256); + virtual std::tuple + updateAndFetch(const mx::array &NewKeys, const mx::array &NewValues) override; + + std::tuple fetch() const; + + void update(const mx::array &NewKeys, const mx::array &NewValues); + + std::vector getState() const override; + void setState(const std::vector &State) override; + + bool isTrimmable() const override; + int trim(int N) override; + std::string getType() const override { return "KVCache"; } +}; + +class RotatingKVCache : public KVCache { +public: + int Keep; + int MaxSize; + int Idx; + + RotatingKVCache(int MaxSize = -1, int Keep = 0, int Step = 256); + + std::tuple + updateAndFetch(const mx::array &NewKeys, const mx::array &NewValues) override; + + std::tuple updateInPlace(const mx::array &NewKeys, + const mx::array &NewValues); + + std::tuple updateConcat(const mx::array &NewKeys, + const mx::array &NewValues); + + mx::array trim(int TrimSize, const mx::array &V, + std::optional Append = std::nullopt); + + mx::array temporalOrder(const mx::array &V); + + std::string getMetaState() const override; + void setMetaState(const std::string &Value) override; + + bool isTrimmable() const override; + int trim(int N) override; + std::string getType() const override { return "RotatingKVCache"; } +}; + +class Module : public mlx::core::nn::Module { +public: + virtual std::tuple> forward( + const mx::array &InputIds, const mx::array &PixelValues, + const mx::array &Mask, + const std::optional>> &Cache = + std::nullopt) = 0; + struct GenerationResult { + std::string Text; + int Token; + std::vector LogProbs; + int PromptTokens; + int GenerationTokens; + float PromptTps; + float GenerationTps; + float PeakMemory; + }; + + // Add this struct to hold generation state + struct StreamGenerationState { + void *Model; + void *Processor; + mx::array PromptTokens; + mx::array InputIds; + mx::array PixelValues; + mx::array Mask; + mx::array CurrentToken; + std::vector CurrentLogProbs; + int TokenCount; + double StartTime; + double PromptTime; + float PromptTps; + bool IsComplete; + std::map Kwargs; + + // Generation parameters + int MaxTokens = 256; + float Temperature = 0.0; + std::optional RepetitionPenalty = std::nullopt; + std::optional RepetitionContextSize = 20; + float TopP = 1.0; + std::map LogitBias; + + // State for generate_step + std::vector RepetitionContext; + std::vector Cache; // Will hold appropriate cache objects + mx::array CrossAttentionStates; + mx::array EncoderOutputs; + }; + + std::vector generate( + const std::string &Prompt = {}, + std::optional Image = std::nullopt, bool Verbose = false, + std::map> + Kwargs = {}); +}; +class LanguageModel : public mlx::core::nn::Module { +public: + virtual int headDim() const = 0; + virtual int nKvHeads() const = 0; + virtual int layers() const = 0; + virtual std::vector> makeCache() { + std::vector> Cache; + int HeadDim = headDim(); + auto KVHeads = nKvHeads(); + for (int I = 0; I < layers(); ++I) { + Cache.emplace_back(std::make_shared(HeadDim, KVHeads)); + } + return Cache; + } + virtual std::tuple> forward( + const mx::array &Inputs, + const std::optional>> &Cache = + std::nullopt) = 0; +}; + +std::optional createAttentionMask( + mx::array H, + std::optional>> = std::nullopt); + +mx::array createAdditiveCausalMask(int N, int Offset = 0); + +} // namespace vlm +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/vlm_sampling.cpp b/plugins/wasi_nn/MLX/model/vlm_sampling.cpp new file mode 100644 index 00000000..c4f9f41c --- /dev/null +++ b/plugins/wasi_nn/MLX/model/vlm_sampling.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "model/vlm_sampling.h" +namespace WasmEdge::Host::WASINN::MLX { +namespace vlm { + +mx::array topPSampling(const mx::array &Logits, float TopP, float Temperature) { + + mx::array WorkingLogits = Logits; + + if (WorkingLogits.dtype() == mx::bfloat16) { + WorkingLogits = astype(WorkingLogits, mx::float32); + } + + mx::array Probs = mx::softmax(WorkingLogits / Temperature, -1); + mx::array SortedIndices = mx::argsort(Probs, -1); + mx::array SqueezedIndices = mx::squeeze(SortedIndices, 0); + mx::array SortedProbs = mx::take(Probs, SqueezedIndices, -1); + mx::array CumulativeProbs = mx::cumsum(SortedProbs, -1); + mx::array TopProbs = mx::where(CumulativeProbs > 1.0f - TopP, SortedProbs, + mx::zeros_like(SortedProbs)); + mx::array SortedToken = mx::random::categorical(mx::log(TopProbs)); + mx::array Token = mx::take(SqueezedIndices, SortedToken); + + return Token; +} + +} // namespace vlm +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/vlm_sampling.h b/plugins/wasi_nn/MLX/model/vlm_sampling.h new file mode 100644 index 00000000..a5fb1592 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/vlm_sampling.h @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once +#include "mlx/base.h" +#include +namespace WasmEdge::Host::WASINN::MLX { +namespace vlm { + +mx::array topPSampling(const mx::array &Logits, float TopP, float Temperature); + +} // namespace vlm +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/wasinn_mlx.cpp b/plugins/wasi_nn/wasinn_mlx.cpp index bd50e47e..868c2aab 100644 --- a/plugins/wasi_nn/wasinn_mlx.cpp +++ b/plugins/wasi_nn/wasinn_mlx.cpp @@ -5,10 +5,16 @@ #include "wasinnenv.h" #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX + +#include "MLX/mlx/base.h" #include "MLX/model/converter.h" -#include "MLX/model/registry.h" +#include "MLX/model/gemma3/gemma3.h" +#include "MLX/model/llm/registry.h" +#include "MLX/model/llm/transformer.h" #include "MLX/model/utils.h" #include "MLX/prompt/prompt.h" +#include +#include #include #endif @@ -61,7 +67,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, // Parse metadata. if (Builders.size() <= 1) { spdlog::error( - "[WASI-NN] MLX backend: Lack model weight or required metadata (tokenizer, model_type)."sv); + "[WASI-NN] MLX backend: Lack model weight or required metadata (model_type)."sv); Env.deleteGraph(GId); return ErrNo::InvalidArgument; } @@ -112,11 +118,6 @@ Expect load(WASINN::WasiNNEnvironment &Env, return ErrNo::InvalidArgument; } TokenizerPath = TokenizerPathView; - } else { - spdlog::error( - "[WASI-NN] MLX backend: Unable to retrieve the tokenizer option."sv); - Env.deleteGraph(GId); - return ErrNo::InvalidArgument; } if (Doc.at_key("max_token").error() == simdjson::SUCCESS) { uint64_t MaxToken; @@ -149,46 +150,12 @@ Expect load(WASINN::WasiNNEnvironment &Env, GraphRef.GroupSize = GroupSize; } - // Load tokenizer. - if (!TokenizerPath.empty()) { - auto Bytes = loadBytesFromFile(TokenizerPath); - if (Bytes.empty()) { - spdlog::error("[WASI-NN] MLX backend: Load tokenizer failed."sv); - Env.deleteGraph(GId); - return ErrNo::InvalidArgument; - } - GraphRef.Tok = tokenizers::Tokenizer::FromBlobJSON(Bytes); - } else { - spdlog::error("[WASI-NN] MLX backend: Tokenizer path not found."sv); - Env.deleteGraph(GId); - return ErrNo::InvalidArgument; - } - - // Create Model. - if (GraphRef.ModelType == "tiny_llama_1.1B_chat_v1.0") { - GraphRef.Model = tinyLlama11BChatV10(); - GraphRef.Prmopt = TinyLLaMAPrompt(); - } else if (GraphRef.ModelType == "llama_3_8b") { - GraphRef.Model = llama38b(); - GraphRef.Prmopt = LLaMA3Prompt(); - } else if (GraphRef.ModelType == "llama_2_7b_chat_hf") { - GraphRef.Model = llama27bChat(); - GraphRef.Prmopt = LLaMA2Prompt(); - } else { - spdlog::error("[WASI-NN] MLX backend: Model type not supported."sv); - Env.deleteGraph(GId); - return ErrNo::InvalidArgument; - } - - if (GraphRef.QBits != 0 && GraphRef.GroupSize != 0 && GraphRef.IsQuantized) { - GraphRef.Model->toQuantized(GraphRef.GroupSize, GraphRef.QBits); - } - + std::unordered_map Weights; // Handle the model path. for (size_t Idx = 0; Idx < Builders.size() - 1; Idx++) { - auto Weight = Builders[Idx]; - const std::string BinModel(reinterpret_cast(Weight.data()), - Weight.size()); + auto WeightData = Builders[Idx]; + const std::string BinModel(reinterpret_cast(WeightData.data()), + WeightData.size()); spdlog::info("[WASI-NN] MLX BinModel: {}"sv, BinModel.size()); if (BinModel.size() == 0) { Env.deleteGraph(GId); @@ -220,19 +187,102 @@ Expect load(WASINN::WasiNNEnvironment &Env, "[WASI-NN][Debug] MLX backend: Write model into a tmpfile...Done"sv); } } - // Load weight. + if (GraphRef.ModelType == "tiny_llama_1.1B_chat_v1.0") { - GraphRef.Model->update(llamaToMlxllm(ModelFilePath)); + auto Weight = llamaToMlxllm(ModelFilePath); + Weights.insert(Weight.begin(), Weight.end()); } else if (GraphRef.ModelType == "llama_3_8b") { - GraphRef.Model->update(llamaToMlxllm(ModelFilePath)); + auto Weight = llamaToMlxllm(ModelFilePath); + Weights.insert(Weight.begin(), Weight.end()); } else if (GraphRef.ModelType == "llama_2_7b_chat_hf") { - GraphRef.Model->update(llamaToMlxllm(ModelFilePath)); + auto Weight = llamaToMlxllm(ModelFilePath); + Weights.insert(Weight.begin(), Weight.end()); + } else if (GraphRef.ModelType == "gemma3") { + auto Weight = mx::load_safetensors(ModelFilePath); + Weights.insert(Weight.first.begin(), Weight.first.end()); } else { - spdlog::error("[WASI-NN] MLX backend: Model type not supported."sv); + spdlog::error("[WASI-NN] MLX backend: Model type {} not supported."sv, + GraphRef.ModelType); Env.deleteGraph(GId); return ErrNo::InvalidArgument; } } + // Create Model. + if (GraphRef.ModelType == "tiny_llama_1.1B_chat_v1.0") { + GraphRef.Model = llm::tinyLlama11BChatV10(); + GraphRef.Prmopt = TinyLLaMAPrompt(); + GraphRef.ModelArch = "llm"; + } else if (GraphRef.ModelType == "llama_3_8b") { + GraphRef.Model = llm::llama38b(); + GraphRef.Prmopt = LLaMA3Prompt(); + GraphRef.ModelArch = "llm"; + } else if (GraphRef.ModelType == "llama_2_7b_chat_hf") { + GraphRef.Model = llm::llama27bChat(); + GraphRef.Prmopt = LLaMA2Prompt(); + GraphRef.ModelArch = "llm"; + } else if (GraphRef.ModelType == "gemma3") { + auto Obj = Doc.get_object(); + gemma3::ModelConfig ModelConfigObj = + gemma3::ModelConfig::fromDict(Obj.value()); + ModelConfigObj.VisionConfig = + (Obj.at_key("vision_config").error() == simdjson::SUCCESS) + ? gemma3::VisionConfig::fromDict( + Obj["vision_config"].get_object().value()) + : gemma3::VisionConfig(); + ModelConfigObj.TextConfig = + (Obj.at_key("text_config").error() == simdjson::SUCCESS) + ? gemma3::TextConfig::fromDict( + Obj["text_config"].get_object().value()) + : gemma3::TextConfig(); + GraphRef.Model = std::dynamic_pointer_cast( + std::make_shared(gemma3::Model(ModelConfigObj))); + Weights = std::dynamic_pointer_cast(GraphRef.Model) + ->sanitize(Weights); + Weights = + gemma3::VisionModel(ModelConfigObj.VisionConfig).sanitize(Weights); + GraphRef.ModelArch = "vlm"; + } else { + spdlog::error( + "[WASI-NN] MLX backend: Model architecture {} not supported."sv, + GraphRef.ModelArch); + Env.deleteGraph(GId); + return ErrNo::InvalidArgument; + } + + // Load tokenizer. + if (!TokenizerPath.empty()) { + auto Bytes = loadBytesFromFile(TokenizerPath); + if (Bytes.empty()) { + spdlog::error("[WASI-NN] MLX backend: Load tokenizer failed."sv); + Env.deleteGraph(GId); + return ErrNo::InvalidArgument; + } + GraphRef.Tok = tokenizers::Tokenizer::FromBlobJSON(Bytes); + } else if (GraphRef.ModelArch == "llm") { + spdlog::error("[WASI-NN] MLX backend: Tokenizer path not found."sv); + Env.deleteGraph(GId); + return ErrNo::InvalidArgument; + } + + if (GraphRef.QBits != 0 && GraphRef.GroupSize != 0 && GraphRef.IsQuantized) { + GraphRef.Model->toQuantized(GraphRef.GroupSize, GraphRef.QBits); + } + + // Load weight. + if (GraphRef.ModelType == "tiny_llama_1.1B_chat_v1.0") { + GraphRef.Model->update(Weights); + } else if (GraphRef.ModelType == "llama_3_8b") { + GraphRef.Model->update(Weights); + } else if (GraphRef.ModelType == "llama_2_7b_chat_hf") { + GraphRef.Model->update(Weights); + } else if (GraphRef.ModelType == "gemma3") { + GraphRef.Model->update(Weights); + } else { + spdlog::error("[WASI-NN] MLX backend: Model type {} not supported."sv, + GraphRef.ModelType); + Env.deleteGraph(GId); + return ErrNo::InvalidArgument; + } if (GraphRef.QBits != 0 && GraphRef.GroupSize != 0 && !GraphRef.IsQuantized) { GraphRef.Model->toQuantized(GraphRef.GroupSize, GraphRef.QBits); @@ -247,19 +297,55 @@ Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, uint32_t &ContextId) noexcept { ContextId = Env.newContext(GraphId, Env.NNGraph[GraphId]); Env.NNContext[ContextId].setReady(); + auto &GraphRef = Env.NNGraph[GraphId].get(); + auto &CxtRef = Env.NNContext[ContextId].get(); + if (GraphRef.ModelArch == "llm") { + CxtRef.Inputs = LLMInput(); + } else if (GraphRef.ModelArch == "vlm") { + CxtRef.Inputs = VLMInput(); + } else { + spdlog::error( + "[WASI-NN] MLX backend: Model architecture {} not supported."sv, + GraphRef.ModelArch); + Env.deleteContext(ContextId); + return ErrNo::InvalidArgument; + } return ErrNo::Success; } Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, - uint32_t, const TensorData &Tensor) noexcept { + uint32_t Index, + const TensorData &Tensor) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN] MLX backend: setInput"sv); } - CxtRef.Inputs = - std::string(reinterpret_cast(Tensor.Tensor.data()), - Tensor.Tensor.size()); + + if (GraphRef.ModelArch == "llm") { + std::get(CxtRef.Inputs).Prompt = + std::string(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); + } else if (GraphRef.ModelArch == "vlm") { + auto TensorPath = + std::string(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); + if (Index == 0) { + std::get(CxtRef.Inputs).Prompt = mx::load(TensorPath); + } else if (Index == 1) { + std::get(CxtRef.Inputs).Pixel = mx::load(TensorPath); + } else if (Index == 2) { + std::get(CxtRef.Inputs).Mask = mx::load(TensorPath); + } else { + spdlog::error("[WASI-NN] MLX backend: Index out of range."sv); + return ErrNo::InvalidArgument; + } + } else { + spdlog::error( + "[WASI-NN] MLX backend: Model architecture {} not supported."sv, + GraphRef.ModelArch); + return ErrNo::InvalidArgument; + } return WASINN::ErrNo::Success; } @@ -271,10 +357,34 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN] MLX backend: getOutput"sv); } - std::string StringTmp(reinterpret_cast(CxtRef.Outputs.data()), - CxtRef.Outputs.size()); - std::copy_n(StringTmp.data(), StringTmp.length(), OutBuffer.data()); - BytesWritten = StringTmp.length(); + if (GraphRef.ModelArch == "llm") { + auto *Output = std::get_if(&CxtRef.Outputs); + if (Output != nullptr) { + std::copy_n(Output->Answer.data(), Output->Answer.length(), + OutBuffer.data()); + BytesWritten = Output->Answer.length(); + } else { + spdlog::error("[WASI-NN] MLX backend: No output found."sv); + return ErrNo::InvalidArgument; + } + } else if (GraphRef.ModelArch == "vlm") { + auto *Output = std::get_if(&CxtRef.Outputs); + if (Output != nullptr) { + auto Answer = Output->Answer; + std::string OutputPath = "Answer.npy"; + mx::save(OutputPath, Answer); + std::copy_n(OutputPath.data(), OutputPath.size(), OutBuffer.data()); + BytesWritten = OutputPath.size(); + } else { + spdlog::error("[WASI-NN] MLX backend: No output found."sv); + return ErrNo::InvalidArgument; + } + } else { + spdlog::error( + "[WASI-NN] MLX backend: Model architecture {} not supported."sv, + GraphRef.ModelArch); + return ErrNo::InvalidArgument; + } return WASINN::ErrNo::Success; } @@ -282,46 +392,48 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - if (GraphRef.Tok == nullptr) { + + const auto Start{std::chrono::steady_clock::now()}; + size_t TokenListSize = 0; + + if (GraphRef.ModelArch == "llm" && GraphRef.Tok == nullptr) { spdlog::error("[WASI-NN] MLX backend: Tokenizer not loaded."sv); return ErrNo::InvalidArgument; } if (GraphRef.EnableDebugLog) { spdlog::info("[WASI-NN] MLX backend: compute"sv); } - const std::vector Ids = GraphRef.Tok->Encode(CxtRef.Inputs); - auto Token = - mx::array(Ids.data(), {static_cast(Ids.size())}, mx::int32); - std::vector TokenList; - std::string Answer; - int32_t Skip = 0; - uint64_t TokenCount = 0; - auto [Y, KVCache] = GraphRef.Model->generate(Token, 0.1); - while (true) { - TokenCount++; - if (TokenCount > GraphRef.MaxToken) { - break; - } - eval(Y); - std::vector Tokens; - auto *Data = Y.data(); - for (int Idx = 0; Idx < static_cast(Y.size()); Idx++) { - Tokens.emplace_back(Data[Idx]); - } - // TODO: break when the token is the eos_token_id - TokenList.insert(TokenList.end(), Tokens.begin(), Tokens.end()); - Answer = GraphRef.Tok->Decode(TokenList); - const AnswerSataus Status = answerSataus(Answer, GraphRef.Prmopt.TextEnd); - if (Status == STOP) { - break; - } - if (Status == GO) { - CxtRef.Outputs += Answer.substr(Skip); - Skip = Answer.size(); - } - auto [NY, NKVCache] = - GraphRef.Model->nextGenerate(Y, GraphRef.Temp, KVCache); - Y = NY, KVCache = NKVCache; + if (GraphRef.ModelArch == "llm") { + auto Result = + std::dynamic_pointer_cast(GraphRef.Model) + ->generate(std::get(CxtRef.Inputs).Prompt, + GraphRef.Prmopt, GraphRef.MaxToken, false, GraphRef.Tok); + CxtRef.Outputs = LLMOutput({Result.Answer}); + TokenListSize = Result.TokenList.size(); + } else if (GraphRef.ModelArch == "vlm") { + auto &Input = std::get(CxtRef.Inputs); + std::map> + Kwargs; + Kwargs.insert({"input_ids", Input.Prompt}); + Kwargs.insert({"pixel_values", Input.Pixel}); + Kwargs.insert({"mask", Input.Mask}); + auto TokenList = std::dynamic_pointer_cast(GraphRef.Model) + ->generate({}, std::nullopt, false, Kwargs); + auto TokenArray = + mx::array(TokenList.data(), {static_cast(TokenList.size())}); + CxtRef.Outputs = VLMOutput({TokenArray}); + TokenListSize = TokenList.size(); + } else { + spdlog::error( + "[WASI-NN] MLX backend: Model architecture {} not supported."sv, + GraphRef.ModelArch); + return ErrNo::InvalidArgument; + } + const auto End{std::chrono::steady_clock::now()}; + const std::chrono::duration ElapsedSeconds{End - Start}; + if (GraphRef.EnableDebugLog) { + spdlog::info("Elapsed time: {} s. TPS: {}.", ElapsedSeconds.count(), + TokenListSize / ElapsedSeconds.count()); } return WASINN::ErrNo::Success; } diff --git a/plugins/wasi_nn/wasinn_mlx.h b/plugins/wasi_nn/wasinn_mlx.h index e2cd9b4e..89aa9212 100644 --- a/plugins/wasi_nn/wasinn_mlx.h +++ b/plugins/wasi_nn/wasinn_mlx.h @@ -10,8 +10,9 @@ #include #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX +#include "MLX/mlx/base.h" #include "MLX/mlx/transformer.h" -#include "MLX/model/transformer.h" +#include "MLX/model/llm/transformer.h" #include "MLX/prompt/prompt.h" #include @@ -24,10 +25,25 @@ struct WasiNNEnvironment; namespace WasmEdge::Host::WASINN::MLX { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX +struct LLMInput { + std::string Prompt = {}; +}; +struct LLMOutput { + std::string Answer = {}; +}; +struct VLMInput { + mx::array Prompt = mx::array({}); + mx::array Pixel = mx::array({}); + mx::array Mask = mx::array({}); +}; +struct VLMOutput { + mx::array Answer = mx::array({}); +}; struct Graph { std::string ModelType; + std::string ModelArch; std::unique_ptr Tok = nullptr; - std::shared_ptr Model; + std::shared_ptr Model; double Temp = 0.0; bool EnableDebugLog = false; bool IsQuantized = false; @@ -39,8 +55,8 @@ struct Graph { struct Context { Context(uint32_t Gid, Graph &) noexcept : GraphId(Gid) {} uint32_t GraphId; - std::string Inputs; - std::string Outputs; + std::variant Inputs; + std::variant Outputs; }; #else struct Graph {}; From 93c426b9ab49ef6fa6cdb2e3692da87b2c29a533 Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 15 May 2025 15:35:40 +0800 Subject: [PATCH 553/623] feat(wasi-nn): use the new libmtmd for multimodal models (#4112) Signed-off-by: dm4 --- plugins/wasi_nn/wasinn_ggml.cpp | 450 +++++++++++--------------------- plugins/wasi_nn/wasinn_ggml.h | 14 +- 2 files changed, 161 insertions(+), 303 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index e52b28d9..1dee2ade 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include @@ -1287,13 +1288,6 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, "Unable to retrieve the input-prefix-bos option."sv) } } - if (Doc.at_key("logits-all").error() == simdjson::SUCCESS) { - auto Err = Doc["logits-all"].get().get(GraphRef.Params.logits_all); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the logits-all option."sv) - } - } if (Doc.at_key("use-mlock").error() == simdjson::SUCCESS) { auto Err = Doc["use-mlock"].get().get(GraphRef.Params.use_mlock); if (Err) { @@ -1795,7 +1789,7 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, const std::string_view Base64ImageTagPrefix = ""sv; -const std::string_view LlavaPromptImagePlaceholder = ""sv; +const std::string_view VisionPromptImagePlaceholder = ""sv; // Get base64 image position if found in prompt. std::optional> @@ -2149,69 +2143,6 @@ void fillBatch(Span Tokens, Graph &GraphRef, NPos += static_cast(Tokens.size()); } -// Evaluate Qwen2vl image embedding. -bool evaluateQwen2vlImageEmbed(llama_context *LlamaCxt, - const struct llava_image_embed *ImageEmbed, - int64_t NBatch, int32_t &NPos, - struct clip_image_size *ImageSize) { - int NEmbd = llama_model_n_embd(llama_get_model(LlamaCxt)); - const int PatchSize = 14 * 2; - const int Ph = - ImageSize->height / PatchSize + (ImageSize->height % PatchSize > 0); - const int Pw = - ImageSize->width / PatchSize + (ImageSize->width % PatchSize > 0); - const int ImgTokens = ImageEmbed->n_image_pos; - std::vector MRopePos; - MRopePos.resize(ImgTokens * 4); - - int32_t StPosId = NPos; - for (int Y = 0; Y < Ph; Y++) { - for (int X = 0; X < Pw; X++) { - int I = Y * Pw + X; - MRopePos[I] = StPosId; - MRopePos[I + ImgTokens] = StPosId + Y; - MRopePos[I + ImgTokens * 2] = StPosId + X; - MRopePos[I + ImgTokens * 3] = 0; - } - } - - int32_t Processed = 0; - std::vector BatchMRopePos; - BatchMRopePos.resize(ImgTokens * 4); - - for (int64_t I = 0; I < ImgTokens; I += NBatch) { - int64_t NEval = ImgTokens - I; - if (NEval > NBatch) { - NEval = NBatch; - } - - std::fill(BatchMRopePos.begin(), BatchMRopePos.end(), 0); - std::copy_n(&MRopePos[Processed], NEval, BatchMRopePos.data()); - std::copy_n(&MRopePos[ImgTokens * 1 + Processed], NEval, - &BatchMRopePos[NEval * 1]); - std::copy_n(&MRopePos[ImgTokens * 2 + Processed], NEval, - &BatchMRopePos[NEval * 2]); - std::copy_n(&MRopePos[ImgTokens * 3 + Processed], NEval, - &BatchMRopePos[NEval * 3]); - - llama_batch Batch = { - static_cast(NEval), // n_tokens - nullptr, // token - (ImageEmbed->embed + I * NEmbd), // embed - BatchMRopePos.data(), // pos - nullptr, // n_seq_id - nullptr, // seq_id - nullptr, // logits - }; - if (llama_decode(LlamaCxt, Batch)) { - RET_ERROR(false, "evaluateQwen2vlImageEmbed: fail to eval."sv) - } - NPos += static_cast(NEval); - Processed += static_cast(NEval); - } - return true; -} - // Evaluate tokens. Construct the tokens into batch and decode. ErrNo evaluateTokens(Span Tokens, Graph &GraphRef, llama_batch &Batch, int &NPos, @@ -2235,30 +2166,13 @@ ErrNo evaluateTokens(Span Tokens, Graph &GraphRef, NEval = static_cast(GraphRef.Params.n_batch); } - // LlamaPos for Qwen2VL. - static std::vector LlamaPos; - if (GraphRef.VisionModelType == VisionModel::Qwen2VL) { - LlamaPos.resize(NEval * 4); - std::fill(LlamaPos.begin(), LlamaPos.end(), 0); - for (int J = 0; J < NEval * 3; J++) { - LlamaPos[J] = NPos + (J % NEval); - } - } - // Fill the batch with pos information. fillBatch(Span(Tokens.begin() + I, NEval), GraphRef, Batch, NPos, IsLogits && I + NEval >= static_cast(Tokens.size())); - // Set the LlamaPos for Qwen2VL. - llama_pos *OriginBatchPos = Batch.pos; - if (GraphRef.VisionModelType == VisionModel::Qwen2VL) { - Batch.pos = LlamaPos.data(); - } - // Decode the batch. auto Status = llama_decode(GraphRef.LlamaContext.get(), Batch); - Batch.pos = OriginBatchPos; if (Status == 1) { RET_ERROR( ErrNo::RuntimeError, @@ -2314,71 +2228,14 @@ ErrNo evaluateInput(Graph &GraphRef, Context &CxtRef, } // Evaluate input tokens. - if (CxtRef.LlavaImageEmbd != nullptr) { - // Llava format prompt with image data. - ReturnCode = - evaluateTokens(Span(CxtRef.LlamaInputs.begin(), - CxtRef.ImagePosition), - GraphRef, CxtRef.LlamaBatch, CxtRef.NPos); - if (ReturnCode != ErrNo::Success) { - RET_ERROR(ReturnCode, - "{}: failed to evaluate input tokens before image."sv, - LogPrefix) - } - - bool EvalImageStatus = false; - switch (GraphRef.VisionModelType) { - case VisionModel::Llava: - LOG_DEBUG(GraphRef.EnableDebugLog, "{}: Eval llava image embd"sv, - LogPrefix) - EvalImageStatus = llava_eval_image_embed( - GraphRef.LlamaContext.get(), CxtRef.LlavaImageEmbd, - static_cast(GraphRef.Params.n_batch), &CxtRef.NPos); - LOG_DEBUG(GraphRef.EnableDebugLog, "{}: Eval llava image embd...done"sv, - LogPrefix) - break; - case VisionModel::Qwen2VL: - LOG_DEBUG(GraphRef.EnableDebugLog, "{}: Eval Qwen2VL image embd"sv, - LogPrefix) - auto *ImageSize = clip_get_load_image_size(GraphRef.ClipContext); - EvalImageStatus = evaluateQwen2vlImageEmbed( - GraphRef.LlamaContext.get(), CxtRef.LlavaImageEmbd, - static_cast(GraphRef.Params.n_batch), CxtRef.NPos, ImageSize); - LOG_DEBUG(GraphRef.EnableDebugLog, "{}: Eval Qwen2VL image embd...done"sv, - LogPrefix) - break; - } - - if (!EvalImageStatus) { - RET_ERROR(ErrNo::RuntimeError, - "{}: failed to evaluate embed image tokens."sv, LogPrefix) - } - ReturnCode = - evaluateTokens(Span( - CxtRef.LlamaInputs.begin() + CxtRef.ImagePosition, - CxtRef.LlamaInputs.size() - CxtRef.ImagePosition), - GraphRef, CxtRef.LlamaBatch, CxtRef.NPos, true); - if (ReturnCode != ErrNo::Success) { - RET_ERROR(ReturnCode, - "{}: failed to evaluate input tokens after image."sv, LogPrefix) - } - } else { - // Text only prompt. - ReturnCode = - evaluateTokens(Span(CxtRef.LlamaInputs.begin(), - CxtRef.LlamaInputs.size()), - GraphRef, CxtRef.LlamaBatch, CxtRef.NPos, true); - if (ReturnCode != ErrNo::Success) { - RET_ERROR(ReturnCode, "{}: failed to evaluate input tokens."sv, LogPrefix) - } + ReturnCode = + evaluateTokens(Span(CxtRef.LlamaInputs.begin(), + CxtRef.LlamaInputs.size()), + GraphRef, CxtRef.LlamaBatch, CxtRef.NPos, true); + if (ReturnCode != ErrNo::Success) { + RET_ERROR(ReturnCode, "{}: failed to evaluate input tokens."sv, LogPrefix) } - CxtRef.Conf.ImagePath = ""sv; - if (CxtRef.LlavaImageEmbd != nullptr) { - LOG_DEBUG(GraphRef.EnableDebugLog, "{}: ImageEmbd consumed"sv, LogPrefix) - llava_image_embed_free(CxtRef.LlavaImageEmbd); - CxtRef.LlavaImageEmbd = nullptr; - } return ErrNo::Success; } @@ -2803,7 +2660,6 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, GraphRef.Params.cb_eval_user_data = ContextParamsDefault.cb_eval_user_data; GraphRef.Params.cache_type_k = ContextParamsDefault.type_k; GraphRef.Params.cache_type_v = ContextParamsDefault.type_v; - GraphRef.Params.logits_all = ContextParamsDefault.logits_all; GraphRef.Params.embedding = ContextParamsDefault.embeddings; GraphRef.Params.no_kv_offload = !ContextParamsDefault.offload_kqv; GraphRef.Params.flash_attn = ContextParamsDefault.flash_attn; @@ -3081,9 +2937,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, auto Base64ImagePos = findBase64ImagePayload(Prompt); if (Base64ImagePos.has_value() || CxtRef.Conf.ImagePath != ""sv) { - // Prompt with image input. Check is llava or mllama case. - - // First check the projection model is loaded. + // First check the projection model is given. if (GraphRef.Params.mmproj.path == ""sv) { RET_ERROR( ErrNo::InvalidArgument, @@ -3091,140 +2945,117 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } // Make sure the projection model is loaded. - if (GraphRef.ClipContext == nullptr) { - LOG_INFO( - true, - "setInput: Load the clip model. Because llama.cpp disabled the GPU support "sv - "for CLIP, the step of loading images in CLIP can only use the "sv - "CPU, which may result in reduced efficiency. (You can refer to "sv - "PR https://github.com/ggerganov/llama.cpp/pull/10896)"sv) - GraphRef.ClipContext = clip_model_load( - GraphRef.Params.mmproj.path.c_str(), GraphRef.EnableLog ? 1 : 0); - if (GraphRef.ClipContext == nullptr) { - RET_ERROR(ErrNo::InvalidArgument, - "setInput: unable to load the clip model."sv) - } - if (clip_is_qwen2vl(GraphRef.ClipContext)) { - GraphRef.VisionModelType = VisionModel::Qwen2VL; - LOG_INFO(true, "setInput: Qwen2vl model loaded."sv) + if (GraphRef.VisionContext == nullptr) { + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: initialize mtmd context."sv) + // Initialize the mtmd context. + mtmd_context_params VisionContextParams = mtmd_context_params_default(); + std::string VisionPromptImagePlaceholderStr(VisionPromptImagePlaceholder); + VisionContextParams.image_marker = + VisionPromptImagePlaceholderStr.c_str(); + VisionContextParams.use_gpu = GraphRef.Params.mmproj_use_gpu; + VisionContextParams.n_threads = GraphRef.Params.cpuparams.n_threads; + VisionContextParams.print_timings = + GraphRef.EnableLog || GraphRef.EnableDebugLog; + if (GraphRef.EnableDebugLog) { + VisionContextParams.verbosity = GGML_LOG_LEVEL_DEBUG; + } else if (GraphRef.EnableLog) { + VisionContextParams.verbosity = GGML_LOG_LEVEL_INFO; } else { - GraphRef.VisionModelType = VisionModel::Llava; - LOG_INFO(true, "setInput: Llava model loaded."sv) + VisionContextParams.verbosity = GGML_LOG_LEVEL_NONE; + } + GraphRef.VisionContext.reset( + mtmd_init_from_file(GraphRef.Params.mmproj.path.c_str(), + GraphRef.LlamaModel.get(), VisionContextParams)); + if (GraphRef.VisionContext == nullptr) { + RET_ERROR(ErrNo::InvalidArgument, + "setInput: unable to load the mmproj model."sv) } + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: initialize mtmd context...Done"sv) } - // Prompt with image. - if (GraphRef.ClipContext != nullptr) { - // Llava case. - LOG_DEBUG(GraphRef.EnableDebugLog, - "setInput: handle llava format prompt."sv) - - // Show some warnings. - if (GraphRef.Params.n_ctx < 4096) { - LOG_INFO( - GraphRef.EnableLog, - "setInput: Context size is {}, we recommend context size >= 2048 when "sv - "using llava-v1.5 and context size >= 4096 when using llava-v1.6 "sv - "for better results."sv, - GraphRef.Params.n_ctx) - } + // Show some warnings for context size. + if (GraphRef.Params.n_ctx < 4096) { + LOG_INFO( + GraphRef.EnableLog, + "setInput: Context size is {}, we recommend context size >= 4096 when using multimodal models for better results"sv, + GraphRef.Params.n_ctx) + } - // Get image embed. - // Follow this link for the supported image formats: - // https://github.com/ggerganov/llama.cpp/blob/master/common/stb_image.h + // Get the image bitmaps. + // Follow this link for the supported image formats: + // https://github.com/ggml-org/llama.cpp/blob/master/common/stb_image.h + mtmd::bitmaps Bitmaps; + if (GraphRef.VisionContext != nullptr) { if (Base64ImagePos.has_value()) { + // Load the image bitmap from the base64 image. LOG_DEBUG(GraphRef.EnableDebugLog, - "setInput: Compute image embd from the base64 image."sv) + "setInput: load the image bitmap from the base64 image."sv) // Extract the payload and image type from the prompt. - auto Payload = extractBase64ImagePayload(Prompt, *Base64ImagePos, - LlavaPromptImagePlaceholder); + std::optional, std::string>> Payload = + extractBase64ImagePayload(Prompt, *Base64ImagePos, + VisionPromptImagePlaceholder); if (Payload.has_value()) { - // Only regenerate the image embedding if the - // always-regenerate-image-embd is on or the image embedding is not - // yet computed. - if (CxtRef.LlavaImageEmbd == nullptr || - CxtRef.Conf.AlwaysRegenerateImageEmbd) { - // Free existing image embedding if regeneration is needed - if (CxtRef.LlavaImageEmbd != nullptr) { - llava_image_embed_free(CxtRef.LlavaImageEmbd); - CxtRef.LlavaImageEmbd = nullptr; - } - - // Create a new image embedding - CxtRef.LlavaImageEmbd = llava_image_embed_make_with_bytes( - GraphRef.ClipContext, - static_cast(GraphRef.Params.cpuparams.n_threads), - Payload->first.data(), static_cast(Payload->first.size())); - } else { - LOG_DEBUG( - GraphRef.EnableDebugLog, - "setInput: Previous image embd is not yet consumed. Use the cached base64 image embd instead of computing a new one"sv) + // Create the new image bitmap. + mtmd::bitmap Bitmap(mtmd_helper_bitmap_init_from_buf( + Payload->first.data(), Payload->first.size())); + if (Bitmap.ptr == nullptr) { + RET_ERROR( + ErrNo::InvalidArgument, + "setInput: unable to load the image from base64 paylaod."sv) } + Bitmaps.entries.push_back(std::move(Bitmap)); } LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: Compute image embd from the base64 image...Done"sv) } else { - // Only regenerate the image embedding if the - // always-regenerate-image-embd is on or the image embedding is not yet - // computed. - if (CxtRef.LlavaImageEmbd == nullptr || - CxtRef.Conf.AlwaysRegenerateImageEmbd) { - // Free existing image embedding if regeneration is needed - if (CxtRef.LlavaImageEmbd != nullptr) { - llava_image_embed_free(CxtRef.LlavaImageEmbd); - CxtRef.LlavaImageEmbd = nullptr; - } - - LOG_DEBUG(GraphRef.EnableDebugLog, - "setInput: Compute image embd from file: {}"sv, - CxtRef.Conf.ImagePath) - // Load the image from the file. - CxtRef.LlavaImageEmbd = llava_image_embed_make_with_filename( - GraphRef.ClipContext, - static_cast(GraphRef.Params.cpuparams.n_threads), - CxtRef.Conf.ImagePath.c_str()); - LOG_DEBUG(GraphRef.EnableDebugLog, - "setInput: Compute image embd from file: {}...Done"sv, - CxtRef.Conf.ImagePath) - } else { - LOG_DEBUG( - GraphRef.EnableDebugLog, - "setInput: Previous image embd is not yet consumed. Use the cached image embd instead of computing a new one"sv) + // Load the image from the file. + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: load the image bitmap from file: {}"sv, + CxtRef.Conf.ImagePath) + mtmd::bitmap Bitmap( + mtmd_helper_bitmap_init_from_file(CxtRef.Conf.ImagePath.c_str())); + if (Bitmap.ptr == nullptr) { + RET_ERROR( + ErrNo::InvalidArgument, + "setInput: unable to load the image bitmap from file: {}."sv, + CxtRef.Conf.ImagePath) } + Bitmaps.entries.push_back(std::move(Bitmap)); + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: load the image bitmap from file: {}...Done"sv, + CxtRef.Conf.ImagePath) } - if (CxtRef.LlavaImageEmbd == nullptr) { - RET_ERROR(ErrNo::InvalidArgument, - "setInput: llava unable to load the image."sv) - } + } - // We split prompt by as placeholder and save the position. - auto PlaceholderPosition = Prompt.find(LlavaPromptImagePlaceholder); - if (PlaceholderPosition == std::string::npos) { - RET_ERROR( - ErrNo::InvalidArgument, - "setInput: unable to find the placeholder in the llava prompt."sv) - } - std::string PromptBeforeImage = Prompt.substr(0, PlaceholderPosition); - std::string PromptAfterImage = Prompt.substr( - PlaceholderPosition + LlavaPromptImagePlaceholder.length()); - std::vector EmbdInputBeforeImage = - common_tokenize(GraphRef.LlamaContext.get(), PromptBeforeImage, - AddSpecial, ParseSpecial); - // Do not add special token (such as , , ... tokens.) to the - // tokens after the image. - std::vector EmbdInputAfterImage = common_tokenize( - GraphRef.LlamaContext.get(), PromptAfterImage, false, ParseSpecial); - CxtRef.ImagePosition = EmbdInputBeforeImage.size(); - CxtRef.LlamaInputs.reserve(EmbdInputBeforeImage.size() + - EmbdInputAfterImage.size()); - CxtRef.LlamaInputs.insert(CxtRef.LlamaInputs.end(), - EmbdInputBeforeImage.begin(), - EmbdInputBeforeImage.end()); - CxtRef.LlamaInputs.insert(CxtRef.LlamaInputs.end(), - EmbdInputAfterImage.begin(), - EmbdInputAfterImage.end()); - LOG_DEBUG(GraphRef.EnableDebugLog, - "setInput: handle llava format prompt...Done"sv) + // Tokenize the prompt. + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize the mtmd prompt"sv) + GraphRef.VisionInputChunks.reset(mtmd_input_chunks_init()); + mtmd_input_text MtmdText; + MtmdText.text = Prompt.c_str(); + MtmdText.add_special = AddSpecial; + MtmdText.parse_special = ParseSpecial; + std::vector BitmapsPtr = Bitmaps.c_ptr(); + int32_t Res = mtmd_tokenize(GraphRef.VisionContext.get(), + GraphRef.VisionInputChunks.get(), &MtmdText, + BitmapsPtr.data(), BitmapsPtr.size()); + if (Res != 0) { + RET_ERROR(ErrNo::InvalidArgument, + "setInput: unable to tokenize the mtmd prompt."sv) + } + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: tokenize the mtmd prompt...Done"sv) + + // Get the number of input tokens (for the metadata). + CxtRef.LlamaNInputs = 0; + for (size_t ChunkIndex = 0; + ChunkIndex < mtmd_input_chunks_size(GraphRef.VisionInputChunks.get()); + ++ChunkIndex) { + size_t NTokens = 0; + const mtmd_input_chunk *Chunk = + mtmd_input_chunks_get(GraphRef.VisionInputChunks.get(), ChunkIndex); + mtmd_input_chunk_get_tokens_text(Chunk, &NTokens); + CxtRef.LlamaNInputs += NTokens; } } else if (GraphRef.TextToSpeech == true) { // TTS prompt. @@ -3235,6 +3066,9 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, "setInput: failed to tokenize tts prompt."sv) } LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize tts prompt...Done"sv) + + // Get the number of input tokens (for the metadata). + CxtRef.LlamaNInputs = CxtRef.LlamaInputs.size(); } else { // Text only prompt. LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize text prompt"sv) @@ -3242,8 +3076,10 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, AddSpecial, ParseSpecial); LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize text prompt...Done"sv) + + // Get the number of input tokens (for the metadata). + CxtRef.LlamaNInputs = CxtRef.LlamaInputs.size(); } - CxtRef.LlamaNInputs = CxtRef.LlamaInputs.size(); // Maybe currently in the compute_single mode. Reset the computing. CxtRef.ComputeSingleStarted = false; @@ -3289,9 +3125,26 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { common_sampler_reset(CxtRef.LlamaSampler); // Evaluate the input tokens. - auto ReturnCode = evaluateInput(GraphRef, CxtRef, "compute"sv); - if (ReturnCode != ErrNo::Success) { - return ReturnCode; + ErrNo ReturnCode = ErrNo::Success; + if (GraphRef.VisionContext == nullptr) { + // Text only prompt. + ReturnCode = evaluateInput(GraphRef, CxtRef, "compute"sv); + if (ReturnCode != ErrNo::Success) { + return ReturnCode; + } + } else { + // Multimodal prompt. + llama_pos NewNPos; + int32_t Res = mtmd_helper_eval_chunks( + GraphRef.VisionContext.get(), GraphRef.LlamaContext.get(), + GraphRef.VisionInputChunks.get(), CxtRef.NPos, + /* seq_id */ 0, static_cast(CxtRef.CurrentBatchSize), + /* logits_last */ true, &NewNPos); + CxtRef.NPos = NewNPos; + if (Res != 0) { + RET_ERROR(ErrNo::InvalidArgument, + "compute: unable to eval the mtmd prompt."sv) + } } // Main prediction loop. @@ -3373,9 +3226,25 @@ Expect computeSingle(WasiNNEnvironment &Env, common_sampler_reset(CxtRef.LlamaSampler); // Evaluate the input tokens. - ReturnCode = evaluateInput(GraphRef, CxtRef, "computeSingle"sv); - if (ReturnCode != ErrNo::Success) { - return ReturnCode; + if (GraphRef.VisionContext == nullptr) { + // Text only prompt. + ReturnCode = evaluateInput(GraphRef, CxtRef, "compute"sv); + if (ReturnCode != ErrNo::Success) { + return ReturnCode; + } + } else { + // Multimodal prompt. + llama_pos NewNPos; + int32_t Res = mtmd_helper_eval_chunks( + GraphRef.VisionContext.get(), GraphRef.LlamaContext.get(), + GraphRef.VisionInputChunks.get(), CxtRef.NPos, + /* seq_id */ 0, static_cast(CxtRef.CurrentBatchSize), + /* logits_last */ true, &NewNPos); + CxtRef.NPos = NewNPos; + if (Res != 0) { + RET_ERROR(ErrNo::InvalidArgument, + "compute: unable to eval the mtmd prompt."sv) + } } } @@ -3437,11 +3306,15 @@ Expect unload(WasiNNEnvironment &Env, uint32_t GraphId) noexcept { GraphRef.LlamaContext.reset(); LOG_DEBUG(IsDebugLog, "unload: free llama context...Done"sv) } - if (GraphRef.ClipContext != nullptr) { - LOG_DEBUG(IsDebugLog, "unload: free clip context"sv) - clip_free(GraphRef.ClipContext); - GraphRef.ClipContext = nullptr; - LOG_DEBUG(IsDebugLog, "unload: free clip context...Done"sv) + if (GraphRef.VisionContext != nullptr) { + LOG_DEBUG(IsDebugLog, "unload: free mtmd context"sv) + GraphRef.VisionContext.reset(); + LOG_DEBUG(IsDebugLog, "unload: free mtmd context...Done"sv) + } + if (GraphRef.VisionInputChunks != nullptr) { + LOG_DEBUG(IsDebugLog, "unload: free mtmd chunks"sv) + GraphRef.VisionInputChunks.reset(); + LOG_DEBUG(IsDebugLog, "unload: free mtmd chunks...Done"sv) } if (GraphRef.TTSModel != nullptr) { LOG_DEBUG(IsDebugLog, "unload: free TTS model"sv) @@ -3466,15 +3339,6 @@ Expect finalizeExecCtx(WasiNNEnvironment &Env, auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); LOG_DEBUG(GraphRef.EnableDebugLog, "finalize_execution_context"sv) - // TODO: Move the resource deallocation into the destructor. - if (CxtRef.LlavaImageEmbd != nullptr) { - LOG_DEBUG(GraphRef.EnableDebugLog, - "finalize_execution_context: free llava image embed"sv) - llava_image_embed_free(CxtRef.LlavaImageEmbd); - CxtRef.LlavaImageEmbd = nullptr; - LOG_DEBUG(GraphRef.EnableDebugLog, - "finalize_execution_context: free llava image embed...Done"sv) - } if (CxtRef.LlamaSampler != nullptr) { LOG_DEBUG(GraphRef.EnableDebugLog, "finalize_execution_context: free compute_single sampler"sv) diff --git a/plugins/wasi_nn/wasinn_ggml.h b/plugins/wasi_nn/wasinn_ggml.h index b9122037..7f45ac13 100644 --- a/plugins/wasi_nn/wasinn_ggml.h +++ b/plugins/wasi_nn/wasinn_ggml.h @@ -12,6 +12,7 @@ #include #include #include +#include #include #endif @@ -33,11 +34,6 @@ enum class EmbdNormalizeType : int32_t { PNorm = 3, }; -enum class VisionModel : uint8_t { - Llava = 0, - Qwen2VL = 1, -}; - struct TTSSpeakerProfile { std::string Text; std::string Data; @@ -64,9 +60,9 @@ struct Graph { // Model context: llama_model_ptr LlamaModel = nullptr; llama_context_ptr LlamaContext = nullptr; - // Clip context (for llava): - struct clip_ctx *ClipContext = nullptr; - VisionModel VisionModelType = VisionModel::Llava; + // Multimodal context: + mtmd::context_ptr VisionContext = nullptr; + mtmd::input_chunks_ptr VisionInputChunks = nullptr; // Text-to-speech: bool TextToSpeech = false; std::string TTSOutputFilePath = "output.wav"; @@ -87,8 +83,6 @@ struct Context { // Llama outputs: std::vector LlamaOutputs; std::vector LlamaOutputTokens; - // Preserve for llava - struct llava_image_embed *LlavaImageEmbd = nullptr; // Data for computing: bool ComputeSingleStarted = false; struct common_sampler *LlamaSampler = nullptr; From efaca71b180b6d1e14499758c307bd13d50e8fe8 Mon Sep 17 00:00:00 2001 From: Han-Wen Tsao Date: Wed, 21 May 2025 15:48:48 +0800 Subject: [PATCH 554/623] feat(WASI-NN/MLX): support quantized gemma3 model (#4099) * feat(WASI-NN/MLX): support quantized gemma3 model Signed-off-by: grorge * fix(WASI-NN/MLX): typo and log prefix Signed-off-by: grorge --------- Signed-off-by: grorge --- plugins/wasi_nn/MLX/mlx/base.cpp | 29 ++++++++++++++------- plugins/wasi_nn/MLX/mlx/base.h | 7 +++-- plugins/wasi_nn/MLX/mlx/embedding.cpp | 4 ++- plugins/wasi_nn/MLX/mlx/embedding.h | 8 ++++-- plugins/wasi_nn/MLX/mlx/linear.cpp | 4 ++- plugins/wasi_nn/MLX/mlx/linear.h | 7 +++-- plugins/wasi_nn/MLX/model/gemma3/gemma3.cpp | 24 ++++++++++------- plugins/wasi_nn/wasinn_mlx.cpp | 29 ++++++++++++++++++++- 8 files changed, 84 insertions(+), 28 deletions(-) diff --git a/plugins/wasi_nn/MLX/mlx/base.cpp b/plugins/wasi_nn/MLX/mlx/base.cpp index 636a78f1..4bff15f9 100644 --- a/plugins/wasi_nn/MLX/mlx/base.cpp +++ b/plugins/wasi_nn/MLX/mlx/base.cpp @@ -20,15 +20,25 @@ void Module::update(std::unordered_map Parameters) { } } -std::shared_ptr Module::toQuantized(int GroupSize, int Bits) { +std::shared_ptr Module::toQuantized( + int GroupSize, int Bits, const std::string &Prefix, + const std::unordered_map &Parameters) { + auto NewPrefix = Prefix + Name + (Prefix.empty() && Name.empty() ? "" : "."); for (auto &[K, V] : Submodules) { - const auto OldModule = V; - auto Weights = V->Parameters.find("weight"); - if (Weights != V->Parameters.end() && - Weights->second.shape().back() % GroupSize != 0) { - continue; + if (V->hasQuantize()) { + auto Weights = V->Parameters.find("weight"); + if (Weights != V->Parameters.end() && !Parameters.empty()) { + if (Parameters.count(NewPrefix + V->Name + ".scales") == 0) { + continue; + } + } + if (Weights != V->Parameters.end() && + Weights->second.shape().back() % GroupSize != 0) { + continue; + } } - V = V->toQuantized(GroupSize, Bits); + V = V->toQuantized(GroupSize, Bits, + Prefix + Name + (Name.empty() ? "" : "."), Parameters); } return shared_from_this(); } @@ -60,12 +70,13 @@ void Module::apply(std::string Key, mx::array Value) { std::unordered_map Module::getWeigts(const std::string &Prefix) { std::unordered_map Weights; + auto NewPrefix = Prefix + Name; for (auto &[K, V] : Submodules) { - auto Subweights = V->getWeigts(Prefix + Name + "."); + auto Subweights = V->getWeigts(NewPrefix + (NewPrefix.empty() ? "" : ".")); Weights.insert(Subweights.begin(), Subweights.end()); } for (auto &[K, V] : Parameters) { - Weights.insert({Prefix + Name + "." + K, V}); + Weights.insert({NewPrefix + (NewPrefix.empty() ? "" : ".") + K, V}); } return Weights; } diff --git a/plugins/wasi_nn/MLX/mlx/base.h b/plugins/wasi_nn/MLX/mlx/base.h index 24ce0e42..3332d434 100644 --- a/plugins/wasi_nn/MLX/mlx/base.h +++ b/plugins/wasi_nn/MLX/mlx/base.h @@ -36,8 +36,11 @@ class Module : public std::enable_shared_from_this { std::unordered_map getWeigts(const std::string &Prefix = "model"); - virtual std::shared_ptr toQuantized(int GroupSize = 64, - int Bits = 4); + virtual std::shared_ptr toQuantized( + int GroupSize = 64, int Bits = 4, const std::string &Prefix = "", + const std::unordered_map &Parameters = {}); + + virtual bool hasQuantize() { return false; } void update(std::unordered_map Parameters); diff --git a/plugins/wasi_nn/MLX/mlx/embedding.cpp b/plugins/wasi_nn/MLX/mlx/embedding.cpp index 8ceb7f57..be9367ef 100644 --- a/plugins/wasi_nn/MLX/mlx/embedding.cpp +++ b/plugins/wasi_nn/MLX/mlx/embedding.cpp @@ -18,7 +18,9 @@ mx::array Embedding::asLinear(mx::array Input) { return matmul(Input, transpose(Parameters.at("weight"))); } -std::shared_ptr Embedding::toQuantized(int GroupSize, int Bits) { +std::shared_ptr +Embedding::toQuantized(int GroupSize, int Bits, const std::string &, + const std::unordered_map &) { auto QuantModel = QuantizedEmbedding::fromEmbedding( std::dynamic_pointer_cast(shared_from_this()), GroupSize, Bits); diff --git a/plugins/wasi_nn/MLX/mlx/embedding.h b/plugins/wasi_nn/MLX/mlx/embedding.h index 778dd1bc..6e6b88d5 100644 --- a/plugins/wasi_nn/MLX/mlx/embedding.h +++ b/plugins/wasi_nn/MLX/mlx/embedding.h @@ -28,8 +28,12 @@ class Embedding : public Module { mx::array asLinear(mx::array Input); - std::shared_ptr toQuantized(int GroupSize = 64, - int Bits = 4) override; + std::shared_ptr + toQuantized(int GroupSize = 64, int Bits = 4, const std::string &Prefix = "", + const std::unordered_map &Parameters = {}) + override; + + virtual bool hasQuantize() override { return true; } }; } // namespace mlx::core::nn diff --git a/plugins/wasi_nn/MLX/mlx/linear.cpp b/plugins/wasi_nn/MLX/mlx/linear.cpp index 99f3ad3f..1de1ab79 100644 --- a/plugins/wasi_nn/MLX/mlx/linear.cpp +++ b/plugins/wasi_nn/MLX/mlx/linear.cpp @@ -17,7 +17,9 @@ mx::array Linear::forward(mx::array Input) { return matmul(Input, transpose(Parameters.at("weight"))); } -std::shared_ptr Linear::toQuantized(int GroupSize, int Bits) { +std::shared_ptr +Linear::toQuantized(int GroupSize, int Bits, const std::string &, + const std::unordered_map &) { auto QuantModel = QuantizedLinear::fromLinear( std::dynamic_pointer_cast(shared_from_this()), GroupSize, Bits); QuantModel->Name = Name; diff --git a/plugins/wasi_nn/MLX/mlx/linear.h b/plugins/wasi_nn/MLX/mlx/linear.h index 5afa7a19..6d880776 100644 --- a/plugins/wasi_nn/MLX/mlx/linear.h +++ b/plugins/wasi_nn/MLX/mlx/linear.h @@ -34,9 +34,12 @@ class Linear : public Module { } virtual mx::array forward(mx::array Input); + std::shared_ptr + toQuantized(int GroupSize = 64, int Bits = 4, const std::string &Prefix = "", + const std::unordered_map &Parameters = {}) + override; - std::shared_ptr toQuantized(int GroupSize = 64, - int Bits = 4) override; + virtual bool hasQuantize() override { return true; } }; } // namespace mlx::core::nn diff --git a/plugins/wasi_nn/MLX/model/gemma3/gemma3.cpp b/plugins/wasi_nn/MLX/model/gemma3/gemma3.cpp index e1ea9a24..e54efa39 100644 --- a/plugins/wasi_nn/MLX/model/gemma3/gemma3.cpp +++ b/plugins/wasi_nn/MLX/model/gemma3/gemma3.cpp @@ -193,16 +193,6 @@ std::shared_ptr Model::fromPretrained(const std::string &ModelPath) { ModelConfigObj.TextConfig = TextConfig::fromDict(Obj["text_config"].get_object().value()); auto Model = std::make_shared(gemma3::Model(ModelConfigObj)); - auto QuantResult = Obj["quantization"].get_object(); - if (!QuantResult.error()) { - auto GroupSize = - static_cast(QuantResult.value()["group_size"].get_int64()); - auto Bits = static_cast(QuantResult.value()["bits"].get_int64()); - spdlog::info("Quantizing model to {} bits, {} group size.", Bits, - GroupSize); - Model = std::dynamic_pointer_cast( - Model->toQuantized(GroupSize, Bits)); - } std::vector WeightFiles; for (auto &P : std::filesystem::directory_iterator(Path)) { if (P.path().extension() == ".safetensors") @@ -220,6 +210,17 @@ std::shared_ptr Model::fromPretrained(const std::string &ModelPath) { } Weights = Model->sanitize(Weights); Weights = gemma3::VisionModel(ModelConfigObj.VisionConfig).sanitize(Weights); + auto QuantResult = Obj["quantization"].get_object(); + if (!QuantResult.error()) { + auto GroupSize = + static_cast(QuantResult.value()["group_size"].get_int64()); + auto Bits = static_cast(QuantResult.value()["bits"].get_int64()); + spdlog::info( + "[WASI-NN] MLX backend: Quantizing model to {} bits, {} group size."sv, + Bits, GroupSize); + Model = std::dynamic_pointer_cast( + Model->toQuantized(GroupSize, Bits, "", Weights)); + } Model->update(Weights); return Model; } @@ -234,6 +235,9 @@ Model::sanitize(const std::unordered_map &Weights) { if (Pos != std::string::npos) Key.replace(Pos, std::string("vision_model").length(), "vision_tower"); } + if (Key.find("model") == 0) { + Key.replace(0, std::string("model").length(), ""); + } Sanitized.insert({Key, Pair.second}); } return Sanitized; diff --git a/plugins/wasi_nn/wasinn_mlx.cpp b/plugins/wasi_nn/wasinn_mlx.cpp index 868c2aab..fd6d0966 100644 --- a/plugins/wasi_nn/wasinn_mlx.cpp +++ b/plugins/wasi_nn/wasinn_mlx.cpp @@ -149,6 +149,25 @@ Expect load(WASINN::WasiNNEnvironment &Env, GraphRef.QBits = QBits; GraphRef.GroupSize = GroupSize; } + if (Doc.at_key("quantization").error() == simdjson::SUCCESS) { + auto QuantResult = Doc["quantization"].get_object(); + auto Err = QuantResult.value()["group_size"].get().get( + GraphRef.GroupSize); + if (Err) { + spdlog::error( + "[WASI-NN] MLX backend: Unable to retrieve the group size from quantization option."sv); + Env.deleteGraph(GId); + return ErrNo::InvalidArgument; + } + Err = QuantResult.value()["bits"].get().get(GraphRef.QBits); + if (Err) { + spdlog::error( + "[WASI-NN] MLX backend: Unable to retrieve the group size from quantization option."sv); + Env.deleteGraph(GId); + return ErrNo::InvalidArgument; + } + GraphRef.IsQuantized = true; + } std::unordered_map Weights; // Handle the model path. @@ -265,7 +284,11 @@ Expect load(WASINN::WasiNNEnvironment &Env, } if (GraphRef.QBits != 0 && GraphRef.GroupSize != 0 && GraphRef.IsQuantized) { - GraphRef.Model->toQuantized(GraphRef.GroupSize, GraphRef.QBits); + spdlog::info( + "[WASI-NN] MLX backend: load Quantized model with q_bits: {} and group_size: {}"sv, + GraphRef.QBits, GraphRef.GroupSize); + GraphRef.Model->toQuantized(GraphRef.GroupSize, GraphRef.QBits, "", + Weights); } // Load weight. @@ -285,6 +308,9 @@ Expect load(WASINN::WasiNNEnvironment &Env, } if (GraphRef.QBits != 0 && GraphRef.GroupSize != 0 && !GraphRef.IsQuantized) { + spdlog::info( + "[WASI-NN] MLX backend: Quantize model with q_bits: {} and group_size: {}"sv, + GraphRef.QBits, GraphRef.GroupSize); GraphRef.Model->toQuantized(GraphRef.GroupSize, GraphRef.QBits); } @@ -432,6 +458,7 @@ Expect compute(WasiNNEnvironment &Env, const auto End{std::chrono::steady_clock::now()}; const std::chrono::duration ElapsedSeconds{End - Start}; if (GraphRef.EnableDebugLog) { + spdlog::info("[WASI-NN] MLX backend: Generate {} tokens."sv, TokenListSize); spdlog::info("Elapsed time: {} s. TPS: {}.", ElapsedSeconds.count(), TokenListSize / ElapsedSeconds.count()); } From 1178a4bc5137bae95cdba3b0ed41cb65732233bd Mon Sep 17 00:00:00 2001 From: dm4 Date: Sat, 24 May 2025 00:11:19 +0800 Subject: [PATCH 555/623] feat(wasi-nn): bump ggml to b5463; bump wasi-nn plugin to 0.1.22 (#4119) Signed-off-by: dm4 --- plugins/wasi_nn/wasinn_ggml.cpp | 15 +++------------ plugins/wasi_nn/wasinn_ggml.h | 1 - 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 1dee2ade..bc017e30 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -8,14 +8,12 @@ #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML #include "simdjson.h" #include -#include #include #include #include #include #include #include -#include #include #include @@ -1318,14 +1316,6 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, "Unable to retrieve the display-prompt option."sv) } } - if (Doc.at_key("dump-kv-cache").error() == simdjson::SUCCESS) { - auto Err = - Doc["dump-kv-cache"].get().get(GraphRef.Params.dump_kv_cache); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the dump-kv-cache option."sv) - } - } if (Doc.at_key("no-kv-offload").error() == simdjson::SUCCESS) { auto Err = Doc["no-kv-offload"].get().get(GraphRef.Params.no_kv_offload); @@ -2950,7 +2940,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, // Initialize the mtmd context. mtmd_context_params VisionContextParams = mtmd_context_params_default(); std::string VisionPromptImagePlaceholderStr(VisionPromptImagePlaceholder); - VisionContextParams.image_marker = + VisionContextParams.media_marker = VisionPromptImagePlaceholderStr.c_str(); VisionContextParams.use_gpu = GraphRef.Params.mmproj_use_gpu; VisionContextParams.n_threads = GraphRef.Params.cpuparams.n_threads; @@ -2968,7 +2958,8 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, GraphRef.LlamaModel.get(), VisionContextParams)); if (GraphRef.VisionContext == nullptr) { RET_ERROR(ErrNo::InvalidArgument, - "setInput: unable to load the mmproj model."sv) + "setInput: unable to load the mmproj model {}."sv, + GraphRef.Params.mmproj.path) } LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: initialize mtmd context...Done"sv) diff --git a/plugins/wasi_nn/wasinn_ggml.h b/plugins/wasi_nn/wasinn_ggml.h index 7f45ac13..e7991353 100644 --- a/plugins/wasi_nn/wasinn_ggml.h +++ b/plugins/wasi_nn/wasinn_ggml.h @@ -11,7 +11,6 @@ #include #include #include -#include #include #include #endif From b4f1c2fd1de758cf147f4ecf7ca4c10bfad4ef71 Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 3 Jun 2025 00:33:56 +0800 Subject: [PATCH 556/623] feat(WASI-NN,ggml): bump llama.cpp b5575 Signed-off-by: hydai --- plugins/wasi_nn/wasinn_ggml.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index bc017e30..ce73d61d 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -11,9 +11,10 @@ #include #include #include +#include #include -#include #include +#include #include #include @@ -2989,7 +2990,8 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, if (Payload.has_value()) { // Create the new image bitmap. mtmd::bitmap Bitmap(mtmd_helper_bitmap_init_from_buf( - Payload->first.data(), Payload->first.size())); + GraphRef.VisionContext.get(), Payload->first.data(), + Payload->first.size())); if (Bitmap.ptr == nullptr) { RET_ERROR( ErrNo::InvalidArgument, @@ -3004,8 +3006,8 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: load the image bitmap from file: {}"sv, CxtRef.Conf.ImagePath) - mtmd::bitmap Bitmap( - mtmd_helper_bitmap_init_from_file(CxtRef.Conf.ImagePath.c_str())); + mtmd::bitmap Bitmap(mtmd_helper_bitmap_init_from_file( + GraphRef.VisionContext.get(), CxtRef.Conf.ImagePath.c_str())); if (Bitmap.ptr == nullptr) { RET_ERROR( ErrNo::InvalidArgument, From 2ef8369861db48f9cd9fb492fa8ed16bdc86bdb7 Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 4 Jun 2025 14:59:03 +0800 Subject: [PATCH 557/623] chore(cmake): seperate the version between wasi_nn and wasi_nn_rpc Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 896f1de1..e7575aca 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -23,6 +23,9 @@ wasmedge_add_library(wasmedgePluginWasiNN include(WASINNDeps) wasmedge_setup_wasinn_target(wasmedgePluginWasiNN PLUGINLIB) +set(WASMEDGE_WASI_NN_VERSION "0.1.23" CACHE STRING "WasmEdge WASI-NN library version") +set(WASMEDGE_WASI_NN_SOVERSION "0" CACHE STRING "WasmEdge WASI-NN library soversion") + # Handle the version of the WASI-NN plugin string(REPLACE "." ";" WASI_NN_VERSION_LIST ${WASMEDGE_WASI_NN_VERSION}) list(GET WASI_NN_VERSION_LIST 0 WASI_NN_VERSION_MAJOR) From 172d2923fdd7d419499bc98d0d67667dc5206fa3 Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 4 Jun 2025 21:30:39 +0800 Subject: [PATCH 558/623] chore(docker): bump the llvm from 12 to 18 on ubuntu 20.04 Signed-off-by: hydai --- utils/docker/Dockerfile.ubuntu-base | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/utils/docker/Dockerfile.ubuntu-base b/utils/docker/Dockerfile.ubuntu-base index a6ed6278..8ef2aa42 100644 --- a/utils/docker/Dockerfile.ubuntu-base +++ b/utils/docker/Dockerfile.ubuntu-base @@ -21,17 +21,21 @@ FROM base AS deps-20 RUN curl -sSf https://apt.kitware.com/kitware-archive.sh | sh RUN apt-get install -y cmake +RUN wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add - \ + && echo "deb http://apt.llvm.org/focal llvm-toolchain-focal-18 main" | tee /etc/apt/sources.list.d/llvm.list + RUN apt-get install -y \ - llvm-12-dev \ - liblld-12-dev \ - clang-12 + llvm-18-dev \ + liblld-18-dev \ + libpolly-18-dev \ + clang-18 -RUN update-alternatives --install /usr/bin/clang clang /usr/bin/clang-12 100 && \ - update-alternatives --install /usr/bin/clang++ clang++ /usr/bin/clang++-12 100 && \ - update-alternatives --install /usr/bin/llvm-strip llvm-strip /usr/bin/llvm-strip-12 100 +RUN update-alternatives --install /usr/bin/clang clang /usr/bin/clang-18 100 && \ + update-alternatives --install /usr/bin/clang++ clang++ /usr/bin/clang++-18 100 && \ + update-alternatives --install /usr/bin/llvm-strip llvm-strip /usr/bin/llvm-strip-18 100 -ENV CC=/usr/bin/clang-12 -ENV CXX=/usr/bin/clang++-12 +ENV CC=/usr/bin/clang-18 +ENV CXX=/usr/bin/clang++-18 ### deps for ubuntu 22.04 ### FROM base AS deps-22 From fad68b24727bdfe3216bb5544e7b4be04a345b7f Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 5 Jun 2025 16:46:04 +0800 Subject: [PATCH 559/623] fix(wasi-nn): clear the context before mtmd evaluation (#4143) Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- plugins/wasi_nn/wasinn_ggml.cpp | 34 ++++++++++++++++----------------- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index e7575aca..fe6894b3 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -23,7 +23,7 @@ wasmedge_add_library(wasmedgePluginWasiNN include(WASINNDeps) wasmedge_setup_wasinn_target(wasmedgePluginWasiNN PLUGINLIB) -set(WASMEDGE_WASI_NN_VERSION "0.1.23" CACHE STRING "WasmEdge WASI-NN library version") +set(WASMEDGE_WASI_NN_VERSION "0.1.24" CACHE STRING "WasmEdge WASI-NN library version") set(WASMEDGE_WASI_NN_SOVERSION "0" CACHE STRING "WasmEdge WASI-NN library soversion") # Handle the version of the WASI-NN plugin diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index ce73d61d..7fba0fbc 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -2181,6 +2181,17 @@ ErrNo evaluateTokens(Span Tokens, Graph &GraphRef, return ErrNo::Success; } +// Clear the context and reset the sampler. +void clearContext(Graph &GraphRef, Context &CxtRef) noexcept { + LOG_DEBUG(GraphRef.EnableDebugLog, "{}: clearContext"sv) + llama_kv_self_clear(GraphRef.LlamaContext.get()); + common_sampler_reset(CxtRef.LlamaSampler); + CxtRef.NPos = 0; + CxtRef.LlamaOutputs.clear(); + CxtRef.LlamaOutputTokens.clear(); + LOG_DEBUG(GraphRef.EnableDebugLog, "{}: clearContext...Done"sv) +} + // Evaluate the input tokens. Clean all inputs if succeeded. ErrNo evaluateInput(Graph &GraphRef, Context &CxtRef, std::string_view LogPrefix) noexcept { @@ -2190,19 +2201,6 @@ ErrNo evaluateInput(Graph &GraphRef, Context &CxtRef, LogPrefix) } - // Clear the outputs. - LOG_DEBUG(GraphRef.EnableDebugLog, - "{}: clear the previous output and tokens"sv, LogPrefix) - CxtRef.LlamaOutputs.clear(); - CxtRef.LlamaOutputTokens.clear(); - LOG_DEBUG(GraphRef.EnableDebugLog, - "{}: clear the previous output and tokens...Done"sv, LogPrefix) - - // Clear the llama context. - llama_kv_self_clear(GraphRef.LlamaContext.get()); - - // Prepare variables; - CxtRef.NPos = 0; // Get the context size. const uint64_t NCtx = llama_n_ctx(GraphRef.LlamaContext.get()); // Minus 4 for the special tokens. (Such as , , ... tokens.) @@ -3110,13 +3108,13 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); LOG_DEBUG(GraphRef.EnableDebugLog, "compute") + // Clear the context and reset the sampler. + clearContext(GraphRef, CxtRef); + if (GraphRef.Params.embedding) { return getEmbedding(GraphRef, CxtRef); } - // Reset the sampler for a new computation. - common_sampler_reset(CxtRef.LlamaSampler); - // Evaluate the input tokens. ErrNo ReturnCode = ErrNo::Success; if (GraphRef.VisionContext == nullptr) { @@ -3215,8 +3213,8 @@ Expect computeSingle(WasiNNEnvironment &Env, if (!CxtRef.ComputeSingleStarted) { CxtRef.ComputeSingleStarted = true; - // Reset the sampler for a new computation. - common_sampler_reset(CxtRef.LlamaSampler); + // Clear the context and reset the sampler. + clearContext(GraphRef, CxtRef); // Evaluate the input tokens. if (GraphRef.VisionContext == nullptr) { From e0848061791451076efe353122ba9020ab796255 Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 5 Jun 2025 17:22:56 +0800 Subject: [PATCH 560/623] feat(wasi-nn): bump ggml to b5593 and wasi-nn plugin to 0.1.25 (#4144) Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index fe6894b3..2ea5b88d 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -23,7 +23,7 @@ wasmedge_add_library(wasmedgePluginWasiNN include(WASINNDeps) wasmedge_setup_wasinn_target(wasmedgePluginWasiNN PLUGINLIB) -set(WASMEDGE_WASI_NN_VERSION "0.1.24" CACHE STRING "WasmEdge WASI-NN library version") +set(WASMEDGE_WASI_NN_VERSION "0.1.25" CACHE STRING "WasmEdge WASI-NN library version") set(WASMEDGE_WASI_NN_SOVERSION "0" CACHE STRING "WasmEdge WASI-NN library soversion") # Handle the version of the WASI-NN plugin From 340ddccc3c83311047012786df897587a6d513bc Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 12 Jun 2025 15:19:29 +0800 Subject: [PATCH 561/623] fix(wasi-nn): fix n_ubatch assignment (#4163) Signed-off-by: dm4 --- plugins/wasi_nn/wasinn_ggml.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 7fba0fbc..a11ca8ec 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -273,7 +273,7 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the ubatch-size option."sv) } - GraphRef.Params.n_batch = static_cast(UBatchSize); + GraphRef.Params.n_ubatch = static_cast(UBatchSize); } if (Doc.at_key("n-keep").error() == simdjson::SUCCESS) { int64_t NKeep; From 9e16864f3c8009ce84913b8a9c3aa8663e0a9850 Mon Sep 17 00:00:00 2001 From: dm4 Date: Thu, 12 Jun 2025 16:03:39 +0800 Subject: [PATCH 562/623] feat(wasi-nn): bump ggml to b5640 and wasi-nn plugin to 0.1.26 (#4164) Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- plugins/wasi_nn/wasinn_ggml.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 2ea5b88d..3f1d81eb 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -23,7 +23,7 @@ wasmedge_add_library(wasmedgePluginWasiNN include(WASINNDeps) wasmedge_setup_wasinn_target(wasmedgePluginWasiNN PLUGINLIB) -set(WASMEDGE_WASI_NN_VERSION "0.1.25" CACHE STRING "WasmEdge WASI-NN library version") +set(WASMEDGE_WASI_NN_VERSION "0.1.26" CACHE STRING "WasmEdge WASI-NN library version") set(WASMEDGE_WASI_NN_SOVERSION "0" CACHE STRING "WasmEdge WASI-NN library soversion") # Handle the version of the WASI-NN plugin diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index a11ca8ec..fbe59054 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -2184,7 +2184,7 @@ ErrNo evaluateTokens(Span Tokens, Graph &GraphRef, // Clear the context and reset the sampler. void clearContext(Graph &GraphRef, Context &CxtRef) noexcept { LOG_DEBUG(GraphRef.EnableDebugLog, "{}: clearContext"sv) - llama_kv_self_clear(GraphRef.LlamaContext.get()); + llama_memory_clear(llama_get_memory(GraphRef.LlamaContext.get()), true); common_sampler_reset(CxtRef.LlamaSampler); CxtRef.NPos = 0; CxtRef.LlamaOutputs.clear(); @@ -2913,7 +2913,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, // Clear the llama context. LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: clear llama context"sv) - llama_kv_self_clear(GraphRef.LlamaContext.get()); + llama_memory_clear(llama_get_memory(GraphRef.LlamaContext.get()), true); LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: clear llama context...Done"sv) // Set the input. From ab46abb1596a72f94057f358b2ae995a146a110c Mon Sep 17 00:00:00 2001 From: "Wang-Yang, Li" <7088579+LFsWang@users.noreply.github.com> Date: Tue, 17 Jun 2025 01:46:55 +0800 Subject: [PATCH 563/623] fix(wasi_nn): update input tensor shape and simplify model compilation (#4185) Signed-off-by: Sylveon --- plugins/wasi_nn/wasinn_openvino.cpp | 16 ++-------------- test/plugins/wasi_nn/wasi_nn.cpp | 22 +++++++++++++++++++++- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/plugins/wasi_nn/wasinn_openvino.cpp b/plugins/wasi_nn/wasinn_openvino.cpp index 729edd02..8f359943 100644 --- a/plugins/wasi_nn/wasinn_openvino.cpp +++ b/plugins/wasi_nn/wasinn_openvino.cpp @@ -101,23 +101,11 @@ Expect setInput(WASINN::WasiNNEnvironment &Env, try { ov::element::Type InputType = ov::element::f32; - ov::Shape InputShape = {1, 224, 224, 3}; + ov::Shape InputShape = {1, 3, 224, 224}; ov::Tensor InputTensor = ov::Tensor(InputType, InputShape, Tensor.Tensor.data()); - const ov::Layout InputLayout{"NHWC"}; - ov::preprocess::PrePostProcessor PPP(GraphRef.OpenVINOModel); - PPP.input() - .tensor() - .set_shape(InputShape) - .set_element_type(InputType) - .set_layout(InputLayout); - PPP.input().preprocess().resize( - ov::preprocess::ResizeAlgorithm::RESIZE_LINEAR); - PPP.input().model().set_layout("NCHW"); - PPP.output().tensor().set_element_type(ov::element::f32); - auto model = PPP.build(); ov::CompiledModel CompiledModel = - Env.OpenVINOCore.compile_model(model, "CPU"); + Env.OpenVINOCore.compile_model(GraphRef.OpenVINOModel, "CPU"); CxtRef.OpenVINOInferRequest = CompiledModel.create_infer_request(); CxtRef.OpenVINOInferRequest.set_input_tensor(Index, InputTensor); } catch (const std::exception &EX) { diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 40c8f8e7..50e286ce 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -130,13 +130,33 @@ TEST(WasiNNTest, OpenVINOBackend) { WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); // Load the files. - std::vector TensorData = + std::vector TensorDataLegecy = readEntireFile("./wasinn_openvino_fixtures/tensor-1x224x224x3-f32.bgr"); std::vector XmlRead = readEntireFile("./wasinn_openvino_fixtures/mobilenet.xml"); std::vector WeightRead = readEntireFile("./wasinn_openvino_fixtures/mobilenet.bin"); + // Convert the NHWC to NCHW format. + // For historical reasons, the OpenVINO model expects the input tensor in + // NCHW format, while the input tensor is in NHWC format. + // https://github.com/intel/openvino-rs/blob/v0.3.3/crates/openvino/tests/fixtures/mobilenet/build.sh#L39 + // https://github.com/intel/openvino-rs/blob/v0.8.0/crates/openvino/tests/classify-mobilenet.rs#L34 + ASSERT_EQ(TensorDataLegecy.size(), 3 * 224 * 224 * 4); + std::vector TensorData(TensorDataLegecy.size()); + + for (size_t C = 0; C < 3; ++C) { + for (size_t H = 0; H < 224; ++H) { + for (size_t W = 0; W < 224; ++W) { + size_t Loc = H * 224 + W; + for (size_t B = 0; B < 4; ++B) { + TensorData[(C * 224 * 224 + Loc) * 4 + B] = + TensorDataLegecy[(3 * Loc + C) * 4 + B]; + } + } + } + } + std::vector TensorDim{1, 3, 224, 224}; uint32_t BuilderPtr = UINT32_C(0); uint32_t LoadEntryPtr = UINT32_C(0); From deb659f1b49b630804172cf3dfac6a667e3d45a0 Mon Sep 17 00:00:00 2001 From: Han-Wen Tsao Date: Fri, 20 Jun 2025 23:21:44 +0800 Subject: [PATCH 564/623] feat(WASI-NN/mlx): support tensor input (#4189) * feat(WASI-NN/mlx): support tensor input Signed-off-by: grorge * fix(WASI-NN/MLX): rewrite eerror message Signed-off-by: grorge --------- Signed-off-by: grorge --- plugins/wasi_nn/wasinn_mlx.cpp | 176 ++++++++++++++++++++++++++++++--- 1 file changed, 165 insertions(+), 11 deletions(-) diff --git a/plugins/wasi_nn/wasinn_mlx.cpp b/plugins/wasi_nn/wasinn_mlx.cpp index fd6d0966..637f49a8 100644 --- a/plugins/wasi_nn/wasinn_mlx.cpp +++ b/plugins/wasi_nn/wasinn_mlx.cpp @@ -36,6 +36,165 @@ std::string loadBytesFromFile(const std::string &Path) { return Data; } +mx::array fromBytes(const Span &Bytes) { + if (Bytes.size() < 9) { + spdlog::error( + "[WASI-NN] MLX backend: Tensor data must be at least 9 bytes long, current size: {}."sv, + Bytes.size()); + return mx::array({0.0f}); + } + + size_t Offset = 0; + + uint8_t RtypeValue = Bytes[Offset]; + Offset += 1; + + uint32_t DimBufLen; + std::memcpy(&DimBufLen, &Bytes[Offset], 4); + Offset += 4; + + std::vector Shape; + for (size_t I = 0; I < DimBufLen; I += 4) { + uint32_t Dim; + std::memcpy(&Dim, &Bytes[Offset + I], 4); + Shape.push_back(static_cast(Dim)); + } + Offset += DimBufLen; + + uint32_t DataBufLen; + std::memcpy(&DataBufLen, &Bytes[Offset], 4); + Offset += 4; + + const void *DataPtr = &Bytes[Offset]; + switch (RtypeValue) { + case 0: { // F16 + return mx::array(static_cast(DataPtr), Shape, + mx::float16); + } + case 1: { // F32 + return mx::array(static_cast(DataPtr), Shape, mx::float32); + } + case 2: { // F64 + return mx::array(static_cast(DataPtr), Shape, mx::float64); + } + case 3: { // U8 + return mx::array(static_cast(DataPtr), Shape, mx::uint8); + } + case 4: { // I32 + return mx::array(static_cast(DataPtr), Shape, mx::int32); + } + case 5: { // I64 + return mx::array(static_cast(DataPtr), Shape, mx::int64); + } + default: + spdlog::error("[WASI-NN] MLX backend: Unsupported rtype: {}", RtypeValue); + return mx::array({0.0f}); + } +} + +std::vector toBytes(const mx::array &Arr) { + std::vector Result; + + uint8_t RtypeValue; + switch (Arr.dtype()) { + case mx::float16: + RtypeValue = 0; + break; + case mx::float32: + RtypeValue = 1; + break; + case mx::float64: + RtypeValue = 2; + break; + case mx::uint8: + RtypeValue = 3; + break; + case mx::int32: + RtypeValue = 4; + break; + case mx::int64: + RtypeValue = 5; + break; + default: + spdlog::error( + "[WASI-NN] MLX backend: Unsupported dtype to convert to Processor Tensor"sv); + return Result; + } + Result.push_back(RtypeValue); + + std::vector DimBuf; + auto Shape = Arr.shape(); + for (int Dim : Shape) { + uint32_t DimData = static_cast(Dim); + const uint8_t *DimBytes = reinterpret_cast(&DimData); + DimBuf.insert(DimBuf.end(), DimBytes, DimBytes + 4); + } + + uint32_t DimBufLen = static_cast(DimBuf.size()); + const uint8_t *DimLenBytes = reinterpret_cast(&DimBufLen); + Result.insert(Result.end(), DimLenBytes, DimLenBytes + 4); + + Result.insert(Result.end(), DimBuf.begin(), DimBuf.end()); + + std::vector DataBuf; + mx::eval(Arr); + + switch (Arr.dtype()) { + case mx::float16: { + auto *Data = Arr.data(); + size_t ByteSize = Arr.nbytes(); + const uint8_t *DataBytes = reinterpret_cast(Data); + DataBuf.insert(DataBuf.end(), DataBytes, DataBytes + ByteSize); + break; + } + case mx::float32: { + auto *Data = Arr.data(); + size_t ByteSize = Arr.nbytes(); + const uint8_t *DataBytes = reinterpret_cast(Data); + DataBuf.insert(DataBuf.end(), DataBytes, DataBytes + ByteSize); + break; + } + case mx::float64: { + auto *Data = Arr.data(); + size_t ByteSize = Arr.nbytes(); + const uint8_t *DataBytes = reinterpret_cast(Data); + DataBuf.insert(DataBuf.end(), DataBytes, DataBytes + ByteSize); + break; + } + case mx::uint8: { + auto *Data = Arr.data(); + size_t ByteSize = Arr.nbytes(); + DataBuf.insert(DataBuf.end(), Data, Data + ByteSize); + break; + } + case mx::int32: { + auto *Data = Arr.data(); + size_t ByteSize = Arr.nbytes(); + const uint8_t *DataBytes = reinterpret_cast(Data); + DataBuf.insert(DataBuf.end(), DataBytes, DataBytes + ByteSize); + break; + } + case mx::int64: { + auto *Data = Arr.data(); + size_t ByteSize = Arr.nbytes(); + const uint8_t *DataBytes = reinterpret_cast(Data); + DataBuf.insert(DataBuf.end(), DataBytes, DataBytes + ByteSize); + break; + } + default: + spdlog::error("[WASI-NN] MLX backend: Unsupported MLX dtype for conversion " + "to Rust Tensor"sv); + break; + } + + uint32_t DataBufLen = static_cast(DataBuf.size()); + const uint8_t *DataLenBytes = reinterpret_cast(&DataBufLen); + Result.insert(Result.end(), DataLenBytes, DataLenBytes + 4); + Result.insert(Result.end(), DataBuf.begin(), DataBuf.end()); + + return Result; +} + enum AnswerSataus { STOP, WAIT, @@ -353,15 +512,12 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, std::string(reinterpret_cast(Tensor.Tensor.data()), Tensor.Tensor.size()); } else if (GraphRef.ModelArch == "vlm") { - auto TensorPath = - std::string(reinterpret_cast(Tensor.Tensor.data()), - Tensor.Tensor.size()); if (Index == 0) { - std::get(CxtRef.Inputs).Prompt = mx::load(TensorPath); + std::get(CxtRef.Inputs).Prompt = fromBytes(Tensor.Tensor); } else if (Index == 1) { - std::get(CxtRef.Inputs).Pixel = mx::load(TensorPath); + std::get(CxtRef.Inputs).Pixel = fromBytes(Tensor.Tensor); } else if (Index == 2) { - std::get(CxtRef.Inputs).Mask = mx::load(TensorPath); + std::get(CxtRef.Inputs).Mask = fromBytes(Tensor.Tensor); } else { spdlog::error("[WASI-NN] MLX backend: Index out of range."sv); return ErrNo::InvalidArgument; @@ -396,11 +552,9 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, } else if (GraphRef.ModelArch == "vlm") { auto *Output = std::get_if(&CxtRef.Outputs); if (Output != nullptr) { - auto Answer = Output->Answer; - std::string OutputPath = "Answer.npy"; - mx::save(OutputPath, Answer); - std::copy_n(OutputPath.data(), OutputPath.size(), OutBuffer.data()); - BytesWritten = OutputPath.size(); + auto OutputBytes = toBytes(Output->Answer); + std::copy_n(OutputBytes.data(), OutputBytes.size(), OutBuffer.data()); + BytesWritten = OutputBytes.size(); } else { spdlog::error("[WASI-NN] MLX backend: No output found."sv); return ErrNo::InvalidArgument; From 551fa564ce4a24a094461a598b8616adc5b3eab2 Mon Sep 17 00:00:00 2001 From: "Wang-Yang, Li" <7088579+LFsWang@users.noreply.github.com> Date: Tue, 24 Jun 2025 14:58:50 +0800 Subject: [PATCH 565/623] feat(wasi-nn/openvino): add device string retrieval and improve input tensor handling (#4199) Signed-off-by: Sylveon --- plugins/wasi_nn/wasinn_openvino.cpp | 30 +++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn/wasinn_openvino.cpp b/plugins/wasi_nn/wasinn_openvino.cpp index 8f359943..ab5ee6cd 100644 --- a/plugins/wasi_nn/wasinn_openvino.cpp +++ b/plugins/wasi_nn/wasinn_openvino.cpp @@ -10,6 +10,25 @@ using namespace std::literals; namespace WasmEdge::Host::WASINN::OpenVINO { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_OPENVINO + +static Expect +GetDeviceString(WASINN::Device TargetDevice, + std::string &DeviceString) noexcept { + switch (TargetDevice) { + case Device::AUTO: + case Device::CPU: + DeviceString = "CPU"; + break; + case Device::GPU: + DeviceString = "GPU"; + break; + default: + spdlog::error("[WASI-NN] Unsupported device type"sv); + return WASINN::ErrNo::InvalidArgument; + } + return WASINN::ErrNo::Success; +} + Expect load(WASINN::WasiNNEnvironment &Env, Span> Builders, WASINN::Device Device, uint32_t &GraphId) noexcept { @@ -101,11 +120,18 @@ Expect setInput(WASINN::WasiNNEnvironment &Env, try { ov::element::Type InputType = ov::element::f32; - ov::Shape InputShape = {1, 3, 224, 224}; + ov::Shape InputShape(Tensor.Dimension.data(), + Tensor.Dimension.data() + Tensor.Dimension.size()); ov::Tensor InputTensor = ov::Tensor(InputType, InputShape, Tensor.Tensor.data()); + std::string Device; + if (!GetDeviceString(GraphRef.TargetDevice, Device)) { + spdlog::error("[WASI-NN] Failed to get device string for OpenVINO."sv); + return WASINN::ErrNo::InvalidArgument; + } + ov::CompiledModel CompiledModel = - Env.OpenVINOCore.compile_model(GraphRef.OpenVINOModel, "CPU"); + Env.OpenVINOCore.compile_model(GraphRef.OpenVINOModel, Device); CxtRef.OpenVINOInferRequest = CompiledModel.create_infer_request(); CxtRef.OpenVINOInferRequest.set_input_tensor(Index, InputTensor); } catch (const std::exception &EX) { From 26a468ab24078f6e26500134d3a7e8351da44a47 Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 1 Jul 2025 16:51:35 +0800 Subject: [PATCH 566/623] chore(bpf): reorder the dual licenses to avoid the false positive license issue (#4217) Signed-off-by: hydai --- test/plugins/wasm_bpf/assets/bpf-sources/simple_map.bpf.c | 2 +- test/plugins/wasm_bpf/assets/bpf-sources/simple_ringbuf.bpf.c | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/plugins/wasm_bpf/assets/bpf-sources/simple_map.bpf.c b/test/plugins/wasm_bpf/assets/bpf-sources/simple_map.bpf.c index 1049aa5f..0e8facb0 100644 --- a/test/plugins/wasm_bpf/assets/bpf-sources/simple_map.bpf.c +++ b/test/plugins/wasm_bpf/assets/bpf-sources/simple_map.bpf.c @@ -1,4 +1,4 @@ -// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +// SPDX-License-Identifier: BSD-3-Clause OR GPL-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC #define SEC(name) __attribute__((section(name), used)) diff --git a/test/plugins/wasm_bpf/assets/bpf-sources/simple_ringbuf.bpf.c b/test/plugins/wasm_bpf/assets/bpf-sources/simple_ringbuf.bpf.c index 0a0b0e8d..1e170cc8 100644 --- a/test/plugins/wasm_bpf/assets/bpf-sources/simple_ringbuf.bpf.c +++ b/test/plugins/wasm_bpf/assets/bpf-sources/simple_ringbuf.bpf.c @@ -1,4 +1,4 @@ -// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +// SPDX-License-Identifier: BSD-3-Clause OR GPL-2.0 // SPDX-FileCopyrightText: 2019-2024 Second State INC #define SEC(name) __attribute__((section(name), used)) From df7103b61f885484950a7af52651d388353cf522 Mon Sep 17 00:00:00 2001 From: Karan Lokchandani <135950363+PhantomInTheWire@users.noreply.github.com> Date: Thu, 3 Jul 2025 09:44:17 +0530 Subject: [PATCH 567/623] fix(wasi-nn, ggml): Empty generation returned if n_predict is -1 or -2 (#4208) Signed-off-by: Karan --- plugins/wasi_nn/wasinn_ggml.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index fbe59054..3c6987a9 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -744,7 +744,7 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, GraphRef.Params.sampling.dry_allowed_length = static_cast(DryAllowedLength); } - if (Doc.at_key("dry-penalty-last-n").error() == simdjson::SUCCESS) { + if (Doc.at_key("dry-last-n-penalty").error() == simdjson::SUCCESS) { int64_t DryLastNPenalty; auto Err = Doc["dry-last-n-penalty"].get().get(DryLastNPenalty); if (Err) { @@ -2260,7 +2260,9 @@ ErrNo sampleOutput(Graph &GraphRef, Context &CxtRef, } // Deal with end of text token. const llama_vocab *Vocab = llama_model_get_vocab(GraphRef.LlamaModel.get()); - if (llama_vocab_is_eog(Vocab, common_sampler_last(CxtRef.LlamaSampler))) { + // Only stop on EOS if GraphRef.Params.sampling.ignore_eos is false. + if (!GraphRef.Params.sampling.ignore_eos && + llama_vocab_is_eog(Vocab, common_sampler_last(CxtRef.LlamaSampler))) { LOG_INFO(GraphRef.EnableLog, "sampleOutput: EOS token found."sv) return ErrNo::EndOfSequence; } @@ -3140,8 +3142,10 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { // Main prediction loop. LOG_DEBUG(GraphRef.EnableDebugLog, "compute: enter main prediction loop"sv) - int64_t NRemain = CxtRef.Conf.NPredict; - while (NRemain-- > 0) { + int64_t NPredict = + CxtRef.Conf.NPredict < 0 ? INT32_MAX : CxtRef.Conf.NPredict; + + while (NPredict-- > 0) { ReturnCode = sampleOutput(GraphRef, CxtRef); if (ReturnCode != ErrNo::Success) { break; From 9231acb0c6eb9bf0261cd335e41b14e123cb670f Mon Sep 17 00:00:00 2001 From: Yi Date: Fri, 4 Jul 2025 20:53:21 +0800 Subject: [PATCH 568/623] docker(ubuntu): simplify image names for plugins (#4235) fix(docker): simplify names for plugins Signed-off-by: Yi Huang --- utils/docker/docker-bake.ubuntu.hcl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/utils/docker/docker-bake.ubuntu.hcl b/utils/docker/docker-bake.ubuntu.hcl index f37c51da..b2012d5f 100644 --- a/utils/docker/docker-bake.ubuntu.hcl +++ b/utils/docker/docker-bake.ubuntu.hcl @@ -76,7 +76,11 @@ function "tags-backports" { function "tags-simplified" { params = [target, ubuntu, toolchain] - result = target == "base" && toolchain == "clang" ? "ubuntu-${ubuntu}" : "" + result = toolchain == "clang" ? join("-", compact([ + "ubuntu", + ubuntu, + target == "plugins" ? "plugins" : "", + ])) : "" } function "tags" { From 61530acc7a98b8f4fcf10d1dc7ffcbc09b0bc190 Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 8 Jul 2025 15:19:09 +0800 Subject: [PATCH 569/623] feat(WASI-NN,ggml): bump llama.cpp b5835 (#4249) The reranking option is changed in this PR[1]. [1]: https://github.com/ggml-org/llama.cpp/pull/14208 Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 2 +- plugins/wasi_nn/wasinn_ggml.cpp | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 3f1d81eb..d4cfc39a 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -23,7 +23,7 @@ wasmedge_add_library(wasmedgePluginWasiNN include(WASINNDeps) wasmedge_setup_wasinn_target(wasmedgePluginWasiNN PLUGINLIB) -set(WASMEDGE_WASI_NN_VERSION "0.1.26" CACHE STRING "WasmEdge WASI-NN library version") +set(WASMEDGE_WASI_NN_VERSION "0.1.27" CACHE STRING "WasmEdge WASI-NN library version") set(WASMEDGE_WASI_NN_SOVERSION "0" CACHE STRING "WasmEdge WASI-NN library soversion") # Handle the version of the WASI-NN plugin diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 3c6987a9..7a9c7aee 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -1386,7 +1386,10 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, GraphRef.Params.embd_sep = EmbdSep; } if (Doc.at_key("reranking").error() == simdjson::SUCCESS) { - auto Err = Doc["reranking"].get().get(GraphRef.Params.reranking); + bool Reranking = false; + auto Err = Doc["reranking"].get().get(Reranking); + GraphRef.Params.embedding = true; + GraphRef.Params.pooling_type = LLAMA_POOLING_TYPE_RANK; if (Err) { RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the reranking option."sv) From d5f46f9c30902ccded807ced5e1bdddae9c0b33d Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 15 Jul 2025 13:37:26 +0800 Subject: [PATCH 570/623] feat(WASI-NN,ggml): bump llama.cpp b5896 (#4263) Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index d4cfc39a..746eab98 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -23,7 +23,7 @@ wasmedge_add_library(wasmedgePluginWasiNN include(WASINNDeps) wasmedge_setup_wasinn_target(wasmedgePluginWasiNN PLUGINLIB) -set(WASMEDGE_WASI_NN_VERSION "0.1.27" CACHE STRING "WasmEdge WASI-NN library version") +set(WASMEDGE_WASI_NN_VERSION "0.1.28" CACHE STRING "WasmEdge WASI-NN library version") set(WASMEDGE_WASI_NN_SOVERSION "0" CACHE STRING "WasmEdge WASI-NN library soversion") # Handle the version of the WASI-NN plugin From 0e10b36d455a58515f69c069aa75a234e1dcc888 Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 22 Jul 2025 15:24:50 +0800 Subject: [PATCH 571/623] chore(lineguard): apply lineguard for all files (#4265) Signed-off-by: hydai --- plugins/wasm_bpf/README.md | 26 +++--- plugins/wasmedge_ffmpeg/CMakeLists.txt | 4 +- test/plugins/wasi_crypto/CMakeLists.txt | 2 +- test/plugins/wasmedge_ffmpeg/CMakeLists.txt | 2 +- thirdparty/wasi_crypto/api.hpp | 94 ++++++++++----------- 5 files changed, 64 insertions(+), 64 deletions(-) diff --git a/plugins/wasm_bpf/README.md b/plugins/wasm_bpf/README.md index 18438062..fed20dcd 100644 --- a/plugins/wasm_bpf/README.md +++ b/plugins/wasm_bpf/README.md @@ -93,18 +93,18 @@ After building, you can find the plug-in `./build/plugins/wasm_bpf/libwasmedgePl Set `WASMEDGE_PLUGIN_PATH=./build/plugins/wasm_bpf/` and run wasmedge: ```console -# WASMEDGE_PLUGIN_PATH=./build/plugins/wasm_bpf/ ./build/tools/wasmedge/wasmedge execve.wasm - -[289150] node -> /bin/sh -c which ps -[289151] sh -> which ps -[289152] node -> /bin/sh -c /usr/bin/ps -ax -o pid=,ppid=,pcpu=,pmem=,c -[289153] sh -> /usr/bin/ps -ax -o pid=,ppid=,pcpu=,pmem=,command= -[289154] node -> /bin/sh -c "/root/.vscode-server-insiders/bin/96a795cc -[289155] sh -> /root/.vscode-server-insiders/bin/96a795cc0 245632 245678 289148 -[289156] cpuUsage.sh -> sed -n s/^cpu\s//p /proc/stat -[289157] cpuUsage.sh -> cat /proc/245632/stat -[289158] cpuUsage.sh -> cat /proc/245678/stat -[289159] cpuUsage.sh -> cat /proc/289148/stat -[289160] cpuUsage.sh -> sleep 1 +# WASMEDGE_PLUGIN_PATH=./build/plugins/wasm_bpf/ ./build/tools/wasmedge/wasmedge execve.wasm + +[289150] node -> /bin/sh -c which ps +[289151] sh -> which ps +[289152] node -> /bin/sh -c /usr/bin/ps -ax -o pid=,ppid=,pcpu=,pmem=,c +[289153] sh -> /usr/bin/ps -ax -o pid=,ppid=,pcpu=,pmem=,command= +[289154] node -> /bin/sh -c "/root/.vscode-server-insiders/bin/96a795cc +[289155] sh -> /root/.vscode-server-insiders/bin/96a795cc0 245632 245678 289148 +[289156] cpuUsage.sh -> sed -n s/^cpu\s//p /proc/stat +[289157] cpuUsage.sh -> cat /proc/245632/stat +[289158] cpuUsage.sh -> cat /proc/245678/stat +[289159] cpuUsage.sh -> cat /proc/289148/stat +[289160] cpuUsage.sh -> sleep 1 ^C ``` diff --git a/plugins/wasmedge_ffmpeg/CMakeLists.txt b/plugins/wasmedge_ffmpeg/CMakeLists.txt index 0a1ff4a8..ad7b867b 100644 --- a/plugins/wasmedge_ffmpeg/CMakeLists.txt +++ b/plugins/wasmedge_ffmpeg/CMakeLists.txt @@ -14,7 +14,7 @@ pkg_check_modules(LIBAV REQUIRED IMPORTED_TARGET wasmedge_add_library(wasmedgePluginWasmEdgeFFmpeg SHARED - + avcodec/avCodecContext.cpp avcodec/avCodec.cpp avcodec/avCodecParameters.cpp @@ -52,7 +52,7 @@ wasmedge_add_library(wasmedgePluginWasmEdgeFFmpeg swscale/swscale_func.cpp swscale/module.cpp - + ffmpeg_env.cpp ) diff --git a/test/plugins/wasi_crypto/CMakeLists.txt b/test/plugins/wasi_crypto/CMakeLists.txt index 8935d066..69fc8482 100644 --- a/test/plugins/wasi_crypto/CMakeLists.txt +++ b/test/plugins/wasi_crypto/CMakeLists.txt @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: 2019-2024 Second State INC -wasmedge_add_executable(wasiCryptoTests +wasmedge_add_executable(wasiCryptoTests aeads.cpp asymmetric.cpp common.cpp diff --git a/test/plugins/wasmedge_ffmpeg/CMakeLists.txt b/test/plugins/wasmedge_ffmpeg/CMakeLists.txt index f1c0ca15..1f580014 100644 --- a/test/plugins/wasmedge_ffmpeg/CMakeLists.txt +++ b/test/plugins/wasmedge_ffmpeg/CMakeLists.txt @@ -28,7 +28,7 @@ wasmedge_add_executable(wasmedgeFFmpegTests avutil/avPixfmt.cpp swresample/swresample_func.cpp - + swscale/swscale_func.cpp utils.cpp diff --git a/thirdparty/wasi_crypto/api.hpp b/thirdparty/wasi_crypto/api.hpp index 71bed3d2..c5eb6ddb 100644 --- a/thirdparty/wasi_crypto/api.hpp +++ b/thirdparty/wasi_crypto/api.hpp @@ -68,7 +68,7 @@ enum __wasi_crypto_errno_e_t : uint16_t { /** * An error occurred when trying to during a conversion from a host type to a guest type. - * + * * Only an internal bug can throw this error. */ __WASI_CRYPTO_ERRNO_GUEST_ERROR = 1, @@ -105,14 +105,14 @@ enum __wasi_crypto_errno_e_t : uint16_t { /** * An invalid or incompatible key was supplied. - * + * * The key may not be valid, or was generated for a different algorithm or parameters set. */ __WASI_CRYPTO_ERRNO_INVALID_KEY = 8, /** * The currently selected algorithm doesn't support the requested output length. - * + * * This error is thrown by non-extensible hash functions, when requesting an output size larger than they produce out of a single block. */ __WASI_CRYPTO_ERRNO_INVALID_LENGTH = 9, @@ -124,18 +124,18 @@ enum __wasi_crypto_errno_e_t : uint16_t { /** * A secure random numbers generator is not available. - * + * * The requested operation requires random numbers, but the host cannot securely generate them at the moment. */ __WASI_CRYPTO_ERRNO_RNG_ERROR = 11, /** * An error was returned by the underlying cryptography library. - * + * * The host may be running out of memory, parameters may be incompatible with the chosen implementation of an algorithm or another unexpected error may have happened. - * + * * Ideally, the specification should provide enough details and guidance to make this error impossible to ever be thrown. - * + * * Realistically, the WASI crypto module cannot possibly cover all possible error types implementations can return, especially since some of these may be language-specific. * This error can thus be thrown when other error types are not suitable, and when the original error comes from the cryptographic primitives themselves and not from the WASI module. */ @@ -163,26 +163,26 @@ enum __wasi_crypto_errno_e_t : uint16_t { /** * An internal error occurred. - * + * * This error is reserved to internal consistency checks, and must only be sent if the internal state of the host remains safe after an inconsistency was detected. */ __WASI_CRYPTO_ERRNO_INTERNAL_ERROR = 17, /** * Too many handles are currently open, and a new one cannot be created. - * + * * Implementations are free to represent handles as they want, and to enforce limits to limit resources usage. */ __WASI_CRYPTO_ERRNO_TOO_MANY_HANDLES = 18, /** * A key was provided, but the chosen algorithm doesn't support keys. - * + * * This is returned by symmetric operations. - * + * * Many hash functions, in particular, do not support keys without being used in particular constructions. * Blindly ignoring a key provided by mistake while trying to open a context for such as function could cause serious security vulnerabilities. - * + * * These functions must refuse to create the context and return this error instead. */ __WASI_CRYPTO_ERRNO_KEY_NOT_SUPPORTED = 19, @@ -194,16 +194,16 @@ enum __wasi_crypto_errno_e_t : uint16_t { /** * The provided authentication tag is invalid or incompatible with the current algorithm. - * + * * This error is returned by decryption functions and tag verification functions. - * + * * Unlike `verification_failed`, this error code is returned when the tag cannot possibly verify for any input. */ __WASI_CRYPTO_ERRNO_INVALID_TAG = 21, /** * The requested operation is incompatible with the current scheme. - * + * * For example, the `symmetric_state_encrypt()` function cannot complete if the selected construction is a key derivation function. * This error code will be returned instead. */ @@ -211,9 +211,9 @@ enum __wasi_crypto_errno_e_t : uint16_t { /** * A nonce is required. - * + * * Most encryption schemes require a nonce. - * + * * In the absence of a nonce, the WASI cryptography module can automatically generate one, if that can be done safely. The nonce can be retrieved later with the `symmetric_state_option_get()` function using the `nonce` parameter. * If automatically generating a nonce cannot be done safely, the module never falls back to an insecure option and requests an explicit nonce by throwing that error. */ @@ -226,7 +226,7 @@ enum __wasi_crypto_errno_e_t : uint16_t { /** * The named option was not set. - * + * * The caller tried to read the value of an option that was not set. * This error is used to make the distinction between an empty option, and an option that was not set and left to its default value. */ @@ -234,30 +234,30 @@ enum __wasi_crypto_errno_e_t : uint16_t { /** * A key or key pair matching the requested identifier cannot be found using the supplied information. - * + * * This error is returned by a secrets manager via the `keypair_from_id()` function. */ __WASI_CRYPTO_ERRNO_NOT_FOUND = 26, /** * The algorithm requires parameters that haven't been set. - * + * * Non-generic options are required and must be given by building an `options` set and giving that object to functions instantiating that algorithm. */ __WASI_CRYPTO_ERRNO_PARAMETERS_MISSING = 27, /** * A requested computation is not done yet, and additional calls to the function are required. - * + * * Some functions, such as functions generating key pairs and password stretching functions, can take a long time to complete. - * + * * In order to avoid a host call to be blocked for too long, these functions can return prematurely, requiring additional calls with the same parameters until they complete. */ __WASI_CRYPTO_ERRNO_IN_PROGRESS = 28, /** * Multiple keys have been provided, but they do not share the same type. - * + * * This error is returned when trying to build a key pair from a public key and a secret key that were created for different and incompatible algorithms. */ __WASI_CRYPTO_ERRNO_INCOMPATIBLE_KEYS = 29, @@ -399,7 +399,7 @@ static_assert(alignof(__wasi_algorithm_type_e_t) == 2, "witx calculated align"); /** * Version of a managed key. - * + * * A version can be an arbitrary `u64` integer, with the exception of some reserved values. */ using __wasi_version_t = uint64_t; @@ -425,12 +425,12 @@ static_assert(alignof(__wasi_timestamp_t) == 8, "witx calculated align"); /** * Handle for functions returning output whose size may be large or not known in advance. - * + * * An `array_output` object contains a host-allocated byte array. - * + * * A guest can get the size of that array after a function returns in order to then allocate a buffer of the correct size. * In addition, the content of such an object can be consumed by a guest in a streaming fashion. - * + * * An `array_output` handle is automatically closed after its full content has been consumed. */ using __wasi_array_output_t = int32_t; @@ -440,9 +440,9 @@ static_assert(alignof(__wasi_array_output_t) == 4, "witx calculated align"); /** * A set of options. - * + * * This type is used to set non-default parameters. - * + * * The exact set of allowed options depends on the algorithm being used. */ using __wasi_options_t = int32_t; @@ -452,7 +452,7 @@ static_assert(alignof(__wasi_options_t) == 4, "witx calculated align"); /** * A handle to the optional secrets management facilities offered by a host. - * + * * This is used to generate, retrieve and invalidate managed keys. */ using __wasi_secrets_manager_t = int32_t; @@ -470,9 +470,9 @@ static_assert(alignof(__wasi_keypair_t) == 4, "witx calculated align"); /** * A state to absorb data to be signed. - * + * * After a signature has been computed or verified, the state remains valid for further operations. - * + * * A subsequent signature would sign all the data accumulated since the creation of the state object. */ using __wasi_signature_state_t = int32_t; @@ -514,7 +514,7 @@ static_assert(alignof(__wasi_signature_verification_state_t) == 4, "witx calcula /** * A state to perform symmetric operations. - * + * * The state is not reset nor invalidated after an option has been performed. * Incremental updates and sessions are thus supported. */ @@ -525,9 +525,9 @@ static_assert(alignof(__wasi_symmetric_state_t) == 4, "witx calculated align"); /** * A symmetric key. - * + * * The key can be imported from raw bytes, or can be a reference to a managed key. - * + * * If it was imported, the host will wipe it from memory as soon as the handle is closed. */ using __wasi_symmetric_key_t = int32_t; @@ -537,13 +537,13 @@ static_assert(alignof(__wasi_symmetric_key_t) == 4, "witx calculated align"); /** * An authentication tag. - * + * * This is an object returned by functions computing authentication tags. - * + * * A tag can be compared against another tag (directly supplied as raw bytes) in constant time with the `symmetric_tag_verify()` function. - * + * * This object type can't be directly created from raw bytes. They are only returned by functions computing MACs. - * + * * The host is responsible for securely wiping them from memory on close. */ using __wasi_symmetric_tag_t = int32_t; @@ -565,7 +565,7 @@ static_assert(alignof(__wasi_opt_options_u_e_t) == 1, "witx calculated align"); /** * An optional options set. - * + * * This union simulates an `Option` type to make the `options` parameter of some functions optional. */ union __wasi_opt_options_u_t { @@ -594,7 +594,7 @@ static_assert(alignof(__wasi_opt_symmetric_key_u_e_t) == 1, "witx calculated ali /** * An optional symmetric key. - * + * * This union simulates an `Option` type to make the `symmetric_key` parameter of some functions optional. */ union __wasi_opt_symmetric_key_u_t { @@ -616,7 +616,7 @@ static_assert(alignof(__wasi_u64_t) == 8, "witx calculated align"); /** * `$kx_keypair` is just an alias for `$keypair` - * + * * However, bindings may want to define a specialized type `kx_keypair` as a super class of `keypair`. */ using __wasi_kx_keypair_t = __wasi_keypair_t; @@ -626,7 +626,7 @@ static_assert(alignof(__wasi_kx_keypair_t) == 4, "witx calculated align"); /** * `$kx_publickey` is just an alias for `$publickey` - * + * * However, bindings may want to define a specialized type `kx_publickey` as a super class of `publickey`, with additional methods such as `dh`. */ using __wasi_kx_publickey_t = __wasi_publickey_t; @@ -636,7 +636,7 @@ static_assert(alignof(__wasi_kx_publickey_t) == 4, "witx calculated align"); /** * `$kx_secretkey` is just an alias for `$secretkey` - * + * * However, bindings may want to define a specialized type `kx_secretkey` as a super class of `secretkeykey`, with additional methods such as `dh`. */ using __wasi_kx_secretkey_t = __wasi_secretkey_t; @@ -646,7 +646,7 @@ static_assert(alignof(__wasi_kx_secretkey_t) == 4, "witx calculated align"); /** * `$signature_keypair` is just an alias for `$keypair` - * + * * However, bindings may want to define a specialized type `signature_keypair` as a super class of `keypair`, with additional methods such as `sign`. */ using __wasi_signature_keypair_t = __wasi_keypair_t; @@ -656,7 +656,7 @@ static_assert(alignof(__wasi_signature_keypair_t) == 4, "witx calculated align") /** * `$signature_publickey` is just an alias for `$publickey` - * + * * However, bindings may want to define a specialized type `signature_publickey` as a super class of `publickey`, with additional methods such as `verify`. */ using __wasi_signature_publickey_t = __wasi_publickey_t; @@ -666,7 +666,7 @@ static_assert(alignof(__wasi_signature_publickey_t) == 4, "witx calculated align /** * `$signature_secretkey` is just an alias for `$secretkey` - * + * * However, bindings may want to define a specialized type `signature_secretkey` as a super class of `secretkey`. */ using __wasi_signature_secretkey_t = __wasi_secretkey_t; From 2a31f284dd1f0c5491f6480d229a5d1e7e4d2d41 Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 6 Aug 2025 11:40:31 +0800 Subject: [PATCH 572/623] feat(WASI-NN,ggml): bump llama.cpp b6097, support gpt-oss (#4300) Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 746eab98..639851ff 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -23,7 +23,7 @@ wasmedge_add_library(wasmedgePluginWasiNN include(WASINNDeps) wasmedge_setup_wasinn_target(wasmedgePluginWasiNN PLUGINLIB) -set(WASMEDGE_WASI_NN_VERSION "0.1.28" CACHE STRING "WasmEdge WASI-NN library version") +set(WASMEDGE_WASI_NN_VERSION "0.1.29" CACHE STRING "WasmEdge WASI-NN library version") set(WASMEDGE_WASI_NN_SOVERSION "0" CACHE STRING "WasmEdge WASI-NN library soversion") # Handle the version of the WASI-NN plugin From ed250fea25ae70b0583a1d139778c20a2017f0ae Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 13 Aug 2025 14:16:48 +0800 Subject: [PATCH 573/623] feat(wasi-nn): remove llama_context_default_params, use default common_params value Signed-off-by: dm4 --- plugins/wasi_nn/wasinn_ggml.cpp | 30 +----------------------------- 1 file changed, 1 insertion(+), 29 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 7a9c7aee..a5fcd040 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -2630,34 +2630,6 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, GraphRef.Params.n_gpu_layers = ModelParamsDefault.n_gpu_layers; GraphRef.Params.mmproj.path = ""sv; GraphRef.Params.warmup = false; - // Initialize the context parameters. - llama_context_params ContextParamsDefault = llama_context_default_params(); - GraphRef.Params.cpuparams.n_threads = ContextParamsDefault.n_threads; - GraphRef.Params.n_ctx = ContextParamsDefault.n_ctx; - GraphRef.Params.n_batch = ContextParamsDefault.n_batch; - GraphRef.Params.n_ubatch = ContextParamsDefault.n_ubatch; - GraphRef.Params.cpuparams.n_threads = ContextParamsDefault.n_threads_batch; - GraphRef.Params.cpuparams_batch.n_threads = - ContextParamsDefault.n_threads_batch; - GraphRef.Params.rope_scaling_type = ContextParamsDefault.rope_scaling_type; - GraphRef.Params.pooling_type = ContextParamsDefault.pooling_type; - GraphRef.Params.attention_type = ContextParamsDefault.attention_type; - GraphRef.Params.rope_freq_base = ContextParamsDefault.rope_freq_base; - GraphRef.Params.rope_freq_scale = ContextParamsDefault.rope_freq_scale; - GraphRef.Params.yarn_ext_factor = ContextParamsDefault.yarn_ext_factor; - GraphRef.Params.yarn_attn_factor = ContextParamsDefault.yarn_attn_factor; - GraphRef.Params.yarn_beta_fast = ContextParamsDefault.yarn_beta_fast; - GraphRef.Params.yarn_beta_slow = ContextParamsDefault.yarn_beta_slow; - GraphRef.Params.yarn_orig_ctx = ContextParamsDefault.yarn_orig_ctx; - GraphRef.Params.defrag_thold = ContextParamsDefault.defrag_thold; - GraphRef.Params.cb_eval = ContextParamsDefault.cb_eval; - GraphRef.Params.cb_eval_user_data = ContextParamsDefault.cb_eval_user_data; - GraphRef.Params.cache_type_k = ContextParamsDefault.type_k; - GraphRef.Params.cache_type_v = ContextParamsDefault.type_v; - GraphRef.Params.embedding = ContextParamsDefault.embeddings; - GraphRef.Params.no_kv_offload = !ContextParamsDefault.offload_kqv; - GraphRef.Params.flash_attn = ContextParamsDefault.flash_attn; - GraphRef.Params.no_perf = ContextParamsDefault.no_perf; // Initialize the sampling parameters. const common_params_sampling SamplerParamsDefault; @@ -2666,7 +2638,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, GraphRef.Conf.StreamStdout = false; GraphRef.Conf.EmbdNormalize = static_cast(CommonParamsDefault.embd_normalize); - GraphRef.Conf.NPredict = ContextParamsDefault.n_ctx; + GraphRef.Conf.NPredict = GraphRef.Params.n_ctx; GraphRef.Conf.ReversePrompt = ""sv; GraphRef.Conf.ImagePath = ""sv; From 6c4d9793a532cd35c76d98a0c8f78243cbb9ad7c Mon Sep 17 00:00:00 2001 From: dm4 Date: Wed, 13 Aug 2025 14:18:15 +0800 Subject: [PATCH 574/623] feat(wasi-nn): bump ggml to b6191 and wasi-nn plugin to 0.1.30 Signed-off-by: dm4 --- plugins/wasi_nn/CMakeLists.txt | 2 +- plugins/wasi_nn/wasinn_ggml.cpp | 43 ++++++++++++++++++++++++++++++++- plugins/wasi_nn/wasinn_ggml.h | 3 +++ 3 files changed, 46 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 639851ff..ca885f4b 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -23,7 +23,7 @@ wasmedge_add_library(wasmedgePluginWasiNN include(WASINNDeps) wasmedge_setup_wasinn_target(wasmedgePluginWasiNN PLUGINLIB) -set(WASMEDGE_WASI_NN_VERSION "0.1.29" CACHE STRING "WasmEdge WASI-NN library version") +set(WASMEDGE_WASI_NN_VERSION "0.1.30" CACHE STRING "WasmEdge WASI-NN library version") set(WASMEDGE_WASI_NN_SOVERSION "0" CACHE STRING "WasmEdge WASI-NN library soversion") # Handle the version of the WASI-NN plugin diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index a5fcd040..aca618eb 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -142,6 +142,32 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, } GraphRef.Params.n_gpu_layers = static_cast(NGPULayers); } + if (Doc.at_key("cpu-moe").error() == simdjson::SUCCESS) { + bool CpuMoe; + auto Err = Doc["cpu-moe"].get().get(CpuMoe); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the cpu-moe option."sv) + } + if (CpuMoe) { + GraphRef.TensorBuftOverrides.push_back("\\.ffn_(up|down|gate)_exps"); + } + } + if (Doc.at_key("n-cpu-moe").error() == simdjson::SUCCESS) { + int64_t NCpuMoe; + auto Err = Doc["n-cpu-moe"].get().get(NCpuMoe); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-cpu-moe option."sv) + } + if (NCpuMoe < 0) { + RET_ERROR(ErrNo::InvalidArgument, "Invalid n-cpu-moe value."sv) + } + for (int I = 0; I < NCpuMoe; I++) { + GraphRef.TensorBuftOverrides.push_back( + string_format("blk\\.%d\\.ffn_(up|down|gate)_exps", I)); + } + } if (Doc.at_key("tensor-split").error() == simdjson::SUCCESS) { // The TensorSplit is a comma-separated list of non-negative values. // E.g., "3,2" presents 60% of the data to GPU 0 and 40% to GPU 1. @@ -1745,6 +1771,15 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, } } + // The tensor buffer overrides should terminated with empty pattern. + if (!GraphRef.TensorBuftOverrides.empty()) { + for (const std::string &Override : GraphRef.TensorBuftOverrides) { + GraphRef.Params.tensor_buft_overrides.push_back( + {Override.c_str(), ggml_backend_cpu_buffer_type()}); + } + GraphRef.Params.tensor_buft_overrides.push_back({nullptr, nullptr}); + } + if (GraphRef.TextToSpeech) { GraphRef.Params.sampling.top_k = 4; GraphRef.Params.sampling.samplers = { @@ -2616,7 +2651,8 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Initialize the plugin parameters. GraphRef.EnableLog = false; GraphRef.EnableDebugLog = false; - const common_params CommonParamsDefault; + common_params CommonParamsDefault; + CommonParamsDefault.lr.init(); GraphRef.Params = CommonParamsDefault; GraphRef.Params.n_keep = 0; GraphRef.Params.n_chunks = -1; @@ -3296,6 +3332,11 @@ Expect unload(WasiNNEnvironment &Env, uint32_t GraphId) noexcept { GraphRef.TTSContext.reset(); LOG_DEBUG(IsDebugLog, "unload: free TTS context...Done"sv) } + if (!GraphRef.TensorBuftOverrides.empty()) { + LOG_DEBUG(IsDebugLog, "unload: free tensor buffer overrides"sv) + GraphRef.TensorBuftOverrides.clear(); + LOG_DEBUG(IsDebugLog, "unload: free tensor buffer overrides...Done"sv) + } Env.deleteGraph(GraphId); Env.mdRemoveById(GraphId); diff --git a/plugins/wasi_nn/wasinn_ggml.h b/plugins/wasi_nn/wasinn_ggml.h index e7991353..fb48b060 100644 --- a/plugins/wasi_nn/wasinn_ggml.h +++ b/plugins/wasi_nn/wasinn_ggml.h @@ -13,6 +13,8 @@ #include #include #include + +#include #endif namespace WasmEdge::Host::WASINN { @@ -56,6 +58,7 @@ struct Graph { bool EnableLog = false; bool EnableDebugLog = false; common_params Params; + std::list TensorBuftOverrides; // Model context: llama_model_ptr LlamaModel = nullptr; llama_context_ptr LlamaContext = nullptr; From 97f53675ceefad8bedf5be5f9aaf279b99a52f22 Mon Sep 17 00:00:00 2001 From: Khush Agrawal <150333865+Khushmagrawal@users.noreply.github.com> Date: Thu, 28 Aug 2025 11:37:01 +0530 Subject: [PATCH 575/623] feat(WASI-NN,bitnet): add BitNet backend support (#4253) * feat(plugins/bitnet): add BitNet backend support Signed-off-by: Khush Agrawal * ci: add BitNet backend to workflow Signed-off-by: Khush Agrawal --------- Signed-off-by: Khush Agrawal --- plugins/wasi_nn/CMakeLists.txt | 1 + plugins/wasi_nn/bitnet.patch | 39 + plugins/wasi_nn/wasinn_bitnet.cpp | 2433 +++++++++++++++++++++++++++ plugins/wasi_nn/wasinn_bitnet.h | 140 ++ plugins/wasi_nn/wasinnenv.cpp | 4 +- plugins/wasi_nn/wasinnenv.h | 1 + plugins/wasi_nn/wasinnfunc.cpp | 30 +- plugins/wasi_nn/wasinntypes.h | 4 +- test/plugins/wasi_nn/CMakeLists.txt | 13 + test/plugins/wasi_nn/wasi_nn.cpp | 352 +++- 10 files changed, 3006 insertions(+), 11 deletions(-) create mode 100644 plugins/wasi_nn/bitnet.patch create mode 100644 plugins/wasi_nn/wasinn_bitnet.cpp create mode 100644 plugins/wasi_nn/wasinn_bitnet.h diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index ca885f4b..95e3fc5d 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -18,6 +18,7 @@ wasmedge_add_library(wasmedgePluginWasiNN wasinn_whisper.cpp wasinn_chattts.cpp wasinn_mlx.cpp + wasinn_bitnet.cpp ) include(WASINNDeps) diff --git a/plugins/wasi_nn/bitnet.patch b/plugins/wasi_nn/bitnet.patch new file mode 100644 index 00000000..b9e506d5 --- /dev/null +++ b/plugins/wasi_nn/bitnet.patch @@ -0,0 +1,39 @@ +diff --git a/utils/codegen_tl1.py b/utils/codegen_tl1.py +index 4c2e7dd..6600a4d 100644 +--- a/utils/codegen_tl1.py ++++ b/utils/codegen_tl1.py +@@ -14,7 +14,9 @@ static void * aligned_malloc(size_t size) {{\n\ + return _aligned_malloc(size, 64);\n\ + #else\n\ + void * ptr = nullptr;\n\ +- posix_memalign(&ptr, 64, size);\n\ ++ if (posix_memalign(&ptr, 64, size) != 0) {{\n\ ++ return nullptr;\n\ ++ }}\n\ + return ptr;\n\ + #endif\n\ + }}\n\ +@@ -201,10 +203,10 @@ def gen_body_core_code(bm, by): + int8x16_t vec_v_{0}_right_tmp1 = vqtbl1q_s8(vec_lut[{1} * k + {5}], vec_a{0}_bot);\n\ + int8x16x2_t vec_v_left_{0} = vzipq_s8(vec_v_{0}_left_tmp1, vec_v_{0}_left_tmp0);\n\ + int8x16x2_t vec_v_right_{0} = vzipq_s8(vec_v_{0}_right_tmp1, vec_v_{0}_right_tmp0);\n\ +- vec_c[{6}] += vec_v_left_{0}.val[0];\n\ +- vec_c[{6}] += vec_v_right_{0}.val[0];\n\ +- vec_c[{7}] += vec_v_left_{0}.val[1];\n\ +- vec_c[{7}] += vec_v_right_{0}.val[1];\n\ ++ vec_c[{6}] = vaddq_s16(vec_c[{6}], vmovl_s8(vget_low_s8(vec_v_left_{0}.val[0])));\n\ ++ vec_c[{6}] = vaddq_s16(vec_c[{6}], vmovl_s8(vget_low_s8(vec_v_right_{0}.val[0])));\n\ ++ vec_c[{7}] = vaddq_s16(vec_c[{7}], vmovl_s8(vget_low_s8(vec_v_left_{0}.val[1])));\n\ ++ vec_c[{7}] = vaddq_s16(vec_c[{7}], vmovl_s8(vget_low_s8(vec_v_right_{0}.val[1])));\n\ + ".format(i, 2 * by // 2, (4 * i) % (2 * by // 2), (4 * i + 1) % (2 * by // 2), (4 * i + 2) % (2 * by // 2), (4 * i + 3) % (2 * by // 2), (i * 2) // (by // 2) * 2 + 0, (i * 2) // (by // 2) * 2 + 1) + + all_code = "".join([all_code, core_code]) +@@ -232,7 +234,7 @@ inline void tbl_impl_{0}(int32_t* c, int8_t* lut, uint8_t* a) {{\n\ + #ifdef __ARM_NEON\n\ + const int KK = BBK{0} / 2;\n\ + const uint8x16_t vec_mask = vdupq_n_u8(0x0f);\n\ +- const int8x16_t vec_zero = vdupq_n_s16(0x0000);\n\ ++ const int16x8_t vec_zero = vdupq_n_s16(0x0000);\n\ + int8x16_t vec_lut[2 * KK];\n\ + ".format(pre, BM, BK) + diff --git a/plugins/wasi_nn/wasinn_bitnet.cpp b/plugins/wasi_nn/wasinn_bitnet.cpp new file mode 100644 index 00000000..74530fc2 --- /dev/null +++ b/plugins/wasi_nn/wasinn_bitnet.cpp @@ -0,0 +1,2433 @@ +#include "wasinn_bitnet.h" +#include "wasinnenv.h" +#include + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_BITNET +#include "simdjson.h" +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +namespace WasmEdge::Host::WASINN::BitNet { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_BITNET + +namespace { + +// Macro for logging debug message. +#define LOG_DEBUG(Debug, ...) \ + if (Debug) { \ + spdlog::info("[WASI-NN][Debug] BitNet backend: "sv __VA_ARGS__); \ + } + +// Macro for logging info message. +#define LOG_INFO(Info, ...) \ + if (Info) { \ + spdlog::info("[WASI-NN] BitNet backend: "sv __VA_ARGS__); \ + } + +// Macro for logging warning message. +#define LOG_WARN(...) spdlog::warn("[WASI-NN] BitNet backend: "sv __VA_ARGS__); + +// Macro for logging error message. +#define LOG_ERROR(...) \ + spdlog::error("[WASI-NN] BitNet backend: "sv __VA_ARGS__); + +// Macro for logging error message and return. +#define RET_ERROR(Error, ...) \ + spdlog::error("[WASI-NN] BitNet backend: "sv __VA_ARGS__); \ + return Error; + +// Llama logging callback. +void llamaLogCallback(ggml_log_level LogLevel, const char *LogText, + void *UserData) { + Graph &GraphRef = *reinterpret_cast(UserData); + if (!GraphRef.EnableLog) { + return; + } + std::string Text(LogText); + // Remove the trailing newlines. + Text = Text.erase(Text.find_last_not_of("\n") + 1); + // Skip for "." + if (Text == ".") { + return; + } + if (LogLevel == GGML_LOG_LEVEL_ERROR) { + spdlog::error("[WASI-NN] BitNet.cpp: {}"sv, Text); + } else if (LogLevel == GGML_LOG_LEVEL_WARN) { + spdlog::warn("[WASI-NN] BitNet.cpp: {}"sv, Text); + } else if (LogLevel == GGML_LOG_LEVEL_INFO) { + spdlog::info("[WASI-NN] BitNet.cpp: {}"sv, Text); + } else if (LogLevel == GGML_LOG_LEVEL_DEBUG) { + spdlog::debug("[WASI-NN] BitNet.cpp: {}"sv, Text); + } +} + +// >>>>>>>> Metadata related functions >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + +// Helper function to parse comma-separated string into vector. +void stringToList(const std::string &Raw, std::vector &Out) { + std::string Copy = Raw; + std::replace(Copy.begin(), Copy.end(), ',', ' '); + std::stringstream SS(Copy); + Out.clear(); + while (SS.good()) { + int TmpInt; + SS >> TmpInt; + Out.push_back(TmpInt); + } +} + +// Parse metadata from json. +ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, + const std::string &Metadata, bool *IsModelUpdated = nullptr, + bool *IsContextUpdated = nullptr, + bool *IsSamplerUpdated = nullptr) noexcept { + // Parse metadata from the json. + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + auto ParseError = Parser.parse(Metadata).get(Doc); + if (ParseError) { + RET_ERROR(ErrNo::InvalidEncoding, "parse metadata error."sv) + } + + // Get the current llama parameters. + int64_t PrevNGPULayers = GraphRef.Params.n_gpu_layers; + int64_t PrevMainGpu = GraphRef.Params.main_gpu; + int64_t PrevThreads = GraphRef.Params.cpuparams.n_threads; + bool PrevFlashAttn = GraphRef.Params.flash_attn; + int64_t PrevCtxSize = GraphRef.Params.n_ctx; + bool PrevEmbedding = GraphRef.Params.embedding; + // Get the current sampler parameters. + double PrevTemp = GraphRef.Params.sparams.temp; + double PrevTopP = GraphRef.Params.sparams.top_p; + double PrevRepeatPenalty = GraphRef.Params.sparams.penalty_repeat; + double PrevPresencePenalty = GraphRef.Params.sparams.penalty_present; + double PrevFrequencyPenalty = GraphRef.Params.sparams.penalty_freq; + std::string PrevGrammar = GraphRef.Params.sparams.grammar; + uint64_t PrevSeed = GraphRef.Params.sparams.seed; + + // The plugin parameters. + if (Doc.at_key("enable-log").error() == simdjson::SUCCESS) { + auto Err = Doc["enable-log"].get().get(GraphRef.EnableLog); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the enable-log option."sv) + } + } + if (Doc.at_key("enable-debug-log").error() == simdjson::SUCCESS) { + auto Err = Doc["enable-debug-log"].get().get(GraphRef.EnableDebugLog); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the enable-debug-log option."sv) + } + } + + // The model parameters. + if (Doc.at_key("main-gpu").error() == simdjson::SUCCESS) { + int64_t MainGPU; + auto Err = Doc["main-gpu"].get().get(MainGPU); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the main-gpu option."sv) + } + GraphRef.Params.main_gpu = static_cast(MainGPU); + } + if (Doc.at_key("n-gpu-layers").error() == simdjson::SUCCESS) { + int64_t NGPULayers; + auto Err = Doc["n-gpu-layers"].get().get(NGPULayers); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-gpu-layers option."sv) + } + GraphRef.Params.n_gpu_layers = static_cast(NGPULayers); + } + if (Doc.at_key("tensor-split").error() == simdjson::SUCCESS) { + // The TensorSplit is a comma-separated list of non-negative values. + // E.g., "3,2" presents 60% of the data to GPU 0 and 40% to GPU 1. + + // helper function `stringToList` cannot be used here since tensor-split + // needs a fixed-size array with validation checks. + std::string_view TSV; + auto Err = Doc["tensor-split"].get().get(TSV); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the tensor-split option."sv) + } + std::string TS(TSV); + std::replace(TS.begin(), TS.end(), ',', ' '); + std::stringstream SS(TS); + std::memset(GraphRef.Params.tensor_split, 0, + sizeof(GraphRef.Params.tensor_split)); + uint32_t TensorSplitSize = 0; + while (SS.good()) { + float TmpTensor; + SS >> TmpTensor; + GraphRef.Params.tensor_split[TensorSplitSize++] = TmpTensor; + } + size_t NDevices = llama_max_devices(); + if (TensorSplitSize > NDevices) { + RET_ERROR( + ErrNo::InvalidArgument, + "Number of Tensor-Split is larger than MaxDevices, please reduce "sv + "the size of tensor-split."sv) + } + for (size_t Idx = TensorSplitSize; Idx < NDevices; Idx++) { + GraphRef.Params.tensor_split[TensorSplitSize++] = 0.0f; + } + } + if (Doc.at_key("embedding").error() == simdjson::SUCCESS) { + auto Err = Doc["embedding"].get().get(GraphRef.Params.embedding); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the embedding option."sv) + } + } + if (Doc.at_key("split-mode").error() == simdjson::SUCCESS) { + std::string_view SplitMode; + auto Err = Doc["split-mode"].get().get(SplitMode); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the split-mode option."sv) + } + if (SplitMode == "none"sv) { + GraphRef.Params.split_mode = LLAMA_SPLIT_MODE_NONE; + } else if (SplitMode == "layer"sv) { + GraphRef.Params.split_mode = LLAMA_SPLIT_MODE_LAYER; + } else if (SplitMode == "row"sv) { + GraphRef.Params.split_mode = LLAMA_SPLIT_MODE_ROW; + } else { + RET_ERROR(ErrNo::InvalidArgument, + "Unknown split-mode: {}. Valid: none, layer, row."sv, SplitMode) + } + } + + // The context parameters. + if (Doc.at_key("ctx-size").error() == simdjson::SUCCESS) { + int64_t CtxSize; + auto Err = Doc["ctx-size"].get().get(CtxSize); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the ctx-size option."sv) + } + GraphRef.Params.n_ctx = static_cast(CtxSize); + } + if (Doc.at_key("batch-size").error() == simdjson::SUCCESS) { + int64_t BatchSize; + auto Err = Doc["batch-size"].get().get(BatchSize); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the batch-size option."sv) + } + GraphRef.Params.n_batch = static_cast(BatchSize); + } + if (Doc.at_key("ubatch-size").error() == simdjson::SUCCESS) { + int64_t UBatchSize; + auto Err = Doc["ubatch-size"].get().get(UBatchSize); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the ubatch-size option."sv) + } + GraphRef.Params.n_ubatch = static_cast(UBatchSize); + } + if (Doc.at_key("n-keep").error() == simdjson::SUCCESS) { + int64_t NKeep; + auto Err = Doc["n-keep"].get().get(NKeep); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-keep option."sv) + } + GraphRef.Params.n_keep = static_cast(NKeep); + } + if (Doc.at_key("n-chunks").error() == simdjson::SUCCESS) { + int64_t NChunks; + auto Err = Doc["n-chunks"].get().get(NChunks); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-chunks option."sv) + } + GraphRef.Params.n_chunks = static_cast(NChunks); + } + if (Doc.at_key("n-parallel").error() == simdjson::SUCCESS) { + int64_t NParallel; + auto Err = Doc["n-parallel"].get().get(NParallel); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-parallel option."sv) + } + GraphRef.Params.n_parallel = static_cast(NParallel); + } + if (Doc.at_key("n-sequences").error() == simdjson::SUCCESS) { + int64_t NSequences; + auto Err = Doc["n-sequences"].get().get(NSequences); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-sequences option."sv) + } + GraphRef.Params.n_sequences = static_cast(NSequences); + } + if (Doc.at_key("grp-attn-n").error() == simdjson::SUCCESS) { + int64_t GrpAttnN; + auto Err = Doc["grp-attn-n"].get().get(GrpAttnN); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the grp-attn-n option."sv) + } + GraphRef.Params.grp_attn_n = static_cast(GrpAttnN); + } + if (Doc.at_key("grp-attn-w").error() == simdjson::SUCCESS) { + int64_t GrpAttnW; + auto Err = Doc["grp-attn-w"].get().get(GrpAttnW); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the grp-attn-w option."sv) + } + GraphRef.Params.grp_attn_w = static_cast(GrpAttnW); + } + if (Doc.at_key("n-print").error() == simdjson::SUCCESS) { + int64_t NPrint; + auto Err = Doc["n-print"].get().get(NPrint); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-print option."sv) + } + GraphRef.Params.n_print = static_cast(NPrint); + } + if (Doc.at_key("rope-freq-base").error() == simdjson::SUCCESS) { + double RopeFreqBase; + auto Err = Doc["rope-freq-base"].get().get(RopeFreqBase); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the rope-freq-base option."sv) + } + GraphRef.Params.rope_freq_base = static_cast(RopeFreqBase); + } + if (Doc.at_key("rope-freq-scale").error() == simdjson::SUCCESS) { + double RopeFreqScale; + auto Err = Doc["rope-freq-scale"].get().get(RopeFreqScale); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the rope-freq-scale option."sv) + } + GraphRef.Params.rope_freq_scale = static_cast(RopeFreqScale); + } + if (Doc.at_key("yarn-ext-factor").error() == simdjson::SUCCESS) { + double YarnExtFactor; + auto Err = Doc["yarn-ext-factor"].get().get(YarnExtFactor); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the yarn-ext-factor option."sv) + } + GraphRef.Params.yarn_ext_factor = static_cast(YarnExtFactor); + } + if (Doc.at_key("yarn-attn-factor").error() == simdjson::SUCCESS) { + double YarnAttnFactor; + auto Err = Doc["yarn-attn-factor"].get().get(YarnAttnFactor); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the yarn-attn-factor option."sv) + } + GraphRef.Params.yarn_attn_factor = static_cast(YarnAttnFactor); + } + if (Doc.at_key("yarn-beta-fast").error() == simdjson::SUCCESS) { + double YarnBetaFast; + auto Err = Doc["yarn-beta-fast"].get().get(YarnBetaFast); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the yarn-beta-fast option."sv) + } + GraphRef.Params.yarn_beta_fast = static_cast(YarnBetaFast); + } + if (Doc.at_key("yarn-beta-slow").error() == simdjson::SUCCESS) { + double YarnBetaSlow; + auto Err = Doc["yarn-beta-slow"].get().get(YarnBetaSlow); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the yarn-beta-slow option."sv) + } + GraphRef.Params.yarn_beta_slow = static_cast(YarnBetaSlow); + } + if (Doc.at_key("yarn-orig-ctx").error() == simdjson::SUCCESS) { + int64_t YarnOrigCtx; + auto Err = Doc["yarn-orig-ctx"].get().get(YarnOrigCtx); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the yarn-orig-ctx option."sv) + } + GraphRef.Params.yarn_orig_ctx = static_cast(YarnOrigCtx); + } + if (Doc.at_key("defrag-thold").error() == simdjson::SUCCESS) { + double DefragThold; + auto Err = Doc["defrag-thold"].get().get(DefragThold); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the defrag-thold option."sv) + } + GraphRef.Params.defrag_thold = static_cast(DefragThold); + } + if (Doc.at_key("mask-valid").error() == simdjson::SUCCESS) { + auto Err = + Doc["mask-valid"].get().get(GraphRef.Params.cpuparams.mask_valid); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the mask-valid option."sv) + } + } + if (Doc.at_key("priority").error() == simdjson::SUCCESS) { + int64_t Priority; + auto Err = Doc["priority"].get().get(Priority); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the priority option."sv) + } + GraphRef.Params.cpuparams.priority = + static_cast(Priority); + } + if (Doc.at_key("strict-cpu").error() == simdjson::SUCCESS) { + auto Err = + Doc["strict-cpu"].get().get(GraphRef.Params.cpuparams.strict_cpu); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the strict-cpu option."sv) + } + } + if (Doc.at_key("poll").error() == simdjson::SUCCESS) { + int64_t Poll; + auto Err = Doc["poll"].get().get(Poll); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the poll option."sv) + } + GraphRef.Params.cpuparams.poll = static_cast(Poll); + } + if (Doc.at_key("mask-valid-batch").error() == simdjson::SUCCESS) { + auto Err = Doc["mask-valid-batch"].get().get( + GraphRef.Params.cpuparams_batch.mask_valid); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the mask-valid-batch option."sv) + } + } + if (Doc.at_key("priority-batch").error() == simdjson::SUCCESS) { + int64_t Priority; + auto Err = Doc["priority-batch"].get().get(Priority); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the priority-batch option."sv) + } + GraphRef.Params.cpuparams_batch.priority = + static_cast(Priority); + } + if (Doc.at_key("strict-cpu-batch").error() == simdjson::SUCCESS) { + auto Err = Doc["strict-cpu-batch"].get().get( + GraphRef.Params.cpuparams_batch.strict_cpu); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the strict-cpu-batch option."sv) + } + } + if (Doc.at_key("poll-batch").error() == simdjson::SUCCESS) { + int64_t Poll; + auto Err = Doc["poll-batch"].get().get(Poll); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the poll-batch option."sv) + } + GraphRef.Params.cpuparams_batch.poll = static_cast(Poll); + } + if (Doc.at_key("numa").error() == simdjson::SUCCESS) { + int64_t Numa; + auto Err = Doc["numa"].get().get(Numa); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the numa option."sv) + } + GraphRef.Params.numa = static_cast(Numa); + } + if (Doc.at_key("rope-scaling-type").error() == simdjson::SUCCESS) { + int64_t RopeScalingType; + auto Err = Doc["rope-scaling-type"].get().get(RopeScalingType); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the rope-scaling-type option."sv) + } + GraphRef.Params.rope_scaling_type = + static_cast(RopeScalingType); + } + if (Doc.at_key("pooling-type").error() == simdjson::SUCCESS) { + int64_t PoolingType; + auto Err = Doc["pooling-type"].get().get(PoolingType); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the pooling-type option."sv) + } + GraphRef.Params.pooling_type = + static_cast(PoolingType); + } + if (Doc.at_key("attention-type").error() == simdjson::SUCCESS) { + int64_t AttentionType; + auto Err = Doc["attention-type"].get().get(AttentionType); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the attention-type option."sv) + } + GraphRef.Params.attention_type = + static_cast(AttentionType); + } + if (Doc.at_key("threads").error() == simdjson::SUCCESS) { + int64_t NThreads; + auto Err = Doc["threads"].get().get(NThreads); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the threads option."sv) + } + GraphRef.Params.cpuparams.n_threads = static_cast(NThreads); + } + if (Doc.at_key("threads-batch").error() == simdjson::SUCCESS) { + int64_t NThreadsBatch; + auto Err = Doc["threads-batch"].get().get(NThreadsBatch); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the threads-batch option."sv) + } + GraphRef.Params.cpuparams_batch.n_threads = + static_cast(NThreadsBatch); + } + + // The sampling parameters. + if (Doc.at_key("n-prev").error() == simdjson::SUCCESS) { + int64_t NPrev; + auto Err = Doc["n-prev"].get().get(NPrev); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n_prev option."sv) + } + GraphRef.Params.sparams.n_prev = static_cast(NPrev); + } + if (Doc.at_key("top-k").error() == simdjson::SUCCESS) { + int64_t TopK; + auto Err = Doc["top-k"].get().get(TopK); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the top-k option."sv) + } + GraphRef.Params.sparams.top_k = static_cast(TopK); + } + if (Doc.at_key("min-p").error() == simdjson::SUCCESS) { + double MinP; + auto Err = Doc["min-p"].get().get(MinP); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the min-p option."sv) + } + GraphRef.Params.sparams.min_p = static_cast(MinP); + } + if (Doc.at_key("typ-p").error() == simdjson::SUCCESS) { + double TypP; + auto Err = Doc["typ-p"].get().get(TypP); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the typ-p option."sv) + } + GraphRef.Params.sparams.typ_p = static_cast(TypP); + } + if (Doc.at_key("penalize-nl").error() == simdjson::SUCCESS) { + auto Err = + Doc["penalize-nl"].get().get(GraphRef.Params.sparams.penalize_nl); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the penalize-nl option."sv) + } + } + if (Doc.at_key("tfs").error() == simdjson::SUCCESS) { + double TfsZ; + auto Err = Doc["tfs"].get().get(TfsZ); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the tfs option."sv) + } + GraphRef.Params.sparams.tfs_z = static_cast(TfsZ); + } + if (Doc.at_key("dynatemp-range").error() == simdjson::SUCCESS) { + double DynaTempRange; + auto Err = Doc["dynatemp-range"].get().get(DynaTempRange); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the dynatemp-range option."sv) + } + GraphRef.Params.sparams.dynatemp_range = static_cast(DynaTempRange); + } + if (Doc.at_key("dynatemp-exponent").error() == simdjson::SUCCESS) { + double DynaTempExponent; + auto Err = Doc["dynatemp-exponent"].get().get(DynaTempExponent); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the dynatemp-exponent option."sv) + } + GraphRef.Params.sparams.dynatemp_exponent = + static_cast(DynaTempExponent); + } + if (Doc.at_key("last-n-penalty").error() == simdjson::SUCCESS) { + int64_t LastNPenalty; + auto Err = Doc["last-n-penalty"].get().get(LastNPenalty); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the last-n-penalty option."sv) + } + GraphRef.Params.sparams.penalty_last_n = static_cast(LastNPenalty); + } + if (Doc.at_key("temp").error() == simdjson::SUCCESS) { + double Temp; + auto Err = Doc["temp"].get().get(Temp); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the temp option."sv) + } + GraphRef.Params.sparams.temp = static_cast(std::max(0.0, Temp)); + } + if (Doc.at_key("top-p").error() == simdjson::SUCCESS) { + double TopP; + auto Err = Doc["top-p"].get().get(TopP); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the top-p option."sv) + } + GraphRef.Params.sparams.top_p = static_cast(std::max(0.0, TopP)); + } + if (Doc.at_key("repeat-penalty").error() == simdjson::SUCCESS) { + double RepeatPenalty; + auto Err = Doc["repeat-penalty"].get().get(RepeatPenalty); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the repeat-penalty option."sv) + } + GraphRef.Params.sparams.penalty_repeat = + static_cast(std::max(0.0, RepeatPenalty)); + } + if (Doc.at_key("presence-penalty").error() == simdjson::SUCCESS) { + double PresencePenalty; + auto Err = Doc["presence-penalty"].get().get(PresencePenalty); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the presence-penalty option."sv) + } + GraphRef.Params.sparams.penalty_present = + static_cast(std::max(0.0, PresencePenalty)); + } + if (Doc.at_key("frequency-penalty").error() == simdjson::SUCCESS) { + double FrequencyPenalty; + auto Err = Doc["frequency-penalty"].get().get(FrequencyPenalty); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the frequency-penalty option."sv) + } + GraphRef.Params.sparams.penalty_freq = + static_cast(std::max(0.0, FrequencyPenalty)); + } + if (Doc.at_key("mirostat").error() == simdjson::SUCCESS) { + int64_t Mirostat; + auto Err = Doc["mirostat"].get().get(Mirostat); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the mirostat option."sv) + } + GraphRef.Params.sparams.mirostat = static_cast(Mirostat); + } + if (Doc.at_key("mirostat-eta").error() == simdjson::SUCCESS) { + double MirostatEta; + auto Err = Doc["mirostat-eta"].get().get(MirostatEta); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the mirostat-eta option."sv) + } + GraphRef.Params.sparams.mirostat_eta = static_cast(MirostatEta); + } + if (Doc.at_key("mirostat-ent").error() == simdjson::SUCCESS) { + double MirostatEnt; + auto Err = Doc["mirostat-ent"].get().get(MirostatEnt); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the mirostat-ent option."sv) + } + GraphRef.Params.sparams.mirostat_tau = static_cast(MirostatEnt); + } + if (Doc.at_key("ignore-eos").error() == simdjson::SUCCESS) { + auto Err = + Doc["ignore-eos"].get().get(GraphRef.Params.sparams.ignore_eos); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the ignore-eos option."sv) + } + } + if (Doc.at_key("no-perf-sampling").error() == simdjson::SUCCESS) { + auto Err = Doc["no-perf-sampling"].get().get( + GraphRef.Params.sparams.no_perf); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the no-perf-sampling option."sv) + } + } + if (Doc.at_key("grammar").error() == simdjson::SUCCESS) { + std::string_view Grammar; + auto Err = Doc["grammar"].get().get(Grammar); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the grammar option."sv) + } + GraphRef.Params.sparams.grammar = Grammar; + } + if (Doc.at_key("seed").error() == simdjson::SUCCESS) { + uint64_t Seed; + auto Err = Doc["seed"].get().get(Seed); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the seed option."sv) + } + GraphRef.Params.sparams.seed = static_cast(Seed); + } + // The speculative parameters. + if (Doc.at_key("n-gpu-layers-draft").error() == simdjson::SUCCESS) { + int64_t NGPULayersDraft; + auto Err = Doc["n-gpu-layers-draft"].get().get(NGPULayersDraft); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-gpu-layers-draft option."sv) + } + GraphRef.Params.n_gpu_layers_draft = static_cast(NGPULayersDraft); + } + if (Doc.at_key("p-split").error() == simdjson::SUCCESS) { + double PSplit; + auto Err = Doc["p-split"].get().get(PSplit); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the p-split option."sv) + } + GraphRef.Params.p_split = static_cast(PSplit); + } + + // The config parameters. + if (Doc.at_key("stream-stdout").error() == simdjson::SUCCESS) { + auto Err = Doc["stream-stdout"].get().get(ConfRef.StreamStdout); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the stream-stdout option."sv) + } + } + if (Doc.at_key("n-predict").error() == simdjson::SUCCESS) { + auto Err = Doc["n-predict"].get().get(ConfRef.NPredict); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-predict option."sv) + } + } + if (Doc.at_key("reverse-prompt").error() == simdjson::SUCCESS) { + std::string_view ReversePrompt; + auto Err = Doc["reverse-prompt"].get().get(ReversePrompt); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the reverse-prompt option."sv) + } + ConfRef.ReversePrompt = ReversePrompt; + } + if (Doc.at_key("model-alias").error() == simdjson::SUCCESS) { + std::string_view ModelAlias; + auto Err = Doc["model-alias"].get().get(ModelAlias); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the model-alias option."sv) + } + GraphRef.Params.model_alias = ModelAlias; + } + if (Doc.at_key("model-url").error() == simdjson::SUCCESS) { + std::string_view ModelUrl; + auto Err = Doc["model-url"].get().get(ModelUrl); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the model-url option."sv) + } + GraphRef.Params.model_url = ModelUrl; + } + if (Doc.at_key("hf-token").error() == simdjson::SUCCESS) { + std::string_view HfToken; + auto Err = Doc["hf-token"].get().get(HfToken); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the hf-token option."sv) + } + GraphRef.Params.hf_token = HfToken; + } + if (Doc.at_key("hf-repo").error() == simdjson::SUCCESS) { + std::string_view HfRepo; + auto Err = Doc["hf-repo"].get().get(HfRepo); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the hf-repo option."sv) + } + GraphRef.Params.hf_repo = HfRepo; + } + if (Doc.at_key("hf-file").error() == simdjson::SUCCESS) { + std::string_view HfFile; + auto Err = Doc["hf-file"].get().get(HfFile); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the hf-file option."sv) + } + GraphRef.Params.hf_file = HfFile; + } + if (Doc.at_key("prompt-file").error() == simdjson::SUCCESS) { + std::string_view PromptFile; + auto Err = Doc["prompt-file"].get().get(PromptFile); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the prompt-file option."sv) + } + GraphRef.Params.prompt_file = PromptFile; + } + if (Doc.at_key("path-prompt-cache").error() == simdjson::SUCCESS) { + std::string_view PathPromptCache; + auto Err = + Doc["path-prompt-cache"].get().get(PathPromptCache); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the path-prompt-cache option."sv) + } + GraphRef.Params.path_prompt_cache = PathPromptCache; + } + if (Doc.at_key("input-prefix").error() == simdjson::SUCCESS) { + std::string_view InputPrefix; + auto Err = Doc["input-prefix"].get().get(InputPrefix); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the input-prefix option."sv) + } + GraphRef.Params.input_prefix = InputPrefix; + } + if (Doc.at_key("input-suffix").error() == simdjson::SUCCESS) { + std::string_view InputSuffix; + auto Err = Doc["input-suffix"].get().get(InputSuffix); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the input-suffix option."sv) + } + GraphRef.Params.input_suffix = InputSuffix; + } + if (Doc.at_key("lookup-cache-static").error() == simdjson::SUCCESS) { + std::string_view LookupCacheStatic; + auto Err = Doc["lookup-cache-static"].get().get( + LookupCacheStatic); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the lookup-cache-static option."sv) + } + GraphRef.Params.lookup_cache_static = LookupCacheStatic; + } + if (Doc.at_key("lookup-cache-dynamic").error() == simdjson::SUCCESS) { + std::string_view LookupCacheDynamic; + auto Err = Doc["lookup-cache-dynamic"].get().get( + LookupCacheDynamic); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the lookup-cache-dynamic option."sv) + } + GraphRef.Params.lookup_cache_dynamic = LookupCacheDynamic; + } + if (Doc.at_key("logits-file").error() == simdjson::SUCCESS) { + std::string_view LogitsFile; + auto Err = Doc["logits-file"].get().get(LogitsFile); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the logits-file option."sv) + } + GraphRef.Params.logits_file = LogitsFile; + } + if (Doc.at_key("lora-init-without-apply").error() == simdjson::SUCCESS) { + auto Err = Doc["lora-init-without-apply"].get().get( + GraphRef.Params.lora_init_without_apply); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the lora-init-without-apply option."sv) + } + } + if (Doc.at_key("verbosity").error() == simdjson::SUCCESS) { + int64_t Verbosity; + auto Err = Doc["verbosity"].get().get(Verbosity); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the verbosity option."sv) + } + GraphRef.Params.verbosity = static_cast(Verbosity); + } + if (Doc.at_key("control-vector-layer-start").error() == simdjson::SUCCESS) { + int64_t ControlVectorLayerStart; + auto Err = Doc["control-vector-layer-start"].get().get( + ControlVectorLayerStart); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the control-vector-layer-start option."sv) + } + GraphRef.Params.control_vector_layer_start = + static_cast(ControlVectorLayerStart); + } + if (Doc.at_key("control-vector-layer-end").error() == simdjson::SUCCESS) { + int64_t ControlVectorLayerEnd; + auto Err = Doc["control-vector-layer-end"].get().get( + ControlVectorLayerEnd); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the control-vector-layer-end option."sv) + } + GraphRef.Params.control_vector_layer_end = + static_cast(ControlVectorLayerEnd); + } + if (Doc.at_key("ppl-stride").error() == simdjson::SUCCESS) { + int64_t PplStride; + auto Err = Doc["ppl-stride"].get().get(PplStride); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the ppl-stride option."sv) + } + GraphRef.Params.ppl_stride = static_cast(PplStride); + } + if (Doc.at_key("ppl-output-type").error() == simdjson::SUCCESS) { + int64_t PplOutputType; + auto Err = Doc["ppl-output-type"].get().get(PplOutputType); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the ppl-output-type option."sv) + } + GraphRef.Params.ppl_output_type = static_cast(PplOutputType); + } + if (Doc.at_key("hellaswag").error() == simdjson::SUCCESS) { + auto Err = Doc["hellaswag"].get().get(GraphRef.Params.hellaswag); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the hellaswag option."sv) + } + } + if (Doc.at_key("hellaswag-tasks").error() == simdjson::SUCCESS) { + uint64_t HellaswagTasks; + auto Err = Doc["hellaswag-tasks"].get().get(HellaswagTasks); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the hellaswag-tasks option."sv) + } + GraphRef.Params.hellaswag_tasks = HellaswagTasks; + } + if (Doc.at_key("winogrande").error() == simdjson::SUCCESS) { + auto Err = Doc["winogrande"].get().get(GraphRef.Params.winogrande); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the winogrande option."sv) + } + } + if (Doc.at_key("winogrande-tasks").error() == simdjson::SUCCESS) { + uint64_t WinograndeTasks; + auto Err = Doc["winogrande-tasks"].get().get(WinograndeTasks); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the winogrande-tasks option."sv) + } + GraphRef.Params.winogrande_tasks = WinograndeTasks; + } + if (Doc.at_key("multiple-choice").error() == simdjson::SUCCESS) { + auto Err = + Doc["multiple-choice"].get().get(GraphRef.Params.multiple_choice); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the multiple-choice option."sv) + } + } + if (Doc.at_key("multiple-choice-tasks").error() == simdjson::SUCCESS) { + uint64_t MultipleChoiceTasks; + auto Err = + Doc["multiple-choice-tasks"].get().get(MultipleChoiceTasks); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the multiple-choice-tasks option."sv) + } + GraphRef.Params.multiple_choice_tasks = MultipleChoiceTasks; + } + if (Doc.at_key("kl-divergence").error() == simdjson::SUCCESS) { + auto Err = + Doc["kl-divergence"].get().get(GraphRef.Params.kl_divergence); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the kl-divergence option."sv) + } + } + if (Doc.at_key("usage").error() == simdjson::SUCCESS) { + auto Err = Doc["usage"].get().get(GraphRef.Params.usage); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the usage option."sv) + } + } + if (Doc.at_key("use-color").error() == simdjson::SUCCESS) { + auto Err = Doc["use-color"].get().get(GraphRef.Params.use_color); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the use-color option."sv) + } + } + if (Doc.at_key("special").error() == simdjson::SUCCESS) { + auto Err = Doc["special"].get().get(GraphRef.Params.special); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the special option."sv) + } + } + if (Doc.at_key("interactive").error() == simdjson::SUCCESS) { + auto Err = Doc["interactive"].get().get(GraphRef.Params.interactive); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the interactive option."sv) + } + } + if (Doc.at_key("interactive-first").error() == simdjson::SUCCESS) { + auto Err = Doc["interactive-first"].get().get( + GraphRef.Params.interactive_first); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the interactive-first option."sv) + } + } + if (Doc.at_key("prompt-cache-all").error() == simdjson::SUCCESS) { + auto Err = Doc["prompt-cache-all"].get().get( + GraphRef.Params.prompt_cache_all); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the prompt-cache-all option."sv) + } + } + if (Doc.at_key("prompt-cache-ro").error() == simdjson::SUCCESS) { + auto Err = + Doc["prompt-cache-ro"].get().get(GraphRef.Params.prompt_cache_ro); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the prompt-cache-ro option."sv) + } + } + if (Doc.at_key("escape").error() == simdjson::SUCCESS) { + auto Err = Doc["escape"].get().get(GraphRef.Params.escape); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the escape option."sv) + } + } + if (Doc.at_key("multiline-input").error() == simdjson::SUCCESS) { + auto Err = + Doc["multiline-input"].get().get(GraphRef.Params.multiline_input); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the multiline-input option."sv) + } + } + if (Doc.at_key("simple-io").error() == simdjson::SUCCESS) { + auto Err = Doc["simple-io"].get().get(GraphRef.Params.simple_io); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the simple-io option."sv) + } + } + if (Doc.at_key("cont-batching").error() == simdjson::SUCCESS) { + auto Err = + Doc["cont-batching"].get().get(GraphRef.Params.cont_batching); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the cont-batching option."sv) + } + } + if (Doc.at_key("flash-attn").error() == simdjson::SUCCESS) { + auto Err = Doc["flash-attn"].get().get(GraphRef.Params.flash_attn); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the flash-attn option."sv) + } + } + if (Doc.at_key("no-perf").error() == simdjson::SUCCESS) { + auto Err = Doc["no-perf"].get().get(GraphRef.Params.no_perf); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the no-perf option."sv) + } + } + if (Doc.at_key("ctx-shift").error() == simdjson::SUCCESS) { + auto Err = Doc["ctx-shift"].get().get(GraphRef.Params.ctx_shift); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the ctx-shift option."sv) + } + } + if (Doc.at_key("input-prefix-bos").error() == simdjson::SUCCESS) { + auto Err = Doc["input-prefix-bos"].get().get( + GraphRef.Params.input_prefix_bos); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the input-prefix-bos option."sv) + } + } + if (Doc.at_key("use-mlock").error() == simdjson::SUCCESS) { + auto Err = Doc["use-mlock"].get().get(GraphRef.Params.use_mlock); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the use-mlock option."sv) + } + } + if (Doc.at_key("use-mmap").error() == simdjson::SUCCESS) { + auto Err = Doc["use-mmap"].get().get(GraphRef.Params.use_mmap); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the use-mmap option."sv) + } + } + if (Doc.at_key("verbose-prompt").error() == simdjson::SUCCESS) { + auto Err = + Doc["verbose-prompt"].get().get(GraphRef.Params.verbose_prompt); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the verbose-prompt option."sv) + } + } + if (Doc.at_key("display-prompt").error() == simdjson::SUCCESS) { + auto Err = + Doc["display-prompt"].get().get(GraphRef.Params.display_prompt); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the display-prompt option."sv) + } + } + if (Doc.at_key("no-kv-offload").error() == simdjson::SUCCESS) { + auto Err = + Doc["no-kv-offload"].get().get(GraphRef.Params.no_kv_offload); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the no-kv-offload option."sv) + } + } + if (Doc.at_key("warmup").error() == simdjson::SUCCESS) { + auto Err = Doc["warmup"].get().get(GraphRef.Params.warmup); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the warmup option."sv) + } + } + if (Doc.at_key("check-tensors").error() == simdjson::SUCCESS) { + auto Err = + Doc["check-tensors"].get().get(GraphRef.Params.check_tensors); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the check-tensors option."sv) + } + } + if (Doc.at_key("cache-type-k").error() == simdjson::SUCCESS) { + std::string_view CacheTypeK; + auto Err = Doc["cache-type-k"].get().get(CacheTypeK); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the cache-type-k option."sv) + } + GraphRef.Params.cache_type_k = CacheTypeK; + } + if (Doc.at_key("cache-type-v").error() == simdjson::SUCCESS) { + std::string_view CacheTypeV; + auto Err = Doc["cache-type-v"].get().get(CacheTypeV); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the cache-type-v option."sv) + } + GraphRef.Params.cache_type_v = CacheTypeV; + } + if (Doc.at_key("embd-normalize").error() == simdjson::SUCCESS) { + int64_t EmbdNormalize; + auto Err = Doc["embd-normalize"].get().get(EmbdNormalize); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the embd-normalize option."sv) + } + GraphRef.Params.embd_normalize = static_cast(EmbdNormalize); + } + if (Doc.at_key("embd-out").error() == simdjson::SUCCESS) { + std::string_view EmbdOut; + auto Err = Doc["embd-out"].get().get(EmbdOut); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the embd-out option."sv) + } + GraphRef.Params.embd_out = EmbdOut; + } + if (Doc.at_key("embd-sep").error() == simdjson::SUCCESS) { + std::string_view EmbdSep; + auto Err = Doc["embd-sep"].get().get(EmbdSep); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the embd-sep option."sv) + } + GraphRef.Params.embd_sep = EmbdSep; + } + if (Doc.at_key("reranking").error() == simdjson::SUCCESS) { + auto Err = Doc["reranking"].get().get(GraphRef.Params.reranking); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the reranking option."sv) + } + } + if (Doc.at_key("port").error() == simdjson::SUCCESS) { + int64_t Port; + auto Err = Doc["port"].get().get(Port); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the port option."sv) + } + GraphRef.Params.port = static_cast(Port); + } + if (Doc.at_key("timeout-read").error() == simdjson::SUCCESS) { + int64_t TimeoutRead; + auto Err = Doc["timeout-read"].get().get(TimeoutRead); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the timeout-read option."sv) + } + GraphRef.Params.timeout_read = static_cast(TimeoutRead); + } + if (Doc.at_key("timeout-write").error() == simdjson::SUCCESS) { + int64_t TimeoutWrite; + auto Err = Doc["timeout-write"].get().get(TimeoutWrite); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the timeout-write option."sv) + } + GraphRef.Params.timeout_write = static_cast(TimeoutWrite); + } + if (Doc.at_key("n-threads-http").error() == simdjson::SUCCESS) { + int64_t NThreadsHttp; + auto Err = Doc["n-threads-http"].get().get(NThreadsHttp); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-threads-http option."sv) + } + GraphRef.Params.n_threads_http = static_cast(NThreadsHttp); + } + if (Doc.at_key("n-cache-reuse").error() == simdjson::SUCCESS) { + int64_t NCacheReuse; + auto Err = Doc["n-cache-reuse"].get().get(NCacheReuse); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-cache-reuse option."sv) + } + GraphRef.Params.n_cache_reuse = static_cast(NCacheReuse); + } + if (Doc.at_key("hostname").error() == simdjson::SUCCESS) { + std::string_view Hostname; + auto Err = Doc["hostname"].get().get(Hostname); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the hostname option."sv) + } + GraphRef.Params.hostname = Hostname; + } + if (Doc.at_key("public-path").error() == simdjson::SUCCESS) { + std::string_view PublicPath; + auto Err = Doc["public-path"].get().get(PublicPath); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the public-path option."sv) + } + GraphRef.Params.public_path = PublicPath; + } + if (Doc.at_key("chat-template").error() == simdjson::SUCCESS) { + std::string_view ChatTemplate; + auto Err = Doc["chat-template"].get().get(ChatTemplate); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the chat-template option."sv) + } + GraphRef.Params.chat_template = ChatTemplate; + } + if (Doc.at_key("enable-chat-template").error() == simdjson::SUCCESS) { + auto Err = Doc["enable-chat-template"].get().get( + GraphRef.Params.enable_chat_template); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the enable-chat-template option."sv) + } + } + if (Doc.at_key("ssl-file-key").error() == simdjson::SUCCESS) { + std::string_view SslFileKey; + auto Err = Doc["ssl-file-key"].get().get(SslFileKey); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the ssl-file-key option."sv) + } + GraphRef.Params.ssl_file_key = SslFileKey; + } + if (Doc.at_key("ssl-file-cert").error() == simdjson::SUCCESS) { + std::string_view SslFileCert; + auto Err = Doc["ssl-file-cert"].get().get(SslFileCert); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the ssl-file-cert option."sv) + } + GraphRef.Params.ssl_file_cert = SslFileCert; + } + if (Doc.at_key("webui").error() == simdjson::SUCCESS) { + auto Err = Doc["webui"].get().get(GraphRef.Params.webui); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the webui option."sv) + } + } + if (Doc.at_key("endpoint-slots").error() == simdjson::SUCCESS) { + int64_t EndpointSlots; + auto Err = Doc["endpoint-slots"].get().get(EndpointSlots); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the endpoint-slots option."sv) + } + GraphRef.Params.endpoint_slots = static_cast(EndpointSlots); + } + if (Doc.at_key("endpoint-props").error() == simdjson::SUCCESS) { + auto Err = + Doc["endpoint-props"].get().get(GraphRef.Params.endpoint_props); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the endpoint-props option."sv) + } + } + if (Doc.at_key("endpoint-metrics").error() == simdjson::SUCCESS) { + auto Err = Doc["endpoint-metrics"].get().get( + GraphRef.Params.endpoint_metrics); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the endpoint-metrics option."sv) + } + } + if (Doc.at_key("log-json").error() == simdjson::SUCCESS) { + auto Err = Doc["log-json"].get().get(GraphRef.Params.log_json); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the log-json option."sv) + } + } + if (Doc.at_key("slot-save-path").error() == simdjson::SUCCESS) { + std::string_view SlotSavePath; + auto Err = Doc["slot-save-path"].get().get(SlotSavePath); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the slot-save-path option."sv) + } + GraphRef.Params.slot_save_path = SlotSavePath; + } + if (Doc.at_key("slot-prompt-similarity").error() == simdjson::SUCCESS) { + double SlotPromptSimilarity; + auto Err = + Doc["slot-prompt-similarity"].get().get(SlotPromptSimilarity); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the slot-prompt-similarity option."sv) + } + GraphRef.Params.slot_prompt_similarity = + static_cast(SlotPromptSimilarity); + } + if (Doc.at_key("is-pp-shared").error() == simdjson::SUCCESS) { + auto Err = + Doc["is-pp-shared"].get().get(GraphRef.Params.is_pp_shared); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the is-pp-shared option."sv) + } + } + if (Doc.at_key("n-pp").error() == simdjson::SUCCESS) { + std::string_view NPP; + auto Err = Doc["n-pp"].get().get(NPP); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the n-pp option."sv) + } + stringToList(std::string(NPP), GraphRef.Params.n_pp); + } + if (Doc.at_key("n-tg").error() == simdjson::SUCCESS) { + std::string_view NTG; + auto Err = Doc["n-tg"].get().get(NTG); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the n-tg option."sv) + } + stringToList(std::string(NTG), GraphRef.Params.n_tg); + } + if (Doc.at_key("n-pl").error() == simdjson::SUCCESS) { + std::string_view NPL; + auto Err = Doc["n-pl"].get().get(NPL); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the n-pl option."sv) + } + stringToList(std::string(NPL), GraphRef.Params.n_pl); + } + if (Doc.at_key("context-files").error() == simdjson::SUCCESS) { + std::string_view ContextFiles; + auto Err = Doc["context-files"].get().get(ContextFiles); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the context-files option."sv) + } + } + if (Doc.at_key("chunk-size").error() == simdjson::SUCCESS) { + int64_t ChunkSize; + auto Err = Doc["chunk-size"].get().get(ChunkSize); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the chunk-size option."sv) + } + GraphRef.Params.chunk_size = static_cast(ChunkSize); + } + if (Doc.at_key("chunk-separator").error() == simdjson::SUCCESS) { + std::string_view ChunkSeparator; + auto Err = + Doc["chunk-separator"].get().get(ChunkSeparator); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the chunk-separator option."sv) + } + GraphRef.Params.chunk_separator = ChunkSeparator; + } + if (Doc.at_key("n-junk").error() == simdjson::SUCCESS) { + int64_t NJunk; + auto Err = Doc["n-junk"].get().get(NJunk); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-junk option."sv) + } + GraphRef.Params.n_junk = static_cast(NJunk); + } + if (Doc.at_key("i-pos").error() == simdjson::SUCCESS) { + int64_t IPos; + auto Err = Doc["i-pos"].get().get(IPos); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the i-pos option."sv) + } + GraphRef.Params.i_pos = static_cast(IPos); + } + if (Doc.at_key("n-out-freq").error() == simdjson::SUCCESS) { + int64_t NOutFreq; + auto Err = Doc["n-out-freq"].get().get(NOutFreq); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-out-freq option."sv) + } + GraphRef.Params.n_out_freq = static_cast(NOutFreq); + } + if (Doc.at_key("n-save-freq").error() == simdjson::SUCCESS) { + int64_t NSaveFreq; + auto Err = Doc["n-save-freq"].get().get(NSaveFreq); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-save-freq option."sv) + } + GraphRef.Params.n_save_freq = static_cast(NSaveFreq); + } + if (Doc.at_key("i-chunk").error() == simdjson::SUCCESS) { + int64_t IChunk; + auto Err = Doc["i-chunk"].get().get(IChunk); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the i-chunk option."sv) + } + GraphRef.Params.i_chunk = static_cast(IChunk); + } + if (Doc.at_key("process-output").error() == simdjson::SUCCESS) { + auto Err = + Doc["process-output"].get().get(GraphRef.Params.process_output); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the process-output option."sv) + } + } + if (Doc.at_key("compute-ppl").error() == simdjson::SUCCESS) { + auto Err = Doc["compute-ppl"].get().get(GraphRef.Params.compute_ppl); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the compute-ppl option."sv) + } + } + if (Doc.at_key("n-pca-batch").error() == simdjson::SUCCESS) { + int64_t NPCABatch; + auto Err = Doc["n-pca-batch"].get().get(NPCABatch); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-pca-batch option."sv) + } + GraphRef.Params.n_pca_batch = static_cast(NPCABatch); + } + if (Doc.at_key("n-pca-iterations").error() == simdjson::SUCCESS) { + int64_t NPCAIterations; + auto Err = Doc["n-pca-iterations"].get().get(NPCAIterations); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the n-pca-iterations option."sv) + } + GraphRef.Params.n_pca_iterations = static_cast(NPCAIterations); + } + if (Doc.at_key("cvector-dimre-method").error() == simdjson::SUCCESS) { + std::string_view CVectorDimreMethod; + auto Err = Doc["cvector-dimre-method"].get().get( + CVectorDimreMethod); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the cvector-dimre-method option."sv) + } + if (CVectorDimreMethod == "pca") { + GraphRef.Params.cvector_dimre_method = DIMRE_METHOD_PCA; + } else if (CVectorDimreMethod == "mean") { + GraphRef.Params.cvector_dimre_method = DIMRE_METHOD_MEAN; + } else { + RET_ERROR( + ErrNo::InvalidArgument, + "Invalid value for cvector-dimre-method: must be 'pca' or 'mean'."sv) + } + } + if (Doc.at_key("cvector-outfile").error() == simdjson::SUCCESS) { + std::string_view CVectorOutfile; + auto Err = + Doc["cvector-outfile"].get().get(CVectorOutfile); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the cvector-outfile option."sv) + } + GraphRef.Params.cvector_outfile = CVectorOutfile; + } + if (Doc.at_key("cvector-positive-file").error() == simdjson::SUCCESS) { + std::string_view CVectorPositiveFile; + auto Err = Doc["cvector-positive-file"].get().get( + CVectorPositiveFile); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the cvector-positive-file option."sv) + } + GraphRef.Params.cvector_positive_file = CVectorPositiveFile; + } + if (Doc.at_key("cvector-negative-file").error() == simdjson::SUCCESS) { + std::string_view CVectorNegativeFile; + auto Err = Doc["cvector-negative-file"].get().get( + CVectorNegativeFile); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the cvector-negative-file option."sv) + } + GraphRef.Params.cvector_negative_file = CVectorNegativeFile; + } + if (Doc.at_key("spm-infill").error() == simdjson::SUCCESS) { + auto Err = Doc["spm-infill"].get().get(GraphRef.Params.spm_infill); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the spm-infill option."sv) + } + } + if (Doc.at_key("out-file").error() == simdjson::SUCCESS) { + std::string_view Outfile; + auto Err = Doc["out-file"].get().get(Outfile); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the out-file option."sv) + } + GraphRef.Params.out_file = Outfile; + } + if (Doc.at_key("batched-bench-output-jsonl").error() == simdjson::SUCCESS) { + auto Err = Doc["batched-bench-output-jsonl"].get().get( + GraphRef.Params.batched_bench_output_jsonl); + if (Err) { + RET_ERROR(ErrNo::InvalidArgument, + "Unable to retrieve the batched-bench-output-jsonl option."sv) + } + } + + // Check if the model parameters are updated. + if (IsModelUpdated && (PrevNGPULayers != GraphRef.Params.n_gpu_layers || + PrevMainGpu != GraphRef.Params.main_gpu)) { + *IsModelUpdated = true; + } + + // Check if the context parameters are updated. + if (IsContextUpdated && (PrevCtxSize != GraphRef.Params.n_ctx || + PrevThreads != GraphRef.Params.cpuparams.n_threads || + PrevFlashAttn != GraphRef.Params.flash_attn || + PrevEmbedding != GraphRef.Params.embedding)) { + *IsContextUpdated = true; + } + + // Check if the sampler parameters are updated. + if (IsSamplerUpdated && + (PrevTemp != GraphRef.Params.sparams.temp || + PrevTopP != GraphRef.Params.sparams.top_p || + PrevRepeatPenalty != GraphRef.Params.sparams.penalty_repeat || + PrevPresencePenalty != GraphRef.Params.sparams.penalty_present || + PrevFrequencyPenalty != GraphRef.Params.sparams.penalty_freq || + PrevGrammar != GraphRef.Params.sparams.grammar || + PrevSeed != GraphRef.Params.sparams.seed)) { + *IsSamplerUpdated = true; + } + + return ErrNo::Success; +} + +// <<<<<<<< Metadata related functions <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + +// >>>>>>>> Output related functions >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + +// Generate output metadata. +std::string buildOutputMetadata(Context &CxtRef) noexcept { + return fmt::format(R"({{"input_tokens": {}, )" + R"("output_tokens": {}, )" + R"("llama_build_number": {}, )" + R"("llama_commit": "{}"}})"sv, + CxtRef.LlamaNInputs, CxtRef.LlamaOutputTokens.size(), + LLAMA_BUILD_NUMBER, LLAMA_COMMIT); +} + +// Generate output embedding. +void buildOutputEmbedding(std::string &Embedding, int32_t NEmbd, + const float *Embeddings) noexcept { + // Embedding vector format + // | Content | + // | ----------------------------------- | + // | '{"number_embedding": ' | + // | n_embedding | + // | ', "embedding": ' | + // | '[' | + // | n_embedding*(embedding value %.10f) | + // | (n_embedding-1)*(',') | + // | ']' | + // | '}' | + Embedding = + fmt::format(R"({{"n_embedding": {}, )" + R"("embedding": [{:.10}]}})"sv, + NEmbd, fmt::join(Embeddings, Embeddings + NEmbd, ","sv)); +} +// <<<<<<<< Output related functions <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + +// >>>>>>>> Compute related functions >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + +// Helper to init a llama batch. +struct llama_batch allocBatch(int64_t NTokens, int64_t Embd = 0, + int32_t NSeqMax = 1) noexcept { + struct llama_batch Batch = llama_batch_init( + /* n_tokens_alloc */ static_cast(NTokens), + /* embd */ static_cast(Embd), + /* n_seq_max */ static_cast(NSeqMax)); + std::fill(Batch.n_seq_id, Batch.n_seq_id + NTokens, + static_cast(NSeqMax)); + for (int64_t I = 0; I < NTokens; I++) { + std::fill(Batch.seq_id[I], Batch.seq_id[I] + NSeqMax, 0); + } + std::fill(Batch.logits, Batch.logits + NTokens, false); + return Batch; +} + +// Fill tokens (smaller than batch size) into a batch with position data. +void fillBatch(Span Tokens, Graph &GraphRef, + llama_batch &Batch, int &NPos, bool IsLogit = false) { + assuming(GraphRef.Params.n_batch >= static_cast(Tokens.size())); + assuming(Batch.token != nullptr); + assuming(Batch.pos != nullptr); + assuming(Batch.logits != nullptr); + // Fill the batch with pos information. + Batch.n_tokens = static_cast(Tokens.size()); + for (uint32_t I = 0; I < Tokens.size(); I++) { + Batch.token[I] = Tokens[I]; + Batch.pos[I] = NPos + I; + Batch.logits[I] = false; + } + + // Logits of sampling or end of inputs. + if (IsLogit) { + Batch.logits[Tokens.size() - 1] = true; + } + + // Move the position. + NPos += static_cast(Tokens.size()); +} + +// Evaluate tokens. Construct the tokens into batch and decode. +ErrNo evaluateTokens(Span Tokens, Graph &GraphRef, + llama_batch &Batch, int &NPos, + bool IsLogits = false) noexcept { + // End the inference if the context is full. + uint32_t NCtx = llama_n_ctx(GraphRef.LlamaContext.get()); + if (NPos + static_cast(Tokens.size()) > NCtx) { + LOG_INFO( + GraphRef.EnableLog, + "evaluateTokens: the context if full ({} / {} tokens). Please increase your "sv + "context size."sv, + NPos + static_cast(Tokens.size()), NCtx) + return ErrNo::ContextFull; + } + + // Loop for decode batch. Split tokens into batch size length. + for (int I = 0; I < static_cast(Tokens.size()); + I += static_cast(GraphRef.Params.n_batch)) { + int NEval = static_cast(Tokens.size()) - I; + if (NEval > static_cast(GraphRef.Params.n_batch)) { + NEval = static_cast(GraphRef.Params.n_batch); + } + + // Fill the batch with pos information. + fillBatch(Span(Tokens.begin() + I, NEval), GraphRef, + Batch, NPos, + IsLogits && I + NEval >= static_cast(Tokens.size())); + + // Decode the batch. + auto Status = llama_decode(GraphRef.LlamaContext.get(), Batch); + if (Status == 1) { + RET_ERROR( + ErrNo::RuntimeError, + "evaluateTokens: failed to llama_decode: try reducing the size of the batch "sv + "or increasing the size of context."sv) + } + if (Status < 0) { + RET_ERROR( + ErrNo::RuntimeError, + "evaluateTokens: failed to llama_decode: internal fatal error. Please open "sv + "an issue on GitHub."sv) + } + } + + return ErrNo::Success; +} + +// Clear the context and reset the sampler. +void clearContext(Graph &GraphRef, Context &CxtRef) noexcept { + LOG_DEBUG(GraphRef.EnableDebugLog, "{}: clearContext"sv) + llama_kv_cache_clear(GraphRef.LlamaContext.get()); + common_sampler_reset(CxtRef.LlamaSampler.get()); + CxtRef.NPos = 0; + CxtRef.LlamaOutputTokens.clear(); + CxtRef.LlamaOutputs.clear(); + LOG_DEBUG(GraphRef.EnableDebugLog, "{}: clearContext...Done"sv) +} + +// Evaluate the input tokens. Clean all inputs if succeeded. +ErrNo evaluateInput(Graph &GraphRef, Context &CxtRef, + std::string_view LogPrefix) noexcept { + // Check if the input is set before setting up the context. + if (CxtRef.LlamaInputs.size() == 0) { + RET_ERROR(ErrNo::InvalidArgument, "{}: llama input is not set!"sv, + LogPrefix) + } + + // Get the context size. + const uint64_t NCtx = llama_n_ctx(GraphRef.LlamaContext.get()); + // Minus 4 for the special tokens. (Such as , , ... tokens.) + const uint64_t MaxTokensListSize = NCtx - 4; + // Return value. + auto ReturnCode = ErrNo::Success; + + // Check if the input is too long. + if (static_cast(CxtRef.LlamaInputs.size()) > MaxTokensListSize) { + RET_ERROR(ErrNo::PromptTooLong, + "{}: the prompt is too long. Your input has {} tokens. "sv + "Please reduce it to {} tokens."sv, + LogPrefix, CxtRef.LlamaInputs.size(), MaxTokensListSize) + } + + // Evaluate input tokens. + ReturnCode = + evaluateTokens(Span(CxtRef.LlamaInputs.begin(), + CxtRef.LlamaInputs.size()), + GraphRef, CxtRef.LlamaBatch, CxtRef.NPos, true); + if (ReturnCode != ErrNo::Success) { + RET_ERROR(ReturnCode, "{}: failed to evaluate input tokens."sv, LogPrefix) + } + + return ErrNo::Success; +} + +// Sample and get the output token. +ErrNo sampleOutput(Graph &GraphRef, Context &CxtRef, + bool IsSingleTokenMode = false) noexcept { + // Use idx = -1 to sample the next token. + const llama_token Id = common_sampler_sample( + CxtRef.LlamaSampler.get(), GraphRef.LlamaContext.get(), /* idx */ -1); + common_sampler_accept(CxtRef.LlamaSampler.get(), Id, + /* accept_grammar */ true); + + // Save the output token. + CxtRef.LlamaOutputTokens.emplace_back(Id); + std::string OutputString = + common_token_to_piece(GraphRef.LlamaContext.get(), Id); + CxtRef.LlamaOutputs.insert(CxtRef.LlamaOutputs.end(), OutputString.begin(), + OutputString.end()); + // In single token mode, we do not handle StreamStdout and ReversePrompt. + if (!IsSingleTokenMode) { + // When setting StreamStdout, we print the output to stdout. + if (CxtRef.Conf.StreamStdout) { + fmt::print("{}"sv, + common_token_to_piece(GraphRef.LlamaContext.get(), Id)); + std::fflush(stdout); + } + // Break if reverse prompt is found. + if (!CxtRef.Conf.ReversePrompt.empty() && + std::string(CxtRef.LlamaOutputs.begin(), CxtRef.LlamaOutputs.end()) + .find(CxtRef.Conf.ReversePrompt) != std::string::npos) { + LOG_INFO(GraphRef.EnableLog, "sampleOutput: reverse prompt found."sv) + return ErrNo::EndOfSequence; + } + } + + // Deal with end of text token. + // Only stop on EOS if GraphRef.Params.sparams.ignore_eos is false. + if (!GraphRef.Params.sparams.ignore_eos) { + if (llama_token_is_eog(GraphRef.LlamaModel.get(), Id)) { + LOG_INFO(GraphRef.EnableLog, "sampleOutput: EOS token found."sv) + return ErrNo::EndOfSequence; + } + } + // Evaluate the output token. + return evaluateTokens(Span(&Id, 1), GraphRef, + CxtRef.OutputBatch, CxtRef.NPos, true); +} + +// TODO: Merge into compute. +Expect getEmbedding(Graph &GraphRef, Context &CxtRef) noexcept { + LOG_DEBUG(GraphRef.EnableDebugLog, "getEmbedding"sv) + + const llama_token SepTokenId = llama_token_sep(GraphRef.LlamaModel.get()); + if (SepTokenId > -1) { + if (CxtRef.LlamaInputs.size() > 0 && + CxtRef.LlamaInputs.back() != SepTokenId) { + LOG_WARN( + "getEmbedding: last token in the prompt is not SEP, "sv + "'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF "sv + "header."sv) + } + } + + // Check if the input is too long. + if (static_cast(CxtRef.LlamaInputs.size()) > + GraphRef.Params.n_batch) { + RET_ERROR( + ErrNo::PromptTooLong, + "getEmbedding: the prompt is too long. Your input has {} tokens exceeds batch "sv + "size {}. Please reduce the input size or increase your batch-size."sv, + CxtRef.LlamaInputs.size(), GraphRef.Params.n_batch) + } + + // Evaluate the input tokens. + auto ReturnCode = evaluateInput(GraphRef, CxtRef, "getEmbedding"sv); + if (ReturnCode != ErrNo::Success) { + return ReturnCode; + } + + // Main prediction loop. + const struct llama_model *LlamaModel = + llama_get_model(GraphRef.LlamaContext.get()); + const int32_t NEmbd = llama_n_embd(LlamaModel); + std::vector Embeddings(NEmbd); + + for (int I = 0; I < CxtRef.LlamaBatch.n_tokens; I++) { + if (!CxtRef.LlamaBatch.logits[I]) { + continue; + } + + // Try to get sequence embeddings. + auto *Embd = llama_get_embeddings_seq(GraphRef.LlamaContext.get(), + CxtRef.LlamaBatch.seq_id[I][0]); + if (Embd == nullptr) { + Embd = llama_get_embeddings_ith(GraphRef.LlamaContext.get(), I); + if (Embd == nullptr) { + LOG_ERROR("getEmbedding: failed to get embeddings for token {}"sv, I); + continue; + } + } + + // Normalize the embeddings. + common_embd_normalize(Embd, Embeddings.data(), NEmbd, + static_cast(CxtRef.Conf.EmbdNormalize)); + } + + std::string EmbeddingString; + buildOutputEmbedding(EmbeddingString, NEmbd, Embeddings.data()); + CxtRef.LlamaOutputs = + std::vector(EmbeddingString.begin(), EmbeddingString.end()); + + if (GraphRef.EnableLog) { + common_perf_print(GraphRef.LlamaContext.get(), /* Sampler */ nullptr); + } + + LOG_DEBUG(GraphRef.EnableDebugLog, "getEmbedding...Done"sv) + return ErrNo::Success; +} + +// <<<<<<<< Compute related functions <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + +} // namespace + +Expect load(WasiNNEnvironment &Env, Span> Builders, + [[maybe_unused]] Device Device, uint32_t &GraphId) noexcept { + if (Builders.empty()) { + RET_ERROR(ErrNo::InvalidArgument, + "Invalid builders size, builders size must be > 0."); + } + + // Add a graph + const uint32_t GId = Env.newGraph(Backend::BitNet); + auto &GraphRef = Env.NNGraph[GId].get(); + + // Initialize the plugin parameters. + GraphRef.EnableLog = false; + GraphRef.EnableDebugLog = false; + const common_params CommonParamsDefault; + GraphRef.Params = CommonParamsDefault; + GraphRef.Params.n_keep = 0; + GraphRef.Params.n_chunks = -1; + GraphRef.Params.n_parallel = 1; + GraphRef.Params.grp_attn_n = 1; + GraphRef.Params.grp_attn_w = 512; + GraphRef.Params.n_print = -1; + GraphRef.Params.split_mode = llama_split_mode::LLAMA_SPLIT_MODE_LAYER; + // Initialize the model parameters. + llama_model_params ModelParamsDefault = llama_model_default_params(); + GraphRef.Params.n_gpu_layers = ModelParamsDefault.n_gpu_layers; + GraphRef.Params.mmproj = ""sv; + GraphRef.Params.warmup = false; + // Initialize the context parameters. + llama_context_params ContextParamsDefault = llama_context_default_params(); + GraphRef.Params.n_ctx = ContextParamsDefault.n_ctx; + GraphRef.Params.n_batch = ContextParamsDefault.n_batch; + GraphRef.Params.n_ubatch = ContextParamsDefault.n_ubatch; + GraphRef.Params.cpuparams.n_threads = ContextParamsDefault.n_threads_batch; + GraphRef.Params.cpuparams_batch.n_threads = + ContextParamsDefault.n_threads_batch; + GraphRef.Params.rope_scaling_type = ContextParamsDefault.rope_scaling_type; + GraphRef.Params.pooling_type = ContextParamsDefault.pooling_type; + GraphRef.Params.attention_type = ContextParamsDefault.attention_type; + GraphRef.Params.rope_freq_base = ContextParamsDefault.rope_freq_base; + GraphRef.Params.rope_freq_scale = ContextParamsDefault.rope_freq_scale; + GraphRef.Params.yarn_ext_factor = ContextParamsDefault.yarn_ext_factor; + GraphRef.Params.yarn_attn_factor = ContextParamsDefault.yarn_attn_factor; + GraphRef.Params.yarn_beta_fast = ContextParamsDefault.yarn_beta_fast; + GraphRef.Params.yarn_beta_slow = ContextParamsDefault.yarn_beta_slow; + GraphRef.Params.yarn_orig_ctx = ContextParamsDefault.yarn_orig_ctx; + GraphRef.Params.defrag_thold = ContextParamsDefault.defrag_thold; + GraphRef.Params.cb_eval = ContextParamsDefault.cb_eval; + GraphRef.Params.cb_eval_user_data = ContextParamsDefault.cb_eval_user_data; + GraphRef.Params.embedding = ContextParamsDefault.embeddings; + GraphRef.Params.no_kv_offload = !ContextParamsDefault.offload_kqv; + GraphRef.Params.flash_attn = ContextParamsDefault.flash_attn; + GraphRef.Params.no_perf = ContextParamsDefault.no_perf; + + // Initialize the sampling parameters. + const common_sampler_params SamplerParamsDefault; + GraphRef.Params.sparams = SamplerParamsDefault; + + // Initialize the config parameters. + GraphRef.Conf.StreamStdout = false; + GraphRef.Conf.EmbdNormalize = + static_cast(CommonParamsDefault.embd_normalize); + GraphRef.Conf.NPredict = ContextParamsDefault.n_ctx; + GraphRef.Conf.ReversePrompt = ""sv; + + // Set llama log callback. + llama_log_set(llamaLogCallback, &GraphRef); + LOG_DEBUG(GraphRef.EnableDebugLog, "load start."sv) + + // If the graph builder length > 1, the data of builder[1] is the metadata. + if (Builders.size() > 1) { + const std::string Metadata( + reinterpret_cast(Builders[1].data()), Builders[1].size()); + // Ignore context or model updates when initializing the graph. + auto Res = parseMetadata(GraphRef, GraphRef.Conf, Metadata); + if (Res != ErrNo::Success) { + Env.deleteGraph(GId); + RET_ERROR(Res, "load: Failed to parse metadata."sv); + } + } + + LOG_INFO(GraphRef.EnableLog, "LLAMA_COMMIT {}"sv, LLAMA_COMMIT) + LOG_INFO(GraphRef.EnableLog, "LLAMA_BUILD_NUMBER {}"sv, LLAMA_BUILD_NUMBER) + + LOG_DEBUG(GraphRef.EnableDebugLog, "load: handling model path."sv) + const auto &Weight = Builders[0]; + const std::string_view BinModel(reinterpret_cast(Weight.data()), + Weight.size()); + + if (BinModel.substr(0, 8) == "preload:"sv) { + GraphRef.Params.model = std::string(BinModel.substr(8)); + } else { + LOG_DEBUG(GraphRef.EnableDebugLog, + "load: Model path not found in nn-preload, write model into "sv + "a tmpfile."sv) + GraphRef.Params.model = "bitnet-model.bin"sv; + std::ofstream TempFile(GraphRef.Params.model, + std::ios::out | std::ios::binary | std::ios::trunc); + if (!TempFile) { + Env.deleteGraph(GId); + RET_ERROR(ErrNo::InvalidArgument, "Failed to create temp model file."sv) + } + TempFile.write(BinModel.data(), BinModel.size()); + TempFile.close(); + LOG_DEBUG(GraphRef.EnableDebugLog, + "load: Write model into a tmpfile...Done"sv) + } + LOG_DEBUG(GraphRef.EnableDebugLog, "load: handling model path...Done"sv) + + // Check if the model exists. + if (!std::filesystem::exists( + std::filesystem::u8path(GraphRef.Params.model))) { + Env.deleteGraph(GId); + RET_ERROR(ErrNo::ModelNotFound, + "load: Model file not found at path: '{}'."sv, + GraphRef.Params.model) + } + + LOG_INFO(GraphRef.EnableLog, "load: Loading model from '{}'."sv, + GraphRef.Params.model) + + // Initialize model parameters. + LOG_DEBUG(GraphRef.EnableDebugLog, + "load: initialize model with given parameters."sv) + + llama_backend_init(); + llama_numa_init(GraphRef.Params.numa); + + // Initialize the llama model and context. + common_init_result LlamaInit = common_init_from_params(GraphRef.Params); + GraphRef.LlamaModel.reset(LlamaInit.model); + GraphRef.LlamaContext.reset(LlamaInit.context); + + if (GraphRef.LlamaModel == nullptr) { + Env.deleteGraph(GId); + RET_ERROR(ErrNo::InvalidArgument, "load: Unable to init model."sv) + } + if (GraphRef.LlamaContext == nullptr) { + Env.deleteGraph(GId); + RET_ERROR(ErrNo::InvalidArgument, "load: Unable to init context."sv) + } + + LOG_DEBUG(GraphRef.EnableDebugLog, + "load: initialize model with given parameters...Done"sv) + + // Store the loaded graph. + GraphId = GId; + Env.NNGraph[GId].setReady(); + + LOG_DEBUG(GraphRef.EnableDebugLog, "load...Done"sv) + return ErrNo::Success; +} + +Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, + uint32_t &ContextId) noexcept { + auto &GraphRef = Env.NNGraph[GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "initExecCtx"sv) + ContextId = Env.newContext(GraphId, Env.NNGraph[GraphId]); + auto &CxtRef = Env.NNContext[ContextId].get(); + + LOG_INFO(GraphRef.EnableLog, "llama_system_info: {}"sv, + llama_print_system_info()) + + // Allocate the batch for input string prompt tokens. + CxtRef.LlamaBatch = allocBatch(GraphRef.Params.n_batch); + CxtRef.CurrentBatchSize = GraphRef.Params.n_batch; + + // Allocate the batch for single-token output sampling. + CxtRef.OutputBatch = allocBatch(1); + + // Allocate the sampler + CxtRef.LlamaSampler.reset( + common_sampler_init(GraphRef.LlamaModel.get(), GraphRef.Params.sparams)); + + Env.NNContext[ContextId].setReady(); + LOG_DEBUG(GraphRef.EnableDebugLog, "initExecCtx...Done"sv) + return ErrNo::Success; +} + +Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index, const TensorData &Tensor) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput"sv) + + // Handle Metadata at Index 1 + if (Index == 1) { + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: found Metadata, processing"sv) + bool IsModelUpdated = false; + bool IsContextUpdated = false; + bool IsSamplerUpdated = false; + const std::string Metadata( + reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); + + auto Res = parseMetadata(GraphRef, CxtRef.Conf, Metadata, &IsModelUpdated, + &IsContextUpdated, &IsSamplerUpdated); + if (Res != ErrNo::Success) { + RET_ERROR(Res, "setInput: failed to parse metadata."sv) + } + + if (IsModelUpdated || GraphRef.LlamaModel == nullptr) { + // The llama model may be nullptr if set_input with updated model params + // last time. Therefore besides the model params updated, we should + // reload the llama model if the model is nullptr. + LOG_INFO(GraphRef.EnableLog, + "setInput: Reloading model due to parameter change"sv) + + // Prepare model parameters for the reload. + llama_model_params ModelParams = llama_model_default_params(); + ModelParams.n_gpu_layers = + static_cast(GraphRef.Params.n_gpu_layers); + ModelParams.main_gpu = static_cast(GraphRef.Params.main_gpu); + + // Free all resources that depend on the old model. + GraphRef.LlamaModel.reset(); + + // Due to the model change, the context and sampler should also be + // reloaded. The new context and sampler will be created in the next + // block. + GraphRef.LlamaContext.reset(); + if (CxtRef.LlamaSampler) { + CxtRef.LlamaSampler.reset(); + CxtRef.LlamaSampler = nullptr; + } + + // Attempt to load the model from file with new parameters. + GraphRef.LlamaModel.reset(llama_load_model_from_file( + GraphRef.Params.model.c_str(), ModelParams)); + if (GraphRef.LlamaModel == nullptr) { + Env.NNGraph[CxtRef.GraphId].setInvalid(); + RET_ERROR(ErrNo::InvalidArgument, "setInput: unable to init model."sv) + } + } + + // Reload context if its parameters changed OR if it was cleared by a model + // reload. + if (IsContextUpdated || GraphRef.LlamaContext == nullptr) { + LOG_INFO(GraphRef.EnableLog, + "setInput: Reloading llama context due to parameter change."sv) + GraphRef.LlamaContext.reset(); + llama_context_params CtxParams = + common_context_params_to_llama(GraphRef.Params); + GraphRef.LlamaContext.reset( + llama_new_context_with_model(GraphRef.LlamaModel.get(), CtxParams)); + if (GraphRef.LlamaContext == nullptr) { + Env.NNGraph[CxtRef.GraphId].setInvalid(); + RET_ERROR(ErrNo::InvalidArgument, "setInput: unable to init context."sv) + } + } + + // Re-initialize sampler if its parameters changed OR if it was cleared. + if (IsSamplerUpdated || CxtRef.LlamaSampler == nullptr) { + LOG_INFO(GraphRef.EnableLog, + "setInput: Re-initializing sampler due to parameter change."sv); + CxtRef.LlamaSampler.reset(common_sampler_init(GraphRef.LlamaModel.get(), + GraphRef.Params.sparams)); + if (CxtRef.LlamaSampler == nullptr) { + Env.NNGraph[CxtRef.GraphId].setInvalid(); + RET_ERROR(ErrNo::InvalidArgument, "setInput: unable to init sampler."sv) + } + } + + // Re-allocate batch if the batch size changed. + if (CxtRef.CurrentBatchSize != GraphRef.Params.n_batch) { + LOG_INFO(GraphRef.EnableLog, + "Re-allocating batch due to n_batch change."); + llama_batch_free(CxtRef.LlamaBatch); + CxtRef.LlamaBatch = allocBatch(GraphRef.Params.n_batch); + if (!CxtRef.LlamaBatch.token) { + RET_ERROR(ErrNo::InvalidArgument, "Failed to re-allocate llama_batch."); + } + CxtRef.CurrentBatchSize = GraphRef.Params.n_batch; + } + + Env.NNGraph[CxtRef.GraphId].setReady(); + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: metadata processing...Done"); + return ErrNo::Success; + } + + if (Index != 0) { + RET_ERROR(ErrNo::InvalidArgument, + "Only prompt (index 0) and metadata (index 1) are supported."); + } + + // Check the graph is valid after reloading during previous set_input. + if (!Env.NNGraph[CxtRef.GraphId].isReady()) { + RET_ERROR( + ErrNo::InvalidArgument, + "setInput: Graph is invalid. Please reload again by passing metadata "sv + "in set_input or unload graph."sv) + } + + LOG_DEBUG(GraphRef.EnableLog, "setInput: Clearing KV cache for new prompt."sv) + llama_kv_cache_clear(GraphRef.LlamaContext.get()); + LOG_DEBUG(GraphRef.EnableLog, + "setInput: Clearing KV cache for new prompt...done"sv) + + // Check tensor type. + if (Tensor.RType != TensorType::U8) { + RET_ERROR(ErrNo::InvalidArgument, + "Input tensor must be a UTF-8 string (U8)."); + } + + // Tokenize the new prompt. + const std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize text prompt"sv) + CxtRef.LlamaInputs = + common_tokenize(GraphRef.LlamaContext.get(), Prompt, + llama_add_bos_token(GraphRef.LlamaModel.get()), true); + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize text prompt...Done"sv) + + // Get the number of input tokens (for the metadata). + CxtRef.LlamaNInputs = CxtRef.LlamaInputs.size(); + + // Reset state for the compute loop. + CxtRef.ComputeSingleStarted = false; + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput...Done"sv) + return ErrNo::Success; +} + +Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index, Span OutBuffer, + uint32_t &BytesWritten) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutput: with Index {}"sv, Index) + + // Handle Metadata Output at Index 1 + if (Index == 1) { + const std::string Metadata = buildOutputMetadata(CxtRef); + const size_t BytesToCopy = + std::min(static_cast(OutBuffer.size()), Metadata.length()); + std::copy_n(Metadata.data(), BytesToCopy, OutBuffer.data()); + BytesWritten = static_cast(Metadata.length()); + + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutput: Metadata (Index 1)...Done"sv) + return ErrNo::Success; + } + + const size_t BytesToCopy = std::min(static_cast(OutBuffer.size()), + CxtRef.LlamaOutputs.size()); + std::copy_n(CxtRef.LlamaOutputs.data(), BytesToCopy, OutBuffer.data()); + BytesWritten = CxtRef.LlamaOutputs.size(); + + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutput: Text (Index 0)...Done"sv) + return ErrNo::Success; +} + +Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "compute"); + + // Clear the context and reset the sampler. + clearContext(GraphRef, CxtRef); + + if (GraphRef.Params.embedding) { + return getEmbedding(GraphRef, CxtRef); + } + + // Evaluate the input tokens. + auto ReturnCode = evaluateInput(GraphRef, CxtRef, "compute"sv); + if (ReturnCode != ErrNo::Success) { + return ReturnCode; + } + + // Main prediction loop. + LOG_DEBUG(GraphRef.EnableDebugLog, "compute: enter main prediction loop"sv) + int64_t NPredict = CxtRef.Conf.NPredict; + if (NPredict < 0) { + NPredict = INT32_MAX; + } + + while (NPredict > 0) { + ReturnCode = sampleOutput(GraphRef, CxtRef); + if (ReturnCode != ErrNo::Success) { + break; + } + NPredict--; + } + + if (ReturnCode == ErrNo::EndOfSequence || ReturnCode == ErrNo::ContextFull) { + LOG_INFO(GraphRef.EnableLog, "compute finished with status: {}."sv, + static_cast(ReturnCode)) + return ErrNo::Success; + } + + LOG_DEBUG(GraphRef.EnableDebugLog, + "compute: enter main prediction loop...Done"sv) + + if (GraphRef.EnableLog) { + common_perf_print(GraphRef.LlamaContext.get(), CxtRef.LlamaSampler.get()); + } + + LOG_DEBUG(GraphRef.EnableDebugLog, "compute...Done") + return ReturnCode; +} + +Expect getOutputSingle(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index, Span OutBuffer, + uint32_t &BytesWritten) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutputSingle: with Index {}"sv, Index) + + // Metadata Output at Index 1 + if (Index == 1) { + const std::string Metadata = buildOutputMetadata(CxtRef); + const size_t BytesToCopy = + std::min(static_cast(OutBuffer.size()), Metadata.length()); + std::copy_n(Metadata.data(), BytesToCopy, OutBuffer.data()); + BytesWritten = static_cast(Metadata.length()); + + LOG_DEBUG(GraphRef.EnableDebugLog, + "getOutputSingle: Metadata (Index 1)...Done"sv) + return ErrNo::Success; + } + + if (CxtRef.LlamaOutputTokens.empty()) { + BytesWritten = 0; + return ErrNo::Success; + } + + const std::string LastTokenStr = common_token_to_piece( + GraphRef.LlamaContext.get(), CxtRef.LlamaOutputTokens.back()); + + const size_t BytesToCopy = + std::min(static_cast(OutBuffer.size()), LastTokenStr.length()); + std::copy_n(LastTokenStr.data(), BytesToCopy, OutBuffer.data()); + BytesWritten = LastTokenStr.length(); + + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutputSingle: Text (Index 0)...Done"sv) + return ErrNo::Success; +} + +Expect computeSingle(WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "computeSingle"sv) + + auto ReturnCode = ErrNo::Success; + if (!CxtRef.ComputeSingleStarted) { + // Clear the context and reset the sampler. + clearContext(GraphRef, CxtRef); + ReturnCode = evaluateInput(GraphRef, CxtRef, "computeSingle"sv); + if (ReturnCode != ErrNo::Success) { + return ReturnCode; + } + + CxtRef.ComputeSingleStarted = true; + } + + // Main prediction process. + LOG_DEBUG(GraphRef.EnableDebugLog, + "computeSingle: enter main prediction process"sv) + ReturnCode = sampleOutput(GraphRef, CxtRef, /* IsSingleTokenMode */ true); + if (ReturnCode != ErrNo::Success) { + CxtRef.ComputeSingleStarted = false; + } + LOG_DEBUG(GraphRef.EnableDebugLog, + "computeSingle: enter main prediction process...Done"sv) + // End of main predict process. + + LOG_DEBUG(GraphRef.EnableDebugLog, "computeSingle...Done"sv) + return ReturnCode; +} + +Expect finiSingle(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "finiSingle"); + + if (GraphRef.EnableLog) { + common_perf_print(GraphRef.LlamaContext.get(), CxtRef.LlamaSampler.get()); + } + + // Reset the llama sampler. + common_sampler_reset(CxtRef.LlamaSampler.get()); + + // Clear the outputs. + LOG_DEBUG(GraphRef.EnableDebugLog, + "finiSingle: clear the previous output and tokens"sv) + CxtRef.LlamaOutputs.clear(); + CxtRef.LlamaOutputTokens.clear(); + LOG_DEBUG(GraphRef.EnableDebugLog, + "finiSingle: clear the previous output and tokens...Done"sv) + + CxtRef.NPos = 0; + CxtRef.ComputeSingleStarted = false; + + LOG_DEBUG(GraphRef.EnableDebugLog, "finiSingle...Done"sv) + return ErrNo::Success; +} + +Expect unload(WasiNNEnvironment &Env, uint32_t GraphId) noexcept { + + auto &GraphRef = Env.NNGraph[GraphId].get(); + if (GraphId >= Env.NNGraph.size() || GraphRef.LlamaModel == nullptr) { + return ErrNo::Success; + } + + LOG_DEBUG(GraphRef.EnableDebugLog, "unload"sv) + + Env.deleteGraph(GraphId); + + LOG_DEBUG(GraphRef.EnableDebugLog, "unload...Done"sv) + return ErrNo::Success; +} + +Expect finalizeExecCtx(WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "finalizeExecCtx"sv) + + CxtRef.LlamaSampler.reset(); + llama_batch_free(CxtRef.LlamaBatch); + llama_batch_free(CxtRef.OutputBatch); + Env.deleteContext(ContextId); + + LOG_DEBUG(GraphRef.EnableDebugLog, "finalizeExecCtx...Done"sv) + return ErrNo::Success; +} + +#else +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error("[WASI-NN] BitNet backend is not built. Please build with " + "-DWASMEDGE_PLUGIN_WASI_NN_BACKEND_BITNET=ON."sv); + return ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WasiNNEnvironment &, Span>, Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WasiNNEnvironment &, uint32_t, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WasiNNEnvironment &, uint32_t, uint32_t, Span, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +Expect getOutputSingle(WasiNNEnvironment &, uint32_t, uint32_t, + Span, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect computeSingle(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +Expect finiSingle(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +Expect unload(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +Expect finalizeExecCtx(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} + +#endif +} // namespace WasmEdge::Host::WASINN::BitNet \ No newline at end of file diff --git a/plugins/wasi_nn/wasinn_bitnet.h b/plugins/wasi_nn/wasinn_bitnet.h new file mode 100644 index 00000000..3361c40d --- /dev/null +++ b/plugins/wasi_nn/wasinn_bitnet.h @@ -0,0 +1,140 @@ +#pragma once + +#include "plugin/plugin.h" +#include "wasinntypes.h" +#include + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_BITNET +#include +#include +#include +#include +#endif + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} + +namespace WasmEdge::Host::WASINN::BitNet { + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_BITNET + +struct LlamaModelDeleter { + void operator()(llama_model *Ptr) const { + if (Ptr) { + llama_free_model(Ptr); + } + } +}; +struct LlamaContextDeleter { + void operator()(llama_context *Ptr) const { + if (Ptr) { + llama_free(Ptr); + } + } +}; +struct CommonSamplerDeleter { + void operator()(common_sampler *Ptr) const { + if (Ptr) { + common_sampler_free(Ptr); + } + } +}; + +using LlamaModelPtr = std::unique_ptr; +using LlamaContextPtr = std::unique_ptr; +using CommonSamplerPtr = std::unique_ptr; + +enum class EmbdNormalizeType : int32_t { + // From: https://github.com/ggerganov/llama.cpp/blob/master/common/common.h + None = -1, + MaxAbsolute = 0, + Taxicab = 1, + Euclidean = 2, + PNorm = 3, +}; + +struct LocalConfig { + int64_t NPredict = -1; + bool StreamStdout = false; + std::string ReversePrompt; + EmbdNormalizeType EmbdNormalize = EmbdNormalizeType::Euclidean; +}; + +struct Graph { + bool EnableLog = false; + bool EnableDebugLog = false; + common_params Params; + LlamaModelPtr LlamaModel = nullptr; + LlamaContextPtr LlamaContext = nullptr; + LocalConfig Conf; +}; + +struct Context { +public: + Context(uint32_t GId, Graph &G) noexcept : GraphId(GId), Conf(G.Conf) {} + + uint32_t GraphId; + bool ComputeSingleStarted = false; + + int32_t NPos = 0; + std::vector LlamaInputs; + uint64_t LlamaNInputs = 0; + std::vector LlamaOutputTokens; + std::vector LlamaOutputs; + CommonSamplerPtr LlamaSampler = nullptr; + int64_t CurrentBatchSize = 0; + struct llama_batch LlamaBatch; + struct llama_batch OutputBatch; + + LocalConfig Conf; +}; + +#else + +struct Graph {}; +struct Context { + Context(uint32_t, Graph &) noexcept {} +}; +#endif + +struct Environ {}; + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; + +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; + +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; + +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; + +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; + +Expect getOutputSingle(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; + +Expect computeSingle(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; + +Expect finiSingle(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; + +Expect finalizeExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; + +Expect unload(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId) noexcept; + +} // namespace WasmEdge::Host::WASINN::BitNet \ No newline at end of file diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index e09d3780..6f4a1c2e 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -38,7 +38,8 @@ std::map BackendMap = { {"mlx"sv, Backend::MLX}, {"piper"sv, Backend::Piper}, {"chattts"sv, Backend::ChatTTS}, - {"openvinogenai"sv, Backend::OpenVINOGenAI}}; + {"openvinogenai"sv, Backend::OpenVINOGenAI}, + {"bitnet"sv, Backend::BitNet}}; std::map DeviceMap = {{"cpu"sv, Device::CPU}, {"gpu"sv, Device::GPU}, @@ -110,6 +111,7 @@ WasiNNEnvironment::WasiNNEnvironment() noexcept { auto Device = DeviceMap.find(Target); if (Backend != BackendMap.end() && Device != DeviceMap.end()) { if (Backend->second == Backend::GGML || + Backend->second == Backend::BitNet || (Backend->second == Backend::PyTorch && Encode == "pytorchaoti"sv)) { // In GGML, we only support loading one model from nn-preload // config. To handle paths on Windows that contains `:` in the diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index 95c54279..956b83bd 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -3,6 +3,7 @@ #pragma once +#include "wasinn_bitnet.h" #include "wasinn_chattts.h" #include "wasinn_ggml.h" #include "wasinn_mlx.h" diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index f1171393..390045ca 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -499,9 +499,13 @@ Expect WasiNNGetOutputSingle::bodyImpl( case WASINN::Backend::GGML: return WASINN::GGML::getOutputSingle(Env, ContextId, Index, OutBuffer, *BytesWritten); + case WASINN::Backend::BitNet: + return WASINN::BitNet::getOutputSingle(Env, ContextId, Index, OutBuffer, + *BytesWritten); default: - spdlog::error("[WASI-NN] get_output_single: Only GGML backend supports "sv - "get_output_single."sv); + spdlog::error( + "[WASI-NN] get_output_single: Only GGML and BitNet backend supports "sv + "get_output_single."sv); return WASINN::ErrNo::InvalidArgument; } } @@ -601,9 +605,12 @@ WasiNNComputeSingle::bodyImpl(const Runtime::CallingFrame &Frame, switch (Env.NNContext[ContextId].getBackend()) { case WASINN::Backend::GGML: return WASINN::GGML::computeSingle(Env, ContextId); + case WASINN::Backend::BitNet: + return WASINN::BitNet::computeSingle(Env, ContextId); default: - spdlog::error("[WASI-NN] compute_single: Only GGML backend supports "sv - "compute_single."sv); + spdlog::error( + "[WASI-NN] compute_single: Only GGML and BitNet backend supports "sv + "compute_single."sv); return WASINN::ErrNo::InvalidArgument; } } @@ -642,9 +649,11 @@ WasiNNFiniSingle::bodyImpl(const Runtime::CallingFrame &Frame, switch (Env.NNContext[ContextId].getBackend()) { case WASINN::Backend::GGML: return WASINN::GGML::finiSingle(Env, ContextId); + case WASINN::Backend::BitNet: + return WASINN::BitNet::finiSingle(Env, ContextId); default: spdlog::error( - "[WASI-NN] fini_single: Only GGML backend supports fini_single."sv); + "[WASI-NN] fini_single: Only GGML and BitNet backend supports fini_single."sv); return WASINN::ErrNo::InvalidArgument; } } @@ -675,8 +684,10 @@ Expect WasiNNUnload::bodyImpl(const Runtime::CallingFrame &Frame, return WASINN::Whisper::unload(Env, GraphId); case WASINN::Backend::ChatTTS: return WASINN::ChatTTS::unload(Env, GraphId); + case WASINN::Backend::BitNet: + return WASINN::BitNet::unload(Env, GraphId); default: - spdlog::error("[WASI-NN] unlaod: Only GGML, Whisper, and ChatTTS "sv + spdlog::error("[WASI-NN] unload: Only GGML, Whisper, ChatTTS and BitNet "sv "backends support unload."sv); return WASINN::ErrNo::InvalidArgument; } @@ -710,9 +721,12 @@ WasiNNFinalizeExecCtx::bodyImpl(const Runtime::CallingFrame &Frame, return WASINN::GGML::finalizeExecCtx(Env, ContextId); case WASINN::Backend::Whisper: return WASINN::Whisper::finalizeExecCtx(Env, ContextId); + case WASINN::Backend::BitNet: + return WASINN::BitNet::finalizeExecCtx(Env, ContextId); default: - spdlog::error("[WASI-NN] finalize_execution_context: Only GGML and "sv - "Whisper backends support finalize_execution_context."sv); + spdlog::error( + "[WASI-NN] finalize_execution_context: Only GGML, BitNet and "sv + "Whisper backends support finalize_execution_context."sv); return WASINN::ErrNo::InvalidArgument; } } diff --git a/plugins/wasi_nn/wasinntypes.h b/plugins/wasi_nn/wasinntypes.h index d6e37bd1..d82873d3 100644 --- a/plugins/wasi_nn/wasinntypes.h +++ b/plugins/wasi_nn/wasinntypes.h @@ -51,6 +51,7 @@ enum class Backend : uint8_t { Piper = 11, ChatTTS = 12, OpenVINOGenAI = 13, + BitNet = 14, }; #define FOR_EACH_BACKEND(F) \ @@ -65,7 +66,8 @@ enum class Backend : uint8_t { F(Piper) \ F(ChatTTS) \ F(MLX) \ - F(OpenVINOGenAI) + F(OpenVINOGenAI) \ + F(BitNet) struct TensorData { Span Dimension; diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index 9fdb3868..66f37d78 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -125,6 +125,19 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) target_compile_options(wasiNNTests PUBLIC -Wno-unused-parameter ) + + elseif(BACKEND STREQUAL "bitnet") + message(STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_bitnet_fixtures") + download( + https://huggingface.co/microsoft/bitnet-b1.58-2B-4T-gguf/resolve/main/ggml-model-i2_s.gguf + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_bitnet_fixtures/ggml-model-i2_s.gguf + MD5=65cb04366e4d02ccd78b4b7b48c84b3b + ) + if(NOT CMAKE_CXX_COMPILER_ID MATCHES "MSVC") + target_compile_options(wasiNNTests PUBLIC + -Wno-unused-function + ) + endif() else() # Add the other backend test files fetching here. endif() diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 50e286ce..df73f023 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -29,7 +29,8 @@ using WasmEdge::Host::WASINN::TensorType; defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_WHISPER) || \ defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_CHATTTS) || \ - defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX) + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX) || \ + defined(WASMEDGE_PLUGIN_WASI_NN_BACKEND_BITNET) namespace { template @@ -3112,3 +3113,352 @@ TEST(WasiNNTest, MLXBackend) { } } #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_BITNET +TEST(WasiNNTest, BitNetBackend) { + // Create the wasi_nn module instance. + auto NNMod = createModule(); + ASSERT_TRUE(NNMod); + + // Create the calling frame with a memory instance. + WasmEdge::Runtime::Instance::ModuleInstance Mod(""); + Mod.addHostMemory( + "memory", std::make_unique( + WasmEdge::AST::MemoryType(60000))); + auto *MemInstPtr = Mod.findMemoryExports("memory"); + ASSERT_TRUE(MemInstPtr != nullptr); + auto &MemInst = *MemInstPtr; + WasmEdge::Runtime::CallingFrame CallFrame(nullptr, &Mod); + + // --- Get host functions --- + // Get the function "load". + auto *FuncInst = NNMod->findFuncExports("load"); + ASSERT_NE(FuncInst, nullptr); + auto &HostFuncLoad = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "init_execution_context". + FuncInst = NNMod->findFuncExports("init_execution_context"); + ASSERT_NE(FuncInst, nullptr); + auto &HostFuncInit = dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "set_input". + FuncInst = NNMod->findFuncExports("set_input"); + ASSERT_NE(FuncInst, nullptr); + auto &HostFuncSetInput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "get_output". + FuncInst = NNMod->findFuncExports("get_output"); + ASSERT_NE(FuncInst, nullptr); + auto &HostFuncGetOutput = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "compute". + FuncInst = NNMod->findFuncExports("compute"); + ASSERT_NE(FuncInst, nullptr); + auto &HostFuncCompute = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "unload". + FuncInst = NNMod->findFuncExports("unload"); + ASSERT_NE(FuncInst, nullptr); + auto &HostFuncUnload = + dynamic_cast(FuncInst->getHostFunc()); + // Get the function "compute_single". + FuncInst = NNMod->findFuncExports("compute_single"); + ASSERT_NE(FuncInst, nullptr); + auto &HostFuncComputeSingle = + dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "get_output_single". + FuncInst = NNMod->findFuncExports("get_output_single"); + ASSERT_NE(FuncInst, nullptr); + auto &HostFuncGetOutputSingle = + dynamic_cast( + FuncInst->getHostFunc()); + // Get the function "fini_single". + FuncInst = NNMod->findFuncExports("fini_single"); + ASSERT_NE(FuncInst, nullptr); + auto &HostFuncFiniSingle = + dynamic_cast(FuncInst->getHostFunc()); + + // --- Test Data & Pointer Setup --- + const std::string ModelPath = "./wasinn_bitnet_fixtures/ggml-model-i2_s.gguf"; + const std::string ModelPreloadStr = "preload:" + ModelPath; + const std::string MetadataStr = R"({"n-predict": 128})"; + const std::string Prompt = "Once upon a time, "; + + uint32_t BuilderPtr = 0; + uint32_t LoadEntryPtr = 0; + uint32_t SetInputEntryPtr = 0; + uint32_t StorePtr = 65536; + const uint32_t OutBoundPtr = UINT32_C(61000) * UINT32_C(65536); + std::array Errno = {0}; + uint32_t GraphId = 0; + uint32_t CtxId = 0; + + // BitNet WASI-NN load tests + // Test: load -- empty builder array. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, 0, static_cast(Backend::BitNet), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: load -- graph builder ptr out of bounds. + { + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + OutBoundPtr, 1, static_cast(Backend::BitNet), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: load -- model bin ptr out of bounds. + { + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, OutBoundPtr, 10, BuilderPtr); + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, 1, static_cast(Backend::BitNet), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: load -- invalid metadata encoding. + { + const std::string InvalidMetadataStr = R"({"n-predict": "not-a-number")"; + std::vector ModelPreloadVec(ModelPreloadStr.begin(), + ModelPreloadStr.end()); + std::vector InvalidMetadataVec(InvalidMetadataStr.begin(), + InvalidMetadataStr.end()); + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, StorePtr, ModelPreloadVec.size(), BuilderPtr); + writeFatPointer(MemInst, StorePtr + ModelPreloadVec.size(), + InvalidMetadataVec.size(), BuilderPtr); + writeBinaries(MemInst, ModelPreloadVec, StorePtr); + writeBinaries(MemInst, InvalidMetadataVec, + StorePtr + ModelPreloadVec.size()); + EXPECT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, 2, static_cast(Backend::BitNet), + static_cast(Device::CPU), BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidEncoding)); + } + // Test: load -- load successfully. + { + std::vector ModelPreloadVec(ModelPreloadStr.begin(), + ModelPreloadStr.end()); + std::vector MetadataVec(MetadataStr.begin(), MetadataStr.end()); + BuilderPtr = LoadEntryPtr; + writeFatPointer(MemInst, StorePtr, ModelPreloadVec.size(), BuilderPtr); + writeFatPointer(MemInst, StorePtr + ModelPreloadVec.size(), + MetadataVec.size(), BuilderPtr); + writeBinaries(MemInst, ModelPreloadVec, StorePtr); + writeBinaries(MemInst, MetadataVec, + StorePtr + ModelPreloadVec.size()); + StorePtr += ModelPreloadVec.size() + MetadataVec.size(); + ASSERT_TRUE(HostFuncLoad.run( + CallFrame, + std::initializer_list{ + LoadEntryPtr, 2, static_cast(Backend::BitNet), + static_cast(Device::CPU), BuilderPtr}, + Errno)) + << "Load failed. Ensure model file exists at: " << ModelPath; + ASSERT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + GraphId = *MemInst.getPointer(BuilderPtr); + BuilderPtr += 4; + } + + // BitNet WASI-NN init_execution_context tests + // Test: init_execution_context -- graph id invalid. + { + EXPECT_TRUE(HostFuncInit.run( + CallFrame, std::initializer_list{999, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: init_execution_context -- init context successfully. + { + ASSERT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{GraphId, BuilderPtr}, + Errno)); + ASSERT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + CtxId = *MemInst.getPointer(BuilderPtr); + BuilderPtr += 4; + } + + // BitNet WASI-NN set_input tests + SetInputEntryPtr = BuilderPtr; + { + std::vector PromptData(Prompt.begin(), Prompt.end()); + std::vector PromptDim = { + static_cast(PromptData.size())}; + + writeFatPointer(MemInst, StorePtr, PromptDim.size(), BuilderPtr); + writeUInt32(MemInst, static_cast(TensorType::U8), BuilderPtr); + writeFatPointer(MemInst, StorePtr + PromptDim.size() * 4, PromptData.size(), + BuilderPtr); + writeBinaries(MemInst, PromptDim, StorePtr); + writeBinaries(MemInst, PromptData, + StorePtr + PromptDim.size() * 4); + } + // Test: set_input -- invalid context id. + { + EXPECT_TRUE(HostFuncSetInput.run( + CallFrame, + std::initializer_list{999, 0, SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: set_input -- invalid tensor index. + { + EXPECT_TRUE(HostFuncSetInput.run( + CallFrame, + std::initializer_list{CtxId, 2, SetInputEntryPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: set_input -- set input successfully. + { + ASSERT_TRUE(HostFuncSetInput.run( + CallFrame, + std::initializer_list{CtxId, 0, SetInputEntryPtr}, + Errno)); + ASSERT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // BitNet WASI-NN compute and get_output tests + // Test: compute -- invalid context ID. + { + EXPECT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{999}, Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: compute -- compute successfully. + { + ASSERT_TRUE(HostFuncCompute.run( + CallFrame, std::initializer_list{CtxId}, Errno)); + ASSERT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + // Test: get_output -- output buffer pointer out of bounds. + { + EXPECT_TRUE( + HostFuncGetOutput.run(CallFrame, + std::initializer_list{ + CtxId, 0, OutBoundPtr, 5, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: get_output -- bytes written pointer out of bounds. + { + EXPECT_TRUE( + HostFuncGetOutput.run(CallFrame, + std::initializer_list{ + CtxId, 0, StorePtr, 5, OutBoundPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: get_output -- get output successfully. + { + uint32_t BytesNeeded = 0; + ASSERT_TRUE( + HostFuncGetOutput.run(CallFrame, + std::initializer_list{ + CtxId, 0, StorePtr, 0, BuilderPtr}, + Errno)); + ASSERT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + BytesNeeded = *MemInst.getPointer(BuilderPtr); + EXPECT_GT(BytesNeeded, 10); + } + + // BitNet compute_single tests + { + std::string FullStreamedOutput = ""; + const int MaxStreamTokens = 20; + + // Test: set_input -- set prompt to start a new streaming sequence. + { + ASSERT_TRUE( + HostFuncSetInput.run(CallFrame, + std::initializer_list{ + CtxId, 0, SetInputEntryPtr}, + Errno)); + ASSERT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + + // Test: compute_single and get_output_single in a loop. + for (int i = 0; i < MaxStreamTokens; ++i) { + ASSERT_TRUE(HostFuncComputeSingle.run( + CallFrame, std::initializer_list{CtxId}, + Errno)); + if (Errno[0].get() == + static_cast(ErrNo::EndOfSequence)) { + break; + } + ASSERT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + + uint32_t SingleTokenBytes = 0; + ASSERT_TRUE(HostFuncGetOutputSingle.run( + CallFrame, + std::initializer_list{CtxId, 0, StorePtr, 32, + BuilderPtr}, + Errno)); + ASSERT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + SingleTokenBytes = *MemInst.getPointer(BuilderPtr); + if (SingleTokenBytes > 0) { + auto TokenSpan = *MemInst.getBytes(StorePtr, SingleTokenBytes); + FullStreamedOutput += std::string( + reinterpret_cast(TokenSpan.data()), TokenSpan.size()); + } + } + EXPECT_GT(FullStreamedOutput.length(), 10); + + // Test: fini_single -- finalize the streaming session. + { + ASSERT_TRUE(HostFuncFiniSingle.run( + CallFrame, std::initializer_list{CtxId}, + Errno)); + ASSERT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + } + } + + // BitNet WASI-NN unload tests. + // Test: unload -- invalid graph id. + { + EXPECT_TRUE(HostFuncUnload.run( + CallFrame, std::initializer_list{999}, Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } + // Test: unload -- unload successfully and verify. + { + ASSERT_TRUE(HostFuncUnload.run( + CallFrame, std::initializer_list{GraphId}, + Errno)); + ASSERT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); + + EXPECT_TRUE(HostFuncInit.run( + CallFrame, + std::initializer_list{GraphId, BuilderPtr}, + Errno)); + EXPECT_EQ(Errno[0].get(), + static_cast(ErrNo::InvalidArgument)); + } +} +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_BITNET \ No newline at end of file From 0e38fd53d9e64762a50769758e2269b6b3c2f0d5 Mon Sep 17 00:00:00 2001 From: Vishruth Thimmaiah Date: Thu, 28 Aug 2025 13:44:56 +0530 Subject: [PATCH 576/623] feat: support WasmEdge on the s390x platform (#4251) * feat: enable building the interpreter on s390x hardware Implements support for running the interpreter mode on s390x hardware. Signed-off-by: vishruth-thimmaiah * feat: add support for wasinn-ggml backend on s390x Signed-off-by: vishruth-thimmaiah * feat: aot and jit support on s390x Signed-off-by: vishruth-thimmaiah --------- Signed-off-by: vishruth-thimmaiah --- plugins/wasi_nn/wasinn_ggml.cpp | 29 ++++++++++++++++------------- plugins/wasi_nn/wasinnenv.h | 2 +- plugins/wasi_nn/wasinnfunc.cpp | 23 ++++++++++++++--------- test/plugins/wasi_nn/CMakeLists.txt | 13 +++++++++++-- test/plugins/wasi_nn/wasi_nn.cpp | 10 ++++++---- 5 files changed, 48 insertions(+), 29 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index aca618eb..f72181b8 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -2,6 +2,7 @@ // SPDX-FileCopyrightText: 2019-2024 Second State INC #include "wasinn_ggml.h" +#include "common/types.h" #include "wasinnenv.h" #include @@ -2645,8 +2646,8 @@ ErrNo codesToSpeech(Graph &GraphRef, Context &CxtRef) noexcept { Expect load(WasiNNEnvironment &Env, Span> Builders, [[maybe_unused]] Device Device, uint32_t &GraphId) noexcept { // Add a new graph. - uint32_t GId = Env.newGraph(Backend::GGML); - auto &GraphRef = Env.NNGraph[GId].get(); + EndianValue GId = Env.newGraph(Backend::GGML); + auto &GraphRef = Env.NNGraph[GId.raw()].get(); // Initialize the plugin parameters. GraphRef.EnableLog = false; @@ -2688,7 +2689,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Ignore context or model updates when initializing the graph. auto Res = parseMetadata(GraphRef, GraphRef.Conf, Metadata); if (Res != ErrNo::Success) { - Env.deleteGraph(GId); + Env.deleteGraph(GId.raw()); RET_ERROR(Res, "load: Failed to parse metadata."sv) } } @@ -2715,7 +2716,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, std::ofstream TempFile(GraphRef.Params.model.path, std::ios::out | std::ios::binary); if (!TempFile) { - Env.deleteGraph(GId); + Env.deleteGraph(GId.raw()); RET_ERROR(ErrNo::InvalidArgument, "load: Failed to create the temporary file. Currently, our "sv "workaround involves creating a temporary model file named "sv @@ -2732,7 +2733,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Check if the model exists. if (!std::filesystem::exists( std::filesystem::u8path(GraphRef.Params.model.path))) { - Env.deleteGraph(GId); + Env.deleteGraph(GId.raw()); RET_ERROR(ErrNo::ModelNotFound, "load: model file not found."sv) } GraphRef.Params.model = GraphRef.Params.model; @@ -2754,11 +2755,11 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, GraphRef.LlamaModel = std::move(LlamaInit.model); GraphRef.LlamaContext = std::move(LlamaInit.context); if (GraphRef.LlamaModel == nullptr) { - Env.deleteGraph(GId); + Env.deleteGraph(GId.raw()); RET_ERROR(ErrNo::InvalidArgument, "load: unable to init model."sv) } if (GraphRef.LlamaContext == nullptr) { - Env.deleteGraph(GId); + Env.deleteGraph(GId.raw()); RET_ERROR(ErrNo::InvalidArgument, "load: unable to init context."sv) } LOG_DEBUG(GraphRef.EnableDebugLog, @@ -2773,19 +2774,19 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, GraphRef.TTSModel = std::move(TTSInit.model); GraphRef.TTSContext = std::move(TTSInit.context); if (GraphRef.TTSModel == nullptr) { - Env.deleteGraph(GId); + Env.deleteGraph(GId.raw()); RET_ERROR(ErrNo::InvalidArgument, "load: unable to init TTS model."sv) } if (GraphRef.TTSContext == nullptr) { - Env.deleteGraph(GId); + Env.deleteGraph(GId.raw()); RET_ERROR(ErrNo::InvalidArgument, "load: unable to init TTS context."sv) } LOG_DEBUG(GraphRef.EnableDebugLog, "load: initialize TTS model...Done"sv) } // Store the loaded graph. - GraphId = GId; - Env.NNGraph[GId].setReady(); + GraphId = GId.le(); + Env.NNGraph[GId.raw()].setReady(); LOG_DEBUG(GraphRef.EnableDebugLog, "load...Done"sv) return ErrNo::Success; @@ -2812,6 +2813,7 @@ Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, common_sampler_init(GraphRef.LlamaModel.get(), GraphRef.Params.sampling); Env.NNContext[ContextId].setReady(); + ContextId = EndianValue(ContextId).le(); LOG_DEBUG(GraphRef.EnableDebugLog, "initExecCtx...Done"sv) return ErrNo::Success; } @@ -3111,7 +3113,8 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, std::copy_n(CxtRef.LlamaOutputs.data(), CxtRef.LlamaOutputs.size(), OutBuffer.data()); - BytesWritten = static_cast(CxtRef.LlamaOutputs.size()); + BytesWritten = + EndianValue(static_cast(CxtRef.LlamaOutputs.size())).le(); LOG_DEBUG(GraphRef.EnableDebugLog, "getOutput: with Index {}...Done"sv, Index) return ErrNo::Success; } @@ -3211,7 +3214,7 @@ Expect getOutputSingle(WasiNNEnvironment &Env, uint32_t ContextId, std::string LastToken = common_token_to_piece( GraphRef.LlamaContext.get(), CxtRef.LlamaOutputTokens.back()); std::copy_n(LastToken.data(), LastToken.length(), OutBuffer.data()); - BytesWritten = static_cast(LastToken.length()); + BytesWritten = EndianValue(static_cast(LastToken.length())).le(); LOG_DEBUG(GraphRef.EnableDebugLog, "getOutputSingle: with Index {}...Done"sv, Index) return ErrNo::Success; diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index 956b83bd..e4e0a39d 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -234,7 +234,7 @@ struct WasiNNEnvironment : bool mdGet(std::string Name, uint32_t &GraphId) noexcept { std::shared_lock Lock(MdMutex); if (auto It = MdMap.find(Name); It != MdMap.end()) { - GraphId = static_cast(It->second); + GraphId = EndianValue(static_cast(It->second)).le(); return true; } return false; diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index 390045ca..64056c5c 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -103,8 +103,9 @@ WasiNNLoad::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t BuilderPtr, Builders.reserve(BuilderLen); for (size_t I = 0; I < WasiBuilders.size(); ++I) { const auto &WasiBuilder = WasiBuilders[I]; - auto Builder = MemInst->getSpan(WasiBuilder.Ptr, WasiBuilder.Len); - if (unlikely(Builder.size() != WasiBuilder.Len)) { + auto Builder = MemInst->getSpan(EndianValue(WasiBuilder.Ptr).le(), + EndianValue(WasiBuilder.Len).le()); + if (unlikely(Builder.size() != EndianValue(WasiBuilder.Len).le())) { spdlog::error("[WASI-NN] Failed when accessing the Builder[{}] memory."sv, I); return WASINN::ErrNo::InvalidArgument; @@ -308,20 +309,24 @@ WasiNNSetInput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t ContextId, } WASINN::TensorData Tensor; - Tensor.Dimension = MemInst->getSpan(WasiTensor->DimensionPtr, - WasiTensor->DimensionLen); - if (unlikely(Tensor.Dimension.size() != WasiTensor->DimensionLen)) { + Tensor.Dimension = + MemInst->getSpan(EndianValue(WasiTensor->DimensionPtr).le(), + EndianValue(WasiTensor->DimensionLen).le()); + if (unlikely(Tensor.Dimension.size() != + EndianValue(WasiTensor->DimensionLen).le())) { spdlog::error("[WASI-NN] Failed when accessing the Dimension memory."sv); return WASINN::ErrNo::InvalidArgument; } Tensor.Tensor = - MemInst->getSpan(WasiTensor->TensorPtr, WasiTensor->TensorLen); - if (unlikely(Tensor.Tensor.size() != WasiTensor->TensorLen)) { + MemInst->getSpan(EndianValue(WasiTensor->TensorPtr).le(), + EndianValue(WasiTensor->TensorLen).le()); + if (unlikely(Tensor.Tensor.size() != + EndianValue(WasiTensor->TensorLen).le())) { spdlog::error("[WASI-NN] Failed when accessing the TensorData memory."sv); return WASINN::ErrNo::InvalidArgument; } - switch (const auto RType = - static_cast(WasiTensor->RType)) { + switch (const auto RType = static_cast( + EndianValue(WasiTensor->RType).le())) { case WASINN::TensorType::F16: case WASINN::TensorType::F32: case WASINN::TensorType::U8: diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index 66f37d78..d39dfdb6 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -60,11 +60,20 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) ) elseif(BACKEND STREQUAL "ggml") message(STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_ggml_fixtures") - download( + if (CMAKE_SYSTEM_PROCESSOR MATCHES "s390x") + # Use a big endian model for s390x + download( + https://huggingface.co/taronaeo/Granite-3.0-1B-A400M-Instruct-BE-GGUF/resolve/main/granite-3.0-1b-a400m-instruct-be.Q2_K.gguf + ${CMAKE_CURRENT_BINARY_DIR}/wasinn_ggml_fixtures/granite-3.gguf + MD5=2520fd8a702468942fdd1595b9dca9a2 + ) + else() + download( https://huggingface.co/TheBloke/orca_mini_v3_7B-GGUF/resolve/main/orca_mini_v3_7b.Q2_K.gguf ${CMAKE_CURRENT_BINARY_DIR}/wasinn_ggml_fixtures/orca_mini.gguf MD5=f895f00678bfbf89f70d6d25f20a7b5f - ) + ) + endif() if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") target_compile_options(wasiNNTests PUBLIC /wd4067 # unexpected tokens following preprocessor directive - expected a newline diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index df73f023..a24c0f35 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include using namespace std::literals; @@ -85,8 +86,7 @@ void writeBinaries(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, void writeUInt32(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, uint32_t Value, uint32_t &Ptr) { - uint32_t *BufPtr = MemInst.getPointer(Ptr); - *BufPtr = Value; + MemInst.storeValue(Value, Ptr); Ptr += 4; } @@ -1319,8 +1319,10 @@ TEST(WasiNNTest, GGMLBackend) { // Load the files. std::string Prompt = "Once upon a time, "; std::vector TensorData(Prompt.begin(), Prompt.end()); - std::vector WeightRead = - readEntireFile("./wasinn_ggml_fixtures/orca_mini.gguf"); + std::string Model = WasmEdge::Endian::native == WasmEdge::Endian::little + ? "./wasinn_ggml_fixtures/orca_mini.gguf" + : "./wasinn_ggml_fixtures/granite-3.gguf"; + std::vector WeightRead = readEntireFile(Model); std::vector TensorDim{1}; uint32_t BuilderPtr = UINT32_C(0); From 51202ff082ecc0cb3be4df8417f59cb676f8e28f Mon Sep 17 00:00:00 2001 From: hydai Date: Thu, 28 Aug 2025 17:53:02 +0800 Subject: [PATCH 577/623] feat(WASI-NN,ggml): bump llama.cpp b6301, support MiniCPM-V4.5 (#4336) Remove the deprecated DefragThold option. Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 2 +- plugins/wasi_nn/wasinn_ggml.cpp | 10 ---------- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 95e3fc5d..f835b292 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -24,7 +24,7 @@ wasmedge_add_library(wasmedgePluginWasiNN include(WASINNDeps) wasmedge_setup_wasinn_target(wasmedgePluginWasiNN PLUGINLIB) -set(WASMEDGE_WASI_NN_VERSION "0.1.30" CACHE STRING "WasmEdge WASI-NN library version") +set(WASMEDGE_WASI_NN_VERSION "0.1.31" CACHE STRING "WasmEdge WASI-NN library version") set(WASMEDGE_WASI_NN_SOVERSION "0" CACHE STRING "WasmEdge WASI-NN library soversion") # Handle the version of the WASI-NN plugin diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index f72181b8..30de8cf9 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -442,16 +442,6 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, } GraphRef.Params.yarn_orig_ctx = static_cast(YarnOrigCtx); } - if (Doc.at_key("defrag-thold").error() == simdjson::SUCCESS) { - double DefragThold; - auto Err = Doc["defrag-thold"].get().get(DefragThold); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the defrag-thold option."sv); - return ErrNo::InvalidArgument; - } - GraphRef.Params.defrag_thold = static_cast(DefragThold); - } if (Doc.at_key("mask-valid").error() == simdjson::SUCCESS) { auto Err = Doc["mask-valid"].get().get(GraphRef.Params.cpuparams.mask_valid); From 40133743ac1ef9df360579d519edab0554d758e8 Mon Sep 17 00:00:00 2001 From: hydai Date: Tue, 2 Sep 2025 14:08:24 +0800 Subject: [PATCH 578/623] feat(WASI-NN,ggml): bump llama.cpp b6343 (#4343) The type of flash attention is changed from bool to string. Previous, true => on, false => off. Currently, it's a mapping relationship: * "on", "enabled" => LLAMA_FLASH_ATTN_TYPE_ENABLED * "off", "disabled" => LLAMA_FLASH_ATTN_TYPE_DISABLED * "auto" => LLAMA_FLASH_ATTN_TYPE_AUTO Signed-off-by: hydai --- plugins/wasi_nn/wasinn_ggml.cpp | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 30de8cf9..7483377f 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -1276,11 +1276,22 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, } } if (Doc.at_key("flash-attn").error() == simdjson::SUCCESS) { - auto Err = Doc["flash-attn"].get().get(GraphRef.Params.flash_attn); + std::string_view FlashAttn; + auto Err = Doc["flash-attn"].get().get(FlashAttn); if (Err) { RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the flash-attn option."sv) } + if (FlashAttn == "on"sv || FlashAttn == "enabled"sv) { + GraphRef.Params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED; + } else if (FlashAttn == "off"sv || FlashAttn == "disabled"sv) { + GraphRef.Params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; + } else if (FlashAttn == "auto"sv) { + GraphRef.Params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; + } else { + RET_ERROR(ErrNo::InvalidArgument, + "The flash-attn option must be one of: on, off, auto."sv) + } } if (Doc.at_key("no-perf").error() == simdjson::SUCCESS) { auto Err = Doc["no-perf"].get().get(GraphRef.Params.no_perf); From 413b94620b7cb98b68bbabc4895eae5d66d002ae Mon Sep 17 00:00:00 2001 From: hydai Date: Mon, 8 Sep 2025 16:03:06 +0800 Subject: [PATCH 579/623] feat(WASI-NN,ggml): bump llama.cpp b6399, fixed embedding issue (#4350) Signed-off-by: hydai --- plugins/wasi_nn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index f835b292..14724c5a 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -24,7 +24,7 @@ wasmedge_add_library(wasmedgePluginWasiNN include(WASINNDeps) wasmedge_setup_wasinn_target(wasmedgePluginWasiNN PLUGINLIB) -set(WASMEDGE_WASI_NN_VERSION "0.1.31" CACHE STRING "WasmEdge WASI-NN library version") +set(WASMEDGE_WASI_NN_VERSION "0.1.32" CACHE STRING "WasmEdge WASI-NN library version") set(WASMEDGE_WASI_NN_SOVERSION "0" CACHE STRING "WasmEdge WASI-NN library soversion") # Handle the version of the WASI-NN plugin From d8ba9ae9be6c884713ae746bdd07a9928c255fcb Mon Sep 17 00:00:00 2001 From: Han-Wen Tsao Date: Tue, 9 Sep 2025 16:39:27 +0800 Subject: [PATCH 580/623] fix(WASI-NN/ChatTTS): fix chatTTS CI failed (#4349) Signed-off-by: grorge --- utils/wasi-nn/install-chattts.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/wasi-nn/install-chattts.sh b/utils/wasi-nn/install-chattts.sh index f406d460..6361cae6 100644 --- a/utils/wasi-nn/install-chattts.sh +++ b/utils/wasi-nn/install-chattts.sh @@ -14,7 +14,7 @@ pip install --upgrade pip # Install PyTorch CPU version to save space pip --python /usr/bin/python3 install --break-system-packages --index-url https://download.pytorch.org/whl/cpu 'torch<=2.6.0' 'torchaudio<=2.6.0' -pip --python /usr/bin/python3 install --break-system-packages chattts==0.2.3 +pip --python /usr/bin/python3 install --break-system-packages chattts==0.2.4, transformers==4.46.3 # Remove wheel cache pip --python /usr/bin/python3 cache purge From 0cb85d8313167356cc6ad4737ccfdaf6622bf704 Mon Sep 17 00:00:00 2001 From: Han-Wen Tsao Date: Fri, 12 Sep 2025 14:39:39 +0800 Subject: [PATCH 581/623] feat(wasi-nn,MLX): support whisper for MLX backend (#4322) Signed-off-by: grorge --- plugins/wasi_nn/CMakeLists.txt | 4 + plugins/wasi_nn/MLX/mlx/activations.cpp | 7 +- plugins/wasi_nn/MLX/mlx/base.cpp | 2 +- plugins/wasi_nn/MLX/mlx/convolution.cpp | 9 + plugins/wasi_nn/MLX/mlx/convolution.h | 22 +- plugins/wasi_nn/MLX/model/utils.cpp | 17 +- plugins/wasi_nn/MLX/model/utils.h | 51 + .../wasi_nn/MLX/model/whisper/decoding.cpp | 883 ++++++++++++++++++ plugins/wasi_nn/MLX/model/whisper/decoding.h | 196 ++++ .../wasi_nn/MLX/model/whisper/tokenizer.cpp | 783 ++++++++++++++++ plugins/wasi_nn/MLX/model/whisper/tokenizer.h | 121 +++ plugins/wasi_nn/MLX/model/whisper/whisper.cpp | 461 +++++++++ plugins/wasi_nn/MLX/model/whisper/whisper.h | 130 +++ .../wasi_nn/MLX/model/whisper_transcribe.cpp | 819 ++++++++++++++++ .../wasi_nn/MLX/model/whisper_transcribe.h | 124 +++ plugins/wasi_nn/wasinn_mlx.cpp | 54 +- plugins/wasi_nn/wasinn_mlx.h | 8 +- 17 files changed, 3668 insertions(+), 23 deletions(-) create mode 100644 plugins/wasi_nn/MLX/model/whisper/decoding.cpp create mode 100644 plugins/wasi_nn/MLX/model/whisper/decoding.h create mode 100644 plugins/wasi_nn/MLX/model/whisper/tokenizer.cpp create mode 100644 plugins/wasi_nn/MLX/model/whisper/tokenizer.h create mode 100644 plugins/wasi_nn/MLX/model/whisper/whisper.cpp create mode 100644 plugins/wasi_nn/MLX/model/whisper/whisper.h create mode 100644 plugins/wasi_nn/MLX/model/whisper_transcribe.cpp create mode 100644 plugins/wasi_nn/MLX/model/whisper_transcribe.h diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 14724c5a..faba2ce1 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -56,6 +56,10 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) MLX/model/utils.cpp MLX/model/vlm_base.cpp MLX/model/vlm_sampling.cpp + MLX/model/whisper/whisper.cpp + MLX/model/whisper/tokenizer.cpp + MLX/model/whisper/decoding.cpp + MLX/model/whisper_transcribe.cpp MLX/mlx/base.cpp MLX/mlx/linear.cpp MLX/mlx/convolution.cpp diff --git a/plugins/wasi_nn/MLX/mlx/activations.cpp b/plugins/wasi_nn/MLX/mlx/activations.cpp index f94695e2..137a9e35 100644 --- a/plugins/wasi_nn/MLX/mlx/activations.cpp +++ b/plugins/wasi_nn/MLX/mlx/activations.cpp @@ -11,7 +11,12 @@ namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core { mx::array gelu(mx::array X) { - return X * (1 + mx::erf(X / std::sqrt(2.0))) / 2.0; + // auto Result = X * (1 + mx::erf(X / std::sqrt(2))) / 2; + auto Result = X * + (mx::array({1}, X.dtype()) + + mx::erf(X / mx::array({std::sqrt(2)}, X.dtype()))) / + mx::array({2}, X.dtype()); + return Result; } mx::array silu(mx::array X) { return X * mx::sigmoid(X); } diff --git a/plugins/wasi_nn/MLX/mlx/base.cpp b/plugins/wasi_nn/MLX/mlx/base.cpp index 4bff15f9..4535f003 100644 --- a/plugins/wasi_nn/MLX/mlx/base.cpp +++ b/plugins/wasi_nn/MLX/mlx/base.cpp @@ -54,7 +54,7 @@ void Module::apply(std::string Key, mx::array Value) { } else { std::string LayerName = SplitKey[0]; SplitKey.erase(SplitKey.begin()); - if (LayerName == "layers") { + if (LayerName == "layers" || LayerName == "blocks") { LayerName += "." + SplitKey[0]; SplitKey.erase(SplitKey.begin()); } diff --git a/plugins/wasi_nn/MLX/mlx/convolution.cpp b/plugins/wasi_nn/MLX/mlx/convolution.cpp index c7947707..4e38a69b 100644 --- a/plugins/wasi_nn/MLX/mlx/convolution.cpp +++ b/plugins/wasi_nn/MLX/mlx/convolution.cpp @@ -7,6 +7,15 @@ namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core::nn { +mx::array Conv1d::forward(mx::array Input) { + auto Y = mx::conv1d(Input, Parameters.at("weight"), Stride, Padding, Dilation, + Groups); + if (Parameters.find("bias") != Parameters.end()) { + Y = Y + Parameters.at("bias"); + } + return Y; +} + mx::array Conv2d::forward(mx::array Input) { auto Y = mx::conv2d(Input, Parameters.at("weight"), Stride, Padding, Dilation, Groups); diff --git a/plugins/wasi_nn/MLX/mlx/convolution.h b/plugins/wasi_nn/MLX/mlx/convolution.h index 9db61de1..9e824dcb 100644 --- a/plugins/wasi_nn/MLX/mlx/convolution.h +++ b/plugins/wasi_nn/MLX/mlx/convolution.h @@ -8,6 +8,27 @@ namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core::nn { +class Conv1d : public nn::Module { + int Padding; + int Stride; + int Dilation; + int Groups; + +public: + Conv1d(int InChannels, int OutChannels, int KernelSize, int Stride = 1, + int Padding = 0, int Dilation = 1, int Groups = 1, bool Bias = true) + : Padding(Padding), Stride(Stride), Dilation(Dilation), Groups(Groups) { + double Scale = std::sqrt(1.0 / (InChannels * KernelSize)); + registerParameter( + "weight", mx::random::uniform(-Scale, Scale, + {OutChannels, InChannels, KernelSize})); + if (Bias) { + registerParameter("bias", mx::zeros({OutChannels})); + } + } + mx::array forward(mx::array Input); +}; + class Conv2d : public nn::Module { std::pair Padding; std::pair Stride; @@ -38,5 +59,4 @@ class Conv2d : public nn::Module { }; } // namespace mlx::core::nn - } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/utils.cpp b/plugins/wasi_nn/MLX/model/utils.cpp index 0674cf17..3fe41c57 100644 --- a/plugins/wasi_nn/MLX/model/utils.cpp +++ b/plugins/wasi_nn/MLX/model/utils.cpp @@ -2,7 +2,7 @@ // SPDX-FileCopyrightText: 2019-2024 Second State INC #include "model/utils.h" - +#include #include namespace WasmEdge::Host::WASINN::MLX { @@ -59,4 +59,19 @@ void saveWeights(const mx::array &Weights, const std::string &Path) { } } +std::string loadBytesFromFile(const std::string &Path) { + std::ifstream Fs(Path, std::ios::in | std::ios::binary); + if (Fs.fail()) { + spdlog::error("[WASI-NN] MLX backend: Cannot open {}."sv, Path); + return ""; + } + std::string Data; + Fs.seekg(0, std::ios::end); + const size_t Size = static_cast(Fs.tellg()); + Fs.seekg(0, std::ios::beg); + Data.resize(Size); + Fs.read(Data.data(), Size); + return Data; +} + } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/utils.h b/plugins/wasi_nn/MLX/model/utils.h index 9d8118ca..288b734d 100644 --- a/plugins/wasi_nn/MLX/model/utils.h +++ b/plugins/wasi_nn/MLX/model/utils.h @@ -24,4 +24,55 @@ void saveWeights(const std::unordered_map &Weights, void saveWeights(const mx::array &Weights, const std::string &Path); +std::string loadBytesFromFile(const std::string &Path); + +void fillPlaceholders(std::ostringstream &Oss, const std::string &Fmt, + size_t &Pos); + +template std::string toString(const T &Value) { + std::ostringstream Oss; + Oss << Value; + return Oss.str(); +} + +template std::string toString(const std::vector &Vec) { + std::ostringstream Oss; + Oss << "["; + for (size_t I = 0; I < Vec.size(); I++) { + Oss << toString(Vec[I]); + if (I + 1 < Vec.size()) { + Oss << ", "; + } + } + Oss << "]"; + return Oss.str(); +} + +template +void fillPlaceholders(std::ostringstream &Oss, const std::string &Fmt, + size_t &Pos, T &&Value, Args &&...args) { + auto PlaceholderPos = Fmt.find("{}", Pos); + if (PlaceholderPos == std::string::npos) { + Oss << Fmt.substr(Pos); + return; + } + Oss << Fmt.substr(Pos, PlaceholderPos - Pos); + Oss << toString(Value); + Pos = PlaceholderPos + 2; + fillPlaceholders(Oss, Fmt, Pos, std::forward(args)...); +} + +template +std::string formatStr(const std::string &Fmt, Args &&...args) { + std::ostringstream Oss; + size_t Pos = 0; + fillPlaceholders(Oss, Fmt, Pos, std::forward(args)...); + return Oss.str(); +} + +template void debug(const std::string &fmt, Args &&...args) { + std::cout << "[DEBUG] " << formatStr(fmt, std::forward(args)...) + << std::endl; +} + } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/whisper/decoding.cpp b/plugins/wasi_nn/MLX/model/whisper/decoding.cpp new file mode 100644 index 00000000..4ad093e5 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/whisper/decoding.cpp @@ -0,0 +1,883 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2025 Second State INC + +#include "decoding.h" +#include "mlx/base.h" +#include "tokenizer.h" +#include "whisper.h" +#include +#include +#include +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace whisper { + +// Utility function implementation +float compressionRatio(const std::string &Text) { + if (Text.empty()) + return 1.0f; + + std::vector Input(Text.begin(), Text.end()); + uLongf CompressedSize = compressBound(Input.size()); + std::vector Compressed(CompressedSize); + + int Result = + compress(Compressed.data(), &CompressedSize, Input.data(), Input.size()); + + if (Result != Z_OK) + return 1.0f; + + return static_cast(Input.size()) / static_cast(CompressedSize); +} + +// Language detection implementation +std::pair>> +detectLanguage(std::shared_ptr Model, const mx::array &Mel, + std::shared_ptr Tokenizer) { + + if (!Tokenizer) { + Tokenizer = getTokenizer(Model->isMultilingual(), Model->numLanguages()); + } + + if (!Tokenizer->Language || + std::find(Tokenizer->SotSequence.begin(), Tokenizer->SotSequence.end(), + Tokenizer->languageToken()) == Tokenizer->SotSequence.end()) { + throw std::runtime_error( + "This model doesn't have language tokens so it can't perform lang id"); + } + bool Single = Mel.ndim() == 2; + mx::array MelArray = Single ? mx::expand_dims(Mel, 0) : Mel; + // Skip encoder forward pass if already-encoded audio features were given + if (MelArray.shape(-2) != Model->Dims.NAudioCtx || + MelArray.shape(-1) != Model->Dims.NAudioState) { + MelArray = Model->embedAudio(MelArray); + } + // Forward pass using a single token, start of transcript + int NAudio = MelArray.shape(0); + std::vector> TokensVec(NAudio, {Tokenizer->getSot()}); + mx::array Tokens = mx::array(TokensVec[0].data(), {NAudio, 1}, mx::int32); + + mx::array Logits = Model->logits(Tokens, MelArray); + Logits = mx::take(Logits, mx::array({0}), 1); // [:, 0] + // Collect detected languages; suppress all non-language tokens + mx::array MaskArray = mx::full( + {Logits.shape(-1)}, -std::numeric_limits::infinity(), mx::float32); + auto LangTokens = Tokenizer->getAllLanguageTokens(); + mx::array LangTokensArray = mx::array( + LangTokens.data(), {static_cast(LangTokens.size())}, mx::int32); + MaskArray = + mx::scatter(MaskArray, LangTokensArray, + mx::zeros({static_cast(LangTokens.size()), 1}), 0); + + Logits = Logits + MaskArray; + mx::array LanguageTokens = mx::argmax(Logits, -1); + mx::array LanguageTokenProbs = mx::softmax(Logits, -1); + LanguageTokenProbs = mx::take(LanguageTokenProbs, 0, 0); + + std::vector> LanguageProbs; + auto LangCodes = Tokenizer->getAllLanguageCodes(); + + for (int I = 0; I < NAudio; ++I) { + std::map Probs; + for (size_t J = 0; J < LangTokens.size() && J < LangCodes.size(); ++J) { + int TokenId = LangTokens[J]; + std::string LangCode = LangCodes[J]; + + mx::array ProbArray = + mx::take(mx::take(LanguageTokenProbs, mx::array({0}), 0), + mx::array({TokenId}), -1); + ProbArray = mx::squeeze(ProbArray); + + float Prob = ProbArray.item(); + Probs[LangCode] = Prob; + } + LanguageProbs.push_back(Probs); + } + + if (Single) { + LanguageTokens = mx::take(LanguageTokens, mx::array({0}), 0); + LanguageProbs = {LanguageProbs[0]}; + } + + return {LanguageTokens, LanguageProbs}; +} + +// Inference class implementation +Inference::Inference(std::shared_ptr Model) : Model(Model) { reset(); } + +mx::array Inference::logits(const mx::array &Tokens, + const mx::array &AudioFeatures) { + auto [LogitsOutput, NewKvCache, _] = + std::dynamic_pointer_cast(Model->Submodules.at("decoder")) + ->forward(Tokens, AudioFeatures, KvCache); + KvCache = NewKvCache; + return mx::astype(LogitsOutput, mx::float32); +} + +void Inference::rearrangeKvCache(const std::vector &SourceIndices) { + // TODO: Implement KV cache rearrangement for beam search + assumingUnreachable(); +} + +void Inference::reset() { KvCache = std::nullopt; } + +// GreedyDecoder implementation +GreedyDecoder::GreedyDecoder(float Temperature, int Eot) + : Temperature(Temperature), Eot(Eot) {} + +void GreedyDecoder::reset() { + // Nothing to reset for greedy decoder +} + +std::tuple +GreedyDecoder::update(const mx::array &Tokens, const mx::array &Logits, + const mx::array &SumLogprobs) { + + // Sample next tokens + mx::array NextTokens = mx::array({}); + if (Temperature == 0.0f) { + NextTokens = mx::argmax(Logits, -1); + } else { + NextTokens = mx::random::categorical(Logits / Temperature); + } + + // Compute logprobs + mx::array Logprobs = Logits - mx::logsumexp(Logits, -1, false); + mx::array CurrentLogprobs = + mx::take(Logprobs, mx::arange(Logprobs.shape(0)), 0); + CurrentLogprobs = take(CurrentLogprobs, NextTokens, 1); + mx::array NewSumLogprobs = + SumLogprobs + + CurrentLogprobs * (take(Tokens, Tokens.shape(1) - 1, 1) != Eot); + + // Extend tokens + mx::array NewTokens = + mx::concatenate({Tokens, mx::expand_dims(NextTokens, -1)}, -1); + + // Check if all sequences are complete + bool Completed = mx::all(mx::equal(mx::take(NewTokens, mx::array({-1}), 1), + mx::array({Eot}))) + .item(); + + return {NewTokens, Completed, NewSumLogprobs}; +} + +std::pair +GreedyDecoder::finalize(const mx::array &Tokens, const mx::array &SumLogprobs) { + + // Make sure each sequence has at least one EOT token at the end + std::vector> PadWidths = {{0, 0}, {0, 0}, {0, 1}}; + mx::array PaddedTokens = + mx::pad(Tokens, PadWidths, mx::array(Eot, mx::int32)); + + return {PaddedTokens, SumLogprobs}; +} + +// SuppressBlank implementation +SuppressBlank::SuppressBlank(std::shared_ptr Tokenizer, + int SampleBegin, int NVocab) + : SampleBegin(SampleBegin), Mask(mx::zeros({NVocab}, mx::float32)) { + Name = "SuppressBlank"; + std::vector MaskVec(NVocab, 0.0f); + + // Suppress space and EOT tokens + auto SpaceTokens = Tokenizer->encode(" "); + for (int Token : SpaceTokens) { + if (Token >= 0 && Token < NVocab) { + MaskVec[Token] = -std::numeric_limits::infinity(); + } + } + + int EotToken = Tokenizer->getEot(); + if (EotToken < NVocab) { + MaskVec[EotToken] = -std::numeric_limits::infinity(); + } + + Mask = mx::array(MaskVec.data(), {NVocab}, mx::float32); +} + +mx::array SuppressBlank::apply(const mx::array &Logits, + const mx::array &Tokens) { + if (Tokens.shape(1) == SampleBegin) { + return Logits + Mask; + } + return Logits; +} + +// SuppressTokens implementation +SuppressTokens::SuppressTokens(const std::vector &SuppressTokens, + int NVocab) + : Mask(mx::zeros({NVocab}, mx::float32)) { + Name = "SuppressTokens"; + std::vector MaskVec(NVocab, 0.0f); + for (int Token : SuppressTokens) { + MaskVec[Token] = -std::numeric_limits::infinity(); + } + + Mask = mx::array(MaskVec.data(), {NVocab}, mx::float32); +} + +mx::array SuppressTokens::apply(const mx::array &Logits, + const mx::array &Tokens) { + return Logits + Mask; +} + +// ApplyTimestampRules implementation +ApplyTimestampRules::ApplyTimestampRules( + std::shared_ptr Tokenizer, int SampleBegin, + std::optional MaxInitialTimestampIndex) + : Tokenizer(Tokenizer), SampleBegin(SampleBegin), + MaxInitialTimestampIndex(MaxInitialTimestampIndex) { + Name = "ApplyTimestampRules"; +} + +mx::array ApplyTimestampRules::apply(const mx::array &Logits, + const mx::array &Tokens) { + auto LogitsShape = Logits.shape(); + std::vector MaskVec(Logits.size(), 0.0f); + + // Suppress <|notimestamps|> which is handled by without_timestamps + if (Tokenizer->getNoTimestamps()) { + for (int I = 0; I < LogitsShape[0]; ++I) { + MaskVec[I * LogitsShape[1] + Tokenizer->getNoTimestamps()] = + -std::numeric_limits::infinity(); + } + } + + mx::eval(Tokens); + std::vector> TokensList(Tokens.shape(0)); + for (int K = 0; K < Tokens.shape(0); ++K) { + TokensList[K].resize(Tokens.shape(1)); + for (int J = 0; J < Tokens.shape(1); ++J) { + mx::array TokenVal = mx::take(mx::take(Tokens, K, 0), J, 0); + TokensList[K][J] = TokenVal.item(); + } + } + + // Timestamps have to appear in pairs, except directly before EOT; mask logits + // accordingly + for (int K = 0; K < static_cast(TokensList.size()); ++K) { + // seq = tokens[k][self.sample_begin :] + std::vector Seq(TokensList[K].begin() + SampleBegin, + TokensList[K].end()); + + bool LastWasTimestamp = + Seq.size() >= 1 && + Seq[Seq.size() - 1] >= Tokenizer->getTimestampBegin(); + bool PenultimateWasTimestamp = + Seq.size() < 2 || Seq[Seq.size() - 2] >= Tokenizer->getTimestampBegin(); + + if (LastWasTimestamp) { + if (PenultimateWasTimestamp) { + // Has to be non-timestamp + for (int I = Tokenizer->getTimestampBegin(); I < LogitsShape[1]; ++I) { + MaskVec[K * LogitsShape[1] + I] = + -std::numeric_limits::infinity(); + } + } else { + // Cannot be normal text tokens + for (int I = 0; I < Tokenizer->getEot(); ++I) { + MaskVec[K * LogitsShape[1] + I] = + -std::numeric_limits::infinity(); + } + } + } + + // Find timestamps in sequence and enforce monotonicity + std::vector Timestamps; + for (size_t I = 0; I < Seq.size(); ++I) { + if (Seq[I] > Tokenizer->getTimestampBegin()) { + Timestamps.push_back(Seq[I]); + } + } + + if (!Timestamps.empty()) { + // Timestamps shouldn't decrease; forbid timestamp tokens smaller than the + // last Also force each segment to have a nonzero length, to prevent + // infinite looping + int LastTimestamp = Timestamps.back(); + if (LastTimestamp == 0 || PenultimateWasTimestamp) { + LastTimestamp += 1; + } + for (int I = Tokenizer->getTimestampBegin(); I < LastTimestamp; ++I) { + MaskVec[K * LogitsShape[1] + I] = + -std::numeric_limits::infinity(); + } + } + } + + if (static_cast(TokensList[0].size()) == SampleBegin) { + // Suppress generating non-timestamp tokens at the beginning + for (int I = 0; I < LogitsShape[0]; ++I) { + for (int J = 0; J < Tokenizer->getTimestampBegin(); ++J) { + MaskVec[I * LogitsShape[1] + J] = + -std::numeric_limits::infinity(); + } + } + + // Apply the `max_initial_timestamp` option + if (MaxInitialTimestampIndex) { + int LastAllowed = + Tokenizer->getTimestampBegin() + *MaxInitialTimestampIndex; + for (int I = 0; I < LogitsShape[0]; ++I) { + for (int J = LastAllowed + 1; J < LogitsShape[1]; ++J) { + MaskVec[I * LogitsShape[1] + J] = + -std::numeric_limits::infinity(); + } + } + } + } + + // If sum of probability over timestamps is above any other token, sample + // timestamp + mx::array MaskArray = mx::array(MaskVec.data(), LogitsShape, mx::float32); + mx::array Logprobs = Logits - mx::logsumexp(Logits, -1, true); + + // Calculate timestamp logprob: sum of probabilities for all timestamp tokens + mx::array TimestampLogprob = + mx::logsumexp(mx::slice(Logprobs, {0, Tokenizer->getTimestampBegin()}, + {LogitsShape[0], LogitsShape[1]}), + -1, true); + + // Calculate max text token logprob: max probability among non-timestamp + // tokens + mx::array MaxTextTokenLogprob = + mx::max(mx::slice(Logprobs, {0, 0}, + {LogitsShape[0], Tokenizer->getTimestampBegin()}), + -1, true); + + // Where timestamp probability > max text probability, suppress text tokens + mx::array TimestampCondition = + mx::greater(TimestampLogprob, MaxTextTokenLogprob); + + for (int I = 0; I < LogitsShape[0]; ++I) { + bool ShouldSuppressText = mx::take(TimestampCondition, I, 0).item(); + if (ShouldSuppressText) { + for (int J = 0; J < Tokenizer->getTimestampBegin(); ++J) { + MaskVec[I * LogitsShape[1] + J] = + -std::numeric_limits::infinity(); + } + } + } + + MaskArray = mx::array(MaskVec.data(), LogitsShape, mx::float32); + return Logits + MaskArray; +} + +// MaximumLikelihoodRanker implementation +MaximumLikelihoodRanker::MaximumLikelihoodRanker( + std::optional LengthPenalty) + : LengthPenalty(LengthPenalty) {} + +std::vector MaximumLikelihoodRanker::rank( + const std::vector>> &Tokens, + const std::vector> &SumLogprobs) { + + std::vector Selected; + + for (size_t I = 0; I < Tokens.size(); ++I) { + std::vector Scores; + + for (size_t J = 0; J < Tokens[I].size(); ++J) { + int Length = Tokens[I][J].size(); + float Logprob = SumLogprobs[I][J]; + + float Penalty; + if (LengthPenalty) { + Penalty = std::pow(Length, *LengthPenalty); + } else { + Penalty = Length; + } + + Scores.push_back(Logprob / Penalty); + } + + auto MaxIterator = std::max_element(Scores.begin(), Scores.end()); + Selected.push_back(std::distance(Scores.begin(), MaxIterator)); + } + + return Selected; +} + +// DecodingTask implementation - Constructor and helper methods +DecodingTask::DecodingTask(std::shared_ptr Model, + const DecodingOptions &Options) + : Model(Model), Options(verifyOptions(Options)) { + + std::string Language = Options.Language.value_or("en"); + Tokenizer = whisper::getTokenizer( + Model->isMultilingual(), Model->numLanguages(), Language, Options.Task); + + NGroup = Options.BeamSize.value_or(Options.BestOf.value_or(1)); + NCtx = Model->Dims.NTextCtx; + SampleLen = Options.SampleLen.value_or(NCtx / 2); + + // Handle SOT sequence with without_timestamps logic + SotSequence = Tokenizer->SotSequence; + if (Options.WithoutTimestamps) { + SotSequence = Tokenizer->getSotSequenceIncludingNotimestamps(); + } + + InitialTokens = getInitialTokens(); + SampleBegin = InitialTokens.size(); + + // Find SOT index + auto SotToken = Tokenizer->getSot(); + auto Iterator = + std::find(InitialTokens.begin(), InitialTokens.end(), SotToken); + SotIndex = std::distance(InitialTokens.begin(), Iterator); + + // Initialize components + Inference = std::make_unique(Model); + SequenceRanker = + std::make_unique(Options.LengthPenalty); + + if (Options.BeamSize && *Options.BeamSize > 1) { + throw std::runtime_error("Beam search decoder is not yet implemented"); + } + Decoder = + std::make_unique(Options.Temperature, Tokenizer->getEot()); + + LogitFilters.clear(); + + if (Options.SuppressBlank) { + LogitFilters.push_back(std::make_unique( + Tokenizer, SampleBegin, Model->Dims.NVocab)); + } + + if (Options.SuppressTokens) { + auto SuppressTokens = getSuppressTokens(); + LogitFilters.push_back(std::make_unique( + SuppressTokens, Model->Dims.NVocab)); + } + + if (!Options.WithoutTimestamps) { + std::optional MaxInitialTimestampIndex; + if (Options.MaxInitialTimestamp) { + float Precision = + 30.0f / Model->Dims.NAudioCtx; // CHUNK_LENGTH / n_audio_ctx + MaxInitialTimestampIndex = static_cast( + std::round(*Options.MaxInitialTimestamp / Precision)); + } + LogitFilters.push_back(std::make_unique( + Tokenizer, SampleBegin, MaxInitialTimestampIndex)); + } +} + +DecodingOptions DecodingTask::verifyOptions(const DecodingOptions &Options) { + DecodingOptions Result = Options; + + // Check beam_size and best_of conflicts + if (Result.BeamSize && Result.BestOf) { + throw std::runtime_error("beam_size and best_of can't be given together"); + } + + // Check temperature = 0 with best_of + if (Result.Temperature == 0.0f && Result.BestOf) { + throw std::runtime_error( + "best_of with greedy sampling (T=0) is not compatible"); + } + + // Check patience requires beam_size + if (Result.Patience && !Result.BeamSize) { + throw std::runtime_error("patience requires beam_size to be given"); + } + + // Check length_penalty range + if (Result.LengthPenalty && + (*Result.LengthPenalty < 0.0f || *Result.LengthPenalty > 1.0f)) { + throw std::runtime_error( + "length_penalty (alpha) should be a value between 0 and 1"); + } + + return Result; +} + +std::vector DecodingTask::getInitialTokens() { + std::vector Tokens = SotSequence; + + if (Options.Prefix) { + std::vector PrefixTokens; + if (std::holds_alternative(*Options.Prefix)) { + std::string PrefixStr = std::get(*Options.Prefix); + PrefixTokens = Tokenizer->encode(" " + PrefixStr); + } else { + PrefixTokens = std::get>(*Options.Prefix); + } + + if (SampleLen > 0) { + int MaxPrefixLen = NCtx / 2 - SampleLen; + if (static_cast(PrefixTokens.size()) > MaxPrefixLen) { + PrefixTokens = std::vector(PrefixTokens.end() - MaxPrefixLen, + PrefixTokens.end()); + } + } + + Tokens.insert(Tokens.end(), PrefixTokens.begin(), PrefixTokens.end()); + } + + if (Options.Prompt) { + std::vector PromptTokens; + if (std::holds_alternative(*Options.Prompt)) { + std::string PromptStr = std::get(*Options.Prompt); + PromptTokens = Tokenizer->encode(" " + PromptStr); + } else { + PromptTokens = std::get>(*Options.Prompt); + } + + int MaxPromptLen = NCtx / 2 - 1; + if (static_cast(PromptTokens.size()) > MaxPromptLen) { + PromptTokens = std::vector(PromptTokens.end() - MaxPromptLen, + PromptTokens.end()); + } + + std::vector NewTokens; + NewTokens.push_back(Tokenizer->getSotPrev()); + NewTokens.insert(NewTokens.end(), PromptTokens.begin(), PromptTokens.end()); + NewTokens.insert(NewTokens.end(), Tokens.begin(), Tokens.end()); + Tokens = NewTokens; + } + + return Tokens; +} + +std::vector DecodingTask::getSuppressTokens() { + std::vector SuppressTokens; + + if (Options.SuppressTokens) { + if (std::holds_alternative(*Options.SuppressTokens)) { + std::string TokensString = std::get(*Options.SuppressTokens); + std::istringstream Iss(TokensString); + std::string TokenString; + while (std::getline(Iss, TokenString, ',')) { + if (!TokenString.empty()) { + SuppressTokens.push_back(std::stoi(TokenString)); + } + } + } else { + SuppressTokens = std::get>(*Options.SuppressTokens); + } + } + + auto Iterator = std::find(SuppressTokens.begin(), SuppressTokens.end(), -1); + if (Iterator != SuppressTokens.end()) { + SuppressTokens.erase(std::remove_if(SuppressTokens.begin(), + SuppressTokens.end(), + [](int Token) { return Token < 0; }), + SuppressTokens.end()); + auto NonSpeechTokens = Tokenizer->getNonSpeechTokens(); + SuppressTokens.insert(SuppressTokens.end(), NonSpeechTokens.begin(), + NonSpeechTokens.end()); + } else if (!Options.SuppressTokens || SuppressTokens.empty()) { + SuppressTokens.clear(); + } else { + assumingUnreachable(); + } + + SuppressTokens.push_back(Tokenizer->getTranscribe()); + SuppressTokens.push_back(Tokenizer->getTranslate()); + SuppressTokens.push_back(Tokenizer->getSot()); + SuppressTokens.push_back(Tokenizer->getSotPrev()); + SuppressTokens.push_back(Tokenizer->getSotLm()); + + SuppressTokens.push_back(Tokenizer->getNoSpeech()); + + std::sort(SuppressTokens.begin(), SuppressTokens.end()); + SuppressTokens.erase( + std::unique(SuppressTokens.begin(), SuppressTokens.end()), + SuppressTokens.end()); + + return SuppressTokens; +} + +mx::array DecodingTask::getAudioFeatures(const mx::array &Mel) { + bool Single = Mel.ndim() == 2; + mx::array MelArray = Single ? mx::expand_dims(Mel, 0) : Mel; + + mx::array AudioFeatures = MelArray; + + // Skip encoder forward pass if already-encoded audio features were given + if (AudioFeatures.shape(-2) != Model->Dims.NAudioCtx || + AudioFeatures.shape(-1) != Model->Dims.NAudioState) { + AudioFeatures = Model->embedAudio(AudioFeatures); + } + + return AudioFeatures; +} + +std::pair, + std::optional>>> +DecodingTask::detectLanguage(const mx::array &AudioFeatures, + mx::array &Tokens) { + + std::vector Languages(AudioFeatures.shape(0), + Options.Language.value_or("en")); + std::optional>> LangProbs; + + if (!Options.Language || Options.Task == "lang_id") { + // Call the global detectLanguage function + auto [DetectedLanguageTokens, Probabilities] = + whisper::detectLanguage(Model, AudioFeatures, Tokenizer); + + Languages.clear(); + for (const auto &ProbsMap : Probabilities) { + auto MaxIterator = std::max_element( + ProbsMap.begin(), ProbsMap.end(), + [](const auto &A, const auto &B) { return A.second < B.second; }); + Languages.push_back(MaxIterator->first); + } + + LangProbs = Probabilities; + + if (!Options.Language) { + std::vector Start = {0, SotIndex + 1}; + std::vector End = {Tokens.shape()[0], SotIndex + 2}; + + Tokens = slice_update(Tokens, DetectedLanguageTokens, Start, End); + } + } + + return {Languages, LangProbs}; +} + +std::tuple +DecodingTask::mainLoop(const mx::array &AudioFeatures, + const mx::array &Tokens) { + + int NBatch = Tokens.shape(0); + mx::array CurrentTokens = Tokens; + mx::array SumLogprobs = mx::zeros({NBatch}, mx::float32); + bool Completed = false; + + auto StepFunction = [&](const mx::array &Inputs, const mx::array &AudioFeats, + const mx::array &TokSeq, const mx::array &SumLogp) + -> std::tuple { + mx::array PreLogits = Inference->logits(Inputs, AudioFeats); + mx::array Logits = take(PreLogits, PreLogits.shape(1) - 1, 1); + for (const auto &Filter : LogitFilters) { + Logits = Filter->apply(Logits, TokSeq); + } + auto [NextTokens, CompletedFlag, NextSumLogprobs] = + Decoder->update(TokSeq, Logits, SumLogp); + return std::make_tuple(NextTokens, CompletedFlag, NextSumLogprobs, + PreLogits); + }; + + auto [NextTokens, CompletedFlag, NextSumLogprobs, PreLogits] = + StepFunction(CurrentTokens, AudioFeatures, CurrentTokens, SumLogprobs); + + CurrentTokens = NextTokens; + SumLogprobs = NextSumLogprobs; + Completed = CompletedFlag; + + mx::array NoSpeechProbs = mx::zeros({NBatch}, mx::float32); + if (Tokenizer->getNoSpeech() != -1) { + auto ProbsAtSot = mx::softmax(mx::take(PreLogits, SotIndex, 1), -1); + NoSpeechProbs = mx::take(ProbsAtSot, Tokenizer->getNoSpeech(), 1); + } else { + NoSpeechProbs = mx::full({NBatch}, std::numeric_limits::quiet_NaN(), + mx::float32); + } + + mx::eval(CurrentTokens, SumLogprobs, NoSpeechProbs); + for (int I = 1; I < SampleLen; ++I) { + mx::array Inputs = + take(CurrentTokens, mx::array({CurrentTokens.shape(1) - 1}), 1); + + if (CurrentTokens.shape(-1) > NCtx) { + break; + } + auto [NextToks, NextCompleted, NextSumLogp, _] = + StepFunction(Inputs, AudioFeatures, CurrentTokens, SumLogprobs); + mx::eval(NextToks, NextSumLogp); + + if (Completed) { + break; + } + + CurrentTokens = NextToks; + Completed = NextCompleted; + SumLogprobs = NextSumLogp; + } + + return std::make_tuple(CurrentTokens, SumLogprobs, NoSpeechProbs); +} + +std::vector DecodingTask::run(const mx::array &Mel) { + Inference->reset(); + Decoder->reset(); + int NAudio = Mel.shape(0); + + mx::array AudioFeatures = getAudioFeatures(Mel); + + mx::array Tokens = + mx::array(InitialTokens.data(), {static_cast(InitialTokens.size())}, + mx::int32); + Tokens = mx::broadcast_to(Tokens, + {NAudio, static_cast(InitialTokens.size())}); + auto [Languages, LangProbs] = detectLanguage(AudioFeatures, Tokens); + + if (Options.Task == "lang_id") { + std::vector Results; + for (int I = 0; I < NAudio; ++I) { + DecodingResult Result; + Result.AudioFeatures = mx::take(AudioFeatures, mx::array({I}), 0); + Result.Language = Languages[I]; + if (LangProbs) { + Result.LanguageProbs = (*LangProbs)[I]; + } + Results.push_back(Result); + } + return Results; + } + + if (NGroup > 1) { + // tokens = tokens[:, None, :] + Tokens = mx::expand_dims(Tokens, 1); + + // tokens = mx.broadcast_to(tokens, [n_audio, self.n_group, + // len(self.initial_tokens)]) + std::vector NewShape = {NAudio, NGroup, + static_cast(InitialTokens.size())}; + Tokens = mx::broadcast_to(Tokens, NewShape); + + // tokens = tokens.reshape(n_audio * self.n_group, len(self.initial_tokens)) + Tokens = mx::reshape( + Tokens, {NAudio * NGroup, static_cast(InitialTokens.size())}); + } + // Call the main sampling loop + auto [TokensResult, SumLogprobs, NoSpeechProbs] = + mainLoop(AudioFeatures, Tokens); + + // Reshape the tensors to have (n_audio, n_group) as the first two dimensions + AudioFeatures = + mx::take(AudioFeatures, mx::arange(0, AudioFeatures.shape(0), NGroup), 0); + NoSpeechProbs = + mx::take(NoSpeechProbs, mx::arange(0, NoSpeechProbs.shape(0), NGroup), 0); + + // Ensure shapes are consistent + if (AudioFeatures.shape(0) != NoSpeechProbs.shape(0) || + AudioFeatures.shape(0) != NAudio) { + throw std::runtime_error( + "Inconsistent audio features and no_speech_probs shapes"); + } + + TokensResult = mx::reshape(TokensResult, {NAudio, NGroup, -1}); + SumLogprobs = mx::reshape(SumLogprobs, {NAudio, NGroup}); + // Get the final candidates for each group, and slice between the first + // sampled token and EOT + auto [FinalizedTokens, FinalizedLogprobs] = + Decoder->finalize(TokensResult, SumLogprobs); + + // tokens[..., self.sample_begin:] + std::vector SliceStart(FinalizedTokens.ndim(), 0); + std::vector SliceEnd = FinalizedTokens.shape(); + SliceStart[FinalizedTokens.ndim() - 1] = SampleBegin; + FinalizedTokens = mx::slice(FinalizedTokens, SliceStart, SliceEnd); + + mx::eval(FinalizedTokens, FinalizedLogprobs, NoSpeechProbs); + + // Convert tokens to nested vectors and handle EOT + std::vector>> TokensList(NAudio); + std::vector> SumLogprobsList(NAudio); + + for (int I = 0; I < NAudio; ++I) { + TokensList[I].resize(NGroup); + SumLogprobsList[I].resize(NGroup); + + for (int J = 0; J < NGroup; ++J) { + // Extract tokens until EOT + for (int K = 0; K < FinalizedTokens.shape(2); ++K) { + int Dim1 = FinalizedTokens.shape(1) * FinalizedTokens.shape(2); + int Dim2 = FinalizedTokens.shape(2); + mx::array TokenArray = + mx::take(FinalizedTokens, mx::array({I * Dim1 + J * Dim2 + K})); + int Token = TokenArray.item(); + if (Token == Tokenizer->getEot()) + break; + TokensList[I][J].push_back(Token); + } + + // Extract sum_logprobs + mx::array LogprobArray = mx::take( + FinalizedLogprobs, mx::array({I * FinalizedLogprobs.shape(1) + J})); + SumLogprobsList[I][J] = LogprobArray.item(); + } + } + + std::vector Selected = SequenceRanker->rank(TokensList, SumLogprobsList); + + // Extract final results + std::vector> FinalTokens(NAudio); + std::vector Texts(NAudio); + std::vector FinalSumLogprobs(NAudio); + std::vector AvgLogprobs(NAudio); + + for (int I = 0; I < NAudio; ++I) { + int SelectedIdx = Selected[I]; + FinalTokens[I] = TokensList[I][SelectedIdx]; + Texts[I] = Tokenizer->decode(FinalTokens[I]); + + Texts[I].erase(0, Texts[I].find_first_not_of(" \t\n\r\f\v")); + Texts[I].erase(Texts[I].find_last_not_of(" \t\n\r\f\v") + 1); + + FinalSumLogprobs[I] = SumLogprobsList[I][SelectedIdx]; + AvgLogprobs[I] = FinalSumLogprobs[I] / (FinalTokens[I].size() + 1); + } + + if (Texts.size() != Languages.size() || + Languages.size() != FinalTokens.size() || + FinalTokens.size() != AvgLogprobs.size() || + static_cast(AvgLogprobs.size()) != NAudio) { + throw std::runtime_error("inconsistent result lengths"); + } + + // Create final results + std::vector Results; + for (int I = 0; I < NAudio; ++I) { + DecodingResult Result; + Result.AudioFeatures = mx::take(AudioFeatures, I, 0); + Result.Language = Languages[I]; + Result.Tokens = FinalTokens[I]; + Result.Text = Texts[I]; + Result.AvgLogprob = AvgLogprobs[I]; + + mx::array NoSpeechArray = mx::take(NoSpeechProbs, mx::array({I}), 0); + Result.NoSpeechProb = NoSpeechArray.item(); + + Result.Temperature = Options.Temperature; + Result.CompressionRatio = compressionRatio(Result.Text); + + if (LangProbs) { + Result.LanguageProbs = (*LangProbs)[I]; + } + Results.push_back(Result); + } + + return Results; +} + +// Main decode function +std::variant> +decode(std::shared_ptr Model, const mx::array &Mel, + const DecodingOptions &Options) { + auto MelArray = Mel; + if (Mel.ndim() == 2) { + auto NewShape = Mel.shape(); + NewShape.insert(NewShape.begin(), 1); + MelArray = reshape(Mel, NewShape); + } + auto Results = DecodingTask(Model, Options).run(MelArray); + + if (Results.size() == 1) { + return Results[0]; + } + return Results; +} + +} // namespace whisper +} // namespace WasmEdge::Host::WASINN::MLX \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/model/whisper/decoding.h b/plugins/wasi_nn/MLX/model/whisper/decoding.h new file mode 100644 index 00000000..aa563253 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/whisper/decoding.h @@ -0,0 +1,196 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2025 Second State INC + +#pragma once + +#include "whisper.h" +#include +#include +#include +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace whisper { + +class Tokenizer; +class Whisper; + +float compressionRatio(const std::string &Text); + +std::pair>> +detectLanguage(std::shared_ptr Model, const mx::array &Mel, + std::shared_ptr Tokenizer = nullptr); + +struct DecodingOptions { + std::string Task = "transcribe"; + std::optional Language = std::nullopt; + float Temperature = 0.0f; + std::optional SampleLen = std::nullopt; + std::optional BestOf = std::nullopt; + std::optional BeamSize = std::nullopt; + std::optional Patience = std::nullopt; + std::optional LengthPenalty = std::nullopt; + std::optional>> Prompt = + std::nullopt; + std::optional>> Prefix = + std::nullopt; + std::optional>> SuppressTokens = + "-1"; + bool SuppressBlank = true; + bool WithoutTimestamps = false; + std::optional MaxInitialTimestamp = 1.0f; + bool Fp16 = true; +}; + +struct DecodingResult { + mx::array AudioFeatures; + std::string Language; + std::optional> LanguageProbs = std::nullopt; + std::vector Tokens; + std::string Text = ""; + float AvgLogprob = std::numeric_limits::quiet_NaN(); + float NoSpeechProb = std::numeric_limits::quiet_NaN(); + float Temperature = std::numeric_limits::quiet_NaN(); + float CompressionRatio = std::numeric_limits::quiet_NaN(); + DecodingResult() : AudioFeatures(mx::array({})) {} +}; + +class Inference { +public: + explicit Inference(std::shared_ptr Model); + mx::array logits(const mx::array &Tokens, const mx::array &AudioFeatures); + void rearrangeKvCache(const std::vector &SourceIndices); + void reset(); + +private: + std::shared_ptr Model; + std::optional< + std::vector>, + std::optional>>>> + KvCache = std::nullopt; +}; + +class SequenceRanker { +public: + virtual ~SequenceRanker() = default; + virtual std::vector + rank(const std::vector>> &Tokens, + const std::vector> &SumLogprobs) = 0; +}; + +class MaximumLikelihoodRanker : public SequenceRanker { +public: + explicit MaximumLikelihoodRanker(std::optional LengthPenalty); + std::vector + rank(const std::vector>> &Tokens, + const std::vector> &SumLogprobs) override; + +private: + std::optional LengthPenalty; +}; + +class TokenDecoder { +public: + virtual ~TokenDecoder() = default; + virtual void reset() = 0; + virtual std::tuple + update(const mx::array &Tokens, const mx::array &Logits, + const mx::array &SumLogprobs) = 0; + virtual std::pair + finalize(const mx::array &Tokens, const mx::array &SumLogprobs) = 0; +}; + +class GreedyDecoder : public TokenDecoder { +public: + GreedyDecoder(float Temperature, int Eot); + void reset() override; + std::tuple + update(const mx::array &Tokens, const mx::array &Logits, + const mx::array &SumLogprobs) override; + std::pair + finalize(const mx::array &Tokens, const mx::array &SumLogprobs) override; + +private: + float Temperature; + int Eot; +}; + +class LogitFilter { +public: + virtual ~LogitFilter() = default; + virtual mx::array apply(const mx::array &Logits, const mx::array &Tokens) = 0; + std::string Name = "LogitFilter"; +}; + +class SuppressBlank : public LogitFilter { +public: + SuppressBlank(std::shared_ptr Tokenizer, int SampleBegin, + int NVocab); + mx::array apply(const mx::array &Logits, const mx::array &Tokens) override; + +private: + int SampleBegin; + mx::array Mask; +}; + +class SuppressTokens : public LogitFilter { +public: + SuppressTokens(const std::vector &SuppressTokens, int NVocab); + mx::array apply(const mx::array &Logits, const mx::array &Tokens) override; + +private: + mx::array Mask; +}; + +class ApplyTimestampRules : public LogitFilter { +public: + ApplyTimestampRules(std::shared_ptr Tokenizer, int SampleBegin, + std::optional MaxInitialTimestampIndex); + mx::array apply(const mx::array &Logits, const mx::array &Tokens) override; + +private: + std::shared_ptr Tokenizer; + int SampleBegin; + std::optional MaxInitialTimestampIndex; +}; + +class DecodingTask { +public: + DecodingTask(std::shared_ptr Model, const DecodingOptions &Options); + std::vector run(const mx::array &Mel); + +private: + DecodingOptions verifyOptions(const DecodingOptions &Options); + std::vector getInitialTokens(); + std::vector getSuppressTokens(); + mx::array getAudioFeatures(const mx::array &Mel); + std::pair, + std::optional>>> + detectLanguage(const mx::array &AudioFeatures, mx::array &Tokens); + std::tuple + mainLoop(const mx::array &AudioFeatures, const mx::array &Tokens); + std::shared_ptr Model; + std::shared_ptr Tokenizer; + DecodingOptions Options; + int NGroup; + int NCtx; + int SampleLen; + std::vector SotSequence; + std::vector InitialTokens; + int SampleBegin; + int SotIndex; + std::unique_ptr Inference; + std::unique_ptr SequenceRanker; + std::unique_ptr Decoder; + std::vector> LogitFilters; +}; + +std::variant> +decode(std::shared_ptr Model, const mx::array &Mel, + const DecodingOptions &Options = DecodingOptions()); + +} // namespace whisper +} // namespace WasmEdge::Host::WASINN::MLX \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/model/whisper/tokenizer.cpp b/plugins/wasi_nn/MLX/model/whisper/tokenizer.cpp new file mode 100644 index 00000000..8eb68df3 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/whisper/tokenizer.cpp @@ -0,0 +1,783 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2025 Second State INC + +#include "tokenizer.h" +#include "mlx/base.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace whisper { + +const std::vector> LANGUAGES = { + {"en", "english"}, {"zh", "chinese"}, {"de", "german"}, + {"es", "spanish"}, {"ru", "russian"}, {"ko", "korean"}, + {"fr", "french"}, {"ja", "japanese"}, {"pt", "portuguese"}, + {"tr", "turkish"}, {"pl", "polish"}, {"ca", "catalan"}, + {"nl", "dutch"}, {"ar", "arabic"}, {"sv", "swedish"}, + {"it", "italian"}, {"id", "indonesian"}, {"hi", "hindi"}, + {"fi", "finnish"}, {"vi", "vietnamese"}, {"he", "hebrew"}, + {"uk", "ukrainian"}, {"el", "greek"}, {"ms", "malay"}, + {"cs", "czech"}, {"ro", "romanian"}, {"da", "danish"}, + {"hu", "hungarian"}, {"ta", "tamil"}, {"no", "norwegian"}, + {"th", "thai"}, {"ur", "urdu"}, {"hr", "croatian"}, + {"bg", "bulgarian"}, {"lt", "lithuanian"}, {"la", "latin"}, + {"mi", "maori"}, {"ml", "malayalam"}, {"cy", "welsh"}, + {"sk", "slovak"}, {"te", "telugu"}, {"fa", "persian"}, + {"lv", "latvian"}, {"bn", "bengali"}, {"sr", "serbian"}, + {"az", "azerbaijani"}, {"sl", "slovenian"}, {"kn", "kannada"}, + {"et", "estonian"}, {"mk", "macedonian"}, {"br", "breton"}, + {"eu", "basque"}, {"is", "icelandic"}, {"hy", "armenian"}, + {"ne", "nepali"}, {"mn", "mongolian"}, {"bs", "bosnian"}, + {"kk", "kazakh"}, {"sq", "albanian"}, {"sw", "swahili"}, + {"gl", "galician"}, {"mr", "marathi"}, {"pa", "punjabi"}, + {"si", "sinhala"}, {"km", "khmer"}, {"sn", "shona"}, + {"yo", "yoruba"}, {"so", "somali"}, {"af", "afrikaans"}, + {"oc", "occitan"}, {"ka", "georgian"}, {"be", "belarusian"}, + {"tg", "tajik"}, {"sd", "sindhi"}, {"gu", "gujarati"}, + {"am", "amharic"}, {"yi", "yiddish"}, {"lo", "lao"}, + {"uz", "uzbek"}, {"fo", "faroese"}, {"ht", "haitian creole"}, + {"ps", "pashto"}, {"tk", "turkmen"}, {"nn", "nynorsk"}, + {"mt", "maltese"}, {"sa", "sanskrit"}, {"lb", "luxembourgish"}, + {"my", "myanmar"}, {"bo", "tibetan"}, {"tl", "tagalog"}, + {"mg", "malagasy"}, {"as", "assamese"}, {"tt", "tatar"}, + {"haw", "hawaiian"}, {"ln", "lingala"}, {"ha", "hausa"}, + {"ba", "bashkir"}, {"jw", "javanese"}, {"su", "sundanese"}, + {"yue", "cantonese"}}; + +// Helper function to find language by code +std::string findLanguageByCode(const std::string &Code) { + for (const auto &[LangCode, Language] : LANGUAGES) { + if (LangCode == Code) { + return Language; + } + } + return ""; +} + +// Helper function to check if language code exists +bool languageCodeExists(const std::string &Code) { + for (const auto &[LangCode, Language] : LANGUAGES) { + if (LangCode == Code) { + return true; + } + } + return false; +} + +const std::unordered_map ToLanguageCode = []() { + std::unordered_map ToLanguageCodeMap; + + // Add language to code mappings + for (const auto &[Code, Language] : LANGUAGES) { + ToLanguageCodeMap[Language] = Code; + } + + // Add aliases + ToLanguageCodeMap["burmese"] = "my"; + ToLanguageCodeMap["valencian"] = "ca"; + ToLanguageCodeMap["flemish"] = "nl"; + ToLanguageCodeMap["haitian"] = "ht"; + ToLanguageCodeMap["letzeburgesch"] = "lb"; + ToLanguageCodeMap["pushto"] = "ps"; + ToLanguageCodeMap["panjabi"] = "pa"; + ToLanguageCodeMap["moldavian"] = "ro"; + ToLanguageCodeMap["moldovan"] = "ro"; + ToLanguageCodeMap["sinhalese"] = "si"; + ToLanguageCodeMap["castilian"] = "es"; + ToLanguageCodeMap["mandarin"] = "zh"; + + return ToLanguageCodeMap; +}(); + +namespace { +// Modern base64 decode implementation +// Based on RFC 4648 standard with better error handling +std::string base64Decode(const std::string &Encoded) { + // Standard base64 alphabet lookup table + static const std::array DecodeTable = []() { + std::array Table{}; + // Initialize all values to -1 (invalid) + std::fill(Table.begin(), Table.end(), -1); + + // Set valid characters + for (int I = 0; I < 26; ++I) { + Table['A' + I] = I; // A-Z: 0-25 + Table['a' + I] = I + 26; // a-z: 26-51 + } + for (int I = 0; I < 10; ++I) { + Table['0' + I] = I + 52; // 0-9: 52-61 + } + Table['+'] = 62; + Table['/'] = 63; + // Padding character + Table['='] = -2; + + return Table; + }(); + + if (Encoded.empty() || Encoded == "=") { + return {}; + } + + // Remove whitespace and validate length + std::string Cleaned; + Cleaned.reserve(Encoded.size()); + for (char C : Encoded) { + if (C != ' ' && C != '\t' && C != '\r' && C != '\n') { + Cleaned.push_back(C); + } + } + + if (Cleaned.size() % 4 != 0) { + throw std::invalid_argument("Invalid base64 string length"); + } + + std::string Result; + Result.reserve((Cleaned.size() * 3) / 4); + + for (size_t I = 0; I < Cleaned.size(); I += 4) { + std::array Values{}; + int PaddingCount = 0; + + // Decode 4 characters at a time + for (int J = 0; J < 4; ++J) { + unsigned char C = static_cast(Cleaned[I + J]); + int Val = DecodeTable[C]; + + if (Val == -1) { + throw std::invalid_argument("Invalid base64 character: " + + std::to_string(C)); + } + if (Val == -2) { // Padding + Values[J] = 0; + ++PaddingCount; + } else { + if (PaddingCount > 0) { + throw std::invalid_argument("Invalid padding in base64 string"); + } + Values[J] = Val; + } + } + + // Convert 4 base64 chars to 3 bytes + int Combined = + (Values[0] << 18) | (Values[1] << 12) | (Values[2] << 6) | Values[3]; + + // Extract bytes based on padding + if (PaddingCount == 0) { + Result.push_back(static_cast((Combined >> 16) & 0xFF)); + Result.push_back(static_cast((Combined >> 8) & 0xFF)); + Result.push_back(static_cast(Combined & 0xFF)); + } else if (PaddingCount == 1) { + Result.push_back(static_cast((Combined >> 16) & 0xFF)); + Result.push_back(static_cast((Combined >> 8) & 0xFF)); + } else if (PaddingCount == 2) { + Result.push_back(static_cast((Combined >> 16) & 0xFF)); + } else { + throw std::invalid_argument("Invalid padding count in base64 string"); + } + } + + return Result; +} +} // namespace + +// Encoding implementation +Encoding::Encoding(const std::string &Name, int ExplicitNVocab, + const std::string &PatStr, + const std::unordered_map &MergeableRanks, + const std::unordered_map &SpecialTokens) + : Name(Name), PatStr(PatStr), MergeableRanks(MergeableRanks), + SpecialTokens(SpecialTokens) { + + EotToken = 50257; // Default EOT token + + // Build reverse mapping + for (const auto &[Token, Id] : MergeableRanks) { + TokenToString[Id] = Token; + } + for (const auto &[Token, Id] : SpecialTokens) { + TokenToString[Id] = Token; + SpecialTokensSet.insert(Token); + } +} + +std::vector Encoding::encode(const std::string &Text) const { + std::vector Result; + + if (Text.empty()) { + return Result; + } + + // Handle the simplified cases for symbols that are commonly in vocab + // First try to encode the text as a single token + auto DirectIt = MergeableRanks.find(Text); + if (DirectIt != MergeableRanks.end()) { + Result.push_back(DirectIt->second); + return Result; + } + + // For complex text, we need a better approach than just splitting by spaces + // This is a more sophisticated approach that handles individual characters + // and symbols + + size_t I = 0; + while (I < Text.length()) { + // Try to find the longest matching token starting at position I + std::string LongestMatch; + int LongestMatchToken = -1; + + // Look for longest match from current position + for (size_t Len = std::min(Text.length() - I, size_t(10)); Len > 0; --Len) { + std::string Candidate = Text.substr(I, Len); + auto It = MergeableRanks.find(Candidate); + if (It != MergeableRanks.end()) { + LongestMatch = Candidate; + LongestMatchToken = It->second; + break; // Take the first (longest) match + } + } + + if (LongestMatchToken != -1) { + Result.push_back(LongestMatchToken); + I += LongestMatch.length(); + } else { + // If no match found, try single character + std::string SingleChar = Text.substr(I, 1); + auto CharIt = MergeableRanks.find(SingleChar); + if (CharIt != MergeableRanks.end()) { + Result.push_back(CharIt->second); + } else { + // Handle UTF-8 sequences - try to find multi-byte character + size_t CharLen = 1; + unsigned char FirstByte = static_cast(Text[I]); + if ((FirstByte & 0x80) != 0) { + // UTF-8 multi-byte character + if ((FirstByte & 0xE0) == 0xC0) + CharLen = 2; + else if ((FirstByte & 0xF0) == 0xE0) + CharLen = 3; + else if ((FirstByte & 0xF8) == 0xF0) + CharLen = 4; + } + + if (I + CharLen <= Text.length()) { + std::string MultiByteChar = Text.substr(I, CharLen); + auto MultiIt = MergeableRanks.find(MultiByteChar); + if (MultiIt != MergeableRanks.end()) { + Result.push_back(MultiIt->second); + I += CharLen; + continue; + } + } + + // If we still can't find it, encode as bytes + for (size_t J = 0; J < CharLen && I + J < Text.length(); ++J) { + unsigned char Byte = static_cast(Text[I + J]); + // Try to find byte encoding - tiktoken often has byte-level tokens + std::string ByteStr(1, static_cast(Byte)); + auto ByteIt = MergeableRanks.find(ByteStr); + if (ByteIt != MergeableRanks.end()) { + Result.push_back(ByteIt->second); + } + } + I += CharLen; + } + } + } + + return Result; +} + +std::string Encoding::decode(const std::vector &TokenIds) const { + std::string Result; + for (int Id : TokenIds) { + auto It = TokenToString.find(Id); + if (It != TokenToString.end()) { + Result += It->second; + } + } + return Result; +} + +int Encoding::encodeSingleToken(const std::string &Token) const { + auto It = SpecialTokens.find(Token); + if (It != SpecialTokens.end()) { + return It->second; + } + + auto MergeIt = MergeableRanks.find(Token); + if (MergeIt != MergeableRanks.end()) { + return MergeIt->second; + } + + throw std::runtime_error("Token not found: " + Token); +} +// Tokenizer implementation +Tokenizer::Tokenizer(std::unique_ptr EncodingPtr, int NumLanguages, + const std::optional &Language, + const std::optional &Task) + : EncodingPtr(std::move(EncodingPtr)), NumLanguages(NumLanguages), + Language(Language), Task(Task) { + + for (const auto &Special : this->EncodingPtr->SpecialTokensSet) { + int SpecialToken = this->EncodingPtr->encodeSingleToken(Special); + SpecialTokens[Special] = SpecialToken; + } + + // Build SOT sequence + int Sot = SpecialTokens["<|startoftranscript|>"]; + int Translate = SpecialTokens["<|translate|>"]; + int Transcribe = SpecialTokens["<|transcribe|>"]; + std::vector Langs; + for (const auto &[Code, Name] : LANGUAGES) { + Langs.push_back(Code); + if (static_cast(Langs.size()) >= NumLanguages) + break; + } + + SotSequence = {Sot}; + if (Language.has_value()) { + auto LangIt = std::find(Langs.begin(), Langs.end(), Language.value()); + if (LangIt != Langs.end()) { + int LangIndex = std::distance(Langs.begin(), LangIt); + SotSequence.push_back(Sot + 1 + LangIndex); + } + } + if (Task.has_value()) { + int TaskToken = (Task.value() == "transcribe") ? Transcribe : Translate; + SotSequence.push_back(TaskToken); + } +} + +std::vector Tokenizer::encode(const std::string &Text) const { + return EncodingPtr->encode(Text); +} + +std::string Tokenizer::decode(const std::vector &TokenIds) const { + std::vector FilteredTokens; + int TimestampBegin = getTimestampBegin(); + + for (int Token : TokenIds) { + if (Token < TimestampBegin) { + FilteredTokens.push_back(Token); + } + } + + return EncodingPtr->decode(FilteredTokens); +} + +std::string +Tokenizer::decodeWithTimestamps(const std::vector &TokenIds) const { + return EncodingPtr->decode(TokenIds); +} + +// Cached property implementations +int Tokenizer::getEot() const { + if (!CachedEot.has_value()) { + CachedEot = EncodingPtr->EotToken; + } + return CachedEot.value(); +} + +int Tokenizer::getTranscribe() const { + if (!CachedTranscribe.has_value()) { + CachedTranscribe = SpecialTokens.at("<|transcribe|>"); + } + return CachedTranscribe.value(); +} + +int Tokenizer::getTranslate() const { + if (!CachedTranslate.has_value()) { + CachedTranslate = SpecialTokens.at("<|translate|>"); + } + return CachedTranslate.value(); +} + +int Tokenizer::getSot() const { + if (!CachedSot.has_value()) { + CachedSot = SpecialTokens.at("<|startoftranscript|>"); + } + return CachedSot.value(); +} + +int Tokenizer::getSotLm() const { + if (!CachedSotLm.has_value()) { + CachedSotLm = SpecialTokens.at("<|startoflm|>"); + } + return CachedSotLm.value(); +} + +int Tokenizer::getSotPrev() const { + if (!CachedSotPrev.has_value()) { + CachedSotPrev = SpecialTokens.at("<|startofprev|>"); + } + return CachedSotPrev.value(); +} + +int Tokenizer::getNoSpeech() const { + if (!CachedNoSpeech.has_value()) { + CachedNoSpeech = SpecialTokens.at("<|nospeech|>"); + } + return CachedNoSpeech.value(); +} + +int Tokenizer::getNoTimestamps() const { + if (!CachedNoTimestamps.has_value()) { + CachedNoTimestamps = SpecialTokens.at("<|notimestamps|>"); + } + return CachedNoTimestamps.value(); +} + +int Tokenizer::getTimestampBegin() const { + if (!CachedTimestampBegin.has_value()) { + CachedTimestampBegin = SpecialTokens.at("<|0.00|>"); + } + return CachedTimestampBegin.value(); +} + +int Tokenizer::languageToken() const { + if (!Language.has_value()) { + throw std::runtime_error( + "This tokenizer does not have language token configured"); + } + return toLanguageToken(Language.value()); +} + +int Tokenizer::toLanguageToken(const std::string &Language) const { + std::string TokenName = "<|" + Language + "|>"; + auto It = SpecialTokens.find(TokenName); + if (It != SpecialTokens.end()) { + return It->second; + } + throw std::runtime_error("Language " + Language + " not found in tokenizer."); +} + +std::vector Tokenizer::getAllLanguageTokens() const { + if (!CachedAllLanguageTokens.has_value()) { + std::vector Result; + + std::vector SortedLanguageTokens; + for (const auto &[Token, TokenId] : SpecialTokens) { + std::string Stripped = Token; + const std::string CharsToStrip = "<|>"; + while (!Stripped.empty() && + CharsToStrip.find(Stripped.front()) != std::string::npos) { + Stripped.erase(0, 1); + } + while (!Stripped.empty() && + CharsToStrip.find(Stripped.back()) != std::string::npos) { + Stripped.pop_back(); + } + if (languageCodeExists(Stripped)) { + SortedLanguageTokens.push_back(Token); + } + } + + std::sort(SortedLanguageTokens.begin(), SortedLanguageTokens.end()); + + for (const std::string &Token : SortedLanguageTokens) { + auto It = SpecialTokens.find(Token); + if (It != SpecialTokens.end()) { + Result.push_back(It->second); + } + } + if (static_cast(Result.size()) > NumLanguages) { + Result.resize(NumLanguages); + } + + CachedAllLanguageTokens = Result; + } + return CachedAllLanguageTokens.value(); +} + +std::vector Tokenizer::getAllLanguageCodes() const { + if (!CachedAllLanguageCodes.has_value()) { + std::vector Result; + auto LanguageTokens = getAllLanguageTokens(); + + for (int Token : LanguageTokens) { + std::string Decoded = EncodingPtr->decode({Token}); + // Remove <| and |> from decoded token + if (Decoded.size() > 4 && Decoded.substr(0, 2) == "<|" && + Decoded.substr(Decoded.size() - 2) == "|>") { + Result.push_back(Decoded.substr(2, Decoded.size() - 4)); + } + } + + CachedAllLanguageCodes = Result; + } + return CachedAllLanguageCodes.value(); +} + +std::vector Tokenizer::getSotSequenceIncludingNotimestamps() const { + if (!CachedSotSequenceIncludingNotimestamps.has_value()) { + std::vector Result = SotSequence; + Result.push_back(getNoTimestamps()); + CachedSotSequenceIncludingNotimestamps = Result; + } + return CachedSotSequenceIncludingNotimestamps.value(); +} + +std::vector Tokenizer::getNonSpeechTokens() const { + if (!CachedNonSpeechTokens.has_value()) { + std::unordered_set ResultSet; + + std::vector Symbols = {"\"", "#", "(", ")", "*", "+", "/", + ":", ";", "<", "=", ">", "@", "[", + "\\", "]", "^", "_", "`", "{", "|", + "}", "~", "「", "」", "『", "』"}; + + std::vector AdditionalSymbols = { + "<<", ">>", "<<<", ">>>", "--", "---", "-(", "-[", "('", "(\"", + "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", "♪♪♪"}; + Symbols.insert(Symbols.end(), AdditionalSymbols.begin(), + AdditionalSymbols.end()); + + std::vector Miscellaneous = {"♩", "♪", "♫", "♬", + "♭", "♮", "♯"}; + + auto DashTokens = EncodingPtr->encode(" -"); + if (!DashTokens.empty()) { + ResultSet.insert(DashTokens[0]); + } + + auto QuoteTokens = EncodingPtr->encode(" '"); + if (!QuoteTokens.empty()) { + ResultSet.insert(QuoteTokens[0]); + } + + // Process all symbols + std::vector AllSymbols = Symbols; + AllSymbols.insert(AllSymbols.end(), Miscellaneous.begin(), + Miscellaneous.end()); + + for (const std::string &Symbol : AllSymbols) { + try { + // Try encoding the symbol directly + auto DirectTokens = EncodingPtr->encode(Symbol); + bool IsMiscellaneous = + std::find(Miscellaneous.begin(), Miscellaneous.end(), Symbol) != + Miscellaneous.end(); + + if (DirectTokens.size() == 1 || IsMiscellaneous) { + if (!DirectTokens.empty()) { + ResultSet.insert(DirectTokens[0]); + } + } + + // Try encoding the symbol with a space prefix + auto SpacedTokens = EncodingPtr->encode(" " + Symbol); + if (SpacedTokens.size() == 1 || IsMiscellaneous) { + if (!SpacedTokens.empty()) { + ResultSet.insert(SpacedTokens[0]); + } + } + } catch (...) { + // Ignore encoding errors for individual symbols + } + } + + // Convert set to sorted vector + std::vector Result(ResultSet.begin(), ResultSet.end()); + std::sort(Result.begin(), Result.end()); + + CachedNonSpeechTokens = Result; + } + return CachedNonSpeechTokens.value(); +} + +std::pair, std::vector>> +Tokenizer::splitToWordTokens(const std::vector &Tokens) const { + std::unordered_set NoSpaceLanguages = {"zh", "ja", "th", + "lo", "my", "yue"}; + + if (Language.has_value() && NoSpaceLanguages.count(Language.value())) { + return splitTokensOnUnicode(Tokens); + } + + return splitTokensOnSpaces(Tokens); +} + +std::pair, std::vector>> +Tokenizer::splitTokensOnUnicode(const std::vector &Tokens) const { + std::string DecodedFull = decodeWithTimestamps(Tokens); + const std::string ReplacementChar = "\uFFFD"; + + std::vector Words; + std::vector> WordTokens; + std::vector CurrentTokens; + size_t UnicodeOffset = 0; + + for (int Token : Tokens) { + CurrentTokens.push_back(Token); + std::string Decoded = decodeWithTimestamps(CurrentTokens); + + bool ValidUnicode = + (Decoded.find(ReplacementChar) == std::string::npos) || + (UnicodeOffset + Decoded.find(ReplacementChar) < DecodedFull.size() && + DecodedFull.substr(UnicodeOffset + Decoded.find(ReplacementChar), + ReplacementChar.length()) == ReplacementChar); + + if (ValidUnicode) { + Words.push_back(Decoded); + WordTokens.push_back(CurrentTokens); + CurrentTokens.clear(); + UnicodeOffset += Decoded.size(); + } + } + + return {Words, WordTokens}; +} + +std::pair, std::vector>> +Tokenizer::splitTokensOnSpaces(const std::vector &Tokens) const { + auto [Subwords, SubwordTokensList] = splitTokensOnUnicode(Tokens); + + std::vector Words; + std::vector> WordTokens; + + for (size_t I = 0; I < Subwords.size(); ++I) { + const std::string &Subword = Subwords[I]; + const std::vector &SubwordTokens = SubwordTokensList[I]; + + bool Special = !SubwordTokens.empty() && SubwordTokens[0] >= getEot(); + bool WithSpace = !Subword.empty() && Subword[0] == ' '; + + // Check if it's punctuation + std::string Trimmed = Subword; + Trimmed.erase(0, Trimmed.find_first_not_of(" \t\n\r")); + Trimmed.erase(Trimmed.find_last_not_of(" \t\n\r") + 1); + bool Punctuation = Trimmed.size() == 1 && std::ispunct(Trimmed[0]); + + if (Special || WithSpace || Punctuation || Words.empty()) { + Words.push_back(Subword); + WordTokens.push_back(SubwordTokens); + } else { + Words.back() += Subword; + WordTokens.back().insert(WordTokens.back().end(), SubwordTokens.begin(), + SubwordTokens.end()); + } + } + + return {Words, WordTokens}; +} + +// Factory functions +std::unique_ptr getEncoding(const std::string &Name, + int NumLanguages) { + std::string VocabPath = "assets/" + Name + ".tiktoken"; + + std::unordered_map Ranks; + std::ifstream File(VocabPath); + + if (!File.is_open()) { + throw std::runtime_error("Failed to open vocab file: " + VocabPath); + } + + std::string Line; + while (std::getline(File, Line)) { + if (Line.empty()) + continue; + + std::istringstream Iss(Line); + std::string Token, RankStr; + if (Iss >> Token >> RankStr) { + std::string DecodedToken = base64Decode(Token); + Ranks[DecodedToken] = std::stoi(RankStr); + } + } + File.close(); + + int NVocab = Ranks.size(); + std::unordered_map SpecialTokens; + + // Build special tokens + std::vector Specials = {"<|endoftext|>", + "<|startoftranscript|>"}; + + for (const auto &[Code, Name] : LANGUAGES) { + Specials.push_back("<|" + Code + "|>"); + if (static_cast(Specials.size()) >= NumLanguages + 2) + break; + } + + Specials.insert(Specials.end(), + {"<|translate|>", "<|transcribe|>", "<|startoflm|>", + "<|startofprev|>", "<|nospeech|>", "<|notimestamps|>"}); + + // Add timestamp tokens + for (int I = 0; I < 1501; ++I) { + std::ostringstream Oss; + Oss << "<|" << std::fixed << std::setprecision(2) << (I * 0.02) << "|>"; + Specials.push_back(Oss.str()); + } + + for (const std::string &Token : Specials) { + SpecialTokens[Token] = NVocab++; + } + + std::string PatStr = + R"('s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+)"; + + return std::make_unique(VocabPath, NVocab, PatStr, Ranks, + SpecialTokens); +} + +std::unique_ptr +getTokenizer(bool Multilingual, int NumLanguages, + const std::optional &Language, + const std::optional &Task) { + std::optional ProcessedLanguage = Language; + + if (ProcessedLanguage.has_value()) { + std::string Lang = ProcessedLanguage.value(); + std::transform(Lang.begin(), Lang.end(), Lang.begin(), ::tolower); + + if (!languageCodeExists(Lang)) { + auto It = ToLanguageCode.find(Lang); + if (It != ToLanguageCode.end()) { + ProcessedLanguage = It->second; + } else { + throw std::runtime_error("Unsupported language: " + Lang); + } + } else { + ProcessedLanguage = Lang; + } + } + + std::string EncodingName; + std::optional FinalLanguage; + std::optional FinalTask; + + if (Multilingual) { + EncodingName = "multilingual"; + FinalLanguage = ProcessedLanguage.value_or("en"); + FinalTask = Task.value_or("transcribe"); + } else { + EncodingName = "gpt2"; + FinalLanguage = std::nullopt; + FinalTask = std::nullopt; + } + + auto Encoding = getEncoding(EncodingName, NumLanguages); + + return std::make_unique(std::move(Encoding), NumLanguages, + FinalLanguage, FinalTask); +} + +std::unique_ptr +createWhisperTokenizer(const std::optional &Language, + const std::optional &Task) { + return getTokenizer(true, 99, Language, Task); +} + +} // namespace whisper +} // namespace WasmEdge::Host::WASINN::MLX \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/model/whisper/tokenizer.h b/plugins/wasi_nn/MLX/model/whisper/tokenizer.h new file mode 100644 index 00000000..9c068e09 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/whisper/tokenizer.h @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2025 Second State INC + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace whisper { + +extern const std::vector> LANGUAGES; +extern const std::unordered_map ToLanguageCode; + +std::string findLanguageByCode(const std::string &Code); +bool languageCodeExists(const std::string &Code); + +class Encoding { +public: + Encoding(const std::string &Name, int ExplicitNVocab, + const std::string &PatStr, + const std::unordered_map &MergeableRanks, + const std::unordered_map &SpecialTokens); + + std::vector encode(const std::string &Text) const; + std::string decode(const std::vector &TokenIds) const; + int encodeSingleToken(const std::string &Token) const; + + int EotToken; + std::unordered_set SpecialTokensSet; + +private: + std::string Name; + std::string PatStr; + std::unordered_map MergeableRanks; + std::unordered_map SpecialTokens; + std::unordered_map TokenToString; +}; + +class Tokenizer { +public: + Tokenizer(std::unique_ptr EncodingPtr, int NumLanguages, + const std::optional &Language = std::nullopt, + const std::optional &Task = std::nullopt); + + std::vector encode(const std::string &Text) const; + std::string decode(const std::vector &TokenIds) const; + std::string decodeWithTimestamps(const std::vector &TokenIds) const; + + // Cached properties + int getEot() const; + int getTranscribe() const; + int getTranslate() const; + int getSot() const; + int getSotLm() const; + int getSotPrev() const; + int getNoSpeech() const; + int getNoTimestamps() const; + int getTimestampBegin() const; + int languageToken() const; + int toLanguageToken(const std::string &Language) const; + std::vector getAllLanguageTokens() const; + std::vector getAllLanguageCodes() const; + std::vector getSotSequenceIncludingNotimestamps() const; + std::vector getNonSpeechTokens() const; + + std::pair, std::vector>> + splitToWordTokens(const std::vector &Tokens) const; + + // Public members + std::unique_ptr EncodingPtr; + int NumLanguages; + std::optional Language; + std::optional Task; + std::vector SotSequence; + std::unordered_map SpecialTokens; + +private: + std::pair, std::vector>> + splitTokensOnUnicode(const std::vector &Tokens) const; + + std::pair, std::vector>> + splitTokensOnSpaces(const std::vector &Tokens) const; + + // Cached values + mutable std::optional CachedEot; + mutable std::optional CachedTranscribe; + mutable std::optional CachedTranslate; + mutable std::optional CachedSot; + mutable std::optional CachedSotLm; + mutable std::optional CachedSotPrev; + mutable std::optional CachedNoSpeech; + mutable std::optional CachedNoTimestamps; + mutable std::optional CachedTimestampBegin; + mutable std::optional> CachedAllLanguageTokens; + mutable std::optional> CachedAllLanguageCodes; + mutable std::optional> + CachedSotSequenceIncludingNotimestamps; + mutable std::optional> CachedNonSpeechTokens; +}; + +std::unique_ptr getEncoding(const std::string &Name = "gpt2", + int NumLanguages = 99); + +std::unique_ptr +getTokenizer(bool Multilingual, int NumLanguages = 99, + const std::optional &Language = std::nullopt, + const std::optional &Task = std::nullopt); + +// Helper function to create a tokenizer for whisper transcription +std::unique_ptr createWhisperTokenizer( + const std::optional &Language = std::nullopt, + const std::optional &Task = std::nullopt); + +} // namespace whisper +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/whisper/whisper.cpp b/plugins/wasi_nn/MLX/model/whisper/whisper.cpp new file mode 100644 index 00000000..2e42f09d --- /dev/null +++ b/plugins/wasi_nn/MLX/model/whisper/whisper.cpp @@ -0,0 +1,461 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2025 Second State INC + +#include "whisper.h" +#include "mlx/activations.h" +#include "mlx/base.h" +#include "mlx/convolution.h" +#include "mlx/embedding.h" +#include "mlx/linear.h" +#include "mlx/normalization.h" +#include "mlx/transformer.h" +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { + +namespace whisper { + +mx::array sinusoids(int Length, int Channels, float MaxTimescale) { + assert(Channels % 2 == 0); + float LogTimescaleIncrement = std::log(MaxTimescale) / (Channels / 2 - 1); + mx::array InvTimescales = + mx::exp(-LogTimescaleIncrement * mx::arange(Channels / 2)); + mx::array LengthArray = mx::arange(Length); + mx::array LengthReshaped = reshape(LengthArray, {Length, 1}); + mx::array InvTimescalesReshaped = reshape(InvTimescales, {1, -1}); + mx::array ScaledTime = LengthReshaped * InvTimescalesReshaped; + return mx::concatenate({mx::sin(ScaledTime), mx::cos(ScaledTime)}, + /*axis=*/1); +} + +ModelDimensions ModelDimensions::fromDict(const simdjson::dom::object &Obj) { + ModelDimensions Dims; + if (auto Val = Obj["n_mels"]; !Val.error()) { + Dims.NMels = Val.get(); + } + if (auto Val = Obj["n_audio_ctx"]; !Val.error()) { + Dims.NAudioCtx = Val.get(); + } + if (auto Val = Obj["n_audio_state"]; !Val.error()) { + Dims.NAudioState = Val.get(); + } + if (auto Val = Obj["n_audio_head"]; !Val.error()) { + Dims.NAudioHead = Val.get(); + } + if (auto Val = Obj["n_audio_layer"]; !Val.error()) { + Dims.NAudioLayer = Val.get(); + } + if (auto Val = Obj["n_vocab"]; !Val.error()) { + Dims.NVocab = Val.get(); + } + if (auto Val = Obj["n_text_ctx"]; !Val.error()) { + Dims.NTextCtx = Val.get(); + } + if (auto Val = Obj["n_text_state"]; !Val.error()) { + Dims.NTextState = Val.get(); + } + if (auto Val = Obj["n_text_head"]; !Val.error()) { + Dims.NTextHead = Val.get(); + } + if (auto Val = Obj["n_text_layer"]; !Val.error()) { + Dims.NTextLayer = Val.get(); + } + return Dims; +} + +// MultiHeadAttention implementation +MultiHeadAttention::MultiHeadAttention(int NState, int NHead) : NHead(NHead) { + auto Query = std::make_shared(NState, NState); + auto Key = std::make_shared(NState, NState, /*bias=*/false); + auto Value = std::make_shared(NState, NState); + auto Out = std::make_shared(NState, NState); + + registerModule("query", Query); + registerModule("key", Key); + registerModule("value", Value); + registerModule("out", Out); +} + +std::tuple, mx::array> +MultiHeadAttention::forward( + const mx::array &X, const std::optional &Xa, + const std::optional &Mask, + const std::optional> &KvCache) { + + auto Query = std::dynamic_pointer_cast(Submodules["query"]); + auto Key = std::dynamic_pointer_cast(Submodules["key"]); + auto Value = std::dynamic_pointer_cast(Submodules["value"]); + auto Out = std::dynamic_pointer_cast(Submodules["out"]); + + mx::array Q = Query->forward(X); + mx::array K = mx::array({}), V = mx::array({}); + + if (!Xa.has_value()) { + // Self-attention + K = Key->forward(X); + V = Value->forward(X); + if (KvCache.has_value()) { + K = mx::concatenate({KvCache->first, K}, 1); + V = mx::concatenate({KvCache->second, V}, 1); + } + } else if (!KvCache.has_value()) { + // Cross-attention without cache + K = Key->forward(*Xa); + V = Value->forward(*Xa); + } else { + // Cross-attention with cache + K = KvCache->first; + V = KvCache->second; + } + auto [Wv, Qk] = qkvAttention(Q, K, V, Mask); + return {Out->forward(Wv), {K, V}, Qk}; +} + +std::pair +MultiHeadAttention::qkvAttention(const mx::array &Q, const mx::array &K, + const mx::array &V, + const std::optional &Mask) { + + auto Shape = Q.shape(); + int NBatch = Shape[0]; + int NCtx = Shape[1]; + int NState = Shape[2]; + + double Scale = std::pow(NState / NHead, -0.25); + Scale = std::round(Scale * 1000000) / 1000000; + + // Reshape and transpose for multi-head attention + mx::array QReshaped = reshape(Q, {Q.shape(0), Q.shape(1), NHead, -1}); + QReshaped = transpose(QReshaped, {0, 2, 1, 3}) * Scale; + + mx::array KReshaped = reshape(K, {K.shape(0), K.shape(1), NHead, -1}); + KReshaped = transpose(KReshaped, {0, 2, 3, 1}) * Scale; + + mx::array VReshaped = reshape(V, {V.shape(0), V.shape(1), NHead, -1}); + VReshaped = transpose(VReshaped, {0, 2, 1, 3}); + mx::array Qk = mx::matmul(QReshaped, KReshaped); + + if (Mask.has_value()) { + Qk = Qk + slice(*Mask, {0, 0}, {NCtx, NCtx}); + } + mx::array W = mx::softmax(Qk, /*axis=*/-1, /*precise=*/true); + mx::array Out = transpose(mx::matmul(W, VReshaped), {0, 2, 1, 3}); + Out = reshape(Out, {NBatch, NCtx, NState}); + return {Out, Qk}; +} + +// ResidualAttentionBlock implementation +ResidualAttentionBlock::ResidualAttentionBlock(int NState, int NHead, + bool CrossAttention) + : HasCrossAttention(CrossAttention) { + + auto Attn = std::make_shared(NState, NHead); + auto AttnLn = std::make_shared(NState); + + registerModule("attn", Attn); + registerModule("attn_ln", AttnLn); + + if (CrossAttention) { + auto CrossAttn = std::make_shared(NState, NHead); + auto CrossAttnLn = std::make_shared(NState); + registerModule("cross_attn", CrossAttn); + registerModule("cross_attn_ln", CrossAttnLn); + } + + int NMlp = NState * 4; + auto Mlp1 = std::make_shared(NState, NMlp); + auto Mlp2 = std::make_shared(NMlp, NState); + auto MlpLn = std::make_shared(NState); + + registerModule("mlp1", Mlp1); + registerModule("mlp2", Mlp2); + registerModule("mlp_ln", MlpLn); +} + +std::tuple>, + std::optional>>, + std::optional> +ResidualAttentionBlock::forward( + const mx::array &X, const std::optional &Xa, + const std::optional &Mask, + const std::optional< + std::pair>, + std::optional>>> &KvCache) { + + auto Attn = std::dynamic_pointer_cast(Submodules["attn"]); + auto AttnLn = std::dynamic_pointer_cast(Submodules["attn_ln"]); + auto Mlp1 = std::dynamic_pointer_cast(Submodules["mlp1"]); + auto Mlp2 = std::dynamic_pointer_cast(Submodules["mlp2"]); + auto MlpLn = std::dynamic_pointer_cast(Submodules["mlp_ln"]); + + std::optional> Kv = + KvCache ? KvCache->first : std::nullopt; + std::optional> CrossKv = + KvCache ? KvCache->second : std::nullopt; + auto [Y, NewKv, _] = + Attn->forward(AttnLn->forward(X), std::nullopt, Mask, Kv); + mx::array Result = X + Y; + + std::optional CrossQk = std::nullopt; + std::optional> NewCrossKv = std::nullopt; + + if (HasCrossAttention) { + auto CrossAttn = + std::dynamic_pointer_cast(Submodules["cross_attn"]); + auto CrossAttnLn = + std::dynamic_pointer_cast(Submodules["cross_attn_ln"]); + + auto [CrossY, TempCrossKv, TempCrossQk] = CrossAttn->forward( + CrossAttnLn->forward(Result), Xa, std::nullopt, CrossKv); + Result = Result + CrossY; + NewCrossKv = TempCrossKv; + CrossQk = TempCrossQk; + } + Result = Result + Mlp2->forward( + mlx::core::gelu(Mlp1->forward(MlpLn->forward(Result)))); + return {Result, {NewKv, NewCrossKv}, CrossQk}; +} + +// AudioEncoder implementation +AudioEncoder::AudioEncoder(int NMels, int NCtx, int NState, int NHead, + int NLayer, mx::Dtype Dtype) { + auto Conv1 = std::make_shared(NMels, NState, 3, 1, 1); + auto Conv2 = std::make_shared(NState, NState, 3, 2, 1); + + registerModule("conv1", Conv1); + registerModule("conv2", Conv2); + + PositionalEmbedding = astype(sinusoids(NCtx, NState), Dtype); + + for (int I = 0; I < NLayer; ++I) { + auto Block = std::make_shared(NState, NHead); + Blocks.push_back(Block); + } + registerLayer("blocks", Blocks); + + auto LnPost = std::make_shared(NState); + registerModule("ln_post", LnPost); +} + +mx::array AudioEncoder::forward(const mx::array &X) { + auto Conv1 = std::dynamic_pointer_cast(Submodules["conv1"]); + auto Conv2 = std::dynamic_pointer_cast(Submodules["conv2"]); + auto LnPost = std::dynamic_pointer_cast(Submodules["ln_post"]); + mx::array Result = Conv1->forward(X); + Result = mlx::core::gelu(Result); + Result = Conv2->forward(Result); + Result = mlx::core::gelu(Result); + assert(Result.shape()[1] == PositionalEmbedding.shape()[0] && + Result.shape()[2] == PositionalEmbedding.shape()[1]); + + Result = Result + PositionalEmbedding; + for (auto &Block : Blocks) { + auto [NewResult, _, __] = Block->forward(Result); + Result = NewResult; + } + return LnPost->forward(Result); +} + +// TextDecoder implementation +TextDecoder::TextDecoder(int NVocab, int NCtx, int NState, int NHead, + int NLayer, mx::Dtype Dtype) { + auto TokenEmbedding = std::make_shared(NVocab, NState); + registerModule("token_embedding", TokenEmbedding); + + registerParameter("positional_embedding", mx::zeros({NCtx, NState})); + + for (int I = 0; I < NLayer; ++I) { + auto Block = std::make_shared( + NState, NHead, /*cross_attention=*/true); + Blocks.push_back(Block); + } + registerLayer("blocks", Blocks); + + auto Ln = std::make_shared(NState); + registerModule("ln", Ln); + + Mask = astype(nn::MultiHeadAttention::createAdditiveCausalMask(NCtx), Dtype); +} + +std::tuple< + mx::array, + std::vector>, + std::optional>>>, + std::vector>> +TextDecoder::forward( + const mx::array &X, const mx::array &Xa, + const std::optional< + std::vector>, + std::optional>>>> + &KvCache) { + + auto TokenEmbedding = + std::dynamic_pointer_cast(Submodules["token_embedding"]); + auto Ln = std::dynamic_pointer_cast(Submodules["ln"]); + + int Offset = 0; + if (KvCache.has_value() && !KvCache->empty() && + KvCache->at(0).first.has_value() && + KvCache->at(0).first->first.shape(1) > 0) { + Offset = KvCache->at(0).first->first.shape(1); + } + std::vector Start(Parameters.at("positional_embedding").shape().size(), + 0); + std::vector End = Parameters.at("positional_embedding").shape(); + Start[0] = Offset; + End[0] = Offset + X.shape(-1); + + mx::array Result = TokenEmbedding->forward(X) + + slice(Parameters.at("positional_embedding"), Start, End); + + std::vector>, + std::optional>>> + NewKvCache; + std::vector> CrossQk; + + if (!KvCache.has_value()) { + NewKvCache.resize(Blocks.size()); + for (auto &Item : NewKvCache) { + Item = {std::nullopt, std::nullopt}; + } + } else { + NewKvCache = *KvCache; + } + + CrossQk.resize(Blocks.size()); + + for (size_t I = 0; I < Blocks.size(); ++I) { + auto [NewResult, UpdatedCache, BlockCrossQk] = + Blocks[I]->forward(Result, Xa, Mask, NewKvCache[I]); + Result = NewResult; + NewKvCache[I] = UpdatedCache; + CrossQk[I] = BlockCrossQk; + } + Result = Ln->forward(Result); + mx::array Logits = TokenEmbedding->asLinear(Result); + + return {Logits, NewKvCache, CrossQk}; +} + +Whisper::Whisper(const ModelDimensions &Dims, mx::Dtype Dtype) : Dims(Dims) { + auto Encoder = std::make_shared( + Dims.NMels, Dims.NAudioCtx, Dims.NAudioState, Dims.NAudioHead, + Dims.NAudioLayer, Dtype); + + auto Decoder = + std::make_shared(Dims.NVocab, Dims.NTextCtx, Dims.NTextState, + Dims.NTextHead, Dims.NTextLayer, Dtype); + + registerModule("encoder", Encoder); + registerModule("decoder", Decoder); + registerParameter("alignment_heads", mx::array({})); + // // Initialize alignment heads (use last half of decoder layers by default) + // std::vector> AllHeads( + // Dims.NTextLayer, std::vector(Dims.NTextHead, false)); + // for (int I = Dims.NTextLayer / 2; I < Dims.NTextLayer; ++I) { + // std::fill(AllHeads[I].begin(), AllHeads[I].end(), true); + // } + + // // Find all True positions and create the alignment heads array + // // Equivalent to: self.alignment_heads = + // // mx.array(np.asarray(all_heads.nonzero()).T) + // std::vector> NonzeroIndices; + // for (int Layer = 0; Layer < Dims.NTextLayer; ++Layer) { + // for (int Head = 0; Head < Dims.NTextHead; ++Head) { + // if (AllHeads[Layer][Head]) { + // NonzeroIndices.push_back({Layer, Head}); + // } + // } + // } + + // // Convert to mx::array format [N, 2] where N is number of True positions + // if (!NonzeroIndices.empty()) { + // std::vector FlatIndices; + // FlatIndices.reserve(NonzeroIndices.size() * 2); + // for (const auto &Index : NonzeroIndices) { + // FlatIndices.push_back(Index[0]); // Layer index + // FlatIndices.push_back(Index[1]); // Head index + // } + // AlignmentHeads = + // mx::array(FlatIndices.data(), + // {static_cast(NonzeroIndices.size()), 2}, mx::int64); + // } else { + // // Create empty array with correct shape + // AlignmentHeads = mx::zeros({0, 2}, mx::int64); + // } +} + +mx::array Whisper::forward(const mx::array &Mel, const mx::array &Tokens) { + auto Encoder = std::dynamic_pointer_cast(Submodules["encoder"]); + auto Decoder = std::dynamic_pointer_cast(Submodules["decoder"]); + + mx::array AudioFeatures = Encoder->forward(Mel); + auto [Logits, _, __] = Decoder->forward(Tokens, AudioFeatures); + return Logits; +} + +mx::array Whisper::embedAudio(const mx::array &Mel) { + auto Encoder = std::dynamic_pointer_cast(Submodules["encoder"]); + return Encoder->forward(Mel); +} + +mx::array Whisper::logits(const mx::array &Tokens, + const mx::array &AudioFeatures) { + auto Decoder = std::dynamic_pointer_cast(Submodules["decoder"]); + auto [Logits, _, __] = Decoder->forward(Tokens, AudioFeatures); + return Logits; +} + +std::pair>> +Whisper::forwardWithCrossQk(const mx::array &Mel, const mx::array &Tokens) { + auto Encoder = std::dynamic_pointer_cast(Submodules["encoder"]); + auto Decoder = std::dynamic_pointer_cast(Submodules["decoder"]); + + mx::array AudioFeatures = Encoder->forward(Mel); + auto [Logits, _, CrossQk] = Decoder->forward(Tokens, AudioFeatures); + return {Logits, CrossQk}; +} + +bool Whisper::isMultilingual() const { return Dims.NVocab >= 51865; } + +int Whisper::numLanguages() const { + return Dims.NVocab - 51765 - static_cast(isMultilingual()); +} + +std::shared_ptr Whisper::fromPretrained(const std::string &ModelPath) { + std::filesystem::path Path(ModelPath); + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + auto Error = Parser.load((Path / "config.json").string()).get(Doc); + if (Error) { + spdlog::error("Could not open config.json at {}", Path.string()); + assumingUnreachable(); + } + auto Obj = Doc.get_object(); + ModelDimensions DefaultDims = ModelDimensions::fromDict(Obj.value()); + auto Model = std::make_shared(DefaultDims); + std::vector WeightFiles; + for (auto &P : std::filesystem::directory_iterator(Path)) { + if (P.path().extension() == ".safetensors") + WeightFiles.push_back(P.path()); + } + if (WeightFiles.empty()) { + spdlog::error("No safetensors found in {}.", Path.string()); + assumingUnreachable(); + } + std::unordered_map Weights; + for (auto &Wf : WeightFiles) { + auto W = mx::load_safetensors(Wf.string()); + Weights.insert(W.first.begin(), W.first.end()); + } + Model->update(Weights); + + return Model; +} + +} // namespace whisper +} // namespace WasmEdge::Host::WASINN::MLX \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/model/whisper/whisper.h b/plugins/wasi_nn/MLX/model/whisper/whisper.h new file mode 100644 index 00000000..7a4b0922 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/whisper/whisper.h @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2025 Second State INC + +#pragma once + +#include "mlx/base.h" +#include "simdjson.h" +#include +#include +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { + +namespace nn = mlx::core::nn; + +namespace whisper { + +struct ModelDimensions { + int NMels = 80; + int NAudioCtx = 1500; + int NAudioState = 768; + int NAudioHead = 12; + int NAudioLayer = 12; + int NVocab = 51864; + int NTextCtx = 448; + int NTextState = 768; + int NTextHead = 12; + int NTextLayer = 12; + + static ModelDimensions fromDict(const simdjson::dom::object &Obj); +}; + +// Utility function for positional embeddings +mx::array sinusoids(int Length, int Channels, float MaxTimescale = 10000.0f); + +class MultiHeadAttention : public nn::Module { +public: + MultiHeadAttention(int NState, int NHead); + + std::tuple, mx::array> + forward(const mx::array &X, const std::optional &Xa = std::nullopt, + const std::optional &Mask = std::nullopt, + const std::optional> &KvCache = + std::nullopt); + +private: + std::pair + qkvAttention(const mx::array &Q, const mx::array &K, const mx::array &V, + const std::optional &Mask = std::nullopt); + + int NHead; +}; + +class ResidualAttentionBlock : public nn::Module { +public: + ResidualAttentionBlock(int NState, int NHead, bool CrossAttention = false); + + std::tuple>, + std::optional>>, + std::optional> + forward(const mx::array &X, const std::optional &Xa = std::nullopt, + const std::optional &Mask = std::nullopt, + const std::optional< + std::pair>, + std::optional>>> + &KvCache = std::nullopt); + +private: + bool HasCrossAttention; +}; + +class AudioEncoder : public nn::Module { +public: + AudioEncoder(int NMels, int NCtx, int NState, int NHead, int NLayer, + mx::Dtype Dtype = mx::float16); + + mx::array forward(const mx::array &X); + +private: + mx::array PositionalEmbedding = mx::array({}); + std::vector> Blocks; +}; + +class TextDecoder : public nn::Module { +public: + TextDecoder(int NVocab, int NCtx, int NState, int NHead, int NLayer, + mx::Dtype Dtype = mx::float16); + + std::tuple< + mx::array, + std::vector>, + std::optional>>>, + std::vector>> + forward(const mx::array &X, const mx::array &Xa, + const std::optional>, + std::optional>>>> + &KvCache = std::nullopt); + +private: + mx::array Mask = mx::array({}); + std::vector> Blocks; +}; + +class Whisper : public nn::Module { +public: + Whisper(const ModelDimensions &Dims, mx::Dtype Dtype = mx::float16); + + mx::array forward(const mx::array &Mel, const mx::array &Tokens); + + mx::array embedAudio(const mx::array &Mel); + mx::array logits(const mx::array &Tokens, const mx::array &AudioFeatures); + std::pair>> + forwardWithCrossQk(const mx::array &Mel, const mx::array &Tokens); + + bool isMultilingual() const; + int numLanguages() const; + + static std::shared_ptr fromPretrained(const std::string &ModelPath); + + ModelDimensions Dims; + mx::array AlignmentHeads = mx::array({}); +}; + +} // namespace whisper +} // namespace WasmEdge::Host::WASINN::MLX \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/model/whisper_transcribe.cpp b/plugins/wasi_nn/MLX/model/whisper_transcribe.cpp new file mode 100644 index 00000000..652fd0bc --- /dev/null +++ b/plugins/wasi_nn/MLX/model/whisper_transcribe.cpp @@ -0,0 +1,819 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2025 Second State INC + +#include "whisper_transcribe.h" +#include "mlx/base.h" +#include "spdlog/spdlog.h" +#include "whisper/tokenizer.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace whisper { + +mx::array loadAudio(const std::string &FilePath, int SampleRate) { + int Channels = 1; + std::vector Cmd = {"ffmpeg", "-nostdin", + "-threads", "0", + "-i", FilePath, + "-f", "s16le", + "-ac", std::to_string(Channels), + "-acodec", "pcm_s16le", + "-ar", std::to_string(SampleRate), + "-v", "quiet", + "-"}; + + // Build command string + std::stringstream CmdStream; + for (size_t i = 0; i < Cmd.size(); ++i) { + if (i > 0) + CmdStream << " "; + CmdStream << Cmd[i]; + } + std::string CmdString = CmdStream.str(); + + // Execute ffmpeg command + FILE *Pipe = popen(CmdString.c_str(), "r"); + if (!Pipe) { + throw std::runtime_error("Failed to execute ffmpeg command"); + } + + // Read raw audio data + std::vector AudioData; + int16_t Buffer[4096]; + size_t BytesRead; + + while ((BytesRead = fread(Buffer, sizeof(int16_t), 4096, Pipe)) > 0) { + AudioData.insert(AudioData.end(), Buffer, Buffer + BytesRead); + } + + int ExitCode = pclose(Pipe); + if (ExitCode != 0) { + throw std::runtime_error("ffmpeg command failed with exit code: " + + std::to_string(ExitCode)); + } + + if (AudioData.empty()) { + throw std::runtime_error("No audio data loaded from file: " + FilePath); + } + + mx::array AudioArray = mx::array( + AudioData.data(), {static_cast(AudioData.size())}, mx::int16); + + if (Channels > 1) { + int Frames = AudioData.size() / Channels; + AudioArray = mx::reshape(AudioArray, {Frames, Channels}); + } + + mx::array FloatAudio = mx::astype(AudioArray, mx::float32) / 32768.0f; + + return FloatAudio; +} + +mx::array padOrTrim(const mx::array &Array, int Length, int Axis) { + auto Shape = Array.shape(); + int ActualAxis = Axis < 0 ? Shape.size() + Axis : Axis; + + if (Shape[ActualAxis] > Length) { + std::vector Start(Shape.size(), 0); + std::vector End = Shape; + std::vector Strides(Shape.size(), 1); + End[ActualAxis] = Length; + return mx::slice(Array, Start, End, Strides); + } + + if (Shape[ActualAxis] < Length) { + std::vector> PadWidths(Shape.size(), {0, 0}); + PadWidths[ActualAxis] = {0, Length - Shape[ActualAxis]}; + return mx::pad(Array, PadWidths); + } + + return Array; +} + +mx::array hanningWindow(int Size) { + mx::array N = mx::arange(Size + 1); + mx::array Window = 0.5f * (1.0f - mx::cos(2.0f * M_PI * N / Size)); + std::vector Start = {0}; + std::vector End = {-1}; + std::vector Strides = {1}; + return mx::slice(Window, Start, End, Strides); +} + +mx::array stft(const mx::array &X, const mx::array &Window, int NPerseg = 256, + int NOverlap = 0, int NfFt = 0, + const std::string &PadMode = "reflect") { + if (NfFt == 0) + NfFt = NPerseg; + + auto Pad = [](const mx::array &X, int Padding, const std::string &PadMode) { + if (PadMode == "constant") { + std::vector> PadWidths = {{Padding, Padding}}; + return mx::pad(X, PadWidths); + } + if (PadMode == "reflect") { + // x[1:padding+1][::-1] + std::vector Start(X.shape().size(), 0); + std::vector End = X.shape(); + std::vector Strides(X.shape().size(), 1); + + Start[0] = 1; + End[0] = Padding + 1; + auto Prefix = mx::slice(X, Start, End, Strides); + Start[0] = -(Padding + 1); + End[0] = -1; + auto Suffix = mx::slice(X, Start, End, Strides); + + return mx::concatenate({Prefix, X, Suffix}); + } + throw std::runtime_error("Invalid pad mode: " + PadMode); + }; + + // Pad the signal + int Padding = NPerseg / 2; + auto Padded = Pad(X, Padding, PadMode); + std::vector Strides = {static_cast(NOverlap), 1}; + int T = (Padded.shape(0) - NPerseg + NOverlap) / NOverlap; + auto Shape = std::vector{T, NfFt}; + auto StridedX = mx::as_strided(Padded, Shape, Strides, 0); + + return mx::fft::rfft(StridedX * Window); +} + +mx::array melFilters(int NMels) { + // Load precomputed mel filters from file + if (NMels != 80 && NMels != 128) { + spdlog::error("Unsupported number of mel filters: " + + std::to_string(NMels)); + assumingUnreachable(); + } + std::string FileName = "assets/mel_filters_" + std::to_string(NMels) + ".npy"; + return mx::load(FileName); +} + +mx::array logMelSpectrogram(const mx::array &Audio, int NMels, int Padding) { + auto PaddedAudio = Audio; + if (Padding > 0) { + std::vector> PadWidths = {{0, Padding}}; + PaddedAudio = mx::pad(Audio, PadWidths); + } + + auto Window = hanningWindow(DefaultNFft); + auto Freqs = stft(PaddedAudio, Window, DefaultNFft, DefaultHopLength, + DefaultNFft, "reflect"); + // freqs[:-1, :].abs().square() + std::vector Start = {0, 0}; + std::vector End = {static_cast(Freqs.shape(0) - 1), + static_cast(Freqs.shape(1))}; + std::vector Strides = {1, 1}; + auto FreqsSliced = mx::slice(Freqs, Start, End, Strides); + auto Magnitudes = mx::square(mx::abs(FreqsSliced)); + // Apply mel filters + auto Filters = melFilters(NMels); + auto MelSpec = mx::matmul(Magnitudes, mx::transpose(Filters)); + auto LogSpec = + mx::log10(mx::maximum(MelSpec, mx::full(MelSpec.shape(), 1e-10f))); + LogSpec = mx::maximum(LogSpec, mx::max(LogSpec) - 8.0f); + LogSpec = (LogSpec + 4.0f) / 4.0f; + return LogSpec; +} + +// Utility functions +std::string formatTimestamp(float Seconds) { + assert(Seconds >= 0); + int Milliseconds = static_cast(std::round(Seconds * 1000.0f)); + + int Hours = Milliseconds / 3600000; + Milliseconds -= Hours * 3600000; + + int Minutes = Milliseconds / 60000; + Milliseconds -= Minutes * 60000; + + int Secs = Milliseconds / 1000; + Milliseconds -= Secs * 1000; + + std::stringstream Ss; + if (Hours > 0) { + Ss << std::setfill('0') << std::setw(2) << Hours << ":"; + } + Ss << std::setfill('0') << std::setw(2) << Minutes << ":" << std::setfill('0') + << std::setw(2) << Secs << "." << std::setfill('0') << std::setw(3) + << Milliseconds; + + return Ss.str(); +} + +std::optional getEnd(const std::vector &Segments) { + for (auto It = Segments.rbegin(); It != Segments.rend(); ++It) { + for (auto WordIt = It->Words.rbegin(); WordIt != It->Words.rend(); + ++WordIt) { + return WordIt->End; + } + if (!It->Words.empty()) { + return It->End; + } + } + return Segments.empty() ? std::nullopt + : std::make_optional(Segments.back().End); +} + +// Decoding functions +DecodingResult +decodeWithFallback(std::shared_ptr Model, + const mx::array &MelSegment, + const DecodingOptions &DecodeOptions, + const std::vector &Temperatures, + std::unique_ptr &Tokenizer, + std::optional CompressionRatioThreshold, + std::optional LogprobThreshold, + std::optional NoSpeechThreshold) { + + DecodingResult Result; + Result.AudioFeatures = MelSegment; + Result.Language = ""; + Result.Tokens = std::vector(); + Result.Text = ""; + + for (float Temp : Temperatures) { + // Simple greedy decoding for temperature = 0, sampling otherwise + DecodingOptions Options = DecodeOptions; + Options.Temperature = Temp; + + // This is a simplified decode implementation + // In practice, you'd implement the full beam search/sampling logic + std::vector Tokens; + + // Start with SOT sequence + auto SotSequence = Tokenizer->SotSequence; + Tokens.insert(Tokens.end(), SotSequence.begin(), SotSequence.end()); + + // Generate tokens + mx::array CurrentTokens = mx::array( + Tokens.data(), {1, static_cast(Tokens.size())}, mx::int32); + Result = std::get(decode(Model, MelSegment, Options)); + + // Check fallback conditions + bool NeedsFallback = false; + if (CompressionRatioThreshold && + Result.CompressionRatio > *CompressionRatioThreshold) { + NeedsFallback = true; + } + if (LogprobThreshold && Result.AvgLogprob < *LogprobThreshold) { + NeedsFallback = true; + } + if (NoSpeechThreshold && Result.NoSpeechProb > *NoSpeechThreshold) { + NeedsFallback = false; + } + + if (!NeedsFallback) { + break; + } + } + + return Result; +} + +// Word-level timestamp functions +void addWordTimestamps(std::vector &Segments, + std::shared_ptr Model, + std::unique_ptr &Tokenizer, + const mx::array &Mel, int NumFrames, + const std::string &PrependPunctuations, + const std::string &AppendPunctuations, + float LastSpeechTimestamp) { + + // This is a simplified implementation + // Full implementation would use cross-attention patterns and DTW + for (auto &Segment : Segments) { + if (!Segment.Tokens.empty()) { + float Duration = Segment.End - Segment.Start; + float TimePerToken = Duration / Segment.Tokens.size(); + + // Simple uniform distribution of word timestamps + std::string Text = Segment.Text; + std::istringstream Iss(Text); + std::string Word; + int WordIdx = 0; + + while (Iss >> Word && + static_cast(WordIdx) < Segment.Tokens.size()) { + WordInfo WordInfo; + WordInfo.Word = Word; + WordInfo.Start = Segment.Start + WordIdx * TimePerToken; + WordInfo.End = Segment.Start + (WordIdx + 1) * TimePerToken; + WordInfo.Probability = 0.8f; // Default probability + + Segment.Words.push_back(WordInfo); + WordIdx++; + } + } + } +} + +// Anomaly detection functions +float wordAnomalyScore(const WordInfo &Word) { + float Score = 0.0f; + + if (Word.Probability < 0.15f) { + Score += 1.0f; + } + + float Duration = Word.End - Word.Start; + if (Duration < 0.133f) { + Score += (0.133f - Duration) * 15.0f; + } + if (Duration > 2.0f) { + Score += Duration - 2.0f; + } + + return Score; +} + +bool isSegmentAnomaly(const std::optional &Segment) { + if (!Segment || Segment->Words.empty()) { + return false; + } + + // Filter out punctuation words + std::vector FilteredWords; + std::string Punctuation = "\"'?([{-\"'.,!?:\")]},"; + + for (const auto &Word : Segment->Words) { + if (Punctuation.find(Word.Word) == std::string::npos) { + FilteredWords.push_back(Word); + } + } + + if (FilteredWords.size() > 8) { + FilteredWords.resize(8); + } + + float TotalScore = 0.0f; + for (const auto &Word : FilteredWords) { + TotalScore += wordAnomalyScore(Word); + } + + return TotalScore >= 3.0f || TotalScore + 0.01f >= FilteredWords.size(); +} + +std::optional +nextWordsSegment(const std::vector &Segments) { + for (const auto &Segment : Segments) { + if (!Segment.Words.empty()) { + return Segment; + } + } + return std::nullopt; +} + +// Main transcribe function +TranscribeResult +transcribe(const std::variant &Audio, + std::shared_ptr Model, std::optional Verbose, + std::variant> Temperature, + std::optional CompressionRatioThreshold, + std::optional LogprobThreshold, + std::optional NoSpeechThreshold, bool ConditionOnPreviousText, + std::optional InitialPrompt, bool WordTimestamps, + const std::string &PrependPunctuations, + const std::string &AppendPunctuations, + std::variant> ClipTimestamps, + std::optional HallucinationSilenceThreshold, + const DecodingOptions &DecodeOptions) { + + // Get audio array + mx::array AudioArray = + std::holds_alternative(Audio) + ? loadAudio(std::get(Audio), DefaultSampleRate) + : std::get(Audio); + + // Get dtype + mx::Dtype Dtype = DecodeOptions.Fp16 ? mx::float16 : mx::float32; + + // Generate mel spectrogram + auto Mel = logMelSpectrogram(AudioArray, Model->Dims.NMels, DefaultNSamples); + int ContentFrames = Mel.shape(-2) - DefaultNFrames; + float ContentDuration = + static_cast(ContentFrames * DefaultHopLength) / DefaultSampleRate; + + // Language detection + DecodingOptions Options = DecodeOptions; + if (!Options.Language) { + if (!Model->isMultilingual()) { + Options.Language = "en"; + } else { + if (Verbose.value_or(false)) { + std::cout << "Detecting language using up to the first 30 seconds.\n"; + } + + auto MelSegment = padOrTrim(Mel, DefaultNFrames, -2); + MelSegment = mx::astype(MelSegment, Dtype); + auto [LangTokens, LangProbs] = detectLanguage(Model, MelSegment); + + // Find most probable language from the detection results + if (!LangProbs.empty() && !LangProbs[0].empty()) { + // Find language with highest probability + auto MaxIterator = std::max_element( + LangProbs[0].begin(), LangProbs[0].end(), + [](const auto &A, const auto &B) { return A.second < B.second; }); + Options.Language = MaxIterator->first; + } else { + Options.Language = "en"; // Default fallback only if detection failed + } + + if (Verbose.value_or(false)) { + std::string LangName = findLanguageByCode(*Options.Language); + if (LangName.empty()) { + LangName = *Options.Language; // fallback to code if not found + } + std::cout << "Detected language: " << LangName << std::endl; + } + } + } + + auto Task = Options.Task; + auto Tokenizer = getTokenizer(Model->isMultilingual(), Model->numLanguages(), + *Options.Language, Task); + + // Parse clip timestamps + std::vector ClipTimes; + if (std::holds_alternative(ClipTimestamps)) { + std::string ClipStr = std::get(ClipTimestamps); + if (!ClipStr.empty()) { + std::stringstream Ss(ClipStr); + std::string Item; + while (std::getline(Ss, Item, ',')) { + ClipTimes.push_back(std::stof(Item)); + } + } + } else { + ClipTimes = std::get>(ClipTimestamps); + } + + // Set up seek points + std::vector SeekPoints; + for (float Ts : ClipTimes) { + SeekPoints.push_back( + static_cast(std::round(Ts * DefaultFramesPerSecond))); + } + if (SeekPoints.empty()) { + SeekPoints.push_back(0); + } + if (SeekPoints.size() % 2 == 1) { + SeekPoints.push_back(ContentFrames); + } else { + SeekPoints.back() = std::min(ContentFrames, SeekPoints.back()); + } + + // Create seek clips + std::vector> SeekClips; + for (size_t I = 0; I < SeekPoints.size(); I += 2) { + SeekClips.emplace_back(SeekPoints[I], SeekPoints[I + 1]); + } + + int Seek = SeekClips[0].first; + // time_precision + + std::vector AllTokens; + std::vector AllSegments; + // prompt_reset_since + int PromptResetSince = 0; + + if (InitialPrompt) { + auto PromptTokens = Tokenizer->encode(" " + *InitialPrompt); + AllTokens.insert(AllTokens.end(), PromptTokens.begin(), PromptTokens.end()); + } + + std::vector Temperatures; + if (std::holds_alternative(Temperature)) { + Temperatures = {std::get(Temperature)}; + } else { + Temperatures = std::get>(Temperature); + } + + const int InputStride = DefaultNFrames / Model->Dims.NAudioCtx; // 2 for tiny + const float TimePrecision = + static_cast(InputStride * DefaultHopLength) / DefaultSampleRate; + const int FramesPerSecond = DefaultFramesPerSecond; + + // Processing loop + float LastSpeechTimestamp = 0.0f; + for (const auto &[SeekClipStart, SeekClipEnd] : SeekClips) { + while (Seek < SeekClipEnd) { + float TimeOffset = + static_cast(Seek * DefaultHopLength) / DefaultSampleRate; + float WindowEndTime = + static_cast((Seek + DefaultNFrames) * DefaultHopLength) / + DefaultSampleRate; + int SegmentSize = + std::min({DefaultNFrames, ContentFrames - Seek, SeekClipEnd - Seek}); + + // Extract mel segment + std::vector Start(Mel.shape().size(), 0); + std::vector End = Mel.shape(); + Start[0] = Seek; + End[0] = Seek + SegmentSize; + auto MelSegment = mx::slice(Mel, Start, End); + MelSegment = padOrTrim(MelSegment, DefaultNFrames, -2); + MelSegment = mx::astype(MelSegment, Dtype); + + // Decode segment + // Provide prompt tokens since last reset + { + std::vector PromptSlice; + if (PromptResetSince >= 0 && + PromptResetSince <= static_cast(AllTokens.size())) { + PromptSlice.assign(AllTokens.begin() + PromptResetSince, + AllTokens.end()); + } + Options.Prompt = PromptSlice; // empty slice allowed (adds sot_prev) + } + + auto Result = decodeWithFallback(Model, MelSegment, Options, Temperatures, + Tokenizer, CompressionRatioThreshold, + LogprobThreshold, NoSpeechThreshold); + + // Voice activity check and fast-forward if silence + if (NoSpeechThreshold) { + bool ShouldSkip = Result.NoSpeechProb > *NoSpeechThreshold; + if (LogprobThreshold && Result.AvgLogprob > *LogprobThreshold) { + ShouldSkip = false; + } + if (ShouldSkip) { + Seek += SegmentSize; + continue; + } + } + + // Prepare tokens for timestamp analysis + const auto &TokVec = Result.Tokens; + std::vector TimestampMask(TokVec.size(), false); + int TsBegin = Tokenizer->getTimestampBegin(); + for (size_t I = 0; I < TokVec.size(); ++I) { + TimestampMask[I] = TokVec[I] >= TsBegin; + } + + bool SingleTimestampEnding = false; + if (TimestampMask.size() >= 2) { + SingleTimestampEnding = + (!TimestampMask[TimestampMask.size() - 2] && TimestampMask.back()); + } + + // Find consecutive timestamp pairs + std::vector ConsecutiveIdx; + for (int I = 1; I < static_cast(TimestampMask.size()); ++I) { + if (TimestampMask[I - 1] && TimestampMask[I]) { + ConsecutiveIdx.push_back(I); + } + } + + std::vector CurrentSegments; + auto NewSegmentFromSlice = [&](int LastSliceIdx, int CurrentSliceIdx, + float StartBase, float EndBase) { + std::vector SliceTokens; + SliceTokens.insert(SliceTokens.end(), TokVec.begin() + LastSliceIdx, + TokVec.begin() + CurrentSliceIdx); + if (SliceTokens.empty()) + return; // nothing to add + int StartTsPos = SliceTokens.front() - TsBegin; + int EndTsPos = SliceTokens.back() - TsBegin; + TranscribeSegment S; + S.Id = static_cast(AllSegments.size() + CurrentSegments.size()); + S.Seek = Seek; + S.Start = StartBase + StartTsPos * TimePrecision; + S.End = StartBase + EndTsPos * TimePrecision; + // text only from tokens < eot + std::vector TextTokens; + int Eot = Tokenizer->getEot(); + for (int Tok : SliceTokens) { + if (Tok < Eot) + TextTokens.push_back(Tok); + } + S.Text = Tokenizer->decode(TextTokens); + S.Tokens.assign(SliceTokens.begin(), SliceTokens.end()); + S.Temperature = Result.Temperature; + S.AvgLogprob = Result.AvgLogprob; + S.CompressionRatio = Result.CompressionRatio; + S.NoSpeechProb = Result.NoSpeechProb; + CurrentSegments.push_back(std::move(S)); + }; + + if (!ConsecutiveIdx.empty()) { + std::vector Slices = ConsecutiveIdx; + if (SingleTimestampEnding) { + Slices.push_back(static_cast(TokVec.size())); + } + + int LastSlice = 0; + for (int Cur : Slices) { + NewSegmentFromSlice(LastSlice, Cur, TimeOffset, TimeOffset); + LastSlice = Cur; + } + + if (SingleTimestampEnding) { + // no speech after the last timestamp + Seek += SegmentSize; + } else { + int LastTimestampPos = TokVec[Slices.back() - 1] - TsBegin; + Seek += LastTimestampPos * InputStride; + } + } else { + // No consecutive timestamp tokens + float SegmentDuration = + static_cast(SegmentSize * DefaultHopLength) / + DefaultSampleRate; + int LastTsIndex = -1; + for (int I = static_cast(TokVec.size()) - 1; I >= 0; --I) { + if (TimestampMask[I]) { + LastTsIndex = I; + break; + } + } + if (LastTsIndex != -1 && TokVec[LastTsIndex] != TsBegin) { + int LastTsPos = TokVec[LastTsIndex] - TsBegin; + SegmentDuration = LastTsPos * TimePrecision; + } + + // Create one segment for the whole window + TranscribeSegment S; + S.Id = static_cast(AllSegments.size()); + S.Seek = Seek; + S.Start = TimeOffset; + S.End = TimeOffset + SegmentDuration; + // text from tokens < eot + std::vector TextTokens; + int Eot = Tokenizer->getEot(); + for (int Tok : TokVec) { + if (Tok < Eot) + TextTokens.push_back(Tok); + } + S.Text = Tokenizer->decode(TextTokens); + S.Tokens = TokVec; + S.Temperature = Result.Temperature; + S.AvgLogprob = Result.AvgLogprob; + S.CompressionRatio = Result.CompressionRatio; + S.NoSpeechProb = Result.NoSpeechProb; + CurrentSegments.push_back(std::move(S)); + + Seek += SegmentSize; + } + + // Word-level timestamps and hallucination handling + if (WordTimestamps) { + addWordTimestamps(CurrentSegments, Model, Tokenizer, MelSegment, + SegmentSize, PrependPunctuations, AppendPunctuations, + LastSpeechTimestamp); + + if (!SingleTimestampEnding) { + auto LastWordEndOpt = getEnd(CurrentSegments); + if (LastWordEndOpt && *LastWordEndOpt > TimeOffset) { + Seek = static_cast( + std::round((*LastWordEndOpt) * FramesPerSecond)); + } + } + + if (HallucinationSilenceThreshold) { + float Threshold = *HallucinationSilenceThreshold; + if (!SingleTimestampEnding) { + auto LastWordEndOpt = getEnd(CurrentSegments); + if (LastWordEndOpt && *LastWordEndOpt > TimeOffset) { + float Remaining = WindowEndTime - *LastWordEndOpt; + if (Remaining > Threshold) { + Seek = static_cast( + std::round((*LastWordEndOpt) * FramesPerSecond)); + } else { + // keep default seek progression + } + } + } + + // if first segment might be a hallucination, skip leading silence + auto FirstWithWords = nextWordsSegment(CurrentSegments); + if (FirstWithWords && isSegmentAnomaly(FirstWithWords)) { + float Gap = FirstWithWords->Start - TimeOffset; + if (Gap > Threshold) { + int NewSeek = static_cast(std::round( + static_cast(Seek) + Gap * FramesPerSecond)); + Seek = NewSeek; + // restart loop + continue; + } + } + + // skip silence before any possible hallucination surrounded by + // silence + float HalLastEnd = LastSpeechTimestamp; + for (size_t Si = 0; Si < CurrentSegments.size(); ++Si) { + auto &Seg = CurrentSegments[Si]; + if (Seg.Words.empty()) + continue; + std::optional SegOpt = Seg; + if (isSegmentAnomaly(SegOpt)) { + std::optional NextSeg = std::nullopt; + for (size_t J = Si + 1; J < CurrentSegments.size(); ++J) { + if (!CurrentSegments[J].Words.empty()) { + NextSeg = CurrentSegments[J]; + break; + } + } + float HalNextStart = + NextSeg ? NextSeg->Words.front().Start + : TimeOffset + static_cast(SegmentSize * + DefaultHopLength) / + DefaultSampleRate; + bool SilenceBefore = (Seg.Start - HalLastEnd > Threshold) || + (Seg.Start < Threshold) || + (Seg.Start - TimeOffset < 2.0f); + bool SilenceAfter = (HalNextStart - Seg.End > Threshold) || + isSegmentAnomaly(NextSeg) || + (WindowEndTime - Seg.End < 2.0f); + if (SilenceBefore && SilenceAfter) { + Seek = static_cast(std::round( + std::max(TimeOffset + 1.0f, Seg.Start) * FramesPerSecond)); + float RemainingContent = ContentDuration - Seg.End; + if (RemainingContent < Threshold) { + Seek = ContentFrames; + } + // drop subsequent segments + CurrentSegments.resize(Si); + break; + } + } + HalLastEnd = Seg.End; + } + } + + // update last_speech_timestamp + auto LastWordEndOpt = getEnd(CurrentSegments); + if (LastWordEndOpt) { + LastSpeechTimestamp = *LastWordEndOpt; + } + } + + // Verbose output for each segment + if (Verbose.value_or(false)) { + for (const auto &S : CurrentSegments) { + std::cout << "[" << formatTimestamp(S.Start) << " --> " + << formatTimestamp(S.End) << "] " << S.Text << std::endl; + } + } + + // Clean instantaneous or empty segments + for (auto &S : CurrentSegments) { + auto Trimmed = S.Text; + // trim whitespace + Trimmed.erase(0, Trimmed.find_first_not_of(" \t\n\r\f\v")); + if (!Trimmed.empty()) + Trimmed.erase(Trimmed.find_last_not_of(" \t\n\r\f\v") + 1); + if (S.Start == S.End || Trimmed.empty()) { + S.Text.clear(); + S.Tokens.clear(); + S.Words.clear(); + } + } + + // Add to results and accumulate tokens + for (auto &S : CurrentSegments) { + S.Id = static_cast(AllSegments.size()); + AllSegments.push_back(S); + AllTokens.insert(AllTokens.end(), S.Tokens.begin(), S.Tokens.end()); + } + + // Prompt reset logic + if (!ConditionOnPreviousText || Result.Temperature > 0.5f) { + PromptResetSince = static_cast(AllTokens.size()); + } + } + } + + // Prepare final result + TranscribeResult FinalResult; + FinalResult.Segments = AllSegments; + FinalResult.Language = *Options.Language; + + // Concatenate all text + std::stringstream TextStream; + for (const auto &Segment : AllSegments) { + TextStream << Segment.Text; + } + FinalResult.Text = TextStream.str(); + + return FinalResult; +} + +} // namespace whisper +} // namespace WasmEdge::Host::WASINN::MLX \ No newline at end of file diff --git a/plugins/wasi_nn/MLX/model/whisper_transcribe.h b/plugins/wasi_nn/MLX/model/whisper_transcribe.h new file mode 100644 index 00000000..5fdcd374 --- /dev/null +++ b/plugins/wasi_nn/MLX/model/whisper_transcribe.h @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2025 Second State INC + +#pragma once + +#include "whisper/decoding.h" +#include "whisper/tokenizer.h" +#include "whisper/whisper.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace WasmEdge::Host::WASINN::MLX { +namespace whisper { + +// Audio processing constants +constexpr int DefaultSampleRate = 16000; +constexpr int DefaultNFft = 400; +constexpr int DefaultHopLength = 160; +constexpr int DefaultChunkLength = 30; +constexpr int DefaultNSamples = + DefaultChunkLength * DefaultSampleRate; // 480000 samples +constexpr int DefaultNFrames = + DefaultNSamples / DefaultHopLength; // 3000 frames +constexpr int DefaultFramesPerSecond = DefaultSampleRate / DefaultHopLength; +constexpr int DefaultNSamplesPerToken = DefaultHopLength * 2; + +extern const std::vector> LANGUAGES; + +// Word information for word-level timestamps +struct WordInfo { + float Start; + float End; + std::string Word; + float Probability; +}; + +// Segment information +struct TranscribeSegment { + int Id; + int Seek; + float Start; + float End; + std::string Text; + std::vector Tokens; + float Temperature; + float AvgLogprob; + float CompressionRatio; + float NoSpeechProb; + std::vector Words; +}; + +// Complete transcribe result +struct TranscribeResult { + std::string Text; + std::vector Segments; + std::string Language; +}; + +// Audio processing functions +mx::array loadAudio(const std::string &FilePath, + int SampleRate = DefaultSampleRate); +mx::array padOrTrim(const mx::array &Array, int Length = DefaultNSamples, + int Axis = -1); +mx::array logMelSpectrogram(const mx::array &Audio, int NMels = 80, + int Padding = 0); + +// Utility functions +std::string formatTimestamp(float Seconds); +std::optional getEnd(const std::vector &Segments); + +// Decoding functions +DecodingResult +decodeWithFallback(std::shared_ptr Model, + const mx::array &MelSegment, + const DecodingOptions &DecodeOptions, + const std::vector &Temperatures, + std::unique_ptr &Tokenizer, + std::optional CompressionRatioThreshold, + std::optional LogprobThreshold, + std::optional NoSpeechThreshold); + +// Word-level timestamp functions +void addWordTimestamps(std::vector &Segments, + std::shared_ptr Model, + std::unique_ptr &Tokenizer, + const mx::array &Mel, int NumFrames, + const std::string &PrependPunctuations, + const std::string &AppendPunctuations, + float LastSpeechTimestamp = 0.0f); + +// Anomaly detection functions +float wordAnomalyScore(const WordInfo &Word); +bool isSegmentAnomaly(const std::optional &Segment); +std::optional +nextWordsSegment(const std::vector &Segments); + +// Main transcribe function +TranscribeResult +transcribe(const std::variant &Audio, + std::shared_ptr Model, + std::optional Verbose = std::nullopt, + std::variant> Temperature = + std::vector{0.0f, 0.2f, 0.4f, 0.6f, 0.8f, 1.0f}, + std::optional CompressionRatioThreshold = 2.4f, + std::optional LogprobThreshold = -1.0f, + std::optional NoSpeechThreshold = 0.6f, + bool ConditionOnPreviousText = true, + std::optional InitialPrompt = std::nullopt, + bool WordTimestamps = false, + const std::string &PrependPunctuations = + "\"'“¿([{-\"'.。,,!!??::”)]}、", + const std::string &AppendPunctuations = "\"'.,!?:\")]},", + std::variant> ClipTimestamps = "0", + std::optional HallucinationSilenceThreshold = std::nullopt, + const DecodingOptions &DecodeOptions = DecodingOptions()); + +} // namespace whisper +} // namespace WasmEdge::Host::WASINN::MLX \ No newline at end of file diff --git a/plugins/wasi_nn/wasinn_mlx.cpp b/plugins/wasi_nn/wasinn_mlx.cpp index 637f49a8..c6cd47a8 100644 --- a/plugins/wasi_nn/wasinn_mlx.cpp +++ b/plugins/wasi_nn/wasinn_mlx.cpp @@ -12,6 +12,8 @@ #include "MLX/model/llm/registry.h" #include "MLX/model/llm/transformer.h" #include "MLX/model/utils.h" +#include "MLX/model/whisper/whisper.h" +#include "MLX/model/whisper_transcribe.h" #include "MLX/prompt/prompt.h" #include #include @@ -21,20 +23,6 @@ namespace WasmEdge::Host::WASINN::MLX { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX -std::string loadBytesFromFile(const std::string &Path) { - std::ifstream Fs(Path, std::ios::in | std::ios::binary); - if (Fs.fail()) { - spdlog::error("[WASI-NN] MLX backend: Cannot open {}."sv, Path); - return ""; - } - std::string Data; - Fs.seekg(0, std::ios::end); - const size_t Size = static_cast(Fs.tellg()); - Fs.seekg(0, std::ios::beg); - Data.resize(Size); - Fs.read(Data.data(), Size); - return Data; -} mx::array fromBytes(const Span &Bytes) { if (Bytes.size() < 9) { @@ -378,6 +366,9 @@ Expect load(WASINN::WasiNNEnvironment &Env, } else if (GraphRef.ModelType == "gemma3") { auto Weight = mx::load_safetensors(ModelFilePath); Weights.insert(Weight.first.begin(), Weight.first.end()); + } else if (GraphRef.ModelType == "whisper") { + auto Weight = mx::load_safetensors(ModelFilePath); + Weights.insert(Weight.first.begin(), Weight.first.end()); } else { spdlog::error("[WASI-NN] MLX backend: Model type {} not supported."sv, GraphRef.ModelType); @@ -419,10 +410,16 @@ Expect load(WASINN::WasiNNEnvironment &Env, Weights = gemma3::VisionModel(ModelConfigObj.VisionConfig).sanitize(Weights); GraphRef.ModelArch = "vlm"; + } else if (GraphRef.ModelType == "whisper") { + auto Obj = Doc.get_object(); + whisper::ModelDimensions DefaultDims = + whisper::ModelDimensions::fromDict(Obj.value()); + GraphRef.Model = std::dynamic_pointer_cast( + std::make_shared(DefaultDims)); + GraphRef.ModelArch = "whisper"; } else { - spdlog::error( - "[WASI-NN] MLX backend: Model architecture {} not supported."sv, - GraphRef.ModelArch); + spdlog::error("[WASI-NN] MLX backend: Model type {} not supported."sv, + GraphRef.ModelType); Env.deleteGraph(GId); return ErrNo::InvalidArgument; } @@ -459,6 +456,8 @@ Expect load(WASINN::WasiNNEnvironment &Env, GraphRef.Model->update(Weights); } else if (GraphRef.ModelType == "gemma3") { GraphRef.Model->update(Weights); + } else if (GraphRef.ModelType == "whisper") { + GraphRef.Model->update(Weights); } else { spdlog::error("[WASI-NN] MLX backend: Model type {} not supported."sv, GraphRef.ModelType); @@ -488,6 +487,8 @@ Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, CxtRef.Inputs = LLMInput(); } else if (GraphRef.ModelArch == "vlm") { CxtRef.Inputs = VLMInput(); + } else if (GraphRef.ModelArch == "whisper") { + CxtRef.Inputs = WhisperInput(); } else { spdlog::error( "[WASI-NN] MLX backend: Model architecture {} not supported."sv, @@ -522,6 +523,10 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, spdlog::error("[WASI-NN] MLX backend: Index out of range."sv); return ErrNo::InvalidArgument; } + } else if (GraphRef.ModelArch == "whisper") { + std::get(CxtRef.Inputs).Audio = + std::string(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); } else { spdlog::error( "[WASI-NN] MLX backend: Model architecture {} not supported."sv, @@ -559,6 +564,17 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, spdlog::error("[WASI-NN] MLX backend: No output found."sv); return ErrNo::InvalidArgument; } + } else if (GraphRef.ModelArch == "whisper") { + auto *Output = std::get_if(&CxtRef.Outputs); + if (Output != nullptr) { + std::string Text = Output->Text; + std::copy_n(Text.data(), Text.length(), OutBuffer.data()); + BytesWritten = Text.length(); + } else { + spdlog::error("[WASI-NN] MLX backend: No output found."sv); + return ErrNo::InvalidArgument; + } + } else { spdlog::error( "[WASI-NN] MLX backend: Model architecture {} not supported."sv, @@ -603,6 +619,10 @@ Expect compute(WasiNNEnvironment &Env, mx::array(TokenList.data(), {static_cast(TokenList.size())}); CxtRef.Outputs = VLMOutput({TokenArray}); TokenListSize = TokenList.size(); + } else if (GraphRef.ModelArch == "whisper") { + CxtRef.Outputs = whisper::transcribe( + std::get(CxtRef.Inputs).Audio, + std::dynamic_pointer_cast(GraphRef.Model), false); } else { spdlog::error( "[WASI-NN] MLX backend: Model architecture {} not supported."sv, diff --git a/plugins/wasi_nn/wasinn_mlx.h b/plugins/wasi_nn/wasinn_mlx.h index 89aa9212..ab519bbf 100644 --- a/plugins/wasi_nn/wasinn_mlx.h +++ b/plugins/wasi_nn/wasinn_mlx.h @@ -14,6 +14,7 @@ #include "MLX/mlx/transformer.h" #include "MLX/model/llm/transformer.h" #include "MLX/prompt/prompt.h" +#include #include #include @@ -39,6 +40,9 @@ struct VLMInput { struct VLMOutput { mx::array Answer = mx::array({}); }; +struct WhisperInput { + std::string Audio; +}; struct Graph { std::string ModelType; std::string ModelArch; @@ -55,8 +59,8 @@ struct Graph { struct Context { Context(uint32_t Gid, Graph &) noexcept : GraphId(Gid) {} uint32_t GraphId; - std::variant Inputs; - std::variant Outputs; + std::variant Inputs; + std::variant Outputs; }; #else struct Graph {}; From fe36e028d3dab7b05091fd274231a49fd79eadb6 Mon Sep 17 00:00:00 2001 From: hydai Date: Fri, 12 Sep 2025 17:08:20 +0800 Subject: [PATCH 582/623] chore(lint): remove trailing spaces and add missing newline at end of file Signed-off-by: hydai --- plugins/wasi_nn/MLX/model/whisper/decoding.cpp | 2 +- plugins/wasi_nn/MLX/model/whisper/decoding.h | 2 +- plugins/wasi_nn/MLX/model/whisper/tokenizer.cpp | 2 +- plugins/wasi_nn/MLX/model/whisper/whisper.cpp | 2 +- plugins/wasi_nn/MLX/model/whisper/whisper.h | 2 +- plugins/wasi_nn/MLX/model/whisper_transcribe.cpp | 2 +- plugins/wasi_nn/MLX/model/whisper_transcribe.h | 2 +- plugins/wasi_nn/wasinn_bitnet.cpp | 2 +- plugins/wasi_nn/wasinn_bitnet.h | 2 +- test/plugins/wasi_nn/wasi_nn.cpp | 2 +- 10 files changed, 10 insertions(+), 10 deletions(-) diff --git a/plugins/wasi_nn/MLX/model/whisper/decoding.cpp b/plugins/wasi_nn/MLX/model/whisper/decoding.cpp index 4ad093e5..e0a93ff4 100644 --- a/plugins/wasi_nn/MLX/model/whisper/decoding.cpp +++ b/plugins/wasi_nn/MLX/model/whisper/decoding.cpp @@ -880,4 +880,4 @@ decode(std::shared_ptr Model, const mx::array &Mel, } } // namespace whisper -} // namespace WasmEdge::Host::WASINN::MLX \ No newline at end of file +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/whisper/decoding.h b/plugins/wasi_nn/MLX/model/whisper/decoding.h index aa563253..6e8b55cd 100644 --- a/plugins/wasi_nn/MLX/model/whisper/decoding.h +++ b/plugins/wasi_nn/MLX/model/whisper/decoding.h @@ -193,4 +193,4 @@ decode(std::shared_ptr Model, const mx::array &Mel, const DecodingOptions &Options = DecodingOptions()); } // namespace whisper -} // namespace WasmEdge::Host::WASINN::MLX \ No newline at end of file +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/whisper/tokenizer.cpp b/plugins/wasi_nn/MLX/model/whisper/tokenizer.cpp index 8eb68df3..42d76c6a 100644 --- a/plugins/wasi_nn/MLX/model/whisper/tokenizer.cpp +++ b/plugins/wasi_nn/MLX/model/whisper/tokenizer.cpp @@ -780,4 +780,4 @@ createWhisperTokenizer(const std::optional &Language, } } // namespace whisper -} // namespace WasmEdge::Host::WASINN::MLX \ No newline at end of file +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/whisper/whisper.cpp b/plugins/wasi_nn/MLX/model/whisper/whisper.cpp index 2e42f09d..ccc79190 100644 --- a/plugins/wasi_nn/MLX/model/whisper/whisper.cpp +++ b/plugins/wasi_nn/MLX/model/whisper/whisper.cpp @@ -458,4 +458,4 @@ std::shared_ptr Whisper::fromPretrained(const std::string &ModelPath) { } } // namespace whisper -} // namespace WasmEdge::Host::WASINN::MLX \ No newline at end of file +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/whisper/whisper.h b/plugins/wasi_nn/MLX/model/whisper/whisper.h index 7a4b0922..2587a37f 100644 --- a/plugins/wasi_nn/MLX/model/whisper/whisper.h +++ b/plugins/wasi_nn/MLX/model/whisper/whisper.h @@ -127,4 +127,4 @@ class Whisper : public nn::Module { }; } // namespace whisper -} // namespace WasmEdge::Host::WASINN::MLX \ No newline at end of file +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/whisper_transcribe.cpp b/plugins/wasi_nn/MLX/model/whisper_transcribe.cpp index 652fd0bc..fa9a0660 100644 --- a/plugins/wasi_nn/MLX/model/whisper_transcribe.cpp +++ b/plugins/wasi_nn/MLX/model/whisper_transcribe.cpp @@ -816,4 +816,4 @@ transcribe(const std::variant &Audio, } } // namespace whisper -} // namespace WasmEdge::Host::WASINN::MLX \ No newline at end of file +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/model/whisper_transcribe.h b/plugins/wasi_nn/MLX/model/whisper_transcribe.h index 5fdcd374..9948f606 100644 --- a/plugins/wasi_nn/MLX/model/whisper_transcribe.h +++ b/plugins/wasi_nn/MLX/model/whisper_transcribe.h @@ -121,4 +121,4 @@ transcribe(const std::variant &Audio, const DecodingOptions &DecodeOptions = DecodingOptions()); } // namespace whisper -} // namespace WasmEdge::Host::WASINN::MLX \ No newline at end of file +} // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/wasinn_bitnet.cpp b/plugins/wasi_nn/wasinn_bitnet.cpp index 74530fc2..867fbea0 100644 --- a/plugins/wasi_nn/wasinn_bitnet.cpp +++ b/plugins/wasi_nn/wasinn_bitnet.cpp @@ -2430,4 +2430,4 @@ Expect finalizeExecCtx(WasiNNEnvironment &, uint32_t) noexcept { } #endif -} // namespace WasmEdge::Host::WASINN::BitNet \ No newline at end of file +} // namespace WasmEdge::Host::WASINN::BitNet diff --git a/plugins/wasi_nn/wasinn_bitnet.h b/plugins/wasi_nn/wasinn_bitnet.h index 3361c40d..6fba1f39 100644 --- a/plugins/wasi_nn/wasinn_bitnet.h +++ b/plugins/wasi_nn/wasinn_bitnet.h @@ -137,4 +137,4 @@ Expect finalizeExecCtx(WASINN::WasiNNEnvironment &Env, Expect unload(WASINN::WasiNNEnvironment &Env, uint32_t GraphId) noexcept; -} // namespace WasmEdge::Host::WASINN::BitNet \ No newline at end of file +} // namespace WasmEdge::Host::WASINN::BitNet diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index a24c0f35..6e3b00c8 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -3463,4 +3463,4 @@ TEST(WasiNNTest, BitNetBackend) { static_cast(ErrNo::InvalidArgument)); } } -#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_BITNET \ No newline at end of file +#endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_BITNET From 85a3f30445a08e6a8dc93d6c64540a0d10863841 Mon Sep 17 00:00:00 2001 From: "Shen-Ta Hsieh(BestSteve)" Date: Tue, 16 Sep 2025 14:37:00 +0800 Subject: [PATCH 583/623] test: fix ubsan warnings (#4341) * round nan to max uint64_t * avoid using invalid input to shift number * free unused module instance * avoid unaligned access in unittest Signed-off-by: Shen-Ta Hsieh --- test/plugins/unittest/testplugin.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/plugins/unittest/testplugin.cpp b/test/plugins/unittest/testplugin.cpp index 589df415..95731844 100644 --- a/test/plugins/unittest/testplugin.cpp +++ b/test/plugins/unittest/testplugin.cpp @@ -44,7 +44,7 @@ static Plugin::PluginModule::ModuleDescriptor MD[]{ }, }; -Plugin::Plugin::PluginDescriptor Descriptor{ +const Plugin::Plugin::PluginDescriptor Descriptor{ /* Name */ "wasmedge_plugintest_cpp", /* Description */ "", /* APIVersion */ Plugin::Plugin::CurrentAPIVersion, From 37f27025a3a68975141cf0b1e938b47e3b69ca04 Mon Sep 17 00:00:00 2001 From: Han-Wen Tsao Date: Fri, 29 Aug 2025 13:55:07 +0800 Subject: [PATCH 584/623] feat(plugin): add WasmEdgeIOstream API (#4197) * feat(Plugin): add WasmEdgeIOstream API Signed-off-by: grorge * feat(plugin): add IOstream implement Signed-off-by: grorge * feat(WASI-NN/MLX): replace std io with WasmEdge io Signed-off-by: grorge * fix(Plugin): fix the segment fault on the non-WASI module Signed-off-by: grorge * ci(plugin): enable ci test in dev/wasi_vfs Signed-off-by: grorge * feat(Plugin): replace STD IO with WasmEdge IO Signed-off-by: grorge * fix(Plugin): string type Signed-off-by: grorge * refactor(WASI-NN): add set environ function Signed-off-by: grorge * refactor(Plugin): move IOFstream namespace Signed-off-by: grorge * test(Plugin): add vfs_io test Signed-off-by: grorge * test(Plugin): add unit test to cover more functions Signed-off-by: grorge * feat(Plugin): add the check file permission function Signed-off-by: grorge * refactor: move file check function to environ Signed-off-by: grorge * fix(Plugin): store Environ instead of currentframe Signed-off-by: grorge * fix(Plugin): change IFStream function signature Signed-off-by: grorge * fix(Plugin): remove eof check Signed-off-by: grorge * feat(Plugin): remove __wasi_fileflags_t Signed-off-by: grorge * fix(Plugin): add close file Signed-off-by: grorge * docs: add comment for FD usage Signed-off-by: grorge --------- Signed-off-by: grorge --- plugins/wasi_nn/wasinn_ggml.cpp | 23 +++++++++++++---------- plugins/wasi_nn/wasinn_mlx.cpp | 6 ++++-- plugins/wasi_nn/wasinn_whisper.cpp | 25 +++++++++++++------------ plugins/wasi_nn/wasinnenv.h | 16 ++++++++++++++++ plugins/wasi_nn/wasinnfunc.cpp | 12 ++++++++++++ 5 files changed, 58 insertions(+), 24 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 7483377f..6d0cc0ab 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -3,6 +3,7 @@ #include "wasinn_ggml.h" #include "common/types.h" +#include "host/wasi/vfs_io.h" #include "wasinnenv.h" #include @@ -2019,8 +2020,8 @@ std::string processTTSPromptText(const std::string &Text) { } std::optional -getSpeakerProfileFromFile(const std::string &FilePath) { - std::ifstream JsonFile(FilePath); +getSpeakerProfileFromFile(const std::string &FilePath, WasiNNEnvironment &Env) { + WasmEdge::FStream::IFStream JsonFile(FilePath, Env.getEnv()); if (!JsonFile.is_open()) { return std::nullopt; } @@ -2055,13 +2056,14 @@ getSpeakerProfileFromFile(const std::string &FilePath) { return TTSSpeakerProfile{TextOutput, AudioOutputText}; } -std::vector processTTSPrompt(Graph &GraphRef, +std::vector processTTSPrompt(WasiNNEnvironment &Env, + Graph &GraphRef, std::string &Prompt) noexcept { // Use the custom speaker profile if available. TTSSpeakerProfile SpeakerProfile = TTSDefaultSpeakerProfile; if (!GraphRef.TTSSpeakerFilePath.empty()) { std::optional SpeakerProfileOpt = - getSpeakerProfileFromFile(GraphRef.TTSSpeakerFilePath); + getSpeakerProfileFromFile(GraphRef.TTSSpeakerFilePath, Env); if (SpeakerProfileOpt.has_value()) { SpeakerProfile = *SpeakerProfileOpt; } else { @@ -2579,7 +2581,8 @@ std::vector audioDataToWav(const std::vector &Data, } // TextToSpeech function, will generate voice data from codes. -ErrNo codesToSpeech(Graph &GraphRef, Context &CxtRef) noexcept { +ErrNo codesToSpeech(WasiNNEnvironment &Env, Graph &GraphRef, + Context &CxtRef) noexcept { // Remove all non-audio tokens. CxtRef.LlamaOutputTokens.erase( std::remove_if(CxtRef.LlamaOutputTokens.begin(), @@ -2626,7 +2629,7 @@ ErrNo codesToSpeech(Graph &GraphRef, Context &CxtRef) noexcept { // Save .wav file if path is provided. if (!GraphRef.TTSOutputFilePath.empty()) { - std::ofstream File(GraphRef.TTSOutputFilePath, std::ios::binary); + WasmEdge::FStream::OFStream File(GraphRef.TTSOutputFilePath, Env.getEnv()); if (!File) { RET_ERROR(ErrNo::RuntimeError, "codesToSpeech: Failed to open file '{}' for writing"sv, @@ -2714,8 +2717,8 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // TODO: pass the model directly to ggml. // Write ggml model to file. GraphRef.Params.model.path = "ggml-model.bin"sv; - std::ofstream TempFile(GraphRef.Params.model.path, - std::ios::out | std::ios::binary); + WasmEdge::FStream::OFStream TempFile(GraphRef.Params.model.path, + Env.getEnv()); if (!TempFile) { Env.deleteGraph(GId.raw()); RET_ERROR(ErrNo::InvalidArgument, @@ -3067,7 +3070,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } else if (GraphRef.TextToSpeech == true) { // TTS prompt. LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize tts prompt"sv) - CxtRef.LlamaInputs = processTTSPrompt(GraphRef, Prompt); + CxtRef.LlamaInputs = processTTSPrompt(Env, GraphRef, Prompt); if (CxtRef.LlamaInputs.empty()) { RET_ERROR(ErrNo::InvalidArgument, "setInput: failed to tokenize tts prompt."sv) @@ -3177,7 +3180,7 @@ Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { if (GraphRef.TextToSpeech) { LOG_DEBUG(GraphRef.EnableDebugLog, "compute: convert output codes to audio file."sv) - ReturnCode = codesToSpeech(GraphRef, CxtRef); + ReturnCode = codesToSpeech(Env, GraphRef, CxtRef); if (ReturnCode != ErrNo::Success) { RET_ERROR(ReturnCode, "compute: failed to convert output codes to audio "sv diff --git a/plugins/wasi_nn/wasinn_mlx.cpp b/plugins/wasi_nn/wasinn_mlx.cpp index c6cd47a8..66d2a43a 100644 --- a/plugins/wasi_nn/wasinn_mlx.cpp +++ b/plugins/wasi_nn/wasinn_mlx.cpp @@ -19,6 +19,8 @@ #include #include + +#include "host/wasi/vfs_io.h" #endif namespace WasmEdge::Host::WASINN::MLX { @@ -339,7 +341,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, // Write model to file. // TODO: handle different model format. ModelFilePath = "MLX" + std::to_string(Idx) + ".safetensors"; - std::ofstream TempFile(ModelFilePath, std::ios::out | std::ios::binary); + WasmEdge::FStream::OFStream TempFile(ModelFilePath, Env.getEnv()); if (!TempFile) { spdlog::error( "[WASI-NN] MLX backend: Failed to create the temporary file. "sv); @@ -426,7 +428,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, // Load tokenizer. if (!TokenizerPath.empty()) { - auto Bytes = loadBytesFromFile(TokenizerPath); + auto Bytes = loadBytesFromFile(TokenizerPath, Env.getEnv()); if (Bytes.empty()) { spdlog::error("[WASI-NN] MLX backend: Load tokenizer failed."sv); Env.deleteGraph(GId); diff --git a/plugins/wasi_nn/wasinn_whisper.cpp b/plugins/wasi_nn/wasinn_whisper.cpp index 759e147f..6d8a608c 100644 --- a/plugins/wasi_nn/wasinn_whisper.cpp +++ b/plugins/wasi_nn/wasinn_whisper.cpp @@ -2,6 +2,7 @@ // SPDX-FileCopyrightText: 2019-2024 Second State INC #include "wasinn_whisper.h" +#include "host/wasi/vfs_io.h" #include "wasinnenv.h" #include #include @@ -74,10 +75,10 @@ estimateDiarizationSpeaker(const std::vector> PCMF32s, return Speaker; } -bool outputSrt(whisper_context *Ctx, const std::string &Fname, - const Config &Params, +bool outputSrt(WasiNNEnvironment &Env, whisper_context *Ctx, + const std::string &Fname, const Config &Params, const std::vector> &PCMF32s) { - std::ofstream Fout(Fname); + WasmEdge::FStream::OFStream Fout(Fname, Env.getEnv()); if (!Fout.is_open()) { spdlog::error("[WASI-NN] Whisper backend: failed to open {} for writing."sv, Fname); @@ -103,10 +104,10 @@ bool outputSrt(whisper_context *Ctx, const std::string &Fname, return true; } -static bool outputLrc(whisper_context *Ctx, const std::string &Fname, - const Config &Params, +static bool outputLrc(WasiNNEnvironment &Env, whisper_context *Ctx, + const std::string &Fname, const Config &Params, const std::vector> &PCMF32s) { - std::ofstream Fout(Fname); + WasmEdge::FStream::OFStream Fout(Fname, Env.getEnv()); if (!Fout.is_open()) { spdlog::error("[WASI-NN] Whisper backend: failed to open {} for writing."sv, Fname); @@ -157,10 +158,10 @@ std::string escapeDoubleQuotesAndBackslashes(const std::string &Str) { return Escaped; } -bool outputJson(whisper_context *Ctx, const std::string &Fname, - const Config &Params, +bool outputJson(WasiNNEnvironment &Env, whisper_context *Ctx, + const std::string &Fname, const Config &Params, const std::vector> &PCMF32s, bool Full) { - std::ofstream Fout(Fname); + WasmEdge::FStream::OFStream Fout(Fname, Env.getEnv()); int Indent = 0; auto Doindent = [&]() { @@ -1037,19 +1038,19 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, if (CxtRef.WhisperConfig.OutputSrt) { const auto Fname = CxtRef.WhisperConfig.FileName + ".srt"; - outputSrt(GraphRef.WhisperCtx, Fname, CxtRef.WhisperConfig, + outputSrt(Env, GraphRef.WhisperCtx, Fname, CxtRef.WhisperConfig, CxtRef.InputPCMs); } if (CxtRef.WhisperConfig.OutputLrc) { const auto Fname = CxtRef.WhisperConfig.FileName + ".lrc"; - outputLrc(GraphRef.WhisperCtx, Fname, CxtRef.WhisperConfig, + outputLrc(Env, GraphRef.WhisperCtx, Fname, CxtRef.WhisperConfig, CxtRef.InputPCMs); } if (CxtRef.WhisperConfig.OutputJson) { const auto Fname = CxtRef.WhisperConfig.FileName + ".json"; - outputJson(GraphRef.WhisperCtx, Fname, CxtRef.WhisperConfig, + outputJson(Env, GraphRef.WhisperCtx, Fname, CxtRef.WhisperConfig, CxtRef.InputPCMs, CxtRef.WhisperConfig.OutputJsonFull); } diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index e4e0a39d..143537d2 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -18,8 +18,12 @@ #include "wasinn_whisper.h" #include "wasinntypes.h" +#include "host/wasi/environ.h" + #include "common/spdlog.h" +#include "host/wasi/wasimodule.h" #include "plugin/plugin.h" +#include "runtime/callingframe.h" #include #include @@ -370,6 +374,18 @@ struct WasiNNEnvironment : static PO::Option NNRPCURI; // For RPC client mode std::shared_ptr NNRPCChannel; #endif + + const Host::WASI::Environ *getEnv() const noexcept { return Environ; } + void setEnviron(const Runtime::CallingFrame *CurrentFrame) noexcept { + auto *WasiModule = CurrentFrame->getWASIModule(); + if (WasiModule != nullptr) { + Environ = dynamic_cast(WasiModule) + ->getEnv(); + } + } + +private: + const Host::WASI::Environ *Environ = nullptr; }; } // namespace WASINN diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index 64056c5c..c9f2f91d 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -53,6 +53,7 @@ Expect WasiNNLoad::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t BuilderPtr, uint32_t BuilderLen, uint32_t RawEncoding, uint32_t Target, uint32_t GraphIdPtr) { + Env.setEnviron(&Frame); #ifdef WASMEDGE_BUILD_WASI_NN_RPC if (Env.NNRPCChannel != nullptr) { // TODO: implement RPC for Load @@ -119,6 +120,7 @@ WasiNNLoad::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t BuilderPtr, Expect WasiNNLoadByName::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t NamePtr, uint32_t NameLen, uint32_t GraphIdPtr) { + Env.setEnviron(&Frame); auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); @@ -168,6 +170,7 @@ WasiNNLoadByName::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t NamePtr, Expect WasiNNLoadByNameWithConfig::bodyImpl( const Runtime::CallingFrame &Frame, uint32_t NamePtr, uint32_t NameLen, uint32_t ConfigPtr, uint32_t ConfigLen, uint32_t GraphIdPtr) { + Env.setEnviron(&Frame); auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); @@ -230,6 +233,7 @@ Expect WasiNNLoadByNameWithConfig::bodyImpl( Expect WasiNNInitExecCtx::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t GraphId, uint32_t ContextPtr) { + Env.setEnviron(&Frame); auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); @@ -287,6 +291,7 @@ WasiNNInitExecCtx::bodyImpl(const Runtime::CallingFrame &Frame, Expect WasiNNSetInput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t ContextId, uint32_t Index, uint32_t TensorPtr) { + Env.setEnviron(&Frame); auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); @@ -388,6 +393,7 @@ WasiNNGetOutput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t ContextId, uint32_t Index, uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr) { + Env.setEnviron(&Frame); auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); @@ -452,6 +458,7 @@ Expect WasiNNGetOutputSingle::bodyImpl( const Runtime::CallingFrame &Frame, uint32_t ContextId, uint32_t Index, uint32_t OutBufferPtr, uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr) { + Env.setEnviron(&Frame); auto *MemInst = Frame.getMemoryByIndex(0); if (MemInst == nullptr) { return Unexpect(ErrCode::Value::HostFuncError); @@ -518,6 +525,7 @@ Expect WasiNNGetOutputSingle::bodyImpl( Expect WasiNNCompute::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t ContextId) { + Env.setEnviron(&Frame); #ifdef WASMEDGE_BUILD_WASI_NN_RPC if (Env.NNRPCChannel != nullptr) { auto Stub = wasi_ephemeral_nn::GraphExecutionContextResource::NewStub( @@ -570,6 +578,7 @@ WasiNNCompute::bodyImpl(const Runtime::CallingFrame &Frame, Expect WasiNNComputeSingle::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t ContextId) { + Env.setEnviron(&Frame); #ifdef WASMEDGE_BUILD_WASI_NN_RPC if (Env.NNRPCChannel != nullptr) { auto Stub = wasi_ephemeral_nn::GraphExecutionContextResource::NewStub( @@ -623,6 +632,7 @@ WasiNNComputeSingle::bodyImpl(const Runtime::CallingFrame &Frame, Expect WasiNNFiniSingle::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t ContextId) { + Env.setEnviron(&Frame); #ifdef WASMEDGE_BUILD_WASI_NN_RPC if (Env.NNRPCChannel != nullptr) { auto Stub = wasi_ephemeral_nn::GraphExecutionContextResource::NewStub( @@ -665,6 +675,7 @@ WasiNNFiniSingle::bodyImpl(const Runtime::CallingFrame &Frame, Expect WasiNNUnload::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t GraphId) { + Env.setEnviron(&Frame); #ifdef WASMEDGE_BUILD_WASI_NN_RPC if (Env.NNRPCChannel != nullptr) { // TODO: implement RPC for unload @@ -701,6 +712,7 @@ Expect WasiNNUnload::bodyImpl(const Runtime::CallingFrame &Frame, Expect WasiNNFinalizeExecCtx::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t ContextId) { + Env.setEnviron(&Frame); #ifdef WASMEDGE_BUILD_WASI_NN_RPC if (Env.NNRPCChannel != nullptr) { // TODO: implement RPC for finalize_execution_context From a0e3e729020d1e6248624db9242e9fd8f1dc9253 Mon Sep 17 00:00:00 2001 From: grorge Date: Wed, 24 Sep 2025 18:31:40 +0800 Subject: [PATCH 585/623] fix(WASI-NN/MLX): change loadBytesFromFile ifstream Signed-off-by: grorge --- plugins/wasi_nn/MLX/model/utils.cpp | 6 ++++-- plugins/wasi_nn/MLX/model/utils.h | 4 +++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/plugins/wasi_nn/MLX/model/utils.cpp b/plugins/wasi_nn/MLX/model/utils.cpp index 3fe41c57..4681b429 100644 --- a/plugins/wasi_nn/MLX/model/utils.cpp +++ b/plugins/wasi_nn/MLX/model/utils.cpp @@ -2,6 +2,7 @@ // SPDX-FileCopyrightText: 2019-2024 Second State INC #include "model/utils.h" +#include "host/wasi/vfs_io.h" #include #include @@ -59,8 +60,9 @@ void saveWeights(const mx::array &Weights, const std::string &Path) { } } -std::string loadBytesFromFile(const std::string &Path) { - std::ifstream Fs(Path, std::ios::in | std::ios::binary); +std::string loadBytesFromFile(const std::string &Path, + const Host::WASI::Environ *Env) { + WasmEdge::FStream::IFStream Fs(Path, Env); if (Fs.fail()) { spdlog::error("[WASI-NN] MLX backend: Cannot open {}."sv, Path); return ""; diff --git a/plugins/wasi_nn/MLX/model/utils.h b/plugins/wasi_nn/MLX/model/utils.h index 288b734d..390413fd 100644 --- a/plugins/wasi_nn/MLX/model/utils.h +++ b/plugins/wasi_nn/MLX/model/utils.h @@ -5,6 +5,7 @@ #include "mlx/base.h" +#include "host/wasi/vfs_io.h" #include #include #include @@ -24,7 +25,8 @@ void saveWeights(const std::unordered_map &Weights, void saveWeights(const mx::array &Weights, const std::string &Path); -std::string loadBytesFromFile(const std::string &Path); +std::string loadBytesFromFile(const std::string &Path, + const Host::WASI::Environ *Env); void fillPlaceholders(std::ostringstream &Oss, const std::string &Fmt, size_t &Pos); From 8a91c285cb7f6ceb1875597372581032243f2ea3 Mon Sep 17 00:00:00 2001 From: Han-Wen Tsao Date: Fri, 26 Sep 2025 19:56:05 +0800 Subject: [PATCH 586/623] refactor(WASI-NN/GGML): parse metadata (#4370) * refactor(WASI-NN/GGML): parse metadata Signed-off-by: grorge * feat(WASI-NN/GGML): change macro to template Signed-off-by: grorge --------- Signed-off-by: grorge --- plugins/wasi_nn/wasinn_ggml.cpp | 2192 ++++++++----------------------- 1 file changed, 528 insertions(+), 1664 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 6d0cc0ab..356179e0 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -57,6 +57,63 @@ namespace { spdlog::error("[WASI-NN] GGML backend: "sv __VA_ARGS__); \ return Error; +// JSON parsing helper template +template +T getJsonValue(const simdjson::dom::element &Doc, std::string_view Key) { + if (Doc.at_key(Key).error() == simdjson::SUCCESS) { + T Value{}; + auto Err = Doc[Key].get().get(Value); + if (Err) { + std::string Msg = fmt::format("Unable to retrieve the {} option.", Key); + spdlog::error("[WASI-NN] GGML backend: {}", Msg); + throw ErrNo(ErrNo::InvalidArgument); + } + return Value; + } + throw ErrNo(ErrNo::NotFound); +} + +template +void parseJsonAuto(const simdjson::dom::element &Doc, std::string_view Key, + T &Var) { + try { + auto Result = getJsonValue(Doc, Key); + Var = Result; + } catch (ErrNo E) { + if (E != ErrNo::NotFound) { + throw E; + } + } +} + +template +void parseJsonWithCastAuto(const simdjson::dom::element &Doc, + std::string_view Key, ToType &Var) { + try { + auto Result = getJsonValue(Doc, Key); + Var = static_cast(Result); + } catch (ErrNo E) { + if (E != ErrNo::NotFound) { + throw E; + } + } +} + +template +void parseJsonWithProcessorAuto(const simdjson::dom::element &Doc, + std::string_view Key, Processor Proc) { + try { + auto Result = getJsonValue(Doc, Key); + if (!Proc(Result)) { + throw ErrNo{ErrNo::InvalidArgument}; + } + } catch (ErrNo E) { + if (E != ErrNo::NotFound) { + throw E; + } + } +} + // Llama logging callback. void llamaLogCallback(ggml_log_level LogLevel, const char *LogText, void *UserData) { @@ -109,1671 +166,478 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, std::string PrevGrammar = GraphRef.Params.sampling.grammar; uint64_t PrevSeed = GraphRef.Params.sampling.seed; - // The plugin parameters. - if (Doc.at_key("enable-log").error() == simdjson::SUCCESS) { - auto Err = Doc["enable-log"].get().get(GraphRef.EnableLog); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the enable-log option."sv) - } - } - if (Doc.at_key("enable-debug-log").error() == simdjson::SUCCESS) { - auto Err = Doc["enable-debug-log"].get().get(GraphRef.EnableDebugLog); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the enable-debug-log option."sv) - } - } - - // The model parameters. - if (Doc.at_key("main-gpu").error() == simdjson::SUCCESS) { - int64_t MainGPU; - auto Err = Doc["main-gpu"].get().get(MainGPU); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the main-gpu option."sv) - } - GraphRef.Params.main_gpu = static_cast(MainGPU); - } - if (Doc.at_key("n-gpu-layers").error() == simdjson::SUCCESS) { - int64_t NGPULayers; - auto Err = Doc["n-gpu-layers"].get().get(NGPULayers); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the n-gpu-layers option."sv) - } - GraphRef.Params.n_gpu_layers = static_cast(NGPULayers); - } - if (Doc.at_key("cpu-moe").error() == simdjson::SUCCESS) { - bool CpuMoe; - auto Err = Doc["cpu-moe"].get().get(CpuMoe); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the cpu-moe option."sv) - } - if (CpuMoe) { - GraphRef.TensorBuftOverrides.push_back("\\.ffn_(up|down|gate)_exps"); - } - } - if (Doc.at_key("n-cpu-moe").error() == simdjson::SUCCESS) { - int64_t NCpuMoe; - auto Err = Doc["n-cpu-moe"].get().get(NCpuMoe); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the n-cpu-moe option."sv) - } - if (NCpuMoe < 0) { - RET_ERROR(ErrNo::InvalidArgument, "Invalid n-cpu-moe value."sv) - } - for (int I = 0; I < NCpuMoe; I++) { - GraphRef.TensorBuftOverrides.push_back( - string_format("blk\\.%d\\.ffn_(up|down|gate)_exps", I)); - } - } - if (Doc.at_key("tensor-split").error() == simdjson::SUCCESS) { - // The TensorSplit is a comma-separated list of non-negative values. - // E.g., "3,2" presents 60% of the data to GPU 0 and 40% to GPU 1. - std::string_view TSV; - auto Err = Doc["tensor-split"].get().get(TSV); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the tensor-split option."sv) - } - std::string TS(TSV); - std::replace(TS.begin(), TS.end(), ',', ' '); - std::stringstream SS(TS); - std::memset(GraphRef.Params.tensor_split, 0, - sizeof(GraphRef.Params.tensor_split)); - uint32_t TensorSplitSize = 0; - while (SS.good()) { - float TmpTensor; - SS >> TmpTensor; - GraphRef.Params.tensor_split[TensorSplitSize++] = TmpTensor; - } - size_t NDevices = llama_max_devices(); - if (TensorSplitSize > NDevices) { - RET_ERROR( - ErrNo::InvalidArgument, - "Number of Tensor-Split is larger than MaxDevices, please reduce "sv - "the size of tensor-split."sv) - } - for (size_t Idx = TensorSplitSize; Idx < NDevices; Idx++) { - GraphRef.Params.tensor_split[TensorSplitSize++] = 0.0f; - } - } - if (Doc.at_key("embedding").error() == simdjson::SUCCESS) { - - auto Err = Doc["embedding"].get().get(GraphRef.Params.embedding); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the embedding option."sv) - } - } - if (Doc.at_key("split-mode").error() == simdjson::SUCCESS) { - std::string_view SplitMode; - auto Err = Doc["split-mode"].get().get(SplitMode); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the split-mode option."sv) - } - if (SplitMode == "none"sv) { - GraphRef.Params.split_mode = LLAMA_SPLIT_MODE_NONE; - } else if (SplitMode == "layer"sv) { - GraphRef.Params.split_mode = LLAMA_SPLIT_MODE_LAYER; - } else if (SplitMode == "row"sv) { - GraphRef.Params.split_mode = LLAMA_SPLIT_MODE_ROW; - } else { - RET_ERROR(ErrNo::InvalidArgument, - "Unknown split-mode: {}. Valid: none, layer, row."sv, SplitMode) - } - } - if (Doc.at_key("mmproj").error() == simdjson::SUCCESS) { - std::string_view MMProjModelPath; - auto Err = Doc["mmproj"].get().get(MMProjModelPath); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the mmproj option."sv) - } - GraphRef.Params.mmproj.path = MMProjModelPath; - } - - // The TTS parameters. - if (Doc.at_key("tts").error() == simdjson::SUCCESS) { - auto Err = Doc["tts"].get().get(GraphRef.TextToSpeech); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the tts option."sv) - } - } - if (Doc.at_key("model-vocoder").error() == simdjson::SUCCESS) { - std::string_view VocoderModelPath; - auto Err = - Doc["model-vocoder"].get().get(VocoderModelPath); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the model-vocoder option."sv) - } - GraphRef.Params.vocoder.model.path = VocoderModelPath; - } - if (Doc.at_key("tts-output-file").error() == simdjson::SUCCESS) { - std::string_view TTSOutputFilePath; - auto Err = - Doc["tts-output-file"].get().get(TTSOutputFilePath); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the tts-output-file option."sv) - } - GraphRef.TTSOutputFilePath = TTSOutputFilePath; - } - if (Doc.at_key("tts-speaker-file").error() == simdjson::SUCCESS) { - std::string_view TTSSpeakerFilePath; - auto Err = - Doc["tts-speaker-file"].get().get(TTSSpeakerFilePath); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the tts-speaker-file option."sv) - } - GraphRef.TTSSpeakerFilePath = TTSSpeakerFilePath; - } - - // The context parameters. - if (Doc.at_key("ctx-size").error() == simdjson::SUCCESS) { - int64_t CtxSize; - auto Err = Doc["ctx-size"].get().get(CtxSize); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the ctx-size option."sv) - } - GraphRef.Params.n_ctx = static_cast(CtxSize); - } - if (Doc.at_key("batch-size").error() == simdjson::SUCCESS) { - int64_t BatchSize; - auto Err = Doc["batch-size"].get().get(BatchSize); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the batch-size option."sv) - } - GraphRef.Params.n_batch = static_cast(BatchSize); - } - if (Doc.at_key("ubatch-size").error() == simdjson::SUCCESS) { - int64_t UBatchSize; - auto Err = Doc["ubatch-size"].get().get(UBatchSize); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the ubatch-size option."sv) - } - GraphRef.Params.n_ubatch = static_cast(UBatchSize); - } - if (Doc.at_key("n-keep").error() == simdjson::SUCCESS) { - int64_t NKeep; - auto Err = Doc["n-keep"].get().get(NKeep); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the n-keep option."sv); - return ErrNo::InvalidArgument; - } - GraphRef.Params.n_keep = static_cast(NKeep); - } - if (Doc.at_key("n-chunks").error() == simdjson::SUCCESS) { - int64_t NChunks; - auto Err = Doc["n-chunks"].get().get(NChunks); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the n-chunks option."sv); - return ErrNo::InvalidArgument; - } - GraphRef.Params.n_chunks = static_cast(NChunks); - } - if (Doc.at_key("n-parallel").error() == simdjson::SUCCESS) { - int64_t NParallel; - auto Err = Doc["n-parallel"].get().get(NParallel); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the n-parallel option."sv); - return ErrNo::InvalidArgument; - } - GraphRef.Params.n_parallel = static_cast(NParallel); - } - if (Doc.at_key("n-sequences").error() == simdjson::SUCCESS) { - int64_t NSequences; - auto Err = Doc["n-sequences"].get().get(NSequences); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the n-sequences option."sv); - return ErrNo::InvalidArgument; - } - GraphRef.Params.n_sequences = static_cast(NSequences); - } - if (Doc.at_key("grp-attn-n").error() == simdjson::SUCCESS) { - int64_t GrpAttnN; - auto Err = Doc["grp-attn-n"].get().get(GrpAttnN); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the grp-attn-n option."sv); - return ErrNo::InvalidArgument; - } - GraphRef.Params.grp_attn_n = static_cast(GrpAttnN); - } - if (Doc.at_key("grp-attn-w").error() == simdjson::SUCCESS) { - int64_t GrpAttnW; - auto Err = Doc["grp-attn-w"].get().get(GrpAttnW); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the grp-attn-w option."sv); - return ErrNo::InvalidArgument; - } - GraphRef.Params.grp_attn_w = static_cast(GrpAttnW); - } - if (Doc.at_key("n-print").error() == simdjson::SUCCESS) { - int64_t NPrint; - auto Err = Doc["n-print"].get().get(NPrint); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the n-print option."sv); - return ErrNo::InvalidArgument; - } - GraphRef.Params.n_print = static_cast(NPrint); - } - if (Doc.at_key("rope-freq-base").error() == simdjson::SUCCESS) { - double RopeFreqBase; - auto Err = Doc["rope-freq-base"].get().get(RopeFreqBase); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the rope-freq-base option."sv); - return ErrNo::InvalidArgument; - } - GraphRef.Params.rope_freq_base = static_cast(RopeFreqBase); - } - if (Doc.at_key("rope-freq-scale").error() == simdjson::SUCCESS) { - double RopeFreqScale; - auto Err = Doc["rope-freq-scale"].get().get(RopeFreqScale); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the rope-freq-scale option."sv); - return ErrNo::InvalidArgument; - } - GraphRef.Params.rope_freq_scale = static_cast(RopeFreqScale); - } - if (Doc.at_key("yarn-ext-factor").error() == simdjson::SUCCESS) { - double YarnExtFactor; - auto Err = Doc["yarn-ext-factor"].get().get(YarnExtFactor); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the yarn-ext-factor option."sv); - return ErrNo::InvalidArgument; - } - GraphRef.Params.yarn_ext_factor = static_cast(YarnExtFactor); - } - if (Doc.at_key("yarn-attn-factor").error() == simdjson::SUCCESS) { - double YarnAttnFactor; - auto Err = Doc["yarn-attn-factor"].get().get(YarnAttnFactor); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the yarn-attn-factor option."sv); - return ErrNo::InvalidArgument; - } - GraphRef.Params.yarn_attn_factor = static_cast(YarnAttnFactor); - } - if (Doc.at_key("yarn-beta-fast").error() == simdjson::SUCCESS) { - double YarnBetaFast; - auto Err = Doc["yarn-beta-fast"].get().get(YarnBetaFast); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the yarn-beta-fast option."sv); - return ErrNo::InvalidArgument; - } - GraphRef.Params.yarn_beta_fast = static_cast(YarnBetaFast); - } - if (Doc.at_key("yarn-beta-slow").error() == simdjson::SUCCESS) { - double YarnBetaSlow; - auto Err = Doc["yarn-beta-slow"].get().get(YarnBetaSlow); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the yarn-beta-slow option."sv); - return ErrNo::InvalidArgument; - } - GraphRef.Params.yarn_beta_slow = static_cast(YarnBetaSlow); - } - if (Doc.at_key("yarn-orig-ctx").error() == simdjson::SUCCESS) { - int64_t YarnOrigCtx; - auto Err = Doc["yarn-orig-ctx"].get().get(YarnOrigCtx); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the yarn-orig-ctx option."sv); - return ErrNo::InvalidArgument; - } - GraphRef.Params.yarn_orig_ctx = static_cast(YarnOrigCtx); - } - if (Doc.at_key("mask-valid").error() == simdjson::SUCCESS) { - auto Err = - Doc["mask-valid"].get().get(GraphRef.Params.cpuparams.mask_valid); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the mask-valid option."sv); - return ErrNo::InvalidArgument; - } - } - if (Doc.at_key("priority").error() == simdjson::SUCCESS) { - int64_t Priority; - auto Err = Doc["priority"].get().get(Priority); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the priority option."sv); - return ErrNo::InvalidArgument; - } - GraphRef.Params.cpuparams.priority = - static_cast(Priority); - } - if (Doc.at_key("strict-cpu").error() == simdjson::SUCCESS) { - auto Err = - Doc["strict-cpu"].get().get(GraphRef.Params.cpuparams.strict_cpu); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the strict-cpu option."sv); - return ErrNo::InvalidArgument; - } - } - if (Doc.at_key("poll").error() == simdjson::SUCCESS) { - int64_t Poll; - auto Err = Doc["poll"].get().get(Poll); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the poll option."sv); - return ErrNo::InvalidArgument; - } - GraphRef.Params.cpuparams.poll = static_cast(Poll); - } - if (Doc.at_key("mask-valid-batch").error() == simdjson::SUCCESS) { - auto Err = Doc["mask-valid-batch"].get().get( - GraphRef.Params.cpuparams_batch.mask_valid); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the mask-valid-batch option."sv); - return ErrNo::InvalidArgument; - } - } - if (Doc.at_key("priority-batch").error() == simdjson::SUCCESS) { - int64_t Priority; - auto Err = Doc["priority-batch"].get().get(Priority); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the priority-batch option."sv); - return ErrNo::InvalidArgument; - } - GraphRef.Params.cpuparams_batch.priority = - static_cast(Priority); - } - if (Doc.at_key("strict-cpu-batch").error() == simdjson::SUCCESS) { - auto Err = Doc["strict-cpu-batch"].get().get( - GraphRef.Params.cpuparams_batch.strict_cpu); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the strict-cpu-batch option."sv); - return ErrNo::InvalidArgument; - } - } - if (Doc.at_key("poll-batch").error() == simdjson::SUCCESS) { - int64_t Poll; - auto Err = Doc["poll-batch"].get().get(Poll); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the poll-batch option."sv); - return ErrNo::InvalidArgument; - } - GraphRef.Params.cpuparams_batch.poll = static_cast(Poll); - } - if (Doc.at_key("numa").error() == simdjson::SUCCESS) { - int64_t Numa; - auto Err = Doc["numa"].get().get(Numa); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the numa option."sv); - return ErrNo::InvalidArgument; - } - GraphRef.Params.numa = static_cast(Numa); - } - if (Doc.at_key("rope-scaling-type").error() == simdjson::SUCCESS) { - int64_t RopeScalingType; - auto Err = Doc["rope-scaling-type"].get().get(RopeScalingType); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the rope-scaling-type option."sv); - return ErrNo::InvalidArgument; - } - GraphRef.Params.rope_scaling_type = - static_cast(RopeScalingType); - } - if (Doc.at_key("pooling-type").error() == simdjson::SUCCESS) { - int64_t PoolingType; - auto Err = Doc["pooling-type"].get().get(PoolingType); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the pooling-type option."sv); - return ErrNo::InvalidArgument; - } - GraphRef.Params.pooling_type = - static_cast(PoolingType); - } - if (Doc.at_key("attention-type").error() == simdjson::SUCCESS) { - int64_t AttentionType; - auto Err = Doc["attention-type"].get().get(AttentionType); - if (Err) { - spdlog::error( - "[WASI-NN] GGML backend: Unable to retrieve the attention-type option."sv); - return ErrNo::InvalidArgument; - } - GraphRef.Params.attention_type = - static_cast(AttentionType); - } - if (Doc.at_key("threads").error() == simdjson::SUCCESS) { - int64_t NThreads; - auto Err = Doc["threads"].get().get(NThreads); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the threads option."sv) - } - GraphRef.Params.cpuparams.n_threads = static_cast(NThreads); - } - if (Doc.at_key("threads-batch").error() == simdjson::SUCCESS) { - int64_t NThreadsBatch; - auto Err = Doc["threads-batch"].get().get(NThreadsBatch); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the threads-batch option."sv) - } - GraphRef.Params.cpuparams_batch.n_threads = - static_cast(NThreadsBatch); - } - - // The sampling parameters. - if (Doc.at_key("n-prev").error() == simdjson::SUCCESS) { - int64_t NPrev; - auto Err = Doc["n-prev"].get().get(NPrev); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the n_prev option."sv) - } - GraphRef.Params.sampling.n_prev = static_cast(NPrev); - } - if (Doc.at_key("n-probs").error() == simdjson::SUCCESS) { - int64_t NProbs; - auto Err = Doc["n-probs"].get().get(NProbs); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the n_probs option."sv) - } - GraphRef.Params.sampling.n_probs = static_cast(NProbs); - } - if (Doc.at_key("min-keep").error() == simdjson::SUCCESS) { - int64_t MinKeep; - auto Err = Doc["min-keep"].get().get(MinKeep); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the min-keep option."sv) - } - GraphRef.Params.sampling.min_keep = static_cast(MinKeep); - } - if (Doc.at_key("top-k").error() == simdjson::SUCCESS) { - int64_t TopK; - auto Err = Doc["top-k"].get().get(TopK); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the top-k option."sv) - } - GraphRef.Params.sampling.top_k = static_cast(TopK); - } - if (Doc.at_key("min-p").error() == simdjson::SUCCESS) { - double MinP; - auto Err = Doc["min-p"].get().get(MinP); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the min-p option."sv) - } - GraphRef.Params.sampling.min_p = static_cast(MinP); - } - if (Doc.at_key("xtc-probability").error() == simdjson::SUCCESS) { - double XtcProbability; - auto Err = Doc["xtc-probability"].get().get(XtcProbability); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the xtc-probability option."sv) - } - GraphRef.Params.sampling.xtc_probability = - static_cast(XtcProbability); - } - if (Doc.at_key("xtc-threshold").error() == simdjson::SUCCESS) { - double XtcThreshold; - auto Err = Doc["xtc-threshold"].get().get(XtcThreshold); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the xtc-threshold option."sv) - } - GraphRef.Params.sampling.xtc_threshold = static_cast(XtcThreshold); - } - if (Doc.at_key("typ-p").error() == simdjson::SUCCESS) { - double TypP; - auto Err = Doc["typ-p"].get().get(TypP); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the typ-p option."sv) - } - GraphRef.Params.sampling.typ_p = static_cast(TypP); - } - if (Doc.at_key("dynatemp-range").error() == simdjson::SUCCESS) { - double DynaTempRange; - auto Err = Doc["dynatemp-range"].get().get(DynaTempRange); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the dynatemp-range option."sv) - } - GraphRef.Params.sampling.dynatemp_range = static_cast(DynaTempRange); - } - if (Doc.at_key("dynatemp-exponent").error() == simdjson::SUCCESS) { - double DynaTempExponent; - auto Err = Doc["dynatemp-exponent"].get().get(DynaTempExponent); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the dynatemp-exponent option."sv) - } - GraphRef.Params.sampling.dynatemp_exponent = - static_cast(DynaTempExponent); - } - if (Doc.at_key("last-n-penalty").error() == simdjson::SUCCESS) { - int64_t LastNPenalty; - auto Err = Doc["last-n-penalty"].get().get(LastNPenalty); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the last-n-penalty option."sv) - } - GraphRef.Params.sampling.penalty_last_n = - static_cast(LastNPenalty); - } - if (Doc.at_key("temp").error() == simdjson::SUCCESS) { - double Temp; - auto Err = Doc["temp"].get().get(Temp); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the temp option."sv) - } - GraphRef.Params.sampling.temp = static_cast(std::max(0.0, Temp)); - } - if (Doc.at_key("top-p").error() == simdjson::SUCCESS) { - double TopP; - auto Err = Doc["top-p"].get().get(TopP); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the top-p option."sv) - } - GraphRef.Params.sampling.top_p = static_cast(std::max(0.0, TopP)); - } - if (Doc.at_key("repeat-penalty").error() == simdjson::SUCCESS) { - double RepeatPenalty; - auto Err = Doc["repeat-penalty"].get().get(RepeatPenalty); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the repeat-penalty option."sv) - } - GraphRef.Params.sampling.penalty_repeat = - static_cast(std::max(0.0, RepeatPenalty)); + try { + parseJsonAuto(Doc, "enable-log", GraphRef.EnableLog); + parseJsonAuto(Doc, "enable-debug-log", GraphRef.EnableDebugLog); + + parseJsonWithCastAuto(Doc, "main-gpu", GraphRef.Params.main_gpu); + parseJsonWithCastAuto(Doc, "n-gpu-layers", + GraphRef.Params.n_gpu_layers); + + parseJsonWithProcessorAuto(Doc, "cpu-moe", + [&GraphRef](const bool &CpuMoe) -> bool { + if (CpuMoe) { + GraphRef.TensorBuftOverrides.push_back( + "\\.ffn_(up|down|gate)_exps"); + } + return true; + }); + + parseJsonWithProcessorAuto( + Doc, "n-cpu-moe", [&GraphRef](const int64_t &NCpuMoe) -> bool { + if (NCpuMoe < 0) { + spdlog::error("[WASI-NN] GGML backend: Invalid n-cpu-moe value."); + return false; + } + for (int I = 0; I < NCpuMoe; I++) { + GraphRef.TensorBuftOverrides.push_back( + string_format("blk\\.%d\\.ffn_(up|down|gate)_exps", I)); + } + return true; + }); + parseJsonWithProcessorAuto( + Doc, "tensor-split", [&GraphRef](const std::string_view &TSV) -> bool { + // The TensorSplit is a comma-separated list of non-negative values. + // E.g., "3,2" presents 60% of the data to GPU 0 and 40% to GPU 1. + std::string TS(TSV); + std::replace(TS.begin(), TS.end(), ',', ' '); + std::stringstream SS(TS); + std::memset(GraphRef.Params.tensor_split, 0, + sizeof(GraphRef.Params.tensor_split)); + uint32_t TensorSplitSize = 0; + while (SS.good()) { + float TmpTensor; + SS >> TmpTensor; + GraphRef.Params.tensor_split[TensorSplitSize++] = TmpTensor; + } + size_t NDevices = llama_max_devices(); + if (TensorSplitSize > NDevices) { + spdlog::error( + "[WASI-NN] GGML backend: Number of Tensor-Split is larger than " + "MaxDevices, please reduce the size of tensor-split."); + return false; + } + for (size_t Idx = TensorSplitSize; Idx < NDevices; Idx++) { + GraphRef.Params.tensor_split[TensorSplitSize++] = 0.0f; + } + return true; + }); + parseJsonAuto(Doc, "embedding", GraphRef.Params.embedding); + parseJsonWithProcessorAuto( + Doc, "split-mode", + [&GraphRef](const std::string_view &SplitMode) -> bool { + if (SplitMode == "none"sv) { + GraphRef.Params.split_mode = LLAMA_SPLIT_MODE_NONE; + } else if (SplitMode == "layer"sv) { + GraphRef.Params.split_mode = LLAMA_SPLIT_MODE_LAYER; + } else if (SplitMode == "row"sv) { + GraphRef.Params.split_mode = LLAMA_SPLIT_MODE_ROW; + } else { + spdlog::error("[WASI-NN] GGML backend: Unknown split-mode: " + "{}. Valid: none, layer, row.", + SplitMode); + return false; + } + return true; + }); + parseJsonWithCastAuto(Doc, "mmproj", + GraphRef.Params.mmproj.path); + // The TTS parameters using macros + parseJsonAuto(Doc, "tts", GraphRef.TextToSpeech); + parseJsonWithCastAuto(Doc, "model-vocoder", + GraphRef.Params.vocoder.model.path); + parseJsonWithCastAuto(Doc, "tts-output-file", + GraphRef.TTSOutputFilePath); + + parseJsonWithCastAuto(Doc, "tts-speaker-file", + GraphRef.TTSSpeakerFilePath); + + // The context parameters + parseJsonWithCastAuto(Doc, "ctx-size", GraphRef.Params.n_ctx); + parseJsonWithCastAuto(Doc, "batch-size", GraphRef.Params.n_batch); + parseJsonWithCastAuto(Doc, "ubatch-size", + GraphRef.Params.n_ubatch); + parseJsonWithCastAuto(Doc, "n-keep", GraphRef.Params.n_keep); + parseJsonWithCastAuto(Doc, "n-chunks", GraphRef.Params.n_chunks); + parseJsonWithCastAuto(Doc, "n-parallel", + GraphRef.Params.n_parallel); + parseJsonWithCastAuto(Doc, "n-sequences", + GraphRef.Params.n_sequences); + parseJsonWithCastAuto(Doc, "grp-attn-n", + GraphRef.Params.grp_attn_n); + parseJsonWithCastAuto(Doc, "grp-attn-w", + GraphRef.Params.grp_attn_w); + parseJsonWithCastAuto(Doc, "n-print", GraphRef.Params.n_print); + parseJsonWithCastAuto(Doc, "rope-freq-base", + GraphRef.Params.rope_freq_base); + parseJsonWithCastAuto(Doc, "rope-freq-scale", + GraphRef.Params.rope_freq_scale); + parseJsonWithCastAuto(Doc, "yarn-ext-factor", + GraphRef.Params.yarn_ext_factor); + parseJsonWithCastAuto(Doc, "yarn-attn-factor", + GraphRef.Params.yarn_attn_factor); + parseJsonWithCastAuto(Doc, "yarn-beta-fast", + GraphRef.Params.yarn_beta_fast); + parseJsonWithCastAuto(Doc, "yarn-beta-slow", + GraphRef.Params.yarn_beta_slow); + parseJsonWithCastAuto(Doc, "yarn-orig-ctx", + GraphRef.Params.yarn_orig_ctx); + parseJsonAuto(Doc, "mask-valid", + GraphRef.Params.cpuparams.mask_valid); + parseJsonWithCastAuto(Doc, "priority", + GraphRef.Params.cpuparams.priority); + parseJsonAuto(Doc, "strict-cpu", + GraphRef.Params.cpuparams.strict_cpu); + parseJsonWithCastAuto(Doc, "poll", GraphRef.Params.cpuparams.poll); + parseJsonAuto(Doc, "mask-valid-batch", + GraphRef.Params.cpuparams_batch.mask_valid); + parseJsonWithProcessorAuto( + Doc, "priority-batch", [&GraphRef](const int64_t &Priority) -> bool { + GraphRef.Params.cpuparams_batch.priority = + static_cast(Priority); + return true; + }); + parseJsonAuto(Doc, "strict-cpu-batch", + GraphRef.Params.cpuparams_batch.strict_cpu); + parseJsonWithCastAuto(Doc, "poll-batch", + GraphRef.Params.cpuparams_batch.poll); + + parseJsonWithCastAuto(Doc, "numa", GraphRef.Params.numa); + + parseJsonWithCastAuto(Doc, "rope-scaling-type", + GraphRef.Params.rope_scaling_type); + parseJsonWithCastAuto(Doc, "pooling-type", + GraphRef.Params.pooling_type); + parseJsonWithCastAuto(Doc, "attention-type", + GraphRef.Params.attention_type); + parseJsonWithProcessorAuto( + Doc, "threads", [&GraphRef](const int64_t &NThreads) -> bool { + GraphRef.Params.cpuparams.n_threads = static_cast(NThreads); + return true; + }); + parseJsonWithCastAuto(Doc, "threads-batch", + GraphRef.Params.cpuparams_batch.n_threads); + parseJsonWithCastAuto(Doc, "n-prev", + GraphRef.Params.sampling.n_prev); + parseJsonWithCastAuto(Doc, "n-probs", + GraphRef.Params.sampling.n_probs); + parseJsonWithCastAuto(Doc, "min-keep", + GraphRef.Params.sampling.min_keep); + parseJsonWithCastAuto(Doc, "top-k", + GraphRef.Params.sampling.top_k); + parseJsonWithCastAuto(Doc, "min-p", GraphRef.Params.sampling.min_p); + parseJsonWithCastAuto(Doc, "xtc-probability", + GraphRef.Params.sampling.xtc_probability); + parseJsonWithCastAuto(Doc, "xtc-threshold", + GraphRef.Params.sampling.xtc_threshold); + parseJsonWithCastAuto(Doc, "typ-p", GraphRef.Params.sampling.typ_p); + parseJsonWithCastAuto(Doc, "dynatemp-range", + GraphRef.Params.sampling.dynatemp_range); + parseJsonWithCastAuto(Doc, "dynatemp-exponent", + GraphRef.Params.sampling.dynatemp_exponent); + parseJsonWithCastAuto(Doc, "last-n-penalty", + GraphRef.Params.sampling.penalty_last_n); + parseJsonWithProcessorAuto( + Doc, "temp", [&GraphRef](const double &Temp) -> bool { + GraphRef.Params.sampling.temp = + static_cast(std::max(0.0, Temp)); + return true; + }); + + parseJsonWithProcessorAuto( + Doc, "top-p", [&GraphRef](const double &TopP) -> bool { + GraphRef.Params.sampling.top_p = + static_cast(std::max(0.0, TopP)); + return true; + }); + parseJsonWithProcessorAuto( + Doc, "repeat-penalty", + [&GraphRef](const double &RepeatPenalty) -> bool { + GraphRef.Params.sampling.penalty_repeat = + static_cast(std::max(0.0, RepeatPenalty)); + return true; + }); + parseJsonWithProcessorAuto( + Doc, "presence-penalty", + [&GraphRef](const double &PresencePenalty) -> bool { + GraphRef.Params.sampling.penalty_present = + static_cast(std::max(0.0, PresencePenalty)); + return true; + }); + parseJsonWithProcessorAuto( + Doc, "frequency-penalty", + [&GraphRef](const double &FrequencyPenalty) -> bool { + GraphRef.Params.sampling.penalty_freq = + static_cast(std::max(0.0, FrequencyPenalty)); + return true; + }); + parseJsonWithCastAuto(Doc, "dry-multipier", + GraphRef.Params.sampling.dry_multiplier); + parseJsonWithCastAuto(Doc, "dry-base", + GraphRef.Params.sampling.dry_base); + parseJsonWithCastAuto(Doc, "dry-allowed-length", + GraphRef.Params.sampling.dry_allowed_length); + parseJsonWithCastAuto(Doc, "dry-last-n-penalty", + GraphRef.Params.sampling.penalty_last_n); + parseJsonWithCastAuto(Doc, "mirostat", + GraphRef.Params.sampling.mirostat); + parseJsonWithCastAuto(Doc, "mirostat-eta", + GraphRef.Params.sampling.mirostat_eta); + parseJsonAuto(Doc, "ignore-eos", GraphRef.Params.sampling.ignore_eos); + parseJsonAuto(Doc, "no-perf-sampling", + GraphRef.Params.sampling.no_perf); + parseJsonAuto(Doc, "timing-per-token", + GraphRef.Params.sampling.timing_per_token); + + parseJsonWithCastAuto(Doc, "grammar", + GraphRef.Params.sampling.grammar); + parseJsonWithProcessorAuto( + Doc, "json-schema", + [&GraphRef](const std::string_view &JsonSchema) -> bool { + GraphRef.Params.sampling.grammar = + json_schema_to_grammar(nlohmann::ordered_json::parse(JsonSchema)); + return true; + }); + parseJsonWithCastAuto(Doc, "seed", GraphRef.Params.sampling.seed); + + // The speculative parameters. + parseJsonWithCastAuto(Doc, "n-ctx-speculative", + GraphRef.Params.speculative.n_ctx); + parseJsonWithCastAuto(Doc, "n-max-speculative", + GraphRef.Params.speculative.n_max); + parseJsonWithCastAuto(Doc, "n-min-speculative", + GraphRef.Params.speculative.n_min); + parseJsonWithCastAuto(Doc, "n-gpu-layers-speculative", + GraphRef.Params.speculative.n_gpu_layers); + parseJsonWithCastAuto(Doc, "p-split-speculative", + GraphRef.Params.speculative.p_split); + parseJsonWithCastAuto(Doc, "p-min-speculative", + GraphRef.Params.speculative.p_min); + // The vocoder parameters. + parseJsonWithCastAuto( + Doc, "hf-repo-vocoder", GraphRef.Params.vocoder.model.hf_repo); + parseJsonWithCastAuto( + Doc, "hf-file-vocoder", GraphRef.Params.vocoder.model.hf_file); + parseJsonWithCastAuto(Doc, "model-url-vocoder", + GraphRef.Params.vocoder.model.url); + + // The config parameters. + parseJsonAuto(Doc, "stream-stdout", ConfRef.StreamStdout); + parseJsonAuto(Doc, "n-predict", ConfRef.NPredict); + parseJsonWithCastAuto(Doc, "reverse-prompt", + ConfRef.ReversePrompt); + parseJsonWithCastAuto(Doc, "image", ConfRef.ImagePath); + parseJsonAuto(Doc, "always-regenerate-image-embd", + ConfRef.AlwaysRegenerateImageEmbd); + parseJsonWithCastAuto(Doc, "model-alias", + GraphRef.Params.model_alias); + parseJsonWithCastAuto(Doc, "model-url", + GraphRef.Params.model.url); + parseJsonWithCastAuto(Doc, "hf-token", + GraphRef.Params.hf_token); + parseJsonWithCastAuto(Doc, "hf-repo", + GraphRef.Params.model.hf_repo); + parseJsonWithCastAuto(Doc, "hf-file", + GraphRef.Params.model.hf_file); + parseJsonWithCastAuto(Doc, "prompt-file", + GraphRef.Params.prompt_file); + parseJsonWithCastAuto(Doc, "path-prompt-cache", + GraphRef.Params.path_prompt_cache); + parseJsonWithCastAuto(Doc, "input-prefix", + GraphRef.Params.input_prefix); + parseJsonWithCastAuto(Doc, "input-suffix", + GraphRef.Params.input_suffix); + parseJsonWithCastAuto( + Doc, "lookup-cache-static", GraphRef.Params.lookup_cache_static); + parseJsonWithCastAuto( + Doc, "lookup-cache-dynamic", GraphRef.Params.lookup_cache_dynamic); + parseJsonWithCastAuto(Doc, "logits-file", + GraphRef.Params.logits_file); + parseJsonAuto(Doc, "lora-init-without-apply", + GraphRef.Params.lora_init_without_apply); + parseJsonWithCastAuto(Doc, "verbosity", GraphRef.Params.verbosity); + parseJsonWithCastAuto(Doc, "control-vector-layer-start", + GraphRef.Params.control_vector_layer_start); + parseJsonWithCastAuto(Doc, "control-vector-layer-end", + GraphRef.Params.control_vector_layer_end); + parseJsonWithCastAuto(Doc, "ppl-stride", + GraphRef.Params.ppl_stride); + parseJsonWithCastAuto(Doc, "ppl-output-type", + GraphRef.Params.ppl_output_type); + parseJsonAuto(Doc, "hellaswag", GraphRef.Params.hellaswag); + parseJsonWithCastAuto(Doc, "hellaswag-tasks", + GraphRef.Params.hellaswag_tasks); + parseJsonAuto(Doc, "winogrande", GraphRef.Params.winogrande); + parseJsonWithCastAuto(Doc, "winogrande-tasks", + GraphRef.Params.winogrande_tasks); + parseJsonAuto(Doc, "multiple-choice", + GraphRef.Params.multiple_choice); + parseJsonWithCastAuto( + Doc, "multiple-choice-tasks", GraphRef.Params.multiple_choice_tasks); + parseJsonAuto(Doc, "kl-divergence", GraphRef.Params.kl_divergence); + parseJsonAuto(Doc, "usage", GraphRef.Params.usage); + parseJsonAuto(Doc, "use-color", GraphRef.Params.use_color); + parseJsonAuto(Doc, "special", GraphRef.Params.special); + parseJsonAuto(Doc, "interactive", GraphRef.Params.interactive); + parseJsonAuto(Doc, "interactive-first", + GraphRef.Params.interactive_first); + parseJsonAuto(Doc, "prompt-cache-all", + GraphRef.Params.prompt_cache_all); + parseJsonAuto(Doc, "prompt-cache-ro", + GraphRef.Params.prompt_cache_ro); + parseJsonAuto(Doc, "escape", GraphRef.Params.escape); + parseJsonAuto(Doc, "multiline-input", + GraphRef.Params.multiline_input); + parseJsonAuto(Doc, "simple-io", GraphRef.Params.simple_io); + parseJsonAuto(Doc, "cont-batching", GraphRef.Params.cont_batching); + parseJsonWithProcessorAuto( + Doc, "flash-attn", + [&GraphRef](const std::string_view &FlashAttn) -> bool { + if (FlashAttn == "on"sv || FlashAttn == "enabled"sv) { + GraphRef.Params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED; + } else if (FlashAttn == "off"sv || FlashAttn == "disabled"sv) { + GraphRef.Params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; + } else if (FlashAttn == "auto"sv) { + GraphRef.Params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; + } else { + spdlog::error( + "[WASI-NN] GGML backend: The flash-attn option must be " + "one of: on, off, auto."); + return false; + } + return true; + }); + parseJsonAuto(Doc, "no-perf", GraphRef.Params.no_perf); + parseJsonAuto(Doc, "ctx-shift", GraphRef.Params.ctx_shift); + parseJsonAuto(Doc, "input-prefix-bos", + GraphRef.Params.input_prefix_bos); + parseJsonAuto(Doc, "use-mlock", GraphRef.Params.use_mlock); + parseJsonAuto(Doc, "use-mmap", GraphRef.Params.use_mmap); + parseJsonAuto(Doc, "verbose-prompt", GraphRef.Params.verbose_prompt); + parseJsonAuto(Doc, "display-prompt", GraphRef.Params.display_prompt); + parseJsonAuto(Doc, "no-kv-offload", GraphRef.Params.no_kv_offload); + parseJsonAuto(Doc, "warmup", GraphRef.Params.warmup); + parseJsonAuto(Doc, "check-tensors", GraphRef.Params.check_tensors); + parseJsonWithCastAuto(Doc, "cache-type-k", + GraphRef.Params.cache_type_k); + parseJsonWithCastAuto(Doc, "cache-type-v", + GraphRef.Params.cache_type_v); + + parseJsonWithCastAuto(Doc, "embd-normalize", + GraphRef.Params.embd_normalize); + parseJsonWithCastAuto(Doc, "embd-out", + GraphRef.Params.embd_out); + + parseJsonWithCastAuto(Doc, "embd-sep", + GraphRef.Params.embd_sep); + parseJsonWithProcessorAuto( + Doc, "reranking", [&GraphRef](const bool &) -> bool { + GraphRef.Params.embedding = true; + GraphRef.Params.pooling_type = LLAMA_POOLING_TYPE_RANK; + return true; + }); + parseJsonWithCastAuto(Doc, "port", GraphRef.Params.port); + parseJsonWithCastAuto(Doc, "timeout-read", + GraphRef.Params.timeout_read); + parseJsonWithCastAuto(Doc, "timeout-write", + GraphRef.Params.timeout_write); + parseJsonWithCastAuto(Doc, "n-threads-http", + GraphRef.Params.n_threads_http); + parseJsonWithCastAuto(Doc, "n-cache-reuse", + GraphRef.Params.n_cache_reuse); + parseJsonWithCastAuto(Doc, "hostname", + GraphRef.Params.hostname); + parseJsonWithCastAuto(Doc, "public-path", + GraphRef.Params.public_path); + parseJsonWithCastAuto(Doc, "chat-template", + GraphRef.Params.chat_template); + parseJsonAuto(Doc, "enable-chat-template", + GraphRef.Params.enable_chat_template); + parseJsonWithCastAuto(Doc, "ssl-file-key", + GraphRef.Params.ssl_file_key); + parseJsonWithCastAuto(Doc, "ssl-file-cert", + GraphRef.Params.ssl_file_cert); + parseJsonAuto(Doc, "webui", GraphRef.Params.webui); + parseJsonAuto(Doc, "endpoint-slots", GraphRef.Params.endpoint_slots); + parseJsonAuto(Doc, "endpoint-props", GraphRef.Params.endpoint_props); + parseJsonAuto(Doc, "endpoint-metrics", + GraphRef.Params.endpoint_metrics); + parseJsonAuto(Doc, "log-json", GraphRef.Params.log_json); + + // Slot parameters + parseJsonWithCastAuto(Doc, "slot-save-path", + GraphRef.Params.slot_save_path); + + parseJsonWithCastAuto(Doc, "slot-prompt-similarity", + GraphRef.Params.slot_prompt_similarity); + parseJsonAuto(Doc, "is-pp-shared", GraphRef.Params.is_pp_shared); + parseJsonWithProcessorAuto( + Doc, "n-pp", [&GraphRef](const int64_t &NPP) -> bool { + GraphRef.Params.n_pp.emplace_back(NPP); + return true; + }); + parseJsonWithProcessorAuto( + Doc, "n-tg", [&GraphRef](const int64_t &NTG) -> bool { + GraphRef.Params.n_tg.emplace_back(NTG); + return true; + }); + parseJsonWithProcessorAuto( + Doc, "n-pl", [&GraphRef](const int64_t &NPL) -> bool { + GraphRef.Params.n_pl.emplace_back(NPL); + return true; + }); + parseJsonWithProcessorAuto( + Doc, "context-files", + [&GraphRef](const std::string_view &ContextFile) -> bool { + GraphRef.Params.context_files.emplace_back(ContextFile); + return true; + }); + parseJsonWithCastAuto(Doc, "chunk-size", + GraphRef.Params.chunk_size); + parseJsonWithCastAuto(Doc, "chunk-separator", + GraphRef.Params.chunk_separator); + parseJsonWithCastAuto(Doc, "n-junk", GraphRef.Params.n_junk); + parseJsonWithCastAuto(Doc, "i-pos", GraphRef.Params.i_pos); + parseJsonWithCastAuto(Doc, "out-file", + GraphRef.Params.out_file); + parseJsonWithCastAuto(Doc, "n-out-freq", + GraphRef.Params.n_out_freq); + parseJsonWithCastAuto(Doc, "n-save-freq", + GraphRef.Params.n_save_freq); + parseJsonWithCastAuto(Doc, "i-chunk", GraphRef.Params.i_chunk); + parseJsonAuto(Doc, "process-output", GraphRef.Params.process_output); + parseJsonAuto(Doc, "compute-ppl", GraphRef.Params.compute_ppl); + parseJsonWithCastAuto(Doc, "n-pca-batch", + GraphRef.Params.n_pca_batch); + parseJsonWithCastAuto(Doc, "n-pca-iterations", + GraphRef.Params.n_pca_iterations); + parseJsonWithProcessorAuto( + Doc, "cvector-dimre-method", + [&GraphRef](const std::string_view &Method) -> bool { + if (Method == "DIMRE_METHOD_PCA") { + GraphRef.Params.cvector_dimre_method = + dimre_method::DIMRE_METHOD_PCA; + return true; + } + if (Method == "DIMRE_METHOD_MEAN") { + GraphRef.Params.cvector_dimre_method = + dimre_method::DIMRE_METHOD_MEAN; + return true; + } + return false; + }); + parseJsonWithCastAuto( + Doc, "cvector-positive-file", GraphRef.Params.cvector_positive_file); + parseJsonWithCastAuto( + Doc, "cvector-negative-file", GraphRef.Params.cvector_negative_file); + parseJsonAuto(Doc, "spm-infill", GraphRef.Params.spm_infill); + parseJsonWithCastAuto(Doc, "out-file", + GraphRef.Params.out_file); + parseJsonAuto(Doc, "batched-bench-output-jsonl", + GraphRef.Params.batched_bench_output_jsonl); + } catch (const ErrNo &Error) { + return Error; } - if (Doc.at_key("presence-penalty").error() == simdjson::SUCCESS) { - double PresencePenalty; - auto Err = Doc["presence-penalty"].get().get(PresencePenalty); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the presence-penalty option."sv) - } - GraphRef.Params.sampling.penalty_present = - static_cast(std::max(0.0, PresencePenalty)); - } - if (Doc.at_key("frequency-penalty").error() == simdjson::SUCCESS) { - double FrequencyPenalty; - auto Err = Doc["frequency-penalty"].get().get(FrequencyPenalty); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the frequency-penalty option."sv) - } - GraphRef.Params.sampling.penalty_freq = - static_cast(std::max(0.0, FrequencyPenalty)); - } - if (Doc.at_key("dry-multipier").error() == simdjson::SUCCESS) { - double DryMultiplier; - auto Err = Doc["dry-multipier"].get().get(DryMultiplier); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the dry-multipier option."sv) - } - GraphRef.Params.sampling.dry_multiplier = static_cast(DryMultiplier); - } - if (Doc.at_key("dry-base").error() == simdjson::SUCCESS) { - double DryBase; - auto Err = Doc["dry-base"].get().get(DryBase); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the dry-base option."sv) - } - GraphRef.Params.sampling.dry_base = static_cast(DryBase); - } - if (Doc.at_key("dry-allowed-length").error() == simdjson::SUCCESS) { - int64_t DryAllowedLength; - auto Err = Doc["dry-allowed-length"].get().get(DryAllowedLength); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the dry-allowed-length option."sv) - } - GraphRef.Params.sampling.dry_allowed_length = - static_cast(DryAllowedLength); - } - if (Doc.at_key("dry-last-n-penalty").error() == simdjson::SUCCESS) { - int64_t DryLastNPenalty; - auto Err = Doc["dry-last-n-penalty"].get().get(DryLastNPenalty); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the dry-last-n-penalty option."sv) - } - GraphRef.Params.sampling.penalty_last_n = - static_cast(DryLastNPenalty); - } - if (Doc.at_key("mirostat").error() == simdjson::SUCCESS) { - int64_t Mirostat; - auto Err = Doc["mirostat"].get().get(Mirostat); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the mirostat option."sv) - } - GraphRef.Params.sampling.mirostat = static_cast(Mirostat); - } - if (Doc.at_key("mirostat-eta").error() == simdjson::SUCCESS) { - double MirostatEta; - auto Err = Doc["mirostat-eta"].get().get(MirostatEta); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the mirostat-eta option."sv) - } - GraphRef.Params.sampling.mirostat_eta = static_cast(MirostatEta); - } - if (Doc.at_key("ignore-eos").error() == simdjson::SUCCESS) { - auto Err = - Doc["ignore-eos"].get().get(GraphRef.Params.sampling.ignore_eos); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the ignore-eos option."sv) - } - } - if (Doc.at_key("no-perf-sampling").error() == simdjson::SUCCESS) { - auto Err = Doc["no-perf-sampling"].get().get( - GraphRef.Params.sampling.no_perf); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the no-perf-sampling option."sv) - } - } - if (Doc.at_key("timing-per-token").error() == simdjson::SUCCESS) { - auto Err = Doc["timing-per-token"].get().get( - GraphRef.Params.sampling.timing_per_token); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the timing-per-token option."sv) - } - } - if (Doc.at_key("grammar").error() == simdjson::SUCCESS) { - std::string_view Grammar; - auto Err = Doc["grammar"].get().get(Grammar); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the grammar option."sv) - } - GraphRef.Params.sampling.grammar = Grammar; - } - if (Doc.at_key("json-schema").error() == simdjson::SUCCESS) { - std::string_view JsonSchema; - auto Err = Doc["json-schema"].get().get(JsonSchema); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the json-schema option."sv) - } - GraphRef.Params.sampling.grammar = - json_schema_to_grammar(nlohmann::ordered_json::parse(JsonSchema)); - } - if (Doc.at_key("seed").error() == simdjson::SUCCESS) { - uint64_t Seed; - auto Err = Doc["seed"].get().get(Seed); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the seed option."sv) - } - GraphRef.Params.sampling.seed = static_cast(Seed); - } - // The speculative parameters. - if (Doc.at_key("n-ctx-speculative").error() == simdjson::SUCCESS) { - int64_t NCtxSpeculative; - auto Err = Doc["n-ctx-speculative"].get().get(NCtxSpeculative); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the n-ctx-speculative option."sv) - } - GraphRef.Params.speculative.n_ctx = static_cast(NCtxSpeculative); - } - if (Doc.at_key("n-max-speculative").error() == simdjson::SUCCESS) { - int64_t NMaxSpeculative; - auto Err = Doc["n-max-speculative"].get().get(NMaxSpeculative); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the n-max-speculative option."sv) - } - GraphRef.Params.speculative.n_max = static_cast(NMaxSpeculative); - } - if (Doc.at_key("n-min-speculative").error() == simdjson::SUCCESS) { - int64_t NMinSpeculative; - auto Err = Doc["n-min-speculative"].get().get(NMinSpeculative); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the n-min-speculative option."sv) - } - GraphRef.Params.speculative.n_min = static_cast(NMinSpeculative); - } - if (Doc.at_key("n-gpu-layers-speculative").error() == simdjson::SUCCESS) { - int64_t NGPULatersinSpeculative; - auto Err = Doc["n-gpu-layers-speculative"].get().get( - NGPULatersinSpeculative); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the n-gpu-layers-speculative option."sv) - } - GraphRef.Params.speculative.n_gpu_layers = - static_cast(NGPULatersinSpeculative); - } - if (Doc.at_key("p-split-speculative").error() == simdjson::SUCCESS) { - double PSplitSpeculative; - auto Err = Doc["p-split-speculative"].get().get(PSplitSpeculative); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the p-split-speculative option."sv) - } - GraphRef.Params.speculative.p_split = static_cast(PSplitSpeculative); - } - if (Doc.at_key("p-min-speculative").error() == simdjson::SUCCESS) { - double PMinSpeculative; - auto Err = Doc["p-min-speculative"].get().get(PMinSpeculative); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the p-min-speculative option."sv) - } - GraphRef.Params.speculative.p_min = static_cast(PMinSpeculative); - } - // The vocoder parameters. - if (Doc.at_key("hf-repo-vocoder").error() == simdjson::SUCCESS) { - std::string_view HfRepo; - auto Err = Doc["hf-repo-vocoder"].get().get(HfRepo); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the hf-repo-vocoder option."sv) - } - GraphRef.Params.vocoder.model.hf_repo = HfRepo; - } - if (Doc.at_key("hf-file-vocoder").error() == simdjson::SUCCESS) { - std::string_view HfFile; - auto Err = Doc["hf-file-vocoder"].get().get(HfFile); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the hf-file-vocoder option."sv) - } - GraphRef.Params.vocoder.model.hf_file = HfFile; - } - if (Doc.at_key("model-url-vocoder").error() == simdjson::SUCCESS) { - std::string_view ModelUrlVocoder; - auto Err = - Doc["model-url-vocoder"].get().get(ModelUrlVocoder); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the model-url-vocoder option."sv) - } - GraphRef.Params.vocoder.model.url = ModelUrlVocoder; - } - // The config parameters. - if (Doc.at_key("stream-stdout").error() == simdjson::SUCCESS) { - auto Err = Doc["stream-stdout"].get().get(ConfRef.StreamStdout); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the stream-stdout option."sv) - } - } - if (Doc.at_key("n-predict").error() == simdjson::SUCCESS) { - auto Err = Doc["n-predict"].get().get(ConfRef.NPredict); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the n-predict option."sv) - } - } - if (Doc.at_key("reverse-prompt").error() == simdjson::SUCCESS) { - std::string_view ReversePrompt; - auto Err = Doc["reverse-prompt"].get().get(ReversePrompt); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the reverse-prompt option."sv) - } - ConfRef.ReversePrompt = ReversePrompt; - } - if (Doc.at_key("image").error() == simdjson::SUCCESS) { - std::string_view ImagePath; - auto Err = Doc["image"].get().get(ImagePath); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the image option."sv) - } - ConfRef.ImagePath = ImagePath; - } - if (Doc.at_key("always-regenerate-image-embd").error() == simdjson::SUCCESS) { - auto Err = Doc["always-regenerate-image-embd"].get().get( - ConfRef.AlwaysRegenerateImageEmbd); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the always-regenerate-image-embd option."sv) - } - } - if (Doc.at_key("model-alias").error() == simdjson::SUCCESS) { - std::string_view ModelAlias; - auto Err = Doc["model-alias"].get().get(ModelAlias); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the model-alias option."sv) - } - GraphRef.Params.model_alias = ModelAlias; - } - if (Doc.at_key("model-url").error() == simdjson::SUCCESS) { - std::string_view ModelUrl; - auto Err = Doc["model-url"].get().get(ModelUrl); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the model-url option."sv) - } - GraphRef.Params.model.url = ModelUrl; - } - if (Doc.at_key("hf-token").error() == simdjson::SUCCESS) { - std::string_view HfToken; - auto Err = Doc["hf-token"].get().get(HfToken); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the hf-token option."sv) - } - GraphRef.Params.hf_token = HfToken; - } - if (Doc.at_key("hf-repo").error() == simdjson::SUCCESS) { - std::string_view HfRepo; - auto Err = Doc["hf-repo"].get().get(HfRepo); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the hf-repo option."sv) - } - GraphRef.Params.model.hf_repo = HfRepo; - } - if (Doc.at_key("hf-file").error() == simdjson::SUCCESS) { - std::string_view HfFile; - auto Err = Doc["hf-file"].get().get(HfFile); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the hf-file option."sv) - } - GraphRef.Params.model.hf_file = HfFile; - } - if (Doc.at_key("prompt-file").error() == simdjson::SUCCESS) { - std::string_view PromptFile; - auto Err = Doc["prompt-file"].get().get(PromptFile); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the prompt-file option."sv) - } - GraphRef.Params.prompt_file = PromptFile; - } - if (Doc.at_key("path-prompt-cache").error() == simdjson::SUCCESS) { - std::string_view PathPromptCache; - auto Err = - Doc["path-prompt-cache"].get().get(PathPromptCache); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the path-prompt-cache option."sv) - } - GraphRef.Params.path_prompt_cache = PathPromptCache; - } - if (Doc.at_key("input-prefix").error() == simdjson::SUCCESS) { - std::string_view InputPrefix; - auto Err = Doc["input-prefix"].get().get(InputPrefix); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the input-prefix option."sv) - } - GraphRef.Params.input_prefix = InputPrefix; - } - if (Doc.at_key("input-suffix").error() == simdjson::SUCCESS) { - std::string_view InputSuffix; - auto Err = Doc["input-suffix"].get().get(InputSuffix); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the input-suffix option."sv) - } - GraphRef.Params.input_suffix = InputSuffix; - } - if (Doc.at_key("lookup-cache-static").error() == simdjson::SUCCESS) { - std::string_view LookupCacheStatic; - auto Err = Doc["lookup-cache-static"].get().get( - LookupCacheStatic); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the lookup-cache-static option."sv) - } - GraphRef.Params.lookup_cache_static = LookupCacheStatic; - } - if (Doc.at_key("lookup-cache-dynamic").error() == simdjson::SUCCESS) { - std::string_view LookupCacheDynamic; - auto Err = Doc["lookup-cache-dynamic"].get().get( - LookupCacheDynamic); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the lookup-cache-dynamic option."sv) - } - GraphRef.Params.lookup_cache_dynamic = LookupCacheDynamic; - } - if (Doc.at_key("logits-file").error() == simdjson::SUCCESS) { - std::string_view LogitsFile; - auto Err = Doc["logits-file"].get().get(LogitsFile); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the logits-file option."sv) - } - GraphRef.Params.logits_file = LogitsFile; - } - if (Doc.at_key("lora-init-without-apply").error() == simdjson::SUCCESS) { - auto Err = Doc["lora-init-without-apply"].get().get( - GraphRef.Params.lora_init_without_apply); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the lora-init-without-apply option."sv) - } - } - if (Doc.at_key("verbosity").error() == simdjson::SUCCESS) { - int64_t Verbosity; - auto Err = Doc["verbosity"].get().get(Verbosity); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the verbosity option."sv) - } - GraphRef.Params.verbosity = static_cast(Verbosity); - } - if (Doc.at_key("control-vector-layer-start").error() == simdjson::SUCCESS) { - int64_t ControlVectorLayerStart; - auto Err = Doc["control-vector-layer-start"].get().get( - ControlVectorLayerStart); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the control-vector-layer-start option."sv) - } - GraphRef.Params.control_vector_layer_start = - static_cast(ControlVectorLayerStart); - } - if (Doc.at_key("control-vector-layer-end").error() == simdjson::SUCCESS) { - int64_t ControlVectorLayerEnd; - auto Err = Doc["control-vector-layer-end"].get().get( - ControlVectorLayerEnd); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the control-vector-layer-end option."sv) - } - GraphRef.Params.control_vector_layer_end = - static_cast(ControlVectorLayerEnd); - } - if (Doc.at_key("ppl-stride").error() == simdjson::SUCCESS) { - int64_t PplStride; - auto Err = Doc["ppl-stride"].get().get(PplStride); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the ppl-stride option."sv) - } - GraphRef.Params.ppl_stride = static_cast(PplStride); - } - if (Doc.at_key("ppl-output-type").error() == simdjson::SUCCESS) { - int64_t PplOutputType; - auto Err = Doc["ppl-output-type"].get().get(PplOutputType); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the ppl-output-type option."sv) - } - GraphRef.Params.ppl_output_type = static_cast(PplOutputType); - } - if (Doc.at_key("hellaswag").error() == simdjson::SUCCESS) { - auto Err = Doc["hellaswag"].get().get(GraphRef.Params.hellaswag); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the hellaswag option."sv) - } - } - if (Doc.at_key("hellaswag-tasks").error() == simdjson::SUCCESS) { - uint64_t HellaswagTasks; - auto Err = Doc["hellaswag-tasks"].get().get(HellaswagTasks); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the hellaswag-tasks option."sv) - } - GraphRef.Params.hellaswag_tasks = HellaswagTasks; - } - if (Doc.at_key("winogrande").error() == simdjson::SUCCESS) { - auto Err = Doc["winogrande"].get().get(GraphRef.Params.winogrande); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the winogrande option."sv) - } - } - if (Doc.at_key("winogrande-tasks").error() == simdjson::SUCCESS) { - uint64_t WinograndeTasks; - auto Err = Doc["winogrande-tasks"].get().get(WinograndeTasks); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the winogrande-tasks option."sv) - } - GraphRef.Params.winogrande_tasks = WinograndeTasks; - } - if (Doc.at_key("multiple-choice").error() == simdjson::SUCCESS) { - auto Err = - Doc["multiple-choice"].get().get(GraphRef.Params.multiple_choice); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the multiple-choice option."sv) - } - } - if (Doc.at_key("multiple-choice-tasks").error() == simdjson::SUCCESS) { - uint64_t MultipleChoiceTasks; - auto Err = - Doc["multiple-choice-tasks"].get().get(MultipleChoiceTasks); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the multiple-choice-tasks option."sv) - } - GraphRef.Params.multiple_choice_tasks = MultipleChoiceTasks; - } - if (Doc.at_key("kl-divergence").error() == simdjson::SUCCESS) { - auto Err = - Doc["kl-divergence"].get().get(GraphRef.Params.kl_divergence); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the kl-divergence option."sv) - } - } - if (Doc.at_key("usage").error() == simdjson::SUCCESS) { - auto Err = Doc["usage"].get().get(GraphRef.Params.usage); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the usage option."sv) - } - } - if (Doc.at_key("use-color").error() == simdjson::SUCCESS) { - auto Err = Doc["use-color"].get().get(GraphRef.Params.use_color); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the use-color option."sv) - } - } - if (Doc.at_key("special").error() == simdjson::SUCCESS) { - auto Err = Doc["special"].get().get(GraphRef.Params.special); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the special option."sv) - } - } - if (Doc.at_key("interactive").error() == simdjson::SUCCESS) { - auto Err = Doc["interactive"].get().get(GraphRef.Params.interactive); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the interactive option."sv) - } - } - if (Doc.at_key("interactive-first").error() == simdjson::SUCCESS) { - auto Err = Doc["interactive-first"].get().get( - GraphRef.Params.interactive_first); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the interactive-first option."sv) - } - } - if (Doc.at_key("prompt-cache-all").error() == simdjson::SUCCESS) { - auto Err = Doc["prompt-cache-all"].get().get( - GraphRef.Params.prompt_cache_all); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the prompt-cache-all option."sv) - } - } - if (Doc.at_key("prompt-cache-ro").error() == simdjson::SUCCESS) { - auto Err = - Doc["prompt-cache-ro"].get().get(GraphRef.Params.prompt_cache_ro); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the prompt-cache-ro option."sv) - } - } - if (Doc.at_key("escape").error() == simdjson::SUCCESS) { - auto Err = Doc["escape"].get().get(GraphRef.Params.escape); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the escape option."sv) - } - } - if (Doc.at_key("multiline-input").error() == simdjson::SUCCESS) { - auto Err = - Doc["multiline-input"].get().get(GraphRef.Params.multiline_input); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the multiline-input option."sv) - } - } - if (Doc.at_key("simple-io").error() == simdjson::SUCCESS) { - auto Err = Doc["simple-io"].get().get(GraphRef.Params.simple_io); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the simple-io option."sv) - } - } - if (Doc.at_key("cont-batching").error() == simdjson::SUCCESS) { - auto Err = - Doc["cont-batching"].get().get(GraphRef.Params.cont_batching); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the cont-batching option."sv) - } - } - if (Doc.at_key("flash-attn").error() == simdjson::SUCCESS) { - std::string_view FlashAttn; - auto Err = Doc["flash-attn"].get().get(FlashAttn); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the flash-attn option."sv) - } - if (FlashAttn == "on"sv || FlashAttn == "enabled"sv) { - GraphRef.Params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED; - } else if (FlashAttn == "off"sv || FlashAttn == "disabled"sv) { - GraphRef.Params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; - } else if (FlashAttn == "auto"sv) { - GraphRef.Params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; - } else { - RET_ERROR(ErrNo::InvalidArgument, - "The flash-attn option must be one of: on, off, auto."sv) - } - } - if (Doc.at_key("no-perf").error() == simdjson::SUCCESS) { - auto Err = Doc["no-perf"].get().get(GraphRef.Params.no_perf); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the no-perf option."sv) - } - } - if (Doc.at_key("ctx-shift").error() == simdjson::SUCCESS) { - auto Err = Doc["ctx-shift"].get().get(GraphRef.Params.ctx_shift); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the ctx-shift option."sv) - } - } - if (Doc.at_key("input-prefix-bos").error() == simdjson::SUCCESS) { - auto Err = Doc["input-prefix-bos"].get().get( - GraphRef.Params.input_prefix_bos); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the input-prefix-bos option."sv) - } - } - if (Doc.at_key("use-mlock").error() == simdjson::SUCCESS) { - auto Err = Doc["use-mlock"].get().get(GraphRef.Params.use_mlock); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the use-mlock option."sv) - } - } - if (Doc.at_key("use-mmap").error() == simdjson::SUCCESS) { - auto Err = Doc["use-mmap"].get().get(GraphRef.Params.use_mmap); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the use-mmap option."sv) - } - } - if (Doc.at_key("verbose-prompt").error() == simdjson::SUCCESS) { - auto Err = - Doc["verbose-prompt"].get().get(GraphRef.Params.verbose_prompt); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the verbose-prompt option."sv) - } - } - if (Doc.at_key("display-prompt").error() == simdjson::SUCCESS) { - auto Err = - Doc["display-prompt"].get().get(GraphRef.Params.display_prompt); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the display-prompt option."sv) - } - } - if (Doc.at_key("no-kv-offload").error() == simdjson::SUCCESS) { - auto Err = - Doc["no-kv-offload"].get().get(GraphRef.Params.no_kv_offload); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the no-kv-offload option."sv) - } - } - if (Doc.at_key("warmup").error() == simdjson::SUCCESS) { - auto Err = Doc["warmup"].get().get(GraphRef.Params.warmup); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the warmup option."sv) - } - } - if (Doc.at_key("check-tensors").error() == simdjson::SUCCESS) { - auto Err = - Doc["check-tensors"].get().get(GraphRef.Params.check_tensors); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the check-tensors option."sv) - } - } - if (Doc.at_key("cache-type-k").error() == simdjson::SUCCESS) { - int64_t CacheTypeK; - auto Err = Doc["cache-type-k"].get().get(CacheTypeK); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the cache-type-k option."sv) - } - GraphRef.Params.cache_type_k = static_cast(CacheTypeK); - } - if (Doc.at_key("cache-type-v").error() == simdjson::SUCCESS) { - int64_t CacheTypeV; - auto Err = Doc["cache-type-v"].get().get(CacheTypeV); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the cache-type-v option."sv) - } - GraphRef.Params.cache_type_v = static_cast(CacheTypeV); - } - if (Doc.at_key("embd-normalize").error() == simdjson::SUCCESS) { - int64_t EmbdNormalize; - auto Err = Doc["embd-normalize"].get().get(EmbdNormalize); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the embd-normalize option."sv) - } - GraphRef.Params.embd_normalize = static_cast(EmbdNormalize); - } - if (Doc.at_key("embd-out").error() == simdjson::SUCCESS) { - std::string_view EmbdOut; - auto Err = Doc["embd-out"].get().get(EmbdOut); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the embd-out option."sv) - } - GraphRef.Params.embd_out = EmbdOut; - } - if (Doc.at_key("embd-sep").error() == simdjson::SUCCESS) { - std::string_view EmbdSep; - auto Err = Doc["embd-sep"].get().get(EmbdSep); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the embd-sep option."sv) - } - GraphRef.Params.embd_sep = EmbdSep; - } - if (Doc.at_key("reranking").error() == simdjson::SUCCESS) { - bool Reranking = false; - auto Err = Doc["reranking"].get().get(Reranking); - GraphRef.Params.embedding = true; - GraphRef.Params.pooling_type = LLAMA_POOLING_TYPE_RANK; - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the reranking option."sv) - } - } - if (Doc.at_key("port").error() == simdjson::SUCCESS) { - int64_t Port; - auto Err = Doc["port"].get().get(Port); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the port option."sv) - } - GraphRef.Params.port = static_cast(Port); - } - if (Doc.at_key("timeout-read").error() == simdjson::SUCCESS) { - int64_t TimeoutRead; - auto Err = Doc["timeout-read"].get().get(TimeoutRead); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the timeout-read option."sv) - } - GraphRef.Params.timeout_read = static_cast(TimeoutRead); - } - if (Doc.at_key("timeout-write").error() == simdjson::SUCCESS) { - int64_t TimeoutWrite; - auto Err = Doc["timeout-write"].get().get(TimeoutWrite); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the timeout-write option."sv) - } - GraphRef.Params.timeout_write = static_cast(TimeoutWrite); - } - if (Doc.at_key("n-threads-http").error() == simdjson::SUCCESS) { - int64_t NThreadsHttp; - auto Err = Doc["n-threads-http"].get().get(NThreadsHttp); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the n-threads-http option."sv) - } - GraphRef.Params.n_threads_http = static_cast(NThreadsHttp); - } - if (Doc.at_key("n-cache-reuse").error() == simdjson::SUCCESS) { - int64_t NCacheReuse; - auto Err = Doc["n-cache-reuse"].get().get(NCacheReuse); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the n-cache-reuse option."sv) - } - GraphRef.Params.n_cache_reuse = static_cast(NCacheReuse); - } - if (Doc.at_key("hostname").error() == simdjson::SUCCESS) { - std::string_view Hostname; - auto Err = Doc["hostname"].get().get(Hostname); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the hostname option."sv) - } - GraphRef.Params.hostname = Hostname; - } - if (Doc.at_key("public-path").error() == simdjson::SUCCESS) { - std::string_view PublicPath; - auto Err = Doc["public-path"].get().get(PublicPath); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the public-path option."sv) - } - GraphRef.Params.public_path = PublicPath; - } - if (Doc.at_key("chat-template").error() == simdjson::SUCCESS) { - std::string_view ChatTemplate; - auto Err = Doc["chat-template"].get().get(ChatTemplate); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the chat-template option."sv) - } - GraphRef.Params.chat_template = ChatTemplate; - } - if (Doc.at_key("enable-chat-template").error() == simdjson::SUCCESS) { - auto Err = Doc["enable-chat-template"].get().get( - GraphRef.Params.enable_chat_template); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the enable-chat-template option."sv) - } - } - if (Doc.at_key("ssl-file-key").error() == simdjson::SUCCESS) { - std::string_view SslFileKey; - auto Err = Doc["ssl-file-key"].get().get(SslFileKey); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the ssl-file-key option."sv) - } - GraphRef.Params.ssl_file_key = SslFileKey; - } - if (Doc.at_key("ssl-file-cert").error() == simdjson::SUCCESS) { - std::string_view SslFileCert; - auto Err = Doc["ssl-file-cert"].get().get(SslFileCert); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the ssl-file-cert option."sv) - } - GraphRef.Params.ssl_file_cert = SslFileCert; - } - if (Doc.at_key("webui").error() == simdjson::SUCCESS) { - auto Err = Doc["webui"].get().get(GraphRef.Params.webui); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the webui option."sv) - } - } - if (Doc.at_key("endpoint-slots").error() == simdjson::SUCCESS) { - int64_t EndpointSlots; - auto Err = Doc["endpoint-slots"].get().get(EndpointSlots); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the endpoint-slots option."sv) - } - GraphRef.Params.endpoint_slots = static_cast(EndpointSlots); - } - if (Doc.at_key("endpoint-props").error() == simdjson::SUCCESS) { - auto Err = - Doc["endpoint-props"].get().get(GraphRef.Params.endpoint_props); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the endpoint-props option."sv) - } - } - if (Doc.at_key("endpoint-metrics").error() == simdjson::SUCCESS) { - auto Err = Doc["endpoint-metrics"].get().get( - GraphRef.Params.endpoint_metrics); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the endpoint-metrics option."sv) - } - } - if (Doc.at_key("log-json").error() == simdjson::SUCCESS) { - auto Err = Doc["log-json"].get().get(GraphRef.Params.log_json); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the log-json option."sv) - } - } - if (Doc.at_key("slot-save-path").error() == simdjson::SUCCESS) { - std::string_view SlotSavePath; - auto Err = Doc["slot-save-path"].get().get(SlotSavePath); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the slot-save-path option."sv) - } - GraphRef.Params.slot_save_path = SlotSavePath; - } - if (Doc.at_key("slot-prompt-similarity").error() == simdjson::SUCCESS) { - double SlotPromptSimilarity; - auto Err = - Doc["slot-prompt-similarity"].get().get(SlotPromptSimilarity); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the slot-prompt-similarity option."sv) - } - GraphRef.Params.slot_prompt_similarity = - static_cast(SlotPromptSimilarity); - } - if (Doc.at_key("is-pp-shared").error() == simdjson::SUCCESS) { - auto Err = - Doc["is-pp-shared"].get().get(GraphRef.Params.is_pp_shared); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the is-pp-shared option."sv) - } - } - if (Doc.at_key("n-pp").error() == simdjson::SUCCESS) { - int64_t NPP; - auto Err = Doc["n-pp"].get().get(NPP); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the n-pp option."sv) - } - } - if (Doc.at_key("n-tg").error() == simdjson::SUCCESS) { - int64_t NTG; - auto Err = Doc["n-tg"].get().get(NTG); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the n-tg option."sv) - } - } - if (Doc.at_key("n-pl").error() == simdjson::SUCCESS) { - int64_t NPL; - auto Err = Doc["n-pl"].get().get(NPL); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, "Unable to retrieve the n-pl option."sv) - } - } - if (Doc.at_key("context-files").error() == simdjson::SUCCESS) { - std::string_view ContextFiles; - auto Err = Doc["context-files"].get().get(ContextFiles); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the context-files option."sv) - } - } - if (Doc.at_key("chunk-size").error() == simdjson::SUCCESS) { - int64_t ChunkSize; - auto Err = Doc["chunk-size"].get().get(ChunkSize); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the chunk-size option."sv) - } - } - if (Doc.at_key("chunk-separator").error() == simdjson::SUCCESS) { - std::string_view ChunkSeparator; - auto Err = - Doc["chunk-separator"].get().get(ChunkSeparator); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the chunk-separator option."sv) - } - } - if (Doc.at_key("n-junk").error() == simdjson::SUCCESS) { - int64_t NJunk; - auto Err = Doc["n-junk"].get().get(NJunk); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the n-junk option."sv) - } - } - if (Doc.at_key("i-pos").error() == simdjson::SUCCESS) { - int64_t IPos; - auto Err = Doc["i-pos"].get().get(IPos); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the i-pos option."sv) - } - } - if (Doc.at_key("out-file").error() == simdjson::SUCCESS) { - std::string_view OutFile; - auto Err = Doc["out-file"].get().get(OutFile); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the out-file option."sv) - } - } - if (Doc.at_key("n-out-freq").error() == simdjson::SUCCESS) { - int64_t NOutFreq; - auto Err = Doc["n-out-freq"].get().get(NOutFreq); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the n-out-freq option."sv) - } - } - if (Doc.at_key("n-save-freq").error() == simdjson::SUCCESS) { - int64_t NSaveFreq; - auto Err = Doc["n-save-freq"].get().get(NSaveFreq); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the n-save-freq option."sv) - } - } - if (Doc.at_key("i-chunk").error() == simdjson::SUCCESS) { - int64_t IChunk; - auto Err = Doc["i-chunk"].get().get(IChunk); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the i-chunk option."sv) - } - } - if (Doc.at_key("process-output").error() == simdjson::SUCCESS) { - auto Err = - Doc["process-output"].get().get(GraphRef.Params.process_output); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the process-output option."sv) - } - } - if (Doc.at_key("compute-ppl").error() == simdjson::SUCCESS) { - auto Err = Doc["compute-ppl"].get().get(GraphRef.Params.compute_ppl); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the compute-ppl option."sv) - } - } - if (Doc.at_key("n-pca-batch").error() == simdjson::SUCCESS) { - int64_t NPCABatch; - auto Err = Doc["n-pca-batch"].get().get(NPCABatch); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the n-pca-batch option."sv) - } - } - if (Doc.at_key("n-pca-iterations").error() == simdjson::SUCCESS) { - int64_t NPCAIterations; - auto Err = Doc["n-pca-iterations"].get().get(NPCAIterations); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the n-pca-iterations option."sv) - } - } - if (Doc.at_key("cvector-dimre-method").error() == simdjson::SUCCESS) { - std::string_view CVectorDimreMethod; - auto Err = Doc["cvector-dimre-method"].get().get( - CVectorDimreMethod); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the cvector-dimre-method option."sv) - } - } - if (Doc.at_key("cvector-outfile").error() == simdjson::SUCCESS) { - std::string_view CVectorOutfile; - auto Err = - Doc["cvector-outfile"].get().get(CVectorOutfile); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the cvector-outfile option."sv) - } - } - if (Doc.at_key("cvector-positive-file").error() == simdjson::SUCCESS) { - std::string_view CVectorPositiveFile; - auto Err = Doc["cvector-positive-file"].get().get( - CVectorPositiveFile); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the cvector-positive-file option."sv) - } - } - if (Doc.at_key("cvector-negative-file").error() == simdjson::SUCCESS) { - std::string_view CVectorNegativeFile; - auto Err = Doc["cvector-negative-file"].get().get( - CVectorNegativeFile); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the cvector-negative-file option."sv) - } - } - if (Doc.at_key("spm-infill").error() == simdjson::SUCCESS) { - auto Err = Doc["spm-infill"].get().get(GraphRef.Params.spm_infill); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the spm-infill option."sv) - } - } - if (Doc.at_key("out-file").error() == simdjson::SUCCESS) { - std::string_view Outfile; - auto Err = Doc["out-file"].get().get(Outfile); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the outfile option."sv) - } - GraphRef.Params.out_file = Outfile; - } - if (Doc.at_key("batched-bench-output-jsonl").error() == simdjson::SUCCESS) { - auto Err = Doc["batched-bench-output-jsonl"].get().get( - GraphRef.Params.batched_bench_output_jsonl); - if (Err) { - RET_ERROR(ErrNo::InvalidArgument, - "Unable to retrieve the batched-bench-output-jsonl option."sv) - } - } - // The tensor buffer overrides should terminated with empty pattern. if (!GraphRef.TensorBuftOverrides.empty()) { for (const std::string &Override : GraphRef.TensorBuftOverrides) { From 0cf844f677de0319e46c88c09e60bf06e6e9e775 Mon Sep 17 00:00:00 2001 From: cmd05 <63466463+cmd05@users.noreply.github.com> Date: Tue, 7 Oct 2025 18:19:32 +0530 Subject: [PATCH 587/623] feat(plugin): bump the dependencies of wasmedge-ffmpeg, ffmpeg to 7.1 (#4327) * feat: bump the dependencies of wasmedge-ffmpeg, ffmpeg to 7.1 Signed-off-by: cmd05 * build: bump the dependencies of wasmedge-ffmpeg, ffmpeg to 7.1 Signed-off-by: cmd05 --------- Signed-off-by: cmd05 --- plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp | 6 ++-- .../avcodec/avCodecContext.cpp | 22 ++++---------- .../wasmedge_ffmpeg/avcodec/avCodecContext.h | 8 ----- plugins/wasmedge_ffmpeg/avcodec/module.cpp | 3 -- plugins/wasmedge_ffmpeg/avutil/avFrame.cpp | 20 +++---------- plugins/wasmedge_ffmpeg/avutil/avFrame.h | 14 --------- .../wasmedge_ffmpeg/avutil/avutil_func.cpp | 29 +++++++++++++++---- plugins/wasmedge_ffmpeg/avutil/module.cpp | 4 --- plugins/wasmedge_ffmpeg/bindings.h | 19 ------------ .../swresample/swresample_func.cpp | 20 ++++++++++--- .../wasmedge_ffmpeg/avcodec/avCodecCtx.cpp | 18 ------------ .../wasmedge_ffmpeg/avutil/avFrame.cpp | 26 ----------------- ...ockerfile.manylinux2014-build-plugins-deps | 8 ++--- .../Dockerfile.manylinux_2_28-plugins-deps | 8 ++--- utils/docker/Dockerfile.ubuntu-plugins-deps | 12 ++++---- ...fmpeg-v6.0.sh => install-ffmpeg-v7.1.1.sh} | 6 ++-- 16 files changed, 69 insertions(+), 154 deletions(-) rename utils/ffmpeg/{install-ffmpeg-v6.0.sh => install-ffmpeg-v7.1.1.sh} (59%) mode change 100755 => 100644 diff --git a/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp b/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp index 3ba24d47..565a161a 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp +++ b/plugins/wasmedge_ffmpeg/avcodec/avCodec.cpp @@ -179,7 +179,7 @@ AVCodecSupportedSampleRatesIter::body(const Runtime::CallingFrame &, Expect AVCodecChannelLayoutIsNull::body(const Runtime::CallingFrame &, uint32_t AvCodecId) { FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); - if (AvCodec->channel_layouts == nullptr) { + if (AvCodec->ch_layouts == nullptr) { return 1; } return 0; @@ -190,7 +190,7 @@ Expect AVCodecChannelLayoutIter::body(const Runtime::CallingFrame &, uint32_t Idx) { FFMPEG_PTR_FETCH(AvCodec, AvCodecId, const AVCodec); - const uint64_t *ChannelLayout = AvCodec->channel_layouts; + const AVChannelLayout *ChannelLayout = AvCodec->ch_layouts; if (ChannelLayout == nullptr) { return 0; } @@ -201,7 +201,7 @@ Expect AVCodecChannelLayoutIter::body(const Runtime::CallingFrame &, Curr++; } - return FFmpegUtils::ChannelLayout::intoChannelLayoutID(*ChannelLayout); + return FFmpegUtils::ChannelLayout::intoChannelLayoutID(ChannelLayout->u.mask); } Expect AVCodecSampleFmtsIsNull::body(const Runtime::CallingFrame &, diff --git a/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.cpp b/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.cpp index ad6de103..c50316fa 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.cpp +++ b/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.cpp @@ -120,7 +120,7 @@ Expect AVCodecCtxChannelLayout::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); // Deprecated method - uint64_t const AvChannel = AvCodecCtx->channel_layout; + uint64_t const AvChannel = AvCodecCtx->ch_layout.u.mask; return FFmpegUtils::ChannelLayout::intoChannelLayoutID(AvChannel); } @@ -130,7 +130,7 @@ Expect AVCodecCtxSetChannelLayout::body(const Runtime::CallingFrame &, FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); uint64_t const AvChannel = FFmpegUtils::ChannelLayout::fromChannelLayoutID(ChannelLayoutId); - AvCodecCtx->channel_layout = AvChannel; + av_channel_layout_from_mask(&AvCodecCtx->ch_layout, AvChannel); return static_cast(ErrNo::Success); } @@ -566,14 +566,14 @@ Expect AVCodecCtxCodec::body(const Runtime::CallingFrame &Frame, Expect AVCodecCtxChannels::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); - return AvCodecCtx->channels; + return AvCodecCtx->ch_layout.nb_channels; } Expect AVCodecCtxSetChannels::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t Channels) { FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); - AvCodecCtx->channels = Channels; + AvCodecCtx->ch_layout.nb_channels = Channels; return static_cast(ErrNo::Success); } @@ -658,7 +658,7 @@ Expect AVCodecCtxSetSliceCount::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId, int32_t Value) { FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); - AvCodecCtx->slice_count = Value; + AvCodecCtx->slices = Value; return static_cast(ErrNo::Success); } @@ -687,7 +687,7 @@ AVCodecCtxChromaSampleLocation::body(const Runtime::CallingFrame &, Expect AVCodecCtxFrameNumber::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); - return AvCodecCtx->frame_number; + return AvCodecCtx->frame_num; } Expect AVCodecCtxBlockAlign::body(const Runtime::CallingFrame &, @@ -720,16 +720,6 @@ Expect AVCodecCtxHasBFrames::body(const Runtime::CallingFrame &, return AvCodecCtx->has_b_frames; } -Expect -AVCodecCtxSetRequestChannelLayout::body(const Runtime::CallingFrame &, - uint32_t AvCodecCtxId, - uint64_t ChannelLayoutId) { - FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); - AvCodecCtx->request_channel_layout = - FFmpegUtils::ChannelLayout::fromChannelLayoutID(ChannelLayoutId); - return static_cast(ErrNo::Success); -} - Expect AVCodecCtxActiveThreadType::body(const Runtime::CallingFrame &, uint32_t AvCodecCtxId) { FFMPEG_PTR_FETCH(AvCodecCtx, AvCodecCtxId, AVCodecContext); diff --git a/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.h b/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.h index 88d86558..d1696074 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.h +++ b/plugins/wasmedge_ffmpeg/avcodec/avCodecContext.h @@ -636,14 +636,6 @@ class AVCodecCtxHasBFrames : public HostFunction { uint32_t AvCodecCtxId); }; -class AVCodecCtxSetRequestChannelLayout - : public HostFunction { -public: - using HostFunction::HostFunction; - Expect body(const Runtime::CallingFrame &Frame, - uint32_t AvCodecCtxId, uint64_t ChannelLayoutId); -}; - class AVCodecCtxActiveThreadType : public HostFunction { public: diff --git a/plugins/wasmedge_ffmpeg/avcodec/module.cpp b/plugins/wasmedge_ffmpeg/avcodec/module.cpp index 5f082ba8..ddae5636 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/module.cpp +++ b/plugins/wasmedge_ffmpeg/avcodec/module.cpp @@ -250,9 +250,6 @@ WasmEdgeFFmpegAVCodecModule::WasmEdgeFFmpegAVCodecModule( std::make_unique(Env)); addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_has_b_frames", std::make_unique(Env)); - addHostFunc( - "wasmedge_ffmpeg_avcodec_avcodeccontext_set_request_channel_layout", - std::make_unique(Env)); addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_active_thread_type", std::make_unique(Env)); addHostFunc("wasmedge_ffmpeg_avcodec_avcodeccontext_set_thread_type", diff --git a/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp b/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp index 2dc1449b..6f0fab2d 100644 --- a/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp +++ b/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp @@ -137,7 +137,7 @@ Expect AVFrameSetChannelLayout::body(const Runtime::CallingFrame &, FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); uint64_t const ChannelLayout = FFmpegUtils::ChannelLayout::fromChannelLayoutID(ChannelLayoutID); - AvFrame->channel_layout = ChannelLayout; + av_channel_layout_from_mask(&AvFrame->ch_layout, ChannelLayout); return static_cast(ErrNo::Success); } @@ -171,20 +171,20 @@ Expect AVFrameSetSampleRate::body(const Runtime::CallingFrame &, Expect AVFrameChannels::body(const Runtime::CallingFrame &, uint32_t FrameId) { FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); - return AvFrame->channels; + return AvFrame->ch_layout.nb_channels; } Expect AVFrameSetChannels::body(const Runtime::CallingFrame &, uint32_t FrameId, int32_t Channels) { FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); - AvFrame->channels = Channels; + AvFrame->ch_layout.nb_channels = Channels; return static_cast(ErrNo::Success); } Expect AVFrameChannelLayout::body(const Runtime::CallingFrame &, uint32_t FrameId) { FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); - uint64_t const ChannelLayout = AvFrame->channel_layout; + uint64_t const ChannelLayout = AvFrame->ch_layout.u.mask; return FFmpegUtils::ChannelLayout::intoChannelLayoutID(ChannelLayout); } @@ -286,18 +286,6 @@ Expect AVFrameChromaLocation::body(const Runtime::CallingFrame &, return FFmpegUtils::ChromaLocation::fromAVChromaLocation(AvChromaLocation); } -Expect AVFrameCodedPictureNumber::body(const Runtime::CallingFrame &, - uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); - return AvFrame->coded_picture_number; -} - -Expect AVFrameDisplayPictureNumber::body(const Runtime::CallingFrame &, - uint32_t FrameId) { - FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); - return AvFrame->display_picture_number; -} - Expect AVFrameRepeatPict::body(const Runtime::CallingFrame &, uint32_t FrameId) { FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); diff --git a/plugins/wasmedge_ffmpeg/avutil/avFrame.h b/plugins/wasmedge_ffmpeg/avutil/avFrame.h index 9e4a3639..f8c87458 100644 --- a/plugins/wasmedge_ffmpeg/avutil/avFrame.h +++ b/plugins/wasmedge_ffmpeg/avutil/avFrame.h @@ -241,20 +241,6 @@ class AVFrameChromaLocation : public HostFunction { Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); }; -class AVFrameCodedPictureNumber - : public HostFunction { -public: - using HostFunction::HostFunction; - Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); -}; - -class AVFrameDisplayPictureNumber - : public HostFunction { -public: - using HostFunction::HostFunction; - Expect body(const Runtime::CallingFrame &Frame, uint32_t FrameId); -}; - class AVFrameRepeatPict : public HostFunction { public: using HostFunction::HostFunction; diff --git a/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp b/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp index 6dfeaf01..673a3e77 100644 --- a/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp +++ b/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp @@ -59,15 +59,25 @@ AVGetChannelLayoutNbChannels::body(const Runtime::CallingFrame &, uint64_t ChannelLayoutId) { uint64_t const ChannelLayout = FFmpegUtils::ChannelLayout::fromChannelLayoutID(ChannelLayoutId); - return av_get_channel_layout_nb_channels(ChannelLayout); + + AVChannelLayout TmpChLayout; + av_channel_layout_from_mask(&TmpChLayout, ChannelLayout); + int32_t ChannelLayoutNbChannels = TmpChLayout.nb_channels; + av_channel_layout_uninit(&TmpChLayout); + + return ChannelLayoutNbChannels; } Expect AVGetChannelLayoutNameLen::body(const Runtime::CallingFrame &, uint64_t ChannelLayoutId) { uint64_t const ChannelLayout = FFmpegUtils::ChannelLayout::fromChannelLayoutID(ChannelLayoutId); - const char *ChName = av_get_channel_name(ChannelLayout); - if (ChName == nullptr) { + char ChName[16] = {0}; // bufsize based on AVChannelCustom.name + // mask ChannelLayout to AVChannel before passing + int Code = + av_channel_name(ChName, 16, static_cast(ChannelLayout >> 1)); + + if (Code < 0) { return 0; } return strlen(ChName); @@ -82,7 +92,9 @@ Expect AVGetChannelLayoutName::body(const Runtime::CallingFrame &Frame, uint64_t const ChannelLayout = FFmpegUtils::ChannelLayout::fromChannelLayoutID(ChannelLayoutId); - const char *ChName = av_get_channel_name(ChannelLayout); + char ChName[16] = {0}; // bufsize based on AVChannelCustom.name + // mask ChannelLayout to AVChannel before passing + av_channel_name(ChName, 16, static_cast(ChannelLayout >> 1)); std::copy_n(ChName, NameLen, NameBuf.data()); return static_cast(ErrNo::Success); @@ -97,8 +109,13 @@ Expect AVGetChannelLayoutMask::body(const Runtime::CallingFrame &, Expect AVGetDefaultChannelLayout::body(const Runtime::CallingFrame &, int32_t Number) { - uint64_t const ChannelLayout = av_get_default_channel_layout(Number); - return FFmpegUtils::ChannelLayout::intoChannelLayoutID(ChannelLayout); + AVChannelLayout TmpChLayout; + av_channel_layout_default(&TmpChLayout, Number); + uint64_t DefaultChannelLayout = + FFmpegUtils::ChannelLayout::intoChannelLayoutID(TmpChLayout.u.mask); + av_channel_layout_uninit(&TmpChLayout); + + return DefaultChannelLayout; } Expect AVUtilConfigurationLength::body(const Runtime::CallingFrame &) { diff --git a/plugins/wasmedge_ffmpeg/avutil/module.cpp b/plugins/wasmedge_ffmpeg/avutil/module.cpp index 2dd207c3..948ff3ec 100644 --- a/plugins/wasmedge_ffmpeg/avutil/module.cpp +++ b/plugins/wasmedge_ffmpeg/avutil/module.cpp @@ -114,10 +114,6 @@ WasmEdgeFFmpegAVUtilModule::WasmEdgeFFmpegAVUtilModule( std::make_unique(Env)); addHostFunc("wasmedge_ffmpeg_avutil_av_frame_chroma_location", std::make_unique(Env)); - addHostFunc("wasmedge_ffmpeg_avutil_av_frame_coded_picture_number", - std::make_unique(Env)); - addHostFunc("wasmedge_ffmpeg_avutil_av_frame_display_picture_number", - std::make_unique(Env)); addHostFunc("wasmedge_ffmpeg_avutil_av_frame_repeat_pict", std::make_unique(Env)); addHostFunc("wasmedge_ffmpeg_avutil_av_frame_flags", diff --git a/plugins/wasmedge_ffmpeg/bindings.h b/plugins/wasmedge_ffmpeg/bindings.h index ac8fb6fb..22b21370 100644 --- a/plugins/wasmedge_ffmpeg/bindings.h +++ b/plugins/wasmedge_ffmpeg/bindings.h @@ -460,8 +460,6 @@ class CodecID { return AV_CODEC_ID_012V; case 197: return AV_CODEC_ID_AVUI; - case 198: - return AV_CODEC_ID_AYUV; case 199: return AV_CODEC_ID_TARGA_Y216; case 200: @@ -1516,8 +1514,6 @@ class CodecID { return 196; case AV_CODEC_ID_AVUI: return 197; - case AV_CODEC_ID_AYUV: - return 198; case AV_CODEC_ID_TARGA_Y216: return 199; case AV_CODEC_ID_V308: @@ -2518,8 +2514,6 @@ class PixFmt { return 172; case AV_PIX_FMT_VIDEOTOOLBOX: return 173; - case AV_PIX_FMT_XVMC: - return 174; // case AV_PIX_FMT_RGB32: // IF format is this type, based on // endianness, it resolves to big endian or small endian. // return 175; // The Switch case contains both big and @@ -3123,8 +3117,6 @@ class PixFmt { return AV_PIX_FMT_AYUV64BE; case 173: return AV_PIX_FMT_VIDEOTOOLBOX; - case 174: - return AV_PIX_FMT_XVMC; case 175: return AV_PIX_FMT_RGB32; case 176: @@ -3483,7 +3475,6 @@ class ChannelLayout { const static uint64_t SURROUND_DIRECT_LEFT = 1ULL << 22; const static uint64_t SURROUND_DIRECT_RIGHT = 1ULL << 23; const static uint64_t LOW_FREQUENCY_2 = 1ULL << 24; - const static uint64_t NATIVE = 1ULL << 25; const static uint64_t MONO = 1ULL << 26; const static uint64_t STEREO = 1ULL << 27; @@ -3593,9 +3584,6 @@ class ChannelLayout { if (ChannelLayout & LOW_FREQUENCY_2) { Channel |= AV_CH_LOW_FREQUENCY_2; } - if (ChannelLayout & NATIVE) { - Channel |= AV_CH_LAYOUT_NATIVE; - } if (ChannelLayout & MONO) { Channel |= AV_CH_LAYOUT_MONO; } @@ -3767,9 +3755,6 @@ class ChannelLayout { } // Channel Mask C; - if ((ChannelLayout & AV_CH_LAYOUT_NATIVE) == AV_CH_LAYOUT_NATIVE) { - Channel |= NATIVE; - } if ((ChannelLayout & AV_CH_LAYOUT_MONO) == AV_CH_LAYOUT_MONO) { Channel |= MONO; } @@ -4117,8 +4102,6 @@ class OptionType { return AV_OPT_TYPE_DURATION; case 15: return AV_OPT_TYPE_COLOR; - case 16: - return AV_OPT_TYPE_CHANNEL_LAYOUT; case 17: return AV_OPT_TYPE_UINT64; case 18: @@ -4164,8 +4147,6 @@ class OptionType { return 14; case AV_OPT_TYPE_COLOR: return 15; - case AV_OPT_TYPE_CHANNEL_LAYOUT: - return 16; case AV_OPT_TYPE_UINT64: return 17; case AV_OPT_TYPE_BOOL: diff --git a/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp b/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp index 3c81578c..f26e33ca 100644 --- a/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp +++ b/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp @@ -49,10 +49,22 @@ SWRAllocSetOpts::body(const Runtime::CallingFrame &Frame, uint32_t SwrCtxPtr, FFmpegUtils::ChannelLayout::fromChannelLayoutID(InChLayoutId); AVSampleFormat const InSampleFmt = FFmpegUtils::SampleFmt::fromSampleID(InSampleFmtId); - CurrSwrCtx = swr_alloc_set_opts( - ExistSWRContext, OutChLayout, OutSampleFmt, OutSampleRate, InChLayout, - InSampleFmt, InSampleRate, LogOffset, - nullptr); // Always being used as null in rust sdk. + + AVChannelLayout AVOutChLayout; + av_channel_layout_from_mask(&AVOutChLayout, OutChLayout); + + AVChannelLayout AVInChLayout; + av_channel_layout_from_mask(&AVInChLayout, InChLayout); + + swr_alloc_set_opts2(&ExistSWRContext, &AVOutChLayout, OutSampleFmt, + OutSampleRate, &AVInChLayout, InSampleFmt, InSampleRate, + LogOffset, + nullptr); // Always being used as null in rust sdk. + CurrSwrCtx = ExistSWRContext; + + av_channel_layout_uninit(&AVOutChLayout); + av_channel_layout_uninit(&AVInChLayout); + FFMPEG_PTR_STORE(CurrSwrCtx, SwrCtxId); return static_cast(ErrNo::Success); } diff --git a/test/plugins/wasmedge_ffmpeg/avcodec/avCodecCtx.cpp b/test/plugins/wasmedge_ffmpeg/avcodec/avCodecCtx.cpp index d7adca94..3407f4e7 100644 --- a/test/plugins/wasmedge_ffmpeg/avcodec/avCodecCtx.cpp +++ b/test/plugins/wasmedge_ffmpeg/avcodec/avCodecCtx.cpp @@ -1553,24 +1553,6 @@ TEST_F(FFmpegTest, AVCodecCtx) { ASSERT_TRUE(Result[0].get() > 0); } - FuncInst = AVCodecMod->findFuncExports( - "wasmedge_ffmpeg_avcodec_avcodeccontext_set_request_channel_layout"); - EXPECT_NE(FuncInst, nullptr); - EXPECT_TRUE(FuncInst->isHostFunction()); - - auto &HostFuncAVCodecCtxSetRequestChannelLayout = dynamic_cast< - WasmEdge::Host::WasmEdgeFFmpeg::AVcodec::AVCodecCtxSetRequestChannelLayout - &>(FuncInst->getHostFunc()); - - { - EXPECT_TRUE(HostFuncAVCodecCtxSetRequestChannelLayout.run( - CallFrame, - std::initializer_list{AVCodecCtxId, - ChannelLayoutId}, - Result)); - EXPECT_EQ(Result[0].get(), static_cast(ErrNo::Success)); - } - FuncInst = AVCodecMod->findFuncExports( "wasmedge_ffmpeg_avcodec_avcodeccontext_active_thread_type"); EXPECT_NE(FuncInst, nullptr); diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp index b30fe2ab..dd20f625 100644 --- a/test/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp +++ b/test/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp @@ -266,32 +266,6 @@ TEST_F(FFmpegTest, AVFrame) { EXPECT_EQ(Result[0].get(), 1); } - FuncInst = AVUtilMod->findFuncExports( - "wasmedge_ffmpeg_avutil_av_frame_coded_picture_number"); - auto &HostAVFrameCodedPictureNumber = dynamic_cast< - WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameCodedPictureNumber &>( - FuncInst->getHostFunc()); - - { - HostAVFrameCodedPictureNumber.run( - CallFrame, std::initializer_list{AVFrameId}, - Result); - EXPECT_EQ(Result[0].get(), 0); - } - - FuncInst = AVUtilMod->findFuncExports( - "wasmedge_ffmpeg_avutil_av_frame_display_picture_number"); - auto &HostAVFrameDisplayPictureNumber = dynamic_cast< - WasmEdge::Host::WasmEdgeFFmpeg::AVUtil::AVFrameDisplayPictureNumber &>( - FuncInst->getHostFunc()); - - { - HostAVFrameDisplayPictureNumber.run( - CallFrame, std::initializer_list{AVFrameId}, - Result); - EXPECT_EQ(Result[0].get(), 0); - } - FuncInst = AVUtilMod->findFuncExports("wasmedge_ffmpeg_avutil_av_frame_repeat_pict"); auto &HostAVFrameRepeatPict = diff --git a/utils/docker/Dockerfile.manylinux2014-build-plugins-deps b/utils/docker/Dockerfile.manylinux2014-build-plugins-deps index 1a8213e1..aa479a0d 100644 --- a/utils/docker/Dockerfile.manylinux2014-build-plugins-deps +++ b/utils/docker/Dockerfile.manylinux2014-build-plugins-deps @@ -27,10 +27,10 @@ COPY wasi-crypto/build-openssl.sh . ENV OPENSSL_ROOT_DIR "/root/openssl-1.1.1n/openssl" RUN [ "/bin/bash", "build-openssl.sh" ] -COPY ffmpeg/install-ffmpeg-v6.0.sh . -RUN [ "/bin/bash", "install-ffmpeg-v6.0.sh" ] -ENV PKG_CONFIG_PATH /root/FFmpeg-n6.0/output/lib/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} -ENV LD_LIBRARY_PATH /root/FFmpeg-n6.0/output/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} +COPY ffmpeg/install-ffmpeg-v7.1.1.sh . +RUN [ "/bin/bash", "install-ffmpeg-v7.1.1.sh" ] +ENV PKG_CONFIG_PATH /root/FFmpeg-n7.1.1/output/lib/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} +ENV LD_LIBRARY_PATH /root/FFmpeg-n7.1.1/output/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} ENV OPENVINO_VERSION "2025.0.0" ENV OPENVINO_YEAR "2025" diff --git a/utils/docker/Dockerfile.manylinux_2_28-plugins-deps b/utils/docker/Dockerfile.manylinux_2_28-plugins-deps index 484ee476..3249c688 100644 --- a/utils/docker/Dockerfile.manylinux_2_28-plugins-deps +++ b/utils/docker/Dockerfile.manylinux_2_28-plugins-deps @@ -30,10 +30,10 @@ COPY wasi-crypto/build-openssl.sh . ENV OpenSSL_DIR="/root/openssl-1.1.1n/openssl" RUN [ "/bin/bash", "build-openssl.sh" ] -COPY ffmpeg/install-ffmpeg-v6.0.sh . -RUN [ "/bin/bash", "install-ffmpeg-v6.0.sh" ] -ENV PKG_CONFIG_PATH=/root/FFmpeg-n6.0/output/lib/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} -ENV LD_LIBRARY_PATH=/root/FFmpeg-n6.0/output/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} +COPY ffmpeg/install-ffmpeg-v7.1.1.sh . +RUN [ "/bin/bash", "install-ffmpeg-v7.1.1.sh" ] +ENV PKG_CONFIG_PATH=/root/FFmpeg-n7.1.1/output/lib/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} +ENV LD_LIBRARY_PATH=/root/FFmpeg-n7.1.1/output/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} ENV OPENVINO_VERSION="2025.0.0" ENV OPENVINO_YEAR="2025" diff --git a/utils/docker/Dockerfile.ubuntu-plugins-deps b/utils/docker/Dockerfile.ubuntu-plugins-deps index 4aaaffcc..460b0170 100644 --- a/utils/docker/Dockerfile.ubuntu-plugins-deps +++ b/utils/docker/Dockerfile.ubuntu-plugins-deps @@ -30,15 +30,15 @@ RUN apt-get install -y \ libswresample-dev \ libswscale-dev -# FFmpeg 6.0 (ubuntu 20.04, 22.04) +# FFmpeg 7.1.1 (ubuntu 20.04, 22.04) FROM base AS deps-20 WORKDIR /usr/local -COPY ffmpeg/install-ffmpeg-v6.0.sh . -RUN [ "/bin/bash", "install-ffmpeg-v6.0.sh" ] -ENV PKG_CONFIG_PATH=/usr/local/FFmpeg-n6.0/output/lib/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} -ENV LD_LIBRARY_PATH=/usr/local/FFmpeg-n6.0/output/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} +COPY ffmpeg/install-ffmpeg-v7.1.1.sh . +RUN [ "/bin/bash", "install-ffmpeg-v7.1.1.sh" ] +ENV PKG_CONFIG_PATH=/usr/local/FFmpeg-n7.1.1/output/lib/pkgconfig${PKG_CONFIG_PATH:+:${PKG_CONFIG_PATH}} +ENV LD_LIBRARY_PATH=/usr/local/FFmpeg-n7.1.1/output/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} FROM deps-20 AS deps-22 @@ -83,7 +83,7 @@ FROM deps-all AS clean-apt RUN rm -f \ install-opencvmini.sh \ - install-ffmpeg-v6.0.sh \ + install-ffmpeg-v7.1.1.sh \ install-pytorch.sh \ install-openvino.sh \ install-onnxruntime.sh \ diff --git a/utils/ffmpeg/install-ffmpeg-v6.0.sh b/utils/ffmpeg/install-ffmpeg-v7.1.1.sh old mode 100755 new mode 100644 similarity index 59% rename from utils/ffmpeg/install-ffmpeg-v6.0.sh rename to utils/ffmpeg/install-ffmpeg-v7.1.1.sh index d02dc733..4e049948 --- a/utils/ffmpeg/install-ffmpeg-v6.0.sh +++ b/utils/ffmpeg/install-ffmpeg-v7.1.1.sh @@ -1,12 +1,12 @@ #!/usr/bin/env bash set -e -curl -sL https://github.com/FFmpeg/FFmpeg/archive/refs/tags/n6.0.zip -o ffmpeg.zip +curl -sL https://github.com/FFmpeg/FFmpeg/archive/refs/tags/n7.1.1.zip -o ffmpeg.zip unzip ffmpeg.zip -mkdir -p FFmpeg-n6.0/output -cd FFmpeg-n6.0 +mkdir -p FFmpeg-n7.1.1/output +cd FFmpeg-n7.1.1 ./configure --prefix=$(pwd)/output --enable-gpl --enable-nonfree --enable-shared --disable-static make && make install cd .. From 352acdcda735a60877265361e429d94f7777ea57 Mon Sep 17 00:00:00 2001 From: Shen-Ta Hsieh Date: Sun, 12 Oct 2025 05:53:37 +0000 Subject: [PATCH 588/623] fix(WASI-NN): use binary mode for binary file Signed-off-by: Shen-Ta Hsieh --- plugins/wasi_nn/wasinn_ggml.cpp | 9 ++++++--- plugins/wasi_nn/wasinn_mlx.cpp | 4 +++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp index 356179e0..52575667 100644 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ b/plugins/wasi_nn/wasinn_ggml.cpp @@ -1493,7 +1493,9 @@ ErrNo codesToSpeech(WasiNNEnvironment &Env, Graph &GraphRef, // Save .wav file if path is provided. if (!GraphRef.TTSOutputFilePath.empty()) { - WasmEdge::FStream::OFStream File(GraphRef.TTSOutputFilePath, Env.getEnv()); + WasmEdge::FStream::OFStream File(GraphRef.TTSOutputFilePath, + std::ios_base::out | std::ios_base::binary, + Env.getEnv()); if (!File) { RET_ERROR(ErrNo::RuntimeError, "codesToSpeech: Failed to open file '{}' for writing"sv, @@ -1581,8 +1583,9 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // TODO: pass the model directly to ggml. // Write ggml model to file. GraphRef.Params.model.path = "ggml-model.bin"sv; - WasmEdge::FStream::OFStream TempFile(GraphRef.Params.model.path, - Env.getEnv()); + WasmEdge::FStream::OFStream TempFile( + GraphRef.Params.model.path, std::ios_base::out | std::ios_base::binary, + Env.getEnv()); if (!TempFile) { Env.deleteGraph(GId.raw()); RET_ERROR(ErrNo::InvalidArgument, diff --git a/plugins/wasi_nn/wasinn_mlx.cpp b/plugins/wasi_nn/wasinn_mlx.cpp index 66d2a43a..4ca60e58 100644 --- a/plugins/wasi_nn/wasinn_mlx.cpp +++ b/plugins/wasi_nn/wasinn_mlx.cpp @@ -341,7 +341,9 @@ Expect load(WASINN::WasiNNEnvironment &Env, // Write model to file. // TODO: handle different model format. ModelFilePath = "MLX" + std::to_string(Idx) + ".safetensors"; - WasmEdge::FStream::OFStream TempFile(ModelFilePath, Env.getEnv()); + WasmEdge::FStream::OFStream TempFile( + ModelFilePath, std::ios_base::out | std::ios_base::binary, + Env.getEnv()); if (!TempFile) { spdlog::error( "[WASI-NN] MLX backend: Failed to create the temporary file. "sv); From 1ba45514f64511439a9c96843ef945a2c6a35725 Mon Sep 17 00:00:00 2001 From: Han-Wen Tsao Date: Tue, 14 Oct 2025 19:39:58 +0800 Subject: [PATCH 589/623] refactor(WASI-NN/GGML): refactor function structure (#4383) Signed-off-by: grorge --- plugins/wasi_nn/CMakeLists.txt | 16 +- .../wasi_nn/GGML/compute/compute_engine.cpp | 144 ++ .../GGML/compute/inference_manager.cpp | 265 ++ .../wasi_nn/GGML/compute/inference_manager.h | 16 + plugins/wasi_nn/GGML/core/ggml_core.cpp | 376 +++ plugins/wasi_nn/GGML/core/ggml_core.h | 47 + .../{wasinn_ggml.h => GGML/core/ggml_type.h} | 72 +- plugins/wasi_nn/GGML/core/input_processor.cpp | 294 +++ .../wasi_nn/GGML/core/output_generator.cpp | 72 + .../wasi_nn/GGML/metadata/metadata_parser.cpp | 552 ++++ .../wasi_nn/GGML/metadata/metadata_parser.h | 77 + plugins/wasi_nn/GGML/tts/tts_core.cpp | 480 ++++ plugins/wasi_nn/GGML/tts/tts_core.h | 50 + plugins/wasi_nn/GGML/utils.cpp | 86 + plugins/wasi_nn/GGML/utils.h | 23 + plugins/wasi_nn/wasinn_ggml.cpp | 2289 ----------------- plugins/wasi_nn/wasinnenv.h | 2 +- 17 files changed, 2529 insertions(+), 2332 deletions(-) create mode 100644 plugins/wasi_nn/GGML/compute/compute_engine.cpp create mode 100644 plugins/wasi_nn/GGML/compute/inference_manager.cpp create mode 100644 plugins/wasi_nn/GGML/compute/inference_manager.h create mode 100644 plugins/wasi_nn/GGML/core/ggml_core.cpp create mode 100644 plugins/wasi_nn/GGML/core/ggml_core.h rename plugins/wasi_nn/{wasinn_ggml.h => GGML/core/ggml_type.h} (58%) create mode 100644 plugins/wasi_nn/GGML/core/input_processor.cpp create mode 100644 plugins/wasi_nn/GGML/core/output_generator.cpp create mode 100644 plugins/wasi_nn/GGML/metadata/metadata_parser.cpp create mode 100644 plugins/wasi_nn/GGML/metadata/metadata_parser.h create mode 100644 plugins/wasi_nn/GGML/tts/tts_core.cpp create mode 100644 plugins/wasi_nn/GGML/tts/tts_core.h create mode 100644 plugins/wasi_nn/GGML/utils.cpp create mode 100644 plugins/wasi_nn/GGML/utils.h delete mode 100644 plugins/wasi_nn/wasinn_ggml.cpp diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index faba2ce1..2770a556 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -12,7 +12,7 @@ wasmedge_add_library(wasmedgePluginWasiNN wasinn_tf.cpp wasinn_torch.cpp wasinn_tfl.cpp - wasinn_ggml.cpp + GGML/core/ggml_core.cpp wasinn_neuralspeed.cpp wasinn_piper.cpp wasinn_whisper.cpp @@ -72,6 +72,20 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) MLX/mlx/quantized.cpp ) endif() + + if(BACKEND STREQUAL "ggml") + target_sources(wasmedgePluginWasiNN + PRIVATE + GGML/core/ggml_core.cpp + GGML/core/input_processor.cpp + GGML/core/output_generator.cpp + GGML/metadata/metadata_parser.cpp + GGML/compute/compute_engine.cpp + GGML/compute/inference_manager.cpp + GGML/tts/tts_core.cpp + GGML/utils.cpp + ) + endif() endforeach() target_compile_options(wasmedgePluginWasiNN diff --git a/plugins/wasi_nn/GGML/compute/compute_engine.cpp b/plugins/wasi_nn/GGML/compute/compute_engine.cpp new file mode 100644 index 00000000..5f569d0c --- /dev/null +++ b/plugins/wasi_nn/GGML/compute/compute_engine.cpp @@ -0,0 +1,144 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "GGML/core/ggml_core.h" +#include "GGML/tts/tts_core.h" +#include "inference_manager.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +#include +#include +#endif + +namespace WasmEdge::Host::WASINN::GGML { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML + +Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "compute") + + // Clear the context and reset the sampler. + clearContext(GraphRef, CxtRef); + + if (GraphRef.Params.embedding) { + return getEmbedding(GraphRef, CxtRef); + } + + // Evaluate the input tokens. + ErrNo ReturnCode = ErrNo::Success; + if (GraphRef.VisionContext == nullptr) { + // Text only prompt. + ReturnCode = evaluateInput(GraphRef, CxtRef, "compute"sv); + if (ReturnCode != ErrNo::Success) { + return ReturnCode; + } + } else { + // Multimodal prompt. + llama_pos NewNPos; + int32_t Res = mtmd_helper_eval_chunks( + GraphRef.VisionContext.get(), GraphRef.LlamaContext.get(), + GraphRef.VisionInputChunks.get(), CxtRef.NPos, + /* seq_id */ 0, static_cast(CxtRef.CurrentBatchSize), + /* logits_last */ true, &NewNPos); + CxtRef.NPos = NewNPos; + if (Res != 0) { + RET_ERROR(ErrNo::InvalidArgument, + "compute: unable to eval the mtmd prompt."sv) + } + } + + // Main prediction loop. + LOG_DEBUG(GraphRef.EnableDebugLog, "compute: enter main prediction loop"sv) + int64_t NPredict = + CxtRef.Conf.NPredict < 0 ? INT32_MAX : CxtRef.Conf.NPredict; + + while (NPredict-- > 0) { + ReturnCode = sampleOutput(GraphRef, CxtRef); + if (ReturnCode != ErrNo::Success) { + break; + } + } + if (ReturnCode == ErrNo::EndOfSequence) { + ReturnCode = ErrNo::Success; + } + LOG_DEBUG(GraphRef.EnableDebugLog, + "compute: enter main prediction loop...Done"sv) + // End of main prediction loop. + + // TTS: convert output codes to audio file. + if (GraphRef.TextToSpeech) { + LOG_DEBUG(GraphRef.EnableDebugLog, + "compute: convert output codes to audio file."sv) + ReturnCode = codesToSpeech(Env, GraphRef, CxtRef); + if (ReturnCode != ErrNo::Success) { + RET_ERROR(ReturnCode, + "compute: failed to convert output codes to audio "sv + "file."sv) + } + LOG_DEBUG(GraphRef.EnableDebugLog, + "compute: convert output codes to audio file...Done"sv) + } + + if (GraphRef.EnableLog) { + common_perf_print(GraphRef.LlamaContext.get(), CxtRef.LlamaSampler); + } + + LOG_DEBUG(GraphRef.EnableDebugLog, "compute...Done"sv) + return ReturnCode; +} + +Expect computeSingle(WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "computeSingle"sv) + + // New compute single token context. + auto ReturnCode = ErrNo::Success; + if (!CxtRef.ComputeSingleStarted) { + CxtRef.ComputeSingleStarted = true; + + // Clear the context and reset the sampler. + clearContext(GraphRef, CxtRef); + + // Evaluate the input tokens. + if (GraphRef.VisionContext == nullptr) { + // Text only prompt. + ReturnCode = evaluateInput(GraphRef, CxtRef, "compute"sv); + if (ReturnCode != ErrNo::Success) { + return ReturnCode; + } + } else { + // Multimodal prompt. + llama_pos NewNPos; + int32_t Res = mtmd_helper_eval_chunks( + GraphRef.VisionContext.get(), GraphRef.LlamaContext.get(), + GraphRef.VisionInputChunks.get(), CxtRef.NPos, + /* seq_id */ 0, static_cast(CxtRef.CurrentBatchSize), + /* logits_last */ true, &NewNPos); + CxtRef.NPos = NewNPos; + if (Res != 0) { + RET_ERROR(ErrNo::InvalidArgument, + "compute: unable to eval the mtmd prompt."sv) + } + } + } + + // Main prediction process. + LOG_DEBUG(GraphRef.EnableDebugLog, + "computeSingle: enter main prediction process"sv) + ReturnCode = sampleOutput(GraphRef, CxtRef, true); + if (ReturnCode != ErrNo::Success) { + CxtRef.ComputeSingleStarted = false; + } + LOG_DEBUG(GraphRef.EnableDebugLog, + "computeSingle: enter main prediction process...Done"sv) + // End of main predict process. + + LOG_DEBUG(GraphRef.EnableDebugLog, "computeSingle...Done"sv) + return ReturnCode; +} + +#endif +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/compute/inference_manager.cpp b/plugins/wasi_nn/GGML/compute/inference_manager.cpp new file mode 100644 index 00000000..468b098c --- /dev/null +++ b/plugins/wasi_nn/GGML/compute/inference_manager.cpp @@ -0,0 +1,265 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "inference_manager.h" +#include "GGML/core/ggml_core.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +#include +#endif + +namespace WasmEdge::Host::WASINN::GGML { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +namespace { + +// Fill tokens (smaller than batch size) into a batch with position data. +void fillBatch(Span Tokens, Graph &GraphRef, + llama_batch &Batch, int &NPos, bool IsLogit = false) { + assuming(GraphRef.Params.n_batch >= static_cast(Tokens.size())); + assuming(Batch.token != nullptr); + assuming(Batch.pos != nullptr); + assuming(Batch.logits != nullptr); + // Fill the batch with pos information. + Batch.n_tokens = static_cast(Tokens.size()); + for (uint32_t I = 0; I < Tokens.size(); I++) { + Batch.token[I] = Tokens[I]; + Batch.pos[I] = NPos + I; + Batch.logits[I] = false; + } + + // Logits of sampling or end of inputs. + if (IsLogit) { + Batch.logits[Tokens.size() - 1] = true; + } + + // Move the position. + NPos += static_cast(Tokens.size()); +} + +// Evaluate tokens. Construct the tokens into batch and decode. +ErrNo evaluateTokens(Span Tokens, Graph &GraphRef, + llama_batch &Batch, int &NPos, + bool IsLogits = false) noexcept { + // End the inference if the context is full. + uint32_t NCtx = llama_n_ctx(GraphRef.LlamaContext.get()); + if (NPos + static_cast(Tokens.size()) > NCtx) { + LOG_INFO( + GraphRef.EnableLog, + "evaluateTokens: the context if full ({} / {} tokens). Please increase your "sv + "context size."sv, + NPos + static_cast(Tokens.size()), NCtx) + return ErrNo::ContextFull; + } + + // Loop for decode batch. Split tokens into batch size length. + for (int I = 0; I < static_cast(Tokens.size()); + I += static_cast(GraphRef.Params.n_batch)) { + int NEval = static_cast(Tokens.size()) - I; + if (NEval > static_cast(GraphRef.Params.n_batch)) { + NEval = static_cast(GraphRef.Params.n_batch); + } + + // Fill the batch with pos information. + fillBatch(Span(Tokens.begin() + I, NEval), GraphRef, + Batch, NPos, + IsLogits && I + NEval >= static_cast(Tokens.size())); + + // Decode the batch. + auto Status = llama_decode(GraphRef.LlamaContext.get(), Batch); + if (Status == 1) { + RET_ERROR( + ErrNo::RuntimeError, + "evaluateTokens: failed to llama_decode: try reducing the size of the batch "sv + "or increasing the size of context."sv) + } + if (Status < 0) { + RET_ERROR( + ErrNo::RuntimeError, + "evaluateTokens: failed to llama_decode: fatal error. Please open "sv + "an issue on GitHub."sv) + } + } + + return ErrNo::Success; +} +// Generate output embedding. +void buildOutputEmbedding(std::string &Embedding, int32_t NEmbd, + const float *Embeddings) noexcept { + // Embedding vector format + // | Content | + // | ----------------------------------- | + // | '{"number_embedding": ' | + // | n_embedding | + // | ', "embedding": ' | + // | '[' | + // | n_embedding*(embedding value %.10f) | + // | (n_embedding-1)*(',') | + // | ']' | + // | '}' | + Embedding = + fmt::format(R"({{"n_embedding": {}, )" + R"("embedding": [{:.10}]}})"sv, + NEmbd, fmt::join(Embeddings, Embeddings + NEmbd, ","sv)); +} +} // namespace + +// Evaluate the input tokens. Clean all inputs if succeeded. +ErrNo evaluateInput(Graph &GraphRef, Context &CxtRef, + std::string_view LogPrefix) noexcept { + // Check if the input is set before setting up the context. + if (CxtRef.LlamaInputs.size() == 0) { + RET_ERROR(ErrNo::InvalidArgument, "{}: llama input is not set!"sv, + LogPrefix) + } + + // Get the context size. + const uint64_t NCtx = llama_n_ctx(GraphRef.LlamaContext.get()); + // Minus 4 for the special tokens. (Such as , , ... tokens.) + const uint64_t MaxTokensListSize = NCtx - 4; + // Return value. + auto ReturnCode = ErrNo::Success; + + // Check if the input is too long. + if (static_cast(CxtRef.LlamaInputs.size()) > MaxTokensListSize) { + RET_ERROR(ErrNo::PromptTooLong, + "{}: the prompt is too long. Your input has {} tokens. "sv + "Please reduce it to {} tokens."sv, + LogPrefix, CxtRef.LlamaInputs.size(), MaxTokensListSize) + } + + // Evaluate input tokens. + ReturnCode = + evaluateTokens(Span(CxtRef.LlamaInputs.begin(), + CxtRef.LlamaInputs.size()), + GraphRef, CxtRef.LlamaBatch, CxtRef.NPos, true); + if (ReturnCode != ErrNo::Success) { + RET_ERROR(ReturnCode, "{}: failed to evaluate input tokens."sv, LogPrefix) + } + + return ErrNo::Success; +} + +// Clear the context and reset the sampler. +void clearContext(Graph &GraphRef, Context &CxtRef) noexcept { + LOG_DEBUG(GraphRef.EnableDebugLog, "{}: clearContext"sv) + llama_memory_clear(llama_get_memory(GraphRef.LlamaContext.get()), true); + common_sampler_reset(CxtRef.LlamaSampler); + CxtRef.NPos = 0; + CxtRef.LlamaOutputs.clear(); + CxtRef.LlamaOutputTokens.clear(); + LOG_DEBUG(GraphRef.EnableDebugLog, "{}: clearContext...Done"sv) +} + +// TODO: Merge into compute. +Expect getEmbedding(Graph &GraphRef, Context &CxtRef) noexcept { + LOG_DEBUG(GraphRef.EnableDebugLog, "getEmbedding"sv) + + const llama_vocab *Vocab = llama_model_get_vocab(GraphRef.LlamaModel.get()); + // Add SEP if not present. + if (CxtRef.LlamaInputs.size() > 0 && + CxtRef.LlamaInputs.back() != llama_vocab_sep(Vocab)) { + LOG_WARN( + "getEmbedding: last token in the prompt is not SEP, "sv + "'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF "sv + "header."sv) + } + + // Check if the input is too long. + if (static_cast(CxtRef.LlamaInputs.size()) > + GraphRef.Params.n_batch) { + RET_ERROR( + ErrNo::PromptTooLong, + "getEmbedding: the prompt is too long. Your input has {} tokens exceeds batch "sv + "size {}. Please reduce the input size or increase your batch-size."sv, + CxtRef.LlamaInputs.size(), GraphRef.Params.n_batch) + } + + // Evaluate the input tokens. + auto ReturnCode = evaluateInput(GraphRef, CxtRef, "getEmbedding"sv); + if (ReturnCode != ErrNo::Success) { + return ReturnCode; + } + + // Main prediction loop. + const int32_t NEmbd = llama_model_n_embd(GraphRef.LlamaModel.get()); + std::vector Embeddings(NEmbd); + + for (int I = 0; I < CxtRef.LlamaBatch.n_tokens; I++) { + if (!CxtRef.LlamaBatch.logits[I]) { + continue; + } + + // Try to get sequence embeddings. + auto *Embd = llama_get_embeddings_seq(GraphRef.LlamaContext.get(), + CxtRef.LlamaBatch.seq_id[I][0]); + if (Embd == nullptr) { + Embd = llama_get_embeddings_ith(GraphRef.LlamaContext.get(), I); + if (Embd == nullptr) { + LOG_ERROR("getEmbedding: failed to get embeddings for token {}"sv, I); + continue; + } + } + + // Normalize the embeddings. + common_embd_normalize(Embd, Embeddings.data(), NEmbd, + static_cast(CxtRef.Conf.EmbdNormalize)); + } + + std::string EmbeddingString; + buildOutputEmbedding(EmbeddingString, NEmbd, Embeddings.data()); + CxtRef.LlamaOutputs = + std::vector(EmbeddingString.begin(), EmbeddingString.end()); + + if (GraphRef.EnableLog) { + common_perf_print(GraphRef.LlamaContext.get(), /* Sampler */ nullptr); + } + + LOG_DEBUG(GraphRef.EnableDebugLog, "getEmbedding...Done"sv) + return ErrNo::Success; +} + +// Sample and get the output token. +ErrNo sampleOutput(Graph &GraphRef, Context &CxtRef, + bool IsSingleTokenMode) noexcept { + // Use idx = -1 to sample the next token. + const llama_token Id = common_sampler_sample( + CxtRef.LlamaSampler, GraphRef.LlamaContext.get(), /* idx */ -1); + common_sampler_accept(CxtRef.LlamaSampler, Id, /* accept_grammar */ true); + + // Save the output token. + CxtRef.LlamaOutputTokens.emplace_back(Id); + std::string OutputString = + common_token_to_piece(GraphRef.LlamaContext.get(), Id); + CxtRef.LlamaOutputs.insert(CxtRef.LlamaOutputs.end(), OutputString.begin(), + OutputString.end()); + // In single token mode, we do not handle StreamStdout and ReversePrompt. + if (!IsSingleTokenMode) { + // When setting StreamStdout, we print the output to stdout. + if (CxtRef.Conf.StreamStdout) { + fmt::print("{}"sv, + common_token_to_piece(GraphRef.LlamaContext.get(), Id)); + std::fflush(stdout); + } + // Break if reverse prompt is found. + if (!CxtRef.Conf.ReversePrompt.empty() && + std::string(CxtRef.LlamaOutputs.begin(), CxtRef.LlamaOutputs.end()) + .find(CxtRef.Conf.ReversePrompt) != std::string::npos) { + LOG_INFO(GraphRef.EnableLog, "sampleOutput: reverse prompt found."sv) + return ErrNo::EndOfSequence; + } + } + // Deal with end of text token. + const llama_vocab *Vocab = llama_model_get_vocab(GraphRef.LlamaModel.get()); + // Only stop on EOS if GraphRef.Params.sampling.ignore_eos is false. + if (!GraphRef.Params.sampling.ignore_eos && + llama_vocab_is_eog(Vocab, common_sampler_last(CxtRef.LlamaSampler))) { + LOG_INFO(GraphRef.EnableLog, "sampleOutput: EOS token found."sv) + return ErrNo::EndOfSequence; + } + // Evaluate the output token. + return evaluateTokens(Span(&Id, 1), GraphRef, + CxtRef.OutputBatch, CxtRef.NPos, true); +} + +#endif +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/compute/inference_manager.h b/plugins/wasi_nn/GGML/compute/inference_manager.h new file mode 100644 index 00000000..b876249c --- /dev/null +++ b/plugins/wasi_nn/GGML/compute/inference_manager.h @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once +#include "GGML/core/ggml_core.h" + +namespace WasmEdge::Host::WASINN::GGML { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +void clearContext(Graph &GraphRef, Context &CxtRef) noexcept; +Expect getEmbedding(Graph &GraphRef, Context &CxtRef) noexcept; +ErrNo evaluateInput(Graph &GraphRef, Context &CxtRef, + std::string_view LogPrefix) noexcept; +ErrNo sampleOutput(Graph &GraphRef, Context &CxtRef, + bool IsSingleTokenMode = false) noexcept; +#endif +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/core/ggml_core.cpp b/plugins/wasi_nn/GGML/core/ggml_core.cpp new file mode 100644 index 00000000..28f7a323 --- /dev/null +++ b/plugins/wasi_nn/GGML/core/ggml_core.cpp @@ -0,0 +1,376 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "ggml_core.h" +#include "GGML/utils.h" +#include "common/types.h" +#include "host/wasi/vfs_io.h" +#include "wasinnenv.h" +#include + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +#include "GGML/metadata/metadata_parser.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#endif + +namespace WasmEdge::Host::WASINN::GGML { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +namespace { + +// Llama logging callback. +void llamaLogCallback(ggml_log_level LogLevel, const char *LogText, + void *UserData) { + Graph &GraphRef = *reinterpret_cast(UserData); + if (!GraphRef.EnableLog) { + return; + } + std::string Text(LogText); + // Remove the trailing newlines. + Text = Text.erase(Text.find_last_not_of("\n") + 1); + // Skip for "." + if (Text == ".") { + return; + } + if (LogLevel == GGML_LOG_LEVEL_ERROR) { + spdlog::error("[WASI-NN] llama.cpp: {}"sv, Text); + } else if (LogLevel == GGML_LOG_LEVEL_WARN) { + spdlog::warn("[WASI-NN] llama.cpp: {}"sv, Text); + } else if (LogLevel == GGML_LOG_LEVEL_INFO) { + spdlog::info("[WASI-NN] llama.cpp: {}"sv, Text); + } else if (LogLevel == GGML_LOG_LEVEL_DEBUG) { + spdlog::debug("[WASI-NN] llama.cpp: {}"sv, Text); + } +} +} // namespace + +Expect load(WasiNNEnvironment &Env, Span> Builders, + [[maybe_unused]] Device Device, uint32_t &GraphId) noexcept { + // Add a new graph. + EndianValue GId = Env.newGraph(Backend::GGML); + auto &GraphRef = Env.NNGraph[GId.raw()].get(); + + // Initialize the plugin parameters. + GraphRef.EnableLog = false; + GraphRef.EnableDebugLog = false; + common_params CommonParamsDefault; + CommonParamsDefault.lr.init(); + GraphRef.Params = CommonParamsDefault; + GraphRef.Params.n_keep = 0; + GraphRef.Params.n_chunks = -1; + GraphRef.Params.n_parallel = 1; + GraphRef.Params.grp_attn_n = 1; + GraphRef.Params.grp_attn_w = 512; + GraphRef.Params.n_print = -1; + GraphRef.Params.split_mode = llama_split_mode::LLAMA_SPLIT_MODE_LAYER; + // Initialize the model parameters. + llama_model_params ModelParamsDefault = llama_model_default_params(); + GraphRef.Params.n_gpu_layers = ModelParamsDefault.n_gpu_layers; + GraphRef.Params.mmproj.path = ""sv; + GraphRef.Params.warmup = false; + + // Initialize the sampling parameters. + const common_params_sampling SamplerParamsDefault; + GraphRef.Params.sampling = SamplerParamsDefault; + // Initialize the config parameters. + GraphRef.Conf.StreamStdout = false; + GraphRef.Conf.EmbdNormalize = + static_cast(CommonParamsDefault.embd_normalize); + GraphRef.Conf.NPredict = GraphRef.Params.n_ctx; + GraphRef.Conf.ReversePrompt = ""sv; + GraphRef.Conf.ImagePath = ""sv; + + // Set llama log callback. + llama_log_set(llamaLogCallback, &GraphRef); + + // If the graph builder length > 1, the data of builder[1] is the metadata. + if (Builders.size() > 1) { + const std::string Metadata(reinterpret_cast(Builders[1].data()), + Builders[1].size()); + // Ignore context or model updates when initializing the graph. + auto Res = parseMetadata(GraphRef, GraphRef.Conf, Metadata); + if (Res != ErrNo::Success) { + Env.deleteGraph(GId.raw()); + RET_ERROR(Res, "load: Failed to parse metadata."sv) + } + } + + // Logging. + LOG_DEBUG(GraphRef.EnableDebugLog, "load"sv) + LOG_INFO(GraphRef.EnableLog, "LLAMA_COMMIT {}"sv, LLAMA_COMMIT) + LOG_INFO(GraphRef.EnableLog, "LLAMA_BUILD_NUMBER {}"sv, LLAMA_BUILD_NUMBER) + + // Handle the model path. + LOG_DEBUG(GraphRef.EnableDebugLog, "load: handling model path."sv) + auto Weight = Builders[0]; + const std::string_view BinModel(reinterpret_cast(Weight.data()), + Weight.size()); + if (BinModel.substr(0, 8) == "preload:"sv) { + GraphRef.Params.model.path = BinModel.substr(8); + } else { + LOG_DEBUG(GraphRef.EnableDebugLog, + "load: Model path not found in nn-preload, write model into "sv + "a tmpfile."sv) + // TODO: pass the model directly to ggml. + // Write ggml model to file. + GraphRef.Params.model.path = "ggml-model.bin"sv; + WasmEdge::FStream::OFStream TempFile( + GraphRef.Params.model.path, std::ios_base::out | std::ios_base::binary, + Env.getEnv()); + if (!TempFile) { + Env.deleteGraph(GId.raw()); + RET_ERROR(ErrNo::InvalidArgument, + "load: Failed to create the temporary file. Currently, our "sv + "workaround involves creating a temporary model file named "sv + "\"ggml-model.bin\" and passing this filename as a "sv + "parameter to the ggml llama library."sv) + } + TempFile.write(BinModel.data(), BinModel.size()); + TempFile.close(); + LOG_DEBUG(GraphRef.EnableDebugLog, + "load: Write model into a tmpfile...Done"sv) + } + LOG_DEBUG(GraphRef.EnableDebugLog, "load: handling model path...Done"sv) + + // Check if the model exists. + if (!std::filesystem::exists( + std::filesystem::u8path(GraphRef.Params.model.path))) { + Env.deleteGraph(GId.raw()); + RET_ERROR(ErrNo::ModelNotFound, "load: model file not found."sv) + } + GraphRef.Params.model = GraphRef.Params.model; + + // Initialize ggml parameters. + LOG_DEBUG(GraphRef.EnableDebugLog, + "load: initialize ggml model with given parameters."sv) + + common_params Params = GraphRef.Params; + Params.cpuparams.n_threads = + static_cast(GraphRef.Params.cpuparams.n_threads); + Params.cpuparams_batch.n_threads = + static_cast(GraphRef.Params.cpuparams.n_threads); + llama_backend_init(); + llama_numa_init(Params.numa); + + // Initialize the llama model and context. + common_init_result LlamaInit = common_init_from_params(Params); + GraphRef.LlamaModel = std::move(LlamaInit.model); + GraphRef.LlamaContext = std::move(LlamaInit.context); + if (GraphRef.LlamaModel == nullptr) { + Env.deleteGraph(GId.raw()); + RET_ERROR(ErrNo::InvalidArgument, "load: unable to init model."sv) + } + if (GraphRef.LlamaContext == nullptr) { + Env.deleteGraph(GId.raw()); + RET_ERROR(ErrNo::InvalidArgument, "load: unable to init context."sv) + } + LOG_DEBUG(GraphRef.EnableDebugLog, + "load: initialize ggml model with given parameters...Done"sv) + + // Initialize the TTS related model and context. + if (GraphRef.TextToSpeech) { + LOG_DEBUG(GraphRef.EnableDebugLog, "load: initialize TTS model."sv) + Params.model = GraphRef.Params.vocoder.model; + Params.embedding = true; + common_init_result TTSInit = common_init_from_params(Params); + GraphRef.TTSModel = std::move(TTSInit.model); + GraphRef.TTSContext = std::move(TTSInit.context); + if (GraphRef.TTSModel == nullptr) { + Env.deleteGraph(GId.raw()); + RET_ERROR(ErrNo::InvalidArgument, "load: unable to init TTS model."sv) + } + if (GraphRef.TTSContext == nullptr) { + Env.deleteGraph(GId.raw()); + RET_ERROR(ErrNo::InvalidArgument, "load: unable to init TTS context."sv) + } + LOG_DEBUG(GraphRef.EnableDebugLog, "load: initialize TTS model...Done"sv) + } + + // Store the loaded graph. + GraphId = GId.le(); + Env.NNGraph[GId.raw()].setReady(); + + LOG_DEBUG(GraphRef.EnableDebugLog, "load...Done"sv) + return ErrNo::Success; +} + +Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, + uint32_t &ContextId) noexcept { + auto &GraphRef = Env.NNGraph[GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "initExecCtx"sv) + ContextId = Env.newContext(GraphId, Env.NNGraph[GraphId]); + LOG_INFO(GraphRef.EnableLog, "llama_system_info: {}"sv, + llama_print_system_info()) + + auto &CxtRef = Env.NNContext[ContextId].get(); + // Allocate the batch for input string prompt tokens. + CxtRef.LlamaBatch = allocBatch(GraphRef.Params.n_batch); + CxtRef.CurrentBatchSize = GraphRef.Params.n_batch; + + // Allocate the batch for output sampling. The batch size is always 1. + CxtRef.OutputBatch = allocBatch(1); + + // Allocate sampler. + CxtRef.LlamaSampler = + common_sampler_init(GraphRef.LlamaModel.get(), GraphRef.Params.sampling); + + Env.NNContext[ContextId].setReady(); + ContextId = EndianValue(ContextId).le(); + LOG_DEBUG(GraphRef.EnableDebugLog, "initExecCtx...Done"sv) + return ErrNo::Success; +} + +Expect finiSingle(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "finiSingle"sv) + + // Logging for the llama timings. + if (GraphRef.EnableLog) { + common_perf_print(GraphRef.LlamaContext.get(), CxtRef.LlamaSampler); + } + + // Clear the outputs. + LOG_DEBUG(GraphRef.EnableDebugLog, + "finiSingle: clear the previous output and tokens"sv) + CxtRef.LlamaOutputs.clear(); + CxtRef.LlamaOutputTokens.clear(); + LOG_DEBUG(GraphRef.EnableDebugLog, + "finiSingle: clear the previous output and tokens...Done"sv) + + // Reset the llama sampler. + common_sampler_reset(CxtRef.LlamaSampler); + CxtRef.ComputeSingleStarted = false; + CxtRef.NPos = 0; + + LOG_DEBUG(GraphRef.EnableDebugLog, "finiSingle...Done"sv) + return ErrNo::Success; +} + +Expect unload(WasiNNEnvironment &Env, uint32_t GraphId) noexcept { + auto &GraphRef = Env.NNGraph[GraphId].get(); + const bool IsDebugLog = GraphRef.EnableDebugLog; + LOG_DEBUG(IsDebugLog, "unload"sv) + + // TODO: Move the resource deallocation into the destructor. + if (GraphRef.LlamaModel != nullptr) { + LOG_DEBUG(IsDebugLog, "unload: free llama model"sv) + GraphRef.LlamaModel.reset(); + LOG_DEBUG(IsDebugLog, "unload: free llama model...Done"sv) + } + if (GraphRef.LlamaContext != nullptr) { + LOG_DEBUG(IsDebugLog, "unload: free llama context"sv) + GraphRef.LlamaContext.reset(); + LOG_DEBUG(IsDebugLog, "unload: free llama context...Done"sv) + } + if (GraphRef.VisionContext != nullptr) { + LOG_DEBUG(IsDebugLog, "unload: free mtmd context"sv) + GraphRef.VisionContext.reset(); + LOG_DEBUG(IsDebugLog, "unload: free mtmd context...Done"sv) + } + if (GraphRef.VisionInputChunks != nullptr) { + LOG_DEBUG(IsDebugLog, "unload: free mtmd chunks"sv) + GraphRef.VisionInputChunks.reset(); + LOG_DEBUG(IsDebugLog, "unload: free mtmd chunks...Done"sv) + } + if (GraphRef.TTSModel != nullptr) { + LOG_DEBUG(IsDebugLog, "unload: free TTS model"sv) + GraphRef.TTSModel.reset(); + LOG_DEBUG(IsDebugLog, "unload: free TTS model...Done"sv) + } + if (GraphRef.TTSContext != nullptr) { + LOG_DEBUG(IsDebugLog, "unload: free TTS context"sv) + GraphRef.TTSContext.reset(); + LOG_DEBUG(IsDebugLog, "unload: free TTS context...Done"sv) + } + if (!GraphRef.TensorBuftOverrides.empty()) { + LOG_DEBUG(IsDebugLog, "unload: free tensor buffer overrides"sv) + GraphRef.TensorBuftOverrides.clear(); + LOG_DEBUG(IsDebugLog, "unload: free tensor buffer overrides...Done"sv) + } + Env.deleteGraph(GraphId); + Env.mdRemoveById(GraphId); + + LOG_DEBUG(IsDebugLog, "unload...Done"sv) + return ErrNo::Success; +} + +Expect finalizeExecCtx(WasiNNEnvironment &Env, + uint32_t ContextId) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "finalize_execution_context"sv) + + if (CxtRef.LlamaSampler != nullptr) { + LOG_DEBUG(GraphRef.EnableDebugLog, + "finalize_execution_context: free compute_single sampler"sv) + common_sampler_free(CxtRef.LlamaSampler); + CxtRef.LlamaSampler = nullptr; + LOG_DEBUG( + GraphRef.EnableDebugLog, + "finalize_execution_context: free compute_single sampler...Done"sv) + } + llama_batch_free(CxtRef.LlamaBatch); + llama_batch_free(CxtRef.OutputBatch); + Env.deleteContext(ContextId); + + LOG_DEBUG(GraphRef.EnableDebugLog, "finalize_execution_context...Done"sv) + return ErrNo::Success; +} + +#else +namespace { +Expect reportBackendNotSupported() noexcept { + spdlog::error("[WASI-NN] ggml backend is not built. use " + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"ggml\" to build it."sv); + return ErrNo::InvalidArgument; +} +} // namespace + +Expect load(WasiNNEnvironment &, Span>, Device, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect initExecCtx(WasiNNEnvironment &, uint32_t, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect setInput(WasiNNEnvironment &, uint32_t, uint32_t, + const TensorData &) noexcept { + return reportBackendNotSupported(); +} +Expect getOutput(WasiNNEnvironment &, uint32_t, uint32_t, Span, + uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect compute(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +Expect getOutputSingle(WasiNNEnvironment &, uint32_t, uint32_t, + Span, uint32_t &) noexcept { + return reportBackendNotSupported(); +} +Expect computeSingle(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +Expect finiSingle(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +Expect unload(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +Expect finalizeExecCtx(WasiNNEnvironment &, uint32_t) noexcept { + return reportBackendNotSupported(); +} +#endif +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/core/ggml_core.h b/plugins/wasi_nn/GGML/core/ggml_core.h new file mode 100644 index 00000000..68671572 --- /dev/null +++ b/plugins/wasi_nn/GGML/core/ggml_core.h @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "ggml_type.h" +#include "plugin/plugin.h" +#include "wasinnenv.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +#include +#endif + +namespace WasmEdge::Host::WASINN { +struct WasiNNEnvironment; +} // namespace WasmEdge::Host::WASINN +namespace WasmEdge::Host::WASINN::GGML { + +Expect load(WASINN::WasiNNEnvironment &Env, + Span> Builders, + WASINN::Device Device, uint32_t &GraphId) noexcept; +Expect initExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId, + uint32_t &ContextId) noexcept; +Expect finiSingle(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +Expect setInput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + const TensorData &Tensor) noexcept; +Expect getOutput(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect getOutputSingle(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId, uint32_t Index, + Span OutBuffer, + uint32_t &BytesWritten) noexcept; +Expect compute(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +Expect computeSingle(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; +Expect unload(WASINN::WasiNNEnvironment &Env, + uint32_t GraphId) noexcept; +Expect finalizeExecCtx(WASINN::WasiNNEnvironment &Env, + uint32_t ContextId) noexcept; + +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/wasinn_ggml.h b/plugins/wasi_nn/GGML/core/ggml_type.h similarity index 58% rename from plugins/wasi_nn/wasinn_ggml.h rename to plugins/wasi_nn/GGML/core/ggml_type.h index fb48b060..66da8786 100644 --- a/plugins/wasi_nn/wasinn_ggml.h +++ b/plugins/wasi_nn/GGML/core/ggml_type.h @@ -3,28 +3,22 @@ #pragma once -#include "wasinntypes.h" - #include "plugin/plugin.h" +#include "wasinntypes.h" #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML -#include +#include "wasinntypes.h" +#include #include #include #include #include -#include #endif -namespace WasmEdge::Host::WASINN { -struct WasiNNEnvironment; -} - namespace WasmEdge::Host::WASINN::GGML { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML - enum class EmbdNormalizeType : int32_t { // Follow: // https://github.com/ggerganov/llama.cpp/blob/0bf2d10c5514ff61b99897a4a5054f846e384e1e/common/common.h#L312 @@ -35,11 +29,6 @@ enum class EmbdNormalizeType : int32_t { PNorm = 3, }; -struct TTSSpeakerProfile { - std::string Text; - std::string Data; -}; - struct LocalConfig { // Configurations which can be changed in every contexts. // The graph handles a default config and parsed from metadata when loading. @@ -107,31 +96,32 @@ struct Context { struct Environ {}; -Expect load(WASINN::WasiNNEnvironment &Env, - Span> Builders, - WASINN::Device Device, uint32_t &GraphId) noexcept; -Expect initExecCtx(WASINN::WasiNNEnvironment &Env, - uint32_t GraphId, - uint32_t &ContextId) noexcept; -Expect setInput(WASINN::WasiNNEnvironment &Env, - uint32_t ContextId, uint32_t Index, - const TensorData &Tensor) noexcept; -Expect getOutput(WASINN::WasiNNEnvironment &Env, - uint32_t ContextId, uint32_t Index, - Span OutBuffer, - uint32_t &BytesWritten) noexcept; -Expect getOutputSingle(WASINN::WasiNNEnvironment &Env, - uint32_t ContextId, uint32_t Index, - Span OutBuffer, - uint32_t &BytesWritten) noexcept; -Expect compute(WASINN::WasiNNEnvironment &Env, - uint32_t ContextId) noexcept; -Expect computeSingle(WASINN::WasiNNEnvironment &Env, - uint32_t ContextId) noexcept; -Expect finiSingle(WASINN::WasiNNEnvironment &Env, - uint32_t ContextId) noexcept; -Expect unload(WASINN::WasiNNEnvironment &Env, - uint32_t GraphId) noexcept; -Expect finalizeExecCtx(WASINN::WasiNNEnvironment &Env, - uint32_t ContextId) noexcept; +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +namespace { + +// Macro for logging debug message. +#define LOG_DEBUG(Debug, ...) \ + if (Debug) { \ + spdlog::info("[WASI-NN][Debug] GGML backend: "sv __VA_ARGS__); \ + } + +// Macro for logging info message. +#define LOG_INFO(Info, ...) \ + if (Info) { \ + spdlog::info("[WASI-NN] GGML backend: "sv __VA_ARGS__); \ + } + +// Macro for logging warning message. +#define LOG_WARN(...) spdlog::warn("[WASI-NN] GGML backend: "sv __VA_ARGS__); + +// Macro for logging error message. +#define LOG_ERROR(...) spdlog::error("[WASI-NN] GGML backend: "sv __VA_ARGS__); + +// Macro for logging error message and return. +#define RET_ERROR(Error, ...) \ + spdlog::error("[WASI-NN] GGML backend: "sv __VA_ARGS__); \ + return Error; +} // namespace +#endif + } // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/core/input_processor.cpp b/plugins/wasi_nn/GGML/core/input_processor.cpp new file mode 100644 index 00000000..0a71d1cc --- /dev/null +++ b/plugins/wasi_nn/GGML/core/input_processor.cpp @@ -0,0 +1,294 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "GGML/metadata/metadata_parser.h" +#include "GGML/tts/tts_core.h" +#include "GGML/utils.h" +#include "ggml_core.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +#include +#include +#endif + +namespace WasmEdge::Host::WASINN::GGML { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML + +Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index, const TensorData &Tensor) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput"sv) + + // Use index 1 for metadata. + if (Index == 1) { + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: found Metadata, processing"sv) + bool IsModelParamsUpdated = false; + bool IsContextParamsUpdated = false; + bool IsSamplerParamsUpdated = false; + const std::string Metadata(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); + auto Res = + parseMetadata(GraphRef, CxtRef.Conf, Metadata, &IsModelParamsUpdated, + &IsContextParamsUpdated, &IsSamplerParamsUpdated); + if (Res != ErrNo::Success) { + RET_ERROR(Res, "setInput: failed to parse metadata."sv) + } + +#ifndef __APPLE__ + // XXX: Due to the limitation of WASI-NN proposal, this is a workaround + // for non-macOS devices. However, if the model params is updated in + // Config stage, then, we don't encourage to use this to avoid the model + // reloading. + { + if (IsModelParamsUpdated || GraphRef.LlamaModel == nullptr) { + // The llama model may be nullptr if set_input with updated model params + // last time. Therefore besides the model params updated, we should + // reload the llama model if the model is nullptr. + LOG_INFO(GraphRef.EnableLog, + "setInput: Reload model due to parameters change."sv) + llama_model_params ModelParams = llama_model_default_params(); + ModelParams.n_gpu_layers = + static_cast(GraphRef.Params.n_gpu_layers); + GraphRef.LlamaModel.reset(); + // Due to the model change, the context and sampler should also be + // reloaded. The new context and sampler will be created in the next + // block. + GraphRef.LlamaContext.reset(); + if (CxtRef.LlamaSampler) { + // TODO: Trigger the sampler in other contexts to reallocate. + common_sampler_free(CxtRef.LlamaSampler); + CxtRef.LlamaSampler = nullptr; + } + GraphRef.LlamaModel = llama_model_ptr(llama_model_load_from_file( + GraphRef.Params.model.path.c_str(), ModelParams)); + if (GraphRef.LlamaModel == nullptr) { + Env.NNGraph[CxtRef.GraphId].setInvalid(); + RET_ERROR(ErrNo::InvalidArgument, "setInput: unable to init model."sv) + } + } + } +#endif + + // Some changes of context parameters will require the context to be + // reloaded. + if (IsContextParamsUpdated || GraphRef.LlamaContext == nullptr) { + LOG_INFO(GraphRef.EnableLog, + "setInput: Reload llama context due to parameters change."sv) + GraphRef.LlamaContext.reset(); + GraphRef.LlamaContext = llama_context_ptr(llama_init_from_model( + GraphRef.LlamaModel.get(), + common_context_params_to_llama(GraphRef.Params))); + if (GraphRef.LlamaContext == nullptr) { + Env.NNGraph[CxtRef.GraphId].setInvalid(); + RET_ERROR(ErrNo::InvalidArgument, "setInput: unable to init context."sv) + } + } + + // Some changes of sampling parameters will require the sampler to be + // reallocated. + if (IsSamplerParamsUpdated || CxtRef.LlamaSampler == nullptr) { + LOG_INFO(GraphRef.EnableLog, + "setInput: Reallocate llama sampler due to parameters change."sv) + if (CxtRef.LlamaSampler) { + common_sampler_free(CxtRef.LlamaSampler); + } + CxtRef.LlamaSampler = common_sampler_init(GraphRef.LlamaModel.get(), + GraphRef.Params.sampling); + if (GraphRef.LlamaContext == nullptr) { + Env.NNGraph[CxtRef.GraphId].setInvalid(); + RET_ERROR(ErrNo::InvalidArgument, "setInput: unable to init sampler."sv) + } + } + + // Check that is batch size changed. + if (CxtRef.CurrentBatchSize != GraphRef.Params.n_batch) { + llama_batch_free(CxtRef.LlamaBatch); + CxtRef.LlamaBatch = allocBatch(GraphRef.Params.n_batch); + CxtRef.CurrentBatchSize = GraphRef.Params.n_batch; + } + + Env.NNGraph[CxtRef.GraphId].setReady(); + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: found Metadata, processing...Done"sv) + return ErrNo::Success; + } + + // Check the graph is valid after reloading during previous set_input. + if (!Env.NNGraph[CxtRef.GraphId].isReady()) { + RET_ERROR( + ErrNo::InvalidArgument, + "setInput: Graph is invalid. Please reload again by passing metadata "sv + "in set_input or unload graph."sv) + } + + // Clear the llama context. + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: clear llama context"sv) + llama_memory_clear(llama_get_memory(GraphRef.LlamaContext.get()), true); + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: clear llama context...Done"sv) + + // Set the input. + const bool AddSpecial = true; + const bool ParseSpecial = true; + std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); + CxtRef.LlamaInputs.clear(); + + auto Base64ImagePos = findBase64ImagePayload(Prompt); + + if (Base64ImagePos.has_value() || CxtRef.Conf.ImagePath != ""sv) { + // First check the projection model is given. + if (GraphRef.Params.mmproj.path == ""sv) { + RET_ERROR( + ErrNo::InvalidArgument, + "setInput: the given model does not support image input, so a projection model is required."sv) + } + + // Make sure the projection model is loaded. + if (GraphRef.VisionContext == nullptr) { + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: initialize mtmd context."sv) + // Initialize the mtmd context. + mtmd_context_params VisionContextParams = mtmd_context_params_default(); + std::string VisionPromptImagePlaceholderStr(VisionPromptImagePlaceholder); + VisionContextParams.media_marker = + VisionPromptImagePlaceholderStr.c_str(); + VisionContextParams.use_gpu = GraphRef.Params.mmproj_use_gpu; + VisionContextParams.n_threads = GraphRef.Params.cpuparams.n_threads; + VisionContextParams.print_timings = + GraphRef.EnableLog || GraphRef.EnableDebugLog; + if (GraphRef.EnableDebugLog) { + VisionContextParams.verbosity = GGML_LOG_LEVEL_DEBUG; + } else if (GraphRef.EnableLog) { + VisionContextParams.verbosity = GGML_LOG_LEVEL_INFO; + } else { + VisionContextParams.verbosity = GGML_LOG_LEVEL_NONE; + } + GraphRef.VisionContext.reset( + mtmd_init_from_file(GraphRef.Params.mmproj.path.c_str(), + GraphRef.LlamaModel.get(), VisionContextParams)); + if (GraphRef.VisionContext == nullptr) { + RET_ERROR(ErrNo::InvalidArgument, + "setInput: unable to load the mmproj model {}."sv, + GraphRef.Params.mmproj.path) + } + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: initialize mtmd context...Done"sv) + } + + // Show some warnings for context size. + if (GraphRef.Params.n_ctx < 4096) { + LOG_INFO( + GraphRef.EnableLog, + "setInput: Context size is {}, we recommend context size >= 4096 when using multimodal models for better results"sv, + GraphRef.Params.n_ctx) + } + + // Get the image bitmaps. + // Follow this link for the supported image formats: + // https://github.com/ggml-org/llama.cpp/blob/master/common/stb_image.h + mtmd::bitmaps Bitmaps; + if (GraphRef.VisionContext != nullptr) { + if (Base64ImagePos.has_value()) { + // Load the image bitmap from the base64 image. + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: load the image bitmap from the base64 image."sv) + // Extract the payload and image type from the prompt. + std::optional, std::string>> Payload = + extractBase64ImagePayload(Prompt, *Base64ImagePos, + VisionPromptImagePlaceholder); + if (Payload.has_value()) { + // Create the new image bitmap. + mtmd::bitmap Bitmap(mtmd_helper_bitmap_init_from_buf( + GraphRef.VisionContext.get(), Payload->first.data(), + Payload->first.size())); + if (Bitmap.ptr == nullptr) { + RET_ERROR( + ErrNo::InvalidArgument, + "setInput: unable to load the image from base64 paylaod."sv) + } + Bitmaps.entries.push_back(std::move(Bitmap)); + } + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: Compute image embd from the base64 image...Done"sv) + } else { + // Load the image from the file. + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: load the image bitmap from file: {}"sv, + CxtRef.Conf.ImagePath) + mtmd::bitmap Bitmap(mtmd_helper_bitmap_init_from_file( + GraphRef.VisionContext.get(), CxtRef.Conf.ImagePath.c_str())); + if (Bitmap.ptr == nullptr) { + RET_ERROR( + ErrNo::InvalidArgument, + "setInput: unable to load the image bitmap from file: {}."sv, + CxtRef.Conf.ImagePath) + } + Bitmaps.entries.push_back(std::move(Bitmap)); + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: load the image bitmap from file: {}...Done"sv, + CxtRef.Conf.ImagePath) + } + } + + // Tokenize the prompt. + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize the mtmd prompt"sv) + GraphRef.VisionInputChunks.reset(mtmd_input_chunks_init()); + mtmd_input_text MtmdText; + MtmdText.text = Prompt.c_str(); + MtmdText.add_special = AddSpecial; + MtmdText.parse_special = ParseSpecial; + std::vector BitmapsPtr = Bitmaps.c_ptr(); + int32_t Res = mtmd_tokenize(GraphRef.VisionContext.get(), + GraphRef.VisionInputChunks.get(), &MtmdText, + BitmapsPtr.data(), BitmapsPtr.size()); + if (Res != 0) { + RET_ERROR(ErrNo::InvalidArgument, + "setInput: unable to tokenize the mtmd prompt."sv) + } + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: tokenize the mtmd prompt...Done"sv) + + // Get the number of input tokens (for the metadata). + CxtRef.LlamaNInputs = 0; + for (size_t ChunkIndex = 0; + ChunkIndex < mtmd_input_chunks_size(GraphRef.VisionInputChunks.get()); + ++ChunkIndex) { + size_t NTokens = 0; + const mtmd_input_chunk *Chunk = + mtmd_input_chunks_get(GraphRef.VisionInputChunks.get(), ChunkIndex); + mtmd_input_chunk_get_tokens_text(Chunk, &NTokens); + CxtRef.LlamaNInputs += NTokens; + } + } else if (GraphRef.TextToSpeech == true) { + // TTS prompt. + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize tts prompt"sv) + CxtRef.LlamaInputs = processTTSPrompt(Env, GraphRef, Prompt); + if (CxtRef.LlamaInputs.empty()) { + RET_ERROR(ErrNo::InvalidArgument, + "setInput: failed to tokenize tts prompt."sv) + } + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize tts prompt...Done"sv) + + // Get the number of input tokens (for the metadata). + CxtRef.LlamaNInputs = CxtRef.LlamaInputs.size(); + } else { + // Text only prompt. + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize text prompt"sv) + CxtRef.LlamaInputs = common_tokenize(GraphRef.LlamaContext.get(), Prompt, + AddSpecial, ParseSpecial); + LOG_DEBUG(GraphRef.EnableDebugLog, + "setInput: tokenize text prompt...Done"sv) + + // Get the number of input tokens (for the metadata). + CxtRef.LlamaNInputs = CxtRef.LlamaInputs.size(); + } + + // Maybe currently in the compute_single mode. Reset the computing. + CxtRef.ComputeSingleStarted = false; + + LOG_DEBUG(GraphRef.EnableDebugLog, "setInput...Done"sv) + return ErrNo::Success; +} + +#endif +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/core/output_generator.cpp b/plugins/wasi_nn/GGML/core/output_generator.cpp new file mode 100644 index 00000000..9cd6667c --- /dev/null +++ b/plugins/wasi_nn/GGML/core/output_generator.cpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "ggml_core.h" + +namespace WasmEdge::Host::WASINN::GGML { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +namespace { +// Generate output metadata. +std::string buildOutputMetadata(Context &CxtRef) noexcept { + return fmt::format(R"({{"input_tokens": {}, )" + R"("output_tokens": {}, )" + R"("llama_build_number": {}, )" + R"("llama_commit": "{}"}})"sv, + CxtRef.LlamaNInputs, CxtRef.LlamaOutputTokens.size(), + LLAMA_BUILD_NUMBER, LLAMA_COMMIT); +} +} // namespace + +Expect getOutputSingle(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index, Span OutBuffer, + uint32_t &BytesWritten) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutputSingle: with Index {}"sv, Index) + + // Use index 1 for the metadata of the outputs. + if (Index == 1) { + std::string Metadata = buildOutputMetadata(CxtRef); + std::copy_n(Metadata.data(), Metadata.length(), OutBuffer.data()); + BytesWritten = static_cast(Metadata.length()); + LOG_DEBUG(GraphRef.EnableDebugLog, + "getOutputSingle: with Index {} a.k.a Metadata...Done"sv, Index) + return ErrNo::Success; + } + + std::string LastToken = common_token_to_piece( + GraphRef.LlamaContext.get(), CxtRef.LlamaOutputTokens.back()); + std::copy_n(LastToken.data(), LastToken.length(), OutBuffer.data()); + BytesWritten = EndianValue(static_cast(LastToken.length())).le(); + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutputSingle: with Index {}...Done"sv, + Index) + return ErrNo::Success; +} + +Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, + uint32_t Index, Span OutBuffer, + uint32_t &BytesWritten) noexcept { + auto &CxtRef = Env.NNContext[ContextId].get(); + auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutput: with Index {}"sv, Index) + + // Use index 1 for the metadata of the outputs. + if (Index == 1) { + std::string Metadata = buildOutputMetadata(CxtRef); + std::copy_n(Metadata.data(), Metadata.length(), OutBuffer.data()); + BytesWritten = static_cast(Metadata.length()); + LOG_DEBUG(GraphRef.EnableDebugLog, + "getOutput: with Index {} a.k.a Metadata ...Done"sv, Index) + return ErrNo::Success; + } + + std::copy_n(CxtRef.LlamaOutputs.data(), CxtRef.LlamaOutputs.size(), + OutBuffer.data()); + BytesWritten = + EndianValue(static_cast(CxtRef.LlamaOutputs.size())).le(); + LOG_DEBUG(GraphRef.EnableDebugLog, "getOutput: with Index {}...Done"sv, Index) + return ErrNo::Success; +} + +#endif +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/metadata/metadata_parser.cpp b/plugins/wasi_nn/GGML/metadata/metadata_parser.cpp new file mode 100644 index 00000000..587ceb13 --- /dev/null +++ b/plugins/wasi_nn/GGML/metadata/metadata_parser.cpp @@ -0,0 +1,552 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "metadata_parser.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +#include +#include +#include +#include +#endif + +namespace WasmEdge::Host::WASINN::GGML { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +// Parse metadata from json. +ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, + const std::string &Metadata, bool *IsModelUpdated, + bool *IsContextUpdated, bool *IsSamplerUpdated) noexcept { + // Parse metadata from the json. + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + auto ParseError = Parser.parse(Metadata).get(Doc); + if (ParseError) { + RET_ERROR(ErrNo::InvalidEncoding, "parse metadata error."sv) + } + + // Get the current llama parameters. + int64_t PrevNGPULayers = GraphRef.Params.n_gpu_layers; + bool PrevEmbedding = GraphRef.Params.embedding; + // Get the current sampler parameters. + double PrevTemp = GraphRef.Params.sampling.temp; + double PrevTopP = GraphRef.Params.sampling.top_p; + double PrevRepeatPenalty = GraphRef.Params.sampling.penalty_repeat; + double PrevPresencePenalty = GraphRef.Params.sampling.penalty_present; + double PrevFrequencyPenalty = GraphRef.Params.sampling.penalty_freq; + std::string PrevGrammar = GraphRef.Params.sampling.grammar; + uint64_t PrevSeed = GraphRef.Params.sampling.seed; + + try { + parseJsonAuto(Doc, "enable-log", GraphRef.EnableLog); + parseJsonAuto(Doc, "enable-debug-log", GraphRef.EnableDebugLog); + + parseJsonWithCastAuto(Doc, "main-gpu", GraphRef.Params.main_gpu); + parseJsonWithCastAuto(Doc, "n-gpu-layers", + GraphRef.Params.n_gpu_layers); + + parseJsonWithProcessorAuto(Doc, "cpu-moe", + [&GraphRef](const bool &CpuMoe) -> bool { + if (CpuMoe) { + GraphRef.TensorBuftOverrides.push_back( + "\\.ffn_(up|down|gate)_exps"); + } + return true; + }); + + parseJsonWithProcessorAuto( + Doc, "n-cpu-moe", [&GraphRef](const int64_t &NCpuMoe) -> bool { + if (NCpuMoe < 0) { + spdlog::error("[WASI-NN] GGML backend: Invalid n-cpu-moe value."); + return false; + } + for (int I = 0; I < NCpuMoe; I++) { + GraphRef.TensorBuftOverrides.push_back( + string_format("blk\\.%d\\.ffn_(up|down|gate)_exps", I)); + } + return true; + }); + parseJsonWithProcessorAuto( + Doc, "tensor-split", [&GraphRef](const std::string_view &TSV) -> bool { + // The TensorSplit is a comma-separated list of non-negative values. + // E.g., "3,2" presents 60% of the data to GPU 0 and 40% to GPU 1. + std::string TS(TSV); + std::replace(TS.begin(), TS.end(), ',', ' '); + std::stringstream SS(TS); + std::memset(GraphRef.Params.tensor_split, 0, + sizeof(GraphRef.Params.tensor_split)); + uint32_t TensorSplitSize = 0; + while (SS.good()) { + float TmpTensor; + SS >> TmpTensor; + GraphRef.Params.tensor_split[TensorSplitSize++] = TmpTensor; + } + size_t NDevices = llama_max_devices(); + if (TensorSplitSize > NDevices) { + spdlog::error( + "[WASI-NN] GGML backend: Number of Tensor-Split is larger than " + "MaxDevices, please reduce the size of tensor-split."); + return false; + } + for (size_t Idx = TensorSplitSize; Idx < NDevices; Idx++) { + GraphRef.Params.tensor_split[TensorSplitSize++] = 0.0f; + } + return true; + }); + parseJsonAuto(Doc, "embedding", GraphRef.Params.embedding); + parseJsonWithProcessorAuto( + Doc, "split-mode", + [&GraphRef](const std::string_view &SplitMode) -> bool { + if (SplitMode == "none"sv) { + GraphRef.Params.split_mode = LLAMA_SPLIT_MODE_NONE; + } else if (SplitMode == "layer"sv) { + GraphRef.Params.split_mode = LLAMA_SPLIT_MODE_LAYER; + } else if (SplitMode == "row"sv) { + GraphRef.Params.split_mode = LLAMA_SPLIT_MODE_ROW; + } else { + spdlog::error("[WASI-NN] GGML backend: Unknown split-mode: " + "{}. Valid: none, layer, row.", + SplitMode); + return false; + } + return true; + }); + parseJsonWithCastAuto(Doc, "mmproj", + GraphRef.Params.mmproj.path); + // The TTS parameters using macros + parseJsonAuto(Doc, "tts", GraphRef.TextToSpeech); + parseJsonWithCastAuto(Doc, "model-vocoder", + GraphRef.Params.vocoder.model.path); + parseJsonWithCastAuto(Doc, "tts-output-file", + GraphRef.TTSOutputFilePath); + + parseJsonWithCastAuto(Doc, "tts-speaker-file", + GraphRef.TTSSpeakerFilePath); + + // The context parameters + parseJsonWithCastAuto(Doc, "ctx-size", GraphRef.Params.n_ctx); + parseJsonWithCastAuto(Doc, "batch-size", GraphRef.Params.n_batch); + parseJsonWithCastAuto(Doc, "ubatch-size", + GraphRef.Params.n_ubatch); + parseJsonWithCastAuto(Doc, "n-keep", GraphRef.Params.n_keep); + parseJsonWithCastAuto(Doc, "n-chunks", GraphRef.Params.n_chunks); + parseJsonWithCastAuto(Doc, "n-parallel", + GraphRef.Params.n_parallel); + parseJsonWithCastAuto(Doc, "n-sequences", + GraphRef.Params.n_sequences); + parseJsonWithCastAuto(Doc, "grp-attn-n", + GraphRef.Params.grp_attn_n); + parseJsonWithCastAuto(Doc, "grp-attn-w", + GraphRef.Params.grp_attn_w); + parseJsonWithCastAuto(Doc, "n-print", GraphRef.Params.n_print); + parseJsonWithCastAuto(Doc, "rope-freq-base", + GraphRef.Params.rope_freq_base); + parseJsonWithCastAuto(Doc, "rope-freq-scale", + GraphRef.Params.rope_freq_scale); + parseJsonWithCastAuto(Doc, "yarn-ext-factor", + GraphRef.Params.yarn_ext_factor); + parseJsonWithCastAuto(Doc, "yarn-attn-factor", + GraphRef.Params.yarn_attn_factor); + parseJsonWithCastAuto(Doc, "yarn-beta-fast", + GraphRef.Params.yarn_beta_fast); + parseJsonWithCastAuto(Doc, "yarn-beta-slow", + GraphRef.Params.yarn_beta_slow); + parseJsonWithCastAuto(Doc, "yarn-orig-ctx", + GraphRef.Params.yarn_orig_ctx); + parseJsonAuto(Doc, "mask-valid", + GraphRef.Params.cpuparams.mask_valid); + parseJsonWithCastAuto(Doc, "priority", + GraphRef.Params.cpuparams.priority); + parseJsonAuto(Doc, "strict-cpu", + GraphRef.Params.cpuparams.strict_cpu); + parseJsonWithCastAuto(Doc, "poll", GraphRef.Params.cpuparams.poll); + parseJsonAuto(Doc, "mask-valid-batch", + GraphRef.Params.cpuparams_batch.mask_valid); + parseJsonWithProcessorAuto( + Doc, "priority-batch", [&GraphRef](const int64_t &Priority) -> bool { + GraphRef.Params.cpuparams_batch.priority = + static_cast(Priority); + return true; + }); + parseJsonAuto(Doc, "strict-cpu-batch", + GraphRef.Params.cpuparams_batch.strict_cpu); + parseJsonWithCastAuto(Doc, "poll-batch", + GraphRef.Params.cpuparams_batch.poll); + + parseJsonWithCastAuto(Doc, "numa", GraphRef.Params.numa); + + parseJsonWithCastAuto(Doc, "rope-scaling-type", + GraphRef.Params.rope_scaling_type); + parseJsonWithCastAuto(Doc, "pooling-type", + GraphRef.Params.pooling_type); + parseJsonWithCastAuto(Doc, "attention-type", + GraphRef.Params.attention_type); + parseJsonWithProcessorAuto( + Doc, "threads", [&GraphRef](const int64_t &NThreads) -> bool { + GraphRef.Params.cpuparams.n_threads = static_cast(NThreads); + return true; + }); + parseJsonWithCastAuto(Doc, "threads-batch", + GraphRef.Params.cpuparams_batch.n_threads); + parseJsonWithCastAuto(Doc, "n-prev", + GraphRef.Params.sampling.n_prev); + parseJsonWithCastAuto(Doc, "n-probs", + GraphRef.Params.sampling.n_probs); + parseJsonWithCastAuto(Doc, "min-keep", + GraphRef.Params.sampling.min_keep); + parseJsonWithCastAuto(Doc, "top-k", + GraphRef.Params.sampling.top_k); + parseJsonWithCastAuto(Doc, "min-p", GraphRef.Params.sampling.min_p); + parseJsonWithCastAuto(Doc, "xtc-probability", + GraphRef.Params.sampling.xtc_probability); + parseJsonWithCastAuto(Doc, "xtc-threshold", + GraphRef.Params.sampling.xtc_threshold); + parseJsonWithCastAuto(Doc, "typ-p", GraphRef.Params.sampling.typ_p); + parseJsonWithCastAuto(Doc, "dynatemp-range", + GraphRef.Params.sampling.dynatemp_range); + parseJsonWithCastAuto(Doc, "dynatemp-exponent", + GraphRef.Params.sampling.dynatemp_exponent); + parseJsonWithCastAuto(Doc, "last-n-penalty", + GraphRef.Params.sampling.penalty_last_n); + parseJsonWithProcessorAuto( + Doc, "temp", [&GraphRef](const double &Temp) -> bool { + GraphRef.Params.sampling.temp = + static_cast(std::max(0.0, Temp)); + return true; + }); + + parseJsonWithProcessorAuto( + Doc, "top-p", [&GraphRef](const double &TopP) -> bool { + GraphRef.Params.sampling.top_p = + static_cast(std::max(0.0, TopP)); + return true; + }); + parseJsonWithProcessorAuto( + Doc, "repeat-penalty", + [&GraphRef](const double &RepeatPenalty) -> bool { + GraphRef.Params.sampling.penalty_repeat = + static_cast(std::max(0.0, RepeatPenalty)); + return true; + }); + parseJsonWithProcessorAuto( + Doc, "presence-penalty", + [&GraphRef](const double &PresencePenalty) -> bool { + GraphRef.Params.sampling.penalty_present = + static_cast(std::max(0.0, PresencePenalty)); + return true; + }); + parseJsonWithProcessorAuto( + Doc, "frequency-penalty", + [&GraphRef](const double &FrequencyPenalty) -> bool { + GraphRef.Params.sampling.penalty_freq = + static_cast(std::max(0.0, FrequencyPenalty)); + return true; + }); + parseJsonWithCastAuto(Doc, "dry-multipier", + GraphRef.Params.sampling.dry_multiplier); + parseJsonWithCastAuto(Doc, "dry-base", + GraphRef.Params.sampling.dry_base); + parseJsonWithCastAuto(Doc, "dry-allowed-length", + GraphRef.Params.sampling.dry_allowed_length); + parseJsonWithCastAuto(Doc, "dry-last-n-penalty", + GraphRef.Params.sampling.penalty_last_n); + parseJsonWithCastAuto(Doc, "mirostat", + GraphRef.Params.sampling.mirostat); + parseJsonWithCastAuto(Doc, "mirostat-eta", + GraphRef.Params.sampling.mirostat_eta); + parseJsonAuto(Doc, "ignore-eos", GraphRef.Params.sampling.ignore_eos); + parseJsonAuto(Doc, "no-perf-sampling", + GraphRef.Params.sampling.no_perf); + parseJsonAuto(Doc, "timing-per-token", + GraphRef.Params.sampling.timing_per_token); + + parseJsonWithCastAuto(Doc, "grammar", + GraphRef.Params.sampling.grammar); + parseJsonWithProcessorAuto( + Doc, "json-schema", + [&GraphRef](const std::string_view &JsonSchema) -> bool { + GraphRef.Params.sampling.grammar = + json_schema_to_grammar(nlohmann::ordered_json::parse(JsonSchema)); + return true; + }); + parseJsonWithCastAuto(Doc, "seed", GraphRef.Params.sampling.seed); + + // The speculative parameters. + parseJsonWithCastAuto(Doc, "n-ctx-speculative", + GraphRef.Params.speculative.n_ctx); + parseJsonWithCastAuto(Doc, "n-max-speculative", + GraphRef.Params.speculative.n_max); + parseJsonWithCastAuto(Doc, "n-min-speculative", + GraphRef.Params.speculative.n_min); + parseJsonWithCastAuto(Doc, "n-gpu-layers-speculative", + GraphRef.Params.speculative.n_gpu_layers); + parseJsonWithCastAuto(Doc, "p-split-speculative", + GraphRef.Params.speculative.p_split); + parseJsonWithCastAuto(Doc, "p-min-speculative", + GraphRef.Params.speculative.p_min); + // The vocoder parameters. + parseJsonWithCastAuto( + Doc, "hf-repo-vocoder", GraphRef.Params.vocoder.model.hf_repo); + parseJsonWithCastAuto( + Doc, "hf-file-vocoder", GraphRef.Params.vocoder.model.hf_file); + parseJsonWithCastAuto(Doc, "model-url-vocoder", + GraphRef.Params.vocoder.model.url); + + // The config parameters. + parseJsonAuto(Doc, "stream-stdout", ConfRef.StreamStdout); + parseJsonAuto(Doc, "n-predict", ConfRef.NPredict); + parseJsonWithCastAuto(Doc, "reverse-prompt", + ConfRef.ReversePrompt); + parseJsonWithCastAuto(Doc, "image", ConfRef.ImagePath); + parseJsonAuto(Doc, "always-regenerate-image-embd", + ConfRef.AlwaysRegenerateImageEmbd); + parseJsonWithCastAuto(Doc, "model-alias", + GraphRef.Params.model_alias); + parseJsonWithCastAuto(Doc, "model-url", + GraphRef.Params.model.url); + parseJsonWithCastAuto(Doc, "hf-token", + GraphRef.Params.hf_token); + parseJsonWithCastAuto(Doc, "hf-repo", + GraphRef.Params.model.hf_repo); + parseJsonWithCastAuto(Doc, "hf-file", + GraphRef.Params.model.hf_file); + parseJsonWithCastAuto(Doc, "prompt-file", + GraphRef.Params.prompt_file); + parseJsonWithCastAuto(Doc, "path-prompt-cache", + GraphRef.Params.path_prompt_cache); + parseJsonWithCastAuto(Doc, "input-prefix", + GraphRef.Params.input_prefix); + parseJsonWithCastAuto(Doc, "input-suffix", + GraphRef.Params.input_suffix); + parseJsonWithCastAuto( + Doc, "lookup-cache-static", GraphRef.Params.lookup_cache_static); + parseJsonWithCastAuto( + Doc, "lookup-cache-dynamic", GraphRef.Params.lookup_cache_dynamic); + parseJsonWithCastAuto(Doc, "logits-file", + GraphRef.Params.logits_file); + parseJsonAuto(Doc, "lora-init-without-apply", + GraphRef.Params.lora_init_without_apply); + parseJsonWithCastAuto(Doc, "verbosity", GraphRef.Params.verbosity); + parseJsonWithCastAuto(Doc, "control-vector-layer-start", + GraphRef.Params.control_vector_layer_start); + parseJsonWithCastAuto(Doc, "control-vector-layer-end", + GraphRef.Params.control_vector_layer_end); + parseJsonWithCastAuto(Doc, "ppl-stride", + GraphRef.Params.ppl_stride); + parseJsonWithCastAuto(Doc, "ppl-output-type", + GraphRef.Params.ppl_output_type); + parseJsonAuto(Doc, "hellaswag", GraphRef.Params.hellaswag); + parseJsonWithCastAuto(Doc, "hellaswag-tasks", + GraphRef.Params.hellaswag_tasks); + parseJsonAuto(Doc, "winogrande", GraphRef.Params.winogrande); + parseJsonWithCastAuto(Doc, "winogrande-tasks", + GraphRef.Params.winogrande_tasks); + parseJsonAuto(Doc, "multiple-choice", + GraphRef.Params.multiple_choice); + parseJsonWithCastAuto( + Doc, "multiple-choice-tasks", GraphRef.Params.multiple_choice_tasks); + parseJsonAuto(Doc, "kl-divergence", GraphRef.Params.kl_divergence); + parseJsonAuto(Doc, "usage", GraphRef.Params.usage); + parseJsonAuto(Doc, "use-color", GraphRef.Params.use_color); + parseJsonAuto(Doc, "special", GraphRef.Params.special); + parseJsonAuto(Doc, "interactive", GraphRef.Params.interactive); + parseJsonAuto(Doc, "interactive-first", + GraphRef.Params.interactive_first); + parseJsonAuto(Doc, "prompt-cache-all", + GraphRef.Params.prompt_cache_all); + parseJsonAuto(Doc, "prompt-cache-ro", + GraphRef.Params.prompt_cache_ro); + parseJsonAuto(Doc, "escape", GraphRef.Params.escape); + parseJsonAuto(Doc, "multiline-input", + GraphRef.Params.multiline_input); + parseJsonAuto(Doc, "simple-io", GraphRef.Params.simple_io); + parseJsonAuto(Doc, "cont-batching", GraphRef.Params.cont_batching); + parseJsonWithProcessorAuto( + Doc, "flash-attn", + [&GraphRef](const std::string_view &FlashAttn) -> bool { + if (FlashAttn == "on"sv || FlashAttn == "enabled"sv) { + GraphRef.Params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED; + } else if (FlashAttn == "off"sv || FlashAttn == "disabled"sv) { + GraphRef.Params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; + } else if (FlashAttn == "auto"sv) { + GraphRef.Params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; + } else { + spdlog::error( + "[WASI-NN] GGML backend: The flash-attn option must be " + "one of: on, off, auto."); + return false; + } + return true; + }); + parseJsonAuto(Doc, "no-perf", GraphRef.Params.no_perf); + parseJsonAuto(Doc, "ctx-shift", GraphRef.Params.ctx_shift); + parseJsonAuto(Doc, "input-prefix-bos", + GraphRef.Params.input_prefix_bos); + parseJsonAuto(Doc, "use-mlock", GraphRef.Params.use_mlock); + parseJsonAuto(Doc, "use-mmap", GraphRef.Params.use_mmap); + parseJsonAuto(Doc, "verbose-prompt", GraphRef.Params.verbose_prompt); + parseJsonAuto(Doc, "display-prompt", GraphRef.Params.display_prompt); + parseJsonAuto(Doc, "no-kv-offload", GraphRef.Params.no_kv_offload); + parseJsonAuto(Doc, "warmup", GraphRef.Params.warmup); + parseJsonAuto(Doc, "check-tensors", GraphRef.Params.check_tensors); + parseJsonWithCastAuto(Doc, "cache-type-k", + GraphRef.Params.cache_type_k); + parseJsonWithCastAuto(Doc, "cache-type-v", + GraphRef.Params.cache_type_v); + + parseJsonWithCastAuto(Doc, "embd-normalize", + GraphRef.Params.embd_normalize); + parseJsonWithCastAuto(Doc, "embd-out", + GraphRef.Params.embd_out); + + parseJsonWithCastAuto(Doc, "embd-sep", + GraphRef.Params.embd_sep); + parseJsonWithProcessorAuto( + Doc, "reranking", [&GraphRef](const bool &) -> bool { + GraphRef.Params.embedding = true; + GraphRef.Params.pooling_type = LLAMA_POOLING_TYPE_RANK; + return true; + }); + parseJsonWithCastAuto(Doc, "port", GraphRef.Params.port); + parseJsonWithCastAuto(Doc, "timeout-read", + GraphRef.Params.timeout_read); + parseJsonWithCastAuto(Doc, "timeout-write", + GraphRef.Params.timeout_write); + parseJsonWithCastAuto(Doc, "n-threads-http", + GraphRef.Params.n_threads_http); + parseJsonWithCastAuto(Doc, "n-cache-reuse", + GraphRef.Params.n_cache_reuse); + parseJsonWithCastAuto(Doc, "hostname", + GraphRef.Params.hostname); + parseJsonWithCastAuto(Doc, "public-path", + GraphRef.Params.public_path); + parseJsonWithCastAuto(Doc, "chat-template", + GraphRef.Params.chat_template); + parseJsonAuto(Doc, "enable-chat-template", + GraphRef.Params.enable_chat_template); + parseJsonWithCastAuto(Doc, "ssl-file-key", + GraphRef.Params.ssl_file_key); + parseJsonWithCastAuto(Doc, "ssl-file-cert", + GraphRef.Params.ssl_file_cert); + parseJsonAuto(Doc, "webui", GraphRef.Params.webui); + parseJsonAuto(Doc, "endpoint-slots", GraphRef.Params.endpoint_slots); + parseJsonAuto(Doc, "endpoint-props", GraphRef.Params.endpoint_props); + parseJsonAuto(Doc, "endpoint-metrics", + GraphRef.Params.endpoint_metrics); + parseJsonAuto(Doc, "log-json", GraphRef.Params.log_json); + + // Slot parameters + parseJsonWithCastAuto(Doc, "slot-save-path", + GraphRef.Params.slot_save_path); + + parseJsonWithCastAuto(Doc, "slot-prompt-similarity", + GraphRef.Params.slot_prompt_similarity); + parseJsonAuto(Doc, "is-pp-shared", GraphRef.Params.is_pp_shared); + parseJsonWithProcessorAuto( + Doc, "n-pp", [&GraphRef](const int64_t &NPP) -> bool { + GraphRef.Params.n_pp.emplace_back(NPP); + return true; + }); + parseJsonWithProcessorAuto( + Doc, "n-tg", [&GraphRef](const int64_t &NTG) -> bool { + GraphRef.Params.n_tg.emplace_back(NTG); + return true; + }); + parseJsonWithProcessorAuto( + Doc, "n-pl", [&GraphRef](const int64_t &NPL) -> bool { + GraphRef.Params.n_pl.emplace_back(NPL); + return true; + }); + parseJsonWithProcessorAuto( + Doc, "context-files", + [&GraphRef](const std::string_view &ContextFile) -> bool { + GraphRef.Params.context_files.emplace_back(ContextFile); + return true; + }); + parseJsonWithCastAuto(Doc, "chunk-size", + GraphRef.Params.chunk_size); + parseJsonWithCastAuto(Doc, "chunk-separator", + GraphRef.Params.chunk_separator); + parseJsonWithCastAuto(Doc, "n-junk", GraphRef.Params.n_junk); + parseJsonWithCastAuto(Doc, "i-pos", GraphRef.Params.i_pos); + parseJsonWithCastAuto(Doc, "out-file", + GraphRef.Params.out_file); + parseJsonWithCastAuto(Doc, "n-out-freq", + GraphRef.Params.n_out_freq); + parseJsonWithCastAuto(Doc, "n-save-freq", + GraphRef.Params.n_save_freq); + parseJsonWithCastAuto(Doc, "i-chunk", GraphRef.Params.i_chunk); + parseJsonAuto(Doc, "process-output", GraphRef.Params.process_output); + parseJsonAuto(Doc, "compute-ppl", GraphRef.Params.compute_ppl); + parseJsonWithCastAuto(Doc, "n-pca-batch", + GraphRef.Params.n_pca_batch); + parseJsonWithCastAuto(Doc, "n-pca-iterations", + GraphRef.Params.n_pca_iterations); + parseJsonWithProcessorAuto( + Doc, "cvector-dimre-method", + [&GraphRef](const std::string_view &Method) -> bool { + if (Method == "DIMRE_METHOD_PCA") { + GraphRef.Params.cvector_dimre_method = + dimre_method::DIMRE_METHOD_PCA; + return true; + } + if (Method == "DIMRE_METHOD_MEAN") { + GraphRef.Params.cvector_dimre_method = + dimre_method::DIMRE_METHOD_MEAN; + return true; + } + return false; + }); + parseJsonWithCastAuto( + Doc, "cvector-positive-file", GraphRef.Params.cvector_positive_file); + parseJsonWithCastAuto( + Doc, "cvector-negative-file", GraphRef.Params.cvector_negative_file); + parseJsonAuto(Doc, "spm-infill", GraphRef.Params.spm_infill); + parseJsonWithCastAuto(Doc, "out-file", + GraphRef.Params.out_file); + parseJsonAuto(Doc, "batched-bench-output-jsonl", + GraphRef.Params.batched_bench_output_jsonl); + } catch (const ErrNo &Error) { + return Error; + } + // The tensor buffer overrides should terminated with empty pattern. + if (!GraphRef.TensorBuftOverrides.empty()) { + for (const std::string &Override : GraphRef.TensorBuftOverrides) { + GraphRef.Params.tensor_buft_overrides.push_back( + {Override.c_str(), ggml_backend_cpu_buffer_type()}); + } + GraphRef.Params.tensor_buft_overrides.push_back({nullptr, nullptr}); + } + + if (GraphRef.TextToSpeech) { + GraphRef.Params.sampling.top_k = 4; + GraphRef.Params.sampling.samplers = { + COMMON_SAMPLER_TYPE_TOP_K, + }; + } + // Check if the model parameters are updated. + if (IsModelUpdated && PrevNGPULayers != GraphRef.Params.n_gpu_layers) { + *IsModelUpdated = true; + } + + // Check if the context parameters are updated. + if (IsContextUpdated && PrevEmbedding != GraphRef.Params.embedding) { + *IsContextUpdated = true; + } + + // Check if the sampler parameters are updated. + if (IsSamplerUpdated && + (PrevTemp != GraphRef.Params.sampling.temp || + PrevTopP != GraphRef.Params.sampling.top_p || + PrevRepeatPenalty != GraphRef.Params.sampling.penalty_repeat || + PrevPresencePenalty != GraphRef.Params.sampling.penalty_present || + PrevFrequencyPenalty != GraphRef.Params.sampling.penalty_freq || + PrevGrammar != GraphRef.Params.sampling.grammar || + PrevSeed != GraphRef.Params.sampling.seed)) { + *IsSamplerUpdated = true; + } + + return ErrNo::Success; +} + +#endif +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/metadata/metadata_parser.h b/plugins/wasi_nn/GGML/metadata/metadata_parser.h new file mode 100644 index 00000000..40971146 --- /dev/null +++ b/plugins/wasi_nn/GGML/metadata/metadata_parser.h @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#pragma once + +#include "GGML/core/ggml_core.h" +#include "simdjson.h" +#include "wasinntypes.h" + +namespace WasmEdge::Host::WASINN::GGML { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +namespace { + +// JSON parsing helper template +template +T getJsonValue(const simdjson::dom::element &Doc, std::string_view Key) { + if (Doc.at_key(Key).error() == simdjson::SUCCESS) { + T Value{}; + auto Err = Doc[Key].get().get(Value); + if (Err) { + std::string Msg = fmt::format("Unable to retrieve the {} option.", Key); + spdlog::error("[WASI-NN] GGML backend: {}", Msg); + throw ErrNo(ErrNo::InvalidArgument); + } + return Value; + } + throw ErrNo(ErrNo::NotFound); +} + +template +void parseJsonAuto(const simdjson::dom::element &Doc, std::string_view Key, + T &Var) { + try { + auto Result = getJsonValue(Doc, Key); + Var = Result; + } catch (ErrNo E) { + if (E != ErrNo::NotFound) { + throw E; + } + } +} + +template +void parseJsonWithCastAuto(const simdjson::dom::element &Doc, + std::string_view Key, ToType &Var) { + try { + auto Result = getJsonValue(Doc, Key); + Var = static_cast(Result); + } catch (ErrNo E) { + if (E != ErrNo::NotFound) { + throw E; + } + } +} + +template +void parseJsonWithProcessorAuto(const simdjson::dom::element &Doc, + std::string_view Key, Processor Proc) { + try { + auto Result = getJsonValue(Doc, Key); + if (!Proc(Result)) { + throw ErrNo{ErrNo::InvalidArgument}; + } + } catch (ErrNo E) { + if (E != ErrNo::NotFound) { + throw E; + } + } +} + +} // namespace +ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, + const std::string &Metadata, bool *IsModelUpdated = nullptr, + bool *IsContextUpdated = nullptr, + bool *IsSamplerUpdated = nullptr) noexcept; +#endif +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/tts/tts_core.cpp b/plugins/wasi_nn/GGML/tts/tts_core.cpp new file mode 100644 index 00000000..3872b5b3 --- /dev/null +++ b/plugins/wasi_nn/GGML/tts/tts_core.cpp @@ -0,0 +1,480 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "tts_core.h" +#include "host/wasi/vfs_io.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +#include +#include +#include +#endif + +namespace WasmEdge::Host::WASINN::GGML { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +namespace { + +// TTS related functions. +void fillHannWindow(int Length, bool Periodic, float *Output) { + int Offset = -1; + float Pi = static_cast(std::acos(-1)); + if (Periodic) { + Offset = 0; + } + for (int I = 0; I < Length; I++) { + double Value = + 0.5 * + (1.0 - cosf(static_cast((2.0 * Pi * I) / (Length + Offset)))); + Output[I] = static_cast(Value); + } +} + +void twiddle(float *Real, float *Imag, int K, int N) { + float Pi = static_cast(std::acos(-1)); + float Angle = 2 * Pi * K / N; + *Real = cos(Angle); + *Imag = sin(Angle); +} + +void irfft(int N, const float *InpCplx, float *OutReal) { + int NN = N / 2 + 1; + + std::vector RealInput(NN); + std::vector ImagInput(NN); + for (int I = 0; I < NN; ++I) { + RealInput[I] = InpCplx[2 * I]; + ImagInput[I] = InpCplx[2 * I + 1]; + } + + std::vector RealOutput(N); + std::vector ImagOutput(N); + + for (int K = 0; K < N; ++K) { + RealOutput[K] = 0.0f; + ImagOutput[K] = 0.0f; + for (int M = 0; M < NN; ++M) { + float TwiddleReal; + float TwiddleImag; + + twiddle(&TwiddleReal, &TwiddleImag, K * M, N); + + RealOutput[K] += RealInput[M] * TwiddleReal - ImagInput[M] * TwiddleImag; + ImagOutput[K] += RealInput[M] * TwiddleImag + ImagInput[M] * TwiddleReal; + } + } + + for (int I = 0; I < N; ++I) { + OutReal[I] = RealOutput[I] / NN; + } +} + +void fold(const std::vector &Data, int64_t NOut, int64_t NWin, + int64_t NHop, int64_t NPad, std::vector &Output) { + int64_t OutputHeight = NOut; + int64_t KernelW = NWin; + int64_t StrideW = NHop; + int64_t Width = NOut; + + Output.resize(Width, 0.0f); + + int64_t ColIdx = 0; + for (int64_t WCol = 0; WCol < Width; ++WCol) { + int64_t Start = WCol * StrideW - NPad; + int64_t End = Start + KernelW; + + for (int64_t WIm = Start; WIm < End; ++WIm) { + if (WIm >= 0 && WIm < OutputHeight && + ColIdx < static_cast(Data.size())) { + Output[WIm] += Data[ColIdx]; + } + ColIdx++; + } + } + + Output.resize(NOut - 2 * NPad); +} + +std::vector embdToAudio(const float *Embd, const int NCodes, + const int NEmbd, const int NThread) { + const int NFft = 1280; + const int NHop = 320; + const int NWin = 1280; + const int NPad = (NWin - NHop) / 2; + const int NOut = (NCodes - 1) * NHop + NWin; + + std::vector Hann(NFft); + + fillHannWindow(static_cast(Hann.size()), true, Hann.data()); + + int NSpec = NEmbd * NCodes; + + std::vector E(NSpec); + std::vector S(NSpec); + std::vector ST(NSpec); + + for (int L = 0; L < NCodes; ++L) { + for (int K = 0; K < NEmbd; ++K) { + E[K * NCodes + L] = Embd[L * NEmbd + K]; + } + } + + for (int K = 0; K < NEmbd / 2; ++K) { + for (int L = 0; L < NCodes; ++L) { + float Mag = E[(K)*NCodes + L]; + float Phi = E[(K + NEmbd / 2) * NCodes + L]; + + Mag = exp(Mag); + + if (Mag > 1e2) { + Mag = 1e2; + } + S[2 * (K * NCodes + L) + 0] = Mag * cosf(Phi); + S[2 * (K * NCodes + L) + 1] = Mag * sinf(Phi); + } + } + + for (int L = 0; L < NCodes; ++L) { + for (int K = 0; K < NEmbd / 2; ++K) { + ST[L * NEmbd + 2 * K + 0] = S[2 * (K * NCodes + L) + 0]; + ST[L * NEmbd + 2 * K + 1] = S[2 * (K * NCodes + L) + 1]; + } + } + + std::vector Res(NCodes * NFft); + std::vector Hann2(NCodes * NFft); + + std::vector Workers(NThread); + for (int I = 0; I < NThread; ++I) { + Workers[I] = std::thread([&, I]() { + for (int L = I; L < NCodes; L += NThread) { + irfft(NFft, ST.data() + L * NEmbd, Res.data() + L * NFft); + for (int J = 0; J < NFft; ++J) { + Res[L * NFft + J] *= Hann[J]; + Hann2[L * NFft + J] = Hann[J] * Hann[J]; + } + } + }); + } + for (int I = 0; I < NThread; ++I) { + Workers[I].join(); + } + + std::vector Audio; + std::vector Env; + + fold(Res, NOut, NWin, NHop, NPad, Audio); + fold(Hann2, NOut, NWin, NHop, NPad, Env); // TODO: can be done once + + for (size_t I = 0; I < Audio.size(); ++I) { + Audio[I] /= Env[I]; + } + + return Audio; +} + +struct WavHeader { + char Riff[4] = {'R', 'I', 'F', 'F'}; + uint32_t ChunkSize; + char Wave[4] = {'W', 'A', 'V', 'E'}; + char Fmt[4] = {'f', 'm', 't', ' '}; + uint32_t FmtChunkSize = 16; + uint16_t AudioFormat = 1; // PCM + uint16_t NumChannels = 1; // Mono + uint32_t SampleRate; + uint32_t ByteRate; + uint16_t BlockAlign; + uint16_t BitsPerSample = 16; + char Data[4] = {'d', 'a', 't', 'a'}; + uint32_t DataSize; +}; + +std::vector audioDataToWav(const std::vector &Data, + int SampleRate) { + std::vector WavData; + WavHeader Header; + Header.SampleRate = SampleRate; + Header.ByteRate = + Header.SampleRate * Header.NumChannels * (Header.BitsPerSample / 8); + Header.BlockAlign = Header.NumChannels * (Header.BitsPerSample / 8); + Header.DataSize = + static_cast(Data.size() * (Header.BitsPerSample / 8)); + Header.ChunkSize = 36 + Header.DataSize; + + WavData.insert(WavData.end(), reinterpret_cast(&Header), + reinterpret_cast(&Header) + sizeof(Header)); + + for (const auto &Sample : Data) { + int16_t PCMSample = + static_cast(std::clamp(Sample * 32767.0, -32768.0, 32767.0)); + WavData.insert(WavData.end(), reinterpret_cast(&PCMSample), + reinterpret_cast(&PCMSample) + sizeof(PCMSample)); + } + + return WavData; +} + +// Convert a number less than 1000 to words +std::string convertLessThanThousand(int Num) { + std::string Result; + + if (Num >= 100) { + Result += Ones.at(Num / 100) + " hundred "; + Num %= 100; + } + + if (Num >= 20) { + Result += Tens.at(Num / 10); + if (Num % 10 > 0) { + Result += "-" + Ones.at(Num % 10); + } + } else if (Num > 0) { + Result += Ones.at(Num); + } + + return Result; +} + +std::string numberToWords(const std::string &NumberStr) { + try { + size_t DecimalPos = NumberStr.find('.'); + std::string IntegerPart = NumberStr.substr(0, DecimalPos); + + int IntNumber = std::stoi(IntegerPart); + std::string Result; + + if (IntNumber == 0) { + Result = "zero"; + } else { + if (IntNumber >= 1000000000) { + int Billions = IntNumber / 1000000000; + Result += convertLessThanThousand(Billions) + " billion "; + IntNumber %= 1000000000; + } + + if (IntNumber >= 1000000) { + int Millions = IntNumber / 1000000; + Result += convertLessThanThousand(Millions) + " million "; + IntNumber %= 1000000; + } + + if (IntNumber >= 1000) { + int Thousands = IntNumber / 1000; + Result += convertLessThanThousand(Thousands) + " thousand "; + IntNumber %= 1000; + } + + if (IntNumber > 0) { + Result += convertLessThanThousand(IntNumber); + } + } + + // Handle decimal part + if (DecimalPos != std::string::npos) { + Result += " point"; + std::string DecimalPart = NumberStr.substr(DecimalPos + 1); + for (char Digit : DecimalPart) { + Result += " " + Ones.at(Digit - '0'); + } + } + + return Result; + } catch (const std::exception &) { + // Skip if fails + return " "; + } +} + +std::string replaceNumbersWithWords(const std::string &InputText) { + std::regex NumberPattern(R"(\d+(\.\d+)?)"); + std::string Result; + auto It = + std::sregex_iterator(InputText.begin(), InputText.end(), NumberPattern); + auto End = std::sregex_iterator(); + + size_t LastPos = 0; + for (std::sregex_iterator I = It; I != End; ++I) { + const std::smatch &Match = *I; + Result.append(InputText, LastPos, Match.position() - LastPos); + Result.append(numberToWords(Match.str())); + LastPos = Match.position() + Match.length(); + } + Result.append(InputText, LastPos); + + return Result; +} + +std::vector processTTSPrompt(WasiNNEnvironment &Env, + Graph &GraphRef, + std::string &Prompt) noexcept { + // Use the custom speaker profile if available. + TTSSpeakerProfile SpeakerProfile = TTSDefaultSpeakerProfile; + if (!GraphRef.TTSSpeakerFilePath.empty()) { + std::optional SpeakerProfileOpt = + getSpeakerProfileFromFile(GraphRef.TTSSpeakerFilePath, Env); + if (SpeakerProfileOpt.has_value()) { + SpeakerProfile = *SpeakerProfileOpt; + } else { + RET_ERROR( + {}, + "processTTSPrompt: Failed to load speaker profile from file: {}"sv, + GraphRef.TTSSpeakerFilePath); + } + } + std::string ProcessedPrompt = processTTSPromptText(Prompt); + std::vector Result, TmpTokens; + Result = common_tokenize(GraphRef.LlamaContext.get(), "<|im_start|>\n", + /* add_special */ true, + /* parse_special */ true); + TmpTokens = common_tokenize(GraphRef.LlamaContext.get(), SpeakerProfile.Text, + /* add_special */ false, + /* parse_special */ true); + Result.insert(Result.end(), TmpTokens.begin(), TmpTokens.end()); + TmpTokens = common_tokenize(GraphRef.LlamaContext.get(), ProcessedPrompt, + /* add_special */ false, + /* parse_special */ true); + Result.insert(Result.end(), TmpTokens.begin(), TmpTokens.end()); + TmpTokens = common_tokenize(GraphRef.LlamaContext.get(), "<|text_end|>\n", + /* add_special */ false, + /* parse_special */ true); + Result.insert(Result.end(), TmpTokens.begin(), TmpTokens.end()); + TmpTokens = common_tokenize(GraphRef.LlamaContext.get(), SpeakerProfile.Data, + /* add_special */ false, + /* parse_special */ true); + Result.insert(Result.end(), TmpTokens.begin(), TmpTokens.end()); + + return Result; +} + +// Based on: +// https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39 +// https://github.com/ggerganov/llama.cpp/blob/b4488/examples/tts/tts.cpp#L374 +std::string processTTSPromptText(const std::string &Text) { + std::string ProcessedText = replaceNumbersWithWords(Text); + + std::transform( + ProcessedText.begin(), ProcessedText.end(), ProcessedText.begin(), + [](unsigned char C) { return static_cast(::tolower(C)); }); + + std::regex SpecialChars(R"([-_/,\.\\])"); + ProcessedText = std::regex_replace(ProcessedText, SpecialChars, " "); + + std::regex NonAlpha(R"([^a-z\s])"); + ProcessedText = std::regex_replace(ProcessedText, NonAlpha, ""); + + std::regex MultipleSpaces(R"(\s+)"); + ProcessedText = std::regex_replace(ProcessedText, MultipleSpaces, " "); + + ProcessedText = + std::regex_replace(ProcessedText, std::regex(R"(^\s+|\s+$)"), ""); + + ProcessedText = + std::regex_replace(ProcessedText, std::regex(R"(\s)"), "<|text_sep|>"); + + return ProcessedText; +} + +std::optional +getSpeakerProfileFromFile(const std::string &FilePath, WasiNNEnvironment &Env) { + WasmEdge::FStream::IFStream JsonFile(FilePath, Env.getEnv()); + if (!JsonFile.is_open()) { + return std::nullopt; + } + nlohmann::json JsonData; + JsonFile >> JsonData; + JsonFile.close(); + + // Initialize the outputs + std::string AudioOutputText = "<|audio_start|>\n"; + std::string TextOutput = "<|text_start|>"; + + // Iterate through each word in the JSON data + for (const auto &WordData : JsonData["words"]) { + std::string Word = WordData["word"]; + double Duration = WordData["duration"]; + std::vector Codes = WordData["codes"]; + + // Create the audio output entry + std::ostringstream WordEntry; + WordEntry << Word << "<|t_" << std::fixed << std::setprecision(2) + << Duration << "|><|code_start|>"; + for (const auto &Code : Codes) { + WordEntry << "<|" << Code << "|>"; + } + WordEntry << "<|code_end|>\n"; + AudioOutputText += WordEntry.str(); + + // Create the text output entry + TextOutput += Word + "<|text_sep|>"; + } + + return TTSSpeakerProfile{TextOutput, AudioOutputText}; +} + +// TextToSpeech function, will generate voice data from codes. +ErrNo codesToSpeech(WasiNNEnvironment &Env, Graph &GraphRef, + Context &CxtRef) noexcept { + // Remove all non-audio tokens. + CxtRef.LlamaOutputTokens.erase( + std::remove_if(CxtRef.LlamaOutputTokens.begin(), + CxtRef.LlamaOutputTokens.end(), + [](llama_token T) { return T < 151672 || T > 155772; }), + CxtRef.LlamaOutputTokens.end()); + + // Adjust the token values for audio data. + for (llama_token &Token : CxtRef.LlamaOutputTokens) { + Token -= 151672; + } + + // Put codes into batch. + const uint32_t NCodes = + static_cast(CxtRef.LlamaOutputTokens.size()); + llama_batch TTSBatch = + llama_batch_init(NCodes, /* embd */ 0, /* n_seq_max */ 1); + for (uint32_t I = 0; I < NCodes; ++I) { + common_batch_add(TTSBatch, CxtRef.LlamaOutputTokens[I], I, + /* seq_ids */ {0}, /* logits */ true); + } + if (llama_decode(GraphRef.TTSContext.get(), TTSBatch) != 0) { + RET_ERROR(ErrNo::RuntimeError, "codesToSpeech: fail to eval."sv) + } + llama_batch_free(TTSBatch); + + // Get embeddings. + const int NEmbd = llama_model_n_embd(GraphRef.TTSModel.get()); + const float *Embd = llama_get_embeddings(GraphRef.TTSContext.get()); + + // Embeddings to audio. + std::vector AudioData = + embdToAudio(Embd, NCodes, NEmbd, + static_cast(GraphRef.Params.cpuparams.n_threads)); + + // Zero out first 0.25 seconds of audio. + const uint32_t SamplingRate = 24000; + for (uint32_t I = 0; I < SamplingRate / 4; ++I) { + AudioData[I] = 0.0f; + } + + // Convert audio data to wav and put it into output buffer. + CxtRef.LlamaOutputs = audioDataToWav(AudioData, SamplingRate); + + // Save .wav file if path is provided. + if (!GraphRef.TTSOutputFilePath.empty()) { + WasmEdge::FStream::OFStream File(GraphRef.TTSOutputFilePath, + std::ios_base::out | std::ios_base::binary, + Env.getEnv()); + if (!File) { + RET_ERROR(ErrNo::RuntimeError, + "codesToSpeech: Failed to open file '{}' for writing"sv, + GraphRef.TTSOutputFilePath); + } + File.write(reinterpret_cast(CxtRef.LlamaOutputs.data()), + CxtRef.LlamaOutputs.size()); + File.close(); + } + + return ErrNo::Success; +} + +} // namespace +#endif +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/tts/tts_core.h b/plugins/wasi_nn/GGML/tts/tts_core.h new file mode 100644 index 00000000..b5c9ae65 --- /dev/null +++ b/plugins/wasi_nn/GGML/tts/tts_core.h @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "GGML/core/ggml_core.h" + +namespace WasmEdge::Host::WASINN::GGML { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +namespace { + +struct TTSSpeakerProfile { + std::string Text; + std::string Data; +}; + +// TTS function to process the prompt text. +// clang-format off +const TTSSpeakerProfile TTSDefaultSpeakerProfile = { + // Speaker profile from edwko/OuteTTS (en_female_1.json). + "<|text_start|>uhm<|text_sep|>now<|text_sep|>being<|text_sep|>the<|text_sep|>one<|text_sep|>to<|text_sep|>say<|text_sep|>i<|text_sep|>know<|text_sep|>the<|text_sep|>worst<|text_sep|>of<|text_sep|>you<|text_sep|>and<|text_sep|>ive<|text_sep|>been<|text_sep|>directly<|text_sep|>affected<|text_sep|>by<|text_sep|>people<|text_sep|>like<|text_sep|>you<|text_sep|>but<|text_sep|>its<|text_sep|>a<|text_sep|>clean<|text_sep|>slate<|text_sep|>with<|text_sep|>me<|text_sep|>buddy<|text_sep|>you<|text_sep|>know<|text_sep|>like<|text_sep|>thats<|text_sep|>really<|text_sep|>powerful<|text_sep|>in<|text_sep|>and<|text_sep|>of<|text_sep|>itself<|text_sep|>", + "<|audio_start|>\nuhm<|t_0.36|><|code_start|><|447|><|223|><|967|><|301|><|965|><|827|><|393|><|908|><|764|><|1167|><|711|><|1222|><|324|><|1318|><|806|><|498|><|1198|><|1127|><|1178|><|916|><|1234|><|1411|><|1428|><|706|><|427|><|1605|><|1578|><|code_end|>\nnow<|t_0.36|><|code_start|><|1049|><|327|><|385|><|1070|><|732|><|1480|><|450|><|1025|><|1469|><|174|><|1013|><|1710|><|1674|><|775|><|771|><|251|><|778|><|1400|><|897|><|1487|><|366|><|441|><|1000|><|393|><|271|><|1000|><|768|><|code_end|>\nbeing<|t_0.27|><|code_start|><|926|><|406|><|1457|><|437|><|1231|><|672|><|1785|><|521|><|1179|><|1559|><|198|><|1086|><|733|><|122|><|1344|><|845|><|348|><|1389|><|470|><|1773|><|code_end|>\nthe<|t_0.08|><|code_start|><|1775|><|562|><|768|><|1222|><|768|><|963|><|code_end|>\none<|t_0.21|><|code_start|><|1757|><|744|><|144|><|1610|><|655|><|616|><|1317|><|225|><|1325|><|913|><|1342|><|992|><|1018|><|80|><|1777|><|883|><|code_end|>\nto<|t_0.08|><|code_start|><|487|><|1363|><|1682|><|1426|><|655|><|1483|><|code_end|>\nsay<|t_0.27|><|code_start|><|1644|><|1804|><|731|><|273|><|1592|><|731|><|1523|><|1404|><|984|><|1207|><|430|><|1132|><|1123|><|768|><|1116|><|829|><|1082|><|1095|><|440|><|1162|><|code_end|>\ni<|t_0.33|><|code_start|><|1330|><|335|><|1162|><|1155|><|308|><|1162|><|1150|><|1481|><|612|><|674|><|712|><|1745|><|1188|><|1787|><|1135|><|1275|><|1237|><|1143|><|408|><|1063|><|393|><|927|><|1298|><|132|><|1686|><|code_end|>\nknow<|t_0.27|><|code_start|><|983|><|1677|><|586|><|1528|><|1435|><|835|><|1396|><|706|><|987|><|22|><|1172|><|218|><|1404|><|1001|><|521|><|1389|><|775|><|1416|><|877|><|120|><|code_end|>\nthe<|t_0.16|><|code_start|><|916|><|1756|><|513|><|1245|><|1392|><|89|><|1266|><|12|><|1045|><|1075|><|904|><|35|><|code_end|>\nworst<|t_0.32|><|code_start|><|1607|><|174|><|1231|><|144|><|932|><|490|><|771|><|1504|><|798|><|674|><|364|><|80|><|1314|><|1636|><|449|><|1704|><|713|><|1795|><|968|><|1527|><|1302|><|1529|><|1176|><|795|><|code_end|>\nof<|t_0.12|><|code_start|><|1193|><|1205|><|390|><|1128|><|1091|><|883|><|322|><|377|><|1070|><|code_end|>\nyou<|t_0.17|><|code_start|><|1016|><|1332|><|926|><|281|><|927|><|1368|><|1687|><|918|><|67|><|1638|><|1317|><|1265|><|1770|><|code_end|>\nand<|t_0.28|><|code_start|><|1129|><|1633|><|1373|><|1207|><|405|><|879|><|1030|><|1253|><|1071|><|612|><|724|><|1770|><|665|><|1046|><|1351|><|1450|><|1541|><|1384|><|111|><|1477|><|284|><|code_end|>\nive<|t_0.35|><|code_start|><|674|><|266|><|89|><|1333|><|1183|><|1526|><|1143|><|883|><|1135|><|732|><|827|><|1119|><|594|><|1261|><|1024|><|1347|><|92|><|1392|><|825|><|1710|><|1289|><|1598|><|1070|><|1525|><|1442|><|555|><|code_end|>\nbeen<|t_0.17|><|code_start|><|1461|><|194|><|337|><|1128|><|188|><|892|><|848|><|1280|><|959|><|754|><|231|><|649|><|1304|><|code_end|>\ndirectly<|t_0.87|><|code_start|><|1030|><|353|><|570|><|1331|><|470|><|1832|><|1362|><|1809|><|1383|><|101|><|325|><|1557|><|1242|><|1512|><|180|><|227|><|1242|><|643|><|209|><|464|><|171|><|1219|><|174|><|1723|><|734|><|118|><|1269|><|643|><|209|><|187|><|612|><|1231|><|68|><|567|><|1242|><|505|><|319|><|1268|><|794|><|678|><|40|><|1286|><|470|><|1454|><|199|><|965|><|188|><|300|><|1234|><|1125|><|794|><|1289|><|1224|><|257|><|469|><|1121|><|101|><|823|><|1769|><|1683|><|95|><|255|><|59|><|67|><|832|><|code_end|>\naffected<|t_0.44|><|code_start|><|510|><|873|><|787|><|1228|><|771|><|1428|><|501|><|751|><|696|><|258|><|845|><|1818|><|1112|><|498|><|1111|><|985|><|1073|><|832|><|1427|><|168|><|163|><|447|><|119|><|567|><|1626|><|1820|><|903|><|635|><|1060|><|10|><|1632|><|35|><|1635|><|code_end|>\nby<|t_0.19|><|code_start|><|144|><|144|><|460|><|185|><|1112|><|1044|><|498|><|1192|><|656|><|1333|><|1001|><|1186|><|1186|><|454|><|code_end|>\npeople<|t_0.48|><|code_start|><|1260|><|747|><|351|><|526|><|612|><|1151|><|1262|><|1791|><|344|><|1752|><|1547|><|930|><|1302|><|1703|><|1289|><|92|><|1407|><|1482|><|508|><|1431|><|355|><|1696|><|337|><|199|><|1157|><|223|><|464|><|568|><|845|><|411|><|826|><|718|><|1786|><|545|><|712|><|580|><|code_end|>\nlike<|t_0.32|><|code_start|><|630|><|532|><|526|><|607|><|526|><|839|><|1305|><|660|><|459|><|339|><|717|><|1178|><|1148|><|687|><|149|><|1390|><|229|><|199|><|513|><|712|><|1451|><|731|><|582|><|1551|><|code_end|>\nyou<|t_0.21|><|code_start|><|1389|><|954|><|1781|><|1047|><|1236|><|930|><|809|><|1621|><|1268|><|384|><|242|><|587|><|869|><|816|><|1680|><|405|><|code_end|>\nbut<|t_0.59|><|code_start|><|1089|><|1590|><|908|><|80|><|594|><|1046|><|1706|><|1025|><|1150|><|405|><|548|><|893|><|1285|><|464|><|301|><|939|><|643|><|23|><|285|><|161|><|209|><|453|><|72|><|167|><|417|><|244|><|151|><|643|><|391|><|199|><|651|><|1023|><|337|><|1010|><|54|><|331|><|1167|><|756|><|388|><|934|><|1060|><|18|><|1624|><|1060|><|code_end|>\nits<|t_0.16|><|code_start|><|1102|><|183|><|1199|><|1258|><|1285|><|35|><|659|><|180|><|426|><|1587|><|1733|><|942|><|code_end|>\na<|t_0.04|><|code_start|><|791|><|1012|><|818|><|code_end|>\nclean<|t_0.61|><|code_start|><|1819|><|976|><|163|><|447|><|316|><|223|><|763|><|457|><|1208|><|1808|><|1697|><|1162|><|1660|><|1833|><|1054|><|1734|><|1121|><|1309|><|1643|><|924|><|1677|><|1548|><|869|><|1268|><|223|><|674|><|111|><|792|><|1670|><|912|><|174|><|1554|><|90|><|80|><|1563|><|1621|><|1698|><|1544|><|992|><|988|><|175|><|793|><|1661|><|1026|><|80|><|1761|><|code_end|>\nslate<|t_0.40|><|code_start|><|1802|><|322|><|1689|><|1577|><|1302|><|1552|><|1529|><|1722|><|1580|><|582|><|1642|><|1529|><|1020|><|582|><|1538|><|970|><|437|><|1141|><|1477|><|988|><|335|><|1611|><|922|><|1558|><|1120|><|1189|><|423|><|188|><|171|><|562|><|code_end|>\nwith<|t_0.15|><|code_start|><|963|><|1347|><|1274|><|747|><|1230|><|712|><|1408|><|1290|><|957|><|1279|><|258|><|code_end|>\nme<|t_0.09|><|code_start|><|638|><|1058|><|174|><|1452|><|1038|><|894|><|1571|><|code_end|>\nbuddy<|t_0.32|><|code_start|><|1003|><|130|><|1341|><|938|><|40|><|804|><|167|><|89|><|1456|><|1189|><|1155|><|1171|><|1434|><|1077|><|1029|><|1455|><|1622|><|1037|><|163|><|1411|><|1165|><|1463|><|837|><|1202|><|code_end|>\nyou<|t_0.36|><|code_start|><|1354|><|1165|><|615|><|1588|><|1192|><|1445|><|1033|><|982|><|401|><|1079|><|684|><|1570|><|266|><|31|><|420|><|163|><|893|><|845|><|905|><|1827|><|1804|><|153|><|627|><|243|><|1179|><|298|><|1147|><|code_end|>\nknow<|t_0.19|><|code_start|><|163|><|1542|><|1366|><|698|><|1753|><|206|><|916|><|1499|><|245|><|665|><|600|><|894|><|587|><|1741|><|code_end|>\nlike<|t_0.24|><|code_start|><|1106|><|1280|><|1062|><|1304|><|945|><|809|><|598|><|104|><|1001|><|822|><|965|><|189|><|693|><|1810|><|1293|><|199|><|1277|><|44|><|code_end|>\nthats<|t_0.24|><|code_start|><|121|><|1789|><|1443|><|370|><|1154|><|393|><|1178|><|1200|><|1264|><|424|><|1391|><|381|><|978|><|1346|><|704|><|1808|><|1579|><|1492|><|code_end|>\nreally<|t_0.56|><|code_start|><|1177|><|1761|><|1723|><|1360|><|1413|><|830|><|551|><|193|><|59|><|332|><|598|><|734|><|1684|><|1802|><|60|><|1590|><|353|><|89|><|1636|><|1396|><|893|><|143|><|455|><|1501|><|435|><|1082|><|621|><|1593|><|677|><|474|><|971|><|1513|><|913|><|828|><|1381|><|1148|><|1798|><|1186|><|1443|><|38|><|335|><|883|><|code_end|>\npowerful<|t_0.63|><|code_start|><|1773|><|458|><|1070|><|964|><|826|><|1220|><|1012|><|1738|><|1125|><|669|><|490|><|1169|><|922|><|958|><|1204|><|489|><|1001|><|886|><|1045|><|675|><|1471|><|1652|><|732|><|698|><|1124|><|480|><|897|><|1484|><|1028|><|35|><|594|><|1465|><|505|><|1669|><|436|><|851|><|1288|><|31|><|1501|><|1187|><|394|><|909|><|1541|><|1793|><|1720|><|922|><|840|><|code_end|>\nin<|t_0.16|><|code_start|><|1317|><|523|><|630|><|1343|><|1187|><|719|><|907|><|636|><|111|><|1524|><|188|><|1382|><|code_end|>\nand<|t_0.13|><|code_start|><|1074|><|922|><|1280|><|1496|><|1050|><|832|><|133|><|1435|><|1049|><|1774|><|code_end|>\nof<|t_0.12|><|code_start|><|960|><|1052|><|1192|><|1303|><|1112|><|970|><|417|><|60|><|1155|><|code_end|>\nitself<|t_0.47|><|code_start|><|1682|><|1209|><|1410|><|513|><|1222|><|861|><|167|><|406|><|1551|><|582|><|634|><|1529|><|786|><|1363|><|1578|><|1739|><|873|><|424|><|1041|><|1328|><|955|><|1110|><|1490|><|1424|><|1199|><|988|><|1162|><|1133|><|1193|><|978|><|470|><|832|><|963|><|1251|><|733|><|code_end|>", +}; +// clang-format on + +const std::map Ones = { + {0, "zero"}, {1, "one"}, {2, "two"}, {3, "three"}, + {4, "four"}, {5, "five"}, {6, "six"}, {7, "seven"}, + {8, "eight"}, {9, "nine"}, {10, "ten"}, {11, "eleven"}, + {12, "twelve"}, {13, "thirteen"}, {14, "fourteen"}, {15, "fifteen"}, + {16, "sixteen"}, {17, "seventeen"}, {18, "eighteen"}, {19, "nineteen"}}; + +const std::map Tens = { + {2, "twenty"}, {3, "thirty"}, {4, "forty"}, {5, "fifty"}, + {6, "sixty"}, {7, "seventy"}, {8, "eighty"}, {9, "ninety"}}; + +std::string processTTSPromptText(const std::string &Text); +std::optional +getSpeakerProfileFromFile(const std::string &FilePath, WasiNNEnvironment &Env); + +std::vector embdToAudio(const float *Embd, const int NCodes, + const int NEmbd, const int NThread); +std::vector audioDataToWav(const std::vector &Data, + int SampleRate); +} // namespace +std::vector processTTSPrompt(WasiNNEnvironment &Env, + Graph &GraphRef, + std::string &Prompt) noexcept; +ErrNo codesToSpeech(WasiNNEnvironment &Env, Graph &GraphRef, + Context &CxtRef) noexcept; +#endif +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/utils.cpp b/plugins/wasi_nn/GGML/utils.cpp new file mode 100644 index 00000000..72e9650e --- /dev/null +++ b/plugins/wasi_nn/GGML/utils.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC + +#include "utils.h" + +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +#include +#endif + +namespace WasmEdge::Host::WASINN::GGML { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +// Helper to init a llama batch. +struct llama_batch allocBatch(int64_t NTokens, int64_t Embd, + int32_t NSeqMax) noexcept { + struct llama_batch Batch = llama_batch_init( + /* n_tokens_alloc */ static_cast(NTokens), + /* embd */ static_cast(Embd), + /* n_seq_max */ static_cast(NSeqMax)); + std::fill(Batch.n_seq_id, Batch.n_seq_id + NTokens, + static_cast(NSeqMax)); + for (int64_t I = 0; I < NTokens; I++) { + std::fill(Batch.seq_id[I], Batch.seq_id[I] + NSeqMax, 0); + } + std::fill(Batch.logits, Batch.logits + NTokens, false); + return Batch; +} + +// Get base64 image position if found in prompt. +std::optional> +findBase64ImagePayload(std::string_view Prompt, bool IsDebugLog) noexcept { + // Find `` + auto EndTagPos = Prompt.find(Base64ImageTagSuffix, PayloadPos); + if (EndTagPos == std::string::npos) { + LOG_DEBUG(IsDebugLog, "base64: image tag unclosed."sv) + return std::nullopt; + } + return std::make_tuple(BeginTagPos, PayloadPos, EndTagPos); +} + +// Extract base64 image payload and image type. Replace it with placeholder. +std::optional, std::string>> +extractBase64ImagePayload(std::string &Prompt, + std::tuple ImagePos, + const std::string_view Placeholder) noexcept { + // Locate the payload and image type. + size_t BeginTagPos = std::get<0>(ImagePos); + size_t TypePos = std::get<0>(ImagePos) + Base64ImageTagPrefix.size(); + size_t PayloadPos = std::get<1>(ImagePos); + size_t BeginBytePos = std::get<1>(ImagePos) + Base64ImageBytesPrefix.size(); + size_t EndTagPos = std::get<2>(ImagePos); + std::string_view Payload = + std::string_view(Prompt).substr(BeginBytePos, EndTagPos - BeginBytePos); + std::string ImageType = Prompt.substr(TypePos, PayloadPos - TypePos); + + // Decode the base64 payload. + auto RequiredBytes = base64::required_encode_size(Payload.size()); + std::vector ImageBytes(RequiredBytes); + try { + base64::decode(Payload.begin(), Payload.end(), ImageBytes.begin()); + } catch (const base64_error &E) { + RET_ERROR(std::make_pair(std::vector(), ""), + "base64: Error when calling base64::decode: {}"sv, E.what()) + } + + // Replace the base64 image with the placeholder. + Prompt.replace(BeginTagPos, + EndTagPos - BeginTagPos + Base64ImageTagSuffix.size(), + Placeholder); + return std::make_pair(ImageBytes, ImageType); +} + +#endif +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/utils.h b/plugins/wasi_nn/GGML/utils.h new file mode 100644 index 00000000..943b22ba --- /dev/null +++ b/plugins/wasi_nn/GGML/utils.h @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2019-2024 Second State INC +#pragma once +#include "GGML/core/ggml_core.h" + +namespace WasmEdge::Host::WASINN::GGML { +#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML +const std::string_view Base64ImageTagPrefix = ""sv; +const std::string_view VisionPromptImagePlaceholder = ""sv; + +struct llama_batch allocBatch(int64_t NTokens, int64_t Embd = 0, + int32_t NSeqMax = 1) noexcept; +std::optional> +findBase64ImagePayload(std::string_view Prompt, + bool IsDebugLog = false) noexcept; +std::optional, std::string>> +extractBase64ImagePayload(std::string &Prompt, + std::tuple ImagePos, + const std::string_view Placeholder) noexcept; +#endif +} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/wasinn_ggml.cpp b/plugins/wasi_nn/wasinn_ggml.cpp deleted file mode 100644 index 52575667..00000000 --- a/plugins/wasi_nn/wasinn_ggml.cpp +++ /dev/null @@ -1,2289 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: 2019-2024 Second State INC - -#include "wasinn_ggml.h" -#include "common/types.h" -#include "host/wasi/vfs_io.h" -#include "wasinnenv.h" -#include - -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML -#include "simdjson.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#endif - -namespace WasmEdge::Host::WASINN::GGML { -#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML - -namespace { - -// Macro for logging debug message. -#define LOG_DEBUG(Debug, ...) \ - if (Debug) { \ - spdlog::info("[WASI-NN][Debug] GGML backend: "sv __VA_ARGS__); \ - } - -// Macro for logging info message. -#define LOG_INFO(Info, ...) \ - if (Info) { \ - spdlog::info("[WASI-NN] GGML backend: "sv __VA_ARGS__); \ - } - -// Macro for logging warning message. -#define LOG_WARN(...) spdlog::warn("[WASI-NN] GGML backend: "sv __VA_ARGS__); - -// Macro for logging error message. -#define LOG_ERROR(...) spdlog::error("[WASI-NN] GGML backend: "sv __VA_ARGS__); - -// Macro for logging error message and return. -#define RET_ERROR(Error, ...) \ - spdlog::error("[WASI-NN] GGML backend: "sv __VA_ARGS__); \ - return Error; - -// JSON parsing helper template -template -T getJsonValue(const simdjson::dom::element &Doc, std::string_view Key) { - if (Doc.at_key(Key).error() == simdjson::SUCCESS) { - T Value{}; - auto Err = Doc[Key].get().get(Value); - if (Err) { - std::string Msg = fmt::format("Unable to retrieve the {} option.", Key); - spdlog::error("[WASI-NN] GGML backend: {}", Msg); - throw ErrNo(ErrNo::InvalidArgument); - } - return Value; - } - throw ErrNo(ErrNo::NotFound); -} - -template -void parseJsonAuto(const simdjson::dom::element &Doc, std::string_view Key, - T &Var) { - try { - auto Result = getJsonValue(Doc, Key); - Var = Result; - } catch (ErrNo E) { - if (E != ErrNo::NotFound) { - throw E; - } - } -} - -template -void parseJsonWithCastAuto(const simdjson::dom::element &Doc, - std::string_view Key, ToType &Var) { - try { - auto Result = getJsonValue(Doc, Key); - Var = static_cast(Result); - } catch (ErrNo E) { - if (E != ErrNo::NotFound) { - throw E; - } - } -} - -template -void parseJsonWithProcessorAuto(const simdjson::dom::element &Doc, - std::string_view Key, Processor Proc) { - try { - auto Result = getJsonValue(Doc, Key); - if (!Proc(Result)) { - throw ErrNo{ErrNo::InvalidArgument}; - } - } catch (ErrNo E) { - if (E != ErrNo::NotFound) { - throw E; - } - } -} - -// Llama logging callback. -void llamaLogCallback(ggml_log_level LogLevel, const char *LogText, - void *UserData) { - Graph &GraphRef = *reinterpret_cast(UserData); - if (!GraphRef.EnableLog) { - return; - } - std::string Text(LogText); - // Remove the trailing newlines. - Text = Text.erase(Text.find_last_not_of("\n") + 1); - // Skip for "." - if (Text == ".") { - return; - } - if (LogLevel == GGML_LOG_LEVEL_ERROR) { - spdlog::error("[WASI-NN] llama.cpp: {}"sv, Text); - } else if (LogLevel == GGML_LOG_LEVEL_WARN) { - spdlog::warn("[WASI-NN] llama.cpp: {}"sv, Text); - } else if (LogLevel == GGML_LOG_LEVEL_INFO) { - spdlog::info("[WASI-NN] llama.cpp: {}"sv, Text); - } else if (LogLevel == GGML_LOG_LEVEL_DEBUG) { - spdlog::debug("[WASI-NN] llama.cpp: {}"sv, Text); - } -} - -// >>>>>>>> Metadata related functions >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> - -// Parse metadata from json. -ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, - const std::string &Metadata, bool *IsModelUpdated = nullptr, - bool *IsContextUpdated = nullptr, - bool *IsSamplerUpdated = nullptr) noexcept { - // Parse metadata from the json. - simdjson::dom::parser Parser; - simdjson::dom::element Doc; - auto ParseError = Parser.parse(Metadata).get(Doc); - if (ParseError) { - RET_ERROR(ErrNo::InvalidEncoding, "parse metadata error."sv) - } - - // Get the current llama parameters. - int64_t PrevNGPULayers = GraphRef.Params.n_gpu_layers; - bool PrevEmbedding = GraphRef.Params.embedding; - // Get the current sampler parameters. - double PrevTemp = GraphRef.Params.sampling.temp; - double PrevTopP = GraphRef.Params.sampling.top_p; - double PrevRepeatPenalty = GraphRef.Params.sampling.penalty_repeat; - double PrevPresencePenalty = GraphRef.Params.sampling.penalty_present; - double PrevFrequencyPenalty = GraphRef.Params.sampling.penalty_freq; - std::string PrevGrammar = GraphRef.Params.sampling.grammar; - uint64_t PrevSeed = GraphRef.Params.sampling.seed; - - try { - parseJsonAuto(Doc, "enable-log", GraphRef.EnableLog); - parseJsonAuto(Doc, "enable-debug-log", GraphRef.EnableDebugLog); - - parseJsonWithCastAuto(Doc, "main-gpu", GraphRef.Params.main_gpu); - parseJsonWithCastAuto(Doc, "n-gpu-layers", - GraphRef.Params.n_gpu_layers); - - parseJsonWithProcessorAuto(Doc, "cpu-moe", - [&GraphRef](const bool &CpuMoe) -> bool { - if (CpuMoe) { - GraphRef.TensorBuftOverrides.push_back( - "\\.ffn_(up|down|gate)_exps"); - } - return true; - }); - - parseJsonWithProcessorAuto( - Doc, "n-cpu-moe", [&GraphRef](const int64_t &NCpuMoe) -> bool { - if (NCpuMoe < 0) { - spdlog::error("[WASI-NN] GGML backend: Invalid n-cpu-moe value."); - return false; - } - for (int I = 0; I < NCpuMoe; I++) { - GraphRef.TensorBuftOverrides.push_back( - string_format("blk\\.%d\\.ffn_(up|down|gate)_exps", I)); - } - return true; - }); - parseJsonWithProcessorAuto( - Doc, "tensor-split", [&GraphRef](const std::string_view &TSV) -> bool { - // The TensorSplit is a comma-separated list of non-negative values. - // E.g., "3,2" presents 60% of the data to GPU 0 and 40% to GPU 1. - std::string TS(TSV); - std::replace(TS.begin(), TS.end(), ',', ' '); - std::stringstream SS(TS); - std::memset(GraphRef.Params.tensor_split, 0, - sizeof(GraphRef.Params.tensor_split)); - uint32_t TensorSplitSize = 0; - while (SS.good()) { - float TmpTensor; - SS >> TmpTensor; - GraphRef.Params.tensor_split[TensorSplitSize++] = TmpTensor; - } - size_t NDevices = llama_max_devices(); - if (TensorSplitSize > NDevices) { - spdlog::error( - "[WASI-NN] GGML backend: Number of Tensor-Split is larger than " - "MaxDevices, please reduce the size of tensor-split."); - return false; - } - for (size_t Idx = TensorSplitSize; Idx < NDevices; Idx++) { - GraphRef.Params.tensor_split[TensorSplitSize++] = 0.0f; - } - return true; - }); - parseJsonAuto(Doc, "embedding", GraphRef.Params.embedding); - parseJsonWithProcessorAuto( - Doc, "split-mode", - [&GraphRef](const std::string_view &SplitMode) -> bool { - if (SplitMode == "none"sv) { - GraphRef.Params.split_mode = LLAMA_SPLIT_MODE_NONE; - } else if (SplitMode == "layer"sv) { - GraphRef.Params.split_mode = LLAMA_SPLIT_MODE_LAYER; - } else if (SplitMode == "row"sv) { - GraphRef.Params.split_mode = LLAMA_SPLIT_MODE_ROW; - } else { - spdlog::error("[WASI-NN] GGML backend: Unknown split-mode: " - "{}. Valid: none, layer, row.", - SplitMode); - return false; - } - return true; - }); - parseJsonWithCastAuto(Doc, "mmproj", - GraphRef.Params.mmproj.path); - // The TTS parameters using macros - parseJsonAuto(Doc, "tts", GraphRef.TextToSpeech); - parseJsonWithCastAuto(Doc, "model-vocoder", - GraphRef.Params.vocoder.model.path); - parseJsonWithCastAuto(Doc, "tts-output-file", - GraphRef.TTSOutputFilePath); - - parseJsonWithCastAuto(Doc, "tts-speaker-file", - GraphRef.TTSSpeakerFilePath); - - // The context parameters - parseJsonWithCastAuto(Doc, "ctx-size", GraphRef.Params.n_ctx); - parseJsonWithCastAuto(Doc, "batch-size", GraphRef.Params.n_batch); - parseJsonWithCastAuto(Doc, "ubatch-size", - GraphRef.Params.n_ubatch); - parseJsonWithCastAuto(Doc, "n-keep", GraphRef.Params.n_keep); - parseJsonWithCastAuto(Doc, "n-chunks", GraphRef.Params.n_chunks); - parseJsonWithCastAuto(Doc, "n-parallel", - GraphRef.Params.n_parallel); - parseJsonWithCastAuto(Doc, "n-sequences", - GraphRef.Params.n_sequences); - parseJsonWithCastAuto(Doc, "grp-attn-n", - GraphRef.Params.grp_attn_n); - parseJsonWithCastAuto(Doc, "grp-attn-w", - GraphRef.Params.grp_attn_w); - parseJsonWithCastAuto(Doc, "n-print", GraphRef.Params.n_print); - parseJsonWithCastAuto(Doc, "rope-freq-base", - GraphRef.Params.rope_freq_base); - parseJsonWithCastAuto(Doc, "rope-freq-scale", - GraphRef.Params.rope_freq_scale); - parseJsonWithCastAuto(Doc, "yarn-ext-factor", - GraphRef.Params.yarn_ext_factor); - parseJsonWithCastAuto(Doc, "yarn-attn-factor", - GraphRef.Params.yarn_attn_factor); - parseJsonWithCastAuto(Doc, "yarn-beta-fast", - GraphRef.Params.yarn_beta_fast); - parseJsonWithCastAuto(Doc, "yarn-beta-slow", - GraphRef.Params.yarn_beta_slow); - parseJsonWithCastAuto(Doc, "yarn-orig-ctx", - GraphRef.Params.yarn_orig_ctx); - parseJsonAuto(Doc, "mask-valid", - GraphRef.Params.cpuparams.mask_valid); - parseJsonWithCastAuto(Doc, "priority", - GraphRef.Params.cpuparams.priority); - parseJsonAuto(Doc, "strict-cpu", - GraphRef.Params.cpuparams.strict_cpu); - parseJsonWithCastAuto(Doc, "poll", GraphRef.Params.cpuparams.poll); - parseJsonAuto(Doc, "mask-valid-batch", - GraphRef.Params.cpuparams_batch.mask_valid); - parseJsonWithProcessorAuto( - Doc, "priority-batch", [&GraphRef](const int64_t &Priority) -> bool { - GraphRef.Params.cpuparams_batch.priority = - static_cast(Priority); - return true; - }); - parseJsonAuto(Doc, "strict-cpu-batch", - GraphRef.Params.cpuparams_batch.strict_cpu); - parseJsonWithCastAuto(Doc, "poll-batch", - GraphRef.Params.cpuparams_batch.poll); - - parseJsonWithCastAuto(Doc, "numa", GraphRef.Params.numa); - - parseJsonWithCastAuto(Doc, "rope-scaling-type", - GraphRef.Params.rope_scaling_type); - parseJsonWithCastAuto(Doc, "pooling-type", - GraphRef.Params.pooling_type); - parseJsonWithCastAuto(Doc, "attention-type", - GraphRef.Params.attention_type); - parseJsonWithProcessorAuto( - Doc, "threads", [&GraphRef](const int64_t &NThreads) -> bool { - GraphRef.Params.cpuparams.n_threads = static_cast(NThreads); - return true; - }); - parseJsonWithCastAuto(Doc, "threads-batch", - GraphRef.Params.cpuparams_batch.n_threads); - parseJsonWithCastAuto(Doc, "n-prev", - GraphRef.Params.sampling.n_prev); - parseJsonWithCastAuto(Doc, "n-probs", - GraphRef.Params.sampling.n_probs); - parseJsonWithCastAuto(Doc, "min-keep", - GraphRef.Params.sampling.min_keep); - parseJsonWithCastAuto(Doc, "top-k", - GraphRef.Params.sampling.top_k); - parseJsonWithCastAuto(Doc, "min-p", GraphRef.Params.sampling.min_p); - parseJsonWithCastAuto(Doc, "xtc-probability", - GraphRef.Params.sampling.xtc_probability); - parseJsonWithCastAuto(Doc, "xtc-threshold", - GraphRef.Params.sampling.xtc_threshold); - parseJsonWithCastAuto(Doc, "typ-p", GraphRef.Params.sampling.typ_p); - parseJsonWithCastAuto(Doc, "dynatemp-range", - GraphRef.Params.sampling.dynatemp_range); - parseJsonWithCastAuto(Doc, "dynatemp-exponent", - GraphRef.Params.sampling.dynatemp_exponent); - parseJsonWithCastAuto(Doc, "last-n-penalty", - GraphRef.Params.sampling.penalty_last_n); - parseJsonWithProcessorAuto( - Doc, "temp", [&GraphRef](const double &Temp) -> bool { - GraphRef.Params.sampling.temp = - static_cast(std::max(0.0, Temp)); - return true; - }); - - parseJsonWithProcessorAuto( - Doc, "top-p", [&GraphRef](const double &TopP) -> bool { - GraphRef.Params.sampling.top_p = - static_cast(std::max(0.0, TopP)); - return true; - }); - parseJsonWithProcessorAuto( - Doc, "repeat-penalty", - [&GraphRef](const double &RepeatPenalty) -> bool { - GraphRef.Params.sampling.penalty_repeat = - static_cast(std::max(0.0, RepeatPenalty)); - return true; - }); - parseJsonWithProcessorAuto( - Doc, "presence-penalty", - [&GraphRef](const double &PresencePenalty) -> bool { - GraphRef.Params.sampling.penalty_present = - static_cast(std::max(0.0, PresencePenalty)); - return true; - }); - parseJsonWithProcessorAuto( - Doc, "frequency-penalty", - [&GraphRef](const double &FrequencyPenalty) -> bool { - GraphRef.Params.sampling.penalty_freq = - static_cast(std::max(0.0, FrequencyPenalty)); - return true; - }); - parseJsonWithCastAuto(Doc, "dry-multipier", - GraphRef.Params.sampling.dry_multiplier); - parseJsonWithCastAuto(Doc, "dry-base", - GraphRef.Params.sampling.dry_base); - parseJsonWithCastAuto(Doc, "dry-allowed-length", - GraphRef.Params.sampling.dry_allowed_length); - parseJsonWithCastAuto(Doc, "dry-last-n-penalty", - GraphRef.Params.sampling.penalty_last_n); - parseJsonWithCastAuto(Doc, "mirostat", - GraphRef.Params.sampling.mirostat); - parseJsonWithCastAuto(Doc, "mirostat-eta", - GraphRef.Params.sampling.mirostat_eta); - parseJsonAuto(Doc, "ignore-eos", GraphRef.Params.sampling.ignore_eos); - parseJsonAuto(Doc, "no-perf-sampling", - GraphRef.Params.sampling.no_perf); - parseJsonAuto(Doc, "timing-per-token", - GraphRef.Params.sampling.timing_per_token); - - parseJsonWithCastAuto(Doc, "grammar", - GraphRef.Params.sampling.grammar); - parseJsonWithProcessorAuto( - Doc, "json-schema", - [&GraphRef](const std::string_view &JsonSchema) -> bool { - GraphRef.Params.sampling.grammar = - json_schema_to_grammar(nlohmann::ordered_json::parse(JsonSchema)); - return true; - }); - parseJsonWithCastAuto(Doc, "seed", GraphRef.Params.sampling.seed); - - // The speculative parameters. - parseJsonWithCastAuto(Doc, "n-ctx-speculative", - GraphRef.Params.speculative.n_ctx); - parseJsonWithCastAuto(Doc, "n-max-speculative", - GraphRef.Params.speculative.n_max); - parseJsonWithCastAuto(Doc, "n-min-speculative", - GraphRef.Params.speculative.n_min); - parseJsonWithCastAuto(Doc, "n-gpu-layers-speculative", - GraphRef.Params.speculative.n_gpu_layers); - parseJsonWithCastAuto(Doc, "p-split-speculative", - GraphRef.Params.speculative.p_split); - parseJsonWithCastAuto(Doc, "p-min-speculative", - GraphRef.Params.speculative.p_min); - // The vocoder parameters. - parseJsonWithCastAuto( - Doc, "hf-repo-vocoder", GraphRef.Params.vocoder.model.hf_repo); - parseJsonWithCastAuto( - Doc, "hf-file-vocoder", GraphRef.Params.vocoder.model.hf_file); - parseJsonWithCastAuto(Doc, "model-url-vocoder", - GraphRef.Params.vocoder.model.url); - - // The config parameters. - parseJsonAuto(Doc, "stream-stdout", ConfRef.StreamStdout); - parseJsonAuto(Doc, "n-predict", ConfRef.NPredict); - parseJsonWithCastAuto(Doc, "reverse-prompt", - ConfRef.ReversePrompt); - parseJsonWithCastAuto(Doc, "image", ConfRef.ImagePath); - parseJsonAuto(Doc, "always-regenerate-image-embd", - ConfRef.AlwaysRegenerateImageEmbd); - parseJsonWithCastAuto(Doc, "model-alias", - GraphRef.Params.model_alias); - parseJsonWithCastAuto(Doc, "model-url", - GraphRef.Params.model.url); - parseJsonWithCastAuto(Doc, "hf-token", - GraphRef.Params.hf_token); - parseJsonWithCastAuto(Doc, "hf-repo", - GraphRef.Params.model.hf_repo); - parseJsonWithCastAuto(Doc, "hf-file", - GraphRef.Params.model.hf_file); - parseJsonWithCastAuto(Doc, "prompt-file", - GraphRef.Params.prompt_file); - parseJsonWithCastAuto(Doc, "path-prompt-cache", - GraphRef.Params.path_prompt_cache); - parseJsonWithCastAuto(Doc, "input-prefix", - GraphRef.Params.input_prefix); - parseJsonWithCastAuto(Doc, "input-suffix", - GraphRef.Params.input_suffix); - parseJsonWithCastAuto( - Doc, "lookup-cache-static", GraphRef.Params.lookup_cache_static); - parseJsonWithCastAuto( - Doc, "lookup-cache-dynamic", GraphRef.Params.lookup_cache_dynamic); - parseJsonWithCastAuto(Doc, "logits-file", - GraphRef.Params.logits_file); - parseJsonAuto(Doc, "lora-init-without-apply", - GraphRef.Params.lora_init_without_apply); - parseJsonWithCastAuto(Doc, "verbosity", GraphRef.Params.verbosity); - parseJsonWithCastAuto(Doc, "control-vector-layer-start", - GraphRef.Params.control_vector_layer_start); - parseJsonWithCastAuto(Doc, "control-vector-layer-end", - GraphRef.Params.control_vector_layer_end); - parseJsonWithCastAuto(Doc, "ppl-stride", - GraphRef.Params.ppl_stride); - parseJsonWithCastAuto(Doc, "ppl-output-type", - GraphRef.Params.ppl_output_type); - parseJsonAuto(Doc, "hellaswag", GraphRef.Params.hellaswag); - parseJsonWithCastAuto(Doc, "hellaswag-tasks", - GraphRef.Params.hellaswag_tasks); - parseJsonAuto(Doc, "winogrande", GraphRef.Params.winogrande); - parseJsonWithCastAuto(Doc, "winogrande-tasks", - GraphRef.Params.winogrande_tasks); - parseJsonAuto(Doc, "multiple-choice", - GraphRef.Params.multiple_choice); - parseJsonWithCastAuto( - Doc, "multiple-choice-tasks", GraphRef.Params.multiple_choice_tasks); - parseJsonAuto(Doc, "kl-divergence", GraphRef.Params.kl_divergence); - parseJsonAuto(Doc, "usage", GraphRef.Params.usage); - parseJsonAuto(Doc, "use-color", GraphRef.Params.use_color); - parseJsonAuto(Doc, "special", GraphRef.Params.special); - parseJsonAuto(Doc, "interactive", GraphRef.Params.interactive); - parseJsonAuto(Doc, "interactive-first", - GraphRef.Params.interactive_first); - parseJsonAuto(Doc, "prompt-cache-all", - GraphRef.Params.prompt_cache_all); - parseJsonAuto(Doc, "prompt-cache-ro", - GraphRef.Params.prompt_cache_ro); - parseJsonAuto(Doc, "escape", GraphRef.Params.escape); - parseJsonAuto(Doc, "multiline-input", - GraphRef.Params.multiline_input); - parseJsonAuto(Doc, "simple-io", GraphRef.Params.simple_io); - parseJsonAuto(Doc, "cont-batching", GraphRef.Params.cont_batching); - parseJsonWithProcessorAuto( - Doc, "flash-attn", - [&GraphRef](const std::string_view &FlashAttn) -> bool { - if (FlashAttn == "on"sv || FlashAttn == "enabled"sv) { - GraphRef.Params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED; - } else if (FlashAttn == "off"sv || FlashAttn == "disabled"sv) { - GraphRef.Params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; - } else if (FlashAttn == "auto"sv) { - GraphRef.Params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; - } else { - spdlog::error( - "[WASI-NN] GGML backend: The flash-attn option must be " - "one of: on, off, auto."); - return false; - } - return true; - }); - parseJsonAuto(Doc, "no-perf", GraphRef.Params.no_perf); - parseJsonAuto(Doc, "ctx-shift", GraphRef.Params.ctx_shift); - parseJsonAuto(Doc, "input-prefix-bos", - GraphRef.Params.input_prefix_bos); - parseJsonAuto(Doc, "use-mlock", GraphRef.Params.use_mlock); - parseJsonAuto(Doc, "use-mmap", GraphRef.Params.use_mmap); - parseJsonAuto(Doc, "verbose-prompt", GraphRef.Params.verbose_prompt); - parseJsonAuto(Doc, "display-prompt", GraphRef.Params.display_prompt); - parseJsonAuto(Doc, "no-kv-offload", GraphRef.Params.no_kv_offload); - parseJsonAuto(Doc, "warmup", GraphRef.Params.warmup); - parseJsonAuto(Doc, "check-tensors", GraphRef.Params.check_tensors); - parseJsonWithCastAuto(Doc, "cache-type-k", - GraphRef.Params.cache_type_k); - parseJsonWithCastAuto(Doc, "cache-type-v", - GraphRef.Params.cache_type_v); - - parseJsonWithCastAuto(Doc, "embd-normalize", - GraphRef.Params.embd_normalize); - parseJsonWithCastAuto(Doc, "embd-out", - GraphRef.Params.embd_out); - - parseJsonWithCastAuto(Doc, "embd-sep", - GraphRef.Params.embd_sep); - parseJsonWithProcessorAuto( - Doc, "reranking", [&GraphRef](const bool &) -> bool { - GraphRef.Params.embedding = true; - GraphRef.Params.pooling_type = LLAMA_POOLING_TYPE_RANK; - return true; - }); - parseJsonWithCastAuto(Doc, "port", GraphRef.Params.port); - parseJsonWithCastAuto(Doc, "timeout-read", - GraphRef.Params.timeout_read); - parseJsonWithCastAuto(Doc, "timeout-write", - GraphRef.Params.timeout_write); - parseJsonWithCastAuto(Doc, "n-threads-http", - GraphRef.Params.n_threads_http); - parseJsonWithCastAuto(Doc, "n-cache-reuse", - GraphRef.Params.n_cache_reuse); - parseJsonWithCastAuto(Doc, "hostname", - GraphRef.Params.hostname); - parseJsonWithCastAuto(Doc, "public-path", - GraphRef.Params.public_path); - parseJsonWithCastAuto(Doc, "chat-template", - GraphRef.Params.chat_template); - parseJsonAuto(Doc, "enable-chat-template", - GraphRef.Params.enable_chat_template); - parseJsonWithCastAuto(Doc, "ssl-file-key", - GraphRef.Params.ssl_file_key); - parseJsonWithCastAuto(Doc, "ssl-file-cert", - GraphRef.Params.ssl_file_cert); - parseJsonAuto(Doc, "webui", GraphRef.Params.webui); - parseJsonAuto(Doc, "endpoint-slots", GraphRef.Params.endpoint_slots); - parseJsonAuto(Doc, "endpoint-props", GraphRef.Params.endpoint_props); - parseJsonAuto(Doc, "endpoint-metrics", - GraphRef.Params.endpoint_metrics); - parseJsonAuto(Doc, "log-json", GraphRef.Params.log_json); - - // Slot parameters - parseJsonWithCastAuto(Doc, "slot-save-path", - GraphRef.Params.slot_save_path); - - parseJsonWithCastAuto(Doc, "slot-prompt-similarity", - GraphRef.Params.slot_prompt_similarity); - parseJsonAuto(Doc, "is-pp-shared", GraphRef.Params.is_pp_shared); - parseJsonWithProcessorAuto( - Doc, "n-pp", [&GraphRef](const int64_t &NPP) -> bool { - GraphRef.Params.n_pp.emplace_back(NPP); - return true; - }); - parseJsonWithProcessorAuto( - Doc, "n-tg", [&GraphRef](const int64_t &NTG) -> bool { - GraphRef.Params.n_tg.emplace_back(NTG); - return true; - }); - parseJsonWithProcessorAuto( - Doc, "n-pl", [&GraphRef](const int64_t &NPL) -> bool { - GraphRef.Params.n_pl.emplace_back(NPL); - return true; - }); - parseJsonWithProcessorAuto( - Doc, "context-files", - [&GraphRef](const std::string_view &ContextFile) -> bool { - GraphRef.Params.context_files.emplace_back(ContextFile); - return true; - }); - parseJsonWithCastAuto(Doc, "chunk-size", - GraphRef.Params.chunk_size); - parseJsonWithCastAuto(Doc, "chunk-separator", - GraphRef.Params.chunk_separator); - parseJsonWithCastAuto(Doc, "n-junk", GraphRef.Params.n_junk); - parseJsonWithCastAuto(Doc, "i-pos", GraphRef.Params.i_pos); - parseJsonWithCastAuto(Doc, "out-file", - GraphRef.Params.out_file); - parseJsonWithCastAuto(Doc, "n-out-freq", - GraphRef.Params.n_out_freq); - parseJsonWithCastAuto(Doc, "n-save-freq", - GraphRef.Params.n_save_freq); - parseJsonWithCastAuto(Doc, "i-chunk", GraphRef.Params.i_chunk); - parseJsonAuto(Doc, "process-output", GraphRef.Params.process_output); - parseJsonAuto(Doc, "compute-ppl", GraphRef.Params.compute_ppl); - parseJsonWithCastAuto(Doc, "n-pca-batch", - GraphRef.Params.n_pca_batch); - parseJsonWithCastAuto(Doc, "n-pca-iterations", - GraphRef.Params.n_pca_iterations); - parseJsonWithProcessorAuto( - Doc, "cvector-dimre-method", - [&GraphRef](const std::string_view &Method) -> bool { - if (Method == "DIMRE_METHOD_PCA") { - GraphRef.Params.cvector_dimre_method = - dimre_method::DIMRE_METHOD_PCA; - return true; - } - if (Method == "DIMRE_METHOD_MEAN") { - GraphRef.Params.cvector_dimre_method = - dimre_method::DIMRE_METHOD_MEAN; - return true; - } - return false; - }); - parseJsonWithCastAuto( - Doc, "cvector-positive-file", GraphRef.Params.cvector_positive_file); - parseJsonWithCastAuto( - Doc, "cvector-negative-file", GraphRef.Params.cvector_negative_file); - parseJsonAuto(Doc, "spm-infill", GraphRef.Params.spm_infill); - parseJsonWithCastAuto(Doc, "out-file", - GraphRef.Params.out_file); - parseJsonAuto(Doc, "batched-bench-output-jsonl", - GraphRef.Params.batched_bench_output_jsonl); - } catch (const ErrNo &Error) { - return Error; - } - // The tensor buffer overrides should terminated with empty pattern. - if (!GraphRef.TensorBuftOverrides.empty()) { - for (const std::string &Override : GraphRef.TensorBuftOverrides) { - GraphRef.Params.tensor_buft_overrides.push_back( - {Override.c_str(), ggml_backend_cpu_buffer_type()}); - } - GraphRef.Params.tensor_buft_overrides.push_back({nullptr, nullptr}); - } - - if (GraphRef.TextToSpeech) { - GraphRef.Params.sampling.top_k = 4; - GraphRef.Params.sampling.samplers = { - COMMON_SAMPLER_TYPE_TOP_K, - }; - } - // Check if the model parameters are updated. - if (IsModelUpdated && PrevNGPULayers != GraphRef.Params.n_gpu_layers) { - *IsModelUpdated = true; - } - - // Check if the context parameters are updated. - if (IsContextUpdated && PrevEmbedding != GraphRef.Params.embedding) { - *IsContextUpdated = true; - } - - // Check if the sampler parameters are updated. - if (IsSamplerUpdated && - (PrevTemp != GraphRef.Params.sampling.temp || - PrevTopP != GraphRef.Params.sampling.top_p || - PrevRepeatPenalty != GraphRef.Params.sampling.penalty_repeat || - PrevPresencePenalty != GraphRef.Params.sampling.penalty_present || - PrevFrequencyPenalty != GraphRef.Params.sampling.penalty_freq || - PrevGrammar != GraphRef.Params.sampling.grammar || - PrevSeed != GraphRef.Params.sampling.seed)) { - *IsSamplerUpdated = true; - } - - return ErrNo::Success; -} - -// <<<<<<<< Metadata related functions <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< - -// >>>>>>>> Input related functions >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> - -const std::string_view Base64ImageTagPrefix = ""sv; -const std::string_view VisionPromptImagePlaceholder = ""sv; - -// Get base64 image position if found in prompt. -std::optional> -findBase64ImagePayload(std::string_view Prompt, - bool IsDebugLog = false) noexcept { - // Find `` - auto EndTagPos = Prompt.find(Base64ImageTagSuffix, PayloadPos); - if (EndTagPos == std::string::npos) { - LOG_DEBUG(IsDebugLog, "base64: image tag unclosed."sv) - return std::nullopt; - } - return std::make_tuple(BeginTagPos, PayloadPos, EndTagPos); -} - -// Extract base64 image payload and image type. Replace it with placeholder. -std::optional, std::string>> -extractBase64ImagePayload(std::string &Prompt, - std::tuple ImagePos, - const std::string_view Placeholder) noexcept { - // Locate the payload and image type. - size_t BeginTagPos = std::get<0>(ImagePos); - size_t TypePos = std::get<0>(ImagePos) + Base64ImageTagPrefix.size(); - size_t PayloadPos = std::get<1>(ImagePos); - size_t BeginBytePos = std::get<1>(ImagePos) + Base64ImageBytesPrefix.size(); - size_t EndTagPos = std::get<2>(ImagePos); - std::string_view Payload = - std::string_view(Prompt).substr(BeginBytePos, EndTagPos - BeginBytePos); - std::string ImageType = Prompt.substr(TypePos, PayloadPos - TypePos); - - // Decode the base64 payload. - auto RequiredBytes = base64::required_encode_size(Payload.size()); - std::vector ImageBytes(RequiredBytes); - try { - base64::decode(Payload.begin(), Payload.end(), ImageBytes.begin()); - } catch (const base64_error &E) { - RET_ERROR(std::make_pair(std::vector(), ""), - "base64: Error when calling base64::decode: {}"sv, E.what()) - } - - // Replace the base64 image with the placeholder. - Prompt.replace(BeginTagPos, - EndTagPos - BeginTagPos + Base64ImageTagSuffix.size(), - Placeholder); - return std::make_pair(ImageBytes, ImageType); -} - -// TTS function to process the prompt text. -// clang-format off -const TTSSpeakerProfile TTSDefaultSpeakerProfile = { - // Speaker profile from edwko/OuteTTS (en_female_1.json). - "<|text_start|>uhm<|text_sep|>now<|text_sep|>being<|text_sep|>the<|text_sep|>one<|text_sep|>to<|text_sep|>say<|text_sep|>i<|text_sep|>know<|text_sep|>the<|text_sep|>worst<|text_sep|>of<|text_sep|>you<|text_sep|>and<|text_sep|>ive<|text_sep|>been<|text_sep|>directly<|text_sep|>affected<|text_sep|>by<|text_sep|>people<|text_sep|>like<|text_sep|>you<|text_sep|>but<|text_sep|>its<|text_sep|>a<|text_sep|>clean<|text_sep|>slate<|text_sep|>with<|text_sep|>me<|text_sep|>buddy<|text_sep|>you<|text_sep|>know<|text_sep|>like<|text_sep|>thats<|text_sep|>really<|text_sep|>powerful<|text_sep|>in<|text_sep|>and<|text_sep|>of<|text_sep|>itself<|text_sep|>", - "<|audio_start|>\nuhm<|t_0.36|><|code_start|><|447|><|223|><|967|><|301|><|965|><|827|><|393|><|908|><|764|><|1167|><|711|><|1222|><|324|><|1318|><|806|><|498|><|1198|><|1127|><|1178|><|916|><|1234|><|1411|><|1428|><|706|><|427|><|1605|><|1578|><|code_end|>\nnow<|t_0.36|><|code_start|><|1049|><|327|><|385|><|1070|><|732|><|1480|><|450|><|1025|><|1469|><|174|><|1013|><|1710|><|1674|><|775|><|771|><|251|><|778|><|1400|><|897|><|1487|><|366|><|441|><|1000|><|393|><|271|><|1000|><|768|><|code_end|>\nbeing<|t_0.27|><|code_start|><|926|><|406|><|1457|><|437|><|1231|><|672|><|1785|><|521|><|1179|><|1559|><|198|><|1086|><|733|><|122|><|1344|><|845|><|348|><|1389|><|470|><|1773|><|code_end|>\nthe<|t_0.08|><|code_start|><|1775|><|562|><|768|><|1222|><|768|><|963|><|code_end|>\none<|t_0.21|><|code_start|><|1757|><|744|><|144|><|1610|><|655|><|616|><|1317|><|225|><|1325|><|913|><|1342|><|992|><|1018|><|80|><|1777|><|883|><|code_end|>\nto<|t_0.08|><|code_start|><|487|><|1363|><|1682|><|1426|><|655|><|1483|><|code_end|>\nsay<|t_0.27|><|code_start|><|1644|><|1804|><|731|><|273|><|1592|><|731|><|1523|><|1404|><|984|><|1207|><|430|><|1132|><|1123|><|768|><|1116|><|829|><|1082|><|1095|><|440|><|1162|><|code_end|>\ni<|t_0.33|><|code_start|><|1330|><|335|><|1162|><|1155|><|308|><|1162|><|1150|><|1481|><|612|><|674|><|712|><|1745|><|1188|><|1787|><|1135|><|1275|><|1237|><|1143|><|408|><|1063|><|393|><|927|><|1298|><|132|><|1686|><|code_end|>\nknow<|t_0.27|><|code_start|><|983|><|1677|><|586|><|1528|><|1435|><|835|><|1396|><|706|><|987|><|22|><|1172|><|218|><|1404|><|1001|><|521|><|1389|><|775|><|1416|><|877|><|120|><|code_end|>\nthe<|t_0.16|><|code_start|><|916|><|1756|><|513|><|1245|><|1392|><|89|><|1266|><|12|><|1045|><|1075|><|904|><|35|><|code_end|>\nworst<|t_0.32|><|code_start|><|1607|><|174|><|1231|><|144|><|932|><|490|><|771|><|1504|><|798|><|674|><|364|><|80|><|1314|><|1636|><|449|><|1704|><|713|><|1795|><|968|><|1527|><|1302|><|1529|><|1176|><|795|><|code_end|>\nof<|t_0.12|><|code_start|><|1193|><|1205|><|390|><|1128|><|1091|><|883|><|322|><|377|><|1070|><|code_end|>\nyou<|t_0.17|><|code_start|><|1016|><|1332|><|926|><|281|><|927|><|1368|><|1687|><|918|><|67|><|1638|><|1317|><|1265|><|1770|><|code_end|>\nand<|t_0.28|><|code_start|><|1129|><|1633|><|1373|><|1207|><|405|><|879|><|1030|><|1253|><|1071|><|612|><|724|><|1770|><|665|><|1046|><|1351|><|1450|><|1541|><|1384|><|111|><|1477|><|284|><|code_end|>\nive<|t_0.35|><|code_start|><|674|><|266|><|89|><|1333|><|1183|><|1526|><|1143|><|883|><|1135|><|732|><|827|><|1119|><|594|><|1261|><|1024|><|1347|><|92|><|1392|><|825|><|1710|><|1289|><|1598|><|1070|><|1525|><|1442|><|555|><|code_end|>\nbeen<|t_0.17|><|code_start|><|1461|><|194|><|337|><|1128|><|188|><|892|><|848|><|1280|><|959|><|754|><|231|><|649|><|1304|><|code_end|>\ndirectly<|t_0.87|><|code_start|><|1030|><|353|><|570|><|1331|><|470|><|1832|><|1362|><|1809|><|1383|><|101|><|325|><|1557|><|1242|><|1512|><|180|><|227|><|1242|><|643|><|209|><|464|><|171|><|1219|><|174|><|1723|><|734|><|118|><|1269|><|643|><|209|><|187|><|612|><|1231|><|68|><|567|><|1242|><|505|><|319|><|1268|><|794|><|678|><|40|><|1286|><|470|><|1454|><|199|><|965|><|188|><|300|><|1234|><|1125|><|794|><|1289|><|1224|><|257|><|469|><|1121|><|101|><|823|><|1769|><|1683|><|95|><|255|><|59|><|67|><|832|><|code_end|>\naffected<|t_0.44|><|code_start|><|510|><|873|><|787|><|1228|><|771|><|1428|><|501|><|751|><|696|><|258|><|845|><|1818|><|1112|><|498|><|1111|><|985|><|1073|><|832|><|1427|><|168|><|163|><|447|><|119|><|567|><|1626|><|1820|><|903|><|635|><|1060|><|10|><|1632|><|35|><|1635|><|code_end|>\nby<|t_0.19|><|code_start|><|144|><|144|><|460|><|185|><|1112|><|1044|><|498|><|1192|><|656|><|1333|><|1001|><|1186|><|1186|><|454|><|code_end|>\npeople<|t_0.48|><|code_start|><|1260|><|747|><|351|><|526|><|612|><|1151|><|1262|><|1791|><|344|><|1752|><|1547|><|930|><|1302|><|1703|><|1289|><|92|><|1407|><|1482|><|508|><|1431|><|355|><|1696|><|337|><|199|><|1157|><|223|><|464|><|568|><|845|><|411|><|826|><|718|><|1786|><|545|><|712|><|580|><|code_end|>\nlike<|t_0.32|><|code_start|><|630|><|532|><|526|><|607|><|526|><|839|><|1305|><|660|><|459|><|339|><|717|><|1178|><|1148|><|687|><|149|><|1390|><|229|><|199|><|513|><|712|><|1451|><|731|><|582|><|1551|><|code_end|>\nyou<|t_0.21|><|code_start|><|1389|><|954|><|1781|><|1047|><|1236|><|930|><|809|><|1621|><|1268|><|384|><|242|><|587|><|869|><|816|><|1680|><|405|><|code_end|>\nbut<|t_0.59|><|code_start|><|1089|><|1590|><|908|><|80|><|594|><|1046|><|1706|><|1025|><|1150|><|405|><|548|><|893|><|1285|><|464|><|301|><|939|><|643|><|23|><|285|><|161|><|209|><|453|><|72|><|167|><|417|><|244|><|151|><|643|><|391|><|199|><|651|><|1023|><|337|><|1010|><|54|><|331|><|1167|><|756|><|388|><|934|><|1060|><|18|><|1624|><|1060|><|code_end|>\nits<|t_0.16|><|code_start|><|1102|><|183|><|1199|><|1258|><|1285|><|35|><|659|><|180|><|426|><|1587|><|1733|><|942|><|code_end|>\na<|t_0.04|><|code_start|><|791|><|1012|><|818|><|code_end|>\nclean<|t_0.61|><|code_start|><|1819|><|976|><|163|><|447|><|316|><|223|><|763|><|457|><|1208|><|1808|><|1697|><|1162|><|1660|><|1833|><|1054|><|1734|><|1121|><|1309|><|1643|><|924|><|1677|><|1548|><|869|><|1268|><|223|><|674|><|111|><|792|><|1670|><|912|><|174|><|1554|><|90|><|80|><|1563|><|1621|><|1698|><|1544|><|992|><|988|><|175|><|793|><|1661|><|1026|><|80|><|1761|><|code_end|>\nslate<|t_0.40|><|code_start|><|1802|><|322|><|1689|><|1577|><|1302|><|1552|><|1529|><|1722|><|1580|><|582|><|1642|><|1529|><|1020|><|582|><|1538|><|970|><|437|><|1141|><|1477|><|988|><|335|><|1611|><|922|><|1558|><|1120|><|1189|><|423|><|188|><|171|><|562|><|code_end|>\nwith<|t_0.15|><|code_start|><|963|><|1347|><|1274|><|747|><|1230|><|712|><|1408|><|1290|><|957|><|1279|><|258|><|code_end|>\nme<|t_0.09|><|code_start|><|638|><|1058|><|174|><|1452|><|1038|><|894|><|1571|><|code_end|>\nbuddy<|t_0.32|><|code_start|><|1003|><|130|><|1341|><|938|><|40|><|804|><|167|><|89|><|1456|><|1189|><|1155|><|1171|><|1434|><|1077|><|1029|><|1455|><|1622|><|1037|><|163|><|1411|><|1165|><|1463|><|837|><|1202|><|code_end|>\nyou<|t_0.36|><|code_start|><|1354|><|1165|><|615|><|1588|><|1192|><|1445|><|1033|><|982|><|401|><|1079|><|684|><|1570|><|266|><|31|><|420|><|163|><|893|><|845|><|905|><|1827|><|1804|><|153|><|627|><|243|><|1179|><|298|><|1147|><|code_end|>\nknow<|t_0.19|><|code_start|><|163|><|1542|><|1366|><|698|><|1753|><|206|><|916|><|1499|><|245|><|665|><|600|><|894|><|587|><|1741|><|code_end|>\nlike<|t_0.24|><|code_start|><|1106|><|1280|><|1062|><|1304|><|945|><|809|><|598|><|104|><|1001|><|822|><|965|><|189|><|693|><|1810|><|1293|><|199|><|1277|><|44|><|code_end|>\nthats<|t_0.24|><|code_start|><|121|><|1789|><|1443|><|370|><|1154|><|393|><|1178|><|1200|><|1264|><|424|><|1391|><|381|><|978|><|1346|><|704|><|1808|><|1579|><|1492|><|code_end|>\nreally<|t_0.56|><|code_start|><|1177|><|1761|><|1723|><|1360|><|1413|><|830|><|551|><|193|><|59|><|332|><|598|><|734|><|1684|><|1802|><|60|><|1590|><|353|><|89|><|1636|><|1396|><|893|><|143|><|455|><|1501|><|435|><|1082|><|621|><|1593|><|677|><|474|><|971|><|1513|><|913|><|828|><|1381|><|1148|><|1798|><|1186|><|1443|><|38|><|335|><|883|><|code_end|>\npowerful<|t_0.63|><|code_start|><|1773|><|458|><|1070|><|964|><|826|><|1220|><|1012|><|1738|><|1125|><|669|><|490|><|1169|><|922|><|958|><|1204|><|489|><|1001|><|886|><|1045|><|675|><|1471|><|1652|><|732|><|698|><|1124|><|480|><|897|><|1484|><|1028|><|35|><|594|><|1465|><|505|><|1669|><|436|><|851|><|1288|><|31|><|1501|><|1187|><|394|><|909|><|1541|><|1793|><|1720|><|922|><|840|><|code_end|>\nin<|t_0.16|><|code_start|><|1317|><|523|><|630|><|1343|><|1187|><|719|><|907|><|636|><|111|><|1524|><|188|><|1382|><|code_end|>\nand<|t_0.13|><|code_start|><|1074|><|922|><|1280|><|1496|><|1050|><|832|><|133|><|1435|><|1049|><|1774|><|code_end|>\nof<|t_0.12|><|code_start|><|960|><|1052|><|1192|><|1303|><|1112|><|970|><|417|><|60|><|1155|><|code_end|>\nitself<|t_0.47|><|code_start|><|1682|><|1209|><|1410|><|513|><|1222|><|861|><|167|><|406|><|1551|><|582|><|634|><|1529|><|786|><|1363|><|1578|><|1739|><|873|><|424|><|1041|><|1328|><|955|><|1110|><|1490|><|1424|><|1199|><|988|><|1162|><|1133|><|1193|><|978|><|470|><|832|><|963|><|1251|><|733|><|code_end|>", -}; -// clang-format on - -const std::map Ones = { - {0, "zero"}, {1, "one"}, {2, "two"}, {3, "three"}, - {4, "four"}, {5, "five"}, {6, "six"}, {7, "seven"}, - {8, "eight"}, {9, "nine"}, {10, "ten"}, {11, "eleven"}, - {12, "twelve"}, {13, "thirteen"}, {14, "fourteen"}, {15, "fifteen"}, - {16, "sixteen"}, {17, "seventeen"}, {18, "eighteen"}, {19, "nineteen"}}; - -const std::map Tens = { - {2, "twenty"}, {3, "thirty"}, {4, "forty"}, {5, "fifty"}, - {6, "sixty"}, {7, "seventy"}, {8, "eighty"}, {9, "ninety"}}; - -// Convert a number less than 1000 to words -std::string convertLessThanThousand(int Num) { - std::string Result; - - if (Num >= 100) { - Result += Ones.at(Num / 100) + " hundred "; - Num %= 100; - } - - if (Num >= 20) { - Result += Tens.at(Num / 10); - if (Num % 10 > 0) { - Result += "-" + Ones.at(Num % 10); - } - } else if (Num > 0) { - Result += Ones.at(Num); - } - - return Result; -} - -std::string numberToWords(const std::string &NumberStr) { - try { - size_t DecimalPos = NumberStr.find('.'); - std::string IntegerPart = NumberStr.substr(0, DecimalPos); - - int IntNumber = std::stoi(IntegerPart); - std::string Result; - - if (IntNumber == 0) { - Result = "zero"; - } else { - if (IntNumber >= 1000000000) { - int Billions = IntNumber / 1000000000; - Result += convertLessThanThousand(Billions) + " billion "; - IntNumber %= 1000000000; - } - - if (IntNumber >= 1000000) { - int Millions = IntNumber / 1000000; - Result += convertLessThanThousand(Millions) + " million "; - IntNumber %= 1000000; - } - - if (IntNumber >= 1000) { - int Thousands = IntNumber / 1000; - Result += convertLessThanThousand(Thousands) + " thousand "; - IntNumber %= 1000; - } - - if (IntNumber > 0) { - Result += convertLessThanThousand(IntNumber); - } - } - - // Handle decimal part - if (DecimalPos != std::string::npos) { - Result += " point"; - std::string DecimalPart = NumberStr.substr(DecimalPos + 1); - for (char Digit : DecimalPart) { - Result += " " + Ones.at(Digit - '0'); - } - } - - return Result; - } catch (const std::exception &) { - // Skip if fails - return " "; - } -} - -std::string replaceNumbersWithWords(const std::string &InputText) { - std::regex NumberPattern(R"(\d+(\.\d+)?)"); - std::string Result; - auto It = - std::sregex_iterator(InputText.begin(), InputText.end(), NumberPattern); - auto End = std::sregex_iterator(); - - size_t LastPos = 0; - for (std::sregex_iterator I = It; I != End; ++I) { - const std::smatch &Match = *I; - Result.append(InputText, LastPos, Match.position() - LastPos); - Result.append(numberToWords(Match.str())); - LastPos = Match.position() + Match.length(); - } - Result.append(InputText, LastPos); - - return Result; -} - -// Based on: -// https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39 -// https://github.com/ggerganov/llama.cpp/blob/b4488/examples/tts/tts.cpp#L374 -std::string processTTSPromptText(const std::string &Text) { - std::string ProcessedText = replaceNumbersWithWords(Text); - - std::transform( - ProcessedText.begin(), ProcessedText.end(), ProcessedText.begin(), - [](unsigned char C) { return static_cast(::tolower(C)); }); - - std::regex SpecialChars(R"([-_/,\.\\])"); - ProcessedText = std::regex_replace(ProcessedText, SpecialChars, " "); - - std::regex NonAlpha(R"([^a-z\s])"); - ProcessedText = std::regex_replace(ProcessedText, NonAlpha, ""); - - std::regex MultipleSpaces(R"(\s+)"); - ProcessedText = std::regex_replace(ProcessedText, MultipleSpaces, " "); - - ProcessedText = - std::regex_replace(ProcessedText, std::regex(R"(^\s+|\s+$)"), ""); - - ProcessedText = - std::regex_replace(ProcessedText, std::regex(R"(\s)"), "<|text_sep|>"); - - return ProcessedText; -} - -std::optional -getSpeakerProfileFromFile(const std::string &FilePath, WasiNNEnvironment &Env) { - WasmEdge::FStream::IFStream JsonFile(FilePath, Env.getEnv()); - if (!JsonFile.is_open()) { - return std::nullopt; - } - nlohmann::json JsonData; - JsonFile >> JsonData; - JsonFile.close(); - - // Initialize the outputs - std::string AudioOutputText = "<|audio_start|>\n"; - std::string TextOutput = "<|text_start|>"; - - // Iterate through each word in the JSON data - for (const auto &WordData : JsonData["words"]) { - std::string Word = WordData["word"]; - double Duration = WordData["duration"]; - std::vector Codes = WordData["codes"]; - - // Create the audio output entry - std::ostringstream WordEntry; - WordEntry << Word << "<|t_" << std::fixed << std::setprecision(2) - << Duration << "|><|code_start|>"; - for (const auto &Code : Codes) { - WordEntry << "<|" << Code << "|>"; - } - WordEntry << "<|code_end|>\n"; - AudioOutputText += WordEntry.str(); - - // Create the text output entry - TextOutput += Word + "<|text_sep|>"; - } - - return TTSSpeakerProfile{TextOutput, AudioOutputText}; -} - -std::vector processTTSPrompt(WasiNNEnvironment &Env, - Graph &GraphRef, - std::string &Prompt) noexcept { - // Use the custom speaker profile if available. - TTSSpeakerProfile SpeakerProfile = TTSDefaultSpeakerProfile; - if (!GraphRef.TTSSpeakerFilePath.empty()) { - std::optional SpeakerProfileOpt = - getSpeakerProfileFromFile(GraphRef.TTSSpeakerFilePath, Env); - if (SpeakerProfileOpt.has_value()) { - SpeakerProfile = *SpeakerProfileOpt; - } else { - RET_ERROR( - {}, - "processTTSPrompt: Failed to load speaker profile from file: {}"sv, - GraphRef.TTSSpeakerFilePath); - } - } - std::string ProcessedPrompt = processTTSPromptText(Prompt); - std::vector Result, TmpTokens; - Result = common_tokenize(GraphRef.LlamaContext.get(), "<|im_start|>\n", - /* add_special */ true, - /* parse_special */ true); - TmpTokens = common_tokenize(GraphRef.LlamaContext.get(), SpeakerProfile.Text, - /* add_special */ false, - /* parse_special */ true); - Result.insert(Result.end(), TmpTokens.begin(), TmpTokens.end()); - TmpTokens = common_tokenize(GraphRef.LlamaContext.get(), ProcessedPrompt, - /* add_special */ false, - /* parse_special */ true); - Result.insert(Result.end(), TmpTokens.begin(), TmpTokens.end()); - TmpTokens = common_tokenize(GraphRef.LlamaContext.get(), "<|text_end|>\n", - /* add_special */ false, - /* parse_special */ true); - Result.insert(Result.end(), TmpTokens.begin(), TmpTokens.end()); - TmpTokens = common_tokenize(GraphRef.LlamaContext.get(), SpeakerProfile.Data, - /* add_special */ false, - /* parse_special */ true); - Result.insert(Result.end(), TmpTokens.begin(), TmpTokens.end()); - - return Result; -} - -// <<<<<<<< Input related functions <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< - -// >>>>>>>> Output related functions >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> - -// Generate output metadata. -std::string buildOutputMetadata(Context &CxtRef) noexcept { - return fmt::format(R"({{"input_tokens": {}, )" - R"("output_tokens": {}, )" - R"("llama_build_number": {}, )" - R"("llama_commit": "{}"}})"sv, - CxtRef.LlamaNInputs, CxtRef.LlamaOutputTokens.size(), - LLAMA_BUILD_NUMBER, LLAMA_COMMIT); -} - -// Generate output embedding. -void buildOutputEmbedding(std::string &Embedding, int32_t NEmbd, - const float *Embeddings) noexcept { - // Embedding vector format - // | Content | - // | ----------------------------------- | - // | '{"number_embedding": ' | - // | n_embedding | - // | ', "embedding": ' | - // | '[' | - // | n_embedding*(embedding value %.10f) | - // | (n_embedding-1)*(',') | - // | ']' | - // | '}' | - Embedding = - fmt::format(R"({{"n_embedding": {}, )" - R"("embedding": [{:.10}]}})"sv, - NEmbd, fmt::join(Embeddings, Embeddings + NEmbd, ","sv)); -} - -// <<<<<<<< Output related functions <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< - -// >>>>>>>> Compute related functions >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> - -// Helper to init a llama batch. -struct llama_batch allocBatch(int64_t NTokens, int64_t Embd = 0, - int32_t NSeqMax = 1) noexcept { - struct llama_batch Batch = llama_batch_init( - /* n_tokens_alloc */ static_cast(NTokens), - /* embd */ static_cast(Embd), - /* n_seq_max */ static_cast(NSeqMax)); - std::fill(Batch.n_seq_id, Batch.n_seq_id + NTokens, - static_cast(NSeqMax)); - for (int64_t I = 0; I < NTokens; I++) { - std::fill(Batch.seq_id[I], Batch.seq_id[I] + NSeqMax, 0); - } - std::fill(Batch.logits, Batch.logits + NTokens, false); - return Batch; -} - -// Fill tokens (smaller than batch size) into a batch with position data. -void fillBatch(Span Tokens, Graph &GraphRef, - llama_batch &Batch, int &NPos, bool IsLogit = false) { - assuming(GraphRef.Params.n_batch >= static_cast(Tokens.size())); - assuming(Batch.token != nullptr); - assuming(Batch.pos != nullptr); - assuming(Batch.logits != nullptr); - // Fill the batch with pos information. - Batch.n_tokens = static_cast(Tokens.size()); - for (uint32_t I = 0; I < Tokens.size(); I++) { - Batch.token[I] = Tokens[I]; - Batch.pos[I] = NPos + I; - Batch.logits[I] = false; - } - - // Logits of sampling or end of inputs. - if (IsLogit) { - Batch.logits[Tokens.size() - 1] = true; - } - - // Move the position. - NPos += static_cast(Tokens.size()); -} - -// Evaluate tokens. Construct the tokens into batch and decode. -ErrNo evaluateTokens(Span Tokens, Graph &GraphRef, - llama_batch &Batch, int &NPos, - bool IsLogits = false) noexcept { - // End the inference if the context is full. - uint32_t NCtx = llama_n_ctx(GraphRef.LlamaContext.get()); - if (NPos + static_cast(Tokens.size()) > NCtx) { - LOG_INFO( - GraphRef.EnableLog, - "evaluateTokens: the context if full ({} / {} tokens). Please increase your "sv - "context size."sv, - NPos + static_cast(Tokens.size()), NCtx) - return ErrNo::ContextFull; - } - - // Loop for decode batch. Split tokens into batch size length. - for (int I = 0; I < static_cast(Tokens.size()); - I += static_cast(GraphRef.Params.n_batch)) { - int NEval = static_cast(Tokens.size()) - I; - if (NEval > static_cast(GraphRef.Params.n_batch)) { - NEval = static_cast(GraphRef.Params.n_batch); - } - - // Fill the batch with pos information. - fillBatch(Span(Tokens.begin() + I, NEval), GraphRef, - Batch, NPos, - IsLogits && I + NEval >= static_cast(Tokens.size())); - - // Decode the batch. - auto Status = llama_decode(GraphRef.LlamaContext.get(), Batch); - if (Status == 1) { - RET_ERROR( - ErrNo::RuntimeError, - "evaluateTokens: failed to llama_decode: try reducing the size of the batch "sv - "or increasing the size of context."sv) - } - if (Status < 0) { - RET_ERROR( - ErrNo::RuntimeError, - "evaluateTokens: failed to llama_decode: internal fatal error. Please open "sv - "an issue on GitHub."sv) - } - } - - return ErrNo::Success; -} - -// Clear the context and reset the sampler. -void clearContext(Graph &GraphRef, Context &CxtRef) noexcept { - LOG_DEBUG(GraphRef.EnableDebugLog, "{}: clearContext"sv) - llama_memory_clear(llama_get_memory(GraphRef.LlamaContext.get()), true); - common_sampler_reset(CxtRef.LlamaSampler); - CxtRef.NPos = 0; - CxtRef.LlamaOutputs.clear(); - CxtRef.LlamaOutputTokens.clear(); - LOG_DEBUG(GraphRef.EnableDebugLog, "{}: clearContext...Done"sv) -} - -// Evaluate the input tokens. Clean all inputs if succeeded. -ErrNo evaluateInput(Graph &GraphRef, Context &CxtRef, - std::string_view LogPrefix) noexcept { - // Check if the input is set before setting up the context. - if (CxtRef.LlamaInputs.size() == 0) { - RET_ERROR(ErrNo::InvalidArgument, "{}: llama input is not set!"sv, - LogPrefix) - } - - // Get the context size. - const uint64_t NCtx = llama_n_ctx(GraphRef.LlamaContext.get()); - // Minus 4 for the special tokens. (Such as , , ... tokens.) - const uint64_t MaxTokensListSize = NCtx - 4; - // Return value. - auto ReturnCode = ErrNo::Success; - - // Check if the input is too long. - if (static_cast(CxtRef.LlamaInputs.size()) > MaxTokensListSize) { - RET_ERROR(ErrNo::PromptTooLong, - "{}: the prompt is too long. Your input has {} tokens. "sv - "Please reduce it to {} tokens."sv, - LogPrefix, CxtRef.LlamaInputs.size(), MaxTokensListSize) - } - - // Evaluate input tokens. - ReturnCode = - evaluateTokens(Span(CxtRef.LlamaInputs.begin(), - CxtRef.LlamaInputs.size()), - GraphRef, CxtRef.LlamaBatch, CxtRef.NPos, true); - if (ReturnCode != ErrNo::Success) { - RET_ERROR(ReturnCode, "{}: failed to evaluate input tokens."sv, LogPrefix) - } - - return ErrNo::Success; -} - -// Sample and get the output token. -ErrNo sampleOutput(Graph &GraphRef, Context &CxtRef, - bool IsSingleTokenMode = false) noexcept { - // Use idx = -1 to sample the next token. - const llama_token Id = common_sampler_sample( - CxtRef.LlamaSampler, GraphRef.LlamaContext.get(), /* idx */ -1); - common_sampler_accept(CxtRef.LlamaSampler, Id, /* accept_grammar */ true); - - // Save the output token. - CxtRef.LlamaOutputTokens.emplace_back(Id); - std::string OutputString = - common_token_to_piece(GraphRef.LlamaContext.get(), Id); - CxtRef.LlamaOutputs.insert(CxtRef.LlamaOutputs.end(), OutputString.begin(), - OutputString.end()); - // In single token mode, we do not handle StreamStdout and ReversePrompt. - if (!IsSingleTokenMode) { - // When setting StreamStdout, we print the output to stdout. - if (CxtRef.Conf.StreamStdout) { - fmt::print("{}"sv, - common_token_to_piece(GraphRef.LlamaContext.get(), Id)); - std::fflush(stdout); - } - // Break if reverse prompt is found. - if (!CxtRef.Conf.ReversePrompt.empty() && - std::string(CxtRef.LlamaOutputs.begin(), CxtRef.LlamaOutputs.end()) - .find(CxtRef.Conf.ReversePrompt) != std::string::npos) { - LOG_INFO(GraphRef.EnableLog, "sampleOutput: reverse prompt found."sv) - return ErrNo::EndOfSequence; - } - } - // Deal with end of text token. - const llama_vocab *Vocab = llama_model_get_vocab(GraphRef.LlamaModel.get()); - // Only stop on EOS if GraphRef.Params.sampling.ignore_eos is false. - if (!GraphRef.Params.sampling.ignore_eos && - llama_vocab_is_eog(Vocab, common_sampler_last(CxtRef.LlamaSampler))) { - LOG_INFO(GraphRef.EnableLog, "sampleOutput: EOS token found."sv) - return ErrNo::EndOfSequence; - } - // Evaluate the output token. - return evaluateTokens(Span(&Id, 1), GraphRef, - CxtRef.OutputBatch, CxtRef.NPos, true); -} - -// TODO: Merge into compute. -Expect getEmbedding(Graph &GraphRef, Context &CxtRef) noexcept { - LOG_DEBUG(GraphRef.EnableDebugLog, "getEmbedding"sv) - - const llama_vocab *Vocab = llama_model_get_vocab(GraphRef.LlamaModel.get()); - // Add SEP if not present. - if (CxtRef.LlamaInputs.size() > 0 && - CxtRef.LlamaInputs.back() != llama_vocab_sep(Vocab)) { - LOG_WARN( - "getEmbedding: last token in the prompt is not SEP, "sv - "'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF "sv - "header."sv) - } - - // Check if the input is too long. - if (static_cast(CxtRef.LlamaInputs.size()) > - GraphRef.Params.n_batch) { - RET_ERROR( - ErrNo::PromptTooLong, - "getEmbedding: the prompt is too long. Your input has {} tokens exceeds batch "sv - "size {}. Please reduce the input size or increase your batch-size."sv, - CxtRef.LlamaInputs.size(), GraphRef.Params.n_batch) - } - - // Evaluate the input tokens. - auto ReturnCode = evaluateInput(GraphRef, CxtRef, "getEmbedding"sv); - if (ReturnCode != ErrNo::Success) { - return ReturnCode; - } - - // Main prediction loop. - const int32_t NEmbd = llama_model_n_embd(GraphRef.LlamaModel.get()); - std::vector Embeddings(NEmbd); - - for (int I = 0; I < CxtRef.LlamaBatch.n_tokens; I++) { - if (!CxtRef.LlamaBatch.logits[I]) { - continue; - } - - // Try to get sequence embeddings. - auto *Embd = llama_get_embeddings_seq(GraphRef.LlamaContext.get(), - CxtRef.LlamaBatch.seq_id[I][0]); - if (Embd == nullptr) { - Embd = llama_get_embeddings_ith(GraphRef.LlamaContext.get(), I); - if (Embd == nullptr) { - LOG_ERROR("getEmbedding: failed to get embeddings for token {}"sv, I); - continue; - } - } - - // Normalize the embeddings. - common_embd_normalize(Embd, Embeddings.data(), NEmbd, - static_cast(CxtRef.Conf.EmbdNormalize)); - } - - std::string EmbeddingString; - buildOutputEmbedding(EmbeddingString, NEmbd, Embeddings.data()); - CxtRef.LlamaOutputs = - std::vector(EmbeddingString.begin(), EmbeddingString.end()); - - if (GraphRef.EnableLog) { - common_perf_print(GraphRef.LlamaContext.get(), /* Sampler */ nullptr); - } - - LOG_DEBUG(GraphRef.EnableDebugLog, "getEmbedding...Done"sv) - return ErrNo::Success; -} - -// TTS related functions. -void fillHannWindow(int Length, bool Periodic, float *Output) { - int Offset = -1; - float Pi = static_cast(std::acos(-1)); - if (Periodic) { - Offset = 0; - } - for (int I = 0; I < Length; I++) { - double Value = - 0.5 * - (1.0 - cosf(static_cast((2.0 * Pi * I) / (Length + Offset)))); - Output[I] = static_cast(Value); - } -} - -void twiddle(float *Real, float *Imag, int K, int N) { - float Pi = static_cast(std::acos(-1)); - float Angle = 2 * Pi * K / N; - *Real = cos(Angle); - *Imag = sin(Angle); -} - -void irfft(int N, const float *InpCplx, float *OutReal) { - int NN = N / 2 + 1; - - std::vector RealInput(NN); - std::vector ImagInput(NN); - for (int I = 0; I < NN; ++I) { - RealInput[I] = InpCplx[2 * I]; - ImagInput[I] = InpCplx[2 * I + 1]; - } - - std::vector RealOutput(N); - std::vector ImagOutput(N); - - for (int K = 0; K < N; ++K) { - RealOutput[K] = 0.0f; - ImagOutput[K] = 0.0f; - for (int M = 0; M < NN; ++M) { - float TwiddleReal; - float TwiddleImag; - - twiddle(&TwiddleReal, &TwiddleImag, K * M, N); - - RealOutput[K] += RealInput[M] * TwiddleReal - ImagInput[M] * TwiddleImag; - ImagOutput[K] += RealInput[M] * TwiddleImag + ImagInput[M] * TwiddleReal; - } - } - - for (int I = 0; I < N; ++I) { - OutReal[I] = RealOutput[I] / NN; - } -} - -void fold(const std::vector &Data, int64_t NOut, int64_t NWin, - int64_t NHop, int64_t NPad, std::vector &Output) { - int64_t OutputHeight = NOut; - int64_t KernelW = NWin; - int64_t StrideW = NHop; - int64_t Width = NOut; - - Output.resize(Width, 0.0f); - - int64_t ColIdx = 0; - for (int64_t WCol = 0; WCol < Width; ++WCol) { - int64_t Start = WCol * StrideW - NPad; - int64_t End = Start + KernelW; - - for (int64_t WIm = Start; WIm < End; ++WIm) { - if (WIm >= 0 && WIm < OutputHeight && - ColIdx < static_cast(Data.size())) { - Output[WIm] += Data[ColIdx]; - } - ColIdx++; - } - } - - Output.resize(NOut - 2 * NPad); -} - -std::vector embdToAudio(const float *Embd, const int NCodes, - const int NEmbd, const int NThread) { - const int NFft = 1280; - const int NHop = 320; - const int NWin = 1280; - const int NPad = (NWin - NHop) / 2; - const int NOut = (NCodes - 1) * NHop + NWin; - - std::vector Hann(NFft); - - fillHannWindow(static_cast(Hann.size()), true, Hann.data()); - - int NSpec = NEmbd * NCodes; - - std::vector E(NSpec); - std::vector S(NSpec); - std::vector ST(NSpec); - - for (int L = 0; L < NCodes; ++L) { - for (int K = 0; K < NEmbd; ++K) { - E[K * NCodes + L] = Embd[L * NEmbd + K]; - } - } - - for (int K = 0; K < NEmbd / 2; ++K) { - for (int L = 0; L < NCodes; ++L) { - float Mag = E[(K)*NCodes + L]; - float Phi = E[(K + NEmbd / 2) * NCodes + L]; - - Mag = exp(Mag); - - if (Mag > 1e2) { - Mag = 1e2; - } - S[2 * (K * NCodes + L) + 0] = Mag * cosf(Phi); - S[2 * (K * NCodes + L) + 1] = Mag * sinf(Phi); - } - } - - for (int L = 0; L < NCodes; ++L) { - for (int K = 0; K < NEmbd / 2; ++K) { - ST[L * NEmbd + 2 * K + 0] = S[2 * (K * NCodes + L) + 0]; - ST[L * NEmbd + 2 * K + 1] = S[2 * (K * NCodes + L) + 1]; - } - } - - std::vector Res(NCodes * NFft); - std::vector Hann2(NCodes * NFft); - - std::vector Workers(NThread); - for (int I = 0; I < NThread; ++I) { - Workers[I] = std::thread([&, I]() { - for (int L = I; L < NCodes; L += NThread) { - irfft(NFft, ST.data() + L * NEmbd, Res.data() + L * NFft); - for (int J = 0; J < NFft; ++J) { - Res[L * NFft + J] *= Hann[J]; - Hann2[L * NFft + J] = Hann[J] * Hann[J]; - } - } - }); - } - for (int I = 0; I < NThread; ++I) { - Workers[I].join(); - } - - std::vector Audio; - std::vector Env; - - fold(Res, NOut, NWin, NHop, NPad, Audio); - fold(Hann2, NOut, NWin, NHop, NPad, Env); // TODO: can be done once - - for (size_t I = 0; I < Audio.size(); ++I) { - Audio[I] /= Env[I]; - } - - return Audio; -} - -struct WavHeader { - char Riff[4] = {'R', 'I', 'F', 'F'}; - uint32_t ChunkSize; - char Wave[4] = {'W', 'A', 'V', 'E'}; - char Fmt[4] = {'f', 'm', 't', ' '}; - uint32_t FmtChunkSize = 16; - uint16_t AudioFormat = 1; // PCM - uint16_t NumChannels = 1; // Mono - uint32_t SampleRate; - uint32_t ByteRate; - uint16_t BlockAlign; - uint16_t BitsPerSample = 16; - char Data[4] = {'d', 'a', 't', 'a'}; - uint32_t DataSize; -}; - -std::vector audioDataToWav(const std::vector &Data, - int SampleRate) { - std::vector WavData; - WavHeader Header; - Header.SampleRate = SampleRate; - Header.ByteRate = - Header.SampleRate * Header.NumChannels * (Header.BitsPerSample / 8); - Header.BlockAlign = Header.NumChannels * (Header.BitsPerSample / 8); - Header.DataSize = - static_cast(Data.size() * (Header.BitsPerSample / 8)); - Header.ChunkSize = 36 + Header.DataSize; - - WavData.insert(WavData.end(), reinterpret_cast(&Header), - reinterpret_cast(&Header) + sizeof(Header)); - - for (const auto &Sample : Data) { - int16_t PCMSample = - static_cast(std::clamp(Sample * 32767.0, -32768.0, 32767.0)); - WavData.insert(WavData.end(), reinterpret_cast(&PCMSample), - reinterpret_cast(&PCMSample) + sizeof(PCMSample)); - } - - return WavData; -} - -// TextToSpeech function, will generate voice data from codes. -ErrNo codesToSpeech(WasiNNEnvironment &Env, Graph &GraphRef, - Context &CxtRef) noexcept { - // Remove all non-audio tokens. - CxtRef.LlamaOutputTokens.erase( - std::remove_if(CxtRef.LlamaOutputTokens.begin(), - CxtRef.LlamaOutputTokens.end(), - [](llama_token T) { return T < 151672 || T > 155772; }), - CxtRef.LlamaOutputTokens.end()); - - // Adjust the token values for audio data. - for (llama_token &Token : CxtRef.LlamaOutputTokens) { - Token -= 151672; - } - - // Put codes into batch. - const uint32_t NCodes = - static_cast(CxtRef.LlamaOutputTokens.size()); - llama_batch TTSBatch = - llama_batch_init(NCodes, /* embd */ 0, /* n_seq_max */ 1); - for (uint32_t I = 0; I < NCodes; ++I) { - common_batch_add(TTSBatch, CxtRef.LlamaOutputTokens[I], I, - /* seq_ids */ {0}, /* logits */ true); - } - if (llama_decode(GraphRef.TTSContext.get(), TTSBatch) != 0) { - RET_ERROR(ErrNo::RuntimeError, "codesToSpeech: fail to eval."sv) - } - llama_batch_free(TTSBatch); - - // Get embeddings. - const int NEmbd = llama_model_n_embd(GraphRef.TTSModel.get()); - const float *Embd = llama_get_embeddings(GraphRef.TTSContext.get()); - - // Embeddings to audio. - std::vector AudioData = - embdToAudio(Embd, NCodes, NEmbd, - static_cast(GraphRef.Params.cpuparams.n_threads)); - - // Zero out first 0.25 seconds of audio. - const uint32_t SamplingRate = 24000; - for (uint32_t I = 0; I < SamplingRate / 4; ++I) { - AudioData[I] = 0.0f; - } - - // Convert audio data to wav and put it into output buffer. - CxtRef.LlamaOutputs = audioDataToWav(AudioData, SamplingRate); - - // Save .wav file if path is provided. - if (!GraphRef.TTSOutputFilePath.empty()) { - WasmEdge::FStream::OFStream File(GraphRef.TTSOutputFilePath, - std::ios_base::out | std::ios_base::binary, - Env.getEnv()); - if (!File) { - RET_ERROR(ErrNo::RuntimeError, - "codesToSpeech: Failed to open file '{}' for writing"sv, - GraphRef.TTSOutputFilePath); - } - File.write(reinterpret_cast(CxtRef.LlamaOutputs.data()), - CxtRef.LlamaOutputs.size()); - File.close(); - } - - return ErrNo::Success; -} - -// <<<<<<<< Compute related functions <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< - -} // namespace - -Expect load(WasiNNEnvironment &Env, Span> Builders, - [[maybe_unused]] Device Device, uint32_t &GraphId) noexcept { - // Add a new graph. - EndianValue GId = Env.newGraph(Backend::GGML); - auto &GraphRef = Env.NNGraph[GId.raw()].get(); - - // Initialize the plugin parameters. - GraphRef.EnableLog = false; - GraphRef.EnableDebugLog = false; - common_params CommonParamsDefault; - CommonParamsDefault.lr.init(); - GraphRef.Params = CommonParamsDefault; - GraphRef.Params.n_keep = 0; - GraphRef.Params.n_chunks = -1; - GraphRef.Params.n_parallel = 1; - GraphRef.Params.grp_attn_n = 1; - GraphRef.Params.grp_attn_w = 512; - GraphRef.Params.n_print = -1; - GraphRef.Params.split_mode = llama_split_mode::LLAMA_SPLIT_MODE_LAYER; - // Initialize the model parameters. - llama_model_params ModelParamsDefault = llama_model_default_params(); - GraphRef.Params.n_gpu_layers = ModelParamsDefault.n_gpu_layers; - GraphRef.Params.mmproj.path = ""sv; - GraphRef.Params.warmup = false; - - // Initialize the sampling parameters. - const common_params_sampling SamplerParamsDefault; - GraphRef.Params.sampling = SamplerParamsDefault; - // Initialize the config parameters. - GraphRef.Conf.StreamStdout = false; - GraphRef.Conf.EmbdNormalize = - static_cast(CommonParamsDefault.embd_normalize); - GraphRef.Conf.NPredict = GraphRef.Params.n_ctx; - GraphRef.Conf.ReversePrompt = ""sv; - GraphRef.Conf.ImagePath = ""sv; - - // Set llama log callback. - llama_log_set(llamaLogCallback, &GraphRef); - - // If the graph builder length > 1, the data of builder[1] is the metadata. - if (Builders.size() > 1) { - const std::string Metadata(reinterpret_cast(Builders[1].data()), - Builders[1].size()); - // Ignore context or model updates when initializing the graph. - auto Res = parseMetadata(GraphRef, GraphRef.Conf, Metadata); - if (Res != ErrNo::Success) { - Env.deleteGraph(GId.raw()); - RET_ERROR(Res, "load: Failed to parse metadata."sv) - } - } - - // Logging. - LOG_DEBUG(GraphRef.EnableDebugLog, "load"sv) - LOG_INFO(GraphRef.EnableLog, "LLAMA_COMMIT {}"sv, LLAMA_COMMIT) - LOG_INFO(GraphRef.EnableLog, "LLAMA_BUILD_NUMBER {}"sv, LLAMA_BUILD_NUMBER) - - // Handle the model path. - LOG_DEBUG(GraphRef.EnableDebugLog, "load: handling model path."sv) - auto Weight = Builders[0]; - const std::string_view BinModel(reinterpret_cast(Weight.data()), - Weight.size()); - if (BinModel.substr(0, 8) == "preload:"sv) { - GraphRef.Params.model.path = BinModel.substr(8); - } else { - LOG_DEBUG(GraphRef.EnableDebugLog, - "load: Model path not found in nn-preload, write model into "sv - "a tmpfile."sv) - // TODO: pass the model directly to ggml. - // Write ggml model to file. - GraphRef.Params.model.path = "ggml-model.bin"sv; - WasmEdge::FStream::OFStream TempFile( - GraphRef.Params.model.path, std::ios_base::out | std::ios_base::binary, - Env.getEnv()); - if (!TempFile) { - Env.deleteGraph(GId.raw()); - RET_ERROR(ErrNo::InvalidArgument, - "load: Failed to create the temporary file. Currently, our "sv - "workaround involves creating a temporary model file named "sv - "\"ggml-model.bin\" and passing this filename as a "sv - "parameter to the ggml llama library."sv) - } - TempFile.write(BinModel.data(), BinModel.size()); - TempFile.close(); - LOG_DEBUG(GraphRef.EnableDebugLog, - "load: Write model into a tmpfile...Done"sv) - } - LOG_DEBUG(GraphRef.EnableDebugLog, "load: handling model path...Done"sv) - - // Check if the model exists. - if (!std::filesystem::exists( - std::filesystem::u8path(GraphRef.Params.model.path))) { - Env.deleteGraph(GId.raw()); - RET_ERROR(ErrNo::ModelNotFound, "load: model file not found."sv) - } - GraphRef.Params.model = GraphRef.Params.model; - - // Initialize ggml parameters. - LOG_DEBUG(GraphRef.EnableDebugLog, - "load: initialize ggml model with given parameters."sv) - - common_params Params = GraphRef.Params; - Params.cpuparams.n_threads = - static_cast(GraphRef.Params.cpuparams.n_threads); - Params.cpuparams_batch.n_threads = - static_cast(GraphRef.Params.cpuparams.n_threads); - llama_backend_init(); - llama_numa_init(Params.numa); - - // Initialize the llama model and context. - common_init_result LlamaInit = common_init_from_params(Params); - GraphRef.LlamaModel = std::move(LlamaInit.model); - GraphRef.LlamaContext = std::move(LlamaInit.context); - if (GraphRef.LlamaModel == nullptr) { - Env.deleteGraph(GId.raw()); - RET_ERROR(ErrNo::InvalidArgument, "load: unable to init model."sv) - } - if (GraphRef.LlamaContext == nullptr) { - Env.deleteGraph(GId.raw()); - RET_ERROR(ErrNo::InvalidArgument, "load: unable to init context."sv) - } - LOG_DEBUG(GraphRef.EnableDebugLog, - "load: initialize ggml model with given parameters...Done"sv) - - // Initialize the TTS related model and context. - if (GraphRef.TextToSpeech) { - LOG_DEBUG(GraphRef.EnableDebugLog, "load: initialize TTS model."sv) - Params.model = GraphRef.Params.vocoder.model; - Params.embedding = true; - common_init_result TTSInit = common_init_from_params(Params); - GraphRef.TTSModel = std::move(TTSInit.model); - GraphRef.TTSContext = std::move(TTSInit.context); - if (GraphRef.TTSModel == nullptr) { - Env.deleteGraph(GId.raw()); - RET_ERROR(ErrNo::InvalidArgument, "load: unable to init TTS model."sv) - } - if (GraphRef.TTSContext == nullptr) { - Env.deleteGraph(GId.raw()); - RET_ERROR(ErrNo::InvalidArgument, "load: unable to init TTS context."sv) - } - LOG_DEBUG(GraphRef.EnableDebugLog, "load: initialize TTS model...Done"sv) - } - - // Store the loaded graph. - GraphId = GId.le(); - Env.NNGraph[GId.raw()].setReady(); - - LOG_DEBUG(GraphRef.EnableDebugLog, "load...Done"sv) - return ErrNo::Success; -} - -Expect initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId, - uint32_t &ContextId) noexcept { - auto &GraphRef = Env.NNGraph[GraphId].get(); - LOG_DEBUG(GraphRef.EnableDebugLog, "initExecCtx"sv) - ContextId = Env.newContext(GraphId, Env.NNGraph[GraphId]); - LOG_INFO(GraphRef.EnableLog, "llama_system_info: {}"sv, - llama_print_system_info()) - - auto &CxtRef = Env.NNContext[ContextId].get(); - // Allocate the batch for input string prompt tokens. - CxtRef.LlamaBatch = allocBatch(GraphRef.Params.n_batch); - CxtRef.CurrentBatchSize = GraphRef.Params.n_batch; - - // Allocate the batch for output sampling. The batch size is always 1. - CxtRef.OutputBatch = allocBatch(1); - - // Allocate sampler. - CxtRef.LlamaSampler = - common_sampler_init(GraphRef.LlamaModel.get(), GraphRef.Params.sampling); - - Env.NNContext[ContextId].setReady(); - ContextId = EndianValue(ContextId).le(); - LOG_DEBUG(GraphRef.EnableDebugLog, "initExecCtx...Done"sv) - return ErrNo::Success; -} - -Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, - uint32_t Index, const TensorData &Tensor) noexcept { - auto &CxtRef = Env.NNContext[ContextId].get(); - auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - LOG_DEBUG(GraphRef.EnableDebugLog, "setInput"sv) - - // Use index 1 for metadata. - if (Index == 1) { - LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: found Metadata, processing"sv) - bool IsModelParamsUpdated = false; - bool IsContextParamsUpdated = false; - bool IsSamplerParamsUpdated = false; - const std::string Metadata(reinterpret_cast(Tensor.Tensor.data()), - Tensor.Tensor.size()); - auto Res = - parseMetadata(GraphRef, CxtRef.Conf, Metadata, &IsModelParamsUpdated, - &IsContextParamsUpdated, &IsSamplerParamsUpdated); - if (Res != ErrNo::Success) { - RET_ERROR(Res, "setInput: failed to parse metadata."sv) - } - -#ifndef __APPLE__ - // XXX: Due to the limitation of WASI-NN proposal, this is a workaround - // for non-macOS devices. However, if the model params is updated in - // Config stage, then, we don't encourage to use this to avoid the model - // reloading. - { - if (IsModelParamsUpdated || GraphRef.LlamaModel == nullptr) { - // The llama model may be nullptr if set_input with updated model params - // last time. Therefore besides the model params updated, we should - // reload the llama model if the model is nullptr. - LOG_INFO(GraphRef.EnableLog, - "setInput: Reload model due to parameters change."sv) - llama_model_params ModelParams = llama_model_default_params(); - ModelParams.n_gpu_layers = - static_cast(GraphRef.Params.n_gpu_layers); - GraphRef.LlamaModel.reset(); - // Due to the model change, the context and sampler should also be - // reloaded. The new context and sampler will be created in the next - // block. - GraphRef.LlamaContext.reset(); - if (CxtRef.LlamaSampler) { - // TODO: Trigger the sampler in other contexts to reallocate. - common_sampler_free(CxtRef.LlamaSampler); - CxtRef.LlamaSampler = nullptr; - } - GraphRef.LlamaModel = llama_model_ptr(llama_model_load_from_file( - GraphRef.Params.model.path.c_str(), ModelParams)); - if (GraphRef.LlamaModel == nullptr) { - Env.NNGraph[CxtRef.GraphId].setInvalid(); - RET_ERROR(ErrNo::InvalidArgument, "setInput: unable to init model."sv) - } - } - } -#endif - - // Some changes of context parameters will require the context to be - // reloaded. - if (IsContextParamsUpdated || GraphRef.LlamaContext == nullptr) { - LOG_INFO(GraphRef.EnableLog, - "setInput: Reload llama context due to parameters change."sv) - GraphRef.LlamaContext.reset(); - GraphRef.LlamaContext = llama_context_ptr(llama_init_from_model( - GraphRef.LlamaModel.get(), - common_context_params_to_llama(GraphRef.Params))); - if (GraphRef.LlamaContext == nullptr) { - Env.NNGraph[CxtRef.GraphId].setInvalid(); - RET_ERROR(ErrNo::InvalidArgument, "setInput: unable to init context."sv) - } - } - - // Some changes of sampling parameters will require the sampler to be - // reallocated. - if (IsSamplerParamsUpdated || CxtRef.LlamaSampler == nullptr) { - LOG_INFO(GraphRef.EnableLog, - "setInput: Reallocate llama sampler due to parameters change."sv) - if (CxtRef.LlamaSampler) { - common_sampler_free(CxtRef.LlamaSampler); - } - CxtRef.LlamaSampler = common_sampler_init(GraphRef.LlamaModel.get(), - GraphRef.Params.sampling); - if (GraphRef.LlamaContext == nullptr) { - Env.NNGraph[CxtRef.GraphId].setInvalid(); - RET_ERROR(ErrNo::InvalidArgument, "setInput: unable to init sampler."sv) - } - } - - // Check that is batch size changed. - if (CxtRef.CurrentBatchSize != GraphRef.Params.n_batch) { - llama_batch_free(CxtRef.LlamaBatch); - CxtRef.LlamaBatch = allocBatch(GraphRef.Params.n_batch); - CxtRef.CurrentBatchSize = GraphRef.Params.n_batch; - } - - Env.NNGraph[CxtRef.GraphId].setReady(); - LOG_DEBUG(GraphRef.EnableDebugLog, - "setInput: found Metadata, processing...Done"sv) - return ErrNo::Success; - } - - // Check the graph is valid after reloading during previous set_input. - if (!Env.NNGraph[CxtRef.GraphId].isReady()) { - RET_ERROR( - ErrNo::InvalidArgument, - "setInput: Graph is invalid. Please reload again by passing metadata "sv - "in set_input or unload graph."sv) - } - - // Clear the llama context. - LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: clear llama context"sv) - llama_memory_clear(llama_get_memory(GraphRef.LlamaContext.get()), true); - LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: clear llama context...Done"sv) - - // Set the input. - const bool AddSpecial = true; - const bool ParseSpecial = true; - std::string Prompt(reinterpret_cast(Tensor.Tensor.data()), - Tensor.Tensor.size()); - CxtRef.LlamaInputs.clear(); - - auto Base64ImagePos = findBase64ImagePayload(Prompt); - - if (Base64ImagePos.has_value() || CxtRef.Conf.ImagePath != ""sv) { - // First check the projection model is given. - if (GraphRef.Params.mmproj.path == ""sv) { - RET_ERROR( - ErrNo::InvalidArgument, - "setInput: the given model does not support image input, so a projection model is required."sv) - } - - // Make sure the projection model is loaded. - if (GraphRef.VisionContext == nullptr) { - LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: initialize mtmd context."sv) - // Initialize the mtmd context. - mtmd_context_params VisionContextParams = mtmd_context_params_default(); - std::string VisionPromptImagePlaceholderStr(VisionPromptImagePlaceholder); - VisionContextParams.media_marker = - VisionPromptImagePlaceholderStr.c_str(); - VisionContextParams.use_gpu = GraphRef.Params.mmproj_use_gpu; - VisionContextParams.n_threads = GraphRef.Params.cpuparams.n_threads; - VisionContextParams.print_timings = - GraphRef.EnableLog || GraphRef.EnableDebugLog; - if (GraphRef.EnableDebugLog) { - VisionContextParams.verbosity = GGML_LOG_LEVEL_DEBUG; - } else if (GraphRef.EnableLog) { - VisionContextParams.verbosity = GGML_LOG_LEVEL_INFO; - } else { - VisionContextParams.verbosity = GGML_LOG_LEVEL_NONE; - } - GraphRef.VisionContext.reset( - mtmd_init_from_file(GraphRef.Params.mmproj.path.c_str(), - GraphRef.LlamaModel.get(), VisionContextParams)); - if (GraphRef.VisionContext == nullptr) { - RET_ERROR(ErrNo::InvalidArgument, - "setInput: unable to load the mmproj model {}."sv, - GraphRef.Params.mmproj.path) - } - LOG_DEBUG(GraphRef.EnableDebugLog, - "setInput: initialize mtmd context...Done"sv) - } - - // Show some warnings for context size. - if (GraphRef.Params.n_ctx < 4096) { - LOG_INFO( - GraphRef.EnableLog, - "setInput: Context size is {}, we recommend context size >= 4096 when using multimodal models for better results"sv, - GraphRef.Params.n_ctx) - } - - // Get the image bitmaps. - // Follow this link for the supported image formats: - // https://github.com/ggml-org/llama.cpp/blob/master/common/stb_image.h - mtmd::bitmaps Bitmaps; - if (GraphRef.VisionContext != nullptr) { - if (Base64ImagePos.has_value()) { - // Load the image bitmap from the base64 image. - LOG_DEBUG(GraphRef.EnableDebugLog, - "setInput: load the image bitmap from the base64 image."sv) - // Extract the payload and image type from the prompt. - std::optional, std::string>> Payload = - extractBase64ImagePayload(Prompt, *Base64ImagePos, - VisionPromptImagePlaceholder); - if (Payload.has_value()) { - // Create the new image bitmap. - mtmd::bitmap Bitmap(mtmd_helper_bitmap_init_from_buf( - GraphRef.VisionContext.get(), Payload->first.data(), - Payload->first.size())); - if (Bitmap.ptr == nullptr) { - RET_ERROR( - ErrNo::InvalidArgument, - "setInput: unable to load the image from base64 paylaod."sv) - } - Bitmaps.entries.push_back(std::move(Bitmap)); - } - LOG_DEBUG(GraphRef.EnableDebugLog, - "setInput: Compute image embd from the base64 image...Done"sv) - } else { - // Load the image from the file. - LOG_DEBUG(GraphRef.EnableDebugLog, - "setInput: load the image bitmap from file: {}"sv, - CxtRef.Conf.ImagePath) - mtmd::bitmap Bitmap(mtmd_helper_bitmap_init_from_file( - GraphRef.VisionContext.get(), CxtRef.Conf.ImagePath.c_str())); - if (Bitmap.ptr == nullptr) { - RET_ERROR( - ErrNo::InvalidArgument, - "setInput: unable to load the image bitmap from file: {}."sv, - CxtRef.Conf.ImagePath) - } - Bitmaps.entries.push_back(std::move(Bitmap)); - LOG_DEBUG(GraphRef.EnableDebugLog, - "setInput: load the image bitmap from file: {}...Done"sv, - CxtRef.Conf.ImagePath) - } - } - - // Tokenize the prompt. - LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize the mtmd prompt"sv) - GraphRef.VisionInputChunks.reset(mtmd_input_chunks_init()); - mtmd_input_text MtmdText; - MtmdText.text = Prompt.c_str(); - MtmdText.add_special = AddSpecial; - MtmdText.parse_special = ParseSpecial; - std::vector BitmapsPtr = Bitmaps.c_ptr(); - int32_t Res = mtmd_tokenize(GraphRef.VisionContext.get(), - GraphRef.VisionInputChunks.get(), &MtmdText, - BitmapsPtr.data(), BitmapsPtr.size()); - if (Res != 0) { - RET_ERROR(ErrNo::InvalidArgument, - "setInput: unable to tokenize the mtmd prompt."sv) - } - LOG_DEBUG(GraphRef.EnableDebugLog, - "setInput: tokenize the mtmd prompt...Done"sv) - - // Get the number of input tokens (for the metadata). - CxtRef.LlamaNInputs = 0; - for (size_t ChunkIndex = 0; - ChunkIndex < mtmd_input_chunks_size(GraphRef.VisionInputChunks.get()); - ++ChunkIndex) { - size_t NTokens = 0; - const mtmd_input_chunk *Chunk = - mtmd_input_chunks_get(GraphRef.VisionInputChunks.get(), ChunkIndex); - mtmd_input_chunk_get_tokens_text(Chunk, &NTokens); - CxtRef.LlamaNInputs += NTokens; - } - } else if (GraphRef.TextToSpeech == true) { - // TTS prompt. - LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize tts prompt"sv) - CxtRef.LlamaInputs = processTTSPrompt(Env, GraphRef, Prompt); - if (CxtRef.LlamaInputs.empty()) { - RET_ERROR(ErrNo::InvalidArgument, - "setInput: failed to tokenize tts prompt."sv) - } - LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize tts prompt...Done"sv) - - // Get the number of input tokens (for the metadata). - CxtRef.LlamaNInputs = CxtRef.LlamaInputs.size(); - } else { - // Text only prompt. - LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize text prompt"sv) - CxtRef.LlamaInputs = common_tokenize(GraphRef.LlamaContext.get(), Prompt, - AddSpecial, ParseSpecial); - LOG_DEBUG(GraphRef.EnableDebugLog, - "setInput: tokenize text prompt...Done"sv) - - // Get the number of input tokens (for the metadata). - CxtRef.LlamaNInputs = CxtRef.LlamaInputs.size(); - } - - // Maybe currently in the compute_single mode. Reset the computing. - CxtRef.ComputeSingleStarted = false; - - LOG_DEBUG(GraphRef.EnableDebugLog, "setInput...Done"sv) - return ErrNo::Success; -} - -Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, - uint32_t Index, Span OutBuffer, - uint32_t &BytesWritten) noexcept { - auto &CxtRef = Env.NNContext[ContextId].get(); - auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - LOG_DEBUG(GraphRef.EnableDebugLog, "getOutput: with Index {}"sv, Index) - - // Use index 1 for the metadata of the outputs. - if (Index == 1) { - std::string Metadata = buildOutputMetadata(CxtRef); - std::copy_n(Metadata.data(), Metadata.length(), OutBuffer.data()); - BytesWritten = static_cast(Metadata.length()); - LOG_DEBUG(GraphRef.EnableDebugLog, - "getOutput: with Index {} a.k.a Metadata ...Done"sv, Index) - return ErrNo::Success; - } - - std::copy_n(CxtRef.LlamaOutputs.data(), CxtRef.LlamaOutputs.size(), - OutBuffer.data()); - BytesWritten = - EndianValue(static_cast(CxtRef.LlamaOutputs.size())).le(); - LOG_DEBUG(GraphRef.EnableDebugLog, "getOutput: with Index {}...Done"sv, Index) - return ErrNo::Success; -} - -Expect compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { - auto &CxtRef = Env.NNContext[ContextId].get(); - auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - LOG_DEBUG(GraphRef.EnableDebugLog, "compute") - - // Clear the context and reset the sampler. - clearContext(GraphRef, CxtRef); - - if (GraphRef.Params.embedding) { - return getEmbedding(GraphRef, CxtRef); - } - - // Evaluate the input tokens. - ErrNo ReturnCode = ErrNo::Success; - if (GraphRef.VisionContext == nullptr) { - // Text only prompt. - ReturnCode = evaluateInput(GraphRef, CxtRef, "compute"sv); - if (ReturnCode != ErrNo::Success) { - return ReturnCode; - } - } else { - // Multimodal prompt. - llama_pos NewNPos; - int32_t Res = mtmd_helper_eval_chunks( - GraphRef.VisionContext.get(), GraphRef.LlamaContext.get(), - GraphRef.VisionInputChunks.get(), CxtRef.NPos, - /* seq_id */ 0, static_cast(CxtRef.CurrentBatchSize), - /* logits_last */ true, &NewNPos); - CxtRef.NPos = NewNPos; - if (Res != 0) { - RET_ERROR(ErrNo::InvalidArgument, - "compute: unable to eval the mtmd prompt."sv) - } - } - - // Main prediction loop. - LOG_DEBUG(GraphRef.EnableDebugLog, "compute: enter main prediction loop"sv) - int64_t NPredict = - CxtRef.Conf.NPredict < 0 ? INT32_MAX : CxtRef.Conf.NPredict; - - while (NPredict-- > 0) { - ReturnCode = sampleOutput(GraphRef, CxtRef); - if (ReturnCode != ErrNo::Success) { - break; - } - } - if (ReturnCode == ErrNo::EndOfSequence) { - ReturnCode = ErrNo::Success; - } - LOG_DEBUG(GraphRef.EnableDebugLog, - "compute: enter main prediction loop...Done"sv) - // End of main prediction loop. - - // TTS: convert output codes to audio file. - if (GraphRef.TextToSpeech) { - LOG_DEBUG(GraphRef.EnableDebugLog, - "compute: convert output codes to audio file."sv) - ReturnCode = codesToSpeech(Env, GraphRef, CxtRef); - if (ReturnCode != ErrNo::Success) { - RET_ERROR(ReturnCode, - "compute: failed to convert output codes to audio "sv - "file."sv) - } - LOG_DEBUG(GraphRef.EnableDebugLog, - "compute: convert output codes to audio file...Done"sv) - } - - if (GraphRef.EnableLog) { - common_perf_print(GraphRef.LlamaContext.get(), CxtRef.LlamaSampler); - } - - LOG_DEBUG(GraphRef.EnableDebugLog, "compute...Done"sv) - return ReturnCode; -} - -Expect getOutputSingle(WasiNNEnvironment &Env, uint32_t ContextId, - uint32_t Index, Span OutBuffer, - uint32_t &BytesWritten) noexcept { - auto &CxtRef = Env.NNContext[ContextId].get(); - auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - LOG_DEBUG(GraphRef.EnableDebugLog, "getOutputSingle: with Index {}"sv, Index) - - // Use index 1 for the metadata of the outputs. - if (Index == 1) { - std::string Metadata = buildOutputMetadata(CxtRef); - std::copy_n(Metadata.data(), Metadata.length(), OutBuffer.data()); - BytesWritten = static_cast(Metadata.length()); - LOG_DEBUG(GraphRef.EnableDebugLog, - "getOutputSingle: with Index {} a.k.a Metadata...Done"sv, Index) - return ErrNo::Success; - } - - std::string LastToken = common_token_to_piece( - GraphRef.LlamaContext.get(), CxtRef.LlamaOutputTokens.back()); - std::copy_n(LastToken.data(), LastToken.length(), OutBuffer.data()); - BytesWritten = EndianValue(static_cast(LastToken.length())).le(); - LOG_DEBUG(GraphRef.EnableDebugLog, "getOutputSingle: with Index {}...Done"sv, - Index) - return ErrNo::Success; -} - -Expect computeSingle(WasiNNEnvironment &Env, - uint32_t ContextId) noexcept { - auto &CxtRef = Env.NNContext[ContextId].get(); - auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - LOG_DEBUG(GraphRef.EnableDebugLog, "computeSingle"sv) - - // New compute single token context. - auto ReturnCode = ErrNo::Success; - if (!CxtRef.ComputeSingleStarted) { - CxtRef.ComputeSingleStarted = true; - - // Clear the context and reset the sampler. - clearContext(GraphRef, CxtRef); - - // Evaluate the input tokens. - if (GraphRef.VisionContext == nullptr) { - // Text only prompt. - ReturnCode = evaluateInput(GraphRef, CxtRef, "compute"sv); - if (ReturnCode != ErrNo::Success) { - return ReturnCode; - } - } else { - // Multimodal prompt. - llama_pos NewNPos; - int32_t Res = mtmd_helper_eval_chunks( - GraphRef.VisionContext.get(), GraphRef.LlamaContext.get(), - GraphRef.VisionInputChunks.get(), CxtRef.NPos, - /* seq_id */ 0, static_cast(CxtRef.CurrentBatchSize), - /* logits_last */ true, &NewNPos); - CxtRef.NPos = NewNPos; - if (Res != 0) { - RET_ERROR(ErrNo::InvalidArgument, - "compute: unable to eval the mtmd prompt."sv) - } - } - } - - // Main prediction process. - LOG_DEBUG(GraphRef.EnableDebugLog, - "computeSingle: enter main prediction process"sv) - ReturnCode = sampleOutput(GraphRef, CxtRef, true); - if (ReturnCode != ErrNo::Success) { - CxtRef.ComputeSingleStarted = false; - } - LOG_DEBUG(GraphRef.EnableDebugLog, - "computeSingle: enter main prediction process...Done"sv) - // End of main predict process. - - LOG_DEBUG(GraphRef.EnableDebugLog, "computeSingle...Done"sv) - return ReturnCode; -} - -Expect finiSingle(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { - auto &CxtRef = Env.NNContext[ContextId].get(); - auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - LOG_DEBUG(GraphRef.EnableDebugLog, "finiSingle"sv) - - // Logging for the llama timings. - if (GraphRef.EnableLog) { - common_perf_print(GraphRef.LlamaContext.get(), CxtRef.LlamaSampler); - } - - // Clear the outputs. - LOG_DEBUG(GraphRef.EnableDebugLog, - "finiSingle: clear the previous output and tokens"sv) - CxtRef.LlamaOutputs.clear(); - CxtRef.LlamaOutputTokens.clear(); - LOG_DEBUG(GraphRef.EnableDebugLog, - "finiSingle: clear the previous output and tokens...Done"sv) - - // Reset the llama sampler. - common_sampler_reset(CxtRef.LlamaSampler); - CxtRef.ComputeSingleStarted = false; - CxtRef.NPos = 0; - - LOG_DEBUG(GraphRef.EnableDebugLog, "finiSingle...Done"sv) - return ErrNo::Success; -} - -Expect unload(WasiNNEnvironment &Env, uint32_t GraphId) noexcept { - auto &GraphRef = Env.NNGraph[GraphId].get(); - const bool IsDebugLog = GraphRef.EnableDebugLog; - LOG_DEBUG(IsDebugLog, "unload"sv) - - // TODO: Move the resource deallocation into the destructor. - if (GraphRef.LlamaModel != nullptr) { - LOG_DEBUG(IsDebugLog, "unload: free llama model"sv) - GraphRef.LlamaModel.reset(); - LOG_DEBUG(IsDebugLog, "unload: free llama model...Done"sv) - } - if (GraphRef.LlamaContext != nullptr) { - LOG_DEBUG(IsDebugLog, "unload: free llama context"sv) - GraphRef.LlamaContext.reset(); - LOG_DEBUG(IsDebugLog, "unload: free llama context...Done"sv) - } - if (GraphRef.VisionContext != nullptr) { - LOG_DEBUG(IsDebugLog, "unload: free mtmd context"sv) - GraphRef.VisionContext.reset(); - LOG_DEBUG(IsDebugLog, "unload: free mtmd context...Done"sv) - } - if (GraphRef.VisionInputChunks != nullptr) { - LOG_DEBUG(IsDebugLog, "unload: free mtmd chunks"sv) - GraphRef.VisionInputChunks.reset(); - LOG_DEBUG(IsDebugLog, "unload: free mtmd chunks...Done"sv) - } - if (GraphRef.TTSModel != nullptr) { - LOG_DEBUG(IsDebugLog, "unload: free TTS model"sv) - GraphRef.TTSModel.reset(); - LOG_DEBUG(IsDebugLog, "unload: free TTS model...Done"sv) - } - if (GraphRef.TTSContext != nullptr) { - LOG_DEBUG(IsDebugLog, "unload: free TTS context"sv) - GraphRef.TTSContext.reset(); - LOG_DEBUG(IsDebugLog, "unload: free TTS context...Done"sv) - } - if (!GraphRef.TensorBuftOverrides.empty()) { - LOG_DEBUG(IsDebugLog, "unload: free tensor buffer overrides"sv) - GraphRef.TensorBuftOverrides.clear(); - LOG_DEBUG(IsDebugLog, "unload: free tensor buffer overrides...Done"sv) - } - Env.deleteGraph(GraphId); - Env.mdRemoveById(GraphId); - - LOG_DEBUG(IsDebugLog, "unload...Done"sv) - return ErrNo::Success; -} - -Expect finalizeExecCtx(WasiNNEnvironment &Env, - uint32_t ContextId) noexcept { - auto &CxtRef = Env.NNContext[ContextId].get(); - auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - LOG_DEBUG(GraphRef.EnableDebugLog, "finalize_execution_context"sv) - - if (CxtRef.LlamaSampler != nullptr) { - LOG_DEBUG(GraphRef.EnableDebugLog, - "finalize_execution_context: free compute_single sampler"sv) - common_sampler_free(CxtRef.LlamaSampler); - CxtRef.LlamaSampler = nullptr; - LOG_DEBUG( - GraphRef.EnableDebugLog, - "finalize_execution_context: free compute_single sampler...Done"sv) - } - llama_batch_free(CxtRef.LlamaBatch); - llama_batch_free(CxtRef.OutputBatch); - Env.deleteContext(ContextId); - - LOG_DEBUG(GraphRef.EnableDebugLog, "finalize_execution_context...Done"sv) - return ErrNo::Success; -} - -#else -namespace { -Expect reportBackendNotSupported() noexcept { - spdlog::error("[WASI-NN] ggml backend is not built. use " - "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"ggml\" to build it."sv); - return ErrNo::InvalidArgument; -} -} // namespace - -Expect load(WasiNNEnvironment &, Span>, Device, - uint32_t &) noexcept { - return reportBackendNotSupported(); -} -Expect initExecCtx(WasiNNEnvironment &, uint32_t, uint32_t &) noexcept { - return reportBackendNotSupported(); -} -Expect setInput(WasiNNEnvironment &, uint32_t, uint32_t, - const TensorData &) noexcept { - return reportBackendNotSupported(); -} -Expect getOutput(WasiNNEnvironment &, uint32_t, uint32_t, Span, - uint32_t &) noexcept { - return reportBackendNotSupported(); -} -Expect compute(WasiNNEnvironment &, uint32_t) noexcept { - return reportBackendNotSupported(); -} -Expect getOutputSingle(WasiNNEnvironment &, uint32_t, uint32_t, - Span, uint32_t &) noexcept { - return reportBackendNotSupported(); -} -Expect computeSingle(WasiNNEnvironment &, uint32_t) noexcept { - return reportBackendNotSupported(); -} -Expect finiSingle(WasiNNEnvironment &, uint32_t) noexcept { - return reportBackendNotSupported(); -} -Expect unload(WasiNNEnvironment &, uint32_t) noexcept { - return reportBackendNotSupported(); -} -Expect finalizeExecCtx(WasiNNEnvironment &, uint32_t) noexcept { - return reportBackendNotSupported(); -} - -#endif -} // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index 143537d2..c83f0c17 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -3,9 +3,9 @@ #pragma once +#include "GGML/core/ggml_core.h" #include "wasinn_bitnet.h" #include "wasinn_chattts.h" -#include "wasinn_ggml.h" #include "wasinn_mlx.h" #include "wasinn_neuralspeed.h" #include "wasinn_onnx.h" From cad375398048df4dfb311d7daffdfc809ccd5e6d Mon Sep 17 00:00:00 2001 From: "Wang-Yang, Li" <7088579+LFsWang@users.noreply.github.com> Date: Fri, 7 Nov 2025 02:15:03 +0800 Subject: [PATCH 590/623] feat(WASI-NN,ggml): bump llama.cpp b6923 and adjust seed type (#4407) Signed-off-by: LFsWang --- plugins/wasi_nn/CMakeLists.txt | 2 +- plugins/wasi_nn/GGML/metadata/metadata_parser.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 2770a556..722b9d1c 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -24,7 +24,7 @@ wasmedge_add_library(wasmedgePluginWasiNN include(WASINNDeps) wasmedge_setup_wasinn_target(wasmedgePluginWasiNN PLUGINLIB) -set(WASMEDGE_WASI_NN_VERSION "0.1.32" CACHE STRING "WasmEdge WASI-NN library version") +set(WASMEDGE_WASI_NN_VERSION "0.1.33" CACHE STRING "WasmEdge WASI-NN library version") set(WASMEDGE_WASI_NN_SOVERSION "0" CACHE STRING "WasmEdge WASI-NN library soversion") # Handle the version of the WASI-NN plugin diff --git a/plugins/wasi_nn/GGML/metadata/metadata_parser.cpp b/plugins/wasi_nn/GGML/metadata/metadata_parser.cpp index 587ceb13..397e6036 100644 --- a/plugins/wasi_nn/GGML/metadata/metadata_parser.cpp +++ b/plugins/wasi_nn/GGML/metadata/metadata_parser.cpp @@ -34,7 +34,7 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, double PrevPresencePenalty = GraphRef.Params.sampling.penalty_present; double PrevFrequencyPenalty = GraphRef.Params.sampling.penalty_freq; std::string PrevGrammar = GraphRef.Params.sampling.grammar; - uint64_t PrevSeed = GraphRef.Params.sampling.seed; + uint32_t PrevSeed = GraphRef.Params.sampling.seed; try { parseJsonAuto(Doc, "enable-log", GraphRef.EnableLog); @@ -268,7 +268,7 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, json_schema_to_grammar(nlohmann::ordered_json::parse(JsonSchema)); return true; }); - parseJsonWithCastAuto(Doc, "seed", GraphRef.Params.sampling.seed); + parseJsonWithCastAuto(Doc, "seed", GraphRef.Params.sampling.seed); // The speculative parameters. parseJsonWithCastAuto(Doc, "n-ctx-speculative", From b7af2b711de59a816d229f0f77364c7a6a866ece Mon Sep 17 00:00:00 2001 From: "Wang-Yang, Li" <7088579+LFsWang@users.noreply.github.com> Date: Thu, 20 Nov 2025 22:55:58 +0800 Subject: [PATCH 591/623] fix(wasi_nn): update WASI-NN version to 0.1.34 and llama GIT_TAG to b7090 (#4421) Signed-off-by: LFsWang --- plugins/wasi_nn/CMakeLists.txt | 2 +- plugins/wasi_nn/GGML/core/ggml_type.h | 1 + plugins/wasi_nn/GGML/core/input_processor.cpp | 7 ------- 3 files changed, 2 insertions(+), 8 deletions(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 722b9d1c..d9c40e83 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -24,7 +24,7 @@ wasmedge_add_library(wasmedgePluginWasiNN include(WASINNDeps) wasmedge_setup_wasinn_target(wasmedgePluginWasiNN PLUGINLIB) -set(WASMEDGE_WASI_NN_VERSION "0.1.33" CACHE STRING "WasmEdge WASI-NN library version") +set(WASMEDGE_WASI_NN_VERSION "0.1.34" CACHE STRING "WasmEdge WASI-NN library version") set(WASMEDGE_WASI_NN_SOVERSION "0" CACHE STRING "WasmEdge WASI-NN library soversion") # Handle the version of the WASI-NN plugin diff --git a/plugins/wasi_nn/GGML/core/ggml_type.h b/plugins/wasi_nn/GGML/core/ggml_type.h index 66da8786..36e6f482 100644 --- a/plugins/wasi_nn/GGML/core/ggml_type.h +++ b/plugins/wasi_nn/GGML/core/ggml_type.h @@ -8,6 +8,7 @@ #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML #include "wasinntypes.h" +#include #include #include #include diff --git a/plugins/wasi_nn/GGML/core/input_processor.cpp b/plugins/wasi_nn/GGML/core/input_processor.cpp index 0a71d1cc..fc1eaf91 100644 --- a/plugins/wasi_nn/GGML/core/input_processor.cpp +++ b/plugins/wasi_nn/GGML/core/input_processor.cpp @@ -156,13 +156,6 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, VisionContextParams.n_threads = GraphRef.Params.cpuparams.n_threads; VisionContextParams.print_timings = GraphRef.EnableLog || GraphRef.EnableDebugLog; - if (GraphRef.EnableDebugLog) { - VisionContextParams.verbosity = GGML_LOG_LEVEL_DEBUG; - } else if (GraphRef.EnableLog) { - VisionContextParams.verbosity = GGML_LOG_LEVEL_INFO; - } else { - VisionContextParams.verbosity = GGML_LOG_LEVEL_NONE; - } GraphRef.VisionContext.reset( mtmd_init_from_file(GraphRef.Params.mmproj.path.c_str(), GraphRef.LlamaModel.get(), VisionContextParams)); From 14a6caa44dfa429a118f65b5b73df71639486ca9 Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 10 Dec 2025 17:07:32 +0800 Subject: [PATCH 592/623] fix(WASI-NN,ggml): keep size_t for the type match on macOS (#4437) Signed-off-by: hydai --- plugins/wasi_nn/GGML/metadata/metadata_parser.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/plugins/wasi_nn/GGML/metadata/metadata_parser.cpp b/plugins/wasi_nn/GGML/metadata/metadata_parser.cpp index 397e6036..ac9af83f 100644 --- a/plugins/wasi_nn/GGML/metadata/metadata_parser.cpp +++ b/plugins/wasi_nn/GGML/metadata/metadata_parser.cpp @@ -335,14 +335,14 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, parseJsonWithCastAuto(Doc, "ppl-output-type", GraphRef.Params.ppl_output_type); parseJsonAuto(Doc, "hellaswag", GraphRef.Params.hellaswag); - parseJsonWithCastAuto(Doc, "hellaswag-tasks", - GraphRef.Params.hellaswag_tasks); + parseJsonWithCastAuto(Doc, "hellaswag-tasks", + GraphRef.Params.hellaswag_tasks); parseJsonAuto(Doc, "winogrande", GraphRef.Params.winogrande); - parseJsonWithCastAuto(Doc, "winogrande-tasks", - GraphRef.Params.winogrande_tasks); + parseJsonWithCastAuto(Doc, "winogrande-tasks", + GraphRef.Params.winogrande_tasks); parseJsonAuto(Doc, "multiple-choice", GraphRef.Params.multiple_choice); - parseJsonWithCastAuto( + parseJsonWithCastAuto( Doc, "multiple-choice-tasks", GraphRef.Params.multiple_choice_tasks); parseJsonAuto(Doc, "kl-divergence", GraphRef.Params.kl_divergence); parseJsonAuto(Doc, "usage", GraphRef.Params.usage); From 5a9657f1c0b10cca7dc5b5ca6e026719d4fc1414 Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 24 Dec 2025 14:52:01 +0800 Subject: [PATCH 593/623] fix(WASI-NN,ggml): moving the function definitions to prevent macos linker issue (#4463) Signed-off-by: hydai --- plugins/wasi_nn/GGML/tts/tts_core.cpp | 3 ++- plugins/wasi_nn/GGML/tts/tts_core.h | 9 +++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/plugins/wasi_nn/GGML/tts/tts_core.cpp b/plugins/wasi_nn/GGML/tts/tts_core.cpp index 3872b5b3..78535eb2 100644 --- a/plugins/wasi_nn/GGML/tts/tts_core.cpp +++ b/plugins/wasi_nn/GGML/tts/tts_core.cpp @@ -303,6 +303,8 @@ std::string replaceNumbersWithWords(const std::string &InputText) { return Result; } +} // namespace + std::vector processTTSPrompt(WasiNNEnvironment &Env, Graph &GraphRef, std::string &Prompt) noexcept { @@ -475,6 +477,5 @@ ErrNo codesToSpeech(WasiNNEnvironment &Env, Graph &GraphRef, return ErrNo::Success; } -} // namespace #endif } // namespace WasmEdge::Host::WASINN::GGML diff --git a/plugins/wasi_nn/GGML/tts/tts_core.h b/plugins/wasi_nn/GGML/tts/tts_core.h index b5c9ae65..f8ff281c 100644 --- a/plugins/wasi_nn/GGML/tts/tts_core.h +++ b/plugins/wasi_nn/GGML/tts/tts_core.h @@ -32,15 +32,16 @@ const std::map Tens = { {2, "twenty"}, {3, "thirty"}, {4, "forty"}, {5, "fifty"}, {6, "sixty"}, {7, "seventy"}, {8, "eighty"}, {9, "ninety"}}; -std::string processTTSPromptText(const std::string &Text); -std::optional -getSpeakerProfileFromFile(const std::string &FilePath, WasiNNEnvironment &Env); - std::vector embdToAudio(const float *Embd, const int NCodes, const int NEmbd, const int NThread); std::vector audioDataToWav(const std::vector &Data, int SampleRate); } // namespace + +std::string processTTSPromptText(const std::string &Text); +std::optional +getSpeakerProfileFromFile(const std::string &FilePath, WasiNNEnvironment &Env); + std::vector processTTSPrompt(WasiNNEnvironment &Env, Graph &GraphRef, std::string &Prompt) noexcept; From 75d063b1adfc35ce365ac3b5f8d0fd18317fd2fe Mon Sep 17 00:00:00 2001 From: Sankalp Jha Date: Wed, 24 Dec 2025 21:39:11 +0530 Subject: [PATCH 594/623] feat(docker): add standalone alpine-base image build pipeline (#4466) Signed-off-by: blackdragoon26 --- utils/docker/Dockerfile.alpine-base | 62 ++++++++++++++++++++++++ utils/docker/docker-bake.alpine-base.hcl | 12 +++++ 2 files changed, 74 insertions(+) create mode 100644 utils/docker/Dockerfile.alpine-base create mode 100644 utils/docker/docker-bake.alpine-base.hcl diff --git a/utils/docker/Dockerfile.alpine-base b/utils/docker/Dockerfile.alpine-base new file mode 100644 index 00000000..c1e587ec --- /dev/null +++ b/utils/docker/Dockerfile.alpine-base @@ -0,0 +1,62 @@ +# syntax=docker/dockerfile:1.5-labs + +ARG ALPINE_VERSION=3.23 + +# ---------------------------------------------------- +# STAGE 1: Builder +# Compiles LLD from source, linking against system LLVM +# ---------------------------------------------------- +FROM alpine:${ALPINE_VERSION} AS builder + +# Use a specific tag for reproducibility +# Pinned version to match Alpine 3.23's system LLVM (v19) +ARG LLVM_VERSION=llvmorg-19.1.5 + +# Build Dependencies: +# - build-base: Compiler toolchain (gcc, g++, make, libc-dev) required for compilation. +# - cmake: The build system used by LLVM. +# - ninja: The build generator (faster than make). +# - python3: Required by LLVM build scripts and configuration utilities. +# - git: Required to clone the LLVM repository. +# - linux-headers: Ensures system header compatibility for C++ standard library builds. +# - llvm19-dev: Provides LLVM development headers for version 19. +# - llvm19-static: Provides static LLVM libraries for version 19. +# - zstd-static: Provides static zstd compression libraries. +# - zlib-static: Provides static zlib compression libraries. +RUN apk add --no-cache build-base cmake ninja python3 git linux-headers \ + llvm19-dev llvm19-static zstd-static zlib-static + + +# The system CMake files require testing libraries that are missing in the Alpine package. +# We create empty dummy archives to satisfy the dependency check. +RUN LLVM_LIB_DIR=/usr/lib/llvm19/lib && \ + for f in libLLVMTestingAnnotations.a libLLVMTestingSupport.a libllvm_gtest.a libllvm_gtest_main.a; do \ + ar rc "$LLVM_LIB_DIR/$f"; \ + done + +# Clone LLVM Project +WORKDIR /src +# Shallow clone specific tag +RUN git clone --depth 1 -b ${LLVM_VERSION} https://github.com/llvm/llvm-project.git + + +# Configure and Build LLD +# We treat LLD as a standalone project and point it to the system LLVM +WORKDIR /src/llvm-project/lld +RUN cmake -B build -G Ninja \ + -DCMAKE_BUILD_TYPE=MinSizeRel \ + -DCMAKE_INSTALL_PREFIX=/usr/local \ + -DLLVM_DIR=/usr/lib/llvm19/lib/cmake/llvm \ + -DLLVM_BUILD_STATIC=ON \ + -DLLVM_INCLUDE_TESTS=OFF \ + -DLLVM_INCLUDE_EXAMPLES=OFF && \ + cmake --build build --target lldELF lldCommon lldMachO lldCOFF lldWasm lldMinGW + +# ---------------------------------------------------- +# STAGE 2: Final Image (Contains only the compiled static libraries) +# ---------------------------------------------------- +FROM alpine:${ALPINE_VERSION} + +# Copy artifacts to a standard library path +COPY --from=builder /src/llvm-project/lld/build/lib/liblld*.a /usr/local/lib/ + diff --git a/utils/docker/docker-bake.alpine-base.hcl b/utils/docker/docker-bake.alpine-base.hcl new file mode 100644 index 00000000..e492c9ef --- /dev/null +++ b/utils/docker/docker-bake.alpine-base.hcl @@ -0,0 +1,12 @@ +group "default" { + targets = ["alpine-base"] +} + +target "alpine-base" { + dockerfile = "Dockerfile.alpine-base" + context = "utils/docker" + # Build for both major architectures at once + platforms = ["linux/amd64", "linux/arm64"] + # Tag it clearly so we can use it later + tags = ["wasmedge/wasmedge:alpine-base-3.23"] +} From ac8060492d645877f51c8ae6a13d3161a5ec3e16 Mon Sep 17 00:00:00 2001 From: Sankalp Jha Date: Fri, 26 Dec 2025 09:21:04 +0530 Subject: [PATCH 595/623] Upgrade Alpine base to 3.23 and restore static LLD support (#4440) build(docker): upgrade Alpine base image from 3.16 to 3.23 This commit upgrades the Alpine static build image to version 3.23 to resolve security issues associated with the EOL Alpine 3.16. Since Alpine 3.17+ removed the 'lld-static' package, this commit introduces a multi-stage build strategy to restore static linking support: 1. Native Builder Stage: Adds an 'lld-builder' stage to compile LLD static libraries from source (pinned to release/19.x). This uses native Alpine compilers to verify artifacts before copying them to the cross- compilation stage, avoiding sysroot header issues. 2. Dependency Updates: Adds 'zstd-static' to the build environment as it is now a required dependency for the system LLVM package. 3. Build Optimization: Mocks unused 'LLVMTesting' static libraries during configuration to satisfy CMake checks without increasing build time. Closes: #3866 Signed-off-by: blackdragoon26 --- utils/docker/Dockerfile.alpine-static | 36 +++++++++++++++++++++------ 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/utils/docker/Dockerfile.alpine-static b/utils/docker/Dockerfile.alpine-static index 36e936af..37e144df 100644 --- a/utils/docker/Dockerfile.alpine-static +++ b/utils/docker/Dockerfile.alpine-static @@ -1,10 +1,17 @@ # syntax=docker/dockerfile:1.5-labs ARG XX_VERSION=1.2.1 -ARG ALPINE_VERSION=3.16 -# alpine 3.16 ships with llvm 13. -# alpine 3.17 and 3.18 do not ship lld-static. +ARG ALPINE_VERSION=3.23 + +# ---------------------------------------------------- +# STAGE 1: Import Pre-Built LLD Libraries +# ---------------------------------------------------- +FROM wasmedge/wasmedge:alpine-base-3.23 AS lld-artifact + +# ---------------------------------------------------- +# STAGE 2: Main WasmEdge Build +# ---------------------------------------------------- FROM --platform=$BUILDPLATFORM tonistiigi/xx:${XX_VERSION} AS xx FROM --platform=$BUILDPLATFORM alpine:${ALPINE_VERSION} AS base COPY --from=xx / / @@ -41,16 +48,29 @@ FROM base AS config ARG TARGETPLATFORM RUN apk add git llvm-dev +# We add zstd-static because the SYSTEM LLVM (installed via apk) depends on it. RUN xx-apk add \ g++ \ llvm-dev llvm-static \ - lld lld-dev lld-static \ - zlib-dev zlib-static + lld lld-dev \ + zlib-dev zlib-static \ + zstd-static # Fix LLVMConfig.cmake to account for the sysroot RUN sed -i 's|/usr/lib/llvm|${CMAKE_SYSROOT}usr/lib/llvm|' $(xx-info sysroot)usr/lib/cmake/llvm*/LLVMConfig.cmake -# In cmake/Helper.cmake we assume that lld is installed alongside llvm, so copy files over -RUN cp $(xx-info sysroot)usr/lib/liblld*.a $(xx-info sysroot)usr/lib/llvm*/lib/ + +# Patch: Create dummy files for missing testing libraries to satisfy CMake +RUN for d in $(xx-info sysroot)usr/lib/llvm*/lib; do \ + for f in libLLVMTestingAnnotations.a libLLVMTestingSupport.a libllvm_gtest.a libllvm_gtest_main.a; do \ + ar rc "$d/$f"; \ + done \ + done + +# Inject the Native LLD libraries from Stage 1 (The Base Image) +# Note: We changed the source path to /usr/local/lib because that's where we put them in the base image +COPY --from=lld-artifact /usr/local/lib/liblld*.a /tmp/lld-libs/ +RUN cp /tmp/lld-libs/*.a $(xx-info sysroot)usr/lib/llvm*/lib/ + # Use the native version of llvm-config RUN ! xx-info is-cross || cp -f /usr/lib/llvm*/bin/llvm-config $(xx-info sysroot)usr/lib/llvm*/bin/llvm-config @@ -70,7 +90,7 @@ RUN --mount=type=bind,target=/src,source=. \ -DWASMEDGE_BUILD_TOOLS=OFF \ -DWASMEDGE_BUILD_PLUGINS=OFF \ -DWASMEDGE_BUILD_EXAMPLE=OFF \ - # Link llvm statically + # link llvm statically -DWASMEDGE_LINK_LLVM_STATIC=ON \ -DWASMEDGE_LINK_TOOLS_STATIC=ON \ # Disable extra dependencies From 230c4948989c1942d66b728a96226c5e000c8c6d Mon Sep 17 00:00:00 2001 From: Divyansh Khatri Date: Fri, 9 Jan 2026 09:24:56 +0000 Subject: [PATCH 596/623] feat(docker): add libpiper to plugin-deps image Signed-off-by: Divyansh Khatri --- .../Dockerfile.manylinux_2_28-plugins-deps | 3 ++ utils/docker/Dockerfile.ubuntu-base | 1 + utils/docker/Dockerfile.ubuntu-plugins-deps | 6 +++- utils/wasi-nn/install-libpiper.sh | 34 +++++++++++++++++++ 4 files changed, 43 insertions(+), 1 deletion(-) create mode 100755 utils/wasi-nn/install-libpiper.sh diff --git a/utils/docker/Dockerfile.manylinux_2_28-plugins-deps b/utils/docker/Dockerfile.manylinux_2_28-plugins-deps index 3249c688..e5c74d27 100644 --- a/utils/docker/Dockerfile.manylinux_2_28-plugins-deps +++ b/utils/docker/Dockerfile.manylinux_2_28-plugins-deps @@ -41,4 +41,7 @@ ENV OPENVINO_YEAR="2025" COPY wasi-nn/install-onnxruntime.sh . RUN [ "/bin/bash", "install-onnxruntime.sh" ] +COPY wasi-nn/install-libpiper.sh . +RUN [ "/bin/bash", "install-libpiper.sh" ] + RUN yum clean all diff --git a/utils/docker/Dockerfile.ubuntu-base b/utils/docker/Dockerfile.ubuntu-base index 8ef2aa42..4d4cb463 100644 --- a/utils/docker/Dockerfile.ubuntu-base +++ b/utils/docker/Dockerfile.ubuntu-base @@ -40,6 +40,7 @@ ENV CXX=/usr/bin/clang++-18 ### deps for ubuntu 22.04 ### FROM base AS deps-22 +RUN curl -sSf https://apt.kitware.com/kitware-archive.sh | sh RUN apt-get install -y cmake RUN apt-get install -y \ diff --git a/utils/docker/Dockerfile.ubuntu-plugins-deps b/utils/docker/Dockerfile.ubuntu-plugins-deps index 460b0170..86c5749e 100644 --- a/utils/docker/Dockerfile.ubuntu-plugins-deps +++ b/utils/docker/Dockerfile.ubuntu-plugins-deps @@ -75,6 +75,9 @@ RUN [ "/bin/bash", "-c", "echo \"source ./openvino_genai/setupvars.sh\" >> .bash COPY wasi-nn/install-onnxruntime.sh . RUN [ "/bin/bash", "install-onnxruntime.sh" ] +COPY wasi-nn/install-libpiper.sh . +RUN [ "/bin/bash", "install-libpiper.sh" ] + COPY wasi-nn/install-chattts.sh . RUN [ "/bin/bash", "install-chattts.sh" ] @@ -88,6 +91,7 @@ RUN rm -f \ install-openvino.sh \ install-onnxruntime.sh \ install-openvino-genai.sh \ - install-chattts.sh + install-chattts.sh \ + install-libpiper.sh RUN rm -rf /var/lib/apt/lists/* diff --git a/utils/wasi-nn/install-libpiper.sh b/utils/wasi-nn/install-libpiper.sh new file mode 100755 index 00000000..b31cc7ef --- /dev/null +++ b/utils/wasi-nn/install-libpiper.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: 2019-2026 Second State INC + +set -e + +PIPER_REPO="https://github.com/OHF-Voice/piper1-gpl.git" +PIPER_COMMIT="32b95f8c1f0dc0ce27a6acd1143de331f61af777" +PIPER_INSTALL_TO="/usr/local" + +case "$(uname -m)" in + 'x86_64') ;; + 'aarch64') ;; + *) + echo "Unsupported architecture for libpiper: $(uname -m)" >&2 + exit 1 + ;; +esac + +git clone --depth 1 "${PIPER_REPO}" piper-source +cd piper-source +git fetch --depth 1 origin "${PIPER_COMMIT}" +git checkout FETCH_HEAD + +cd libpiper + +cmake -Bbuild -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX="${PIPER_INSTALL_TO}" +cmake --build build +cmake --install build + +cd ../.. +rm -rf piper-source + +ldconfig From 383e49f2754933f3e841c8b987fa5c6ae9e93af5 Mon Sep 17 00:00:00 2001 From: Vishal Malyan <146833908+vishal2005025@users.noreply.github.com> Date: Mon, 12 Jan 2026 18:09:23 +0530 Subject: [PATCH 597/623] fix(utils/ffmpeg): remove duplicate directory creation check (#4484) Signed-off-by: vishal2005025 --- utils/ffmpeg/download-ffmpeg-sample-video.sh | 3 --- 1 file changed, 3 deletions(-) diff --git a/utils/ffmpeg/download-ffmpeg-sample-video.sh b/utils/ffmpeg/download-ffmpeg-sample-video.sh index 5b05b080..151642b9 100644 --- a/utils/ffmpeg/download-ffmpeg-sample-video.sh +++ b/utils/ffmpeg/download-ffmpeg-sample-video.sh @@ -12,9 +12,6 @@ fi if [ ! -d $TODIR ]; then mkdir $TODIR fi -if [ ! -d $TODIR ]; then - mkdir $TODIR -fi if [ ! -f $TODIR/sample_video.mp4 ]; then curl -sL $SAMPLE_VIDEO -o $TODIR/sample_video.mp4 From 7c37b077ef574da98fc86df532435e5a6df0037f Mon Sep 17 00:00:00 2001 From: Vishal Malyan <146833908+vishal2005025@users.noreply.github.com> Date: Mon, 12 Jan 2026 20:44:12 +0530 Subject: [PATCH 598/623] fix(plugin/zlib): standardize missing memory instance handling (#4506) Signed-off-by: vishal2005025 --- plugins/wasmedge_zlib/zlibfunc.cpp | 197 +++++------------------------ 1 file changed, 34 insertions(+), 163 deletions(-) diff --git a/plugins/wasmedge_zlib/zlibfunc.cpp b/plugins/wasmedge_zlib/zlibfunc.cpp index e2f7d3ef..73752350 100644 --- a/plugins/wasmedge_zlib/zlibfunc.cpp +++ b/plugins/wasmedge_zlib/zlibfunc.cpp @@ -8,6 +8,13 @@ namespace WasmEdge { namespace Host { +#define MEMINST_CHECK(Out, CallFrame, Index) \ + auto *Out = CallFrame.getMemoryByIndex(Index); \ + if (unlikely(Out == nullptr)) { \ + spdlog::error("[WasmEdge-Zlib] Memory instance not found."sv); \ + return Unexpect(ErrCode::Value::HostFuncError); \ + } + constexpr bool CheckSize(int32_t StreamSize) { return (StreamSize == static_cast(sizeof(WasmZStream))); @@ -20,13 +27,7 @@ auto SyncRun(const std::string_view &Msg, WasmEdgeZlibEnvironment &Env, uint32_t ZStreamPtr, const Runtime::CallingFrame &Frame, T Callback) -> Expect { - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [{}-SyncRun] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv, - Msg); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) WasmZStream *ModuleZStream = MemInst->getPointer(ZStreamPtr); const auto HostZStreamIt = Env.ZStreamMap.find(ZStreamPtr); @@ -264,12 +265,7 @@ Expect WasmEdgeZlibDeflateSetDictionary::body( const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, uint32_t DictionaryPtr, uint32_t DictLength) { - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibDeflateSetDictionary] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) const auto *Dictionary = MemInst->getPointer(DictionaryPtr); @@ -284,12 +280,7 @@ Expect WasmEdgeZlibDeflateGetDictionary::body( const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, uint32_t DictionaryPtr, uint32_t DictLengthPtr) { - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibDeflateGetDictionary] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) auto *Dictionary = MemInst->getPointer(DictionaryPtr); auto *DictLength = MemInst->getPointer(DictLengthPtr); @@ -386,12 +377,7 @@ WasmEdgeZlibDeflatePending::body(const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, uint32_t PendingPtr, uint32_t BitsPtr) { - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibDeflatePending] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) auto *Pending = MemInst->getPointer(PendingPtr); auto *Bits = MemInst->getPointer(BitsPtr); @@ -468,12 +454,7 @@ Expect WasmEdgeZlibInflateSetDictionary::body( const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, uint32_t DictionaryPtr, uint32_t DictLength) { - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibInflateSetDictionary] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) auto *Dictionary = MemInst->getPointer(DictionaryPtr); @@ -488,12 +469,7 @@ Expect WasmEdgeZlibInflateGetDictionary::body( const Runtime::CallingFrame &Frame, uint32_t ZStreamPtr, uint32_t DictionaryPtr, uint32_t DictLengthPtr) { - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibInflateGetDictionary] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) auto *Dictionary = MemInst->getPointer(DictionaryPtr); auto *DictLength = MemInst->getPointer(DictLengthPtr); @@ -623,12 +599,7 @@ WasmEdgeZlibInflateBackInit::body(const Runtime::CallingFrame &Frame, HostZStream->opaque = Z_NULL; // ignore opaque since zmalloc and zfree was ignored - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibInflateBackInit] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) auto *Window = MemInst->getPointer(WindowPtr); @@ -671,12 +642,7 @@ Expect WasmEdgeZlibCompress::body(const Runtime::CallingFrame &Frame, uint32_t DestLenPtr, uint32_t SourcePtr, uint32_t SourceLen) { - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibCompress] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) auto *Dest = MemInst->getPointer(DestPtr); auto *DestLen = MemInst->getPointer(DestLenPtr); @@ -695,12 +661,7 @@ Expect WasmEdgeZlibCompress2::body(const Runtime::CallingFrame &Frame, uint32_t DestLenPtr, uint32_t SourcePtr, uint32_t SourceLen, int32_t Level) { - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibCompress2] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) auto *Dest = MemInst->getPointer(DestPtr); auto *DestLen = MemInst->getPointer(DestLenPtr); @@ -724,12 +685,7 @@ Expect WasmEdgeZlibUncompress::body(const Runtime::CallingFrame &Frame, uint32_t DestLenPtr, uint32_t SourcePtr, uint32_t SourceLen) { - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibUncompress] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) auto *Dest = MemInst->getPointer(DestPtr); auto *DestLen = MemInst->getPointer(DestLenPtr); @@ -747,12 +703,7 @@ Expect WasmEdgeZlibUncompress2::body(const Runtime::CallingFrame &Frame, uint32_t DestPtr, uint32_t DestLenPtr, uint32_t SourcePtr, uint32_t SourceLenPtr) { - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibUncompress2] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) auto *Dest = MemInst->getPointer(DestPtr); auto *DestLen = MemInst->getPointer(DestLenPtr); @@ -771,12 +722,7 @@ WasmEdgeZlibUncompress2::body(const Runtime::CallingFrame &Frame, Expect WasmEdgeZlibGZOpen::body(const Runtime::CallingFrame &Frame, uint32_t PathPtr, uint32_t ModePtr) { - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZOpen] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) auto *Path = MemInst->getPointer(PathPtr); auto *Mode = MemInst->getPointer(ModePtr); @@ -795,12 +741,7 @@ Expect WasmEdgeZlibGZOpen::body(const Runtime::CallingFrame &Frame, Expect WasmEdgeZlibGZDOpen::body(const Runtime::CallingFrame &Frame, int32_t FD, uint32_t ModePtr) { - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZDOpen] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) auto *Mode = MemInst->getPointer(ModePtr); @@ -851,12 +792,7 @@ Expect WasmEdgeZlibGZRead::body(const Runtime::CallingFrame &Frame, return Unexpect(ErrCode::Value::HostFuncError); } - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZRead] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) auto *Buf = MemInst->getPointer(BufPtr); @@ -873,12 +809,7 @@ Expect WasmEdgeZlibGZFread::body(const Runtime::CallingFrame &Frame, return Unexpect(ErrCode::Value::HostFuncError); } - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZFread] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) auto *Buf = MemInst->getPointer(BufPtr); @@ -895,12 +826,7 @@ Expect WasmEdgeZlibGZWrite::body(const Runtime::CallingFrame &Frame, return Unexpect(ErrCode::Value::HostFuncError); } - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZWrite] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) auto *Buf = MemInst->getPointer(BufPtr); @@ -917,12 +843,7 @@ Expect WasmEdgeZlibGZFwrite::body(const Runtime::CallingFrame &Frame, return Unexpect(ErrCode::Value::HostFuncError); } - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZFwrite] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) auto *Buf = MemInst->getPointer(BufPtr); @@ -938,12 +859,7 @@ Expect WasmEdgeZlibGZPuts::body(const Runtime::CallingFrame &Frame, return Unexpect(ErrCode::Value::HostFuncError); } - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibGZPuts] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) auto *String = MemInst->getPointer(StringPtr); @@ -1136,12 +1052,7 @@ Expect WasmEdgeZlibGZClearerr::body(const Runtime::CallingFrame &, Expect WasmEdgeZlibAdler32::body(const Runtime::CallingFrame &Frame, uint32_t Adler, uint32_t BufPtr, uint32_t Len) { - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibAdler32] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) auto *Buf = MemInst->getPointer(BufPtr); @@ -1151,12 +1062,7 @@ Expect WasmEdgeZlibAdler32::body(const Runtime::CallingFrame &Frame, Expect WasmEdgeZlibAdler32_z::body(const Runtime::CallingFrame &Frame, uint32_t Adler, uint32_t BufPtr, uint32_t Len) { - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibAdler32_z] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) auto *Buf = MemInst->getPointer(BufPtr); @@ -1173,12 +1079,7 @@ Expect WasmEdgeZlibAdler32Combine::body(const Runtime::CallingFrame &, Expect WasmEdgeZlibCRC32::body(const Runtime::CallingFrame &Frame, uint32_t CRC, uint32_t BufPtr, uint32_t Len) { - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibCRC32] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) auto *Buf = MemInst->getPointer(BufPtr); @@ -1188,12 +1089,7 @@ Expect WasmEdgeZlibCRC32::body(const Runtime::CallingFrame &Frame, Expect WasmEdgeZlibCRC32_z::body(const Runtime::CallingFrame &Frame, uint32_t CRC, uint32_t BufPtr, uint32_t Len) { - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibCRC32_z] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) auto *Buf = MemInst->getPointer(BufPtr); @@ -1213,12 +1109,7 @@ WasmEdgeZlibDeflateInit_::body(const Runtime::CallingFrame &Frame, if (!CheckSize(StreamSize)) return static_cast(Z_VERSION_ERROR); - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibDeflateInit_] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) const auto *WasmZlibVersion = MemInst->getPointer(VersionPtr); auto HostZStream = std::make_unique(); @@ -1253,12 +1144,7 @@ WasmEdgeZlibInflateInit_::body(const Runtime::CallingFrame &Frame, if (!CheckSize(StreamSize)) return static_cast(Z_VERSION_ERROR); - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibInflateInit_] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) const auto *WasmZlibVersion = MemInst->getPointer(VersionPtr); auto HostZStream = std::make_unique(); @@ -1292,12 +1178,7 @@ Expect WasmEdgeZlibDeflateInit2_::body( if (!CheckSize(StreamSize)) return static_cast(Z_VERSION_ERROR); - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibDeflateInit2_] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) const auto *WasmZlibVersion = MemInst->getPointer(VersionPtr); auto HostZStream = std::make_unique(); @@ -1330,12 +1211,7 @@ WasmEdgeZlibInflateInit2_::body(const Runtime::CallingFrame &Frame, if (!CheckSize(StreamSize)) return static_cast(Z_VERSION_ERROR); - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibInflateInit2_] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) const auto *WasmZlibVersion = MemInst->getPointer(VersionPtr); auto HostZStream = std::make_unique(); @@ -1367,12 +1243,7 @@ Expect WasmEdgeZlibInflateBackInit_::body( if (!CheckSize(StreamSize)) return static_cast(Z_VERSION_ERROR); - auto *MemInst = Frame.getMemoryByIndex(0); - if (MemInst == nullptr) { - spdlog::error("[WasmEdge-Zlib] [WasmEdgeZlibInflateBackInit_] "sv - "Frame.getMemoryByIndex(0) returned nullptr."sv); - return Unexpect(ErrCode::Value::HostFuncError); - } + MEMINST_CHECK(MemInst, Frame, 0) const auto *WasmZlibVersion = MemInst->getPointer(VersionPtr); auto *Window = MemInst->getPointer(WindowPtr); From 12bfc1ae36d82e2df11dc54daa5d73f1f522dfa1 Mon Sep 17 00:00:00 2001 From: Vishal Malyan <146833908+vishal2005025@users.noreply.github.com> Date: Thu, 15 Jan 2026 17:59:12 +0530 Subject: [PATCH 599/623] fix(plugins/ffmpeg): clamp guest-provided lengths when copying FFmpeg C strings (#4478) fix(ffmpeg): clamp guest-provided lengths when copying FFmpeg C strings Signed-off-by: vishal2005025 --- .../wasmedge_ffmpeg/avcodec/avcodec_func.cpp | 8 +++++-- .../avdevice/avDevice_func.cpp | 8 +++++-- .../avfilter/avfilter_func.cpp | 12 +++++++--- .../avformat/avformat_func.cpp | 8 +++++-- .../wasmedge_ffmpeg/avutil/avutil_func.cpp | 8 +++++-- plugins/wasmedge_ffmpeg/avutil/pixfmt.cpp | 23 +++++++++++++++---- .../swresample/swresample_func.cpp | 8 +++++-- .../wasmedge_ffmpeg/swscale/swscale_func.cpp | 8 +++++-- 8 files changed, 63 insertions(+), 20 deletions(-) diff --git a/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp b/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp index 5158f4b3..e7fb71bf 100644 --- a/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp +++ b/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp @@ -291,7 +291,9 @@ Expect AVCodecConfiguration::body(const Runtime::CallingFrame &Frame, MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); const char *Config = avcodec_configuration(); - std::copy_n(Config, ConfigLen, ConfigBuf.data()); + auto Actual = std::strlen(Config); + auto N = std::min(ConfigLen, static_cast(Actual + 1)); + std::copy_n(Config, N, ConfigBuf.data()); return static_cast(ErrNo::Success); } @@ -306,7 +308,9 @@ Expect AVCodecLicense::body(const Runtime::CallingFrame &Frame, MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); const char *License = avcodec_license(); - std::copy_n(License, LicenseLen, LicenseBuf.data()); + auto Actual = std::strlen(License); + auto N = std::min(LicenseLen, static_cast(Actual + 1)); + std::copy_n(License, N, LicenseBuf.data()); return static_cast(ErrNo::Success); } diff --git a/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.cpp b/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.cpp index 0e19d8aa..926c618f 100644 --- a/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.cpp +++ b/plugins/wasmedge_ffmpeg/avdevice/avDevice_func.cpp @@ -94,7 +94,9 @@ Expect AVDeviceConfiguration::body(const Runtime::CallingFrame &Frame, MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); const char *Config = avdevice_configuration(); - std::copy_n(Config, ConfigLen, ConfigBuf.data()); + auto Actual = std::strlen(Config); + auto N = std::min(ConfigLen, static_cast(Actual + 1)); + std::copy_n(Config, N, ConfigBuf.data()); return static_cast(ErrNo::Success); } @@ -110,7 +112,9 @@ Expect AVDeviceLicense::body(const Runtime::CallingFrame &Frame, MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); const char *License = avdevice_license(); - std::copy_n(License, LicenseLen, LicenseBuf.data()); + auto Actual = std::strlen(License); + auto N = std::min(LicenseLen, static_cast(Actual + 1)); + std::copy_n(License, N, LicenseBuf.data()); return static_cast(ErrNo::Success); } diff --git a/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp b/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp index 73ad99ca..8c46d501 100644 --- a/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp +++ b/plugins/wasmedge_ffmpeg/avfilter/avfilter_func.cpp @@ -132,7 +132,9 @@ Expect AVFilterConfiguration::body(const Runtime::CallingFrame &Frame, MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); const char *Config = avfilter_configuration(); - std::copy_n(Config, ConfigLen, ConfigBuf.data()); + auto Actual = std::strlen(Config); + auto N = std::min(ConfigLen, static_cast(Actual + 1)); + std::copy_n(Config, N, ConfigBuf.data()); return static_cast(ErrNo::Success); } @@ -148,7 +150,9 @@ Expect AVFilterLicense::body(const Runtime::CallingFrame &Frame, MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); const char *License = avfilter_license(); - std::copy_n(License, LicenseLen, LicenseBuf.data()); + auto Actual = std::strlen(License); + auto N = std::min(LicenseLen, static_cast(Actual + 1)); + std::copy_n(License, N, LicenseBuf.data()); return static_cast(ErrNo::Success); } @@ -212,7 +216,9 @@ Expect AVFilterPadGetName::body(const Runtime::CallingFrame &Frame, FFMPEG_PTR_FETCH(FilterPad, FilterPadId, AVFilterPad); const char *Name = avfilter_pad_get_name(FilterPad, Idx); - std::copy_n(Name, NameLen, NameBuf.data()); + auto Actual = std::strlen(Name); + auto N = std::min(NameLen, static_cast(Actual + 1)); + std::copy_n(Name, N, NameBuf.data()); return static_cast(ErrNo::Success); } diff --git a/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp b/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp index 35637087..331235d1 100644 --- a/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp +++ b/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp @@ -337,7 +337,9 @@ Expect AVFormatConfiguration::body(const Runtime::CallingFrame &Frame, MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); const char *Config = avformat_configuration(); - std::copy_n(Config, ConfigLen, ConfigBuf.data()); + auto Actual = std::strlen(Config); + auto N = std::min(ConfigLen, static_cast(Actual + 1)); + std::copy_n(Config, N, ConfigBuf.data()); return static_cast(ErrNo::Success); } @@ -353,7 +355,9 @@ Expect AVFormatLicense::body(const Runtime::CallingFrame &Frame, MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); const char *License = avformat_license(); - std::copy_n(License, LicenseLen, LicenseBuf.data()); + auto Actual = std::strlen(License); + auto N = std::min(LicenseLen, static_cast(Actual + 1)); + std::copy_n(License, N, LicenseBuf.data()); return static_cast(ErrNo::Success); } diff --git a/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp b/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp index 673a3e77..da812afc 100644 --- a/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp +++ b/plugins/wasmedge_ffmpeg/avutil/avutil_func.cpp @@ -130,7 +130,9 @@ Expect AVUtilConfiguration::body(const Runtime::CallingFrame &Frame, MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); const char *Config = avutil_configuration(); - std::copy_n(Config, ConfigLen, ConfigBuf.data()); + auto Actual = std::strlen(Config); + auto N = std::min(ConfigLen, static_cast(Actual + 1)); + std::copy_n(Config, N, ConfigBuf.data()); return static_cast(ErrNo::Success); } @@ -145,7 +147,9 @@ Expect AVUtilLicense::body(const Runtime::CallingFrame &Frame, MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); const char *License = avutil_license(); - std::copy_n(License, LicenseLen, LicenseBuf.data()); + auto Actual = std::strlen(License); + auto N = std::min(LicenseLen, static_cast(Actual + 1)); + std::copy_n(License, N, LicenseBuf.data()); return static_cast(ErrNo::Success); } diff --git a/plugins/wasmedge_ffmpeg/avutil/pixfmt.cpp b/plugins/wasmedge_ffmpeg/avutil/pixfmt.cpp index 639a9241..8886459f 100644 --- a/plugins/wasmedge_ffmpeg/avutil/pixfmt.cpp +++ b/plugins/wasmedge_ffmpeg/avutil/pixfmt.cpp @@ -57,7 +57,9 @@ Expect AVColorRangeName::body(const Runtime::CallingFrame &Frame, AVColorRange const ColorRange = static_cast(RangeId); const char *RangeName = av_color_range_name(ColorRange); - std::copy_n(RangeName, RangeLength, RangeNameBuf.data()); + auto Actual = std::strlen(RangeName); + auto N = std::min(RangeLength, static_cast(Actual + 1)); + std::copy_n(RangeName, N, RangeNameBuf.data()); return static_cast(ErrNo::Success); } @@ -80,7 +82,10 @@ Expect AVColorTransferName::body(const Runtime::CallingFrame &Frame, AVColorTransferCharacteristic const Characteristic = static_cast(TransferId); const char *TransferName = av_color_transfer_name(Characteristic); - std::copy_n(TransferName, TransferLength, TransferNameBuf.data()); + auto Actual = std::strlen(TransferName); + auto N = + std::min(TransferLength, static_cast(Actual + 1)); + std::copy_n(TransferName, N, TransferNameBuf.data()); return static_cast(ErrNo::Success); } @@ -101,7 +106,9 @@ Expect AVColorSpaceName::body(const Runtime::CallingFrame &Frame, AVColorSpace const ColorSpace = static_cast(ColorSpaceId); const char *ColorSpaceName = av_color_space_name(ColorSpace); - std::copy_n(ColorSpaceName, ColorSpaceLen, ColorSpaceBuf.data()); + auto Actual = std::strlen(ColorSpaceName); + auto N = std::min(ColorSpaceLen, static_cast(Actual + 1)); + std::copy_n(ColorSpaceName, N, ColorSpaceBuf.data()); return static_cast(ErrNo::Success); } @@ -124,7 +131,10 @@ Expect AVColorPrimariesName::body(const Runtime::CallingFrame &Frame, AVColorPrimaries const ColorPrimaries = FFmpegUtils::ColorPrimaries::intoAVColorPrimaries(ColorPrimariesId); const char *PrimariesName = av_color_primaries_name(ColorPrimaries); - std::copy_n(PrimariesName, ColorPrimariesLen, ColorPrimariesBuf.data()); + auto Actual = std::strlen(PrimariesName); + auto N = + std::min(ColorPrimariesLen, static_cast(Actual + 1)); + std::copy_n(PrimariesName, N, ColorPrimariesBuf.data()); return static_cast(ErrNo::Success); } @@ -149,7 +159,10 @@ Expect AVPixelFormatName::body(const Runtime::CallingFrame &Frame, FFmpegUtils::PixFmt::intoAVPixFmt(PixFormatId); const AVPixFmtDescriptor *PixFmtDescriptor = av_pix_fmt_desc_get(PixFormat); const char *PixFormatName = PixFmtDescriptor->name; - std::copy_n(PixFormatName, PixFormatNameLen, PixFormatBuf.data()); + auto Actual = std::strlen(PixFormatName); + auto N = + std::min(PixFormatNameLen, static_cast(Actual + 1)); + std::copy_n(PixFormatName, N, PixFormatBuf.data()); return static_cast(ErrNo::Success); } diff --git a/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp b/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp index f26e33ca..07244499 100644 --- a/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp +++ b/plugins/wasmedge_ffmpeg/swresample/swresample_func.cpp @@ -108,7 +108,9 @@ SWResampleConfiguration::body(const Runtime::CallingFrame &Frame, MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); const char *Config = swresample_configuration(); - std::copy_n(Config, ConfigLen, ConfigBuf.data()); + auto Actual = std::strlen(Config); + auto N = std::min(ConfigLen, static_cast(Actual + 1)); + std::copy_n(Config, N, ConfigBuf.data()); return static_cast(ErrNo::Success); } @@ -124,7 +126,9 @@ Expect SWResampleLicense::body(const Runtime::CallingFrame &Frame, MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); const char *License = swresample_license(); - std::copy_n(License, LicenseLen, LicenseBuf.data()); + auto Actual = std::strlen(License); + auto N = std::min(LicenseLen, static_cast(Actual + 1)); + std::copy_n(License, N, LicenseBuf.data()); return static_cast(ErrNo::Success); } diff --git a/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp b/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp index 5da058cc..90107b8b 100644 --- a/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp +++ b/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp @@ -259,7 +259,9 @@ Expect SwscaleConfiguration::body(const Runtime::CallingFrame &Frame, MEM_SPAN_CHECK(ConfigBuf, MemInst, char, ConfigPtr, ConfigLen, ""); const char *Config = swscale_configuration(); - std::copy_n(Config, ConfigLen, ConfigBuf.data()); + auto Actual = std::strlen(Config); + auto N = std::min(ConfigLen, static_cast(Actual + 1)); + std::copy_n(Config, N, ConfigBuf.data()); return static_cast(ErrNo::Success); } @@ -274,7 +276,9 @@ Expect SwscaleLicense::body(const Runtime::CallingFrame &Frame, MEM_SPAN_CHECK(LicenseBuf, MemInst, char, LicensePtr, LicenseLen, ""); const char *License = swscale_license(); - std::copy_n(License, LicenseLen, LicenseBuf.data()); + auto Actual = std::strlen(License); + auto N = std::min(LicenseLen, static_cast(Actual + 1)); + std::copy_n(License, N, LicenseBuf.data()); return static_cast(ErrNo::Success); } From 1fa657b6339f2b3b124e6a085357693061fd7cfb Mon Sep 17 00:00:00 2001 From: Parship Chowdhury Date: Mon, 19 Jan 2026 14:58:43 +0530 Subject: [PATCH 600/623] fix(WASI-NN/bitnet): out-of-bounds array access in unload() function (#4531) fix: out-of-bounds array access in unload() function Signed-off-by: Parship Chowdhury --- plugins/wasi_nn/wasinn_bitnet.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/plugins/wasi_nn/wasinn_bitnet.cpp b/plugins/wasi_nn/wasinn_bitnet.cpp index 867fbea0..e184def0 100644 --- a/plugins/wasi_nn/wasinn_bitnet.cpp +++ b/plugins/wasi_nn/wasinn_bitnet.cpp @@ -2356,8 +2356,11 @@ Expect finiSingle(WasiNNEnvironment &Env, uint32_t ContextId) noexcept { Expect unload(WasiNNEnvironment &Env, uint32_t GraphId) noexcept { + if (GraphId >= Env.NNGraph.size()) { + return ErrNo::Success; + } auto &GraphRef = Env.NNGraph[GraphId].get(); - if (GraphId >= Env.NNGraph.size() || GraphRef.LlamaModel == nullptr) { + if (GraphRef.LlamaModel == nullptr) { return ErrNo::Success; } From 599cdf85fd3de9c2826441e2d32aef83aeb7c58c Mon Sep 17 00:00:00 2001 From: Divyansh Khatri <146909065+Divyansh200102@users.noreply.github.com> Date: Mon, 19 Jan 2026 21:50:58 +0000 Subject: [PATCH 601/623] fix(docker): build libpiper statically in install script (#4535) Signed-off-by: Divyansh Khatri --- utils/wasi-nn/install-libpiper.sh | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/utils/wasi-nn/install-libpiper.sh b/utils/wasi-nn/install-libpiper.sh index b31cc7ef..d50fd653 100755 --- a/utils/wasi-nn/install-libpiper.sh +++ b/utils/wasi-nn/install-libpiper.sh @@ -24,7 +24,11 @@ git checkout FETCH_HEAD cd libpiper -cmake -Bbuild -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX="${PIPER_INSTALL_TO}" +cmake -Bbuild \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX="${PIPER_INSTALL_TO}" \ + -DBUILD_SHARED_LIBS=OFF \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON cmake --build build cmake --install build From 6ad7b039141d43e2f9eca92519a026ddb6ad948b Mon Sep 17 00:00:00 2001 From: Divyansh Khatri <146909065+Divyansh200102@users.noreply.github.com> Date: Mon, 26 Jan 2026 11:06:18 +0000 Subject: [PATCH 602/623] refactor(WASI-Crypto): deduplicate factory functions and OpenSSL BIO helpers (#4545) * refactor: use template for wasi_crypto modules * Replace repetitive manual factory functions with a generic createModule template. * Reduce boilerplate code in ctx.cpp. * Simplify the registration logic. * refactor: cleanup includes and deduplicate OpenSSL helpers * plugins/wasi_crypto/common/array_output.h: Remove unused headers , , and . * plugins/wasi_crypto/common/array_output.cpp: Remove unused header . * plugins/wasi_crypto/utils/evp_wrapper.cpp: Deduplicate repetitive BIO setup logic using new helper functions createBioFromSpan and writeKeyToBio. --------- Signed-off-by: Divyansh Khatri --- plugins/wasi_crypto/common/array_output.cpp | 1 - plugins/wasi_crypto/common/array_output.h | 3 - plugins/wasi_crypto/ctx.cpp | 32 ++--- plugins/wasi_crypto/utils/evp_wrapper.cpp | 132 +++++++------------- 4 files changed, 53 insertions(+), 115 deletions(-) diff --git a/plugins/wasi_crypto/common/array_output.cpp b/plugins/wasi_crypto/common/array_output.cpp index 8f25327e..2085201c 100644 --- a/plugins/wasi_crypto/common/array_output.cpp +++ b/plugins/wasi_crypto/common/array_output.cpp @@ -4,7 +4,6 @@ #include "common/array_output.h" #include -#include namespace WasmEdge { namespace Host { diff --git a/plugins/wasi_crypto/common/array_output.h b/plugins/wasi_crypto/common/array_output.h index f5fef5c8..1c00b752 100644 --- a/plugins/wasi_crypto/common/array_output.h +++ b/plugins/wasi_crypto/common/array_output.h @@ -19,10 +19,7 @@ #include "common/span.h" -#include -#include #include -#include #include #include diff --git a/plugins/wasi_crypto/ctx.cpp b/plugins/wasi_crypto/ctx.cpp index 7c7d6811..fd5f4233 100644 --- a/plugins/wasi_crypto/ctx.cpp +++ b/plugins/wasi_crypto/ctx.cpp @@ -13,26 +13,10 @@ namespace Host { namespace { -Runtime::Instance::ModuleInstance *createAsymmetricCommon( - const Plugin::PluginModule::ModuleDescriptor *) noexcept { - return new WasiCryptoAsymmetricCommonModule( - WasiCrypto::Context::getInstance()); -} -Runtime::Instance::ModuleInstance * -createCommon(const Plugin::PluginModule::ModuleDescriptor *) noexcept { - return new WasiCryptoCommonModule(WasiCrypto::Context::getInstance()); -} -Runtime::Instance::ModuleInstance * -createKx(const Plugin::PluginModule::ModuleDescriptor *) noexcept { - return new WasiCryptoKxModule(WasiCrypto::Context::getInstance()); -} -Runtime::Instance::ModuleInstance * -createSignatures(const Plugin::PluginModule::ModuleDescriptor *) noexcept { - return new WasiCryptoSignaturesModule(WasiCrypto::Context::getInstance()); -} +template Runtime::Instance::ModuleInstance * -createSymmetric(const Plugin::PluginModule::ModuleDescriptor *) noexcept { - return new WasiCryptoSymmetricModule(WasiCrypto::Context::getInstance()); +createModule(const Plugin::PluginModule::ModuleDescriptor *) noexcept { + return new T(WasiCrypto::Context::getInstance()); } Plugin::Plugin::PluginDescriptor Descriptor{ @@ -46,27 +30,27 @@ Plugin::Plugin::PluginDescriptor Descriptor{ { .Name = "wasi_crypto_asymmetric_common", .Description = "", - .Create = createAsymmetricCommon, + .Create = createModule, }, { .Name = "wasi_crypto_common", .Description = "", - .Create = createCommon, + .Create = createModule, }, { .Name = "wasi_crypto_kx", .Description = "", - .Create = createKx, + .Create = createModule, }, { .Name = "wasi_crypto_signatures", .Description = "", - .Create = createSignatures, + .Create = createModule, }, { .Name = "wasi_crypto_symmetric", .Description = "", - .Create = createSymmetric, + .Create = createModule, }, }, .AddOptions = nullptr, diff --git a/plugins/wasi_crypto/utils/evp_wrapper.cpp b/plugins/wasi_crypto/utils/evp_wrapper.cpp index a28f9665..84842220 100644 --- a/plugins/wasi_crypto/utils/evp_wrapper.cpp +++ b/plugins/wasi_crypto/utils/evp_wrapper.cpp @@ -13,133 +13,91 @@ namespace WasmEdge { namespace Host { namespace WasiCrypto { -EVP_PKEY *pemReadPUBKEY(Span Encoded) { - BioPtr Bio{BIO_new(BIO_s_mem())}; +namespace { - if (size_t Size; - BIO_write_ex(Bio.get(), Encoded.data(), Encoded.size(), &Size)) { - if (Size != Encoded.size()) { - return nullptr; - } - } else { +BioPtr createBioFromSpan(Span Encoded) { + BioPtr Bio{BIO_new(BIO_s_mem())}; + if (!Bio) { return nullptr; } - return PEM_read_bio_PUBKEY(Bio.get(), nullptr, nullptr, nullptr); + size_t Size; + if (!BIO_write_ex(Bio.get(), Encoded.data(), Encoded.size(), &Size) || + Size != Encoded.size()) { + return nullptr; + } + return Bio; } -WasiCryptoExpect> pemWritePUBKEY(EVP_PKEY *Key) { +template +WasiCryptoExpect writeKeyToBio(EVP_PKEY *Key, WriteFunc &&Func) { BioPtr Bio{BIO_new(BIO_s_mem())}; - opensslCheck(PEM_write_bio_PUBKEY(Bio.get(), Key)); + opensslCheck(Func(Bio.get(), Key)); BUF_MEM *Mem = nullptr; opensslCheck(BIO_get_mem_ptr(Bio.get(), &Mem)); - std::vector Ret(Mem->length); - if (size_t Size; BIO_read_ex(Bio.get(), Ret.data(), Ret.size(), &Size)) { - ensureOrReturn(Size == Ret.size(), __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); - } else { - return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + T Ret(Mem->length); + size_t Size; + if (BIO_read_ex(Bio.get(), Ret.data(), Ret.size(), &Size) && + Size == Ret.size()) { + return Ret; } - - return Ret; + return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); } -EVP_PKEY *pemReadPrivateKey(Span Encoded) { - BioPtr Bio{BIO_new(BIO_s_mem())}; +} // namespace - if (size_t Size; - BIO_write_ex(Bio.get(), Encoded.data(), Encoded.size(), &Size)) { - if (Size != Encoded.size()) { - return nullptr; - } - } else { +EVP_PKEY *pemReadPUBKEY(Span Encoded) { + auto Bio = createBioFromSpan(Encoded); + if (!Bio) { return nullptr; } - - return PEM_read_bio_PrivateKey(Bio.get(), nullptr, nullptr, nullptr); + return PEM_read_bio_PUBKEY(Bio.get(), nullptr, nullptr, nullptr); } -WasiCryptoExpect pemWritePrivateKey(EVP_PKEY *Key) { - BioPtr Bio{BIO_new(BIO_s_mem())}; - opensslCheck(PEM_write_bio_PrivateKey(Bio.get(), Key, nullptr, nullptr, 0, - nullptr, nullptr)); - - BUF_MEM *Mem = nullptr; - opensslCheck(BIO_get_mem_ptr(Bio.get(), &Mem)); - SecretVec Ret(Mem->length); +WasiCryptoExpect> pemWritePUBKEY(EVP_PKEY *Key) { + return writeKeyToBio>( + Key, [](BIO *B, EVP_PKEY *K) { return PEM_write_bio_PUBKEY(B, K); }); +} - if (size_t Size; BIO_read_ex(Bio.get(), Ret.data(), Ret.size(), &Size)) { - ensureOrReturn(Size == Ret.size(), __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); - } else { - return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); +EVP_PKEY *pemReadPrivateKey(Span Encoded) { + auto Bio = createBioFromSpan(Encoded); + if (!Bio) { + return nullptr; } + return PEM_read_bio_PrivateKey(Bio.get(), nullptr, nullptr, nullptr); +} - return Ret; +WasiCryptoExpect pemWritePrivateKey(EVP_PKEY *Key) { + return writeKeyToBio(Key, [](BIO *B, EVP_PKEY *K) { + return PEM_write_bio_PrivateKey(B, K, nullptr, nullptr, 0, nullptr, + nullptr); + }); } EVP_PKEY *d2iPUBKEY(Span Encoded) { - BioPtr Bio{BIO_new(BIO_s_mem())}; - - if (size_t Size; - BIO_write_ex(Bio.get(), Encoded.data(), Encoded.size(), &Size)) { - if (Size != Encoded.size()) { - return nullptr; - } - } else { + auto Bio = createBioFromSpan(Encoded); + if (!Bio) { return nullptr; } - return d2i_PUBKEY_bio(Bio.get(), nullptr); } WasiCryptoExpect> i2dPUBKEY(EVP_PKEY *Key) { - BioPtr Bio{BIO_new(BIO_s_mem())}; - opensslCheck(i2d_PUBKEY_bio(Bio.get(), Key)); - - BUF_MEM *Mem = nullptr; - opensslCheck(BIO_get_mem_ptr(Bio.get(), &Mem)); - std::vector Ret(Mem->length); - - if (size_t Size; BIO_read_ex(Bio.get(), Ret.data(), Ret.size(), &Size)) { - ensureOrReturn(Size == Ret.size(), __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); - } else { - return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); - } - - return Ret; + return writeKeyToBio>(Key, i2d_PUBKEY_bio); } EVP_PKEY *d2iPrivateKey(Span Encoded) { - BioPtr Bio{BIO_new(BIO_s_mem())}; - - if (size_t Size; - BIO_write_ex(Bio.get(), Encoded.data(), Encoded.size(), &Size)) { - if (Size != Encoded.size()) { - return nullptr; - } - } else { + auto Bio = createBioFromSpan(Encoded); + if (!Bio) { return nullptr; } - return d2i_PrivateKey_bio(Bio.get(), nullptr); } WasiCryptoExpect i2dPrivateKey(EVP_PKEY *Key) { - BioPtr Bio{BIO_new(BIO_s_mem())}; - opensslCheck(i2d_PrivateKey_bio(Bio.get(), Key)); - - BUF_MEM *Mem = nullptr; - opensslCheck(BIO_get_mem_ptr(Bio.get(), &Mem)); - SecretVec Ret(Mem->length); - - if (size_t Size; BIO_read_ex(Bio.get(), Ret.data(), Ret.size(), &Size)) { - ensureOrReturn(Size == Ret.size(), __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); - } else { - return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); - } - - return Ret; + return writeKeyToBio(Key, i2d_PrivateKey_bio); } ECDSA_SIG *d2iEcdsaSig(Span Encoded) { From 868aa365e98fafdb369ca3d2e1dc15a262a0046e Mon Sep 17 00:00:00 2001 From: Divyansh Khatri <146909065+Divyansh200102@users.noreply.github.com> Date: Wed, 28 Jan 2026 04:51:08 +0000 Subject: [PATCH 603/623] fix(docker): build libpiper statically (#4565) - utils/docker: Update Dockerfiles to COPY patch from the correct build context. - utils/wasi-nn/install-libpiper.sh: Fix patch path logic and enable parallel build. - utils/wasi-nn/libpiper.patch: Force static library creation and disable shared libs. Signed-off-by: Divyansh Khatri --- .../Dockerfile.manylinux_2_28-plugins-deps | 1 + utils/docker/Dockerfile.ubuntu-plugins-deps | 1 + utils/wasi-nn/install-libpiper.sh | 13 ++++++++++- utils/wasi-nn/libpiper.patch | 22 +++++++++++++++++++ 4 files changed, 36 insertions(+), 1 deletion(-) create mode 100644 utils/wasi-nn/libpiper.patch diff --git a/utils/docker/Dockerfile.manylinux_2_28-plugins-deps b/utils/docker/Dockerfile.manylinux_2_28-plugins-deps index e5c74d27..b08f89de 100644 --- a/utils/docker/Dockerfile.manylinux_2_28-plugins-deps +++ b/utils/docker/Dockerfile.manylinux_2_28-plugins-deps @@ -41,6 +41,7 @@ ENV OPENVINO_YEAR="2025" COPY wasi-nn/install-onnxruntime.sh . RUN [ "/bin/bash", "install-onnxruntime.sh" ] +COPY wasi-nn/libpiper.patch . COPY wasi-nn/install-libpiper.sh . RUN [ "/bin/bash", "install-libpiper.sh" ] diff --git a/utils/docker/Dockerfile.ubuntu-plugins-deps b/utils/docker/Dockerfile.ubuntu-plugins-deps index 86c5749e..a85945d3 100644 --- a/utils/docker/Dockerfile.ubuntu-plugins-deps +++ b/utils/docker/Dockerfile.ubuntu-plugins-deps @@ -75,6 +75,7 @@ RUN [ "/bin/bash", "-c", "echo \"source ./openvino_genai/setupvars.sh\" >> .bash COPY wasi-nn/install-onnxruntime.sh . RUN [ "/bin/bash", "install-onnxruntime.sh" ] +COPY wasi-nn/libpiper.patch . COPY wasi-nn/install-libpiper.sh . RUN [ "/bin/bash", "install-libpiper.sh" ] diff --git a/utils/wasi-nn/install-libpiper.sh b/utils/wasi-nn/install-libpiper.sh index d50fd653..bc2476a8 100755 --- a/utils/wasi-nn/install-libpiper.sh +++ b/utils/wasi-nn/install-libpiper.sh @@ -4,6 +4,14 @@ set -e +SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" +PATCH_PATH="${SCRIPT_DIR}/libpiper.patch" + +if [ ! -f "${PATCH_PATH}" ]; then + echo "Error: libpiper.patch not found at ${PATCH_PATH}" + exit 1 +fi + PIPER_REPO="https://github.com/OHF-Voice/piper1-gpl.git" PIPER_COMMIT="32b95f8c1f0dc0ce27a6acd1143de331f61af777" PIPER_INSTALL_TO="/usr/local" @@ -17,11 +25,14 @@ case "$(uname -m)" in ;; esac +rm -rf piper-source git clone --depth 1 "${PIPER_REPO}" piper-source cd piper-source git fetch --depth 1 origin "${PIPER_COMMIT}" git checkout FETCH_HEAD +cp "${PATCH_PATH}" . +patch -p1 < libpiper.patch cd libpiper cmake -Bbuild \ @@ -29,7 +40,7 @@ cmake -Bbuild \ -DCMAKE_INSTALL_PREFIX="${PIPER_INSTALL_TO}" \ -DBUILD_SHARED_LIBS=OFF \ -DCMAKE_POSITION_INDEPENDENT_CODE=ON -cmake --build build +cmake --build build --parallel $(nproc) cmake --install build cd ../.. diff --git a/utils/wasi-nn/libpiper.patch b/utils/wasi-nn/libpiper.patch new file mode 100644 index 00000000..e74cb4bd --- /dev/null +++ b/utils/wasi-nn/libpiper.patch @@ -0,0 +1,22 @@ +diff --git a/libpiper/CMakeLists.txt b/libpiper/CMakeLists.txt +index 57349c1..2e7ddd8 100644 +--- a/libpiper/CMakeLists.txt ++++ b/libpiper/CMakeLists.txt +@@ -2,6 +2,8 @@ + cmake_minimum_required(VERSION 3.26) + project(piper LANGUAGES C CXX) + ++option(BUILD_SHARED_LIBS "Build using shared libraries" OFF) ++ + set(CMAKE_CXX_STANDARD 17) + set(CMAKE_CXX_STANDARD_REQUIRED ON) + +@@ -112,7 +114,7 @@ endif() + + # ---- libpiper --- + +-add_library(piper SHARED ++add_library(piper + ${CMAKE_CURRENT_SOURCE_DIR}/src/piper.cpp + ) + From d1bf6e2e1ce34c28efee7889d9bf3029ae01d803 Mon Sep 17 00:00:00 2001 From: Khushi-Singh Date: Wed, 28 Jan 2026 21:21:40 +0530 Subject: [PATCH 604/623] feat(wasi-nn, ggml): support HIP backend for llama.cpp (#4552) feat(wasi-nn): support HIP backend for llama.cpp Signed-off-by: Khushi --- plugins/wasi_nn/CMakeLists.txt | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index d9c40e83..529f6403 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -38,7 +38,6 @@ target_compile_definitions(wasmedgePluginWasiNN PRIVATE WASI_NN_VERSION_MINOR=${WASI_NN_VERSION_MINOR} WASI_NN_VERSION_PATCH=${WASI_NN_VERSION_PATCH} ) - # This for-each iteration is for the additional sources. # The dependencies are moved into `cmake/WASINNDeps.cmake`. foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) @@ -85,6 +84,19 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) GGML/tts/tts_core.cpp GGML/utils.cpp ) + if(WASMEDGE_PLUGIN_WASI_NN_GGML_LLAMA_HIP) + find_package(hip REQUIRED) + find_package(hipblas REQUIRED) + set(GGML_HIP ON CACHE BOOL "Build GGML with HIP" FORCE) + if(DEFINED ENV{GPU_TARGETS}) + set(GPU_TARGETS $ENV{GPU_TARGETS} CACHE STRING "HIP GPU targets") + else() + set(GPU_TARGETS "gfx90c" CACHE STRING "Default for Vega iGPU") + endif() + message(STATUS "Enabling HIP for ggml backend with targets: ${GPU_TARGETS}") + target_compile_definitions(wasmedgePluginWasiNN PRIVATE GGML_HIP=${GGML_HIP} GPU_TARGETS=${GPU_TARGETS}) + target_link_libraries(wasmedgePluginWasiNN PRIVATE hip::host roc::hipblas) + endif() endif() endforeach() From b7173dbf3764d5d73472e930ac0e1eac70750457 Mon Sep 17 00:00:00 2001 From: Divyansh Khatri <146909065+Divyansh200102@users.noreply.github.com> Date: Mon, 9 Feb 2026 04:25:28 +0000 Subject: [PATCH 605/623] fix(docker): manually link espeak-ng and ucd static libs for Piper backend (#4622) Signed-off-by: Divyansh Khatri --- utils/wasi-nn/install-libpiper.sh | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/utils/wasi-nn/install-libpiper.sh b/utils/wasi-nn/install-libpiper.sh index bc2476a8..fbcb3a9d 100755 --- a/utils/wasi-nn/install-libpiper.sh +++ b/utils/wasi-nn/install-libpiper.sh @@ -43,6 +43,26 @@ cmake -Bbuild \ cmake --build build --parallel $(nproc) cmake --install build +echo "Copying espeak-ng static libraries to ${PIPER_INSTALL_TO}/lib..." + +if [ -f "build/espeak_ng-install/lib/libespeak-ng.a" ]; then + cp "build/espeak_ng-install/lib/libespeak-ng.a" "${PIPER_INSTALL_TO}/lib/" +else + find build -name "libespeak-ng.a" -exec cp {} "${PIPER_INSTALL_TO}/lib/" \; -quit +fi + +find build -name "libucd.a" -exec cp {} "${PIPER_INSTALL_TO}/lib/" \; -quit + +if [ ! -f "${PIPER_INSTALL_TO}/lib/libespeak-ng.a" ]; then + echo "Error: Failed to install libespeak-ng.a" + exit 1 +fi + +if [ ! -f "${PIPER_INSTALL_TO}/lib/libucd.a" ]; then + echo "Error: Failed to install libucd.a" + exit 1 +fi + cd ../.. rm -rf piper-source From e4234eabd0569d22674962143aa228a87a234962 Mon Sep 17 00:00:00 2001 From: Divyansh Khatri <146909065+Divyansh200102@users.noreply.github.com> Date: Tue, 10 Feb 2026 08:35:25 +0000 Subject: [PATCH 606/623] fix(wasi-nn): install espeak-ng-data for Piper backend tests in CI (#4634) fix: install espeak-ng-data for Piper backend tests in CI Signed-off-by: Divyansh Khatri --- utils/wasi-nn/install-libpiper.sh | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/utils/wasi-nn/install-libpiper.sh b/utils/wasi-nn/install-libpiper.sh index fbcb3a9d..4a4ee475 100755 --- a/utils/wasi-nn/install-libpiper.sh +++ b/utils/wasi-nn/install-libpiper.sh @@ -63,6 +63,17 @@ if [ ! -f "${PIPER_INSTALL_TO}/lib/libucd.a" ]; then exit 1 fi +echo "Installing espeak-ng-data to ${PIPER_INSTALL_TO}/share..." + +if [ -d "build/espeak_ng-install/share/espeak-ng-data" ]; then + mkdir -p "${PIPER_INSTALL_TO}/share" + cp -r "build/espeak_ng-install/share/espeak-ng-data" "${PIPER_INSTALL_TO}/share/" + echo "Espeak-ng-data installed successfully." +else + echo "Error: espeak-ng-data directory not found in build tree!" + exit 1 +fi + cd ../.. rm -rf piper-source From d0a5c275ad40855aac39ffff0598ebe7910cae9a Mon Sep 17 00:00:00 2001 From: Divyansh Khatri <146909065+Divyansh200102@users.noreply.github.com> Date: Wed, 11 Feb 2026 03:57:41 +0000 Subject: [PATCH 607/623] feat(WASI-NN, piper): update the WASI-NN Piper plugin due to upstream changes (#4443) feat(wasi-nn): upgrade Piper backend to use new upstream Signed-off-by: Divyansh Khatri --- plugins/wasi_nn/CMakeLists.txt | 31 ++ plugins/wasi_nn/piper.patch | 472 ---------------------------- plugins/wasi_nn/wasinn_piper.cpp | 406 +++++++++++------------- plugins/wasi_nn/wasinn_piper.h | 36 ++- test/plugins/wasi_nn/CMakeLists.txt | 60 +++- test/plugins/wasi_nn/wasi_nn.cpp | 13 +- utils/build_libpiper.sh | 18 ++ 7 files changed, 304 insertions(+), 732 deletions(-) delete mode 100644 plugins/wasi_nn/piper.patch create mode 100644 utils/build_libpiper.sh diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index 529f6403..dc3d8b5e 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -98,6 +98,37 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) target_link_libraries(wasmedgePluginWasiNN PRIVATE hip::host roc::hipblas) endif() endif() + if(BACKEND STREQUAL "piper") + if(DEFINED PIPER_ROOT) + find_library(ESPEAK_NG_LIB + NAMES espeak-ng libespeak-ng + PATHS /usr/local/lib /usr/local/lib64 + NO_DEFAULT_PATH + ) + if (NOT ESPEAK_NG_LIB) + find_library(ESPEAK_NG_LIB NAMES espeak-ng libespeak-ng) + endif() + + find_library(UCD_LIB + NAMES ucd libucd + PATHS /usr/local/lib /usr/local/lib64 + NO_DEFAULT_PATH + ) + if (NOT UCD_LIB) + find_library(UCD_LIB NAMES ucd libucd) + endif() + + set(ESPEAK_TARGETS ${ESPEAK_NG_LIB} ${UCD_LIB}) + else() + set(ESPEAK_TARGETS "") + endif() + + target_link_libraries(wasmedgePluginWasiNN + PRIVATE + onnxruntime + ${ESPEAK_TARGETS} + ) + endif() endforeach() target_compile_options(wasmedgePluginWasiNN diff --git a/plugins/wasi_nn/piper.patch b/plugins/wasi_nn/piper.patch deleted file mode 100644 index bca50ca7..00000000 --- a/plugins/wasi_nn/piper.patch +++ /dev/null @@ -1,472 +0,0 @@ -diff --git a/CMakeLists.txt b/CMakeLists.txt -index f96ec44..6a2d6c4 100644 ---- a/CMakeLists.txt -+++ b/CMakeLists.txt -@@ -2,6 +2,8 @@ cmake_minimum_required(VERSION 3.13) - - project(piper C CXX) - -+option(BUILD_SHARED_LIBS "Build using shared libraries" OFF) -+ - file(READ "${CMAKE_CURRENT_LIST_DIR}/VERSION" piper_version) - - set(CMAKE_CXX_STANDARD 17) -@@ -13,11 +15,13 @@ if(MSVC) - add_compile_options("$<$:/utf-8>") - elseif(NOT APPLE) - # Linux flags -- string(APPEND CMAKE_CXX_FLAGS " -Wall -Wextra -Wl,-rpath,'$ORIGIN'") -+ string(APPEND CMAKE_CXX_FLAGS " -Wall -Wextra") -+ list(APPEND CMAKE_BUILD_RPATH "$ORIGIN") -+ list(APPEND CMAKE_INSTALL_RPATH "$ORIGIN") - string(APPEND CMAKE_C_FLAGS " -Wall -Wextra") - endif() - --add_executable(piper src/cpp/main.cpp src/cpp/piper.cpp) -+add_library(piper src/cpp/piper.cpp) - add_executable(test_piper src/cpp/test.cpp src/cpp/piper.cpp) - - # NOTE: external project prefix are shortened because of path length restrictions on Windows -@@ -25,7 +29,21 @@ add_executable(test_piper src/cpp/test.cpp src/cpp/piper.cpp) - - # ---- fmt --- - --if(NOT DEFINED FMT_DIR) -+set(fmt_FOUND FALSE) -+ -+if(NOT fmt_FOUND AND TARGET "fmt::fmt") -+ list(APPEND FMT_LINK_LIBRARIES "fmt::fmt") -+ set(fmt_FOUND TRUE) -+endif() -+ -+if(NOT fmt_FOUND AND NOT DEFINED FMT_DIR) -+ find_package(fmt) -+ if(fmt_FOUND) -+ list(APPEND FMT_LINK_LIBRARIES "fmt::fmt") -+ endif() -+endif() -+ -+if(NOT fmt_FOUND AND NOT DEFINED FMT_DIR) - set(FMT_VERSION "10.0.0") - set(FMT_DIR "${CMAKE_CURRENT_BINARY_DIR}/fi") - -@@ -41,11 +59,33 @@ if(NOT DEFINED FMT_DIR) - add_dependencies(test_piper fmt_external) - endif() - -+if(NOT fmt_FOUND AND DEFINED FMT_DIR) -+ list(APPEND FMT_LINK_LIBRARIES "fmt") -+ list(APPEND FMT_LINK_DIRECTORIES "${FMT_DIR}/lib") -+ list(APPEND FMT_INCLUDE_DIRECTORIES "${FMT_DIR}/include") -+ set(fmt_FOUND TRUE) -+endif() -+ - # ---- spdlog --- - --if(NOT DEFINED SPDLOG_DIR) -+set(spdlog_FOUND FALSE) -+ -+if(NOT spdlog_FOUND AND TARGET "spdlog::spdlog") -+ list(APPEND SPDLOG_LINK_LIBRARIES "spdlog::spdlog") -+ set(spdlog_FOUND TRUE) -+endif() -+ -+if(NOT spdlog_FOUND AND NOT DEFINED SPDLOG_DIR) -+ find_package(spdlog) -+ if(spdlog_FOUND) -+ list(APPEND SPDLOG_LINK_LIBRARIES "spdlog::spdlog") -+ endif() -+endif() -+ -+if(NOT spdlog_FOUND AND NOT DEFINED SPDLOG_DIR) - set(SPDLOG_DIR "${CMAKE_CURRENT_BINARY_DIR}/si") - set(SPDLOG_VERSION "1.12.0") -+ include(ExternalProject) - ExternalProject_Add( - spdlog_external - PREFIX "${CMAKE_CURRENT_BINARY_DIR}/s" -@@ -56,81 +96,81 @@ if(NOT DEFINED SPDLOG_DIR) - add_dependencies(test_piper spdlog_external) - endif() - -+if(NOT spdlog_FOUND AND DEFINED SPDLOG_DIR) -+ list(APPEND SPDLOG_LINK_LIBRARIES "spdlog") -+ list(APPEND SPDLOG_LINK_DIRECTORIES "${SPDLOG_DIR}/lib") -+ list(APPEND SPDLOG_INCLUDE_DIRECTORIES "${SPDLOG_DIR}/include") -+ set(spdlog_FOUND TRUE) -+endif() -+ - # ---- piper-phonemize --- - --if(NOT DEFINED PIPER_PHONEMIZE_DIR) -- set(PIPER_PHONEMIZE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pi") -- ExternalProject_Add( -- piper_phonemize_external -- PREFIX "${CMAKE_CURRENT_BINARY_DIR}/p" -- URL "https://github.com/rhasspy/piper-phonemize/archive/refs/heads/master.zip" -- CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${PIPER_PHONEMIZE_DIR} -- ) -- add_dependencies(piper piper_phonemize_external) -- add_dependencies(test_piper piper_phonemize_external) --endif() -+include(FetchContent) -+find_program(GIT_CMD git REQUIRED) -+FetchContent_Declare( -+ piper_phonemize -+ GIT_REPOSITORY "https://github.com/rhasspy/piper-phonemize.git" -+ GIT_TAG "bfc2e7549957829b0227c66a305d11cc88167bda" # master -+ UPDATE_DISCONNECTED TRUE -+ PATCH_COMMAND "${GIT_CMD}" "apply" "${CMAKE_CURRENT_SOURCE_DIR}/piper-phonemize.patch" -+) -+FetchContent_MakeAvailable(piper_phonemize) - - # ---- Declare executable ---- - - if((NOT MSVC) AND (NOT APPLE)) - # Linux flags -- string(APPEND CMAKE_CXX_FLAGS " -Wall -Wextra -Wl,-rpath,'$ORIGIN'") -+ string(APPEND CMAKE_CXX_FLAGS " -Wall -Wextra") -+ list(APPEND CMAKE_BUILD_RPATH "$ORIGIN") -+ list(APPEND CMAKE_INSTALL_RPATH "$ORIGIN") - string(APPEND CMAKE_C_FLAGS " -Wall -Wextra") -- target_link_libraries(piper -static-libgcc -static-libstdc++) -+ target_link_libraries(piper PRIVATE -static-libgcc -static-libstdc++) - - set(PIPER_EXTRA_LIBRARIES "pthread") - endif() - --target_link_libraries(piper -- fmt -- spdlog -+target_link_libraries(piper PRIVATE -+ ${FMT_LINK_LIBRARIES} -+ ${SPDLOG_LINK_LIBRARIES} - espeak-ng -- piper_phonemize - onnxruntime - ${PIPER_EXTRA_LIBRARIES} -+ PUBLIC piper_phonemize - ) - --target_link_directories(piper PUBLIC -- ${FMT_DIR}/lib -- ${SPDLOG_DIR}/lib -- ${PIPER_PHONEMIZE_DIR}/lib --) -- --target_include_directories(piper PUBLIC -- ${FMT_DIR}/include -- ${SPDLOG_DIR}/include -- ${PIPER_PHONEMIZE_DIR}/include -+target_link_directories(piper PRIVATE -+ ${FMT_LINK_DIRECTORIES} -+ ${SPDLOG_LINK_DIRECTORIES} - ) - --target_compile_definitions(piper PUBLIC _PIPER_VERSION=${piper_version}) -+set(PIPER_INTERFACE_INCLUDE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/include") -+file(COPY src/cpp/piper.hpp src/cpp/json.hpp DESTINATION "${PIPER_INTERFACE_INCLUDE_DIRECTORY}") - --# ---- Declare test ---- --include(CTest) --enable_testing() --add_test( -- NAME test_piper -- COMMAND test_piper "${CMAKE_SOURCE_DIR}/etc/test_voice.onnx" "${PIPER_PHONEMIZE_DIR}/share/espeak-ng-data" "${CMAKE_CURRENT_BINARY_DIR}/test.wav" -+target_include_directories(piper PRIVATE -+ ${FMT_INCLUDE_DIRECTORIES} -+ ${SPDLOG_INCLUDE_DIRECTORIES} -+ INTERFACE "${PIPER_INTERFACE_INCLUDE_DIRECTORY}" - ) - -+target_compile_definitions(piper PRIVATE _PIPER_VERSION=${piper_version}) -+ - target_compile_features(test_piper PUBLIC cxx_std_17) - - target_include_directories( - test_piper PUBLIC -- ${FMT_DIR}/include -- ${SPDLOG_DIR}/include -- ${PIPER_PHONEMIZE_DIR}/include -+ ${FMT_INCLUDE_DIRECTORIES} -+ ${SPDLOG_INCLUDE_DIRECTORIES} - ) - - target_link_directories( - test_piper PUBLIC -- ${FMT_DIR}/lib -- ${SPDLOG_DIR}/lib -- ${PIPER_PHONEMIZE_DIR}/lib -+ ${FMT_LINK_DIRECTORIES} -+ ${SPDLOG_LINK_DIRECTORIES} - ) - - target_link_libraries(test_piper PUBLIC -- fmt -- spdlog -+ ${FMT_LINK_LIBRARIES} -+ ${SPDLOG_LINK_LIBRARIES} - espeak-ng - piper_phonemize - onnxruntime -@@ -141,32 +181,3 @@ target_link_libraries(test_piper PUBLIC - install( - TARGETS piper - DESTINATION ${CMAKE_INSTALL_PREFIX}) -- --# Dependencies --install( -- DIRECTORY ${PIPER_PHONEMIZE_DIR}/bin/ -- DESTINATION ${CMAKE_INSTALL_PREFIX} -- USE_SOURCE_PERMISSIONS # keep +x -- FILES_MATCHING -- PATTERN "piper_phonemize" -- PATTERN "espeak-ng" -- PATTERN "*.dll" --) -- --install( -- DIRECTORY ${PIPER_PHONEMIZE_DIR}/lib/ -- DESTINATION ${CMAKE_INSTALL_PREFIX} -- FILES_MATCHING -- PATTERN "*.dll" -- PATTERN "*.so*" --) -- --install( -- DIRECTORY ${PIPER_PHONEMIZE_DIR}/share/espeak-ng-data -- DESTINATION ${CMAKE_INSTALL_PREFIX} --) -- --install( -- FILES ${PIPER_PHONEMIZE_DIR}/share/libtashkeel_model.ort -- DESTINATION ${CMAKE_INSTALL_PREFIX} --) -diff --git a/VERSION b/VERSION -index 26aaba0..867e524 100644 ---- a/VERSION -+++ b/VERSION -@@ -1 +1 @@ --1.2.0 -+1.2.0 -\ No newline at end of file -diff --git a/piper-phonemize.patch b/piper-phonemize.patch -new file mode 100644 -index 0000000..c4676cb ---- /dev/null -+++ b/piper-phonemize.patch -@@ -0,0 +1,214 @@ -+diff --git a/CMakeLists.txt b/CMakeLists.txt -+index ec7b501..39275a6 100644 -+--- a/CMakeLists.txt -++++ b/CMakeLists.txt -+@@ -10,6 +10,8 @@ project( -+ LANGUAGES CXX -+ ) -+ -++option(BUILD_SHARED_LIBS "Build using shared libraries" ON) -++ -+ if(MSVC) -+ # Force compiler to use UTF-8 for IPA constants -+ add_compile_options("$<$:/utf-8>") -+@@ -17,12 +19,14 @@ if(MSVC) -+ -+ elseif(NOT APPLE) -+ # Linux flags -+- string(APPEND CMAKE_CXX_FLAGS " -Wall -Wextra -Wl,-rpath,'$ORIGIN'") -++ string(APPEND CMAKE_CXX_FLAGS " -Wall -Wextra") -++ list(APPEND CMAKE_BUILD_RPATH "$ORIGIN") -++ list(APPEND CMAKE_INSTALL_RPATH "$ORIGIN") -+ string(APPEND CMAKE_C_FLAGS " -Wall -Wextra") -+ endif() -+ -+ add_library( -+- piper_phonemize SHARED -++ piper_phonemize -+ src/phonemize.cpp -+ src/phoneme_ids.cpp -+ src/tashkeel.cpp -+@@ -36,12 +40,33 @@ set_target_properties(piper_phonemize PROPERTIES -+ -+ # ---- onnxruntime --- -+ -+-# Look for onnxruntime files in /lib -+-if(NOT DEFINED ONNXRUNTIME_DIR) -+- if(NOT DEFINED ONNXRUNTIME_VERSION) -+- set(ONNXRUNTIME_VERSION "1.14.1") -++set(onnxruntime_FOUND FALSE) -++ -++if(NOT DEFINED ONNXRUNTIME_VERSION) -++ set(ONNXRUNTIME_VERSION "1.14.1") -++endif() -++ -++if(NOT onnxruntime_FOUND AND NOT DEFINED ONNXRUNTIME_DIR) -++ find_package(onnxruntime "${ONNXRUNTIME_VERSION}") -++ if(onnxruntime_FOUND) -++ list(APPEND ONNXRUNTIME_LINK_LIBRARIES "onnxruntime::onnxruntime") -+ endif() -++endif() -++ -++if(NOT onnxruntime_FOUND AND NOT DEFINED ONNXRUNTIME_DIR) -++ find_library(ONNXRUNTIME_LIBRARY onnxruntime) -++ if(NOT "${ONNXRUNTIME_LIBRARY}" STREQUAL "ONNXRUNTIME_LIBRARY-NOTFOUND") -++ find_path(ONNXRUNTIME_PATH "onnxruntime_cxx_api.h" PATH_SUFFIXES "onnxruntime") -++ if(NOT "${ONNXRUNTIME_PATH}" STREQUAL "ONNXRUNTIME_PATH-NOTFOUND") -++ list(APPEND ONNXRUNTIME_INCLUDE_DIRECTORIES "${ONNXRUNTIME_PATH}") -++ list(APPEND ONNXRUNTIME_LINK_LIBRARIES "${ONNXRUNTIME_LIBRARY}") -++ set(onnxruntime_FOUND TRUE) -++ endif() -++ endif() -++endif() -+ -++# Look for onnxruntime files in /lib -++if(NOT onnxruntime_FOUND AND NOT DEFINED ONNXRUNTIME_DIR) -+ if(WIN32) -+ # Windows x86-64 -+ set(ONNXRUNTIME_PREFIX "onnxruntime-win-x64-${ONNXRUNTIME_VERSION}") -+@@ -95,19 +120,31 @@ if(NOT DEFINED ONNXRUNTIME_DIR) -+ endif() -+ endif() -+ -++if(NOT onnxruntime_FOUND AND DEFINED ONNXRUNTIME_DIR) -++ list(APPEND ONNXRUNTIME_INCLUDE_DIRECTORIES "${ONNXRUNTIME_DIR}/include") -++ list(APPEND ONNXRUNTIME_LINK_DIRECTORIES "${ONNXRUNTIME_DIR}/lib") -++ list(APPEND ONNXRUNTIME_LINK_LIBRARIES "onnxruntime") -++ set(onnxruntime_FOUND TRUE) -++endif() -++ -+ # ---- espeak-ng --- -+ -+ if(NOT DEFINED ESPEAK_NG_DIR) -+ set(ESPEAK_NG_DIR "${CMAKE_CURRENT_BINARY_DIR}/ei") -+ -++ find_program(GIT_PROGRAM "git" REQUIRED) -+ include(ExternalProject) -+ ExternalProject_Add( -+ espeak_ng_external -+ PREFIX "${CMAKE_CURRENT_BINARY_DIR}/e" -+- URL "https://github.com/rhasspy/espeak-ng/archive/0f65aa301e0d6bae5e172cc74197d32a6182200f.zip" -++ GIT_REPOSITORY "https://github.com/rhasspy/espeak-ng" -++ GIT_TAG "0f65aa301e0d6bae5e172cc74197d32a6182200f" -++ GIT_PROGRESS TRUE -++ UPDATE_DISCONNECTED TRUE -++ PATCH_COMMAND "${GIT_PROGRAM}" "apply" "${CMAKE_CURRENT_SOURCE_DIR}/espeak-ng.patch" -+ CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${ESPEAK_NG_DIR} -+ CMAKE_ARGS -DUSE_ASYNC:BOOL=OFF -+- CMAKE_ARGS -DBUILD_SHARED_LIBS:BOOL=ON -++ CMAKE_ARGS "-DBUILD_SHARED_LIBS:BOOL=${BUILD_SHARED_LIBS}" -+ CMAKE_ARGS -DUSE_MBROLA:BOOL=OFF -+ CMAKE_ARGS -DUSE_LIBSONIC:BOOL=OFF -+ CMAKE_ARGS -DUSE_LIBPCAUDIO:BOOL=OFF -+@@ -116,6 +153,8 @@ if(NOT DEFINED ESPEAK_NG_DIR) -+ CMAKE_ARGS -DEXTRA_cmn:BOOL=ON -+ CMAKE_ARGS -DEXTRA_ru:BOOL=ON -+ CMAKE_ARGS -DCMAKE_C_FLAGS="-D_FILE_OFFSET_BITS=64" -++ CMAKE_ARGS "-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${CMAKE_POSITION_INDEPENDENT_CODE}" -++ USES_TERMINAL_DOWNLOAD TRUE -+ ) -+ add_dependencies(piper_phonemize espeak_ng_external) -+ endif() -+@@ -123,23 +162,27 @@ endif() -+ -+ # ---- Declare library ---- -+ -++set(PIPER_PHONEMIZE_INTERFACE_INCLUDE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/include") -++file(COPY "src/" DESTINATION "${PIPER_PHONEMIZE_INTERFACE_INCLUDE_DIRECTORY}/piper-phonemize" FILES_MATCHING PATTERN "*.h" PATTERN "*.hpp") -++ -+ target_include_directories( -+ piper_phonemize PUBLIC -+ "$" -+ ${ESPEAK_NG_DIR}/include -+- ${ONNXRUNTIME_DIR}/include -++ ${ONNXRUNTIME_INCLUDE_DIRECTORIES} -++ INTERFACE "${PIPER_PHONEMIZE_INTERFACE_INCLUDE_DIRECTORY}" -+ ) -+ -+ target_link_directories( -+ piper_phonemize PUBLIC -+ ${ESPEAK_NG_DIR}/lib -+- ${ONNXRUNTIME_DIR}/lib -++ ${ONNXRUNTIME_LINK_DIRECTORIES} -+ ) -+ -+ target_link_libraries( -+ piper_phonemize -+ espeak-ng -+- onnxruntime -++ ${ONNXRUNTIME_LINK_LIBRARIES} -+ ) -+ -+ target_compile_features(piper_phonemize PUBLIC cxx_std_17) -+@@ -173,12 +216,13 @@ target_link_libraries(piper_phonemize_exe PUBLIC -+ # ---- Declare test ---- -+ -+ include(CTest) -+-enable_testing() -+ add_executable(test_piper_phonemize src/test.cpp src/phoneme_ids.cpp) -+-add_test( -+- NAME test_piper_phonemize -+- COMMAND test_piper_phonemize "${ESPEAK_NG_DIR}/share/espeak-ng-data" "${CMAKE_SOURCE_DIR}/etc/libtashkeel_model.ort" -+-) -++if(BUILD_TESTING) -++ add_test( -++ NAME test_piper_phonemize -++ COMMAND test_piper_phonemize "${ESPEAK_NG_DIR}/share/espeak-ng-data" "${CMAKE_SOURCE_DIR}/etc/libtashkeel_model.ort" -++ ) -++endif() -+ -+ target_compile_features(test_piper_phonemize PUBLIC cxx_std_17) -+ -+@@ -207,7 +251,7 @@ install( -+ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) -+ -+ install( -+- DIRECTORY ${CMAKE_SOURCE_DIR}/src/ -++ DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/src/" -+ DESTINATION include/piper-phonemize -+ FILES_MATCHING -+ PATTERN "*.h" -+@@ -218,7 +262,7 @@ install( -+ ARCHIVE DESTINATION ${CMAKE_INSTALL_BINDIR}) -+ -+ install( -+- FILES ${CMAKE_SOURCE_DIR}/etc/libtashkeel_model.ort -++ FILES "${CMAKE_CURRENT_SOURCE_DIR}/etc/libtashkeel_model.ort" -+ TYPE DATA) -+ -+ # Dependencies -+@@ -226,10 +270,12 @@ install( -+ DIRECTORY ${ESPEAK_NG_DIR}/ -+ DESTINATION ${CMAKE_INSTALL_PREFIX}) -+ -+-install( -+- DIRECTORY ${ONNXRUNTIME_DIR}/include/ -+- DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) -++if(DEFINED ONNXRUNTIME_DIR) -++ install( -++ DIRECTORY ${ONNXRUNTIME_DIR}/include/ -++ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) -+ -+-install( -+- DIRECTORY ${ONNXRUNTIME_DIR}/lib/ -+- DESTINATION ${CMAKE_INSTALL_LIBDIR}) -++ install( -++ DIRECTORY ${ONNXRUNTIME_DIR}/lib/ -++ DESTINATION ${CMAKE_INSTALL_LIBDIR}) -++endif() -+diff --git a/espeak-ng.patch b/espeak-ng.patch -+new file mode 100644 -+index 0000000..a51d146 -+--- /dev/null -++++ b/espeak-ng.patch -+@@ -0,0 +1,10 @@ -++diff --git a/src/ucd-tools/CMakeLists.txt b/src/ucd-tools/CMakeLists.txt -++index 2050c114..4bd7d17e 100644 -++--- a/src/ucd-tools/CMakeLists.txt -+++++ b/src/ucd-tools/CMakeLists.txt -++@@ -1,4 +1,4 @@ -++-add_library(ucd STATIC -+++add_library(ucd OBJECT -++ src/case.c -++ src/categories.c -++ src/ctype.c diff --git a/plugins/wasi_nn/wasinn_piper.cpp b/plugins/wasi_nn/wasinn_piper.cpp index c9bf9e0b..433bc147 100644 --- a/plugins/wasi_nn/wasinn_piper.cpp +++ b/plugins/wasi_nn/wasinn_piper.cpp @@ -2,16 +2,19 @@ // SPDX-FileCopyrightText: 2019-2024 Second State INC #include "wasinn_piper.h" +#include "common/errcode.h" +#include "common/span.h" #include "wasinnenv.h" +#include "wasinntypes.h" #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER #include "simdjson.h" +#include +#include #include #include -#include -#include -#include +#include #include #include #include @@ -23,6 +26,46 @@ namespace WasmEdge::Host::WASINN::Piper { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER +namespace { + +// helper function to write WAV header +void writeWavHeader(int SampleRate, int16_t NumChannels, int32_t NumSamples, + std::vector &OutputBuffer) { + int32_t const ByteRate = SampleRate * NumChannels * sizeof(int16_t); + int32_t const DataSize = NumSamples * NumChannels * sizeof(int16_t); + int32_t const RiffSize = 36 + DataSize; + + auto PushU32 = [&](int32_t Val) { + OutputBuffer.push_back(Val & 0xFF); + OutputBuffer.push_back((Val >> 8) & 0xFF); + OutputBuffer.push_back((Val >> 16) & 0xFF); + OutputBuffer.push_back((Val >> 24) & 0xFF); + }; + auto PushU16 = [&](int16_t Val) { + OutputBuffer.push_back(Val & 0xFF); + OutputBuffer.push_back((Val >> 8) & 0xFF); + }; + auto PushStr = [&](const char *Str) { + for (int I = 0; I < 4; I++) + OutputBuffer.push_back(Str[I]); + }; + + PushStr("RIFF"); + PushU32(RiffSize); + PushStr("WAVE"); + PushStr("fmt "); + PushU32(16); + PushU16(1); + PushU16(NumChannels); + PushU32(SampleRate); + PushU32(ByteRate); + PushU16(NumChannels * sizeof(int16_t)); + PushU16(16); + PushStr("data"); + PushU32(DataSize); +} + +} // namespace template std::tuple getOption(simdjson::dom::object &Object, std::string_view Key, T &Result) { @@ -51,8 +94,7 @@ WASINN::ErrNo getOptionalOption(simdjson::dom::object &Object, } WASINN::ErrNo parseSynthesisConfig(SynthesisConfig &SynthesisConfig, - simdjson::dom::object &Object, - const bool JsonInput) { + simdjson::dom::object &Object) { { auto Value = std::optional{}; if (auto Err = getOptionalOption(Object, "output_type", Value); @@ -72,17 +114,14 @@ WASINN::ErrNo parseSynthesisConfig(SynthesisConfig &SynthesisConfig, } } } - if (JsonInput) { - if (auto Err = - getOptionalOption(Object, "speaker_id", SynthesisConfig.SpeakerId); + { + auto SpeakerId = std::optional{}; + if (auto Err = getOptionalOption(Object, "speaker_id", SpeakerId); Err != WASINN::ErrNo::Success) { return Err; } - } else { - if (auto Err = - getOptionalOption(Object, "speaker", SynthesisConfig.SpeakerId); - Err != WASINN::ErrNo::Success) { - return Err; + if (SpeakerId.has_value()) { + SynthesisConfig.SpeakerId = static_cast(SpeakerId.value()); } } if (auto Err = getOptionalOption(Object, "noise_scale", @@ -100,45 +139,8 @@ WASINN::ErrNo parseSynthesisConfig(SynthesisConfig &SynthesisConfig, Err != WASINN::ErrNo::Success) { return Err; } - if (auto Err = getOptionalOption( - Object, "sentence_silence", SynthesisConfig.SentenceSilenceSeconds); - Err != WASINN::ErrNo::Success) { - return Err; - } - { - auto PhonemeSilence = std::optional{}; - if (auto Err = getOptionalOption(Object, "phoneme_silence", PhonemeSilence); - Err != WASINN::ErrNo::Success) { - return Err; - } - if (PhonemeSilence) { - for (auto [Key, Value] : PhonemeSilence.value()) { - auto PhonemeStr = std::string{Key}; - if (!piper::isSingleCodepoint(PhonemeStr)) { - spdlog::error( - "[WASI-NN] Piper backend: Phoneme '{}' is not a single codepoint (phoneme_silence)."sv, - PhonemeStr); - return WASINN::ErrNo::InvalidArgument; - } - auto Seconds = Value.get_double(); - if (auto Error = Seconds.error()) { - spdlog::error( - "[WASI-NN] Piper backend: Failed to get silence seconds for phoneme '{}' as a double: {}"sv, - PhonemeStr, simdjson::error_message(Error)); - return WASINN::ErrNo::InvalidArgument; - } - if (!SynthesisConfig.PhonemeSilenceSeconds) { - SynthesisConfig.PhonemeSilenceSeconds.emplace(); - } - auto Phoneme = piper::getCodepoint(PhonemeStr); - SynthesisConfig.PhonemeSilenceSeconds.value()[Phoneme] = - Seconds.value(); - } - } - } return WASINN::ErrNo::Success; } - WASINN::ErrNo parseRunConfig(RunConfig &RunConfig, const std::string &String) noexcept { simdjson::dom::parser Parser; @@ -193,8 +195,7 @@ WASINN::ErrNo parseRunConfig(RunConfig &RunConfig, return WASINN::ErrNo::InvalidArgument; } - if (auto Err = - parseSynthesisConfig(RunConfig.DefaultSynthesisConfig, Object, false); + if (auto Err = parseSynthesisConfig(RunConfig.DefaultSynthesisConfig, Object); Err != WASINN::ErrNo::Success) { return Err; } @@ -208,16 +209,6 @@ WASINN::ErrNo parseRunConfig(RunConfig &RunConfig, RunConfig.ESpeakDataPath = std::filesystem::u8path(Path.value()); } } - { - auto Path = std::optional{}; - if (auto Err = getOptionalOption(Object, "tashkeel_model", Path); - Err != WASINN::ErrNo::Success) { - return Err; - } - if (Path) { - RunConfig.TashkeelModelPath = std::filesystem::u8path(Path.value()); - } - } if (auto Err = std::get<0>(getOption(Object, "json_input", RunConfig.JsonInput)); Err != WASINN::ErrNo::Success) { @@ -226,38 +217,19 @@ WASINN::ErrNo parseRunConfig(RunConfig &RunConfig, return WASINN::ErrNo::Success; } -void updateSynthesisConfig(SynthesisConfig &SynthesisConfig, - piper::SynthesisConfig &PiperSynthesisConfig, - const bool ForceOverwritePhonemeSilenceSeconds) { +void updatePiperOptions(const SynthesisConfig &SynthesisConfig, + piper_synthesize_options &Options) { + if (SynthesisConfig.SpeakerId) { + Options.speaker_id = SynthesisConfig.SpeakerId.value(); + } if (SynthesisConfig.NoiseScale) { - PiperSynthesisConfig.noiseScale = SynthesisConfig.NoiseScale.value(); + Options.noise_scale = SynthesisConfig.NoiseScale.value(); } if (SynthesisConfig.LengthScale) { - PiperSynthesisConfig.lengthScale = SynthesisConfig.LengthScale.value(); + Options.length_scale = SynthesisConfig.LengthScale.value(); } if (SynthesisConfig.NoiseW) { - PiperSynthesisConfig.noiseW = SynthesisConfig.NoiseW.value(); - } - if (SynthesisConfig.SentenceSilenceSeconds) { - PiperSynthesisConfig.sentenceSilenceSeconds = - SynthesisConfig.SentenceSilenceSeconds.value(); - } - if (ForceOverwritePhonemeSilenceSeconds) { - PiperSynthesisConfig.phonemeSilenceSeconds = - SynthesisConfig.PhonemeSilenceSeconds; - } else if (SynthesisConfig.PhonemeSilenceSeconds) { - if (!PiperSynthesisConfig.phonemeSilenceSeconds) { - // Overwrite - PiperSynthesisConfig.phonemeSilenceSeconds = - SynthesisConfig.PhonemeSilenceSeconds; - } else { - // Merge - for (const auto &[Phoneme, SilenceSeconds] : - *SynthesisConfig.PhonemeSilenceSeconds) { - PiperSynthesisConfig.phonemeSilenceSeconds->try_emplace(Phoneme, - SilenceSeconds); - } - } + Options.noise_w_scale = SynthesisConfig.NoiseW.value(); } } @@ -273,7 +245,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, } // Add a new graph. - uint32_t GId = Env.newGraph(Backend::Piper); + uint32_t const GId = Env.newGraph(Backend::Piper); auto &GraphRef = Env.NNGraph[GId].get(); GraphRef.Config = std::make_unique(); auto String = std::string{Builders[0].begin(), Builders[0].end()}; @@ -284,72 +256,24 @@ Expect load(WASINN::WasiNNEnvironment &Env, return Res; } - GraphRef.PiperConfig = std::make_unique(); - GraphRef.Voice = std::make_unique(); - piper::loadVoice(*GraphRef.PiperConfig, GraphRef.Config->ModelPath.string(), - GraphRef.Config->ModelConfigPath.string(), *GraphRef.Voice, - GraphRef.Config->DefaultSynthesisConfig.SpeakerId); + std::string EspeakPath = ""; + if (GraphRef.Config->ESpeakDataPath) { + EspeakPath = GraphRef.Config->ESpeakDataPath->string(); + } - if (GraphRef.Voice->phonemizeConfig.phonemeType == - piper::PhonemeType::eSpeakPhonemes) { - if (!GraphRef.Config->ESpeakDataPath) { - spdlog::error( - "[WASI-NN] Piper backend: espeak-ng data directory is required for eSpeakPhonemes"sv); - Env.deleteGraph(GId); - return WASINN::ErrNo::InvalidArgument; - } - if (!std::filesystem::exists(GraphRef.Config->ESpeakDataPath.value())) { - spdlog::error( - "[WASI-NN] Piper backend: espeak-ng data directory doesn't exist"sv); - Env.deleteGraph(GId); - return WASINN::ErrNo::InvalidArgument; - } - // User provided path - GraphRef.PiperConfig->eSpeakDataPath = - GraphRef.Config->ESpeakDataPath->string(); - } else { - // Not using eSpeak - GraphRef.PiperConfig->useESpeak = false; + piper_synthesizer *Synth = + piper_create(GraphRef.Config->ModelPath.string().c_str(), + GraphRef.Config->ModelConfigPath.string().c_str(), + EspeakPath.empty() ? nullptr : EspeakPath.c_str()); + + if (!Synth) { + spdlog::error( + "[WASI-NN] Piper backend: Failed to create piper synthesizer."sv); + Env.deleteGraph(GId); + return WASINN::ErrNo::InvalidArgument; } - // Enable libtashkeel for Arabic - if (GraphRef.Voice->phonemizeConfig.eSpeak.voice == "ar") { - if (!GraphRef.Config->TashkeelModelPath) { - spdlog::error( - "[WASI-NN] Piper backend: libtashkeel ort model is required for Arabic"sv); - Env.deleteGraph(GId); - return WASINN::ErrNo::InvalidArgument; - } - if (!std::filesystem::exists(GraphRef.Config->TashkeelModelPath.value())) { - spdlog::error( - "[WASI-NN] Piper backend: libtashkeel ort model doesn't exist"sv); - Env.deleteGraph(GId); - return WASINN::ErrNo::InvalidArgument; - } - GraphRef.PiperConfig->useTashkeel = true; - // User provided path - GraphRef.PiperConfig->tashkeelModelPath = - GraphRef.Config->TashkeelModelPath->string(); - } - - piper::initialize(*GraphRef.PiperConfig); - - // Update the default config - updateSynthesisConfig(GraphRef.Config->DefaultSynthesisConfig, - GraphRef.Voice->synthesisConfig, false); - // Copy back the result - GraphRef.Config->DefaultSynthesisConfig.NoiseScale = - GraphRef.Voice->synthesisConfig.noiseScale; - GraphRef.Config->DefaultSynthesisConfig.LengthScale = - GraphRef.Voice->synthesisConfig.lengthScale; - GraphRef.Config->DefaultSynthesisConfig.NoiseW = - GraphRef.Voice->synthesisConfig.noiseW; - GraphRef.Config->DefaultSynthesisConfig.SentenceSilenceSeconds = - GraphRef.Voice->synthesisConfig.sentenceSilenceSeconds; - GraphRef.Config->DefaultSynthesisConfig.PhonemeSilenceSeconds = - GraphRef.Voice->synthesisConfig.phonemeSilenceSeconds; - - // Store the loaded graph. + GraphRef.Synth = std::unique_ptr(Synth); GraphId = GId; Env.NNGraph[GId].setReady(); return WASINN::ErrNo::Success; @@ -371,74 +295,43 @@ Expect setInput(WASINN::WasiNNEnvironment &Env, spdlog::error("[WASI-NN] Piper backend: Input index must be 0."sv); return WASINN::ErrNo::InvalidArgument; } - if (!(Tensor.Dimension.size() == 1 && Tensor.Dimension[0] == 1)) { + if (Tensor.Dimension.size() != 1) { spdlog::error( - "[WASI-NN] Piper backend: Input tensor dimension must be [1]."sv); + "[WASI-NN] Piper backend: Input tensor dimension must be 1D."sv); return WASINN::ErrNo::InvalidArgument; } auto &CxtRef = Env.NNContext[ContextId].get(); auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); - auto Line = std::string{Tensor.Tensor.begin(), Tensor.Tensor.end()}; + CxtRef.Line = + std::string(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); if (GraphRef.Config->JsonInput) { simdjson::dom::parser Parser; simdjson::dom::element Doc; - if (auto Error = Parser.parse(Line).get(Doc)) { - spdlog::error("[WASI-NN] Piper backend: Parse json input error: {}"sv, - simdjson::error_message(Error)); - return WASINN::ErrNo::InvalidEncoding; - } - simdjson::dom::object Object; - if (auto Error = Doc.get(Object)) { - spdlog::error( - "[WASI-NN] Piper backend: The json input is not an object: {}"sv, - simdjson::error_message(Error)); + simdjson::padded_string const PaddedInput(CxtRef.Line.value()); + + if (Parser.parse(PaddedInput).get(Doc) != simdjson::SUCCESS) { + spdlog::error("[WASI-NN] Piper backend: Failed to parse JSON input."sv); return WASINN::ErrNo::InvalidArgument; } - // Text is required - auto Text = std::string_view{}; - if (auto Error = Object["text"].get(Text)) { - spdlog::error( - "[WASI-NN] Piper backend: Unable to retrieve required \"text\" from json input: {}"sv, - simdjson::error_message(Error)); + simdjson::dom::object JsonObj; + if (Doc.get(JsonObj) != simdjson::SUCCESS) { + spdlog::error("[WASI-NN] Piper backend: JSON input is not an object."sv); return WASINN::ErrNo::InvalidArgument; } - Line = Text; - // Parse override config - auto JsonInputSynthesisConfig = SynthesisConfig{}; - if (auto Err = parseSynthesisConfig(JsonInputSynthesisConfig, Object, true); + SynthesisConfig NewConfig; + if (auto Err = parseSynthesisConfig(NewConfig, JsonObj); Err != WASINN::ErrNo::Success) { return Err; } - if (!JsonInputSynthesisConfig.SpeakerId) { - auto SpeakerName = std::optional{}; - if (auto Err = getOptionalOption(Object, "speaker", SpeakerName); - Err != WASINN::ErrNo::Success) { - return Err; - } - if (SpeakerName) { - // Resolve to id using speaker id map - auto Name = std::string{SpeakerName.value()}; - if (GraphRef.Voice->modelConfig.speakerIdMap && - GraphRef.Voice->modelConfig.speakerIdMap->count(Name) > 0) { - JsonInputSynthesisConfig.SpeakerId = - GraphRef.Voice->modelConfig.speakerIdMap.value()[Name]; - } else { - spdlog::warn("[WASI-NN] Piper backend: No speaker named: {}"sv, Name); - } - } - } - if (!CxtRef.JsonInputSynthesisConfig) { - CxtRef.JsonInputSynthesisConfig = - std::make_unique>(); - } - *CxtRef.JsonInputSynthesisConfig = JsonInputSynthesisConfig; + CxtRef.JsonInputSynthesisConfig = + std::make_unique>(NewConfig); } - CxtRef.Line = Line; return WASINN::ErrNo::Success; } @@ -487,42 +380,109 @@ Expect compute(WASINN::WasiNNEnvironment &Env, return WASINN::ErrNo::InvalidArgument; } + std::string TextToSpeak = CxtRef.Line.value(); + + if (GraphRef.Config->JsonInput) { + simdjson::dom::parser Parser; + simdjson::dom::element Doc; + simdjson::padded_string const PaddedInput(TextToSpeak); + + if (Parser.parse(PaddedInput).get(Doc) != simdjson::SUCCESS) { + spdlog::error("[WASI-NN] Piper backend: Failed to parse JSON input."sv); + return WASINN::ErrNo::InvalidArgument; + } + + simdjson::dom::object JsonObj; + if (Doc.get(JsonObj) != simdjson::SUCCESS) { + spdlog::error("[WASI-NN] Piper backend: JSON input is not an object."sv); + return WASINN::ErrNo::InvalidArgument; + } + + std::string_view TmpText; + if (JsonObj["text"].get(TmpText) != simdjson::SUCCESS) { + spdlog::error( + "[WASI-NN] Piper backend: JSON input must contain 'text' field."sv); + return WASINN::ErrNo::InvalidArgument; + } + TextToSpeak = std::string(TmpText); + + SynthesisConfig NewConfig; + if (parseSynthesisConfig(NewConfig, JsonObj) == WASINN::ErrNo::Success) { + CxtRef.JsonInputSynthesisConfig = + std::make_unique>(NewConfig); + } + } + + piper_synthesize_options Options = + piper_default_synthesize_options(GraphRef.Synth.get()); + updatePiperOptions(GraphRef.Config->DefaultSynthesisConfig, Options); + auto OutputType = SynthesisConfigOutputType::OUTPUT_WAV; if (GraphRef.Config->DefaultSynthesisConfig.OutputType) { OutputType = GraphRef.Config->DefaultSynthesisConfig.OutputType.value(); } - // Override config if (CxtRef.JsonInputSynthesisConfig && CxtRef.JsonInputSynthesisConfig->has_value()) { - updateSynthesisConfig(CxtRef.JsonInputSynthesisConfig->value(), - GraphRef.Voice->synthesisConfig, false); + updatePiperOptions(CxtRef.JsonInputSynthesisConfig->value(), Options); if (CxtRef.JsonInputSynthesisConfig->value().OutputType) { OutputType = CxtRef.JsonInputSynthesisConfig->value().OutputType.value(); } } - auto Result = piper::SynthesisResult{}; + int const Res = piper_synthesize_start(GraphRef.Synth.get(), + TextToSpeak.c_str(), &Options); + if (Res != PIPER_OK) { + spdlog::error("[WASI-NN] Piper backend: piper_synthesize_start failed."sv); + return WASINN::ErrNo::RuntimeError; + } + + std::vector AudioBuffer; + piper_audio_chunk Chunk; + int SampleRate = 0; + constexpr float MaxWavValue = 32767.0f; + + while (piper_synthesize_next(GraphRef.Synth.get(), &Chunk) != PIPER_DONE) { + if (Chunk.num_samples == 0) { + continue; + } + SampleRate = Chunk.sample_rate; + size_t OriginalSize = AudioBuffer.size(); + AudioBuffer.resize(OriginalSize + Chunk.num_samples); + + for (size_t I = 0; I < Chunk.num_samples; I++) { + float Sample = Chunk.samples[I]; + Sample = std::clamp(Sample, -1.0f, 1.0f); + AudioBuffer[OriginalSize + I] = + static_cast(Sample * MaxWavValue); + } + } + + CxtRef.Output.emplace(); + constexpr int DefaultPiperSampleRate = 22050; if (OutputType == SynthesisConfigOutputType::OUTPUT_WAV) { - auto AudioFile = - std::stringstream{std::ios::binary | std::ios::in | std::ios::out}; - piper::textToWavFile(*GraphRef.PiperConfig, *GraphRef.Voice, - CxtRef.Line.value(), AudioFile, Result); - auto String = AudioFile.str(); - CxtRef.Output = std::vector{String.begin(), String.end()}; - } else if (OutputType == SynthesisConfigOutputType::OUTPUT_RAW) { - auto AudioBuffer = std::vector{}; - piper::textToAudio(*GraphRef.PiperConfig, *GraphRef.Voice, - CxtRef.Line.value(), AudioBuffer, Result, nullptr); - CxtRef.Output = std::vector( - sizeof(decltype(AudioBuffer)::value_type) * AudioBuffer.size()); - std::memcpy(CxtRef.Output->data(), AudioBuffer.data(), - CxtRef.Output->size()); - } - - // Restore config (json_input) - updateSynthesisConfig(GraphRef.Config->DefaultSynthesisConfig, - GraphRef.Voice->synthesisConfig, true); + if (SampleRate == 0) { + SampleRate = DefaultPiperSampleRate; + } + size_t const TotalSize = 44 + (AudioBuffer.size() * sizeof(int16_t)); + CxtRef.Output->reserve(TotalSize); + + writeWavHeader(SampleRate, 1, static_cast(AudioBuffer.size()), + *CxtRef.Output); + + const uint8_t *RawData = + reinterpret_cast(AudioBuffer.data()); + CxtRef.Output->insert(CxtRef.Output->end(), RawData, + RawData + (AudioBuffer.size() * sizeof(int16_t))); + + } else { + size_t const TotalSize = AudioBuffer.size() * sizeof(int16_t); + CxtRef.Output->resize(TotalSize); + const uint8_t *RawData = + reinterpret_cast(AudioBuffer.data()); + std::copy_n(RawData, TotalSize, CxtRef.Output->data()); + } + return WASINN::ErrNo::Success; } #else diff --git a/plugins/wasi_nn/wasinn_piper.h b/plugins/wasi_nn/wasinn_piper.h index 1dec1455..900abef5 100644 --- a/plugins/wasi_nn/wasinn_piper.h +++ b/plugins/wasi_nn/wasinn_piper.h @@ -3,14 +3,14 @@ #pragma once -#include "wasinntypes.h" - #include "plugin/plugin.h" +#include "wasinntypes.h" #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER -#include +#include #include +#include #include #include #include @@ -19,7 +19,7 @@ namespace WasmEdge::Host::WASINN { struct WasiNNEnvironment; -} +} // namespace WasmEdge::Host::WASINN namespace WasmEdge::Host::WASINN::Piper { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER @@ -30,7 +30,7 @@ struct SynthesisConfig { std::optional OutputType; // Numerical id of the default speaker (multi-speaker voices) - std::optional SpeakerId; + std::optional SpeakerId; // Amount of noise to add during audio generation std::optional NoiseScale; @@ -41,12 +41,10 @@ struct SynthesisConfig { // Variation in phoneme lengths std::optional NoiseW; - // Seconds of silence to add after each sentence - std::optional SentenceSilenceSeconds; - - // Seconds of extra silence to insert after a single phoneme - std::optional> PhonemeSilenceSeconds; + // NOTE: Phoneme/Sentence silence configuration is not exposed + // in the new upstream C API (piper.h) and has been removed. }; + struct RunConfig { // Path to .onnx voice file std::filesystem::path ModelPath; @@ -57,25 +55,29 @@ struct RunConfig { // Path to espeak-ng data directory std::optional ESpeakDataPath; - // Path to libtashkeel ort model - // https://github.com/mush42/libtashkeel/ - std::optional TashkeelModelPath; - // input is JSON with format: // { // "text": str, (required) // "speaker_id": int, (optional) - // "speaker": str, (optional) + // "output_type": str, (optional, "wav" or "raw") // } // including options in SynthesisConfig bool JsonInput = false; SynthesisConfig DefaultSynthesisConfig; }; + +// Custom deleter for the piper_synthesizer +struct PiperDeleter { + void operator()(piper_synthesizer *P) const { + if (P) + piper_free(P); + } +}; + struct Graph { std::unique_ptr Config; - std::unique_ptr PiperConfig; - std::unique_ptr Voice; + std::unique_ptr Synth; }; struct Context { Context(uint32_t GId, Graph &) noexcept : GraphId(GId) {} diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index d39dfdb6..0a7512e9 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -88,24 +88,58 @@ foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) elseif(BACKEND STREQUAL "piper") message(STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_piper_fixtures") download( - https://github.com/rhasspy/piper/raw/master/etc/test_voice.onnx + https://github.com/OHF-Voice/piper1-gpl/raw/v1.3.0/tests/test_voice.onnx ${CMAKE_CURRENT_BINARY_DIR}/wasinn_piper_fixtures/test_voice.onnx - SHA256=937682595755bbb3ee9f131b8a4b2b1ba2fac9b26431fcd7aa48cff0f7382838 + SHA256=1c8bbb420741358f0a356bb83eaae1b4161fbb5974f6941e10eb5a1725d78994 ) download( - https://github.com/rhasspy/piper/raw/master/etc/test_voice.onnx.json + https://github.com/OHF-Voice/piper1-gpl/raw/v1.3.0/tests/test_voice.onnx.json ${CMAKE_CURRENT_BINARY_DIR}/wasinn_piper_fixtures/test_voice.onnx.json - SHA256=f3e0b906861cc2fb8a50e12ceca263afe226ff9688f60e9d4ef943d4f047a513 + SHA256=ccd28e02c334fbcfc94a86c8f86f1d7dbb5bffc844af9f22243a1f9f7840db1b ) - download( - https://github.com/rhasspy/piper/releases/download/2023.11.14-2/piper_linux_x86_64.tar.gz - ${CMAKE_CURRENT_BINARY_DIR}/wasinn_piper_fixtures/piper_linux_x86_64.tar.gz - SHA256=a50cb45f355b7af1f6d758c1b360717877ba0a398cc8cbe6d2a7a3a26e225992 - ) - file(ARCHIVE_EXTRACT - INPUT ${CMAKE_CURRENT_BINARY_DIR}/wasinn_piper_fixtures/piper_linux_x86_64.tar.gz - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/wasinn_piper_fixtures - PATTERNS piper/espeak-ng-data + set(ESPEAK_SOURCE_DIR "") + + if (DEFINED PIPER_ROOT) + set(ESPEAK_SOURCE_DIR "${PIPER_ROOT}/espeak-ng-data") + elseif(EXISTS "${CMAKE_BINARY_DIR}/espeak_ng-install/share/espeak-ng-data") + set(ESPEAK_SOURCE_DIR "${CMAKE_BINARY_DIR}/espeak_ng-install/share/espeak-ng-data") + endif() + + if (EXISTS "${ESPEAK_SOURCE_DIR}") + file( + COPY ${ESPEAK_SOURCE_DIR} + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/wasinn_piper_fixtures + ) + else() + message(WARNING "Could not find espeak-ng-data at ${ESPEAK_SOURCE_DIR}") + endif() + if(DEFINED PIPER_ROOT) + find_library(ESPEAK_NG_LIB + NAMES espeak-ng libespeak-ng + PATHS /usr/local/lib /usr/local/lib64 + NO_DEFAULT_PATH + ) + if (NOT ESPEAK_NG_LIB) + find_library(ESPEAK_NG_LIB NAMES espeak-ng libespeak-ng) + endif() + find_library(UCD_LIB + NAMES ucd libucd + PATHS /usr/local/lib /usr/local/lib64 + NO_DEFAULT_PATH + ) + if (NOT UCD_LIB) + find_library(UCD_LIB NAMES ucd libucd) + endif() + set(ESPEAK_TARGETS ${ESPEAK_NG_LIB} ${UCD_LIB}) + + else() + set(ESPEAK_TARGETS "") + endif() + + target_link_libraries(wasiNNTests + PRIVATE + onnxruntime + ${ESPEAK_TARGETS} ) elseif(BACKEND STREQUAL "whisper") message( STATUS "Download ML artifacts to ${CMAKE_CURRENT_BINARY_DIR}/wasinn_whisper_fixtures") diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index 6e3b00c8..ad580eb3 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -2519,8 +2519,7 @@ TEST(WasiNNTest, PiperBackend) { // First json input with parameters overridden Text = "{\"text\": \"This is a test.\", \"noise_scale\": 0.0, " - "\"length_scale\": 2.0, \"noise_w\": 0.0, \"sentence_silence\": 1.0, " - "\"phoneme_silence\": {\"t\": 0.0}}"; + "\"length_scale\": 2.0, \"noise_w\": 0.0}"; TensorData = {Text.begin(), Text.end()}; SetInputEntryPtr = BuilderPtr; writeFatPointer(MemInst, StorePtr, TensorDim.size(), BuilderPtr); @@ -2563,8 +2562,8 @@ TEST(WasiNNTest, PiperBackend) { Errno)); EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); auto BytesWritten = *MemInst.getPointer(BuilderPtr); - // Should output more than 50000 bytes. - EXPECT_GE(BytesWritten, 50000); + // Should output more than 40000 bytes. + EXPECT_GE(BytesWritten, 40000); } // Second json input to check if one-time overriding is working properly @@ -2613,9 +2612,9 @@ TEST(WasiNNTest, PiperBackend) { EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::Success)); auto BytesWritten = *MemInst.getPointer(BuilderPtr); EXPECT_GE(BytesWritten, 30000); - // Should output less than 40000 bytes. - EXPECT_LT(BytesWritten, 40000); - EXPECT_EQ(BytesWritten, 34048); + // Should output less than 50000 bytes. + EXPECT_LT(BytesWritten, 50000); + EXPECT_EQ(BytesWritten, 44100); } } #endif // WASMEDGE_PLUGIN_WASI_NN_BACKEND_PIPER diff --git a/utils/build_libpiper.sh b/utils/build_libpiper.sh new file mode 100644 index 00000000..05f4ab50 --- /dev/null +++ b/utils/build_libpiper.sh @@ -0,0 +1,18 @@ +#!/bin/bash +set -e + +echo "::group::Build and install libpiper" +rm -rf piper-source + +git clone https://github.com/OHF-Voice/piper1-gpl piper-source +cd piper-source/libpiper +git checkout 32b95f8c1f0dc0ce27a6acd1143de331f61af777 +cmake -Bbuild-deps \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX="$PWD/install" \ + -DBUILD_SHARED_LIBS=OFF \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON + +cmake --build build-deps +cmake --install build-deps +echo "::endgroup::" From 7952a72372b9f2e7aafce104095411035a5a5daf Mon Sep 17 00:00:00 2001 From: Yi Liu Date: Thu, 26 Feb 2026 14:57:28 +0800 Subject: [PATCH 608/623] fix(plugin/process): correct timeout unit conversion (#4678) * fix(plugin): correct timeout unit conversion in process plugin The timeout calculation divided microseconds by 1000000 (converting to seconds) instead of 1000 (converting to milliseconds). Since the seconds part is already multiplied by 1000 to convert to milliseconds, the microseconds part should also be in milliseconds. This gave the timeout only second-level granularity instead of the intended millisecond-level granularity, making sub-second timeouts effectively zero. Signed-off-by: Yi LIU * test(plugin): add process plugin timeout precision test Signed-off-by: Yi LIU --------- Signed-off-by: Yi LIU --- plugins/wasmedge_process/processfunc.cpp | 2 +- .../wasmedge_process/wasmedge_process.cpp | 45 +++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/plugins/wasmedge_process/processfunc.cpp b/plugins/wasmedge_process/processfunc.cpp index d8eee10a..5b9f3026 100644 --- a/plugins/wasmedge_process/processfunc.cpp +++ b/plugins/wasmedge_process/processfunc.cpp @@ -228,7 +228,7 @@ Expect WasmEdgeProcessRun::body(const Runtime::CallingFrame &) { while (true) { gettimeofday(&TCurr, NULL); if ((TCurr.tv_sec - TStart.tv_sec) * 1000U + - (TCurr.tv_usec - TStart.tv_usec) / 1000000U > + (TCurr.tv_usec - TStart.tv_usec) / 1000U > Env.TimeOut) { // Over timeout. Interrupt child process. kill(PID, SIGKILL); diff --git a/test/plugins/wasmedge_process/wasmedge_process.cpp b/test/plugins/wasmedge_process/wasmedge_process.cpp index d74e245f..6791c9d8 100644 --- a/test/plugins/wasmedge_process/wasmedge_process.cpp +++ b/test/plugins/wasmedge_process/wasmedge_process.cpp @@ -8,6 +8,8 @@ #include #include +#include +#include #include #include #include @@ -376,6 +378,49 @@ TEST(WasmEdgeProcessTest, Run) { ProcMod->getEnv().StdOut.end(), OutStr.begin())); } +TEST(WasmEdgeProcessTest, TimeoutPrecision) { + // Create the wasmedge_process module instance. + auto ProcMod = createModule(); + ASSERT_TRUE(ProcMod); + + // Get the function "wasmedge_process_run". + auto *FuncInst = ProcMod->findFuncExports("wasmedge_process_run"); + EXPECT_NE(FuncInst, nullptr); + EXPECT_TRUE(FuncInst->isHostFunction()); + auto &HostFuncRun = dynamic_cast( + FuncInst->getHostFunc()); + + // Return value. + std::array RetVal; + + // Run "sleep 2" with a 500ms timeout. + // With the fix, the process should be killed after ~500ms. + // With the old bug (/1000000U instead of /1000U), the microsecond + // component was effectively zero, so timeout only fired at whole-second + // boundaries (i.e., at ~1000ms instead of ~500ms). + ProcMod->getEnv().AllowedAll = true; + ProcMod->getEnv().Name = "sleep"; + ProcMod->getEnv().Args.push_back("2"); + ProcMod->getEnv().TimeOut = 500; + + auto Start = std::chrono::steady_clock::now(); + EXPECT_TRUE(HostFuncRun.run(DummyCallFrame, {}, RetVal)); + auto End = std::chrono::steady_clock::now(); + auto ElapsedMs = + std::chrono::duration_cast(End - Start) + .count(); + + // The process should have been killed due to timeout. + EXPECT_EQ(RetVal[0].get(), static_cast(ETIMEDOUT)); + + // With the fix, elapsed time should be close to 500ms (the timeout value). + // Allow generous margins for CI/scheduling variance, but the key assertion + // is that it finishes well before the 2-second sleep completes and does not + // overshoot to ~1000ms (the old buggy whole-second boundary). + EXPECT_GE(ElapsedMs, 400); + EXPECT_LE(ElapsedMs, 900); +} + TEST(WasmEdgeProcessTest, GetExitCode) { // Create the wasmedge_process module instance. auto ProcMod = createModule(); From 2efbc2b14eea90e887e755a8ae9b8073a9a2108f Mon Sep 17 00:00:00 2001 From: Yi-Ying He Date: Fri, 10 Apr 2026 16:38:50 +0800 Subject: [PATCH 609/623] fix(test): disable LTO in tensorflow plugin tests of manylinux (#4766) Signed-off-by: YiYing He --- test/plugins/wasmedge_tensorflow/CMakeLists.txt | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/plugins/wasmedge_tensorflow/CMakeLists.txt b/test/plugins/wasmedge_tensorflow/CMakeLists.txt index 9c0f6823..2104f3b3 100644 --- a/test/plugins/wasmedge_tensorflow/CMakeLists.txt +++ b/test/plugins/wasmedge_tensorflow/CMakeLists.txt @@ -5,6 +5,15 @@ wasmedge_add_executable(wasmedgeTensorflowTests wasmedge_tensorflow.cpp ) +# On manylinux_2_28 with gcc-toolset and CXX11_ABI=0, LTO can cause unresolved +# internal libstdc++ symbols from libstdc++_nonshared.a (cow-fs_path.o) when +# linking against the TensorFlow C++ shared libraries. +if(CMAKE_COMPILER_IS_GNUCXX AND NOT WASMEDGE_USE_CXX11_ABI) + set_target_properties(wasmedgeTensorflowTests PROPERTIES + INTERPROCEDURAL_OPTIMIZATION OFF + ) +endif() + add_dependencies(wasmedgeTensorflowTests wasmedgePluginWasmEdgeTensorflow ) From 182fb0f101609c4c99a456283fd148b6b1802083 Mon Sep 17 00:00:00 2001 From: grorge Date: Sat, 11 Apr 2026 16:51:25 +0800 Subject: [PATCH 610/623] feat(ggml): upgrade version to b8757 Assisted-by: Codex (OpenAI) Signed-off-by: grorge --- plugins/wasi_nn/GGML/core/ggml_core.cpp | 18 ++++++----- .../wasi_nn/GGML/metadata/metadata_parser.cpp | 31 +++++++++++++------ test/plugins/wasi_nn/CMakeLists.txt | 7 +++++ 3 files changed, 40 insertions(+), 16 deletions(-) diff --git a/plugins/wasi_nn/GGML/core/ggml_core.cpp b/plugins/wasi_nn/GGML/core/ggml_core.cpp index 28f7a323..ca2b519f 100644 --- a/plugins/wasi_nn/GGML/core/ggml_core.cpp +++ b/plugins/wasi_nn/GGML/core/ggml_core.cpp @@ -87,7 +87,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, GraphRef.Conf.StreamStdout = false; GraphRef.Conf.EmbdNormalize = static_cast(CommonParamsDefault.embd_normalize); - GraphRef.Conf.NPredict = GraphRef.Params.n_ctx; + GraphRef.Conf.NPredict = GraphRef.Params.n_predict; GraphRef.Conf.ReversePrompt = ""sv; GraphRef.Conf.ImagePath = ""sv; @@ -164,13 +164,15 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, llama_numa_init(Params.numa); // Initialize the llama model and context. - common_init_result LlamaInit = common_init_from_params(Params); - GraphRef.LlamaModel = std::move(LlamaInit.model); - GraphRef.LlamaContext = std::move(LlamaInit.context); + llama_model_params ModelParams = common_model_params_to_llama(Params); + GraphRef.LlamaModel = llama_model_ptr( + llama_model_load_from_file(Params.model.path.c_str(), ModelParams)); if (GraphRef.LlamaModel == nullptr) { Env.deleteGraph(GId.raw()); RET_ERROR(ErrNo::InvalidArgument, "load: unable to init model."sv) } + GraphRef.LlamaContext = llama_context_ptr(llama_init_from_model( + GraphRef.LlamaModel.get(), common_context_params_to_llama(Params))); if (GraphRef.LlamaContext == nullptr) { Env.deleteGraph(GId.raw()); RET_ERROR(ErrNo::InvalidArgument, "load: unable to init context."sv) @@ -183,13 +185,15 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, LOG_DEBUG(GraphRef.EnableDebugLog, "load: initialize TTS model."sv) Params.model = GraphRef.Params.vocoder.model; Params.embedding = true; - common_init_result TTSInit = common_init_from_params(Params); - GraphRef.TTSModel = std::move(TTSInit.model); - GraphRef.TTSContext = std::move(TTSInit.context); + llama_model_params TTSModelParams = common_model_params_to_llama(Params); + GraphRef.TTSModel = llama_model_ptr( + llama_model_load_from_file(Params.model.path.c_str(), TTSModelParams)); if (GraphRef.TTSModel == nullptr) { Env.deleteGraph(GId.raw()); RET_ERROR(ErrNo::InvalidArgument, "load: unable to init TTS model."sv) } + GraphRef.TTSContext = llama_context_ptr(llama_init_from_model( + GraphRef.TTSModel.get(), common_context_params_to_llama(Params))); if (GraphRef.TTSContext == nullptr) { Env.deleteGraph(GId.raw()); RET_ERROR(ErrNo::InvalidArgument, "load: unable to init TTS context."sv) diff --git a/plugins/wasi_nn/GGML/metadata/metadata_parser.cpp b/plugins/wasi_nn/GGML/metadata/metadata_parser.cpp index ac9af83f..02e643c4 100644 --- a/plugins/wasi_nn/GGML/metadata/metadata_parser.cpp +++ b/plugins/wasi_nn/GGML/metadata/metadata_parser.cpp @@ -33,7 +33,8 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, double PrevRepeatPenalty = GraphRef.Params.sampling.penalty_repeat; double PrevPresencePenalty = GraphRef.Params.sampling.penalty_present; double PrevFrequencyPenalty = GraphRef.Params.sampling.penalty_freq; - std::string PrevGrammar = GraphRef.Params.sampling.grammar; + std::string PrevGrammar = + common_grammar_value(GraphRef.Params.sampling.grammar); uint32_t PrevSeed = GraphRef.Params.sampling.seed; try { @@ -259,13 +260,19 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, parseJsonAuto(Doc, "timing-per-token", GraphRef.Params.sampling.timing_per_token); - parseJsonWithCastAuto(Doc, "grammar", - GraphRef.Params.sampling.grammar); + parseJsonWithProcessorAuto( + Doc, "grammar", [&GraphRef](const std::string_view &Grammar) -> bool { + GraphRef.Params.sampling.grammar = + common_grammar(COMMON_GRAMMAR_TYPE_USER, std::string(Grammar)); + return true; + }); parseJsonWithProcessorAuto( Doc, "json-schema", [&GraphRef](const std::string_view &JsonSchema) -> bool { GraphRef.Params.sampling.grammar = - json_schema_to_grammar(nlohmann::ordered_json::parse(JsonSchema)); + common_grammar(COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT, + json_schema_to_grammar( + nlohmann::ordered_json::parse(JsonSchema))); return true; }); parseJsonWithCastAuto(Doc, "seed", GraphRef.Params.sampling.seed); @@ -299,8 +306,12 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, parseJsonWithCastAuto(Doc, "image", ConfRef.ImagePath); parseJsonAuto(Doc, "always-regenerate-image-embd", ConfRef.AlwaysRegenerateImageEmbd); - parseJsonWithCastAuto(Doc, "model-alias", - GraphRef.Params.model_alias); + parseJsonWithProcessorAuto( + Doc, "model-alias", + [&GraphRef](const std::string_view &ModelAlias) -> bool { + GraphRef.Params.model_alias.emplace(ModelAlias); + return true; + }); parseJsonWithCastAuto(Doc, "model-url", GraphRef.Params.model.url); parseJsonWithCastAuto(Doc, "hf-token", @@ -318,9 +329,11 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, parseJsonWithCastAuto(Doc, "input-suffix", GraphRef.Params.input_suffix); parseJsonWithCastAuto( - Doc, "lookup-cache-static", GraphRef.Params.lookup_cache_static); + Doc, "lookup-cache-static", + GraphRef.Params.speculative.lookup_cache_static); parseJsonWithCastAuto( - Doc, "lookup-cache-dynamic", GraphRef.Params.lookup_cache_dynamic); + Doc, "lookup-cache-dynamic", + GraphRef.Params.speculative.lookup_cache_dynamic); parseJsonWithCastAuto(Doc, "logits-file", GraphRef.Params.logits_file); parseJsonAuto(Doc, "lora-init-without-apply", @@ -540,7 +553,7 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, PrevRepeatPenalty != GraphRef.Params.sampling.penalty_repeat || PrevPresencePenalty != GraphRef.Params.sampling.penalty_present || PrevFrequencyPenalty != GraphRef.Params.sampling.penalty_freq || - PrevGrammar != GraphRef.Params.sampling.grammar || + PrevGrammar != common_grammar_value(GraphRef.Params.sampling.grammar) || PrevSeed != GraphRef.Params.sampling.seed)) { *IsSamplerUpdated = true; } diff --git a/test/plugins/wasi_nn/CMakeLists.txt b/test/plugins/wasi_nn/CMakeLists.txt index 0a7512e9..23cb850d 100644 --- a/test/plugins/wasi_nn/CMakeLists.txt +++ b/test/plugins/wasi_nn/CMakeLists.txt @@ -204,6 +204,13 @@ target_link_libraries(wasiNNTests ${GTEST_BOTH_LIBRARIES} ) +if(TARGET build_info) + target_link_libraries(wasiNNTests + PRIVATE + build_info + ) +endif() + # Link to the WasmEdge library if(WASMEDGE_LINK_PLUGINS_STATIC) target_link_libraries(wasiNNTests From cef09caddd411b9e103e67ba1279054714622589 Mon Sep 17 00:00:00 2001 From: Han-Wen Tsao Date: Sat, 18 Apr 2026 11:43:35 +0800 Subject: [PATCH 611/623] feat(ggml): redirect multi model log (#4801) Signed-off-by: grorge --- plugins/wasi_nn/GGML/core/ggml_core.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/plugins/wasi_nn/GGML/core/ggml_core.cpp b/plugins/wasi_nn/GGML/core/ggml_core.cpp index ca2b519f..2815f1cb 100644 --- a/plugins/wasi_nn/GGML/core/ggml_core.cpp +++ b/plugins/wasi_nn/GGML/core/ggml_core.cpp @@ -93,6 +93,7 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Set llama log callback. llama_log_set(llamaLogCallback, &GraphRef); + mtmd_helper_log_set(llamaLogCallback, &GraphRef); // If the graph builder length > 1, the data of builder[1] is the metadata. if (Builders.size() > 1) { From 85a6cffdf49e805cf848adea73a2edc9d67d8435 Mon Sep 17 00:00:00 2001 From: LuaLighter Date: Fri, 24 Apr 2026 20:37:05 +0800 Subject: [PATCH 612/623] fix(typo): some typo in comments (#4810) chore: fix some minor issues Signed-off-by: purelualight --- plugins/wasm_bpf/wasm-bpf.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/wasm_bpf/wasm-bpf.cpp b/plugins/wasm_bpf/wasm-bpf.cpp index 490dbbdd..1dd9b25f 100644 --- a/plugins/wasm_bpf/wasm-bpf.cpp +++ b/plugins/wasm_bpf/wasm-bpf.cpp @@ -173,7 +173,7 @@ std::unique_ptr bpf_buffer__new(bpf_map *events) { int32_t wasm_bpf_program::bpf_map_fd_by_name(const char *name) { return bpf_object__find_map_fd_by_name(obj.get(), name); } -/// \brief load all bpf programs and maps in a object file. +/// \brief load all bpf programs and maps in an object file. int32_t wasm_bpf_program::load_bpf_object(const void *obj_buf, size_t obj_buf_sz) { auto object = bpf_object__open_mem(obj_buf, obj_buf_sz, nullptr); From b40f20b41e0dddfef84fa0c1f87a8c0551d9fc16 Mon Sep 17 00:00:00 2001 From: hydai Date: Fri, 8 May 2026 10:55:06 +0800 Subject: [PATCH 613/623] chore: fix comment grammar sweep (#4850) Assisted-by: OpenAI Codex (GPT-5.5) Signed-off-by: hydai Co-authored-by: YiYing He --- plugins/CMakeLists.txt | 16 +++---- plugins/wasi_crypto/asymmetric_common/ecdsa.h | 2 +- .../wasi_crypto/asymmetric_common/publickey.h | 2 +- plugins/wasi_crypto/common/array_output.h | 9 ++-- plugins/wasi_crypto/common/options.h | 2 +- plugins/wasi_crypto/kx/dh/ecdsa.h | 2 +- plugins/wasi_crypto/kx/dh/x25519.h | 2 +- plugins/wasi_crypto/signatures/ecdsa.cpp | 6 +-- plugins/wasi_crypto/signatures/ecdsa.h | 2 +- plugins/wasi_crypto/signatures/eddsa.cpp | 4 +- plugins/wasi_crypto/signatures/eddsa.h | 2 +- plugins/wasi_crypto/signatures/rsa.h | 2 +- plugins/wasi_crypto/signatures/signstate.h | 2 +- plugins/wasi_crypto/symmetric/aeads.h | 17 ++++--- plugins/wasi_crypto/symmetric/hash.h | 12 ++--- plugins/wasi_crypto/symmetric/kdf.h | 12 ++--- plugins/wasi_crypto/symmetric/mac.h | 8 ++-- plugins/wasi_crypto/symmetric/state.cpp | 2 +- plugins/wasi_crypto/symmetric/state.h | 4 +- plugins/wasi_crypto/symmetric/tag.h | 4 +- plugins/wasi_crypto/utils/evp_wrapper.h | 17 +++---- plugins/wasi_crypto/utils/handles_manager.h | 23 +++++----- plugins/wasi_crypto/utils/optional.h | 2 +- plugins/wasi_nn/CMakeLists.txt | 4 +- .../GGML/compute/inference_manager.cpp | 10 ++--- plugins/wasi_nn/GGML/core/ggml_core.cpp | 3 +- plugins/wasi_nn/GGML/core/ggml_type.h | 16 +++---- plugins/wasi_nn/GGML/core/input_processor.cpp | 32 +++++++------- .../wasi_nn/GGML/core/output_generator.cpp | 4 +- .../wasi_nn/GGML/metadata/metadata_parser.cpp | 4 +- plugins/wasi_nn/GGML/tts/tts_core.cpp | 6 +-- plugins/wasi_nn/GGML/utils.cpp | 2 +- plugins/wasi_nn/MLX/mlx/convolution.h | 2 +- .../wasi_nn/MLX/model/whisper/decoding.cpp | 4 +- .../wasi_nn/MLX/model/whisper_transcribe.cpp | 8 ++-- plugins/wasi_nn/wasinn_bitnet.cpp | 28 ++++++------ plugins/wasi_nn/wasinn_openvino_genai.cpp | 2 +- plugins/wasi_nn/wasinn_piper.cpp | 2 +- plugins/wasi_nn/wasinn_piper.h | 2 +- plugins/wasi_nn/wasinn_whisper.cpp | 6 +-- plugins/wasi_nn/wasinn_whisper.h | 4 +- plugins/wasi_nn/wasinnenv.h | 13 +++--- plugins/wasi_nn/wasinnfunc.cpp | 10 ++--- plugins/wasm_bpf/CMakeLists.txt | 5 ++- plugins/wasm_bpf/README.md | 19 ++++---- plugins/wasm_bpf/bpf-api.h | 32 +++++++------- plugins/wasm_bpf/func-attach-bpf-program.h | 2 +- plugins/wasm_bpf/func-bpf-buffer-poll.h | 14 +++--- plugins/wasm_bpf/func-bpf-map-fd-by-name.h | 4 +- plugins/wasm_bpf/func-bpf-map-operate.h | 4 +- plugins/wasm_bpf/func-close-bpf-object.h | 4 +- plugins/wasm_bpf/func-load-bpf-object.h | 10 ++--- plugins/wasm_bpf/state.h | 2 +- plugins/wasm_bpf/util.h | 4 +- plugins/wasm_bpf/wasm-bpf.cpp | 24 +++++----- .../wasmedge_ffmpeg/avutil/avDictionary.cpp | 4 +- plugins/wasmedge_ffmpeg/avutil/avFrame.cpp | 2 +- plugins/wasmedge_ffmpeg/avutil/avutil_func.h | 2 +- plugins/wasmedge_ocr/ocr_env.h | 4 +- plugins/wasmedge_opencvmini/opencvmini_func.h | 2 +- plugins/wasmedge_process/processenv.h | 2 +- plugins/wasmedge_process/processfunc.cpp | 6 +-- plugins/wasmedge_stablediffusion/sd_func.cpp | 4 +- .../wasmedge_tensorflow/tensorflow_func.cpp | 4 +- plugins/wasmedge_zlib/zlibenv.h | 2 +- plugins/wasmedge_zlib/zlibfunc.cpp | 4 +- plugins/wasmedge_zlib/zlibfunc.h | 2 +- test/plugins/CMakeLists.txt | 14 +++--- test/plugins/unittest/testplugin.c | 4 +- test/plugins/wasi_crypto/helper.h | 2 +- test/plugins/wasi_logging/wasi_logging.cpp | 2 +- test/plugins/wasi_nn/wasi_nn.cpp | 44 +++++++++---------- test/plugins/wasm_bpf/assets/README.md | 9 ++-- test/plugins/wasm_bpf/simple_map_test.cpp | 30 ++++++------- test/plugins/wasm_bpf/simple_ringbuf_test.cpp | 18 ++++---- test/plugins/wasm_bpf/wasm_bpf.cpp | 36 +++++++-------- .../wasmedge_ffmpeg/avcodec/avcodec_func.cpp | 4 +- .../avformat/avformat_func.cpp | 12 ++--- .../wasmedge_ffmpeg/avutil/avDictionary.cpp | 4 +- .../wasmedge_ffmpeg/swscale/swscale_func.cpp | 19 ++++---- test/plugins/wasmedge_ffmpeg/utils.h | 2 +- .../wasmedge_opencvmini.cpp | 2 +- .../wasmedge_process/wasmedge_process.cpp | 8 ++-- .../wasmedge_stablediffusion.cpp | 2 +- .../wasmedge_tensorflow.cpp | 2 +- .../wasmedge_tensorflowlite.cpp | 2 +- utils/docker/Dockerfile.ubuntu2104_armv7l | 2 +- utils/ffmpeg/download-ffmpeg-sample-video.sh | 2 +- utils/wasi-crypto/build-openssl.sh | 2 +- 89 files changed, 357 insertions(+), 343 deletions(-) diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt index b4d89734..69a0ec9d 100644 --- a/plugins/CMakeLists.txt +++ b/plugins/CMakeLists.txt @@ -14,7 +14,7 @@ endif() # WASI plug-in: WASI-Logging proposal. if(WASMEDGE_PLUGIN_WASI_LOGGING) # BUILTIN-PLUGIN: Add the wasi-logging plugin here after the new plugin - # architecture ready in 0.15.0. + # architecture is ready in 0.15.0. endif() # WASI plug-in: WASI-NN proposal with backends. @@ -29,7 +29,7 @@ endif() # WasmEdge plug-in: wasm-bpf. if(WASMEDGE_PLUGIN_WASM_BPF) - # Only Linux systems support wasm_bpf now. + # wasm_bpf is currently supported only on Linux systems. if(CMAKE_SYSTEM_NAME MATCHES "Linux") add_subdirectory(wasm_bpf) else() @@ -44,7 +44,7 @@ endif() # WasmEdge plug-in: Image. if(WASMEDGE_PLUGIN_IMAGE) - # Only Linux and MacOS support wasmedge_image now. + # wasmedge_image is currently supported only on Linux and macOS. if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") add_subdirectory(wasmedge_image) else() @@ -64,7 +64,7 @@ endif() # WasmEdge plug-in: OpenCV-mini. if(WASMEDGE_PLUGIN_OPENCVMINI) - # Only Linux and MacOS support wasmedge_opencvmini now. + # wasmedge_opencvmini is currently supported only on Linux and macOS. if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") add_subdirectory(wasmedge_opencvmini) else() @@ -74,7 +74,7 @@ endif() # WasmEdge plug-in: Process. if(WASMEDGE_PLUGIN_PROCESS) - # Only Linux systems support wasmedge_process now. + # wasmedge_process is currently supported only on Linux systems. if(CMAKE_SYSTEM_NAME MATCHES "Linux") add_subdirectory(wasmedge_process) else() @@ -84,7 +84,7 @@ endif() # WasmEdge plug-in: Stable-diffusion. if(WASMEDGE_PLUGIN_STABLEDIFFUSION) - # Only Linux and MacOS support wasmedge_stablediffusion now. + # wasmedge_stablediffusion is currently supported only on Linux and macOS. if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") add_subdirectory(wasmedge_stablediffusion) else() @@ -94,7 +94,7 @@ endif() # WasmEdge plug-in: TensorFlow. if(WASMEDGE_PLUGIN_TENSORFLOW) - # Only Linux and MacOS support wasmedge_tensorflow now. + # wasmedge_tensorflow is currently supported only on Linux and macOS. if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") add_subdirectory(wasmedge_tensorflow) else() @@ -104,7 +104,7 @@ endif() # WasmEdge plug-in: TensorFlow-Lite. if(WASMEDGE_PLUGIN_TENSORFLOWLITE) - # Only Linux and MacOS support wasmedge_tensorflowlite now. + # wasmedge_tensorflowlite is currently supported only on Linux and macOS. if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") add_subdirectory(wasmedge_tensorflowlite) else() diff --git a/plugins/wasi_crypto/asymmetric_common/ecdsa.h b/plugins/wasi_crypto/asymmetric_common/ecdsa.h index ee351b8a..85990ca3 100644 --- a/plugins/wasi_crypto/asymmetric_common/ecdsa.h +++ b/plugins/wasi_crypto/asymmetric_common/ecdsa.h @@ -8,7 +8,7 @@ //===----------------------------------------------------------------------===// /// /// \file -/// This file contains the definition of ecdsa algorithm. +/// This file contains the definition of the ECDSA algorithm. /// //===----------------------------------------------------------------------===// diff --git a/plugins/wasi_crypto/asymmetric_common/publickey.h b/plugins/wasi_crypto/asymmetric_common/publickey.h index 5e50efdd..6148dcbb 100644 --- a/plugins/wasi_crypto/asymmetric_common/publickey.h +++ b/plugins/wasi_crypto/asymmetric_common/publickey.h @@ -8,7 +8,7 @@ //===----------------------------------------------------------------------===// /// /// \file -/// This file contains the asymmetric common PubicKey of wasi-crypto. +/// This file contains the asymmetric common PublicKey of wasi-crypto. /// //===----------------------------------------------------------------------===// #pragma once diff --git a/plugins/wasi_crypto/common/array_output.h b/plugins/wasi_crypto/common/array_output.h index 1c00b752..4ca43e27 100644 --- a/plugins/wasi_crypto/common/array_output.h +++ b/plugins/wasi_crypto/common/array_output.h @@ -31,7 +31,7 @@ namespace Common { /// Functions returning arrays whose size is not constant or too large to be /// safely allocated on the stack return a handle to an ArrayOutput type. /// -/// More detail: +/// More details: /// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#array-outputs class ArrayOutput { public: @@ -44,11 +44,12 @@ class ArrayOutput { ArrayOutput(SecretVec &&Data) noexcept : Data(std::move(Data)) {} - /// Copy the content to the @param Buf buffer. - /// Multiple calls are possible, the total number of bytes to be read is + /// Copy the contents to the @param Buf buffer. + /// Multiple calls are possible, and the total number of bytes to be read is /// guaranteed to always match the data size. /// - /// @returns the number of bytes read. If all pull, return true. + /// @returns a tuple of `{bytes read, all-pulled flag}`. The flag is true if + /// all data has been pulled. std::tuple pull(Span Buf) noexcept; /// Return ArrayOutput data size. diff --git a/plugins/wasi_crypto/common/options.h b/plugins/wasi_crypto/common/options.h index e07578c1..2225d513 100644 --- a/plugins/wasi_crypto/common/options.h +++ b/plugins/wasi_crypto/common/options.h @@ -32,7 +32,7 @@ namespace Common { /// Some functions support options. For example, options can be used to access /// the features that are only relevant to specific ciphers and hash functions. /// -/// Options are represented as a (key, value) map, keys are strings. They are +/// Options are represented as a (key, value) map with string keys. They are /// attached to a context, such as a cipher state. Applications can set, and /// also read the value associated with a key in order to either get the default /// value or obtain the runtime information. diff --git a/plugins/wasi_crypto/kx/dh/ecdsa.h b/plugins/wasi_crypto/kx/dh/ecdsa.h index f41cd63b..ef3648db 100644 --- a/plugins/wasi_crypto/kx/dh/ecdsa.h +++ b/plugins/wasi_crypto/kx/dh/ecdsa.h @@ -8,7 +8,7 @@ //===----------------------------------------------------------------------===// /// /// \file -/// This file contains the definition of ecdsa algorithm. +/// This file contains the definition of the ECDSA algorithm. /// //===----------------------------------------------------------------------===// diff --git a/plugins/wasi_crypto/kx/dh/x25519.h b/plugins/wasi_crypto/kx/dh/x25519.h index 806d6dd7..ad3de81a 100644 --- a/plugins/wasi_crypto/kx/dh/x25519.h +++ b/plugins/wasi_crypto/kx/dh/x25519.h @@ -8,7 +8,7 @@ //===----------------------------------------------------------------------===// /// /// \file -/// This file contains the definition of x25519 algorithm. +/// This file contains the definition of the X25519 algorithm. /// //===----------------------------------------------------------------------===// diff --git a/plugins/wasi_crypto/signatures/ecdsa.cpp b/plugins/wasi_crypto/signatures/ecdsa.cpp index e7b2162e..79c53f1d 100644 --- a/plugins/wasi_crypto/signatures/ecdsa.cpp +++ b/plugins/wasi_crypto/signatures/ecdsa.cpp @@ -84,10 +84,10 @@ template WasiCryptoExpect::Signature> Ecdsa::SignState::sign() noexcept { size_t Size; - // For ecdsa, OpenSSL produce a der format signatures which means the size is - // not fixed. Here is an answer talk about it: + // For ECDSA, OpenSSL produces DER-formatted signatures, which means the size + // is not fixed. For more context, see: // https://bitcoin.stackexchange.com/questions/77191/what-is-the-maximum-size-of-a-der-encoded-ecdsa-signature - // So instead of fixing size, just read. + // So instead of fixing the size, just read it. std::scoped_lock Lock{Ctx->Mutex}; opensslCheck(EVP_DigestSignFinal(Ctx->RawCtx.get(), nullptr, &Size)); diff --git a/plugins/wasi_crypto/signatures/ecdsa.h b/plugins/wasi_crypto/signatures/ecdsa.h index 41c0d27c..f3d6d35f 100644 --- a/plugins/wasi_crypto/signatures/ecdsa.h +++ b/plugins/wasi_crypto/signatures/ecdsa.h @@ -8,7 +8,7 @@ //===----------------------------------------------------------------------===// /// /// \file -/// This file contains the definition of ecdsa algorithm. +/// This file contains the definition of the ECDSA algorithm. /// //===----------------------------------------------------------------------===// diff --git a/plugins/wasi_crypto/signatures/eddsa.cpp b/plugins/wasi_crypto/signatures/eddsa.cpp index d1a44f24..08f15260 100644 --- a/plugins/wasi_crypto/signatures/eddsa.cpp +++ b/plugins/wasi_crypto/signatures/eddsa.cpp @@ -201,8 +201,8 @@ WasiCryptoExpect> Eddsa::Signature::exportData( WasiCryptoExpect Eddsa::SignState::update(Span Input) noexcept { - // Notice: Ecdsa is oneshot in OpenSSL, we need a cache for updating instead - // of calling `EVP_DigestSignUpdate`. + // Notice: EdDSA is one-shot in OpenSSL, so we need a cache for updating + // instead of calling `EVP_DigestSignUpdate`. std::scoped_lock Lock{Ctx->Mutex}; Ctx->Data.insert(Ctx->Data.end(), Input.begin(), Input.end()); diff --git a/plugins/wasi_crypto/signatures/eddsa.h b/plugins/wasi_crypto/signatures/eddsa.h index 99674408..3ba516dc 100644 --- a/plugins/wasi_crypto/signatures/eddsa.h +++ b/plugins/wasi_crypto/signatures/eddsa.h @@ -8,7 +8,7 @@ //===----------------------------------------------------------------------===// /// /// \file -/// This file contains the definition of eddsa algorithm. +/// This file contains the definition of the EdDSA algorithm. /// //===----------------------------------------------------------------------===// diff --git a/plugins/wasi_crypto/signatures/rsa.h b/plugins/wasi_crypto/signatures/rsa.h index 878324ae..81631080 100644 --- a/plugins/wasi_crypto/signatures/rsa.h +++ b/plugins/wasi_crypto/signatures/rsa.h @@ -8,7 +8,7 @@ //===----------------------------------------------------------------------===// /// /// \file -/// This file contains the declaration of rsa and related classes. +/// This file contains the declaration of RSA and related classes. /// //===----------------------------------------------------------------------===// diff --git a/plugins/wasi_crypto/signatures/signstate.h b/plugins/wasi_crypto/signatures/signstate.h index 5ed5f61d..6cc733ae 100644 --- a/plugins/wasi_crypto/signatures/signstate.h +++ b/plugins/wasi_crypto/signatures/signstate.h @@ -25,7 +25,7 @@ namespace Host { namespace WasiCrypto { namespace Signatures { -/// Signatures computation. +/// Signature computation. /// /// More detailed: /// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#signature-creation diff --git a/plugins/wasi_crypto/symmetric/aeads.h b/plugins/wasi_crypto/symmetric/aeads.h index 4e580485..d44123c5 100644 --- a/plugins/wasi_crypto/symmetric/aeads.h +++ b/plugins/wasi_crypto/symmetric/aeads.h @@ -26,7 +26,7 @@ namespace Host { namespace WasiCrypto { namespace Symmetric { -/// Aeads invalid operations, every Aeads state should inherent from this class. +/// Aeads invalid operations. Every Aeads state should inherit from this class. /// /// More detailed: /// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#aeads @@ -77,8 +77,8 @@ template class Cipher { class State : public AEADsState { public: /// There are four inputs for authenticated encryption: - /// @param[in] Key The secret key for encrypt - /// @param[in] OptOption `Must` contain an Nonce (Initialization vector). + /// @param[in] Key The secret key for encryption. + /// @param[in] OptOption `Must` contain a Nonce (initialization vector). static WasiCryptoExpect open(const Key &Key, OptionalRef OptOption) noexcept; @@ -106,7 +106,7 @@ template class Cipher { WasiCryptoExpect maxTagLen() const noexcept; /// Check Out.size() == Data.size() + maxTagLen(), then call - /// encryptUnchecked(Out, Data) or return error if not equal. + /// encryptUnchecked(Out, Data), or return an error if they are not equal. /// /// @param Out The encrypted data text /// @param Data The data to be encrypted @@ -116,7 +116,8 @@ template class Cipher { Span Data) noexcept; /// Check Out.size() == Data.size(), then call - /// encryptDetachedUnchecked(Out, Data) or error if not equal + /// encryptDetachedUnchecked(Out, Data), or return an error if they are not + /// equal. /// /// @param Out The encrypted data text /// @param Data The data to be encrypted @@ -126,7 +127,8 @@ template class Cipher { Span Data) noexcept; /// Check Out.size() = Data.size() + maxTagLen(), then call - /// decryptDetachedUnchecked(Out, Data) or error if not equal + /// decryptDetachedUnchecked(Out, Data), or return an error if they are not + /// equal. /// /// @param Out The decrypted data text /// @param Data The data to be decrypted @@ -136,7 +138,8 @@ template class Cipher { Span Data) noexcept; /// Check Out.size() == Data.size(), then call - /// encryptDetachedUnchecked(Out, Data) or error if not equal + /// encryptDetachedUnchecked(Out, Data), or return an error if they are not + /// equal. /// /// @param Out The decrypted data text /// @param Data The data to be decrypted diff --git a/plugins/wasi_crypto/symmetric/hash.h b/plugins/wasi_crypto/symmetric/hash.h index 59257a42..7baf01b0 100644 --- a/plugins/wasi_crypto/symmetric/hash.h +++ b/plugins/wasi_crypto/symmetric/hash.h @@ -26,8 +26,8 @@ namespace Host { namespace WasiCrypto { namespace Symmetric { -/// Hash never have key, just a placement, every hash key should inherent from -/// this class. +/// Hash never has a key; it is only a placeholder. Every hash key should +/// inherit from this class. /// /// More detailed: /// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#hash-functions @@ -44,16 +44,16 @@ template class HashKey { SecretVec exportData() const noexcept { assumingUnreachable(); } }; -/// Hash invalid operations, every hash state should inherent from this class. +/// Hash invalid operations. Every hash state should inherit from this class. template class HashState { public: - /// Current hash not support any options. + /// The current hash does not support any options. WasiCryptoExpect optionsGet(std::string_view, Span) const noexcept { return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); } - /// Current hash not support any options. + /// The current hash does not support any options. WasiCryptoExpect optionsGetU64(std::string_view) const noexcept { return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); } @@ -97,7 +97,7 @@ template class HashState { template class Sha2 { public: - /// In fact, sha2 key will never produce. This design is for removing the + /// In fact, sha2 keys are never produced. This design removes the /// forwarding declaration. class Key : public HashKey {}; diff --git a/plugins/wasi_crypto/symmetric/kdf.h b/plugins/wasi_crypto/symmetric/kdf.h index 4834f385..e17f77b2 100644 --- a/plugins/wasi_crypto/symmetric/kdf.h +++ b/plugins/wasi_crypto/symmetric/kdf.h @@ -28,20 +28,20 @@ namespace Host { namespace WasiCrypto { namespace Symmetric { -/// Expand invalid operations, every expand state should inherent from this +/// Expand invalid operations. Every expand state should inherit from this /// class. /// /// More detailed: /// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#key-derivation-using-extract-and-expand template class ExpandState { public: - /// Current kdf not support any options. + /// The current kdf does not support any options. WasiCryptoExpect optionsGet(std::string_view, Span) const noexcept { return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); } - /// Current kdf not support any options. + /// The current kdf does not support any options. WasiCryptoExpect optionsGetU64(std::string_view) const noexcept { return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); } @@ -83,17 +83,17 @@ template class ExpandState { } }; -/// Extract invalid operations, every extract state should inherent from this +/// Extract invalid operations. Every extract state should inherit from this /// class. template class ExtractState { public: - /// Current kdf not support any options. + /// The current kdf does not support any options. WasiCryptoExpect optionsGet(std::string_view, Span) const noexcept { return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); } - /// Current kdf not support any options. + /// The current kdf does not support any options. WasiCryptoExpect optionsGetU64(std::string_view) const noexcept { return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); } diff --git a/plugins/wasi_crypto/symmetric/mac.h b/plugins/wasi_crypto/symmetric/mac.h index 57fa5d43..fad820b9 100644 --- a/plugins/wasi_crypto/symmetric/mac.h +++ b/plugins/wasi_crypto/symmetric/mac.h @@ -31,19 +31,19 @@ namespace Host { namespace WasiCrypto { namespace Symmetric { -/// Mac invalid operation, every mac state should inherent from this class +/// Mac invalid operations. Every mac state should inherit from this class. /// /// More detailed: /// https://github.com/WebAssembly/wasi-crypto/blob/main/docs/wasi-crypto.md#message-authentication-codes template class MacState { public: - /// Current mac not support any options. + /// The current mac does not support any options. WasiCryptoExpect optionsGet(std::string_view, Span) const noexcept { return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); } - /// Current mac not support any options. + /// The current mac does not support any options. WasiCryptoExpect optionsGetU64(std::string_view) const noexcept { return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_UNSUPPORTED_OPTION); } @@ -119,7 +119,7 @@ template class Hmac { WasiCryptoExpect absorb(Span Data) noexcept; /// Authenticates the input received up to the function call. - /// If the finalization is required, the implementation MUST duplicate the + /// If finalization is required, the implementation MUST duplicate the /// internal state and apply the finalization on the copy, leaving the state /// unchanged from the guest perspective. /// diff --git a/plugins/wasi_crypto/symmetric/state.cpp b/plugins/wasi_crypto/symmetric/state.cpp index c759d219..e532b599 100644 --- a/plugins/wasi_crypto/symmetric/state.cpp +++ b/plugins/wasi_crypto/symmetric/state.cpp @@ -55,7 +55,7 @@ openState(Algorithm Alg, OptionalRef OptKeyVariant, [OptOptions](const auto &Key) -> WasiCryptoExpect { using InKeyType = std::decay_t; if constexpr (!std::is_same_v) { - // Key type not same. + // Key types do not match. return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_KEY); } else { // Key type fitted. diff --git a/plugins/wasi_crypto/symmetric/state.h b/plugins/wasi_crypto/symmetric/state.h index 9759753e..f681f7ce 100644 --- a/plugins/wasi_crypto/symmetric/state.h +++ b/plugins/wasi_crypto/symmetric/state.h @@ -8,7 +8,7 @@ //===----------------------------------------------------------------------===// /// /// \file -/// This file contains the symmetric state related classes, and provide a +/// This file contains the symmetric state related classes, and provides a /// unified interface which can be used to implement the algorithm operations. /// //===----------------------------------------------------------------------===// @@ -27,7 +27,7 @@ namespace Host { namespace WasiCrypto { namespace Symmetric { -/// State created from key, and performs symmetric operations with using the +/// State created from a key, and performs symmetric operations using the /// underlying algorithms. /// /// More detail: diff --git a/plugins/wasi_crypto/symmetric/tag.h b/plugins/wasi_crypto/symmetric/tag.h index 9daa458b..6ecc26a4 100644 --- a/plugins/wasi_crypto/symmetric/tag.h +++ b/plugins/wasi_crypto/symmetric/tag.h @@ -25,7 +25,7 @@ namespace Host { namespace WasiCrypto { namespace Symmetric { -/// Authentication tag, that can be verified without channels using the provided +/// Authentication tag that can be verified without channels using the provided /// APIs. Very small and no streaming. /// /// More detail: @@ -42,7 +42,7 @@ class Tag { size_t len() const noexcept { return Data.size(); } /// The function MUST return `__WASI_CRYPTO_ERRNO_INVALID_TAG` if the - /// tags don't match. + /// tags do not match. WasiCryptoExpect verify(Span RawTag) const noexcept; WasiCryptoExpect pull(Span Raw) const noexcept; diff --git a/plugins/wasi_crypto/utils/evp_wrapper.h b/plugins/wasi_crypto/utils/evp_wrapper.h index 893fb0e5..9827e6f8 100644 --- a/plugins/wasi_crypto/utils/evp_wrapper.h +++ b/plugins/wasi_crypto/utils/evp_wrapper.h @@ -8,7 +8,7 @@ //===----------------------------------------------------------------------===// /// /// \file -/// This file contains the definition of OpenSSL evp relative function. +/// This file contains the definitions of OpenSSL EVP-related functions. /// //===----------------------------------------------------------------------===// @@ -55,7 +55,7 @@ using EcdsaSigPtr = OpenSSLUniquePtr; using RsaPtr = OpenSSLUniquePtr; /// OpenSSL functions always return 1 for success and 0/NULL for failure. This -/// is used to reduce repeating checking. +/// is used to reduce repeated checks. #ifdef NDEBUG #define opensslCheck(Cond) \ do { \ @@ -114,13 +114,13 @@ WasiCryptoExpect> i2dEcdsaSig(ECDSA_SIG *Sig); // Transform raw represent ecdsa ( r | s) to ECDSA_SIG. Need to check `nullptr`. ECDSA_SIG *o2iEcdsaSig(Span Encoded); -// Transform ECDSA_SIG to raw represent ( r | s). +// Transform ECDSA_SIG to raw representation (r | s). WasiCryptoExpect> i2oEcdsaSig(ECDSA_SIG *Sig); -// This is a wrapper for EVP_PKEY, since EVP_PKEY inner use lock to guarantee -// thread-safe `EVP_PKEY_up_ref` (you will find them in crypto/evp/p_lib.c in -// OpenSSL v1.1.1), use shared_ptr for `EVP_PKEY` is wasted. -// It only provide limits function to correct use. +// This is a wrapper for EVP_PKEY. Since EVP_PKEY internally uses locks to +// guarantee thread-safe `EVP_PKEY_up_ref` (you will find them in +// crypto/evp/p_lib.c in OpenSSL v1.1.1), using shared_ptr for `EVP_PKEY` is +// wasteful. It only provides limited functions for correct use. class SharedEvpPkey { public: SharedEvpPkey(EvpPkeyPtr Pkey) noexcept : Pkey(Pkey.release()) {} @@ -128,7 +128,8 @@ class SharedEvpPkey { SharedEvpPkey(const SharedEvpPkey &Rhs) noexcept; SharedEvpPkey(SharedEvpPkey &&Rhs) noexcept; - // Assigning to existing SharedEvpPkey is not thread-safe, delete them. + // Assigning to an existing SharedEvpPkey is not thread-safe, so delete the + // assignment operators. SharedEvpPkey &operator=(const SharedEvpPkey &Rhs) noexcept = delete; SharedEvpPkey &operator=(SharedEvpPkey &&Rhs) noexcept = delete; diff --git a/plugins/wasi_crypto/utils/handles_manager.h b/plugins/wasi_crypto/utils/handles_manager.h index 8eb45603..f7f072ea 100644 --- a/plugins/wasi_crypto/utils/handles_manager.h +++ b/plugins/wasi_crypto/utils/handles_manager.h @@ -32,11 +32,10 @@ namespace detail { /// The Handles Manager base class. /// -/// @tparam HandleType This is the type of handle, notice it must be `32-bit -/// long`. +/// @tparam HandleType This is the handle type. It must be 32 bits wide. /// @tparam ManagerType The managed content type. /// -/// HandlesManager uses handle as index to represent the managed contents. +/// HandlesManager uses a handle as the index to represent the managed contents. /// /// Referenced from: /// https://github.com/WebAssembly/wasi-crypto/blob/main/implementations/hostcalls/rust/src/handles.rs @@ -60,36 +59,36 @@ template class BaseHandlesManager { return {}; } - /// Constructor a new manager. + /// Construct a new manager. template WasiCryptoExpect registerManager(Args &&...Manager) noexcept { std::unique_lock Lock{Mutex}; - // Find a handle that can be used and emplace. + // Find an available handle and emplace it. // Assume that LastHandle is 0x01000000, NextHandle is 0x01000001. auto NextHandle = LastHandle.nextHandle(); while (true) { // Try to emplace NextHandle. if (Map.try_emplace(NextHandle, std::forward(Manager)...).second) { - // If success, emplace which indicate the NextHandle not exists in the - // managed content. Update the last handle and return it. + // If this succeeds, the emplacement indicates that NextHandle does not + // exist in the managed content. Update the last handle and return it. LastHandle = NextHandle; return LastHandle.Handle; } - // Otherwise, the NextHandle Map already exists a content, call NextHandle - // and loop. + // Otherwise, the NextHandle map already contains content. Advance + // NextHandle and loop. NextHandle = NextHandle.nextHandle(); - // If after looping `many times(2^24 - 1)`, we get 0x01000000 again. + // If, after looping many times (2^24 - 1), we get 0x01000000 again, the + // hash map is full. if (NextHandle == LastHandle) { - // It indicates the hashmap is full. return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_TOO_MANY_HANDLES); } } } protected: - /// The handle internal representation as [-TypeID-|------CurrentNumber------] + /// The handle internal representation: [-TypeID-|------CurrentNumber------] union HandleWrapper { static_assert(sizeof(HandleType) == 4, "HandleType must be 4 byte"); HandleWrapper(uint8_t TypeID, uint32_t CurrentNumber) noexcept diff --git a/plugins/wasi_crypto/utils/optional.h b/plugins/wasi_crypto/utils/optional.h index 85d06ad6..0752ed5e 100644 --- a/plugins/wasi_crypto/utils/optional.h +++ b/plugins/wasi_crypto/utils/optional.h @@ -8,7 +8,7 @@ //===----------------------------------------------------------------------===// /// /// \file -/// This file contains the definition of the OptionalRef and some helper +/// This file contains the definition of OptionalRef and some helper /// functions. /// //===----------------------------------------------------------------------===// diff --git a/plugins/wasi_nn/CMakeLists.txt b/plugins/wasi_nn/CMakeLists.txt index dc3d8b5e..d7be3bcf 100644 --- a/plugins/wasi_nn/CMakeLists.txt +++ b/plugins/wasi_nn/CMakeLists.txt @@ -38,8 +38,8 @@ target_compile_definitions(wasmedgePluginWasiNN PRIVATE WASI_NN_VERSION_MINOR=${WASI_NN_VERSION_MINOR} WASI_NN_VERSION_PATCH=${WASI_NN_VERSION_PATCH} ) -# This for-each iteration is for the additional sources. -# The dependencies are moved into `cmake/WASINNDeps.cmake`. +# This foreach iteration handles the additional sources. +# The dependencies are moved to `cmake/WASINNDeps.cmake`. foreach(BACKEND ${WASMEDGE_PLUGIN_WASI_NN_BACKEND}) string(TOLOWER ${BACKEND} BACKEND) if(BACKEND STREQUAL "mlx") diff --git a/plugins/wasi_nn/GGML/compute/inference_manager.cpp b/plugins/wasi_nn/GGML/compute/inference_manager.cpp index 468b098c..83dd3b06 100644 --- a/plugins/wasi_nn/GGML/compute/inference_manager.cpp +++ b/plugins/wasi_nn/GGML/compute/inference_manager.cpp @@ -12,7 +12,7 @@ namespace WasmEdge::Host::WASINN::GGML { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML namespace { -// Fill tokens (smaller than batch size) into a batch with position data. +// Fill a batch with tokens (smaller than batch size) and position data. void fillBatch(Span Tokens, Graph &GraphRef, llama_batch &Batch, int &NPos, bool IsLogit = false) { assuming(GraphRef.Params.n_batch >= static_cast(Tokens.size())); @@ -27,7 +27,7 @@ void fillBatch(Span Tokens, Graph &GraphRef, Batch.logits[I] = false; } - // Logits of sampling or end of inputs. + // Logits for sampling or the end of inputs. if (IsLogit) { Batch.logits[Tokens.size() - 1] = true; } @@ -36,7 +36,7 @@ void fillBatch(Span Tokens, Graph &GraphRef, NPos += static_cast(Tokens.size()); } -// Evaluate tokens. Construct the tokens into batch and decode. +// Evaluate tokens. Construct the batch from tokens and decode. ErrNo evaluateTokens(Span Tokens, Graph &GraphRef, llama_batch &Batch, int &NPos, bool IsLogits = false) noexcept { @@ -51,7 +51,7 @@ ErrNo evaluateTokens(Span Tokens, Graph &GraphRef, return ErrNo::ContextFull; } - // Loop for decode batch. Split tokens into batch size length. + // Loop for decoding batches. Split tokens by batch size. for (int I = 0; I < static_cast(Tokens.size()); I += static_cast(GraphRef.Params.n_batch)) { int NEval = static_cast(Tokens.size()) - I; @@ -103,7 +103,7 @@ void buildOutputEmbedding(std::string &Embedding, int32_t NEmbd, } } // namespace -// Evaluate the input tokens. Clean all inputs if succeeded. +// Evaluate the input tokens. Clear all inputs on success. ErrNo evaluateInput(Graph &GraphRef, Context &CxtRef, std::string_view LogPrefix) noexcept { // Check if the input is set before setting up the context. diff --git a/plugins/wasi_nn/GGML/core/ggml_core.cpp b/plugins/wasi_nn/GGML/core/ggml_core.cpp index 2815f1cb..ce422a21 100644 --- a/plugins/wasi_nn/GGML/core/ggml_core.cpp +++ b/plugins/wasi_nn/GGML/core/ggml_core.cpp @@ -95,7 +95,8 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, llama_log_set(llamaLogCallback, &GraphRef); mtmd_helper_log_set(llamaLogCallback, &GraphRef); - // If the graph builder length > 1, the data of builder[1] is the metadata. + // If the graph builder length is greater than 1, builder[1] contains the + // metadata. if (Builders.size() > 1) { const std::string Metadata(reinterpret_cast(Builders[1].data()), Builders[1].size()); diff --git a/plugins/wasi_nn/GGML/core/ggml_type.h b/plugins/wasi_nn/GGML/core/ggml_type.h index 36e6f482..cda10b0e 100644 --- a/plugins/wasi_nn/GGML/core/ggml_type.h +++ b/plugins/wasi_nn/GGML/core/ggml_type.h @@ -31,10 +31,10 @@ enum class EmbdNormalizeType : int32_t { }; struct LocalConfig { - // Configurations which can be changed in every contexts. - // The graph handles a default config and parsed from metadata when loading. - // The context inherits a copy from graph when creating, and can be modified - // when parsing metadata in set_input. + // Configuration values that can be changed in every context. + // The graph holds a default config parsed from metadata during loading. + // The context inherits a copy from the graph during creation and can be + // modified when parsing metadata in set_input. bool StreamStdout = false; EmbdNormalizeType EmbdNormalize = EmbdNormalizeType::Euclidean; int64_t NPredict; @@ -75,11 +75,11 @@ struct Context { // Llama outputs: std::vector LlamaOutputs; std::vector LlamaOutputTokens; - // Data for computing: + // Data for computation: bool ComputeSingleStarted = false; struct common_sampler *LlamaSampler = nullptr; - // Handle the batch in the context to prevent from reallocation in every - // computing. + // Handle the batch in the context to prevent reallocation during every + // computation. struct llama_batch LlamaBatch; struct llama_batch OutputBatch; int64_t CurrentBatchSize = 0; @@ -118,7 +118,7 @@ namespace { // Macro for logging error message. #define LOG_ERROR(...) spdlog::error("[WASI-NN] GGML backend: "sv __VA_ARGS__); -// Macro for logging error message and return. +// Macro for logging an error message and returning. #define RET_ERROR(Error, ...) \ spdlog::error("[WASI-NN] GGML backend: "sv __VA_ARGS__); \ return Error; diff --git a/plugins/wasi_nn/GGML/core/input_processor.cpp b/plugins/wasi_nn/GGML/core/input_processor.cpp index fc1eaf91..12f57de1 100644 --- a/plugins/wasi_nn/GGML/core/input_processor.cpp +++ b/plugins/wasi_nn/GGML/core/input_processor.cpp @@ -36,14 +36,14 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } #ifndef __APPLE__ - // XXX: Due to the limitation of WASI-NN proposal, this is a workaround - // for non-macOS devices. However, if the model params is updated in - // Config stage, then, we don't encourage to use this to avoid the model - // reloading. + // XXX: Because of the limitation in the WASI-NN proposal, this is a + // workaround for non-macOS devices. However, if the model params are + // updated in the configuration stage, we do not recommend using this to + // avoid reloading the model. { if (IsModelParamsUpdated || GraphRef.LlamaModel == nullptr) { - // The llama model may be nullptr if set_input with updated model params - // last time. Therefore besides the model params updated, we should + // The llama model may be nullptr if set_input updated the model params + // last time. Therefore, in addition to updated model params, we should // reload the llama model if the model is nullptr. LOG_INFO(GraphRef.EnableLog, "setInput: Reload model due to parameters change."sv) @@ -70,7 +70,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } #endif - // Some changes of context parameters will require the context to be + // Some changes to context parameters will require the context to be // reloaded. if (IsContextParamsUpdated || GraphRef.LlamaContext == nullptr) { LOG_INFO(GraphRef.EnableLog, @@ -85,7 +85,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } } - // Some changes of sampling parameters will require the sampler to be + // Some changes to sampling parameters will require the sampler to be // reallocated. if (IsSamplerParamsUpdated || CxtRef.LlamaSampler == nullptr) { LOG_INFO(GraphRef.EnableLog, @@ -101,7 +101,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } } - // Check that is batch size changed. + // Check whether the batch size changed. if (CxtRef.CurrentBatchSize != GraphRef.Params.n_batch) { llama_batch_free(CxtRef.LlamaBatch); CxtRef.LlamaBatch = allocBatch(GraphRef.Params.n_batch); @@ -114,7 +114,8 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, return ErrNo::Success; } - // Check the graph is valid after reloading during previous set_input. + // Check that the graph is valid after reloading during the previous + // set_input. if (!Env.NNGraph[CxtRef.GraphId].isReady()) { RET_ERROR( ErrNo::InvalidArgument, @@ -137,7 +138,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, auto Base64ImagePos = findBase64ImagePayload(Prompt); if (Base64ImagePos.has_value() || CxtRef.Conf.ImagePath != ""sv) { - // First check the projection model is given. + // First check whether the projection model is provided. if (GraphRef.Params.mmproj.path == ""sv) { RET_ERROR( ErrNo::InvalidArgument, @@ -241,7 +242,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize the mtmd prompt...Done"sv) - // Get the number of input tokens (for the metadata). + // Get the number of input tokens for the metadata. CxtRef.LlamaNInputs = 0; for (size_t ChunkIndex = 0; ChunkIndex < mtmd_input_chunks_size(GraphRef.VisionInputChunks.get()); @@ -262,7 +263,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize tts prompt...Done"sv) - // Get the number of input tokens (for the metadata). + // Get the number of input tokens for the metadata. CxtRef.LlamaNInputs = CxtRef.LlamaInputs.size(); } else { // Text only prompt. @@ -272,11 +273,12 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize text prompt...Done"sv) - // Get the number of input tokens (for the metadata). + // Get the number of input tokens for the metadata. CxtRef.LlamaNInputs = CxtRef.LlamaInputs.size(); } - // Maybe currently in the compute_single mode. Reset the computing. + // The context may currently be in compute_single mode. Reset the compute + // state. CxtRef.ComputeSingleStarted = false; LOG_DEBUG(GraphRef.EnableDebugLog, "setInput...Done"sv) diff --git a/plugins/wasi_nn/GGML/core/output_generator.cpp b/plugins/wasi_nn/GGML/core/output_generator.cpp index 9cd6667c..f5c5dda2 100644 --- a/plugins/wasi_nn/GGML/core/output_generator.cpp +++ b/plugins/wasi_nn/GGML/core/output_generator.cpp @@ -24,7 +24,7 @@ Expect getOutputSingle(WasiNNEnvironment &Env, uint32_t ContextId, auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); LOG_DEBUG(GraphRef.EnableDebugLog, "getOutputSingle: with Index {}"sv, Index) - // Use index 1 for the metadata of the outputs. + // Use index 1 for output metadata. if (Index == 1) { std::string Metadata = buildOutputMetadata(CxtRef); std::copy_n(Metadata.data(), Metadata.length(), OutBuffer.data()); @@ -50,7 +50,7 @@ Expect getOutput(WasiNNEnvironment &Env, uint32_t ContextId, auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get(); LOG_DEBUG(GraphRef.EnableDebugLog, "getOutput: with Index {}"sv, Index) - // Use index 1 for the metadata of the outputs. + // Use index 1 for output metadata. if (Index == 1) { std::string Metadata = buildOutputMetadata(CxtRef); std::copy_n(Metadata.data(), Metadata.length(), OutBuffer.data()); diff --git a/plugins/wasi_nn/GGML/metadata/metadata_parser.cpp b/plugins/wasi_nn/GGML/metadata/metadata_parser.cpp index 02e643c4..fc98904b 100644 --- a/plugins/wasi_nn/GGML/metadata/metadata_parser.cpp +++ b/plugins/wasi_nn/GGML/metadata/metadata_parser.cpp @@ -12,7 +12,7 @@ namespace WasmEdge::Host::WASINN::GGML { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML -// Parse metadata from json. +// Parse metadata from JSON. ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, const std::string &Metadata, bool *IsModelUpdated, bool *IsContextUpdated, bool *IsSamplerUpdated) noexcept { @@ -521,7 +521,7 @@ ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, } catch (const ErrNo &Error) { return Error; } - // The tensor buffer overrides should terminated with empty pattern. + // The tensor buffer overrides should be terminated with an empty pattern. if (!GraphRef.TensorBuftOverrides.empty()) { for (const std::string &Override : GraphRef.TensorBuftOverrides) { GraphRef.Params.tensor_buft_overrides.push_back( diff --git a/plugins/wasi_nn/GGML/tts/tts_core.cpp b/plugins/wasi_nn/GGML/tts/tts_core.cpp index 78535eb2..86d332bd 100644 --- a/plugins/wasi_nn/GGML/tts/tts_core.cpp +++ b/plugins/wasi_nn/GGML/tts/tts_core.cpp @@ -385,7 +385,7 @@ getSpeakerProfileFromFile(const std::string &FilePath, WasiNNEnvironment &Env) { JsonFile >> JsonData; JsonFile.close(); - // Initialize the outputs + // Initialize the outputs. std::string AudioOutputText = "<|audio_start|>\n"; std::string TextOutput = "<|text_start|>"; @@ -412,7 +412,7 @@ getSpeakerProfileFromFile(const std::string &FilePath, WasiNNEnvironment &Env) { return TTSSpeakerProfile{TextOutput, AudioOutputText}; } -// TextToSpeech function, will generate voice data from codes. +// TextToSpeech function that generates voice data from codes. ErrNo codesToSpeech(WasiNNEnvironment &Env, Graph &GraphRef, Context &CxtRef) noexcept { // Remove all non-audio tokens. @@ -456,7 +456,7 @@ ErrNo codesToSpeech(WasiNNEnvironment &Env, Graph &GraphRef, AudioData[I] = 0.0f; } - // Convert audio data to wav and put it into output buffer. + // Convert audio data to WAV and put it into the output buffer. CxtRef.LlamaOutputs = audioDataToWav(AudioData, SamplingRate); // Save .wav file if path is provided. diff --git a/plugins/wasi_nn/GGML/utils.cpp b/plugins/wasi_nn/GGML/utils.cpp index 72e9650e..29df8ab6 100644 --- a/plugins/wasi_nn/GGML/utils.cpp +++ b/plugins/wasi_nn/GGML/utils.cpp @@ -9,7 +9,7 @@ namespace WasmEdge::Host::WASINN::GGML { #ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_GGML -// Helper to init a llama batch. +// Helper to initialize a llama batch. struct llama_batch allocBatch(int64_t NTokens, int64_t Embd, int32_t NSeqMax) noexcept { struct llama_batch Batch = llama_batch_init( diff --git a/plugins/wasi_nn/MLX/mlx/convolution.h b/plugins/wasi_nn/MLX/mlx/convolution.h index 9e824dcb..db475a5f 100644 --- a/plugins/wasi_nn/MLX/mlx/convolution.h +++ b/plugins/wasi_nn/MLX/mlx/convolution.h @@ -44,7 +44,7 @@ class Conv2d : public nn::Module { : Padding(Padding), Stride(Stride), Dilation(Dilation), Groups(Groups) { if (InChannels % Groups != 0) { - // InChannels must be divisible by Groups + // InChannels must be divisible by Groups. assumingUnreachable(); } double Scale = std::sqrt(1.0 / (InChannels * KernelSize * KernelSize)); diff --git a/plugins/wasi_nn/MLX/model/whisper/decoding.cpp b/plugins/wasi_nn/MLX/model/whisper/decoding.cpp index e0a93ff4..2be9c624 100644 --- a/plugins/wasi_nn/MLX/model/whisper/decoding.cpp +++ b/plugins/wasi_nn/MLX/model/whisper/decoding.cpp @@ -332,8 +332,8 @@ mx::array ApplyTimestampRules::apply(const mx::array &Logits, } } - // If sum of probability over timestamps is above any other token, sample - // timestamp + // If the sum of probabilities over timestamps is above any other token, + // sample the timestamp. mx::array MaskArray = mx::array(MaskVec.data(), LogitsShape, mx::float32); mx::array Logprobs = Logits - mx::logsumexp(Logits, -1, true); diff --git a/plugins/wasi_nn/MLX/model/whisper_transcribe.cpp b/plugins/wasi_nn/MLX/model/whisper_transcribe.cpp index fa9a0660..546952f6 100644 --- a/plugins/wasi_nn/MLX/model/whisper_transcribe.cpp +++ b/plugins/wasi_nn/MLX/model/whisper_transcribe.cpp @@ -155,7 +155,7 @@ mx::array stft(const mx::array &X, const mx::array &Window, int NPerseg = 256, } mx::array melFilters(int NMels) { - // Load precomputed mel filters from file + // Load precomputed mel filters from a file. if (NMels != 80 && NMels != 128) { spdlog::error("Unsupported number of mel filters: " + std::to_string(NMels)); @@ -287,7 +287,7 @@ decodeWithFallback(std::shared_ptr Model, return Result; } -// Word-level timestamp functions +// Word-level timestamp functions. void addWordTimestamps(std::vector &Segments, std::shared_ptr Model, std::unique_ptr &Tokenizer, @@ -296,8 +296,8 @@ void addWordTimestamps(std::vector &Segments, const std::string &AppendPunctuations, float LastSpeechTimestamp) { - // This is a simplified implementation - // Full implementation would use cross-attention patterns and DTW + // This is a simplified implementation. + // A full implementation would use cross-attention patterns and DTW. for (auto &Segment : Segments) { if (!Segment.Tokens.empty()) { float Duration = Segment.End - Segment.Start; diff --git a/plugins/wasi_nn/wasinn_bitnet.cpp b/plugins/wasi_nn/wasinn_bitnet.cpp index e184def0..15564a70 100644 --- a/plugins/wasi_nn/wasinn_bitnet.cpp +++ b/plugins/wasi_nn/wasinn_bitnet.cpp @@ -38,7 +38,7 @@ namespace { #define LOG_ERROR(...) \ spdlog::error("[WASI-NN] BitNet backend: "sv __VA_ARGS__); -// Macro for logging error message and return. +// Macro for logging an error message and returning. #define RET_ERROR(Error, ...) \ spdlog::error("[WASI-NN] BitNet backend: "sv __VA_ARGS__); \ return Error; @@ -83,7 +83,7 @@ void stringToList(const std::string &Raw, std::vector &Out) { } } -// Parse metadata from json. +// Parse metadata from JSON. ErrNo parseMetadata(Graph &GraphRef, LocalConfig &ConfRef, const std::string &Metadata, bool *IsModelUpdated = nullptr, bool *IsContextUpdated = nullptr, @@ -1603,7 +1603,7 @@ void buildOutputEmbedding(std::string &Embedding, int32_t NEmbd, // >>>>>>>> Compute related functions >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> -// Helper to init a llama batch. +// Helper to initialize a llama batch. struct llama_batch allocBatch(int64_t NTokens, int64_t Embd = 0, int32_t NSeqMax = 1) noexcept { struct llama_batch Batch = llama_batch_init( @@ -1619,7 +1619,7 @@ struct llama_batch allocBatch(int64_t NTokens, int64_t Embd = 0, return Batch; } -// Fill tokens (smaller than batch size) into a batch with position data. +// Fill a batch with tokens (smaller than batch size) and position data. void fillBatch(Span Tokens, Graph &GraphRef, llama_batch &Batch, int &NPos, bool IsLogit = false) { assuming(GraphRef.Params.n_batch >= static_cast(Tokens.size())); @@ -1634,7 +1634,7 @@ void fillBatch(Span Tokens, Graph &GraphRef, Batch.logits[I] = false; } - // Logits of sampling or end of inputs. + // Logits for sampling or the end of inputs. if (IsLogit) { Batch.logits[Tokens.size() - 1] = true; } @@ -1643,7 +1643,7 @@ void fillBatch(Span Tokens, Graph &GraphRef, NPos += static_cast(Tokens.size()); } -// Evaluate tokens. Construct the tokens into batch and decode. +// Evaluate tokens. Construct the batch from tokens and decode. ErrNo evaluateTokens(Span Tokens, Graph &GraphRef, llama_batch &Batch, int &NPos, bool IsLogits = false) noexcept { @@ -1658,7 +1658,7 @@ ErrNo evaluateTokens(Span Tokens, Graph &GraphRef, return ErrNo::ContextFull; } - // Loop for decode batch. Split tokens into batch size length. + // Loop for decoding batches. Split tokens by batch size. for (int I = 0; I < static_cast(Tokens.size()); I += static_cast(GraphRef.Params.n_batch)) { int NEval = static_cast(Tokens.size()) - I; @@ -1701,7 +1701,7 @@ void clearContext(Graph &GraphRef, Context &CxtRef) noexcept { LOG_DEBUG(GraphRef.EnableDebugLog, "{}: clearContext...Done"sv) } -// Evaluate the input tokens. Clean all inputs if succeeded. +// Evaluate the input tokens. Clear all inputs on success. ErrNo evaluateInput(Graph &GraphRef, Context &CxtRef, std::string_view LogPrefix) noexcept { // Check if the input is set before setting up the context. @@ -1926,7 +1926,8 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, llama_log_set(llamaLogCallback, &GraphRef); LOG_DEBUG(GraphRef.EnableDebugLog, "load start."sv) - // If the graph builder length > 1, the data of builder[1] is the metadata. + // If the graph builder length is greater than 1, builder[1] contains the + // metadata. if (Builders.size() > 1) { const std::string Metadata( reinterpret_cast(Builders[1].data()), Builders[1].size()); @@ -2059,8 +2060,8 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, } if (IsModelUpdated || GraphRef.LlamaModel == nullptr) { - // The llama model may be nullptr if set_input with updated model params - // last time. Therefore besides the model params updated, we should + // The llama model may be nullptr if set_input updated the model params + // last time. Therefore, in addition to updated model params, we should // reload the llama model if the model is nullptr. LOG_INFO(GraphRef.EnableLog, "setInput: Reloading model due to parameter change"sv) @@ -2142,7 +2143,8 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, "Only prompt (index 0) and metadata (index 1) are supported."); } - // Check the graph is valid after reloading during previous set_input. + // Check that the graph is valid after reloading during the previous + // set_input. if (!Env.NNGraph[CxtRef.GraphId].isReady()) { RET_ERROR( ErrNo::InvalidArgument, @@ -2170,7 +2172,7 @@ Expect setInput(WasiNNEnvironment &Env, uint32_t ContextId, llama_add_bos_token(GraphRef.LlamaModel.get()), true); LOG_DEBUG(GraphRef.EnableDebugLog, "setInput: tokenize text prompt...Done"sv) - // Get the number of input tokens (for the metadata). + // Get the number of input tokens for the metadata. CxtRef.LlamaNInputs = CxtRef.LlamaInputs.size(); // Reset state for the compute loop. diff --git a/plugins/wasi_nn/wasinn_openvino_genai.cpp b/plugins/wasi_nn/wasinn_openvino_genai.cpp index 4fd945e9..a7053ec9 100644 --- a/plugins/wasi_nn/wasinn_openvino_genai.cpp +++ b/plugins/wasi_nn/wasinn_openvino_genai.cpp @@ -118,7 +118,7 @@ Expect load(WASINN::WasiNNEnvironment &Env, // Builder-1: Path to the dir model xml/bin files // Builder-2: Empty for now (reserved for future use) - // There are 4 types (text or img) x (text or img), we assume the input is 0 + // There are 4 types (text or img) x (text or img); we assume the input is 0 // for now. auto ModelType = std::string( reinterpret_cast(Builders[0].data()), Builders[0].size()); diff --git a/plugins/wasi_nn/wasinn_piper.cpp b/plugins/wasi_nn/wasinn_piper.cpp index 433bc147..7e71efc9 100644 --- a/plugins/wasi_nn/wasinn_piper.cpp +++ b/plugins/wasi_nn/wasinn_piper.cpp @@ -28,7 +28,7 @@ namespace WasmEdge::Host::WASINN::Piper { namespace { -// helper function to write WAV header +// Helper function to write the WAV header. void writeWavHeader(int SampleRate, int16_t NumChannels, int32_t NumSamples, std::vector &OutputBuffer) { int32_t const ByteRate = SampleRate * NumChannels * sizeof(int16_t); diff --git a/plugins/wasi_nn/wasinn_piper.h b/plugins/wasi_nn/wasinn_piper.h index 900abef5..e10e3150 100644 --- a/plugins/wasi_nn/wasinn_piper.h +++ b/plugins/wasi_nn/wasinn_piper.h @@ -35,7 +35,7 @@ struct SynthesisConfig { // Amount of noise to add during audio generation std::optional NoiseScale; - // Speed of speaking (1 = normal, < 1 is faster, > 1 is slower) + // Speech speed (1 = normal, < 1 is faster, > 1 is slower) std::optional LengthScale; // Variation in phoneme lengths diff --git a/plugins/wasi_nn/wasinn_whisper.cpp b/plugins/wasi_nn/wasinn_whisper.cpp index 6d8a608c..f82bd29d 100644 --- a/plugins/wasi_nn/wasinn_whisper.cpp +++ b/plugins/wasi_nn/wasinn_whisper.cpp @@ -341,8 +341,7 @@ bool checkAudioRIFF(const std::string_view Buf, const std::string_view Format) { bool loadWAV(Span Buf, std::vector &PCMF32, std::vector> &PCMF32s, bool Stereo) { - // Not to use the helper function in examples of whisper.cpp to prevent from - // copy. + // Do not use the helper function from whisper.cpp examples to avoid copying. drwav WAV; const uint32_t ConstSampleRate = 16000; @@ -856,7 +855,8 @@ Expect load(WasiNNEnvironment &Env, Span> Builders, // Set whisper log callback. whisper_log_set(WhisperLogCallback, &GraphRef); - // If the graph builder length > 1, the data of builder[1] is the metadata. + // If the graph builder length is greater than 1, builder[1] contains the + // metadata. if (Builders.size() > 1) { const std::string Metadata(reinterpret_cast(Builders[1].data()), Builders[1].size()); diff --git a/plugins/wasi_nn/wasinn_whisper.h b/plugins/wasi_nn/wasinn_whisper.h index c61ede75..64745040 100644 --- a/plugins/wasi_nn/wasinn_whisper.h +++ b/plugins/wasi_nn/wasinn_whisper.h @@ -80,8 +80,8 @@ struct Context { // mono-channel F32 PCM input. std::vector InputPCM; std::vector> InputPCMs; - // Whisper config. Inherit from the graph and accept metadata when setting - // input. + // Whisper config. Inherited from the graph and updated from metadata when + // setting input. Config WhisperConfig; whisper_full_params WhisperParams; // Recognition outputs. diff --git a/plugins/wasi_nn/wasinnenv.h b/plugins/wasi_nn/wasinnenv.h index c83f0c17..d9123069 100644 --- a/plugins/wasi_nn/wasinnenv.h +++ b/plugins/wasi_nn/wasinnenv.h @@ -141,10 +141,11 @@ class Graph { Impl; // Graph status. // Uninitialized: A new graph in monostate. - // Invalid: The graph loaded failed in set_input with metadata. Can be - // reload with a new metadata in set_input. - // Finalized: The graph being deleted, but there are contexts linked. This - // graph ID will be released once the contexts are deleted. + // Invalid: The graph failed to load in set_input with metadata. It can be + // reloaded with new metadata in set_input. + // Finalized: The graph is being deleted, but there are linked contexts. + // This graph ID will be released once the contexts are + // deleted. // Ready: This graph can be used to create a context. enum class Status : uint8_t { Uninitialized, Invalid, Finalized, Ready }; Status Stat; @@ -317,7 +318,7 @@ struct WasiNNEnvironment : auto &G = NNGraph[Id]; G.setFinalized(); if (G.getContextCount() == 0) { - // Checked all contexts are deleted. Release the graph id. + // All contexts are deleted. Release the graph ID. if (Id == NNGraph.size() - 1) { NNGraph.pop_back(); } else { @@ -337,7 +338,7 @@ struct WasiNNEnvironment : auto &G = NNGraph[GId]; G.decreaseContext(); if (G.getContextCount() == 0 && G.isFinalized()) { - // Checked all contexts are deleted. Release the graph id. + // All contexts are deleted. Release the graph ID. if (GId == NNGraph.size() - 1) { NNGraph.pop_back(); } else { diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index c9f2f91d..b8a65f4f 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -133,7 +133,7 @@ WasiNNLoadByName::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t NamePtr, return WASINN::ErrNo::InvalidArgument; } - // Get the name of model + // Get the model name. auto Name = MemInst->getPointer(NamePtr); if (unlikely(Name == nullptr)) { spdlog::error("[WASI-NN] Failed when accessing the return Name memory."sv); @@ -158,7 +158,7 @@ WasiNNLoadByName::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t NamePtr, } #endif // ifdef WASMEDGE_BUILD_WASI_NN_RPC - // Get the model + // Get the model. std::string ModelName(reinterpret_cast(Name), NameLen); if (Env.mdGet(ModelName, *GraphId)) { return WASINN::ErrNo::Success; @@ -183,14 +183,14 @@ Expect WasiNNLoadByNameWithConfig::bodyImpl( return WASINN::ErrNo::InvalidArgument; } - // Get the name of model + // Get the model name. auto Name = MemInst->getPointer(NamePtr); if (unlikely(Name == nullptr)) { spdlog::error("[WASI-NN] Failed when accessing the return Name memory."sv); return WASINN::ErrNo::InvalidArgument; } - // Get the config of model + // Get the model config. auto Config = MemInst->getPointer(ConfigPtr); if (unlikely(Config == nullptr)) { spdlog::error( @@ -218,7 +218,7 @@ Expect WasiNNLoadByNameWithConfig::bodyImpl( } #endif // ifdef WASMEDGE_BUILD_WASI_NN_RPC - // Get the model + // Get the model. std::string ModelName(reinterpret_cast(Name), NameLen); std::vector ModelConfig(reinterpret_cast(Config), reinterpret_cast(Config) + diff --git a/plugins/wasm_bpf/CMakeLists.txt b/plugins/wasm_bpf/CMakeLists.txt index dc9ba54a..17ab9a80 100644 --- a/plugins/wasm_bpf/CMakeLists.txt +++ b/plugins/wasm_bpf/CMakeLists.txt @@ -104,7 +104,7 @@ if(NOT ${LIBBPF_FOUND}) set(LIBBPF_SOURCE "fetch-content") endif() -# If we cannot find libbpf.. +# If we cannot find libbpf. if(NOT ${LIBBPF_FOUND}) message(FATAL_ERROR "Could not find libbpf") endif() @@ -145,7 +145,8 @@ if("${LIBBPF_SOURCE}" STREQUAL "pkgconf") message(STATUS "Link libbpf dynamically") target_link_libraries(wasmedgePluginWasmBpf PUBLIC ${LIBBPF_LIBRARIES} ${LIBBPF_DEP_LIBRARIES}) else() - # Link libbpf statically if we don't use pkgconf, because under this case libbpf is not installed systemwide + # Link libbpf statically if we don't use pkgconf, because in this case libbpf is + # not installed systemwide. message(STATUS "Link libbpf statically") target_link_libraries(wasmedgePluginWasmBpf PUBLIC ${LIBBPF_LIBRARIES_STATIC} ${LIBBPF_DEP_LIBRARIES}) endif() diff --git a/plugins/wasm_bpf/README.md b/plugins/wasm_bpf/README.md index fed20dcd..95fbcdc5 100644 --- a/plugins/wasm_bpf/README.md +++ b/plugins/wasm_bpf/README.md @@ -1,25 +1,26 @@ # wasm_bpf Plugin -This plugin added six host functions that give you Wasm application access to eBPF. +This plugin adds six host functions that give Wasm applications access to eBPF. -Six functions are listed here. And all of them are in the module `wasm_bpf`, if you loaded this plugin. +The six functions are listed here. All of them are in the `wasm_bpf` module when +this plugin is loaded. ```c -/// lookup a bpf map fd by name. +/// look up a BPF map fd by name. i32 wasm_bpf_map_fd_by_name(u64 obj, u32 name); -/// detach and close a bpf program. +/// detach and close a BPF program. i32 wasm_close_bpf_object(u64 obj); -/// CO-RE load a bpf object into the kernel. +/// CO-RE load a BPF object into the kernel. u64 wasm_load_bpf_object(u32 obj_buf, u32 obj_buf_sz); -/// attach a bpf program to a kernel hook. +/// attach a BPF program to a kernel hook. i32 wasm_attach_bpf_program(u64 obj, u32 name, u32 attach_target); -/// poll a bpf buffer, and call a wasm callback indicated by sample_func. -/// the first time to call this function will open and create a bpf buffer. +/// poll a BPF buffer and call a Wasm callback indicated by sample_func. +/// the first call to this function will open and create a BPF buffer. i32 wasm_bpf_buffer_poll(u64 program, i32 fd, u32 sample_func, u32 ctx, u32 data, i32 max_size, i32 timeout_ms); -/// lookup, update, delete, and get_next_key operations on a bpf map. +/// perform lookup, update, delete, and get_next_key operations on a BPF map. i32 wasm_bpf_map_operate(u64 fd, i32 cmd, u32 key, u32 value, u32 next_key, u64 flags); ``` diff --git a/plugins/wasm_bpf/bpf-api.h b/plugins/wasm_bpf/bpf-api.h index 1970fe69..72140842 100644 --- a/plugins/wasm_bpf/bpf-api.h +++ b/plugins/wasm_bpf/bpf-api.h @@ -28,13 +28,13 @@ extern "C" { namespace WasmEdge { namespace Host { -/// \brief init libbpf callbacks +/// \brief Initialize libbpf callbacks. void init_libbpf(void); typedef int32_t (*bpf_buffer_sample_fn)(void *ctx, void *data, size_t size); -/// An absraction of a bpf ring buffer or perf buffer -/// see https://github.com/iovisor/bcc/blob/master/libbpf-tools/compat.c +/// An abstraction of a BPF ring buffer or perf buffer. +/// See https://github.com/iovisor/bcc/blob/master/libbpf-tools/compat.c. class bpf_buffer { protected: bpf_buffer_sample_fn fn; @@ -47,27 +47,27 @@ class bpf_buffer { uint32_t wasm_buf_ptr; public: - /// sample callback which calls the wasm handler indirectly + /// Sample callback that calls the Wasm handler indirectly. int32_t bpf_buffer_sample(void *data, size_t size); - /// Check if the bpf buffer is valid + /// Check whether the BPF buffer is valid. /// - /// a valid module instance should have only one table and a sample function + /// A valid module instance should have only one table and a sample function. bool is_valid() const; - /// set the wasm callback parameters + /// Set the Wasm callback parameters. void set_callback_params(WasmEdge_ExecutorContext *executor, const WasmEdge_ModuleInstanceContext *module_instance, uint32_t sample_func, void *data, size_t max_size, uint32_t ctx, uint32_t buf_ptr); - /// polling the bpf buffer + /// Poll the BPF buffer. virtual int32_t bpf_buffer__poll(int32_t timeout_ms) = 0; - /// open the bpf buffer map + /// Open the BPF buffer map. virtual int32_t bpf_buffer__open(int32_t fd, bpf_buffer_sample_fn sample_cb, void *ctx) = 0; virtual ~bpf_buffer() noexcept = default; }; -/// bpf program instance +/// BPF program instance. class wasm_bpf_program { std::unique_ptr obj{nullptr, bpf_object__close}; @@ -76,19 +76,19 @@ class wasm_bpf_program { links; public: - /// Find a bpf map fd by name + /// Find a BPF map fd by name. int32_t bpf_map_fd_by_name(const char *name); - /// Load a bpf object from a buffer into the kernel + /// Load a BPF object from a buffer into the kernel. int32_t load_bpf_object(const void *obj_buf, size_t obj_buf_sz); - /// Attach a bpf program to a target (e.g. a kernel function on a kprobe) + /// Attach a BPF program to a target (e.g. a kernel function on a kprobe). int32_t attach_bpf_program(const char *name, const char *attach_target); - /// Poll the bpf buffer to get data from the kernel + /// Poll the BPF buffer to get data from the kernel. int32_t bpf_buffer_poll(WasmEdge_ExecutorContext *executor, const WasmEdge_ModuleInstanceContext *module_instance, int32_t fd, int32_t sample_func, uint32_t ctx, void *buffer_data, size_t max_size, int32_t timeout_ms, uint32_t wasm_buf_ptr); - /// Get the bpf map pointer by fd + /// Get the BPF map pointer by fd. bpf_map *map_ptr_by_fd(int32_t fd); }; @@ -99,7 +99,7 @@ enum bpf_map_cmd { _BPF_MAP_GET_NEXT_KEY, }; -/// Operate on a bpf map. +/// Operate on a BPF map. int32_t bpf_map_operate(int32_t fd, int32_t cmd, void *key, void *value, void *next_key, uint64_t flags); using handle_t = int64_t; diff --git a/plugins/wasm_bpf/func-attach-bpf-program.h b/plugins/wasm_bpf/func-attach-bpf-program.h index 5eda9088..a0f990b7 100644 --- a/plugins/wasm_bpf/func-attach-bpf-program.h +++ b/plugins/wasm_bpf/func-attach-bpf-program.h @@ -13,7 +13,7 @@ namespace WasmEdge { namespace Host { -/// \brief Attach a bpf program to the specified target +/// \brief Attach a BPF program to the specified target. class AttachBpfProgram : public WasmEdge::Runtime::HostFunction { public: diff --git a/plugins/wasm_bpf/func-bpf-buffer-poll.h b/plugins/wasm_bpf/func-bpf-buffer-poll.h index ff33f82e..63af33a7 100644 --- a/plugins/wasm_bpf/func-bpf-buffer-poll.h +++ b/plugins/wasm_bpf/func-bpf-buffer-poll.h @@ -13,17 +13,17 @@ namespace WasmEdge { namespace Host { -/// Perform a bpf buffer poll. If the map is not opened, it will be opened. +/// Perform a BPF buffer poll. If the map is not opened, it will be opened. /// -/// \param fd the map fd for bpf buffer. -/// \param sample_func callback function. When things are polled, it will be +/// \param fd the map fd for the BPF buffer. +/// \param sample_func callback function. When data is polled, it will be /// invoked. -/// \param ctx user customized variable. +/// \param ctx user-customized variable. /// \param data data buffer that will be used to store the polled data. -/// \param max_size How many bytes can be put at data. -/// \param timeout_ms how many milliseconds can be waited. +/// \param max_size how many bytes can be put in data. +/// \param timeout_ms how many milliseconds to wait. /// -/// \return On success, return 0. On error, return error code. +/// \return 0 on success, error code on failure. class BpfBufferPoll : public WasmEdge::Runtime::HostFunction { public: BpfBufferPoll(state_t state) : state(state) {} diff --git a/plugins/wasm_bpf/func-bpf-map-fd-by-name.h b/plugins/wasm_bpf/func-bpf-map-fd-by-name.h index f4d93e1e..e155d96c 100644 --- a/plugins/wasm_bpf/func-bpf-map-fd-by-name.h +++ b/plugins/wasm_bpf/func-bpf-map-fd-by-name.h @@ -13,9 +13,9 @@ namespace WasmEdge { namespace Host { -/// \brief Lookup a map fd by its name. +/// \brief Look up a map fd by its name. /// -/// Map fd is returned if succeed, others if failed. +/// Returns the map fd on success; other values indicate failure. class BpfMapFdByName : public WasmEdge::Runtime::HostFunction { public: BpfMapFdByName(state_t state) : state(state) {} diff --git a/plugins/wasm_bpf/func-bpf-map-operate.h b/plugins/wasm_bpf/func-bpf-map-operate.h index 486f7d9b..c5779e2f 100644 --- a/plugins/wasm_bpf/func-bpf-map-operate.h +++ b/plugins/wasm_bpf/func-bpf-map-operate.h @@ -11,9 +11,9 @@ namespace WasmEdge { namespace Host { -/// Perform bpf map operations on a specified bpf map through map fd. +/// Perform BPF map operations on a specified BPF map through a map fd. /// -/// Return zero if succeed, others if error +/// Returns zero on success; other values indicate errors. class BpfMapOperate : public WasmEdge::Runtime::HostFunction { public: BpfMapOperate(state_t state) : state(state) {} diff --git a/plugins/wasm_bpf/func-close-bpf-object.h b/plugins/wasm_bpf/func-close-bpf-object.h index 1d8d8c12..ee94ba59 100644 --- a/plugins/wasm_bpf/func-close-bpf-object.h +++ b/plugins/wasm_bpf/func-close-bpf-object.h @@ -13,8 +13,8 @@ namespace WasmEdge { namespace Host { -/// \brief Close an opened bpf object. Will remove mapfds from the cache. -/// Return 0 if success. Others represent error codes. +/// \brief Close an opened BPF object and remove map fds from the cache. +/// Returns 0 on success; other values represent error codes. class CloseBpfObject : public WasmEdge::Runtime::HostFunction { public: CloseBpfObject(state_t state) : state(state) {} diff --git a/plugins/wasm_bpf/func-load-bpf-object.h b/plugins/wasm_bpf/func-load-bpf-object.h index 39516908..eb51cd8e 100644 --- a/plugins/wasm_bpf/func-load-bpf-object.h +++ b/plugins/wasm_bpf/func-load-bpf-object.h @@ -13,13 +13,13 @@ namespace WasmEdge { namespace Host { -/// \brief Load a bpf ELF file. +/// \brief Load a BPF ELF file. /// -/// Binary file should be provided through a Wasm Buffer. wasm_bpf will handle -/// the remaining process Call to this function will also cache bpf map fds. +/// A binary file should be provided through a Wasm buffer. wasm_bpf handles +/// the remaining process. Calling this function also caches BPF map fds. /// -/// \return a handle to a bpf program, which is stored in a map in the global -/// state. Return 0 if failed. +/// \return a handle to a BPF program, which is stored in a map in the global +/// state. Returns 0 on failure. class LoadBpfObject : public WasmEdge::Runtime::HostFunction { public: LoadBpfObject(state_t state) : state(state) {} diff --git a/plugins/wasm_bpf/state.h b/plugins/wasm_bpf/state.h index cd9d4720..2b15cd1e 100644 --- a/plugins/wasm_bpf/state.h +++ b/plugins/wasm_bpf/state.h @@ -13,7 +13,7 @@ namespace WasmEdge { namespace Host { struct WasmBpfState { - /// manage bpf programs + /// Manage BPF programs. std::unordered_map> handles; std::shared_mutex lock; ~WasmBpfState() noexcept = default; diff --git a/plugins/wasm_bpf/util.h b/plugins/wasm_bpf/util.h index 4ac79e44..d55161d9 100644 --- a/plugins/wasm_bpf/util.h +++ b/plugins/wasm_bpf/util.h @@ -8,8 +8,8 @@ namespace WasmEdge { namespace Host { -/// \brief read a c string from memory and check if it is null terminated -/// \param memory memory instance from wasm runtime +/// \brief Read a C string from memory and check whether it is null terminated. +/// \param memory memory instance from the wasm runtime. /// \param ptr the wasm32 buffer pointer /// \return WasmEdge::Expect diff --git a/plugins/wasm_bpf/wasm-bpf.cpp b/plugins/wasm_bpf/wasm-bpf.cpp index 1dd9b25f..fffe1125 100644 --- a/plugins/wasm_bpf/wasm-bpf.cpp +++ b/plugins/wasm_bpf/wasm-bpf.cpp @@ -45,7 +45,7 @@ static int32_t bpf_buffer_sample(void *ctx, void *data, size_t size) { namespace WasmEdge { namespace Host { -/// \brief initialize libbpf library +/// \brief Initialize the libbpf library. void init_libbpf(void) { libbpf_set_strict_mode(LIBBPF_STRICT_ALL); libbpf_set_print(libbpf_print_fn); @@ -156,7 +156,7 @@ int32_t bpf_buffer::bpf_buffer_sample(void *data, size_t size) { return WasmEdge_ValueGetI32(invoke_func_result); } -/// \brief create a bpf buffer based on the object map type +/// \brief Create a BPF buffer based on the object map type. std::unique_ptr bpf_buffer__new(bpf_map *events) { bpf_map_type map_type = bpf_map__type(events); switch (map_type) { @@ -173,7 +173,7 @@ std::unique_ptr bpf_buffer__new(bpf_map *events) { int32_t wasm_bpf_program::bpf_map_fd_by_name(const char *name) { return bpf_object__find_map_fd_by_name(obj.get(), name); } -/// \brief load all bpf programs and maps in an object file. +/// \brief Load all BPF programs and maps in an object file. int32_t wasm_bpf_program::load_bpf_object(const void *obj_buf, size_t obj_buf_sz) { auto object = bpf_object__open_mem(obj_buf, obj_buf_sz, nullptr); @@ -184,13 +184,13 @@ int32_t wasm_bpf_program::load_bpf_object(const void *obj_buf, return bpf_object__load(object); } -/// \brief attach a specific bpf program by name and target. +/// \brief Attach a specific BPF program by name and target. int32_t wasm_bpf_program::attach_bpf_program(const char *name, const char *attach_target) { bpf_link *link; if (!attach_target) { - // auto attach base on bpf_program__section_name. The works well for most - // bpf types, include kprobe, uprobe, fentry, lsm, etc. + // Auto-attach based on bpf_program__section_name. This works well for most + // BPF types, including kprobe, uprobe, fentry, lsm, etc. link = bpf_program__attach(bpf_object__find_program_by_name(obj.get(), name)); } else { @@ -200,11 +200,11 @@ int32_t wasm_bpf_program::attach_bpf_program(const char *name, spdlog::error("[WasmEdge Wasm_bpf] get prog {} fail"sv, name); return -1; } - // TODO: attach dynamically base on bpf_program__section_name(prog) and - // attach_target to support more attach type libbpf cannot auto attach. For - // example, if bpf_program__section_name(prog) is "xdp" and attach_target is - // "eth0", or attach sockops to a socket fd. For now, we will try auto - // attach as well. + // TODO: attach dynamically based on bpf_program__section_name(prog) and + // attach_target to support more attach types that libbpf cannot auto + // attach. For example, if bpf_program__section_name(prog) is "xdp" and + // attach_target is "eth0", or attach sockops to a socket fd. For now, we + // will try auto attach as well. link = bpf_program__attach(bpf_object__find_program_by_name(obj.get(), name)); } @@ -227,7 +227,7 @@ bpf_map *wasm_bpf_program::map_ptr_by_fd(int fd) { return nullptr; } -/// polling the buffer, if the buffer is not created, create it. +/// Poll the buffer; if the buffer is not created, create it. int32_t wasm_bpf_program::bpf_buffer_poll( WasmEdge_ExecutorContext *executor, const WasmEdge_ModuleInstanceContext *module_instance, int32_t fd, diff --git a/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp b/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp index 908c47c6..486916c2 100644 --- a/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp +++ b/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp @@ -89,7 +89,7 @@ Expect AVDictGet::body(const Runtime::CallingFrame &Frame, FFMPEG_PTR_FETCH(AvDict, DictId, AVDictionary *); - // If Dict Not created return (i.e. 0 is passed as AVDictId) + // Return if Dict was not created (i.e. 0 is passed as AVDictId). if (AvDict == nullptr) { return static_cast(ErrNo::InternalError); } @@ -124,7 +124,7 @@ Expect AVDictGetKeyValue::body( FFMPEG_PTR_FETCH(AvDict, DictId, AVDictionary *); - // If Dict Not created return (i.e. 0 is passed as AVDictId) + // Return if Dict was not created (i.e. 0 is passed as AVDictId). if (AvDict == nullptr) { return static_cast(ErrNo::InternalError); } diff --git a/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp b/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp index 6f0fab2d..d63cd248 100644 --- a/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp +++ b/plugins/wasmedge_ffmpeg/avutil/avFrame.cpp @@ -266,7 +266,7 @@ AVFrameColorTransferCharacteristic::body(const Runtime::CallingFrame &, FFMPEG_PTR_FETCH(AvFrame, FrameId, AVFrame); AVColorTransferCharacteristic const Characteristic = AvFrame->color_trc; - // Can use the binding as well. Currently, Commented the binding. + // The binding can be used as well. Currently, the binding is commented out. return static_cast(Characteristic); } diff --git a/plugins/wasmedge_ffmpeg/avutil/avutil_func.h b/plugins/wasmedge_ffmpeg/avutil/avutil_func.h index 05a2c7b5..ebdee0e4 100644 --- a/plugins/wasmedge_ffmpeg/avutil/avutil_func.h +++ b/plugins/wasmedge_ffmpeg/avutil/avutil_func.h @@ -34,7 +34,7 @@ class AVLogGetFlags : public HostFunction { Expect body(const Runtime::CallingFrame &Frame); }; -// Option funcs. +// Option functions. class AVOptSetBin : public HostFunction { public: using HostFunction::HostFunction; diff --git a/plugins/wasmedge_ocr/ocr_env.h b/plugins/wasmedge_ocr/ocr_env.h index e4a0ddf4..231672a2 100644 --- a/plugins/wasmedge_ocr/ocr_env.h +++ b/plugins/wasmedge_ocr/ocr_env.h @@ -25,8 +25,8 @@ enum class ErrNo : uint32_t { class OCREnv { public: OCREnv() noexcept { - // check Tesseract API by initializing tesseract-ocr with English, without - // specifying tessdata path + // Check the Tesseract API by initializing tesseract-ocr with English + // without specifying the tessdata path. if (TesseractApi->Init(NULL, "eng")) { spdlog::error( "[WasmEdge-OCR] Error occurred when initializing tesseract."); diff --git a/plugins/wasmedge_opencvmini/opencvmini_func.h b/plugins/wasmedge_opencvmini/opencvmini_func.h index e2c80d76..587cc6a7 100644 --- a/plugins/wasmedge_opencvmini/opencvmini_func.h +++ b/plugins/wasmedge_opencvmini/opencvmini_func.h @@ -80,7 +80,7 @@ class WasmEdgeOpenCVMiniImwrite uint32_t SrcMatKey); }; -/// This is not `cv::normalize`, refers to: +/// This is not `cv::normalize`; refer to: /// https://github.com/WasmEdge/WasmEdge/commit/77051da4995d7318d91a82102a72ce2557151764#diff-3333d926ca87cf4285bfcd6deae45ee310307be66fca8a4ca6f0f8a946743fccR50-R54 class WasmEdgeOpenCVMiniNormalize : public WasmEdgeOpenCVMini { diff --git a/plugins/wasmedge_process/processenv.h b/plugins/wasmedge_process/processenv.h index 07e9aa0b..a8a98d92 100644 --- a/plugins/wasmedge_process/processenv.h +++ b/plugins/wasmedge_process/processenv.h @@ -40,7 +40,7 @@ class WasmEdgeProcessEnvironment { /// Configurations /// Timeout in milliseconds. uint32_t TimeOut = DEFAULT_TIMEOUT; - /// Programs in white list. + /// Programs in the allowlist. std::unordered_set AllowedCmd; /// Flag to allow all programs. bool AllowedAll; diff --git a/plugins/wasmedge_process/processfunc.cpp b/plugins/wasmedge_process/processfunc.cpp index 5b9f3026..8cd5d942 100644 --- a/plugins/wasmedge_process/processfunc.cpp +++ b/plugins/wasmedge_process/processfunc.cpp @@ -100,7 +100,7 @@ Expect WasmEdgeProcessRun::body(const Runtime::CallingFrame &) { Env.StdErr.clear(); Env.ExitCode = static_cast(-1); - // Check white list of commands. + // Check the command allowlist. if (!Env.AllowedAll && Env.AllowedCmd.find(Env.Name) == Env.AllowedCmd.end()) { std::string Msg = "Permission denied: Command \""; @@ -142,7 +142,7 @@ Expect WasmEdgeProcessRun::body(const Runtime::CallingFrame &) { return Env.ExitCode; } - // Create a child process for executing command. + // Create a child process for executing a command. pid_t PID = fork(); if (PID == -1) { // Create process failed. @@ -272,7 +272,7 @@ Expect WasmEdgeProcessRun::body(const Runtime::CallingFrame &) { usleep(Env.DEFAULT_POLLTIME * 1000); } - // Read remained stdout and stderr. + // Read remaining stdout and stderr. do { RBytes = read(FDStdOut[0], Buf, sizeof(Buf)); if (RBytes > 0) { diff --git a/plugins/wasmedge_stablediffusion/sd_func.cpp b/plugins/wasmedge_stablediffusion/sd_func.cpp index f0cca7a3..9ad24fc4 100644 --- a/plugins/wasmedge_stablediffusion/sd_func.cpp +++ b/plugins/wasmedge_stablediffusion/sd_func.cpp @@ -150,7 +150,7 @@ void upscalerModel(const char *UpscaleModelPath, uint32_t UpscaleRepeats, free(CurrentImage.data); CurrentImage = UpscaledImage; } - // Set the final upscaled image as the result + // Set the final upscaled image as the result. Results[I] = CurrentImage; } } @@ -513,7 +513,7 @@ Expect SDImageToImage::body( free(InputImageBuffer); return static_cast(ErrNo::InvalidArgument); } - // Resize image when image size not matches width and height + // Resize image when its size does not match the width and height. if (Height != static_cast(ImageHeight) || Width != static_cast(ImageWidth)) { int ResizedHeight = Height; diff --git a/plugins/wasmedge_tensorflow/tensorflow_func.cpp b/plugins/wasmedge_tensorflow/tensorflow_func.cpp index d86143f4..2b8e6cd3 100644 --- a/plugins/wasmedge_tensorflow/tensorflow_func.cpp +++ b/plugins/wasmedge_tensorflow/tensorflow_func.cpp @@ -138,8 +138,8 @@ Expect CreateSessionSavedModel::body( Tags.reserve(TagsBufLen); TagsArgv.reserve(TagsBufLen); for (size_t I = 0; I < TagSpan.size(); ++I) { - // Should use std::string to copy the tag name here to prevent from no - // null-termination of the tag strings here. + // Use std::string to copy the tag name here and avoid relying on + // null-termination of the tag strings. const auto &Tag = TagSpan[I]; MEM_SV_CHECK(TagNameSV, MemInst, Tag.Ptr, Tag.Len, "Failed when accessing the tag name memory."sv) diff --git a/plugins/wasmedge_zlib/zlibenv.h b/plugins/wasmedge_zlib/zlibenv.h index a677d98a..407772ff 100644 --- a/plugins/wasmedge_zlib/zlibenv.h +++ b/plugins/wasmedge_zlib/zlibenv.h @@ -12,7 +12,7 @@ #include /** - * @brief A struct which maps perfectly to a wasm 32bit z_stream object + * @brief A struct that maps exactly to a 32-bit Wasm z_stream object * */ struct WasmZStream { diff --git a/plugins/wasmedge_zlib/zlibfunc.cpp b/plugins/wasmedge_zlib/zlibfunc.cpp index 73752350..b283f69a 100644 --- a/plugins/wasmedge_zlib/zlibfunc.cpp +++ b/plugins/wasmedge_zlib/zlibfunc.cpp @@ -1117,7 +1117,7 @@ WasmEdgeZlibDeflateInit_::body(const Runtime::CallingFrame &Frame, // ignore wasm custom allocators HostZStream->zalloc = Z_NULL; HostZStream->zfree = Z_NULL; - // ignore opaque since zmalloc and zfree was ignored + // Ignore opaque because zmalloc and zfree are ignored. HostZStream->opaque = Z_NULL; auto It = @@ -1152,7 +1152,7 @@ WasmEdgeZlibInflateInit_::body(const Runtime::CallingFrame &Frame, // ignore wasm custom allocators HostZStream->zalloc = Z_NULL; HostZStream->zfree = Z_NULL; - // ignore opaque since zmalloc and zfree was ignored + // Ignore opaque because zmalloc and zfree are ignored. HostZStream->opaque = Z_NULL; auto It = diff --git a/plugins/wasmedge_zlib/zlibfunc.h b/plugins/wasmedge_zlib/zlibfunc.h index b7ca1f05..21053cbb 100644 --- a/plugins/wasmedge_zlib/zlibfunc.h +++ b/plugins/wasmedge_zlib/zlibfunc.h @@ -114,7 +114,7 @@ class WasmEdgeZlibDeflateTune : public WasmEdgeZlib { }; // https://github.com/emscripten-core/emscripten/issues/17009 -// Using 32bit, because on wasm-side it will be 32bit long +// Use 32-bit because long is 32-bit wide on the Wasm side. class WasmEdgeZlibDeflateBound : public WasmEdgeZlib { public: WasmEdgeZlibDeflateBound(WasmEdgeZlibEnvironment &HostEnv) diff --git a/test/plugins/CMakeLists.txt b/test/plugins/CMakeLists.txt index 457b1653..7ead52f7 100644 --- a/test/plugins/CMakeLists.txt +++ b/test/plugins/CMakeLists.txt @@ -16,7 +16,7 @@ endif() # WasmEdge plug-in: wasm-bpf. if(WASMEDGE_PLUGIN_WASM_BPF) - # Only Linux systems support wasm_bpf now. + # wasm_bpf is currently supported only on Linux systems. if(CMAKE_SYSTEM_NAME MATCHES "Linux") add_subdirectory(wasm_bpf) else() @@ -31,7 +31,7 @@ endif() # WasmEdge plug-in: Image. if(WASMEDGE_PLUGIN_IMAGE) - # Only Linux and MacOS support wasmedge_image now. + # wasmedge_image is currently supported only on Linux and macOS. if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") add_subdirectory(wasmedge_image) else() @@ -46,7 +46,7 @@ endif() # WasmEdge plug-in: OpenCV-mini. if(WASMEDGE_PLUGIN_OPENCVMINI) - # Only Linux and MacOS support wasmedge_opencvmini now. + # wasmedge_opencvmini is currently supported only on Linux and macOS. if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") add_subdirectory(wasmedge_opencvmini) else() @@ -56,7 +56,7 @@ endif() # WasmEdge plug-in: Process. if(WASMEDGE_PLUGIN_PROCESS) - # Only Linux systems support wasmedge_process now. + # wasmedge_process is currently supported only on Linux systems. if(CMAKE_SYSTEM_NAME MATCHES "Linux") add_subdirectory(wasmedge_process) else() @@ -66,7 +66,7 @@ endif() # WasmEdge plug-in: Stable-diffusion. if(WASMEDGE_PLUGIN_STABLEDIFFUSION) - # Only Linux and MacOS support wasmedge_stablediffusion now. + # wasmedge_stablediffusion is currently supported only on Linux and macOS. if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") add_subdirectory(wasmedge_stablediffusion) else() @@ -76,7 +76,7 @@ endif() # WasmEdge plug-in: TensorFlow. if(WASMEDGE_PLUGIN_TENSORFLOW) - # Only Linux and MacOS support wasmedge_tensorflow now. + # wasmedge_tensorflow is currently supported only on Linux and macOS. if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") add_subdirectory(wasmedge_tensorflow) else() @@ -86,7 +86,7 @@ endif() # WasmEdge plug-in: TensorFlow-Lite. if(WASMEDGE_PLUGIN_TENSORFLOWLITE) - # Only Linux and MacOS support wasmedge_tensorflowlite now. + # wasmedge_tensorflowlite is currently supported only on Linux and macOS. if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") add_subdirectory(wasmedge_tensorflowlite) else() diff --git a/test/plugins/unittest/testplugin.c b/test/plugins/unittest/testplugin.c index a2be6472..f9dac55f 100644 --- a/test/plugins/unittest/testplugin.c +++ b/test/plugins/unittest/testplugin.c @@ -72,13 +72,13 @@ CreateTestModule(const struct WasmEdge_ModuleDescriptor *Desc) { ParamTypes[1] = WasmEdge_ValTypeGenI32(); ReturnTypes[0] = WasmEdge_ValTypeGenI32(); - /* Create the "add" function and add into the module instance. */ + /* Create the "add" function and add it to the module instance. */ FType = WasmEdge_FunctionTypeCreate(ParamTypes, 2, ReturnTypes, 1); FuncName = WasmEdge_StringCreateByCString("add"); FuncCxt = WasmEdge_FunctionInstanceCreate(FType, HostFuncAdd, Accumulate, 0); WasmEdge_ModuleInstanceAddFunction(Mod, FuncName, FuncCxt); WasmEdge_StringDelete(FuncName); - /* Create the "sub" function and add into the module instance. */ + /* Create the "sub" function and add it to the module instance. */ FuncName = WasmEdge_StringCreateByCString("sub"); FuncCxt = WasmEdge_FunctionInstanceCreate(FType, HostFuncSub, Accumulate, 0); WasmEdge_ModuleInstanceAddFunction(Mod, FuncName, FuncCxt); diff --git a/test/plugins/wasi_crypto/helper.h b/test/plugins/wasi_crypto/helper.h index 0fd03a50..19b1425f 100644 --- a/test/plugins/wasi_crypto/helper.h +++ b/test/plugins/wasi_crypto/helper.h @@ -137,7 +137,7 @@ class WasiCryptoTest : public ::testing::Test { WasiCryptoExpect optionsSetU64(__wasi_options_t OptionsHandle, std::string_view Name, uint64_t Value); - // Not supported, Buf placing must be on page. + // Not supported, buffer placement must be on a page. // WasiCryptoExpect // optionsSetGuestBuffer(__wasi_options_t OptionsHandle, // std::string_view Name, Span Buf); diff --git a/test/plugins/wasi_logging/wasi_logging.cpp b/test/plugins/wasi_logging/wasi_logging.cpp index 706aeada..b523a177 100644 --- a/test/plugins/wasi_logging/wasi_logging.cpp +++ b/test/plugins/wasi_logging/wasi_logging.cpp @@ -52,7 +52,7 @@ void fillMemContent(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, TEST(WasiLoggingTests, func_log) { using namespace std::literals::string_view_literals; // Create the wasi-logging module instance. - // Here create 2 wasi-logging modules for testing in multiple modules. + // Create two wasi-logging modules for testing multiple modules. auto WasiLoggingMod1 = createModule(); ASSERT_TRUE(WasiLoggingMod1); auto WasiLoggingMod2 = createModule(); diff --git a/test/plugins/wasi_nn/wasi_nn.cpp b/test/plugins/wasi_nn/wasi_nn.cpp index ad580eb3..514a0923 100644 --- a/test/plugins/wasi_nn/wasi_nn.cpp +++ b/test/plugins/wasi_nn/wasi_nn.cpp @@ -105,7 +105,7 @@ std::vector classSort(WasmEdge::Span Array) { std::iota(Indices.begin(), Indices.end(), 0); std::sort(Indices.begin(), Indices.end(), [&Array](size_t Left, size_t Right) -> bool { - // Sort indices according to corresponding array element. + // Sort indices according to the corresponding array elements. return Array[Left] > Array[Right]; }); return Indices; @@ -274,7 +274,7 @@ TEST(WasiNNTest, OpenVINOBackend) { static_cast(ErrNo::InvalidArgument)); } - // Test: load -- wrong builders' length. + // Test: load -- wrong builder count. BuilderPtr = LoadEntryPtr; writeFatPointer(MemInst, StorePtr, static_cast(XmlRead.size()), BuilderPtr); @@ -361,7 +361,7 @@ TEST(WasiNNTest, OpenVINOBackend) { NNGraphTmp.swap(NNMod->getEnv().NNGraph); NNContextTmp.swap(NNMod->getEnv().NNContext); - // Test: init_execution_context -- init context successfully. + // Test: init_execution_context -- initialize context successfully. { EXPECT_TRUE(HostFuncInit.run( CallFrame, @@ -372,7 +372,7 @@ TEST(WasiNNTest, OpenVINOBackend) { BuilderPtr += 4; } - // Test: init_execution_context -- init second context. + // Test: init_execution_context -- initialize the second context. { EXPECT_TRUE(HostFuncInit.run( CallFrame, @@ -677,7 +677,7 @@ TEST(WasiNNTest, PyTorchBackend) { static_cast(ErrNo::InvalidArgument)); } - // Test: load -- wrong builders' length. + // Test: load -- wrong builder count. BuilderPtr = LoadEntryPtr; writeFatPointer(MemInst, StorePtr, static_cast(WeightRead.size()), BuilderPtr); @@ -747,7 +747,7 @@ TEST(WasiNNTest, PyTorchBackend) { NNGraphTmp.emplace_back(Backend::PyTorch); NNGraphTmp.back().setReady(); // Test: init_execution_context -- graph id exceeds. - // TODO: not null test for pytorch now + // TODO: add a non-null test for PyTorch. // NNGraphTmp.swap(NNMod->getEnv().NNGraph); // NNContextTmp.swap(NNMod->getEnv().NNContext); // { @@ -762,7 +762,7 @@ TEST(WasiNNTest, PyTorchBackend) { // NNGraphTmp.swap(NNMod->getEnv().NNGraph); // NNContextTmp.swap(NNMod->getEnv().NNContext); - // Test: init_execution_context -- init context successfully. + // Test: init_execution_context -- initialize context successfully. { EXPECT_TRUE(HostFuncInit.run( CallFrame, @@ -773,7 +773,7 @@ TEST(WasiNNTest, PyTorchBackend) { BuilderPtr += 4; } - // Test: init_execution_context -- init second context. + // Test: init_execution_context -- initialize the second context. { EXPECT_TRUE(HostFuncInit.run( CallFrame, @@ -1056,7 +1056,7 @@ TEST(WasiNNTest, TFLiteBackend) { static_cast(ErrNo::InvalidArgument)); } - // Test: load -- wrong builders' length. + // Test: load -- wrong builder count. BuilderPtr = LoadEntryPtr; writeFatPointer(MemInst, StorePtr, static_cast(WeightRead.size()), BuilderPtr); @@ -1144,7 +1144,7 @@ TEST(WasiNNTest, TFLiteBackend) { NNGraphTmp.swap(NNMod->getEnv().NNGraph); NNContextTmp.swap(NNMod->getEnv().NNContext); - // Test: init_execution_context -- init context successfully. + // Test: init_execution_context -- initialize context successfully. { EXPECT_TRUE(HostFuncInit.run( CallFrame, @@ -1155,7 +1155,7 @@ TEST(WasiNNTest, TFLiteBackend) { BuilderPtr += 4; } - // Test: init_execution_context -- init second context. + // Test: init_execution_context -- initialize the second context. { EXPECT_TRUE(HostFuncInit.run( CallFrame, @@ -1458,7 +1458,7 @@ TEST(WasiNNTest, GGMLBackend) { static_cast(ErrNo::InvalidArgument)); } - // Test: init_execution_context -- init context successfully. + // Test: init_execution_context -- initialize context successfully. { EXPECT_TRUE(HostFuncInit.run( CallFrame, @@ -1695,7 +1695,7 @@ TEST(WasiNNTest, GGMLBackendWithRPC) { EXPECT_NE(Errno[0].get(), static_cast(ErrNo::Success)); } - // Test: init_execution_context -- init context successfully. + // Test: init_execution_context -- initialize context successfully. { EXPECT_TRUE(HostFuncInit.run( CallFrame, @@ -1920,7 +1920,7 @@ TEST(WasiNNTest, GGMLBackendComputeSingleWithRPC) { BuilderPtr += 4; } - // Test: init_execution_context -- init context successfully. + // Test: init_execution_context -- initialize context successfully. { EXPECT_TRUE(HostFuncInit.run( CallFrame, @@ -2147,7 +2147,7 @@ TEST(WasiNNTest, WhisperBackend) { static_cast(ErrNo::InvalidArgument)); } - // Test: init_execution_context -- init second context. + // Test: init_execution_context -- initialize the second context. { EXPECT_TRUE(HostFuncInit.run( CallFrame, @@ -2386,7 +2386,7 @@ TEST(WasiNNTest, PiperBackend) { static_cast(ErrNo::InvalidArgument)); } - // Test: init_execution_context -- init context successfully. + // Test: init_execution_context -- initialize context successfully. { EXPECT_TRUE(HostFuncInit.run( CallFrame, @@ -2506,7 +2506,7 @@ TEST(WasiNNTest, PiperBackend) { BuilderPtr += 4; } - // Test: init_execution_context -- init context successfully. + // Test: init_execution_context -- initialize context successfully. { EXPECT_TRUE(HostFuncInit.run( CallFrame, @@ -2517,7 +2517,7 @@ TEST(WasiNNTest, PiperBackend) { BuilderPtr += 4; } - // First json input with parameters overridden + // First JSON input with parameters overridden. Text = "{\"text\": \"This is a test.\", \"noise_scale\": 0.0, " "\"length_scale\": 2.0, \"noise_w\": 0.0}"; TensorData = {Text.begin(), Text.end()}; @@ -2566,7 +2566,7 @@ TEST(WasiNNTest, PiperBackend) { EXPECT_GE(BytesWritten, 40000); } - // Second json input to check if one-time overriding is working properly + // Second JSON input to check if one-time overriding works properly. Text = "{\"text\": \"This is a test.\", \"output_type\": \"raw\", " "\"noise_scale\": 0.0, \"noise_w\": 0.0}"; TensorData = {Text.begin(), Text.end()}; @@ -2739,7 +2739,7 @@ TEST(WasiNNTest, ChatTTSBackend) { static_cast(ErrNo::InvalidArgument)); } - // Test: init_execution_context -- init second context. + // Test: init_execution_context -- initialize the second context. { EXPECT_TRUE(HostFuncInit.run( CallFrame, @@ -3020,7 +3020,7 @@ TEST(WasiNNTest, MLXBackend) { static_cast(ErrNo::InvalidArgument)); } - // Test: init_execution_context -- init second context. + // Test: init_execution_context -- initialize the second context. { EXPECT_TRUE(HostFuncInit.run( CallFrame, @@ -3288,7 +3288,7 @@ TEST(WasiNNTest, BitNetBackend) { EXPECT_EQ(Errno[0].get(), static_cast(ErrNo::InvalidArgument)); } - // Test: init_execution_context -- init context successfully. + // Test: init_execution_context -- initialize context successfully. { ASSERT_TRUE(HostFuncInit.run( CallFrame, diff --git a/test/plugins/wasm_bpf/assets/README.md b/test/plugins/wasm_bpf/assets/README.md index f2435fff..fcee89d2 100644 --- a/test/plugins/wasm_bpf/assets/README.md +++ b/test/plugins/wasm_bpf/assets/README.md @@ -1,12 +1,13 @@ # wasm_bpf Plugin tests -This file contains bpf programs that will be used during testing. +This file contains BPF programs that will be used during testing. - `bootstrap` and `runqlat`: examples copied from `wasm-bpf`. See [here](https://github.com/eunomia-bpf/wasm-bpf/tree/main/examples) for build instructions. -- `simple_ringbuf`: A simple ebpf program which writes fixed data to a ring buffer -- `simple_map`: A simple ebpf program which stores fixed data to a bpf map +- `simple_ringbuf`: A simple eBPF program that writes fixed data to a ring buffer +- `simple_map`: A simple eBPF program that stores fixed data to a BPF map -The source of `simple_ringbuf` and `simple_map` are listed under `bpf-sources`. Run `make` under that directory to build them. +The sources of `simple_ringbuf` and `simple_map` are listed under +`bpf-sources`. Run `make` under that directory to build them. `libbpf` and `clang` are required to build them. diff --git a/test/plugins/wasm_bpf/simple_map_test.cpp b/test/plugins/wasm_bpf/simple_map_test.cpp index b6c4c00c..f8b84e4e 100644 --- a/test/plugins/wasm_bpf/simple_map_test.cpp +++ b/test/plugins/wasm_bpf/simple_map_test.cpp @@ -53,7 +53,7 @@ static const uint32_t RESULT_VALUE_KEY = 0x7890; TEST(WasmBpfTest, SimpleMapTest) { using namespace std::string_view_literals; - // Test loading and attaching a bpf program, and some operations of maps + // Test loading and attaching a BPF program and some map operations. auto module = dynamic_cast(createModule()); ASSERT_NE(module, nullptr); @@ -72,10 +72,10 @@ TEST(WasmBpfTest, SimpleMapTest) { namespace fs = std::filesystem; auto bpfObject = getAssertsPath() / "simple_map.bpf.o"; - // Ensure the bpf object we need exists + // Ensure the BPF object we need exists. ASSERT_TRUE(fs::exists(bpfObject)); - // Read the bpf object into wasm memory + // Read the BPF object into Wasm memory. std::ifstream bpfObjStream(bpfObject); ASSERT_TRUE(bpfObjStream.is_open()); ASSERT_TRUE(bpfObjStream.good()); @@ -83,15 +83,15 @@ TEST(WasmBpfTest, SimpleMapTest) { (std::istreambuf_iterator(bpfObjStream)), std::istreambuf_iterator()); ASSERT_FALSE(bpfObjectBytes.empty()); - // Offset to put things into memory + // Offset used to place data in memory. uint32_t nextOffset = 1; - // Put the bpf object into memory + // Put the BPF object in memory. const uint32_t bpfObjectMemoryOffset = nextOffset; fillMemContent(memoryInstRef, bpfObjectMemoryOffset, bpfObjectBytes); nextOffset += static_cast(bpfObjectBytes.size()); - // Fill strings that will be used into memory + // Write the strings to memory. std::array strings = { "test_map", // Map name "sched_wakeup", // Program names @@ -202,7 +202,7 @@ TEST(WasmBpfTest, SimpleMapTest) { key, value, 0, 0); }; - // Helper functions to make read & write more convenient + // Helper functions that make reading and writing more convenient. auto readU64 = [&](uint32_t offset) -> uint64_t { const auto *ptr = memoryInstRef.getPointer(offset); EXPECT_NE(ptr, nullptr); @@ -220,8 +220,8 @@ TEST(WasmBpfTest, SimpleMapTest) { *ptr = val; }; - // Generate two numbers, which will be stored in the map and calculated the - // summation by the ebpf program + // Generate two numbers, which will be stored in the map and summed by the + // eBPF program. std::mt19937 randGen; randGen.seed(std::random_device()()); std::uniform_int_distribution intDist(0, @@ -229,7 +229,7 @@ TEST(WasmBpfTest, SimpleMapTest) { uint64_t num1 = intDist(randGen); uint64_t num2 = intDist(randGen); - // Prepare for wasm memory which is used to store numbers + // Prepare Wasm memory to store numbers. const uint32_t numOffset1 = nextOffset; nextOffset += 8; const uint32_t numOffset2 = nextOffset; @@ -257,16 +257,16 @@ TEST(WasmBpfTest, SimpleMapTest) { writeU64(resultOffset, 0); - // Write the add values into the map + // Write the addend values into the map. ASSERT_EQ(mapUpdateElem(mapFd, num1KeyOffset, numOffset1), 0); ASSERT_EQ(mapUpdateElem(mapFd, num2KeyOffset, numOffset2), 0); - // Write the indicating key - // Arbitrary values are correct. We only care the existence of the - // indicating key + // Write the indicator key. + // Arbitrary values are correct. We only care about the existence of the + // indicator key. ASSERT_EQ(mapUpdateElem(mapFd, indicatingKeyOffset, numOffset1), 0); - // Sleep for 1s and wait for the ebpf program to process.. + // Sleep for 1s and wait for the eBPF program to process. std::this_thread::sleep_for(std::chrono::seconds(2)); // Read the result and check it diff --git a/test/plugins/wasm_bpf/simple_ringbuf_test.cpp b/test/plugins/wasm_bpf/simple_ringbuf_test.cpp index 81fc935b..e2fbf94c 100644 --- a/test/plugins/wasm_bpf/simple_ringbuf_test.cpp +++ b/test/plugins/wasm_bpf/simple_ringbuf_test.cpp @@ -69,7 +69,7 @@ class PollCallbackFunction TEST(WasmBpfTest, SimpleRingbuf) { using namespace std::string_view_literals; - // Test loading and attaching a bpf program, and polling buffer + // Test loading and attaching a BPF program and polling a buffer. auto module = dynamic_cast(createModule()); ASSERT_NE(module, nullptr); @@ -88,10 +88,10 @@ TEST(WasmBpfTest, SimpleRingbuf) { namespace fs = std::filesystem; auto bpfObject = getAssertsPath() / "simple_ringbuf.bpf.o"; - // Ensure the bpf object we need exists + // Ensure the BPF object we need exists. ASSERT_TRUE(fs::exists(bpfObject)); - // Read the bpf object into wasm memory + // Read the BPF object into Wasm memory. std::ifstream bpfObjStream(bpfObject); ASSERT_TRUE(bpfObjStream.is_open()); ASSERT_TRUE(bpfObjStream.good()); @@ -99,15 +99,15 @@ TEST(WasmBpfTest, SimpleRingbuf) { (std::istreambuf_iterator(bpfObjStream)), std::istreambuf_iterator()); ASSERT_FALSE(bpfObjectBytes.empty()); - // Offset to put things into memory + // Offset used to place data in memory. uint32_t nextOffset = 1; - // Put the bpf object into memory + // Put the BPF object in memory. const uint32_t bpfObjectMemoryOffset = nextOffset; fillMemContent(memoryInstRef, bpfObjectMemoryOffset, bpfObjectBytes); nextOffset += static_cast(bpfObjectBytes.size()); - // Fill strings that will be used into memory + // Write the strings to memory. std::array strings = { "rb", // Map name "handle_exec", // Program names @@ -184,13 +184,13 @@ TEST(WasmBpfTest, SimpleRingbuf) { auto mapFd = mapFdResult[0].get(); ASSERT_GE(mapFd, 0); - // In the following several steps we will prepare for polling - // Create an instance of the polling callback function + // In the following steps we prepare for polling. + // Create an instance of the polling callback function. moduleInst.addHostFunc("__polling_callback_hostfunc"sv, std::make_unique()); auto *callbackFuncInst = moduleInst.findFuncExports("__polling_callback_hostfunc"); - // Create a function table, and fill the callback function into it + // Create a function table and fill it with the callback function. auto funcTableInst = std::make_unique( WasmEdge::AST::TableType(WasmEdge::TypeCode::FuncRef, 1)); diff --git a/test/plugins/wasm_bpf/wasm_bpf.cpp b/test/plugins/wasm_bpf/wasm_bpf.cpp index 827a491d..34f8b116 100644 --- a/test/plugins/wasm_bpf/wasm_bpf.cpp +++ b/test/plugins/wasm_bpf/wasm_bpf.cpp @@ -134,7 +134,7 @@ class PollCallbackFunction TEST(WasmBpfTest, RunBpfProgramWithPolling) { using namespace std::literals::string_view_literals; - // Test loading and attaching a bpf program, and polling buffer + // Test loading and attaching a BPF program and polling a buffer. auto module = createModule(); ASSERT_TRUE(module); @@ -153,10 +153,10 @@ TEST(WasmBpfTest, RunBpfProgramWithPolling) { namespace fs = std::filesystem; auto bpfObject = getAssertsPath() / "bootstrap.bpf.o"; - // Ensure the bpf object we need exists + // Ensure the BPF object we need exists. EXPECT_TRUE(fs::exists(bpfObject)); - // Read the bpf object into wasm memory + // Read the BPF object into Wasm memory. std::ifstream bpfObjStream(bpfObject); EXPECT_TRUE(bpfObjStream.is_open()); EXPECT_TRUE(bpfObjStream.good()); @@ -165,11 +165,11 @@ TEST(WasmBpfTest, RunBpfProgramWithPolling) { std::istreambuf_iterator()); EXPECT_FALSE(bpfObjectBytes.empty()); - // Fill bpf object into memory + // Fill memory with the BPF object. const uint32_t bpfObjectMemoryOffset = 1; fillMemContent(memoryInstRef, bpfObjectMemoryOffset, bpfObjectBytes); - // Fill `handle_exec`, the bpf function name, into memory + // Write `handle_exec`, the BPF function name, to memory. const uint32_t targetHandleExecNameMemoryOffset = bpfObjectMemoryOffset + static_cast(bpfObjectBytes.size()); const std::string targetHandleExecName("handle_exec"); @@ -181,7 +181,7 @@ TEST(WasmBpfTest, RunBpfProgramWithPolling) { fillMemContent(memoryInstRef, targetHandleExecNameMemoryOffset, targetHandleExecNameBytes); - // Fill `handle_exit`, the bpf function name, into memory + // Write `handle_exit`, the BPF function name, to memory. const uint32_t targetHandleExitNameMemoryOffset = targetHandleExecNameMemoryOffset + static_cast(targetHandleExecNameBytes.size()); @@ -204,7 +204,7 @@ TEST(WasmBpfTest, RunBpfProgramWithPolling) { std::copy(mapName.begin(), mapName.end(), mapNameBytes.begin()); fillMemContent(memoryInstRef, mapNameMemoryOffset, mapNameBytes); - // Prepare a memory area for storing polled things + // Prepare a memory area for storing polled items. const uint32_t bufferPollMemoryOffset = mapNameMemoryOffset + static_cast(mapNameBytes.size()); const uint32_t bufferPollSize = 1024; @@ -280,13 +280,13 @@ TEST(WasmBpfTest, RunBpfProgramWithPolling) { auto mapFd = mapFdResult[0].get(); EXPECT_GE(mapFd, 0); - // In the following several steps we will prepare for polling - // Create an instance of the polling callback function + // In the following steps we prepare for polling. + // Create an instance of the polling callback function. moduleInst.addHostFunc("__polling_callback_hostfunc"sv, std::make_unique()); auto *callbackFuncInst = moduleInst.findFuncExports("__polling_callback_hostfunc"); - // Create a function table, and fill the callback function into it + // Create a function table and fill it with the callback function. auto funcTableInst = std::make_unique( WasmEdge::AST::TableType(WasmEdge::TypeCode::FuncRef, 1)); @@ -348,7 +348,7 @@ struct hist { } __attribute__((packed)); TEST(WasmBpfTest, RunBpfProgramWithMapOperation) { - // Test loading and attaching a bpf program, and polling buffer + // Test loading and attaching a BPF program and polling a buffer. auto module = createModule(); ASSERT_TRUE(module); @@ -366,10 +366,10 @@ TEST(WasmBpfTest, RunBpfProgramWithMapOperation) { namespace fs = std::filesystem; auto bpfObject = getAssertsPath() / "runqlat.bpf.o"; - // Ensure the bpf object we need exists + // Ensure the BPF object we need exists. EXPECT_TRUE(fs::exists(bpfObject)); - // Read the bpf object into wasm memory + // Read the BPF object into Wasm memory. std::ifstream bpfObjStream(bpfObject); EXPECT_TRUE(bpfObjStream.is_open()); EXPECT_TRUE(bpfObjStream.good()); @@ -377,15 +377,15 @@ TEST(WasmBpfTest, RunBpfProgramWithMapOperation) { (std::istreambuf_iterator(bpfObjStream)), std::istreambuf_iterator()); EXPECT_FALSE(bpfObjectBytes.empty()); - // Offset to put things into memory + // Offset used to place data in memory. uint32_t nextOffset = 1; - // Put the bpf object into memory + // Put the BPF object in memory. const uint32_t bpfObjectMemoryOffset = nextOffset; fillMemContent(memoryInstRef, bpfObjectMemoryOffset, bpfObjectBytes); nextOffset += static_cast(bpfObjectBytes.size()); - // Fill strings that will be used into memory + // Write the strings to memory. std::array strings = { "hists", // Map name "sched_wakeup", "sched_wakeup_new", "sched_switch", // Program names @@ -482,7 +482,7 @@ TEST(WasmBpfTest, RunBpfProgramWithMapOperation) { callResult)); return callResult[0].get(); }; - // Three helper functions that will be used + // Three helper functions used below. auto mapGetNextKey = [&](int32_t fd, uint32_t lookupKey, uint32_t nextKey) -> int32_t { // lookupKey is the last element -> returns -1 @@ -508,7 +508,7 @@ TEST(WasmBpfTest, RunBpfProgramWithMapOperation) { 3, // BPF_MAP_DELETE_ELEM key, 0, 0, 0); }; - // Three helper functions to make read & write more convenient + // Three helper functions that make reading and writing more convenient. auto readU32 = [&](uint32_t offset) -> uint32_t { const auto *ptr = memoryInstRef.getPointer(offset); EXPECT_NE(ptr, nullptr); diff --git a/test/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp b/test/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp index e85bb769..f82fb2d6 100644 --- a/test/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp +++ b/test/plugins/wasmedge_ffmpeg/avcodec/avcodec_func.cpp @@ -12,7 +12,7 @@ namespace WasmEdge { namespace Host { namespace WasmEdgeFFmpeg { -// TODO: Commented functions need to be tested. +// TODO: Commented functions need tests. TEST_F(FFmpegTest, AVCodecFunc) { ASSERT_TRUE(AVCodecMod != nullptr); @@ -270,7 +270,7 @@ TEST_F(FFmpegTest, AVCodecFunc) { EXPECT_EQ(Result[0].get(), 0); } - // TODO: Need FormatCtxId To test this func. + // TODO: Need FormatCtxId to test this function. // FuncInst = AVCodecMod->findFuncExports( // "wasmedge_ffmpeg_avcodec_avcodec_parameters_copy"); // EXPECT_NE(FuncInst, nullptr); diff --git a/test/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp b/test/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp index 788e7495..809a25d9 100644 --- a/test/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp +++ b/test/plugins/wasmedge_ffmpeg/avformat/avformat_func.cpp @@ -434,8 +434,8 @@ TEST_F(FFmpegTest, AVOutputFormatFunc) { EXPECT_TRUE(Result[0].get() >= 0); } - // TODO: This test modifies the input file. Unable to test. - // Added test on rust side. + // TODO: This test modifies the input file, so it cannot be tested. + // Added test on the Rust side. // spdlog::info("Testing AVGuessCodec"sv); // uint32_t EmptyStrPtr = UINT32_C(520); // writeUInt32(MemInst, 0, EmptyStrPtr); @@ -472,9 +472,9 @@ TEST_F(FFmpegTest, AVOutputFormatFunc) { EXPECT_EQ(Result[0].get(), -22); } - // Write Header above return invalid argument due to which below test won't - // work. The OutputFormatContext should Be configured using the input format - // context. Test on the Rust side. This is working as expected. + // Writing the header above returns an invalid argument, so the test below + // does not work. The OutputFormatContext should be configured using the input + // format context. Test this on the Rust side. This is working as expected. // FuncInst = AVFormatMod->findFuncExports( // "wasmedge_ffmpeg_avformat_avformat_write_trailer"); @@ -516,7 +516,7 @@ TEST_F(FFmpegTest, AVOutputFormatFunc) { // auto &HostFuncAVChapterDynarrayAdd = dynamic_cast< // WasmEdge::Host::WasmEdgeFFmpeg::AVFormat::AVChapterDynarrayAdd &>( // FuncInst->getHostFunc()); - // // For the give input file, nb_chapter is 0; + // // For the given input file, nb_chapter is 0; // { // uint32_t AvChapterId = readUInt32(MemInst, AvFormatCtxPtr); // uint32_t AvFormatCtxId = readUInt32(MemInst, AvFormatCtxPtr); diff --git a/test/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp b/test/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp index bba090c8..779fb236 100644 --- a/test/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp +++ b/test/plugins/wasmedge_ffmpeg/avutil/avDictionary.cpp @@ -77,7 +77,7 @@ TEST_F(FFmpegTest, AVDictionary) { FuncInst->getHostFunc()); { - // Store the string length of Key and value in below Pointers. + // Store the string lengths of Key and value in the pointers below. uint32_t KeyLenPtr = UINT32_C(56); uint32_t ValueLenPtr = UINT32_C(60); uint32_t DictId = readUInt32(MemInst, DictPtr); @@ -111,7 +111,7 @@ TEST_F(FFmpegTest, AVDictionary) { FuncInst->getHostFunc()); { - // Store the string of Key and value in below Buffer Pointers. + // Store the strings of Key and value in the buffer pointers below. uint32_t KeyBufPtr = UINT32_C(36); uint32_t ValueBufPtr = UINT32_C(40); uint32_t DictId = readUInt32(MemInst, DictPtr); diff --git a/test/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp b/test/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp index 2f0fa314..805b2224 100644 --- a/test/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp +++ b/test/plugins/wasmedge_ffmpeg/swscale/swscale_func.cpp @@ -13,7 +13,7 @@ namespace Host { namespace WasmEdgeFFmpeg { // ============================================================================ -// This test deals with funcs related to SwsContext +// This test deals with functions related to SwsContext. // ============================================================================ TEST_F(FFmpegTest, SwsContext) { @@ -114,13 +114,13 @@ TEST_F(FFmpegTest, SwsContext) { FuncInst->getHostFunc()); { - // AV_PIX_FMT_RGB24 is supported Pixel Format + // AV_PIX_FMT_RGB24 is a supported pixel format. EXPECT_TRUE(HostFuncSwsIsSupportedInput.run( CallFrame, std::initializer_list{RGB24Id}, Result)); ASSERT_TRUE(Result[0].get() > 0); - // AV_PIX_FMT_XVMC is not supported Pixel Format + // AV_PIX_FMT_XVMC is not a supported pixel format. EXPECT_TRUE(HostFuncSwsIsSupportedInput.run( CallFrame, std::initializer_list{XVMCId}, Result)); @@ -136,13 +136,13 @@ TEST_F(FFmpegTest, SwsContext) { FuncInst->getHostFunc()); { - // AV_PIX_FMT_RGB24 is supported Pixel Format + // AV_PIX_FMT_RGB24 is a supported pixel format. EXPECT_TRUE(HostFuncSwsIsSupportedOutput.run( CallFrame, std::initializer_list{RGB24Id}, Result)); ASSERT_TRUE(Result[0].get() > 0); - // AV_PIX_FMT_XVMC is not supported Pixel Format + // AV_PIX_FMT_XVMC is not a supported pixel format. EXPECT_TRUE(HostFuncSwsIsSupportedOutput.run( CallFrame, std::initializer_list{XVMCId}, Result)); @@ -159,7 +159,7 @@ TEST_F(FFmpegTest, SwsContext) { FuncInst->getHostFunc()); { - // AV_PIX_FMT_XVMC is not supported Pixel Format for + // AV_PIX_FMT_XVMC is not a supported pixel format for EXPECT_TRUE(HostFuncSwsIsSupportedEndiannessConversion.run( CallFrame, std::initializer_list{XVMCId}, Result)); @@ -198,7 +198,7 @@ TEST_F(FFmpegTest, SwsContext) { } // ============================================================================ -// This test deals with funcs related to SwsFilter. +// This test deals with functions related to SwsFilter. // ============================================================================ TEST_F(FFmpegTest, SwsFilter) { @@ -319,7 +319,7 @@ TEST_F(FFmpegTest, SwsFilter) { } // ============================================================================ -// This test deals with funcs related to SwsVector. +// This test deals with functions related to SwsVector. // ============================================================================ TEST_F(FFmpegTest, SwsVector) { @@ -453,7 +453,8 @@ TEST_F(FFmpegTest, SwsVector) { } // ============================================================================ -// This test deals with funcs related to Version, Configuration and License +// This test deals with functions related to Version, Configuration, and +// License. // ============================================================================ TEST_F(FFmpegTest, SWScaleVersion) { diff --git a/test/plugins/wasmedge_ffmpeg/utils.h b/test/plugins/wasmedge_ffmpeg/utils.h index ce4d771c..79b00242 100644 --- a/test/plugins/wasmedge_ffmpeg/utils.h +++ b/test/plugins/wasmedge_ffmpeg/utils.h @@ -138,7 +138,7 @@ class FFmpegTest : public ::testing::Test { std::string FileName); void allocPacket(uint32_t PacketPtr); - // Result of Funcs to be stored here. + // Results of Funcs are stored here. std::array Result = {UINT32_C(0)}; // Create the calling frame with memory instance. diff --git a/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp b/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp index d2905c80..c9b07871 100644 --- a/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp +++ b/test/plugins/wasmedge_opencvmini/wasmedge_opencvmini.cpp @@ -42,7 +42,7 @@ std::unique_ptr createModule() { } // namespace -// TODO: unit tests for every functions. +// TODO: add unit tests for every function. TEST(WasmEdgeOpecvminiTest, Module) { // Create the wasmedge_opencvmini module instance. diff --git a/test/plugins/wasmedge_process/wasmedge_process.cpp b/test/plugins/wasmedge_process/wasmedge_process.cpp index 6791c9d8..ea80f906 100644 --- a/test/plugins/wasmedge_process/wasmedge_process.cpp +++ b/test/plugins/wasmedge_process/wasmedge_process.cpp @@ -331,7 +331,7 @@ TEST(WasmEdgeProcessTest, Run) { // Return value. std::array RetVal; - // Test: Run function failed to run "c++" without allowing all commands. + // Test: Run function fails to run "c++" without allowing all commands. ProcMod->getEnv().AllowedAll = false; ProcMod->getEnv().Name = "c++"; EXPECT_TRUE(HostFuncInst.run(DummyCallFrame, {}, RetVal)); @@ -345,7 +345,7 @@ TEST(WasmEdgeProcessTest, Run) { EXPECT_TRUE(std::equal(ProcMod->getEnv().StdErr.begin(), ProcMod->getEnv().StdErr.end(), ErrStr.begin())); - // Test: Run function successfully to run "c++" with allowing all commands. + // Test: Run function successfully to run "c++" while allowing all commands. ProcMod->getEnv().AllowedAll = true; ProcMod->getEnv().Name = "c++"; EXPECT_TRUE(HostFuncInst.run(DummyCallFrame, {}, RetVal)); @@ -353,7 +353,7 @@ TEST(WasmEdgeProcessTest, Run) { EXPECT_TRUE(ProcMod->getEnv().StdOut.size() == 0); EXPECT_TRUE(ProcMod->getEnv().StdErr.size() > 0); - // Test: Run function successfully to run "c++" with allowing this command. + // Test: Run function successfully to run "c++" while allowing this command. ProcMod->getEnv().AllowedAll = false; ProcMod->getEnv().AllowedCmd.insert("c++"); ProcMod->getEnv().Name = "c++"; @@ -362,7 +362,7 @@ TEST(WasmEdgeProcessTest, Run) { EXPECT_TRUE(ProcMod->getEnv().StdOut.size() == 0); EXPECT_TRUE(ProcMod->getEnv().StdErr.size() > 0); - // Test: Run function successfully to run "/bin/echo" with allowing this + // Test: Run function successfully to run "/bin/echo" while allowing this // command. ProcMod->getEnv().AllowedAll = false; ProcMod->getEnv().AllowedCmd.clear(); diff --git a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp index f2114669..e6817916 100644 --- a/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp +++ b/test/plugins/wasmedge_stablediffusion/wasmedge_stablediffusion.cpp @@ -59,7 +59,7 @@ void writeFatPointer(WasmEdge::Runtime::Instance::MemoryInstance &MemInst, writeUInt32(MemInst, PtrSize, Ptr); } -// TODO: unit tests for every functions. +// TODO: add unit tests for every function. TEST(WasmEdgeStableDiffusionTest, ModuleFunctions) { // Create the stable diffusion module instance. diff --git a/test/plugins/wasmedge_tensorflow/wasmedge_tensorflow.cpp b/test/plugins/wasmedge_tensorflow/wasmedge_tensorflow.cpp index 8ce35675..7d1cc84a 100644 --- a/test/plugins/wasmedge_tensorflow/wasmedge_tensorflow.cpp +++ b/test/plugins/wasmedge_tensorflow/wasmedge_tensorflow.cpp @@ -42,7 +42,7 @@ std::unique_ptr createModule() { } } // namespace -// TODO: unit tests for every functions. +// TODO: add unit tests for every function. TEST(WasmEdgeTensorflowTest, Module) { // Create the wasmedge_tensorflow module instance. diff --git a/test/plugins/wasmedge_tensorflowlite/wasmedge_tensorflowlite.cpp b/test/plugins/wasmedge_tensorflowlite/wasmedge_tensorflowlite.cpp index 843e5fc5..d8f3fc94 100644 --- a/test/plugins/wasmedge_tensorflowlite/wasmedge_tensorflowlite.cpp +++ b/test/plugins/wasmedge_tensorflowlite/wasmedge_tensorflowlite.cpp @@ -43,7 +43,7 @@ std::unique_ptr createModule() { } // namespace -// TODO: unit tests for every functions. +// TODO: add unit tests for every function. TEST(WasmEdgeTensorflowLiteTest, Module) { // Create the wasmedge_tensorflowlite module instance. diff --git a/utils/docker/Dockerfile.ubuntu2104_armv7l b/utils/docker/Dockerfile.ubuntu2104_armv7l index 3329e6f2..f1439243 100644 --- a/utils/docker/Dockerfile.ubuntu2104_armv7l +++ b/utils/docker/Dockerfile.ubuntu2104_armv7l @@ -27,7 +27,7 @@ RUN apt update && apt upgrade -y \ wget \ xz-utils -# CMake build from source to avoid compiler_id_detection fails when using QEMU user-mode emulation +# Build CMake from source to avoid compiler_id_detection failures when using QEMU user-mode emulation # See: https://gitlab.kitware.com/cmake/cmake/-/issues/20568 #RUN wget https://github.com/Kitware/CMake/releases/download/v3.29.3/cmake-3.29.3.tar.gz --no-check-certificate && \ diff --git a/utils/ffmpeg/download-ffmpeg-sample-video.sh b/utils/ffmpeg/download-ffmpeg-sample-video.sh index 151642b9..bae80b6e 100644 --- a/utils/ffmpeg/download-ffmpeg-sample-video.sh +++ b/utils/ffmpeg/download-ffmpeg-sample-video.sh @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # SPDX-FileCopyrightText: 2019-2024 Second State INC -# The below video used is sourced from an ffmpeg-libav-tutorial repository. +# The video below is sourced from an ffmpeg-libav-tutorial repository. # Source: https://github.com/leandromoreira/ffmpeg-libav-tutorial/blob/master/LICENSE. TODIR=$1 SAMPLE_VIDEO=https://raw.githubusercontent.com/Hrushi20/rust-ffmpeg/master/assets/bunny.mp4 diff --git a/utils/wasi-crypto/build-openssl.sh b/utils/wasi-crypto/build-openssl.sh index 45d628bf..3adef636 100755 --- a/utils/wasi-crypto/build-openssl.sh +++ b/utils/wasi-crypto/build-openssl.sh @@ -8,7 +8,7 @@ curl -s -L -O --remote-name-all https://www.openssl.org/source/openssl-1.1.1n.ta echo "40dceb51a4f6a5275bde0e6bf20ef4b91bfc32ed57c0552e2e8e15463372b17a openssl-1.1.1n.tar.gz" | sha256sum -c tar -xf openssl-1.1.1n.tar.gz cd ./openssl-1.1.1n -# OpenSSL configure need newer perl +# Configuring OpenSSL requires newer Perl. curl -s -L -O --remote-name-all https://www.cpan.org/src/5.0/perl-5.34.0.tar.gz tar -xf perl-5.34.0.tar.gz cd perl-5.34.0 From 88f6d792f83ccfa962dcc543998e6cd9de4b7dac Mon Sep 17 00:00:00 2001 From: Pranjal Kole <61913668+pranjalkole@users.noreply.github.com> Date: Fri, 29 May 2026 20:20:48 +0530 Subject: [PATCH 614/623] fix(test/plugins/image): remove concrete types to fix devirtualization (#4910) Signed-off-by: Pranjal Kole --- test/plugins/wasmedge_image/wasmedge_image.cpp | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/test/plugins/wasmedge_image/wasmedge_image.cpp b/test/plugins/wasmedge_image/wasmedge_image.cpp index 4f449727..479afbd0 100644 --- a/test/plugins/wasmedge_image/wasmedge_image.cpp +++ b/test/plugins/wasmedge_image/wasmedge_image.cpp @@ -2,8 +2,8 @@ // SPDX-FileCopyrightText: 2019-2024 Second State INC #include "common/defines.h" -#include "image_func.h" #include "image_module.h" +#include "runtime/callingframe.h" #include "runtime/instance/module.h" #include @@ -179,8 +179,7 @@ TEST(WasmEdgeImageTest, LoadJPG) { auto *FuncInst = ImgMod->findFuncExports("load_jpg"); EXPECT_NE(FuncInst, nullptr); EXPECT_TRUE(FuncInst->isHostFunction()); - auto &HostFuncInst = dynamic_cast( - FuncInst->getHostFunc()); + auto &HostFuncInst = FuncInst->getHostFunc(); // Test: Load JPG and resize into 50x60 RGB u8 format. // Clear the memory[0, 32768]. @@ -320,8 +319,7 @@ TEST(WasmEdgeImageTest, LoadPNG) { auto *FuncInst = ImgMod->findFuncExports("load_png"); EXPECT_NE(FuncInst, nullptr); EXPECT_TRUE(FuncInst->isHostFunction()); - auto &HostFuncInst = dynamic_cast( - FuncInst->getHostFunc()); + auto &HostFuncInst = FuncInst->getHostFunc(); // Test: Load PNG and resize into 50x60 RGB u8 format. // Clear the memory[0, 32768]. @@ -459,8 +457,7 @@ TEST(WasmEdgeImageTest, LoadImage) { auto *FuncInst = ImgMod->findFuncExports("load_image"); EXPECT_NE(FuncInst, nullptr); EXPECT_TRUE(FuncInst->isHostFunction()); - auto &HostFuncInst = dynamic_cast( - FuncInst->getHostFunc()); + auto &HostFuncInst = FuncInst->getHostFunc(); // Test: Load JPG and resize into 50x60 BGR u8 format. // Clear the memory[0, 32768]. From 14f997c0910eadf1c699e17a5fb2b079d739cb7d Mon Sep 17 00:00:00 2001 From: Pranjal Kole <61913668+pranjalkole@users.noreply.github.com> Date: Mon, 1 Jun 2026 18:51:55 +0530 Subject: [PATCH 615/623] fix(test/plugins/zlib): remove concrete types to fix devirtualization (#4913) Signed-off-by: Pranjal Kole --- test/plugins/wasmedge_zlib/wasmedge_zlib.cpp | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp b/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp index b777abf2..cb6ce197 100644 --- a/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp +++ b/test/plugins/wasmedge_zlib/wasmedge_zlib.cpp @@ -84,38 +84,32 @@ TEST(WasmEdgeZlibTest, DeflateInflateCycle) { auto *FuncInst = ZlibMod->findFuncExports("deflateInit_"); EXPECT_NE(FuncInst, nullptr); EXPECT_TRUE(FuncInst->isHostFunction()); - auto &DeflateInit_ = dynamic_cast( - FuncInst->getHostFunc()); + auto &DeflateInit_ = FuncInst->getHostFunc(); FuncInst = ZlibMod->findFuncExports("deflate"); EXPECT_NE(FuncInst, nullptr); EXPECT_TRUE(FuncInst->isHostFunction()); - auto &Deflate = dynamic_cast( - FuncInst->getHostFunc()); + auto &Deflate = FuncInst->getHostFunc(); FuncInst = ZlibMod->findFuncExports("deflateEnd"); EXPECT_NE(FuncInst, nullptr); EXPECT_TRUE(FuncInst->isHostFunction()); - auto &DeflateEnd = dynamic_cast( - FuncInst->getHostFunc()); + auto &DeflateEnd = FuncInst->getHostFunc(); FuncInst = ZlibMod->findFuncExports("inflateInit_"); EXPECT_NE(FuncInst, nullptr); EXPECT_TRUE(FuncInst->isHostFunction()); - auto &InflateInit_ = dynamic_cast( - FuncInst->getHostFunc()); + auto &InflateInit_ = FuncInst->getHostFunc(); FuncInst = ZlibMod->findFuncExports("inflate"); EXPECT_NE(FuncInst, nullptr); EXPECT_TRUE(FuncInst->isHostFunction()); - auto &Inflate = dynamic_cast( - FuncInst->getHostFunc()); + auto &Inflate = FuncInst->getHostFunc(); FuncInst = ZlibMod->findFuncExports("inflateEnd"); EXPECT_NE(FuncInst, nullptr); EXPECT_TRUE(FuncInst->isHostFunction()); - auto &InflateEnd = dynamic_cast( - FuncInst->getHostFunc()); + auto &InflateEnd = FuncInst->getHostFunc(); std::array RetVal; From 16d2e1850eb1fd749774e315154c1e91246a1c5b Mon Sep 17 00:00:00 2001 From: hydai Date: Wed, 3 Jun 2026 22:53:46 +0800 Subject: [PATCH 616/623] chore(lint): apply clang-format-20 Signed-off-by: hydai --- plugins/wasi_crypto/asymmetric_common/keypair.cpp | 2 +- plugins/wasi_crypto/symmetric/ctx.cpp | 6 +++--- plugins/wasi_crypto/symmetric/state.cpp | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/plugins/wasi_crypto/asymmetric_common/keypair.cpp b/plugins/wasi_crypto/asymmetric_common/keypair.cpp index e3881524..73a602db 100644 --- a/plugins/wasi_crypto/asymmetric_common/keypair.cpp +++ b/plugins/wasi_crypto/asymmetric_common/keypair.cpp @@ -42,7 +42,7 @@ generateKp(AsymmetricCommon::Algorithm Alg, return transposeOptionalRef( OptOptions, [](auto &&Options) noexcept - -> WasiCryptoExpect> { + -> WasiCryptoExpect> { using InOptionsType = std::decay_t; if constexpr (std::is_same_v) { diff --git a/plugins/wasi_crypto/symmetric/ctx.cpp b/plugins/wasi_crypto/symmetric/ctx.cpp index 54263417..c0dab4e3 100644 --- a/plugins/wasi_crypto/symmetric/ctx.cpp +++ b/plugins/wasi_crypto/symmetric/ctx.cpp @@ -202,7 +202,7 @@ Context::symmetricKeyGenerate(Symmetric::Algorithm Alg, auto OptSymmetricOptionsResult = transposeOptionalToRef( *OptOptionsResult, [](const auto &Options) noexcept - -> WasiCryptoExpect> { + -> WasiCryptoExpect> { auto *SymmetricOptions = std::get_if(&Options); if (!SymmetricOptions) { return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_HANDLE); @@ -228,7 +228,7 @@ Context::symmetricStateOpen(Symmetric::Algorithm Alg, auto OptKeyResult = mapAndTransposeOptional(OptKeyHandle, [this](__wasi_symmetric_key_t KeyHandle) noexcept - -> WasiCryptoExpect { + -> WasiCryptoExpect { return SymmetricKeyManager.get(KeyHandle); }); if (!OptKeyResult) { @@ -248,7 +248,7 @@ Context::symmetricStateOpen(Symmetric::Algorithm Alg, auto OptSymmetricOptionsResult = transposeOptionalToRef( *OptOptionsResult, [](const auto &Options) noexcept - -> WasiCryptoExpect> { + -> WasiCryptoExpect> { auto *SymmetricOptions = std::get_if(&Options); if (!SymmetricOptions) { return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_INVALID_HANDLE); diff --git a/plugins/wasi_crypto/symmetric/state.cpp b/plugins/wasi_crypto/symmetric/state.cpp index e532b599..ce358f9d 100644 --- a/plugins/wasi_crypto/symmetric/state.cpp +++ b/plugins/wasi_crypto/symmetric/state.cpp @@ -161,7 +161,7 @@ WasiCryptoExpect stateEncrypt(StateVariant &StateVariant, return checkedAdd(DataSize, TagLen); }) .and_then([Out, Data, &State](size_t ActualDataLen) noexcept - -> WasiCryptoExpect { + -> WasiCryptoExpect { ensureOrReturn(Out.size() == ActualDataLen, __WASI_CRYPTO_ERRNO_INVALID_LENGTH); return State.encrypt(Out, Data); @@ -189,7 +189,7 @@ WasiCryptoExpect stateDecrypt(StateVariant &StateVariant, return checkedAdd(OutSize, TagLen); }) .and_then([Out, Data, &State](size_t ActualOutLen) noexcept - -> WasiCryptoExpect { + -> WasiCryptoExpect { ensureOrReturn(Data.size() == ActualOutLen, __WASI_CRYPTO_ERRNO_INVALID_LENGTH); return State.decrypt(Out, Data); From 8b813c65e5ad0a568af15f26b5a4ad4b13dc684e Mon Sep 17 00:00:00 2001 From: aizu-m Date: Mon, 8 Jun 2026 16:57:01 +0530 Subject: [PATCH 617/623] fix(plugins/opencvmini): bound-check guest buffer length in host funcs (#4923) imdecode/imshow/imwrite/imencode fetched the input buffer with getPointer, which only validates one byte, then read a guest-controlled length past it (host out-of-bounds read). Switch to getSpan(Ptr, Len) and reject a short span, matching the output-buffer check already present in imencode. Signed-off-by: aizu-m --- .../wasmedge_opencvmini/opencvmini_func.cpp | 29 ++++++++++++++----- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/plugins/wasmedge_opencvmini/opencvmini_func.cpp b/plugins/wasmedge_opencvmini/opencvmini_func.cpp index 112d2a08..9f31a61f 100644 --- a/plugins/wasmedge_opencvmini/opencvmini_func.cpp +++ b/plugins/wasmedge_opencvmini/opencvmini_func.cpp @@ -24,9 +24,12 @@ WasmEdgeOpenCVMiniImdecode::body(const Runtime::CallingFrame &Frame, return Unexpect(ErrCode::Value::HostFuncError); } - char *Buf = MemInst->getPointer(BufPtr); + auto Buf = MemInst->getSpan(BufPtr, BufLen); + if (unlikely(Buf.size() != BufLen)) { + return Unexpect(ErrCode::Value::HostFuncError); + } - std::vector Content(Buf, Buf + BufLen); + std::vector Content(Buf.begin(), Buf.end()); cv::Mat Img = cv::imdecode(cv::InputArray(Content), cv::IMREAD_COLOR); return Env.insertMat(Img); @@ -44,8 +47,11 @@ Expect WasmEdgeOpenCVMiniImshow::body(const Runtime::CallingFrame &Frame, return Unexpect(ErrCode::Value::HostFuncError); } - char *Buf = MemInst->getPointer(WindowNamePtr); - std::copy_n(Buf, WindowNameLen, std::back_inserter(WindowName)); + auto Buf = MemInst->getSpan(WindowNamePtr, WindowNameLen); + if (unlikely(Buf.size() != WindowNameLen)) { + return Unexpect(ErrCode::Value::HostFuncError); + } + std::copy_n(Buf.data(), WindowNameLen, std::back_inserter(WindowName)); if (auto Img = Env.getMat(MatKey); Img) { cv::imshow(WindowName.c_str(), *Img); @@ -186,8 +192,12 @@ Expect WasmEdgeOpenCVMiniImwrite::body(const Runtime::CallingFrame &Frame, return Unexpect(ErrCode::Value::HostFuncError); } - char *Buf = MemInst->getPointer(TargetFileNamePtr); - std::copy_n(Buf, TargetFileNameLen, std::back_inserter(TargetFileName)); + auto Buf = MemInst->getSpan(TargetFileNamePtr, TargetFileNameLen); + if (unlikely(Buf.size() != TargetFileNameLen)) { + return Unexpect(ErrCode::Value::HostFuncError); + } + std::copy_n(Buf.data(), TargetFileNameLen, + std::back_inserter(TargetFileName)); if (auto Img = Env.getMat(MatKey); Img) { cv::imwrite(TargetFileName.c_str(), *Img); @@ -203,8 +213,11 @@ Expect WasmEdgeOpenCVMiniImencode::body( auto *MemInst = Frame.getMemoryByIndex(0); - char *Buf = MemInst->getPointer(ExtPtr); - std::copy_n(Buf, ExtLen, std::back_inserter(Ext)); + auto Buf = MemInst->getSpan(ExtPtr, ExtLen); + if (unlikely(Buf.size() != ExtLen)) { + return Unexpect(ErrCode::Value::HostFuncError); + } + std::copy_n(Buf.data(), ExtLen, std::back_inserter(Ext)); auto Img = Env.getMat(MatKey); if (!Img) { From afa621dff9f54bc0cf683c194d8f916b4205ad8e Mon Sep 17 00:00:00 2001 From: Yashika Date: Tue, 9 Jun 2026 01:09:03 +0530 Subject: [PATCH 618/623] fix(wasi_nn): reuse validated tensor data in RPC set_input (#4942) WasiNNSetInput already retrieves the tensor data through Tensor.Tensor before constructing the RPC request. This change reuses that existing span when calling RPCTensor.set_data() instead of retrieving the pointer and length again from WasiTensor. The behavior for valid inputs remains unchanged. Signed-off-by: Yashika --- plugins/wasi_nn/wasinnfunc.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/wasi_nn/wasinnfunc.cpp b/plugins/wasi_nn/wasinnfunc.cpp index b8a65f4f..afa7b66a 100644 --- a/plugins/wasi_nn/wasinnfunc.cpp +++ b/plugins/wasi_nn/wasinnfunc.cpp @@ -356,8 +356,8 @@ WasiNNSetInput::bodyImpl(const Runtime::CallingFrame &Frame, uint32_t ContextId, RPCTensor.mutable_dimensions()->Add(Tensor.Dimension.begin(), Tensor.Dimension.end()); RPCTensor.set_ty(wasi_ephemeral_nn::TensorType(Tensor.RType)); - RPCTensor.set_data(MemInst->getPointer(WasiTensor->TensorPtr), - WasiTensor->TensorLen); + RPCTensor.set_data(reinterpret_cast(Tensor.Tensor.data()), + Tensor.Tensor.size()); *Req.mutable_tensor() = RPCTensor; google::protobuf::Empty Res; auto Status = Stub->SetInput(&ClientContext, Req, &Res); From ea1d45f7f630de9205f2f4b825927f12218d81f9 Mon Sep 17 00:00:00 2001 From: Vishal Malyan <146833908+vishal2005025@users.noreply.github.com> Date: Tue, 9 Jun 2026 10:56:38 +0530 Subject: [PATCH 619/623] ci: remove obsolete apt-transport-https and duplicate dpkg-dev (#4646) Signed-off-by: vishal2005025 --- utils/docker/Dockerfile.ci-image-base | 1 - utils/docker/Dockerfile.ubuntu2004_x86_64 | 1 - 2 files changed, 2 deletions(-) diff --git a/utils/docker/Dockerfile.ci-image-base b/utils/docker/Dockerfile.ci-image-base index 15a854af..a857630b 100644 --- a/utils/docker/Dockerfile.ci-image-base +++ b/utils/docker/Dockerfile.ci-image-base @@ -4,7 +4,6 @@ ENV DEBIAN_FRONTEND=noninteractive RUN apt update && apt upgrade -y \ && apt install -y \ - apt-transport-https \ ca-certificates \ curl \ gnupg-agent \ diff --git a/utils/docker/Dockerfile.ubuntu2004_x86_64 b/utils/docker/Dockerfile.ubuntu2004_x86_64 index db96296b..f71ed3c8 100644 --- a/utils/docker/Dockerfile.ubuntu2004_x86_64 +++ b/utils/docker/Dockerfile.ubuntu2004_x86_64 @@ -16,7 +16,6 @@ RUN apt update && apt upgrade -y \ liblld-12-dev \ gcc \ rpm \ - dpkg-dev \ g++ RUN rm -rf /var/lib/apt/lists/* From 7ba76c28f6044f331b0b8b5e6ec255777900fee4 Mon Sep 17 00:00:00 2001 From: SHIGRAF SALIK <140247389+ShigrafS@users.noreply.github.com> Date: Wed, 10 Jun 2026 21:01:25 +0530 Subject: [PATCH 620/623] feat(wasi_crypto): implement managed keypair generation for EDDSA (#4776) * feat(wasi_crypto): implement managed keypair generation for EDDSA - Implemented keypairGenerateManaged for EDDSA algorithms in ctx.cpp. - Added comprehensive positive and negative tests in asymmetric.cpp. - Integrated NOT_IMPLEMENTED verification for unsupported algorithms. - Updated notimplement.cpp to reflect the new implementation. - Removed unrelated formatting and typo fixes to minimize PR noise. Signed-off-by: ShigrafS * fix(test): resolve type mismatch in managed keypair generation test - Use std::optional<__wasi_options_t> instead of __wasi_opt_options_t to match helper function signature. - Pass InvaildHandle directly as an optional value. Signed-off-by: ShigrafS * style(wasi_crypto): format code using clang-format-20 Signed-off-by: SHIGRAF SALIK <140247389+ShigrafS@users.noreply.github.com> * test(wasi_crypto): refactor managed keypair tests to follow project patterns - Define ManagedNegativeCheck lambda for data-driven negative testing. - Integrate managed keypair negative tests into the main Asymmetric suite. - Replace manual EXPECT_EQ with WASI_CRYPTO_EXPECT_FAILURE macro in notimplement.cpp. Assisted-by: Gemini (Google) Signed-off-by: ShigrafS --------- Signed-off-by: ShigrafS Signed-off-by: SHIGRAF SALIK <140247389+ShigrafS@users.noreply.github.com> --- plugins/wasi_crypto/asymmetric_common/ctx.cpp | 10 +++--- test/plugins/wasi_crypto/asymmetric.cpp | 35 +++++++++++++++++++ test/plugins/wasi_crypto/notimplement.cpp | 8 ++--- 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/plugins/wasi_crypto/asymmetric_common/ctx.cpp b/plugins/wasi_crypto/asymmetric_common/ctx.cpp index 4e94a3fd..a610a654 100644 --- a/plugins/wasi_crypto/asymmetric_common/ctx.cpp +++ b/plugins/wasi_crypto/asymmetric_common/ctx.cpp @@ -162,10 +162,12 @@ Context::secretkeyImport(AsymmetricCommon::Algorithm Alg, }); } -WasiCryptoExpect<__wasi_keypair_t> -Context::keypairGenerateManaged(__wasi_secrets_manager_t, - AsymmetricCommon::Algorithm, - __wasi_opt_options_t) noexcept { +WasiCryptoExpect<__wasi_keypair_t> Context::keypairGenerateManaged( + __wasi_secrets_manager_t, AsymmetricCommon::Algorithm Alg, + __wasi_opt_options_t OptOptionsHandle) noexcept { + if (std::holds_alternative(Alg)) { + return keypairGenerate(Alg, OptOptionsHandle); + } return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); } diff --git a/test/plugins/wasi_crypto/asymmetric.cpp b/test/plugins/wasi_crypto/asymmetric.cpp index 24f621e7..2a2b5018 100644 --- a/test/plugins/wasi_crypto/asymmetric.cpp +++ b/test/plugins/wasi_crypto/asymmetric.cpp @@ -28,6 +28,21 @@ TEST_F(WasiCryptoTest, Asymmetric) { WASI_CRYPTO_EXPECT_TRUE(keypairClose(KpHandle)); WASI_CRYPTO_EXPECT_TRUE(publickeyClose(PkHandle)); WASI_CRYPTO_EXPECT_TRUE(secretkeyClose(SkHandle)); + + if (Alg == "Ed25519"sv) { + WASI_CRYPTO_EXPECT_SUCCESS( + KpMHandle, + keypairGenerateManaged(1, AlgType, Alg, std::nullopt)); + WASI_CRYPTO_EXPECT_SUCCESS(PkMHandle, keypairPublickey(KpMHandle)); + WASI_CRYPTO_EXPECT_SUCCESS(SkMHandle, keypairSecretkey(KpMHandle)); + WASI_CRYPTO_EXPECT_TRUE(keypairClose(KpMHandle)); + WASI_CRYPTO_EXPECT_TRUE(publickeyClose(PkMHandle)); + WASI_CRYPTO_EXPECT_TRUE(secretkeyClose(SkMHandle)); + } else { + WASI_CRYPTO_EXPECT_FAILURE( + keypairGenerateManaged(1, AlgType, Alg, std::nullopt), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + } } // Encoding checking. @@ -200,6 +215,18 @@ TEST_F(WasiCryptoTest, Asymmetric) { } }; + auto ManagedNegativeCheck = [this]( + __wasi_secrets_manager_t SmHandle, + __wasi_algorithm_type_e_t AlgType, + std::string_view Alg, + std::optional<__wasi_options_t> OptOptions, + __wasi_crypto_errno_e_t ExpectedError) { + SCOPED_TRACE(Alg); + WASI_CRYPTO_EXPECT_FAILURE( + keypairGenerateManaged(SmHandle, AlgType, Alg, OptOptions), + ExpectedError); + }; + RsaCheck( "2048"sv, {{__WASI_PUBLICKEY_ENCODING_PEM, @@ -600,6 +627,14 @@ TEST_F(WasiCryptoTest, Asymmetric) { "3e28cd3c23ecbe66098bd0a029a8792ae5259282c116b32a28908aaa1bcf" "ef61d9529f"_u8v}}, {}); + + ManagedNegativeCheck(1, __WASI_ALGORITHM_TYPE_SYMMETRIC, "Ed25519"sv, + std::nullopt, __WASI_CRYPTO_ERRNO_INVALID_OPERATION); + ManagedNegativeCheck(1, __WASI_ALGORITHM_TYPE_SIGNATURES, "FooBar"sv, + std::nullopt, __WASI_CRYPTO_ERRNO_UNSUPPORTED_ALGORITHM); + ManagedNegativeCheck(1, __WASI_ALGORITHM_TYPE_SIGNATURES, "Ed25519"sv, + static_cast<__wasi_options_t>(InvaildHandle), + __WASI_CRYPTO_ERRNO_INVALID_HANDLE); } } // namespace WasiCrypto diff --git a/test/plugins/wasi_crypto/notimplement.cpp b/test/plugins/wasi_crypto/notimplement.cpp index 0b255180..f525152e 100644 --- a/test/plugins/wasi_crypto/notimplement.cpp +++ b/test/plugins/wasi_crypto/notimplement.cpp @@ -21,10 +21,10 @@ TEST_F(WasiCryptoTest, NotImplement) { WASI_CRYPTO_EXPECT_FAILURE(symmetricKeyFromId(1, {}, 1), __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); - EXPECT_EQ(keypairGenerateManaged(1, __WASI_ALGORITHM_TYPE_SIGNATURES, - "Ed25519"sv, std::nullopt) - .error(), - __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + WASI_CRYPTO_EXPECT_FAILURE( + keypairGenerateManaged(1, __WASI_ALGORITHM_TYPE_SIGNATURES, + "ECDSA_P256_SHA256"sv, std::nullopt), + __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); WASI_CRYPTO_EXPECT_FAILURE(keypairStoreManaged(1, 1, {}), __WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); WASI_CRYPTO_EXPECT_FAILURE(keypairReplaceManaged(1, 1, 1), From cf9d4fb4861286da7d9f79ad53efc7221a11f95c Mon Sep 17 00:00:00 2001 From: hydai Date: Thu, 11 Jun 2026 00:56:18 +0800 Subject: [PATCH 621/623] fix: rename shadowed Env params in LLMC host function constructors Rename the constructor parameters that shadowed the inherited Env member to HostEnv, in both the HostFunction base and its subclasses. Assisted-by: Claude (Anthropic) Signed-off-by: hydai --- plugins/wasmedge_llmc/llmc_base.h | 2 +- plugins/wasmedge_llmc/llmc_func.h | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/plugins/wasmedge_llmc/llmc_base.h b/plugins/wasmedge_llmc/llmc_base.h index 6af0d36a..6b35ffff 100644 --- a/plugins/wasmedge_llmc/llmc_base.h +++ b/plugins/wasmedge_llmc/llmc_base.h @@ -14,7 +14,7 @@ namespace WasmEdgeLLMC { template class HostFunction : public Runtime::HostFunction { public: - HostFunction(LLMCEnv &E) : Runtime::HostFunction(0), Env(E) {} + HostFunction(LLMCEnv &HostEnv) : Runtime::HostFunction(0), Env(HostEnv) {} protected: static constexpr uint32_t castErrNo(ErrNo E) noexcept { diff --git a/plugins/wasmedge_llmc/llmc_func.h b/plugins/wasmedge_llmc/llmc_func.h index 4bae88fb..85c786f5 100644 --- a/plugins/wasmedge_llmc/llmc_func.h +++ b/plugins/wasmedge_llmc/llmc_func.h @@ -16,7 +16,7 @@ namespace WasmEdgeLLMC { class ModelCreate : public HostFunction { public: - explicit ModelCreate(LLMCEnv &Env) : HostFunction(Env) {} + explicit ModelCreate(LLMCEnv &HostEnv) : HostFunction(HostEnv) {} Expect body(const Runtime::CallingFrame &Frame, uint32_t CheckPointPath, uint32_t CheckPointPathLen, @@ -33,7 +33,7 @@ class ModelCreate : public HostFunction { class DataLoaderCreate : public HostFunction { public: - explicit DataLoaderCreate(LLMCEnv &Env) : HostFunction(Env) {} + explicit DataLoaderCreate(LLMCEnv &HostEnv) : HostFunction(HostEnv) {} Expect body(const Runtime::CallingFrame &Frame, uint32_t DataPath, uint32_t DataPathLen, uint32_t B, uint32_t T, @@ -53,7 +53,7 @@ class DataLoaderCreate : public HostFunction { class TokenizerCreate : public HostFunction { public: - explicit TokenizerCreate(LLMCEnv &Env) : HostFunction(Env) {} + explicit TokenizerCreate(LLMCEnv &HostEnv) : HostFunction(HostEnv) {} Expect body(const Runtime::CallingFrame &Frame, uint32_t FilePath, uint32_t FilePathLen, uint32_t TokenizerIdPtr) { @@ -68,7 +68,7 @@ class TokenizerCreate : public HostFunction { class ModelTrain : public HostFunction { public: - explicit ModelTrain(LLMCEnv &Env) : HostFunction(Env) {} + explicit ModelTrain(LLMCEnv &HostEnv) : HostFunction(HostEnv) {} Expect body(const Runtime::CallingFrame &Frame, uint32_t ModelId, uint32_t TrainDataLoaderId, uint32_t ValDataLoaderId, From 2c90c547a7e1803c9ab3846f42486f874522c518 Mon Sep 17 00:00:00 2001 From: hydai Date: Thu, 11 Jun 2026 00:56:26 +0800 Subject: [PATCH 622/623] fix: rename shadowed identifiers in wasi_nn MLX backend Rename parameters and locals that shadowed members, types, or outer-scope bindings across the MLX base module, embedding, linear, and pooling layers, the whisper tokenizer and decoder, and the VLM cache, and give the pooling constructor parameters semantic names. Assisted-by: Claude (Anthropic) Signed-off-by: hydai --- plugins/wasi_nn/MLX/mlx/base.cpp | 19 +++++++-------- plugins/wasi_nn/MLX/mlx/base.h | 8 +++---- plugins/wasi_nn/MLX/mlx/embedding.h | 4 ++-- plugins/wasi_nn/MLX/mlx/linear.h | 4 ++-- plugins/wasi_nn/MLX/mlx/pooling.cpp | 23 ++++++++++--------- plugins/wasi_nn/MLX/mlx/pooling.h | 16 ++++++------- plugins/wasi_nn/MLX/model/vlm_base.cpp | 4 ++-- plugins/wasi_nn/MLX/model/vlm_base.h | 2 +- .../wasi_nn/MLX/model/whisper/decoding.cpp | 5 ++-- plugins/wasi_nn/MLX/model/whisper/decoding.h | 2 +- .../wasi_nn/MLX/model/whisper/tokenizer.cpp | 9 ++++---- plugins/wasi_nn/MLX/model/whisper/tokenizer.h | 2 +- 12 files changed, 51 insertions(+), 47 deletions(-) diff --git a/plugins/wasi_nn/MLX/mlx/base.cpp b/plugins/wasi_nn/MLX/mlx/base.cpp index 4535f003..0030bb17 100644 --- a/plugins/wasi_nn/MLX/mlx/base.cpp +++ b/plugins/wasi_nn/MLX/mlx/base.cpp @@ -9,26 +9,26 @@ namespace WasmEdge::Host::WASINN::MLX { namespace mlx::core::nn { -mx::array &Module::registerParameter(std::string Name, mx::array &&W) { - Parameters.insert({Name, W}); - return Parameters.at(Name); +mx::array &Module::registerParameter(std::string ParamName, mx::array &&W) { + Parameters.insert({ParamName, W}); + return Parameters.at(ParamName); } -void Module::update(std::unordered_map Parameters) { - for (auto &[K, V] : Parameters) { +void Module::update(std::unordered_map NewParameters) { + for (auto &[K, V] : NewParameters) { apply(K, V); } } std::shared_ptr Module::toQuantized( int GroupSize, int Bits, const std::string &Prefix, - const std::unordered_map &Parameters) { + const std::unordered_map &LoadedWeights) { auto NewPrefix = Prefix + Name + (Prefix.empty() && Name.empty() ? "" : "."); for (auto &[K, V] : Submodules) { if (V->hasQuantize()) { auto Weights = V->Parameters.find("weight"); - if (Weights != V->Parameters.end() && !Parameters.empty()) { - if (Parameters.count(NewPrefix + V->Name + ".scales") == 0) { + if (Weights != V->Parameters.end() && !LoadedWeights.empty()) { + if (LoadedWeights.count(NewPrefix + V->Name + ".scales") == 0) { continue; } } @@ -38,7 +38,8 @@ std::shared_ptr Module::toQuantized( } } V = V->toQuantized(GroupSize, Bits, - Prefix + Name + (Name.empty() ? "" : "."), Parameters); + Prefix + Name + (Name.empty() ? "" : "."), + LoadedWeights); } return shared_from_this(); } diff --git a/plugins/wasi_nn/MLX/mlx/base.h b/plugins/wasi_nn/MLX/mlx/base.h index 3332d434..971b4f32 100644 --- a/plugins/wasi_nn/MLX/mlx/base.h +++ b/plugins/wasi_nn/MLX/mlx/base.h @@ -31,20 +31,20 @@ class Module : public std::enable_shared_from_this { virtual ~Module() = default; - mx::array ®isterParameter(std::string Name, mx::array &&W); + mx::array ®isterParameter(std::string ParamName, mx::array &&W); std::unordered_map getWeigts(const std::string &Prefix = "model"); virtual std::shared_ptr toQuantized( int GroupSize = 64, int Bits = 4, const std::string &Prefix = "", - const std::unordered_map &Parameters = {}); + const std::unordered_map &LoadedWeights = {}); virtual bool hasQuantize() { return false; } - void update(std::unordered_map Parameters); + void update(std::unordered_map NewParameters); - void apply(std::string Key, mx::array Parameters); + void apply(std::string Key, mx::array Value); template void registerModule(std::string ModuleName, std::shared_ptr M) { diff --git a/plugins/wasi_nn/MLX/mlx/embedding.h b/plugins/wasi_nn/MLX/mlx/embedding.h index 6e6b88d5..66175d0e 100644 --- a/plugins/wasi_nn/MLX/mlx/embedding.h +++ b/plugins/wasi_nn/MLX/mlx/embedding.h @@ -30,8 +30,8 @@ class Embedding : public Module { std::shared_ptr toQuantized(int GroupSize = 64, int Bits = 4, const std::string &Prefix = "", - const std::unordered_map &Parameters = {}) - override; + const std::unordered_map &LoadedWeights = + {}) override; virtual bool hasQuantize() override { return true; } }; diff --git a/plugins/wasi_nn/MLX/mlx/linear.h b/plugins/wasi_nn/MLX/mlx/linear.h index 6d880776..a8b07360 100644 --- a/plugins/wasi_nn/MLX/mlx/linear.h +++ b/plugins/wasi_nn/MLX/mlx/linear.h @@ -36,8 +36,8 @@ class Linear : public Module { virtual mx::array forward(mx::array Input); std::shared_ptr toQuantized(int GroupSize = 64, int Bits = 4, const std::string &Prefix = "", - const std::unordered_map &Parameters = {}) - override; + const std::unordered_map &LoadedWeights = + {}) override; virtual bool hasQuantize() override { return true; } }; diff --git a/plugins/wasi_nn/MLX/mlx/pooling.cpp b/plugins/wasi_nn/MLX/mlx/pooling.cpp index 4c5ff591..f81349d7 100644 --- a/plugins/wasi_nn/MLX/mlx/pooling.cpp +++ b/plugins/wasi_nn/MLX/mlx/pooling.cpp @@ -151,30 +151,31 @@ mx::array Pool::forward(const mx::array &X) { Pool2d::Pool2d( const std::function &)> - &PoolingFunction, - int PaddingValue, const std::vector &KernelSize, + &PoolingFn, + int PadValue, const std::vector &KernelSizes, const std::optional> &StrideOpt, const std::optional> &PaddingOpt) - : Pool(PoolingFunction, - KernelSize.size() == 1 ? valueOrList(KernelSize[0], 2) : KernelSize, + : Pool(PoolingFn, + KernelSizes.size() == 1 ? valueOrList(KernelSizes[0], 2) + : KernelSizes, (StrideOpt.has_value() ? (StrideOpt.value().size() == 1 ? valueOrList(StrideOpt.value()[0], 2) : StrideOpt.value()) - : (KernelSize.size() == 1 ? valueOrList(KernelSize[0], 2) - : KernelSize)), + : (KernelSizes.size() == 1 ? valueOrList(KernelSizes[0], 2) + : KernelSizes)), makePaddingPairs(PaddingOpt.has_value() ? PaddingOpt.value() : valueOrList(0, 2)), - PaddingValue) {} + PadValue) {} -AvgPool2d::AvgPool2d(const std::vector &KernelSize, - const std::optional> &Stride, - const std::optional> &Padding) +AvgPool2d::AvgPool2d(const std::vector &KernelSizes, + const std::optional> &StrideOpt, + const std::optional> &PaddingOpt) : Pool2d( [](const mx::array &A, const std::vector &Axis) -> mx::array { return mx::mean(A, Axis, false); }, - 0, KernelSize, Stride, Padding) {} + 0, KernelSizes, StrideOpt, PaddingOpt) {} } // namespace mlx::core::nn } // namespace WasmEdge::Host::WASINN::MLX diff --git a/plugins/wasi_nn/MLX/mlx/pooling.h b/plugins/wasi_nn/MLX/mlx/pooling.h index 1e14d0f4..c5a20fc2 100644 --- a/plugins/wasi_nn/MLX/mlx/pooling.h +++ b/plugins/wasi_nn/MLX/mlx/pooling.h @@ -27,18 +27,18 @@ class Pool : public nn::Module { class Pool2d : public Pool { public: - Pool2d(const std::function &)> &PoolingFunction, - int PaddingValue, const std::vector &KernelSize, - const std::optional> &Stride, - const std::optional> &Padding); + Pool2d(const std::function &)> &PoolingFn, + int PadValue, const std::vector &KernelSizes, + const std::optional> &StrideOpt, + const std::optional> &PaddingOpt); }; class AvgPool2d : public Pool2d { public: - AvgPool2d(const std::vector &KernelSize, - const std::optional> &Stride = std::nullopt, - const std::optional> &Padding = std::nullopt); + AvgPool2d(const std::vector &KernelSizes, + const std::optional> &StrideOpt = std::nullopt, + const std::optional> &PaddingOpt = std::nullopt); }; } // namespace mlx::core::nn diff --git a/plugins/wasi_nn/MLX/model/vlm_base.cpp b/plugins/wasi_nn/MLX/model/vlm_base.cpp index 7e30518e..7e3ec447 100644 --- a/plugins/wasi_nn/MLX/model/vlm_base.cpp +++ b/plugins/wasi_nn/MLX/model/vlm_base.cpp @@ -131,8 +131,8 @@ int KVCache::trim(int N) { } // RotatingKVCache implementation -RotatingKVCache::RotatingKVCache(int MaxSize, int Keep, int Step) - : KVCache(0, 0, Step), Keep(Keep), MaxSize(MaxSize), Idx(0) {} +RotatingKVCache::RotatingKVCache(int MaxSize, int Keep, int StepSize) + : KVCache(0, 0, StepSize), Keep(Keep), MaxSize(MaxSize), Idx(0) {} mx::array RotatingKVCache::trim(int TrimSize, const mx::array &V, std::optional Append) { diff --git a/plugins/wasi_nn/MLX/model/vlm_base.h b/plugins/wasi_nn/MLX/model/vlm_base.h index 58bee647..45fc835f 100644 --- a/plugins/wasi_nn/MLX/model/vlm_base.h +++ b/plugins/wasi_nn/MLX/model/vlm_base.h @@ -67,7 +67,7 @@ class RotatingKVCache : public KVCache { int MaxSize; int Idx; - RotatingKVCache(int MaxSize = -1, int Keep = 0, int Step = 256); + RotatingKVCache(int MaxSize = -1, int Keep = 0, int StepSize = 256); std::tuple updateAndFetch(const mx::array &NewKeys, const mx::array &NewValues) override; diff --git a/plugins/wasi_nn/MLX/model/whisper/decoding.cpp b/plugins/wasi_nn/MLX/model/whisper/decoding.cpp index 2be9c624..a387f1ba 100644 --- a/plugins/wasi_nn/MLX/model/whisper/decoding.cpp +++ b/plugins/wasi_nn/MLX/model/whisper/decoding.cpp @@ -468,8 +468,9 @@ DecodingTask::DecodingTask(std::shared_ptr Model, } } -DecodingOptions DecodingTask::verifyOptions(const DecodingOptions &Options) { - DecodingOptions Result = Options; +DecodingOptions +DecodingTask::verifyOptions(const DecodingOptions &InputOptions) { + DecodingOptions Result = InputOptions; // Check beam_size and best_of conflicts if (Result.BeamSize && Result.BestOf) { diff --git a/plugins/wasi_nn/MLX/model/whisper/decoding.h b/plugins/wasi_nn/MLX/model/whisper/decoding.h index 6e8b55cd..84f05d76 100644 --- a/plugins/wasi_nn/MLX/model/whisper/decoding.h +++ b/plugins/wasi_nn/MLX/model/whisper/decoding.h @@ -163,7 +163,7 @@ class DecodingTask { std::vector run(const mx::array &Mel); private: - DecodingOptions verifyOptions(const DecodingOptions &Options); + DecodingOptions verifyOptions(const DecodingOptions &InputOptions); std::vector getInitialTokens(); std::vector getSuppressTokens(); mx::array getAudioFeatures(const mx::array &Mel); diff --git a/plugins/wasi_nn/MLX/model/whisper/tokenizer.cpp b/plugins/wasi_nn/MLX/model/whisper/tokenizer.cpp index 42d76c6a..4178261f 100644 --- a/plugins/wasi_nn/MLX/model/whisper/tokenizer.cpp +++ b/plugins/wasi_nn/MLX/model/whisper/tokenizer.cpp @@ -451,13 +451,14 @@ int Tokenizer::languageToken() const { return toLanguageToken(Language.value()); } -int Tokenizer::toLanguageToken(const std::string &Language) const { - std::string TokenName = "<|" + Language + "|>"; +int Tokenizer::toLanguageToken(const std::string &LanguageCode) const { + std::string TokenName = "<|" + LanguageCode + "|>"; auto It = SpecialTokens.find(TokenName); if (It != SpecialTokens.end()) { return It->second; } - throw std::runtime_error("Language " + Language + " not found in tokenizer."); + throw std::runtime_error("Language " + LanguageCode + + " not found in tokenizer."); } std::vector Tokenizer::getAllLanguageTokens() const { @@ -703,7 +704,7 @@ std::unique_ptr getEncoding(const std::string &Name, std::vector Specials = {"<|endoftext|>", "<|startoftranscript|>"}; - for (const auto &[Code, Name] : LANGUAGES) { + for (const auto &[Code, LanguageName] : LANGUAGES) { Specials.push_back("<|" + Code + "|>"); if (static_cast(Specials.size()) >= NumLanguages + 2) break; diff --git a/plugins/wasi_nn/MLX/model/whisper/tokenizer.h b/plugins/wasi_nn/MLX/model/whisper/tokenizer.h index 9c068e09..23ecf401 100644 --- a/plugins/wasi_nn/MLX/model/whisper/tokenizer.h +++ b/plugins/wasi_nn/MLX/model/whisper/tokenizer.h @@ -63,7 +63,7 @@ class Tokenizer { int getNoTimestamps() const; int getTimestampBegin() const; int languageToken() const; - int toLanguageToken(const std::string &Language) const; + int toLanguageToken(const std::string &LanguageCode) const; std::vector getAllLanguageTokens() const; std::vector getAllLanguageCodes() const; std::vector getSotSequenceIncludingNotimestamps() const; From 18ce2cd89d99f5642158497e0fc72bb287108666 Mon Sep 17 00:00:00 2001 From: Parth Dagia Date: Thu, 11 Jun 2026 20:39:49 +0530 Subject: [PATCH 623/623] feat(wasi_crypto): implement EdDSA public key verification (#4927) Implement Eddsa::PublicKey::verify() using EVP_PKEY_public_check. Part of #2669. Signed-off-by: Parth Dagia --- plugins/wasi_crypto/signatures/eddsa.cpp | 10 +++++++++- test/plugins/wasi_crypto/signatures.cpp | 21 +++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/plugins/wasi_crypto/signatures/eddsa.cpp b/plugins/wasi_crypto/signatures/eddsa.cpp index 08f15260..f3d36297 100644 --- a/plugins/wasi_crypto/signatures/eddsa.cpp +++ b/plugins/wasi_crypto/signatures/eddsa.cpp @@ -35,7 +35,15 @@ Eddsa::PublicKey::import(Span Encoded, } WasiCryptoExpect Eddsa::PublicKey::verify() const noexcept { - return WasiCryptoUnexpect(__WASI_CRYPTO_ERRNO_NOT_IMPLEMENTED); + EvpPkeyCtxPtr CheckCtx{EVP_PKEY_CTX_new(Ctx.get(), nullptr)}; + ensureOrReturn(CheckCtx, __WASI_CRYPTO_ERRNO_ALGORITHM_FAILURE); + // EVP_PKEY_public_check() returns 1 for a valid key and 0 for an invalid + // one. A negative value means the check is unsupported for this key type + // (e.g. Ed25519 on OpenSSL 1.1.1), so only an explicit 0 is treated as + // invalid. + ensureOrReturn(EVP_PKEY_public_check(CheckCtx.get()) != 0, + __WASI_CRYPTO_ERRNO_INVALID_KEY); + return {}; } WasiCryptoExpect> Eddsa::PublicKey::exportData( diff --git a/test/plugins/wasi_crypto/signatures.cpp b/test/plugins/wasi_crypto/signatures.cpp index 07deb8c4..6d814992 100644 --- a/test/plugins/wasi_crypto/signatures.cpp +++ b/test/plugins/wasi_crypto/signatures.cpp @@ -52,6 +52,27 @@ TEST_F(WasiCryptoTest, Signatures) { SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PSS_3072_SHA512"sv); SigTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "RSA_PSS_4096_SHA512"sv); + // Verify that a generated public key is well-formed. + auto PublicKeyVerifyTest = [this](__wasi_algorithm_type_e_t AlgType, + std::string_view Alg) { + SCOPED_TRACE(Alg); + WASI_CRYPTO_EXPECT_SUCCESS(KpHandle, + keypairGenerate(AlgType, Alg, std::nullopt)); + WASI_CRYPTO_EXPECT_SUCCESS(PkHandle, keypairPublickey(KpHandle)); + WASI_CRYPTO_EXPECT_TRUE(publickeyVerify(PkHandle)); + }; + PublicKeyVerifyTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "Ed25519"sv); + + // A raw public key with an invalid length is rejected on import. + auto PublicKeyImportInvalidTest = [this](__wasi_algorithm_type_e_t AlgType, + std::string_view Alg) { + SCOPED_TRACE(Alg); + WASI_CRYPTO_EXPECT_FAILURE( + publickeyImport(AlgType, Alg, "00"_u8v, __WASI_PUBLICKEY_ENCODING_RAW), + __WASI_CRYPTO_ERRNO_INVALID_KEY); + }; + PublicKeyImportInvalidTest(__WASI_ALGORITHM_TYPE_SIGNATURES, "Ed25519"sv); + auto SigEncodingTest = [this]( std::string_view Alg,